Skip to content

Add model.logp_elemwiset #5245

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Dec 13, 2021
9 changes: 6 additions & 3 deletions pymc/backends/arviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import pymc

from pymc.aesaraf import extract_obs_data
from pymc.distributions import logpt
from pymc.model import modelcontext
from pymc.util import get_default_varnames

Expand Down Expand Up @@ -264,11 +263,15 @@ def _extract_log_likelihood(self, trace):
if self.model is None:
return None

# TODO: We no longer need one function per observed variable
if self.log_likelihood is True:
cached = [(var, self.model.fn(logpt(var))) for var in self.model.observed_RVs]
cached = [
(var, self.model.fn(self.model.logp_elemwiset(var)[0]))
for var in self.model.observed_RVs
]
else:
cached = [
(var, self.model.fn(logpt(var)))
(var, self.model.fn(self.model.logp_elemwiset(var)[0]))
for var in self.model.observed_RVs
if var.name in self.log_likelihood
]
Expand Down
148 changes: 65 additions & 83 deletions pymc/distributions/logprob.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from collections.abc import Mapping
from functools import singledispatch
from typing import Dict, Optional, Union
from typing import Dict, List, Optional, Union

import aesara.tensor as at
import numpy as np
Expand All @@ -24,10 +24,8 @@
from aeppl.logprob import logcdf as logcdf_aeppl
from aeppl.logprob import logprob as logp_aeppl
from aeppl.transforms import TransformValuesOpt
from aesara import config
from aesara.graph.basic import graph_inputs, io_toposort
from aesara.graph.op import Op, compute_test_value
from aesara.tensor.random.op import RandomVariable
from aesara.graph.op import Op
from aesara.tensor.subtensor import (
AdvancedIncSubtensor,
AdvancedIncSubtensor1,
Expand Down Expand Up @@ -121,15 +119,15 @@ def _get_scaling(total_size, shape, ndim):


def logpt(
var: TensorVariable,
var: Union[TensorVariable, List[TensorVariable]],
rv_values: Optional[Union[TensorVariable, Dict[TensorVariable, TensorVariable]]] = None,
*,
jacobian: bool = True,
scaling: bool = True,
transformed: bool = True,
sum: bool = True,
**kwargs,
) -> TensorVariable:
) -> Union[TensorVariable, List[TensorVariable]]:
"""Create a measure-space (i.e. log-likelihood) graph for a random variable
or a list of random variables at a given point.

Expand All @@ -156,108 +154,100 @@ def logpt(
transformed
Apply transforms.
sum
Sum the log-likelihood.
Sum the log-likelihood or return each term as a separate list item.

"""
# TODO: In future when we drop support for tag.value_var most of the following
# logic can be removed and logpt can just be a wrapper function that calls aeppl's
# joint_logprob directly.

# If var is not a list make it one.
if not isinstance(var, list):
if not isinstance(var, (list, tuple)):
var = [var]

# If logpt isn't provided values and the variable (provided in var)
# is an RV, it is assumed that the tagged value var or observation is
# the value variable for that particular RV.
# If logpt isn't provided values it is assumed that the tagged value var or
# observation is the value variable for that particular RV.
if rv_values is None:
rv_values = {}
for _var in var:
if isinstance(_var.owner.op, RandomVariable):
rv_value_var = getattr(
_var.tag, "observations", getattr(_var.tag, "value_var", _var)
)
rv_values = {_var: rv_value_var}
for rv in var:
value_var = getattr(rv.tag, "observations", getattr(rv.tag, "value_var", None))
if value_var is None:
raise ValueError(f"No value variable found for var {rv}")
rv_values[rv] = value_var
# Else we assume we were given a single rv and respective value
elif not isinstance(rv_values, Mapping):
# Else if we're given a single value and a single variable we assume a mapping among them.
rv_values = (
{var[0]: at.as_tensor_variable(rv_values).astype(var[0].type)} if len(var) == 1 else {}
)

# Since the filtering of logp graph is based on value variables
# provided to this function
if not rv_values:
warnings.warn("No value variables provided the logp will be an empty graph")
if len(var) == 1:
rv_values = {var[0]: at.as_tensor_variable(rv_values).astype(var[0].type)}
else:
raise ValueError("rv_values must be a dict if more than one var is requested")

