Skip to content

Commit 4ec80ea

Browse files
committed
add vectorised metadata, closes #1676
1 parent 9cd4473 commit 4ec80ea

File tree

2 files changed

+180
-0
lines changed

2 files changed

+180
-0
lines changed

python/tests/test_tables.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1112,6 +1112,135 @@ def test_set_columns_metadata_schema(self):
11121112
table.set_columns(**table2.asdict())
11131113
assert table.metadata_schema == ms
11141114

1115+
def verify_metadata_vector(self, table, key, dtype, default_value=9999):
1116+
# this is just a hack for testing; the actual method
1117+
# does this more elegantly
1118+
has_default = default_value != 9999
1119+
if has_default:
1120+
md_vec = table.metadata_vector(
1121+
key, default_value=default_value, dtype=dtype
1122+
)
1123+
else:
1124+
md_vec = table.metadata_vector(key, dtype=dtype)
1125+
assert isinstance(md_vec, np.ndarray)
1126+
if dtype is not None:
1127+
assert md_vec.dtype == np.dtype(dtype)
1128+
assert len(md_vec) == table.num_rows
1129+
if not isinstance(key, list):
1130+
key = [key]
1131+
for x, row in zip(md_vec, table):
1132+
md = row.metadata
1133+
for k in key:
1134+
if k in md or not has_default:
1135+
md = md[k]
1136+
else:
1137+
md = default_value
1138+
break
1139+
assert np.all(np.cast[dtype](md) == x)
1140+
1141+
def test_metadata_vector_errors(self):
1142+
table = self.table_class()
1143+
ms = tskit.MetadataSchema({"codec": "json"})
1144+
table.metadata_schema = ms
1145+
table.add_row(
1146+
**{
1147+
**self.input_data_for_add_row(),
1148+
"metadata": None,
1149+
}
1150+
)
1151+
with pytest.raises(KeyError):
1152+
_ = table.metadata_vector("x")
1153+
metadata_list = [
1154+
{"a": 4, "u": [1, 2]},
1155+
{},
1156+
]
1157+
for md in metadata_list:
1158+
table.add_row(
1159+
**{
1160+
**self.input_data_for_add_row(),
1161+
"metadata": md,
1162+
}
1163+
)
1164+
with pytest.raises(KeyError):
1165+
_ = table.metadata_vector("x")
1166+
1167+
def test_metadata_vector_nodefault(self):
1168+
table = self.table_class()
1169+
ms = tskit.MetadataSchema({"codec": "json"})
1170+
table.metadata_schema = ms
1171+
metadata_list = [
1172+
{"abc": 4, "u": [1, 2]},
1173+
{"abc": 10, "u": [3, 4]},
1174+
{"abc": -3, "b": {"c": 1}, "u": [5, 6]},
1175+
{"abc": 1},
1176+
]
1177+
for md in metadata_list:
1178+
table.add_row(
1179+
**{
1180+
**self.input_data_for_add_row(),
1181+
"metadata": md,
1182+
}
1183+
)
1184+
# first the totally obvious test
1185+
md_vec = table.metadata_vector("abc")
1186+
assert np.all(np.equal(md_vec, [d["abc"] for d in metadata_list]))
1187+
# now automated ones
1188+
for dtype in [None, "int", "float", "object"]:
1189+
self.verify_metadata_vector(
1190+
table, key="abc", dtype=dtype, default_value=9999
1191+
)
1192+
self.verify_metadata_vector(
1193+
table, key=["abc"], dtype=dtype, default_value=9999
1194+
)
1195+
1196+
def test_metadata_vector(self):
1197+
table = self.table_class()
1198+
ms = tskit.MetadataSchema({"codec": "json"})
1199+
table.metadata_schema = ms
1200+
metadata_list = [
1201+
{"abc": 4, "u": [1, 2]},
1202+
{"abc": 10, "u": [3, 4]},
1203+
{"abc": -3, "b": {"c": 1}, "u": [5, 6]},
1204+
{"b": {"c": 3.2}, "u": [7, 8]},
1205+
{"b": {"x": 8.2}},
1206+
{},
1207+
None,
1208+
]
1209+
for md in metadata_list:
1210+
table.add_row(
1211+
**{
1212+
**self.input_data_for_add_row(),
1213+
"metadata": md,
1214+
}
1215+
)
1216+
# first the totally obvious test
1217+
md_vec = table.metadata_vector("abc", default_value=0)
1218+
assert np.all(
1219+
np.equal(
1220+
md_vec,
1221+
[
1222+
d["abc"] if (d is not None and "abc" in d) else 0
1223+
for d in metadata_list
1224+
],
1225+
)
1226+
)
1227+
1228+
# now some automated ones
1229+
for dtype in [None, "int", "float", "object"]:
1230+
self.verify_metadata_vector(table, key="abc", dtype=dtype, default_value=-1)
1231+
self.verify_metadata_vector(
1232+
table, key=["abc"], dtype=dtype, default_value=-1
1233+
)
1234+
self.verify_metadata_vector(table, key=["x"], dtype=dtype, default_value=-1)
1235+
self.verify_metadata_vector(
1236+
table, key=["b", "c"], dtype=dtype, default_value=-1
1237+
)
1238+
self.verify_metadata_vector(table, key=["b"], dtype="object", default_value=-1)
1239+
self.verify_metadata_vector(table, key=["u"], dtype="int", default_value=[0, 0])
1240+
# and finally we should get rectangular arrays when it makes sense
1241+
md_vec = table.metadata_vector("u", default_value=[0, 0])
1242+
assert md_vec.shape == (table.num_rows, 2)
1243+
11151244

11161245
class AssertEqualsMixin:
11171246
def test_equal(self, table_5row, test_rows):

python/tskit/tables.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@
3131
import json
3232
import numbers
3333
import warnings
34+
from collections.abc import Mapping
3435
from dataclasses import dataclass
36+
from functools import reduce
3537
from typing import Any
3638
from typing import Optional
3739
from typing import Union
@@ -47,6 +49,9 @@
4749

4850
dataclass_options = {"frozen": True}
4951

52+
# we'll need this because python can't tell if kwargs are passed in not
53+
DEFAULT = object()
54+
5055

5156
@metadata.lazy_decode
5257
@dataclass(**dataclass_options)
@@ -585,6 +590,52 @@ def _update_metadata_schema_cache_from_ll(self) -> None:
585590
self.ll_table.metadata_schema
586591
)
587592

593+
def metadata_vector(self, key, *, dtype=None, default_value=DEFAULT):
594+
"""
595+
Returns a numpy array of metadata values obtained by extracting ``key``
596+
from each metadata entry, and using ``default_value`` if the key is
597+
not present. ``key`` may be a list, in which case nested values are returned.
598+
For instance, ``key = ["a", "x"]`` will return an array of
599+
``row.metadata["a"]["x"]`` values, iterated over rows in this table.
600+
601+
:param str key: The name, or a list of names, of metadata entries.
602+
:param str dtype: The dtype of the result (can usually be omitted).
603+
:param object default_value: The value to be inserted if the metadata key
604+
is not present. Note that for numeric columns, a default value of None
605+
will result in a non-numeric array.
606+
"""
607+
608+
if default_value == DEFAULT:
609+
610+
def getter(d, k):
611+
return d[k]
612+
613+
else:
614+
615+
def getter(d, k):
616+
return (
617+
d.get(k, default_value) if isinstance(d, Mapping) else default_value
618+
)
619+
620+
if isinstance(key, list):
621+
out = np.array(
622+
[
623+
reduce(
624+
getter,
625+
key,
626+
row.metadata,
627+
)
628+
for row in self
629+
],
630+
dtype=dtype,
631+
)
632+
else:
633+
out = np.array(
634+
[getter(row.metadata, key) for row in self],
635+
dtype=dtype,
636+
)
637+
return out
638+
588639

589640
class IndividualTable(BaseTable, MetadataMixin):
590641
"""

0 commit comments

Comments
 (0)