Skip to content

Commit 50a005c

Browse files
committed
ENH/WIP: Index[bool]
1 parent c2d0291 commit 50a005c

File tree

14 files changed

+51
-22
lines changed

14 files changed

+51
-22
lines changed

pandas/_libs/index.pyi

+1
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ class ObjectEngine(IndexEngine): ...
4141
class DatetimeEngine(Int64Engine): ...
4242
class TimedeltaEngine(DatetimeEngine): ...
4343
class PeriodEngine(Int64Engine): ...
44+
class BoolEngine(Uint8Engine): ...
4445

4546
class BaseMultiIndexCodesEngine:
4647
levels: list[np.ndarray]

pandas/_libs/index.pyx

+6
Original file line numberDiff line numberDiff line change
@@ -795,3 +795,9 @@ cdef class BaseMultiIndexCodesEngine:
795795

796796
# Generated from template.
797797
include "index_class_helper.pxi"
798+
799+
800+
cdef class BoolEngine(UInt8Engine):
801+
cdef _check_type(self, object val):
802+
if not util.is_bool_object(val):
803+
raise KeyError(val)

pandas/conftest.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -529,7 +529,8 @@ def _create_mi_with_dt64tz_level():
529529
"num_uint8": tm.makeNumericIndex(100, dtype="uint8"),
530530
"num_float64": tm.makeNumericIndex(100, dtype="float64"),
531531
"num_float32": tm.makeNumericIndex(100, dtype="float32"),
532-
"bool": tm.makeBoolIndex(10),
532+
"bool-object": tm.makeBoolIndex(10).astype(object),
533+
"bool-dtype": Index(np.random.randn(10) < 0),
533534
"categorical": tm.makeCategoricalIndex(100),
534535
"interval": tm.makeIntervalIndex(100),
535536
"empty": Index([]),

pandas/core/algorithms.py

