Coverage for src / basanos / math / _engine_solve.py: 100%

190 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-02 17:47 +0000

1"""Solve/position mixin for BasanosEngine. 

2 

3This private module contains :class:`_SolveMixin`, which provides the 

4``_iter_matrices`` and ``_iter_solve`` generator methods. Separating them 

5from :mod:`basanos.math.optimizer` keeps the engine facade lean and makes 

6the per-timestamp solve logic independently readable and testable. 

7""" 

8 

9from __future__ import annotations 

10 

11import dataclasses 

12import datetime 

13import logging 

14from collections.abc import Generator 

15from enum import StrEnum 

16from typing import TYPE_CHECKING, TypeAlias, cast 

17 

18import numpy as np 

19 

20from ..exceptions import SingularMatrixError 

21from ._config import EwmaShrinkConfig, SlidingWindowConfig 

22from ._ewm_corr import _ewm_corr_with_final_state, _EwmCorrState 

23from ._factor_model import FactorModel 

24from ._linalg import inv_a_norm, solve 

25from ._signal import shrink2id 

26 

27if TYPE_CHECKING: 

28 from ._engine_protocol import _EngineProtocol 

29 

30_logger = logging.getLogger(__name__) 

31 

32 

33class SolveStatus(StrEnum): 

34 """Solver outcome labels for each timestamp. 

35 

36 Since :class:`SolveStatus` inherits from :class:`str` via ``StrEnum``, 

37 values compare equal to their string equivalents (e.g. 

38 ``SolveStatus.VALID == "valid"``), preserving backward compatibility 

39 with code that matches on string literals. 

40 

41 Attributes: 

42 WARMUP: Insufficient history for the sliding-window covariance mode. 

43 ZERO_SIGNAL: The expected-return vector was all-zero; positions zeroed. 

44 DEGENERATE: Normalisation denominator was non-finite, solve failed, or 

45 no asset had a finite price; positions zeroed for safety. 

46 VALID: Linear system solved successfully; positions are non-trivially 

47 non-zero. 

48 """ 

49 

50 WARMUP = "warmup" 

51 ZERO_SIGNAL = "zero_signal" 

52 DEGENERATE = "degenerate" 

53 VALID = "valid" 

54 

55 

56@dataclasses.dataclass(frozen=True) 

57class MatrixBundle: 

58 """Container for the covariance matrix and any mode-specific auxiliary state. 

59 

60 Wrapping the covariance matrix in a dataclass decouples 

61 :meth:`_SolveMixin._compute_position` from the raw array so that future 

62 covariance modes (e.g. DCC-GARCH, RMT-cleaned) can carry additional fields 

63 through the same interface without changing the method signature. 

64 

65 Attributes: 

66 matrix: The ``(n_active, n_active)`` covariance sub-matrix for the 

67 active assets at a given timestamp. 

68 """ 

69 

70 matrix: np.ndarray 

71 

72 

73#: Yield type for :meth:`_SolveMixin._iter_matrices`: 

74#: ``(i, t, mask, bundle)`` where ``bundle`` is ``None`` during warmup/no-data. 

75MatrixYield: TypeAlias = tuple[int, datetime.date, np.ndarray, MatrixBundle | None] 

76 

77#: Yield type for :meth:`_SolveMixin._iter_solve`: 

78#: ``(i, t, mask, pos_or_none, status)`` where ``pos_or_none`` is ``None`` only for warmup rows. 

79SolveYield: TypeAlias = tuple[int, datetime.date, np.ndarray, np.ndarray | None, SolveStatus] 

80 

81 

82@dataclasses.dataclass(frozen=True) 

83class WarmupState: 

