Skip to content

Make metadata lazy for table row classes #1261

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ def handle_item(fieldarg, content):
("c:type", "uint32_t"),
("c:type", "bool"),
("py:class", "tskit.metadata.AbstractMetadataCodec"),
("py:class", "tskit.trees.Site"),
# TODO these have been triaged here to make the docs compile, but we should
# sort them out properly. https://github.com/tskit-dev/tskit/issues/336
("py:class", "array_like"),
Expand Down
9 changes: 9 additions & 0 deletions python/CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,20 @@

**Breaking changes**

- `Mutation.position` and `Mutation.index` which were deprecated in 0.2.2 (Sep '19) have
been removed.

**Features**

- SVG visualization of a single tree allows all mutations on an edge to be plotted
via the ``all_edge_mutations`` param (:user:`hyanwong`,:issue:`1253`, :pr:`1258`).

- Entity classes such as `Mutation`, `Node` are now python dataclasses
(:user:`benjeffery`, :pr:`1261`).

- Metadata decoding for table row access is now lazy (:user:`benjeffery`, :pr:`1261`).


**Fixes**

--------------------
Expand Down
8 changes: 4 additions & 4 deletions python/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,13 +234,13 @@ def __init__(self, tree_sequence, breakpoints=None):
def make_mutation(id_):
site, node, derived_state, parent, metadata, time = ll_ts.get_mutation(id_)
return tskit.Mutation(
id_=id_,
id=id_,
site=site,
node=node,
time=time,
derived_state=derived_state,
parent=parent,
encoded_metadata=metadata,
metadata=metadata,
metadata_decoder=tskit.metadata.parse_metadata_schema(
ll_ts.get_table_metadata_schemas().mutation
).decode_row,
Expand All @@ -250,11 +250,11 @@ def make_mutation(id_):
pos, ancestral_state, ll_mutations, id_, metadata = ll_ts.get_site(j)
self._sites.append(
tskit.Site(
id_=id_,
id=id_,
position=pos,
ancestral_state=ancestral_state,
mutations=[make_mutation(ll_mut) for ll_mut in ll_mutations],
encoded_metadata=metadata,
metadata=metadata,
metadata_decoder=tskit.metadata.parse_metadata_schema(
ll_ts.get_table_metadata_schemas().site
).decode_row,
Expand Down
72 changes: 27 additions & 45 deletions python/tests/test_highlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1241,17 +1241,7 @@ def verify_mutations(self, ts):
assert ts.num_mutations == len(other_mutations)
assert ts.num_mutations == len(mutations)
for mut, other_mut in zip(mutations, other_mutations):
# We cannot compare these directly as the mutations obtained
# from the mutations iterator will have extra deprecated
# attributes.
assert mut.id == other_mut.id
assert mut.site == other_mut.site
assert mut.parent == other_mut.parent
assert mut.node == other_mut.node
assert mut.metadata == other_mut.metadata
# Check the deprecated attrs.
assert mut.position == ts.site(mut.site).position
assert mut.index == mut.site
assert mut == other_mut

def test_sites_mutations(self):
# Check that the mutations iterator returns the correct values.
Expand Down Expand Up @@ -2103,17 +2093,7 @@ def verify_mutations(self, tree):
assert tree.num_mutations == len(other_mutations)
assert tree.num_mutations == len(mutations)
for mut, other_mut in zip(mutations, other_mutations):
# We cannot compare these directly as the mutations obtained
# from the mutations iterator will have extra deprecated
# attributes.
assert mut.id == other_mut.id
assert mut.site == other_mut.site
assert mut.parent == other_mut.parent
assert mut.node == other_mut.node
assert mut.metadata == other_mut.metadata
# Check the deprecated attrs.
assert mut.position == tree.tree_sequence.site(mut.site).position
assert mut.index == mut.site
assert mut == other_mut

def test_simple_mutations(self):
tree = self.get_tree()
Expand Down Expand Up @@ -2991,17 +2971,18 @@ def test_metadata(self):
(inst,) = self.get_instances(1)
(inst2,) = self.get_instances(1)
assert inst == inst2
inst._metadata_decoder = lambda m: "different decoder"
inst.metadata
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this statement do?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought it was forcing decode of the metadata, but that is done by the equality operator above anyway, removed.

assert inst == inst2
inst._encoded_metadata = b"different"
assert not (inst == inst2)
inst._metadata = "different"
assert inst != inst2

def test_decoder_run_once(self):
# For a given instance, the decoded metadata should be cached, with the decoder
# called once
(inst,) = self.get_instances(1)
times_run = 0

# Hack in a tracing decoder
def decoder(m):
nonlocal times_run
times_run += 1
Expand All @@ -3019,12 +3000,12 @@ class TestIndividualContainer(SimpleContainersMixin, SimpleContainersWithMetadat
def get_instances(self, n):
return [
tskit.Individual(
id_=j,
id=j,
flags=j,
location=[j],
parents=[j],
nodes=[j],
encoded_metadata=b"x" * j,
metadata=b"x" * j,
metadata_decoder=lambda m: m.decode() + "decoded",
)
for j in range(n)
Expand All @@ -3035,12 +3016,12 @@ class TestNodeContainer(SimpleContainersMixin, SimpleContainersWithMetadataMixin
def get_instances(self, n):
return [
tskit.Node(
id_=j,
id=j,
flags=j,
time=j,
population=j,
individual=j,
encoded_metadata=b"x" * j,
metadata=b"x" * j,
metadata_decoder=lambda m: m.decode() + "decoded",
)
for j in range(n)
Expand All @@ -3055,9 +3036,9 @@ def get_instances(self, n):
right=j,
parent=j,
child=j,
encoded_metadata=b"x" * j,
metadata=b"x" * j,
metadata_decoder=lambda m: m.decode() + "decoded",
id_=j,
id=j,
)
for j in range(n)
]
Expand All @@ -3067,11 +3048,11 @@ class TestSiteContainer(SimpleContainersMixin, SimpleContainersWithMetadataMixin
def get_instances(self, n):
return [
tskit.Site(
id_=j,
id=j,
position=j,
ancestral_state="A" * j,
mutations=TestMutationContainer().get_instances(j),
encoded_metadata=b"x" * j,
metadata=b"x" * j,
metadata_decoder=lambda m: m.decode() + "decoded",
)
for j in range(n)
Expand All @@ -3082,46 +3063,46 @@ class TestMutationContainer(SimpleContainersMixin, SimpleContainersWithMetadataM
def get_instances(self, n):
return [
tskit.Mutation(
id_=j,
id=j,
site=j,
node=j,
time=j,
derived_state="A" * j,
parent=j,
encoded_metadata=b"x" * j,
metadata=b"x" * j,
metadata_decoder=lambda m: m.decode() + "decoded",
)
for j in range(n)
]

def test_nan_equality(self):
a = tskit.Mutation(
id_=42,
id=42,
site=42,
node=42,
time=UNKNOWN_TIME,
derived_state="A" * 42,
parent=42,
encoded_metadata=b"x" * 42,
metadata=b"x" * 42,
metadata_decoder=lambda m: m.decode() + "decoded",
)
b = tskit.Mutation(
id_=42,
id=42,
site=42,
node=42,
derived_state="A" * 42,
parent=42,
encoded_metadata=b"x" * 42,
metadata=b"x" * 42,
metadata_decoder=lambda m: m.decode() + "decoded",
)
c = tskit.Mutation(
id_=42,
id=42,
site=42,
node=42,
time=math.nan,
derived_state="A" * 42,
parent=42,
encoded_metadata=b"x" * 42,
metadata=b"x" * 42,
metadata_decoder=lambda m: m.decode() + "decoded",
)
assert a == a
Expand All @@ -3139,13 +3120,14 @@ class TestMigrationContainer(SimpleContainersMixin, SimpleContainersWithMetadata
def get_instances(self, n):
return [
tskit.Migration(
id=j,
left=j,
right=j,
node=j,
source=j,
dest=j,
time=j,
encoded_metadata=b"x" * j,
metadata=b"x" * j,
metadata_decoder=lambda m: m.decode() + "decoded",
)
for j in range(n)
Expand All @@ -3156,8 +3138,8 @@ class TestPopulationContainer(SimpleContainersMixin, SimpleContainersWithMetadat
def get_instances(self, n):
return [
tskit.Population(
id_=j,
encoded_metadata=b"x" * j,
id=j,
metadata=b"x" * j,
metadata_decoder=lambda m: m.decode() + "decoded",
)
for j in range(n)
Expand All @@ -3167,7 +3149,7 @@ def get_instances(self, n):
class TestProvenanceContainer(SimpleContainersMixin):
def get_instances(self, n):
return [
tskit.Provenance(id_=j, timestamp="x" * j, record="y" * j) for j in range(n)
tskit.Provenance(id=j, timestamp="x" * j, record="y" * j) for j in range(n)
]


Expand Down
10 changes: 8 additions & 2 deletions python/tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,17 @@ def verify_max_distance(self, ts):
A = ldc.get_r2_matrix()
j = len(mutations) // 2
for k in range(j):
x = mutations[j + k].position - mutations[j].position
x = (
ts.site(mutations[j + k].site).position
- ts.site(mutations[j].site).position
)
a = ldc.get_r2_array(j, max_distance=x)
assert a.shape[0] == k
assert np.allclose(A[j, j + 1 : j + 1 + k], a)
x = mutations[j].position - mutations[j - k].position
x = (
ts.site(mutations[j].site).position
- ts.site(mutations[j - k].site).position
)
a = ldc.get_r2_array(j, max_distance=x, direction=tskit.REVERSE)
assert a.shape[0] == k
assert np.allclose(A[j, j - k : j], a[::-1])
Expand Down
57 changes: 57 additions & 0 deletions python/tskit/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@
Classes for metadata decoding, encoding and validation
"""
import abc
import builtins
import collections
import copy
import json
import pprint
import struct
import types
from itertools import islice
from typing import Any
from typing import Mapping
Expand All @@ -39,6 +41,8 @@
import tskit
import tskit.exceptions as exceptions

__builtins__object__setattr__ = builtins.object.__setattr__


def replace_root_refs(obj):
if type(obj) == list:
Expand Down Expand Up @@ -656,3 +660,56 @@ def parse_metadata_schema(encoded_schema: str) -> MetadataSchema:
except json.decoder.JSONDecodeError:
raise ValueError(f"Metadata schema is not JSON, found {encoded_schema}")
return MetadataSchema(decoded)


class _CachedMetadata:
"""
Descriptor for lazy decoding of metadata on attribute access.
"""

def __get__(self, row, owner):
if row._metadata_decoder is not None:
# Some classes that use this are frozen so we need to directly setattr.
__builtins__object__setattr__(
row, "_metadata", row._metadata_decoder(row._metadata)
)
# Decoder being None indicates that metadata is decoded
__builtins__object__setattr__(row, "_metadata_decoder", None)
return row._metadata

def __set__(self, row, value):
__builtins__object__setattr__(row, "_metadata", value)


def lazy_decode(cls):
"""
Modifies a dataclass such that it lazily decodes metadata, if it is encoded.
If the metadata passed to the constructor is encoded a `metadata_decoder` parameter
must be also be passed.
"""
wrapped_init = cls.__init__

# Intercept the init to record the decoder
def new_init(self, *args, metadata_decoder=None, **kwargs):
__builtins__object__setattr__(self, "_metadata_decoder", metadata_decoder)
wrapped_init(self, *args, **kwargs)

cls.__init__ = new_init

# Add a descriptor to the class to decode and cache metadata
cls.metadata = _CachedMetadata()

# Add slots needed to the class
slots = cls.__slots__
slots.extend(["_metadata", "_metadata_decoder"])
dict_ = dict()
sloted_members = dict()
for k, v in cls.__dict__.items():
if k not in slots:
dict_[k] = v
elif not isinstance(v, types.MemberDescriptorType):
sloted_members[k] = v
new_cls = type(cls.__name__, cls.__bases__, dict_)
for k, v in sloted_members.items():
setattr(new_cls, k, v)
return new_cls
Loading