Coverage for src/jquantstats/_utils/_data.py: 100%

99 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-06-23 06:13 +0000

1"""Utility methods for Data objects — the jquantstats equivalent of qs.utils.""" 

2 

3from __future__ import annotations 

4 

5import math 

6from collections.abc import Callable, Hashable 

7 

8import numpy as np 

9import polars as pl 

10 

11from jquantstats._protocol import DataLike 

12 

13from ..exceptions import MissingDateColumnError 

14 

15__all__ = ["DataUtils"] 

16 

17# Maps human-readable aliases to Polars every-string format. 

18_PERIOD_ALIASES: dict[str, str] = { 

19 "daily": "1d", 

20 "weekly": "1w", 

21 "monthly": "1mo", 

22 "quarterly": "1q", 

23 "annual": "1y", 

24 "yearly": "1y", 

25} 

26 

27 

28class DataUtils: 

29 """Utility transforms and conversions for financial returns data. 

30 

31 Mirrors the public API of ``quantstats.utils`` but operates on Polars 

32 DataFrames and integrates with `Data` via the 

33 ``data.utils`` property. 

34 """ 

35 

36 __slots__ = ("_data",) 

37 

38 def __init__(self, data: DataLike) -> None: 

39 self._data = data 

40 

41 def __repr__(self) -> str: 

42 """Return a string representation of the DataUtils object.""" 

43 return f"DataUtils(assets={list(self._data.returns.columns)})" 

44 

45 # ── helpers ─────────────────────────────────────────────────────────────── 

46 

47 def _combined(self) -> pl.DataFrame: 

48 """Return index hstacked with returns (no benchmark).""" 

49 return pl.concat([self._data.index, self._data.returns], how="horizontal") 

50 

51 def _asset_cols(self) -> list[str]: 

52 """Return the asset column names from returns (excluding benchmark).""" 

53 return list(self._data.returns.columns) 

54 

55 def _require_temporal_index(self, method: str) -> str: 

56 """Raise MissingDateColumnError if the index is not temporal, else return date col name.""" 

57 date_cols = self._data.date_col 

58 if not date_cols: 

59 raise MissingDateColumnError(method) # pragma: no cover 

60 date_col = date_cols[0] 

61 if not self._data.index[date_col].dtype.is_temporal(): 

62 raise MissingDateColumnError(method) 

63 return date_col 

64 

65 # ── public API ──────────────────────────────────────────────────────────── 

66 

67 def to_prices(self, base: float = 1e5) -> pl.DataFrame: 

68 """Convert returns to a cumulative price series. 

69 

70 Computes ``base * prod(1 + r_t)`` for each asset column, matching the 

71 behaviour of ``quantstats.utils.to_prices``. 

72 

73 Args: 

74 base: Starting value for the price series. Defaults to ``1e5``. 

75 

76 Returns: 

77 DataFrame with the same date column (if present) and one price 

78 column per asset. 

79 

80 """ 

81 asset_cols = self._asset_cols() 

82 return self._combined().with_columns( 

83 [(pl.col(c).fill_null(0.0) + 1.0).cum_prod().mul(base).alias(c) for c in asset_cols] 

84 ) 

85 

86 def to_log_returns(self) -> pl.DataFrame: 

87 """Convert simple returns to log returns: ``ln(1 + r)``. 

88 

89 Matches ``quantstats.utils.to_log_returns``. 

90 

91 Returns: 

92 DataFrame with the same columns as the input returns, values 

93 replaced by their log-return equivalents. 

94 

95 """ 

96 asset_cols = self._asset_cols() 

97 return self._combined().with_columns( 

98 [(pl.col(c).fill_null(0.0) + 1.0).log(base=math.e).alias(c) for c in asset_cols] 

99 ) 

100 

101 def to_volatility_adjusted_returns( 

102 self, 

103 window: int = 60, 

104 vol_estimator: Callable[[pl.Expr], pl.Expr] | None = None, 

105 ) -> pl.DataFrame: 

