Skip to content

REF: Share Index comparison and arithmetic methods #43555

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 29 additions & 2 deletions pandas/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,10 @@
remove_na_arraylike,
)

from pandas.core import algorithms
from pandas.core import (
algorithms,
ops,
)
from pandas.core.accessor import DirNamesMixin
from pandas.core.algorithms import (
duplicated,
Expand All @@ -61,7 +64,11 @@
)
from pandas.core.arraylike import OpsMixin
from pandas.core.arrays import ExtensionArray
from pandas.core.construction import create_series_with_explicit_dtype
from pandas.core.construction import (
create_series_with_explicit_dtype,
ensure_wrapped_if_datetimelike,
extract_array,
)
import pandas.core.nanops as nanops

if TYPE_CHECKING:
Expand Down Expand Up @@ -1238,3 +1245,23 @@ def _duplicated(
self, keep: Literal["first", "last", False] = "first"
) -> npt.NDArray[np.bool_]:
return duplicated(self._values, keep=keep)

def _arith_method(self, other, op):
res_name = ops.get_op_result_name(self, other)

lvalues = self._values
rvalues = extract_array(other, extract_numpy=True, extract_range=True)
rvalues = ops.maybe_prepare_scalar_for_op(rvalues, lvalues.shape)
rvalues = ensure_wrapped_if_datetimelike(rvalues)

with np.errstate(all="ignore"):
result = ops.arithmetic_op(lvalues, rvalues, op)

return self._construct_result(result, name=res_name)

def _construct_result(self, result, name):
"""
Construct an appropriately-wrapped result from the ArrayLike result
of an arithmetic-like operation.
"""
raise AbstractMethodError(self)
35 changes: 25 additions & 10 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6364,7 +6364,10 @@ def _cmp_method(self, other, op):
arr[self.isna()] = False
return arr
elif op in {operator.ne, operator.lt, operator.gt}:
return np.zeros(len(self), dtype=bool)
arr = np.zeros(len(self), dtype=bool)
if self._can_hold_na and not isinstance(self, ABCMultiIndex):
arr[self.isna()] = True
return arr

