diff --git a/pymc/logprob/censoring.py b/pymc/logprob/censoring.py index b9221e08db..b15e338841 100644 --- a/pymc/logprob/censoring.py +++ b/pymc/logprob/censoring.py @@ -37,20 +37,42 @@ from typing import Optional import numpy as np +import pytensor import pytensor.tensor as pt -from pytensor.graph.basic import Node +from pytensor.graph import Op +from pytensor.graph.basic import Apply, Node from pytensor.graph.fg import FunctionGraph from pytensor.graph.rewriting.basic import node_rewriter -from pytensor.scalar.basic import Ceil, Clip, Floor, RoundHalfToEven +from pytensor.raise_op import Assert +from pytensor.scalar.basic import ( + GE, + GT, + LE, + LT, + Ceil, + Clip, + Floor, + RoundHalfToEven, + Switch, +) from pytensor.scalar.basic import clip as scalar_clip -from pytensor.tensor import TensorVariable +from pytensor.tensor import TensorType, TensorVariable +from pytensor.tensor.basic import switch as switch from pytensor.tensor.math import ceil, clip, floor, round_half_to_even +from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.variable import TensorConstant -from pymc.logprob.abstract import MeasurableElemwise, _logcdf, _logprob +from pymc.logprob.abstract import ( + MeasurableElemwise, + MeasurableVariable, + _logcdf, + _logcdf_helper, + _logprob, + _logprob_helper, +) from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db -from pymc.logprob.utils import CheckParameterValue +from pymc.logprob.utils import CheckParameterValue, check_potential_measurability class MeasurableClip(MeasurableElemwise): @@ -238,3 +260,282 @@ def round_logprob(op, values, base_rv, **kwargs): from pymc.math import logdiffexp return logdiffexp(logcdf_upper, logcdf_lower) + + +class FlatSwitches(Op): + __props__ = ("out_dtype", "rv_idx") + + def __init__(self, *args, out_dtype, rv_idx, **kwargs): + super().__init__(*args, **kwargs) + self.out_dtype = out_dtype + self.rv_idx = rv_idx + + def make_node(self, *inputs): + return Apply( + self, list(inputs), [TensorType(dtype=self.out_dtype, shape=inputs[0].type.shape)()] + ) + + def perform(self, *args, **kwargs): + raise NotImplementedError("This Op should not be evaluated") + + +MeasurableVariable.register(FlatSwitches) + + +def get_intervals(binary_node, valued_rvs): + """ + Handles both "x > 1" and "1 < x" expressions. + """ + + measurable_inputs = [ + inp for inp in binary_node.inputs if check_potential_measurability([inp], valued_rvs) + ] + + if len(measurable_inputs) != 1: + return None + + measurable_var = measurable_inputs[0] + measurable_var_idx = binary_node.inputs.index(measurable_var) + + const = binary_node.inputs[(measurable_var_idx + 1) % 2] + + # whether it is a lower or an upper bound depends on measurable_var_idx and the binary Op. + is_gt_or_ge = isinstance(binary_node.op.scalar_op, (GT, GE)) + is_lt_or_le = isinstance(binary_node.op.scalar_op, (LT, LE)) + + if not is_lt_or_le and not is_gt_or_ge: + # Switch condition was not defined using binary Ops + return None + + intervals = [(-np.inf, const), (const, np.inf)] + + # interval_true for the interval corresponds to true branch in 'Switch', interval_false corresponds to false branch + if measurable_var_idx == 0: + # e.g. "x < 1" and "x > 1" + interval_true, interval_false = intervals if is_lt_or_le else intervals[::-1] + else: + # e.g. "1 > x" and "1 < x" + interval_true, interval_false = intervals[::-1] if is_gt_or_ge else intervals + + return [interval_true, interval_false] + + +def adjust_intervals(intervals, outer_interval): + for i in range(2): + current = intervals[i] + lower = pt.maximum(current[0], outer_interval[0]) + upper = pt.minimum(current[1], outer_interval[1]) + + intervals[i] = (lower, upper) + + return intervals + + +def flat_switch_helper(node, valued_rvs, encoding_list, outer_interval, base_rv): + """ + Carries out the main recursion through the branches to fetch the encodings, their respective + intervals and adjust any overlaps. It also performs several checks on the switch condition and measurable + components. + """ + from pymc.distributions.distribution import SymbolicRandomVariable + + switch_cond, *components = node.inputs + + # deny broadcasting of the switch condition + if switch_cond.type.broadcastable != node.outputs[0].type.broadcastable: + return None + + measurable_var_switch = [ + var for var in switch_cond.owner.inputs if check_potential_measurability([var], valued_rvs) + ] + + if len(measurable_var_switch) != 1: + return None + + current_base_var = measurable_var_switch[0] + # deny cases where base_var is some function of 'x', e.g. pt.exp(x), and measurable var in the current switch is x + # also check if all the sources of measurability are the same RV. e.g. x1 and x2 + if current_base_var is not base_rv: + return None + + measurable_var_idx = [] + switch_comp_idx = [] + + # get the indices for the switch and other measurable components for further recursion + for idx, component in enumerate(components): + if check_potential_measurability([component], valued_rvs): + if isinstance( + component.owner.op, (RandomVariable, SymbolicRandomVariable) + ) or not isinstance(component.owner.op.scalar_op, Switch): + measurable_var_idx.append(idx) + else: + switch_comp_idx.append(idx) + + # Check if measurability source and the component itself are the same for all measurable components + if any(components[i] is not base_rv for i in measurable_var_idx): + return None + + # Get intervals for true and false components from the condition + intervals = get_intervals(switch_cond.owner, valued_rvs) + adjusted_intervals = adjust_intervals(intervals, outer_interval) + + # Base condition for recursion - when there is no more switch in either of the components + if not switch_comp_idx: + # Insert the two components and their respective intervals into encoding_list + for i in range(2): + switch_dict = { + "lower": adjusted_intervals[i][0], + "upper": adjusted_intervals[i][1], + "encoding": components[i], + } + encoding_list.append(switch_dict) + + return encoding_list + + else: + for i in range(2): + if i in switch_comp_idx: + # Recurse through the switch component(es) + encoding_list = flat_switch_helper( + components[i].owner, valued_rvs, encoding_list, adjusted_intervals[i], base_rv + ) + + else: + switch_dict = { + "lower": adjusted_intervals[i][0], + "upper": adjusted_intervals[i][1], + "encoding": components[i], + } + encoding_list.append(switch_dict) + + return encoding_list + + +@node_rewriter(tracks=[switch]) +def find_measurable_flat_switch_encoding(fgraph: FunctionGraph, node: Node): + rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) + + if rv_map_feature is None: + return None # pragma: no cover + + valued_rvs = rv_map_feature.rv_values.keys() + switch_cond = node.inputs[0] + + encoding_list = [] + initial_interval = (-np.inf, np.inf) + + # fetch base_var as the only measurable input to the logical op in switch condition + measurable_switch_inp = [ + component + for component in switch_cond.owner.inputs + if check_potential_measurability([component], valued_rvs) + ] + + if len(measurable_switch_inp) != 1: + return None + + base_rv = measurable_switch_inp[0] + + # We do not allow discrete RVs yet + if base_rv.dtype.startswith("int"): + return None + + # Since we verify the source of measurability to be the same for switch conditions + # and all measurable components, denying broadcastability of the base_var is enough + if base_rv.type.broadcastable != node.outputs[0].type.broadcastable: + return None + + encoding_list = flat_switch_helper(node, valued_rvs, encoding_list, initial_interval, base_rv) + if encoding_list is None: + return None + + encodings, intervals = [], [] + rv_idx = () + + # TODO: Some alternative cleaner way to do this + for idx, item in enumerate(encoding_list): + encoding = item["encoding"] + # indices of intervals having base_rv as their "encoding" + if encoding == base_rv: + rv_idx += (idx,) + + encodings.append(encoding) + intervals.extend((item["lower"], item["upper"])) + + flat_switch_op = FlatSwitches(out_dtype=node.outputs[0].dtype, rv_idx=rv_idx) + + new_outs = flat_switch_op.make_node(base_rv, *intervals, *encodings).default_output() + return [new_outs] + + +@_logprob.register(FlatSwitches) +def flat_switches_logprob(op, values, base_rv, *inputs, **kwargs): + from pymc.math import logdiffexp + + (value,) = values + + encodings_count = len(inputs) // 3 + # 'inputs' is of the form (lower1, upper1, lower2, upper2, encoding1, encoding2) + # Possible TODO: + # encodings = op.get_encodings_from_inputs(inputs) + encodings = inputs[2 * encodings_count : 3 * encodings_count] + encodings = pt.broadcast_arrays(*encodings) + + encodings = Assert(msg="all encodings should be unique")( + encodings, pt.eq(pt.unique(encodings, axis=0).shape[0], len(encodings)) + ) + + # TODO: We do not support the encoding graphs of discrete RVs yet + + interval_bounds = pt.broadcast_arrays(*inputs[0 : 2 * encodings_count]) + lower_interval_bounds = interval_bounds[::2] + upper_interval_bounds = interval_bounds[1::2] + + lower_interval_bounds = pt.concatenate([i[None] for i in lower_interval_bounds], axis=0) + upper_interval_bounds = pt.concatenate([j[None] for j in upper_interval_bounds], axis=0) + + interval_bounds = pt.concatenate( + [lower_interval_bounds[None], upper_interval_bounds[None]], axis=0 + ) + + # define a logcdf map on a scalar, use vectorize to calculate it for 2D intervals + scalar_interval_bound = pt.scalar("scalar_interval_bound", dtype=base_rv.dtype) + logcdf_scalar_interval_bound = _logcdf_helper(base_rv, scalar_interval_bound, **kwargs) + logcdf_interval_bounds = pytensor.graph.replace.vectorize( + logcdf_scalar_interval_bound, replace={scalar_interval_bound: interval_bounds} + ) + logcdf_intervals = logdiffexp( + logcdf_interval_bounds[1, ...], logcdf_interval_bounds[0, ...] + ) # (encoding, *base_rv.shape) + + # default logprob is -inf if there is no RV in branches + if op.rv_idx: + logprob = _logprob_helper(base_rv, value, **kwargs) + + # Add rv branch (and checks whether it is possible) + for i in op.rv_idx: + logprob = pt.where( + pt.and_(value <= upper_interval_bounds[i], value >= lower_interval_bounds[i]), + logprob, + -np.inf, + ) + else: + logprob = -np.inf + + for i in range(encodings_count): + # if encoding found in interval (Lower, Upper), then Prob = CDF(Upper) - CDF(Lower) + logprob = pt.where( + pt.eq(value, encodings[i]), + logcdf_intervals[i], + logprob, + ) + + return logprob + + +measurable_ir_rewrites_db.register( + "find_measurable_flat_switch_encoding", + find_measurable_flat_switch_encoding, + "basic", + "censoring", +) diff --git a/tests/logprob/test_censoring.py b/tests/logprob/test_censoring.py index de407fd579..9b7f9c7d2b 100644 --- a/tests/logprob/test_censoring.py +++ b/tests/logprob/test_censoring.py @@ -262,3 +262,159 @@ def test_rounding(rounding_op): logprob.eval({xr_vv: test_value}), expected_logp, ) + + +def test_switch_encoding_no_branch_measurable(): + x_rv = pt.random.normal(0.5, 1, size=2) + + y_rv = pt.switch(x_rv < 1, [2.0, 2.5], 3.0) + + y_vv1 = y_rv.clone() + + logprob1 = logp(y_rv, y_vv1) + + logp_fn1 = pytensor.function([y_vv1], logprob1) + + ref_scipy = st.norm(0.5, 1) + + np.testing.assert_allclose( + logp_fn1([2.0, 1.5]), + np.array([ref_scipy.logcdf(1), -np.inf]), + ) + + +def test_switch_encoding_one_branch_measurable(): + x_rv = pt.random.normal(0.5, 1, size=3) + + y_rv1 = pt.switch(x_rv < 1, x_rv, 1) + y_rv2 = pt.switch(x_rv < 1, 1, x_rv) + + y_vv1 = y_rv1.clone() + y_vv2 = y_rv2.clone() + + logprob1 = logp(y_rv1, y_vv1) + logprob2 = logp(y_rv2, y_vv2) + + logp_fn1 = pytensor.function([y_vv1], logprob1) + logp_fn2 = pytensor.function([y_vv2], logprob2) + + ref_scipy = st.norm(0.5, 1) + + np.testing.assert_allclose( + logp_fn1([1.5, 1, 0.9]), + np.array([-np.inf, st.norm(0.5, 1).logsf(1), st.norm(0.5, 1).logpdf(0.9)]), + ) + + np.testing.assert_allclose( + logp_fn2([0.9, 1, 1.5]), np.array([-np.inf, ref_scipy.logcdf(1), ref_scipy.logpdf(1.5)]) + ) + + +def test_switch_encoding_two_branches(): + x_rv = pt.random.normal(0.5, 1, size=4) + + y_rv = pt.switch(x_rv < -1, -1, pt.switch(x_rv < 1, x_rv, 1)) + clip_rv = pt.clip(x_rv, -1, 1) + + y_vv = y_rv.clone() + clip_vv = y_vv.clone() + + logp_switch = logp(y_rv, y_vv) + logp_clip = logp(clip_rv, clip_vv) + + logp_fn_switch = pytensor.function([y_vv], logp_switch) + logp_fn_clip = pytensor.function([clip_vv], logp_clip) + + test_values = [-1, 0, 1, 1.5] + np.testing.assert_allclose(logp_fn_switch(test_values), logp_fn_clip(test_values)) + + +def test_switch_encoding_nested_branches(): + x_rv = pt.random.normal(0.5, 1, size=3) + y_rv = pt.switch(x_rv < -1, -1, pt.switch(x_rv < 2, x_rv, pt.switch(x_rv >= 2.5, 2, 1))) + # -inf to -1: -1 + # -1 to 2: x + # 2 to 2.5: 1 + # 2.5 to inf: 2 + y_vv = y_rv.clone() + + logp_switch = logp(y_rv, y_vv) + logp_fn_switch = pytensor.function([y_vv], logp_switch) + + ref_scipy = st.norm(0.5, 1) + + np.testing.assert_allclose( + logp_fn_switch([-2, -1.0, 0]), [-np.inf, ref_scipy.logcdf(-1), ref_scipy.logpdf(0)] + ) + np.testing.assert_allclose( + logp_fn_switch([1.5, 2, 2.5]), [ref_scipy.logpdf(1.5), ref_scipy.logsf(2.5), -np.inf] + ) + + +def test_switch_encoding_broadcastability(): + """Test that measurable branches and switch conditions are not allowed to be broadcasted""" + x_rv = pt.random.normal(0.5, 1, size=2) + + y_rv_valid = pt.switch(x_rv < [0.3, 0.3], pt.switch(x_rv > -0.5, x_rv, [0.1, 0.2]), 1.0) + + y_rv_invalid1 = pt.switch(x_rv < [[0.3, 0.3], [0.1, 0.1]], 1.0, [0.0, 0.5]) + y_rv_invalid2 = pt.switch(x_rv < [0.3, 0.3], x_rv, [[0.0, 0.5], x_rv]) + + y_vv_valid = y_rv_valid.clone() + y_vv_invalid1 = y_rv_invalid1.clone() + y_vv_invalid2 = y_rv_invalid2.clone() + + y_test = [0.1, 0.2] + ref_scipy = st.norm(0.5, 1) + np.testing.assert_allclose( + logp(y_rv_valid, y_vv_valid).eval({y_vv_valid: y_test}), + np.array([ref_scipy.logcdf(-0.5)] * 2), + ) + + with pytest.raises( + NotImplementedError, + match="Logprob method not implemented", + ): + logp(y_rv_invalid1, y_vv_invalid1).eval({y_vv_invalid1: y_test}) + + with pytest.raises( + NotImplementedError, + match="Logprob method not implemented", + ): + logp(y_rv_invalid2, y_vv_invalid2).eval({y_vv_invalid2: y_test}) + + +def test_switch_measurability_source(): + """Test failure when more than one sources of measurability are present""" + x_rv1 = pt.random.normal(0.5, 1) + x_rv2 = pt.random.halfnormal(0.5, 1) + + y_rv = pt.switch(x_rv1 > 1, x_rv2, 2) + y_vv = y_rv.clone() + y_vv.name = "cens_x" + + x_vv1 = x_rv1.clone() + + with pytest.raises( + NotImplementedError, + match="Logprob method not implemented", + ): + logp(y_rv, y_vv) + + with pytest.raises(RuntimeError, match="could not be derived: {cens_x}"): + conditional_logp({y_rv: y_vv, x_rv1: x_vv1}) + + +def test_switch_discrete_fail(): + """Test failure when discrete RVs are used in the graph""" + x_rv = pt.random.poisson(2) + y_rv = pt.switch(x_rv > 3, x_rv, 1) + + y_vv = x_rv.clone() + y_vv_test = 1 + + with pytest.raises( + NotImplementedError, + match="Logprob method not implemented", + ): + logp(y_rv, y_vv).eval({y_vv: y_vv_test})