Skip to content

Commit 9b90016

Browse files
sinhrksjreback
authored andcommitted
API: Index.take inconsistently handle fill_value
closes #12631 Author: sinhrks <[email protected]> Closes #12676 from sinhrks/index_take and squashes the following commits: 3c19920 [sinhrks] API: Index.take inconsistently handle fill_value
1 parent 0d58446 commit 9b90016

File tree

14 files changed

+476
-47
lines changed

14 files changed

+476
-47
lines changed

doc/source/whatsnew/v0.18.1.txt

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ New features
2929
Enhancements
3030
~~~~~~~~~~~~
3131

32-
3332
.. _whatsnew_0181.partial_string_indexing:
3433

3534
Partial string indexing on ``DateTimeIndex`` when part of a ``MultiIndex``
@@ -59,6 +58,14 @@ Other Enhancements
5958
- ``pd.read_csv()`` now supports opening ZIP files that contains a single CSV, via extension inference or explict ``compression='zip'`` (:issue:`12175`)
6059
- ``pd.read_csv()`` now supports opening files using xz compression, via extension inference or explicit ``compression='xz'`` is specified; ``xz`` compressions is also supported by ``DataFrame.to_csv`` in the same way (:issue:`11852`)
6160
- ``pd.read_msgpack()`` now always gives writeable ndarrays even when compression is used (:issue:`12359`).
61+
- ``Index.take`` now handles ``allow_fill`` and ``fill_value`` consistently (:issue:`12631`)
62+
63+
.. ipython:: python
64+
65+
idx = pd.Index([1., 2., 3., 4.], dtype='float')
66+
idx.take([2, -1]) # default, allow_fill=True, fill_value=None
67+
idx.take([2, -1], fill_value=True)
68+
6269

6370
.. _whatsnew_0181.api:
6471

pandas/indexes/base.py

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1329,24 +1329,60 @@ def _ensure_compat_concat(indexes):
13291329

13301330
return indexes
13311331

