Skip to content

Commit 917b734

Browse files
authored
ENH: cast instead of raise for IntervalIndex setops with differnet closed (#39267)
1 parent 8d523dd commit 917b734

File tree

6 files changed

+34
-43
lines changed

6 files changed

+34
-43
lines changed

doc/source/whatsnew/v1.3.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ Interval
270270
^^^^^^^^
271271
- Bug in :meth:`IntervalIndex.intersection` and :meth:`IntervalIndex.symmetric_difference` always returning object-dtype when operating with :class:`CategoricalIndex` (:issue:`38653`, :issue:`38741`)
272272
- Bug in :meth:`IntervalIndex.intersection` returning duplicates when at least one of both Indexes has duplicates which are present in the other (:issue:`38743`)
273-
-
273+
- :meth:`IntervalIndex.union`, :meth:`IntervalIndex.intersection`, :meth:`IntervalIndex.difference`, and :meth:`IntervalIndex.symmetric_difference` now cast to the appropriate dtype instead of raising ``TypeError`` when operating with another :class:`IntervalIndex` with incompatible dtype (:issue:`39267`)
274274

275275
Indexing
276276
^^^^^^^^

pandas/core/indexes/base.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -3007,6 +3007,7 @@ def _intersection(self, other, sort=False):
30073007

30083008
return result
30093009

3010+
@final
30103011
def difference(self, other, sort=None):
30113012
"""
30123013
Return a new Index with elements of index not in `other`.
@@ -3131,7 +3132,7 @@ def symmetric_difference(self, other, result_name=None, sort=None):
31313132
result_name = result_name_update
31323133

31333134
if not self._should_compare(other):
3134-
return self.union(other).rename(result_name)
3135+
return self.union(other, sort=sort).rename(result_name)
31353136
elif not is_dtype_equal(self.dtype, other.dtype):
31363137
dtype = find_common_type([self.dtype, other.dtype])
31373138
this = self.astype(dtype, copy=False)
@@ -6240,7 +6241,7 @@ def _maybe_cast_data_without_dtype(subarr):
62406241
try:
62416242
data = IntervalArray._from_sequence(subarr, copy=False)
62426243
return data
6243-
except ValueError:
6244+
except (ValueError, TypeError):
62446245
# GH27172: mixed closed Intervals --> object dtype
62456246
pass
62466247
elif inferred == "boolean":

pandas/core/indexes/datetimelike.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -660,9 +660,8 @@ def is_type_compatible(self, kind: str) -> bool:
660660
# --------------------------------------------------------------------
661661
# Set Operation Methods
662662

663-
@Appender(Index.difference.__doc__)
664-
def difference(self, other, sort=None):
665-
new_idx = super().difference(other, sort=sort)._with_freq(None)
663+
def _difference(self, other, sort=None):
664+
new_idx = super()._difference(other, sort=sort)._with_freq(None)
666665
return new_idx
667666

668667
def _intersection(self, other: Index, sort=False) -> Index:

pandas/core/indexes/interval.py

+11-18
Original file line numberDiff line numberDiff line change
@@ -127,13 +127,17 @@ def setop_check(method):
127127
def wrapped(self, other, sort=False):
128128
self._validate_sort_keyword(sort)
129129
self._assert_can_do_setop(other)
130-
other, _ = self._convert_can_do_setop(other)
130+
other, result_name = self._convert_can_do_setop(other)
131131

132-
if not isinstance(other, IntervalIndex):
133-
result = getattr(self.astype(object), op_name)(other)
134-
if op_name in ("difference",):
135-
result = result.astype(self.dtype)
136-
return result
132+
if op_name == "difference":
133+
if not isinstance(other, IntervalIndex):
134+
result = getattr(self.astype(object), op_name)(other, sort=sort)
135+
return result.astype(self.dtype)
136+
137+
elif not self._should_compare(other):
138+
# GH#19016: ensure set op will not return a prohibited dtype
139+
result = getattr(self.astype(object), op_name)(other, sort=sort)
140+
return result.astype(self.dtype)
137141

138142
return method(self, other, sort)
139143

@@ -912,17 +916,6 @@ def _format_space(self) -> str:
912916
# --------------------------------------------------------------------
913917
# Set Operations
914918

915-
def _assert_can_do_setop(self, other):
916-
super()._assert_can_do_setop(other)
917-
918-
if isinstance(other, IntervalIndex) and not self._should_compare(other):
919-
# GH#19016: ensure set op will not return a prohibited dtype
920-
raise TypeError(
921-
"can only do set operations between two IntervalIndex "
922-
"objects that are closed on the same side "
923-
"and have compatible dtypes"
924-
)
925-
926919
def _intersection(self, other, sort):
927920
"""
928921
intersection specialized to the case with matching dtypes.
@@ -1014,7 +1007,7 @@ def func(self, other, sort=sort):
10141007
return setop_check(func)
10151008

10161009
_union = _setop("union")
1017-
difference = _setop("difference")
1010+
_difference = _setop("difference")
10181011

10191012
# --------------------------------------------------------------------
10201013

pandas/core/indexes/range.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -632,14 +632,14 @@ def _union(self, other, sort):
632632
return type(self)(start_r, end_r + step_o, step_o)
633633
return self._int64index._union(other, sort=sort)
634634

635-
def difference(self, other, sort=None):
635+
def _difference(self, other, sort=None):
636636
# optimized set operation if we have another RangeIndex
637637
self._validate_sort_keyword(sort)
638638
self._assert_can_do_setop(other)
639639
other, result_name = self._convert_can_do_setop(other)
640640

641641
if not isinstance(other, RangeIndex):
642-
return super().difference(other, sort=sort)
642+
return super()._difference(other, sort=sort)
643643

644644
res_name = ops.get_op_result_name(self, other)
645645

@@ -654,11 +654,11 @@ def difference(self, other, sort=None):
654654
return self[:0].rename(res_name)
655655
if not isinstance(overlap, RangeIndex):
656656
# We won't end up with RangeIndex, so fall back
657-
return super().difference(other, sort=sort)
657+
return super()._difference(other, sort=sort)
658658
if overlap.step != first.step:
659659
# In some cases we might be able to get a RangeIndex back,
660660
# but not worth the effort.
661-
return super().difference(other, sort=sort)
661+
return super()._difference(other, sort=sort)
662662

663663
if overlap[0] == first.start:
664664
# The difference is everything after the intersection
@@ -668,7 +668,7 @@ def difference(self, other, sort=None):
668668
new_rng = range(first.start, overlap[0], first.step)
669669
else:
670670
# The difference is not range-like
671-
return super().difference(other, sort=sort)
671+
return super()._difference(other, sort=sort)
672672

673673
new_index = type(self)._simple_new(new_rng, name=res_name)
674674
if first is not self._range:

pandas/tests/indexes/interval/test_setops.py

+12-14
Original file line numberDiff line numberDiff line change
@@ -178,21 +178,19 @@ def test_set_incompatible_types(self, closed, op_name, sort):
178178
result = set_op(Index([1, 2, 3]), sort=sort)
179179
tm.assert_index_equal(result, expected)
180180

181-
# mixed closed
182-
msg = (
183-
"can only do set operations between two IntervalIndex objects "
184-
"that are closed on the same side and have compatible dtypes"
185-
)
181+
# mixed closed -> cast to object
186182
for other_closed in {"right", "left", "both", "neither"} - {closed}:
187183
other = monotonic_index(0, 11, closed=other_closed)
188-
with pytest.raises(TypeError, match=msg):
189-
set_op(other, sort=sort)
184+
expected = getattr(index.astype(object), op_name)(other, sort=sort)
185+
if op_name == "difference":
186+
expected = index
187+
result = set_op(other, sort=sort)
188+
tm.assert_index_equal(result, expected)
190189

191-
# GH 19016: incompatible dtypes
190+
# GH 19016: incompatible dtypes -> cast to object
192191
other = interval_range(Timestamp("20180101"), periods=9, closed=closed)
193-
msg = (
194-
"can only do set operations between two IntervalIndex objects "
195-
"that are closed on the same side and have compatible dtypes"
196-
)
197-
with pytest.raises(TypeError, match=msg):
198-
set_op(other, sort=sort)
192+
expected = getattr(index.astype(object), op_name)(other, sort=sort)
193+
if op_name == "difference":
194+
expected = index
195+
result = set_op(other, sort=sort)
196+
tm.assert_index_equal(result, expected)

0 commit comments

Comments
 (0)