Skip to content

Commit f56399c

Browse files
committed
more tests
1 parent ccba416 commit f56399c

File tree

2 files changed

+56
-13
lines changed

2 files changed

+56
-13
lines changed

python/tests/test_tables.py

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1112,15 +1112,34 @@ def test_set_columns_metadata_schema(self):
11121112
table.set_columns(**table2.asdict())
11131113
assert table.metadata_schema == ms
11141114

1115+
def verify_metadata_vector(self, table, key, dtype, default_value):
1116+
md_vec = table.metadata_vector(key, default_value=default_value, dtype=dtype)
1117+
assert isinstance(md_vec, np.ndarray)
1118+
if dtype is not None:
1119+
assert md_vec.dtype == np.dtype(dtype)
1120+
assert len(md_vec) == table.num_rows
1121+
if not isinstance(key, list):
1122+
key = [key]
1123+
for x, row in zip(md_vec, table):
1124+
md = row.metadata
1125+
for k in key:
1126+
if k in md:
1127+
md = md[k]
1128+
else:
1129+
md = default_value
1130+
break
1131+
assert np.all(np.cast[dtype](md) == x)
1132+
11151133
def test_metadata_vector(self):
11161134
table = self.table_class()
11171135
ms = tskit.MetadataSchema({"codec": "json"})
11181136
table.metadata_schema = ms
11191137
metadata_list = [
1120-
{"a": 4},
1121-
{"a": 10},
1122-
{"a": -3, "b": {"c": 1}},
1123-
{"b": {"c": 3.2}},
1138+
{"a": 4, "u": [1, 2]},
1139+
{"a": 10, "u": [3, 4]},
1140+
{"a": -3, "b": {"c": 1}, "u": [5, 6]},
1141+
{"b": {"c": 3.2}, "u": [7, 8]},
1142+
{"b": {"x": 8.2}},
11241143
{},
11251144
]
11261145
for md in metadata_list:
@@ -1130,12 +1149,18 @@ def test_metadata_vector(self):
11301149
"metadata": md,
11311150
}
11321151
)
1133-
default_value = -1
1134-
for key in ["a", ["b", "c"]]:
1135-
assert np.equal(
1136-
[md.get(key, default_value) for md in metadata_list],
1137-
table.metadata_vector(key, default_value=default_value),
1152+
for dtype in [None, "int", "float", "object"]:
1153+
self.verify_metadata_vector(table, key="a", dtype=dtype, default_value=-1)
1154+
self.verify_metadata_vector(table, key=["a"], dtype=dtype, default_value=-1)
1155+
self.verify_metadata_vector(table, key=["x"], dtype=dtype, default_value=-1)
1156+
self.verify_metadata_vector(
1157+
table, key=["b", "c"], dtype=dtype, default_value=-1
11381158
)
1159+
self.verify_metadata_vector(table, key=["b"], dtype="object", default_value=-1)
1160+
self.verify_metadata_vector(table, key=["u"], dtype="int", default_value=[0, 0])
1161+
md_vec = table.metadata_vector("u", default_value=[0, 0], dtype="int")
1162+
# and finally we should get rectangular arrays when it makes sense
1163+
assert md_vec.shape == (table.num_rows, 2)
11391164

11401165

11411166
class AssertEqualsMixin:

python/tskit/tables.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import collections.abc
2828
import dataclasses
2929
import datetime
30+
import functools
3031
import itertools
3132
import json
3233
import numbers
@@ -599,10 +600,27 @@ def metadata_vector(self, key, *, dtype=None, default_value=None):
599600
is not present. Note that for numeric columns, a default value of None
600601
will result in a non-numeric array.
601602
"""
602-
return np.fromiter(
603-
[row.metadata.get(key, default_value) for row in self],
604-
dtype=dtype,
605-
)
603+
604+
if isinstance(key, list):
605+
out = np.array(
606+
[
607+
functools.reduce(
608+
lambda d, k: d.get(k, default_value)
609+
if isinstance(d, collections.abc.Mapping)
610+
else default_value,
611+
key,
612+
row.metadata,
613+
)
614+
for row in self
615+
],
616+
dtype=dtype,
617+
)
618+
else:
619+
out = np.array(
620+
[row.metadata.get(key, default_value) for row in self],
621+
dtype=dtype,
622+
)
623+
return out
606624

607625

608626
class IndividualTable(BaseTable, MetadataMixin):

0 commit comments

Comments
 (0)