diff --git a/python/tests/test_tables.py b/python/tests/test_tables.py index e4a3397f11..c88cfb425c 100644 --- a/python/tests/test_tables.py +++ b/python/tests/test_tables.py @@ -580,16 +580,21 @@ def test_equality(self): t2.set_columns(**input_data_copy) self.assertEqual(t1, t2) self.assertFalse(t1 != t2) + self.assertEqual(t1[0], t2[0]) col_copy += 1 t2.set_columns(**input_data_copy) self.assertNotEqual(t1, t2) self.assertNotEqual(t2, t1) + self.assertNotEqual(t1[0], t2[0]) + self.assertTrue(t1[0] != t2[0]) + self.assertTrue(t1[0] != []) for list_col, offset_col in self.ragged_list_columns: value = list_col.get_input(num_rows) input_data_copy = dict(input_data) input_data_copy[list_col.name] = value + 1 t2.set_columns(**input_data_copy) self.assertNotEqual(t1, t2) + self.assertNotEqual(t1[0], t2[0]) value = list_col.get_input(num_rows + 1) input_data_copy = dict(input_data) input_data_copy[list_col.name] = value @@ -600,6 +605,7 @@ def test_equality(self): t2.set_columns(**input_data_copy) self.assertNotEqual(t1, t2) self.assertNotEqual(t2, t1) + self.assertNotEqual(t1[-1], t2[-1]) # Different types should always be unequal. self.assertNotEqual(t1, None) self.assertNotEqual(t1, []) diff --git a/python/tskit/tables.py b/python/tskit/tables.py index 3268684c1b..3ab62f994c 100644 --- a/python/tskit/tables.py +++ b/python/tskit/tables.py @@ -39,12 +39,28 @@ attr_options = {"slots": True, "frozen": True, "auto_attribs": True} -@attr.s(**attr_options) +# note: soon attrs will deprecate cmp; we just need to change this argument to eq +@attr.s(cmp=False, **attr_options) class IndividualTableRow: flags: int location: np.ndarray metadata: bytes + def __eq__(self, other): + if not isinstance(other, type(self)): + return False + else: + return all( + ( + self.flags == other.flags, + np.array_equal(self.location, other.location), + self.metadata == other.metadata, + ) + ) + + def __neq__(self, other): + return not self.__eq__(other) + @attr.s(**attr_options) class NodeTableRow: