|
32 | 32 | import pickle
|
33 | 33 | import platform
|
34 | 34 | import random
|
| 35 | +import re |
35 | 36 | import struct
|
36 | 37 | import time
|
37 | 38 | import unittest
|
@@ -73,7 +74,7 @@ def get_input(self, n):
|
73 | 74 | class CharColumn(Column):
|
74 | 75 | def get_input(self, n):
|
75 | 76 | rng = np.random.RandomState(42)
|
76 |
| - return rng.randint(low=0, high=127, size=n, dtype=np.int8) |
| 77 | + return rng.randint(low=65, high=122, size=n, dtype=np.int8) |
77 | 78 |
|
78 | 79 |
|
79 | 80 | class DoubleColumn(Column):
|
@@ -1009,7 +1010,162 @@ def test_set_with_optional_properties(self, codec):
|
1009 | 1010 | assert md == row.metadata
|
1010 | 1011 |
|
1011 | 1012 |
|
1012 |
| -class TestIndividualTable(CommonTestsMixin, MetadataTestsMixin): |
| 1013 | +class AssertEqualsMixin: |
| 1014 | + @pytest.fixture |
| 1015 | + def test_rows(self, scope="session"): |
| 1016 | + test_rows = self.make_transposed_input_data(10) |
| 1017 | + # Annoyingly we have to tweak some types as once added to a row and then put in |
| 1018 | + # an error message things come out differently |
| 1019 | + for n in range(10): |
| 1020 | + for col in test_rows[n].keys(): |
| 1021 | + if col in ["timestamp", "record", "ancestral_state", "derived_state"]: |
| 1022 | + test_rows[n][col] = bytes(test_rows[n][col]).decode("ascii") |
| 1023 | + return test_rows |
| 1024 | + |
| 1025 | + @pytest.fixture |
| 1026 | + def table1(self, test_rows): |
| 1027 | + table1 = self.table_class() |
| 1028 | + for row in test_rows[:5]: |
| 1029 | + table1.add_row(**row) |
| 1030 | + return table1 |
| 1031 | + |
| 1032 | + def test_equal(self, table1, test_rows): |
| 1033 | + table2 = self.table_class() |
| 1034 | + for row in test_rows[:5]: |
| 1035 | + table2.add_row(**row) |
| 1036 | + table1.assert_equals(table2) |
| 1037 | + |
| 1038 | + def test_type(self, table1): |
| 1039 | + with pytest.raises( |
| 1040 | + AssertionError, |
| 1041 | + match=f"Types differ: self={type(table1)} other=<class 'int'>", |
| 1042 | + ): |
| 1043 | + table1.assert_equals(42) |
| 1044 | + |
| 1045 | + def test_metadata_schema(self, table1): |
| 1046 | + if hasattr(table1, "metadata_schema"): |
| 1047 | + table2 = table1.copy() |
| 1048 | + table2.metadata_schema = tskit.MetadataSchema({"codec": "json"}) |
| 1049 | + with pytest.raises( |
| 1050 | + AssertionError, |
| 1051 | + match=f"{type(table1).__name__} metadata schemas differ: self=None " |
| 1052 | + f"other=OrderedDict([('codec', " |
| 1053 | + "'json')])", |
| 1054 | + ): |
| 1055 | + table1.assert_equals(table2) |
| 1056 | + table1.assert_equals(table2, ignore_metadata=True) |
| 1057 | + |
| 1058 | + def test_row_changes(self, table1, test_rows): |
| 1059 | + for column_name in test_rows[0].keys(): |
| 1060 | + table2 = self.table_class() |
| 1061 | + for row in test_rows[:4]: |
| 1062 | + table2.add_row(**row) |
| 1063 | + modified_row = { |
| 1064 | + **test_rows[4], |
| 1065 | + **{column_name: test_rows[5][column_name]}, |
| 1066 | + } |
| 1067 | + table2.add_row(**modified_row) |
| 1068 | + with pytest.raises( |
| 1069 | + AssertionError, |
| 1070 | + match=re.escape( |
| 1071 | + f"{type(table1).__name__} row 4 differs:\n" |
| 1072 | + f"self.{column_name}={test_rows[4][column_name]} " |
| 1073 | + f"other.{column_name}={test_rows[5][column_name]}" |
| 1074 | + ), |
| 1075 | + ): |
| 1076 | + table1.assert_equals(table2) |
| 1077 | + if column_name == "metadata": |
| 1078 | + table1.assert_equals(table2, ignore_metadata=True) |
| 1079 | + if column_name == "timestamp": |
| 1080 | + table1.assert_equals(table2, ignore_timestamps=True) |
| 1081 | + |
| 1082 | + # Two columns differ, as we don't know the order in the error message |
| 1083 | + # test for both independently |
| 1084 | + for column_name, column_name2 in zip( |
| 1085 | + list(test_rows[0].keys())[:-1], list(test_rows[0].keys())[1:] |
| 1086 | + ): |
| 1087 | + table2 = self.table_class() |
| 1088 | + for row in test_rows[:4]: |
| 1089 | + table2.add_row(**row) |
| 1090 | + modified_row = { |
| 1091 | + **test_rows[4], |
| 1092 | + **{ |
| 1093 | + column_name: test_rows[5][column_name], |
| 1094 | + column_name2: test_rows[5][column_name2], |
| 1095 | + }, |
| 1096 | + } |
| 1097 | + table2.add_row(**modified_row) |
| 1098 | + with pytest.raises( |
| 1099 | + AssertionError, |
| 1100 | + match=re.escape( |
| 1101 | + f"self.{column_name}={test_rows[4][column_name]} " |
| 1102 | + f"other.{column_name}={test_rows[5][column_name]}" |
| 1103 | + ), |
| 1104 | + ): |
| 1105 | + table1.assert_equals(table2) |
| 1106 | + with pytest.raises( |
| 1107 | + AssertionError, |
| 1108 | + match=re.escape( |
| 1109 | + f"self.{column_name2}={test_rows[4][column_name2]} " |
| 1110 | + f"other.{column_name2}={test_rows[5][column_name2]}" |
| 1111 | + ), |
| 1112 | + ): |
| 1113 | + table1.assert_equals(table2) |
| 1114 | + |
| 1115 | + def test_num_rows(self, table1, test_rows): |
| 1116 | + table2 = self.table_class() |
| 1117 | + for row in test_rows[:4]: |
| 1118 | + table2.add_row(**row) |
| 1119 | + with pytest.raises( |
| 1120 | + AssertionError, |
| 1121 | + match=f"{type(table1).__name__} number of rows differ: self=5 other=4", |
| 1122 | + ): |
| 1123 | + table1.assert_equals(table2) |
| 1124 | + |
| 1125 | + def test_metadata(self, table1, test_rows): |
| 1126 | + if "metadata" in test_rows[0].keys(): |
| 1127 | + table2 = self.table_class() |
| 1128 | + for row in test_rows[:4]: |
| 1129 | + table2.add_row(**row) |
| 1130 | + modified_row = { |
| 1131 | + **test_rows[4], |
| 1132 | + **{"metadata": test_rows[5]["metadata"]}, |
| 1133 | + } |
| 1134 | + table2.add_row(**modified_row) |
| 1135 | + with pytest.raises( |
| 1136 | + AssertionError, |
| 1137 | + match=re.escape( |
| 1138 | + f"{type(table1).__name__} row 4 differs:\n" |
| 1139 | + f"self.metadata={test_rows[4]['metadata']} " |
| 1140 | + f"other.metadata={test_rows[5]['metadata']}" |
| 1141 | + ), |
| 1142 | + ): |
| 1143 | + table1.assert_equals(table2) |
| 1144 | + table1.assert_equals(table2, ignore_metadata=True) |
| 1145 | + |
| 1146 | + def test_timestamp(self, table1, test_rows): |
| 1147 | + if "timestamp" in test_rows[0].keys(): |
| 1148 | + table2 = self.table_class() |
| 1149 | + for row in test_rows[:4]: |
| 1150 | + table2.add_row(**row) |
| 1151 | + modified_row = { |
| 1152 | + **test_rows[4], |
| 1153 | + **{"timestamp": test_rows[5]["timestamp"]}, |
| 1154 | + } |
| 1155 | + table2.add_row(**modified_row) |
| 1156 | + with pytest.raises( |
| 1157 | + AssertionError, |
| 1158 | + match=re.escape( |
| 1159 | + f"{type(table1).__name__} row 4 differs:\n" |
| 1160 | + f"self.timestamp={test_rows[4]['timestamp']} " |
| 1161 | + f"other.timestamp={test_rows[5]['timestamp']}" |
| 1162 | + ), |
| 1163 | + ): |
| 1164 | + table1.assert_equals(table2) |
| 1165 | + table1.assert_equals(table2, ignore_timestamps=True) |
| 1166 | + |
| 1167 | + |
| 1168 | +class TestIndividualTable(CommonTestsMixin, MetadataTestsMixin, AssertEqualsMixin): |
1013 | 1169 | columns = [UInt32Column("flags")]
|
1014 | 1170 | ragged_list_columns = [
|
1015 | 1171 | (DoubleColumn("location"), UInt32Column("location_offset")),
|
@@ -1138,7 +1294,7 @@ def test_various_not_equals(self):
|
1138 | 1294 | assert a == b
|
1139 | 1295 |
|
1140 | 1296 |
|
1141 |
| -class TestNodeTable(CommonTestsMixin, MetadataTestsMixin): |
| 1297 | +class TestNodeTable(CommonTestsMixin, MetadataTestsMixin, AssertEqualsMixin): |
1142 | 1298 |
|
1143 | 1299 | columns = [
|
1144 | 1300 | UInt32Column("flags"),
|
@@ -1219,7 +1375,7 @@ def test_add_row_bad_data(self):
|
1219 | 1375 | t.add_row(metadata=123)
|
1220 | 1376 |
|
1221 | 1377 |
|
1222 |
| -class TestEdgeTable(CommonTestsMixin, MetadataTestsMixin): |
| 1378 | +class TestEdgeTable(CommonTestsMixin, MetadataTestsMixin, AssertEqualsMixin): |
1223 | 1379 |
|
1224 | 1380 | columns = [
|
1225 | 1381 | DoubleColumn("left"),
|
@@ -1267,7 +1423,7 @@ def test_add_row_bad_data(self):
|
1267 | 1423 | t.add_row(0, 0, 0, 0, metadata=123)
|
1268 | 1424 |
|
1269 | 1425 |
|
1270 |
| -class TestSiteTable(CommonTestsMixin, MetadataTestsMixin): |
| 1426 | +class TestSiteTable(CommonTestsMixin, MetadataTestsMixin, AssertEqualsMixin): |
1271 | 1427 | columns = [DoubleColumn("position")]
|
1272 | 1428 | ragged_list_columns = [
|
1273 | 1429 | (CharColumn("ancestral_state"), UInt32Column("ancestral_state_offset")),
|
@@ -1322,7 +1478,7 @@ def test_packset_ancestral_state(self):
|
1322 | 1478 | assert np.array_equal(table.ancestral_state_offset, ancestral_state_offset)
|
1323 | 1479 |
|
1324 | 1480 |
|
1325 |
| -class TestMutationTable(CommonTestsMixin, MetadataTestsMixin): |
| 1481 | +class TestMutationTable(CommonTestsMixin, MetadataTestsMixin, AssertEqualsMixin): |
1326 | 1482 | columns = [
|
1327 | 1483 | Int32Column("site"),
|
1328 | 1484 | Int32Column("node"),
|
@@ -1394,7 +1550,7 @@ def test_packset_derived_state(self):
|
1394 | 1550 | assert np.array_equal(table.derived_state_offset, derived_state_offset)
|
1395 | 1551 |
|
1396 | 1552 |
|
1397 |
| -class TestMigrationTable(CommonTestsMixin, MetadataTestsMixin): |
| 1553 | +class TestMigrationTable(CommonTestsMixin, MetadataTestsMixin, AssertEqualsMixin): |
1398 | 1554 | columns = [
|
1399 | 1555 | DoubleColumn("left"),
|
1400 | 1556 | DoubleColumn("right"),
|
@@ -1445,7 +1601,7 @@ def test_add_row_bad_data(self):
|
1445 | 1601 | t.add_row(0, 0, 0, 0, 0, 0, metadata=123)
|
1446 | 1602 |
|
1447 | 1603 |
|
1448 |
| -class TestProvenanceTable(CommonTestsMixin): |
| 1604 | +class TestProvenanceTable(CommonTestsMixin, AssertEqualsMixin): |
1449 | 1605 | columns = []
|
1450 | 1606 | ragged_list_columns = [
|
1451 | 1607 | (CharColumn("timestamp"), UInt32Column("timestamp_offset")),
|
@@ -1496,7 +1652,7 @@ def test_packset_record(self):
|
1496 | 1652 | assert t[1].record == "BBBB"
|
1497 | 1653 |
|
1498 | 1654 |
|
1499 |
| -class TestPopulationTable(CommonTestsMixin, MetadataTestsMixin): |
| 1655 | +class TestPopulationTable(CommonTestsMixin, MetadataTestsMixin, AssertEqualsMixin): |
1500 | 1656 | metadata_mandatory = True
|
1501 | 1657 | columns = []
|
1502 | 1658 | ragged_list_columns = [(CharColumn("metadata"), UInt32Column("metadata_offset"))]
|
@@ -3307,6 +3463,122 @@ def test_equals_population_metadata(self, ts_fixture):
|
3307 | 3463 | assert t1.equals(t2, ignore_metadata=True)
|
3308 | 3464 |
|
3309 | 3465 |
|
| 3466 | +class TestTableCollectionAssertEquals: |
| 3467 | + @pytest.fixture |
| 3468 | + def t1(self, ts_fixture): |
| 3469 | + return ts_fixture.dump_tables() |
| 3470 | + |
| 3471 | + @pytest.fixture |
| 3472 | + def t2(self, ts_fixture): |
| 3473 | + return ts_fixture.dump_tables() |
| 3474 | + |
| 3475 | + def test_equal(self, t1, t2): |
| 3476 | + assert t1 is not t2 |
| 3477 | + t1.assert_equals(t2) |
| 3478 | + |
| 3479 | + def test_type(self, t1): |
| 3480 | + with pytest.raises( |
| 3481 | + AssertionError, |
| 3482 | + match=re.escape( |
| 3483 | + "Types differ: self=<class 'tskit.tables.TableCollection'> " |
| 3484 | + "other=<class 'int'>" |
| 3485 | + ), |
| 3486 | + ): |
| 3487 | + t1.assert_equals(42) |
| 3488 | + |
| 3489 | + def test_sequence_length(self, t1, t2): |
| 3490 | + t2.sequence_length = 42 |
| 3491 | + with pytest.raises( |
| 3492 | + AssertionError, match="Sequence Length differs: self=1.0 other=42.0" |
| 3493 | + ): |
| 3494 | + t1.assert_equals(t2) |
| 3495 | + |
| 3496 | + def test_metadata_schema(self, t1, t2): |
| 3497 | + t2.metadata_schema = tskit.MetadataSchema(None) |
| 3498 | + with pytest.raises( |
| 3499 | + AssertionError, |
| 3500 | + match=re.escape( |
| 3501 | + "Metadata schemas differ: self=OrderedDict([('codec', 'json')]) " |
| 3502 | + "other=None" |
| 3503 | + ), |
| 3504 | + ): |
| 3505 | + t1.assert_equals(t2) |
| 3506 | + t1.assert_equals(t2, ignore_metadata=True) |
| 3507 | + t1.assert_equals(t2, ignore_ts_metadata=True) |
| 3508 | + |
| 3509 | + def test_metadata(self, t1, t2): |
| 3510 | + t2.metadata = {"foo": "bar"} |
| 3511 | + with pytest.raises( |
| 3512 | + AssertionError, |
| 3513 | + match=re.escape( |
| 3514 | + "Metadata differs: self=Test metadata other={'foo': 'bar'}" |
| 3515 | + ), |
| 3516 | + ): |
| 3517 | + t1.assert_equals(t2) |
| 3518 | + t1.assert_equals(t2, ignore_metadata=True) |
| 3519 | + t1.assert_equals(t2, ignore_ts_metadata=True) |
| 3520 | + |
| 3521 | + @pytest.mark.parametrize("table_name", tskit.TableCollection(1).name_map) |
| 3522 | + def test_tables(self, t1, t2, table_name): |
| 3523 | + table = getattr(t2, table_name) |
| 3524 | + table.truncate(0) |
| 3525 | + with pytest.raises( |
| 3526 | + AssertionError, |
| 3527 | + match=f"{type(table).__name__} number of rows differ: " |
| 3528 | + f"self={len(getattr(t1, table_name))} other=0", |
| 3529 | + ): |
| 3530 | + t1.assert_equals(t2) |
| 3531 | + |
| 3532 | + @pytest.mark.parametrize("table_name", tskit.TableCollection(1).name_map) |
| 3533 | + def test_ignore_metadata(self, t1, t2, table_name): |
| 3534 | + table = getattr(t2, table_name) |
| 3535 | + if hasattr(table, "metadata_schema"): |
| 3536 | + table.metadata_schema = tskit.MetadataSchema(None) |
| 3537 | + with pytest.raises( |
| 3538 | + AssertionError, |
| 3539 | + match=re.escape( |
| 3540 | + f"{type(table).__name__} metadata schemas differ: " |
| 3541 | + f"self=OrderedDict([('codec', 'json')]) other=None" |
| 3542 | + ), |
| 3543 | + ): |
| 3544 | + t1.assert_equals(t2) |
| 3545 | + t1.assert_equals(t2, ignore_metadata=True) |
| 3546 | + |
| 3547 | + def test_ignore_provenance(self, t1, t2): |
| 3548 | + t2.provenances.truncate(0) |
| 3549 | + with pytest.raises( |
| 3550 | + AssertionError, |
| 3551 | + match="ProvenanceTable number of rows differ: self=1 other=0", |
| 3552 | + ): |
| 3553 | + t1.assert_equals(t2) |
| 3554 | + with pytest.raises( |
| 3555 | + AssertionError, |
| 3556 | + match="ProvenanceTable number of rows differ: self=1 other=0", |
| 3557 | + ): |
| 3558 | + t1.assert_equals(t2, ignore_timestamps=True) |
| 3559 | + |
| 3560 | + t1.assert_equals(t2, ignore_provenance=True) |
| 3561 | + |
| 3562 | + def test_ignore_timestamps(self, t1, t2): |
| 3563 | + table = t2.provenances |
| 3564 | + timestamp = table.timestamp |
| 3565 | + timestamp[0] = ord("F") |
| 3566 | + table.set_columns( |
| 3567 | + timestamp=timestamp, |
| 3568 | + timestamp_offset=table.timestamp_offset, |
| 3569 | + record=table.record, |
| 3570 | + record_offset=table.record_offset, |
| 3571 | + ) |
| 3572 | + with pytest.raises( |
| 3573 | + AssertionError, |
| 3574 | + match="ProvenanceTable row 0 differs:\n" |
| 3575 | + "self.timestamp=.* other.timestamp=F.*", |
| 3576 | + ): |
| 3577 | + t1.assert_equals(t2) |
| 3578 | + t1.assert_equals(t2, ignore_provenance=True) |
| 3579 | + t1.assert_equals(t2, ignore_timestamps=True) |
| 3580 | + |
| 3581 | + |
3310 | 3582 | class TestTableCollectionMethodSignatures:
|
3311 | 3583 | tc = msprime.simulate(10, random_seed=1234).dump_tables()
|
3312 | 3584 |
|
|
0 commit comments