diff --git a/doc/source/whatsnew/v0.24.0.rst b/doc/source/whatsnew/v0.24.0.rst index 16f327a9006da..9ecd7db873a4c 100644 --- a/doc/source/whatsnew/v0.24.0.rst +++ b/doc/source/whatsnew/v0.24.0.rst @@ -996,6 +996,7 @@ update the ``ExtensionDtype._metadata`` tuple to match the signature of your - :meth:`~pandas.api.types.ExtensionArray.repeat` has been added (:issue:`24349`) - ``ExtensionDtype`` has gained the ability to instantiate from string dtypes, e.g. ``decimal`` would instantiate a registered ``DecimalDtype``; furthermore the ``ExtensionDtype`` has gained the method ``construct_array_type`` (:issue:`21185`) +- :meth:`~pandas.api.types.ExtensionArray.searchsorted` has been added (:issue:`24350`) - An ``ExtensionArray`` with a boolean dtype now works correctly as a boolean indexer. :meth:`pandas.api.types.is_bool_dtype` now properly considers them boolean (:issue:`22326`) - Added ``ExtensionDtype._is_numeric`` for controlling whether an extension dtype is considered numeric (:issue:`22290`). - The ``ExtensionArray`` constructor, ``_from_sequence`` now take the keyword arg ``copy=False`` (:issue:`21185`) diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index 739c8174c7d87..2d4f8ca9c2cee 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -68,6 +68,7 @@ class ExtensionArray(object): * unique * factorize / _values_for_factorize * argsort / _values_for_argsort + * searchsorted The remaining methods implemented on this class should be performant, as they only compose abstract methods. Still, a more efficient @@ -518,6 +519,54 @@ def unique(self): uniques = unique(self.astype(object)) return self._from_sequence(uniques, dtype=self.dtype) + def searchsorted(self, value, side="left", sorter=None): + """ + Find indices where elements should be inserted to maintain order. + + .. versionadded:: 0.24.0 + + Find the indices into a sorted array `self` (a) such that, if the + corresponding elements in `v` were inserted before the indices, the + order of `self` would be preserved. + + Assuming that `a` is sorted: + + ====== ============================ + `side` returned index `i` satisfies + ====== ============================ + left ``self[i-1] < v <= self[i]`` + right ``self[i-1] <= v < self[i]`` + ====== ============================ + + Parameters + ---------- + value : array_like + Values to insert into `self`. + side : {'left', 'right'}, optional + If 'left', the index of the first suitable location found is given. + If 'right', return the last such index. If there is no suitable + index, return either 0 or N (where N is the length of `self`). + sorter : 1-D array_like, optional + Optional array of integer indices that sort array a into ascending + order. They are typically the result of argsort. + + Returns + ------- + indices : array of ints + Array of insertion points with the same shape as `value`. + + See Also + -------- + numpy.searchsorted : Similar method from NumPy. + """ + # Note: the base tests provided by pandas only test the basics. + # We do not test + # 1. Values outside the range of the `data_for_sorting` fixture + # 2. Values between the values in the `data_for_sorting` fixture + # 3. Missing values. + arr = self.astype(object) + return arr.searchsorted(value, side=side, sorter=sorter) + def _values_for_factorize(self): # type: () -> Tuple[ndarray, Any] """ diff --git a/pandas/core/arrays/sparse.py b/pandas/core/arrays/sparse.py index e4a8c21bbb839..7c8f58c9a3203 100644 --- a/pandas/core/arrays/sparse.py +++ b/pandas/core/arrays/sparse.py @@ -1169,6 +1169,16 @@ def _take_without_fill(self, indices): return taken + def searchsorted(self, v, side="left", sorter=None): + msg = "searchsorted requires high memory usage." + warnings.warn(msg, PerformanceWarning, stacklevel=2) + if not is_scalar(v): + v = np.asarray(v) + v = np.asarray(v) + return np.asarray(self, dtype=self.dtype.subtype).searchsorted( + v, side, sorter + ) + def copy(self, deep=False): if deep: values = self.sp_values.copy() diff --git a/pandas/core/base.py b/pandas/core/base.py index 9dc5b94f68392..b8ee50765e070 100644 --- a/pandas/core/base.py +++ b/pandas/core/base.py @@ -1464,7 +1464,7 @@ def factorize(self, sort=False, na_sentinel=-1): @Appender(_shared_docs['searchsorted']) def searchsorted(self, value, side='left', sorter=None): # needs coercion on the key (DatetimeIndex does already) - return self.values.searchsorted(value, side=side, sorter=sorter) + return self._values.searchsorted(value, side=side, sorter=sorter) def drop_duplicates(self, keep='first', inplace=False): inplace = validate_bool_kwarg(inplace, 'inplace') diff --git a/pandas/tests/extension/base/methods.py b/pandas/tests/extension/base/methods.py index 14cf28f59d50f..2c04c4cd99801 100644 --- a/pandas/tests/extension/base/methods.py +++ b/pandas/tests/extension/base/methods.py @@ -242,6 +242,31 @@ def test_hash_pandas_object_works(self, data, as_frame): b = pd.util.hash_pandas_object(data) self.assert_equal(a, b) + @pytest.mark.parametrize("as_series", [True, False]) + def test_searchsorted(self, data_for_sorting, as_series): + b, c, a = data_for_sorting + arr = type(data_for_sorting)._from_sequence([a, b, c]) + + if as_series: + arr = pd.Series(arr) + assert arr.searchsorted(a) == 0 + assert arr.searchsorted(a, side="right") == 1 + + assert arr.searchsorted(b) == 1 + assert arr.searchsorted(b, side="right") == 2 + + assert arr.searchsorted(c) == 2 + assert arr.searchsorted(c, side="right") == 3 + + result = arr.searchsorted(arr.take([0, 2])) + expected = np.array([0, 2], dtype=np.intp) + + tm.assert_numpy_array_equal(result, expected) + + # sorter + sorter = np.array([1, 2, 0]) + assert data_for_sorting.searchsorted(a, sorter=sorter) == 0 + @pytest.mark.parametrize("as_frame", [True, False]) def test_where_series(self, data, na_value, as_frame): assert data[0] != data[1] diff --git a/pandas/tests/extension/json/test_json.py b/pandas/tests/extension/json/test_json.py index a35997b07fd83..9ee131950f19c 100644 --- a/pandas/tests/extension/json/test_json.py +++ b/pandas/tests/extension/json/test_json.py @@ -232,6 +232,10 @@ def test_where_series(self, data, na_value): # with shapes (4,) (4,) (0,) super().test_where_series(data, na_value) + @pytest.mark.skip(reason="Can't compare dicts.") + def test_searchsorted(self, data_for_sorting): + super(TestMethods, self).test_searchsorted(data_for_sorting) + class TestCasting(BaseJSON, base.BaseCastingTests): @pytest.mark.skip(reason="failing on np.array(self, dtype=str)") diff --git a/pandas/tests/extension/test_categorical.py b/pandas/tests/extension/test_categorical.py index 6106bc3d58620..c876db416470c 100644 --- a/pandas/tests/extension/test_categorical.py +++ b/pandas/tests/extension/test_categorical.py @@ -189,6 +189,10 @@ def test_combine_add(self, data_repeated): def test_fillna_length_mismatch(self, data_missing): super().test_fillna_length_mismatch(data_missing) + def test_searchsorted(self, data_for_sorting): + if not data_for_sorting.ordered: + raise pytest.skip(reason="searchsorted requires ordered data.") + class TestCasting(base.BaseCastingTests): pass diff --git a/pandas/tests/extension/test_sparse.py b/pandas/tests/extension/test_sparse.py index ea849a78cda12..257eb44cd94fe 100644 --- a/pandas/tests/extension/test_sparse.py +++ b/pandas/tests/extension/test_sparse.py @@ -286,6 +286,12 @@ def test_combine_first(self, data): pytest.skip("TODO(SparseArray.__setitem__ will preserve dtype.") super(TestMethods, self).test_combine_first(data) + @pytest.mark.parametrize("as_series", [True, False]) + def test_searchsorted(self, data_for_sorting, as_series): + with tm.assert_produces_warning(PerformanceWarning): + super(TestMethods, self).test_searchsorted(data_for_sorting, + as_series=as_series) + class TestCasting(BaseSparseTests, base.BaseCastingTests): pass