Coverage for src/basanos/math/_stream.py: 99%

299 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-06-23 05:58 +0000

1"""Incremental (streaming) API for BasanosEngine. 

2 

3This private module defines three public symbols: 

4 

5* `_StreamState` — mutable dataclass that persists all O(N²) 

6 accumulator state between consecutive `step` calls. Kept separate 

7 from the engine so the state layout can be read and tested in isolation. 

8* `StepResult` — frozen dataclass returned by each 

9 `step` call. 

10* `BasanosStream` — incremental façade with a 

11 `from_warmup` classmethod and a 

12 `step` method. 

13 

14EWM correlation state model 

15---------------------------- 

16In EWM mode the correlation at each step is recomputed by calling 

17``ewm_covariance`` from ``cvx.linalg`` over the full growing history of 

18vol-adjusted returns stored in ``corr_ret_buf``. This keeps the incremental 

19and batch paths numerically identical at the cost of O(T·N²) time per step 

20(acceptable for small N or short warmup histories). 

21 

22The volatility accumulators (``vola_*``, ``pct_*``) use a simpler scalar 

23recurrence and store the running sums directly as ``(N,)`` arrays. 

24 

25Memory 

26------ 

27Total incremental state is O(T·N) for the growing history buffer plus 

288x(N,) + O(1) scalars. For the SlidingWindowConfig the buffer is a fixed 

29(W, N) array independent of T. 

30""" 

31 

32from __future__ import annotations 

33 

34import dataclasses 

35import logging 

36import os 

37from typing import Any, cast 

38 

39import numpy as np 

40import polars as pl 

41from cvx.linalg import cov_to_corr 

42from cvx.linalg.ewm_cov import ewm_covariance 

43from scipy.signal import lfilter 

44 

45from ..exceptions import MissingDateColumnError, StreamStateCorruptError 

46from ._config import BasanosConfig, EwmaShrinkConfig, SlidingWindowConfig 

47from ._engine_solve import MatrixBundle, SolveStatus, _SolveMixin 

48from ._factor_model import FactorModel 

49from ._signal import shrink2id 

50 

51_logger = logging.getLogger(__name__) 

52 

53#: Increment this when the `save` archive layout changes in 

54#: a backward-incompatible way. `load` asserts the stored 

55#: value matches before deserialising anything, so callers get a clear error 

56#: instead of a silent ``KeyError`` or wrong state. 

57_SAVE_FORMAT_VERSION: int = 3 

58 

59 

60@dataclasses.dataclass 

61class _StreamState: 

62 """Mutable state carrier for one `BasanosStream` instance. 

63 

64 All arrays are updated in-place (or replaced) by ``BasanosStream.step()``. 

65 The class is intentionally *not* frozen so that the step method can modify 

66 fields directly without creating a new object on every tick. 

67 

68 EWM correlation state 

69 ~~~~~~~~~~~~~~~~~~~~~ 

70 ``corr_ret_buf`` holds the growing history of vol-adjusted returns used by 

71 ``ewm_covariance`` to recompute the correlation matrix on each step. It 

72 is ``None`` for ``SlidingWindowConfig`` (which uses ``sw_ret_buf`` instead). 

73 

74 EWM accumulator state (volatility) 

75 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 

76 ``vola_*`` and ``pct_*`` accumulate the running weighted sums needed to 

77 compute exponentially-weighted standard deviations: 

78 

79 * ``s_x`` — EWM sum of x (numerator of the mean) 

80 * ``s_x2`` — EWM sum of x² (numerator of the second moment) 

81 * ``s_w`` — EWM sum of weights (denominator) 

82 * ``s_w2`` — EWM sum of squared weights (for bias correction) 

83 

84 ``beta_vola = (cfg.vola - 1) / cfg.vola`` (from ``com = cfg.vola - 1``) 

85 

86 Attributes: 

87 corr_ret_buf: Growing history of vol-adjusted returns used by 

88 ``ewm_covariance``; shape ``(T, N)`` for EwmaShrinkConfig. 

89 ``None`` for SlidingWindowConfig. 

90 vola_s_x: EWM sum of volatility-adjusted log-returns; shape ``(N,)``. 

91 vola_s_x2: EWM sum of squared vol-adj log-returns; shape ``(N,)``. 

92 vola_s_w: EWM weight sum for vol accumulators; shape ``(N,)``. 

93 vola_s_w2: EWM squared-weight sum for vol accumulators; shape ``(N,)``. 

94 vola_count: Cumulative finite observation count for vol; shape ``(N,)`` 

95 dtype int. 

96 pct_s_x: EWM sum of pct-returns; shape ``(N,)``. 

97 pct_s_x2: EWM sum of squared pct-returns; shape ``(N,)``. 

98 pct_s_w: EWM weight sum for pct accumulators; shape ``(N,)``. 

99 pct_s_w2: EWM squared-weight sum for pct accumulators; shape ``(N,)``. 

100 pct_count: Cumulative finite observation count for pct-return vol; 

101 shape ``(N,)`` dtype int. 

102 prev_price: Last price row seen, used to compute returns on the next 

103 step; shape ``(N,)``. 

104 prev_cash_pos: Last cash position, used to apply the turnover constraint 

105 on the next step; shape ``(N,)``. 

106 step_count: Number of steps processed so far (0 before first step). 

107 """ 

108 

109 # ── EWM correlation history buffer — (T, N) for EwmaShrinkConfig ───────── 

110 corr_ret_buf: np.ndarray | None # (T, N) growing history of vol-adj returns; None for SlidingWindowConfig 

111 

112 # ── EWMA accumulators for vol_adj (log-return std; com=vola-1, min_samples=1) ── 

113 vola_s_x: np.ndarray # (N,) 

114 vola_s_x2: np.ndarray # (N,) 

115 vola_s_w: np.ndarray # (N,) 

116 vola_s_w2: np.ndarray # (N,) 

117 vola_count: np.ndarray # (N,) int 

118 

119 # ── EWMA accumulators for vola (pct-return std; com=vola-1, min_samples=vola) ── 

120 pct_s_x: np.ndarray # (N,) 

121 pct_s_x2: np.ndarray # (N,) 

122 pct_s_w: np.ndarray # (N,) 

123 pct_s_w2: np.ndarray # (N,) 

124 pct_count: np.ndarray # (N,) int 

125 

126 # ── Scalars ─────────────────────────────────────────────────────────────── 

127 prev_price: np.ndarray # (N,) last price row (to compute returns at next step) 

128 prev_cash_pos: np.ndarray # (N,) last cash position (for turnover constraint at next step) 

129 step_count: int 

130 

131 # ── SlidingWindowConfig state — None for EwmaShrinkConfig ──────────────── 

132 # shape (W, N): last W vol-adjusted returns (oldest row first); None when 

133 # using EwmaShrinkConfig. corr_ret_buf above is unused (None) in this mode; 

134 # sw_ret_buf carries all the correlation state instead. 

135 sw_ret_buf: np.ndarray | None = None # (W, N) rolling buffer, or None 

136 

137 

138#: Keys that `save` writes to the ``.npz`` archive for 

139#: `_StreamState` fields. Derived automatically from 

140#: `fields` so that adding a new field to ``_StreamState`` 

141#: is sufficient — no manual update here is required. 

142#: 

143#: The three non-state keys (``format_version``, ``cfg_json``, ``assets``) are 

144#: added explicitly because they are not fields of ``_StreamState`` itself. 

145_REQUIRED_KEYS: frozenset[str] = frozenset( 

146 {f.name for f in dataclasses.fields(_StreamState)} | {"format_version", "cfg_json", "assets"} 

147) 

148 

149 

150@dataclasses.dataclass(frozen=True) 

151class StepResult: 

152 """Frozen dataclass representing the output of a single ``BasanosStream`` step. 

153 

154 Each call to ``BasanosStream.step()`` returns one ``StepResult`` capturing 

155 the optimised cash positions, the per-asset volatility estimate, the step 

156 date, and a status label that describes the solver outcome for that 

157 timestep. 

158 

159 Attributes: 

160 date: The timestamp or date label for this step. The type mirrors 

161 whatever is stored in the ``'date'`` column of the input prices 

162 DataFrame (typically a Python `date`, 

163 `datetime`, or a Polars temporal scalar). 

164 cash_position: Optimised cash-position vector, shape ``(N,)``. 

165 Entries are ``NaN`` for assets that are still in the EWMA warmup 

166 period or that are otherwise inactive at this step. 

167 status: Solver outcome label for this timestep 

168 (`SolveStatus`). Since `SolveStatus` 

169 is a ``StrEnum``, values compare equal to their string equivalents 

170 (e.g. ``result.status == "valid"`` is ``True``): 

171 

172 * ``'warmup'`` — fewer rows have been seen than the EWMA warmup 

173 requires; all positions are ``NaN``. 

174 * ``'zero_signal'`` — the expected-return signal vector ``mu`` is 

175 identically zero; positions are set to zero rather than solved. 

176 * ``'degenerate'`` — the covariance matrix is ill-conditioned or 

177 numerically singular; positions cannot be computed reliably and 

178 are returned as ``NaN``. 

179 * ``'valid'`` — normal operation; ``cash_position`` holds the 

180 optimised allocations. 

181 vola: Per-asset EWMA percentage-return volatility, shape ``(N,)``. 

182 Values are ``NaN`` during the warmup period before the EWMA has 

183 accumulated sufficient history. 

184 

185 Examples: 

186 >>> import numpy as np 

187 >>> result = StepResult( 

188 ... date="2024-01-02", 

189 ... cash_position=np.array([1000.0, -500.0]), 

190 ... status="valid", 

191 ... vola=np.array([0.012, 0.018]), 

192 ... ) 

193 >>> result.status 

194 'valid' 

195 >>> result.cash_position.shape 

196 (2,) 

197 """ 

198 

199 date: object 

200 cash_position: np.ndarray 

201 status: SolveStatus 

202 vola: np.ndarray 

203 

204 

205# --------------------------------------------------------------------------- 

206# Helper: unbiased EWMA std from running accumulators 

207# --------------------------------------------------------------------------- 

208 

209 

210def _ewm_std_from_state( 

211 s_x: np.ndarray, 

212 s_x2: np.ndarray, 

213 s_w: np.ndarray, 

214 s_w2: np.ndarray, 

215 count: np.ndarray, 

216 min_samples: int, 

217) -> np.ndarray: 

218 r"""Compute the unbiased EWMA standard deviation from running accumulators. 

219 

220 Implements the same Bessel-corrected formula used by 

221 ``polars.Expr.ewm_std(adjust=True)``:: 

222 

223 var_biased = s_x2/s_w - (s_x/s_w)^2 

224 correction = s_w^2 / (s_w^2 - s_w2) # Bessel correction 

225 var_unbiased = var_biased * correction 

226 std = sqrt(max(0, var_unbiased)) 

227 

228 where ``s_w2 = sum(wi^2)`` is the sum of squared EWM weights. 

229 

230 Parameters 

231 ---------- 

232 s_x, s_x2, s_w, s_w2: 

233 Running accumulators, each of shape ``(N,)``. 

234 count: 

235 Integer count of finite observations per asset, shape ``(N,)``. 

236 min_samples: 

237 Minimum number of finite observations required before returning a 

238 non-NaN value. 

239 

240 Returns: 

241 ------- 

242 np.ndarray of shape ``(N,)`` with per-asset standard deviations. 

243 NaN is returned for assets where ``count < min_samples``. 

244 """ 

245 n = len(s_x) 

246 result = np.full(n, np.nan, dtype=float) 

247 ok = count >= min_samples 

248 if not ok.any(): 

249 return result 

250 

251 with np.errstate(divide="ignore", invalid="ignore"): 

252 mean = np.where(s_w > 0, s_x / s_w, 0.0) 

253 mean_sq = np.where(s_w > 0, s_x2 / s_w, 0.0) 

254 var_biased = np.maximum(mean_sq - mean**2, 0.0) 

255 denom_corr = s_w**2 - s_w2 

256 # denom_corr > 0 iff count >= 2; equals 0 when count == 1 

257 var_unbiased = np.where(denom_corr > 0, var_biased * s_w**2 / denom_corr, 0.0) 

258 std = np.sqrt(var_unbiased) 

259 

260 return np.where(ok, std, np.nan) 

261 

262 

263# --------------------------------------------------------------------------- 

264# Helper: batch EWMA volatility accumulators from a returns matrix 

265# --------------------------------------------------------------------------- 

266 

267 

268def _ewm_vol_accumulators_from_batch( 

269 returns: np.ndarray, 

270 beta: float, 

271 beta_sq: float, 

272) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: 

273 r"""Compute final EWMA volatility accumulators from a batch of returns. 

274 

275 Implements the same IIR recurrence as `step` but 

276 vectorised over *T* timesteps using ``scipy.signal.lfilter``. The five 

277 returned arrays are identical to the accumulators that would result from 

278 feeding each row of *returns* through the scalar step-by-step recurrence:: 

279 

280 s_x[t] = beta * s_x[t-1] + (x[t] if finite else 0) 

281 s_x2[t] = beta * s_x2[t-1] + (x[t]^2 if finite else 0) 

282 s_w[t] = beta * s_w[t-1] + (1 if finite else 0) 

283 s_w2[t] = beta^2 * s_w2[t-1] + (1 if finite else 0) 

284 

285 Parameters 

286 ---------- 

287 

288 Returns: 

289 Float array of shape ``(T, N)``. NaN entries are treated as missing 

290 observations — they contribute nothing to the numerator sums and do 

291 not increment the weight accumulators. 

292 beta: 

293 EWM decay factor for ``s_x``, ``s_x2``, and ``s_w`` 

294 (``beta = (com) / (1 + com)`` for ``com = cfg.vola - 1``). 

295 beta_sq: 

296 Squared decay factor used for ``s_w2``. Must equal ``beta ** 2``. 

297 

298 Returns: 

299 ------- 

300 s_x, s_x2, s_w, s_w2 : np.ndarray of shape ``(N,)`` 

301 Final EWMA running accumulators after processing all *T* rows. 

302 count : np.ndarray of shape ``(N,)`` dtype int 

303 Number of finite observations per asset. 

304 

305 Notes: 

306 ----- 

307 This function is the shared implementation used by 

308 `from_warmup` for both the log-return (``vola_*``) 

309 and pct-return (``pct_*``) accumulators. Keeping a single implementation 

310 here guarantees that the batch and incremental paths stay in sync when the 

311 recurrence definition changes. 

312 """ 

313 fin = np.isfinite(returns).astype(np.float64) # (T, N) 

314 x_z = np.where(fin.astype(bool), returns, 0.0) # (T, N) 

315 filt_a = np.array([1.0, -beta]) 

316 filt_a2 = np.array([1.0, -beta_sq]) 

317 

318 s_x: np.ndarray = lfilter([1.0], filt_a, x_z, axis=0)[-1] 

319 s_x2: np.ndarray = lfilter([1.0], filt_a, x_z**2, axis=0)[-1] 

320 s_w: np.ndarray = lfilter([1.0], filt_a, fin, axis=0)[-1] 

321 s_w2: np.ndarray = lfilter([1.0], filt_a2, fin, axis=0)[-1] 

322 count: np.ndarray = fin.sum(axis=0).astype(int) 

323 

324 return s_x, s_x2, s_w, s_w2, count 

325 

326 

327def _resolve_step_vector( 

328 values: np.ndarray | dict[str, float], 

329 assets: list[str], 

330 n_assets: int, 

331 arg_name: str, 

332) -> np.ndarray: 

333 """Resolve one step input to a validated ``(N,)`` float vector. 

334 

335 Args: 

336 values: Raw input provided to ``step`` as either dict or array-like. 

337 assets: Ordered asset names used when ``values`` is a mapping. 

338 n_assets: Expected vector length. 

339 arg_name: Argument label used in shape-mismatch errors. 

340 

341 Returns: 

342 A float64 numpy vector of shape ``(n_assets,)``. 

343 

344 Raises: 

345 ValueError: If the resolved vector does not match ``(n_assets,)``. 

346 """ 

347 if isinstance(values, dict): 

348 vector = np.array([float(values[a]) for a in assets], dtype=float) 

349 else: 

350 vector = np.asarray(values, dtype=float).ravel() 

351 if vector.shape != (n_assets,): 

352 raise ValueError(f"{arg_name} must have shape ({n_assets},); got {vector.shape}") # noqa: TRY003 

353 return vector 

354 

355 

356# --------------------------------------------------------------------------- 

357# BasanosStream 

358# --------------------------------------------------------------------------- 

359 

360 

361class BasanosStream: 

362 """Incremental (streaming) optimiser backed by a single `_StreamState`. 

363 

364 After warming up on a historical batch via `from_warmup`, each call 

365 to `step` advances the internal state by exactly one row in 

366 O(N^2) time — without revisiting the full warmup history. 

367 

368 Attributes: 

369 assets: Ordered list of asset column names (read-only). 

370 

371 Examples: 

372 >>> import numpy as np 

373 >>> import polars as pl 

374 >>> from datetime import date, timedelta 

375 >>> from basanos.math import BasanosConfig, BasanosStream 

376 >>> rng = np.random.default_rng(0) 

377 >>> warmup_len = 60 

378 >>> dates = pl.date_range( 

379 ... start=date(2024, 1, 1), 

380 ... end=date(2024, 1, 1) + timedelta(days=warmup_len), 

381 ... interval="1d", 

382 ... eager=True, 

383 ... ) 

384 >>> prices = pl.DataFrame({ 

385 ... "date": dates, 

386 ... "A": np.cumprod(1 + rng.normal(0.001, 0.02, warmup_len + 1)) * 100.0, 

387 ... "B": np.cumprod(1 + rng.normal(0.001, 0.02, warmup_len + 1)) * 150.0, 

388 ... }) 

389 >>> mu = pl.DataFrame({ 

390 ... "date": dates, 

391 ... "A": rng.normal(0, 0.5, warmup_len + 1), 

392 ... "B": rng.normal(0, 0.5, warmup_len + 1), 

393 ... }) 

394 >>> cfg = BasanosConfig(vola=5, corr=10, clip=3.0, shrink=0.5, aum=1e6) 

395 >>> stream = BasanosStream.from_warmup(prices.head(warmup_len), mu.head(warmup_len), cfg) 

396 >>> result = stream.step( 

397 ... prices.select(["A", "B"]).to_numpy()[warmup_len], 

398 ... mu.select(["A", "B"]).to_numpy()[warmup_len], 

399 ... prices["date"][warmup_len], 

400 ... ) 

401 >>> isinstance(result, StepResult) 

402 True 

403 >>> result.cash_position.shape 

404 (2,) 

405 """ 

406 

407 _cfg: BasanosConfig 

408 _assets: list[str] 

409 _state: _StreamState 

410 

411 def __init__(self, cfg: BasanosConfig, assets: list[str], state: _StreamState) -> None: 

412 """Initialise from an explicit config, asset list, and state container.""" 

413 object.__setattr__(self, "_cfg", cfg) 

414 object.__setattr__(self, "_assets", assets) 

415 object.__setattr__(self, "_state", state) 

416 

417 def __setattr__(self, name: str, value: object) -> None: 

418 """Prevent accidental attribute mutation — BasanosStream is immutable.""" 

419 raise dataclasses.FrozenInstanceError(f"{type(self).__name__}.{name}") 

420 

421 @property 

422 def assets(self) -> list[str]: 

423 """Ordered list of asset column names.""" 

424 return self._assets 

425 

426 # ------------------------------------------------------------------ 

427 # from_warmup 

428 # ------------------------------------------------------------------ 

429 

430 @classmethod 

431 def from_warmup( 

432 cls, 

433 prices: pl.DataFrame, 

434 mu: pl.DataFrame, 

435 cfg: BasanosConfig, 

436 ) -> BasanosStream: 

437 """Build a `BasanosStream` from a historical warmup batch. 

438 

439 Runs `BasanosEngine` on the full warmup batch 

440 exactly once and extracts the minimal IIR-filter state required for 

441 subsequent `step` calls. After this call, each `step` 

442 advances the optimiser in O(N^2) time without touching the warmup 

443 data again. 

444 

445 Parameters 

446 ---------- 

447 prices: 

448 Historical price DataFrame. Must contain a ``'date'`` column and 

449 at least one numeric asset column with strictly positive, 

450 non-monotonic values. 

451 mu: 

452 Expected-return signal DataFrame aligned row-by-row with 

453 ``prices``. 

454 cfg: 

455 Engine configuration. Both `EwmaShrinkConfig` 

456 and `SlidingWindowConfig` are supported. 

457 

458 Returns: 

459 ------- 

460 BasanosStream 

461 A stream instance whose `step` method is ready to accept the 

462 row immediately following the last warmup row. 

463 

464 Notes: 

465 ------ 

466 **Short-warmup behaviour with** ``SlidingWindowConfig``: when 

467 ``len(prices) < cfg.covariance_config.window``, the internal rolling 

468 buffer (``sw_ret_buf``) is NaN-padded for the missing prefix rows. 

469 `step` returns ``StepResult(status="warmup")`` for each of the 

470 first ``window - len(prices)`` calls, exactly matching the EWM warmup 

471 semantics. By the time `step` returns the first non-warmup 

472 result the buffer contains only real data — no NaN-padded rows remain. 

473 

474 Raises: 

475 ------ 

476 MissingDateColumnError 

477 If ``'date'`` is absent from ``prices``. 

478 """ 

479 # 1. Validate ------------------------------------------------------- 

480 if "date" not in prices.columns: 

481 raise MissingDateColumnError("prices") 

482 

483 # 2. Build the engine on the full warmup batch ---------------------- 

484 # Import here to avoid a circular dependency at module level. 

485 from .optimizer import BasanosEngine 

486 

487 engine = BasanosEngine(prices=prices, mu=mu, cfg=cfg) 

488 assets = engine.assets 

489 n_assets = len(assets) 

490 n_rows = prices.height 

491 prices_np = prices.select(assets).to_numpy() # (n_rows, n_assets) 

492 

493 # 3. Extract mode-specific state from WarmupState -------------------- 

494 ws = engine.warmup_state() 

495 if isinstance(cfg.covariance_config, EwmaShrinkConfig): 

496 # EWM: seed the growing history buffer from engine.ret_adj so that 

497 # each subsequent step() can call ewm_covariance over the full history. 

498 ret_adj_np = engine.ret_adj.select(assets).to_numpy() 

499 corr_ret_buf: np.ndarray | None = ret_adj_np 

500 sw_ret_buf: np.ndarray | None = None 

501 else: 

502 # SW: carry the last W vol-adjusted returns as a rolling buffer. 

503 sw_config = cfg.covariance_config 

504 win_w = sw_config.window 

505 ret_adj_np = engine.ret_adj.select(assets).to_numpy() # (n_rows, N) 

506 if n_rows >= win_w: 

507 sw_ret_buf = ret_adj_np[-win_w:].copy() 

508 else: 

509 sw_ret_buf = np.full((win_w, n_assets), np.nan) 

510 sw_ret_buf[-n_rows:] = ret_adj_np 

511 corr_ret_buf = None 

512 

513 # 4. Derive EWMA volatility accumulators (vectorised) --------------- 

514 # Both log-return (for vol_adj) and pct-return (for vola) use the 

515 # same beta = (vola-1)/vola. NaN observations (leading NaN at row 0 

516 # from diff/pct_change) are skipped — the filter input is 0 for NaN 

517 # rows and the weight accumulator (s_w) only increments for finite 

518 # observations, matching Polars' effective behaviour for a 

519 # leading-NaN series. 

520 # 

521 # Delegate to the shared helper _ewm_vol_accumulators_from_batch so 

522 # that the batch and incremental recurrences share a single definition. 

523 beta_vola: float = (cfg.vola - 1) / cfg.vola 

524 beta_vola_sq: float = beta_vola**2 

525 

526 log_ret = np.full((n_rows, n_assets), np.nan, dtype=float) 

527 pct_ret = np.full((n_rows, n_assets), np.nan, dtype=float) 

528 if n_rows > 1: 

529 with np.errstate(divide="ignore", invalid="ignore"): 

530 log_ret[1:] = np.log(prices_np[1:] / prices_np[:-1]) 

531 pct_ret[1:] = prices_np[1:] / prices_np[:-1] - 1.0 

532 

533 vola_s_x, vola_s_x2, vola_s_w, vola_s_w2, vola_count = _ewm_vol_accumulators_from_batch( 

534 log_ret, beta_vola, beta_vola_sq 

535 ) 

536 pct_s_x, pct_s_x2, pct_s_w, pct_s_w2, pct_count = _ewm_vol_accumulators_from_batch( 

537 pct_ret, beta_vola, beta_vola_sq 

538 ) 

539 

540 # 5. Extract prev_cash_pos from WarmupState -------------------------- 

541 prev_cash_pos: np.ndarray = ws.prev_cash_pos 

542 prev_price: np.ndarray = prices_np[-1].copy() 

543 

544 # 6. Construct _StreamState and return ------------------------------ 

545 state = _StreamState( 

546 corr_ret_buf=corr_ret_buf, 

547 vola_s_x=vola_s_x, 

548 vola_s_x2=vola_s_x2, 

549 vola_s_w=vola_s_w, 

550 vola_s_w2=vola_s_w2, 

551 vola_count=vola_count, 

552 pct_s_x=pct_s_x, 

553 pct_s_x2=pct_s_x2, 

554 pct_s_w=pct_s_w, 

555 pct_s_w2=pct_s_w2, 

556 pct_count=pct_count, 

557 prev_price=prev_price, 

558 prev_cash_pos=prev_cash_pos, 

559 step_count=n_rows, 

560 sw_ret_buf=sw_ret_buf, 

561 ) 

562 return cls(cfg=cfg, assets=assets, state=state) 

563 

564 # ------------------------------------------------------------------ 

565 # step 

566 # ------------------------------------------------------------------ 

567 

568 @staticmethod 

569 def _warmup_threshold(cfg: BasanosConfig) -> int: 

570 """Return the step count at which warmup ends for the configured mode.""" 

571 if isinstance(cfg.covariance_config, SlidingWindowConfig): 

572 return cfg.covariance_config.window 

573 return cfg.corr 

574 

575 @staticmethod 

576 def _persist_state( 

577 state: _StreamState, 

578 *, 

579 corr_ret_buf: np.ndarray | None, 

580 vola_s_x: np.ndarray, 

581 vola_s_x2: np.ndarray, 

582 vola_s_w: np.ndarray, 

583 vola_s_w2: np.ndarray, 

584 vola_count: np.ndarray, 

585 pct_s_x: np.ndarray, 

586 pct_s_x2: np.ndarray, 

587 pct_s_w: np.ndarray, 

588 pct_s_w2: np.ndarray, 

589 pct_count: np.ndarray, 

590 new_price: np.ndarray, 

591 new_cash_pos: np.ndarray | None = None, 

592 ) -> None: 

593 """Persist accumulators, last-seen vectors, and increment step count.""" 

594 state.corr_ret_buf = corr_ret_buf 

595 state.vola_s_x = vola_s_x 

596 state.vola_s_x2 = vola_s_x2 

597 state.vola_s_w = vola_s_w 

598 state.vola_s_w2 = vola_s_w2 

599 state.vola_count = vola_count 

600 state.pct_s_x = pct_s_x 

601 state.pct_s_x2 = pct_s_x2 

602 state.pct_s_w = pct_s_w 

603 state.pct_s_w2 = pct_s_w2 

604 state.pct_count = pct_count 

605 state.prev_price = new_price.copy() 

606 if new_cash_pos is not None: 

607 state.prev_cash_pos = new_cash_pos.copy() 

608 state.step_count += 1 

609 

610 @staticmethod 

611 def _warmup_result(n_assets: int, date: Any) -> StepResult: 

612 """Build a standard warmup ``StepResult`` payload.""" 

613 return StepResult( 

614 date=date, 

615 cash_position=np.full(n_assets, np.nan), 

616 status=SolveStatus.WARMUP, 

617 vola=np.full(n_assets, np.nan), 

618 ) 

619 

620 def _solve_sliding_window_position( 

621 self, 

622 *, 

623 cfg: BasanosConfig, 

624 state: _StreamState, 

625 mask: np.ndarray, 

626 new_m: np.ndarray, 

627 vola_vec: np.ndarray, 

628 n_assets: int, 

629 date: Any, 

630 ) -> tuple[np.ndarray, SolveStatus]: 

631 """Solve one step in SlidingWindow mode and return cash position + status.""" 

632 from cvx.linalg import SingularMatrixError 

633 

634 new_cash_pos = np.full(n_assets, np.nan, dtype=float) 

635 status = SolveStatus.DEGENERATE 

636 sw_config = cast(SlidingWindowConfig, cfg.covariance_config) 

637 if not mask.any(): 

638 return new_cash_pos, status 

639 

640 win_w = sw_config.window 

641 win_k = sw_config.n_factors 

642 sw_ret_buf = cast(np.ndarray, state.sw_ret_buf) 

643 window_ret = np.where( 

644 np.isfinite(sw_ret_buf[:, mask]), 

645 sw_ret_buf[:, mask], 

646 0.0, 

647 ) 

648 n_sub = int(mask.sum()) 

649 k_eff = min(win_k, win_w, n_sub) 

650 if sw_config.max_components is not None: 

651 k_eff = min(k_eff, sw_config.max_components) 

652 try: 

653 fm = FactorModel.from_returns(window_ret, k=k_eff) 

654 except (np.linalg.LinAlgError, ValueError) as exc: 

655 _logger.debug("Sliding window SVD failed at date=%s: %s", date, exc) 

656 new_cash_pos[mask] = 0.0 

657 return new_cash_pos, status 

658 

659 expected_mu = np.nan_to_num(new_m[mask]) 

660 if np.allclose(expected_mu, 0.0): 

661 new_cash_pos[mask] = 0.0 

662 return new_cash_pos, SolveStatus.ZERO_SIGNAL 

663 

664 try: 

665 x = fm.solve(expected_mu) 

666 denom_val = float(np.sqrt(max(0.0, float(np.dot(expected_mu, x))))) 

667 except (SingularMatrixError, np.linalg.LinAlgError) as exc: 

668 _logger.warning("Woodbury solve failed at date=%s: %s", date, exc) 

669 new_cash_pos[mask] = 0.0 

670 return new_cash_pos, status 

671 

672 if not np.isfinite(denom_val) or denom_val <= cfg.denom_tol: 

