Coverage for src/jquantstats/_plots/_portfolio.py: 100%

166 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-06-23 06:13 +0000

1"""Plotting utilities for portfolio analytics using Plotly. 

2 

3This module defines the PortfolioPlots facade which renders common portfolio visuals 

4such as snapshots, lagged performance curves, smoothed-holdings curves, and 

5lead/lag information ratio bar charts. Designed for notebook use. 

6""" 

7 

8from __future__ import annotations 

9 

10from typing import TYPE_CHECKING 

11 

12import plotly.express as px 

13import plotly.graph_objects as go 

14import plotly.io as pio 

15import polars as pl 

16from plotly.subplots import make_subplots 

17 

18from ._data import _apply_base_layout 

19 

20if TYPE_CHECKING: 

21 from ._protocol import PortfolioLike 

22 

23# Ensure Plotly works with Marimo (set after imports to satisfy linters) 

24pio.renderers.default = "plotly_mimetype" 

25 

26 

27class PortfolioPlots: 

28 """Facade for portfolio plots built with Plotly. 

29 

30 Provides convenience methods to visualize portfolio performance and 

31 diagnostics directly from a Portfolio instance (e.g., snapshot charts, 

32 lagged performance, smoothed holdings, and lead/lag IR). 

33 """ 

34 

35 __slots__ = ("_portfolio",) 

36 

37 def __init__(self, portfolio: PortfolioLike) -> None: 

38 self._portfolio = portfolio 

39 

40 def lead_lag_ir_plot(self, start: int = -10, end: int = 19) -> go.Figure: 

41 """Plot Sharpe ratio (IR) across lead/lag variants of the portfolio. 

42 

43 Builds portfolios with cash positions lagged from ``start`` to ``end`` 

44 (inclusive) and plots a bar chart of the Sharpe ratio for each lag. 

45 Positive lags delay weights; negative lags lead them. 

46 

47 Args: 

48 start: First lag to include (default: -10). 

49 end: Last lag to include (default: +19). 

50 

51 Returns: 

52 A Plotly Figure with one bar per lag labeled by the lag value. 

53 """ 

54 if not isinstance(start, int) or not isinstance(end, int): 

55 raise TypeError 

56 if start > end: 

57 start, end = end, start 

58 

59 lags = list(range(start, end + 1)) 

60 

61 x_vals: list[int] = [] 

62 y_vals: list[float] = [] 

63 

64 for n in lags: 

65 pf = self._portfolio if n == 0 else self._portfolio.lag(n) 

66 # Compute Sharpe on the portfolio's returns series 

67 sharpe_val = pf.stats.sharpe().get("returns", float("nan")) 

68 # Ensure a float (Stats returns mapping asset->value) 

69 y_vals.append(float(sharpe_val) if sharpe_val is not None else float("nan")) 

70 x_vals.append(n) 

71 

72 colors = ["red" if x == 0 else "#1f77b4" for x in x_vals] 

73 fig = go.Figure( 

74 data=[ 

75 go.Bar(x=x_vals, y=y_vals, name="Sharpe by lag", marker_color=colors), 

76 ] 

77 ) 

78 fig.update_layout( 

79 title="Lead/Lag Information Ratio (Sharpe) by Lag", 

80 xaxis_title="Lag (steps)", 

81 yaxis_title="Sharpe ratio", 

82 plot_bgcolor="white", 

83 hovermode="x", 

84 ) 

85 fig.update_xaxes(showgrid=True, gridwidth=0.5, gridcolor="lightgrey") 

86 fig.update_yaxes(showgrid=True, gridwidth=0.5, gridcolor="lightgrey") 

87 return fig 

88 

89 def snapshot(self, log_scale: bool = False) -> go.Figure: 

90 """Return a snapshot dashboard of NAV and drawdown. 

91 

92 When the portfolio has a non-zero ``cost_model.cost_per_unit``, an additional 

93 ``"Net-of-Cost NAV"`` trace is overlaid on the NAV panel showing the 

94 realised NAV path after deducting position-delta trading costs. 

95 

96 Args: 

97 log_scale (bool, optional): If True, display NAV on a log scale. Defaults to False. 

98 

99 Returns: 

100 plotly.graph_objects.Figure: A Figure with accumulated NAV (including tilt/timing) 

101 and drawdown shaded area, equipped with a range selector. 

102 """ 

103 # Create subplot grid with domain for stats table 

104 fig = make_subplots( 

105 rows=2, 

106 cols=1, 

107 shared_xaxes=True, 

108 row_heights=[0.66, 0.33], 

109 subplot_titles=["Accumulated Profit", "Drawdown"], 

110 vertical_spacing=0.05, 

111 ) 

112 

113 # --- Row 1: Cumulative Returns 

114 fig.add_trace( 

115 go.Scatter( 

116 x=self._portfolio.nav_accumulated["date"], 

117 y=self._portfolio.nav_accumulated["NAV_accumulated"], 

118 mode="lines", 

119 name="NAV", 

120 showlegend=False, 

121 ), 

122 row=1, 

123 col=1, 

124 ) 

125 

126 fig.add_trace( 

127 go.Scatter( 

128 x=self._portfolio.tilt.nav_accumulated["date"], 

129 y=self._portfolio.tilt.nav_accumulated["NAV_accumulated"], 

130 mode="lines", 

131 name="Tilt", 

132 showlegend=False, 

133 ), 

134 row=1, 

135 col=1, 

136 ) 

137 

138 fig.add_trace( 

139 go.Scatter( 

140 x=self._portfolio.timing.nav_accumulated["date"], 

141 y=self._portfolio.timing.nav_accumulated["NAV_accumulated"], 

142 mode="lines", 

143 name="Timing", 

144 showlegend=False, 

145 ), 

146 row=1, 

147 col=1, 

148 ) 

149 

150 # Net-of-cost NAV overlay (only when a cost model is active) 

151 if self._portfolio.cost_model.cost_per_unit > 0: 

152 net_nav_df = self._portfolio.net_cost_nav 

153 x_dates = net_nav_df["date"] if "date" in net_nav_df.columns else None 

154 fig.add_trace( 

155 go.Scatter( 

156 x=x_dates, 

157 y=net_nav_df["NAV_accumulated_net"], 

158 mode="lines", 

159 name="Net-of-Cost NAV", 

160 line={"dash": "dash"}, 

161 showlegend=True, 

162 ), 

163 row=1, 

164 col=1, 

165 ) 

166 

167 fig.add_trace( 

168 go.Scatter( 

169 x=self._portfolio.drawdown["date"], 

170 y=self._portfolio.drawdown["drawdown_pct"], 

171 mode="lines", 

172 fill="tozeroy", 

173 name="Drawdown", 

174 showlegend=False, 

175 ), 

176 row=2, 

177 col=1, 

178 ) 

179 

180 fig.add_hline(y=0, line_width=1, line_color="gray", row=2, col=1) 

181 

182 _apply_base_layout(fig, "Performance Dashboard", height=1200) 

183 

184 fig.update_yaxes(title_text="NAV (accumulated)", row=1, col=1, tickformat=".2s") 

185 fig.update_yaxes(title_text="Drawdown", row=2, col=1, tickformat=".0%") 

186 

187 if log_scale: 

188 fig.update_yaxes(type="log", row=1, col=1) 

189 # Ensure the first y-axis is explicitly set for environments 

190 # where subplot updates may not propagate to layout alias. 

191 if hasattr(fig.layout, "yaxis"): # pragma: no branch — plotly figures always have .yaxis 

192 fig.layout.yaxis.type = "log" 

193 

194 return fig 

195 

196 @staticmethod 

197 def _apply_nav_layout(fig: go.Figure, title: str, log_scale: bool = False) -> None: 

198 """Apply common NAV-accumulated layout to *fig* in-place. 

199 

200 Configures the plot background, legend, hover mode, x-axis date range 

201 selector, y-axis label, grid lines, and optional logarithmic y-scale. 

202 Shared by `lagged_performance_plot` and 

203 `smoothed_holdings_performance_plot`. 

204 

205 Args: 

206 fig: The Plotly Figure to configure. 

207 title: Chart title text. 

208 log_scale: If True, set the primary y-axis to logarithmic scale. 

209 """ 

210 _apply_base_layout(fig, title) 

211 fig.update_yaxes(title_text="NAV (accumulated)") 

212 

213 if log_scale: 

214 fig.update_yaxes(type="log") 

215 if hasattr(fig.layout, "yaxis"): # pragma: no branch — plotly figures always have .yaxis 

216 fig.layout.yaxis.type = "log" 

217 

218 def lagged_performance_plot(self, lags: list[int] | None = None, log_scale: bool = False) -> go.Figure: 

219 """Plot NAV_accumulated for multiple lagged portfolios. 

220 

221 Creates a Plotly figure with one line per lag value showing the 

222 accumulated NAV series for the portfolio with cash positions 

223 shifted by that lag. By default, lags [0, 1, 2, 3, 4] are used. 

224 

225 Args: 

226 lags: A list of integer lags to apply; defaults to [0, 1, 2, 3, 4]. 

227 log_scale: If True, set the primary y-axis to logarithmic scale. 

228 

229 Returns: 

230 A Plotly Figure containing one trace per requested lag. 

231 """ 

232 if lags is None: 

233 lags = [0, 1, 2, 3, 4] 

234 if not isinstance(lags, list) or not all(isinstance(x, int) for x in lags): 

235 raise TypeError 

236 

237 fig = go.Figure() 

238 for lag in lags: 

239 pf = self._portfolio if lag == 0 else self._portfolio.lag(lag) 

240 nav = pf.nav_accumulated 

241 fig.add_trace( 

242 go.Scatter( 

243 x=nav["date"], 

244 y=nav["NAV_accumulated"], 

245 mode="lines", 

246 name=f"lag {lag}", 

247 line={"width": 1}, 

248 ) 

249 ) 

250 

251 self._apply_nav_layout(fig, title="NAV accumulated by lag", log_scale=log_scale) 

252 return fig 

253 

254 def rolling_sharpe_plot(self, window: int = 63) -> go.Figure: 

255 """Plot rolling annualised Sharpe ratio over time. 

256 

257 Computes the rolling Sharpe for each asset column using the given 

258 window and renders one line per asset. 

259 

260 Args: 

261 window: Rolling-window size in periods. Defaults to 63. 

262 

263 Returns: 

264 A Plotly Figure with one trace per asset. 

265 

266 Raises: 

267 ValueError: If ``window`` is not a positive integer. 

268 """ 

269 if not isinstance(window, int) or window <= 0: 

270 raise ValueError(f"window must be a positive integer, got {window!r}") # noqa: TRY003 

271 

272 rolling = self._portfolio.stats.rolling_sharpe(rolling_period=window) 

273 

274 fig = go.Figure() 

275 date_col = rolling["date"] if "date" in rolling.columns else None 

276 for col in rolling.columns: 

277 if col == "date": 

278 continue 

279 fig.add_trace( 

280 go.Scatter( 

281 x=date_col, 

282 y=rolling[col], 

283 mode="lines", 

284 name=col, 

285 line={"width": 1}, 

286 ) 

287 ) 

288 

289 fig.add_hline(y=0, line_width=1, line_dash="dash", line_color="gray") 

290 

291 _apply_base_layout(fig, f"Rolling Sharpe Ratio ({window}-period window)") 

292 fig.update_yaxes(title_text="Sharpe ratio") 

293 return fig 

294 

295 def rolling_volatility_plot(self, window: int = 63) -> go.Figure: 

296 """Plot rolling annualised volatility over time. 

297 

298 Computes the rolling volatility for each asset column using the given 

299 window and renders one line per asset. 

300 

301 Args: 

302 window: Rolling-window size in periods. Defaults to 63. 

303 

304 Returns: 

305 A Plotly Figure with one trace per asset. 

306 

307 Raises: 

308 ValueError: If ``window`` is not a positive integer. 

309 """ 

310 if not isinstance(window, int) or window <= 0: 

311 raise ValueError(f"window must be a positive integer, got {window!r}") # noqa: TRY003 

312 

313 rolling = self._portfolio.stats.rolling_volatility(rolling_period=window) 

314 

315 fig = go.Figure() 

316 date_col = rolling["date"] if "date" in rolling.columns else None 

317 for col in rolling.columns: 

318 if col == "date": 

319 continue 

320 fig.add_trace( 

321 go.Scatter( 

322 x=date_col, 

323 y=rolling[col], 

324 mode="lines", 

325 name=col, 

326 line={"width": 1}, 

327 ) 

328 ) 

329 

330 _apply_base_layout(fig, f"Rolling Volatility ({window}-period window)") 

331 fig.update_yaxes(title_text="Annualised volatility") 

332 return fig 

333 

334 def annual_sharpe_plot(self) -> go.Figure: 

335 """Plot annualised Sharpe ratio broken down by calendar year. 

336 

337 Computes the Sharpe ratio for each calendar year from the portfolio 

338 returns and renders a grouped bar chart with one bar per year per 

339 asset. 

340 

341 Returns: 

342 A Plotly Figure with one bar group per asset. 

343 """ 

344 breakdown = self._portfolio.stats.annual_breakdown() 

345 

346 # Extract the sharpe row for each year 

347 sharpe_rows = breakdown.filter(pl.col("metric") == "sharpe") 

348 asset_cols = [c for c in sharpe_rows.columns if c not in ("year", "metric")] 

349 

350 fig = go.Figure() 

351 for asset in asset_cols: 

352 fig.add_trace( 

353 go.Bar( 

354 x=sharpe_rows["year"], 

355 y=sharpe_rows[asset], 

356 name=asset, 

357 ) 

358 ) 

359 

360 fig.add_hline(y=0, line_width=1, line_color="gray") 

361 

362 fig.update_layout( 

363 title="Annual Sharpe Ratio by Year", 

364 barmode="group", 

365 hovermode="x unified", 

366 plot_bgcolor="white", 

367 legend={"orientation": "h", "yanchor": "bottom", "y": 1.02, "xanchor": "right", "x": 1}, 

368 ) 

369 fig.update_yaxes(title_text="Sharpe ratio") 

370 fig.update_xaxes(showgrid=True, gridwidth=0.5, gridcolor="lightgrey", title_text="Year") 

371 fig.update_yaxes(showgrid=True, gridwidth=0.5, gridcolor="lightgrey") 

372 return fig 

373 

374 def correlation_heatmap( 

375 self, 

376 frame: pl.DataFrame | None = None, 

377 name: str = "portfolio", 

378 title: str = "Correlation heatmap", 

379 ) -> go.Figure: 

380 """Plot a correlation heatmap for assets and the portfolio. 

381 

382 If ``frame`` is None, uses the portfolio's prices. The portfolio's 

383 profit series is appended under ``name`` before computing the 

384 correlation matrix. 

385 

386 Args: 

387 frame: Optional Polars DataFrame with at least the asset price 

388 columns. If omitted, uses ``self._portfolio.prices``. 

389 name: Column name under which to include the portfolio profit. 

390 title: Plot title. 

391 

392 Returns: 

393 A Plotly Figure rendering the correlation matrix as a heatmap. 

394 """ 

395 if frame is None: 

396 frame = self._portfolio.prices 

397 

398 corr = self._portfolio.correlation(frame, name=name) 

399 

400 # Create an interactive heatmap 

401 fig = px.imshow( 

402 corr, 

403 x=corr.columns, 

404 y=corr.columns, 

405 text_auto=".2f", # show correlation values 

406 color_continuous_scale="RdBu_r", # red-blue diverging colormap 

407 zmin=-1, 

408 zmax=1, # correlation range 

409 title=title, 

410 ) 

411 

412 # Adjust layout 

413 fig.update_layout( 

414 xaxis_title="", yaxis_title="", width=700, height=600, coloraxis_colorbar={"title": "Correlation"} 

415 ) 

416 

417 return fig 

418 

419 def monthly_returns_heatmap(self) -> go.Figure: 

420 """Plot a monthly returns calendar heatmap. 

421 

422 Groups portfolio returns by calendar year and month, then renders a 

423 Plotly heatmap with months on the x-axis and years on the y-axis. 

424 Green cells indicate positive months; red cells indicate negative 

425 months. Cell text shows the percentage return for that month. 

426 

427 Returns: 

428 A Plotly Figure with a calendar heatmap of monthly returns. 

429 

430 Raises: 

431 ValueError: If the portfolio has no ``date`` column. 

432 """ 

433 monthly = self._portfolio.monthly 

434 

435 years = monthly["year"].unique().sort().to_list() 

436 month_names = ["Jan", "Feb", "Mar", "Apr", "May", "Jun", "Jul", "Aug", "Sep", "Oct", "Nov", "Dec"] 

437 

438 z: list[list[float | None]] = [] 

439 text: list[list[str]] = [] 

440 for year in years: 

441 year_data = monthly.filter(pl.col("year") == year) 

442 year_row: list[float | None] = [] 

443 year_text: list[str] = [] 

444 for m in range(1, 13): 

445 month_data = year_data.filter(pl.col("month") == m) 

446 if month_data.is_empty(): 

447 year_row.append(None) 

448 year_text.append("") 

449 else: 

450 ret = float(month_data["returns"][0]) 

451 year_row.append(ret * 100.0) 

452 year_text.append(f"{ret * 100.0:.1f}%") 

453 z.append(year_row) 

454 text.append(year_text) 

455 

456 fig = go.Figure( 

457 data=go.Heatmap( 

458 z=z, 

459 x=month_names, 

460 y=[str(y) for y in years], 

461 text=text, 

462 texttemplate="%{text}", 

463 colorscale="RdYlGn", 

464 zmid=0, 

465 colorbar={"title": "Return (%)"}, 

466 hovertemplate="<b>%{y} %{x}</b><br>Return: %{text}<extra></extra>", 

467 ) 

468 ) 

469 

470 fig.update_layout( 

471 title="Monthly Returns Heatmap", 

472 xaxis_title="Month", 

473 yaxis_title="Year", 

474 plot_bgcolor="white", 

475 yaxis={"type": "category"}, 

476 ) 

477 

478 return fig 

479 

480 def smoothed_holdings_performance_plot( 

481 self, 

482 windows: list[int] | None = None, 

483 log_scale: bool = False, 

484 ) -> go.Figure: 

485 """Plot NAV_accumulated for smoothed-holding portfolios. 

486 

487 Builds portfolios with cash positions smoothed by a trailing rolling 

488 mean over the previous ``n`` steps (window size n+1) for n in 

489 ``windows`` (defaults to [0, 1, 2, 3, 4]) and plots their 

490 accumulated NAV curves. 

491 

492 Args: 

493 windows: List of non-negative integers specifying smoothing steps 

494 to include; defaults to [0, 1, 2, 3, 4]. 

495 log_scale: If True, set the primary y-axis to logarithmic scale. 

496 

497 Returns: 

498 A Plotly Figure containing one line per requested smoothing level. 

499 """ 

500 if windows is None: 

501 windows = [0, 1, 2, 3, 4] 

502 if not isinstance(windows, list) or not all(isinstance(x, int) and x >= 0 for x in windows): 

503 raise TypeError 

504 

505 fig = go.Figure() 

506 for n in windows: 

507 pf = self._portfolio if n == 0 else self._portfolio.smoothed_holding(n) 

508 nav = pf.nav_accumulated 

509 fig.add_trace( 

510 go.Scatter( 

511 x=nav["date"], 

512 y=nav["NAV_accumulated"], 

513 mode="lines", 

514 name=f"smooth {n}", 

515 line={"width": 1}, 

516 ) 

517 ) 

518 

519 self._apply_nav_layout(fig, title="NAV accumulated by smoothed holdings", log_scale=log_scale) 

520 return fig 

521 

522 def trading_cost_impact_plot(self, max_bps: int = 20) -> go.Figure: 

523 """Plot the Sharpe ratio as a function of one-way trading costs. 

524 

525 Evaluates the portfolio's annualised Sharpe ratio at each integer 

526 cost level from 0 up to ``max_bps`` basis points and renders the 

527 result as a line chart. The zero-cost Sharpe is shown as a 

528 reference horizontal line so that the reader can quickly gauge 

529 at what cost level the strategy's edge is eroded. 

530 

531 Args: 

532 max_bps: Maximum one-way trading cost to evaluate, in basis 

533 points. Defaults to 20. 

534 

535 Returns: 

536 A Plotly Figure with one line trace showing Sharpe vs. cost. 

537 

538 Raises: 

539 ValueError: If ``max_bps`` is not a positive integer. 

540 """ 

541 impact = self._portfolio.trading_cost_impact(max_bps=max_bps) 

542 

543 cost_vals = impact["cost_bps"].to_list() 

544 sharpe_vals = impact["sharpe"].to_list() 

545 

546 # Baseline Sharpe at zero cost 

547 baseline = float(sharpe_vals[0]) if sharpe_vals and sharpe_vals[0] is not None else float("nan") 

548 

549 fig = go.Figure() 

550 fig.add_trace( 

551 go.Scatter( 

552 x=cost_vals, 

553 y=sharpe_vals, 

554 mode="lines+markers", 

555 name="Sharpe (cost-adjusted)", 

556 marker={"size": 6}, 

557 line={"width": 2, "color": "#1f77b4"}, 

558 ) 

559 ) 

560 if baseline == baseline: # only add when baseline is finite (NaN != NaN) 

561 fig.add_hline( 

562 y=baseline, 

563 line_width=1, 

564 line_dash="dash", 

565 line_color="gray", 

566 annotation_text="0 bps baseline", 

567 annotation_position="top right", 

568 ) 

569 

570 fig.update_layout( 

571 title=f"Trading Cost Impact on Sharpe Ratio (0\u2013{max_bps} bps)", 

572 hovermode="x unified", 

573 plot_bgcolor="white", 

574 ) 

575 fig.update_xaxes( 

576 title_text="One-way cost (basis points)", 

577 showgrid=True, 

578 gridwidth=0.5, 

579 gridcolor="lightgrey", 

580 dtick=1, 

581 ) 

582 fig.update_yaxes( 

583 title_text="Annualised Sharpe ratio", 

584 showgrid=True, 

585 gridwidth=0.5, 

586 gridcolor="lightgrey", 

587 ) 

588 return fig