Coverage for src/cvx/linalg/covariance/ewm_cov.py: 100%

29 statements  

« prev     ^ index     » next       coverage.py v7.15.0, created at 2026-07-03 18:56 +0000

1"""Exponentially weighted covariance matrix computation. 

2 

3This module requires the optional ``polars`` dependency. Install it with 

4``pip install cvx-linalg[ewm]``. 

5""" 

6 

7from __future__ import annotations 

8 

9from collections.abc import Hashable 

10 

11import numpy as np 

12 

13from ..core.exceptions import NegativeWarmupError as NegativeWarmupError 

14from ..core.exceptions import NonIntegerWarmupError 

15from ..core.types import Matrix 

16 

17try: 

18 import polars as pl 

19except ImportError as exc: # pragma: no cover 

20 _msg = ( 

21 "polars is required for cvx.linalg.covariance.ewm_cov; " 

22 "install it with `pip install cvx-linalg[ewm]` or `pip install polars`." 

23 ) 

24 raise ImportError(_msg) from exc 

25 

26 

27def ewm_covariance( 

28 data: pl.DataFrame, 

29 assets: list[str], 

30 index_col: str, 

31 window: int = 30, 

32 is_halflife: bool = False, 

33 warmup: int = 0, 

34) -> dict[Hashable, Matrix]: 

35 """Compute the exponentially weighted covariance matrix of returns. 

36 

37 EWM covariance uses the identity 

38 ``Cov(X, Y) = EWM(X*Y) - EWM(X)*EWM(Y)`` applied to the 

39 *common non-null observations* of each pair, which is equivalent 

40 to ``pandas.DataFrame.ewm(span).cov(bias=True)``. 

41 

42 Each date is included in the result as long as at least one 

43 matrix entry is non-NaN. Cells involving a late-starting asset 

44 are ``NaN`` until that asset has enough observations; the date is 

45 never dropped on account of a single asset being unavailable. 

46 Dates where every cell is NaN (before the warmup period is met 

47 for any asset) are omitted. 

48 

49 Args: 

50 data: Polars DataFrame containing the index column and asset columns. 

51 assets: Ordered list of asset column names. 

52 index_col: Name of the index (e.g. date) column in *data*. 

53 window: Span (default) or half-life (when *is_halflife* is 

54 ``True``) of the exponential decay. Defaults to ``30``. 

55 is_halflife: When ``True`` *window* is interpreted as the 

56 half-life; otherwise it is the EWMA span. Defaults to 

57 ``False``. 

58 warmup: Minimum number of common observations required before 

59 a pair's cell is non-NaN. Defaults to ``0`` (cells are 

60 non-NaN from the first shared observation). 

61 

62 Returns: 

63 Dictionary keyed by index value (date or integer) mapping to 

64 a square symmetric ``numpy.ndarray`` of shape ``(n, n)`` 

65 where ``n`` is the number of assets. Row/column order 

66 matches *assets*. Unavailable cells are ``NaN``. 

67 

68 Raises: 

69 NonIntegerWarmupError: If *warmup* is not an integer (booleans included). 

70 NegativeWarmupError: If *warmup* is negative. 

71 

72 """ 

73 if isinstance(warmup, bool) or not isinstance(warmup, int): 

74 raise NonIntegerWarmupError(warmup) 

75 if warmup < 0: 

76 raise NegativeWarmupError(warmup) 

77 

78 n = len(assets) 

79 min_samples = 1 if warmup == 0 else warmup 

80 

81 def _ewm(expr: pl.Expr) -> pl.Expr: 

82 """Apply EWM mean with the configured span or half-life.""" 

83 if is_halflife: 

84 return expr.ewm_mean(half_life=window, min_samples=min_samples) 

85 return expr.ewm_mean(span=window, min_samples=min_samples) 

86 

87 cov_exprs = [ 

88 ( 

89 _ewm(pl.col(a) * pl.col(b)) 

90 - _ewm(pl.when(pl.col(b).is_null()).then(None).otherwise(pl.col(a))) 

91 * _ewm(pl.when(pl.col(a).is_null()).then(None).otherwise(pl.col(b))) 

92 ).alias(f"{a}_{b}") 

93 for i, a in enumerate(assets) 

94 for b in assets[i:] 

95 ] 

96 

97 pair_df = data.with_columns(cov_exprs).drop(assets) 

98 all_keys = pair_df[index_col].to_list() 

99 pair_arr = pair_df.drop(index_col).to_numpy() 

100 

101 ii, jj = np.triu_indices(n) 

102 cube = np.full((len(all_keys), n, n), np.nan) 

103 cube[:, ii, jj] = pair_arr 

104 cube[:, jj, ii] = pair_arr 

105 

106 has_data = ~np.all(np.isnan(cube), axis=(1, 2)) 

107 return {k: cube[t] for t, k in enumerate(all_keys) if has_data[t]}