Skip to content

Commit a44da9a

Browse files
committed
add vectorised metadata, closes #1676
Tweaks for docs and an extra test
1 parent 936bf86 commit a44da9a

File tree

2 files changed

+204
-0
lines changed

2 files changed

+204
-0
lines changed

python/tests/test_tables.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1112,6 +1112,150 @@ def test_set_columns_metadata_schema(self):
11121112
table.set_columns(**table2.asdict())
11131113
assert table.metadata_schema == ms
11141114

1115+
def verify_metadata_vector(self, table, key, dtype, default_value=9999):
1116+
# this is just a hack for testing; the actual method
1117+
# does this more elegantly
1118+
has_default = default_value != 9999
1119+
if has_default:
1120+
md_vec = table.metadata_vector(
1121+
key, default_value=default_value, dtype=dtype
1122+
)
1123+
else:
1124+
md_vec = table.metadata_vector(key, dtype=dtype)
1125+
assert isinstance(md_vec, np.ndarray)
1126+
if dtype is not None:
1127+
assert md_vec.dtype == np.dtype(dtype)
1128+
assert len(md_vec) == table.num_rows
1129+
if not isinstance(key, list):
1130+
key = [key]
1131+
for x, row in zip(md_vec, table):
1132+
md = row.metadata
1133+
for k in key:
1134+
if k in md or not has_default:
1135+
md = md[k]
1136+
else:
1137+
md = default_value
1138+
break
1139+
assert np.all(np.cast[dtype](md) == x)
1140+
1141+
def test_metadata_vector_errors(self):
1142+
table = self.table_class()
1143+
ms = tskit.MetadataSchema({"codec": "json"})
1144+
table.metadata_schema = ms
1145+
table.add_row(
1146+
**{
1147+
**self.input_data_for_add_row(),
1148+
"metadata": None,
1149+
}
1150+
)
1151+
with pytest.raises(KeyError):
1152+
_ = table.metadata_vector("x")
1153+
metadata_list = [
1154+
{"a": 4, "u": [1, 2]},
1155+
{},
1156+
]
1157+
for md in metadata_list:
1158+
table.add_row(
1159+
**{
1160+
**self.input_data_for_add_row(),
1161+
"metadata": md,
1162+
}
1163+
)
1164+
with pytest.raises(KeyError):
1165+
_ = table.metadata_vector("x")
1166+
1167+
table.clear()
1168+
metadata_list = [
1169+
{"a": {"c": 5}, "u": [1, 2]},
1170+
{"a": {"b": 6}},
1171+
]
1172+
for md in metadata_list:
1173+
table.add_row(
1174+
**{
1175+
**self.input_data_for_add_row(),
1176+
"metadata": md,
1177+
}
1178+
)
1179+
with pytest.raises(KeyError):
1180+
_ = table.metadata_vector(["a", "x"])
1181+
1182+
def test_metadata_vector_nodefault(self):
1183+
table = self.table_class()
1184+
ms = tskit.MetadataSchema({"codec": "json"})
1185+
table.metadata_schema = ms
1186+
metadata_list = [
1187+
{"abc": 4, "u": [1, 2]},
1188+
{"abc": 10, "u": [3, 4]},
1189+
{"abc": -3, "b": {"c": 1}, "u": [5, 6]},
1190+
{"abc": 1},
1191+
]
1192+
for md in metadata_list:
1193+
table.add_row(
1194+
**{
1195+
**self.input_data_for_add_row(),
1196+
"metadata": md,
1197+
}
1198+
)
1199+
# first the totally obvious test
1200+
md_vec = table.metadata_vector("abc")
1201+
assert np.all(np.equal(md_vec, [d["abc"] for d in metadata_list]))
1202+
# now automated ones
1203+
for dtype in [None, "int", "float", "object"]:
1204+
self.verify_metadata_vector(
1205+
table, key="abc", dtype=dtype, default_value=9999
1206+
)
1207+
self.verify_metadata_vector(
1208+
table, key=["abc"], dtype=dtype, default_value=9999
1209+
)
1210+
1211+
def test_metadata_vector(self):
1212+
table = self.table_class()
1213+
ms = tskit.MetadataSchema({"codec": "json"})
1214+
table.metadata_schema = ms
1215+
metadata_list = [
1216+
{"abc": 4, "u": [1, 2]},
1217+
{"abc": 10, "u": [3, 4]},
1218+
{"abc": -3, "b": {"c": 1}, "u": [5, 6]},
1219+
{"b": {"c": 3.2}, "u": [7, 8]},
1220+
{"b": {"x": 8.2}},
1221+
{},
1222+
None,
1223+
]
1224+
for md in metadata_list:
1225+
table.add_row(
1226+
**{
1227+
**self.input_data_for_add_row(),
1228+
"metadata": md,
1229+
}
1230+
)
1231+
# first the totally obvious test
1232+
md_vec = table.metadata_vector("abc", default_value=0)
1233+
assert np.all(
1234+
np.equal(
1235+
md_vec,
1236+
[
1237+
d["abc"] if (d is not None and "abc" in d) else 0
1238+
for d in metadata_list
1239+
],
1240+
)
1241+
)
1242+
1243+
# now some automated ones
1244+
for dtype in [None, "int", "float", "object"]:
1245+
self.verify_metadata_vector(table, key="abc", dtype=dtype, default_value=-1)
1246+
self.verify_metadata_vector(
1247+
table, key=["abc"], dtype=dtype, default_value=-1
1248+
)
1249+
self.verify_metadata_vector(table, key=["x"], dtype=dtype, default_value=-1)
1250+
self.verify_metadata_vector(
1251+
table, key=["b", "c"], dtype=dtype, default_value=-1
1252+
)
1253+
self.verify_metadata_vector(table, key=["b"], dtype="object", default_value=-1)
1254+
self.verify_metadata_vector(table, key=["u"], dtype="int", default_value=[0, 0])
1255+
# and finally we should get rectangular arrays when it makes sense
1256+
md_vec = table.metadata_vector("u", default_value=[0, 0])
1257+
assert md_vec.shape == (table.num_rows, 2)
1258+
11151259

11161260
class AssertEqualsMixin:
11171261
def test_equal(self, table_5row, test_rows):

python/tskit/tables.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@
3131
import json
3232
import numbers
3333
import warnings
34+
from collections.abc import Mapping
3435
from dataclasses import dataclass
36+
from functools import reduce
3537
from typing import Any
3638
from typing import Optional
3739
from typing import Union
@@ -48,6 +50,17 @@
4850
dataclass_options = {"frozen": True}
4951

5052

53+
# Needed for cases where `None` can be an appropriate kwarg value,
54+
# we override the meta so that it looks good in the docs.
55+
class NotSetMeta(type):
56+
def __repr__(cls):
57+
return "Not set"
58+
59+
60+
class NOTSET(metaclass=NotSetMeta):
61+
pass
62+
63+
5164
@metadata.lazy_decode
5265
@dataclass(**dataclass_options)
5366
class IndividualTableRow(util.Dataclass):
@@ -585,6 +598,53 @@ def _update_metadata_schema_cache_from_ll(self) -> None:
585598
self.ll_table.metadata_schema
586599
)
587600

601+
def metadata_vector(self, key, *, dtype=None, default_value=NOTSET):
602+
"""
603+
Returns a numpy array of metadata values obtained by extracting ``key``
604+
from each metadata entry, and using ``default_value`` if the key is
605+
not present. ``key`` may be a list, in which case nested values are returned.
606+
For instance, ``key = ["a", "x"]`` will return an array of
607+
``row.metadata["a"]["x"]`` values, iterated over rows in this table.
608+
609+
:param str key: The name, or a list of names, of metadata entries.
610+
:param str dtype: The dtype of the result (can usually be omitted).
611+
:param object default_value: The value to be inserted if the metadata key
612+
is not present. Note that for numeric columns, a default value of None
613+
will result in a non-numeric array. The default behaviour is to raise
614+
``KeyError`` on missing entries.
615+
"""
616+
617+
if default_value == NOTSET:
618+
619+
def getter(d, k):
620+
return d[k]
621+
622+
else:
623+
624+
def getter(d, k):
625+
return (
626+
d.get(k, default_value) if isinstance(d, Mapping) else default_value
627+
)
628+
629+
if isinstance(key, list):
630+
out = np.array(
631+
[
632+
reduce(
633+
getter,
634+
key,
635+
row.metadata,
636+
)
637+
for row in self
638+
],
639+
dtype=dtype,
640+
)
641+
else:
642+
out = np.array(
643+
[getter(row.metadata, key) for row in self],
644+
dtype=dtype,
645+
)
646+
return out
647+
588648

589649
class IndividualTable(BaseTable, MetadataMixin):
590650
"""

0 commit comments

Comments
 (0)