Skip to content

InitContext, part 4 - Use init!! to replace evaluate_and_sample!!, predict, returned, and initialize_values #984

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: breaking
Choose a base branch
from

Conversation

penelopeysm
Copy link
Member

@penelopeysm penelopeysm commented Jul 10, 2025

Part 1: Adding hasvalue and getvalue to AbstractPPL
Part 2: Removing hasvalue and getvalue from DynamicPPL
Part 3: Introducing InitContext and init!!

This is part 4/N of #967.


In Part 3 we introduced InitContext. This PR makes use of the functionality in there to replace a bunch of code that no longer needs to exist:

  1. setval_and_resample! followed by model evaluation: This process was used for predict and returned, to manually store certain values in the VarInfo, which would be used in the subsequent model evaluation. We can now do this in a single step using InitFromParams.
  2. initialize_values!!: very similar to the above. It would manually set values inside the varinfo, and then it would trigger an extra model evaluation to update the logp field. Again, this is directly replaced with InitFromParams.
  3. evaluate_and_sample!!: direct one-to-one replacement with init!!.

There is one fairly major API change associated with point (2): the initial_params kwarg to Turing's sample must now be an AbstractInitStrategy.

It's still optional (it will default to init_strategy(spl), which is usually InitFromPrior, except for the HMC family which uses InitFromUniform). However, there are two implications:

  • initial_params cannot be a vector of parameters anymore. It must be InitFromParams(::NamedTuple) OR InitFromParams(::AbstractDict{VarName}).
  • Because InitFromParams expects values in unlinked space, initial_params must always be specified in unlinked space. Previously, initial_params would have to be specified in a way that matched the linking status of the underlying varinfo.

I consider both of these to be a major win for clarity. (One might argue that vectors are more convenient. But IMO anything that lets you extract a vector will also let you extract a NT or Dict, maybe with a bit more typing at worst).

Closes

Closes #774
Closes #797
Closes #983
Closes TuringLang/Turing.jl#2476
Closes TuringLang/Turing.jl#1775

Copy link
Contributor

github-actions bot commented Jul 10, 2025

Benchmark Report for Commit 3bb7ade

Computer Information

Julia Version 1.11.6
Commit 9615af0f269 (2025-07-09 12:58 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × AMD EPYC 7763 64-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver3)
Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores)

Benchmark Results

|                 Model | Dimension |  AD Backend |      VarInfo Type | Linked | Eval Time / Ref Time | AD Time / Eval Time |
|-----------------------|-----------|-------------|-------------------|--------|----------------------|---------------------|
| Simple assume observe |         1 | forwarddiff |             typed |  false |                  8.8 |                 1.6 |
|           Smorgasbord |       201 | forwarddiff |             typed |  false |                627.6 |                46.4 |
|           Smorgasbord |       201 | forwarddiff | simple_namedtuple |   true |                413.2 |                52.2 |
|           Smorgasbord |       201 | forwarddiff |           untyped |   true |               1148.1 |                29.4 |
|           Smorgasbord |       201 | forwarddiff |       simple_dict |   true |               6360.4 |                29.2 |
|           Smorgasbord |       201 | reversediff |             typed |   true |               1019.1 |                40.7 |
|           Smorgasbord |       201 |    mooncake |             typed |   true |                974.0 |                 4.7 |
|    Loop univariate 1k |      1000 |    mooncake |             typed |   true |               5702.1 |                 4.4 |
|       Multivariate 1k |      1000 |    mooncake |             typed |   true |                970.4 |                 8.9 |
|   Loop univariate 10k |     10000 |    mooncake |             typed |   true |              63372.6 |                 3.9 |
|      Multivariate 10k |     10000 |    mooncake |             typed |   true |               8159.8 |                10.2 |
|               Dynamic |        10 |    mooncake |             typed |   true |                130.1 |                12.3 |
|              Submodel |         1 |    mooncake |             typed |   true |                 12.8 |                 5.4 |
|                   LDA |        12 | reversediff |             typed |   true |               1076.4 |                 2.1 |

@penelopeysm penelopeysm changed the title Use init!! to replace evaluate_and_sample!!, predict, returned, and initialize_values InitContext, part 4 - Use init!! to replace evaluate_and_sample!!, predict, returned, and initialize_values Jul 10, 2025
@penelopeysm penelopeysm force-pushed the py/init-prior-uniform branch 2 times, most recently from 025aa8b to b55c1e1 Compare July 10, 2025 14:24
@penelopeysm penelopeysm force-pushed the py/actually-use-init branch 5 times, most recently from b72c3bf to 92d3542 Compare July 10, 2025 15:57
@penelopeysm penelopeysm mentioned this pull request Jul 10, 2025
22 tasks
@penelopeysm penelopeysm force-pushed the py/actually-use-init branch 4 times, most recently from 7438b23 to d55d378 Compare July 10, 2025 16:56
@penelopeysm penelopeysm force-pushed the py/actually-use-init branch 3 times, most recently from 12d93e5 to 7a8e7e3 Compare July 10, 2025 17:47
@penelopeysm penelopeysm force-pushed the py/actually-use-init branch 2 times, most recently from 1d8bceb to 2edcd10 Compare July 20, 2025 00:59
Copy link
Contributor

DynamicPPL.jl documentation for PR #984 is available at:
https://TuringLang.github.io/DynamicPPL.jl/previews/PR984/

Copy link

codecov bot commented Jul 20, 2025

Codecov Report

❌ Patch coverage is 86.95652% with 12 lines in your changes missing coverage. Please review.
✅ Project coverage is 80.69%. Comparing base (991e825) to head (3bb7ade).

Files with missing lines Patch % Lines
src/simple_varinfo.jl 40.00% 6 Missing ⚠️
src/test_utils/contexts.jl 83.33% 5 Missing ⚠️
src/test_utils/model_interface.jl 0.00% 1 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff              @@
##           breaking     #984      +/-   ##
============================================
- Coverage     82.53%   80.69%   -1.85%     
============================================
  Files            39       39              
  Lines          4008     3947      -61     
============================================
- Hits           3308     3185     -123     
- Misses          700      762      +62     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Comment on lines -1228 to +1201
"""
predict([rng::Random.AbstractRNG,] model::Model, chain::AbstractVector{<:AbstractVarInfo})

Generate samples from the posterior predictive distribution by evaluating `model` at each set
of parameter values provided in `chain`. The number of posterior predictive samples matches
the length of `chain`. The returned `AbstractVarInfo`s will contain both the posterior parameter values
and the predicted values.
"""
function predict(
rng::Random.AbstractRNG, model::Model, chain::AbstractArray{<:AbstractVarInfo}
)
varinfo = DynamicPPL.VarInfo(model)
return map(chain) do params_varinfo
vi = deepcopy(varinfo)
DynamicPPL.setval_and_resample!(vi, values_as(params_varinfo, NamedTuple))
model(rng, vi)
return vi
end
end
# Implemented & documented in DynamicPPLMCMCChainsExt
function predict end
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was discussed at one of the meetings and we decided we didn't care enough about the predict method on vectors of varinfos. It's currently bugged because varinfo is always unlinked, but params_varinfo might be linked, and if it is, it will give wrong results because it sets a linked value into an unlinked varinfo. See #983.

Base automatically changed from py/init-prior-uniform to breaking August 13, 2025 16:47
@penelopeysm penelopeysm force-pushed the py/actually-use-init branch from f6dd1d5 to d9292ad Compare August 13, 2025 16:51
Comment on lines +45 to +52
function chain_sample_to_varname_dict(c::MCMCChains.Chains, sample_idx, chain_idx)
_check_varname_indexing(c)
d = Dict{DynamicPPL.VarName,Any}()
for vn in DynamicPPL.varnames(c)
d[vn] = DynamicPPL.getindex_varname(c, sample_idx, vn, chain_idx)
end
return d
end
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that, if the chain does not store varnames inside its info field, chain_sample_to_varname_dict will fail.

I don't think this is a huge problem right now because every chain obtained via Turing's sample() will contain varnames:

https://github.com/TuringLang/Turing.jl/blob/1aa95ac91a115569c742bab74f7b751ed1450309/src/mcmc/Inference.jl#L288-L290

So this is only a problem if one manually constructs a chain and tries to call predict on it, which I think is a highly unlikely workflow (and I'm happy to wait for people to complain if it fails). There are a few places in DynamicPPL's test suite where this does actually happen. I fixed them all by manually adding the varname dictionary.

However, it's obviously ugly. The only good way around this is to rework MCMCChains.jl :( (See here for the implementation of the corresponding functionality in FlexiChains.)

@penelopeysm penelopeysm marked this pull request as ready for review August 13, 2025 16:57
@penelopeysm penelopeysm force-pushed the py/actually-use-init branch from 89bc0ea to 726d486 Compare August 13, 2025 17:15
@penelopeysm penelopeysm force-pushed the py/actually-use-init branch from 726d486 to bc04355 Compare August 13, 2025 17:16
Comment on lines 20 to 22
function DynamicPPL.Experimental._determine_varinfo_jet(
model::DynamicPPL.Model; only_ddpl::Bool=true
)
Copy link
Member Author

@penelopeysm penelopeysm Aug 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit confused by the comments in this function because as far as I can tell it only ever tested sampling, not both sampling and evaluation. (That was also true going further back e.g. in v0.36)

This PR thus also changes the implementation of this function to test both evaluation and sampling (i.e. initialisation) and if either fails, it will return the untyped varinfo.

Sorry I had to make this change in this PR. There were a few unholy tests where one would end up evaluating a model with a SamplingContext{<:InitContext}, which would error unless I introduced special code to handle it, and I didn't really want to do that. JETExt was one of those unholy scenarios.

Comment on lines -61 to +63
DynamicPPL.evaluate_and_sample!!(rng, model, vi, sampler)
return vi, nothing
strategy = sampler isa SampleFromPrior ? InitFromPrior() : InitFromUniform()
_, new_vi = DynamicPPL.init!!(rng, model, vi, strategy)
return new_vi, nothing
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit weird, but it's really just to tide us over until we delete SampleFromUniform/SampleFromPrior properly.

Define the initialisation strategy used for generating initial values when
sampling with `sampler`. Defaults to `InitFromPrior()`, but can be overridden.
"""
init_strategy(::Sampler) = InitFromPrior()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eventually the aim would be to use ::AbstractSampler. But that will have to wait for cleanup in Turing. DynamicPPL itself doesn't use Sampler at all and if you only look at DPPL it looks like a meaningless empty wrapper, but Turing relies on these methods a fair bit

Comment on lines -158 to -174
@testset "rng" begin
model = GDEMO_DEFAULT

for sampler in (SampleFromPrior(), SampleFromUniform())
for i in 1:10
Random.seed!(100 + i)
vi = VarInfo()
DynamicPPL.evaluate_and_sample!!(Random.default_rng(), model, vi, sampler)
vals = vi[:]

Random.seed!(100 + i)
vi = VarInfo()
DynamicPPL.evaluate_and_sample!!(Random.default_rng(), model, vi, sampler)
@test vi[:] == vals
end
end
end
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is adequately tested in the InitContext tests (test_rng_respected)

Comment on lines -65 to -66
varinfo_untyped = DynamicPPL.VarInfo()
model_with_spl = contextualize(model, SamplingContext(context))
Copy link
Member Author

@penelopeysm penelopeysm Aug 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also had to rework these tests because of the SamplingContext{<:InitContext} case.

Comment on lines -631 to -634
## `typed_varinfo`
vi = DynamicPPL.typed_varinfo(model)
vi = DynamicPPL.settrans!!(vi, true, vn)
test_linked_varinfo(model, vi)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was a duplicate of the 4 lines above

@@ -1012,45 +959,6 @@ end
@test merge(vi_double, vi_single)[vn] == 1.0
end

@testset "sampling from linked varinfo" begin
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these tests are also covered in InitContext now (test_link_status_respected)!

@penelopeysm penelopeysm requested a review from mhauru August 13, 2025 23:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant