diff --git a/docs/conf.py b/docs/conf.py index 3ef40cf9a6..31ee5ddaa6 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -307,6 +307,7 @@ def handle_item(fieldarg, content): ("c:type", "uint32_t"), ("c:type", "bool"), ("py:class", "tskit.metadata.AbstractMetadataCodec"), + ("py:class", "tskit.trees.Site"), # TODO these have been triaged here to make the docs compile, but we should # sort them out properly. https://github.com/tskit-dev/tskit/issues/336 ("py:class", "array_like"), diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index 425f4e3147..db5a7b85b7 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -4,11 +4,20 @@ **Breaking changes** +- `Mutation.position` and `Mutation.index` which were deprecated in 0.2.2 (Sep '19) have + been removed. + **Features** - SVG visualization of a single tree allows all mutations on an edge to be plotted via the ``all_edge_mutations`` param (:user:`hyanwong`,:issue:`1253`, :pr:`1258`). +- Entity classes such as `Mutation`, `Node` are now python dataclasses + (:user:`benjeffery`, :pr:`1261`). + +- Metadata decoding for table row access is now lazy (:user:`benjeffery`, :pr:`1261`). + + **Fixes** -------------------- diff --git a/python/tests/__init__.py b/python/tests/__init__.py index aab77205b5..e49da94aaf 100644 --- a/python/tests/__init__.py +++ b/python/tests/__init__.py @@ -234,13 +234,13 @@ def __init__(self, tree_sequence, breakpoints=None): def make_mutation(id_): site, node, derived_state, parent, metadata, time = ll_ts.get_mutation(id_) return tskit.Mutation( - id_=id_, + id=id_, site=site, node=node, time=time, derived_state=derived_state, parent=parent, - encoded_metadata=metadata, + metadata=metadata, metadata_decoder=tskit.metadata.parse_metadata_schema( ll_ts.get_table_metadata_schemas().mutation ).decode_row, @@ -250,11 +250,11 @@ def make_mutation(id_): pos, ancestral_state, ll_mutations, id_, metadata = ll_ts.get_site(j) self._sites.append( tskit.Site( - id_=id_, + id=id_, position=pos, ancestral_state=ancestral_state, mutations=[make_mutation(ll_mut) for ll_mut in ll_mutations], - encoded_metadata=metadata, + metadata=metadata, metadata_decoder=tskit.metadata.parse_metadata_schema( ll_ts.get_table_metadata_schemas().site ).decode_row, diff --git a/python/tests/test_highlevel.py b/python/tests/test_highlevel.py index de4590e5f3..ab19a2eeda 100644 --- a/python/tests/test_highlevel.py +++ b/python/tests/test_highlevel.py @@ -1241,17 +1241,7 @@ def verify_mutations(self, ts): assert ts.num_mutations == len(other_mutations) assert ts.num_mutations == len(mutations) for mut, other_mut in zip(mutations, other_mutations): - # We cannot compare these directly as the mutations obtained - # from the mutations iterator will have extra deprecated - # attributes. - assert mut.id == other_mut.id - assert mut.site == other_mut.site - assert mut.parent == other_mut.parent - assert mut.node == other_mut.node - assert mut.metadata == other_mut.metadata - # Check the deprecated attrs. - assert mut.position == ts.site(mut.site).position - assert mut.index == mut.site + assert mut == other_mut def test_sites_mutations(self): # Check that the mutations iterator returns the correct values. @@ -2103,17 +2093,7 @@ def verify_mutations(self, tree): assert tree.num_mutations == len(other_mutations) assert tree.num_mutations == len(mutations) for mut, other_mut in zip(mutations, other_mutations): - # We cannot compare these directly as the mutations obtained - # from the mutations iterator will have extra deprecated - # attributes. - assert mut.id == other_mut.id - assert mut.site == other_mut.site - assert mut.parent == other_mut.parent - assert mut.node == other_mut.node - assert mut.metadata == other_mut.metadata - # Check the deprecated attrs. - assert mut.position == tree.tree_sequence.site(mut.site).position - assert mut.index == mut.site + assert mut == other_mut def test_simple_mutations(self): tree = self.get_tree() @@ -2991,10 +2971,10 @@ def test_metadata(self): (inst,) = self.get_instances(1) (inst2,) = self.get_instances(1) assert inst == inst2 - inst._metadata_decoder = lambda m: "different decoder" + inst.metadata assert inst == inst2 - inst._encoded_metadata = b"different" - assert not (inst == inst2) + inst._metadata = "different" + assert inst != inst2 def test_decoder_run_once(self): # For a given instance, the decoded metadata should be cached, with the decoder @@ -3002,6 +2982,7 @@ def test_decoder_run_once(self): (inst,) = self.get_instances(1) times_run = 0 + # Hack in a tracing decoder def decoder(m): nonlocal times_run times_run += 1 @@ -3019,12 +3000,12 @@ class TestIndividualContainer(SimpleContainersMixin, SimpleContainersWithMetadat def get_instances(self, n): return [ tskit.Individual( - id_=j, + id=j, flags=j, location=[j], parents=[j], nodes=[j], - encoded_metadata=b"x" * j, + metadata=b"x" * j, metadata_decoder=lambda m: m.decode() + "decoded", ) for j in range(n) @@ -3035,12 +3016,12 @@ class TestNodeContainer(SimpleContainersMixin, SimpleContainersWithMetadataMixin def get_instances(self, n): return [ tskit.Node( - id_=j, + id=j, flags=j, time=j, population=j, individual=j, - encoded_metadata=b"x" * j, + metadata=b"x" * j, metadata_decoder=lambda m: m.decode() + "decoded", ) for j in range(n) @@ -3055,9 +3036,9 @@ def get_instances(self, n): right=j, parent=j, child=j, - encoded_metadata=b"x" * j, + metadata=b"x" * j, metadata_decoder=lambda m: m.decode() + "decoded", - id_=j, + id=j, ) for j in range(n) ] @@ -3067,11 +3048,11 @@ class TestSiteContainer(SimpleContainersMixin, SimpleContainersWithMetadataMixin def get_instances(self, n): return [ tskit.Site( - id_=j, + id=j, position=j, ancestral_state="A" * j, mutations=TestMutationContainer().get_instances(j), - encoded_metadata=b"x" * j, + metadata=b"x" * j, metadata_decoder=lambda m: m.decode() + "decoded", ) for j in range(n) @@ -3082,13 +3063,13 @@ class TestMutationContainer(SimpleContainersMixin, SimpleContainersWithMetadataM def get_instances(self, n): return [ tskit.Mutation( - id_=j, + id=j, site=j, node=j, time=j, derived_state="A" * j, parent=j, - encoded_metadata=b"x" * j, + metadata=b"x" * j, metadata_decoder=lambda m: m.decode() + "decoded", ) for j in range(n) @@ -3096,32 +3077,32 @@ def get_instances(self, n): def test_nan_equality(self): a = tskit.Mutation( - id_=42, + id=42, site=42, node=42, time=UNKNOWN_TIME, derived_state="A" * 42, parent=42, - encoded_metadata=b"x" * 42, + metadata=b"x" * 42, metadata_decoder=lambda m: m.decode() + "decoded", ) b = tskit.Mutation( - id_=42, + id=42, site=42, node=42, derived_state="A" * 42, parent=42, - encoded_metadata=b"x" * 42, + metadata=b"x" * 42, metadata_decoder=lambda m: m.decode() + "decoded", ) c = tskit.Mutation( - id_=42, + id=42, site=42, node=42, time=math.nan, derived_state="A" * 42, parent=42, - encoded_metadata=b"x" * 42, + metadata=b"x" * 42, metadata_decoder=lambda m: m.decode() + "decoded", ) assert a == a @@ -3139,13 +3120,14 @@ class TestMigrationContainer(SimpleContainersMixin, SimpleContainersWithMetadata def get_instances(self, n): return [ tskit.Migration( + id=j, left=j, right=j, node=j, source=j, dest=j, time=j, - encoded_metadata=b"x" * j, + metadata=b"x" * j, metadata_decoder=lambda m: m.decode() + "decoded", ) for j in range(n) @@ -3156,8 +3138,8 @@ class TestPopulationContainer(SimpleContainersMixin, SimpleContainersWithMetadat def get_instances(self, n): return [ tskit.Population( - id_=j, - encoded_metadata=b"x" * j, + id=j, + metadata=b"x" * j, metadata_decoder=lambda m: m.decode() + "decoded", ) for j in range(n) @@ -3167,7 +3149,7 @@ def get_instances(self, n): class TestProvenanceContainer(SimpleContainersMixin): def get_instances(self, n): return [ - tskit.Provenance(id_=j, timestamp="x" * j, record="y" * j) for j in range(n) + tskit.Provenance(id=j, timestamp="x" * j, record="y" * j) for j in range(n) ] diff --git a/python/tests/test_stats.py b/python/tests/test_stats.py index e8c1d38392..350e96cf5d 100644 --- a/python/tests/test_stats.py +++ b/python/tests/test_stats.py @@ -108,11 +108,17 @@ def verify_max_distance(self, ts): A = ldc.get_r2_matrix() j = len(mutations) // 2 for k in range(j): - x = mutations[j + k].position - mutations[j].position + x = ( + ts.site(mutations[j + k].site).position + - ts.site(mutations[j].site).position + ) a = ldc.get_r2_array(j, max_distance=x) assert a.shape[0] == k assert np.allclose(A[j, j + 1 : j + 1 + k], a) - x = mutations[j].position - mutations[j - k].position + x = ( + ts.site(mutations[j].site).position + - ts.site(mutations[j - k].site).position + ) a = ldc.get_r2_array(j, max_distance=x, direction=tskit.REVERSE) assert a.shape[0] == k assert np.allclose(A[j, j - k : j], a[::-1]) diff --git a/python/tskit/metadata.py b/python/tskit/metadata.py index 6ccf2878f9..29acb45bef 100644 --- a/python/tskit/metadata.py +++ b/python/tskit/metadata.py @@ -23,11 +23,13 @@ Classes for metadata decoding, encoding and validation """ import abc +import builtins import collections import copy import json import pprint import struct +import types from itertools import islice from typing import Any from typing import Mapping @@ -39,6 +41,8 @@ import tskit import tskit.exceptions as exceptions +__builtins__object__setattr__ = builtins.object.__setattr__ + def replace_root_refs(obj): if type(obj) == list: @@ -656,3 +660,56 @@ def parse_metadata_schema(encoded_schema: str) -> MetadataSchema: except json.decoder.JSONDecodeError: raise ValueError(f"Metadata schema is not JSON, found {encoded_schema}") return MetadataSchema(decoded) + + +class _CachedMetadata: + """ + Descriptor for lazy decoding of metadata on attribute access. + """ + + def __get__(self, row, owner): + if row._metadata_decoder is not None: + # Some classes that use this are frozen so we need to directly setattr. + __builtins__object__setattr__( + row, "_metadata", row._metadata_decoder(row._metadata) + ) + # Decoder being None indicates that metadata is decoded + __builtins__object__setattr__(row, "_metadata_decoder", None) + return row._metadata + + def __set__(self, row, value): + __builtins__object__setattr__(row, "_metadata", value) + + +def lazy_decode(cls): + """ + Modifies a dataclass such that it lazily decodes metadata, if it is encoded. + If the metadata passed to the constructor is encoded a `metadata_decoder` parameter + must be also be passed. + """ + wrapped_init = cls.__init__ + + # Intercept the init to record the decoder + def new_init(self, *args, metadata_decoder=None, **kwargs): + __builtins__object__setattr__(self, "_metadata_decoder", metadata_decoder) + wrapped_init(self, *args, **kwargs) + + cls.__init__ = new_init + + # Add a descriptor to the class to decode and cache metadata + cls.metadata = _CachedMetadata() + + # Add slots needed to the class + slots = cls.__slots__ + slots.extend(["_metadata", "_metadata_decoder"]) + dict_ = dict() + sloted_members = dict() + for k, v in cls.__dict__.items(): + if k not in slots: + dict_[k] = v + elif not isinstance(v, types.MemberDescriptorType): + sloted_members[k] = v + new_cls = type(cls.__name__, cls.__bases__, dict_) + for k, v in sloted_members.items(): + setattr(new_cls, k, v) + return new_cls diff --git a/python/tskit/tables.py b/python/tskit/tables.py index bc969bf492..3816a8ad23 100644 --- a/python/tskit/tables.py +++ b/python/tskit/tables.py @@ -32,7 +32,6 @@ import warnings from dataclasses import dataclass from typing import Any -from typing import Tuple import numpy as np @@ -46,31 +45,27 @@ dataclass_options = {"frozen": True} -@dataclass(eq=False, **dataclass_options) +@metadata.lazy_decode +@dataclass(**dataclass_options) class IndividualTableRow: __slots__ = ["flags", "location", "parents", "metadata"] flags: int location: np.ndarray parents: np.ndarray - metadata: bytes + metadata: Any + # We need a custom eq for the numpy arrays def __eq__(self, other): - if not isinstance(other, type(self)): - return False - else: - return all( - ( - self.flags == other.flags, - np.array_equal(self.location, other.location), - np.array_equal(self.parents, other.parents), - self.metadata == other.metadata, - ) - ) - - def __neq__(self, other): - return not self.__eq__(other) + return ( + isinstance(other, IndividualTableRow) + and self.flags == other.flags + and np.array_equal(self.location, other.location) + and np.array_equal(self.parents, other.parents) + and self.metadata == other.metadata + ) +@metadata.lazy_decode @dataclass(**dataclass_options) class NodeTableRow: __slots__ = ["flags", "time", "population", "individual", "metadata"] @@ -78,9 +73,10 @@ class NodeTableRow: time: float population: int individual: int - metadata: bytes + metadata: Any +@metadata.lazy_decode @dataclass(**dataclass_options) class EdgeTableRow: __slots__ = ["left", "right", "parent", "child", "metadata"] @@ -88,9 +84,10 @@ class EdgeTableRow: right: float parent: int child: int - metadata: bytes + metadata: Any +@metadata.lazy_decode @dataclass(**dataclass_options) class MigrationTableRow: __slots__ = ["left", "right", "node", "source", "dest", "time", "metadata"] @@ -100,27 +97,30 @@ class MigrationTableRow: source: int dest: int time: float - metadata: bytes + metadata: Any +@metadata.lazy_decode @dataclass(**dataclass_options) class SiteTableRow: __slots__ = ["position", "ancestral_state", "metadata"] position: float ancestral_state: str - metadata: bytes + metadata: Any -@dataclass(eq=False, **dataclass_options) +@metadata.lazy_decode +@dataclass(**dataclass_options) class MutationTableRow: __slots__ = ["site", "node", "derived_state", "parent", "metadata", "time"] site: int node: int derived_state: str parent: int - metadata: bytes + metadata: Any time: float + # We need a custom eq here as we have unknown times (nans) to check def __eq__(self, other): return ( isinstance(other, MutationTableRow) @@ -138,10 +138,11 @@ def __eq__(self, other): ) +@metadata.lazy_decode @dataclass(**dataclass_options) class PopulationTableRow: __slots__ = ["metadata"] - metadata: bytes + metadata: Any @dataclass(**dataclass_options) @@ -277,7 +278,7 @@ def __setattr__(self, name, value): def __getitem__(self, index): """ - Return the specifed row of this table, decoding metadata if it is present. + Return the specified row of this table, decoding metadata if it is present. Supports negative indexing, e.g. ``table[-5]``. :param int index: the zero-index of the desired row @@ -286,13 +287,7 @@ def __getitem__(self, index): index += len(self) if index < 0 or index >= len(self): raise IndexError("Index out of bounds") - row = self.ll_table.get_row(index) - try: - row = self.decode_row(row) - except AttributeError: - # This means the class returns the low-level row unchanged. - pass - return self.row_class(*row) + return self.row_class(*self.ll_table.get_row(index)) def clear(self): """ @@ -395,9 +390,14 @@ class MetadataMixin: """ def __init__(self): - self.metadata_column_index = [ - field.name for field in dataclasses.fields(self.row_class) - ].index("metadata") + base_row_class = self.row_class + + def row_class(*args, **kwargs): + return base_row_class( + *args, **kwargs, metadata_decoder=self.metadata_schema.decode_row + ) + + self.row_class = row_class self._update_metadata_schema_cache_from_ll() def packset_metadata(self, metadatas): @@ -431,13 +431,6 @@ def metadata_schema(self, schema: metadata.MetadataSchema) -> None: self.ll_table.metadata_schema = repr(schema) self._update_metadata_schema_cache_from_ll() - def decode_row(self, row: Tuple[Any]) -> Tuple: - return ( - row[: self.metadata_column_index] - + (self._metadata_schema_cache.decode_row(row[self.metadata_column_index]),) - + row[self.metadata_column_index + 1 :] - ) - def _update_metadata_schema_cache_from_ll(self) -> None: self._metadata_schema_cache = metadata.parse_metadata_schema( self.ll_table.metadata_schema diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 0412ca836c..bf63d81425 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -27,15 +27,14 @@ import base64 import collections import concurrent.futures -import copy import functools import itertools import math import textwrap import warnings +from dataclasses import dataclass from typing import Any -import attr import numpy as np import _tskit @@ -78,58 +77,9 @@ def span(self): return self.right - self.left -# TODO this interface is rubbish. Should have much better printing options. -# TODO we should be use __slots__ here probably. -class SimpleContainer: - def __eq__(self, other): - return self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - def __repr__(self): - return repr(self.__dict__) - - -class SimpleContainerWithMetadata(SimpleContainer): - """ - This class allows metadata to be lazily decoded and cached - """ - - class CachedMetadata: - """ - If we had python>=3.8 we could just use @functools.cached_property here. We - don't so we implement it similarly using a descriptor - """ - - def __get__(self, container: "SimpleContainerWithMetadata", owner: type): - decoded = container._metadata_decoder(container._encoded_metadata) - container.__dict__["metadata"] = decoded - return decoded - - metadata: Any = CachedMetadata() - - def __eq__(self, other: SimpleContainer) -> bool: - # We need to remove metadata and the decoder so we are just comparing - # the encoded metadata, along with the other attributes - other = {**other.__dict__} - other["metadata"] = None - other["_metadata_decoder"] = None - self_ = {**self.__dict__} - self_["metadata"] = None - self_["_metadata_decoder"] = None - return self_ == other - - def __repr__(self) -> str: - # Make sure we have a decoded metadata - _ = self.metadata - out = {**self.__dict__} - del out["_encoded_metadata"] - del out["_metadata_decoder"] - return repr(out) - - -class Individual(SimpleContainerWithMetadata): +@metadata_module.lazy_decode +@dataclass +class Individual: """ An :ref:`individual ` in a tree sequence. Since nodes correspond to genomes, individuals are associated with a collection @@ -138,57 +88,54 @@ class Individual(SimpleContainerWithMetadata): Modifying the attributes in this class will have **no effect** on the underlying tree sequence data. - - :ivar id: The integer ID of this individual. Varies from 0 to - :attr:`TreeSequence.num_individuals` - 1. - :vartype id: int - :ivar flags: The bitwise flags for this individual. - :vartype flags: int - :ivar location: The spatial location of this individual as a numpy array. The - location is an empty array if no spatial location is defined. - :vartype location: numpy.ndarray - :ivar parents: The parent individual ids of this individual as a numpy array. The - parents is an empty array if no parents are defined. - :vartype parents: numpy.ndarray - :ivar nodes: The IDs of the nodes that are associated with this individual as - a numpy array (dtype=np.int32). If no nodes are associated with the - individual this array will be empty. - :vartype nodes: numpy.ndarray - :ivar metadata: The decoded :ref:`metadata ` - for this individual. - :vartype metadata: object """ - def __init__( - self, - id_=None, - flags=0, - location=None, - parents=None, - nodes=None, - encoded_metadata=b"", - metadata_decoder=lambda metadata: metadata, - ): - self.id = id_ - self.flags = flags - self.location = location - self.parents = parents - self._encoded_metadata = encoded_metadata - self._metadata_decoder = metadata_decoder - self.nodes = nodes + __slots__ = ["id", "flags", "location", "parents", "nodes", "metadata"] + id: int # noqa A003 + """ + The integer ID of this individual. Varies from 0 to + :attr:`TreeSequence.num_individuals` - 1.""" + flags: int + """ + The bitwise flags for this individual. + """ + location: np.ndarray + """ + The spatial location of this individual as a numpy array. The location is an empty + array if no spatial location is defined. + """ + parents: np.ndarray + """ + The parent individual ids of this individual as a numpy array. The parents is an + empty array if no parents are defined. + """ + nodes: np.ndarray + """ + The IDs of the nodes that are associated with this individual as + a numpy array (dtype=np.int32). If no nodes are associated with the + individual this array will be empty. + """ + metadata: Any + """ + The :ref:`metadata ` + for this individual, decoded if a schema applies. + """ + # Custom eq for the numpy arrays def __eq__(self, other): return ( self.id == other.id and self.flags == other.flags - and self._encoded_metadata == other._encoded_metadata - and np.array_equal(self.nodes, other.nodes) and np.array_equal(self.location, other.location) and np.array_equal(self.parents, other.parents) + and np.array_equal(self.nodes, other.nodes) + and self.metadata == other.metadata ) -class Node(SimpleContainerWithMetadata): +@metadata_module.lazy_decode +@dataclass +class Node: """ A :ref:`node ` in a tree sequence, corresponding to a single genome. The ``time`` and ``population`` are attributes of the @@ -197,39 +144,34 @@ class Node(SimpleContainerWithMetadata): Modifying the attributes in this class will have **no effect** on the underlying tree sequence data. - - :ivar id: The integer ID of this node. Varies from 0 to - :attr:`TreeSequence.num_nodes` - 1. - :vartype id: int - :ivar flags: The bitwise flags for this node. - :vartype flags: int - :ivar time: The birth time of this node. - :vartype time: float - :ivar population: The integer ID of the population that this node was born in. - :vartype population: int - :ivar individual: The integer ID of the individual that this node was a part of. - :vartype individual: int - :ivar metadata: The decoded :ref:`metadata ` for this node. - :vartype metadata: object """ - def __init__( - self, - id_=None, - flags=0, - time=0, - population=NULL, - individual=NULL, - encoded_metadata=b"", - metadata_decoder=lambda metadata: metadata, - ): - self.id = id_ - self.time = time - self.population = population - self.individual = individual - self._encoded_metadata = encoded_metadata - self._metadata_decoder = metadata_decoder - self.flags = flags + __slots__ = ["id", "flags", "time", "population", "individual", "metadata"] + id: int # noqa A003 + """ + The integer ID of this node. Varies from 0 to :attr:`TreeSequence.num_nodes` - 1. + """ + flags: int + """ + The bitwise flags for this node. + """ + time: float + """ + The birth time of this node. + """ + population: int + """ + The integer ID of the population that this node was born in. + """ + individual: int + """ + The integer ID of the individual that this node was a part of. + """ + metadata: Any + """ + The :ref:`metadata ` for this node, decoded if a schema + applies. + """ def is_sample(self): """ @@ -241,57 +183,56 @@ def is_sample(self): return self.flags & NODE_IS_SAMPLE -class Edge(SimpleContainerWithMetadata): +@metadata_module.lazy_decode +@dataclass +class Edge: """ An :ref:`edge ` in a tree sequence. Modifying the attributes in this class will have **no effect** on the underlying tree sequence data. + """ - :ivar left: The left coordinate of this edge. - :vartype left: float - :ivar right: The right coordinate of this edge. - :vartype right: float - :ivar parent: The integer ID of the parent node for this edge. - To obtain further information about a node with a given ID, use - :meth:`TreeSequence.node`. - :vartype parent: int - :ivar child: The integer ID of the child node for this edge. - To obtain further information about a node with a given ID, use - :meth:`TreeSequence.node`. - :vartype child: int - :ivar id: The integer ID of this edge. Varies from 0 to - :attr:`TreeSequence.num_edges` - 1. - :vartype id: int - :ivar metadata: The decoded :ref:`metadata ` for this edge. - :vartype metadata: object + __slots__ = ["left", "right", "parent", "child", "metadata", "id"] + left: float + """ + The left coordinate of this edge. + """ + right: float + """ + The right coordinate of this edge. + """ + parent: int + """ + The integer ID of the parent node for this edge. + To obtain further information about a node with a given ID, use + :meth:`TreeSequence.node`. + """ + child: int + """ + The integer ID of the child node for this edge. + To obtain further information about a node with a given ID, use + :meth:`TreeSequence.node`. + """ + metadata: Any + """ + The :ref:`metadata ` for this edge, decoded if a schema + applies. + """ + id: int # noqa A003 + """ + The integer ID of this edge. Varies from 0 to + :attr:`TreeSequence.num_edges` - 1. """ - def __init__( - self, - left, - right, - parent, - child, - encoded_metadata=b"", - id_=None, - metadata_decoder=lambda metadata: metadata, - ): - self.id = id_ + # Custom init to define default values with slots + def __init__(self, left, right, parent, child, metadata=b"", id=None): # noqa A003 + self.id = id self.left = left self.right = right self.parent = parent self.child = child - self._encoded_metadata = encoded_metadata - self._metadata_decoder = metadata_decoder - - def __repr__(self): - return ( - "{{left={:.3f}, right={:.3f}, parent={}, child={}, id={}, " - "metadata={}}}".format( - self.left, self.right, self.parent, self.child, self.id, self.metadata - ) - ) + self.metadata = metadata @property def span(self): @@ -304,213 +245,223 @@ def span(self): return self.right - self.left -class Site(SimpleContainerWithMetadata): +@metadata_module.lazy_decode +@dataclass +class Site: """ A :ref:`site ` in a tree sequence. Modifying the attributes in this class will have **no effect** on the underlying tree sequence data. + """ - :ivar id: The integer ID of this site. Varies from 0 to - :attr:`TreeSequence.num_sites` - 1. - :vartype id: int - :ivar position: The floating point location of this site in genome coordinates. - Ranges from 0 (inclusive) to :attr:`TreeSequence.sequence_length` - (exclusive). - :vartype position: float - :ivar ancestral_state: The ancestral state at this site (i.e., the state - inherited by nodes, unless mutations occur). - :vartype ancestral_state: str - :ivar metadata: The decoded :ref:`metadata ` for this site. - :vartype metadata: object - :ivar mutations: The list of mutations at this site. Mutations - within a site are returned in the order they are specified in the - underlying :class:`MutationTable`. - :vartype mutations: list[:class:`Mutation`] + __slots__ = ["id", "position", "ancestral_state", "mutations", "metadata"] + id: int # noqa A003 + """ + The integer ID of this site. Varies from 0 to :attr:`TreeSequence.num_sites` - 1. + """ + position: float + """ + The floating point location of this site in genome coordinates. + Ranges from 0 (inclusive) to :attr:`TreeSequence.sequence_length` (exclusive). + """ + ancestral_state: str + """ + The ancestral state at this site (i.e., the state inherited by nodes, unless + mutations occur). + """ + mutations: np.ndarray + """ + The list of mutations at this site. Mutations within a site are returned in the + order they are specified in the underlying :class:`MutationTable`. + """ + metadata: Any + """ + The :ref:`metadata ` for this site, decoded if a schema + applies. """ - def __init__( - self, - id_, - position, - ancestral_state, - mutations, - encoded_metadata=b"", - metadata_decoder=lambda metadata: metadata, - ): - self.id = id_ - self.position = position - self.ancestral_state = ancestral_state - self.mutations = mutations - self._encoded_metadata = encoded_metadata - self._metadata_decoder = metadata_decoder + # We need a custom eq for the numpy arrays + def __eq__(self, other): + return ( + isinstance(other, Site) + and self.id == other.id + and self.position == other.position + and self.ancestral_state == other.ancestral_state + and np.array_equal(self.mutations, other.mutations) + and self.metadata == other.metadata + ) -class Mutation(SimpleContainerWithMetadata): +@metadata_module.lazy_decode +@dataclass +class Mutation: """ A :ref:`mutation ` in a tree sequence. Modifying the attributes in this class will have **no effect** on the underlying tree sequence data. + """ - :ivar id: The integer ID of this mutation. Varies from 0 to - :attr:`TreeSequence.num_mutations` - 1. - :vartype id: int - :ivar site: The integer ID of the site that this mutation occurs at. To obtain - further information about a site with a given ID use - :meth:`TreeSequence.site`. - :vartype site: int - :ivar node: The integer ID of the first node that inherits this mutation. - To obtain further information about a node with a given ID, use - :meth:`TreeSequence.node`. - :vartype node: int - :ivar time: The occurrence time of this mutation. - :vartype time: float - :ivar derived_state: The derived state for this mutation. This is the state - inherited by nodes in the subtree rooted at this mutation's node, unless - another mutation occurs. - :vartype derived_state: str - :ivar parent: The integer ID of this mutation's parent mutation. When multiple - mutations occur at a site along a path in the tree, mutations must - record the mutation that is immediately above them. If the mutation does - not have a parent, this is equal to the :data:`NULL` (-1). - To obtain further information about a mutation with a given ID, use - :meth:`TreeSequence.mutation`. - :vartype parent: int - :ivar metadata: The decoded :ref:`metadata ` for this - mutation. - :vartype metadata: object + __slots__ = ["id", "site", "node", "derived_state", "parent", "metadata", "time"] + id: int # noqa A003 """ + The integer ID of this mutation. Varies from 0 to + :attr:`TreeSequence.num_mutations` - 1. + Modifying the attributes in this class will have **no effect** on the + underlying tree sequence data. + """ + site: int + """ + The integer ID of the site that this mutation occurs at. To obtain + further information about a site with a given ID use :meth:`TreeSequence.site`. + """ + node: int + """ + The integer ID of the first node that inherits this mutation. + To obtain further information about a node with a given ID, use + :meth:`TreeSequence.node`. + """ + derived_state: str + """ + The derived state for this mutation. This is the state + inherited by nodes in the subtree rooted at this mutation's node, unless + another mutation occurs. + """ + parent: int + """ + The integer ID of this mutation's parent mutation. When multiple + mutations occur at a site along a path in the tree, mutations must + record the mutation that is immediately above them. If the mutation does + not have a parent, this is equal to the :data:`NULL` (-1). + To obtain further information about a mutation with a given ID, use + :meth:`TreeSequence.mutation`. + """ + metadata: Any + """ + The :ref:`metadata ` for this mutation, decoded if a schema + applies. + """ + time: float + """ + The occurrence time of this mutation. + """ + + # To get default values on slots we define a custom init def __init__( self, - id_=NULL, + id=NULL, # noqa A003 site=NULL, node=NULL, time=UNKNOWN_TIME, derived_state=None, parent=NULL, - encoded_metadata=b"", - metadata_decoder=lambda metadata: metadata, + metadata=b"", ): - self.id = id_ + self.id = id self.site = site self.node = node self.time = time self.derived_state = derived_state self.parent = parent - self._encoded_metadata = encoded_metadata - self._metadata_decoder = metadata_decoder + self.metadata = metadata + # We need a custom eq to compare unknown times. def __eq__(self, other): - # We need to remove metadata and the decoder so we are just comparing - # the encoded metadata, along with the other attributes. - # We also need to remove time as we have to compare to unknown time. - other_ = copy.copy(other.__dict__) - other_["metadata"] = None - other_["_metadata_decoder"] = None - other_["time"] = None - self_ = copy.copy(self.__dict__) - self_["metadata"] = None - self_["_metadata_decoder"] = None - self_["time"] = None - return self_ == other_ and ( - self.time == other.time - # We need to special case unknown times as they are a nan value. - or (util.is_unknown_time(self.time) and util.is_unknown_time(other.time)) + return ( + isinstance(other, Mutation) + and self.id == other.id + and self.site == other.site + and self.node == other.node + and self.derived_state == other.derived_state + and self.parent == other.parent + and self.metadata == other.metadata + and ( + self.time == other.time + or ( + util.is_unknown_time(self.time) and util.is_unknown_time(other.time) + ) + ) ) -class Migration(SimpleContainerWithMetadata): +@metadata_module.lazy_decode +@dataclass +class Migration: """ A :ref:`migration ` in a tree sequence. Modifying the attributes in this class will have **no effect** on the underlying tree sequence data. - - :ivar left: The left end of the genomic interval covered by this - migration (inclusive). - :vartype left: float - :ivar right: The right end of the genomic interval covered by this migration - (exclusive). - :vartype right: float - :ivar node: The integer ID of the node involved in this migration event. - To obtain further information about a node with a given ID, use - :meth:`TreeSequence.node`. - :vartype node: int - :ivar source: The source population ID. - :vartype source: int - :ivar dest: The destination population ID. - :vartype dest: int - :ivar time: The time at which this migration occured at. - :vartype time: float - :ivar metadata: The decoded :ref:`metadata ` for this - migration. - :vartype metadata: object """ - def __init__( - self, - left, - right, - node, - source, - dest, - time, - encoded_metadata=b"", - metadata_decoder=lambda metadata: metadata, - id_=None, - ): - self.id = id_ - self.left = left - self.right = right - self.node = node - self.source = source - self.dest = dest - self.time = time - self._encoded_metadata = encoded_metadata - self._metadata_decoder = metadata_decoder - - def __repr__(self): - return ( - "{{left={:.3f}, right={:.3f}, node={}, source={}, dest={} time={:.3f}" - " id={}, metadata={}}}".format( - self.left, - self.right, - self.node, - self.source, - self.dest, - self.time, - self.id, - self.metadata, - ) - ) + __slots__ = ["left", "right", "node", "source", "dest", "time", "metadata", "id"] + left: float + """ + The left end of the genomic interval covered by this + migration (inclusive). + """ + right: float + """ + The right end of the genomic interval covered by this migration + (exclusive). + """ + node: int + """ + The integer ID of the node involved in this migration event. + To obtain further information about a node with a given ID, use + :meth:`TreeSequence.node`. + """ + source: int + """ + The source population ID. + """ + dest: int + """ + The destination population ID. + """ + time: float + """ + The time at which this migration occurred at. + """ + metadata: Any + """ + The :ref:`metadata ` for this migration, decoded if a schema + applies. + """ + id: int # noqa A003 + """ + The integer ID of this mutation. Varies from 0 to + :attr:`TreeSequence.num_mutations` - 1. + """ -class Population(SimpleContainerWithMetadata): +@metadata_module.lazy_decode +@dataclass +class Population: """ A :ref:`population ` in a tree sequence. Modifying the attributes in this class will have **no effect** on the underlying tree sequence data. - - :ivar id: The integer ID of this population. Varies from 0 to - :attr:`TreeSequence.num_populations` - 1. - :vartype id: int - :ivar metadata: The decoded :ref:`metadata ` - for this population. - :vartype metadata: object """ - def __init__( - self, id_, encoded_metadata=b"", metadata_decoder=lambda metadata: metadata - ): - self.id = id_ - self._encoded_metadata = encoded_metadata - self._metadata_decoder = metadata_decoder + __slots__ = ["id", "metadata"] + id: int # noqa A003 + """ + The integer ID of this population. Varies from 0 to + :attr:`TreeSequence.num_populations` - 1. + """ + metadata: Any + """ + The :ref:`metadata ` for this population, decoded if a + schema applies. + """ -class Variant(SimpleContainer): +@dataclass +class Variant: """ A variant represents the observed variation among samples for a given site. A variant consists (a) of a reference to the @@ -559,72 +510,87 @@ class Variant(SimpleContainer): Modifying the attributes in this class will have **no effect** on the underlying tree sequence data. + """ - :ivar site: The site object for this variant. - :vartype site: :class:`Site` - :ivar alleles: A tuple of the allelic values that may be observed at the - samples at the current site. The first element of this tuple is always - the site's ancestral state. - :vartype alleles: tuple(str) - :ivar genotypes: An array of indexes into the list ``alleles``, giving the - state of each sample at the current site. - :ivar has_missing_data: True if there is missing data for any of the + __slots__ = ["site", "alleles", "genotypes"] + site: Site + """ + The site object for this variant. + """ + alleles: tuple + """ + A tuple of the allelic values that may be observed at the + samples at the current site. The first element of this tuple is always + the site's ancestral state. + """ + genotypes: np.ndarray + """ + An array of indexes into the list ``alleles``, giving the + state of each sample at the current site. + """ + + @property + def has_missing_data(self): + """ + True if there is missing data for any of the samples at the current site. - :vartype has_missing_data: bool - :ivar num_alleles: The number of distinct alleles at this site. Note that + """ + return self.alleles[-1] is None + + @property + def num_alleles(self): + """ + The number of distinct alleles at this site. Note that this may be greater than the number of distinct values in the genotypes array. - :vartype num_alleles: int - :vartype genotypes: numpy.ndarray - """ + """ + return len(self.alleles) - self.has_missing_data - def __init__(self, site, alleles, genotypes): - self.site = site - self.alleles = alleles - self.has_missing_data = alleles[-1] is None - self.num_alleles = len(alleles) - self.has_missing_data - self.genotypes = genotypes - # Deprecated aliases to avoid breaking existing code. - self.position = site.position - self.index = site.id + # Deprecated alias to avoid breaking existing code. + @property + def position(self): + return self.site.position + + # Deprecated alias to avoid breaking existing code. + @property + def index(self): + return self.site.id + # We need a custom eq for the numpy array def __eq__(self, other): return ( - self.site == other.site + isinstance(other, Variant) + and self.site == other.site and self.alleles == other.alleles and np.array_equal(self.genotypes, other.genotypes) ) -class Edgeset(SimpleContainer): - def __init__(self, left, right, parent, children): - self.left = left - self.right = right - self.parent = parent - self.children = children +@dataclass +class Edgeset: + __slots__ = ["left", "right", "parent", "children"] + left: int + right: int + parent: int + children: np.ndarray - def __repr__(self): - return "{{left={:.3f}, right={:.3f}, parent={}, children={}}}".format( - self.left, self.right, self.parent, self.children + # We need a custom eq for the numpy array + def __eq__(self, other): + return ( + isinstance(other, Edgeset) + and self.left == other.left + and self.right == other.right + and self.parent == other.parent + and np.array_equal(self.children, other.children) ) -class Provenance(SimpleContainer): - def __init__(self, id_=None, timestamp=None, record=None): - self.id = id_ - self.timestamp = timestamp - self.record = record - - -def add_deprecated_mutation_attrs(site, mutation): - """ - Add in attributes for the older deprecated way of defining - mutations. These attributes will be removed in future releases - and are deliberately undocumented in tskit v0.2.2. - """ - mutation.position = site.position - mutation.index = site.id - return mutation +@dataclass +class Provenance: + __slots__ = ["id", "timestamp", "record"] + id: int # noqa A003 + timestamp: str + record: str class Tree: @@ -1861,8 +1827,7 @@ def mutations(self): :rtype: iter(:class:`Mutation`) """ for site in self.sites(): - for mutation in site.mutations: - yield add_deprecated_mutation_attrs(site, mutation) + yield from site.mutations def get_leaves(self, u): # Deprecated alias for samples. See the discussion in the get_num_leaves @@ -2422,7 +2387,12 @@ def map_mutations(self, genotypes, alleles): # Translate back into string alleles ancestral_state = alleles[ancestral_state] mutations = [ - Mutation(node=node, derived_state=alleles[derived_state], parent=parent) + Mutation( + node=node, + derived_state=alleles[derived_state], + parent=parent, + metadata=self.tree_sequence.table_metadata_schemas.mutation.empty_value, + ) for node, parent, derived_state in transitions ] return ancestral_state, mutations @@ -3277,19 +3247,19 @@ class TreeSequence: the :meth:`.variants` method iterates over all sites and their genotypes. """ - @attr.s(slots=True, frozen=True, kw_only=True, auto_attribs=True) + @dataclass(frozen=True) class _TableMetadataSchemas: """ Convenience class for returning schemas """ - node: metadata_module.MetadataSchema - edge: metadata_module.MetadataSchema - site: metadata_module.MetadataSchema - mutation: metadata_module.MetadataSchema - migration: metadata_module.MetadataSchema - individual: metadata_module.MetadataSchema - population: metadata_module.MetadataSchema + node: metadata_module.MetadataSchema = None + edge: metadata_module.MetadataSchema = None + site: metadata_module.MetadataSchema = None + mutation: metadata_module.MetadataSchema = None + migration: metadata_module.MetadataSchema = None + individual: metadata_module.MetadataSchema = None + population: metadata_module.MetadataSchema = None def __init__(self, ll_tree_sequence): self._ll_tree_sequence = ll_tree_sequence @@ -3974,8 +3944,12 @@ def edge_diffs(self, include_terminal=False): iterator = _tskit.TreeDiffIterator(self._ll_tree_sequence, include_terminal) metadata_decoder = self.table_metadata_schemas.edge.decode_row for interval, edge_tuples_out, edge_tuples_in in iterator: - edges_out = [Edge(*(e + (metadata_decoder,))) for e in edge_tuples_out] - edges_in = [Edge(*(e + (metadata_decoder,))) for e in edge_tuples_in] + edges_out = [ + Edge(*e, metadata_decoder=metadata_decoder) for e in edge_tuples_out + ] + edges_in = [ + Edge(*e, metadata_decoder=metadata_decoder) for e in edge_tuples_in + ] yield Interval(*interval), edges_out, edges_in def sites(self): @@ -4009,8 +3983,7 @@ def mutations(self): :rtype: iter(:class:`Mutation`) """ for site in self.sites(): - for mutation in site.mutations: - yield add_deprecated_mutation_attrs(site, mutation) + yield from site.mutations def populations(self): """ @@ -4507,13 +4480,13 @@ def individual(self, id_): nodes, ) = self._ll_tree_sequence.get_individual(id_) return Individual( - id_=id_, + id=id_, flags=flags, location=location, parents=parents, - encoded_metadata=metadata, - metadata_decoder=self.table_metadata_schemas.individual.decode_row, + metadata=metadata, nodes=nodes, + metadata_decoder=self.table_metadata_schemas.individual.decode_row, ) def node(self, id_): @@ -4531,12 +4504,12 @@ def node(self, id_): metadata, ) = self._ll_tree_sequence.get_node(id_) return Node( - id_=id_, + id=id_, flags=flags, time=time, population=population, individual=individual, - encoded_metadata=metadata, + metadata=metadata, metadata_decoder=self.table_metadata_schemas.node.decode_row, ) @@ -4549,12 +4522,12 @@ def edge(self, id_): """ left, right, parent, child, metadata = self._ll_tree_sequence.get_edge(id_) return Edge( - id_=id_, + id=id_, left=left, right=right, parent=parent, child=child, - encoded_metadata=metadata, + metadata=metadata, metadata_decoder=self.table_metadata_schemas.edge.decode_row, ) @@ -4589,14 +4562,14 @@ def migration(self, id_): metadata, ) = self._ll_tree_sequence.get_migration(id_) return Migration( - id_=id_, + id=id_, left=left, right=right, node=node, source=source, dest=dest, time=time, - encoded_metadata=metadata, + metadata=metadata, metadata_decoder=self.table_metadata_schemas.migration.decode_row, ) @@ -4616,14 +4589,14 @@ def mutation(self, id_): time, ) = self._ll_tree_sequence.get_mutation(id_) return Mutation( - id_=id_, + id=id_, site=site, node=node, derived_state=derived_state, parent=parent, - encoded_metadata=metadata, - metadata_decoder=self.table_metadata_schemas.mutation.decode_row, + metadata=metadata, time=time, + metadata_decoder=self.table_metadata_schemas.mutation.decode_row, ) def site(self, id_): @@ -4637,11 +4610,11 @@ def site(self, id_): pos, ancestral_state, ll_mutations, _, metadata = ll_site mutations = [self.mutation(mut_id) for mut_id in ll_mutations] return Site( - id_=id_, + id=id_, position=pos, ancestral_state=ancestral_state, mutations=mutations, - encoded_metadata=metadata, + metadata=metadata, metadata_decoder=self.table_metadata_schemas.site.decode_row, ) @@ -4654,14 +4627,14 @@ def population(self, id_): """ (metadata,) = self._ll_tree_sequence.get_population(id_) return Population( - id_=id_, - encoded_metadata=metadata, + id=id_, + metadata=metadata, metadata_decoder=self.table_metadata_schemas.population.decode_row, ) def provenance(self, id_): timestamp, record = self._ll_tree_sequence.get_provenance(id_) - return Provenance(id_=id_, timestamp=timestamp, record=record) + return Provenance(id=id_, timestamp=timestamp, record=record) def get_samples(self, population_id=None): # Deprecated alias for samples() diff --git a/python/tskit/vcf.py b/python/tskit/vcf.py index 578a73e5b6..ee0bc5786e 100644 --- a/python/tskit/vcf.py +++ b/python/tskit/vcf.py @@ -203,8 +203,7 @@ def write(self, output): end="\t", file=output, ) - variant.genotypes += ord("0") - gt_array[indexes] = variant.genotypes + gt_array[indexes] = variant.genotypes + ord("0") g_bytes = memoryview(gt_array).tobytes() g_str = g_bytes.decode() print(g_str, end="", file=output)