Coverage for src/cvx/linalg/core/valid.py: 100%

10 statements  

« prev     ^ index     » next       coverage.py v7.15.0, created at 2026-07-03 18:56 +0000

1"""Matrix validation utilities for handling non-finite values. 

2 

3This module provides functions for validating and cleaning matrices that may 

4contain non-finite values (NaN or infinity). This is particularly useful when 

5working with financial data where missing values are common. 

6 

7Example: 

8 Extract the valid submatrix from a covariance matrix with missing data: 

9 

10 >>> import numpy as np 

11 >>> from cvx.linalg import valid 

12 >>> # Create a covariance matrix with some NaN values on diagonal 

13 >>> cov = np.array([[np.nan, 0.5, 0.2], 

14 ... [0.5, 2.0, 0.3], 

15 ... [0.2, 0.3, np.nan]]) 

16 >>> # Get valid indicator and submatrix 

17 >>> v, submatrix = valid(cov) 

18 >>> v # Second row/column is valid 

19 array([False, True, False]) 

20 >>> submatrix 

21 array([[2.]]) 

22 

23""" 

24 

25from __future__ import annotations 

26 

27import numpy as np 

28import numpy.typing as npt 

29 

30from .exceptions import NonSquareMatrixError 

31from .types import Matrix 

32 

33 

34def valid(matrix: Matrix) -> tuple[npt.NDArray[np.bool_], Matrix]: 

35 """Extract the valid subset of a matrix by removing rows/columns with non-finite values. 

36 

37 This function identifies rows and columns in a square matrix that contain 

38 non-finite values (NaN or infinity) on the diagonal and removes them, 

39 returning both the indicator vector and the resulting valid submatrix. 

40 

41 This is useful when working with covariance matrices where some assets 

42 may have missing or invalid data. 

43 

44 Args: 

45 matrix: A square n x n matrix to be validated. Typically a covariance 

46 or correlation matrix. 

47 

48 Returns: 

49 A tuple containing: 

50 - v: Boolean vector of shape (n,) indicating which rows/columns are 

51 valid (True for valid, False for invalid). 

52 - submatrix: The valid submatrix with invalid rows/columns removed. 

53 Shape is (k, k) where k is the number of True values in v. 

54 

55 Raises: 

56 NonSquareMatrixError: If the input matrix is not square (n x n). 

57 

58 Example: 

59 Basic usage with a covariance matrix: 

60 

61 >>> import numpy as np 

62 >>> from cvx.linalg import valid 

63 >>> # Create a 3x3 matrix with one invalid entry 

64 >>> cov = np.array([[1.0, 0.5, 0.2], 

65 ... [0.5, np.nan, 0.3], 

66 ... [0.2, 0.3, 1.0]]) 

67 >>> v, submatrix = valid(cov) 

68 >>> v 

69 array([ True, False, True]) 

70 >>> submatrix 

71 array([[1. , 0.2], 

72 [0.2, 1. ]]) 

73 

74 Handling a fully valid matrix: 

75 

76 >>> cov = np.array([[1.0, 0.5], [0.5, 1.0]]) 

77 >>> v, submatrix = valid(cov) 

78 >>> v 

79 array([ True, True]) 

80 >>> np.allclose(submatrix, cov) 

81 True 

82 

83 Handling infinity values: 

84 

85 >>> cov = np.array([[1.0, 0.5, 0.2], 

86 ... [0.5, np.inf, 0.3], 

87 ... [0.2, 0.3, 1.0]]) 

88 >>> v, submatrix = valid(cov) 

89 >>> v 

90 array([ True, False, True]) 

91 >>> submatrix 

92 array([[1. , 0.2], 

93 [0.2, 1. ]]) 

94 

95 Multiple invalid entries: 

96 

97 >>> cov = np.array([[np.nan, 0.1, 0.2, 0.3], 

98 ... [0.1, 2.0, 0.4, 0.5], 

99 ... [0.2, 0.4, np.nan, 0.6], 

100 ... [0.3, 0.5, 0.6, 3.0]]) 

101 >>> v, submatrix = valid(cov) 

102 >>> v 

103 array([False, True, False, True]) 

104 >>> submatrix.shape 

105 (2, 2) 

106 >>> submatrix 

107 array([[2. , 0.5], 

108 [0.5, 3. ]]) 

109 

110 Non-square matrix raises NonSquareMatrixError: 

111 

112 >>> try: 

113 ... valid(np.array([[1, 2, 3], [4, 5, 6]])) 

114 ... except NonSquareMatrixError: 

115 ... print("Caught NonSquareMatrixError for non-square matrix") 

116 Caught NonSquareMatrixError for non-square matrix 

117 

118 Note: 

119 The function checks only the diagonal elements for validity. It assumes 

120 that if the diagonal is finite, the entire row/column is valid. This is 

121 a common assumption for covariance matrices. 

122 

123 """ 

124 if matrix.shape[0] != matrix.shape[1]: 

125 raise NonSquareMatrixError(matrix.shape[0], matrix.shape[1]) 

126 

127 v = np.isfinite(np.diag(matrix)) 

128 return v, matrix[:, v][v]