Coverage for src/basanos/math/_stream.py: 99%
299 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-23 05:58 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-23 05:58 +0000
1"""Incremental (streaming) API for BasanosEngine.
3This private module defines three public symbols:
5* `_StreamState` — mutable dataclass that persists all O(N²)
6 accumulator state between consecutive `step` calls. Kept separate
7 from the engine so the state layout can be read and tested in isolation.
8* `StepResult` — frozen dataclass returned by each
9 `step` call.
10* `BasanosStream` — incremental façade with a
11 `from_warmup` classmethod and a
12 `step` method.
14EWM correlation state model
15----------------------------
16In EWM mode the correlation at each step is recomputed by calling
17``ewm_covariance`` from ``cvx.linalg`` over the full growing history of
18vol-adjusted returns stored in ``corr_ret_buf``. This keeps the incremental
19and batch paths numerically identical at the cost of O(T·N²) time per step
20(acceptable for small N or short warmup histories).
22The volatility accumulators (``vola_*``, ``pct_*``) use a simpler scalar
23recurrence and store the running sums directly as ``(N,)`` arrays.
25Memory
26------
27Total incremental state is O(T·N) for the growing history buffer plus
288x(N,) + O(1) scalars. For the SlidingWindowConfig the buffer is a fixed
29(W, N) array independent of T.
30"""
32from __future__ import annotations
34import dataclasses
35import logging
36import os
37from typing import Any, cast
39import numpy as np
40import polars as pl
41from cvx.linalg import cov_to_corr
42from cvx.linalg.ewm_cov import ewm_covariance
43from scipy.signal import lfilter
45from ..exceptions import MissingDateColumnError, StreamStateCorruptError
46from ._config import BasanosConfig, EwmaShrinkConfig, SlidingWindowConfig
47from ._engine_solve import MatrixBundle, SolveStatus, _SolveMixin
48from ._factor_model import FactorModel
49from ._signal import shrink2id
51_logger = logging.getLogger(__name__)
53#: Increment this when the `save` archive layout changes in
54#: a backward-incompatible way. `load` asserts the stored
55#: value matches before deserialising anything, so callers get a clear error
56#: instead of a silent ``KeyError`` or wrong state.
57_SAVE_FORMAT_VERSION: int = 3
60@dataclasses.dataclass
61class _StreamState:
62 """Mutable state carrier for one `BasanosStream` instance.
64 All arrays are updated in-place (or replaced) by ``BasanosStream.step()``.
65 The class is intentionally *not* frozen so that the step method can modify
66 fields directly without creating a new object on every tick.
68 EWM correlation state
69 ~~~~~~~~~~~~~~~~~~~~~
70 ``corr_ret_buf`` holds the growing history of vol-adjusted returns used by
71 ``ewm_covariance`` to recompute the correlation matrix on each step. It
72 is ``None`` for ``SlidingWindowConfig`` (which uses ``sw_ret_buf`` instead).
74 EWM accumulator state (volatility)
75 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
76 ``vola_*`` and ``pct_*`` accumulate the running weighted sums needed to
77 compute exponentially-weighted standard deviations:
79 * ``s_x`` — EWM sum of x (numerator of the mean)
80 * ``s_x2`` — EWM sum of x² (numerator of the second moment)
81 * ``s_w`` — EWM sum of weights (denominator)
82 * ``s_w2`` — EWM sum of squared weights (for bias correction)
84 ``beta_vola = (cfg.vola - 1) / cfg.vola`` (from ``com = cfg.vola - 1``)
86 Attributes:
87 corr_ret_buf: Growing history of vol-adjusted returns used by
88 ``ewm_covariance``; shape ``(T, N)`` for EwmaShrinkConfig.
89 ``None`` for SlidingWindowConfig.
90 vola_s_x: EWM sum of volatility-adjusted log-returns; shape ``(N,)``.
91 vola_s_x2: EWM sum of squared vol-adj log-returns; shape ``(N,)``.
92 vola_s_w: EWM weight sum for vol accumulators; shape ``(N,)``.
93 vola_s_w2: EWM squared-weight sum for vol accumulators; shape ``(N,)``.
94 vola_count: Cumulative finite observation count for vol; shape ``(N,)``
95 dtype int.
96 pct_s_x: EWM sum of pct-returns; shape ``(N,)``.
97 pct_s_x2: EWM sum of squared pct-returns; shape ``(N,)``.
98 pct_s_w: EWM weight sum for pct accumulators; shape ``(N,)``.
99 pct_s_w2: EWM squared-weight sum for pct accumulators; shape ``(N,)``.
100 pct_count: Cumulative finite observation count for pct-return vol;
101 shape ``(N,)`` dtype int.
102 prev_price: Last price row seen, used to compute returns on the next
103 step; shape ``(N,)``.
104 prev_cash_pos: Last cash position, used to apply the turnover constraint
105 on the next step; shape ``(N,)``.
106 step_count: Number of steps processed so far (0 before first step).
107 """
109 # ── EWM correlation history buffer — (T, N) for EwmaShrinkConfig ─────────
110 corr_ret_buf: np.ndarray | None # (T, N) growing history of vol-adj returns; None for SlidingWindowConfig
112 # ── EWMA accumulators for vol_adj (log-return std; com=vola-1, min_samples=1) ──
113 vola_s_x: np.ndarray # (N,)
114 vola_s_x2: np.ndarray # (N,)
115 vola_s_w: np.ndarray # (N,)
116 vola_s_w2: np.ndarray # (N,)
117 vola_count: np.ndarray # (N,) int
119 # ── EWMA accumulators for vola (pct-return std; com=vola-1, min_samples=vola) ──
120 pct_s_x: np.ndarray # (N,)
121 pct_s_x2: np.ndarray # (N,)
122 pct_s_w: np.ndarray # (N,)
123 pct_s_w2: np.ndarray # (N,)
124 pct_count: np.ndarray # (N,) int
126 # ── Scalars ───────────────────────────────────────────────────────────────
127 prev_price: np.ndarray # (N,) last price row (to compute returns at next step)
128 prev_cash_pos: np.ndarray # (N,) last cash position (for turnover constraint at next step)
129 step_count: int
131 # ── SlidingWindowConfig state — None for EwmaShrinkConfig ────────────────
132 # shape (W, N): last W vol-adjusted returns (oldest row first); None when
133 # using EwmaShrinkConfig. corr_ret_buf above is unused (None) in this mode;
134 # sw_ret_buf carries all the correlation state instead.
135 sw_ret_buf: np.ndarray | None = None # (W, N) rolling buffer, or None
138#: Keys that `save` writes to the ``.npz`` archive for
139#: `_StreamState` fields. Derived automatically from
140#: `fields` so that adding a new field to ``_StreamState``
141#: is sufficient — no manual update here is required.
142#:
143#: The three non-state keys (``format_version``, ``cfg_json``, ``assets``) are
144#: added explicitly because they are not fields of ``_StreamState`` itself.
145_REQUIRED_KEYS: frozenset[str] = frozenset(
146 {f.name for f in dataclasses.fields(_StreamState)} | {"format_version", "cfg_json", "assets"}
147)
150@dataclasses.dataclass(frozen=True)
151class StepResult:
152 """Frozen dataclass representing the output of a single ``BasanosStream`` step.
154 Each call to ``BasanosStream.step()`` returns one ``StepResult`` capturing
155 the optimised cash positions, the per-asset volatility estimate, the step
156 date, and a status label that describes the solver outcome for that
157 timestep.
159 Attributes:
160 date: The timestamp or date label for this step. The type mirrors
161 whatever is stored in the ``'date'`` column of the input prices
162 DataFrame (typically a Python `date`,
163 `datetime`, or a Polars temporal scalar).
164 cash_position: Optimised cash-position vector, shape ``(N,)``.
165 Entries are ``NaN`` for assets that are still in the EWMA warmup
166 period or that are otherwise inactive at this step.
167 status: Solver outcome label for this timestep
168 (`SolveStatus`). Since `SolveStatus`
169 is a ``StrEnum``, values compare equal to their string equivalents
170 (e.g. ``result.status == "valid"`` is ``True``):
172 * ``'warmup'`` — fewer rows have been seen than the EWMA warmup
173 requires; all positions are ``NaN``.
174 * ``'zero_signal'`` — the expected-return signal vector ``mu`` is
175 identically zero; positions are set to zero rather than solved.
176 * ``'degenerate'`` — the covariance matrix is ill-conditioned or
177 numerically singular; positions cannot be computed reliably and
178 are returned as ``NaN``.
179 * ``'valid'`` — normal operation; ``cash_position`` holds the
180 optimised allocations.
181 vola: Per-asset EWMA percentage-return volatility, shape ``(N,)``.
182 Values are ``NaN`` during the warmup period before the EWMA has
183 accumulated sufficient history.
185 Examples:
186 >>> import numpy as np
187 >>> result = StepResult(
188 ... date="2024-01-02",
189 ... cash_position=np.array([1000.0, -500.0]),
190 ... status="valid",
191 ... vola=np.array([0.012, 0.018]),
192 ... )
193 >>> result.status
194 'valid'
195 >>> result.cash_position.shape
196 (2,)
197 """
199 date: object
200 cash_position: np.ndarray
201 status: SolveStatus
202 vola: np.ndarray
205# ---------------------------------------------------------------------------
206# Helper: unbiased EWMA std from running accumulators
207# ---------------------------------------------------------------------------
210def _ewm_std_from_state(
211 s_x: np.ndarray,
212 s_x2: np.ndarray,
213 s_w: np.ndarray,
214 s_w2: np.ndarray,
215 count: np.ndarray,
216 min_samples: int,
217) -> np.ndarray:
218 r"""Compute the unbiased EWMA standard deviation from running accumulators.
220 Implements the same Bessel-corrected formula used by
221 ``polars.Expr.ewm_std(adjust=True)``::
223 var_biased = s_x2/s_w - (s_x/s_w)^2
224 correction = s_w^2 / (s_w^2 - s_w2) # Bessel correction
225 var_unbiased = var_biased * correction
226 std = sqrt(max(0, var_unbiased))
228 where ``s_w2 = sum(wi^2)`` is the sum of squared EWM weights.
230 Parameters
231 ----------
232 s_x, s_x2, s_w, s_w2:
233 Running accumulators, each of shape ``(N,)``.
234 count:
235 Integer count of finite observations per asset, shape ``(N,)``.
236 min_samples:
237 Minimum number of finite observations required before returning a
238 non-NaN value.
240 Returns:
241 -------
242 np.ndarray of shape ``(N,)`` with per-asset standard deviations.
243 NaN is returned for assets where ``count < min_samples``.
244 """
245 n = len(s_x)
246 result = np.full(n, np.nan, dtype=float)
247 ok = count >= min_samples
248 if not ok.any():
249 return result
251 with np.errstate(divide="ignore", invalid="ignore"):
252 mean = np.where(s_w > 0, s_x / s_w, 0.0)
253 mean_sq = np.where(s_w > 0, s_x2 / s_w, 0.0)
254 var_biased = np.maximum(mean_sq - mean**2, 0.0)
255 denom_corr = s_w**2 - s_w2
256 # denom_corr > 0 iff count >= 2; equals 0 when count == 1
257 var_unbiased = np.where(denom_corr > 0, var_biased * s_w**2 / denom_corr, 0.0)
258 std = np.sqrt(var_unbiased)
260 return np.where(ok, std, np.nan)
263# ---------------------------------------------------------------------------
264# Helper: batch EWMA volatility accumulators from a returns matrix
265# ---------------------------------------------------------------------------
268def _ewm_vol_accumulators_from_batch(
269 returns: np.ndarray,
270 beta: float,
271 beta_sq: float,
272) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
273 r"""Compute final EWMA volatility accumulators from a batch of returns.
275 Implements the same IIR recurrence as `step` but
276 vectorised over *T* timesteps using ``scipy.signal.lfilter``. The five
277 returned arrays are identical to the accumulators that would result from
278 feeding each row of *returns* through the scalar step-by-step recurrence::
280 s_x[t] = beta * s_x[t-1] + (x[t] if finite else 0)
281 s_x2[t] = beta * s_x2[t-1] + (x[t]^2 if finite else 0)
282 s_w[t] = beta * s_w[t-1] + (1 if finite else 0)
283 s_w2[t] = beta^2 * s_w2[t-1] + (1 if finite else 0)
285 Parameters
286 ----------
288 Returns:
289 Float array of shape ``(T, N)``. NaN entries are treated as missing
290 observations — they contribute nothing to the numerator sums and do
291 not increment the weight accumulators.
292 beta:
293 EWM decay factor for ``s_x``, ``s_x2``, and ``s_w``
294 (``beta = (com) / (1 + com)`` for ``com = cfg.vola - 1``).
295 beta_sq:
296 Squared decay factor used for ``s_w2``. Must equal ``beta ** 2``.
298 Returns:
299 -------
300 s_x, s_x2, s_w, s_w2 : np.ndarray of shape ``(N,)``
301 Final EWMA running accumulators after processing all *T* rows.
302 count : np.ndarray of shape ``(N,)`` dtype int
303 Number of finite observations per asset.
305 Notes:
306 -----
307 This function is the shared implementation used by
308 `from_warmup` for both the log-return (``vola_*``)
309 and pct-return (``pct_*``) accumulators. Keeping a single implementation
310 here guarantees that the batch and incremental paths stay in sync when the
311 recurrence definition changes.
312 """
313 fin = np.isfinite(returns).astype(np.float64) # (T, N)
314 x_z = np.where(fin.astype(bool), returns, 0.0) # (T, N)
315 filt_a = np.array([1.0, -beta])
316 filt_a2 = np.array([1.0, -beta_sq])
318 s_x: np.ndarray = lfilter([1.0], filt_a, x_z, axis=0)[-1]
319 s_x2: np.ndarray = lfilter([1.0], filt_a, x_z**2, axis=0)[-1]
320 s_w: np.ndarray = lfilter([1.0], filt_a, fin, axis=0)[-1]
321 s_w2: np.ndarray = lfilter([1.0], filt_a2, fin, axis=0)[-1]
322 count: np.ndarray = fin.sum(axis=0).astype(int)
324 return s_x, s_x2, s_w, s_w2, count
327def _resolve_step_vector(
328 values: np.ndarray | dict[str, float],
329 assets: list[str],
330 n_assets: int,
331 arg_name: str,
332) -> np.ndarray:
333 """Resolve one step input to a validated ``(N,)`` float vector.
335 Args:
336 values: Raw input provided to ``step`` as either dict or array-like.
337 assets: Ordered asset names used when ``values`` is a mapping.
338 n_assets: Expected vector length.
339 arg_name: Argument label used in shape-mismatch errors.
341 Returns:
342 A float64 numpy vector of shape ``(n_assets,)``.
344 Raises:
345 ValueError: If the resolved vector does not match ``(n_assets,)``.
346 """
347 if isinstance(values, dict):
348 vector = np.array([float(values[a]) for a in assets], dtype=float)
349 else:
350 vector = np.asarray(values, dtype=float).ravel()
351 if vector.shape != (n_assets,):
352 raise ValueError(f"{arg_name} must have shape ({n_assets},); got {vector.shape}") # noqa: TRY003
353 return vector
356# ---------------------------------------------------------------------------
357# BasanosStream
358# ---------------------------------------------------------------------------
361class BasanosStream:
362 """Incremental (streaming) optimiser backed by a single `_StreamState`.
364 After warming up on a historical batch via `from_warmup`, each call
365 to `step` advances the internal state by exactly one row in
366 O(N^2) time — without revisiting the full warmup history.
368 Attributes:
369 assets: Ordered list of asset column names (read-only).
371 Examples:
372 >>> import numpy as np
373 >>> import polars as pl
374 >>> from datetime import date, timedelta
375 >>> from basanos.math import BasanosConfig, BasanosStream
376 >>> rng = np.random.default_rng(0)
377 >>> warmup_len = 60
378 >>> dates = pl.date_range(
379 ... start=date(2024, 1, 1),
380 ... end=date(2024, 1, 1) + timedelta(days=warmup_len),
381 ... interval="1d",
382 ... eager=True,
383 ... )
384 >>> prices = pl.DataFrame({
385 ... "date": dates,
386 ... "A": np.cumprod(1 + rng.normal(0.001, 0.02, warmup_len + 1)) * 100.0,
387 ... "B": np.cumprod(1 + rng.normal(0.001, 0.02, warmup_len + 1)) * 150.0,
388 ... })
389 >>> mu = pl.DataFrame({
390 ... "date": dates,
391 ... "A": rng.normal(0, 0.5, warmup_len + 1),
392 ... "B": rng.normal(0, 0.5, warmup_len + 1),
393 ... })
394 >>> cfg = BasanosConfig(vola=5, corr=10, clip=3.0, shrink=0.5, aum=1e6)
395 >>> stream = BasanosStream.from_warmup(prices.head(warmup_len), mu.head(warmup_len), cfg)
396 >>> result = stream.step(
397 ... prices.select(["A", "B"]).to_numpy()[warmup_len],
398 ... mu.select(["A", "B"]).to_numpy()[warmup_len],
399 ... prices["date"][warmup_len],
400 ... )
401 >>> isinstance(result, StepResult)
402 True
403 >>> result.cash_position.shape
404 (2,)
405 """
407 _cfg: BasanosConfig
408 _assets: list[str]
409 _state: _StreamState
411 def __init__(self, cfg: BasanosConfig, assets: list[str], state: _StreamState) -> None:
412 """Initialise from an explicit config, asset list, and state container."""
413 object.__setattr__(self, "_cfg", cfg)
414 object.__setattr__(self, "_assets", assets)
415 object.__setattr__(self, "_state", state)
417 def __setattr__(self, name: str, value: object) -> None:
418 """Prevent accidental attribute mutation — BasanosStream is immutable."""
419 raise dataclasses.FrozenInstanceError(f"{type(self).__name__}.{name}")
421 @property
422 def assets(self) -> list[str]:
423 """Ordered list of asset column names."""
424 return self._assets
426 # ------------------------------------------------------------------
427 # from_warmup
428 # ------------------------------------------------------------------
430 @classmethod
431 def from_warmup(
432 cls,
433 prices: pl.DataFrame,
434 mu: pl.DataFrame,
435 cfg: BasanosConfig,
436 ) -> BasanosStream:
437 """Build a `BasanosStream` from a historical warmup batch.
439 Runs `BasanosEngine` on the full warmup batch
440 exactly once and extracts the minimal IIR-filter state required for
441 subsequent `step` calls. After this call, each `step`
442 advances the optimiser in O(N^2) time without touching the warmup
443 data again.
445 Parameters
446 ----------
447 prices:
448 Historical price DataFrame. Must contain a ``'date'`` column and
449 at least one numeric asset column with strictly positive,
450 non-monotonic values.
451 mu:
452 Expected-return signal DataFrame aligned row-by-row with
453 ``prices``.
454 cfg:
455 Engine configuration. Both `EwmaShrinkConfig`
456 and `SlidingWindowConfig` are supported.
458 Returns:
459 -------
460 BasanosStream
461 A stream instance whose `step` method is ready to accept the
462 row immediately following the last warmup row.
464 Notes:
465 ------
466 **Short-warmup behaviour with** ``SlidingWindowConfig``: when
467 ``len(prices) < cfg.covariance_config.window``, the internal rolling
468 buffer (``sw_ret_buf``) is NaN-padded for the missing prefix rows.
469 `step` returns ``StepResult(status="warmup")`` for each of the
470 first ``window - len(prices)`` calls, exactly matching the EWM warmup
471 semantics. By the time `step` returns the first non-warmup
472 result the buffer contains only real data — no NaN-padded rows remain.
474 Raises:
475 ------
476 MissingDateColumnError
477 If ``'date'`` is absent from ``prices``.
478 """
479 # 1. Validate -------------------------------------------------------
480 if "date" not in prices.columns:
481 raise MissingDateColumnError("prices")
483 # 2. Build the engine on the full warmup batch ----------------------
484 # Import here to avoid a circular dependency at module level.
485 from .optimizer import BasanosEngine
487 engine = BasanosEngine(prices=prices, mu=mu, cfg=cfg)
488 assets = engine.assets
489 n_assets = len(assets)
490 n_rows = prices.height
491 prices_np = prices.select(assets).to_numpy() # (n_rows, n_assets)
493 # 3. Extract mode-specific state from WarmupState --------------------
494 ws = engine.warmup_state()
495 if isinstance(cfg.covariance_config, EwmaShrinkConfig):
496 # EWM: seed the growing history buffer from engine.ret_adj so that
497 # each subsequent step() can call ewm_covariance over the full history.
498 ret_adj_np = engine.ret_adj.select(assets).to_numpy()
499 corr_ret_buf: np.ndarray | None = ret_adj_np
500 sw_ret_buf: np.ndarray | None = None
501 else:
502 # SW: carry the last W vol-adjusted returns as a rolling buffer.
503 sw_config = cfg.covariance_config
504 win_w = sw_config.window
505 ret_adj_np = engine.ret_adj.select(assets).to_numpy() # (n_rows, N)
506 if n_rows >= win_w:
507 sw_ret_buf = ret_adj_np[-win_w:].copy()
508 else:
509 sw_ret_buf = np.full((win_w, n_assets), np.nan)
510 sw_ret_buf[-n_rows:] = ret_adj_np
511 corr_ret_buf = None
513 # 4. Derive EWMA volatility accumulators (vectorised) ---------------
514 # Both log-return (for vol_adj) and pct-return (for vola) use the
515 # same beta = (vola-1)/vola. NaN observations (leading NaN at row 0
516 # from diff/pct_change) are skipped — the filter input is 0 for NaN
517 # rows and the weight accumulator (s_w) only increments for finite
518 # observations, matching Polars' effective behaviour for a
519 # leading-NaN series.
520 #
521 # Delegate to the shared helper _ewm_vol_accumulators_from_batch so
522 # that the batch and incremental recurrences share a single definition.
523 beta_vola: float = (cfg.vola - 1) / cfg.vola
524 beta_vola_sq: float = beta_vola**2
526 log_ret = np.full((n_rows, n_assets), np.nan, dtype=float)
527 pct_ret = np.full((n_rows, n_assets), np.nan, dtype=float)
528 if n_rows > 1:
529 with np.errstate(divide="ignore", invalid="ignore"):
530 log_ret[1:] = np.log(prices_np[1:] / prices_np[:-1])
531 pct_ret[1:] = prices_np[1:] / prices_np[:-1] - 1.0
533 vola_s_x, vola_s_x2, vola_s_w, vola_s_w2, vola_count = _ewm_vol_accumulators_from_batch(
534 log_ret, beta_vola, beta_vola_sq
535 )
536 pct_s_x, pct_s_x2, pct_s_w, pct_s_w2, pct_count = _ewm_vol_accumulators_from_batch(
537 pct_ret, beta_vola, beta_vola_sq
538 )
540 # 5. Extract prev_cash_pos from WarmupState --------------------------
541 prev_cash_pos: np.ndarray = ws.prev_cash_pos
542 prev_price: np.ndarray = prices_np[-1].copy()
544 # 6. Construct _StreamState and return ------------------------------
545 state = _StreamState(
546 corr_ret_buf=corr_ret_buf,
547 vola_s_x=vola_s_x,
548 vola_s_x2=vola_s_x2,
549 vola_s_w=vola_s_w,
550 vola_s_w2=vola_s_w2,
551 vola_count=vola_count,
552 pct_s_x=pct_s_x,
553 pct_s_x2=pct_s_x2,
554 pct_s_w=pct_s_w,
555 pct_s_w2=pct_s_w2,
556 pct_count=pct_count,
557 prev_price=prev_price,
558 prev_cash_pos=prev_cash_pos,
559 step_count=n_rows,
560 sw_ret_buf=sw_ret_buf,
561 )
562 return cls(cfg=cfg, assets=assets, state=state)
564 # ------------------------------------------------------------------
565 # step
566 # ------------------------------------------------------------------
568 @staticmethod
569 def _warmup_threshold(cfg: BasanosConfig) -> int:
570 """Return the step count at which warmup ends for the configured mode."""
571 if isinstance(cfg.covariance_config, SlidingWindowConfig):
572 return cfg.covariance_config.window
573 return cfg.corr
575 @staticmethod
576 def _persist_state(
577 state: _StreamState,
578 *,
579 corr_ret_buf: np.ndarray | None,
580 vola_s_x: np.ndarray,
581 vola_s_x2: np.ndarray,
582 vola_s_w: np.ndarray,
583 vola_s_w2: np.ndarray,
584 vola_count: np.ndarray,
585 pct_s_x: np.ndarray,
586 pct_s_x2: np.ndarray,
587 pct_s_w: np.ndarray,
588 pct_s_w2: np.ndarray,
589 pct_count: np.ndarray,
590 new_price: np.ndarray,
591 new_cash_pos: np.ndarray | None = None,
592 ) -> None:
593 """Persist accumulators, last-seen vectors, and increment step count."""
594 state.corr_ret_buf = corr_ret_buf
595 state.vola_s_x = vola_s_x
596 state.vola_s_x2 = vola_s_x2
597 state.vola_s_w = vola_s_w
598 state.vola_s_w2 = vola_s_w2
599 state.vola_count = vola_count
600 state.pct_s_x = pct_s_x
601 state.pct_s_x2 = pct_s_x2
602 state.pct_s_w = pct_s_w
603 state.pct_s_w2 = pct_s_w2
604 state.pct_count = pct_count
605 state.prev_price = new_price.copy()
606 if new_cash_pos is not None:
607 state.prev_cash_pos = new_cash_pos.copy()
608 state.step_count += 1
610 @staticmethod
611 def _warmup_result(n_assets: int, date: Any) -> StepResult:
612 """Build a standard warmup ``StepResult`` payload."""
613 return StepResult(
614 date=date,
615 cash_position=np.full(n_assets, np.nan),
616 status=SolveStatus.WARMUP,
617 vola=np.full(n_assets, np.nan),
618 )
620 def _solve_sliding_window_position(
621 self,
622 *,
623 cfg: BasanosConfig,
624 state: _StreamState,
625 mask: np.ndarray,
626 new_m: np.ndarray,
627 vola_vec: np.ndarray,
628 n_assets: int,
629 date: Any,
630 ) -> tuple[np.ndarray, SolveStatus]:
631 """Solve one step in SlidingWindow mode and return cash position + status."""
632 from cvx.linalg import SingularMatrixError
634 new_cash_pos = np.full(n_assets, np.nan, dtype=float)
635 status = SolveStatus.DEGENERATE
636 sw_config = cast(SlidingWindowConfig, cfg.covariance_config)
637 if not mask.any():
638 return new_cash_pos, status
640 win_w = sw_config.window
641 win_k = sw_config.n_factors
642 sw_ret_buf = cast(np.ndarray, state.sw_ret_buf)
643 window_ret = np.where(
644 np.isfinite(sw_ret_buf[:, mask]),
645 sw_ret_buf[:, mask],
646 0.0,
647 )
648 n_sub = int(mask.sum())
649 k_eff = min(win_k, win_w, n_sub)
650 if sw_config.max_components is not None:
651 k_eff = min(k_eff, sw_config.max_components)
652 try:
653 fm = FactorModel.from_returns(window_ret, k=k_eff)
654 except (np.linalg.LinAlgError, ValueError) as exc:
655 _logger.debug("Sliding window SVD failed at date=%s: %s", date, exc)
656 new_cash_pos[mask] = 0.0
657 return new_cash_pos, status
659 expected_mu = np.nan_to_num(new_m[mask])
660 if np.allclose(expected_mu, 0.0):
661 new_cash_pos[mask] = 0.0
662 return new_cash_pos, SolveStatus.ZERO_SIGNAL
664 try:
665 x = fm.solve(expected_mu)
666 denom_val = float(np.sqrt(max(0.0, float(np.dot(expected_mu, x)))))
667 except (SingularMatrixError, np.linalg.LinAlgError) as exc:
668 _logger.warning("Woodbury solve failed at date=%s: %s", date, exc)
669 new_cash_pos[mask] = 0.0
670 return new_cash_pos, status
672 if not np.isfinite(denom_val) or denom_val <= cfg.denom_tol:
673 _logger.warning(
674 "Positions zeroed at date=%s (sliding_window): normalisation "
675 "denominator degenerate (denom=%s, denom_tol=%s).",
676 date,
677 denom_val,
678 cfg.denom_tol,
679 )
680 new_cash_pos[mask] = 0.0
681 return new_cash_pos, status
683 risk_pos = x / denom_val
684 vola_sub = vola_vec[mask]
685 with np.errstate(invalid="ignore"):
686 new_cash_pos[mask] = risk_pos / vola_sub
687 return new_cash_pos, SolveStatus.VALID
689 @staticmethod
690 def _solve_ewma_position(
691 *,
692 cfg: BasanosConfig,
693 state: _StreamState,
694 corr_ret_buf: np.ndarray,
695 mask: np.ndarray,
696 new_m: np.ndarray,
697 vola_vec: np.ndarray,
698 assets: list[str],
699 n_assets: int,
700 date: Any,
701 ) -> tuple[np.ndarray, SolveStatus]:
702 """Solve one step in EWMA mode and return cash position + status."""
703 new_cash_pos = np.full(n_assets, np.nan, dtype=float)
704 buf = corr_ret_buf # (T, N) — already includes the new row
705 span = 2 * cfg.corr + 1
706 t = buf.shape[0]
707 cols = [pl.Series(a, buf[:, i]).fill_nan(None) for i, a in enumerate(assets)]
708 pl_df = pl.DataFrame([pl.Series("t", list(range(t))), *cols])
709 cov_dict = ewm_covariance(pl_df, assets=assets, index_col="t", window=span, warmup=cfg.corr)
710 if not cov_dict:
711 corr = np.full((n_assets, n_assets), np.nan)
712 else:
713 # keys are the integer ``t`` index values built from ``range(t)`` above
714 latest = max(cov_dict, key=lambda k: cast("int", k))
715 corr = cov_to_corr(cov_dict[latest], cfg.min_corr_denom)
716 matrix = shrink2id(corr, lamb=cfg.shrink)
717 expected_mu, early = _SolveMixin._row_early_check(state.step_count, date, mask, new_m)
718 if early is not None:
719 _, _, _, pos, status = early
720 new_cash_pos[mask] = pos
721 return new_cash_pos, status
723 corr_sub = matrix[np.ix_(mask, mask)]
724 _, _, _, pos, status = _SolveMixin._compute_position(
725 state.step_count, date, mask, expected_mu, MatrixBundle(matrix=corr_sub), cfg.denom_tol
726 )
727 if status == SolveStatus.VALID:
728 new_cash_pos[mask] = _SolveMixin._scale_to_cash(cast(np.ndarray, pos), vola_vec[mask])
729 else:
730 new_cash_pos[mask] = pos
731 return new_cash_pos, status
733 def step(
734 self,
735 new_prices: np.ndarray | dict[str, float],
736 new_mu: np.ndarray | dict[str, float],
737 date: Any = None,
738 ) -> StepResult:
739 """Advance the stream by one row and return the new optimised position.
741 Parameters
742 ----------
743 new_prices:
744 Per-asset prices for the new timestep. Either a numpy array of
745 shape ``(N,)`` (assets ordered as in `assets`) or a dict
746 mapping asset names to price values.
747 new_mu:
748 Per-asset expected-return signals, same format as ``new_prices``.
749 date:
750 Timestamp for this step (stored in `date`
751 verbatim; not used in any computation).
753 Returns:
754 -------
755 StepResult
756 Frozen dataclass with ``cash_position``, ``vola``, ``status``, and
757 ``date`` for this timestep.
758 """
759 cfg = self._cfg
760 assets = self._assets
761 state = self._state
762 n_assets = len(assets)
764 # ── Check if still in the warmup period ──────────────────────────────
765 # step_count is initialised to n_rows in from_warmup.
766 #
767 # EwmaShrinkConfig: in_warmup is True for the first (cfg.corr - n_rows)
768 # calls when the warmup batch was shorter than cfg.corr (not enough rows
769 # to populate the EWM correlation matrix).
770 #
771 # SlidingWindowConfig: in_warmup is True for the first (window - n_rows)
772 # calls when the warmup batch was shorter than the window. During this
773 # period sw_ret_buf still contains NaN-padded prefix rows; each step
774 # shifts one NaN out and appends a real row, so the buffer is fully
775 # populated with real data exactly when in_warmup becomes False.
776 #
777 # In both modes all accumulators are still updated during warmup so that
778 # the state is ready the moment the warmup period ends.
779 _warmup_thresh = self._warmup_threshold(cfg)
780 in_warmup: bool = state.step_count < _warmup_thresh
782 # ── Resolve inputs to (N,) float64 arrays ──────────────────────────
783 new_p = _resolve_step_vector(new_prices, assets, n_assets, "new_prices")
784 new_m = _resolve_step_vector(new_mu, assets, n_assets, "new_mu")
786 prev_p = state.prev_price
787 beta_vola: float = (cfg.vola - 1) / cfg.vola
788 beta_vola_sq: float = beta_vola**2
790 # ── Compute new log-returns and pct-returns ─────────────────────────
791 with np.errstate(divide="ignore", invalid="ignore"):
792 ratio = np.where(
793 np.isfinite(new_p) & np.isfinite(prev_p) & (prev_p > 0),
794 new_p / prev_p,
795 np.nan,
796 )
797 log_ret = np.log(ratio)
798 pct_ret = ratio - 1.0
800 # ── Update log-return EWMA accumulators ────────────────────────────
801 fin_log = np.isfinite(log_ret)
802 vola_s_x = beta_vola * state.vola_s_x + np.where(fin_log, log_ret, 0.0)
803 vola_s_x2 = beta_vola * state.vola_s_x2 + np.where(fin_log, log_ret**2, 0.0)
804 vola_s_w = beta_vola * state.vola_s_w + fin_log.astype(float)
805 vola_s_w2 = beta_vola_sq * state.vola_s_w2 + fin_log.astype(float)
806 vola_count = state.vola_count + fin_log.astype(int)
808 # ── Update pct-return EWMA accumulators ────────────────────────────
809 fin_pct = np.isfinite(pct_ret)
810 pct_s_x = beta_vola * state.pct_s_x + np.where(fin_pct, pct_ret, 0.0)
811 pct_s_x2 = beta_vola * state.pct_s_x2 + np.where(fin_pct, pct_ret**2, 0.0)
812 pct_s_w = beta_vola * state.pct_s_w + fin_pct.astype(float)
813 pct_s_w2 = beta_vola_sq * state.pct_s_w2 + fin_pct.astype(float)
814 pct_count = state.pct_count + fin_pct.astype(int)
816 # ── Compute vol-adjusted return (for the correlation IIR input) ─────
817 log_vol = _ewm_std_from_state(vola_s_x, vola_s_x2, vola_s_w, vola_s_w2, vola_count, min_samples=1)
818 # Divide; std == 0 yields ±inf → clipped to ±cfg.clip (matches Polars)
819 with np.errstate(divide="ignore", invalid="ignore"):
820 vol_adj_val = np.where(
821 fin_log,
822 np.clip(log_ret / log_vol, -cfg.clip, cfg.clip),
823 np.nan,
824 )
826 # ── Mode-specific correlation state update ───────────────────────────
827 if isinstance(cfg.covariance_config, SlidingWindowConfig):
828 # SW: shift the rolling window buffer in-place and append this row.
829 buf = cast(np.ndarray, state.sw_ret_buf) # (W, N), already owned by state
830 buf[:-1] = buf[1:]
831 buf[-1] = vol_adj_val
832 corr_ret_buf = state.corr_ret_buf # None for SW; pass through
833 else:
834 # EWM: append new vol-adjusted return to the growing history buffer.
835 new_row = vol_adj_val[np.newaxis] # (1, N)
836 corr_ret_buf = np.vstack([cast(np.ndarray, state.corr_ret_buf), new_row])
838 # ── Early return during EWM warmup period ───────────────────────────
839 # All accumulators are already updated above; skip the O(N²) matrix
840 # reconstruction and O(N³) Cholesky solve which are wasteful during
841 # warmup — the computed positions would be discarded anyway.
842 if in_warmup:
843 self._persist_state(
844 state,
845 corr_ret_buf=corr_ret_buf,
846 vola_s_x=vola_s_x,
847 vola_s_x2=vola_s_x2,
848 vola_s_w=vola_s_w,
849 vola_s_w2=vola_s_w2,
850 vola_count=vola_count,
851 pct_s_x=pct_s_x,
852 pct_s_x2=pct_s_x2,
853 pct_s_w=pct_s_w,
854 pct_s_w2=pct_s_w2,
855 pct_count=pct_count,
856 new_price=new_p,
857 )
858 return self._warmup_result(n_assets, date)
860 # ── Compute EWMA volatility (pct-return std) — shared ───────────────
861 vola_vec = _ewm_std_from_state(pct_s_x, pct_s_x2, pct_s_w, pct_s_w2, pct_count, min_samples=cfg.vola)
863 # ── Solve for position ───────────────────────────────────────────────
864 mask = np.isfinite(new_p)
865 if isinstance(cfg.covariance_config, SlidingWindowConfig):
866 new_cash_pos, status = self._solve_sliding_window_position(
867 cfg=cfg,
868 state=state,
869 mask=mask,
870 new_m=new_m,
871 vola_vec=vola_vec,
872 n_assets=n_assets,
873 date=date,
874 )
875 else:
876 new_cash_pos, status = self._solve_ewma_position(
877 cfg=cfg,
878 state=state,
879 corr_ret_buf=cast(np.ndarray, corr_ret_buf),
880 mask=mask,
881 new_m=new_m,
882 vola_vec=vola_vec,
883 assets=list(self._assets),
884 n_assets=n_assets,
885 date=date,
886 )
888 # ── Apply turnover constraint ─────────────────────────────────────────
889 if cfg.max_turnover is not None and status == SolveStatus.VALID:
890 new_cash_pos[mask] = _SolveMixin._apply_turnover_constraint(
891 new_cash_pos[mask],
892 state.prev_cash_pos[mask],
893 cfg.max_turnover,
894 )
896 # ── Persist updated state ───────────────────────────────────────────
897 self._persist_state(
898 state,
899 corr_ret_buf=corr_ret_buf,
900 vola_s_x=vola_s_x,
901 vola_s_x2=vola_s_x2,
902 vola_s_w=vola_s_w,
903 vola_s_w2=vola_s_w2,
904 vola_count=vola_count,
905 pct_s_x=pct_s_x,
906 pct_s_x2=pct_s_x2,
907 pct_s_w=pct_s_w,
908 pct_s_w2=pct_s_w2,
909 pct_count=pct_count,
910 new_price=new_p,
911 new_cash_pos=new_cash_pos,
912 )
914 return StepResult(
915 date=date,
916 cash_position=new_cash_pos,
917 status=status,
918 vola=vola_vec,
919 )
921 def save(self, path: str | os.PathLike[str]) -> None:
922 """Serialise the stream to a ``.npz`` archive at *path*.
924 All `_StreamState` arrays, the configuration, and the asset
925 list are written in a single `savez` call. A stream
926 restored via `load` produces bit-for-bit identical
927 `step` output.
929 Args:
930 path: Destination file path. `savez` appends
931 ``.npz`` automatically when the suffix is absent.
933 Examples:
934 >>> import tempfile, pathlib, numpy as np
935 >>> import polars as pl
936 >>> from datetime import date, timedelta
937 >>> from basanos.math import BasanosConfig, BasanosStream
938 >>> rng = np.random.default_rng(0)
939 >>> n = 60
940 >>> end = date(2024, 1, 1) + timedelta(days=n - 1)
941 >>> dates = pl.date_range(
942 ... date(2024, 1, 1), end, interval="1d", eager=True
943 ... )
944 >>> prices = pl.DataFrame({
945 ... "date": dates,
946 ... "A": np.cumprod(1 + rng.normal(0.001, 0.02, n)) * 100.0,
947 ... "B": np.cumprod(1 + rng.normal(0.001, 0.02, n)) * 150.0,
948 ... })
949 >>> mu = pl.DataFrame({
950 ... "date": dates,
951 ... "A": rng.normal(0, 0.5, n),
952 ... "B": rng.normal(0, 0.5, n),
953 ... })
954 >>> cfg = BasanosConfig(vola=5, corr=10, clip=3.0, shrink=0.5, aum=1e6)
955 >>> stream = BasanosStream.from_warmup(prices, mu, cfg)
956 >>> with tempfile.TemporaryDirectory() as tmp:
957 ... p = pathlib.Path(tmp) / "stream.npz"
958 ... stream.save(p)
959 ... restored = BasanosStream.load(p)
960 ... restored.assets == stream.assets
961 True
962 """
963 state = self._state
964 # Build the per-field dict automatically from _StreamState so that any
965 # new field added to the dataclass is included without manual updates.
966 state_arrays: dict[str, Any] = {}
967 for field in dataclasses.fields(_StreamState):
968 value = getattr(state, field.name)
969 if field.name in ("sw_ret_buf", "corr_ret_buf"):
970 # Sentinel: use an empty (0, 0) array to represent None so the
971 # key is always present in the archive and load() can detect it.
972 state_arrays[field.name] = value if value is not None else np.empty((0, 0), dtype=float)
973 elif field.name == "step_count":
974 state_arrays[field.name] = np.array(value)
975 else:
976 state_arrays[field.name] = value
977 np.savez(
978 path,
979 format_version=np.array(_SAVE_FORMAT_VERSION),
980 cfg_json=np.array(self._cfg.model_dump_json()),
981 assets=np.array(self._assets),
982 **state_arrays,
983 )
985 @classmethod
986 def load(cls, path: str | os.PathLike[str]) -> BasanosStream:
987 """Restore a stream previously saved with `save`.
989 Args:
990 path: Path to a ``.npz`` archive written by `save`.
992 Returns:
993 A `BasanosStream` whose `step` output is
994 bit-for-bit identical to the original stream at the time
995 `save` was called.
997 Examples:
998 >>> import tempfile, pathlib, numpy as np
999 >>> import polars as pl
1000 >>> from datetime import date, timedelta
1001 >>> from basanos.math import BasanosConfig, BasanosStream
1002 >>> rng = np.random.default_rng(1)
1003 >>> n = 60
1004 >>> end = date(2024, 1, 1) + timedelta(days=n - 1)
1005 >>> dates = pl.date_range(
1006 ... date(2024, 1, 1), end, interval="1d", eager=True
1007 ... )
1008 >>> prices = pl.DataFrame({
1009 ... "date": dates,
1010 ... "A": np.cumprod(1 + rng.normal(0.001, 0.02, n)) * 100.0,
1011 ... "B": np.cumprod(1 + rng.normal(0.001, 0.02, n)) * 150.0,
1012 ... })
1013 >>> mu = pl.DataFrame({
1014 ... "date": dates,
1015 ... "A": rng.normal(0, 0.5, n),
1016 ... "B": rng.normal(0, 0.5, n),
1017 ... })
1018 >>> cfg = BasanosConfig(vola=5, corr=10, clip=3.0, shrink=0.5, aum=1e6)
1019 >>> stream = BasanosStream.from_warmup(prices, mu, cfg)
1020 >>> with tempfile.TemporaryDirectory() as tmp:
1021 ... p = pathlib.Path(tmp) / "stream.npz"
1022 ... stream.save(p)
1023 ... restored = BasanosStream.load(p)
1024 ... restored.assets == stream.assets
1025 True
1026 """
1027 with np.load(path, allow_pickle=False) as data:
1028 if "format_version" not in data:
1029 raise ValueError( # noqa: TRY003
1030 "Stream file is missing a format version tag. "
1031 "It was written with an incompatible version of BasanosStream. "
1032 "Re-generate it via BasanosStream.from_warmup()."
1033 )
1034 found = int(data["format_version"])
1035 if found != _SAVE_FORMAT_VERSION:
1036 raise ValueError( # noqa: TRY003
1037 f"Stream file was written with format version {found}, "
1038 f"but the current version is {_SAVE_FORMAT_VERSION}. "
1039 "Re-generate it via BasanosStream.from_warmup()."
1040 )
1041 # Validate that every required key is present. This catches archives
1042 # that were produced by an older codebase missing a newly added field,
1043 # or archives that have been manually edited, with a descriptive error
1044 # instead of a bare KeyError.
1045 archive_keys = frozenset(data.files)
1046 missing = _REQUIRED_KEYS - archive_keys
1047 if missing:
1048 raise StreamStateCorruptError(missing)
1049 cfg = BasanosConfig.model_validate_json(data["cfg_json"].item())
1050 assets: list[str] = list(data["assets"])
1051 state_kwargs: dict[str, Any] = {}
1052 for field in dataclasses.fields(_StreamState):
1053 raw = data[field.name]
1054 if field.name in ("sw_ret_buf", "corr_ret_buf"):
1055 state_kwargs[field.name] = raw if raw.size > 0 else None
1056 elif field.name == "step_count":
1057 state_kwargs[field.name] = int(raw)
1058 else:
1059 state_kwargs[field.name] = raw
1060 state = _StreamState(**state_kwargs)
1061 return cls(cfg=cfg, assets=assets, state=state)