From d589938bbdfbb99dec2803417bfbd710646abf85 Mon Sep 17 00:00:00 2001 From: Brock Date: Tue, 22 Dec 2020 21:32:03 -0800 Subject: [PATCH 1/5] BUG: MultiIndex, IntervalIndex intersection with Categorical --- pandas/core/dtypes/dtypes.py | 2 +- pandas/core/indexes/multi.py | 14 ++++++++------ pandas/core/indexes/period.py | 16 ++++++++++------ .../tests/dtypes/cast/test_find_common_type.py | 12 ++++++++++++ pandas/tests/indexes/multi/test_equivalence.py | 2 ++ pandas/tests/indexes/multi/test_setops.py | 14 ++++++++++++++ pandas/tests/indexes/test_setops.py | 17 +++++++++++++++++ 7 files changed, 64 insertions(+), 13 deletions(-) diff --git a/pandas/core/dtypes/dtypes.py b/pandas/core/dtypes/dtypes.py index 75f3b511bc57d..4cfba314c719c 100644 --- a/pandas/core/dtypes/dtypes.py +++ b/pandas/core/dtypes/dtypes.py @@ -1173,7 +1173,7 @@ def __from_arrow__( def _get_common_dtype(self, dtypes: List[DtypeObj]) -> Optional[DtypeObj]: # NB: this doesn't handle checking for closed match if not all(isinstance(x, IntervalDtype) for x in dtypes): - return np.dtype(object) + return None from pandas.core.dtypes.cast import find_common_type diff --git a/pandas/core/indexes/multi.py b/pandas/core/indexes/multi.py index 5312dfe84cfd8..be7f5b8a2534d 100644 --- a/pandas/core/indexes/multi.py +++ b/pandas/core/indexes/multi.py @@ -3449,8 +3449,8 @@ def equals(self, other: object) -> bool: if not isinstance(other, MultiIndex): # d-level MultiIndex can equal d-tuple Index - if not is_object_dtype(other.dtype): - # other cannot contain tuples, so cannot match self + if not self._should_compare(other): + # object Index or Categorical[object] may contain tuples return False return array_equivalent(self._values, other._values) @@ -3588,13 +3588,15 @@ def union(self, other, sort=None): def _union(self, other, sort): other, result_names = self._convert_can_do_setop(other) - if not is_object_dtype(other.dtype): + if not self._should_compare(other): raise NotImplementedError( "Can only union MultiIndex with MultiIndex or Index of tuples, " "try mi.to_flat_index().union(other) instead." ) - uniq_tuples = lib.fast_unique_multiple([self._values, other._values], sort=sort) + # We could get here with CategoricalIndex other + rvals = other._values.astype(object, copy=False) + uniq_tuples = lib.fast_unique_multiple([self._values, rvals], sort=sort) return MultiIndex.from_arrays( zip(*uniq_tuples), sortorder=0, names=result_names @@ -3666,12 +3668,12 @@ def intersection(self, other, sort=False): def _intersection(self, other, sort=False): other, result_names = self._convert_can_do_setop(other) - if not self._is_comparable_dtype(other.dtype): + if not self._should_compare(other): # The intersection is empty return self[:0].rename(result_names) lvals = self._values - rvals = other._values + rvals = other._values.astype(object, copy=False) uniq_tuples = None # flag whether _inner_indexer was successful if self.is_monotonic and other.is_monotonic: diff --git a/pandas/core/indexes/period.py b/pandas/core/indexes/period.py index f8a62c6a8e006..19544a2ed1bd3 100644 --- a/pandas/core/indexes/period.py +++ b/pandas/core/indexes/period.py @@ -11,6 +11,7 @@ from pandas.errors import InvalidIndexError from pandas.util._decorators import cache_readonly, doc +from pandas.core.dtypes.cast import find_common_type from pandas.core.dtypes.common import ( is_bool_dtype, is_datetime64_any_dtype, @@ -645,15 +646,18 @@ def intersection(self, other, sort=False): return self._intersection(other, sort=sort) def _intersection(self, other, sort=False): + other, result_name = self._convert_can_do_setop(other) - if is_object_dtype(other.dtype): - return self.astype("O").intersection(other, sort=sort) - - elif not self._is_comparable_dtype(other.dtype): + if not self._should_compare(other): # We can infer that the intersection is empty. # assert_can_do_setop ensures that this is not just a mismatched freq - this = self[:0].astype("O") - other = other[:0].astype("O") + return Index([], name=result_name) + + elif not is_dtype_equal(self.dtype, other.dtype): + # we can get here with Categorical or object + dtype = find_common_type([self.dtype, other.dtype]) + this = self.astype(dtype, copy=False) + other = other.astype(dtype, copy=False) return this.intersection(other, sort=sort) return self._setop(other, sort, opname="intersection") diff --git a/pandas/tests/dtypes/cast/test_find_common_type.py b/pandas/tests/dtypes/cast/test_find_common_type.py index 7b1aa12dc0cc4..6043deec573f8 100644 --- a/pandas/tests/dtypes/cast/test_find_common_type.py +++ b/pandas/tests/dtypes/cast/test_find_common_type.py @@ -9,6 +9,8 @@ PeriodDtype, ) +from pandas import Categorical, Index + @pytest.mark.parametrize( "source_dtypes,expected_common_dtype", @@ -156,3 +158,13 @@ def test_interval_dtype(left, right): else: assert result == object + + +@pytest.mark.parametrize("dtype", interval_dtypes) +def test_interval_dtype_with_categorical(dtype): + obj = Index([], dtype=dtype) + + cat = Categorical([], categories=obj) + + result = find_common_type([dtype, cat.dtype]) + assert result == dtype diff --git a/pandas/tests/indexes/multi/test_equivalence.py b/pandas/tests/indexes/multi/test_equivalence.py index bb34760e28d96..4ee235141ed76 100644 --- a/pandas/tests/indexes/multi/test_equivalence.py +++ b/pandas/tests/indexes/multi/test_equivalence.py @@ -10,6 +10,8 @@ def test_equals(idx): assert idx.equals(idx) assert idx.equals(idx.copy()) assert idx.equals(idx.astype(object)) + assert idx.equals(idx.to_flat_index()) + assert idx.equals(idx.to_flat_index().astype("category")) assert not idx.equals(list(idx)) assert not idx.equals(np.array(idx)) diff --git a/pandas/tests/indexes/multi/test_setops.py b/pandas/tests/indexes/multi/test_setops.py index a26eb793afe7e..9a7ff78bae3db 100644 --- a/pandas/tests/indexes/multi/test_setops.py +++ b/pandas/tests/indexes/multi/test_setops.py @@ -294,6 +294,20 @@ def test_intersection(idx, sort): # assert result.equals(tuples) +@pytest.mark.parametrize("method", ["intersection", "union"]) +def test_setop_with_categorical(idx, sort, method): + other = idx.to_flat_index().astype("category") + res_names = [None] * idx.nlevels + + result = getattr(idx, method)(other, sort=sort) + expected = getattr(idx, method)(idx, sort=sort).rename(res_names) + tm.assert_index_equal(result, expected) + + result = getattr(idx, method)(other[:5], sort=sort) + expected = getattr(idx, method)(idx[:5], sort=sort).rename(res_names) + tm.assert_index_equal(result, expected) + + def test_intersection_non_object(idx, sort): other = Index(range(3), name="foo") diff --git a/pandas/tests/indexes/test_setops.py b/pandas/tests/indexes/test_setops.py index 6f949960ce30b..538e937703de6 100644 --- a/pandas/tests/indexes/test_setops.py +++ b/pandas/tests/indexes/test_setops.py @@ -446,3 +446,20 @@ def test_intersection_difference_match_empty(self, index, sort): inter = index.intersection(index[:0]) diff = index.difference(index, sort=sort) tm.assert_index_equal(inter, diff, exact=True) + + +@pytest.mark.parametrize("method", ["intersection", "union"]) +def test_setop_with_categorical(index, sort, method): + if isinstance(index, MultiIndex): + # tested separately in tests.indexes.multi.test_setops + return + + other = index.astype("category") + + result = getattr(index, method)(other, sort=sort) + expected = getattr(index, method)(index, sort=sort) + tm.assert_index_equal(result, expected) + + result = getattr(index, method)(other[:5], sort=sort) + expected = getattr(index, method)(index[:5], sort=sort) + tm.assert_index_equal(result, expected) From ed4d27d8a4ab44961af27fdda1e1f34547be7287 Mon Sep 17 00:00:00 2001 From: Brock Date: Wed, 23 Dec 2020 07:33:31 -0800 Subject: [PATCH 2/5] standardize --- pandas/core/indexes/base.py | 7 +++++-- pandas/core/indexes/datetimelike.py | 16 +++++++++++----- pandas/core/indexes/multi.py | 20 ++++++++++++++------ pandas/core/indexes/period.py | 14 +++++++------- 4 files changed, 37 insertions(+), 20 deletions(-) diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index f5b9d0194833a..83479996feb66 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -2796,14 +2796,17 @@ def intersection(self, other, sort=False): """ self._validate_sort_keyword(sort) self._assert_can_do_setop(other) - other, _ = self._convert_can_do_setop(other) + other, result_name = self._convert_can_do_setop(other) if self.equals(other): if self.has_duplicates: return self.unique()._get_reconciled_name_object(other) return self._get_reconciled_name_object(other) - if not is_dtype_equal(self.dtype, other.dtype): + elif not self._should_compare(other): + return Index([], name=result_name) + + elif not is_dtype_equal(self.dtype, other.dtype): dtype = find_common_type([self.dtype, other.dtype]) this = self.astype(dtype, copy=False) other = other.astype(dtype, copy=False) diff --git a/pandas/core/indexes/datetimelike.py b/pandas/core/indexes/datetimelike.py index 220cd5363e78f..fe432dcd78bcd 100644 --- a/pandas/core/indexes/datetimelike.py +++ b/pandas/core/indexes/datetimelike.py @@ -12,6 +12,7 @@ from pandas.compat.numpy import function as nv from pandas.util._decorators import Appender, cache_readonly, doc +from pandas.core.dtypes.cast import find_common_type from pandas.core.dtypes.common import ( is_bool_dtype, is_categorical_dtype, @@ -683,13 +684,22 @@ def intersection(self, other, sort=False): """ self._validate_sort_keyword(sort) self._assert_can_do_setop(other) - other, _ = self._convert_can_do_setop(other) + other, result_name = self._convert_can_do_setop(other) if self.equals(other): if self.has_duplicates: return self.unique()._get_reconciled_name_object(other) return self._get_reconciled_name_object(other) + elif not self._should_compare(other): + return Index([], name=result_name) + + elif not is_dtype_equal(self.dtype, other.dtype): + dtype = find_common_type([self.dtype, other.dtype]) + this = self.astype(dtype, copy=False) + other = other.astype(dtype, copy=False) + return this.intersection(other, sort=sort) + return self._intersection(other, sort=sort) def _intersection(self, other: Index, sort=False) -> Index: @@ -701,10 +711,6 @@ def _intersection(self, other: Index, sort=False) -> Index: if len(other) == 0: return other.copy()._get_reconciled_name_object(self) - if not isinstance(other, type(self)): - result = Index.intersection(self, other, sort=sort) - return result - elif not self._can_fast_intersect(other): result = Index._intersection(self, other, sort=sort) # We need to invalidate the freq because Index._intersection diff --git a/pandas/core/indexes/multi.py b/pandas/core/indexes/multi.py index be7f5b8a2534d..10ba9e56817af 100644 --- a/pandas/core/indexes/multi.py +++ b/pandas/core/indexes/multi.py @@ -25,11 +25,12 @@ from pandas.errors import InvalidIndexError, PerformanceWarning, UnsortedIndexError from pandas.util._decorators import Appender, cache_readonly, doc -from pandas.core.dtypes.cast import coerce_indexer_dtype +from pandas.core.dtypes.cast import coerce_indexer_dtype, find_common_type from pandas.core.dtypes.common import ( ensure_int64, ensure_platform_int, is_categorical_dtype, + is_dtype_equal, is_hashable, is_integer, is_iterator, @@ -3656,22 +3657,29 @@ def intersection(self, other, sort=False): """ self._validate_sort_keyword(sort) self._assert_can_do_setop(other) - other, _ = self._convert_can_do_setop(other) + other, result_names = self._convert_can_do_setop(other) if self.equals(other): if self.has_duplicates: return self.unique()._get_reconciled_name_object(other) return self._get_reconciled_name_object(other) + elif not self._should_compare(other): + # The intersection is empty + return self[:0].rename(result_names) + + elif not is_dtype_equal(self.dtype, other.dtype): + # e.g. Categorical[object] + dtype = find_common_type([self.dtype, other.dtype]) + this = self.astype(dtype, copy=False) + other = other.astype(dtype, copy=False) + return this.intersection(other, sort=sort) + return self._intersection(other, sort=sort) def _intersection(self, other, sort=False): other, result_names = self._convert_can_do_setop(other) - if not self._should_compare(other): - # The intersection is empty - return self[:0].rename(result_names) - lvals = self._values rvals = other._values.astype(object, copy=False) diff --git a/pandas/core/indexes/period.py b/pandas/core/indexes/period.py index 19544a2ed1bd3..7fc15b6e295d9 100644 --- a/pandas/core/indexes/period.py +++ b/pandas/core/indexes/period.py @@ -636,19 +636,14 @@ def _setop(self, other, sort, opname: str): def intersection(self, other, sort=False): self._validate_sort_keyword(sort) self._assert_can_do_setop(other) - other, _ = self._convert_can_do_setop(other) + other, result_name = self._convert_can_do_setop(other) if self.equals(other): if self.has_duplicates: return self.unique()._get_reconciled_name_object(other) return self._get_reconciled_name_object(other) - return self._intersection(other, sort=sort) - - def _intersection(self, other, sort=False): - other, result_name = self._convert_can_do_setop(other) - - if not self._should_compare(other): + elif not self._should_compare(other): # We can infer that the intersection is empty. # assert_can_do_setop ensures that this is not just a mismatched freq return Index([], name=result_name) @@ -660,6 +655,11 @@ def _intersection(self, other, sort=False): other = other.astype(dtype, copy=False) return this.intersection(other, sort=sort) + return self._intersection(other, sort=sort) + + def _intersection(self, other, sort=False): + other, result_name = self._convert_can_do_setop(other) + return self._setop(other, sort, opname="intersection") def difference(self, other, sort=None): From c212bc1cf334461b8210aea7c1240f83e79bc949 Mon Sep 17 00:00:00 2001 From: Brock Date: Wed, 23 Dec 2020 08:53:46 -0800 Subject: [PATCH 3/5] Share intersection --- pandas/core/indexes/base.py | 2 ++ pandas/core/indexes/datetimelike.py | 49 ----------------------------- pandas/core/indexes/multi.py | 46 +-------------------------- pandas/core/indexes/period.py | 25 --------------- 4 files changed, 3 insertions(+), 119 deletions(-) diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index e18b83a22202e..8d48a6277d412 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -2814,6 +2814,8 @@ def intersection(self, other, sort=False): elif not self._should_compare(other): # We can infer that the intersection is empty. + if isinstance(self, ABCMultiIndex): + return self[:0].rename(result_name) return Index([], name=result_name) elif not is_dtype_equal(self.dtype, other.dtype): diff --git a/pandas/core/indexes/datetimelike.py b/pandas/core/indexes/datetimelike.py index 33fd637909c00..c2dfd97481aa6 100644 --- a/pandas/core/indexes/datetimelike.py +++ b/pandas/core/indexes/datetimelike.py @@ -12,7 +12,6 @@ from pandas.compat.numpy import function as nv from pandas.util._decorators import Appender, cache_readonly, doc -from pandas.core.dtypes.cast import find_common_type from pandas.core.dtypes.common import ( is_bool_dtype, is_categorical_dtype, @@ -655,54 +654,6 @@ def difference(self, other, sort=None): new_idx = super().difference(other, sort=sort)._with_freq(None) return new_idx - def intersection(self, other, sort=False): - """ - Specialized intersection for DatetimeIndex/TimedeltaIndex. - - May be much faster than Index.intersection - - Parameters - ---------- - other : Same type as self or array-like - sort : False or None, default False - Sort the resulting index if possible. - - .. versionadded:: 0.24.0 - - .. versionchanged:: 0.24.1 - - Changed the default to ``False`` to match the behaviour - from before 0.24.0. - - .. versionchanged:: 0.25.0 - - The `sort` keyword is added - - Returns - ------- - y : Index or same type as self - """ - self._validate_sort_keyword(sort) - self._assert_can_do_setop(other) - other, result_name = self._convert_can_do_setop(other) - - if self.equals(other): - if self.has_duplicates: - return self.unique()._get_reconciled_name_object(other) - return self._get_reconciled_name_object(other) - - elif not self._should_compare(other): - # We can infer that the intersection is empty. - return Index([], name=result_name) - - elif not is_dtype_equal(self.dtype, other.dtype): - dtype = find_common_type([self.dtype, other.dtype]) - this = self.astype(dtype, copy=False) - other = other.astype(dtype, copy=False) - return this.intersection(other, sort=sort) - - return self._intersection(other, sort=sort) - def _intersection(self, other: Index, sort=False) -> Index: """ intersection specialized to the case with matching dtypes. diff --git a/pandas/core/indexes/multi.py b/pandas/core/indexes/multi.py index 10ba9e56817af..c7be66b596246 100644 --- a/pandas/core/indexes/multi.py +++ b/pandas/core/indexes/multi.py @@ -25,12 +25,11 @@ from pandas.errors import InvalidIndexError, PerformanceWarning, UnsortedIndexError from pandas.util._decorators import Appender, cache_readonly, doc -from pandas.core.dtypes.cast import coerce_indexer_dtype, find_common_type +from pandas.core.dtypes.cast import coerce_indexer_dtype from pandas.core.dtypes.common import ( ensure_int64, ensure_platform_int, is_categorical_dtype, - is_dtype_equal, is_hashable, is_integer, is_iterator, @@ -3634,49 +3633,6 @@ def _maybe_match_names(self, other): names.append(None) return names - def intersection(self, other, sort=False): - """ - Form the intersection of two MultiIndex objects. - - Parameters - ---------- - other : MultiIndex or array / Index of tuples - sort : False or None, default False - Sort the resulting MultiIndex if possible - - .. versionadded:: 0.24.0 - - .. versionchanged:: 0.24.1 - - Changed the default from ``True`` to ``False``, to match - behaviour from before 0.24.0 - - Returns - ------- - Index - """ - self._validate_sort_keyword(sort) - self._assert_can_do_setop(other) - other, result_names = self._convert_can_do_setop(other) - - if self.equals(other): - if self.has_duplicates: - return self.unique()._get_reconciled_name_object(other) - return self._get_reconciled_name_object(other) - - elif not self._should_compare(other): - # The intersection is empty - return self[:0].rename(result_names) - - elif not is_dtype_equal(self.dtype, other.dtype): - # e.g. Categorical[object] - dtype = find_common_type([self.dtype, other.dtype]) - this = self.astype(dtype, copy=False) - other = other.astype(dtype, copy=False) - return this.intersection(other, sort=sort) - - return self._intersection(other, sort=sort) - def _intersection(self, other, sort=False): other, result_names = self._convert_can_do_setop(other) diff --git a/pandas/core/indexes/period.py b/pandas/core/indexes/period.py index a255fc6a542b5..4d48bc0d51912 100644 --- a/pandas/core/indexes/period.py +++ b/pandas/core/indexes/period.py @@ -11,7 +11,6 @@ from pandas.errors import InvalidIndexError from pandas.util._decorators import cache_readonly, doc -from pandas.core.dtypes.cast import find_common_type from pandas.core.dtypes.common import ( is_bool_dtype, is_datetime64_any_dtype, @@ -633,30 +632,6 @@ def _setop(self, other, sort, opname: str): result = type(self)._simple_new(parr, name=res_name) return result - def intersection(self, other, sort=False): - self._validate_sort_keyword(sort) - self._assert_can_do_setop(other) - other, result_name = self._convert_can_do_setop(other) - - if self.equals(other): - if self.has_duplicates: - return self.unique()._get_reconciled_name_object(other) - return self._get_reconciled_name_object(other) - - elif not self._should_compare(other): - # We can infer that the intersection is empty. - # assert_can_do_setop ensures that this is not just a mismatched freq - return Index([], name=result_name) - - elif not is_dtype_equal(self.dtype, other.dtype): - # we can get here with Categorical or object - dtype = find_common_type([self.dtype, other.dtype]) - this = self.astype(dtype, copy=False) - other = other.astype(dtype, copy=False) - return this.intersection(other, sort=sort) - - return self._intersection(other, sort=sort) - def _intersection(self, other, sort=False): return self._setop(other, sort, opname="intersection") From 4fdbe3eeef097e056e6a799c4962aec6916366b3 Mon Sep 17 00:00:00 2001 From: Brock Date: Wed, 23 Dec 2020 08:56:32 -0800 Subject: [PATCH 4/5] whatsnew --- doc/source/whatsnew/v1.3.0.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/doc/source/whatsnew/v1.3.0.rst b/doc/source/whatsnew/v1.3.0.rst index 65bfd8289fe3d..3b5bc5dbd6c83 100644 --- a/doc/source/whatsnew/v1.3.0.rst +++ b/doc/source/whatsnew/v1.3.0.rst @@ -212,7 +212,7 @@ Strings Interval ^^^^^^^^ - +- Bug in :meth:`IntervalIndex.intersection` always returning object-dtype when intersecting with :class:`CategoricalIndex` (:issue:`38653`) - - @@ -236,6 +236,7 @@ MultiIndex - Bug in :meth:`DataFrame.drop` raising ``TypeError`` when :class:`MultiIndex` is non-unique and no level is provided (:issue:`36293`) - Bug in :meth:`MultiIndex.equals` incorrectly returning ``True`` when :class:`MultiIndex` containing ``NaN`` even when they are differntly ordered (:issue:`38439`) +- Bug in :meth:`MultiIndex.intersection` always returning empty when intersecting with :class:`CategoricalIndex` (:issue:`38653`) I/O ^^^ From ce43d7aa6868420a17e7678b86a945ff0a2d82f7 Mon Sep 17 00:00:00 2001 From: Brock Date: Wed, 23 Dec 2020 09:48:32 -0800 Subject: [PATCH 5/5] mypy fixup --- pandas/core/indexes/datetimelike.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pandas/core/indexes/datetimelike.py b/pandas/core/indexes/datetimelike.py index c2dfd97481aa6..94c055e264e71 100644 --- a/pandas/core/indexes/datetimelike.py +++ b/pandas/core/indexes/datetimelike.py @@ -658,6 +658,7 @@ def _intersection(self, other: Index, sort=False) -> Index: """ intersection specialized to the case with matching dtypes. """ + other = cast("DatetimeTimedeltaMixin", other) if len(self) == 0: return self.copy()._get_reconciled_name_object(other) if len(other) == 0: