-
Notifications
You must be signed in to change notification settings - Fork 35
Add ruff to ci setup #82
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
42e2a8b
4d0ccb9
fbe6bd4
c7c27be
d3d57b9
0d437cd
afccf29
2395ea0
645cef2
5a6f411
0ff0836
f4d78c7
31bbbfa
5ecc7b5
8e4e9ca
5c66efc
bca606d
890c497
b2f9557
52ef9ee
b069230
ff51015
b0a323d
0ec2d89
2baa4da
efd745c
49f2b7a
a748bfa
6b4e92c
a92f640
9b1110b
68c788f
1720fb6
5cd47df
c5d55ae
49851b5
2db3d6a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
name: CI | ||
on: [push, pull_request] | ||
jobs: | ||
check-ruff: | ||
runs-on: ubuntu-latest | ||
continue-on-error: true | ||
steps: | ||
- uses: actions/checkout@v3 | ||
- name: Install Python | ||
uses: actions/setup-python@v4 | ||
with: | ||
python-version: "3.11" | ||
- name: Install dependencies | ||
run: | | ||
python -m pip install --upgrade pip | ||
pip install ruff | ||
# Update output format to enable automatic inline annotations. | ||
- name: Run Ruff | ||
run: ruff check --output-format=github --select F822,PLC0414,RUF022 --preview . |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,4 +19,4 @@ | |
""" | ||
__version__ = '1.4.1' | ||
|
||
from .common import * | ||
from .common import * # noqa: F401, F403 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,17 @@ | ||
from ._helpers import * | ||
from ._helpers import ( | ||
array_namespace, | ||
device, | ||
get_namespace, | ||
is_array_api_obj, | ||
size, | ||
to_device, | ||
) | ||
|
||
__all__ = [ | ||
"array_namespace", | ||
"device", | ||
"get_namespace", | ||
"is_array_api_obj", | ||
"size", | ||
"to_device", | ||
] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,16 +1,153 @@ | ||
from cupy import * | ||
import cupy as _cp | ||
from cupy import * # noqa: F401, F403 | ||
|
||
# from cupy import * doesn't overwrite these builtin names | ||
from cupy import abs, max, min, round | ||
|
||
from .._internal import _get_all_public_members | ||
from ..common._helpers import ( | ||
array_namespace, | ||
device, | ||
get_namespace, | ||
is_array_api_obj, | ||
size, | ||
to_device, | ||
) | ||
|
||
# These imports may overwrite names from the import * above. | ||
from ._aliases import * | ||
from ._aliases import ( | ||
UniqueAllResult, | ||
UniqueCountsResult, | ||
UniqueInverseResult, | ||
acos, | ||
acosh, | ||
arange, | ||
argsort, | ||
asarray, | ||
asarray_cupy, | ||
asin, | ||
asinh, | ||
astype, | ||
atan, | ||
atan2, | ||
atanh, | ||
bitwise_invert, | ||
bitwise_left_shift, | ||
bitwise_right_shift, | ||
bool, | ||
ceil, | ||
concat, | ||
empty, | ||
empty_like, | ||
eye, | ||
floor, | ||
full, | ||
full_like, | ||
isdtype, | ||
linspace, | ||
matmul, | ||
matrix_transpose, | ||
nonzero, | ||
ones, | ||
ones_like, | ||
permute_dims, | ||
pow, | ||
prod, | ||
reshape, | ||
sort, | ||
std, | ||
sum, | ||
tensordot, | ||
trunc, | ||
unique_all, | ||
unique_counts, | ||
unique_inverse, | ||
unique_values, | ||
var, | ||
vecdot, | ||
zeros, | ||
zeros_like, | ||
) | ||
|
||
# See the comment in the numpy __init__.py | ||
__import__(__package__ + '.linalg') | ||
__all__ = [] | ||
|
||
__all__ += _get_all_public_members(_cp) | ||
|
||
__all__ += [ | ||
"abs", | ||
"max", | ||
"min", | ||
"round", | ||
] | ||
|
||
from .linalg import matrix_transpose, vecdot | ||
__all__ += [ | ||
"array_namespace", | ||
"device", | ||
"get_namespace", | ||
"is_array_api_obj", | ||
"size", | ||
"to_device", | ||
] | ||
|
||
from ..common._helpers import * | ||
__all__ += [ | ||
"UniqueAllResult", | ||
"UniqueCountsResult", | ||
"UniqueInverseResult", | ||
"acos", | ||
"acosh", | ||
"arange", | ||
"argsort", | ||
"asarray", | ||
"asarray_cupy", | ||
"asin", | ||
"asinh", | ||
"astype", | ||
"atan", | ||
"atan2", | ||
"atanh", | ||
"bitwise_invert", | ||
"bitwise_left_shift", | ||
"bitwise_right_shift", | ||
"bool", | ||
"ceil", | ||
"concat", | ||
"empty", | ||
"empty_like", | ||
"eye", | ||
"floor", | ||
"full", | ||
"full_like", | ||
"isdtype", | ||
"linspace", | ||
"matmul", | ||
"matrix_transpose", | ||
"nonzero", | ||
"ones", | ||
"ones_like", | ||
"permute_dims", | ||
"pow", | ||
"prod", | ||
"reshape", | ||
"sort", | ||
"std", | ||
"sum", | ||
"tensordot", | ||
"trunc", | ||
"unique_all", | ||
"unique_counts", | ||
"unique_inverse", | ||
"unique_values", | ||
"var", | ||
"zeros", | ||
"zeros_like", | ||
] | ||
|
||
__all__ += [ | ||
"matrix_transpose", | ||
"vecdot", | ||
] | ||
|
||
# See the comment in the numpy __init__.py | ||
__import__(__package__ + ".linalg") | ||
|
||
__array_api_version__ = '2022.12' | ||
__array_api_version__ = "2022.12" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,9 @@ | ||
from __future__ import annotations | ||
|
||
__all__ = [ | ||
"ndarray", | ||
"Device", | ||
"Dtype", | ||
"ndarray", | ||
] | ||
|
||
import sys | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,47 +1,62 @@ | ||
from cupy.linalg import * | ||
# cupy.linalg doesn't have __all__. If it is added, replace this with | ||
# | ||
# from cupy.linalg import __all__ as linalg_all | ||
_n = {} | ||
exec('from cupy.linalg import *', _n) | ||
del _n['__builtins__'] | ||
linalg_all = list(_n) | ||
del _n | ||
import cupy as _cp | ||
|
||
from ..common import _linalg | ||
from .._internal import get_xp | ||
from ._aliases import (matmul, matrix_transpose, tensordot, vecdot) | ||
from .._internal import _get_all_public_members | ||
|
||
import cupy as cp | ||
_cupy_linalg_all = _get_all_public_members(_cp.linalg) | ||
|
||
cross = get_xp(cp)(_linalg.cross) | ||
outer = get_xp(cp)(_linalg.outer) | ||
EighResult = _linalg.EighResult | ||
QRResult = _linalg.QRResult | ||
SlogdetResult = _linalg.SlogdetResult | ||
SVDResult = _linalg.SVDResult | ||
eigh = get_xp(cp)(_linalg.eigh) | ||
qr = get_xp(cp)(_linalg.qr) | ||
slogdet = get_xp(cp)(_linalg.slogdet) | ||
svd = get_xp(cp)(_linalg.svd) | ||
cholesky = get_xp(cp)(_linalg.cholesky) | ||
matrix_rank = get_xp(cp)(_linalg.matrix_rank) | ||
pinv = get_xp(cp)(_linalg.pinv) | ||
matrix_norm = get_xp(cp)(_linalg.matrix_norm) | ||
svdvals = get_xp(cp)(_linalg.svdvals) | ||
diagonal = get_xp(cp)(_linalg.diagonal) | ||
trace = get_xp(cp)(_linalg.trace) | ||
for _name in _cupy_linalg_all: | ||
globals()[_name] = getattr(_cp.linalg, _name) | ||
|
||
# These functions are completely new here. If the library already has them | ||
# (i.e., numpy 2.0), use the library version instead of our wrapper. | ||
if hasattr(cp.linalg, 'vector_norm'): | ||
vector_norm = cp.linalg.vector_norm | ||
else: | ||
vector_norm = get_xp(cp)(_linalg.vector_norm) | ||
from ._aliases import ( # noqa: E402 | ||
EighResult, | ||
QRResult, | ||
SlogdetResult, | ||
SVDResult, | ||
cholesky, | ||
cross, | ||
diagonal, | ||
eigh, | ||
matmul, | ||
matrix_norm, | ||
matrix_rank, | ||
matrix_transpose, | ||
outer, | ||
pinv, | ||
qr, | ||
slogdet, | ||
svd, | ||
svdvals, | ||
tensordot, | ||
trace, | ||
vecdot, | ||
vector_norm, | ||
) | ||
|
||
__all__ = linalg_all + _linalg.__all__ | ||
__all__ = [] | ||
|
||
del get_xp | ||
del cp | ||
del linalg_all | ||
del _linalg | ||
__all__ += _cupy_linalg_all | ||
|
||
__all__ += [ | ||
"EighResult", | ||
"QRResult", | ||
"SVDResult", | ||
"SlogdetResult", | ||
"cholesky", | ||
"cross", | ||
"diagonal", | ||
"eigh", | ||
"matmul", | ||
"matrix_norm", | ||
"matrix_rank", | ||
"matrix_transpose", | ||
"outer", | ||
"pinv", | ||
"qr", | ||
"slogdet", | ||
"svd", | ||
"svdvals", | ||
"tensordot", | ||
"trace", | ||
"vecdot", | ||
"vector_norm", | ||
] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,210 @@ | ||
from dask.array import * | ||
import dask.array as _da | ||
from dask.array import * # noqa: F401, F403 | ||
from dask.array import ( | ||
# Element wise aliases | ||
arccos as acos, | ||
) | ||
from dask.array import ( | ||
arccosh as acosh, | ||
) | ||
from dask.array import ( | ||
arcsin as asin, | ||
) | ||
from dask.array import ( | ||
arcsinh as asinh, | ||
) | ||
from dask.array import ( | ||
arctan as atan, | ||
) | ||
from dask.array import ( | ||
arctan2 as atan2, | ||
) | ||
from dask.array import ( | ||
arctanh as atanh, | ||
) | ||
from dask.array import ( | ||
bool_ as bool, | ||
) | ||
from dask.array import ( | ||
# Other | ||
concatenate as concat, | ||
) | ||
from dask.array import ( | ||
invert as bitwise_invert, | ||
) | ||
from dask.array import ( | ||
left_shift as bitwise_left_shift, | ||
) | ||
from dask.array import ( | ||
power as pow, | ||
) | ||
from dask.array import ( | ||
right_shift as bitwise_right_shift, | ||
) | ||
|
||
# These imports may overwrite names from the import * above. | ||
from ._aliases import * | ||
from numpy import ( | ||
can_cast, | ||
complex64, | ||
complex128, | ||
e, | ||
finfo, | ||
float32, | ||
float64, | ||
iinfo, | ||
inf, | ||
int8, | ||
int16, | ||
int32, | ||
int64, | ||
nan, | ||
newaxis, | ||
pi, | ||
result_type, | ||
uint8, | ||
uint16, | ||
uint32, | ||
uint64, | ||
) | ||
|
||
__array_api_version__ = '2022.12' | ||
from ..common._helpers import ( | ||
array_namespace, | ||
device, | ||
get_namespace, | ||
is_array_api_obj, | ||
size, | ||
to_device, | ||
) | ||
from ..internal import _get_all_public_members | ||
from ._aliases import ( | ||
UniqueAllResult, | ||
UniqueCountsResult, | ||
UniqueInverseResult, | ||
arange, | ||
asarray, | ||
astype, | ||
ceil, | ||
empty, | ||
empty_like, | ||
eye, | ||
floor, | ||
full, | ||
full_like, | ||
isdtype, | ||
linspace, | ||
matmul, | ||
matrix_transpose, | ||
nonzero, | ||
ones, | ||
ones_like, | ||
permute_dims, | ||
prod, | ||
reshape, | ||
std, | ||
sum, | ||
tensordot, | ||
trunc, | ||
unique_all, | ||
unique_counts, | ||
unique_inverse, | ||
unique_values, | ||
var, | ||
vecdot, | ||
zeros, | ||
zeros_like, | ||
) | ||
|
||
__import__(__package__ + '.linalg') | ||
__all__ = [] | ||
|
||
__all__ += _get_all_public_members(_da) | ||
|
||
__all__ += [ | ||
"can_cast", | ||
"complex64", | ||
"complex128", | ||
"e", | ||
"finfo", | ||
"float32", | ||
"float64", | ||
"iinfo", | ||
"inf", | ||
"int8", | ||
"int16", | ||
"int32", | ||
"int64", | ||
"nan", | ||
"newaxis", | ||
"pi", | ||
"result_type", | ||
"uint8", | ||
"uint16", | ||
"uint32", | ||
"uint64", | ||
] | ||
|
||
__all__ += [ | ||
"array_namespace", | ||
"device", | ||
"get_namespace", | ||
"is_array_api_obj", | ||
"size", | ||
"to_device", | ||
] | ||
|
||
# 'sort', 'argsort' are unsupported by dask.array | ||
|
||
__all__ += [ | ||
"UniqueAllResult", | ||
"UniqueCountsResult", | ||
"UniqueInverseResult", | ||
"acos", | ||
"acosh", | ||
"arange", | ||
"asarray", | ||
"asin", | ||
"asinh", | ||
"astype", | ||
"atan", | ||
"atan2", | ||
"atanh", | ||
"bitwise_invert", | ||
"bitwise_left_shift", | ||
"bitwise_right_shift", | ||
"bool", | ||
"ceil", | ||
"concat", | ||
"empty", | ||
"empty_like", | ||
"eye", | ||
"floor", | ||
"full", | ||
"full_like", | ||
"isdtype", | ||
"linspace", | ||
"matmul", | ||
"matrix_transpose", | ||
"nonzero", | ||
"ones", | ||
"ones_like", | ||
"permute_dims", | ||
"pow", | ||
"prod", | ||
"reshape", | ||
"std", | ||
"sum", | ||
"tensordot", | ||
"trunc", | ||
"unique_all", | ||
"unique_counts", | ||
"unique_inverse", | ||
"unique_values", | ||
"var", | ||
"vecdot", | ||
"zeros", | ||
"zeros_like", | ||
] | ||
|
||
|
||
__array_api_version__ = "2022.12" | ||
|
||
__import__(__package__ + ".linalg") |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,48 +1,50 @@ | ||
from __future__ import annotations | ||
|
||
from dask.array.linalg import * | ||
from ...common import _linalg | ||
from ..._internal import get_xp | ||
from dask.array import matmul, tensordot, trace, outer | ||
from ._aliases import matrix_transpose, vecdot | ||
|
||
import dask.array as da | ||
|
||
from typing import TYPE_CHECKING | ||
if TYPE_CHECKING: | ||
from typing import Union, Tuple | ||
from ...common._typing import ndarray | ||
|
||
# cupy.linalg doesn't have __all__. If it is added, replace this with | ||
# | ||
# from cupy.linalg import __all__ as linalg_all | ||
_n = {} | ||
exec('from dask.array.linalg import *', _n) | ||
del _n['__builtins__'] | ||
linalg_all = list(_n) | ||
del _n | ||
|
||
EighResult = _linalg.EighResult | ||
QRResult = _linalg.QRResult | ||
SlogdetResult = _linalg.SlogdetResult | ||
SVDResult = _linalg.SVDResult | ||
qr = get_xp(da)(_linalg.qr) | ||
cholesky = get_xp(da)(_linalg.cholesky) | ||
matrix_rank = get_xp(da)(_linalg.matrix_rank) | ||
matrix_norm = get_xp(da)(_linalg.matrix_norm) | ||
|
||
def svdvals(x: ndarray) -> Union[ndarray, Tuple[ndarray, ...]]: | ||
# TODO: can't avoid computing U or V for dask | ||
_, s, _ = svd(x) | ||
return s | ||
|
||
vector_norm = get_xp(da)(_linalg.vector_norm) | ||
diagonal = get_xp(da)(_linalg.diagonal) | ||
|
||
__all__ = linalg_all + ["EighResult", "QRResult", "SlogdetResult", | ||
"SVDResult", "qr", "cholesky", "matrix_rank", "matrix_norm", | ||
"svdvals", "vector_norm", "diagonal"] | ||
|
||
del get_xp | ||
del da | ||
del _linalg | ||
import dask.array as _da | ||
from dask.array import ( | ||
matmul, | ||
outer, | ||
tensordot, | ||
trace, | ||
) | ||
from dask.array.linalg import * # noqa: F401, F403 | ||
|
||
from .._internal import _get_all_public_members | ||
from ._aliases import ( | ||
EighResult, | ||
QRResult, | ||
SlogdetResult, | ||
SVDResult, | ||
cholesky, | ||
diagonal, | ||
matrix_norm, | ||
matrix_rank, | ||
matrix_transpose, | ||
qr, | ||
svdvals, | ||
vecdot, | ||
vector_norm, | ||
) | ||
|
||
__all__ = [ | ||
"matmul", | ||
"outer", | ||
"tensordot", | ||
"trace", | ||
] | ||
|
||
__all__ += _get_all_public_members(_da.linalg) | ||
|
||
__all__ += [ | ||
"EighResult", | ||
"QRResult", | ||
"SVDResult", | ||
"SlogdetResult", | ||
"cholesky", | ||
"diagonal", | ||
"matrix_norm", | ||
"matrix_rank", | ||
"matrix_transpose", | ||
"qr", | ||
"svdvals", | ||
"vecdot", | ||
"vector_norm", | ||
] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,9 @@ | ||
from __future__ import annotations | ||
|
||
__all__ = [ | ||
"ndarray", | ||
"Device", | ||
"Dtype", | ||
"ndarray", | ||
] | ||
|
||
import sys | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,40 +1,63 @@ | ||
from numpy.linalg import * | ||
from numpy.linalg import __all__ as linalg_all | ||
|
||
from ..common import _linalg | ||
from .._internal import get_xp | ||
from ._aliases import (matmul, matrix_transpose, tensordot, vecdot) | ||
|
||
import numpy as np | ||
|
||
cross = get_xp(np)(_linalg.cross) | ||
outer = get_xp(np)(_linalg.outer) | ||
EighResult = _linalg.EighResult | ||
QRResult = _linalg.QRResult | ||
SlogdetResult = _linalg.SlogdetResult | ||
SVDResult = _linalg.SVDResult | ||
eigh = get_xp(np)(_linalg.eigh) | ||
qr = get_xp(np)(_linalg.qr) | ||
slogdet = get_xp(np)(_linalg.slogdet) | ||
svd = get_xp(np)(_linalg.svd) | ||
cholesky = get_xp(np)(_linalg.cholesky) | ||
matrix_rank = get_xp(np)(_linalg.matrix_rank) | ||
pinv = get_xp(np)(_linalg.pinv) | ||
matrix_norm = get_xp(np)(_linalg.matrix_norm) | ||
svdvals = get_xp(np)(_linalg.svdvals) | ||
diagonal = get_xp(np)(_linalg.diagonal) | ||
trace = get_xp(np)(_linalg.trace) | ||
|
||
# These functions are completely new here. If the library already has them | ||
# (i.e., numpy 2.0), use the library version instead of our wrapper. | ||
if hasattr(np.linalg, 'vector_norm'): | ||
vector_norm = np.linalg.vector_norm | ||
else: | ||
vector_norm = get_xp(np)(_linalg.vector_norm) | ||
|
||
__all__ = linalg_all + _linalg.__all__ | ||
|
||
del get_xp | ||
del np | ||
del linalg_all | ||
del _linalg | ||
import numpy as _np | ||
|
||
from .._internal import _get_all_public_members | ||
|
||
_numpy_linalg_all = _get_all_public_members(_np.linalg) | ||
|
||
for _name in _numpy_linalg_all: | ||
globals()[_name] = getattr(_np.linalg, _name) | ||
|
||
|
||
from ._aliases import ( # noqa: E402 | ||
EighResult, | ||
QRResult, | ||
SlogdetResult, | ||
SVDResult, | ||
cholesky, | ||
cross, | ||
diagonal, | ||
eigh, | ||
matmul, | ||
matrix_norm, | ||
matrix_rank, | ||
matrix_transpose, | ||
outer, | ||
pinv, | ||
qr, | ||
slogdet, | ||
svd, | ||
svdvals, | ||
tensordot, | ||
trace, | ||
vecdot, | ||
vector_norm, | ||
) | ||
|
||
__all__ = [] | ||
|
||
__all__ += _numpy_linalg_all | ||
|
||
__all__ += [ | ||
"EighResult", | ||
"QRResult", | ||
"SVDResult", | ||
"SlogdetResult", | ||
"cholesky", | ||
"cross", | ||
"diagonal", | ||
"eigh", | ||
"matmul", | ||
"matrix_norm", | ||
"matrix_rank", | ||
"matrix_transpose", | ||
"outer", | ||
"pinv", | ||
"qr", | ||
"slogdet", | ||
"svd", | ||
"svdvals", | ||
"tensordot", | ||
"trace", | ||
"vecdot", | ||
"vector_norm", | ||
] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,22 +1,189 @@ | ||
from torch import * | ||
|
||
# Several names are not included in the above import * | ||
import torch | ||
for n in dir(torch): | ||
if (n.startswith('_') | ||
or n.endswith('_') | ||
or 'cuda' in n | ||
or 'cpu' in n | ||
or 'backward' in n): | ||
continue | ||
exec(n + ' = torch.' + n) | ||
import torch as _torch | ||
from torch import * # noqa: F401, F403 | ||
|
||
from .._internal import _get_all_public_members | ||
|
||
|
||
def exlcude(name): | ||
if ( | ||
name.startswith("_") | ||
or name.endswith("_") | ||
or "cuda" in name | ||
or "cpu" in name | ||
or "backward" in name | ||
): | ||
return True | ||
return False | ||
|
||
|
||
_torch_all = _get_all_public_members(_torch, exclude=exlcude, extend_all=True) | ||
|
||
for _name in _torch_all: | ||
globals()[_name] = getattr(_torch, _name) | ||
|
||
|
||
from ..common._helpers import ( # noqa: E402 | ||
array_namespace, | ||
device, | ||
get_namespace, | ||
is_array_api_obj, | ||
size, | ||
to_device, | ||
) | ||
|
||
# These imports may overwrite names from the import * above. | ||
from ._aliases import * | ||
from ._aliases import ( # noqa: E402 | ||
add, | ||
all, | ||
any, | ||
arange, | ||
astype, | ||
atan2, | ||
bitwise_and, | ||
bitwise_invert, | ||
bitwise_left_shift, | ||
bitwise_or, | ||
bitwise_right_shift, | ||
bitwise_xor, | ||
broadcast_arrays, | ||
broadcast_to, | ||
can_cast, | ||
concat, | ||
divide, | ||
empty, | ||
equal, | ||
expand_dims, | ||
eye, | ||
flip, | ||
floor_divide, | ||
full, | ||
greater, | ||
greater_equal, | ||
isdtype, | ||
less, | ||
less_equal, | ||
linspace, | ||
logaddexp, | ||
matmul, | ||
matrix_transpose, | ||
max, | ||
mean, | ||
min, | ||
multiply, | ||
newaxis, | ||
nonzero, | ||
not_equal, | ||
ones, | ||
permute_dims, | ||
pow, | ||
prod, | ||
remainder, | ||
reshape, | ||
result_type, | ||
roll, | ||
sort, | ||
squeeze, | ||
std, | ||
subtract, | ||
sum, | ||
take, | ||
tensordot, | ||
tril, | ||
triu, | ||
unique_all, | ||
unique_counts, | ||
unique_inverse, | ||
unique_values, | ||
var, | ||
vecdot, | ||
where, | ||
zeros, | ||
) | ||
|
||
# See the comment in the numpy __init__.py | ||
__import__(__package__ + '.linalg') | ||
__all__ = [] | ||
|
||
__all__ += _torch_all | ||
|
||
__all__ += [ | ||
"array_namespace", | ||
"device", | ||
"get_namespace", | ||
"is_array_api_obj", | ||
"size", | ||
"to_device", | ||
] | ||
|
||
from ..common._helpers import * | ||
__all__ += [ | ||
"add", | ||
"all", | ||
"any", | ||
"arange", | ||
"astype", | ||
"atan2", | ||
"bitwise_and", | ||
"bitwise_invert", | ||
"bitwise_left_shift", | ||
"bitwise_or", | ||
"bitwise_right_shift", | ||
"bitwise_xor", | ||
"broadcast_arrays", | ||
"broadcast_to", | ||
"can_cast", | ||
"concat", | ||
"divide", | ||
"empty", | ||
"equal", | ||
"expand_dims", | ||
"eye", | ||
"flip", | ||
"floor_divide", | ||
"full", | ||
"greater", | ||
"greater_equal", | ||
"isdtype", | ||
"less", | ||
"less_equal", | ||
"linspace", | ||
"logaddexp", | ||
"matmul", | ||
"matrix_transpose", | ||
"max", | ||
"mean", | ||
"min", | ||
"multiply", | ||
"newaxis", | ||
"nonzero", | ||
"not_equal", | ||
"ones", | ||
"permute_dims", | ||
"pow", | ||
"prod", | ||
"remainder", | ||
"reshape", | ||
"result_type", | ||
"roll", | ||
"sort", | ||
"squeeze", | ||
"std", | ||
"subtract", | ||
"sum", | ||
"take", | ||
"tensordot", | ||
"tril", | ||
"triu", | ||
"unique_all", | ||
"unique_counts", | ||
"unique_inverse", | ||
"unique_values", | ||
"var", | ||
"vecdot", | ||
"where", | ||
"zeros", | ||
] | ||
|
||
|
||
# See the comment in the numpy __init__.py | ||
__import__(__package__ + ".linalg") | ||
|
||
__array_api_version__ = '2022.12' | ||
__array_api_version__ = "2022.12" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,22 +1,24 @@ | ||
from __future__ import annotations | ||
|
||
from builtins import all as builtin_all | ||
from builtins import any as builtin_any | ||
from functools import wraps | ||
from builtins import all as builtin_all, any as builtin_any | ||
|
||
from ..common._aliases import (UniqueAllResult, UniqueCountsResult, | ||
UniqueInverseResult, | ||
matrix_transpose as _aliases_matrix_transpose, | ||
vecdot as _aliases_vecdot) | ||
from .._internal import get_xp | ||
from typing import TYPE_CHECKING | ||
|
||
import torch | ||
|
||
from typing import TYPE_CHECKING | ||
from .._internal import get_xp | ||
from ..common._aliases import UniqueAllResult, UniqueCountsResult, UniqueInverseResult | ||
from ..common._aliases import matrix_transpose as _aliases_matrix_transpose | ||
from ..common._aliases import vecdot as _aliases_vecdot | ||
|
||
if TYPE_CHECKING: | ||
from typing import List, Optional, Sequence, Tuple, Union | ||
from ..common._typing import Device | ||
|
||
from torch import dtype as Dtype | ||
|
||
from ..common._typing import Device | ||
|
||
array = torch.Tensor | ||
|
||
_int_dtypes = { | ||
|
@@ -693,15 +695,42 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) - | |
axis = 0 | ||
return torch.index_select(x, axis, indices, **kwargs) | ||
|
||
__all__ = ['result_type', 'can_cast', 'permute_dims', 'bitwise_invert', 'newaxis', | ||
'add', 'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or', | ||
'bitwise_right_shift', 'bitwise_xor', 'divide', 'equal', | ||
'floor_divide', 'greater', 'greater_equal', 'less', 'less_equal', | ||
'logaddexp', 'multiply', 'not_equal', 'pow', 'remainder', | ||
'subtract', 'max', 'min', 'sort', 'prod', 'sum', 'any', 'all', | ||
'mean', 'std', 'var', 'concat', 'squeeze', 'broadcast_to', 'flip', 'roll', | ||
'nonzero', 'where', 'reshape', 'arange', 'eye', 'linspace', 'full', | ||
'ones', 'zeros', 'empty', 'tril', 'triu', 'expand_dims', 'astype', | ||
'broadcast_arrays', 'unique_all', 'unique_counts', | ||
'unique_inverse', 'unique_values', 'matmul', 'matrix_transpose', | ||
'vecdot', 'tensordot', 'isdtype', 'take'] | ||
|
||
|
||
# Note: torch.linalg.cross does not default to axis=-1 (it defaults to the | ||
# first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743 | ||
def cross(x1: array, x2: array, /, *, axis: int = -1) -> array: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why did you move these to this file? These were in linalg.py because they're linalg only functions. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I understand. It seemed conceptually simpler in the end to just use |
||
x1, x2 = _fix_promotion(x1, x2, only_scalar=False) | ||
return torch.linalg.cross(x1, x2, dim=axis) | ||
|
||
def vecdot_linalg(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array: | ||
from ._aliases import isdtype | ||
|
||
x1, x2 = _fix_promotion(x1, x2, only_scalar=False) | ||
|
||
# torch.linalg.vecdot doesn't support integer dtypes | ||
if isdtype(x1.dtype, 'integral') or isdtype(x2.dtype, 'integral'): | ||
if kwargs: | ||
raise RuntimeError("vecdot kwargs not supported for integral dtypes") | ||
ndim = max(x1.ndim, x2.ndim) | ||
x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape) | ||
x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape) | ||
if x1_shape[axis] != x2_shape[axis]: | ||
raise ValueError("x1 and x2 must have the same size along the given axis") | ||
|
||
x1_, x2_ = torch.broadcast_tensors(x1, x2) | ||
x1_ = torch.moveaxis(x1_, axis, -1) | ||
x2_ = torch.moveaxis(x2_, axis, -1) | ||
|
||
res = x1_[..., None, :] @ x2_[..., None] | ||
return res[..., 0, 0] | ||
return torch.linalg.vecdot(x1, x2, dim=axis, **kwargs) | ||
|
||
def solve(x1: array, x2: array, /, **kwargs) -> array: | ||
x1, x2 = _fix_promotion(x1, x2, only_scalar=False) | ||
return torch.linalg.solve(x1, x2, **kwargs) | ||
|
||
# torch.trace doesn't support the offset argument and doesn't support stacking | ||
def trace(x: array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> array: | ||
# Use our wrapped sum to make sure it does upcasting correctly | ||
return sum(torch.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,62 +1,34 @@ | ||
from __future__ import annotations | ||
import torch as _torch | ||
|
||
from typing import TYPE_CHECKING | ||
if TYPE_CHECKING: | ||
import torch | ||
array = torch.Tensor | ||
from torch import dtype as Dtype | ||
from typing import Optional | ||
from .._internal import _get_all_public_members | ||
|
||
from torch.linalg import * | ||
_torch_linalg_all = _get_all_public_members(_torch.linalg) | ||
|
||
# torch.linalg doesn't define __all__ | ||
# from torch.linalg import __all__ as linalg_all | ||
from torch import linalg as torch_linalg | ||
linalg_all = [i for i in dir(torch_linalg) if not i.startswith('_')] | ||
for _name in _torch_linalg_all: | ||
globals()[_name] = getattr(_torch.linalg, _name) | ||
|
||
# outer is implemented in torch but aren't in the linalg namespace | ||
from torch import outer | ||
from ._aliases import _fix_promotion, matrix_transpose, tensordot, sum | ||
|
||
# Note: torch.linalg.cross does not default to axis=-1 (it defaults to the | ||
# first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743 | ||
def cross(x1: array, x2: array, /, *, axis: int = -1) -> array: | ||
x1, x2 = _fix_promotion(x1, x2, only_scalar=False) | ||
return torch_linalg.cross(x1, x2, dim=axis) | ||
|
||
def vecdot(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array: | ||
from ._aliases import isdtype | ||
|
||
x1, x2 = _fix_promotion(x1, x2, only_scalar=False) | ||
|
||
# torch.linalg.vecdot doesn't support integer dtypes | ||
if isdtype(x1.dtype, 'integral') or isdtype(x2.dtype, 'integral'): | ||
if kwargs: | ||
raise RuntimeError("vecdot kwargs not supported for integral dtypes") | ||
ndim = max(x1.ndim, x2.ndim) | ||
x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape) | ||
x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape) | ||
if x1_shape[axis] != x2_shape[axis]: | ||
raise ValueError("x1 and x2 must have the same size along the given axis") | ||
|
||
x1_, x2_ = torch.broadcast_tensors(x1, x2) | ||
x1_ = torch.moveaxis(x1_, axis, -1) | ||
x2_ = torch.moveaxis(x2_, axis, -1) | ||
|
||
res = x1_[..., None, :] @ x2_[..., None] | ||
return res[..., 0, 0] | ||
return torch.linalg.vecdot(x1, x2, dim=axis, **kwargs) | ||
|
||
def solve(x1: array, x2: array, /, **kwargs) -> array: | ||
x1, x2 = _fix_promotion(x1, x2, only_scalar=False) | ||
return torch.linalg.solve(x1, x2, **kwargs) | ||
|
||
# torch.trace doesn't support the offset argument and doesn't support stacking | ||
def trace(x: array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> array: | ||
# Use our wrapped sum to make sure it does upcasting correctly | ||
return sum(torch.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype) | ||
|
||
__all__ = linalg_all + ['outer', 'trace', 'matrix_transpose', 'tensordot', | ||
'vecdot', 'solve'] | ||
|
||
del linalg_all | ||
outer = _torch.outer | ||
|
||
from ._aliases import ( # noqa: E402 | ||
matrix_transpose, | ||
solve, | ||
sum, | ||
tensordot, | ||
trace, | ||
vecdot_linalg as vecdot, | ||
) | ||
|
||
__all__ = [] | ||
|
||
__all__ += _torch_linalg_all | ||
|
||
__all__ += [ | ||
"matrix_transpose", | ||
"outer", | ||
"solve", | ||
"sum", | ||
"tensordot", | ||
"trace", | ||
"vecdot", | ||
] |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,8 @@ | |
|
||
import pytest | ||
|
||
def import_(library): | ||
if 'cupy' in library: | ||
|
||
def import_or_skip_cupy(library): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FYI, I'm probably going to rename this back, because I need to add some additional skipping logic for jax as well at #84 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure. I think it would be good to keep the "or skip" in the name, because from looking at the places where it was used, it was not clear what it does. |
||
if "cupy" in library: | ||
return pytest.importorskip(library) | ||
return import_module(library) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is monkeypatching
torch.__all__
etc.? We don't want to do that.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, but this just keeps the current behavior. Take a look at https://github.com/data-apis/array-api-compat/blob/main/array_api_compat/torch/__init__.py#L3
I have not checked whether this is still necessary, but probably we have to keep it this way?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I checked with the following code:
And this gives:
So indeed
__all__
does not contain multiple members and most importantly it does not contain the dtypes.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That code does not modify the
torch.__all__
list:Generally speaking, this package should not monkeypatch the underlying libraries.
Yes, that's a known issue. pytorch/pytorch#91908
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, now I understand. I did not mean to actually modify
torch.__all__
in place but copy and extend instead. I'll fix that behavior.