if isinstance(other, (np.ndarray, Index, ABCSeries, ExtensionArray)) and len(
self
Expand All @@ -6381,6 +6384,9 @@ def _cmp_method(self, other, op):
with np.errstate(all="ignore"):
result = op(self._values, other)

elif isinstance(self._values, ExtensionArray):
result = op(self._values, other)

elif is_object_dtype(self.dtype) and not isinstance(self, ABCMultiIndex):
# don't pass MultiIndex
with np.errstate(all="ignore"):
Expand All @@ -6392,17 +6398,26 @@ def _cmp_method(self, other, op):

return result

def _arith_method(self, other, op):
"""
Wrapper used to dispatch arithmetic operations.
"""
def _construct_result(self, result, name):
if isinstance(result, tuple):
return (
Index._with_infer(result[0], name=name),
Index._with_infer(result[1], name=name),
)
return Index._with_infer(result, name=name)

from pandas import Series
def _arith_method(self, other, op):
if (
isinstance(other, Index)
and is_object_dtype(other.dtype)
and type(other) is not Index
):
# We return NotImplemented for object-dtype index *subclasses* so they have
# a chance to implement ops before we unwrap them.
# See https://github.com/pandas-dev/pandas/issues/31109
return NotImplemented

result = op(Series(self), other)
if isinstance(result, tuple):
return (Index._with_infer(result[0]), Index(result[1]))
return Index._with_infer(result)
return super()._arith_method(other, op)

@final
def _unary_method(self, op):
Expand Down
119 changes: 1 addition & 118 deletions pandas/core/indexes/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,9 @@

from pandas.core.dtypes.common import (
is_dtype_equal,
is_object_dtype,
pandas_dtype,
)
from pandas.core.dtypes.generic import (
ABCDataFrame,
ABCSeries,
)
from pandas.core.dtypes.generic import ABCDataFrame

from pandas.core.arrays import (
Categorical,
Expand All @@ -45,7 +41,6 @@
from pandas.core.arrays.base import ExtensionArray
from pandas.core.indexers import deprecate_ndim_indexing
from pandas.core.indexes.base import Index
from pandas.core.ops import get_op_result_name

if TYPE_CHECKING:

Expand Down Expand Up @@ -154,94 +149,6 @@ def wrapper(cls):
return wrapper


def _make_wrapped_comparison_op(opname: str):
"""
Create a comparison method that dispatches to ``._data``.
"""

def wrapper(self, other):
if isinstance(other, ABCSeries):
# the arrays defer to Series for comparison ops but the indexes
# don't, so we have to unwrap here.
other = other._values

other = _maybe_unwrap_index(other)

op = getattr(self._data, opname)
return op(other)

wrapper.__name__ = opname
return wrapper


def _make_wrapped_arith_op(opname: str):
def method(self, other):
if (
isinstance(other, Index)
and is_object_dtype(other.dtype)
and type(other) is not Index
):
# We return NotImplemented for object-dtype index *subclasses* so they have
# a chance to implement ops before we unwrap them.
# See https://github.com/pandas-dev/pandas/issues/31109
return NotImplemented

try:
meth = getattr(self._data, opname)
except AttributeError as err:
# e.g. Categorical, IntervalArray
cls = type(self).__name__
raise TypeError(
f"cannot perform {opname} with this index type: {cls}"
) from err

result = meth(_maybe_unwrap_index(other))
return _wrap_arithmetic_op(self, other, result)

method.__name__ = opname
return method


def _wrap_arithmetic_op(self, other, result):
if result is NotImplemented:
return NotImplemented

if isinstance(result, tuple):
# divmod, rdivmod
assert len(result) == 2
return (
_wrap_arithmetic_op(self, other, result[0]),
_wrap_arithmetic_op(self, other, result[1]),
)

if not isinstance(result, Index):
# Index.__new__ will choose appropriate subclass for dtype
result = Index(result)

res_name = get_op_result_name(self, other)
result.name = res_name
return result


def _maybe_unwrap_index(obj):
"""
If operating against another Index object, we need to unwrap the underlying
data before deferring to the DatetimeArray/TimedeltaArray/PeriodArray
implementation, otherwise we will incorrectly return NotImplemented.

Parameters
----------
obj : object

Returns
-------
unwrapped object
"""
if isinstance(obj, Index):
return obj._data
return obj


class ExtensionIndex(Index):
"""
Index subclass for indexes backed by ExtensionArray.
Expand Down Expand Up @@ -284,30 +191,6 @@ def _simple_new(
result._reset_identity()
return result

__eq__ = _make_wrapped_comparison_op("__eq__")
__ne__ = _make_wrapped_comparison_op("__ne__")
__lt__ = _make_wrapped_comparison_op("__lt__")
__gt__ = _make_wrapped_comparison_op("__gt__")
__le__ = _make_wrapped_comparison_op("__le__")
__ge__ = _make_wrapped_comparison_op("__ge__")

__add__ = _make_wrapped_arith_op("__add__")
__sub__ = _make_wrapped_arith_op("__sub__")
__radd__ = _make_wrapped_arith_op("__radd__")
__rsub__ = _make_wrapped_arith_op("__rsub__")
__pow__ = _make_wrapped_arith_op("__pow__")
__rpow__ = _make_wrapped_arith_op("__rpow__")
__mul__ = _make_wrapped_arith_op("__mul__")
__rmul__ = _make_wrapped_arith_op("__rmul__")
__floordiv__ = _make_wrapped_arith_op("__floordiv__")
__rfloordiv__ = _make_wrapped_arith_op("__rfloordiv__")
__mod__ = _make_wrapped_arith_op("__mod__")
__rmod__ = _make_wrapped_arith_op("__rmod__")
__divmod__ = _make_wrapped_arith_op("__divmod__")
__rdivmod__ = _make_wrapped_arith_op("__rdivmod__")
__truediv__ = _make_wrapped_arith_op("__truediv__")
__rtruediv__ = _make_wrapped_arith_op("__rtruediv__")

# ---------------------------------------------------------------------
# NDarray-Like Methods

Expand Down
13 changes: 1 addition & 12 deletions pandas/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@
import pandas.core.common as com
from pandas.core.construction import (
create_series_with_explicit_dtype,
ensure_wrapped_if_datetimelike,
extract_array,
is_empty_data,
sanitize_array,
Expand Down Expand Up @@ -5515,18 +5514,8 @@ def _logical_method(self, other, op):
return self._construct_result(res_values, name=res_name)

def _arith_method(self, other, op):
res_name = ops.get_op_result_name(self, other)
self, other = ops.align_method_SERIES(self, other)

lvalues = self._values
rvalues = extract_array(other, extract_numpy=True, extract_range=True)
rvalues = ops.maybe_prepare_scalar_for_op(rvalues, lvalues.shape)
rvalues = ensure_wrapped_if_datetimelike(rvalues)

with np.errstate(all="ignore"):
result = ops.arithmetic_op(lvalues, rvalues, op)

return self._construct_result(result, name=res_name)
return base.IndexOpsMixin._arith_method(self, other, op)


Series._add_numeric_operations()
Expand Down
10 changes: 2 additions & 8 deletions pandas/tests/arithmetic/test_datetime64.py
Original file line number Diff line number Diff line change
Expand Up @@ -2146,7 +2146,7 @@ def test_dti_sub_tdi(self, tz_naive_fixture):
result = dti - tdi.values
tm.assert_index_equal(result, expected)

msg = "cannot subtract DatetimeArray from"
msg = "cannot subtract a datelike from a TimedeltaArray"
with pytest.raises(TypeError, match=msg):
tdi.values - dti

Expand All @@ -2172,13 +2172,7 @@ def test_dti_isub_tdi(self, tz_naive_fixture):
result -= tdi.values
tm.assert_index_equal(result, expected)

msg = "|".join(
[
"cannot perform __neg__ with this index type:",
"ufunc subtract cannot use operands with types",
"cannot subtract DatetimeArray from",
]
)
msg = "cannot subtract a datelike from a TimedeltaArray"
with pytest.raises(TypeError, match=msg):
tdi.values -= dti

Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/arithmetic/test_period.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,7 @@ def test_pi_add_sub_td64_array_non_tick_raises(self):

with pytest.raises(TypeError, match=msg):
rng - tdarr
msg = r"cannot subtract PeriodArray from timedelta64\[ns\]"
msg = r"cannot subtract period\[Q-DEC\]-dtype from TimedeltaArray"
with pytest.raises(TypeError, match=msg):
tdarr - rng

Expand Down