Coverage for src / basanos / math / _stream.py: 100%
317 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-02 17:47 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-02 17:47 +0000
1"""Incremental (streaming) API for BasanosEngine.
3This private module defines three public symbols:
5* :class:`_StreamState` — mutable dataclass that persists all O(N²) IIR
6 filter and EWMA accumulator state between consecutive
7 :meth:`BasanosStream.step` calls. Kept separate from the engine so the
8 state layout can be read and tested in isolation.
9* :class:`StepResult` — frozen dataclass returned by each
10 :meth:`BasanosStream.step` call.
11* :class:`BasanosStream` — incremental façade with a
12 :meth:`~BasanosStream.from_warmup` classmethod and a
13 :meth:`~BasanosStream.step` method.
15IIR state model
16---------------
17The EWM recurrence ``s[t] = beta·s[t-1] + v[t]`` is a causal, single-pole IIR
18filter. When the full history is available, ``scipy.signal.lfilter`` solves
19all N² pairs in one vectorised call and discards the intermediate array.
21In the *incremental* setting there is no history — only the current sample
22arrives at each ``step()``. To continue the same recurrence across calls we
23need the *filter memory*, i.e. the value of the accumulator at the end of the
24previous call.
26``scipy.signal.lfilter`` exposes this directly: when called as::
28 y, zf = lfilter(b, a, x, zi=zi)
30``zi`` is the initial state (shape ``(max(len(a), len(b)) - 1, …)`` = ``(1, …)``
31for our first-order filter) and ``zf`` is the *final* state after processing
32``x``. Passing the returned ``zf`` back as ``zi`` in the next call is
33mathematically equivalent to having run ``lfilter`` over the concatenated
34input, so the incremental and batch paths produce bit-for-bit identical
35results.
37The four correlation accumulators (``corr_zi_x``, ``corr_zi_x2``,
38``corr_zi_xy``, ``corr_zi_w``) follow exactly this pattern. Each has shape
39``(1, N, N)`` — the leading 1 is the IIR filter order required by
40``lfilter``'s ``zi`` argument.
42The volatility accumulators (``vola_*``, ``pct_*``) use a simpler scalar
43recurrence and store the running sums directly as ``(N,)`` arrays.
45Memory
46------
47Total incremental state is 4x(1,N,N) + (N,N) + 8x(N,) + O(1) scalars,
48giving **O(N^2)** memory independent of the number of timesteps processed.
49"""
51from __future__ import annotations
53import dataclasses
54import logging
55import os
56from typing import TYPE_CHECKING, Any, cast
58if TYPE_CHECKING:
59 from ._ewm_corr import _EwmCorrState
61import numpy as np
62import polars as pl
63from scipy.signal import lfilter
65from ..exceptions import MissingDateColumnError, StreamStateCorruptError
66from ._config import BasanosConfig, EwmaShrinkConfig, SlidingWindowConfig
67from ._engine_solve import MatrixBundle, SolveStatus, _SolveMixin
68from ._ewm_corr import _corr_from_ewm_accumulators
69from ._factor_model import FactorModel
70from ._signal import shrink2id
72_logger = logging.getLogger(__name__)
74#: Increment this when the :func:`BasanosStream.save` archive layout changes in
75#: a backward-incompatible way. :func:`BasanosStream.load` asserts the stored
76#: value matches before deserialising anything, so callers get a clear error
77#: instead of a silent ``KeyError`` or wrong state.
78_SAVE_FORMAT_VERSION: int = 2
81@dataclasses.dataclass
82class _StreamState:
83 """Mutable state carrier for one :class:`BasanosStream` instance.
85 All arrays are updated in-place (or replaced) by ``BasanosStream.step()``.
86 The class is intentionally *not* frozen so that the step method can modify
87 fields directly without creating a new object on every tick.
89 IIR filter state (correlation)
90 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
91 The four ``corr_zi_*`` fields are the *final conditions* (``zf``) returned
92 by ``scipy.signal.lfilter`` after the previous step. They are passed back
93 as the ``zi`` argument on the next step so that the incremental filter
94 produces the same numerical result as a single batch call over all history.
95 See the module docstring for the full derivation.
97 ``beta_corr = cfg.corr / (1 + cfg.corr)`` (from ``com = cfg.corr``)
99 EWM accumulator state (volatility)
100 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
101 ``vola_*`` and ``pct_*`` accumulate the running weighted sums needed to
102 compute exponentially-weighted standard deviations:
104 * ``s_x`` — EWM sum of x (numerator of the mean)
105 * ``s_x2`` — EWM sum of x² (numerator of the second moment)
106 * ``s_w`` — EWM sum of weights (denominator)
107 * ``s_w2`` — EWM sum of squared weights (for bias correction)
109 ``beta_vola = (cfg.vola - 1) / cfg.vola`` (from ``com = cfg.vola - 1``)
111 Attributes:
112 corr_zi_x: IIR filter state for the x-accumulator of the EWM
113 correlation; shape ``(1, N, N)``.
114 corr_zi_x2: IIR filter state for the x²-accumulator; shape
115 ``(1, N, N)``.
116 corr_zi_xy: IIR filter state for the xy-accumulator; shape
117 ``(1, N, N)``.
118 corr_zi_w: IIR filter state for the weight-accumulator; shape
119 ``(1, N, N)``.
120 corr_count: Cumulative joint-finite observation count per asset pair;
121 shape ``(N, N)`` dtype int.
122 vola_s_x: EWM sum of volatility-adjusted log-returns; shape ``(N,)``.
123 vola_s_x2: EWM sum of squared vol-adj log-returns; shape ``(N,)``.
124 vola_s_w: EWM weight sum for vol accumulators; shape ``(N,)``.
125 vola_s_w2: EWM squared-weight sum for vol accumulators; shape ``(N,)``.
126 vola_count: Cumulative finite observation count for vol; shape ``(N,)``
127 dtype int.
128 pct_s_x: EWM sum of pct-returns; shape ``(N,)``.
129 pct_s_x2: EWM sum of squared pct-returns; shape ``(N,)``.
130 pct_s_w: EWM weight sum for pct accumulators; shape ``(N,)``.
131 pct_s_w2: EWM squared-weight sum for pct accumulators; shape ``(N,)``.
132 pct_count: Cumulative finite observation count for pct-return vol;
133 shape ``(N,)`` dtype int.
134 prev_price: Last price row seen, used to compute returns on the next
135 step; shape ``(N,)``.
136 prev_cash_pos: Last cash position, used to apply the turnover constraint
137 on the next step; shape ``(N,)``.
138 step_count: Number of steps processed so far (0 before first step).
139 """
141 # ── IIR filter state for EWM correlation — shape (1, N, N) each ──────────
142 corr_zi_x: np.ndarray # (1, N, N)
143 corr_zi_x2: np.ndarray # (1, N, N)
144 corr_zi_xy: np.ndarray # (1, N, N)
145 corr_zi_w: np.ndarray # (1, N, N)
146 corr_count: np.ndarray # (N, N) int — cumulative joint-finite observation count
148 # ── EWMA accumulators for vol_adj (log-return std; com=vola-1, min_samples=1) ──
149 vola_s_x: np.ndarray # (N,)
150 vola_s_x2: np.ndarray # (N,)
151 vola_s_w: np.ndarray # (N,)
152 vola_s_w2: np.ndarray # (N,)
153 vola_count: np.ndarray # (N,) int
155 # ── EWMA accumulators for vola (pct-return std; com=vola-1, min_samples=vola) ──
156 pct_s_x: np.ndarray # (N,)
157 pct_s_x2: np.ndarray # (N,)
158 pct_s_w: np.ndarray # (N,)
159 pct_s_w2: np.ndarray # (N,)
160 pct_count: np.ndarray # (N,) int
162 # ── Scalars ───────────────────────────────────────────────────────────────
163 prev_price: np.ndarray # (N,) last price row (to compute returns at next step)
164 prev_cash_pos: np.ndarray # (N,) last cash position (for turnover constraint at next step)
165 step_count: int
167 # ── SlidingWindowConfig state — None for EwmaShrinkConfig ────────────────
168 # shape (W, N): last W vol-adjusted returns (oldest row first); None when
169 # using EwmaShrinkConfig. The corr_zi_* fields above are unused (zeros)
170 # in this mode; sw_ret_buf carries all the correlation state instead.
171 sw_ret_buf: np.ndarray | None = None # (W, N) rolling buffer, or None
174#: Keys that :meth:`BasanosStream.save` writes to the ``.npz`` archive for
175#: :class:`_StreamState` fields. Derived automatically from
176#: :func:`dataclasses.fields` so that adding a new field to ``_StreamState``
177#: is sufficient — no manual update here is required.
178#:
179#: The three non-state keys (``format_version``, ``cfg_json``, ``assets``) are
180#: added explicitly because they are not fields of ``_StreamState`` itself.
181_REQUIRED_KEYS: frozenset[str] = frozenset(
182 {f.name for f in dataclasses.fields(_StreamState)} | {"format_version", "cfg_json", "assets"}
183)
186@dataclasses.dataclass(frozen=True)
187class StepResult:
188 """Frozen dataclass representing the output of a single ``BasanosStream`` step.
190 Each call to ``BasanosStream.step()`` returns one ``StepResult`` capturing
191 the optimised cash positions, the per-asset volatility estimate, the step
192 date, and a status label that describes the solver outcome for that
193 timestep.
195 Attributes:
196 date: The timestamp or date label for this step. The type mirrors
197 whatever is stored in the ``'date'`` column of the input prices
198 DataFrame (typically a Python :class:`datetime.date`,
199 :class:`datetime.datetime`, or a Polars temporal scalar).
200 cash_position: Optimised cash-position vector, shape ``(N,)``.
201 Entries are ``NaN`` for assets that are still in the EWMA warmup
202 period or that are otherwise inactive at this step.
203 status: Solver outcome label for this timestep
204 (:class:`~basanos.math.SolveStatus`). Since :class:`SolveStatus`
205 is a ``StrEnum``, values compare equal to their string equivalents
206 (e.g. ``result.status == "valid"`` is ``True``):
208 * ``'warmup'`` — fewer rows have been seen than the EWMA warmup
209 requires; all positions are ``NaN``.
210 * ``'zero_signal'`` — the expected-return signal vector ``mu`` is
211 identically zero; positions are set to zero rather than solved.
212 * ``'degenerate'`` — the covariance matrix is ill-conditioned or
213 numerically singular; positions cannot be computed reliably and
214 are returned as ``NaN``.
215 * ``'valid'`` — normal operation; ``cash_position`` holds the
216 optimised allocations.
217 vola: Per-asset EWMA percentage-return volatility, shape ``(N,)``.
218 Values are ``NaN`` during the warmup period before the EWMA has
219 accumulated sufficient history.
221 Examples:
222 >>> import numpy as np
223 >>> result = StepResult(
224 ... date="2024-01-02",
225 ... cash_position=np.array([1000.0, -500.0]),
226 ... status="valid",
227 ... vola=np.array([0.012, 0.018]),
228 ... )
229 >>> result.status
230 'valid'
231 >>> result.cash_position.shape
232 (2,)
233 """
235 date: object
236 cash_position: np.ndarray
237 status: SolveStatus
238 vola: np.ndarray
241# ---------------------------------------------------------------------------
242# Helper: unbiased EWMA std from running accumulators
243# ---------------------------------------------------------------------------
246def _ewm_std_from_state(
247 s_x: np.ndarray,
248 s_x2: np.ndarray,
249 s_w: np.ndarray,
250 s_w2: np.ndarray,
251 count: np.ndarray,
252 min_samples: int,
253) -> np.ndarray:
254 r"""Compute the unbiased EWMA standard deviation from running accumulators.
256 Implements the same Bessel-corrected formula used by
257 ``polars.Expr.ewm_std(adjust=True)``::
259 var_biased = s_x2/s_w - (s_x/s_w)^2
260 correction = s_w^2 / (s_w^2 - s_w2) # Bessel correction
261 var_unbiased = var_biased * correction
262 std = sqrt(max(0, var_unbiased))
264 where ``s_w2 = sum(wi^2)`` is the sum of squared EWM weights.
266 Parameters
267 ----------
268 s_x, s_x2, s_w, s_w2:
269 Running accumulators, each of shape ``(N,)``.
270 count:
271 Integer count of finite observations per asset, shape ``(N,)``.
272 min_samples:
273 Minimum number of finite observations required before returning a
274 non-NaN value.
276 Returns:
277 -------
278 np.ndarray of shape ``(N,)`` with per-asset standard deviations.
279 NaN is returned for assets where ``count < min_samples``.
280 """
281 n = len(s_x)
282 result = np.full(n, np.nan, dtype=float)
283 ok = count >= min_samples
284 if not ok.any():
285 return result
287 with np.errstate(divide="ignore", invalid="ignore"):
288 mean = np.where(s_w > 0, s_x / s_w, 0.0)
289 mean_sq = np.where(s_w > 0, s_x2 / s_w, 0.0)
290 var_biased = np.maximum(mean_sq - mean**2, 0.0)
291 denom_corr = s_w**2 - s_w2
292 # denom_corr > 0 iff count >= 2; equals 0 when count == 1
293 var_unbiased = np.where(denom_corr > 0, var_biased * s_w**2 / denom_corr, 0.0)
294 std = np.sqrt(var_unbiased)
296 return np.where(ok, std, np.nan)
299# ---------------------------------------------------------------------------
300# Helper: batch EWMA volatility accumulators from a returns matrix
301# ---------------------------------------------------------------------------
304def _ewm_vol_accumulators_from_batch(
305 returns: np.ndarray,
306 beta: float,
307 beta_sq: float,
308) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
309 r"""Compute final EWMA volatility accumulators from a batch of returns.
311 Implements the same IIR recurrence as :meth:`BasanosStream.step` but
312 vectorised over *T* timesteps using ``scipy.signal.lfilter``. The five
313 returned arrays are identical to the accumulators that would result from
314 feeding each row of *returns* through the scalar step-by-step recurrence::
316 s_x[t] = beta * s_x[t-1] + (x[t] if finite else 0)
317 s_x2[t] = beta * s_x2[t-1] + (x[t]^2 if finite else 0)
318 s_w[t] = beta * s_w[t-1] + (1 if finite else 0)
319 s_w2[t] = beta^2 * s_w2[t-1] + (1 if finite else 0)
321 Parameters
322 ----------
324 Returns:
325 Float array of shape ``(T, N)``. NaN entries are treated as missing
326 observations — they contribute nothing to the numerator sums and do
327 not increment the weight accumulators.
328 beta:
329 EWM decay factor for ``s_x``, ``s_x2``, and ``s_w``
330 (``beta = (com) / (1 + com)`` for ``com = cfg.vola - 1``).
331 beta_sq:
332 Squared decay factor used for ``s_w2``. Must equal ``beta ** 2``.
334 Returns:
335 -------
336 s_x, s_x2, s_w, s_w2 : np.ndarray of shape ``(N,)``
337 Final EWMA running accumulators after processing all *T* rows.
338 count : np.ndarray of shape ``(N,)`` dtype int
339 Number of finite observations per asset.
341 Notes:
342 -----
343 This function is the shared implementation used by
344 :meth:`BasanosStream.from_warmup` for both the log-return (``vola_*``)
345 and pct-return (``pct_*``) accumulators. Keeping a single implementation
346 here guarantees that the batch and incremental paths stay in sync when the
347 recurrence definition changes.
348 """
349 fin = np.isfinite(returns).astype(np.float64) # (T, N)
350 x_z = np.where(fin.astype(bool), returns, 0.0) # (T, N)
351 filt_a = np.array([1.0, -beta])
352 filt_a2 = np.array([1.0, -beta_sq])
354 s_x: np.ndarray = lfilter([1.0], filt_a, x_z, axis=0)[-1]
355 s_x2: np.ndarray = lfilter([1.0], filt_a, x_z**2, axis=0)[-1]
356 s_w: np.ndarray = lfilter([1.0], filt_a, fin, axis=0)[-1]
357 s_w2: np.ndarray = lfilter([1.0], filt_a2, fin, axis=0)[-1]
358 count: np.ndarray = fin.sum(axis=0).astype(int)
360 return s_x, s_x2, s_w, s_w2, count
363# ---------------------------------------------------------------------------
364# BasanosStream
365# ---------------------------------------------------------------------------
368class BasanosStream:
369 """Incremental (streaming) optimiser backed by a single :class:`_StreamState`.
371 After warming up on a historical batch via :meth:`from_warmup`, each call
372 to :meth:`step` advances the internal state by exactly one row in
373 O(N^2) time — without revisiting the full warmup history.
375 Attributes:
376 assets: Ordered list of asset column names (read-only).
378 Examples:
379 >>> import numpy as np
380 >>> import polars as pl
381 >>> from datetime import date, timedelta
382 >>> from basanos.math import BasanosConfig, BasanosStream
383 >>> rng = np.random.default_rng(0)
384 >>> warmup_len = 60
385 >>> dates = pl.date_range(
386 ... start=date(2024, 1, 1),
387 ... end=date(2024, 1, 1) + timedelta(days=warmup_len),
388 ... interval="1d",
389 ... eager=True,
390 ... )
391 >>> prices = pl.DataFrame({
392 ... "date": dates,
393 ... "A": np.cumprod(1 + rng.normal(0.001, 0.02, warmup_len + 1)) * 100.0,
394 ... "B": np.cumprod(1 + rng.normal(0.001, 0.02, warmup_len + 1)) * 150.0,
395 ... })
396 >>> mu = pl.DataFrame({
397 ... "date": dates,
398 ... "A": rng.normal(0, 0.5, warmup_len + 1),
399 ... "B": rng.normal(0, 0.5, warmup_len + 1),
400 ... })
401 >>> cfg = BasanosConfig(vola=5, corr=10, clip=3.0, shrink=0.5, aum=1e6)
402 >>> stream = BasanosStream.from_warmup(prices.head(warmup_len), mu.head(warmup_len), cfg)
403 >>> result = stream.step(
404 ... prices.select(["A", "B"]).to_numpy()[warmup_len],
405 ... mu.select(["A", "B"]).to_numpy()[warmup_len],
406 ... prices["date"][warmup_len],
407 ... )
408 >>> isinstance(result, StepResult)
409 True
410 >>> result.cash_position.shape
411 (2,)
412 """
414 _cfg: BasanosConfig
415 _assets: list[str]
416 _state: _StreamState
418 def __init__(self, cfg: BasanosConfig, assets: list[str], state: _StreamState) -> None:
419 """Initialise from an explicit config, asset list, and state container."""
420 object.__setattr__(self, "_cfg", cfg)
421 object.__setattr__(self, "_assets", assets)
422 object.__setattr__(self, "_state", state)
424 def __setattr__(self, name: str, value: object) -> None:
425 """Prevent accidental attribute mutation — BasanosStream is immutable."""
426 raise dataclasses.FrozenInstanceError(f"{type(self).__name__}.{name}")
428 @property
429 def assets(self) -> list[str]:
430 """Ordered list of asset column names."""
431 return self._assets
433 # ------------------------------------------------------------------
434 # from_warmup
435 # ------------------------------------------------------------------
437 @classmethod
438 def from_warmup(
439 cls,
440 prices: pl.DataFrame,
441 mu: pl.DataFrame,
442 cfg: BasanosConfig,
443 ) -> BasanosStream:
444 """Build a :class:`BasanosStream` from a historical warmup batch.
446 Runs :class:`~basanos.math.BasanosEngine` on the full warmup batch
447 exactly once and extracts the minimal IIR-filter state required for
448 subsequent :meth:`step` calls. After this call, each :meth:`step`
449 advances the optimiser in O(N^2) time without touching the warmup
450 data again.
452 Parameters
453 ----------
454 prices:
455 Historical price DataFrame. Must contain a ``'date'`` column and
456 at least one numeric asset column with strictly positive,
457 non-monotonic values.
458 mu:
459 Expected-return signal DataFrame aligned row-by-row with
460 ``prices``.
461 cfg:
462 Engine configuration. Both :class:`~basanos.math.EwmaShrinkConfig`
463 and :class:`~basanos.math.SlidingWindowConfig` are supported.
465 Returns:
466 -------
467 BasanosStream
468 A stream instance whose :meth:`step` method is ready to accept the
469 row immediately following the last warmup row.
471 Notes:
472 ------
473 **Short-warmup behaviour with** ``SlidingWindowConfig``: when
474 ``len(prices) < cfg.covariance_config.window``, the internal rolling
475 buffer (``sw_ret_buf``) is NaN-padded for the missing prefix rows.
476 :meth:`step` returns ``StepResult(status="warmup")`` for each of the
477 first ``window - len(prices)`` calls, exactly matching the EWM warmup
478 semantics. By the time :meth:`step` returns the first non-warmup
479 result the buffer contains only real data — no NaN-padded rows remain.
481 Raises:
482 ------
483 MissingDateColumnError
484 If ``'date'`` is absent from ``prices``.
485 """
486 # 1. Validate -------------------------------------------------------
487 if "date" not in prices.columns:
488 raise MissingDateColumnError("prices")
490 # 2. Build the engine on the full warmup batch ----------------------
491 # Import here to avoid a circular dependency at module level.
492 from .optimizer import BasanosEngine
494 engine = BasanosEngine(prices=prices, mu=mu, cfg=cfg)
495 assets = engine.assets
496 n_assets = len(assets)
497 n_rows = prices.height
498 prices_np = prices.select(assets).to_numpy() # (n_rows, n_assets)
500 # 3. Extract mode-specific state from WarmupState --------------------
501 ws = engine.warmup_state()
502 if isinstance(cfg.covariance_config, EwmaShrinkConfig):
503 # EWM: seed the per-step lfilter from the IIR state captured
504 # during the single batch pass in warmup_state().
505 iir = cast("_EwmCorrState", ws.corr_iir_state)
506 corr_zi_x = iir.corr_zi_x
507 corr_zi_x2 = iir.corr_zi_x2
508 corr_zi_xy = iir.corr_zi_xy
509 corr_zi_w = iir.corr_zi_w
510 corr_count: np.ndarray = iir.count
511 sw_ret_buf: np.ndarray | None = None
512 else:
513 # SW: carry the last W vol-adjusted returns as a rolling buffer.
514 # The IIR fields are initialised to zeros and left unused.
515 sw_config = cast(SlidingWindowConfig, cfg.covariance_config)
516 win_w = sw_config.window
517 ret_adj_np = engine.ret_adj.select(assets).to_numpy() # (n_rows, N)
518 if n_rows >= win_w:
519 sw_ret_buf = ret_adj_np[-win_w:].copy()
520 else:
521 sw_ret_buf = np.full((win_w, n_assets), np.nan)
522 sw_ret_buf[-n_rows:] = ret_adj_np
523 corr_zi_x = np.zeros((1, n_assets, n_assets))
524 corr_zi_x2 = np.zeros((1, n_assets, n_assets))
525 corr_zi_xy = np.zeros((1, n_assets, n_assets))
526 corr_zi_w = np.zeros((1, n_assets, n_assets))
527 corr_count = np.zeros((n_assets, n_assets), dtype=np.int64)
529 # 4. Derive EWMA volatility accumulators (vectorised) ---------------
530 # Both log-return (for vol_adj) and pct-return (for vola) use the
531 # same beta = (vola-1)/vola. NaN observations (leading NaN at row 0
532 # from diff/pct_change) are skipped — the filter input is 0 for NaN
533 # rows and the weight accumulator (s_w) only increments for finite
534 # observations, matching Polars' effective behaviour for a
535 # leading-NaN series.
536 #
537 # Delegate to the shared helper _ewm_vol_accumulators_from_batch so
538 # that the batch and incremental recurrences share a single definition.
539 beta_vola: float = (cfg.vola - 1) / cfg.vola
540 beta_vola_sq: float = beta_vola**2
542 log_ret = np.full((n_rows, n_assets), np.nan, dtype=float)
543 pct_ret = np.full((n_rows, n_assets), np.nan, dtype=float)
544 if n_rows > 1:
545 with np.errstate(divide="ignore", invalid="ignore"):
546 log_ret[1:] = np.log(prices_np[1:] / prices_np[:-1])
547 pct_ret[1:] = prices_np[1:] / prices_np[:-1] - 1.0
549 vola_s_x, vola_s_x2, vola_s_w, vola_s_w2, vola_count = _ewm_vol_accumulators_from_batch(
550 log_ret, beta_vola, beta_vola_sq
551 )
552 pct_s_x, pct_s_x2, pct_s_w, pct_s_w2, pct_count = _ewm_vol_accumulators_from_batch(
553 pct_ret, beta_vola, beta_vola_sq
554 )
556 # 5. Extract prev_cash_pos from WarmupState --------------------------
557 prev_cash_pos: np.ndarray = ws.prev_cash_pos
558 prev_price: np.ndarray = prices_np[-1].copy()
560 # 6. Construct _StreamState and return ------------------------------
561 state = _StreamState(
562 corr_zi_x=corr_zi_x,
563 corr_zi_x2=corr_zi_x2,
564 corr_zi_xy=corr_zi_xy,
565 corr_zi_w=corr_zi_w,
566 corr_count=corr_count,
567 vola_s_x=vola_s_x,
568 vola_s_x2=vola_s_x2,
569 vola_s_w=vola_s_w,
570 vola_s_w2=vola_s_w2,
571 vola_count=vola_count,
572 pct_s_x=pct_s_x,
573 pct_s_x2=pct_s_x2,
574 pct_s_w=pct_s_w,
575 pct_s_w2=pct_s_w2,
576 pct_count=pct_count,
577 prev_price=prev_price,
578 prev_cash_pos=prev_cash_pos,
579 step_count=n_rows,
580 sw_ret_buf=sw_ret_buf,
581 )
582 return cls(cfg=cfg, assets=assets, state=state)
584 # ------------------------------------------------------------------
585 # step
586 # ------------------------------------------------------------------
588 def step(
589 self,
590 new_prices: np.ndarray | dict[str, float],
591 new_mu: np.ndarray | dict[str, float],
592 date: Any = None,
593 ) -> StepResult:
594 """Advance the stream by one row and return the new optimised position.
596 Parameters
597 ----------
598 new_prices:
599 Per-asset prices for the new timestep. Either a numpy array of
600 shape ``(N,)`` (assets ordered as in :attr:`assets`) or a dict
601 mapping asset names to price values.
602 new_mu:
603 Per-asset expected-return signals, same format as ``new_prices``.
604 date:
605 Timestamp for this step (stored in :attr:`StepResult.date`
606 verbatim; not used in any computation).
608 Returns:
609 -------
610 StepResult
611 Frozen dataclass with ``cash_position``, ``vola``, ``status``, and
612 ``date`` for this timestep.
613 """
614 from ..exceptions import SingularMatrixError
616 cfg = self._cfg
617 assets = self._assets
618 state = self._state
619 n_assets = len(assets)
621 # ── Check if still in the warmup period ──────────────────────────────
622 # step_count is initialised to n_rows in from_warmup.
623 #
624 # EwmaShrinkConfig: in_warmup is True for the first (cfg.corr - n_rows)
625 # calls when the warmup batch was shorter than cfg.corr (not enough rows
626 # to populate the EWM correlation matrix).
627 #
628 # SlidingWindowConfig: in_warmup is True for the first (window - n_rows)
629 # calls when the warmup batch was shorter than the window. During this
630 # period sw_ret_buf still contains NaN-padded prefix rows; each step
631 # shifts one NaN out and appends a real row, so the buffer is fully
632 # populated with real data exactly when in_warmup becomes False.
633 #
634 # In both modes all accumulators are still updated during warmup so that
635 # the state is ready the moment the warmup period ends.
636 _warmup_thresh = (
637 cfg.covariance_config.window if isinstance(cfg.covariance_config, SlidingWindowConfig) else cfg.corr
638 )
639 in_warmup: bool = state.step_count < _warmup_thresh
641 # ── Resolve inputs to (N,) float64 arrays ──────────────────────────
642 if isinstance(new_prices, dict):
643 new_p = np.array([float(new_prices[a]) for a in assets], dtype=float)
644 else:
645 new_p = np.asarray(new_prices, dtype=float).ravel()
647 if isinstance(new_mu, dict):
648 new_m = np.array([float(new_mu[a]) for a in assets], dtype=float)
649 else:
650 new_m = np.asarray(new_mu, dtype=float).ravel()
652 if new_p.shape != (n_assets,):
653 raise ValueError(f"new_prices must have shape ({n_assets},); got {new_p.shape}") # noqa: TRY003
654 if new_m.shape != (n_assets,):
655 raise ValueError(f"new_mu must have shape ({n_assets},); got {new_m.shape}") # noqa: TRY003
657 prev_p = state.prev_price
658 beta_vola: float = (cfg.vola - 1) / cfg.vola
659 beta_vola_sq: float = beta_vola**2
660 beta_corr: float = cfg.corr / (1.0 + cfg.corr)
662 # ── Compute new log-returns and pct-returns ─────────────────────────
663 with np.errstate(divide="ignore", invalid="ignore"):
664 ratio = np.where(
665 np.isfinite(new_p) & np.isfinite(prev_p) & (prev_p > 0),
666 new_p / prev_p,
667 np.nan,
668 )
669 log_ret = np.log(ratio)
670 pct_ret = ratio - 1.0
672 # ── Update log-return EWMA accumulators ────────────────────────────
673 fin_log = np.isfinite(log_ret)
674 vola_s_x = beta_vola * state.vola_s_x + np.where(fin_log, log_ret, 0.0)
675 vola_s_x2 = beta_vola * state.vola_s_x2 + np.where(fin_log, log_ret**2, 0.0)
676 vola_s_w = beta_vola * state.vola_s_w + fin_log.astype(float)
677 vola_s_w2 = beta_vola_sq * state.vola_s_w2 + fin_log.astype(float)
678 vola_count = state.vola_count + fin_log.astype(int)
680 # ── Update pct-return EWMA accumulators ────────────────────────────
681 fin_pct = np.isfinite(pct_ret)
682 pct_s_x = beta_vola * state.pct_s_x + np.where(fin_pct, pct_ret, 0.0)
683 pct_s_x2 = beta_vola * state.pct_s_x2 + np.where(fin_pct, pct_ret**2, 0.0)
684 pct_s_w = beta_vola * state.pct_s_w + fin_pct.astype(float)
685 pct_s_w2 = beta_vola_sq * state.pct_s_w2 + fin_pct.astype(float)
686 pct_count = state.pct_count + fin_pct.astype(int)
688 # ── Compute vol-adjusted return (for the correlation IIR input) ─────
689 log_vol = _ewm_std_from_state(vola_s_x, vola_s_x2, vola_s_w, vola_s_w2, vola_count, min_samples=1)
690 # Divide; std == 0 yields ±inf → clipped to ±cfg.clip (matches Polars)
691 with np.errstate(divide="ignore", invalid="ignore"):
692 vol_adj_val = np.where(
693 fin_log,
694 np.clip(log_ret / log_vol, -cfg.clip, cfg.clip),
695 np.nan,
696 )
698 # ── Mode-specific correlation state update ───────────────────────────
699 if isinstance(cfg.covariance_config, SlidingWindowConfig):
700 # SW: shift the rolling window buffer in-place and append this row.
701 # The corr_zi_* fields are unused; alias them to their old values so
702 # the early-return and persist blocks below can reference them safely.
703 buf = state.sw_ret_buf # (W, N), already owned by state
704 buf[:-1] = buf[1:] # type: ignore[index]
705 buf[-1] = vol_adj_val # type: ignore[index]
706 corr_zi_x = state.corr_zi_x
707 corr_zi_x2 = state.corr_zi_x2
708 corr_zi_xy = state.corr_zi_xy
709 corr_zi_w = state.corr_zi_w
710 corr_count = state.corr_count
711 else:
712 # EWM: Update IIR filter state for EWM correlation
713 fin_va = np.isfinite(vol_adj_val)
714 va_f = np.where(fin_va, vol_adj_val, 0.0)
715 joint_fin = fin_va[:, np.newaxis] & fin_va[np.newaxis, :] # (N, N)
717 new_v_x = (va_f[:, np.newaxis] * joint_fin)[np.newaxis] # (1, N, N)
718 new_v_x2 = ((va_f**2)[:, np.newaxis] * joint_fin)[np.newaxis] # (1, N, N)
719 new_v_xy = (va_f[:, np.newaxis] * va_f[np.newaxis, :])[np.newaxis] # (1, N, N)
720 new_v_w = joint_fin.astype(np.float64)[np.newaxis] # (1, N, N)
722 filt_a_corr = np.array([1.0, -beta_corr])
723 # y_x[0] is the current-step EWM state (filter output); corr_zi_x is
724 # the new filter memory (zf = beta * y[0]) passed as zi next step.
725 y_x, corr_zi_x = lfilter([1.0], filt_a_corr, new_v_x, axis=0, zi=state.corr_zi_x)
726 y_x2, corr_zi_x2 = lfilter([1.0], filt_a_corr, new_v_x2, axis=0, zi=state.corr_zi_x2)
727 y_xy, corr_zi_xy = lfilter([1.0], filt_a_corr, new_v_xy, axis=0, zi=state.corr_zi_xy)
728 y_w, corr_zi_w = lfilter([1.0], filt_a_corr, new_v_w, axis=0, zi=state.corr_zi_w)
729 corr_count = state.corr_count + joint_fin.astype(np.int64)
731 # ── Early return during EWM warmup period ───────────────────────────
732 # All accumulators are already updated above; skip the O(N²) matrix
733 # reconstruction and O(N³) Cholesky solve which are wasteful during
734 # warmup — the computed positions would be discarded anyway.
735 if in_warmup:
736 state.corr_zi_x = corr_zi_x
737 state.corr_zi_x2 = corr_zi_x2
738 state.corr_zi_xy = corr_zi_xy
739 state.corr_zi_w = corr_zi_w
740 state.corr_count = corr_count
741 state.vola_s_x = vola_s_x
742 state.vola_s_x2 = vola_s_x2
743 state.vola_s_w = vola_s_w
744 state.vola_s_w2 = vola_s_w2
745 state.vola_count = vola_count
746 state.pct_s_x = pct_s_x
747 state.pct_s_x2 = pct_s_x2
748 state.pct_s_w = pct_s_w
749 state.pct_s_w2 = pct_s_w2
750 state.pct_count = pct_count
751 state.prev_price = new_p.copy()
752 state.step_count += 1
753 return StepResult(
754 date=date,
755 cash_position=np.full(n_assets, np.nan),
756 status=SolveStatus.WARMUP,
757 vola=np.full(n_assets, np.nan),
758 )
760 # ── Compute EWMA volatility (pct-return std) — shared ───────────────
761 vola_vec = _ewm_std_from_state(pct_s_x, pct_s_x2, pct_s_w, pct_s_w2, pct_count, min_samples=cfg.vola)
763 # ── Solve for position ───────────────────────────────────────────────
764 new_cash_pos = np.full(n_assets, np.nan, dtype=float)
765 status = SolveStatus.DEGENERATE
767 mask = np.isfinite(new_p)
769 if isinstance(cfg.covariance_config, SlidingWindowConfig):
770 # ── SW path: FactorModel solve via Woodbury identity ─────────────
771 sw_config = cfg.covariance_config
772 if not mask.any():
773 status = SolveStatus.DEGENERATE
774 else:
775 win_w = sw_config.window
776 win_k = sw_config.n_factors
777 window_ret = np.where(
778 np.isfinite(state.sw_ret_buf[:, mask]), # type: ignore[index]
779 state.sw_ret_buf[:, mask], # type: ignore[index]
780 0.0,
781 )
782 n_sub = int(mask.sum())
783 k_eff = min(win_k, win_w, n_sub)
784 if sw_config.max_components is not None:
785 k_eff = min(k_eff, sw_config.max_components)
786 try:
787 fm = FactorModel.from_returns(window_ret, k=k_eff)
788 except (np.linalg.LinAlgError, ValueError) as exc:
789 _logger.debug("Sliding window SVD failed at date=%s: %s", date, exc)
790 new_cash_pos[mask] = 0.0
791 status = SolveStatus.DEGENERATE
792 else:
793 expected_mu = np.nan_to_num(new_m[mask])
794 if np.allclose(expected_mu, 0.0):
795 new_cash_pos[mask] = 0.0
796 status = SolveStatus.ZERO_SIGNAL
797 else:
798 try:
799 x = fm.solve(expected_mu)
800 denom_val = float(np.sqrt(max(0.0, float(np.dot(expected_mu, x)))))
801 except (SingularMatrixError, np.linalg.LinAlgError) as exc:
802 _logger.warning("Woodbury solve failed at date=%s: %s", date, exc)
803 new_cash_pos[mask] = 0.0
804 status = SolveStatus.DEGENERATE
805 else:
806 if not np.isfinite(denom_val) or denom_val <= cfg.denom_tol:
807 _logger.warning(
808 "Positions zeroed at date=%s (sliding_window): normalisation "
809 "denominator degenerate (denom=%s, denom_tol=%s).",
810 date,
811 denom_val,
812 cfg.denom_tol,
813 )
814 new_cash_pos[mask] = 0.0
815 status = SolveStatus.DEGENERATE
816 else:
817 risk_pos = x / denom_val
818 vola_sub = vola_vec[mask]
819 with np.errstate(invalid="ignore"):
820 new_cash_pos[mask] = risk_pos / vola_sub
821 status = SolveStatus.VALID
822 else:
823 # ── EWM path: shared _corr_from_ewm_accumulators + _SolveMixin solve ──
824 # Reconstruct the EWM correlation matrix from filter outputs using
825 # the shared helper — the same formula as _ewm_corr_with_final_state
826 # but operating on (N, N) slices instead of (T, N, N) tensors.
827 # Use y_*[0] (the filter OUTPUT for this step), not zf[0].
828 corr = _corr_from_ewm_accumulators(
829 y_x[0],
830 y_x2[0],
831 y_xy[0],
832 y_w[0],
833 corr_count,
834 min_periods=cfg.corr,
835 min_corr_denom=cfg.min_corr_denom,
836 )
837 matrix = shrink2id(corr, lamb=cfg.shrink)
839 # Delegate the signal check and linear solve to _SolveMixin so that
840 # any algorithm change (denominator guard, status labels, etc.) only
841 # needs to be applied in _engine_solve.py.
842 expected_mu, early = _SolveMixin._row_early_check(state.step_count, date, mask, new_m)
843 if early is not None:
844 _, _, _, pos, status = early
845 new_cash_pos[mask] = pos
846 else:
847 corr_sub = matrix[np.ix_(mask, mask)]
848 _, _, _, pos, status = _SolveMixin._compute_position(
849 state.step_count, date, mask, expected_mu, MatrixBundle(matrix=corr_sub), cfg.denom_tol
850 )
851 if status == SolveStatus.VALID:
852 new_cash_pos[mask] = _SolveMixin._scale_to_cash(cast(np.ndarray, pos), vola_vec[mask])
853 else:
854 new_cash_pos[mask] = pos
856 # ── Apply turnover constraint ─────────────────────────────────────────
857 if cfg.max_turnover is not None and status == SolveStatus.VALID:
858 new_cash_pos[mask] = _SolveMixin._apply_turnover_constraint(
859 new_cash_pos[mask],
860 state.prev_cash_pos[mask],
861 cfg.max_turnover,
862 )
864 # ── Persist updated state ───────────────────────────────────────────
865 state.corr_zi_x = corr_zi_x
866 state.corr_zi_x2 = corr_zi_x2
867 state.corr_zi_xy = corr_zi_xy
868 state.corr_zi_w = corr_zi_w
869 state.corr_count = corr_count
870 state.vola_s_x = vola_s_x
871 state.vola_s_x2 = vola_s_x2
872 state.vola_s_w = vola_s_w
873 state.vola_s_w2 = vola_s_w2
874 state.vola_count = vola_count
875 state.pct_s_x = pct_s_x
876 state.pct_s_x2 = pct_s_x2
877 state.pct_s_w = pct_s_w
878 state.pct_s_w2 = pct_s_w2
879 state.pct_count = pct_count
880 state.prev_price = new_p.copy()
881 state.prev_cash_pos = new_cash_pos.copy()
882 state.step_count += 1
884 return StepResult(
885 date=date,
886 cash_position=new_cash_pos,
887 status=status,
888 vola=vola_vec,
889 )
891 def save(self, path: str | os.PathLike[str]) -> None:
892 """Serialise the stream to a ``.npz`` archive at *path*.
894 All :class:`_StreamState` arrays, the configuration, and the asset
895 list are written in a single :func:`numpy.savez` call. A stream
896 restored via :meth:`load` produces bit-for-bit identical
897 :meth:`step` output.
899 Args:
900 path: Destination file path. :func:`numpy.savez` appends
901 ``.npz`` automatically when the suffix is absent.
903 Examples:
904 >>> import tempfile, pathlib, numpy as np
905 >>> import polars as pl
906 >>> from datetime import date, timedelta
907 >>> from basanos.math import BasanosConfig, BasanosStream
908 >>> rng = np.random.default_rng(0)
909 >>> n = 60
910 >>> end = date(2024, 1, 1) + timedelta(days=n - 1)
911 >>> dates = pl.date_range(
912 ... date(2024, 1, 1), end, interval="1d", eager=True
913 ... )
914 >>> prices = pl.DataFrame({
915 ... "date": dates,
916 ... "A": np.cumprod(1 + rng.normal(0.001, 0.02, n)) * 100.0,
917 ... "B": np.cumprod(1 + rng.normal(0.001, 0.02, n)) * 150.0,
918 ... })
919 >>> mu = pl.DataFrame({
920 ... "date": dates,
921 ... "A": rng.normal(0, 0.5, n),
922 ... "B": rng.normal(0, 0.5, n),
923 ... })
924 >>> cfg = BasanosConfig(vola=5, corr=10, clip=3.0, shrink=0.5, aum=1e6)
925 >>> stream = BasanosStream.from_warmup(prices, mu, cfg)
926 >>> with tempfile.TemporaryDirectory() as tmp:
927 ... p = pathlib.Path(tmp) / "stream.npz"
928 ... stream.save(p)
929 ... restored = BasanosStream.load(p)
930 ... restored.assets == stream.assets
931 True
932 """
933 state = self._state
934 # Build the per-field dict automatically from _StreamState so that any
935 # new field added to the dataclass is included without manual updates.
936 state_arrays: dict[str, Any] = {}
937 for field in dataclasses.fields(_StreamState):
938 value = getattr(state, field.name)
939 if field.name == "sw_ret_buf":
940 # Sentinel: use an empty (0, 0) array to represent None so the
941 # key is always present in the archive and load() can detect it.
942 state_arrays[field.name] = value if value is not None else np.empty((0, 0), dtype=float)
943 elif field.name == "step_count":
944 state_arrays[field.name] = np.array(value)
945 else:
946 state_arrays[field.name] = value
947 np.savez(
948 path,
949 format_version=np.array(_SAVE_FORMAT_VERSION),
950 cfg_json=np.array(self._cfg.model_dump_json()),
951 assets=np.array(self._assets),
952 **state_arrays,
953 )
955 @classmethod
956 def load(cls, path: str | os.PathLike[str]) -> BasanosStream:
957 """Restore a stream previously saved with :meth:`save`.
959 Args:
960 path: Path to a ``.npz`` archive written by :meth:`save`.
962 Returns:
963 A :class:`BasanosStream` whose :meth:`step` output is
964 bit-for-bit identical to the original stream at the time
965 :meth:`save` was called.
967 Examples:
968 >>> import tempfile, pathlib, numpy as np
969 >>> import polars as pl
970 >>> from datetime import date, timedelta
971 >>> from basanos.math import BasanosConfig, BasanosStream
972 >>> rng = np.random.default_rng(1)
973 >>> n = 60
974 >>> end = date(2024, 1, 1) + timedelta(days=n - 1)
975 >>> dates = pl.date_range(
976 ... date(2024, 1, 1), end, interval="1d", eager=True
977 ... )
978 >>> prices = pl.DataFrame({
979 ... "date": dates,
980 ... "A": np.cumprod(1 + rng.normal(0.001, 0.02, n)) * 100.0,
981 ... "B": np.cumprod(1 + rng.normal(0.001, 0.02, n)) * 150.0,
982 ... })
983 >>> mu = pl.DataFrame({
984 ... "date": dates,
985 ... "A": rng.normal(0, 0.5, n),
986 ... "B": rng.normal(0, 0.5, n),
987 ... })
988 >>> cfg = BasanosConfig(vola=5, corr=10, clip=3.0, shrink=0.5, aum=1e6)
989 >>> stream = BasanosStream.from_warmup(prices, mu, cfg)
990 >>> with tempfile.TemporaryDirectory() as tmp:
991 ... p = pathlib.Path(tmp) / "stream.npz"
992 ... stream.save(p)
993 ... restored = BasanosStream.load(p)
994 ... restored.assets == stream.assets
995 True
996 """
997 data = np.load(path, allow_pickle=False)
998 if "format_version" not in data:
999 raise ValueError( # noqa: TRY003
1000 "Stream file is missing a format version tag. "
1001 "It was written with an incompatible version of BasanosStream. "
1002 "Re-generate it via BasanosStream.from_warmup()."
1003 )
1004 found = int(data["format_version"])
1005 if found != _SAVE_FORMAT_VERSION:
1006 raise ValueError( # noqa: TRY003
1007 f"Stream file was written with format version {found}, "
1008 f"but the current version is {_SAVE_FORMAT_VERSION}. "
1009 "Re-generate it via BasanosStream.from_warmup()."
1010 )
1011 # Validate that every required key is present. This catches archives
1012 # that were produced by an older codebase missing a newly added field,
1013 # or archives that have been manually edited, with a descriptive error
1014 # instead of a bare KeyError.
1015 archive_keys = frozenset(data.files)
1016 missing = _REQUIRED_KEYS - archive_keys
1017 if missing:
1018 raise StreamStateCorruptError(missing)
1019 cfg = BasanosConfig.model_validate_json(data["cfg_json"].item())
1020 assets: list[str] = list(data["assets"])
1021 state_kwargs: dict[str, Any] = {}
1022 for field in dataclasses.fields(_StreamState):
1023 raw = data[field.name]
1024 if field.name == "sw_ret_buf":
1025 state_kwargs[field.name] = raw if raw.size > 0 else None
1026 elif field.name == "step_count":
1027 state_kwargs[field.name] = int(raw)
1028 else:
1029 state_kwargs[field.name] = raw
1030 state = _StreamState(**state_kwargs)
1031 return cls(cfg=cfg, assets=assets, state=state)