84 """Final state produced by a full batch solve; consumed by :meth:`BasanosStream.from_warmup`. 

85 

86 Returned by :meth:`BasanosEngine.warmup_state` and used by 

87 :meth:`BasanosStream.from_warmup` to initialise the streaming state without 

88 coupling to the private :meth:`~_SolveMixin._iter_solve` generator. 

89 

90 Attributes: 

91 prev_cash_pos: Cash positions at the last warmup row, shape 

92 ``(n_assets,)``. ``NaN`` for assets that were still in their 

93 own warmup period. 

94 corr_iir_state: Final IIR filter memory from the EWM correlation pass, 

95 or ``None`` when using :class:`~basanos.math.SlidingWindowConfig`. 

96 :meth:`BasanosStream.from_warmup` reads these arrays to seed the 

97 incremental ``lfilter`` state without a second pass over the 

98 warmup data. 

99 """ 

100 

101 prev_cash_pos: np.ndarray 

102 corr_iir_state: _EwmCorrState | None = dataclasses.field(default=None) 

103 

104 

105class _SolveMixin: 

106 """Mixin that provides ``_iter_matrices`` and ``_iter_solve`` generators. 

107 

108 Consumers must also inherit from (or satisfy the interface of) 

109 :class:`~basanos.math._engine_protocol._EngineProtocol` so that 

110 ``self.assets``, ``self.prices``, ``self.mu``, ``self.cfg``, ``self.cor``, 

111 and ``self.ret_adj`` are all available. 

112 """ 

113 

114 @staticmethod 

115 def _compute_mask(prices_row: np.ndarray) -> np.ndarray: 

116 """Return boolean mask indicating which assets have finite prices in the given row.""" 

117 return np.isfinite(prices_row) 

118 

119 @staticmethod 

120 def _check_signal(mu: np.ndarray, mask: np.ndarray) -> SolveStatus | None: 

121 """Return ``ZERO_SIGNAL`` when the masked expected-return vector is all-zero. 

122 

123 Returns ``None`` when the signal is non-trivially non-zero, indicating 

124 that the caller should proceed to the linear solve. 

125 """ 

126 if np.allclose(np.nan_to_num(mu[mask]), 0.0): 

127 return SolveStatus.ZERO_SIGNAL 

128 return None 

129 

130 @staticmethod 

131 def _scale_to_cash(pos: np.ndarray, vola_active: np.ndarray) -> np.ndarray: 

132 """Convert raw solver positions to cash-adjusted positions. 

133 

134 Divides *pos* by *vola_active* (volatility for the active asset subset) 

135 to get cash positions. ``np.errstate(invalid="ignore")`` is applied 

136 internally so NaN volatility values propagate quietly. 

137 """ 

138 with np.errstate(invalid="ignore"): 

139 return pos / vola_active 

140 

141 @staticmethod 

142 def _row_early_check( 

143 i: int, 

144 t: datetime.date, 

145 mask: np.ndarray, 

146 mu_row: np.ndarray, 

147 ) -> tuple[np.ndarray, SolveYield | None]: 

148 """Validate the price mask and expected-return signal for a single row. 

149 

150 Returns an ``(expected_mu, early_yield)`` pair. When ``early_yield`` 

151 is not ``None``, the caller should ``yield early_yield; continue`` 

152 immediately — the row is either degenerate (empty mask) or has an 

153 all-zero signal. When ``early_yield`` is ``None`` the row is ready 

154 for the mode-specific solve step. 

155 

156 Args: 

157 i: Row index. 

158 t: Timestamp. 

159 mask: Boolean array of shape ``(n_assets,)`` indicating finite prices. 

160 mu_row: Expected-return row of shape ``(n_assets,)``. 

161 

162 Returns: 

163 tuple: ``(expected_mu, early_yield)`` where ``expected_mu`` is 

164 ``np.nan_to_num(mu_row[mask])`` and ``early_yield`` is either a 

165 complete :data:`SolveYield` tuple (when the caller should yield 

166 and continue) or ``None`` (when the caller should proceed to solve). 

167 """ 

168 if not mask.any(): 

169 return np.zeros(0), (i, t, mask, np.zeros(0), SolveStatus.DEGENERATE) 

170 expected_mu = np.nan_to_num(mu_row[mask]) 

