Coverage for src/jquantstats/_portfolio_cost.py: 100%
56 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-23 06:13 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-23 06:13 +0000
1"""Cost analysis mixin for Portfolio."""
3from __future__ import annotations
5import math
7import polars as pl
9from ._portfolio_base import _PortfolioMembers
10from .exceptions import InvalidMaxBpsError, NegativeCostBpsError
13class PortfolioCostMixin(_PortfolioMembers):
14 """Mixin providing cost analysis methods for Portfolio."""
16 @property
17 def position_delta_costs(self) -> pl.DataFrame:
18 """Daily trading cost using the position-delta model.
20 Computes the per-period cost as::
22 cost_t = sum_i( |x_{i,t} - x_{i,t-1}| ) * cost_per_unit
24 where ``x_{i,t}`` is the cash position in asset *i* at time *t* and
25 ``cost_per_unit`` is the one-way cost per unit of traded notional.
26 The first row is always zero because there is no prior position to
27 form a difference against.
29 Returns:
30 pl.DataFrame: Frame with an optional ``'date'`` column and a
31 ``'cost'`` column (absolute cash cost per period).
33 Examples:
34 >>> from jquantstats.portfolio import Portfolio
35 >>> import polars as pl
36 >>> from datetime import date
37 >>> _d = [date(2020, 1, 1), date(2020, 1, 2), date(2020, 1, 3)]
38 >>> prices = pl.DataFrame({"date": _d, "A": [100.0, 110.0, 121.0]})
39 >>> pos = pl.DataFrame({"date": _d, "A": [1000.0, 1200.0, 900.0]})
40 >>> pf = Portfolio(prices=prices, cashposition=pos, aum=1e5, cost_per_unit=0.01)
41 >>> pf.position_delta_costs["cost"].to_list()
42 [0.0, 2.0, 3.0]
43 """
44 assets = [c for c in self.cashposition.columns if c != "date" and self.cashposition[c].dtype.is_numeric()]
45 abs_position_changes = pl.sum_horizontal(pl.col(c).diff().abs().fill_null(0.0).fill_nan(0.0) for c in assets)
46 daily_cost = (abs_position_changes * self.cost_per_unit).alias("cost")
47 cols: list[str | pl.Expr] = []
48 if "date" in self.cashposition.columns:
49 cols.append("date")
50 cols.append(daily_cost)
51 return self.cashposition.select(cols)
53 @property
54 def net_cost_nav(self) -> pl.DataFrame:
55 """Net-of-cost cumulative additive NAV using the position-delta cost model.
57 Deducts `position_delta_costs` from daily portfolio profit and
58 computes the running cumulative sum offset by AUM. The result
59 represents the realised NAV path a strategy would achieve after paying
60 ``cost_per_unit`` on every unit of position change.
62 When ``cost_per_unit`` is zero the result equals `nav_accumulated`.
64 Returns:
65 pl.DataFrame: Frame with an optional ``'date'`` column,
66 ``'profit'``, ``'cost'``, and ``'NAV_accumulated_net'`` columns.
68 Examples:
69 >>> from jquantstats.portfolio import Portfolio
70 >>> import polars as pl
71 >>> from datetime import date
72 >>> _d = [date(2020, 1, 1), date(2020, 1, 2), date(2020, 1, 3)]
73 >>> prices = pl.DataFrame({"date": _d, "A": [100.0, 110.0, 121.0]})
74 >>> pos = pl.DataFrame({"date": _d, "A": [1000.0, 1200.0, 900.0]})
75 >>> pf = Portfolio(prices=prices, cashposition=pos, aum=1e5, cost_per_unit=0.0)
76 >>> net = pf.net_cost_nav
77 >>> list(net.columns)
78 ['date', 'profit', 'cost', 'NAV_accumulated_net']
79 """
80 profit_df = self.profit
81 cost_df = self.position_delta_costs
82 if "date" in profit_df.columns:
83 df = profit_df.join(cost_df, on="date", how="left")
84 else:
85 df = profit_df.hstack(cost_df.select(["cost"]))
86 return df.with_columns(((pl.col("profit") - pl.col("cost")).cum_sum() + self.aum).alias("NAV_accumulated_net"))
88 def cost_adjusted_returns(self, cost_bps: float | None = None) -> pl.DataFrame:
89 """Return daily portfolio returns net of estimated one-way trading costs.
91 Trading costs are modelled as a linear function of daily one-way
92 turnover: for every unit of AUM traded, the strategy incurs
93 ``cost_bps`` basis points (i.e. ``cost_bps / 10_000`` fractional
94 cost). The daily cost deduction is therefore::
96 daily_cost = turnover * (cost_bps / 10_000)
98 where ``turnover`` is the fraction-of-AUM one-way turnover already
99 computed by `turnover`. The deduction is applied to the
100 ``returns`` column of `returns`, leaving all other columns
101 (including ``date``) untouched.
103 Args:
104 cost_bps: One-way trading cost in basis points per unit of AUM
105 traded. Must be non-negative. Defaults to ``self.cost_bps``
106 set at construction time.
108 Returns:
109 pl.DataFrame: Same schema as `returns` but with the
110 ``returns`` column reduced by the per-period trading cost.
112 Raises:
113 TypeError: If ``cost_bps`` is not a number.
114 ValueError: If ``cost_bps`` is not finite (NaN or infinity).
115 NegativeCostBpsError: If ``cost_bps`` is negative.
117 Examples:
118 >>> from jquantstats.portfolio import Portfolio
119 >>> import polars as pl
120 >>> from datetime import date
121 >>> _d = [date(2020, 1, 1), date(2020, 1, 2), date(2020, 1, 3)]
122 >>> prices = pl.DataFrame({"date": _d, "A": [100.0, 110.0, 121.0]})
123 >>> pos = pl.DataFrame({"date": _d, "A": [1000.0, 1200.0, 900.0]})
124 >>> pf = Portfolio(prices=prices, cashposition=pos, aum=1e5)
125 >>> adj = pf.cost_adjusted_returns(0.0)
126 >>> float(adj["returns"][1]) == float(pf.returns["returns"][1])
127 True
128 """
129 effective_bps = cost_bps if cost_bps is not None else self.cost_bps
130 if isinstance(effective_bps, bool) or not isinstance(effective_bps, int | float):
131 raise TypeError(f"cost_bps must be a number, got {type(effective_bps).__name__}") # noqa: TRY003
132 effective_bps = float(effective_bps)
133 if not math.isfinite(effective_bps):
134 raise ValueError(f"cost_bps must be finite, got {effective_bps}") # noqa: TRY003
135 if effective_bps < 0:
136 raise NegativeCostBpsError(effective_bps)
137 base = self.returns
138 daily_cost = self.turnover["turnover"] * (effective_bps / 10_000.0)
139 return base.with_columns((pl.col("returns") - daily_cost).alias("returns"))
141 def trading_cost_impact(self, max_bps: int = 20) -> pl.DataFrame:
142 """Estimate the impact of trading costs on the Sharpe ratio.
144 Computes the annualised Sharpe ratio of cost-adjusted returns for
145 each integer cost level from 0 up to and including ``max_bps`` basis
146 points (1 bp = 0.01 %). The result lets you quickly assess at what
147 cost level the strategy's edge is eroded.
149 Args:
150 max_bps: Maximum one-way trading cost to evaluate, in basis
151 points. Defaults to 20 (i.e., evaluates 0, 1, 2, …, 20
152 bps). Must be a positive integer.
154 Returns:
155 pl.DataFrame: Frame with columns ``'cost_bps'`` (Int64) and
156 ``'sharpe'`` (Float64), one row per cost level from 0 to
157 ``max_bps`` inclusive.
159 Raises:
160 InvalidMaxBpsError: If ``max_bps`` is not a positive integer.
162 Examples:
163 >>> from jquantstats.portfolio import Portfolio
164 >>> import polars as pl
165 >>> from datetime import date, timedelta
166 >>> import numpy as np
167 >>> start = date(2020, 1, 1)
168 >>> dates = pl.date_range(
169 ... start=start, end=start + timedelta(days=99), interval="1d", eager=True
170 ... )
171 >>> rng = np.random.default_rng(0)
172 >>> prices = pl.DataFrame({
173 ... "date": dates,
174 ... "A": pl.Series(np.cumprod(1 + rng.normal(0.001, 0.01, 100)) * 100),
175 ... })
176 >>> pos = pl.DataFrame({"date": dates, "A": pl.Series(np.ones(100) * 1000.0)})
177 >>> pf = Portfolio(prices=prices, cashposition=pos, aum=1e5)
178 >>> impact = pf.trading_cost_impact(max_bps=5)
179 >>> list(impact["cost_bps"])
180 [0, 1, 2, 3, 4, 5]
181 """
182 if not isinstance(max_bps, int) or max_bps < 1:
183 raise InvalidMaxBpsError(max_bps)
184 import numpy as np
186 from ._stats._core import _std_is_negligible
188 periods = self.data._periods_per_year # one Data object, outside the loop
189 sqrt_periods = float(np.sqrt(periods))
190 cost_levels = list(range(max_bps + 1))
192 # Extract base returns and turnover once — O(1) allocations regardless of max_bps
193 base_rets = self.returns["returns"]
194 turnover_s = self.turnover["turnover"]
196 # Build all cost-adjusted return columns in one vectorised DataFrame construction,
197 # then compute means and stds in a single aggregate pass (no per-iteration allocation).
198 sweep = pl.DataFrame({str(bps): base_rets - turnover_s * (bps / 10_000.0) for bps in cost_levels})
199 means_row = sweep.mean().row(0)
200 stds_row = sweep.std(ddof=1).row(0)
202 sharpe_values: list[float] = []
203 for mean_raw, std_raw in zip(means_row, stds_row, strict=False):
204 mean_val = 0.0 if mean_raw is None else float(mean_raw)
205 if _std_is_negligible(std_raw, mean_val):
206 sharpe_values.append(float("nan"))
207 else:
208 sharpe_values.append(mean_val / float(std_raw) * sqrt_periods)
209 return pl.DataFrame({"cost_bps": pl.Series(cost_levels, dtype=pl.Int64), "sharpe": pl.Series(sharpe_values)})