Skip to content

API: Treat 1D vectors as columns in some cases for arith/cmp ops #23306

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -4960,7 +4960,7 @@ def _combine_match_index(self, other, func, level=None):
copy=False)
assert left.index.equals(right.index)

if left._is_mixed_type or right._is_mixed_type:
if ops.should_series_dispatch(left, right, func):
# operate column-wise; avoid costly object-casting in `.values`
return ops.dispatch_to_series(left, right, func)
else:
Expand Down
40 changes: 34 additions & 6 deletions pandas/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -928,12 +928,14 @@ def should_series_dispatch(left, right, op):
if left._is_mixed_type or right._is_mixed_type:
return True

if not len(left.columns) or not len(right.columns):
if not len(left.columns) or (isinstance(right, ABCDataFrame) and
not len(right.columns)):
# ensure obj.dtypes[0] exists for each obj
return False

ldtype = left.dtypes.iloc[0]
rdtype = right.dtypes.iloc[0]
rdtype = (right.dtypes.iloc[0]
if isinstance(right, ABCDataFrame) else right.dtype)

if ((is_timedelta64_dtype(ldtype) and is_integer_dtype(rdtype)) or
(is_timedelta64_dtype(rdtype) and is_integer_dtype(ldtype))):
Expand Down Expand Up @@ -1767,6 +1769,23 @@ def _combine_series_frame(self, other, func, fill_value=None, axis=None,
return self._combine_match_columns(other, func, level=level)


def _treat_as_column(left, right, axis=None):
# Only if it is _impossible_ to treat `right` as row-like and a precise
# fit to treat it as column-like do we treat it as a column.
if axis is not None:
return False

if np.ndim(right) != 1:
return False

if isinstance(right, ABCSeries):
# require an exact match
return (right.index.equals(left.index) and
not left.index.equals(left.columns))

return len(right) == len(left) != len(left.columns)


def _align_method_FRAME(left, right, axis):
""" convert rhs to meet lhs dims if input is list, tuple or np.ndarray """

Expand Down Expand Up @@ -1850,7 +1869,12 @@ def na_op(x, y):
doc = _arith_doc_FRAME % op_name

@Appender(doc)
def f(self, other, axis=default_axis, level=None, fill_value=None):
def f(self, other, axis=None, level=None, fill_value=None):

if _treat_as_column(self, other, axis):
axis = "index"
else:
axis = axis if axis is not None else default_axis

other = _align_method_FRAME(self, other, axis)

Expand All @@ -1861,7 +1885,7 @@ def f(self, other, axis=default_axis, level=None, fill_value=None):
elif isinstance(other, ABCSeries):
# For these values of `axis`, we end up dispatching to Series op,
# so do not want the masked op.
pass_op = op if axis in [0, "columns", None] else na_op
pass_op = op if axis in [0, "columns", "index", None] else na_op
return _combine_series_frame(self, other, pass_op,
fill_value=fill_value, axis=axis,
level=level)
Expand Down Expand Up @@ -1923,7 +1947,11 @@ def _comp_method_FRAME(cls, func, special):
@Appender('Wrapper for comparison method {name}'.format(name=op_name))
def f(self, other):

other = _align_method_FRAME(self, other, axis=None)
axis = None
if _treat_as_column(self, other, axis=axis):
axis = "index"

other = _align_method_FRAME(self, other, axis=axis)

if isinstance(other, ABCDataFrame):
# Another DataFrame
Expand All @@ -1934,7 +1962,7 @@ def f(self, other):

elif isinstance(other, ABCSeries):
return _combine_series_frame(self, other, func,
fill_value=None, axis=None,
fill_value=None, axis=axis,
level=None)
else:

Expand Down
16 changes: 0 additions & 16 deletions pandas/tests/arithmetic/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,19 +153,3 @@ def box_df_fail(request):
Fixture equivalent to `box` fixture but xfailing the DataFrame case.
"""
return request.param


@pytest.fixture(params=[
pd.Index,
pd.Series,
pytest.param(pd.DataFrame,
marks=pytest.mark.xfail(reason="Tries to broadcast "
"incorrectly",
strict=True, raises=ValueError))
], ids=lambda x: x.__name__)
def box_df_broadcast_failure(request):
"""
Fixture equivalent to `box` but with the common failing case where
the DataFrame operation tries to broadcast incorrectly.
"""
return request.param
4 changes: 1 addition & 3 deletions pandas/tests/arithmetic/test_datetime64.py
Original file line number Diff line number Diff line change
Expand Up @@ -1365,10 +1365,8 @@ def test_sub_period(self, freq, box):
operator.sub, ops.rsub])
@pytest.mark.parametrize('pi_freq', ['D', 'W', 'Q', 'H'])
@pytest.mark.parametrize('dti_freq', [None, 'D'])
def test_dti_sub_pi(self, dti_freq, pi_freq, op, box_df_broadcast_failure):
def test_dti_sub_pi(self, dti_freq, pi_freq, op, box):
# GH#20049 subtracting PeriodIndex should raise TypeError
box = box_df_broadcast_failure

dti = pd.DatetimeIndex(['2011-01-01', '2011-01-02'], freq=dti_freq)
pi = dti.to_period(pi_freq)

Expand Down
8 changes: 2 additions & 6 deletions pandas/tests/arithmetic/test_period.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,9 +344,7 @@ class TestPeriodIndexArithmetic(object):
# PeriodIndex - other is defined for integers, timedelta-like others,
# and PeriodIndex (with matching freq)

def test_parr_add_iadd_parr_raises(self, box_df_broadcast_failure):
box = box_df_broadcast_failure

def test_parr_add_iadd_parr_raises(self, box):
rng = pd.period_range('1/1/2000', freq='D', periods=5)
other = pd.period_range('1/6/2000', freq='D', periods=5)
# TODO: parametrize over boxes for other?
Expand Down Expand Up @@ -388,9 +386,7 @@ def test_pi_sub_pi_with_nat(self):
expected = pd.Index([pd.NaT, 0 * off, 0 * off, 0 * off, 0 * off])
tm.assert_index_equal(result, expected)

def test_parr_sub_pi_mismatched_freq(self, box_df_broadcast_failure):
box = box_df_broadcast_failure

def test_parr_sub_pi_mismatched_freq(self, box):
rng = pd.period_range('1/1/2000', freq='D', periods=5)
other = pd.period_range('1/6/2000', freq='H', periods=5)
# TODO: parametrize over boxes for other?
Expand Down
Loading