171 sig_status = _SolveMixin._check_signal(mu_row, mask) 

172 if sig_status is not None: 

173 return expected_mu, (i, t, mask, np.zeros_like(expected_mu), sig_status) 

174 return expected_mu, None 

175 

176 @staticmethod 

177 def _denom_guard_yield( 

178 i: int, 

179 t: datetime.date, 

180 mask: np.ndarray, 

181 expected_mu: np.ndarray, 

182 pos_raw: np.ndarray, 

183 denom: float, 

184 denom_tol: float, 

185 ) -> SolveYield: 

186 """Apply the normalisation-denominator guard and return the appropriate yield tuple. 

187 

188 Emits a :data:`~logging.WARNING` and returns a 

189 :attr:`~SolveStatus.DEGENERATE` yield when *denom* is non-finite or at 

190 or below *denom_tol*; otherwise returns a :attr:`~SolveStatus.VALID` 

191 yield with normalised positions ``pos_raw / denom``. 

192 

193 Args: 

194 i: Row index. 

195 t: Timestamp. 

196 mask: Boolean asset mask of shape ``(n_assets,)``. 

197 expected_mu: Masked expected-return vector of shape ``(n_active,)``. 

198 pos_raw: Raw (pre-normalisation) position vector of shape ``(n_active,)``. 

199 denom: Computed normalisation denominator. 

200 denom_tol: Tolerance threshold below which *denom* is treated as degenerate. 

201 

202 Returns: 

203 SolveYield: Either a degenerate or valid ``(i, t, mask, pos, status)`` tuple. 

204 """ 

205 n_active = len(expected_mu) 

206 if not np.isfinite(denom) or denom <= denom_tol: 

207 _logger.warning( 

208 "Positions zeroed at t=%s: normalisation denominator is degenerate " 

209 "(denom=%s, denom_tol=%s). Check signal magnitude and covariance matrix.", 

210 t, 

211 denom, 

212 denom_tol, 

213 extra={ 

214 "context": { 

215 "t": str(t), 

216 "denom": denom, 

217 "denom_tol": denom_tol, 

218 } 

219 }, 

220 ) 

221 return i, t, mask, np.zeros(n_active), SolveStatus.DEGENERATE 

222 return i, t, mask, pos_raw / denom, SolveStatus.VALID 

223 

224 @staticmethod 

225 def _compute_position( 

226 i: int, 

227 t: datetime.date, 

228 mask: np.ndarray, 

229 expected_mu: np.ndarray, 

230 bundle: MatrixBundle, 

231 denom_tol: float, 

232 ) -> SolveYield: 

233 """Shared solve step used by both covariance branches. 

234 

235 Computes the normalisation denominator via :func:`~basanos.math._linalg.inv_a_norm` 

236 and solves the linear system via :func:`~basanos.math._linalg.solve`, then 

237 delegates to :meth:`_denom_guard_yield`. Handles 

238 :exc:`~basanos.exceptions.SingularMatrixError` from both calls. 

239 

240 Accepting a :class:`MatrixBundle` instead of a raw array means future 

241 covariance modes can attach auxiliary state to the bundle without 

242 changing this method's signature. 

243 

244 Args: 

245 i: Row index. 

246 t: Timestamp. 

247 mask: Boolean asset mask of shape ``(n_assets,)``. 

248 expected_mu: Masked expected-return vector of shape ``(n_active,)``. 

249 bundle: Covariance bundle whose ``matrix`` field is an 

250 ``(n_active, n_active)`` covariance matrix for the active assets. 

251 denom_tol: Tolerance threshold for the normalisation denominator. 

252 

253 Returns: 

254 SolveYield: A degenerate or valid ``(i, t, mask, pos, status)`` tuple. 

255 """ 

256 matrix = bundle.matrix 

257 try: 

258 denom = inv_a_norm(expected_mu, matrix) 

259 except SingularMatrixError: 

260 denom = float("nan") 

