Skip to content

Commit e750c94

Browse files
authored
ENH: allow storing ExtensionArrays in Index (#43930)
1 parent 3c6a26e commit e750c94

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+509
-133
lines changed

doc/source/whatsnew/v1.4.0.rst

+37
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,43 @@ be removed in the future, see :ref:`here <whatsnew_140.deprecations.int64_uint64
9090

9191
See :ref:`here <advanced.numericindex>` for more about :class:`NumericIndex`.
9292

93+
94+
.. _whatsnew_140.enhancements.ExtensionIndex:
95+
96+
Index can hold arbitrary ExtensionArrays
97+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
98+
99+
Until now, passing a custom :class:`ExtensionArray` to ``pd.Index`` would cast the
100+
array to ``object`` dtype. Now :class:`Index` can directly hold arbitrary ExtensionArrays (:issue:`43930`).
101+
102+
*Previous behavior*:
103+
104+
.. ipython:: python
105+
106+
arr = pd.array([1, 2, pd.NA])
107+
idx = pd.Index(arr)
108+
109+
In the old behavior, ``idx`` would be object-dtype:
110+
111+
*Previous behavior*:
112+
113+
.. code-block:: ipython
114+
115+
In [1]: idx
116+
Out[1]: Index([1, 2, <NA>], dtype='object')
117+
118+
With the new behavior, we keep the original dtype:
119+
120+
*New behavior*:
121+
122+
.. ipython:: python
123+
124+
idx
125+
126+
One exception to this is ``SparseArray``, which will continue to cast to numpy
127+
dtype until pandas 2.0. At that point it will retain its dtype like other
128+
ExtensionArrays.
129+
93130
.. _whatsnew_140.enhancements.styler:
94131

95132
Styler

pandas/_libs/index.pyx

+3-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@ from pandas._libs import (
3333
hashtable as _hash,
3434
)
3535

36+
from pandas._libs.lib cimport eq_NA_compat
3637
from pandas._libs.missing cimport (
38+
C_NA as NA,
3739
checknull,
3840
is_matching_na,
3941
)
@@ -62,7 +64,7 @@ cdef ndarray _get_bool_indexer(ndarray values, object val):
6264
if values.descr.type_num == cnp.NPY_OBJECT:
6365
# i.e. values.dtype == object
6466
if not checknull(val):
65-
indexer = values == val
67+
indexer = eq_NA_compat(values, val)
6668

6769
else:
6870
# We need to check for _matching_ NA values

pandas/_libs/lib.pxd

+5
Original file line numberDiff line numberDiff line change
@@ -1 +1,6 @@
1+
from numpy cimport ndarray
2+
3+
14
cdef bint c_is_list_like(object, bint) except -1
5+
6+
cpdef ndarray eq_NA_compat(ndarray[object] arr, object key)

pandas/_libs/lib.pyx

+21
Original file line numberDiff line numberDiff line change
@@ -3050,6 +3050,27 @@ def is_bool_list(obj: list) -> bool:
30503050
return True
30513051

30523052

3053+
cpdef ndarray eq_NA_compat(ndarray[object] arr, object key):
3054+
"""
3055+
Check for `arr == key`, treating all values as not-equal to pd.NA.
3056+
3057+
key is assumed to have `not isna(key)`
3058+
"""
3059+
cdef:
3060+
ndarray[uint8_t, cast=True] result = np.empty(len(arr), dtype=bool)
3061+
Py_ssize_t i
3062+
object item
3063+
3064+
for i in range(len(arr)):
3065+
item = arr[i]
3066+
if item is C_NA:
3067+
result[i] = False
3068+
else:
3069+
result[i] = item == key
3070+
3071+
return result
3072+
3073+
30533074
def dtypes_all_equal(list types not None) -> bool:
30543075
"""
30553076
Faster version for:

pandas/_testing/asserters.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -404,9 +404,9 @@ def _get_ilevel_values(index, level):
404404
# skip exact index checking when `check_categorical` is False
405405
if check_exact and check_categorical:
406406
if not left.equals(right):
407-
diff = (
408-
np.sum((left._values != right._values).astype(int)) * 100.0 / len(left)
409-
)
407+
mismatch = left._values != right._values
408+
409+
diff = np.sum(mismatch.astype(int)) * 100.0 / len(left)
410410
msg = f"{obj} values are different ({np.round(diff, 5)} %)"
411411
raise_assert_detail(obj, msg, left, right)
412412
else:

pandas/conftest.py

+16
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,14 @@
6767
MultiIndex,
6868
)
6969

70+
try:
71+
import pyarrow as pa
72+
except ImportError:
73+
has_pyarrow = False
74+
else:
75+
del pa
76+
has_pyarrow = True
77+
7078
# Until https://github.com/numpy/numpy/issues/19078 is sorted out, just suppress
7179
suppress_npdev_promotion_warning = pytest.mark.filterwarnings(
7280
"ignore:Promotion of numbers and bools:FutureWarning"
@@ -549,7 +557,15 @@ def _create_mi_with_dt64tz_level():
549557
"mi-with-dt64tz-level": _create_mi_with_dt64tz_level(),
550558
"multi": _create_multiindex(),
551559
"repeats": Index([0, 0, 1, 1, 2, 2]),
560+
"nullable_int": Index(np.arange(100), dtype="Int64"),
561+
"nullable_uint": Index(np.arange(100), dtype="UInt16"),
562+
"nullable_float": Index(np.arange(100), dtype="Float32"),
563+
"nullable_bool": Index(np.arange(100).astype(bool), dtype="boolean"),
564+
"string-python": Index(pd.array(tm.makeStringIndex(100), dtype="string[python]")),
552565
}
566+
if has_pyarrow:
567+
idx = Index(pd.array(tm.makeStringIndex(100), dtype="string[pyarrow]"))
568+
indices_dict["string-pyarrow"] = idx
553569

554570

555571
@pytest.fixture(params=indices_dict.keys())

pandas/core/arrays/masked.py

+4-8
Original file line numberDiff line numberDiff line change
@@ -714,10 +714,7 @@ def value_counts(self, dropna: bool = True) -> Series:
714714
data = self._data[~self._mask]
715715
value_counts = Index(data).value_counts()
716716

717-
# TODO(ExtensionIndex)
718-
# if we have allow Index to hold an ExtensionArray
719-
# this is easier
720-
index = value_counts.index._values.astype(object)
717+
index = value_counts.index
721718

722719
# if we want nans, count the mask
723720
if dropna:
@@ -727,10 +724,9 @@ def value_counts(self, dropna: bool = True) -> Series:
727724
counts[:-1] = value_counts
728725
counts[-1] = self._mask.sum()
729726

730-
index = Index(
731-
np.concatenate([index, np.array([self.dtype.na_value], dtype=object)]),
732-
dtype=object,
733-
)
727+
index = index.insert(len(index), self.dtype.na_value)
728+
729+
index = index.astype(self.dtype)
734730

735731
mask = np.zeros(len(counts), dtype="bool")
736732
counts = IntegerArray(counts, mask)

pandas/core/arrays/string_.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,9 @@ def max(self, axis=None, skipna: bool = True, **kwargs) -> Scalar:
470470
def value_counts(self, dropna: bool = True):
471471
from pandas import value_counts
472472

473-
return value_counts(self._ndarray, dropna=dropna).astype("Int64")
473+
result = value_counts(self._ndarray, dropna=dropna).astype("Int64")
474+
result.index = result.index.astype(self.dtype)
475+
return result
474476

475477
def memory_usage(self, deep: bool = False) -> int:
476478
result = self._ndarray.nbytes

pandas/core/arrays/string_arrow.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,13 @@ def __getitem__(
313313
elif isinstance(item, tuple):
314314
item = unpack_tuple_and_ellipses(item)
315315

316+
# error: Non-overlapping identity check (left operand type:
317+
# "Union[Union[int, integer[Any]], Union[slice, List[int],
318+
# ndarray[Any, Any]]]", right operand type: "ellipsis")
319+
if item is Ellipsis: # type: ignore[comparison-overlap]
320+
# TODO: should be handled by pyarrow?
321+
item = slice(None)
322+
316323
if is_scalar(item) and not is_integer(item):
317324
# e.g. "foo" or 2.5
318325
# exception message copied from numpy
@@ -615,8 +622,7 @@ def value_counts(self, dropna: bool = True) -> Series:
615622
# No missing values so we can adhere to the interface and return a numpy array.
616623
counts = np.array(counts)
617624

618-
# Index cannot hold ExtensionArrays yet
619-
index = Index(type(self)(values)).astype(object)
625+
index = Index(type(self)(values))
620626

621627
return Series(counts, index=index).astype("Int64")
622628

pandas/core/dtypes/common.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -1325,12 +1325,8 @@ def is_bool_dtype(arr_or_dtype) -> bool:
13251325
# now we use the special definition for Index
13261326

13271327
if isinstance(arr_or_dtype, ABCIndex):
1328-
1329-
# TODO(jreback)
1330-
# we don't have a boolean Index class
1331-
# so its object, we need to infer to
1332-
# guess this
1333-
return arr_or_dtype.is_object() and arr_or_dtype.inferred_type == "boolean"
1328+
# Allow Index[object] that is all-bools or Index["boolean"]
1329+
return arr_or_dtype.inferred_type == "boolean"
13341330
elif isinstance(dtype, ExtensionDtype):
13351331
return getattr(dtype, "_is_boolean", False)
13361332

0 commit comments

Comments
 (0)