Skip to content

Reinstate log-likelihood transforms #4521

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Mar 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion pymc3/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,18 @@ def __init__(self, name, model=None, vars=None, test_point=None):
model = modelcontext(model)
self.model = model
if vars is None:
vars = [v.tag.value_var for v in model.unobserved_RVs]
vars = []
for v in model.unobserved_RVs:
var = getattr(v.tag, "value_var", v)
transform = getattr(var.tag, "transform", None)
if transform:
# We need to create and add an un-transformed version of
# each transformed variable
untrans_var = transform.backward(var)
untrans_var.name = v.name
vars.append(untrans_var)
vars.append(var)

self.vars = vars
self.varnames = [var.name for var in vars]
self.fn = model.fastfn(vars)
Expand Down
204 changes: 118 additions & 86 deletions pymc3/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import singledispatch
from itertools import chain
from typing import Generator, List, Optional, Tuple, Union

import aesara.tensor as aet
Expand All @@ -31,6 +32,11 @@
]


@singledispatch
def logp_transform(op, inputs):
return None


def _get_scaling(total_size, shape, ndim):
"""
Gets scaling constant for logp
Expand Down Expand Up @@ -135,7 +141,6 @@ def change_rv_size(

def rv_log_likelihood_args(
rv_var: TensorVariable,
rv_value: Optional[TensorVariable] = None,
transformed: Optional[bool] = True,
) -> Tuple[TensorVariable, TensorVariable]:
"""Get a `RandomVariable` and its corresponding log-likelihood `TensorVariable` value.
Expand All @@ -146,38 +151,24 @@ def rv_log_likelihood_args(
A variable corresponding to a `RandomVariable`, whether directly or
indirectly (e.g. an observed variable that's the output of an
`Observed` `Op`).
rv_value
The measure-space input `TensorVariable` (i.e. "input" to a
log-likelihood).
transformed
When ``True``, return the transformed value var.

Returns
=======
The first value in the tuple is the `RandomVariable`, and the second is the
measure-space variable that corresponds with the latter. The first is used
to determine the log likelihood graph and the second is the "input"
parameter to that graph. In the case of an observed `RandomVariable`, the
"input" is actual data; in all other cases, it's just another
`TensorVariable`.
measure-space variable that corresponds with the latter (i.e. the "value"
variable).

"""

if rv_value is None:
if rv_var.owner and isinstance(rv_var.owner.op, Observed):
rv_var, rv_value = rv_var.owner.inputs
elif hasattr(rv_var.tag, "value_var"):
rv_value = rv_var.tag.value_var
else:
return rv_var, None

rv_value = aet.as_tensor_variable(rv_value)

transform = getattr(rv_value.tag, "transform", None)
if transformed and transform:
rv_value = transform.forward(rv_value)

return rv_var, rv_value
if rv_var.owner and isinstance(rv_var.owner.op, Observed):
return tuple(rv_var.owner.inputs)
elif hasattr(rv_var.tag, "value_var"):
rv_value = rv_var.tag.value_var
return rv_var, rv_value
else:
return rv_var, None


def rv_ancestors(graphs: List[TensorVariable]) -> Generator[TensorVariable, None, None]:
Expand All @@ -197,22 +188,52 @@ def strip_observed(x: TensorVariable) -> TensorVariable:
return x


def sample_to_measure_vars(graphs: List[TensorVariable]) -> List[TensorVariable]:
"""Replace `RandomVariable` terms in graphs with their measure-space counterparts."""
def sample_to_measure_vars(
graphs: List[TensorVariable],
) -> Tuple[List[TensorVariable], List[TensorVariable]]:
"""Replace sample-space variables in graphs with their measure-space counterparts.

Sample-space variables are `TensorVariable` outputs of `RandomVariable`
`Op`s. Measure-space variables are `TensorVariable`s that correspond to
the value of a sample-space variable in a likelihood function (e.g. ``x``
in ``p(X = x)``, where ``X`` is the corresponding sample-space variable).
(``x`` is also the variable found in ``rv_var.tag.value_var``, so this
function could also be called ``sample_to_value_vars``.)

Parameters
==========
graphs
The graphs in which random variables are to be replaced by their
measure variables.

Returns
=======
Tuple containing the transformed graphs and a ``dict`` of the replacements
that were made.
"""
replace = {}
for anc in rv_ancestors(graphs):
measure_var = getattr(anc.tag, "value_var", None)
if measure_var is not None:
replace[anc] = measure_var
for anc in chain(rv_ancestors(graphs), graphs):

if not (anc.owner and isinstance(anc.owner.op, RandomVariable)):
continue

_, value_var = rv_log_likelihood_args(anc)

if value_var is not None:
replace[anc] = value_var

if replace:
measure_graphs = clone_replace(graphs, replace=replace)
else:
measure_graphs = graphs

dist_params = clone_replace(graphs, replace=replace)
return dist_params
return measure_graphs, replace


def logpt(
rv_var: TensorVariable,
rv_value: Optional[TensorVariable] = None,
jacobian: bool = True,
jacobian: Optional[bool] = True,
scaling: Optional[bool] = True,
**kwargs,
) -> TensorVariable:
Expand All @@ -228,29 +249,41 @@ def logpt(
rv_var
The `RandomVariable` output that determines the log-likelihood graph.
rv_value
The input variable for the log-likelihood graph.
The input variable for the log-likelihood graph. If `rv_value` is
a transformed variable, its transformations will be applied.
If no value is provided, `rv_var.tag.value_var` will be checked and,
when available, used.
jacobian
Whether or not to include the Jacobian term.
scaling
A scaling term to apply to the generated log-likelihood graph.

"""

rv_var, rv_value = rv_log_likelihood_args(rv_var, rv_value)
rv_var, rv_value_var = rv_log_likelihood_args(rv_var)

if rv_value is None:
rv_value = rv_value_var
else:
rv_value = aet.as_tensor(rv_value)

if rv_value_var is None:
rv_value_var = rv_value

rv_node = rv_var.owner

if not rv_node:
raise TypeError("rv_var must be the output of a RandomVariable Op")

if not isinstance(rv_node.op, RandomVariable):

# This will probably need another generic function...
if isinstance(rv_node.op, (Subtensor, AdvancedSubtensor, AdvancedSubtensor1)):

raise NotImplementedError("Missing value support is incomplete")

# "Flatten" and sum an array of indexed RVs' log-likelihoods
rv_var, missing_values = rv_node.inputs
rv_value = rv_var.tag.value_var

missing_values = missing_values.data
logp_var = aet.sum(
Expand All @@ -268,28 +301,36 @@ def logpt(

return aet.zeros_like(rv_var)

# This case should be reached when `rv_var` is either the result of an
# `Observed` or a `RandomVariable` `Op`
rng, size, dtype, *dist_params = rv_node.inputs

dist_params = sample_to_measure_vars(dist_params)
dist_params, replacements = sample_to_measure_vars(dist_params)

if jacobian:
logp_var = _logp(rv_node.op, rv_value, *dist_params, **kwargs)
else:
logp_var = _logp_nojac(rv_node.op, rv_value, *dist_params, **kwargs)
transform = getattr(rv_value_var.tag, "transform", None)

# Replace `RandomVariable` ancestors with their corresponding
# log-likelihood input variables
lik_replacements = [
(v, v.tag.value_var)
for v in ancestors([logp_var])
if v.owner and isinstance(v.owner.op, RandomVariable) and getattr(v.tag, "value_var", None)
]
# If any of the measure vars are transformed measure-space variables
# (signified by having a `transform` value in their tags), then we apply
# the their transforms and add their Jacobians (when enabled)
if transform:
logp_var = _logp(rv_node.op, transform.backward(rv_value), *dist_params, **kwargs)
logp_var = transform_logp(
logp_var,
tuple(replacements.values()),
)

(logp_var,) = clone_replace([logp_var], replace=lik_replacements)
if jacobian:
transformed_jacobian = transform.jacobian_det(rv_value)
if transformed_jacobian:
if logp_var.ndim > transformed_jacobian.ndim:
logp_var = logp_var.sum(axis=-1)
logp_var += transformed_jacobian
else:
logp_var = _logp(rv_node.op, rv_value, *dist_params, **kwargs)

if scaling:
logp_var *= _get_scaling(
getattr(rv_var.tag, "total_size", None), rv_value.shape, rv_value.ndim
getattr(rv_var.tag, "total_size", None), rv_value_var.shape, rv_value_var.ndim
)

if rv_var.name is not None:
Expand All @@ -298,6 +339,25 @@ def logpt(
return logp_var


def transform_logp(logp_var: TensorVariable, inputs: List[TensorVariable]) -> TensorVariable:
"""Transform the inputs of a log-likelihood graph."""
trans_replacements = {}
for measure_var in inputs:

transform = getattr(measure_var.tag, "transform", None)

if transform is None:
continue

trans_rv_value = transform.backward(measure_var)
trans_replacements[measure_var] = trans_rv_value

if trans_replacements:
(logp_var,) = clone_replace([logp_var], trans_replacements)

return logp_var


@singledispatch
def _logp(op, value, *dist_params, **kwargs):
"""Create a log-likelihood graph.
Expand All @@ -310,20 +370,24 @@ def _logp(op, value, *dist_params, **kwargs):
return aet.zeros_like(value)


def logcdf(rv_var, rv_value, **kwargs):
def logcdf(rv_var, rv_value, jacobian=True, **kwargs):
"""Create a log-CDF graph."""

rv_var, rv_value = rv_log_likelihood_args(rv_var, rv_value)
rv_var, _ = rv_log_likelihood_args(rv_var)
rv_node = rv_var.owner

if not rv_node:
raise TypeError()

rv_value = aet.as_tensor(rv_value)

rng, size, dtype, *dist_params = rv_node.inputs

dist_params = sample_to_measure_vars(dist_params)
dist_params, replacements = sample_to_measure_vars(dist_params)

logp_var = _logcdf(rv_node.op, rv_value, *dist_params, **kwargs)

return _logcdf(rv_node.op, rv_value, *dist_params, **kwargs)
return logp_var


@singledispatch
Expand All @@ -338,38 +402,6 @@ def _logcdf(op, value, *args, **kwargs):
raise NotImplementedError()


def logp_nojac(rv_var, rv_value=None, **kwargs):
"""Create a graph of the log-likelihood that doesn't include the Jacobian."""

rv_var, rv_value = rv_log_likelihood_args(rv_var, rv_value)
rv_node = rv_var.owner

if not rv_node:
raise TypeError()

rng, size, dtype, *dist_params = rv_node.inputs

dist_params = sample_to_measure_vars(dist_params)

return _logp_nojac(rv_node.op, rv_value, **kwargs)


@singledispatch
def _logp_nojac(op, value, *args, **kwargs):
"""Return the logp, but do not include a jacobian term for transforms.

If we use different parametrizations for the same distribution, we
need to add the determinant of the jacobian of the transformation
to make sure the densities still describe the same distribution.
However, MAP estimates are not invariant with respect to the
parameterization, we need to exclude the jacobian terms in this case.

This function should be overwritten in base classes for transformed
distributions.
"""
return logpt(op, value, *args, **kwargs)


def logpt_sum(rv_var: TensorVariable, rv_value: Optional[TensorVariable] = None, **kwargs):
"""Return the sum of the logp values for the given observations.

Expand Down
Loading