|
30 | 30 | from pytensor.tensor.random.basic import normal, uniform
|
31 | 31 | from pytensor.tensor.random.op import RandomVariable
|
32 | 32 | from pytensor.tensor.random.var import RandomStateSharedVariable
|
| 33 | +from pytensor.tensor.slinalg import Cholesky |
33 | 34 | from pytensor.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1
|
34 | 35 | from pytensor.tensor.var import TensorVariable
|
35 | 36 |
|
@@ -878,3 +879,34 @@ def replacement_fn(var, replacements):
|
878 | 879 | [new_x], _ = _replace_vars_in_graphs([x], replacement_fn=replacement_fn)
|
879 | 880 |
|
880 | 881 | assert new_x.eval() > 50
|
| 882 | + |
| 883 | + |
| 884 | +def test_mvnormal_no_cholesky_op(): |
| 885 | + """ |
| 886 | + Test MvNormal likelihood when using Cholesky factor parameterization does not unnecessarily |
| 887 | + recompute the cholesky factorization |
| 888 | + Reversion test of #6717 |
| 889 | + """ |
| 890 | + with pm.Model() as m: |
| 891 | + n = 3 |
| 892 | + sd_dist = pm.HalfNormal.dist(shape=n) |
| 893 | + chol, corr, sigmas = pm.LKJCholeskyCov("cov", n=n, eta=1, sd_dist=sd_dist) |
| 894 | + mu = np.zeros(n) |
| 895 | + data = np.ones((10, n)) |
| 896 | + pm.MvNormal("y", mu=mu, chol=chol, observed=data) |
| 897 | + |
| 898 | + contains_cholesky_op = lambda fgraph: any( |
| 899 | + isinstance(node.op, Cholesky) for node in fgraph.apply_nodes |
| 900 | + ) |
| 901 | + |
| 902 | + logp = m.compile_logp() |
| 903 | + assert not contains_cholesky_op(logp.f.maker.fgraph) |
| 904 | + |
| 905 | + dlogp = m.compile_dlogp() |
| 906 | + assert not contains_cholesky_op(dlogp.f.maker.fgraph) |
| 907 | + |
| 908 | + d2logp = m.compile_d2logp() |
| 909 | + assert not contains_cholesky_op(d2logp.f.maker.fgraph) |
| 910 | + |
| 911 | + logp_dlogp = m.logp_dlogp_function() |
| 912 | + assert not contains_cholesky_op(logp_dlogp._pytensor_function.maker.fgraph) |
0 commit comments