-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Rewrite logp graph before taking the gradient #6736
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #6736 +/- ##
==========================================
- Coverage 91.92% 91.91% -0.02%
==========================================
Files 95 95
Lines 16197 16207 +10
==========================================
+ Hits 14889 14896 +7
- Misses 1308 1311 +3
|
This is a bit brute-force but can you benchmark after replacing these lines: Lines 386 to 387 in 8c93bb5
By: logp_terms = list(logp_terms.values())
_check_no_rvs(logp_terms)
from pytensor.compile import optdb
from pytensor.graph import FunctionGraph
rewrite = optdb.query("+canonicalize")
fg = FunctionGraph(outputs=logp_terms, clone=False)
rewrite.rewrite(fg)
return logp_terms That should remove the useless cholesky before the gradient is generated. |
Updated benchmarks:
@ricardoV94, do you think it's worth making the rewite optional via |
Awesome results. About the rewrites, we shouldn't introduce them here. Users should always look at the compiled function when they want to investigate the graph for performance concerns. We should introduce the cholesky rewrite (and a couple others like |
By the way can you share the benchmark script so I can replicate locally? We should add a test for the compiled logp/dlogp to make sure the useless cholesky is removed (and avoid a regression in the future) |
@ricardoV94, is this the kind of thing you were thinking? I've made a new rewrite db if we want to be a bit more precise about what we want to perform. Not sure if that's the best place to put the rewrites. I'll add tests in the next few days. This is the script I was using for benchmarks (ipynb but I can't attach here):
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks pretty good, just wondering whether we should be more targeted in the rewrites we include
e750986
to
b27b828
Compare
New test will fail until pytensor version is bumped. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some small typos.
Does the last PyTensor already include the rewrites?
Some rough benchmarks for compiling
Generally, including the pregrad rewrites is faster overall, if not the same. I'm guessing the time saved by compiling the simpler dlogp is making up for the extra time spend optimising. Admittedly these are not very interesting models. Let me know if you have some others to test! MethodologyFor each of the following, I executed
For uncached, then for cached
Models: n=1000
with pm.Model() as m1:
chol, corr, sigmas = pm.LKJCholeskyCov('cov', n=n, eta=1, sd_dist=pm.HalfNormal.dist())
pm.MvNormal('y', mu=np.zeros(n), chol=chol, observed=np.ones((1000, n)))
with pm.Model() as m2:
sigma = pm.HalfCauchy('sigma', 1)
for i in range(1,30):
pm.Normal(f'm{i}', 0, i*sigma)
y = np.array([28, 8, -3, 7, -1, 1, 18, 12])
sigma = np.array([15, 10, 16, 11, 9, 11, 10, 18])
J = len(y)
with pm.Model() as m3:
eta = pm.Normal("eta", 0, 1, shape=J)
# Hierarchical mean and SD
mu = pm.Normal("mu", 0, sigma=10)
tau = pm.HalfNormal("tau", 10)
# Non-centered parameterization of random effect
theta = pm.Deterministic("theta", mu + tau * eta)
obs = pm.Normal("obs", theta, sigma=sigma, observed=y) I had to decrease the number of variables in model 2 as it was hitting the compilers nesting level max depth, but only in the case without the pregrad rewrites. (Of course there are other ways to get around this, and this is an intentionally pathological way to write this model!) |
86a2b70
to
34ab51d
Compare
Thanks, fixed.
Yes, included since 2.12.2. EDIT: ah looks like a dimension issue, as well as a programmer error. Passing locally now. |
Benchmarks look good. I'll mark this PR as requiring a major release and I'll ask other devs to keep an eye to see if their runtimes are significantly affected. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a last-minute doubt about the helper. WDYT?
This adds safe rewrites to logp before the grad operator is applied. This is motivated by pymc-devs#6717, where expensive `cholesky(L.dot(L.T))` operations are removed. If these remain in the logp graph when the grad is taken, the resulting dlogp graph contains unnecessary operations. However this may improve the stability and performance of grad logp in other situation.
Since pymc-devs/pytensor#303, `cholesky(L.dot(L.T))` will rewritten to L if `L.tag.lower_triangular=True`. This change adds these where appropriate. Fixes pymc-devs#6717.
@ricardoV94, should be good to go assuming tests all pass |
Thanks a lot @dehorsley! |
With pymc-devs/pytensor#303 merged,
cholesky(L.dot(L.T))
will be rewritten toL
ifL.tag.lower_triangular
is set. This change adds these where appropriate. This is important for #6717, however more work is likely required to improve the gradient in such cases.Some rough benchmarks of computing logp and its grad on the initial point of following model.
The major difference in the "after" between C and JAX backends is that JAX is computing the grad after the rewrite is applied. As mentioned in #6717, this is probably a good motivator for performing some kind of rewrite before computing the gradient.
Maintenance
📚 Documentation preview 📚: https://pymc--6736.org.readthedocs.build/en/6736/