From a29b95396fc63d688ffb82e31e11756cd6265416 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 21 Jul 2025 17:36:42 +0100 Subject: [PATCH 01/15] logjac accumulator --- src/abstract_varinfo.jl | 99 +++++++++++++++++++++++++++++----- src/accumulators.jl | 15 +++++- src/context_implementations.jl | 4 +- src/default_accumulators.jl | 84 ++++++++++++++++++++++++++++- src/pointwise_logdensities.jl | 2 + src/simple_varinfo.jl | 2 +- src/transforming.jl | 30 +++-------- 7 files changed, 195 insertions(+), 41 deletions(-) diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 581ca829b..bd9dfb3ed 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -99,16 +99,34 @@ See also: [`getlogprior`](@ref), [`getloglikelihood`](@ref). """ getlogjoint(vi::AbstractVarInfo) = getlogprior(vi) + getloglikelihood(vi) +""" + getlogjoint_internal(vi::AbstractVarInfo) + +Return the log of the joint probability of the observed data and parameters as +they are stored internally in `vi`, including the log-Jacobian for any linked +parameters. + +In general, we have that: + +```julia +getlogjoint_internal(vi) == getlogjoint(vi) - getlogjac(vi) +``` +""" +getlogjoint_internal(vi::AbstractVarInfo) = + getlogprior(vi) + getloglikelihood(vi) - getlogjac(vi) + """ getlogp(vi::AbstractVarInfo) -Return a NamedTuple of the log prior and log likelihood probabilities. +Return a NamedTuple of the log prior, log Jacobian, and log likelihood probabilities. -The keys are called `logprior` and `loglikelihood`. If either one is not present in `vi` an -error will be thrown. +The keys are called `logprior`, `logjac`, and `loglikelihood`. If any of them +are not present in `vi` an error will be thrown. """ function getlogp(vi::AbstractVarInfo) - return (; logprior=getlogprior(vi), loglikelihood=getloglikelihood(vi)) + return (; + logprior=getlogprior(vi), logjac=getlogjac(vi), loglikelihood=getloglikelihood(vi) + ) end """ @@ -164,6 +182,30 @@ See also: [`getlogjoint`](@ref), [`getloglikelihood`](@ref), [`setlogprior!!`](@ """ getlogprior(vi::AbstractVarInfo) = getacc(vi, Val(:LogPrior)).logp +""" + getlogprior_internal(vi::AbstractVarInfo) + +Return the log of the prior probability of the parameters as stored internally +in `vi`. This includes the log-Jacobian for any linked parameters. + +In general, we have that: + +```julia +getlogprior_internal(vi) == getlogprior(vi) - getlogjac(vi) +``` +""" +getlogprior_internal(vi::AbstractVarInfo) = getlogprior(vi) - getlogjac(vi) + +""" + getlogjac(vi::AbstractVarInfo) + +Return the accumulated log-Jacobian term for any linked parameters in `vi`. The +Jacobian here is taken with respect to the forward (link) transform. + +See also: [`setlogjac!!`](@ref). +""" +getlogjac(vi::AbstractVarInfo) = getacc(vi, Val(:LogJacobian)).logJ + """ getloglikelihood(vi::AbstractVarInfo) @@ -196,6 +238,16 @@ See also: [`setloglikelihood!!`](@ref), [`setlogp!!`](@ref), [`getlogprior`](@re """ setlogprior!!(vi::AbstractVarInfo, logp) = setacc!!(vi, LogPriorAccumulator(logp)) +""" + setlogjac!!(vi::AbstractVarInfo, logJ) + +Set the accumulated log-Jacobian term for any linked parameters in `vi`. The +Jacobian here is taken with respect to the forward (link) transform. + +See also: [`getlogjac!!`](@ref). +""" +setlogjac!!(vi::AbstractVarInfo, logJ) = setacc!!(vi, LogJacobianAccumulator(logJ)) + """ setloglikelihood!!(vi::AbstractVarInfo, logp) @@ -215,10 +267,13 @@ Set both the log prior and the log likelihood probabilities in `vi`. See also: [`setlogprior!!`](@ref), [`setloglikelihood!!`](@ref), [`getlogp`](@ref). """ function setlogp!!(vi::AbstractVarInfo, logp::NamedTuple{names}) where {names} - if !(names == (:logprior, :loglikelihood) || names == (:loglikelihood, :logprior)) - error("logp must have the fields logprior and loglikelihood and no other fields.") + if Set(names) != Set([:logprior, :logjac, :loglikelihood]) + error( + "The second argument to `setlogp!!` must be a NamedTuple with the fields logprior, logjac, and loglikelihood.", + ) end vi = setlogprior!!(vi, logp.logprior) + vi = setlogjac!!(vi, logp.logjac) vi = setloglikelihood!!(vi, logp.loglikelihood) return vi end @@ -226,7 +281,7 @@ end function setlogp!!(vi::AbstractVarInfo, logp::Number) return error(""" `setlogp!!(vi::AbstractVarInfo, logp::Number)` is no longer supported. Use - `setloglikelihood!!` and/or `setlogprior!!` instead. + `setloglikelihood!!`, `setlogjac!!`, and/or `setlogprior!!` instead. """) end @@ -306,6 +361,19 @@ function acclogprior!!(vi::AbstractVarInfo, logp) return map_accumulator!!(acc -> acc + LogPriorAccumulator(logp), vi, Val(:LogPrior)) end +""" + acclogjac!!(vi::AbstractVarInfo, logJ) + +Add `logJ` to the value of the log Jacobian in `vi`. + +See also: [`getlogjac`](@ref), [`setlogjac!!`](@ref). +""" +function acclogjac!!(vi::AbstractVarInfo, logJ) + return map_accumulator!!( + acc -> acc + LogJacobianAccumulator(logJ), vi, Val(:LogJacobian) + ) +end + """ accloglikelihood!!(vi::AbstractVarInfo, logp) @@ -368,6 +436,9 @@ function resetlogp!!(vi::AbstractVarInfo) if hasacc(vi, Val(:LogPrior)) vi = map_accumulator!!(zero, vi, Val(:LogPrior)) end + if hasacc(vi, Val(:LogJacobian)) + vi = map_accumulator!!(zero, vi, Val(:LogJacobian)) + end if hasacc(vi, Val(:LogLikelihood)) vi = map_accumulator!!(zero, vi, Val(:LogLikelihood)) end @@ -836,8 +907,10 @@ function link!!( x = vi[:] y, logjac = with_logabsdet_jacobian(b, x) - lp_new = getlogprior(vi) - logjac - vi_new = setlogprior!!(unflatten(vi, y), lp_new) + # Set parameters + vi_new = unflatten(vi, y) + # Update logjac + vi_new = setlogjac!!(vi_new, logjac) return settrans!!(vi_new, t) end @@ -846,10 +919,12 @@ function invlink!!( ) b = t.bijector y = vi[:] - x, logjac = with_logabsdet_jacobian(b, y) + x = b(y) - lp_new = getlogprior(vi) + logjac - vi_new = setlogprior!!(unflatten(vi, x), lp_new) + # Set parameters + vi_new = unflatten(vi, x) + # Reset logjac to 0 + vi_new = setlogjac!!(vi_new, 0.0) return settrans!!(vi_new, NoTransformation()) end diff --git a/src/accumulators.jl b/src/accumulators.jl index 1e3e37e61..7d30b62f0 100644 --- a/src/accumulators.jl +++ b/src/accumulators.jl @@ -11,10 +11,21 @@ seen so far. An accumulator type `T <: AbstractAccumulator` must implement the following methods: - `accumulator_name(acc::T)` or `accumulator_name(::Type{T})` -- `accumulate_observe!!(acc::T, right, left, vn)` -- `accumulate_assume!!(acc::T, val, logjac, vn, right)` +- `accumulate_observe!!(acc::T, dist, val, vn)` +- `accumulate_assume!!(acc::T, val, logjac, vn, dist)` - `Base.copy(acc::T)` +In these functions: +- `val` is the new value of the random variable sampled from a new distribution (always + in the original unlinked space), or the value on the left-hand side of an observe + statement. +- `dist` is the distribution on the RHS of the tilde statement. +- `vn` is the `VarName` that is on the left-hand side of the tilde-statement. If the + tilde-statement is a literal observation like `0.0 ~ Normal()`, then `vn` is `nothing`. +- `logjac` is the log determinant of the Jacobian of the link transformation, _if_ the + variable is stored as a linked value in the VarInfo. If the variable is stored in its + original, unlinked form, then `logjac` is zero. + To be able to work with multi-threading, it should also implement: - `split(acc::T)` - `combine(acc::T, acc2::T)` diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 9e9a2d63d..66fcc83a9 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -124,7 +124,7 @@ function assume(dist::Distribution, vn::VarName, vi) y = getindex_internal(vi, vn) f = from_maybe_linked_internal_transform(vi, vn, dist) x, logjac = with_logabsdet_jacobian(f, y) - vi = accumulate_assume!!(vi, x, logjac, vn, dist) + vi = accumulate_assume!!(vi, x, -logjac, vn, dist) return x, vi end @@ -166,6 +166,6 @@ function assume( # HACK: The above code might involve an `invlink` somewhere, etc. so we need to correct. logjac = logabsdetjac(istrans(vi, vn) ? link_transform(dist) : identity, r) - vi = accumulate_assume!!(vi, r, -logjac, vn, dist) + vi = accumulate_assume!!(vi, r, logjac, vn, dist) return r, vi end diff --git a/src/default_accumulators.jl b/src/default_accumulators.jl index 418362e8f..1a61d68f0 100644 --- a/src/default_accumulators.jl +++ b/src/default_accumulators.jl @@ -3,6 +3,10 @@ An accumulator that tracks the cumulative log prior during model execution. +Note that the log prior stored in here is always calculated based on unlinked +parameters, i.e., the value of `logp` is independent of whether tha VarInfo is +linked or not. + # Fields $(TYPEDFIELDS) """ @@ -19,6 +23,48 @@ Create a new `LogPriorAccumulator` accumulator with the log prior initialized to LogPriorAccumulator{T}() where {T<:Real} = LogPriorAccumulator(zero(T)) LogPriorAccumulator() = LogPriorAccumulator{LogProbType}() +""" + LogJacobianAccumulator{T<:Real} <: AbstractAccumulator + +An accumulator that tracks the cumulative log Jacobian (technically, +log(abs(det(J)))) during model execution. Specifically, J refers to the +Jacobian of the _link transform_, i.e., from the space of the original +distribution to unconstrained space. + +!!! note + This accumulator is only incremented if the variable is transformed by a + link function, i.e., if the VarInfo is linked (for this particular + variable). If the VarInfo is not linked, the log Jacobian term will be 0. + + In general, for the forward Jacobian `J` corresponding to the function `y = + f(x)`, + + ```math + \\log(q(\\mathbf{y})) = \\log(p(\\mathbf{x})) - \\log\\(|\\mathbf{J}|\\) + ``` + + and correspondingly: + + ```julia + getlogjoint_internal(vi) = getlogjoint(vi) - getlogjac(vi) + ``` + +# Fields +$(TYPEDFIELDS) +""" +struct LogJacobianAccumulator{T<:Real} <: AbstractAccumulator + "the logabsdet of the link transform Jacobian" + logJ::T +end + +""" + LogJacobianAccumulator{T}() + +Create a new `LogJacobianAccumulator` accumulator with the log Jacobian initialized to zero. +""" +LogJacobianAccumulator{T}() where {T<:Real} = LogJacobianAccumulator(zero(T)) +LogJacobianAccumulator() = LogJacobianAccumulator{LogProbType}() + """ LogLikelihoodAccumulator{T<:Real} <: AbstractAccumulator @@ -71,6 +117,7 @@ VariableOrderAccumulator(n) = VariableOrderAccumulator{typeof(n)}(n) VariableOrderAccumulator() = VariableOrderAccumulator{Int}() Base.copy(acc::LogPriorAccumulator) = acc +Base.copy(acc::LogJacobianAccumulator) = acc Base.copy(acc::LogLikelihoodAccumulator) = acc function Base.copy(acc::VariableOrderAccumulator) return VariableOrderAccumulator(acc.num_produce, copy(acc.order)) @@ -79,6 +126,9 @@ end function Base.show(io::IO, acc::LogPriorAccumulator) return print(io, "LogPriorAccumulator($(repr(acc.logp)))") end +function Base.show(io::IO, acc::LogJacobianAccumulator) + return print(io, "LogJacobianAccumulator($(repr(acc.logJ)))") +end function Base.show(io::IO, acc::LogLikelihoodAccumulator) return print(io, "LogLikelihoodAccumulator($(repr(acc.logp)))") end @@ -92,6 +142,9 @@ end # equality of hashes. Both of the below implementations are also different from the default # implementation for structs. Base.:(==)(acc1::LogPriorAccumulator, acc2::LogPriorAccumulator) = acc1.logp == acc2.logp +function Base.:(==)(acc1::LogJacobianAccumulator, acc2::LogJacobianAccumulator) + return acc1.logJ == acc2.logJ +end function Base.:(==)(acc1::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator) return acc1.logp == acc2.logp end @@ -102,6 +155,9 @@ end function Base.isequal(acc1::LogPriorAccumulator, acc2::LogPriorAccumulator) return isequal(acc1.logp, acc2.logp) end +function Base.isequal(acc1::LogJacobianAccumulator, acc2::LogJacobianAccumulator) + return isequal(acc1.logJ, acc2.logJ) +end function Base.isequal(acc1::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator) return isequal(acc1.logp, acc2.logp) end @@ -110,6 +166,9 @@ function Base.isequal(acc1::VariableOrderAccumulator, acc2::VariableOrderAccumul end Base.hash(acc::LogPriorAccumulator, h::UInt) = hash((LogPriorAccumulator, acc.logp), h) +function Base.hash(acc::LogJacobianAccumulator, h::UInt) + return hash((LogJacobianAccumulator, acc.logJ), h) +end function Base.hash(acc::LogLikelihoodAccumulator, h::UInt) return hash((LogLikelihoodAccumulator, acc.logp), h) end @@ -118,16 +177,21 @@ function Base.hash(acc::VariableOrderAccumulator, h::UInt) end accumulator_name(::Type{<:LogPriorAccumulator}) = :LogPrior +accumulator_name(::Type{<:LogJacobianAccumulator}) = :LogJacobian accumulator_name(::Type{<:LogLikelihoodAccumulator}) = :LogLikelihood accumulator_name(::Type{<:VariableOrderAccumulator}) = :VariableOrder split(::LogPriorAccumulator{T}) where {T} = LogPriorAccumulator(zero(T)) +split(::LogJacobianAccumulator{T}) where {T} = LogJacobianAccumulator(zero(T)) split(::LogLikelihoodAccumulator{T}) where {T} = LogLikelihoodAccumulator(zero(T)) split(acc::VariableOrderAccumulator) = copy(acc) function combine(acc::LogPriorAccumulator, acc2::LogPriorAccumulator) return LogPriorAccumulator(acc.logp + acc2.logp) end +function combine(acc::LogJacobianAccumulator, acc2::LogJacobianAccumulator) + return LogJacobianAccumulator(acc.logJ + acc2.logJ) +end function combine(acc::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator) return LogLikelihoodAccumulator(acc.logp + acc2.logp) end @@ -142,6 +206,9 @@ end function Base.:+(acc1::LogPriorAccumulator, acc2::LogPriorAccumulator) return LogPriorAccumulator(acc1.logp + acc2.logp) end +function Base.:+(acc1::LogJacobianAccumulator, acc2::LogJacobianAccumulator) + return LogJacobianAccumulator(acc1.logJ + acc2.logJ) +end function Base.:+(acc1::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator) return LogLikelihoodAccumulator(acc1.logp + acc2.logp) end @@ -150,13 +217,19 @@ function increment(acc::VariableOrderAccumulator) end Base.zero(acc::LogPriorAccumulator) = LogPriorAccumulator(zero(acc.logp)) +Base.zero(acc::LogJacobianAccumulator) = LogJacobianAccumulator(zero(acc.logJ)) Base.zero(acc::LogLikelihoodAccumulator) = LogLikelihoodAccumulator(zero(acc.logp)) function accumulate_assume!!(acc::LogPriorAccumulator, val, logjac, vn, right) - return acc + LogPriorAccumulator(logpdf(right, val) + logjac) + return acc + LogPriorAccumulator(logpdf(right, val)) end accumulate_observe!!(acc::LogPriorAccumulator, right, left, vn) = acc +function accumulate_assume!!(acc::LogJacobianAccumulator, val, logjac, vn, right) + return acc + LogJacobianAccumulator(logjac) +end +accumulate_observe!!(acc::LogJacobianAccumulator, right, left, vn) = acc + accumulate_assume!!(acc::LogLikelihoodAccumulator, val, logjac, vn, right) = acc function accumulate_observe!!(acc::LogLikelihoodAccumulator, right, left, vn) # Note that it's important to use the loglikelihood function here, not logpdf, because @@ -174,6 +247,11 @@ accumulate_observe!!(acc::VariableOrderAccumulator, right, left, vn) = increment function Base.convert(::Type{LogPriorAccumulator{T}}, acc::LogPriorAccumulator) where {T} return LogPriorAccumulator(convert(T, acc.logp)) end +function Base.convert( + ::Type{LogJacobianAccumulator{T}}, acc::LogJacobianAccumulator +) where {T} + return LogJacobianAccumulator(convert(T, acc.logJ)) +end function Base.convert( ::Type{LogLikelihoodAccumulator{T}}, acc::LogLikelihoodAccumulator ) where {T} @@ -197,6 +275,9 @@ end function convert_eltype(::Type{T}, acc::LogPriorAccumulator) where {T} return LogPriorAccumulator(convert(T, acc.logp)) end +function convert_eltype(::Type{T}, acc::LogJacobianAccumulator) where {T} + return LogJacobianAccumulator(convert(T, acc.logJ)) +end function convert_eltype(::Type{T}, acc::LogLikelihoodAccumulator) where {T} return LogLikelihoodAccumulator(convert(T, acc.logp)) end @@ -206,6 +287,7 @@ function default_accumulators( ) where {FloatT,IntT} return AccumulatorTuple( LogPriorAccumulator{FloatT}(), + LogJacobianAccumulator{FloatT}(), LogLikelihoodAccumulator{FloatT}(), VariableOrderAccumulator{IntT}(), ) diff --git a/src/pointwise_logdensities.jl b/src/pointwise_logdensities.jl index 44882f91e..14f8103e7 100644 --- a/src/pointwise_logdensities.jl +++ b/src/pointwise_logdensities.jl @@ -74,6 +74,8 @@ function accumulate_assume!!( # T is the element type of the vectors that are the values of `acc.logps`. Usually # it's LogProbType. T = eltype(last(fieldtypes(eltype(acc.logps)))) + # Note that accumulating LogPrior ignores logjac (since we want to + # return log densities that don't depend on the linking status of the VarInfo). subacc = accumulate_assume!!(LogPriorAccumulator{T}(), val, logjac, vn, right) push!(acc, vn, subacc.logp) end diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index abb93a0ab..24c358afc 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -476,7 +476,7 @@ function assume( f = to_maybe_linked_internal_transform(vi, vn, dist) value_raw, logjac = with_logabsdet_jacobian(f, value) vi = BangBang.push!!(vi, vn, value_raw, dist) - vi = accumulate_assume!!(vi, value, -logjac, vn, dist) + vi = accumulate_assume!!(vi, value, logjac, vn, dist) return value, vi end diff --git a/src/transforming.jl b/src/transforming.jl index e3da0ff29..56f861cff 100644 --- a/src/transforming.jl +++ b/src/transforming.jl @@ -15,8 +15,8 @@ NodeTrait(::DynamicTransformationContext) = IsLeaf() function tilde_assume( ::DynamicTransformationContext{isinverse}, right, vn, vi ) where {isinverse} - r = vi[vn, right] - lp = Bijectors.logpdf_with_trans(right, r, !isinverse) + # vi[vn, right] always provides the value in unlinked space. + x = vi[vn, right] if istrans(vi, vn) isinverse || @warn "Trying to link an already transformed variable ($vn)" @@ -24,13 +24,11 @@ function tilde_assume( isinverse && @warn "Trying to invlink a non-transformed variable ($vn)" end - # Only transform if `!isinverse` since `vi[vn, right]` - # already performs the inverse transformation if it's transformed. - r_transformed = isinverse ? r : link_transform(right)(r) - if hasacc(vi, Val(:LogPrior)) - vi = acclogprior!!(vi, lp) - end - return r, setindex!!(vi, r_transformed, vn) + transform = isinverse ? identity : link_transform(right) + y, logjac = with_logabsdet_jacobian(transform, x) + vi = accumulate_assume!!(vi, x, logjac, vn, right) + vi = setindex!!(vi, y, vn) + return x, vi end function tilde_observe!!(::DynamicTransformationContext, right, left, vn, vi) @@ -53,21 +51,7 @@ function _transform!!( ) # To transform using DynamicTransformationContext, we evaluate the model using that as the leaf context: model = contextualize(model, setleafcontext(model.context, ctx)) - # but we do not need to use any accumulators other than LogPriorAccumulator - # (which is affected by the Jacobian of the transformation). - accs = getaccs(vi) - has_logprior = haskey(accs, Val(:LogPrior)) - if has_logprior - old_logprior = getacc(accs, Val(:LogPrior)) - vi = setaccs!!(vi, (old_logprior,)) - end vi = settrans!!(last(evaluate!!(model, vi)), t) - # Restore the accumulators. - if has_logprior - new_logprior = getacc(vi, Val(:LogPrior)) - accs = setacc!!(accs, new_logprior) - end - vi = setaccs!!(vi, accs) return vi end From d6c9cfaafc8334f75b79e24a0748650c6cfdc374 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 21 Jul 2025 18:06:46 +0100 Subject: [PATCH 02/15] Fix tests --- benchmarks/src/DynamicPPLBenchmarks.jl | 4 +- docs/src/api.md | 5 ++ src/DynamicPPL.jl | 3 ++ src/logdensityfunction.jl | 59 ++++++++++++++++------ src/model.jl | 8 +++ src/simple_varinfo.jl | 4 +- src/test_utils/ad.jl | 5 +- test/ad.jl | 10 ++-- test/linking.jl | 17 ++++--- test/logdensityfunction.jl | 3 ++ test/model.jl | 7 +++ test/simple_varinfo.jl | 36 ++++++------- test/varinfo.jl | 70 +++++++++++--------------- 13 files changed, 142 insertions(+), 89 deletions(-) diff --git a/benchmarks/src/DynamicPPLBenchmarks.jl b/benchmarks/src/DynamicPPLBenchmarks.jl index 54a302a6f..8c5032ace 100644 --- a/benchmarks/src/DynamicPPLBenchmarks.jl +++ b/benchmarks/src/DynamicPPLBenchmarks.jl @@ -86,7 +86,9 @@ function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked:: vi = DynamicPPL.link(vi, model) end - f = DynamicPPL.LogDensityFunction(model, DynamicPPL.getlogjoint, vi; adtype=adbackend) + f = DynamicPPL.LogDensityFunction( + model, DynamicPPL.getlogjoint_internal, vi; adtype=adbackend + ) # The parameters at which we evaluate f. θ = vi[:] diff --git a/docs/src/api.md b/docs/src/api.md index 180e8dfd4..8d0b7b558 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -367,6 +367,7 @@ DynamicPPL provides the following default accumulators. ```@docs LogPriorAccumulator +LogJacobianAccumulator LogLikelihoodAccumulator VariableOrderAccumulator ``` @@ -380,7 +381,11 @@ getlogp setlogp!! acclogp!! getlogjoint +getlogjoint_internal +getlogjac +setlogjac!! getlogprior +getlogprior_internal setlogprior!! acclogprior!! getloglikelihood diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index c282939a2..462adbbb8 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -58,6 +58,9 @@ export AbstractVarInfo, getlogjoint, getlogprior, getloglikelihood, + getlogjac, + getlogjoint_internal, + getlogprior_internal, setlogp!!, setlogprior!!, setloglikelihood!!, diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 3c092c06b..9e801dd30 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -18,7 +18,7 @@ is_supported(::ADTypes.AutoReverseDiff) = true """ LogDensityFunction( model::Model, - getlogdensity::Function=getlogjoint, + getlogdensity::Function=getlogjoint_internal, varinfo::AbstractVarInfo=ldf_default_varinfo(model, getlogdensity); adtype::Union{ADTypes.AbstractADType,Nothing}=nothing ) @@ -29,10 +29,37 @@ A struct which contains a model, along with all the information necessary to: - and if `adtype` is provided, calculate the gradient of the log density at that point. -At its most basic level, a LogDensityFunction wraps the model together with a -function that specifies how to extract the log density, and the type of -VarInfo to be used. These must be known in order to calculate the log density -(using [`DynamicPPL.evaluate!!`](@ref)). +This information can be extracted using the LogDensityProblems.jl interface, +specifically, using `LogDensityProblems.logdensity` and +`LogDensityProblems.logdensity_and_gradient`. If `adtype` is nothing, then only +`logdensity` is implemented. If `adtype` is a concrete AD backend type, then +`logdensity_and_gradient` is also implemented. + +There are several options for `getlogdensity` that are 'supported' out of the +box: + +- [`getlogjoint_internal`](@ref): calculate the log joint, including the + log-Jacobian term for any variables that have been linked in the provided + VarInfo. +- [`getlogprior_internal`](@ref): calculate the log prior, including the + log-Jacobian term for any variables that have been linked in the provided + VarInfo. +- [`getlogjoint`](@ref): calculate the log joint in the model space, ignoring + any effects of linking +- [`getlogprior`](@ref): calculate the log prior in the model space, ignoring + any effects of linking +- [`getloglikelihood`](@ref): calculate the log likelihood (this is unaffected + by linking, since transforms are only applied to random variables) + +!!! note + By default, `LogDensityFunction` uses `getlogjoint_internal`, i.e., the + result of `LogDensityProblems.logdensity(f, x)` will depend on whether the + `LogDensityFunction` was created with a linked or unlinked VarInfo. This + is done primarily to ease interoperability with MCMC samplers. + +If you provide one of these functions, a `VarInfo` will be automatically created +for you. If you provide a different function, you have to manually create a +VarInfo and pass it as the third argument. If the `adtype` keyword argument is provided, then this struct will also store the adtype along with other information for efficient calculation of the @@ -40,10 +67,6 @@ gradient of the log density. Note that preparing a `LogDensityFunction` with an AD type `AutoBackend()` requires the AD backend itself to have been loaded (e.g. with `import Backend`). -`DynamicPPL.LogDensityFunction` implements the LogDensityProblems.jl interface. -If `adtype` is nothing, then only `logdensity` is implemented. If `adtype` is a -concrete AD backend type, then `logdensity_and_gradient` is also implemented. - # Fields $(FIELDS) @@ -74,7 +97,7 @@ julia> LogDensityProblems.dimension(f) 1 julia> # By default it uses `VarInfo` under the hood, but this is not necessary. - f = LogDensityFunction(model, getlogjoint, SimpleVarInfo(model)); + f = LogDensityFunction(model, getlogjoint_internal, SimpleVarInfo(model)); julia> LogDensityProblems.logdensity(f, [0.0]) -2.3378770664093453 @@ -99,7 +122,7 @@ struct LogDensityFunction{ } <: AbstractModel "model used for evaluation" model::M - "function to be called on `varinfo` to extract the log density. By default `getlogjoint`." + "function to be called on `varinfo` to extract the log density. By default `getlogjoint_internal`." getlogdensity::F "varinfo used for evaluation. If not specified, generated with `ldf_default_varinfo`." varinfo::V @@ -110,7 +133,7 @@ struct LogDensityFunction{ function LogDensityFunction( model::Model, - getlogdensity::Function=getlogjoint, + getlogdensity::Function=getlogjoint_internal, varinfo::AbstractVarInfo=ldf_default_varinfo(model, getlogdensity); adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, ) @@ -180,10 +203,18 @@ function ldf_default_varinfo(::Model, getlogdensity::Function) return error(msg) end -ldf_default_varinfo(model::Model, ::typeof(getlogjoint)) = VarInfo(model) +ldf_default_varinfo(model::Model, ::typeof(getlogjoint_internal)) = VarInfo(model) + +function ldf_default_varinfo(model::Model, ::typeof(getlogjoint)) + return setaccs!!(VarInfo(model), (LogPriorAccumulator(), LogLikelihoodAccumulator())) +end + +function ldf_default_varinfo(model::Model, ::typeof(getlogprior_internal)) + return setaccs!!(VarInfo(model), (LogPriorAccumulator(), LogJacobianAccumulator())) +end function ldf_default_varinfo(model::Model, ::typeof(getlogprior)) - return setaccs!!(VarInfo(model), (LogPriorAccumulator(),)) + return setaccs!!(VarInfo(model), (LogPriorAccumulator())) end function ldf_default_varinfo(model::Model, ::typeof(getloglikelihood)) diff --git a/src/model.jl b/src/model.jl index 93e77eaec..dbbe0b85b 100644 --- a/src/model.jl +++ b/src/model.jl @@ -995,6 +995,10 @@ Base.rand(model::Model) = rand(Random.default_rng(), NamedTuple, model) Return the log joint probability of variables `varinfo` for the probabilistic `model`. +Note that this probability always refers to the parameters in unlinked space, i.e., +the return value of `logjoint` does not depend on whether `VarInfo` has been linked +or not. + See [`logprior`](@ref) and [`loglikelihood`](@ref). """ function logjoint(model::Model, varinfo::AbstractVarInfo) @@ -1042,6 +1046,10 @@ end Return the log prior probability of variables `varinfo` for the probabilistic `model`. +Note that this probability always refers to the parameters in unlinked space, i.e., +the return value of `logprior` does not depend on whether `VarInfo` has been linked +or not. + See also [`logjoint`](@ref) and [`loglikelihood`](@ref). """ function logprior(model::Model, varinfo::AbstractVarInfo) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 24c358afc..e65a8e5bf 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -125,7 +125,7 @@ julia> vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), true) Transformed SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0), VariableOrder = VariableOrderAccumulator(0, Dict{VarName, Int64}()))) julia> # (✓) Positive probability mass on negative numbers! - getlogjoint(last(DynamicPPL.evaluate!!(m, vi))) + getlogjoint_internal(last(DynamicPPL.evaluate!!(m, vi))) -1.3678794411714423 julia> # While if we forget to indicate that it's transformed: @@ -133,7 +133,7 @@ julia> # While if we forget to indicate that it's transformed: SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0), VariableOrder = VariableOrderAccumulator(0, Dict{VarName, Int64}()))) julia> # (✓) No probability mass on negative numbers! - getlogjoint(last(DynamicPPL.evaluate!!(m, vi))) + getlogjoint_internal(last(DynamicPPL.evaluate!!(m, vi))) -Inf ``` diff --git a/src/test_utils/ad.jl b/src/test_utils/ad.jl index d4f6f9a1d..1ac33a481 100644 --- a/src/test_utils/ad.jl +++ b/src/test_utils/ad.jl @@ -4,7 +4,8 @@ using ADTypes: AbstractADType, AutoForwardDiff using Chairmarks: @be import DifferentiationInterface as DI using DocStringExtensions -using DynamicPPL: Model, LogDensityFunction, VarInfo, AbstractVarInfo, getlogjoint, link +using DynamicPPL: + Model, LogDensityFunction, VarInfo, AbstractVarInfo, getlogjoint_internal, link using LogDensityProblems: logdensity, logdensity_and_gradient using Random: AbstractRNG, default_rng using Statistics: median @@ -224,7 +225,7 @@ function run_ad( benchmark::Bool=false, atol::AbstractFloat=100 * eps(), rtol::AbstractFloat=sqrt(eps()), - getlogdensity::Function=getlogjoint, + getlogdensity::Function=getlogjoint_internal, rng::AbstractRNG=default_rng(), varinfo::AbstractVarInfo=link(VarInfo(rng, model), model), params::Union{Nothing,Vector{<:AbstractFloat}}=nothing, diff --git a/test/ad.jl b/test/ad.jl index 308894ada..371e79b06 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -30,7 +30,7 @@ using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos linked_varinfo = DynamicPPL.link(varinfo, m) - f = LogDensityFunction(m, getlogjoint, linked_varinfo) + f = LogDensityFunction(m, getlogjoint_internal, linked_varinfo) x = DynamicPPL.getparams(f) # Calculate reference logp + gradient of logp using ForwardDiff @@ -52,17 +52,17 @@ using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest if is_mooncake && is_1_11 && is_svi_vnv # https://github.com/compintell/Mooncake.jl/issues/470 @test_throws ArgumentError DynamicPPL.LogDensityFunction( - m, getlogjoint, linked_varinfo; adtype=adtype + m, getlogjoint_internal, linked_varinfo; adtype=adtype ) elseif is_mooncake && is_1_10 && is_svi_vnv # TODO: report upstream @test_throws UndefRefError DynamicPPL.LogDensityFunction( - m, getlogjoint, linked_varinfo; adtype=adtype + m, getlogjoint_internal, linked_varinfo; adtype=adtype ) elseif is_mooncake && is_1_10 && is_svi_od # TODO: report upstream @test_throws Mooncake.MooncakeRuleCompilationError DynamicPPL.LogDensityFunction( - m, getlogjoint, linked_varinfo; adtype=adtype + m, getlogjoint_internal, linked_varinfo; adtype=adtype ) else @test run_ad( @@ -113,7 +113,7 @@ using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest spl = Sampler(MyEmptyAlg()) sampling_model = contextualize(model, SamplingContext(model.context)) ldf = LogDensityFunction( - sampling_model, getlogjoint; adtype=AutoReverseDiff(; compile=true) + sampling_model, getlogjoint_internal; adtype=AutoReverseDiff(; compile=true) ) x = ldf.varinfo[:] @test LogDensityProblems.logdensity_and_gradient(ldf, x) isa Any diff --git a/test/linking.jl b/test/linking.jl index b0c2dcb5c..12204e868 100644 --- a/test/linking.jl +++ b/test/linking.jl @@ -84,8 +84,11 @@ end else DynamicPPL.link(vi, model) end - # Difference should just be the log-absdet-jacobian "correction". - @test DynamicPPL.getlogjoint(vi) - DynamicPPL.getlogjoint(vi_linked) ≈ log(2) + # Difference between the internal logjoints should just be the log-absdet-jacobian "correction". + @test DynamicPPL.getlogjoint_internal(vi) - + DynamicPPL.getlogjoint_internal(vi_linked) ≈ log(2) + # The non-internal logjoint should be the same. + @test DynamicPPL.getlogjoint(vi) ≈ DynamicPPL.getlogjoint_internal(vi_linked) @test vi_linked[@varname(m), dist] == LowerTriangular(vi[@varname(m), dist]) # Linked one should be working with a lower-dimensional representation. @test length(vi_linked[:]) < length(vi[:]) @@ -99,6 +102,8 @@ end @test length(vi_invlinked[:]) == length(vi[:]) @test vi_invlinked[@varname(m), dist] ≈ LowerTriangular(vi[@varname(m), dist]) @test DynamicPPL.getlogjoint(vi_invlinked) ≈ DynamicPPL.getlogjoint(vi) + @test DynamicPPL.getlogjoint_internal(vi_invlinked) ≈ + DynamicPPL.getlogjoint_internal(vi) end end @@ -130,7 +135,7 @@ end end @test length(vi_linked[:]) == d * (d - 1) ÷ 2 # Should now include the log-absdet-jacobian correction. - @test !(getlogjoint(vi_linked) ≈ lp) + @test !(getlogjoint_internal(vi_linked) ≈ lp) # Invlinked. vi_invlinked = if mutable DynamicPPL.invlink!!(deepcopy(vi_linked), model) @@ -138,7 +143,7 @@ end DynamicPPL.invlink(vi_linked, model) end @test length(vi_invlinked[:]) == d^2 - @test getlogjoint(vi_invlinked) ≈ lp + @test getlogjoint_internal(vi_invlinked) ≈ lp end end end @@ -164,7 +169,7 @@ end end @test length(vi_linked[:]) == d - 1 # Should now include the log-absdet-jacobian correction. - @test !(getlogjoint(vi_linked) ≈ lp) + @test !(getlogjoint_internal(vi_linked) ≈ lp) # Invlinked. vi_invlinked = if mutable DynamicPPL.invlink!!(deepcopy(vi_linked), model) @@ -172,7 +177,7 @@ end DynamicPPL.invlink(vi_linked, model) end @test length(vi_invlinked[:]) == d - @test getlogjoint(vi_invlinked) ≈ lp + @test getlogjoint_internal(vi_invlinked) ≈ lp end end end diff --git a/test/logdensityfunction.jl b/test/logdensityfunction.jl index c4d0d6beb..fbd868f71 100644 --- a/test/logdensityfunction.jl +++ b/test/logdensityfunction.jl @@ -26,8 +26,11 @@ end loglikelihood(model, vi) @testset "$(varinfo)" for varinfo in varinfos + # Note use of `getlogjoint` rather than `getlogjoint_internal` here ... logdensity = DynamicPPL.LogDensityFunction(model, getlogjoint, varinfo) θ = varinfo[:] + # ... because it has to match with `logjoint(model, vi)`, which always returns + # the unlinked value @test LogDensityProblems.logdensity(logdensity, θ) ≈ logjoint(model, varinfo) @test LogDensityProblems.dimension(logdensity) == length(θ) end diff --git a/test/model.jl b/test/model.jl index daa3cc743..81f84e548 100644 --- a/test/model.jl +++ b/test/model.jl @@ -485,11 +485,18 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() DynamicPPL.untyped_simple_varinfo(model), ] @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos + logjoint = getlogjoint(varinfo) # unlinked space varinfo_linked = DynamicPPL.link(varinfo, model) varinfo_linked_result = last( DynamicPPL.evaluate!!(model, deepcopy(varinfo_linked)) ) + # getlogjoint should return the same result as before it was linked @test getlogjoint(varinfo_linked) ≈ getlogjoint(varinfo_linked_result) + @test getlogjoint(varinfo_linked) ≈ logjoint + # getlogjoint_internal shouldn't + @test getlogjoint_internal(varinfo_linked) ≈ + getlogjoint_internal(varinfo_linked_result) + @test !isapprox(getlogjoint_internal(varinfo_linked), logjoint) end end diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index e300c651e..a4c358702 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -100,27 +100,29 @@ end vi = last(DynamicPPL.evaluate!!(model, vi)) - # `link!!` - vi_linked = link!!(deepcopy(vi), model) - lp_linked = getlogjoint(vi_linked) - values_unconstrained, lp_linked_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( + # Calculate ground truth + lp_unlinked_true = DynamicPPL.TestUtils.logjoint_true( + models, values_constrained... + ) + _, lp_linked_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( model, values_constrained... ) - # Should result in the correct logjoint. + + # `link!!` + vi_linked = link!!(deepcopy(vi), model) + lp_unlinked = getlogjoint(vi_linked) + lp_linked = getlogjoint_internal(vi_linked) @test lp_linked ≈ lp_linked_true - # Should be approx. the same as the "lazy" transformation. - @test logjoint(model, vi_linked) ≈ lp_linked + @test lp_unlinked ≈ lp_unlinked_true + @test logjoint(model, vi_linked) ≈ lp_unlinked # `invlink!!` vi_invlinked = invlink!!(deepcopy(vi_linked), model) - lp_invlinked = getlogjoint(vi_invlinked) - lp_invlinked_true = DynamicPPL.TestUtils.logjoint_true( - model, values_constrained... - ) - # Should result in the correct logjoint. - @test lp_invlinked ≈ lp_invlinked_true - # Should be approx. the same as the "lazy" transformation. - @test logjoint(model, vi_invlinked) ≈ lp_invlinked + lp_unlinked = getlogjoint(vi_invlinked) + also_lp_unlinked = getlogjoint_internal(vi_invlinked) + @test lp_unlinked ≈ lp_unlinked_true + @test also_lp_unlinked ≈ lp_unlinked_true + @test logjoint(model, vi_invlinked) ≈ lp_unlinked # Should result in same values. @test all( @@ -250,7 +252,7 @@ end # `getlogp` should be equal to the logjoint with log-absdet-jac correction. - lp = getlogjoint(svi) + lp = getlogjoint_internal(svi) # needs higher atol because of https://github.com/TuringLang/Bijectors.jl/issues/375 @test lp ≈ lp_true atol = 1.2e-5 end @@ -304,7 +306,7 @@ DynamicPPL.tovec(retval_unconstrained.m) # The resulting varinfo should hold the correct logp. - lp = getlogjoint(vi_linked_result) + lp = getlogjoint_internal(vi_linked_result) @test lp ≈ lp_true end end diff --git a/test/varinfo.jl b/test/varinfo.jl index dad54f024..c68d5ca8f 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -167,8 +167,9 @@ end vi = last(DynamicPPL.evaluate!!(m, deepcopy(vi))) @test getlogprior(vi) == lp_a + lp_b + @test getlogjac(vi) == 0.0 @test getloglikelihood(vi) == lp_c + lp_d - @test getlogp(vi) == (; logprior=lp_a + lp_b, loglikelihood=lp_c + lp_d) + @test getlogp(vi) == (; logprior=lp_a + lp_b, logjac=0.0, loglikelihood=lp_c + lp_d) @test getlogjoint(vi) == lp_a + lp_b + lp_c + lp_d @test get_num_produce(vi) == 2 @test begin @@ -183,17 +184,21 @@ end vi = setlogprior!!(vi, -1.0) getlogprior(vi) == -1.0 end + @test begin + vi = setlogjac!!(vi, -1.0) + getlogjac(vi) == -1.0 + end @test begin vi = setloglikelihood!!(vi, -1.0) getloglikelihood(vi) == -1.0 end @test begin - vi = setlogp!!(vi, (logprior=-3.0, loglikelihood=-3.0)) - getlogp(vi) == (; logprior=-3.0, loglikelihood=-3.0) + vi = setlogp!!(vi, (logprior=-3.0, logjac=-3.0, loglikelihood=-3.0)) + getlogp(vi) == (; logprior=-3.0, logjac=-3.0, loglikelihood=-3.0) end @test begin vi = acclogp!!(vi, (logprior=1.0, loglikelihood=1.0)) - getlogp(vi) == (; logprior=-2.0, loglikelihood=-2.0) + getlogp(vi) == (; logprior=-2.0, logjac=-3.0, loglikelihood=-2.0) end @test getlogp(setlogp!!(vi, getlogp(vi))) == getlogp(vi) @@ -552,71 +557,52 @@ end end end - @testset "istrans" begin + @testset "logp evaluation on linked varinfo" begin @model demo_constrained() = x ~ truncated(Normal(); lower=0) model = demo_constrained() vn = @varname(x) dist = truncated(Normal(); lower=0) - ### `VarInfo` - # Need to run once since we can't specify that we want to _sample_ - # in the unconstrained space for `VarInfo` without having `vn` - # present in the `varinfo`. - - ## `untyped_varinfo` - vi = DynamicPPL.untyped_varinfo(model) + function test_linked_varinfo(model, vi) + # vn and dist are taken from the containing scope + vi = last(DynamicPPL.evaluate_and_sample!!(model, vi)) + f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) + x = f(DynamicPPL.getindex_internal(vi, vn)) + @test istrans(vi, vn) + @test getlogjoint_internal(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + @test getlogprior_internal(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + @test getloglikelihood(vi) == 0.0 + @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, false) + @test getlogprior(vi) ≈ Bijectors.logpdf_with_trans(dist, x, false) + end ## `untyped_varinfo` vi = DynamicPPL.untyped_varinfo(model) vi = DynamicPPL.settrans!!(vi, true, vn) - # Sample in unconstrained space. - vi = last(DynamicPPL.evaluate_and_sample!!(model, vi)) - f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) - x = f(DynamicPPL.getindex_internal(vi, vn)) - @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + test_linked_varinfo(model, vi, dist) ## `typed_varinfo` vi = DynamicPPL.typed_varinfo(model) vi = DynamicPPL.settrans!!(vi, true, vn) - # Sample in unconstrained space. - vi = last(DynamicPPL.evaluate_and_sample!!(model, vi)) - f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) - x = f(DynamicPPL.getindex_internal(vi, vn)) - @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + test_linked_varinfo(model, vi, dist) ## `typed_varinfo` vi = DynamicPPL.typed_varinfo(model) vi = DynamicPPL.settrans!!(vi, true, vn) - # Sample in unconstrained space. - vi = last(DynamicPPL.evaluate_and_sample!!(model, vi)) - f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) - x = f(DynamicPPL.getindex_internal(vi, vn)) - @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + test_linked_varinfo(model, vi, dist) ### `SimpleVarInfo` ## `SimpleVarInfo{<:NamedTuple}` vi = DynamicPPL.settrans!!(SimpleVarInfo(), true) - # Sample in unconstrained space. - vi = last(DynamicPPL.evaluate_and_sample!!(model, vi)) - f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) - x = f(DynamicPPL.getindex_internal(vi, vn)) - @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + test_linked_varinfo(model, vi, dist) ## `SimpleVarInfo{<:Dict}` vi = DynamicPPL.settrans!!(SimpleVarInfo(Dict()), true) - # Sample in unconstrained space. - vi = last(DynamicPPL.evaluate_and_sample!!(model, vi)) - f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) - x = f(DynamicPPL.getindex_internal(vi, vn)) - @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + test_linked_varinfo(model, vi, dist) ## `SimpleVarInfo{<:VarNamedVector}` vi = DynamicPPL.settrans!!(SimpleVarInfo(DynamicPPL.VarNamedVector()), true) - # Sample in unconstrained space. - vi = last(DynamicPPL.evaluate_and_sample!!(model, vi)) - f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) - x = f(DynamicPPL.getindex_internal(vi, vn)) - @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + test_linked_varinfo(model, vi, dist) end @testset "values_as" begin From e671a56185078b7f0f2153e8e992e0b4f8075b42 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 22 Jul 2025 00:37:21 +0100 Subject: [PATCH 03/15] Fix a whole bunch of stuff --- src/DynamicPPL.jl | 2 ++ src/abstract_varinfo.jl | 9 +++++--- src/context_implementations.jl | 4 ++-- src/logdensityfunction.jl | 2 +- src/simple_varinfo.jl | 14 +++++++----- src/threadsafe.jl | 5 +++++ src/varinfo.jl | 24 ++++++++++---------- test/accumulators.jl | 6 ++++- test/simple_varinfo.jl | 41 +++++++++++++++++++--------------- test/varinfo.jl | 19 ++++++++-------- 10 files changed, 75 insertions(+), 51 deletions(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 462adbbb8..9663b6702 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -63,8 +63,10 @@ export AbstractVarInfo, getlogprior_internal, setlogp!!, setlogprior!!, + setlogjac!!, setloglikelihood!!, acclogp!!, + acclogjac!!, acclogprior!!, accloglikelihood!!, resetlogp!!, diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index bd9dfb3ed..9dd922163 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -909,7 +909,8 @@ function link!!( # Set parameters vi_new = unflatten(vi, y) - # Update logjac + # Update logjac. We can overwrite any old value since there is only + # a single logjac term to worry about. vi_new = setlogjac!!(vi_new, logjac) return settrans!!(vi_new, t) end @@ -923,8 +924,10 @@ function invlink!!( # Set parameters vi_new = unflatten(vi, x) - # Reset logjac to 0 - vi_new = setlogjac!!(vi_new, 0.0) + # Reset logjac to 0. + if hasacc(vi_new, Val(:LogJacobian)) + vi_new = map_accumulator!!(zero, vi_new, Val(:LogJacobian)) + end return settrans!!(vi_new, NoTransformation()) end diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 66fcc83a9..786d7c913 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -123,8 +123,8 @@ end function assume(dist::Distribution, vn::VarName, vi) y = getindex_internal(vi, vn) f = from_maybe_linked_internal_transform(vi, vn, dist) - x, logjac = with_logabsdet_jacobian(f, y) - vi = accumulate_assume!!(vi, x, -logjac, vn, dist) + x, inv_logjac = with_logabsdet_jacobian(f, y) + vi = accumulate_assume!!(vi, x, -inv_logjac, vn, dist) return x, vi end diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 9e801dd30..3b790576a 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -214,7 +214,7 @@ function ldf_default_varinfo(model::Model, ::typeof(getlogprior_internal)) end function ldf_default_varinfo(model::Model, ::typeof(getlogprior)) - return setaccs!!(VarInfo(model), (LogPriorAccumulator())) + return setaccs!!(VarInfo(model), (LogPriorAccumulator(),)) end function ldf_default_varinfo(model::Model, ::typeof(getloglikelihood)) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index e65a8e5bf..85d0e6066 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -494,6 +494,7 @@ end istrans(vi::SimpleVarInfo) = !(vi.transformation isa NoTransformation) istrans(vi::SimpleVarInfo, ::VarName) = istrans(vi) istrans(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, vn::VarName) = istrans(vi.varinfo, vn) +istrans(vi::ThreadSafeVarInfo{<:SimpleVarInfo}) = istrans(vi.varinfo) islinked(vi::SimpleVarInfo) = istrans(vi) @@ -619,8 +620,10 @@ function link!!( x = vi.values y, logjac = with_logabsdet_jacobian(b, x) vi_new = Accessors.@set(vi.values = y) - if hasacc(vi_new, Val(:LogPrior)) - vi_new = acclogprior!!(vi_new, -logjac) + # Since there's only a single transformation, we can overwrite any previous + # value in logjac. + if hasacc(vi_new, Val(:LogJacobian)) + vi_new = setlogjac!!(vi_new, logjac) end return settrans!!(vi_new, t) end @@ -632,10 +635,11 @@ function invlink!!( ) b = t.bijector y = vi.values - x, logjac = with_logabsdet_jacobian(b, y) + x = b(y) vi_new = Accessors.@set(vi.values = x) - if hasacc(vi_new, Val(:LogPrior)) - vi_new = acclogprior!!(vi_new, logjac) + # logjac should be zero for an unlinked VarInfo. + if hasacc(vi_new, Val(:LogJacobian)) + vi_new = map_accumulator!!(zero, vi_new, Val(:LogJacobian)) end return settrans!!(vi_new, NoTransformation()) end diff --git a/src/threadsafe.jl b/src/threadsafe.jl index 9b82cd8b4..5f0a6d3e5 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -201,6 +201,11 @@ function resetlogp!!(vi::ThreadSafeVarInfo) zero, vi.accs_by_thread[i], Val(:LogPrior) ) end + if hasacc(vi, Val(:LogJacobian)) + vi.accs_by_thread[i] = map_accumulator( + zero, vi.accs_by_thread[i], Val(:LogJacobian) + ) + end if hasacc(vi, Val(:LogLikelihood)) vi.accs_by_thread[i] = map_accumulator( zero, vi.accs_by_thread[i], Val(:LogLikelihood) diff --git a/src/varinfo.jl b/src/varinfo.jl index d8233ae07..4cf5b0562 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1148,8 +1148,8 @@ function _inner_transform!(md::Metadata, vi::VarInfo, vn::VarName, f) setrange!(md, vn, start:(start + length(yvec) - 1)) # Set the new value. setval!(md, yvec, vn) - if hasacc(vi, Val(:LogPrior)) - vi = acclogprior!!(vi, -logjac) + if hasacc(vi, Val(:LogJacobian)) + vi = acclogjac!!(vi, logjac) end return vi end @@ -1187,8 +1187,8 @@ function _link(model::Model, varinfo::VarInfo, vns) varinfo = deepcopy(varinfo) md, logjac = _link_metadata!!(model, varinfo, varinfo.metadata, vns) new_varinfo = VarInfo(md, varinfo.accs) - if hasacc(new_varinfo, Val(:LogPrior)) - new_varinfo = acclogprior!!(new_varinfo, -logjac) + if hasacc(new_varinfo, Val(:LogJacobian)) + new_varinfo = acclogjac!!(new_varinfo, logjac) end return new_varinfo end @@ -1203,8 +1203,8 @@ function _link(model::Model, varinfo::NTVarInfo, vns::NamedTuple) varinfo = deepcopy(varinfo) md, logjac = _link_metadata!(model, varinfo, varinfo.metadata, vns) new_varinfo = VarInfo(md, varinfo.accs) - if hasacc(new_varinfo, Val(:LogPrior)) - new_varinfo = acclogprior!!(new_varinfo, -logjac) + if hasacc(new_varinfo, Val(:LogJacobian)) + new_varinfo = acclogjac!!(new_varinfo, logjac) end return new_varinfo end @@ -1353,8 +1353,8 @@ function _invlink(model::Model, varinfo::VarInfo, vns) varinfo = deepcopy(varinfo) md, logjac = _invlink_metadata!!(model, varinfo, varinfo.metadata, vns) new_varinfo = VarInfo(md, varinfo.accs) - if hasacc(new_varinfo, Val(:LogPrior)) - new_varinfo = acclogprior!!(new_varinfo, -logjac) + if hasacc(new_varinfo, Val(:LogJacobian)) + new_varinfo = acclogjac!!(new_varinfo, logjac) end return new_varinfo end @@ -1369,8 +1369,8 @@ function _invlink(model::Model, varinfo::NTVarInfo, vns::NamedTuple) varinfo = deepcopy(varinfo) md, logjac = _invlink_metadata!(model, varinfo, varinfo.metadata, vns) new_varinfo = VarInfo(md, varinfo.accs) - if hasacc(new_varinfo, Val(:LogPrior)) - new_varinfo = acclogprior!!(new_varinfo, -logjac) + if hasacc(new_varinfo, Val(:LogJacobian)) + new_varinfo = acclogjac!!(new_varinfo, logjac) end return new_varinfo end @@ -1430,11 +1430,11 @@ function _invlink_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, targ y = getindex_internal(varinfo, vn) dist = getdist(varinfo, vn) f = from_linked_internal_transform(varinfo, vn, dist) - x, logjac = with_logabsdet_jacobian(f, y) + x, inv_logjac = with_logabsdet_jacobian(f, y) # Vectorize value. xvec = tovec(x) # Accumulate the log-abs-det jacobian correction. - cumulative_logjac += logjac + cumulative_logjac -= inv_logjac # Mark as no longer transformed. settrans!!(varinfo, false, vn) # Return the vectorized transformed value. diff --git a/test/accumulators.jl b/test/accumulators.jl index 5963ad8b5..506821c38 100644 --- a/test/accumulators.jl +++ b/test/accumulators.jl @@ -87,7 +87,9 @@ using DynamicPPL: vn = @varname(x) dist = Normal() @test accumulate_assume!!(LogPriorAccumulator(1.0), val, logjac, vn, dist) == - LogPriorAccumulator(1.0 + logjac + logpdf(dist, val)) + LogPriorAccumulator(1.0 + logpdf(dist, val)) + @test accumulate_assume!!(LogJacobianAccumulator(2.0), val, logjac, vn, dist) == + LogJacobianAccumulator(2.0 + logjac) @test accumulate_assume!!( LogLikelihoodAccumulator(1.0), val, logjac, vn, dist ) == LogLikelihoodAccumulator(1.0) @@ -101,6 +103,8 @@ using DynamicPPL: vn = @varname(x) @test accumulate_observe!!(LogPriorAccumulator(1.0), right, left, vn) == LogPriorAccumulator(1.0) + @test accumulate_observe!!(LogJacobianAccumulator(1.0), right, left, vn) == + LogJacobianAccumulator(1.0) @test accumulate_observe!!(LogLikelihoodAccumulator(1.0), right, left, vn) == LogLikelihoodAccumulator(1.0 + logpdf(right, left)) @test accumulate_observe!!(VariableOrderAccumulator(1), right, left, vn) == diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index a4c358702..3cca1b5dc 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -89,11 +89,11 @@ @testset "link!! & invlink!! on $(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS values_constrained = DynamicPPL.TestUtils.rand_prior_true(model) - @testset "$(typeof(vi))" for vi in ( - SimpleVarInfo(Dict()), - SimpleVarInfo(values_constrained), - SimpleVarInfo(DynamicPPL.VarNamedVector()), - DynamicPPL.typed_varinfo(model), + @testset "$name" for (name, vi) in ( + ("SVI{Dict}", SimpleVarInfo(Dict())), + ("SVI{NamedTuple}", SimpleVarInfo(values_constrained)), + ("SVI{VNV}", SimpleVarInfo(DynamicPPL.VarNamedVector())), + ("TypedVarInfo", DynamicPPL.typed_varinfo(model)), ) for vn in DynamicPPL.TestUtils.varnames(model) vi = DynamicPPL.setindex!!(vi, get(values_constrained, vn), vn) @@ -102,7 +102,7 @@ # Calculate ground truth lp_unlinked_true = DynamicPPL.TestUtils.logjoint_true( - models, values_constrained... + model, values_constrained... ) _, lp_linked_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( model, values_constrained... @@ -145,10 +145,10 @@ end svi_vnv = SimpleVarInfo(vnv) - @testset "$(nameof(typeof(DynamicPPL.values_as(svi))))" for svi in ( - svi_nt, - svi_dict, - svi_vnv, + @testset "$name" for (name, svi) in ( + ("NamedTuple", svi_nt), + ("Dict", svi_dict), + ("VarNamedVector", svi_vnv), # TODO(mhauru) Fix linked SimpleVarInfos to work with our test models. # DynamicPPL.settrans!!(deepcopy(svi_nt), true), # DynamicPPL.settrans!!(deepcopy(svi_dict), true), @@ -283,31 +283,36 @@ vi_linked = DynamicPPL.setindex!!(vi_linked, -rand(), vn) end - retval, vi_linked_result = DynamicPPL.evaluate!!(model, deepcopy(vi_linked)) + # NOTE: Evaluating a linked VarInfo, **specifically when the transformation + # is static**, will result in an invlinked VarInfo. This is because of + # `maybe_invlink_before_eval!`, which only invlinks if the transformation + # is static. (src/abstract_varinfo.jl) + retval, vi_unlinked_again = DynamicPPL.evaluate!!(model, deepcopy(vi_linked)) @test DynamicPPL.tovec(DynamicPPL.getindex_internal(vi_linked, @varname(s))) ≠ DynamicPPL.tovec(retval.s) # `s` is unconstrained in original @test DynamicPPL.tovec( - DynamicPPL.getindex_internal(vi_linked_result, @varname(s)) + DynamicPPL.getindex_internal(vi_unlinked_again, @varname(s)) ) == DynamicPPL.tovec(retval.s) # `s` is constrained in result # `m` should not be transformed. @test vi_linked[@varname(m)] == retval.m - @test vi_linked_result[@varname(m)] == retval.m + @test vi_unlinked_again[@varname(m)] == retval.m - # Compare to truth. - retval_unconstrained, lp_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( + # Get ground truths + retval_unconstrained, lp_linked_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( model, retval.s, retval.m ) + lp_unlinked_true = DynamicPPL.TestUtils.logjoint_true(model, retval.s, retval.m) @test DynamicPPL.tovec(DynamicPPL.getindex_internal(vi_linked, @varname(s))) ≈ DynamicPPL.tovec(retval_unconstrained.s) @test DynamicPPL.tovec(DynamicPPL.getindex_internal(vi_linked, @varname(m))) ≈ DynamicPPL.tovec(retval_unconstrained.m) - # The resulting varinfo should hold the correct logp. - lp = getlogjoint_internal(vi_linked_result) - @test lp ≈ lp_true + # The unlinked varinfo should hold the unlinked logp. + lp_unlinked = getlogjoint(vi_unlinked_again) + @test getlogjoint(vi_unlinked_again) ≈ lp_unlinked_true end end end diff --git a/test/varinfo.jl b/test/varinfo.jl index c68d5ca8f..16a9a857d 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -211,7 +211,7 @@ end # need regex because 1.11 and 1.12 throw different errors (in 1.12 the # missing field is surrounded by backticks) @test_throws r"has no field `?LogLikelihood" getloglikelihood(vi) - @test_throws r"has no field `?LogLikelihood" getlogp(vi) + @test_throws r"has no field `?LogJacobian" getlogp(vi) @test_throws r"has no field `?LogLikelihood" getlogjoint(vi) @test_throws r"has no field `?VariableOrder" get_num_produce(vi) @test begin @@ -579,30 +579,30 @@ end ## `untyped_varinfo` vi = DynamicPPL.untyped_varinfo(model) vi = DynamicPPL.settrans!!(vi, true, vn) - test_linked_varinfo(model, vi, dist) + test_linked_varinfo(model, vi) ## `typed_varinfo` vi = DynamicPPL.typed_varinfo(model) vi = DynamicPPL.settrans!!(vi, true, vn) - test_linked_varinfo(model, vi, dist) + test_linked_varinfo(model, vi) ## `typed_varinfo` vi = DynamicPPL.typed_varinfo(model) vi = DynamicPPL.settrans!!(vi, true, vn) - test_linked_varinfo(model, vi, dist) + test_linked_varinfo(model, vi) ### `SimpleVarInfo` ## `SimpleVarInfo{<:NamedTuple}` vi = DynamicPPL.settrans!!(SimpleVarInfo(), true) - test_linked_varinfo(model, vi, dist) + test_linked_varinfo(model, vi) ## `SimpleVarInfo{<:Dict}` vi = DynamicPPL.settrans!!(SimpleVarInfo(Dict()), true) - test_linked_varinfo(model, vi, dist) + test_linked_varinfo(model, vi) ## `SimpleVarInfo{<:VarNamedVector}` vi = DynamicPPL.settrans!!(SimpleVarInfo(DynamicPPL.VarNamedVector()), true) - test_linked_varinfo(model, vi, dist) + test_linked_varinfo(model, vi) end @testset "values_as" begin @@ -705,8 +705,8 @@ end lp = logjoint(model, varinfo) @test lp ≈ lp_true @test getlogjoint(varinfo) ≈ lp_true - lp_linked = getlogjoint(varinfo_linked) - @test lp_linked ≈ lp_linked_true + lp_linked_internal = getlogjoint_internal(varinfo_linked) + @test lp_linked_internal ≈ lp_linked_true # TODO: Compare values once we are no longer working with `NamedTuple` for # the true values, e.g. `value_true`. @@ -718,6 +718,7 @@ end ) @test length(varinfo_invlinked[:]) == length(varinfo[:]) @test getlogjoint(varinfo_invlinked) ≈ lp_true + @test getlogjoint_internal(varinfo_invlinked) ≈ lp_true end end end From 5a4b01b8aaa2b0f49e6ca0b0998516e954f8311d Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 22 Jul 2025 01:05:16 +0100 Subject: [PATCH 04/15] Fix final tests --- src/varinfo.jl | 36 +++++++++++++++++++++--------------- test/linking.jl | 7 +++++-- 2 files changed, 26 insertions(+), 17 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 4cf5b0562..7b819c58f 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1351,10 +1351,13 @@ end function _invlink(model::Model, varinfo::VarInfo, vns) varinfo = deepcopy(varinfo) - md, logjac = _invlink_metadata!!(model, varinfo, varinfo.metadata, vns) + md, inv_logjac = _invlink_metadata!!(model, varinfo, varinfo.metadata, vns) new_varinfo = VarInfo(md, varinfo.accs) if hasacc(new_varinfo, Val(:LogJacobian)) - new_varinfo = acclogjac!!(new_varinfo, logjac) + # Mildly confusing: we need to _add_ the logjac of the inverse transform, + # because we are trying to remove the logjac of the forward transform + # that was previously accumulated when linking. + new_varinfo = acclogjac!!(new_varinfo, inv_logjac) end return new_varinfo end @@ -1367,10 +1370,13 @@ end function _invlink(model::Model, varinfo::NTVarInfo, vns::NamedTuple) varinfo = deepcopy(varinfo) - md, logjac = _invlink_metadata!(model, varinfo, varinfo.metadata, vns) + md, inv_logjac = _invlink_metadata!(model, varinfo, varinfo.metadata, vns) new_varinfo = VarInfo(md, varinfo.accs) if hasacc(new_varinfo, Val(:LogJacobian)) - new_varinfo = acclogjac!!(new_varinfo, logjac) + # Mildly confusing: we need to _add_ the logjac of the inverse transform, + # because we are trying to remove the logjac of the forward transform + # that was previously accumulated when linking. + new_varinfo = acclogjac!!(new_varinfo, inv_logjac) end return new_varinfo end @@ -1382,7 +1388,7 @@ end vns::NamedTuple{vns_names}, ) where {metadata_names,vns_names} expr = quote - cumulative_logjac = zero(LogProbType) + cumulative_inv_logjac = zero(LogProbType) end mds = Expr(:tuple) for f in metadata_names @@ -1391,10 +1397,10 @@ end mds.args, quote begin - md, logjac = _invlink_metadata!!( + md, inv_logjac = _invlink_metadata!!( model, varinfo, metadata.$f, vns.$f ) - cumulative_logjac += logjac + cumulative_inv_logjac += inv_logjac md end end, @@ -1407,7 +1413,7 @@ end push!( expr.args, quote - (NamedTuple{$metadata_names}($mds), cumulative_logjac) + (NamedTuple{$metadata_names}($mds), cumulative_inv_logjac) end, ) return expr @@ -1415,7 +1421,7 @@ end function _invlink_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, target_vns) vns = metadata.vns - cumulative_logjac = zero(LogProbType) + cumulative_inv_logjac = zero(LogProbType) # Construct the new transformed values, and keep track of their lengths. vals_new = map(vns) do vn @@ -1434,7 +1440,7 @@ function _invlink_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, targ # Vectorize value. xvec = tovec(x) # Accumulate the log-abs-det jacobian correction. - cumulative_logjac -= inv_logjac + cumulative_inv_logjac += inv_logjac # Mark as no longer transformed. settrans!!(varinfo, false, vn) # Return the vectorized transformed value. @@ -1459,25 +1465,25 @@ function _invlink_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, targ metadata.dists, metadata.flags, ), - cumulative_logjac + cumulative_inv_logjac end function _invlink_metadata!!( ::Model, varinfo::VarInfo, metadata::VarNamedVector, target_vns ) vns = target_vns === nothing ? keys(metadata) : target_vns - cumulative_logjac = zero(LogProbType) + cumulative_inv_logjac = zero(LogProbType) for vn in vns transform = gettransform(metadata, vn) old_val = getindex_internal(metadata, vn) - new_val, logjac = with_logabsdet_jacobian(transform, old_val) + new_val, inv_logjac = with_logabsdet_jacobian(transform, old_val) # TODO(mhauru) We are calling a !! function but ignoring the return value. - cumulative_logjac += logjac + cumulative_inv_logjac += inv_logjac new_transform = from_vec_transform(new_val) metadata = setindex_internal!!(metadata, tovec(new_val), vn, new_transform) settrans!(metadata, false, vn) end - return metadata, cumulative_logjac + return metadata, cumulative_inv_logjac end # TODO(mhauru) The treatment of the case when some variables are linked and others are not diff --git a/test/linking.jl b/test/linking.jl index 12204e868..cae101c72 100644 --- a/test/linking.jl +++ b/test/linking.jl @@ -87,8 +87,8 @@ end # Difference between the internal logjoints should just be the log-absdet-jacobian "correction". @test DynamicPPL.getlogjoint_internal(vi) - DynamicPPL.getlogjoint_internal(vi_linked) ≈ log(2) - # The non-internal logjoint should be the same. - @test DynamicPPL.getlogjoint(vi) ≈ DynamicPPL.getlogjoint_internal(vi_linked) + # The non-internal logjoint should be the same since it doesn't depend on linking. + @test DynamicPPL.getlogjoint(vi) ≈ DynamicPPL.getlogjoint(vi_linked) @test vi_linked[@varname(m), dist] == LowerTriangular(vi[@varname(m), dist]) # Linked one should be working with a lower-dimensional representation. @test length(vi_linked[:]) < length(vi[:]) @@ -101,7 +101,10 @@ end end @test length(vi_invlinked[:]) == length(vi[:]) @test vi_invlinked[@varname(m), dist] ≈ LowerTriangular(vi[@varname(m), dist]) + # The non-internal logjoint should still be the same, again since + # it doesn't depend on linking. @test DynamicPPL.getlogjoint(vi_invlinked) ≈ DynamicPPL.getlogjoint(vi) + # The internal logjoint should also be the same as before the round-trip linking. @test DynamicPPL.getlogjoint_internal(vi_invlinked) ≈ DynamicPPL.getlogjoint_internal(vi) end From 974c2823c30f12af8c8d1cc6afa45232c1d97bfb Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 22 Jul 2025 01:06:21 +0100 Subject: [PATCH 05/15] Fix docs --- docs/src/api.md | 1 + src/DynamicPPL.jl | 1 + 2 files changed, 2 insertions(+) diff --git a/docs/src/api.md b/docs/src/api.md index 8d0b7b558..9237943c7 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -384,6 +384,7 @@ getlogjoint getlogjoint_internal getlogjac setlogjac!! +acclogjac!! getlogprior getlogprior_internal setlogprior!! diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 9663b6702..15d39014e 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -50,6 +50,7 @@ export AbstractVarInfo, AbstractAccumulator, LogLikelihoodAccumulator, LogPriorAccumulator, + LogJacobianAccumulator, VariableOrderAccumulator, push!!, empty!!, From 60b686324a859ceb70e2c361b0031d6775f85f75 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 22 Jul 2025 01:08:56 +0100 Subject: [PATCH 06/15] Fix docs/doctests --- src/abstract_varinfo.jl | 2 +- src/simple_varinfo.jl | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 9dd922163..4a256ce0c 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -244,7 +244,7 @@ setlogprior!!(vi::AbstractVarInfo, logp) = setacc!!(vi, LogPriorAccumulator(logp Set the accumulated log-Jacobian term for any linked parameters in `vi`. The Jacobian here is taken with respect to the forward (link) transform. -See also: [`getlogjac!!`](@ref). +See also: [`getlogjac`](@ref), [`acclogjac!!`](@ref). """ setlogjac!!(vi::AbstractVarInfo, logJ) = setacc!!(vi, LogJacobianAccumulator(logJ)) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 85d0e6066..9479e26f2 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -122,7 +122,7 @@ Evaluation in transformed space of course also works: ```jldoctest simplevarinfo-general julia> vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), true) -Transformed SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0), VariableOrder = VariableOrderAccumulator(0, Dict{VarName, Int64}()))) +Transformed SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogJacobian = LogJacobianAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0), VariableOrder = VariableOrderAccumulator(0, Dict{VarName, Int64}()))) julia> # (✓) Positive probability mass on negative numbers! getlogjoint_internal(last(DynamicPPL.evaluate!!(m, vi))) @@ -130,7 +130,7 @@ julia> # (✓) Positive probability mass on negative numbers! julia> # While if we forget to indicate that it's transformed: vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), false) -SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0), VariableOrder = VariableOrderAccumulator(0, Dict{VarName, Int64}()))) +SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogJacobian = LogJacobianAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0), VariableOrder = VariableOrderAccumulator(0, Dict{VarName, Int64}()))) julia> # (✓) No probability mass on negative numbers! getlogjoint_internal(last(DynamicPPL.evaluate!!(m, vi))) From 53a2f61c4ea85ec871f1f031b4e949a13b6f6a42 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 22 Jul 2025 01:18:24 +0100 Subject: [PATCH 07/15] Fix maths in LogJacobianAccumulator docstring --- docs/make.jl | 4 +++- src/default_accumulators.jl | 11 ++++++----- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/docs/make.jl b/docs/make.jl index c69b72fb8..9c59cb06b 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -21,7 +21,9 @@ makedocs(; sitename="DynamicPPL", # The API index.html page is fairly large, and violates the default HTML page size # threshold of 200KiB, so we double that. - format=Documenter.HTML(; size_threshold=2^10 * 400), + format=Documenter.HTML(; + size_threshold=2^10 * 400, mathengine=Documenter.HTMLWriter.MathJax3() + ), modules=[DynamicPPL, Base.get_extension(DynamicPPL, :DynamicPPLMCMCChainsExt)], pages=[ "Home" => "index.md", "API" => "api.md", "Internals" => ["internals/varinfo.md"] diff --git a/src/default_accumulators.jl b/src/default_accumulators.jl index 1a61d68f0..bce973853 100644 --- a/src/default_accumulators.jl +++ b/src/default_accumulators.jl @@ -33,14 +33,15 @@ distribution to unconstrained space. !!! note This accumulator is only incremented if the variable is transformed by a - link function, i.e., if the VarInfo is linked (for this particular - variable). If the VarInfo is not linked, the log Jacobian term will be 0. + link function, i.e., if the VarInfo is linked (for the particular + variable that is currently being accumulated). If the variable is not + linked, the log Jacobian term will be 0. - In general, for the forward Jacobian `J` corresponding to the function `y = - f(x)`, + In general, for the forward Jacobian ``\\mathbf{J}`` corresponding to the + function ``\\mathbf{y} = f(\\mathbf{x})``, ```math - \\log(q(\\mathbf{y})) = \\log(p(\\mathbf{x})) - \\log\\(|\\mathbf{J}|\\) + \\log(q(\\mathbf{y})) = \\log(p(\\mathbf{x})) - \\log (|\\mathbf{J}|) ``` and correspondingly: From a47641c754367a01767f210cd1a1acb2871572f7 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 22 Jul 2025 01:21:00 +0100 Subject: [PATCH 08/15] Twiddle with a comment --- src/pointwise_logdensities.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/pointwise_logdensities.jl b/src/pointwise_logdensities.jl index 14f8103e7..dea432022 100644 --- a/src/pointwise_logdensities.jl +++ b/src/pointwise_logdensities.jl @@ -74,8 +74,9 @@ function accumulate_assume!!( # T is the element type of the vectors that are the values of `acc.logps`. Usually # it's LogProbType. T = eltype(last(fieldtypes(eltype(acc.logps)))) - # Note that accumulating LogPrior ignores logjac (since we want to - # return log densities that don't depend on the linking status of the VarInfo). + # Note that in only accumulating LogPrior, we effectively ignore logjac + # (since we want to return log densities that don't depend on the + # linking status of the VarInfo). subacc = accumulate_assume!!(LogPriorAccumulator{T}(), val, logjac, vn, right) push!(acc, vn, subacc.logp) end From 10de51f893a2442904ec922a0842d3d9fe4a6767 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 22 Jul 2025 01:51:29 +0100 Subject: [PATCH 09/15] Add changelog --- HISTORY.md | 32 ++++++++++++++++++++++++++------ 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index d367e9ad7..b59d8dd7f 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -32,20 +32,40 @@ Their semantics are the same as in Julia's `isapprox`; two values are equal if t You now need to explicitly pass a `VarInfo` argument to `check_model` and `check_model_and_trace`. Previously, these functions would generate a new VarInfo for you (using an optionally provided `rng`). -### Removal of `PriorContext` and `LikelihoodContext` - -A number of DynamicPPL's contexts have been removed, most notably `PriorContext` and `LikelihoodContext`. -Although these are not the only _exported_ contexts, we consider unlikely that anyone was using _other_ contexts manually: if you have a question about contexts _other_ than these, please continue reading the 'Internals' section below. +### Evaluating model log-probabilities in more detail Previously, during evaluation of a model, DynamicPPL only had the capability to store a _single_ log probability (`logp`) field. `DefaultContext`, `PriorContext`, and `LikelihoodContext` were used to control what this field represented: they would accumulate the log joint, log prior, or log likelihood, respectively. -Now, we have reworked DynamicPPL's `VarInfo` object such that it can track multiple log probabilities at once (see the 'Accumulators' section below). +In this version, we have overhauled this quite substantially. +The technical details of exactly _how_ this is done is covered in the 'Accumulators' section below, but the upshot is that the log prior, log likelihood, and log Jacobian terms (for any linked variables) are separately tracked. + +Specifically, you will want to use the following functions to access these log probabilities: + + - `getlogprior(varinfo)` to get the log prior. **Note:** This version introduces new, more consistent behaviour for this function, in that it always returns the log-prior of the values in the original, untransformed space, even if the `varinfo` has been linked. + - `getloglikelihood(varinfo)` to get the log likelihood. + - `getlogjoint(varinfo)` to get the log joint probability. **Note:** Similar to `getlogprior`, this function now always returns the log joint of the values in the original, untransformed space, even if the `varinfo` has been linked. + +If you are using linked VarInfos (e.g. if you are writing a sampler), you may find that you need to obtain the log probability of the variables in the transformed space. +To this end, you can use: + + - `getlogjac(varinfo)` to get the log Jacobian of the link transforms for any linked variables. + - `getlogprior_internal(varinfo)` to get the log prior of the variables in the transformed space. + - `getlogjoint_internal(varinfo)` to get the log joint probability of the variables in the transformed space. + +Since transformations only apply to random variables, the likelihood is unaffected by linking. + +### Removal of `PriorContext` and `LikelihoodContext` + +Following on from the above, a number of DynamicPPL's contexts have been removed, most notably `PriorContext` and `LikelihoodContext`. +Although these are not the only _exported_ contexts, we consider unlikely that anyone was using _other_ contexts manually: if you have a question about contexts _other_ than these, please continue reading the 'Internals' section below. + If you were evaluating a model with `PriorContext`, you can now just evaluate it with `DefaultContext`, and instead of calling `getlogp(varinfo)`, you can call `getlogprior(varinfo)` (and similarly for the likelihood). If you were constructing a `LogDensityFunction` with `PriorContext`, you can now stick to `DefaultContext`. `LogDensityFunction` now has an extra field, called `getlogdensity`, which represents a function that takes a `VarInfo` and returns the log density you want. -Thus, if you pass `getlogprior` as the value of this parameter, you will get the same behaviour as with `PriorContext`. +Thus, if you pass `getlogprior_internal` as the value of this parameter, you will get the same behaviour as with `PriorContext`. +(You should consider whether your use case needs the log prior in the transformed space, or the original space, and use (respectively) `getlogprior_internal` or `getlogprior` as needed.) The other case where one might use `PriorContext` was to use `@addlogprob!` to add to the log prior. Previously, this was accomplished by manually checking `__context__ isa DynamicPPL.PriorContext`. From e9bf50b0b6cf65539c6cd1e2a610698e2490477a Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 24 Jul 2025 08:35:22 +0100 Subject: [PATCH 10/15] Simplify accs with LogProbAccumulator --- src/default_accumulators.jl | 239 ++++++++++++++++-------------------- src/utils.jl | 7 ++ 2 files changed, 116 insertions(+), 130 deletions(-) diff --git a/src/default_accumulators.jl b/src/default_accumulators.jl index bce973853..7b7d8c8b5 100644 --- a/src/default_accumulators.jl +++ b/src/default_accumulators.jl @@ -1,5 +1,84 @@ """ - LogPriorAccumulator{T<:Real} <: AbstractAccumulator + LogProbAccumulator{T} <: AbstractAccumulator + +An abstract type for accumulators that hold a single scalar log probability value. + +Every subtype of `LogProbAccumulator` must implement +* A method for `logp` that returns the scalar log probability value that defines it. +* A single-argument constructor that takes a `logp` value. +* `accumulator_name`, `accumulate_assume!!`, and `accumulate_observe!!` methods like any + other accumulator. + +`LogProbAccumulator` provides implementations for other common functions, like convenience +constructors, `copy`, `show`, `==`, `isequal`, `hash`, `split`, and `combine`. + +This type has no great conceptual significance, it just reduces code duplication between +types like LogPriorAccumulator, LogJacobianAccumulator, and LogLikelihoodAccumulator. +""" +abstract type LogProbAccumulator{T<:Real} <: AbstractAccumulator end + +# The first of the below methods sets AccType{T}() = AccType(zero(T)) for any +# AccType <: LogProbAccumulator{T}. The second one sets LogProbType as the default eltype T +# when calling AccType(). +""" + LogProbAccumulator{T}() + +Create a new `LogProbAccumulator` accumulator with the log prior initialized to zero. +""" +(::Type{AccType})() where {T<:Real,AccType<:LogProbAccumulator{T}} = AccType(zero(T)) +(::Type{AccType})() where {AccType<:LogProbAccumulator} = AccType{LogProbType}() + +Base.copy(acc::LogProbAccumulator) = acc + +function Base.show(io::IO, acc::LogProbAccumulator) + return print(io, "$(repr(accumulator_name(acc)))($(repr(logp(acc)))))") +end + +# Note that == and isequal are different, and equality under the latter should imply +# equality of hashes. Both of the below implementations are also different from the default +# implementation for structs. +function Base.:(==)(acc1::LogProbAccumulator, acc2::LogProbAccumulator) + return accumulator_name(acc1) === accumulator_name(acc2) && logp(acc1) == logp(acc2) +end + +function Base.isequal(acc1::LogProbAccumulator, acc2::LogProbAccumulator) + return basetypeof(acc1) === basetypeof(acc2) && isequal(logp(acc1), logp(acc2)) +end + +Base.hash(acc::T, h::UInt) where {T<:LogProbAccumulator} = hash((T, logp(acc)), h) + +split(::AccType) where {T,AccType<:LogProbAccumulator{T}} = AccType(zero(T)) + +function combine(acc::LogProbAccumulator, acc2::LogProbAccumulator) + if basetypeof(acc) !== basetypeof(acc2) + msg = "Cannot combine accumulators of different types: $(basetypeof(acc)) and $(basetypeof(acc2))" + throw(ArgumentError(msg)) + end + return basetypeof(acc)(logp(acc) + logp(acc2)) +end + +function Base.:+(acc1::LogProbAccumulator, acc2::LogProbAccumulator) + if basetypeof(acc1) !== basetypeof(acc2) + msg = "Cannot add accumulators of different types: $(basetypeof(acc1)) and $(basetypeof(acc2))" + throw(ArgumentError(msg)) + end + return basetypeof(acc1)(logp(acc1) + logp(acc2)) +end + +Base.zero(acc::T) where {T<:LogProbAccumulator} = T(zero(logp(acc))) + +function Base.convert( + ::Type{AccType}, acc::LogProbAccumulator +) where {T,AccType<:LogProbAccumulator{T}} + return AccType(convert(T, logp(acc))) +end + +function convert_eltype(::Type{T}, acc::LogProbAccumulator) where {T} + return basetypeof(acc)(convert(T, logp(acc))) +end + +""" + LogPriorAccumulator{T<:Real} <: LogProbAccumulator{T} An accumulator that tracks the cumulative log prior during model execution. @@ -10,21 +89,22 @@ linked or not. # Fields $(TYPEDFIELDS) """ -struct LogPriorAccumulator{T<:Real} <: AbstractAccumulator +struct LogPriorAccumulator{T<:Real} <: LogProbAccumulator{T} "the scalar log prior value" logp::T end -""" - LogPriorAccumulator{T}() +logp(acc::LogPriorAccumulator) = acc.logp -Create a new `LogPriorAccumulator` accumulator with the log prior initialized to zero. -""" -LogPriorAccumulator{T}() where {T<:Real} = LogPriorAccumulator(zero(T)) -LogPriorAccumulator() = LogPriorAccumulator{LogProbType}() +accumulator_name(::Type{<:LogPriorAccumulator}) = :LogPrior + +function accumulate_assume!!(acc::LogPriorAccumulator, val, logjac, vn, right) + return acc + LogPriorAccumulator(logpdf(right, val)) +end +accumulate_observe!!(acc::LogPriorAccumulator, right, left, vn) = acc """ - LogJacobianAccumulator{T<:Real} <: AbstractAccumulator + LogJacobianAccumulator{T<:Real} <: LogProbAccumulator{T} An accumulator that tracks the cumulative log Jacobian (technically, log(abs(det(J)))) during model execution. Specifically, J refers to the @@ -53,39 +133,44 @@ distribution to unconstrained space. # Fields $(TYPEDFIELDS) """ -struct LogJacobianAccumulator{T<:Real} <: AbstractAccumulator +struct LogJacobianAccumulator{T<:Real} <: LogProbAccumulator{T} "the logabsdet of the link transform Jacobian" logJ::T end -""" - LogJacobianAccumulator{T}() +logp(acc::LogJacobianAccumulator) = acc.logJ -Create a new `LogJacobianAccumulator` accumulator with the log Jacobian initialized to zero. -""" -LogJacobianAccumulator{T}() where {T<:Real} = LogJacobianAccumulator(zero(T)) -LogJacobianAccumulator() = LogJacobianAccumulator{LogProbType}() +accumulator_name(::Type{<:LogJacobianAccumulator}) = :LogJacobian + +function accumulate_assume!!(acc::LogJacobianAccumulator, val, logjac, vn, right) + return acc + LogJacobianAccumulator(logjac) +end +accumulate_observe!!(acc::LogJacobianAccumulator, right, left, vn) = acc """ - LogLikelihoodAccumulator{T<:Real} <: AbstractAccumulator + LogLikelihoodAccumulator{T<:Real} <: LogProbAccumulator{T} An accumulator that tracks the cumulative log likelihood during model execution. # Fields $(TYPEDFIELDS) """ -struct LogLikelihoodAccumulator{T<:Real} <: AbstractAccumulator +struct LogLikelihoodAccumulator{T<:Real} <: LogProbAccumulator{T} "the scalar log likelihood value" logp::T end -""" - LogLikelihoodAccumulator{T}() +logp(acc::LogLikelihoodAccumulator) = acc.logp -Create a new `LogLikelihoodAccumulator` accumulator with the log likelihood initialized to zero. -""" -LogLikelihoodAccumulator{T}() where {T<:Real} = LogLikelihoodAccumulator(zero(T)) -LogLikelihoodAccumulator() = LogLikelihoodAccumulator{LogProbType}() +accumulator_name(::Type{<:LogLikelihoodAccumulator}) = :LogLikelihood + +accumulate_assume!!(acc::LogLikelihoodAccumulator, val, logjac, vn, right) = acc +function accumulate_observe!!(acc::LogLikelihoodAccumulator, right, left, vn) + # Note that it's important to use the loglikelihood function here, not logpdf, because + # they handle vectors differently: + # https://github.com/JuliaStats/Distributions.jl/issues/1972 + return acc + LogLikelihoodAccumulator(Distributions.loglikelihood(right, left)) +end """ VariableOrderAccumulator{T} <: AbstractAccumulator @@ -117,85 +202,32 @@ VariableOrderAccumulator{T}(n=zero(T)) where {T<:Integer} = VariableOrderAccumulator(n) = VariableOrderAccumulator{typeof(n)}(n) VariableOrderAccumulator() = VariableOrderAccumulator{Int}() -Base.copy(acc::LogPriorAccumulator) = acc -Base.copy(acc::LogJacobianAccumulator) = acc -Base.copy(acc::LogLikelihoodAccumulator) = acc function Base.copy(acc::VariableOrderAccumulator) return VariableOrderAccumulator(acc.num_produce, copy(acc.order)) end -function Base.show(io::IO, acc::LogPriorAccumulator) - return print(io, "LogPriorAccumulator($(repr(acc.logp)))") -end -function Base.show(io::IO, acc::LogJacobianAccumulator) - return print(io, "LogJacobianAccumulator($(repr(acc.logJ)))") -end -function Base.show(io::IO, acc::LogLikelihoodAccumulator) - return print(io, "LogLikelihoodAccumulator($(repr(acc.logp)))") -end function Base.show(io::IO, acc::VariableOrderAccumulator) return print( io, "VariableOrderAccumulator($(repr(acc.num_produce)), $(repr(acc.order)))" ) end -# Note that == and isequal are different, and equality under the latter should imply -# equality of hashes. Both of the below implementations are also different from the default -# implementation for structs. -Base.:(==)(acc1::LogPriorAccumulator, acc2::LogPriorAccumulator) = acc1.logp == acc2.logp -function Base.:(==)(acc1::LogJacobianAccumulator, acc2::LogJacobianAccumulator) - return acc1.logJ == acc2.logJ -end -function Base.:(==)(acc1::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator) - return acc1.logp == acc2.logp -end function Base.:(==)(acc1::VariableOrderAccumulator, acc2::VariableOrderAccumulator) return acc1.num_produce == acc2.num_produce && acc1.order == acc2.order end -function Base.isequal(acc1::LogPriorAccumulator, acc2::LogPriorAccumulator) - return isequal(acc1.logp, acc2.logp) -end -function Base.isequal(acc1::LogJacobianAccumulator, acc2::LogJacobianAccumulator) - return isequal(acc1.logJ, acc2.logJ) -end -function Base.isequal(acc1::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator) - return isequal(acc1.logp, acc2.logp) -end function Base.isequal(acc1::VariableOrderAccumulator, acc2::VariableOrderAccumulator) return isequal(acc1.num_produce, acc2.num_produce) && isequal(acc1.order, acc2.order) end -Base.hash(acc::LogPriorAccumulator, h::UInt) = hash((LogPriorAccumulator, acc.logp), h) -function Base.hash(acc::LogJacobianAccumulator, h::UInt) - return hash((LogJacobianAccumulator, acc.logJ), h) -end -function Base.hash(acc::LogLikelihoodAccumulator, h::UInt) - return hash((LogLikelihoodAccumulator, acc.logp), h) -end function Base.hash(acc::VariableOrderAccumulator, h::UInt) return hash((VariableOrderAccumulator, acc.num_produce, acc.order), h) end -accumulator_name(::Type{<:LogPriorAccumulator}) = :LogPrior -accumulator_name(::Type{<:LogJacobianAccumulator}) = :LogJacobian -accumulator_name(::Type{<:LogLikelihoodAccumulator}) = :LogLikelihood accumulator_name(::Type{<:VariableOrderAccumulator}) = :VariableOrder -split(::LogPriorAccumulator{T}) where {T} = LogPriorAccumulator(zero(T)) -split(::LogJacobianAccumulator{T}) where {T} = LogJacobianAccumulator(zero(T)) -split(::LogLikelihoodAccumulator{T}) where {T} = LogLikelihoodAccumulator(zero(T)) split(acc::VariableOrderAccumulator) = copy(acc) -function combine(acc::LogPriorAccumulator, acc2::LogPriorAccumulator) - return LogPriorAccumulator(acc.logp + acc2.logp) -end -function combine(acc::LogJacobianAccumulator, acc2::LogJacobianAccumulator) - return LogJacobianAccumulator(acc.logJ + acc2.logJ) -end -function combine(acc::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator) - return LogLikelihoodAccumulator(acc.logp + acc2.logp) -end function combine(acc::VariableOrderAccumulator, acc2::VariableOrderAccumulator) # Note that assumptions are not allowed in parallelised blocks, and thus the # dictionaries should be identical. @@ -204,60 +236,16 @@ function combine(acc::VariableOrderAccumulator, acc2::VariableOrderAccumulator) ) end -function Base.:+(acc1::LogPriorAccumulator, acc2::LogPriorAccumulator) - return LogPriorAccumulator(acc1.logp + acc2.logp) -end -function Base.:+(acc1::LogJacobianAccumulator, acc2::LogJacobianAccumulator) - return LogJacobianAccumulator(acc1.logJ + acc2.logJ) -end -function Base.:+(acc1::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator) - return LogLikelihoodAccumulator(acc1.logp + acc2.logp) -end function increment(acc::VariableOrderAccumulator) return VariableOrderAccumulator(acc.num_produce + oneunit(acc.num_produce), acc.order) end -Base.zero(acc::LogPriorAccumulator) = LogPriorAccumulator(zero(acc.logp)) -Base.zero(acc::LogJacobianAccumulator) = LogJacobianAccumulator(zero(acc.logJ)) -Base.zero(acc::LogLikelihoodAccumulator) = LogLikelihoodAccumulator(zero(acc.logp)) - -function accumulate_assume!!(acc::LogPriorAccumulator, val, logjac, vn, right) - return acc + LogPriorAccumulator(logpdf(right, val)) -end -accumulate_observe!!(acc::LogPriorAccumulator, right, left, vn) = acc - -function accumulate_assume!!(acc::LogJacobianAccumulator, val, logjac, vn, right) - return acc + LogJacobianAccumulator(logjac) -end -accumulate_observe!!(acc::LogJacobianAccumulator, right, left, vn) = acc - -accumulate_assume!!(acc::LogLikelihoodAccumulator, val, logjac, vn, right) = acc -function accumulate_observe!!(acc::LogLikelihoodAccumulator, right, left, vn) - # Note that it's important to use the loglikelihood function here, not logpdf, because - # they handle vectors differently: - # https://github.com/JuliaStats/Distributions.jl/issues/1972 - return acc + LogLikelihoodAccumulator(Distributions.loglikelihood(right, left)) -end - function accumulate_assume!!(acc::VariableOrderAccumulator, val, logjac, vn, right) acc.order[vn] = acc.num_produce return acc end accumulate_observe!!(acc::VariableOrderAccumulator, right, left, vn) = increment(acc) -function Base.convert(::Type{LogPriorAccumulator{T}}, acc::LogPriorAccumulator) where {T} - return LogPriorAccumulator(convert(T, acc.logp)) -end -function Base.convert( - ::Type{LogJacobianAccumulator{T}}, acc::LogJacobianAccumulator -) where {T} - return LogJacobianAccumulator(convert(T, acc.logJ)) -end -function Base.convert( - ::Type{LogLikelihoodAccumulator{T}}, acc::LogLikelihoodAccumulator -) where {T} - return LogLikelihoodAccumulator(convert(T, acc.logp)) -end function Base.convert( ::Type{VariableOrderAccumulator{ElType,VnType}}, acc::VariableOrderAccumulator ) where {ElType,VnType} @@ -273,15 +261,6 @@ end # convert_eltype(::AbstractAccumulator, ::Type). This is because they are only used to # deal with dual number types of AD backends, which shouldn't concern VariableOrderAccumulator. This is # horribly hacky and should be fixed. See also comment in `unflatten` in `src/varinfo.jl`. -function convert_eltype(::Type{T}, acc::LogPriorAccumulator) where {T} - return LogPriorAccumulator(convert(T, acc.logp)) -end -function convert_eltype(::Type{T}, acc::LogJacobianAccumulator) where {T} - return LogJacobianAccumulator(convert(T, acc.logJ)) -end -function convert_eltype(::Type{T}, acc::LogLikelihoodAccumulator) where {T} - return LogLikelihoodAccumulator(convert(T, acc.logp)) -end function default_accumulators( ::Type{FloatT}=LogProbType, ::Type{IntT}=Int diff --git a/src/utils.jl b/src/utils.jl index 0f4d98b11..aa7fa50a8 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1332,3 +1332,10 @@ function group_varnames_by_symbol(vns::VarNameTuple) elements = map(collect, tuple((filter(vn -> getsym(vn) == s, vns) for s in syms)...)) return NamedTuple{syms}(elements) end + +""" + basetypeof(x) + +Return `typeof(x)` stripped of its type parameters. +""" +basetypeof(x::T) where {T} = Base.typename(T).wrapper From f983da5e7c9d54550a41bca8202bbc0d9773846a Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 24 Jul 2025 09:59:19 +0100 Subject: [PATCH 11/15] Replace + with accumulate for LogProbAccs --- src/abstract_varinfo.jl | 10 +++------- src/default_accumulators.jl | 16 +++++----------- test/accumulators.jl | 10 ++++------ 3 files changed, 12 insertions(+), 24 deletions(-) diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 4a256ce0c..d74d26305 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -358,7 +358,7 @@ Add `logp` to the value of the log of the prior probability in `vi`. See also: [`accloglikelihood!!`](@ref), [`acclogp!!`](@ref), [`getlogprior`](@ref), [`setlogprior!!`](@ref). """ function acclogprior!!(vi::AbstractVarInfo, logp) - return map_accumulator!!(acc -> acc + LogPriorAccumulator(logp), vi, Val(:LogPrior)) + return map_accumulator!!(acc -> acclogp(acc, logp), vi, Val(:LogPrior)) end """ @@ -369,9 +369,7 @@ Add `logJ` to the value of the log Jacobian in `vi`. See also: [`getlogjac`](@ref), [`setlogjac!!`](@ref). """ function acclogjac!!(vi::AbstractVarInfo, logJ) - return map_accumulator!!( - acc -> acc + LogJacobianAccumulator(logJ), vi, Val(:LogJacobian) - ) + return map_accumulator!!(acc -> acclogp(acc, logJ), vi, Val(:LogJacobian)) end """ @@ -382,9 +380,7 @@ Add `logp` to the value of the log of the likelihood in `vi`. See also: [`accloglikelihood!!`](@ref), [`acclogp!!`](@ref), [`getloglikelihood`](@ref), [`setloglikelihood!!`](@ref). """ function accloglikelihood!!(vi::AbstractVarInfo, logp) - return map_accumulator!!( - acc -> acc + LogLikelihoodAccumulator(logp), vi, Val(:LogLikelihood) - ) + return map_accumulator!!(acc -> acclogp(acc, logp), vi, Val(:LogLikelihood)) end """ diff --git a/src/default_accumulators.jl b/src/default_accumulators.jl index 7b7d8c8b5..3c733ab5a 100644 --- a/src/default_accumulators.jl +++ b/src/default_accumulators.jl @@ -57,13 +57,7 @@ function combine(acc::LogProbAccumulator, acc2::LogProbAccumulator) return basetypeof(acc)(logp(acc) + logp(acc2)) end -function Base.:+(acc1::LogProbAccumulator, acc2::LogProbAccumulator) - if basetypeof(acc1) !== basetypeof(acc2) - msg = "Cannot add accumulators of different types: $(basetypeof(acc1)) and $(basetypeof(acc2))" - throw(ArgumentError(msg)) - end - return basetypeof(acc1)(logp(acc1) + logp(acc2)) -end +acclogp(acc::LogProbAccumulator, val) = basetypeof(acc)(logp(acc) + val) Base.zero(acc::T) where {T<:LogProbAccumulator} = T(zero(logp(acc))) @@ -99,7 +93,7 @@ logp(acc::LogPriorAccumulator) = acc.logp accumulator_name(::Type{<:LogPriorAccumulator}) = :LogPrior function accumulate_assume!!(acc::LogPriorAccumulator, val, logjac, vn, right) - return acc + LogPriorAccumulator(logpdf(right, val)) + return acclogp(acc, logpdf(right, val)) end accumulate_observe!!(acc::LogPriorAccumulator, right, left, vn) = acc @@ -143,7 +137,7 @@ logp(acc::LogJacobianAccumulator) = acc.logJ accumulator_name(::Type{<:LogJacobianAccumulator}) = :LogJacobian function accumulate_assume!!(acc::LogJacobianAccumulator, val, logjac, vn, right) - return acc + LogJacobianAccumulator(logjac) + return acclogp(acc, logjac) end accumulate_observe!!(acc::LogJacobianAccumulator, right, left, vn) = acc @@ -169,7 +163,7 @@ function accumulate_observe!!(acc::LogLikelihoodAccumulator, right, left, vn) # Note that it's important to use the loglikelihood function here, not logpdf, because # they handle vectors differently: # https://github.com/JuliaStats/Distributions.jl/issues/1972 - return acc + LogLikelihoodAccumulator(Distributions.loglikelihood(right, left)) + return acclogp(acc, Distributions.loglikelihood(right, left)) end """ @@ -208,7 +202,7 @@ end function Base.show(io::IO, acc::VariableOrderAccumulator) return print( - io, "VariableOrderAccumulator($(repr(acc.num_produce)), $(repr(acc.order)))" + io, "VariableOrderAccumulator($(string(acc.num_produce)), $(repr(acc.order)))" ) end diff --git a/test/accumulators.jl b/test/accumulators.jl index 506821c38..bbd9f772c 100644 --- a/test/accumulators.jl +++ b/test/accumulators.jl @@ -39,13 +39,11 @@ using DynamicPPL: end @testset "addition and incrementation" begin - @test LogPriorAccumulator(1.0f0) + LogPriorAccumulator(1.0f0) == - LogPriorAccumulator(2.0f0) - @test LogPriorAccumulator(1.0) + LogPriorAccumulator(1.0f0) == - LogPriorAccumulator(2.0) - @test LogLikelihoodAccumulator(1.0f0) + LogLikelihoodAccumulator(1.0f0) == + @test acclogp(LogPriorAccumulator(1.0f0), 1.0f0) == LogPriorAccumulator(2.0f0) + @test acclogp(LogPriorAccumulator(1.0), 1.0f0) == LogPriorAccumulator(2.0) + @test acclogp(LogLikelihoodAccumulator(1.0f0), 1.0f0) == LogLikelihoodAccumulator(2.0f0) - @test LogLikelihoodAccumulator(1.0) + LogLikelihoodAccumulator(1.0f0) == + @test acclogp(LogLikelihoodAccumulator(1.0), 1.0f0) == LogLikelihoodAccumulator(2.0) @test increment(VariableOrderAccumulator()) == VariableOrderAccumulator(1) @test increment(VariableOrderAccumulator{UInt8}()) == From bc51d622c17a6d6b3c9af3b6017c560ad3b8775b Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 24 Jul 2025 12:20:11 +0100 Subject: [PATCH 12/15] Introduce merge and subset for accs --- src/accumulators.jl | 71 +++++++++++++++++++++++++++++++++++++ src/default_accumulators.jl | 16 +++++++++ src/simple_varinfo.jl | 6 ++-- src/varinfo.jl | 5 +-- test/accumulators.jl | 24 +++++++++++++ 5 files changed, 118 insertions(+), 4 deletions(-) diff --git a/src/accumulators.jl b/src/accumulators.jl index 7d30b62f0..d3cfc7222 100644 --- a/src/accumulators.jl +++ b/src/accumulators.jl @@ -30,6 +30,13 @@ To be able to work with multi-threading, it should also implement: - `split(acc::T)` - `combine(acc::T, acc2::T)` +If two accumulators of the same type should be merged in some non-trivial way, other than +always keeping the second one over the first, `merge(acc1::T, acc2::T)` should be defined. + +If limiting the accumulator to a subset of `VarName`s is a meaningful operation and should +do something other than copy the original accumulator, then +`subset(acc::T, vns::AbstractVector{<:VarnName})` should be defined.` + See the documentation for each of these functions for more details. """ abstract type AbstractAccumulator end @@ -113,6 +120,24 @@ used by various AD backends, should implement a method for this function. """ convert_eltype(::Type, acc::AbstractAccumulator) = acc +""" + subset(acc::AbstractAccumulator, vns::AbstractVector{<:VarName}) + +Return a new accumulator that only contains the information for the `VarName`s in `vns`. + +By default returns a copy of `acc`. Subtypes should override this behaviour as needed. +""" +subset(acc::AbstractAccumulator, ::AbstractVector{<:VarName}) = copy(acc) + +""" + merge(acc1::AbstractAccumulator, acc2::AbstractAccumulator) + +Merge two accumulators of the same type. Returns a new accumulator of the same type. + +By default returns a copy of `acc2`. Subtypes should override this behaviour as needed. +""" +Base.merge(acc1::AbstractAccumulator, acc2::AbstractAccumulator) = copy(acc2) + """ AccumulatorTuple{N,T<:NamedTuple} @@ -158,6 +183,52 @@ function Base.convert(::Type{AccumulatorTuple{N,T}}, accs::AccumulatorTuple{N}) return AccumulatorTuple(convert(T, accs.nt)) end +""" + subset(at::AccumulatorTuple, vns::AbstractVector{<:VarName}) + +Replace each accumulator `acc` in `at` with `subset(acc, vns)`. +""" +function subset(at::AccumulatorTuple, vns::AbstractVector{<:VarName}) + return AccumulatorTuple(map(Base.Fix2(subset, vns), at.nt)) +end + +""" + _joint_keys(nt1::NamedTuple, nt2::NamedTuple) + +A helper function that returns three tuples of keys given two `NamedTuple`s: +The keys only in `nt1`, only in `nt2`, and in both, and in that order. + +Implemented as a generated function to enabled constant propagation of the result in `merge`. +""" +@generated function _joint_keys( + nt1::NamedTuple{names1}, nt2::NamedTuple{names2} +) where {names1,names2} + set_names1 = Set(names1) + set_names2 = Set(names2) + only_in_nt1 = tuple(setdiff(set_names1, set_names2)...) + only_in_nt2 = tuple(setdiff(set_names2, set_names1)...) + in_both = tuple(intersect(set_names1, set_names2)...) + return :($only_in_nt1, $only_in_nt2, $in_both) +end + +""" + merge(at1::AccumulatorTuple, at2::AccumulatorTuple) + +Merge two `AccumulatorTuple`s. + +For any `accumulator_name` that exists in both `at1` and `at2`, we call `merge` on the two +accumulators themselves. Other accumulators are copied. +""" +function Base.merge(at1::AccumulatorTuple, at2::AccumulatorTuple) + keys_in_at1, keys_in_at2, keys_in_both = _joint_keys(at1.nt, at2.nt) + accs_in_at1 = (getfield(at1.nt, key) for key in keys_in_at1) + accs_in_at2 = (getfield(at2.nt, key) for key in keys_in_at2) + accs_in_both = ( + merge(getfield(at1.nt, key), getfield(at2.nt, key)) for key in keys_in_both + ) + return AccumulatorTuple(accs_in_at1..., accs_in_at2..., accs_in_both...) +end + """ setacc!!(at::AccumulatorTuple, acc::AbstractAccumulator) diff --git a/src/default_accumulators.jl b/src/default_accumulators.jl index 3c733ab5a..0c7d1e9f9 100644 --- a/src/default_accumulators.jl +++ b/src/default_accumulators.jl @@ -266,3 +266,19 @@ function default_accumulators( VariableOrderAccumulator{IntT}(), ) end + +function subset(acc::VariableOrderAccumulator, vns::AbstractVector{<:VarName}) + order = filter(pair -> any(subsumes(vn, first(pair)) for vn in vns), acc.order) + return VariableOrderAccumulator(acc.num_produce, order) +end + +""" + merge(acc1::VariableOrderAccumulator, acc2::VariableOrderAccumulator) + +Merge two `VariableOrderAccumulator` instances. + +The `num_produce` field of the return value is the `num_produce` of `acc2`. +""" +function Base.merge(acc1::VariableOrderAccumulator, acc2::VariableOrderAccumulator) + return VariableOrderAccumulator(acc2.num_produce, merge(acc1.order, acc2.order)) +end diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 9479e26f2..e2f0d3ae4 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -417,7 +417,9 @@ Base.eltype(::SimpleOrThreadSafeSimple{<:Any,V}) where {V} = V # `subset` function subset(varinfo::SimpleVarInfo, vns::AbstractVector{<:VarName}) - return Accessors.@set varinfo.values = _subset(varinfo.values, vns) + return SimpleVarInfo( + _subset(varinfo.values, vns), subset(getaccs(varinfo), vns), varinfo.transformation + ) end function _subset(x::AbstractDict, vns::AbstractVector{VN}) where {VN<:VarName} @@ -454,7 +456,7 @@ _subset(x::VarNamedVector, vns) = subset(x, vns) # `merge` function Base.merge(varinfo_left::SimpleVarInfo, varinfo_right::SimpleVarInfo) values = merge(varinfo_left.values, varinfo_right.values) - accs = copy(getaccs(varinfo_right)) + accs = merge(getaccs(varinfo_left), getaccs(varinfo_right)) transformation = merge_transformations( varinfo_left.transformation, varinfo_right.transformation ) diff --git a/src/varinfo.jl b/src/varinfo.jl index 7b819c58f..f3730426c 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -447,7 +447,7 @@ end function subset(varinfo::VarInfo, vns::AbstractVector{<:VarName}) metadata = subset(varinfo.metadata, vns) - return VarInfo(metadata, copy(varinfo.accs)) + return VarInfo(metadata, subset(getaccs(varinfo), vns)) end function subset(metadata::NamedTuple, vns::AbstractVector{<:VarName}) @@ -528,7 +528,8 @@ end function _merge(varinfo_left::VarInfo, varinfo_right::VarInfo) metadata = merge_metadata(varinfo_left.metadata, varinfo_right.metadata) - return VarInfo(metadata, copy(varinfo_right.accs)) + accs = merge(getaccs(varinfo_left), getaccs(varinfo_right)) + return VarInfo(metadata, accs) end function merge_metadata(vnv_left::VarNamedVector, vnv_right::VarNamedVector) diff --git a/test/accumulators.jl b/test/accumulators.jl index bbd9f772c..139992bc9 100644 --- a/test/accumulators.jl +++ b/test/accumulators.jl @@ -108,6 +108,30 @@ using DynamicPPL: @test accumulate_observe!!(VariableOrderAccumulator(1), right, left, vn) == VariableOrderAccumulator(2) end + + @testset "merge and subset" begin + @test merge(LogPriorAccumulator(1.0), LogPriorAccumulator(2.0)) == + LogPriorAccumulator(3.0) + @test merge(LogJacobianAccumulator(1.0), LogJacobianAccumulator(2.0)) == + LogJacobianAccumulator(3.0) + @test merge(LogLikelihoodAccumulator(1.0), LogLikelihoodAccumulator(2.0)) == + LogLikelihoodAccumulator(3.0) + + @test merge( + VariableOrderAccumulator(1, Dict{VarName,Int}()), + VariableOrderAccumulator(2, Dict{VarName,Int}()), + ) == VariableOrderAccumulator(2, Dict{VarName,Int}()) + @test merge( + VariableOrderAccumulator( + 2, Dict{VarName,Int}((@varname(a) => 1, @varname(b) => 2)) + ), + VariableOrderAccumulator( + 1, Dict{VarName,Int}((@varname(a) => 2, @varname(c) => 3)) + ), + ) == VariableOrderAccumulator( + 1, Dict{VarName,Int}((@varname(a) => 2, @varname(b) => 2, @varname(c) => 3)) + ) + end end @testset "accumulator tuples" begin From 2dc61a813ef848ca24fa14c52db8f86d7b04f293 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 24 Jul 2025 13:31:29 +0100 Subject: [PATCH 13/15] Improve acc tests --- src/DynamicPPL.jl | 1 + src/accumulators.jl | 10 ++-- test/accumulators.jl | 113 ++++++++++++++++++++++++++++++++++++++----- 3 files changed, 106 insertions(+), 18 deletions(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 15d39014e..c53681829 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -66,6 +66,7 @@ export AbstractVarInfo, setlogprior!!, setlogjac!!, setloglikelihood!!, + acclogp, acclogp!!, acclogjac!!, acclogprior!!, diff --git a/src/accumulators.jl b/src/accumulators.jl index d3cfc7222..fc911909e 100644 --- a/src/accumulators.jl +++ b/src/accumulators.jl @@ -203,11 +203,9 @@ Implemented as a generated function to enabled constant propagation of the resul @generated function _joint_keys( nt1::NamedTuple{names1}, nt2::NamedTuple{names2} ) where {names1,names2} - set_names1 = Set(names1) - set_names2 = Set(names2) - only_in_nt1 = tuple(setdiff(set_names1, set_names2)...) - only_in_nt2 = tuple(setdiff(set_names2, set_names1)...) - in_both = tuple(intersect(set_names1, set_names2)...) + only_in_nt1 = tuple(setdiff(names1, names2)...) + only_in_nt2 = tuple(setdiff(names2, names1)...) + in_both = tuple(intersect(names1, names2)...) return :($only_in_nt1, $only_in_nt2, $in_both) end @@ -226,7 +224,7 @@ function Base.merge(at1::AccumulatorTuple, at2::AccumulatorTuple) accs_in_both = ( merge(getfield(at1.nt, key), getfield(at2.nt, key)) for key in keys_in_both ) - return AccumulatorTuple(accs_in_at1..., accs_in_at2..., accs_in_both...) + return AccumulatorTuple(accs_in_at1..., accs_in_both..., accs_in_at2...) end """ diff --git a/test/accumulators.jl b/test/accumulators.jl index 139992bc9..d84fbf43d 100644 --- a/test/accumulators.jl +++ b/test/accumulators.jl @@ -109,13 +109,13 @@ using DynamicPPL: VariableOrderAccumulator(2) end - @testset "merge and subset" begin + @testset "merge" begin @test merge(LogPriorAccumulator(1.0), LogPriorAccumulator(2.0)) == - LogPriorAccumulator(3.0) + LogPriorAccumulator(2.0) @test merge(LogJacobianAccumulator(1.0), LogJacobianAccumulator(2.0)) == - LogJacobianAccumulator(3.0) + LogJacobianAccumulator(2.0) @test merge(LogLikelihoodAccumulator(1.0), LogLikelihoodAccumulator(2.0)) == - LogLikelihoodAccumulator(3.0) + LogLikelihoodAccumulator(2.0) @test merge( VariableOrderAccumulator(1, Dict{VarName,Int}()), @@ -132,6 +132,49 @@ using DynamicPPL: 1, Dict{VarName,Int}((@varname(a) => 2, @varname(b) => 2, @varname(c) => 3)) ) end + + @testset "subset" begin + @test subset(LogPriorAccumulator(1.0), VarName[]) == LogPriorAccumulator(1.0) + @test subset(LogJacobianAccumulator(1.0), VarName[]) == + LogJacobianAccumulator(1.0) + @test subset(LogLikelihoodAccumulator(1.0), VarName[]) == + LogLikelihoodAccumulator(1.0) + + @test subset( + VariableOrderAccumulator(1, Dict{VarName,Int}()), + VarName[@varname(a), @varname(b)], + ) == VariableOrderAccumulator(1, Dict{VarName,Int}()) + @test subset( + VariableOrderAccumulator( + 2, Dict{VarName,Int}((@varname(a) => 1, @varname(b) => 2)) + ), + VarName[@varname(a)], + ) == VariableOrderAccumulator(2, Dict{VarName,Int}((@varname(a) => 1))) + @test subset( + VariableOrderAccumulator( + 2, Dict{VarName,Int}((@varname(a) => 1, @varname(b) => 2)) + ), + VarName[], + ) == VariableOrderAccumulator(2, Dict{VarName,Int}()) + @test subset( + VariableOrderAccumulator( + 2, + Dict{VarName,Int}(( + @varname(a) => 1, + @varname(a.b.c) => 2, + @varname(a.b.c.d[1]) => 2, + @varname(b) => 3, + @varname(c[1]) => 4, + )), + ), + VarName[@varname(a.b), @varname(b)], + ) == VariableOrderAccumulator( + 2, + Dict{VarName,Int}(( + @varname(a.b.c) => 2, @varname(a.b.c.d[1]) => 2, @varname(b) => 3 + )), + ) + end end @testset "accumulator tuples" begin @@ -140,7 +183,7 @@ using DynamicPPL: lp_f32 = LogPriorAccumulator(1.0f0) ll_f64 = LogLikelihoodAccumulator(1.0) ll_f32 = LogLikelihoodAccumulator(1.0f0) - np_i64 = VariableOrderAccumulator(1) + vo_i64 = VariableOrderAccumulator(1) @testset "constructors" begin @test AccumulatorTuple(lp_f64, ll_f64) == AccumulatorTuple((lp_f64, ll_f64)) @@ -154,22 +197,22 @@ using DynamicPPL: end @testset "basic operations" begin - at_all64 = AccumulatorTuple(lp_f64, ll_f64, np_i64) + at_all64 = AccumulatorTuple(lp_f64, ll_f64, vo_i64) @test at_all64[:LogPrior] == lp_f64 @test at_all64[:LogLikelihood] == ll_f64 - @test at_all64[:VariableOrder] == np_i64 + @test at_all64[:VariableOrder] == vo_i64 - @test haskey(AccumulatorTuple(np_i64), Val(:VariableOrder)) - @test ~haskey(AccumulatorTuple(np_i64), Val(:LogPrior)) - @test length(AccumulatorTuple(lp_f64, ll_f64, np_i64)) == 3 + @test haskey(AccumulatorTuple(vo_i64), Val(:VariableOrder)) + @test ~haskey(AccumulatorTuple(vo_i64), Val(:LogPrior)) + @test length(AccumulatorTuple(lp_f64, ll_f64, vo_i64)) == 3 @test keys(at_all64) == (:LogPrior, :LogLikelihood, :VariableOrder) - @test collect(at_all64) == [lp_f64, ll_f64, np_i64] + @test collect(at_all64) == [lp_f64, ll_f64, vo_i64] # Replace the existing LogPriorAccumulator @test setacc!!(at_all64, lp_f32)[:LogPrior] == lp_f32 # Check that setacc!! didn't modify the original - @test at_all64 == AccumulatorTuple(lp_f64, ll_f64, np_i64) + @test at_all64 == AccumulatorTuple(lp_f64, ll_f64, vo_i64) # Add a new accumulator type. @test setacc!!(AccumulatorTuple(lp_f64), ll_f64) == AccumulatorTuple(lp_f64, ll_f64) @@ -197,6 +240,52 @@ using DynamicPPL: acc -> convert_eltype(Float64, acc), accs, Val(:LogLikelihood) ) == AccumulatorTuple(lp_f32, LogLikelihoodAccumulator(1.0)) end + + @testset "merge" begin + vo1 = VariableOrderAccumulator( + 1, Dict{VarName,Int}(@varname(a) => 1, @varname(b) => 1) + ) + vo2 = VariableOrderAccumulator( + 2, Dict{VarName,Int}(@varname(a) => 2, @varname(c) => 2) + ) + accs1 = AccumulatorTuple(lp_f64, ll_f64, vo1) + accs2 = AccumulatorTuple(lp_f32, vo2) + @test merge(accs1, accs2) == AccumulatorTuple( + ll_f64, + lp_f32, + VariableOrderAccumulator( + 2, + Dict{VarName,Int}(@varname(a) => 2, @varname(b) => 1, @varname(c) => 2), + ), + ) + @test merge(AccumulatorTuple(), accs1) == accs1 + @test merge(accs1, AccumulatorTuple()) == accs1 + @test merge(accs1, accs1) == accs1 + end + + @testset "subset" begin + accs = AccumulatorTuple( + lp_f64, + ll_f64, + VariableOrderAccumulator( + 1, + Dict{VarName,Int}( + @varname(a.b) => 1, @varname(a.b[1]) => 2, @varname(b) => 1 + ), + ), + ) + + @test subset(accs, VarName[]) == AccumulatorTuple( + lp_f64, ll_f64, VariableOrderAccumulator(1, Dict{VarName,Int}()) + ) + @test subset(accs, VarName[@varname(a)]) == AccumulatorTuple( + lp_f64, + ll_f64, + VariableOrderAccumulator( + 1, Dict{VarName,Int}(@varname(a.b) => 1, @varname(a.b[1]) => 2) + ), + ) + end end end From e852b5eef084057285f0eccaa2a74b68297a7f88 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 24 Jul 2025 14:53:18 +0100 Subject: [PATCH 14/15] Fix docstring typo. Co-authored-by: Penelope Yong --- src/accumulators.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/accumulators.jl b/src/accumulators.jl index fc911909e..d4fbdd88c 100644 --- a/src/accumulators.jl +++ b/src/accumulators.jl @@ -198,7 +198,7 @@ end A helper function that returns three tuples of keys given two `NamedTuple`s: The keys only in `nt1`, only in `nt2`, and in both, and in that order. -Implemented as a generated function to enabled constant propagation of the result in `merge`. +Implemented as a generated function to enable constant propagation of the result in `merge`. """ @generated function _joint_keys( nt1::NamedTuple{names1}, nt2::NamedTuple{names2} From 1aac709627db4abb700e8f10e6d2b564ee216516 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 24 Jul 2025 15:23:21 +0100 Subject: [PATCH 15/15] Fix merge --- src/abstract_varinfo.jl | 15 +-------------- src/default_accumulators.jl | 2 +- 2 files changed, 2 insertions(+), 15 deletions(-) diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 0b4c887d0..caf6dc16c 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -361,17 +361,6 @@ function acclogprior!!(vi::AbstractVarInfo, logp) return map_accumulator!!(acc -> acclogp(acc, logp), vi, Val(:LogPrior)) end -""" - acclogjac!!(vi::AbstractVarInfo, logJ) - -Add `logJ` to the value of the log Jacobian in `vi`. - -See also: [`getlogjac`](@ref), [`setlogjac!!`](@ref). -""" -function acclogjac!!(vi::AbstractVarInfo, logJ) - return map_accumulator!!(acc -> acclogp(acc, logJ), vi, Val(:LogJacobian)) -end - """ acclogjac!!(vi::AbstractVarInfo, logjac) @@ -380,9 +369,7 @@ Add `logjac` to the value of the log Jacobian in `vi`. See also: [`getlogjac`](@ref), [`setlogjac!!`](@ref). """ function acclogjac!!(vi::AbstractVarInfo, logjac) - return map_accumulator!!( - acc -> acc + LogJacobianAccumulator(logjac), vi, Val(:LogJacobian) - ) + return map_accumulator!!(acc -> acclogp(acc, logjac), vi, Val(:LogJacobian)) end """ diff --git a/src/default_accumulators.jl b/src/default_accumulators.jl index 1b9400d65..8d51a8431 100644 --- a/src/default_accumulators.jl +++ b/src/default_accumulators.jl @@ -31,7 +31,7 @@ Create a new `LogProbAccumulator` accumulator with the log prior initialized to Base.copy(acc::LogProbAccumulator) = acc function Base.show(io::IO, acc::LogProbAccumulator) - return print(io, "$(repr(accumulator_name(acc)))($(repr(logp(acc)))))") + return print(io, "$(string(basetypeof(acc)))($(repr(logp(acc))))") end # Note that == and isequal are different, and equality under the latter should imply