Coverage for src / cvx / linalg / lstsq.py: 100%

24 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-19 05:40 +0000

1"""Least-squares solver with NaN-aware row filtering.""" 

2 

3from __future__ import annotations 

4 

5import warnings 

6 

7import numpy as np 

8 

9from .exceptions import DimensionMismatchError, IllConditionedMatrixWarning 

10 

11_DEFAULT_COND_THRESHOLD: float = 1e12 

12"""Default condition-number threshold above which a warning is emitted.""" 

13 

14 

15def lstsq( 

16 matrix: np.ndarray, 

17 rhs: np.ndarray, 

18 cond_threshold: float = _DEFAULT_COND_THRESHOLD, 

19) -> tuple[np.ndarray, np.ndarray, int, np.ndarray]: 

20 """Solve an overdetermined or underdetermined system in the least-squares sense. 

21 

22 Rows where any entry in *matrix* or the corresponding entry in *rhs* is 

23 non-finite are excluded before solving. The returned solution vector 

24 always has length equal to the number of columns in *matrix*. When the 

25 effective condition number of the valid sub-matrix exceeds 

26 *cond_threshold*, an ``IllConditionedMatrixWarning`` is emitted. 

27 

28 Args: 

29 matrix: Coefficient matrix of shape ``(m, n)``. 

30 rhs: Right-hand side vector of length ``m``. 

31 cond_threshold: Condition-number threshold above which a warning is 

32 emitted. Defaults to ``1e12``. 

33 

34 Returns: 

35 A four-tuple ``(x, residuals, rank, sv)`` matching the convention of 

36 :func:`numpy.linalg.lstsq`: 

37 

38 - ``x`` — least-squares solution of shape ``(n,)``. 

39 - ``residuals`` — sum of squared residuals; empty when the solution is 

40 not unique or all rows are invalid. 

41 - ``rank`` — effective rank of the valid sub-matrix. 

42 - ``sv`` — singular values of the valid sub-matrix in descending order. 

43 

44 Raises: 

45 DimensionMismatchError: If ``rhs`` length does not match the number of 

46 rows in *matrix*. 

47 

48 Example: 

49 >>> import numpy as np 

50 >>> from cvx.linalg import lstsq 

51 >>> A = np.array([[1.0, 1.0], [1.0, 2.0], [1.0, 3.0]]) 

52 >>> b = np.array([6.0, 5.0, 7.0]) 

53 >>> x, res, rank, sv = lstsq(A, b) 

54 >>> int(rank) 

55 2 

56 

57 NaN rows are silently dropped: 

58 

59 >>> A_nan = np.array([[1.0, 1.0], [np.nan, 2.0], [1.0, 3.0]]) 

60 >>> b_nan = np.array([6.0, 5.0, 7.0]) 

61 >>> x2, _, rank2, _ = lstsq(A_nan, b_nan) 

62 >>> int(rank2) 

63 2 

64 """ 

65 if rhs.shape[0] != matrix.shape[0]: 

66 raise DimensionMismatchError(rhs.shape[0], matrix.shape[0]) 

67 

68 n_cols = matrix.shape[1] 

69 

70 # Filter rows that contain any non-finite value in matrix or rhs. 

71 row_mask = np.isfinite(matrix).all(axis=1) & np.isfinite(rhs) 

72 sub_matrix = matrix[row_mask] 

73 sub_rhs = rhs[row_mask] 

74 

75 if sub_matrix.shape[0] == 0: 

76 return np.full(n_cols, np.nan), np.array([]), 0, np.array([]) 

77 

78 x, residuals, rank, sv = np.linalg.lstsq(sub_matrix, sub_rhs, rcond=None) 

79 

80 # Compute condition number from singular values. 

81 if sv.size > 0 and sv[-1] > 0: 

82 cond = float(sv[0] / sv[-1]) 

83 elif sv.size > 0: 

84 cond = float("inf") 

85 else: 

86 cond = 1.0 

87 

88 if cond > cond_threshold: 

89 warnings.warn( 

90 f"Matrix condition number {cond:.3e} exceeds threshold {cond_threshold:.3e}; " 

91 "results may be numerically unreliable.", 

92 IllConditionedMatrixWarning, 

93 stacklevel=2, 

94 ) 

95 

96 return x, residuals, rank, sv