Coverage for src / cvx / linalg / inv.py: 100%

22 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-19 05:40 +0000

1"""Matrix inversion helpers that ignore rows and columns with non-finite diagonals.""" 

2 

3from __future__ import annotations 

4 

5import numpy as np 

6 

7from .exceptions import ( 

8 NonSquareMatrixError, 

9 SingularMatrixError, 

10) 

11from .exceptions import ( 

12 check_and_warn_condition as _check_and_warn_condition, 

13) 

14from .valid import valid 

15 

16_DEFAULT_COND_THRESHOLD: float = 1e12 

17"""Default condition-number threshold above which a warning is emitted.""" 

18 

19 

20def inv( 

21 matrix: np.ndarray, 

22 cond_threshold: float = _DEFAULT_COND_THRESHOLD, 

23) -> np.ndarray: 

24 """Invert a matrix restricted to the valid submatrix. 

25 

26 Rows and columns with non-finite diagonal entries are excluded from the 

27 inversion; the corresponding rows and columns in the result are set to NaN. 

28 When the condition number of the valid sub-matrix exceeds *cond_threshold*, 

29 an ``IllConditionedMatrixWarning`` is emitted. 

30 

31 Args: 

32 matrix: Square matrix to invert. 

33 cond_threshold: Condition-number threshold above which a warning is 

34 emitted. Defaults to ``1e12``. 

35 

36 Returns: 

37 An inverted matrix with the same shape as *matrix*. Rows and columns 

38 mapped to invalid entries are returned as ``NaN``. 

39 

40 Raises: 

41 NonSquareMatrixError: If the matrix is not square. 

42 SingularMatrixError: If the valid sub-matrix is singular. 

43 

44 Example: 

45 >>> import numpy as np 

46 >>> from cvx.linalg import inv 

47 >>> np.allclose(inv(np.eye(2)), np.eye(2)) 

48 True 

49 

50 NaN-masked entries are skipped: 

51 

52 >>> matrix = np.array([[4.0, 0.0], [0.0, np.nan]]) 

53 >>> result = inv(matrix) 

54 >>> float(result[0, 0]) 

55 0.25 

56 >>> bool(np.isnan(result[0, 1]) and np.isnan(result[1, 0]) and np.isnan(result[1, 1])) 

57 True 

58 """ 

59 if matrix.shape[0] != matrix.shape[1]: 

60 raise NonSquareMatrixError(matrix.shape[0], matrix.shape[1]) 

61 

62 n = matrix.shape[0] 

63 result = np.full((n, n), np.nan) 

64 mask, submatrix = valid(matrix) 

65 

66 if mask.any(): 

67 _check_and_warn_condition(submatrix, cond_threshold) 

68 try: 

69 sub_inv = np.linalg.inv(submatrix) 

70 except np.linalg.LinAlgError as exc: 

71 raise SingularMatrixError(str(exc)) from exc 

72 

73 idx = np.where(mask)[0] 

74 result[np.ix_(idx, idx)] = sub_inv 

75 

76 return result