Skip to content

Use accumulators to fix all logp calculations when sampling #2630

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Aug 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 5 additions & 13 deletions ext/TuringDynamicHMCExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ function DynamicPPL.initialstep(

# Define log-density function.
ℓ = DynamicPPL.LogDensityFunction(
model, DynamicPPL.getlogjoint, vi; adtype=spl.alg.adtype
model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.alg.adtype
)

# Perform initial step.
Expand All @@ -73,14 +73,9 @@ function DynamicPPL.initialstep(
steps = DynamicHMC.mcmc_steps(results.sampling_logdensity, results.final_warmup_state)
Q, _ = DynamicHMC.mcmc_next_step(steps, results.final_warmup_state.Q)

# Update the variables.
vi = DynamicPPL.unflatten(vi, Q.q)
# TODO(DPPL0.37/penelopeysm): This is obviously incorrect. Fix this.
vi = DynamicPPL.setloglikelihood!!(vi, Q.ℓq)
vi = DynamicPPL.setlogprior!!(vi, 0.0)

# Create first sample and state.
sample = Turing.Inference.Transition(model, vi)
vi = DynamicPPL.unflatten(vi, Q.q)
sample = Turing.Inference.Transition(model, vi, nothing)
state = DynamicNUTSState(ℓ, vi, Q, steps.H.κ, steps.ϵ)

return sample, state
Expand All @@ -99,12 +94,9 @@ function AbstractMCMC.step(
steps = DynamicHMC.mcmc_steps(rng, spl.alg.sampler, state.metric, ℓ, state.stepsize)
Q, _ = DynamicHMC.mcmc_next_step(steps, state.cache)

# Update the variables.
vi = DynamicPPL.unflatten(vi, Q.q)
vi = DynamicPPL.setlogp!!(vi, Q.ℓq)

# Create next sample and state.
sample = Turing.Inference.Transition(model, vi)
vi = DynamicPPL.unflatten(vi, Q.q)
sample = Turing.Inference.Transition(model, vi, nothing)
newstate = DynamicNUTSState(ℓ, vi, Q, state.metric, state.stepsize)

return sample, newstate
Expand Down
6 changes: 3 additions & 3 deletions ext/TuringOptimExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ function Optim.optimize(
options::Optim.Options=Optim.Options();
kwargs...,
)
f = Optimisation.OptimLogDensity(model, Optimisation.getlogjoint_without_jacobian)
f = Optimisation.OptimLogDensity(model, DynamicPPL.getlogjoint)
init_vals = DynamicPPL.getparams(f.ldf)
optimizer = Optim.LBFGS()
return _map_optimize(model, init_vals, optimizer, options; kwargs...)
Expand All @@ -124,7 +124,7 @@ function Optim.optimize(
options::Optim.Options=Optim.Options();
kwargs...,
)
f = Optimisation.OptimLogDensity(model, Optimisation.getlogjoint_without_jacobian)
f = Optimisation.OptimLogDensity(model, DynamicPPL.getlogjoint)
init_vals = DynamicPPL.getparams(f.ldf)
return _map_optimize(model, init_vals, optimizer, options; kwargs...)
end
Expand All @@ -140,7 +140,7 @@ function Optim.optimize(
end

function _map_optimize(model::DynamicPPL.Model, args...; kwargs...)
f = Optimisation.OptimLogDensity(model, Optimisation.getlogjoint_without_jacobian)
f = Optimisation.OptimLogDensity(model, DynamicPPL.getlogjoint)
return _optimize(f, args...; kwargs...)
end

Expand Down
219 changes: 78 additions & 141 deletions src/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ using DynamicPPL:
setindex!!,
push!!,
setlogp!!,
getlogp,
getlogjoint,
getlogjoint_internal,
VarName,
getsym,
getdist,
Expand Down Expand Up @@ -123,71 +123,94 @@ end
######################
# Default Transition #
######################
# Default
getstats(t) = nothing
getstats(::Any) = NamedTuple()

# TODO(penelopeysm): Remove this abstract type by converting SGLDTransition,
# SMCTransition, and PGTransition to Turing.Inference.Transition instead.
abstract type AbstractTransition end

struct Transition{T,F<:AbstractFloat,S<:Union{NamedTuple,Nothing}} <: AbstractTransition
struct Transition{T,F<:AbstractFloat,N<:NamedTuple} <: AbstractTransition
θ::T
lp::F # TODO: merge `lp` with `stat`
stat::S
end

Transition(θ, lp) = Transition(θ, lp, nothing)
function Transition(model::DynamicPPL.Model, vi::AbstractVarInfo, t)
θ = getparams(model, vi)
lp = getlogjoint(vi)
return Transition(θ, lp, getstats(t))
end
logprior::F
loglikelihood::F
stat::N

"""
Transition(model::Model, vi::AbstractVarInfo, sampler_transition)

Construct a new `Turing.Inference.Transition` object using the outputs of a
sampler step.

Here, `vi` represents a VarInfo _for which the appropriate parameters have
already been set_. However, the accumulators (e.g. logp) may in general
have junk contents. The role of this method is to re-evaluate `model` and
thus set the accumulators to the correct values.

`sampler_transition` is the transition object returned by the sampler
itself and is only used to extract statistics of interest.
"""
function Transition(model::DynamicPPL.Model, vi::AbstractVarInfo, sampler_transition)
vi = DynamicPPL.setaccs!!(
vi,
(
DynamicPPL.ValuesAsInModelAccumulator(true),
DynamicPPL.LogPriorAccumulator(),
DynamicPPL.LogLikelihoodAccumulator(),
),
)
_, vi = DynamicPPL.evaluate!!(model, vi)

# Extract all the information we need
vals_as_in_model = DynamicPPL.getacc(vi, Val(:ValuesAsInModel)).values
logprior = DynamicPPL.getlogprior(vi)
loglikelihood = DynamicPPL.getloglikelihood(vi)

# Get additional statistics
stats = getstats(sampler_transition)
return new{typeof(vals_as_in_model),typeof(logprior),typeof(stats)}(
vals_as_in_model, logprior, loglikelihood, stats
)
end

function metadata(t::Transition)
stat = t.stat
if stat === nothing
return (lp=t.lp,)
else
return merge((lp=t.lp,), stat)
function Transition(
model::DynamicPPL.Model,
untyped_vi::DynamicPPL.VarInfo{<:DynamicPPL.Metadata},
sampler_transition,
)
# Re-evaluating the model is unconscionably slow for untyped VarInfo. It's
# much faster to convert it to a typed varinfo first, hence this method.
# https://github.com/TuringLang/Turing.jl/issues/2604
return Transition(model, DynamicPPL.typed_varinfo(untyped_vi), sampler_transition)
end
end

DynamicPPL.getlogjoint(t::Transition) = t.lp

# Metadata of VarInfo object
metadata(vi::AbstractVarInfo) = (lp=getlogjoint(vi),)
function getstats_with_lp(t::Transition)
return merge(
t.stat,
(
lp=t.logprior + t.loglikelihood,
logprior=t.logprior,
loglikelihood=t.loglikelihood,
),
)
end
function getstats_with_lp(vi::AbstractVarInfo)
return (
lp=DynamicPPL.getlogjoint(vi),
logprior=DynamicPPL.getlogprior(vi),
loglikelihood=DynamicPPL.getloglikelihood(vi),
)
end

##########################
# Chain making utilities #
##########################

"""
getparams(model, t)

Return a named tuple of parameters.
"""
getparams(model, t) = t.θ
function getparams(model::DynamicPPL.Model, vi::DynamicPPL.VarInfo)
# NOTE: In the past, `invlink(vi, model)` + `values_as(vi, OrderedDict)` was used.
# Unfortunately, using `invlink` can cause issues in scenarios where the constraints
# of the parameters change depending on the realizations. Hence we have to use
# `values_as_in_model`, which re-runs the model and extracts the parameters
# as they are seen in the model, i.e. in the constrained space. Moreover,
# this means that the code below will work both of linked and invlinked `vi`.
# Ref: https://github.com/TuringLang/Turing.jl/issues/2195
# NOTE: We need to `deepcopy` here to avoid modifying the original `vi`.
return DynamicPPL.values_as_in_model(model, true, deepcopy(vi))
end
function getparams(
model::DynamicPPL.Model, untyped_vi::DynamicPPL.VarInfo{<:DynamicPPL.Metadata}
)
# values_as_in_model is unconscionably slow for untyped VarInfo. It's
# much faster to convert it to a typed varinfo before calling getparams.
# https://github.com/TuringLang/Turing.jl/issues/2604
return getparams(model, DynamicPPL.typed_varinfo(untyped_vi))
getparams(::DynamicPPL.Model, t::AbstractTransition) = t.θ
function getparams(model::DynamicPPL.Model, vi::AbstractVarInfo)
t = Transition(model, vi, nothing)
return getparams(model, t)
end
function getparams(::DynamicPPL.Model, ::DynamicPPL.VarInfo{NamedTuple{(),Tuple{}}})
return Dict{VarName,Any}()
end

function _params_to_array(model::DynamicPPL.Model, ts::Vector)
names_set = OrderedSet{VarName}()
# Extract the parameter names and values from each transition.
Expand All @@ -203,7 +226,6 @@ function _params_to_array(model::DynamicPPL.Model, ts::Vector)
iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals))
mapreduce(collect, vcat, iters)
end

nms = map(first, nms_and_vs)
vs = map(last, nms_and_vs)
for nm in nms
Expand All @@ -218,14 +240,9 @@ function _params_to_array(model::DynamicPPL.Model, ts::Vector)
return names, vals
end

function get_transition_extras(ts::AbstractVector{<:VarInfo})
valmat = reshape([getlogjoint(t) for t in ts], :, 1)
return [:lp], valmat
end

function get_transition_extras(ts::AbstractVector)
# Extract all metadata.
extra_data = map(metadata, ts)
# Extract stats + log probabilities from each transition or VarInfo
extra_data = map(getstats_with_lp, ts)
return names_values(extra_data)
end

Expand Down Expand Up @@ -334,7 +351,7 @@ function AbstractMCMC.bundle_samples(
vals = map(values(sym_to_vns)) do vns
map(Base.Fix1(getindex, params), vns)
end
return merge(NamedTuple(zip(keys(sym_to_vns), vals)), metadata(t))
return merge(NamedTuple(zip(keys(sym_to_vns), vals)), getstats_with_lp(t))
end
end

Expand Down Expand Up @@ -396,84 +413,4 @@ function DynamicPPL.get_matching_type(
return Array{T,N}
end

##############
# Utilities #
##############

"""

transitions_from_chain(
[rng::AbstractRNG,]
model::Model,
chain::MCMCChains.Chains;
sampler = DynamicPPL.SampleFromPrior()
)

Execute `model` conditioned on each sample in `chain`, and return resulting transitions.

The returned transitions are represented in a `Vector{<:Turing.Inference.Transition}`.

# Details

In a bit more detail, the process is as follows:
1. For every `sample` in `chain`
1. For every `variable` in `sample`
1. Set `variable` in `model` to its value in `sample`
2. Execute `model` with variables fixed as above, sampling variables NOT present
in `chain` using `SampleFromPrior`
3. Return sampled variables and log-joint

# Example
```julia-repl
julia> using Turing

julia> @model function demo()
m ~ Normal(0, 1)
x ~ Normal(m, 1)
end;

julia> m = demo();

julia> chain = Chains(randn(2, 1, 1), ["m"]); # 2 samples of `m`

julia> transitions = Turing.Inference.transitions_from_chain(m, chain);

julia> [Turing.Inference.getlogjoint(t) for t in transitions] # extract the logjoints
2-element Array{Float64,1}:
-3.6294991938628374
-2.5697948166987845

julia> [first(t.θ.x) for t in transitions] # extract samples for `x`
2-element Array{Array{Float64,1},1}:
[-2.0844148956440796]
[-1.704630494695469]
```
"""
function transitions_from_chain(
model::DynamicPPL.Model, chain::MCMCChains.Chains; kwargs...
)
return transitions_from_chain(Random.default_rng(), model, chain; kwargs...)
end

function transitions_from_chain(
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
chain::MCMCChains.Chains;
sampler=DynamicPPL.SampleFromPrior(),
)
vi = Turing.VarInfo(model)

iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
transitions = map(iters) do (sample_idx, chain_idx)
# Set variables present in `chain` and mark those NOT present in chain to be resampled.
DynamicPPL.setval_and_resample!(vi, chain, sample_idx, chain_idx)
model(rng, vi, sampler)

# Convert `VarInfo` into `NamedTuple` and save.
Transition(model, vi)
end

return transitions
end

end # module
11 changes: 5 additions & 6 deletions src/mcmc/emcee.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,14 @@ function AbstractMCMC.step(
end

# Compute initial transition and states.
transition = map(Base.Fix1(Transition, model), vis)
transition = [Transition(model, vi, nothing) for vi in vis]

# TODO: Make compatible with immutable `AbstractVarInfo`.
state = EmceeState(
vis[1],
map(vis) do vi
vi = DynamicPPL.link!!(vi, model)
AMH.Transition(vi[:], DynamicPPL.getlogjoint(vi), false)
AMH.Transition(vi[:], DynamicPPL.getlogjoint_internal(vi), false)
end,
)

Expand All @@ -87,18 +87,17 @@ function AbstractMCMC.step(
densitymodel = AMH.DensityModel(
Base.Fix1(
LogDensityProblems.logdensity,
DynamicPPL.LogDensityFunction(model, DynamicPPL.getlogjoint, vi),
DynamicPPL.LogDensityFunction(model, DynamicPPL.getlogjoint_internal, vi),
),
)

# Compute the next states.
states = last(AbstractMCMC.step(rng, densitymodel, spl.alg.ensemble, state.states))
t, states = AbstractMCMC.step(rng, densitymodel, spl.alg.ensemble, state.states)

# Compute the next transition and state.
transition = map(states) do _state
vi = DynamicPPL.unflatten(vi, _state.params)
t = Transition(getparams(model, vi), _state.lp)
return t
return Transition(model, vi, t)
end
newstate = EmceeState(vi, states)

Expand Down
4 changes: 2 additions & 2 deletions src/mcmc/ess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ function DynamicPPL.initialstep(
EllipticalSliceSampling.isgaussian(typeof(dist)) ||
error("ESS only supports Gaussian prior distributions")
end
return Transition(model, vi), vi
return Transition(model, vi, nothing), vi
end

function AbstractMCMC.step(
Expand All @@ -56,7 +56,7 @@ function AbstractMCMC.step(
vi = DynamicPPL.unflatten(vi, sample)
vi = DynamicPPL.setloglikelihood!!(vi, state.loglikelihood)

return Transition(model, vi), vi
return Transition(model, vi, nothing), vi
end

# Prior distribution of considered random variable
Expand Down
Loading
Loading