Skip to content

Commit b7aa11e

Browse files
committed
ENH: add expand kw to str.extract and str.get_dummies
1 parent eafd22d commit b7aa11e

File tree

2 files changed

+224
-75
lines changed

2 files changed

+224
-75
lines changed

pandas/core/strings.py

+65-63
Original file line numberDiff line numberDiff line change
@@ -421,24 +421,39 @@ def str_extract(arr, pat, flags=0):
421421
Pattern or regular expression
422422
flags : int, default 0 (no flags)
423423
re module flags, e.g. re.IGNORECASE
424+
expand : bool, default True
425+
* If True, return DataFrame/MultiIndex expanding dimensionality.
426+
* If False, return Series/Index.
424427
425428
Returns
426429
-------
427-
extracted groups : Series (one group) or DataFrame (multiple groups)
430+
extracted groups : Deprecated: Series (one group) or DataFrame (multiple groups)
428431
Note that dtype of the result is always object, even when no match is
429432
found and the result is a Series or DataFrame containing only NaN
430433
values.
431434
435+
Being changed to return Series/Index or DataFrame/MultiIndex of objects
436+
specified by expand option in future version.
437+
432438
Examples
433439
--------
434-
A pattern with one group will return a Series. Non-matches will be NaN.
440+
Deprecated: A pattern with one group returns a Series. Non-matches will be NaN.
441+
Being changed to return DataFrame by default in future version.
435442
436443
>>> Series(['a1', 'b2', 'c3']).str.extract('[ab](\d)')
437444
0 1
438445
1 2
439446
2 NaN
440447
dtype: object
441448
449+
Specify ``expand=False`` to return Series.
450+
451+
>>> Series(['a1', 'b2', 'c3']).str.extract('[ab](\d)', expand=False)
452+
0 1
453+
1 2
454+
2 NaN
455+
dtype: object
456+
442457
A pattern with more than one group will return a DataFrame.
443458
444459
>>> Series(['a1', 'b2', 'c3']).str.extract('([ab])(\d)')
@@ -462,12 +477,7 @@ def str_extract(arr, pat, flags=0):
462477
0 a 1
463478
1 b 2
464479
2 NaN NaN
465-
466480
"""
467-
from pandas.core.series import Series
468-
from pandas.core.frame import DataFrame
469-
from pandas.core.index import Index
470-
471481
regex = re.compile(pat, flags=flags)
472482
# just to be safe, check this
473483
if regex.groups == 0:
@@ -487,18 +497,9 @@ def f(x):
487497
result = np.array([f(val)[0] for val in arr], dtype=object)
488498
name = _get_single_group_name(regex)
489499
else:
490-
if isinstance(arr, Index):
491-
raise ValueError("only one regex group is supported with Index")
492-
name = None
493500
names = dict(zip(regex.groupindex.values(), regex.groupindex.keys()))
494-
columns = [names.get(1 + i, i) for i in range(regex.groups)]
495-
if arr.empty:
496-
result = DataFrame(columns=columns, dtype=object)
497-
else:
498-
result = DataFrame([f(val) for val in arr],
499-
columns=columns,
500-
index=arr.index,
501-
dtype=object)
501+
name = [names.get(1 + i, i) for i in range(regex.groups)]
502+
result = np.array([f(val) for val in arr], dtype=object)
502503
return result, name
503504

504505

@@ -511,6 +512,9 @@ def str_get_dummies(arr, sep='|'):
511512
----------
512513
sep : string, default "|"
513514
String to split on.
515+
expand : bool, default False
516+
* If True, return DataFrame/MultiIndex expanding dimensionality.
517+
* If False, return Series/Index.
514518
515519
Returns
516520
-------
@@ -534,15 +538,15 @@ def str_get_dummies(arr, sep='|'):
534538
--------
535539
pandas.get_dummies
536540
"""
537-
from pandas.core.frame import DataFrame
538541
from pandas.core.index import Index
539-
540-
# GH9980, Index.str does not support get_dummies() as it returns a frame
542+
# TODO: Add fillna GH 10089
541543
if isinstance(arr, Index):
542-
raise TypeError("get_dummies is not supported for string methods on Index")
543-
544-
# TODO remove this hack?
545-
arr = arr.fillna('')
544+
# temp hack
545+
values = arr.values
546+
values[isnull(values)] = ''
547+
arr = Index(values)
548+
else:
549+
arr = arr.fillna('')
546550
try:
547551
arr = sep + arr + sep
548552
except TypeError:
@@ -558,7 +562,7 @@ def str_get_dummies(arr, sep='|'):
558562
for i, t in enumerate(tags):
559563
pat = sep + t + sep
560564
dummies[:, i] = lib.map_infer(arr.values, lambda x: pat in x)
561-
return DataFrame(dummies, arr.index, tags)
565+
return dummies, tags
562566

563567

564568
def str_join(arr, sep):
@@ -1043,40 +1047,19 @@ def __iter__(self):
10431047
i += 1
10441048
g = self.get(i)
10451049

