Coverage for src/basanos/math/_engine_solve.py: 100%
184 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-23 05:58 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-23 05:58 +0000
1"""Solve/position mixin for BasanosEngine.
3This private module contains `_SolveMixin`, which provides the
4``_iter_matrices`` and ``_iter_solve`` generator methods. Separating them
5from `optimizer` keeps the engine facade lean and makes
6the per-timestamp solve logic independently readable and testable.
7"""
9from __future__ import annotations
11import dataclasses
12import datetime
13import logging
14from collections.abc import Generator
15from enum import StrEnum
16from typing import TYPE_CHECKING, TypeAlias, cast
18import numpy as np
19from cvx.linalg import SingularMatrixError, inv_a_norm, solve
21from ._config import EwmaShrinkConfig, SlidingWindowConfig
22from ._factor_model import FactorModel
23from ._signal import shrink2id
25if TYPE_CHECKING:
26 from ._engine_protocol import _EngineProtocol
28_logger = logging.getLogger(__name__)
31class SolveStatus(StrEnum):
32 """Solver outcome labels for each timestamp.
34 Since `SolveStatus` inherits from `str` via ``StrEnum``,
35 values compare equal to their string equivalents (e.g.
36 ``SolveStatus.VALID == "valid"``), preserving backward compatibility
37 with code that matches on string literals.
39 Attributes:
40 WARMUP: Insufficient history for the sliding-window covariance mode.
41 ZERO_SIGNAL: The expected-return vector was all-zero; positions zeroed.
42 DEGENERATE: Normalisation denominator was non-finite, solve failed, or
43 no asset had a finite price; positions zeroed for safety.
44 VALID: Linear system solved successfully; positions are non-trivially
45 non-zero.
46 """
48 WARMUP = "warmup"
49 ZERO_SIGNAL = "zero_signal"
50 DEGENERATE = "degenerate"
51 VALID = "valid"
54@dataclasses.dataclass(frozen=True)
55class MatrixBundle:
56 """Container for the covariance matrix and any mode-specific auxiliary state.
58 Wrapping the covariance matrix in a dataclass decouples
59 `_compute_position` from the raw array so that future
60 covariance modes (e.g. DCC-GARCH, RMT-cleaned) can carry additional fields
61 through the same interface without changing the method signature.
63 Attributes:
64 matrix: The ``(n_active, n_active)`` covariance sub-matrix for the
65 active assets at a given timestamp.
66 """
68 matrix: np.ndarray
71#: Yield type for `_iter_matrices`:
72#: ``(i, t, mask, bundle)`` where ``bundle`` is ``None`` during warmup/no-data.
73MatrixYield: TypeAlias = tuple[int, datetime.date, np.ndarray, MatrixBundle | None]
75#: Yield type for `_iter_solve`:
76#: ``(i, t, mask, pos_or_none, status)`` where ``pos_or_none`` is ``None`` only for warmup rows.
77SolveYield: TypeAlias = tuple[int, datetime.date, np.ndarray, np.ndarray | None, SolveStatus]
80@dataclasses.dataclass(frozen=True)
81class WarmupState:
82 """Final state produced by a full batch solve; consumed by `from_warmup`.
84 Returned by `warmup_state` and used by
85 `from_warmup` to initialise the streaming state without
86 coupling to the private `_iter_solve` generator.
88 Attributes:
89 prev_cash_pos: Cash positions at the last warmup row, shape
90 ``(n_assets,)``. ``NaN`` for assets that were still in their
91 own warmup period.
92 """
94 prev_cash_pos: np.ndarray
97class _SolveMixin:
98 """Mixin that provides ``_iter_matrices`` and ``_iter_solve`` generators.
100 Consumers must also inherit from (or satisfy the interface of)
101 `_EngineProtocol` so that
102 ``self.assets``, ``self.prices``, ``self.mu``, ``self.cfg``, ``self.cor``,
103 and ``self.ret_adj`` are all available.
104 """
106 @staticmethod
107 def _compute_mask(prices_row: np.ndarray) -> np.ndarray:
108 """Return boolean mask indicating which assets have finite prices in the given row."""
109 mask: np.ndarray = np.isfinite(prices_row)
110 return mask
112 @staticmethod
113 def _check_signal(mu: np.ndarray, mask: np.ndarray) -> SolveStatus | None:
114 """Return ``ZERO_SIGNAL`` when the masked expected-return vector is all-zero.
116 Returns ``None`` when the signal is non-trivially non-zero, indicating
117 that the caller should proceed to the linear solve.
118 """
119 if np.allclose(np.nan_to_num(mu[mask]), 0.0):
120 return SolveStatus.ZERO_SIGNAL
121 return None
123 @staticmethod
124 def _scale_to_cash(pos: np.ndarray, vola_active: np.ndarray) -> np.ndarray:
125 """Convert raw solver positions to cash-adjusted positions.
127 Divides *pos* by *vola_active* (volatility for the active asset subset)
128 to get cash positions. ``np.errstate(invalid="ignore")`` is applied
129 internally so NaN volatility values propagate quietly.
130 """
131 with np.errstate(invalid="ignore"):
132 return cast("np.ndarray", pos / vola_active)
134 @staticmethod
135 def _row_early_check(
136 i: int,
137 t: datetime.date,
138 mask: np.ndarray,
139 mu_row: np.ndarray,
140 ) -> tuple[np.ndarray, SolveYield | None]:
141 """Validate the price mask and expected-return signal for a single row.
143 Returns an ``(expected_mu, early_yield)`` pair. When ``early_yield``
144 is not ``None``, the caller should ``yield early_yield; continue``
145 immediately — the row is either degenerate (empty mask) or has an
146 all-zero signal. When ``early_yield`` is ``None`` the row is ready
147 for the mode-specific solve step.
149 Args:
150 i: Row index.
151 t: Timestamp.
152 mask: Boolean array of shape ``(n_assets,)`` indicating finite prices.
153 mu_row: Expected-return row of shape ``(n_assets,)``.
155 Returns:
156 tuple: ``(expected_mu, early_yield)`` where ``expected_mu`` is
157 ``np.nan_to_num(mu_row[mask])`` and ``early_yield`` is either a
158 complete `SolveYield` tuple (when the caller should yield
159 and continue) or ``None`` (when the caller should proceed to solve).
160 """
161 if not mask.any():
162 return np.zeros(0), (i, t, mask, np.zeros(0), SolveStatus.DEGENERATE)
163 expected_mu = np.nan_to_num(mu_row[mask])
164 sig_status = _SolveMixin._check_signal(mu_row, mask)
165 if sig_status is not None:
166 return expected_mu, (i, t, mask, np.zeros_like(expected_mu), sig_status)
167 return expected_mu, None
169 @staticmethod
170 def _denom_guard_yield(
171 i: int,
172 t: datetime.date,
173 mask: np.ndarray,
174 expected_mu: np.ndarray,
175 pos_raw: np.ndarray,
176 denom: float,
177 denom_tol: float,
178 ) -> SolveYield:
179 """Apply the normalisation-denominator guard and return the appropriate yield tuple.
181 Emits a `WARNING` and returns a
182 `DEGENERATE` yield when *denom* is non-finite or at
183 or below *denom_tol*; otherwise returns a `VALID`
184 yield with normalised positions ``pos_raw / denom``.
186 Args:
187 i: Row index.
188 t: Timestamp.
189 mask: Boolean asset mask of shape ``(n_assets,)``.
190 expected_mu: Masked expected-return vector of shape ``(n_active,)``.
191 pos_raw: Raw (pre-normalisation) position vector of shape ``(n_active,)``.
192 denom: Computed normalisation denominator.
193 denom_tol: Tolerance threshold below which *denom* is treated as degenerate.
195 Returns:
196 SolveYield: Either a degenerate or valid ``(i, t, mask, pos, status)`` tuple.
197 """
198 n_active = len(expected_mu)
199 if not np.isfinite(denom) or denom <= denom_tol:
200 _logger.warning(
201 "Positions zeroed at t=%s: normalisation denominator is degenerate "
202 "(denom=%s, denom_tol=%s). Check signal magnitude and covariance matrix.",
203 t,
204 denom,
205 denom_tol,
206 extra={
207 "context": {
208 "t": str(t),
209 "denom": denom,
210 "denom_tol": denom_tol,
211 }
212 },
213 )
214 return i, t, mask, np.zeros(n_active), SolveStatus.DEGENERATE
215 return i, t, mask, pos_raw / denom, SolveStatus.VALID
217 @staticmethod
218 def _compute_position(
219 i: int,
220 t: datetime.date,
221 mask: np.ndarray,
222 expected_mu: np.ndarray,
223 bundle: MatrixBundle,
224 denom_tol: float,
225 ) -> SolveYield:
226 """Shared solve step used by both covariance branches.
228 Computes the normalisation denominator via `inv_a_norm`
229 and solves the linear system via `solve`, then
230 delegates to `_denom_guard_yield`. Handles
231 :exc:`~basanos.exceptions.SingularMatrixError` from both calls.
233 Accepting a `MatrixBundle` instead of a raw array means future
234 covariance modes can attach auxiliary state to the bundle without
235 changing this method's signature.
237 Args:
238 i: Row index.
239 t: Timestamp.
240 mask: Boolean asset mask of shape ``(n_assets,)``.
241 expected_mu: Masked expected-return vector of shape ``(n_active,)``.
242 bundle: Covariance bundle whose ``matrix`` field is an
243 ``(n_active, n_active)`` covariance matrix for the active assets.
244 denom_tol: Tolerance threshold for the normalisation denominator.
246 Returns:
247 SolveYield: A degenerate or valid ``(i, t, mask, pos, status)`` tuple.
248 """
249 matrix = bundle.matrix
250 try:
251 denom = inv_a_norm(expected_mu, matrix)
252 except SingularMatrixError:
253 denom = float("nan")
254 try:
255 pos = solve(matrix, expected_mu)
256 except SingularMatrixError:
257 return i, t, mask, np.zeros_like(expected_mu), SolveStatus.DEGENERATE
258 return _SolveMixin._denom_guard_yield(i, t, mask, expected_mu, pos, denom, denom_tol)
260 @staticmethod
261 def _apply_turnover_constraint(
262 new_cash: np.ndarray,
263 prev_cash: np.ndarray,
264 max_turnover: float,
265 ) -> np.ndarray:
266 """Cap the L1 norm of the position change to *max_turnover*.
268 When ``sum(|new_cash - prev_cash|) > max_turnover``, the delta is
269 scaled back proportionally toward *prev_cash* so that the constraint
270 is exactly met. When the constraint is already satisfied the input is
271 returned unchanged.
273 Args:
274 new_cash: Proposed cash positions after the solve step, shape
275 ``(n_active,)`` — ``NaN`` values treated as zero.
276 prev_cash: Cash positions at the previous step, shape
277 ``(n_active,)`` — ``NaN`` values treated as zero.
278 max_turnover: Maximum allowed L1 norm of the position change.
280 Returns:
281 np.ndarray: The (possibly scaled) new cash positions.
282 """
283 curr = np.nan_to_num(new_cash, nan=0.0)
284 prev = np.nan_to_num(prev_cash, nan=0.0)
285 delta = curr - prev
286 total_delta = float(np.sum(np.abs(delta)))
287 if total_delta > max_turnover:
288 scale = max_turnover / total_delta
289 return cast("np.ndarray", prev + delta * scale)
290 return new_cash
292 def _replay_positions(
293 self: _EngineProtocol,
294 risk_pos_np: np.ndarray,
295 cash_pos_np: np.ndarray,
296 vola_np: np.ndarray,
297 ) -> None:
298 """Replay positions across all rows, filling position arrays.
300 Iterates `_iter_solve`, writes risk and cash positions into the
301 provided pre-allocated arrays. Both arrays are mutated **in-place**.
303 When `max_turnover` is set, the L1 norm of the
304 position change ``sum(|x_t - x_{t-1}|)`` is capped at that value by
305 proportionally scaling the delta toward the previous position before
306 writing to ``cash_pos_np``.
308 Args:
309 risk_pos_np: Pre-allocated ``(T, N)`` array for risk positions.
310 cash_pos_np: Pre-allocated ``(T, N)`` array for cash positions.
311 vola_np: ``(T, N)`` EWMA volatility array.
312 """
313 max_to: float | None = self.cfg.max_turnover
314 for i, _t, mask, pos, _status in self._iter_solve():
315 if pos is not None:
316 new_cash = _SolveMixin._scale_to_cash(pos, vola_np[i, mask])
317 if max_to is not None and i > 0:
318 new_cash = _SolveMixin._apply_turnover_constraint(new_cash, cash_pos_np[i - 1, mask], max_to)
319 risk_pos_np[i, mask] = new_cash * vola_np[i, mask]
320 cash_pos_np[i, mask] = new_cash
322 def _iter_matrices(self: _EngineProtocol) -> Generator[MatrixYield, None, None]:
323 r"""Yield ``(i, t, mask, bundle)`` for every timestamp.
325 ``bundle`` is a `MatrixBundle` wrapping the effective
326 $(n_{\text{sub}},\ n_{\text{sub}})$ correlation matrix for the
327 active assets (those with finite prices at timestamp *t*). Yields
328 ``None`` when no valid matrix is available (e.g., before the warm-up
329 period has elapsed or when no assets have finite prices).
331 The behaviour depends on `covariance_config`:
333 * `EwmaShrinkConfig`: Applies `shrink2id` to
334 the EWMA correlation matrix (same computation as
335 `cash_position`).
336 * `SlidingWindowConfig`: Builds a
337 `FactorModel` from the last
338 ``cfg.covariance_config.window`` rows of vol-adjusted returns and returns its
339 `covariance`.
341 Yields:
342 tuple: ``(i, t, mask, bundle)`` where
344 * ``i`` (*int*): Row index into ``self.prices``.
345 * ``t``: Timestamp value from ``self.prices["date"]``.
346 * ``mask`` (*np.ndarray[bool]*): Shape ``(n_assets,)``; ``True``
347 for assets with finite prices at row *i*.
348 * ``bundle`` (`MatrixBundle` | ``None``): Covariance bundle
349 of shape ``(mask.sum(), mask.sum())``, or ``None``.
350 """
351 assets = self.assets
352 prices_num = self.prices.select(assets).to_numpy()
353 dates = self.prices["date"].to_list()
355 if isinstance(self.cfg.covariance_config, EwmaShrinkConfig):
356 cor = self.cor
357 for i, t in enumerate(dates):
358 mask = _SolveMixin._compute_mask(prices_num[i])
359 if not mask.any():
360 yield i, t, mask, None
361 continue
362 corr_n = cor[t]
363 matrix = shrink2id(corr_n, lamb=self.cfg.shrink)[np.ix_(mask, mask)]
364 yield i, t, mask, MatrixBundle(matrix=matrix)
365 else:
366 sw_config = self.cfg.covariance_config
367 win_w: int = sw_config.window
368 win_k: int = sw_config.n_factors
369 ret_adj_np = self.ret_adj.select(assets).to_numpy()
370 for i, t in enumerate(dates):
371 mask = _SolveMixin._compute_mask(prices_num[i])
372 if not mask.any() or i + 1 < win_w:
373 yield i, t, mask, None
374 continue
375 window_ret = ret_adj_np[i + 1 - win_w : i + 1][:, mask]
376 window_ret = np.where(np.isfinite(window_ret), window_ret, 0.0)
377 n_sub = int(mask.sum())
378 k_eff = min(win_k, win_w, n_sub)
379 try:
380 fm = FactorModel.from_returns(window_ret, k=k_eff)
381 yield i, t, mask, MatrixBundle(matrix=fm.covariance)
382 except (np.linalg.LinAlgError, ValueError) as exc:
383 _logger.warning("Factor model fit failed at t=%s: %s", t, exc)
384 yield i, t, mask, None
386 @staticmethod
387 def _batched_solve_group(
388 group: list[tuple[int, datetime.date, np.ndarray, np.ndarray, np.ndarray]],
389 denom_tol: float,
390 ) -> dict[int, SolveYield]:
391 """Solve a batch of linear systems sharing the same active-asset mask.
393 Stacks the ``len(group)`` systems into a ``(G, n, n)`` coefficient tensor
394 and a ``(G, n)`` right-hand-side matrix, then dispatches a single
395 ``numpy.linalg.solve`` call (which maps to a single batched LAPACK
396 routine). Denominators are computed directly from the batch result as
397 ``sqrt(mu_i · pos_i)`` — algebraically identical to the per-row
398 `inv_a_norm` call.
400 Falls back to row-by-row `_compute_position` when
401 ``numpy.linalg.solve`` raises ``LinAlgError`` (any matrix in the batch
402 is singular).
404 Args:
405 group: List of ``(i, t, mask, expected_mu, matrix)`` tuples; all
406 entries share the same boolean mask and therefore the same
407 ``n_active x n_active`` matrix shape.
408 denom_tol: Passed through to `_denom_guard_yield`.
410 Returns:
411 dict: Mapping from row index ``i`` to its `SolveYield`.
412 """
413 results: dict[int, SolveYield] = {}
414 a_stack = np.stack([row[4] for row in group]) # (G, n, n)
415 mu_stack = np.stack([row[3] for row in group]) # (G, n)
417 try:
418 # numpy.linalg.solve requires the RHS to be (..., M, K) when a is (..., M, M).
419 # Reshape mu_stack from (G, n) → (G, n, 1) so core dims match, then squeeze.
420 pos_stack = np.linalg.solve(a_stack, mu_stack[..., np.newaxis])[..., 0] # (G, n)
421 except np.linalg.LinAlgError:
422 # At least one matrix is singular — fall back to sequential per-row solve.
423 for i, t, mask, expected_mu, matrix in group:
424 results[i] = _SolveMixin._compute_position(
425 i, t, mask, expected_mu, MatrixBundle(matrix=matrix), denom_tol
426 )
427 return results
429 # Denominators: sqrt(mu_i^T A_i^{-1} mu_i) = sqrt(mu_i · pos_i).
430 dots = (mu_stack * pos_stack).sum(axis=1) # (G,)
431 denoms = np.where(dots > 0.0, np.sqrt(dots), np.nan)
433 for (i, t, mask, expected_mu, _matrix), pos, denom in zip(group, pos_stack, denoms, strict=True):
434 results[i] = _SolveMixin._denom_guard_yield(i, t, mask, expected_mu, pos, float(denom), denom_tol)
436 return results
438 @staticmethod
439 def _iter_solve_ewma_batched(
440 mu_np: np.ndarray,
441 matrix_yields: list[MatrixYield],
442 denom_tol: float,
443 ) -> Generator[SolveYield, None, None]:
444 r"""Vectorised EwmaShrink solve: batch ``numpy.linalg.solve`` across timestamps.
446 Groups rows by their boolean asset mask so all systems within a group
447 share the same ``(n_active, n_active)`` shape, then stacks them into a
448 ``(G, n, n)`` tensor and calls ``numpy.linalg.solve`` once per unique
449 mask pattern. Results are collected in a dict and yielded in original
450 row order.
452 Denominators are derived from the batch solution as
453 $\sqrt{\mu_i \cdot \mathbf{pos}_i} = \sqrt{\mu_i^\top \Sigma_i^{-1} \mu_i}$,
454 matching the scalar `inv_a_norm` result up
455 to float64 rounding.
457 Any group whose batch solve raises ``LinAlgError`` (singular matrix in
458 the batch) falls back to sequential `_compute_position` for that
459 group only.
461 Args:
462 mu_np: Signal matrix, shape ``(T, n_assets)``.
463 matrix_yields: Pre-collected list from `_iter_matrices`
464 (the EwmaShrinkConfig branch).
465 denom_tol: Denominator guard tolerance.
467 Yields:
468 `SolveYield` tuples in original row order.
469 """
470 # First pass: categorise each row as early-exit or a solve candidate.
471 all_results: dict[int, SolveYield] = {}
472 # mask.tobytes() → list of (i, t, mask, expected_mu, matrix)
473 solve_groups: dict[bytes, list[tuple[int, datetime.date, np.ndarray, np.ndarray, np.ndarray]]] = {}
475 for i, t, mask, bundle in matrix_yields:
476 if bundle is None:
477 all_results[i] = (i, t, mask, np.zeros(int(mask.sum())), SolveStatus.DEGENERATE)
478 continue
479 expected_mu, early = _SolveMixin._row_early_check(i, t, mask, mu_np[i])
480 if early is not None:
481 all_results[i] = early
482 continue
483 mask_key = mask.tobytes()
484 if mask_key not in solve_groups:
485 solve_groups[mask_key] = []
486 solve_groups[mask_key].append((i, t, mask, expected_mu, bundle.matrix))
488 # Second pass: batch-solve each mask group.
489 for group in solve_groups.values():
490 all_results.update(_SolveMixin._batched_solve_group(group, denom_tol))
492 # Yield in original row order.
493 for i in range(len(matrix_yields)):
494 if i in all_results:
495 yield all_results[i]
497 def _iter_solve(self: _EngineProtocol) -> Generator[SolveYield, None, None]:
498 r"""Yield ``(i, t, mask, pos_or_none, status)`` for every timestamp.
500 Iterates `_iter_matrices` for the per-row covariance sub-matrix,
501 then applies `_row_early_check` (mask/signal guard) and
502 `_compute_position` (linear solve and denominator guard). The two
503 covariance modes differ only in how ``matrix`` is built, which
504 `_iter_matrices` already encapsulates.
506 * ``matrix is None`` → `WARMUP` (sliding-window before
507 sufficient history) or `DEGENERATE` otherwise.
508 * Signal all-zero → `ZERO_SIGNAL`.
509 * Singular or degenerate solve → `DEGENERATE`.
510 * Success → `VALID`.
512 For the `EwmaShrinkConfig` path the solve step is
513 vectorised: rows are grouped by their active-asset mask pattern and each
514 group is solved via a single batched ``numpy.linalg.solve`` call (see
515 `_iter_solve_ewma_batched`). The `SlidingWindowConfig`
516 path retains a sequential per-row solve because the factor-model matrices
517 are constructed lazily and may vary in numerical character across rows.
519 .. note::
521 **Dual-path maintenance obligation**: this method dispatches to two
522 fundamentally different implementations. Any change to solve
523 semantics — a new edge case, a new `SolveStatus` value, or a
524 change to denominator logic — **must be applied to both branches**:
526 * `_iter_solve_ewma_batched` / `_batched_solve_group`
527 (EwmaShrink vectorised path)
528 * The sequential ``_compute_position`` loop below
529 (SlidingWindow path)
531 The cross-path numerical consistency test
532 ``test_ewma_batch_and_sequential_paths_agree`` in
533 ``tests/test_math/test_numerical_regression.py`` will fail
534 whenever the two paths drift apart, surfacing the divergence
535 before it reaches production.
537 Yields:
538 SolveYield: ``(i, t, mask, pos_or_none, status)`` — see
539 `SolveYield` for detailed field descriptions.
540 """
541 mu_np = self.mu.select(self.assets).to_numpy()
542 cov_config = self.cfg.covariance_config
544 if not isinstance(cov_config, SlidingWindowConfig):
545 # EwmaShrinkConfig path: vectorised batch solve grouped by mask pattern.
546 yield from _SolveMixin._iter_solve_ewma_batched(mu_np, list(self._iter_matrices()), self.cfg.denom_tol)
547 return
549 # SlidingWindowConfig path: sequential per-row solve (lazy factor models).
550 win_w: int = cov_config.window
552 for i, t, mask, bundle in self._iter_matrices():
553 if bundle is None:
554 # Distinguish SW warmup (insufficient history) from no-data / model-failure.
555 if mask.any() and i + 1 < win_w:
556 yield i, t, mask, None, SolveStatus.WARMUP
557 else:
558 yield i, t, mask, np.zeros(int(mask.sum())), SolveStatus.DEGENERATE
559 continue
560 expected_mu, early = _SolveMixin._row_early_check(i, t, mask, mu_np[i])
561 if early is not None:
562 yield early
563 continue
564 yield _SolveMixin._compute_position(i, t, mask, expected_mu, bundle, self.cfg.denom_tol)
566 def warmup_state(self: _EngineProtocol) -> WarmupState:
567 """Return the final `WarmupState` after replaying the full batch.
569 Encapsulates the position replay loop that was previously duplicated
570 inside `from_warmup`. By centralising the loop
571 here, `from_warmup` no longer needs to call the
572 private `_iter_solve` generator directly.
574 Returns:
575 WarmupState: A frozen dataclass with:
577 * ``prev_cash_pos`` - cash-position vector at the last row,
578 shape ``(n_assets,)``.
580 Examples:
581 >>> import numpy as np
582 >>> import polars as pl
583 >>> from basanos.math import BasanosConfig, BasanosEngine
584 >>> rng = np.random.default_rng(0)
585 >>> dates = list(range(30))
586 >>> prices = pl.DataFrame({
587 ... "date": dates,
588 ... "A": np.cumprod(1 + rng.normal(0.001, 0.02, 30)) * 100.0,
589 ... "B": np.cumprod(1 + rng.normal(0.001, 0.02, 30)) * 150.0,
590 ... })
591 >>> mu = pl.DataFrame({
592 ... "date": dates,
593 ... "A": rng.normal(0, 0.5, 30),
594 ... "B": rng.normal(0, 0.5, 30),
595 ... })
596 >>> cfg = BasanosConfig(vola=5, corr=10, clip=3.0, shrink=0.5, aum=1e6)
597 >>> engine = BasanosEngine(prices=prices, mu=mu, cfg=cfg)
598 >>> ws = engine.warmup_state()
599 >>> ws.prev_cash_pos.shape
600 (2,)
601 """
602 assets = self.assets
603 n_rows = self.prices.height
604 vola_np = self.vola.select(assets).to_numpy()
606 risk_pos_np = np.full((n_rows, len(assets)), np.nan, dtype=float)
607 cash_pos_np = np.full((n_rows, len(assets)), np.nan, dtype=float)
609 _SolveMixin._replay_positions(self, risk_pos_np, cash_pos_np, vola_np)
610 prev_cash_pos = cash_pos_np[-1].copy()
611 return WarmupState(prev_cash_pos=prev_cash_pos)