1332-
def take(self, indices, axis=0, allow_fill=True, fill_value=None):
1333-
"""
1334-
return a new Index of the values selected by the indexer
1332+
_index_shared_docs['take'] = """
1333+
return a new Index of the values selected by the indices
13351334
13361335
For internal compatibility with numpy arrays.
13371336
1338-
# filling must always be None/nan here
1339-
# but is passed thru internally
1337+
Parameters
1338+
----------
1339+
indices : list
1340+
Indices to be taken
1341+
axis : int, optional
1342+
The axis over which to select values, always 0.
1343+
allow_fill : bool, default True
1344+
fill_value : bool, default None
1345+
If allow_fill=True and fill_value is not None, indices specified by
1346+
-1 is regarded as NA. If Index doesn't hold NA, raise ValueError
13401347
13411348
See also
13421349
--------
13431350
numpy.ndarray.take
13441351
"""
1345-
1352+
@Appender(_index_shared_docs['take'])
1353+
def take(self, indices, axis=0, allow_fill=True, fill_value=None):
13461354
indices = com._ensure_platform_int(indices)
1347-
taken = self.values.take(indices)
1355+
if self._can_hold_na:
1356+
taken = self._assert_take_fillable(self.values, indices,
1357+
allow_fill=allow_fill,
1358+
fill_value=fill_value,
1359+
na_value=self._na_value)
1360+
else:
1361+
if allow_fill and fill_value is not None:
1362+
msg = 'Unable to fill values because {0} cannot contain NA'
1363+
raise ValueError(msg.format(self.__class__.__name__))
1364+
taken = self.values.take(indices)
13481365
return self._shallow_copy(taken)
13491366

1367+
def _assert_take_fillable(self, values, indices, allow_fill=True,
1368+
fill_value=None, na_value=np.nan):
1369+
""" Internal method to handle NA filling of take """
1370+
indices = com._ensure_platform_int(indices)
1371+
1372+
# only fill if we are passing a non-None fill_value
1373+
if allow_fill and fill_value is not None:
1374+
if (indices < -1).any():
1375+
msg = ('When allow_fill=True and fill_value is not None, '
1376+
'all indices must be >= -1')
1377+
raise ValueError(msg)
1378+
taken = values.take(indices)
1379+
mask = indices == -1
1380+
if mask.any():
1381+
taken[mask] = na_value
1382+
else:
1383+
taken = values.take(indices)
1384+
return taken
1385+
13501386
@cache_readonly
13511387
def _isnan(self):
13521388
""" return if each value is nan"""

pandas/indexes/category.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -459,21 +459,13 @@ def _convert_list_indexer(self, keyarr, kind=None):
459459

460460
return None
461461

462-
def take(self, indexer, axis=0, allow_fill=True, fill_value=None):
463-
"""
464-
For internal compatibility with numpy arrays.
465-
466-
# filling must always be None/nan here
467-
# but is passed thru internally
468-
assert isnull(fill_value)
469-
470-
See also
471-
--------
472-
numpy.ndarray.take
473-
"""
474-
475-
indexer = com._ensure_platform_int(indexer)
476-
taken = self.codes.take(indexer)
462+
@Appender(_index_shared_docs['take'])
463+
def take(self, indices, axis=0, allow_fill=True, fill_value=None):
464+
indices = com._ensure_platform_int(indices)
465+
taken = self._assert_take_fillable(self.codes, indices,
466+
allow_fill=allow_fill,
467+
fill_value=fill_value,
468+
na_value=-1)
477469
return self._create_from_codes(taken)
478470

479471
def delete(self, loc):

pandas/indexes/multi.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
12
# pylint: disable=E1101,E1103,W0232
23
import datetime
34
import warnings
@@ -11,7 +12,7 @@
1112

1213
from pandas.compat import range, zip, lrange, lzip, map
1314
from pandas import compat
14-
from pandas.core.base import FrozenList
15+
from pandas.core.base import FrozenList, FrozenNDArray
1516
import pandas.core.base as base
1617
from pandas.util.decorators import (Appender, cache_readonly,
1718
deprecate, deprecate_kwarg)
@@ -1003,12 +1004,38 @@ def __getitem__(self, key):
10031004
names=self.names, sortorder=sortorder,
10041005
verify_integrity=False)
10051006

1006-
def take(self, indexer, axis=None):
1007-
indexer = com._ensure_platform_int(indexer)
1008-
new_labels = [lab.take(indexer) for lab in self.labels]
1009-
return MultiIndex(levels=self.levels, labels=new_labels,
1007+
@Appender(_index_shared_docs['take'])
1008+
def take(self, indices, axis=0, allow_fill=True, fill_value=None):
1009+
indices = com._ensure_platform_int(indices)
1010+
taken = self._assert_take_fillable(self.labels, indices,
1011+
allow_fill=allow_fill,
1012+
fill_value=fill_value,
1013+
na_value=-1)
1014+
return MultiIndex(levels=self.levels, labels=taken,
10101015
names=self.names, verify_integrity=False)
10111016

1017+
def _assert_take_fillable(self, values, indices, allow_fill=True,
1018+
fill_value=None, na_value=None):
1019+
""" Internal method to handle NA filling of take """
1020+
# only fill if we are passing a non-None fill_value
1021+
if allow_fill and fill_value is not None:
1022+
if (indices < -1).any():
1023+
msg = ('When allow_fill=True and fill_value is not None, '
1024+
'all indices must be >= -1')
1025+
raise ValueError(msg)
1026+
taken = [lab.take(indices) for lab in self.labels]
1027+
mask = indices == -1
1028+
if mask.any():
1029+
masked = []
1030+
for new_label in taken:
1031+
label_values = new_label.values()
1032+
label_values[mask] = na_value
1033+
masked.append(base.FrozenNDArray(label_values))
1034+
taken = masked
1035+
else:
1036+
taken = [lab.take(indices) for lab in self.labels]
1037+
return taken
1038+
10121039
def append(self, other):
10131040
"""
10141041
Append a collection of Index options together

pandas/tests/indexes/test_base.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1240,6 +1240,34 @@ def test_nan_first_take_datetime(self):
12401240
exp = Index([idx[-1], idx[0], idx[1]])
12411241
tm.assert_index_equal(res, exp)
12421242

1243+
def test_take_fill_value(self):
1244+
# GH 12631
1245+
idx = pd.Index(list('ABC'), name='xxx')
1246+
result = idx.take(np.array([1, 0, -1]))
1247+
expected = pd.Index(list('BAC'), name='xxx')
1248+
tm.assert_index_equal(result, expected)
1249+
1250+
# fill_value
1251+
result = idx.take(np.array([1, 0, -1]), fill_value=True)
1252+
expected = pd.Index(['B', 'A', np.nan], name='xxx')
1253+
tm.assert_index_equal(result, expected)
1254+
1255+
# allow_fill=False
1256+
result = idx.take(np.array([1, 0, -1]), allow_fill=False,
1257+
fill_value=True)
1258+
expected = pd.Index(['B', 'A', 'C'], name='xxx')
1259+
tm.assert_index_equal(result, expected)
1260+
1261+
msg = ('When allow_fill=True and fill_value is not None, '
1262+
'all indices must be >= -1')
1263+
with tm.assertRaisesRegexp(ValueError, msg):
1264+
idx.take(np.array([1, 0, -2]), fill_value=True)
1265+
with tm.assertRaisesRegexp(ValueError, msg):
1266+
idx.take(np.array([1, 0, -5]), fill_value=True)
1267+
1268+
with tm.assertRaises(IndexError):
1269+
idx.take(np.array([1, -5]))
1270+
12431271
def test_reindex_preserves_name_if_target_is_list_or_ndarray(self):
12441272
# GH6552
12451273
idx = pd.Index([0, 1, 2])

pandas/tests/indexes/test_category.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -708,3 +708,100 @@ def test_fillna_categorical(self):
708708
with tm.assertRaisesRegexp(ValueError,
709709
'fill value must be in categories'):
710710
idx.fillna(2.0)
711+
712+
def test_take_fill_value(self):
713+
# GH 12631
714+
715+
# numeric category
716+
idx = pd.CategoricalIndex([1, 2, 3], name='xxx')
717+
result = idx.take(np.array([1, 0, -1]))
718+
expected = pd.CategoricalIndex([2, 1, 3], name='xxx')
719+
tm.assert_index_equal(result, expected)
720+
tm.assert_categorical_equal(result.values, expected.values)
721+
722+
# fill_value
723+
result = idx.take(np.array([1, 0, -1]), fill_value=True)
724+
expected = pd.CategoricalIndex([2, 1, np.nan], categories=[1, 2, 3],
725+
name='xxx')
726+
tm.assert_index_equal(result, expected)
727+
tm.assert_categorical_equal(result.values, expected.values)
728+
729+
# allow_fill=False
730+
result = idx.take(np.array([1, 0, -1]), allow_fill=False,
731+
fill_value=True)
732+
expected = pd.CategoricalIndex([2, 1, 3], name='xxx')
733+
tm.assert_index_equal(result, expected)
734+
tm.assert_categorical_equal(result.values, expected.values)
735+
736+
# object category
737+
idx = pd.CategoricalIndex(list('CBA'), categories=list('ABC'),
738+
ordered=True, name='xxx')
739+
result = idx.take(np.array([1, 0, -1]))
740+
expected = pd.CategoricalIndex(list('BCA'), categories=list('ABC'),
741+
ordered=True, name='xxx')
742+
tm.assert_index_equal(result, expected)
743+
tm.assert_categorical_equal(result.values, expected.values)
744+
745+
# fill_value
746+
result = idx.take(np.array([1, 0, -1]), fill_value=True)
747+
expected = pd.CategoricalIndex(['B', 'C', np.nan],
748+
categories=list('ABC'), ordered=True,
749+
name='xxx')
750+
tm.assert_index_equal(result, expected)
751+
tm.assert_categorical_equal(result.values, expected.values)
752+
753+
# allow_fill=False
754+
result = idx.take(np.array([1, 0, -1]), allow_fill=False,
755+
fill_value=True)
756+
expected = pd.CategoricalIndex(list('BCA'), categories=list('ABC'),
757+
ordered=True, name='xxx')
758+
tm.assert_index_equal(result, expected)
759+
tm.assert_categorical_equal(result.values, expected.values)
760+
761+
msg = ('When allow_fill=True and fill_value is not None, '
762+
'all indices must be >= -1')
763+
with tm.assertRaisesRegexp(ValueError, msg):
764+
idx.take(np.array([1, 0, -2]), fill_value=True)
765+
with tm.assertRaisesRegexp(ValueError, msg):
766+
idx.take(np.array([1, 0, -5]), fill_value=True)
767+
768+
with tm.assertRaises(IndexError):
769+
idx.take(np.array([1, -5]))
770+
771+
def test_take_fill_value_datetime(self):
772+
773+
# datetime category
774+
idx = pd.DatetimeIndex(['2011-01-01', '2011-02-01', '2011-03-01'],
775+
name='xxx')
776+
idx = pd.CategoricalIndex(idx)
777+
result = idx.take(np.array([1, 0, -1]))
778+
expected = pd.DatetimeIndex(['2011-02-01', '2011-01-01', '2011-03-01'],
779+
name='xxx')
780+
expected = pd.CategoricalIndex(expected)
781+
tm.assert_index_equal(result, expected)
782+
783+
# fill_value
784+
result = idx.take(np.array([1, 0, -1]), fill_value=True)
785+
expected = pd.DatetimeIndex(['2011-02-01', '2011-01-01', 'NaT'],
786+
name='xxx')
787+
exp_cats = pd.DatetimeIndex(['2011-01-01', '2011-02-01', '2011-03-01'])
788+
expected = pd.CategoricalIndex(expected, categories=exp_cats)
789+
tm.assert_index_equal(result, expected)
790+
791+
# allow_fill=False
792+
result = idx.take(np.array([1, 0, -1]), allow_fill=False,
793+
fill_value=True)
794+
expected = pd.DatetimeIndex(['2011-02-01', '2011-01-01', '2011-03-01'],
795+
name='xxx')
796+
expected = pd.CategoricalIndex(expected)
797+
tm.assert_index_equal(result, expected)
798+
799+
msg = ('When allow_fill=True and fill_value is not None, '
800+
'all indices must be >= -1')
801+
with tm.assertRaisesRegexp(ValueError, msg):
802+
idx.take(np.array([1, 0, -2]), fill_value=True)
803+
with tm.assertRaisesRegexp(ValueError, msg):
804+
idx.take(np.array([1, 0, -5]), fill_value=True)
805+
806+
with tm.assertRaises(IndexError):
807+
idx.take(np.array([1, -5]))

pandas/tests/indexes/test_multi.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1514,6 +1514,46 @@ def test_take_preserve_name(self):
15141514
taken = self.index.take([3, 0, 1])
15151515
self.assertEqual(taken.names, self.index.names)
15161516

1517+
def test_take_fill_value(self):
1518+
# GH 12631
1519+
vals = [['A', 'B'],
1520+
[pd.Timestamp('2011-01-01'), pd.Timestamp('2011-01-02')]]
1521+
idx = pd.MultiIndex.from_product(vals, names=['str', 'dt'])
1522+
1523+
result = idx.take(np.array([1, 0, -1]))
1524+
exp_vals = [('A', pd.Timestamp('2011-01-02')),
1525+
('A', pd.Timestamp('2011-01-01')),
1526+
('B', pd.Timestamp('2011-01-02'))]
1527+
expected = pd.MultiIndex.from_tuples(exp_vals, names=['str', 'dt'])
1528+
tm.assert_index_equal(result, expected)
1529+
1530+
# fill_value
1531+
result = idx.take(np.array([1, 0, -1]), fill_value=True)
1532+
exp_vals = [('A', pd.Timestamp('2011-01-02')),
1533+
('A', pd.Timestamp('2011-01-01')),
1534+
(np.nan, pd.NaT)]
1535+
expected = pd.MultiIndex.from_tuples(exp_vals, names=['str', 'dt'])
1536+
tm.assert_index_equal(result, expected)
1537+
1538+
# allow_fill=False
1539+
result = idx.take(np.array([1, 0, -1]), allow_fill=False,
1540+
fill_value=True)
1541+
exp_vals = [('A', pd.Timestamp('2011-01-02')),
1542+
('A', pd.Timestamp('2011-01-01')),
1543+
('B', pd.Timestamp('2011-01-02'))]
1544+
expected = pd.MultiIndex.from_tuples(exp_vals, names=['str', 'dt'])
1545+
tm.assert_index_equal(result, expected)
1546+
1547+
msg = ('When allow_fill=True and fill_value is not None, '
1548+
'all indices must be >= -1')
1549+
with tm.assertRaisesRegexp(ValueError, msg):
1550+
idx.take(np.array([1, 0, -2]), fill_value=True)
1551+
with tm.assertRaisesRegexp(ValueError, msg):
1552+
idx.take(np.array([1, 0, -5]), fill_value=True)
1553+
1554+
with tm.assertRaises(IndexError):
1555+
idx.take(np.array([1, -5]))
1556+
15171557
def test_join_level(self):
15181558
def _check_how(other, how):
15191559
join_index, lidx, ridx = other.join(self.index, how=how,

0 commit comments

Comments
 (0)