673 _logger.warning( 

674 "Positions zeroed at date=%s (sliding_window): normalisation " 

675 "denominator degenerate (denom=%s, denom_tol=%s).", 

676 date, 

677 denom_val, 

678 cfg.denom_tol, 

679 ) 

680 new_cash_pos[mask] = 0.0 

681 return new_cash_pos, status 

682 

683 risk_pos = x / denom_val 

684 vola_sub = vola_vec[mask] 

685 with np.errstate(invalid="ignore"): 

686 new_cash_pos[mask] = risk_pos / vola_sub 

687 return new_cash_pos, SolveStatus.VALID 

688 

689 @staticmethod 

690 def _solve_ewma_position( 

691 *, 

692 cfg: BasanosConfig, 

693 state: _StreamState, 

694 corr_ret_buf: np.ndarray, 

695 mask: np.ndarray, 

696 new_m: np.ndarray, 

697 vola_vec: np.ndarray, 

698 assets: list[str], 

699 n_assets: int, 

700 date: Any, 

701 ) -> tuple[np.ndarray, SolveStatus]: 

702 """Solve one step in EWMA mode and return cash position + status.""" 

703 new_cash_pos = np.full(n_assets, np.nan, dtype=float) 

704 buf = corr_ret_buf # (T, N) — already includes the new row 

705 span = 2 * cfg.corr + 1 

706 t = buf.shape[0] 

707 cols = [pl.Series(a, buf[:, i]).fill_nan(None) for i, a in enumerate(assets)] 

708 pl_df = pl.DataFrame([pl.Series("t", list(range(t))), *cols]) 

709 cov_dict = ewm_covariance(pl_df, assets=assets, index_col="t", window=span, warmup=cfg.corr) 

710 if not cov_dict: 

711 corr = np.full((n_assets, n_assets), np.nan) 

712 else: 

713 # keys are the integer ``t`` index values built from ``range(t)`` above 

714 latest = max(cov_dict, key=lambda k: cast("int", k)) 

715 corr = cov_to_corr(cov_dict[latest], cfg.min_corr_denom) 

716 matrix = shrink2id(corr, lamb=cfg.shrink) 

717 expected_mu, early = _SolveMixin._row_early_check(state.step_count, date, mask, new_m) 

718 if early is not None: 

719 _, _, _, pos, status = early 

720 new_cash_pos[mask] = pos 

721 return new_cash_pos, status 

722 

723 corr_sub = matrix[np.ix_(mask, mask)] 

724 _, _, _, pos, status = _SolveMixin._compute_position( 

725 state.step_count, date, mask, expected_mu, MatrixBundle(matrix=corr_sub), cfg.denom_tol 

726 ) 

727 if status == SolveStatus.VALID: 

728 new_cash_pos[mask] = _SolveMixin._scale_to_cash(cast(np.ndarray, pos), vola_vec[mask]) 

729 else: 

730 new_cash_pos[mask] = pos 

731 return new_cash_pos, status 

732 

733 def step( 

734 self, 

735 new_prices: np.ndarray | dict[str, float], 

736 new_mu: np.ndarray | dict[str, float], 

737 date: Any = None, 

738 ) -> StepResult: 

739 """Advance the stream by one row and return the new optimised position. 

740 

741 Parameters 

742 ---------- 

743 new_prices: 

744 Per-asset prices for the new timestep. Either a numpy array of 

745 shape ``(N,)`` (assets ordered as in `assets`) or a dict 

746 mapping asset names to price values. 

747 new_mu: 

748 Per-asset expected-return signals, same format as ``new_prices``. 

749 date: 

750 Timestamp for this step (stored in `date` 

751 verbatim; not used in any computation). 

752 

753 Returns: 

754 ------- 

755 StepResult 

756 Frozen dataclass with ``cash_position``, ``vola``, ``status``, and 

757 ``date`` for this timestep. 

758 """ 

759 cfg = self._cfg 

760 assets = self._assets 

761 state = self._state 

762 n_assets = len(assets) 

763 

764 # ── Check if still in the warmup period ────────────────────────────── 

765 # step_count is initialised to n_rows in from_warmup. 

766 # 

767 # EwmaShrinkConfig: in_warmup is True for the first (cfg.corr - n_rows) 

768 # calls when the warmup batch was shorter than cfg.corr (not enough rows 

769 # to populate the EWM correlation matrix). 

770 # 

771 # SlidingWindowConfig: in_warmup is True for the first (window - n_rows) 

772 # calls when the warmup batch was shorter than the window. During this 

773 # period sw_ret_buf still contains NaN-padded prefix rows; each step 

774 # shifts one NaN out and appends a real row, so the buffer is fully 

775 # populated with real data exactly when in_warmup becomes False. 

776 # 

777 # In both modes all accumulators are still updated during warmup so that 

778 # the state is ready the moment the warmup period ends. 

779 _warmup_thresh = self._warmup_threshold(cfg) 

780 in_warmup: bool = state.step_count < _warmup_thresh 

781 

782 # ── Resolve inputs to (N,) float64 arrays ────────────────────────── 

783 new_p = _resolve_step_vector(new_prices, assets, n_assets, "new_prices") 

784 new_m = _resolve_step_vector(new_mu, assets, n_assets, "new_mu") 

785 

786 prev_p = state.prev_price 

787 beta_vola: float = (cfg.vola - 1) / cfg.vola 

788 beta_vola_sq: float = beta_vola**2 

789 

790 # ── Compute new log-returns and pct-returns ───────────────────────── 

791 with np.errstate(divide="ignore", invalid="ignore"): 

792 ratio = np.where( 

793 np.isfinite(new_p) & np.isfinite(prev_p) & (prev_p > 0), 

794 new_p / prev_p, 

795 np.nan, 

796 ) 

797 log_ret = np.log(ratio) 

798 pct_ret = ratio - 1.0 

799 

800 # ── Update log-return EWMA accumulators ──────────────────────────── 

801 fin_log = np.isfinite(log_ret) 

802 vola_s_x = beta_vola * state.vola_s_x + np.where(fin_log, log_ret, 0.0) 

803 vola_s_x2 = beta_vola * state.vola_s_x2 + np.where(fin_log, log_ret**2, 0.0) 

804 vola_s_w = beta_vola * state.vola_s_w + fin_log.astype(float) 

805 vola_s_w2 = beta_vola_sq * state.vola_s_w2 + fin_log.astype(float) 

806 vola_count = state.vola_count + fin_log.astype(int) 

807 

808 # ── Update pct-return EWMA accumulators ──────────────────────────── 

809 fin_pct = np.isfinite(pct_ret) 

810 pct_s_x = beta_vola * state.pct_s_x + np.where(fin_pct, pct_ret, 0.0) 

811 pct_s_x2 = beta_vola * state.pct_s_x2 + np.where(fin_pct, pct_ret**2, 0.0) 

812 pct_s_w = beta_vola * state.pct_s_w + fin_pct.astype(float) 

813 pct_s_w2 = beta_vola_sq * state.pct_s_w2 + fin_pct.astype(float) 

814 pct_count = state.pct_count + fin_pct.astype(int) 

815 

816 # ── Compute vol-adjusted return (for the correlation IIR input) ───── 

817 log_vol = _ewm_std_from_state(vola_s_x, vola_s_x2, vola_s_w, vola_s_w2, vola_count, min_samples=1) 

818 # Divide; std == 0 yields ±inf → clipped to ±cfg.clip (matches Polars) 

819 with np.errstate(divide="ignore", invalid="ignore"): 

820 vol_adj_val = np.where( 

821 fin_log, 

822 np.clip(log_ret / log_vol, -cfg.clip, cfg.clip), 

823 np.nan, 

824 ) 

825 

826 # ── Mode-specific correlation state update ─────────────────────────── 

827 if isinstance(cfg.covariance_config, SlidingWindowConfig): 

828 # SW: shift the rolling window buffer in-place and append this row. 

829 buf = cast(np.ndarray, state.sw_ret_buf) # (W, N), already owned by state 

830 buf[:-1] = buf[1:] 

831 buf[-1] = vol_adj_val 

832 corr_ret_buf = state.corr_ret_buf # None for SW; pass through 

833 else: 

834 # EWM: append new vol-adjusted return to the growing history buffer. 

835 new_row = vol_adj_val[np.newaxis] # (1, N) 

836 corr_ret_buf = np.vstack([cast(np.ndarray, state.corr_ret_buf), new_row]) 

837 

838 # ── Early return during EWM warmup period ─────────────────────────── 

839 # All accumulators are already updated above; skip the O(N²) matrix 

840 # reconstruction and O(N³) Cholesky solve which are wasteful during 

841 # warmup — the computed positions would be discarded anyway. 

842 if in_warmup: 

843 self._persist_state( 

844 state, 

845 corr_ret_buf=corr_ret_buf, 

846 vola_s_x=vola_s_x, 

847 vola_s_x2=vola_s_x2, 

848 vola_s_w=vola_s_w, 

849 vola_s_w2=vola_s_w2, 

850 vola_count=vola_count, 

851 pct_s_x=pct_s_x, 

852 pct_s_x2=pct_s_x2, 

853 pct_s_w=pct_s_w, 

854 pct_s_w2=pct_s_w2, 

855 pct_count=pct_count, 

856 new_price=new_p, 

857 ) 

858 return self._warmup_result(n_assets, date) 

859 

860 # ── Compute EWMA volatility (pct-return std) — shared ─────────────── 

861 vola_vec = _ewm_std_from_state(pct_s_x, pct_s_x2, pct_s_w, pct_s_w2, pct_count, min_samples=cfg.vola) 

862 

863 # ── Solve for position ─────────────────────────────────────────────── 

864 mask = np.isfinite(new_p) 

865 if isinstance(cfg.covariance_config, SlidingWindowConfig): 

866 new_cash_pos, status = self._solve_sliding_window_position( 

867 cfg=cfg, 

868 state=state, 

869 mask=mask, 

870 new_m=new_m, 

871 vola_vec=vola_vec, 

872 n_assets=n_assets, 

873 date=date, 

874 ) 

875 else: 

876 new_cash_pos, status = self._solve_ewma_position( 

877 cfg=cfg, 

878 state=state, 

879 corr_ret_buf=cast(np.ndarray, corr_ret_buf), 

880 mask=mask, 

881 new_m=new_m, 

882 vola_vec=vola_vec, 

883 assets=list(self._assets), 

884 n_assets=n_assets, 

885 date=date, 

886 ) 

887 

888 # ── Apply turnover constraint ───────────────────────────────────────── 

889 if cfg.max_turnover is not None and status == SolveStatus.VALID: 

890 new_cash_pos[mask] = _SolveMixin._apply_turnover_constraint( 

891 new_cash_pos[mask], 

892 state.prev_cash_pos[mask], 

893 cfg.max_turnover, 

894 ) 

895 

896 # ── Persist updated state ─────────────────────────────────────────── 

897 self._persist_state( 

898 state, 

899 corr_ret_buf=corr_ret_buf, 

900 vola_s_x=vola_s_x, 

901 vola_s_x2=vola_s_x2, 

902 vola_s_w=vola_s_w, 

903 vola_s_w2=vola_s_w2, 

904 vola_count=vola_count, 

905 pct_s_x=pct_s_x, 

906 pct_s_x2=pct_s_x2, 

907 pct_s_w=pct_s_w, 

908 pct_s_w2=pct_s_w2, 

909 pct_count=pct_count, 

910 new_price=new_p, 

911 new_cash_pos=new_cash_pos, 

912 ) 

913 

914 return StepResult( 

915 date=date, 

916 cash_position=new_cash_pos, 

917 status=status, 

918 vola=vola_vec, 

919 ) 

920 

921 def save(self, path: str | os.PathLike[str]) -> None: 

922 """Serialise the stream to a ``.npz`` archive at *path*. 

923 

924 All `_StreamState` arrays, the configuration, and the asset 

925 list are written in a single `savez` call. A stream 

926 restored via `load` produces bit-for-bit identical 

927 `step` output. 

928 

929 Args: 

930 path: Destination file path. `savez` appends 

931 ``.npz`` automatically when the suffix is absent. 

932 

933 Examples: 

934 >>> import tempfile, pathlib, numpy as np 

935 >>> import polars as pl 

936 >>> from datetime import date, timedelta 

937 >>> from basanos.math import BasanosConfig, BasanosStream 

938 >>> rng = np.random.default_rng(0) 

939 >>> n = 60 

940 >>> end = date(2024, 1, 1) + timedelta(days=n - 1) 

941 >>> dates = pl.date_range( 

942 ... date(2024, 1, 1), end, interval="1d", eager=True 

943 ... ) 

944 >>> prices = pl.DataFrame({ 

945 ... "date": dates, 

946 ... "A": np.cumprod(1 + rng.normal(0.001, 0.02, n)) * 100.0, 

947 ... "B": np.cumprod(1 + rng.normal(0.001, 0.02, n)) * 150.0, 

948 ... }) 

949 >>> mu = pl.DataFrame({ 

950 ... "date": dates, 

951 ... "A": rng.normal(0, 0.5, n), 

952 ... "B": rng.normal(0, 0.5, n), 

953 ... }) 

954 >>> cfg = BasanosConfig(vola=5, corr=10, clip=3.0, shrink=0.5, aum=1e6) 

955 >>> stream = BasanosStream.from_warmup(prices, mu, cfg) 

956 >>> with tempfile.TemporaryDirectory() as tmp: 

957 ... p = pathlib.Path(tmp) / "stream.npz" 

958 ... stream.save(p) 

959 ... restored = BasanosStream.load(p) 

960 ... restored.assets == stream.assets 

961 True 

962 """ 

963 state = self._state 

964 # Build the per-field dict automatically from _StreamState so that any 

965 # new field added to the dataclass is included without manual updates. 

966 state_arrays: dict[str, Any] = {} 

967 for field in dataclasses.fields(_StreamState): 

968 value = getattr(state, field.name) 

969 if field.name in ("sw_ret_buf", "corr_ret_buf"): 

970 # Sentinel: use an empty (0, 0) array to represent None so the 

971 # key is always present in the archive and load() can detect it. 

972 state_arrays[field.name] = value if value is not None else np.empty((0, 0), dtype=float) 

973 elif field.name == "step_count": 

974 state_arrays[field.name] = np.array(value) 

975 else: 

976 state_arrays[field.name] = value 

977 np.savez( 

978 path, 

979 format_version=np.array(_SAVE_FORMAT_VERSION), 

980 cfg_json=np.array(self._cfg.model_dump_json()), 

981 assets=np.array(self._assets), 

982 **state_arrays, 

983 ) 

984 

985 @classmethod 

986 def load(cls, path: str | os.PathLike[str]) -> BasanosStream: 

987 """Restore a stream previously saved with `save`. 

988 

989 Args: 

990 path: Path to a ``.npz`` archive written by `save`. 

991 

992 Returns: 

993 A `BasanosStream` whose `step` output is 

994 bit-for-bit identical to the original stream at the time 

995 `save` was called. 

996 

997 Examples: 

998 >>> import tempfile, pathlib, numpy as np 

999 >>> import polars as pl 

1000 >>> from datetime import date, timedelta 

1001 >>> from basanos.math import BasanosConfig, BasanosStream 

1002 >>> rng = np.random.default_rng(1) 

1003 >>> n = 60 

1004 >>> end = date(2024, 1, 1) + timedelta(days=n - 1) 

1005 >>> dates = pl.date_range( 

1006 ... date(2024, 1, 1), end, interval="1d", eager=True 

1007 ... ) 

1008 >>> prices = pl.DataFrame({ 

1009 ... "date": dates, 

1010 ... "A": np.cumprod(1 + rng.normal(0.001, 0.02, n)) * 100.0, 

1011 ... "B": np.cumprod(1 + rng.normal(0.001, 0.02, n)) * 150.0, 

1012 ... }) 

1013 >>> mu = pl.DataFrame({ 

1014 ... "date": dates, 

1015 ... "A": rng.normal(0, 0.5, n), 

1016 ... "B": rng.normal(0, 0.5, n), 

1017 ... }) 

1018 >>> cfg = BasanosConfig(vola=5, corr=10, clip=3.0, shrink=0.5, aum=1e6) 

1019 >>> stream = BasanosStream.from_warmup(prices, mu, cfg) 

1020 >>> with tempfile.TemporaryDirectory() as tmp: 

1021 ... p = pathlib.Path(tmp) / "stream.npz" 

1022 ... stream.save(p) 

1023 ... restored = BasanosStream.load(p) 

1024 ... restored.assets == stream.assets 

1025 True 

1026 """ 

1027 with np.load(path, allow_pickle=False) as data: 

1028 if "format_version" not in data: 

1029 raise ValueError( # noqa: TRY003 

1030 "Stream file is missing a format version tag. " 

1031 "It was written with an incompatible version of BasanosStream. " 

1032 "Re-generate it via BasanosStream.from_warmup()." 

1033 ) 

1034 found = int(data["format_version"]) 

1035 if found != _SAVE_FORMAT_VERSION: 

1036 raise ValueError( # noqa: TRY003 

1037 f"Stream file was written with format version {found}, " 

1038 f"but the current version is {_SAVE_FORMAT_VERSION}. " 

1039 "Re-generate it via BasanosStream.from_warmup()." 

1040 ) 

1041 # Validate that every required key is present. This catches archives 

1042 # that were produced by an older codebase missing a newly added field, 

1043 # or archives that have been manually edited, with a descriptive error 

1044 # instead of a bare KeyError. 

1045 archive_keys = frozenset(data.files) 

1046 missing = _REQUIRED_KEYS - archive_keys 

1047 if missing: 

1048 raise StreamStateCorruptError(missing) 

1049 cfg = BasanosConfig.model_validate_json(data["cfg_json"].item()) 

1050 assets: list[str] = list(data["assets"]) 

1051 state_kwargs: dict[str, Any] = {} 

1052 for field in dataclasses.fields(_StreamState): 

1053 raw = data[field.name] 

1054 if field.name in ("sw_ret_buf", "corr_ret_buf"): 

1055 state_kwargs[field.name] = raw if raw.size > 0 else None 

1056 elif field.name == "step_count": 

1057 state_kwargs[field.name] = int(raw) 

1058 else: 

1059 state_kwargs[field.name] = raw 

1060 state = _StreamState(**state_kwargs) 

1061 return cls(cfg=cfg, assets=assets, state=state)