Coverage for src / cvx / linalg / lstsq.py: 100%
24 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-19 05:40 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-19 05:40 +0000
1"""Least-squares solver with NaN-aware row filtering."""
3from __future__ import annotations
5import warnings
7import numpy as np
9from .exceptions import DimensionMismatchError, IllConditionedMatrixWarning
11_DEFAULT_COND_THRESHOLD: float = 1e12
12"""Default condition-number threshold above which a warning is emitted."""
15def lstsq(
16 matrix: np.ndarray,
17 rhs: np.ndarray,
18 cond_threshold: float = _DEFAULT_COND_THRESHOLD,
19) -> tuple[np.ndarray, np.ndarray, int, np.ndarray]:
20 """Solve an overdetermined or underdetermined system in the least-squares sense.
22 Rows where any entry in *matrix* or the corresponding entry in *rhs* is
23 non-finite are excluded before solving. The returned solution vector
24 always has length equal to the number of columns in *matrix*. When the
25 effective condition number of the valid sub-matrix exceeds
26 *cond_threshold*, an ``IllConditionedMatrixWarning`` is emitted.
28 Args:
29 matrix: Coefficient matrix of shape ``(m, n)``.
30 rhs: Right-hand side vector of length ``m``.
31 cond_threshold: Condition-number threshold above which a warning is
32 emitted. Defaults to ``1e12``.
34 Returns:
35 A four-tuple ``(x, residuals, rank, sv)`` matching the convention of
36 :func:`numpy.linalg.lstsq`:
38 - ``x`` — least-squares solution of shape ``(n,)``.
39 - ``residuals`` — sum of squared residuals; empty when the solution is
40 not unique or all rows are invalid.
41 - ``rank`` — effective rank of the valid sub-matrix.
42 - ``sv`` — singular values of the valid sub-matrix in descending order.
44 Raises:
45 DimensionMismatchError: If ``rhs`` length does not match the number of
46 rows in *matrix*.
48 Example:
49 >>> import numpy as np
50 >>> from cvx.linalg import lstsq
51 >>> A = np.array([[1.0, 1.0], [1.0, 2.0], [1.0, 3.0]])
52 >>> b = np.array([6.0, 5.0, 7.0])
53 >>> x, res, rank, sv = lstsq(A, b)
54 >>> int(rank)
55 2
57 NaN rows are silently dropped:
59 >>> A_nan = np.array([[1.0, 1.0], [np.nan, 2.0], [1.0, 3.0]])
60 >>> b_nan = np.array([6.0, 5.0, 7.0])
61 >>> x2, _, rank2, _ = lstsq(A_nan, b_nan)
62 >>> int(rank2)
63 2
64 """
65 if rhs.shape[0] != matrix.shape[0]:
66 raise DimensionMismatchError(rhs.shape[0], matrix.shape[0])
68 n_cols = matrix.shape[1]
70 # Filter rows that contain any non-finite value in matrix or rhs.
71 row_mask = np.isfinite(matrix).all(axis=1) & np.isfinite(rhs)
72 sub_matrix = matrix[row_mask]
73 sub_rhs = rhs[row_mask]
75 if sub_matrix.shape[0] == 0:
76 return np.full(n_cols, np.nan), np.array([]), 0, np.array([])
78 x, residuals, rank, sv = np.linalg.lstsq(sub_matrix, sub_rhs, rcond=None)
80 # Compute condition number from singular values.
81 if sv.size > 0 and sv[-1] > 0:
82 cond = float(sv[0] / sv[-1])
83 elif sv.size > 0:
84 cond = float("inf")
85 else:
86 cond = 1.0
88 if cond > cond_threshold:
89 warnings.warn(
90 f"Matrix condition number {cond:.3e} exceeds threshold {cond_threshold:.3e}; "
91 "results may be numerically unreliable.",
92 IllConditionedMatrixWarning,
93 stacklevel=2,
94 )
96 return x, residuals, rank, sv