Coverage for src / basanos / math / _stream.py: 100%

317 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-02 17:47 +0000

1"""Incremental (streaming) API for BasanosEngine. 

2 

3This private module defines three public symbols: 

4 

5* :class:`_StreamState` — mutable dataclass that persists all O(N²) IIR 

6 filter and EWMA accumulator state between consecutive 

7 :meth:`BasanosStream.step` calls. Kept separate from the engine so the 

8 state layout can be read and tested in isolation. 

9* :class:`StepResult` — frozen dataclass returned by each 

10 :meth:`BasanosStream.step` call. 

11* :class:`BasanosStream` — incremental façade with a 

12 :meth:`~BasanosStream.from_warmup` classmethod and a 

13 :meth:`~BasanosStream.step` method. 

14 

15IIR state model 

16--------------- 

17The EWM recurrence ``s[t] = beta·s[t-1] + v[t]`` is a causal, single-pole IIR 

18filter. When the full history is available, ``scipy.signal.lfilter`` solves 

19all N² pairs in one vectorised call and discards the intermediate array. 

20 

21In the *incremental* setting there is no history — only the current sample 

22arrives at each ``step()``. To continue the same recurrence across calls we 

23need the *filter memory*, i.e. the value of the accumulator at the end of the 

24previous call. 

25 

26``scipy.signal.lfilter`` exposes this directly: when called as:: 

27 

28 y, zf = lfilter(b, a, x, zi=zi) 

29 

30``zi`` is the initial state (shape ``(max(len(a), len(b)) - 1, …)`` = ``(1, …)`` 

31for our first-order filter) and ``zf`` is the *final* state after processing 

32``x``. Passing the returned ``zf`` back as ``zi`` in the next call is 

33mathematically equivalent to having run ``lfilter`` over the concatenated 

34input, so the incremental and batch paths produce bit-for-bit identical 

35results. 

36 

37The four correlation accumulators (``corr_zi_x``, ``corr_zi_x2``, 

38``corr_zi_xy``, ``corr_zi_w``) follow exactly this pattern. Each has shape 

39``(1, N, N)`` — the leading 1 is the IIR filter order required by 

40``lfilter``'s ``zi`` argument. 

41 

42The volatility accumulators (``vola_*``, ``pct_*``) use a simpler scalar 

43recurrence and store the running sums directly as ``(N,)`` arrays. 

44 

45Memory 

46------ 

47Total incremental state is 4x(1,N,N) + (N,N) + 8x(N,) + O(1) scalars, 

48giving **O(N^2)** memory independent of the number of timesteps processed. 

49""" 

50 

51from __future__ import annotations 

52 

53import dataclasses 

54import logging 

55import os 

56from typing import TYPE_CHECKING, Any, cast 

57 

58if TYPE_CHECKING: 

59 from ._ewm_corr import _EwmCorrState 

60 

61import numpy as np 

62import polars as pl 

63from scipy.signal import lfilter 

64 

65from ..exceptions import MissingDateColumnError, StreamStateCorruptError 

66from ._config import BasanosConfig, EwmaShrinkConfig, SlidingWindowConfig 

67from ._engine_solve import MatrixBundle, SolveStatus, _SolveMixin 

68from ._ewm_corr import _corr_from_ewm_accumulators 

69from ._factor_model import FactorModel 

70from ._signal import shrink2id 

71 

72_logger = logging.getLogger(__name__) 

73 

74#: Increment this when the :func:`BasanosStream.save` archive layout changes in 

75#: a backward-incompatible way. :func:`BasanosStream.load` asserts the stored 

76#: value matches before deserialising anything, so callers get a clear error 

77#: instead of a silent ``KeyError`` or wrong state. 

78_SAVE_FORMAT_VERSION: int = 2 

79 

80 

81@dataclasses.dataclass 

82class _StreamState: 

83 """Mutable state carrier for one :class:`BasanosStream` instance. 

84 

85 All arrays are updated in-place (or replaced) by ``BasanosStream.step()``. 

86 The class is intentionally *not* frozen so that the step method can modify 

87 fields directly without creating a new object on every tick. 

88 

89 IIR filter state (correlation) 

90 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 

91 The four ``corr_zi_*`` fields are the *final conditions* (``zf``) returned 

92 by ``scipy.signal.lfilter`` after the previous step. They are passed back 

93 as the ``zi`` argument on the next step so that the incremental filter 

94 produces the same numerical result as a single batch call over all history. 

95 See the module docstring for the full derivation. 

96 

97 ``beta_corr = cfg.corr / (1 + cfg.corr)`` (from ``com = cfg.corr``) 

98 

99 EWM accumulator state (volatility) 

100 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 

101 ``vola_*`` and ``pct_*`` accumulate the running weighted sums needed to 

102 compute exponentially-weighted standard deviations: 

103 

104 * ``s_x`` — EWM sum of x (numerator of the mean) 

105 * ``s_x2`` — EWM sum of x² (numerator of the second moment) 

106 * ``s_w`` — EWM sum of weights (denominator) 

107 * ``s_w2`` — EWM sum of squared weights (for bias correction) 

108 

109 ``beta_vola = (cfg.vola - 1) / cfg.vola`` (from ``com = cfg.vola - 1``) 

110 

111 Attributes: 

112 corr_zi_x: IIR filter state for the x-accumulator of the EWM 

113 correlation; shape ``(1, N, N)``. 

114 corr_zi_x2: IIR filter state for the x²-accumulator; shape 

115 ``(1, N, N)``. 

116 corr_zi_xy: IIR filter state for the xy-accumulator; shape 

117 ``(1, N, N)``. 

118 corr_zi_w: IIR filter state for the weight-accumulator; shape 

119 ``(1, N, N)``. 

120 corr_count: Cumulative joint-finite observation count per asset pair; 

121 shape ``(N, N)`` dtype int. 

122 vola_s_x: EWM sum of volatility-adjusted log-returns; shape ``(N,)``. 

123 vola_s_x2: EWM sum of squared vol-adj log-returns; shape ``(N,)``. 

124 vola_s_w: EWM weight sum for vol accumulators; shape ``(N,)``. 

125 vola_s_w2: EWM squared-weight sum for vol accumulators; shape ``(N,)``. 

126 vola_count: Cumulative finite observation count for vol; shape ``(N,)`` 

127 dtype int. 

128 pct_s_x: EWM sum of pct-returns; shape ``(N,)``. 

129 pct_s_x2: EWM sum of squared pct-returns; shape ``(N,)``. 

130 pct_s_w: EWM weight sum for pct accumulators; shape ``(N,)``. 

131 pct_s_w2: EWM squared-weight sum for pct accumulators; shape ``(N,)``. 

132 pct_count: Cumulative finite observation count for pct-return vol; 

133 shape ``(N,)`` dtype int. 

134 prev_price: Last price row seen, used to compute returns on the next 

135 step; shape ``(N,)``. 

136 prev_cash_pos: Last cash position, used to apply the turnover constraint 

137 on the next step; shape ``(N,)``. 

138 step_count: Number of steps processed so far (0 before first step). 

139 """ 

140 

141 # ── IIR filter state for EWM correlation — shape (1, N, N) each ────────── 

142 corr_zi_x: np.ndarray # (1, N, N) 

143 corr_zi_x2: np.ndarray # (1, N, N) 

144 corr_zi_xy: np.ndarray # (1, N, N) 

145 corr_zi_w: np.ndarray # (1, N, N) 

146 corr_count: np.ndarray # (N, N) int — cumulative joint-finite observation count 

147 

148 # ── EWMA accumulators for vol_adj (log-return std; com=vola-1, min_samples=1) ── 

149 vola_s_x: np.ndarray # (N,) 

150 vola_s_x2: np.ndarray # (N,) 

151 vola_s_w: np.ndarray # (N,) 

152 vola_s_w2: np.ndarray # (N,) 

153 vola_count: np.ndarray # (N,) int 

154 

155 # ── EWMA accumulators for vola (pct-return std; com=vola-1, min_samples=vola) ── 

156 pct_s_x: np.ndarray # (N,) 

157 pct_s_x2: np.ndarray # (N,) 

158 pct_s_w: np.ndarray # (N,) 

159 pct_s_w2: np.ndarray # (N,) 

160 pct_count: np.ndarray # (N,) int 

161 

162 # ── Scalars ─────────────────────────────────────────────────────────────── 

163 prev_price: np.ndarray # (N,) last price row (to compute returns at next step) 

164 prev_cash_pos: np.ndarray # (N,) last cash position (for turnover constraint at next step) 

165 step_count: int 

166 

167 # ── SlidingWindowConfig state — None for EwmaShrinkConfig ──────────────── 

168 # shape (W, N): last W vol-adjusted returns (oldest row first); None when 

169 # using EwmaShrinkConfig. The corr_zi_* fields above are unused (zeros) 

170 # in this mode; sw_ret_buf carries all the correlation state instead. 

171 sw_ret_buf: np.ndarray | None = None # (W, N) rolling buffer, or None 

172 

173 

174#: Keys that :meth:`BasanosStream.save` writes to the ``.npz`` archive for 

175#: :class:`_StreamState` fields. Derived automatically from 

176#: :func:`dataclasses.fields` so that adding a new field to ``_StreamState`` 

177#: is sufficient — no manual update here is required. 

178#: 

179#: The three non-state keys (``format_version``, ``cfg_json``, ``assets``) are 

180#: added explicitly because they are not fields of ``_StreamState`` itself. 

181_REQUIRED_KEYS: frozenset[str] = frozenset( 

182 {f.name for f in dataclasses.fields(_StreamState)} | {"format_version", "cfg_json", "assets"} 

183) 

184 

185 

186@dataclasses.dataclass(frozen=True) 

187class StepResult: 

188 """Frozen dataclass representing the output of a single ``BasanosStream`` step. 

189 

190 Each call to ``BasanosStream.step()`` returns one ``StepResult`` capturing 

191 the optimised cash positions, the per-asset volatility estimate, the step 

192 date, and a status label that describes the solver outcome for that 

193 timestep. 

194 

195 Attributes: 

196 date: The timestamp or date label for this step. The type mirrors 

197 whatever is stored in the ``'date'`` column of the input prices 

198 DataFrame (typically a Python :class:`datetime.date`, 

199 :class:`datetime.datetime`, or a Polars temporal scalar). 

200 cash_position: Optimised cash-position vector, shape ``(N,)``. 

201 Entries are ``NaN`` for assets that are still in the EWMA warmup 

202 period or that are otherwise inactive at this step. 

203 status: Solver outcome label for this timestep 

204 (:class:`~basanos.math.SolveStatus`). Since :class:`SolveStatus` 

205 is a ``StrEnum``, values compare equal to their string equivalents 

206 (e.g. ``result.status == "valid"`` is ``True``): 

207 

208 * ``'warmup'`` — fewer rows have been seen than the EWMA warmup 

209 requires; all positions are ``NaN``. 

210 * ``'zero_signal'`` — the expected-return signal vector ``mu`` is 

211 identically zero; positions are set to zero rather than solved. 

212 * ``'degenerate'`` — the covariance matrix is ill-conditioned or 

213 numerically singular; positions cannot be computed reliably and 

214 are returned as ``NaN``. 

215 * ``'valid'`` — normal operation; ``cash_position`` holds the 

216 optimised allocations. 

217 vola: Per-asset EWMA percentage-return volatility, shape ``(N,)``. 

218 Values are ``NaN`` during the warmup period before the EWMA has 

219 accumulated sufficient history. 

220 

221 Examples: 

222 >>> import numpy as np 

223 >>> result = StepResult( 

224 ... date="2024-01-02", 

225 ... cash_position=np.array([1000.0, -500.0]), 

226 ... status="valid", 

227 ... vola=np.array([0.012, 0.018]), 

228 ... ) 

229 >>> result.status 

230 'valid' 

231 >>> result.cash_position.shape 

232 (2,) 

233 """ 

234 

235 date: object 

236 cash_position: np.ndarray 

237 status: SolveStatus 

238 vola: np.ndarray 

239 

240 

241# --------------------------------------------------------------------------- 

242# Helper: unbiased EWMA std from running accumulators 

243# --------------------------------------------------------------------------- 

244 

245 

246def _ewm_std_from_state( 

247 s_x: np.ndarray, 

248 s_x2: np.ndarray, 

249 s_w: np.ndarray, 

250 s_w2: np.ndarray, 

251 count: np.ndarray, 

252 min_samples: int, 

253) -> np.ndarray: 

254 r"""Compute the unbiased EWMA standard deviation from running accumulators. 

255 

256 Implements the same Bessel-corrected formula used by 

257 ``polars.Expr.ewm_std(adjust=True)``:: 

258 

259 var_biased = s_x2/s_w - (s_x/s_w)^2 

260 correction = s_w^2 / (s_w^2 - s_w2) # Bessel correction 

261 var_unbiased = var_biased * correction 

262 std = sqrt(max(0, var_unbiased)) 

263 

264 where ``s_w2 = sum(wi^2)`` is the sum of squared EWM weights. 

265 

266 Parameters 

267 ---------- 

268 s_x, s_x2, s_w, s_w2: 

269 Running accumulators, each of shape ``(N,)``. 

270 count: 

271 Integer count of finite observations per asset, shape ``(N,)``. 

272 min_samples: 

273 Minimum number of finite observations required before returning a 

274 non-NaN value. 

275 

276 Returns: 

277 ------- 

278 np.ndarray of shape ``(N,)`` with per-asset standard deviations. 

279 NaN is returned for assets where ``count < min_samples``. 

280 """ 

281 n = len(s_x) 

282 result = np.full(n, np.nan, dtype=float) 

283 ok = count >= min_samples 

284 if not ok.any(): 

285 return result 

286 

287 with np.errstate(divide="ignore", invalid="ignore"): 

288 mean = np.where(s_w > 0, s_x / s_w, 0.0) 

289 mean_sq = np.where(s_w > 0, s_x2 / s_w, 0.0) 

290 var_biased = np.maximum(mean_sq - mean**2, 0.0) 

291 denom_corr = s_w**2 - s_w2 

292 # denom_corr > 0 iff count >= 2; equals 0 when count == 1 

293 var_unbiased = np.where(denom_corr > 0, var_biased * s_w**2 / denom_corr, 0.0) 

294 std = np.sqrt(var_unbiased) 

295 

296 return np.where(ok, std, np.nan) 

297 

298 

299# --------------------------------------------------------------------------- 

300# Helper: batch EWMA volatility accumulators from a returns matrix 

301# --------------------------------------------------------------------------- 

302 

303 

304def _ewm_vol_accumulators_from_batch( 

305 returns: np.ndarray, 

306 beta: float, 

307 beta_sq: float, 

308) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: 

309 r"""Compute final EWMA volatility accumulators from a batch of returns. 

310 

311 Implements the same IIR recurrence as :meth:`BasanosStream.step` but 

312 vectorised over *T* timesteps using ``scipy.signal.lfilter``. The five 

313 returned arrays are identical to the accumulators that would result from 

314 feeding each row of *returns* through the scalar step-by-step recurrence:: 

315 

316 s_x[t] = beta * s_x[t-1] + (x[t] if finite else 0) 

317 s_x2[t] = beta * s_x2[t-1] + (x[t]^2 if finite else 0) 

318 s_w[t] = beta * s_w[t-1] + (1 if finite else 0) 

319 s_w2[t] = beta^2 * s_w2[t-1] + (1 if finite else 0) 

320 

321 Parameters 

322 ---------- 

323 

324 Returns: 

325 Float array of shape ``(T, N)``. NaN entries are treated as missing 

326 observations — they contribute nothing to the numerator sums and do 

327 not increment the weight accumulators. 

328 beta: 

329 EWM decay factor for ``s_x``, ``s_x2``, and ``s_w`` 

330 (``beta = (com) / (1 + com)`` for ``com = cfg.vola - 1``). 

331 beta_sq: 

332 Squared decay factor used for ``s_w2``. Must equal ``beta ** 2``. 

333 

334 Returns: 

335 ------- 

336 s_x, s_x2, s_w, s_w2 : np.ndarray of shape ``(N,)`` 

337 Final EWMA running accumulators after processing all *T* rows. 

338 count : np.ndarray of shape ``(N,)`` dtype int 

339 Number of finite observations per asset. 

340 

341 Notes: 

342 ----- 

343 This function is the shared implementation used by 

344 :meth:`BasanosStream.from_warmup` for both the log-return (``vola_*``) 

345 and pct-return (``pct_*``) accumulators. Keeping a single implementation 

346 here guarantees that the batch and incremental paths stay in sync when the 

347 recurrence definition changes. 

348 """ 

349 fin = np.isfinite(returns).astype(np.float64) # (T, N) 

350 x_z = np.where(fin.astype(bool), returns, 0.0) # (T, N) 

351 filt_a = np.array([1.0, -beta]) 

352 filt_a2 = np.array([1.0, -beta_sq]) 

353 

354 s_x: np.ndarray = lfilter([1.0], filt_a, x_z, axis=0)[-1] 

355 s_x2: np.ndarray = lfilter([1.0], filt_a, x_z**2, axis=0)[-1] 

356 s_w: np.ndarray = lfilter([1.0], filt_a, fin, axis=0)[-1] 

357 s_w2: np.ndarray = lfilter([1.0], filt_a2, fin, axis=0)[-1] 

358 count: np.ndarray = fin.sum(axis=0).astype(int) 

359 

360 return s_x, s_x2, s_w, s_w2, count 

361 

362 

363# --------------------------------------------------------------------------- 

364# BasanosStream 

365# --------------------------------------------------------------------------- 

366 

367 

368class BasanosStream: 

369 """Incremental (streaming) optimiser backed by a single :class:`_StreamState`. 

370 

371 After warming up on a historical batch via :meth:`from_warmup`, each call 

372 to :meth:`step` advances the internal state by exactly one row in 

373 O(N^2) time — without revisiting the full warmup history. 

374 

375 Attributes: 

376 assets: Ordered list of asset column names (read-only). 

377 

378 Examples: 

379 >>> import numpy as np 

380 >>> import polars as pl 

381 >>> from datetime import date, timedelta 

382 >>> from basanos.math import BasanosConfig, BasanosStream 

383 >>> rng = np.random.default_rng(0) 

384 >>> warmup_len = 60 

385 >>> dates = pl.date_range( 

386 ... start=date(2024, 1, 1), 

387 ... end=date(2024, 1, 1) + timedelta(days=warmup_len), 

388 ... interval="1d", 

389 ... eager=True, 

390 ... ) 

391 >>> prices = pl.DataFrame({ 

392 ... "date": dates, 

393 ... "A": np.cumprod(1 + rng.normal(0.001, 0.02, warmup_len + 1)) * 100.0, 

394 ... "B": np.cumprod(1 + rng.normal(0.001, 0.02, warmup_len + 1)) * 150.0, 

395 ... }) 

396 >>> mu = pl.DataFrame({ 

397 ... "date": dates, 

398 ... "A": rng.normal(0, 0.5, warmup_len + 1), 

399 ... "B": rng.normal(0, 0.5, warmup_len + 1), 

400 ... }) 

401 >>> cfg = BasanosConfig(vola=5, corr=10, clip=3.0, shrink=0.5, aum=1e6) 

402 >>> stream = BasanosStream.from_warmup(prices.head(warmup_len), mu.head(warmup_len), cfg) 

403 >>> result = stream.step( 

404 ... prices.select(["A", "B"]).to_numpy()[warmup_len], 

405 ... mu.select(["A", "B"]).to_numpy()[warmup_len], 

406 ... prices["date"][warmup_len], 

407 ... ) 

408 >>> isinstance(result, StepResult) 

409 True 

410 >>> result.cash_position.shape 

411 (2,) 

412 """ 

413 

414 _cfg: BasanosConfig 

415 _assets: list[str] 

416 _state: _StreamState 

417 

418 def __init__(self, cfg: BasanosConfig, assets: list[str], state: _StreamState) -> None: 

419 """Initialise from an explicit config, asset list, and state container.""" 

420 object.__setattr__(self, "_cfg", cfg) 

421 object.__setattr__(self, "_assets", assets) 

422 object.__setattr__(self, "_state", state) 

423 

424 def __setattr__(self, name: str, value: object) -> None: 

425 """Prevent accidental attribute mutation — BasanosStream is immutable.""" 

426 raise dataclasses.FrozenInstanceError(f"{type(self).__name__}.{name}") 

427 

428 @property 

429 def assets(self) -> list[str]: 

430 """Ordered list of asset column names.""" 

431 return self._assets 

432 

433 # ------------------------------------------------------------------ 

434 # from_warmup 

435 # ------------------------------------------------------------------ 

436 

437 @classmethod 

438 def from_warmup( 

439 cls, 

440 prices: pl.DataFrame, 

441 mu: pl.DataFrame, 

442 cfg: BasanosConfig, 

443 ) -> BasanosStream: 

444 """Build a :class:`BasanosStream` from a historical warmup batch. 

445 

446 Runs :class:`~basanos.math.BasanosEngine` on the full warmup batch 

447 exactly once and extracts the minimal IIR-filter state required for 

448 subsequent :meth:`step` calls. After this call, each :meth:`step` 

449 advances the optimiser in O(N^2) time without touching the warmup 

450 data again. 

451 

452 Parameters 

453 ---------- 

454 prices: 

455 Historical price DataFrame. Must contain a ``'date'`` column and 

456 at least one numeric asset column with strictly positive, 

457 non-monotonic values. 

458 mu: 

459 Expected-return signal DataFrame aligned row-by-row with 

460 ``prices``. 

461 cfg: 

462 Engine configuration. Both :class:`~basanos.math.EwmaShrinkConfig` 

463 and :class:`~basanos.math.SlidingWindowConfig` are supported. 

464 

465 Returns: 

466 ------- 

467 BasanosStream 

468 A stream instance whose :meth:`step` method is ready to accept the 

469 row immediately following the last warmup row. 

470 

471 Notes: 

472 ------ 

473 **Short-warmup behaviour with** ``SlidingWindowConfig``: when 

474 ``len(prices) < cfg.covariance_config.window``, the internal rolling 

475 buffer (``sw_ret_buf``) is NaN-padded for the missing prefix rows. 

476 :meth:`step` returns ``StepResult(status="warmup")`` for each of the 

477 first ``window - len(prices)`` calls, exactly matching the EWM warmup 

478 semantics. By the time :meth:`step` returns the first non-warmup 

479 result the buffer contains only real data — no NaN-padded rows remain. 

480 

481 Raises: 

482 ------ 

483 MissingDateColumnError 

484 If ``'date'`` is absent from ``prices``. 

485 """ 

486 # 1. Validate ------------------------------------------------------- 

487 if "date" not in prices.columns: 

488 raise MissingDateColumnError("prices") 

489 

490 # 2. Build the engine on the full warmup batch ---------------------- 

491 # Import here to avoid a circular dependency at module level. 

492 from .optimizer import BasanosEngine 

493 

494 engine = BasanosEngine(prices=prices, mu=mu, cfg=cfg) 

495 assets = engine.assets 

496 n_assets = len(assets) 

497 n_rows = prices.height 

498 prices_np = prices.select(assets).to_numpy() # (n_rows, n_assets) 

499 

500 # 3. Extract mode-specific state from WarmupState -------------------- 

501 ws = engine.warmup_state() 

502 if isinstance(cfg.covariance_config, EwmaShrinkConfig): 

503 # EWM: seed the per-step lfilter from the IIR state captured 

504 # during the single batch pass in warmup_state(). 

505 iir = cast("_EwmCorrState", ws.corr_iir_state) 

506 corr_zi_x = iir.corr_zi_x 

507 corr_zi_x2 = iir.corr_zi_x2 

508 corr_zi_xy = iir.corr_zi_xy 

509 corr_zi_w = iir.corr_zi_w 

510 corr_count: np.ndarray = iir.count 

511 sw_ret_buf: np.ndarray | None = None 

512 else: 

513 # SW: carry the last W vol-adjusted returns as a rolling buffer. 

514 # The IIR fields are initialised to zeros and left unused. 

515 sw_config = cast(SlidingWindowConfig, cfg.covariance_config) 

516 win_w = sw_config.window 

517 ret_adj_np = engine.ret_adj.select(assets).to_numpy() # (n_rows, N) 

518 if n_rows >= win_w: 

519 sw_ret_buf = ret_adj_np[-win_w:].copy() 

520 else: 

521 sw_ret_buf = np.full((win_w, n_assets), np.nan) 

522 sw_ret_buf[-n_rows:] = ret_adj_np 

523 corr_zi_x = np.zeros((1, n_assets, n_assets)) 

524 corr_zi_x2 = np.zeros((1, n_assets, n_assets)) 

525 corr_zi_xy = np.zeros((1, n_assets, n_assets)) 

526 corr_zi_w = np.zeros((1, n_assets, n_assets)) 

527 corr_count = np.zeros((n_assets, n_assets), dtype=np.int64) 

528 

529 # 4. Derive EWMA volatility accumulators (vectorised) --------------- 

530 # Both log-return (for vol_adj) and pct-return (for vola) use the 

531 # same beta = (vola-1)/vola. NaN observations (leading NaN at row 0 

532 # from diff/pct_change) are skipped — the filter input is 0 for NaN 

533 # rows and the weight accumulator (s_w) only increments for finite 

534 # observations, matching Polars' effective behaviour for a 

535 # leading-NaN series. 

536 # 

537 # Delegate to the shared helper _ewm_vol_accumulators_from_batch so 

538 # that the batch and incremental recurrences share a single definition. 

539 beta_vola: float = (cfg.vola - 1) / cfg.vola 

540 beta_vola_sq: float = beta_vola**2 

541 

542 log_ret = np.full((n_rows, n_assets), np.nan, dtype=float) 

543 pct_ret = np.full((n_rows, n_assets), np.nan, dtype=float) 

544 if n_rows > 1: 

545 with np.errstate(divide="ignore", invalid="ignore"): 

546 log_ret[1:] = np.log(prices_np[1:] / prices_np[:-1]) 

547 pct_ret[1:] = prices_np[1:] / prices_np[:-1] - 1.0 

548 

549 vola_s_x, vola_s_x2, vola_s_w, vola_s_w2, vola_count = _ewm_vol_accumulators_from_batch( 

550 log_ret, beta_vola, beta_vola_sq 

551 ) 

552 pct_s_x, pct_s_x2, pct_s_w, pct_s_w2, pct_count = _ewm_vol_accumulators_from_batch( 

553 pct_ret, beta_vola, beta_vola_sq 

554 ) 

555 

556 # 5. Extract prev_cash_pos from WarmupState -------------------------- 

557 prev_cash_pos: np.ndarray = ws.prev_cash_pos 

558 prev_price: np.ndarray = prices_np[-1].copy() 

559 

560 # 6. Construct _StreamState and return ------------------------------ 

561 state = _StreamState( 

562 corr_zi_x=corr_zi_x, 

563 corr_zi_x2=corr_zi_x2, 

564 corr_zi_xy=corr_zi_xy, 

565 corr_zi_w=corr_zi_w, 

566 corr_count=corr_count, 

567 vola_s_x=vola_s_x, 

568 vola_s_x2=vola_s_x2, 

569 vola_s_w=vola_s_w, 

570 vola_s_w2=vola_s_w2, 

571 vola_count=vola_count, 

572 pct_s_x=pct_s_x, 

573 pct_s_x2=pct_s_x2, 

574 pct_s_w=pct_s_w, 

575 pct_s_w2=pct_s_w2, 

576 pct_count=pct_count, 

577 prev_price=prev_price, 

578 prev_cash_pos=prev_cash_pos, 

579 step_count=n_rows, 

580 sw_ret_buf=sw_ret_buf, 

581 ) 

582 return cls(cfg=cfg, assets=assets, state=state) 

583 

584 # ------------------------------------------------------------------ 

585 # step 

586 # ------------------------------------------------------------------ 

587 

588 def step( 

589 self, 

590 new_prices: np.ndarray | dict[str, float], 

591 new_mu: np.ndarray | dict[str, float], 

592 date: Any = None, 

593 ) -> StepResult: 

594 """Advance the stream by one row and return the new optimised position. 

595 

596 Parameters 

597 ---------- 

598 new_prices: 

599 Per-asset prices for the new timestep. Either a numpy array of 

600 shape ``(N,)`` (assets ordered as in :attr:`assets`) or a dict 

601 mapping asset names to price values. 

602 new_mu: 

603 Per-asset expected-return signals, same format as ``new_prices``. 

604 date: 

605 Timestamp for this step (stored in :attr:`StepResult.date` 

606 verbatim; not used in any computation). 

607 

608 Returns: 

609 ------- 

610 StepResult 

611 Frozen dataclass with ``cash_position``, ``vola``, ``status``, and 

612 ``date`` for this timestep. 

613 """ 

614 from ..exceptions import SingularMatrixError 

615 

616 cfg = self._cfg 

617 assets = self._assets 

618 state = self._state 

619 n_assets = len(assets) 

620 

621 # ── Check if still in the warmup period ────────────────────────────── 

622 # step_count is initialised to n_rows in from_warmup. 

623 # 

624 # EwmaShrinkConfig: in_warmup is True for the first (cfg.corr - n_rows) 

625 # calls when the warmup batch was shorter than cfg.corr (not enough rows 

626 # to populate the EWM correlation matrix). 

627 # 

628 # SlidingWindowConfig: in_warmup is True for the first (window - n_rows) 

629 # calls when the warmup batch was shorter than the window. During this 

630 # period sw_ret_buf still contains NaN-padded prefix rows; each step 

631 # shifts one NaN out and appends a real row, so the buffer is fully 

632 # populated with real data exactly when in_warmup becomes False. 

633 # 

634 # In both modes all accumulators are still updated during warmup so that 

635 # the state is ready the moment the warmup period ends. 

636 _warmup_thresh = ( 

637 cfg.covariance_config.window if isinstance(cfg.covariance_config, SlidingWindowConfig) else cfg.corr 

638 ) 

639 in_warmup: bool = state.step_count < _warmup_thresh 

640 

641 # ── Resolve inputs to (N,) float64 arrays ────────────────────────── 

642 if isinstance(new_prices, dict): 

643 new_p = np.array([float(new_prices[a]) for a in assets], dtype=float) 

644 else: 

645 new_p = np.asarray(new_prices, dtype=float).ravel() 

646 

647 if isinstance(new_mu, dict): 

648 new_m = np.array([float(new_mu[a]) for a in assets], dtype=float) 

649 else: 

650 new_m = np.asarray(new_mu, dtype=float).ravel() 

651 

652 if new_p.shape != (n_assets,): 

653 raise ValueError(f"new_prices must have shape ({n_assets},); got {new_p.shape}") # noqa: TRY003 

654 if new_m.shape != (n_assets,): 

655 raise ValueError(f"new_mu must have shape ({n_assets},); got {new_m.shape}") # noqa: TRY003 

656 

657 prev_p = state.prev_price 

658 beta_vola: float = (cfg.vola - 1) / cfg.vola 

659 beta_vola_sq: float = beta_vola**2 

660 beta_corr: float = cfg.corr / (1.0 + cfg.corr) 

661 

662 # ── Compute new log-returns and pct-returns ───────────────────────── 

663 with np.errstate(divide="ignore", invalid="ignore"): 

664 ratio = np.where( 

665 np.isfinite(new_p) & np.isfinite(prev_p) & (prev_p > 0), 

666 new_p / prev_p, 

667 np.nan, 

668 ) 

669 log_ret = np.log(ratio) 

670 pct_ret = ratio - 1.0 

671 

672 # ── Update log-return EWMA accumulators ──────────────────────────── 

673 fin_log = np.isfinite(log_ret) 

674 vola_s_x = beta_vola * state.vola_s_x + np.where(fin_log, log_ret, 0.0) 

675 vola_s_x2 = beta_vola * state.vola_s_x2 + np.where(fin_log, log_ret**2, 0.0) 

676 vola_s_w = beta_vola * state.vola_s_w + fin_log.astype(float) 

677 vola_s_w2 = beta_vola_sq * state.vola_s_w2 + fin_log.astype(float) 

678 vola_count = state.vola_count + fin_log.astype(int) 

679 

680 # ── Update pct-return EWMA accumulators ──────────────────────────── 

681 fin_pct = np.isfinite(pct_ret) 

682 pct_s_x = beta_vola * state.pct_s_x + np.where(fin_pct, pct_ret, 0.0) 

683 pct_s_x2 = beta_vola * state.pct_s_x2 + np.where(fin_pct, pct_ret**2, 0.0) 

684 pct_s_w = beta_vola * state.pct_s_w + fin_pct.astype(float) 

685 pct_s_w2 = beta_vola_sq * state.pct_s_w2 + fin_pct.astype(float) 

686 pct_count = state.pct_count + fin_pct.astype(int) 

687 

688 # ── Compute vol-adjusted return (for the correlation IIR input) ───── 

689 log_vol = _ewm_std_from_state(vola_s_x, vola_s_x2, vola_s_w, vola_s_w2, vola_count, min_samples=1) 

690 # Divide; std == 0 yields ±inf → clipped to ±cfg.clip (matches Polars) 

691 with np.errstate(divide="ignore", invalid="ignore"): 

692 vol_adj_val = np.where( 

693 fin_log, 

694 np.clip(log_ret / log_vol, -cfg.clip, cfg.clip), 

695 np.nan, 

696 ) 

697 

698 # ── Mode-specific correlation state update ─────────────────────────── 

699 if isinstance(cfg.covariance_config, SlidingWindowConfig): 

700 # SW: shift the rolling window buffer in-place and append this row. 

701 # The corr_zi_* fields are unused; alias them to their old values so 

702 # the early-return and persist blocks below can reference them safely. 

703 buf = state.sw_ret_buf # (W, N), already owned by state 

704 buf[:-1] = buf[1:] # type: ignore[index] 

705 buf[-1] = vol_adj_val # type: ignore[index] 

706 corr_zi_x = state.corr_zi_x 

707 corr_zi_x2 = state.corr_zi_x2 

708 corr_zi_xy = state.corr_zi_xy 

709 corr_zi_w = state.corr_zi_w 

710 corr_count = state.corr_count 

711 else: 

712 # EWM: Update IIR filter state for EWM correlation 

713 fin_va = np.isfinite(vol_adj_val) 

714 va_f = np.where(fin_va, vol_adj_val, 0.0) 

715 joint_fin = fin_va[:, np.newaxis] & fin_va[np.newaxis, :] # (N, N) 

716 

717 new_v_x = (va_f[:, np.newaxis] * joint_fin)[np.newaxis] # (1, N, N) 

718 new_v_x2 = ((va_f**2)[:, np.newaxis] * joint_fin)[np.newaxis] # (1, N, N) 

719 new_v_xy = (va_f[:, np.newaxis] * va_f[np.newaxis, :])[np.newaxis] # (1, N, N) 

720 new_v_w = joint_fin.astype(np.float64)[np.newaxis] # (1, N, N) 

721 

722 filt_a_corr = np.array([1.0, -beta_corr]) 

723 # y_x[0] is the current-step EWM state (filter output); corr_zi_x is 

724 # the new filter memory (zf = beta * y[0]) passed as zi next step. 

725 y_x, corr_zi_x = lfilter([1.0], filt_a_corr, new_v_x, axis=0, zi=state.corr_zi_x) 

726 y_x2, corr_zi_x2 = lfilter([1.0], filt_a_corr, new_v_x2, axis=0, zi=state.corr_zi_x2) 

727 y_xy, corr_zi_xy = lfilter([1.0], filt_a_corr, new_v_xy, axis=0, zi=state.corr_zi_xy) 

728 y_w, corr_zi_w = lfilter([1.0], filt_a_corr, new_v_w, axis=0, zi=state.corr_zi_w) 

729 corr_count = state.corr_count + joint_fin.astype(np.int64) 

730 

731 # ── Early return during EWM warmup period ─────────────────────────── 

732 # All accumulators are already updated above; skip the O(N²) matrix 

733 # reconstruction and O(N³) Cholesky solve which are wasteful during 

734 # warmup — the computed positions would be discarded anyway. 

735 if in_warmup: 

736 state.corr_zi_x = corr_zi_x 

737 state.corr_zi_x2 = corr_zi_x2 

738 state.corr_zi_xy = corr_zi_xy 

739 state.corr_zi_w = corr_zi_w 

740 state.corr_count = corr_count 

741 state.vola_s_x = vola_s_x 

742 state.vola_s_x2 = vola_s_x2 

743 state.vola_s_w = vola_s_w 

744 state.vola_s_w2 = vola_s_w2 

745 state.vola_count = vola_count 

746 state.pct_s_x = pct_s_x 

747 state.pct_s_x2 = pct_s_x2 

748 state.pct_s_w = pct_s_w 

749 state.pct_s_w2 = pct_s_w2 

750 state.pct_count = pct_count 

751 state.prev_price = new_p.copy() 

752 state.step_count += 1 

753 return StepResult( 

754 date=date, 

755 cash_position=np.full(n_assets, np.nan), 

756 status=SolveStatus.WARMUP, 

757 vola=np.full(n_assets, np.nan), 

758 ) 

759 

760 # ── Compute EWMA volatility (pct-return std) — shared ─────────────── 

761 vola_vec = _ewm_std_from_state(pct_s_x, pct_s_x2, pct_s_w, pct_s_w2, pct_count, min_samples=cfg.vola) 

762 

763 # ── Solve for position ─────────────────────────────────────────────── 

764 new_cash_pos = np.full(n_assets, np.nan, dtype=float) 

765 status = SolveStatus.DEGENERATE 

766 

767 mask = np.isfinite(new_p) 

768 

769 if isinstance(cfg.covariance_config, SlidingWindowConfig): 

770 # ── SW path: FactorModel solve via Woodbury identity ───────────── 

771 sw_config = cfg.covariance_config 

772 if not mask.any(): 

773 status = SolveStatus.DEGENERATE 

774 else: 

775 win_w = sw_config.window 

776 win_k = sw_config.n_factors 

777 window_ret = np.where( 

778 np.isfinite(state.sw_ret_buf[:, mask]), # type: ignore[index] 

779 state.sw_ret_buf[:, mask], # type: ignore[index] 

780 0.0, 

781 ) 

782 n_sub = int(mask.sum()) 

783 k_eff = min(win_k, win_w, n_sub) 

784 if sw_config.max_components is not None: 

785 k_eff = min(k_eff, sw_config.max_components) 

786 try: 

787 fm = FactorModel.from_returns(window_ret, k=k_eff) 

788 except (np.linalg.LinAlgError, ValueError) as exc: 

789 _logger.debug("Sliding window SVD failed at date=%s: %s", date, exc) 

790 new_cash_pos[mask] = 0.0 

791 status = SolveStatus.DEGENERATE 

792 else: 

793 expected_mu = np.nan_to_num(new_m[mask]) 

794 if np.allclose(expected_mu, 0.0): 

795 new_cash_pos[mask] = 0.0 

796 status = SolveStatus.ZERO_SIGNAL 

797 else: 

798 try: 

799 x = fm.solve(expected_mu) 

800 denom_val = float(np.sqrt(max(0.0, float(np.dot(expected_mu, x))))) 

801 except (SingularMatrixError, np.linalg.LinAlgError) as exc: 

802 _logger.warning("Woodbury solve failed at date=%s: %s", date, exc) 

803 new_cash_pos[mask] = 0.0 

804 status = SolveStatus.DEGENERATE 

805 else: 

806 if not np.isfinite(denom_val) or denom_val <= cfg.denom_tol: 

807 _logger.warning( 

808 "Positions zeroed at date=%s (sliding_window): normalisation " 

809 "denominator degenerate (denom=%s, denom_tol=%s).", 

810 date, 

811 denom_val, 

812 cfg.denom_tol, 

813 ) 

814 new_cash_pos[mask] = 0.0 

815 status = SolveStatus.DEGENERATE 

816 else: 

817 risk_pos = x / denom_val 

818 vola_sub = vola_vec[mask] 

819 with np.errstate(invalid="ignore"): 

820 new_cash_pos[mask] = risk_pos / vola_sub 

821 status = SolveStatus.VALID 

822 else: 

823 # ── EWM path: shared _corr_from_ewm_accumulators + _SolveMixin solve ── 

824 # Reconstruct the EWM correlation matrix from filter outputs using 

825 # the shared helper — the same formula as _ewm_corr_with_final_state 

826 # but operating on (N, N) slices instead of (T, N, N) tensors. 

827 # Use y_*[0] (the filter OUTPUT for this step), not zf[0]. 

828 corr = _corr_from_ewm_accumulators( 

829 y_x[0], 

830 y_x2[0], 

831 y_xy[0], 

832 y_w[0], 

833 corr_count, 

834 min_periods=cfg.corr, 

835 min_corr_denom=cfg.min_corr_denom, 

836 ) 

837 matrix = shrink2id(corr, lamb=cfg.shrink) 

838 

839 # Delegate the signal check and linear solve to _SolveMixin so that 

840 # any algorithm change (denominator guard, status labels, etc.) only 

841 # needs to be applied in _engine_solve.py. 

842 expected_mu, early = _SolveMixin._row_early_check(state.step_count, date, mask, new_m) 

843 if early is not None: 

844 _, _, _, pos, status = early 

845 new_cash_pos[mask] = pos 

846 else: 

847 corr_sub = matrix[np.ix_(mask, mask)] 

848 _, _, _, pos, status = _SolveMixin._compute_position( 

849 state.step_count, date, mask, expected_mu, MatrixBundle(matrix=corr_sub), cfg.denom_tol 

850 ) 

851 if status == SolveStatus.VALID: 

852 new_cash_pos[mask] = _SolveMixin._scale_to_cash(cast(np.ndarray, pos), vola_vec[mask]) 

853 else: 

854 new_cash_pos[mask] = pos 

855 

856 # ── Apply turnover constraint ───────────────────────────────────────── 

857 if cfg.max_turnover is not None and status == SolveStatus.VALID: 

858 new_cash_pos[mask] = _SolveMixin._apply_turnover_constraint( 

859 new_cash_pos[mask], 

860 state.prev_cash_pos[mask], 

861 cfg.max_turnover, 

862 ) 

863 

864 # ── Persist updated state ─────────────────────────────────────────── 

865 state.corr_zi_x = corr_zi_x 

866 state.corr_zi_x2 = corr_zi_x2 

867 state.corr_zi_xy = corr_zi_xy 

868 state.corr_zi_w = corr_zi_w 

869 state.corr_count = corr_count 

870 state.vola_s_x = vola_s_x 

871 state.vola_s_x2 = vola_s_x2 

872 state.vola_s_w = vola_s_w 

873 state.vola_s_w2 = vola_s_w2 

874 state.vola_count = vola_count 

875 state.pct_s_x = pct_s_x 

876 state.pct_s_x2 = pct_s_x2 

877 state.pct_s_w = pct_s_w 

878 state.pct_s_w2 = pct_s_w2 

879 state.pct_count = pct_count 

880 state.prev_price = new_p.copy() 

881 state.prev_cash_pos = new_cash_pos.copy() 

882 state.step_count += 1 

883 

884 return StepResult( 

885 date=date, 

886 cash_position=new_cash_pos, 

887 status=status, 

888 vola=vola_vec, 

889 ) 

890 

891 def save(self, path: str | os.PathLike[str]) -> None: 

892 """Serialise the stream to a ``.npz`` archive at *path*. 

893 

894 All :class:`_StreamState` arrays, the configuration, and the asset 

895 list are written in a single :func:`numpy.savez` call. A stream 

896 restored via :meth:`load` produces bit-for-bit identical 

897 :meth:`step` output. 

898 

899 Args: 

900 path: Destination file path. :func:`numpy.savez` appends 

901 ``.npz`` automatically when the suffix is absent. 

902 

903 Examples: 

904 >>> import tempfile, pathlib, numpy as np 

905 >>> import polars as pl 

906 >>> from datetime import date, timedelta 

907 >>> from basanos.math import BasanosConfig, BasanosStream 

908 >>> rng = np.random.default_rng(0) 

909 >>> n = 60 

910 >>> end = date(2024, 1, 1) + timedelta(days=n - 1) 

911 >>> dates = pl.date_range( 

912 ... date(2024, 1, 1), end, interval="1d", eager=True 

913 ... ) 

914 >>> prices = pl.DataFrame({ 

915 ... "date": dates, 

916 ... "A": np.cumprod(1 + rng.normal(0.001, 0.02, n)) * 100.0, 

917 ... "B": np.cumprod(1 + rng.normal(0.001, 0.02, n)) * 150.0, 

918 ... }) 

919 >>> mu = pl.DataFrame({ 

920 ... "date": dates, 

921 ... "A": rng.normal(0, 0.5, n), 

922 ... "B": rng.normal(0, 0.5, n), 

923 ... }) 

924 >>> cfg = BasanosConfig(vola=5, corr=10, clip=3.0, shrink=0.5, aum=1e6) 

925 >>> stream = BasanosStream.from_warmup(prices, mu, cfg) 

926 >>> with tempfile.TemporaryDirectory() as tmp: 

927 ... p = pathlib.Path(tmp) / "stream.npz" 

928 ... stream.save(p) 

929 ... restored = BasanosStream.load(p) 

930 ... restored.assets == stream.assets 

931 True 

932 """ 

933 state = self._state 

934 # Build the per-field dict automatically from _StreamState so that any 

935 # new field added to the dataclass is included without manual updates. 

936 state_arrays: dict[str, Any] = {} 

937 for field in dataclasses.fields(_StreamState): 

938 value = getattr(state, field.name) 

939 if field.name == "sw_ret_buf": 

940 # Sentinel: use an empty (0, 0) array to represent None so the 

941 # key is always present in the archive and load() can detect it. 

942 state_arrays[field.name] = value if value is not None else np.empty((0, 0), dtype=float) 

943 elif field.name == "step_count": 

944 state_arrays[field.name] = np.array(value) 

945 else: 

946 state_arrays[field.name] = value 

947 np.savez( 

948 path, 

949 format_version=np.array(_SAVE_FORMAT_VERSION), 

950 cfg_json=np.array(self._cfg.model_dump_json()), 

951 assets=np.array(self._assets), 

952 **state_arrays, 

953 ) 

954 

955 @classmethod 

956 def load(cls, path: str | os.PathLike[str]) -> BasanosStream: 

957 """Restore a stream previously saved with :meth:`save`. 

958 

959 Args: 

960 path: Path to a ``.npz`` archive written by :meth:`save`. 

961 

962 Returns: 

963 A :class:`BasanosStream` whose :meth:`step` output is 

964 bit-for-bit identical to the original stream at the time 

965 :meth:`save` was called. 

966 

967 Examples: 

968 >>> import tempfile, pathlib, numpy as np 

969 >>> import polars as pl 

970 >>> from datetime import date, timedelta 

971 >>> from basanos.math import BasanosConfig, BasanosStream 

972 >>> rng = np.random.default_rng(1) 

973 >>> n = 60 

974 >>> end = date(2024, 1, 1) + timedelta(days=n - 1) 

975 >>> dates = pl.date_range( 

976 ... date(2024, 1, 1), end, interval="1d", eager=True 

977 ... ) 

978 >>> prices = pl.DataFrame({ 

979 ... "date": dates, 

980 ... "A": np.cumprod(1 + rng.normal(0.001, 0.02, n)) * 100.0, 

981 ... "B": np.cumprod(1 + rng.normal(0.001, 0.02, n)) * 150.0, 

982 ... }) 

983 >>> mu = pl.DataFrame({ 

984 ... "date": dates, 

985 ... "A": rng.normal(0, 0.5, n), 

986 ... "B": rng.normal(0, 0.5, n), 

987 ... }) 

988 >>> cfg = BasanosConfig(vola=5, corr=10, clip=3.0, shrink=0.5, aum=1e6) 

989 >>> stream = BasanosStream.from_warmup(prices, mu, cfg) 

990 >>> with tempfile.TemporaryDirectory() as tmp: 

991 ... p = pathlib.Path(tmp) / "stream.npz" 

992 ... stream.save(p) 

993 ... restored = BasanosStream.load(p) 

994 ... restored.assets == stream.assets 

995 True 

996 """ 

997 data = np.load(path, allow_pickle=False) 

998 if "format_version" not in data: 

999 raise ValueError( # noqa: TRY003 

1000 "Stream file is missing a format version tag. " 

1001 "It was written with an incompatible version of BasanosStream. " 

1002 "Re-generate it via BasanosStream.from_warmup()." 

1003 ) 

1004 found = int(data["format_version"]) 

1005 if found != _SAVE_FORMAT_VERSION: 

1006 raise ValueError( # noqa: TRY003 

1007 f"Stream file was written with format version {found}, " 

1008 f"but the current version is {_SAVE_FORMAT_VERSION}. " 

1009 "Re-generate it via BasanosStream.from_warmup()." 

1010 ) 

1011 # Validate that every required key is present. This catches archives 

1012 # that were produced by an older codebase missing a newly added field, 

1013 # or archives that have been manually edited, with a descriptive error 

1014 # instead of a bare KeyError. 

1015 archive_keys = frozenset(data.files) 

1016 missing = _REQUIRED_KEYS - archive_keys 

1017 if missing: 

1018 raise StreamStateCorruptError(missing) 

1019 cfg = BasanosConfig.model_validate_json(data["cfg_json"].item()) 

1020 assets: list[str] = list(data["assets"]) 

1021 state_kwargs: dict[str, Any] = {} 

1022 for field in dataclasses.fields(_StreamState): 

1023 raw = data[field.name] 

1024 if field.name == "sw_ret_buf": 

1025 state_kwargs[field.name] = raw if raw.size > 0 else None 

1026 elif field.name == "step_count": 

1027 state_kwargs[field.name] = int(raw) 

1028 else: 

1029 state_kwargs[field.name] = raw 

1030 state = _StreamState(**state_kwargs) 

1031 return cls(cfg=cfg, assets=assets, state=state)