Skip to content

Commit d8d87d2

Browse files
maaleskeshoyer
authored andcommitted
Allow passing of positional arguments in apply for Groupby objects (#2413)
* Allow passing of positional arguments to `func` in `XGroupBy.apply` * Fix the documentation call to `func` to show the actual call. * Add changes to whats-new.rst * Fix typo * Allow passing args to func from DatasetResample and DataArrayResample * Update whats-new with Resample changes * Add tests for func arguments in groupby * Add tests for apply func arguments in resample * flake8 fixes
1 parent 7fcb80f commit d8d87d2

File tree

7 files changed

+72
-10
lines changed

7 files changed

+72
-10
lines changed

doc/whats-new.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@ Enhancements
6666
- :py:meth:`DataArray.resample` and :py:meth:`Dataset.resample` now supports the
6767
``loffset`` kwarg just like Pandas.
6868
By `Deepak Cherian <https://github.com/dcherian>`_
69+
- The `apply` methods for `DatasetGroupBy`, `DataArrayGroupBy`,
70+
`DatasetResample` and `DataArrayResample` can now pass positional arguments to
71+
the applied function.
72+
By `Matti Eskelinen <https://github.com/maaleske>`_.
6973
- 0d slices of ndarrays are now obtained directly through indexing, rather than
7074
extracting and wrapping a scalar, avoiding unnecessary copying. By `Daniel
7175
Wennberg <https://github.com/danielwe>`_.
@@ -260,13 +264,17 @@ Announcements of note:
260264
for more details.
261265
- We have a new :doc:`roadmap` that outlines our future development plans.
262266

267+
- `Dataset.apply` now properly documents the way `func` is called.
268+
By `Matti Eskelinen <https://github.com/maaleske>`_.
269+
263270
Enhancements
264271
~~~~~~~~~~~~
265272

266273
- :py:meth:`~xarray.DataArray.differentiate` and
267274
:py:meth:`~xarray.Dataset.differentiate` are newly added.
268275
(:issue:`1332`)
269276
By `Keisuke Fujii <https://github.com/fujiisoup>`_.
277+
270278
- Default colormap for sequential and divergent data can now be set via
271279
:py:func:`~xarray.set_options()`
272280
(:issue:`2394`)

xarray/core/dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2953,8 +2953,8 @@ def apply(self, func, keep_attrs=None, args=(), **kwargs):
29532953
Parameters
29542954
----------
29552955
func : function
2956-
Function which can be called in the form `f(x, **kwargs)` to
2957-
transform each DataArray `x` in this dataset into another
2956+
Function which can be called in the form `func(x, *args, **kwargs)`
2957+
to transform each DataArray `x` in this dataset into another
29582958
DataArray.
29592959
keep_attrs : bool, optional
29602960
If True, the dataset's attributes (`attrs`) will be copied from

xarray/core/groupby.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,7 @@ def lookup_order(dimension):
503503
new_order = sorted(stacked.dims, key=lookup_order)
504504
return stacked.transpose(*new_order)
505505

506-
def apply(self, func, shortcut=False, **kwargs):
506+
def apply(self, func, shortcut=False, args=(), **kwargs):
507507
"""Apply a function over each array in the group and concatenate them
508508
together into a new array.
509509
@@ -532,6 +532,8 @@ def apply(self, func, shortcut=False, **kwargs):
532532
If these conditions are satisfied `shortcut` provides significant
533533
speedup. This should be the case for many common groupby operations
534534
(e.g., applying numpy ufuncs).
535+
args : tuple, optional
536+
Positional arguments passed to `func`.
535537
**kwargs
536538
Used to call `func(ar, **kwargs)` for each array `ar`.
537539
@@ -544,7 +546,7 @@ def apply(self, func, shortcut=False, **kwargs):
544546
grouped = self._iter_grouped_shortcut()
545547
else:
546548
grouped = self._iter_grouped()
547-
applied = (maybe_wrap_array(arr, func(arr, **kwargs))
549+
applied = (maybe_wrap_array(arr, func(arr, *args, **kwargs))
548550
for arr in grouped)
549551
return self._combine(applied, shortcut=shortcut)
550552

@@ -642,7 +644,7 @@ def wrapped_func(self, dim=DEFAULT_DIMS, axis=None,
642644

643645

644646
class DatasetGroupBy(GroupBy, ImplementsDatasetReduce):
645-
def apply(self, func, **kwargs):
647+
def apply(self, func, args=(), **kwargs):
646648
"""Apply a function over each Dataset in the group and concatenate them
647649
together into a new Dataset.
648650
@@ -661,6 +663,8 @@ def apply(self, func, **kwargs):
661663
----------
662664
func : function
663665
Callable to apply to each sub-dataset.
666+
args : tuple, optional
667+
Positional arguments to pass to `func`.
664668
**kwargs
665669
Used to call `func(ds, **kwargs)` for each sub-dataset `ar`.
666670
@@ -670,7 +674,7 @@ def apply(self, func, **kwargs):
670674
The result of splitting, applying and combining this dataset.
671675
"""
672676
kwargs.pop('shortcut', None) # ignore shortcut if set (for now)
673-
applied = (func(ds, **kwargs) for ds in self._iter_grouped())
677+
applied = (func(ds, *args, **kwargs) for ds in self._iter_grouped())
674678
return self._combine(applied)
675679

676680
def _combine(self, applied):

xarray/core/resample.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def __init__(self, *args, **kwargs):
129129
"('{}')! ".format(self._resample_dim, self._dim))
130130
super(DataArrayResample, self).__init__(*args, **kwargs)
131131

132-
def apply(self, func, shortcut=False, **kwargs):
132+
def apply(self, func, shortcut=False, args=(), **kwargs):
133133
"""Apply a function over each array in the group and concatenate them
134134
together into a new array.
135135
@@ -158,6 +158,8 @@ def apply(self, func, shortcut=False, **kwargs):
158158
If these conditions are satisfied `shortcut` provides significant
159159
speedup. This should be the case for many common groupby operations
160160
(e.g., applying numpy ufuncs).
161+
args : tuple, optional
162+
Positional arguments passed on to `func`.
161163
**kwargs
162164
Used to call `func(ar, **kwargs)` for each array `ar`.
163165
@@ -167,7 +169,7 @@ def apply(self, func, shortcut=False, **kwargs):
167169
The result of splitting, applying and combining this array.
168170
"""
169171
combined = super(DataArrayResample, self).apply(
170-
func, shortcut=shortcut, **kwargs)
172+
func, shortcut=shortcut, args=args, **kwargs)
171173

172174
# If the aggregation function didn't drop the original resampling
173175
# dimension, then we need to do so before we can rename the proxy
@@ -240,7 +242,7 @@ def __init__(self, *args, **kwargs):
240242
"('{}')! ".format(self._resample_dim, self._dim))
241243
super(DatasetResample, self).__init__(*args, **kwargs)
242244

243-
def apply(self, func, **kwargs):
245+
def apply(self, func, args=(), **kwargs):
244246
"""Apply a function over each Dataset in the groups generated for
245247
resampling and concatenate them together into a new Dataset.
246248
@@ -259,6 +261,8 @@ def apply(self, func, **kwargs):
259261
----------
260262
func : function
261263
Callable to apply to each sub-dataset.
264+
args : tuple, optional
265+
Positional arguments passed on to `func`.
262266
**kwargs
263267
Used to call `func(ds, **kwargs)` for each sub-dataset `ar`.
264268
@@ -268,7 +272,7 @@ def apply(self, func, **kwargs):
268272
The result of splitting, applying and combining this dataset.
269273
"""
270274
kwargs.pop('shortcut', None) # ignore shortcut if set (for now)
271-
applied = (func(ds, **kwargs) for ds in self._iter_grouped())
275+
applied = (func(ds, *args, **kwargs) for ds in self._iter_grouped())
272276
combined = self._combine(applied)
273277

274278
return combined.rename({self._resample_dim: self._dim})

xarray/tests/test_dataarray.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2295,6 +2295,17 @@ def test_resample(self):
22952295
with raises_regex(ValueError, 'index must be monotonic'):
22962296
array[[2, 0, 1]].resample(time='1D')
22972297

2298+
def test_da_resample_func_args(self):
2299+
2300+
def func(arg1, arg2, arg3=0.):
2301+
return arg1.mean('time') + arg2 + arg3
2302+
2303+
times = pd.date_range('2000', periods=3, freq='D')
2304+
da = xr.DataArray([1., 1., 1.], coords=[times], dims=['time'])
2305+
expected = xr.DataArray([3., 3., 3.], coords=[times], dims=['time'])
2306+
actual = da.resample(time='D').apply(func, args=(1.,), arg3=1.)
2307+
assert_identical(actual, expected)
2308+
22982309
@requires_cftime
22992310
def test_resample_cftimeindex(self):
23002311
cftime = _import_cftime()

xarray/tests/test_dataset.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2886,6 +2886,19 @@ def test_resample_old_api(self):
28862886
with raises_regex(TypeError, r'resample\(\) no longer supports'):
28872887
ds.resample('1D', dim='time')
28882888

2889+
def test_ds_resample_apply_func_args(self):
2890+
2891+
def func(arg1, arg2, arg3=0.):
2892+
return arg1.mean('time') + arg2 + arg3
2893+
2894+
times = pd.date_range('2000', freq='D', periods=3)
2895+
ds = xr.Dataset({'foo': ('time', [1., 1., 1.]),
2896+
'time': times})
2897+
expected = xr.Dataset({'foo': ('time', [3., 3., 3.]),
2898+
'time': times})
2899+
actual = ds.resample(time='D').apply(func, args=(1.,), arg3=1.)
2900+
assert_identical(expected, actual)
2901+
28892902
def test_to_array(self):
28902903
ds = Dataset(OrderedDict([('a', 1), ('b', ('x', [1, 2, 3]))]),
28912904
coords={'c': 42}, attrs={'Conventions': 'None'})

xarray/tests/test_groupby.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,4 +85,26 @@ def test_groupby_input_mutation():
8585
assert_identical(array, array_copy) # should not modify inputs
8686

8787

88+
def test_da_groupby_apply_func_args():
89+
90+
def func(arg1, arg2, arg3=0):
91+
return arg1 + arg2 + arg3
92+
93+
array = xr.DataArray([1, 1, 1], [('x', [1, 2, 3])])
94+
expected = xr.DataArray([3, 3, 3], [('x', [1, 2, 3])])
95+
actual = array.groupby('x').apply(func, args=(1,), arg3=1)
96+
assert_identical(expected, actual)
97+
98+
99+
def test_ds_groupby_apply_func_args():
100+
101+
def func(arg1, arg2, arg3=0):
102+
return arg1 + arg2 + arg3
103+
104+
dataset = xr.Dataset({'foo': ('x', [1, 1, 1])}, {'x': [1, 2, 3]})
105+
expected = xr.Dataset({'foo': ('x', [3, 3, 3])}, {'x': [1, 2, 3]})
106+
actual = dataset.groupby('x').apply(func, args=(1,), arg3=1)
107+
assert_identical(expected, actual)
108+
109+
88110
# TODO: move other groupby tests from test_dataset and test_dataarray over here

0 commit comments

Comments
 (0)