From 93ba789c8fb2741bc491a7f8658c0e70adcdc938 Mon Sep 17 00:00:00 2001 From: David Horsley Date: Wed, 14 Jun 2023 17:44:29 +1000 Subject: [PATCH 1/2] Apply rewrites to logp before apply grad This adds safe rewrites to logp before the grad operator is applied. This is motivated by #6717, where expensive `cholesky(L.dot(L.T))` operations are removed. If these remain in the logp graph when the grad is taken, the resulting dlogp graph contains unnecessary operations. However this may improve the stability and performance of grad logp in other situation. --- pymc/model.py | 5 +++++ pymc/pytensorf.py | 7 +++++++ 2 files changed, 12 insertions(+) diff --git a/pymc/model.py b/pymc/model.py index ec94772192..9aeecf5e3d 100644 --- a/pymc/model.py +++ b/pymc/model.py @@ -75,6 +75,7 @@ hessian, inputvars, replace_rvs_by_values, + rewrite_pregrad, ) from pymc.util import ( UNSET, @@ -381,6 +382,8 @@ def __init__( self._extra_vars_shared[var.name] = shared givens.append((var, shared)) + cost = rewrite_pregrad(cost) + if compute_grads: grads = pytensor.grad(cost, grad_vars, disconnected_inputs="ignore") for grad_wrt, var in zip(grads, grad_vars): @@ -824,6 +827,7 @@ def dlogp( ) cost = self.logp(jacobian=jacobian) + cost = rewrite_pregrad(cost) return gradient(cost, value_vars) def d2logp( @@ -862,6 +866,7 @@ def d2logp( ) cost = self.logp(jacobian=jacobian) + cost = rewrite_pregrad(cost) return hessian(cost, value_vars) @property diff --git a/pymc/pytensorf.py b/pymc/pytensorf.py index 81e26400e1..04ed1bdad2 100644 --- a/pymc/pytensorf.py +++ b/pymc/pytensorf.py @@ -1228,3 +1228,10 @@ def constant_fold( return tuple( folded_x.data if isinstance(folded_x, Constant) else folded_x for folded_x in folded_xs ) + + +def rewrite_pregrad(graph): + """Apply simplifying or stabilizing rewrites to graph that are safe to use + pre-grad. + """ + return rewrite_graph(graph, include=("canonicalize", "stabilize")) From ee1657be26bbe4403cbf4f4faa114b348bb542f5 Mon Sep 17 00:00:00 2001 From: David Horsley Date: Wed, 14 Jun 2023 17:46:29 +1000 Subject: [PATCH 2/2] Add lower triangular tags to allow chol rewrites Since pymc-devs/pytensor#303, `cholesky(L.dot(L.T))` will rewritten to L if `L.tag.lower_triangular=True`. This change adds these where appropriate. Fixes #6717. --- pymc/distributions/multivariate.py | 3 +++ pymc/math.py | 10 ++++++++-- tests/test_pytensorf.py | 32 ++++++++++++++++++++++++++++++ 3 files changed, 43 insertions(+), 2 deletions(-) diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index ba613967c1..476fa14c2d 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -132,6 +132,9 @@ def quaddist_matrix(cov=None, chol=None, tau=None, lower=True, *args, **kwargs): chol = pt.as_tensor_variable(chol) if chol.ndim != 2: raise ValueError("chol must be two dimensional.") + + # tag as lower triangular to enable pytensor rewrites of chol(l.l') -> l + chol.tag.lower_triangular = True cov = chol.dot(chol.T) return cov diff --git a/pymc/math.py b/pymc/math.py index 688186c51c..2f8520b0ec 100644 --- a/pymc/math.py +++ b/pymc/math.py @@ -443,11 +443,17 @@ def expand_packed_triangular(n, packed, lower=True, diagonal_only=False): elif lower: out = pt.zeros((n, n), dtype=pytensor.config.floatX) idxs = np.tril_indices(n) - return pt.set_subtensor(out[idxs], packed) + # tag as lower triangular to enable pytensor rewrites + out = pt.set_subtensor(out[idxs], packed) + out.tag.lower_triangular = True + return out elif not lower: out = pt.zeros((n, n), dtype=pytensor.config.floatX) idxs = np.triu_indices(n) - return pt.set_subtensor(out[idxs], packed) + # tag as upper triangular to enable pytensor rewrites + out = pt.set_subtensor(out[idxs], packed) + out.tag.upper_triangular = True + return out class BatchedDiag(Op): diff --git a/tests/test_pytensorf.py b/tests/test_pytensorf.py index 865047a8d7..94cf75dc7d 100644 --- a/tests/test_pytensorf.py +++ b/tests/test_pytensorf.py @@ -30,6 +30,7 @@ from pytensor.tensor.random.basic import normal, uniform from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.random.var import RandomStateSharedVariable +from pytensor.tensor.slinalg import Cholesky from pytensor.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1 from pytensor.tensor.var import TensorVariable @@ -878,3 +879,34 @@ def replacement_fn(var, replacements): [new_x], _ = _replace_vars_in_graphs([x], replacement_fn=replacement_fn) assert new_x.eval() > 50 + + +def test_mvnormal_no_cholesky_op(): + """ + Test MvNormal likelihood when using Cholesky factor parameterization does not unnecessarily + recompute the cholesky factorization + Reversion test of #6717 + """ + with pm.Model() as m: + n = 3 + sd_dist = pm.HalfNormal.dist(shape=n) + chol, corr, sigmas = pm.LKJCholeskyCov("cov", n=n, eta=1, sd_dist=sd_dist) + mu = np.zeros(n) + data = np.ones((10, n)) + pm.MvNormal("y", mu=mu, chol=chol, observed=data) + + contains_cholesky_op = lambda fgraph: any( + isinstance(node.op, Cholesky) for node in fgraph.apply_nodes + ) + + logp = m.compile_logp() + assert not contains_cholesky_op(logp.f.maker.fgraph) + + dlogp = m.compile_dlogp() + assert not contains_cholesky_op(dlogp.f.maker.fgraph) + + d2logp = m.compile_d2logp() + assert not contains_cholesky_op(d2logp.f.maker.fgraph) + + logp_dlogp = m.logp_dlogp_function() + assert not contains_cholesky_op(logp_dlogp._pytensor_function.maker.fgraph)