diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index 903f013abe..56280fd302 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -42,6 +42,7 @@ import numpy as np import pytensor.tensor as pt +from pytensor import scan from pytensor.gradient import DisconnectedType, jacobian from pytensor.graph.basic import Apply, Node, Variable from pytensor.graph.features import AlreadyThere, Feature @@ -49,21 +50,42 @@ from pytensor.graph.op import Op from pytensor.graph.replace import clone_replace from pytensor.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter -from pytensor.scalar import Abs, Add, Exp, Log, Mul, Pow, Sqr, Sqrt +from pytensor.scalar import ( + Abs, + Add, + Cosh, + Erf, + Erfc, + Erfcx, + Exp, + Log, + Mul, + Pow, + Sinh, + Sqr, + Sqrt, + Tanh, +) from pytensor.scan.op import Scan from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.math import ( abs, add, + cosh, + erf, + erfc, + erfcx, exp, log, mul, neg, pow, reciprocal, + sinh, sqr, sqrt, sub, + tanh, true_div, ) from pytensor.tensor.rewriting.basic import ( @@ -122,6 +144,8 @@ def remove_TransformedVariables(fgraph, node): class RVTransform(abc.ABC): + ndim_supp = None + @abc.abstractmethod def forward(self, value: TensorVariable, *inputs: Variable) -> TensorVariable: """Apply the transformation.""" @@ -135,12 +159,16 @@ def backward( def log_jac_det(self, value: TensorVariable, *inputs) -> TensorVariable: """Construct the log of the absolute value of the Jacobian determinant.""" - # jac = pt.reshape( - # gradient(pt.sum(self.backward(value, *inputs)), [value]), value.shape - # ) - # return pt.log(pt.abs(jac)) - phi_inv = self.backward(value, *inputs) - return pt.log(pt.abs(pt.nlinalg.det(pt.atleast_2d(jacobian(phi_inv, [value])[0])))) + if self.ndim_supp not in (0, 1): + raise NotImplementedError( + f"RVTransform default log_jac_det only implemented for ndim_supp in (0, 1), got {self.ndim_supp=}" + ) + if self.ndim_supp == 0: + jac = pt.reshape(pt.grad(pt.sum(self.backward(value, *inputs)), [value]), value.shape) + return pt.log(pt.abs(jac)) + else: + phi_inv = self.backward(value, *inputs) + return pt.log(pt.abs(pt.nlinalg.det(pt.atleast_2d(jacobian(phi_inv, [value])[0])))) @node_rewriter(tracks=None) @@ -340,7 +368,7 @@ def apply(self, fgraph: FunctionGraph): class MeasurableTransform(MeasurableElemwise): """A placeholder used to specify a log-likelihood for a transformed measurable variable""" - valid_scalar_types = (Exp, Log, Add, Mul, Pow, Abs) + valid_scalar_types = (Exp, Log, Add, Mul, Pow, Abs, Sinh, Cosh, Tanh, Erf, Erfc, Erfcx) # Cannot use `transform` as name because it would clash with the property added by # the `TransformValuesRewrite` @@ -540,7 +568,7 @@ def measurable_sub_to_neg(fgraph, node): return [pt.add(minuend, pt.neg(subtrahend))] -@node_rewriter([exp, log, add, mul, pow, abs]) +@node_rewriter([exp, log, add, mul, pow, abs, sinh, cosh, tanh, erf, erfc, erfcx]) def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]: """Find measurable transformations from Elemwise operators.""" @@ -596,13 +624,20 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li measurable_input_idx = 0 transform_inputs: Tuple[TensorVariable, ...] = (measurable_input,) transform: RVTransform - if isinstance(scalar_op, Exp): - transform = ExpTransform() - elif isinstance(scalar_op, Log): - transform = LogTransform() - elif isinstance(scalar_op, Abs): - transform = AbsTransform() - elif isinstance(scalar_op, Pow): + + transform_dict = { + Exp: ExpTransform(), + Log: LogTransform(), + Abs: AbsTransform(), + Sinh: SinhTransform(), + Cosh: CoshTransform(), + Tanh: TanhTransform(), + Erf: ErfTransform(), + Erfc: ErfcTransform(), + Erfcx: ErfcxTransform(), + } + transform = transform_dict.get(type(scalar_op), None) + if isinstance(scalar_op, Pow): # We only allow for the base to be measurable if measurable_input_idx != 0: return None @@ -619,7 +654,7 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li transform = LocTransform( transform_args_fn=lambda *inputs: inputs[-1], ) - else: + elif transform is None: transform_inputs = (measurable_input, pt.mul(*other_inputs)) transform = ScaleTransform( transform_args_fn=lambda *inputs: inputs[-1], @@ -682,6 +717,87 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li ) +class SinhTransform(RVTransform): + name = "sinh" + ndim_supp = 0 + + def forward(self, value, *inputs): + return pt.sinh(value) + + def backward(self, value, *inputs): + return pt.arcsinh(value) + + +class CoshTransform(RVTransform): + name = "cosh" + ndim_supp = 0 + + def forward(self, value, *inputs): + return pt.cosh(value) + + def backward(self, value, *inputs): + return pt.arccosh(value) + + +class TanhTransform(RVTransform): + name = "tanh" + ndim_supp = 0 + + def forward(self, value, *inputs): + return pt.tanh(value) + + def backward(self, value, *inputs): + return pt.arctanh(value) + + +class ErfTransform(RVTransform): + name = "erf" + ndim_supp = 0 + + def forward(self, value, *inputs): + return pt.erf(value) + + def backward(self, value, *inputs): + return pt.erfinv(value) + + +class ErfcTransform(RVTransform): + name = "erfc" + ndim_supp = 0 + + def forward(self, value, *inputs): + return pt.erfc(value) + + def backward(self, value, *inputs): + return pt.erfcinv(value) + + +class ErfcxTransform(RVTransform): + name = "erfcx" + ndim_supp = 0 + + def forward(self, value, *inputs): + return pt.erfcx(value) + + def backward(self, value, *inputs): + # computes the inverse of erfcx, this was adapted from + # https://tinyurl.com/4mxfd3cz + x = pt.switch(value <= 1, 1.0 / (value * pt.sqrt(np.pi)), -pt.sqrt(pt.log(value))) + + def calc_delta_x(value, prior_result): + return prior_result - (pt.erfcx(prior_result) - value) / ( + 2 * prior_result * pt.erfcx(prior_result) - 2 / pt.sqrt(np.pi) + ) + + result, updates = scan( + fn=calc_delta_x, + outputs_info=pt.ones_like(x), + non_sequences=value, + n_steps=10, + ) + return result[-1] + + class LocTransform(RVTransform): name = "loc" diff --git a/tests/distributions/test_transform.py b/tests/distributions/test_transform.py index a29ab16679..4acf463bfd 100644 --- a/tests/distributions/test_transform.py +++ b/tests/distributions/test_transform.py @@ -112,7 +112,7 @@ def check_jacobian_det( ) for yval in domain.vals: - close_to(actual_ljd(yval), computed_ljd(yval), tol) + np.testing.assert_allclose(actual_ljd(yval), computed_ljd(yval), rtol=tol) def test_simplex(): diff --git a/tests/logprob/test_transforms.py b/tests/logprob/test_transforms.py index ba24fadee1..3484b70ba4 100644 --- a/tests/logprob/test_transforms.py +++ b/tests/logprob/test_transforms.py @@ -49,9 +49,13 @@ from pymc.distributions.transforms import _default_transform, log, logodds from pymc.logprob.abstract import MeasurableVariable, _get_measurable_outputs, _logprob -from pymc.logprob.basic import factorized_joint_logprob +from pymc.logprob.basic import factorized_joint_logprob, logp from pymc.logprob.transforms import ( ChainedTransform, + CoshTransform, + ErfcTransform, + ErfcxTransform, + ErfTransform, ExpTransform, IntervalTransform, LocTransform, @@ -59,6 +63,8 @@ LogTransform, RVTransform, ScaleTransform, + SinhTransform, + TanhTransform, TransformValuesMapping, TransformValuesRewrite, transformed_variable, @@ -327,6 +333,7 @@ def test_fallback_log_jac_det(ndim): class SquareTransform(RVTransform): name = "square" + ndim_supp = ndim def forward(self, value, *inputs): return pt.power(value, 2) @@ -336,13 +343,31 @@ def backward(self, value, *inputs): square_tr = SquareTransform() - value = pt.TensorType("float64", (None,) * ndim)("value") + value = pt.vector("value") value_tr = square_tr.forward(value) log_jac_det = square_tr.log_jac_det(value_tr) - test_value = np.full((2,) * ndim, 3) - expected_log_jac_det = -np.log(6) * test_value.size - assert np.isclose(log_jac_det.eval({value: test_value}), expected_log_jac_det) + test_value = np.r_[3, 4] + expected_log_jac_det = -np.log(2 * test_value) + if ndim == 1: + expected_log_jac_det = expected_log_jac_det.sum() + np.testing.assert_array_equal(log_jac_det.eval({value: test_value}), expected_log_jac_det) + + +@pytest.mark.parametrize("ndim", (None, 2)) +def test_fallback_log_jac_det_undefined_ndim(ndim): + class SquareTransform(RVTransform): + name = "square" + ndim_supp = ndim + + def forward(self, value, *inputs): + return pt.power(value, 2) + + def backward(self, value, *inputs): + return pt.sqrt(value) + + with pytest.raises(NotImplementedError, match=r"only implemented for ndim_supp in \(0, 1\)"): + SquareTransform().log_jac_det(0) def test_hierarchical_uniform_transform(): @@ -989,3 +1014,57 @@ def test_multivariate_transform(shift, scale): scale_mat @ cov @ scale_mat.T, ), ) + + +@pytest.mark.parametrize( + "pt_transform, transform", + [ + (pt.erf, ErfTransform()), + (pt.erfc, ErfcTransform()), + (pt.erfcx, ErfcxTransform()), + (pt.sinh, SinhTransform()), + (pt.cosh, CoshTransform()), + (pt.tanh, TanhTransform()), + ], +) +def test_erf_logp(pt_transform, transform): + base_rv = pt.random.normal( + 0.5, 1, name="base_rv" + ) # Something not centered around 0 is usually better + rv = pt_transform(base_rv) + + vv = rv.clone() + rv_logp = logp(rv, vv) + + expected_logp = logp(base_rv, transform.backward(vv)) + transform.log_jac_det(vv) + + vv_test = np.array(0.25) # Arbitrary test value + np.testing.assert_almost_equal( + rv_logp.eval({vv: vv_test}), np.nan_to_num(expected_logp.eval({vv: vv_test}), nan=-np.inf) + ) + + +from pymc.testing import Rplusbig, Vector +from tests.distributions.test_transform import check_jacobian_det + + +@pytest.mark.parametrize( + "transform", + [ + ErfTransform(), + ErfcTransform(), + ErfcxTransform(), + SinhTransform(), + CoshTransform(), + TanhTransform(), + ], +) +def test_check_jac_det(transform): + check_jacobian_det( + transform, + Vector(Rplusbig, 2), + pt.dvector, + [0.1, 0.1], + elemwise=True, + rv_var=pt.random.normal(0.5, 1, name="base_rv"), + )