diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 71cbc1a08ee..92048e02837 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -38,6 +38,13 @@ Breaking changes Deprecations ~~~~~~~~~~~~ +- As part of an effort to standardize the API, we're renaming the ``dims`` + keyword arg to ``dim`` for the minority of functions which current use + ``dims``. This started with :py:func:`xarray.dot` & :py:meth:`DataArray.dot` + and we'll gradually roll this out across all functions. The warnings are + currently ``PendingDeprecationWarning``, which are silenced by default. We'll + convert these to ``DeprecationWarning`` in a future release. + By `Maximilian Roos `_. Bug fixes ~~~~~~~~~ diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py index 732ec5d3ea6..041fe63a9f3 100644 --- a/xarray/core/alignment.py +++ b/xarray/core/alignment.py @@ -324,7 +324,7 @@ def assert_no_index_conflict(self) -> None: "- they may be used to reindex data along common dimensions" ) - def _need_reindex(self, dims, cmp_indexes) -> bool: + def _need_reindex(self, dim, cmp_indexes) -> bool: """Whether or not we need to reindex variables for a set of matching indexes. @@ -340,14 +340,14 @@ def _need_reindex(self, dims, cmp_indexes) -> bool: return True unindexed_dims_sizes = {} - for dim in dims: - if dim in self.unindexed_dim_sizes: - sizes = self.unindexed_dim_sizes[dim] + for d in dim: + if d in self.unindexed_dim_sizes: + sizes = self.unindexed_dim_sizes[d] if len(sizes) > 1: # reindex if different sizes are found for unindexed dims return True else: - unindexed_dims_sizes[dim] = next(iter(sizes)) + unindexed_dims_sizes[d] = next(iter(sizes)) if unindexed_dims_sizes: indexed_dims_sizes = {} @@ -356,8 +356,8 @@ def _need_reindex(self, dims, cmp_indexes) -> bool: for var in index_vars.values(): indexed_dims_sizes.update(var.sizes) - for dim, size in unindexed_dims_sizes.items(): - if indexed_dims_sizes.get(dim, -1) != size: + for d, size in unindexed_dims_sizes.items(): + if indexed_dims_sizes.get(d, -1) != size: # reindex if unindexed dimension size doesn't match return True diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 0c5c9d6d5cb..ed2c733d4ca 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -26,6 +26,7 @@ from xarray.core.types import Dims, T_DataArray from xarray.core.utils import is_dict_like, is_scalar from xarray.core.variable import Variable +from xarray.util.deprecation_helpers import deprecate_dims if TYPE_CHECKING: from xarray.core.coordinates import Coordinates @@ -1691,9 +1692,10 @@ def cross( return c +@deprecate_dims def dot( *arrays, - dims: Dims = None, + dim: Dims = None, **kwargs: Any, ): """Generalized dot product for xarray objects. Like ``np.einsum``, but @@ -1703,7 +1705,7 @@ def dot( ---------- *arrays : DataArray or Variable Arrays to compute. - dims : str, iterable of hashable, "..." or None, optional + dim : str, iterable of hashable, "..." or None, optional Which dimensions to sum over. Ellipsis ('...') sums over all dimensions. If not specified, then all the common dimensions are summed over. **kwargs : dict @@ -1756,18 +1758,18 @@ def dot( [3, 4, 5]]) Dimensions without coordinates: c, d - >>> xr.dot(da_a, da_b, dims=["a", "b"]) + >>> xr.dot(da_a, da_b, dim=["a", "b"]) array([110, 125]) Dimensions without coordinates: c - >>> xr.dot(da_a, da_b, dims=["a"]) + >>> xr.dot(da_a, da_b, dim=["a"]) array([[40, 46], [70, 79]]) Dimensions without coordinates: b, c - >>> xr.dot(da_a, da_b, da_c, dims=["b", "c"]) + >>> xr.dot(da_a, da_b, da_c, dim=["b", "c"]) array([[ 9, 14, 19], [ 93, 150, 207], @@ -1779,7 +1781,7 @@ def dot( array([110, 125]) Dimensions without coordinates: c - >>> xr.dot(da_a, da_b, dims=...) + >>> xr.dot(da_a, da_b, dim=...) array(235) """ @@ -1803,18 +1805,18 @@ def dot( einsum_axes = "abcdefghijklmnopqrstuvwxyz" dim_map = {d: einsum_axes[i] for i, d in enumerate(all_dims)} - if dims is ...: - dims = all_dims - elif isinstance(dims, str): - dims = (dims,) - elif dims is None: + if dim is ...: + dim = all_dims + elif isinstance(dim, str): + dim = (dim,) + elif dim is None: # find dimensions that occur more than one times dim_counts: Counter = Counter() for arr in arrays: dim_counts.update(arr.dims) - dims = tuple(d for d, c in dim_counts.items() if c > 1) + dim = tuple(d for d, c in dim_counts.items() if c > 1) - dot_dims: set[Hashable] = set(dims) + dot_dims: set[Hashable] = set(dim) # dimensions to be parallelized broadcast_dims = common_dims - dot_dims diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index b417470fdc0..47708cfb581 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -65,7 +65,7 @@ ) from xarray.plot.accessor import DataArrayPlotAccessor from xarray.plot.utils import _get_units_from_attrs -from xarray.util.deprecation_helpers import _deprecate_positional_args +from xarray.util.deprecation_helpers import _deprecate_positional_args, deprecate_dims if TYPE_CHECKING: from typing import TypeVar, Union @@ -115,14 +115,14 @@ T_XarrayOther = TypeVar("T_XarrayOther", bound=Union["DataArray", Dataset]) -def _check_coords_dims(shape, coords, dims): - sizes = dict(zip(dims, shape)) +def _check_coords_dims(shape, coords, dim): + sizes = dict(zip(dim, shape)) for k, v in coords.items(): - if any(d not in dims for d in v.dims): + if any(d not in dim for d in v.dims): raise ValueError( f"coordinate {k} has dimensions {v.dims}, but these " "are not a subset of the DataArray " - f"dimensions {dims}" + f"dimensions {dim}" ) for d, s in v.sizes.items(): @@ -4895,10 +4895,11 @@ def imag(self) -> Self: """ return self._replace(self.variable.imag) + @deprecate_dims def dot( self, other: T_Xarray, - dims: Dims = None, + dim: Dims = None, ) -> T_Xarray: """Perform dot product of two DataArrays along their shared dims. @@ -4908,7 +4909,7 @@ def dot( ---------- other : DataArray The other array with which the dot product is performed. - dims : ..., str, Iterable of Hashable or None, optional + dim : ..., str, Iterable of Hashable or None, optional Which dimensions to sum over. Ellipsis (`...`) sums over all dimensions. If not specified, then all the common dimensions are summed over. @@ -4947,7 +4948,7 @@ def dot( if not isinstance(other, DataArray): raise TypeError("dot only operates on DataArrays.") - return computation.dot(self, other, dims=dims) + return computation.dot(self, other, dim=dim) def sortby( self, diff --git a/xarray/core/variable.py b/xarray/core/variable.py index c2133d55aeb..39a947e6264 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -1541,15 +1541,15 @@ def stack(self, dimensions=None, **dimensions_kwargs): result = result._stack_once(dims, new_dim) return result - def _unstack_once_full(self, dims: Mapping[Any, int], old_dim: Hashable) -> Self: + def _unstack_once_full(self, dim: Mapping[Any, int], old_dim: Hashable) -> Self: """ Unstacks the variable without needing an index. Unlike `_unstack_once`, this function requires the existing dimension to contain the full product of the new dimensions. """ - new_dim_names = tuple(dims.keys()) - new_dim_sizes = tuple(dims.values()) + new_dim_names = tuple(dim.keys()) + new_dim_sizes = tuple(dim.values()) if old_dim not in self.dims: raise ValueError(f"invalid existing dimension: {old_dim}") diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index 28740a99020..53ff6db5f28 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -228,7 +228,7 @@ def _reduce( # `dot` does not broadcast arrays, so this avoids creating a large # DataArray (if `weights` has additional dimensions) - return dot(da, weights, dims=dim) + return dot(da, weights, dim=dim) def _sum_of_weights(self, da: DataArray, dim: Dims = None) -> DataArray: """Calculate the sum of weights, accounting for missing values""" diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index 425673dc40f..396507652c6 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -1936,7 +1936,7 @@ def test_dot(use_dask: bool) -> None: da_a = da_a.chunk({"a": 3}) da_b = da_b.chunk({"a": 3}) da_c = da_c.chunk({"c": 3}) - actual = xr.dot(da_a, da_b, dims=["a", "b"]) + actual = xr.dot(da_a, da_b, dim=["a", "b"]) assert actual.dims == ("c",) assert (actual.data == np.einsum("ij,ijk->k", a, b)).all() assert isinstance(actual.variable.data, type(da_a.variable.data)) @@ -1960,33 +1960,33 @@ def test_dot(use_dask: bool) -> None: if use_dask: da_a = da_a.chunk({"a": 3}) da_b = da_b.chunk({"a": 3}) - actual = xr.dot(da_a, da_b, dims=["b"]) + actual = xr.dot(da_a, da_b, dim=["b"]) assert actual.dims == ("a", "c") assert (actual.data == np.einsum("ij,ijk->ik", a, b)).all() assert isinstance(actual.variable.data, type(da_a.variable.data)) - actual = xr.dot(da_a, da_b, dims=["b"]) + actual = xr.dot(da_a, da_b, dim=["b"]) assert actual.dims == ("a", "c") assert (actual.data == np.einsum("ij,ijk->ik", a, b)).all() - actual = xr.dot(da_a, da_b, dims="b") + actual = xr.dot(da_a, da_b, dim="b") assert actual.dims == ("a", "c") assert (actual.data == np.einsum("ij,ijk->ik", a, b)).all() - actual = xr.dot(da_a, da_b, dims="a") + actual = xr.dot(da_a, da_b, dim="a") assert actual.dims == ("b", "c") assert (actual.data == np.einsum("ij,ijk->jk", a, b)).all() - actual = xr.dot(da_a, da_b, dims="c") + actual = xr.dot(da_a, da_b, dim="c") assert actual.dims == ("a", "b") assert (actual.data == np.einsum("ij,ijk->ij", a, b)).all() - actual = xr.dot(da_a, da_b, da_c, dims=["a", "b"]) + actual = xr.dot(da_a, da_b, da_c, dim=["a", "b"]) assert actual.dims == ("c", "e") assert (actual.data == np.einsum("ij,ijk,kl->kl ", a, b, c)).all() # should work with tuple - actual = xr.dot(da_a, da_b, dims=("c",)) + actual = xr.dot(da_a, da_b, dim=("c",)) assert actual.dims == ("a", "b") assert (actual.data == np.einsum("ij,ijk->ij", a, b)).all() @@ -1996,47 +1996,47 @@ def test_dot(use_dask: bool) -> None: assert (actual.data == np.einsum("ij,ijk,kl->l ", a, b, c)).all() # 1 array summation - actual = xr.dot(da_a, dims="a") + actual = xr.dot(da_a, dim="a") assert actual.dims == ("b",) assert (actual.data == np.einsum("ij->j ", a)).all() # empty dim - actual = xr.dot(da_a.sel(a=[]), da_a.sel(a=[]), dims="a") + actual = xr.dot(da_a.sel(a=[]), da_a.sel(a=[]), dim="a") assert actual.dims == ("b",) assert (actual.data == np.zeros(actual.shape)).all() # Ellipsis (...) sums over all dimensions - actual = xr.dot(da_a, da_b, dims=...) + actual = xr.dot(da_a, da_b, dim=...) assert actual.dims == () assert (actual.data == np.einsum("ij,ijk->", a, b)).all() - actual = xr.dot(da_a, da_b, da_c, dims=...) + actual = xr.dot(da_a, da_b, da_c, dim=...) assert actual.dims == () assert (actual.data == np.einsum("ij,ijk,kl-> ", a, b, c)).all() - actual = xr.dot(da_a, dims=...) + actual = xr.dot(da_a, dim=...) assert actual.dims == () assert (actual.data == np.einsum("ij-> ", a)).all() - actual = xr.dot(da_a.sel(a=[]), da_a.sel(a=[]), dims=...) + actual = xr.dot(da_a.sel(a=[]), da_a.sel(a=[]), dim=...) assert actual.dims == () assert (actual.data == np.zeros(actual.shape)).all() # Invalid cases if not use_dask: with pytest.raises(TypeError): - xr.dot(da_a, dims="a", invalid=None) + xr.dot(da_a, dim="a", invalid=None) with pytest.raises(TypeError): - xr.dot(da_a.to_dataset(name="da"), dims="a") + xr.dot(da_a.to_dataset(name="da"), dim="a") with pytest.raises(TypeError): - xr.dot(dims="a") + xr.dot(dim="a") # einsum parameters - actual = xr.dot(da_a, da_b, dims=["b"], order="C") + actual = xr.dot(da_a, da_b, dim=["b"], order="C") assert (actual.data == np.einsum("ij,ijk->ik", a, b)).all() assert actual.values.flags["C_CONTIGUOUS"] assert not actual.values.flags["F_CONTIGUOUS"] - actual = xr.dot(da_a, da_b, dims=["b"], order="F") + actual = xr.dot(da_a, da_b, dim=["b"], order="F") assert (actual.data == np.einsum("ij,ijk->ik", a, b)).all() # dask converts Fortran arrays to C order when merging the final array if not use_dask: @@ -2078,7 +2078,7 @@ def test_dot_align_coords(use_dask: bool) -> None: expected = (da_a * da_b).sum(["a", "b"]) xr.testing.assert_allclose(expected, actual) - actual = xr.dot(da_a, da_b, dims=...) + actual = xr.dot(da_a, da_b, dim=...) expected = (da_a * da_b).sum() xr.testing.assert_allclose(expected, actual) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 44b9790f0b7..f9547f3afa2 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -3964,13 +3964,13 @@ def test_dot(self) -> None: assert_equal(expected3, actual3) # Ellipsis: all dims are shared - actual4 = da.dot(da, dims=...) + actual4 = da.dot(da, dim=...) expected4 = da.dot(da) assert_equal(expected4, actual4) # Ellipsis: not all dims are shared - actual5 = da.dot(dm3, dims=...) - expected5 = da.dot(dm3, dims=("j", "x", "y", "z")) + actual5 = da.dot(dm3, dim=...) + expected5 = da.dot(dm3, dim=("j", "x", "y", "z")) assert_equal(expected5, actual5) with pytest.raises(NotImplementedError): diff --git a/xarray/util/deprecation_helpers.py b/xarray/util/deprecation_helpers.py index 7b4cf901aa1..c620e45574e 100644 --- a/xarray/util/deprecation_helpers.py +++ b/xarray/util/deprecation_helpers.py @@ -36,6 +36,8 @@ from functools import wraps from typing import Callable, TypeVar +from xarray.core.utils import emit_user_level_warning + T = TypeVar("T", bound=Callable) POSITIONAL_OR_KEYWORD = inspect.Parameter.POSITIONAL_OR_KEYWORD @@ -115,3 +117,28 @@ def inner(*args, **kwargs): return inner return _decorator + + +def deprecate_dims(func: T) -> T: + """ + For functions that previously took `dims` as a kwarg, and have now transitioned to + `dim`. This decorator will issue a warning if `dims` is passed while forwarding it + to `dim`. + """ + + @wraps(func) + def wrapper(*args, **kwargs): + if "dims" in kwargs: + emit_user_level_warning( + "The `dims` argument has been renamed to `dim`, and will be removed " + "in the future. This renaming is taking place throughout xarray over the " + "next few releases.", + # Upgrade to `DeprecationWarning` in the future, when the renaming is complete. + PendingDeprecationWarning, + ) + kwargs["dim"] = kwargs.pop("dims") + return func(*args, **kwargs) + + # We're quite confident we're just returning `T` from this function, so it's fine to ignore typing + # within the function. + return wrapper # type: ignore