Skip to content

Add ruff to ci setup #82

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 37 commits into from
Feb 8, 2024
Merged
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
42e2a8b
Add ruff to ci setup
adonath Dec 14, 2023
4d0ccb9
Fix ruff errors in common/
adonath Jan 26, 2024
fbe6bd4
Fix ruff errors in cupy/
adonath Jan 26, 2024
c7c27be
Fix ruff errors in numpy/
adonath Jan 26, 2024
d3d57b9
Fix ruff errors in torch/
adonath Jan 26, 2024
0d437cd
Fix ruff errors in tests/
adonath Jan 26, 2024
afccf29
Fix ruff errors in array_api_compat/__init__.py
adonath Jan 26, 2024
2395ea0
Implement _get_all_public_members
adonath Jan 26, 2024
645cef2
Move linalg aliases to _aliases
adonath Jan 26, 2024
5a6f411
Fix ruff errors in cupy/linalg
adonath Jan 26, 2024
0ff0836
Move linalg aliases to numpy/_aliases
adonath Jan 26, 2024
f4d78c7
Fix ruff errors in numpy/linalg
adonath Jan 26, 2024
31bbbfa
Hide helper variables in cupy/linalg.py
adonath Jan 26, 2024
5ecc7b5
Move linalg aliases to torch/_aliases
adonath Jan 26, 2024
8e4e9ca
Fix ruff errors in torch/linalg
adonath Jan 26, 2024
5c66efc
Fix final ruff errors in array_api_compat/torch/__init__.py
adonath Jan 26, 2024
bca606d
Expose public members from numpy an cupy in __all__ respectively
adonath Jan 26, 2024
890c497
Clean up
adonath Jan 26, 2024
b2f9557
Add importorskip torch
adonath Jan 26, 2024
52ef9ee
Use importorskip
adonath Jan 26, 2024
b069230
Add missing isdtype
adonath Jan 26, 2024
ff51015
Fix tests
adonath Jan 26, 2024
b0a323d
Rename import_ to import_or_skip_cupy
adonath Jan 27, 2024
0ec2d89
Add missing imports and sort __all__
adonath Jan 27, 2024
2baa4da
More cleanup
adonath Jan 27, 2024
efd745c
Remove redefinitions
adonath Jan 27, 2024
49f2b7a
Add ruff select F822 option [skip ci]
adonath Jan 29, 2024
a748bfa
Add PLC0414 error code as well
adonath Jan 29, 2024
6b4e92c
Avoid in place modification of __all__ in _get_all_public_members
adonath Feb 1, 2024
a92f640
Add sort check for __all__
adonath Feb 2, 2024
9b1110b
Sort __all__ lists
adonath Feb 2, 2024
68c788f
Use * import for array_api_compat/__init__.py
adonath Feb 2, 2024
1720fb6
Update array_api_compat/_internal.py
adonath Feb 7, 2024
5cd47df
Adapt dask
adonath Feb 7, 2024
c5d55ae
Fix ruff errors for dask/array/linalg
adonath Feb 7, 2024
49851b5
Fix __all__ order in dask linalg
adonath Feb 7, 2024
2db3d6a
Fix import of __all__ in dask/array/__init__.py
adonath Feb 7, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions .github/workflows/ruff.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
name: CI
on: [push, pull_request]
jobs:
check-ruff:
runs-on: ubuntu-latest
continue-on-error: true
steps:
- uses: actions/checkout@v3
- name: Install Python
uses: actions/setup-python@v4
with:
python-version: "3.11"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install ruff
# Update output format to enable automatic inline annotations.
- name: Run Ruff
run: ruff check --output-format=github --select F822,PLC0414,RUF022 --preview .
2 changes: 1 addition & 1 deletion array_api_compat/__init__.py
Original file line number Diff line number Diff line change
@@ -19,4 +19,4 @@
"""
__version__ = '1.4.1'

from .common import *
from .common import * # noqa: F401, F403
34 changes: 33 additions & 1 deletion array_api_compat/_internal.py
Original file line number Diff line number Diff line change
@@ -5,6 +5,7 @@
from functools import wraps
from inspect import signature


def get_xp(xp):
"""
Decorator to automatically replace xp with the corresponding array module.
@@ -21,13 +22,16 @@ def func(x, /, xp, kwarg=None):
arguments.
"""

def inner(f):
@wraps(f)
def wrapped_f(*args, **kwargs):
return f(*args, xp=xp, **kwargs)

sig = signature(f)
new_sig = sig.replace(parameters=[sig.parameters[i] for i in sig.parameters if i != 'xp'])
new_sig = sig.replace(
parameters=[sig.parameters[i] for i in sig.parameters if i != "xp"]
)

if wrapped_f.__doc__ is None:
wrapped_f.__doc__ = f"""\
@@ -41,3 +45,31 @@ def wrapped_f(*args, **kwargs):
return wrapped_f

return inner


def _get_all_public_members(module, exclude=None, extend_all=False):
"""Get all public members of a module.
Parameters
----------
module : module
The module to get members from.
exclude : callable, optional
A callable that takes a name and returns True if the name should be
excluded from the list of members.
extend_all : bool, optional
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is monkeypatching torch.__all__ etc.? We don't want to do that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but this just keeps the current behavior. Take a look at https://github.com/data-apis/array-api-compat/blob/main/array_api_compat/torch/__init__.py#L3

I have not checked whether this is still necessary, but probably we have to keep it this way?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked with the following code:

import torch

torch_all = set(torch.__all__)
public = set([name for name in dir(torch) if not name.startswith("_")])

print(torch_all.difference(public))
print(public.difference(torch_all))

And this gives:

set()
{'complex64', 'eig', 'special', ... , 'QInt8Storage', 'segment_reduce', 'ComplexDoubleStorage'}

So indeed __all__ does not contain multiple members and most importantly it does not contain the dtypes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but this just keeps the current behavior. Take a look at main/array_api_compat/torch/init.py#L3

That code does not modify the torch.__all__ list:

>>> import torch
>>> torch_all = list(torch.__all__)
>>> import array_api_compat.torch
>>> torch_all2 = list(torch.__all__)
>>> torch_all == torch_all2
True

Generally speaking, this package should not monkeypatch the underlying libraries.

So indeed all does not contain multiple members and most importantly it does not contain the dtypes.

Yes, that's a known issue. pytorch/pytorch#91908

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, now I understand. I did not mean to actually modify torch.__all__ in place but copy and extend instead. I'll fix that behavior.

If True, extend the module's __all__ attribute with the members of the
module derived from dir(module). To be used for libraries that do not have a complete __all__ list.
"""
members = getattr(module, "__all__", [])

