diff --git a/doc/user-guide/computation.rst b/doc/user-guide/computation.rst index d1f1274c7a1..f8141f40321 100644 --- a/doc/user-guide/computation.rst +++ b/doc/user-guide/computation.rst @@ -63,6 +63,10 @@ Data arrays also implement many :py:class:`numpy.ndarray` methods: arr.round(2) arr.T + intarr = xr.DataArray([0, 1, 2, 3, 4, 5]) + intarr << 2 # only supported for int types + intarr >> 1 + .. _missing_values: Missing values diff --git a/doc/whats-new.rst b/doc/whats-new.rst index bfc040eb271..f54e9ad9676 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -22,6 +22,9 @@ v2023.05.0 (unreleased) New Features ~~~~~~~~~~~~ +- Add support for lshift and rshift binary operators (`<<`, `>>`) on + :py:class:`xr.DataArray` of type :py:class:`int` (:issue:`7727` , :pull:`7741`). + By `Alan Brammer `_. Breaking changes diff --git a/xarray/core/_typed_ops.py b/xarray/core/_typed_ops.py index a6e6fdbfaec..d3a783be45d 100644 --- a/xarray/core/_typed_ops.py +++ b/xarray/core/_typed_ops.py @@ -42,6 +42,12 @@ def __xor__(self, other): def __or__(self, other): return self._binary_op(other, operator.or_) + def __lshift__(self, other): + return self._binary_op(other, operator.lshift) + + def __rshift__(self, other): + return self._binary_op(other, operator.rshift) + def __lt__(self, other): return self._binary_op(other, operator.lt) @@ -123,6 +129,12 @@ def __ixor__(self, other): def __ior__(self, other): return self._inplace_binary_op(other, operator.ior) + def __ilshift__(self, other): + return self._inplace_binary_op(other, operator.ilshift) + + def __irshift__(self, other): + return self._inplace_binary_op(other, operator.irshift) + def _unary_op(self, f, *args, **kwargs): raise NotImplementedError @@ -160,6 +172,8 @@ def conjugate(self, *args, **kwargs): __and__.__doc__ = operator.and_.__doc__ __xor__.__doc__ = operator.xor.__doc__ __or__.__doc__ = operator.or_.__doc__ + __lshift__.__doc__ = operator.lshift.__doc__ + __rshift__.__doc__ = operator.rshift.__doc__ __lt__.__doc__ = operator.lt.__doc__ __le__.__doc__ = operator.le.__doc__ __gt__.__doc__ = operator.gt.__doc__ @@ -186,6 +200,8 @@ def conjugate(self, *args, **kwargs): __iand__.__doc__ = operator.iand.__doc__ __ixor__.__doc__ = operator.ixor.__doc__ __ior__.__doc__ = operator.ior.__doc__ + __ilshift__.__doc__ = operator.ilshift.__doc__ + __irshift__.__doc__ = operator.irshift.__doc__ __neg__.__doc__ = operator.neg.__doc__ __pos__.__doc__ = operator.pos.__doc__ __abs__.__doc__ = operator.abs.__doc__ @@ -232,6 +248,12 @@ def __xor__(self, other): def __or__(self, other): return self._binary_op(other, operator.or_) + def __lshift__(self, other): + return self._binary_op(other, operator.lshift) + + def __rshift__(self, other): + return self._binary_op(other, operator.rshift) + def __lt__(self, other): return self._binary_op(other, operator.lt) @@ -313,6 +335,12 @@ def __ixor__(self, other): def __ior__(self, other): return self._inplace_binary_op(other, operator.ior) + def __ilshift__(self, other): + return self._inplace_binary_op(other, operator.ilshift) + + def __irshift__(self, other): + return self._inplace_binary_op(other, operator.irshift) + def _unary_op(self, f, *args, **kwargs): raise NotImplementedError @@ -350,6 +378,8 @@ def conjugate(self, *args, **kwargs): __and__.__doc__ = operator.and_.__doc__ __xor__.__doc__ = operator.xor.__doc__ __or__.__doc__ = operator.or_.__doc__ + __lshift__.__doc__ = operator.lshift.__doc__ + __rshift__.__doc__ = operator.rshift.__doc__ __lt__.__doc__ = operator.lt.__doc__ __le__.__doc__ = operator.le.__doc__ __gt__.__doc__ = operator.gt.__doc__ @@ -376,6 +406,8 @@ def conjugate(self, *args, **kwargs): __iand__.__doc__ = operator.iand.__doc__ __ixor__.__doc__ = operator.ixor.__doc__ __ior__.__doc__ = operator.ior.__doc__ + __ilshift__.__doc__ = operator.ilshift.__doc__ + __irshift__.__doc__ = operator.irshift.__doc__ __neg__.__doc__ = operator.neg.__doc__ __pos__.__doc__ = operator.pos.__doc__ __abs__.__doc__ = operator.abs.__doc__ @@ -422,6 +454,12 @@ def __xor__(self, other): def __or__(self, other): return self._binary_op(other, operator.or_) + def __lshift__(self, other): + return self._binary_op(other, operator.lshift) + + def __rshift__(self, other): + return self._binary_op(other, operator.rshift) + def __lt__(self, other): return self._binary_op(other, operator.lt) @@ -503,6 +541,12 @@ def __ixor__(self, other): def __ior__(self, other): return self._inplace_binary_op(other, operator.ior) + def __ilshift__(self, other): + return self._inplace_binary_op(other, operator.ilshift) + + def __irshift__(self, other): + return self._inplace_binary_op(other, operator.irshift) + def _unary_op(self, f, *args, **kwargs): raise NotImplementedError @@ -540,6 +584,8 @@ def conjugate(self, *args, **kwargs): __and__.__doc__ = operator.and_.__doc__ __xor__.__doc__ = operator.xor.__doc__ __or__.__doc__ = operator.or_.__doc__ + __lshift__.__doc__ = operator.lshift.__doc__ + __rshift__.__doc__ = operator.rshift.__doc__ __lt__.__doc__ = operator.lt.__doc__ __le__.__doc__ = operator.le.__doc__ __gt__.__doc__ = operator.gt.__doc__ @@ -566,6 +612,8 @@ def conjugate(self, *args, **kwargs): __iand__.__doc__ = operator.iand.__doc__ __ixor__.__doc__ = operator.ixor.__doc__ __ior__.__doc__ = operator.ior.__doc__ + __ilshift__.__doc__ = operator.ilshift.__doc__ + __irshift__.__doc__ = operator.irshift.__doc__ __neg__.__doc__ = operator.neg.__doc__ __pos__.__doc__ = operator.pos.__doc__ __abs__.__doc__ = operator.abs.__doc__ @@ -612,6 +660,12 @@ def __xor__(self, other): def __or__(self, other): return self._binary_op(other, operator.or_) + def __lshift__(self, other): + return self._binary_op(other, operator.lshift) + + def __rshift__(self, other): + return self._binary_op(other, operator.rshift) + def __lt__(self, other): return self._binary_op(other, operator.lt) @@ -670,6 +724,8 @@ def __ror__(self, other): __and__.__doc__ = operator.and_.__doc__ __xor__.__doc__ = operator.xor.__doc__ __or__.__doc__ = operator.or_.__doc__ + __lshift__.__doc__ = operator.lshift.__doc__ + __rshift__.__doc__ = operator.rshift.__doc__ __lt__.__doc__ = operator.lt.__doc__ __le__.__doc__ = operator.le.__doc__ __gt__.__doc__ = operator.gt.__doc__ @@ -724,6 +780,12 @@ def __xor__(self, other): def __or__(self, other): return self._binary_op(other, operator.or_) + def __lshift__(self, other): + return self._binary_op(other, operator.lshift) + + def __rshift__(self, other): + return self._binary_op(other, operator.rshift) + def __lt__(self, other): return self._binary_op(other, operator.lt) @@ -782,6 +844,8 @@ def __ror__(self, other): __and__.__doc__ = operator.and_.__doc__ __xor__.__doc__ = operator.xor.__doc__ __or__.__doc__ = operator.or_.__doc__ + __lshift__.__doc__ = operator.lshift.__doc__ + __rshift__.__doc__ = operator.rshift.__doc__ __lt__.__doc__ = operator.lt.__doc__ __le__.__doc__ = operator.le.__doc__ __gt__.__doc__ = operator.gt.__doc__ diff --git a/xarray/core/_typed_ops.pyi b/xarray/core/_typed_ops.pyi index 98a17a47cd5..9e2ba2d3a06 100644 --- a/xarray/core/_typed_ops.pyi +++ b/xarray/core/_typed_ops.pyi @@ -44,6 +44,8 @@ class DatasetOpsMixin: def __and__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... def __xor__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... def __or__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... + def __lshift__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... + def __rshift__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... def __lt__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... def __le__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... def __gt__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... @@ -135,6 +137,18 @@ class DataArrayOpsMixin: @overload def __or__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... @overload + def __lshift__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __lshift__(self, other: "DatasetGroupBy") -> "Dataset": ... + @overload + def __lshift__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... + @overload + def __rshift__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __rshift__(self, other: "DatasetGroupBy") -> "Dataset": ... + @overload + def __rshift__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... + @overload def __lt__(self, other: T_Dataset) -> T_Dataset: ... @overload def __lt__(self, other: "DatasetGroupBy") -> "Dataset": ... @@ -305,6 +319,18 @@ class VariableOpsMixin: @overload def __or__(self: T_Variable, other: VarCompatible) -> T_Variable: ... @overload + def __lshift__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __lshift__(self, other: T_DataArray) -> T_DataArray: ... + @overload + def __lshift__(self: T_Variable, other: VarCompatible) -> T_Variable: ... + @overload + def __rshift__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __rshift__(self, other: T_DataArray) -> T_DataArray: ... + @overload + def __rshift__(self: T_Variable, other: VarCompatible) -> T_Variable: ... + @overload def __lt__(self, other: T_Dataset) -> T_Dataset: ... @overload def __lt__(self, other: T_DataArray) -> T_DataArray: ... @@ -475,6 +501,18 @@ class DatasetGroupByOpsMixin: @overload def __or__(self, other: GroupByIncompatible) -> NoReturn: ... @overload + def __lshift__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __lshift__(self, other: "DataArray") -> "Dataset": ... + @overload + def __lshift__(self, other: GroupByIncompatible) -> NoReturn: ... + @overload + def __rshift__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __rshift__(self, other: "DataArray") -> "Dataset": ... + @overload + def __rshift__(self, other: GroupByIncompatible) -> NoReturn: ... + @overload def __lt__(self, other: T_Dataset) -> T_Dataset: ... @overload def __lt__(self, other: "DataArray") -> "Dataset": ... @@ -635,6 +673,18 @@ class DataArrayGroupByOpsMixin: @overload def __or__(self, other: GroupByIncompatible) -> NoReturn: ... @overload + def __lshift__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __lshift__(self, other: T_DataArray) -> T_DataArray: ... + @overload + def __lshift__(self, other: GroupByIncompatible) -> NoReturn: ... + @overload + def __rshift__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __rshift__(self, other: T_DataArray) -> T_DataArray: ... + @overload + def __rshift__(self, other: GroupByIncompatible) -> NoReturn: ... + @overload def __lt__(self, other: T_Dataset) -> T_Dataset: ... @overload def __lt__(self, other: T_DataArray) -> T_DataArray: ... diff --git a/xarray/core/ops.py b/xarray/core/ops.py index 009616c5f12..e1c3573841a 100644 --- a/xarray/core/ops.py +++ b/xarray/core/ops.py @@ -33,6 +33,8 @@ "and", "xor", "or", + "lshift", + "rshift", ] # methods which pass on the numpy return value unchanged diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index b37399d6ef8..1171464a962 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -178,6 +178,19 @@ def test_binary_op(self): self.assertLazyAndIdentical(u + u, v + v) self.assertLazyAndIdentical(u[0] + u, v[0] + v) + def test_binary_op_bitshift(self) -> None: + # bit shifts only work on ints so we need to generate + # new eager and lazy vars + rng = np.random.default_rng(0) + values = rng.integers(low=-10000, high=10000, size=(4, 6)) + data = da.from_array(values, chunks=(2, 2)) + u = Variable(("x", "y"), values) + v = Variable(("x", "y"), data) + self.assertLazyAndIdentical(u << 2, v << 2) + self.assertLazyAndIdentical(u << 5, v << 5) + self.assertLazyAndIdentical(u >> 2, v >> 2) + self.assertLazyAndIdentical(u >> 5, v >> 5) + def test_repr(self): expected = dedent( """\ diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index dcbfd42c9f1..ece061607bd 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -3926,6 +3926,11 @@ def test_binary_op_propagate_indexes(self) -> None: actual = (self.dv > 10).xindexes["x"] assert expected is actual + # use mda for bitshift test as it's type int + actual = (self.mda << 2).xindexes["x"] + expected = self.mda.xindexes["x"] + assert expected is actual + def test_binary_op_join_setting(self) -> None: dim = "x" align_type: Final = "outer" diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 4488bda3eca..48a9c17d2a8 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -808,6 +808,34 @@ def test_groupby_math_more() -> None: ds + ds.groupby("time.month") +def test_groupby_math_bitshift() -> None: + # create new dataset of int's only + ds = Dataset( + { + "x": ("index", np.ones(4, dtype=int)), + "y": ("index", np.ones(4, dtype=int) * -1), + "level": ("index", [1, 1, 2, 2]), + "index": [0, 1, 2, 3], + } + ) + shift = DataArray([1, 2, 1], [("level", [1, 2, 8])]) + + left_expected = Dataset( + { + "x": ("index", [2, 2, 4, 4]), + "y": ("index", [-2, -2, -4, -4]), + "level": ("index", [2, 2, 8, 8]), + "index": [0, 1, 2, 3], + } + ) + + left_actual = (ds.groupby("level") << shift).reset_coords(names="level") + assert_equal(left_expected, left_actual) + + right_actual = (left_expected.groupby("level") >> shift).reset_coords(names="level") + assert_equal(ds, right_actual) + + @pytest.mark.parametrize("use_flox", [True, False]) def test_groupby_bins_cut_kwargs(use_flox: bool) -> None: da = xr.DataArray(np.arange(12).reshape(6, 2), dims=("x", "y")) diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index b92db16e34b..bef5efc15cc 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -344,9 +344,10 @@ def test_pandas_period_index(self): assert v[0] == pd.Period("2000", freq="B") assert "Period('2000-01-03', 'B')" in repr(v) - def test_1d_math(self): - x = 1.0 * np.arange(5) - y = np.ones(5) + @pytest.mark.parametrize("dtype", [float, int]) + def test_1d_math(self, dtype: np.typing.DTypeLike) -> None: + x = np.arange(5, dtype=dtype) + y = np.ones(5, dtype=dtype) # should we need `.to_base_variable()`? # probably a break that `+v` changes type? @@ -360,11 +361,18 @@ def test_1d_math(self): assert_identical(base_v, v + 0) assert_identical(base_v, 0 + v) assert_identical(base_v, v * 1) + if dtype is int: + assert_identical(base_v, v << 0) + assert_array_equal(v << 3, x << 3) + assert_array_equal(v >> 2, x >> 2) # binary ops with numpy arrays assert_array_equal((v * x).values, x**2) - assert_array_equal((x * v).values, x**2) + assert_array_equal((x * v).values, x**2) # type: ignore[attr-defined] # TODO: Fix mypy thinking numpy takes priority, GH7780 assert_array_equal(v - y, v - 1) assert_array_equal(y - v, 1 - v) + if dtype is int: + assert_array_equal(v << x, x << x) + assert_array_equal(v >> x, x >> x) # verify attributes are dropped v2 = self.cls(["x"], x, {"units": "meters"}) with set_options(keep_attrs=False): @@ -378,10 +386,10 @@ def test_1d_math(self): # something complicated assert_array_equal((v**2 * w - 1 + x).values, x**2 * y - 1 + x) # make sure dtype is preserved (for Index objects) - assert float == (+v).dtype - assert float == (+v).values.dtype - assert float == (0 + v).dtype - assert float == (0 + v).values.dtype + assert dtype == (+v).dtype + assert dtype == (+v).values.dtype + assert dtype == (0 + v).dtype + assert dtype == (0 + v).values.dtype # check types of returned data assert isinstance(+v, Variable) assert not isinstance(+v, IndexVariable) diff --git a/xarray/util/generate_ops.py b/xarray/util/generate_ops.py index 02a3725f475..cf0673e7cca 100644 --- a/xarray/util/generate_ops.py +++ b/xarray/util/generate_ops.py @@ -30,6 +30,8 @@ ("__and__", "operator.and_"), ("__xor__", "operator.xor"), ("__or__", "operator.or_"), + ("__lshift__", "operator.lshift"), + ("__rshift__", "operator.rshift"), ) BINOPS_REFLEXIVE = ( ("__radd__", "operator.add"), @@ -54,6 +56,8 @@ ("__iand__", "operator.iand"), ("__ixor__", "operator.ixor"), ("__ior__", "operator.ior"), + ("__ilshift__", "operator.ilshift"), + ("__irshift__", "operator.irshift"), ) UNARY_OPS = ( ("__neg__", "operator.neg"),