Skip to content

Commit ab480b5

Browse files
committed
Add 'out' keyword to argmin/argmax methods - allow numpy call signature
When np.argmin(da) is called, numpy passes an 'out' keyword argument to argmin/argmax. Need to allow this argument to avoid errors (but an exception is thrown if out is not None).
1 parent cb6742d commit ab480b5

File tree

2 files changed

+23
-6
lines changed

2 files changed

+23
-6
lines changed

xarray/core/dataarray.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3728,6 +3728,7 @@ def argmin(
37283728
axis: Union[int, None] = None,
37293729
keep_attrs: bool = None,
37303730
skipna: bool = None,
3731+
out=None,
37313732
) -> Union["DataArray", Dict[Hashable, "DataArray"]]:
37323733
"""Indices of the minimum of the DataArray over one or more dimensions. Result
37333734
returned as dict of DataArrays, which can be passed directly to isel().
@@ -3752,6 +3753,9 @@ def argmin(
37523753
skips missing values for float dtypes; other dtypes either do not
37533754
have a sentinel missing value (int) or skipna=True has not been
37543755
implemented (object, datetime64 or timedelta64).
3756+
out : None
3757+
'out' should not be passed - provided for compatibility with numpy function
3758+
signature
37553759
37563760
Returns
37573761
-------
@@ -3812,7 +3816,7 @@ def argmin(
38123816
array([ 1, -5, 1])
38133817
Dimensions without coordinates: y
38143818
"""
3815-
result = self.variable.argmin(dim, axis, keep_attrs, skipna)
3819+
result = self.variable.argmin(dim, axis, keep_attrs, skipna, out)
38163820
if isinstance(result, dict):
38173821
return {k: self._replace_maybe_drop_dims(v) for k, v in result.items()}
38183822
else:
@@ -3824,6 +3828,7 @@ def argmax(
38243828
axis: Union[int, None] = None,
38253829
keep_attrs: bool = None,
38263830
skipna: bool = None,
3831+
out=None,
38273832
) -> Union["DataArray", Dict[Hashable, "DataArray"]]:
38283833
"""Indices of the maximum of the DataArray over one or more dimensions. Result
38293834
returned as dict of DataArrays, which can be passed directly to isel().
@@ -3848,6 +3853,9 @@ def argmax(
38483853
skips missing values for float dtypes; other dtypes either do not
38493854
have a sentinel missing value (int) or skipna=True has not been
38503855
implemented (object, datetime64 or timedelta64).
3856+
out : None
3857+
'out' should not be passed - provided for compatibility with numpy function
3858+
signature
38513859
38523860
Returns
38533861
-------
@@ -3909,7 +3917,7 @@ def argmax(
39093917
array([3, 5, 3])
39103918
Dimensions without coordinates: y
39113919
"""
3912-
result = self.variable.argmax(dim, axis, keep_attrs, skipna)
3920+
result = self.variable.argmax(dim, axis, keep_attrs, skipna, out)
39133921
if isinstance(result, dict):
39143922
return {k: self._replace_maybe_drop_dims(v) for k, v in result.items()}
39153923
else:

xarray/core/variable.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2086,6 +2086,7 @@ def _unravel_argminmax(
20862086
axis: Union[int, None],
20872087
keep_attrs: Optional[bool],
20882088
skipna: Optional[bool],
2089+
out,
20892090
) -> Union["Variable", Dict[Hashable, "Variable"]]:
20902091
"""Apply argmin or argmax over one or more dimensions, returning the result as a
20912092
dict of DataArray that can be passed directly to isel.
@@ -2110,7 +2111,7 @@ def _unravel_argminmax(
21102111
# Return int index if single dimension is passed, and is not part of a
21112112
# sequence
21122113
return getattr(self, "_injected_" + str(argminmax))(
2113-
dim=dim, axis=axis, keep_attrs=keep_attrs, skipna=skipna
2114+
dim=dim, axis=axis, keep_attrs=keep_attrs, skipna=skipna, out=out
21142115
)
21152116

21162117
# Get a name for the new dimension that does not conflict with any existing
@@ -2127,7 +2128,7 @@ def _unravel_argminmax(
21272128
reduce_shape = tuple(self.sizes[d] for d in dim)
21282129

21292130
result_flat_indices = getattr(stacked, "_injected_" + str(argminmax))(
2130-
axis=-1, skipna=skipna
2131+
axis=-1, skipna=skipna, out=out
21312132
)
21322133

21332134
result_unravelled_indices = np.unravel_index(result_flat_indices, reduce_shape)
@@ -2151,6 +2152,7 @@ def argmin(
21512152
axis: Union[int, None] = None,
21522153
keep_attrs: bool = None,
21532154
skipna: bool = None,
2155+
out=None,
21542156
) -> Union["Variable", Dict[Hashable, "Variable"]]:
21552157
"""Indices of the minimum of the DataArray over one or more dimensions. Result
21562158
returned as dict of DataArrays, which can be passed directly to isel().
@@ -2175,6 +2177,9 @@ def argmin(
21752177
skips missing values for float dtypes; other dtypes either do not
21762178
have a sentinel missing value (int) or skipna=True has not been
21772179
implemented (object, datetime64 or timedelta64).
2180+
out : None
2181+
'out' should not be passed - provided for compatibility with numpy function
2182+
signature
21782183
21792184
Returns
21802185
-------
@@ -2184,14 +2189,15 @@ def argmin(
21842189
--------
21852190
DataArray.argmin, DataArray.idxmin
21862191
"""
2187-
return self._unravel_argminmax("argmin", dim, axis, keep_attrs, skipna)
2192+
return self._unravel_argminmax("argmin", dim, axis, keep_attrs, skipna, out)
21882193

21892194
def argmax(
21902195
self,
21912196
dim: Union[Hashable, Sequence[Hashable]] = None,
21922197
axis: Union[int, None] = None,
21932198
keep_attrs: bool = None,
21942199
skipna: bool = None,
2200+
out=None,
21952201
) -> Union["Variable", Dict[Hashable, "Variable"]]:
21962202
"""Indices of the maximum of the DataArray over one or more dimensions. Result
21972203
returned as dict of DataArrays, which can be passed directly to isel().
@@ -2216,6 +2222,9 @@ def argmax(
22162222
skips missing values for float dtypes; other dtypes either do not
22172223
have a sentinel missing value (int) or skipna=True has not been
22182224
implemented (object, datetime64 or timedelta64).
2225+
out : None
2226+
'out' should not be passed - provided for compatibility with numpy function
2227+
signature
22192228
22202229
Returns
22212230
-------
@@ -2225,7 +2234,7 @@ def argmax(
22252234
--------
22262235
DataArray.argmax, DataArray.idxmax
22272236
"""
2228-
return self._unravel_argminmax("argmax", dim, axis, keep_attrs, skipna)
2237+
return self._unravel_argminmax("argmax", dim, axis, keep_attrs, skipna, out)
22292238

22302239

22312240
ops.inject_all_ops_and_reduce_methods(Variable)

0 commit comments

Comments
 (0)