Skip to content

Commit fb70ca4

Browse files
[ArrowStringArray] REF: extract/extractall column names (#41417)
1 parent 57be60d commit fb70ca4

File tree

1 file changed

+27
-10
lines changed

1 file changed

+27
-10
lines changed

pandas/core/strings/accessor.py

+27-10
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
import re
44
from typing import (
55
Dict,
6+
Hashable,
67
List,
78
Optional,
9+
Pattern,
810
)
911
import warnings
1012

@@ -3036,13 +3038,31 @@ def _result_dtype(arr):
30363038
return object
30373039

30383040

3039-
def _get_single_group_name(rx):
3040-
try:
3041-
return list(rx.groupindex.keys()).pop()
3042-
except IndexError:
3041+
def _get_single_group_name(regex: Pattern) -> Hashable:
3042+
if regex.groupindex:
3043+
return next(iter(regex.groupindex))
3044+
else:
30433045
return None
30443046

30453047

3048+
def _get_group_names(regex: Pattern) -> List[Hashable]:
3049+
"""
3050+
Get named groups from compiled regex.
3051+
3052+
Unnamed groups are numbered.
3053+
3054+
Parameters
3055+
----------
3056+
regex : compiled regex
3057+
3058+
Returns
3059+
-------
3060+
list of column labels
3061+
"""
3062+
names = {v: k for k, v in regex.groupindex.items()}
3063+
return [names.get(1 + i, i) for i in range(regex.groups)]
3064+
3065+
30463066
def _str_extract_noexpand(arr, pat, flags=0):
30473067
"""
30483068
Find groups in each string in the Series using passed regular
@@ -3069,8 +3089,7 @@ def _str_extract_noexpand(arr, pat, flags=0):
30693089
if isinstance(arr, ABCIndex):
30703090
raise ValueError("only one regex group is supported with Index")
30713091
name = None
3072-
names = dict(zip(regex.groupindex.values(), regex.groupindex.keys()))
3073-
columns = [names.get(1 + i, i) for i in range(regex.groups)]
3092+
columns = _get_group_names(regex)
30743093
if arr.size == 0:
30753094
# error: Incompatible types in assignment (expression has type
30763095
# "DataFrame", variable has type "ndarray")
@@ -3101,8 +3120,7 @@ def _str_extract_frame(arr, pat, flags=0):
31013120

31023121
regex = re.compile(pat, flags=flags)
31033122
groups_or_na = _groups_or_na_fun(regex)
3104-
names = dict(zip(regex.groupindex.values(), regex.groupindex.keys()))
3105-
columns = [names.get(1 + i, i) for i in range(regex.groups)]
3123+
columns = _get_group_names(regex)
31063124

31073125
if len(arr) == 0:
31083126
return DataFrame(columns=columns, dtype=object)
@@ -3139,8 +3157,7 @@ def str_extractall(arr, pat, flags=0):
31393157
if isinstance(arr, ABCIndex):
31403158
arr = arr.to_series().reset_index(drop=True)
31413159

3142-
names = dict(zip(regex.groupindex.values(), regex.groupindex.keys()))
3143-
columns = [names.get(1 + i, i) for i in range(regex.groups)]
3160+
columns = _get_group_names(regex)
31443161
match_list = []
31453162
index_list = []
31463163
is_mi = arr.index.nlevels > 1

0 commit comments

Comments
 (0)