Skip to content

Commit c007dba

Browse files
committed
ENH: add expand kw to str.extract and str.get_dummies
1 parent 5bc191a commit c007dba

File tree

3 files changed

+210
-81
lines changed

3 files changed

+210
-81
lines changed

pandas/core/strings.py

+53-70
Original file line numberDiff line numberDiff line change
@@ -424,17 +424,21 @@ def str_extract(arr, pat, flags=0):
424424
Pattern or regular expression
425425
flags : int, default 0 (no flags)
426426
re module flags, e.g. re.IGNORECASE
427+
expand : None or bool, default None
428+
* If None, return Series/Index (one group) or DataFrame/MultiIndex (multiple groups)
429+
* If True, return DataFrame/MultiIndex expanding dimensionality.
430+
* If False, return Series/Index.
427431
428432
Returns
429433
-------
430-
extracted groups : Series (one group) or DataFrame (multiple groups)
434+
extracted groups : Series/Index or DataFrame/MultiIndex of objects
431435
Note that dtype of the result is always object, even when no match is
432436
found and the result is a Series or DataFrame containing only NaN
433437
values.
434438
435439
Examples
436440
--------
437-
A pattern with one group will return a Series. Non-matches will be NaN.
441+
A pattern with one group returns a Series. Non-matches will be NaN.
438442
439443
>>> Series(['a1', 'b2', 'c3']).str.extract('[ab](\d)')
440444
0 1
@@ -466,11 +470,14 @@ def str_extract(arr, pat, flags=0):
466470
1 b 2
467471
2 NaN NaN
468472
469-
"""
470-
from pandas.core.series import Series
471-
from pandas.core.frame import DataFrame
472-
from pandas.core.index import Index
473+
Or you can specify ``expand=False`` to return Series.
473474
475+
>>> pd.Series(['a1', 'b2', 'c3']).str.extract('([ab])?(\d)', expand=False)
476+
0 [a, 1]
477+
1 [b, 2]
478+
2 [nan, 3]
479+
Name: [0, 1], dtype: object
480+
"""
474481
regex = re.compile(pat, flags=flags)
475482
# just to be safe, check this
476483
if regex.groups == 0:
@@ -490,18 +497,9 @@ def f(x):
490497
result = np.array([f(val)[0] for val in arr], dtype=object)
491498
name = _get_single_group_name(regex)
492499
else:
493-
if isinstance(arr, Index):
494-
raise ValueError("only one regex group is supported with Index")
495-
name = None
496500
names = dict(zip(regex.groupindex.values(), regex.groupindex.keys()))
497-
columns = [names.get(1 + i, i) for i in range(regex.groups)]
498-
if arr.empty:
499-
result = DataFrame(columns=columns, dtype=object)
500-
else:
501-
result = DataFrame([f(val) for val in arr],
502-
columns=columns,
503-
index=arr.index,
504-
dtype=object)
501+
name = [names.get(1 + i, i) for i in range(regex.groups)]
502+
result = np.array([f(val) for val in arr], dtype=object)
505503
return result, name
506504

507505

