Coverage for src/cvx/linalg/solve/solve.py: 100%
21 statements
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-03 18:56 +0000
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-03 18:56 +0000
1"""Linear-system helpers that ignore matrix rows and columns with non-finite diagonals."""
3from __future__ import annotations
5import numpy as np
7from ..core.exceptions import (
8 DEFAULT_COND_THRESHOLD,
9 DimensionMismatchError,
10 NonSquareMatrixError,
11 SingularMatrixError,
12)
13from ..core.exceptions import (
14 check_and_warn_condition as _check_and_warn_condition,
15)
16from ..core.types import Matrix, Vector
17from ..core.valid import valid
18from ..decomposition.cholesky import cholesky_solve as _cholesky_solve
21def solve(
22 matrix: Matrix,
23 rhs: Vector | Matrix,
24 cond_threshold: float = DEFAULT_COND_THRESHOLD,
25) -> Vector | Matrix:
26 """Solve a linear system restricted to the valid submatrix.
28 Rows and columns with non-finite diagonal entries are excluded from the
29 solve; the corresponding positions in the result are set to NaN. Cholesky
30 decomposition is attempted first for numerical stability and falls back to
31 LU decomposition for non-positive-definite matrices. When the condition
32 number of the valid sub-matrix exceeds *cond_threshold*, an
33 ``IllConditionedMatrixWarning`` is emitted.
35 Args:
36 matrix: Square coefficient matrix of shape ``(n, n)``.
37 rhs: Right-hand side vector of length ``n`` or matrix of shape ``(n, k)``.
38 cond_threshold: Condition-number threshold above which a warning is
39 emitted. Defaults to ``1e12``.
41 Returns:
42 A solution array with the same shape as ``rhs``. Entries mapped to
43 invalid rows or columns are returned as ``NaN``.
45 Raises:
46 NonSquareMatrixError: If the matrix is not square.
47 DimensionMismatchError: If the leading dimension of ``rhs`` does not
48 match the matrix dimension.
49 SingularMatrixError: If the valid sub-matrix is singular.
51 Example:
52 >>> import numpy as np
53 >>> from cvx.linalg import solve
54 >>> solve(np.eye(2), np.array([1.0, 2.0])).tolist()
55 [1.0, 2.0]
57 NaN-masked entries are skipped:
59 >>> matrix = np.array([[4.0, 0.0], [0.0, np.nan]])
60 >>> solve(matrix, np.array([8.0, 1.0])).tolist()
61 [2.0, nan]
63 Matrix right-hand sides are supported:
65 >>> solve(np.eye(2), np.array([[1.0, 2.0], [3.0, 4.0]])).tolist()
66 [[1.0, 2.0], [3.0, 4.0]]
67 """
68 if matrix.shape[0] != matrix.shape[1]:
69 raise NonSquareMatrixError(matrix.shape[0], matrix.shape[1])
71 if rhs.shape[0] != matrix.shape[0]:
72 raise DimensionMismatchError(rhs.shape[0], matrix.shape[0])
74 solution = np.full(rhs.shape, np.nan)
75 mask, submatrix = valid(matrix)
77 if mask.any():
78 _check_and_warn_condition(submatrix, cond_threshold)
79 try:
80 solution[mask] = _cholesky_solve(submatrix, rhs[mask])
81 except np.linalg.LinAlgError as exc:
82 raise SingularMatrixError(str(exc)) from exc
84 return solution