Skip to content

Revert __all__ related changes from #82 #95

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Feb 26, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ruff.yml
Original file line number Diff line number Diff line change
@@ -16,4 +16,4 @@ jobs:
pip install ruff
# Update output format to enable automatic inline annotations.
- name: Run Ruff
run: ruff check --output-format=github --select F822,PLC0414,RUF022 --preview .
run: ruff check --output-format=github .
29 changes: 0 additions & 29 deletions array_api_compat/_internal.py
Original file line number Diff line number Diff line change
@@ -5,7 +5,6 @@
from functools import wraps
from inspect import signature


def get_xp(xp):
"""
Decorator to automatically replace xp with the corresponding array module.
@@ -45,31 +44,3 @@ def wrapped_f(*args, **kwargs):
return wrapped_f

return inner


def _get_all_public_members(module, exclude=None, extend_all=False):
"""Get all public members of a module.

Parameters
----------
module : module
The module to get members from.
exclude : callable, optional
A callable that takes a name and returns True if the name should be
excluded from the list of members.
extend_all : bool, optional
If True, extend the module's __all__ attribute with the members of the
module derived from dir(module). To be used for libraries that do not have a complete __all__ list.
"""
members = getattr(module, "__all__", [])

if members and not extend_all:
return members

if exclude is None:
exclude = lambda name: name.startswith("_") # noqa: E731

members = members + [_ for _ in dir(module) if not exclude(_)]

# remove duplicates
return list(set(members))
28 changes: 1 addition & 27 deletions array_api_compat/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1,27 +1 @@
from ._helpers import (
array_namespace,
device,
get_namespace,
is_array_api_obj,
is_cupy_array,
is_dask_array,
is_jax_array,
is_numpy_array,
is_torch_array,
size,
to_device,
)

__all__ = [
"array_namespace",
"device",
"get_namespace",
"is_array_api_obj",
"is_cupy_array",
"is_dask_array",
"is_jax_array",
"is_numpy_array",
"is_torch_array",
"size",
"to_device",
]
from ._helpers import * # noqa: F403
11 changes: 11 additions & 0 deletions array_api_compat/common/_aliases.py
Original file line number Diff line number Diff line change
@@ -146,6 +146,9 @@ def zeros_like(

# The functions here return namedtuples (np.unique() returns a normal
# tuple).

# Note that these named tuples aren't actually part of the standard namespace,
# but I don't see any issue with exporting the names here regardless.
class UniqueAllResult(NamedTuple):
values: ndarray
indices: ndarray
@@ -545,3 +548,11 @@ def isdtype(
# more strict here to match the type annotation? Note that the
# array_api_strict implementation will be very strict.
return dtype == kind

__all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like',
'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like',
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
'astype', 'std', 'var', 'permute_dims', 'reshape', 'argsort',
'sort', 'nonzero', 'sum', 'prod', 'ceil', 'floor', 'trunc',
'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype']
16 changes: 16 additions & 0 deletions array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
@@ -288,3 +288,19 @@ def size(x):
if None in x.shape:
return None
return math.prod(x.shape)

__all__ = [
"array_namespace",
"device",
"get_namespace",
"is_array_api_obj",
"is_cupy_array",
"is_dask_array",
"is_jax_array",
"is_numpy_array",
"is_torch_array",
"size",
"to_device",
]

_all_ignore = ['sys', 'math', 'inspect']
10 changes: 8 additions & 2 deletions array_api_compat/common/_linalg.py
Original file line number Diff line number Diff line change
@@ -11,7 +11,7 @@
else:
from numpy.core.numeric import normalize_axis_tuple

from ._aliases import matrix_transpose, isdtype
from ._aliases import matmul, matrix_transpose, tensordot, vecdot, isdtype
from .._internal import get_xp

# These are in the main NumPy namespace but not in numpy.linalg
@@ -149,4 +149,10 @@ def trace(x: ndarray, /, xp, *, offset: int = 0, dtype=None, **kwargs) -> ndarra
dtype = xp.float64
elif x.dtype == xp.complex64:
dtype = xp.complex128
return xp.asarray(xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs))
return xp.asarray(xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs))

__all__ = ['cross', 'matmul', 'outer', 'tensordot', 'EighResult',
'QRResult', 'SlogdetResult', 'SVDResult', 'eigh', 'qr', 'slogdet',
'svd', 'cholesky', 'matrix_rank', 'pinv', 'matrix_norm',
'matrix_transpose', 'svdvals', 'vecdot', 'vector_norm', 'diagonal',
'trace']
2 changes: 1 addition & 1 deletion array_api_compat/common/_typing.py
Original file line number Diff line number Diff line change
@@ -20,4 +20,4 @@ def __len__(self, /) -> int: ...
SupportsBufferProtocol = Any

Array = Any
Device = Any
Device = Any
153 changes: 7 additions & 146 deletions array_api_compat/cupy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,153 +1,14 @@
import cupy as _cp
from cupy import * # noqa: F401, F403
from cupy import * # noqa: F403

# from cupy import * doesn't overwrite these builtin names
from cupy import abs, max, min, round

from .._internal import _get_all_public_members
from ..common._helpers import (
array_namespace,
device,
get_namespace,
is_array_api_obj,
size,
to_device,
)
from cupy import abs, max, min, round # noqa: F401

# These imports may overwrite names from the import * above.
from ._aliases import (
UniqueAllResult,
UniqueCountsResult,
UniqueInverseResult,
acos,
acosh,
arange,
argsort,
asarray,
asarray_cupy,
asin,
asinh,
astype,
atan,
atan2,
atanh,
bitwise_invert,
bitwise_left_shift,
bitwise_right_shift,
bool,
ceil,
concat,
empty,
empty_like,
eye,
floor,
full,
full_like,
isdtype,
linspace,
matmul,
matrix_transpose,
nonzero,
ones,
ones_like,
permute_dims,
pow,
prod,
reshape,
sort,
std,
sum,
tensordot,
trunc,
unique_all,
unique_counts,
unique_inverse,
unique_values,
var,
vecdot,
zeros,
zeros_like,
)

__all__ = []

__all__ += _get_all_public_members(_cp)

__all__ += [
"abs",
"max",
"min",
"round",
]

__all__ += [
"array_namespace",
"device",
"get_namespace",
"is_array_api_obj",
"size",
"to_device",
]

__all__ += [
"UniqueAllResult",
"UniqueCountsResult",
"UniqueInverseResult",
"acos",
"acosh",
"arange",
"argsort",
"asarray",
"asarray_cupy",
"asin",
"asinh",
"astype",
"atan",
"atan2",
"atanh",
"bitwise_invert",
"bitwise_left_shift",
"bitwise_right_shift",
"bool",
"ceil",
"concat",
"empty",
"empty_like",
"eye",
"floor",
"full",
"full_like",
"isdtype",
"linspace",
"matmul",
"matrix_transpose",
"nonzero",
"ones",
"ones_like",
"permute_dims",
"pow",
"prod",
"reshape",
"sort",
"std",
"sum",
"tensordot",
"trunc",
"unique_all",
"unique_counts",
"unique_inverse",
"unique_values",
"var",
"zeros",
"zeros_like",
]

__all__ += [
"matrix_transpose",
"vecdot",
]
from ._aliases import * # noqa: F403

# See the comment in the numpy __init__.py
__import__(__package__ + ".linalg")
__import__(__package__ + '.linalg')

from ..common._helpers import * # noqa: F401,F403

__array_api_version__ = "2022.12"
__array_api_version__ = '2022.12'
32 changes: 5 additions & 27 deletions array_api_compat/cupy/_aliases.py
Original file line number Diff line number Diff line change
@@ -5,12 +5,11 @@
import cupy as cp

from ..common import _aliases
from ..common import _linalg

from .._internal import get_xp

asarray = asarray_cupy = partial(_aliases._asarray, namespace='cupy')
asarray.__doc__ = _aliases._asarray.__doc__
del partial

bool = cp.bool_

@@ -74,28 +73,7 @@
else:
isdtype = get_xp(cp)(_aliases.isdtype)


cross = get_xp(cp)(_linalg.cross)
outer = get_xp(cp)(_linalg.outer)
EighResult = _linalg.EighResult
QRResult = _linalg.QRResult
SlogdetResult = _linalg.SlogdetResult
SVDResult = _linalg.SVDResult
eigh = get_xp(cp)(_linalg.eigh)
qr = get_xp(cp)(_linalg.qr)
slogdet = get_xp(cp)(_linalg.slogdet)
svd = get_xp(cp)(_linalg.svd)
cholesky = get_xp(cp)(_linalg.cholesky)
matrix_rank = get_xp(cp)(_linalg.matrix_rank)
pinv = get_xp(cp)(_linalg.pinv)
matrix_norm = get_xp(cp)(_linalg.matrix_norm)
svdvals = get_xp(cp)(_linalg.svdvals)
diagonal = get_xp(cp)(_linalg.diagonal)
trace = get_xp(cp)(_linalg.trace)

# These functions are completely new here. If the library already has them
# (i.e., numpy 2.0), use the library version instead of our wrapper.
if hasattr(cp.linalg, 'vector_norm'):
vector_norm = cp.linalg.vector_norm
else:
vector_norm = get_xp(cp)(_linalg.vector_norm)
__all__ = _aliases.__all__ + ['asarray', 'asarray_cupy', 'bool', 'acos',
'acosh', 'asin', 'asinh', 'atan', 'atan2',
'atanh', 'bitwise_left_shift', 'bitwise_invert',
'bitwise_right_shift', 'concat', 'pow']
2 changes: 1 addition & 1 deletion array_api_compat/cupy/_typing.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from __future__ import annotations

__all__ = [
"ndarray",
"Device",
"Dtype",
"ndarray",
]

import sys
97 changes: 42 additions & 55 deletions array_api_compat/cupy/linalg.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,49 @@
import cupy as _cp
from cupy.linalg import * # noqa: F403
# cupy.linalg doesn't have __all__. If it is added, replace this with
#
# from cupy.linalg import __all__ as linalg_all
_n = {}
exec('from cupy.linalg import *', _n)
del _n['__builtins__']
linalg_all = list(_n)
del _n

from .._internal import _get_all_public_members
from ..common import _linalg
from .._internal import get_xp

_cupy_linalg_all = _get_all_public_members(_cp.linalg)
import cupy as cp

for _name in _cupy_linalg_all:
globals()[_name] = getattr(_cp.linalg, _name)
# These functions are in both the main and linalg namespaces
from ._aliases import matmul, matrix_transpose, tensordot, vecdot # noqa: F401

from ._aliases import ( # noqa: E402
EighResult,
QRResult,
SlogdetResult,
SVDResult,
cholesky,
cross,
diagonal,
eigh,
matmul,
matrix_norm,
matrix_rank,
matrix_transpose,
outer,
pinv,
qr,
slogdet,
svd,
svdvals,
tensordot,
trace,
vecdot,
vector_norm,
)
cross = get_xp(cp)(_linalg.cross)
outer = get_xp(cp)(_linalg.outer)
EighResult = _linalg.EighResult
QRResult = _linalg.QRResult
SlogdetResult = _linalg.SlogdetResult
SVDResult = _linalg.SVDResult
eigh = get_xp(cp)(_linalg.eigh)
qr = get_xp(cp)(_linalg.qr)
slogdet = get_xp(cp)(_linalg.slogdet)
svd = get_xp(cp)(_linalg.svd)
cholesky = get_xp(cp)(_linalg.cholesky)
matrix_rank = get_xp(cp)(_linalg.matrix_rank)
pinv = get_xp(cp)(_linalg.pinv)
matrix_norm = get_xp(cp)(_linalg.matrix_norm)
svdvals = get_xp(cp)(_linalg.svdvals)
diagonal = get_xp(cp)(_linalg.diagonal)
trace = get_xp(cp)(_linalg.trace)

__all__ = []
# These functions are completely new here. If the library already has them
# (i.e., numpy 2.0), use the library version instead of our wrapper.
if hasattr(cp.linalg, 'vector_norm'):
vector_norm = cp.linalg.vector_norm
else:
vector_norm = get_xp(cp)(_linalg.vector_norm)

__all__ += _cupy_linalg_all
__all__ = linalg_all + _linalg.__all__

__all__ += [
"EighResult",
"QRResult",
"SVDResult",
"SlogdetResult",
"cholesky",
"cross",
"diagonal",
"eigh",
"matmul",
"matrix_norm",
"matrix_rank",
"matrix_transpose",
"outer",
"pinv",
"qr",
"slogdet",
"svd",
"svdvals",
"tensordot",
"trace",
"vecdot",
"vector_norm",
]
del get_xp
del cp
del linalg_all
del _linalg
210 changes: 4 additions & 206 deletions array_api_compat/dask/array/__init__.py
Original file line number Diff line number Diff line change
@@ -1,210 +1,8 @@
import dask.array as _da
from dask.array import * # noqa: F401, F403
from dask.array import (
# Element wise aliases
arccos as acos,
)
from dask.array import (
arccosh as acosh,
)
from dask.array import (
arcsin as asin,
)
from dask.array import (
arcsinh as asinh,
)
from dask.array import (
arctan as atan,
)
from dask.array import (
arctan2 as atan2,
)
from dask.array import (
arctanh as atanh,
)
from dask.array import (
bool_ as bool,
)
from dask.array import (
# Other
concatenate as concat,
)
from dask.array import (
invert as bitwise_invert,
)
from dask.array import (
left_shift as bitwise_left_shift,
)
from dask.array import (
power as pow,
)
from dask.array import (
right_shift as bitwise_right_shift,
)
from dask.array import * # noqa: F403

# These imports may overwrite names from the import * above.
from numpy import (
can_cast,
complex64,
complex128,
e,
finfo,
float32,
float64,
iinfo,
inf,
int8,
int16,
int32,
int64,
nan,
newaxis,
pi,
result_type,
uint8,
uint16,
uint32,
uint64,
)
from ._aliases import * # noqa: F403

from ..common._helpers import (
array_namespace,
device,
get_namespace,
is_array_api_obj,
size,
to_device,
)
from ..internal import _get_all_public_members
from ._aliases import (
UniqueAllResult,
UniqueCountsResult,
UniqueInverseResult,
arange,
asarray,
astype,
ceil,
empty,
empty_like,
eye,
floor,
full,
full_like,
isdtype,
linspace,
matmul,
matrix_transpose,
nonzero,
ones,
ones_like,
permute_dims,
prod,
reshape,
std,
sum,
tensordot,
trunc,
unique_all,
unique_counts,
unique_inverse,
unique_values,
var,
vecdot,
zeros,
zeros_like,
)
__array_api_version__ = '2022.12'

__all__ = []

__all__ += _get_all_public_members(_da)

__all__ += [
"can_cast",
"complex64",
"complex128",
"e",
"finfo",
"float32",
"float64",
"iinfo",
"inf",
"int8",
"int16",
"int32",
"int64",
"nan",
"newaxis",
"pi",
"result_type",
"uint8",
"uint16",
"uint32",
"uint64",
]

__all__ += [
"array_namespace",
"device",
"get_namespace",
"is_array_api_obj",
"size",
"to_device",
]

# 'sort', 'argsort' are unsupported by dask.array

__all__ += [
"UniqueAllResult",
"UniqueCountsResult",
"UniqueInverseResult",
"acos",
"acosh",
"arange",
"asarray",
"asin",
"asinh",
"astype",
"atan",
"atan2",
"atanh",
"bitwise_invert",
"bitwise_left_shift",
"bitwise_right_shift",
"bool",
"ceil",
"concat",
"empty",
"empty_like",
"eye",
"floor",
"full",
"full_like",
"isdtype",
"linspace",
"matmul",
"matrix_transpose",
"nonzero",
"ones",
"ones_like",
"permute_dims",
"pow",
"prod",
"reshape",
"std",
"sum",
"tensordot",
"trunc",
"unique_all",
"unique_counts",
"unique_inverse",
"unique_values",
"var",
"vecdot",
"zeros",
"zeros_like",
]


__array_api_version__ = "2022.12"

__import__(__package__ + ".linalg")
__import__(__package__ + '.linalg')
101 changes: 68 additions & 33 deletions array_api_compat/dask/array/_aliases.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,42 @@
from __future__ import annotations

from functools import partial
from typing import TYPE_CHECKING

import numpy as np
from ...common import _aliases
from ...common._helpers import _check_device

from ..._internal import get_xp
from ...common import _aliases, _linalg
from ...common._helpers import _check_device

if TYPE_CHECKING:
from typing import Optional, Tuple, Union
import numpy as np
from numpy import (
# Constants
e,
inf,
nan,
pi,
newaxis,
# Dtypes
bool_ as bool,
float32,
float64,
int8,
int16,
int32,
int64,
uint8,
uint16,
uint32,
uint64,
complex64,
complex128,
iinfo,
finfo,
can_cast,
result_type,
)

from ...common._typing import Device, Dtype, ndarray
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Optional, Union
from ...common._typing import ndarray, Device, Dtype

import dask.array as da

@@ -25,9 +49,8 @@
# not pass stop/step as keyword arguments, which will cause
# an error with dask


# TODO: delete the xp stuff, it shouldn't be necessary
def dask_arange(
def _dask_arange(
start: Union[int, float],
/,
stop: Optional[Union[int, float]] = None,
@@ -49,11 +72,11 @@ def dask_arange(
args.append(step)
return xp.arange(*args, dtype=dtype, **kwargs)


arange = get_xp(da)(dask_arange)
arange = get_xp(da)(_dask_arange)
eye = get_xp(da)(_aliases.eye)

asarray = partial(_aliases._asarray, namespace="dask.array")
from functools import partial
asarray = partial(_aliases._asarray, namespace='dask.array')
asarray.__doc__ = _aliases._asarray.__doc__

linspace = get_xp(da)(_aliases.linspace)
@@ -89,22 +112,34 @@ def dask_arange(
matmul = get_xp(np)(_aliases.matmul)
tensordot = get_xp(np)(_aliases.tensordot)


EighResult = _linalg.EighResult
QRResult = _linalg.QRResult
SlogdetResult = _linalg.SlogdetResult
SVDResult = _linalg.SVDResult
qr = get_xp(da)(_linalg.qr)
cholesky = get_xp(da)(_linalg.cholesky)
matrix_rank = get_xp(da)(_linalg.matrix_rank)
matrix_norm = get_xp(da)(_linalg.matrix_norm)


def svdvals(x: ndarray) -> Union[ndarray, Tuple[ndarray, ...]]:
# TODO: can't avoid computing U or V for dask
_, s, _ = da.linalg.svd(x)
return s


vector_norm = get_xp(da)(_linalg.vector_norm)
diagonal = get_xp(da)(_linalg.diagonal)
from dask.array import (
# Element wise aliases
arccos as acos,
arccosh as acosh,
arcsin as asin,
arcsinh as asinh,
arctan as atan,
arctan2 as atan2,
arctanh as atanh,
left_shift as bitwise_left_shift,
right_shift as bitwise_right_shift,
invert as bitwise_invert,
power as pow,
# Other
concatenate as concat,
)

# exclude these from all since
_da_unsupported = ['sort', 'argsort']

common_aliases = [alias for alias in _aliases.__all__ if alias not in _da_unsupported]

__all__ = common_aliases + ['asarray', 'bool', 'acos',
'acosh', 'asin', 'asinh', 'atan', 'atan2',
'atanh', 'bitwise_left_shift', 'bitwise_invert',
'bitwise_right_shift', 'concat', 'pow',
'e', 'inf', 'nan', 'pi', 'newaxis', 'float32', 'float64', 'int8',
'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64',
'complex64', 'complex128', 'iinfo', 'finfo', 'can_cast', 'result_type']

_all_ignore = ['get_xp', 'da', 'partial', 'common_aliases', 'np']
106 changes: 56 additions & 50 deletions array_api_compat/dask/array/linalg.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,56 @@
import dask.array as _da
from dask.array import (
matmul,
outer,
tensordot,
trace,
)
from dask.array.linalg import * # noqa: F401, F403

from .._internal import _get_all_public_members
from ._aliases import (
EighResult,
QRResult,
SlogdetResult,
SVDResult,
cholesky,
diagonal,
matrix_norm,
matrix_rank,
matrix_transpose,
qr,
svdvals,
vecdot,
vector_norm,
)

__all__ = [
"matmul",
"outer",
"tensordot",
"trace",
]

__all__ += _get_all_public_members(_da.linalg)

__all__ += [
"EighResult",
"QRResult",
"SVDResult",
"SlogdetResult",
"cholesky",
"diagonal",
"matrix_norm",
"matrix_rank",
"matrix_transpose",
"qr",
"svdvals",
"vecdot",
"vector_norm",
]
from __future__ import annotations

from dask.array.linalg import svd
from ...common import _linalg
from ..._internal import get_xp

# Exports
from dask.array.linalg import * # noqa: F403
from dask.array import trace, outer

# These functions are in both the main and linalg namespaces
from dask.array import matmul, tensordot
from ._aliases import matrix_transpose, vecdot

import dask.array as da

from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Union, Tuple
from ...common._typing import ndarray

# cupy.linalg doesn't have __all__. If it is added, replace this with
#
# from cupy.linalg import __all__ as linalg_all
_n = {}
exec('from dask.array.linalg import *', _n)
del _n['__builtins__']
if 'annotations' in _n:
del _n['annotations']
linalg_all = list(_n)
del _n

EighResult = _linalg.EighResult
QRResult = _linalg.QRResult
SlogdetResult = _linalg.SlogdetResult
SVDResult = _linalg.SVDResult
qr = get_xp(da)(_linalg.qr)
cholesky = get_xp(da)(_linalg.cholesky)
matrix_rank = get_xp(da)(_linalg.matrix_rank)
matrix_norm = get_xp(da)(_linalg.matrix_norm)

def svdvals(x: ndarray) -> Union[ndarray, Tuple[ndarray, ...]]:
# TODO: can't avoid computing U or V for dask
_, s, _ = svd(x)
return s

vector_norm = get_xp(da)(_linalg.vector_norm)
diagonal = get_xp(da)(_linalg.diagonal)

__all__ = linalg_all + ["trace", "outer", "matmul", "tensordot",
"matrix_transpose", "vecdot", "EighResult",
"QRResult", "SlogdetResult", "SVDResult", "qr",
"cholesky", "matrix_rank", "matrix_norm", "svdvals",
"vector_norm", "diagonal"]

_all_ignore = ['get_xp', 'da', 'linalg_all']
154 changes: 9 additions & 145 deletions array_api_compat/numpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,150 +1,10 @@
from numpy import * # noqa: F401, F403
from numpy import __all__ as _numpy_all
from numpy import * # noqa: F403

# from numpy import * doesn't overwrite these builtin names
from numpy import abs, max, min, round

from ..common._helpers import (
array_namespace,
device,
get_namespace,
is_array_api_obj,
size,
to_device,
)
from numpy import abs, max, min, round # noqa: F401

# These imports may overwrite names from the import * above.
from ._aliases import (
UniqueAllResult,
UniqueCountsResult,
UniqueInverseResult,
acos,
acosh,
arange,
argsort,
asarray,
asarray_numpy,
asin,
asinh,
astype,
atan,
atan2,
atanh,
bitwise_invert,
bitwise_left_shift,
bitwise_right_shift,
bool,
ceil,
concat,
empty,
empty_like,
eye,
floor,
full,
full_like,
isdtype,
linspace,
matmul,
matrix_transpose,
nonzero,
ones,
ones_like,
permute_dims,
pow,
prod,
reshape,
sort,
std,
sum,
tensordot,
trunc,
unique_all,
unique_counts,
unique_inverse,
unique_values,
var,
vecdot,
zeros,
zeros_like,
)

__all__ = []

__all__ += _numpy_all

__all__ += [
"abs",
"max",
"min",
"round",
]

__all__ += [
"array_namespace",
"device",
"get_namespace",
"is_array_api_obj",
"size",
"to_device",
]

__all__ += [
"UniqueAllResult",
"UniqueCountsResult",
"UniqueInverseResult",
"acos",
"acosh",
"arange",
"argsort",
"asarray",
"asarray_numpy",
"asin",
"asinh",
"astype",
"atan",
"atan2",
"atanh",
"bitwise_invert",
"bitwise_left_shift",
"bitwise_right_shift",
"bool",
"ceil",
"concat",
"empty",
"empty_like",
"eye",
"floor",
"full",
"full_like",
"isdtype",
"linspace",
"matmul",
"matrix_transpose",
"nonzero",
"ones",
"ones_like",
"permute_dims",
"pow",
"prod",
"reshape",
"sort",
"std",
"sum",
"tensordot",
"trunc",
"unique_all",
"unique_counts",
"unique_inverse",
"unique_values",
"var",
"zeros",
"zeros_like",
]

__all__ += [
"matrix_transpose",
"vecdot",
]
from ._aliases import * # noqa: F403

# Don't know why, but we have to do an absolute import to import linalg. If we
# instead do
@@ -153,6 +13,10 @@
#
# It doesn't overwrite np.linalg from above. The import is generated
# dynamically so that the library can be vendored.
__import__(__package__ + ".linalg")
__import__(__package__ + '.linalg')

from .linalg import matrix_transpose, vecdot # noqa: F401

from ..common._helpers import * # noqa: F403

__array_api_version__ = "2022.12"
__array_api_version__ = '2022.12'
40 changes: 11 additions & 29 deletions array_api_compat/numpy/_aliases.py
Original file line number Diff line number Diff line change
@@ -2,14 +2,15 @@

from functools import partial

import numpy as np
from ..common import _aliases

from .._internal import get_xp
from ..common import _aliases, _linalg

asarray = asarray_numpy = partial(_aliases._asarray, namespace="numpy")
asarray = asarray_numpy = partial(_aliases._asarray, namespace='numpy')
asarray.__doc__ = _aliases._asarray.__doc__
del partial

import numpy as np
bool = np.bool_

# Basic renames
@@ -63,37 +64,18 @@

# These functions are completely new here. If the library already has them
# (i.e., numpy 2.0), use the library version instead of our wrapper.
if hasattr(np, "vecdot"):
if hasattr(np, 'vecdot'):
vecdot = np.vecdot
else:
vecdot = get_xp(np)(_aliases.vecdot)
if hasattr(np, "isdtype"):
if hasattr(np, 'isdtype'):
isdtype = np.isdtype
else:
isdtype = get_xp(np)(_aliases.isdtype)

__all__ = _aliases.__all__ + ['asarray', 'asarray_numpy', 'bool', 'acos',
'acosh', 'asin', 'asinh', 'atan', 'atan2',
'atanh', 'bitwise_left_shift', 'bitwise_invert',
'bitwise_right_shift', 'concat', 'pow']

cross = get_xp(np)(_linalg.cross)
outer = get_xp(np)(_linalg.outer)
EighResult = _linalg.EighResult
QRResult = _linalg.QRResult
SlogdetResult = _linalg.SlogdetResult
SVDResult = _linalg.SVDResult
eigh = get_xp(np)(_linalg.eigh)
qr = get_xp(np)(_linalg.qr)
slogdet = get_xp(np)(_linalg.slogdet)
svd = get_xp(np)(_linalg.svd)
cholesky = get_xp(np)(_linalg.cholesky)
matrix_rank = get_xp(np)(_linalg.matrix_rank)
pinv = get_xp(np)(_linalg.pinv)
matrix_norm = get_xp(np)(_linalg.matrix_norm)
svdvals = get_xp(np)(_linalg.svdvals)
diagonal = get_xp(np)(_linalg.diagonal)
trace = get_xp(np)(_linalg.trace)

# These functions are completely new here. If the library already has them
# (i.e., numpy 2.0), use the library version instead of our wrapper.
if hasattr(np.linalg, "vector_norm"):
vector_norm = np.linalg.vector_norm
else:
vector_norm = get_xp(np)(_linalg.vector_norm)
_all_ignore = ['np', 'get_xp']
2 changes: 1 addition & 1 deletion array_api_compat/numpy/_typing.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from __future__ import annotations

__all__ = [
"ndarray",
"Device",
"Dtype",
"ndarray",
]

import sys
105 changes: 42 additions & 63 deletions array_api_compat/numpy/linalg.py
Original file line number Diff line number Diff line change
@@ -1,63 +1,42 @@
import numpy as _np

from .._internal import _get_all_public_members

_numpy_linalg_all = _get_all_public_members(_np.linalg)

for _name in _numpy_linalg_all:
globals()[_name] = getattr(_np.linalg, _name)


from ._aliases import ( # noqa: E402
EighResult,
QRResult,
SlogdetResult,
SVDResult,
cholesky,
cross,
diagonal,
eigh,
matmul,
matrix_norm,
matrix_rank,
matrix_transpose,
outer,
pinv,
qr,
slogdet,
svd,
svdvals,
tensordot,
trace,
vecdot,
vector_norm,
)

__all__ = []

__all__ += _numpy_linalg_all

__all__ += [
"EighResult",
"QRResult",
"SVDResult",
"SlogdetResult",
"cholesky",
"cross",
"diagonal",
"eigh",
"matmul",
"matrix_norm",
"matrix_rank",
"matrix_transpose",
"outer",
"pinv",
"qr",
"slogdet",
"svd",
"svdvals",
"tensordot",
"trace",
"vecdot",
"vector_norm",
]
from numpy.linalg import * # noqa: F403
from numpy.linalg import __all__ as linalg_all

from ..common import _linalg
from .._internal import get_xp

# These functions are in both the main and linalg namespaces
from ._aliases import matmul, matrix_transpose, tensordot, vecdot # noqa: F401

import numpy as np

cross = get_xp(np)(_linalg.cross)
outer = get_xp(np)(_linalg.outer)
EighResult = _linalg.EighResult
QRResult = _linalg.QRResult
SlogdetResult = _linalg.SlogdetResult
SVDResult = _linalg.SVDResult
eigh = get_xp(np)(_linalg.eigh)
qr = get_xp(np)(_linalg.qr)
slogdet = get_xp(np)(_linalg.slogdet)
svd = get_xp(np)(_linalg.svd)
cholesky = get_xp(np)(_linalg.cholesky)
matrix_rank = get_xp(np)(_linalg.matrix_rank)
pinv = get_xp(np)(_linalg.pinv)
matrix_norm = get_xp(np)(_linalg.matrix_norm)
svdvals = get_xp(np)(_linalg.svdvals)
diagonal = get_xp(np)(_linalg.diagonal)
trace = get_xp(np)(_linalg.trace)

# These functions are completely new here. If the library already has them
# (i.e., numpy 2.0), use the library version instead of our wrapper.
if hasattr(np.linalg, 'vector_norm'):
vector_norm = np.linalg.vector_norm
else:
vector_norm = get_xp(np)(_linalg.vector_norm)

__all__ = linalg_all + _linalg.__all__

del get_xp
del np
del linalg_all
del _linalg
199 changes: 16 additions & 183 deletions array_api_compat/torch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,189 +1,22 @@
# Several names are not included in the above import *
import torch as _torch
from torch import * # noqa: F401, F403

from .._internal import _get_all_public_members


def exlcude(name):
if (
name.startswith("_")
or name.endswith("_")
or "cuda" in name
or "cpu" in name
or "backward" in name
):
return True
return False


_torch_all = _get_all_public_members(_torch, exclude=exlcude, extend_all=True)

for _name in _torch_all:
globals()[_name] = getattr(_torch, _name)

from torch import * # noqa: F403

from ..common._helpers import ( # noqa: E402
array_namespace,
device,
get_namespace,
is_array_api_obj,
size,
to_device,
)
# Several names are not included in the above import *
import torch
for n in dir(torch):
if (n.startswith('_')
or n.endswith('_')
or 'cuda' in n
or 'cpu' in n
or 'backward' in n):
continue
exec(n + ' = torch.' + n)

# These imports may overwrite names from the import * above.
from ._aliases import ( # noqa: E402
add,
all,
any,
arange,
astype,
atan2,
bitwise_and,
bitwise_invert,
bitwise_left_shift,
bitwise_or,
bitwise_right_shift,
bitwise_xor,
broadcast_arrays,
broadcast_to,
can_cast,
concat,
divide,
empty,
equal,
expand_dims,
eye,
flip,
floor_divide,
full,
greater,
greater_equal,
isdtype,
less,
less_equal,
linspace,
logaddexp,
matmul,
matrix_transpose,
max,
mean,
min,
multiply,
newaxis,
nonzero,
not_equal,
ones,
permute_dims,
pow,
prod,
remainder,
reshape,
result_type,
roll,
sort,
squeeze,
std,
subtract,
sum,
take,
tensordot,
tril,
triu,
unique_all,
unique_counts,
unique_inverse,
unique_values,
var,
vecdot,
where,
zeros,
)

__all__ = []

__all__ += _torch_all

__all__ += [
"array_namespace",
"device",
"get_namespace",
"is_array_api_obj",
"size",
"to_device",
]

__all__ += [
"add",
"all",
"any",
"arange",
"astype",
"atan2",
"bitwise_and",
"bitwise_invert",
"bitwise_left_shift",
"bitwise_or",
"bitwise_right_shift",
"bitwise_xor",
"broadcast_arrays",
"broadcast_to",
"can_cast",
"concat",
"divide",
"empty",
"equal",
"expand_dims",
"eye",
"flip",
"floor_divide",
"full",
"greater",
"greater_equal",
"isdtype",
"less",
"less_equal",
"linspace",
"logaddexp",
"matmul",
"matrix_transpose",
"max",
"mean",
"min",
"multiply",
"newaxis",
"nonzero",
"not_equal",
"ones",
"permute_dims",
"pow",
"prod",
"remainder",
"reshape",
"result_type",
"roll",
"sort",
"squeeze",
"std",
"subtract",
"sum",
"take",
"tensordot",
"tril",
"triu",
"unique_all",
"unique_counts",
"unique_inverse",
"unique_values",
"var",
"vecdot",
"where",
"zeros",
]

from ._aliases import * # noqa: F403

# See the comment in the numpy __init__.py
__import__(__package__ + ".linalg")
__import__(__package__ + '.linalg')

from ..common._helpers import * # noqa: F403

__array_api_version__ = "2022.12"
__array_api_version__ = '2022.12'
86 changes: 32 additions & 54 deletions array_api_compat/torch/_aliases.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,19 @@
from __future__ import annotations

from builtins import all as builtin_all
from builtins import any as builtin_any
from functools import wraps
from typing import TYPE_CHECKING

import torch
from functools import wraps as _wraps
from builtins import all as _builtin_all, any as _builtin_any

from ..common._aliases import (matrix_transpose as _aliases_matrix_transpose,
vecdot as _aliases_vecdot)
from .._internal import get_xp
from ..common._aliases import UniqueAllResult, UniqueCountsResult, UniqueInverseResult
from ..common._aliases import matrix_transpose as _aliases_matrix_transpose
from ..common._aliases import vecdot as _aliases_vecdot

import torch

from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import List, Optional, Sequence, Tuple, Union

from torch import dtype as Dtype

from ..common._typing import Device
from torch import dtype as Dtype

array = torch.Tensor

@@ -88,7 +84,7 @@


def _two_arg(f):
@wraps(f)
@_wraps(f)
def _f(x1, x2, /, **kwargs):
x1, x2 = _fix_promotion(x1, x2)
return f(x1, x2, **kwargs)
@@ -511,7 +507,7 @@ def arange(start: Union[int, float],
start, stop = 0, start
if step > 0 and stop <= start or step < 0 and stop >= start:
if dtype is None:
if builtin_all(isinstance(i, int) for i in [start, stop, step]):
if _builtin_all(isinstance(i, int) for i in [start, stop, step]):
dtype = torch.int64
else:
dtype = torch.float32
@@ -603,6 +599,11 @@ def broadcast_arrays(*arrays: array) -> List[array]:
shape = torch.broadcast_shapes(*[a.shape for a in arrays])
return [torch.broadcast_to(a, shape) for a in arrays]

# Note that these named tuples aren't actually part of the standard namespace,
# but I don't see any issue with exporting the names here regardless.
from ..common._aliases import (UniqueAllResult, UniqueCountsResult,
UniqueInverseResult)

# https://github.com/pytorch/pytorch/issues/70920
def unique_all(x: array) -> UniqueAllResult:
# torch.unique doesn't support returning indices.
@@ -667,7 +668,7 @@ def isdtype(
for more details
"""
if isinstance(kind, tuple) and _tuple:
return builtin_any(isdtype(dtype, k, _tuple=False) for k in kind)
return _builtin_any(isdtype(dtype, k, _tuple=False) for k in kind)
elif isinstance(kind, str):
if kind == 'bool':
return dtype == torch.bool
@@ -695,42 +696,19 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -
axis = 0
return torch.index_select(x, axis, indices, **kwargs)



# Note: torch.linalg.cross does not default to axis=-1 (it defaults to the
# first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743
def cross(x1: array, x2: array, /, *, axis: int = -1) -> array:
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
return torch.linalg.cross(x1, x2, dim=axis)

def vecdot_linalg(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array:
from ._aliases import isdtype

x1, x2 = _fix_promotion(x1, x2, only_scalar=False)

# torch.linalg.vecdot doesn't support integer dtypes
if isdtype(x1.dtype, 'integral') or isdtype(x2.dtype, 'integral'):
if kwargs:
raise RuntimeError("vecdot kwargs not supported for integral dtypes")
ndim = max(x1.ndim, x2.ndim)
x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape)
x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape)
if x1_shape[axis] != x2_shape[axis]:
raise ValueError("x1 and x2 must have the same size along the given axis")

x1_, x2_ = torch.broadcast_tensors(x1, x2)
x1_ = torch.moveaxis(x1_, axis, -1)
x2_ = torch.moveaxis(x2_, axis, -1)

res = x1_[..., None, :] @ x2_[..., None]
return res[..., 0, 0]
return torch.linalg.vecdot(x1, x2, dim=axis, **kwargs)

def solve(x1: array, x2: array, /, **kwargs) -> array:
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
return torch.linalg.solve(x1, x2, **kwargs)

# torch.trace doesn't support the offset argument and doesn't support stacking
def trace(x: array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> array:
# Use our wrapped sum to make sure it does upcasting correctly
return sum(torch.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype)
__all__ = ['result_type', 'can_cast', 'permute_dims', 'bitwise_invert',
'newaxis', 'add', 'atan2', 'bitwise_and', 'bitwise_left_shift',
'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'divide',
'equal', 'floor_divide', 'greater', 'greater_equal', 'less',
'less_equal', 'logaddexp', 'multiply', 'not_equal', 'pow',
'remainder', 'subtract', 'max', 'min', 'sort', 'prod', 'sum',
'any', 'all', 'mean', 'std', 'var', 'concat', 'squeeze',
'broadcast_to', 'flip', 'roll', 'nonzero', 'where', 'reshape',
'arange', 'eye', 'linspace', 'full', 'ones', 'zeros', 'empty',
'tril', 'triu', 'expand_dims', 'astype', 'broadcast_arrays',
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
'matmul', 'matrix_transpose', 'vecdot', 'tensordot', 'isdtype',
'take']

_all_ignore = ['torch', 'get_xp']
91 changes: 62 additions & 29 deletions array_api_compat/torch/linalg.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,67 @@
import torch as _torch
from __future__ import annotations

from .._internal import _get_all_public_members
from typing import TYPE_CHECKING
if TYPE_CHECKING:
import torch
array = torch.Tensor
from torch import dtype as Dtype
from typing import Optional

_torch_linalg_all = _get_all_public_members(_torch.linalg)
from ._aliases import _fix_promotion, sum

for _name in _torch_linalg_all:
globals()[_name] = getattr(_torch.linalg, _name)
from torch.linalg import * # noqa: F403

# torch.linalg doesn't define __all__
# from torch.linalg import __all__ as linalg_all
from torch import linalg as torch_linalg
linalg_all = [i for i in dir(torch_linalg) if not i.startswith('_')]

# outer is implemented in torch but aren't in the linalg namespace
outer = _torch.outer

from ._aliases import ( # noqa: E402
matrix_transpose,
solve,
sum,
tensordot,
trace,
vecdot_linalg as vecdot,
)

__all__ = []

__all__ += _torch_linalg_all

__all__ += [
"matrix_transpose",
"outer",
"solve",
"sum",
"tensordot",
"trace",
"vecdot",
]
from torch import outer
# These functions are in both the main and linalg namespaces
from ._aliases import matmul, matrix_transpose, tensordot

# Note: torch.linalg.cross does not default to axis=-1 (it defaults to the
# first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743
def cross(x1: array, x2: array, /, *, axis: int = -1) -> array:
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
return torch_linalg.cross(x1, x2, dim=axis)

def vecdot(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array:
from ._aliases import isdtype

x1, x2 = _fix_promotion(x1, x2, only_scalar=False)

# torch.linalg.vecdot doesn't support integer dtypes
if isdtype(x1.dtype, 'integral') or isdtype(x2.dtype, 'integral'):
if kwargs:
raise RuntimeError("vecdot kwargs not supported for integral dtypes")
ndim = max(x1.ndim, x2.ndim)
x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape)
x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape)
if x1_shape[axis] != x2_shape[axis]:
raise ValueError("x1 and x2 must have the same size along the given axis")

