diff --git a/pandas/core/arrays/categorical.py b/pandas/core/arrays/categorical.py index cb8a08f5668ac..95d9409b265ce 100644 --- a/pandas/core/arrays/categorical.py +++ b/pandas/core/arrays/categorical.py @@ -2453,7 +2453,9 @@ def replace(self, to_replace, value, inplace: bool = False): # ------------------------------------------------------------------------ # String methods interface - def _str_map(self, f, na_value=np.nan, dtype=np.dtype("object")): + def _str_map( + self, f, na_value=np.nan, dtype=np.dtype("object"), convert: bool = True + ): # Optimization to apply the callable `f` to the categories once # and rebuild the result by `take`ing from the result with the codes. # Returns the same type as the object-dtype implementation though. diff --git a/pandas/core/arrays/string_.py b/pandas/core/arrays/string_.py index 74ca5130ca322..ab1dadf4d2dfa 100644 --- a/pandas/core/arrays/string_.py +++ b/pandas/core/arrays/string_.py @@ -410,7 +410,9 @@ def _cmp_method(self, other, op): # String methods interface _str_na_value = StringDtype.na_value - def _str_map(self, f, na_value=None, dtype: Dtype | None = None): + def _str_map( + self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True + ): from pandas.arrays import BooleanArray if dtype is None: diff --git a/pandas/core/arrays/string_arrow.py b/pandas/core/arrays/string_arrow.py index 4370f3a4e15cf..454d8ebde989b 100644 --- a/pandas/core/arrays/string_arrow.py +++ b/pandas/core/arrays/string_arrow.py @@ -741,7 +741,9 @@ def value_counts(self, dropna: bool = True) -> Series: _str_na_value = ArrowStringDtype.na_value - def _str_map(self, f, na_value=None, dtype: Dtype | None = None): + def _str_map( + self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True + ): # TODO: de-duplicate with StringArray method. This method is moreless copy and # paste. diff --git a/pandas/core/strings/accessor.py b/pandas/core/strings/accessor.py index e1399968cb1c4..7643019ff8c55 100644 --- a/pandas/core/strings/accessor.py +++ b/pandas/core/strings/accessor.py @@ -13,10 +13,7 @@ import numpy as np import pandas._libs.lib as lib -from pandas._typing import ( - ArrayLike, - FrameOrSeriesUnion, -) +from pandas._typing import FrameOrSeriesUnion from pandas.util._decorators import Appender from pandas.core.dtypes.common import ( @@ -160,7 +157,6 @@ class StringMethods(NoNewAttributesMixin): # TODO: Dispatch all the methods # Currently the following are not dispatched to the array # * cat - # * extract # * extractall def __init__(self, data): @@ -243,7 +239,7 @@ def _wrap_result( self, result, name=None, - expand=None, + expand: bool | None = None, fill_value=np.nan, returns_string=True, ): @@ -2385,10 +2381,7 @@ def extract( 2 NaN dtype: object """ - from pandas import ( - DataFrame, - array as pd_array, - ) + from pandas import DataFrame if not isinstance(expand, bool): raise ValueError("expand must be True or False") @@ -2400,8 +2393,6 @@ def extract( if not expand and regex.groups > 1 and isinstance(self._data, ABCIndex): raise ValueError("only one regex group is supported with Index") - # TODO: dispatch - obj = self._data result_dtype = _result_dtype(obj) @@ -2415,8 +2406,8 @@ def extract( result = DataFrame(columns=columns, dtype=result_dtype) else: - result_list = _str_extract( - obj.array, pat, flags=flags, expand=returns_df + result_list = self._data.array._str_extract( + pat, flags=flags, expand=returns_df ) result_index: Index | None @@ -2431,9 +2422,7 @@ def extract( else: name = _get_single_group_name(regex) - result_arr = _str_extract(obj.array, pat, flags=flags, expand=returns_df) - # not dispatching, so we have to reconstruct here. - result = pd_array(result_arr, dtype=result_dtype) + result = self._data.array._str_extract(pat, flags=flags, expand=returns_df) return self._wrap_result(result, name=name) @forbid_nonstring_types(["bytes"]) @@ -3121,33 +3110,6 @@ def _get_group_names(regex: re.Pattern) -> list[Hashable]: return [names.get(1 + i, i) for i in range(regex.groups)] -def _str_extract(arr: ArrayLike, pat: str, flags=0, expand: bool = True): - """ - Find groups in each string in the array using passed regular expression. - - Returns - ------- - np.ndarray or list of lists is expand is True - """ - regex = re.compile(pat, flags=flags) - - empty_row = [np.nan] * regex.groups - - def f(x): - if not isinstance(x, str): - return empty_row - m = regex.search(x) - if m: - return [np.nan if item is None else item for item in m.groups()] - else: - return empty_row - - if expand: - return [f(val) for val in np.asarray(arr)] - - return np.array([f(val)[0] for val in np.asarray(arr)], dtype=object) - - def str_extractall(arr, pat, flags=0): regex = re.compile(pat, flags=flags) # the regex must contain capture groups. diff --git a/pandas/core/strings/base.py b/pandas/core/strings/base.py index 730870b448cb2..cd71844d3b527 100644 --- a/pandas/core/strings/base.py +++ b/pandas/core/strings/base.py @@ -230,3 +230,7 @@ def _str_split(self, pat=None, n=-1, expand=False): @abc.abstractmethod def _str_rsplit(self, pat=None, n=-1): pass + + @abc.abstractmethod + def _str_extract(self, pat: str, flags: int = 0, expand: bool = True): + pass diff --git a/pandas/core/strings/object_array.py b/pandas/core/strings/object_array.py index c214ada9c1ada..7ce4abe904f3b 100644 --- a/pandas/core/strings/object_array.py +++ b/pandas/core/strings/object_array.py @@ -32,7 +32,9 @@ def __len__(self): # For typing, _str_map relies on the object being sized. raise NotImplementedError - def _str_map(self, f, na_value=None, dtype: Dtype | None = None): + def _str_map( + self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True + ): """ Map a callable over valid element of the array. @@ -47,6 +49,8 @@ def _str_map(self, f, na_value=None, dtype: Dtype | None = None): for object-dtype and Categorical and ``pd.NA`` for StringArray. dtype : Dtype, optional The dtype of the result array. + convert : bool, default True + Whether to call `maybe_convert_objects` on the resulting ndarray """ if dtype is None: dtype = np.dtype("object") @@ -60,9 +64,9 @@ def _str_map(self, f, na_value=None, dtype: Dtype | None = None): arr = np.asarray(self, dtype=object) mask = isna(arr) - convert = not np.all(mask) + map_convert = convert and not np.all(mask) try: - result = lib.map_infer_mask(arr, f, mask.view(np.uint8), convert) + result = lib.map_infer_mask(arr, f, mask.view(np.uint8), map_convert) except (TypeError, AttributeError) as e: # Reraise the exception if callable `f` got wrong number of args. # The user may want to be warned by this, instead of getting NaN @@ -88,7 +92,7 @@ def g(x): return result if na_value is not np.nan: np.putmask(result, mask, na_value) - if result.dtype == object: + if convert and result.dtype == object: result = lib.maybe_convert_objects(result) return result @@ -410,3 +414,28 @@ def _str_lstrip(self, to_strip=None): def _str_rstrip(self, to_strip=None): return self._str_map(lambda x: x.rstrip(to_strip)) + + def _str_extract(self, pat: str, flags: int = 0, expand: bool = True): + regex = re.compile(pat, flags=flags) + na_value = self._str_na_value + + if not expand: + + def g(x): + m = regex.search(x) + return m.groups()[0] if m else na_value + + return self._str_map(g, convert=False) + + empty_row = [na_value] * regex.groups + + def f(x): + if not isinstance(x, str): + return empty_row + m = regex.search(x) + if m: + return [na_value if item is None else item for item in m.groups()] + else: + return empty_row + + return [f(val) for val in np.asarray(self)]