Coverage for src / basanos / math / _engine_solve.py: 100%
190 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-02 17:47 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-02 17:47 +0000
1"""Solve/position mixin for BasanosEngine.
3This private module contains :class:`_SolveMixin`, which provides the
4``_iter_matrices`` and ``_iter_solve`` generator methods. Separating them
5from :mod:`basanos.math.optimizer` keeps the engine facade lean and makes
6the per-timestamp solve logic independently readable and testable.
7"""
9from __future__ import annotations
11import dataclasses
12import datetime
13import logging
14from collections.abc import Generator
15from enum import StrEnum
16from typing import TYPE_CHECKING, TypeAlias, cast
18import numpy as np
20from ..exceptions import SingularMatrixError
21from ._config import EwmaShrinkConfig, SlidingWindowConfig
22from ._ewm_corr import _ewm_corr_with_final_state, _EwmCorrState
23from ._factor_model import FactorModel
24from ._linalg import inv_a_norm, solve
25from ._signal import shrink2id
27if TYPE_CHECKING:
28 from ._engine_protocol import _EngineProtocol
30_logger = logging.getLogger(__name__)
33class SolveStatus(StrEnum):
34 """Solver outcome labels for each timestamp.
36 Since :class:`SolveStatus` inherits from :class:`str` via ``StrEnum``,
37 values compare equal to their string equivalents (e.g.
38 ``SolveStatus.VALID == "valid"``), preserving backward compatibility
39 with code that matches on string literals.
41 Attributes:
42 WARMUP: Insufficient history for the sliding-window covariance mode.
43 ZERO_SIGNAL: The expected-return vector was all-zero; positions zeroed.
44 DEGENERATE: Normalisation denominator was non-finite, solve failed, or
45 no asset had a finite price; positions zeroed for safety.
46 VALID: Linear system solved successfully; positions are non-trivially
47 non-zero.
48 """
50 WARMUP = "warmup"
51 ZERO_SIGNAL = "zero_signal"
52 DEGENERATE = "degenerate"
53 VALID = "valid"
56@dataclasses.dataclass(frozen=True)
57class MatrixBundle:
58 """Container for the covariance matrix and any mode-specific auxiliary state.
60 Wrapping the covariance matrix in a dataclass decouples
61 :meth:`_SolveMixin._compute_position` from the raw array so that future
62 covariance modes (e.g. DCC-GARCH, RMT-cleaned) can carry additional fields
63 through the same interface without changing the method signature.
65 Attributes:
66 matrix: The ``(n_active, n_active)`` covariance sub-matrix for the
67 active assets at a given timestamp.
68 """
70 matrix: np.ndarray
73#: Yield type for :meth:`_SolveMixin._iter_matrices`:
74#: ``(i, t, mask, bundle)`` where ``bundle`` is ``None`` during warmup/no-data.
75MatrixYield: TypeAlias = tuple[int, datetime.date, np.ndarray, MatrixBundle | None]
77#: Yield type for :meth:`_SolveMixin._iter_solve`:
78#: ``(i, t, mask, pos_or_none, status)`` where ``pos_or_none`` is ``None`` only for warmup rows.
79SolveYield: TypeAlias = tuple[int, datetime.date, np.ndarray, np.ndarray | None, SolveStatus]
82@dataclasses.dataclass(frozen=True)
83class WarmupState:
84 """Final state produced by a full batch solve; consumed by :meth:`BasanosStream.from_warmup`.
86 Returned by :meth:`BasanosEngine.warmup_state` and used by
87 :meth:`BasanosStream.from_warmup` to initialise the streaming state without
88 coupling to the private :meth:`~_SolveMixin._iter_solve` generator.
90 Attributes:
91 prev_cash_pos: Cash positions at the last warmup row, shape
92 ``(n_assets,)``. ``NaN`` for assets that were still in their
93 own warmup period.
94 corr_iir_state: Final IIR filter memory from the EWM correlation pass,
95 or ``None`` when using :class:`~basanos.math.SlidingWindowConfig`.
96 :meth:`BasanosStream.from_warmup` reads these arrays to seed the
97 incremental ``lfilter`` state without a second pass over the
98 warmup data.
99 """
101 prev_cash_pos: np.ndarray
102 corr_iir_state: _EwmCorrState | None = dataclasses.field(default=None)
105class _SolveMixin:
106 """Mixin that provides ``_iter_matrices`` and ``_iter_solve`` generators.
108 Consumers must also inherit from (or satisfy the interface of)
109 :class:`~basanos.math._engine_protocol._EngineProtocol` so that
110 ``self.assets``, ``self.prices``, ``self.mu``, ``self.cfg``, ``self.cor``,
111 and ``self.ret_adj`` are all available.
112 """
114 @staticmethod
115 def _compute_mask(prices_row: np.ndarray) -> np.ndarray:
116 """Return boolean mask indicating which assets have finite prices in the given row."""
117 return np.isfinite(prices_row)
119 @staticmethod
120 def _check_signal(mu: np.ndarray, mask: np.ndarray) -> SolveStatus | None:
121 """Return ``ZERO_SIGNAL`` when the masked expected-return vector is all-zero.
123 Returns ``None`` when the signal is non-trivially non-zero, indicating
124 that the caller should proceed to the linear solve.
125 """
126 if np.allclose(np.nan_to_num(mu[mask]), 0.0):
127 return SolveStatus.ZERO_SIGNAL
128 return None
130 @staticmethod
131 def _scale_to_cash(pos: np.ndarray, vola_active: np.ndarray) -> np.ndarray:
132 """Convert raw solver positions to cash-adjusted positions.
134 Divides *pos* by *vola_active* (volatility for the active asset subset)
135 to get cash positions. ``np.errstate(invalid="ignore")`` is applied
136 internally so NaN volatility values propagate quietly.
137 """
138 with np.errstate(invalid="ignore"):
139 return pos / vola_active
141 @staticmethod
142 def _row_early_check(
143 i: int,
144 t: datetime.date,
145 mask: np.ndarray,
146 mu_row: np.ndarray,
147 ) -> tuple[np.ndarray, SolveYield | None]:
148 """Validate the price mask and expected-return signal for a single row.
150 Returns an ``(expected_mu, early_yield)`` pair. When ``early_yield``
151 is not ``None``, the caller should ``yield early_yield; continue``
152 immediately — the row is either degenerate (empty mask) or has an
153 all-zero signal. When ``early_yield`` is ``None`` the row is ready
154 for the mode-specific solve step.
156 Args:
157 i: Row index.
158 t: Timestamp.
159 mask: Boolean array of shape ``(n_assets,)`` indicating finite prices.
160 mu_row: Expected-return row of shape ``(n_assets,)``.
162 Returns:
163 tuple: ``(expected_mu, early_yield)`` where ``expected_mu`` is
164 ``np.nan_to_num(mu_row[mask])`` and ``early_yield`` is either a
165 complete :data:`SolveYield` tuple (when the caller should yield
166 and continue) or ``None`` (when the caller should proceed to solve).
167 """
168 if not mask.any():
169 return np.zeros(0), (i, t, mask, np.zeros(0), SolveStatus.DEGENERATE)
170 expected_mu = np.nan_to_num(mu_row[mask])
171 sig_status = _SolveMixin._check_signal(mu_row, mask)
172 if sig_status is not None:
173 return expected_mu, (i, t, mask, np.zeros_like(expected_mu), sig_status)
174 return expected_mu, None
176 @staticmethod
177 def _denom_guard_yield(
178 i: int,
179 t: datetime.date,
180 mask: np.ndarray,
181 expected_mu: np.ndarray,
182 pos_raw: np.ndarray,
183 denom: float,
184 denom_tol: float,
185 ) -> SolveYield:
186 """Apply the normalisation-denominator guard and return the appropriate yield tuple.
188 Emits a :data:`~logging.WARNING` and returns a
189 :attr:`~SolveStatus.DEGENERATE` yield when *denom* is non-finite or at
190 or below *denom_tol*; otherwise returns a :attr:`~SolveStatus.VALID`
191 yield with normalised positions ``pos_raw / denom``.
193 Args:
194 i: Row index.
195 t: Timestamp.
196 mask: Boolean asset mask of shape ``(n_assets,)``.
197 expected_mu: Masked expected-return vector of shape ``(n_active,)``.
198 pos_raw: Raw (pre-normalisation) position vector of shape ``(n_active,)``.
199 denom: Computed normalisation denominator.
200 denom_tol: Tolerance threshold below which *denom* is treated as degenerate.
202 Returns:
203 SolveYield: Either a degenerate or valid ``(i, t, mask, pos, status)`` tuple.
204 """
205 n_active = len(expected_mu)
206 if not np.isfinite(denom) or denom <= denom_tol:
207 _logger.warning(
208 "Positions zeroed at t=%s: normalisation denominator is degenerate "
209 "(denom=%s, denom_tol=%s). Check signal magnitude and covariance matrix.",
210 t,
211 denom,
212 denom_tol,
213 extra={
214 "context": {
215 "t": str(t),
216 "denom": denom,
217 "denom_tol": denom_tol,
218 }
219 },
220 )
221 return i, t, mask, np.zeros(n_active), SolveStatus.DEGENERATE
222 return i, t, mask, pos_raw / denom, SolveStatus.VALID
224 @staticmethod
225 def _compute_position(
226 i: int,
227 t: datetime.date,
228 mask: np.ndarray,
229 expected_mu: np.ndarray,
230 bundle: MatrixBundle,
231 denom_tol: float,
232 ) -> SolveYield:
233 """Shared solve step used by both covariance branches.
235 Computes the normalisation denominator via :func:`~basanos.math._linalg.inv_a_norm`
236 and solves the linear system via :func:`~basanos.math._linalg.solve`, then
237 delegates to :meth:`_denom_guard_yield`. Handles
238 :exc:`~basanos.exceptions.SingularMatrixError` from both calls.
240 Accepting a :class:`MatrixBundle` instead of a raw array means future
241 covariance modes can attach auxiliary state to the bundle without
242 changing this method's signature.
244 Args:
245 i: Row index.
246 t: Timestamp.
247 mask: Boolean asset mask of shape ``(n_assets,)``.
248 expected_mu: Masked expected-return vector of shape ``(n_active,)``.
249 bundle: Covariance bundle whose ``matrix`` field is an
250 ``(n_active, n_active)`` covariance matrix for the active assets.
251 denom_tol: Tolerance threshold for the normalisation denominator.
253 Returns:
254 SolveYield: A degenerate or valid ``(i, t, mask, pos, status)`` tuple.
255 """
256 matrix = bundle.matrix
257 try:
258 denom = inv_a_norm(expected_mu, matrix)
259 except SingularMatrixError:
260 denom = float("nan")
261 try:
262 pos = solve(matrix, expected_mu)
263 except SingularMatrixError:
264 return i, t, mask, np.zeros_like(expected_mu), SolveStatus.DEGENERATE
265 return _SolveMixin._denom_guard_yield(i, t, mask, expected_mu, pos, denom, denom_tol)
267 @staticmethod
268 def _apply_turnover_constraint(
269 new_cash: np.ndarray,
270 prev_cash: np.ndarray,
271 max_turnover: float,
272 ) -> np.ndarray:
273 """Cap the L1 norm of the position change to *max_turnover*.
275 When ``sum(|new_cash - prev_cash|) > max_turnover``, the delta is
276 scaled back proportionally toward *prev_cash* so that the constraint
277 is exactly met. When the constraint is already satisfied the input is
278 returned unchanged.
280 Args:
281 new_cash: Proposed cash positions after the solve step, shape
282 ``(n_active,)`` — ``NaN`` values treated as zero.
283 prev_cash: Cash positions at the previous step, shape
284 ``(n_active,)`` — ``NaN`` values treated as zero.
285 max_turnover: Maximum allowed L1 norm of the position change.
287 Returns:
288 np.ndarray: The (possibly scaled) new cash positions.
289 """
290 curr = np.nan_to_num(new_cash, nan=0.0)
291 prev = np.nan_to_num(prev_cash, nan=0.0)
292 delta = curr - prev
293 total_delta = float(np.sum(np.abs(delta)))
294 if total_delta > max_turnover:
295 scale = max_turnover / total_delta
296 return prev + delta * scale
297 return new_cash
299 def _replay_positions(
300 self: _EngineProtocol,
301 risk_pos_np: np.ndarray,
302 cash_pos_np: np.ndarray,
303 vola_np: np.ndarray,
304 ) -> None:
305 """Replay positions across all rows, filling position arrays.
307 Iterates :meth:`_iter_solve`, writes risk and cash positions into the
308 provided pre-allocated arrays. Both arrays are mutated **in-place**.
310 When :attr:`BasanosConfig.max_turnover` is set, the L1 norm of the
311 position change ``sum(|x_t - x_{t-1}|)`` is capped at that value by
312 proportionally scaling the delta toward the previous position before
313 writing to ``cash_pos_np``.
315 Args:
316 risk_pos_np: Pre-allocated ``(T, N)`` array for risk positions.
317 cash_pos_np: Pre-allocated ``(T, N)`` array for cash positions.
318 vola_np: ``(T, N)`` EWMA volatility array.
319 """
320 max_to: float | None = self.cfg.max_turnover
321 for i, _t, mask, pos, _status in self._iter_solve():
322 if pos is not None:
323 new_cash = _SolveMixin._scale_to_cash(pos, vola_np[i, mask])
324 if max_to is not None and i > 0:
325 new_cash = _SolveMixin._apply_turnover_constraint(new_cash, cash_pos_np[i - 1, mask], max_to)
326 risk_pos_np[i, mask] = new_cash * vola_np[i, mask]
327 cash_pos_np[i, mask] = new_cash
329 def _iter_matrices(self: _EngineProtocol) -> Generator[MatrixYield, None, None]:
330 r"""Yield ``(i, t, mask, bundle)`` for every timestamp.
332 ``bundle`` is a :class:`MatrixBundle` wrapping the effective
333 :math:`(n_{\text{sub}},\ n_{\text{sub}})` correlation matrix for the
334 active assets (those with finite prices at timestamp *t*). Yields
335 ``None`` when no valid matrix is available (e.g., before the warm-up
336 period has elapsed or when no assets have finite prices).
338 The behaviour depends on :attr:`BasanosConfig.covariance_config`:
340 * :class:`EwmaShrinkConfig`: Applies :func:`~basanos.math._signal.shrink2id` to
341 the EWMA correlation matrix (same computation as
342 :attr:`cash_position`).
343 * :class:`SlidingWindowConfig`: Builds a
344 :class:`~basanos.math._factor_model.FactorModel` from the last
345 ``cfg.covariance_config.window`` rows of vol-adjusted returns and returns its
346 :attr:`~basanos.math._factor_model.FactorModel.covariance`.
348 Yields:
349 tuple: ``(i, t, mask, bundle)`` where
351 * ``i`` (*int*): Row index into ``self.prices``.
352 * ``t``: Timestamp value from ``self.prices["date"]``.
353 * ``mask`` (*np.ndarray[bool]*): Shape ``(n_assets,)``; ``True``
354 for assets with finite prices at row *i*.
355 * ``bundle`` (:class:`MatrixBundle` | ``None``): Covariance bundle
356 of shape ``(mask.sum(), mask.sum())``, or ``None``.
357 """
358 assets = self.assets
359 prices_num = self.prices.select(assets).to_numpy()
360 dates = self.prices["date"].to_list()
362 if isinstance(self.cfg.covariance_config, EwmaShrinkConfig):
363 cor = self.cor
364 for i, t in enumerate(dates):
365 mask = _SolveMixin._compute_mask(prices_num[i])
366 if not mask.any():
367 yield i, t, mask, None
368 continue
369 corr_n = cor[t]
370 matrix = shrink2id(corr_n, lamb=self.cfg.shrink)[np.ix_(mask, mask)]
371 yield i, t, mask, MatrixBundle(matrix=matrix)
372 else:
373 sw_config = cast(SlidingWindowConfig, self.cfg.covariance_config)
374 win_w: int = sw_config.window
375 win_k: int = sw_config.n_factors
376 ret_adj_np = self.ret_adj.select(assets).to_numpy()
377 for i, t in enumerate(dates):
378 mask = _SolveMixin._compute_mask(prices_num[i])
379 if not mask.any() or i + 1 < win_w:
380 yield i, t, mask, None
381 continue
382 window_ret = ret_adj_np[i + 1 - win_w : i + 1][:, mask]
383 window_ret = np.where(np.isfinite(window_ret), window_ret, 0.0)
384 n_sub = int(mask.sum())
385 k_eff = min(win_k, win_w, n_sub)
386 try:
387 fm = FactorModel.from_returns(window_ret, k=k_eff)
388 yield i, t, mask, MatrixBundle(matrix=fm.covariance)
389 except (np.linalg.LinAlgError, ValueError) as exc:
390 _logger.warning("Factor model fit failed at t=%s: %s", t, exc)
391 yield i, t, mask, None
393 @staticmethod
394 def _batched_solve_group(
395 group: list[tuple[int, datetime.date, np.ndarray, np.ndarray, np.ndarray]],
396 denom_tol: float,
397 ) -> dict[int, SolveYield]:
398 """Solve a batch of linear systems sharing the same active-asset mask.
400 Stacks the ``len(group)`` systems into a ``(G, n, n)`` coefficient tensor
401 and a ``(G, n)`` right-hand-side matrix, then dispatches a single
402 ``numpy.linalg.solve`` call (which maps to a single batched LAPACK
403 routine). Denominators are computed directly from the batch result as
404 ``sqrt(mu_i · pos_i)`` — algebraically identical to the per-row
405 :func:`~basanos.math._linalg.inv_a_norm` call.
407 Falls back to row-by-row :meth:`_compute_position` when
408 ``numpy.linalg.solve`` raises ``LinAlgError`` (any matrix in the batch
409 is singular).
411 Args:
412 group: List of ``(i, t, mask, expected_mu, matrix)`` tuples; all
413 entries share the same boolean mask and therefore the same
414 ``n_active x n_active`` matrix shape.
415 denom_tol: Passed through to :meth:`_denom_guard_yield`.
417 Returns:
418 dict: Mapping from row index ``i`` to its :data:`SolveYield`.
419 """
420 results: dict[int, SolveYield] = {}
421 a_stack = np.stack([row[4] for row in group]) # (G, n, n)
422 mu_stack = np.stack([row[3] for row in group]) # (G, n)
424 try:
425 # numpy.linalg.solve requires the RHS to be (..., M, K) when a is (..., M, M).
426 # Reshape mu_stack from (G, n) → (G, n, 1) so core dims match, then squeeze.
427 pos_stack = np.linalg.solve(a_stack, mu_stack[..., np.newaxis])[..., 0] # (G, n)
428 except np.linalg.LinAlgError:
429 # At least one matrix is singular — fall back to sequential per-row solve.
430 for i, t, mask, expected_mu, matrix in group:
431 results[i] = _SolveMixin._compute_position(
432 i, t, mask, expected_mu, MatrixBundle(matrix=matrix), denom_tol
433 )
434 return results
436 # Denominators: sqrt(mu_i^T A_i^{-1} mu_i) = sqrt(mu_i · pos_i).
437 dots = (mu_stack * pos_stack).sum(axis=1) # (G,)
438 denoms = np.where(dots > 0.0, np.sqrt(dots), np.nan)
440 for (i, t, mask, expected_mu, _matrix), pos, denom in zip(group, pos_stack, denoms, strict=True):
441 results[i] = _SolveMixin._denom_guard_yield(i, t, mask, expected_mu, pos, float(denom), denom_tol)
443 return results
445 @staticmethod
446 def _iter_solve_ewma_batched(
447 mu_np: np.ndarray,
448 matrix_yields: list[MatrixYield],
449 denom_tol: float,
450 ) -> Generator[SolveYield, None, None]:
451 r"""Vectorised EwmaShrink solve: batch ``numpy.linalg.solve`` across timestamps.
453 Groups rows by their boolean asset mask so all systems within a group
454 share the same ``(n_active, n_active)`` shape, then stacks them into a
455 ``(G, n, n)`` tensor and calls ``numpy.linalg.solve`` once per unique
456 mask pattern. Results are collected in a dict and yielded in original
457 row order.
459 Denominators are derived from the batch solution as
460 :math:`\sqrt{\mu_i \cdot \mathbf{pos}_i} = \sqrt{\mu_i^\top \Sigma_i^{-1} \mu_i}`,
461 matching the scalar :func:`~basanos.math._linalg.inv_a_norm` result up
462 to float64 rounding.
464 Any group whose batch solve raises ``LinAlgError`` (singular matrix in
465 the batch) falls back to sequential :meth:`_compute_position` for that
466 group only.
468 Args:
469 mu_np: Signal matrix, shape ``(T, n_assets)``.
470 matrix_yields: Pre-collected list from :meth:`_iter_matrices`
471 (the EwmaShrinkConfig branch).
472 denom_tol: Denominator guard tolerance.
474 Yields:
475 :data:`SolveYield` tuples in original row order.
476 """
477 # First pass: categorise each row as early-exit or a solve candidate.
478 all_results: dict[int, SolveYield] = {}
479 # mask.tobytes() → list of (i, t, mask, expected_mu, matrix)
480 solve_groups: dict[bytes, list[tuple[int, datetime.date, np.ndarray, np.ndarray, np.ndarray]]] = {}
482 for i, t, mask, bundle in matrix_yields:
483 if bundle is None:
484 all_results[i] = (i, t, mask, np.zeros(int(mask.sum())), SolveStatus.DEGENERATE)
485 continue
486 expected_mu, early = _SolveMixin._row_early_check(i, t, mask, mu_np[i])
487 if early is not None:
488 all_results[i] = early
489 continue
490 mask_key = mask.tobytes()
491 if mask_key not in solve_groups:
492 solve_groups[mask_key] = []
493 solve_groups[mask_key].append((i, t, mask, expected_mu, bundle.matrix))
495 # Second pass: batch-solve each mask group.
496 for group in solve_groups.values():
497 all_results.update(_SolveMixin._batched_solve_group(group, denom_tol))
499 # Yield in original row order.
500 for i in range(len(matrix_yields)):
501 if i in all_results:
502 yield all_results[i]
504 def _iter_solve(self: _EngineProtocol) -> Generator[SolveYield, None, None]:
505 r"""Yield ``(i, t, mask, pos_or_none, status)`` for every timestamp.
507 Iterates :meth:`_iter_matrices` for the per-row covariance sub-matrix,
508 then applies :meth:`_row_early_check` (mask/signal guard) and
509 :meth:`_compute_position` (linear solve and denominator guard). The two
510 covariance modes differ only in how ``matrix`` is built, which
511 :meth:`_iter_matrices` already encapsulates.
513 * ``matrix is None`` → :attr:`~SolveStatus.WARMUP` (sliding-window before
514 sufficient history) or :attr:`~SolveStatus.DEGENERATE` otherwise.
515 * Signal all-zero → :attr:`~SolveStatus.ZERO_SIGNAL`.
516 * Singular or degenerate solve → :attr:`~SolveStatus.DEGENERATE`.
517 * Success → :attr:`~SolveStatus.VALID`.
519 For the :class:`~basanos.math.EwmaShrinkConfig` path the solve step is
520 vectorised: rows are grouped by their active-asset mask pattern and each
521 group is solved via a single batched ``numpy.linalg.solve`` call (see
522 :meth:`_iter_solve_ewma_batched`). The :class:`~basanos.math.SlidingWindowConfig`
523 path retains a sequential per-row solve because the factor-model matrices
524 are constructed lazily and may vary in numerical character across rows.
526 .. note::
528 **Dual-path maintenance obligation**: this method dispatches to two
529 fundamentally different implementations. Any change to solve
530 semantics — a new edge case, a new :class:`SolveStatus` value, or a
531 change to denominator logic — **must be applied to both branches**:
533 * :meth:`_iter_solve_ewma_batched` / :meth:`_batched_solve_group`
534 (EwmaShrink vectorised path)
535 * The sequential ``_compute_position`` loop below
536 (SlidingWindow path)
538 The cross-path numerical consistency test
539 ``test_ewma_batch_and_sequential_paths_agree`` in
540 ``tests/test_math/test_numerical_regression.py`` will fail
541 whenever the two paths drift apart, surfacing the divergence
542 before it reaches production.
544 Yields:
545 SolveYield: ``(i, t, mask, pos_or_none, status)`` — see
546 :data:`SolveYield` for detailed field descriptions.
547 """
548 mu_np = self.mu.select(self.assets).to_numpy()
549 is_sw = isinstance(self.cfg.covariance_config, SlidingWindowConfig)
551 if not is_sw:
552 # EwmaShrinkConfig path: vectorised batch solve grouped by mask pattern.
553 yield from _SolveMixin._iter_solve_ewma_batched(mu_np, list(self._iter_matrices()), self.cfg.denom_tol)
554 return
556 # SlidingWindowConfig path: sequential per-row solve (lazy factor models).
557 win_w: int = cast(SlidingWindowConfig, self.cfg.covariance_config).window
559 for i, t, mask, bundle in self._iter_matrices():
560 if bundle is None:
561 # Distinguish SW warmup (insufficient history) from no-data / model-failure.
562 if mask.any() and i + 1 < win_w:
563 yield i, t, mask, None, SolveStatus.WARMUP
564 else:
565 yield i, t, mask, np.zeros(int(mask.sum())), SolveStatus.DEGENERATE
566 continue
567 expected_mu, early = _SolveMixin._row_early_check(i, t, mask, mu_np[i])
568 if early is not None:
569 yield early
570 continue
571 yield _SolveMixin._compute_position(i, t, mask, expected_mu, bundle, self.cfg.denom_tol)
573 def warmup_state(self: _EngineProtocol) -> WarmupState:
574 """Return the final :class:`WarmupState` after replaying the full batch.
576 Encapsulates the position replay loop that was previously duplicated
577 inside :meth:`BasanosStream.from_warmup`. By centralising the loop
578 here, :meth:`~BasanosStream.from_warmup` no longer needs to call the
579 private :meth:`_iter_solve` generator directly.
581 Returns:
582 WarmupState: A frozen dataclass with:
584 * ``prev_cash_pos`` - cash-position vector at the last row,
585 shape ``(n_assets,)``.
587 Examples:
588 >>> import numpy as np
589 >>> import polars as pl
590 >>> from basanos.math import BasanosConfig, BasanosEngine
591 >>> rng = np.random.default_rng(0)
592 >>> dates = list(range(30))
593 >>> prices = pl.DataFrame({
594 ... "date": dates,
595 ... "A": np.cumprod(1 + rng.normal(0.001, 0.02, 30)) * 100.0,
596 ... "B": np.cumprod(1 + rng.normal(0.001, 0.02, 30)) * 150.0,
597 ... })
598 >>> mu = pl.DataFrame({
599 ... "date": dates,
600 ... "A": rng.normal(0, 0.5, 30),
601 ... "B": rng.normal(0, 0.5, 30),
602 ... })
603 >>> cfg = BasanosConfig(vola=5, corr=10, clip=3.0, shrink=0.5, aum=1e6)
604 >>> engine = BasanosEngine(prices=prices, mu=mu, cfg=cfg)
605 >>> ws = engine.warmup_state()
606 >>> ws.prev_cash_pos.shape
607 (2,)
608 """
609 assets = self.assets
610 n_rows = self.prices.height
611 vola_np = self.vola.select(assets).to_numpy()
613 risk_pos_np = np.full((n_rows, len(assets)), np.nan, dtype=float)
614 cash_pos_np = np.full((n_rows, len(assets)), np.nan, dtype=float)
616 if isinstance(self.cfg.covariance_config, EwmaShrinkConfig):
617 # Compute the IIR filter state in a single pass over the warmup data
618 # so BasanosStream.from_warmup() can seed the incremental lfilter
619 # without a second sweep.
620 ret_adj_np = self.ret_adj.select(assets).to_numpy()
621 _, iir_state = _ewm_corr_with_final_state(
622 ret_adj_np,
623 com=self.cfg.corr,
624 min_periods=self.cfg.corr,
625 min_corr_denom=self.cfg.min_corr_denom,
626 )
627 else:
628 iir_state = None
630 _SolveMixin._replay_positions(self, risk_pos_np, cash_pos_np, vola_np)
631 prev_cash_pos = cash_pos_np[-1].copy()
632 return WarmupState(
633 prev_cash_pos=prev_cash_pos,
634 corr_iir_state=iir_state,
635 )