x1_, x2_ = torch.broadcast_tensors(x1, x2)
x1_ = torch.moveaxis(x1_, axis, -1)
x2_ = torch.moveaxis(x2_, axis, -1)

res = x1_[..., None, :] @ x2_[..., None]
return res[..., 0, 0]
return torch.linalg.vecdot(x1, x2, dim=axis, **kwargs)

def solve(x1: array, x2: array, /, **kwargs) -> array:
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
return torch.linalg.solve(x1, x2, **kwargs)

# torch.trace doesn't support the offset argument and doesn't support stacking
def trace(x: array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> array:
# Use our wrapped sum to make sure it does upcasting correctly
return sum(torch.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype)

__all__ = linalg_all + ['outer', 'trace', 'matmul', 'matrix_transpose', 'tensordot',
'vecdot', 'solve']

_all_ignore = ['torch_linalg', 'sum']

del linalg_all
13 changes: 13 additions & 0 deletions ruff.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
[lint]
preview = true
select = [
# Defaults
"E4", "E7", "E9", "F",
# Undefined export
"F822",
# Useless import alias
"PLC0414"
]

# Ignore module import not at top of file
ignore = ["E402"]
2 changes: 1 addition & 1 deletion tests/_helpers.py
Original file line number Diff line number Diff line change
@@ -8,7 +8,7 @@
def import_(library, wrapper=False):
if library == 'cupy':
return pytest.importorskip(library)
if 'jax' in library and sys.version_info <= (3, 8):
if 'jax' in library and sys.version_info < (3, 9):
pytest.skip('JAX array API support does not support Python 3.8')

if wrapper:
42 changes: 42 additions & 0 deletions tests/test_all.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""
Test that files that define __all__ aren't missing any exports.
You can add names that shouldn't be exported to _all_ignore, like
_all_ignore = ['sys']
This is preferable to del-ing the names as this will break any name that is
used inside of a function. Note that names starting with an underscore are automatically ignored.
"""


import sys

from ._helpers import import_

import pytest

@pytest.mark.parametrize("library", ["common", "cupy", "numpy", "torch", "dask.array"])
def test_all(library):
import_(library, wrapper=True)

for mod_name in sys.modules:
if 'array_api_compat.' + library not in mod_name:
continue

module = sys.modules[mod_name]

# TODO: We should define __all__ in the __init__.py files and test it
# there too.
if not hasattr(module, '__all__'):
continue

dir_names = [n for n in dir(module) if not n.startswith('_')]
ignore_all_names = getattr(module, '_all_ignore', [])
ignore_all_names += ['annotations', 'TYPE_CHECKING']
dir_names = set(dir_names) - set(ignore_all_names)
all_names = module.__all__

if set(dir_names) != set(all_names):
assert set(dir_names) - set(all_names) == set(), f"Some dir() names not included in __all__ for {mod_name}"
assert set(all_names) - set(dir_names) == set(), f"Some __all__ names not in dir() for {mod_name}"