Skip to content

Commit ee1657b

Browse files
committed
Add lower triangular tags to allow chol rewrites
Since pymc-devs/pytensor#303, `cholesky(L.dot(L.T))` will rewritten to L if `L.tag.lower_triangular=True`. This change adds these where appropriate. Fixes #6717.
1 parent 93ba789 commit ee1657b

File tree

3 files changed

+43
-2
lines changed

3 files changed

+43
-2
lines changed

pymc/distributions/multivariate.py

+3
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,9 @@ def quaddist_matrix(cov=None, chol=None, tau=None, lower=True, *args, **kwargs):
132132
chol = pt.as_tensor_variable(chol)
133133
if chol.ndim != 2:
134134
raise ValueError("chol must be two dimensional.")
135+
136+
# tag as lower triangular to enable pytensor rewrites of chol(l.l') -> l
137+
chol.tag.lower_triangular = True
135138
cov = chol.dot(chol.T)
136139

137140
return cov

pymc/math.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -443,11 +443,17 @@ def expand_packed_triangular(n, packed, lower=True, diagonal_only=False):
443443
elif lower:
444444
out = pt.zeros((n, n), dtype=pytensor.config.floatX)
445445
idxs = np.tril_indices(n)
446-
return pt.set_subtensor(out[idxs], packed)
446+
# tag as lower triangular to enable pytensor rewrites
447+
out = pt.set_subtensor(out[idxs], packed)
448+
out.tag.lower_triangular = True
449+
return out
447450
elif not lower:
448451
out = pt.zeros((n, n), dtype=pytensor.config.floatX)
449452
idxs = np.triu_indices(n)
450-
return pt.set_subtensor(out[idxs], packed)
453+
# tag as upper triangular to enable pytensor rewrites
454+
out = pt.set_subtensor(out[idxs], packed)
455+
out.tag.upper_triangular = True
456+
return out
451457

452458

453459
class BatchedDiag(Op):

tests/test_pytensorf.py

+32
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from pytensor.tensor.random.basic import normal, uniform
3131
from pytensor.tensor.random.op import RandomVariable
3232
from pytensor.tensor.random.var import RandomStateSharedVariable
33+
from pytensor.tensor.slinalg import Cholesky
3334
from pytensor.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1
3435
from pytensor.tensor.var import TensorVariable
3536

@@ -878,3 +879,34 @@ def replacement_fn(var, replacements):
878879
[new_x], _ = _replace_vars_in_graphs([x], replacement_fn=replacement_fn)
879880

880881
assert new_x.eval() > 50
882+
883+
884+
def test_mvnormal_no_cholesky_op():
885+
"""
886+
Test MvNormal likelihood when using Cholesky factor parameterization does not unnecessarily
887+
recompute the cholesky factorization
888+
Reversion test of #6717
889+
"""
890+
with pm.Model() as m:
891+
n = 3
892+
sd_dist = pm.HalfNormal.dist(shape=n)
893+
chol, corr, sigmas = pm.LKJCholeskyCov("cov", n=n, eta=1, sd_dist=sd_dist)
894+
mu = np.zeros(n)
895+
data = np.ones((10, n))
896+
pm.MvNormal("y", mu=mu, chol=chol, observed=data)
897+
898+
contains_cholesky_op = lambda fgraph: any(
899+
isinstance(node.op, Cholesky) for node in fgraph.apply_nodes
900+
)
901+
902+
logp = m.compile_logp()
903+
assert not contains_cholesky_op(logp.f.maker.fgraph)
904+
905+
dlogp = m.compile_dlogp()
906+
assert not contains_cholesky_op(dlogp.f.maker.fgraph)
907+
908+
d2logp = m.compile_d2logp()
909+
assert not contains_cholesky_op(d2logp.f.maker.fgraph)
910+
911+
logp_dlogp = m.logp_dlogp_function()
912+
assert not contains_cholesky_op(logp_dlogp._pytensor_function.maker.fgraph)

0 commit comments

Comments
 (0)