diff --git a/pandas/core/strings/accessor.py b/pandas/core/strings/accessor.py index 15fc2d9e6d3c5..5606380908f38 100644 --- a/pandas/core/strings/accessor.py +++ b/pandas/core/strings/accessor.py @@ -3086,52 +3086,38 @@ def _get_group_names(regex: Pattern) -> List[Hashable]: def _str_extract_noexpand(arr, pat, flags=0): """ - Find groups in each string in the Series using passed regular - expression. This function is called from - str_extract(expand=False), and can return Series, DataFrame, or - Index. + Find groups in each string in the Series/Index using passed regular expression. + This function is called from str_extract(expand=False) when there is a single group + in the regex. + + Returns + ------- + np.ndarray """ - from pandas import ( - DataFrame, - array as pd_array, - ) + from pandas import array as pd_array regex = re.compile(pat, flags=flags) groups_or_na = _groups_or_na_fun(regex) result_dtype = _result_dtype(arr) - if regex.groups == 1: - result = np.array([groups_or_na(val)[0] for val in arr], dtype=object) - name = _get_single_group_name(regex) - # not dispatching, so we have to reconstruct here. - result = pd_array(result, dtype=result_dtype) - else: - name = None - columns = _get_group_names(regex) - if arr.size == 0: - # error: Incompatible types in assignment (expression has type - # "DataFrame", variable has type "ndarray") - result = DataFrame( # type: ignore[assignment] - columns=columns, dtype=result_dtype - ) - else: - # error: Incompatible types in assignment (expression has type - # "DataFrame", variable has type "ndarray") - result = DataFrame( # type:ignore[assignment] - [groups_or_na(val) for val in arr], - columns=columns, - index=arr.index, - dtype=result_dtype, - ) - return result, name + result = np.array([groups_or_na(val)[0] for val in arr], dtype=object) + # not dispatching, so we have to reconstruct here. + result = pd_array(result, dtype=result_dtype) + return result def _str_extract_frame(arr, pat, flags=0): """ - For each subject string in the Series, extract groups from the - first match of regular expression pat. This function is called from - str_extract(expand=True), and always returns a DataFrame. + Find groups in each string in the Series/Index using passed regular expression. + + For each subject string in the Series/Index, extract groups from the first match of + regular expression pat. This function is called from str_extract(expand=True) or + str_extract(expand=False) when there is more than one group in the regex. + + Returns + ------- + DataFrame """ from pandas import DataFrame @@ -3141,11 +3127,13 @@ def _str_extract_frame(arr, pat, flags=0): columns = _get_group_names(regex) result_dtype = _result_dtype(arr) - if len(arr) == 0: + if arr.size == 0: return DataFrame(columns=columns, dtype=result_dtype) - try: + + result_index: Optional["Index"] + if isinstance(arr, ABCSeries): result_index = arr.index - except AttributeError: + else: result_index = None return DataFrame( [groups_or_na(val) for val in arr], @@ -3156,12 +3144,16 @@ def _str_extract_frame(arr, pat, flags=0): def str_extract(arr, pat, flags=0, expand=True): - if expand: + regex = re.compile(pat, flags=flags) + returns_df = regex.groups > 1 or expand + + if returns_df: + name = None result = _str_extract_frame(arr._orig, pat, flags=flags) - return result.__finalize__(arr._orig, method="str_extract") else: - result, name = _str_extract_noexpand(arr._orig, pat, flags=flags) - return arr._wrap_result(result, name=name, expand=expand) + name = _get_single_group_name(regex) + result = _str_extract_noexpand(arr._orig, pat, flags=flags) + return arr._wrap_result(result, name=name) def str_extractall(arr, pat, flags=0):