Skip to content

Commit 45352da

Browse files
committed
fixup! Add case_when API * Used to support conditional assignment operation.
1 parent 9e0238f commit 45352da

File tree

2 files changed

+19
-16
lines changed

2 files changed

+19
-16
lines changed

pandas/core/case_when.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def case_when(*args, default: Any = lib.no_default) -> Callable:
1818
Create a callable for assignment based on a condition or multiple conditions.
1919
2020
This is useful when you want to assign a column based on multiple conditions.
21+
Uses `np.select` to perform the assignment.
2122
2223
Parameters
2324
----------
@@ -26,8 +27,9 @@ def case_when(*args, default: Any = lib.no_default) -> Callable:
2627
`condition0`, `value0`, `condition1`, `value1`, ...
2728
`condition` can be a 1-D boolean array/series or a callable
2829
that evaluate to a 1-D boolean array/series.
29-
default : Any, default is `None`.
30-
The default value to be used if all conditions evaluate False.
30+
default : Any, default is `np.nan`.
31+
The default value to be used if all conditions evaluate False. This value
32+
will be passed to the underlying `np.select` call.
3133
3234
Returns
3335
-------
@@ -72,7 +74,7 @@ def case_when(*args, default: Any = lib.no_default) -> Callable:
7274
)
7375

7476
if default is lib.no_default:
75-
default = None
77+
default = np.nan
7678

7779
def _eval(df: pd.DataFrame) -> np.ndarray:
7880
booleans = []

pandas/tests/test_case_when.py

+14-13
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import numpy as np
12
import pytest # noqa
23

34
import pandas as pd
@@ -10,51 +11,51 @@ def test_case_when_multiple_conditions_callable(self):
1011
result = df.assign(
1112
new_column=pd.case_when(
1213
lambda x: x.a == 1,
13-
"first",
14+
1,
1415
lambda x: (x.a > 1) & (x.b == 5),
15-
"second",
16+
2,
1617
)
1718
)
18-
expected = df.assign(new_column=["first", "second", None])
19+
expected = df.assign(new_column=[1, 2, np.nan])
1920
tm.assert_frame_equal(result, expected)
2021

2122
def test_case_when_multiple_conditions_array_series(self):
2223
df = pd.DataFrame(dict(a=[1, 2, 3], b=[4, 5, 6]))
2324
result = df.assign(
2425
new_column=pd.case_when(
2526
[True, False, False],
26-
"first",
27+
1,
2728
pd.Series([False, True, False]),
28-
"second",
29+
2,
2930
)
3031
)
31-
expected = df.assign(new_column=["first", "second", None])
32+
expected = df.assign(new_column=[1, 2, np.nan])
3233
tm.assert_frame_equal(result, expected)
3334

3435
def test_case_when_multiple_conditions_callable_default(self):
3536
df = pd.DataFrame(dict(a=[1, 2, 3], b=[4, 5, 6]))
3637
result = df.assign(
3738
new_column=pd.case_when(
3839
lambda x: x.a == 1,
39-
"first",
40+
1,
4041
lambda x: (x.a > 1) & (x.b == 5),
41-
"second",
42-
default="default",
42+
2,
43+
default=-1,
4344
)
4445
)
45-
expected = df.assign(new_column=["first", "second", "default"])
46+
expected = df.assign(new_column=[1, 2, -1])
4647
tm.assert_frame_equal(result, expected)
4748

4849
def test_case_when_multiple_conditions_callable_default_series(self):
4950
df = pd.DataFrame(dict(a=[1, 2, 3], b=[4, 5, 6]))
5051
result = df.assign(
5152
new_column=pd.case_when(
5253
lambda x: x.a == 1,
53-
"first",
54+
1,
5455
lambda x: (x.a > 1) & (x.b == 5),
55-
"second",
56+
2,
5657
default=df.b,
5758
)
5859
)
59-
expected = df.assign(new_column=["first", "second", "6"])
60+
expected = df.assign(new_column=[1, 2, 6])
6061
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)