Coverage for src/jquantstats/_plots/_portfolio.py: 100%
166 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-23 06:13 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-23 06:13 +0000
1"""Plotting utilities for portfolio analytics using Plotly.
3This module defines the PortfolioPlots facade which renders common portfolio visuals
4such as snapshots, lagged performance curves, smoothed-holdings curves, and
5lead/lag information ratio bar charts. Designed for notebook use.
6"""
8from __future__ import annotations
10from typing import TYPE_CHECKING
12import plotly.express as px
13import plotly.graph_objects as go
14import plotly.io as pio
15import polars as pl
16from plotly.subplots import make_subplots
18from ._data import _apply_base_layout
20if TYPE_CHECKING:
21 from ._protocol import PortfolioLike
23# Ensure Plotly works with Marimo (set after imports to satisfy linters)
24pio.renderers.default = "plotly_mimetype"
27class PortfolioPlots:
28 """Facade for portfolio plots built with Plotly.
30 Provides convenience methods to visualize portfolio performance and
31 diagnostics directly from a Portfolio instance (e.g., snapshot charts,
32 lagged performance, smoothed holdings, and lead/lag IR).
33 """
35 __slots__ = ("_portfolio",)
37 def __init__(self, portfolio: PortfolioLike) -> None:
38 self._portfolio = portfolio
40 def lead_lag_ir_plot(self, start: int = -10, end: int = 19) -> go.Figure:
41 """Plot Sharpe ratio (IR) across lead/lag variants of the portfolio.
43 Builds portfolios with cash positions lagged from ``start`` to ``end``
44 (inclusive) and plots a bar chart of the Sharpe ratio for each lag.
45 Positive lags delay weights; negative lags lead them.
47 Args:
48 start: First lag to include (default: -10).
49 end: Last lag to include (default: +19).
51 Returns:
52 A Plotly Figure with one bar per lag labeled by the lag value.
53 """
54 if not isinstance(start, int) or not isinstance(end, int):
55 raise TypeError
56 if start > end:
57 start, end = end, start
59 lags = list(range(start, end + 1))
61 x_vals: list[int] = []
62 y_vals: list[float] = []
64 for n in lags:
65 pf = self._portfolio if n == 0 else self._portfolio.lag(n)
66 # Compute Sharpe on the portfolio's returns series
67 sharpe_val = pf.stats.sharpe().get("returns", float("nan"))
68 # Ensure a float (Stats returns mapping asset->value)
69 y_vals.append(float(sharpe_val) if sharpe_val is not None else float("nan"))
70 x_vals.append(n)
72 colors = ["red" if x == 0 else "#1f77b4" for x in x_vals]
73 fig = go.Figure(
74 data=[
75 go.Bar(x=x_vals, y=y_vals, name="Sharpe by lag", marker_color=colors),
76 ]
77 )
78 fig.update_layout(
79 title="Lead/Lag Information Ratio (Sharpe) by Lag",
80 xaxis_title="Lag (steps)",
81 yaxis_title="Sharpe ratio",
82 plot_bgcolor="white",
83 hovermode="x",
84 )
85 fig.update_xaxes(showgrid=True, gridwidth=0.5, gridcolor="lightgrey")
86 fig.update_yaxes(showgrid=True, gridwidth=0.5, gridcolor="lightgrey")
87 return fig
89 def snapshot(self, log_scale: bool = False) -> go.Figure:
90 """Return a snapshot dashboard of NAV and drawdown.
92 When the portfolio has a non-zero ``cost_model.cost_per_unit``, an additional
93 ``"Net-of-Cost NAV"`` trace is overlaid on the NAV panel showing the
94 realised NAV path after deducting position-delta trading costs.
96 Args:
97 log_scale (bool, optional): If True, display NAV on a log scale. Defaults to False.
99 Returns:
100 plotly.graph_objects.Figure: A Figure with accumulated NAV (including tilt/timing)
101 and drawdown shaded area, equipped with a range selector.
102 """
103 # Create subplot grid with domain for stats table
104 fig = make_subplots(
105 rows=2,
106 cols=1,
107 shared_xaxes=True,
108 row_heights=[0.66, 0.33],
109 subplot_titles=["Accumulated Profit", "Drawdown"],
110 vertical_spacing=0.05,
111 )
113 # --- Row 1: Cumulative Returns
114 fig.add_trace(
115 go.Scatter(
116 x=self._portfolio.nav_accumulated["date"],
117 y=self._portfolio.nav_accumulated["NAV_accumulated"],
118 mode="lines",
119 name="NAV",
120 showlegend=False,
121 ),
122 row=1,
123 col=1,
124 )
126 fig.add_trace(
127 go.Scatter(
128 x=self._portfolio.tilt.nav_accumulated["date"],
129 y=self._portfolio.tilt.nav_accumulated["NAV_accumulated"],
130 mode="lines",
131 name="Tilt",
132 showlegend=False,
133 ),
134 row=1,
135 col=1,
136 )
138 fig.add_trace(
139 go.Scatter(
140 x=self._portfolio.timing.nav_accumulated["date"],
141 y=self._portfolio.timing.nav_accumulated["NAV_accumulated"],
142 mode="lines",
143 name="Timing",
144 showlegend=False,
145 ),
146 row=1,
147 col=1,
148 )
150 # Net-of-cost NAV overlay (only when a cost model is active)
151 if self._portfolio.cost_model.cost_per_unit > 0:
152 net_nav_df = self._portfolio.net_cost_nav
153 x_dates = net_nav_df["date"] if "date" in net_nav_df.columns else None
154 fig.add_trace(
155 go.Scatter(
156 x=x_dates,
157 y=net_nav_df["NAV_accumulated_net"],
158 mode="lines",
159 name="Net-of-Cost NAV",
160 line={"dash": "dash"},
161 showlegend=True,
162 ),
163 row=1,
164 col=1,
165 )
167 fig.add_trace(
168 go.Scatter(
169 x=self._portfolio.drawdown["date"],
170 y=self._portfolio.drawdown["drawdown_pct"],
171 mode="lines",
172 fill="tozeroy",
173 name="Drawdown",
174 showlegend=False,
175 ),
176 row=2,
177 col=1,
178 )
180 fig.add_hline(y=0, line_width=1, line_color="gray", row=2, col=1)
182 _apply_base_layout(fig, "Performance Dashboard", height=1200)
184 fig.update_yaxes(title_text="NAV (accumulated)", row=1, col=1, tickformat=".2s")
185 fig.update_yaxes(title_text="Drawdown", row=2, col=1, tickformat=".0%")
187 if log_scale:
188 fig.update_yaxes(type="log", row=1, col=1)
189 # Ensure the first y-axis is explicitly set for environments
190 # where subplot updates may not propagate to layout alias.
191 if hasattr(fig.layout, "yaxis"): # pragma: no branch — plotly figures always have .yaxis
192 fig.layout.yaxis.type = "log"
194 return fig
196 @staticmethod
197 def _apply_nav_layout(fig: go.Figure, title: str, log_scale: bool = False) -> None:
198 """Apply common NAV-accumulated layout to *fig* in-place.
200 Configures the plot background, legend, hover mode, x-axis date range
201 selector, y-axis label, grid lines, and optional logarithmic y-scale.
202 Shared by `lagged_performance_plot` and
203 `smoothed_holdings_performance_plot`.
205 Args:
206 fig: The Plotly Figure to configure.
207 title: Chart title text.
208 log_scale: If True, set the primary y-axis to logarithmic scale.
209 """
210 _apply_base_layout(fig, title)
211 fig.update_yaxes(title_text="NAV (accumulated)")
213 if log_scale:
214 fig.update_yaxes(type="log")
215 if hasattr(fig.layout, "yaxis"): # pragma: no branch — plotly figures always have .yaxis
216 fig.layout.yaxis.type = "log"
218 def lagged_performance_plot(self, lags: list[int] | None = None, log_scale: bool = False) -> go.Figure:
219 """Plot NAV_accumulated for multiple lagged portfolios.
221 Creates a Plotly figure with one line per lag value showing the
222 accumulated NAV series for the portfolio with cash positions
223 shifted by that lag. By default, lags [0, 1, 2, 3, 4] are used.
225 Args:
226 lags: A list of integer lags to apply; defaults to [0, 1, 2, 3, 4].
227 log_scale: If True, set the primary y-axis to logarithmic scale.
229 Returns:
230 A Plotly Figure containing one trace per requested lag.
231 """
232 if lags is None:
233 lags = [0, 1, 2, 3, 4]
234 if not isinstance(lags, list) or not all(isinstance(x, int) for x in lags):
235 raise TypeError
237 fig = go.Figure()
238 for lag in lags:
239 pf = self._portfolio if lag == 0 else self._portfolio.lag(lag)
240 nav = pf.nav_accumulated
241 fig.add_trace(
242 go.Scatter(
243 x=nav["date"],
244 y=nav["NAV_accumulated"],
245 mode="lines",
246 name=f"lag {lag}",
247 line={"width": 1},
248 )
249 )
251 self._apply_nav_layout(fig, title="NAV accumulated by lag", log_scale=log_scale)
252 return fig
254 def rolling_sharpe_plot(self, window: int = 63) -> go.Figure:
255 """Plot rolling annualised Sharpe ratio over time.
257 Computes the rolling Sharpe for each asset column using the given
258 window and renders one line per asset.
260 Args:
261 window: Rolling-window size in periods. Defaults to 63.
263 Returns:
264 A Plotly Figure with one trace per asset.
266 Raises:
267 ValueError: If ``window`` is not a positive integer.
268 """
269 if not isinstance(window, int) or window <= 0:
270 raise ValueError(f"window must be a positive integer, got {window!r}") # noqa: TRY003
272 rolling = self._portfolio.stats.rolling_sharpe(rolling_period=window)
274 fig = go.Figure()
275 date_col = rolling["date"] if "date" in rolling.columns else None
276 for col in rolling.columns:
277 if col == "date":
278 continue
279 fig.add_trace(
280 go.Scatter(
281 x=date_col,
282 y=rolling[col],
283 mode="lines",
284 name=col,
285 line={"width": 1},
286 )
287 )
289 fig.add_hline(y=0, line_width=1, line_dash="dash", line_color="gray")
291 _apply_base_layout(fig, f"Rolling Sharpe Ratio ({window}-period window)")
292 fig.update_yaxes(title_text="Sharpe ratio")
293 return fig
295 def rolling_volatility_plot(self, window: int = 63) -> go.Figure:
296 """Plot rolling annualised volatility over time.
298 Computes the rolling volatility for each asset column using the given
299 window and renders one line per asset.
301 Args:
302 window: Rolling-window size in periods. Defaults to 63.
304 Returns:
305 A Plotly Figure with one trace per asset.
307 Raises:
308 ValueError: If ``window`` is not a positive integer.
309 """
310 if not isinstance(window, int) or window <= 0:
311 raise ValueError(f"window must be a positive integer, got {window!r}") # noqa: TRY003
313 rolling = self._portfolio.stats.rolling_volatility(rolling_period=window)
315 fig = go.Figure()
316 date_col = rolling["date"] if "date" in rolling.columns else None
317 for col in rolling.columns:
318 if col == "date":
319 continue
320 fig.add_trace(
321 go.Scatter(
322 x=date_col,
323 y=rolling[col],
324 mode="lines",
325 name=col,
326 line={"width": 1},
327 )
328 )
330 _apply_base_layout(fig, f"Rolling Volatility ({window}-period window)")
331 fig.update_yaxes(title_text="Annualised volatility")
332 return fig
334 def annual_sharpe_plot(self) -> go.Figure:
335 """Plot annualised Sharpe ratio broken down by calendar year.
337 Computes the Sharpe ratio for each calendar year from the portfolio
338 returns and renders a grouped bar chart with one bar per year per
339 asset.
341 Returns:
342 A Plotly Figure with one bar group per asset.
343 """
344 breakdown = self._portfolio.stats.annual_breakdown()
346 # Extract the sharpe row for each year
347 sharpe_rows = breakdown.filter(pl.col("metric") == "sharpe")
348 asset_cols = [c for c in sharpe_rows.columns if c not in ("year", "metric")]
350 fig = go.Figure()
351 for asset in asset_cols:
352 fig.add_trace(
353 go.Bar(
354 x=sharpe_rows["year"],
355 y=sharpe_rows[asset],
356 name=asset,
357 )
358 )
360 fig.add_hline(y=0, line_width=1, line_color="gray")
362 fig.update_layout(
363 title="Annual Sharpe Ratio by Year",
364 barmode="group",
365 hovermode="x unified",
366 plot_bgcolor="white",
367 legend={"orientation": "h", "yanchor": "bottom", "y": 1.02, "xanchor": "right", "x": 1},
368 )
369 fig.update_yaxes(title_text="Sharpe ratio")
370 fig.update_xaxes(showgrid=True, gridwidth=0.5, gridcolor="lightgrey", title_text="Year")
371 fig.update_yaxes(showgrid=True, gridwidth=0.5, gridcolor="lightgrey")
372 return fig
374 def correlation_heatmap(
375 self,
376 frame: pl.DataFrame | None = None,
377 name: str = "portfolio",
378 title: str = "Correlation heatmap",
379 ) -> go.Figure:
380 """Plot a correlation heatmap for assets and the portfolio.
382 If ``frame`` is None, uses the portfolio's prices. The portfolio's
383 profit series is appended under ``name`` before computing the
384 correlation matrix.
386 Args:
387 frame: Optional Polars DataFrame with at least the asset price
388 columns. If omitted, uses ``self._portfolio.prices``.
389 name: Column name under which to include the portfolio profit.
390 title: Plot title.
392 Returns:
393 A Plotly Figure rendering the correlation matrix as a heatmap.
394 """
395 if frame is None:
396 frame = self._portfolio.prices
398 corr = self._portfolio.correlation(frame, name=name)
400 # Create an interactive heatmap
401 fig = px.imshow(
402 corr,
403 x=corr.columns,
404 y=corr.columns,
405 text_auto=".2f", # show correlation values
406 color_continuous_scale="RdBu_r", # red-blue diverging colormap
407 zmin=-1,
408 zmax=1, # correlation range
409 title=title,
410 )
412 # Adjust layout
413 fig.update_layout(
414 xaxis_title="", yaxis_title="", width=700, height=600, coloraxis_colorbar={"title": "Correlation"}
415 )
417 return fig
419 def monthly_returns_heatmap(self) -> go.Figure:
420 """Plot a monthly returns calendar heatmap.
422 Groups portfolio returns by calendar year and month, then renders a
423 Plotly heatmap with months on the x-axis and years on the y-axis.
424 Green cells indicate positive months; red cells indicate negative
425 months. Cell text shows the percentage return for that month.
427 Returns:
428 A Plotly Figure with a calendar heatmap of monthly returns.
430 Raises:
431 ValueError: If the portfolio has no ``date`` column.
432 """
433 monthly = self._portfolio.monthly
435 years = monthly["year"].unique().sort().to_list()
436 month_names = ["Jan", "Feb", "Mar", "Apr", "May", "Jun", "Jul", "Aug", "Sep", "Oct", "Nov", "Dec"]
438 z: list[list[float | None]] = []
439 text: list[list[str]] = []
440 for year in years:
441 year_data = monthly.filter(pl.col("year") == year)
442 year_row: list[float | None] = []
443 year_text: list[str] = []
444 for m in range(1, 13):
445 month_data = year_data.filter(pl.col("month") == m)
446 if month_data.is_empty():
447 year_row.append(None)
448 year_text.append("")
449 else:
450 ret = float(month_data["returns"][0])
451 year_row.append(ret * 100.0)
452 year_text.append(f"{ret * 100.0:.1f}%")
453 z.append(year_row)
454 text.append(year_text)
456 fig = go.Figure(
457 data=go.Heatmap(
458 z=z,
459 x=month_names,
460 y=[str(y) for y in years],
461 text=text,
462 texttemplate="%{text}",
463 colorscale="RdYlGn",
464 zmid=0,
465 colorbar={"title": "Return (%)"},
466 hovertemplate="<b>%{y} %{x}</b><br>Return: %{text}<extra></extra>",
467 )
468 )
470 fig.update_layout(
471 title="Monthly Returns Heatmap",
472 xaxis_title="Month",
473 yaxis_title="Year",
474 plot_bgcolor="white",
475 yaxis={"type": "category"},
476 )
478 return fig
480 def smoothed_holdings_performance_plot(
481 self,
482 windows: list[int] | None = None,
483 log_scale: bool = False,
484 ) -> go.Figure:
485 """Plot NAV_accumulated for smoothed-holding portfolios.
487 Builds portfolios with cash positions smoothed by a trailing rolling
488 mean over the previous ``n`` steps (window size n+1) for n in
489 ``windows`` (defaults to [0, 1, 2, 3, 4]) and plots their
490 accumulated NAV curves.
492 Args:
493 windows: List of non-negative integers specifying smoothing steps
494 to include; defaults to [0, 1, 2, 3, 4].
495 log_scale: If True, set the primary y-axis to logarithmic scale.
497 Returns:
498 A Plotly Figure containing one line per requested smoothing level.
499 """
500 if windows is None:
501 windows = [0, 1, 2, 3, 4]
502 if not isinstance(windows, list) or not all(isinstance(x, int) and x >= 0 for x in windows):
503 raise TypeError
505 fig = go.Figure()
506 for n in windows:
507 pf = self._portfolio if n == 0 else self._portfolio.smoothed_holding(n)
508 nav = pf.nav_accumulated
509 fig.add_trace(
510 go.Scatter(
511 x=nav["date"],
512 y=nav["NAV_accumulated"],
513 mode="lines",
514 name=f"smooth {n}",
515 line={"width": 1},
516 )
517 )
519 self._apply_nav_layout(fig, title="NAV accumulated by smoothed holdings", log_scale=log_scale)
520 return fig
522 def trading_cost_impact_plot(self, max_bps: int = 20) -> go.Figure:
523 """Plot the Sharpe ratio as a function of one-way trading costs.
525 Evaluates the portfolio's annualised Sharpe ratio at each integer
526 cost level from 0 up to ``max_bps`` basis points and renders the
527 result as a line chart. The zero-cost Sharpe is shown as a
528 reference horizontal line so that the reader can quickly gauge
529 at what cost level the strategy's edge is eroded.
531 Args:
532 max_bps: Maximum one-way trading cost to evaluate, in basis
533 points. Defaults to 20.
535 Returns:
536 A Plotly Figure with one line trace showing Sharpe vs. cost.
538 Raises:
539 ValueError: If ``max_bps`` is not a positive integer.
540 """
541 impact = self._portfolio.trading_cost_impact(max_bps=max_bps)
543 cost_vals = impact["cost_bps"].to_list()
544 sharpe_vals = impact["sharpe"].to_list()
546 # Baseline Sharpe at zero cost
547 baseline = float(sharpe_vals[0]) if sharpe_vals and sharpe_vals[0] is not None else float("nan")
549 fig = go.Figure()
550 fig.add_trace(
551 go.Scatter(
552 x=cost_vals,
553 y=sharpe_vals,
554 mode="lines+markers",
555 name="Sharpe (cost-adjusted)",
556 marker={"size": 6},
557 line={"width": 2, "color": "#1f77b4"},
558 )
559 )
560 if baseline == baseline: # only add when baseline is finite (NaN != NaN)
561 fig.add_hline(
562 y=baseline,
563 line_width=1,
564 line_dash="dash",
565 line_color="gray",
566 annotation_text="0 bps baseline",
567 annotation_position="top right",
568 )
570 fig.update_layout(
571 title=f"Trading Cost Impact on Sharpe Ratio (0\u2013{max_bps} bps)",
572 hovermode="x unified",
573 plot_bgcolor="white",
574 )
575 fig.update_xaxes(
576 title_text="One-way cost (basis points)",
577 showgrid=True,
578 gridwidth=0.5,
579 gridcolor="lightgrey",
580 dtick=1,
581 )
582 fig.update_yaxes(
583 title_text="Annualised Sharpe ratio",
584 showgrid=True,
585 gridwidth=0.5,
586 gridcolor="lightgrey",
587 )
588 return fig