diff --git a/array_api_compat/_internal.py b/array_api_compat/_internal.py
index 170a1ff9..cd8d939f 100644
--- a/array_api_compat/_internal.py
+++ b/array_api_compat/_internal.py
@@ -2,10 +2,16 @@
 Internal helpers
 """
 
+from collections.abc import Callable
 from functools import wraps
 from inspect import signature
+from types import ModuleType
+from typing import TypeVar
 
-def get_xp(xp):
+_T = TypeVar("_T")
+
+
+def get_xp(xp: ModuleType) -> Callable[[Callable[..., _T]], Callable[..., _T]]:
     """
     Decorator to automatically replace xp with the corresponding array module.
 
@@ -22,14 +28,14 @@ def func(x, /, xp, kwarg=None):
 
     """
 
-    def inner(f):
+    def inner(f: Callable[..., _T], /) -> Callable[..., _T]:
         @wraps(f)
-        def wrapped_f(*args, **kwargs):
+        def wrapped_f(*args: object, **kwargs: object) -> object:
             return f(*args, xp=xp, **kwargs)
 
         sig = signature(f)
         new_sig = sig.replace(
-            parameters=[sig.parameters[i] for i in sig.parameters if i != "xp"]
+            parameters=[par for i, par in sig.parameters.items() if i != "xp"]
         )
 
         if wrapped_f.__doc__ is None:
@@ -40,7 +46,14 @@ def wrapped_f(*args, **kwargs):
 specification for more details.
 
 """
-        wrapped_f.__signature__ = new_sig
-        return wrapped_f
+        wrapped_f.__signature__ = new_sig  # pyright: ignore[reportAttributeAccessIssue]
+        return wrapped_f  # pyright: ignore[reportReturnType]
 
     return inner
+
+
+__all__ = ["get_xp"]
+
+
+def __dir__() -> list[str]:
+    return __all__
diff --git a/array_api_compat/common/__init__.py b/array_api_compat/common/__init__.py
index 91ab1c40..82360807 100644
--- a/array_api_compat/common/__init__.py
+++ b/array_api_compat/common/__init__.py
@@ -1 +1 @@
-from ._helpers import * # noqa: F403
+from ._helpers import *  # noqa: F403
diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py
index 351b5bd6..8ea9162a 100644
--- a/array_api_compat/common/_aliases.py
+++ b/array_api_compat/common/_aliases.py
@@ -5,158 +5,170 @@
 from __future__ import annotations
 
 import inspect
-from typing import Any, NamedTuple, Optional, Sequence, Tuple, Union
+from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Sequence, cast
 
+from ._helpers import _check_device, array_namespace
+from ._helpers import device as _get_device
+from ._helpers import is_cupy_namespace as _is_cupy_namespace
 from ._typing import Array, Device, DType, Namespace
-from ._helpers import (
-    array_namespace,
-    _check_device,
-    device as _get_device,
-    is_cupy_namespace as _is_cupy_namespace
-)
 
+if TYPE_CHECKING:
+    # TODO: import from typing (requires Python >=3.13)
+    from typing_extensions import TypeIs
 
 # These functions are modified from the NumPy versions.
 
 # Creation functions add the device keyword (which does nothing for NumPy and Dask)
 
+
 def arange(
-    start: Union[int, float],
+    start: float,
     /,
-    stop: Optional[Union[int, float]] = None,
-    step: Union[int, float] = 1,
+    stop: float | None = None,
+    step: float = 1,
     *,
     xp: Namespace,
-    dtype: Optional[DType] = None,
-    device: Optional[Device] = None,
-    **kwargs,
+    dtype: DType | None = None,
+    device: Device | None = None,
+    **kwargs: object,
 ) -> Array:
     _check_device(xp, device)
     return xp.arange(start, stop=stop, step=step, dtype=dtype, **kwargs)
 
+
 def empty(
-    shape: Union[int, Tuple[int, ...]],
+    shape: int | tuple[int, ...],
     xp: Namespace,
     *,
-    dtype: Optional[DType] = None,
-    device: Optional[Device] = None,
-    **kwargs,
+    dtype: DType | None = None,
+    device: Device | None = None,
+    **kwargs: object,
 ) -> Array:
     _check_device(xp, device)
     return xp.empty(shape, dtype=dtype, **kwargs)
 
+
 def empty_like(
     x: Array,
     /,
     xp: Namespace,
     *,
-    dtype: Optional[DType] = None, 
-    device: Optional[Device] = None,
-    **kwargs,
+    dtype: DType | None = None,
+    device: Device | None = None,
+    **kwargs: object,
 ) -> Array:
     _check_device(xp, device)
     return xp.empty_like(x, dtype=dtype, **kwargs)
 
+
 def eye(
     n_rows: int,
-    n_cols: Optional[int] = None,
+    n_cols: int | None = None,
     /,
     *,
     xp: Namespace,
     k: int = 0,
-    dtype: Optional[DType] = None,
-    device: Optional[Device] = None,
-    **kwargs,
+    dtype: DType | None = None,
+    device: Device | None = None,
+    **kwargs: object,
 ) -> Array:
     _check_device(xp, device)
     return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype, **kwargs)
 
+
 def full(
-    shape: Union[int, Tuple[int, ...]],
-    fill_value: bool | int | float | complex,
+    shape: int | tuple[int, ...],
+    fill_value: complex,
     xp: Namespace,
     *,
-    dtype: Optional[DType] = None,
-    device: Optional[Device] = None,
-    **kwargs,
+    dtype: DType | None = None,
+    device: Device | None = None,
+    **kwargs: object,
 ) -> Array:
     _check_device(xp, device)
     return xp.full(shape, fill_value, dtype=dtype, **kwargs)
 
+
 def full_like(
     x: Array,
     /,
-    fill_value: bool | int | float | complex,
+    fill_value: complex,
     *,
     xp: Namespace,
-    dtype: Optional[DType] = None,
-    device: Optional[Device] = None,
-    **kwargs,
+    dtype: DType | None = None,
+    device: Device | None = None,
+    **kwargs: object,
 ) -> Array:
     _check_device(xp, device)
     return xp.full_like(x, fill_value, dtype=dtype, **kwargs)
 
+
 def linspace(
-    start: Union[int, float],
-    stop: Union[int, float],
+    start: float,
+    stop: float,
     /,
     num: int,
     *,
     xp: Namespace,
-    dtype: Optional[DType] = None,
-    device: Optional[Device] = None,
+    dtype: DType | None = None,
+    device: Device | None = None,
     endpoint: bool = True,
-    **kwargs,
+    **kwargs: object,
 ) -> Array:
     _check_device(xp, device)
     return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint, **kwargs)
 
+
 def ones(
-    shape: Union[int, Tuple[int, ...]],
+    shape: int | tuple[int, ...],
     xp: Namespace,
     *,
-    dtype: Optional[DType] = None,
-    device: Optional[Device] = None,
-    **kwargs,
+    dtype: DType | None = None,
+    device: Device | None = None,
+    **kwargs: object,
 ) -> Array:
     _check_device(xp, device)
     return xp.ones(shape, dtype=dtype, **kwargs)
 
+
 def ones_like(
     x: Array,
     /,
     xp: Namespace,
     *,
-    dtype: Optional[DType] = None,
-    device: Optional[Device] = None,
-    **kwargs,
+    dtype: DType | None = None,
+    device: Device | None = None,
+    **kwargs: object,
 ) -> Array:
     _check_device(xp, device)
     return xp.ones_like(x, dtype=dtype, **kwargs)
 
+
 def zeros(
-    shape: Union[int, Tuple[int, ...]],
+    shape: int | tuple[int, ...],
     xp: Namespace,
     *,
-    dtype: Optional[DType] = None,
-    device: Optional[Device] = None,
-    **kwargs,
+    dtype: DType | None = None,
+    device: Device | None = None,
+    **kwargs: object,
 ) -> Array:
     _check_device(xp, device)
     return xp.zeros(shape, dtype=dtype, **kwargs)
 
+
 def zeros_like(
     x: Array,
     /,
     xp: Namespace,
     *,
-    dtype: Optional[DType] = None,
-    device: Optional[Device] = None,
-    **kwargs,
+    dtype: DType | None = None,
+    device: Device | None = None,
+    **kwargs: object,
 ) -> Array:
     _check_device(xp, device)
     return xp.zeros_like(x, dtype=dtype, **kwargs)
 
+
 # np.unique() is split into four functions in the array API:
 # unique_all, unique_counts, unique_inverse, and unique_values (this is done
 # to remove polymorphic return types).
@@ -164,6 +176,7 @@ def zeros_like(
 # The functions here return namedtuples (np.unique() returns a normal
 # tuple).
 
+
 # Note that these named tuples aren't actually part of the standard namespace,
 # but I don't see any issue with exporting the names here regardless.
 class UniqueAllResult(NamedTuple):
@@ -188,10 +201,11 @@ def _unique_kwargs(xp: Namespace) -> dict[str, bool]:
     # trying to parse version numbers, just check if equal_nan is in the
     # signature.
     s = inspect.signature(xp.unique)
-    if 'equal_nan' in s.parameters:
-        return {'equal_nan': False}
+    if "equal_nan" in s.parameters:
+        return {"equal_nan": False}
     return {}
 
+
 def unique_all(x: Array, /, xp: Namespace) -> UniqueAllResult:
     kwargs = _unique_kwargs(xp)
     values, indices, inverse_indices, counts = xp.unique(
@@ -215,11 +229,7 @@ def unique_all(x: Array, /, xp: Namespace) -> UniqueAllResult:
 def unique_counts(x: Array, /, xp: Namespace) -> UniqueCountsResult:
     kwargs = _unique_kwargs(xp)
     res = xp.unique(
-        x,
-        return_counts=True,
-        return_index=False,
-        return_inverse=False,
-        **kwargs
+        x, return_counts=True, return_index=False, return_inverse=False, **kwargs
     )
 
     return UniqueCountsResult(*res)
@@ -250,51 +260,58 @@ def unique_values(x: Array, /, xp: Namespace) -> Array:
         **kwargs,
     )
 
+
 # These functions have different keyword argument names
 
+
 def std(
     x: Array,
     /,
     xp: Namespace,
     *,
-    axis: Optional[Union[int, Tuple[int, ...]]] = None,
-    correction: Union[int, float] = 0.0,  # correction instead of ddof
+    axis: int | tuple[int, ...] | None = None,
+    correction: float = 0.0,  # correction instead of ddof
     keepdims: bool = False,
-    **kwargs,
+    **kwargs: object,
 ) -> Array:
     return xp.std(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs)
 
+
 def var(
     x: Array,
     /,
     xp: Namespace,
     *,
-    axis: Optional[Union[int, Tuple[int, ...]]] = None,
-    correction: Union[int, float] = 0.0,  # correction instead of ddof
+    axis: int | tuple[int, ...] | None = None,
+    correction: float = 0.0,  # correction instead of ddof
     keepdims: bool = False,
-    **kwargs,
+    **kwargs: object,
 ) -> Array:
     return xp.var(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs)
 
+
 # cumulative_sum is renamed from cumsum, and adds the include_initial keyword
 # argument
 
+
 def cumulative_sum(
     x: Array,
     /,
     xp: Namespace,
     *,
-    axis: Optional[int] = None,
-    dtype: Optional[DType] = None,
+    axis: int | None = None,
+    dtype: DType | None = None,
     include_initial: bool = False,
-    **kwargs,
+    **kwargs: object,
 ) -> Array:
     wrapped_xp = array_namespace(x)
 
     # TODO: The standard is not clear about what should happen when x.ndim == 0.
     if axis is None:
         if x.ndim > 1:
-            raise ValueError("axis must be specified in cumulative_sum for more than one dimension")
+            raise ValueError(
+                "axis must be specified in cumulative_sum for more than one dimension"
+            )
         axis = 0
 
     res = xp.cumsum(x, axis=axis, dtype=dtype, **kwargs)
@@ -304,7 +321,12 @@ def cumulative_sum(
         initial_shape = list(x.shape)
         initial_shape[axis] = 1
         res = xp.concatenate(
-            [wrapped_xp.zeros(shape=initial_shape, dtype=res.dtype, device=_get_device(res)), res],
+            [
+                wrapped_xp.zeros(
+                    shape=initial_shape, dtype=res.dtype, device=_get_device(res)
+                ),
+                res,
+            ],
             axis=axis,
         )
     return res
@@ -315,16 +337,18 @@ def cumulative_prod(
     /,
     xp: Namespace,
     *,
-    axis: Optional[int] = None,
-    dtype: Optional[DType] = None,
+    axis: int | None = None,
+    dtype: DType | None = None,
     include_initial: bool = False,
-    **kwargs,
+    **kwargs: object,
 ) -> Array:
     wrapped_xp = array_namespace(x)
 
     if axis is None:
         if x.ndim > 1:
-            raise ValueError("axis must be specified in cumulative_prod for more than one dimension")
+            raise ValueError(
+                "axis must be specified in cumulative_prod for more than one dimension"
+            )
         axis = 0
 
     res = xp.cumprod(x, axis=axis, dtype=dtype, **kwargs)
@@ -334,24 +358,30 @@ def cumulative_prod(
         initial_shape = list(x.shape)
         initial_shape[axis] = 1
         res = xp.concatenate(
-            [wrapped_xp.ones(shape=initial_shape, dtype=res.dtype, device=_get_device(res)), res],
+            [
+                wrapped_xp.ones(
+                    shape=initial_shape, dtype=res.dtype, device=_get_device(res)
+                ),
+                res,
+            ],
             axis=axis,
         )
     return res
 
+
 # The min and max argument names in clip are different and not optional in numpy, and type
 # promotion behavior is different.
 def clip(
     x: Array,
     /,
-    min: Optional[Union[int, float, Array]] = None,
-    max: Optional[Union[int, float, Array]] = None,
+    min: float | Array | None = None,
+    max: float | Array | None = None,
     *,
     xp: Namespace,
     # TODO: np.clip has other ufunc kwargs
-    out: Optional[Array] = None,
+    out: Array | None = None,
 ) -> Array:
-    def _isscalar(a):
+    def _isscalar(a: object) -> TypeIs[int | float | None]:
         return isinstance(a, (int, float, type(None)))
 
     min_shape = () if _isscalar(min) else min.shape
@@ -378,7 +408,6 @@ def _isscalar(a):
     # but an answer of 0 might be preferred. See
     # https://github.com/numpy/numpy/issues/24976 for more discussion on this issue.
 
-
     # At least handle the case of Python integers correctly (see
     # https://github.com/numpy/numpy/pull/26892).
     if wrapped_xp.isdtype(x.dtype, "integral"):
@@ -390,6 +419,7 @@ def _isscalar(a):
     dev = _get_device(x)
     if out is None:
         out = wrapped_xp.empty(result_shape, dtype=x.dtype, device=dev)
+    assert out is not None  # workaround for a type-narrowing issue in pyright
     out[()] = x
 
     if min is not None:
@@ -407,19 +437,21 @@ def _isscalar(a):
     # Return a scalar for 0-D
     return out[()]
 
+
 # Unlike transpose(), the axes argument to permute_dims() is required.
-def permute_dims(x: Array, /, axes: Tuple[int, ...], xp: Namespace) -> Array:
+def permute_dims(x: Array, /, axes: tuple[int, ...], xp: Namespace) -> Array:
     return xp.transpose(x, axes)
 
+
 # np.reshape calls the keyword argument 'newshape' instead of 'shape'
 def reshape(
     x: Array,
     /,
-    shape: Tuple[int, ...],
+    shape: tuple[int, ...],
     xp: Namespace,
     *,
     copy: Optional[bool] = None,
-    **kwargs,
+    **kwargs: object,
 ) -> Array:
     if copy is True:
         x = x.copy()
@@ -429,6 +461,7 @@ def reshape(
         return y
     return xp.reshape(x, shape, **kwargs)
 
+
 # The descending keyword is new in sort and argsort, and 'kind' replaced with
 # 'stable'
 def argsort(
@@ -439,13 +472,13 @@ def argsort(
     axis: int = -1,
     descending: bool = False,
     stable: bool = True,
-    **kwargs,
+    **kwargs: object,
 ) -> Array:
     # Note: this keyword argument is different, and the default is different.
     # We set it in kwargs like this because numpy.sort uses kind='quicksort'
     # as the default whereas cupy.sort uses kind=None.
     if stable:
-        kwargs['kind'] = "stable"
+        kwargs["kind"] = "stable"
     if not descending:
         res = xp.argsort(x, axis=axis, **kwargs)
     else:
@@ -462,6 +495,7 @@ def argsort(
         res = max_i - res
     return res
 
+
 def sort(
     x: Array,
     /,
@@ -470,68 +504,78 @@ def sort(
     axis: int = -1,
     descending: bool = False,
     stable: bool = True,
-    **kwargs,
+    **kwargs: object,
 ) -> Array:
     # Note: this keyword argument is different, and the default is different.
     # We set it in kwargs like this because numpy.sort uses kind='quicksort'
     # as the default whereas cupy.sort uses kind=None.
     if stable:
-        kwargs['kind'] = "stable"
+        kwargs["kind"] = "stable"
     res = xp.sort(x, axis=axis, **kwargs)
     if descending:
         res = xp.flip(res, axis=axis)
     return res
 
+
 # nonzero should error for zero-dimensional arrays
-def nonzero(x: Array, /, xp: Namespace, **kwargs) -> Tuple[Array, ...]:
+def nonzero(x: Array, /, xp: Namespace, **kwargs: object) -> tuple[Array, ...]:
     if x.ndim == 0:
         raise ValueError("nonzero() does not support zero-dimensional arrays")
     return xp.nonzero(x, **kwargs)
 
+
 # ceil, floor, and trunc return integers for integer inputs
 
-def ceil(x: Array, /, xp: Namespace, **kwargs) -> Array:
+
+def ceil(x: Array, /, xp: Namespace, **kwargs: object) -> Array:
     if xp.issubdtype(x.dtype, xp.integer):
         return x
     return xp.ceil(x, **kwargs)
 
-def floor(x: Array, /, xp: Namespace, **kwargs) -> Array:
+
+def floor(x: Array, /, xp: Namespace, **kwargs: object) -> Array:
     if xp.issubdtype(x.dtype, xp.integer):
         return x
     return xp.floor(x, **kwargs)
 
-def trunc(x: Array, /, xp: Namespace, **kwargs) -> Array:
+
+def trunc(x: Array, /, xp: Namespace, **kwargs: object) -> Array:
     if xp.issubdtype(x.dtype, xp.integer):
         return x
     return xp.trunc(x, **kwargs)
 
+
 # linear algebra functions
 
-def matmul(x1: Array, x2: Array, /, xp: Namespace, **kwargs) -> Array:
+
+def matmul(x1: Array, x2: Array, /, xp: Namespace, **kwargs: object) -> Array:
     return xp.matmul(x1, x2, **kwargs)
 
+
 # Unlike transpose, matrix_transpose only transposes the last two axes.
 def matrix_transpose(x: Array, /, xp: Namespace) -> Array:
     if x.ndim < 2:
         raise ValueError("x must be at least 2-dimensional for matrix_transpose")
     return xp.swapaxes(x, -1, -2)
 
+
 def tensordot(
     x1: Array,
     x2: Array,
     /,
     xp: Namespace,
     *,
-    axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2,
-    **kwargs,
+    axes: int | tuple[Sequence[int], Sequence[int]] = 2,
+    **kwargs: object,
 ) -> Array:
     return xp.tensordot(x1, x2, axes=axes, **kwargs)
 
+
 def vecdot(x1: Array, x2: Array, /, xp: Namespace, *, axis: int = -1) -> Array:
     if x1.shape[axis] != x2.shape[axis]:
         raise ValueError("x1 and x2 must have the same size along the given axis")
 
-    if hasattr(xp, 'broadcast_tensors'):
+    if hasattr(xp, "broadcast_tensors"):
         _broadcast = xp.broadcast_tensors
     else:
         _broadcast = xp.broadcast_arrays
@@ -543,14 +587,16 @@ def vecdot(x1: Array, x2: Array, /, xp: Namespace, *, axis: int = -1) -> Array:
     res = xp.conj(x1_[..., None, :]) @ x2_[..., None]
     return res[..., 0, 0]
 
+
 # isdtype is a new function in the 2022.12 array API specification.
 
+
 def isdtype(
     dtype: DType,
-    kind: Union[DType, str, Tuple[Union[DType, str], ...]],
+    kind: DType | str | tuple[DType | str, ...],
     xp: Namespace,
     *,
-    _tuple: bool = True, # Disallow nested tuples
+    _tuple: bool = True,  # Disallow nested tuples
 ) -> bool:
     """
     Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``.
@@ -563,21 +609,24 @@ def isdtype(
     for more details
     """
     if isinstance(kind, tuple) and _tuple:
-        return any(isdtype(dtype, k, xp, _tuple=False) for k in kind)
+        return any(
+            isdtype(dtype, k, xp, _tuple=False)
+            for k in cast("tuple[DType | str, ...]", kind)
+        )
     elif isinstance(kind, str):
-        if kind == 'bool':
+        if kind == "bool":
             return dtype == xp.bool_
-        elif kind == 'signed integer':
+        elif kind == "signed integer":
             return xp.issubdtype(dtype, xp.signedinteger)
-        elif kind == 'unsigned integer':
+        elif kind == "unsigned integer":
             return xp.issubdtype(dtype, xp.unsignedinteger)
-        elif kind == 'integral':
+        elif kind == "integral":
             return xp.issubdtype(dtype, xp.integer)
-        elif kind == 'real floating':
+        elif kind == "real floating":
             return xp.issubdtype(dtype, xp.floating)
-        elif kind == 'complex floating':
+        elif kind == "complex floating":
             return xp.issubdtype(dtype, xp.complexfloating)
-        elif kind == 'numeric':
+        elif kind == "numeric":
             return xp.issubdtype(dtype, xp.number)
         else:
             raise ValueError(f"Unrecognized data type kind: {kind!r}")
@@ -588,24 +637,27 @@ def isdtype(
         # array_api_strict implementation will be very strict.
         return dtype == kind
 
+
 # unstack is a new function in the 2023.12 array API standard
-def unstack(x: Array, /, xp: Namespace, *, axis: int = 0) -> Tuple[Array, ...]:
+def unstack(x: Array, /, xp: Namespace, *, axis: int = 0) -> tuple[Array, ...]:
     if x.ndim == 0:
         raise ValueError("Input array must be at least 1-d.")
     return tuple(xp.moveaxis(x, axis, 0))
 
+
 # numpy 1.26 does not use the standard definition for sign on complex numbers
 
-def sign(x: Array, /, xp: Namespace, **kwargs) -> Array:
-    if isdtype(x.dtype, 'complex floating', xp=xp):
-        out = (x/xp.abs(x, **kwargs))[...]
+
+def sign(x: Array, /, xp: Namespace, **kwargs: object) -> Array:
+    if isdtype(x.dtype, "complex floating", xp=xp):
+        out = (x / xp.abs(x, **kwargs))[...]
         # sign(0) = 0 but the above formula would give nan
-        out[x == 0+0j] = 0+0j
+        out[x == 0j] = 0j
     else:
         out = xp.sign(x, **kwargs)
     # CuPy sign() does not propagate nans. See
     # https://github.com/data-apis/array-api-compat/issues/136
-    if _is_cupy_namespace(xp) and isdtype(x.dtype, 'real floating', xp=xp):
+    if _is_cupy_namespace(xp) and isdtype(x.dtype, "real floating", xp=xp):
         out[xp.isnan(x)] = xp.nan
     return out[()]
 
@@ -626,13 +678,50 @@ def iinfo(type_: DType | Array, /, xp: Namespace) -> Any:
         return xp.iinfo(type_.dtype)
 
 
-__all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like',
-           'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like',
-           'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
-           'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
-           'std', 'var', 'cumulative_sum', 'cumulative_prod','clip', 'permute_dims',
-           'reshape', 'argsort', 'sort', 'nonzero', 'ceil', 'floor', 'trunc',
-           'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype',
-           'unstack', 'sign', 'finfo', 'iinfo']
-
-_all_ignore = ['inspect', 'array_namespace', 'NamedTuple']
+__all__ = [
+    "arange",
+    "empty",
+    "empty_like",
+    "eye",
+    "full",
+    "full_like",
+    "linspace",
+    "ones",
+    "ones_like",
+    "zeros",
+    "zeros_like",
+    "UniqueAllResult",
+    "UniqueCountsResult",
+    "UniqueInverseResult",
+    "unique_all",
+    "unique_counts",
+    "unique_inverse",
+    "unique_values",
+    "std",
+    "var",
+    "cumulative_sum",
+    "cumulative_prod",
+    "clip",
+    "permute_dims",
+    "reshape",
+    "argsort",
+    "sort",
+    "nonzero",
+    "ceil",
+    "floor",
+    "trunc",
+    "matmul",
+    "matrix_transpose",
+    "tensordot",
+    "vecdot",
+    "isdtype",
+    "unstack",
+    "sign",
+    "finfo",
+    "iinfo",
+]
+_all_ignore = ["inspect", "array_namespace", "NamedTuple"]
+
+
+def __dir__() -> list[str]:
+    return __all__
diff --git a/array_api_compat/common/_fft.py b/array_api_compat/common/_fft.py
index bd2a4e1a..18839d37 100644
--- a/array_api_compat/common/_fft.py
+++ b/array_api_compat/common/_fft.py
@@ -1,9 +1,11 @@
 from __future__ import annotations
 
 from collections.abc import Sequence
-from typing import Union, Optional, Literal
+from typing import Literal, TypeAlias
 
-from ._typing import Device, Array, DType, Namespace
+from ._typing import Array, Device, DType, Namespace
+
+_Norm: TypeAlias = Literal["backward", "ortho", "forward"]
 
 # Note: NumPy fft functions improperly upcast float32 and complex64 to
 # complex128, which is why we require wrapping them all here.
@@ -13,9 +15,9 @@ def fft(
     /,
     xp: Namespace,
     *,
-    n: Optional[int] = None,
+    n: int | None = None,
     axis: int = -1,
-    norm: Literal["backward", "ortho", "forward"] = "backward",
+    norm: _Norm = "backward",
 ) -> Array:
     res = xp.fft.fft(x, n=n, axis=axis, norm=norm)
     if x.dtype in [xp.float32, xp.complex64]:
@@ -27,9 +29,9 @@ def ifft(
     /,
     xp: Namespace,
     *,
-    n: Optional[int] = None,
+    n: int | None = None,
     axis: int = -1,
-    norm: Literal["backward", "ortho", "forward"] = "backward",
+    norm: _Norm = "backward",
 ) -> Array:
     res = xp.fft.ifft(x, n=n, axis=axis, norm=norm)
     if x.dtype in [xp.float32, xp.complex64]:
@@ -41,9 +43,9 @@ def fftn(
     /,
     xp: Namespace,
     *,
-    s: Sequence[int] = None,
-    axes: Sequence[int] = None,
-    norm: Literal["backward", "ortho", "forward"] = "backward",
+    s: Sequence[int] | None = None,
+    axes: Sequence[int] | None = None,
+    norm: _Norm = "backward",
 ) -> Array:
     res = xp.fft.fftn(x, s=s, axes=axes, norm=norm)
     if x.dtype in [xp.float32, xp.complex64]:
@@ -55,9 +57,9 @@ def ifftn(
     /,
     xp: Namespace,
     *,
-    s: Sequence[int] = None,
-    axes: Sequence[int] = None,
-    norm: Literal["backward", "ortho", "forward"] = "backward",
+    s: Sequence[int] | None = None,
+    axes: Sequence[int] | None = None,
+    norm: _Norm = "backward",
 ) -> Array:
     res = xp.fft.ifftn(x, s=s, axes=axes, norm=norm)
     if x.dtype in [xp.float32, xp.complex64]:
@@ -69,9 +71,9 @@ def rfft(
     /,
     xp: Namespace,
     *,
-    n: Optional[int] = None,
+    n: int | None = None,
     axis: int = -1,
-    norm: Literal["backward", "ortho", "forward"] = "backward",
+    norm: _Norm = "backward",
 ) -> Array:
     res = xp.fft.rfft(x, n=n, axis=axis, norm=norm)
     if x.dtype == xp.float32:
@@ -83,9 +85,9 @@ def irfft(
     /,
     xp: Namespace,
     *,
-    n: Optional[int] = None,
+    n: int | None = None,
     axis: int = -1,
-    norm: Literal["backward", "ortho", "forward"] = "backward",
+    norm: _Norm = "backward",
 ) -> Array:
     res = xp.fft.irfft(x, n=n, axis=axis, norm=norm)
     if x.dtype == xp.complex64:
@@ -97,9 +99,9 @@ def rfftn(
     /,
     xp: Namespace,
     *,
-    s: Sequence[int] = None,
-    axes: Sequence[int] = None,
-    norm: Literal["backward", "ortho", "forward"] = "backward",
+    s: Sequence[int] | None = None,
+    axes: Sequence[int] | None = None,
+    norm: _Norm = "backward",
 ) -> Array:
     res = xp.fft.rfftn(x, s=s, axes=axes, norm=norm)
     if x.dtype == xp.float32:
@@ -111,9 +113,9 @@ def irfftn(
     /,
     xp: Namespace,
     *,
-    s: Sequence[int] = None,
-    axes: Sequence[int] = None,
-    norm: Literal["backward", "ortho", "forward"] = "backward",
+    s: Sequence[int] | None = None,
+    axes: Sequence[int] | None = None,
+    norm: _Norm = "backward",
 ) -> Array:
     res = xp.fft.irfftn(x, s=s, axes=axes, norm=norm)
     if x.dtype == xp.complex64:
@@ -125,9 +127,9 @@ def hfft(
     /,
     xp: Namespace,
     *,
-    n: Optional[int] = None,
+    n: int | None = None,
     axis: int = -1,
-    norm: Literal["backward", "ortho", "forward"] = "backward",
+    norm: _Norm = "backward",
 ) -> Array:
     res = xp.fft.hfft(x, n=n, axis=axis, norm=norm)
     if x.dtype in [xp.float32, xp.complex64]:
@@ -139,9 +141,9 @@ def ihfft(
     /,
     xp: Namespace,
     *,
-    n: Optional[int] = None,
+    n: int | None = None,
     axis: int = -1,
-    norm: Literal["backward", "ortho", "forward"] = "backward",
+    norm: _Norm = "backward",
 ) -> Array:
     res = xp.fft.ihfft(x, n=n, axis=axis, norm=norm)
     if x.dtype in [xp.float32, xp.complex64]:
@@ -154,8 +156,8 @@ def fftfreq(
     xp: Namespace,
     *,
     d: float = 1.0,
-    dtype: Optional[DType] = None,
-    device: Optional[Device] = None,
+    dtype: DType | None = None,
+    device: Device | None = None,
 ) -> Array:
     if device not in ["cpu", None]:
         raise ValueError(f"Unsupported device {device!r}")
@@ -170,8 +172,8 @@ def rfftfreq(
     xp: Namespace,
     *,
     d: float = 1.0,
-    dtype: Optional[DType] = None,
-    device: Optional[Device] = None,
+    dtype: DType | None = None,
+    device: Device | None = None,
 ) -> Array:
     if device not in ["cpu", None]:
         raise ValueError(f"Unsupported device {device!r}")
@@ -181,12 +183,12 @@ def rfftfreq(
     return res
 
 def fftshift(
-    x: Array, /, xp: Namespace, *, axes: Union[int, Sequence[int]] = None
+    x: Array, /, xp: Namespace, *, axes: int | Sequence[int] | None = None
 ) -> Array:
     return xp.fft.fftshift(x, axes=axes)
 
 def ifftshift(
-    x: Array, /, xp: Namespace, *, axes: Union[int, Sequence[int]] = None
+    x: Array, /, xp: Namespace, *, axes: int | Sequence[int] | None = None
 ) -> Array:
     return xp.fft.ifftshift(x, axes=axes)
 
@@ -206,3 +208,6 @@ def ifftshift(
     "fftshift",
     "ifftshift",
 ]
+
+def __dir__() -> list[str]:
+    return __all__
diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py
index 67c619b8..db3e4cd7 100644
--- a/array_api_compat/common/_helpers.py
+++ b/array_api_compat/common/_helpers.py
@@ -5,33 +5,82 @@
 that are in __all__ are intended as additional helper functions for use by end
 users of the compat library.
 """
+
 from __future__ import annotations
 
-import sys
-import math
 import inspect
+import math
+import sys
 import warnings
-from typing import Optional, Union, Any
+from collections.abc import Collection
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Final,
+    Literal,
+    SupportsIndex,
+    TypeAlias,
+    TypeGuard,
+    TypeVar,
+    cast,
+    overload,
+)
+
+from ._typing import Array, Device, HasShape, Namespace, SupportsArrayNamespace
+
+if TYPE_CHECKING:
+
+    import dask.array as da
+    import jax
+    import ndonnx as ndx
+    import numpy as np
+    import numpy.typing as npt
+    import sparse  # pyright: ignore[reportMissingTypeStubs]
+    import torch
+
+    # TODO: import from typing (requires Python >=3.13)
+    from typing_extensions import TypeIs, TypeVar
 
-from ._typing import Array, Device, Namespace
+    _SizeT = TypeVar("_SizeT", bound = int | None)
 
+    _ZeroGradientArray: TypeAlias = npt.NDArray[np.void]
+    _CupyArray: TypeAlias = Any  # cupy has no py.typed
 
-def _is_jax_zero_gradient_array(x: object) -> bool:
+    _ArrayApiObj: TypeAlias = (
+        npt.NDArray[Any]
+        | da.Array
+        | jax.Array
+        | ndx.Array
+        | sparse.SparseArray
+        | torch.Tensor
+        | SupportsArrayNamespace[Any]
+        | _CupyArray
+    )
+
+_API_VERSIONS_OLD: Final = frozenset({"2021.12", "2022.12", "2023.12"})
+_API_VERSIONS: Final = _API_VERSIONS_OLD | frozenset({"2024.12"})
+
+
+def _is_jax_zero_gradient_array(x: object) -> TypeGuard[_ZeroGradientArray]:
     """Return True if `x` is a zero-gradient array.
 
     These arrays are a design quirk of Jax that may one day be removed.
     See https://github.com/google/jax/issues/20620.
     """
-    if 'numpy' not in sys.modules or 'jax' not in sys.modules:
+    if "numpy" not in sys.modules or "jax" not in sys.modules:
         return False
 
-    import numpy as np
     import jax
+    import numpy as np
 
-    return isinstance(x, np.ndarray) and x.dtype == jax.float0
+    jax_float0 = cast("np.dtype[np.void]", jax.float0)
+    return (
+        isinstance(x, np.ndarray)
+        and cast("npt.NDArray[np.void]", x).dtype == jax_float0
+    )
 
 
-def is_numpy_array(x: object) -> bool:
+def is_numpy_array(x: object) -> TypeGuard[npt.NDArray[Any]]:
     """
     Return True if `x` is a NumPy array.
 
@@ -53,14 +102,14 @@ def is_numpy_array(x: object) -> bool:
     is_pydata_sparse_array
     """
     # Avoid importing NumPy if it isn't already
-    if 'numpy' not in sys.modules:
+    if "numpy" not in sys.modules:
         return False
 
     import numpy as np
 
     # TODO: Should we reject ndarray subclasses?
     return (isinstance(x, (np.ndarray, np.generic))
-            and not _is_jax_zero_gradient_array(x))
+            and not _is_jax_zero_gradient_array(x))  # pyright: ignore[reportUnknownArgumentType]  # fmt: skip
 
 
 def is_cupy_array(x: object) -> bool:
@@ -85,16 +134,16 @@ def is_cupy_array(x: object) -> bool:
     is_pydata_sparse_array
     """
     # Avoid importing CuPy if it isn't already
-    if 'cupy' not in sys.modules:
+    if "cupy" not in sys.modules:
         return False
 
-    import cupy as cp
+    import cupy as cp  # pyright: ignore[reportMissingTypeStubs]
 
     # TODO: Should we reject ndarray subclasses?
-    return isinstance(x, cp.ndarray)
+    return isinstance(x, cp.ndarray)  # pyright: ignore[reportUnknownMemberType]
 
 
-def is_torch_array(x: object) -> bool:
+def is_torch_array(x: object) -> TypeIs[torch.Tensor]:
     """
     Return True if `x` is a PyTorch tensor.
 
@@ -113,7 +162,7 @@ def is_torch_array(x: object) -> bool:
     is_pydata_sparse_array
     """
     # Avoid importing torch if it isn't already
-    if 'torch' not in sys.modules:
+    if "torch" not in sys.modules:
         return False
 
     import torch
@@ -122,7 +171,7 @@ def is_torch_array(x: object) -> bool:
     return isinstance(x, torch.Tensor)
 
 
-def is_ndonnx_array(x: object) -> bool:
+def is_ndonnx_array(x: object) -> TypeIs[ndx.Array]:
     """
     Return True if `x` is a ndonnx Array.
 
@@ -142,7 +191,7 @@ def is_ndonnx_array(x: object) -> bool:
     is_pydata_sparse_array
     """
     # Avoid importing torch if it isn't already
-    if 'ndonnx' not in sys.modules:
+    if "ndonnx" not in sys.modules:
         return False
 
     import ndonnx as ndx
@@ -150,7 +199,7 @@ def is_ndonnx_array(x: object) -> bool:
     return isinstance(x, ndx.Array)
 
 
-def is_dask_array(x: object) -> bool:
+def is_dask_array(x: object) -> TypeIs[da.Array]:
     """
     Return True if `x` is a dask.array Array.
 
@@ -170,7 +219,7 @@ def is_dask_array(x: object) -> bool:
     is_pydata_sparse_array
     """
     # Avoid importing dask if it isn't already
-    if 'dask.array' not in sys.modules:
+    if "dask.array" not in sys.modules:
         return False
 
     import dask.array
@@ -178,7 +227,7 @@ def is_dask_array(x: object) -> bool:
     return isinstance(x, dask.array.Array)
 
 
-def is_jax_array(x: object) -> bool:
+def is_jax_array(x: object) -> TypeIs[jax.Array]:
     """
     Return True if `x` is a JAX array.
 
@@ -199,7 +248,7 @@ def is_jax_array(x: object) -> bool:
     is_pydata_sparse_array
     """
     # Avoid importing jax if it isn't already
-    if 'jax' not in sys.modules:
+    if "jax" not in sys.modules:
         return False
 
     import jax
@@ -207,7 +256,7 @@ def is_jax_array(x: object) -> bool:
     return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x)
 
 
-def is_pydata_sparse_array(x) -> bool:
+def is_pydata_sparse_array(x: object) -> TypeIs[sparse.SparseArray]:
     """
     Return True if `x` is an array from the `sparse` package.
 
@@ -228,16 +277,16 @@ def is_pydata_sparse_array(x) -> bool:
     is_jax_array
     """
     # Avoid importing jax if it isn't already
-    if 'sparse' not in sys.modules:
+    if "sparse" not in sys.modules:
         return False
 
-    import sparse
+    import sparse  # pyright: ignore[reportMissingTypeStubs]
 
     # TODO: Account for other backends.
     return isinstance(x, sparse.SparseArray)
 
 
-def is_array_api_obj(x: object) -> bool:
+def is_array_api_obj(x: object) -> TypeIs[_ArrayApiObj]:  # pyright: ignore[reportUnknownParameterType]
     """
     Return True if `x` is an array API compatible array object.
 
@@ -252,18 +301,20 @@ def is_array_api_obj(x: object) -> bool:
     is_dask_array
     is_jax_array
     """
-    return is_numpy_array(x) \
-        or is_cupy_array(x) \
-        or is_torch_array(x) \
-        or is_dask_array(x) \
-        or is_jax_array(x) \
-        or is_pydata_sparse_array(x) \
-        or hasattr(x, '__array_namespace__')
+    return (
+        is_numpy_array(x)
+        or is_cupy_array(x)
+        or is_torch_array(x)
+        or is_dask_array(x)
+        or is_jax_array(x)
+        or is_pydata_sparse_array(x)
+        or hasattr(x, "__array_namespace__")
+    )
 
 
 def _compat_module_name() -> str:
-    assert __name__.endswith('.common._helpers')
-    return __name__.removesuffix('.common._helpers')
+    assert __name__.endswith(".common._helpers")
+    return __name__.removesuffix(".common._helpers")
 
 
 def is_numpy_namespace(xp: Namespace) -> bool:
@@ -284,7 +335,7 @@ def is_numpy_namespace(xp: Namespace) -> bool:
     is_pydata_sparse_namespace
     is_array_api_strict_namespace
     """
-    return xp.__name__ in {'numpy', _compat_module_name() + '.numpy'}
+    return xp.__name__ in {"numpy", _compat_module_name() + ".numpy"}
 
 
 def is_cupy_namespace(xp: Namespace) -> bool:
@@ -305,7 +356,7 @@ def is_cupy_namespace(xp: Namespace) -> bool:
     is_pydata_sparse_namespace
     is_array_api_strict_namespace
     """
-    return xp.__name__ in {'cupy', _compat_module_name() + '.cupy'}
+    return xp.__name__ in {"cupy", _compat_module_name() + ".cupy"}
 
 
 def is_torch_namespace(xp: Namespace) -> bool:
@@ -326,7 +377,7 @@ def is_torch_namespace(xp: Namespace) -> bool:
     is_pydata_sparse_namespace
     is_array_api_strict_namespace
     """
-    return xp.__name__ in {'torch', _compat_module_name() + '.torch'}
+    return xp.__name__ in {"torch", _compat_module_name() + ".torch"}
 
 
 def is_ndonnx_namespace(xp: Namespace) -> bool:
@@ -345,7 +396,7 @@ def is_ndonnx_namespace(xp: Namespace) -> bool:
     is_pydata_sparse_namespace
     is_array_api_strict_namespace
     """
-    return xp.__name__ == 'ndonnx'
+    return xp.__name__ == "ndonnx"
 
 
 def is_dask_namespace(xp: Namespace) -> bool:
@@ -366,7 +417,7 @@ def is_dask_namespace(xp: Namespace) -> bool:
     is_pydata_sparse_namespace
     is_array_api_strict_namespace
     """
-    return xp.__name__ in {'dask.array', _compat_module_name() + '.dask.array'}
+    return xp.__name__ in {"dask.array", _compat_module_name() + ".dask.array"}
 
 
 def is_jax_namespace(xp: Namespace) -> bool:
@@ -388,7 +439,7 @@ def is_jax_namespace(xp: Namespace) -> bool:
     is_pydata_sparse_namespace
     is_array_api_strict_namespace
     """
-    return xp.__name__ in {'jax.numpy', 'jax.experimental.array_api'}
+    return xp.__name__ in {"jax.numpy", "jax.experimental.array_api"}
 
 
 def is_pydata_sparse_namespace(xp: Namespace) -> bool:
@@ -407,7 +458,7 @@ def is_pydata_sparse_namespace(xp: Namespace) -> bool:
     is_jax_namespace
     is_array_api_strict_namespace
     """
-    return xp.__name__ == 'sparse'
+    return xp.__name__ == "sparse"
 
 
 def is_array_api_strict_namespace(xp: Namespace) -> bool:
@@ -426,21 +477,24 @@ def is_array_api_strict_namespace(xp: Namespace) -> bool:
     is_jax_namespace
     is_pydata_sparse_namespace
     """
-    return xp.__name__ == 'array_api_strict'
+    return xp.__name__ == "array_api_strict"
 
 
-def _check_api_version(api_version: str) -> None:
-    if api_version in ['2021.12', '2022.12', '2023.12']:
-        warnings.warn(f"The {api_version} version of the array API specification was requested but the returned namespace is actually version 2024.12")
-    elif api_version is not None and api_version not in ['2021.12', '2022.12',
-                                                         '2023.12', '2024.12']:
-        raise ValueError("Only the 2024.12 version of the array API specification is currently supported")
+def _check_api_version(api_version: str | None) -> None:
+    if api_version in _API_VERSIONS_OLD:
+        warnings.warn(
+            f"The {api_version} version of the array API specification was requested but the returned namespace is actually version 2024.12"
+        )
+    elif api_version is not None and api_version not in _API_VERSIONS:
+        raise ValueError(
+            "Only the 2024.12 version of the array API specification is currently supported"
+        )
 
 
 def array_namespace(
-    *xs: Union[Array, bool, int, float, complex, None],
-    api_version: Optional[str] = None,
-    use_compat: Optional[bool] = None,
+    *xs: Array | complex | None,
+    api_version: str | None = None,
+    use_compat: bool | None = None,
 ) -> Namespace:
     """
     Get the array API compatible namespace for the arrays `xs`.
@@ -510,11 +564,13 @@ def your_function(x, y):
 
     _use_compat = use_compat in [None, True]
 
-    namespaces = set()
+    namespaces: set[Namespace] = set()
     for x in xs:
         if is_numpy_array(x):
-            from .. import numpy as numpy_namespace
             import numpy as np
+
+            from .. import numpy as numpy_namespace
+
             if use_compat is True:
                 _check_api_version(api_version)
                 namespaces.add(numpy_namespace)
@@ -528,25 +584,31 @@ def your_function(x, y):
             if _use_compat:
                 _check_api_version(api_version)
                 from .. import cupy as cupy_namespace
+
                 namespaces.add(cupy_namespace)
             else:
-                import cupy as cp
+                import cupy as cp  # pyright: ignore[reportMissingTypeStubs]
+
                 namespaces.add(cp)
         elif is_torch_array(x):
             if _use_compat:
                 _check_api_version(api_version)
                 from .. import torch as torch_namespace
+
                 namespaces.add(torch_namespace)
             else:
                 import torch
+
                 namespaces.add(torch)
         elif is_dask_array(x):
             if _use_compat:
                 _check_api_version(api_version)
                 from ..dask import array as dask_namespace
+
                 namespaces.add(dask_namespace)
             else:
                 import dask.array as da
+
                 namespaces.add(da)
         elif is_jax_array(x):
             if use_compat is True:
@@ -558,23 +620,27 @@ def your_function(x, y):
                 # JAX v0.4.32 and newer implements the array API directly in jax.numpy.
                 # For older JAX versions, it is available via jax.experimental.array_api.
                 import jax.numpy
+
                 if hasattr(jax.numpy, "__array_api_version__"):
                     jnp = jax.numpy
                 else:
-                    import jax.experimental.array_api as jnp
+                    import jax.experimental.array_api as jnp  # pyright: ignore[reportMissingImports]
             namespaces.add(jnp)
         elif is_pydata_sparse_array(x):
             if use_compat is True:
                 _check_api_version(api_version)
                 raise ValueError("`sparse` does not have an array-api-compat wrapper")
             else:
-                import sparse
+                import sparse  # pyright: ignore[reportMissingTypeStubs]
             # `sparse` is already an array namespace. We do not have a wrapper
             # submodule for it.
             namespaces.add(sparse)
-        elif hasattr(x, '__array_namespace__'):
+        elif hasattr(x, "__array_namespace__"):
             if use_compat is True:
-                raise ValueError("The given array does not have an array-api-compat wrapper")
+                raise ValueError(
+                    "The given array does not have an array-api-compat wrapper"
+                )
+            x = cast("SupportsArrayNamespace[Any]", x)
             namespaces.add(x.__array_namespace__(api_version=api_version))
         elif isinstance(x, (bool, int, float, complex, type(None))):
             continue
@@ -588,15 +654,16 @@ def your_function(x, y):
     if len(namespaces) != 1:
         raise TypeError(f"Multiple namespaces for array inputs: {namespaces}")
 
-    xp, = namespaces
+    (xp,) = namespaces
 
     return xp
 
+
 # backwards compatibility alias
 get_namespace = array_namespace
 
 
-def _check_device(bare_xp, device):
+def _check_device(bare_xp: Namespace, device: Device) -> None:  # pyright: ignore[reportUnusedFunction]
     """
     Validate dummy device on device-less array backends.
 
@@ -609,11 +676,11 @@ def _check_device(bare_xp, device):
 
     https://github.com/data-apis/array-api-compat/pull/293
     """
-    if bare_xp is sys.modules.get('numpy'):
+    if bare_xp is sys.modules.get("numpy"):
         if device not in ("cpu", None):
             raise ValueError(f"Unsupported device for NumPy: {device!r}")
 
-    elif bare_xp is sys.modules.get('dask.array'):
+    elif bare_xp is sys.modules.get("dask.array"):
         if device not in ("cpu", _DASK_DEVICE, None):
             raise ValueError(f"Unsupported device for Dask: {device!r}")
 
@@ -622,18 +689,20 @@ def _check_device(bare_xp, device):
 # when the array backend is not the CPU.
 # (since it is not easy to tell which device a dask array is on)
 class _dask_device:
-    def __repr__(self):
+    def __repr__(self) -> Literal["DASK_DEVICE"]:
         return "DASK_DEVICE"
 
+
 _DASK_DEVICE = _dask_device()
 
+
 # device() is not on numpy.ndarray or dask.array and to_device() is not on numpy.ndarray
 # or cupy.ndarray. They are not included in array objects of this library
 # because this library just reuses the respective ndarray classes without
 # wrapping or subclassing them. These helper functions can be used instead of
 # the wrapper functions for libraries that need to support both NumPy/CuPy and
 # other libraries that use devices.
-def device(x: Array, /) -> Device:
+def device(x: _ArrayApiObj, /) -> Device:
     """
     Hardware device the array data resides on.
 
@@ -669,7 +738,7 @@ def device(x: Array, /) -> Device:
         return "cpu"
     elif is_dask_array(x):
         # Peek at the metadata of the Dask array to determine type
-        if is_numpy_array(x._meta):
+        if is_numpy_array(x._meta):  # pyright: ignore
             # Must be on CPU since backed by numpy
             return "cpu"
         return _DASK_DEVICE
@@ -679,7 +748,7 @@ def device(x: Array, /) -> Device:
         #       Return None in this case. Note that this workaround breaks
         #       the standard and will result in new arrays being created on the
         #       default device instead of the same device as the input array(s).
-        x_device = getattr(x, 'device', None)
+        x_device = getattr(x, "device", None)
         # Older JAX releases had .device() as a method, which has been replaced
         # with a property in accordance with the standard.
         if inspect.ismethod(x_device):
@@ -688,27 +757,34 @@ def device(x: Array, /) -> Device:
             return x_device
     elif is_pydata_sparse_array(x):
         # `sparse` will gain `.device`, so check for this first.
-        x_device = getattr(x, 'device', None)
+        x_device = getattr(x, "device", None)
         if x_device is not None:
             return x_device
         # Everything but DOK has this attr.
         try:
-            inner = x.data
+            inner = x.data  # pyright: ignore
         except AttributeError:
             return "cpu"
         # Return the device of the constituent array
-        return device(inner)
-    return x.device
+        return device(inner)  # pyright: ignore
+    return x.device  # pyright: ignore
+
 
 # Prevent shadowing, used below
 _device = device
 
+
 # Based on cupy.array_api.Array.to_device
-def _cupy_to_device(x, device, /, stream=None):
-    import cupy as cp
-    from cupy.cuda import Device as _Device
-    from cupy.cuda import stream as stream_module
-    from cupy_backends.cuda.api import runtime
+def _cupy_to_device(
+    x: _CupyArray,
+    device: Device,
+    /,
+    stream: int | Any | None = None,
+) -> _CupyArray:
+    import cupy as cp  # pyright: ignore[reportMissingTypeStubs]
+    from cupy.cuda import Device as _Device  # pyright: ignore
+    from cupy.cuda import stream as stream_module  # pyright: ignore
+    from cupy_backends.cuda.api import runtime  # pyright: ignore
 
     if device == x.device:
         return x
@@ -721,33 +797,40 @@ def _cupy_to_device(x, device, /, stream=None):
         raise ValueError(f"Unsupported device {device!r}")
     else:
         # see cupy/cupy#5985 for the reason how we handle device/stream here
-        prev_device = runtime.getDevice()
-        prev_stream: stream_module.Stream = None
+        prev_device: Any = runtime.getDevice()  # pyright: ignore[reportUnknownMemberType]
+        prev_stream = None
         if stream is not None:
-            prev_stream = stream_module.get_current_stream()
+            prev_stream: Any = stream_module.get_current_stream()  # pyright: ignore
             # stream can be an int as specified in __dlpack__, or a CuPy stream
             if isinstance(stream, int):
-                stream = cp.cuda.ExternalStream(stream)
-            elif isinstance(stream, cp.cuda.Stream):
+                stream = cp.cuda.ExternalStream(stream)  # pyright: ignore
+            elif isinstance(stream, cp.cuda.Stream):  # pyright: ignore[reportUnknownMemberType]
                 pass
             else:
-                raise ValueError('the input stream is not recognized')
-            stream.use()
+                raise ValueError("the input stream is not recognized")
+            stream.use()  # pyright: ignore[reportUnknownMemberType]
         try:
-            runtime.setDevice(device.id)
+            runtime.setDevice(device.id)  # pyright: ignore[reportUnknownMemberType]
             arr = x.copy()
         finally:
-            runtime.setDevice(prev_device)
+            runtime.setDevice(prev_device)  # pyright: ignore[reportUnknownMemberType]
             if stream is not None:
                 prev_stream.use()
         return arr
 
-def _torch_to_device(x, device, /, stream=None):
+
+def _torch_to_device(
+    x: torch.Tensor,
+    device: torch.device | str | int,
+    /,
+    stream: None = None,
+) -> torch.Tensor:
     if stream is not None:
         raise NotImplementedError
     return x.to(device)
 
-def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] = None) -> Array:
+
+def to_device(x: Array, device: Device, /, *, stream: int | Any | None = None) -> Array:
     """
     Copy the array from the device on which it currently resides to the specified ``device``.
 
@@ -767,7 +850,7 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
         a ``device`` object (see the `Device Support <https://data-apis.org/array-api/latest/design_topics/device_support.html>`__
         section of the array API specification).
 
-    stream: Optional[Union[int, Any]]
+    stream: int | Any | None
         stream object to use during copy. In addition to the types supported
         in ``array.__dlpack__``, implementations may choose to support any
         library-specific stream object with the caveat that any code using
@@ -799,25 +882,26 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
     if is_numpy_array(x):
         if stream is not None:
             raise ValueError("The stream argument to to_device() is not supported")
-        if device == 'cpu':
+        if device == "cpu":
             return x
         raise ValueError(f"Unsupported device {device!r}")
     elif is_cupy_array(x):
         # cupy does not yet have to_device
         return _cupy_to_device(x, device, stream=stream)
     elif is_torch_array(x):
-        return _torch_to_device(x, device, stream=stream)
+        return _torch_to_device(x, device, stream=stream)  # pyright: ignore[reportArgumentType]
     elif is_dask_array(x):
         if stream is not None:
             raise ValueError("The stream argument to to_device() is not supported")
         # TODO: What if our array is on the GPU already?
-        if device == 'cpu':
+        if device == "cpu":
             return x
         raise ValueError(f"Unsupported device {device!r}")
     elif is_jax_array(x):
         if not hasattr(x, "__array_namespace__"):
             # In JAX v0.4.31 and older, this import adds to_device method to x...
-            import jax.experimental.array_api # noqa: F401
+            import jax.experimental.array_api  # noqa: F401  # pyright: ignore
+
             # ... but only on eager JAX. It won't work inside jax.jit.
             if not hasattr(x, "to_device"):
                 return x
@@ -826,10 +910,16 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
         # Perform trivial check to return the same array if
         # device is same instead of err-ing.
         return x
-    return x.to_device(device, stream=stream)
+    return x.to_device(device, stream=stream)  # pyright: ignore
 
 
-def size(x: Array) -> int | None:
+@overload
+def size(x: HasShape[Collection[SupportsIndex]]) -> int: ...
+@overload
+def size(x: HasShape[Collection[None]]) -> None: ...
+@overload
+def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None: ...
+def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None:
     """
     Return the total number of elements of x.
 
@@ -844,7 +934,7 @@ def size(x: Array) -> int | None:
     # Lazy API compliant arrays, such as ndonnx, can contain None in their shape
     if None in x.shape:
         return None
-    out = math.prod(x.shape)
+    out = math.prod(cast("Collection[SupportsIndex]", x.shape))
     # dask.array.Array.shape can contain NaN
     return None if math.isnan(out) else out
 
@@ -907,7 +997,7 @@ def is_lazy_array(x: object) -> bool:
     # on __bool__ (dask is one such example, which however is special-cased above).
 
     # Select a single point of the array
-    s = size(x)
+    s = size(cast("HasShape[Collection[SupportsIndex | None]]", x))
     if s is None:
         return True
     xp = array_namespace(x)
@@ -952,4 +1042,7 @@ def is_lazy_array(x: object) -> bool:
     "to_device",
 ]
 
-_all_ignore = ['sys', 'math', 'inspect', 'warnings']
+_all_ignore = ["sys", "math", "inspect", "warnings"]
+
+def __dir__() -> list[str]:
+    return __all__
diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py
index d1e7ebd8..7e002aed 100644
--- a/array_api_compat/common/_linalg.py
+++ b/array_api_compat/common/_linalg.py
@@ -1,23 +1,33 @@
 from __future__ import annotations
 
 import math
-from typing import Literal, NamedTuple, Optional, Tuple, Union
+from typing import Literal, NamedTuple, cast
 
 import numpy as np
+
 if np.__version__[0] == "2":
     from numpy.lib.array_utils import normalize_axis_tuple
 else:
     from numpy.core.numeric import normalize_axis_tuple
 
-from ._aliases import matmul, matrix_transpose, tensordot, vecdot, isdtype
 from .._internal import get_xp
-from ._typing import Array, Namespace
+from ._aliases import isdtype, matmul, matrix_transpose, tensordot, vecdot
+from ._typing import Array, DType, Namespace
+
 
 # These are in the main NumPy namespace but not in numpy.linalg
-def cross(x1: Array, x2: Array, /, xp: Namespace, *, axis: int = -1, **kwargs) -> Array:
+def cross(
+    x1: Array,
+    x2: Array,
+    /,
+    xp: Namespace,
+    *,
+    axis: int = -1,
+    **kwargs: object,
+) -> Array:
     return xp.cross(x1, x2, axis=axis, **kwargs)
 
-def outer(x1: Array, x2: Array, /, xp: Namespace, **kwargs) -> Array:
+def outer(x1: Array, x2: Array, /, xp: Namespace, **kwargs: object) -> Array:
     return xp.outer(x1, x2, **kwargs)
 
 class EighResult(NamedTuple):
@@ -39,46 +49,66 @@ class SVDResult(NamedTuple):
 
 # These functions are the same as their NumPy counterparts except they return
 # a namedtuple.
-def eigh(x: Array, /, xp: Namespace, **kwargs) -> EighResult:
+def eigh(x: Array, /, xp: Namespace, **kwargs: object) -> EighResult:
     return EighResult(*xp.linalg.eigh(x, **kwargs))
 
-def qr(x: Array, /, xp: Namespace, *, mode: Literal['reduced', 'complete'] = 'reduced',
-       **kwargs) -> QRResult:
+def qr(
+    x: Array,
+    /,
+    xp: Namespace,
+    *,
+    mode: Literal["reduced", "complete"] = "reduced",
+    **kwargs: object,
+) -> QRResult:
     return QRResult(*xp.linalg.qr(x, mode=mode, **kwargs))
 
-def slogdet(x: Array, /, xp: Namespace, **kwargs) -> SlogdetResult:
+def slogdet(x: Array, /, xp: Namespace, **kwargs: object) -> SlogdetResult:
     return SlogdetResult(*xp.linalg.slogdet(x, **kwargs))
 
 def svd(
-    x: Array, /, xp: Namespace, *, full_matrices: bool = True, **kwargs
+    x: Array,
+    /,
+    xp: Namespace,
+    *,
+    full_matrices: bool = True,
+    **kwargs: object,
 ) -> SVDResult:
     return SVDResult(*xp.linalg.svd(x, full_matrices=full_matrices, **kwargs))
 
 # These functions have additional keyword arguments
 
 # The upper keyword argument is new from NumPy
-def cholesky(x: Array, /, xp: Namespace, *, upper: bool = False, **kwargs) -> Array:
+def cholesky(
+    x: Array,
+    /,
+    xp: Namespace,
+    *,
+    upper: bool = False,
+    **kwargs: object,
+) -> Array:
     L = xp.linalg.cholesky(x, **kwargs)
     if upper:
         U = get_xp(xp)(matrix_transpose)(L)
         if get_xp(xp)(isdtype)(U.dtype, 'complex floating'):
-            U = xp.conj(U)
+            U = xp.conj(U)  # pyright: ignore[reportConstantRedefinition]
         return U
     return L
 
 # The rtol keyword argument of matrix_rank() and pinv() is new from NumPy.
 # Note that it has a different semantic meaning from tol and rcond.
-def matrix_rank(x: Array,
-                /,
-                xp: Namespace,
-                *,
-                rtol: Optional[Union[float, Array]] = None,
-                **kwargs) -> Array:
+def matrix_rank(
+    x: Array,
+    /,
+    xp: Namespace,
+    *,
+    rtol: float | Array | None = None,
+    **kwargs: object,
+) -> Array:
     # this is different from xp.linalg.matrix_rank, which supports 1
     # dimensional arrays.
     if x.ndim < 2:
         raise xp.linalg.LinAlgError("1-dimensional array given. Array must be at least two-dimensional")
-    S = get_xp(xp)(svdvals)(x, **kwargs)
+    S: Array = get_xp(xp)(svdvals)(x, **kwargs)
     if rtol is None:
         tol = S.max(axis=-1, keepdims=True) * max(x.shape[-2:]) * xp.finfo(S.dtype).eps
     else:
@@ -88,7 +118,12 @@ def matrix_rank(x: Array,
     return xp.count_nonzero(S > tol, axis=-1)
 
 def pinv(
-    x: Array, /, xp: Namespace, *, rtol: Optional[Union[float, Array]] = None, **kwargs
+    x: Array,
+    /,
+    xp: Namespace,
+    *,
+    rtol: float | Array | None = None,
+    **kwargs: object,
 ) -> Array:
     # this is different from xp.linalg.pinv, which does not multiply the
     # default tolerance by max(M, N).
@@ -104,13 +139,13 @@ def matrix_norm(
     xp: Namespace,
     *,
     keepdims: bool = False,
-    ord: Optional[Union[int, float, Literal['fro', 'nuc']]] = 'fro',
+    ord: float | Literal["fro", "nuc"] | None = "fro",
 ) -> Array:
     return xp.linalg.norm(x, axis=(-2, -1), keepdims=keepdims, ord=ord)
 
 # svdvals is not in NumPy (but it is in SciPy). It is equivalent to
 # xp.linalg.svd(compute_uv=False).
-def svdvals(x: Array, /, xp: Namespace) -> Union[Array, Tuple[Array, ...]]:
+def svdvals(x: Array, /, xp: Namespace) -> Array | tuple[Array, ...]:
     return xp.linalg.svd(x, compute_uv=False)
 
 def vector_norm(
@@ -118,9 +153,9 @@ def vector_norm(
     /,
     xp: Namespace,
     *,
-    axis: Optional[Union[int, Tuple[int, ...]]] = None,
+    axis: int | tuple[int, ...] | None = None,
     keepdims: bool = False,
-    ord: Optional[Union[int, float]] = 2,
+    ord: float = 2,
 ) -> Array:
     # xp.linalg.norm tries to do a matrix norm whenever axis is a 2-tuple or
     # when axis=None and the input is 2-D, so to force a vector norm, we make
@@ -133,7 +168,10 @@ def vector_norm(
     elif isinstance(axis, tuple):
         # Note: The axis argument supports any number of axes, whereas
         # xp.linalg.norm() only supports a single axis for vector norm.
-        normalized_axis = normalize_axis_tuple(axis, x.ndim)
+        normalized_axis = cast(
+            "tuple[int, ...]",
+            normalize_axis_tuple(axis, x.ndim),  # pyright: ignore[reportCallIssue]
+        )
         rest = tuple(i for i in range(x.ndim) if i not in normalized_axis)
         newshape = axis + rest
         _x = xp.transpose(x, newshape).reshape(
@@ -149,7 +187,13 @@ def vector_norm(
         # We can't reuse xp.linalg.norm(keepdims) because of the reshape hacks
         # above to avoid matrix norm logic.
         shape = list(x.shape)
-        _axis = normalize_axis_tuple(range(x.ndim) if axis is None else axis, x.ndim)
+        _axis = cast(
+            "tuple[int, ...]",
+            normalize_axis_tuple(  # pyright: ignore[reportCallIssue]
+                range(x.ndim) if axis is None else axis,
+                x.ndim,
+            ),
+        )
         for i in _axis:
             shape[i] = 1
         res = xp.reshape(res, tuple(shape))
@@ -159,11 +203,17 @@ def vector_norm(
 # xp.diagonal and xp.trace operate on the first two axes whereas these
 # operates on the last two
 
-def diagonal(x: Array, /, xp: Namespace, *, offset: int = 0, **kwargs) -> Array:
+def diagonal(x: Array, /, xp: Namespace, *, offset: int = 0, **kwargs: object) -> Array:
     return xp.diagonal(x, offset=offset, axis1=-2, axis2=-1, **kwargs)
 
 def trace(
-    x: Array, /, xp: Namespace, *, offset: int = 0, dtype=None, **kwargs
+    x: Array,
+    /,
+    xp: Namespace,
+    *,
+    offset: int = 0,
+    dtype: DType | None = None,
+    **kwargs: object,
 ) -> Array:
     return xp.asarray(
         xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs)
@@ -176,3 +226,7 @@ def trace(
            'trace']
 
 _all_ignore = ['math', 'normalize_axis_tuple', 'get_xp', 'np', 'isdtype']
+
+
+def __dir__() -> list[str]:
+    return __all__
diff --git a/array_api_compat/common/_typing.py b/array_api_compat/common/_typing.py
index 4c3b356b..d7deade1 100644
--- a/array_api_compat/common/_typing.py
+++ b/array_api_compat/common/_typing.py
@@ -1,24 +1,150 @@
 from __future__ import annotations
+
+from collections.abc import Mapping
 from types import ModuleType as Namespace
-from typing import Any, TypeVar, Protocol
+from typing import TYPE_CHECKING, Literal, Protocol, TypeAlias, TypedDict, TypeVar
+
+if TYPE_CHECKING:
+    from _typeshed import Incomplete
+
+    SupportsBufferProtocol: TypeAlias = Incomplete
+    Array: TypeAlias = Incomplete
+    Device: TypeAlias = Incomplete
+    DType: TypeAlias = Incomplete
+else:
+    SupportsBufferProtocol = object
+    Array = object
+    Device = object
+    DType = object
+
+
+_T_co = TypeVar("_T_co", covariant=True)
+
+
+class NestedSequence(Protocol[_T_co]):
+    def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ...
+    def __len__(self, /) -> int: ...
+
+
+class SupportsArrayNamespace(Protocol[_T_co]):
+    def __array_namespace__(self, /, *, api_version: str | None) -> _T_co: ...
+
+
+class HasShape(Protocol[_T_co]):
+    @property
+    def shape(self, /) -> _T_co: ...
+
+
+# Return type of `__array_namespace_info__.default_dtypes`
+Capabilities = TypedDict(
+    "Capabilities",
+    {
+        "boolean indexing": bool,
+        "data-dependent shapes": bool,
+        "max dimensions": int,
+    },
+)
+
+# Return type of `__array_namespace_info__.default_dtypes`
+DefaultDTypes = TypedDict(
+    "DefaultDTypes",
+    {
+        "real floating": DType,
+        "complex floating": DType,
+        "integral": DType,
+        "indexing": DType,
+    },
+)
+
+
+_DTypeKind: TypeAlias = Literal[
+    "bool",
+    "signed integer",
+    "unsigned integer",
+    "integral",
+    "real floating",
+    "complex floating",
+    "numeric",
+]
+# Type of the `kind` parameter in `__array_namespace_info__.dtypes`
+DTypeKind: TypeAlias = _DTypeKind | tuple[_DTypeKind, ...]
+
+
+# `__array_namespace_info__.dtypes(kind="bool")`
+class DTypesBool(TypedDict):
+    bool: DType
+
+
+# `__array_namespace_info__.dtypes(kind="signed integer")`
+class DTypesSigned(TypedDict):
+    int8: DType
+    int16: DType
+    int32: DType
+    int64: DType
+
+
+# `__array_namespace_info__.dtypes(kind="unsigned integer")`
+class DTypesUnsigned(TypedDict):
+    uint8: DType
+    uint16: DType
+    uint32: DType
+    uint64: DType
+
+
+# `__array_namespace_info__.dtypes(kind="integral")`
+class DTypesIntegral(DTypesSigned, DTypesUnsigned):
+    pass
+
+
+# `__array_namespace_info__.dtypes(kind="real floating")`
+class DTypesReal(TypedDict):
+    float32: DType
+    float64: DType
+
+
+# `__array_namespace_info__.dtypes(kind="complex floating")`
+class DTypesComplex(TypedDict):
+    complex64: DType
+    complex128: DType
+
+
+# `__array_namespace_info__.dtypes(kind="numeric")`
+class DTypesNumeric(DTypesIntegral, DTypesReal, DTypesComplex):
+    pass
+
+
+# `__array_namespace_info__.dtypes(kind=None)` (default)
+class DTypesAll(DTypesBool, DTypesNumeric):
+    pass
+
+
+# `__array_namespace_info__.dtypes(kind=?)` (fallback)
+DTypesAny: TypeAlias = Mapping[str, DType]
+
 
 __all__ = [
     "Array",
+    "Capabilities",
     "DType",
+    "DTypeKind",
+    "DTypesAny",
+    "DTypesAll",
+    "DTypesBool",
+    "DTypesNumeric",
+    "DTypesIntegral",
+    "DTypesSigned",
+    "DTypesUnsigned",
+    "DTypesReal",
+    "DTypesComplex",
+    "DefaultDTypes",
     "Device",
+    "HasShape",
     "Namespace",
     "NestedSequence",
+    "SupportsArrayNamespace",
     "SupportsBufferProtocol",
 ]
 
-_T_co = TypeVar("_T_co", covariant=True)
-
-class NestedSequence(Protocol[_T_co]):
-    def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ...
-    def __len__(self, /) -> int: ...
-
 
-SupportsBufferProtocol = Any
-Array = Any
-Device = Any
-DType = Any
+def __dir__() -> list[str]:
+    return __all__
diff --git a/array_api_compat/dask/array/__init__.py b/array_api_compat/dask/array/__init__.py
index bb649306..1e47b960 100644
--- a/array_api_compat/dask/array/__init__.py
+++ b/array_api_compat/dask/array/__init__.py
@@ -1,9 +1,11 @@
-from dask.array import * # noqa: F403
+from typing import Final
+
+from dask.array import *  # noqa: F403
 
 # These imports may overwrite names from the import * above.
-from ._aliases import * # noqa: F403
+from ._aliases import *  # noqa: F403
 
-__array_api_version__ = '2024.12'
+__array_api_version__: Final = "2024.12"
 
 # See the comment in the numpy __init__.py
 __import__(__package__ + '.linalg')
diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py
index e7ddde78..9687a9cd 100644
--- a/array_api_compat/dask/array/_aliases.py
+++ b/array_api_compat/dask/array/_aliases.py
@@ -1,28 +1,38 @@
+# pyright: reportPrivateUsage=false
+# pyright: reportUnknownArgumentType=false
+# pyright: reportUnknownMemberType=false
+# pyright: reportUnknownVariableType=false
+
 from __future__ import annotations
 
-from typing import Callable, Optional, Union
+from builtins import bool as py_bool
+from collections.abc import Callable
+from typing import TYPE_CHECKING, Any
+
+if TYPE_CHECKING:
+    from typing_extensions import TypeIs
 
+import dask.array as da
 import numpy as np
+from numpy import bool_ as bool
 from numpy import (
-    # dtypes
-    bool_ as bool,
+    can_cast,
+    complex64,
+    complex128,
     float32,
     float64,
     int8,
     int16,
     int32,
     int64,
+    result_type,
     uint8,
     uint16,
     uint32,
     uint64,
-    complex64,
-    complex128,
-    can_cast,
-    result_type,
 )
-import dask.array as da
 
+from ..._internal import get_xp
 from ...common import _aliases, _helpers, array_namespace
 from ...common._typing import (
     Array,
@@ -31,7 +41,6 @@
     NestedSequence,
     SupportsBufferProtocol,
 )
-from ..._internal import get_xp
 from ._info import __array_namespace_info__
 
 isdtype = get_xp(np)(_aliases.isdtype)
@@ -44,8 +53,8 @@ def astype(
     dtype: DType,
     /,
     *,
-    copy: bool = True,
-    device: Optional[Device] = None,
+    copy: py_bool = True,
+    device: Device | None = None,
 ) -> Array:
     """
     Array API compatibility wrapper for astype().
@@ -69,14 +78,14 @@ def astype(
 # not pass stop/step as keyword arguments, which will cause
 # an error with dask
 def arange(
-    start: Union[int, float],
+    start: float,
     /,
-    stop: Optional[Union[int, float]] = None,
-    step: Union[int, float] = 1,
+    stop: float | None = None,
+    step: float = 1,
     *,
-    dtype: Optional[DType] = None,
-    device: Optional[Device] = None,
-    **kwargs,
+    dtype: DType | None = None,
+    device: Device | None = None,
+    **kwargs: object,
 ) -> Array:
     """
     Array API compatibility wrapper for arange().
@@ -87,7 +96,7 @@ def arange(
     # TODO: respect device keyword?
     _helpers._check_device(da, device)
 
-    args = [start]
+    args: list[Any] = [start]
     if stop is not None:
         args.append(stop)
     else:
@@ -137,18 +146,13 @@ def arange(
 
 # asarray also adds the copy keyword, which is not present in numpy 1.0.
 def asarray(
-    obj: (
-        Array 
-        | bool | int | float | complex 
-        | NestedSequence[bool | int | float | complex] 
-        | SupportsBufferProtocol
-    ),
+    obj: complex | NestedSequence[complex] | Array | SupportsBufferProtocol,
     /,
     *,
-    dtype: Optional[DType] = None,
-    device: Optional[Device] = None,
-    copy: Optional[bool] = None,
-    **kwargs,
+    dtype: DType | None = None,
+    device: Device | None = None,
+    copy: py_bool | None = None,
+    **kwargs: object,
 ) -> Array:
     """
     Array API compatibility wrapper for asarray().
@@ -164,7 +168,7 @@ def asarray(
             if copy is False:
                 raise ValueError("Unable to avoid copy when changing dtype")
             obj = obj.astype(dtype)
-        return obj.copy() if copy else obj
+        return obj.copy() if copy else obj  # pyright: ignore[reportAttributeAccessIssue]
 
     if copy is False:
         raise NotImplementedError(
@@ -177,22 +181,21 @@ def asarray(
     return da.from_array(obj)
 
 
-from dask.array import (
-    # Element wise aliases
-    arccos as acos,
-    arccosh as acosh,
-    arcsin as asin,
-    arcsinh as asinh,
-    arctan as atan,
-    arctan2 as atan2,
-    arctanh as atanh,
-    left_shift as bitwise_left_shift,
-    right_shift as bitwise_right_shift,
-    invert as bitwise_invert,
-    power as pow,
-    # Other
-    concatenate as concat,
-)
+# Element wise aliases
+from dask.array import arccos as acos
+from dask.array import arccosh as acosh
+from dask.array import arcsin as asin
+from dask.array import arcsinh as asinh
+from dask.array import arctan as atan
+from dask.array import arctan2 as atan2
+from dask.array import arctanh as atanh
+
+# Other
+from dask.array import concatenate as concat
+from dask.array import invert as bitwise_invert
+from dask.array import left_shift as bitwise_left_shift
+from dask.array import power as pow
+from dask.array import right_shift as bitwise_right_shift
 
 
 # dask.array.clip does not work unless all three arguments are provided.
@@ -202,8 +205,8 @@ def asarray(
 def clip(
     x: Array,
     /,
-    min: Optional[Union[int, float, Array]] = None,
-    max: Optional[Union[int, float, Array]] = None,
+    min: float | Array | None = None,
+    max: float | Array | None = None,
 ) -> Array:
     """
     Array API compatibility wrapper for clip().
@@ -212,8 +215,8 @@ def clip(
     specification for more details.
     """
 
-    def _isscalar(a):
-        return isinstance(a, (int, float, type(None)))
+    def _isscalar(a: float | Array | None, /) -> TypeIs[float | None]:
+        return a is None or isinstance(a, (int, float))
 
     min_shape = () if _isscalar(min) else min.shape
     max_shape = () if _isscalar(max) else max.shape
@@ -266,7 +269,12 @@ def _ensure_single_chunk(x: Array, axis: int) -> tuple[Array, Callable[[Array],
 
 
 def sort(
-    x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True
+    x: Array,
+    /,
+    *,
+    axis: int = -1,
+    descending: py_bool = False,
+    stable: py_bool = True,
 ) -> Array:
     """
     Array API compatibility layer around the lack of sort() in Dask.
@@ -296,7 +304,12 @@ def sort(
 
 
 def argsort(
-    x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True
+    x: Array,
+    /,
+    *,
+    axis: int = -1,
+    descending: py_bool = False,
+    stable: py_bool = True,
 ) -> Array:
     """
     Array API compatibility layer around the lack of argsort() in Dask.
@@ -330,25 +343,34 @@ def argsort(
 # dask.array.count_nonzero does not have keepdims
 def count_nonzero(
     x: Array,
-    axis=None,
-    keepdims=False
+    axis: int | None = None,
+    keepdims: py_bool = False,
 ) -> Array:
-   result = da.count_nonzero(x, axis)
-   if keepdims:
-       if axis is None:
-            return da.reshape(result, [1]*x.ndim)
-       return da.expand_dims(result, axis)
-   return result
-
-
+    result = da.count_nonzero(x, axis)
+    if keepdims:
+        if axis is None:
+            return da.reshape(result, [1] * x.ndim)
+        return da.expand_dims(result, axis)
+    return result
+
+
+__all__ = [
+    "__array_namespace_info__",
+    "count_nonzero",
+    "bool",
+    "int8", "int16", "int32", "int64",
+    "uint8", "uint16", "uint32", "uint64",
+    "float32", "float64",
+    "complex64", "complex128",
+    "asarray", "astype", "can_cast", "result_type",
+    "pow",
+    "concat",
+    "acos", "acosh", "asin", "asinh", "atan", "atan2", "atanh",
+    "bitwise_left_shift", "bitwise_right_shift", "bitwise_invert",
+]  # fmt: skip
+__all__ += _aliases.__all__
+_all_ignore = ["array_namespace", "get_xp", "da", "np"]
 
-__all__ = _aliases.__all__ + [
-                    '__array_namespace_info__', 'asarray', 'astype', 'acos',
-                    'acosh', 'asin', 'asinh', 'atan', 'atan2',
-                    'atanh', 'bitwise_left_shift', 'bitwise_invert',
-                    'bitwise_right_shift', 'concat', 'pow', 'can_cast',
-                    'result_type', 'bool', 'float32', 'float64', 'int8', 'int16', 'int32', 'int64',
-                    'uint8', 'uint16', 'uint32', 'uint64', 'complex64', 'complex128',
-                    'can_cast', 'count_nonzero', 'result_type']
 
-_all_ignore = ["array_namespace", "get_xp", "da", "np"]
+def __dir__() -> list[str]:
+    return __all__
diff --git a/array_api_compat/dask/array/_info.py b/array_api_compat/dask/array/_info.py
index 614f43d9..9e4d736f 100644
--- a/array_api_compat/dask/array/_info.py
+++ b/array_api_compat/dask/array/_info.py
@@ -7,25 +7,51 @@
 more details.
 
 """
+
+# pyright: reportPrivateUsage=false
+
+from __future__ import annotations
+
+from typing import Literal as L
+from typing import TypeAlias, overload
+
+from numpy import bool_ as bool
 from numpy import (
+    complex64,
+    complex128,
     dtype,
-    bool_ as bool,
-    intp,
+    float32,
+    float64,
     int8,
     int16,
     int32,
     int64,
+    intp,
     uint8,
     uint16,
     uint32,
     uint64,
-    float32,
-    float64,
-    complex64,
-    complex128,
 )
 
-from ...common._helpers import _DASK_DEVICE
+from ...common._helpers import _DASK_DEVICE, _dask_device
+from ...common._typing import (
+    Capabilities,
+    DefaultDTypes,
+    DType,
+    DTypeKind,
+    DTypesAll,
+    DTypesAny,
+    DTypesBool,
+    DTypesComplex,
+    DTypesIntegral,
+    DTypesNumeric,
+    DTypesReal,
+    DTypesSigned,
+    DTypesUnsigned,
+)
+
+_Device: TypeAlias = L["cpu"] | _dask_device
+
 
 class __array_namespace_info__:
     """
@@ -59,9 +85,9 @@ class __array_namespace_info__:
 
     """
 
-    __module__ = 'dask.array'
+    __module__ = "dask.array"
 
-    def capabilities(self):
+    def capabilities(self) -> Capabilities:
         """
         Return a dictionary of array API library capabilities.
 
@@ -116,7 +142,7 @@ def capabilities(self):
             "max dimensions": 64,
         }
 
-    def default_device(self):
+    def default_device(self) -> L["cpu"]:
         """
         The default device used for new Dask arrays.
 
@@ -143,7 +169,7 @@ def default_device(self):
         """
         return "cpu"
 
-    def default_dtypes(self, *, device=None):
+    def default_dtypes(self, /, *, device: _Device | None = None) -> DefaultDTypes:
         """
         The default data types used for new Dask arrays.
 
@@ -184,8 +210,8 @@ def default_dtypes(self, *, device=None):
         """
         if device not in ["cpu", _DASK_DEVICE, None]:
             raise ValueError(
-                'Device not understood. Only "cpu" or _DASK_DEVICE is allowed, but received:'
-                f' {device}'
+                f'Device not understood. Only "cpu" or _DASK_DEVICE is allowed, '
+                f"but received: {device!r}"
             )
         return {
             "real floating": dtype(float64),
@@ -194,7 +220,41 @@ def default_dtypes(self, *, device=None):
             "indexing": dtype(intp),
         }
 
-    def dtypes(self, *, device=None, kind=None):
+    @overload
+    def dtypes(
+        self, /, *, device: _Device | None = None, kind: None = None
+    ) -> DTypesAll: ...
+    @overload
+    def dtypes(
+        self, /, *, device: _Device | None = None, kind: L["bool"]
+    ) -> DTypesBool: ...
+    @overload
+    def dtypes(
+        self, /, *, device: _Device | None = None, kind: L["signed integer"]
+    ) -> DTypesSigned: ...
+    @overload
+    def dtypes(
+        self, /, *, device: _Device | None = None, kind: L["unsigned integer"]
+    ) -> DTypesUnsigned: ...
+    @overload
+    def dtypes(
+        self, /, *, device: _Device | None = None, kind: L["integral"]
+    ) -> DTypesIntegral: ...
+    @overload
+    def dtypes(
+        self, /, *, device: _Device | None = None, kind: L["real floating"]
+    ) -> DTypesReal: ...
+    @overload
+    def dtypes(
+        self, /, *, device: _Device | None = None, kind: L["complex floating"]
+    ) -> DTypesComplex: ...
+    @overload
+    def dtypes(
+        self, /, *, device: _Device | None = None, kind: L["numeric"]
+    ) -> DTypesNumeric: ...
+    def dtypes(
+        self, /, *, device: _Device | None = None, kind: DTypeKind | None = None
+    ) -> DTypesAny:
         """
         The array API data types supported by Dask.
 
@@ -251,7 +311,7 @@ def dtypes(self, *, device=None, kind=None):
         if device not in ["cpu", _DASK_DEVICE, None]:
             raise ValueError(
                 'Device not understood. Only "cpu" or _DASK_DEVICE is allowed, but received:'
-                f' {device}'
+                f" {device}"
             )
         if kind is None:
             return {
@@ -321,14 +381,14 @@ def dtypes(self, *, device=None, kind=None):
                 "complex64": dtype(complex64),
                 "complex128": dtype(complex128),
             }
-        if isinstance(kind, tuple):
-            res = {}
+        if isinstance(kind, tuple):  # type: ignore[reportUnnecessaryIsinstanceCall]
+            res: dict[str, DType] = {}
             for k in kind:
                 res.update(self.dtypes(kind=k))
             return res
         raise ValueError(f"unsupported kind: {kind!r}")
 
-    def devices(self):
+    def devices(self) -> list[_Device]:
         """
         The devices supported by Dask.
 
diff --git a/array_api_compat/dask/array/linalg.py b/array_api_compat/dask/array/linalg.py
index bd53f0df..0825386e 100644
--- a/array_api_compat/dask/array/linalg.py
+++ b/array_api_compat/dask/array/linalg.py
@@ -3,15 +3,16 @@
 from typing import Literal
 
 import dask.array as da
+
+# The `matmul` and `tensordot` functions are in both the main and linalg namespaces
+from dask.array import matmul, outer, tensordot
+
 # Exports
-from dask.array.linalg import * # noqa: F403
-from dask.array import outer
-# These functions are in both the main and linalg namespaces
-from dask.array import matmul, tensordot
+from dask.array.linalg import *  # noqa: F403
 
 from ..._internal import get_xp
 from ...common import _linalg
-from ...common._typing import Array
+from ...common._typing import Array as _Array
 from ._aliases import matrix_transpose, vecdot
 
 # dask.array.linalg doesn't have __all__. If it is added, replace this with
@@ -32,8 +33,11 @@
 # supports the mode keyword on QR
 # https://github.com/dask/dask/issues/10388
 #qr = get_xp(da)(_linalg.qr)
-def qr(x: Array, mode: Literal['reduced', 'complete'] = 'reduced',
-       **kwargs) -> QRResult:
+def qr(
+    x: _Array,
+    mode: Literal["reduced", "complete"] = "reduced",
+    **kwargs: object,
+) -> QRResult:
     if mode != "reduced":
         raise ValueError("dask arrays only support using mode='reduced'")
     return QRResult(*da.linalg.qr(x, **kwargs))
@@ -46,12 +50,12 @@ def qr(x: Array, mode: Literal['reduced', 'complete'] = 'reduced',
 # Wrap the svd functions to not pass full_matrices to dask
 # when full_matrices=False (as that is the default behavior for dask),
 # and dask doesn't have the full_matrices keyword
-def svd(x: Array, full_matrices: bool = True, **kwargs) -> SVDResult:
+def svd(x: _Array, full_matrices: bool = True, **kwargs) -> SVDResult:
     if full_matrices:
         raise ValueError("full_matrics=True is not supported by dask.")
     return da.linalg.svd(x, coerce_signs=False, **kwargs)
 
-def svdvals(x: Array) -> Array:
+def svdvals(x: _Array) -> _Array:
     # TODO: can't avoid computing U or V for dask
     _, s, _ =  svd(x)
     return s
diff --git a/array_api_compat/numpy/__init__.py b/array_api_compat/numpy/__init__.py
index 6a5d9867..f7b558ba 100644
--- a/array_api_compat/numpy/__init__.py
+++ b/array_api_compat/numpy/__init__.py
@@ -1,10 +1,16 @@
-from numpy import * # noqa: F403
+# ruff: noqa: PLC0414
+from typing import Final
+
+from numpy import *  # noqa: F403  # pyright: ignore[reportWildcardImportFromLibrary]
 
 # from numpy import * doesn't overwrite these builtin names
-from numpy import abs, max, min, round # noqa: F401
+from numpy import abs as abs
+from numpy import max as max
+from numpy import min as min
+from numpy import round as round
 
 # These imports may overwrite names from the import * above.
-from ._aliases import * # noqa: F403
+from ._aliases import *  # noqa: F403
 
 # Don't know why, but we have to do an absolute import to import linalg. If we
 # instead do
@@ -13,9 +19,17 @@
 #
 # It doesn't overwrite np.linalg from above. The import is generated
 # dynamically so that the library can be vendored.
-__import__(__package__ + '.linalg')
-__import__(__package__ + '.fft')
+__import__(__package__ + ".linalg")
+
+__import__(__package__ + ".fft")
+
+from ..common._helpers import *  # noqa: F403
+from .linalg import matrix_transpose, vecdot  # noqa: F401
 
-from .linalg import matrix_transpose, vecdot # noqa: F401
+try:
+    # Used in asarray(). Not present in older versions.
+    from numpy import _CopyMode  # noqa: F401
+except ImportError:
+    pass
 
-__array_api_version__ = '2024.12'
+__array_api_version__: Final = "2024.12"
diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py
index d1fd46a1..d8792611 100644
--- a/array_api_compat/numpy/_aliases.py
+++ b/array_api_compat/numpy/_aliases.py
@@ -1,6 +1,10 @@
+# pyright: reportPrivateUsage=false
 from __future__ import annotations
 
-from typing import Optional, Union
+from builtins import bool as py_bool
+from typing import TYPE_CHECKING, Any, Literal, TypeAlias, cast
+
+import numpy as np
 
 from .._internal import get_xp
 from ..common import _aliases, _helpers
@@ -8,7 +12,12 @@
 from ._info import __array_namespace_info__
 from ._typing import Array, Device, DType
 
-import numpy as np
+if TYPE_CHECKING:
+    from typing_extensions import Buffer, TypeIs
+
+# The values of the `_CopyMode` enum can be either `False`, `True`, or `2`:
+# https://github.com/numpy/numpy/blob/5a8a6a79d9c2fff8f07dcab5d41e14f8508d673f/numpy/_globals.pyi#L7-L10
+_Copy: TypeAlias = py_bool | Literal[2] | np._CopyMode
 
 bool = np.bool_
 
@@ -65,9 +74,9 @@
 iinfo = get_xp(np)(_aliases.iinfo)
 
 
-def _supports_buffer_protocol(obj):
+def _supports_buffer_protocol(obj: object) -> TypeIs[Buffer]:  # pyright: ignore[reportUnusedFunction]
     try:
-        memoryview(obj)
+        memoryview(obj)  # pyright: ignore[reportArgumentType]
     except TypeError:
         return False
     return True
@@ -78,18 +87,13 @@ def _supports_buffer_protocol(obj):
 # complicated enough that it's easier to define it separately for each module
 # rather than trying to combine everything into one function in common/
 def asarray(
-    obj: (
-        Array 
-        | bool | int | float | complex 
-        | NestedSequence[bool | int | float | complex] 
-        | SupportsBufferProtocol
-    ),
+    obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol,
     /,
     *,
-    dtype: Optional[DType] = None,
-    device: Optional[Device] = None,
-    copy: Optional[Union[bool, np._CopyMode]] = None,
-    **kwargs,
+    dtype: DType | None = None,
+    device: Device | None = None,
+    copy: _Copy | None = None,
+    **kwargs: Any,
 ) -> Array:
     """
     Array API compatibility wrapper for asarray().
@@ -106,7 +110,7 @@ def asarray(
     elif copy is True:
         copy = np._CopyMode.ALWAYS
 
-    return np.array(obj, copy=copy, dtype=dtype, **kwargs)
+    return np.array(obj, copy=copy, dtype=dtype, **kwargs)  # pyright: ignore
 
 
 def astype(
@@ -114,8 +118,8 @@ def astype(
     dtype: DType,
     /,
     *,
-    copy: bool = True,
-    device: Optional[Device] = None,
+    copy: py_bool = True,
+    device: Device | None = None,
 ) -> Array:
     _helpers._check_device(np, device)
     return x.astype(dtype=dtype, copy=copy)
@@ -123,8 +127,14 @@ def astype(
 
 # count_nonzero returns a python int for axis=None and keepdims=False
 # https://github.com/numpy/numpy/issues/17562
-def count_nonzero(x: Array, axis=None, keepdims=False) -> Array:
-    result = np.count_nonzero(x, axis=axis, keepdims=keepdims)
+def count_nonzero(
+    x: Array,
+    axis: int | tuple[int, ...] | None = None,
+    keepdims: py_bool = False,
+) -> Array:
+    # NOTE: this is currently incorrectly typed in numpy, but will be fixed in
+    # numpy 2.2.5 and 2.3.0: https://github.com/numpy/numpy/pull/28750
+    result = cast("Any", np.count_nonzero(x, axis=axis, keepdims=keepdims))  # pyright: ignore[reportArgumentType, reportCallIssue]
     if axis is None and not keepdims:
         return np.asarray(result)
     return result
@@ -132,25 +142,43 @@ def count_nonzero(x: Array, axis=None, keepdims=False) -> Array:
 
 # These functions are completely new here. If the library already has them
 # (i.e., numpy 2.0), use the library version instead of our wrapper.
-if hasattr(np, 'vecdot'):
+if hasattr(np, "vecdot"):
     vecdot = np.vecdot
 else:
     vecdot = get_xp(np)(_aliases.vecdot)
 
-if hasattr(np, 'isdtype'):
+if hasattr(np, "isdtype"):
     isdtype = np.isdtype
 else:
     isdtype = get_xp(np)(_aliases.isdtype)
 
-if hasattr(np, 'unstack'):
+if hasattr(np, "unstack"):
     unstack = np.unstack
 else:
     unstack = get_xp(np)(_aliases.unstack)
 
-__all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'astype',
-                              'acos', 'acosh', 'asin', 'asinh', 'atan',
-                              'atan2', 'atanh', 'bitwise_left_shift',
-                              'bitwise_invert', 'bitwise_right_shift',
-                              'bool', 'concat', 'count_nonzero', 'pow']
-
-_all_ignore = ['np', 'get_xp']
+__all__ = [
+    "__array_namespace_info__",
+    "asarray",
+    "astype",
+    "acos",
+    "acosh",
+    "asin",
+    "asinh",
+    "atan",
+    "atan2",
+    "atanh",
+    "bitwise_left_shift",
+    "bitwise_invert",
+    "bitwise_right_shift",
+    "bool",
+    "concat",
+    "count_nonzero",
+    "pow",
+]
+__all__ += _aliases.__all__
+_all_ignore = ["np", "get_xp"]
+
+
+def __dir__() -> list[str]:
+    return __all__
diff --git a/array_api_compat/numpy/_info.py b/array_api_compat/numpy/_info.py
index 365855b8..f307f62c 100644
--- a/array_api_compat/numpy/_info.py
+++ b/array_api_compat/numpy/_info.py
@@ -7,24 +7,28 @@
 more details.
 
 """
+from __future__ import annotations
+
+from numpy import bool_ as bool
 from numpy import (
+    complex64,
+    complex128,
     dtype,
-    bool_ as bool,
-    intp,
+    float32,
+    float64,
     int8,
     int16,
     int32,
     int64,
+    intp,
     uint8,
     uint16,
     uint32,
     uint64,
-    float32,
-    float64,
-    complex64,
-    complex128,
 )
 
+from ._typing import Device, DType
+
 
 class __array_namespace_info__:
     """
@@ -131,7 +135,11 @@ def default_device(self):
         """
         return "cpu"
 
-    def default_dtypes(self, *, device=None):
+    def default_dtypes(
+        self,
+        *,
+        device: Device | None = None,
+    ) -> dict[str, dtype[intp | float64 | complex128]]:
         """
         The default data types used for new NumPy arrays.
 
@@ -183,7 +191,12 @@ def default_dtypes(self, *, device=None):
             "indexing": dtype(intp),
         }
 
-    def dtypes(self, *, device=None, kind=None):
+    def dtypes(
+        self,
+        *,
+        device: Device | None = None,
+        kind: str | tuple[str, ...] | None = None,
+    ) -> dict[str, DType]:
         """
         The array API data types supported by NumPy.
 
@@ -260,7 +273,7 @@ def dtypes(self, *, device=None, kind=None):
                 "complex128": dtype(complex128),
             }
         if kind == "bool":
-            return {"bool": bool}
+            return {"bool": dtype(bool)}
         if kind == "signed integer":
             return {
                 "int8": dtype(int8),
@@ -312,13 +325,13 @@ def dtypes(self, *, device=None, kind=None):
                 "complex128": dtype(complex128),
             }
         if isinstance(kind, tuple):
-            res = {}
+            res: dict[str, DType] = {}
             for k in kind:
                 res.update(self.dtypes(kind=k))
             return res
         raise ValueError(f"unsupported kind: {kind!r}")
 
-    def devices(self):
+    def devices(self) -> list[Device]:
         """
         The devices supported by NumPy.
 
@@ -344,3 +357,10 @@ def devices(self):
 
         """
         return ["cpu"]
+
+
+__all__ = ["__array_namespace_info__"]
+
+
+def __dir__() -> list[str]:
+    return __all__
diff --git a/array_api_compat/numpy/_typing.py b/array_api_compat/numpy/_typing.py
index a6c96924..e771c788 100644
--- a/array_api_compat/numpy/_typing.py
+++ b/array_api_compat/numpy/_typing.py
@@ -1,31 +1,30 @@
 from __future__ import annotations
 
-__all__ = ["Array", "DType", "Device"]
-_all_ignore = ["np"]
-
-from typing import Literal, TYPE_CHECKING
+from typing import TYPE_CHECKING, Any, Literal, TypeAlias
 
 import numpy as np
-from numpy import ndarray as Array
 
-Device = Literal["cpu"]
+Device: TypeAlias = Literal["cpu"]
+
 if TYPE_CHECKING:
+
     # NumPy 1.x on Python 3.10 fails to parse np.dtype[]
-    DType = np.dtype[
-        np.intp
-        | np.int8
-        | np.int16
-        | np.int32
-        | np.int64
-        | np.uint8
-        | np.uint16
-        | np.uint32
-        | np.uint64
+    DType: TypeAlias = np.dtype[
+        np.bool_
+        | np.integer[Any]
         | np.float32
         | np.float64
         | np.complex64
         | np.complex128
-        | np.bool
     ]
+    Array: TypeAlias = np.ndarray[Any, DType]
 else:
-    DType = np.dtype
+    DType: TypeAlias = np.dtype
+    Array: TypeAlias = np.ndarray
+
+__all__ = ["Array", "DType", "Device"]
+_all_ignore = ["np"]
+
+
+def __dir__() -> list[str]:
+    return __all__
diff --git a/array_api_compat/numpy/fft.py b/array_api_compat/numpy/fft.py
index 28667594..06875f00 100644
--- a/array_api_compat/numpy/fft.py
+++ b/array_api_compat/numpy/fft.py
@@ -1,10 +1,9 @@
-from numpy.fft import * # noqa: F403
+import numpy as np
 from numpy.fft import __all__ as fft_all
+from numpy.fft import fft2, ifft2, irfft2, rfft2
 
-from ..common import _fft
 from .._internal import get_xp
-
-import numpy as np
+from ..common import _fft
 
 fft = get_xp(np)(_fft.fft)
 ifft = get_xp(np)(_fft.ifft)
@@ -21,7 +20,14 @@
 fftshift = get_xp(np)(_fft.fftshift)
 ifftshift = get_xp(np)(_fft.ifftshift)
 
-__all__ = fft_all + _fft.__all__
+
+__all__ = ["rfft2", "irfft2", "fft2", "ifft2"]
+__all__ += _fft.__all__
+
+
+def __dir__() -> list[str]:
+    return __all__
+
 
 del get_xp
 del np
diff --git a/array_api_compat/numpy/linalg.py b/array_api_compat/numpy/linalg.py
index 8f01593b..2d3e731d 100644
--- a/array_api_compat/numpy/linalg.py
+++ b/array_api_compat/numpy/linalg.py
@@ -1,14 +1,35 @@
-from numpy.linalg import * # noqa: F403
-from numpy.linalg import __all__ as linalg_all
-import numpy as _np
+# pyright: reportAttributeAccessIssue=false
+# pyright: reportUnknownArgumentType=false
+# pyright: reportUnknownMemberType=false
+# pyright: reportUnknownVariableType=false
+
+from __future__ import annotations
+
+import numpy as np
+
+# intersection of `np.linalg.__all__` on numpy 1.22 and 2.2, minus `_linalg.__all__`
+from numpy.linalg import (
+    LinAlgError,
+    cond,
+    det,
+    eig,
+    eigvals,
+    eigvalsh,
+    inv,
+    lstsq,
+    matrix_power,
+    multi_dot,
+    norm,
+    tensorinv,
+    tensorsolve,
+)
 
-from ..common import _linalg
 from .._internal import get_xp
+from ..common import _linalg
 
 # These functions are in both the main and linalg namespaces
-from ._aliases import matmul, matrix_transpose, tensordot, vecdot # noqa: F401
-
-import numpy as np
+from ._aliases import matmul, matrix_transpose, tensordot, vecdot  # noqa: F401
+from ._typing import Array
 
 cross = get_xp(np)(_linalg.cross)
 outer = get_xp(np)(_linalg.outer)
@@ -38,19 +59,28 @@
 # To workaround this, the below is the code from np.linalg.solve except
 # only calling solve1 in the exactly 1D case.
 
+
 # This code is here instead of in common because it is numpy specific. Also
 # note that CuPy's solve() does not currently support broadcasting (see
 # https://github.com/cupy/cupy/blob/main/cupy/cublas.py#L43).
-def solve(x1: _np.ndarray, x2: _np.ndarray, /) -> _np.ndarray:
+def solve(x1: Array, x2: Array, /) -> Array:
     try:
         from numpy.linalg._linalg import (
-        _makearray, _assert_stacked_2d, _assert_stacked_square,
-        _commonType, isComplexType, _raise_linalgerror_singular
+            _assert_stacked_2d,
+            _assert_stacked_square,
+            _commonType,
+            _makearray,
+            _raise_linalgerror_singular,
+            isComplexType,
         )
     except ImportError:
         from numpy.linalg.linalg import (
-        _makearray, _assert_stacked_2d, _assert_stacked_square,
-        _commonType, isComplexType, _raise_linalgerror_singular
+            _assert_stacked_2d,
+            _assert_stacked_square,
+            _commonType,
+            _makearray,
+            _raise_linalgerror_singular,
+            isComplexType,
         )
     from numpy.linalg import _umath_linalg
 
@@ -61,6 +91,7 @@ def solve(x1: _np.ndarray, x2: _np.ndarray, /) -> _np.ndarray:
     t, result_t = _commonType(x1, x2)
 
     # This part is different from np.linalg.solve
+    gufunc: np.ufunc
     if x2.ndim == 1:
         gufunc = _umath_linalg.solve1
     else:
@@ -68,23 +99,45 @@ def solve(x1: _np.ndarray, x2: _np.ndarray, /) -> _np.ndarray:
 
     # This does nothing currently but is left in because it will be relevant
     # when complex dtype support is added to the spec in 2022.
-    signature = 'DD->D' if isComplexType(t) else 'dd->d'
-    with _np.errstate(call=_raise_linalgerror_singular, invalid='call',
-                      over='ignore', divide='ignore', under='ignore'):
-        r = gufunc(x1, x2, signature=signature)
+    signature = "DD->D" if isComplexType(t) else "dd->d"
+    with np.errstate(
+        call=_raise_linalgerror_singular,
+        invalid="call",
+        over="ignore",
+        divide="ignore",
+        under="ignore",
+    ):
+        r: Array = gufunc(x1, x2, signature=signature)
 
     return wrap(r.astype(result_t, copy=False))
 
+
 # These functions are completely new here. If the library already has them
 # (i.e., numpy 2.0), use the library version instead of our wrapper.
-if hasattr(np.linalg, 'vector_norm'):
+if hasattr(np.linalg, "vector_norm"):
     vector_norm = np.linalg.vector_norm
 else:
     vector_norm = get_xp(np)(_linalg.vector_norm)
 
-__all__ = linalg_all + _linalg.__all__ + ['solve']
 
-del get_xp
-del np
-del linalg_all
-del _linalg
+__all__ = [
+    "LinAlgError",
+    "cond",
+    "det",
+    "eig",
+    "eigvals",
+    "eigvalsh",
+    "inv",
+    "lstsq",
+    "matrix_power",
+    "multi_dot",
+    "norm",
+    "tensorinv",
+    "tensorsolve",
+]
+__all__ += _linalg.__all__
+__all__ += ["solve", "vector_norm"]
+
+
+def __dir__() -> list[str]:
+    return __all__