Skip to content

Commit c2aecba

Browse files
committed
gram solver && unit test
1 parent b8fd539 commit c2aecba

File tree

2 files changed

+166
-0
lines changed

2 files changed

+166
-0
lines changed

skglm/solvers/gram.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import numpy as np
2+
from numba import njit
3+
4+
5+
def gram_solver(X, y, penalty, max_iter=20, max_epoch=1000, p0=10, tol=1e-4,
6+
verbose=False):
7+
"""Run a Gram solver by reformulation the problem as below.
8+
9+
Minimize::
10+
w.T @ Q @ w / (2*n_samples) - b.T @ w / n_samples + penalty(w)
11+
12+
where::
13+
Q = X.T @ X
14+
b = X.T @ y
15+
"""
16+
n_features = X.shape[1]
17+
XtX = X.T @ X
18+
Xty = X.T @ y
19+
all_features = np.arange(n_features)
20+
p_objs_out = []
21+
22+
w = np.zeros(n_features)
23+
XtXw = np.zeros(n_features)
24+
25+
for t in range(max_iter):
26+
# compute scores
27+
grad = _construct_grad(y, XtXw, Xty, all_features)
28+
opt = penalty.subdiff_distance(w, grad, all_features)
29+
30+
# check convergences
31+
stop_crit = np.max(opt)
32+
if verbose:
33+
p_obj = _quadratic_value(X, w, XtXw, Xty) + penalty.value(w)
34+
print(
35+
f"Iteration {t+1}: {p_obj:.10f}, "
36+
f"stopping crit: {stop_crit:.2e}"
37+
)
38+
39+
if stop_crit <= tol:
40+
if verbose:
41+
print(f"Stopping criterion max violation: {stop_crit:.2e}")
42+
break
43+
44+
# build ws
45+
gsupp_size = penalty.generalized_support(w).sum()
46+
ws_size = max(min(p0, n_features),
47+
min(n_features, 2 * gsupp_size))
48+
# similar to np.argsort()[-ws_size:] but without sorting
49+
ws = np.argpartition(opt, -ws_size)[-ws_size:]
50+
tol_in = 0.3 * stop_crit
51+
52+
for epoch in range(max_epoch):
53+
# inplace update of w, XtXw
54+
_gram_cd_epoch(y, XtX, Xty, w, XtXw, penalty, ws)
55+
56+
if epoch % 10 == 0:
57+
grad = _construct_grad(y, XtXw, Xty, ws)
58+
opt_in = penalty.subdiff_distance(w, grad, ws)
59+
60+
stop_crit_in = np.max(opt_in)
61+
if max(verbose-1, 0):
62+
p_obj = _quadratic_value(X, w, XtXw, Xty) + penalty.value(w)
63+
print(
64+
f"Epoch {epoch+1}: {p_obj:.10f}, "
65+
f"stopping crit in: {stop_crit_in:.2e}"
66+
)
67+
68+
if stop_crit_in <= tol_in:
69+
if max(verbose-1, 0):
70+
print("Early exit")
71+
break
72+
73+
p_obj = _quadratic_value(X, w, XtXw, Xty) + penalty.value(w)
74+
p_objs_out.append(p_obj)
75+
return w, p_objs_out, stop_crit
76+
77+
78+
@njit
79+
def _gram_cd_epoch(y, XtX, Xty, w, XtXw, penalty, ws):
80+
# inplace update of w, XtXw
81+
for j in ws:
82+
# skip for X[:, j] == 0
83+
if XtX[j, j] == 0:
84+
continue
85+
86+
old_w_j = w[j]
87+
grad_j = (XtXw[j] - Xty[j]) / len(y)
88+
step = 1 / XtX[j, j] # 1 / lipchitz_j
89+
90+
w[j] = penalty.prox_1d(old_w_j - step * grad_j, step, j)
91+
92+
# Gram matrix update
93+
if w[j] != old_w_j:
94+
XtXw += (w[j] - old_w_j) * XtX[:, j]
95+
96+
97+
@njit
98+
def _construct_grad(y, XtXw, Xty, ws):
99+
n_samples = len(y)
100+
grad = np.zeros(len(ws))
101+
for idx, j in enumerate(ws):
102+
grad[idx] = (XtXw[j] - Xty[j]) / n_samples
103+
return grad
104+
105+
106+
@njit
107+
def _quadratic_value(X, w, XtXw, Xty):
108+
n_samples = X.shape[0]
109+
return w @ XtXw / (2*n_samples) - Xty @ w / n_samples

skglm/tests/test_gram_solver.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import pytest
2+
from itertools import product
3+
4+
import numpy as np
5+
from numpy.linalg import norm
6+
from sklearn.linear_model import Lasso
7+
8+
from skglm.penalties import L1
9+
from skglm.solvers.gram import gram_solver
10+
from skglm.utils import make_correlated_data, compiled_clone
11+
12+
13+
@pytest.mark.parametrize("n_samples, n_features",
14+
product([100, 200], [50, 90]))
15+
def test_alpha_max(n_samples, n_features):
16+
X, y, _ = make_correlated_data(n_samples, n_features, random_state=0)
17+
alpha_max = norm(X.T @ y, ord=np.inf) / n_samples
18+
19+
l1_penalty = compiled_clone(L1(alpha_max))
20+
w = gram_solver(X, y, l1_penalty, tol=1e-9, verbose=2)[0]
21+
22+
np.testing.assert_equal(w, 0)
23+
24+
25+
@pytest.mark.parametrize("n_samples, n_features, rho",
26+
product([50, 100], [20, 80], [1e-1, 1e-2]))
27+
def test_vs_lasso_sklearn(n_samples, n_features, rho):
28+
X, y, _ = make_correlated_data(n_samples, n_features, random_state=0)
29+
alpha_max = norm(X.T @ y, ord=np.inf) / n_samples
30+
alpha = rho * alpha_max
31+
32+
sk_lasso = Lasso(alpha, fit_intercept=False, tol=1e-9)
33+
sk_lasso.fit(X, y)
34+
35+
l1_penalty = compiled_clone(L1(alpha))
36+
w = gram_solver(X, y, l1_penalty, tol=1e-9, verbose=0, p0=10)[0]
37+
38+
print(
39+
f"skglm: {compute_obj(X, y, alpha, w)}\n"
40+
f"sklearn: {compute_obj(X, y, alpha, sk_lasso.coef_.flatten())}"
41+
)
42+
43+
# np.testing.assert_allclose(w, sk_lasso.coef_.flatten(), rtol=1e-5, atol=1e-5)
44+
45+
46+
def compute_obj(X, y, alpha, coef):
47+
return norm(y - X @ coef) ** 2 / (2 * len(y)) + alpha * norm(coef, ord=1)
48+
49+
50+
if __name__ == '__main__':
51+
test_vs_lasso_sklearn(50, 80, 0.01)
52+
53+
# print(
54+
# f"skglm: {compute_obj(X, y, alpha, w)}\n"
55+
# f"sklearn: {compute_obj(X, y, alpha, sk_lasso.coef_.flatten())}"
56+
# )
57+
pass

0 commit comments

Comments
 (0)