|
4 | 4 | import warnings
|
5 | 5 | from collections.abc import Callable, Sequence
|
6 | 6 | from types import ModuleType, NoneType
|
7 |
| -from typing import cast, overload |
| 7 | +from typing import Literal, cast, overload |
8 | 8 |
|
9 | 9 | from ._at import at
|
10 | 10 | from ._utils import _compat, _helpers
|
|
16 | 16 | meta_namespace,
|
17 | 17 | ndindex,
|
18 | 18 | )
|
19 |
| -from ._utils._typing import Array |
| 19 | +from ._utils._typing import Array, Device, DType |
20 | 20 |
|
21 | 21 | __all__ = [
|
22 | 22 | "apply_where",
|
@@ -438,6 +438,44 @@ def create_diagonal(
|
438 | 438 | return xp.reshape(diag, (*batch_dims, n, n))
|
439 | 439 |
|
440 | 440 |
|
| 441 | +def default_dtype( |
| 442 | + xp: ModuleType, |
| 443 | + kind: Literal[ |
| 444 | + "real floating", "complex floating", "integral", "indexing" |
| 445 | + ] = "real floating", |
| 446 | + *, |
| 447 | + device: Device | None = None, |
| 448 | +) -> DType: |
| 449 | + """ |
| 450 | + Return the default dtype for the given namespace and device. |
| 451 | +
|
| 452 | + This is a convenience shorthand for |
| 453 | + ``xp.__array_namespace_info__().default_dtypes(device=device)[kind]``. |
| 454 | +
|
| 455 | + Parameters |
| 456 | + ---------- |
| 457 | + xp : array_namespace |
| 458 | + The standard-compatible namespace for which to get the default dtype. |
| 459 | + kind : {'real floating', 'complex floating', 'integral', 'indexing'}, optional |
| 460 | + The kind of dtype to return. Default is 'real floating'. |
| 461 | + device : Device, optional |
| 462 | + The device for which to get the default dtype. Default: current device. |
| 463 | +
|
| 464 | + Returns |
| 465 | + ------- |
| 466 | + dtype |
| 467 | + The default dtype for the given namespace, kind, and device. |
| 468 | + """ |
| 469 | + dtypes = xp.__array_namespace_info__().default_dtypes(device=device) |
| 470 | + try: |
| 471 | + return dtypes[kind] |
| 472 | + except KeyError as e: |
| 473 | + domain = ("real floating", "complex floating", "integral", "indexing") |
| 474 | + assert set(dtypes) == set(domain), f"Non-compliant namespace: {dtypes}" |
| 475 | + msg = f"Unknown kind '{kind}'. Expected one of {domain}." |
| 476 | + raise ValueError(msg) from e |
| 477 | + |
| 478 | + |
441 | 479 | def expand_dims(
|
442 | 480 | a: Array, /, *, axis: int | tuple[int, ...] = (0,), xp: ModuleType | None = None
|
443 | 481 | ) -> Array:
|
@@ -728,9 +766,7 @@ def nunique(x: Array, /, *, xp: ModuleType | None = None) -> Array:
|
728 | 766 | x = xp.reshape(x, (-1,))
|
729 | 767 | x = xp.sort(x)
|
730 | 768 | mask = x != xp.roll(x, -1)
|
731 |
| - default_int = xp.__array_namespace_info__().default_dtypes( |
732 |
| - device=_compat.device(x) |
733 |
| - )["integral"] |
| 769 | + default_int = default_dtype(xp, "integral", device=_compat.device(x)) |
734 | 770 | return xp.maximum(
|
735 | 771 | # Special cases:
|
736 | 772 | # - array is size 0
|
|
0 commit comments