Skip to content

Commit db53e69

Browse files
authored
compare indiv table rows and test (#582)
1 parent ea37280 commit db53e69

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

python/tests/test_tables.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -580,16 +580,21 @@ def test_equality(self):
580580
t2.set_columns(**input_data_copy)
581581
self.assertEqual(t1, t2)
582582
self.assertFalse(t1 != t2)
583+
self.assertEqual(t1[0], t2[0])
583584
col_copy += 1
584585
t2.set_columns(**input_data_copy)
585586
self.assertNotEqual(t1, t2)
586587
self.assertNotEqual(t2, t1)
588+
self.assertNotEqual(t1[0], t2[0])
589+
self.assertTrue(t1[0] != t2[0])
590+
self.assertTrue(t1[0] != [])
587591
for list_col, offset_col in self.ragged_list_columns:
588592
value = list_col.get_input(num_rows)
589593
input_data_copy = dict(input_data)
590594
input_data_copy[list_col.name] = value + 1
591595
t2.set_columns(**input_data_copy)
592596
self.assertNotEqual(t1, t2)
597+
self.assertNotEqual(t1[0], t2[0])
593598
value = list_col.get_input(num_rows + 1)
594599
input_data_copy = dict(input_data)
595600
input_data_copy[list_col.name] = value
@@ -600,6 +605,7 @@ def test_equality(self):
600605
t2.set_columns(**input_data_copy)
601606
self.assertNotEqual(t1, t2)
602607
self.assertNotEqual(t2, t1)
608+
self.assertNotEqual(t1[-1], t2[-1])
603609
# Different types should always be unequal.
604610
self.assertNotEqual(t1, None)
605611
self.assertNotEqual(t1, [])

python/tskit/tables.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,28 @@
3939
attr_options = {"slots": True, "frozen": True, "auto_attribs": True}
4040

4141

42-
@attr.s(**attr_options)
42+
# note: soon attrs will deprecate cmp; we just need to change this argument to eq
43+
@attr.s(cmp=False, **attr_options)
4344
class IndividualTableRow:
4445
flags: int
4546
location: np.ndarray
4647
metadata: bytes
4748

49+
def __eq__(self, other):
50+
if not isinstance(other, type(self)):
51+
return False
52+
else:
53+
return all(
54+
(
55+
self.flags == other.flags,
56+
np.array_equal(self.location, other.location),
57+
self.metadata == other.metadata,
58+
)
59+
)
60+
61+
def __neq__(self, other):
62+
return not self.__eq__(other)
63+
4864

4965
@attr.s(**attr_options)
5066
class NodeTableRow:

0 commit comments

Comments
 (0)