106 """Convert simple returns to volatility adjusted returns. 

107 

108 Divides each return by a lagged volatility estimate to avoid 

109 look-ahead bias: ``vol_adjusted_r_t = r_t / vol(r_{t-1})``. 

110 

111 By default the volatility estimate is 

112 ``pl.Expr.rolling_std(window)``. Pass *vol_estimator* to 

113 override with any function that maps a ``pl.Expr`` to a 

114 ``pl.Expr`` (e.g. an EWMA standard deviation). 

115 

116 Matches ``quantstats.utils.to_volatility_adjusted_returns``. 

117 

118 Args: 

119 window: Rolling lookback for the default volatility 

120 estimator. Ignored when *vol_estimator* is provided. 

121 Defaults to ``60``. 

122 vol_estimator: A callable ``(pl.Expr) -> pl.Expr`` that 

123 produces a volatility series from a returns expression. 

124 Defaults to ``None`` (uses ``rolling_std(window)``). 

125 

126 Returns: 

127 DataFrame with the same columns as the input returns, values 

128 replaced by their volatility adjusted equivalents. 

129 

130 """ 

131 if vol_estimator is None: 

132 

133 def vol_estimator(expr: pl.Expr) -> pl.Expr: 

134 """Return rolling standard deviation over *window*.""" 

135 return expr.rolling_std(window) 

136 

137 asset_cols = self._asset_cols() 

138 return self._combined().with_columns([pl.col(c) / vol_estimator(pl.col(c)).shift(1) for c in asset_cols]) 

139 

140 def log_returns(self) -> pl.DataFrame: 

141 """Alias for `to_log_returns`. 

142 

143 Matches ``quantstats.utils.log_returns``. 

144 

145 Returns: 

146 DataFrame of log returns. 

147 

148 """ 

149 return self.to_log_returns() 

150 

151 def rebase(self, base: float = 100.0) -> pl.DataFrame: 

152 """Normalise the returns as a price series that starts at *base*. 

153 

154 Converts returns to prices via `to_prices` and then rescales 

155 each column so its first observation equals *base* exactly, matching 

156 the behaviour of ``quantstats.utils.rebase``. 

157 

158 Args: 

159 base: Target starting value. Defaults to ``100.0``. 

160 

161 Returns: 

162 DataFrame with price columns anchored to *base* at t = 0. 

163 

164 """ 

165 prices_df = self.to_prices(base=1.0) 

166 asset_cols = self._asset_cols() 

167 return prices_df.with_columns([(pl.col(c) / pl.col(c).first() * base).alias(c) for c in asset_cols]) 

168 

169 def winsorise(self, window: int = 7, n_sigma: float = 3.0) -> pl.DataFrame: 

170 """Winsorise returns by clipping to within *n_sigma* rolling standard deviations. 

171 

172 For each asset column, values outside 

173 ``rolling_mean ± n_sigma * rolling_std`` (computed over *window*) 

174 are clipped to the respective bound. 

175 

176 Args: 

177 window: Rolling lookback for mean and standard deviation. 

178 Defaults to ``7``. 

179 n_sigma: Number of standard deviations for the clip bounds. 

180 Defaults to ``3.0``. 

181 

182 Returns: 

183 DataFrame with the same columns as the input returns, extreme 

184 values clipped. 

185 """ 

186 asset_cols = self._asset_cols() 

187 df = self._combined() 

188 exprs = [] 

189 for c in asset_cols: 

190 col = pl.col(c) 

191 r_mean = col.rolling_mean(window).shift(1) 

192 r_std = col.rolling_std(window).shift(1) 

193 lower = r_mean - n_sigma * r_std 

194 upper = r_mean + n_sigma * r_std 

195 exprs.append(col.clip(lower_bound=lower, upper_bound=upper).alias(c)) 

196 return df.with_columns(exprs) 

197 

198 def group_returns(self, period: str = "1mo", compounded: bool = True) -> pl.DataFrame: 

