Skip to content

Commit 830f797

Browse files
committed
add parse_dims_as_set
1 parent 6f4973d commit 830f797

File tree

5 files changed

+59
-19
lines changed

5 files changed

+59
-19
lines changed

xarray/core/computation.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from xarray.core.merge import merge_attrs, merge_coordinates_without_align
3232
from xarray.core.options import OPTIONS, _get_keep_attrs
3333
from xarray.core.types import Dims, T_DataArray
34-
from xarray.core.utils import is_dict_like, is_scalar, parse_dims
34+
from xarray.core.utils import is_dict_like, is_scalar, parse_dims_as_set
3535
from xarray.core.variable import Variable
3636
from xarray.namedarray.parallelcompat import get_chunked_array_type
3737
from xarray.namedarray.pycompat import is_chunked_array
@@ -1846,11 +1846,9 @@ def dot(
18461846
dim_counts: Counter = Counter()
18471847
for arr in arrays:
18481848
dim_counts.update(arr.dims)
1849-
dim = tuple(d for d, c in dim_counts.items() if c > 1)
1849+
dot_dims = {d for d, c in dim_counts.items() if c > 1}
18501850
else:
1851-
dim = parse_dims(dim, all_dims=tuple(all_dims))
1852-
1853-
dot_dims: set[Hashable] = set(dim)
1851+
dot_dims = parse_dims_as_set(dim, all_dims=set(all_dims))
18541852

18551853
# dimensions to be parallelized
18561854
broadcast_dims = common_dims - dot_dims

xarray/core/dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@
118118
is_duck_dask_array,
119119
is_scalar,
120120
maybe_wrap_array,
121-
parse_dims,
121+
parse_dims_as_set,
122122
)
123123
from xarray.core.variable import (
124124
IndexVariable,
@@ -6987,7 +6987,7 @@ def reduce(
69876987
" Please use 'dim' instead."
69886988
)
69896989

6990-
dims = set(parse_dims(dim, tuple(self.dims)))
6990+
dims = parse_dims_as_set(dim, self._dims.keys())
69916991

69926992
if keep_attrs is None:
69936993
keep_attrs = _get_keep_attrs(default=False)

xarray/core/datatree.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
drop_dims_from_indexers,
4444
either_dict_or_kwargs,
4545
maybe_wrap_array,
46-
parse_dims,
46+
parse_dims_as_set,
4747
)
4848
from xarray.core.variable import Variable
4949

@@ -1631,7 +1631,7 @@ def reduce(
16311631
**kwargs: Any,
16321632
) -> Self:
16331633
"""Reduce this tree by applying `func` along some dimension(s)."""
1634-
dims = set(parse_dims(dim, tuple(self._get_all_dims())))
1634+
dims = parse_dims_as_set(dim, self._get_all_dims())
16351635
result = {}
16361636
for node in self.subtree:
16371637
reduce_dims = [d for d in node._node_dims if d in dims]

xarray/core/utils.py

Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
Mapping,
5959
MutableMapping,
6060
MutableSet,
61+
Set,
6162
Sequence,
6263
ValuesView,
6364
)
@@ -831,7 +832,7 @@ def drop_dims_from_indexers(
831832

832833

833834
@overload
834-
def parse_dims(
835+
def parse_dims_as_tuple(
835836
dim: Dims,
836837
all_dims: tuple[Hashable, ...],
837838
*,
@@ -841,7 +842,7 @@ def parse_dims(
841842

842843

843844
@overload
844-
def parse_dims(
845+
def parse_dims_as_tuple(
845846
dim: Dims,
846847
all_dims: tuple[Hashable, ...],
847848
*,
@@ -850,7 +851,7 @@ def parse_dims(
850851
) -> tuple[Hashable, ...] | None | EllipsisType: ...
851852

852853

853-
def parse_dims(
854+
def parse_dims_as_tuple(
854855
dim: Dims,
855856
all_dims: tuple[Hashable, ...],
856857
*,
@@ -891,6 +892,47 @@ def parse_dims(
891892
return tuple(dim)
892893

893894

895+
@overload
896+
def parse_dims_as_set(
897+
dim: Dims,
898+
all_dims: Set[Hashable],
899+
*,
900+
check_exists: bool = True,
901+
replace_none: Literal[True] = True,
902+
) -> Set[Hashable]: ...
903+
904+
905+
@overload
906+
def parse_dims_as_set(
907+
dim: Dims,
908+
all_dims: Set[Hashable],
909+
*,
910+
check_exists: bool = True,
911+
replace_none: Literal[False],
912+
) -> Set[Hashable] | None | EllipsisType: ...
913+
914+
915+
def parse_dims_as_set(
916+
dim: Dims,
917+
all_dims: Set[Hashable],
918+
*,
919+
check_exists: bool = True,
920+
replace_none: bool = True,
921+
) -> Set[Hashable] | None | EllipsisType:
922+
"""Like parse_dims_as_tuple, but returning a set instead of a tuple."""
923+
# TODO: Consider removing parse_dims_as_tuple?
924+
if dim is None or dim is ...:
925+
if replace_none:
926+
return all_dims
927+
return dim
928+
if isinstance(dim, str):
929+
dim = {dim}
930+
dim = set(dim)
931+
if check_exists:
932+
_check_dims(dim, all_dims)
933+
return dim
934+
935+
894936
@overload
895937
def parse_ordered_dims(
896938
dim: Dims,
@@ -958,15 +1000,15 @@ def parse_ordered_dims(
9581000
return dims[:idx] + other_dims + dims[idx + 1 :]
9591001
else:
9601002
# mypy cannot resolve that the sequence cannot contain "..."
961-
return parse_dims( # type: ignore[call-overload]
1003+
return parse_dims_as_tuple( # type: ignore[call-overload]
9621004
dim=dim,
9631005
all_dims=all_dims,
9641006
check_exists=check_exists,
9651007
replace_none=replace_none,
9661008
)
9671009

9681010

969-
def _check_dims(dim: set[Hashable], all_dims: set[Hashable]) -> None:
1011+
def _check_dims(dim: Set[Hashable], all_dims: Set[Hashable]) -> None:
9701012
wrong_dims = (dim - all_dims) - {...}
9711013
if wrong_dims:
9721014
wrong_dims_str = ", ".join(f"'{d!s}'" for d in wrong_dims)

xarray/tests/test_utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -283,16 +283,16 @@ def test_infix_dims_errors(supplied, all_):
283283
pytest.param(..., ..., id="ellipsis"),
284284
],
285285
)
286-
def test_parse_dims(dim, expected) -> None:
286+
def test_parse_dims_as_tuple(dim, expected) -> None:
287287
all_dims = ("a", "b", 1, ("b", "c")) # selection of different Hashables
288-
actual = utils.parse_dims(dim, all_dims, replace_none=False)
288+
actual = utils.parse_dims_as_tuple(dim, all_dims, replace_none=False)
289289
assert actual == expected
290290

291291

292292
def test_parse_dims_set() -> None:
293293
all_dims = ("a", "b", 1, ("b", "c")) # selection of different Hashables
294294
dim = {"a", 1}
295-
actual = utils.parse_dims(dim, all_dims)
295+
actual = utils.parse_dims_as_tuple(dim, all_dims)
296296
assert set(actual) == dim
297297

298298

@@ -301,7 +301,7 @@ def test_parse_dims_set() -> None:
301301
)
302302
def test_parse_dims_replace_none(dim: None | EllipsisType) -> None:
303303
all_dims = ("a", "b", 1, ("b", "c")) # selection of different Hashables
304-
actual = utils.parse_dims(dim, all_dims, replace_none=True)
304+
actual = utils.parse_dims_as_tuple(dim, all_dims, replace_none=True)
305305
assert actual == all_dims
306306

307307

@@ -316,7 +316,7 @@ def test_parse_dims_replace_none(dim: None | EllipsisType) -> None:
316316
def test_parse_dims_raises(dim) -> None:
317317
all_dims = ("a", "b", 1, ("b", "c")) # selection of different Hashables
318318
with pytest.raises(ValueError, match="'x'"):
319-
utils.parse_dims(dim, all_dims, check_exists=True)
319+
utils.parse_dims_as_tuple(dim, all_dims, check_exists=True)
320320

321321

322322
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)