Coverage for src/fast_minimum_variance/proximal.py: 100%
58 statements
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-02 13:28 +0000
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-02 13:28 +0000
1"""Fast solver for 0.5 ||mat @ x - vec||^2 s. t. {x >= 0, sum(x) = 1}.
3This module implements proximal gradient descent for constrained linear least squares
4optimization on the probability simplex. The algorithm is based on iterative projection
5using the efficient simplex projection from Duchi et al. (2008).
7The gradient is evaluated matrix-free as mat.T @ (mat @ w), avoiding explicit assembly
8of the n x n normal matrix mat.T @ mat. The Lipschitz constant is estimated via power
9iteration, also matrix-free.
11References:
12----------
13Duchi, J., Shalev-Shwartz, S., Singer, Y., & Chandra, T. (2008).
14"Efficient Projections onto the l1-Ball for Learning in High Dimensions."
15Proceedings of the 25th International Conference on Machine Learning (ICML).
16"""
18from __future__ import annotations
20from typing import TYPE_CHECKING, cast
22import numpy as np
23from cvx.linalg import power_iteration
25if TYPE_CHECKING:
26 from collections.abc import Callable
28 from numpy.typing import NDArray
31def proj_simplex(
32 vec: NDArray[np.floating],
33 rad: float = 1.0,
34) -> NDArray[np.floating]:
35 """Project a vector onto the probability simplex.
37 This function computes the Euclidean projection of a given vector onto the probability
38 simplex. The simplex is defined as the set of non-negative vectors that sum to a
39 given radius, typically 1. The projection ensures that the resulting vector satisfies
40 these constraints.
42 The algorithm is based on Duchi et al. (2008) "Efficient Projections onto the
43 l1-Ball for Learning in High Dimensions".
45 Parameters
46 ----------
47 vec : NDArray[np.floating]
48 Input vector that is to be projected onto the simplex.
49 rad : float, optional
50 Radius of the simplex. The projected vector will have components summing
51 to this value. Default is 1.0.
53 Returns:
54 -------
55 NDArray[np.floating]
56 The projected vector that lies on the probability simplex.
58 Raises:
59 ------
60 ValueError
61 If the input vector is empty.
63 Examples:
64 --------
65 >>> import numpy as np
66 >>> vec = np.array([1.0, 2.0, 3.0])
67 >>> result = proj_simplex(vec)
68 >>> bool(np.isclose(result.sum(), 1.0))
69 True
70 >>> bool(np.all(result >= 0))
71 True
73 """
74 muu = np.sort(vec)[::-1]
75 cummeans = 1 / np.arange(1, len(vec) + 1) * (np.cumsum(muu) - rad)
76 rho = max(np.where(muu > cummeans)[0])
77 result: NDArray[np.floating] = np.maximum(vec - cummeans[rho], 0)
78 return result
81def _lipschitz(
82 mat: NDArray[np.floating],
83 extra_matvec: Callable[[NDArray[np.floating]], NDArray[np.floating]] | None = None,
84 n_iter: int = 30,
85 rng: np.random.Generator | None = None,
86) -> float:
87 """Estimate lambda_max(mat.T @ mat + extra) via power iteration (matrix-free).
89 extra_matvec: optional callable v -> extra @ v for a second SPD contribution.
90 Each iteration costs O(rows * cols) — two matrix-vector products with mat —
91 and never forms the cols x cols normal matrix. The iteration is delegated to
92 cvx-linalg's operator-aware power_iteration, applied matrix-free to the normal
93 operator v -> mat.T @ (mat @ v) (+ extra).
94 """
95 seed = None if rng is None else int(rng.integers(np.iinfo(np.int64).max))
97 def normal_matvec(v: NDArray[np.floating]) -> NDArray[np.floating]:
98 """Apply the normal operator mat.T @ mat (+ extra) to v, matrix-free."""
99 w = mat.T @ (mat @ v)
100 if extra_matvec is not None:
101 w = w + extra_matvec(v)
102 return w
104 matvec = cast("Callable[[NDArray[np.float64]], NDArray[np.float64]]", normal_matvec)
105 eigenvalue, _ = power_iteration(matvec, n=mat.shape[1], n_iter=n_iter, seed=seed)
106 return max(float(eigenvalue), 0.0)
109def fista_gradient(
110 mat: NDArray[np.floating],
111 vec: NDArray[np.floating],
112 *,
113 extra_grad: Callable[[NDArray[np.floating]], NDArray[np.floating]] | None = None,
114 eps_rel: float = 1e-6,
115 max_iter: int = 100000,
116) -> tuple[NDArray[np.floating], int]:
117 r"""FISTA (Nesterov-accelerated proximal gradient) on the probability simplex.
119 Same interface as ``prox_gradient`` but uses the Beck-Teboulle momentum
120 sequence $t_{k+1} = (1 + \\sqrt{1+4t_k^2})/2$ to achieve $O(1/k^2)$
121 convergence for convex objectives (versus $O(1/k)$ for plain gradient
122 descent). For strongly convex $f$ with condition number $\\kappa$ the
123 linear convergence rate is $(1 - 1/\\sqrt{\\kappa})^k$, matching CG's
124 asymptotic iteration count.
126 The gradient is evaluated at the extrapolated point $y_k$; the simplex
127 projection is applied to obtain $x_k$; the momentum update then forms
128 $y_{k+1} = x_k + \\frac{t_k-1}{t_{k+1}}(x_k - x_{k-1})$.
130 References:
131 ----------
132 Beck, A., & Teboulle, M. (2009). "A Fast Iterative Shrinkage-Thresholding
133 Algorithm for Linear Inverse Problems." SIAM Journal on Imaging Sciences.
135 Examples:
136 --------
137 >>> import numpy as np
138 >>> mat = np.array([[1.0, 0.5], [0.5, 1.0]])
139 >>> vec = np.ones(2)
140 >>> result, _ = fista_gradient(mat, vec)
141 >>> bool(np.isclose(result.sum(), 1.0))
142 True
144 """
145 rng = np.random.default_rng()
146 lip = _lipschitz(mat, extra_matvec=extra_grad, rng=rng)
147 step = 1.0 / lip if lip > 1e-15 else 1.0
148 out_prod = mat.T @ vec
150 x = proj_simplex(np.asarray(rng.standard_normal(mat.shape[1])))
151 y = x.copy()
152 t = 1.0
154 for ite in range(1, max_iter + 1): # noqa: B007
155 grad = mat.T @ (mat @ y) - out_prod
156 if extra_grad is not None:
157 grad = grad + extra_grad(y)
158 x_new = proj_simplex(y - step * grad)
160 t_new = 0.5 * (1.0 + np.sqrt(1.0 + 4.0 * t * t))
161 y = x_new + ((t - 1.0) / t_new) * (x_new - x)
163 err = float(np.linalg.norm(x - x_new))
164 x = x_new
165 t = t_new
167 if err < eps_rel:
168 break
170 return x, ite
173def prox_gradient(
174 mat: NDArray[np.floating],
175 vec: NDArray[np.floating],
176 *,
177 extra_grad: Callable[[NDArray[np.floating]], NDArray[np.floating]] | None = None,
178 eps_rel: float = 1e-6,
179 max_iter: int = 100000,
180) -> tuple[NDArray[np.floating], int]:
181 """Perform proximal gradient descent to solve a constrained optimization problem.
183 Solves the optimization problem:
184 minimize 0.5 ||mat @ x - vec||^2 + g(x)
185 subject to x >= 0, sum(x) = 1
187 where g captures an optional extra gradient term supplied via ``extra_grad``.
188 The gradient is evaluated matrix-free at each step; the normal matrix
189 mat.T @ mat is never formed. The Lipschitz constant is estimated once via
190 power iteration at O(n_power_iter * rows * cols) setup cost.
192 Parameters
193 ----------
194 mat : NDArray[np.floating]
195 A matrix of shape (n_samples, n_features).
196 vec : NDArray[np.floating]
197 A vector of shape (n_samples,).
198 extra_grad : callable, optional
199 v -> additional gradient term (e.g. ``alpha * target @ v`` for
200 Ledoit-Wolf shrinkage). Must be SPD for convergence guarantees.
201 When provided, the Lipschitz estimate accounts for this term.
202 eps_rel : float, optional
203 Relative step-size change stopping tolerance. Default is 1e-6.
204 max_iter : int, optional
205 Maximum number of iterations. Default is 100000.
207 Returns:
208 -------
209 tuple[NDArray[np.floating], int]
210 ``(w, n_iters)`` — weight vector of shape (n_features,) and the
211 number of gradient steps taken.
213 Examples:
214 --------
215 >>> import numpy as np
216 >>> mat = np.array([[1.0, 0.5], [0.5, 1.0]])
217 >>> vec = np.ones(2)
218 >>> result, _ = prox_gradient(mat, vec)
219 >>> bool(np.isclose(result.sum(), 1.0))
220 True
222 """
223 rng = np.random.default_rng()
224 prim_var: NDArray[np.floating] = np.asarray(rng.standard_normal(size=mat.shape[1]))
225 lip = _lipschitz(mat, extra_matvec=extra_grad, rng=rng)
226 step = 1.0 / lip if lip > 1e-15 else 1.0
228 # Precompute mat.T @ vec once; zero for minimum-variance (vec = 0).
229 out_prod = mat.T @ vec
230 ite = 0
231 err_rel = eps_rel + 1
232 while err_rel > eps_rel and ite < max_iter:
233 grad = mat.T @ (mat @ prim_var) - out_prod
234 if extra_grad is not None:
235 grad = grad + extra_grad(prim_var)
236 prim_var_new = proj_simplex(prim_var - step * grad)
237 err_rel = float(np.linalg.norm(prim_var - prim_var_new, 2))
238 prim_var = prim_var_new.copy()
239 ite += 1
240 return prim_var, ite