Skip to content

Commit b8e6ab7

Browse files
committed
DIms_type
1 parent 357a444 commit b8e6ab7

File tree

5 files changed

+37
-40
lines changed

5 files changed

+37
-40
lines changed

xarray/core/computation.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from xarray.core.parallelcompat import get_chunked_array_type
2525
from xarray.core.pycompat import is_chunked_array, is_duck_dask_array
2626
from xarray.core.types import Dims, T_DataArray
27-
from xarray.core.utils import is_dict_like, is_scalar
27+
from xarray.core.utils import is_dict_like, is_scalar, parse_dims
2828
from xarray.core.variable import Variable
2929
from xarray.util.deprecation_helpers import deprecate_dims
3030

@@ -1775,7 +1775,7 @@ def dot(
17751775
----------
17761776
*arrays : DataArray or Variable
17771777
Arrays to compute.
1778-
dim : str, iterable of hashable, "..." or None, optional
1778+
dim : Hashable, iterable of Hashable, ..., or None, optional
17791779
Which dimensions to sum over. Ellipsis ('...') sums over all dimensions.
17801780
If not specified, then all the common dimensions are summed over.
17811781
**kwargs : dict
@@ -1875,16 +1875,14 @@ def dot(
18751875
einsum_axes = "abcdefghijklmnopqrstuvwxyz"
18761876
dim_map = {d: einsum_axes[i] for i, d in enumerate(all_dims)}
18771877

1878-
if dim is ...:
1879-
dim = all_dims
1880-
elif isinstance(dim, str):
1881-
dim = (dim,)
1882-
elif dim is None:
1883-
# find dimensions that occur more than one times
1878+
if dim is None:
1879+
# find dimensions that occur more than once
18841880
dim_counts: Counter = Counter()
18851881
for arr in arrays:
18861882
dim_counts.update(arr.dims)
18871883
dim = tuple(d for d, c in dim_counts.items() if c > 1)
1884+
else:
1885+
dim = parse_dims(dim, all_dims=tuple(all_dims))
18881886

18891887
dot_dims: set[Hashable] = set(dim)
18901888

xarray/core/types.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import datetime
44
import sys
5-
from collections.abc import Hashable, Iterable, Iterator, Mapping, Sequence
5+
from collections.abc import Collection, Hashable, Iterator, Mapping, Sequence
66
from typing import (
77
TYPE_CHECKING,
88
Any,
@@ -182,8 +182,7 @@ def copy(
182182
DsCompatible = Union["Dataset", "DaCompatible"]
183183
GroupByCompatible = Union["Dataset", "DataArray"]
184184

185-
Dims = Union[str, Iterable[Hashable], "ellipsis", None]
186-
OrderedDims = Union[str, Sequence[Union[Hashable, "ellipsis"]], "ellipsis", None]
185+
Dims = Union[Hashable, Collection[Hashable], None] # Note: Hashable include ellipsis
187186

188187
# FYI in some cases we don't allow `None`, which this doesn't take account of.
189188
T_ChunkDim: TypeAlias = Union[int, Literal["auto"], None, tuple[int, ...]]

xarray/core/utils.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@
7676
import pandas as pd
7777

7878
if TYPE_CHECKING:
79-
from xarray.core.types import Dims, ErrorOptionsWithWarn, OrderedDims, T_DuckArray
79+
from xarray.core.types import Dims, ErrorOptionsWithWarn, T_DuckArray
8080

8181
K = TypeVar("K")
8282
V = TypeVar("V")
@@ -988,7 +988,7 @@ def drop_missing_dims(
988988

989989
@overload
990990
def parse_dims(
991-
dim: str | Iterable[Hashable] | T_None,
991+
dim: Dims,
992992
all_dims: tuple[Hashable, ...],
993993
*,
994994
check_exists: bool = True,
@@ -999,12 +999,12 @@ def parse_dims(
999999

10001000
@overload
10011001
def parse_dims(
1002-
dim: str | Iterable[Hashable] | T_None,
1002+
dim: Dims,
10031003
all_dims: tuple[Hashable, ...],
10041004
*,
10051005
check_exists: bool = True,
10061006
replace_none: Literal[False],
1007-
) -> tuple[Hashable, ...] | T_None:
1007+
) -> tuple[Hashable, ...] | None | ellipsis:
10081008
...
10091009

10101010

@@ -1042,16 +1042,19 @@ def parse_dims(
10421042
if replace_none:
10431043
return all_dims
10441044
return dim
1045-
if isinstance(dim, str):
1045+
if isinstance(dim, Collection) and not isinstance(dim, str):
1046+
dim = tuple(dim)
1047+
else:
10461048
dim = (dim,)
1049+
10471050
if check_exists:
10481051
_check_dims(set(dim), set(all_dims))
1049-
return tuple(dim)
1052+
return dim
10501053

10511054

10521055
@overload
10531056
def parse_ordered_dims(
1054-
dim: str | Sequence[Hashable | ellipsis] | T_None,
1057+
dim: Dims,
10551058
all_dims: tuple[Hashable, ...],
10561059
*,
10571060
check_exists: bool = True,
@@ -1062,7 +1065,7 @@ def parse_ordered_dims(
10621065

10631066
@overload
10641067
def parse_ordered_dims(
1065-
dim: str | Sequence[Hashable | ellipsis] | T_None,
1068+
dim: Dims,
10661069
all_dims: tuple[Hashable, ...],
10671070
*,
10681071
check_exists: bool = True,
@@ -1072,24 +1075,22 @@ def parse_ordered_dims(
10721075

10731076

10741077
def parse_ordered_dims(
1075-
dim: OrderedDims,
1078+
dim: Dims,
10761079
all_dims: tuple[Hashable, ...],
10771080
*,
10781081
check_exists: bool = True,
10791082
replace_none: bool = True,
10801083
) -> tuple[Hashable, ...] | None | ellipsis:
10811084
"""Parse one or more dimensions.
10821085
1083-
A single dimension must be always a str, multiple dimensions
1084-
can be Hashables. This supports e.g. using a tuple as a dimension.
10851086
An ellipsis ("...") in a sequence of dimensions will be
10861087
replaced with all remaining dimensions. This only makes sense when
10871088
the input is a sequence and not e.g. a set.
10881089
10891090
Parameters
10901091
----------
1091-
dim : str, Sequence of Hashable or "...", "..." or None
1092-
Dimension(s) to parse. If "..." appears in a Sequence
1092+
dim : Hashable or ... (or Sequence thereof), or None
1093+
Dimension(s) to parse. If ... appears in a Sequence
10931094
it always gets replaced with all remaining dims
10941095
all_dims : tuple of Hashable
10951096
All possible dimensions.
@@ -1103,7 +1104,7 @@ def parse_ordered_dims(
11031104
parsed_dims : tuple of Hashable
11041105
Input dimensions as a tuple.
11051106
"""
1106-
if dim is not None and dim is not ... and not isinstance(dim, str) and ... in dim:
1107+
if isinstance(dim, Sequence) and not isinstance(dim, str) and ... in dim:
11071108
dims_set: set[Hashable | ellipsis] = set(dim)
11081109
all_dims_set = set(all_dims)
11091110
if check_exists:
@@ -1126,9 +1127,9 @@ def parse_ordered_dims(
11261127
)
11271128

11281129

1129-
def _check_dims(dim: set[Hashable | ellipsis], all_dims: set[Hashable]) -> None:
1130-
wrong_dims = dim - all_dims
1131-
if wrong_dims and wrong_dims != {...}:
1130+
def _check_dims(dim: set[Hashable], all_dims: set[Hashable]) -> None:
1131+
wrong_dims = (dim - all_dims) - {...}
1132+
if wrong_dims:
11321133
wrong_dims_str = ", ".join(f"'{d!s}'" for d in wrong_dims)
11331134
raise ValueError(
11341135
f"Dimension(s) {wrong_dims_str} do not exist. Expected one or more of {all_dims}"

xarray/tests/test_interp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -838,8 +838,8 @@ def test_interpolate_chunk_1d(
838838
if chunked:
839839
dest[dim] = xr.DataArray(data=dest[dim], dims=[dim])
840840
dest[dim] = dest[dim].chunk(2)
841-
actual = da.interp(method=method, **dest, kwargs=kwargs) # type: ignore
842-
expected = da.compute().interp(method=method, **dest, kwargs=kwargs) # type: ignore
841+
actual = da.interp(method=method, **dest, kwargs=kwargs)
842+
expected = da.compute().interp(method=method, **dest, kwargs=kwargs)
843843

844844
assert_identical(actual, expected)
845845

xarray/tests/test_utils.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from collections.abc import Hashable, Iterable, Sequence
3+
from collections.abc import Hashable
44

55
import numpy as np
66
import pandas as pd
@@ -255,19 +255,21 @@ def test_infix_dims_errors(supplied, all_):
255255
["dim", "expected"],
256256
[
257257
pytest.param("a", ("a",), id="str"),
258+
pytest.param(1, (1,), id="non_str_hashable"),
258259
pytest.param(["a", "b"], ("a", "b"), id="list_of_str"),
259260
pytest.param(["a", 1], ("a", 1), id="list_mixed"),
261+
pytest.param(["a", ...], ("a", ...), id="list_with_ellipsis"),
260262
pytest.param(("a", "b"), ("a", "b"), id="tuple_of_str"),
261263
pytest.param(["a", ("b", "c")], ("a", ("b", "c")), id="list_with_tuple"),
262264
pytest.param((("b", "c"),), (("b", "c"),), id="tuple_of_tuple"),
265+
pytest.param({"a", 1}, tuple({"a", 1}), id="non_sequence_collection"),
266+
pytest.param((), (), id="empty_tuple"),
267+
pytest.param(set(), (), id="empty_collection"),
263268
pytest.param(None, None, id="None"),
264269
pytest.param(..., ..., id="ellipsis"),
265270
],
266271
)
267-
def test_parse_dims(
268-
dim: str | Iterable[Hashable] | None,
269-
expected: tuple[Hashable, ...],
270-
) -> None:
272+
def test_parse_dims(dim, expected):
271273
all_dims = ("a", "b", 1, ("b", "c")) # selection of different Hashables
272274
actual = utils.parse_dims(dim, all_dims, replace_none=False)
273275
assert actual == expected
@@ -297,7 +299,7 @@ def test_parse_dims_replace_none(dim: None | ellipsis) -> None:
297299
pytest.param(["x", 2], id="list_missing_all"),
298300
],
299301
)
300-
def test_parse_dims_raises(dim: str | Iterable[Hashable]) -> None:
302+
def test_parse_dims_raises(dim):
301303
all_dims = ("a", "b", 1, ("b", "c")) # selection of different Hashables
302304
with pytest.raises(ValueError, match="'x'"):
303305
utils.parse_dims(dim, all_dims, check_exists=True)
@@ -313,10 +315,7 @@ def test_parse_dims_raises(dim: str | Iterable[Hashable]) -> None:
313315
pytest.param(["a", ..., "b"], ("a", "c", "b"), id="list_with_middle_ellipsis"),
314316
],
315317
)
316-
def test_parse_ordered_dims(
317-
dim: str | Sequence[Hashable | ellipsis],
318-
expected: tuple[Hashable, ...],
319-
) -> None:
318+
def test_parse_ordered_dims(dim, expected):
320319
all_dims = ("a", "b", "c")
321320
actual = utils.parse_ordered_dims(dim, all_dims)
322321
assert actual == expected

0 commit comments

Comments
 (0)