diff --git a/pandas/core/strings/accessor.py b/pandas/core/strings/accessor.py index e85da9f4b574d..1461c52d5cb65 100644 --- a/pandas/core/strings/accessor.py +++ b/pandas/core/strings/accessor.py @@ -15,7 +15,10 @@ import numpy as np import pandas._libs.lib as lib -from pandas._typing import FrameOrSeriesUnion +from pandas._typing import ( + ArrayLike, + FrameOrSeriesUnion, +) from pandas.util._decorators import Appender from pandas.core.dtypes.common import ( @@ -3084,9 +3087,9 @@ def _get_group_names(regex: Pattern) -> List[Hashable]: return [names.get(1 + i, i) for i in range(regex.groups)] -def _str_extract_noexpand(arr, pat, flags=0): +def _str_extract_noexpand(arr: ArrayLike, pat: str, flags=0): """ - Find groups in each string in the Series/Index using passed regular expression. + Find groups in each string in the array using passed regular expression. This function is called from str_extract(expand=False) when there is a single group in the regex. @@ -3095,65 +3098,69 @@ def _str_extract_noexpand(arr, pat, flags=0): ------- np.ndarray """ - from pandas import array as pd_array - regex = re.compile(pat, flags=flags) groups_or_na = _groups_or_na_fun(regex) - result_dtype = _result_dtype(arr) result = np.array([groups_or_na(val)[0] for val in np.asarray(arr)], dtype=object) - # not dispatching, so we have to reconstruct here. - result = pd_array(result, dtype=result_dtype) return result -def _str_extract_frame(arr, pat, flags=0): +def _str_extract_expand(arr: ArrayLike, pat: str, flags: int = 0) -> List[List]: """ - Find groups in each string in the Series/Index using passed regular expression. + Find groups in each string in the array using passed regular expression. - For each subject string in the Series/Index, extract groups from the first match of + For each subject string in the array, extract groups from the first match of regular expression pat. This function is called from str_extract(expand=True) or str_extract(expand=False) when there is more than one group in the regex. Returns ------- - DataFrame + list of lists """ - from pandas import DataFrame - regex = re.compile(pat, flags=flags) groups_or_na = _groups_or_na_fun(regex) - columns = _get_group_names(regex) - result_dtype = _result_dtype(arr) - if arr.size == 0: - return DataFrame(columns=columns, dtype=result_dtype) + return [groups_or_na(val) for val in np.asarray(arr)] - result_index: Optional["Index"] - if isinstance(arr, ABCSeries): - result_index = arr.index - else: - result_index = None - return DataFrame( - [groups_or_na(val) for val in np.asarray(arr)], - columns=columns, - index=result_index, - dtype=result_dtype, - ) +def str_extract(accessor: StringMethods, pat: str, flags: int = 0, expand: bool = True): + from pandas import ( + DataFrame, + array as pd_array, + ) -def str_extract(arr, pat, flags=0, expand=True): + obj = accessor._data + result_dtype = _result_dtype(obj) regex = re.compile(pat, flags=flags) returns_df = regex.groups > 1 or expand if returns_df: name = None - result = _str_extract_frame(arr._orig, pat, flags=flags) + columns = _get_group_names(regex) + + if obj.array.size == 0: + result = DataFrame(columns=columns, dtype=result_dtype) + + else: + result_list = _str_extract_expand(obj.array, pat, flags=flags) + + result_index: Optional["Index"] + if isinstance(obj, ABCSeries): + result_index = obj.index + else: + result_index = None + + result = DataFrame( + result_list, columns=columns, index=result_index, dtype=result_dtype + ) + else: name = _get_single_group_name(regex) - result = _str_extract_noexpand(arr._orig, pat, flags=flags) - return arr._wrap_result(result, name=name) + result_arr = _str_extract_noexpand(obj.array, pat, flags=flags) + # not dispatching, so we have to reconstruct here. + result = pd_array(result_arr, dtype=result_dtype) + return accessor._wrap_result(result, name=name) def str_extractall(arr, pat, flags=0):