Skip to content

Commit ff41988

Browse files
Scott Walesshoyer
Scott Wales
authored andcommitted
ENH: keepdims=True for xarray reductions (#3033)
* ENH: keepdims=True for xarray reductions Addresses #2170 Add new option `keepdims` to xarray reduce operations, following the behaviour of Numpy. `keepdims` may be passed to reductions on either Datasets or DataArrays, and will result in the reduced dimensions being still present in the output with size 1. Coordinates that depend on the reduced dimensions will be removed from the Dataset/DataArray * Set the default to be `False` * Correct lint error * Apply suggestions from code review Co-Authored-By: Maximilian Roos <[email protected]> * Add test for dask and fix implementation * Move 'keepdims' up to where 'dims' is set * Fix lint, add test for scalar variable
1 parent 724ad83 commit ff41988

File tree

7 files changed

+134
-8
lines changed

7 files changed

+134
-8
lines changed

doc/whats-new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ Enhancements
2222
~~~~~~~~~~~~
2323

2424

25+
- Add ``keepdims`` argument for reduce operations (:issue:`2170`)
26+
By `Scott Wales <https://github.com/ScottWales>`_.
2527
- netCDF chunksizes are now only dropped when original_shape is different,
2628
not when it isn't found. (:issue:`2207`)
2729
By `Karel van de Plassche <https://github.com/Karel-van-de-Plassche>`_.

xarray/core/dataarray.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -259,8 +259,14 @@ def _replace(self, variable=None, coords=None, name=__default):
259259
return type(self)(variable, coords, name=name, fastpath=True)
260260

261261
def _replace_maybe_drop_dims(self, variable, name=__default):
262-
if variable.dims == self.dims:
262+
if variable.dims == self.dims and variable.shape == self.shape:
263263
coords = self._coords.copy()
264+
elif variable.dims == self.dims:
265+
# Shape has changed (e.g. from reduce(..., keepdims=True)
266+
new_sizes = dict(zip(self.dims, variable.shape))
267+
coords = OrderedDict((k, v) for k, v in self._coords.items()
268+
if v.shape == tuple(new_sizes[d]
269+
for d in v.dims))
264270
else:
265271
allowed_dims = set(variable.dims)
266272
coords = OrderedDict((k, v) for k, v in self._coords.items()
@@ -1642,7 +1648,8 @@ def combine_first(self, other):
16421648
"""
16431649
return ops.fillna(self, other, join="outer")
16441650

1645-
def reduce(self, func, dim=None, axis=None, keep_attrs=None, **kwargs):
1651+
def reduce(self, func, dim=None, axis=None, keep_attrs=None,
1652+
keepdims=False, **kwargs):
16461653
"""Reduce this array by applying `func` along some dimension(s).
16471654
16481655
Parameters
@@ -1662,6 +1669,10 @@ def reduce(self, func, dim=None, axis=None, keep_attrs=None, **kwargs):
16621669
If True, the variable's attributes (`attrs`) will be copied from
16631670
the original object to the new one. If False (default), the new
16641671
object will be returned without attributes.
1672+
keepdims : bool, default False
1673+
If True, the dimensions which are reduced are left in the result
1674+
as dimensions of size one. Coordinates that use these dimensions
1675+
are removed.
16651676
**kwargs : dict
16661677
Additional keyword arguments passed on to `func`.
16671678
@@ -1672,7 +1683,8 @@ def reduce(self, func, dim=None, axis=None, keep_attrs=None, **kwargs):
16721683
summarized data and the indicated dimension(s) removed.
16731684
"""
16741685

1675-
var = self.variable.reduce(func, dim, axis, keep_attrs, **kwargs)
1686+
var = self.variable.reduce(func, dim, axis, keep_attrs, keepdims,
1687+
**kwargs)
16761688
return self._replace_maybe_drop_dims(var)
16771689

16781690
def to_pandas(self):

xarray/core/dataset.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3152,8 +3152,8 @@ def combine_first(self, other):
31523152
out = ops.fillna(self, other, join="outer", dataset_join="outer")
31533153
return out
31543154

3155-
def reduce(self, func, dim=None, keep_attrs=None, numeric_only=False,
3156-
allow_lazy=False, **kwargs):
3155+
def reduce(self, func, dim=None, keep_attrs=None, keepdims=False,
3156+
numeric_only=False, allow_lazy=False, **kwargs):
31573157
"""Reduce this dataset by applying `func` along some dimension(s).
31583158
31593159
Parameters
@@ -3169,6 +3169,10 @@ def reduce(self, func, dim=None, keep_attrs=None, numeric_only=False,
31693169
If True, the dataset's attributes (`attrs`) will be copied from
31703170
the original object to the new one. If False (default), the new
31713171
object will be returned without attributes.
3172+
keepdims : bool, default False
3173+
If True, the dimensions which are reduced are left in the result
3174+
as dimensions of size one. Coordinates that use these dimensions
3175+
are removed.
31723176
numeric_only : bool, optional
31733177
If True, only apply ``func`` to variables with a numeric dtype.
31743178
**kwargs : dict
@@ -3218,6 +3222,7 @@ def reduce(self, func, dim=None, keep_attrs=None, numeric_only=False,
32183222
reduce_dims = None
32193223
variables[name] = var.reduce(func, dim=reduce_dims,
32203224
keep_attrs=keep_attrs,
3225+
keepdims=keepdims,
32213226
allow_lazy=allow_lazy,
32223227
**kwargs)
32233228

xarray/core/variable.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1334,7 +1334,7 @@ def where(self, cond, other=dtypes.NA):
13341334
return ops.where_method(self, cond, other)
13351335

13361336
def reduce(self, func, dim=None, axis=None,
1337-
keep_attrs=None, allow_lazy=False, **kwargs):
1337+
keep_attrs=None, keepdims=False, allow_lazy=False, **kwargs):
13381338
"""Reduce this array by applying `func` along some dimension(s).
13391339
13401340
Parameters
@@ -1354,6 +1354,9 @@ def reduce(self, func, dim=None, axis=None,
13541354
If True, the variable's attributes (`attrs`) will be copied from
13551355
the original object to the new one. If False (default), the new
13561356
object will be returned without attributes.
1357+
keepdims : bool, default False
1358+
If True, the dimensions which are reduced are left in the result
1359+
as dimensions of size one
13571360
**kwargs : dict
13581361
Additional keyword arguments passed on to `func`.
13591362
@@ -1381,8 +1384,19 @@ def reduce(self, func, dim=None, axis=None,
13811384
else:
13821385
removed_axes = (range(self.ndim) if axis is None
13831386
else np.atleast_1d(axis) % self.ndim)
1384-
dims = [adim for n, adim in enumerate(self.dims)
1385-
if n not in removed_axes]
1387+
if keepdims:
1388+
# Insert np.newaxis for removed dims
1389+
slices = tuple(np.newaxis if i in removed_axes else
1390+
slice(None, None) for i in range(self.ndim))
1391+
if getattr(data, 'shape', None) is None:
1392+
# Reduce has produced a scalar value, not an array-like
1393+
data = np.asanyarray(data)[slices]
1394+
else:
1395+
data = data[slices]
1396+
dims = self.dims
1397+
else:
1398+
dims = [adim for n, adim in enumerate(self.dims)
1399+
if n not in removed_axes]
13861400

13871401
if keep_attrs is None:
13881402
keep_attrs = _get_keep_attrs(default=False)

xarray/tests/test_dataarray.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1991,6 +1991,44 @@ def test_reduce(self):
19911991
dims=['x', 'y']).mean('x')
19921992
assert_equal(actual, expected)
19931993

1994+
def test_reduce_keepdims(self):
1995+
coords = {'x': [-1, -2], 'y': ['ab', 'cd', 'ef'],
1996+
'lat': (['x', 'y'], [[1, 2, 3], [-1, -2, -3]]),
1997+
'c': -999}
1998+
orig = DataArray([[-1, 0, 1], [-3, 0, 3]], coords, dims=['x', 'y'])
1999+
2000+
# Mean on all axes loses non-constant coordinates
2001+
actual = orig.mean(keepdims=True)
2002+
expected = DataArray(orig.data.mean(keepdims=True), dims=orig.dims,
2003+
coords={k: v for k, v in coords.items()
2004+
if k in ['c']})
2005+
assert_equal(actual, expected)
2006+
2007+
assert actual.sizes['x'] == 1
2008+
assert actual.sizes['y'] == 1
2009+
2010+
# Mean on specific axes loses coordinates not involving that axis
2011+
actual = orig.mean('y', keepdims=True)
2012+
expected = DataArray(orig.data.mean(axis=1, keepdims=True),
2013+
dims=orig.dims,
2014+
coords={k: v for k, v in coords.items()
2015+
if k not in ['y', 'lat']})
2016+
assert_equal(actual, expected)
2017+
2018+
@requires_bottleneck
2019+
def test_reduce_keepdims_bottleneck(self):
2020+
import bottleneck
2021+
2022+
coords = {'x': [-1, -2], 'y': ['ab', 'cd', 'ef'],
2023+
'lat': (['x', 'y'], [[1, 2, 3], [-1, -2, -3]]),
2024+
'c': -999}
2025+
orig = DataArray([[-1, 0, 1], [-3, 0, 3]], coords, dims=['x', 'y'])
2026+
2027+
# Bottleneck does not have its own keepdims implementation
2028+
actual = orig.reduce(bottleneck.nanmean, keepdims=True)
2029+
expected = orig.mean(keepdims=True)
2030+
assert_equal(actual, expected)
2031+
19942032
def test_reduce_dtype(self):
19952033
coords = {'x': [-1, -2], 'y': ['ab', 'cd', 'ef'],
19962034
'lat': (['x', 'y'], [[1, 2, 3], [-1, -2, -3]]),

xarray/tests/test_dataset.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3898,6 +3898,25 @@ def total_sum(x):
38983898
with raises_regex(TypeError, "unexpected keyword argument 'axis'"):
38993899
ds.reduce(total_sum, dim='x')
39003900

3901+
def test_reduce_keepdims(self):
3902+
ds = Dataset({'a': (['x', 'y'], [[0, 1, 2, 3, 4]])},
3903+
coords={'y': [0, 1, 2, 3, 4], 'x': [0],
3904+
'lat': (['x', 'y'], [[0, 1, 2, 3, 4]]),
3905+
'c': -999.0})
3906+
3907+
# Shape should match behaviour of numpy reductions with keepdims=True
3908+
# Coordinates involved in the reduction should be removed
3909+
actual = ds.mean(keepdims=True)
3910+
expected = Dataset({'a': (['x', 'y'], np.mean(ds.a, keepdims=True))},
3911+
coords={'c': ds.c})
3912+
assert_identical(expected, actual)
3913+
3914+
actual = ds.mean('x', keepdims=True)
3915+
expected = Dataset({'a': (['x', 'y'],
3916+
np.mean(ds.a, axis=0, keepdims=True))},
3917+
coords={'y': ds.y, 'c': ds.c})
3918+
assert_identical(expected, actual)
3919+
39013920
def test_quantile(self):
39023921

39033922
ds = create_test_data(seed=123)

xarray/tests/test_variable.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1540,6 +1540,42 @@ def test_reduce_funcs(self):
15401540
assert_identical(
15411541
v.max(), Variable([], pd.Timestamp('2000-01-03')))
15421542

1543+
def test_reduce_keepdims(self):
1544+
v = Variable(['x', 'y'], self.d)
1545+
1546+
assert_identical(v.mean(keepdims=True),
1547+
Variable(v.dims, np.mean(self.d, keepdims=True)))
1548+
assert_identical(v.mean(dim='x', keepdims=True),
1549+
Variable(v.dims, np.mean(self.d, axis=0,
1550+
keepdims=True)))
1551+
assert_identical(v.mean(dim='y', keepdims=True),
1552+
Variable(v.dims, np.mean(self.d, axis=1,
1553+
keepdims=True)))
1554+
assert_identical(v.mean(dim=['y', 'x'], keepdims=True),
1555+
Variable(v.dims, np.mean(self.d, axis=(1, 0),
1556+
keepdims=True)))
1557+
1558+
v = Variable([], 1.0)
1559+
assert_identical(v.mean(keepdims=True),
1560+
Variable([], np.mean(v.data, keepdims=True)))
1561+
1562+
@requires_dask
1563+
def test_reduce_keepdims_dask(self):
1564+
import dask.array
1565+
v = Variable(['x', 'y'], self.d).chunk()
1566+
1567+
actual = v.mean(keepdims=True)
1568+
assert isinstance(actual.data, dask.array.Array)
1569+
1570+
expected = Variable(v.dims, np.mean(self.d, keepdims=True))
1571+
assert_identical(actual, expected)
1572+
1573+
actual = v.mean(dim='y', keepdims=True)
1574+
assert isinstance(actual.data, dask.array.Array)
1575+
1576+
expected = Variable(v.dims, np.mean(self.d, axis=1, keepdims=True))
1577+
assert_identical(actual, expected)
1578+
15431579
def test_reduce_keep_attrs(self):
15441580
_attrs = {'units': 'test', 'long_name': 'testing'}
15451581

0 commit comments

Comments
 (0)