Coverage for src / jquantstats / _utils / _data.py: 100%
75 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-07 15:52 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-07 15:52 +0000
1"""Utility methods for Data objects — the jquantstats equivalent of qs.utils."""
3from __future__ import annotations
5import dataclasses
6import math
7from collections.abc import Callable
9import polars as pl
11from ..exceptions import MissingDateColumnError
12from ._protocol import DataLike
14__all__ = ["DataUtils"]
16# Maps human-readable aliases to Polars every-string format.
17_PERIOD_ALIASES: dict[str, str] = {
18 "daily": "1d",
19 "weekly": "1w",
20 "monthly": "1mo",
21 "quarterly": "1q",
22 "annual": "1y",
23 "yearly": "1y",
24}
27@dataclasses.dataclass(frozen=True)
28class DataUtils:
29 """Utility transforms and conversions for financial returns data.
31 Mirrors the public API of ``quantstats.utils`` but operates on Polars
32 DataFrames and integrates with `Data` via the
33 ``data.utils`` property.
35 Attributes:
36 data: Any object satisfying the `DataLike`
37 protocol — typically a `Data` instance.
39 """
41 data: DataLike
43 def __repr__(self) -> str:
44 """Return a string representation of the DataUtils object."""
45 return f"DataUtils(assets={list(self.data.returns.columns)})"
47 # ── helpers ───────────────────────────────────────────────────────────────
49 def _combined(self) -> pl.DataFrame:
50 """Return index hstacked with returns (no benchmark)."""
51 return pl.concat([self.data.index, self.data.returns], how="horizontal")
53 def _asset_cols(self) -> list[str]:
54 """Return the asset column names from returns (excluding benchmark)."""
55 return list(self.data.returns.columns)
57 def _require_temporal_index(self, method: str) -> str:
58 """Raise MissingDateColumnError if the index is not temporal, else return date col name."""
59 date_cols = self.data.date_col
60 if not date_cols:
61 raise MissingDateColumnError(method) # pragma: no cover
62 date_col = date_cols[0]
63 if not self.data.index[date_col].dtype.is_temporal():
64 raise MissingDateColumnError(method)
65 return date_col
67 # ── public API ────────────────────────────────────────────────────────────
69 def to_prices(self, base: float = 1e5) -> pl.DataFrame:
70 """Convert returns to a cumulative price series.
72 Computes ``base * prod(1 + r_t)`` for each asset column, matching the
73 behaviour of ``quantstats.utils.to_prices``.
75 Args:
76 base: Starting value for the price series. Defaults to ``1e5``.
78 Returns:
79 DataFrame with the same date column (if present) and one price
80 column per asset.
82 """
83 asset_cols = self._asset_cols()
84 return self._combined().with_columns(
85 [(pl.col(c).fill_null(0.0) + 1.0).cum_prod().mul(base).alias(c) for c in asset_cols]
86 )
88 def to_log_returns(self) -> pl.DataFrame:
89 """Convert simple returns to log returns: ``ln(1 + r)``.
91 Matches ``quantstats.utils.to_log_returns``.
93 Returns:
94 DataFrame with the same columns as the input returns, values
95 replaced by their log-return equivalents.
97 """
98 asset_cols = self._asset_cols()
99 return self._combined().with_columns(
100 [(pl.col(c).fill_null(0.0) + 1.0).log(base=math.e).alias(c) for c in asset_cols]
101 )
103 def to_volatility_adjusted_returns(
104 self,
105 window: int = 60,
106 vol_estimator: Callable[[pl.Expr], pl.Expr] | None = None,
107 ) -> pl.DataFrame:
108 """Convert simple returns to volatility adjusted returns.
110 Divides each return by a lagged volatility estimate to avoid
111 look-ahead bias: ``vol_adjusted_r_t = r_t / vol(r_{t-1})``.
113 By default the volatility estimate is
114 ``pl.Expr.rolling_std(window)``. Pass *vol_estimator* to
115 override with any function that maps a ``pl.Expr`` to a
116 ``pl.Expr`` (e.g. an EWMA standard deviation).
118 Matches ``quantstats.utils.to_volatility_adjusted_returns``.
120 Args:
121 window: Rolling lookback for the default volatility
122 estimator. Ignored when *vol_estimator* is provided.
123 Defaults to ``60``.
124 vol_estimator: A callable ``(pl.Expr) -> pl.Expr`` that
125 produces a volatility series from a returns expression.
126 Defaults to ``None`` (uses ``rolling_std(window)``).
128 Returns:
129 DataFrame with the same columns as the input returns, values
130 replaced by their volatility adjusted equivalents.
132 """
133 if vol_estimator is None:
135 def vol_estimator(expr: pl.Expr) -> pl.Expr:
136 """Return rolling standard deviation over *window*."""
137 return expr.rolling_std(window)
139 asset_cols = self._asset_cols()
140 return self._combined().with_columns([pl.col(c) / vol_estimator(pl.col(c)).shift(1) for c in asset_cols])
142 def log_returns(self) -> pl.DataFrame:
143 """Alias for `to_log_returns`.
145 Matches ``quantstats.utils.log_returns``.
147 Returns:
148 DataFrame of log returns.
150 """
151 return self.to_log_returns()
153 def rebase(self, base: float = 100.0) -> pl.DataFrame:
154 """Normalise the returns as a price series that starts at *base*.
156 Converts returns to prices via `to_prices` and then rescales
157 each column so its first observation equals *base* exactly, matching
158 the behaviour of ``quantstats.utils.rebase``.
160 Args:
161 base: Target starting value. Defaults to ``100.0``.
163 Returns:
164 DataFrame with price columns anchored to *base* at t = 0.
166 """
167 prices_df = self.to_prices(base=1.0)
168 asset_cols = self._asset_cols()
169 return prices_df.with_columns([(pl.col(c) / pl.col(c).first() * base).alias(c) for c in asset_cols])
171 def winsorise(self, window: int = 7, n_sigma: float = 3.0) -> pl.DataFrame:
172 """Winsorise returns by clipping to within *n_sigma* rolling standard deviations.
174 For each asset column, values outside
175 ``rolling_mean ± n_sigma * rolling_std`` (computed over *window*)
176 are clipped to the respective bound.
178 Args:
179 window: Rolling lookback for mean and standard deviation.
180 Defaults to ``7``.
181 n_sigma: Number of standard deviations for the clip bounds.
182 Defaults to ``3.0``.
184 Returns:
185 DataFrame with the same columns as the input returns, extreme
186 values clipped.
187 """
188 asset_cols = self._asset_cols()
189 df = self._combined()
190 exprs = []
191 for c in asset_cols:
192 col = pl.col(c)
193 r_mean = col.rolling_mean(window).shift(1)
194 r_std = col.rolling_std(window).shift(1)
195 lower = r_mean - n_sigma * r_std
196 upper = r_mean + n_sigma * r_std
197 exprs.append(col.clip(lower_bound=lower, upper_bound=upper).alias(c))
198 return df.with_columns(exprs)
200 def group_returns(self, period: str = "1mo", compounded: bool = True) -> pl.DataFrame:
201 """Aggregate returns by a calendar period.
203 Requires a temporal (Date/Datetime) index; raises
204 `MissingDateColumnError` for integer-indexed data.
206 Human-readable aliases are accepted alongside native Polars interval
207 strings (``"1mo"``, ``"1q"``, ``"1y"``, ``"1w"``, ``"1d"``):
209 ``"daily"``, ``"weekly"``, ``"monthly"``, ``"quarterly"``,
210 ``"annual"`` / ``"yearly"``.
212 Args:
213 period: Aggregation period. Defaults to ``"1mo"`` (monthly).
214 compounded: When ``True`` (default) compound the returns
215 ``prod(1 + r) - 1``; when ``False`` sum them.
217 Returns:
218 DataFrame with one row per period and one column per asset.
220 """
221 date_col = self._require_temporal_index("group_returns")
222 polars_period = _PERIOD_ALIASES.get(period, period)
223 asset_cols = self._asset_cols()
225 if compounded:
226 agg_exprs = [((pl.col(c).fill_null(0.0) + 1.0).product() - 1.0).alias(c) for c in asset_cols]
227 else:
228 agg_exprs = [pl.col(c).fill_null(0.0).sum().alias(c) for c in asset_cols]
230 return (
231 self._combined()
232 .sort(date_col)
233 .group_by_dynamic(date_col, every=polars_period)
234 .agg(agg_exprs)
235 .sort(date_col)
236 )
238 def aggregate_returns(self, period: str = "1mo", compounded: bool = True) -> pl.DataFrame:
239 """Alias for `group_returns`.
241 Matches ``quantstats.utils.aggregate_returns``.
243 Args:
244 period: Aggregation period. See `group_returns` for accepted values.
245 compounded: Whether to compound returns. Defaults to ``True``.
247 Returns:
248 DataFrame with one row per period and one column per asset.
250 """
251 return self.group_returns(period=period, compounded=compounded)
253 def to_excess_returns(self, rf: float = 0.0, nperiods: int | None = None) -> pl.DataFrame:
254 """Subtract a risk-free rate from returns.
256 When *nperiods* is supplied the annual *rf* is converted to a
257 per-period rate via ``(1 + rf)^(1/nperiods) - 1``, matching
258 ``quantstats.utils.to_excess_returns``.
260 Args:
261 rf: Annual risk-free rate as a decimal (e.g. ``0.05`` for 5 %).
262 Defaults to ``0.0``.
263 nperiods: Number of return periods per year used to convert *rf*
264 to a per-period rate. When ``None`` *rf* is applied as-is.
266 Returns:
267 DataFrame of excess returns with the same columns as the input.
269 """
270 rf_per_period = ((1.0 + rf) ** (1.0 / nperiods) - 1.0) if nperiods is not None else rf
271 asset_cols = self._asset_cols()
272 return self._combined().with_columns([(pl.col(c) - rf_per_period).alias(c) for c in asset_cols])
274 def exponential_stdev(self, window: int = 30, is_halflife: bool = False) -> pl.DataFrame:
275 """Compute the exponentially weighted standard deviation of returns.
277 Matches ``quantstats.utils.exponential_stdev``. Uses Polars
278 ``ewm_std`` under the hood.
280 Args:
281 window: Span (default) or half-life (when *is_halflife* is
282 ``True``) of the exponential decay. Defaults to ``30``.
283 is_halflife: When ``True`` *window* is interpreted as the
284 half-life; otherwise it is the EWMA span. Defaults to
285 ``False``.
287 Returns:
288 DataFrame of rolling EWMA standard deviations with the same
289 columns as the input returns.
291 """
292 asset_cols = self._asset_cols()
293 if is_halflife:
294 exprs = [pl.col(c).ewm_std(half_life=window, min_samples=1).alias(c) for c in asset_cols]
295 else:
296 exprs = [pl.col(c).ewm_std(span=window, min_samples=1).alias(c) for c in asset_cols]
297 return self._combined().with_columns(exprs)