Coverage for src/jquantstats/_utils/_data.py: 100%
99 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-23 06:13 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-23 06:13 +0000
1"""Utility methods for Data objects — the jquantstats equivalent of qs.utils."""
3from __future__ import annotations
5import math
6from collections.abc import Callable, Hashable
8import numpy as np
9import polars as pl
11from jquantstats._protocol import DataLike
13from ..exceptions import MissingDateColumnError
15__all__ = ["DataUtils"]
17# Maps human-readable aliases to Polars every-string format.
18_PERIOD_ALIASES: dict[str, str] = {
19 "daily": "1d",
20 "weekly": "1w",
21 "monthly": "1mo",
22 "quarterly": "1q",
23 "annual": "1y",
24 "yearly": "1y",
25}
28class DataUtils:
29 """Utility transforms and conversions for financial returns data.
31 Mirrors the public API of ``quantstats.utils`` but operates on Polars
32 DataFrames and integrates with `Data` via the
33 ``data.utils`` property.
34 """
36 __slots__ = ("_data",)
38 def __init__(self, data: DataLike) -> None:
39 self._data = data
41 def __repr__(self) -> str:
42 """Return a string representation of the DataUtils object."""
43 return f"DataUtils(assets={list(self._data.returns.columns)})"
45 # ── helpers ───────────────────────────────────────────────────────────────
47 def _combined(self) -> pl.DataFrame:
48 """Return index hstacked with returns (no benchmark)."""
49 return pl.concat([self._data.index, self._data.returns], how="horizontal")
51 def _asset_cols(self) -> list[str]:
52 """Return the asset column names from returns (excluding benchmark)."""
53 return list(self._data.returns.columns)
55 def _require_temporal_index(self, method: str) -> str:
56 """Raise MissingDateColumnError if the index is not temporal, else return date col name."""
57 date_cols = self._data.date_col
58 if not date_cols:
59 raise MissingDateColumnError(method) # pragma: no cover
60 date_col = date_cols[0]
61 if not self._data.index[date_col].dtype.is_temporal():
62 raise MissingDateColumnError(method)
63 return date_col
65 # ── public API ────────────────────────────────────────────────────────────
67 def to_prices(self, base: float = 1e5) -> pl.DataFrame:
68 """Convert returns to a cumulative price series.
70 Computes ``base * prod(1 + r_t)`` for each asset column, matching the
71 behaviour of ``quantstats.utils.to_prices``.
73 Args:
74 base: Starting value for the price series. Defaults to ``1e5``.
76 Returns:
77 DataFrame with the same date column (if present) and one price
78 column per asset.
80 """
81 asset_cols = self._asset_cols()
82 return self._combined().with_columns(
83 [(pl.col(c).fill_null(0.0) + 1.0).cum_prod().mul(base).alias(c) for c in asset_cols]
84 )
86 def to_log_returns(self) -> pl.DataFrame:
87 """Convert simple returns to log returns: ``ln(1 + r)``.
89 Matches ``quantstats.utils.to_log_returns``.
91 Returns:
92 DataFrame with the same columns as the input returns, values
93 replaced by their log-return equivalents.
95 """
96 asset_cols = self._asset_cols()
97 return self._combined().with_columns(
98 [(pl.col(c).fill_null(0.0) + 1.0).log(base=math.e).alias(c) for c in asset_cols]
99 )
101 def to_volatility_adjusted_returns(
102 self,
103 window: int = 60,
104 vol_estimator: Callable[[pl.Expr], pl.Expr] | None = None,
105 ) -> pl.DataFrame:
106 """Convert simple returns to volatility adjusted returns.
108 Divides each return by a lagged volatility estimate to avoid
109 look-ahead bias: ``vol_adjusted_r_t = r_t / vol(r_{t-1})``.
111 By default the volatility estimate is
112 ``pl.Expr.rolling_std(window)``. Pass *vol_estimator* to
113 override with any function that maps a ``pl.Expr`` to a
114 ``pl.Expr`` (e.g. an EWMA standard deviation).
116 Matches ``quantstats.utils.to_volatility_adjusted_returns``.
118 Args:
119 window: Rolling lookback for the default volatility
120 estimator. Ignored when *vol_estimator* is provided.
121 Defaults to ``60``.
122 vol_estimator: A callable ``(pl.Expr) -> pl.Expr`` that
123 produces a volatility series from a returns expression.
124 Defaults to ``None`` (uses ``rolling_std(window)``).
126 Returns:
127 DataFrame with the same columns as the input returns, values
128 replaced by their volatility adjusted equivalents.
130 """
131 if vol_estimator is None:
133 def vol_estimator(expr: pl.Expr) -> pl.Expr:
134 """Return rolling standard deviation over *window*."""
135 return expr.rolling_std(window)
137 asset_cols = self._asset_cols()
138 return self._combined().with_columns([pl.col(c) / vol_estimator(pl.col(c)).shift(1) for c in asset_cols])
140 def log_returns(self) -> pl.DataFrame:
141 """Alias for `to_log_returns`.
143 Matches ``quantstats.utils.log_returns``.
145 Returns:
146 DataFrame of log returns.
148 """
149 return self.to_log_returns()
151 def rebase(self, base: float = 100.0) -> pl.DataFrame:
152 """Normalise the returns as a price series that starts at *base*.
154 Converts returns to prices via `to_prices` and then rescales
155 each column so its first observation equals *base* exactly, matching
156 the behaviour of ``quantstats.utils.rebase``.
158 Args:
159 base: Target starting value. Defaults to ``100.0``.
161 Returns:
162 DataFrame with price columns anchored to *base* at t = 0.
164 """
165 prices_df = self.to_prices(base=1.0)
166 asset_cols = self._asset_cols()
167 return prices_df.with_columns([(pl.col(c) / pl.col(c).first() * base).alias(c) for c in asset_cols])
169 def winsorise(self, window: int = 7, n_sigma: float = 3.0) -> pl.DataFrame:
170 """Winsorise returns by clipping to within *n_sigma* rolling standard deviations.
172 For each asset column, values outside
173 ``rolling_mean ± n_sigma * rolling_std`` (computed over *window*)
174 are clipped to the respective bound.
176 Args:
177 window: Rolling lookback for mean and standard deviation.
178 Defaults to ``7``.
179 n_sigma: Number of standard deviations for the clip bounds.
180 Defaults to ``3.0``.
182 Returns:
183 DataFrame with the same columns as the input returns, extreme
184 values clipped.
185 """
186 asset_cols = self._asset_cols()
187 df = self._combined()
188 exprs = []
189 for c in asset_cols:
190 col = pl.col(c)
191 r_mean = col.rolling_mean(window).shift(1)
192 r_std = col.rolling_std(window).shift(1)
193 lower = r_mean - n_sigma * r_std
194 upper = r_mean + n_sigma * r_std
195 exprs.append(col.clip(lower_bound=lower, upper_bound=upper).alias(c))
196 return df.with_columns(exprs)
198 def group_returns(self, period: str = "1mo", compounded: bool = True) -> pl.DataFrame:
199 """Aggregate returns by a calendar period.
201 Requires a temporal (Date/Datetime) index; raises
202 `MissingDateColumnError` for integer-indexed data.
204 Human-readable aliases are accepted alongside native Polars interval
205 strings (``"1mo"``, ``"1q"``, ``"1y"``, ``"1w"``, ``"1d"``):
207 ``"daily"``, ``"weekly"``, ``"monthly"``, ``"quarterly"``,
208 ``"annual"`` / ``"yearly"``.
210 Args:
211 period: Aggregation period. Defaults to ``"1mo"`` (monthly).
212 compounded: When ``True`` (default) compound the returns
213 ``prod(1 + r) - 1``; when ``False`` sum them.
215 Returns:
216 DataFrame with one row per period and one column per asset.
218 """
219 date_col = self._require_temporal_index("group_returns")
220 polars_period = _PERIOD_ALIASES.get(period, period)
221 asset_cols = self._asset_cols()
223 if compounded:
224 agg_exprs = [((pl.col(c).fill_null(0.0) + 1.0).product() - 1.0).alias(c) for c in asset_cols]
225 else:
226 agg_exprs = [pl.col(c).fill_null(0.0).sum().alias(c) for c in asset_cols]
228 return (
229 self._combined()
230 .sort(date_col)
231 .group_by_dynamic(date_col, every=polars_period)
232 .agg(agg_exprs)
233 .sort(date_col)
234 )
236 def aggregate_returns(self, period: str = "1mo", compounded: bool = True) -> pl.DataFrame:
237 """Alias for `group_returns`.
239 Matches ``quantstats.utils.aggregate_returns``.
241 Args:
242 period: Aggregation period. See `group_returns` for accepted values.
243 compounded: Whether to compound returns. Defaults to ``True``.
245 Returns:
246 DataFrame with one row per period and one column per asset.
248 """
249 return self.group_returns(period=period, compounded=compounded)
251 def to_excess_returns(self, rf: float = 0.0, nperiods: int | None = None) -> pl.DataFrame:
252 """Subtract a risk-free rate from returns.
254 When *nperiods* is supplied the annual *rf* is converted to a
255 per-period rate via ``(1 + rf)^(1/nperiods) - 1``, matching
256 ``quantstats.utils.to_excess_returns``.
258 Args:
259 rf: Annual risk-free rate as a decimal (e.g. ``0.05`` for 5 %).
260 Defaults to ``0.0``.
261 nperiods: Number of return periods per year used to convert *rf*
262 to a per-period rate. When ``None`` *rf* is applied as-is.
264 Returns:
265 DataFrame of excess returns with the same columns as the input.
267 """
268 rf_per_period = ((1.0 + rf) ** (1.0 / nperiods) - 1.0) if nperiods is not None else rf
269 asset_cols = self._asset_cols()
270 return self._combined().with_columns([(pl.col(c) - rf_per_period).alias(c) for c in asset_cols])
272 def exponential_stdev(self, window: int = 30, is_halflife: bool = False) -> pl.DataFrame:
273 """Compute the exponentially weighted standard deviation of returns.
275 Matches ``quantstats.utils.exponential_stdev``. Uses Polars
276 ``ewm_std`` under the hood.
278 Args:
279 window: Span (default) or half-life (when *is_halflife* is
280 ``True``) of the exponential decay. Defaults to ``30``.
281 is_halflife: When ``True`` *window* is interpreted as the
282 half-life; otherwise it is the EWMA span. Defaults to
283 ``False``.
285 Returns:
286 DataFrame of rolling EWMA standard deviations with the same
287 columns as the input returns.
289 """
290 asset_cols = self._asset_cols()
291 if is_halflife:
292 exprs = [pl.col(c).ewm_std(half_life=window, min_samples=1).alias(c) for c in asset_cols]
293 else:
294 exprs = [pl.col(c).ewm_std(span=window, min_samples=1).alias(c) for c in asset_cols]
295 return self._combined().with_columns(exprs)
297 def exponential_cov(
298 self, window: int = 30, is_halflife: bool = False, warmup: int = 0
299 ) -> dict[Hashable, np.ndarray]:
300 """Compute the exponentially weighted covariance matrix of returns.
302 EWM covariance uses the identity
303 ``Cov(X, Y) = EWM(X*Y) - EWM(X)*EWM(Y)`` applied to the
304 *common non-null observations* of each pair, which is equivalent
305 to ``pandas.DataFrame.ewm(span).cov(bias=True)``.
307 Each date is included in the result as long as at least one
308 matrix entry is non-NaN. Cells involving a late-starting asset
309 are ``NaN`` until that asset has enough observations; the date is
310 never dropped on account of a single asset being unavailable.
311 Dates where every cell is NaN (before the warmup period is met
312 for any asset) are omitted.
314 Args:
315 window: Span (default) or half-life (when *is_halflife* is
316 ``True``) of the exponential decay. Defaults to ``30``.
317 is_halflife: When ``True`` *window* is interpreted as the
318 half-life; otherwise it is the EWMA span. Defaults to
319 ``False``.
320 warmup: Minimum number of common observations required before
321 a pair's cell is non-NaN. Defaults to ``0`` (cells are
322 non-NaN from the first shared observation).
324 Returns:
325 Dictionary keyed by index value (date or integer) mapping to
326 a square symmetric ``numpy.ndarray`` of shape ``(n, n)``
327 where ``n`` is the number of assets. Row/column order
328 matches ``data.assets``. Unavailable cells are ``NaN``.
330 """
331 if isinstance(warmup, bool) or not isinstance(warmup, int):
332 raise TypeError(f"warmup must be an integer, got {type(warmup).__name__}") # noqa: TRY003
333 if warmup < 0:
334 raise ValueError(f"warmup must be a non-negative integer, got {warmup}") # noqa: TRY003
336 asset_cols = self._asset_cols()
337 n = len(asset_cols)
338 min_samples = 1 if warmup == 0 else warmup
340 def _ewm(expr: pl.Expr) -> pl.Expr:
341 """Apply EWM mean with the configured span or half-life."""
342 if is_halflife:
343 return expr.ewm_mean(half_life=window, min_samples=min_samples)
344 return expr.ewm_mean(span=window, min_samples=min_samples)
346 # For each pair restrict both series to their common non-null rows
347 # so that all three EWMs use the same observation set.
348 cov_exprs = [
349 (
350 _ewm(pl.col(a) * pl.col(b))
351 - _ewm(pl.when(pl.col(b).is_null()).then(None).otherwise(pl.col(a)))
352 * _ewm(pl.when(pl.col(a).is_null()).then(None).otherwise(pl.col(b)))
353 ).alias(f"{a}_{b}")
354 for i, a in enumerate(asset_cols)
355 for b in asset_cols[i:]
356 ]
358 index_col = self._data.index.columns[0]
359 pair_df = self._combined().with_columns(cov_exprs).drop(asset_cols)
360 all_keys = pair_df[index_col].to_list()
361 pair_arr = pair_df.drop(index_col).to_numpy()
363 ii, jj = np.triu_indices(n)
364 cube = np.full((len(all_keys), n, n), np.nan)
365 cube[:, ii, jj] = pair_arr
366 cube[:, jj, ii] = pair_arr
368 # Drop dates where every cell is NaN (warmup not yet met for any asset)
369 has_data = ~np.all(np.isnan(cube), axis=(1, 2))
370 return {k: cube[t] for t, k in enumerate(all_keys) if has_data[t]}