diff --git a/torch_np/__init__.py b/torch_np/__init__.py index d2b3f539..2706d9d4 100644 --- a/torch_np/__init__.py +++ b/torch_np/__init__.py @@ -1,4 +1,4 @@ -from . import random +from . import linalg, random from ._binary_ufuncs import * from ._detail._util import AxisError, UFuncTypeError from ._dtypes import * diff --git a/torch_np/_funcs.py b/torch_np/_funcs.py index e230f319..7d1db54e 100644 --- a/torch_np/_funcs.py +++ b/torch_np/_funcs.py @@ -911,10 +911,6 @@ def array_equiv(a1: ArrayLike, a2: ArrayLike): return _tensor_equal(a1_t, a2_t) -def common_type(): - raise NotImplementedError - - def mintypecode(): raise NotImplementedError @@ -927,6 +923,10 @@ def asfarray(): raise NotImplementedError +def block(*args, **kwds): + raise NotImplementedError + + # ### put/take_along_axis ### @@ -1358,8 +1358,12 @@ def reshape(a: ArrayLike, newshape, order="C"): @normalizer def transpose(a: ArrayLike, axes=None): # numpy allows both .tranpose(sh) and .transpose(*sh) + # also older code uses axes being a list if axes in [(), None, (None,)]: axes = tuple(range(a.ndim))[::-1] + elif len(axes) == 1: + axes = axes[0] + try: result = a.permute(axes) except RuntimeError: @@ -1908,3 +1912,45 @@ def blackman(M): def bartlett(M): dtype = _dtypes_impl.default_float_dtype return torch.bartlett_window(M, periodic=False, dtype=dtype) + + +# ### Dtype routines ### + +# vendored from https://github.com/numpy/numpy/blob/v1.24.0/numpy/lib/type_check.py#L666 + + +array_type = [ + [torch.float16, torch.float32, torch.float64], + [None, torch.complex64, torch.complex128], +] +array_precision = { + torch.float16: 0, + torch.float32: 1, + torch.float64: 2, + torch.complex64: 1, + torch.complex128: 2, +} + + +@normalizer +def common_type(*tensors: ArrayLike): + + import builtins + + is_complex = False + precision = 0 + for a in tensors: + t = a.dtype + if iscomplexobj(a): + is_complex = True + if not (t.is_floating_point or t.is_complex): + p = 2 # array_precision[_nx.double] + else: + p = array_precision.get(t, None) + if p is None: + raise TypeError("can't get common type for non-numeric array") + precision = builtins.max(precision, p) + if is_complex: + return array_type[1][precision] + else: + return array_type[0][precision] diff --git a/torch_np/_ndarray.py b/torch_np/_ndarray.py index 08b0ed18..bf1a58ed 100644 --- a/torch_np/_ndarray.py +++ b/torch_np/_ndarray.py @@ -152,6 +152,14 @@ def copy(self, order="C"): tensor = self.tensor.clone() return ndarray(tensor) + def view(self, dtype): + torch_dtype = _dtypes.dtype(dtype).torch_dtype + tview = self.tensor.view(torch_dtype) + return ndarray(tview) + + def fill(self, value): + self.tensor.fill_(value) + def tolist(self): return self.tensor.tolist() diff --git a/torch_np/linalg.py b/torch_np/linalg.py new file mode 100644 index 00000000..17792970 --- /dev/null +++ b/torch_np/linalg.py @@ -0,0 +1,240 @@ +import functools +import math +from typing import Sequence + +import torch + +from ._detail import _dtypes_impl, _util +from ._normalizations import ArrayLike, normalizer + + +class LinAlgError(Exception): + pass + + +def _atleast_float_1(a): + if not (a.dtype.is_floating_point or a.dtype.is_complex): + a = a.to(_dtypes_impl.default_float_dtype) + return a + + +def _atleast_float_2(a, b): + dtyp = _dtypes_impl.result_type_impl((a.dtype, b.dtype)) + if not (dtyp.is_floating_point or dtyp.is_complex): + dtyp = _dtypes_impl.default_float_dtype + + a = _util.cast_if_needed(a, dtyp) + b = _util.cast_if_needed(b, dtyp) + return a, b + + +def linalg_errors(func): + @functools.wraps(func) + def wrapped(*args, **kwds): + try: + return func(*args, **kwds) + except torch._C._LinAlgError as e: + raise LinAlgError(*e.args) + + return wrapped + + +# ### Matrix and vector products ### + + +@normalizer +@linalg_errors +def matrix_power(a: ArrayLike, n): + a = _atleat_float_1(a) + return torch.linalg.matrix_power(a, n) + + +@normalizer +@linalg_errors +def multi_dot(inputs: Sequence[ArrayLike], *, out=None): + return torch.linalg.multi_dot(inputs) + + +# ### Solving equations and inverting matrices ### + + +@normalizer +@linalg_errors +def solve(a: ArrayLike, b: ArrayLike): + a, b = _atleast_float_2(a, b) + return torch.linalg.solve(a, b) + + +@normalizer +@linalg_errors +def lstsq(a: ArrayLike, b: ArrayLike, rcond=None): + a, b = _atleast_float_2(a, b) + # NumPy is using gelsd: https://github.com/numpy/numpy/blob/v1.24.0/numpy/linalg/umath_linalg.cpp#L3991 + # on CUDA, only `gels` is available though, so use it instead + driver = "gels" if a.is_cuda or b.is_cuda else "gelsd" + return torch.linalg.lstsq(a, b, rcond=rcond, driver=driver) + + +@normalizer +@linalg_errors +def inv(a: ArrayLike): + a = _atleast_float_1(a) + result = torch.linalg.inv(a) + return result + + +@normalizer +@linalg_errors +def pinv(a: ArrayLike, rcond=1e-15, hermitian=False): + a = _atleast_float_1(a) + return torch.linalg.pinv(a, rtol=rcond, hermitian=hermitian) + + +@normalizer +@linalg_errors +def tensorsolve(a: ArrayLike, b: ArrayLike, axes=None): + a, b = _atleast_float_2(a, b) + return torch.linalg.tensorsolve(a, b, dims=axes) + + +@normalizer +@linalg_errors +def tensorinv(a: ArrayLike, ind=2): + a = _atleast_float_1(a) + return torch.linalg.tensorinv(a, ind=ind) + + +# ### Norms and other numbers ### + + +@normalizer +@linalg_errors +def det(a: ArrayLike): + a = _atleast_float_1(a) + return torch.linalg.det(a) + + +@normalizer +@linalg_errors +def slogdet(a: ArrayLike): + a = _atleast_float_1(a) + return torch.linalg.slogdet(a) + + +@normalizer +@linalg_errors +def cond(x: ArrayLike, p=None): + x = _atleast_float_1(x) + + # check if empty + # cf: https://github.com/numpy/numpy/blob/v1.24.0/numpy/linalg/linalg.py#L1744 + if x.numel() == 0 and math.prod(x.shape[-2:]) == 0: + raise LinAlgError("cond is not defined on empty arrays") + + result = torch.linalg.cond(x, p=p) + + # Convert nans to infs (numpy does it in a data-dependent way, depending on + # whether the input array has nans or not) + # XXX: NumPy does this: https://github.com/numpy/numpy/blob/v1.24.0/numpy/linalg/linalg.py#L1744 + return torch.where(torch.isnan(result), float("inf"), result) + + +@normalizer +@linalg_errors +def matrix_rank(a: ArrayLike, tol=None, hermitian=False): + a = _atleast_float_1(a) + + if a.ndim < 2: + return int((a != 0).any()) + + if tol is None: + # follow https://github.com/numpy/numpy/blob/v1.24.0/numpy/linalg/linalg.py#L1885 + atol = 0 + rtol = max(a.shape[-2:]) * torch.finfo(a.dtype).eps + else: + atol, rtol = tol, 0 + return torch.linalg.matrix_rank(a, atol=atol, rtol=rtol, hermitian=hermitian) + + +@normalizer +@linalg_errors +def norm(x: ArrayLike, ord=None, axis=None, keepdims=False): + x = _atleast_float_1(x) + result = torch.linalg.norm(x, ord=ord, dim=axis) + if keepdims: + result = _util.apply_keepdims(result, axis, x.ndim) + return result + + +# ### Decompositions ### + + +@normalizer +@linalg_errors +def cholesky(a: ArrayLike): + a = _atleast_float_1(a) + return torch.linalg.cholesky(a) + + +@normalizer +@linalg_errors +def qr(a: ArrayLike, mode="reduced"): + a = _atleast_float_1(a) + result = torch.linalg.qr(a, mode=mode) + if mode == "r": + # match NumPy + result = result.R + return result + + +@normalizer +@linalg_errors +def svd(a: ArrayLike, full_matrices=True, compute_uv=True, hermitian=False): + a = _atleast_float_1(a) + if not compute_uv: + return torch.linalg.svdvals(a) + + # NB: ignore the hermitian= argument (no pytorch equivalent) + result = torch.linalg.svd(a, full_matrices=full_matrices) + return result + + +# ### Eigenvalues and eigenvectors ### + + +@normalizer +@linalg_errors +def eig(a: ArrayLike): + a = _atleast_float_1(a) + w, vt = torch.linalg.eig(a) + + if not a.is_complex(): + if w.is_complex() and (w.imag == 0).all(): + w = w.real + vt = vt.real + return w, vt + + +@normalizer +@linalg_errors +def eigh(a: ArrayLike, UPLO="L"): + a = _atleast_float_1(a) + return torch.linalg.eigh(a, UPLO=UPLO) + + +@normalizer +@linalg_errors +def eigvals(a: ArrayLike): + a = _atleast_float_1(a) + result = torch.linalg.eigvals(a) + if not a.is_complex(): + if result.is_complex() and (result.imag == 0).all(): + result = result.real + return result + + +@normalizer +@linalg_errors +def eigvalsh(a: ArrayLike, UPLO="L"): + a = _atleast_float_1(a) + return torch.linalg.eigvalsh(a, UPLO=UPLO) diff --git a/torch_np/tests/numpy_tests/linalg/test_linalg.py b/torch_np/tests/numpy_tests/linalg/test_linalg.py index 83595bb7..a5b51608 100644 --- a/torch_np/tests/numpy_tests/linalg/test_linalg.py +++ b/torch_np/tests/numpy_tests/linalg/test_linalg.py @@ -320,10 +320,12 @@ def test_empty_nonsq_cases(self): class HermitianTestCase(LinalgTestCase): + @pytest.mark.xfail(reason="zero-sized arrays") def test_herm_cases(self): self.check_cases(require={'hermitian'}, exclude={'generalized', 'size-0'}) + @pytest.mark.xfail(reason="zero-sized arrays") def test_empty_herm_cases(self): self.check_cases(require={'hermitian', 'size-0'}, exclude={'generalized'}) @@ -336,6 +338,7 @@ def test_generalized_sq_cases(self): self.check_cases(require={'generalized', 'square'}, exclude={'size-0'}) + @pytest.mark.xfail(reason="zero-size arrays") @pytest.mark.slow def test_generalized_empty_sq_cases(self): self.check_cases(require={'generalized', 'square', 'size-0'}) @@ -355,11 +358,13 @@ def test_generalized_empty_nonsq_cases(self): class HermitianGeneralizedTestCase(LinalgTestCase): + @pytest.mark.xfail(reason="sort complex") @pytest.mark.slow def test_generalized_herm_cases(self): self.check_cases(require={'generalized', 'hermitian'}, exclude={'size-0'}) + @pytest.mark.xfail(reason="zero-size arrays") @pytest.mark.slow def test_generalized_empty_herm_cases(self): self.check_cases(require={'generalized', 'hermitian', 'size-0'}, @@ -403,13 +408,13 @@ def do(self, a, b, tags): assert_(consistent_subclass(x, b)) -@pytest.mark.xfail(reason='TODO') class TestSolve(SolveCases): @pytest.mark.parametrize('dtype', [single, double, csingle, cdouble]) def test_types(self, dtype): x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype) assert_equal(linalg.solve(x, x).dtype, dtype) + @pytest.mark.xfail(reason="zero-sized arrays") def test_0_size(self): class ArraySubclass(np.ndarray): pass @@ -443,6 +448,7 @@ class ArraySubclass(np.ndarray): assert_raises(ValueError, linalg.solve, a[0:0], b[0:0]) assert_raises(ValueError, linalg.solve, a[:, 0:0, 0:0], b) + @pytest.mark.xfail(reason="zero-sized arrays") def test_0_size_k(self): # test zero multiple equation (K=0) case. class ArraySubclass(np.ndarray): @@ -471,13 +477,13 @@ def do(self, a, b, tags): assert_(consistent_subclass(a_inv, a)) -@pytest.mark.xfail(reason='TODO') class TestInv(InvCases): @pytest.mark.parametrize('dtype', [single, double, csingle, cdouble]) def test_types(self, dtype): x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype) assert_equal(linalg.inv(x).dtype, dtype) + @pytest.mark.xfail(reason="zero-sized arrays") def test_0_size(self): # Check that all kinds of 0-sized arrays work class ArraySubclass(np.ndarray): @@ -503,7 +509,7 @@ def do(self, a, b, tags): assert_almost_equal(ev, evalues) -@pytest.mark.xfail(reason='TODO') + class TestEigvals(EigvalsCases): @pytest.mark.parametrize('dtype', [single, double, csingle, cdouble]) def test_types(self, dtype): @@ -512,6 +518,7 @@ def test_types(self, dtype): x = np.array([[1, 0.5], [-1, 1]], dtype=dtype) assert_equal(linalg.eigvals(x).dtype, get_complex_dtype(dtype)) + @pytest.mark.xfail(reason="zero-sized arrays") def test_0_size(self): # Check that all kinds of 0-sized arrays work class ArraySubclass(np.ndarray): @@ -541,7 +548,6 @@ def do(self, a, b, tags): assert_(consistent_subclass(evectors, a)) -@pytest.mark.xfail(reason='TODO') class TestEig(EigCases): @pytest.mark.parametrize('dtype', [single, double, csingle, cdouble]) def test_types(self, dtype): @@ -555,6 +561,7 @@ def test_types(self, dtype): assert_equal(w.dtype, get_complex_dtype(dtype)) assert_equal(v.dtype, get_complex_dtype(dtype)) + @pytest.mark.xfail(reason="zero-sized arrays") def test_0_size(self): # Check that all kinds of 0-sized arrays work class ArraySubclass(np.ndarray): @@ -603,7 +610,6 @@ def do(self, a, b, tags): assert_(consistent_subclass(vt, a)) -@pytest.mark.xfail(reason='TODO') class TestSVD(SVDCases, SVDBaseTests): def test_empty_identity(self): """ Empty input should put an identity matrix in u or vh """ @@ -634,12 +640,11 @@ def hermitian(mat): assert_almost_equal(np.matmul(u, hermitian(u)), np.broadcast_to(np.eye(u.shape[-1]), u.shape)) assert_almost_equal(np.matmul(vt, hermitian(vt)), np.broadcast_to(np.eye(vt.shape[-1]), vt.shape)) - assert_equal(np.sort(s)[..., ::-1], s) + assert_equal(np.sort(s), np.flip(s, -1)) assert_(consistent_subclass(u, a)) assert_(consistent_subclass(vt, a)) -@pytest.mark.xfail(reason='TODO') class TestSVDHermitian(SVDHermitianCases, SVDBaseTests): hermitian = True @@ -690,7 +695,6 @@ def do(self, a, b, tags): single_decimal=5, double_decimal=11) -@pytest.mark.xfail(reason='TODO') class TestCond(CondCases): def test_basic_nonsvd(self): # Smoketest the non-svd norms @@ -771,7 +775,6 @@ def do(self, a, b, tags): assert_(consistent_subclass(a_ginv, a)) -@pytest.mark.xfail(reason='TODO') class TestPinv(PinvCases): pass @@ -786,7 +789,6 @@ def do(self, a, b, tags): assert_(consistent_subclass(a_ginv, a)) -@pytest.mark.xfail(reason='TODO') class TestPinvHermitian(PinvHermitianCases): pass @@ -801,8 +803,8 @@ def do(self, a, b, tags): else: ad = asarray(a).astype(cdouble) ev = linalg.eigvals(ad) - assert_almost_equal(d, multiply.reduce(ev, axis=-1)) - assert_almost_equal(s * np.exp(ld), multiply.reduce(ev, axis=-1)) + assert_almost_equal(d, np.prod(ev, axis=-1)) + assert_almost_equal(s * np.exp(ld), np.prod(ev, axis=-1)) s = np.atleast_1d(s) ld = np.atleast_1d(ld) @@ -811,20 +813,20 @@ def do(self, a, b, tags): assert_equal(ld[~m], -inf) -@pytest.mark.xfail(reason='TODO') class TestDet(DetCases): def test_zero(self): + # NB: comment out tests of type(det) == double : we return zero-dim arrays assert_equal(linalg.det([[0.0]]), 0.0) - assert_equal(type(linalg.det([[0.0]])), double) + # assert_equal(type(linalg.det([[0.0]])), double) assert_equal(linalg.det([[0.0j]]), 0.0) - assert_equal(type(linalg.det([[0.0j]])), cdouble) + # assert_equal(type(linalg.det([[0.0j]])), cdouble) assert_equal(linalg.slogdet([[0.0]]), (0.0, -inf)) - assert_equal(type(linalg.slogdet([[0.0]])[0]), double) - assert_equal(type(linalg.slogdet([[0.0]])[1]), double) + # assert_equal(type(linalg.slogdet([[0.0]])[0]), double) + # assert_equal(type(linalg.slogdet([[0.0]])[1]), double) assert_equal(linalg.slogdet([[0.0j]]), (0.0j, -inf)) - assert_equal(type(linalg.slogdet([[0.0j]])[0]), cdouble) - assert_equal(type(linalg.slogdet([[0.0j]])[1]), double) + # assert_equal(type(linalg.slogdet([[0.0j]])[0]), cdouble) + # assert_equal(type(linalg.slogdet([[0.0j]])[1]), double) @pytest.mark.parametrize('dtype', [single, double, csingle, cdouble]) def test_types(self, dtype): @@ -864,28 +866,29 @@ def do(self, a, b, tags): if m == 0: assert_((x == 0).all()) if m <= n: - assert_almost_equal(b, dot(a, x)) + assert_almost_equal(b, dot(a, x), single_decimal=5) assert_equal(rank, m) else: assert_equal(rank, n) - assert_almost_equal(sv, sv.__array_wrap__(s)) + # assert_almost_equal(sv, sv.__array_wrap__(s)) if rank == n and m > n: expect_resids = ( np.asarray(abs(np.dot(a, x) - b)) ** 2).sum(axis=0) expect_resids = np.asarray(expect_resids) if np.asarray(b).ndim == 1: - expect_resids.shape = (1,) + expect_resids = expect_resids.reshape(1,) assert_equal(residuals.shape, expect_resids.shape) else: - expect_resids = np.array([]).view(type(x)) - assert_almost_equal(residuals, expect_resids) + expect_resids = np.array([]) #.view(type(x)) + assert_almost_equal(residuals, expect_resids, single_decimal=5) assert_(np.issubdtype(residuals.dtype, np.floating)) assert_(consistent_subclass(x, b)) assert_(consistent_subclass(residuals, b)) -@pytest.mark.xfail(reason='TODO') class TestLstsq(LstsqCases): + + @pytest.mark.xfail(reason="Lstsq: we use the future default =None") def test_future_rcond(self): a = np.array([[0., 1., 0., 1., 2., 0.], [0., 2., 0., 0., 1., 0.], @@ -910,7 +913,7 @@ def test_future_rcond(self): (0, 4, 2), (4, 0, 1), (4, 0, 2), - (4, 2, 0), + # (4, 2, 0), # Intel MKL ERROR: Parameter 4 was incorrect on entry to DLALSD. (0, 0, 0) ]) def test_empty_a_b(self, m, n, n_rhs): @@ -934,7 +937,7 @@ def test_incompatible_dims(self): y = np.array([-1, 0.2, 0.9, 2.1, 3.3]) A = np.vstack([x, np.ones(len(x))]).T # with assert_raises_regex(LinAlgError, "Incompatible dimensions"): - with assert_raises(LinAlgError): + with assert_raises((RuntimeError, LinAlgError)): linalg.lstsq(A, y, rcond=None) @@ -1028,7 +1031,6 @@ def test_exceptions_not_invertible(self, dt): assert_raises(LinAlgError, matrix_power, mat, -1) -@pytest.mark.xfail(reason='TODO') class TestEigvalshCases(HermitianTestCase, HermitianGeneralizedTestCase): def do(self, a, b, tags): @@ -1043,7 +1045,6 @@ def do(self, a, b, tags): assert_allclose(ev2, evalues, rtol=get_rtol(ev.dtype)) -@pytest.mark.xfail(reason='TODO') class TestEigvalsh: @pytest.mark.parametrize('dtype', [single, double, csingle, cdouble]) def test_types(self, dtype): @@ -1053,9 +1054,9 @@ def test_types(self, dtype): def test_invalid(self): x = np.array([[1, 0.5], [0.5, 1]], dtype=np.float32) - assert_raises(ValueError, np.linalg.eigvalsh, x, UPLO="lrong") - assert_raises(ValueError, np.linalg.eigvalsh, x, "lower") - assert_raises(ValueError, np.linalg.eigvalsh, x, "upper") + assert_raises((RuntimeError, ValueError), np.linalg.eigvalsh, x, UPLO="lrong") + assert_raises((RuntimeError, ValueError), np.linalg.eigvalsh, x, "lower") + assert_raises((RuntimeError, ValueError), np.linalg.eigvalsh, x, "upper") def test_UPLO(self): Klo = np.array([[0, 0], [1, 0]], dtype=np.double) @@ -1081,16 +1082,16 @@ def test_UPLO(self): def test_0_size(self): # Check that all kinds of 0-sized arrays work - class ArraySubclass(np.ndarray): - pass - a = np.zeros((0, 1, 1), dtype=np.int_).view(ArraySubclass) + # class ArraySubclass(np.ndarray): + # pass + a = np.zeros((0, 1, 1), dtype=np.int_) #.view(ArraySubclass) res = linalg.eigvalsh(a) assert_(res.dtype.type is np.float64) assert_equal((0, 1), res.shape) # This is just for documentation, it might make sense to change: assert_(isinstance(res, np.ndarray)) - a = np.zeros((0, 0), dtype=np.complex64).view(ArraySubclass) + a = np.zeros((0, 0), dtype=np.complex64) #.view(ArraySubclass) res = linalg.eigvalsh(a) assert_(res.dtype.type is np.float32) assert_equal((0,), res.shape) @@ -1098,7 +1099,6 @@ class ArraySubclass(np.ndarray): assert_(isinstance(res, np.ndarray)) -@pytest.mark.xfail(reason='TODO') class TestEighCases(HermitianTestCase, HermitianGeneralizedTestCase): def do(self, a, b, tags): @@ -1121,7 +1121,6 @@ def do(self, a, b, tags): rtol=get_rtol(ev.dtype), err_msg=repr(a)) -@pytest.mark.xfail(reason='TODO') class TestEigh: @pytest.mark.parametrize('dtype', [single, double, csingle, cdouble]) def test_types(self, dtype): @@ -1132,9 +1131,9 @@ def test_types(self, dtype): def test_invalid(self): x = np.array([[1, 0.5], [0.5, 1]], dtype=np.float32) - assert_raises(ValueError, np.linalg.eigh, x, UPLO="lrong") - assert_raises(ValueError, np.linalg.eigh, x, "lower") - assert_raises(ValueError, np.linalg.eigh, x, "upper") + assert_raises((RuntimeError, ValueError), np.linalg.eigh, x, UPLO="lrong") + assert_raises((RuntimeError, ValueError), np.linalg.eigh, x, "lower") + assert_raises((RuntimeError, ValueError), np.linalg.eigh, x, "upper") def test_UPLO(self): Klo = np.array([[0, 0], [1, 0]], dtype=np.double) @@ -1160,9 +1159,9 @@ def test_UPLO(self): def test_0_size(self): # Check that all kinds of 0-sized arrays work - class ArraySubclass(np.ndarray): - pass - a = np.zeros((0, 1, 1), dtype=np.int_).view(ArraySubclass) +# class ArraySubclass(np.ndarray): +# pass + a = np.zeros((0, 1, 1), dtype=np.int_) #.view(ArraySubclass) res, res_v = linalg.eigh(a) assert_(res_v.dtype.type is np.float64) assert_(res.dtype.type is np.float64) @@ -1171,7 +1170,7 @@ class ArraySubclass(np.ndarray): # This is just for documentation, it might make sense to change: assert_(isinstance(a, np.ndarray)) - a = np.zeros((0, 0), dtype=np.complex64).view(ArraySubclass) + a = np.zeros((0, 0), dtype=np.complex64) #.view(ArraySubclass) res, res_v = linalg.eigh(a) assert_(res_v.dtype.type is np.complex64) assert_(res.dtype.type is np.float32) @@ -1212,6 +1211,9 @@ def test_vector_return_type(self): for each_type in all_types: at = a.astype(each_type) + if each_type == np.dtype('float16'): + pytest.xfail('float16**float64 => float64 (?)') + an = norm(at, -np.inf) self.check_dtype(at, an) assert_almost_equal(an, 0.0) @@ -1277,7 +1279,7 @@ def test_axis(self): # Compare the use of `axis` with computing the norm of each row # or column separately. A = array([[1, 2, 3], [4, 5, 6]], dtype=self.dt) - for order in [None, -1, 0, 1, 2, 3, np.Inf, -np.Inf]: + for order in [None, -1, 0, 1, 2, 3, np.inf, -np.inf]: expected0 = [norm(A[:, k], ord=order) for k in range(A.shape[1])] assert_almost_equal(norm(A, ord=order, axis=0), expected0) expected1 = [norm(A[k, :], ord=order) for k in range(A.shape[0])] @@ -1286,7 +1288,7 @@ def test_axis(self): # Matrix norms. B = np.arange(1, 25, dtype=self.dt).reshape(2, 3, 4) nd = B.ndim - for order in [None, -2, 2, -1, 1, np.Inf, -np.Inf, 'fro']: + for order in [None, -2, 2, -1, 1, np.inf, -np.inf, 'fro']: for axis in itertools.combinations(range(-nd, nd), 2): row_axis, col_axis = axis if row_axis < 0: @@ -1294,7 +1296,7 @@ def test_axis(self): if col_axis < 0: col_axis += nd if row_axis == col_axis: - assert_raises(ValueError, norm, B, ord=order, axis=axis) + assert_raises((RuntimeError, ValueError), norm, B, ord=order, axis=axis) else: n = norm(B, ord=order, axis=axis) @@ -1325,7 +1327,7 @@ def test_keepdims(self): shape_err.format(found.shape, expected_shape, None, None)) # Vector norms. - for order in [None, -1, 0, 1, 2, 3, np.Inf, -np.Inf]: + for order in [None, -1, 0, 1, 2, 3, np.inf, -np.inf]: for k in range(A.ndim): expected = norm(A, ord=order, axis=k) found = norm(A, ord=order, axis=k, keepdims=True) @@ -1338,7 +1340,7 @@ def test_keepdims(self): shape_err.format(found.shape, expected_shape, order, k)) # Matrix norms. - for order in [None, -2, 2, -1, 1, np.Inf, -np.Inf, 'fro', 'nuc']: + for order in [None, -2, 2, -1, 1, np.inf, -np.inf, 'fro', 'nuc']: for k in itertools.permutations(range(A.ndim), 2): expected = norm(A, ord=order, axis=k) found = norm(A, ord=order, axis=k, keepdims=True) @@ -1355,13 +1357,12 @@ def test_keepdims(self): class _TestNorm2D(_TestNormBase): # Define the part for 2d arrays separately, so we can subclass this # and run the tests using np.matrix in matrixlib.tests.test_matrix_linalg. - array = np.array def test_matrix_empty(self): - assert_equal(norm(self.array([[]], dtype=self.dt)), 0.0) + assert_equal(norm(np.array([[]], dtype=self.dt)), 0.0) def test_matrix_return_type(self): - a = self.array([[1, 0, 1], [0, 1, 1]]) + a = np.array([[1, 0, 1], [0, 1, 1]]) exact_types = np.typecodes['AllInteger'] @@ -1412,7 +1413,7 @@ def test_matrix_return_type(self): np.testing.assert_almost_equal(an, 2.7320508075688772, decimal=6) def test_matrix_2x2(self): - A = self.array([[1, 3], [5, 7]], dtype=self.dt) + A = np.array([[1, 3], [5, 7]], dtype=self.dt) assert_almost_equal(norm(A), 84 ** 0.5) assert_almost_equal(norm(A, 'fro'), 84 ** 0.5) assert_almost_equal(norm(A, 'nuc'), 10.0) @@ -1423,9 +1424,9 @@ def test_matrix_2x2(self): assert_almost_equal(norm(A, 2), 9.1231056256176615) assert_almost_equal(norm(A, -2), 0.87689437438234041) - assert_raises(ValueError, norm, A, 'nofro') - assert_raises(ValueError, norm, A, -3) - assert_raises(ValueError, norm, A, 0) + assert_raises((RuntimeError, ValueError), norm, A, 'nofro') + assert_raises((RuntimeError, ValueError), norm, A, -3) + assert_raises((RuntimeError, ValueError), norm, A, 0) def test_matrix_3x3(self): # This test has been added because the 2x2 example @@ -1433,7 +1434,7 @@ def test_matrix_3x3(self): # The 1/10 scaling factor accommodates the absolute tolerance # used in assert_almost_equal. A = (1 / 10) * \ - self.array([[1, 2, 3], [6, 0, 5], [3, 2, 1]], dtype=self.dt) + np.array([[1, 2, 3], [6, 0, 5], [3, 2, 1]], dtype=self.dt) assert_almost_equal(norm(A), (1 / 10) * 89 ** 0.5) assert_almost_equal(norm(A, 'fro'), (1 / 10) * 89 ** 0.5) assert_almost_equal(norm(A, 'nuc'), 1.3366836911774836) @@ -1447,62 +1448,43 @@ def test_matrix_3x3(self): def test_bad_args(self): # Check that bad arguments raise the appropriate exceptions. - A = self.array([[1, 2, 3], [4, 5, 6]], dtype=self.dt) + A = np.array([[1, 2, 3], [4, 5, 6]], dtype=self.dt) B = np.arange(1, 25, dtype=self.dt).reshape(2, 3, 4) # Using `axis=` or passing in a 1-D array implies vector # norms are being computed, so also using `ord='fro'` # or `ord='nuc'` or any other string raises a ValueError. - assert_raises(ValueError, norm, A, 'fro', 0) - assert_raises(ValueError, norm, A, 'nuc', 0) - assert_raises(ValueError, norm, [3, 4], 'fro', None) - assert_raises(ValueError, norm, [3, 4], 'nuc', None) - assert_raises(ValueError, norm, [3, 4], 'test', None) + assert_raises((RuntimeError, ValueError), norm, A, 'fro', 0) + assert_raises((RuntimeError, ValueError), norm, A, 'nuc', 0) + assert_raises((RuntimeError, ValueError), norm, [3, 4], 'fro', None) + assert_raises((RuntimeError, ValueError), norm, [3, 4], 'nuc', None) + assert_raises((RuntimeError, ValueError), norm, [3, 4], 'test', None) # Similarly, norm should raise an exception when ord is any finite # number other than 1, 2, -1 or -2 when computing matrix norms. for order in [0, 3]: - assert_raises(ValueError, norm, A, order, None) - assert_raises(ValueError, norm, A, order, (0, 1)) - assert_raises(ValueError, norm, B, order, (1, 2)) + assert_raises((RuntimeError, ValueError), norm, A, order, None) + assert_raises((RuntimeError, ValueError), norm, A, order, (0, 1)) + assert_raises((RuntimeError, ValueError), norm, B, order, (1, 2)) # Invalid axis - assert_raises(np.AxisError, norm, B, None, 3) - assert_raises(np.AxisError, norm, B, None, (2, 3)) - assert_raises(ValueError, norm, B, None, (0, 1, 2)) + assert_raises((IndexError, np.AxisError), norm, B, None, 3) + assert_raises((IndexError, np.AxisError), norm, B, None, (2, 3)) + assert_raises((RuntimeError, ValueError), norm, B, None, (0, 1, 2)) class _TestNorm(_TestNorm2D, _TestNormGeneral): pass -@pytest.mark.xfail(reason='TODO') class TestNorm_NonSystematic: - def test_longdouble_norm(self): - # Non-regression test: p-norm of longdouble would previously raise - # UnboundLocalError. - x = np.arange(10, dtype=np.longdouble) - old_assert_almost_equal(norm(x, ord=3), 12.65, decimal=2) - def test_intmin(self): # Non-regression test: p-norm of signed integer would previously do # float cast and abs in the wrong order. x = np.array([-2 ** 31], dtype=np.int32) old_assert_almost_equal(norm(x, ord=3), 2 ** 31, decimal=5) - def test_complex_high_ord(self): - # gh-4156 - d = np.empty((2,), dtype=np.clongdouble) - d[0] = 6 + 7j - d[1] = -6 + 7j - res = 11.615898132184 - old_assert_almost_equal(np.linalg.norm(d, ord=3), res, decimal=10) - d = d.astype(np.complex128) - old_assert_almost_equal(np.linalg.norm(d, ord=3), res, decimal=9) - d = d.astype(np.complex64) - old_assert_almost_equal(np.linalg.norm(d, ord=3), res, decimal=5) - # Separate definitions so we can use them for matrix tests. class _TestNormDoubleBase(_TestNormBase): @@ -1520,22 +1502,18 @@ class _TestNormInt64Base(_TestNormBase): dec = 12 -@pytest.mark.xfail(reason='TODO') class TestNormDouble(_TestNorm, _TestNormDoubleBase): pass -@pytest.mark.xfail(reason='TODO') class TestNormSingle(_TestNorm, _TestNormSingleBase): pass -@pytest.mark.xfail(reason='TODO') class TestNormInt64(_TestNorm, _TestNormInt64Base): pass -@pytest.mark.xfail(reason='TODO') class TestMatrixRank: def test_matrix_rank(self): @@ -1572,13 +1550,13 @@ def test_symmetric_rank(self): assert_equal(3, matrix_rank(I, hermitian=True, tol=1.01e-8)) -@pytest.mark.xfail(reason='TODO') def test_reduced_rank(): # Test matrices with reduced rank - rng = np.random.RandomState(20120714) + # rng = np.random.RandomState(20120714) + np.random.seed(20120714) for i in range(100): # Make a rank deficient matrix - X = rng.normal(size=(40, 10)) + X = np.random.normal(size=(40, 10)) X[:, 0] = X[:, 1] + X[:, 2] # Assert that matrix_rank detected deficiency assert_equal(matrix_rank(X), 9) @@ -1586,10 +1564,7 @@ def test_reduced_rank(): assert_equal(matrix_rank(X), 8) -@pytest.mark.xfail(reason='TODO') class TestQR: - # Define the array class here, so run this on matrices elsewhere. - array = np.array def check_qr(self, a): # This test expects the argument `a` to be an ndarray or @@ -1607,7 +1582,7 @@ def check_qr(self, a): assert_(isinstance(r, a_type)) assert_(q.shape == (m, m)) assert_(r.shape == (m, n)) - assert_almost_equal(dot(q, r), a) + assert_almost_equal(dot(q, r), a, single_decimal=5) assert_almost_equal(dot(q.T.conj(), q), np.eye(m)) assert_almost_equal(np.triu(r), r) @@ -1619,7 +1594,7 @@ def check_qr(self, a): assert_(isinstance(r1, a_type)) assert_(q1.shape == (m, k)) assert_(r1.shape == (k, n)) - assert_almost_equal(dot(q1, r1), a) + assert_almost_equal(dot(q1, r1), a, single_decimal=5) assert_almost_equal(dot(q1.T.conj(), q1), np.eye(k)) assert_almost_equal(np.triu(r1), r1) @@ -1630,6 +1605,7 @@ def check_qr(self, a): assert_almost_equal(r2, r1) + @pytest.mark.xfail(reason="torch does not allow qr(..., mode='raw'") @pytest.mark.parametrize(["m", "n"], [ (3, 0), (0, 3), @@ -1647,6 +1623,7 @@ def test_qr_empty(self, m, n): assert_equal(h.shape, (n, m)) assert_equal(tau.shape, (k,)) + @pytest.mark.xfail(reason="torch does not allow qr(..., mode='raw'") def test_mode_raw(self): # The factorization is not unique and varies between libraries, # so it is not possible to check against known values. Functional @@ -1654,7 +1631,7 @@ def test_mode_raw(self): # of the functions in lapack_lite. Consequently, this test is # very limited in scope. Note that the results are in FORTRAN # order, hence the h arrays are transposed. - a = self.array([[1, 2], [3, 4], [5, 6]], dtype=np.double) + a = np.array([[1, 2], [3, 4], [5, 6]], dtype=np.double) # Test double h, tau = linalg.qr(a, mode='raw') @@ -1670,8 +1647,8 @@ def test_mode_raw(self): assert_(tau.shape == (2,)) def test_mode_all_but_economic(self): - a = self.array([[1, 2], [3, 4]]) - b = self.array([[1, 2], [3, 4], [5, 6]]) + a = np.array([[1, 2], [3, 4]]) + b = np.array([[1, 2], [3, 4], [5, 6]]) for dt in "fd": m1 = a.astype(dt) m2 = b.astype(dt) @@ -1747,7 +1724,6 @@ def test_stacked_inputs(self, outer_size, size, dt): self.check_qr_stacked(A + 1.j*B) -@pytest.mark.xfail(reason='TODO') class TestCholesky: # TODO: are there no other tests for cholesky? @@ -1773,28 +1749,27 @@ def test_basic_property(self, shape, dtype): c = np.linalg.cholesky(a) b = np.matmul(c, c.transpose(t).conj()) - with np._no_nep50_warning(): - atol = 500 * a.shape[0] * np.finfo(dtype).eps + atol = 500 * a.shape[0] * np.finfo(dtype).eps assert_allclose(b, a, atol=atol, err_msg=f'{shape} {dtype}\n{a}\n{c}') def test_0_size(self): - class ArraySubclass(np.ndarray): - pass - a = np.zeros((0, 1, 1), dtype=np.int_).view(ArraySubclass) + # class ArraySubclass(np.ndarray): + # pass + a = np.zeros((0, 1, 1), dtype=np.int_) #.view(ArraySubclass) res = linalg.cholesky(a) assert_equal(a.shape, res.shape) assert_(res.dtype.type is np.float64) # for documentation purpose: assert_(isinstance(res, np.ndarray)) - a = np.zeros((1, 0, 0), dtype=np.complex64).view(ArraySubclass) + a = np.zeros((1, 0, 0), dtype=np.complex64) #.view(ArraySubclass) res = linalg.cholesky(a) assert_equal(a.shape, res.shape) assert_(res.dtype.type is np.complex64) assert_(isinstance(res, np.ndarray)) -@pytest.mark.xfail(reason='TODO') +@pytest.mark.xfail(reason='endianness') def test_byteorder_check(): # Byte order check should pass for native order if sys.byteorder == 'little': @@ -1816,7 +1791,6 @@ def test_byteorder_check(): assert_array_equal(res, routine(sw_arr)) -@pytest.mark.xfail(reason='TODO') @pytest.mark.skipif(IS_WASM, reason="fp errors don't work in wasm") def test_generalized_raise_multiloop(): # It should raise an error even if the error doesn't occur in the @@ -1832,7 +1806,6 @@ def test_generalized_raise_multiloop(): assert_raises(np.linalg.LinAlgError, np.linalg.inv, x) -@pytest.mark.xfail(reason='TODO') def test_xerbla_override(): # Check that our xerbla has been successfully linked in. If it is not, # the default xerbla routine is called, which prints a message to stdout @@ -1882,7 +1855,6 @@ def test_xerbla_override(): pytest.skip('Numpy xerbla not linked in.') -@pytest.mark.xfail(reason='TODO') @pytest.mark.skipif(IS_WASM, reason="Cannot start subprocess") @pytest.mark.slow def test_sdot_bug_8577(): @@ -1919,7 +1891,6 @@ def test_sdot_bug_8577(): subprocess.check_call([sys.executable, "-c", code]) -@pytest.mark.xfail(reason='TODO') class TestMultiDot: def test_basic_function_with_three_arguments(self): @@ -2045,11 +2016,10 @@ def test_dynamic_programming_logic(self): assert_almost_equal(np.triu(m), np.triu(m_expected)) def test_too_few_input_arrays(self): - assert_raises(ValueError, multi_dot, []) - assert_raises(ValueError, multi_dot, [np.random.random((3, 3))]) + assert_raises((RuntimeError, ValueError), multi_dot, []) + assert_raises((RuntimeError, ValueError), multi_dot, [np.random.random((3, 3))]) -@pytest.mark.xfail(reason='TODO') class TestTensorinv: @pytest.mark.parametrize("arr, ind", [ @@ -2057,7 +2027,7 @@ class TestTensorinv: (np.ones((3, 3, 2)), 1), ]) def test_non_square_handling(self, arr, ind): - with assert_raises(LinAlgError): + with assert_raises((LinAlgError, RuntimeError)): linalg.tensorinv(arr, ind=ind) @pytest.mark.parametrize("shape, ind", [ @@ -2066,8 +2036,7 @@ def test_non_square_handling(self, arr, ind): ((24, 8, 3), 1), ]) def test_tensorinv_shape(self, shape, ind): - a = np.eye(24) - a.shape = shape + a = np.eye(24).reshape(shape) ainv = linalg.tensorinv(a=a, ind=ind) expected = a.shape[ind:] + a.shape[:ind] actual = ainv.shape @@ -2077,21 +2046,18 @@ def test_tensorinv_shape(self, shape, ind): 0, -2, ]) def test_tensorinv_ind_limit(self, ind): - a = np.eye(24) - a.shape = (4, 6, 8, 3) - with assert_raises(ValueError): + a = np.eye(24).reshape(4, 6, 8, 3) + with assert_raises((ValueError, RuntimeError)): linalg.tensorinv(a=a, ind=ind) def test_tensorinv_result(self): # mimic a docstring example - a = np.eye(24) - a.shape = (24, 8, 3) + a = np.eye(24).reshape(24, 8, 3) ainv = linalg.tensorinv(a, ind=1) b = np.ones(24) assert_allclose(np.tensordot(ainv, b, 1), np.linalg.tensorsolve(a, b)) -@pytest.mark.xfail(reason='TODO') class TestTensorsolve: @pytest.mark.parametrize("a, axes", [ @@ -2099,7 +2065,7 @@ class TestTensorsolve: (np.ones((3, 3, 2)), (0, 2)), ]) def test_non_square_handling(self, a, axes): - with assert_raises(LinAlgError): + with assert_raises((LinAlgError, RuntimeError)): b = np.ones(a.shape[:2]) linalg.tensorsolve(a, b, axes=axes) @@ -2137,7 +2103,7 @@ def test_blas64_dot(): assert_equal(c[0,-1], 1) -@pytest.mark.xfail(reason='TODO') +@pytest.mark.skip(reason='lapack-lite specific') @pytest.mark.xfail(not HAS_LAPACK64, reason="Numpy not compiled with 64-bit BLAS/LAPACK") def test_blas64_geqrf_lwork_smoketest():