261 try: 

262 pos = solve(matrix, expected_mu) 

263 except SingularMatrixError: 

264 return i, t, mask, np.zeros_like(expected_mu), SolveStatus.DEGENERATE 

265 return _SolveMixin._denom_guard_yield(i, t, mask, expected_mu, pos, denom, denom_tol) 

266 

267 @staticmethod 

268 def _apply_turnover_constraint( 

269 new_cash: np.ndarray, 

270 prev_cash: np.ndarray, 

271 max_turnover: float, 

272 ) -> np.ndarray: 

273 """Cap the L1 norm of the position change to *max_turnover*. 

274 

275 When ``sum(|new_cash - prev_cash|) > max_turnover``, the delta is 

276 scaled back proportionally toward *prev_cash* so that the constraint 

277 is exactly met. When the constraint is already satisfied the input is 

278 returned unchanged. 

279 

280 Args: 

281 new_cash: Proposed cash positions after the solve step, shape 

282 ``(n_active,)`` — ``NaN`` values treated as zero. 

283 prev_cash: Cash positions at the previous step, shape 

284 ``(n_active,)`` — ``NaN`` values treated as zero. 

285 max_turnover: Maximum allowed L1 norm of the position change. 

286 

287 Returns: 

288 np.ndarray: The (possibly scaled) new cash positions. 

289 """ 

290 curr = np.nan_to_num(new_cash, nan=0.0) 

291 prev = np.nan_to_num(prev_cash, nan=0.0) 

292 delta = curr - prev 

293 total_delta = float(np.sum(np.abs(delta))) 

294 if total_delta > max_turnover: 

295 scale = max_turnover / total_delta 

296 return prev + delta * scale 

297 return new_cash 

298 

299 def _replay_positions( 

300 self: _EngineProtocol, 

301 risk_pos_np: np.ndarray, 

302 cash_pos_np: np.ndarray, 

303 vola_np: np.ndarray, 

304 ) -> None: 

305 """Replay positions across all rows, filling position arrays. 

306 

307 Iterates :meth:`_iter_solve`, writes risk and cash positions into the 

308 provided pre-allocated arrays. Both arrays are mutated **in-place**. 

309 

310 When :attr:`BasanosConfig.max_turnover` is set, the L1 norm of the 

311 position change ``sum(|x_t - x_{t-1}|)`` is capped at that value by 

312 proportionally scaling the delta toward the previous position before 

313 writing to ``cash_pos_np``. 

314 

315 Args: 

316 risk_pos_np: Pre-allocated ``(T, N)`` array for risk positions. 

317 cash_pos_np: Pre-allocated ``(T, N)`` array for cash positions. 

318 vola_np: ``(T, N)`` EWMA volatility array. 

319 """ 

320 max_to: float | None = self.cfg.max_turnover 

321 for i, _t, mask, pos, _status in self._iter_solve(): 

322 if pos is not None: 

323 new_cash = _SolveMixin._scale_to_cash(pos, vola_np[i, mask]) 

324 if max_to is not None and i > 0: 

325 new_cash = _SolveMixin._apply_turnover_constraint(new_cash, cash_pos_np[i - 1, mask], max_to) 

326 risk_pos_np[i, mask] = new_cash * vola_np[i, mask] 

327 cash_pos_np[i, mask] = new_cash 

328 

329 def _iter_matrices(self: _EngineProtocol) -> Generator[MatrixYield, None, None]: 