if members and not extend_all:
return members

if exclude is None:
exclude = lambda name: name.startswith("_") # noqa: E731

members = members + [_ for _ in dir(module) if not exclude(_)]

# remove duplicates
return list(set(members))
18 changes: 17 additions & 1 deletion array_api_compat/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,17 @@
from ._helpers import *
from ._helpers import (
array_namespace,
device,
get_namespace,
is_array_api_obj,
size,
to_device,
)

__all__ = [
"array_namespace",
"device",
"get_namespace",
"is_array_api_obj",
"size",
"to_device",
]
11 changes: 2 additions & 9 deletions array_api_compat/common/_aliases.py
Original file line number Diff line number Diff line change
@@ -6,7 +6,8 @@

from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Optional, Sequence, Tuple, Union, List
import numpy as np
from typing import Optional, Sequence, Tuple, Union
from ._typing import ndarray, Device, Dtype, NestedSequence, SupportsBufferProtocol

from typing import NamedTuple
@@ -544,11 +545,3 @@ def isdtype(
# more strict here to match the type annotation? Note that the
# numpy.array_api implementation will be very strict.
return dtype == kind

__all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like',
'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like',
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
'astype', 'std', 'var', 'permute_dims', 'reshape', 'argsort',
'sort', 'nonzero', 'sum', 'prod', 'ceil', 'floor', 'trunc',
'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype']
12 changes: 8 additions & 4 deletions array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
@@ -7,6 +7,12 @@
"""
from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from typing import Optional, Union, Any
from ._typing import Array, Device

import sys
import math

@@ -142,7 +148,7 @@ def _check_device(xp, device):
# wrapping or subclassing them. These helper functions can be used instead of
# the wrapper functions for libraries that need to support both NumPy/CuPy and
# other libraries that use devices.
def device(x: "Array", /) -> "Device":
def device(x: Array, /) -> Device:
"""
Hardware device the array data resides on.
@@ -204,7 +210,7 @@ def _torch_to_device(x, device, /, stream=None):
raise NotImplementedError
return x.to(device)

