-
Notifications
You must be signed in to change notification settings - Fork 228
"Fixes" for PG-in-Gibbs #2629
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
"Fixes" for PG-in-Gibbs #2629
Conversation
2f625cb
to
c1d05bd
Compare
Turing.jl documentation for PR #2629 is available at: |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## mhauru/dppl-0.37 #2629 +/- ##
=====================================================
+ Coverage 0.00% 24.43% +24.43%
=====================================================
Files 22 22
Lines 1485 1502 +17
=====================================================
+ Hits 0 367 +367
+ Misses 1485 1135 -350 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
89c45c9
to
6aaad1f
Compare
# We want to increment num produce for the VarInfo stored in the trace. The trace is | ||
# mutable, so we create a new model with the incremented VarInfo and set it in the trace | ||
model = trace.model | ||
model = Accessors.@set model.f.varinfo = DynamicPPL.increment_num_produce!!( | ||
model.f.varinfo | ||
) | ||
trace.model = model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we don't ever use num_produce, we don't need to modify it either.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changes in this file are all due to the issue with Gibbs-conditioned-variables now triggering resampling. Thus now it's forbidden to have a dynamic number of Gibbs-conditioned-variables.
test/mcmc/gibbs.jl
Outdated
# The below test used to sample incorrectly before | ||
# https://github.com/TuringLang/Turing.jl/pull/2328 | ||
@testset "dynamic model with ESS" begin | ||
@model function dynamic_model_for_ess() | ||
b ~ Bernoulli() | ||
x_length = b ? 1 : 2 | ||
x = Vector{Float64}(undef, x_length) | ||
for i in 1:x_length | ||
x[i] ~ Normal(i, 1.0) | ||
end | ||
end | ||
|
||
m = dynamic_model_for_ess() | ||
chain = sample(m, Gibbs(:b => PG(10), :x => ESS()), 2000; discard_initial=100) | ||
means = Dict(:b => 0.5, "x[1]" => 1.0, "x[2]" => 2.0) | ||
stds = Dict(:b => 0.5, "x[1]" => 1.0, "x[2]" => 1.0) | ||
for vn in keys(means) | ||
@test isapprox(mean(skipmissing(chain[:, vn, 1])), means[vn]; atol=0.1) | ||
@test isapprox(std(skipmissing(chain[:, vn, 1])), stds[vn]; atol=0.1) | ||
end | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, we now can't handle the different number of Gibbs-conditioned-variables. However, in the specific case of these tests, we can change the model to something equivalent which has the same number of observations:
@model function dynamic_model_for_ess()
b ~ Bernoulli()
if b
x ~ MvNormal([1.0], I)
else
x ~ MvNormal([1.0, 2.0], I)
end
end
The one below this can similarly use filldist
.
Do you think it's worth keeping these tests? i.e. was the point to make sure that the sampling results are numerically correct, or was the point of the test to make sure that you could have different numbers of Gibbs-conditioned variables (in which case rewriting it would defeat the purpose)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(If it's the former, I think we should include a likelihood term, I think being able to sample from the prior correctly is a bit of a low bar to clear for an inference algorithm implementation.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a regression test, and I don't know what the original issue was that caused it to error. We don't have too many dynamic model tests, so I would keep it with the MvNormal
because it doesn't cost us much and might still catch something.
An alternative, though: Any interest in turning this into a test to check whether you can still do dynamic models with loops if you balance out the branches with @addlogprob!!(0.0)
calls? Hacky, but I kinda want to know that it works as a workaround if need be. So that would be something like
@model function dynamic_model_for_ess()
b ~ Bernoulli()
x_length = b ? 1 : 2
max_x = 2
x = Vector{Float64}(undef, x_length)
for i in 1:x_length
x[i] ~ Normal(i, 1.0)
end
for _ in (x_length+1):max_x:
@addlogprob!!(0.0)
end
end
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, happy to do that, will add a comment with the intent too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently unsure what to do with the tests.
@model function dynamic_bernoulli()
b ~ Bernoulli()
if b
x ~ MvNormal([1.0], I)
else
x ~ MvNormal([1.0, 2.0], I)
end
end
chain = sample(dynamic_bernoulli(), Gibbs(:b => PG(10), :x => ESS()), 2000; discard_initial=100)
This errors because if x
in the global varinfo is a length-2 vector and PG selects b = 1
, then logpdf(MvNormal([1.0], I), x)
will fail.
@model function dynamic_bernoulli_2()
b ~ Bernoulli()
x_length = b ? 1 : 2
max_x = 2
x = Vector{Float64}(undef, x_length)
for i in 1:x_length
x[i] ~ Normal(i, 1.0)
end
for _ in (x_length+1):max_x
@addlogprob!(0.0)
end
end
chain = sample(dynamic_bernoulli_2(), Gibbs(:b => PG(10), :x => ESS()), 2000; discard_initial=100)
This runs but gives a rather incorrect value of b = 0.75
(ish, depends on number of iterations and stuff).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I actually understand why the mean is wrong now (at least when it runs; I don't know why it crashes). Suppose that x[1] = 1.0
and x[2] = 2.0
(on average this is true).
For particles with b = true
(i.e. only observe x[1]
), Libtask.produce
gets called twice: once with logpdf(Normal(1.0), 1.0) = -0.919
and once with a manual logpdf of 0
.
For particles with b = false
(i.e. observe x[1]
and x[2]
), Libtask.produce
gets called twice, once with
logpdf(Normal(1.0), 1.0) = -0.919
and once with logpdf(Normal(2.0), 2.0) = -0.919
.
When PG starts its thing, it just samples b
from the prior, so there should be 50% true
and 50% false
.
On the first reweighting, the logpdf's (i.e., weights) are equal and thus one would expect that the resampling step generates an equal amounts of true
and false
.
On the second reweighting however the weights are different. The unnormalised weights are exp(0) = 1
and exp(-0.919) = 0.3989
, which is a ratio of about 2.5, so when resampling you're approximately 2.5x more likely to get a particle with b = true
rather than b = false
.
So mean(b)
would be expected to be 2.5 / (2.5 + 1) = 0.715
, which is sort of in the right region.
Furthermore, in general x[2]
will not be equal to 2.0, meaning that logpdf(Normal(2.0), x[2])
will be smaller and the ratio of weights larger than 2.5, and mean(b)
will be even more biased towards 1. This is exactly what we observe with mean(b)
coming out at 0.772.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed if you make this change:
for _ in (x_length+1):max_x
- @addlogprob!(0.0)
+ # same as @addlogprob!(logpdf(Normal(2.0), 2.0))
+ @addlogprob!(logpdf(Normal(), 0.0))
end
you get back a value of mean(b) = 0.592
which is closer to 0.5 (albeit still larger than 0.5 because of the last paragraph in the previous comment).
And to completely remove the bias, I think you'd have to add a term that is the expectation value of logpdf(Normal(), x)
where x ~ Normal()
. That's a slightly nasty integral, but we can approximate it with Monte Carlo (or alternatively use WolframAlpha):
julia> mean(logpdf.(Normal(), randn(10_000_000)))
-1.418849291536427
Calling @addlogprob!
with this gives... mean(b) = 0.5050
!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Finally, none of this actually explains why it used to "work" with the previous non-tilde-observe Gibbs sampler. The current behaviour makes complete sense to me given our tilde-observe-Gibbs implementation, but it still seems to me that the exact same thing should have happened before this PR. To understand that I might have to take another long walk.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. That makes sense.
I think I know why it used to work: The problem was masked by another bug. Namely, we would accumulate the logp contributions from the Gibbs-conditioned tilde_assume
statements into the VarInfo, and then add it to the score/log likelihood at the next produce/resampling, with the line return score + DynamicPPL.getlogp(varinfo)
in advance!
. But in the above model there is no next produce/resampling, because there is no likelihood term. Observe instead this, running on latest release:
julia> module MWE
using Turing
@model function dynamic_bernoulli()
b ~ Bernoulli()
x_length = b ? 1 : 2
x = Vector{Float64}(undef, x_length)
for i in 1:x_length
x[i] ~ Normal(i, 1.0)
end
0.0 ~ Normal()
end
chain = sample(dynamic_bernoulli(), Gibbs(:b => PG(10), :x => ESS()), 2000; discard_initial=100)
describe(chain)
end
WARNING: replacing module MWE.
Sampling 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| Time: 0:00:04
Chains MCMC chain (2000×4×1 reshape(::Matrix{Union{Missing, Float64}}, 2000, 4, 1) with eltype Union{Missing, Float64}):
Iterations = 101:1:2100
Number of chains = 1
Samples per chain = 2000
Wall duration = 4.23 seconds
Compute duration = 4.23 seconds
parameters = b, x[1], x[2]
internals = lp
Summary Statistics
parameters mean std mcse ess_bulk ess_tail rhat ess_per_sec
Symbol Float64 Float64 Float64? Float64? Float64? Float64? Float64?
b 0.7975 0.4020 0.0098 1692.0019 NaN 0.9995 399.7170
x[1] 0.9588 1.0110 0.0232 1888.4870 1014.7637 0.9999 446.1344
x[2] 2.0290 0.9868 missing missing missing missing missing
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
b 0.0000 1.0000 1.0000 1.0000 1.0000
x[1] -1.0495 0.2851 0.9589 1.6325 2.9982
x[2] 0.1122 1.2627 2.0212 2.6988 3.9777
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On your earlier explanation of how PG works, that's helpful, and matches what I imagined it would be. In that light, I just didn't understand how this could come to pass:
I actually wonder if this is a Libtask thing. Let's say that we have a particle with b == false, so x_length = 2. After it sees x[1], it has to resample. Suppose it then resamples (transforms itself) into a particle with b == true and x_length = 1. What happens then, does it go on to sample x[2] or does it exit the loop because the termination condition x_length == 1 is now hit?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One comment/proposal, otherwise I think all good to go.
test/mcmc/gibbs.jl
Outdated
# The below test used to sample incorrectly before | ||
# https://github.com/TuringLang/Turing.jl/pull/2328 | ||
@testset "dynamic model with ESS" begin | ||
@model function dynamic_model_for_ess() | ||
b ~ Bernoulli() | ||
x_length = b ? 1 : 2 | ||
x = Vector{Float64}(undef, x_length) | ||
for i in 1:x_length | ||
x[i] ~ Normal(i, 1.0) | ||
end | ||
end | ||
|
||
m = dynamic_model_for_ess() | ||
chain = sample(m, Gibbs(:b => PG(10), :x => ESS()), 2000; discard_initial=100) | ||
means = Dict(:b => 0.5, "x[1]" => 1.0, "x[2]" => 2.0) | ||
stds = Dict(:b => 0.5, "x[1]" => 1.0, "x[2]" => 1.0) | ||
for vn in keys(means) | ||
@test isapprox(mean(skipmissing(chain[:, vn, 1])), means[vn]; atol=0.1) | ||
@test isapprox(std(skipmissing(chain[:, vn, 1])), stds[vn]; atol=0.1) | ||
end | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a regression test, and I don't know what the original issue was that caused it to error. We don't have too many dynamic model tests, so I would keep it with the MvNormal
because it doesn't cost us much and might still catch something.
An alternative, though: Any interest in turning this into a test to check whether you can still do dynamic models with loops if you balance out the branches with @addlogprob!!(0.0)
calls? Hacky, but I kinda want to know that it works as a workaround if need be. So that would be something like
@model function dynamic_model_for_ess()
b ~ Bernoulli()
x_length = b ? 1 : 2
max_x = 2
x = Vector{Float64}(undef, x_length)
for i in 1:x_length
x[i] ~ Normal(i, 1.0)
end
for _ in (x_length+1):max_x:
@addlogprob!!(0.0)
end
end
db267ce
to
1c424ab
Compare
Pretty sure that the diff of this PR matches when you reviewed so will merge. Dynamic models to be discussed at next Monday's team meeting... |
* First efforts towards DPPL 0.37 compat, WIP * More DPPL 0.37 compat work, WIP * Add [sources] for [email protected] * Remove context argument from `LogDensityFunction` * Fix MH * Remove spurious logging * Remove residual OptimizationContext * Delete files that were removed in previous releases * Fix typo * Simplify ESS * Fix LDF * Fix Prior(), fix a couple more imports * fixes * actually fix prior * Remove extra return value from tilde_assume * fix ldf * actually fix prior * fix HMC log-density * fix ldf * fix make_evaluate_... * more fixes for evaluate!! * fix hmc * fix run_ad * even more fixes (oh goodness when will this end) * more fixes * fix * more fix fix fix * fix return values of tilde pipeline * even more fixes * Fix missing import * More MH fixes * Fix conversion * don't think it really needs those type params * implement copy for LogPriorWithoutJacAcc * Even more fixes * More fixes; I think the remaining failures are pMCMC related * Fix merge * DPPL 0.37 compat for particle MCMC (#2625) * Progress in DPPL 0.37 compat for particle MCMC * WIP PMCMC work * Gibbs fixes for DPPL 0.37 (plus tiny bugfixes for ESS + HMC) (#2628) * Obviously this single commit will make Gibbs work * Fixes for ESS * Fix HMC call * improve some comments * Fixes to ProduceLogLikelihoodAccumulator * Use LogProbAccumulator for ProduceLogLikelihoodAccumulator * use get_conditioned_gibbs --------- Co-authored-by: Penelope Yong <[email protected]> * "Fixes" for PG-in-Gibbs (#2629) * WIP PMCMC work * Fixes to ProduceLogLikelihoodAccumulator * inline definition of `set_retained_vns_del!` * Fix ProduceLogLikelihoodAcc * Remove all uses of `set_retained_vns_del!` * Use nice functions * Remove PG tests with dynamic number of Gibbs-conditioned-observations * Fix essential/container tests * Update pMCMC implementation as per discussion * remove extra printing statements * revert unneeded changes * Add back (some kind of) dynamic model test * fix rebase * Add a todo comment for dynamic model tests --------- Co-authored-by: Markus Hauru <[email protected]> * Use accumulators to fix all logp calculations when sampling (#2630) * Use new `getlogjoint` for optimisation * Change getlogjoint -> getlogjoint_internal where needed * Enforce re-evaluation when constructing `Transition` * fix tests * Remove extra evaluations from SGLD and SGHMC * Remove dead `transitions_from_chain` method (used to be part of `predict`) * metadata -> getstats_with_lp * Clean up some stray getlogp * InitContext isn't for 0.37, update comments * Fix merge * Do not re-evaluate model for Prior (#2644) * Allow Prior to skip model re-evaluation * remove unneeded `default_chain_type` method * add a test * add a likelihood term too * why not test correctness while we're at it * No need to test AD for SamplingContext{<:HMC} (#2645) * change breaking -> main * Remove calls to resetlogp!! & add changelog (#2650) * Remove calls to resetlogp!! * Add a changelog for 0.40 * Update HISTORY.md Co-authored-by: Markus Hauru <[email protected]> --------- Co-authored-by: Markus Hauru <[email protected]> * Remove `[sources]` * Unify Turing `Transition`s, fix some tests (#2651) * Unify `Transition` methods * Add tests * Add same test for SGLD/SGHMC * Refactor so that it's nice and organised * Fix failing test on 1.10 * just increase the atol * Make addlogprob test more robust * Remove stray `@show` Co-authored-by: Markus Hauru <[email protected]> --------- Co-authored-by: Markus Hauru <[email protected]> * Update changelog for PG in Gibbs --------- Co-authored-by: Penelope Yong <[email protected]>
* [no ci] Bump to v0.40.0 * Uncomment tests that should be there * Support DPPL 0.37 (#2550) * First efforts towards DPPL 0.37 compat, WIP * More DPPL 0.37 compat work, WIP * Add [sources] for [email protected] * Remove context argument from `LogDensityFunction` * Fix MH * Remove spurious logging * Remove residual OptimizationContext * Delete files that were removed in previous releases * Fix typo * Simplify ESS * Fix LDF * Fix Prior(), fix a couple more imports * fixes * actually fix prior * Remove extra return value from tilde_assume * fix ldf * actually fix prior * fix HMC log-density * fix ldf * fix make_evaluate_... * more fixes for evaluate!! * fix hmc * fix run_ad * even more fixes (oh goodness when will this end) * more fixes * fix * more fix fix fix * fix return values of tilde pipeline * even more fixes * Fix missing import * More MH fixes * Fix conversion * don't think it really needs those type params * implement copy for LogPriorWithoutJacAcc * Even more fixes * More fixes; I think the remaining failures are pMCMC related * Fix merge * DPPL 0.37 compat for particle MCMC (#2625) * Progress in DPPL 0.37 compat for particle MCMC * WIP PMCMC work * Gibbs fixes for DPPL 0.37 (plus tiny bugfixes for ESS + HMC) (#2628) * Obviously this single commit will make Gibbs work * Fixes for ESS * Fix HMC call * improve some comments * Fixes to ProduceLogLikelihoodAccumulator * Use LogProbAccumulator for ProduceLogLikelihoodAccumulator * use get_conditioned_gibbs --------- Co-authored-by: Penelope Yong <[email protected]> * "Fixes" for PG-in-Gibbs (#2629) * WIP PMCMC work * Fixes to ProduceLogLikelihoodAccumulator * inline definition of `set_retained_vns_del!` * Fix ProduceLogLikelihoodAcc * Remove all uses of `set_retained_vns_del!` * Use nice functions * Remove PG tests with dynamic number of Gibbs-conditioned-observations * Fix essential/container tests * Update pMCMC implementation as per discussion * remove extra printing statements * revert unneeded changes * Add back (some kind of) dynamic model test * fix rebase * Add a todo comment for dynamic model tests --------- Co-authored-by: Markus Hauru <[email protected]> * Use accumulators to fix all logp calculations when sampling (#2630) * Use new `getlogjoint` for optimisation * Change getlogjoint -> getlogjoint_internal where needed * Enforce re-evaluation when constructing `Transition` * fix tests * Remove extra evaluations from SGLD and SGHMC * Remove dead `transitions_from_chain` method (used to be part of `predict`) * metadata -> getstats_with_lp * Clean up some stray getlogp * InitContext isn't for 0.37, update comments * Fix merge * Do not re-evaluate model for Prior (#2644) * Allow Prior to skip model re-evaluation * remove unneeded `default_chain_type` method * add a test * add a likelihood term too * why not test correctness while we're at it * No need to test AD for SamplingContext{<:HMC} (#2645) * change breaking -> main * Remove calls to resetlogp!! & add changelog (#2650) * Remove calls to resetlogp!! * Add a changelog for 0.40 * Update HISTORY.md Co-authored-by: Markus Hauru <[email protected]> --------- Co-authored-by: Markus Hauru <[email protected]> * Remove `[sources]` * Unify Turing `Transition`s, fix some tests (#2651) * Unify `Transition` methods * Add tests * Add same test for SGLD/SGHMC * Refactor so that it's nice and organised * Fix failing test on 1.10 * just increase the atol * Make addlogprob test more robust * Remove stray `@show` Co-authored-by: Markus Hauru <[email protected]> --------- Co-authored-by: Markus Hauru <[email protected]> * Update changelog for PG in Gibbs --------- Co-authored-by: Penelope Yong <[email protected]> --------- Co-authored-by: Markus Hauru <[email protected]>
This PR:
(1) Adds
ProduceLogLikelihoodAccumulator
on steps n ≥ 2 of PG, otherwiseLibtask.produce
is never called.(2) Removes all uses of
set_retained_vns_del!
, to avoid errors with PG-within-Gibbs (issue is that it would get a VariableOrderAccumulator from a different component sampler that didn't have the variables that PG cared about).The role of
set_retained_vns_del!
was to mark specific variables for resampling by comparingorder
againstnum_produce
. The idea was that we would only mark variables for resampling if they were 'in the future' i.e. we hadn't reached their tilde-statements in the model yet.Instead of doing that, we now simply mark everything as to be resampled. This is fine because models can only be propagated in the forward direction i.e. there is really no harm in marking previously seen variables for resampling because we won't see them again! There are two cases where this could lead to different behaviour:
del
flags set) is later re-used. To circumvent this problem, I introduce an inverse function which marks all variables as not being for resampling. It turns out that there is only one place where this inverse function has to be called: specifically, it has to be called before creating the reference particle on the next pMCMC iteration, to ensure that the trajectory of the reference particle is indeed fixed.Collectively, this completely removes the need to track which variables have already been seen. With these proposed changes there is simply no need to track the
order
of a variable or thenum_produce
of the varinfo, meaning that we can remove VariableOrderAccumulator entirely (with concomitant positive implications for performance).I have not exhaustively tested these changes on
breaking
, but I have tried making exactly the same changes to thedel
flag onmain
and they lead to no changes in the sampled results if the PRNG seed is set.Note
While we aren't quite there yet, this change also paves the way for eventually getting rid of the
del
flag entirely (TuringLang/DynamicPPL.jl#982).Notice that the changes in this PR already remove the need for a per-variable
del
flag; instead we can have a single boolean for the entire VarInfo which marks whether something should be resampled or not. This would already represent a substantial simplification.However, ultimately we should like to shift this information outside of the VarInfo because then it means that VarInfo doesn't need to track information that only PG needs. The way I would propose to do this is with contexts: one which never resamples variables (which would correspond to all
del
flags unset), and one which always resamples variables (corresponding to alldel
flags set). Notice further that these roles are very similar toDefaultContext
and the upcomingInitContext
. There are probably a ton of subtleties (e.g., in that we have to use the traced rng rather than just naively resampling), but I think there is light at the end of the tunnel.