Coverage for src/jquantstats/_stats/_drawdown.py: 100%
56 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-23 06:13 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-23 06:13 +0000
1"""Drawdown and cumulative-return metrics for financial returns data."""
3from __future__ import annotations
5from typing import TYPE_CHECKING, cast
7import polars as pl
9from ._core import columnwise_stat, to_frame
10from ._internals import _nav_series
12if TYPE_CHECKING:
13 from ..data import Data
15# ── Drawdown statistics mixin ─────────────────────────────────────────────────
18class _DrawdownMixin:
19 """Mixin providing cumulative-return and drawdown metrics.
21 Covers: compounded cumulative returns (``compsum``), the drawdown series,
22 price (NAV) conversion, maximum drawdown, and per-episode drawdown details.
23 """
25 _data: Data
26 all: pl.DataFrame
28 if TYPE_CHECKING:
29 from .._protocol import DataLike
31 data: DataLike
33 # ── Cumulative returns ────────────────────────────────────────────────────
35 @to_frame
36 def compsum(self, series: pl.Series) -> pl.Series:
37 """Calculate the rolling compounded (cumulative) returns.
39 Computed as cumprod(1 + r) - 1 for each period.
41 Args:
42 series (pl.Series): The series to calculate cumulative returns for.
44 Returns:
45 pl.Series: Cumulative compounded returns per period.
47 """
48 return (1.0 + series).cum_prod() - 1.0
50 # ── Drawdown ──────────────────────────────────────────────────────────────
52 @to_frame
53 def drawdown(self, series: pl.Series) -> pl.Series:
54 """Calculate the drawdown series for returns.
56 Args:
57 series (pl.Series): The series to calculate drawdown for.
59 Returns:
60 pl.Series: The drawdown series.
62 """
63 equity = self.prices(series)
64 d = (equity / equity.cum_max()) - 1
65 return -d
67 @staticmethod
68 def prices(series: pl.Series) -> pl.Series:
69 """Convert returns series to price series.
71 Args:
72 series (pl.Series): The returns series to convert.
74 Returns:
75 pl.Series: The price series.
77 """
78 return _nav_series(series)
80 @staticmethod
81 def max_drawdown_single_series(series: pl.Series) -> float:
82 """Compute the maximum drawdown for a single returns series.
84 Args:
85 series: A Polars Series of returns values.
87 Returns:
88 float: The maximum drawdown as a positive fraction (e.g. 0.2 for 20%).
89 """
90 price = _DrawdownMixin.prices(series)
91 peak = price.cum_max()
92 drawdown = price / peak - 1
93 dd_min = cast(float, drawdown.min())
94 return dd_min if dd_min is not None else 0.0
96 @columnwise_stat
97 def max_drawdown(self, series: pl.Series) -> float:
98 """Calculate the maximum drawdown for each column.
100 Args:
101 series (pl.Series): The series to calculate maximum drawdown for.
103 Returns:
104 float: The maximum drawdown value.
106 """
107 return _DrawdownMixin.max_drawdown_single_series(series)
109 def drawdown_details(self) -> dict[str, pl.DataFrame]:
110 """Return detailed statistics for each individual drawdown period.
112 For each contiguous underwater episode, records the start date, valley
113 (worst point), recovery date, total duration, maximum drawdown, and
114 recovery duration.
116 Returns:
117 dict[str, pl.DataFrame]: Per-asset DataFrames with columns
118 ``start``, ``valley``, ``end``, ``duration``, ``max_drawdown``,
119 ``recovery_duration``.
121 Note:
122 ``end`` and ``recovery_duration`` are ``null`` for drawdown periods
123 that have not yet recovered by the last observation.
124 ``max_drawdown`` is a negative fraction (e.g. ``-0.2`` for 20%).
125 """
126 all_df = self.all
127 date_col_name = self._data.date_col[0] if self._data.date_col else None
128 has_date = date_col_name is not None and all_df[date_col_name].dtype.is_temporal()
130 result: dict[str, pl.DataFrame] = {}
131 for col, series in self._data.items():
132 nav = _nav_series(series)
133 hwm = nav.cum_max()
134 in_dd = nav < hwm
135 dd_pct = nav / hwm - 1 # negative or zero
137 if has_date and date_col_name is not None:
138 dates = all_df[date_col_name]
139 else:
140 dates = pl.Series(list(range(len(series))), dtype=pl.Int64)
142 date_dtype = dates.dtype
144 frame = (
145 pl.DataFrame({"date": dates, "nav": nav, "dd_pct": dd_pct, "in_dd": in_dd})
146 .with_row_index("row_idx")
147 .with_columns(pl.col("in_dd").rle_id().cast(pl.Int64).alias("run_id"))
148 )
150 dd_frame = frame.filter(pl.col("in_dd"))
152 # A monotonic NAV has no underwater rows, so drawdown_details should return an empty typed frame.
153 if dd_frame.is_empty():
154 result[col] = pl.DataFrame(
155 {
156 "start": pl.Series([], dtype=date_dtype),
157 "valley": pl.Series([], dtype=date_dtype),
158 "end": pl.Series([], dtype=date_dtype),
159 "duration": pl.Series([], dtype=pl.Int64),
160 "max_drawdown": pl.Series([], dtype=pl.Float64),
161 "recovery_duration": pl.Series([], dtype=pl.Int64),
162 }
163 )
164 continue
166 # Per-period stats: start, last_dd_date, valley, max drawdown
167 dd_periods = (
168 dd_frame.group_by("run_id")
169 .agg(
170 [
171 pl.col("date").first().alias("start"),
172 pl.col("date").last().alias("last_dd_date"),
173 pl.col("date").sort_by("nav").first().alias("valley"),
174 pl.col("dd_pct").min().alias("max_drawdown"),
175 ]
176 )
177 .sort("start")
178 )
180 # First date of each non-drawdown run → recovery date for the preceding drawdown run
181 non_dd_starts = (
182 frame.filter(~pl.col("in_dd"))
183 .group_by("run_id")
184 .agg(pl.col("date").first().alias("end"))
185 .with_columns((pl.col("run_id") - 1).alias("run_id"))
186 )
188 dd_periods = dd_periods.join(non_dd_starts.select(["run_id", "end"]), on="run_id", how="left")
190 # Compute durations
191 if has_date:
192 dd_periods = dd_periods.with_columns(
193 [
194 pl.when(pl.col("end").is_not_null())
195 .then((pl.col("end") - pl.col("start")).dt.total_days())
196 .otherwise((pl.col("last_dd_date") - pl.col("start")).dt.total_days() + 1)
197 .cast(pl.Int64)
198 .alias("duration"),
199 pl.when(pl.col("end").is_not_null())
200 .then((pl.col("end") - pl.col("valley")).dt.total_days().cast(pl.Int64))
201 .otherwise(pl.lit(None, dtype=pl.Int64))
202 .alias("recovery_duration"),
203 ]
204 )
205 else:
206 dd_periods = dd_periods.with_columns(
207 [
208 pl.when(pl.col("end").is_not_null())
209 .then((pl.col("end") - pl.col("start")).cast(pl.Int64))
210 .otherwise((pl.col("last_dd_date") - pl.col("start") + 1).cast(pl.Int64))
211 .alias("duration"),
212 pl.when(pl.col("end").is_not_null())
213 .then((pl.col("end") - pl.col("valley")).cast(pl.Int64))
214 .otherwise(pl.lit(None, dtype=pl.Int64))
215 .alias("recovery_duration"),
216 ]
217 )
219 result[col] = dd_periods.select(["start", "valley", "end", "duration", "max_drawdown", "recovery_duration"])
221 return result