From 999b89ec44538f5126bbbc029938c99dd4175c35 Mon Sep 17 00:00:00 2001 From: Pierre-Antoine Bannier Date: Thu, 21 Apr 2022 00:07:44 +0200 Subject: [PATCH 01/21] add gram_solver --- skglm/gram_solver.py | 170 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 170 insertions(+) create mode 100644 skglm/gram_solver.py diff --git a/skglm/gram_solver.py b/skglm/gram_solver.py new file mode 100644 index 000000000..0dbd60b53 --- /dev/null +++ b/skglm/gram_solver.py @@ -0,0 +1,170 @@ +from time import time +import numpy as np +from numpy.linalg import norm +from numba import njit +from celer import Lasso, GroupLasso +from benchopt.datasets.simulated import make_correlated_data +from skglm.utils import BST, ST + + +def _grp_converter(groups, n_features): + if isinstance(groups, int): + grp_size = groups + if n_features % grp_size != 0: + raise ValueError("n_features (%d) is not a multiple of the desired" + " group size (%d)" % (n_features, grp_size)) + n_groups = n_features // grp_size + grp_ptr = grp_size * np.arange(n_groups + 1) + grp_indices = np.arange(n_features) + elif isinstance(groups, list) and isinstance(groups[0], int): + grp_indices = np.arange(n_features).astype(np.int32) + grp_ptr = np.cumsum(np.hstack([[0], groups])) + elif isinstance(groups, list) and isinstance(groups[0], list): + grp_sizes = np.array([len(ls) for ls in groups]) + grp_ptr = np.cumsum(np.hstack([[0], grp_sizes])) + grp_indices = np.array([idx for grp in groups for idx in grp]) + else: + raise ValueError("Unsupported group format.") + return grp_ptr.astype(np.int32), grp_indices.astype(np.int32) + + +@njit +def primal(alpha, y, X, w): + r = y - X @ w + p_obj = (r @ r) / (2 * len(y)) + return p_obj + alpha * np.sum(np.abs(w)) + + +@njit +def primal_grp(alpha, y, X, w, grp_ptr, grp_indices): + r = y - X @ w + p_obj = (r @ r) / (2 * len(y)) + for g in range(len(grp_ptr) - 1): + w_g = w[grp_indices[grp_ptr[g]:grp_ptr[g + 1]]] + p_obj += alpha * norm(w_g, ord=2) + return p_obj + + +@njit +def cd_epoch(X, G, grads, w, alpha, lipschitz): + n_features = X.shape[1] + for j in range(n_features): + if lipschitz[j] == 0.: + continue + old_w_j = w[j] + w[j] = ST(w[j] + grads[j] / lipschitz[j], alpha / lipschitz[j]) + if old_w_j != w[j]: + grads += G[j, :] * (old_w_j - w[j]) / len(X) + + +@njit +def bcd_epoch(X, G, grads, w, alpha, lipschitz, grp_indices, grp_ptr): + n_groups = len(grp_ptr) - 1 + for g in range(n_groups): + if lipschitz[g] == 0.: + continue + idx = grp_indices[grp_ptr[g]:grp_ptr[g + 1]] + old_w_g = w[idx].copy() + w[idx] = BST(w[idx] + grads[idx] / lipschitz[g], alpha / lipschitz[g]) + diff = old_w_g - w[idx] + if np.any(diff != 0.): + grads += diff @ G[idx, :] / len(X) + + +def lasso(X, y, alpha, max_iter, tol, check_freq=10): + p_obj_prev = np.inf + n_features = X.shape[1] + # Initialization + grads = X.T @ y / len(y) + G = X.T @ X + lipschitz = np.zeros(n_features, dtype=X.dtype) + for j in range(n_features): + lipschitz[j] = (X[:, j] ** 2).sum() / len(y) + w = np.zeros(n_features) + # CD + for n_iter in range(max_iter): + cd_epoch(X, G, grads, w, alpha, lipschitz) + if n_iter % check_freq == 0: + p_obj = primal(alpha, y, X, w) + if p_obj_prev - p_obj < tol: + print("Convergence reached!") + break + print(f"iter {n_iter} :: p_obj {p_obj}") + p_obj_prev = p_obj + return w + + +def group_lasso(X, y, alpha, groups, max_iter, tol, check_freq=50): + p_obj_prev = np.inf + n_features = X.shape[1] + grp_ptr, grp_indices = _grp_converter(groups, X.shape[1]) + n_groups = len(grp_ptr) - 1 + # Initialization + grads = X.T @ y / len(y) + G = X.T @ X + lipschitz = np.zeros(n_groups, dtype=X.dtype) + for g in range(n_groups): + X_g = X[:, grp_indices[grp_ptr[g]:grp_ptr[g + 1]]] + lipschitz[g] = norm(X_g, ord=2) ** 2 / len(y) + w = np.zeros(n_features) + # BCD + for n_iter in range(max_iter): + bcd_epoch(X, G, grads, w, alpha, lipschitz, grp_indices, grp_ptr) + if n_iter % check_freq == 0: + p_obj = primal_grp(alpha, y, X, w, grp_ptr, grp_indices) + if p_obj_prev - p_obj < tol: + print("Convergence reached!") + break + print(f"iter {n_iter} :: p_obj {p_obj}") + p_obj_prev = p_obj + return w + + +n_samples, n_features = 1_000_000, 300 +X, y, w_star = make_correlated_data( + n_samples=n_samples, n_features=n_features, random_state=0) +alpha_max = norm(X.T @ y, ord=np.inf) + +# Hyperparameters +max_iter = 1000 +tol = 1e-8 +reg = 0.1 +group_size = 3 + +alpha = alpha_max * reg / n_samples + +# Lasso +print("#" * 15) +print("Lasso") +print("#" * 15) +start = time() +w = lasso(X, y, alpha, max_iter, tol) +gram_lasso_time = time() - start +clf_sk = Lasso(alpha, tol=tol, fit_intercept=False) +start = time() +clf_sk.fit(X, y) +celer_lasso_time = time() - start +np.testing.assert_allclose(w, clf_sk.coef_, rtol=1e-5) + +print("\n") +print("Celer: %.2f" % celer_lasso_time) +print("Gram: %.2f" % gram_lasso_time) +print("\n") + +# Group Lasso +print("#" * 15) +print("Group Lasso") +print("#" * 15) +start = time() +w = group_lasso(X, y, alpha, group_size, max_iter, tol) +gram_group_lasso_time = time() - start +clf_celer = GroupLasso(group_size, alpha, tol=tol, fit_intercept=False) +start = time() +clf_celer.fit(X, y) +celer_group_lasso_time = time() - start +np.testing.assert_allclose(w, clf_celer.coef_, rtol=1e-1) + +print("\n") +print("Celer: %.2f" % celer_group_lasso_time) +print("Gram: %.2f" % gram_group_lasso_time) +print("\n") From af666a57156f27bc1d48d49a56b40d5e297d80f4 Mon Sep 17 00:00:00 2001 From: mathurinm Date: Thu, 21 Apr 2022 08:57:51 +0200 Subject: [PATCH 02/21] test with large data --- gram_test.py | 57 ++++++++++++++++++++++++++ skglm/gram_solver.py | 97 ++++++++++++++++++++++---------------------- 2 files changed, 106 insertions(+), 48 deletions(-) create mode 100644 gram_test.py diff --git a/gram_test.py b/gram_test.py new file mode 100644 index 000000000..acb907a6f --- /dev/null +++ b/gram_test.py @@ -0,0 +1,57 @@ +# data available at https://www.dropbox.com/sh/32b3mr3xghi496g/AACNRS_NOsUXU-hrSLixNg0ja?dl=0 + + +import time +from numpy.linalg import norm +import matplotlib.pyplot as plt +import numpy as np +from celer import GroupLasso +from skglm.gram_solver import group_lasso + +X = np.load("design_matrix.npy") +y = np.load("target.npy") +groups = np.load("groups.npy") +weights = np.load("weights.npy") +# grps = [list(np.where(groups == i)[0]) for i in range(1, 33)] + + +alpha_ratio = 1e-2 +n_alphas = 10 + + +# Case 1: slower runtime for (very) small alphas +# alpha_max = 0.003471727067743962 +alpha_max = np.max(np.linalg.norm((X.T @ y).reshape(-1, 5), axis=1)) / len(y) +alpha = alpha_max / 50 +clf = GroupLasso(fit_intercept=False, + groups=5, alpha=alpha, verbose=1) + +t0 = time.time() +clf.fit(X, y) +t1 = time.time() + +print(f"Celer: {t1 - t0:.3f} s") + +t0 = time.time() +res = group_lasso(X, y, alpha, groups=5, tol=1e-10, max_iter=10_000, check_freq=10) +t1 = time.time() + +print(f"skglm gram: {t1 - t0:.3f} s") + +# # Case 2: slower runtime for (very) small alphas with weights +# # alpha_max_w = 0.0001897719130007628 +# alpha_max_w = np.max(norm((X.T @ y).reshape(-1, 5) / +# weights[:, None], axis=1)) / len(y) + + +# alpha_ratio = 0.1 +# grid_w = np.geomspace(alpha_max_w*alpha_ratio, alpha_max_w, n_alphas)[::-1] +# clf = GroupLasso(fit_intercept=False, +# weights=weights, groups=grps, warm_start=True) + +# # for alpha in grid_w: +# # clf.alpha = alpha +# # t0 = time.time() +# # clf.fit(X, y) +# t1 = time.time() +# print(f"Finished tuning with {alpha:.2e}. Took {t1-t0:.2f} seconds!") diff --git a/skglm/gram_solver.py b/skglm/gram_solver.py index 0dbd60b53..c50708862 100644 --- a/skglm/gram_solver.py +++ b/skglm/gram_solver.py @@ -120,51 +120,52 @@ def group_lasso(X, y, alpha, groups, max_iter, tol, check_freq=50): return w -n_samples, n_features = 1_000_000, 300 -X, y, w_star = make_correlated_data( - n_samples=n_samples, n_features=n_features, random_state=0) -alpha_max = norm(X.T @ y, ord=np.inf) - -# Hyperparameters -max_iter = 1000 -tol = 1e-8 -reg = 0.1 -group_size = 3 - -alpha = alpha_max * reg / n_samples - -# Lasso -print("#" * 15) -print("Lasso") -print("#" * 15) -start = time() -w = lasso(X, y, alpha, max_iter, tol) -gram_lasso_time = time() - start -clf_sk = Lasso(alpha, tol=tol, fit_intercept=False) -start = time() -clf_sk.fit(X, y) -celer_lasso_time = time() - start -np.testing.assert_allclose(w, clf_sk.coef_, rtol=1e-5) - -print("\n") -print("Celer: %.2f" % celer_lasso_time) -print("Gram: %.2f" % gram_lasso_time) -print("\n") - -# Group Lasso -print("#" * 15) -print("Group Lasso") -print("#" * 15) -start = time() -w = group_lasso(X, y, alpha, group_size, max_iter, tol) -gram_group_lasso_time = time() - start -clf_celer = GroupLasso(group_size, alpha, tol=tol, fit_intercept=False) -start = time() -clf_celer.fit(X, y) -celer_group_lasso_time = time() - start -np.testing.assert_allclose(w, clf_celer.coef_, rtol=1e-1) - -print("\n") -print("Celer: %.2f" % celer_group_lasso_time) -print("Gram: %.2f" % gram_group_lasso_time) -print("\n") +if __name__ == "__main__": + n_samples, n_features = 1_000_000, 300 + X, y, w_star = make_correlated_data( + n_samples=n_samples, n_features=n_features, random_state=0) + alpha_max = norm(X.T @ y, ord=np.inf) + + # Hyperparameters + max_iter = 1000 + tol = 1e-8 + reg = 0.1 + group_size = 3 + + alpha = alpha_max * reg / n_samples + + # Lasso + print("#" * 15) + print("Lasso") + print("#" * 15) + start = time() + w = lasso(X, y, alpha, max_iter, tol) + gram_lasso_time = time() - start + clf_sk = Lasso(alpha, tol=tol, fit_intercept=False) + start = time() + clf_sk.fit(X, y) + celer_lasso_time = time() - start + np.testing.assert_allclose(w, clf_sk.coef_, rtol=1e-5) + + print("\n") + print("Celer: %.2f" % celer_lasso_time) + print("Gram: %.2f" % gram_lasso_time) + print("\n") + + # Group Lasso + print("#" * 15) + print("Group Lasso") + print("#" * 15) + start = time() + w = group_lasso(X, y, alpha, group_size, max_iter, tol) + gram_group_lasso_time = time() - start + clf_celer = GroupLasso(group_size, alpha, tol=tol, fit_intercept=False) + start = time() + clf_celer.fit(X, y) + celer_group_lasso_time = time() - start + np.testing.assert_allclose(w, clf_celer.coef_, rtol=1e-1) + + print("\n") + print("Celer: %.2f" % celer_group_lasso_time) + print("Gram: %.2f" % gram_group_lasso_time) + print("\n") From 787c8c21c1d8185dfa9d9660e358e8628a5b9b48 Mon Sep 17 00:00:00 2001 From: mathurinm Date: Thu, 21 Apr 2022 09:15:39 +0200 Subject: [PATCH 03/21] isolate gram solver in solvers submodule --- gram_test.py | 24 +---- skglm/gram_solver.py | 217 ++++++++++-------------------------------- skglm/solvers/gram.py | 96 +++++++++++++++++++ 3 files changed, 152 insertions(+), 185 deletions(-) create mode 100644 skglm/solvers/gram.py diff --git a/gram_test.py b/gram_test.py index acb907a6f..e3d25e69b 100644 --- a/gram_test.py +++ b/gram_test.py @@ -6,7 +6,7 @@ import matplotlib.pyplot as plt import numpy as np from celer import GroupLasso -from skglm.gram_solver import group_lasso +from skglm.solvers.gram import gram_group_lasso X = np.load("design_matrix.npy") y = np.load("target.npy") @@ -22,7 +22,7 @@ # Case 1: slower runtime for (very) small alphas # alpha_max = 0.003471727067743962 alpha_max = np.max(np.linalg.norm((X.T @ y).reshape(-1, 5), axis=1)) / len(y) -alpha = alpha_max / 50 +alpha = alpha_max / 100 clf = GroupLasso(fit_intercept=False, groups=5, alpha=alpha, verbose=1) @@ -32,26 +32,12 @@ print(f"Celer: {t1 - t0:.3f} s") +# beware: stopping criterion is not the same, tol here needs to be lower +# to get meaningful comparison t0 = time.time() res = group_lasso(X, y, alpha, groups=5, tol=1e-10, max_iter=10_000, check_freq=10) t1 = time.time() print(f"skglm gram: {t1 - t0:.3f} s") -# # Case 2: slower runtime for (very) small alphas with weights -# # alpha_max_w = 0.0001897719130007628 -# alpha_max_w = np.max(norm((X.T @ y).reshape(-1, 5) / -# weights[:, None], axis=1)) / len(y) - - -# alpha_ratio = 0.1 -# grid_w = np.geomspace(alpha_max_w*alpha_ratio, alpha_max_w, n_alphas)[::-1] -# clf = GroupLasso(fit_intercept=False, -# weights=weights, groups=grps, warm_start=True) - -# # for alpha in grid_w: -# # clf.alpha = alpha -# # t0 = time.time() -# # clf.fit(X, y) -# t1 = time.time() -# print(f"Finished tuning with {alpha:.2e}. Took {t1-t0:.2f} seconds!") +# TODO support weights in gram solver diff --git a/skglm/gram_solver.py b/skglm/gram_solver.py index c50708862..6ec2aaadd 100644 --- a/skglm/gram_solver.py +++ b/skglm/gram_solver.py @@ -1,171 +1,56 @@ from time import time import numpy as np from numpy.linalg import norm -from numba import njit from celer import Lasso, GroupLasso from benchopt.datasets.simulated import make_correlated_data -from skglm.utils import BST, ST - - -def _grp_converter(groups, n_features): - if isinstance(groups, int): - grp_size = groups - if n_features % grp_size != 0: - raise ValueError("n_features (%d) is not a multiple of the desired" - " group size (%d)" % (n_features, grp_size)) - n_groups = n_features // grp_size - grp_ptr = grp_size * np.arange(n_groups + 1) - grp_indices = np.arange(n_features) - elif isinstance(groups, list) and isinstance(groups[0], int): - grp_indices = np.arange(n_features).astype(np.int32) - grp_ptr = np.cumsum(np.hstack([[0], groups])) - elif isinstance(groups, list) and isinstance(groups[0], list): - grp_sizes = np.array([len(ls) for ls in groups]) - grp_ptr = np.cumsum(np.hstack([[0], grp_sizes])) - grp_indices = np.array([idx for grp in groups for idx in grp]) - else: - raise ValueError("Unsupported group format.") - return grp_ptr.astype(np.int32), grp_indices.astype(np.int32) - - -@njit -def primal(alpha, y, X, w): - r = y - X @ w - p_obj = (r @ r) / (2 * len(y)) - return p_obj + alpha * np.sum(np.abs(w)) - - -@njit -def primal_grp(alpha, y, X, w, grp_ptr, grp_indices): - r = y - X @ w - p_obj = (r @ r) / (2 * len(y)) - for g in range(len(grp_ptr) - 1): - w_g = w[grp_indices[grp_ptr[g]:grp_ptr[g + 1]]] - p_obj += alpha * norm(w_g, ord=2) - return p_obj - - -@njit -def cd_epoch(X, G, grads, w, alpha, lipschitz): - n_features = X.shape[1] - for j in range(n_features): - if lipschitz[j] == 0.: - continue - old_w_j = w[j] - w[j] = ST(w[j] + grads[j] / lipschitz[j], alpha / lipschitz[j]) - if old_w_j != w[j]: - grads += G[j, :] * (old_w_j - w[j]) / len(X) - - -@njit -def bcd_epoch(X, G, grads, w, alpha, lipschitz, grp_indices, grp_ptr): - n_groups = len(grp_ptr) - 1 - for g in range(n_groups): - if lipschitz[g] == 0.: - continue - idx = grp_indices[grp_ptr[g]:grp_ptr[g + 1]] - old_w_g = w[idx].copy() - w[idx] = BST(w[idx] + grads[idx] / lipschitz[g], alpha / lipschitz[g]) - diff = old_w_g - w[idx] - if np.any(diff != 0.): - grads += diff @ G[idx, :] / len(X) - - -def lasso(X, y, alpha, max_iter, tol, check_freq=10): - p_obj_prev = np.inf - n_features = X.shape[1] - # Initialization - grads = X.T @ y / len(y) - G = X.T @ X - lipschitz = np.zeros(n_features, dtype=X.dtype) - for j in range(n_features): - lipschitz[j] = (X[:, j] ** 2).sum() / len(y) - w = np.zeros(n_features) - # CD - for n_iter in range(max_iter): - cd_epoch(X, G, grads, w, alpha, lipschitz) - if n_iter % check_freq == 0: - p_obj = primal(alpha, y, X, w) - if p_obj_prev - p_obj < tol: - print("Convergence reached!") - break - print(f"iter {n_iter} :: p_obj {p_obj}") - p_obj_prev = p_obj - return w - - -def group_lasso(X, y, alpha, groups, max_iter, tol, check_freq=50): - p_obj_prev = np.inf - n_features = X.shape[1] - grp_ptr, grp_indices = _grp_converter(groups, X.shape[1]) - n_groups = len(grp_ptr) - 1 - # Initialization - grads = X.T @ y / len(y) - G = X.T @ X - lipschitz = np.zeros(n_groups, dtype=X.dtype) - for g in range(n_groups): - X_g = X[:, grp_indices[grp_ptr[g]:grp_ptr[g + 1]]] - lipschitz[g] = norm(X_g, ord=2) ** 2 / len(y) - w = np.zeros(n_features) - # BCD - for n_iter in range(max_iter): - bcd_epoch(X, G, grads, w, alpha, lipschitz, grp_indices, grp_ptr) - if n_iter % check_freq == 0: - p_obj = primal_grp(alpha, y, X, w, grp_ptr, grp_indices) - if p_obj_prev - p_obj < tol: - print("Convergence reached!") - break - print(f"iter {n_iter} :: p_obj {p_obj}") - p_obj_prev = p_obj - return w - - -if __name__ == "__main__": - n_samples, n_features = 1_000_000, 300 - X, y, w_star = make_correlated_data( - n_samples=n_samples, n_features=n_features, random_state=0) - alpha_max = norm(X.T @ y, ord=np.inf) - - # Hyperparameters - max_iter = 1000 - tol = 1e-8 - reg = 0.1 - group_size = 3 - - alpha = alpha_max * reg / n_samples - - # Lasso - print("#" * 15) - print("Lasso") - print("#" * 15) - start = time() - w = lasso(X, y, alpha, max_iter, tol) - gram_lasso_time = time() - start - clf_sk = Lasso(alpha, tol=tol, fit_intercept=False) - start = time() - clf_sk.fit(X, y) - celer_lasso_time = time() - start - np.testing.assert_allclose(w, clf_sk.coef_, rtol=1e-5) - - print("\n") - print("Celer: %.2f" % celer_lasso_time) - print("Gram: %.2f" % gram_lasso_time) - print("\n") - - # Group Lasso - print("#" * 15) - print("Group Lasso") - print("#" * 15) - start = time() - w = group_lasso(X, y, alpha, group_size, max_iter, tol) - gram_group_lasso_time = time() - start - clf_celer = GroupLasso(group_size, alpha, tol=tol, fit_intercept=False) - start = time() - clf_celer.fit(X, y) - celer_group_lasso_time = time() - start - np.testing.assert_allclose(w, clf_celer.coef_, rtol=1e-1) - - print("\n") - print("Celer: %.2f" % celer_group_lasso_time) - print("Gram: %.2f" % gram_group_lasso_time) - print("\n") +from skglm.solvers.gram import gram_lasso, gram_group_lasso + + +n_samples, n_features = 1_000_000, 300 +X, y, w_star = make_correlated_data( + n_samples=n_samples, n_features=n_features, random_state=0) +alpha_max = norm(X.T @ y, ord=np.inf) + +# Hyperparameters +max_iter = 1000 +tol = 1e-8 +reg = 0.1 +group_size = 3 + +alpha = alpha_max * reg / n_samples + +# Lasso +print("#" * 15) +print("Lasso") +print("#" * 15) +start = time() +w = gram_lasso(X, y, alpha, max_iter, tol) +gram_lasso_time = time() - start +clf_sk = Lasso(alpha, tol=tol, fit_intercept=False) +start = time() +clf_sk.fit(X, y) +celer_lasso_time = time() - start +np.testing.assert_allclose(w, clf_sk.coef_, rtol=1e-5) + +print("\n") +print("Celer: %.2f" % celer_lasso_time) +print("Gram: %.2f" % gram_lasso_time) +print("\n") + +# Group Lasso +print("#" * 15) +print("Group Lasso") +print("#" * 15) +start = time() +w = gram_group_lasso(X, y, alpha, group_size, max_iter, tol) +gram_group_lasso_time = time() - start +clf_celer = GroupLasso(group_size, alpha, tol=tol, fit_intercept=False) +start = time() +clf_celer.fit(X, y) +celer_group_lasso_time = time() - start +np.testing.assert_allclose(w, clf_celer.coef_, rtol=1e-1) + +print("\n") +print("Celer: %.2f" % celer_group_lasso_time) +print("Gram: %.2f" % gram_group_lasso_time) +print("\n") diff --git a/skglm/solvers/gram.py b/skglm/solvers/gram.py new file mode 100644 index 000000000..38a238abb --- /dev/null +++ b/skglm/solvers/gram.py @@ -0,0 +1,96 @@ +import numpy as np +from numba import njit +from numpy.linalg import norm +from celer.homotopy import _grp_converter + +from skglm.utils import BST, ST + + +@njit +def primal(alpha, y, X, w): + r = y - X @ w + p_obj = (r @ r) / (2 * len(y)) + return p_obj + alpha * np.sum(np.abs(w)) + + +@njit +def primal_grp(alpha, y, X, w, grp_ptr, grp_indices): + r = y - X @ w + p_obj = (r @ r) / (2 * len(y)) + for g in range(len(grp_ptr) - 1): + w_g = w[grp_indices[grp_ptr[g]:grp_ptr[g + 1]]] + p_obj += alpha * norm(w_g, ord=2) + return p_obj + + +def gram_lasso(X, y, alpha, max_iter, tol, check_freq=10): + p_obj_prev = np.inf + n_features = X.shape[1] + grads = X.T @ y / len(y) + G = X.T @ X + lipschitz = np.zeros(n_features, dtype=X.dtype) + for j in range(n_features): + lipschitz[j] = (X[:, j] ** 2).sum() / len(y) + w = np.zeros(n_features) + # CD + for n_iter in range(max_iter): + cd_epoch(X, G, grads, w, alpha, lipschitz) + if n_iter % check_freq == 0: + p_obj = primal(alpha, y, X, w) + if p_obj_prev - p_obj < tol: + print("Convergence reached!") + break + print(f"iter {n_iter} :: p_obj {p_obj}") + p_obj_prev = p_obj + return w + + +def gram_group_lasso(X, y, alpha, groups, max_iter, tol, check_freq=50): + p_obj_prev = np.inf + n_features = X.shape[1] + grp_ptr, grp_indices = _grp_converter(groups, X.shape[1]) + n_groups = len(grp_ptr) - 1 + grads = X.T @ y / len(y) + G = X.T @ X + lipschitz = np.zeros(n_groups, dtype=X.dtype) + for g in range(n_groups): + X_g = X[:, grp_indices[grp_ptr[g]:grp_ptr[g + 1]]] + lipschitz[g] = norm(X_g, ord=2) ** 2 / len(y) + w = np.zeros(n_features) + # BCD + for n_iter in range(max_iter): + bcd_epoch(X, G, grads, w, alpha, lipschitz, grp_indices, grp_ptr) + if n_iter % check_freq == 0: + p_obj = primal_grp(alpha, y, X, w, grp_ptr, grp_indices) + if p_obj_prev - p_obj < tol: + print("Convergence reached!") + break + print(f"iter {n_iter} :: p_obj {p_obj}") + p_obj_prev = p_obj + return w + + +@njit +def cd_epoch(X, G, grads, w, alpha, lipschitz): + n_features = X.shape[1] + for j in range(n_features): + if lipschitz[j] == 0.: + continue + old_w_j = w[j] + w[j] = ST(w[j] + grads[j] / lipschitz[j], alpha / lipschitz[j]) + if old_w_j != w[j]: + grads += G[j, :] * (old_w_j - w[j]) / len(X) + + +@njit +def bcd_epoch(X, G, grads, w, alpha, lipschitz, grp_indices, grp_ptr): + n_groups = len(grp_ptr) - 1 + for g in range(n_groups): + if lipschitz[g] == 0.: + continue + idx = grp_indices[grp_ptr[g]:grp_ptr[g + 1]] + old_w_g = w[idx].copy() + w[idx] = BST(w[idx] + grads[idx] / lipschitz[g], alpha / lipschitz[g]) + diff = old_w_g - w[idx] + if np.any(diff != 0.): + grads += diff @ G[idx, :] / len(X) From 51b4cfed75ddc500a984445d6397040392b7e0af Mon Sep 17 00:00:00 2001 From: Pierre-Antoine Bannier Date: Fri, 22 Apr 2022 17:53:27 +0200 Subject: [PATCH 04/21] added weights and warm_start --- skglm/gram_solver.py | 12 ++++++++---- skglm/solvers/gram.py | 40 ++++++++++++++++++++++------------------ 2 files changed, 30 insertions(+), 22 deletions(-) diff --git a/skglm/gram_solver.py b/skglm/gram_solver.py index 6ec2aaadd..4e64b746a 100644 --- a/skglm/gram_solver.py +++ b/skglm/gram_solver.py @@ -19,14 +19,17 @@ alpha = alpha_max * reg / n_samples +weights = np.random.normal(2, 0.4, n_features) +weights_grp = np.random.normal(2, 0.4, n_features // group_size) + # Lasso print("#" * 15) print("Lasso") print("#" * 15) start = time() -w = gram_lasso(X, y, alpha, max_iter, tol) +w = gram_lasso(X, y, alpha, max_iter, tol, weights=weights) gram_lasso_time = time() - start -clf_sk = Lasso(alpha, tol=tol, fit_intercept=False) +clf_sk = Lasso(alpha, weights=weights, tol=tol, fit_intercept=False) start = time() clf_sk.fit(X, y) celer_lasso_time = time() - start @@ -42,9 +45,10 @@ print("Group Lasso") print("#" * 15) start = time() -w = gram_group_lasso(X, y, alpha, group_size, max_iter, tol) +w = gram_group_lasso(X, y, alpha, group_size, max_iter, tol, weights=weights_grp) gram_group_lasso_time = time() - start -clf_celer = GroupLasso(group_size, alpha, tol=tol, fit_intercept=False) +clf_celer = GroupLasso(group_size, alpha, weights=weights_grp, tol=tol, + fit_intercept=False) start = time() clf_celer.fit(X, y) celer_group_lasso_time = time() - start diff --git a/skglm/solvers/gram.py b/skglm/solvers/gram.py index 38a238abb..915ae8741 100644 --- a/skglm/solvers/gram.py +++ b/skglm/solvers/gram.py @@ -7,23 +7,23 @@ @njit -def primal(alpha, y, X, w): +def primal(alpha, y, X, w, weights): r = y - X @ w p_obj = (r @ r) / (2 * len(y)) - return p_obj + alpha * np.sum(np.abs(w)) + return p_obj + alpha * np.sum(np.abs(w * weights)) @njit -def primal_grp(alpha, y, X, w, grp_ptr, grp_indices): +def primal_grp(alpha, y, X, w, grp_ptr, grp_indices, weights): r = y - X @ w p_obj = (r @ r) / (2 * len(y)) for g in range(len(grp_ptr) - 1): w_g = w[grp_indices[grp_ptr[g]:grp_ptr[g + 1]]] - p_obj += alpha * norm(w_g, ord=2) + p_obj += alpha * norm(w_g * weights[g], ord=2) return p_obj -def gram_lasso(X, y, alpha, max_iter, tol, check_freq=10): +def gram_lasso(X, y, alpha, max_iter, tol, w_init=None, weights=None, check_freq=10): p_obj_prev = np.inf n_features = X.shape[1] grads = X.T @ y / len(y) @@ -31,12 +31,13 @@ def gram_lasso(X, y, alpha, max_iter, tol, check_freq=10): lipschitz = np.zeros(n_features, dtype=X.dtype) for j in range(n_features): lipschitz[j] = (X[:, j] ** 2).sum() / len(y) - w = np.zeros(n_features) + w = w_init if w_init is not None else np.zeros(n_features) + weights = weights if weights is not None else np.ones(n_features) # CD for n_iter in range(max_iter): - cd_epoch(X, G, grads, w, alpha, lipschitz) + cd_epoch(X, G, grads, w, alpha, lipschitz, weights) if n_iter % check_freq == 0: - p_obj = primal(alpha, y, X, w) + p_obj = primal(alpha, y, X, w, weights) if p_obj_prev - p_obj < tol: print("Convergence reached!") break @@ -45,7 +46,8 @@ def gram_lasso(X, y, alpha, max_iter, tol, check_freq=10): return w -def gram_group_lasso(X, y, alpha, groups, max_iter, tol, check_freq=50): +def gram_group_lasso(X, y, alpha, groups, max_iter, tol, w_init=None, weights=None, + check_freq=50): p_obj_prev = np.inf n_features = X.shape[1] grp_ptr, grp_indices = _grp_converter(groups, X.shape[1]) @@ -56,12 +58,13 @@ def gram_group_lasso(X, y, alpha, groups, max_iter, tol, check_freq=50): for g in range(n_groups): X_g = X[:, grp_indices[grp_ptr[g]:grp_ptr[g + 1]]] lipschitz[g] = norm(X_g, ord=2) ** 2 / len(y) - w = np.zeros(n_features) + w = w_init if w_init is not None else np.zeros(n_features) + weights = weights if weights is not None else np.ones(n_groups) # BCD for n_iter in range(max_iter): - bcd_epoch(X, G, grads, w, alpha, lipschitz, grp_indices, grp_ptr) + bcd_epoch(X, G, grads, w, alpha, lipschitz, grp_indices, grp_ptr, weights) if n_iter % check_freq == 0: - p_obj = primal_grp(alpha, y, X, w, grp_ptr, grp_indices) + p_obj = primal_grp(alpha, y, X, w, grp_ptr, grp_indices, weights) if p_obj_prev - p_obj < tol: print("Convergence reached!") break @@ -71,26 +74,27 @@ def gram_group_lasso(X, y, alpha, groups, max_iter, tol, check_freq=50): @njit -def cd_epoch(X, G, grads, w, alpha, lipschitz): +def cd_epoch(X, G, grads, w, alpha, lipschitz, weights): n_features = X.shape[1] for j in range(n_features): - if lipschitz[j] == 0.: + if lipschitz[j] == 0. or weights[j] == np.inf: continue old_w_j = w[j] - w[j] = ST(w[j] + grads[j] / lipschitz[j], alpha / lipschitz[j]) + w[j] = ST(w[j] + grads[j] / lipschitz[j], alpha / lipschitz[j] * weights[j]) if old_w_j != w[j]: grads += G[j, :] * (old_w_j - w[j]) / len(X) @njit -def bcd_epoch(X, G, grads, w, alpha, lipschitz, grp_indices, grp_ptr): +def bcd_epoch(X, G, grads, w, alpha, lipschitz, grp_indices, grp_ptr, weights): n_groups = len(grp_ptr) - 1 for g in range(n_groups): - if lipschitz[g] == 0.: + if lipschitz[g] == 0. and weights[g] == np.inf: continue idx = grp_indices[grp_ptr[g]:grp_ptr[g + 1]] old_w_g = w[idx].copy() - w[idx] = BST(w[idx] + grads[idx] / lipschitz[g], alpha / lipschitz[g]) + w[idx] = BST(w[idx] + grads[idx] / lipschitz[g], alpha / lipschitz[g] + * weights[g]) diff = old_w_g - w[idx] if np.any(diff != 0.): grads += diff @ G[idx, :] / len(X) From a7db68b7459e79afbb4efa72078cb6e0e5c2913c Mon Sep 17 00:00:00 2001 From: Pierre-Antoine Bannier Date: Fri, 22 Apr 2022 19:13:49 +0200 Subject: [PATCH 05/21] WIP FISTA --- skglm/solvers/gram.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/skglm/solvers/gram.py b/skglm/solvers/gram.py index 915ae8741..317cea99b 100644 --- a/skglm/solvers/gram.py +++ b/skglm/solvers/gram.py @@ -31,11 +31,15 @@ def gram_lasso(X, y, alpha, max_iter, tol, w_init=None, weights=None, check_freq lipschitz = np.zeros(n_features, dtype=X.dtype) for j in range(n_features): lipschitz[j] = (X[:, j] ** 2).sum() / len(y) - w = w_init if w_init is not None else np.zeros(n_features) + w = w_init.copy() if w_init is not None else np.zeros(n_features) + z = w_init.copy() if w_init is not None else np.zeros(n_features) + beta_0 = beta_1 = 1 weights = weights if weights is not None else np.ones(n_features) # CD for n_iter in range(max_iter): - cd_epoch(X, G, grads, w, alpha, lipschitz, weights) + beta_1 = (1 + np.sqrt(1 + 4 * beta_0 ** 2)) / 2 + cd_epoch(X, G, grads, w, z, alpha, beta_1, beta_0, lipschitz, weights) + beta_0 = beta_1 if n_iter % check_freq == 0: p_obj = primal(alpha, y, X, w, weights) if p_obj_prev - p_obj < tol: @@ -58,7 +62,7 @@ def gram_group_lasso(X, y, alpha, groups, max_iter, tol, w_init=None, weights=No for g in range(n_groups): X_g = X[:, grp_indices[grp_ptr[g]:grp_ptr[g + 1]]] lipschitz[g] = norm(X_g, ord=2) ** 2 / len(y) - w = w_init if w_init is not None else np.zeros(n_features) + w = w_init.copy() if w_init is not None else np.zeros(n_features) weights = weights if weights is not None else np.ones(n_groups) # BCD for n_iter in range(max_iter): @@ -74,15 +78,17 @@ def gram_group_lasso(X, y, alpha, groups, max_iter, tol, w_init=None, weights=No @njit -def cd_epoch(X, G, grads, w, alpha, lipschitz, weights): +def cd_epoch(X, G, grads, w, z, alpha, beta_1, beta_0, lipschitz, weights): n_features = X.shape[1] for j in range(n_features): if lipschitz[j] == 0. or weights[j] == np.inf: continue old_w_j = w[j] - w[j] = ST(w[j] + grads[j] / lipschitz[j], alpha / lipschitz[j] * weights[j]) - if old_w_j != w[j]: - grads += G[j, :] * (old_w_j - w[j]) / len(X) + old_z_j = z[j] + w[j] = ST(z[j] + grads[j] / lipschitz[j], alpha / lipschitz[j] * weights[j]) + z[j] = w[j] + ((beta_0 - 1) / beta_1) * (w[j] - old_w_j) + if old_z_j != z[j]: + grads += G[j, :] * (old_z_j - z[j]) / len(X) @njit From fbee02d4f0530b60324e9f7f1ea558a6cbde2801 Mon Sep 17 00:00:00 2001 From: Pierre-Antoine Bannier Date: Sun, 24 Apr 2022 23:32:43 +0200 Subject: [PATCH 06/21] added FISTA gram --- skglm/fista_gram.py | 91 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 91 insertions(+) create mode 100644 skglm/fista_gram.py diff --git a/skglm/fista_gram.py b/skglm/fista_gram.py new file mode 100644 index 000000000..9fce87397 --- /dev/null +++ b/skglm/fista_gram.py @@ -0,0 +1,91 @@ +import numpy as np +from numpy.linalg import norm +from numba import njit +from celer import Lasso +from benchopt.datasets.simulated import make_correlated_data +from skglm.utils import ST_vec + + +@njit +def primal(alpha, r, w): + p_obj = (r @ r) / (2 * len(r)) + return p_obj + alpha * np.sum(np.abs(w)) + +@njit +def dual(alpha, norm_y2, theta, y): + d_obj = - np.sum((y / (alpha * len(y)) - theta) ** 2) + d_obj *= 0.5 * alpha ** 2 * len(y) + d_obj += norm_y2 / (2 * len(y)) + return d_obj + +@njit +def dnorm_l1(theta, X): + n_features = X.shape[1] + scal = 0. + for j in range(n_features): + Xj_theta = X[:, j] @ theta + scal = max(scal, Xj_theta) + return scal + +@njit +def create_dual_point(r, alpha, X): + theta = r / (alpha * len(y)) + scal = dnorm_l1(theta, X) + if scal > 1.: + theta /= scal + return theta + +@njit +def dual_gap(alpha, norm_y2, y, X, w): + r = y - X @ w + p_obj = primal(alpha, r, w) + theta = create_dual_point(r, alpha, X) + d_obj = dual(alpha, norm_y2, theta, y) + return p_obj, d_obj, p_obj - d_obj + + +n_samples, n_features = 30, 50 +X, y, w_star = make_correlated_data( + n_samples=n_samples, n_features=n_features, random_state=0) +alpha_max = norm(X.T @ y, ord=np.inf) + +# Hyperparameters +max_iter = 1_000 +tol = 1e-5 +reg = 0.1 +group_size = 3 +check_gap_freq = 100 + +alpha = alpha_max * reg / n_samples + +L = np.linalg.norm(X, ord=2) ** 2 / n_samples + +G = X.T @ X +Xty = X.T @ y + +w = np.zeros(n_features) +z = np.zeros(n_features) + +norm_y2 = y @ y + +t_new = 1 + +for n_iter in range(max_iter): + t_old = t_new + t_new = (1 + np.sqrt(1 + 4 * t_old ** 2)) / 2 + w_old = w.copy() + z -= (G @ z - Xty) / L / n_samples + w = ST_vec(z, alpha / L) + z = w + (t_old - 1.) / t_new * (w - w_old) + + if n_iter % check_gap_freq == 0: + p_obj, d_obj, gap = dual_gap(alpha, norm_y2, y, X, w) + print(f"iter {n_iter} :: p_obj {p_obj:.5f} :: d_obj {d_obj:.5f} " + + f":: gap {gap:.5f}") + if gap < tol: + print("Convergence reached!") + break + +clf = Lasso(alpha, tol=tol, fit_intercept=False) +clf.fit(X, y) +np.testing.assert_allclose(w, clf.coef_, rtol=1e-3) From 5ffd02b81e64d7e729850503d792190672111cf2 Mon Sep 17 00:00:00 2001 From: Pierre-Antoine Bannier Date: Sun, 24 Apr 2022 23:36:41 +0200 Subject: [PATCH 07/21] larger examples --- skglm/fista_gram.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/skglm/fista_gram.py b/skglm/fista_gram.py index 9fce87397..a6a50e1b0 100644 --- a/skglm/fista_gram.py +++ b/skglm/fista_gram.py @@ -44,14 +44,14 @@ def dual_gap(alpha, norm_y2, y, X, w): return p_obj, d_obj, p_obj - d_obj -n_samples, n_features = 30, 50 +n_samples, n_features = 1_000_000, 300 X, y, w_star = make_correlated_data( n_samples=n_samples, n_features=n_features, random_state=0) alpha_max = norm(X.T @ y, ord=np.inf) # Hyperparameters -max_iter = 1_000 -tol = 1e-5 +max_iter = 1000 +tol = 1e-8 reg = 0.1 group_size = 3 check_gap_freq = 100 @@ -88,4 +88,4 @@ def dual_gap(alpha, norm_y2, y, X, w): clf = Lasso(alpha, tol=tol, fit_intercept=False) clf.fit(X, y) -np.testing.assert_allclose(w, clf.coef_, rtol=1e-3) +np.testing.assert_allclose(w, clf.coef_, rtol=1e-5) From d88df2aa1ded4207d07004ae7ce24998579dd195 Mon Sep 17 00:00:00 2001 From: Pierre-Antoine Bannier Date: Sun, 24 Apr 2022 23:43:48 +0200 Subject: [PATCH 08/21] added weights --- skglm/fista_gram.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/skglm/fista_gram.py b/skglm/fista_gram.py index a6a50e1b0..80cd1ac90 100644 --- a/skglm/fista_gram.py +++ b/skglm/fista_gram.py @@ -7,9 +7,9 @@ @njit -def primal(alpha, r, w): +def primal(alpha, r, w, weights): p_obj = (r @ r) / (2 * len(r)) - return p_obj + alpha * np.sum(np.abs(w)) + return p_obj + alpha * np.sum(np.abs(w * weights)) @njit def dual(alpha, norm_y2, theta, y): @@ -19,36 +19,38 @@ def dual(alpha, norm_y2, theta, y): return d_obj @njit -def dnorm_l1(theta, X): +def dnorm_l1(theta, X, weights): n_features = X.shape[1] scal = 0. for j in range(n_features): Xj_theta = X[:, j] @ theta - scal = max(scal, Xj_theta) + scal = max(scal, Xj_theta / weights[j]) return scal @njit -def create_dual_point(r, alpha, X): +def create_dual_point(r, alpha, X, weights): theta = r / (alpha * len(y)) - scal = dnorm_l1(theta, X) + scal = dnorm_l1(theta, X, weights) if scal > 1.: theta /= scal return theta @njit -def dual_gap(alpha, norm_y2, y, X, w): +def dual_gap(alpha, norm_y2, y, X, w, weights): r = y - X @ w - p_obj = primal(alpha, r, w) - theta = create_dual_point(r, alpha, X) + p_obj = primal(alpha, r, w, weights) + theta = create_dual_point(r, alpha, X, weights) d_obj = dual(alpha, norm_y2, theta, y) return p_obj, d_obj, p_obj - d_obj -n_samples, n_features = 1_000_000, 300 +n_samples, n_features = 1_000, 300 X, y, w_star = make_correlated_data( n_samples=n_samples, n_features=n_features, random_state=0) alpha_max = norm(X.T @ y, ord=np.inf) +weights = np.random.normal(2, 0.4, n_features) + # Hyperparameters max_iter = 1000 tol = 1e-8 @@ -75,17 +77,17 @@ def dual_gap(alpha, norm_y2, y, X, w): t_new = (1 + np.sqrt(1 + 4 * t_old ** 2)) / 2 w_old = w.copy() z -= (G @ z - Xty) / L / n_samples - w = ST_vec(z, alpha / L) + w = ST_vec(z, alpha / L * weights) z = w + (t_old - 1.) / t_new * (w - w_old) if n_iter % check_gap_freq == 0: - p_obj, d_obj, gap = dual_gap(alpha, norm_y2, y, X, w) + p_obj, d_obj, gap = dual_gap(alpha, norm_y2, y, X, w, weights) print(f"iter {n_iter} :: p_obj {p_obj:.5f} :: d_obj {d_obj:.5f} " + f":: gap {gap:.5f}") if gap < tol: print("Convergence reached!") break -clf = Lasso(alpha, tol=tol, fit_intercept=False) +clf = Lasso(alpha, tol=tol, weights=weights, fit_intercept=False) clf.fit(X, y) np.testing.assert_allclose(w, clf.coef_, rtol=1e-5) From c6342e9e8888220d2c0de48b77e97eb4c0a949b7 Mon Sep 17 00:00:00 2001 From: Pierre-Antoine Bannier Date: Sun, 24 Apr 2022 23:58:33 +0200 Subject: [PATCH 09/21] ENH dual gap criterion --- skglm/solvers/gram.py | 105 +++++++++++++++++++++++++++++++++--------- 1 file changed, 84 insertions(+), 21 deletions(-) diff --git a/skglm/solvers/gram.py b/skglm/solvers/gram.py index 317cea99b..fb3226caf 100644 --- a/skglm/solvers/gram.py +++ b/skglm/solvers/gram.py @@ -1,17 +1,49 @@ +from re import L import numpy as np from numba import njit from numpy.linalg import norm from celer.homotopy import _grp_converter -from skglm.utils import BST, ST +from skglm.utils import BST, ST, ST_vec @njit -def primal(alpha, y, X, w, weights): - r = y - X @ w - p_obj = (r @ r) / (2 * len(y)) +def primal(alpha, r, w, weights): + p_obj = (r @ r) / (2 * len(r)) return p_obj + alpha * np.sum(np.abs(w * weights)) +@njit +def dual(alpha, norm_y2, theta, y): + d_obj = - np.sum((y / (alpha * len(y)) - theta) ** 2) + d_obj *= 0.5 * alpha ** 2 * len(y) + d_obj += norm_y2 / (2 * len(y)) + return d_obj + +@njit +def dnorm_l1(theta, X, weights): + n_features = X.shape[1] + scal = 0. + for j in range(n_features): + Xj_theta = X[:, j] @ theta + scal = max(scal, Xj_theta / weights[j]) + return scal + +@njit +def create_dual_point(r, alpha, X, y, weights): + theta = r / (alpha * len(y)) + scal = dnorm_l1(theta, X, weights) + if scal > 1.: + theta /= scal + return theta + +@njit +def dual_gap(alpha, norm_y2, y, X, w, weights): + r = y - X @ w + p_obj = primal(alpha, r, w, weights) + theta = create_dual_point(r, alpha, X, y, weights) + d_obj = dual(alpha, norm_y2, theta, y) + return p_obj, d_obj, p_obj - d_obj + @njit def primal_grp(alpha, y, X, w, grp_ptr, grp_indices, weights): @@ -23,33 +55,66 @@ def primal_grp(alpha, y, X, w, grp_ptr, grp_indices, weights): return p_obj -def gram_lasso(X, y, alpha, max_iter, tol, w_init=None, weights=None, check_freq=10): - p_obj_prev = np.inf +@njit +def compute_lipschitz(X, y): n_features = X.shape[1] - grads = X.T @ y / len(y) - G = X.T @ X lipschitz = np.zeros(n_features, dtype=X.dtype) for j in range(n_features): lipschitz[j] = (X[:, j] ** 2).sum() / len(y) + return lipschitz + + +def gram_lasso(X, y, alpha, max_iter, tol, w_init=None, weights=None, check_freq=10): + n_features = X.shape[1] + norm_y2 = y @ y + grads = X.T @ y / len(y) + G = X.T @ X + lipschitz = compute_lipschitz(X, y) w = w_init.copy() if w_init is not None else np.zeros(n_features) - z = w_init.copy() if w_init is not None else np.zeros(n_features) - beta_0 = beta_1 = 1 weights = weights if weights is not None else np.ones(n_features) # CD for n_iter in range(max_iter): - beta_1 = (1 + np.sqrt(1 + 4 * beta_0 ** 2)) / 2 - cd_epoch(X, G, grads, w, z, alpha, beta_1, beta_0, lipschitz, weights) - beta_0 = beta_1 + cd_epoch(X, G, grads, w, alpha, lipschitz, weights) if n_iter % check_freq == 0: - p_obj = primal(alpha, y, X, w, weights) + p_obj, d_obj, d_gap = dual_gap(alpha, norm_y2, y, X, w, weights) + print(f"iter {n_iter} :: p_obj {p_obj:.5f} :: d_obj {d_obj:.5f}" + + f" :: gap {d_gap:.5f}") + if d_gap < tol: + print("Convergence reached!") + break + return w + + +def gram_fista_lasso(X, y, alpha, max_iter, tol, w_init=None, weights=None, + check_freq=10): + n_samples, n_features = X.shape + p_obj_prev = np.inf + t_new = 1 + w = w_init.copy() if w_init is not None else np.zeros(n_features) + z = w_init.copy() if w_init is not None else np.zeros(n_features) + weights = weights if weights is not None else np.ones(n_features) + G = X.T @ X + Xty = X.T @ y + L = np.linalg.norm(X, ord=2) ** 2 / n_samples + for n_iter in range(max_iter): + t_old = t_new + t_new = (1 + np.sqrt(1 + 4 * t_old ** 2)) / 2 + w_old = w.copy() + z -= (G @ z - Xty) / L / n_samples + w = ST_vec(z, alpha / L) + z = w + (t_old - 1.) / t_new * (w - w_old) + if n_iter % check_freq == 0: + r = y - X @ w + p_obj = primal(alpha, r, w, weights) if p_obj_prev - p_obj < tol: print("Convergence reached!") break - print(f"iter {n_iter} :: p_obj {p_obj}") + print(f"iter {n_iter} :: p_obj {p_obj:.5f}") p_obj_prev = p_obj return w + def gram_group_lasso(X, y, alpha, groups, max_iter, tol, w_init=None, weights=None, check_freq=50): p_obj_prev = np.inf @@ -78,17 +143,15 @@ def gram_group_lasso(X, y, alpha, groups, max_iter, tol, w_init=None, weights=No @njit -def cd_epoch(X, G, grads, w, z, alpha, beta_1, beta_0, lipschitz, weights): +def cd_epoch(X, G, grads, w, alpha, lipschitz, weights): n_features = X.shape[1] for j in range(n_features): if lipschitz[j] == 0. or weights[j] == np.inf: continue old_w_j = w[j] - old_z_j = z[j] - w[j] = ST(z[j] + grads[j] / lipschitz[j], alpha / lipschitz[j] * weights[j]) - z[j] = w[j] + ((beta_0 - 1) / beta_1) * (w[j] - old_w_j) - if old_z_j != z[j]: - grads += G[j, :] * (old_z_j - z[j]) / len(X) + w[j] = ST(w[j] + grads[j] / lipschitz[j], alpha / lipschitz[j] * weights[j]) + if old_w_j != w[j]: + grads += G[j, :] * (old_w_j - w[j]) / len(X) @njit From dc7a0eb92c8df86a8eed2c7dc9f1b24d24ab5423 Mon Sep 17 00:00:00 2001 From: Pierre-Antoine Bannier Date: Mon, 25 Apr 2022 00:11:17 +0200 Subject: [PATCH 10/21] CLN Gram FISTA solver --- skglm/fista_gram.py | 93 ------------------------------------------- skglm/gram_solver.py | 45 +++++++++++---------- skglm/solvers/gram.py | 17 ++++---- 3 files changed, 35 insertions(+), 120 deletions(-) delete mode 100644 skglm/fista_gram.py diff --git a/skglm/fista_gram.py b/skglm/fista_gram.py deleted file mode 100644 index 80cd1ac90..000000000 --- a/skglm/fista_gram.py +++ /dev/null @@ -1,93 +0,0 @@ -import numpy as np -from numpy.linalg import norm -from numba import njit -from celer import Lasso -from benchopt.datasets.simulated import make_correlated_data -from skglm.utils import ST_vec - - -@njit -def primal(alpha, r, w, weights): - p_obj = (r @ r) / (2 * len(r)) - return p_obj + alpha * np.sum(np.abs(w * weights)) - -@njit -def dual(alpha, norm_y2, theta, y): - d_obj = - np.sum((y / (alpha * len(y)) - theta) ** 2) - d_obj *= 0.5 * alpha ** 2 * len(y) - d_obj += norm_y2 / (2 * len(y)) - return d_obj - -@njit -def dnorm_l1(theta, X, weights): - n_features = X.shape[1] - scal = 0. - for j in range(n_features): - Xj_theta = X[:, j] @ theta - scal = max(scal, Xj_theta / weights[j]) - return scal - -@njit -def create_dual_point(r, alpha, X, weights): - theta = r / (alpha * len(y)) - scal = dnorm_l1(theta, X, weights) - if scal > 1.: - theta /= scal - return theta - -@njit -def dual_gap(alpha, norm_y2, y, X, w, weights): - r = y - X @ w - p_obj = primal(alpha, r, w, weights) - theta = create_dual_point(r, alpha, X, weights) - d_obj = dual(alpha, norm_y2, theta, y) - return p_obj, d_obj, p_obj - d_obj - - -n_samples, n_features = 1_000, 300 -X, y, w_star = make_correlated_data( - n_samples=n_samples, n_features=n_features, random_state=0) -alpha_max = norm(X.T @ y, ord=np.inf) - -weights = np.random.normal(2, 0.4, n_features) - -# Hyperparameters -max_iter = 1000 -tol = 1e-8 -reg = 0.1 -group_size = 3 -check_gap_freq = 100 - -alpha = alpha_max * reg / n_samples - -L = np.linalg.norm(X, ord=2) ** 2 / n_samples - -G = X.T @ X -Xty = X.T @ y - -w = np.zeros(n_features) -z = np.zeros(n_features) - -norm_y2 = y @ y - -t_new = 1 - -for n_iter in range(max_iter): - t_old = t_new - t_new = (1 + np.sqrt(1 + 4 * t_old ** 2)) / 2 - w_old = w.copy() - z -= (G @ z - Xty) / L / n_samples - w = ST_vec(z, alpha / L * weights) - z = w + (t_old - 1.) / t_new * (w - w_old) - - if n_iter % check_gap_freq == 0: - p_obj, d_obj, gap = dual_gap(alpha, norm_y2, y, X, w, weights) - print(f"iter {n_iter} :: p_obj {p_obj:.5f} :: d_obj {d_obj:.5f} " + - f":: gap {gap:.5f}") - if gap < tol: - print("Convergence reached!") - break - -clf = Lasso(alpha, tol=tol, weights=weights, fit_intercept=False) -clf.fit(X, y) -np.testing.assert_allclose(w, clf.coef_, rtol=1e-5) diff --git a/skglm/gram_solver.py b/skglm/gram_solver.py index 4e64b746a..58c9415e5 100644 --- a/skglm/gram_solver.py +++ b/skglm/gram_solver.py @@ -3,10 +3,10 @@ from numpy.linalg import norm from celer import Lasso, GroupLasso from benchopt.datasets.simulated import make_correlated_data -from skglm.solvers.gram import gram_lasso, gram_group_lasso +from skglm.solvers.gram import gram_fista_lasso, gram_lasso, gram_group_lasso -n_samples, n_features = 1_000_000, 300 +n_samples, n_features = 100, 300 X, y, w_star = make_correlated_data( n_samples=n_samples, n_features=n_features, random_state=0) alpha_max = norm(X.T @ y, ord=np.inf) @@ -33,28 +33,33 @@ start = time() clf_sk.fit(X, y) celer_lasso_time = time() - start -np.testing.assert_allclose(w, clf_sk.coef_, rtol=1e-5) +start = time() +w_fista = gram_fista_lasso(X, y, alpha, max_iter, tol, weights=weights) +gram_fista_lasso_time = time() - start +np.testing.assert_allclose(w, clf_sk.coef_, rtol=1e-4) +np.testing.assert_allclose(w, w_fista, rtol=1e-4) print("\n") print("Celer: %.2f" % celer_lasso_time) print("Gram: %.2f" % gram_lasso_time) +print("FISTA Gram: %.2f" % gram_fista_lasso_time) print("\n") -# Group Lasso -print("#" * 15) -print("Group Lasso") -print("#" * 15) -start = time() -w = gram_group_lasso(X, y, alpha, group_size, max_iter, tol, weights=weights_grp) -gram_group_lasso_time = time() - start -clf_celer = GroupLasso(group_size, alpha, weights=weights_grp, tol=tol, - fit_intercept=False) -start = time() -clf_celer.fit(X, y) -celer_group_lasso_time = time() - start -np.testing.assert_allclose(w, clf_celer.coef_, rtol=1e-1) +# # Group Lasso +# print("#" * 15) +# print("Group Lasso") +# print("#" * 15) +# start = time() +# w = gram_group_lasso(X, y, alpha, group_size, max_iter, tol, weights=weights_grp) +# gram_group_lasso_time = time() - start +# clf_celer = GroupLasso(group_size, alpha, tol=tol, +# fit_intercept=False) +# start = time() +# clf_celer.fit(X, y) +# celer_group_lasso_time = time() - start +# np.testing.assert_allclose(w, clf_celer.coef_, rtol=1e-1) -print("\n") -print("Celer: %.2f" % celer_group_lasso_time) -print("Gram: %.2f" % gram_group_lasso_time) -print("\n") +# print("\n") +# print("Celer: %.2f" % celer_group_lasso_time) +# print("Gram: %.2f" % gram_group_lasso_time) +# print("\n") diff --git a/skglm/solvers/gram.py b/skglm/solvers/gram.py index fb3226caf..e48db07f1 100644 --- a/skglm/solvers/gram.py +++ b/skglm/solvers/gram.py @@ -88,29 +88,32 @@ def gram_lasso(X, y, alpha, max_iter, tol, w_init=None, weights=None, check_freq def gram_fista_lasso(X, y, alpha, max_iter, tol, w_init=None, weights=None, check_freq=10): n_samples, n_features = X.shape - p_obj_prev = np.inf + norm_y2 = y @ y t_new = 1 + w = w_init.copy() if w_init is not None else np.zeros(n_features) z = w_init.copy() if w_init is not None else np.zeros(n_features) weights = weights if weights is not None else np.ones(n_features) + G = X.T @ X Xty = X.T @ y L = np.linalg.norm(X, ord=2) ** 2 / n_samples + for n_iter in range(max_iter): t_old = t_new t_new = (1 + np.sqrt(1 + 4 * t_old ** 2)) / 2 w_old = w.copy() z -= (G @ z - Xty) / L / n_samples - w = ST_vec(z, alpha / L) + w = ST_vec(z, alpha / L * weights) z = w + (t_old - 1.) / t_new * (w - w_old) + if n_iter % check_freq == 0: - r = y - X @ w - p_obj = primal(alpha, r, w, weights) - if p_obj_prev - p_obj < tol: + p_obj, d_obj, d_gap = dual_gap(alpha, norm_y2, y, X, w, weights) + print(f"iter {n_iter} :: p_obj {p_obj:.5f} :: d_obj {d_obj:.5f} " + + f":: gap {d_gap:.5f}") + if d_gap < tol: print("Convergence reached!") break - print(f"iter {n_iter} :: p_obj {p_obj:.5f}") - p_obj_prev = p_obj return w From adbab981c71161fdb43b86bf0ad5ac1a20badbc3 Mon Sep 17 00:00:00 2001 From: Pierre-Antoine Bannier Date: Mon, 25 Apr 2022 00:15:24 +0200 Subject: [PATCH 11/21] format --- skglm/solvers/gram.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/skglm/solvers/gram.py b/skglm/solvers/gram.py index e48db07f1..66ac28389 100644 --- a/skglm/solvers/gram.py +++ b/skglm/solvers/gram.py @@ -1,4 +1,3 @@ -from re import L import numpy as np from numba import njit from numpy.linalg import norm @@ -12,6 +11,7 @@ def primal(alpha, r, w, weights): p_obj = (r @ r) / (2 * len(r)) return p_obj + alpha * np.sum(np.abs(w * weights)) + @njit def dual(alpha, norm_y2, theta, y): d_obj = - np.sum((y / (alpha * len(y)) - theta) ** 2) @@ -19,6 +19,7 @@ def dual(alpha, norm_y2, theta, y): d_obj += norm_y2 / (2 * len(y)) return d_obj + @njit def dnorm_l1(theta, X, weights): n_features = X.shape[1] @@ -28,6 +29,7 @@ def dnorm_l1(theta, X, weights): scal = max(scal, Xj_theta / weights[j]) return scal + @njit def create_dual_point(r, alpha, X, y, weights): theta = r / (alpha * len(y)) @@ -36,6 +38,7 @@ def create_dual_point(r, alpha, X, y, weights): theta /= scal return theta + @njit def dual_gap(alpha, norm_y2, y, X, w, weights): r = y - X @ w @@ -117,7 +120,6 @@ def gram_fista_lasso(X, y, alpha, max_iter, tol, w_init=None, weights=None, return w - def gram_group_lasso(X, y, alpha, groups, max_iter, tol, w_init=None, weights=None, check_freq=50): p_obj_prev = np.inf From b81d3485f66ac3474ec4771b70a8553c88b2703a Mon Sep 17 00:00:00 2001 From: Pierre-Antoine Bannier Date: Mon, 25 Apr 2022 00:50:16 +0200 Subject: [PATCH 12/21] WIP Gram FISTA BCD --- skglm/gram_solver.py | 64 ++++++++++++++++++++++++------------------- skglm/solvers/gram.py | 50 ++++++++++++++++++++++++++++++--- 2 files changed, 82 insertions(+), 32 deletions(-) diff --git a/skglm/gram_solver.py b/skglm/gram_solver.py index 58c9415e5..498dbae9f 100644 --- a/skglm/gram_solver.py +++ b/skglm/gram_solver.py @@ -3,7 +3,7 @@ from numpy.linalg import norm from celer import Lasso, GroupLasso from benchopt.datasets.simulated import make_correlated_data -from skglm.solvers.gram import gram_fista_lasso, gram_lasso, gram_group_lasso +from skglm.solvers.gram import gram_fista_group_lasso, gram_fista_lasso, gram_lasso, gram_group_lasso n_samples, n_features = 100, 300 @@ -23,43 +23,51 @@ weights_grp = np.random.normal(2, 0.4, n_features // group_size) # Lasso +# print("#" * 15) +# print("Lasso") +# print("#" * 15) +# start = time() +# w = gram_lasso(X, y, alpha, max_iter, tol, weights=weights) +# gram_lasso_time = time() - start +# clf_sk = Lasso(alpha, weights=weights, tol=tol, fit_intercept=False) +# start = time() +# clf_sk.fit(X, y) +# celer_lasso_time = time() - start +# start = time() +# w_fista = gram_fista_lasso(X, y, alpha, max_iter, tol, weights=weights) +# gram_fista_lasso_time = time() - start +# np.testing.assert_allclose(w, clf_sk.coef_, rtol=1e-4) +# np.testing.assert_allclose(w, w_fista, rtol=1e-4) + +# print("\n") +# print("Celer: %.2f" % celer_lasso_time) +# print("Gram: %.2f" % gram_lasso_time) +# print("FISTA Gram: %.2f" % gram_fista_lasso_time) +# print("\n") + +# Group Lasso print("#" * 15) -print("Lasso") +print("Group Lasso") print("#" * 15) start = time() -w = gram_lasso(X, y, alpha, max_iter, tol, weights=weights) -gram_lasso_time = time() - start -clf_sk = Lasso(alpha, weights=weights, tol=tol, fit_intercept=False) +w = gram_group_lasso(X, y, alpha, group_size, max_iter, tol, weights=weights_grp) +gram_group_lasso_time = time() - start start = time() -clf_sk.fit(X, y) -celer_lasso_time = time() - start -start = time() -w_fista = gram_fista_lasso(X, y, alpha, max_iter, tol, weights=weights) -gram_fista_lasso_time = time() - start -np.testing.assert_allclose(w, clf_sk.coef_, rtol=1e-4) -np.testing.assert_allclose(w, w_fista, rtol=1e-4) +w_fista = gram_fista_group_lasso(X, y, alpha, group_size, max_iter, tol, + weights=weights_grp) +gram_fista_group_lasso_time = time() - start -print("\n") -print("Celer: %.2f" % celer_lasso_time) -print("Gram: %.2f" % gram_lasso_time) -print("FISTA Gram: %.2f" % gram_fista_lasso_time) -print("\n") +np.testing.assert_allclose(w, w_fista, rtol=1e-4) -# # Group Lasso -# print("#" * 15) -# print("Group Lasso") -# print("#" * 15) -# start = time() -# w = gram_group_lasso(X, y, alpha, group_size, max_iter, tol, weights=weights_grp) -# gram_group_lasso_time = time() - start -# clf_celer = GroupLasso(group_size, alpha, tol=tol, +# clf_celer = GroupLasso(group_size, alpha, tol=tol, weights=weights_grp, # fit_intercept=False) # start = time() # clf_celer.fit(X, y) # celer_group_lasso_time = time() - start # np.testing.assert_allclose(w, clf_celer.coef_, rtol=1e-1) -# print("\n") +print("\n") # print("Celer: %.2f" % celer_group_lasso_time) -# print("Gram: %.2f" % gram_group_lasso_time) -# print("\n") +print("Gram: %.2f" % gram_group_lasso_time) +print("FISTA Gram: %.2f" % gram_fista_group_lasso_time) +print("\n") diff --git a/skglm/solvers/gram.py b/skglm/solvers/gram.py index 66ac28389..c40d7b063 100644 --- a/skglm/solvers/gram.py +++ b/skglm/solvers/gram.py @@ -3,7 +3,7 @@ from numpy.linalg import norm from celer.homotopy import _grp_converter -from skglm.utils import BST, ST, ST_vec +from skglm.utils import BST, ST, BST_vec, ST_vec @njit @@ -69,7 +69,7 @@ def compute_lipschitz(X, y): def gram_lasso(X, y, alpha, max_iter, tol, w_init=None, weights=None, check_freq=10): n_features = X.shape[1] - norm_y2 = y @ y + norm_y2 = y @ y grads = X.T @ y / len(y) G = X.T @ X lipschitz = compute_lipschitz(X, y) @@ -88,7 +88,7 @@ def gram_lasso(X, y, alpha, max_iter, tol, w_init=None, weights=None, check_freq return w -def gram_fista_lasso(X, y, alpha, max_iter, tol, w_init=None, weights=None, +def gram_fista_lasso(X, y, alpha, max_iter, tol, w_init=None, weights=None, check_freq=10): n_samples, n_features = X.shape norm_y2 = y @ y @@ -120,7 +120,49 @@ def gram_fista_lasso(X, y, alpha, max_iter, tol, w_init=None, weights=None, return w -def gram_group_lasso(X, y, alpha, groups, max_iter, tol, w_init=None, weights=None, +def gram_fista_group_lasso(X, y, alpha, groups, max_iter, tol, w_init=None, + weights=None, check_freq=50): + p_obj_prev = np.inf + n_features = X.shape[1] + + grp_ptr, grp_indices = _grp_converter(groups, X.shape[1]) + n_groups = len(grp_ptr) - 1 + grp_size = n_features // n_groups + + t_new = 1 + + w = w_init.copy() if w_init is not None else np.zeros(n_features) + z = w_init.copy() if w_init is not None else np.zeros(n_features) + weights = weights if weights is not None else np.ones(n_features) + tiled_weights = np.repeat(weights, grp_size) + + G = X.T @ X + Xty = X.T @ y + + lipschitz = np.zeros(n_groups, dtype=X.dtype) + for g in range(n_groups): + X_g = X[:, grp_indices[grp_ptr[g]:grp_ptr[g + 1]]] + lipschitz[g] = norm(X_g, ord=2) ** 2 / len(y) + tiled_lipschitz = np.repeat(lipschitz[np.newaxis, :], grp_size) + + for n_iter in range(max_iter): + t_old = t_new + t_new = (1 + np.sqrt(1 + 4 * t_old ** 2)) / 2 + w_old = w.copy() + z -= (G @ z - Xty) / tiled_lipschitz / len(y) + w = BST_vec(z, alpha / tiled_lipschitz * tiled_weights, n_features) + z = w + (t_old - 1.) / t_new * (w - w_old) + if n_iter % check_freq == 0: + p_obj = primal_grp(alpha, y, X, w, grp_ptr, grp_indices, weights) + print(f"iter {n_iter} :: p_obj {p_obj:.5f}") + if p_obj_prev - p_obj < tol: + print("Convergence reached!") + break + p_obj_prev = p_obj + return w + + +def gram_group_lasso(X, y, alpha, groups, max_iter, tol, w_init=None, weights=None, check_freq=50): p_obj_prev = np.inf n_features = X.shape[1] From f142317e8ed727263f363ad0bc60432d910505e7 Mon Sep 17 00:00:00 2001 From: Pierre-Antoine Bannier Date: Mon, 25 Apr 2022 08:50:02 +0200 Subject: [PATCH 13/21] duality gap for BCD --- skglm/solvers/gram.py | 65 +++++++++++++++++++++++++++++++++---------- 1 file changed, 50 insertions(+), 15 deletions(-) diff --git a/skglm/solvers/gram.py b/skglm/solvers/gram.py index c40d7b063..a969e35c8 100644 --- a/skglm/solvers/gram.py +++ b/skglm/solvers/gram.py @@ -12,6 +12,15 @@ def primal(alpha, r, w, weights): return p_obj + alpha * np.sum(np.abs(w * weights)) +@njit +def primal_grp(alpha, norm_r2, r, w, grp_ptr, grp_indices, weights): + p_obj = norm_r2 / (2 * len(r)) + for g in range(len(grp_ptr) - 1): + w_g = w[grp_indices[grp_ptr[g]:grp_ptr[g + 1]]] + p_obj += alpha * norm(w_g * weights[g], ord=2) + return p_obj + + @njit def dual(alpha, norm_y2, theta, y): d_obj = - np.sum((y / (alpha * len(y)) - theta) ** 2) @@ -30,6 +39,22 @@ def dnorm_l1(theta, X, weights): return scal +@njit +def dnorm_l21(theta, grp_ptr, grp_indices, X, weights): + scal = 0. + n_groups = len(grp_ptr) - 1 + for g in range(n_groups): + if weights[g] == np.inf: + continue + tmp = 0. + for k in range(grp_ptr[g], grp_ptr[g + 1]): + j = grp_indices[k] + Xj_theta = X[:, j] @ theta + tmp += Xj_theta ** 2 + scal = max(scal, np.sqrt(tmp) / weights[g]) + return scal + + @njit def create_dual_point(r, alpha, X, y, weights): theta = r / (alpha * len(y)) @@ -39,6 +64,15 @@ def create_dual_point(r, alpha, X, y, weights): return theta +@njit +def create_dual_point_grp(r, alpha, y, X, grp_ptr, grp_indices, weights): + theta = r / (alpha * len(y)) + scal = dnorm_l21(theta, grp_ptr, grp_indices, X, weights) + if scal > 1.: + theta /= scal + return theta + + @njit def dual_gap(alpha, norm_y2, y, X, w, weights): r = y - X @ w @@ -49,13 +83,13 @@ def dual_gap(alpha, norm_y2, y, X, w, weights): @njit -def primal_grp(alpha, y, X, w, grp_ptr, grp_indices, weights): +def dual_gap_grp(y, X, w, alpha, norm_y2, grp_ptr, grp_indices, weights): r = y - X @ w - p_obj = (r @ r) / (2 * len(y)) - for g in range(len(grp_ptr) - 1): - w_g = w[grp_indices[grp_ptr[g]:grp_ptr[g + 1]]] - p_obj += alpha * norm(w_g * weights[g], ord=2) - return p_obj + norm_r2 = r @ r + p_obj = primal_grp(alpha, norm_r2, r, w, grp_ptr, grp_indices, weights) + theta = create_dual_point_grp(r, alpha, y, X, grp_ptr, grp_indices, weights) + d_obj = dual(alpha, norm_y2, theta, y) + return p_obj, d_obj, p_obj - d_obj @njit @@ -133,8 +167,8 @@ def gram_fista_group_lasso(X, y, alpha, groups, max_iter, tol, w_init=None, w = w_init.copy() if w_init is not None else np.zeros(n_features) z = w_init.copy() if w_init is not None else np.zeros(n_features) - weights = weights if weights is not None else np.ones(n_features) - tiled_weights = np.repeat(weights, grp_size) + weights = weights if weights is not None else np.ones(n_groups) + # tiled_weights = np.repeat(weights, grp_size) G = X.T @ X Xty = X.T @ y @@ -150,10 +184,10 @@ def gram_fista_group_lasso(X, y, alpha, groups, max_iter, tol, w_init=None, t_new = (1 + np.sqrt(1 + 4 * t_old ** 2)) / 2 w_old = w.copy() z -= (G @ z - Xty) / tiled_lipschitz / len(y) - w = BST_vec(z, alpha / tiled_lipschitz * tiled_weights, n_features) + w = BST_vec(z, alpha / lipschitz, grp_size) z = w + (t_old - 1.) / t_new * (w - w_old) if n_iter % check_freq == 0: - p_obj = primal_grp(alpha, y, X, w, grp_ptr, grp_indices, weights) + p_obj = primal_grp(alpha, y, X, w, grp_ptr, grp_indices, np.ones(n_groups)) print(f"iter {n_iter} :: p_obj {p_obj:.5f}") if p_obj_prev - p_obj < tol: print("Convergence reached!") @@ -164,10 +198,10 @@ def gram_fista_group_lasso(X, y, alpha, groups, max_iter, tol, w_init=None, def gram_group_lasso(X, y, alpha, groups, max_iter, tol, w_init=None, weights=None, check_freq=50): - p_obj_prev = np.inf n_features = X.shape[1] grp_ptr, grp_indices = _grp_converter(groups, X.shape[1]) n_groups = len(grp_ptr) - 1 + norm_y2 = y @ y grads = X.T @ y / len(y) G = X.T @ X lipschitz = np.zeros(n_groups, dtype=X.dtype) @@ -180,12 +214,13 @@ def gram_group_lasso(X, y, alpha, groups, max_iter, tol, w_init=None, weights=No for n_iter in range(max_iter): bcd_epoch(X, G, grads, w, alpha, lipschitz, grp_indices, grp_ptr, weights) if n_iter % check_freq == 0: - p_obj = primal_grp(alpha, y, X, w, grp_ptr, grp_indices, weights) - if p_obj_prev - p_obj < tol: + p_obj, d_obj, d_gap = dual_gap_grp(y, X, w, alpha, norm_y2, grp_ptr, + grp_indices, weights) + print(f"iter {n_iter} :: p_obj {p_obj:.5f} :: d_obj {d_obj:.5f} " + + f":: gap {d_gap:.5f}") + if d_gap < tol: print("Convergence reached!") break - print(f"iter {n_iter} :: p_obj {p_obj}") - p_obj_prev = p_obj return w From 552485e2679a790f029f0191c3b4d4d8f72d5f45 Mon Sep 17 00:00:00 2001 From: Pierre-Antoine Bannier Date: Mon, 25 Apr 2022 09:17:42 +0200 Subject: [PATCH 14/21] working BCD FISTA gram with weights --- skglm/gram_solver.py | 6 +++--- skglm/solvers/gram.py | 23 ++++++++++------------- 2 files changed, 13 insertions(+), 16 deletions(-) diff --git a/skglm/gram_solver.py b/skglm/gram_solver.py index 498dbae9f..d22eb9aba 100644 --- a/skglm/gram_solver.py +++ b/skglm/gram_solver.py @@ -12,7 +12,7 @@ alpha_max = norm(X.T @ y, ord=np.inf) # Hyperparameters -max_iter = 1000 +max_iter = 10_000 tol = 1e-8 reg = 0.1 group_size = 3 @@ -53,13 +53,13 @@ w = gram_group_lasso(X, y, alpha, group_size, max_iter, tol, weights=weights_grp) gram_group_lasso_time = time() - start start = time() -w_fista = gram_fista_group_lasso(X, y, alpha, group_size, max_iter, tol, +w_fista = gram_fista_group_lasso(X, y, alpha, group_size, max_iter, tol, weights=weights_grp) gram_fista_group_lasso_time = time() - start np.testing.assert_allclose(w, w_fista, rtol=1e-4) -# clf_celer = GroupLasso(group_size, alpha, tol=tol, weights=weights_grp, +# clf_celer = GroupLasso(group_size, alpha, tol=tol, weights=weights_grp, # fit_intercept=False) # start = time() # clf_celer.fit(X, y) diff --git a/skglm/solvers/gram.py b/skglm/solvers/gram.py index a969e35c8..11cee868c 100644 --- a/skglm/solvers/gram.py +++ b/skglm/solvers/gram.py @@ -156,8 +156,8 @@ def gram_fista_lasso(X, y, alpha, max_iter, tol, w_init=None, weights=None, def gram_fista_group_lasso(X, y, alpha, groups, max_iter, tol, w_init=None, weights=None, check_freq=50): - p_obj_prev = np.inf n_features = X.shape[1] + norm_y2 = y @ y grp_ptr, grp_indices = _grp_converter(groups, X.shape[1]) n_groups = len(grp_ptr) - 1 @@ -168,31 +168,28 @@ def gram_fista_group_lasso(X, y, alpha, groups, max_iter, tol, w_init=None, w = w_init.copy() if w_init is not None else np.zeros(n_features) z = w_init.copy() if w_init is not None else np.zeros(n_features) weights = weights if weights is not None else np.ones(n_groups) - # tiled_weights = np.repeat(weights, grp_size) G = X.T @ X Xty = X.T @ y - lipschitz = np.zeros(n_groups, dtype=X.dtype) - for g in range(n_groups): - X_g = X[:, grp_indices[grp_ptr[g]:grp_ptr[g + 1]]] - lipschitz[g] = norm(X_g, ord=2) ** 2 / len(y) - tiled_lipschitz = np.repeat(lipschitz[np.newaxis, :], grp_size) + L = np.linalg.norm(X, ord=2) ** 2 / len(y) for n_iter in range(max_iter): t_old = t_new t_new = (1 + np.sqrt(1 + 4 * t_old ** 2)) / 2 w_old = w.copy() - z -= (G @ z - Xty) / tiled_lipschitz / len(y) - w = BST_vec(z, alpha / lipschitz, grp_size) + z -= (G @ z - Xty) / L / len(y) + w = BST_vec(z, alpha / L * weights, grp_size) z = w + (t_old - 1.) / t_new * (w - w_old) + if n_iter % check_freq == 0: - p_obj = primal_grp(alpha, y, X, w, grp_ptr, grp_indices, np.ones(n_groups)) - print(f"iter {n_iter} :: p_obj {p_obj:.5f}") - if p_obj_prev - p_obj < tol: + p_obj, d_obj, d_gap = dual_gap_grp(y, X, w, alpha, norm_y2, grp_ptr, + grp_indices, weights) + print(f"iter {n_iter} :: p_obj {p_obj:.5f} :: d_obj {d_obj:.5f} " + + f":: gap {d_gap:.5f}") + if d_gap < tol: print("Convergence reached!") break - p_obj_prev = p_obj return w From 1a6e6aacea527816ed3ea5a39c667bcff01a942b Mon Sep 17 00:00:00 2001 From: Pierre-Antoine Bannier Date: Mon, 25 Apr 2022 09:20:11 +0200 Subject: [PATCH 15/21] CLN --- skglm/gram_solver.py | 42 +++++++++++++++++++++--------------------- skglm/solvers/gram.py | 8 ++++---- 2 files changed, 25 insertions(+), 25 deletions(-) diff --git a/skglm/gram_solver.py b/skglm/gram_solver.py index d22eb9aba..6463dc68c 100644 --- a/skglm/gram_solver.py +++ b/skglm/gram_solver.py @@ -23,27 +23,27 @@ weights_grp = np.random.normal(2, 0.4, n_features // group_size) # Lasso -# print("#" * 15) -# print("Lasso") -# print("#" * 15) -# start = time() -# w = gram_lasso(X, y, alpha, max_iter, tol, weights=weights) -# gram_lasso_time = time() - start -# clf_sk = Lasso(alpha, weights=weights, tol=tol, fit_intercept=False) -# start = time() -# clf_sk.fit(X, y) -# celer_lasso_time = time() - start -# start = time() -# w_fista = gram_fista_lasso(X, y, alpha, max_iter, tol, weights=weights) -# gram_fista_lasso_time = time() - start -# np.testing.assert_allclose(w, clf_sk.coef_, rtol=1e-4) -# np.testing.assert_allclose(w, w_fista, rtol=1e-4) +print("#" * 15) +print("Lasso") +print("#" * 15) +start = time() +w = gram_lasso(X, y, alpha, max_iter, tol, weights=weights) +gram_lasso_time = time() - start +clf_sk = Lasso(alpha, weights=weights, tol=tol, fit_intercept=False) +start = time() +clf_sk.fit(X, y) +celer_lasso_time = time() - start +start = time() +w_fista = gram_fista_lasso(X, y, alpha, max_iter, tol, weights=weights) +gram_fista_lasso_time = time() - start +np.testing.assert_allclose(w, clf_sk.coef_, rtol=1e-4) +np.testing.assert_allclose(w, w_fista, rtol=1e-4) -# print("\n") -# print("Celer: %.2f" % celer_lasso_time) -# print("Gram: %.2f" % gram_lasso_time) -# print("FISTA Gram: %.2f" % gram_fista_lasso_time) -# print("\n") +print("\n") +print("Celer: %.2f" % celer_lasso_time) +print("CD Gram: %.2f" % gram_lasso_time) +print("FISTA Gram: %.2f" % gram_fista_lasso_time) +print("\n") # Group Lasso print("#" * 15) @@ -68,6 +68,6 @@ print("\n") # print("Celer: %.2f" % celer_group_lasso_time) -print("Gram: %.2f" % gram_group_lasso_time) +print("BCD Gram: %.2f" % gram_group_lasso_time) print("FISTA Gram: %.2f" % gram_fista_group_lasso_time) print("\n") diff --git a/skglm/solvers/gram.py b/skglm/solvers/gram.py index 11cee868c..c6ba533f6 100644 --- a/skglm/solvers/gram.py +++ b/skglm/solvers/gram.py @@ -101,7 +101,7 @@ def compute_lipschitz(X, y): return lipschitz -def gram_lasso(X, y, alpha, max_iter, tol, w_init=None, weights=None, check_freq=10): +def gram_lasso(X, y, alpha, max_iter, tol, w_init=None, weights=None, check_freq=100): n_features = X.shape[1] norm_y2 = y @ y grads = X.T @ y / len(y) @@ -123,7 +123,7 @@ def gram_lasso(X, y, alpha, max_iter, tol, w_init=None, weights=None, check_freq def gram_fista_lasso(X, y, alpha, max_iter, tol, w_init=None, weights=None, - check_freq=10): + check_freq=100): n_samples, n_features = X.shape norm_y2 = y @ y t_new = 1 @@ -155,7 +155,7 @@ def gram_fista_lasso(X, y, alpha, max_iter, tol, w_init=None, weights=None, def gram_fista_group_lasso(X, y, alpha, groups, max_iter, tol, w_init=None, - weights=None, check_freq=50): + weights=None, check_freq=100): n_features = X.shape[1] norm_y2 = y @ y @@ -194,7 +194,7 @@ def gram_fista_group_lasso(X, y, alpha, groups, max_iter, tol, w_init=None, def gram_group_lasso(X, y, alpha, groups, max_iter, tol, w_init=None, weights=None, - check_freq=50): + check_freq=100): n_features = X.shape[1] grp_ptr, grp_indices = _grp_converter(groups, X.shape[1]) n_groups = len(grp_ptr) - 1 From a2cf8a7999674629b838133640a902103a1295fe Mon Sep 17 00:00:00 2001 From: Pierre-Antoine Bannier Date: Mon, 25 Apr 2022 09:26:06 +0200 Subject: [PATCH 16/21] CLN --- skglm/solvers/gram.py | 63 +++++++++++++++++-------------------------- 1 file changed, 25 insertions(+), 38 deletions(-) diff --git a/skglm/solvers/gram.py b/skglm/solvers/gram.py index c6ba533f6..9a87e10c5 100644 --- a/skglm/solvers/gram.py +++ b/skglm/solvers/gram.py @@ -109,7 +109,6 @@ def gram_lasso(X, y, alpha, max_iter, tol, w_init=None, weights=None, check_freq lipschitz = compute_lipschitz(X, y) w = w_init.copy() if w_init is not None else np.zeros(n_features) weights = weights if weights is not None else np.ones(n_features) - # CD for n_iter in range(max_iter): cd_epoch(X, G, grads, w, alpha, lipschitz, weights) if n_iter % check_freq == 0: @@ -127,15 +126,12 @@ def gram_fista_lasso(X, y, alpha, max_iter, tol, w_init=None, weights=None, n_samples, n_features = X.shape norm_y2 = y @ y t_new = 1 - w = w_init.copy() if w_init is not None else np.zeros(n_features) z = w_init.copy() if w_init is not None else np.zeros(n_features) weights = weights if weights is not None else np.ones(n_features) - G = X.T @ X Xty = X.T @ y L = np.linalg.norm(X, ord=2) ** 2 / n_samples - for n_iter in range(max_iter): t_old = t_new t_new = (1 + np.sqrt(1 + 4 * t_old ** 2)) / 2 @@ -143,7 +139,6 @@ def gram_fista_lasso(X, y, alpha, max_iter, tol, w_init=None, weights=None, z -= (G @ z - Xty) / L / n_samples w = ST_vec(z, alpha / L * weights) z = w + (t_old - 1.) / t_new * (w - w_old) - if n_iter % check_freq == 0: p_obj, d_obj, d_gap = dual_gap(alpha, norm_y2, y, X, w, weights) print(f"iter {n_iter} :: p_obj {p_obj:.5f} :: d_obj {d_obj:.5f} " + @@ -154,34 +149,22 @@ def gram_fista_lasso(X, y, alpha, max_iter, tol, w_init=None, weights=None, return w -def gram_fista_group_lasso(X, y, alpha, groups, max_iter, tol, w_init=None, - weights=None, check_freq=100): +def gram_group_lasso(X, y, alpha, groups, max_iter, tol, w_init=None, weights=None, + check_freq=100): n_features = X.shape[1] - norm_y2 = y @ y - grp_ptr, grp_indices = _grp_converter(groups, X.shape[1]) n_groups = len(grp_ptr) - 1 - grp_size = n_features // n_groups - - t_new = 1 - + norm_y2 = y @ y + grads = X.T @ y / len(y) + G = X.T @ X + lipschitz = np.zeros(n_groups, dtype=X.dtype) + for g in range(n_groups): + X_g = X[:, grp_indices[grp_ptr[g]:grp_ptr[g + 1]]] + lipschitz[g] = norm(X_g, ord=2) ** 2 / len(y) w = w_init.copy() if w_init is not None else np.zeros(n_features) - z = w_init.copy() if w_init is not None else np.zeros(n_features) weights = weights if weights is not None else np.ones(n_groups) - - G = X.T @ X - Xty = X.T @ y - - L = np.linalg.norm(X, ord=2) ** 2 / len(y) - for n_iter in range(max_iter): - t_old = t_new - t_new = (1 + np.sqrt(1 + 4 * t_old ** 2)) / 2 - w_old = w.copy() - z -= (G @ z - Xty) / L / len(y) - w = BST_vec(z, alpha / L * weights, grp_size) - z = w + (t_old - 1.) / t_new * (w - w_old) - + bcd_epoch(X, G, grads, w, alpha, lipschitz, grp_indices, grp_ptr, weights) if n_iter % check_freq == 0: p_obj, d_obj, d_gap = dual_gap_grp(y, X, w, alpha, norm_y2, grp_ptr, grp_indices, weights) @@ -193,23 +176,27 @@ def gram_fista_group_lasso(X, y, alpha, groups, max_iter, tol, w_init=None, return w -def gram_group_lasso(X, y, alpha, groups, max_iter, tol, w_init=None, weights=None, - check_freq=100): +def gram_fista_group_lasso(X, y, alpha, groups, max_iter, tol, w_init=None, + weights=None, check_freq=100): n_features = X.shape[1] + norm_y2 = y @ y grp_ptr, grp_indices = _grp_converter(groups, X.shape[1]) n_groups = len(grp_ptr) - 1 - norm_y2 = y @ y - grads = X.T @ y / len(y) - G = X.T @ X - lipschitz = np.zeros(n_groups, dtype=X.dtype) - for g in range(n_groups): - X_g = X[:, grp_indices[grp_ptr[g]:grp_ptr[g + 1]]] - lipschitz[g] = norm(X_g, ord=2) ** 2 / len(y) + grp_size = n_features // n_groups + t_new = 1 w = w_init.copy() if w_init is not None else np.zeros(n_features) + z = w_init.copy() if w_init is not None else np.zeros(n_features) weights = weights if weights is not None else np.ones(n_groups) - # BCD + G = X.T @ X + Xty = X.T @ y + L = np.linalg.norm(X, ord=2) ** 2 / len(y) for n_iter in range(max_iter): - bcd_epoch(X, G, grads, w, alpha, lipschitz, grp_indices, grp_ptr, weights) + t_old = t_new + t_new = (1 + np.sqrt(1 + 4 * t_old ** 2)) / 2 + w_old = w.copy() + z -= (G @ z - Xty) / L / len(y) + w = BST_vec(z, alpha / L * weights, grp_size) + z = w + (t_old - 1.) / t_new * (w - w_old) if n_iter % check_freq == 0: p_obj, d_obj, d_gap = dual_gap_grp(y, X, w, alpha, norm_y2, grp_ptr, grp_indices, weights) From bfae08829c63e0dd63296f8dce57ba23b262b6ca Mon Sep 17 00:00:00 2001 From: Pierre-Antoine Bannier Date: Mon, 25 Apr 2022 10:24:51 +0200 Subject: [PATCH 17/21] fix primal comp --- skglm/solvers/gram.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/skglm/solvers/gram.py b/skglm/solvers/gram.py index 9a87e10c5..26d0b4b15 100644 --- a/skglm/solvers/gram.py +++ b/skglm/solvers/gram.py @@ -8,14 +8,22 @@ @njit def primal(alpha, r, w, weights): + n_features = len(weights) p_obj = (r @ r) / (2 * len(r)) - return p_obj + alpha * np.sum(np.abs(w * weights)) + pen = 0. + for j in range(n_features): + if weights[j] == np.inf: + continue + pen += np.abs(w[j] * weights[j]) + return p_obj + alpha * pen @njit def primal_grp(alpha, norm_r2, r, w, grp_ptr, grp_indices, weights): p_obj = norm_r2 / (2 * len(r)) for g in range(len(grp_ptr) - 1): + if weights[g] == np.inf: + continue w_g = w[grp_indices[grp_ptr[g]:grp_ptr[g + 1]]] p_obj += alpha * norm(w_g * weights[g], ord=2) return p_obj From f7bda722b3ba079b5915cf59a8e7b76317614916 Mon Sep 17 00:00:00 2001 From: Pierre-Antoine Bannier Date: Mon, 25 Apr 2022 10:39:34 +0200 Subject: [PATCH 18/21] FIX prox_L21 for variable size groups --- skglm/solvers/gram.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/skglm/solvers/gram.py b/skglm/solvers/gram.py index 26d0b4b15..79a5a8929 100644 --- a/skglm/solvers/gram.py +++ b/skglm/solvers/gram.py @@ -3,7 +3,7 @@ from numpy.linalg import norm from celer.homotopy import _grp_converter -from skglm.utils import BST, ST, BST_vec, ST_vec +from skglm.utils import BST, ST, ST_vec @njit @@ -109,6 +109,18 @@ def compute_lipschitz(X, y): return lipschitz +@njit +def prox_l21(w, u, weights, grp_ptr, grp_indices): + n_groups = len(grp_ptr) - 1 + out = w.copy() + for g in range(n_groups): + idx = grp_indices[grp_ptr[g]:grp_ptr[g + 1]] + grp_nrm = norm(w[idx], ord=2) + scaling = np.maximum(1 - u / grp_nrm * weights[g], 0) + out[idx] *= scaling + return out + + def gram_lasso(X, y, alpha, max_iter, tol, w_init=None, weights=None, check_freq=100): n_features = X.shape[1] norm_y2 = y @ y @@ -190,7 +202,6 @@ def gram_fista_group_lasso(X, y, alpha, groups, max_iter, tol, w_init=None, norm_y2 = y @ y grp_ptr, grp_indices = _grp_converter(groups, X.shape[1]) n_groups = len(grp_ptr) - 1 - grp_size = n_features // n_groups t_new = 1 w = w_init.copy() if w_init is not None else np.zeros(n_features) z = w_init.copy() if w_init is not None else np.zeros(n_features) @@ -203,7 +214,7 @@ def gram_fista_group_lasso(X, y, alpha, groups, max_iter, tol, w_init=None, t_new = (1 + np.sqrt(1 + 4 * t_old ** 2)) / 2 w_old = w.copy() z -= (G @ z - Xty) / L / len(y) - w = BST_vec(z, alpha / L * weights, grp_size) + w = prox_l21(z, alpha / L, weights, grp_ptr, grp_indices) z = w + (t_old - 1.) / t_new * (w - w_old) if n_iter % check_freq == 0: p_obj, d_obj, d_gap = dual_gap_grp(y, X, w, alpha, norm_y2, grp_ptr, From f8e4cf051c1bdc0d2957f9da4615e311bdb83147 Mon Sep 17 00:00:00 2001 From: pabannier Date: Mon, 25 Apr 2022 10:48:34 +0200 Subject: [PATCH 19/21] better precision and working group lasso celer --- skglm/gram_solver.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/skglm/gram_solver.py b/skglm/gram_solver.py index 6463dc68c..57301a4f5 100644 --- a/skglm/gram_solver.py +++ b/skglm/gram_solver.py @@ -59,15 +59,15 @@ np.testing.assert_allclose(w, w_fista, rtol=1e-4) -# clf_celer = GroupLasso(group_size, alpha, tol=tol, weights=weights_grp, -# fit_intercept=False) -# start = time() -# clf_celer.fit(X, y) -# celer_group_lasso_time = time() - start -# np.testing.assert_allclose(w, clf_celer.coef_, rtol=1e-1) +clf_celer = GroupLasso(group_size, alpha, tol=tol, weights=weights_grp, + fit_intercept=False) +start = time() +clf_celer.fit(X, y) +celer_group_lasso_time = time() - start +np.testing.assert_allclose(w, clf_celer.coef_, rtol=1e-4) print("\n") -# print("Celer: %.2f" % celer_group_lasso_time) +print("Celer: %.2f" % celer_group_lasso_time) print("BCD Gram: %.2f" % gram_group_lasso_time) print("FISTA Gram: %.2f" % gram_fista_group_lasso_time) print("\n") From 5c0a33a8ea7f39152e7b9f47845b2e7c1c18b2b0 Mon Sep 17 00:00:00 2001 From: pabannier Date: Mon, 25 Apr 2022 10:54:36 +0200 Subject: [PATCH 20/21] example with weights --- gram_test.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/gram_test.py b/gram_test.py index e3d25e69b..57a4037c1 100644 --- a/gram_test.py +++ b/gram_test.py @@ -12,19 +12,19 @@ y = np.load("target.npy") groups = np.load("groups.npy") weights = np.load("weights.npy") -# grps = [list(np.where(groups == i)[0]) for i in range(1, 33)] +grps = [list(np.where(groups == i)[0]) for i in range(1, 33)] alpha_ratio = 1e-2 n_alphas = 10 - +tol = 1e-8 # Case 1: slower runtime for (very) small alphas # alpha_max = 0.003471727067743962 alpha_max = np.max(np.linalg.norm((X.T @ y).reshape(-1, 5), axis=1)) / len(y) alpha = alpha_max / 100 -clf = GroupLasso(fit_intercept=False, - groups=5, alpha=alpha, verbose=1) +clf = GroupLasso(fit_intercept=False, tol=tol, + groups=grps, weights=weights, alpha=alpha, verbose=1) t0 = time.time() clf.fit(X, y) @@ -32,12 +32,9 @@ print(f"Celer: {t1 - t0:.3f} s") -# beware: stopping criterion is not the same, tol here needs to be lower -# to get meaningful comparison t0 = time.time() -res = group_lasso(X, y, alpha, groups=5, tol=1e-10, max_iter=10_000, check_freq=10) +res = gram_group_lasso(X, y, alpha, groups=grps, tol=tol, weights=weights, max_iter=10_000, + check_freq=50) t1 = time.time() print(f"skglm gram: {t1 - t0:.3f} s") - -# TODO support weights in gram solver From ae92334d2b685e1189dc04fa1285e33b715e9894 Mon Sep 17 00:00:00 2001 From: pabannier Date: Mon, 25 Apr 2022 11:03:22 +0200 Subject: [PATCH 21/21] ENH fista to example --- gram_test.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/gram_test.py b/gram_test.py index 57a4037c1..1fcab4585 100644 --- a/gram_test.py +++ b/gram_test.py @@ -2,11 +2,9 @@ import time -from numpy.linalg import norm -import matplotlib.pyplot as plt import numpy as np from celer import GroupLasso -from skglm.solvers.gram import gram_group_lasso +from skglm.solvers.gram import gram_fista_group_lasso, gram_group_lasso X = np.load("design_matrix.npy") y = np.load("target.npy") @@ -19,7 +17,7 @@ n_alphas = 10 tol = 1e-8 -# Case 1: slower runtime for (very) small alphas +# Case 1: slower runtime for small alphas # alpha_max = 0.003471727067743962 alpha_max = np.max(np.linalg.norm((X.T @ y).reshape(-1, 5), axis=1)) / len(y) alpha = alpha_max / 100 @@ -38,3 +36,21 @@ t1 = time.time() print(f"skglm gram: {t1 - t0:.3f} s") + + +# FISTA Gram for very small alphas +alpha = alpha_max / 1e-4 +clf = GroupLasso(fit_intercept=False, tol=tol, groups=grps, weights=weights, alpha=alpha, + verbose=1) + +t0 = time.time() +clf.fit(X, y) +t1 = time.time() + +print(f"Celer: {t1 - t0:.3f} s") + +t0 = time.time() +res = gram_fista_group_lasso(X, y, alpha, groups=grps, tol=tol, weights=weights, max_iter=10_000, + check_freq=50) +t1 = time.time() +print(f"skglm fista gram: {t1 - t0:.3f} s") \ No newline at end of file