From cea1f7d986c611ce84b8edb00529b987818c7593 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 15 May 2025 16:13:06 +0100 Subject: [PATCH 01/49] First efforts towards DPPL 0.37 compat, WIP --- Project.toml | 2 +- ext/TuringOptimExt.jl | 33 +++--- src/mcmc/Inference.jl | 3 - src/mcmc/ess.jl | 20 +--- src/mcmc/gibbs.jl | 12 +- src/mcmc/hmc.jl | 4 - src/mcmc/is.jl | 4 - src/mcmc/mh.jl | 4 - src/mcmc/particle_mcmc.jl | 22 ++-- src/optimisation/Optimisation.jl | 147 ++++++++++++------------ test/Project.toml | 2 +- test/mcmc/Inference.jl | 45 -------- test/mcmc/hmc.jl | 2 + test/mcmc/mh.jl | 2 + test/optimisation/Optimisation.jl | 67 +++-------- test/test_utils/ad_utils.jl | 185 ++++++++++++++++++++++++++++++ 16 files changed, 321 insertions(+), 233 deletions(-) create mode 100644 test/test_utils/ad_utils.jl diff --git a/Project.toml b/Project.toml index 3956c079fd..4ab392230e 100644 --- a/Project.toml +++ b/Project.toml @@ -64,7 +64,7 @@ Distributions = "0.25.77" DistributionsAD = "0.6" DocStringExtensions = "0.8, 0.9" DynamicHMC = "3.4" -DynamicPPL = "0.36.3" +DynamicPPL = "0.37" EllipticalSliceSampling = "0.5, 1, 2" ForwardDiff = "0.10.3, 1" Libtask = "0.8.8" diff --git a/ext/TuringOptimExt.jl b/ext/TuringOptimExt.jl index d6c253e2a2..9f5c51a2b4 100644 --- a/ext/TuringOptimExt.jl +++ b/ext/TuringOptimExt.jl @@ -34,8 +34,8 @@ function Optim.optimize( options::Optim.Options=Optim.Options(); kwargs..., ) - ctx = Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext()) - f = Optimisation.OptimLogDensity(model, ctx) + vi = DynamicPPL.setaccs!!(VarInfo(model), (DynamicPPL.LogLikelihoodAccumulator(),)) + f = Optimisation.OptimLogDensity(model, vi) init_vals = DynamicPPL.getparams(f.ldf) optimizer = Optim.LBFGS() return _mle_optimize(model, init_vals, optimizer, options; kwargs...) @@ -57,8 +57,8 @@ function Optim.optimize( options::Optim.Options=Optim.Options(); kwargs..., ) - ctx = Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext()) - f = Optimisation.OptimLogDensity(model, ctx) + vi = DynamicPPL.setaccs!!(VarInfo(model), (DynamicPPL.LogLikelihoodAccumulator(),)) + f = Optimisation.OptimLogDensity(model, vi) init_vals = DynamicPPL.getparams(f.ldf) return _mle_optimize(model, init_vals, optimizer, options; kwargs...) end @@ -74,8 +74,9 @@ function Optim.optimize( end function _mle_optimize(model::DynamicPPL.Model, args...; kwargs...) - ctx = Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext()) - return _optimize(Optimisation.OptimLogDensity(model, ctx), args...; kwargs...) + vi = DynamicPPL.setaccs!!(VarInfo(model), (DynamicPPL.LogLikelihoodAccumulator(),)) + f = Optimisation.OptimLogDensity(model, vi) + return _optimize(f, args...; kwargs...) end """ @@ -104,8 +105,8 @@ function Optim.optimize( options::Optim.Options=Optim.Options(); kwargs..., ) - ctx = Optimisation.OptimizationContext(DynamicPPL.DefaultContext()) - f = Optimisation.OptimLogDensity(model, ctx) + vi = DynamicPPL.setaccs!!(VarInfo(model), (LogPriorWithoutJacobianAccumulator(), DynamicPPL.LogLikelihoodAccumulator(),)) + f = Optimisation.OptimLogDensity(model, vi) init_vals = DynamicPPL.getparams(f.ldf) optimizer = Optim.LBFGS() return _map_optimize(model, init_vals, optimizer, options; kwargs...) @@ -127,8 +128,8 @@ function Optim.optimize( options::Optim.Options=Optim.Options(); kwargs..., ) - ctx = Optimisation.OptimizationContext(DynamicPPL.DefaultContext()) - f = Optimisation.OptimLogDensity(model, ctx) + vi = DynamicPPL.setaccs!!(VarInfo(model), (LogPriorWithoutJacobianAccumulator(), DynamicPPL.LogLikelihoodAccumulator(),)) + f = Optimisation.OptimLogDensity(model, vi) init_vals = DynamicPPL.getparams(f.ldf) return _map_optimize(model, init_vals, optimizer, options; kwargs...) end @@ -144,9 +145,11 @@ function Optim.optimize( end function _map_optimize(model::DynamicPPL.Model, args...; kwargs...) - ctx = Optimisation.OptimizationContext(DynamicPPL.DefaultContext()) - return _optimize(Optimisation.OptimLogDensity(model, ctx), args...; kwargs...) + vi = DynamicPPL.setaccs!!(VarInfo(model), (LogPriorWithoutJacobianAccumulator(), DynamicPPL.LogLikelihoodAccumulator(),)) + f = Optimisation.OptimLogDensity(model, vi) + return _optimize(f, args...; kwargs...) end + """ _optimize(f::OptimLogDensity, optimizer=Optim.LBFGS(), args...; kwargs...) @@ -166,7 +169,7 @@ function _optimize( # whether initialisation is really necessary at all vi = DynamicPPL.unflatten(f.ldf.varinfo, init_vals) vi = DynamicPPL.link(vi, f.ldf.model) - f = Optimisation.OptimLogDensity(f.ldf.model, vi, f.ldf.context; adtype=f.ldf.adtype) + f = Optimisation.OptimLogDensity(f.ldf.model, vi; adtype=f.ldf.adtype) init_vals = DynamicPPL.getparams(f.ldf) # Optimize! @@ -183,9 +186,7 @@ function _optimize( # Get the optimum in unconstrained space. `getparams` does the invlinking. vi = f.ldf.varinfo vi_optimum = DynamicPPL.unflatten(vi, M.minimizer) - logdensity_optimum = Optimisation.OptimLogDensity( - f.ldf.model, vi_optimum, f.ldf.context - ) + logdensity_optimum = Optimisation.OptimLogDensity(f.ldf.model, vi_optimum; adtype=f.ldf.adtype) vns_vals_iter = Turing.Inference.getparams(f.ldf.model, vi_optimum) varnames = map(Symbol ∘ first, vns_vals_iter) vals = map(last, vns_vals_iter) diff --git a/src/mcmc/Inference.jl b/src/mcmc/Inference.jl index 0370e619a3..52d4277b0f 100644 --- a/src/mcmc/Inference.jl +++ b/src/mcmc/Inference.jl @@ -26,9 +26,6 @@ using DynamicPPL: SampleFromPrior, SampleFromUniform, DefaultContext, - PriorContext, - LikelihoodContext, - SamplingContext, set_flag!, unset_flag! using Distributions, Libtask, Bijectors diff --git a/src/mcmc/ess.jl b/src/mcmc/ess.jl index 5448173486..5205772032 100644 --- a/src/mcmc/ess.jl +++ b/src/mcmc/ess.jl @@ -49,7 +49,7 @@ function AbstractMCMC.step( rng, EllipticalSliceSampling.ESSModel( ESSPrior(model, spl, vi), - DynamicPPL.LogDensityFunction( + DynamicPPL.LogDensityFunction{:LogLikelihood}( model, vi, DynamicPPL.SamplingContext(spl, DynamicPPL.DefaultContext()) ), ), @@ -59,7 +59,7 @@ function AbstractMCMC.step( # update sample and log-likelihood vi = DynamicPPL.unflatten(vi, sample) - vi = setlogp!!(vi, state.loglikelihood) + vi = setloglikelihood!!(vi, state.loglikelihood) return Transition(model, vi), vi end @@ -108,20 +108,12 @@ end # Mean of prior distribution Distributions.mean(p::ESSPrior) = p.μ -# Evaluate log-likelihood of proposals -const ESSLogLikelihood{M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo} = - DynamicPPL.LogDensityFunction{M,V,<:DynamicPPL.SamplingContext{<:S},AD} where {AD} - -(ℓ::ESSLogLikelihood)(f::AbstractVector) = LogDensityProblems.logdensity(ℓ, f) - function DynamicPPL.tilde_assume( - rng::Random.AbstractRNG, ::DefaultContext, ::Sampler{<:ESS}, right, vn, vi + rng::Random.AbstractRNG, ctx::DefaultContext, ::Sampler{<:ESS}, right, vn, vi ) - return DynamicPPL.tilde_assume( - rng, LikelihoodContext(), SampleFromPrior(), right, vn, vi - ) + return DynamicPPL.tilde_assume(rng, ctx, SampleFromPrior(), right, vn, vi) end -function DynamicPPL.tilde_observe(ctx::DefaultContext, ::Sampler{<:ESS}, right, left, vi) - return DynamicPPL.tilde_observe(ctx, SampleFromPrior(), right, left, vi) +function DynamicPPL.tilde_observe!!(ctx::DefaultContext, ::Sampler{<:ESS}, right, left, vn, vi) + return DynamicPPL.tilde_observe!!(ctx, SampleFromPrior(), right, left, vn, vi) end diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index f36cb9c364..34d372a9e4 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -33,7 +33,7 @@ can_be_wrapped(ctx::DynamicPPL.PrefixContext) = can_be_wrapped(ctx.context) # # Purpose: avoid triggering resampling of variables we're conditioning on. # - Using standard `DynamicPPL.condition` results in conditioned variables being treated -# as observations in the truest sense, i.e. we hit `DynamicPPL.tilde_observe`. +# as observations in the truest sense, i.e. we hit `DynamicPPL.tilde_observe!!`. # - But `observe` is overloaded by some samplers, e.g. `CSMC`, which can lead to # undesirable behavior, e.g. `CSMC` triggering a resampling for every conditioned variable # rather than only for the "true" observations. @@ -178,16 +178,18 @@ function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi) DynamicPPL.tilde_assume(child_context, right, vn, vi) elseif has_conditioned_gibbs(context, vn) # Short-circuit the tilde assume if `vn` is present in `context`. - value, lp, _ = DynamicPPL.tilde_assume( + # TODO(mhauru) Fix accumulation here. In this branch anything that gets + # accumulated just gets discarded with `_`. + value, _ = DynamicPPL.tilde_assume( child_context, right, vn, get_global_varinfo(context) ) - value, lp, vi + value, vi else # If the varname has not been conditioned on, nor is it a target variable, its # presumably a new variable that should be sampled from its prior. We need to add # this new variable to the global `varinfo` of the context, but not to the local one # being used by the current sampler. - value, lp, new_global_vi = DynamicPPL.tilde_assume( + value, new_global_vi = DynamicPPL.tilde_assume( child_context, DynamicPPL.SampleFromPrior(), right, @@ -195,7 +197,7 @@ function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi) get_global_varinfo(context), ) set_global_varinfo!(context, new_global_vi) - value, lp, vi + value, vi end end diff --git a/src/mcmc/hmc.jl b/src/mcmc/hmc.jl index b5f51587b1..5175a9831c 100644 --- a/src/mcmc/hmc.jl +++ b/src/mcmc/hmc.jl @@ -516,10 +516,6 @@ function DynamicPPL.assume( return DynamicPPL.assume(dist, vn, vi) end -function DynamicPPL.observe(::Sampler{<:Hamiltonian}, d::Distribution, value, vi) - return DynamicPPL.observe(d, value, vi) -end - #### #### Default HMC stepsize and mass matrix adaptor #### diff --git a/src/mcmc/is.jl b/src/mcmc/is.jl index d83abd173c..9ad0e1f82a 100644 --- a/src/mcmc/is.jl +++ b/src/mcmc/is.jl @@ -55,7 +55,3 @@ function DynamicPPL.assume(rng, ::Sampler{<:IS}, dist::Distribution, vn::VarName end return r, 0, vi end - -function DynamicPPL.observe(::Sampler{<:IS}, dist::Distribution, value, vi) - return logpdf(dist, value), vi -end diff --git a/src/mcmc/mh.jl b/src/mcmc/mh.jl index fb50c5f582..97f4209bec 100644 --- a/src/mcmc/mh.jl +++ b/src/mcmc/mh.jl @@ -392,7 +392,3 @@ function DynamicPPL.assume( retval = DynamicPPL.assume(rng, SampleFromPrior(), dist, vn, vi) return retval end - -function DynamicPPL.observe(spl::Sampler{<:MH}, d::Distribution, value, vi) - return DynamicPPL.observe(SampleFromPrior(), d, value, vi) -end diff --git a/src/mcmc/particle_mcmc.jl b/src/mcmc/particle_mcmc.jl index ffc1019519..a4d7ef1dc2 100644 --- a/src/mcmc/particle_mcmc.jl +++ b/src/mcmc/particle_mcmc.jl @@ -450,10 +450,11 @@ function DynamicPPL.assume( return r, lp, vi end -function DynamicPPL.observe(spl::Sampler{<:Union{PG,SMC}}, dist::Distribution, value, vi) - # NOTE: The `Libtask.produce` is now hit in `acclogp_observe!!`. - return logpdf(dist, value), trace_local_varinfo_maybe(vi) -end +# TODO(mhauru) Fix this. +# function DynamicPPL.observe(spl::Sampler{<:Union{PG,SMC}}, dist::Distribution, value, vi) +# # NOTE: The `Libtask.produce` is now hit in `acclogp_observe!!`. +# return logpdf(dist, value), trace_local_varinfo_maybe(vi) +# end function DynamicPPL.acclogp!!( context::SamplingContext{<:Sampler{<:Union{PG,SMC}}}, varinfo::AbstractVarInfo, logp @@ -462,12 +463,13 @@ function DynamicPPL.acclogp!!( return DynamicPPL.acclogp!!(DynamicPPL.childcontext(context), varinfo_trace, logp) end -function DynamicPPL.acclogp_observe!!( - context::SamplingContext{<:Sampler{<:Union{PG,SMC}}}, varinfo::AbstractVarInfo, logp -) - Libtask.produce(logp) - return trace_local_varinfo_maybe(varinfo) -end +# TODO(mhauru) Fix this. +# function DynamicPPL.acclogp_observe!!( +# context::SamplingContext{<:Sampler{<:Union{PG,SMC}}}, varinfo::AbstractVarInfo, logp +# ) +# Libtask.produce(logp) +# return trace_local_varinfo_maybe(varinfo) +# end # Convenient constructor function AdvancedPS.Trace( diff --git a/src/optimisation/Optimisation.jl b/src/optimisation/Optimisation.jl index ddcc27b876..23da8b08a6 100644 --- a/src/optimisation/Optimisation.jl +++ b/src/optimisation/Optimisation.jl @@ -43,75 +43,87 @@ Concrete type for maximum a posteriori estimation. Only used for the Optim.jl in """ struct MAP <: ModeEstimator end +# Most of these functions for LogPriorWithoutJacobianAccumulator are copied from +# LogPriorAccumulator. The only one that is different is the accumulate_assume!! one. """ - OptimizationContext{C<:AbstractContext} <: AbstractContext + LogPriorWithoutJacobianAccumulator{T} <: DynamicPPL.AbstractAccumulator -The `OptimizationContext` transforms variables to their constrained space, but -does not use the density with respect to the transformation. This context is -intended to allow an optimizer to sample in R^n freely. +Exactly like DynamicPPL.LogPriorAccumulator, but does not include the log determinant of the +Jacobian of any variable transformations. + +Used for MAP optimisation. """ -struct OptimizationContext{C<:DynamicPPL.AbstractContext} <: DynamicPPL.AbstractContext - context::C +struct LogPriorWithoutJacobianAccumulator{T} <: DynamicPPL.AbstractAccumulator + logp::T +end - function OptimizationContext{C}(context::C) where {C<:DynamicPPL.AbstractContext} - if !( - context isa Union{ - DynamicPPL.DefaultContext, - DynamicPPL.LikelihoodContext, - DynamicPPL.PriorContext, - } - ) - msg = """ - `OptimizationContext` supports only leaf contexts of type - `DynamicPPL.DefaultContext`, `DynamicPPL.LikelihoodContext`, - and `DynamicPPL.PriorContext` (given: `$(typeof(context)))` - """ - throw(ArgumentError(msg)) - end - return new{C}(context) - end +""" + LogPriorWithoutJacobianAccumulator{T}() + +Create a new `LogPriorWithoutJacobianAccumulator` accumulator with the log prior initialized to zero. +""" +LogPriorWithoutJacobianAccumulator{T}() where {T<:Real} = LogPriorWithoutJacobianAccumulator(zero(T)) +LogPriorWithoutJacobianAccumulator() = LogPriorWithoutJacobianAccumulator{DynamicPPL.LogProbType}() + +function Base.show(io::IO, acc::LogPriorWithoutJacobianAccumulator) + return print(io, "LogPriorWithoutJacobianAccumulator($(repr(acc.logp)))") end -OptimizationContext(ctx::DynamicPPL.AbstractContext) = OptimizationContext{typeof(ctx)}(ctx) +# We use the same name for LogPriorWithoutJacobianAccumulator as for LogPriorAccumulator. +# This has three effects: +# 1. You can't have a VarInfo with both accumulator types. +# 2. When you call functions like `getlogprior` on a VarInfo, it will return the one without +# the Jacobian term, as if that was the usual log prior. +# 3. This may cause a small number of invalidations in DynamicPPL. I haven't checked, but I +# suspect they will be negligible. +# TODO(mhauru) Not sure I like this solution. It's kinda glib, but might confuse a reader +# of the code who expects things like `getlogprior` to always get the LogPriorAccumulator +# contents. Another solution would be welcome, but would need to play nicely with how +# LogDenssityFunction works, since it calls `getlogprior` explictily. +DynamicPPL.accumulator_name(::Type{<:LogPriorWithoutJacobianAccumulator}) = :LogPrior + +DynamicPPL.split(::LogPriorWithoutJacobianAccumulator{T}) where {T} = LogPriorWithoutJacobianAccumulator(zero(T)) + +function DynamicPPL.combine(acc::LogPriorWithoutJacobianAccumulator, acc2::LogPriorWithoutJacobianAccumulator) + return LogPriorWithoutJacobianAccumulator(acc.logp + acc2.logp) +end -DynamicPPL.NodeTrait(::OptimizationContext) = DynamicPPL.IsLeaf() +function Base.:+(acc1::LogPriorWithoutJacobianAccumulator, acc2::LogPriorWithoutJacobianAccumulator) + return LogPriorWithoutJacobianAccumulator(acc1.logp + acc2.logp) +end -function DynamicPPL.tilde_assume(ctx::OptimizationContext, dist, vn, vi) - r = vi[vn, dist] - lp = if ctx.context isa Union{DynamicPPL.DefaultContext,DynamicPPL.PriorContext} - # MAP - Distributions.logpdf(dist, r) - else - # MLE - 0 - end - return r, lp, vi +Base.zero(acc::LogPriorWithoutJacobianAccumulator) = LogPriorWithoutJacobianAccumulator(zero(acc.logp)) + +function DynamicPPL.accumulate_assume!!(acc::LogPriorWithoutJacobianAccumulator, val, logjac, vn, right) + return acc + LogPriorWithoutJacobianAccumulator(Distributions.logpdf(right, val)) end +DynamicPPL.accumulate_observe!!(acc::LogPriorWithoutJacobianAccumulator, right, left, vn) = acc -function DynamicPPL.tilde_observe( - ctx::OptimizationContext{<:DynamicPPL.PriorContext}, args... -) - return DynamicPPL.tilde_observe(ctx.context, args...) +function Base.convert(::Type{LogPriorWithoutJacobianAccumulator{T}}, acc::LogPriorWithoutJacobianAccumulator) where {T} + return LogPriorWithoutJacobianAccumulator(convert(T, acc.logp)) +end + +function DynamicPPL.convert_eltype(::Type{T}, acc::LogPriorWithoutJacobianAccumulator) where {T} + return LogPriorWithoutJacobianAccumulator(convert(T, acc.logp)) end """ OptimLogDensity{ M<:DynamicPPL.Model, - V<:DynamicPPL.VarInfo, - C<:OptimizationContext, + V<:DynamicPPL.AbstractVarInfo, AD<:ADTypes.AbstractADType } A struct that wraps a single LogDensityFunction. Can be invoked either using ```julia -OptimLogDensity(model, varinfo, ctx; adtype=adtype) +OptimLogDensity(model, varinfo; adtype=adtype) ``` or ```julia -OptimLogDensity(model, ctx; adtype=adtype) +OptimLogDensity(model; adtype=adtype) ``` If not specified, `adtype` defaults to `AutoForwardDiff()`. @@ -129,37 +141,20 @@ the underlying LogDensityFunction at the point `z`. This is done to satisfy the Optim.jl interface. ```julia -optim_ld = OptimLogDensity(model, varinfo, ctx) +optim_ld = OptimLogDensity(model, varinfo) optim_ld(z) # returns -logp ``` """ struct OptimLogDensity{ M<:DynamicPPL.Model, - V<:DynamicPPL.VarInfo, - C<:OptimizationContext, + V<:DynamicPPL.AbstractVarInfo, AD<:ADTypes.AbstractADType, } - ldf::DynamicPPL.LogDensityFunction{M,V,C,AD} -end - -function OptimLogDensity( - model::DynamicPPL.Model, - vi::DynamicPPL.VarInfo, - ctx::OptimizationContext; - adtype::ADTypes.AbstractADType=AutoForwardDiff(), -) - return OptimLogDensity(DynamicPPL.LogDensityFunction(model, vi, ctx; adtype=adtype)) + ldf::DynamicPPL.LogDensityFunction{M,V,AD} end -# No varinfo -function OptimLogDensity( - model::DynamicPPL.Model, - ctx::OptimizationContext; - adtype::ADTypes.AbstractADType=AutoForwardDiff(), -) - return OptimLogDensity( - DynamicPPL.LogDensityFunction(model, DynamicPPL.VarInfo(model), ctx; adtype=adtype) - ) +function OptimLogDensity(model::DynamicPPL.Model, vi::DynamicPPL.AbstractVarInfo=DynamicPPL.VarInfo(model); adtype=AutoForwardDiff()) + return OptimLogDensity(DynamicPPL.LogDensityFunction(model, vi; adtype=adtype)) end """ @@ -325,10 +320,11 @@ function StatsBase.informationmatrix( # Convert the values to their unconstrained states to make sure the # Hessian is computed with respect to the untransformed parameters. - linked = DynamicPPL.istrans(m.f.ldf.varinfo) + old_ldf = m.f.ldf + linked = DynamicPPL.istrans(old_ldf.varinfo) if linked - new_vi = DynamicPPL.invlink!!(m.f.ldf.varinfo, m.f.ldf.model) - new_f = OptimLogDensity(m.f.ldf.model, new_vi, m.f.ldf.context) + new_vi = DynamicPPL.invlink!!(old_ldf.varinfo, old_ldf.model) + new_f = OptimLogDensity(old_ldf.model, new_vi; adtype=old_ldf.adtype) m = Accessors.@set m.f = new_f end @@ -339,8 +335,9 @@ function StatsBase.informationmatrix( # Link it back if we invlinked it. if linked - new_vi = DynamicPPL.link!!(m.f.ldf.varinfo, m.f.ldf.model) - new_f = OptimLogDensity(m.f.ldf.model, new_vi, m.f.ldf.context) + invlinked_ldf = m.f.ldf + new_vi = DynamicPPL.link!!(invlinked_ldf.varinfo, invlinked_ldf.model) + new_f = OptimLogDensity(invlinked_ldf.model, new_vi; adtype=invlinked_ldf.adtype) m = Accessors.@set m.f = new_f end @@ -560,12 +557,11 @@ function estimate_mode( # Create an OptimLogDensity object that can be used to evaluate the objective function, # i.e. the negative log density. - inner_context = if estimator isa MAP - DynamicPPL.DefaultContext() + accs = if estimator isa MAP + (LogPriorWithoutJacobianAccumulator(), DynamicPPL.LogLikelihoodAccumulator()) else - DynamicPPL.LikelihoodContext() + (DynamicPPL.LogLikelihoodAccumulator(),) end - ctx = OptimizationContext(inner_context) # Set its VarInfo to the initial parameters. # TODO(penelopeysm): Unclear if this is really needed? Any time that logp is calculated @@ -574,6 +570,7 @@ function estimate_mode( # directly on the fields of the LogDensityFunction vi = DynamicPPL.VarInfo(model) vi = DynamicPPL.unflatten(vi, initial_params) + vi = DynamicPPL.setaccs!!(vi, accs) # Link the varinfo if needed. # TODO(mhauru) We currently couple together the questions of whether the user specified @@ -585,7 +582,7 @@ function estimate_mode( vi = DynamicPPL.link(vi, model) end - log_density = OptimLogDensity(model, vi, ctx) + log_density = OptimLogDensity(model, vi) prob = Optimization.OptimizationProblem(log_density, adtype, constraints) solution = Optimization.solve(prob, solver; kwargs...) diff --git a/test/Project.toml b/test/Project.toml index 0048224d50..42f32936cb 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -53,7 +53,7 @@ Combinatorics = "1" Distributions = "0.25" DistributionsAD = "0.6.3" DynamicHMC = "2.1.6, 3.0" -DynamicPPL = "0.36.12" +DynamicPPL = "0.37" FiniteDifferences = "0.10.8, 0.11, 0.12" ForwardDiff = "0.10.12 - 0.10.32, 0.10, 1" HypothesisTests = "0.11" diff --git a/test/mcmc/Inference.jl b/test/mcmc/Inference.jl index a0d4421869..38baa46fc2 100644 --- a/test/mcmc/Inference.jl +++ b/test/mcmc/Inference.jl @@ -113,36 +113,6 @@ using Turing check_gdemo(chn3_contd) end - @testset "Contexts" begin - # Test LikelihoodContext - @model function testmodel1(x) - a ~ Beta() - lp1 = getlogp(__varinfo__) - x[1] ~ Bernoulli(a) - return global loglike = getlogp(__varinfo__) - lp1 - end - model = testmodel1([1.0]) - varinfo = DynamicPPL.VarInfo(model) - model(varinfo, DynamicPPL.SampleFromPrior(), DynamicPPL.LikelihoodContext()) - @test getlogp(varinfo) == loglike - - # Test MiniBatchContext - @model function testmodel2(x) - a ~ Beta() - return x[1] ~ Bernoulli(a) - end - model = testmodel2([1.0]) - varinfo1 = DynamicPPL.VarInfo(model) - varinfo2 = deepcopy(varinfo1) - model(varinfo1, DynamicPPL.SampleFromPrior(), DynamicPPL.LikelihoodContext()) - model( - varinfo2, - DynamicPPL.SampleFromPrior(), - DynamicPPL.MiniBatchContext(DynamicPPL.LikelihoodContext(), 10), - ) - @test isapprox(getlogp(varinfo2) / getlogp(varinfo1), 10) - end - @testset "Prior" begin N = 10_000 @@ -174,21 +144,6 @@ using Turing @test mean(x[:s][1] for x in chains) ≈ 3 atol = 0.11 @test mean(x[:m][1] for x in chains) ≈ 0 atol = 0.1 end - - @testset "#2169" begin - # Not exactly the same as the issue, but similar. - @model function issue2169_model() - if DynamicPPL.leafcontext(__context__) isa DynamicPPL.PriorContext - x ~ Normal(0, 1) - else - x ~ Normal(1000, 1) - end - end - - model = issue2169_model() - chain = sample(StableRNG(seed), model, Prior(), 10) - @test all(mean(chain[:x]) .< 5) - end end @testset "chain ordering" begin diff --git a/test/mcmc/hmc.jl b/test/mcmc/hmc.jl index 8832e5fe7b..f78c7a0237 100644 --- a/test/mcmc/hmc.jl +++ b/test/mcmc/hmc.jl @@ -171,6 +171,8 @@ using Turing @test Array(res1) == Array(res2) == Array(res3) end + # TODO(mhauru) Do we give up being able to sample from only prior/likelihood like this, + # or do we implement some way to pass `whichlogprob=:LogPrior` through `sample`? @testset "prior" begin # NOTE: Used to use `InverseGamma(2, 3)` but this has infinite variance # which means that it's _very_ difficult to find a good tolerance in the test below:) diff --git a/test/mcmc/mh.jl b/test/mcmc/mh.jl index add2e7404a..3bbb83db5f 100644 --- a/test/mcmc/mh.jl +++ b/test/mcmc/mh.jl @@ -262,6 +262,8 @@ GKernel(var) = (x) -> Normal(x, sqrt.(var)) @test !DynamicPPL.islinked(vi) end + # TODO(mhauru) Do we give up being able to sample from only prior/likelihood like this, + # or do we implement some way to pass `whichlogprob=:LogPrior` through `sample`? @testset "prior" begin alg = MH() gdemo_default_prior = DynamicPPL.contextualize( diff --git a/test/optimisation/Optimisation.jl b/test/optimisation/Optimisation.jl index 2acb7edc55..9850a3ed39 100644 --- a/test/optimisation/Optimisation.jl +++ b/test/optimisation/Optimisation.jl @@ -25,27 +25,6 @@ using Turing # Issue: https://discourse.julialang.org/t/two-equivalent-conditioning-syntaxes-giving-different-likelihood-values/100320 @testset "OptimizationContext" begin - # Used for testing how well it works with nested contexts. - struct OverrideContext{C,T1,T2} <: DynamicPPL.AbstractContext - context::C - logprior_weight::T1 - loglikelihood_weight::T2 - end - DynamicPPL.NodeTrait(::OverrideContext) = DynamicPPL.IsParent() - DynamicPPL.childcontext(parent::OverrideContext) = parent.context - DynamicPPL.setchildcontext(parent::OverrideContext, child) = - OverrideContext(child, parent.logprior_weight, parent.loglikelihood_weight) - - # Only implement what we need for the models above. - function DynamicPPL.tilde_assume(context::OverrideContext, right, vn, vi) - value, logp, vi = DynamicPPL.tilde_assume(context.context, right, vn, vi) - return value, context.logprior_weight, vi - end - function DynamicPPL.tilde_observe(context::OverrideContext, right, left, vi) - logp, vi = DynamicPPL.tilde_observe(context.context, right, left, vi) - return context.loglikelihood_weight, vi - end - @model function model1(x) μ ~ Uniform(0, 2) return x ~ LogNormal(μ, 1) @@ -62,48 +41,34 @@ using Turing @testset "With ConditionContext" begin m1 = model1(x) m2 = model2() | (x=x,) - ctx = Turing.Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext()) - @test Turing.Optimisation.OptimLogDensity(m1, ctx)(w) == - Turing.Optimisation.OptimLogDensity(m2, ctx)(w) + @test Turing.Optimisation.OptimLogDensity(m1)(w) == + Turing.Optimisation.OptimLogDensity(m2)(w) end @testset "With prefixes" begin vn = @varname(inner) m1 = prefix(model1(x), vn) m2 = prefix((model2() | (x=x,)), vn) - ctx = Turing.Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext()) - @test Turing.Optimisation.OptimLogDensity(m1, ctx)(w) == - Turing.Optimisation.OptimLogDensity(m2, ctx)(w) - end - - @testset "Weighted" begin - function override(model) - return DynamicPPL.contextualize( - model, OverrideContext(model.context, 100, 1) - ) - end - m1 = override(model1(x)) - m2 = override(model2() | (x=x,)) - ctx = Turing.Optimisation.OptimizationContext(DynamicPPL.DefaultContext()) - @test Turing.Optimisation.OptimLogDensity(m1, ctx)(w) == - Turing.Optimisation.OptimLogDensity(m2, ctx)(w) + @test Turing.Optimisation.OptimLogDensity(m1)(w) == + Turing.Optimisation.OptimLogDensity(m2)(w) end @testset "Default, Likelihood, Prior Contexts" begin m1 = model1(x) - defctx = Turing.Optimisation.OptimizationContext(DynamicPPL.DefaultContext()) - llhctx = Turing.Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext()) - prictx = Turing.Optimisation.OptimizationContext(DynamicPPL.PriorContext()) + vi = DynamicPPL.VarInfo(m1) + vi_joint = DynamicPPL.setaccs!!(deepcopy(vi), (LogPriorWithoutJacobianAccumulator(), LogLikelihoodAccumulator())) + vi_prior = DynamicPPL.setaccs!!(deepcopy(vi), (LogPriorWithoutJacobianAccumulator(),)) + vi_likelihood = DynamicPPL.setaccs!!(deepcopy(vi), (LogLikelihoodAccumulator(),)) a = [0.3] - @test Turing.Optimisation.OptimLogDensity(m1, defctx)(a) == - Turing.Optimisation.OptimLogDensity(m1, llhctx)(a) + - Turing.Optimisation.OptimLogDensity(m1, prictx)(a) + @test Turing.Optimisation.OptimLogDensity(m1, vi_joint)(a) == + Turing.Optimisation.OptimLogDensity(m1, vi_prior)(a) + + Turing.Optimisation.OptimLogDensity(m1, vi_likelihood)(a) - # test that PriorContext is calculating the right thing - @test Turing.Optimisation.OptimLogDensity(m1, prictx)([0.3]) ≈ + # test that the prior accumulator is calculating the right thing + @test Turing.Optimisation.OptimLogDensity(m1, vi_prior)([0.3]) ≈ -Distributions.logpdf(Uniform(0, 2), 0.3) - @test Turing.Optimisation.OptimLogDensity(m1, prictx)([-0.3]) ≈ + @test Turing.Optimisation.OptimLogDensity(m1, vi_prior)([-0.3]) ≈ -Distributions.logpdf(Uniform(0, 2), -0.3) end end @@ -651,8 +616,8 @@ using Turing return nothing end m = saddle_model() - ctx = Turing.Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext()) - optim_ld = Turing.Optimisation.OptimLogDensity(m, ctx) + vi = DynamicPPL.setaccs!!(DynamicPPL.VarInfo(m), (LogLikelihoodAccumulator(),)) + optim_ld = Turing.Optimisation.OptimLogDensity(m, vi) vals = Turing.Optimisation.NamedArrays.NamedArray([0.0, 0.0]) m = Turing.Optimisation.ModeResult(vals, nothing, 0.0, optim_ld) ct = coeftable(m) diff --git a/test/test_utils/ad_utils.jl b/test/test_utils/ad_utils.jl new file mode 100644 index 0000000000..36bd7a9f68 --- /dev/null +++ b/test/test_utils/ad_utils.jl @@ -0,0 +1,185 @@ +module ADUtils + +using ForwardDiff: ForwardDiff +using Pkg: Pkg +using Random: Random +using ReverseDiff: ReverseDiff +using Mooncake: Mooncake +using Test: Test +using Turing: Turing +using Turing: DynamicPPL + +export ADTypeCheckContext, adbackends + +# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # +# Stuff for checking that the right AD backend is being used. + +"""Element types that are always valid for a VarInfo regardless of ADType.""" +const always_valid_eltypes = (AbstractFloat, AbstractIrrational, Integer, Rational) + +"""A dictionary mapping ADTypes to the element types they use.""" +const eltypes_by_adtype = Dict( + Turing.AutoForwardDiff => (ForwardDiff.Dual,), + Turing.AutoReverseDiff => ( + ReverseDiff.TrackedArray, + ReverseDiff.TrackedMatrix, + ReverseDiff.TrackedReal, + ReverseDiff.TrackedStyle, + ReverseDiff.TrackedType, + ReverseDiff.TrackedVecOrMat, + ReverseDiff.TrackedVector, + ), + Turing.AutoMooncake => (Mooncake.CoDual,), +) + +""" + AbstractWrongADBackendError + +An abstract error thrown when we seem to be using a different AD backend than expected. +""" +abstract type AbstractWrongADBackendError <: Exception end + +""" + WrongADBackendError + +An error thrown when we seem to be using a different AD backend than expected. +""" +struct WrongADBackendError <: AbstractWrongADBackendError + actual_adtype::Type + expected_adtype::Type +end + +function Base.showerror(io::IO, e::WrongADBackendError) + return print( + io, "Expected to use $(e.expected_adtype), but using $(e.actual_adtype) instead." + ) +end + +""" + IncompatibleADTypeError + +An error thrown when an element type is encountered that is unexpected for the given ADType. +""" +struct IncompatibleADTypeError <: AbstractWrongADBackendError + valtype::Type + adtype::Type +end + +function Base.showerror(io::IO, e::IncompatibleADTypeError) + return print( + io, + "Incompatible ADType: Did not expect element of type $(e.valtype) with $(e.adtype)", + ) +end + +""" + ADTypeCheckContext{ADType,ChildContext} + +A context for checking that the expected ADType is being used. + +Evaluating a model with this context will check that the types of values in a `VarInfo` are +compatible with the ADType of the context. If the check fails, an `IncompatibleADTypeError` +is thrown. + +For instance, evaluating a model with +`ADTypeCheckContext(AutoForwardDiff(), child_context)` +would throw an error if within the model a type associated with e.g. ReverseDiff was +encountered. + +""" +struct ADTypeCheckContext{ADType,ChildContext<:DynamicPPL.AbstractContext} <: + DynamicPPL.AbstractContext + child::ChildContext + + function ADTypeCheckContext(adbackend, child) + adtype = adbackend isa Type ? adbackend : typeof(adbackend) + if !any(adtype <: k for k in keys(eltypes_by_adtype)) + throw(ArgumentError("Unsupported ADType: $adtype")) + end + return new{adtype,typeof(child)}(child) + end +end + +adtype(_::ADTypeCheckContext{ADType}) where {ADType} = ADType + +DynamicPPL.NodeTrait(::ADTypeCheckContext) = DynamicPPL.IsParent() +DynamicPPL.childcontext(c::ADTypeCheckContext) = c.child +function DynamicPPL.setchildcontext(c::ADTypeCheckContext, child) + return ADTypeCheckContext(adtype(c), child) +end + +""" + valid_eltypes(context::ADTypeCheckContext) + +Return the element types that are valid for the ADType of `context` as a tuple. +""" +function valid_eltypes(context::ADTypeCheckContext) + context_at = adtype(context) + for at in keys(eltypes_by_adtype) + if context_at <: at + return (eltypes_by_adtype[at]..., always_valid_eltypes...) + end + end + # This should never be reached due to the check in the inner constructor. + throw(ArgumentError("Unsupported ADType: $(adtype(context))")) +end + +""" + check_adtype(context::ADTypeCheckContext, vi::DynamicPPL.VarInfo) + +Check that the element types in `vi` are compatible with the ADType of `context`. + +Throw an `IncompatibleADTypeError` if an incompatible element type is encountered. +""" +function check_adtype(context::ADTypeCheckContext, vi::DynamicPPL.AbstractVarInfo) + valids = valid_eltypes(context) + for val in vi[:] + valtype = typeof(val) + if !any(valtype .<: valids) + throw(IncompatibleADTypeError(valtype, adtype(context))) + end + end + return nothing +end + +# A bunch of tilde_assume/tilde_observe methods that just call the same method on the child +# context, and then call check_adtype on the result before returning the results from the +# child context. + +function DynamicPPL.tilde_assume(context::ADTypeCheckContext, right, vn, vi) + value, vi = DynamicPPL.tilde_assume( + DynamicPPL.childcontext(context), right, vn, vi + ) + check_adtype(context, vi) + return value, vi +end + +function DynamicPPL.tilde_assume( + rng::Random.AbstractRNG, context::ADTypeCheckContext, sampler, right, vn, vi +) + value, vi = DynamicPPL.tilde_assume( + rng, DynamicPPL.childcontext(context), sampler, right, vn, vi + ) + check_adtype(context, vi) + return value, vi +end + +function DynamicPPL.tilde_observe!!(context::ADTypeCheckContext, right, left, vn, vi) + left, vi = DynamicPPL.tilde_observe(DynamicPPL.childcontext(context), right, left, vn, vi) + check_adtype(context, vi) + return left, vi +end + +# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # +# List of AD backends to test. + +""" +All the ADTypes on which we want to run the tests. +""" +adbackends = [ + Turing.AutoForwardDiff(), + Turing.AutoReverseDiff(; compile=false), + Turing.AutoMooncake(; config=nothing), +] + +end From 5d860d9b0df22aea1a639c4da1f844da976edcec Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 20 May 2025 17:53:05 +0100 Subject: [PATCH 02/49] More DPPL 0.37 compat work, WIP --- ext/TuringOptimExt.jl | 26 +++---- src/essential/container.jl | 73 ++++++++++++++++++ src/mcmc/Inference.jl | 11 +-- src/mcmc/ess.jl | 4 +- src/mcmc/hmc.jl | 11 ++- src/mcmc/particle_mcmc.jl | 18 ++--- src/optimisation/Optimisation.jl | 124 +++++++++++++++++++++--------- test/optimisation/Optimisation.jl | 45 ++++++----- test/test_utils/ad_utils.jl | 8 +- 9 files changed, 229 insertions(+), 91 deletions(-) create mode 100644 src/essential/container.jl diff --git a/ext/TuringOptimExt.jl b/ext/TuringOptimExt.jl index 9f5c51a2b4..635eb89111 100644 --- a/ext/TuringOptimExt.jl +++ b/ext/TuringOptimExt.jl @@ -34,8 +34,7 @@ function Optim.optimize( options::Optim.Options=Optim.Options(); kwargs..., ) - vi = DynamicPPL.setaccs!!(VarInfo(model), (DynamicPPL.LogLikelihoodAccumulator(),)) - f = Optimisation.OptimLogDensity(model, vi) + f = Optimisation.OptimLogDensity(model, DynamicPPL.getloglikelihood) init_vals = DynamicPPL.getparams(f.ldf) optimizer = Optim.LBFGS() return _mle_optimize(model, init_vals, optimizer, options; kwargs...) @@ -57,8 +56,7 @@ function Optim.optimize( options::Optim.Options=Optim.Options(); kwargs..., ) - vi = DynamicPPL.setaccs!!(VarInfo(model), (DynamicPPL.LogLikelihoodAccumulator(),)) - f = Optimisation.OptimLogDensity(model, vi) + f = Optimisation.OptimLogDensity(model, DynamicPPL.getloglikelihood) init_vals = DynamicPPL.getparams(f.ldf) return _mle_optimize(model, init_vals, optimizer, options; kwargs...) end @@ -74,8 +72,7 @@ function Optim.optimize( end function _mle_optimize(model::DynamicPPL.Model, args...; kwargs...) - vi = DynamicPPL.setaccs!!(VarInfo(model), (DynamicPPL.LogLikelihoodAccumulator(),)) - f = Optimisation.OptimLogDensity(model, vi) + f = Optimisation.OptimLogDensity(model, DynamicPPL.getloglikelihood) return _optimize(f, args...; kwargs...) end @@ -105,8 +102,7 @@ function Optim.optimize( options::Optim.Options=Optim.Options(); kwargs..., ) - vi = DynamicPPL.setaccs!!(VarInfo(model), (LogPriorWithoutJacobianAccumulator(), DynamicPPL.LogLikelihoodAccumulator(),)) - f = Optimisation.OptimLogDensity(model, vi) + f = Optimisation.OptimLogDensity(model, Optimisation.getlogjoint_without_jacobian) init_vals = DynamicPPL.getparams(f.ldf) optimizer = Optim.LBFGS() return _map_optimize(model, init_vals, optimizer, options; kwargs...) @@ -128,8 +124,7 @@ function Optim.optimize( options::Optim.Options=Optim.Options(); kwargs..., ) - vi = DynamicPPL.setaccs!!(VarInfo(model), (LogPriorWithoutJacobianAccumulator(), DynamicPPL.LogLikelihoodAccumulator(),)) - f = Optimisation.OptimLogDensity(model, vi) + f = Optimisation.OptimLogDensity(model, Optimisation.getlogjoint_without_jacobian) init_vals = DynamicPPL.getparams(f.ldf) return _map_optimize(model, init_vals, optimizer, options; kwargs...) end @@ -145,8 +140,7 @@ function Optim.optimize( end function _map_optimize(model::DynamicPPL.Model, args...; kwargs...) - vi = DynamicPPL.setaccs!!(VarInfo(model), (LogPriorWithoutJacobianAccumulator(), DynamicPPL.LogLikelihoodAccumulator(),)) - f = Optimisation.OptimLogDensity(model, vi) + f = Optimisation.OptimLogDensity(model, Optimisation.getlogjoint_without_jacobian) return _optimize(f, args...; kwargs...) end @@ -169,7 +163,9 @@ function _optimize( # whether initialisation is really necessary at all vi = DynamicPPL.unflatten(f.ldf.varinfo, init_vals) vi = DynamicPPL.link(vi, f.ldf.model) - f = Optimisation.OptimLogDensity(f.ldf.model, vi; adtype=f.ldf.adtype) + f = Optimisation.OptimLogDensity( + f.ldf.model, f.ldf.getlogdensity, vi; adtype=f.ldf.adtype + ) init_vals = DynamicPPL.getparams(f.ldf) # Optimize! @@ -186,7 +182,9 @@ function _optimize( # Get the optimum in unconstrained space. `getparams` does the invlinking. vi = f.ldf.varinfo vi_optimum = DynamicPPL.unflatten(vi, M.minimizer) - logdensity_optimum = Optimisation.OptimLogDensity(f.ldf.model, vi_optimum; adtype=f.ldf.adtype) + logdensity_optimum = Optimisation.OptimLogDensity( + f.ldf.model, f.ldf.getlogdensity, vi_optimum; adtype=f.ldf.adtype + ) vns_vals_iter = Turing.Inference.getparams(f.ldf.model, vi_optimum) varnames = map(Symbol ∘ first, vns_vals_iter) vals = map(last, vns_vals_iter) diff --git a/src/essential/container.jl b/src/essential/container.jl new file mode 100644 index 0000000000..5c78e110fa --- /dev/null +++ b/src/essential/container.jl @@ -0,0 +1,73 @@ +struct TracedModel{S<:AbstractSampler,V<:AbstractVarInfo,M<:Model,E<:Tuple} <: + AdvancedPS.AbstractGenericModel + model::M + sampler::S + varinfo::V + evaluator::E +end + +function TracedModel( + model::Model, + sampler::AbstractSampler, + varinfo::AbstractVarInfo, + rng::Random.AbstractRNG, +) + context = SamplingContext(rng, sampler, DefaultContext()) + args, kwargs = DynamicPPL.make_evaluate_args_and_kwargs(model, varinfo, context) + if kwargs !== nothing && !isempty(kwargs) + error( + "Sampling with `$(sampler.alg)` does not support models with keyword arguments. See issue #2007 for more details.", + ) + end + return TracedModel{AbstractSampler,AbstractVarInfo,Model,Tuple}( + model, sampler, varinfo, (model.f, args...) + ) +end + +function AdvancedPS.advance!( + trace::AdvancedPS.Trace{<:AdvancedPS.LibtaskModel{<:TracedModel}}, isref::Bool=false +) + # Make sure we load/reset the rng in the new replaying mechanism + # TODO(mhauru) Stop ignoring the return value. + DynamicPPL.increment_num_produce!!(trace.model.f.varinfo) + isref ? AdvancedPS.load_state!(trace.rng) : AdvancedPS.save_state!(trace.rng) + score = consume(trace.model.ctask) + if score === nothing + return nothing + else + return score + DynamicPPL.getlogp(trace.model.f.varinfo) + end +end + +function AdvancedPS.delete_retained!(trace::TracedModel) + DynamicPPL.set_retained_vns_del!(trace.varinfo) + return trace +end + +function AdvancedPS.reset_model(trace::TracedModel) + new_vi = DynamicPPL.reset_num_produce!!(trace.varinfo) + trace = TracedModel(trace.model, trace.sampler, new_vi, trace.evaluator) + return trace +end + +function AdvancedPS.reset_logprob!(trace::TracedModel) + # TODO(mhauru) Stop ignoring the return value. + DynamicPPL.resetlogp!!(trace.model.varinfo) + return trace +end + +function AdvancedPS.update_rng!( + trace::AdvancedPS.Trace{<:AdvancedPS.LibtaskModel{<:TracedModel}} +) + # Extract the `args`. + args = trace.model.ctask.args + # From `args`, extract the `SamplingContext`, which contains the RNG. + sampling_context = args[3] + rng = sampling_context.rng + trace.rng = rng + return trace +end + +function Libtask.TapedTask(model::TracedModel, ::Random.AbstractRNG, args...; kwargs...) # RNG ? + return Libtask.TapedTask(model.evaluator[1], model.evaluator[2:end]...; kwargs...) +end diff --git a/src/mcmc/Inference.jl b/src/mcmc/Inference.jl index 52d4277b0f..15efe2ad18 100644 --- a/src/mcmc/Inference.jl +++ b/src/mcmc/Inference.jl @@ -18,6 +18,7 @@ using DynamicPPL: push!!, setlogp!!, getlogp, + getlogjoint, VarName, getsym, getdist, @@ -136,7 +137,7 @@ end Transition(θ, lp) = Transition(θ, lp, nothing) function Transition(model::DynamicPPL.Model, vi::AbstractVarInfo, t) θ = getparams(model, vi) - lp = getlogp(vi) + lp = getlogjoint(vi) return Transition(θ, lp, getstats(t)) end @@ -149,10 +150,10 @@ function metadata(t::Transition) end end -DynamicPPL.getlogp(t::Transition) = t.lp +DynamicPPL.getlogjoint(t::Transition) = t.lp # Metadata of VarInfo object -metadata(vi::AbstractVarInfo) = (lp=getlogp(vi),) +metadata(vi::AbstractVarInfo) = (lp=getlogjoint(vi),) ########################## # Chain making utilities # @@ -215,7 +216,7 @@ function _params_to_array(model::DynamicPPL.Model, ts::Vector) end function get_transition_extras(ts::AbstractVector{<:VarInfo}) - valmat = reshape([getlogp(t) for t in ts], :, 1) + valmat = reshape([getlogjoint(t) for t in ts], :, 1) return [:lp], valmat end @@ -434,7 +435,7 @@ julia> chain = Chains(randn(2, 1, 1), ["m"]); # 2 samples of `m` julia> transitions = Turing.Inference.transitions_from_chain(m, chain); -julia> [Turing.Inference.getlogp(t) for t in transitions] # extract the logjoints +julia> [Turing.Inference.getlogjoint(t) for t in transitions] # extract the logjoints 2-element Array{Float64,1}: -3.6294991938628374 -2.5697948166987845 diff --git a/src/mcmc/ess.jl b/src/mcmc/ess.jl index 5205772032..aeafa13ad3 100644 --- a/src/mcmc/ess.jl +++ b/src/mcmc/ess.jl @@ -114,6 +114,8 @@ function DynamicPPL.tilde_assume( return DynamicPPL.tilde_assume(rng, ctx, SampleFromPrior(), right, vn, vi) end -function DynamicPPL.tilde_observe!!(ctx::DefaultContext, ::Sampler{<:ESS}, right, left, vn, vi) +function DynamicPPL.tilde_observe!!( + ctx::DefaultContext, ::Sampler{<:ESS}, right, left, vn, vi +) return DynamicPPL.tilde_observe!!(ctx, SampleFromPrior(), right, left, vn, vi) end diff --git a/src/mcmc/hmc.jl b/src/mcmc/hmc.jl index 5175a9831c..792175420f 100644 --- a/src/mcmc/hmc.jl +++ b/src/mcmc/hmc.jl @@ -214,7 +214,7 @@ function DynamicPPL.initialstep( theta = vi[:] # Cache current log density. - log_density_old = getlogp(vi) + log_density_old = getloglikelihood(vi) # Find good eps if not provided one if iszero(spl.alg.ϵ) @@ -242,10 +242,12 @@ function DynamicPPL.initialstep( # Update `vi` based on acceptance if t.stat.is_accept vi = DynamicPPL.unflatten(vi, t.z.θ) - vi = setlogp!!(vi, t.stat.log_density) + # TODO(mhauru) Is setloglikelihood! the right thing here? + vi = setloglikelihood!!(vi, t.stat.log_density) else vi = DynamicPPL.unflatten(vi, theta) - vi = setlogp!!(vi, log_density_old) + # TODO(mhauru) Is setloglikelihood! the right thing here? + vi = setloglikelihood!!(vi, log_density_old) end transition = Transition(model, vi, t) @@ -290,7 +292,8 @@ function AbstractMCMC.step( vi = state.vi if t.stat.is_accept vi = DynamicPPL.unflatten(vi, t.z.θ) - vi = setlogp!!(vi, t.stat.log_density) + # TODO(mhauru) Is setloglikelihood! the right thing here? + vi = setloglikelihood!!(vi, t.stat.log_density) end # Compute next transition and state. diff --git a/src/mcmc/particle_mcmc.jl b/src/mcmc/particle_mcmc.jl index a4d7ef1dc2..77820edd51 100644 --- a/src/mcmc/particle_mcmc.jl +++ b/src/mcmc/particle_mcmc.jl @@ -193,10 +193,10 @@ function DynamicPPL.initialstep( kwargs..., ) # Reset the VarInfo. - DynamicPPL.reset_num_produce!(vi) - DynamicPPL.set_retained_vns_del!(vi) - DynamicPPL.resetlogp!!(vi) - DynamicPPL.empty!!(vi) + vi = DynamicPPL.reset_num_produce!!(vi) + set_retained_vns_del!(vi) + vi = DynamicPPL.resetlogp!!(vi) + vi = DynamicPPL.empty!!(vi) # Create a new set of particles. particles = AdvancedPS.ParticleContainer( @@ -327,9 +327,9 @@ function DynamicPPL.initialstep( kwargs..., ) # Reset the VarInfo before new sweep - DynamicPPL.reset_num_produce!(vi) + vi = DynamicPPL.reset_num_produce!!(vi) DynamicPPL.set_retained_vns_del!(vi) - DynamicPPL.resetlogp!!(vi) + vi = DynamicPPL.resetlogp!!(vi) # Create a new set of particles num_particles = spl.alg.nparticles @@ -359,8 +359,8 @@ function AbstractMCMC.step( ) # Reset the VarInfo before new sweep. vi = state.vi - DynamicPPL.reset_num_produce!(vi) - DynamicPPL.resetlogp!!(vi) + vi = DynamicPPL.reset_num_produce!!(vi) + vi = DynamicPPL.resetlogp!!(vi) # Create reference particle for which the samples will be retained. reference = AdvancedPS.forkr(AdvancedPS.Trace(model, spl, vi, state.rng)) @@ -479,7 +479,7 @@ function AdvancedPS.Trace( rng::AdvancedPS.TracedRNG, ) newvarinfo = deepcopy(varinfo) - DynamicPPL.reset_num_produce!(newvarinfo) + newvarinfo = DynamicPPL.reset_num_produce!!(newvarinfo) tmodel = TracedModel(model, sampler, newvarinfo, rng) newtrace = AdvancedPS.Trace(tmodel, rng) diff --git a/src/optimisation/Optimisation.jl b/src/optimisation/Optimisation.jl index 23da8b08a6..80582019ea 100644 --- a/src/optimisation/Optimisation.jl +++ b/src/optimisation/Optimisation.jl @@ -62,56 +62,99 @@ end Create a new `LogPriorWithoutJacobianAccumulator` accumulator with the log prior initialized to zero. """ -LogPriorWithoutJacobianAccumulator{T}() where {T<:Real} = LogPriorWithoutJacobianAccumulator(zero(T)) -LogPriorWithoutJacobianAccumulator() = LogPriorWithoutJacobianAccumulator{DynamicPPL.LogProbType}() +LogPriorWithoutJacobianAccumulator{T}() where {T<:Real} = + LogPriorWithoutJacobianAccumulator(zero(T)) +function LogPriorWithoutJacobianAccumulator() + return LogPriorWithoutJacobianAccumulator{DynamicPPL.LogProbType}() +end function Base.show(io::IO, acc::LogPriorWithoutJacobianAccumulator) return print(io, "LogPriorWithoutJacobianAccumulator($(repr(acc.logp)))") end -# We use the same name for LogPriorWithoutJacobianAccumulator as for LogPriorAccumulator. -# This has three effects: -# 1. You can't have a VarInfo with both accumulator types. -# 2. When you call functions like `getlogprior` on a VarInfo, it will return the one without -# the Jacobian term, as if that was the usual log prior. -# 3. This may cause a small number of invalidations in DynamicPPL. I haven't checked, but I -# suspect they will be negligible. -# TODO(mhauru) Not sure I like this solution. It's kinda glib, but might confuse a reader -# of the code who expects things like `getlogprior` to always get the LogPriorAccumulator -# contents. Another solution would be welcome, but would need to play nicely with how -# LogDenssityFunction works, since it calls `getlogprior` explictily. -DynamicPPL.accumulator_name(::Type{<:LogPriorWithoutJacobianAccumulator}) = :LogPrior +function DynamicPPL.accumulator_name(::Type{<:LogPriorWithoutJacobianAccumulator}) + return :LogPriorWithoutJacobian +end -DynamicPPL.split(::LogPriorWithoutJacobianAccumulator{T}) where {T} = LogPriorWithoutJacobianAccumulator(zero(T)) +function DynamicPPL.split(::LogPriorWithoutJacobianAccumulator{T}) where {T} + return LogPriorWithoutJacobianAccumulator(zero(T)) +end -function DynamicPPL.combine(acc::LogPriorWithoutJacobianAccumulator, acc2::LogPriorWithoutJacobianAccumulator) +function DynamicPPL.combine( + acc::LogPriorWithoutJacobianAccumulator, acc2::LogPriorWithoutJacobianAccumulator +) return LogPriorWithoutJacobianAccumulator(acc.logp + acc2.logp) end -function Base.:+(acc1::LogPriorWithoutJacobianAccumulator, acc2::LogPriorWithoutJacobianAccumulator) +function Base.:+( + acc1::LogPriorWithoutJacobianAccumulator, acc2::LogPriorWithoutJacobianAccumulator +) return LogPriorWithoutJacobianAccumulator(acc1.logp + acc2.logp) end -Base.zero(acc::LogPriorWithoutJacobianAccumulator) = LogPriorWithoutJacobianAccumulator(zero(acc.logp)) +function Base.zero(acc::LogPriorWithoutJacobianAccumulator) + return LogPriorWithoutJacobianAccumulator(zero(acc.logp)) +end -function DynamicPPL.accumulate_assume!!(acc::LogPriorWithoutJacobianAccumulator, val, logjac, vn, right) +function DynamicPPL.accumulate_assume!!( + acc::LogPriorWithoutJacobianAccumulator, val, logjac, vn, right +) return acc + LogPriorWithoutJacobianAccumulator(Distributions.logpdf(right, val)) end -DynamicPPL.accumulate_observe!!(acc::LogPriorWithoutJacobianAccumulator, right, left, vn) = acc +function DynamicPPL.accumulate_observe!!( + acc::LogPriorWithoutJacobianAccumulator, right, left, vn +) + return acc +end -function Base.convert(::Type{LogPriorWithoutJacobianAccumulator{T}}, acc::LogPriorWithoutJacobianAccumulator) where {T} +function Base.convert( + ::Type{LogPriorWithoutJacobianAccumulator{T}}, acc::LogPriorWithoutJacobianAccumulator +) where {T} return LogPriorWithoutJacobianAccumulator(convert(T, acc.logp)) end -function DynamicPPL.convert_eltype(::Type{T}, acc::LogPriorWithoutJacobianAccumulator) where {T} +function DynamicPPL.convert_eltype( + ::Type{T}, acc::LogPriorWithoutJacobianAccumulator +) where {T} return LogPriorWithoutJacobianAccumulator(convert(T, acc.logp)) end +function getlogprior_without_jacobian(vi::DynamicPPL.AbstractVarInfo) + acc = DynamicPPL.getacc(vi, Val(:LogPriorWithoutJacobian)) + return acc.logp +end + +function getlogjoint_without_jacobian(vi::DynamicPPL.AbstractVarInfo) + return getlogprior_without_jacobian(vi) + DynamicPPL.getloglikelihood(vi) +end + +# This is called when constructing a LogDensityFunction, and ensures the VarInfo has the +# right accumulators. +function DynamicPPL.ldf_default_varinfo( + model::DynamicPPL.Model, ::typeof(getlogprior_without_jacobian) +) + vi = DynamicPPL.VarInfo(model) + vi = DynamicPPL.setaccs!!(vi, (LogPriorWithoutJacobianAccumulator(),)) + return vi +end + +function DynamicPPL.ldf_default_varinfo( + model::DynamicPPL.Model, ::typeof(getlogjoint_without_jacobian) +) + vi = DynamicPPL.VarInfo(model) + vi = DynamicPPL.setaccs!!( + vi, (LogPriorWithoutJacobianAccumulator(), DynamicPPL.LogLikelihoodAccumulator()) + ) + return vi +end + """ OptimLogDensity{ M<:DynamicPPL.Model, + F<:Function, V<:DynamicPPL.AbstractVarInfo, - AD<:ADTypes.AbstractADType + C<:DynamicPPL.AbstractContext, + AD<:ADTypes.AbstractADType, } A struct that wraps a single LogDensityFunction. Can be invoked either using @@ -147,14 +190,23 @@ optim_ld(z) # returns -logp """ struct OptimLogDensity{ M<:DynamicPPL.Model, + F<:Function, V<:DynamicPPL.AbstractVarInfo, + C<:DynamicPPL.AbstractContext, AD<:ADTypes.AbstractADType, } - ldf::DynamicPPL.LogDensityFunction{M,V,AD} + ldf::DynamicPPL.LogDensityFunction{M,F,V,C,AD} end -function OptimLogDensity(model::DynamicPPL.Model, vi::DynamicPPL.AbstractVarInfo=DynamicPPL.VarInfo(model); adtype=AutoForwardDiff()) - return OptimLogDensity(DynamicPPL.LogDensityFunction(model, vi; adtype=adtype)) +function OptimLogDensity( + model::DynamicPPL.Model, + getlogdensity::Function, + vi::DynamicPPL.AbstractVarInfo=DynamicPPL.ldf_default_varinfo(model, getlogdensity); + adtype=AutoForwardDiff(), +) + return OptimLogDensity( + DynamicPPL.LogDensityFunction(model, getlogdensity, vi; adtype=adtype) + ) end """ @@ -324,7 +376,9 @@ function StatsBase.informationmatrix( linked = DynamicPPL.istrans(old_ldf.varinfo) if linked new_vi = DynamicPPL.invlink!!(old_ldf.varinfo, old_ldf.model) - new_f = OptimLogDensity(old_ldf.model, new_vi; adtype=old_ldf.adtype) + new_f = OptimLogDensity( + old_ldf.model, old_ldf.getlogdensity, new_vi; adtype=old_ldf.adtype + ) m = Accessors.@set m.f = new_f end @@ -337,7 +391,9 @@ function StatsBase.informationmatrix( if linked invlinked_ldf = m.f.ldf new_vi = DynamicPPL.link!!(invlinked_ldf.varinfo, invlinked_ldf.model) - new_f = OptimLogDensity(invlinked_ldf.model, new_vi; adtype=invlinked_ldf.adtype) + new_f = OptimLogDensity( + invlinked_ldf.model, old_ldf.getlogdensity, new_vi; adtype=invlinked_ldf.adtype + ) m = Accessors.@set m.f = new_f end @@ -557,20 +613,16 @@ function estimate_mode( # Create an OptimLogDensity object that can be used to evaluate the objective function, # i.e. the negative log density. - accs = if estimator isa MAP - (LogPriorWithoutJacobianAccumulator(), DynamicPPL.LogLikelihoodAccumulator()) - else - (DynamicPPL.LogLikelihoodAccumulator(),) - end + getlogdensity = + estimator isa MAP ? getlogjoint_without_jacobian : DynamicPPL.getloglikelihood # Set its VarInfo to the initial parameters. # TODO(penelopeysm): Unclear if this is really needed? Any time that logp is calculated # (using `LogDensityProblems.logdensity(ldf, x)`) the parameters in the # varinfo are completely ignored. The parameters only matter if you are calling evaluate!! # directly on the fields of the LogDensityFunction - vi = DynamicPPL.VarInfo(model) + vi = DynamicPPL.ldf_default_varinfo(model, getlogdensity) vi = DynamicPPL.unflatten(vi, initial_params) - vi = DynamicPPL.setaccs!!(vi, accs) # Link the varinfo if needed. # TODO(mhauru) We currently couple together the questions of whether the user specified @@ -582,7 +634,7 @@ function estimate_mode( vi = DynamicPPL.link(vi, model) end - log_density = OptimLogDensity(model, vi) + log_density = OptimLogDensity(model, getlogdensity, vi) prob = Optimization.OptimizationProblem(log_density, adtype, constraints) solution = Optimization.solve(prob, solver; kwargs...) diff --git a/test/optimisation/Optimisation.jl b/test/optimisation/Optimisation.jl index 9850a3ed39..9909ee149b 100644 --- a/test/optimisation/Optimisation.jl +++ b/test/optimisation/Optimisation.jl @@ -24,7 +24,7 @@ using Turing hasstats(result) = result.optim_result.stats !== nothing # Issue: https://discourse.julialang.org/t/two-equivalent-conditioning-syntaxes-giving-different-likelihood-values/100320 - @testset "OptimizationContext" begin + @testset "OptimLogDensity and contexts" begin @model function model1(x) μ ~ Uniform(0, 2) return x ~ LogNormal(μ, 1) @@ -41,34 +41,44 @@ using Turing @testset "With ConditionContext" begin m1 = model1(x) m2 = model2() | (x=x,) - @test Turing.Optimisation.OptimLogDensity(m1)(w) == - Turing.Optimisation.OptimLogDensity(m2)(w) + # Doesn't matter if we use getlogjoint or getlogjoint_without_jacobian since the + # VarInfo isn't linked. + ld1 = Turing.Optimisation.OptimLogDensity( + m1, Turing.Optimisation.getlogjoint_without_jacobian + ) + ld2 = Turing.Optimisation.OptimLogDensity(m2, DynamicPPL.getlogjoint) + @test ld1(w) == ld2(w) end @testset "With prefixes" begin vn = @varname(inner) m1 = prefix(model1(x), vn) m2 = prefix((model2() | (x=x,)), vn) - @test Turing.Optimisation.OptimLogDensity(m1)(w) == - Turing.Optimisation.OptimLogDensity(m2)(w) + ld1 = Turing.Optimisation.OptimLogDensity( + m1, Turing.Optimisation.getlogjoint_without_jacobian + ) + ld2 = Turing.Optimisation.OptimLogDensity(m2, DynamicPPL.getlogjoint) + @test ld1(w) == ld2(w) end - @testset "Default, Likelihood, Prior Contexts" begin + @testset "Joint, prior, and likelihood" begin m1 = model1(x) - vi = DynamicPPL.VarInfo(m1) - vi_joint = DynamicPPL.setaccs!!(deepcopy(vi), (LogPriorWithoutJacobianAccumulator(), LogLikelihoodAccumulator())) - vi_prior = DynamicPPL.setaccs!!(deepcopy(vi), (LogPriorWithoutJacobianAccumulator(),)) - vi_likelihood = DynamicPPL.setaccs!!(deepcopy(vi), (LogLikelihoodAccumulator(),)) a = [0.3] - - @test Turing.Optimisation.OptimLogDensity(m1, vi_joint)(a) == - Turing.Optimisation.OptimLogDensity(m1, vi_prior)(a) + - Turing.Optimisation.OptimLogDensity(m1, vi_likelihood)(a) + ld_joint = Turing.Optimisation.OptimLogDensity( + m1, Turing.Optimisation.getlogjoint_without_jacobian + ) + ld_prior = Turing.Optimisation.OptimLogDensity( + m1, Turing.Optimisation.getlogprior_without_jacobian + ) + ld_likelihood = Turing.Optimisation.OptimLogDensity( + m1, DynamicPPL.getloglikelihood + ) + @test ld_joint(a) == ld_prior(a) + ld_likelihood(a) # test that the prior accumulator is calculating the right thing - @test Turing.Optimisation.OptimLogDensity(m1, vi_prior)([0.3]) ≈ + @test Turing.Optimisation.OptimLogDensity(m1, DynamicPPL.getlogprior)([0.3]) ≈ -Distributions.logpdf(Uniform(0, 2), 0.3) - @test Turing.Optimisation.OptimLogDensity(m1, vi_prior)([-0.3]) ≈ + @test Turing.Optimisation.OptimLogDensity(m1, DynamicPPL.getlogprior)([-0.3]) ≈ -Distributions.logpdf(Uniform(0, 2), -0.3) end end @@ -616,8 +626,7 @@ using Turing return nothing end m = saddle_model() - vi = DynamicPPL.setaccs!!(DynamicPPL.VarInfo(m), (LogLikelihoodAccumulator(),)) - optim_ld = Turing.Optimisation.OptimLogDensity(m, vi) + optim_ld = Turing.Optimisation.OptimLogDensity(m, DynamicPPL.getloglikelihood) vals = Turing.Optimisation.NamedArrays.NamedArray([0.0, 0.0]) m = Turing.Optimisation.ModeResult(vals, nothing, 0.0, optim_ld) ct = coeftable(m) diff --git a/test/test_utils/ad_utils.jl b/test/test_utils/ad_utils.jl index 36bd7a9f68..f1b5b3d145 100644 --- a/test/test_utils/ad_utils.jl +++ b/test/test_utils/ad_utils.jl @@ -147,9 +147,7 @@ end # child context. function DynamicPPL.tilde_assume(context::ADTypeCheckContext, right, vn, vi) - value, vi = DynamicPPL.tilde_assume( - DynamicPPL.childcontext(context), right, vn, vi - ) + value, vi = DynamicPPL.tilde_assume(DynamicPPL.childcontext(context), right, vn, vi) check_adtype(context, vi) return value, vi end @@ -165,7 +163,9 @@ function DynamicPPL.tilde_assume( end function DynamicPPL.tilde_observe!!(context::ADTypeCheckContext, right, left, vn, vi) - left, vi = DynamicPPL.tilde_observe(DynamicPPL.childcontext(context), right, left, vn, vi) + left, vi = DynamicPPL.tilde_observe!!( + DynamicPPL.childcontext(context), right, left, vn, vi + ) check_adtype(context, vi) return left, vi end From c7c46385855dfab0a358c5b95a9470a55fd866d4 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 17 Jul 2025 14:42:45 +0100 Subject: [PATCH 03/49] Add [sources] for DynamicPPL@0.37 --- Project.toml | 3 +++ test/Project.toml | 3 +++ 2 files changed, 6 insertions(+) diff --git a/Project.toml b/Project.toml index 4ab392230e..cdf6826368 100644 --- a/Project.toml +++ b/Project.toml @@ -90,3 +90,6 @@ julia = "1.10.2" [extras] DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb" Optim = "429524aa-4258-5aef-a3af-852621145aeb" + +[sources] +DynamicPPL = {url = "https://github.com/TuringLang/DynamicPPL.jl", rev = "breaking"} diff --git a/test/Project.toml b/test/Project.toml index 42f32936cb..7e25379122 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -77,3 +77,6 @@ StatsBase = "0.33, 0.34" StatsFuns = "0.9.5, 1" TimerOutputs = "0.5" julia = "1.10" + +[sources] +DynamicPPL = {url = "https://github.com/TuringLang/DynamicPPL.jl", rev = "breaking"} From f16a5cf7151d150f1233ca4df4d7c8181904822d Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sat, 19 Jul 2025 01:59:53 +0100 Subject: [PATCH 04/49] Remove context argument from `LogDensityFunction` --- ext/TuringDynamicHMCExt.jl | 9 +---- src/mcmc/ess.jl | 23 ++++++------ src/mcmc/external_sampler.jl | 4 +- src/mcmc/gibbs.jl | 2 +- src/mcmc/hmc.jl | 22 +---------- src/mcmc/mh.jl | 2 +- src/mcmc/particle_mcmc.jl | 3 ++ src/optimisation/Optimisation.jl | 49 +++++++++++++++++-------- src/variational/VariationalInference.jl | 10 +---- 9 files changed, 58 insertions(+), 66 deletions(-) diff --git a/ext/TuringDynamicHMCExt.jl b/ext/TuringDynamicHMCExt.jl index 5718e3855a..8cb917eb51 100644 --- a/ext/TuringDynamicHMCExt.jl +++ b/ext/TuringDynamicHMCExt.jl @@ -58,16 +58,11 @@ function DynamicPPL.initialstep( # Ensure that initial sample is in unconstrained space. if !DynamicPPL.islinked(vi) vi = DynamicPPL.link!!(vi, model) - vi = last(DynamicPPL.evaluate!!(model, vi, DynamicPPL.SamplingContext(rng, spl))) + vi = last(DynamicPPL.evaluate!!(model, vi)) end # Define log-density function. - ℓ = DynamicPPL.LogDensityFunction( - model, - vi, - DynamicPPL.SamplingContext(spl, DynamicPPL.DefaultContext()); - adtype=spl.alg.adtype, - ) + ℓ = DynamicPPL.LogDensityFunction(model, vi; adtype=spl.alg.adtype) # Perform initial step. results = DynamicHMC.mcmc_keep_warmup( diff --git a/src/mcmc/ess.jl b/src/mcmc/ess.jl index aeafa13ad3..0f5a7f5cc9 100644 --- a/src/mcmc/ess.jl +++ b/src/mcmc/ess.jl @@ -49,8 +49,8 @@ function AbstractMCMC.step( rng, EllipticalSliceSampling.ESSModel( ESSPrior(model, spl, vi), - DynamicPPL.LogDensityFunction{:LogLikelihood}( - model, vi, DynamicPPL.SamplingContext(spl, DynamicPPL.DefaultContext()) + ESSLikelihood( + DynamicPPL.LogDensityFunction(model, DynamicPPL.getloglikelihood, vi) ), ), EllipticalSliceSampling.ESS(), @@ -63,7 +63,7 @@ function AbstractMCMC.step( return Transition(model, vi), vi end - +f # Prior distribution of considered random variable struct ESSPrior{M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo,T} model::M @@ -97,6 +97,10 @@ function Base.rand(rng::Random.AbstractRNG, p::ESSPrior) sampler = p.sampler varinfo = p.varinfo # TODO: Surely there's a better way of doing this now that we have `SamplingContext`? + # TODO(DPPL0.37/penelopeysm): This can be replaced with `init!!(p.model, + # p.varinfo, PriorInit())` after TuringLang/DynamicPPL.jl#984. The reason + # why we had to use the 'del' flag before this was because + # SampleFromPrior() wouldn't overwrite existing variables. vns = keys(varinfo) for vn in vns set_flag!(varinfo, vn, "del") @@ -108,14 +112,9 @@ end # Mean of prior distribution Distributions.mean(p::ESSPrior) = p.μ -function DynamicPPL.tilde_assume( - rng::Random.AbstractRNG, ctx::DefaultContext, ::Sampler{<:ESS}, right, vn, vi -) - return DynamicPPL.tilde_assume(rng, ctx, SampleFromPrior(), right, vn, vi) +# Evaluate log-likelihood of proposals +struct ESSLogLikelihood{M<:Model,V<:AbstractVarInfo,AD<:ADTypes.AbstractADType} + ldf::DynamicPPL.LogDensityFunction{M,V,AD} end -function DynamicPPL.tilde_observe!!( - ctx::DefaultContext, ::Sampler{<:ESS}, right, left, vn, vi -) - return DynamicPPL.tilde_observe!!(ctx, SampleFromPrior(), right, left, vn, vi) -end +(ℓ::ESSLogLikelihood)(f::AbstractVector) = LogDensityProblems.logdensity(ℓ.ldf, f) diff --git a/src/mcmc/external_sampler.jl b/src/mcmc/external_sampler.jl index 7fa7692e4c..9524b2766d 100644 --- a/src/mcmc/external_sampler.jl +++ b/src/mcmc/external_sampler.jl @@ -96,12 +96,12 @@ getlogp_external(::Any, ::Any) = missing getlogp_external(mh::AdvancedMH.Transition, ::AdvancedMH.Transition) = mh.lp getlogp_external(hmc::AdvancedHMC.Transition, ::AdvancedHMC.HMCState) = hmc.stat.log_density -struct TuringState{S,V1<:AbstractVarInfo,M,V,C} +struct TuringState{S,V1<:AbstractVarInfo,M,V} state::S # Note that this varinfo has the correct parameters and logp obtained from # the state, whereas `ldf.varinfo` will in general have junk inside it. varinfo::V1 - ldf::DynamicPPL.LogDensityFunction{M,V,C} + ldf::DynamicPPL.LogDensityFunction{M,V} end varinfo(state::TuringState) = state.varinfo diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 34d372a9e4..0498ec2a8c 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -559,7 +559,7 @@ function setparams_varinfo!!( params::AbstractVarInfo, ) logdensity = DynamicPPL.LogDensityFunction( - model, state.ldf.varinfo, state.ldf.context; adtype=sampler.alg.adtype + model, state.ldf.varinfo; adtype=sampler.alg.adtype ) new_inner_state = setparams_varinfo!!( AbstractMCMC.LogDensityModel(logdensity), sampler, state.state, params diff --git a/src/mcmc/hmc.jl b/src/mcmc/hmc.jl index 792175420f..9d19a03c54 100644 --- a/src/mcmc/hmc.jl +++ b/src/mcmc/hmc.jl @@ -190,16 +190,7 @@ function DynamicPPL.initialstep( # Create a Hamiltonian. metricT = getmetricT(spl.alg) metric = metricT(length(theta)) - ldf = DynamicPPL.LogDensityFunction( - model, - vi, - # TODO(penelopeysm): Can we just use leafcontext(model.context)? Do we - # need to pass in the sampler? (In fact LogDensityFunction defaults to - # using leafcontext(model.context) so could we just remove the argument - # entirely?) - DynamicPPL.SamplingContext(rng, spl, DynamicPPL.leafcontext(model.context)); - adtype=spl.alg.adtype, - ) + ldf = DynamicPPL.LogDensityFunction(model, vi; adtype=spl.alg.adtype) lp_func = Base.Fix1(LogDensityProblems.logdensity, ldf) lp_grad_func = Base.Fix1(LogDensityProblems.logdensity_and_gradient, ldf) hamiltonian = AHMC.Hamiltonian(metric, lp_func, lp_grad_func) @@ -305,16 +296,7 @@ end function get_hamiltonian(model, spl, vi, state, n) metric = gen_metric(n, spl, state) - ldf = DynamicPPL.LogDensityFunction( - model, - vi, - # TODO(penelopeysm): Can we just use leafcontext(model.context)? Do we - # need to pass in the sampler? (In fact LogDensityFunction defaults to - # using leafcontext(model.context) so could we just remove the argument - # entirely?) - DynamicPPL.SamplingContext(spl, DynamicPPL.leafcontext(model.context)); - adtype=spl.alg.adtype, - ) + ldf = DynamicPPL.LogDensityFunction(model, vi; adtype=spl.alg.adtype) lp_func = Base.Fix1(LogDensityProblems.logdensity, ldf) lp_grad_func = Base.Fix1(LogDensityProblems.logdensity_and_gradient, ldf) return AHMC.Hamiltonian(metric, lp_func, lp_grad_func) diff --git a/src/mcmc/mh.jl b/src/mcmc/mh.jl index 97f4209bec..53b2ceae12 100644 --- a/src/mcmc/mh.jl +++ b/src/mcmc/mh.jl @@ -189,7 +189,7 @@ A log density function for the MH sampler. This variant uses the `set_namedtuple!` function to update the `VarInfo`. """ const MHLogDensityFunction{M<:Model,S<:Sampler{<:MH},V<:AbstractVarInfo} = - DynamicPPL.LogDensityFunction{M,V,<:DynamicPPL.SamplingContext{<:S},AD} where {AD} + DynamicPPL.LogDensityFunction{M,V,AD} where {AD} function LogDensityProblems.logdensity(f::MHLogDensityFunction, x::NamedTuple) vi = deepcopy(f.varinfo) diff --git a/src/mcmc/particle_mcmc.jl b/src/mcmc/particle_mcmc.jl index 77820edd51..340ab3fb6b 100644 --- a/src/mcmc/particle_mcmc.jl +++ b/src/mcmc/particle_mcmc.jl @@ -206,7 +206,9 @@ function DynamicPPL.initialstep( ) # Perform particle sweep. + @info "Hello!" logevidence = AdvancedPS.sweep!(rng, particles, spl.alg.resampler, spl) + @info "Goodbye!" # Extract the first particle and its weight. particle = particles.vals[1] @@ -222,6 +224,7 @@ end function AbstractMCMC.step( ::AbstractRNG, model::AbstractModel, spl::Sampler{<:SMC}, state::SMCState; kwargs... ) + @info "helloooooo from step" # Extract the index of the current particle. index = state.particleindex diff --git a/src/optimisation/Optimisation.jl b/src/optimisation/Optimisation.jl index 80582019ea..d85ffe6838 100644 --- a/src/optimisation/Optimisation.jl +++ b/src/optimisation/Optimisation.jl @@ -169,6 +169,9 @@ or OptimLogDensity(model; adtype=adtype) ``` +Here, `ctx` must be a context that contains an `OptimizationContext` as its +leaf. + If not specified, `adtype` defaults to `AutoForwardDiff()`. An OptimLogDensity does not, in itself, obey the LogDensityProblems interface. @@ -189,24 +192,40 @@ optim_ld(z) # returns -logp ``` """ struct OptimLogDensity{ - M<:DynamicPPL.Model, - F<:Function, - V<:DynamicPPL.AbstractVarInfo, - C<:DynamicPPL.AbstractContext, - AD<:ADTypes.AbstractADType, + M<:DynamicPPL.Model,F<:Function,V<:DynamicPPL.AbstractVarInfo,AD<:ADTypes.AbstractADType } - ldf::DynamicPPL.LogDensityFunction{M,F,V,C,AD} -end + ldf::DynamicPPL.LogDensityFunction{M,F,V,AD} -function OptimLogDensity( - model::DynamicPPL.Model, - getlogdensity::Function, - vi::DynamicPPL.AbstractVarInfo=DynamicPPL.ldf_default_varinfo(model, getlogdensity); - adtype=AutoForwardDiff(), -) - return OptimLogDensity( - DynamicPPL.LogDensityFunction(model, getlogdensity, vi; adtype=adtype) + # Inner constructors enforce that the model has an OptimizationContext as + # its leaf context. + function OptimLogDensity( + model::DynamicPPL.Model, + getlogdensity::Function, + vi::DynamicPPL.VarInfo, + ctx::OptimizationContext; + adtype::ADTypes.AbstractADType=Turing.DEFAULT_ADTYPE, ) + new_context = DynamicPPL.setleafcontext(model, ctx) + new_model = contextualize(model, new_context) + return new{typeof(new_model),typeof(getlogdensity),typeof(vi),typeof(adtype)}( + DynamicPPL.LogDensityFunction(new_model, getlogdensity, vi; adtype=adtype) + ) + end + function OptimLogDensity( + model::DynamicPPL.Model, + getlogdensity::Function, + ctx::OptimizationContext; + adtype::ADTypes.AbstractADType=Turing.DEFAULT_ADTYPE, + ) + # No varinfo + return OptimLogDensity( + model, + getlogdensity, + DynamicPPL.ldf_default_varinfo(model, getlogdensity), + ctx; + adtype=adtype, + ) + end end """ diff --git a/src/variational/VariationalInference.jl b/src/variational/VariationalInference.jl index b9428af112..d516319684 100644 --- a/src/variational/VariationalInference.jl +++ b/src/variational/VariationalInference.jl @@ -17,12 +17,6 @@ export vi, q_locationscale, q_meanfield_gaussian, q_fullrank_gaussian include("deprecated.jl") -function make_logdensity(model::DynamicPPL.Model) - weight = 1.0 - ctx = DynamicPPL.MiniBatchContext(DynamicPPL.DefaultContext(), weight) - return DynamicPPL.LogDensityFunction(model, DynamicPPL.VarInfo(model), ctx) -end - """ q_initialize_scale( [rng::Random.AbstractRNG,] @@ -68,7 +62,7 @@ function q_initialize_scale( num_max_trials::Int=10, reduce_factor::Real=one(eltype(scale)) / 2, ) - prob = make_logdensity(model) + prob = LogDensityFunction(model) ℓπ = Base.Fix1(LogDensityProblems.logdensity, prob) varinfo = DynamicPPL.VarInfo(model) @@ -309,7 +303,7 @@ function vi( ) return AdvancedVI.optimize( rng, - make_logdensity(model), + LogDensityFunction(model), objective, q, n_iterations; From 98d5e7ae4c2a6eeb684967db4894236a6ef77453 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sat, 19 Jul 2025 02:14:59 +0100 Subject: [PATCH 05/49] Fix MH --- src/mcmc/mh.jl | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/mcmc/mh.jl b/src/mcmc/mh.jl index 53b2ceae12..0c657dd65f 100644 --- a/src/mcmc/mh.jl +++ b/src/mcmc/mh.jl @@ -157,6 +157,8 @@ end # Utility functions # ##################### +# TODO(DPPL0.37/penelopeysm): This function should no longer be needed +# once InitContext is merged. """ set_namedtuple!(vi::VarInfo, nt::NamedTuple) @@ -181,17 +183,15 @@ function set_namedtuple!(vi::DynamicPPL.VarInfoOrThreadSafeVarInfo, nt::NamedTup end end -""" - MHLogDensityFunction - -A log density function for the MH sampler. - -This variant uses the `set_namedtuple!` function to update the `VarInfo`. -""" -const MHLogDensityFunction{M<:Model,S<:Sampler{<:MH},V<:AbstractVarInfo} = - DynamicPPL.LogDensityFunction{M,V,AD} where {AD} - -function LogDensityProblems.logdensity(f::MHLogDensityFunction, x::NamedTuple) +# NOTE(penelopeysm): MH does not conform to the usual LogDensityProblems +# interface in that it gets evaluated with a NamedTuple. Hence we need this +# method just to deal with MH. +# TODO(DPPL0.37/penelopeysm): Check the extent to which this method is actually +# needed. If it's still needed, replace this with `init!!(f.model, f.varinfo, +# ParamsInit(x))`. Much less hacky than `set_namedtuple!` (hopefully...). +# In general, we should much prefer to either (1) conform to the +# LogDensityProblems interface or (2) use VarNames anyway. +function LogDensityProblems.logdensity(f::LogDensityFunction, x::NamedTuple) vi = deepcopy(f.varinfo) set_namedtuple!(vi, x) vi_new = last(DynamicPPL.evaluate!!(f.model, vi, f.context)) From 73e127bcf4cfec823026cdbea9128b6f52f2ec84 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sat, 19 Jul 2025 02:16:04 +0100 Subject: [PATCH 06/49] Remove spurious logging --- src/mcmc/particle_mcmc.jl | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/mcmc/particle_mcmc.jl b/src/mcmc/particle_mcmc.jl index 340ab3fb6b..77820edd51 100644 --- a/src/mcmc/particle_mcmc.jl +++ b/src/mcmc/particle_mcmc.jl @@ -206,9 +206,7 @@ function DynamicPPL.initialstep( ) # Perform particle sweep. - @info "Hello!" logevidence = AdvancedPS.sweep!(rng, particles, spl.alg.resampler, spl) - @info "Goodbye!" # Extract the first particle and its weight. particle = particles.vals[1] @@ -224,7 +222,6 @@ end function AbstractMCMC.step( ::AbstractRNG, model::AbstractModel, spl::Sampler{<:SMC}, state::SMCState; kwargs... ) - @info "helloooooo from step" # Extract the index of the current particle. index = state.particleindex From ce0c782230aec5bb2b3b34b0eeb790049836f82b Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sat, 19 Jul 2025 02:20:57 +0100 Subject: [PATCH 07/49] Remove residual OptimizationContext --- src/optimisation/Optimisation.jl | 21 +++++---------------- 1 file changed, 5 insertions(+), 16 deletions(-) diff --git a/src/optimisation/Optimisation.jl b/src/optimisation/Optimisation.jl index d85ffe6838..3885961adf 100644 --- a/src/optimisation/Optimisation.jl +++ b/src/optimisation/Optimisation.jl @@ -153,7 +153,6 @@ end M<:DynamicPPL.Model, F<:Function, V<:DynamicPPL.AbstractVarInfo, - C<:DynamicPPL.AbstractContext, AD<:ADTypes.AbstractADType, } @@ -169,9 +168,6 @@ or OptimLogDensity(model; adtype=adtype) ``` -Here, `ctx` must be a context that contains an `OptimizationContext` as its -leaf. - If not specified, `adtype` defaults to `AutoForwardDiff()`. An OptimLogDensity does not, in itself, obey the LogDensityProblems interface. @@ -196,33 +192,26 @@ struct OptimLogDensity{ } ldf::DynamicPPL.LogDensityFunction{M,F,V,AD} - # Inner constructors enforce that the model has an OptimizationContext as - # its leaf context. function OptimLogDensity( model::DynamicPPL.Model, getlogdensity::Function, - vi::DynamicPPL.VarInfo, - ctx::OptimizationContext; + vi::DynamicPPL.VarInfo; adtype::ADTypes.AbstractADType=Turing.DEFAULT_ADTYPE, ) - new_context = DynamicPPL.setleafcontext(model, ctx) - new_model = contextualize(model, new_context) - return new{typeof(new_model),typeof(getlogdensity),typeof(vi),typeof(adtype)}( - DynamicPPL.LogDensityFunction(new_model, getlogdensity, vi; adtype=adtype) + return new{typeof(model),typeof(getlogdensity),typeof(vi),typeof(adtype)}( + DynamicPPL.LogDensityFunction(model, getlogdensity, vi; adtype=adtype) ) end function OptimLogDensity( model::DynamicPPL.Model, - getlogdensity::Function, - ctx::OptimizationContext; + getlogdensity::Function; adtype::ADTypes.AbstractADType=Turing.DEFAULT_ADTYPE, ) # No varinfo return OptimLogDensity( model, getlogdensity, - DynamicPPL.ldf_default_varinfo(model, getlogdensity), - ctx; + DynamicPPL.ldf_default_varinfo(model, getlogdensity); adtype=adtype, ) end From 4d03c07dd6c21228bc8210c56bf8a7f71190e940 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sat, 19 Jul 2025 02:50:52 +0100 Subject: [PATCH 08/49] Delete files that were removed in previous releases --- src/essential/container.jl | 73 -------------- test/test_utils/ad_utils.jl | 185 ------------------------------------ 2 files changed, 258 deletions(-) delete mode 100644 src/essential/container.jl delete mode 100644 test/test_utils/ad_utils.jl diff --git a/src/essential/container.jl b/src/essential/container.jl deleted file mode 100644 index 5c78e110fa..0000000000 --- a/src/essential/container.jl +++ /dev/null @@ -1,73 +0,0 @@ -struct TracedModel{S<:AbstractSampler,V<:AbstractVarInfo,M<:Model,E<:Tuple} <: - AdvancedPS.AbstractGenericModel - model::M - sampler::S - varinfo::V - evaluator::E -end - -function TracedModel( - model::Model, - sampler::AbstractSampler, - varinfo::AbstractVarInfo, - rng::Random.AbstractRNG, -) - context = SamplingContext(rng, sampler, DefaultContext()) - args, kwargs = DynamicPPL.make_evaluate_args_and_kwargs(model, varinfo, context) - if kwargs !== nothing && !isempty(kwargs) - error( - "Sampling with `$(sampler.alg)` does not support models with keyword arguments. See issue #2007 for more details.", - ) - end - return TracedModel{AbstractSampler,AbstractVarInfo,Model,Tuple}( - model, sampler, varinfo, (model.f, args...) - ) -end - -function AdvancedPS.advance!( - trace::AdvancedPS.Trace{<:AdvancedPS.LibtaskModel{<:TracedModel}}, isref::Bool=false -) - # Make sure we load/reset the rng in the new replaying mechanism - # TODO(mhauru) Stop ignoring the return value. - DynamicPPL.increment_num_produce!!(trace.model.f.varinfo) - isref ? AdvancedPS.load_state!(trace.rng) : AdvancedPS.save_state!(trace.rng) - score = consume(trace.model.ctask) - if score === nothing - return nothing - else - return score + DynamicPPL.getlogp(trace.model.f.varinfo) - end -end - -function AdvancedPS.delete_retained!(trace::TracedModel) - DynamicPPL.set_retained_vns_del!(trace.varinfo) - return trace -end - -function AdvancedPS.reset_model(trace::TracedModel) - new_vi = DynamicPPL.reset_num_produce!!(trace.varinfo) - trace = TracedModel(trace.model, trace.sampler, new_vi, trace.evaluator) - return trace -end - -function AdvancedPS.reset_logprob!(trace::TracedModel) - # TODO(mhauru) Stop ignoring the return value. - DynamicPPL.resetlogp!!(trace.model.varinfo) - return trace -end - -function AdvancedPS.update_rng!( - trace::AdvancedPS.Trace{<:AdvancedPS.LibtaskModel{<:TracedModel}} -) - # Extract the `args`. - args = trace.model.ctask.args - # From `args`, extract the `SamplingContext`, which contains the RNG. - sampling_context = args[3] - rng = sampling_context.rng - trace.rng = rng - return trace -end - -function Libtask.TapedTask(model::TracedModel, ::Random.AbstractRNG, args...; kwargs...) # RNG ? - return Libtask.TapedTask(model.evaluator[1], model.evaluator[2:end]...; kwargs...) -end diff --git a/test/test_utils/ad_utils.jl b/test/test_utils/ad_utils.jl deleted file mode 100644 index f1b5b3d145..0000000000 --- a/test/test_utils/ad_utils.jl +++ /dev/null @@ -1,185 +0,0 @@ -module ADUtils - -using ForwardDiff: ForwardDiff -using Pkg: Pkg -using Random: Random -using ReverseDiff: ReverseDiff -using Mooncake: Mooncake -using Test: Test -using Turing: Turing -using Turing: DynamicPPL - -export ADTypeCheckContext, adbackends - -# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # -# Stuff for checking that the right AD backend is being used. - -"""Element types that are always valid for a VarInfo regardless of ADType.""" -const always_valid_eltypes = (AbstractFloat, AbstractIrrational, Integer, Rational) - -"""A dictionary mapping ADTypes to the element types they use.""" -const eltypes_by_adtype = Dict( - Turing.AutoForwardDiff => (ForwardDiff.Dual,), - Turing.AutoReverseDiff => ( - ReverseDiff.TrackedArray, - ReverseDiff.TrackedMatrix, - ReverseDiff.TrackedReal, - ReverseDiff.TrackedStyle, - ReverseDiff.TrackedType, - ReverseDiff.TrackedVecOrMat, - ReverseDiff.TrackedVector, - ), - Turing.AutoMooncake => (Mooncake.CoDual,), -) - -""" - AbstractWrongADBackendError - -An abstract error thrown when we seem to be using a different AD backend than expected. -""" -abstract type AbstractWrongADBackendError <: Exception end - -""" - WrongADBackendError - -An error thrown when we seem to be using a different AD backend than expected. -""" -struct WrongADBackendError <: AbstractWrongADBackendError - actual_adtype::Type - expected_adtype::Type -end - -function Base.showerror(io::IO, e::WrongADBackendError) - return print( - io, "Expected to use $(e.expected_adtype), but using $(e.actual_adtype) instead." - ) -end - -""" - IncompatibleADTypeError - -An error thrown when an element type is encountered that is unexpected for the given ADType. -""" -struct IncompatibleADTypeError <: AbstractWrongADBackendError - valtype::Type - adtype::Type -end - -function Base.showerror(io::IO, e::IncompatibleADTypeError) - return print( - io, - "Incompatible ADType: Did not expect element of type $(e.valtype) with $(e.adtype)", - ) -end - -""" - ADTypeCheckContext{ADType,ChildContext} - -A context for checking that the expected ADType is being used. - -Evaluating a model with this context will check that the types of values in a `VarInfo` are -compatible with the ADType of the context. If the check fails, an `IncompatibleADTypeError` -is thrown. - -For instance, evaluating a model with -`ADTypeCheckContext(AutoForwardDiff(), child_context)` -would throw an error if within the model a type associated with e.g. ReverseDiff was -encountered. - -""" -struct ADTypeCheckContext{ADType,ChildContext<:DynamicPPL.AbstractContext} <: - DynamicPPL.AbstractContext - child::ChildContext - - function ADTypeCheckContext(adbackend, child) - adtype = adbackend isa Type ? adbackend : typeof(adbackend) - if !any(adtype <: k for k in keys(eltypes_by_adtype)) - throw(ArgumentError("Unsupported ADType: $adtype")) - end - return new{adtype,typeof(child)}(child) - end -end - -adtype(_::ADTypeCheckContext{ADType}) where {ADType} = ADType - -DynamicPPL.NodeTrait(::ADTypeCheckContext) = DynamicPPL.IsParent() -DynamicPPL.childcontext(c::ADTypeCheckContext) = c.child -function DynamicPPL.setchildcontext(c::ADTypeCheckContext, child) - return ADTypeCheckContext(adtype(c), child) -end - -""" - valid_eltypes(context::ADTypeCheckContext) - -Return the element types that are valid for the ADType of `context` as a tuple. -""" -function valid_eltypes(context::ADTypeCheckContext) - context_at = adtype(context) - for at in keys(eltypes_by_adtype) - if context_at <: at - return (eltypes_by_adtype[at]..., always_valid_eltypes...) - end - end - # This should never be reached due to the check in the inner constructor. - throw(ArgumentError("Unsupported ADType: $(adtype(context))")) -end - -""" - check_adtype(context::ADTypeCheckContext, vi::DynamicPPL.VarInfo) - -Check that the element types in `vi` are compatible with the ADType of `context`. - -Throw an `IncompatibleADTypeError` if an incompatible element type is encountered. -""" -function check_adtype(context::ADTypeCheckContext, vi::DynamicPPL.AbstractVarInfo) - valids = valid_eltypes(context) - for val in vi[:] - valtype = typeof(val) - if !any(valtype .<: valids) - throw(IncompatibleADTypeError(valtype, adtype(context))) - end - end - return nothing -end - -# A bunch of tilde_assume/tilde_observe methods that just call the same method on the child -# context, and then call check_adtype on the result before returning the results from the -# child context. - -function DynamicPPL.tilde_assume(context::ADTypeCheckContext, right, vn, vi) - value, vi = DynamicPPL.tilde_assume(DynamicPPL.childcontext(context), right, vn, vi) - check_adtype(context, vi) - return value, vi -end - -function DynamicPPL.tilde_assume( - rng::Random.AbstractRNG, context::ADTypeCheckContext, sampler, right, vn, vi -) - value, vi = DynamicPPL.tilde_assume( - rng, DynamicPPL.childcontext(context), sampler, right, vn, vi - ) - check_adtype(context, vi) - return value, vi -end - -function DynamicPPL.tilde_observe!!(context::ADTypeCheckContext, right, left, vn, vi) - left, vi = DynamicPPL.tilde_observe!!( - DynamicPPL.childcontext(context), right, left, vn, vi - ) - check_adtype(context, vi) - return left, vi -end - -# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # -# List of AD backends to test. - -""" -All the ADTypes on which we want to run the tests. -""" -adbackends = [ - Turing.AutoForwardDiff(), - Turing.AutoReverseDiff(; compile=false), - Turing.AutoMooncake(; config=nothing), -] - -end From 06fec2d0cb11731046e4835b998e982a5b2d9931 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sat, 19 Jul 2025 23:45:56 +0100 Subject: [PATCH 09/49] Fix typo --- src/mcmc/ess.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcmc/ess.jl b/src/mcmc/ess.jl index 0f5a7f5cc9..849a3ee305 100644 --- a/src/mcmc/ess.jl +++ b/src/mcmc/ess.jl @@ -63,7 +63,7 @@ function AbstractMCMC.step( return Transition(model, vi), vi end -f + # Prior distribution of considered random variable struct ESSPrior{M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo,T} model::M From 0af87259c59c82b95c4c8bb8b40f53c9165ee403 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sat, 19 Jul 2025 23:52:02 +0100 Subject: [PATCH 10/49] Simplify ESS --- src/mcmc/ess.jl | 43 +++++++++++++++++++------------------------ 1 file changed, 19 insertions(+), 24 deletions(-) diff --git a/src/mcmc/ess.jl b/src/mcmc/ess.jl index 849a3ee305..f669bdf7ea 100644 --- a/src/mcmc/ess.jl +++ b/src/mcmc/ess.jl @@ -24,7 +24,7 @@ struct ESS <: InferenceAlgorithm end # always accept in the first step function DynamicPPL.initialstep( - rng::AbstractRNG, model::Model, spl::Sampler{<:ESS}, vi::AbstractVarInfo; kwargs... + rng::AbstractRNG, model::Model, ::Sampler{<:ESS}, vi::AbstractVarInfo; kwargs... ) for vn in keys(vi) dist = getdist(vi, vn) @@ -35,7 +35,7 @@ function DynamicPPL.initialstep( end function AbstractMCMC.step( - rng::AbstractRNG, model::Model, spl::Sampler{<:ESS}, vi::AbstractVarInfo; kwargs... + rng::AbstractRNG, model::Model, ::Sampler{<:ESS}, vi::AbstractVarInfo; kwargs... ) # obtain previous sample f = vi[:] @@ -47,12 +47,7 @@ function AbstractMCMC.step( # compute next state sample, state = AbstractMCMC.step( rng, - EllipticalSliceSampling.ESSModel( - ESSPrior(model, spl, vi), - ESSLikelihood( - DynamicPPL.LogDensityFunction(model, DynamicPPL.getloglikelihood, vi) - ), - ), + EllipticalSliceSampling.ESSModel(ESSPrior(model, vi), ESSLikelihood(model, vi)), EllipticalSliceSampling.ESS(), oldstate, ) @@ -65,15 +60,12 @@ function AbstractMCMC.step( end # Prior distribution of considered random variable -struct ESSPrior{M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo,T} +struct ESSPrior{M<:Model,V<:AbstractVarInfo,T} model::M - sampler::S varinfo::V μ::T - function ESSPrior{M,S,V}( - model::M, sampler::S, varinfo::V - ) where {M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo} + function ESSPrior(model::Model, varinfo::AbstractVarInfo) vns = keys(varinfo) μ = mapreduce(vcat, vns) do vn dist = getdist(varinfo, vn) @@ -81,20 +73,15 @@ struct ESSPrior{M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo,T} error("[ESS] only supports Gaussian prior distributions") DynamicPPL.tovec(mean(dist)) end - return new{M,S,V,typeof(μ)}(model, sampler, varinfo, μ) + return new{typeof(model),typeof(varinfo),typeof(μ)}(model, varinfo, μ) end end -function ESSPrior(model::Model, sampler::Sampler{<:ESS}, varinfo::AbstractVarInfo) - return ESSPrior{typeof(model),typeof(sampler),typeof(varinfo)}(model, sampler, varinfo) -end - # Ensure that the prior is a Gaussian distribution (checked in the constructor) EllipticalSliceSampling.isgaussian(::Type{<:ESSPrior}) = true # Only define out-of-place sampling function Base.rand(rng::Random.AbstractRNG, p::ESSPrior) - sampler = p.sampler varinfo = p.varinfo # TODO: Surely there's a better way of doing this now that we have `SamplingContext`? # TODO(DPPL0.37/penelopeysm): This can be replaced with `init!!(p.model, @@ -105,16 +92,24 @@ function Base.rand(rng::Random.AbstractRNG, p::ESSPrior) for vn in vns set_flag!(varinfo, vn, "del") end - p.model(rng, varinfo, sampler) + p.model(rng, varinfo) return varinfo[:] end # Mean of prior distribution Distributions.mean(p::ESSPrior) = p.μ -# Evaluate log-likelihood of proposals -struct ESSLogLikelihood{M<:Model,V<:AbstractVarInfo,AD<:ADTypes.AbstractADType} - ldf::DynamicPPL.LogDensityFunction{M,V,AD} +# Evaluate log-likelihood of proposals. We need this struct because +# EllipticalSliceSampling.jl expects a callable struct / a function as its +# likelihood. +struct ESSLikelihood{M<:Model,V<:AbstractVarInfo} + ldf::DynamicPPL.LogDensityFunction{M,V} + + # Force usage of `getloglikelihood` in inner constructor + function ESSLogLikelihood(model::Model, varinfo::AbstractVarInfo) + ldf = DynamicPPL.LogDensityFunction(model, DynamicPPL.getloglikelihood, varinfo) + return new{typeof(model),typeof(varinfo)}(ldf) + end end -(ℓ::ESSLogLikelihood)(f::AbstractVector) = LogDensityProblems.logdensity(ℓ.ldf, f) +(ℓ::ESSLikelihood)(f::AbstractVector) = LogDensityProblems.logdensity(ℓ.ldf, f) From 3d44c123640dde7fe29e0c8ed29c3ca298c53693 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sat, 19 Jul 2025 23:54:04 +0100 Subject: [PATCH 11/49] Fix LDF --- ext/TuringDynamicHMCExt.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ext/TuringDynamicHMCExt.jl b/ext/TuringDynamicHMCExt.jl index 8cb917eb51..8a34d26498 100644 --- a/ext/TuringDynamicHMCExt.jl +++ b/ext/TuringDynamicHMCExt.jl @@ -62,7 +62,9 @@ function DynamicPPL.initialstep( end # Define log-density function. - ℓ = DynamicPPL.LogDensityFunction(model, vi; adtype=spl.alg.adtype) + ℓ = DynamicPPL.LogDensityFunction( + model, DynamicPPL.getlogjoint, vi; adtype=spl.alg.adtype + ) # Perform initial step. results = DynamicHMC.mcmc_keep_warmup( From a1837b5f07a7f357215e8f99b91eab5430de3b16 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 20 Jul 2025 00:31:02 +0100 Subject: [PATCH 12/49] Fix Prior(), fix a couple more imports --- src/mcmc/gibbs.jl | 5 +++-- src/mcmc/particle_mcmc.jl | 8 +++++--- src/mcmc/prior.jl | 11 +++++++---- 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 0498ec2a8c..84021432d9 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -549,7 +549,8 @@ function setparams_varinfo!!( # update its logprob. To do this, we have to call evaluate!! with the sampler, rather # than just a context, because ESS is peculiar in how it uses LikelihoodContext for # some variables and DefaultContext for others. - return last(DynamicPPL.evaluate!!(model, params, SamplingContext(sampler))) + # TODO(penelopeysm): Is this still needed? + return last(DynamicPPL.evaluate!!(model, params, DynamicPPL.SamplingContext(sampler))) end function setparams_varinfo!!( @@ -559,7 +560,7 @@ function setparams_varinfo!!( params::AbstractVarInfo, ) logdensity = DynamicPPL.LogDensityFunction( - model, state.ldf.varinfo; adtype=sampler.alg.adtype + model, DynamicPPL.getlogjoint, state.ldf.varinfo; adtype=sampler.alg.adtype ) new_inner_state = setparams_varinfo!!( AbstractMCMC.LogDensityModel(logdensity), sampler, state.state, params diff --git a/src/mcmc/particle_mcmc.jl b/src/mcmc/particle_mcmc.jl index 77820edd51..b44753ee6c 100644 --- a/src/mcmc/particle_mcmc.jl +++ b/src/mcmc/particle_mcmc.jl @@ -18,7 +18,7 @@ function TracedModel( varinfo::AbstractVarInfo, rng::Random.AbstractRNG, ) - context = SamplingContext(rng, sampler, DefaultContext()) + context = DynamicPPL.SamplingContext(rng, sampler, DefaultContext()) args, kwargs = DynamicPPL.make_evaluate_args_and_kwargs(model, varinfo, context) if kwargs !== nothing && !isempty(kwargs) error( @@ -395,7 +395,7 @@ function AbstractMCMC.step( end function DynamicPPL.use_threadsafe_eval( - ::SamplingContext{<:Sampler{<:Union{PG,SMC}}}, ::AbstractVarInfo + ::DynamicPPL.SamplingContext{<:Sampler{<:Union{PG,SMC}}}, ::AbstractVarInfo ) return false end @@ -457,7 +457,9 @@ end # end function DynamicPPL.acclogp!!( - context::SamplingContext{<:Sampler{<:Union{PG,SMC}}}, varinfo::AbstractVarInfo, logp + context::DynamicPPL.SamplingContext{<:Sampler{<:Union{PG,SMC}}}, + varinfo::AbstractVarInfo, + logp, ) varinfo_trace = trace_local_varinfo_maybe(varinfo) return DynamicPPL.acclogp!!(DynamicPPL.childcontext(context), varinfo_trace, logp) diff --git a/src/mcmc/prior.jl b/src/mcmc/prior.jl index c7a5cc5737..1b301508a6 100644 --- a/src/mcmc/prior.jl +++ b/src/mcmc/prior.jl @@ -12,14 +12,17 @@ function AbstractMCMC.step( state=nothing; kwargs..., ) + # TODO(DPPL0.37/penelopeysm): replace with init!! instead vi = last( DynamicPPL.evaluate!!( - model, - VarInfo(), - SamplingContext(rng, DynamicPPL.SampleFromPrior(), DynamicPPL.PriorContext()), + model, VarInfo(), DynamicPPL.SamplingContext(rng, DynamicPPL.SampleFromPrior()) ), ) - return vi, nothing + # Need to manually construct the Transition here because we only + # want to use the prior probability. + xs = Turing.Inference.getparams(model, vi) + lp = DynamicPPL.getlogprior(vi) + return Transition(xs, lp, nothing) end DynamicPPL.default_chain_type(sampler::Prior) = MCMCChains.Chains From 17efb8c5136d7855b3d8a72ac733eae83f94baf3 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 20 Jul 2025 00:46:18 +0100 Subject: [PATCH 13/49] fixes --- src/mcmc/abstractmcmc.jl | 4 +++- test/ad.jl | 8 ++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/mcmc/abstractmcmc.jl b/src/mcmc/abstractmcmc.jl index fd4d441bdd..4d55d5c698 100644 --- a/src/mcmc/abstractmcmc.jl +++ b/src/mcmc/abstractmcmc.jl @@ -1,7 +1,9 @@ # TODO: Implement additional checks for certain samplers, e.g. # HMC not supporting discrete parameters. function _check_model(model::DynamicPPL.Model) - return DynamicPPL.check_model(model; error_on_failure=true) + # TODO(DPPL0.37/penelopeysm): use InitContext + spl_model = DynamicPPL.contextualize(model, DynamicPPL.SamplingContext(model.context)) + return DynamicPPL.check_model(spl_model, VarInfo(); error_on_failure=true) end function _check_model(model::DynamicPPL.Model, alg::InferenceAlgorithm) return _check_model(model) diff --git a/test/ad.jl b/test/ad.jl index 2f645fab5d..24e1270eab 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -172,14 +172,14 @@ function DynamicPPL.tilde_assume( return value, logp, vi end -function DynamicPPL.tilde_observe(context::ADTypeCheckContext, right, left, vi) - logp, vi = DynamicPPL.tilde_observe(DynamicPPL.childcontext(context), right, left, vi) +function DynamicPPL.tilde_observe!!(context::ADTypeCheckContext, right, left, vi) + logp, vi = DynamicPPL.tilde_observe!!(DynamicPPL.childcontext(context), right, left, vi) check_adtype(context, vi) return logp, vi end -function DynamicPPL.tilde_observe(context::ADTypeCheckContext, sampler, right, left, vi) - logp, vi = DynamicPPL.tilde_observe( +function DynamicPPL.tilde_observe!!(context::ADTypeCheckContext, sampler, right, left, vi) + logp, vi = DynamicPPL.tilde_observe!!( DynamicPPL.childcontext(context), sampler, right, left, vi ) check_adtype(context, vi) From d62ad820c58dc4d24fdb73530ce556f429a6e073 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 20 Jul 2025 00:50:00 +0100 Subject: [PATCH 14/49] actually fix prior --- src/mcmc/prior.jl | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/mcmc/prior.jl b/src/mcmc/prior.jl index 1b301508a6..07d541ee33 100644 --- a/src/mcmc/prior.jl +++ b/src/mcmc/prior.jl @@ -12,12 +12,11 @@ function AbstractMCMC.step( state=nothing; kwargs..., ) - # TODO(DPPL0.37/penelopeysm): replace with init!! instead - vi = last( - DynamicPPL.evaluate!!( - model, VarInfo(), DynamicPPL.SamplingContext(rng, DynamicPPL.SampleFromPrior()) - ), + # TODO(DPPL0.37/penelopeysm): replace with init!! + sampling_model = DynamicPPL.contextualize( + model, DynamicPPL.SamplingContext(rng, DynamicPPL.SampleFromPrior(), model.context) ) + _, vi = DynamicPPL.evaluate!!(sampling_model, VarInfo()) # Need to manually construct the Transition here because we only # want to use the prior probability. xs = Turing.Inference.getparams(model, vi) From aac93f17dd940c93f3b61c45ab0a6a4c032c5bc6 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 20 Jul 2025 00:53:54 +0100 Subject: [PATCH 15/49] Remove extra return value from tilde_assume --- src/mcmc/gibbs.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 84021432d9..c2dc2a6df3 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -212,12 +212,12 @@ function DynamicPPL.tilde_assume( return if is_target_varname(context, vn) DynamicPPL.tilde_assume(rng, child_context, sampler, right, vn, vi) elseif has_conditioned_gibbs(context, vn) - value, lp, _ = DynamicPPL.tilde_assume( + value, _ = DynamicPPL.tilde_assume( child_context, right, vn, get_global_varinfo(context) ) - value, lp, vi + value, vi else - value, lp, new_global_vi = DynamicPPL.tilde_assume( + value, new_global_vi = DynamicPPL.tilde_assume( rng, child_context, DynamicPPL.SampleFromPrior(), @@ -226,7 +226,7 @@ function DynamicPPL.tilde_assume( get_global_varinfo(context), ) set_global_varinfo!(context, new_global_vi) - value, lp, vi + value, vi end end From e903d1c02b37572e206edd0fae772f2ad167222f Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 20 Jul 2025 00:55:13 +0100 Subject: [PATCH 16/49] fix ldf --- src/mcmc/hmc.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/mcmc/hmc.jl b/src/mcmc/hmc.jl index 9d19a03c54..1bfa774916 100644 --- a/src/mcmc/hmc.jl +++ b/src/mcmc/hmc.jl @@ -190,7 +190,9 @@ function DynamicPPL.initialstep( # Create a Hamiltonian. metricT = getmetricT(spl.alg) metric = metricT(length(theta)) - ldf = DynamicPPL.LogDensityFunction(model, vi; adtype=spl.alg.adtype) + ldf = DynamicPPL.LogDensityFunction( + model, DynamicPPL.getlogjoint, vi; adtype=spl.alg.adtype + ) lp_func = Base.Fix1(LogDensityProblems.logdensity, ldf) lp_grad_func = Base.Fix1(LogDensityProblems.logdensity_and_gradient, ldf) hamiltonian = AHMC.Hamiltonian(metric, lp_func, lp_grad_func) @@ -296,7 +298,9 @@ end function get_hamiltonian(model, spl, vi, state, n) metric = gen_metric(n, spl, state) - ldf = DynamicPPL.LogDensityFunction(model, vi; adtype=spl.alg.adtype) + ldf = DynamicPPL.LogDensityFunction( + model, DynamicPPL.getlogjoint, vi; adtype=spl.alg.adtype + ) lp_func = Base.Fix1(LogDensityProblems.logdensity, ldf) lp_grad_func = Base.Fix1(LogDensityProblems.logdensity_and_gradient, ldf) return AHMC.Hamiltonian(metric, lp_func, lp_grad_func) From fd5a815961a0993399ed903ae4e7f3c1b55cea1c Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 20 Jul 2025 00:56:26 +0100 Subject: [PATCH 17/49] actually fix prior --- src/mcmc/prior.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcmc/prior.jl b/src/mcmc/prior.jl index 07d541ee33..eadeaceb38 100644 --- a/src/mcmc/prior.jl +++ b/src/mcmc/prior.jl @@ -21,7 +21,7 @@ function AbstractMCMC.step( # want to use the prior probability. xs = Turing.Inference.getparams(model, vi) lp = DynamicPPL.getlogprior(vi) - return Transition(xs, lp, nothing) + return Transition(xs, lp, nothing), nothing end DynamicPPL.default_chain_type(sampler::Prior) = MCMCChains.Chains From 10a130acdaba52fd752f1740623b2aa633b2699f Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 20 Jul 2025 01:04:37 +0100 Subject: [PATCH 18/49] fix HMC log-density --- src/mcmc/hmc.jl | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/src/mcmc/hmc.jl b/src/mcmc/hmc.jl index 1bfa774916..1de0347b71 100644 --- a/src/mcmc/hmc.jl +++ b/src/mcmc/hmc.jl @@ -206,8 +206,8 @@ function DynamicPPL.initialstep( end theta = vi[:] - # Cache current log density. - log_density_old = getloglikelihood(vi) + # Cache current log density. We will reuse this if the transition is rejected. + logp_old = DynamicPPL.getlogp(vi) # Find good eps if not provided one if iszero(spl.alg.ϵ) @@ -232,15 +232,21 @@ function DynamicPPL.initialstep( ) end - # Update `vi` based on acceptance + # Update VarInfo based on acceptance if t.stat.is_accept vi = DynamicPPL.unflatten(vi, t.z.θ) - # TODO(mhauru) Is setloglikelihood! the right thing here? - vi = setloglikelihood!!(vi, t.stat.log_density) + # Re-evaluate to calculate log probability density. + # TODO(penelopeysm): This seems a little bit wasteful. The need for + # this stems from the fact that the HMC sampler doesn't keep track of + # prior and likelihood separately but rather a single log-joint, for + # which we have no way to decompose this back into prior and + # likelihood. I don't immediately see how to solve this without + # re-evaluating the model. + vi = DynamicPPL.evaluate!!(model, vi) else + # Reset VarInfo back to its original state. vi = DynamicPPL.unflatten(vi, theta) - # TODO(mhauru) Is setloglikelihood! the right thing here? - vi = setloglikelihood!!(vi, log_density_old) + vi = DynamicPPL.setlogp!!(vi, logp_old) end transition = Transition(model, vi, t) From c63072319a0ad0d8443ecda8d9700e3b5462bd9a Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 20 Jul 2025 01:05:40 +0100 Subject: [PATCH 19/49] fix ldf --- src/mcmc/external_sampler.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/mcmc/external_sampler.jl b/src/mcmc/external_sampler.jl index 9524b2766d..4f4203f5a9 100644 --- a/src/mcmc/external_sampler.jl +++ b/src/mcmc/external_sampler.jl @@ -156,7 +156,9 @@ function AbstractMCMC.step( end # Construct LogDensityFunction - f = DynamicPPL.LogDensityFunction(model, varinfo; adtype=alg.adtype) + f = DynamicPPL.LogDensityFunction( + model, DynamicPPL.getlogjoint, varinfo; adtype=alg.adtype + ) # Then just call `AbstractMCMC.step` with the right arguments. if initial_state === nothing From 9cbb2e9d1ed1997de4ca6fc569564317c06e15af Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 20 Jul 2025 01:07:02 +0100 Subject: [PATCH 20/49] fix make_evaluate_... --- src/mcmc/particle_mcmc.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/mcmc/particle_mcmc.jl b/src/mcmc/particle_mcmc.jl index b44753ee6c..ce6ea9184f 100644 --- a/src/mcmc/particle_mcmc.jl +++ b/src/mcmc/particle_mcmc.jl @@ -18,15 +18,16 @@ function TracedModel( varinfo::AbstractVarInfo, rng::Random.AbstractRNG, ) - context = DynamicPPL.SamplingContext(rng, sampler, DefaultContext()) - args, kwargs = DynamicPPL.make_evaluate_args_and_kwargs(model, varinfo, context) + spl_context = DynamicPPL.SamplingContext(rng, sampler, model.context) + spl_model = DynamicPPL.contextualize(model, spl_context) + args, kwargs = DynamicPPL.make_evaluate_args_and_kwargs(spl_model, varinfo) if kwargs !== nothing && !isempty(kwargs) error( "Sampling with `$(sampler.alg)` does not support models with keyword arguments. See issue #2007 for more details.", ) end return TracedModel{AbstractSampler,AbstractVarInfo,Model,Tuple}( - model, sampler, varinfo, (model.f, args...) + spl_model, sampler, varinfo, (spl_model.f, args...) ) end From 335cd2a8f4d400755aaaa319951800aa0324f9f5 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 20 Jul 2025 01:14:48 +0100 Subject: [PATCH 21/49] more fixes for evaluate!! --- src/mcmc/ess.jl | 2 +- src/mcmc/gibbs.jl | 13 ++++--------- 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/src/mcmc/ess.jl b/src/mcmc/ess.jl index f669bdf7ea..86b92b28ee 100644 --- a/src/mcmc/ess.jl +++ b/src/mcmc/ess.jl @@ -42,7 +42,7 @@ function AbstractMCMC.step( # define previous sampler state # (do not use cache to avoid in-place sampling from prior) - oldstate = EllipticalSliceSampling.ESSState(f, getlogp(vi), nothing) + oldstate = EllipticalSliceSampling.ESSState(f, DynamicPPL.getloglikelihood(vi), nothing) # compute next state sample, state = AbstractMCMC.step( diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index c2dc2a6df3..81281389ec 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -349,7 +349,7 @@ function initial_varinfo(rng, model, spl, initial_params) # This is a quick fix for https://github.com/TuringLang/Turing.jl/issues/1588 # and https://github.com/TuringLang/Turing.jl/issues/1563 # to avoid that existing variables are resampled - vi = last(DynamicPPL.evaluate!!(model, vi, DynamicPPL.DefaultContext())) + vi = last(DynamicPPL.evaluate!!(model, vi)) end return vi end @@ -534,9 +534,7 @@ function setparams_varinfo!!( ) # The state is already a VarInfo, so we can just return `params`, but first we need to # update its logprob. - # NOTE: Using `leafcontext(model.context)` here is a no-op, as it will be concatenated - # with `model.context` before hitting `model.f`. - return last(DynamicPPL.evaluate!!(model, params, DynamicPPL.leafcontext(model.context))) + return last(DynamicPPL.evaluate!!(model, params)) end function setparams_varinfo!!( @@ -546,11 +544,8 @@ function setparams_varinfo!!( params::AbstractVarInfo, ) # The state is already a VarInfo, so we can just return `params`, but first we need to - # update its logprob. To do this, we have to call evaluate!! with the sampler, rather - # than just a context, because ESS is peculiar in how it uses LikelihoodContext for - # some variables and DefaultContext for others. - # TODO(penelopeysm): Is this still needed? - return last(DynamicPPL.evaluate!!(model, params, DynamicPPL.SamplingContext(sampler))) + # update its logprob. + return last(DynamicPPL.evaluate!!(model, params)) end function setparams_varinfo!!( From c912fb94e143166eb341cd4d4cc0ad7d8d339e96 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 20 Jul 2025 01:18:35 +0100 Subject: [PATCH 22/49] fix hmc --- src/mcmc/hmc.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcmc/hmc.jl b/src/mcmc/hmc.jl index 1de0347b71..5b6977c611 100644 --- a/src/mcmc/hmc.jl +++ b/src/mcmc/hmc.jl @@ -242,7 +242,7 @@ function DynamicPPL.initialstep( # which we have no way to decompose this back into prior and # likelihood. I don't immediately see how to solve this without # re-evaluating the model. - vi = DynamicPPL.evaluate!!(model, vi) + _, vi = DynamicPPL.evaluate!!(model, vi) else # Reset VarInfo back to its original state. vi = DynamicPPL.unflatten(vi, theta) From 195f81929c20e720a15dea8cb834062556b8254f Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 20 Jul 2025 01:20:22 +0100 Subject: [PATCH 23/49] fix run_ad --- test/ad.jl | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/test/ad.jl b/test/ad.jl index 24e1270eab..91e09c276f 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -256,8 +256,10 @@ end @testset "model=$(model.f)" for model in DEMO_MODELS rng = StableRNG(123) - ctx = DynamicPPL.SamplingContext(rng, DynamicPPL.Sampler(alg)) - @test run_ad(model, adtype; context=ctx, test=true, benchmark=false) isa Any + spl_model = DynamicPPL.contextualize( + model, DynamicPPL.SamplingContext(rng, DynamicPPL.Sampler(alg)) + ) + @test run_ad(spl_model, adtype; test=true, benchmark=false) isa Any end end end @@ -283,8 +285,10 @@ end model, varnames, deepcopy(global_vi) ) rng = StableRNG(123) - ctx = DynamicPPL.SamplingContext(rng, DynamicPPL.Sampler(HMC(0.1, 10))) - @test run_ad(model, adtype; context=ctx, test=true, benchmark=false) isa Any + spl_model = DynamicPPL.contextualize( + model, DynamicPPL.SamplingContext(rng, DynamicPPL.Sampler(HMC(0.1, 10))) + ) + @test run_ad(spl_model, adtype; test=true, benchmark=false) isa Any end end end From cd52e9fefefb0a8e9abec68f0ae804e69c9ba328 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 20 Jul 2025 01:25:53 +0100 Subject: [PATCH 24/49] even more fixes (oh goodness when will this end) --- src/mcmc/emcee.jl | 8 ++++++-- src/mcmc/hmc.jl | 17 +++++++++-------- src/mcmc/mh.jl | 6 +++--- 3 files changed, 18 insertions(+), 13 deletions(-) diff --git a/src/mcmc/emcee.jl b/src/mcmc/emcee.jl index dfd1fc0d30..6f80dea114 100644 --- a/src/mcmc/emcee.jl +++ b/src/mcmc/emcee.jl @@ -53,10 +53,14 @@ function AbstractMCMC.step( length(initial_params) == n || throw(ArgumentError("initial parameters have to be specified for each walker")) vis = map(vis, initial_params) do vi, init + # TODO(DPPL0.37/penelopeysm) This whole thing can be replaced with init!! vi = DynamicPPL.initialize_parameters!!(vi, init, model) # Update log joint probability. - last(DynamicPPL.evaluate!!(model, rng, vi, SampleFromPrior())) + spl_model = DynamicPPL.contextualize( + model, DynamicPPL.SamplingContext(rng, SampleFromPrior(), model.context) + ) + last(DynamicPPL.evaluate!!(spl_model, vi)) end end @@ -68,7 +72,7 @@ function AbstractMCMC.step( vis[1], map(vis) do vi vi = DynamicPPL.link!!(vi, model) - AMH.Transition(vi[:], getlogp(vi), false) + AMH.Transition(vi[:], DynamicPPL.getlogjoint(vi), false) end, ) diff --git a/src/mcmc/hmc.jl b/src/mcmc/hmc.jl index 5b6977c611..e19f023437 100644 --- a/src/mcmc/hmc.jl +++ b/src/mcmc/hmc.jl @@ -236,12 +236,12 @@ function DynamicPPL.initialstep( if t.stat.is_accept vi = DynamicPPL.unflatten(vi, t.z.θ) # Re-evaluate to calculate log probability density. - # TODO(penelopeysm): This seems a little bit wasteful. The need for - # this stems from the fact that the HMC sampler doesn't keep track of - # prior and likelihood separately but rather a single log-joint, for - # which we have no way to decompose this back into prior and - # likelihood. I don't immediately see how to solve this without - # re-evaluating the model. + # TODO(penelopeysm): This seems a little bit wasteful. Unfortunately, + # even though `t.stat.log_density` contains some kind of logp, this + # doesn't track prior and likelihood separately but rather a single + # log-joint (and in linked space), so which we have no way to decompose + # this back into prior and likelihood. I don't immediately see how to + # solve this without re-evaluating the model. _, vi = DynamicPPL.evaluate!!(model, vi) else # Reset VarInfo back to its original state. @@ -291,8 +291,9 @@ function AbstractMCMC.step( vi = state.vi if t.stat.is_accept vi = DynamicPPL.unflatten(vi, t.z.θ) - # TODO(mhauru) Is setloglikelihood! the right thing here? - vi = setloglikelihood!!(vi, t.stat.log_density) + # Re-evaluate to calculate log probability density. + # TODO(penelopeysm): This seems a little bit wasteful. See note above. + _, vi = DynamicPPL.evaluate!!(model, vi) end # Compute next transition and state. diff --git a/src/mcmc/mh.jl b/src/mcmc/mh.jl index 0c657dd65f..019af79391 100644 --- a/src/mcmc/mh.jl +++ b/src/mcmc/mh.jl @@ -195,7 +195,7 @@ function LogDensityProblems.logdensity(f::LogDensityFunction, x::NamedTuple) vi = deepcopy(f.varinfo) set_namedtuple!(vi, x) vi_new = last(DynamicPPL.evaluate!!(f.model, vi, f.context)) - lj = getlogp(vi_new) + lj = f.getlogdensity(vi_new) return lj end @@ -304,7 +304,7 @@ function propose!!( # Create a sampler and the previous transition. mh_sampler = AMH.MetropolisHastings(dt) - prev_trans = AMH.Transition(vt, getlogp(vi), false) + prev_trans = AMH.Transition(vt, DynamicPPL.getlogjoint(vi), false) # Make a new transition. densitymodel = AMH.DensityModel( @@ -339,7 +339,7 @@ function propose!!( # Create a sampler and the previous transition. mh_sampler = AMH.MetropolisHastings(spl.alg.proposals) - prev_trans = AMH.Transition(vals, getlogp(vi), false) + prev_trans = AMH.Transition(vals, DynamicPPL.getlogjoint(vi), false) # Make a new transition. densitymodel = AMH.DensityModel( From 9360f18b54e6d21531c827bfaaab9fb9a65afa39 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 20 Jul 2025 01:27:50 +0100 Subject: [PATCH 25/49] more fixes --- src/mcmc/particle_mcmc.jl | 8 ++++---- src/mcmc/sghmc.jl | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/mcmc/particle_mcmc.jl b/src/mcmc/particle_mcmc.jl index ce6ea9184f..88f2b2204e 100644 --- a/src/mcmc/particle_mcmc.jl +++ b/src/mcmc/particle_mcmc.jl @@ -35,13 +35,13 @@ function AdvancedPS.advance!( trace::AdvancedPS.Trace{<:AdvancedPS.LibtaskModel{<:TracedModel}}, isref::Bool=false ) # Make sure we load/reset the rng in the new replaying mechanism - DynamicPPL.increment_num_produce!(trace.model.f.varinfo) + trace.model.f.varinfo = DynamicPPL.increment_num_produce!!(trace.model.f.varinfo) isref ? AdvancedPS.load_state!(trace.rng) : AdvancedPS.save_state!(trace.rng) score = consume(trace.model.ctask) if score === nothing return nothing else - return score + DynamicPPL.getlogp(trace.model.f.varinfo) + return score + DynamicPPL.getlogjoint(trace.model.f.varinfo) end end @@ -128,7 +128,7 @@ function SMCTransition(model::DynamicPPL.Model, vi::AbstractVarInfo, weight) # This is pretty useless since we reset the log probability continuously in the # particle sweep. - lp = getlogp(vi) + lp = DynamicPPL.getlogjoint(vi) return SMCTransition(theta, lp, weight) end @@ -307,7 +307,7 @@ function PGTransition(model::DynamicPPL.Model, vi::AbstractVarInfo, logevidence) # This is pretty useless since we reset the log probability continuously in the # particle sweep. - lp = getlogp(vi) + lp = DynamicPPL.getlogjoint(vi) return PGTransition(theta, lp, logevidence) end diff --git a/src/mcmc/sghmc.jl b/src/mcmc/sghmc.jl index 0c322244eb..2d669cd908 100644 --- a/src/mcmc/sghmc.jl +++ b/src/mcmc/sghmc.jl @@ -200,7 +200,7 @@ end function SGLDTransition(model::DynamicPPL.Model, vi::AbstractVarInfo, stepsize) theta = getparams(model, vi) - lp = getlogp(vi) + lp = DynamicPPL.getlogjoint(vi) return SGLDTransition(theta, lp, stepsize) end From 64ebd9271c99ebfbd8a17abfcf43e671f6d0aff4 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 20 Jul 2025 01:31:44 +0100 Subject: [PATCH 26/49] fix --- src/mcmc/particle_mcmc.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcmc/particle_mcmc.jl b/src/mcmc/particle_mcmc.jl index 88f2b2204e..ffae95b1c7 100644 --- a/src/mcmc/particle_mcmc.jl +++ b/src/mcmc/particle_mcmc.jl @@ -195,7 +195,7 @@ function DynamicPPL.initialstep( ) # Reset the VarInfo. vi = DynamicPPL.reset_num_produce!!(vi) - set_retained_vns_del!(vi) + DynamicPPL.set_retained_vns_del!(vi) vi = DynamicPPL.resetlogp!!(vi) vi = DynamicPPL.empty!!(vi) From 283d4dde819e57171e498cfec16f34e0787d1632 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 20 Jul 2025 01:44:36 +0100 Subject: [PATCH 27/49] more fix fix fix --- src/mcmc/is.jl | 13 ++++++++++--- src/mcmc/particle_mcmc.jl | 10 +++++----- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/src/mcmc/is.jl b/src/mcmc/is.jl index 9ad0e1f82a..fb4214cdb6 100644 --- a/src/mcmc/is.jl +++ b/src/mcmc/is.jl @@ -31,14 +31,20 @@ DynamicPPL.initialsampler(sampler::Sampler{<:IS}) = sampler function DynamicPPL.initialstep( rng::AbstractRNG, model::Model, spl::Sampler{<:IS}, vi::AbstractVarInfo; kwargs... ) - return Transition(model, vi), nothing + # Need to manually construct the Transition here because we only + # want to use the likelihood. + xs = Turing.Inference.getparams(model, vi) + lp = DynamicPPL.getloglikelihood(vi) + return Transition(xs, lp, nothing), nothing end function AbstractMCMC.step( rng::Random.AbstractRNG, model::Model, spl::Sampler{<:IS}, ::Nothing; kwargs... ) vi = VarInfo(rng, model, spl) - return Transition(model, vi), nothing + xs = Turing.Inference.getparams(model, vi) + lp = DynamicPPL.getloglikelihood(vi) + return Transition(xs, lp, nothing), nothing end # Calculate evidence. @@ -53,5 +59,6 @@ function DynamicPPL.assume(rng, ::Sampler{<:IS}, dist::Distribution, vn::VarName r = rand(rng, dist) vi = push!!(vi, vn, r, dist) end - return r, 0, vi + vi = accumulate_assume!!(vi, r, 0.0, vn, dist) + return r, vi end diff --git a/src/mcmc/particle_mcmc.jl b/src/mcmc/particle_mcmc.jl index ffae95b1c7..8e03a1505b 100644 --- a/src/mcmc/particle_mcmc.jl +++ b/src/mcmc/particle_mcmc.jl @@ -35,7 +35,9 @@ function AdvancedPS.advance!( trace::AdvancedPS.Trace{<:AdvancedPS.LibtaskModel{<:TracedModel}}, isref::Bool=false ) # Make sure we load/reset the rng in the new replaying mechanism - trace.model.f.varinfo = DynamicPPL.increment_num_produce!!(trace.model.f.varinfo) + trace = Accessors.@set trace.model.f.varinfo = DynamicPPL.increment_num_produce!!( + trace.model.f.varinfo + ) isref ? AdvancedPS.load_state!(trace.rng) : AdvancedPS.save_state!(trace.rng) score = consume(trace.model.ctask) if score === nothing @@ -51,13 +53,11 @@ function AdvancedPS.delete_retained!(trace::TracedModel) end function AdvancedPS.reset_model(trace::TracedModel) - DynamicPPL.reset_num_produce!(trace.varinfo) - return trace + return Accessors.@set trace.varinfo = DynamicPPL.reset_num_produce!!(trace.varinfo) end function AdvancedPS.reset_logprob!(trace::TracedModel) - DynamicPPL.resetlogp!!(trace.model.varinfo) - return trace + return Accessors.@set trace.model.varinfo = DynamicPPL.resetlogp!!(trace.model.varinfo) end function AdvancedPS.update_rng!( From b3461988e83f0f4f3202b2e59122194eb608b5e9 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 20 Jul 2025 01:46:30 +0100 Subject: [PATCH 28/49] fix return values of tilde pipeline --- test/ad.jl | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/test/ad.jl b/test/ad.jl index 91e09c276f..f53dd98358 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -155,35 +155,33 @@ end # child context. function DynamicPPL.tilde_assume(context::ADTypeCheckContext, right, vn, vi) - value, logp, vi = DynamicPPL.tilde_assume( - DynamicPPL.childcontext(context), right, vn, vi - ) + value, vi = DynamicPPL.tilde_assume(DynamicPPL.childcontext(context), right, vn, vi) check_adtype(context, vi) - return value, logp, vi + return value, vi end function DynamicPPL.tilde_assume( rng::Random.AbstractRNG, context::ADTypeCheckContext, sampler, right, vn, vi ) - value, logp, vi = DynamicPPL.tilde_assume( + value, vi = DynamicPPL.tilde_assume( rng, DynamicPPL.childcontext(context), sampler, right, vn, vi ) check_adtype(context, vi) - return value, logp, vi + return value, vi end function DynamicPPL.tilde_observe!!(context::ADTypeCheckContext, right, left, vi) - logp, vi = DynamicPPL.tilde_observe!!(DynamicPPL.childcontext(context), right, left, vi) + left, vi = DynamicPPL.tilde_observe!!(DynamicPPL.childcontext(context), right, left, vi) check_adtype(context, vi) - return logp, vi + return left, vi end function DynamicPPL.tilde_observe!!(context::ADTypeCheckContext, sampler, right, left, vi) - logp, vi = DynamicPPL.tilde_observe!!( + left, vi = DynamicPPL.tilde_observe!!( DynamicPPL.childcontext(context), sampler, right, left, vi ) check_adtype(context, vi) - return logp, vi + return left, vi end """ From 9012774493901433d24012879e41abd3f49f6e7c Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 20 Jul 2025 02:10:04 +0100 Subject: [PATCH 29/49] even more fixes --- src/mcmc/external_sampler.jl | 8 +++++++- src/mcmc/is.jl | 2 +- src/mcmc/particle_mcmc.jl | 9 +++++---- src/optimisation/Optimisation.jl | 7 ++++++- test/mcmc/external_sampler.jl | 4 +++- 5 files changed, 22 insertions(+), 8 deletions(-) diff --git a/src/mcmc/external_sampler.jl b/src/mcmc/external_sampler.jl index 4f4203f5a9..992a2fb2db 100644 --- a/src/mcmc/external_sampler.jl +++ b/src/mcmc/external_sampler.jl @@ -126,7 +126,13 @@ function make_updated_varinfo( return if ismissing(new_logp) last(DynamicPPL.evaluate!!(f.model, new_varinfo, f.context)) else - DynamicPPL.setlogp!!(new_varinfo, new_logp) + # TODO(DPPL0.37/penelopeysm) This is obviously wrong. Note that we + # have the same problem here as in HMC in that the sampler doesn't + # tell us about how logp is broken down into prior and likelihood. + # We should probably just re-evaluate unconditionally. A bit + # unfortunate. + DynamicPPL.setlogprior!!(new_varinfo, 0.0) + DynamicPPL.setloglikelihood!!(new_varinfo, new_logp) end end diff --git a/src/mcmc/is.jl b/src/mcmc/is.jl index fb4214cdb6..5f2f1627fb 100644 --- a/src/mcmc/is.jl +++ b/src/mcmc/is.jl @@ -59,6 +59,6 @@ function DynamicPPL.assume(rng, ::Sampler{<:IS}, dist::Distribution, vn::VarName r = rand(rng, dist) vi = push!!(vi, vn, r, dist) end - vi = accumulate_assume!!(vi, r, 0.0, vn, dist) + vi = DynamicPPL.accumulate_assume!!(vi, r, 0.0, vn, dist) return r, vi end diff --git a/src/mcmc/particle_mcmc.jl b/src/mcmc/particle_mcmc.jl index 8e03a1505b..549a4a02df 100644 --- a/src/mcmc/particle_mcmc.jl +++ b/src/mcmc/particle_mcmc.jl @@ -429,6 +429,8 @@ function trace_local_rng_maybe(rng::Random.AbstractRNG) end end +# TODO(DPPL0.37/penelopeysm) The whole tilde pipeline for particle MCMC needs to be +# thoroughly fixed. function DynamicPPL.assume( rng, ::Sampler{<:Union{PG,SMC}}, dist::Distribution, vn::VarName, _vi::AbstractVarInfo ) @@ -442,13 +444,12 @@ function DynamicPPL.assume( DynamicPPL.unset_flag!(vi, vn, "del") # Reference particle parent r = rand(trng, dist) vi[vn] = DynamicPPL.tovec(r) - DynamicPPL.setorder!(vi, vn, DynamicPPL.get_num_produce(vi)) + vi = DynamicPPL.setorder!!(vi, vn, DynamicPPL.get_num_produce(vi)) else r = vi[vn] end - # TODO: Should we make this `zero(promote_type(eltype(dist), eltype(r)))` or something? - lp = 0 - return r, lp, vi + # TODO: call accumulate_assume?! + return r, vi end # TODO(mhauru) Fix this. diff --git a/src/optimisation/Optimisation.jl b/src/optimisation/Optimisation.jl index 3885961adf..058514f60f 100644 --- a/src/optimisation/Optimisation.jl +++ b/src/optimisation/Optimisation.jl @@ -611,7 +611,12 @@ function estimate_mode( ub=nothing, kwargs..., ) - check_model && DynamicPPL.check_model(model; error_on_failure=true) + if check_model + spl_model = DynamicPPL.contextualize( + model, DynamicPPL.SamplingContext(model.context) + ) + DynamicPPL.check_model(spl_model, VarInfo(); error_on_failure=true) + end constraints = ModeEstimationConstraints(lb, ub, cons, lcons, ucons) initial_params = generate_initial_params(model, initial_params, constraints) diff --git a/test/mcmc/external_sampler.jl b/test/mcmc/external_sampler.jl index e2dc417d09..6a6aebddb0 100644 --- a/test/mcmc/external_sampler.jl +++ b/test/mcmc/external_sampler.jl @@ -20,7 +20,9 @@ function initialize_nuts(model::DynamicPPL.Model) linked_vi = DynamicPPL.link!!(vi, model) # Create a LogDensityFunction - f = DynamicPPL.LogDensityFunction(model, linked_vi; adtype=Turing.DEFAULT_ADTYPE) + f = DynamicPPL.LogDensityFunction( + model, DynamicPPL.getlogjoint, linked_vi; adtype=Turing.DEFAULT_ADTYPE + ) # Choose parameter dimensionality and initial parameter value D = LogDensityProblems.dimension(f) From e60058926e1e58dec425e5f22eb9e47e8ea70e0e Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 20 Jul 2025 12:42:52 +0100 Subject: [PATCH 30/49] Fix missing import --- src/optimisation/Optimisation.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/optimisation/Optimisation.jl b/src/optimisation/Optimisation.jl index 058514f60f..19652f9edf 100644 --- a/src/optimisation/Optimisation.jl +++ b/src/optimisation/Optimisation.jl @@ -615,7 +615,7 @@ function estimate_mode( spl_model = DynamicPPL.contextualize( model, DynamicPPL.SamplingContext(model.context) ) - DynamicPPL.check_model(spl_model, VarInfo(); error_on_failure=true) + DynamicPPL.check_model(spl_model, DynamicPPL.VarInfo(); error_on_failure=true) end constraints = ModeEstimationConstraints(lb, ub, cons, lcons, ucons) From 3d5072fd71a642f9c6799850e95f2b321b891e60 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 20 Jul 2025 12:49:13 +0100 Subject: [PATCH 31/49] More MH fixes --- src/mcmc/mh.jl | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/src/mcmc/mh.jl b/src/mcmc/mh.jl index 019af79391..78b0ab8809 100644 --- a/src/mcmc/mh.jl +++ b/src/mcmc/mh.jl @@ -307,22 +307,24 @@ function propose!!( prev_trans = AMH.Transition(vt, DynamicPPL.getlogjoint(vi), false) # Make a new transition. + spl_model = DynamicPPL.contextualize( + model, DynamicPPL.SamplingContext(rng, spl, model.context) + ) densitymodel = AMH.DensityModel( Base.Fix1( LogDensityProblems.logdensity, - DynamicPPL.LogDensityFunction( - model, - vi, - DynamicPPL.SamplingContext(rng, spl, DynamicPPL.leafcontext(model.context)), - ), + DynamicPPL.LogDensityFunction(spl_model, DynamicPPL.getlogjoint, vi), ), ) trans, _ = AbstractMCMC.step(rng, densitymodel, mh_sampler, prev_trans) # TODO: Make this compatible with immutable `VarInfo`. # Update the values in the VarInfo. + # TODO(DPPL0.37/penelopeysm): This is obviously incorrect. We need to + # re-evaluate the model. set_namedtuple!(vi, trans.params) - return setlogp!!(vi, trans.lp) + vi = DynamicPPL.setloglikelihood!!(vi, trans.lp) + return DynamicPPL.setlogprior!!(vi, 0.0) end # Make a proposal if we DO have a covariance proposal matrix. @@ -342,19 +344,22 @@ function propose!!( prev_trans = AMH.Transition(vals, DynamicPPL.getlogjoint(vi), false) # Make a new transition. + spl_model = DynamicPPL.contextualize( + model, DynamicPPL.SamplingContext(rng, spl, model.context) + ) densitymodel = AMH.DensityModel( Base.Fix1( LogDensityProblems.logdensity, - DynamicPPL.LogDensityFunction( - model, - vi, - DynamicPPL.SamplingContext(rng, spl, DynamicPPL.leafcontext(model.context)), - ), + DynamicPPL.LogDensityFunction(spl_model, DynamicPPL.getlogjoint, vi), ), ) trans, _ = AbstractMCMC.step(rng, densitymodel, mh_sampler, prev_trans) - return setlogp!!(DynamicPPL.unflatten(vi, trans.params), trans.lp) + # TODO(DPPL0.37/penelopeysm): This is obviously incorrect. We need to + # re-evaluate the model. + vi = DynamicPPL.unflatten(vi, trans.params) + vi = DynamicPPL.setloglikelihood!!(vi, trans.lp) + return DynamicPPL.setlogprior!!(vi, 0.0) end function DynamicPPL.initialstep( From 37466ccd7976fa9330b20004e964c42c4fe93135 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 20 Jul 2025 16:02:53 +0100 Subject: [PATCH 32/49] Fix conversion --- src/optimisation/Optimisation.jl | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/optimisation/Optimisation.jl b/src/optimisation/Optimisation.jl index 19652f9edf..0276d50fd0 100644 --- a/src/optimisation/Optimisation.jl +++ b/src/optimisation/Optimisation.jl @@ -195,11 +195,16 @@ struct OptimLogDensity{ function OptimLogDensity( model::DynamicPPL.Model, getlogdensity::Function, - vi::DynamicPPL.VarInfo; + vi::DynamicPPL.AbstractVarInfo; adtype::ADTypes.AbstractADType=Turing.DEFAULT_ADTYPE, ) - return new{typeof(model),typeof(getlogdensity),typeof(vi),typeof(adtype)}( - DynamicPPL.LogDensityFunction(model, getlogdensity, vi; adtype=adtype) + # Note that typeof(adtype) != typeof(ldf.adtype) in general because of + # DynamicPPL's tweak_adtype + ldf = DynamicPPL.LogDensityFunction(model, getlogdensity, vi; adtype=adtype) + return new{ + typeof(ldf.model),typeof(ldf.getlogdensity),typeof(ldf.vi),typeof(ldf.adtype) + }( + ldf ) end function OptimLogDensity( From 1b73e5a5b28a9362a3cd638c375621ecc4ad5146 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 20 Jul 2025 16:04:06 +0100 Subject: [PATCH 33/49] don't think it really needs those type params --- src/optimisation/Optimisation.jl | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/src/optimisation/Optimisation.jl b/src/optimisation/Optimisation.jl index 0276d50fd0..c0faa14e5f 100644 --- a/src/optimisation/Optimisation.jl +++ b/src/optimisation/Optimisation.jl @@ -187,10 +187,8 @@ optim_ld = OptimLogDensity(model, varinfo) optim_ld(z) # returns -logp ``` """ -struct OptimLogDensity{ - M<:DynamicPPL.Model,F<:Function,V<:DynamicPPL.AbstractVarInfo,AD<:ADTypes.AbstractADType -} - ldf::DynamicPPL.LogDensityFunction{M,F,V,AD} +struct OptimLogDensity{L<:DynamicPPL.LogDensityFunction} + ldf::L function OptimLogDensity( model::DynamicPPL.Model, @@ -198,14 +196,8 @@ struct OptimLogDensity{ vi::DynamicPPL.AbstractVarInfo; adtype::ADTypes.AbstractADType=Turing.DEFAULT_ADTYPE, ) - # Note that typeof(adtype) != typeof(ldf.adtype) in general because of - # DynamicPPL's tweak_adtype ldf = DynamicPPL.LogDensityFunction(model, getlogdensity, vi; adtype=adtype) - return new{ - typeof(ldf.model),typeof(ldf.getlogdensity),typeof(ldf.vi),typeof(ldf.adtype) - }( - ldf - ) + return new{typeof(ldf)}(ldf) end function OptimLogDensity( model::DynamicPPL.Model, From 66a854442b18837d8b129731df782988f50c5f50 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 20 Jul 2025 16:53:32 +0100 Subject: [PATCH 34/49] implement copy for LogPriorWithoutJacAcc --- src/optimisation/Optimisation.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/optimisation/Optimisation.jl b/src/optimisation/Optimisation.jl index c0faa14e5f..63318c27ae 100644 --- a/src/optimisation/Optimisation.jl +++ b/src/optimisation/Optimisation.jl @@ -76,6 +76,8 @@ function DynamicPPL.accumulator_name(::Type{<:LogPriorWithoutJacobianAccumulator return :LogPriorWithoutJacobian end +Base.copy(acc::LogPriorWithoutJacobianAccumulator) = acc + function DynamicPPL.split(::LogPriorWithoutJacobianAccumulator{T}) where {T} return LogPriorWithoutJacobianAccumulator(zero(T)) end From 98e70c27a285549e56670ebe352b14932fb2f921 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 20 Jul 2025 17:25:21 +0100 Subject: [PATCH 35/49] Even more fixes --- ext/TuringDynamicHMCExt.jl | 4 +++- src/mcmc/mh.jl | 2 +- test/mcmc/hmc.jl | 19 ------------------- test/mcmc/mh.jl | 20 -------------------- 4 files changed, 4 insertions(+), 41 deletions(-) diff --git a/ext/TuringDynamicHMCExt.jl b/ext/TuringDynamicHMCExt.jl index 8a34d26498..144f9e7008 100644 --- a/ext/TuringDynamicHMCExt.jl +++ b/ext/TuringDynamicHMCExt.jl @@ -75,7 +75,9 @@ function DynamicPPL.initialstep( # Update the variables. vi = DynamicPPL.unflatten(vi, Q.q) - vi = DynamicPPL.setlogp!!(vi, Q.ℓq) + # TODO(DPPL0.37/penelopeysm): This is obviously incorrect. Fix this. + vi = DynamicPPL.setloglikelihood!!(vi, Q.ℓq) + vi = DynamicPPL.setlogprior!!(vi, 0.0) # Create first sample and state. sample = Turing.Inference.Transition(model, vi) diff --git a/src/mcmc/mh.jl b/src/mcmc/mh.jl index 78b0ab8809..e0258fda45 100644 --- a/src/mcmc/mh.jl +++ b/src/mcmc/mh.jl @@ -194,7 +194,7 @@ end function LogDensityProblems.logdensity(f::LogDensityFunction, x::NamedTuple) vi = deepcopy(f.varinfo) set_namedtuple!(vi, x) - vi_new = last(DynamicPPL.evaluate!!(f.model, vi, f.context)) + vi_new = last(DynamicPPL.evaluate!!(f.model, vi)) lj = f.getlogdensity(vi_new) return lj end diff --git a/test/mcmc/hmc.jl b/test/mcmc/hmc.jl index f78c7a0237..839dffbbe5 100644 --- a/test/mcmc/hmc.jl +++ b/test/mcmc/hmc.jl @@ -171,25 +171,6 @@ using Turing @test Array(res1) == Array(res2) == Array(res3) end - # TODO(mhauru) Do we give up being able to sample from only prior/likelihood like this, - # or do we implement some way to pass `whichlogprob=:LogPrior` through `sample`? - @testset "prior" begin - # NOTE: Used to use `InverseGamma(2, 3)` but this has infinite variance - # which means that it's _very_ difficult to find a good tolerance in the test below:) - prior_dist = truncated(Normal(3, 1); lower=0) - - @model function demo_hmc_prior() - s ~ prior_dist - return m ~ Normal(0, sqrt(s)) - end - alg = NUTS(1000, 0.8) - gdemo_default_prior = DynamicPPL.contextualize( - demo_hmc_prior(), DynamicPPL.PriorContext() - ) - chain = sample(gdemo_default_prior, alg, 5_000; initial_params=[3.0, 0.0]) - check_numerical(chain, [:s, :m], [mean(prior_dist), 0]; atol=0.2) - end - @testset "warning for difficult init params" begin attempt = 0 @model function demo_warn_initial_params() diff --git a/test/mcmc/mh.jl b/test/mcmc/mh.jl index 3bbb83db5f..70810e1643 100644 --- a/test/mcmc/mh.jl +++ b/test/mcmc/mh.jl @@ -262,26 +262,6 @@ GKernel(var) = (x) -> Normal(x, sqrt.(var)) @test !DynamicPPL.islinked(vi) end - # TODO(mhauru) Do we give up being able to sample from only prior/likelihood like this, - # or do we implement some way to pass `whichlogprob=:LogPrior` through `sample`? - @testset "prior" begin - alg = MH() - gdemo_default_prior = DynamicPPL.contextualize( - gdemo_default, DynamicPPL.PriorContext() - ) - burnin = 10_000 - n = 10_000 - chain = sample( - StableRNG(seed), - gdemo_default_prior, - alg, - n; - discard_initial=burnin, - thinning=10, - ) - check_numerical(chain, [:s, :m], [mean(InverseGamma(2, 3)), 0]; atol=0.3) - end - @testset "`filldist` proposal (issue #2180)" begin @model demo_filldist_issue2180() = x ~ MvNormal(zeros(3), I) chain = sample( From d2c1c92cd317368be875f76c6f1085d781f2f071 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 20 Jul 2025 19:38:06 +0100 Subject: [PATCH 36/49] More fixes; I think the remaining failures are pMCMC related --- src/Turing.jl | 2 -- src/mcmc/emcee.jl | 5 ++++- src/mcmc/ess.jl | 2 +- src/mcmc/sghmc.jl | 8 ++++---- 4 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/Turing.jl b/src/Turing.jl index 1ff2310174..0cdbe24586 100644 --- a/src/Turing.jl +++ b/src/Turing.jl @@ -71,7 +71,6 @@ using DynamicPPL: unfix, prefix, conditioned, - @submodel, to_submodel, LogDensityFunction, @addlogprob! @@ -81,7 +80,6 @@ using OrderedCollections: OrderedDict # Turing essentials - modelling macros and inference algorithms export # DEPRECATED - @submodel, generated_quantities, # Modelling - AbstractPPL and DynamicPPL @model, diff --git a/src/mcmc/emcee.jl b/src/mcmc/emcee.jl index 6f80dea114..cdbaf9fd02 100644 --- a/src/mcmc/emcee.jl +++ b/src/mcmc/emcee.jl @@ -85,7 +85,10 @@ function AbstractMCMC.step( # Generate a log joint function. vi = state.vi densitymodel = AMH.DensityModel( - Base.Fix1(LogDensityProblems.logdensity, DynamicPPL.LogDensityFunction(model, vi)) + Base.Fix1( + LogDensityProblems.logdensity, + DynamicPPL.LogDensityFunction(model, DynamicPPL.getlogjoint, vi), + ), ) # Compute the next states. diff --git a/src/mcmc/ess.jl b/src/mcmc/ess.jl index 86b92b28ee..c49b52d3a4 100644 --- a/src/mcmc/ess.jl +++ b/src/mcmc/ess.jl @@ -106,7 +106,7 @@ struct ESSLikelihood{M<:Model,V<:AbstractVarInfo} ldf::DynamicPPL.LogDensityFunction{M,V} # Force usage of `getloglikelihood` in inner constructor - function ESSLogLikelihood(model::Model, varinfo::AbstractVarInfo) + function ESSLikelihood(model::Model, varinfo::AbstractVarInfo) ldf = DynamicPPL.LogDensityFunction(model, DynamicPPL.getloglikelihood, varinfo) return new{typeof(model),typeof(varinfo)}(ldf) end diff --git a/src/mcmc/sghmc.jl b/src/mcmc/sghmc.jl index 2d669cd908..6eeddfefa0 100644 --- a/src/mcmc/sghmc.jl +++ b/src/mcmc/sghmc.jl @@ -61,7 +61,7 @@ function DynamicPPL.initialstep( # Transform the samples to unconstrained space and compute the joint log probability. if !DynamicPPL.islinked(vi) vi = DynamicPPL.link!!(vi, model) - vi = last(DynamicPPL.evaluate!!(model, vi, DynamicPPL.SamplingContext(rng, spl))) + vi = last(DynamicPPL.evaluate!!(model, vi)) end # Compute initial sample and state. @@ -100,7 +100,7 @@ function AbstractMCMC.step( # Save new variables and recompute log density. vi = DynamicPPL.unflatten(vi, θ) - vi = last(DynamicPPL.evaluate!!(model, vi, DynamicPPL.SamplingContext(rng, spl))) + vi = last(DynamicPPL.evaluate!!(model, vi)) # Compute next sample and state. sample = Transition(model, vi) @@ -224,7 +224,7 @@ function DynamicPPL.initialstep( # Transform the samples to unconstrained space and compute the joint log probability. if !DynamicPPL.islinked(vi) vi = DynamicPPL.link!!(vi, model) - vi = last(DynamicPPL.evaluate!!(model, vi, DynamicPPL.SamplingContext(rng, spl))) + vi = last(DynamicPPL.evaluate!!(model, vi)) end # Create first sample and state. @@ -254,7 +254,7 @@ function AbstractMCMC.step( # Save new variables and recompute log density. vi = DynamicPPL.unflatten(vi, θ) - vi = last(DynamicPPL.evaluate!!(model, vi, DynamicPPL.SamplingContext(rng, spl))) + vi = last(DynamicPPL.evaluate!!(model, vi)) # Compute next sample and state. sample = SGLDTransition(model, vi, stepsize) From 11a2a31492a4fb4bfed0230ac2a6e72f7875e504 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 21 Jul 2025 15:31:04 +0100 Subject: [PATCH 37/49] Fix merge --- src/mcmc/particle_mcmc.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcmc/particle_mcmc.jl b/src/mcmc/particle_mcmc.jl index cca7fbc754..6cf8fc3152 100644 --- a/src/mcmc/particle_mcmc.jl +++ b/src/mcmc/particle_mcmc.jl @@ -483,7 +483,7 @@ end # called on `:invoke` expressions rather than `:call`s, but since those are implementation # details of the compiler, we set a bunch of methods as might_produce = true. We start with # `acclogp_observe!!` which is what calls `produce` and go up the call stack. -Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.acclogp_observe!!),Vararg}}) = true +# Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.acclogp_observe!!),Vararg}}) = true Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.tilde_observe!!),Vararg}}) = true Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.evaluate!!),Vararg}}) = true function Libtask.might_produce( From c062867aefe579f5a00cfde8ff40c0a2b39f2e75 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 31 Jul 2025 12:06:00 +0100 Subject: [PATCH 38/49] DPPL 0.37 compat for particle MCMC (#2625) * Progress in DPPL 0.37 compat for particle MCMC * WIP PMCMC work * Gibbs fixes for DPPL 0.37 (plus tiny bugfixes for ESS + HMC) (#2628) * Obviously this single commit will make Gibbs work * Fixes for ESS * Fix HMC call * improve some comments * Fixes to ProduceLogLikelihoodAccumulator * Use LogProbAccumulator for ProduceLogLikelihoodAccumulator * use get_conditioned_gibbs --------- Co-authored-by: Penelope Yong --- src/mcmc/ess.jl | 13 ++- src/mcmc/gibbs.jl | 36 +++++-- src/mcmc/hmc.jl | 4 +- src/mcmc/particle_mcmc.jl | 219 ++++++++++++++++++++++++++------------ 4 files changed, 188 insertions(+), 84 deletions(-) diff --git a/src/mcmc/ess.jl b/src/mcmc/ess.jl index c49b52d3a4..feb737a30d 100644 --- a/src/mcmc/ess.jl +++ b/src/mcmc/ess.jl @@ -54,7 +54,7 @@ function AbstractMCMC.step( # update sample and log-likelihood vi = DynamicPPL.unflatten(vi, sample) - vi = setloglikelihood!!(vi, state.loglikelihood) + vi = DynamicPPL.setloglikelihood!!(vi, state.loglikelihood) return Transition(model, vi), vi end @@ -88,6 +88,11 @@ function Base.rand(rng::Random.AbstractRNG, p::ESSPrior) # p.varinfo, PriorInit())` after TuringLang/DynamicPPL.jl#984. The reason # why we had to use the 'del' flag before this was because # SampleFromPrior() wouldn't overwrite existing variables. + # The main problem I'm rather unsure about is ESS-within-Gibbs. The + # current implementation I think makes sure to only resample the variables + # that 'belong' to the current ESS sampler. InitContext on the other hand + # would resample all variables in the model (??) Need to think about this + # carefully. vns = keys(varinfo) for vn in vns set_flag!(varinfo, vn, "del") @@ -102,13 +107,13 @@ Distributions.mean(p::ESSPrior) = p.μ # Evaluate log-likelihood of proposals. We need this struct because # EllipticalSliceSampling.jl expects a callable struct / a function as its # likelihood. -struct ESSLikelihood{M<:Model,V<:AbstractVarInfo} - ldf::DynamicPPL.LogDensityFunction{M,V} +struct ESSLikelihood{L<:DynamicPPL.LogDensityFunction} + ldf::L # Force usage of `getloglikelihood` in inner constructor function ESSLikelihood(model::Model, varinfo::AbstractVarInfo) ldf = DynamicPPL.LogDensityFunction(model, DynamicPPL.getloglikelihood, varinfo) - return new{typeof(model),typeof(varinfo)}(ldf) + return new{typeof(ldf)}(ldf) end end diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 81281389ec..58db29789d 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -177,13 +177,15 @@ function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi) # Fall back to the default behavior. DynamicPPL.tilde_assume(child_context, right, vn, vi) elseif has_conditioned_gibbs(context, vn) - # Short-circuit the tilde assume if `vn` is present in `context`. - # TODO(mhauru) Fix accumulation here. In this branch anything that gets - # accumulated just gets discarded with `_`. - value, _ = DynamicPPL.tilde_assume( - child_context, right, vn, get_global_varinfo(context) - ) - value, vi + # This branch means that a different sampler is supposed to handle this + # variable. From the perspective of this sampler, this variable is + # conditioned on, so we can just treat it as an observation. + # The only catch is that the value that we need is to be obtained from + # the global VarInfo (since the local VarInfo has no knowledge of it). + # Note that tilde_observe!! will trigger resampling in particle methods + # for variables that are handled by other Gibbs component samplers. + val = get_conditioned_gibbs(context, vn) + DynamicPPL.tilde_observe!!(child_context, right, val, vn, vi) else # If the varname has not been conditioned on, nor is it a target variable, its # presumably a new variable that should be sampled from its prior. We need to add @@ -210,13 +212,25 @@ function DynamicPPL.tilde_assume( vn, child_context = DynamicPPL.prefix_and_strip_contexts(child_context, vn) return if is_target_varname(context, vn) + # This branch means that that `sampler` is supposed to handle + # this variable. We can thus use its default behaviour, with + # the 'local' sampler-specific VarInfo. DynamicPPL.tilde_assume(rng, child_context, sampler, right, vn, vi) elseif has_conditioned_gibbs(context, vn) - value, _ = DynamicPPL.tilde_assume( - child_context, right, vn, get_global_varinfo(context) - ) - value, vi + # This branch means that a different sampler is supposed to handle this + # variable. From the perspective of this sampler, this variable is + # conditioned on, so we can just treat it as an observation. + # The only catch is that the value that we need is to be obtained from + # the global VarInfo (since the local VarInfo has no knowledge of it). + # Note that tilde_observe!! will trigger resampling in particle methods + # for variables that are handled by other Gibbs component samplers. + val = get_conditioned_gibbs(context, vn) + DynamicPPL.tilde_observe!!(child_context, right, val, vn, vi) else + # If the varname has not been conditioned on, nor is it a target variable, its + # presumably a new variable that should be sampled from its prior. We need to add + # this new variable to the global `varinfo` of the context, but not to the local one + # being used by the current sampler. value, new_global_vi = DynamicPPL.tilde_assume( rng, child_context, diff --git a/src/mcmc/hmc.jl b/src/mcmc/hmc.jl index e19f023437..18733f6a8d 100644 --- a/src/mcmc/hmc.jl +++ b/src/mcmc/hmc.jl @@ -162,7 +162,9 @@ function find_initial_params( # Resample and try again. # NOTE: varinfo has to be linked to make sure this samples in unconstrained space varinfo = last( - DynamicPPL.evaluate!!(model, rng, varinfo, DynamicPPL.SampleFromUniform()) + DynamicPPL.evaluate_and_sample!!( + rng, model, varinfo, DynamicPPL.SampleFromUniform() + ), ) end diff --git a/src/mcmc/particle_mcmc.jl b/src/mcmc/particle_mcmc.jl index 6cf8fc3152..17feb18c13 100644 --- a/src/mcmc/particle_mcmc.jl +++ b/src/mcmc/particle_mcmc.jl @@ -33,17 +33,17 @@ end function AdvancedPS.advance!( trace::AdvancedPS.Trace{<:AdvancedPS.LibtaskModel{<:TracedModel}}, isref::Bool=false ) - # Make sure we load/reset the rng in the new replaying mechanism - trace = Accessors.@set trace.model.f.varinfo = DynamicPPL.increment_num_produce!!( - trace.model.f.varinfo + # We want to increment num produce for the VarInfo stored in the trace. The trace is + # mutable, so we create a new model with the incremented VarInfo and set it in the trace + model = trace.model + model = Accessors.@set model.f.varinfo = DynamicPPL.increment_num_produce!!( + model.f.varinfo ) + trace.model = model + # Make sure we load/reset the rng in the new replaying mechanism isref ? AdvancedPS.load_state!(trace.rng) : AdvancedPS.save_state!(trace.rng) score = consume(trace.model.ctask) - if score === nothing - return nothing - else - return score + DynamicPPL.getlogjoint(trace.model.f.varinfo) - end + return score end function AdvancedPS.delete_retained!(trace::TracedModel) @@ -55,10 +55,6 @@ function AdvancedPS.reset_model(trace::TracedModel) return Accessors.@set trace.varinfo = DynamicPPL.reset_num_produce!!(trace.varinfo) end -function AdvancedPS.reset_logprob!(trace::TracedModel) - return Accessors.@set trace.model.varinfo = DynamicPPL.resetlogp!!(trace.model.varinfo) -end - function Libtask.TapedTask(taped_globals, model::TracedModel; kwargs...) return Libtask.TapedTask( taped_globals, model.evaluator[1], model.evaluator[2:end]...; kwargs... @@ -114,11 +110,7 @@ end function SMCTransition(model::DynamicPPL.Model, vi::AbstractVarInfo, weight) theta = getparams(model, vi) - - # This is pretty useless since we reset the log probability continuously in the - # particle sweep. lp = DynamicPPL.getlogjoint(vi) - return SMCTransition(theta, lp, weight) end @@ -183,6 +175,7 @@ function DynamicPPL.initialstep( kwargs..., ) # Reset the VarInfo. + vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator()) vi = DynamicPPL.reset_num_produce!!(vi) DynamicPPL.set_retained_vns_del!(vi) vi = DynamicPPL.resetlogp!!(vi) @@ -293,11 +286,7 @@ varinfo(state::PGState) = state.vi function PGTransition(model::DynamicPPL.Model, vi::AbstractVarInfo, logevidence) theta = getparams(model, vi) - - # This is pretty useless since we reset the log probability continuously in the - # particle sweep. lp = DynamicPPL.getlogjoint(vi) - return PGTransition(theta, lp, logevidence) end @@ -316,6 +305,7 @@ function DynamicPPL.initialstep( vi::AbstractVarInfo; kwargs..., ) + vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator()) # Reset the VarInfo before new sweep vi = DynamicPPL.reset_num_produce!!(vi) DynamicPPL.set_retained_vns_del!(vi) @@ -390,78 +380,115 @@ function DynamicPPL.use_threadsafe_eval( return false end -function trace_local_varinfo_maybe(varinfo) - try - trace = Libtask.get_taped_globals(Any).other - return (trace === nothing ? varinfo : trace.model.f.varinfo)::AbstractVarInfo +""" + get_trace_local_varinfo_maybe(vi::AbstractVarInfo) + +Get the `Trace` local varinfo if one exists. + +If executed within a `TapedTask`, return the `varinfo` stored in the "taped globals" of the +task, otherwise return `vi`. +""" +function get_trace_local_varinfo_maybe(varinfo::AbstractVarInfo) + trace = try + Libtask.get_taped_globals(Any).other catch e - # NOTE: this heuristic allows Libtask evaluating a model outside a `Trace`. - if e == KeyError(:task_variable) - return varinfo - else - rethrow(e) - end + e == KeyError(:task_variable) ? nothing : rethrow(e) end + return (trace === nothing ? varinfo : trace.model.f.varinfo)::AbstractVarInfo end -function trace_local_rng_maybe(rng::Random.AbstractRNG) - try - return Libtask.get_taped_globals(Any).rng +""" + get_trace_local_varinfo_maybe(rng::Random.AbstractRNG) + +Get the `Trace` local rng if one exists. + +If executed within a `TapedTask`, return the `rng` stored in the "taped globals" of the +task, otherwise return `vi`. +""" +function get_trace_local_rng_maybe(rng::Random.AbstractRNG) + return try + Libtask.get_taped_globals(Any).rng catch e - # NOTE: this heuristic allows Libtask evaluating a model outside a `Trace`. - if e == KeyError(:task_variable) - return rng - else - rethrow(e) - end + e == KeyError(:task_variable) ? rng : rethrow(e) + end +end + +""" + set_trace_local_varinfo_maybe(vi::AbstractVarInfo) + +Set the `Trace` local varinfo if executing within a `Trace`. Return `nothing`. + +If executed within a `TapedTask`, set the `varinfo` stored in the "taped globals" of the +task. Otherwise do nothing. +""" +function set_trace_local_varinfo_maybe(vi::AbstractVarInfo) + # TODO(mhauru) This should be done in a try-catch block, as in the commented out code. + # However, Libtask currently can't handle this block. + trace = #try + Libtask.get_taped_globals(Any).other + # catch e + # e == KeyError(:task_variable) ? nothing : rethrow(e) + # end + if trace !== nothing + model = trace.model + model = Accessors.@set model.f.varinfo = vi + trace.model = model end + return nothing end -# TODO(DPPL0.37/penelopeysm) The whole tilde pipeline for particle MCMC needs to be -# thoroughly fixed. function DynamicPPL.assume( - rng, ::Sampler{<:Union{PG,SMC}}, dist::Distribution, vn::VarName, _vi::AbstractVarInfo + rng, ::Sampler{<:Union{PG,SMC}}, dist::Distribution, vn::VarName, vi::AbstractVarInfo ) - vi = trace_local_varinfo_maybe(_vi) - trng = trace_local_rng_maybe(rng) + arg_vi_id = objectid(vi) + vi = get_trace_local_varinfo_maybe(vi) + using_local_vi = objectid(vi) == arg_vi_id + + trng = get_trace_local_rng_maybe(rng) if ~haskey(vi, vn) r = rand(trng, dist) - push!!(vi, vn, r, dist) + vi = push!!(vi, vn, r, dist) elseif DynamicPPL.is_flagged(vi, vn, "del") DynamicPPL.unset_flag!(vi, vn, "del") # Reference particle parent r = rand(trng, dist) vi[vn] = DynamicPPL.tovec(r) + # TODO(mhauru): + # The below is the only line that differs from assume called on SampleFromPrior. + # Could we just call assume on SampleFromPrior and then `setorder!!` after that? vi = DynamicPPL.setorder!!(vi, vn, DynamicPPL.get_num_produce(vi)) else r = vi[vn] end - # TODO: call accumulate_assume?! + + vi = DynamicPPL.accumulate_assume!!(vi, r, 0, vn, dist) + + # TODO(mhauru) Rather than this if-block, we should use try-catch within + # `set_trace_local_varinfo_maybe`. However, currently Libtask can't handle such a block, + # hence this. + if !using_local_vi + set_trace_local_varinfo_maybe(vi) + end return r, vi end -# TODO(mhauru) Fix this. -# function DynamicPPL.observe(spl::Sampler{<:Union{PG,SMC}}, dist::Distribution, value, vi) -# # NOTE: The `Libtask.produce` is now hit in `acclogp_observe!!`. -# return logpdf(dist, value), trace_local_varinfo_maybe(vi) -# end - -function DynamicPPL.acclogp!!( - context::DynamicPPL.SamplingContext{<:Sampler{<:Union{PG,SMC}}}, - varinfo::AbstractVarInfo, - logp, +function DynamicPPL.tilde_observe!!( + ctx::DynamicPPL.SamplingContext{<:Sampler{<:Union{PG,SMC}}}, right, left, vn, vi ) - varinfo_trace = trace_local_varinfo_maybe(varinfo) - return DynamicPPL.acclogp!!(DynamicPPL.childcontext(context), varinfo_trace, logp) -end + arg_vi_id = objectid(vi) + vi = get_trace_local_varinfo_maybe(vi) + using_local_vi = objectid(vi) == arg_vi_id -# TODO(mhauru) Fix this. -# function DynamicPPL.acclogp_observe!!( -# context::SamplingContext{<:Sampler{<:Union{PG,SMC}}}, varinfo::AbstractVarInfo, logp -# ) -# Libtask.produce(logp) -# return trace_local_varinfo_maybe(varinfo) -# end + left, vi = DynamicPPL.tilde_observe!!(ctx.context, right, left, vn, vi) + + # TODO(mhauru) Rather than this if-block, we should use try-catch within + # `set_trace_local_varinfo_maybe`. However, currently Libtask can't handle such a block, + # hence this. + if !using_local_vi + set_trace_local_varinfo_maybe(vi) + end + return left, vi +end # Convenient constructor function AdvancedPS.Trace( @@ -478,13 +505,69 @@ function AdvancedPS.Trace( return newtrace end +""" + ProduceLogLikelihoodAccumulator{T<:Real} <: AbstractAccumulator + +Exactly like `LogLikelihoodAccumulator`, but calls `Libtask.produce` on change of value. + +# Fields +$(TYPEDFIELDS) +""" +struct ProduceLogLikelihoodAccumulator{T<:Real} <: DynamicPPL.LogProbAccumulator{T} + "the scalar log likelihood value" + logp::T +end + +# Note that this uses the same name as `LogLikelihoodAccumulator`. Thus only one of the two +# can be used in a given VarInfo. +DynamicPPL.accumulator_name(::Type{<:ProduceLogLikelihoodAccumulator}) = :LogLikelihood +DynamicPPL.logp(acc::ProduceLogLikelihoodAccumulator) = acc.logp + +function DynamicPPL.acclogp(acc1::ProduceLogLikelihoodAccumulator, val) + # The below line is the only difference from `LogLikelihoodAccumulator`. + Libtask.produce(val) + return ProduceLogLikelihoodAccumulator(acc1.logp + val) +end + +function DynamicPPL.accumulate_assume!!( + acc::ProduceLogLikelihoodAccumulator, val, logjac, vn, right +) + return acc +end +function DynamicPPL.accumulate_observe!!( + acc::ProduceLogLikelihoodAccumulator, right, left, vn +) + return DynamicPPL.acclogp(acc, Distributions.loglikelihood(right, left)) +end + # We need to tell Libtask which calls may have `produce` calls within them. In practice most # of these won't be needed, because of inlining and the fact that `might_produce` is only # called on `:invoke` expressions rather than `:call`s, but since those are implementation # details of the compiler, we set a bunch of methods as might_produce = true. We start with -# `acclogp_observe!!` which is what calls `produce` and go up the call stack. -# Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.acclogp_observe!!),Vararg}}) = true +# adding to ProduceLogLikelihoodAccumulator, which is what calls `produce`, and go up the +# call stack. +Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.accloglikelihood!!),Vararg}}) = true +function Libtask.might_produce( + ::Type{ + <:Tuple{ + typeof(Base.:+), + ProduceLogLikelihoodAccumulator, + DynamicPPL.LogLikelihoodAccumulator, + }, + }, +) + return true +end +function Libtask.might_produce( + ::Type{<:Tuple{typeof(DynamicPPL.accumulate_observe!!),Vararg}} +) + return true +end Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.tilde_observe!!),Vararg}}) = true +# Could the next two could have tighter type bounds on the arguments, namely a GibbsContext? +# That's the only thing that makes tilde_assume calls result in tilde_observe calls. +Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.tilde_assume!!),Vararg}}) = true +Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.tilde_assume),Vararg}}) = true Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.evaluate!!),Vararg}}) = true function Libtask.might_produce( ::Type{<:Tuple{typeof(DynamicPPL.evaluate_threadsafe!!),Vararg}} From 7124864a726070f3ae0e85e93b8c6508d8590bf9 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 31 Jul 2025 12:18:53 +0100 Subject: [PATCH 39/49] "Fixes" for PG-in-Gibbs (#2629) * WIP PMCMC work * Fixes to ProduceLogLikelihoodAccumulator * inline definition of `set_retained_vns_del!` * Fix ProduceLogLikelihoodAcc * Remove all uses of `set_retained_vns_del!` * Use nice functions * Remove PG tests with dynamic number of Gibbs-conditioned-observations * Fix essential/container tests * Update pMCMC implementation as per discussion * remove extra printing statements * revert unneeded changes * Add back (some kind of) dynamic model test * fix rebase * Add a todo comment for dynamic model tests --------- Co-authored-by: Markus Hauru --- src/mcmc/particle_mcmc.jl | 74 +++++++++++++++------- test/essential/container.jl | 2 + test/mcmc/gibbs.jl | 120 ++++++++++++++++++++++++++---------- 3 files changed, 144 insertions(+), 52 deletions(-) diff --git a/src/mcmc/particle_mcmc.jl b/src/mcmc/particle_mcmc.jl index 17feb18c13..f7e40f4326 100644 --- a/src/mcmc/particle_mcmc.jl +++ b/src/mcmc/particle_mcmc.jl @@ -4,6 +4,38 @@ ### AdvancedPS models and interface +""" + set_all_del!(vi::AbstractVarInfo) + +Set the "del" flag for all variables in the VarInfo `vi`, thus marking them for +resampling. +""" +function set_all_del!(vi::AbstractVarInfo) + # TODO(penelopeysm): Instead of being a 'del' flag on the VarInfo, we + # could either: + # - keep a boolean 'resample' flag on the trace, or + # - modify the model context appropriately. + # However, this refactoring will have to wait until InitContext is + # merged into DPPL. + for vn in keys(vi) + DynamicPPL.set_flag!(vi, vn, "del") + end + return nothing +end + +""" + unset_all_del!(vi::AbstractVarInfo) + +Unset the "del" flag for all variables in the VarInfo `vi`, thus preventing +them from being resampled. +""" +function unset_all_del!(vi::AbstractVarInfo) + for vn in keys(vi) + DynamicPPL.unset_flag!(vi, vn, "del") + end + return nothing +end + struct TracedModel{S<:AbstractSampler,V<:AbstractVarInfo,M<:Model,E<:Tuple} <: AdvancedPS.AbstractGenericModel model::M @@ -33,13 +65,6 @@ end function AdvancedPS.advance!( trace::AdvancedPS.Trace{<:AdvancedPS.LibtaskModel{<:TracedModel}}, isref::Bool=false ) - # We want to increment num produce for the VarInfo stored in the trace. The trace is - # mutable, so we create a new model with the incremented VarInfo and set it in the trace - model = trace.model - model = Accessors.@set model.f.varinfo = DynamicPPL.increment_num_produce!!( - model.f.varinfo - ) - trace.model = model # Make sure we load/reset the rng in the new replaying mechanism isref ? AdvancedPS.load_state!(trace.rng) : AdvancedPS.save_state!(trace.rng) score = consume(trace.model.ctask) @@ -47,12 +72,23 @@ function AdvancedPS.advance!( end function AdvancedPS.delete_retained!(trace::TracedModel) - DynamicPPL.set_retained_vns_del!(trace.varinfo) + # This method is called if, during a CSMC update, we perform a resampling + # and choose the reference particle as the trajectory to carry on from. + # In such a case, we need to ensure that when we continue sampling (i.e. + # the next time we hit tilde_assume), we don't use the values in the + # reference particle but rather sample new values. + # + # Here, we indiscriminately set the 'del' flag for all variables in the + # VarInfo. This is slightly overkill: it is not necessary to set the 'del' + # flag for variables that were already sampled. However, it allows us to + # avoid keeping track of which variables were sampled, which leads to many + # simplifications in the VarInfo data structure. + set_all_del!(trace.varinfo) return trace end function AdvancedPS.reset_model(trace::TracedModel) - return Accessors.@set trace.varinfo = DynamicPPL.reset_num_produce!!(trace.varinfo) + return trace end function Libtask.TapedTask(taped_globals, model::TracedModel; kwargs...) @@ -176,8 +212,7 @@ function DynamicPPL.initialstep( ) # Reset the VarInfo. vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator()) - vi = DynamicPPL.reset_num_produce!!(vi) - DynamicPPL.set_retained_vns_del!(vi) + set_all_del!(vi) vi = DynamicPPL.resetlogp!!(vi) vi = DynamicPPL.empty!!(vi) @@ -307,8 +342,7 @@ function DynamicPPL.initialstep( ) vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator()) # Reset the VarInfo before new sweep - vi = DynamicPPL.reset_num_produce!!(vi) - DynamicPPL.set_retained_vns_del!(vi) + set_all_del!(vi) vi = DynamicPPL.resetlogp!!(vi) # Create a new set of particles @@ -339,14 +373,15 @@ function AbstractMCMC.step( ) # Reset the VarInfo before new sweep. vi = state.vi - vi = DynamicPPL.reset_num_produce!!(vi) + vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator()) vi = DynamicPPL.resetlogp!!(vi) # Create reference particle for which the samples will be retained. + unset_all_del!(vi) reference = AdvancedPS.forkr(AdvancedPS.Trace(model, spl, vi, state.rng)) # For all other particles, do not retain the variables but resample them. - DynamicPPL.set_retained_vns_del!(vi) + set_all_del!(vi) # Create a new set of particles. num_particles = spl.alg.nparticles @@ -451,12 +486,11 @@ function DynamicPPL.assume( vi = push!!(vi, vn, r, dist) elseif DynamicPPL.is_flagged(vi, vn, "del") DynamicPPL.unset_flag!(vi, vn, "del") # Reference particle parent - r = rand(trng, dist) - vi[vn] = DynamicPPL.tovec(r) # TODO(mhauru): # The below is the only line that differs from assume called on SampleFromPrior. - # Could we just call assume on SampleFromPrior and then `setorder!!` after that? - vi = DynamicPPL.setorder!!(vi, vn, DynamicPPL.get_num_produce(vi)) + # Could we just call assume on SampleFromPrior with a specific rng? + r = rand(trng, dist) + vi[vn] = DynamicPPL.tovec(r) else r = vi[vn] end @@ -498,8 +532,6 @@ function AdvancedPS.Trace( rng::AdvancedPS.TracedRNG, ) newvarinfo = deepcopy(varinfo) - newvarinfo = DynamicPPL.reset_num_produce!!(newvarinfo) - tmodel = TracedModel(model, sampler, newvarinfo, rng) newtrace = AdvancedPS.Trace(tmodel, rng) return newtrace diff --git a/test/essential/container.jl b/test/essential/container.jl index cbd7a6fe2b..8ebce270d4 100644 --- a/test/essential/container.jl +++ b/test/essential/container.jl @@ -19,6 +19,7 @@ using Turing @testset "constructor" begin vi = DynamicPPL.VarInfo() + vi = DynamicPPL.setacc!!(vi, Turing.Inference.ProduceLogLikelihoodAccumulator()) sampler = Sampler(PG(10)) model = test() trace = AdvancedPS.Trace(model, sampler, vi, AdvancedPS.TracedRNG()) @@ -46,6 +47,7 @@ using Turing return a, b end vi = DynamicPPL.VarInfo() + vi = DynamicPPL.setacc!!(vi, Turing.Inference.ProduceLogLikelihoodAccumulator()) sampler = Sampler(PG(10)) model = normal() diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index f44a9fefc6..092cc71f7e 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -207,8 +207,8 @@ end val ~ Normal(s, 1) 1.0 ~ Normal(s + m, 1) - n := m + 1 - xs = M(undef, n) + n := m + xs = M(undef, 5) for i in eachindex(xs) xs[i] ~ Beta(0.5, 0.5) end @@ -565,40 +565,98 @@ end end end - # The below test used to sample incorrectly before - # https://github.com/TuringLang/Turing.jl/pull/2328 - @testset "dynamic model with ESS" begin - @model function dynamic_model_for_ess() - b ~ Bernoulli() - x_length = b ? 1 : 2 - x = Vector{Float64}(undef, x_length) - for i in 1:x_length - x[i] ~ Normal(i, 1.0) + @testset "PG with variable number of observations" begin + # When sampling from a model with Particle Gibbs, it is mandatory for + # the number of observations to be the same in all particles, since the + # observations trigger particle resampling. + # + # Up until Turing v0.39, `x ~ dist` statements where `x` was the + # responsibility of a different (non-PG) Gibbs subsampler used to not + # count as an observation. Instead, the log-probability `logpdf(dist, x)` + # would be manually added to the VarInfo's `logp` field and included in the + # weighting for the _following_ observation. + # + # In Turing v0.40, this is now changed: `x ~ dist` uses tilde_observe!! + # which thus triggers resampling. Thus, for example, the following model + # does not work any more: + # + # @model function f() + # a ~ Poisson(2.0) + # x = Vector{Float64}(undef, a) + # for i in eachindex(x) + # x[i] ~ Normal() + # end + # end + # sample(f(), Gibbs(:a => PG(10), :x => MH()), 1000) + # + # because the number of observations in each particle depends on the value + # of `a`. + # + # This testset checks that ways of working around such a situation. + + function test_dynamic_bernoulli(chain) + means = Dict(:b => 0.5, "x[1]" => 1.0, "x[2]" => 2.0) + stds = Dict(:b => 0.5, "x[1]" => 1.0, "x[2]" => 1.0) + for vn in keys(means) + @test isapprox(mean(skipmissing(chain[:, vn, 1])), means[vn]; atol=0.1) + @test isapprox(std(skipmissing(chain[:, vn, 1])), stds[vn]; atol=0.1) end end - m = dynamic_model_for_ess() - chain = sample(m, Gibbs(:b => PG(10), :x => ESS()), 2000; discard_initial=100) - means = Dict(:b => 0.5, "x[1]" => 1.0, "x[2]" => 2.0) - stds = Dict(:b => 0.5, "x[1]" => 1.0, "x[2]" => 1.0) - for vn in keys(means) - @test isapprox(mean(skipmissing(chain[:, vn, 1])), means[vn]; atol=0.1) - @test isapprox(std(skipmissing(chain[:, vn, 1])), stds[vn]; atol=0.1) + # TODO(DPPL0.37/penelopeysm): decide what to do with these tests + @testset "Coalescing multiple observations into one" begin + # Instead of observing x[1] and x[2] separately, we lump them into a + # single distribution. + @model function dynamic_bernoulli() + b ~ Bernoulli() + if b + dists = [Normal(1.0)] + else + dists = [Normal(1.0), Normal(2.0)] + end + return x ~ product_distribution(dists) + end + model = dynamic_bernoulli() + # This currently fails because if the global varinfo has `x` with length 2, + # and the particle sampler has `b = true`, it attempts to calculate the + # log-likelihood of a length-2 vector with respect to a length-1 + # distribution. + @test_throws DimensionMismatch chain = sample( + StableRNG(468), + model, + Gibbs(:b => PG(10), :x => ESS()), + 2000; + discard_initial=100, + ) + # test_dynamic_bernoulli(chain) end - end - @testset "dynamic model with dot tilde" begin - @model function dynamic_model_with_dot_tilde( - num_zs=10, (::Type{M})=Vector{Float64} - ) where {M} - z = Vector{Int}(undef, num_zs) - z .~ Poisson(1.0) - num_ms = sum(z) - m = M(undef, num_ms) - return m .~ Normal(1.0, 1.0) - end - model = dynamic_model_with_dot_tilde() - sample(model, Gibbs(:z => PG(10), :m => HMC(0.01, 4)), 100) + @testset "Inserting @addlogprob!" begin + # On top of observing x[i], we also add in extra 'observations' + @model function dynamic_bernoulli_2() + b ~ Bernoulli() + x_length = b ? 1 : 2 + x = Vector{Float64}(undef, x_length) + for i in 1:x_length + x[i] ~ Normal(i, 1.0) + end + if length(x) == 1 + # This value is the expectation value of logpdf(Normal(), x) where x ~ Normal(). + # See discussion in + # https://github.com/TuringLang/Turing.jl/pull/2629#discussion_r2237323817 + @addlogprob!(-1.418849) + end + end + model = dynamic_bernoulli_2() + chain = sample( + StableRNG(468), + model, + Gibbs(:b => PG(10), :x => ESS()), + 2000; + discard_initial=100, + ) + test_dynamic_bernoulli(chain) + end end @testset "Demo model" begin From 8fdecc0f764270ccebe8f8ed723b7eebea52bae1 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 1 Aug 2025 16:24:16 +0100 Subject: [PATCH 40/49] Use accumulators to fix all logp calculations when sampling (#2630) * Use new `getlogjoint` for optimisation * Change getlogjoint -> getlogjoint_internal where needed * Enforce re-evaluation when constructing `Transition` * fix tests * Remove extra evaluations from SGLD and SGHMC * Remove dead `transitions_from_chain` method (used to be part of `predict`) * metadata -> getstats_with_lp * Clean up some stray getlogp --- ext/TuringDynamicHMCExt.jl | 18 +-- ext/TuringOptimExt.jl | 6 +- src/mcmc/Inference.jl | 219 +++++++++++------------------- src/mcmc/emcee.jl | 11 +- src/mcmc/ess.jl | 4 +- src/mcmc/external_sampler.jl | 69 +++------- src/mcmc/gibbs.jl | 35 +++-- src/mcmc/hmc.jl | 31 ++--- src/mcmc/is.jl | 12 +- src/mcmc/mh.jl | 67 ++++++--- src/mcmc/particle_mcmc.jl | 14 +- src/mcmc/prior.jl | 6 +- src/mcmc/sghmc.jl | 32 ++--- src/optimisation/Optimisation.jl | 113 +-------------- test/essential/container.jl | 3 - test/mcmc/Inference.jl | 8 +- test/mcmc/ess.jl | 2 +- test/mcmc/external_sampler.jl | 2 +- test/mcmc/gibbs.jl | 22 +-- test/mcmc/is.jl | 4 +- test/optimisation/Optimisation.jl | 22 +-- 21 files changed, 228 insertions(+), 472 deletions(-) diff --git a/ext/TuringDynamicHMCExt.jl b/ext/TuringDynamicHMCExt.jl index 144f9e7008..2c4bd08980 100644 --- a/ext/TuringDynamicHMCExt.jl +++ b/ext/TuringDynamicHMCExt.jl @@ -63,7 +63,7 @@ function DynamicPPL.initialstep( # Define log-density function. ℓ = DynamicPPL.LogDensityFunction( - model, DynamicPPL.getlogjoint, vi; adtype=spl.alg.adtype + model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.alg.adtype ) # Perform initial step. @@ -73,14 +73,9 @@ function DynamicPPL.initialstep( steps = DynamicHMC.mcmc_steps(results.sampling_logdensity, results.final_warmup_state) Q, _ = DynamicHMC.mcmc_next_step(steps, results.final_warmup_state.Q) - # Update the variables. - vi = DynamicPPL.unflatten(vi, Q.q) - # TODO(DPPL0.37/penelopeysm): This is obviously incorrect. Fix this. - vi = DynamicPPL.setloglikelihood!!(vi, Q.ℓq) - vi = DynamicPPL.setlogprior!!(vi, 0.0) - # Create first sample and state. - sample = Turing.Inference.Transition(model, vi) + vi = DynamicPPL.unflatten(vi, Q.q) + sample = Turing.Inference.Transition(model, vi, nothing) state = DynamicNUTSState(ℓ, vi, Q, steps.H.κ, steps.ϵ) return sample, state @@ -99,12 +94,9 @@ function AbstractMCMC.step( steps = DynamicHMC.mcmc_steps(rng, spl.alg.sampler, state.metric, ℓ, state.stepsize) Q, _ = DynamicHMC.mcmc_next_step(steps, state.cache) - # Update the variables. - vi = DynamicPPL.unflatten(vi, Q.q) - vi = DynamicPPL.setlogp!!(vi, Q.ℓq) - # Create next sample and state. - sample = Turing.Inference.Transition(model, vi) + vi = DynamicPPL.unflatten(vi, Q.q) + sample = Turing.Inference.Transition(model, vi, nothing) newstate = DynamicNUTSState(ℓ, vi, Q, state.metric, state.stepsize) return sample, newstate diff --git a/ext/TuringOptimExt.jl b/ext/TuringOptimExt.jl index 41a89d3b5c..0f755988ef 100644 --- a/ext/TuringOptimExt.jl +++ b/ext/TuringOptimExt.jl @@ -102,7 +102,7 @@ function Optim.optimize( options::Optim.Options=Optim.Options(); kwargs..., ) - f = Optimisation.OptimLogDensity(model, Optimisation.getlogjoint_without_jacobian) + f = Optimisation.OptimLogDensity(model, DynamicPPL.getlogjoint) init_vals = DynamicPPL.getparams(f.ldf) optimizer = Optim.LBFGS() return _map_optimize(model, init_vals, optimizer, options; kwargs...) @@ -124,7 +124,7 @@ function Optim.optimize( options::Optim.Options=Optim.Options(); kwargs..., ) - f = Optimisation.OptimLogDensity(model, Optimisation.getlogjoint_without_jacobian) + f = Optimisation.OptimLogDensity(model, DynamicPPL.getlogjoint) init_vals = DynamicPPL.getparams(f.ldf) return _map_optimize(model, init_vals, optimizer, options; kwargs...) end @@ -140,7 +140,7 @@ function Optim.optimize( end function _map_optimize(model::DynamicPPL.Model, args...; kwargs...) - f = Optimisation.OptimLogDensity(model, Optimisation.getlogjoint_without_jacobian) + f = Optimisation.OptimLogDensity(model, DynamicPPL.getlogjoint) return _optimize(f, args...; kwargs...) end diff --git a/src/mcmc/Inference.jl b/src/mcmc/Inference.jl index 299670a0d4..3275072520 100644 --- a/src/mcmc/Inference.jl +++ b/src/mcmc/Inference.jl @@ -17,8 +17,8 @@ using DynamicPPL: setindex!!, push!!, setlogp!!, - getlogp, getlogjoint, + getlogjoint_internal, VarName, getsym, getdist, @@ -123,71 +123,94 @@ end ###################### # Default Transition # ###################### -# Default -getstats(t) = nothing +getstats(::Any) = NamedTuple() +# TODO(penelopeysm): Remove this abstract type by converting SGLDTransition, +# SMCTransition, and PGTransition to Turing.Inference.Transition instead. abstract type AbstractTransition end -struct Transition{T,F<:AbstractFloat,S<:Union{NamedTuple,Nothing}} <: AbstractTransition +struct Transition{T,F<:AbstractFloat,N<:NamedTuple} <: AbstractTransition θ::T - lp::F # TODO: merge `lp` with `stat` - stat::S -end - -Transition(θ, lp) = Transition(θ, lp, nothing) -function Transition(model::DynamicPPL.Model, vi::AbstractVarInfo, t) - θ = getparams(model, vi) - lp = getlogjoint(vi) - return Transition(θ, lp, getstats(t)) -end + logprior::F + loglikelihood::F + stat::N + + """ + Transition(model::Model, vi::AbstractVarInfo, sampler_transition) + + Construct a new `Turing.Inference.Transition` object using the outputs of a + sampler step. + + Here, `vi` represents a VarInfo _for which the appropriate parameters have + already been set_. However, the accumulators (e.g. logp) may in general + have junk contents. The role of this method is to re-evaluate `model` and + thus set the accumulators to the correct values. + + `sampler_transition` is the transition object returned by the sampler + itself and is only used to extract statistics of interest. + """ + function Transition(model::DynamicPPL.Model, vi::AbstractVarInfo, sampler_transition) + vi = DynamicPPL.setaccs!!( + vi, + ( + DynamicPPL.ValuesAsInModelAccumulator(true), + DynamicPPL.LogPriorAccumulator(), + DynamicPPL.LogLikelihoodAccumulator(), + ), + ) + _, vi = DynamicPPL.evaluate!!(model, vi) + + # Extract all the information we need + vals_as_in_model = DynamicPPL.getacc(vi, Val(:ValuesAsInModel)).values + logprior = DynamicPPL.getlogprior(vi) + loglikelihood = DynamicPPL.getloglikelihood(vi) + + # Get additional statistics + stats = getstats(sampler_transition) + return new{typeof(vals_as_in_model),typeof(logprior),typeof(stats)}( + vals_as_in_model, logprior, loglikelihood, stats + ) + end -function metadata(t::Transition) - stat = t.stat - if stat === nothing - return (lp=t.lp,) - else - return merge((lp=t.lp,), stat) + function Transition( + model::DynamicPPL.Model, + untyped_vi::DynamicPPL.VarInfo{<:DynamicPPL.Metadata}, + sampler_transition, + ) + # Re-evaluating the model is unconscionably slow for untyped VarInfo. It's + # much faster to convert it to a typed varinfo first, hence this method. + # https://github.com/TuringLang/Turing.jl/issues/2604 + return Transition(model, DynamicPPL.typed_varinfo(untyped_vi), sampler_transition) end end -DynamicPPL.getlogjoint(t::Transition) = t.lp - -# Metadata of VarInfo object -metadata(vi::AbstractVarInfo) = (lp=getlogjoint(vi),) +function getstats_with_lp(t::Transition) + return merge( + t.stat, + ( + lp=t.logprior + t.loglikelihood, + logprior=t.logprior, + loglikelihood=t.loglikelihood, + ), + ) +end +function getstats_with_lp(vi::AbstractVarInfo) + return ( + lp=DynamicPPL.getlogjoint(vi), + logprior=DynamicPPL.getlogprior(vi), + loglikelihood=DynamicPPL.getloglikelihood(vi), + ) +end ########################## # Chain making utilities # ########################## -""" - getparams(model, t) - -Return a named tuple of parameters. -""" -getparams(model, t) = t.θ -function getparams(model::DynamicPPL.Model, vi::DynamicPPL.VarInfo) - # NOTE: In the past, `invlink(vi, model)` + `values_as(vi, OrderedDict)` was used. - # Unfortunately, using `invlink` can cause issues in scenarios where the constraints - # of the parameters change depending on the realizations. Hence we have to use - # `values_as_in_model`, which re-runs the model and extracts the parameters - # as they are seen in the model, i.e. in the constrained space. Moreover, - # this means that the code below will work both of linked and invlinked `vi`. - # Ref: https://github.com/TuringLang/Turing.jl/issues/2195 - # NOTE: We need to `deepcopy` here to avoid modifying the original `vi`. - return DynamicPPL.values_as_in_model(model, true, deepcopy(vi)) -end -function getparams( - model::DynamicPPL.Model, untyped_vi::DynamicPPL.VarInfo{<:DynamicPPL.Metadata} -) - # values_as_in_model is unconscionably slow for untyped VarInfo. It's - # much faster to convert it to a typed varinfo before calling getparams. - # https://github.com/TuringLang/Turing.jl/issues/2604 - return getparams(model, DynamicPPL.typed_varinfo(untyped_vi)) +getparams(::DynamicPPL.Model, t::AbstractTransition) = t.θ +function getparams(model::DynamicPPL.Model, vi::AbstractVarInfo) + t = Transition(model, vi, nothing) + return getparams(model, t) end -function getparams(::DynamicPPL.Model, ::DynamicPPL.VarInfo{NamedTuple{(),Tuple{}}}) - return Dict{VarName,Any}() -end - function _params_to_array(model::DynamicPPL.Model, ts::Vector) names_set = OrderedSet{VarName}() # Extract the parameter names and values from each transition. @@ -203,7 +226,6 @@ function _params_to_array(model::DynamicPPL.Model, ts::Vector) iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)) mapreduce(collect, vcat, iters) end - nms = map(first, nms_and_vs) vs = map(last, nms_and_vs) for nm in nms @@ -218,14 +240,9 @@ function _params_to_array(model::DynamicPPL.Model, ts::Vector) return names, vals end -function get_transition_extras(ts::AbstractVector{<:VarInfo}) - valmat = reshape([getlogjoint(t) for t in ts], :, 1) - return [:lp], valmat -end - function get_transition_extras(ts::AbstractVector) - # Extract all metadata. - extra_data = map(metadata, ts) + # Extract stats + log probabilities from each transition or VarInfo + extra_data = map(getstats_with_lp, ts) return names_values(extra_data) end @@ -334,7 +351,7 @@ function AbstractMCMC.bundle_samples( vals = map(values(sym_to_vns)) do vns map(Base.Fix1(getindex, params), vns) end - return merge(NamedTuple(zip(keys(sym_to_vns), vals)), metadata(t)) + return merge(NamedTuple(zip(keys(sym_to_vns), vals)), getstats_with_lp(t)) end end @@ -396,84 +413,4 @@ function DynamicPPL.get_matching_type( return Array{T,N} end -############## -# Utilities # -############## - -""" - - transitions_from_chain( - [rng::AbstractRNG,] - model::Model, - chain::MCMCChains.Chains; - sampler = DynamicPPL.SampleFromPrior() - ) - -Execute `model` conditioned on each sample in `chain`, and return resulting transitions. - -The returned transitions are represented in a `Vector{<:Turing.Inference.Transition}`. - -# Details - -In a bit more detail, the process is as follows: -1. For every `sample` in `chain` - 1. For every `variable` in `sample` - 1. Set `variable` in `model` to its value in `sample` - 2. Execute `model` with variables fixed as above, sampling variables NOT present - in `chain` using `SampleFromPrior` - 3. Return sampled variables and log-joint - -# Example -```julia-repl -julia> using Turing - -julia> @model function demo() - m ~ Normal(0, 1) - x ~ Normal(m, 1) - end; - -julia> m = demo(); - -julia> chain = Chains(randn(2, 1, 1), ["m"]); # 2 samples of `m` - -julia> transitions = Turing.Inference.transitions_from_chain(m, chain); - -julia> [Turing.Inference.getlogjoint(t) for t in transitions] # extract the logjoints -2-element Array{Float64,1}: - -3.6294991938628374 - -2.5697948166987845 - -julia> [first(t.θ.x) for t in transitions] # extract samples for `x` -2-element Array{Array{Float64,1},1}: - [-2.0844148956440796] - [-1.704630494695469] -``` -""" -function transitions_from_chain( - model::DynamicPPL.Model, chain::MCMCChains.Chains; kwargs... -) - return transitions_from_chain(Random.default_rng(), model, chain; kwargs...) -end - -function transitions_from_chain( - rng::Random.AbstractRNG, - model::DynamicPPL.Model, - chain::MCMCChains.Chains; - sampler=DynamicPPL.SampleFromPrior(), -) - vi = Turing.VarInfo(model) - - iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) - transitions = map(iters) do (sample_idx, chain_idx) - # Set variables present in `chain` and mark those NOT present in chain to be resampled. - DynamicPPL.setval_and_resample!(vi, chain, sample_idx, chain_idx) - model(rng, vi, sampler) - - # Convert `VarInfo` into `NamedTuple` and save. - Transition(model, vi) - end - - return transitions -end - end # module diff --git a/src/mcmc/emcee.jl b/src/mcmc/emcee.jl index cdbaf9fd02..076e61d7e8 100644 --- a/src/mcmc/emcee.jl +++ b/src/mcmc/emcee.jl @@ -65,14 +65,14 @@ function AbstractMCMC.step( end # Compute initial transition and states. - transition = map(Base.Fix1(Transition, model), vis) + transition = [Transition(model, vi, nothing) for vi in vis] # TODO: Make compatible with immutable `AbstractVarInfo`. state = EmceeState( vis[1], map(vis) do vi vi = DynamicPPL.link!!(vi, model) - AMH.Transition(vi[:], DynamicPPL.getlogjoint(vi), false) + AMH.Transition(vi[:], DynamicPPL.getlogjoint_internal(vi), false) end, ) @@ -87,18 +87,17 @@ function AbstractMCMC.step( densitymodel = AMH.DensityModel( Base.Fix1( LogDensityProblems.logdensity, - DynamicPPL.LogDensityFunction(model, DynamicPPL.getlogjoint, vi), + DynamicPPL.LogDensityFunction(model, DynamicPPL.getlogjoint_internal, vi), ), ) # Compute the next states. - states = last(AbstractMCMC.step(rng, densitymodel, spl.alg.ensemble, state.states)) + t, states = AbstractMCMC.step(rng, densitymodel, spl.alg.ensemble, state.states) # Compute the next transition and state. transition = map(states) do _state vi = DynamicPPL.unflatten(vi, _state.params) - t = Transition(getparams(model, vi), _state.lp) - return t + return Transition(model, vi, t) end newstate = EmceeState(vi, states) diff --git a/src/mcmc/ess.jl b/src/mcmc/ess.jl index feb737a30d..bbf900657c 100644 --- a/src/mcmc/ess.jl +++ b/src/mcmc/ess.jl @@ -31,7 +31,7 @@ function DynamicPPL.initialstep( EllipticalSliceSampling.isgaussian(typeof(dist)) || error("ESS only supports Gaussian prior distributions") end - return Transition(model, vi), vi + return Transition(model, vi, nothing), vi end function AbstractMCMC.step( @@ -56,7 +56,7 @@ function AbstractMCMC.step( vi = DynamicPPL.unflatten(vi, sample) vi = DynamicPPL.setloglikelihood!!(vi, state.loglikelihood) - return Transition(model, vi), vi + return Transition(model, vi, nothing), vi end # Prior distribution of considered random variable diff --git a/src/mcmc/external_sampler.jl b/src/mcmc/external_sampler.jl index 992a2fb2db..200c59293a 100644 --- a/src/mcmc/external_sampler.jl +++ b/src/mcmc/external_sampler.jl @@ -22,8 +22,6 @@ There are a few more optional functions which you can implement to improve the i - `Turing.Inference.isgibbscomponent(::MySampler)`: If you want your sampler to function as a component in Turing's Gibbs sampler, you should make this evaluate to `true`. - `Turing.Inference.requires_unconstrained_space(::MySampler)`: If your sampler requires unconstrained space, you should return `true`. This tells Turing to perform linking on the VarInfo before evaluation, and ensures that the parameter values passed to your sampler will always be in unconstrained (Euclidean) space. - -- `Turing.Inference.getlogp_external(external_transition, external_state)`: Tell Turing how to extract the log probability density associated with this transition (and state). If you do not specify these, Turing will simply re-evaluate the model with the parameters obtained from `getparams`, which can be inefficient. It is therefore recommended to store the log probability density in either the transition or the state (or both) and override this method. """ struct ExternalSampler{S<:AbstractSampler,AD<:ADTypes.AbstractADType,Unconstrained} <: InferenceAlgorithm @@ -85,27 +83,21 @@ function externalsampler( return ExternalSampler(sampler, adtype, Val(unconstrained)) end -""" - getlogp_external(external_transition, external_state) - -Get the log probability density associated with the external sampler's -transition and state. Returns `missing` by default; in this case, an extra -model evaluation will be needed to calculate the correct log density. -""" -getlogp_external(::Any, ::Any) = missing -getlogp_external(mh::AdvancedMH.Transition, ::AdvancedMH.Transition) = mh.lp -getlogp_external(hmc::AdvancedHMC.Transition, ::AdvancedHMC.HMCState) = hmc.stat.log_density - -struct TuringState{S,V1<:AbstractVarInfo,M,V} +# TODO(penelopeysm): Can't we clean this up somehow? +struct TuringState{S,V1,M,V} state::S - # Note that this varinfo has the correct parameters and logp obtained from - # the state, whereas `ldf.varinfo` will in general have junk inside it. + # Note that this varinfo must have the correct parameters set; but logp + # does not matter as it will be re-evaluated varinfo::V1 + # Note that in general the VarInfo inside this LogDensityFunction will have + # junk parameters and logp. It only exists to provide structure ldf::DynamicPPL.LogDensityFunction{M,V} end -varinfo(state::TuringState) = state.varinfo -varinfo(state::AbstractVarInfo) = state +# get_varinfo should return something from which the correct parameters can be +# obtained, hence we use state.varinfo rather than state.ldf.varinfo +get_varinfo(state::TuringState) = state.varinfo +get_varinfo(state::AbstractVarInfo) = state getparams(::DynamicPPL.Model, transition::AdvancedHMC.Transition) = transition.z.θ function getparams(model::DynamicPPL.Model, state::AdvancedHMC.HMCState) @@ -115,27 +107,6 @@ getstats(transition::AdvancedHMC.Transition) = transition.stat getparams(::DynamicPPL.Model, transition::AdvancedMH.Transition) = transition.params -function make_updated_varinfo( - f::DynamicPPL.LogDensityFunction, external_transition, external_state -) - # Set the parameters. - new_parameters = getparams(f.model, external_state) - new_varinfo = DynamicPPL.unflatten(f.varinfo, new_parameters) - # Set (or recalculate, if needed) the log density. - new_logp = getlogp_external(external_transition, external_state) - return if ismissing(new_logp) - last(DynamicPPL.evaluate!!(f.model, new_varinfo, f.context)) - else - # TODO(DPPL0.37/penelopeysm) This is obviously wrong. Note that we - # have the same problem here as in HMC in that the sampler doesn't - # tell us about how logp is broken down into prior and likelihood. - # We should probably just re-evaluate unconditionally. A bit - # unfortunate. - DynamicPPL.setlogprior!!(new_varinfo, 0.0) - DynamicPPL.setloglikelihood!!(new_varinfo, new_logp) - end -end - # TODO: Do we also support `resume`, etc? function AbstractMCMC.step( rng::Random.AbstractRNG, @@ -163,7 +134,7 @@ function AbstractMCMC.step( # Construct LogDensityFunction f = DynamicPPL.LogDensityFunction( - model, DynamicPPL.getlogjoint, varinfo; adtype=alg.adtype + model, DynamicPPL.getlogjoint_internal, varinfo; adtype=alg.adtype ) # Then just call `AbstractMCMC.step` with the right arguments. @@ -182,13 +153,10 @@ function AbstractMCMC.step( ) end - # Get the parameters and log density, and set them in the varinfo. - new_varinfo = make_updated_varinfo(f, transition_inner, state_inner) - - # Update the `state` + new_parameters = getparams(f.model, state_inner) + new_vi = DynamicPPL.unflatten(f.varinfo, new_parameters) return ( - Transition(f.model, new_varinfo, transition_inner), - TuringState(state_inner, new_varinfo, f), + Transition(f.model, new_vi, transition_inner), TuringState(state_inner, new_vi, f) ) end @@ -207,12 +175,9 @@ function AbstractMCMC.step( rng, AbstractMCMC.LogDensityModel(f), sampler, state.state; kwargs... ) - # Get the parameters and log density, and set them in the varinfo. - new_varinfo = make_updated_varinfo(f, transition_inner, state_inner) - - # Update the `state` + new_parameters = getparams(f.model, state_inner) + new_vi = DynamicPPL.unflatten(f.varinfo, new_parameters) return ( - Transition(f.model, new_varinfo, transition_inner), - TuringState(state_inner, new_varinfo, f), + Transition(f.model, new_vi, transition_inner), TuringState(state_inner, new_vi, f) ) end diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 58db29789d..6927487679 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -343,7 +343,7 @@ struct GibbsState{V<:DynamicPPL.AbstractVarInfo,S} states::S end -varinfo(state::GibbsState) = state.vi +get_varinfo(state::GibbsState) = state.vi """ Initialise a VarInfo for the Gibbs sampler. @@ -390,7 +390,7 @@ function AbstractMCMC.step( initial_params=initial_params, kwargs..., ) - return Transition(model, vi), GibbsState(vi, states) + return Transition(model, vi, nothing), GibbsState(vi, states) end function AbstractMCMC.step_warmup( @@ -415,7 +415,7 @@ function AbstractMCMC.step_warmup( initial_params=initial_params, kwargs..., ) - return Transition(model, vi), GibbsState(vi, states) + return Transition(model, vi, nothing), GibbsState(vi, states) end """ @@ -465,7 +465,7 @@ function gibbs_initialstep_recursive( initial_params=initial_params_local, kwargs..., ) - new_vi_local = varinfo(new_state) + new_vi_local = get_varinfo(new_state) # Merge in any new variables that were introduced during the step, but that # were not in the domain of the current sampler. vi = merge(vi, get_global_varinfo(context)) @@ -493,7 +493,7 @@ function AbstractMCMC.step( state::GibbsState; kwargs..., ) - vi = varinfo(state) + vi = get_varinfo(state) alg = spl.alg varnames = alg.varnames samplers = alg.samplers @@ -503,7 +503,7 @@ function AbstractMCMC.step( vi, states = gibbs_step_recursive( rng, model, AbstractMCMC.step, varnames, samplers, states, vi; kwargs... ) - return Transition(model, vi), GibbsState(vi, states) + return Transition(model, vi, nothing), GibbsState(vi, states) end function AbstractMCMC.step_warmup( @@ -513,7 +513,7 @@ function AbstractMCMC.step_warmup( state::GibbsState; kwargs..., ) - vi = varinfo(state) + vi = get_varinfo(state) alg = spl.alg varnames = alg.varnames samplers = alg.samplers @@ -523,7 +523,7 @@ function AbstractMCMC.step_warmup( vi, states = gibbs_step_recursive( rng, model, AbstractMCMC.step_warmup, varnames, samplers, states, vi; kwargs... ) - return Transition(model, vi), GibbsState(vi, states) + return Transition(model, vi, nothing), GibbsState(vi, states) end """ @@ -541,14 +541,11 @@ function setparams_varinfo!!(model, ::Sampler, state, params::AbstractVarInfo) end function setparams_varinfo!!( - model::DynamicPPL.Model, - sampler::Sampler{<:MH}, - state::AbstractVarInfo, - params::AbstractVarInfo, + model::DynamicPPL.Model, sampler::Sampler{<:MH}, state::MHState, params::AbstractVarInfo ) - # The state is already a VarInfo, so we can just return `params`, but first we need to - # update its logprob. - return last(DynamicPPL.evaluate!!(model, params)) + # Re-evaluate to update the logprob. + new_vi = last(DynamicPPL.evaluate!!(model, params)) + return MHState(new_vi, DynamicPPL.getlogjoint_internal(new_vi)) end function setparams_varinfo!!( @@ -569,7 +566,7 @@ function setparams_varinfo!!( params::AbstractVarInfo, ) logdensity = DynamicPPL.LogDensityFunction( - model, DynamicPPL.getlogjoint, state.ldf.varinfo; adtype=sampler.alg.adtype + model, DynamicPPL.getlogjoint_internal, state.ldf.varinfo; adtype=sampler.alg.adtype ) new_inner_state = setparams_varinfo!!( AbstractMCMC.LogDensityModel(logdensity), sampler, state.state, params @@ -608,7 +605,7 @@ state for this sampler. This is relevant when multilple samplers are sampling th variables, and one might need it to be linked while the other doesn't. """ function match_linking!!(varinfo_local, prev_state_local, model) - prev_varinfo_local = varinfo(prev_state_local) + prev_varinfo_local = get_varinfo(prev_state_local) was_linked = DynamicPPL.istrans(prev_varinfo_local) is_linked = DynamicPPL.istrans(varinfo_local) if was_linked && !is_linked @@ -690,10 +687,10 @@ function gibbs_step_recursive( # Take a step with the local sampler. new_state = last(step_function(rng, conditioned_model, sampler, state; kwargs...)) - new_vi_local = varinfo(new_state) + new_vi_local = get_varinfo(new_state) # Merge the latest values for all the variables in the current sampler. new_global_vi = merge(get_global_varinfo(context), new_vi_local) - new_global_vi = setlogp!!(new_global_vi, getlogp(new_vi_local)) + new_global_vi = DynamicPPL.setlogp!!(new_global_vi, DynamicPPL.getlogp(new_vi_local)) new_states = (new_states..., new_state) return gibbs_step_recursive( diff --git a/src/mcmc/hmc.jl b/src/mcmc/hmc.jl index 18733f6a8d..d80502f7e1 100644 --- a/src/mcmc/hmc.jl +++ b/src/mcmc/hmc.jl @@ -25,7 +25,7 @@ end ### Hamiltonian Monte Carlo samplers. ### -varinfo(state::HMCState) = state.vi +get_varinfo(state::HMCState) = state.vi """ HMC(ϵ::Float64, n_leapfrog::Int; adtype::ADTypes.AbstractADType = AutoForwardDiff()) @@ -193,7 +193,7 @@ function DynamicPPL.initialstep( metricT = getmetricT(spl.alg) metric = metricT(length(theta)) ldf = DynamicPPL.LogDensityFunction( - model, DynamicPPL.getlogjoint, vi; adtype=spl.alg.adtype + model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.alg.adtype ) lp_func = Base.Fix1(LogDensityProblems.logdensity, ldf) lp_grad_func = Base.Fix1(LogDensityProblems.logdensity_and_gradient, ldf) @@ -208,9 +208,6 @@ function DynamicPPL.initialstep( end theta = vi[:] - # Cache current log density. We will reuse this if the transition is rejected. - logp_old = DynamicPPL.getlogp(vi) - # Find good eps if not provided one if iszero(spl.alg.ϵ) ϵ = AHMC.find_good_stepsize(rng, hamiltonian, theta) @@ -234,22 +231,13 @@ function DynamicPPL.initialstep( ) end - # Update VarInfo based on acceptance - if t.stat.is_accept - vi = DynamicPPL.unflatten(vi, t.z.θ) - # Re-evaluate to calculate log probability density. - # TODO(penelopeysm): This seems a little bit wasteful. Unfortunately, - # even though `t.stat.log_density` contains some kind of logp, this - # doesn't track prior and likelihood separately but rather a single - # log-joint (and in linked space), so which we have no way to decompose - # this back into prior and likelihood. I don't immediately see how to - # solve this without re-evaluating the model. - _, vi = DynamicPPL.evaluate!!(model, vi) + # Update VarInfo parameters based on acceptance + new_params = if t.stat.is_accept + t.z.θ else - # Reset VarInfo back to its original state. - vi = DynamicPPL.unflatten(vi, theta) - vi = DynamicPPL.setlogp!!(vi, logp_old) + theta end + vi = DynamicPPL.unflatten(vi, new_params) transition = Transition(model, vi, t) state = HMCState(vi, 1, kernel, hamiltonian, t.z, adaptor) @@ -293,9 +281,6 @@ function AbstractMCMC.step( vi = state.vi if t.stat.is_accept vi = DynamicPPL.unflatten(vi, t.z.θ) - # Re-evaluate to calculate log probability density. - # TODO(penelopeysm): This seems a little bit wasteful. See note above. - _, vi = DynamicPPL.evaluate!!(model, vi) end # Compute next transition and state. @@ -308,7 +293,7 @@ end function get_hamiltonian(model, spl, vi, state, n) metric = gen_metric(n, spl, state) ldf = DynamicPPL.LogDensityFunction( - model, DynamicPPL.getlogjoint, vi; adtype=spl.alg.adtype + model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.alg.adtype ) lp_func = Base.Fix1(LogDensityProblems.logdensity, ldf) lp_grad_func = Base.Fix1(LogDensityProblems.logdensity_and_gradient, ldf) diff --git a/src/mcmc/is.jl b/src/mcmc/is.jl index 5f2f1627fb..319e424fcb 100644 --- a/src/mcmc/is.jl +++ b/src/mcmc/is.jl @@ -31,25 +31,19 @@ DynamicPPL.initialsampler(sampler::Sampler{<:IS}) = sampler function DynamicPPL.initialstep( rng::AbstractRNG, model::Model, spl::Sampler{<:IS}, vi::AbstractVarInfo; kwargs... ) - # Need to manually construct the Transition here because we only - # want to use the likelihood. - xs = Turing.Inference.getparams(model, vi) - lp = DynamicPPL.getloglikelihood(vi) - return Transition(xs, lp, nothing), nothing + return Transition(model, vi, nothing), nothing end function AbstractMCMC.step( rng::Random.AbstractRNG, model::Model, spl::Sampler{<:IS}, ::Nothing; kwargs... ) vi = VarInfo(rng, model, spl) - xs = Turing.Inference.getparams(model, vi) - lp = DynamicPPL.getloglikelihood(vi) - return Transition(xs, lp, nothing), nothing + return Transition(model, vi, nothing), nothing end # Calculate evidence. function getlogevidence(samples::Vector{<:Transition}, ::Sampler{<:IS}, state) - return logsumexp(map(x -> x.lp, samples)) - log(length(samples)) + return logsumexp(map(x -> x.loglikelihood, samples)) - log(length(samples)) end function DynamicPPL.assume(rng, ::Sampler{<:IS}, dist::Distribution, vn::VarName, vi) diff --git a/src/mcmc/mh.jl b/src/mcmc/mh.jl index e0258fda45..eb5b3aa3ee 100644 --- a/src/mcmc/mh.jl +++ b/src/mcmc/mh.jl @@ -153,6 +153,27 @@ function MH(model::Model; proposal_type=AMH.StaticProposal) return AMH.MetropolisHastings(priors) end +""" + MHState(varinfo::AbstractVarInfo, logjoint_internal::Real) + +State for Metropolis-Hastings sampling. + +`varinfo` must have the correct parameters set inside it, but its other fields +(e.g. accumulators, which track logp) can in general be missing or incorrect. + +`logjoint_internal` is the log joint probability of the model, evaluated using +the parameters and linking status of `varinfo`. It should be equal to +`DynamicPPL.getlogjoint_internal(varinfo)`. This information is returned by the +MH sampler so we store this here to avoid re-evaluating the model +unnecessarily. +""" +struct MHState{V<:AbstractVarInfo,L<:Real} + varinfo::V + logjoint_internal::L +end + +get_varinfo(s::MHState) = s.varinfo + ##################### # Utility functions # ##################### @@ -297,14 +318,15 @@ end # Make a proposal if we don't have a covariance proposal matrix (the default). function propose!!( - rng::AbstractRNG, vi::AbstractVarInfo, model::Model, spl::Sampler{<:MH}, proposal + rng::AbstractRNG, prev_state::MHState, model::Model, spl::Sampler{<:MH}, proposal ) + vi = prev_state.varinfo # Retrieve distribution and value NamedTuples. dt, vt = dist_val_tuple(spl, vi) # Create a sampler and the previous transition. mh_sampler = AMH.MetropolisHastings(dt) - prev_trans = AMH.Transition(vt, DynamicPPL.getlogjoint(vi), false) + prev_trans = AMH.Transition(vt, prev_state.logjoint_internal, false) # Make a new transition. spl_model = DynamicPPL.contextualize( @@ -313,35 +335,35 @@ function propose!!( densitymodel = AMH.DensityModel( Base.Fix1( LogDensityProblems.logdensity, - DynamicPPL.LogDensityFunction(spl_model, DynamicPPL.getlogjoint, vi), + DynamicPPL.LogDensityFunction(spl_model, DynamicPPL.getlogjoint_internal, vi), ), ) trans, _ = AbstractMCMC.step(rng, densitymodel, mh_sampler, prev_trans) - - # TODO: Make this compatible with immutable `VarInfo`. - # Update the values in the VarInfo. - # TODO(DPPL0.37/penelopeysm): This is obviously incorrect. We need to - # re-evaluate the model. + # trans.params isa NamedTuple set_namedtuple!(vi, trans.params) - vi = DynamicPPL.setloglikelihood!!(vi, trans.lp) - return DynamicPPL.setlogprior!!(vi, 0.0) + # Here, `trans.lp` is equal to `getlogjoint_internal(vi)`. We don't know + # how to set this back inside vi (without re-evaluating). However, the next + # MH step will require this information to calculate the acceptance + # probability, so we return it together with vi. + return MHState(vi, trans.lp) end # Make a proposal if we DO have a covariance proposal matrix. function propose!!( rng::AbstractRNG, - vi::AbstractVarInfo, + prev_state::MHState, model::Model, spl::Sampler{<:MH}, proposal::AdvancedMH.RandomWalkProposal, ) + vi = prev_state.varinfo # If this is the case, we can just draw directly from the proposal # matrix. vals = vi[:] # Create a sampler and the previous transition. mh_sampler = AMH.MetropolisHastings(spl.alg.proposals) - prev_trans = AMH.Transition(vals, DynamicPPL.getlogjoint(vi), false) + prev_trans = AMH.Transition(vals, prev_state.logjoint_internal, false) # Make a new transition. spl_model = DynamicPPL.contextualize( @@ -350,16 +372,17 @@ function propose!!( densitymodel = AMH.DensityModel( Base.Fix1( LogDensityProblems.logdensity, - DynamicPPL.LogDensityFunction(spl_model, DynamicPPL.getlogjoint, vi), + DynamicPPL.LogDensityFunction(spl_model, DynamicPPL.getlogjoint_internal, vi), ), ) trans, _ = AbstractMCMC.step(rng, densitymodel, mh_sampler, prev_trans) - - # TODO(DPPL0.37/penelopeysm): This is obviously incorrect. We need to - # re-evaluate the model. + # trans.params isa AbstractVector vi = DynamicPPL.unflatten(vi, trans.params) - vi = DynamicPPL.setloglikelihood!!(vi, trans.lp) - return DynamicPPL.setlogprior!!(vi, 0.0) + # Here, `trans.lp` is equal to `getlogjoint_internal(vi)`. We don't know + # how to set this back inside vi (without re-evaluating). However, the next + # MH step will require this information to calculate the acceptance + # probability, so we return it together with vi. + return MHState(vi, trans.lp) end function DynamicPPL.initialstep( @@ -373,18 +396,18 @@ function DynamicPPL.initialstep( # just link everything before sampling. vi = maybe_link!!(vi, spl, spl.alg.proposals, model) - return Transition(model, vi), vi + return Transition(model, vi, nothing), MHState(vi, DynamicPPL.getlogjoint_internal(vi)) end function AbstractMCMC.step( - rng::AbstractRNG, model::Model, spl::Sampler{<:MH}, vi::AbstractVarInfo; kwargs... + rng::AbstractRNG, model::Model, spl::Sampler{<:MH}, state::MHState; kwargs... ) # Cases: # 1. A covariance proposal matrix # 2. A bunch of NamedTuples that specify the proposal space - vi = propose!!(rng, vi, model, spl, spl.alg.proposals) + new_state = propose!!(rng, state, model, spl, spl.alg.proposals) - return Transition(model, vi), vi + return Transition(model, new_state.varinfo, nothing), new_state end #### diff --git a/src/mcmc/particle_mcmc.jl b/src/mcmc/particle_mcmc.jl index f7e40f4326..6959e22ccd 100644 --- a/src/mcmc/particle_mcmc.jl +++ b/src/mcmc/particle_mcmc.jl @@ -146,13 +146,11 @@ end function SMCTransition(model::DynamicPPL.Model, vi::AbstractVarInfo, weight) theta = getparams(model, vi) - lp = DynamicPPL.getlogjoint(vi) + lp = DynamicPPL.getlogjoint_internal(vi) return SMCTransition(theta, lp, weight) end -metadata(t::SMCTransition) = (lp=t.lp, weight=t.weight) - -DynamicPPL.getlogp(t::SMCTransition) = t.lp +getstats_with_lp(t::SMCTransition) = (lp=t.lp, weight=t.weight) struct SMCState{P,F<:AbstractFloat} particles::P @@ -317,17 +315,15 @@ struct PGState rng::Random.AbstractRNG end -varinfo(state::PGState) = state.vi +get_varinfo(state::PGState) = state.vi function PGTransition(model::DynamicPPL.Model, vi::AbstractVarInfo, logevidence) theta = getparams(model, vi) - lp = DynamicPPL.getlogjoint(vi) + lp = DynamicPPL.getlogjoint_internal(vi) return PGTransition(theta, lp, logevidence) end -metadata(t::PGTransition) = (lp=t.lp, logevidence=t.logevidence) - -DynamicPPL.getlogp(t::PGTransition) = t.lp +getstats_with_lp(t::PGTransition) = (lp=t.lp, logevidence=t.logevidence) function getlogevidence(samples, sampler::Sampler{<:PG}, state::PGState) return mean(x.logevidence for x in samples) diff --git a/src/mcmc/prior.jl b/src/mcmc/prior.jl index eadeaceb38..6d7463c2f9 100644 --- a/src/mcmc/prior.jl +++ b/src/mcmc/prior.jl @@ -17,11 +17,7 @@ function AbstractMCMC.step( model, DynamicPPL.SamplingContext(rng, DynamicPPL.SampleFromPrior(), model.context) ) _, vi = DynamicPPL.evaluate!!(sampling_model, VarInfo()) - # Need to manually construct the Transition here because we only - # want to use the prior probability. - xs = Turing.Inference.getparams(model, vi) - lp = DynamicPPL.getlogprior(vi) - return Transition(xs, lp, nothing), nothing + return Transition(model, vi, nothing), nothing end DynamicPPL.default_chain_type(sampler::Prior) = MCMCChains.Chains diff --git a/src/mcmc/sghmc.jl b/src/mcmc/sghmc.jl index 6eeddfefa0..5ca351643e 100644 --- a/src/mcmc/sghmc.jl +++ b/src/mcmc/sghmc.jl @@ -58,19 +58,15 @@ function DynamicPPL.initialstep( vi::AbstractVarInfo; kwargs..., ) - # Transform the samples to unconstrained space and compute the joint log probability. + # Transform the samples to unconstrained space. if !DynamicPPL.islinked(vi) vi = DynamicPPL.link!!(vi, model) - vi = last(DynamicPPL.evaluate!!(model, vi)) end # Compute initial sample and state. - sample = Transition(model, vi) + sample = Transition(model, vi, nothing) ℓ = DynamicPPL.LogDensityFunction( - model, - vi, - DynamicPPL.SamplingContext(spl, DynamicPPL.DefaultContext()); - adtype=spl.alg.adtype, + model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.alg.adtype ) state = SGHMCState(ℓ, vi, zero(vi[:])) @@ -98,12 +94,11 @@ function AbstractMCMC.step( α = spl.alg.momentum_decay newv = (1 - α) .* v .+ η .* grad .+ sqrt(2 * η * α) .* randn(rng, eltype(v), length(v)) - # Save new variables and recompute log density. + # Save new variables. vi = DynamicPPL.unflatten(vi, θ) - vi = last(DynamicPPL.evaluate!!(model, vi)) # Compute next sample and state. - sample = Transition(model, vi) + sample = Transition(model, vi, nothing) newstate = SGHMCState(ℓ, vi, newv) return sample, newstate @@ -200,13 +195,11 @@ end function SGLDTransition(model::DynamicPPL.Model, vi::AbstractVarInfo, stepsize) theta = getparams(model, vi) - lp = DynamicPPL.getlogjoint(vi) + lp = DynamicPPL.getlogjoint_internal(vi) return SGLDTransition(theta, lp, stepsize) end -metadata(t::SGLDTransition) = (lp=t.lp, SGLD_stepsize=t.stepsize) - -DynamicPPL.getlogp(t::SGLDTransition) = t.lp +getstats_with_lp(t::SGLDTransition) = (lp=t.lp, SGLD_stepsize=t.stepsize) struct SGLDState{L,V<:AbstractVarInfo} logdensity::L @@ -221,19 +214,15 @@ function DynamicPPL.initialstep( vi::AbstractVarInfo; kwargs..., ) - # Transform the samples to unconstrained space and compute the joint log probability. + # Transform the samples to unconstrained space. if !DynamicPPL.islinked(vi) vi = DynamicPPL.link!!(vi, model) - vi = last(DynamicPPL.evaluate!!(model, vi)) end # Create first sample and state. sample = SGLDTransition(model, vi, zero(spl.alg.stepsize(0))) ℓ = DynamicPPL.LogDensityFunction( - model, - vi, - DynamicPPL.SamplingContext(spl, DynamicPPL.DefaultContext()); - adtype=spl.alg.adtype, + model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.alg.adtype ) state = SGLDState(ℓ, vi, 1) @@ -252,9 +241,8 @@ function AbstractMCMC.step( stepsize = spl.alg.stepsize(step) θ .+= (stepsize / 2) .* grad .+ sqrt(stepsize) .* randn(rng, eltype(θ), length(θ)) - # Save new variables and recompute log density. + # Save new variables. vi = DynamicPPL.unflatten(vi, θ) - vi = last(DynamicPPL.evaluate!!(model, vi)) # Compute next sample and state. sample = SGLDTransition(model, vi, stepsize) diff --git a/src/optimisation/Optimisation.jl b/src/optimisation/Optimisation.jl index 29ea067263..19c52c381b 100644 --- a/src/optimisation/Optimisation.jl +++ b/src/optimisation/Optimisation.jl @@ -43,113 +43,6 @@ Concrete type for maximum a posteriori estimation. Only used for the Optim.jl in """ struct MAP <: ModeEstimator end -# Most of these functions for LogPriorWithoutJacobianAccumulator are copied from -# LogPriorAccumulator. The only one that is different is the accumulate_assume!! one. -""" - LogPriorWithoutJacobianAccumulator{T} <: DynamicPPL.AbstractAccumulator - -Exactly like DynamicPPL.LogPriorAccumulator, but does not include the log determinant of the -Jacobian of any variable transformations. - -Used for MAP optimisation. -""" -struct LogPriorWithoutJacobianAccumulator{T} <: DynamicPPL.AbstractAccumulator - logp::T -end - -""" - LogPriorWithoutJacobianAccumulator{T}() - -Create a new `LogPriorWithoutJacobianAccumulator` accumulator with the log prior initialized to zero. -""" -LogPriorWithoutJacobianAccumulator{T}() where {T<:Real} = - LogPriorWithoutJacobianAccumulator(zero(T)) -function LogPriorWithoutJacobianAccumulator() - return LogPriorWithoutJacobianAccumulator{DynamicPPL.LogProbType}() -end - -function Base.show(io::IO, acc::LogPriorWithoutJacobianAccumulator) - return print(io, "LogPriorWithoutJacobianAccumulator($(repr(acc.logp)))") -end - -function DynamicPPL.accumulator_name(::Type{<:LogPriorWithoutJacobianAccumulator}) - return :LogPriorWithoutJacobian -end - -Base.copy(acc::LogPriorWithoutJacobianAccumulator) = acc - -function DynamicPPL.split(::LogPriorWithoutJacobianAccumulator{T}) where {T} - return LogPriorWithoutJacobianAccumulator(zero(T)) -end - -function DynamicPPL.combine( - acc::LogPriorWithoutJacobianAccumulator, acc2::LogPriorWithoutJacobianAccumulator -) - return LogPriorWithoutJacobianAccumulator(acc.logp + acc2.logp) -end - -function Base.:+( - acc1::LogPriorWithoutJacobianAccumulator, acc2::LogPriorWithoutJacobianAccumulator -) - return LogPriorWithoutJacobianAccumulator(acc1.logp + acc2.logp) -end - -function Base.zero(acc::LogPriorWithoutJacobianAccumulator) - return LogPriorWithoutJacobianAccumulator(zero(acc.logp)) -end - -function DynamicPPL.accumulate_assume!!( - acc::LogPriorWithoutJacobianAccumulator, val, logjac, vn, right -) - return acc + LogPriorWithoutJacobianAccumulator(Distributions.logpdf(right, val)) -end -function DynamicPPL.accumulate_observe!!( - acc::LogPriorWithoutJacobianAccumulator, right, left, vn -) - return acc -end - -function Base.convert( - ::Type{LogPriorWithoutJacobianAccumulator{T}}, acc::LogPriorWithoutJacobianAccumulator -) where {T} - return LogPriorWithoutJacobianAccumulator(convert(T, acc.logp)) -end - -function DynamicPPL.convert_eltype( - ::Type{T}, acc::LogPriorWithoutJacobianAccumulator -) where {T} - return LogPriorWithoutJacobianAccumulator(convert(T, acc.logp)) -end - -function getlogprior_without_jacobian(vi::DynamicPPL.AbstractVarInfo) - acc = DynamicPPL.getacc(vi, Val(:LogPriorWithoutJacobian)) - return acc.logp -end - -function getlogjoint_without_jacobian(vi::DynamicPPL.AbstractVarInfo) - return getlogprior_without_jacobian(vi) + DynamicPPL.getloglikelihood(vi) -end - -# This is called when constructing a LogDensityFunction, and ensures the VarInfo has the -# right accumulators. -function DynamicPPL.ldf_default_varinfo( - model::DynamicPPL.Model, ::typeof(getlogprior_without_jacobian) -) - vi = DynamicPPL.VarInfo(model) - vi = DynamicPPL.setaccs!!(vi, (LogPriorWithoutJacobianAccumulator(),)) - return vi -end - -function DynamicPPL.ldf_default_varinfo( - model::DynamicPPL.Model, ::typeof(getlogjoint_without_jacobian) -) - vi = DynamicPPL.VarInfo(model) - vi = DynamicPPL.setaccs!!( - vi, (LogPriorWithoutJacobianAccumulator(), DynamicPPL.LogLikelihoodAccumulator()) - ) - return vi -end - """ OptimLogDensity{ M<:DynamicPPL.Model, @@ -628,8 +521,10 @@ function estimate_mode( # Create an OptimLogDensity object that can be used to evaluate the objective function, # i.e. the negative log density. - getlogdensity = - estimator isa MAP ? getlogjoint_without_jacobian : DynamicPPL.getloglikelihood + # Note that we use `getlogjoint` rather than `getlogjoint_internal`: this + # is intentional, because even though the VarInfo may be linked, the + # optimisation target should not take the Jacobian term into account. + getlogdensity = estimator isa MAP ? DynamicPPL.getlogjoint : DynamicPPL.getloglikelihood # Set its VarInfo to the initial parameters. # TODO(penelopeysm): Unclear if this is really needed? Any time that logp is calculated diff --git a/test/essential/container.jl b/test/essential/container.jl index 8ebce270d4..124637aab7 100644 --- a/test/essential/container.jl +++ b/test/essential/container.jl @@ -28,14 +28,11 @@ using Turing @test trace.model.ctask.taped_globals.other === trace res = AdvancedPS.advance!(trace, false) - @test DynamicPPL.get_num_produce(trace.model.f.varinfo) == 1 @test res ≈ -log(2) # Catch broken copy, espetially for RNG / VarInfo newtrace = AdvancedPS.fork(trace) res2 = AdvancedPS.advance!(trace) - @test DynamicPPL.get_num_produce(trace.model.f.varinfo) == 2 - @test DynamicPPL.get_num_produce(newtrace.model.f.varinfo) == 1 end @testset "fork" begin diff --git a/test/mcmc/Inference.jl b/test/mcmc/Inference.jl index 5d3d265c75..8c26e2a227 100644 --- a/test/mcmc/Inference.jl +++ b/test/mcmc/Inference.jl @@ -5,7 +5,7 @@ using ..NumericalTests: check_gdemo, check_numerical using Distributions: Bernoulli, Beta, InverseGamma, Normal using Distributions: sample import DynamicPPL -using DynamicPPL: Sampler, getlogp +using DynamicPPL: Sampler import ForwardDiff using LinearAlgebra: I import MCMCChains @@ -116,11 +116,9 @@ using Turing @testset "Prior" begin N = 10_000 - # Note that all chains contain 3 values per sample: 2 variables + log probability @testset "Single-threaded vanilla" begin chains = sample(StableRNG(seed), gdemo_d(), Prior(), N) @test chains isa MCMCChains.Chains - @test size(chains) == (N, 3, 1) @test mean(chains, :s) ≈ 3 atol = 0.11 @test mean(chains, :m) ≈ 0 atol = 0.1 end @@ -128,7 +126,6 @@ using Turing @testset "Multi-threaded" begin chains = sample(StableRNG(seed), gdemo_d(), Prior(), MCMCThreads(), N, 4) @test chains isa MCMCChains.Chains - @test size(chains) == (N, 3, 4) @test mean(chains, :s) ≈ 3 atol = 0.11 @test mean(chains, :m) ≈ 0 atol = 0.1 end @@ -139,8 +136,9 @@ using Turing ) @test chains isa Vector{<:NamedTuple} @test length(chains) == N - @test all(length(x) == 3 for x in chains) @test all(haskey(x, :lp) for x in chains) + @test all(haskey(x, :logprior) for x in chains) + @test all(haskey(x, :loglikelihood) for x in chains) @test mean(x[:s][1] for x in chains) ≈ 3 atol = 0.11 @test mean(x[:m][1] for x in chains) ≈ 0 atol = 0.1 end diff --git a/test/mcmc/ess.jl b/test/mcmc/ess.jl index e918b3a512..1e1be9b45f 100644 --- a/test/mcmc/ess.jl +++ b/test/mcmc/ess.jl @@ -54,7 +54,7 @@ using Turing @testset "gdemo with CSMC + ESS" begin alg = Gibbs(:s => CSMC(15), :m => ESS()) - chain = sample(StableRNG(seed), gdemo(1.5, 2.0), alg, 2000) + chain = sample(StableRNG(seed), gdemo(1.5, 2.0), alg, 3_000) check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.1) end diff --git a/test/mcmc/external_sampler.jl b/test/mcmc/external_sampler.jl index 6a6aebddb0..5127c628ee 100644 --- a/test/mcmc/external_sampler.jl +++ b/test/mcmc/external_sampler.jl @@ -21,7 +21,7 @@ function initialize_nuts(model::DynamicPPL.Model) # Create a LogDensityFunction f = DynamicPPL.LogDensityFunction( - model, DynamicPPL.getlogjoint, linked_vi; adtype=Turing.DEFAULT_ADTYPE + model, DynamicPPL.getlogjoint_internal, linked_vi; adtype=Turing.DEFAULT_ADTYPE ) # Choose parameter dimensionality and initial parameter value diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index 092cc71f7e..0fd76be3ab 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -295,7 +295,7 @@ end vi::T end - Turing.Inference.varinfo(state::VarInfoState) = state.vi + Turing.Inference.get_varinfo(state::VarInfoState) = state.vi function Turing.Inference.setparams_varinfo!!( ::DynamicPPL.Model, ::DynamicPPL.Sampler, @@ -312,8 +312,8 @@ end kwargs..., ) spl.alg.non_warmup_init_count += 1 - return Turing.Inference.Transition(nothing, 0.0), - VarInfoState(DynamicPPL.VarInfo(model)) + vi = DynamicPPL.VarInfo(model) + return (Turing.Inference.Transition(model, vi, nothing), VarInfoState(vi)) end function AbstractMCMC.step_warmup( @@ -323,30 +323,30 @@ end kwargs..., ) spl.alg.warmup_init_count += 1 - return Turing.Inference.Transition(nothing, 0.0), - VarInfoState(DynamicPPL.VarInfo(model)) + vi = DynamicPPL.VarInfo(model) + return (Turing.Inference.Transition(model, vi, nothing), VarInfoState(vi)) end function AbstractMCMC.step( ::Random.AbstractRNG, - ::DynamicPPL.Model, + model::DynamicPPL.Model, spl::DynamicPPL.Sampler{<:WarmupCounter}, s::VarInfoState; kwargs..., ) spl.alg.non_warmup_count += 1 - return Turing.Inference.Transition(nothing, 0.0), s + return Turing.Inference.Transition(model, s.vi, nothing), s end function AbstractMCMC.step_warmup( ::Random.AbstractRNG, - ::DynamicPPL.Model, + model::DynamicPPL.Model, spl::DynamicPPL.Sampler{<:WarmupCounter}, s::VarInfoState; kwargs..., ) spl.alg.warmup_count += 1 - return Turing.Inference.Transition(nothing, 0.0), s + return Turing.Inference.Transition(model, s.vi, nothing), s end @model f() = x ~ Normal() @@ -886,7 +886,9 @@ end function check_logp_correct(sampler) @testset "logp is set correctly" begin @model logp_check() = x ~ Normal() - chn = sample(logp_check(), Gibbs(@varname(x) => sampler), 100) + chn = sample( + logp_check(), Gibbs(@varname(x) => sampler), 100; progress=false + ) @test isapprox(logpdf.(Normal(), chn[:x]), chn[:lp]) end end diff --git a/test/mcmc/is.jl b/test/mcmc/is.jl index 44fbe92014..2811e9c866 100644 --- a/test/mcmc/is.jl +++ b/test/mcmc/is.jl @@ -47,11 +47,11 @@ using Turing Random.seed!(seed) chain = sample(model, alg, n; check_model=false) - sampled = get(chain, [:a, :b, :lp]) + sampled = get(chain, [:a, :b, :loglikelihood]) @test vec(sampled.a) == ref.as @test vec(sampled.b) == ref.bs - @test vec(sampled.lp) == ref.logps + @test vec(sampled.loglikelihood) == ref.logps @test chain.logevidence == ref.logevidence end diff --git a/test/optimisation/Optimisation.jl b/test/optimisation/Optimisation.jl index 9909ee149b..6b93e7629e 100644 --- a/test/optimisation/Optimisation.jl +++ b/test/optimisation/Optimisation.jl @@ -41,12 +41,10 @@ using Turing @testset "With ConditionContext" begin m1 = model1(x) m2 = model2() | (x=x,) - # Doesn't matter if we use getlogjoint or getlogjoint_without_jacobian since the + # Doesn't matter if we use getlogjoint or getlogjoint_internal since the # VarInfo isn't linked. - ld1 = Turing.Optimisation.OptimLogDensity( - m1, Turing.Optimisation.getlogjoint_without_jacobian - ) - ld2 = Turing.Optimisation.OptimLogDensity(m2, DynamicPPL.getlogjoint) + ld1 = Turing.Optimisation.OptimLogDensity(m1, DynamicPPL.getlogjoint) + ld2 = Turing.Optimisation.OptimLogDensity(m2, DynamicPPL.getlogjoint_internal) @test ld1(w) == ld2(w) end @@ -54,22 +52,16 @@ using Turing vn = @varname(inner) m1 = prefix(model1(x), vn) m2 = prefix((model2() | (x=x,)), vn) - ld1 = Turing.Optimisation.OptimLogDensity( - m1, Turing.Optimisation.getlogjoint_without_jacobian - ) - ld2 = Turing.Optimisation.OptimLogDensity(m2, DynamicPPL.getlogjoint) + ld1 = Turing.Optimisation.OptimLogDensity(m1, DynamicPPL.getlogjoint) + ld2 = Turing.Optimisation.OptimLogDensity(m2, DynamicPPL.getlogjoint_internal) @test ld1(w) == ld2(w) end @testset "Joint, prior, and likelihood" begin m1 = model1(x) a = [0.3] - ld_joint = Turing.Optimisation.OptimLogDensity( - m1, Turing.Optimisation.getlogjoint_without_jacobian - ) - ld_prior = Turing.Optimisation.OptimLogDensity( - m1, Turing.Optimisation.getlogprior_without_jacobian - ) + ld_joint = Turing.Optimisation.OptimLogDensity(m1, DynamicPPL.getlogjoint) + ld_prior = Turing.Optimisation.OptimLogDensity(m1, DynamicPPL.getlogprior) ld_likelihood = Turing.Optimisation.OptimLogDensity( m1, DynamicPPL.getloglikelihood ) From 119c81884589ab934e764090e7da9c3a6f62276c Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 1 Aug 2025 16:49:52 +0100 Subject: [PATCH 41/49] InitContext isn't for 0.37, update comments --- src/mcmc/abstractmcmc.jl | 2 +- src/mcmc/emcee.jl | 2 +- src/mcmc/ess.jl | 2 +- src/mcmc/mh.jl | 4 ++-- src/mcmc/prior.jl | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/mcmc/abstractmcmc.jl b/src/mcmc/abstractmcmc.jl index 4d55d5c698..4522875b4b 100644 --- a/src/mcmc/abstractmcmc.jl +++ b/src/mcmc/abstractmcmc.jl @@ -1,7 +1,7 @@ # TODO: Implement additional checks for certain samplers, e.g. # HMC not supporting discrete parameters. function _check_model(model::DynamicPPL.Model) - # TODO(DPPL0.37/penelopeysm): use InitContext + # TODO(DPPL0.38/penelopeysm): use InitContext spl_model = DynamicPPL.contextualize(model, DynamicPPL.SamplingContext(model.context)) return DynamicPPL.check_model(spl_model, VarInfo(); error_on_failure=true) end diff --git a/src/mcmc/emcee.jl b/src/mcmc/emcee.jl index 076e61d7e8..98ed20b40e 100644 --- a/src/mcmc/emcee.jl +++ b/src/mcmc/emcee.jl @@ -53,7 +53,7 @@ function AbstractMCMC.step( length(initial_params) == n || throw(ArgumentError("initial parameters have to be specified for each walker")) vis = map(vis, initial_params) do vi, init - # TODO(DPPL0.37/penelopeysm) This whole thing can be replaced with init!! + # TODO(DPPL0.38/penelopeysm) This whole thing can be replaced with init!! vi = DynamicPPL.initialize_parameters!!(vi, init, model) # Update log joint probability. diff --git a/src/mcmc/ess.jl b/src/mcmc/ess.jl index bbf900657c..3afd91607c 100644 --- a/src/mcmc/ess.jl +++ b/src/mcmc/ess.jl @@ -84,7 +84,7 @@ EllipticalSliceSampling.isgaussian(::Type{<:ESSPrior}) = true function Base.rand(rng::Random.AbstractRNG, p::ESSPrior) varinfo = p.varinfo # TODO: Surely there's a better way of doing this now that we have `SamplingContext`? - # TODO(DPPL0.37/penelopeysm): This can be replaced with `init!!(p.model, + # TODO(DPPL0.38/penelopeysm): This can be replaced with `init!!(p.model, # p.varinfo, PriorInit())` after TuringLang/DynamicPPL.jl#984. The reason # why we had to use the 'del' flag before this was because # SampleFromPrior() wouldn't overwrite existing variables. diff --git a/src/mcmc/mh.jl b/src/mcmc/mh.jl index eb5b3aa3ee..863db559ce 100644 --- a/src/mcmc/mh.jl +++ b/src/mcmc/mh.jl @@ -178,7 +178,7 @@ get_varinfo(s::MHState) = s.varinfo # Utility functions # ##################### -# TODO(DPPL0.37/penelopeysm): This function should no longer be needed +# TODO(DPPL0.38/penelopeysm): This function should no longer be needed # once InitContext is merged. """ set_namedtuple!(vi::VarInfo, nt::NamedTuple) @@ -207,7 +207,7 @@ end # NOTE(penelopeysm): MH does not conform to the usual LogDensityProblems # interface in that it gets evaluated with a NamedTuple. Hence we need this # method just to deal with MH. -# TODO(DPPL0.37/penelopeysm): Check the extent to which this method is actually +# TODO(DPPL0.38/penelopeysm): Check the extent to which this method is actually # needed. If it's still needed, replace this with `init!!(f.model, f.varinfo, # ParamsInit(x))`. Much less hacky than `set_namedtuple!` (hopefully...). # In general, we should much prefer to either (1) conform to the diff --git a/src/mcmc/prior.jl b/src/mcmc/prior.jl index 6d7463c2f9..4bf16e0f93 100644 --- a/src/mcmc/prior.jl +++ b/src/mcmc/prior.jl @@ -12,7 +12,7 @@ function AbstractMCMC.step( state=nothing; kwargs..., ) - # TODO(DPPL0.37/penelopeysm): replace with init!! + # TODO(DPPL0.38/penelopeysm): replace with init!! sampling_model = DynamicPPL.contextualize( model, DynamicPPL.SamplingContext(rng, DynamicPPL.SampleFromPrior(), model.context) ) From b41a4b19c725b9050348ffa612a02d888d663f3c Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 1 Aug 2025 17:21:05 +0100 Subject: [PATCH 42/49] Fix merge --- src/mcmc/external_sampler.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/mcmc/external_sampler.jl b/src/mcmc/external_sampler.jl index 0c339fe4ab..af31e0243f 100644 --- a/src/mcmc/external_sampler.jl +++ b/src/mcmc/external_sampler.jl @@ -182,7 +182,10 @@ function AbstractMCMC.step( rng, AbstractMCMC.LogDensityModel(f), sampler, state.state; kwargs... ) - new_parameters = getparams(f.model, state_inner) + # NOTE: This is Turing.Inference.getparams, not AbstractMCMC.getparams (!!!!!) + # The latter uses the state rather than the transition. + # TODO(penelopeysm): Make this use AbstractMCMC.getparams instead + new_parameters = Turing.Inference.getparams(f.model, transition_inner) new_vi = DynamicPPL.unflatten(f.varinfo, new_parameters) return ( Transition(f.model, new_vi, transition_inner), TuringState(state_inner, new_vi, f) From d92fd56cd10382f2ea3699d6e4595922653fcec9 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 5 Aug 2025 17:39:51 +0100 Subject: [PATCH 43/49] Do not re-evaluate model for Prior (#2644) * Allow Prior to skip model re-evaluation * remove unneeded `default_chain_type` method * add a test * add a likelihood term too * why not test correctness while we're at it --- src/mcmc/Inference.jl | 53 +++++++++++++++++++++++++++++++----------- src/mcmc/prior.jl | 15 ++++++++---- test/mcmc/Inference.jl | 23 ++++++++++++++++++ 3 files changed, 74 insertions(+), 17 deletions(-) diff --git a/src/mcmc/Inference.jl b/src/mcmc/Inference.jl index 86fa683665..07c8311b45 100644 --- a/src/mcmc/Inference.jl +++ b/src/mcmc/Inference.jl @@ -136,7 +136,7 @@ struct Transition{T,F<:AbstractFloat,N<:NamedTuple} <: AbstractTransition stat::N """ - Transition(model::Model, vi::AbstractVarInfo, sampler_transition) + Transition(model::Model, vi::AbstractVarInfo, sampler_transition; reevaluate=true) Construct a new `Turing.Inference.Transition` object using the outputs of a sampler step. @@ -148,17 +148,38 @@ struct Transition{T,F<:AbstractFloat,N<:NamedTuple} <: AbstractTransition `sampler_transition` is the transition object returned by the sampler itself and is only used to extract statistics of interest. + + By default, the model is re-evaluated in order to obtain values of: + - the values of the parameters as per user parameterisation (`vals_as_in_model`) + - the various components of the log joint probability (`logprior`, `loglikelihood`) + that are guaranteed to be correct. + + If you **know** for a fact that the VarInfo `vi` already contains this information, + then you can set `reevaluate=false` to skip the re-evaluation step. + + !!! warning + Note that in general this is unsafe and may lead to wrong results. + + If `reevaluate` is set to `false`, it is the caller's responsibility to ensure that + the `VarInfo` passed in has `ValuesAsInModelAccumulator`, `LogPriorAccumulator`, + and `LogLikelihoodAccumulator` set up with the correct values. Note that the + `ValuesAsInModelAccumulator` must also have `include_colon_eq == true`, i.e. it + must be set up to track `x := y` statements. """ - function Transition(model::DynamicPPL.Model, vi::AbstractVarInfo, sampler_transition) - vi = DynamicPPL.setaccs!!( - vi, - ( - DynamicPPL.ValuesAsInModelAccumulator(true), - DynamicPPL.LogPriorAccumulator(), - DynamicPPL.LogLikelihoodAccumulator(), - ), - ) - _, vi = DynamicPPL.evaluate!!(model, vi) + function Transition( + model::DynamicPPL.Model, vi::AbstractVarInfo, sampler_transition; reevaluate=true + ) + if reevaluate + vi = DynamicPPL.setaccs!!( + vi, + ( + DynamicPPL.ValuesAsInModelAccumulator(true), + DynamicPPL.LogPriorAccumulator(), + DynamicPPL.LogLikelihoodAccumulator(), + ), + ) + _, vi = DynamicPPL.evaluate!!(model, vi) + end # Extract all the information we need vals_as_in_model = DynamicPPL.getacc(vi, Val(:ValuesAsInModel)).values @@ -175,12 +196,18 @@ struct Transition{T,F<:AbstractFloat,N<:NamedTuple} <: AbstractTransition function Transition( model::DynamicPPL.Model, untyped_vi::DynamicPPL.VarInfo{<:DynamicPPL.Metadata}, - sampler_transition, + sampler_transition; + reevaluate=true, ) # Re-evaluating the model is unconscionably slow for untyped VarInfo. It's # much faster to convert it to a typed varinfo first, hence this method. # https://github.com/TuringLang/Turing.jl/issues/2604 - return Transition(model, DynamicPPL.typed_varinfo(untyped_vi), sampler_transition) + return Transition( + model, + DynamicPPL.typed_varinfo(untyped_vi), + sampler_transition; + reevaluate=reevaluate, + ) end end diff --git a/src/mcmc/prior.jl b/src/mcmc/prior.jl index 4bf16e0f93..2ead40cedf 100644 --- a/src/mcmc/prior.jl +++ b/src/mcmc/prior.jl @@ -16,8 +16,15 @@ function AbstractMCMC.step( sampling_model = DynamicPPL.contextualize( model, DynamicPPL.SamplingContext(rng, DynamicPPL.SampleFromPrior(), model.context) ) - _, vi = DynamicPPL.evaluate!!(sampling_model, VarInfo()) - return Transition(model, vi, nothing), nothing + vi = VarInfo() + vi = DynamicPPL.setaccs!!( + vi, + ( + DynamicPPL.ValuesAsInModelAccumulator(true), + DynamicPPL.LogPriorAccumulator(), + DynamicPPL.LogLikelihoodAccumulator(), + ), + ) + _, vi = DynamicPPL.evaluate!!(sampling_model, vi) + return Transition(model, vi, nothing; reevaluate=false), nothing end - -DynamicPPL.default_chain_type(sampler::Prior) = MCMCChains.Chains diff --git a/test/mcmc/Inference.jl b/test/mcmc/Inference.jl index 8c26e2a227..2cc7e4bc06 100644 --- a/test/mcmc/Inference.jl +++ b/test/mcmc/Inference.jl @@ -142,6 +142,29 @@ using Turing @test mean(x[:s][1] for x in chains) ≈ 3 atol = 0.11 @test mean(x[:m][1] for x in chains) ≈ 0 atol = 0.1 end + + @testset "accumulators are set correctly" begin + # Prior() uses `reevaluate=false` when constructing a + # `Turing.Inference.Transition`, so we had better make sure that it + # does capture colon-eq statements, as we can't rely on the default + # `Transition` constructor to do this for us. + @model function coloneq() + x ~ Normal() + 10.0 ~ Normal(x) + z := 1.0 + return nothing + end + chain = sample(coloneq(), Prior(), N) + @test chain isa MCMCChains.Chains + @test all(x -> x == 1.0, chain[:z]) + # And for the same reason we should also make sure that the logp + # components are correctly calculated. + @test isapprox(chain[:logprior], logpdf.(Normal(), chain[:x])) + @test isapprox(chain[:loglikelihood], logpdf.(Normal.(chain[:x]), 10.0)) + @test isapprox(chain[:lp], chain[:logprior] .+ chain[:loglikelihood]) + # And that the outcome is not influenced by the likelihood + @test mean(chain, :x) ≈ 0.0 atol = 0.1 + end end @testset "chain ordering" begin From 806c82db5e4d2a218d2f3c14afaaeec316d7a987 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 5 Aug 2025 17:40:14 +0100 Subject: [PATCH 44/49] No need to test AD for SamplingContext{<:HMC} (#2645) --- test/ad.jl | 33 ++------------------------------- 1 file changed, 2 insertions(+), 31 deletions(-) diff --git a/test/ad.jl b/test/ad.jl index f53dd98358..dcfe4ef46c 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -237,34 +237,8 @@ end end end -@testset verbose = true "AD / SamplingContext" begin - # AD tests for gradient-based samplers need to be run with SamplingContext - # because samplers can potentially use this to define custom behaviour in - # the tilde-pipeline and thus change the code executed during model - # evaluation. - @testset "adtype=$adtype" for adtype in ADTYPES - @testset "alg=$alg" for alg in [ - HMC(0.1, 10; adtype=adtype), - HMCDA(0.8, 0.75; adtype=adtype), - NUTS(1000, 0.8; adtype=adtype), - SGHMC(; learning_rate=0.02, momentum_decay=0.5, adtype=adtype), - SGLD(; stepsize=PolynomialStepsize(0.25), adtype=adtype), - ] - @info "Testing AD for $alg" - - @testset "model=$(model.f)" for model in DEMO_MODELS - rng = StableRNG(123) - spl_model = DynamicPPL.contextualize( - model, DynamicPPL.SamplingContext(rng, DynamicPPL.Sampler(alg)) - ) - @test run_ad(spl_model, adtype; test=true, benchmark=false) isa Any - end - end - end -end - @testset verbose = true "AD / GibbsContext" begin - # Gibbs sampling also needs extra AD testing because the models are + # Gibbs sampling needs some extra AD testing because the models are # executed with GibbsContext and a subsetted varinfo. (see e.g. # `gibbs_initialstep_recursive` and `gibbs_step_recursive` in # src/mcmc/gibbs.jl -- the code here mimics what happens in those @@ -283,10 +257,7 @@ end model, varnames, deepcopy(global_vi) ) rng = StableRNG(123) - spl_model = DynamicPPL.contextualize( - model, DynamicPPL.SamplingContext(rng, DynamicPPL.Sampler(HMC(0.1, 10))) - ) - @test run_ad(spl_model, adtype; test=true, benchmark=false) isa Any + @test run_ad(model, adtype; test=true, benchmark=false) isa Any end end end From 5743ff77c3acd7a2e692ede10ae98bfeecbd9c5d Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 7 Aug 2025 17:06:18 +0100 Subject: [PATCH 45/49] change breaking -> main --- Project.toml | 2 +- test/Project.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 4fe0a0bf41..05c5f6380b 100644 --- a/Project.toml +++ b/Project.toml @@ -92,4 +92,4 @@ DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb" Optim = "429524aa-4258-5aef-a3af-852621145aeb" [sources] -DynamicPPL = {url = "https://github.com/TuringLang/DynamicPPL.jl", rev = "breaking"} +DynamicPPL = {url = "https://github.com/TuringLang/DynamicPPL.jl", rev = "main"} diff --git a/test/Project.toml b/test/Project.toml index a7d4e75f57..149c7336bd 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -79,4 +79,4 @@ TimerOutputs = "0.5" julia = "1.10" [sources] -DynamicPPL = {url = "https://github.com/TuringLang/DynamicPPL.jl", rev = "breaking"} +DynamicPPL = {url = "https://github.com/TuringLang/DynamicPPL.jl", rev = "main"} From 57e6f9c89559840447f974bc9c1dcecf16499587 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 11 Aug 2025 14:30:15 +0100 Subject: [PATCH 46/49] Remove calls to resetlogp!! & add changelog (#2650) * Remove calls to resetlogp!! * Add a changelog for 0.40 * Update HISTORY.md Co-authored-by: Markus Hauru --------- Co-authored-by: Markus Hauru --- HISTORY.md | 53 ++++++++++++++++++++++++++++++++++++++- src/mcmc/particle_mcmc.jl | 3 --- 2 files changed, 52 insertions(+), 4 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index 0a673decc8..3848b03127 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,6 +1,57 @@ # 0.40.0 -[...] +## Breaking changes + +**DynamicPPL 0.37** + +Turing.jl v0.40 updates DynamicPPL compatibility to 0.37. +The summary of the changes provided here is intended for end-users of Turing. +If you are a package developer, or would otherwise like to understand these changes in-depth, please see [the DynamicPPL changelog](https://github.com/TuringLang/DynamicPPL.jl/blob/main/HISTORY.md#0370). + + - **`@submodel`** is now completely removed; please use `to_submodel`. + + - **Prior and likelihood calculations** are now completely separated in Turing. Previously, the log-density used to be accumulated in a single field and thus there was no clear way to separate prior and likelihood components. + + + **`@addlogprob! f`**, where `f` is a float, now adds to the likelihood by default. + + You can instead use **`@addlogprob! (; logprior=x, loglikelihood=y)`** to control which log-density component to add to. + + This means that usage of `PriorContext` and `LikelihoodContext` is no longer needed, and these have now been removed. + - The special **`__context__`** variable has been removed. If you still need to access the evaluation context, it is now available as `__model__.context`. + +**Log-density in chains** + +When sampling from a Turing model, the resulting `MCMCChains.Chains` object now contains not only the log-joint (accessible via `chain[:lp]`) but also the log-prior and log-likelihood (`chain[:logprior]` and `chain[:loglikelihood]` respectively). + +These values now correspond to the log density of the sampled variables exactly as per the model definition / user parameterisation and thus will ignore any linking (transformation to unconstrained space). +For example, if the model is `@model f() = x ~ LogNormal()`, `chain[:lp]` would always contain the value of `logpdf(LogNormal(), x)` for each sampled value of `x`. +Previously these values could be incorrect if linking had occurred: some samplers would return `logpdf(Normal(), log(x))` i.e. the log-density with respect to the transformed distribution. + +**Gibbs sampler** + +When using Turing's Gibbs sampler, e.g. `Gibbs(:x => MH(), :y => HMC(0.1, 20))`, the conditioned variables (for example `y` during the MH step, or `x` during the HMC step) are treated as true observations. +Thus the log-density associated with them is added to the likelihood. +Previously these would effectively be added to the prior (in the sense that if `LikelihoodContext` was used they would be ignored). +This is unlikely to affect users but we mention it here to be explicit. +This change only affects the log probabilities as the Gibbs component samplers see them; the resulting chain will include the usual log prior, likelihood, and joint, as described above. + +**Particle Gibbs** + +Previously, only 'true' observations (i.e., `x ~ dist` where `x` is a model argument or conditioned upon) would trigger resampling of particles. +Specifically, there were two cases where resampling would not be triggered: + + - Calls to `@addlogprob!` + - Gibbs-conditioned variables: e.g. `y` in `Gibbs(:x => PG(20), :y => MH())` + +Turing 0.40 changes this such that both of the above cause resampling. +(The second case follows from the changes to the Gibbs sampler, see above.) + +This release also fixes a bug where, if the model ended with one of these statements, their contribution to the particle weight would be ignored, leading to incorrect results. + +## Other changes + + - Sampling using `Prior()` should now be about twice as fast because we now avoid evaluating the model twice on every iteration. + - `Turing.Inference.Transition` now has different fields. + If `t isa Turing.Inference.Transition`, `t.stat` is always a NamedTuple, not `nothing` (if it genuinely has no information then it's an empty NamedTuple). + Furthermore, `t.lp` has now been split up into `t.logprior` and `t.loglikelihood` (see also 'Log-density in chains' section above). # 0.39.9 diff --git a/src/mcmc/particle_mcmc.jl b/src/mcmc/particle_mcmc.jl index 6959e22ccd..b500d3a46e 100644 --- a/src/mcmc/particle_mcmc.jl +++ b/src/mcmc/particle_mcmc.jl @@ -211,7 +211,6 @@ function DynamicPPL.initialstep( # Reset the VarInfo. vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator()) set_all_del!(vi) - vi = DynamicPPL.resetlogp!!(vi) vi = DynamicPPL.empty!!(vi) # Create a new set of particles. @@ -339,7 +338,6 @@ function DynamicPPL.initialstep( vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator()) # Reset the VarInfo before new sweep set_all_del!(vi) - vi = DynamicPPL.resetlogp!!(vi) # Create a new set of particles num_particles = spl.alg.nparticles @@ -370,7 +368,6 @@ function AbstractMCMC.step( # Reset the VarInfo before new sweep. vi = state.vi vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator()) - vi = DynamicPPL.resetlogp!!(vi) # Create reference particle for which the samples will be retained. unset_all_del!(vi) From bb21e1eeec83fafc5c828e823ee76fbb89a1571c Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 11 Aug 2025 14:34:52 +0100 Subject: [PATCH 47/49] Remove `[sources]` --- Project.toml | 3 --- test/Project.toml | 3 --- 2 files changed, 6 deletions(-) diff --git a/Project.toml b/Project.toml index 05c5f6380b..b0504e3678 100644 --- a/Project.toml +++ b/Project.toml @@ -90,6 +90,3 @@ julia = "1.10.8" [extras] DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb" Optim = "429524aa-4258-5aef-a3af-852621145aeb" - -[sources] -DynamicPPL = {url = "https://github.com/TuringLang/DynamicPPL.jl", rev = "main"} diff --git a/test/Project.toml b/test/Project.toml index 149c7336bd..b10be01404 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -77,6 +77,3 @@ StatsBase = "0.33, 0.34" StatsFuns = "0.9.5, 1" TimerOutputs = "0.5" julia = "1.10" - -[sources] -DynamicPPL = {url = "https://github.com/TuringLang/DynamicPPL.jl", rev = "main"} From 1bc2fbf468f78cdade83ecf1cfe2a561c9077304 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 12 Aug 2025 12:27:23 +0100 Subject: [PATCH 48/49] Unify Turing `Transition`s, fix some tests (#2651) * Unify `Transition` methods * Add tests * Add same test for SGLD/SGHMC * Refactor so that it's nice and organised * Fix failing test on 1.10 * just increase the atol * Make addlogprob test more robust * Remove stray `@show` Co-authored-by: Markus Hauru --------- Co-authored-by: Markus Hauru --- src/mcmc/Inference.jl | 33 ++++++++++----------- src/mcmc/particle_mcmc.jl | 61 +++++++++++++------------------------- src/mcmc/sghmc.jl | 25 +++------------- test/mcmc/gibbs.jl | 6 ++-- test/mcmc/particle_mcmc.jl | 33 ++++++++++++++++++++- test/mcmc/sghmc.jl | 10 +++++++ test/runtests.jl | 1 + test/test_utils/sampler.jl | 27 +++++++++++++++++ 8 files changed, 114 insertions(+), 82 deletions(-) create mode 100644 test/test_utils/sampler.jl diff --git a/src/mcmc/Inference.jl b/src/mcmc/Inference.jl index 07c8311b45..d6e9afcbbc 100644 --- a/src/mcmc/Inference.jl +++ b/src/mcmc/Inference.jl @@ -124,19 +124,16 @@ end # Default Transition # ###################### getstats(::Any) = NamedTuple() +getstats(nt::NamedTuple) = nt -# TODO(penelopeysm): Remove this abstract type by converting SGLDTransition, -# SMCTransition, and PGTransition to Turing.Inference.Transition instead. -abstract type AbstractTransition end - -struct Transition{T,F<:AbstractFloat,N<:NamedTuple} <: AbstractTransition +struct Transition{T,F<:AbstractFloat,N<:NamedTuple} θ::T logprior::F loglikelihood::F stat::N """ - Transition(model::Model, vi::AbstractVarInfo, sampler_transition; reevaluate=true) + Transition(model::Model, vi::AbstractVarInfo, stats; reevaluate=true) Construct a new `Turing.Inference.Transition` object using the outputs of a sampler step. @@ -146,8 +143,10 @@ struct Transition{T,F<:AbstractFloat,N<:NamedTuple} <: AbstractTransition have junk contents. The role of this method is to re-evaluate `model` and thus set the accumulators to the correct values. - `sampler_transition` is the transition object returned by the sampler - itself and is only used to extract statistics of interest. + `stats` is any object on which `Turing.Inference.getstats` can be called to + return a NamedTuple of statistics. This could be, for example, the transition + returned by an (unwrapped) external sampler. Or alternatively, it could + simply be a NamedTuple itself (for which `getstats` acts as the identity). By default, the model is re-evaluated in order to obtain values of: - the values of the parameters as per user parameterisation (`vals_as_in_model`) @@ -167,8 +166,11 @@ struct Transition{T,F<:AbstractFloat,N<:NamedTuple} <: AbstractTransition must be set up to track `x := y` statements. """ function Transition( - model::DynamicPPL.Model, vi::AbstractVarInfo, sampler_transition; reevaluate=true + model::DynamicPPL.Model, vi::AbstractVarInfo, stats; reevaluate=true ) + # Avoid mutating vi as it may be used later e.g. when constructing + # sampler states. + vi = deepcopy(vi) if reevaluate vi = DynamicPPL.setaccs!!( vi, @@ -187,7 +189,7 @@ struct Transition{T,F<:AbstractFloat,N<:NamedTuple} <: AbstractTransition loglikelihood = DynamicPPL.getloglikelihood(vi) # Get additional statistics - stats = getstats(sampler_transition) + stats = getstats(stats) return new{typeof(vals_as_in_model),typeof(logprior),typeof(stats)}( vals_as_in_model, logprior, loglikelihood, stats ) @@ -196,17 +198,14 @@ struct Transition{T,F<:AbstractFloat,N<:NamedTuple} <: AbstractTransition function Transition( model::DynamicPPL.Model, untyped_vi::DynamicPPL.VarInfo{<:DynamicPPL.Metadata}, - sampler_transition; + stats; reevaluate=true, ) # Re-evaluating the model is unconscionably slow for untyped VarInfo. It's # much faster to convert it to a typed varinfo first, hence this method. # https://github.com/TuringLang/Turing.jl/issues/2604 return Transition( - model, - DynamicPPL.typed_varinfo(untyped_vi), - sampler_transition; - reevaluate=reevaluate, + model, DynamicPPL.typed_varinfo(untyped_vi), stats; reevaluate=reevaluate ) end end @@ -318,7 +317,7 @@ getlogevidence(transitions, sampler, state) = missing # Default MCMCChains.Chains constructor. # This is type piracy (at least for SampleFromPrior). function AbstractMCMC.bundle_samples( - ts::Vector{<:Union{AbstractTransition,AbstractVarInfo}}, + ts::Vector{<:Union{Transition,AbstractVarInfo}}, model::AbstractModel, spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior,RepeatSampler}, state, @@ -381,7 +380,7 @@ end # This is type piracy (for SampleFromPrior). function AbstractMCMC.bundle_samples( - ts::Vector{<:Union{AbstractTransition,AbstractVarInfo}}, + ts::Vector{<:Union{Transition,AbstractVarInfo}}, model::AbstractModel, spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior,RepeatSampler}, state, diff --git a/src/mcmc/particle_mcmc.jl b/src/mcmc/particle_mcmc.jl index b500d3a46e..ab2add975c 100644 --- a/src/mcmc/particle_mcmc.jl +++ b/src/mcmc/particle_mcmc.jl @@ -135,23 +135,6 @@ function SMC(threshold::Real) return SMC(AdvancedPS.resample_systematic, threshold) end -struct SMCTransition{T,F<:AbstractFloat} <: AbstractTransition - "The parameters for any given sample." - θ::T - "The joint log probability of the sample (NOTE: does not work, always set to zero)." - lp::F - "The weight of the particle the sample was retrieved from." - weight::F -end - -function SMCTransition(model::DynamicPPL.Model, vi::AbstractVarInfo, weight) - theta = getparams(model, vi) - lp = DynamicPPL.getlogjoint_internal(vi) - return SMCTransition(theta, lp, weight) -end - -getstats_with_lp(t::SMCTransition) = (lp=t.lp, weight=t.weight) - struct SMCState{P,F<:AbstractFloat} particles::P particleindex::Int @@ -228,7 +211,8 @@ function DynamicPPL.initialstep( weight = AdvancedPS.getweight(particles, 1) # Compute the first transition and the first state. - transition = SMCTransition(model, particle.model.f.varinfo, weight) + stats = (; weight=weight, logevidence=logevidence) + transition = Transition(model, particle.model.f.varinfo, stats) state = SMCState(particles, 2, logevidence) return transition, state @@ -246,7 +230,8 @@ function AbstractMCMC.step( weight = AdvancedPS.getweight(particles, index) # Compute the transition and the next state. - transition = SMCTransition(model, particle.model.f.varinfo, weight) + stats = (; weight=weight, logevidence=state.average_logevidence) + transition = Transition(model, particle.model.f.varinfo, stats) nextstate = SMCState(state.particles, index + 1, state.average_logevidence) return transition, nextstate @@ -300,15 +285,6 @@ Equivalent to [`PG`](@ref). """ const CSMC = PG # type alias of PG as Conditional SMC -struct PGTransition{T,F<:AbstractFloat} <: AbstractTransition - "The parameters for any given sample." - θ::T - "The joint log probability of the sample (NOTE: does not work, always set to zero)." - lp::F - "The log evidence of the sample." - logevidence::F -end - struct PGState vi::AbstractVarInfo rng::Random.AbstractRNG @@ -316,16 +292,21 @@ end get_varinfo(state::PGState) = state.vi -function PGTransition(model::DynamicPPL.Model, vi::AbstractVarInfo, logevidence) - theta = getparams(model, vi) - lp = DynamicPPL.getlogjoint_internal(vi) - return PGTransition(theta, lp, logevidence) -end - -getstats_with_lp(t::PGTransition) = (lp=t.lp, logevidence=t.logevidence) - -function getlogevidence(samples, sampler::Sampler{<:PG}, state::PGState) - return mean(x.logevidence for x in samples) +function getlogevidence( + transitions::AbstractVector{<:Turing.Inference.Transition}, + sampler::Sampler{<:PG}, + state::PGState, +) + logevidences = map(transitions) do t + if haskey(t.stat, :logevidence) + return t.stat.logevidence + else + # This should not really happen, but if it does we can handle it + # gracefully + return missing + end + end + return mean(logevidences) end function DynamicPPL.initialstep( @@ -357,7 +338,7 @@ function DynamicPPL.initialstep( # Compute the first transition. _vi = reference.model.f.varinfo - transition = PGTransition(model, _vi, logevidence) + transition = Transition(model, _vi, (; logevidence=logevidence)) return transition, PGState(_vi, reference.rng) end @@ -397,7 +378,7 @@ function AbstractMCMC.step( # Compute the transition. _vi = newreference.model.f.varinfo - transition = PGTransition(model, _vi, logevidence) + transition = Transition(model, _vi, (; logevidence=logevidence)) return transition, PGState(_vi, newreference.rng) end diff --git a/src/mcmc/sghmc.jl b/src/mcmc/sghmc.jl index 5ca351643e..34d7cf9d8d 100644 --- a/src/mcmc/sghmc.jl +++ b/src/mcmc/sghmc.jl @@ -184,23 +184,6 @@ function SGLD(; return SGLD(stepsize, adtype) end -struct SGLDTransition{T,F<:Real} <: AbstractTransition - "The parameters for any given sample." - θ::T - "The joint log probability of the sample." - lp::F - "The stepsize that was used to obtain the sample." - stepsize::F -end - -function SGLDTransition(model::DynamicPPL.Model, vi::AbstractVarInfo, stepsize) - theta = getparams(model, vi) - lp = DynamicPPL.getlogjoint_internal(vi) - return SGLDTransition(theta, lp, stepsize) -end - -getstats_with_lp(t::SGLDTransition) = (lp=t.lp, SGLD_stepsize=t.stepsize) - struct SGLDState{L,V<:AbstractVarInfo} logdensity::L vi::V @@ -220,13 +203,13 @@ function DynamicPPL.initialstep( end # Create first sample and state. - sample = SGLDTransition(model, vi, zero(spl.alg.stepsize(0))) + transition = Transition(model, vi, (; SGLD_stepsize=zero(spl.alg.stepsize(0)))) ℓ = DynamicPPL.LogDensityFunction( model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.alg.adtype ) state = SGLDState(ℓ, vi, 1) - return sample, state + return transition, state end function AbstractMCMC.step( @@ -245,8 +228,8 @@ function AbstractMCMC.step( vi = DynamicPPL.unflatten(vi, θ) # Compute next sample and state. - sample = SGLDTransition(model, vi, stepsize) + transition = Transition(model, vi, (; SGLD_stepsize=stepsize)) newstate = SGLDState(ℓ, vi, state.step + 1) - return sample, newstate + return transition, newstate end diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index 0fd76be3ab..dc8cd42d0d 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -598,8 +598,8 @@ end means = Dict(:b => 0.5, "x[1]" => 1.0, "x[2]" => 2.0) stds = Dict(:b => 0.5, "x[1]" => 1.0, "x[2]" => 1.0) for vn in keys(means) - @test isapprox(mean(skipmissing(chain[:, vn, 1])), means[vn]; atol=0.1) - @test isapprox(std(skipmissing(chain[:, vn, 1])), stds[vn]; atol=0.1) + @test isapprox(mean(skipmissing(chain[:, vn, 1])), means[vn]; atol=0.15) + @test isapprox(std(skipmissing(chain[:, vn, 1])), stds[vn]; atol=0.15) end end @@ -651,7 +651,7 @@ end chain = sample( StableRNG(468), model, - Gibbs(:b => PG(10), :x => ESS()), + Gibbs(:b => PG(20), :x => ESS()), 2000; discard_initial=100, ) diff --git a/test/mcmc/particle_mcmc.jl b/test/mcmc/particle_mcmc.jl index 7a2f5fe1c7..ad7373b855 100644 --- a/test/mcmc/particle_mcmc.jl +++ b/test/mcmc/particle_mcmc.jl @@ -1,10 +1,11 @@ module ParticleMCMCTests using ..Models: gdemo_default -#using ..Models: MoGtest, MoGtest_default +using ..SamplerTestUtils: test_chain_logp_metadata using AdvancedPS: ResampleWithESSThreshold, resample_systematic, resample_multinomial using Distributions: Bernoulli, Beta, Gamma, Normal, sample using Random: Random +using StableRNGs: StableRNG using Test: @test, @test_throws, @testset using Turing @@ -49,6 +50,10 @@ using Turing @test_throws ErrorException sample(fail_smc(), SMC(), 100) end + @testset "chain log-density metadata" begin + test_chain_logp_metadata(SMC()) + end + @testset "logevidence" begin Random.seed!(100) @@ -65,7 +70,10 @@ using Turing chains_smc = sample(test(), SMC(), 100) @test all(isone, chains_smc[:x]) + # the chain itself has a logevidence field @test chains_smc.logevidence ≈ -2 * log(2) + # but each transition also contains the logevidence + @test chains_smc[:logevidence] ≈ fill(chains_smc.logevidence, 100) end end @@ -88,6 +96,10 @@ end @test s.resampler === resample_systematic end + @testset "chain log-density metadata" begin + test_chain_logp_metadata(PG(10)) + end + @testset "logevidence" begin Random.seed!(100) @@ -105,6 +117,7 @@ end @test all(isone, chains_pg[:x]) @test chains_pg.logevidence ≈ -2 * log(2) atol = 0.01 + @test chains_pg[:logevidence] ≈ fill(chains_pg.logevidence, 100) end # https://github.com/TuringLang/Turing.jl/issues/1598 @@ -114,6 +127,24 @@ end @test length(unique(c[:s])) == 1 end + @testset "addlogprob leads to reweighting" begin + # Make sure that PG takes @addlogprob! into account. It didn't use to: + # https://github.com/TuringLang/Turing.jl/issues/1996 + @model function addlogprob_demo() + x ~ Normal(0, 1) + if x < 0 + @addlogprob! -10.0 + else + # Need a balanced number of addlogprobs in all branches, or + # else PG will error + @addlogprob! 0.0 + end + end + c = sample(StableRNG(468), addlogprob_demo(), PG(10), 100) + # Result should be biased towards x > 0. + @test mean(c[:x]) > 0.7 + end + # https://github.com/TuringLang/Turing.jl/issues/2007 @testset "keyword arguments not supported" begin @model kwarg_demo(; x=2) = return x diff --git a/test/mcmc/sghmc.jl b/test/mcmc/sghmc.jl index 1671362ed4..ee943270cd 100644 --- a/test/mcmc/sghmc.jl +++ b/test/mcmc/sghmc.jl @@ -2,6 +2,7 @@ module SGHMCTests using ..Models: gdemo_default using ..NumericalTests: check_gdemo +using ..SamplerTestUtils: test_chain_logp_metadata using DynamicPPL.TestUtils.AD: run_ad using DynamicPPL.TestUtils: DEMO_MODELS using DynamicPPL: DynamicPPL @@ -32,6 +33,10 @@ using Turing chain = sample(rng, gdemo_default, alg, 10_000) check_gdemo(chain; atol=0.1) end + + @testset "chain log-density metadata" begin + test_chain_logp_metadata(SGHMC(; learning_rate=0.02, momentum_decay=0.5)) + end end @testset "Testing sgld.jl" begin @@ -46,6 +51,7 @@ end sampler = DynamicPPL.Sampler(alg) @test sampler isa DynamicPPL.Sampler{<:SGLD} end + @testset "sgld inference" begin rng = StableRNG(1) @@ -59,6 +65,10 @@ end @test s_weighted ≈ 49 / 24 atol = 0.2 @test m_weighted ≈ 7 / 6 atol = 0.2 end + + @testset "chain log-density metadata" begin + test_chain_logp_metadata(SGLD(; stepsize=PolynomialStepsize(0.25))) + end end end diff --git a/test/runtests.jl b/test/runtests.jl index 9fec2f737e..5fb6b21411 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -11,6 +11,7 @@ seed!(23) include("test_utils/models.jl") include("test_utils/numerical_tests.jl") +include("test_utils/sampler.jl") Turing.setprogress!(false) included_paths, excluded_paths = parse_args(ARGS) diff --git a/test/test_utils/sampler.jl b/test/test_utils/sampler.jl new file mode 100644 index 0000000000..32a3647f98 --- /dev/null +++ b/test/test_utils/sampler.jl @@ -0,0 +1,27 @@ +module SamplerTestUtils + +using Turing +using Test + +""" +Check that when sampling with `spl`, the resulting chain contains log-density +metadata that is correct. +""" +function test_chain_logp_metadata(spl) + @model function f() + # some prior term (but importantly, one that is constrained, i.e., can + # be linked with non-identity transform) + x ~ LogNormal() + # some likelihood term + return 1.0 ~ Normal(x) + end + chn = sample(f(), spl, 100) + # Check that the log-prior term is calculated in unlinked space. + @test chn[:logprior] ≈ logpdf.(LogNormal(), chn[:x]) + @test chn[:loglikelihood] ≈ logpdf.(Normal.(chn[:x]), 1.0) + # This should always be true, but it also indirectly checks that the + # log-joint is also calculated in unlinked space. + @test chn[:lp] ≈ chn[:logprior] + chn[:loglikelihood] +end + +end From 247aee9c350ea86b27e5fa35742fac1a9d4fb733 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 12 Aug 2025 14:57:13 +0100 Subject: [PATCH 49/49] Update changelog for PG in Gibbs --- HISTORY.md | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/HISTORY.md b/HISTORY.md index 3848b03127..0188c4fce3 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -46,6 +46,32 @@ Turing 0.40 changes this such that both of the above cause resampling. This release also fixes a bug where, if the model ended with one of these statements, their contribution to the particle weight would be ignored, leading to incorrect results. +The changes above also mean that certain models that previously worked with PG-within-Gibbs may now error. +Specifically this is likely to happen when the dimension of the model is variable. +For example: + +```julia +@model function f() + x ~ Bernoulli() + if x + y1 ~ Normal() + else + y1 ~ Normal() + y2 ~ Normal() + end + # (some likelihood term...) +end +sample(f(), Gibbs(:x => PG(20), (:y1, :y2) => MH()), 100) +``` + +This sampler now cannot be used for this model because depending on which branch is taken, the number of observations will be different. +To use PG-within-Gibbs, the number of observations that the PG component sampler sees must be constant. +Thus, for example, this will still work if `x`, `y1`, and `y2` are grouped together under the PG component sampler. + +If you absolutely require the old behaviour, we recommend using Turing.jl v0.39, but also thinking very carefully about what the expected behaviour of the model is, and checking that Turing is sampling from it correctly (note that the behaviour on v0.39 may in general be incorrect because of the fact that Gibbs-conditioned variables did not trigger resampling). +We would also welcome any GitHub issues highlighting such problems. +Our support for dynamic models is incomplete and is liable to undergo further changes. + ## Other changes - Sampling using `Prior()` should now be about twice as fast because we now avoid evaluating the model twice on every iteration.