Skip to content

BUG: MvNormal logp recomputes Cholesky factorization #6717

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
dehorsley opened this issue May 15, 2023 · 3 comments
Closed

BUG: MvNormal logp recomputes Cholesky factorization #6717

dehorsley opened this issue May 15, 2023 · 3 comments
Labels

Comments

@dehorsley
Copy link
Contributor

dehorsley commented May 15, 2023

Describe the issue:

When given a Cholesky factor, L, an observed MvNormal, eg

pm.MvNormal('y', mu=np.zeros(2), chol=chol, observed=np.zeros(2))

the logp function unnecessarily computes the matrix product of L.LT, then recomputes the Cholesky factor.

This is expensive for large matrices.

This originates in the changes for pymc 4, where both Cholesky and precision matrix parameterizations were modified to transform into a covariance matrix parameterization. I'm guessing there are some performance improvements to be had with the precision matrix as well.

Reproduceable code example:

import numpy as np
import pymc as pm
import pytensor

with pm.Model() as m:
   chol, corr, sigmas = pm.LKJCholeskyCov('cov', n=2, eta=1, sd_dist=pm.HalfNormal.dist())
   pm.MvNormal('y', mu=np.zeros(2), chol=chol, observed=np.zeros(2))

pytensor.dprint(m.logp())

Error message:

Region of logp dprint interest highlighted here, full logp below. The function computes L.LT then recomputes cholesky(L.LT).

       | |     |   |   | |     | |Cholesky{lower=True, destructive=False, on_error='nan'} [id GL]
       | |     |   |   | |     |   |dot [id GM]
       | |     |   |   | |     |     |AdvancedIncSubtensor{inplace=False,  set_instead_of_inc=True} [id GN]
       | |     |   |   | |     |     | |Alloc [id GO]
       | |     |   |   | |     |     | | |TensorConstant{0.0} [id GP]
       | |     |   |   | |     |     | | |TensorConstant{2} [id GQ]
       | |     |   |   | |     |     | | |TensorConstant{2} [id GQ]
       | |     |   |   | |     |     | |TransformedVariable [id CB] 'cov_cholesky-cov-packed___cholesky-cov-packed'
       | |     |   |   | |     |     | |TensorConstant{[0 1 1]} [id GR]
       | |     |   |   | |     |     | |TensorConstant{[0 0 1]} [id GS]
       | |     |   |   | |     |     |InplaceDimShuffle{1,0} [id GT]
       | |     |   |   | |     |       |AdvancedIncSubtensor{inplace=False,

Sum{acc_dtype=float64} [id A] '__logp'
 |MakeVector{dtype='float64'} [id B]
   |Sum{acc_dtype=float64} [id C]
   | |Elemwise{add,no_inplace} [id D] 'cov_cholesky-cov-packed___logprob'
   |   |Elemwise{add,no_inplace} [id E]
   |   | |Elemwise{add,no_inplace} [id F]
   |   | | |Elemwise{add,no_inplace} [id G]
   |   | | | |Elemwise{add,no_inplace} [id H]
   |   | | | | |Sum{acc_dtype=float64} [id I]
   |   | | | | | |Elemwise{gammaln,no_inplace} [id J]
   |   | | | | |   |Elemwise{mul,no_inplace} [id K]
   |   | | | | |     |InplaceDimShuffle{x} [id L]
   |   | | | | |     | |TensorConstant{2.0} [id M]
   |   | | | | |     |ARange{dtype='int64'} [id N]
   |   | | | | |       |TensorConstant{1} [id O]
   |   | | | | |       |TensorConstant{1} [id P]
   |   | | | | |       |TensorConstant{1} [id Q]
   |   | | | | |Elemwise{sub,no_inplace} [id R]
   |   | | | |   |Elemwise{add,no_inplace} [id S]
   |   | | | |   | |Elemwise{add,no_inplace} [id T]
   |   | | | |   | | |Elemwise{mul,no_inplace} [id U]
   |   | | | |   | | | |TensorConstant{0.0} [id V]
   |   | | | |   | | | |Elemwise{log,no_inplace} [id W]
   |   | | | |   | | |   |TensorConstant{3.141592653589793} [id X]
   |   | | | |   | | |Elemwise{mul,no_inplace} [id Y]
   |   | | | |   | |   |TensorConstant{1.0} [id Z]
   |   | | | |   | |   |Elemwise{log,no_inplace} [id BA]
   |   | | | |   | |     |TensorConstant{2.0} [id BB]
   |   | | | |   | |Elemwise{mul,no_inplace} [id BC]
   |   | | | |   |   |TensorConstant{2} [id BD]
   |   | | | |   |   |Elemwise{gammaln,no_inplace} [id BE]
   |   | | | |   |     |TensorConstant{1.0} [id BF]
   |   | | | |   |Elemwise{mul,no_inplace} [id BG]
   |   | | | |     |TensorConstant{1} [id BH]
   |   | | | |     |Elemwise{gammaln,no_inplace} [id BI]
   |   | | | |       |TensorConstant{2} [id BJ]
   |   | | | |Sum{acc_dtype=float64} [id BK]
   |   | | |   |Elemwise{mul,no_inplace} [id BL]
   |   | | |     |Elemwise{sub,no_inplace} [id BM]
   |   | | |     | |InplaceDimShuffle{x} [id BN]
   |   | | |     | | |Elemwise{add,no_inplace} [id BO]
   |   | | |     | |   |Elemwise{sub,no_inplace} [id BP]
   |   | | |     | |   | |Elemwise{mul,no_inplace} [id BQ]
   |   | | |     | |   | | |TensorConstant{2} [id BR]
   |   | | |     | |   | | |TensorConstant{1.0} [id BS]
   |   | | |     | |   | |TensorConstant{3} [id BT]
   |   | | |     | |   |TensorConstant{2} [id BU]
   |   | | |     | |ARange{dtype='int64'} [id BV]
   |   | | |     |   |TensorConstant{0} [id BW]
   |   | | |     |   |TensorConstant{2} [id BU]
   |   | | |     |   |TensorConstant{1} [id BX]
   |   | | |     |Elemwise{log,no_inplace} [id BY]
   |   | | |       |Elemwise{true_div,no_inplace} [id BZ]
   |   | | |         |AdvancedSubtensor [id CA]
   |   | | |         | |TransformedVariable [id CB] 'cov_cholesky-cov-packed___cholesky-cov-packed'
   |   | | |         | | |AdvancedIncSubtensor{inplace=False,  set_instead_of_inc=True} [id CC]
   |   | | |         | | | |cov_cholesky-cov-packed__ [id CD]
   |   | | |         | | | |Elemwise{exp,no_inplace} [id CE]
   |   | | |         | | | | |AdvancedSubtensor [id CF]
   |   | | |         | | | |   |cov_cholesky-cov-packed__ [id CD]
   |   | | |         | | | |   |Elemwise{sub,no_inplace} [id CG]
   |   | | |         | | | |     |CumOp{None, add} [id CH]
   |   | | |         | | | |     | |ARange{dtype='int64'} [id CI]
   |   | | |         | | | |     |   |TensorConstant{1} [id CJ]
   |   | | |         | | | |     |   |Elemwise{add,no_inplace} [id CK]
   |   | | |         | | | |     |   | |TensorConstant{2} [id BU]
   |   | | |         | | | |     |   | |TensorConstant{1} [id CL]
   |   | | |         | | | |     |   |TensorConstant{1} [id CM]
   |   | | |         | | | |     |InplaceDimShuffle{x} [id CN]
   |   | | |         | | | |       |TensorConstant{1} [id CO]
   |   | | |         | | | |Elemwise{sub,no_inplace} [id CG]
   |   | | |         | | |cov_cholesky-cov-packed__ [id CD]
   |   | | |         | |Elemwise{sub,no_inplace} [id CP]
   |   | | |         |   |CumOp{None, add} [id CQ]
   |   | | |         |   | |ARange{dtype='int64'} [id CR]
   |   | | |         |   |   |TensorConstant{1} [id CS]
   |   | | |         |   |   |Elemwise{add,no_inplace} [id CT]
   |   | | |         |   |   | |TensorConstant{2} [id BU]
   |   | | |         |   |   | |TensorConstant{1} [id CU]
   |   | | |         |   |   |TensorConstant{1} [id CV]
   |   | | |         |   |InplaceDimShuffle{x} [id CW]
   |   | | |         |     |TensorConstant{1} [id CX]
   |   | | |         |Elemwise{sqrt,no_inplace} [id CY]
   |   | | |           |IncSubtensor{Inc;int64::} [id CZ]
   |   | | |             |IncSubtensor{Inc;int64} [id DA]
   |   | | |             | |Alloc [id DB]
   |   | | |             | | |TensorConstant{0.0} [id DC]
   |   | | |             | | |Subtensor{int64} [id DD]
   |   | | |             | |   |InplaceDimShuffle{x} [id DE]
   |   | | |             | |   | |TensorConstant{2} [id BU]
   |   | | |             | |   |ScalarConstant{0} [id DF]
   |   | | |             | |Elemwise{pow,no_inplace} [id DG]
   |   | | |             | | |Subtensor{int64} [id DH]
   |   | | |             | | | |TransformedVariable [id CB] 'cov_cholesky-cov-packed___cholesky-cov-packed'
   |   | | |             | | | |ScalarConstant{0} [id DI]
   |   | | |             | | |TensorConstant{2} [id DJ]
   |   | | |             | |ScalarConstant{0} [id DK]
   |   | | |             |Elemwise{sub,no_inplace} [id DL]
   |   | | |             | |AdvancedSubtensor [id DM]
   |   | | |             | | |CumOp{None, add} [id DN]
   |   | | |             | | | |Elemwise{pow,no_inplace} [id DO]
   |   | | |             | | |   |TransformedVariable [id CB] 'cov_cholesky-cov-packed___cholesky-cov-packed'
   |   | | |             | | |   |InplaceDimShuffle{x} [id DP]
   |   | | |             | | |     |TensorConstant{2} [id DQ]
   |   | | |             | | |Subtensor{int64::} [id DR]
   |   | | |             | |   |Elemwise{sub,no_inplace} [id CP]
   |   | | |             | |   |ScalarConstant{1} [id DS]
   |   | | |             | |AdvancedSubtensor [id DT]
   |   | | |             |   |CumOp{None, add} [id DN]
   |   | | |             |   |Subtensor{:int64:} [id DU]
   |   | | |             |     |Elemwise{sub,no_inplace} [id CP]
   |   | | |             |     |ScalarConstant{-1} [id DV]
   |   | | |             |ScalarConstant{1} [id DW]
   |   | | |Sum{acc_dtype=float64} [id DX]
   |   | |   |Check{sigma > 0} [id DY]
   |   | |     |Elemwise{switch,no_inplace} [id DZ]
   |   | |     | |Elemwise{ge,no_inplace} [id EA]
   |   | |     | | |Elemwise{sqrt,no_inplace} [id CY]
   |   | |     | | |InplaceDimShuffle{x} [id EB]
   |   | |     | |   |TensorConstant{0.0} [id EC]
   |   | |     | |Elemwise{sub,no_inplace} [id ED]
   |   | |     | | |Elemwise{add,no_inplace} [id EE]
   |   | |     | | | |Elemwise{mul,no_inplace} [id EF]
   |   | |     | | | | |InplaceDimShuffle{x} [id EG]
   |   | |     | | | | | |TensorConstant{-0.5} [id EH]
   |   | |     | | | | |Elemwise{pow,no_inplace} [id EI]
   |   | |     | | | |   |Elemwise{true_div,no_inplace} [id EJ]
   |   | |     | | | |   | |Elemwise{sub,no_inplace} [id EK]
   |   | |     | | | |   | | |Elemwise{sqrt,no_inplace} [id CY]
   |   | |     | | | |   | | |InplaceDimShuffle{x} [id EL]
   |   | |     | | | |   | |   |TensorConstant{0.0} [id EC]
   |   | |     | | | |   | |InplaceDimShuffle{x} [id EM]
   |   | |     | | | |   |   |TensorConstant{1.0} [id BS]
   |   | |     | | | |   |InplaceDimShuffle{x} [id EN]
   |   | |     | | | |     |TensorConstant{2} [id EO]
   |   | |     | | | |InplaceDimShuffle{x} [id EP]
   |   | |     | | |   |Elemwise{log,no_inplace} [id EQ]
   |   | |     | | |     |Elemwise{sqrt,no_inplace} [id ER]
   |   | |     | | |       |TensorConstant{0.6366197723675814} [id ES]
   |   | |     | | |InplaceDimShuffle{x} [id ET]
   |   | |     | |   |Elemwise{log,no_inplace} [id EU]
   |   | |     | |     |TensorConstant{1.0} [id BS]
   |   | |     | |InplaceDimShuffle{x} [id EV]
   |   | |     |   |TensorConstant{-inf} [id EW]
   |   | |     |All [id EX]
   |   | |       |MakeVector{dtype='bool'} [id EY]
   |   | |         |All [id EZ]
   |   | |           |Elemwise{gt,no_inplace} [id FA]
   |   | |             |TensorConstant{1.0} [id BS]
   |   | |             |TensorConstant{0} [id FB]
   |   | |Sum{acc_dtype=float64} [id FC]
   |   |   |Elemwise{sub,no_inplace} [id FD]
   |   |     |Elemwise{log,no_inplace} [id FE]
   |   |     | |Elemwise{true_div,no_inplace} [id BZ]
   |   |     |Elemwise{mul,no_inplace} [id FF]
   |   |       |ARange{dtype='int64'} [id FG]
   |   |       | |TensorConstant{0} [id FH]
   |   |       | |TensorConstant{2} [id BU]
   |   |       | |TensorConstant{1} [id FI]
   |   |       |Elemwise{log,no_inplace} [id FJ]
   |   |         |Elemwise{sqrt,no_inplace} [id CY]
   |   |Elemwise{identity} [id FK] 'cov_cholesky-cov-packed___cholesky-cov-packed_jacobian'
   |     |Sum{axis=[0], acc_dtype=float64} [id FL]
   |       |AdvancedSubtensor [id FM]
   |         |cov_cholesky-cov-packed__ [id CD]
   |         |Elemwise{sub,no_inplace} [id CG]
   |Sum{acc_dtype=float64} [id FN]
     |Check{posdef} [id FO] 'y_logprob'
       |Elemwise{sub,no_inplace} [id FP]
       | |Elemwise{sub,no_inplace} [id FQ]
       | | |Elemwise{mul,no_inplace} [id FR]
       | | | |Elemwise{mul,no_inplace} [id FS]
       | | | | |TensorConstant{-0.5} [id FT]
       | | | | |Elemwise{Cast{float64}} [id FU]
       | | | |   |Subtensor{int64} [id FV]
       | | | |     |TensorConstant{(1,) of 2} [id FW]
       | | | |     |ScalarConstant{-1} [id FX]
       | | | |TensorConstant{1.8378770664093453} [id FY]
       | | |Elemwise{mul,no_inplace} [id FZ]
       | |   |TensorConstant{0.5} [id GA]
       | |   |Subtensor{int64} [id GB]
       | |     |Sum{axis=[1], acc_dtype=float64} [id GC]
       | |     | |Elemwise{pow,no_inplace} [id GD]
       | |     |   |InplaceDimShuffle{1,0} [id GE]
       | |     |   | |SolveTriangular{lower=True, trans=0, unit_diagonal=False, check_finite=True} [id GF]
       | |     |   |   |Elemwise{switch,no_inplace} [id GG]
       | |     |   |   | |InplaceDimShuffle{x,x} [id GH]
       | |     |   |   | | |All [id GI]
       | |     |   |   | |   |Elemwise{gt,no_inplace} [id GJ]
       | |     |   |   | |     |ExtractDiag{offset=0, axis1=0, axis2=1, view=False} [id GK]
       | |     |   |   | |     | |Cholesky{lower=True, destructive=False, on_error='nan'} [id GL]
       | |     |   |   | |     |   |dot [id GM]
       | |     |   |   | |     |     |AdvancedIncSubtensor{inplace=False,  set_instead_of_inc=True} [id GN]
       | |     |   |   | |     |     | |Alloc [id GO]
       | |     |   |   | |     |     | | |TensorConstant{0.0} [id GP]
       | |     |   |   | |     |     | | |TensorConstant{2} [id GQ]
       | |     |   |   | |     |     | | |TensorConstant{2} [id GQ]
       | |     |   |   | |     |     | |TransformedVariable [id CB] 'cov_cholesky-cov-packed___cholesky-cov-packed'
       | |     |   |   | |     |     | |TensorConstant{[0 1 1]} [id GR]
       | |     |   |   | |     |     | |TensorConstant{[0 0 1]} [id GS]
       | |     |   |   | |     |     |InplaceDimShuffle{1,0} [id GT]
       | |     |   |   | |     |       |AdvancedIncSubtensor{inplace=False,  set_instead_of_inc=True} [id GN]
       | |     |   |   | |     |InplaceDimShuffle{x} [id GU]
       | |     |   |   | |       |TensorConstant{0} [id GV]
       | |     |   |   | |Cholesky{lower=True, destructive=False, on_error='nan'} [id GL]
       | |     |   |   | |InplaceDimShuffle{x,x} [id GW]
       | |     |   |   |   |TensorConstant{1} [id GX]
       | |     |   |   |InplaceDimShuffle{1,0} [id GY]
       | |     |   |     |Elemwise{sub,no_inplace} [id GZ]
       | |     |   |       |InplaceDimShuffle{x,0} [id HA]
       | |     |   |       | |y{(2,) of 0.0} [id HB]
       | |     |   |       |InplaceDimShuffle{x,0} [id HC]
       | |     |   |         |TensorConstant{(2,) of 0.0} [id HD]
       | |     |   |InplaceDimShuffle{x,x} [id HE]
       | |     |     |TensorConstant{2} [id HF]
       | |     |ScalarConstant{0} [id HG]
       | |Sum{acc_dtype=float64} [id HH]
       |   |Elemwise{log,no_inplace} [id HI]
       |     |ExtractDiag{offset=0, axis1=0, axis2=1, view=False} [id GK]
       |All [id HJ]
         |MakeVector{dtype='bool'} [id HK]
           |All [id HL]
             |All [id GI]

PyMC version information:

pymc 5.3.1

Context for the issue:

Performance regression for any model involving observed MvNormal

@dehorsley dehorsley added the bug label May 15, 2023
@ricardoV94
Copy link
Member

ricardoV94 commented May 15, 2023

Yeah I think you're right. We always convert the cholesky into a covariance internally, because that's what the MvNormal RandomVariable needs. It would be easy to add a rewrite that replaces cholesky(dot(L, L.T)) -> L but the dlogp would probably still be messed up, because the gradient would be retrieved before we had a chance to do that optimization.

We had planned to introduce an optimization on the logp graph before we create the dlogp, and this is yet another good reason for it.

A more immediate solution is to create new RandomVariables that extend the MvNormal and MvStudentT by carrying along the cholesky as an additional parameter so that it can be used directly in the logp graph. For random graphs where those aren't used we would replace them by the base MvNormal and MvStudenT RandomVariables

@dehorsley
Copy link
Contributor Author

I looked into adding a graph rewrite, but I couldn't see a way of proving the L in dot(L,L.T) is lower triangular without some more infrastructure. Adding a node for triu_indices might do it but that doesn't catch everything, eg a constant cholesky matrix. So it seems it will be easier to fix on the pymc side.

That said, being able to handle this in PyTensor would be nice, as I think it would open the door to some other optimisations around Cholesky factorisation, such as rank-n updates/downdates.

@ricardoV94
Copy link
Member

ricardoV94 commented May 15, 2023

We can add a flag to the tag of the L variable "promising" that it is lower triangular. We do that for conveying semipositiveness or symmetry in some rewrites in PyTensor.

More generally, if you are interested I think there's a lot of low-hanging fruit at the rewrite level in PyTensor for this kind of operations.

I think this all we have so far: https://github.com/pymc-devs/pytensor/blob/main/pytensor/sandbox/linalg/ops.py

dehorsley added a commit to dehorsley/pymc that referenced this issue Jun 29, 2023
This adds safe rewrites to logp before the grad operator is applied.
This is motivated by pymc-devs#6717, where expensive `cholesky(L.dot(L.T))`
operations are removed. If these remain in the logp graph when the grad
is taken, the resulting dlogp graph contains unnecessary operations.
However this may improve the stability and performance of grad logp
in other situation.
dehorsley added a commit to dehorsley/pymc that referenced this issue Jun 29, 2023
Since pymc-devs/pytensor#303, `cholesky(L.dot(L.T))` will rewritten to L
if `L.tag.lower_triangular=True`. This change adds these where
appropriate.

Fixes pymc-devs#6717.
ricardoV94 pushed a commit that referenced this issue Jun 29, 2023
This adds safe rewrites to logp before the grad operator is applied.
This is motivated by #6717, where expensive `cholesky(L.dot(L.T))`
operations are removed. If these remain in the logp graph when the grad
is taken, the resulting dlogp graph contains unnecessary operations.
However this may improve the stability and performance of grad logp
in other situation.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants