Coverage for src / basanos / analytics / portfolio.py: 100%
208 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-19 05:23 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-19 05:23 +0000
1"""Portfolio utilities for computing profits, NAV, and Sharpe using Polars and Plotly.
3This module provides a Portfolio dataclass and helpers to compute per-asset profits,
4aggregate portfolio profit, NAV, Sharpe ratio, and produce Plotly visualizations.
5"""
7import dataclasses
9import polars as pl
10import polars.selectors as cs
12from ..exceptions import (
13 IntegerIndexBoundError,
14 InvalidCashPositionTypeError,
15 InvalidPricesTypeError,
16 MissingDateColumnError,
17 NonPositiveAumError,
18 RowCountMismatchError,
19)
20from ._plots import Plots
21from ._report import Report
22from ._stats import Stats
25@dataclasses.dataclass(frozen=True)
26class Portfolio:
27 """Store prices, positions, and compute portfolio statistics.
29 Attributes:
30 cashposition: Polars DataFrame of positions per asset over time (includes date column if present).
31 prices: Polars DataFrame of prices per asset over time (includes date column if present).
32 aum: Assets under management used as base NAV offset.
35 Examples:
36 >>> import polars as pl
37 >>> from datetime import date
38 >>> prices = pl.DataFrame({"date": [date(2020, 1, 1), date(2020, 1, 2)], "A": [100.0, 110.0]})
39 >>> pos = pl.DataFrame({"date": [date(2020, 1, 1), date(2020, 1, 2)], "A": [1000.0, 1000.0]})
40 >>> pf = Portfolio(prices=prices, cashposition=pos)
41 >>> pf.assets
42 ['A']
43 """
45 cashposition: pl.DataFrame
46 prices: pl.DataFrame
47 aum: float = 1e8
49 def __post_init__(self) -> None:
50 """Validate input types, shapes, and parameters post-initialization."""
51 # Input validation
52 if not isinstance(self.prices, pl.DataFrame):
53 raise InvalidPricesTypeError(type(self.prices).__name__)
54 if not isinstance(self.cashposition, pl.DataFrame):
55 raise InvalidCashPositionTypeError(type(self.cashposition).__name__)
57 if self.cashposition.shape[0] != self.prices.shape[0]:
58 raise RowCountMismatchError(self.prices.shape[0], self.cashposition.shape[0])
59 if self.aum <= 0.0:
60 raise NonPositiveAumError(self.aum)
62 @classmethod
63 def from_risk_position(
64 cls, prices: pl.DataFrame, risk_position: pl.DataFrame, vola: int = 32, aum: float = 1e8
65 ) -> "Portfolio":
66 """Create a Portfolio from per-asset risk positions by de-volatizing with EWMA volatility.
68 Args:
69 prices: Price levels per asset over time (may include a date column).
70 risk_position: Risk units per asset (e.g., target risk exposure) aligned with prices.
71 vola: EWMA lookback (span-equivalent) used to estimate volatility in trading days.
72 aum: Assets under management used as the base NAV offset.
74 Returns:
75 A Portfolio instance whose cash positions are risk_position divided by EWMA volatility.
76 """
77 assets = [col for col, dtype in prices.schema.items() if dtype.is_numeric()]
79 def vol(col_name: str, vola: int) -> pl.Expr: # pragma: no cover
80 """Return an EWMA volatility expression for the given column and lookback."""
81 return pl.col(col_name).pct_change().ewm_std(com=vola - 1, adjust=True, min_samples=vola)
83 # Join prices to risk_position to compute volatility from price data
84 cash_position = risk_position.with_columns(
85 (pl.col(asset) / prices[asset].pct_change().ewm_std(com=vola - 1, adjust=True, min_samples=vola)).alias(
86 asset
87 )
88 for asset in assets
89 )
90 return cls(prices=prices, cashposition=cash_position, aum=aum)
92 @classmethod
93 def from_cash_position(cls, prices: pl.DataFrame, cash_position: pl.DataFrame, aum: float = 1e8) -> "Portfolio":
94 """Create a Portfolio directly from cash positions aligned with prices.
96 Args:
97 prices: Price levels per asset over time (may include a date column).
98 cash_position: Cash exposure per asset over time (same shape/index as prices).
99 aum: Assets under management used as the base NAV offset.
101 Returns:
102 A Portfolio instance with the provided cash positions.
103 """
104 return cls(prices=prices, cashposition=cash_position, aum=aum)
106 @property
107 def profits(self) -> pl.DataFrame:
108 """Compute per-asset daily cash profits, preserving non-numeric columns.
110 Returns:
111 pl.DataFrame: Per-asset daily profit series along with any non-numeric
112 columns (e.g., 'date').
114 Examples:
115 >>> import polars as pl
116 >>> prices = pl.DataFrame({"A": [100.0, 110.0, 105.0]})
117 >>> pos = pl.DataFrame({"A": [1000.0, 1000.0, 1000.0]})
118 >>> pf = Portfolio(prices=prices, cashposition=pos)
119 >>> pf.profits.columns
120 ['A']
121 """
122 assets = [c for c in self.prices.columns if self.prices[c].dtype.is_numeric()]
124 # Compute daily profits per numeric column
125 profits = self.prices.with_columns(
126 (self.prices[asset].pct_change().fill_null(0.0) * self.cashposition[asset].shift(n=1).fill_null(0.0)).alias(
127 asset
128 )
129 for asset in assets
130 )
132 # Ensure there are no Nulls/NaNs/Infs in numeric profit columns
133 # - Fill nulls with 0.0 (should already be handled above, but double-guard)
134 # - Replace non-finite values (NaN/Inf) with 0.0
135 if assets:
136 profits = profits.with_columns(
137 pl.when(pl.col(c).is_finite()).then(pl.col(c)).otherwise(0.0).fill_null(0.0).alias(c) for c in assets
138 )
139 # Guards to guarantee cleanliness
140 for c in assets:
141 s = profits[c]
142 if int(s.null_count()) != 0:
143 raise ValueError # pragma: no cover
144 if not bool(pl.Series(s).is_finite().all()):
145 raise ValueError # pragma: no cover
147 return profits
149 @staticmethod
150 def _assert_clean_series(series: pl.Series, name: str = "") -> None:
151 """Raise ValueError if the series contains nulls or non-finite values."""
152 if series.null_count() != 0:
153 raise ValueError
154 if not series.is_finite().all():
155 raise ValueError
157 @property
158 def profit(self) -> pl.DataFrame:
159 """Return total daily portfolio profit including the 'date' column.
161 Ensures that no day's total profit is NaN/null by asserting the
162 'profit' column has zero nulls.
163 """
164 df_profits = self.profits
165 assets = [c for c in df_profits.columns if df_profits[c].dtype.is_numeric()]
167 if not assets:
168 raise ValueError
170 non_assets = [c for c in df_profits.columns if c not in set(assets)]
171 # numeric_cols, non_numeric_cols = split_numeric_non_numeric(df_profits)
173 # Row-wise sum of numeric columns
174 portfolio_daily_profit = pl.sum_horizontal([pl.col(c).fill_null(0.0) for c in assets]).alias("profit")
176 # Combine with non-numeric columns (like 'date')
177 result = df_profits.select([*non_assets, portfolio_daily_profit])
179 # Guard: profit must not contain NaN/null values
180 # Use null_count to cover both nulls and NaNs (Polars treats NaNs as not-null but we ensure
181 # inputs are numeric and filled; additional check for finite values guards against NaN/Inf)
182 self._assert_clean_series(series=result["profit"])
184 return result
186 @property
187 def nav_accumulated(self) -> pl.DataFrame:
188 """Compute cumulative NAV of the portfolio including 'date'."""
189 # Compute cumulative sum of profit column and expose as 'NAV'
190 return self.profit.with_columns((pl.col("profit").cum_sum() + self.aum).alias("NAV_accumulated"))
192 @property
193 def returns(self) -> pl.DataFrame:
194 """Return daily returns as profit scaled by AUM, preserving 'date'.
196 The returned DataFrame contains the original 'date' column with the
197 'profit' column scaled by AUM (i.e., per-period returns), and also an
198 additional convenience column named 'returns' with the same values for
199 downstream consumers.
200 """
201 return self.nav_accumulated.with_columns(
202 (pl.col("profit") / self.aum).alias("returns"),
203 )
205 @property
206 def monthly(self) -> pl.DataFrame:
207 """Return monthly compounded returns and calendar columns.
209 Aggregates daily returns (profit/AUM) by calendar month and computes
210 the compounded monthly return: prod(1 + r_d) - 1. The resulting frame
211 includes:
212 - date: month-end label as a Polars Date (end of the grouping window)
213 - returns: compounded monthly return
214 - NAV_accumulated: last NAV within the month
215 - profit: summed profit within the month
216 - year: integer year (e.g., 2020)
217 - month: integer month number (1-12)
218 - month_name: abbreviated month name (e.g., "Jan", "Feb")
220 Raises:
221 ValueError: If the portfolio data has no 'date' column. Monthly
222 aggregation requires temporal date information.
223 """
224 if "date" not in self.prices.columns:
225 raise MissingDateColumnError("monthly")
226 daily = self.returns.select(["date", "returns", "profit", "NAV_accumulated"]) # ensure only required columns
227 monthly = (
228 daily.group_by_dynamic(
229 "date",
230 every="1mo",
231 period="1mo",
232 label="left",
233 closed="right",
234 )
235 .agg(
236 [
237 pl.col("profit").sum().alias("profit"),
238 pl.col("NAV_accumulated").last().alias("NAV_accumulated"),
239 (pl.col("returns") + 1.0).product().alias("gross"),
240 ]
241 )
242 .with_columns((pl.col("gross") - 1.0).alias("returns"))
243 .select(["date", "returns", "NAV_accumulated", "profit"]) # keep month-end date
244 .with_columns(
245 [
246 pl.col("date").dt.year().alias("year"),
247 pl.col("date").dt.month().alias("month"),
248 pl.col("date").dt.strftime("%b").alias("month_name"),
249 ]
250 )
251 .sort("date")
252 )
253 return monthly
255 @property
256 def nav_compounded(self) -> pl.DataFrame:
257 """Compute compounded NAV from returns (profit/AUM), preserving 'date'."""
258 # self.returns contains 'date' and scaled 'profit' (i.e., returns)
259 return self.returns.with_columns(((pl.col("returns") + 1.0).cum_prod() * self.aum).alias("NAV_compounded"))
261 @property
262 def highwater(self) -> pl.DataFrame:
263 """Return the cumulative maximum of NAV as the high-water mark series.
265 The resulting DataFrame preserves the 'date' column and adds a
266 'highwater' column computed as the cumulative maximum of
267 'NAV_accumulated'.
268 """
269 return self.returns.with_columns(pl.col("NAV_accumulated").cum_max().alias("highwater"))
271 @property
272 def drawdown(self) -> pl.DataFrame:
273 """Return drawdown as the distance from high-water mark to current NAV.
275 Computes 'drawdown' = 'highwater' - 'NAV_accumulated' and preserves the
276 'date' column alongside the intermediate columns.
277 """
278 return self.highwater.with_columns(
279 (pl.col("highwater") - pl.col("NAV_accumulated")).alias("drawdown"),
280 ((pl.col("highwater") - pl.col("NAV_accumulated")) / pl.col("highwater")).alias("drawdown_pct"),
281 )
283 @property
284 def all(self) -> pl.DataFrame:
285 """Return a merged view of drawdown and compounded NAV.
287 When a 'date' column is present the two frames are joined on that
288 column to ensure temporal alignment. When the data is integer-indexed
289 (no 'date' column) the frames are stacked horizontally - they are
290 guaranteed to have identical row counts because both are derived from
291 the same source portfolio.
292 """
293 # Start with drawdown (includes date, NAV_accumulated, highwater, drawdown, drawdown_pct, etc.)
294 left = self.drawdown
295 # From nav_compounded, only take the additional compounded NAV column to avoid duplicate fields
296 if "date" in left.columns:
297 right = self.nav_compounded.select(["date", "NAV_compounded"])
298 return left.join(right, on="date", how="inner")
299 else:
300 right = self.nav_compounded.select(["NAV_compounded"])
301 return left.hstack(right)
303 @property
304 def stats(self) -> Stats:
305 """Return a Stats object built from the portfolio's daily returns.
307 Constructs a basanos.analytics.Stats instance from the portfolio
308 returns. When a 'date' column is present both 'date' and 'returns'
309 are passed to Stats; otherwise only 'returns' is used.
310 """
311 cols = ["date", "returns"] if "date" in self.returns.columns else ["returns"]
312 return Stats(data=self.returns.select(cols))
314 def truncate(self, start: object = None, end: object = None) -> "Portfolio":
315 """Return a new Portfolio truncated to the inclusive [start, end] range.
317 When a 'date' column is present in both prices and cash positions,
318 truncation is performed by comparing the 'date' column against
319 ``start`` and ``end`` (which should be date/datetime values or strings
320 parseable by Polars).
322 When the 'date' column is absent, integer-based row slicing is used
323 instead. In this case ``start`` and ``end`` must be non-negative
324 integers representing 0-based row indices. Passing non-integer bounds
325 to an integer-indexed portfolio raises ``TypeError``.
327 In all cases the ``aum`` value is preserved.
329 Args:
330 start: Optional lower bound (inclusive). A date/datetime or
331 Polars-parseable string when a 'date' column exists; a
332 non-negative int row index when the data has no 'date' column.
333 end: Optional upper bound (inclusive). Same type rules as ``start``.
335 Returns:
336 A new Portfolio instance with prices and cash positions filtered to
337 the specified range.
339 Raises:
340 TypeError: When the portfolio has no 'date' column and a non-integer
341 bound is supplied.
342 """
343 has_date = "date" in self.prices.columns
344 if has_date:
345 cond = pl.lit(True)
346 if start is not None:
347 cond = cond & (pl.col("date") >= pl.lit(start))
348 if end is not None:
349 cond = cond & (pl.col("date") <= pl.lit(end))
350 pr = self.prices.filter(cond)
351 cp = self.cashposition.filter(cond)
352 else:
353 # Integer row-index slicing for date-free portfolios
354 if start is not None and not isinstance(start, int):
355 raise IntegerIndexBoundError("start", type(start).__name__)
356 if end is not None and not isinstance(end, int):
357 raise IntegerIndexBoundError("end", type(end).__name__)
358 row_start = int(start) if start is not None else 0
359 row_end = int(end) + 1 if end is not None else self.prices.height
360 length = max(0, row_end - row_start)
361 pr = self.prices.slice(row_start, length)
362 cp = self.cashposition.slice(row_start, length)
363 return Portfolio(prices=pr, cashposition=cp, aum=self.aum)
365 def lag(self, n: int) -> "Portfolio":
366 """Return a new Portfolio with cash positions lagged by ``n`` steps.
368 This method shifts the numeric asset columns in the cashposition
369 DataFrame by ``n`` rows, preserving the 'date' column and any
370 non-numeric columns unchanged. Positive ``n`` delays weights
371 (moves them down); negative ``n`` leads them (moves them up);
372 ``n == 0`` returns the current portfolio unchanged.
374 Notes:
375 - Missing values introduced by the shift are left as nulls;
376 downstream profit computation already guards and treats
377 nulls as zero when multiplying by returns.
379 Args:
380 n: Number of rows to shift (can be negative, zero, or positive).
382 Returns:
383 A new Portfolio instance with lagged cash positions and the same
384 prices/AUM as the original.
385 """
386 if not isinstance(n, int):
387 raise TypeError
388 if n == 0:
389 return self
391 # Identify numeric asset columns (exclude 'date')
392 assets = [c for c in self.cashposition.columns if c != "date" and self.cashposition[c].dtype.is_numeric()]
394 # Shift numeric columns by n; keep others as-is
395 cp_lagged = self.cashposition.with_columns(pl.col(c).shift(n) for c in assets)
396 return Portfolio(prices=self.prices, cashposition=cp_lagged, aum=self.aum)
398 def smoothed_holding(self, n: int) -> "Portfolio":
399 """Return a new Portfolio with cash positions smoothed by a rolling mean.
401 Applies a trailing window average over the last ``n`` steps for each
402 numeric asset column (excluding 'date'). The window length is ``n + 1``
403 so that:
404 - n=0 returns the original weights (no smoothing),
405 - n=1 averages the current and previous weights,
406 - n=k averages the current and last k weights.
408 Args:
409 n: Non-negative integer specifying how many previous steps to include.
411 Returns:
412 A new Portfolio with smoothed cash positions and the same prices/AUM.
413 """
414 if not isinstance(n, int):
415 raise TypeError
416 if n < 0:
417 raise ValueError
418 if n == 0:
419 return self
421 # Identify numeric asset columns (exclude 'date')
422 assets = [c for c in self.cashposition.columns if c != "date" and self.cashposition[c].dtype.is_numeric()]
424 window = n + 1
425 # Apply rolling mean per numeric asset column; keep others unchanged
426 cp_smoothed = self.cashposition.with_columns(
427 pl.col(c).rolling_mean(window_size=window, min_samples=1).alias(c) for c in assets
428 )
429 return Portfolio(prices=self.prices, cashposition=cp_smoothed, aum=self.aum)
431 @property
432 def plots(self) -> Plots:
433 """Convenience accessor returning a Plots facade for this portfolio.
435 Use this to create Plotly visualizations such as snapshots, lagged
436 performance curves, and lead/lag IR charts.
438 Returns:
439 basanos.analytics._plots.Plots: Helper object with plotting methods.
440 """
441 return Plots(self)
443 @property
444 def report(self) -> Report:
445 """Convenience accessor returning a Report facade for this portfolio.
447 Use this to generate a self-contained HTML performance report
448 containing statistics tables and interactive charts.
450 Returns:
451 basanos.analytics._report.Report: Helper object with report methods.
452 """
453 return Report(self)
455 @property
456 def assets(self) -> list[str]:
457 """List the asset column names from prices (numeric columns).
459 Returns:
460 list[str]: Names of numeric columns in prices; typically excludes 'date'.
461 """
462 return [c for c in self.prices.columns if self.prices[c].dtype.is_numeric()]
464 @property
465 def tilt(self) -> "Portfolio":
466 """Return the 'tilt' portfolio with constant average weights.
468 Computes the time-average of each asset's cash position (ignoring nulls/NaNs)
469 and builds a new Portfolio with those constant weights applied across time.
470 Prices and AUM are preserved.
471 """
472 const_position = self.cashposition.with_columns(
473 pl.col(col).drop_nulls().drop_nans().mean().alias(col) for col in self.assets
474 )
476 return Portfolio.from_cash_position(self.prices, const_position, aum=self.aum)
478 @property
479 def timing(self) -> "Portfolio":
480 """Return the 'timing' portfolio capturing deviations from the tilt.
482 Constructs weights as original cash positions minus the tilt's constant
483 positions, per asset. This isolates timing (alloc-demeaned) effects.
484 Prices and AUM are preserved.
485 """
486 const_position = self.tilt.cashposition
487 # subtracting frames is subtle as it would also try to subtract the date column
488 position = self.cashposition.with_columns((pl.col(col) - const_position[col]).alias(col) for col in self.assets)
489 return Portfolio.from_cash_position(self.prices, position, aum=self.aum)
491 @property
492 def tilt_timing_decomp(self) -> pl.DataFrame:
493 """Return the portfolio's tilt/timing NAV decomposition.
495 When a 'date' column is present the three NAV series are joined on it.
496 When data is integer-indexed the frames are stacked horizontally.
497 """
498 if "date" in self.nav_accumulated.columns:
499 nav_portfolio = self.nav_accumulated.select(["date", "NAV_accumulated"])
500 nav_tilt = self.tilt.nav_accumulated.select(["date", "NAV_accumulated"])
501 nav_timing = self.timing.nav_accumulated.select(["date", "NAV_accumulated"])
503 # Join all three DataFrames on the 'date' column
504 merged_df = nav_portfolio.join(nav_tilt, on="date", how="inner", suffix="_tilt").join(
505 nav_timing, on="date", how="inner", suffix="_timing"
506 )
507 else:
508 nav_portfolio = self.nav_accumulated.select(["NAV_accumulated"])
509 nav_tilt = self.tilt.nav_accumulated.select(["NAV_accumulated"]).rename(
510 {"NAV_accumulated": "NAV_accumulated_tilt"}
511 )
512 nav_timing = self.timing.nav_accumulated.select(["NAV_accumulated"]).rename(
513 {"NAV_accumulated": "NAV_accumulated_timing"}
514 )
515 merged_df = nav_portfolio.hstack(nav_tilt).hstack(nav_timing)
517 merged_df = merged_df.rename(
518 {"NAV_accumulated_tilt": "tilt", "NAV_accumulated_timing": "timing", "NAV_accumulated": "portfolio"}
519 )
521 return merged_df
523 @property
524 def turnover(self) -> pl.DataFrame:
525 """Daily one-way portfolio turnover as a fraction of AUM.
527 Computes the sum of absolute position changes across all assets for each
528 period, normalised by AUM. The first row is always zero because there is
529 no prior position to form a difference against.
531 Returns:
532 pl.DataFrame: Frame with an optional ``'date'`` column and a
533 ``'turnover'`` column (dimensionless fraction of AUM).
535 Examples:
536 >>> import polars as pl
537 >>> from datetime import date
538 >>> _d = [date(2020, 1, 1), date(2020, 1, 2), date(2020, 1, 3)]
539 >>> prices = pl.DataFrame({"date": _d, "A": [100.0, 110.0, 121.0]})
540 >>> pos = pl.DataFrame({"date": prices["date"], "A": [1000.0, 1200.0, 900.0]})
541 >>> pf = Portfolio(prices=prices, cashposition=pos, aum=1e5)
542 >>> pf.turnover["turnover"].to_list()
543 [0.0, 0.002, 0.003]
544 """
545 assets = [c for c in self.cashposition.columns if c != "date" and self.cashposition[c].dtype.is_numeric()]
546 daily_abs_chg = (pl.sum_horizontal(pl.col(c).diff().abs().fill_null(0.0) for c in assets) / self.aum).alias(
547 "turnover"
548 )
549 cols: list[str | pl.Expr] = []
550 if "date" in self.cashposition.columns:
551 cols.append("date")
552 cols.append(daily_abs_chg)
553 return self.cashposition.select(cols)
555 @property
556 def turnover_weekly(self) -> pl.DataFrame:
557 """Weekly aggregated one-way portfolio turnover as a fraction of AUM.
559 When a ``'date'`` column is present, sums the daily turnover within each
560 calendar week (Monday-based ``group_by_dynamic``). Without a date
561 column, a rolling 5-period sum with ``min_samples=5`` is returned
562 (the first four rows will be ``null``).
564 Returns:
565 pl.DataFrame: Frame with an optional ``'date'`` column (week start)
566 and a ``'turnover'`` column (fraction of AUM, summed over the week).
568 Raises:
569 MissingDateColumnError: Never — returns a rolling result when date
570 is absent.
571 """
572 daily = self.turnover
573 if "date" not in daily.columns or not daily["date"].dtype.is_temporal():
574 return daily.with_columns(pl.col("turnover").rolling_sum(window_size=5, min_samples=5))
575 return daily.group_by_dynamic("date", every="1w").agg(pl.col("turnover").sum()).sort("date")
577 def turnover_summary(self) -> pl.DataFrame:
578 """Return a summary DataFrame of turnover statistics.
580 Computes three metrics from the daily turnover series:
582 - ``mean_daily_turnover``: mean of daily one-way turnover (fraction of AUM).
583 - ``mean_weekly_turnover``: mean of weekly-aggregated turnover (fraction of AUM).
584 - ``turnover_std``: standard deviation of daily turnover (fraction of AUM);
585 complements the mean to detect regime switches.
587 Returns:
588 pl.DataFrame: One row per metric with columns ``'metric'`` and
589 ``'value'``.
591 Examples:
592 >>> import polars as pl
593 >>> from datetime import date, timedelta
594 >>> import numpy as np
595 >>> start = date(2020, 1, 1)
596 >>> dates = pl.date_range(start=start, end=start + timedelta(days=9), interval="1d", eager=True)
597 >>> prices = pl.DataFrame({"date": dates, "A": pl.Series(np.ones(10) * 100.0)})
598 >>> pos = pl.DataFrame({"date": dates, "A": pl.Series([float(i) * 100 for i in range(10)])})
599 >>> pf = Portfolio(prices=prices, cashposition=pos, aum=1e4)
600 >>> summary = pf.turnover_summary()
601 >>> list(summary["metric"])
602 ['mean_daily_turnover', 'mean_weekly_turnover', 'turnover_std']
603 """
604 daily_col = self.turnover["turnover"]
605 _mean = daily_col.mean()
606 mean_daily = float(_mean) if isinstance(_mean, (int, float)) else 0.0
607 _std = daily_col.std()
608 std_daily = float(_std) if isinstance(_std, (int, float)) else 0.0
609 weekly_col = self.turnover_weekly["turnover"].drop_nulls()
610 _weekly_mean = weekly_col.mean()
611 mean_weekly = (
612 float(_weekly_mean) if weekly_col.len() > 0 and isinstance(_weekly_mean, (int, float)) else float("nan")
613 )
614 return pl.DataFrame(
615 {
616 "metric": ["mean_daily_turnover", "mean_weekly_turnover", "turnover_std"],
617 "value": [mean_daily, mean_weekly, std_daily],
618 }
619 )
621 def cost_adjusted_returns(self, cost_bps: float) -> pl.DataFrame:
622 """Return daily portfolio returns net of estimated one-way trading costs.
624 Trading costs are modelled as a linear function of daily one-way
625 turnover: for every unit of AUM traded, the strategy incurs
626 ``cost_bps`` basis points (i.e. ``cost_bps / 10_000`` fractional cost).
627 The daily cost deduction is therefore::
629 daily_cost = turnover * (cost_bps / 10_000)
631 where ``turnover`` is the fraction-of-AUM one-way turnover already
632 computed by :py:attr:`turnover`. The deduction is applied to the
633 ``returns`` column of :py:attr:`returns`, leaving all other columns
634 (including ``date``) untouched.
636 Args:
637 cost_bps: One-way trading cost in basis points per unit of AUM
638 traded. Must be non-negative.
640 Returns:
641 pl.DataFrame: Same schema as :py:attr:`returns` but with the
642 ``returns`` column reduced by the per-period trading cost.
644 Raises:
645 ValueError: If ``cost_bps`` is negative.
647 Examples:
648 >>> import polars as pl
649 >>> from datetime import date
650 >>> _d = [date(2020, 1, 1), date(2020, 1, 2), date(2020, 1, 3)]
651 >>> prices = pl.DataFrame({"date": _d, "A": [100.0, 110.0, 121.0]})
652 >>> pos = pl.DataFrame({"date": _d, "A": [1000.0, 1200.0, 900.0]})
653 >>> pf = Portfolio(prices=prices, cashposition=pos, aum=1e5)
654 >>> adj = pf.cost_adjusted_returns(0.0)
655 >>> float(adj["returns"][1]) == float(pf.returns["returns"][1])
656 True
657 """
658 if cost_bps < 0:
659 raise ValueError
660 base = self.returns
661 daily_cost = self.turnover["turnover"] * (cost_bps / 10_000.0)
662 return base.with_columns((pl.col("returns") - daily_cost).alias("returns"))
664 def trading_cost_impact(self, max_bps: int = 20) -> pl.DataFrame:
665 """Estimate the impact of trading costs on the Sharpe ratio.
667 Computes the annualised Sharpe ratio of cost-adjusted returns for
668 each integer cost level from 0 up to and including ``max_bps``
669 basis points (1 bp = 0.01 %). The result lets you quickly assess at
670 what cost level the strategy's edge is eroded.
672 Args:
673 max_bps: Maximum one-way trading cost to evaluate, in basis
674 points. Defaults to 20 (i.e., evaluates 0, 1, 2, …, 20 bps).
675 Must be a positive integer.
677 Returns:
678 pl.DataFrame: Frame with columns ``'cost_bps'`` (Int64) and
679 ``'sharpe'`` (Float64), one row per cost level from 0 to
680 ``max_bps`` inclusive.
682 Raises:
683 ValueError: If ``max_bps`` is not a positive integer.
685 Examples:
686 >>> import polars as pl
687 >>> from datetime import date, timedelta
688 >>> import numpy as np
689 >>> start = date(2020, 1, 1)
690 >>> dates = pl.date_range(
691 ... start=start, end=start + timedelta(days=99), interval="1d", eager=True
692 ... )
693 >>> rng = np.random.default_rng(0)
694 >>> prices = pl.DataFrame({
695 ... "date": dates,
696 ... "A": pl.Series(np.cumprod(1 + rng.normal(0.001, 0.01, 100)) * 100),
697 ... })
698 >>> pos = pl.DataFrame({"date": dates, "A": pl.Series(np.ones(100) * 1000.0)})
699 >>> pf = Portfolio(prices=prices, cashposition=pos, aum=1e5)
700 >>> impact = pf.trading_cost_impact(max_bps=5)
701 >>> list(impact["cost_bps"])
702 [0, 1, 2, 3, 4, 5]
703 """
704 if not isinstance(max_bps, int) or max_bps < 1:
705 raise ValueError
706 cost_levels = list(range(0, max_bps + 1))
707 sharpe_values: list[float] = []
708 for bps in cost_levels:
709 adj = self.cost_adjusted_returns(float(bps))
710 cols = ["date", "returns"] if "date" in adj.columns else ["returns"]
711 sharpe_val = Stats(data=adj.select(cols)).sharpe().get("returns", float("nan"))
712 sharpe_values.append(float(sharpe_val) if sharpe_val is not None else float("nan"))
713 return pl.DataFrame({"cost_bps": pl.Series(cost_levels, dtype=pl.Int64), "sharpe": pl.Series(sharpe_values)})
715 def correlation(self, frame: pl.DataFrame, name: str = "portfolio") -> pl.DataFrame:
716 """Compute a correlation matrix of asset returns plus the portfolio.
718 Computes percentage changes for all numeric columns in ``frame``,
719 appends the portfolio profit series under the provided ``name``, and
720 returns the Pearson correlation matrix across all numeric columns.
722 Args:
723 frame: A Polars DataFrame containing at least the asset price
724 columns (and a date column which will be ignored if non-numeric).
725 name: The column name to use when adding the portfolio profit
726 series to the input frame.
728 Returns:
729 A square Polars DataFrame where each cell is the correlation
730 between a pair of series (values in [-1, 1]).
731 """
732 # 1. Compute percentage change for all float columns
733 p = frame.with_columns(cs.by_dtype(pl.Float32, pl.Float64).pct_change())
735 # 2. Add the portfolio column from self.profit["profit"]
736 p = p.with_columns(pl.Series(name, self.profit["profit"]))
738 # 3. Compute correlation matrix
739 corr_matrix = p.select(cs.numeric()).fill_null(0.0).corr()
741 return corr_matrix