Skip to content

Commit dc7e449

Browse files
authored
[MRG] Fix test failures due to updated packages: deprecated pytest.warns(None) syntax + GLasso update in sklearn (#357)
* Fix GLasso import for SDML for newer sklearn versions * fix import and argument issue * also fix deprecated pytest.warns(None) syntex * fix flake8
1 parent 8fb6872 commit dc7e449

File tree

5 files changed

+32
-22
lines changed

5 files changed

+32
-22
lines changed

metric_learn/sdml.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,13 @@
66
import numpy as np
77
from sklearn.base import TransformerMixin
88
from scipy.linalg import pinvh
9-
from sklearn.covariance import graphical_lasso
9+
try:
10+
from sklearn.covariance._graph_lasso import (
11+
_graphical_lasso as graphical_lasso
12+
)
13+
except ImportError:
14+
from sklearn.covariance import graphical_lasso
15+
1016
from sklearn.exceptions import ConvergenceWarning
1117

1218
from .base_metric import MahalanobisMixin, _PairsClassifierMixin
@@ -79,9 +85,9 @@ def _fit(self, pairs, y):
7985
msg=self.verbose,
8086
Theta0=theta0, Sigma0=sigma0)
8187
else:
82-
_, M = graphical_lasso(emp_cov, alpha=self.sparsity_param,
83-
verbose=self.verbose,
84-
cov_init=sigma0)
88+
_, M, *_ = graphical_lasso(emp_cov, alpha=self.sparsity_param,
89+
verbose=self.verbose,
90+
cov_init=sigma0)
8591
raised_error = None
8692
w_mahalanobis, _ = np.linalg.eigh(M)
8793
not_spd = any(w_mahalanobis < 0.)

test/metric_learn_test.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
import unittest
23
import re
34
import pytest
@@ -734,12 +735,12 @@ def test_raises_no_warning_installed_skggm(self):
734735
pairs = np.array([[[-10., 0.], [10., 0.]], [[0., -55.], [0., -60]]])
735736
y_pairs = [1, -1]
736737
X, y = make_classification(random_state=42)
737-
with pytest.warns(None) as records:
738+
with warnings.catch_warnings(record=True) as records:
738739
sdml = SDML(prior='covariance')
739740
sdml.fit(pairs, y_pairs)
740741
for record in records:
741742
assert record.category is not ConvergenceWarning
742-
with pytest.warns(None) as records:
743+
with warnings.catch_warnings(record=True) as records:
743744
sdml_supervised = SDML_Supervised(prior='identity', balance_param=1e-5)
744745
sdml_supervised.fit(X, y)
745746
for record in records:
@@ -999,7 +1000,7 @@ def test_rank_deficient_returns_warning(self):
9991000
'for instance using `sklearn.decomposition.PCA` as a '
10001001
'preprocessing step.')
10011002

1002-
with pytest.warns(None) as raised_warnings:
1003+
with warnings.catch_warnings(record=True) as raised_warnings:
10031004
rca.fit(X, y)
10041005
assert any(str(w.message) == msg for w in raised_warnings)
10051006

@@ -1034,7 +1035,7 @@ def test_bad_parameters(self):
10341035
'Increase the number or size of the chunks to correct '
10351036
'this problem.'
10361037
)
1037-
with pytest.warns(None) as raised_warning:
1038+
with warnings.catch_warnings(record=True) as raised_warning:
10381039
rca.fit(X, y)
10391040
assert any(str(w.message) == msg for w in raised_warning)
10401041

test/test_base_metric.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from numpy.core.numeric import array_equal
2+
import warnings
23
import pytest
34
import re
45
import unittest
@@ -226,15 +227,15 @@ def test_get_metric_works_does_not_raise(estimator, build_dataset):
226227
(X[0][None], X[1][None])]
227228

228229
for u, v in list_test_get_metric_doesnt_raise:
229-
with pytest.warns(None) as record:
230+
with warnings.catch_warnings(record=True) as record:
230231
metric(u, v)
231232
assert len(record) == 0
232233

233234
# Test that the scalar case works
234235
model.components_ = np.array([3.1])
235236
metric = model.get_metric()
236237
for u, v in [(5, 6.7), ([5], [6.7]), ([[5]], [[6.7]])]:
237-
with pytest.warns(None) as record:
238+
with warnings.catch_warnings(record=True) as record:
238239
metric(u, v)
239240
assert len(record) == 0
240241

test/test_pairs_classifiers.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from functools import partial
22

3+
import warnings
34
import pytest
45
from numpy.testing import assert_array_equal
56
from scipy.spatial.distance import euclidean
@@ -136,7 +137,7 @@ def test_threshold_different_scores_is_finite(estimator, build_dataset,
136137
estimator.set_params(preprocessor=preprocessor)
137138
set_random_state(estimator)
138139
estimator.fit(input_data, labels)
139-
with pytest.warns(None) as record:
140+
with warnings.catch_warnings(record=True) as record:
140141
estimator.calibrate_threshold(input_data, labels, **kwargs)
141142
assert len(record) == 0
142143

@@ -383,7 +384,7 @@ def test_calibrate_threshold_valid_parameters(valid_args):
383384
pairs, y = rng.randn(20, 2, 5), rng.choice([-1, 1], size=20)
384385
pairs_learner = IdentityPairsClassifier()
385386
pairs_learner.fit(pairs, y)
386-
with pytest.warns(None) as record:
387+
with warnings.catch_warnings(record=True) as record:
387388
pairs_learner.calibrate_threshold(pairs, y, **valid_args)
388389
assert len(record) == 0
389390

@@ -518,7 +519,7 @@ def test_validate_calibration_params_valid_parameters(
518519
# test that no warning message is returned if valid arguments are given to
519520
# _validate_calibration_params for all pairs metric learners, as well as
520521
# a mocking example, and the class itself
521-
with pytest.warns(None) as record:
522+
with warnings.catch_warnings(record=True) as record:
522523
estimator._validate_calibration_params(**valid_args)
523524
assert len(record) == 0
524525

test/test_utils.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
import pytest
23
from scipy.linalg import eigh, pinvh
34
from collections import namedtuple
@@ -353,7 +354,7 @@ def test_check_tuples_valid_tuple_size(tuple_size):
353354
checks that checking the number of tuples (pairs, quadruplets, etc) raises
354355
no warning if there is the right number of points in a tuple.
355356
"""
356-
with pytest.warns(None) as record:
357+
with warnings.catch_warnings(record=True) as record:
357358
check_input(tuples_prep(), type_of_inputs='tuples',
358359
preprocessor=mock_preprocessor, tuple_size=tuple_size)
359360
check_input(tuples_no_prep(), type_of_inputs='tuples', preprocessor=None,
@@ -378,7 +379,7 @@ def test_check_tuples_valid_tuple_size(tuple_size):
378379
[[2.6, 2.3], [3.4, 5.0]]])])
379380
def test_check_tuples_valid_with_preprocessor(tuples):
380381
"""Test that valid inputs when using a preprocessor raises no warning"""
381-
with pytest.warns(None) as record:
382+
with warnings.catch_warnings(record=True) as record:
382383
check_input(tuples, type_of_inputs='tuples',
383384
preprocessor=mock_preprocessor)
384385
assert len(record) == 0
@@ -399,7 +400,7 @@ def test_check_tuples_valid_with_preprocessor(tuples):
399400
((3, 1), (4, 4), (29, 4)))])
400401
def test_check_tuples_valid_without_preprocessor(tuples):
401402
"""Test that valid inputs when using no preprocessor raises no warning"""
402-
with pytest.warns(None) as record:
403+
with warnings.catch_warnings(record=True) as record:
403404
check_input(tuples, type_of_inputs='tuples', preprocessor=None)
404405
assert len(record) == 0
405406

@@ -408,12 +409,12 @@ def test_check_tuples_behaviour_auto_dtype():
408409
"""Checks that check_tuples allows by default every type if using a
409410
preprocessor, and numeric types if using no preprocessor"""
410411
tuples_prep = [['img1.png', 'img2.png'], ['img3.png', 'img5.png']]
411-
with pytest.warns(None) as record:
412+
with warnings.catch_warnings(record=True) as record:
412413
check_input(tuples_prep, type_of_inputs='tuples',
413414
preprocessor=mock_preprocessor)
414415
assert len(record) == 0
415416

416-
with pytest.warns(None) as record:
417+
with warnings.catch_warnings(record=True) as record:
417418
check_input(tuples_no_prep(), type_of_inputs='tuples') # numeric type
418419
assert len(record) == 0
419420

@@ -549,7 +550,7 @@ def test_check_classic_invalid_dtype_not_convertible(preprocessor, points):
549550
[2.6, 2.3]])])
550551
def test_check_classic_valid_with_preprocessor(points):
551552
"""Test that valid inputs when using a preprocessor raises no warning"""
552-
with pytest.warns(None) as record:
553+
with warnings.catch_warnings(record=True) as record:
553554
check_input(points, type_of_inputs='classic',
554555
preprocessor=mock_preprocessor)
555556
assert len(record) == 0
@@ -570,7 +571,7 @@ def test_check_classic_valid_with_preprocessor(points):
570571
(3, 1, 4, 4, 29, 4))])
571572
def test_check_classic_valid_without_preprocessor(points):
572573
"""Test that valid inputs when using no preprocessor raises no warning"""
573-
with pytest.warns(None) as record:
574+
with warnings.catch_warnings(record=True) as record:
574575
check_input(points, type_of_inputs='classic', preprocessor=None)
575576
assert len(record) == 0
576577

@@ -585,12 +586,12 @@ def test_check_classic_behaviour_auto_dtype():
585586
"""Checks that check_input (for points) allows by default every type if
586587
using a preprocessor, and numeric types if using no preprocessor"""
587588
points_prep = ['img1.png', 'img2.png', 'img3.png', 'img5.png']
588-
with pytest.warns(None) as record:
589+
with warnings.catch_warnings(record=True) as record:
589590
check_input(points_prep, type_of_inputs='classic',
590591
preprocessor=mock_preprocessor)
591592
assert len(record) == 0
592593

593-
with pytest.warns(None) as record:
594+
with warnings.catch_warnings(record=True) as record:
594595
check_input(points_no_prep(), type_of_inputs='classic') # numeric type
595596
assert len(record) == 0
596597

0 commit comments

Comments
 (0)