199 """Aggregate returns by a calendar period. 

200 

201 Requires a temporal (Date/Datetime) index; raises 

202 `MissingDateColumnError` for integer-indexed data. 

203 

204 Human-readable aliases are accepted alongside native Polars interval 

205 strings (``"1mo"``, ``"1q"``, ``"1y"``, ``"1w"``, ``"1d"``): 

206 

207 ``"daily"``, ``"weekly"``, ``"monthly"``, ``"quarterly"``, 

208 ``"annual"`` / ``"yearly"``. 

209 

210 Args: 

211 period: Aggregation period. Defaults to ``"1mo"`` (monthly). 

212 compounded: When ``True`` (default) compound the returns 

213 ``prod(1 + r) - 1``; when ``False`` sum them. 

214 

215 Returns: 

216 DataFrame with one row per period and one column per asset. 

217 

218 """ 

219 date_col = self._require_temporal_index("group_returns") 

220 polars_period = _PERIOD_ALIASES.get(period, period) 

221 asset_cols = self._asset_cols() 

222 

223 if compounded: 

224 agg_exprs = [((pl.col(c).fill_null(0.0) + 1.0).product() - 1.0).alias(c) for c in asset_cols] 

225 else: 

226 agg_exprs = [pl.col(c).fill_null(0.0).sum().alias(c) for c in asset_cols] 

227 

228 return ( 

229 self._combined() 

230 .sort(date_col) 

231 .group_by_dynamic(date_col, every=polars_period) 

232 .agg(agg_exprs) 

233 .sort(date_col) 

234 ) 

235 

236 def aggregate_returns(self, period: str = "1mo", compounded: bool = True) -> pl.DataFrame: 

237 """Alias for `group_returns`. 

238 

239 Matches ``quantstats.utils.aggregate_returns``. 

240 

241 Args: 

242 period: Aggregation period. See `group_returns` for accepted values. 

243 compounded: Whether to compound returns. Defaults to ``True``. 

244 

245 Returns: 

246 DataFrame with one row per period and one column per asset. 

247 

248 """ 

249 return self.group_returns(period=period, compounded=compounded) 

250 

251 def to_excess_returns(self, rf: float = 0.0, nperiods: int | None = None) -> pl.DataFrame: 

252 """Subtract a risk-free rate from returns. 

253 

254 When *nperiods* is supplied the annual *rf* is converted to a 

255 per-period rate via ``(1 + rf)^(1/nperiods) - 1``, matching 

256 ``quantstats.utils.to_excess_returns``. 

257 

258 Args: 

259 rf: Annual risk-free rate as a decimal (e.g. ``0.05`` for 5 %). 

260 Defaults to ``0.0``. 

261 nperiods: Number of return periods per year used to convert *rf* 

262 to a per-period rate. When ``None`` *rf* is applied as-is. 

263 

264 Returns: 

265 DataFrame of excess returns with the same columns as the input. 

266 

267 """ 

268 rf_per_period = ((1.0 + rf) ** (1.0 / nperiods) - 1.0) if nperiods is not None else rf 

269 asset_cols = self._asset_cols() 

270 return self._combined().with_columns([(pl.col(c) - rf_per_period).alias(c) for c in asset_cols]) 

271 

272 def exponential_stdev(self, window: int = 30, is_halflife: bool = False) -> pl.DataFrame: 

273 """Compute the exponentially weighted standard deviation of returns. 

274 

275 Matches ``quantstats.utils.exponential_stdev``. Uses Polars 

276 ``ewm_std`` under the hood. 

277 

278 Args: 

279 window: Span (default) or half-life (when *is_halflife* is 

280 ``True``) of the exponential decay. Defaults to ``30``. 

281 is_halflife: When ``True`` *window* is interpreted as the 

282 half-life; otherwise it is the EWMA span. Defaults to 

283 ``False``. 

284 

285 Returns: 

286 DataFrame of rolling EWMA standard deviations with the same 

287 columns as the input returns. 

288 

289 """ 

290 asset_cols = self._asset_cols() 

291 if is_halflife: 

292 exprs = [pl.col(c).ewm_std(half_life=window, min_samples=1).alias(c) for c in asset_cols] 

293 else: 

294 exprs = [pl.col(c).ewm_std(span=window, min_samples=1).alias(c) for c in asset_cols] 

295 return self._combined().with_columns(exprs) 

296 

297 def exponential_cov( 

298 self, window: int = 30, is_halflife: bool = False, warmup: int = 0 

299 ) -> dict[Hashable, np.ndarray]: 

300 """Compute the exponentially weighted covariance matrix of returns. 

301 

302 EWM covariance uses the identity 

303 ``Cov(X, Y) = EWM(X*Y) - EWM(X)*EWM(Y)`` applied to the 

304 *common non-null observations* of each pair, which is equivalent 

305 to ``pandas.DataFrame.ewm(span).cov(bias=True)``. 

306 

307 Each date is included in the result as long as at least one 

308 matrix entry is non-NaN. Cells involving a late-starting asset 

309 are ``NaN`` until that asset has enough observations; the date is 

310 never dropped on account of a single asset being unavailable. 

311 Dates where every cell is NaN (before the warmup period is met 

312 for any asset) are omitted. 

313 

314 Args: 

315 window: Span (default) or half-life (when *is_halflife* is 

316 ``True``) of the exponential decay. Defaults to ``30``. 

317 is_halflife: When ``True`` *window* is interpreted as the 

318 half-life; otherwise it is the EWMA span. Defaults to 

319 ``False``. 

320 warmup: Minimum number of common observations required before 

321 a pair's cell is non-NaN. Defaults to ``0`` (cells are 

322 non-NaN from the first shared observation). 

323 

324 Returns: 

325 Dictionary keyed by index value (date or integer) mapping to 

326 a square symmetric ``numpy.ndarray`` of shape ``(n, n)`` 

327 where ``n`` is the number of assets. Row/column order 

328 matches ``data.assets``. Unavailable cells are ``NaN``. 

329 

330 """ 

331 if isinstance(warmup, bool) or not isinstance(warmup, int): 

332 raise TypeError(f"warmup must be an integer, got {type(warmup).__name__}") # noqa: TRY003 

333 if warmup < 0: 

334 raise ValueError(f"warmup must be a non-negative integer, got {warmup}") # noqa: TRY003 

335 

336 asset_cols = self._asset_cols() 

337 n = len(asset_cols) 

338 min_samples = 1 if warmup == 0 else warmup 

339 

340 def _ewm(expr: pl.Expr) -> pl.Expr: 

341 """Apply EWM mean with the configured span or half-life.""" 

342 if is_halflife: 

343 return expr.ewm_mean(half_life=window, min_samples=min_samples) 

344 return expr.ewm_mean(span=window, min_samples=min_samples) 

345 

346 # For each pair restrict both series to their common non-null rows 

347 # so that all three EWMs use the same observation set. 

348 cov_exprs = [ 

349 ( 

350 _ewm(pl.col(a) * pl.col(b)) 

351 - _ewm(pl.when(pl.col(b).is_null()).then(None).otherwise(pl.col(a))) 

352 * _ewm(pl.when(pl.col(a).is_null()).then(None).otherwise(pl.col(b))) 

353 ).alias(f"{a}_{b}") 

354 for i, a in enumerate(asset_cols) 

355 for b in asset_cols[i:] 

356 ] 

357 

358 index_col = self._data.index.columns[0] 

359 pair_df = self._combined().with_columns(cov_exprs).drop(asset_cols) 

360 all_keys = pair_df[index_col].to_list() 

361 pair_arr = pair_df.drop(index_col).to_numpy() 

362 

363 ii, jj = np.triu_indices(n) 

364 cube = np.full((len(all_keys), n, n), np.nan) 

365 cube[:, ii, jj] = pair_arr 

366 cube[:, jj, ii] = pair_arr 

367 

368 # Drop dates where every cell is NaN (warmup not yet met for any asset) 

369 has_data = ~np.all(np.isnan(cube), axis=(1, 2)) 

370 return {k: cube[t] for t, k in enumerate(all_keys) if has_data[t]}