@@ -514,10 +512,13 @@ def str_get_dummies(arr, sep='|'):
514512
----------
515513
sep : string, default "|"
516514
String to split on.
515+
expand : bool, default True
516+
* If True, return DataFrame/MultiIndex expanding dimensionality.
517+
* If False, return Series/Index.
517518
518519
Returns
519520
-------
520-
dummies : DataFrame
521+
dummies : Series/Index or DataFrame/MultiIndex of objects
521522
522523
Examples
523524
--------
@@ -537,14 +538,7 @@ def str_get_dummies(arr, sep='|'):
537538
--------
538539
pandas.get_dummies
539540
"""
540-
from pandas.core.frame import DataFrame
541541
from pandas.core.index import Index
542-
543-
# GH9980, Index.str does not support get_dummies() as it returns a frame
544-
if isinstance(arr, Index):
545-
raise TypeError("get_dummies is not supported for string methods on Index")
546-
547-
# TODO remove this hack?
548542
arr = arr.fillna('')
549543
try:
550544
arr = sep + arr + sep
@@ -561,7 +555,7 @@ def str_get_dummies(arr, sep='|'):
561555
for i, t in enumerate(tags):
562556
pat = sep + t + sep
563557
dummies[:, i] = lib.map_infer(arr.values, lambda x: pat in x)
564-
return DataFrame(dummies, arr.index, tags)
558+
return dummies, tags
565559

566560

567561
def str_join(arr, sep):
@@ -1081,7 +1075,10 @@ def __iter__(self):
10811075
i += 1
10821076
g = self.get(i)
10831077

1084-
def _wrap_result(self, result, use_codes=True, name=None):
1078+
def _wrap_result(self, result, use_codes=True, name=None, expand=False):
1079+
1080+
if not isinstance(expand, bool):
1081+
raise ValueError("expand must be True or False")
10851082

10861083
# for category, we do the stuff on the categories, so blow it up
10871084
# to the full series again
@@ -1095,39 +1092,11 @@ def _wrap_result(self, result, use_codes=True, name=None):
10951092
# can be merged to _wrap_result_expand in v0.17
10961093
from pandas.core.series import Series
10971094
from pandas.core.frame import DataFrame
1098-
from pandas.core.index import Index
1095+
from pandas.core.index import Index, MultiIndex
10991096

1100-
if not hasattr(result, 'ndim'):
1101-
return result
11021097
name = name or getattr(result, 'name', None) or self._orig.name
11031098

1104-
if result.ndim == 1:
1105-
if isinstance(self._orig, Index):
1106-
# if result is a boolean np.array, return the np.array
1107-
# instead of wrapping it into a boolean Index (GH 8875)
1108-
if is_bool_dtype(result):
1109-
return result
1110-
return Index(result, name=name)
1111-
return Series(result, index=self._orig.index, name=name)
1112-
else:
1113-
assert result.ndim < 3
1114-
return DataFrame(result, index=self._orig.index)
1115-
1116-
def _wrap_result_expand(self, result, expand=False):
1117-
if not isinstance(expand, bool):
1118-
raise ValueError("expand must be True or False")
1119-
1120-
# for category, we do the stuff on the categories, so blow it up
1121-
# to the full series again
1122-
if self._is_categorical:
1123-
result = take_1d(result, self._orig.cat.codes)
1124-
1125-
from pandas.core.index import Index, MultiIndex
1126-
if not hasattr(result, 'ndim'):
1127-
return result
1128-
11291099
if isinstance(self._orig, Index):
1130-
name = getattr(result, 'name', None)
11311100
# if result is a boolean np.array, return the np.array
11321101
# instead of wrapping it into a boolean Index (GH 8875)
11331102
if hasattr(result, 'dtype') and is_bool_dtype(result):
@@ -1137,7 +1106,7 @@ def _wrap_result_expand(self, result, expand=False):
11371106
result = list(result)
11381107
return MultiIndex.from_tuples(result, names=name)
11391108
else:
1140-
return Index(result, name=name)
1109+
return Index(result, name=name, tupleize_cols=False)
11411110
else:
11421111
index = self._orig.index
11431112
if expand:
@@ -1148,30 +1117,34 @@ def cons_row(x):
11481117
return [ x ]
11491118
cons = self._orig._constructor_expanddim
11501119
data = [cons_row(x) for x in result]
1151-
return cons(data, index=index)
1120+
return cons(data, index=index, columns=name,
1121+
dtype=result.dtype)
11521122
else:
1153-
name = getattr(result, 'name', None)
1123+
if result.ndim > 1:
1124+
result = list(result)
11541125
cons = self._orig._constructor
11551126
return cons(result, name=name, index=index)
11561127

11571128
@copy(str_cat)
11581129
def cat(self, others=None, sep=None, na_rep=None):
11591130
data = self._orig if self._is_categorical else self._data
11601131
result = str_cat(data, others=others, sep=sep, na_rep=na_rep)
1132+
if not hasattr(result, 'ndim'):
1133+
# str_cat may results in np.nan or str
1134+
return result
11611135
return self._wrap_result(result, use_codes=(not self._is_categorical))
11621136

1163-
11641137
@deprecate_kwarg('return_type', 'expand',
11651138
mapping={'series': False, 'frame': True})
11661139
@copy(str_split)
11671140
def split(self, pat=None, n=-1, expand=False):
11681141
result = str_split(self._data, pat, n=n)
1169-
return self._wrap_result_expand(result, expand=expand)
1142+
return self._wrap_result(result, expand=expand)
11701143

11711144
@copy(str_rsplit)
11721145
def rsplit(self, pat=None, n=-1, expand=False):
11731146
result = str_rsplit(self._data, pat, n=n)
1174-
return self._wrap_result_expand(result, expand=expand)
1147+
return self._wrap_result(result, expand=expand)
11751148

11761149
_shared_docs['str_partition'] = ("""
11771150
Split the string at the %(side)s occurrence of `sep`, and return 3 elements
@@ -1222,15 +1195,15 @@ def rsplit(self, pat=None, n=-1, expand=False):
12221195
def partition(self, pat=' ', expand=True):
12231196
f = lambda x: x.partition(pat)
12241197
result = _na_map(f, self._data)
1225-
return self._wrap_result_expand(result, expand=expand)
1198+
return self._wrap_result(result, expand=expand)
12261199

12271200
@Appender(_shared_docs['str_partition'] % {'side': 'last',
12281201
'return': '3 elements containing two empty strings, followed by the string itself',
12291202
'also': 'partition : Split the string at the first occurrence of `sep`'})
12301203
def rpartition(self, pat=' ', expand=True):
12311204
f = lambda x: x.rpartition(pat)
12321205
result = _na_map(f, self._data)
1233-
return self._wrap_result_expand(result, expand=expand)
1206+
return self._wrap_result(result, expand=expand)
12341207

12351208
@copy(str_get)
12361209
def get(self, i):
@@ -1371,12 +1344,13 @@ def wrap(self, width, **kwargs):
13711344
return self._wrap_result(result)
13721345

13731346
@copy(str_get_dummies)
1374-
def get_dummies(self, sep='|'):
1347+
def get_dummies(self, sep='|', expand=True):
13751348
# we need to cast to Series of strings as only that has all
13761349
# methods available for making the dummies...
13771350
data = self._orig.astype(str) if self._is_categorical else self._data
1378-
result = str_get_dummies(data, sep)
1379-
return self._wrap_result(result, use_codes=(not self._is_categorical))
1351+
result, name = str_get_dummies(data, sep)
1352+
return self._wrap_result(result, use_codes=(not self._is_categorical),
1353+
name=name, expand=expand)
13801354

13811355
@copy(str_translate)
13821356
def translate(self, table, deletechars=None):
@@ -1389,9 +1363,18 @@ def translate(self, table, deletechars=None):
13891363
findall = _pat_wrapper(str_findall, flags=True)
13901364

13911365
@copy(str_extract)
1392-
def extract(self, pat, flags=0):
1393-
result, name = str_extract(self._data, pat, flags=flags)
1394-
return self._wrap_result(result, name=name)
1366+
def extract(self, pat, flags=0, expand=None):
1367+
result, name = str_extract(self._orig, pat, flags=flags)
1368+
if expand is None and hasattr(result, 'ndim'):
1369+
# to be compat with previous behavior
1370+
if len(result) == 0:
1371+
# for empty input
1372+
expand = True if isinstance(name, list) else False
1373+
elif result.ndim > 1:
1374+
expand = True
1375+
else:
1376+
expand = False
1377+
return self._wrap_result(result, name=name, use_codes=False, expand=expand)
13951378

13961379
_shared_docs['find'] = ("""
13971380
Return %(side)s indexes in each strings in the Series/Index

pandas/tests/test_categorical.py

+1
Original file line numberDiff line numberDiff line change
@@ -3714,6 +3714,7 @@ def test_str_accessor_api_for_categorical(self):
37143714

37153715

37163716
for func, args, kwargs in func_defs:
3717+
print(func, args, kwargs, c)
37173718
res = getattr(c.str, func)(*args, **kwargs)
37183719
exp = getattr(s.str, func)(*args, **kwargs)
37193720

0 commit comments

Comments
 (0)