Skip to content

Use NoCache to improve set_to_zero!! performance with Mooncake #975

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 11 commits into from

Conversation

sunxd3
Copy link
Member

@sunxd3 sunxd3 commented Jul 8, 2025

Copy link
Contributor

github-actions bot commented Jul 8, 2025

Benchmark Report for Commit 0576503

Computer Information

Julia Version 1.11.6
Commit 9615af0f269 (2025-07-09 12:58 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × AMD EPYC 7763 64-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver3)
Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores)

Benchmark Results

|                 Model | Dimension |  AD Backend |      VarInfo Type | Linked | Eval Time / Ref Time | AD Time / Eval Time |
|-----------------------|-----------|-------------|-------------------|--------|----------------------|---------------------|
| Simple assume observe |         1 | forwarddiff |             typed |  false |                 10.1 |                 1.5 |
|           Smorgasbord |       201 | forwarddiff |             typed |  false |                761.0 |                35.9 |
|           Smorgasbord |       201 | forwarddiff | simple_namedtuple |   true |                309.9 |                68.1 |
|           Smorgasbord |       201 | forwarddiff |           untyped |   true |               1268.1 |                28.4 |
|           Smorgasbord |       201 | forwarddiff |       simple_dict |   true |               3690.9 |                22.8 |
|           Smorgasbord |       201 | reversediff |             typed |   true |               1490.0 |                29.4 |
|           Smorgasbord |       201 |    mooncake |             typed |   true |                986.3 |                 5.2 |
|    Loop univariate 1k |      1000 |    mooncake |             typed |   true |               5803.5 |                 4.0 |
|       Multivariate 1k |      1000 |    mooncake |             typed |   true |               1050.4 |                 8.7 |
|   Loop univariate 10k |     10000 |    mooncake |             typed |   true |              64246.0 |                 3.7 |
|      Multivariate 10k |     10000 |    mooncake |             typed |   true |               9382.6 |                 9.5 |
|               Dynamic |        10 |    mooncake |             typed |   true |                133.5 |                13.2 |
|              Submodel |         1 |    mooncake |             typed |   true |                 14.7 |                 6.3 |
|                   LDA |        12 | reversediff |             typed |   true |                486.0 |                 5.4 |

Copy link

codecov bot commented Jul 8, 2025

Codecov Report

❌ Patch coverage is 0% with 10 lines in your changes missing coverage. Please review.
✅ Project coverage is 81.31%. Comparing base (0b7213f) to head (0576503).
⚠️ Report is 4 commits behind head on main.

Files with missing lines Patch % Lines
ext/DynamicPPLMooncakeExt.jl 0.00% 10 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #975      +/-   ##
==========================================
- Coverage   82.97%   81.31%   -1.67%     
==========================================
  Files          36       37       +1     
  Lines        3965     3965              
==========================================
- Hits         3290     3224      -66     
- Misses        675      741      +66     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@coveralls
Copy link

coveralls commented Jul 8, 2025

Pull Request Test Coverage Report for Build 16785074264

Details

  • 0 of 10 (0.0%) changed or added relevant lines in 1 file are covered.
  • 85 unchanged lines in 8 files lost coverage.
  • Overall coverage decreased (-1.8%) to 81.311%

Changes Missing Coverage Covered Lines Changed/Added Lines %
ext/DynamicPPLMooncakeExt.jl 0 10 0.0%
Files with Coverage Reduction New Missed Lines %
src/bijector.jl 1 93.33%
src/model.jl 1 85.83%
src/experimental.jl 3 0.0%
src/varinfo.jl 7 84.46%
ext/DynamicPPLJETExt.jl 16 0.0%
src/threadsafe.jl 16 55.05%
src/logdensityfunction.jl 20 54.35%
src/test_utils/ad.jl 21 0.0%
Totals Coverage Status
Change from base Build 16722554440: -1.8%
Covered Lines: 3224
Relevant Lines: 3965

💛 - Coveralls

Copy link
Contributor

github-actions bot commented Jul 9, 2025

DynamicPPL.jl documentation for PR #975 is available at:
https://TuringLang.github.io/DynamicPPL.jl/previews/PR975/

@sunxd3 sunxd3 requested a review from penelopeysm July 21, 2025 10:53
@sunxd3
Copy link
Member Author

sunxd3 commented Aug 6, 2025

I just realized that chalk-lab/Mooncake.jl#667 is not sufficient.

E.g., for type

struct T
    field_whose_tangent_type_needs_cache
    field_whose_tangent_type_does_not_need_cache
end

We can not do requires_cache(T) = Val{false}() because it will not use cache for the whole tangent type, thus field_whose_tangent_type_needs_cache can get us in trouble.

But setting requires_cache(T) = Val{true}() will not bring us any performance improvement. The reason is that requires_cahce is called by set_to_zero!!, but set_to_zero_internal!! doesn't call set_to_zero!!.

@willtebbutt
Copy link
Member

Which type in particular are you looking to do this for? I'm just wondering whether you're able to convince yourself in practice that this won't be a problem, or something like that.

@sunxd3
Copy link
Member Author

sunxd3 commented Aug 7, 2025

My understanding is

"""
The `Metadata` struct stores some metadata about the parameters of the model. This helps
query certain information about a variable, such as its distribution, which samplers
sample this variable, its value and whether this value is transformed to real space or
not.
Let `md` be an instance of `Metadata`:
- `md.vns` is the vector of all `VarName` instances.
- `md.idcs` is the dictionary that maps each `VarName` instance to its index in
`md.vns`, `md.ranges` `md.dists`, `md.orders` and `md.flags`.
- `md.vns[md.idcs[vn]] == vn`.
- `md.dists[md.idcs[vn]]` is the distribution of `vn`.
- `md.orders[md.idcs[vn]]` is the number of `observe` statements before `vn` is sampled.
- `md.ranges[md.idcs[vn]]` is the index range of `vn` in `md.vals`.
- `md.vals[md.ranges[md.idcs[vn]]]` is the vector of values of corresponding to `vn`.
- `md.flags` is a dictionary of true/false flags. `md.flags[flag][md.idcs[vn]]` is the
value of `flag` corresponding to `vn`.
To make `md::Metadata` type stable, all the `md.vns` must have the same symbol
and distribution type. However, one can have a Julia variable, say `x`, that is a
matrix or a hierarchical array sampled in partitions, e.g.
`x[1][:] ~ MvNormal(zeros(2), I); x[2][:] ~ MvNormal(ones(2), I)`, and is managed by
a single `md::Metadata` so long as all the distributions on the RHS of `~` are of the
same type. Type unstable `Metadata` will still work but will have inferior performance.
When sampling, the first iteration uses a type unstable `Metadata` for all the
variables then a specialized `Metadata` is used for each symbol along with a function
barrier to make the rest of the sampling type stable.
"""
struct Metadata{
TIdcs<:Dict{<:VarName,Int},
TDists<:AbstractVector{<:Distribution},
TVN<:AbstractVector{<:VarName},
TVal<:AbstractVector{<:Real},
}
# Mapping from the `VarName` to its integer index in `vns`, `ranges` and `dists`
idcs::TIdcs # Dict{<:VarName,Int}
# Vector of identifiers for the random variables, where `vns[idcs[vn]] == vn`
vns::TVN # AbstractVector{<:VarName}
# Vector of index ranges in `vals` corresponding to `vns`
# Each `VarName` `vn` has a single index or a set of contiguous indices in `vals`
ranges::Vector{UnitRange{Int}}
# Vector of values of all the univariate, multivariate and matrix variables
# The value(s) of `vn` is/are `vals[ranges[idcs[vn]]]`
vals::TVal # AbstractVector{<:Real}
# Vector of distributions correpsonding to `vns`
dists::TDists # AbstractVector{<:Distribution}
# Number of `observe` statements before each random variable is sampled
orders::Vector{Int}
# Each `flag` has a `BitVector` `flags[flag]`, where `flags[flag][i]` is the true/false flag value corresonding to `vns[i]`
flags::Dict{String,BitVector}
end
(used in VarInfo) is the type we might want to skip using cache for.

We need to use cache for DPPL.Model because it might capture a closure.

@willtebbutt
Copy link
Member

Ahh okay. I think I missed this context when we were reviewing chalk-lab/Mooncake.jl#667 .

My feeling on this is that the best way to handle this is to ensure the exclusion of aliasing and circular references at cache-construction time (in Mooncake), and then assume that there continues not to be any aliasing / circular referencing in the future. I'm going to DM you about this to see if we can chat in-person, as that might make life easier.

@sunxd3
Copy link
Member Author

sunxd3 commented Aug 8, 2025

closed in favor of chalk-lab/Mooncake.jl#680

@sunxd3 sunxd3 closed this Aug 8, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants