Coverage for src/cvx/linalg/solve/lstsq.py: 100%
22 statements
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-03 18:56 +0000
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-03 18:56 +0000
1"""Least-squares solver with NaN-aware row filtering."""
3from __future__ import annotations
5import numpy as np
7from ..core.exceptions import DEFAULT_COND_THRESHOLD, DimensionMismatchError
8from ..core.exceptions import warn_ill_conditioned as _warn_ill_conditioned
9from ..core.types import Matrix, Vector
12def lstsq(
13 matrix: Matrix,
14 rhs: Vector,
15 cond_threshold: float = DEFAULT_COND_THRESHOLD,
16) -> tuple[Vector, Vector, int, Vector]:
17 """Solve an overdetermined or underdetermined system in the least-squares sense.
19 Rows where any entry in *matrix* or the corresponding entry in *rhs* is
20 non-finite are excluded before solving. The returned solution vector
21 always has length equal to the number of columns in *matrix*. When the
22 effective condition number of the valid sub-matrix exceeds
23 *cond_threshold*, an ``IllConditionedMatrixWarning`` is emitted.
25 Args:
26 matrix: Coefficient matrix of shape ``(m, n)``.
27 rhs: Right-hand side vector of length ``m``.
28 cond_threshold: Condition-number threshold above which a warning is
29 emitted. Defaults to ``1e12``.
31 Returns:
32 A four-tuple ``(x, residuals, rank, sv)`` matching the convention of
33 :func:`numpy.linalg.lstsq`:
35 - ``x`` — least-squares solution of shape ``(n,)``.
36 - ``residuals`` — sum of squared residuals; empty when the solution is
37 not unique or all rows are invalid.
38 - ``rank`` — effective rank of the valid sub-matrix.
39 - ``sv`` — singular values of the valid sub-matrix in descending order.
41 Raises:
42 DimensionMismatchError: If ``rhs`` length does not match the number of
43 rows in *matrix*.
45 Example:
46 >>> import numpy as np
47 >>> from cvx.linalg import lstsq
48 >>> A = np.array([[1.0, 1.0], [1.0, 2.0], [1.0, 3.0]])
49 >>> b = np.array([6.0, 5.0, 7.0])
50 >>> x, res, rank, sv = lstsq(A, b)
51 >>> int(rank)
52 2
54 NaN rows are silently dropped:
56 >>> A_nan = np.array([[1.0, 1.0], [np.nan, 2.0], [1.0, 3.0]])
57 >>> b_nan = np.array([6.0, 5.0, 7.0])
58 >>> x2, _, rank2, _ = lstsq(A_nan, b_nan)
59 >>> int(rank2)
60 2
61 """
62 if rhs.shape[0] != matrix.shape[0]:
63 raise DimensionMismatchError(rhs.shape[0], matrix.shape[0])
65 n_cols = matrix.shape[1]
67 # Filter rows that contain any non-finite value in matrix or rhs.
68 row_mask = np.isfinite(matrix).all(axis=1) & np.isfinite(rhs)
69 sub_matrix = matrix[row_mask]
70 sub_rhs = rhs[row_mask]
72 if sub_matrix.shape[0] == 0:
73 return np.full(n_cols, np.nan), np.array([]), 0, np.array([])
75 x, residuals, rank, sv = np.linalg.lstsq(sub_matrix, sub_rhs, rcond=None)
77 # Compute condition number from singular values.
78 if sv.size > 0 and sv[-1] > 0:
79 cond = float(sv[0] / sv[-1])
80 elif sv.size > 0:
81 cond = float("inf")
82 else:
83 cond = 1.0
85 _warn_ill_conditioned(cond, cond_threshold)
87 return (
88 x.astype(np.float64, copy=False),
89 residuals.astype(np.float64, copy=False),
90 int(rank),
91 sv.astype(np.float64, copy=False),
92 )