Skip to content

Rewrite logp graph before taking the gradient #6736

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 29, 2023

Conversation

dehorsley
Copy link
Contributor

@dehorsley dehorsley commented May 25, 2023

With pymc-devs/pytensor#303 merged, cholesky(L.dot(L.T)) will be rewritten to L if L.tag.lower_triangular is set. This change adds these where appropriate. This is important for #6717, however more work is likely required to improve the gradient in such cases.

Some rough benchmarks of computing logp and its grad on the initial point of following model.

n=1000
with pm.Model() as m:
    chol, corr, sigmas = pm.LKJCholeskyCov('cov', n=n, eta=1, sd_dist=pm.HalfNormal.dist())
    pm.MvNormal('y', mu=np.zeros(n), chol=chol, observed=np.ones((1000, n)))

C backend JAX backend
before 292 ms ± 28.6 ms 285 ms ± 28.9 ms
after 260 ms ± 19.9 ms 107 ms ± 969 µs

The major difference in the "after" between C and JAX backends is that JAX is computing the grad after the rewrite is applied. As mentioned in #6717, this is probably a good motivator for performing some kind of rewrite before computing the gradient.

Maintenance

  • improve performance of MvNormal logp with Cholesky factor.

📚 Documentation preview 📚: https://pymc--6736.org.readthedocs.build/en/6736/

@codecov
Copy link

codecov bot commented May 25, 2023

Codecov Report

Merging #6736 (ee1657b) into main (7b08fc1) will decrease coverage by 0.02%.
The diff coverage is 100.00%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #6736      +/-   ##
==========================================
- Coverage   91.92%   91.91%   -0.02%     
==========================================
  Files          95       95              
  Lines       16197    16207      +10     
==========================================
+ Hits        14889    14896       +7     
- Misses       1308     1311       +3     
Impacted Files Coverage Δ
pymc/distributions/multivariate.py 92.20% <100.00%> (+<0.01%) ⬆️
pymc/math.py 70.61% <100.00%> (+0.56%) ⬆️
pymc/model.py 89.95% <100.00%> (+0.03%) ⬆️
pymc/pytensorf.py 92.75% <100.00%> (+0.03%) ⬆️

... and 1 file with indirect coverage changes

@ricardoV94
Copy link
Member

ricardoV94 commented May 25, 2023

This is a bit brute-force but can you benchmark after replacing these lines:

pymc/pymc/logprob/basic.py

Lines 386 to 387 in 8c93bb5

_check_no_rvs(list(logp_terms.values()))
return list(logp_terms.values())

By:

    logp_terms = list(logp_terms.values())
    _check_no_rvs(logp_terms)

    from pytensor.compile import optdb
    from pytensor.graph import FunctionGraph
    
    rewrite = optdb.query("+canonicalize")
    fg = FunctionGraph(outputs=logp_terms, clone=False)
    rewrite.rewrite(fg)

    return logp_terms

That should remove the useless cholesky before the gradient is generated.

@dehorsley
Copy link
Contributor Author

Updated benchmarks:

C backend JAX backend
before 245 ms ± 4.88 ms 265 ms ± 7.38 ms
after 231 ms ± 9.16 ms 106 ms ± 3.15 ms
after + pre grad rewrite 90.3 ms ± 2.15 ms 109 ms ± 4.06 ms

@ricardoV94, do you think it's worth making the rewite optional via Model.logp? I can imagine it'd useful to see the "raw" logp graph while debugging

@ricardoV94
Copy link
Member

ricardoV94 commented May 28, 2023

Awesome results. About the rewrites, we shouldn't introduce them here. Users should always look at the compiled function when they want to investigate the graph for performance concerns.

We should introduce the cholesky rewrite (and a couple others like log1mexp, softmax that have more stable gradients) when calling model.dlogp

@ricardoV94
Copy link
Member

ricardoV94 commented May 28, 2023

By the way can you share the benchmark script so I can replicate locally?

We should add a test for the compiled logp/dlogp to make sure the useless cholesky is removed (and avoid a regression in the future)

@dehorsley
Copy link
Contributor Author

@ricardoV94, is this the kind of thing you were thinking? I've made a new rewrite db if we want to be a bit more precise about what we want to perform. Not sure if that's the best place to put the rewrites. I'll add tests in the next few days.

This is the script I was using for benchmarks (ipynb but I can't attach here):

# %%
import pymc as pm
import numpy as np
import pytensor
from pymc.blocking import DictToArrayBijection

# %%
print(pm.__version__)
print(pytensor.__version__)

# %%
n=1000
with pm.Model() as m:
    chol, corr, sigmas = pm.LKJCholeskyCov('cov', n=n, eta=1, sd_dist=pm.HalfNormal.dist())
    pm.MvNormal('y', mu=np.zeros(n), chol=chol, observed=np.ones((1000, n)))


# %%
logp = m.logp_dlogp_function()
logp.set_extra_values([])
pt = DictToArrayBijection.map(m.initial_point())
logp(pt)

# %%
%timeit -n 10 logp(pt)

# %%
pytensor.dprint(logp._pytensor_function.maker.fgraph.outputs[0])

# %%
pytensor.config.mode = pytensor.compile.get_default_mode().excluding("cholesky_ldotlt")

# %%
logp = m.logp_dlogp_function()
logp.set_extra_values([])
pt = DictToArrayBijection.map(m.initial_point())
logp(pt)

# %%
pytensor.dprint(logp._pytensor_function.maker.fgraph.outputs[0])

# %%
%timeit -n 10 logp(pt)

# %%
import jax 
import jax.numpy as jnp
from pymc.sampling.jax import get_jaxified_logp

# %%
logp_jax = get_jaxified_logp(m)
rvs = [rv.name for rv in m.value_vars]
init_position_dict = m.initial_point()
init_position = [jnp.array(init_position_dict[rv]) for rv in rvs]
grad_and_logp_jax = jax.jit(jax.value_and_grad(logp_jax))
grad_and_logp_jax(init_position)

# %%
%timeit -n 10 grad_and_logp_jax(init_position)

# %%
pytensor.compile.mode.JAX = pytensor.compile.mode.JAX.excluding("cholesky_ldotlt")

# %%
logp_jax = get_jaxified_logp(m)
rvs = [rv.name for rv in m.value_vars]
init_position_dict = m.initial_point()
init_position = [jnp.array(init_position_dict[rv]) for rv in rvs]
grad_and_logp_jax = jax.jit(jax.value_and_grad(logp_jax))
grad_and_logp_jax(init_position)

# %%
%timeit -n 10 grad_and_logp_jax(init_position)

# %%
m.optdb.print_summary()

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks pretty good, just wondering whether we should be more targeted in the rewrites we include

@dehorsley dehorsley force-pushed the fix-6717 branch 2 times, most recently from e750986 to b27b828 Compare June 14, 2023 07:48
@dehorsley
Copy link
Contributor Author

New test will fail until pytensor version is bumped.

@dehorsley dehorsley requested a review from ricardoV94 June 16, 2023 00:26
Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some small typos.

Does the last PyTensor already include the rewrites?

@dehorsley
Copy link
Contributor Author

Some rough benchmarks for compiling dlogp, with the C backend:

pregrad rewites (pr #6736) dev (7b08fc1)
model uncached cached uncached cached
Large MvNormal (m1) 135 s 1.0 s 151 s 1.2 s
Large no. normal vars (m2) 47 s 2.4 s 56 s 3.0 s
8 schools, smallish every day model (m3) 33 s 0.3 s 33 s 0.3 s

Generally, including the pregrad rewrites is faster overall, if not the same. I'm guessing the time saved by compiling the simpler dlogp is making up for the extra time spend optimising. Admittedly these are not very interesting models. Let me know if you have some others to test!

Methodology

For each of the following, I executed pytensor-cache purge, created the model, then ran

%%time 
m.compile_dlogp()

For uncached, then for cached

%timeit m.compile_dlogp()

Models:

n=1000
with pm.Model() as m1:
    chol, corr, sigmas = pm.LKJCholeskyCov('cov', n=n, eta=1, sd_dist=pm.HalfNormal.dist())
    pm.MvNormal('y', mu=np.zeros(n), chol=chol, observed=np.ones((1000, n)))

with pm.Model() as m2:
    sigma = pm.HalfCauchy('sigma', 1)
    for i in range(1,30):
        pm.Normal(f'm{i}', 0, i*sigma)


y = np.array([28, 8, -3, 7, -1, 1, 18, 12])
sigma = np.array([15, 10, 16, 11, 9, 11, 10, 18])
J = len(y)

with pm.Model() as m3:
    eta = pm.Normal("eta", 0, 1, shape=J)
    # Hierarchical mean and SD
    mu = pm.Normal("mu", 0, sigma=10)
    tau = pm.HalfNormal("tau", 10)

    # Non-centered parameterization of random effect
    theta = pm.Deterministic("theta", mu + tau * eta)

    obs = pm.Normal("obs", theta, sigma=sigma, observed=y)

I had to decrease the number of variables in model 2 as it was hitting the compilers nesting level max depth, but only in the case without the pregrad rewrites. (Of course there are other ways to get around this, and this is an intentionally pathological way to write this model!)

@dehorsley dehorsley force-pushed the fix-6717 branch 2 times, most recently from 86a2b70 to 34ab51d Compare June 28, 2023 03:29
@dehorsley
Copy link
Contributor Author

dehorsley commented Jun 28, 2023

Some small typos.

Thanks, fixed.

Does the last PyTensor already include the rewrites?

Yes, included since 2.12.2. I'm not sure why tests are failing, looks like the test environment has v2.12.3.

EDIT: ah looks like a dimension issue, as well as a programmer error. Passing locally now.

@ricardoV94
Copy link
Member

Benchmarks look good. I'll mark this PR as requiring a major release and I'll ask other devs to keep an eye to see if their runtimes are significantly affected.

@ricardoV94 ricardoV94 added enhancements major Include in major changes release notes section labels Jun 28, 2023
@ricardoV94 ricardoV94 changed the title add lower triangular tags to allow chol rewrite Rewrite logp graph before taking the gradient Jun 28, 2023
Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a last-minute doubt about the helper. WDYT?

This adds safe rewrites to logp before the grad operator is applied.
This is motivated by pymc-devs#6717, where expensive `cholesky(L.dot(L.T))`
operations are removed. If these remain in the logp graph when the grad
is taken, the resulting dlogp graph contains unnecessary operations.
However this may improve the stability and performance of grad logp
in other situation.
Since pymc-devs/pytensor#303, `cholesky(L.dot(L.T))` will rewritten to L
if `L.tag.lower_triangular=True`. This change adds these where
appropriate.

Fixes pymc-devs#6717.
@dehorsley
Copy link
Contributor Author

@ricardoV94, should be good to go assuming tests all pass

@ricardoV94 ricardoV94 merged commit 4847914 into pymc-devs:main Jun 29, 2023
@ricardoV94
Copy link
Member

Thanks a lot @dehorsley!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancements major Include in major changes release notes section
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants