diff --git a/doc/source/whatsnew/v2.0.0.rst b/doc/source/whatsnew/v2.0.0.rst index 7d26005315c33..0ba0e5bd32f77 100644 --- a/doc/source/whatsnew/v2.0.0.rst +++ b/doc/source/whatsnew/v2.0.0.rst @@ -14,6 +14,26 @@ including other versions of pandas. Enhancements ~~~~~~~~~~~~ +.. _whatsnew_200.enhancements.case_when: + +Assignment based on multiple conditions +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The ``pd.case_when`` API has now been added to support assignment based on multiple conditions. + +.. ipython:: python + + import pandas as pd + + df = pd.DataFrame(dict(a=[1, 2, 3], b=[4, 5, 6])) + df.assign( + new_column=pd.case_when( + lambda x: x.a == 1, 'first', + lambda x: (x.a > 1) & (x.b == 5), 'second', + default='default', + ) + ) + .. _whatsnew_200.enhancements.optional_dependency_management_pip: Installing optional dependencies with pip extras diff --git a/pandas/__init__.py b/pandas/__init__.py index d11a429987ac4..30d9f2b4b5a24 100644 --- a/pandas/__init__.py +++ b/pandas/__init__.py @@ -70,6 +70,7 @@ notnull, # indexes Index, + case_when, CategoricalIndex, RangeIndex, MultiIndex, @@ -238,6 +239,7 @@ __all__ = [ "ArrowDtype", "BooleanDtype", + "case_when", "Categorical", "CategoricalDtype", "CategoricalIndex", diff --git a/pandas/core/api.py b/pandas/core/api.py index 2cfe5ffc0170d..323ec234b9c7c 100644 --- a/pandas/core/api.py +++ b/pandas/core/api.py @@ -42,6 +42,7 @@ UInt64Dtype, ) from pandas.core.arrays.string_ import StringDtype +from pandas.core.case_when import case_when from pandas.core.construction import array from pandas.core.flags import Flags from pandas.core.groupby import ( @@ -80,11 +81,13 @@ # DataFrame needs to be imported after NamedAgg to avoid a circular import from pandas.core.frame import DataFrame # isort:skip + __all__ = [ "array", "ArrowDtype", "bdate_range", "BooleanDtype", + "case_when", "Categorical", "CategoricalDtype", "CategoricalIndex", diff --git a/pandas/core/case_when.py b/pandas/core/case_when.py new file mode 100644 index 0000000000000..7394d05ffdaba --- /dev/null +++ b/pandas/core/case_when.py @@ -0,0 +1,178 @@ +from __future__ import annotations + +from typing import Any +import warnings + +from pandas.util._exceptions import find_stack_level + +import pandas as pd +import pandas.core.common as com + + +def warn_and_override_index(series, series_type, index): + warnings.warn( + f"Series {series_type} will be reindexed to match obj index.", + UserWarning, + stacklevel=find_stack_level(), + ) + return pd.Series(series.values, index=index) + + +def case_when(obj: pd.DataFrame | pd.Series, *args, default: Any) -> pd.Series: + """ + Returns a Series based on multiple conditions assignment. + + This is useful when you want to assign a column based on multiple conditions. + Uses `Series.mask` to perform the assignment. + + If multiple conditions are met, the value of the first condition is taken. + + The returned Series have the same index as `obj` as mainly follows dtype + of `default`. + + Parameters + ---------- + obj : Dataframe or Series on which the conditions will be applied. + args : Variable argument of conditions and expected values. + Takes the form: + `condition0`, `value0`, `condition1`, `value1`, ... + `condition` can be a 1-D boolean array/series or a callable + that evaluate to a 1-D boolean array/series. See examples below. + default : Any + The default value to be used if all conditions evaluate False. This value + will be used to create the `Series` on which `Series.mask` will be called. + If this value is not already an array like (i.e. it is not of type `Series`, + `np.array` or `list`) it will be repeated `obj.shape[0]` times in order to + create an array like object from it and then apply the `Series.mask`. + + Returns + ------- + Series + Series with the corresponding values based on the conditions, their values + and the default value. + + + Examples + -------- + >>> df = pd.DataFrame(dict(a=[1, 2, 3], b=[4, 5, 6])) + >>> df + a b + 0 1 4 + 1 2 5 + 2 3 6 + + >>> pd.case_when( + ... df, + ... lambda x: x.a == 1, + ... 'first', + ... lambda x: (x.a == 2) & (x.b == 5), + ... 'second', + ... default='default', + ... ) + 0 first + 1 second + 2 default + dtype: object + + >>> pd.case_when( + ... df, + ... lambda x: (x.a == 1) & (x.b == 4), + ... df.b, + ... default=0, + ... ) + 0 4 + 1 0 + 2 0 + dtype: int64 + + >>> pd.case_when( + ... df, + ... lambda x: (x.a > 1) & (x.b > 1), + ... -1, + ... default=df.a, + ... ) + 0 1 + 1 -1 + 2 -1 + Name: a, dtype: int64 + + >>> pd.case_when( + ... df.a, + ... lambda x: x == 1, + ... -1, + ... default=df.a, + ... ) + 0 -1 + 1 2 + 2 3 + Name: a, dtype: int64 + + >>> pd.case_when( + ... df.a, + ... df.a > 1, + ... -1, + ... default=df.a, + ... ) + 0 1 + 1 -1 + 2 -1 + Name: a, dtype: int64 + + The index will always follow that of `obj`. For example: + >>> df = pd.DataFrame( + ... dict(a=[1, 2, 3], b=[4, 5, 6]), + ... index=['index 1', 'index 2', 'index 3'] + ... ) + >>> df + a b + index 1 1 4 + index 2 2 5 + index 3 3 6 + + >>> pd.case_when( + ... df, + ... lambda x: (x.a == 1) & (x.b == 4), + ... df.b, + ... default=0, + ... ) + index 1 4 + index 2 0 + index 3 0 + dtype: int64 + """ + len_args = len(args) + + if len_args < 2: + raise ValueError("At least two arguments are required for `case_when`") + if len_args % 2: + raise ValueError( + "The number of conditions and values do not match. " + f"There are {len_args - len_args//2} conditions " + f"and {len_args//2} values." + ) + + # construct series on which we will apply `Series.mask` + series = pd.Series(default, index=obj.index) + + # a series to keep track of which row got modified. + # we need this because if a row satisfied multiple conditions, + # we set the value of the first condition. + modified = pd.Series(False, index=obj.index) + + for i in range(0, len_args, 2): + # get conditions + if callable(args[i]): + conditions = com.apply_if_callable(args[i], obj) + else: + conditions = args[i] + + # get replacements + replacements = args[i + 1] + + # `Series.mask` call + series = series.mask(~modified & conditions, replacements) + + # keeping track of which row got modified + modified |= conditions + + return series diff --git a/pandas/tests/api/test_api.py b/pandas/tests/api/test_api.py index 60bcb97aaa364..ae4fe5d56ebf6 100644 --- a/pandas/tests/api/test_api.py +++ b/pandas/tests/api/test_api.py @@ -106,6 +106,7 @@ class TestPDApi(Base): funcs = [ "array", "bdate_range", + "case_when", "concat", "crosstab", "cut", diff --git a/pandas/tests/test_case_when.py b/pandas/tests/test_case_when.py new file mode 100644 index 0000000000000..f8f54a8b83786 --- /dev/null +++ b/pandas/tests/test_case_when.py @@ -0,0 +1,61 @@ +import numpy as np +import pytest # noqa + +import pandas as pd +import pandas._testing as tm + + +class TestCaseWhen: + def test_case_when_multiple_conditions_callable(self): + df = pd.DataFrame(dict(a=[1, 2, 3], b=[4, 5, 6])) + result = df.assign( + new_column=pd.case_when( + lambda x: x.a == 1, + 1, + lambda x: (x.a > 1) & (x.b == 5), + 2, + ) + ) + expected = df.assign(new_column=[1, 2, np.nan]) + tm.assert_frame_equal(result, expected) + + def test_case_when_multiple_conditions_array_series(self): + df = pd.DataFrame(dict(a=[1, 2, 3], b=[4, 5, 6])) + result = df.assign( + new_column=pd.case_when( + [True, False, False], + 1, + pd.Series([False, True, False]), + 2, + ) + ) + expected = df.assign(new_column=[1, 2, np.nan]) + tm.assert_frame_equal(result, expected) + + def test_case_when_multiple_conditions_callable_default(self): + df = pd.DataFrame(dict(a=[1, 2, 3], b=[4, 5, 6])) + result = df.assign( + new_column=pd.case_when( + lambda x: x.a == 1, + 1, + lambda x: (x.a > 1) & (x.b == 5), + 2, + default=-1, + ) + ) + expected = df.assign(new_column=[1, 2, -1]) + tm.assert_frame_equal(result, expected) + + def test_case_when_multiple_conditions_callable_default_series(self): + df = pd.DataFrame(dict(a=[1, 2, 3], b=[4, 5, 6])) + result = df.assign( + new_column=pd.case_when( + lambda x: x.a == 1, + 1, + lambda x: (x.a > 1) & (x.b == 5), + 2, + default=df.b, + ) + ) + expected = df.assign(new_column=[1, 2, 6]) + tm.assert_frame_equal(result, expected)