Skip to content

Commit 953d7c0

Browse files
committed
TYP: __dir__
1 parent 0dd925f commit 953d7c0

File tree

11 files changed

+62
-19
lines changed

11 files changed

+62
-19
lines changed

array_api_compat/_internal.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88
from types import ModuleType
99
from typing import TypeVar
1010

11-
__all__ = ["get_xp"]
12-
1311
_T = TypeVar("_T")
1412

1513

@@ -52,3 +50,10 @@ def wrapped_f(*args: object, **kwargs: object) -> object:
5250
return wrapped_f # pyright: ignore[reportReturnType]
5351

5452
return inner
53+
54+
55+
__all__ = ["get_xp"]
56+
57+
58+
def __dir__() -> list[str]:
59+
return __all__

array_api_compat/common/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from ._helpers import * # noqa: F403
1+
from ._helpers import * # noqa: F403

array_api_compat/common/_aliases.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -720,5 +720,8 @@ def iinfo(type_: DType | Array, /, xp: Namespace) -> Any:
720720
"finfo",
721721
"iinfo",
722722
]
723-
724723
_all_ignore = ["inspect", "array_namespace", "NamedTuple"]
724+
725+
726+
def __dir__() -> list[str]:
727+
return __all__

array_api_compat/common/_fft.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,3 +208,6 @@ def ifftshift(
208208
"fftshift",
209209
"ifftshift",
210210
]
211+
212+
def __dir__() -> list[str]:
213+
return __all__

array_api_compat/common/_helpers.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1044,3 +1044,6 @@ def is_lazy_array(x: object) -> bool:
10441044
]
10451045

10461046
_all_ignore = ["sys", "math", "inspect", "warnings"]
1047+
1048+
def __dir__() -> list[str]:
1049+
return __all__

array_api_compat/common/_linalg.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,3 +226,7 @@ def trace(
226226
'trace']
227227

228228
_all_ignore = ['math', 'normalize_axis_tuple', 'get_xp', 'np', 'isdtype']
229+
230+
231+
def __dir__() -> list[str]:
232+
return __all__

array_api_compat/common/_typing.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,6 @@
33
from types import ModuleType as Namespace
44
from typing import Any, Protocol, TypeAlias, TypeVar
55

6-
__all__ = [
7-
"Array",
8-
"SupportsArrayNamespace",
9-
"DType",
10-
"Device",
11-
"HasShape",
12-
"Namespace",
13-
"NestedSequence",
14-
"SupportsBufferProtocol",
15-
]
16-
176
_T_co = TypeVar("_T_co", covariant=True)
187

198
class NestedSequence(Protocol[_T_co]):
@@ -34,3 +23,19 @@ def shape(self, /) -> _T_co: ...
3423
Array: TypeAlias = Any
3524
Device: TypeAlias = Any
3625
DType: TypeAlias = Any
26+
27+
28+
__all__ = [
29+
"Array",
30+
"SupportsArrayNamespace",
31+
"DType",
32+
"Device",
33+
"HasShape",
34+
"Namespace",
35+
"NestedSequence",
36+
"SupportsBufferProtocol",
37+
]
38+
39+
40+
def __dir__() -> list[str]:
41+
return __all__

array_api_compat/numpy/_aliases.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,5 +173,8 @@ def count_nonzero(
173173
"pow",
174174
]
175175
__all__ += _aliases.__all__
176-
177176
_all_ignore = ["np", "get_xp"]
177+
178+
179+
def __dir__() -> list[str]:
180+
return __all__

array_api_compat/numpy/_info.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,3 +357,10 @@ def devices(self) -> list[Device]:
357357
358358
"""
359359
return ["cpu"]
360+
361+
362+
__all__ = ["__array_namespace_info__"]
363+
364+
365+
def __dir__() -> list[str]:
366+
return __all__

array_api_compat/numpy/_typing.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
from __future__ import annotations
22

3-
__all__ = ["Array", "DType", "Device"]
4-
_all_ignore = ["np"]
5-
63
from typing import TYPE_CHECKING, Any, Literal, TypeAlias
74

85
import numpy as np
@@ -24,3 +21,10 @@
2421
else:
2522
DType: TypeAlias = np.dtype
2623
Array: TypeAlias = np.ndarray
24+
25+
__all__ = ["Array", "DType", "Device"]
26+
_all_ignore = ["np"]
27+
28+
29+
def __dir__() -> list[str]:
30+
return __all__

array_api_compat/numpy/fft.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,15 @@
2020
fftshift = get_xp(np)(_fft.fftshift)
2121
ifftshift = get_xp(np)(_fft.ifftshift)
2222

23+
2324
__all__ = ["rfft2", "irfft2", "fft2", "ifft2"]
2425
__all__ += _fft.__all__
2526

27+
28+
def __dir__() -> list[str]:
29+
return __all__
30+
31+
2632
del get_xp
2733
del np
2834
del fft_all

0 commit comments

Comments
 (0)