-3
Original file line numberDiff line numberDiff line change
@@ -219,9 +219,6 @@ def _reconstruct_data(
219219
elif is_bool_dtype(dtype):
220220
values = values.astype(dtype, copy=False)
221221

222-
# we only support object dtypes bool Index
223-
if isinstance(original, ABCIndex):
224-
values = values.astype(object, copy=False)
225222
elif dtype is not None:
226223
if is_datetime64_dtype(dtype):
227224
dtype = np.dtype("datetime64[ns]")

pandas/core/dtypes/common.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1330,7 +1330,7 @@ def is_bool_dtype(arr_or_dtype) -> bool:
13301330
# we don't have a boolean Index class
13311331
# so its object, we need to infer to
13321332
# guess this
1333-
return arr_or_dtype.is_object() and arr_or_dtype.inferred_type == "boolean"
1333+
return arr_or_dtype.inferred_type == "boolean"
13341334
elif isinstance(dtype, ExtensionDtype):
13351335
return getattr(dtype, "_is_boolean", False)
13361336

pandas/core/indexes/base.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,10 @@ def __new__(
481481
if data.dtype.kind in ["i", "u", "f"]:
482482
# maybe coerce to a sub-class
483483
arr = data
484+
elif data.dtype.kind == "b":
485+
# No special subclass, and Index._ensure_array won't do this
486+
# for us.
487+
arr = np.asarray(data)
484488
else:
485489
arr = com.asarray_tuplesafe(data, dtype=np.dtype("object"))
486490

@@ -672,7 +676,7 @@ def _with_infer(cls, *args, **kwargs):
672676
# "Union[ExtensionArray, ndarray[Any, Any]]"; expected
673677
# "ndarray[Any, Any]"
674678
values = lib.maybe_convert_objects(result._values) # type: ignore[arg-type]
675-
if values.dtype.kind in ["i", "u", "f"]:
679+
if values.dtype.kind in ["i", "u", "f", "b"]:
676680
return Index(values, name=result.name)
677681

678682
return result
@@ -837,6 +841,8 @@ def _engine(self) -> libindex.IndexEngine:
837841
# to avoid a reference cycle, bind `target_values` to a local variable, so
838842
# `self` is not passed into the lambda.
839843
target_values = self._get_engine_target()
844+
if target_values.dtype == bool:
845+
return libindex.BoolEngine(target_values)
840846
return self._engine_type(target_values)
841847

842848
@final
@@ -2548,6 +2554,8 @@ def _is_all_dates(self) -> bool:
25482554
"""
25492555
Whether or not the index values only consist of dates.
25502556
"""
2557+
if self.dtype.kind == "b":
2558+
return False
25512559
return is_datetime_array(ensure_object(self._values))
25522560

25532561
@cache_readonly
@@ -7048,7 +7056,7 @@ def _maybe_cast_data_without_dtype(
70487056
FutureWarning,
70497057
stacklevel=3,
70507058
)
7051-
if result.dtype.kind in ["b", "c"]:
7059+
if result.dtype.kind in ["c"]:
70527060
return subarr
70537061
result = ensure_wrapped_if_datetimelike(result)
70547062
return result

pandas/core/tools/datetimes.py

+2
Original file line numberDiff line numberDiff line change
@@ -925,6 +925,8 @@ def to_datetime(
925925
result = convert_listlike(arg, format)
926926
else:
927927
result = convert_listlike(np.array([arg]), format)[0]
928+
if isinstance(arg, bool) and isinstance(result, np.bool_):
929+
result = bool(result) # TODO: avoid this kludge.
928930

929931
# error: Incompatible return value type (got "Union[Timestamp, NaTType,
930932
# Series, Index]", expected "Union[DatetimeIndex, Series, float, str,

pandas/core/util/hashing.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ def _hash_ndarray(
311311

312312
# First, turn whatever array this is into unsigned 64-bit ints, if we can
313313
# manage it.
314-
elif isinstance(dtype, bool):
314+
elif dtype == bool:
315315
vals = vals.astype("u8")
316316
elif issubclass(dtype.type, (np.datetime64, np.timedelta64)):
317317
vals = vals.view("i8").astype("u8", copy=False)

pandas/tests/indexes/common.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,11 @@ def test_fillna(self, index):
492492
# GH 11343
493493
if len(index) == 0:
494494
return
495-
elif isinstance(index, NumericIndex) and is_integer_dtype(index.dtype):
495+
elif (
496+
isinstance(index, NumericIndex)
497+
and is_integer_dtype(index.dtype)
498+
or index.dtype == bool
499+
):
496500
return
497501
elif isinstance(index, MultiIndex):
498502
idx = index.copy(deep=True)

pandas/tests/indexes/test_base.py

+16-7
Original file line numberDiff line numberDiff line change
@@ -320,15 +320,21 @@ def test_view_with_args(self, index):
320320
"unicode",
321321
"string",
322322
pytest.param("categorical", marks=pytest.mark.xfail(reason="gh-25464")),
323-
"bool",
323+
"bool-object",
324+
"bool-dtype",
324325
"empty",
325326
],
326327
indirect=True,
327328
)
328329
def test_view_with_args_object_array_raises(self, index):
329-
msg = "Cannot change data-type for object array"
330-
with pytest.raises(TypeError, match=msg):
331-
index.view("i8")
330+
if index.dtype == bool:
331+
msg = "When changing to a larger dtype"
332+
with pytest.raises(ValueError, match=msg):
333+
index.view("i8")
334+
else:
335+
msg = "Cannot change data-type for object array"
336+
with pytest.raises(TypeError, match=msg):
337+
index.view("i8")
332338

333339
@pytest.mark.parametrize("index", ["int", "range"], indirect=True)
334340
def test_astype(self, index):
@@ -587,7 +593,8 @@ def test_append_empty_preserve_name(self, name, expected):
587593
"index, expected",
588594
[
589595
("string", False),
590-
("bool", False),
596+
("bool-object", False),
597+
("bool-dtype", False),
591598
("categorical", False),
592599
("int", True),
593600
("datetime", False),
@@ -602,7 +609,8 @@ def test_is_numeric(self, index, expected):
602609
"index, expected",
603610
[
604611
("string", True),
605-
("bool", True),
612+
("bool-object", True),
613+
("bool-dtype", False),
606614
("categorical", False),
607615
("int", False),
608616
("datetime", False),
@@ -617,7 +625,8 @@ def test_is_object(self, index, expected):
617625
"index, expected",
618626
[
619627
("string", False),
620-
("bool", False),
628+
("bool-object", False),
629+
("bool-dtype", False),
621630
("categorical", False),
622631
("int", False),
623632
("datetime", True),

pandas/tests/indexes/test_index_new.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def test_constructor_dtypes_to_object(self, cast_index, vals):
7373
index = Index(vals)
7474

7575
assert type(index) is Index
76-
assert index.dtype == object
76+
assert index.dtype == bool
7777

7878
def test_constructor_categorical_to_object(self):
7979
# GH#32167 Categorical data and dtype=object should return object-dtype

pandas/tests/indexes/test_numpy_compat.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def test_numpy_ufuncs_basic(index, func):
5151
with tm.external_error_raised((TypeError, AttributeError)):
5252
with np.errstate(all="ignore"):
5353
func(index)
54-
elif isinstance(index, NumericIndex):
54+
elif isinstance(index, NumericIndex) or index.dtype == bool:
5555
# coerces to float (e.g. np.sin)
5656
with np.errstate(all="ignore"):
5757
result = func(index)
@@ -89,7 +89,7 @@ def test_numpy_ufuncs_other(index, func, request):
8989
with tm.external_error_raised(TypeError):
9090
func(index)
9191

92-
elif isinstance(index, NumericIndex):
92+
elif isinstance(index, NumericIndex) or index.dtype == bool:
9393
# Results in bool array
9494
result = func(index)
9595
assert isinstance(result, np.ndarray)

pandas/tests/series/indexing/test_setitem.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -696,7 +696,7 @@ def test_index_where(self, obj, key, expected, val, request):
696696
mask[key] = True
697697

698698
res = Index(obj).where(~mask, val)
699-
tm.assert_index_equal(res, Index(expected))
699+
tm.assert_index_equal(res, Index(expected, dtype=expected.dtype))
700700

701701
def test_index_putmask(self, obj, key, expected, val):
702702
if Index(obj).dtype != obj.dtype:
@@ -707,7 +707,7 @@ def test_index_putmask(self, obj, key, expected, val):
707707
mask[key] = True
708708

709709
res = Index(obj).putmask(mask, val)
710-
tm.assert_index_equal(res, Index(expected))
710+
tm.assert_index_equal(res, Index(expected, dtype=expected.dtype))
711711

712712

713713
@pytest.mark.parametrize(

pandas/tests/series/methods/test_drop.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ def test_drop_with_ignore_errors():
5454

5555
# GH 8522
5656
s = Series([2, 3], index=[True, False])
57-
assert s.index.is_object()
57+
assert not s.index.is_object()
58+
assert s.index.dtype == bool
5859
result = s.drop(True)
5960
expected = Series([3], index=[False])
6061
tm.assert_series_equal(result, expected)

0 commit comments

Comments
 (0)