@@ -424,17 +424,21 @@ def str_extract(arr, pat, flags=0):
424
424
Pattern or regular expression
425
425
flags : int, default 0 (no flags)
426
426
re module flags, e.g. re.IGNORECASE
427
+ expand : None or bool, default None
428
+ * If None, return Series/Index (one group) or DataFrame/MultiIndex (multiple groups)
429
+ * If True, return DataFrame/MultiIndex expanding dimensionality.
430
+ * If False, return Series/Index.
427
431
428
432
Returns
429
433
-------
430
- extracted groups : Series (one group) or DataFrame (multiple groups)
434
+ extracted groups : Series/Index or DataFrame/MultiIndex of objects
431
435
Note that dtype of the result is always object, even when no match is
432
436
found and the result is a Series or DataFrame containing only NaN
433
437
values.
434
438
435
439
Examples
436
440
--------
437
- A pattern with one group will return a Series. Non-matches will be NaN.
441
+ A pattern with one group returns a Series. Non-matches will be NaN.
438
442
439
443
>>> Series(['a1', 'b2', 'c3']).str.extract('[ab](\d)')
440
444
0 1
@@ -466,11 +470,14 @@ def str_extract(arr, pat, flags=0):
466
470
1 b 2
467
471
2 NaN NaN
468
472
469
- """
470
- from pandas .core .series import Series
471
- from pandas .core .frame import DataFrame
472
- from pandas .core .index import Index
473
+ Or you can specify ``expand=False`` to return Series.
473
474
475
+ >>> pd.Series(['a1', 'b2', 'c3']).str.extract('([ab])?(\d)', expand=False)
476
+ 0 [a, 1]
477
+ 1 [b, 2]
478
+ 2 [nan, 3]
479
+ Name: [0, 1], dtype: object
480
+ """
474
481
regex = re .compile (pat , flags = flags )
475
482
# just to be safe, check this
476
483
if regex .groups == 0 :
@@ -490,18 +497,9 @@ def f(x):
490
497
result = np .array ([f (val )[0 ] for val in arr ], dtype = object )
491
498
name = _get_single_group_name (regex )
492
499
else :
493
- if isinstance (arr , Index ):
494
- raise ValueError ("only one regex group is supported with Index" )
495
- name = None
496
500
names = dict (zip (regex .groupindex .values (), regex .groupindex .keys ()))
497
- columns = [names .get (1 + i , i ) for i in range (regex .groups )]
498
- if arr .empty :
499
- result = DataFrame (columns = columns , dtype = object )
500
- else :
501
- result = DataFrame ([f (val ) for val in arr ],
502
- columns = columns ,
503
- index = arr .index ,
504
- dtype = object )
501
+ name = [names .get (1 + i , i ) for i in range (regex .groups )]
502
+ result = np .array ([f (val ) for val in arr ], dtype = object )
505
503
return result , name
506
504
507
505
@@ -514,10 +512,13 @@ def str_get_dummies(arr, sep='|'):
514
512
----------
515
513
sep : string, default "|"
516
514
String to split on.
515
+ expand : bool, default True
516
+ * If True, return DataFrame/MultiIndex expanding dimensionality.
517
+ * If False, return Series/Index.
517
518
518
519
Returns
519
520
-------
520
- dummies : DataFrame
521
+ dummies : Series/Index or DataFrame/MultiIndex of objects
521
522
522
523
Examples
523
524
--------
@@ -537,14 +538,7 @@ def str_get_dummies(arr, sep='|'):
537
538
--------
538
539
pandas.get_dummies
539
540
"""
540
- from pandas .core .frame import DataFrame
541
541
from pandas .core .index import Index
542
-
543
- # GH9980, Index.str does not support get_dummies() as it returns a frame
544
- if isinstance (arr , Index ):
545
- raise TypeError ("get_dummies is not supported for string methods on Index" )
546
-
547
- # TODO remove this hack?
548
542
arr = arr .fillna ('' )
549
543
try :
550
544
arr = sep + arr + sep
@@ -561,7 +555,7 @@ def str_get_dummies(arr, sep='|'):
561
555
for i , t in enumerate (tags ):
562
556
pat = sep + t + sep
563
557
dummies [:, i ] = lib .map_infer (arr .values , lambda x : pat in x )
564
- return DataFrame ( dummies , arr . index , tags )
558
+ return dummies , tags
565
559
566
560
567
561
def str_join (arr , sep ):
@@ -1081,7 +1075,10 @@ def __iter__(self):
1081
1075
i += 1
1082
1076
g = self .get (i )
1083
1077
1084
- def _wrap_result (self , result , use_codes = True , name = None ):
1078
+ def _wrap_result (self , result , use_codes = True , name = None , expand = False ):
1079
+
1080
+ if not isinstance (expand , bool ):
1081
+ raise ValueError ("expand must be True or False" )
1085
1082
1086
1083
# for category, we do the stuff on the categories, so blow it up
1087
1084
# to the full series again
@@ -1095,39 +1092,11 @@ def _wrap_result(self, result, use_codes=True, name=None):
1095
1092
# can be merged to _wrap_result_expand in v0.17
1096
1093
from pandas .core .series import Series
1097
1094
from pandas .core .frame import DataFrame
1098
- from pandas .core .index import Index
1095
+ from pandas .core .index import Index , MultiIndex
1099
1096
1100
- if not hasattr (result , 'ndim' ):
1101
- return result
1102
1097
name = name or getattr (result , 'name' , None ) or self ._orig .name
1103
1098
1104
- if result .ndim == 1 :
1105
- if isinstance (self ._orig , Index ):
1106
- # if result is a boolean np.array, return the np.array
1107
- # instead of wrapping it into a boolean Index (GH 8875)
1108
- if is_bool_dtype (result ):
1109
- return result
1110
- return Index (result , name = name )
1111
- return Series (result , index = self ._orig .index , name = name )
1112
- else :
1113
- assert result .ndim < 3
1114
- return DataFrame (result , index = self ._orig .index )
1115
-
1116
- def _wrap_result_expand (self , result , expand = False ):
1117
- if not isinstance (expand , bool ):
1118
- raise ValueError ("expand must be True or False" )
1119
-
1120
- # for category, we do the stuff on the categories, so blow it up
1121
- # to the full series again
1122
- if self ._is_categorical :
1123
- result = take_1d (result , self ._orig .cat .codes )
1124
-
1125
- from pandas .core .index import Index , MultiIndex
1126
- if not hasattr (result , 'ndim' ):
1127
- return result
1128
-
1129
1099
if isinstance (self ._orig , Index ):
1130
- name = getattr (result , 'name' , None )
1131
1100
# if result is a boolean np.array, return the np.array
1132
1101
# instead of wrapping it into a boolean Index (GH 8875)
1133
1102
if hasattr (result , 'dtype' ) and is_bool_dtype (result ):
@@ -1137,7 +1106,7 @@ def _wrap_result_expand(self, result, expand=False):
1137
1106
result = list (result )
1138
1107
return MultiIndex .from_tuples (result , names = name )
1139
1108
else :
1140
- return Index (result , name = name )
1109
+ return Index (result , name = name , tupleize_cols = False )
1141
1110
else :
1142
1111
index = self ._orig .index
1143
1112
if expand :
@@ -1148,30 +1117,34 @@ def cons_row(x):
1148
1117
return [ x ]
1149
1118
cons = self ._orig ._constructor_expanddim
1150
1119
data = [cons_row (x ) for x in result ]
1151
- return cons (data , index = index )
1120
+ return cons (data , index = index , columns = name ,
1121
+ dtype = result .dtype )
1152
1122
else :
1153
- name = getattr (result , 'name' , None )
1123
+ if result .ndim > 1 :
1124
+ result = list (result )
1154
1125
cons = self ._orig ._constructor
1155
1126
return cons (result , name = name , index = index )
1156
1127
1157
1128
@copy (str_cat )
1158
1129
def cat (self , others = None , sep = None , na_rep = None ):
1159
1130
data = self ._orig if self ._is_categorical else self ._data
1160
1131
result = str_cat (data , others = others , sep = sep , na_rep = na_rep )
1132
+ if not hasattr (result , 'ndim' ):
1133
+ # str_cat may results in np.nan or str
1134
+ return result
1161
1135
return self ._wrap_result (result , use_codes = (not self ._is_categorical ))
1162
1136
1163
-
1164
1137
@deprecate_kwarg ('return_type' , 'expand' ,
1165
1138
mapping = {'series' : False , 'frame' : True })
1166
1139
@copy (str_split )
1167
1140
def split (self , pat = None , n = - 1 , expand = False ):
1168
1141
result = str_split (self ._data , pat , n = n )
1169
- return self ._wrap_result_expand (result , expand = expand )
1142
+ return self ._wrap_result (result , expand = expand )
1170
1143
1171
1144
@copy (str_rsplit )
1172
1145
def rsplit (self , pat = None , n = - 1 , expand = False ):
1173
1146
result = str_rsplit (self ._data , pat , n = n )
1174
- return self ._wrap_result_expand (result , expand = expand )
1147
+ return self ._wrap_result (result , expand = expand )
1175
1148
1176
1149
_shared_docs ['str_partition' ] = ("""
1177
1150
Split the string at the %(side)s occurrence of `sep`, and return 3 elements
@@ -1222,15 +1195,15 @@ def rsplit(self, pat=None, n=-1, expand=False):
1222
1195
def partition (self , pat = ' ' , expand = True ):
1223
1196
f = lambda x : x .partition (pat )
1224
1197
result = _na_map (f , self ._data )
1225
- return self ._wrap_result_expand (result , expand = expand )
1198
+ return self ._wrap_result (result , expand = expand )
1226
1199
1227
1200
@Appender (_shared_docs ['str_partition' ] % {'side' : 'last' ,
1228
1201
'return' : '3 elements containing two empty strings, followed by the string itself' ,
1229
1202
'also' : 'partition : Split the string at the first occurrence of `sep`' })
1230
1203
def rpartition (self , pat = ' ' , expand = True ):
1231
1204
f = lambda x : x .rpartition (pat )
1232
1205
result = _na_map (f , self ._data )
1233
- return self ._wrap_result_expand (result , expand = expand )
1206
+ return self ._wrap_result (result , expand = expand )
1234
1207
1235
1208
@copy (str_get )
1236
1209
def get (self , i ):
@@ -1371,12 +1344,13 @@ def wrap(self, width, **kwargs):
1371
1344
return self ._wrap_result (result )
1372
1345
1373
1346
@copy (str_get_dummies )
1374
- def get_dummies (self , sep = '|' ):
1347
+ def get_dummies (self , sep = '|' , expand = True ):
1375
1348
# we need to cast to Series of strings as only that has all
1376
1349
# methods available for making the dummies...
1377
1350
data = self ._orig .astype (str ) if self ._is_categorical else self ._data
1378
- result = str_get_dummies (data , sep )
1379
- return self ._wrap_result (result , use_codes = (not self ._is_categorical ))
1351
+ result , name = str_get_dummies (data , sep )
1352
+ return self ._wrap_result (result , use_codes = (not self ._is_categorical ),
1353
+ name = name , expand = expand )
1380
1354
1381
1355
@copy (str_translate )
1382
1356
def translate (self , table , deletechars = None ):
@@ -1389,9 +1363,18 @@ def translate(self, table, deletechars=None):
1389
1363
findall = _pat_wrapper (str_findall , flags = True )
1390
1364
1391
1365
@copy (str_extract )
1392
- def extract (self , pat , flags = 0 ):
1393
- result , name = str_extract (self ._data , pat , flags = flags )
1394
- return self ._wrap_result (result , name = name )
1366
+ def extract (self , pat , flags = 0 , expand = None ):
1367
+ result , name = str_extract (self ._orig , pat , flags = flags )
1368
+ if expand is None and hasattr (result , 'ndim' ):
1369
+ # to be compat with previous behavior
1370
+ if len (result ) == 0 :
1371
+ # for empty input
1372
+ expand = True if isinstance (name , list ) else False
1373
+ elif result .ndim > 1 :
1374
+ expand = True
1375
+ else :
1376
+ expand = False
1377
+ return self ._wrap_result (result , name = name , use_codes = False , expand = expand )
1395
1378
1396
1379
_shared_docs ['find' ] = ("""
1397
1380
Return %(side)s indexes in each strings in the Series/Index
0 commit comments