diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index ba613967c1..476fa14c2d 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -132,6 +132,9 @@ def quaddist_matrix(cov=None, chol=None, tau=None, lower=True, *args, **kwargs): chol = pt.as_tensor_variable(chol) if chol.ndim != 2: raise ValueError("chol must be two dimensional.") + + # tag as lower triangular to enable pytensor rewrites of chol(l.l') -> l + chol.tag.lower_triangular = True cov = chol.dot(chol.T) return cov diff --git a/pymc/math.py b/pymc/math.py index 688186c51c..2f8520b0ec 100644 --- a/pymc/math.py +++ b/pymc/math.py @@ -443,11 +443,17 @@ def expand_packed_triangular(n, packed, lower=True, diagonal_only=False): elif lower: out = pt.zeros((n, n), dtype=pytensor.config.floatX) idxs = np.tril_indices(n) - return pt.set_subtensor(out[idxs], packed) + # tag as lower triangular to enable pytensor rewrites + out = pt.set_subtensor(out[idxs], packed) + out.tag.lower_triangular = True + return out elif not lower: out = pt.zeros((n, n), dtype=pytensor.config.floatX) idxs = np.triu_indices(n) - return pt.set_subtensor(out[idxs], packed) + # tag as upper triangular to enable pytensor rewrites + out = pt.set_subtensor(out[idxs], packed) + out.tag.upper_triangular = True + return out class BatchedDiag(Op): diff --git a/pymc/model.py b/pymc/model.py index ec94772192..9aeecf5e3d 100644 --- a/pymc/model.py +++ b/pymc/model.py @@ -75,6 +75,7 @@ hessian, inputvars, replace_rvs_by_values, + rewrite_pregrad, ) from pymc.util import ( UNSET, @@ -381,6 +382,8 @@ def __init__( self._extra_vars_shared[var.name] = shared givens.append((var, shared)) + cost = rewrite_pregrad(cost) + if compute_grads: grads = pytensor.grad(cost, grad_vars, disconnected_inputs="ignore") for grad_wrt, var in zip(grads, grad_vars): @@ -824,6 +827,7 @@ def dlogp( ) cost = self.logp(jacobian=jacobian) + cost = rewrite_pregrad(cost) return gradient(cost, value_vars) def d2logp( @@ -862,6 +866,7 @@ def d2logp( ) cost = self.logp(jacobian=jacobian) + cost = rewrite_pregrad(cost) return hessian(cost, value_vars) @property diff --git a/pymc/pytensorf.py b/pymc/pytensorf.py index 81e26400e1..04ed1bdad2 100644 --- a/pymc/pytensorf.py +++ b/pymc/pytensorf.py @@ -1228,3 +1228,10 @@ def constant_fold( return tuple( folded_x.data if isinstance(folded_x, Constant) else folded_x for folded_x in folded_xs ) + + +def rewrite_pregrad(graph): + """Apply simplifying or stabilizing rewrites to graph that are safe to use + pre-grad. + """ + return rewrite_graph(graph, include=("canonicalize", "stabilize")) diff --git a/tests/test_pytensorf.py b/tests/test_pytensorf.py index 865047a8d7..94cf75dc7d 100644 --- a/tests/test_pytensorf.py +++ b/tests/test_pytensorf.py @@ -30,6 +30,7 @@ from pytensor.tensor.random.basic import normal, uniform from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.random.var import RandomStateSharedVariable +from pytensor.tensor.slinalg import Cholesky from pytensor.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1 from pytensor.tensor.var import TensorVariable @@ -878,3 +879,34 @@ def replacement_fn(var, replacements): [new_x], _ = _replace_vars_in_graphs([x], replacement_fn=replacement_fn) assert new_x.eval() > 50 + + +def test_mvnormal_no_cholesky_op(): + """ + Test MvNormal likelihood when using Cholesky factor parameterization does not unnecessarily + recompute the cholesky factorization + Reversion test of #6717 + """ + with pm.Model() as m: + n = 3 + sd_dist = pm.HalfNormal.dist(shape=n) + chol, corr, sigmas = pm.LKJCholeskyCov("cov", n=n, eta=1, sd_dist=sd_dist) + mu = np.zeros(n) + data = np.ones((10, n)) + pm.MvNormal("y", mu=mu, chol=chol, observed=data) + + contains_cholesky_op = lambda fgraph: any( + isinstance(node.op, Cholesky) for node in fgraph.apply_nodes + ) + + logp = m.compile_logp() + assert not contains_cholesky_op(logp.f.maker.fgraph) + + dlogp = m.compile_dlogp() + assert not contains_cholesky_op(dlogp.f.maker.fgraph) + + d2logp = m.compile_d2logp() + assert not contains_cholesky_op(d2logp.f.maker.fgraph) + + logp_dlogp = m.logp_dlogp_function() + assert not contains_cholesky_op(logp_dlogp._pytensor_function.maker.fgraph)