def to_device(x: "Array", device: "Device", /, *, stream: "Optional[Union[int, Any]]" = None) -> "Array":
def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] = None) -> Array:
"""
Copy the array from the device on which it currently resides to the specified ``device``.
@@ -252,5 +258,3 @@ def size(x):
if None in x.shape:
return None
return math.prod(x.shape)

__all__ = ['is_array_api_obj', 'array_namespace', 'get_namespace', 'device', 'to_device', 'size']
12 changes: 3 additions & 9 deletions array_api_compat/common/_linalg.py
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@

from typing import TYPE_CHECKING, NamedTuple
if TYPE_CHECKING:
from typing import Literal, Optional, Sequence, Tuple, Union
from typing import Literal, Optional, Tuple, Union
from ._typing import ndarray

import numpy as np
@@ -11,7 +11,7 @@
else:
from numpy.core.numeric import normalize_axis_tuple

from ._aliases import matmul, matrix_transpose, tensordot, vecdot, isdtype
from ._aliases import matrix_transpose, isdtype
from .._internal import get_xp

# These are in the main NumPy namespace but not in numpy.linalg
@@ -149,10 +149,4 @@ def trace(x: ndarray, /, xp, *, offset: int = 0, dtype=None, **kwargs) -> ndarra
dtype = xp.float64
elif x.dtype == xp.complex64:
dtype = xp.complex128
return xp.asarray(xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs))

__all__ = ['cross', 'matmul', 'outer', 'tensordot', 'EighResult',
'QRResult', 'SlogdetResult', 'SVDResult', 'eigh', 'qr', 'slogdet',
'svd', 'cholesky', 'matrix_rank', 'pinv', 'matrix_norm',
'matrix_transpose', 'svdvals', 'vecdot', 'vector_norm', 'diagonal',
'trace']
return xp.asarray(xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs))
3 changes: 3 additions & 0 deletions array_api_compat/common/_typing.py
Original file line number Diff line number Diff line change
@@ -18,3 +18,6 @@ def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ...
def __len__(self, /) -> int: ...

SupportsBufferProtocol = Any

Array = Any
Device = Any
151 changes: 144 additions & 7 deletions array_api_compat/cupy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,153 @@
from cupy import *
import cupy as _cp
from cupy import * # noqa: F401, F403

# from cupy import * doesn't overwrite these builtin names
from cupy import abs, max, min, round

from .._internal import _get_all_public_members
from ..common._helpers import (
array_namespace,
device,
get_namespace,
is_array_api_obj,
size,
to_device,
)

# These imports may overwrite names from the import * above.
from ._aliases import *
from ._aliases import (
UniqueAllResult,
UniqueCountsResult,
UniqueInverseResult,
acos,
acosh,
arange,
argsort,
asarray,
asarray_cupy,
asin,
asinh,
astype,
atan,
atan2,
atanh,
bitwise_invert,
bitwise_left_shift,
bitwise_right_shift,
bool,
ceil,
concat,
empty,
empty_like,
eye,
floor,
full,
full_like,
isdtype,
linspace,
matmul,
matrix_transpose,
nonzero,
ones,
ones_like,
permute_dims,
pow,
prod,
reshape,
sort,
std,
sum,
tensordot,
trunc,
unique_all,
unique_counts,
unique_inverse,
unique_values,
var,
vecdot,
zeros,
zeros_like,
)

# See the comment in the numpy __init__.py
__import__(__package__ + '.linalg')
__all__ = []

__all__ += _get_all_public_members(_cp)

__all__ += [
"abs",
"max",
"min",
"round",
]

from .linalg import matrix_transpose, vecdot
__all__ += [
"array_namespace",
"device",
"get_namespace",
"is_array_api_obj",
"size",
"to_device",
]

from ..common._helpers import *
__all__ += [
"UniqueAllResult",
"UniqueCountsResult",
"UniqueInverseResult",
"acos",
"acosh",
"arange",
"argsort",
"asarray",
"asarray_cupy",
"asin",
"asinh",
"astype",
"atan",
"atan2",
"atanh",
"bitwise_invert",
"bitwise_left_shift",
"bitwise_right_shift",
"bool",
"ceil",
"concat",
"empty",
"empty_like",
"eye",
"floor",
"full",
"full_like",
"isdtype",
"linspace",
"matmul",
"matrix_transpose",
"nonzero",
"ones",
"ones_like",
"permute_dims",
"pow",
"prod",
"reshape",
"sort",
"std",
"sum",
"tensordot",
"trunc",
"unique_all",
"unique_counts",
"unique_inverse",
"unique_values",
"var",
"zeros",
"zeros_like",
]

__all__ += [
"matrix_transpose",
"vecdot",
]

# See the comment in the numpy __init__.py
__import__(__package__ + ".linalg")

__array_api_version__ = '2022.12'
__array_api_version__ = "2022.12"
34 changes: 28 additions & 6 deletions array_api_compat/cupy/_aliases.py
Original file line number Diff line number Diff line change
@@ -2,15 +2,16 @@

from functools import partial

import cupy as cp

from ..common import _aliases
from ..common import _linalg

from .._internal import get_xp

asarray = asarray_cupy = partial(_aliases._asarray, namespace='cupy')
asarray.__doc__ = _aliases._asarray.__doc__
del partial

import cupy as cp
bool = cp.bool_

# Basic renames
@@ -73,7 +74,28 @@
else:
isdtype = get_xp(cp)(_aliases.isdtype)

__all__ = _aliases.__all__ + ['asarray', 'asarray_cupy', 'bool', 'acos',
'acosh', 'asin', 'asinh', 'atan', 'atan2',
'atanh', 'bitwise_left_shift', 'bitwise_invert',
'bitwise_right_shift', 'concat', 'pow']

cross = get_xp(cp)(_linalg.cross)
outer = get_xp(cp)(_linalg.outer)
EighResult = _linalg.EighResult
QRResult = _linalg.QRResult
SlogdetResult = _linalg.SlogdetResult
SVDResult = _linalg.SVDResult
eigh = get_xp(cp)(_linalg.eigh)
qr = get_xp(cp)(_linalg.qr)
slogdet = get_xp(cp)(_linalg.slogdet)
svd = get_xp(cp)(_linalg.svd)
cholesky = get_xp(cp)(_linalg.cholesky)
matrix_rank = get_xp(cp)(_linalg.matrix_rank)
pinv = get_xp(cp)(_linalg.pinv)
matrix_norm = get_xp(cp)(_linalg.matrix_norm)
svdvals = get_xp(cp)(_linalg.svdvals)
diagonal = get_xp(cp)(_linalg.diagonal)
trace = get_xp(cp)(_linalg.trace)

# These functions are completely new here. If the library already has them
# (i.e., numpy 2.0), use the library version instead of our wrapper.
if hasattr(cp.linalg, 'vector_norm'):
vector_norm = cp.linalg.vector_norm
else:
vector_norm = get_xp(cp)(_linalg.vector_norm)
2 changes: 1 addition & 1 deletion array_api_compat/cupy/_typing.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from __future__ import annotations

__all__ = [
"ndarray",
"Device",
"Dtype",
"ndarray",
]

import sys
97 changes: 56 additions & 41 deletions array_api_compat/cupy/linalg.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,62 @@
from cupy.linalg import *
# cupy.linalg doesn't have __all__. If it is added, replace this with
#
# from cupy.linalg import __all__ as linalg_all
_n = {}
exec('from cupy.linalg import *', _n)
del _n['__builtins__']
linalg_all = list(_n)
del _n
import cupy as _cp

from ..common import _linalg
from .._internal import get_xp
from ._aliases import (matmul, matrix_transpose, tensordot, vecdot)
from .._internal import _get_all_public_members

import cupy as cp
_cupy_linalg_all = _get_all_public_members(_cp.linalg)

cross = get_xp(cp)(_linalg.cross)
outer = get_xp(cp)(_linalg.outer)
EighResult = _linalg.EighResult
QRResult = _linalg.QRResult
SlogdetResult = _linalg.SlogdetResult
SVDResult = _linalg.SVDResult
eigh = get_xp(cp)(_linalg.eigh)
qr = get_xp(cp)(_linalg.qr)
slogdet = get_xp(cp)(_linalg.slogdet)
svd = get_xp(cp)(_linalg.svd)
cholesky = get_xp(cp)(_linalg.cholesky)
matrix_rank = get_xp(cp)(_linalg.matrix_rank)
pinv = get_xp(cp)(_linalg.pinv)
matrix_norm = get_xp(cp)(_linalg.matrix_norm)
svdvals = get_xp(cp)(_linalg.svdvals)
diagonal = get_xp(cp)(_linalg.diagonal)
trace = get_xp(cp)(_linalg.trace)
for _name in _cupy_linalg_all:
globals()[_name] = getattr(_cp.linalg, _name)

# These functions are completely new here. If the library already has them
# (i.e., numpy 2.0), use the library version instead of our wrapper.
if hasattr(cp.linalg, 'vector_norm'):
vector_norm = cp.linalg.vector_norm
else:
vector_norm = get_xp(cp)(_linalg.vector_norm)
from ._aliases import ( # noqa: E402
EighResult,
QRResult,
SlogdetResult,
SVDResult,
cholesky,
cross,
diagonal,
eigh,
matmul,
matrix_norm,
matrix_rank,
matrix_transpose,
outer,
pinv,
qr,
slogdet,
svd,
svdvals,
tensordot,
trace,
vecdot,
vector_norm,
)

__all__ = linalg_all + _linalg.__all__
__all__ = []

del get_xp
del cp
del linalg_all
del _linalg
__all__ += _cupy_linalg_all

__all__ += [
"EighResult",
"QRResult",
"SVDResult",
"SlogdetResult",
"cholesky",
"cross",
"diagonal",
"eigh",
"matmul",
"matrix_norm",
"matrix_rank",
"matrix_transpose",
"outer",
"pinv",
"qr",
"slogdet",
"svd",
"svdvals",
"tensordot",
"trace",
"vecdot",
"vector_norm",
]
210 changes: 206 additions & 4 deletions array_api_compat/dask/array/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,210 @@
from dask.array import *
import dask.array as _da
from dask.array import * # noqa: F401, F403
from dask.array import (
# Element wise aliases
arccos as acos,
)
from dask.array import (
arccosh as acosh,
)
from dask.array import (
arcsin as asin,
)
from dask.array import (
arcsinh as asinh,
)
from dask.array import (
arctan as atan,
)
from dask.array import (
arctan2 as atan2,
)
from dask.array import (
arctanh as atanh,
)
from dask.array import (
bool_ as bool,
)
from dask.array import (
# Other
concatenate as concat,
)
from dask.array import (
invert as bitwise_invert,
)
from dask.array import (
left_shift as bitwise_left_shift,
)
from dask.array import (
power as pow,
)
from dask.array import (
right_shift as bitwise_right_shift,
)

# These imports may overwrite names from the import * above.
from ._aliases import *
from numpy import (
can_cast,
complex64,
complex128,
e,
finfo,
float32,
float64,
iinfo,
inf,
int8,
int16,
int32,
int64,
nan,
newaxis,
pi,
result_type,
uint8,
uint16,
uint32,
uint64,
)

__array_api_version__ = '2022.12'
from ..common._helpers import (
array_namespace,
device,
get_namespace,
is_array_api_obj,
size,
to_device,
)
from ..internal import _get_all_public_members
from ._aliases import (
UniqueAllResult,
UniqueCountsResult,
UniqueInverseResult,
arange,
asarray,
astype,
ceil,
empty,
empty_like,
eye,
floor,
full,
full_like,
isdtype,
linspace,
matmul,
matrix_transpose,
nonzero,
ones,
ones_like,
permute_dims,
prod,
reshape,
std,
sum,
tensordot,
trunc,
unique_all,
unique_counts,
unique_inverse,
unique_values,
var,
vecdot,
zeros,
zeros_like,
)

__import__(__package__ + '.linalg')
__all__ = []

__all__ += _get_all_public_members(_da)

__all__ += [
"can_cast",
"complex64",
"complex128",
"e",
"finfo",
"float32",
"float64",
"iinfo",
"inf",
"int8",
"int16",
"int32",
"int64",
"nan",
"newaxis",
"pi",
"result_type",
"uint8",
"uint16",
"uint32",
"uint64",
]

__all__ += [
"array_namespace",
"device",
"get_namespace",
"is_array_api_obj",
"size",
"to_device",
]

# 'sort', 'argsort' are unsupported by dask.array

__all__ += [
"UniqueAllResult",
"UniqueCountsResult",
"UniqueInverseResult",
"acos",
"acosh",
"arange",
"asarray",
"asin",
"asinh",
"astype",
"atan",
"atan2",
"atanh",
"bitwise_invert",
"bitwise_left_shift",
"bitwise_right_shift",
"bool",
"ceil",
"concat",
"empty",
"empty_like",
"eye",
"floor",
"full",
"full_like",
"isdtype",
"linspace",
"matmul",
"matrix_transpose",
"nonzero",
"ones",
"ones_like",
"permute_dims",
"pow",
"prod",
"reshape",
"std",
"sum",
"tensordot",
"trunc",
"unique_all",
"unique_counts",
"unique_inverse",
"unique_values",
"var",
"vecdot",
"zeros",
"zeros_like",
]


__array_api_version__ = "2022.12"

__import__(__package__ + ".linalg")
99 changes: 32 additions & 67 deletions array_api_compat/dask/array/_aliases.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,18 @@
from __future__ import annotations

from ...common import _aliases
from ...common._helpers import _check_device

from ..._internal import get_xp
from functools import partial
from typing import TYPE_CHECKING

import numpy as np
from numpy import (
# Constants
e,
inf,
nan,
pi,
newaxis,
# Dtypes
bool_ as bool,
float32,
float64,
int8,
int16,
int32,
int64,
uint8,
uint16,
uint32,
uint64,
complex64,
complex128,
iinfo,
finfo,
can_cast,
result_type,
)

from typing import TYPE_CHECKING
from ..._internal import get_xp
from ...common import _aliases, _linalg
from ...common._helpers import _check_device

if TYPE_CHECKING:
from typing import Optional, Union
from ...common._typing import ndarray, Device, Dtype
from typing import Optional, Tuple, Union

from ...common._typing import Device, Dtype, ndarray

import dask.array as da

@@ -49,6 +25,7 @@
# not pass stop/step as keyword arguments, which will cause
# an error with dask


# TODO: delete the xp stuff, it shouldn't be necessary
def dask_arange(
start: Union[int, float],
@@ -59,7 +36,7 @@ def dask_arange(
xp,
dtype: Optional[Dtype] = None,
device: Optional[Device] = None,
**kwargs
**kwargs,
) -> ndarray:
_check_device(xp, device)
args = [start]
@@ -72,11 +49,11 @@ def dask_arange(
args.append(step)
return xp.arange(*args, dtype=dtype, **kwargs)


arange = get_xp(da)(dask_arange)
eye = get_xp(da)(_aliases.eye)

from functools import partial
asarray = partial(_aliases._asarray, namespace='dask.array')
asarray = partial(_aliases._asarray, namespace="dask.array")
asarray.__doc__ = _aliases._asarray.__doc__

linspace = get_xp(da)(_aliases.linspace)
@@ -112,34 +89,22 @@ def dask_arange(
matmul = get_xp(np)(_aliases.matmul)
tensordot = get_xp(np)(_aliases.tensordot)

from dask.array import (
# Element wise aliases
arccos as acos,
arccosh as acosh,
arcsin as asin,
arcsinh as asinh,
arctan as atan,
arctan2 as atan2,
arctanh as atanh,
left_shift as bitwise_left_shift,
right_shift as bitwise_right_shift,
invert as bitwise_invert,
power as pow,
# Other
concatenate as concat,
)

# exclude these from all since
_da_unsupported = ['sort', 'argsort']

common_aliases = [alias for alias in _aliases.__all__ if alias not in _da_unsupported]

__all__ = common_aliases + ['asarray', 'bool', 'acos',
'acosh', 'asin', 'asinh', 'atan', 'atan2',
'atanh', 'bitwise_left_shift', 'bitwise_invert',
'bitwise_right_shift', 'concat', 'pow',
'e', 'inf', 'nan', 'pi', 'newaxis', 'float32', 'float64', 'int8',
'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64',
'complex64', 'complex128', 'iinfo', 'finfo', 'can_cast', 'result_type']

del da, partial, common_aliases, _da_unsupported,

EighResult = _linalg.EighResult
QRResult = _linalg.QRResult
SlogdetResult = _linalg.SlogdetResult
SVDResult = _linalg.SVDResult
qr = get_xp(da)(_linalg.qr)
cholesky = get_xp(da)(_linalg.cholesky)
matrix_rank = get_xp(da)(_linalg.matrix_rank)
matrix_norm = get_xp(da)(_linalg.matrix_norm)


def svdvals(x: ndarray) -> Union[ndarray, Tuple[ndarray, ...]]:
# TODO: can't avoid computing U or V for dask
_, s, _ = da.linalg.svd(x)
return s


vector_norm = get_xp(da)(_linalg.vector_norm)
diagonal = get_xp(da)(_linalg.diagonal)
98 changes: 50 additions & 48 deletions array_api_compat/dask/array/linalg.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,50 @@
from __future__ import annotations

from dask.array.linalg import *
from ...common import _linalg
from ..._internal import get_xp
from dask.array import matmul, tensordot, trace, outer
from ._aliases import matrix_transpose, vecdot

import dask.array as da

from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Union, Tuple
from ...common._typing import ndarray

# cupy.linalg doesn't have __all__. If it is added, replace this with
#
# from cupy.linalg import __all__ as linalg_all
_n = {}
exec('from dask.array.linalg import *', _n)
del _n['__builtins__']
linalg_all = list(_n)
del _n

EighResult = _linalg.EighResult
QRResult = _linalg.QRResult
SlogdetResult = _linalg.SlogdetResult
SVDResult = _linalg.SVDResult
qr = get_xp(da)(_linalg.qr)
cholesky = get_xp(da)(_linalg.cholesky)
matrix_rank = get_xp(da)(_linalg.matrix_rank)
matrix_norm = get_xp(da)(_linalg.matrix_norm)

def svdvals(x: ndarray) -> Union[ndarray, Tuple[ndarray, ...]]:
# TODO: can't avoid computing U or V for dask
_, s, _ = svd(x)
return s

vector_norm = get_xp(da)(_linalg.vector_norm)
diagonal = get_xp(da)(_linalg.diagonal)

__all__ = linalg_all + ["EighResult", "QRResult", "SlogdetResult",
"SVDResult", "qr", "cholesky", "matrix_rank", "matrix_norm",
"svdvals", "vector_norm", "diagonal"]

del get_xp
del da
del _linalg
import dask.array as _da
from dask.array import (
matmul,
outer,
tensordot,
trace,
)
from dask.array.linalg import * # noqa: F401, F403

from .._internal import _get_all_public_members
from ._aliases import (
EighResult,
QRResult,
SlogdetResult,
SVDResult,
cholesky,
diagonal,
matrix_norm,
matrix_rank,
matrix_transpose,
qr,
svdvals,
vecdot,
vector_norm,
)

__all__ = [
"matmul",
"outer",
"tensordot",
"trace",
]

__all__ += _get_all_public_members(_da.linalg)

__all__ += [
"EighResult",
"QRResult",
"SVDResult",
"SlogdetResult",
"cholesky",
"diagonal",
"matrix_norm",
"matrix_rank",
"matrix_transpose",
"qr",
"svdvals",
"vecdot",
"vector_norm",
]
152 changes: 144 additions & 8 deletions array_api_compat/numpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,150 @@
from numpy import *
from numpy import * # noqa: F401, F403
from numpy import __all__ as _numpy_all

# from numpy import * doesn't overwrite these builtin names
from numpy import abs, max, min, round

from ..common._helpers import (
array_namespace,
device,
get_namespace,
is_array_api_obj,
size,
to_device,
)

# These imports may overwrite names from the import * above.
from ._aliases import *
from ._aliases import (
UniqueAllResult,
UniqueCountsResult,
UniqueInverseResult,
acos,
acosh,
arange,
argsort,
asarray,
asarray_numpy,
asin,
asinh,
astype,
atan,
atan2,
atanh,
bitwise_invert,
bitwise_left_shift,
bitwise_right_shift,
bool,
ceil,
concat,
empty,
empty_like,
eye,
floor,
full,
full_like,
isdtype,
linspace,
matmul,
matrix_transpose,
nonzero,
ones,
ones_like,
permute_dims,
pow,
prod,
reshape,
sort,
std,
sum,
tensordot,
trunc,
unique_all,
unique_counts,
unique_inverse,
unique_values,
var,
vecdot,
zeros,
zeros_like,
)

__all__ = []

__all__ += _numpy_all

__all__ += [
"abs",
"max",
"min",
"round",
]

__all__ += [
"array_namespace",
"device",
"get_namespace",
"is_array_api_obj",
"size",
"to_device",
]

__all__ += [
"UniqueAllResult",
"UniqueCountsResult",
"UniqueInverseResult",
"acos",
"acosh",
"arange",
"argsort",
"asarray",
"asarray_numpy",
"asin",
"asinh",
"astype",
"atan",
"atan2",
"atanh",
"bitwise_invert",
"bitwise_left_shift",
"bitwise_right_shift",
"bool",
"ceil",
"concat",
"empty",
"empty_like",
"eye",
"floor",
"full",
"full_like",
"isdtype",
"linspace",
"matmul",
"matrix_transpose",
"nonzero",
"ones",
"ones_like",
"permute_dims",
"pow",
"prod",
"reshape",
"sort",
"std",
"sum",
"tensordot",
"trunc",
"unique_all",
"unique_counts",
"unique_inverse",
"unique_values",
"var",
"zeros",
"zeros_like",
]

__all__ += [
"matrix_transpose",
"vecdot",
]

# Don't know why, but we have to do an absolute import to import linalg. If we
# instead do
@@ -13,10 +153,6 @@
#
# It doesn't overwrite np.linalg from above. The import is generated
# dynamically so that the library can be vendored.
__import__(__package__ + '.linalg')

from .linalg import matrix_transpose, vecdot

from ..common._helpers import *
__import__(__package__ + ".linalg")

__array_api_version__ = '2022.12'
__array_api_version__ = "2022.12"
40 changes: 30 additions & 10 deletions array_api_compat/numpy/_aliases.py
Original file line number Diff line number Diff line change
@@ -2,15 +2,14 @@

from functools import partial

from ..common import _aliases
import numpy as np

from .._internal import get_xp
from ..common import _aliases, _linalg

asarray = asarray_numpy = partial(_aliases._asarray, namespace='numpy')
asarray = asarray_numpy = partial(_aliases._asarray, namespace="numpy")
asarray.__doc__ = _aliases._asarray.__doc__
del partial

import numpy as np
bool = np.bool_

# Basic renames
@@ -64,16 +63,37 @@

# These functions are completely new here. If the library already has them
# (i.e., numpy 2.0), use the library version instead of our wrapper.
if hasattr(np, 'vecdot'):
if hasattr(np, "vecdot"):
vecdot = np.vecdot
else:
vecdot = get_xp(np)(_aliases.vecdot)
if hasattr(np, 'isdtype'):
if hasattr(np, "isdtype"):
isdtype = np.isdtype
else:
isdtype = get_xp(np)(_aliases.isdtype)

__all__ = _aliases.__all__ + ['asarray', 'asarray_numpy', 'bool', 'acos',
'acosh', 'asin', 'asinh', 'atan', 'atan2',
'atanh', 'bitwise_left_shift', 'bitwise_invert',
'bitwise_right_shift', 'concat', 'pow']

cross = get_xp(np)(_linalg.cross)
outer = get_xp(np)(_linalg.outer)
EighResult = _linalg.EighResult
QRResult = _linalg.QRResult
SlogdetResult = _linalg.SlogdetResult
SVDResult = _linalg.SVDResult
eigh = get_xp(np)(_linalg.eigh)
qr = get_xp(np)(_linalg.qr)
slogdet = get_xp(np)(_linalg.slogdet)
svd = get_xp(np)(_linalg.svd)
cholesky = get_xp(np)(_linalg.cholesky)
matrix_rank = get_xp(np)(_linalg.matrix_rank)
pinv = get_xp(np)(_linalg.pinv)
matrix_norm = get_xp(np)(_linalg.matrix_norm)
svdvals = get_xp(np)(_linalg.svdvals)
diagonal = get_xp(np)(_linalg.diagonal)
trace = get_xp(np)(_linalg.trace)

# These functions are completely new here. If the library already has them
# (i.e., numpy 2.0), use the library version instead of our wrapper.
if hasattr(np.linalg, "vector_norm"):
vector_norm = np.linalg.vector_norm
else:
vector_norm = get_xp(np)(_linalg.vector_norm)
2 changes: 1 addition & 1 deletion array_api_compat/numpy/_typing.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from __future__ import annotations

__all__ = [
"ndarray",
"Device",
"Dtype",
"ndarray",
]

import sys
103 changes: 63 additions & 40 deletions array_api_compat/numpy/linalg.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,63 @@
from numpy.linalg import *
from numpy.linalg import __all__ as linalg_all

from ..common import _linalg
from .._internal import get_xp
from ._aliases import (matmul, matrix_transpose, tensordot, vecdot)

import numpy as np

cross = get_xp(np)(_linalg.cross)
outer = get_xp(np)(_linalg.outer)
EighResult = _linalg.EighResult
QRResult = _linalg.QRResult
SlogdetResult = _linalg.SlogdetResult
SVDResult = _linalg.SVDResult
eigh = get_xp(np)(_linalg.eigh)
qr = get_xp(np)(_linalg.qr)
slogdet = get_xp(np)(_linalg.slogdet)
svd = get_xp(np)(_linalg.svd)
cholesky = get_xp(np)(_linalg.cholesky)
matrix_rank = get_xp(np)(_linalg.matrix_rank)
pinv = get_xp(np)(_linalg.pinv)
matrix_norm = get_xp(np)(_linalg.matrix_norm)
svdvals = get_xp(np)(_linalg.svdvals)
diagonal = get_xp(np)(_linalg.diagonal)
trace = get_xp(np)(_linalg.trace)

# These functions are completely new here. If the library already has them
# (i.e., numpy 2.0), use the library version instead of our wrapper.
if hasattr(np.linalg, 'vector_norm'):
vector_norm = np.linalg.vector_norm
else:
vector_norm = get_xp(np)(_linalg.vector_norm)

__all__ = linalg_all + _linalg.__all__

del get_xp
del np
del linalg_all
del _linalg
import numpy as _np

from .._internal import _get_all_public_members

_numpy_linalg_all = _get_all_public_members(_np.linalg)

for _name in _numpy_linalg_all:
globals()[_name] = getattr(_np.linalg, _name)


from ._aliases import ( # noqa: E402
EighResult,
QRResult,
SlogdetResult,
SVDResult,
cholesky,
cross,
diagonal,
eigh,
matmul,
matrix_norm,
matrix_rank,
matrix_transpose,
outer,
pinv,
qr,
slogdet,
svd,
svdvals,
tensordot,
trace,
vecdot,
vector_norm,
)

__all__ = []

__all__ += _numpy_linalg_all

__all__ += [
"EighResult",
"QRResult",
"SVDResult",
"SlogdetResult",
"cholesky",
"cross",
"diagonal",
"eigh",
"matmul",
"matrix_norm",
"matrix_rank",
"matrix_transpose",
"outer",
"pinv",
"qr",
"slogdet",
"svd",
"svdvals",
"tensordot",
"trace",
"vecdot",
"vector_norm",
]
199 changes: 183 additions & 16 deletions array_api_compat/torch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,189 @@
from torch import *

# Several names are not included in the above import *
import torch
for n in dir(torch):
if (n.startswith('_')
or n.endswith('_')
or 'cuda' in n
or 'cpu' in n
or 'backward' in n):
continue
exec(n + ' = torch.' + n)
import torch as _torch
from torch import * # noqa: F401, F403

from .._internal import _get_all_public_members


def exlcude(name):
if (
name.startswith("_")
or name.endswith("_")
or "cuda" in name
or "cpu" in name
or "backward" in name
):
return True
return False


_torch_all = _get_all_public_members(_torch, exclude=exlcude, extend_all=True)

for _name in _torch_all:
globals()[_name] = getattr(_torch, _name)


from ..common._helpers import ( # noqa: E402
array_namespace,
device,
get_namespace,
is_array_api_obj,
size,
to_device,
)

# These imports may overwrite names from the import * above.
from ._aliases import *
from ._aliases import ( # noqa: E402
add,
all,
any,
arange,
astype,
atan2,
bitwise_and,
bitwise_invert,
bitwise_left_shift,
bitwise_or,
bitwise_right_shift,
bitwise_xor,
broadcast_arrays,
broadcast_to,
can_cast,
concat,
divide,
empty,
equal,
expand_dims,
eye,
flip,
floor_divide,
full,
greater,
greater_equal,
isdtype,
less,
less_equal,
linspace,
logaddexp,
matmul,
matrix_transpose,
max,
mean,
min,
multiply,
newaxis,
nonzero,
not_equal,
ones,
permute_dims,
pow,
prod,
remainder,
reshape,
result_type,
roll,
sort,
squeeze,
std,
subtract,
sum,
take,
tensordot,
tril,
triu,
unique_all,
unique_counts,
unique_inverse,
unique_values,
var,
vecdot,
where,
zeros,
)

# See the comment in the numpy __init__.py
__import__(__package__ + '.linalg')
__all__ = []

__all__ += _torch_all

__all__ += [
"array_namespace",
"device",
"get_namespace",
"is_array_api_obj",
"size",
"to_device",
]

from ..common._helpers import *
__all__ += [
"add",
"all",
"any",
"arange",
"astype",
"atan2",
"bitwise_and",
"bitwise_invert",
"bitwise_left_shift",
"bitwise_or",
"bitwise_right_shift",
"bitwise_xor",
"broadcast_arrays",
"broadcast_to",
"can_cast",
"concat",
"divide",
"empty",
"equal",
"expand_dims",
"eye",
"flip",
"floor_divide",
"full",
"greater",
"greater_equal",
"isdtype",
"less",
"less_equal",
"linspace",
"logaddexp",
"matmul",
"matrix_transpose",
"max",
"mean",
"min",
"multiply",
"newaxis",
"nonzero",
"not_equal",
"ones",
"permute_dims",
"pow",
"prod",
"remainder",
"reshape",
"result_type",
"roll",
"sort",
"squeeze",
"std",
"subtract",
"sum",
"take",
"tensordot",
"tril",
"triu",
"unique_all",
"unique_counts",
"unique_inverse",
"unique_values",
"var",
"vecdot",
"where",
"zeros",
]


# See the comment in the numpy __init__.py
__import__(__package__ + ".linalg")

__array_api_version__ = '2022.12'
__array_api_version__ = "2022.12"
71 changes: 50 additions & 21 deletions array_api_compat/torch/_aliases.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
from __future__ import annotations

from builtins import all as builtin_all
from builtins import any as builtin_any
from functools import wraps
from builtins import all as builtin_all, any as builtin_any

from ..common._aliases import (UniqueAllResult, UniqueCountsResult,
UniqueInverseResult,
matrix_transpose as _aliases_matrix_transpose,
vecdot as _aliases_vecdot)
from .._internal import get_xp
from typing import TYPE_CHECKING

import torch

from typing import TYPE_CHECKING
from .._internal import get_xp
from ..common._aliases import UniqueAllResult, UniqueCountsResult, UniqueInverseResult
from ..common._aliases import matrix_transpose as _aliases_matrix_transpose
from ..common._aliases import vecdot as _aliases_vecdot

if TYPE_CHECKING:
from typing import List, Optional, Sequence, Tuple, Union
from ..common._typing import Device

from torch import dtype as Dtype

from ..common._typing import Device

array = torch.Tensor

_int_dtypes = {
@@ -693,15 +695,42 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -
axis = 0
return torch.index_select(x, axis, indices, **kwargs)

__all__ = ['result_type', 'can_cast', 'permute_dims', 'bitwise_invert', 'newaxis',
'add', 'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or',
'bitwise_right_shift', 'bitwise_xor', 'divide', 'equal',
'floor_divide', 'greater', 'greater_equal', 'less', 'less_equal',
'logaddexp', 'multiply', 'not_equal', 'pow', 'remainder',
'subtract', 'max', 'min', 'sort', 'prod', 'sum', 'any', 'all',
'mean', 'std', 'var', 'concat', 'squeeze', 'broadcast_to', 'flip', 'roll',
'nonzero', 'where', 'reshape', 'arange', 'eye', 'linspace', 'full',
'ones', 'zeros', 'empty', 'tril', 'triu', 'expand_dims', 'astype',
'broadcast_arrays', 'unique_all', 'unique_counts',
'unique_inverse', 'unique_values', 'matmul', 'matrix_transpose',
'vecdot', 'tensordot', 'isdtype', 'take']


# Note: torch.linalg.cross does not default to axis=-1 (it defaults to the
# first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743
def cross(x1: array, x2: array, /, *, axis: int = -1) -> array:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you move these to this file? These were in linalg.py because they're linalg only functions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I understand. It seemed conceptually simpler in the end to just use linalg.py the same way as __init__.py. So it only contains the explicit imports and the __all__ declarations. This avoids the need for del statements and if something is missing from __all__ it errors as an unused import during the style check. I hope this makes sense. Otherwise I could maybe just introduce _aliases_linalg.py or similar.

x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
return torch.linalg.cross(x1, x2, dim=axis)

def vecdot_linalg(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array:
from ._aliases import isdtype

x1, x2 = _fix_promotion(x1, x2, only_scalar=False)

# torch.linalg.vecdot doesn't support integer dtypes
if isdtype(x1.dtype, 'integral') or isdtype(x2.dtype, 'integral'):
if kwargs:
raise RuntimeError("vecdot kwargs not supported for integral dtypes")
ndim = max(x1.ndim, x2.ndim)
x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape)
x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape)
if x1_shape[axis] != x2_shape[axis]:
raise ValueError("x1 and x2 must have the same size along the given axis")

x1_, x2_ = torch.broadcast_tensors(x1, x2)
x1_ = torch.moveaxis(x1_, axis, -1)
x2_ = torch.moveaxis(x2_, axis, -1)

res = x1_[..., None, :] @ x2_[..., None]
return res[..., 0, 0]
return torch.linalg.vecdot(x1, x2, dim=axis, **kwargs)

def solve(x1: array, x2: array, /, **kwargs) -> array:
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
return torch.linalg.solve(x1, x2, **kwargs)

# torch.trace doesn't support the offset argument and doesn't support stacking
def trace(x: array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> array:
# Use our wrapped sum to make sure it does upcasting correctly
return sum(torch.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype)
86 changes: 29 additions & 57 deletions array_api_compat/torch/linalg.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,34 @@
from __future__ import annotations
import torch as _torch

from typing import TYPE_CHECKING
if TYPE_CHECKING:
import torch
array = torch.Tensor
from torch import dtype as Dtype
from typing import Optional
from .._internal import _get_all_public_members

from torch.linalg import *
_torch_linalg_all = _get_all_public_members(_torch.linalg)

# torch.linalg doesn't define __all__
# from torch.linalg import __all__ as linalg_all
from torch import linalg as torch_linalg
linalg_all = [i for i in dir(torch_linalg) if not i.startswith('_')]
for _name in _torch_linalg_all:
globals()[_name] = getattr(_torch.linalg, _name)

# outer is implemented in torch but aren't in the linalg namespace
from torch import outer
from ._aliases import _fix_promotion, matrix_transpose, tensordot, sum

# Note: torch.linalg.cross does not default to axis=-1 (it defaults to the
# first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743
def cross(x1: array, x2: array, /, *, axis: int = -1) -> array:
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
return torch_linalg.cross(x1, x2, dim=axis)

def vecdot(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array:
from ._aliases import isdtype

x1, x2 = _fix_promotion(x1, x2, only_scalar=False)

# torch.linalg.vecdot doesn't support integer dtypes
if isdtype(x1.dtype, 'integral') or isdtype(x2.dtype, 'integral'):
if kwargs:
raise RuntimeError("vecdot kwargs not supported for integral dtypes")
ndim = max(x1.ndim, x2.ndim)
x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape)
x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape)
if x1_shape[axis] != x2_shape[axis]:
raise ValueError("x1 and x2 must have the same size along the given axis")

x1_, x2_ = torch.broadcast_tensors(x1, x2)
x1_ = torch.moveaxis(x1_, axis, -1)
x2_ = torch.moveaxis(x2_, axis, -1)

res = x1_[..., None, :] @ x2_[..., None]
return res[..., 0, 0]
return torch.linalg.vecdot(x1, x2, dim=axis, **kwargs)

def solve(x1: array, x2: array, /, **kwargs) -> array:
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
return torch.linalg.solve(x1, x2, **kwargs)

# torch.trace doesn't support the offset argument and doesn't support stacking
def trace(x: array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> array:
# Use our wrapped sum to make sure it does upcasting correctly
return sum(torch.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype)

__all__ = linalg_all + ['outer', 'trace', 'matrix_transpose', 'tensordot',
'vecdot', 'solve']

del linalg_all
outer = _torch.outer

from ._aliases import ( # noqa: E402
matrix_transpose,
solve,
sum,
tensordot,
trace,
vecdot_linalg as vecdot,
)

__all__ = []

__all__ += _torch_linalg_all

__all__ += [
"matrix_transpose",
"outer",
"solve",
"sum",
"tensordot",
"trace",
"vecdot",
]
5 changes: 3 additions & 2 deletions tests/_helpers.py
Original file line number Diff line number Diff line change
@@ -2,7 +2,8 @@

import pytest

def import_(library):
if 'cupy' in library:

def import_or_skip_cupy(library):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, I'm probably going to rename this back, because I need to add some additional skipping logic for jax as well at #84

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. I think it would be good to keep the "or skip" in the name, because from looking at the places where it was used, it was not clear what it does.

if "cupy" in library:
return pytest.importorskip(library)
return import_module(library)
27 changes: 14 additions & 13 deletions tests/test_array_namespace.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
import numpy as np
import pytest
import torch

import array_api_compat
from array_api_compat import array_namespace

from ._helpers import import_
from ._helpers import import_or_skip_cupy

import pytest

@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array"])
@pytest.mark.parametrize("api_version", [None, '2021.12'])
@pytest.mark.parametrize("api_version", [None, "2021.12"])
def test_array_namespace(library, api_version):
lib = import_(library)
xp = import_or_skip_cupy(library)

array = lib.asarray([1.0, 2.0, 3.0])
array = xp.asarray([1.0, 2.0, 3.0])
namespace = array_api_compat.array_namespace(array, api_version=api_version)

if 'array_api' in library:
assert namespace == lib
if "array_api" in library:
assert namespace == xp
else:
if library == "dask.array":
assert namespace == array_api_compat.dask.array
@@ -26,18 +29,16 @@ def test_array_namespace_errors():
pytest.raises(TypeError, lambda: array_namespace([1]))
pytest.raises(TypeError, lambda: array_namespace())

import numpy as np
x = np.asarray([1, 2])

pytest.raises(TypeError, lambda: array_namespace((x, x)))
pytest.raises(TypeError, lambda: array_namespace(x, (x, x)))

import torch
y = torch.asarray([1, 2])

def test_array_namespace_errors_torch():
y = torch.asarray([1, 2])
x = np.asarray([1, 2])
pytest.raises(TypeError, lambda: array_namespace(x, y))

pytest.raises(ValueError, lambda: array_namespace(x, api_version='2022.12'))
pytest.raises(ValueError, lambda: array_namespace(x, api_version="2022.12"))


def test_get_namespace():
13 changes: 8 additions & 5 deletions tests/test_common.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
from ._helpers import import_
from array_api_compat import to_device, device

import pytest
import numpy as np
import pytest
from numpy.testing import assert_allclose

from array_api_compat import to_device

from ._helpers import import_or_skip_cupy


@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array"])
def test_to_device_host(library):
# different libraries have different semantics
# for DtoH transfers; ensure that we support a portable
# shim for common array libs
# see: https://github.com/scipy/scipy/issues/18286#issuecomment-1527552919
xp = import_('array_api_compat.' + library)
xp = import_or_skip_cupy("array_api_compat." + library)

expected = np.array([1, 2, 3])
x = xp.asarray([1, 2, 3])
x = to_device(x, "cpu")
10 changes: 5 additions & 5 deletions tests/test_isdtype.py
Original file line number Diff line number Diff line change
@@ -3,10 +3,10 @@
non-spec dtypes
"""

from ._helpers import import_

import pytest

from ._helpers import import_or_skip_cupy

# Check the known dtypes by their string names

def _spec_dtypes(library):
@@ -61,12 +61,12 @@ def isdtype_(dtype_, kind):
res = dtype_categories[kind](dtype_)
else:
res = dtype_ == kind
assert type(res) is bool
assert type(res) is bool # noqa: E721
return res

@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array"])
def test_isdtype_spec_dtypes(library):
xp = import_('array_api_compat.' + library)
xp = import_or_skip_cupy('array_api_compat.' + library)

isdtype = xp.isdtype

@@ -101,7 +101,7 @@ def test_isdtype_spec_dtypes(library):
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array"])
@pytest.mark.parametrize("dtype_", additional_dtypes)
def test_isdtype_additional_dtypes(library, dtype_):
xp = import_('array_api_compat.' + library)
xp = import_or_skip_cupy('array_api_compat.' + library)

isdtype = xp.isdtype

12 changes: 7 additions & 5 deletions tests/test_vendoring.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
from pytest import skip
import pytest


def test_vendoring_numpy():
from vendor_test import uses_numpy

uses_numpy._test_numpy()


def test_vendoring_cupy():
try:
import cupy
except ImportError:
skip("CuPy is not installed")
pytest.importorskip("cupy")

from vendor_test import uses_cupy

uses_cupy._test_cupy()


def test_vendoring_torch():
from vendor_test import uses_torch

uses_torch._test_torch()

def test_vendoring_dask():