Skip to content

Commit 3c6a26e

Browse files
authored
TST: parameterize more tests (#45131)
1 parent 9512393 commit 3c6a26e

File tree

4 files changed

+170
-185
lines changed

4 files changed

+170
-185
lines changed

pandas/tests/arithmetic/test_datetime64.py

+41-45
Original file line numberDiff line numberDiff line change
@@ -1810,55 +1810,51 @@ def test_dt64ser_sub_datetime_dtype(self):
18101810
# TODO: This next block of tests came from tests.series.test_operators,
18111811
# needs to be de-duplicated and parametrized over `box` classes
18121812

1813-
def test_operators_datetimelike_invalid(self, all_arithmetic_operators):
1814-
# these are all TypeEror ops
1813+
@pytest.mark.parametrize(
1814+
"left, right, op_fail",
1815+
[
1816+
[
1817+
[Timestamp("20111230"), Timestamp("20120101"), NaT],
1818+
[Timestamp("20111231"), Timestamp("20120102"), Timestamp("20120104")],
1819+
["__sub__", "__rsub__"],
1820+
],
1821+
[
1822+
[Timestamp("20111230"), Timestamp("20120101"), NaT],
1823+
[timedelta(minutes=5, seconds=3), timedelta(minutes=5, seconds=3), NaT],
1824+
["__add__", "__radd__", "__sub__"],
1825+
],
1826+
[
1827+
[
1828+
Timestamp("20111230", tz="US/Eastern"),
1829+
Timestamp("20111230", tz="US/Eastern"),
1830+
NaT,
1831+
],
1832+
[timedelta(minutes=5, seconds=3), NaT, timedelta(minutes=5, seconds=3)],
1833+
["__add__", "__radd__", "__sub__"],
1834+
],
1835+
],
1836+
)
1837+
def test_operators_datetimelike_invalid(
1838+
self, left, right, op_fail, all_arithmetic_operators
1839+
):
1840+
# these are all TypeError ops
18151841
op_str = all_arithmetic_operators
1816-
1817-
def check(get_ser, test_ser):
1818-
1819-
# check that we are getting a TypeError
1820-
# with 'operate' (from core/ops.py) for the ops that are not
1821-
# defined
1822-
op = getattr(get_ser, op_str, None)
1823-
# Previously, _validate_for_numeric_binop in core/indexes/base.py
1824-
# did this for us.
1842+
arg1 = Series(left)
1843+
arg2 = Series(right)
1844+
# check that we are getting a TypeError
1845+
# with 'operate' (from core/ops.py) for the ops that are not
1846+
# defined
1847+
op = getattr(arg1, op_str, None)
1848+
# Previously, _validate_for_numeric_binop in core/indexes/base.py
1849+
# did this for us.
1850+
if op_str not in op_fail:
18251851
with pytest.raises(
18261852
TypeError, match="operate|[cC]annot|unsupported operand"
18271853
):
1828-
op(test_ser)
1829-
1830-
# ## timedelta64 ###
1831-
td1 = Series([timedelta(minutes=5, seconds=3)] * 3)
1832-
td1.iloc[2] = np.nan
1833-
1834-
# ## datetime64 ###
1835-
dt1 = Series(
1836-
[Timestamp("20111230"), Timestamp("20120101"), Timestamp("20120103")]
1837-
)
1838-
dt1.iloc[2] = np.nan
1839-
dt2 = Series(
1840-
[Timestamp("20111231"), Timestamp("20120102"), Timestamp("20120104")]
1841-
)
1842-
if op_str not in ["__sub__", "__rsub__"]:
1843-
check(dt1, dt2)
1844-
1845-
# ## datetime64 with timetimedelta ###
1846-
# TODO(jreback) __rsub__ should raise?
1847-
if op_str not in ["__add__", "__radd__", "__sub__"]:
1848-
check(dt1, td1)
1849-
1850-
# 8260, 10763
1851-
# datetime64 with tz
1852-
tz = "US/Eastern"
1853-
dt1 = Series(date_range("2000-01-01 09:00:00", periods=5, tz=tz), name="foo")
1854-
dt2 = dt1.copy()
1855-
dt2.iloc[2] = np.nan
1856-
td1 = Series(pd.timedelta_range("1 days 1 min", periods=5, freq="H"))
1857-
td2 = td1.copy()
1858-
td2.iloc[1] = np.nan
1859-
1860-
if op_str not in ["__add__", "__radd__", "__sub__", "__rsub__"]:
1861-
check(dt2, td2)
1854+
op(arg2)
1855+
else:
1856+
# Smoke test
1857+
op(arg2)
18621858

18631859
def test_sub_single_tz(self):
18641860
# GH#12290

pandas/tests/arithmetic/test_numeric.py

+55-67
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,17 @@ def adjust_negative_zero(zero, expected):
5858
return expected
5959

6060

61+
def compare_op(series, other, op):
62+
left = np.abs(series) if op in (ops.rpow, operator.pow) else series
63+
right = np.abs(other) if op in (ops.rpow, operator.pow) else other
64+
65+
cython_or_numpy = op(left, right)
66+
python = left.combine(right, op)
67+
if isinstance(other, Series) and not other.index.equals(series.index):
68+
python.index = python.index._with_freq(None)
69+
tm.assert_series_equal(cython_or_numpy, python)
70+
71+
6172
# TODO: remove this kludge once mypy stops giving false positives here
6273
# List comprehension has incompatible type List[PandasObject]; expected List[RangeIndex]
6374
# See GH#29725
@@ -959,77 +970,54 @@ def test_frame_operators(self, float_frame):
959970
assert (df + df).equals(df)
960971
tm.assert_frame_equal(df + df, df)
961972

962-
# TODO: taken from tests.series.test_operators; needs cleanup
963-
def test_series_operators(self):
964-
def _check_op(series, other, op, pos_only=False):
965-
left = np.abs(series) if pos_only else series
966-
right = np.abs(other) if pos_only else other
967-
968-
cython_or_numpy = op(left, right)
969-
python = left.combine(right, op)
970-
if isinstance(other, Series) and not other.index.equals(series.index):
971-
python.index = python.index._with_freq(None)
972-
tm.assert_series_equal(cython_or_numpy, python)
973-
974-
def check(series, other):
975-
simple_ops = ["add", "sub", "mul", "truediv", "floordiv", "mod"]
976-
977-
for opname in simple_ops:
978-
_check_op(series, other, getattr(operator, opname))
973+
@pytest.mark.parametrize(
974+
"func",
975+
[lambda x: x * 2, lambda x: x[::2], lambda x: 5],
976+
ids=["multiply", "slice", "constant"],
977+
)
978+
def test_series_operators_arithmetic(self, all_arithmetic_functions, func):
979+
op = all_arithmetic_functions
980+
series = tm.makeTimeSeries().rename("ts")
981+
other = func(series)
982+
compare_op(series, other, op)
979983

980-
_check_op(series, other, operator.pow, pos_only=True)
984+
@pytest.mark.parametrize(
985+
"func", [lambda x: x + 1, lambda x: 5], ids=["add", "constant"]
986+
)
987+
def test_series_operators_compare(self, comparison_op, func):
988+
op = comparison_op
989+
series = tm.makeTimeSeries().rename("ts")
990+
other = func(series)
991+
compare_op(series, other, op)
981992

982-
_check_op(series, other, ops.radd)
983-
_check_op(series, other, ops.rsub)
984-
_check_op(series, other, ops.rtruediv)
985-
_check_op(series, other, ops.rfloordiv)
986-
_check_op(series, other, ops.rmul)
987-
_check_op(series, other, ops.rpow, pos_only=True)
988-
_check_op(series, other, ops.rmod)
993+
@pytest.mark.parametrize(
994+
"func",
995+
[lambda x: x * 2, lambda x: x[::2], lambda x: 5],
996+
ids=["multiply", "slice", "constant"],
997+
)
998+
def test_divmod(self, func):
999+
series = tm.makeTimeSeries().rename("ts")
1000+
other = func(series)
1001+
results = divmod(series, other)
1002+
if isinstance(other, abc.Iterable) and len(series) != len(other):
1003+
# if the lengths don't match, this is the test where we use
1004+
# `tser[::2]`. Pad every other value in `other_np` with nan.
1005+
other_np = []
1006+
for n in other:
1007+
other_np.append(n)
1008+
other_np.append(np.nan)
1009+
else:
1010+
other_np = other
1011+
other_np = np.asarray(other_np)
1012+
with np.errstate(all="ignore"):
1013+
expecteds = divmod(series.values, np.asarray(other_np))
9891014

990-
tser = tm.makeTimeSeries().rename("ts")
991-
check(tser, tser * 2)
992-
check(tser, tser[::2])
993-
check(tser, 5)
994-
995-
def check_comparators(series, other):
996-
_check_op(series, other, operator.gt)
997-
_check_op(series, other, operator.ge)
998-
_check_op(series, other, operator.eq)
999-
_check_op(series, other, operator.lt)
1000-
_check_op(series, other, operator.le)
1001-
1002-
check_comparators(tser, 5)
1003-
check_comparators(tser, tser + 1)
1004-
1005-
# TODO: taken from tests.series.test_operators; needs cleanup
1006-
def test_divmod(self):
1007-
def check(series, other):
1008-
results = divmod(series, other)
1009-
if isinstance(other, abc.Iterable) and len(series) != len(other):
1010-
# if the lengths don't match, this is the test where we use
1011-
# `tser[::2]`. Pad every other value in `other_np` with nan.
1012-
other_np = []
1013-
for n in other:
1014-
other_np.append(n)
1015-
other_np.append(np.nan)
1016-
else:
1017-
other_np = other
1018-
other_np = np.asarray(other_np)
1019-
with np.errstate(all="ignore"):
1020-
expecteds = divmod(series.values, np.asarray(other_np))
1021-
1022-
for result, expected in zip(results, expecteds):
1023-
# check the values, name, and index separately
1024-
tm.assert_almost_equal(np.asarray(result), expected)
1025-
1026-
assert result.name == series.name
1027-
tm.assert_index_equal(result.index, series.index._with_freq(None))
1015+
for result, expected in zip(results, expecteds):
1016+
# check the values, name, and index separately
1017+
tm.assert_almost_equal(np.asarray(result), expected)
10281018

1029-
tser = tm.makeTimeSeries().rename("ts")
1030-
check(tser, tser * 2)
1031-
check(tser, tser[::2])
1032-
check(tser, 5)
1019+
assert result.name == series.name
1020+
tm.assert_index_equal(result.index, series.index._with_freq(None))
10331021

10341022
def test_series_divmod_zero(self):
10351023
# Check that divmod uses pandas convention for division by zero,

pandas/tests/indexing/test_scalar.py

+29-21
Original file line numberDiff line numberDiff line change
@@ -47,31 +47,39 @@ def _check(f, func, values=False):
4747
_check(f, "at")
4848

4949
@pytest.mark.parametrize("kind", ["series", "frame"])
50-
def test_at_and_iat_set(self, kind):
51-
def _check(f, func, values=False):
50+
@pytest.mark.parametrize("col", ["ints", "uints"])
51+
def test_iat_set_ints(self, kind, col):
52+
f = getattr(self, kind)[col]
53+
if f is not None:
54+
indices = self.generate_indices(f, True)
55+
for i in indices:
56+
f.iat[i] = 1
57+
expected = self.get_value("iat", f, i, True)
58+
tm.assert_almost_equal(expected, 1)
5259

53-
if f is not None:
54-
indices = self.generate_indices(f, values)
60+
@pytest.mark.parametrize("kind", ["series", "frame"])
61+
@pytest.mark.parametrize("col", ["labels", "ts", "floats"])
62+
def test_iat_set_other(self, kind, col):
63+
f = getattr(self, kind)[col]
64+
if f is not None:
65+
msg = "iAt based indexing can only have integer indexers"
66+
with pytest.raises(ValueError, match=msg):
67+
indices = self.generate_indices(f, False)
5568
for i in indices:
56-
getattr(f, func)[i] = 1
57-
expected = self.get_value(func, f, i, values)
69+
f.iat[i] = 1
70+
expected = self.get_value("iat", f, i, False)
5871
tm.assert_almost_equal(expected, 1)
5972

60-
d = getattr(self, kind)
61-
62-
# iat
63-
for f in [d["ints"], d["uints"]]:
64-
_check(f, "iat", values=True)
65-
66-
for f in [d["labels"], d["ts"], d["floats"]]:
67-
if f is not None:
68-
msg = "iAt based indexing can only have integer indexers"
69-
with pytest.raises(ValueError, match=msg):
70-
_check(f, "iat")
71-
72-
# at
73-
for f in [d["ints"], d["uints"], d["labels"], d["ts"], d["floats"]]:
74-
_check(f, "at")
73+
@pytest.mark.parametrize("kind", ["series", "frame"])
74+
@pytest.mark.parametrize("col", ["ints", "uints", "labels", "ts", "floats"])
75+
def test_at_set_ints_other(self, kind, col):
76+
f = getattr(self, kind)[col]
77+
if f is not None:
78+
indices = self.generate_indices(f, False)
79+
for i in indices:
80+
f.at[i] = 1
81+
expected = self.get_value("at", f, i, False)
82+
tm.assert_almost_equal(expected, 1)
7583

7684

7785
class TestAtAndiAT:

0 commit comments

Comments
 (0)