Coverage for src/basanos/math/_engine_solve.py: 100%

184 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-06-23 05:58 +0000

1"""Solve/position mixin for BasanosEngine. 

2 

3This private module contains `_SolveMixin`, which provides the 

4``_iter_matrices`` and ``_iter_solve`` generator methods. Separating them 

5from `optimizer` keeps the engine facade lean and makes 

6the per-timestamp solve logic independently readable and testable. 

7""" 

8 

9from __future__ import annotations 

10 

11import dataclasses 

12import datetime 

13import logging 

14from collections.abc import Generator 

15from enum import StrEnum 

16from typing import TYPE_CHECKING, TypeAlias, cast 

17 

18import numpy as np 

19from cvx.linalg import SingularMatrixError, inv_a_norm, solve 

20 

21from ._config import EwmaShrinkConfig, SlidingWindowConfig 

22from ._factor_model import FactorModel 

23from ._signal import shrink2id 

24 

25if TYPE_CHECKING: 

26 from ._engine_protocol import _EngineProtocol 

27 

28_logger = logging.getLogger(__name__) 

29 

30 

31class SolveStatus(StrEnum): 

32 """Solver outcome labels for each timestamp. 

33 

34 Since `SolveStatus` inherits from `str` via ``StrEnum``, 

35 values compare equal to their string equivalents (e.g. 

36 ``SolveStatus.VALID == "valid"``), preserving backward compatibility 

37 with code that matches on string literals. 

38 

39 Attributes: 

40 WARMUP: Insufficient history for the sliding-window covariance mode. 

41 ZERO_SIGNAL: The expected-return vector was all-zero; positions zeroed. 

42 DEGENERATE: Normalisation denominator was non-finite, solve failed, or 

43 no asset had a finite price; positions zeroed for safety. 

44 VALID: Linear system solved successfully; positions are non-trivially 

45 non-zero. 

46 """ 

47 

48 WARMUP = "warmup" 

49 ZERO_SIGNAL = "zero_signal" 

50 DEGENERATE = "degenerate" 

51 VALID = "valid" 

52 

53 

54@dataclasses.dataclass(frozen=True) 

55class MatrixBundle: 

56 """Container for the covariance matrix and any mode-specific auxiliary state. 

57 

58 Wrapping the covariance matrix in a dataclass decouples 

59 `_compute_position` from the raw array so that future 

60 covariance modes (e.g. DCC-GARCH, RMT-cleaned) can carry additional fields 

61 through the same interface without changing the method signature. 

62 

63 Attributes: 

64 matrix: The ``(n_active, n_active)`` covariance sub-matrix for the 

65 active assets at a given timestamp. 

66 """ 

67 

68 matrix: np.ndarray 

69 

70 

71#: Yield type for `_iter_matrices`: 

72#: ``(i, t, mask, bundle)`` where ``bundle`` is ``None`` during warmup/no-data. 

73MatrixYield: TypeAlias = tuple[int, datetime.date, np.ndarray, MatrixBundle | None] 

74 

75#: Yield type for `_iter_solve`: 

76#: ``(i, t, mask, pos_or_none, status)`` where ``pos_or_none`` is ``None`` only for warmup rows. 

77SolveYield: TypeAlias = tuple[int, datetime.date, np.ndarray, np.ndarray | None, SolveStatus] 

78 

79 

80@dataclasses.dataclass(frozen=True) 

81class WarmupState: 

82 """Final state produced by a full batch solve; consumed by `from_warmup`. 

83 

84 Returned by `warmup_state` and used by 

85 `from_warmup` to initialise the streaming state without 

86 coupling to the private `_iter_solve` generator. 

87 

88 Attributes: 

89 prev_cash_pos: Cash positions at the last warmup row, shape 

90 ``(n_assets,)``. ``NaN`` for assets that were still in their 

91 own warmup period. 

92 """ 

93 

94 prev_cash_pos: np.ndarray 

95 

96 

97class _SolveMixin: 

98 """Mixin that provides ``_iter_matrices`` and ``_iter_solve`` generators. 

99 

100 Consumers must also inherit from (or satisfy the interface of) 

101 `_EngineProtocol` so that 

102 ``self.assets``, ``self.prices``, ``self.mu``, ``self.cfg``, ``self.cor``, 

103 and ``self.ret_adj`` are all available. 

104 """ 

105 

106 @staticmethod 

107 def _compute_mask(prices_row: np.ndarray) -> np.ndarray: 

108 """Return boolean mask indicating which assets have finite prices in the given row.""" 

109 mask: np.ndarray = np.isfinite(prices_row) 

110 return mask 

111 

112 @staticmethod 

113 def _check_signal(mu: np.ndarray, mask: np.ndarray) -> SolveStatus | None: 

114 """Return ``ZERO_SIGNAL`` when the masked expected-return vector is all-zero. 

115 

116 Returns ``None`` when the signal is non-trivially non-zero, indicating 

117 that the caller should proceed to the linear solve. 

118 """ 

119 if np.allclose(np.nan_to_num(mu[mask]), 0.0): 

120 return SolveStatus.ZERO_SIGNAL 

121 return None 

122 

123 @staticmethod 

124 def _scale_to_cash(pos: np.ndarray, vola_active: np.ndarray) -> np.ndarray: 

125 """Convert raw solver positions to cash-adjusted positions. 

126 

127 Divides *pos* by *vola_active* (volatility for the active asset subset) 

128 to get cash positions. ``np.errstate(invalid="ignore")`` is applied 

129 internally so NaN volatility values propagate quietly. 

130 """ 

131 with np.errstate(invalid="ignore"): 

132 return cast("np.ndarray", pos / vola_active) 

133 

134 @staticmethod 

135 def _row_early_check( 

136 i: int, 

137 t: datetime.date, 

138 mask: np.ndarray, 

139 mu_row: np.ndarray, 

140 ) -> tuple[np.ndarray, SolveYield | None]: 

141 """Validate the price mask and expected-return signal for a single row. 

142 

143 Returns an ``(expected_mu, early_yield)`` pair. When ``early_yield`` 

144 is not ``None``, the caller should ``yield early_yield; continue`` 

145 immediately — the row is either degenerate (empty mask) or has an 

146 all-zero signal. When ``early_yield`` is ``None`` the row is ready 

147 for the mode-specific solve step. 

148 

149 Args: 

150 i: Row index. 

151 t: Timestamp. 

152 mask: Boolean array of shape ``(n_assets,)`` indicating finite prices. 

153 mu_row: Expected-return row of shape ``(n_assets,)``. 

154 

155 Returns: 

156 tuple: ``(expected_mu, early_yield)`` where ``expected_mu`` is 

157 ``np.nan_to_num(mu_row[mask])`` and ``early_yield`` is either a 

158 complete `SolveYield` tuple (when the caller should yield 

159 and continue) or ``None`` (when the caller should proceed to solve). 

160 """ 

161 if not mask.any(): 

162 return np.zeros(0), (i, t, mask, np.zeros(0), SolveStatus.DEGENERATE) 

163 expected_mu = np.nan_to_num(mu_row[mask]) 

164 sig_status = _SolveMixin._check_signal(mu_row, mask) 

165 if sig_status is not None: 

166 return expected_mu, (i, t, mask, np.zeros_like(expected_mu), sig_status) 

167 return expected_mu, None 

168 

169 @staticmethod 

170 def _denom_guard_yield( 

171 i: int, 

172 t: datetime.date, 

173 mask: np.ndarray, 

174 expected_mu: np.ndarray, 

175 pos_raw: np.ndarray, 

176 denom: float, 

177 denom_tol: float, 

178 ) -> SolveYield: 

179 """Apply the normalisation-denominator guard and return the appropriate yield tuple. 

180 

181 Emits a `WARNING` and returns a 

182 `DEGENERATE` yield when *denom* is non-finite or at 

183 or below *denom_tol*; otherwise returns a `VALID` 

184 yield with normalised positions ``pos_raw / denom``. 

185 

186 Args: 

187 i: Row index. 

188 t: Timestamp. 

189 mask: Boolean asset mask of shape ``(n_assets,)``. 

190 expected_mu: Masked expected-return vector of shape ``(n_active,)``. 

191 pos_raw: Raw (pre-normalisation) position vector of shape ``(n_active,)``. 

192 denom: Computed normalisation denominator. 

193 denom_tol: Tolerance threshold below which *denom* is treated as degenerate. 

194 

195 Returns: 

196 SolveYield: Either a degenerate or valid ``(i, t, mask, pos, status)`` tuple. 

197 """ 

198 n_active = len(expected_mu) 

199 if not np.isfinite(denom) or denom <= denom_tol: 

200 _logger.warning( 

201 "Positions zeroed at t=%s: normalisation denominator is degenerate " 

202 "(denom=%s, denom_tol=%s). Check signal magnitude and covariance matrix.", 

203 t, 

204 denom, 

205 denom_tol, 

206 extra={ 

207 "context": { 

208 "t": str(t), 

209 "denom": denom, 

210 "denom_tol": denom_tol, 

211 } 

212 }, 

213 ) 

214 return i, t, mask, np.zeros(n_active), SolveStatus.DEGENERATE 

215 return i, t, mask, pos_raw / denom, SolveStatus.VALID 

216 

217 @staticmethod 

218 def _compute_position( 

219 i: int, 

220 t: datetime.date, 

221 mask: np.ndarray, 

222 expected_mu: np.ndarray, 

223 bundle: MatrixBundle, 

224 denom_tol: float, 

225 ) -> SolveYield: 

226 """Shared solve step used by both covariance branches. 

227 

228 Computes the normalisation denominator via `inv_a_norm` 

229 and solves the linear system via `solve`, then 

230 delegates to `_denom_guard_yield`. Handles 

231 :exc:`~basanos.exceptions.SingularMatrixError` from both calls. 

232 

233 Accepting a `MatrixBundle` instead of a raw array means future 

234 covariance modes can attach auxiliary state to the bundle without 

235 changing this method's signature. 

236 

237 Args: 

238 i: Row index. 

239 t: Timestamp. 

240 mask: Boolean asset mask of shape ``(n_assets,)``. 

241 expected_mu: Masked expected-return vector of shape ``(n_active,)``. 

242 bundle: Covariance bundle whose ``matrix`` field is an 

243 ``(n_active, n_active)`` covariance matrix for the active assets. 

244 denom_tol: Tolerance threshold for the normalisation denominator. 

245 

246 Returns: 

247 SolveYield: A degenerate or valid ``(i, t, mask, pos, status)`` tuple. 

248 """ 

249 matrix = bundle.matrix 

250 try: 

251 denom = inv_a_norm(expected_mu, matrix) 

252 except SingularMatrixError: 

253 denom = float("nan") 

254 try: 

255 pos = solve(matrix, expected_mu) 

256 except SingularMatrixError: 

257 return i, t, mask, np.zeros_like(expected_mu), SolveStatus.DEGENERATE 

258 return _SolveMixin._denom_guard_yield(i, t, mask, expected_mu, pos, denom, denom_tol) 

259 

260 @staticmethod 

261 def _apply_turnover_constraint( 

262 new_cash: np.ndarray, 

263 prev_cash: np.ndarray, 

264 max_turnover: float, 

265 ) -> np.ndarray: 

266 """Cap the L1 norm of the position change to *max_turnover*. 

267 

268 When ``sum(|new_cash - prev_cash|) > max_turnover``, the delta is 

269 scaled back proportionally toward *prev_cash* so that the constraint 

270 is exactly met. When the constraint is already satisfied the input is 

271 returned unchanged. 

272 

273 Args: 

274 new_cash: Proposed cash positions after the solve step, shape 

275 ``(n_active,)`` — ``NaN`` values treated as zero. 

276 prev_cash: Cash positions at the previous step, shape 

277 ``(n_active,)`` — ``NaN`` values treated as zero. 

278 max_turnover: Maximum allowed L1 norm of the position change. 

279 

280 Returns: 

281 np.ndarray: The (possibly scaled) new cash positions. 

282 """ 

283 curr = np.nan_to_num(new_cash, nan=0.0) 

284 prev = np.nan_to_num(prev_cash, nan=0.0) 

285 delta = curr - prev 

286 total_delta = float(np.sum(np.abs(delta))) 

287 if total_delta > max_turnover: 

288 scale = max_turnover / total_delta 

289 return cast("np.ndarray", prev + delta * scale) 

290 return new_cash 

291 

292 def _replay_positions( 

293 self: _EngineProtocol, 

294 risk_pos_np: np.ndarray, 

295 cash_pos_np: np.ndarray, 

296 vola_np: np.ndarray, 

297 ) -> None: 

298 """Replay positions across all rows, filling position arrays. 

299 

300 Iterates `_iter_solve`, writes risk and cash positions into the 

301 provided pre-allocated arrays. Both arrays are mutated **in-place**. 

302 

303 When `max_turnover` is set, the L1 norm of the 

304 position change ``sum(|x_t - x_{t-1}|)`` is capped at that value by 

305 proportionally scaling the delta toward the previous position before 

306 writing to ``cash_pos_np``. 

307 

308 Args: 

309 risk_pos_np: Pre-allocated ``(T, N)`` array for risk positions. 

310 cash_pos_np: Pre-allocated ``(T, N)`` array for cash positions. 

311 vola_np: ``(T, N)`` EWMA volatility array. 

312 """ 

313 max_to: float | None = self.cfg.max_turnover 

314 for i, _t, mask, pos, _status in self._iter_solve(): 

315 if pos is not None: 

316 new_cash = _SolveMixin._scale_to_cash(pos, vola_np[i, mask]) 

317 if max_to is not None and i > 0: 

318 new_cash = _SolveMixin._apply_turnover_constraint(new_cash, cash_pos_np[i - 1, mask], max_to) 

319 risk_pos_np[i, mask] = new_cash * vola_np[i, mask] 

320 cash_pos_np[i, mask] = new_cash 

321 

322 def _iter_matrices(self: _EngineProtocol) -> Generator[MatrixYield, None, None]: 

323 r"""Yield ``(i, t, mask, bundle)`` for every timestamp. 

324 

325 ``bundle`` is a `MatrixBundle` wrapping the effective 

326 $(n_{\text{sub}},\ n_{\text{sub}})$ correlation matrix for the 

327 active assets (those with finite prices at timestamp *t*). Yields 

328 ``None`` when no valid matrix is available (e.g., before the warm-up 

329 period has elapsed or when no assets have finite prices). 

330 

331 The behaviour depends on `covariance_config`: 

332 

333 * `EwmaShrinkConfig`: Applies `shrink2id` to 

334 the EWMA correlation matrix (same computation as 

335 `cash_position`). 

336 * `SlidingWindowConfig`: Builds a 

337 `FactorModel` from the last 

338 ``cfg.covariance_config.window`` rows of vol-adjusted returns and returns its 

339 `covariance`. 

340 

341 Yields: 

342 tuple: ``(i, t, mask, bundle)`` where 

343 

344 * ``i`` (*int*): Row index into ``self.prices``. 

345 * ``t``: Timestamp value from ``self.prices["date"]``. 

346 * ``mask`` (*np.ndarray[bool]*): Shape ``(n_assets,)``; ``True`` 

347 for assets with finite prices at row *i*. 

348 * ``bundle`` (`MatrixBundle` | ``None``): Covariance bundle 

349 of shape ``(mask.sum(), mask.sum())``, or ``None``. 

350 """ 

351 assets = self.assets 

352 prices_num = self.prices.select(assets).to_numpy() 

353 dates = self.prices["date"].to_list() 

354 

355 if isinstance(self.cfg.covariance_config, EwmaShrinkConfig): 

356 cor = self.cor 

357 for i, t in enumerate(dates): 

358 mask = _SolveMixin._compute_mask(prices_num[i]) 

359 if not mask.any(): 

360 yield i, t, mask, None 

361 continue 

362 corr_n = cor[t] 

363 matrix = shrink2id(corr_n, lamb=self.cfg.shrink)[np.ix_(mask, mask)] 

364 yield i, t, mask, MatrixBundle(matrix=matrix) 

365 else: 

366 sw_config = self.cfg.covariance_config 

367 win_w: int = sw_config.window 

368 win_k: int = sw_config.n_factors 

369 ret_adj_np = self.ret_adj.select(assets).to_numpy() 

370 for i, t in enumerate(dates): 

371 mask = _SolveMixin._compute_mask(prices_num[i]) 

372 if not mask.any() or i + 1 < win_w: 

373 yield i, t, mask, None 

374 continue 

375 window_ret = ret_adj_np[i + 1 - win_w : i + 1][:, mask] 

376 window_ret = np.where(np.isfinite(window_ret), window_ret, 0.0) 

377 n_sub = int(mask.sum()) 

378 k_eff = min(win_k, win_w, n_sub) 

379 try: 

380 fm = FactorModel.from_returns(window_ret, k=k_eff) 

381 yield i, t, mask, MatrixBundle(matrix=fm.covariance) 

382 except (np.linalg.LinAlgError, ValueError) as exc: 

383 _logger.warning("Factor model fit failed at t=%s: %s", t, exc) 

384 yield i, t, mask, None 

385 

386 @staticmethod 

387 def _batched_solve_group( 

388 group: list[tuple[int, datetime.date, np.ndarray, np.ndarray, np.ndarray]], 

389 denom_tol: float, 

390 ) -> dict[int, SolveYield]: 

391 """Solve a batch of linear systems sharing the same active-asset mask. 

392 

393 Stacks the ``len(group)`` systems into a ``(G, n, n)`` coefficient tensor 

394 and a ``(G, n)`` right-hand-side matrix, then dispatches a single 

395 ``numpy.linalg.solve`` call (which maps to a single batched LAPACK 

396 routine). Denominators are computed directly from the batch result as 

397 ``sqrt(mu_i · pos_i)`` — algebraically identical to the per-row 

398 `inv_a_norm` call. 

399 

400 Falls back to row-by-row `_compute_position` when 

401 ``numpy.linalg.solve`` raises ``LinAlgError`` (any matrix in the batch 

402 is singular). 

403 

404 Args: 

405 group: List of ``(i, t, mask, expected_mu, matrix)`` tuples; all 

406 entries share the same boolean mask and therefore the same 

407 ``n_active x n_active`` matrix shape. 

408 denom_tol: Passed through to `_denom_guard_yield`. 

409 

410 Returns: 

411 dict: Mapping from row index ``i`` to its `SolveYield`. 

412 """ 

413 results: dict[int, SolveYield] = {} 

414 a_stack = np.stack([row[4] for row in group]) # (G, n, n) 

415 mu_stack = np.stack([row[3] for row in group]) # (G, n) 

416 

417 try: 

418 # numpy.linalg.solve requires the RHS to be (..., M, K) when a is (..., M, M). 

419 # Reshape mu_stack from (G, n) → (G, n, 1) so core dims match, then squeeze. 

420 pos_stack = np.linalg.solve(a_stack, mu_stack[..., np.newaxis])[..., 0] # (G, n) 

421 except np.linalg.LinAlgError: 

422 # At least one matrix is singular — fall back to sequential per-row solve. 

423 for i, t, mask, expected_mu, matrix in group: 

424 results[i] = _SolveMixin._compute_position( 

425 i, t, mask, expected_mu, MatrixBundle(matrix=matrix), denom_tol 

426 ) 

427 return results 

428 

429 # Denominators: sqrt(mu_i^T A_i^{-1} mu_i) = sqrt(mu_i · pos_i). 

430 dots = (mu_stack * pos_stack).sum(axis=1) # (G,) 

431 denoms = np.where(dots > 0.0, np.sqrt(dots), np.nan) 

432 

433 for (i, t, mask, expected_mu, _matrix), pos, denom in zip(group, pos_stack, denoms, strict=True): 

434 results[i] = _SolveMixin._denom_guard_yield(i, t, mask, expected_mu, pos, float(denom), denom_tol) 

435 

436 return results 

437 

438 @staticmethod 

439 def _iter_solve_ewma_batched( 

440 mu_np: np.ndarray, 

441 matrix_yields: list[MatrixYield], 

442 denom_tol: float, 

443 ) -> Generator[SolveYield, None, None]: 

444 r"""Vectorised EwmaShrink solve: batch ``numpy.linalg.solve`` across timestamps. 

445 

446 Groups rows by their boolean asset mask so all systems within a group 

447 share the same ``(n_active, n_active)`` shape, then stacks them into a 

448 ``(G, n, n)`` tensor and calls ``numpy.linalg.solve`` once per unique 

449 mask pattern. Results are collected in a dict and yielded in original 

450 row order. 

451 

452 Denominators are derived from the batch solution as 

453 $\sqrt{\mu_i \cdot \mathbf{pos}_i} = \sqrt{\mu_i^\top \Sigma_i^{-1} \mu_i}$, 

454 matching the scalar `inv_a_norm` result up 

455 to float64 rounding. 

456 

457 Any group whose batch solve raises ``LinAlgError`` (singular matrix in 

458 the batch) falls back to sequential `_compute_position` for that 

459 group only. 

460 

461 Args: 

462 mu_np: Signal matrix, shape ``(T, n_assets)``. 

463 matrix_yields: Pre-collected list from `_iter_matrices` 

464 (the EwmaShrinkConfig branch). 

465 denom_tol: Denominator guard tolerance. 

466 

467 Yields: 

468 `SolveYield` tuples in original row order. 

469 """ 

470 # First pass: categorise each row as early-exit or a solve candidate. 

471 all_results: dict[int, SolveYield] = {} 

472 # mask.tobytes() → list of (i, t, mask, expected_mu, matrix) 

473 solve_groups: dict[bytes, list[tuple[int, datetime.date, np.ndarray, np.ndarray, np.ndarray]]] = {} 

474 

475 for i, t, mask, bundle in matrix_yields: 

476 if bundle is None: 

477 all_results[i] = (i, t, mask, np.zeros(int(mask.sum())), SolveStatus.DEGENERATE) 

478 continue 

479 expected_mu, early = _SolveMixin._row_early_check(i, t, mask, mu_np[i]) 

480 if early is not None: 

481 all_results[i] = early 

482 continue 

483 mask_key = mask.tobytes() 

484 if mask_key not in solve_groups: 

485 solve_groups[mask_key] = [] 

486 solve_groups[mask_key].append((i, t, mask, expected_mu, bundle.matrix)) 

487 

488 # Second pass: batch-solve each mask group. 

489 for group in solve_groups.values(): 

490 all_results.update(_SolveMixin._batched_solve_group(group, denom_tol)) 

491 

492 # Yield in original row order. 

493 for i in range(len(matrix_yields)): 

494 if i in all_results: 

495 yield all_results[i] 

496 

497 def _iter_solve(self: _EngineProtocol) -> Generator[SolveYield, None, None]: 

498 r"""Yield ``(i, t, mask, pos_or_none, status)`` for every timestamp. 

499 

500 Iterates `_iter_matrices` for the per-row covariance sub-matrix, 

501 then applies `_row_early_check` (mask/signal guard) and 

502 `_compute_position` (linear solve and denominator guard). The two 

503 covariance modes differ only in how ``matrix`` is built, which 

504 `_iter_matrices` already encapsulates. 

505 

506 * ``matrix is None`` → `WARMUP` (sliding-window before 

507 sufficient history) or `DEGENERATE` otherwise. 

508 * Signal all-zero → `ZERO_SIGNAL`. 

509 * Singular or degenerate solve → `DEGENERATE`. 

510 * Success → `VALID`. 

511 

512 For the `EwmaShrinkConfig` path the solve step is 

513 vectorised: rows are grouped by their active-asset mask pattern and each 

514 group is solved via a single batched ``numpy.linalg.solve`` call (see 

515 `_iter_solve_ewma_batched`). The `SlidingWindowConfig` 

516 path retains a sequential per-row solve because the factor-model matrices 

517 are constructed lazily and may vary in numerical character across rows. 

518 

519 .. note:: 

520 

521 **Dual-path maintenance obligation**: this method dispatches to two 

522 fundamentally different implementations. Any change to solve 

523 semantics — a new edge case, a new `SolveStatus` value, or a 

524 change to denominator logic — **must be applied to both branches**: 

525 

526 * `_iter_solve_ewma_batched` / `_batched_solve_group` 

527 (EwmaShrink vectorised path) 

528 * The sequential ``_compute_position`` loop below 

529 (SlidingWindow path) 

530 

531 The cross-path numerical consistency test 

532 ``test_ewma_batch_and_sequential_paths_agree`` in 

533 ``tests/test_math/test_numerical_regression.py`` will fail 

534 whenever the two paths drift apart, surfacing the divergence 

535 before it reaches production. 

536 

537 Yields: 

538 SolveYield: ``(i, t, mask, pos_or_none, status)`` — see 

539 `SolveYield` for detailed field descriptions. 

540 """ 

541 mu_np = self.mu.select(self.assets).to_numpy() 

542 cov_config = self.cfg.covariance_config 

543 

544 if not isinstance(cov_config, SlidingWindowConfig): 

545 # EwmaShrinkConfig path: vectorised batch solve grouped by mask pattern. 

546 yield from _SolveMixin._iter_solve_ewma_batched(mu_np, list(self._iter_matrices()), self.cfg.denom_tol) 

547 return 

548 

549 # SlidingWindowConfig path: sequential per-row solve (lazy factor models). 

550 win_w: int = cov_config.window 

551 

552 for i, t, mask, bundle in self._iter_matrices(): 

553 if bundle is None: 

554 # Distinguish SW warmup (insufficient history) from no-data / model-failure. 

555 if mask.any() and i + 1 < win_w: 

556 yield i, t, mask, None, SolveStatus.WARMUP 

557 else: 

558 yield i, t, mask, np.zeros(int(mask.sum())), SolveStatus.DEGENERATE 

559 continue 

560 expected_mu, early = _SolveMixin._row_early_check(i, t, mask, mu_np[i]) 

561 if early is not None: 

562 yield early 

563 continue 

564 yield _SolveMixin._compute_position(i, t, mask, expected_mu, bundle, self.cfg.denom_tol) 

565 

566 def warmup_state(self: _EngineProtocol) -> WarmupState: 

567 """Return the final `WarmupState` after replaying the full batch. 

568 

569 Encapsulates the position replay loop that was previously duplicated 

570 inside `from_warmup`. By centralising the loop 

571 here, `from_warmup` no longer needs to call the 

572 private `_iter_solve` generator directly. 

573 

574 Returns: 

575 WarmupState: A frozen dataclass with: 

576 

577 * ``prev_cash_pos`` - cash-position vector at the last row, 

578 shape ``(n_assets,)``. 

579 

580 Examples: 

581 >>> import numpy as np 

582 >>> import polars as pl 

583 >>> from basanos.math import BasanosConfig, BasanosEngine 

584 >>> rng = np.random.default_rng(0) 

585 >>> dates = list(range(30)) 

586 >>> prices = pl.DataFrame({ 

587 ... "date": dates, 

588 ... "A": np.cumprod(1 + rng.normal(0.001, 0.02, 30)) * 100.0, 

589 ... "B": np.cumprod(1 + rng.normal(0.001, 0.02, 30)) * 150.0, 

590 ... }) 

591 >>> mu = pl.DataFrame({ 

592 ... "date": dates, 

593 ... "A": rng.normal(0, 0.5, 30), 

594 ... "B": rng.normal(0, 0.5, 30), 

595 ... }) 

596 >>> cfg = BasanosConfig(vola=5, corr=10, clip=3.0, shrink=0.5, aum=1e6) 

597 >>> engine = BasanosEngine(prices=prices, mu=mu, cfg=cfg) 

598 >>> ws = engine.warmup_state() 

599 >>> ws.prev_cash_pos.shape 

600 (2,) 

601 """ 

602 assets = self.assets 

603 n_rows = self.prices.height 

604 vola_np = self.vola.select(assets).to_numpy() 

605 

606 risk_pos_np = np.full((n_rows, len(assets)), np.nan, dtype=float) 

607 cash_pos_np = np.full((n_rows, len(assets)), np.nan, dtype=float) 

608 

609 _SolveMixin._replay_positions(self, risk_pos_np, cash_pos_np, vola_np) 

610 prev_cash_pos = cash_pos_np[-1].copy() 

611 return WarmupState(prev_cash_pos=prev_cash_pos)