1046-
def _wrap_result(self, result, **kwargs):
1047-
1048-
# leave as it is to keep extract and get_dummies results
1049-
# can be merged to _wrap_result_expand in v0.17
1050-
from pandas.core.series import Series
1051-
from pandas.core.frame import DataFrame
1052-
from pandas.core.index import Index
1053-
1054-
if not hasattr(result, 'ndim'):
1055-
return result
1056-
name = kwargs.get('name') or getattr(result, 'name', None) or self.series.name
1057-
1058-
if result.ndim == 1:
1059-
if isinstance(self.series, Index):
1060-
# if result is a boolean np.array, return the np.array
1061-
# instead of wrapping it into a boolean Index (GH 8875)
1062-
if is_bool_dtype(result):
1063-
return result
1064-
return Index(result, name=name)
1065-
return Series(result, index=self.series.index, name=name)
1066-
else:
1067-
assert result.ndim < 3
1068-
return DataFrame(result, index=self.series.index)
1050+
def _wrap_result(self, result, expand=False, name=None):
1051+
from pandas.core.index import Index, MultiIndex
10691052

1070-
def _wrap_result_expand(self, result, expand=False):
10711053
if not isinstance(expand, bool):
10721054
raise ValueError("expand must be True or False")
10731055

1074-
from pandas.core.index import Index, MultiIndex
1056+
if name is None:
1057+
name = getattr(result, 'name', None) or self.series.name
1058+
10751059
if not hasattr(result, 'ndim'):
10761060
return result
10771061

10781062
if isinstance(self.series, Index):
1079-
name = getattr(result, 'name', None)
10801063
# if result is a boolean np.array, return the np.array
10811064
# instead of wrapping it into a boolean Index (GH 8875)
10821065
if hasattr(result, 'dtype') and is_bool_dtype(result):
@@ -1092,10 +1075,12 @@ def _wrap_result_expand(self, result, expand=False):
10921075
if expand:
10931076
cons_row = self.series._constructor
10941077
cons = self.series._constructor_expanddim
1095-
data = [cons_row(x) for x in result]
1096-
return cons(data, index=index)
1078+
data = [cons_row(x, index=name) for x in result]
1079+
return cons(data, index=index, columns=name,
1080+
dtype=result.dtype)
10971081
else:
1098-
name = getattr(result, 'name', None)
1082+
if result.ndim > 1:
1083+
result = list(result)
10991084
cons = self.series._constructor
11001085
return cons(result, name=name, index=index)
11011086

@@ -1109,7 +1094,7 @@ def cat(self, others=None, sep=None, na_rep=None):
11091094
@copy(str_split)
11101095
def split(self, pat=None, n=-1, expand=False):
11111096
result = str_split(self.series, pat, n=n)
1112-
return self._wrap_result_expand(result, expand=expand)
1097+
return self._wrap_result(result, expand=expand)
11131098

11141099
_shared_docs['str_partition'] = ("""
11151100
Split the string at the %(side)s occurrence of `sep`, and return 3 elements
@@ -1160,15 +1145,15 @@ def split(self, pat=None, n=-1, expand=False):
11601145
def partition(self, pat=' ', expand=True):
11611146
f = lambda x: x.partition(pat)
11621147
result = _na_map(f, self.series)
1163-
return self._wrap_result_expand(result, expand=expand)
1148+
return self._wrap_result(result, expand=expand)
11641149

11651150
@Appender(_shared_docs['str_partition'] % {'side': 'last',
11661151
'return': '3 elements containing two empty strings, followed by the string itself',
11671152
'also': 'partition : Split the string at the first occurrence of `sep`'})
11681153
def rpartition(self, pat=' ', expand=True):
11691154
f = lambda x: x.rpartition(pat)
11701155
result = _na_map(f, self.series)
1171-
return self._wrap_result_expand(result, expand=expand)
1156+
return self._wrap_result(result, expand=expand)
11721157

11731158
@copy(str_get)
11741159
def get(self, i):
@@ -1309,9 +1294,9 @@ def wrap(self, width, **kwargs):
13091294
return self._wrap_result(result)
13101295

13111296
@copy(str_get_dummies)
1312-
def get_dummies(self, sep='|'):
1313-
result = str_get_dummies(self.series, sep)
1314-
return self._wrap_result(result)
1297+
def get_dummies(self, sep='|', expand=True):
1298+
result, name = str_get_dummies(self.series, sep)
1299+
return self._wrap_result(result, name=name, expand=expand)
13151300

13161301
@copy(str_translate)
13171302
def translate(self, table, deletechars=None):
@@ -1324,9 +1309,26 @@ def translate(self, table, deletechars=None):
13241309
findall = _pat_wrapper(str_findall, flags=True)
13251310

13261311
@copy(str_extract)
1327-
def extract(self, pat, flags=0):
1312+
def extract(self, pat, flags=0, expand=None):
13281313
result, name = str_extract(self.series, pat, flags=flags)
1329-
return self._wrap_result(result, name=name)
1314+
1315+
if expand is None and hasattr(result, 'ndim'):
1316+
# to be compat with previous behavior
1317+
msg = ("Extracting with single group returns DataFrame in future version. "
1318+
"Specify expand=False to return Series.")
1319+
if len(result) == 0:
1320+
# for empty input
1321+
if isinstance(name, list):
1322+
expand = True
1323+
else:
1324+
warnings.warn(msg, UserWarning)
1325+
expand = False
1326+
elif result.ndim > 1:
1327+
expand = True
1328+
else:
1329+
warnings.warn(msg, UserWarning)
1330+
expand = False
1331+
return self._wrap_result(result, name=name, expand=expand)
13301332

13311333
_shared_docs['find'] = ("""
13321334
Return %(side)s indexes in each strings in the Series/Index

0 commit comments

Comments
 (0)