Coverage for src / cvx / linalg / ewm_cov.py: 100%

26 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-19 05:40 +0000

1"""Exponentially weighted covariance matrix computation.""" 

2 

3from __future__ import annotations 

4 

5from collections.abc import Hashable 

6 

7import numpy as np 

8import polars as pl 

9 

10 

11class NegativeWarmupError(ValueError): 

12 """Raised when warmup is a negative integer.""" 

13 

14 

15def ewm_covariance( 

16 data: pl.DataFrame, 

17 assets: list[str], 

18 index_col: str, 

19 window: int = 30, 

20 is_halflife: bool = False, 

21 warmup: int = 0, 

22) -> dict[Hashable, np.ndarray]: 

23 """Compute the exponentially weighted covariance matrix of returns. 

24 

25 EWM covariance uses the identity 

26 ``Cov(X, Y) = EWM(X*Y) - EWM(X)*EWM(Y)`` applied to the 

27 *common non-null observations* of each pair, which is equivalent 

28 to ``pandas.DataFrame.ewm(span).cov(bias=True)``. 

29 

30 Each date is included in the result as long as at least one 

31 matrix entry is non-NaN. Cells involving a late-starting asset 

32 are ``NaN`` until that asset has enough observations; the date is 

33 never dropped on account of a single asset being unavailable. 

34 Dates where every cell is NaN (before the warmup period is met 

35 for any asset) are omitted. 

36 

37 Args: 

38 data: Polars DataFrame containing the index column and asset columns. 

39 assets: Ordered list of asset column names. 

40 index_col: Name of the index (e.g. date) column in *data*. 

41 window: Span (default) or half-life (when *is_halflife* is 

42 ``True``) of the exponential decay. Defaults to ``30``. 

43 is_halflife: When ``True`` *window* is interpreted as the 

44 half-life; otherwise it is the EWMA span. Defaults to 

45 ``False``. 

46 warmup: Minimum number of common observations required before 

47 a pair's cell is non-NaN. Defaults to ``0`` (cells are 

48 non-NaN from the first shared observation). 

49 

50 Returns: 

51 Dictionary keyed by index value (date or integer) mapping to 

52 a square symmetric ``numpy.ndarray`` of shape ``(n, n)`` 

53 where ``n`` is the number of assets. Row/column order 

54 matches *assets*. Unavailable cells are ``NaN``. 

55 

56 """ 

57 if isinstance(warmup, bool) or not isinstance(warmup, int): 

58 raise TypeError 

59 if warmup < 0: 

60 raise NegativeWarmupError 

61 

62 n = len(assets) 

63 min_samples = 1 if warmup == 0 else warmup 

64 

65 def _ewm(expr: pl.Expr) -> pl.Expr: 

66 """Apply EWM mean with the configured span or half-life.""" 

67 if is_halflife: 

68 return expr.ewm_mean(half_life=window, min_samples=min_samples) 

69 return expr.ewm_mean(span=window, min_samples=min_samples) 

70 

71 cov_exprs = [ 

72 ( 

73 _ewm(pl.col(a) * pl.col(b)) 

74 - _ewm(pl.when(pl.col(b).is_null()).then(None).otherwise(pl.col(a))) 

75 * _ewm(pl.when(pl.col(a).is_null()).then(None).otherwise(pl.col(b))) 

76 ).alias(f"{a}_{b}") 

77 for i, a in enumerate(assets) 

78 for b in assets[i:] 

79 ] 

80 

81 pair_df = data.with_columns(cov_exprs).drop(assets) 

82 all_keys = pair_df[index_col].to_list() 

83 pair_arr = pair_df.drop(index_col).to_numpy() 

84 

85 ii, jj = np.triu_indices(n) 

86 cube = np.full((len(all_keys), n, n), np.nan) 

87 cube[:, ii, jj] = pair_arr 

88 cube[:, jj, ii] = pair_arr 

89 

90 has_data = ~np.all(np.isnan(cube), axis=(1, 2)) 

91 return {k: cube[t] for t, k in enumerate(all_keys) if has_data[t]}