330 r"""Yield ``(i, t, mask, bundle)`` for every timestamp. 

331 

332 ``bundle`` is a :class:`MatrixBundle` wrapping the effective 

333 :math:`(n_{\text{sub}},\ n_{\text{sub}})` correlation matrix for the 

334 active assets (those with finite prices at timestamp *t*). Yields 

335 ``None`` when no valid matrix is available (e.g., before the warm-up 

336 period has elapsed or when no assets have finite prices). 

337 

338 The behaviour depends on :attr:`BasanosConfig.covariance_config`: 

339 

340 * :class:`EwmaShrinkConfig`: Applies :func:`~basanos.math._signal.shrink2id` to 

341 the EWMA correlation matrix (same computation as 

342 :attr:`cash_position`). 

343 * :class:`SlidingWindowConfig`: Builds a 

344 :class:`~basanos.math._factor_model.FactorModel` from the last 

345 ``cfg.covariance_config.window`` rows of vol-adjusted returns and returns its 

346 :attr:`~basanos.math._factor_model.FactorModel.covariance`. 

347 

348 Yields: 

349 tuple: ``(i, t, mask, bundle)`` where 

350 

351 * ``i`` (*int*): Row index into ``self.prices``. 

352 * ``t``: Timestamp value from ``self.prices["date"]``. 

353 * ``mask`` (*np.ndarray[bool]*): Shape ``(n_assets,)``; ``True`` 

354 for assets with finite prices at row *i*. 

355 * ``bundle`` (:class:`MatrixBundle` | ``None``): Covariance bundle 

356 of shape ``(mask.sum(), mask.sum())``, or ``None``. 

357 """ 

358 assets = self.assets 

359 prices_num = self.prices.select(assets).to_numpy() 

360 dates = self.prices["date"].to_list() 

361 

362 if isinstance(self.cfg.covariance_config, EwmaShrinkConfig): 

363 cor = self.cor 

364 for i, t in enumerate(dates): 

365 mask = _SolveMixin._compute_mask(prices_num[i]) 

366 if not mask.any(): 

367 yield i, t, mask, None 

368 continue 

369 corr_n = cor[t] 

370 matrix = shrink2id(corr_n, lamb=self.cfg.shrink)[np.ix_(mask, mask)] 

371 yield i, t, mask, MatrixBundle(matrix=matrix) 

372 else: 

373 sw_config = cast(SlidingWindowConfig, self.cfg.covariance_config) 

374 win_w: int = sw_config.window 

375 win_k: int = sw_config.n_factors 

376 ret_adj_np = self.ret_adj.select(assets).to_numpy() 

377 for i, t in enumerate(dates): 

378 mask = _SolveMixin._compute_mask(prices_num[i]) 

379 if not mask.any() or i + 1 < win_w: 

380 yield i, t, mask, None 

381 continue 

382 window_ret = ret_adj_np[i + 1 - win_w : i + 1][:, mask] 

383 window_ret = np.where(np.isfinite(window_ret), window_ret, 0.0) 

384 n_sub = int(mask.sum()) 

385 k_eff = min(win_k, win_w, n_sub) 

386 try: 

387 fm = FactorModel.from_returns(window_ret, k=k_eff) 

388 yield i, t, mask, MatrixBundle(matrix=fm.covariance) 

389 except (np.linalg.LinAlgError, ValueError) as exc: 

390 _logger.warning("Factor model fit failed at t=%s: %s", t, exc) 

391 yield i, t, mask, None 

392 

393 @staticmethod 

394 def _batched_solve_group( 

395 group: list[tuple[int, datetime.date, np.ndarray, np.ndarray, np.ndarray]], 

396 denom_tol: float, 

397 ) -> dict[int, SolveYield]: 

398 """Solve a batch of linear systems sharing the same active-asset mask. 

399 

400 Stacks the ``len(group)`` systems into a ``(G, n, n)`` coefficient tensor 

401 and a ``(G, n)`` right-hand-side matrix, then dispatches a single 

402 ``numpy.linalg.solve`` call (which maps to a single batched LAPACK 

403 routine). Denominators are computed directly from the batch result as 

404 ``sqrt(mu_i · pos_i)`` — algebraically identical to the per-row 

405 :func:`~basanos.math._linalg.inv_a_norm` call. 

406 

407 Falls back to row-by-row :meth:`_compute_position` when 

408 ``numpy.linalg.solve`` raises ``LinAlgError`` (any matrix in the batch 

409 is singular). 

410 

411 Args: 

412 group: List of ``(i, t, mask, expected_mu, matrix)`` tuples; all 

413 entries share the same boolean mask and therefore the same 

414 ``n_active x n_active`` matrix shape. 

415 denom_tol: Passed through to :meth:`_denom_guard_yield`. 

416 

417 Returns: 

418 dict: Mapping from row index ``i`` to its :data:`SolveYield`. 

419 """ 

420 results: dict[int, SolveYield] = {} 

421 a_stack = np.stack([row[4] for row in group]) # (G, n, n) 

422 mu_stack = np.stack([row[3] for row in group]) # (G, n) 

423 

424 try: 

425 # numpy.linalg.solve requires the RHS to be (..., M, K) when a is (..., M, M). 

426 # Reshape mu_stack from (G, n) → (G, n, 1) so core dims match, then squeeze. 

427 pos_stack = np.linalg.solve(a_stack, mu_stack[..., np.newaxis])[..., 0] # (G, n) 

428 except np.linalg.LinAlgError: 

429 # At least one matrix is singular — fall back to sequential per-row solve. 

430 for i, t, mask, expected_mu, matrix in group: 

431 results[i] = _SolveMixin._compute_position( 

432 i, t, mask, expected_mu, MatrixBundle(matrix=matrix), denom_tol 

433 ) 

434 return results 

435 

436 # Denominators: sqrt(mu_i^T A_i^{-1} mu_i) = sqrt(mu_i · pos_i). 

437 dots = (mu_stack * pos_stack).sum(axis=1) # (G,) 

438 denoms = np.where(dots > 0.0, np.sqrt(dots), np.nan) 

439 

440 for (i, t, mask, expected_mu, _matrix), pos, denom in zip(group, pos_stack, denoms, strict=True): 

441 results[i] = _SolveMixin._denom_guard_yield(i, t, mask, expected_mu, pos, float(denom), denom_tol) 

442 

443 return results 

444 

445 @staticmethod 

446 def _iter_solve_ewma_batched( 

447 mu_np: np.ndarray, 

448 matrix_yields: list[MatrixYield], 

449 denom_tol: float, 

450 ) -> Generator[SolveYield, None, None]: 

451 r"""Vectorised EwmaShrink solve: batch ``numpy.linalg.solve`` across timestamps. 

452 

453 Groups rows by their boolean asset mask so all systems within a group 

454 share the same ``(n_active, n_active)`` shape, then stacks them into a 

455 ``(G, n, n)`` tensor and calls ``numpy.linalg.solve`` once per unique 

456 mask pattern. Results are collected in a dict and yielded in original 

457 row order. 

458 

459 Denominators are derived from the batch solution as 

460 :math:`\sqrt{\mu_i \cdot \mathbf{pos}_i} = \sqrt{\mu_i^\top \Sigma_i^{-1} \mu_i}`, 

461 matching the scalar :func:`~basanos.math._linalg.inv_a_norm` result up 

462 to float64 rounding. 

463 

464 Any group whose batch solve raises ``LinAlgError`` (singular matrix in 

465 the batch) falls back to sequential :meth:`_compute_position` for that 

466 group only. 

467 

468 Args: 

469 mu_np: Signal matrix, shape ``(T, n_assets)``. 

470 matrix_yields: Pre-collected list from :meth:`_iter_matrices` 

471 (the EwmaShrinkConfig branch). 

472 denom_tol: Denominator guard tolerance. 

473 

474 Yields: 

475 :data:`SolveYield` tuples in original row order. 

476 """ 

477 # First pass: categorise each row as early-exit or a solve candidate. 

478 all_results: dict[int, SolveYield] = {} 

479 # mask.tobytes() → list of (i, t, mask, expected_mu, matrix) 

480 solve_groups: dict[bytes, list[tuple[int, datetime.date, np.ndarray, np.ndarray, np.ndarray]]] = {} 

481 

482 for i, t, mask, bundle in matrix_yields: 

