From 0f14a0f537b97a13f9076dc1d6a6ec7fa4ebb162 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 5 Dec 2022 14:34:53 +0100 Subject: [PATCH 1/2] Implement reciprocal measurable transform Adds rewrite that converts divisions with measurable variables to product with reciprocals, making the reciprocal measurable transform more widely applicable. --- pymc/logprob/transforms.py | 61 +++++++++++++++++++++++++-- pymc/tests/logprob/test_transforms.py | 32 +++++++++++--- 2 files changed, 83 insertions(+), 10 deletions(-) diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index 8ae8b43240..d6a2621960 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -47,9 +47,10 @@ from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import Op from pytensor.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter -from pytensor.scalar import Add, Exp, Log, Mul +from pytensor.scalar import Add, Exp, Log, Mul, Reciprocal from pytensor.scan.op import Scan -from pytensor.tensor.math import add, exp, log, mul +from pytensor.tensor.exceptions import NotScalarConstantError +from pytensor.tensor.math import add, exp, log, mul, reciprocal, true_div from pytensor.tensor.rewriting.basic import ( register_specialize, register_stabilize, @@ -318,7 +319,7 @@ def apply(self, fgraph: FunctionGraph): class MeasurableTransform(MeasurableElemwise): """A placeholder used to specify a log-likelihood for a transformed measurable variable""" - valid_scalar_types = (Exp, Log, Add, Mul) + valid_scalar_types = (Exp, Log, Add, Mul, Reciprocal) # Cannot use `transform` as name because it would clash with the property added by # the `TransformValuesRewrite` @@ -354,7 +355,36 @@ def measurable_transform_logprob(op: MeasurableTransform, values, *inputs, **kwa return input_logprob + jacobian -@node_rewriter([exp, log, add, mul]) +@node_rewriter([true_div]) +def measurable_div_to_reciprocal_product(fgraph, node): + """Convert divisions involving `MeasurableVariable`s to product with reciprocal.""" + + measurable_vars = [ + var for var in node.inputs if (var.owner and isinstance(var.owner.op, MeasurableVariable)) + ] + if not measurable_vars: + return None # pragma: no cover + + rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) + if rv_map_feature is None: + return None # pragma: no cover + + # Only apply this rewrite if there is one unvalued MeasurableVariable involved + if all(measurable_var in rv_map_feature.rv_values for measurable_var in measurable_vars): + return None # pragma: no cover + + numerator, denominator = node.inputs + + # Check if numerator is 1 + try: + if at.get_scalar_constant_value(numerator) == 1: + return [at.reciprocal(denominator)] + except NotScalarConstantError: + pass + return [at.mul(numerator, at.reciprocal(denominator))] + + +@node_rewriter([exp, log, add, mul, reciprocal]) def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]: """Find measurable transformations from Elemwise operators.""" @@ -414,6 +444,8 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li transform = ExpTransform() elif isinstance(scalar_op, Log): transform = LogTransform() + elif isinstance(scalar_op, Reciprocal): + transform = ReciprocalTransform() elif isinstance(scalar_op, Add): transform_inputs = (measurable_input, at.add(*other_inputs)) transform = LocTransform( @@ -436,6 +468,14 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li return [transform_out] +measurable_ir_rewrites_db.register( + "measurable_div_to_reciprocal_product", + measurable_div_to_reciprocal_product, + "basic", + "transform", +) + + measurable_ir_rewrites_db.register( "find_measurable_transforms", find_measurable_transforms, @@ -507,6 +547,19 @@ def log_jac_det(self, value, *inputs): return -at.log(value) +class ReciprocalTransform(RVTransform): + name = "reciprocal" + + def forward(self, value, *inputs): + return at.reciprocal(value) + + def backward(self, value, *inputs): + return at.reciprocal(value) + + def log_jac_det(self, value, *inputs): + return -2 * at.log(value) + + class IntervalTransform(RVTransform): name = "interval" diff --git a/pymc/tests/logprob/test_transforms.py b/pymc/tests/logprob/test_transforms.py index 8467f3ddba..07438c7b46 100644 --- a/pymc/tests/logprob/test_transforms.py +++ b/pymc/tests/logprob/test_transforms.py @@ -692,17 +692,20 @@ def test_loc_transform_rv(rv_size, loc_type): @pytest.mark.parametrize( - "rv_size, scale_type", + "rv_size, scale_type, product", [ - (None, at.scalar), - (1, at.TensorType("floatX", (True,))), - ((2, 3), at.matrix), + (None, at.scalar, True), + (1, at.TensorType("floatX", (True,)), True), + ((2, 3), at.matrix, False), ], ) -def test_scale_transform_rv(rv_size, scale_type): +def test_scale_transform_rv(rv_size, scale_type, product): scale = scale_type("scale") - y_rv = at.random.normal(0, 1, size=rv_size, name="base_rv") * scale + if product: + y_rv = at.random.normal(0, 1, size=rv_size, name="base_rv") * scale + else: + y_rv = at.random.normal(0, 1, size=rv_size, name="base_rv") / at.reciprocal(scale) y_rv.name = "y" y_vv = y_rv.clone() @@ -784,6 +787,23 @@ def test_invalid_broadcasted_transform_rv_fails(): assert False, "Should have failed before" +@pytest.mark.parametrize("numerator", (1.0, 2.0)) +def test_reciprocal_rv_transform(numerator): + shape = 3 + scale = 5 + x_rv = numerator / at.random.gamma(shape, scale) + x_rv.name = "x" + + x_vv = x_rv.clone() + x_logp_fn = pytensor.function([x_vv], joint_logprob({x_rv: x_vv})) + + x_test_val = 1.5 + assert np.isclose( + x_logp_fn(x_test_val), + sp.stats.invgamma(shape, scale=scale * numerator).logpdf(x_test_val), + ) + + def test_scan_transform(): """Test that Scan valued variables can be transformed""" From a819fcfcaee549534a65b258a4c236c8b904f06d Mon Sep 17 00:00:00 2001 From: Ricardo Date: Mon, 2 May 2022 08:46:54 +0200 Subject: [PATCH 2/2] Add rewrites for measurable negation and subtraction --- pymc/logprob/transforms.py | 55 ++++++++++++++++++++++++++- pymc/tests/logprob/test_transforms.py | 36 +++++++++++++++--- 2 files changed, 84 insertions(+), 7 deletions(-) diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index d6a2621960..54f5deb7dc 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -50,7 +50,7 @@ from pytensor.scalar import Add, Exp, Log, Mul, Reciprocal from pytensor.scan.op import Scan from pytensor.tensor.exceptions import NotScalarConstantError -from pytensor.tensor.math import add, exp, log, mul, reciprocal, true_div +from pytensor.tensor.math import add, exp, log, mul, neg, reciprocal, sub, true_div from pytensor.tensor.rewriting.basic import ( register_specialize, register_stabilize, @@ -384,6 +384,46 @@ def measurable_div_to_reciprocal_product(fgraph, node): return [at.mul(numerator, at.reciprocal(denominator))] +@node_rewriter([neg]) +def measurable_neg_to_product(fgraph, node): + """Convert negation of `MeasurableVariable`s to product with `-1`.""" + + inp = node.inputs[0] + if not (inp.owner and isinstance(inp.owner.op, MeasurableVariable)): + return None + + rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) + if rv_map_feature is None: + return None # pragma: no cover + + # Only apply this rewrite if the variable is unvalued + if inp in rv_map_feature.rv_values: + return None # pragma: no cover + + return [at.mul(inp, -1.0)] + + +@node_rewriter([sub]) +def measurable_sub_to_neg(fgraph, node): + """Convert subtraction involving `MeasurableVariable`s to addition with neg""" + measurable_vars = [ + var for var in node.inputs if (var.owner and isinstance(var.owner.op, MeasurableVariable)) + ] + if not measurable_vars: + return None # pragma: no cover + + rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) + if rv_map_feature is None: + return None # pragma: no cover + + # Only apply this rewrite if there is one unvalued MeasurableVariable involved + if all(measurable_var in rv_map_feature.rv_values for measurable_var in measurable_vars): + return None # pragma: no cover + + minuend, subtrahend = node.inputs + return [at.add(minuend, at.neg(subtrahend))] + + @node_rewriter([exp, log, add, mul, reciprocal]) def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]: """Find measurable transformations from Elemwise operators.""" @@ -475,6 +515,19 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li "transform", ) +measurable_ir_rewrites_db.register( + "measurable_neg_to_product", + measurable_neg_to_product, + "basic", + "transform", +) + +measurable_ir_rewrites_db.register( + "measurable_sub_to_neg", + measurable_sub_to_neg, + "basic", + "transform", +) measurable_ir_rewrites_db.register( "find_measurable_transforms", diff --git a/pymc/tests/logprob/test_transforms.py b/pymc/tests/logprob/test_transforms.py index 07438c7b46..cf5b4eb2f1 100644 --- a/pymc/tests/logprob/test_transforms.py +++ b/pymc/tests/logprob/test_transforms.py @@ -664,17 +664,20 @@ def test_log_transform_rv(): @pytest.mark.parametrize( - "rv_size, loc_type", + "rv_size, loc_type, addition", [ - (None, at.scalar), - (2, at.vector), - ((2, 1), at.col), + (None, at.scalar, True), + (2, at.vector, False), + ((2, 1), at.col, True), ], ) -def test_loc_transform_rv(rv_size, loc_type): +def test_loc_transform_rv(rv_size, loc_type, addition): loc = loc_type("loc") - y_rv = loc + at.random.normal(0, 1, size=rv_size, name="base_rv") + if addition: + y_rv = loc + at.random.normal(0, 1, size=rv_size, name="base_rv") + else: + y_rv = at.random.normal(0, 1, size=rv_size, name="base_rv") - at.neg(loc) y_rv.name = "y" y_vv = y_rv.clone() @@ -804,6 +807,27 @@ def test_reciprocal_rv_transform(numerator): ) +def test_negated_rv_transform(): + x_rv = -at.random.halfnormal() + x_rv.name = "x" + + x_vv = x_rv.clone() + x_logp_fn = pytensor.function([x_vv], joint_logprob({x_rv: x_vv})) + + assert np.isclose(x_logp_fn(-1.5), sp.stats.halfnorm.logpdf(1.5)) + + +def test_subtracted_rv_transform(): + # Choose base RV that is assymetric around zero + x_rv = 5.0 - at.random.normal(1.0) + x_rv.name = "x" + + x_vv = x_rv.clone() + x_logp_fn = pytensor.function([x_vv], joint_logprob({x_rv: x_vv})) + + assert np.isclose(x_logp_fn(7.3), sp.stats.norm.logpdf(5.0 - 7.3, 1.0)) + + def test_scan_transform(): """Test that Scan valued variables can be transformed"""