diff --git a/docs/examples/magic_cut.rst b/docs/examples/magic_cut.rst index 8575c5d92..b70b7ae54 100644 --- a/docs/examples/magic_cut.rst +++ b/docs/examples/magic_cut.rst @@ -118,15 +118,15 @@ into existence. >>> C = (2,) >>> AB = (0, 1) -The cut applied to the subsystem severs the connections going to |C| from -either |A| or |B|. In this circumstance, knowing the state of |A| or |B| does -not tell us anything about the state of |C|; only the previous state of |C| can -tell us about the next state of |C|. ``C_node.tpm_on`` gives us the probability -of |C| being ON in the next state, while ``C_node.tpm_off`` would give us the +The cut applied to the subsystem severs the connections going to |C| from either +|A| or |B|. In this circumstance, knowing the state of |A| or |B| does not tell +us anything about the state of |C|; only the previous state of |C| can tell us +about the next state of |C|. ``C_node.tpm[..., 1]`` gives us the probability of +|C| being ON in the next state, while ``C_node.tpm[..., 0]`` would give us the probability of |C| being OFF. >>> C_node = cut_subsystem.indices2nodes(C)[0] - >>> C_node.tpm_on.flatten() + >>> C_node.tpm[..., 1].flatten() array([0.5 , 0.75]) This states that |C| has a 50% chance of being ON in the next state if it diff --git a/pyphi/__init__.py b/pyphi/__init__.py index fe13d10a8..4cd2935ea 100644 --- a/pyphi/__init__.py +++ b/pyphi/__init__.py @@ -75,7 +75,6 @@ from .direction import Direction from .network import Network from .subsystem import Subsystem -from .tpm import ExplicitTPM # Skip modules that require optional dependencies _skip_import = ["visualize", "graphs"] diff --git a/pyphi/cache/__init__.py b/pyphi/cache/__init__.py index 5a37ed503..97fede79c 100644 --- a/pyphi/cache/__init__.py +++ b/pyphi/cache/__init__.py @@ -103,10 +103,10 @@ class DictCache: Intended to be used as an object-level cache of method results. """ - def __init__(self): - self.cache = {} - self.hits = 0 - self.misses = 0 + def __init__(self, cache=None, hits=0, misses=0): + self.cache = dict() if cache is None else cache + self.hits = hits + self.misses = misses def clear(self): self.cache = {} @@ -148,6 +148,11 @@ def key(self, *args, _prefix=None, **kwargs): if kwargs: raise NotImplementedError("kwarg cache keys not implemented") return (_prefix,) + tuple(args) + + def __repr__(self): + return "{}(cache={}, hits={}, misses={})".format( + type(self).__name__, self.cache, self.hits, self.misses + ) def validate_parent_cache(parent_cache): diff --git a/pyphi/convert.py b/pyphi/convert.py old mode 100644 new mode 100755 diff --git a/pyphi/data_structures/array_like.py b/pyphi/data_structures/array_like.py index 12f75b71f..775f2bc93 100644 --- a/pyphi/data_structures/array_like.py +++ b/pyphi/data_structures/array_like.py @@ -12,10 +12,14 @@ class ArrayLike(NDArrayOperatorsMixin): # TODO(tpm) populate this list _TYPE_CLOSED_FUNCTIONS = ( + np.all, + np.any, np.concatenate, + np.expand_dims, np.stack, - np.all, np.sum, + np.result_type, + np.broadcast_to, ) # Holds the underlying array diff --git a/pyphi/distribution.py b/pyphi/distribution.py old mode 100644 new mode 100755 index 7768f26b6..e4547134f --- a/pyphi/distribution.py +++ b/pyphi/distribution.py @@ -4,6 +4,7 @@ import numpy as np from .cache import cache +from .utils import eq def normalize(a): diff --git a/pyphi/examples.py b/pyphi/examples.py old mode 100644 new mode 100755 index 8a1e09c99..42b84bd80 --- a/pyphi/examples.py +++ b/pyphi/examples.py @@ -1512,3 +1512,28 @@ def get_net( print(transition) account = actual.account(transition) print(account) + +@register_example +def functionally_equivalent(): + """The 2nd deterministic system from Figure 8 of the IIT 4.0 paper: + Functionally equivalent networks with different Φ-structures. + """ + node_labels = ("A", "B", "C") + # fmt: off + cm = np.array([ + [1, 1, 0,], + [0, 1, 1,], + [1, 1, 1,], + ]) + tpm = np.array([ + [1, 0, 0], + [0, 1, 0], + [1, 1, 1], + [0, 1, 1], + [0, 0, 0], + [1, 1, 0], + [0, 0, 1], + [1, 0, 1], + ]) + # fmt: on + return Network(tpm, cm=cm, node_labels=node_labels) diff --git a/pyphi/macro.py b/pyphi/macro.py index 4472f44d4..5a010809f 100644 --- a/pyphi/macro.py +++ b/pyphi/macro.py @@ -14,7 +14,10 @@ from .labels import NodeLabels from .network import irreducible_purviews from .node import expand_node_tpm, generate_nodes +from .state_space import build_state_space from .subsystem import Subsystem + +# TODO(tpm) use ImplicitTPM type consistently throughout module from .tpm import ExplicitTPM # Create a logger for this module. @@ -43,25 +46,6 @@ def rebuild_system_tpm(node_tpms): return ExplicitTPM(tpm, validate=True) -# TODO This should be a method of the TPM class in tpm.py -def remove_singleton_dimensions(tpm): - """Remove singleton dimensions from the TPM. - - Singleton dimensions are created by conditioning on a set of elements. - This removes those elements from the TPM, leaving a TPM that only - describes the non-conditioned elements. - - Note that indices used in the original TPM must be reindexed for the - smaller TPM. - """ - # Don't squeeze out the final dimension (which contains the probability) - # for networks with one element. - if tpm.ndim <= 2: - return tpm - - return tpm.squeeze()[..., tpm.tpm_indices()] - - def run_tpm(system, steps, blackbox): """Iterate the TPM for the given number of timesteps. @@ -73,7 +57,8 @@ def run_tpm(system, steps, blackbox): # boxes. node_tpms = [] for node in system.nodes: - node_tpm = node.tpm_on + # TODO: nonbinary nodes. + node_tpm = node.tpm[..., 1] for input_node in node.inputs: if not blackbox.in_same_box(node.index, input_node): if input_node in blackbox.output_indices: @@ -92,7 +77,11 @@ def run_tpm(system, steps, blackbox): return ExplicitTPM(convert.state_by_state2state_by_node(tpm), validate=True) -class SystemAttrs(namedtuple("SystemAttrs", ["tpm", "cm", "node_indices", "state"])): +class SystemAttrs( + namedtuple( + "SystemAttrs", ["tpm", "cm", "node_indices", "state"] + ) +): """An immutable container that holds all the attributes of a subsystem. Versions of this object are passed down the steps of the micro-to-macro @@ -106,15 +95,34 @@ def node_labels(self): labels = list("m{}".format(i) for i in self.node_indices) return NodeLabels(labels, self.node_indices) + @property + def state_space(self): + state_space, _ = build_state_space( + self.node_labels, + self.tpm.shape[:-1], + node_states=None, + ) + return state_space + @property def nodes(self): return generate_nodes( - self.tpm, self.cm, self.state, self.node_indices, self.node_labels + self.tpm, + self.cm, + self.state_space, + self.node_indices, + self.node_labels, + network_state=self.state, ) @staticmethod def pack(system): - return SystemAttrs(system.tpm, system.cm, system.node_indices, system.state) + return SystemAttrs( + system.tpm, + system.cm, + system.node_indices, + system.state, + ) def apply(self, system): system.tpm = self.tpm @@ -214,11 +222,12 @@ def _squeeze(system): Reindexes the subsystem so that the nodes are ``0..n`` where ``n`` is the number of internal indices in the system. """ - assert system.node_indices == system.tpm.tpm_indices() + assert system.node_indices == system.tpm.tpm_indices(reconstituted=True) - internal_indices = system.tpm.tpm_indices() + internal_indices = system.tpm.tpm_indices(reconstituted=True) + tpm = system.tpm.remove_singleton_dimensions() - tpm = remove_singleton_dimensions(system.tpm) + # TODO(tpm): deduplicate commonalities with tpm.ImplicitTPM.squeeze. # The connectivity matrix is the network's connectivity matrix, with # cut applied, with all connections to/from external nodes severed, @@ -229,10 +238,24 @@ def _squeeze(system): # Re-index the subsystem nodes with the external nodes removed node_indices = reindex(internal_indices) - nodes = generate_nodes(tpm, cm, state, node_indices) + node_labels = NodeLabels(None, node_indices) + state_space, _ = build_state_space( + node_labels, + tpm.shape[:-1], + ) + + nodes = generate_nodes( + tpm, + cm, + state_space, + node_indices, + node_labels, + network_state=state + ) # Re-calcuate the tpm based on the results of the cut - tpm = rebuild_system_tpm(node.tpm_on for node in nodes) + # TODO: nonbinary nodes. + tpm = rebuild_system_tpm(node.tpm[..., 1] for node in nodes) return SystemAttrs(tpm, cm, node_indices, state) @@ -242,7 +265,8 @@ def _blackbox_partial_noise(blackbox, system): # Noise inputs from non-output elements hidden in other boxes node_tpms = [] for node in system.nodes: - node_tpm = node.tpm_on + # TODO: nonbinary nodes. + node_tpm = node.tpm[..., 1] for input_node in node.inputs: if blackbox.hidden_from(input_node, node.index): node_tpm = node_tpm.marginalize_out([input_node]) @@ -264,7 +288,12 @@ def _blackbox_time(time_scale, blackbox, system): n = len(system.node_indices) cm = np.ones((n, n)) - return SystemAttrs(tpm, cm, system.node_indices, system.state) + return SystemAttrs( + tpm, + cm, + system.node_indices, + system.state, + ) def _blackbox_space(self, blackbox, system): """Blackbox the TPM and CM in space. @@ -282,7 +311,8 @@ def _blackbox_space(self, blackbox, system): assert blackbox.output_indices == tpm.tpm_indices() - tpm = remove_singleton_dimensions(tpm) + new_tpm = tpm.remove_singleton_dimensions() + n = len(blackbox) cm = np.zeros((n, n)) for i, j in itertools.product(range(n), repeat=2): @@ -294,8 +324,12 @@ def _blackbox_space(self, blackbox, system): state = blackbox.macro_state(system.state) node_indices = blackbox.macro_indices + state_space, _ = build_state_space( + NodeLabels(None, node_indices), + tpm.shape[:-1] + ) - return SystemAttrs(tpm, cm, node_indices, state) + return SystemAttrs(new_tpm, cm, node_indices, state) @staticmethod def _coarsegrain_space(coarse_grain, is_cut, system): diff --git a/pyphi/network.py b/pyphi/network.py old mode 100644 new mode 100755 index 2230fd044..8fbad61a5 --- a/pyphi/network.py +++ b/pyphi/network.py @@ -5,11 +5,14 @@ |big_phi| computation. """ +from typing import Iterable import numpy as np from . import cache, connectivity, jsonify, utils, validate from .labels import NodeLabels -from .tpm import ExplicitTPM +from .node import generate_nodes, generate_node +from .tpm import ExplicitTPM, ImplicitTPM +from .state_space import build_state_space class Network: @@ -19,25 +22,7 @@ class Network: Args: tpm (np.ndarray): The transition probability matrix of the network. - - The TPM can be provided in any of three forms: **state-by-state**, - **state-by-node**, or **multidimensional state-by-node** form. - In the state-by-node forms, row indices must follow the - little-endian convention (see :ref:`little-endian-convention`). In - state-by-state form, column indices must also follow the - little-endian convention. - - If the TPM is given in state-by-node form, it can be either - 2-dimensional, so that ``tpm[i]`` gives the probabilities of each - node being ON if the previous state is encoded by |i| according to - the little-endian convention, or in multidimensional form, so that - ``tpm[(0, 0, 1)]`` gives the probabilities of each node being ON if - the previous state is |N_0 = 0, N_1 = 0, N_2 = 1|. - - The shape of the 2-dimensional form of a state-by-node TPM must be - ``(s, n)``, and the shape of the multidimensional form of the TPM - must be ``[2] * n + [n]``, where ``s`` is the number of states and - ``n`` is the number of nodes in the network. + See :func:`pyphi.tpm.ExplicitTPM`. Keyword Args: cm (np.ndarray): A square binary adjacency matrix indicating the @@ -47,29 +32,110 @@ class Network: is connected to every node (including itself)**. node_labels (tuple[str] or |NodeLabels|): Human-readable labels for each node in the network. - - Example: - In a 3-node network, ``the_network.tpm[(0, 0, 1)]`` gives the - transition probabilities for each node at |t| given that state at |t-1| - was |N_0 = 0, N_1 = 0, N_2 = 1|. + state_space (Optional[tuple[tuple[Union[int, str]]]]): + Labels for the state space of each node in the network. If ``None``, + states will be automatically labeled using a zero-based integer + index per node. """ - # TODO make tpm also optional when implementing logical network definition - def __init__(self, tpm, cm=None, node_labels=None, purview_cache=None): + def __init__( + self, + tpm, + cm=None, + node_labels=None, + state_space=None, + purview_cache=None + ): # Initialize _tpm according to argument type. - if isinstance(tpm, ExplicitTPM): + + if isinstance(tpm, (np.ndarray, ExplicitTPM)): + # Validate TPM and convert to state-by-node multidimensional format. + tpm = ExplicitTPM(tpm, validate=True) + + self._cm, self._cm_hash = self._build_cm(cm, tpm) + + self._node_indices = tuple(range(self.size)) + self._node_labels = NodeLabels(node_labels, self._node_indices) + + self._state_space, _ = build_state_space( + self._node_labels, + tpm.shape[:-1], + state_space + ) + + self._tpm = ImplicitTPM( + generate_nodes( + tpm, + self._cm, + self._state_space, + self._node_indices, + self._node_labels + ) + ) + + elif isinstance(tpm, Iterable): + invalid = [ + i for i in tpm if not isinstance(i, (np.ndarray, ExplicitTPM)) + ] + + if invalid: + raise TypeError("Invalid set of nodes containing {}.".format( + ', '.join(str(i) for i in invalid) + )) + + tpm = tuple( + ExplicitTPM(node_tpm, validate=False) for node_tpm in tpm + ) + + shapes = [node.shape for node in tpm] + + self._cm, self._cm_hash = self._build_cm(cm, tpm, shapes) + + self._node_indices = tuple(range(self.size)) + self._node_labels = NodeLabels(node_labels, self._node_indices) + + network_tpm_shape = ImplicitTPM._node_shapes_to_shape(shapes) + self._state_space, _ = build_state_space( + self._node_labels, + network_tpm_shape[:-1], + state_space + ) + + self._tpm = ImplicitTPM( + tuple( + generate_node( + node_tpm, + self._cm, + self._state_space, + index, + node_labels=self._node_labels + ) + for index, node_tpm in zip(self._node_indices, tpm) + ) + ) + + elif isinstance(tpm, ImplicitTPM): self._tpm = tpm - elif isinstance(tpm, np.ndarray): - self._tpm = ExplicitTPM(tpm, validate=True) + self._cm, self._cm_hash = self._build_cm(cm, self._tpm) + self._node_indices = tuple(range(self.size)) + self._node_labels = NodeLabels(node_labels, self._node_indices) + self._state_space, _ = build_state_space( + self._node_labels, + self._tpm.shape[:-1], + state_space + ) + + # FIXME(TPM) initialization from JSON elif isinstance(tpm, dict): # From JSON. - self._tpm = ExplicitTPM(tpm["_tpm"], validate=True) + self._tpm = ImplicitTPM(tpm["_tpm"]) + self._cm, self._cm_hash = self._build_cm(cm, tpm) + self._node_indices = tuple(range(self.size)) + self._node_labels = NodeLabels(node_labels, self._node_indices) + else: - raise TypeError(f"Invalid tpm of type {type(tpm)}.") + raise TypeError(f"Invalid TPM of type {type(tpm)}.") - self._cm, self._cm_hash = self._build_cm(cm) - self._node_indices = tuple(range(self.size)) - self._node_labels = NodeLabels(node_labels, self._node_indices) self.purview_cache = purview_cache or cache.PurviewCache() validate.network(self) @@ -91,18 +157,43 @@ def cm(self): """ return self._cm - def _build_cm(self, cm): + def _build_cm(self, cm, tpm, shapes=None): """Convert the passed CM to the proper format, or construct the - unitary CM if none was provided. + unitary CM if none was provided (explicit TPM), or infer from node TPMs. """ if cm is None: - # Assume all are connected. - cm = np.ones((self.size, self.size)) - else: - cm = np.array(cm) + if hasattr(tpm, "shape"): + network_size = tpm.shape[-1] + else: + network_size = len(tpm) + + # Explicit TPM without connectivity matrix: assume all are connected. + if shapes is None: + cm = np.ones((network_size, network_size), dtype=int) + utils.np_immutable(cm) + return (cm, utils.np_hash(cm)) + + # ImplicitTPM without connectivity matrix: infer from node TPMs. + cm = np.zeros((network_size, network_size), dtype=int) + + for i, shape in enumerate(shapes): + for j in range(len(shapes)): + if shape[j] != 1: + cm[j][i] = 1 + + utils.np_immutable(cm) + return (cm, utils.np_hash(cm)) + cm = np.array(cm) utils.np_immutable(cm) + # Explicit TPM with connectivity matrix: return. + if shapes is None: + return (cm, utils.np_hash(cm)) + + # ImplicitTPM with connectivity matrix: validate against node shapes. + validate.shapes(shapes, cm) + return (cm, utils.np_hash(cm)) @property @@ -120,11 +211,18 @@ def size(self): """int: The number of nodes in the network.""" return len(self) - # TODO extend to nonbinary nodes + @property + def state_space(self): + """tuple[tuple[Union[int, str]]]: Labels for the state space of each node. + """ + return self._state_space + @property def num_states(self): """int: The number of possible states of the network.""" - return 2**self.size + return np.prod( + [len(node_states) for node_states in self._state_space] + ) @property def node_indices(self): @@ -159,13 +257,14 @@ def potential_purviews(self, direction, mechanism): def __len__(self): """int: The number of nodes in the network.""" - return self.tpm.shape[-1] + return self._cm.shape[0] def __repr__(self): - return "Network({}, cm={})".format(self.tpm, self.cm) - - def __str__(self): - return self.__repr__() + # TODO implement a cleaner repr, similar to analyses objects, + # distinctions, etc. + return "Network(\n{},\ncm={},\nnode_labels={},\nstate_space={}\n)".format( + self.tpm, self.cm, self.node_labels, self.state_space._dict + ) def __eq__(self, other): """Return whether this network equals the other object. @@ -181,6 +280,7 @@ def __eq__(self, other): def __ne__(self, other): return not self.__eq__(other) + # TODO(tpm): Immutability in xarray. def __hash__(self): return hash((hash(self.tpm), self._cm_hash)) @@ -191,6 +291,7 @@ def to_json(self): "cm": self.cm, "size": self.size, "node_labels": self.node_labels, + "state_space": self.state_space, } @classmethod diff --git a/pyphi/node.py b/pyphi/node.py old mode 100644 new mode 100755 index 1274da2a4..8ae111d1b --- a/pyphi/node.py +++ b/pyphi/node.py @@ -3,140 +3,233 @@ import functools +from typing import Iterable, Mapping, Optional, Tuple, Union + import numpy as np +import xarray as xr + +# TODO rework circular dependency between node.py and tpm.py, instead +# of importing all of pyphi.tpm and relying on late binding of pyphi.tpm. +# to avoid the circular import error. +import pyphi.tpm -from . import utils from .connectivity import get_inputs_from_cm, get_outputs_from_cm -from .labels import NodeLabels -from .tpm import ExplicitTPM +from .state_space import ( + dimension_labels, + build_state_space, + SINGLETON_COORDINATE, +) +from .utils import state_of -# TODO extend to nonbinary nodes +@xr.register_dataarray_accessor("node") @functools.total_ordering class Node: """A node in a subsystem. Args: - cause_tpm (ExplicitTPM): The cause (backward) TPM of the subsystem. - effect_tpm (ExplicitTPM): The effect (forward) TPM of the subsystem. - cm (np.ndarray): The CM of the subsystem. - index (int): The node's index in the network. - state (int): The state of this node. - node_labels (|NodeLabels|): Labels for these nodes. + effect_dataarray (xr.DataArray): the xarray DataArray for the effect TPM. + + Keyword Args: + cause_dataarray (xr.DataArray): the xarray DataArray for the cause TPM. Attributes: - cause_tpm (ExplicitTPM), - effect_tpm (ExplicitTPM): The node TPM is an array with shape ``(2,)*(n + 1)``, - where ``n`` is the size of the |Network|. The first ``n`` - dimensions correspond to each node in the system. Dimensions - corresponding to nodes that provide input to this node are of size - 2, while those that do not correspond to inputs are of size 1, so - that the TPM has |2^m x 2| elements where |m| is the number of - inputs. The last dimension corresponds to the state of the node in - the next timestep, so that ``node.tpm[..., 0]`` gives probabilities - that the node will be 'OFF' and ``node.tpm[..., 1]`` gives - probabilities that the node will be 'ON'. + index (int): The node's index in the network. + label (str): The textual label for this node. + node_labels (Tuple[str]): The textual labels for the nodes in the network. + cause_dataarray (xr.DataArray): the xarray DataArray for the cause TPM. + effect_dataarray (xr.DataArray): the xarray DataArray for the effect TPM. + cause_tpm (|ExplicitTPM|), + effect_tpm (|ExplicitTPM|): The node TPM is an array with |n + 1| dimensions, + where ``n`` is the size of the |Network|. The first ``n`` dimensions + correspond to each node in the system. Dimensions corresponding to + nodes that provide input to this node are of size > 1, while those + that do not correspond to inputs are of size 1. The last dimension + encodes the state of the node in the next timestep, so that + ``node.tpm[..., 0]`` gives probabilities that the node will be 'OFF' + and ``node.tpm[..., 1]`` gives probabilities that the node will be + 'ON'. + inputs (frozenset): The set of nodes which send connections to this node. + outputs (frozenset): The set of nodes this node sends connections to. + state_space (Tuple[Union[int, str]]): The space of states this node can + inhabit. + state (Optional[Union[int, str]]): The current state of this node. """ - def __init__(self, cause_tpm, effect_tpm, cm, index, state, node_labels): - # This node's index in the list of nodes. - self.index = index - - # State of this node. - self.state = state + def __init__( + self, + effect_dataarray: xr.DataArray, + cause_dataarray: Optional[xr.DataArray] = None, + ): + self._index = effect_dataarray.attrs["index"] # Node labels used in the system - self.node_labels = node_labels - - # Get indices of the inputs. - self._inputs = frozenset(get_inputs_from_cm(self.index, cm)) - self._outputs = frozenset(get_outputs_from_cm(self.index, cm)) - - # Generate the node's TPMs. - # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - # We begin by getting the part of the subsystem's TPM that gives just - # the state of this node. This part is still indexed by network state, - # but its last dimension will be gone, since now there's just a single - # scalar value (this node's state) rather than a state-vector for all - # the network nodes. - cause_tpm_on = cause_tpm[..., self.index] - effect_tpm_on = effect_tpm[..., self.index] + self._node_labels = effect_dataarray.attrs["node_labels"] - # TODO extend to nonbinary nodes - # Marginalize out non-input nodes that are in the subsystem, since the - # external nodes have already been dealt with as boundary conditions in - # the subsystem's TPM. + self._inputs = effect_dataarray.attrs["inputs"] + self._outputs = effect_dataarray.attrs["outputs"] - # TODO use names rather than indices - cause_non_inputs = set(cause_tpm.tpm_indices()) - self._inputs - cause_tpm_on = cause_tpm_on.marginalize_out(cause_non_inputs).tpm + if cause_dataarray is None: + self._cause_dataarray = None + self._cause_tpm = None + else: + self._cause_dataarray = cause_dataarray + self._cause_tpm = cause_dataarray.data - effect_non_inputs = set(effect_tpm.tpm_indices()) - self._inputs - effect_tpm_on = effect_tpm_on.marginalize_out(effect_non_inputs).tpm + self._effect_dataarray = effect_dataarray + self._effect_tpm = self._effect_dataarray.data - # Get the TPM that gives the probability of the node being off, rather - # than on. - cause_tpm_off = 1 - cause_tpm_on - effect_tpm_off = 1 - effect_tpm_on + self.state_space = effect_dataarray.attrs["state_space"] - # Combine the on- and off-TPM so that the first dimension is indexed by - # the state of the node's inputs at t, and the last dimension is - # indexed by the node's state at t+1. This representation makes it easy - # to condition on the node state. - self.cause_tpm = ExplicitTPM( - np.stack([cause_tpm_off, cause_tpm_on], axis=-1), - ) - self.effect_tpm = ExplicitTPM( - np.stack([effect_tpm_off, effect_tpm_on], axis=-1), - ) - # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # (Optional) current state of this node. + self.state = effect_dataarray.attrs["state"] # Only compute the hash once. self._hash = hash( ( - index, - hash(self.cause_tpm), - hash(self.effect_tpm), - self.state, + self.index, + hash(pyphi.tpm.ExplicitTPM(self.cause_tpm)), + hash(pyphi.tpm.ExplicitTPM(self.effect_tpm)), self._inputs, self._outputs, + self.state_space, + self.state ) ) @property - def cause_tpm_off(self): - """The cause (backward) TPM of this node containing only the 'OFF' probabilities.""" - return self.cause_tpm[..., 0] + def index(self): + """int: The node's index in the network.""" + return self._index + + @property + def label(self): + """str: The textual label for this node.""" + return self._node_labels[self.index] + + @property + def cause_dataarray(self): + """|xr.DataArray|: The cause xarray DataArray for this node.""" + return self._cause_dataarray @property - def effect_tpm_off(self): - """The effect (forward) TPM of this node containing only the 'OFF' probabilities.""" - return self.effect_tpm[..., 0] + def effect_dataarray(self): + """|xr.DataArray|: The effect xarray DataArray for this node.""" + return self._effect_dataarray @property - def cause_tpm_on(self): - """The cause (backward) TPM of this node containing only the 'ON' probabilities.""" - return self.cause_tpm[..., 1] + def cause_tpm(self): + """|ExplicitTPM|: The TPM of this node.""" + return self._cause_tpm @property - def effect_tpm_on(self): - """The effect (forward) TPM of this node containing only the 'ON' probabilities.""" - return self.effect_tpm[..., 1] + def effect_tpm(self): + """|ExplicitTPM|: The TPM of this node.""" + return self._effect_tpm @property def inputs(self): - """The set of nodes with connections to this node.""" + """frozenset: The set of nodes with connections to this node.""" return self._inputs @property def outputs(self): - """The set of nodes this node has connections to.""" + """frozenset: The set of nodes this node has connections to.""" return self._outputs @property - def label(self): - """The textual label for this node.""" - return self.node_labels[self.index] + def state_space(self): + """Tuple[Union[int, str]]: The space of states this node can inhabit.""" + return self._state_space + + @state_space.setter + def state_space(self, value): + _state_space = tuple(value) + + if len(set(_state_space)) < len(_state_space): + raise ValueError( + "Invalid node state space tuple. Repeated states are ambiguous." + ) + + if len(_state_space) < 2: + raise ValueError( + "Invalid node state space with less than 2 states." + ) + + self._state_space = _state_space + + @property + def state(self): + """Optional[Union[int, str]]: The current state of this node.""" + return self._state + + @state.setter + def state(self, value): + if value not in (*self.state_space, None): + raise ValueError( + f"Invalid node state. Possible states are {self.state_space}." + ) + + self._state = value + + def project_index(self, index, preserve_singletons=False): + """Convert absolute TPM index to a valid index relative to this node.""" + + # Supported index coordinates (in the right dimension order) + # respective to this node, to be used like an AND mask, with + # `singleton_coordinate` acting like 0. + dimensions = self._effect_dataarray.dims + coordinates = self._effect_dataarray.coords + + support = {dim: tuple(coordinates[dim].values) for dim in dimensions} + + if isinstance(index, dict): + singleton_coordinate = ( + [SINGLETON_COORDINATE] if preserve_singletons + else SINGLETON_COORDINATE + ) + + try: + # Convert potential int dimension indices to common currency of + # string dimension labels. + keys = [ + k if isinstance(k, str) else dimensions[k] + for k in index.keys() + ] + + projected_index = { + key: value if support[key] != (SINGLETON_COORDINATE,) + else singleton_coordinate + for key, value in zip(keys, index.values()) + } + + except KeyError as e: + raise ValueError( + "Dimension {} does not exist. Expected one or more of: " + "{}.".format(e, dimensions) + ) from e + + return projected_index + + # Assume regular index otherwise. + + if not isinstance(index, tuple): + # Index is a single int, slice, ellipsis, etc. Make it + # amenable to zip(). + index = (index,) + + index_support_map = zip(index, support.values()) + singleton_coordinate = [0] if preserve_singletons else 0 + projected_index = tuple( + i if support != (SINGLETON_COORDINATE,) + else singleton_coordinate + for i, support in index_support_map + ) + + return projected_index + + # def __getitem__(self, index): + # return self._dataarray[index].node def __repr__(self): return self.label @@ -147,20 +240,21 @@ def __str__(self): def __eq__(self, other): """Return whether this node equals the other object. - Two nodes are equal if they belong to the same subsystem and have the - same index (their TPMs must be the same in that case, so this method - doesn't need to check TPM equality). + Two nodes are equal if they have the same index, the same + inputs and outputs, the same TPMs, the same state_space and the + same state. Labels are for display only, so two equal nodes may have different labels. """ return ( - self.index == other.index - and self.cause_tpm.array_equal(other.cause_tpm) - and self.effect_tpm.array_equal(other.effect_tpm) - and self.state == other.state - and self.inputs == other.inputs - and self.outputs == other.outputs + self.index == other.index and + self.cause_tpm.array_equal(other.tpm) and + self.effect_tpm.array_equal(other.tpm) and + self.inputs == other.inputs and + self.outputs == other.outputs and + self.state_space == other.state_space and + self.state == other.state ) def __ne__(self, other): @@ -178,42 +272,181 @@ def to_json(self): return self.index -def generate_nodes(cause_tpm, effect_tpm, cm, network_state, indices, node_labels=None): - """Generate |Node| objects for a subsystem. +def generate_node( + effect_tpm: pyphi.tpm.ExplicitTPM, + cm: np.ndarray, + network_state_space: Mapping[str, Tuple[Union[int, str]]], + index: int, + node_labels: Iterable[str], + cause_tpm: Optional[pyphi.tpm.ExplicitTPM] = None, + state: Optional[Union[int, str]] = None, +) -> xr.DataArray: + """ + Instantiate a node TPM DataArray. Args: - cause_tpm (ExplicitTPM): The system's cause (backward) TPM - effect_tpm (ExplicitTPM): The system's effect (forward) TPM - cm (np.ndarray): The corresponding CM. - network_state (tuple): The state of the network. - indices (tuple[int]): Indices to generate nodes for. + effect_tpm (ExplicitTPM): The effect TPM of this node. + cm (np.ndarray): The CM of the network. + network_state_space (Mapping[str, Tuple[Union[int, str]]]): + Labels for the state space of each node in the network. + index (int): The node's index in the network. + node_labels (Iterable[str]): Textual labels for each node in the network. Keyword Args: - node_labels (|NodeLabels|): Textual labels for each node. + cause_tpm (ExplicitTPM): The cause TPM of this node. + state (Optional[Union[int, str]]): The state of this node. Returns: - tuple[Node]: The nodes of the system. + xr.DataArray: The node in question. """ - if node_labels is None: - node_labels = NodeLabels(None, indices) - - node_state = utils.state_of(indices, network_state) + # Get indices of the inputs and outputs. + inputs = frozenset(get_inputs_from_cm(index, cm)) + outputs = frozenset(get_outputs_from_cm(index, cm)) + + # Marginalize out non-input nodes. + effect_non_inputs = set(effect_tpm.tpm_indices()) - inputs + effect_tpm = effect_tpm.marginalize_out(effect_non_inputs) + + if cause_tpm is not None: + cause_non_inputs = set(cause_tpm.tpm_indices()) - inputs + cause_tpm = cause_tpm.marginalize_out(cause_non_inputs) + + # Dimensions are the names of this node's parents (whose state this node's + # TPM can be conditioned on), plus the last dimension with the probability + # for each possible state of this node in the next timestep. + dimensions = dimension_labels(node_labels) + + # Compute the relevant state labels (coordinates in xarray terminology) from + # the perspective of this node and its direct inputs. + node_states = [network_state_space[dim] for dim in dimensions[:-1]] + input_coordinates, _ = build_state_space( + node_labels, + effect_tpm.shape[:-1], + node_states, + singleton_state_space=(SINGLETON_COORDINATE,), + ) - return tuple( - Node(cause_tpm, effect_tpm, cm, index, state, node_labels) - for index, state in zip(indices, node_state) + node_state_space = network_state_space[dimensions[index]] + + coordinates = {**input_coordinates, dimensions[-1]: node_state_space} + + cause_dataarray = xr.DataArray( + name=node_labels[index], + data=cause_tpm, + dims=dimensions, + coords=coordinates, + attrs={ + "index": index, + "node_labels": node_labels, + "cm": cm, + "inputs": inputs, + "outputs": outputs, + "state_space": tuple(node_state_space), + "state": state, + "network_state_space": network_state_space + } + ) if cause_tpm is not None else None + + effect_dataarray = xr.DataArray( + name=node_labels[index], + data=effect_tpm, + dims=dimensions, + coords=coordinates, + attrs={ + "index": index, + "node_labels": node_labels, + "cm": cm, + "inputs": inputs, + "outputs": outputs, + "state_space": tuple(node_state_space), + "state": state, + "network_state_space": network_state_space + } ) + return Node(effect_dataarray, cause_dataarray) + + +def generate_nodes( + network_tpm, + cm: np.ndarray, + state_space: Mapping[str, Tuple[Union[int, str]]], + indices: Tuple[int], + node_labels: Tuple[str], + network_state: Optional[Tuple[Union[int, str]]] = None, +) -> Tuple[xr.DataArray]: + """Generate |Node| objects out of a binary network |TPM|. + + Args: + network_tpm (|ExplicitTPM, ImplicitTPM|): The system's TPM. + cm (np.ndarray): The CM of the network. + state_space (Mapping[str, Tuple[Union[int, str]]]): Labels + for the state space of each node in the network. + indices (Tuple[int]): Indices to generate nodes for. + node_labels (Optional[Tuple[str]]): Textual labels for each node. + + Keyword Args: + network_state (Optional[Tuple[Union[int, str]]]): The state of + the network. + + Returns: + Tuple[xr.DataArray]: The nodes of the system. + """ + if isinstance(network_tpm, pyphi.tpm.ImplicitTPM): + network_tpm = pyphi.tpm.reconstitute_tpm(network_tpm) + + if network_state is None: + network_state = (None,) * cm.shape[0] + + node_state = state_of(indices, network_state) + + nodes = [] + + for index, state in zip(indices, node_state): + # We begin by getting the part of the subsystem's TPM that gives just + # the state of this node. This part is still indexed by network state, + # but its last dimension will be gone, since now there's just a single + # scalar value (this node's state) rather than a state-vector for all + # the network nodes. + tpm_on = network_tpm[..., index] + + # Get the TPM that gives the probability of the node being off, rather + # than on. + tpm_off = 1 - tpm_on + + # Combine the on- and off-TPM so that the first dimension is indexed by + # the state of the node's inputs at t, and the last dimension is + # indexed by the node's state at t+1. This representation makes it easy + # to condition on the node state. + node_tpm = pyphi.tpm.ExplicitTPM( + np.stack([np.asarray(tpm_off), np.asarray(tpm_on)], axis=-1) + ) + + nodes.append( + generate_node( + node_tpm, + cm, + state_space, + index, + node_labels, + cause_tpm=None, + state=state, + ) + ) + + return tuple(nodes) + +# TODO: nonbinary nodes def expand_node_tpm(tpm): """Broadcast a node TPM over the full network. Args: - tpm (ExplicitTPM): The node TPM to expand. + tpm (|ExplicitTPM|): The node TPM to expand. This is different from broadcasting the TPM of a full system since the last dimension (containing the state of the node) contains only the probability of *this* node being on, rather than the probabilities for each node. """ - uc = ExplicitTPM(np.ones([2 for node in tpm.shape])) + uc = pyphi.tpm.ExplicitTPM(np.ones([2 for node in tpm.shape])) return uc * tpm diff --git a/pyphi/state_space.py b/pyphi/state_space.py new file mode 100644 index 000000000..88add190f --- /dev/null +++ b/pyphi/state_space.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# state_space.py + +""" +Constants and utility functions for dealing with the state space of a |Network|. +""" + +from typing import Iterable, List, Optional, Union, Tuple + +from .data_structures import FrozenMap + + +INPUT_DIMENSION_PREFIX = "" +PROBABILITY_DIMENSION = "Pr" +SINGLETON_COORDINATE = "_" + + +def input_dimension_label(node_label: str) -> str: + """Generate label for an input dimension in the |ImplicitTPM|. + + data_vars (xr.DataArray node names) and dimension names share the + same dictionary-like namespace in xr.Dataset. Prepend constant + string to avoid the conflict. + + Args: + node_label (str): Textual label for a node in the network. + + Returns: + str: Textual label for the same dimension in the multidimensional TPM. + """ + return INPUT_DIMENSION_PREFIX + str(node_label) + +def dimension_labels(node_labels: Iterable[str]) -> List[str]: + """Generate labels for each dimension in the |ImplicitTPM|. + + Args: + node_labels (Iterable[str]): Textual labels for each node in the network. + + Returns: + List[str]: Textual labels for each dimension in the multidimensional TPM. + """ + return ( + list(map(input_dimension_label, node_labels)) + + [PROBABILITY_DIMENSION] + ) + + +def build_state_space( + node_labels: Iterable[str], + nodes_shape: Iterable[int], + node_states: Optional[Iterable[Iterable[Union[int, str]]]] = None, + singleton_state_space: Optional[Iterable[Union[int, str]]] = None, +) -> Tuple[FrozenMap[str, Tuple[Union[int, str]]], int]: + """Format the passed state space labels or construct defaults if none. + + Args: + node_labels (Iterable[str]): Textual labels for each node in the network. + nodes_shape (Iterable[int]): The first |n| components in the shape of + a multidimensional |ExplicitTPM|, where |n| is the number of nodes + in the network. + + Keyword Args: + node_states (Optional[Iterable[Iterable[Union[int, str]]]]): The + network's state space labels as provided by the user. + singleton_state_space (Optional[Iterable[Union[int, str]]]): The label + to be used for singleton dimensions. If ``None``, singleton + dimensions will be discarded. + + Returns: + Tuple[FrozenMap[str, Tuple[Union[int, str]]], int]: State space for the network + of interest and its hash. + """ + if node_states is None: + node_states = [tuple(range(dim)) for dim in nodes_shape] + else: + node_states = [tuple(n) for n in node_states] + + # labels-to-states map. + state_space = zip(dimension_labels(node_labels), node_states) + + # Filter out states of singleton dimensions. + shape_state_map = zip(nodes_shape, state_space) + + if singleton_state_space is None: + state_space = { + node_states + for dim, node_states in shape_state_map + if dim > 1 + } + + else: + state_space = { + node_states if dim > 1 else (node_states[0], singleton_state_space) + for dim, node_states in shape_state_map + } + + state_space = FrozenMap(state_space) + state_space_hash = hash(state_space) + state_space = FrozenMap({k: list(v) for k,v in state_space.items()}) + + return (state_space, state_space_hash) diff --git a/pyphi/subsystem.py b/pyphi/subsystem.py old mode 100644 new mode 100755 index 08ecba06f..7788ad4c4 --- a/pyphi/subsystem.py +++ b/pyphi/subsystem.py @@ -27,11 +27,10 @@ _null_ria, CauseEffectStructure, ) +from .node import generate_node from .models.mechanism import ShortCircuitConditions, StateSpecification from .network import irreducible_purviews -from .node import generate_nodes from .partition import mip_partitions -from .tpm import backward_tpm as _backward_tpm from .utils import state_of log = logging.getLogger(__name__) @@ -79,7 +78,6 @@ def __init__( ): # The network this subsystem belongs to. validate.is_network(network) - network._tpm = network.tpm self.network = network self.node_labels = network.node_labels @@ -87,7 +85,7 @@ def __init__( # (for JSON serialization). self.node_indices = self.node_labels.coerce_to_indices(nodes) - validate.state_length(state, self.network.size) + validate.state(state, self.network.size, self.network.tpm.shape[:-1]) # The state of the network. self.state = tuple(state) @@ -104,12 +102,12 @@ def __init__( # Get the TPMs conditioned on the state of the external nodes. external_state = utils.state_of(self.external_indices, self.state) background_conditions = dict(zip(self.external_indices, external_state)) - self.cause_tpm = _backward_tpm(self.network.tpm, state, self.node_indices) + self.cause_tpm = self.network.tpm.backward_tpm(state, self.node_indices) self.effect_tpm = self.network.tpm.condition_tpm(background_conditions) # The TPMs for just the nodes in the subsystem. - self.proper_effect_tpm = self.effect_tpm.squeeze()[..., list(self.node_indices)] - self.proper_cause_tpm = self.cause_tpm.squeeze()[..., list(self.node_indices)] + self.proper_effect_tpm = self.effect_tpm.squeeze() + self.proper_cause_tpm = self.cause_tpm.squeeze() # The unidirectional cut applied for phi evaluation self.cut = ( @@ -133,13 +131,26 @@ def __init__( unconstrained_forward_repertoire_cache or cache.DictCache() ) - self.nodes = generate_nodes( - self.cause_tpm, - self.effect_tpm, - self.cm, - self.state, - self.node_indices, - self.node_labels, + # Set the state of the |Node|s. + nodes_zip = zip(self.effect_tpm.nodes, self.cause_tpm.nodes, self.state) + for effect_node, cause_node, node_state in nodes_zip: + effect_node.state = node_state + cause_node.state = node_state + + # Generate |Node|s for this subsystem and this particular cut to the cm. + nodes_enumerate = enumerate(zip(self.cause_tpm.nodes, self.effect_tpm.nodes)) + self.nodes = tuple( + generate_node( + node[Direction.EFFECT].effect_tpm, + self.cm, + self.network.state_space, + i, + self.node_labels, + cause_tpm=node[Direction.CAUSE].effect_tpm, + state=node[Direction.EFFECT].state, + ) + for i, node in nodes_enumerate + if i in self.node_indices ) validate.subsystem(self) @@ -215,6 +226,10 @@ def tpm_size(self): raise ValueError("cause and effect TPM sizes should be the same") return self.effect_tpm.shape[-1] + @property + def state_space(self): + return self.network.state_space + def cache_info(self): """Report repertoire cache statistics.""" return { @@ -1088,7 +1103,8 @@ def find_mice(self, direction, mechanism, purviews=None, **kwargs): Returns: MaximallyIrreducibleCauseOrEffect: The |MIC| or |MIE|. """ - purviews = self.potential_purviews(direction, mechanism, purviews) + if purviews is None: + purviews = self.potential_purviews(direction, mechanism, purviews) if direction == Direction.CAUSE: mice_class = MaximallyIrreducibleCause diff --git a/pyphi/tpm.py b/pyphi/tpm.py old mode 100644 new mode 100755 index 7f481cf86..6a4bb2569 --- a/pyphi/tpm.py +++ b/pyphi/tpm.py @@ -2,160 +2,201 @@ """Provides classes for representing TPMs.""" import math +import functools from itertools import chain -from typing import Mapping, Set, Iterable +from typing import Iterable, Mapping, Optional, Set, Tuple import numpy as np -from . import convert, data_structures, exceptions +from . import convert, distribution, data_structures, exceptions +from .connectivity import subadjacency from .conf import config from .constants import OFF, ON from .data_structures import FrozenMap -from .utils import all_states, np_hash, np_immutable +import pyphi.node +from .utils import all_states, eq, np_hash, np_immutable +class TPM: + """TPM interface for derived classes.""" -# TODO(tpm) remove pending ArrayLike refactor -class ProxyMetaclass(type): - """A metaclass to create wrappers for the TPM array's special attributes. - - The CPython interpreter resolves double-underscore attributes (e.g., the - method definitions of mathematical operators) by looking up in the class' - static methods, not in the instance methods. This makes it impossible to - intercept calls to them when an instance's __getattr__() is implicitly - invoked, which in turn means there are only two options to wrap the special - methods of the array inside our custom objects (in order to perform - arithmetic operations with the TPM while also casting the result to our - custom class type): - - 1. Manually "overload" all the necessary methods. - 2. Use this metaclass to introspect the underlying array - and automatically overload methods in our custom TPM class definition. - """ + _ERROR_MSG_PROBABILITY_IMAGE = ( + "Invalid TPM: probabilities must be in the interval [0, 1]." + ) + + _ERROR_MSG_PROBABILITY_SUM = "Invalid TPM: probabilities must sum to 1." + + def validate(self, check_independence=True): + raise NotImplementedError + + def to_multidimensional_state_by_node(self): + raise NotImplementedError + + def conditionally_independent(self): + raise NotImplementedError + + def condition_tpm(self, condition): + raise NotImplementedError + + def marginalize_out(self, node_indices): + raise NotImplementedError - def __init__(cls, type_name, bases, dct): - # Casting semantics: values belonging to our custom TPM class should - # remain closed under the following methods: - __closures__ = frozenset( - { - # 1-ary - "__abs__", - "__copy__", - "__invert__", - "__neg__", - "__pos__", - # 2-ary - "__add__", - "__iadd__", - "__radd__", - "__sub__", - "__isub__", - "__rsub__", - "__mul__", - "__imul__", - "__rmul__", - "__matmul__", - "__imatmul__", - "__rmatmul__", - "__truediv__", - "__itruediv__", - "__rtruediv__", - "__floordiv__", - "__ifloordiv__", - "__rfloordiv__", - "__mod__", - "__imod__", - "__rmod__", - "__and__", - "__iand__", - "__rand__", - "__lshift__", - "__ilshift__", - "__irshift__", - "__rlshift__", - "__rrshift__", - "__rshift__", - "__ior__", - "__or__", - "__ror__", - "__xor__", - "__ixor__", - "__rxor__", - "__eq__", - "__ne__", - "__ge__", - "__gt__", - "__lt__", - "__le__", - "__deepcopy__", - # 3-ary - "__pow__", - "__ipow__", - "__rpow__", - # 2-ary, 2-valued - "__divmod__", - "__rdivmod__", - } + def is_deterministic(self): + raise NotImplementedError + + def is_state_by_state(self): + raise NotImplementedError + + def remove_singleton_dimensions(self): + raise NotImplementedError + + def expand_tpm(self): + raise NotImplementedError + + def subtpm(self, fixed_nodes, state): + """Return the TPM for a subset of nodes, conditioned on other nodes. + + Arguments: + fixed_nodes (tuple[int]): The nodes to select. + state (tuple[int]): The state of the fixed nodes. + + Returns: + ExplicitTPM: The TPM of just the subsystem of the free nodes. + + Examples: + >>> from pyphi import examples + >>> # Get the TPM for nodes only 1 and 2, conditioned on node 0 = OFF + >>> reconstitute_tpm(examples.grid3_network().tpm).subtpm((0,), (0,)) + ExplicitTPM( + [[[[0.02931223 0.04742587] + [0.07585818 0.88079708]] + + [[0.81757448 0.11920292] + [0.92414182 0.95257413]]]] + ) + """ + N = self.shape[-1] + free_nodes = sorted(set(range(N)) - set(fixed_nodes)) + condition = FrozenMap(zip(fixed_nodes, state)) + conditioned_tpm = self.condition_tpm(condition) + + if isinstance(self, ExplicitTPM): + return conditioned_tpm[..., free_nodes] + + return type(self)( + tuple( + node for node in conditioned_tpm.nodes + if node.index in free_nodes + ) ) - def make_proxy(name): - """Returns a function that acts as a proxy for the given method name. + def infer_edge(self, a, b, contexts): + """Infer the presence or absence of an edge from node A to node B. + + Let |S| be the set of all nodes in a network. Let |A' = S - {A}|. We + call the state of |A'| the context |C| of |A|. There is an edge from |A| + to |B| if there exists any context |C(A)| such that + |Pr(B | C(A), A=0) != Pr(B | C(A), A=1)|. - Args: - name (str): The name of the method to introspect in self._tpm. + Args: + a (int): The index of the putative source node. + b (int): The index of the putative sink node. + contexts (tuple[tuple[int]]): The tuple of states of ``a`` + Returns: + bool: ``True`` if the edge |A -> B| exists, ``False`` otherwise. + """ - Returns: - function: The wrapping function. + def a_in_context(context): + """Given a context C(A), return the states of the full system with A + OFF and ON, respectively. """ + a_off = context[:a] + OFF + context[a:] + a_on = context[:a] + ON + context[a:] + return (a_off, a_on) - def proxy(self): - return _new_attribute(name, __closures__, self._tpm) - - return proxy + def a_affects_b_in_context(tpm, context): + """Return ``True`` if A has an effect on B, given a context.""" + a_off, a_on = a_in_context(context) + return tpm[a_off][b] != tpm[a_on][b] - type.__init__(cls, type_name, bases, dct) + tpm = self.to_multidimensional_state_by_node() + return any(a_affects_b_in_context(tpm, context) for context in contexts) - if not cls.__wraps__: - return + def infer_cm(self): + """Infer the connectivity matrix associated with a state-by-node TPM in + multidimensional form. + """ + tpm = self.to_multidimensional_state_by_node() + network_size = tpm.shape[-1] + all_contexts = tuple(all_states(network_size - 1)) + cm = np.empty((network_size, network_size), dtype=int) + for a, b in np.ndindex(cm.shape): + cm[a][b] = self.infer_edge(a, b, all_contexts) + return cm - ignore = cls.__ignore__ + def tpm_indices(self, reconstituted=False): + """Return the indices of nodes in the TPM.""" + shape = self._reconstituted_shape if reconstituted else self.shape + return tuple(np.where(np.array(shape[:-1]) != 1)[0]) - # Go through all the attribute strings in the wrapped array type. - for name in dir(cls.__wraps__): - # Filter special attributes, rest will be handled by `__getattr__()` - if any([not name.startswith("__"), name in ignore, name in dct]): - continue + def print(self): + raise NotImplementedError - # Create function for `name` and bind to future instances of `cls`. - setattr(cls, name, property(make_proxy(name))) + def permute_nodes(self, permutation): + raise NotImplementedError + def backward_tpm(self, current_state, system_indices): + raise NotImplementedError -class Wrapper(metaclass=ProxyMetaclass): - """Proxy to the array inside PyPhi's custom TPM class.""" + def __str__(self): + raise NotImplementedError - __wraps__ = None + def __repr__(self): + raise NotImplementedError - __ignore__ = frozenset( - { - "__class__", - "__mro__", - "__new__", - "__init__", - "__setattr__", - "__getattr__", - "__getattribute__", - } - ) + def __hash__(self): + raise NotImplementedError - def __init__(self): - if self.__wraps__ is None: - raise TypeError("Base class Wrapper may not be instantiated.") - if not isinstance(self._tpm, self.__wraps__): - raise ValueError(f"Wrapped object must be of type {self.__wraps__}") +class ExplicitTPM(data_structures.ArrayLike, TPM): + """An explicit network TPM in multidimensional form. -class ExplicitTPM(data_structures.ArrayLike): - """An explicit network TPM in multidimensional form.""" + Args: + tpm (np.array): The transition probability matrix of the |Network|. + + The TPM can be provided in any of three forms: **state-by-state**, + **state-by-node**, or **multidimensional state-by-node** form. + In the state-by-node forms, row indices must follow the + little-endian convention (see :ref:`little-endian-convention`). In + state-by-state form, column indices must also follow the + little-endian convention. + + If the TPM is given in state-by-node form, it can be either + 2-dimensional, so that ``tpm[i]`` gives the probabilities of each + node being ON if the previous state is encoded by |i| according to + the little-endian convention, or in multidimensional form, so that + ``tpm[(0, 0, 1)]`` gives the probabilities of each node being ON if + the previous state is |N_0 = 0, N_1 = 0, N_2 = 1|. + + The shape of the 2-dimensional form of a state-by-node TPM must be + ``(s, n)``, and the shape of the multidimensional form of the TPM + must be ``[2] * n + [n]``, where ``s`` is the number of states and + ``n`` is the number of nodes in the network. + + Keyword Args: + validate (bool): Whether to check the shape and content of the input + array for correctness. + + Example: + In a 3-node network, ``tpm[(0, 0, 1)]`` gives the + transition probabilities for each node at |t| given that state at |t-1| + was |N_0 = 0, N_1 = 0, N_2 = 1|. + + Attributes: + _VALUE_ATTR (str): The key of the attribute holding the TPM array value. + __wraps__ (type): The class of the array referenced by ``_VALUE_ATTR``. + __closures__ (frozenset): np.ndarray method names proxied by this class. + """ _VALUE_ATTR = "_tpm" @@ -165,9 +206,6 @@ class ExplicitTPM(data_structures.ArrayLike): # TODO(tpm) remove pending ArrayLike refactor # Casting semantics: values belonging to our custom TPM class should # remain closed under the following methods: - - # TODO attributes data, real and imag return arrays that should also be - # cast, even though they are not callable. __closures__ = frozenset( { "argpartition", @@ -225,10 +263,12 @@ def __len__(self): def __init__(self, tpm, validate=False): self._tpm = np.array(tpm) - super().__init__() if validate: - self.validate(check_independence=config.VALIDATE_CONDITIONAL_INDEPENDENCE) + self.validate( + check_independence=config.VALIDATE_CONDITIONAL_INDEPENDENCE, + network_tpm=True + ) self._tpm = self.to_multidimensional_state_by_node() self._tpm = np_immutable(self._tpm) @@ -236,27 +276,56 @@ def __init__(self, tpm, validate=False): @property def tpm(self): - """Return the underlying `tpm` object.""" + """np.ndarray: The underlying `tpm` object.""" return self._tpm - def validate(self, check_independence=True): + @property + def number_of_units(self): + if self.is_state_by_state(): + # Assumes binary nodes + return int(math.log2(self._tpm.shape[1])) + return self._tpm.shape[-1] + + def validate(self, check_independence=True, network_tpm=False): """Validate this TPM.""" - return self._validate_probabilities() and self._validate_shape( + return self._validate_probabilities(network_tpm) and self._validate_shape( check_independence ) - def _validate_probabilities(self): + def _validate_probabilities(self, network_tpm=False): """Check that the probabilities in a TPM are valid.""" + # Validate TPM image is within [0, 1] (first axiom of probability). if (self._tpm < 0.0).any() or (self._tpm > 1.0).any(): - raise ValueError( - "Invalid TPM: probabilities must be in the interval [0, 1]." - ) - if self.is_state_by_state() and not np.all( - np.isclose(np.sum(self._tpm, axis=1), 1.0, atol=1e-15) - ): - raise ValueError("Invalid TPM: probabilities must sum to 1.") + raise ValueError(self._ERROR_MSG_PROBABILITY_IMAGE) + + # Validate that probabilities sum to 1. + if not self.is_unitary(network_tpm): + raise ValueError(self._ERROR_MSG_PROBABILITY_SUM) + return True + def is_unitary(self, network_tpm=False): + """Whether the TPM satisfies the second axiom of probability theory. + + A TPM is unitary if and only if for every current state of the system, + the probability distribution over next states conditioned on the current + state sums to 1 (up to |config.PRECISION|). + + Keyword Args: + network_tpm (bool): Whether ``self`` is an old-style system TPM + instead of a node TPM. + + Returns: + bool: + """ + tpm = self + if network_tpm and not tpm.is_state_by_state(): + tpm = convert.state_by_node2state_by_state(self) + + # Marginalize last dimension, then check that all integrals are close to 1. + measures_over_current_states = tpm.sum(axis=-1).ravel() + return all(eq(p, 1.0) for p in measures_over_current_states) + def _validate_shape(self, check_independence=True): """Validate this TPM's shape. @@ -302,13 +371,6 @@ def _validate_shape(self, check_independence=True): ) return True - @property - def number_of_units(self): - if self.is_state_by_state(): - # Assumes binary nodes - return int(math.log2(self._tpm.shape[1])) - return self._tpm.shape[-1] - def to_multidimensional_state_by_node(self): """Return the current TPM re-represented in multidimensional state-by-node form. @@ -374,12 +436,7 @@ def condition_tpm(self, condition: Mapping[int, int]): conditioning_indices = tuple(chain.from_iterable(conditioning_indices)) # Obtain the actual conditioned TPM by indexing with the conditioning # indices. - tpm = self._tpm[conditioning_indices] - # Create new TPM object of the same type as self. - # self.tpm has already been validated and converted to multidimensional - # state-by-node form. Further validation would be problematic for - # singleton dimensions. - return type(self)(tpm) + return self[conditioning_indices] def marginalize_out(self, node_indices): """Marginalize out nodes from this TPM. @@ -391,13 +448,12 @@ def marginalize_out(self, node_indices): ExplicitTPM: A TPM with the same number of dimensions, with the nodes marginalized out. """ - tpm = self._tpm.sum(tuple(node_indices), keepdims=True) / ( + tpm = self.sum(tuple(node_indices), keepdims=True) / ( np.array(self.shape)[list(node_indices)].prod() ) - # Return new TPM object of the same type as self. - # self._tpm has already been validated and converted to multidimensional - # state-by-node form. Further validation would be problematic for - # singleton dimensions. + # Return new TPM object of the same type as self. Assume self had + # already been validated and converted formatted. Further validation + # would be problematic for singleton dimensions. return type(self)(tpm) def is_deterministic(self): @@ -410,31 +466,22 @@ def is_state_by_state(self): """ return self.ndim == 2 and self.shape[0] == self.shape[1] - def subtpm(self, fixed_nodes, state): - """Return the TPM for a subset of nodes, conditioned on other nodes. - - Arguments: - fixed_nodes (tuple[int]): The nodes to select. - state (tuple[int]): The state of the fixed nodes. + def remove_singleton_dimensions(self): + """Remove singleton dimensions from the TPM. - Returns: - ExplicitTPM: The TPM of just the subsystem of the free nodes. + Singleton dimensions are created by conditioning on a set of elements. + This removes those elements from the TPM, leaving a TPM that only + describes the non-conditioned elements. - Examples: - >>> from pyphi import examples - >>> # Get the TPM for nodes only 1 and 2, conditioned on node 0 = OFF - >>> examples.grid3_network().tpm.subtpm((0,), (0,)) - ExplicitTPM([[[[0.02931223 0.04742587] - [0.07585818 0.88079708]] - - [[0.81757448 0.11920292] - [0.92414182 0.95257413]]]]) + Note that indices used in the original TPM must be reindexed for the + smaller TPM. """ - free_nodes = sorted(set(range(self.number_of_units)) - set(fixed_nodes)) - condition = FrozenMap(zip(fixed_nodes, state)) - conditioned = self.condition_tpm(condition) - # TODO test indicing behavior on xr.DataArray - return conditioned[..., free_nodes] + # Don't squeeze out the final dimension (which contains the probability) + # for networks with one element. + if self.ndim <= 2: + return self + + return self.squeeze()[..., self.tpm_indices()] def expand_tpm(self): """Broadcast a state-by-node TPM so that singleton dimensions are expanded @@ -443,55 +490,6 @@ def expand_tpm(self): unconstrained = np.ones([2] * (self._tpm.ndim - 1) + [self._tpm.shape[-1]]) return type(self)(self._tpm * unconstrained) - def infer_edge(self, a, b, contexts): - """Infer the presence or absence of an edge from node A to node B. - - Let |S| be the set of all nodes in a network. Let |A' = S - {A}|. We - call the state of |A'| the context |C| of |A|. There is an edge from |A| - to |B| if there exists any context |C(A)| such that - |Pr(B | C(A), A=0) != Pr(B | C(A), A=1)|. - - Args: - a (int): The index of the putative source node. - b (int): The index of the putative sink node. - contexts (tuple[tuple[int]]): The tuple of states of ``a`` - Returns: - bool: ``True`` if the edge |A -> B| exists, ``False`` otherwise. - """ - - def a_in_context(context): - """Given a context C(A), return the states of the full system with A - OFF and ON, respectively. - """ - a_off = context[:a] + OFF + context[a:] - a_on = context[:a] + ON + context[a:] - return (a_off, a_on) - - def a_affects_b_in_context(tpm, context): - """Return ``True`` if A has an effect on B, given a context.""" - a_off, a_on = a_in_context(context) - return tpm[a_off][b] != tpm[a_on][b] - - tpm = self.to_multidimensional_state_by_node() - return any(a_affects_b_in_context(tpm, context) for context in contexts) - - def infer_cm(self): - """Infer the connectivity matrix associated with a state-by-node TPM in - multidimensional form. - """ - tpm = self.to_multidimensional_state_by_node() - network_size = tpm.shape[-1] - all_contexts = tuple(all_states(network_size - 1)) - cm = np.empty((network_size, network_size), dtype=int) - for a, b in np.ndindex(cm.shape): - cm[a][b] = self.infer_edge(a, b, all_contexts) - return cm - - def tpm_indices(self): - """Return the indices of nodes in the TPM.""" - # TODO This currently assumes binary elements (2) - return tuple(np.where(np.array(self.shape[:-1]) == 2)[0]) - def print(self): tpm = convert.to_multidimensional(self._tpm) for state in all_states(tpm.shape[-1]): @@ -509,11 +507,58 @@ def permute_nodes(self, permutation): self._tpm.transpose(dimension_permutation)[..., list(permutation)], ) - def __getitem__(self, i): - item = self._tpm[i] - if isinstance(item, type(self._tpm)): - item = type(self)(item) - return item + def probability_of_current_state(self, current_state): + """Return the probability of the current state as a distribution over previous states. + + Arguments: + current_state (tuple[int]): The current state. + """ + state_probabilities = np.empty(self.shape) + if not len(current_state) == self.shape[-1]: + raise ValueError( + f"current_state must have length {self.shape[-1]} " + f"for state-by-node TPM of shape {self.shape}" + ) + for i in range(self.shape[-1]): + # TODO extend to nonbinary nodes + state_probabilities[..., i] = ( + self[..., i] if current_state[i] else (1 - self[..., i]) + ) + return state_probabilities.prod(axis=-1, keepdims=True) + + def backward_tpm( + self, + current_state: tuple[int], + system_indices: Iterable[int], + remove_background: bool = False, + ): + """Compute the backward TPM for a given network state.""" + all_indices = tuple(range(self.number_of_units)) + system_indices = tuple(sorted(system_indices)) + background_indices = tuple(sorted(set(all_indices) - set(system_indices))) + if not set(system_indices).issubset(set(all_indices)): + raise ValueError( + "system_indices must be a subset of `range(self.number_of_units))`" + ) + + # p(u_t | s_{t–1}, w_{t–1}) + pr_current_state = self.probability_of_current_state(current_state) + # Σ_{s_{t–1}} p(u_t | s_{t–1}, w_{t–1}) + pr_current_state_given_only_background = pr_current_state.sum( + axis=tuple(system_indices), keepdims=True + ) + # Σ_{u'_{t–1}} p(u_t | u'_{t–1}) + normalization = np.sum(pr_current_state) + # Σ_{s_{t–1}} p(u_t | s_{t–1}, w_{t–1}) + # Σ_{w_{t–1}} p(s_{i,t} | s_{t–1}, w_{t–1}) ——————————————————————————————————————— + # Σ_{u'_{t–1}} p(u_t | u'_{t–1}) + backward_tpm = ( + self * pr_current_state_given_only_background / normalization + ).sum(axis=background_indices, keepdims=True) + if remove_background: + # Remove background units from last dimension of the state-by-node TPM + backward_tpm = backward_tpm[..., list(system_indices)] + return ExplicitTPM(backward_tpm) def array_equal(self, o: object): """Return whether this TPM equals the other object. @@ -523,41 +568,472 @@ def array_equal(self, o: object): """ return isinstance(o, type(self)) and np.array_equal(self._tpm, o._tpm) + def __getitem__(self, i): + item = self._tpm[i] + if isinstance(item, type(self._tpm)): + item = type(self)(item) + return item + def __str__(self): return self.__repr__() def __repr__(self): - return "ExplicitTPM({})".format(self._tpm) + return "ExplicitTPM(\n{}\n)".format(self._tpm) def __hash__(self): return self._hash +class ImplicitTPM(TPM): + + """An implicit network TPM containing |Node| TPMs in multidimensional form. + + Args: + dataset (xr.Dataset): + + Attributes: + """ + + def __init__(self, nodes): + """Args: + nodes (pyphi.node.Node) + """ + self._nodes = tuple(nodes) + + @property + def nodes(self): + """Tuple[xr.DataArray]: The node TPMs in this ImplicitTPM""" + return self._nodes + + @property + def tpm(self): + """Tuple[np.ndarray]: Verbose representation of all node TPMs.""" + return tuple(node.effect_tpm for node in self._nodes) + + @property + def number_of_units(self): + return len(self.nodes) + + @property + def ndim(self): + """int: The number of dimensions of the TPM.""" + return len(self.shape) + + @property + def shape(self): + """Tuple[int]: The size or number of coordinates in each dimension.""" + shapes = self.shapes + return self._node_shapes_to_shape(shapes) + + @property + def _reconstituted_shape(self): + shapes = self.shapes + return self._node_shapes_to_shape(shapes, reconstituted=True) + + @property + def shapes(self): + """Tuple[Tuple[int]]: The shapes of each node TPM in this TPM.""" + return [node.effect_tpm.shape for node in self._nodes] + + @staticmethod + def _node_shapes_to_shape( + shapes: Iterable[Iterable[int]], + reconstituted: Optional[bool] = None + ) -> Tuple[int]: + """Infer the shape of the equivalent multidimensional |ExplicitTPM|. + + Args: + shapes (Iterable[Iterable[int]]): The shapes of the individual node + TPMs in the network, ordered by node index. + + Returns: + Tuple[int]: The inferred shape of the equivalent TPM. + """ + # This should recompute the network TPM shape from individual node + # shapes, as opposed to measuring the size of the state space. + + if not all(len(shape) == len(shapes[0]) for shape in shapes): + raise ValueError( + "The provided shapes contain varying number of dimensions." + ) + + N = len(shapes) + if reconstituted: + states_per_node = tuple(max(dim) for dim in zip(*shapes))[:-1] + else: + states_per_node = tuple(shape[-1] for shape in shapes) + + # Check consistency of shapes across nodes. + + dimensions_from_shapes = tuple( + set(shape[node_index] for shape in shapes) + for node_index in range(N) + ) + + for node_index in range(N): + # Valid state cardinalities along a dimension can be either: + # {1, s_i}, s_i != 1 iff node provides input to only some nodes, + # {s_i}, s_i != 1 iff node provides input to all nodes. + valid_cardinalities = ( + {max(dimensions_from_shapes[node_index]), 1}, + {max(dimensions_from_shapes[node_index])} + ) + if not any( + dimensions_from_shapes[node_index] == cardinality + for cardinality in valid_cardinalities + ): + raise ValueError( + "The provided shapes disagree on the number of states of " + "node {}.".format(node_index) + ) + + return states_per_node + (N,) + + def validate(self, check_independence=True): + """Validate this TPM.""" + return self._validate_probabilities() and self._validate_shape() + + def _validate_probabilities(self): + """Check that the probabilities in a TPM are valid.""" + # An implicit TPM contains valid probabilities if and only if + # individual node TPMs contain valid probabilities, for every node. + if all( + node.effect_tpm._validate_probabilities() + for node in self._nodes + ): + return True + + def is_unitary(self): + """Whether the TPM satisfies the second axiom of probability theory. + + A TPM is unitary if and only if for every current state of the system, + the probability distribution over next states conditioned on the current + state sums to 1 (up to |config.PRECISION|). + """ + return all(node.effect_tpm.is_unitary() for node in self._nodes) + + def _validate_shape(self): + """Validate this TPM's shape. + + The inferred shape of the implicit network TPM must be in + multidimensional state-by-node form, nonbinary and heterogeneous units + supported. + """ + N = len(self.nodes) + if N + 1 != self.ndim: + raise ValueError( + "Invalid TPM shape: {} nodes were provided, but their shapes" + "suggest a {}-node network.".format(N, self.ndim - 1) + ) + + return True + + def to_multidimensional_state_by_node(self): + """Return the current TPM re-represented in multidimensional + state-by-node form. + + See the PyPhi documentation on :ref:`tpm-conventions` for more + information. + + Returns: + np.ndarray: The TPM in multidimensional state-by-node format. + """ + return reconstitute_tpm(self) + + # TODO(tpm) accept node labels and state labels in the map. + def condition_tpm(self, condition: Mapping[int, int]): + """Return a TPM conditioned on the given fixed node indices, whose + states are fixed according to the given state-tuple. + + The dimensions of the new TPM that correspond to the fixed nodes are + collapsed onto their state, making those dimensions singletons suitable + for broadcasting. The number of dimensions of the conditioned TPM will + be the same as the unconditioned TPM. + + Args: + condition (dict[int, int]): A mapping from node indices to the state + to condition on for that node. + + Returns: + TPM: A conditioned TPM with the same number of dimensions, with + singleton dimensions for nodes in a fixed state. + """ + # Wrapping index elements in a list is the xarray equivalent + # of inserting a numpy.newaxis, which preserves the singleton even + # after selection of a single state. + conditioning_indices = { + i: (state_i if isinstance(state_i, list) else [state_i]) + for i, state_i in condition.items() + } + + return self.__getitem__(conditioning_indices, preserve_singletons=True) + + def marginalize_out(self, node_indices): + """Marginalize out nodes from this TPM. + + Args: + node_indices (list[int]): The indices of nodes to be marginalized out. + + Returns: + ImplicitTPM: A TPM with the same number of dimensions, with the nodes + marginalized out. + """ + # Leverage ExplicitTPM.marginalize_out() to distribute operation to + # individual nodes, then assemble into a new ImplicitTPM. + return type(self)( + tuple( + pyphi.node.generate_node( + node.effect_tpm.marginalize_out(node_indices), + node.effect_dataarray.attrs["cm"], + node.effect_dataarray.attrs["network_state_space"], + node.index, + node.effect_dataarray.attrs["node_labels"], + ) + for node in self.nodes + ) + ) + + def is_state_by_state(self): + """Return ``True`` if ``tpm`` is in state-by-state form, otherwise + ``False``. + """ + return False + + def remove_singleton_dimensions(self): + """Remove singleton dimensions from the TPM. + + Singleton dimensions are created by conditioning on a set of elements. + This removes those elements from the TPM, leaving a TPM that only + describes the non-conditioned elements. + + Note that indices used in the original TPM must be reindexed for the + smaller TPM. + """ + # Don't squeeze out the final dimension (which contains the probability) + # for networks with one element. + if self.ndim <= 2: + return self + + # Find the set of singleton dimensions for this TPM. + shape = self._reconstituted_shape + singletons = set(np.where(np.array(shape) == 1)[0]) + + # Squeeze out singleton dimensions and return a new TPM with + # the surviving nodes. + return type(self)( + tuple(node for node in self.squeeze().nodes) + ) + + def probability_of_current_state( + self, + current_state: tuple[int] + ) -> tuple[ExplicitTPM]: + """Return probability of current state as distribution over previous states. + + Output format is similar to an |ImplicitTPM|, however the last dimension + only contains the probability for the current state. + + Arguments: + current_state (tuple[int]): The current state. + Returns: + tuple[ExplicitTPM]: Node-marginal distributions of the current state. + """ + if not len(current_state) == self.number_of_units: + raise ValueError( + f"current_state must have length {self.number_of_units} " + f"for state-by-node TPM of shape {self.shape}" + ) + nodes = [] + for node in self.nodes: + i = node.index + state = current_state[i] + # DataArray indexing: keep last dimension by wrapping index in list. + pr_current_state = node.effect_dataarray[..., [state]].data + normalization = np.sum(pr_current_state) + nodes.append(pr_current_state / normalization) + return tuple(nodes) + + def backward_tpm( + self, + current_state: tuple[int], + system_indices: Iterable[int], + ): + """Compute the backward TPM for a given network state.""" + all_indices = tuple(range(self.number_of_units)) + system_indices = tuple(sorted(system_indices)) + background_indices = tuple(sorted(set(all_indices) - set(system_indices))) + if not set(system_indices).issubset(set(all_indices)): + raise ValueError( + "system_indices must be a subset of `range(self.number_of_units))`" + ) + # p(u_t | s_{t–1}, w_{t–1}) + pr_current_state_nodes = self.probability_of_current_state(current_state) + # TODO Avoid computing the full joint probability. Find uninformative + # dimensions after each product and propagate their dismissal. + pr_current_state = functools.reduce(np.multiply, pr_current_state_nodes) + # Σ_{s_{t–1}} p(u_t | s_{t–1}, w_{t–1}) + pr_current_state_given_only_background = pr_current_state.sum( + axis=tuple(system_indices), keepdims=True + ) + # Σ_{s_{t–1}} p(u_t | s_{t–1}, w_{t–1}) + # ————————————————————————————————————— + # Σ_{u'_{t–1}} p(u_t | u'_{t–1}) + pr_current_state_given_only_background_normalized = ( + pr_current_state_given_only_background / np.sum(pr_current_state) + ) + # Σ_{s_{t–1}} p(u_t | s_{t–1}, w_{t–1}) + # Σ_{w_{t–1}} p(s_{i,t} | s_{t–1}, w_{t–1}) ————————————————————————————————————— + # Σ_{u'_{t–1}} p(u_t | u'_{t–1}) + backward_tpm = tuple( + (node_tpm * pr_current_state_given_only_background_normalized).sum( + axis=background_indices, keepdims=True + ) + for node_tpm in self.tpm + ) + reference_node = self.nodes[0].effect_dataarray + return ImplicitTPM( + tuple( + pyphi.node.generate_node( + backward_node_tpm, + reference_node.attrs["cm"], + reference_node.attrs["network_state_space"], + i, + reference_node.attrs["node_labels"], + ) + for i, backward_node_tpm in enumerate(backward_tpm) + ) + ) + + def equals(self, o: object): + """Return whether this TPM equals the other object. + + Two TPMs are equal if they are instances of the same class + and their tuple of node TPMs are equal. + """ + return isinstance(o, type(self)) and self.nodes == o.nodes + + def array_equal(self, o: object): + return self.equals(o) + + def squeeze(self, axis=None): + """Wrapper around numpy.squeeze.""" + # If axis is None, all axis should be considered. + if axis is None: + axis = set(range(len(self))) + else: + axis = set(axis) if isinstance(axis, Iterable) else set([axis]) + + # Subtract non-singleton dimensions from `axis`, including fake + # singletons (dimensions that are singletons only for a proper subset of + # the nodes), since those should not be squeezed, not even within + # individual node TPMs. + shape = self._reconstituted_shape + nonsingletons = tuple(np.where(np.array(shape) > 1)[0]) + axis = tuple(axis - set(nonsingletons)) + + # From now on, we will only care about the first n-1 dimensions (parents). + if shape[-1] > 1: + nonsingletons = nonsingletons[:-1] + + # Recompute connectivity matrix and subset of node labels. + # TODO(tpm) deduplicate commonalities with macro.MacroSubsystem._squeeze. + some_node = self.nodes[0] + + new_cm = subadjacency(some_node.effect_dataarray.attrs["cm"], nonsingletons) + + new_node_indices = iter(range(len(nonsingletons))) + new_node_labels = tuple(some_node._node_labels[n] for n in nonsingletons) + + state_space = some_node.effect_dataarray.attrs["network_state_space"] + new_state_space = {n: state_space[n] for n in new_node_labels} + + # Leverage ExplicitTPM.squeeze to distribute squeezing to every node. + return type(self)( + tuple( + pyphi.node.generate_node( + node.effect_tpm.squeeze(axis=axis), + new_cm, + new_state_space, + next(new_node_indices), + new_node_labels, + ) + for node in self.nodes if node.index in nonsingletons + ) + ) + + def __getitem__(self, index, **kwargs): + if isinstance(index, (int, slice, type(...), tuple)): + return type(self)( + tuple( + # The nodes in an ImplicitTPM only have "effect" + # node TPMs, even if ImplicitTPM is a cause TPM. + node.effect_dataarray[node.project_index(index, **kwargs)].node + for node in self.nodes + ) + ) + if isinstance(index, dict): + return type(self)( + tuple( + # The nodes in an ImplicitTPM only have "effect" + # node TPMs, even if ImplicitTPM is a cause TPM. + node.effect_dataarray.loc[node.project_index(index, **kwargs)].node + for node in self.nodes + ) + ) + raise TypeError(f"Invalid index {index} of type {type(index)}.") + + def __len__(self): + """int: The number of nodes in the TPM.""" + return len(self._nodes) + + def __str__(self): + return self.__repr__() + + def __repr__(self): + return "ImplicitTPM({})".format(self.nodes) + + def __hash__(self): + return hash(tuple(hash(node) for node in self.nodes)) + + + def reconstitute_tpm(subsystem): - """Reconstitute the TPM of a subsystem using the individual node TPMs.""" + """Reconstitute the ExplicitTPM of a subsystem using individual node TPMs.""" # The last axis of the node TPMs correponds to ON or OFF probabilities # (used in the conditioning step when calculating the repertoires); we want # ON probabilities. - node_tpms = [node.tpm.tpm[..., 1] for node in subsystem.nodes] + + # TODO nonbinary nodes + node_tpms = [np.asarray(node.effect_tpm)[..., 1] for node in subsystem.nodes] + + external_indices = () + if hasattr(subsystem, "external_indices"): + external_indices = subsystem.external_indices + # Remove the singleton dimensions corresponding to external nodes - node_tpms = [tpm.squeeze(axis=subsystem.external_indices) for tpm in node_tpms] + node_tpms = [tpm.squeeze(axis=external_indices) for tpm in node_tpms] # We add a new singleton axis at the end so that we can use # pyphi.tpm.expand_tpm, which expects a state-by-node TPM (where the last # axis corresponds to nodes.) node_tpms = [np.expand_dims(tpm, -1) for tpm in node_tpms] # Now we expand the node TPMs to the full state space, so we can combine # them all (this uses the maximum entropy distribution). + shapes = tuple(tpm.shape[:-1] for tpm in node_tpms) + network_shape = tuple(max(dim) for dim in zip(*shapes)) node_tpms = [ - tpm * np.ones([2] * (tpm.ndim - 1) + [tpm.shape[-1]]) for tpm in node_tpms + tpm * np.ones(network_shape + (1,)) for tpm in node_tpms ] # We concatenate the node TPMs along a new axis to get a multidimensional # state-by-node TPM (where the last axis corresponds to nodes). - return np.concatenate(node_tpms, axis=-1) + return ExplicitTPM(np.concatenate(node_tpms, axis=-1)) # TODO(tpm) remove pending ArrayLike refactor def _new_attribute( - name: str, closures: Set[str], tpm: ExplicitTPM.__wraps__, cls=ExplicitTPM + name: str, + closures: Set[str], + tpm: np.ndarray, + cls=ExplicitTPM ) -> object: """Helper function to return adequate proxy attributes for TPM arrays. @@ -589,7 +1065,7 @@ def overriding_attribute(*args, **kwargs): # Test type of result and cast (or not) accordingly. # Array. - if isinstance(result, cls.__wraps__): + if isinstance(result, np.ndarray): return cls(result) # Multivalued "functions" returning a tuple (__divmod__()). @@ -607,58 +1083,3 @@ def overriding_attribute(*args, **kwargs): return overriding_attribute - -def probability_of_current_state(sbn_tpm, current_state): - """Return the probability of the current state as a distribution over previous states. - - Arguments: - sbn_tpm (ExplicitTPM): State-by-node TPM. - current_state (tuple[int]): The current state. - """ - state_probabilities = np.empty(sbn_tpm.shape) - if not len(current_state) == sbn_tpm.shape[-1]: - raise ValueError( - f"current_state must have length {sbn_tpm.shape[-1]}" - f"for state-by-node TPM of shape {sbn_tpm.shape}" - ) - for i in range(sbn_tpm.shape[-1]): - # TODO extend to nonbinary nodes - state_probabilities[..., i] = ( - sbn_tpm[..., i] if current_state[i] else (1 - sbn_tpm[..., i]) - ) - return state_probabilities.prod(axis=-1, keepdims=True) - - -def backward_tpm( - forward_tpm: ExplicitTPM, - current_state: tuple[int], - system_indices: Iterable[int], - remove_background: bool = False, -) -> ExplicitTPM: - """Compute the backward TPM for a given network state.""" - all_indices = tuple(range(forward_tpm.number_of_units)) - system_indices = tuple(sorted(system_indices)) - background_indices = tuple(sorted(set(all_indices) - set(system_indices))) - if not set(system_indices).issubset(set(all_indices)): - raise ValueError( - "system_indices must be a subset of `range(forward_tpm.number_of_units))`" - ) - - # p(u_t | s_{t–1}, w_{t–1}) - pr_current_state = probability_of_current_state(forward_tpm, current_state) - # Σ_{s_{t–1}} p(u_t | s_{t–1}, w_{t–1}) - pr_current_state_given_only_background = pr_current_state.sum( - axis=tuple(system_indices), keepdims=True - ) - # Σ_{u'_{t–1}} p(u_t | u'_{t–1}) - normalization = np.sum(pr_current_state) - # Σ_{s_{t–1}} p(u_t | s_{t–1}, w_{t–1}) - # Σ_{w_{t–1}} p(s_{i,t} | s_{t–1}, w_{t–1}) ——————————————————————————————————————— - # Σ_{u'_{t–1}} p(u_t | u'_{t–1}) - backward_tpm = ( - forward_tpm * pr_current_state_given_only_background / normalization - ).sum(axis=background_indices, keepdims=True) - if remove_background: - # Remove background units from last dimension of the state-by-node TPM - backward_tpm = backward_tpm[..., list(system_indices)] - return ExplicitTPM(backward_tpm) diff --git a/pyphi/utils.py b/pyphi/utils.py index 59ed65a3a..7fb6fcfe0 100644 --- a/pyphi/utils.py +++ b/pyphi/utils.py @@ -37,6 +37,7 @@ def state_of_subsystem_nodes(node_indices, nodes, subsystem_state): return state_of([node_indices.index(n) for n in nodes], subsystem_state) +# TODO: nonbinary states def all_states(n, big_endian=False): """Return all binary states for a system. @@ -59,6 +60,43 @@ def all_states(n, big_endian=False): yield state[::-1] # Convert to little-endian ordering +def equivalent_states(state, mask, state_space_shape): + """Generate the equivalence class of some state, given irrelevant dimensions. + + Arguments: + state (Iterable[int]): Some state in the equivalence class. + mask (Iterable[int]): State mask with 1's representing irrelevant dimensions. + state_space_shape (Iterable[int]): The cardinalities of each dimension + in the state space. + + Yields: + Iterable[tuple[int]]: A generator for the equivalence class of states. + + Examples: + >>> state = (1, 1, 1, 1) + >>> mask = (2, 1, 1, 2) + >>> state_space_shape = (2, 2, 3, 3) + >>> list(equivalent_states(state, mask, state_space_shape)) + [(1, 0, 0, 1), (1, 0, 1, 1), (1, 0, 2, 1), (1, 1, 0, 1), (1, 1, 1, 1), (1, 1, 2, 1)] + """ + n = len(state) + if any(n != len(arg) for arg in [mask, state_space_shape]): + raise ValueError(f"Expected mask and state_space_shape of size {n}.") + + indices_needing_expansion = { + i: state_space_shape[i] for i, mask in enumerate(mask) + if mask == 1 + } + locally_expanded_states = product( + *[range(states) for i, states in indices_needing_expansion.items()] + ) + expanded_indices = list(indices_needing_expansion.keys()) + state = np.array(state) + for s in locally_expanded_states: + state[expanded_indices] = s + yield tuple(state) + + def np_immutable(a): """Make a NumPy array immutable.""" a.flags.writeable = False diff --git a/pyphi/validate.py b/pyphi/validate.py index c224243cc..45e75303b 100644 --- a/pyphi/validate.py +++ b/pyphi/validate.py @@ -1,12 +1,14 @@ # validate.py """Methods for validating user input.""" +from itertools import product import numpy as np -from . import exceptions +from . import conf, exceptions from .conf import config from .direction import Direction + # pylint: disable=redefined-outer-name @@ -34,20 +36,6 @@ def direction(direction, allow_bi=False): return True -def connectivity_matrix(cm): - """Validate the given connectivity matrix.""" - # Special case for empty matrices. - if cm.size == 0: - return True - if cm.ndim != 2: - raise ValueError("Connectivity matrix must be 2-dimensional.") - if cm.shape[0] != cm.shape[1]: - raise ValueError("Connectivity matrix must be square.") - if not np.all(np.logical_or(cm == 1, cm == 0)): - raise ValueError("Connectivity matrix must contain only binary " "values.") - return True - - def node_labels(node_labels, node_indices): """Validate that there is a label for each node.""" if len(node_labels) != len(node_indices): @@ -58,7 +46,6 @@ def node_labels(node_labels, node_indices): if len(node_labels) != len(set(node_labels)): raise ValueError("Labels {0} must be unique.".format(node_labels)) - def network(n): """Validate a |Network|. @@ -66,6 +53,7 @@ def network(n): """ n.tpm.validate() connectivity_matrix(n.cm) + shapes(n.tpm.shapes, n.cm) if n.cm.shape[0] != n.size: raise ValueError( "Connectivity matrix must be NxN, where N is the " @@ -74,22 +62,42 @@ def network(n): return True +def connectivity_matrix(cm): + """Validate the given connectivity matrix.""" + # Special case for empty matrices. + if cm.size == 0: + return True + if cm.ndim != 2: + raise ValueError("Connectivity matrix must be 2-dimensional.") + if cm.shape[0] != cm.shape[1]: + raise ValueError("Connectivity matrix must be square.") + if not np.all(np.logical_or(cm == 1, cm == 0)): + raise ValueError("Connectivity matrix must contain only binary " "values.") + return True + + +def shapes(shapes, cm): + """Validate consistency between node TPM shapes and a user-provided cm.""" + for i, shape in enumerate(shapes): + for j, con in enumerate(cm[..., i]): + if (con == 0 and shape[j] != 1) or (con != 0 and shape[j] == 1): + raise ValueError( + "Node TPM {} of shape {} does not match the connectivity " + "matrix.".format(i, shape) + ) + return True + + def is_network(network): """Validate that the argument is a |Network|.""" from . import Network if not isinstance(network, Network): raise ValueError( - "Input must be a Network (perhaps you passed a Subsystem instead?" + "Input must be a Network (perhaps you passed a Subsystem instead?)" ) -def node_states(state): - """Check that the state contains only zeros and ones.""" - if not all(n in (0, 1) for n in state): - raise ValueError("Invalid state: states must consist of only zeros and ones.") - - def state_length(state, size): """Check that the state is the given size.""" if len(state) != size: @@ -101,18 +109,131 @@ def state_length(state, size): return True +def state_type(state): + """Check that the state only contains integers.""" + if any(not isinstance(s, int) for s in state): + raise TypeError( + f"Invalid state {state}: each entry must be of int type." + ) + return True + + +def state_value(state, shape): + """Check that each entry in the state falls within the right range.""" + if any( + s not in range(cardinality) + for s, cardinality in zip(state, shape) + ): + raise ValueError( + f"Invalid state {state}: entries must be within zero and " + f"{tuple((np.array(shape) - 1).tolist())}." + ) + return True + + +def state(state, size, shape): + """Check that the state is of the correct length, type and value.""" + return ( + state_length(state, size) and + state_type(state) and + state_value(state, shape) + ) + +def _past_states(p_node): + """Find set of states which could have led to the current state of a node. + + The state of irrelevant dimensions, nodes which don't output to this + node, is represented with -1 to encode a whole equivalence class. + + Arguments: + p_node (np.ndarray): Node TPM conditioned on the current subsystem state. + See also :func:`pyphi.tpm.ImplicitTPM.probability_of_current_state`. + + Returns: + set: Set of past states with nonzero probability of transitioning. + """ + # Find s_{t-1} such that p_node > 0. + states = list(np.argwhere(np.asarray(p_node) > 0)) + # Remove last dimension (probability of current state). + states = [state[:-1] for state in states] + # If node TPM shape at certain parent contains a 1, then + # there's no dependency on that parent. Substitute '0' state + # with placeholder -1 to encode equivalent states. + states = [ + tuple(-1 if p_node.shape[i] == 1 else s for i, s in enumerate(state)) + for state in states + ] + return set(states) + + +def _states_intersection(states1, states2): + """Efficient symbolic intersection between two sets of states. + + Arguments: + states1 (set[tuple[int]]): First set of states or equivalence classes. + states2 (set[tuple[int]]): Second set of states or equivalence classes. + + Returns: + set[tuple[int]]: The intersection between the two sets. + + Examples: + >>> states1 = {(1, 0, -1), (1, 1, 1)} + >>> states2 = {(1, 0, 0), (1, 1, 1), (0, 0, 0)} + >>> sorted(list(_states_intersection(states1, states2))) + [(1, 0, 0), (1, 1, 1)] + + >>> states1 = {(1, -1, -1)} + >>> states2 = {(1, 0, -1), (1, 1, -1)} + >>> sorted(list(_states_intersection(states1, states2))) + [(1, 0, -1), (1, 1, -1)] + """ + def find_intersection(state_pair): + # For each unordered pair |{state1, state2}| in the Cartesian product of + # the two sets, check if |state1| and |state2| have a non-empty + # (sub)class in common. If so, that is a member of the intersection. + subclass = [] + for i, j in zip(*state_pair): + if i == j: + subclass.append(i) + elif i == -1: + subclass.append(j) + elif j == -1: + subclass.append(i) + else: + return None + return tuple(subclass) + + # Lazy generator of the Cartesian product. + state_pairs = product(states1, states2) + # Find 2-ary intersections, filter out None's on the fly and return that set. + return set( + intersection for pair in state_pairs + if (intersection := find_intersection(pair)) + ) + + def state_reachable(subsystem): - """Return whether a state can be reached according to the network's TPM.""" - # If there is a row `r` in the TPM such that all entries of `r - state` are - # between -1 and 1, then the given state has a nonzero probability of being - # reached from some state. - # First we take the submatrix of the conditioned TPM that corresponds to - # the nodes that are actually in the subsystem... - tpm = subsystem.effect_tpm.tpm[..., subsystem.node_indices] - # Then we do the subtraction and test. - test = tpm - np.array(subsystem.proper_state) - if not np.any(np.logical_and(-1 < test, test < 1).all(-1)): - raise exceptions.StateUnreachableError(subsystem.state) + """Raise exception if state cannot be reached according to subsystem's TPM.""" + # A state s is reachable by Subsystem S if and only if there is at least + # one state s_{t-1} with nonzero probability of transitioning to s: + # ∃ s_{t-1} : p(s | s_{t-1}, w_{t-1}) > 0 + + # Obtain p(s | w_{t-1}) as node marginals (i.e. implicitly). + p = subsystem.proper_effect_tpm.probability_of_current_state(subsystem.proper_state) + + # Avoid computing the joint distribution. For each node n, find the set of + # coordinates s_{t-1} for which p_n > 0. The intersection of all such sets + # is the set of previous states leading to the current state. + + # Initial value. + intersection = _past_states(p[0]) + + for p_node in p[1:]: + intersection = _states_intersection(intersection, _past_states(p_node)) + # Shortcircuit evaluation of intersection as soon as a + # 2-ary intersection is empty. + if not intersection: + raise exceptions.StateUnreachableError(subsystem.state) def cut(cut, node_indices): @@ -128,7 +249,6 @@ def subsystem(s): Checks its state and cut. """ - node_states(s.state) # cut(s.cut, s.cut_indices) if config.VALIDATE_SUBSYSTEM_STATES: state_reachable(s) diff --git a/test/test_macro.py b/test/test_macro.py index 49c3c211b..12b736f4d 100644 --- a/test/test_macro.py +++ b/test/test_macro.py @@ -1,8 +1,9 @@ import numpy as np import pytest -from pyphi import convert, macro, ExplicitTPM +from pyphi import convert, macro from pyphi.exceptions import ConditionallyDependentError +from pyphi.tpm import ExplicitTPM # flake8: noqa @@ -298,7 +299,7 @@ def test_rebuild_system_tpm(s): # fmt: on assert macro.rebuild_system_tpm(node_tpms).array_equal(answer) - node_tpms = [node.tpm_on for node in s.nodes] + node_tpms = [node.tpm[..., 1] for node in s.nodes] assert macro.rebuild_system_tpm(node_tpms).array_equal(s.tpm) @@ -313,7 +314,7 @@ def test_remove_singleton_dimensions(): ) # fmt: on assert tpm.tpm_indices() == (0,) - assert macro.remove_singleton_dimensions(tpm).array_equal(tpm) + assert tpm.remove_singleton_dimensions().array_equal(tpm) # fmt: off tpm = ExplicitTPM( @@ -330,7 +331,7 @@ def test_remove_singleton_dimensions(): ) # fmt: on assert tpm.tpm_indices() == (1,) - assert macro.remove_singleton_dimensions(tpm).array_equal(answer) + assert tpm.remove_singleton_dimensions().array_equal(answer) # fmt: off tpm = ExplicitTPM( @@ -351,12 +352,12 @@ def test_remove_singleton_dimensions(): ) # fmt: on assert tpm.tpm_indices() == (0, 2) - assert macro.remove_singleton_dimensions(tpm).array_equal(answer) + assert tpm.remove_singleton_dimensions().array_equal(answer) def test_pack_attrs(s): attrs = macro.SystemAttrs.pack(s) - assert attrs.tpm.array_equal(s.tpm) + assert attrs.tpm == s.tpm assert np.array_equal(attrs.cm, s.cm) assert attrs.node_indices == s.node_indices assert attrs.state == s.state diff --git a/test/test_macro_blackbox.py b/test/test_macro_blackbox.py index b82b96782..cfe57436a 100644 --- a/test/test_macro_blackbox.py +++ b/test/test_macro_blackbox.py @@ -1,7 +1,8 @@ import numpy as np import pytest -from pyphi import Network, compute, config, convert, ExplicitTPM, macro, models, utils +from pyphi import Network, compute, config, convert, macro, models, utils +from pyphi.tpm import ExplicitTPM # TODO: move these to examples.py diff --git a/test/test_macro_subsystem.py b/test/test_macro_subsystem.py index a70247707..a3aacb46a 100644 --- a/test/test_macro_subsystem.py +++ b/test/test_macro_subsystem.py @@ -3,7 +3,6 @@ import pyphi from pyphi import convert, macro, models, timescale, config -from pyphi.tpm import ExplicitTPM from pyphi.convert import state_by_node2state_by_state as sbn2sbs from pyphi.convert import state_by_state2state_by_node as sbs2sbn @@ -278,7 +277,7 @@ def test_blackbox(s): ms = macro.MacroSubsystem( s.network, s.state, s.node_indices, blackbox=macro.Blackbox(((0, 1, 2),), (1,)) ) - assert np.array_equal(ms.tpm.tpm, np.array([[0.5], [0.5]])) + assert np.array_equal(np.asarray(ms.tpm), np.array([[0.5], [0.5]])) assert np.array_equal(ms.cm, np.array([[1]])) assert ms.node_indices == (0,) assert ms.state == (0,) @@ -289,7 +288,7 @@ def test_blackbox_external(s): ms = macro.MacroSubsystem( s.network, s.state, (1, 2), blackbox=macro.Blackbox(((1, 2),), (1,)) ) - assert np.array_equal(ms.tpm.tpm, np.array([[0.5], [0.5]])) + assert np.array_equal(np.asarray(ms.tpm), np.array([[0.5], [0.5]])) assert np.array_equal(ms.cm, np.array([[1]])) assert ms.node_indices == (0,) assert ms.state == (0,) diff --git a/test/test_network.py b/test/test_network.py index 18972aa90..39063a932 100644 --- a/test/test_network.py +++ b/test/test_network.py @@ -1,8 +1,10 @@ import numpy as np +import xarray as xr import pytest from pyphi import Direction, config, exceptions from pyphi.network import Network +from pyphi.tpm import ExplicitTPM, ImplicitTPM @pytest.fixture() @@ -57,15 +59,15 @@ def test_potential_purviews(s): def test_node_labels(standard): labels = ("A", "B", "C") - network = Network(standard.tpm.tpm, node_labels=labels) + network = Network(standard.tpm, node_labels=labels) assert network.node_labels.labels == labels labels = ("A", "B") # Too few labels with pytest.raises(ValueError): - Network(standard.tpm.tpm, node_labels=labels) + Network(standard.tpm, node_labels=labels) # Auto-generated labels - network = Network(standard.tpm.tpm, node_labels=None) + network = Network(standard.tpm, node_labels=None) assert network.node_labels.labels == ("n0", "n1", "n2") @@ -87,3 +89,186 @@ def test_len(standard): def test_size(standard): assert standard.size == 3 + + +def test_network_init_with_explicit_tpm(): + tpm = ExplicitTPM([ + [0, 0, 0], + [0, 0, 1], + [1, 0, 1], + [1, 0, 0], + [1, 1, 0], + [1, 1, 1], + [1, 1, 1], + [1, 1, 0] + ], validate=True) + + network = Network(tpm) + + assert type(network.tpm) == ImplicitTPM + + expected_nodes = ( + xr.DataArray([ + [ + [ + [1., 0.], + [0., 1.] + ], + [ + [0., 1.], + [0., 1.] + ] + ], + [ + [ + [1., 0.], + [0., 1.] + ], + [ + [0., 1.], + [0., 1.] + ] + ] + ]), + xr.DataArray([ + [ + [ + [1., 0.], + [0., 1.] + ], + [ + [1., 0.], + [0., 1.] + ] + ], + [ + [ + [1., 0.], + [0., 1.] + ], + [ + [1., 0.], + [0., 1.] + ] + ] + ]), + xr.DataArray([ + [ + [ + [1., 0.], + [1., 0.] + ], + [ + [0., 1.], + [0., 1.] + ] + ], + [ + [ + [0., 1.], + [0., 1.] + ], + [ + [1., 0.], + [1., 0.] + ] + ] + ]) + ) + + for i, node in enumerate(network.tpm.nodes): + assert (node.dataarray.values == expected_nodes[i].values).all() + + +def test_build_cm(): + # ExplicitTPM, no CM + tpm = np.array([ + [0, 0, 0], + [0, 0, 1], + [1, 0, 1], + [1, 0, 0], + [1, 1, 0], + [1, 1, 1], + [1, 1, 1], + [1, 1, 0] + ]) + cm = np.ones((3, 3), dtype=int) + network = Network(tpm) + assert((network.cm == cm).all()) + # ExplicitTPM, provided CM + cm = np.array([ + [0, 1, 1], + [1, 1, 0], + [1, 1, 1] + ]) + network = Network(tpm, cm) + assert((network.cm == cm).all()) + # ImplicitTPM, no CM + tpm = [ + np.array([ + [ + [ + [0., 1.], + [1., 0.] + ] + ], + [ + [ + [1., 0.], + [0., 1.] + ] + ] + ]), + np.array([ + [ + [ + [1., 0.], + [1., 0.] + ], + [ + [0., 1.], + [1., 0.] + ] + ], + [ + [ + [0., 1.], + [0., 1.] + ], + [ + [0., 1.], + [1., 0.] + ] + ] + ]), + np.array([ + [ + [ + [1., 0.], + [1., 0.] + ], + [ + [0., 1.], + [0., 1.] + ] + ] + ]) + ] + cm = np.array([ + [1, 1, 0], + [0, 1, 1], + [1, 1, 1] + ]) + network = Network(tpm) + assert((network.cm == cm).all()) + # ImplicitTPM, correct CM + network = Network(tpm, cm) + assert((network.cm == cm).all()) + # ImplicitTPM, incorrect CM + cm = np.array([ + [1, 0, 0], + [1, 1, 0], + [1, 1, 1] + ]) + with pytest.raises(ValueError): + network = Network(tpm, cm) diff --git a/test/test_node.py b/test/test_node.py index e7a1905a5..428c8d6b3 100644 --- a/test/test_node.py +++ b/test/test_node.py @@ -1,8 +1,8 @@ import numpy as np -from pyphi.node import Node, expand_node_tpm, generate_nodes +from pyphi.node import expand_node_tpm, generate_node, generate_nodes from pyphi.subsystem import Subsystem -from pyphi.tpm import ExplicitTPM +from pyphi.tpm import ExplicitTPM, reconstitute_tpm def test_node_init_tpm(s): @@ -38,16 +38,17 @@ def test_node_init_inputs(s): def test_node_eq(s): - assert s.nodes[1] == Node(s.tpm, s.cm, 1, 0, "B") + expected = generate_node(s.tpm, s.cm, s.state_space, 1, 0, "B") + assert s.nodes[1] == expected def test_node_neq_by_index(s): - assert s.nodes[0] != Node(s.tpm, s.cm, 1, 0, "B") + assert s.nodes[0] != generate_node(s.tpm, s.cm, s.state_space, 1, 0, "B") def test_node_neq_by_state(s): other_s = Subsystem(s.network, (1, 1, 1), s.node_indices) - assert other_s.nodes[1] != Node(s.tpm, s.cm, 1, 0, "B") + assert other_s.nodes[1] != generate_node(s.tpm, s.cm, s.state_space, 1, 0, "B") def test_repr(s): @@ -78,7 +79,14 @@ def test_expand_tpm(): def test_generate_nodes(s): - nodes = generate_nodes(s.tpm, s.cm, s.state, s.node_indices, s.node_labels) + nodes = generate_nodes( + s.tpm, + s.cm, + s.state_space, + s.node_indices, + network_state=s.state, + node_labels=s.node_labels + ) # fmt: off node0_tpm = ExplicitTPM( @@ -125,5 +133,13 @@ def test_generate_nodes(s): def test_generate_nodes_default_labels(s): - nodes = generate_nodes(s.tpm, s.cm, s.state, s.node_indices) - assert [n.label for n in nodes] == ["n0", "n1", "n2"] + nodes = generate_nodes( + s.tpm, + s.cm, + s.state_space, + s.node_indices, + network_state=s.state, + node_labels=s.node_labels + ) + + assert [n.label for n in nodes] == ["A", "B", "C"] diff --git a/test/test_subsystem.py b/test/test_subsystem.py index 51a5f9d9d..91e9a3aa9 100644 --- a/test/test_subsystem.py +++ b/test/test_subsystem.py @@ -121,8 +121,8 @@ def test_apply_cut(s): assert s.network == cut_s.network assert s.state == cut_s.state assert s.node_indices == cut_s.node_indices - assert np.array_equal(cut_s.tpm.tpm, s.tpm.tpm) - assert np.array_equal(cut_s.cm, cut.apply_cut(s.cm)) + assert s.tpm.array_equal(cut_s.tpm) + assert np.array_equal(cut.apply_cut(s.cm), cut_s.cm) def test_cut_indices(s, subsys_n1n2): @@ -144,7 +144,13 @@ def test_cut_node_labels(s): def test_specify_elements_with_labels(standard): - network = Network(standard.tpm.tpm, node_labels=("A", "B", "C")) + cm = np.array([ + [0, 0, 1], + [1, 0, 1], + [1, 1, 0] + ]) + print(standard.tpm) + network = Network(standard.tpm, cm, node_labels=("A", "B", "C")) subsystem = Subsystem(network, (0, 0, 0), ("B", "C")) assert subsystem.node_indices == (1, 2) assert tuple(node.label for node in subsystem.nodes) == ("B", "C") diff --git a/test/test_tpm.py b/test/test_tpm.py old mode 100644 new mode 100755 index e81d7ef86..d72abb671 --- a/test/test_tpm.py +++ b/test/test_tpm.py @@ -1,9 +1,61 @@ import numpy as np import pickle import pytest +import random + +from pyphi import examples, Network, Subsystem +from pyphi.convert import to_md +from pyphi.distribution import normalize +from pyphi.tpm import ExplicitTPM, reconstitute_tpm + + +@pytest.fixture() +def implicit_tpm(size, degree, node_states, seed=1337, deterministic=False): + if degree > size: + raise ValueError( + f"The number of parents of each node (degree={degree}) cannot be" + f"smaller than the size of the network ({size})." + ) + if node_states < 2: + raise ValueError("Nodes must have at least 2 node_states.") + + rng = random.Random(seed) + + def random_deterministic_repertoire(): + """Assign all probability to a single purview state at random.""" + repertoire = rng.sample([1] + (node_states - 1) * [0], node_states) + return repertoire + + def random_repertoire(deterministic): + if deterministic: + return random_deterministic_repertoire() + + repertoire = np.array([rng.uniform(0, 1) for s in range(node_states)]) + # Normalize using L1 metric. + return normalize(repertoire) + + tpm = [] + + for node_index in range(size): + # Generate |node_states| repertoires for each combination of parent + # states at t - 1. + node_tpm = [ + random_repertoire(deterministic) + for j in range(node_states ** degree) + ] + + # Select |degree| nodes at random as parents to this node, then reshape + # node TPM to multidimensional form. + node_shape = np.ones(size, dtype=int) + parents = rng.sample(range(size), degree) + node_shape[parents] = node_states + + node_tpm = np.array(node_tpm).reshape(tuple(node_shape) + (node_states,)) + + tpm.append(node_tpm) + + return tpm -from pyphi import Subsystem, ExplicitTPM -from pyphi.tpm import reconstitute_tpm @pytest.mark.parametrize( @@ -112,38 +164,81 @@ def test_expand_tpm(): def test_marginalize_out(s): marginalized_distribution = s.tpm.marginalize_out([0]) # fmt: off - answer = ExplicitTPM( - np.array([ + answer = np.array([ [[[0.0, 0.0, 0.5], [1.0, 1.0, 0.5]], [[1.0, 0.0, 0.5], [1.0, 1.0, 0.5]]], ]) - ) # fmt: on - assert marginalized_distribution.array_equal(answer) + assert np.array_equal( + np.asarray(reconstitute_tpm(marginalized_distribution)), answer + ) marginalized_distribution = s.tpm.marginalize_out([0, 1]) # fmt: off - answer = ExplicitTPM( - np.array([ + answer = np.array([ [[[0.5, 0.0, 0.5], [1.0, 1.0, 0.5]]], ]) - ) # fmt: on - assert marginalized_distribution.array_equal(answer) + assert np.array_equal( + np.asarray(reconstitute_tpm(marginalized_distribution)), answer + ) def test_infer_cm(rule152): assert np.array_equal(rule152.tpm.infer_cm(), rule152.cm) +def test_backward_tpm(): + network = examples.functionally_equivalent() + implicit_tpm = network.tpm + explicit_tpm = reconstitute_tpm(network.tpm) + + state = (1, 0, 0) + + # Backward TPM of full network must equal forward TPM. + subsystem_indices = (0, 1, 2) + + backward = explicit_tpm.backward_tpm(state, subsystem_indices) + assert backward.array_equal(explicit_tpm) + + backward = reconstitute_tpm( + implicit_tpm.backward_tpm(state, subsystem_indices) + ) + assert backward.array_equal(explicit_tpm) + + # Backward TPM of proper subsystem. + # fmt: off + answer = ExplicitTPM( + np.array( + [[[[1, 0, 0,]], + [[1, 1, 1,]]], + [[[0, 1, 0,]], + [[0, 1, 1,]]]], + ) + ) + # fmt: on + subsystem_indices = (0, 1) + + backward = explicit_tpm.backward_tpm(state, subsystem_indices) + assert backward.array_equal(answer) + + backward = reconstitute_tpm( + implicit_tpm.backward_tpm(state, subsystem_indices) + ) + assert backward.array_equal(answer) + + def test_reconstitute_tpm(standard, s_complete, rule152, noised): # Check subsystem and network TPM are the same when the subsystem is the # whole network - assert np.array_equal(reconstitute_tpm(s_complete), standard.tpm.tpm) + assert np.array_equal( + np.asarray(reconstitute_tpm(s_complete)), + np.asarray(reconstitute_tpm(standard.tpm)) + ) # Regression tests # fmt: off @@ -159,7 +254,7 @@ def test_reconstitute_tpm(standard, s_complete, rule152, noised): ]) # fmt: on subsystem = Subsystem(rule152, (0,) * 5, (0, 1, 2)) - assert np.array_equal(answer, reconstitute_tpm(subsystem)) + assert np.array_equal(answer, np.asarray(reconstitute_tpm(subsystem))) subsystem = Subsystem(noised, (0, 0, 0), (0, 1)) # fmt: off @@ -170,4 +265,4 @@ def test_reconstitute_tpm(standard, s_complete, rule152, noised): [1. , 0. ]], ]) # fmt: on - assert np.array_equal(answer, reconstitute_tpm(subsystem)) + assert np.array_equal(answer, np.asarray(reconstitute_tpm(subsystem))) diff --git a/test/test_validate.py b/test/test_validate.py index e09b86575..41c0da167 100644 --- a/test/test_validate.py +++ b/test/test_validate.py @@ -80,7 +80,7 @@ def test_validate_connectivity_matrix_not_binary(): def test_validate_network_wrong_cm_size(s): with pytest.raises(ValueError): - Network(s.network.tpm.tpm, np.ones(16).reshape(4, 4)) + Network(s.network.tpm, np.ones(16).reshape(4, 4)) def test_validate_is_network(s):