483 if bundle is None: 

484 all_results[i] = (i, t, mask, np.zeros(int(mask.sum())), SolveStatus.DEGENERATE) 

485 continue 

486 expected_mu, early = _SolveMixin._row_early_check(i, t, mask, mu_np[i]) 

487 if early is not None: 

488 all_results[i] = early 

489 continue 

490 mask_key = mask.tobytes() 

491 if mask_key not in solve_groups: 

492 solve_groups[mask_key] = [] 

493 solve_groups[mask_key].append((i, t, mask, expected_mu, bundle.matrix)) 

494 

495 # Second pass: batch-solve each mask group. 

496 for group in solve_groups.values(): 

497 all_results.update(_SolveMixin._batched_solve_group(group, denom_tol)) 

498 

499 # Yield in original row order. 

500 for i in range(len(matrix_yields)): 

501 if i in all_results: 

502 yield all_results[i] 

503 

504 def _iter_solve(self: _EngineProtocol) -> Generator[SolveYield, None, None]: 

505 r"""Yield ``(i, t, mask, pos_or_none, status)`` for every timestamp. 

506 

507 Iterates :meth:`_iter_matrices` for the per-row covariance sub-matrix, 

508 then applies :meth:`_row_early_check` (mask/signal guard) and 

509 :meth:`_compute_position` (linear solve and denominator guard). The two 

510 covariance modes differ only in how ``matrix`` is built, which 

511 :meth:`_iter_matrices` already encapsulates. 

512 

513 * ``matrix is None`` → :attr:`~SolveStatus.WARMUP` (sliding-window before 

514 sufficient history) or :attr:`~SolveStatus.DEGENERATE` otherwise. 

515 * Signal all-zero → :attr:`~SolveStatus.ZERO_SIGNAL`. 

516 * Singular or degenerate solve → :attr:`~SolveStatus.DEGENERATE`. 

517 * Success → :attr:`~SolveStatus.VALID`. 

518 

519 For the :class:`~basanos.math.EwmaShrinkConfig` path the solve step is 

520 vectorised: rows are grouped by their active-asset mask pattern and each 

521 group is solved via a single batched ``numpy.linalg.solve`` call (see 

522 :meth:`_iter_solve_ewma_batched`). The :class:`~basanos.math.SlidingWindowConfig` 

523 path retains a sequential per-row solve because the factor-model matrices 

524 are constructed lazily and may vary in numerical character across rows. 

525 

526 .. note:: 

527 

528 **Dual-path maintenance obligation**: this method dispatches to two 

529 fundamentally different implementations. Any change to solve 

530 semantics — a new edge case, a new :class:`SolveStatus` value, or a 

531 change to denominator logic — **must be applied to both branches**: 

532 

533 * :meth:`_iter_solve_ewma_batched` / :meth:`_batched_solve_group` 

534 (EwmaShrink vectorised path) 

535 * The sequential ``_compute_position`` loop below 

536 (SlidingWindow path) 

537 

538 The cross-path numerical consistency test 

539 ``test_ewma_batch_and_sequential_paths_agree`` in 

540 ``tests/test_math/test_numerical_regression.py`` will fail 

541 whenever the two paths drift apart, surfacing the divergence 

542 before it reaches production. 

543 

544 Yields: 

545 SolveYield: ``(i, t, mask, pos_or_none, status)`` — see 

546 :data:`SolveYield` for detailed field descriptions. 

547 """ 

548 mu_np = self.mu.select(self.assets).to_numpy() 

549 is_sw = isinstance(self.cfg.covariance_config, SlidingWindowConfig) 

550 

551 if not is_sw: 

552 # EwmaShrinkConfig path: vectorised batch solve grouped by mask pattern. 

553 yield from _SolveMixin._iter_solve_ewma_batched(mu_np, list(self._iter_matrices()), self.cfg.denom_tol) 

554 return 

555 

556 # SlidingWindowConfig path: sequential per-row solve (lazy factor models). 

557 win_w: int = cast(SlidingWindowConfig, self.cfg.covariance_config).window 

558 

559 for i, t, mask, bundle in self._iter_matrices(): 

560 if bundle is None: 

561 # Distinguish SW warmup (insufficient history) from no-data / model-failure. 

562 if mask.any() and i + 1 < win_w: 

563 yield i, t, mask, None, SolveStatus.WARMUP 

564 else: 

565 yield i, t, mask, np.zeros(int(mask.sum())), SolveStatus.DEGENERATE 

566 continue 

567 expected_mu, early = _SolveMixin._row_early_check(i, t, mask, mu_np[i]) 

568 if early is not None: 

569 yield early 

570 continue 

571 yield _SolveMixin._compute_position(i, t, mask, expected_mu, bundle, self.cfg.denom_tol) 

572 

573 def warmup_state(self: _EngineProtocol) -> WarmupState: 

574 """Return the final :class:`WarmupState` after replaying the full batch. 

575 

576 Encapsulates the position replay loop that was previously duplicated 

577 inside :meth:`BasanosStream.from_warmup`. By centralising the loop 

578 here, :meth:`~BasanosStream.from_warmup` no longer needs to call the 

579 private :meth:`_iter_solve` generator directly. 

580 

581 Returns: 

582 WarmupState: A frozen dataclass with: 

583 

584 * ``prev_cash_pos`` - cash-position vector at the last row, 

585 shape ``(n_assets,)``. 

586 

587 Examples: 

588 >>> import numpy as np 

589 >>> import polars as pl 

590 >>> from basanos.math import BasanosConfig, BasanosEngine 

591 >>> rng = np.random.default_rng(0) 

592 >>> dates = list(range(30)) 

593 >>> prices = pl.DataFrame({ 

594 ... "date": dates, 

595 ... "A": np.cumprod(1 + rng.normal(0.001, 0.02, 30)) * 100.0, 

596 ... "B": np.cumprod(1 + rng.normal(0.001, 0.02, 30)) * 150.0, 

597 ... }) 

598 >>> mu = pl.DataFrame({ 

599 ... "date": dates, 

600 ... "A": rng.normal(0, 0.5, 30), 

601 ... "B": rng.normal(0, 0.5, 30), 

602 ... }) 

603 >>> cfg = BasanosConfig(vola=5, corr=10, clip=3.0, shrink=0.5, aum=1e6) 

604 >>> engine = BasanosEngine(prices=prices, mu=mu, cfg=cfg) 

605 >>> ws = engine.warmup_state() 

606 >>> ws.prev_cash_pos.shape 

607 (2,) 

608 """ 

609 assets = self.assets 

610 n_rows = self.prices.height 

611 vola_np = self.vola.select(assets).to_numpy() 

612 

613 risk_pos_np = np.full((n_rows, len(assets)), np.nan, dtype=float) 

614 cash_pos_np = np.full((n_rows, len(assets)), np.nan, dtype=float) 

615 

616 if isinstance(self.cfg.covariance_config, EwmaShrinkConfig): 

617 # Compute the IIR filter state in a single pass over the warmup data 

618 # so BasanosStream.from_warmup() can seed the incremental lfilter 

619 # without a second sweep. 

620 ret_adj_np = self.ret_adj.select(assets).to_numpy() 

621 _, iir_state = _ewm_corr_with_final_state( 

622 ret_adj_np, 

623 com=self.cfg.corr, 

624 min_periods=self.cfg.corr, 

625 min_corr_denom=self.cfg.min_corr_denom, 

626 ) 

627 else: 

628 iir_state = None 

629 

630 _SolveMixin._replay_positions(self, risk_pos_np, cash_pos_np, vola_np) 

631 prev_cash_pos = cash_pos_np[-1].copy() 

632 return WarmupState( 

633 prev_cash_pos=prev_cash_pos, 

634 corr_iir_state=iir_state, 

635 )