if scaling:
rv_scalings = {}
for _var in var:
rv_value_var = getattr(_var.tag, "observations", getattr(_var.tag, "value_var", _var))
rv_scalings[rv_value_var] = _get_scaling(
getattr(_var.tag, "total_size", None), rv_value_var.shape, rv_value_var.ndim
for rv, value_var in rv_values.items():
rv_scalings[value_var] = _get_scaling(
getattr(rv.tag, "total_size", None), value_var.shape, value_var.ndim
)

# Aeppl needs all rv-values pairs, not just that of the requested var.
# Hence we iterate through the graph to collect them.
tmp_rvs_to_values = rv_values.copy()
transform_map = {}
for node in io_toposort(graph_inputs(var), var):
try:
curr_vars = [node.default_output()]
except ValueError:
curr_vars = node.outputs
for curr_var in curr_vars:
rv_value_var = getattr(
if curr_var in tmp_rvs_to_values:
continue
# Check if variable has a value variable
value_var = getattr(
curr_var.tag, "observations", getattr(curr_var.tag, "value_var", None)
)
if rv_value_var is None:
continue
rv_value = rv_values.get(curr_var, rv_value_var)
tmp_rvs_to_values[curr_var] = rv_value
# Along with value variables we also check for transforms if any.
if hasattr(rv_value_var.tag, "transform") and transformed:
transform_map[rv_value] = rv_value_var.tag.transform
if value_var is not None:
tmp_rvs_to_values[curr_var] = value_var

# After collecting all necessary rvs and values, we check for any value transforms
transform_map = {}
if transformed:
for rv, value_var in tmp_rvs_to_values.items():
if hasattr(value_var.tag, "transform"):
transform_map[value_var] = value_var.tag.transform
# If the provided value_variable does not have transform information, we
# check if the original `rv.tag.value_var` does.
# TODO: This logic should be replaced by an explicit dict of
# `{value_var: transform}` similar to `rv_values`.
else:
original_value_var = getattr(rv.tag, "value_var", None)
if original_value_var is not None and hasattr(original_value_var.tag, "transform"):
transform_map[value_var] = original_value_var.tag.transform

transform_opt = TransformValuesOpt(transform_map)
temp_logp_var_dict = factorized_joint_logprob(
tmp_rvs_to_values, extra_rewrites=transform_opt, use_jacobian=jacobian, **kwargs
)

# aeppl returns the logpt for every single value term we provided to it. This includes
# the extra values we plugged in above so we need to filter those out.
# the extra values we plugged in above, so we filter those we actually wanted in the
# same order they were given in.
logp_var_dict = {}
for value_var, _logp in temp_logp_var_dict.items():
if value_var in rv_values.values():
logp_var_dict[value_var] = _logp
for value_var in rv_values.values():
logp_var_dict[value_var] = temp_logp_var_dict[value_var]

# If it's an empty dictionary the logp is None
if not logp_var_dict:
logp_var = None
else:
# Otherwise apply appropriate scalings and at.add and/or at.sum the
# graphs accordingly.
if scaling:
for _value in logp_var_dict.keys():
if _value in rv_scalings:
logp_var_dict[_value] *= rv_scalings[_value]

if len(logp_var_dict) == 1:
logp_var_dict = tuple(logp_var_dict.values())[0]
if sum:
logp_var = at.sum(logp_var_dict)
else:
logp_var = logp_var_dict
else:
if sum:
logp_var = at.sum([at.sum(factor) for factor in logp_var_dict.values()])
else:
logp_var = at.add(*logp_var_dict.values())
if scaling:
for value_var in logp_var_dict.keys():
if value_var in rv_scalings:
logp_var_dict[value_var] *= rv_scalings[value_var]

# Recompute test values for the changes introduced by the replacements
# above.
if config.compute_test_value != "off":
for node in io_toposort(graph_inputs((logp_var,)), (logp_var,)):
compute_test_value(node)
if sum:
logp_var = at.sum([at.sum(factor) for factor in logp_var_dict.values()])
else:
logp_var = list(logp_var_dict.values())
# TODO: deprecate special behavior when only one variable is requested and
# always return a list. This is here for backwards compatibility as logpt
# started as a replacement to factor.logpt, but it should now be considered an
# internal function reached only via model.logp* methods.
if len(logp_var) == 1:
logp_var = logp_var[0]

return logp_var

Expand All @@ -276,23 +266,15 @@ def logcdf(rv, value):
return logcdf_aeppl(rv, value)


@singledispatch
def _logcdf(op, values, *args, **kwargs):
"""Create a log-CDF graph.

This function dispatches on the type of `op`, which should be a subclass
of `RandomVariable`. If you want to implement new log-CDF graphs
for a `RandomVariable`, register a new function on this dispatcher.

"""
raise NotImplementedError()


def logpt_sum(*args, **kwargs):
"""Return the sum of the logp values for the given observations.

Subclasses can use this to improve the speed of logp evaluations
if only the sum of the logp values is needed.
"""
# TODO: Deprecate this
warnings.warn(
"logpt_sum has been deprecated, you can use logpt instead, which now defaults"
"to the same behavior of logpt_sum",
DeprecationWarning,
)
return logpt(*args, sum=True, **kwargs)
Loading