Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions Lib/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -1128,21 +1128,33 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
if order:
# Create and set the ordering methods.
flds = [f for f in field_list if f.compare]
self_tuple = _tuple_str('self', flds)
other_tuple = _tuple_str('other', flds)
for name, op in [('__lt__', '<'),
('__le__', '<='),
('__gt__', '>'),
('__ge__', '>='),
]:
# Create a comparison function. If the fields in the object are
# named 'x' and 'y', then self_tuple is the string
# '(self.x,self.y)' and other_tuple is the string
# '(other.x,other.y)'.
# named 'x' and 'y'.
# if self.x != other.x:
# return self.x {op} other.x
# if self.y != other.y:
# return self.y {op} other.y
# return {op.endswith("=")}
return_when_equal = f' return {op.endswith("=")}'
func_builder.add_fn(name,
('self', 'other'),
[ ' if other.__class__ is self.__class__:',
f' return {self_tuple}{op}{other_tuple}',
[ ' if self is other:',
# __eq__ has this self guard, add here for consistency
return_when_equal,
' if other.__class__ is self.__class__:',
*(
f' if self.{f.name} != other.{f.name}:\n'
# ? use "op[0]" here since gated by "!=", probably not worth confusion
f' return self.{f.name} {op} other.{f.name}'
for f in flds
),
# the instances are equal here, return constant
return_when_equal,
' return NotImplemented'],
overwrite_error='Consider using functools.total_ordering')

Expand Down
85 changes: 85 additions & 0 deletions Lib/test/test_dataclasses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,16 @@ class C1:
@dataclass(order=True)
class C:
pass

# Check "self" comparisons.
ref = C()
self.assertEqual(ref, ref)
self.assertLessEqual(ref, ref)
self.assertGreaterEqual(ref, ref)
self.assertFalse(ref != ref)
self.assertFalse(ref < ref)
self.assertFalse(ref > ref)

self.assertLessEqual(C(), C())
self.assertGreaterEqual(C(), C())

Expand Down Expand Up @@ -399,13 +409,58 @@ class C1:
@dataclass(order=True)
class C:
x: int

# Check "self" comparisons.
ref = C(0)
self.assertEqual(ref, ref)
self.assertLessEqual(ref, ref)
self.assertGreaterEqual(ref, ref)
self.assertFalse(ref != ref)
self.assertFalse(ref < ref)
self.assertFalse(ref > ref)

self.assertLess(C(0), C(1))
self.assertLessEqual(C(0), C(1))
self.assertLessEqual(C(1), C(1))
self.assertGreater(C(1), C(0))
self.assertGreaterEqual(C(1), C(0))
self.assertGreaterEqual(C(1), C(1))

@dataclass(order=True)
class CFloat:
x: float

nan = float("nan")

# Check "self" comparisons.
ref = CFloat(nan)
self.assertEqual(ref, ref)
self.assertLessEqual(ref, ref)
self.assertGreaterEqual(ref, ref)
self.assertFalse(ref != ref)
self.assertFalse(ref < ref)
self.assertFalse(ref > ref)

self.assertNotEqual(CFloat(0.0), CFloat(nan))
self.assertNotEqual(CFloat(nan), CFloat(0.0))
self.assertNotEqual(CFloat(nan), CFloat(nan))

for idx, fn in enumerate([lambda a, b: a < b,
lambda a, b: a <= b,
lambda a, b: a == b]):
with self.subTest(idx=idx):
self.assertFalse(fn(CFloat(0.0), CFloat(nan)))
self.assertFalse(fn(CFloat(nan), CFloat(0.0)))
self.assertFalse(fn(CFloat(nan), CFloat(nan)))

for idx, fn in enumerate([lambda a, b: a > b,
lambda a, b: a >= b,
lambda a, b: a == b]):
with self.subTest(idx=idx):
self.assertFalse(fn(CFloat(0.0), CFloat(nan)))
self.assertFalse(fn(CFloat(nan), CFloat(0.0)))
self.assertFalse(fn(CFloat(nan), CFloat(nan)))

def test_simple_compare(self):
# Ensure that order=False is the default.
@dataclass
Expand Down Expand Up @@ -460,6 +515,36 @@ class C:
self.assertTrue(fn(C(1, 0), C(0, 1)))
self.assertTrue(fn(C(1, 1), C(1, 0)))

@dataclass(order=True)
class CFloat:
x: float
y: float

nan = float("nan")

self.assertNotEqual(CFloat(0.0, nan), CFloat(nan, 0.0))
self.assertNotEqual(CFloat(0.0, 0.0), CFloat(nan, nan))
self.assertNotEqual(CFloat(0.0, nan), CFloat(nan, nan))
self.assertNotEqual(CFloat(nan, nan), CFloat(nan, nan))

for idx, fn in enumerate([lambda a, b: a < b,
lambda a, b: a <= b,
lambda a, b: a == b]):
with self.subTest(idx=idx):
self.assertFalse(fn(CFloat(0.0, nan), CFloat(nan, 0.0)))
self.assertFalse(fn(CFloat(0.0, 0.0), CFloat(nan, nan)))
self.assertFalse(fn(CFloat(0.0, nan), CFloat(nan, nan)))
self.assertFalse(fn(CFloat(nan, nan), CFloat(nan, nan)))

for idx, fn in enumerate([lambda a, b: a > b,
lambda a, b: a >= b,
lambda a, b: a == b]):
with self.subTest(idx=idx):
self.assertFalse(fn(CFloat(0.0, nan), CFloat(nan, 0.0)))
self.assertFalse(fn(CFloat(0.0, 0.0), CFloat(nan, nan)))
self.assertFalse(fn(CFloat(0.0, nan), CFloat(nan, nan)))
self.assertFalse(fn(CFloat(nan, nan), CFloat(nan, nan)))

def test_compare_subclasses(self):
# Comparisons fail for subclasses, even if no fields
# are added.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Fix dataclass order method behaviors to align with the ``__eq__`` semantics
change introduced in 3.13.
Loading