Skip to content

Allow a function in .sortby method #8273

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Oct 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,12 @@ New Features
~~~~~~~~~~~~

- :py:meth:`DataArray.where` & :py:meth:`Dataset.where` accept a callable for
the ``other`` parameter, passing the object as the first argument. Previously,
the ``other`` parameter, passing the object as the only argument. Previously,
this was only valid for the ``cond`` parameter. (:issue:`8255`)
By `Maximilian Roos <https://github.com/max-sixty>`_.
- :py:meth:`DataArray.sortby` & :py:meth:`Dataset.sortby` accept a callable for
the ``variables`` parameter, passing the object as the only argument.
By `Maximilian Roos <https://github.com/max-sixty>`_.

Breaking changes
~~~~~~~~~~~~~~~~
Expand Down
3 changes: 2 additions & 1 deletion xarray/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1073,7 +1073,8 @@ def where(self, cond: Any, other: Any = dtypes.NA, drop: bool = False) -> Self:
----------
cond : DataArray, Dataset, or callable
Locations at which to preserve this object's values. dtype must be `bool`.
If a callable, it must expect this object as its only parameter.
If a callable, the callable is passed this object, and the result is used as
the value for cond.
other : scalar, DataArray, Dataset, or callable, optional
Value to use for locations in this object where ``cond`` is False.
By default, these locations are filled with NA. If a callable, it must
Expand Down
31 changes: 23 additions & 8 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -4921,7 +4921,10 @@ def dot(

def sortby(
self,
variables: Hashable | DataArray | Sequence[Hashable | DataArray],
variables: Hashable
| DataArray
| Sequence[Hashable | DataArray]
| Callable[[Self], Hashable | DataArray | Sequence[Hashable | DataArray]],
ascending: bool = True,
) -> Self:
"""Sort object by labels or values (along an axis).
Expand All @@ -4942,9 +4945,10 @@ def sortby(

Parameters
----------
variables : Hashable, DataArray, or sequence of Hashable or DataArray
1D DataArray objects or name(s) of 1D variable(s) in
coords whose values are used to sort this array.
variables : Hashable, DataArray, sequence of Hashable or DataArray, or Callable
1D DataArray objects or name(s) of 1D variable(s) in coords whose values are
used to sort this array. If a callable, the callable is passed this object,
and the result is used as the value for cond.
ascending : bool, default: True
Whether to sort by ascending or descending order.

Expand All @@ -4964,22 +4968,33 @@ def sortby(
Examples
--------
>>> da = xr.DataArray(
... np.random.rand(5),
... np.arange(5, 0, -1),
... coords=[pd.date_range("1/1/2000", periods=5)],
... dims="time",
... )
>>> da
<xarray.DataArray (time: 5)>
array([0.5488135 , 0.71518937, 0.60276338, 0.54488318, 0.4236548 ])
array([5, 4, 3, 2, 1])
Coordinates:
* time (time) datetime64[ns] 2000-01-01 2000-01-02 ... 2000-01-05

>>> da.sortby(da)
<xarray.DataArray (time: 5)>
array([0.4236548 , 0.54488318, 0.5488135 , 0.60276338, 0.71518937])
array([1, 2, 3, 4, 5])
Coordinates:
* time (time) datetime64[ns] 2000-01-05 2000-01-04 ... 2000-01-02
* time (time) datetime64[ns] 2000-01-05 2000-01-04 ... 2000-01-01

>>> da.sortby(lambda x: x)
<xarray.DataArray (time: 5)>
array([1, 2, 3, 4, 5])
Coordinates:
* time (time) datetime64[ns] 2000-01-05 2000-01-04 ... 2000-01-01
"""
# We need to convert the callable here rather than pass it through to the
# dataset method, since otherwise the dataset method would try to call the
# callable with the dataset as the object
if callable(variables):
variables = variables(self)
ds = self._to_temp_dataset().sortby(variables, ascending=ascending)
return self._from_temp_dataset(ds)

Expand Down
26 changes: 20 additions & 6 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7824,7 +7824,10 @@ def roll(

def sortby(
self,
variables: Hashable | DataArray | list[Hashable | DataArray],
variables: Hashable
| DataArray
| Sequence[Hashable | DataArray]
| Callable[[Self], Hashable | DataArray | list[Hashable | DataArray]],
ascending: bool = True,
) -> Self:
"""
Expand All @@ -7846,9 +7849,10 @@ def sortby(

Parameters
----------
variables : Hashable, DataArray, or list of hashable or DataArray
1D DataArray objects or name(s) of 1D variable(s) in
coords/data_vars whose values are used to sort the dataset.
kariables : Hashable, DataArray, sequence of Hashable or DataArray, or Callable
1D DataArray objects or name(s) of 1D variable(s) in coords whose values are
used to sort this array. If a callable, the callable is passed this object,
and the result is used as the value for cond.
ascending : bool, default: True
Whether to sort by ascending or descending order.

Expand All @@ -7874,8 +7878,7 @@ def sortby(
... },
... coords={"x": ["b", "a"], "y": [1, 0]},
... )
>>> ds = ds.sortby("x")
>>> ds
>>> ds.sortby("x")
<xarray.Dataset>
Dimensions: (x: 2, y: 2)
Coordinates:
Expand All @@ -7884,9 +7887,20 @@ def sortby(
Data variables:
A (x, y) int64 3 4 1 2
B (x, y) int64 7 8 5 6
>>> ds.sortby(lambda x: -x["y"])
<xarray.Dataset>
Dimensions: (x: 2, y: 2)
Coordinates:
* x (x) <U1 'b' 'a'
* y (y) int64 1 0
Data variables:
A (x, y) int64 1 2 3 4
B (x, y) int64 5 6 7 8
"""
from xarray.core.dataarray import DataArray

if callable(variables):
variables = variables(self)
if not isinstance(variables, list):
variables = [variables]
else:
Expand Down