diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index faec5ded04e..262c023059a 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -224,14 +224,17 @@ def empty_like(a, **kwargs): return xp.empty_like(a, **kwargs) -def astype(data, dtype, **kwargs): - if hasattr(data, "__array_namespace__"): +def astype(data, dtype, *, xp=None, **kwargs): + if not hasattr(data, "__array_namespace__") and xp is None: + return data.astype(dtype, **kwargs) + + if xp is None: xp = get_array_namespace(data) - if xp == np: - # numpy currently doesn't have a astype: - return data.astype(dtype, **kwargs) - return xp.astype(data, dtype, **kwargs) - return data.astype(dtype, **kwargs) + + if xp == np: + # numpy currently doesn't have a astype: + return data.astype(dtype, **kwargs) + return xp.astype(data, dtype, **kwargs) def asarray(data, xp=np, dtype=None): @@ -373,6 +376,13 @@ def sum_where(data, axis=None, dtype=None, where=None): def where(condition, x, y): """Three argument where() with better dtype promotion rules.""" xp = get_array_namespace(condition, x, y) + + dtype = xp.bool_ if hasattr(xp, "bool_") else xp.bool + if not is_duck_array(condition): + condition = asarray(condition, dtype=dtype, xp=xp) + else: + condition = astype(condition, dtype=dtype, xp=xp) + return xp.where(condition, *as_shared_dtype([x, y], xp=xp)) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index b28ba390a9f..6a3ce156ce6 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -524,7 +524,7 @@ def factorize(self) -> EncodedGroups: # Restore these after the raveling broadcasted_masks = broadcast(*masks) mask = functools.reduce(np.logical_or, broadcasted_masks) # type: ignore[arg-type] - _flatcodes = where(mask, -1, _flatcodes) + _flatcodes = where(mask.data, -1, _flatcodes) full_index = pd.MultiIndex.from_product( (grouper.full_index.values for grouper in groupers), diff --git a/xarray/tests/test_array_api.py b/xarray/tests/test_array_api.py index c273260d7dd..022d2e3750e 100644 --- a/xarray/tests/test_array_api.py +++ b/xarray/tests/test_array_api.py @@ -139,7 +139,6 @@ def test_unstack(arrays: tuple[xr.DataArray, xr.DataArray]) -> None: assert_equal(actual, expected) -@pytest.mark.skip def test_where() -> None: np_arr = xr.DataArray(np.array([1, 0]), dims="x") xp_arr = xr.DataArray(xp.asarray([1, 0]), dims="x")