Skip to content

add torch_np.linalg module #100

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torch_np/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from . import random
from . import linalg, random
from ._binary_ufuncs import *
from ._detail._util import AxisError, UFuncTypeError
from ._dtypes import *
Expand Down
54 changes: 50 additions & 4 deletions torch_np/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -911,10 +911,6 @@ def array_equiv(a1: ArrayLike, a2: ArrayLike):
return _tensor_equal(a1_t, a2_t)


def common_type():
raise NotImplementedError


def mintypecode():
raise NotImplementedError

Expand All @@ -927,6 +923,10 @@ def asfarray():
raise NotImplementedError


def block(*args, **kwds):
raise NotImplementedError


# ### put/take_along_axis ###


Expand Down Expand Up @@ -1358,8 +1358,12 @@ def reshape(a: ArrayLike, newshape, order="C"):
@normalizer
def transpose(a: ArrayLike, axes=None):
# numpy allows both .tranpose(sh) and .transpose(*sh)
# also older code uses axes being a list
if axes in [(), None, (None,)]:
axes = tuple(range(a.ndim))[::-1]
elif len(axes) == 1:
axes = axes[0]

try:
result = a.permute(axes)
except RuntimeError:
Expand Down Expand Up @@ -1908,3 +1912,45 @@ def blackman(M):
def bartlett(M):
dtype = _dtypes_impl.default_float_dtype
return torch.bartlett_window(M, periodic=False, dtype=dtype)


# ### Dtype routines ###

# vendored from https://github.com/numpy/numpy/blob/v1.24.0/numpy/lib/type_check.py#L666


array_type = [
[torch.float16, torch.float32, torch.float64],
[None, torch.complex64, torch.complex128],
]
array_precision = {
torch.float16: 0,
torch.float32: 1,
torch.float64: 2,
torch.complex64: 1,
torch.complex128: 2,
}


@normalizer
def common_type(*tensors: ArrayLike):

import builtins

is_complex = False
precision = 0
for a in tensors:
t = a.dtype
if iscomplexobj(a):
is_complex = True
if not (t.is_floating_point or t.is_complex):
p = 2 # array_precision[_nx.double]
else:
p = array_precision.get(t, None)
if p is None:
raise TypeError("can't get common type for non-numeric array")
precision = builtins.max(precision, p)
if is_complex:
return array_type[1][precision]
else:
return array_type[0][precision]
8 changes: 8 additions & 0 deletions torch_np/_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,14 @@ def copy(self, order="C"):
tensor = self.tensor.clone()
return ndarray(tensor)

def view(self, dtype):
torch_dtype = _dtypes.dtype(dtype).torch_dtype
tview = self.tensor.view(torch_dtype)
return ndarray(tview)

def fill(self, value):
self.tensor.fill_(value)

def tolist(self):
return self.tensor.tolist()

Expand Down
240 changes: 240 additions & 0 deletions torch_np/linalg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
import functools
import math
from typing import Sequence

import torch

from ._detail import _dtypes_impl, _util
from ._normalizations import ArrayLike, normalizer


class LinAlgError(Exception):
pass


def _atleast_float_1(a):
if not (a.dtype.is_floating_point or a.dtype.is_complex):
a = a.to(_dtypes_impl.default_float_dtype)
return a


def _atleast_float_2(a, b):
dtyp = _dtypes_impl.result_type_impl((a.dtype, b.dtype))
if not (dtyp.is_floating_point or dtyp.is_complex):
dtyp = _dtypes_impl.default_float_dtype

a = _util.cast_if_needed(a, dtyp)
b = _util.cast_if_needed(b, dtyp)
return a, b


def linalg_errors(func):
@functools.wraps(func)
def wrapped(*args, **kwds):
try:
return func(*args, **kwds)
except torch._C._LinAlgError as e:
raise LinAlgError(*e.args)

return wrapped


# ### Matrix and vector products ###


@normalizer
@linalg_errors
def matrix_power(a: ArrayLike, n):
a = _atleat_float_1(a)
return torch.linalg.matrix_power(a, n)


@normalizer
@linalg_errors
def multi_dot(inputs: Sequence[ArrayLike], *, out=None):
return torch.linalg.multi_dot(inputs)


# ### Solving equations and inverting matrices ###


@normalizer
@linalg_errors
def solve(a: ArrayLike, b: ArrayLike):
a, b = _atleast_float_2(a, b)
return torch.linalg.solve(a, b)


@normalizer
@linalg_errors
def lstsq(a: ArrayLike, b: ArrayLike, rcond=None):
a, b = _atleast_float_2(a, b)
# NumPy is using gelsd: https://github.com/numpy/numpy/blob/v1.24.0/numpy/linalg/umath_linalg.cpp#L3991
# on CUDA, only `gels` is available though, so use it instead
driver = "gels" if a.is_cuda or b.is_cuda else "gelsd"
return torch.linalg.lstsq(a, b, rcond=rcond, driver=driver)


@normalizer
@linalg_errors
def inv(a: ArrayLike):
a = _atleast_float_1(a)
result = torch.linalg.inv(a)
return result


@normalizer
@linalg_errors
def pinv(a: ArrayLike, rcond=1e-15, hermitian=False):
a = _atleast_float_1(a)
return torch.linalg.pinv(a, rtol=rcond, hermitian=hermitian)


@normalizer
@linalg_errors
def tensorsolve(a: ArrayLike, b: ArrayLike, axes=None):
a, b = _atleast_float_2(a, b)
return torch.linalg.tensorsolve(a, b, dims=axes)


@normalizer
@linalg_errors
def tensorinv(a: ArrayLike, ind=2):
a = _atleast_float_1(a)
return torch.linalg.tensorinv(a, ind=ind)


# ### Norms and other numbers ###


@normalizer
@linalg_errors
def det(a: ArrayLike):
a = _atleast_float_1(a)
return torch.linalg.det(a)


@normalizer
@linalg_errors
def slogdet(a: ArrayLike):
a = _atleast_float_1(a)
return torch.linalg.slogdet(a)


@normalizer
@linalg_errors
def cond(x: ArrayLike, p=None):
x = _atleast_float_1(x)

# check if empty
# cf: https://github.com/numpy/numpy/blob/v1.24.0/numpy/linalg/linalg.py#L1744
if x.numel() == 0 and math.prod(x.shape[-2:]) == 0:
raise LinAlgError("cond is not defined on empty arrays")

result = torch.linalg.cond(x, p=p)

# Convert nans to infs (numpy does it in a data-dependent way, depending on
# whether the input array has nans or not)
# XXX: NumPy does this: https://github.com/numpy/numpy/blob/v1.24.0/numpy/linalg/linalg.py#L1744
return torch.where(torch.isnan(result), float("inf"), result)


@normalizer
@linalg_errors
def matrix_rank(a: ArrayLike, tol=None, hermitian=False):
a = _atleast_float_1(a)

if a.ndim < 2:
return int((a != 0).any())

if tol is None:
# follow https://github.com/numpy/numpy/blob/v1.24.0/numpy/linalg/linalg.py#L1885
atol = 0
rtol = max(a.shape[-2:]) * torch.finfo(a.dtype).eps
else:
atol, rtol = tol, 0
return torch.linalg.matrix_rank(a, atol=atol, rtol=rtol, hermitian=hermitian)


@normalizer
@linalg_errors
def norm(x: ArrayLike, ord=None, axis=None, keepdims=False):
x = _atleast_float_1(x)
result = torch.linalg.norm(x, ord=ord, dim=axis)
if keepdims:
result = _util.apply_keepdims(result, axis, x.ndim)
return result


# ### Decompositions ###


@normalizer
@linalg_errors
def cholesky(a: ArrayLike):
a = _atleast_float_1(a)
return torch.linalg.cholesky(a)


@normalizer
@linalg_errors
def qr(a: ArrayLike, mode="reduced"):
a = _atleast_float_1(a)
result = torch.linalg.qr(a, mode=mode)
if mode == "r":
# match NumPy
result = result.R
return result


@normalizer
@linalg_errors
def svd(a: ArrayLike, full_matrices=True, compute_uv=True, hermitian=False):
a = _atleast_float_1(a)
if not compute_uv:
return torch.linalg.svdvals(a)

# NB: ignore the hermitian= argument (no pytorch equivalent)
result = torch.linalg.svd(a, full_matrices=full_matrices)
return result


# ### Eigenvalues and eigenvectors ###


@normalizer
@linalg_errors
def eig(a: ArrayLike):
a = _atleast_float_1(a)
w, vt = torch.linalg.eig(a)

if not a.is_complex():
if w.is_complex() and (w.imag == 0).all():
w = w.real
vt = vt.real
return w, vt


@normalizer
@linalg_errors
def eigh(a: ArrayLike, UPLO="L"):
a = _atleast_float_1(a)
return torch.linalg.eigh(a, UPLO=UPLO)


@normalizer
@linalg_errors
def eigvals(a: ArrayLike):
a = _atleast_float_1(a)
result = torch.linalg.eigvals(a)
if not a.is_complex():
if result.is_complex() and (result.imag == 0).all():
result = result.real
return result


@normalizer
@linalg_errors
def eigvalsh(a: ArrayLike, UPLO="L"):
a = _atleast_float_1(a)
return torch.linalg.eigvalsh(a, UPLO=UPLO)
Loading