Skip to content

Commit 199c947

Browse files
committed
Add case_when API
* Used to support conditional assignment operation.
1 parent fa78ea8 commit 199c947

File tree

6 files changed

+177
-0
lines changed

6 files changed

+177
-0
lines changed

doc/source/whatsnew/v2.0.0.rst

+20
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,26 @@ including other versions of pandas.
1414
Enhancements
1515
~~~~~~~~~~~~
1616

17+
.. _whatsnew_200.enhancements.case_when:
18+
19+
Assignment based on multiple conditions
20+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
21+
22+
The ``pd.case_when`` API has now been added to support assignment based on multiple conditions.
23+
24+
.. ipython:: python
25+
26+
import pandas as pd
27+
28+
df = pd.DataFrame(dict(a=[1, 2, 3], b=[4, 5, 6]))
29+
df.assign(
30+
new_column=pd.case_when(
31+
lambda x: x.a == 1, 'first',
32+
lambda x: (x.a > 1) & (x.b == 5), 'second',
33+
default='default',
34+
)
35+
)
36+
1737
.. _whatsnew_200.enhancements.optional_dependency_management_pip:
1838

1939
Installing optional dependencies with pip extras

pandas/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
notnull,
7373
# indexes
7474
Index,
75+
case_when,
7576
CategoricalIndex,
7677
RangeIndex,
7778
MultiIndex,
@@ -231,6 +232,7 @@
231232
__all__ = [
232233
"ArrowDtype",
233234
"BooleanDtype",
235+
"case_when",
234236
"Categorical",
235237
"CategoricalDtype",
236238
"CategoricalIndex",

pandas/core/api.py

+3
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
UInt64Dtype,
4343
)
4444
from pandas.core.arrays.string_ import StringDtype
45+
from pandas.core.case_when import case_when
4546
from pandas.core.construction import array
4647
from pandas.core.flags import Flags
4748
from pandas.core.groupby import (
@@ -84,11 +85,13 @@
8485
# DataFrame needs to be imported after NamedAgg to avoid a circular import
8586
from pandas.core.frame import DataFrame # isort:skip
8687

88+
8789
__all__ = [
8890
"array",
8991
"ArrowDtype",
9092
"bdate_range",
9193
"BooleanDtype",
94+
"case_when",
9295
"Categorical",
9396
"CategoricalDtype",
9497
"CategoricalIndex",

pandas/core/case_when.py

+91
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
from __future__ import annotations
2+
3+
from typing import (
4+
Any,
5+
Callable,
6+
)
7+
8+
import numpy as np
9+
10+
from pandas._libs import lib
11+
12+
import pandas as pd
13+
import pandas.core.common as com
14+
15+
16+
def case_when(*args, default: Any = lib.no_default) -> Callable:
17+
"""
18+
Create a callable for assignment based on a condition or multiple conditions.
19+
20+
This is useful when you want to assign a column based on multiple conditions.
21+
22+
Parameters
23+
----------
24+
args : Variable argument of conditions and expected values.
25+
Takes the form:
26+
`condition0`, `value0`, `condition1`, `value1`, ...
27+
`condition` can be a 1-D boolean array/series or a callable
28+
that evaluate to a 1-D boolean array/series.
29+
default : Any, default is `None`.
30+
The default value to be used if all conditions evaluate False.
31+
32+
Returns
33+
-------
34+
Callable
35+
The Callable returned in `case_when` can be used with `df.assign(...)`
36+
for multi-condition assignment. See examples below for more info.
37+
38+
See Also
39+
--------
40+
DataFrame.assign: Assign new columns to a DataFrame.
41+
42+
Examples
43+
--------
44+
>>> df = pd.DataFrame(dict(a=[1, 2, 3], b=[4, 5, 6]))
45+
>>> df
46+
a b
47+
0 1 4
48+
1 2 5
49+
2 3 6
50+
51+
>>> df.assign(
52+
... new_column = pd.case_when(
53+
... lambda x: x.a == 1, 'first',
54+
... lambda x: (x.a > 1) & (x.b == 5), 'second',
55+
... default='default'
56+
... )
57+
... )
58+
a b new_column
59+
0 1 4 first
60+
1 2 5 second
61+
2 3 6 default
62+
"""
63+
len_args = len(args)
64+
65+
if len_args < 2:
66+
raise ValueError("At least two arguments are required for `case_when`")
67+
if len_args % 2:
68+
raise ValueError(
69+
"The number of conditions and values do not match. "
70+
f"There are {len_args - len_args//2} conditions "
71+
f"and {len_args//2} values."
72+
)
73+
74+
if default is lib.no_default:
75+
default = None
76+
77+
def _eval(df: pd.DataFrame) -> np.ndarray:
78+
booleans = []
79+
replacements = []
80+
81+
for index, value in enumerate(args):
82+
if not index % 2:
83+
if callable(value):
84+
value = com.apply_if_callable(value, df)
85+
booleans.append(value)
86+
else:
87+
replacements.append(value)
88+
89+
return np.select(booleans, replacements, default=default)
90+
91+
return lambda df: _eval(df)

pandas/tests/api/test_api.py

+1
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ class TestPDApi(Base):
9999
funcs = [
100100
"array",
101101
"bdate_range",
102+
"case_when",
102103
"concat",
103104
"crosstab",
104105
"cut",

pandas/tests/test_case_when.py

+60
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import pytest # noqa
2+
3+
import pandas as pd
4+
import pandas._testing as tm
5+
6+
7+
class TestCaseWhen:
8+
def test_case_when_multiple_conditions_callable(self):
9+
df = pd.DataFrame(dict(a=[1, 2, 3], b=[4, 5, 6]))
10+
result = df.assign(
11+
new_column=pd.case_when(
12+
lambda x: x.a == 1,
13+
"first",
14+
lambda x: (x.a > 1) & (x.b == 5),
15+
"second",
16+
)
17+
)
18+
expected = df.assign(new_column=["first", "second", None])
19+
tm.assert_frame_equal(result, expected)
20+
21+
def test_case_when_multiple_conditions_array_series(self):
22+
df = pd.DataFrame(dict(a=[1, 2, 3], b=[4, 5, 6]))
23+
result = df.assign(
24+
new_column=pd.case_when(
25+
[True, False, False],
26+
"first",
27+
pd.Series([False, True, False]),
28+
"second",
29+
)
30+
)
31+
expected = df.assign(new_column=["first", "second", None])
32+
tm.assert_frame_equal(result, expected)
33+
34+
def test_case_when_multiple_conditions_callable_default(self):
35+
df = pd.DataFrame(dict(a=[1, 2, 3], b=[4, 5, 6]))
36+
result = df.assign(
37+
new_column=pd.case_when(
38+
lambda x: x.a == 1,
39+
"first",
40+
lambda x: (x.a > 1) & (x.b == 5),
41+
"second",
42+
default="default",
43+
)
44+
)
45+
expected = df.assign(new_column=["first", "second", "default"])
46+
tm.assert_frame_equal(result, expected)
47+
48+
def test_case_when_multiple_conditions_callable_default_series(self):
49+
df = pd.DataFrame(dict(a=[1, 2, 3], b=[4, 5, 6]))
50+
result = df.assign(
51+
new_column=pd.case_when(
52+
lambda x: x.a == 1,
53+
"first",
54+
lambda x: (x.a > 1) & (x.b == 5),
55+
"second",
56+
default=df.b,
57+
)
58+
)
59+
expected = df.assign(new_column=["first", "second", "6"])
60+
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)