diff --git a/doc/source/whatsnew/v2.0.0.rst b/doc/source/whatsnew/v2.0.0.rst index d1b965e64e43b..01aad7eae4d79 100644 --- a/doc/source/whatsnew/v2.0.0.rst +++ b/doc/source/whatsnew/v2.0.0.rst @@ -314,7 +314,7 @@ Other enhancements - Added new argument ``engine`` to :func:`read_json` to support parsing JSON with pyarrow by specifying ``engine="pyarrow"`` (:issue:`48893`) - Added support for SQLAlchemy 2.0 (:issue:`40686`) - :class:`Index` set operations :meth:`Index.union`, :meth:`Index.intersection`, :meth:`Index.difference`, and :meth:`Index.symmetric_difference` now support ``sort=True``, which will always return a sorted result, unlike the default ``sort=None`` which does not sort in some cases (:issue:`25151`) -- +- Added :meth:`Index.filter` that allows filtering on :class:`Index` to create a new :class:`Index` object .. --------------------------------------------------------------------------- .. _whatsnew_200.notable_bug_fixes: diff --git a/pandas/core/generic.py b/pandas/core/generic.py index d76648558bc6e..0938dcb97b7ff 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -5473,6 +5473,8 @@ def filter( -------- DataFrame.loc : Access a group of rows and columns by label(s) or a boolean array. + Index.filter : Create a subset of the index according to the specified index + labels. Notes ----- diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index 8adb7cca45323..742c8952bd957 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -4,6 +4,7 @@ import functools from itertools import zip_longest import operator +import re from typing import ( TYPE_CHECKING, Any, @@ -88,6 +89,7 @@ ensure_int64, ensure_object, ensure_platform_int, + ensure_str, is_any_real_numeric_dtype, is_bool_dtype, is_categorical_dtype, @@ -155,6 +157,7 @@ PandasObject, ) import pandas.core.common as com +from pandas.core.common import count_not_none from pandas.core.construction import ( ensure_wrapped_if_datetimelike, extract_array, @@ -6150,6 +6153,76 @@ def map(self, mapper, na_action=None): return Index._with_infer(new_values, dtype=dtype, copy=False, name=self.name) + def filter( + self: _IndexT, + items: Axes | None = None, + like: str_t | None = None, + regex: str_t | None = None, + ) -> _IndexT: + """ + Create a subset of the index according to the specified index labels. + + Parameters + ---------- + items : list-like + Keep labels from index which are in items. + like : str + Keep labels from index for which "like in label == True". + regex : str (regular expression) + Keep labels from index for which re.search(regex, label) == True. + + Returns + ------- + same type as input object + + See Also + -------- + DataFrame.filter : Subset the dataframe rows or columns according to the + specified index labels. + + Notes + ----- + The ``items``, ``like``, and ``regex`` parameters are + enforced to be mutually exclusive. + + Examples + -------- + >>> idx = pd.Index(["cat", "dog", "bat", "bird"]) + >>> idx + Index(['cat', 'dog', 'bat', 'bird'], dtype='object') + + >>> # select index values by name + >>> idx.filter(items=['cat', 'dog']) + Index(['cat', 'dog'], dtype='object') + + >>> # select index values by regular expression + >>> idx.filter(regex=r'b.*') + Index(['bat', 'bird'], dtype='object') + + >>> # select index values containing 'at' + >>> idx.filter(like='at') + Index(['cat', 'bat'], dtype='object') + """ + nkw = count_not_none(items, like, regex) + if nkw > 1: + raise TypeError( + "Keyword arguments `items`, `like`, or `regex` " + "are mutually exclusive" + ) + + if items is not None: + mask = [r in items for r in self] + elif like: + mask = [like in ensure_str(r) for r in self] + elif regex: + matcher = re.compile(regex) + mask = [matcher.search(ensure_str(r)) is not None for r in self] + else: + raise TypeError("Must pass either `items`, `like`, or `regex`") + boolmask = np.array(mask, dtype=np.bool_) + res_values = self._data[boolmask] + return type(self)._simple_new(res_values, name=self._name) + # TODO: De-duplicate with map, xref GH#32349 @final def _transform_index(self, func, *, level=None) -> Index: diff --git a/pandas/tests/indexes/test_base.py b/pandas/tests/indexes/test_base.py index 783cf76403059..ae1f51efcf723 100644 --- a/pandas/tests/indexes/test_base.py +++ b/pandas/tests/indexes/test_base.py @@ -1285,6 +1285,34 @@ def test_sortlevel(self): result = index.sortlevel(ascending=False) tm.assert_index_equal(result[0], expected) + @pytest.mark.parametrize( + "dtype", + ["object", pd.StringDtype()], + ) + def test_filter_string(self, dtype): + idx = Index(["cat", "dog", "bat", "bird"], dtype=dtype) + result = idx.filter(["cat", "dog"]) + expected = Index(["cat", "dog"], dtype=dtype) + tm.assert_index_equal(result, expected) + + result = idx.filter(like="at") + expected = Index(["cat", "bat"], dtype=dtype) + tm.assert_index_equal(result, expected) + + result = idx.filter(regex=r"b.*") + expected = Index(["bat", "bird"], dtype=dtype) + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize( + "dtype", + ["int", pd.Int64Dtype()], + ) + def test_filter_int(self, dtype): + idx = Index([1, 2, 3, 4, 5], dtype=dtype) + result = idx.filter(range(2, 4)) + expected = Index([2, 3], dtype=dtype) + tm.assert_index_equal(result, expected) + class TestMixedIntIndex(Base): # Mostly the tests from common.py for which the results differ