Skip to content

Commit b9f541f

Browse files
benjefferymergify-bot
authored andcommitted
Add assert_equal methods
1 parent 4f53216 commit b9f541f

File tree

3 files changed

+457
-9
lines changed

3 files changed

+457
-9
lines changed

python/CHANGELOG.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@
3131
- Improve display of tables when ``print``ed, limiting lines set via
3232
``tskit.set_print_options`` (:user:`benjeffery`,:issue:`1270`, :pr:`1300`).
3333

34+
- Add ``Table.assert_equals`` and ``TableCollection.assert_equals`` which give an exact
35+
report of any differences. (:user:`benjeffery`,:issue:`1076`, :pr:`1328`)
36+
3437
**Fixes**
3538

3639
- Tree sequences were not properly init'd after unpickling

python/tests/test_tables.py

Lines changed: 281 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import pickle
3333
import platform
3434
import random
35+
import re
3536
import struct
3637
import time
3738
import unittest
@@ -73,7 +74,7 @@ def get_input(self, n):
7374
class CharColumn(Column):
7475
def get_input(self, n):
7576
rng = np.random.RandomState(42)
76-
return rng.randint(low=0, high=127, size=n, dtype=np.int8)
77+
return rng.randint(low=65, high=122, size=n, dtype=np.int8)
7778

7879

7980
class DoubleColumn(Column):
@@ -1009,7 +1010,162 @@ def test_set_with_optional_properties(self, codec):
10091010
assert md == row.metadata
10101011

10111012

1012-
class TestIndividualTable(CommonTestsMixin, MetadataTestsMixin):
1013+
class AssertEqualsMixin:
1014+
@pytest.fixture
1015+
def test_rows(self, scope="session"):
1016+
test_rows = self.make_transposed_input_data(10)
1017+
# Annoyingly we have to tweak some types as once added to a row and then put in
1018+
# an error message things come out differently
1019+
for n in range(10):
1020+
for col in test_rows[n].keys():
1021+
if col in ["timestamp", "record", "ancestral_state", "derived_state"]:
1022+
test_rows[n][col] = bytes(test_rows[n][col]).decode("ascii")
1023+
return test_rows
1024+
1025+
@pytest.fixture
1026+
def table1(self, test_rows):
1027+
table1 = self.table_class()
1028+
for row in test_rows[:5]:
1029+
table1.add_row(**row)
1030+
return table1
1031+
1032+
def test_equal(self, table1, test_rows):
1033+
table2 = self.table_class()
1034+
for row in test_rows[:5]:
1035+
table2.add_row(**row)
1036+
table1.assert_equals(table2)
1037+
1038+
def test_type(self, table1):
1039+
with pytest.raises(
1040+
AssertionError,
1041+
match=f"Types differ: self={type(table1)} other=<class 'int'>",
1042+
):
1043+
table1.assert_equals(42)
1044+
1045+
def test_metadata_schema(self, table1):
1046+
if hasattr(table1, "metadata_schema"):
1047+
table2 = table1.copy()
1048+
table2.metadata_schema = tskit.MetadataSchema({"codec": "json"})
1049+
with pytest.raises(
1050+
AssertionError,
1051+
match=f"{type(table1).__name__} metadata schemas differ: self=None "
1052+
f"other=OrderedDict([('codec', "
1053+
"'json')])",
1054+
):
1055+
table1.assert_equals(table2)
1056+
table1.assert_equals(table2, ignore_metadata=True)
1057+
1058+
def test_row_changes(self, table1, test_rows):
1059+
for column_name in test_rows[0].keys():
1060+
table2 = self.table_class()
1061+
for row in test_rows[:4]:
1062+
table2.add_row(**row)
1063+
modified_row = {
1064+
**test_rows[4],
1065+
**{column_name: test_rows[5][column_name]},
1066+
}
1067+
table2.add_row(**modified_row)
1068+
with pytest.raises(
1069+
AssertionError,
1070+
match=re.escape(
1071+
f"{type(table1).__name__} row 4 differs:\n"
1072+
f"self.{column_name}={test_rows[4][column_name]} "
1073+
f"other.{column_name}={test_rows[5][column_name]}"
1074+
),
1075+
):
1076+
table1.assert_equals(table2)
1077+
if column_name == "metadata":
1078+
table1.assert_equals(table2, ignore_metadata=True)
1079+
if column_name == "timestamp":
1080+
table1.assert_equals(table2, ignore_timestamps=True)
1081+
1082+
# Two columns differ, as we don't know the order in the error message
1083+
# test for both independently
1084+
for column_name, column_name2 in zip(
1085+
list(test_rows[0].keys())[:-1], list(test_rows[0].keys())[1:]
1086+
):
1087+
table2 = self.table_class()
1088+
for row in test_rows[:4]:
1089+
table2.add_row(**row)
1090+
modified_row = {
1091+
**test_rows[4],
1092+
**{
1093+
column_name: test_rows[5][column_name],
1094+
column_name2: test_rows[5][column_name2],
1095+
},
1096+
}
1097+
table2.add_row(**modified_row)
1098+
with pytest.raises(
1099+
AssertionError,
1100+
match=re.escape(
1101+
f"self.{column_name}={test_rows[4][column_name]} "
1102+
f"other.{column_name}={test_rows[5][column_name]}"
1103+
),
1104+
):
1105+
table1.assert_equals(table2)
1106+
with pytest.raises(
1107+
AssertionError,
1108+
match=re.escape(
1109+
f"self.{column_name2}={test_rows[4][column_name2]} "
1110+
f"other.{column_name2}={test_rows[5][column_name2]}"
1111+
),
1112+
):
1113+
table1.assert_equals(table2)
1114+
1115+
def test_num_rows(self, table1, test_rows):
1116+
table2 = self.table_class()
1117+
for row in test_rows[:4]:
1118+
table2.add_row(**row)
1119+
with pytest.raises(
1120+
AssertionError,
1121+
match=f"{type(table1).__name__} number of rows differ: self=5 other=4",
1122+
):
1123+
table1.assert_equals(table2)
1124+
1125+
def test_metadata(self, table1, test_rows):
1126+
if "metadata" in test_rows[0].keys():
1127+
table2 = self.table_class()
1128+
for row in test_rows[:4]:
1129+
table2.add_row(**row)
1130+
modified_row = {
1131+
**test_rows[4],
1132+
**{"metadata": test_rows[5]["metadata"]},
1133+
}
1134+
table2.add_row(**modified_row)
1135+
with pytest.raises(
1136+
AssertionError,
1137+
match=re.escape(
1138+
f"{type(table1).__name__} row 4 differs:\n"
1139+
f"self.metadata={test_rows[4]['metadata']} "
1140+
f"other.metadata={test_rows[5]['metadata']}"
1141+
),
1142+
):
1143+
table1.assert_equals(table2)
1144+
table1.assert_equals(table2, ignore_metadata=True)
1145+
1146+
def test_timestamp(self, table1, test_rows):
1147+
if "timestamp" in test_rows[0].keys():
1148+
table2 = self.table_class()
1149+
for row in test_rows[:4]:
1150+
table2.add_row(**row)
1151+
modified_row = {
1152+
**test_rows[4],
1153+
**{"timestamp": test_rows[5]["timestamp"]},
1154+
}
1155+
table2.add_row(**modified_row)
1156+
with pytest.raises(
1157+
AssertionError,
1158+
match=re.escape(
1159+
f"{type(table1).__name__} row 4 differs:\n"
1160+
f"self.timestamp={test_rows[4]['timestamp']} "
1161+
f"other.timestamp={test_rows[5]['timestamp']}"
1162+
),
1163+
):
1164+
table1.assert_equals(table2)
1165+
table1.assert_equals(table2, ignore_timestamps=True)
1166+
1167+
1168+
class TestIndividualTable(CommonTestsMixin, MetadataTestsMixin, AssertEqualsMixin):
10131169
columns = [UInt32Column("flags")]
10141170
ragged_list_columns = [
10151171
(DoubleColumn("location"), UInt32Column("location_offset")),
@@ -1138,7 +1294,7 @@ def test_various_not_equals(self):
11381294
assert a == b
11391295

11401296

1141-
class TestNodeTable(CommonTestsMixin, MetadataTestsMixin):
1297+
class TestNodeTable(CommonTestsMixin, MetadataTestsMixin, AssertEqualsMixin):
11421298

11431299
columns = [
11441300
UInt32Column("flags"),
@@ -1219,7 +1375,7 @@ def test_add_row_bad_data(self):
12191375
t.add_row(metadata=123)
12201376

12211377

1222-
class TestEdgeTable(CommonTestsMixin, MetadataTestsMixin):
1378+
class TestEdgeTable(CommonTestsMixin, MetadataTestsMixin, AssertEqualsMixin):
12231379

12241380
columns = [
12251381
DoubleColumn("left"),
@@ -1267,7 +1423,7 @@ def test_add_row_bad_data(self):
12671423
t.add_row(0, 0, 0, 0, metadata=123)
12681424

12691425

1270-
class TestSiteTable(CommonTestsMixin, MetadataTestsMixin):
1426+
class TestSiteTable(CommonTestsMixin, MetadataTestsMixin, AssertEqualsMixin):
12711427
columns = [DoubleColumn("position")]
12721428
ragged_list_columns = [
12731429
(CharColumn("ancestral_state"), UInt32Column("ancestral_state_offset")),
@@ -1322,7 +1478,7 @@ def test_packset_ancestral_state(self):
13221478
assert np.array_equal(table.ancestral_state_offset, ancestral_state_offset)
13231479

13241480

1325-
class TestMutationTable(CommonTestsMixin, MetadataTestsMixin):
1481+
class TestMutationTable(CommonTestsMixin, MetadataTestsMixin, AssertEqualsMixin):
13261482
columns = [
13271483
Int32Column("site"),
13281484
Int32Column("node"),
@@ -1394,7 +1550,7 @@ def test_packset_derived_state(self):
13941550
assert np.array_equal(table.derived_state_offset, derived_state_offset)
13951551

13961552

1397-
class TestMigrationTable(CommonTestsMixin, MetadataTestsMixin):
1553+
class TestMigrationTable(CommonTestsMixin, MetadataTestsMixin, AssertEqualsMixin):
13981554
columns = [
13991555
DoubleColumn("left"),
14001556
DoubleColumn("right"),
@@ -1445,7 +1601,7 @@ def test_add_row_bad_data(self):
14451601
t.add_row(0, 0, 0, 0, 0, 0, metadata=123)
14461602

14471603

1448-
class TestProvenanceTable(CommonTestsMixin):
1604+
class TestProvenanceTable(CommonTestsMixin, AssertEqualsMixin):
14491605
columns = []
14501606
ragged_list_columns = [
14511607
(CharColumn("timestamp"), UInt32Column("timestamp_offset")),
@@ -1496,7 +1652,7 @@ def test_packset_record(self):
14961652
assert t[1].record == "BBBB"
14971653

14981654

1499-
class TestPopulationTable(CommonTestsMixin, MetadataTestsMixin):
1655+
class TestPopulationTable(CommonTestsMixin, MetadataTestsMixin, AssertEqualsMixin):
15001656
metadata_mandatory = True
15011657
columns = []
15021658
ragged_list_columns = [(CharColumn("metadata"), UInt32Column("metadata_offset"))]
@@ -3307,6 +3463,122 @@ def test_equals_population_metadata(self, ts_fixture):
33073463
assert t1.equals(t2, ignore_metadata=True)
33083464

33093465

3466+
class TestTableCollectionAssertEquals:
3467+
@pytest.fixture
3468+
def t1(self, ts_fixture):
3469+
return ts_fixture.dump_tables()
3470+
3471+
@pytest.fixture
3472+
def t2(self, ts_fixture):
3473+
return ts_fixture.dump_tables()
3474+
3475+
def test_equal(self, t1, t2):
3476+
assert t1 is not t2
3477+
t1.assert_equals(t2)
3478+
3479+
def test_type(self, t1):
3480+
with pytest.raises(
3481+
AssertionError,
3482+
match=re.escape(
3483+
"Types differ: self=<class 'tskit.tables.TableCollection'> "
3484+
"other=<class 'int'>"
3485+
),
3486+
):
3487+
t1.assert_equals(42)
3488+
3489+
def test_sequence_length(self, t1, t2):
3490+
t2.sequence_length = 42
3491+
with pytest.raises(
3492+
AssertionError, match="Sequence Length differs: self=1.0 other=42.0"
3493+
):
3494+
t1.assert_equals(t2)
3495+
3496+
def test_metadata_schema(self, t1, t2):
3497+
t2.metadata_schema = tskit.MetadataSchema(None)
3498+
with pytest.raises(
3499+
AssertionError,
3500+
match=re.escape(
3501+
"Metadata schemas differ: self=OrderedDict([('codec', 'json')]) "
3502+
"other=None"
3503+
),
3504+
):
3505+
t1.assert_equals(t2)
3506+
t1.assert_equals(t2, ignore_metadata=True)
3507+
t1.assert_equals(t2, ignore_ts_metadata=True)
3508+
3509+
def test_metadata(self, t1, t2):
3510+
t2.metadata = {"foo": "bar"}
3511+
with pytest.raises(
3512+
AssertionError,
3513+
match=re.escape(
3514+
"Metadata differs: self=Test metadata other={'foo': 'bar'}"
3515+
),
3516+
):
3517+
t1.assert_equals(t2)
3518+
t1.assert_equals(t2, ignore_metadata=True)
3519+
t1.assert_equals(t2, ignore_ts_metadata=True)
3520+
3521+
@pytest.mark.parametrize("table_name", tskit.TableCollection(1).name_map)
3522+
def test_tables(self, t1, t2, table_name):
3523+
table = getattr(t2, table_name)
3524+
table.truncate(0)
3525+
with pytest.raises(
3526+
AssertionError,
3527+
match=f"{type(table).__name__} number of rows differ: "
3528+
f"self={len(getattr(t1, table_name))} other=0",
3529+
):
3530+
t1.assert_equals(t2)
3531+
3532+
@pytest.mark.parametrize("table_name", tskit.TableCollection(1).name_map)
3533+
def test_ignore_metadata(self, t1, t2, table_name):
3534+
table = getattr(t2, table_name)
3535+
if hasattr(table, "metadata_schema"):
3536+
table.metadata_schema = tskit.MetadataSchema(None)
3537+
with pytest.raises(
3538+
AssertionError,
3539+
match=re.escape(
3540+
f"{type(table).__name__} metadata schemas differ: "
3541+
f"self=OrderedDict([('codec', 'json')]) other=None"
3542+
),
3543+
):
3544+
t1.assert_equals(t2)
3545+
t1.assert_equals(t2, ignore_metadata=True)
3546+
3547+
def test_ignore_provenance(self, t1, t2):
3548+
t2.provenances.truncate(0)
3549+
with pytest.raises(
3550+
AssertionError,
3551+
match="ProvenanceTable number of rows differ: self=1 other=0",
3552+
):
3553+
t1.assert_equals(t2)
3554+
with pytest.raises(
3555+
AssertionError,
3556+
match="ProvenanceTable number of rows differ: self=1 other=0",
3557+
):
3558+
t1.assert_equals(t2, ignore_timestamps=True)
3559+
3560+
t1.assert_equals(t2, ignore_provenance=True)
3561+
3562+
def test_ignore_timestamps(self, t1, t2):
3563+
table = t2.provenances
3564+
timestamp = table.timestamp
3565+
timestamp[0] = ord("F")
3566+
table.set_columns(
3567+
timestamp=timestamp,
3568+
timestamp_offset=table.timestamp_offset,
3569+
record=table.record,
3570+
record_offset=table.record_offset,
3571+
)
3572+
with pytest.raises(
3573+
AssertionError,
3574+
match="ProvenanceTable row 0 differs:\n"
3575+
"self.timestamp=.* other.timestamp=F.*",
3576+
):
3577+
t1.assert_equals(t2)
3578+
t1.assert_equals(t2, ignore_provenance=True)
3579+
t1.assert_equals(t2, ignore_timestamps=True)
3580+
3581+
33103582
class TestTableCollectionMethodSignatures:
33113583
tc = msprime.simulate(10, random_seed=1234).dump_tables()
33123584

0 commit comments

Comments
 (0)