Coverage for src / jquantstats / _utils / _data.py: 100%

75 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-07 15:52 +0000

1"""Utility methods for Data objects — the jquantstats equivalent of qs.utils.""" 

2 

3from __future__ import annotations 

4 

5import dataclasses 

6import math 

7from collections.abc import Callable 

8 

9import polars as pl 

10 

11from ..exceptions import MissingDateColumnError 

12from ._protocol import DataLike 

13 

14__all__ = ["DataUtils"] 

15 

16# Maps human-readable aliases to Polars every-string format. 

17_PERIOD_ALIASES: dict[str, str] = { 

18 "daily": "1d", 

19 "weekly": "1w", 

20 "monthly": "1mo", 

21 "quarterly": "1q", 

22 "annual": "1y", 

23 "yearly": "1y", 

24} 

25 

26 

27@dataclasses.dataclass(frozen=True) 

28class DataUtils: 

29 """Utility transforms and conversions for financial returns data. 

30 

31 Mirrors the public API of ``quantstats.utils`` but operates on Polars 

32 DataFrames and integrates with `Data` via the 

33 ``data.utils`` property. 

34 

35 Attributes: 

36 data: Any object satisfying the `DataLike` 

37 protocol — typically a `Data` instance. 

38 

39 """ 

40 

41 data: DataLike 

42 

43 def __repr__(self) -> str: 

44 """Return a string representation of the DataUtils object.""" 

45 return f"DataUtils(assets={list(self.data.returns.columns)})" 

46 

47 # ── helpers ─────────────────────────────────────────────────────────────── 

48 

49 def _combined(self) -> pl.DataFrame: 

50 """Return index hstacked with returns (no benchmark).""" 

51 return pl.concat([self.data.index, self.data.returns], how="horizontal") 

52 

53 def _asset_cols(self) -> list[str]: 

54 """Return the asset column names from returns (excluding benchmark).""" 

55 return list(self.data.returns.columns) 

56 

57 def _require_temporal_index(self, method: str) -> str: 

58 """Raise MissingDateColumnError if the index is not temporal, else return date col name.""" 

59 date_cols = self.data.date_col 

60 if not date_cols: 

61 raise MissingDateColumnError(method) # pragma: no cover 

62 date_col = date_cols[0] 

63 if not self.data.index[date_col].dtype.is_temporal(): 

64 raise MissingDateColumnError(method) 

65 return date_col 

66 

67 # ── public API ──────────────────────────────────────────────────────────── 

68 

69 def to_prices(self, base: float = 1e5) -> pl.DataFrame: 

70 """Convert returns to a cumulative price series. 

71 

72 Computes ``base * prod(1 + r_t)`` for each asset column, matching the 

73 behaviour of ``quantstats.utils.to_prices``. 

74 

75 Args: 

76 base: Starting value for the price series. Defaults to ``1e5``. 

77 

78 Returns: 

79 DataFrame with the same date column (if present) and one price 

80 column per asset. 

81 

82 """ 

83 asset_cols = self._asset_cols() 

84 return self._combined().with_columns( 

85 [(pl.col(c).fill_null(0.0) + 1.0).cum_prod().mul(base).alias(c) for c in asset_cols] 

86 ) 

87 

88 def to_log_returns(self) -> pl.DataFrame: 

89 """Convert simple returns to log returns: ``ln(1 + r)``. 

90 

91 Matches ``quantstats.utils.to_log_returns``. 

92 

93 Returns: 

94 DataFrame with the same columns as the input returns, values 

95 replaced by their log-return equivalents. 

96 

97 """ 

98 asset_cols = self._asset_cols() 

99 return self._combined().with_columns( 

100 [(pl.col(c).fill_null(0.0) + 1.0).log(base=math.e).alias(c) for c in asset_cols] 

101 ) 

102 

103 def to_volatility_adjusted_returns( 

104 self, 

105 window: int = 60, 

106 vol_estimator: Callable[[pl.Expr], pl.Expr] | None = None, 

107 ) -> pl.DataFrame: 

108 """Convert simple returns to volatility adjusted returns. 

109 

110 Divides each return by a lagged volatility estimate to avoid 

111 look-ahead bias: ``vol_adjusted_r_t = r_t / vol(r_{t-1})``. 

112 

113 By default the volatility estimate is 

114 ``pl.Expr.rolling_std(window)``. Pass *vol_estimator* to 

115 override with any function that maps a ``pl.Expr`` to a 

116 ``pl.Expr`` (e.g. an EWMA standard deviation). 

117 

118 Matches ``quantstats.utils.to_volatility_adjusted_returns``. 

119 

120 Args: 

121 window: Rolling lookback for the default volatility 

122 estimator. Ignored when *vol_estimator* is provided. 

123 Defaults to ``60``. 

124 vol_estimator: A callable ``(pl.Expr) -> pl.Expr`` that 

125 produces a volatility series from a returns expression. 

126 Defaults to ``None`` (uses ``rolling_std(window)``). 

127 

128 Returns: 

129 DataFrame with the same columns as the input returns, values 

130 replaced by their volatility adjusted equivalents. 

131 

132 """ 

133 if vol_estimator is None: 

134 

135 def vol_estimator(expr: pl.Expr) -> pl.Expr: 

136 """Return rolling standard deviation over *window*.""" 

137 return expr.rolling_std(window) 

138 

139 asset_cols = self._asset_cols() 

140 return self._combined().with_columns([pl.col(c) / vol_estimator(pl.col(c)).shift(1) for c in asset_cols]) 

141 

142 def log_returns(self) -> pl.DataFrame: 

143 """Alias for `to_log_returns`. 

144 

145 Matches ``quantstats.utils.log_returns``. 

146 

147 Returns: 

148 DataFrame of log returns. 

149 

150 """ 

151 return self.to_log_returns() 

152 

153 def rebase(self, base: float = 100.0) -> pl.DataFrame: 

154 """Normalise the returns as a price series that starts at *base*. 

155 

156 Converts returns to prices via `to_prices` and then rescales 

157 each column so its first observation equals *base* exactly, matching 

158 the behaviour of ``quantstats.utils.rebase``. 

159 

160 Args: 

161 base: Target starting value. Defaults to ``100.0``. 

162 

163 Returns: 

164 DataFrame with price columns anchored to *base* at t = 0. 

165 

166 """ 

167 prices_df = self.to_prices(base=1.0) 

168 asset_cols = self._asset_cols() 

169 return prices_df.with_columns([(pl.col(c) / pl.col(c).first() * base).alias(c) for c in asset_cols]) 

170 

171 def winsorise(self, window: int = 7, n_sigma: float = 3.0) -> pl.DataFrame: 

172 """Winsorise returns by clipping to within *n_sigma* rolling standard deviations. 

173 

174 For each asset column, values outside 

175 ``rolling_mean ± n_sigma * rolling_std`` (computed over *window*) 

176 are clipped to the respective bound. 

177 

178 Args: 

179 window: Rolling lookback for mean and standard deviation. 

180 Defaults to ``7``. 

181 n_sigma: Number of standard deviations for the clip bounds. 

182 Defaults to ``3.0``. 

183 

184 Returns: 

185 DataFrame with the same columns as the input returns, extreme 

186 values clipped. 

187 """ 

188 asset_cols = self._asset_cols() 

189 df = self._combined() 

190 exprs = [] 

191 for c in asset_cols: 

192 col = pl.col(c) 

193 r_mean = col.rolling_mean(window).shift(1) 

194 r_std = col.rolling_std(window).shift(1) 

195 lower = r_mean - n_sigma * r_std 

196 upper = r_mean + n_sigma * r_std 

197 exprs.append(col.clip(lower_bound=lower, upper_bound=upper).alias(c)) 

198 return df.with_columns(exprs) 

199 

200 def group_returns(self, period: str = "1mo", compounded: bool = True) -> pl.DataFrame: 

201 """Aggregate returns by a calendar period. 

202 

203 Requires a temporal (Date/Datetime) index; raises 

204 `MissingDateColumnError` for integer-indexed data. 

205 

206 Human-readable aliases are accepted alongside native Polars interval 

207 strings (``"1mo"``, ``"1q"``, ``"1y"``, ``"1w"``, ``"1d"``): 

208 

209 ``"daily"``, ``"weekly"``, ``"monthly"``, ``"quarterly"``, 

210 ``"annual"`` / ``"yearly"``. 

211 

212 Args: 

213 period: Aggregation period. Defaults to ``"1mo"`` (monthly). 

214 compounded: When ``True`` (default) compound the returns 

215 ``prod(1 + r) - 1``; when ``False`` sum them. 

216 

217 Returns: 

218 DataFrame with one row per period and one column per asset. 

219 

220 """ 

221 date_col = self._require_temporal_index("group_returns") 

222 polars_period = _PERIOD_ALIASES.get(period, period) 

223 asset_cols = self._asset_cols() 

224 

225 if compounded: 

226 agg_exprs = [((pl.col(c).fill_null(0.0) + 1.0).product() - 1.0).alias(c) for c in asset_cols] 

227 else: 

228 agg_exprs = [pl.col(c).fill_null(0.0).sum().alias(c) for c in asset_cols] 

229 

230 return ( 

231 self._combined() 

232 .sort(date_col) 

233 .group_by_dynamic(date_col, every=polars_period) 

234 .agg(agg_exprs) 

235 .sort(date_col) 

236 ) 

237 

238 def aggregate_returns(self, period: str = "1mo", compounded: bool = True) -> pl.DataFrame: 

239 """Alias for `group_returns`. 

240 

241 Matches ``quantstats.utils.aggregate_returns``. 

242 

243 Args: 

244 period: Aggregation period. See `group_returns` for accepted values. 

245 compounded: Whether to compound returns. Defaults to ``True``. 

246 

247 Returns: 

248 DataFrame with one row per period and one column per asset. 

249 

250 """ 

251 return self.group_returns(period=period, compounded=compounded) 

252 

253 def to_excess_returns(self, rf: float = 0.0, nperiods: int | None = None) -> pl.DataFrame: 

254 """Subtract a risk-free rate from returns. 

255 

256 When *nperiods* is supplied the annual *rf* is converted to a 

257 per-period rate via ``(1 + rf)^(1/nperiods) - 1``, matching 

258 ``quantstats.utils.to_excess_returns``. 

259 

260 Args: 

261 rf: Annual risk-free rate as a decimal (e.g. ``0.05`` for 5 %). 

262 Defaults to ``0.0``. 

263 nperiods: Number of return periods per year used to convert *rf* 

264 to a per-period rate. When ``None`` *rf* is applied as-is. 

265 

266 Returns: 

267 DataFrame of excess returns with the same columns as the input. 

268 

269 """ 

270 rf_per_period = ((1.0 + rf) ** (1.0 / nperiods) - 1.0) if nperiods is not None else rf 

271 asset_cols = self._asset_cols() 

272 return self._combined().with_columns([(pl.col(c) - rf_per_period).alias(c) for c in asset_cols]) 

273 

274 def exponential_stdev(self, window: int = 30, is_halflife: bool = False) -> pl.DataFrame: 

275 """Compute the exponentially weighted standard deviation of returns. 

276 

277 Matches ``quantstats.utils.exponential_stdev``. Uses Polars 

278 ``ewm_std`` under the hood. 

279 

280 Args: 

281 window: Span (default) or half-life (when *is_halflife* is 

282 ``True``) of the exponential decay. Defaults to ``30``. 

283 is_halflife: When ``True`` *window* is interpreted as the 

284 half-life; otherwise it is the EWMA span. Defaults to 

285 ``False``. 

286 

287 Returns: 

288 DataFrame of rolling EWMA standard deviations with the same 

289 columns as the input returns. 

290 

291 """ 

292 asset_cols = self._asset_cols() 

293 if is_halflife: 

294 exprs = [pl.col(c).ewm_std(half_life=window, min_samples=1).alias(c) for c in asset_cols] 

295 else: 

296 exprs = [pl.col(c).ewm_std(span=window, min_samples=1).alias(c) for c in asset_cols] 

297 return self._combined().with_columns(exprs)