diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index c4e7e6fba..4a13c9878 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -66,6 +66,7 @@ export AbstractVarInfo, setlogprior!!, setlogjac!!, setloglikelihood!!, + acclogp, acclogp!!, acclogjac!!, acclogprior!!, diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index cf5ce5706..caf6dc16c 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -358,7 +358,7 @@ Add `logp` to the value of the log of the prior probability in `vi`. See also: [`accloglikelihood!!`](@ref), [`acclogp!!`](@ref), [`getlogprior`](@ref), [`setlogprior!!`](@ref). """ function acclogprior!!(vi::AbstractVarInfo, logp) - return map_accumulator!!(acc -> acc + LogPriorAccumulator(logp), vi, Val(:LogPrior)) + return map_accumulator!!(acc -> acclogp(acc, logp), vi, Val(:LogPrior)) end """ @@ -369,9 +369,7 @@ Add `logjac` to the value of the log Jacobian in `vi`. See also: [`getlogjac`](@ref), [`setlogjac!!`](@ref). """ function acclogjac!!(vi::AbstractVarInfo, logjac) - return map_accumulator!!( - acc -> acc + LogJacobianAccumulator(logjac), vi, Val(:LogJacobian) - ) + return map_accumulator!!(acc -> acclogp(acc, logjac), vi, Val(:LogJacobian)) end """ @@ -382,9 +380,7 @@ Add `logp` to the value of the log of the likelihood in `vi`. See also: [`accloglikelihood!!`](@ref), [`acclogp!!`](@ref), [`getloglikelihood`](@ref), [`setloglikelihood!!`](@ref). """ function accloglikelihood!!(vi::AbstractVarInfo, logp) - return map_accumulator!!( - acc -> acc + LogLikelihoodAccumulator(logp), vi, Val(:LogLikelihood) - ) + return map_accumulator!!(acc -> acclogp(acc, logp), vi, Val(:LogLikelihood)) end """ diff --git a/src/accumulators.jl b/src/accumulators.jl index 0dcf9c7cf..b560307b7 100644 --- a/src/accumulators.jl +++ b/src/accumulators.jl @@ -30,6 +30,13 @@ To be able to work with multi-threading, it should also implement: - `split(acc::T)` - `combine(acc::T, acc2::T)` +If two accumulators of the same type should be merged in some non-trivial way, other than +always keeping the second one over the first, `merge(acc1::T, acc2::T)` should be defined. + +If limiting the accumulator to a subset of `VarName`s is a meaningful operation and should +do something other than copy the original accumulator, then +`subset(acc::T, vns::AbstractVector{<:VarnName})` should be defined.` + See the documentation for each of these functions for more details. """ abstract type AbstractAccumulator end @@ -113,6 +120,24 @@ used by various AD backends, should implement a method for this function. """ convert_eltype(::Type, acc::AbstractAccumulator) = acc +""" + subset(acc::AbstractAccumulator, vns::AbstractVector{<:VarName}) + +Return a new accumulator that only contains the information for the `VarName`s in `vns`. + +By default returns a copy of `acc`. Subtypes should override this behaviour as needed. +""" +subset(acc::AbstractAccumulator, ::AbstractVector{<:VarName}) = copy(acc) + +""" + merge(acc1::AbstractAccumulator, acc2::AbstractAccumulator) + +Merge two accumulators of the same type. Returns a new accumulator of the same type. + +By default returns a copy of `acc2`. Subtypes should override this behaviour as needed. +""" +Base.merge(acc1::AbstractAccumulator, acc2::AbstractAccumulator) = copy(acc2) + """ AccumulatorTuple{N,T<:NamedTuple} @@ -158,6 +183,50 @@ function Base.convert(::Type{AccumulatorTuple{N,T}}, accs::AccumulatorTuple{N}) return AccumulatorTuple(convert(T, accs.nt)) end +""" + subset(at::AccumulatorTuple, vns::AbstractVector{<:VarName}) + +Replace each accumulator `acc` in `at` with `subset(acc, vns)`. +""" +function subset(at::AccumulatorTuple, vns::AbstractVector{<:VarName}) + return AccumulatorTuple(map(Base.Fix2(subset, vns), at.nt)) +end + +""" + _joint_keys(nt1::NamedTuple, nt2::NamedTuple) + +A helper function that returns three tuples of keys given two `NamedTuple`s: +The keys only in `nt1`, only in `nt2`, and in both, and in that order. + +Implemented as a generated function to enable constant propagation of the result in `merge`. +""" +@generated function _joint_keys( + nt1::NamedTuple{names1}, nt2::NamedTuple{names2} +) where {names1,names2} + only_in_nt1 = tuple(setdiff(names1, names2)...) + only_in_nt2 = tuple(setdiff(names2, names1)...) + in_both = tuple(intersect(names1, names2)...) + return :($only_in_nt1, $only_in_nt2, $in_both) +end + +""" + merge(at1::AccumulatorTuple, at2::AccumulatorTuple) + +Merge two `AccumulatorTuple`s. + +For any `accumulator_name` that exists in both `at1` and `at2`, we call `merge` on the two +accumulators themselves. Other accumulators are copied. +""" +function Base.merge(at1::AccumulatorTuple, at2::AccumulatorTuple) + keys_in_at1, keys_in_at2, keys_in_both = _joint_keys(at1.nt, at2.nt) + accs_in_at1 = (getfield(at1.nt, key) for key in keys_in_at1) + accs_in_at2 = (getfield(at2.nt, key) for key in keys_in_at2) + accs_in_both = ( + merge(getfield(at1.nt, key), getfield(at2.nt, key)) for key in keys_in_both + ) + return AccumulatorTuple(accs_in_at1..., accs_in_both..., accs_in_at2...) +end + """ setacc!!(at::AccumulatorTuple, acc::AbstractAccumulator) diff --git a/src/default_accumulators.jl b/src/default_accumulators.jl index d503b3e64..8d51a8431 100644 --- a/src/default_accumulators.jl +++ b/src/default_accumulators.jl @@ -1,5 +1,78 @@ """ - LogPriorAccumulator{T<:Real} <: AbstractAccumulator + LogProbAccumulator{T} <: AbstractAccumulator + +An abstract type for accumulators that hold a single scalar log probability value. + +Every subtype of `LogProbAccumulator` must implement +* A method for `logp` that returns the scalar log probability value that defines it. +* A single-argument constructor that takes a `logp` value. +* `accumulator_name`, `accumulate_assume!!`, and `accumulate_observe!!` methods like any + other accumulator. + +`LogProbAccumulator` provides implementations for other common functions, like convenience +constructors, `copy`, `show`, `==`, `isequal`, `hash`, `split`, and `combine`. + +This type has no great conceptual significance, it just reduces code duplication between +types like LogPriorAccumulator, LogJacobianAccumulator, and LogLikelihoodAccumulator. +""" +abstract type LogProbAccumulator{T<:Real} <: AbstractAccumulator end + +# The first of the below methods sets AccType{T}() = AccType(zero(T)) for any +# AccType <: LogProbAccumulator{T}. The second one sets LogProbType as the default eltype T +# when calling AccType(). +""" + LogProbAccumulator{T}() + +Create a new `LogProbAccumulator` accumulator with the log prior initialized to zero. +""" +(::Type{AccType})() where {T<:Real,AccType<:LogProbAccumulator{T}} = AccType(zero(T)) +(::Type{AccType})() where {AccType<:LogProbAccumulator} = AccType{LogProbType}() + +Base.copy(acc::LogProbAccumulator) = acc + +function Base.show(io::IO, acc::LogProbAccumulator) + return print(io, "$(string(basetypeof(acc)))($(repr(logp(acc))))") +end + +# Note that == and isequal are different, and equality under the latter should imply +# equality of hashes. Both of the below implementations are also different from the default +# implementation for structs. +function Base.:(==)(acc1::LogProbAccumulator, acc2::LogProbAccumulator) + return accumulator_name(acc1) === accumulator_name(acc2) && logp(acc1) == logp(acc2) +end + +function Base.isequal(acc1::LogProbAccumulator, acc2::LogProbAccumulator) + return basetypeof(acc1) === basetypeof(acc2) && isequal(logp(acc1), logp(acc2)) +end + +Base.hash(acc::T, h::UInt) where {T<:LogProbAccumulator} = hash((T, logp(acc)), h) + +split(::AccType) where {T,AccType<:LogProbAccumulator{T}} = AccType(zero(T)) + +function combine(acc::LogProbAccumulator, acc2::LogProbAccumulator) + if basetypeof(acc) !== basetypeof(acc2) + msg = "Cannot combine accumulators of different types: $(basetypeof(acc)) and $(basetypeof(acc2))" + throw(ArgumentError(msg)) + end + return basetypeof(acc)(logp(acc) + logp(acc2)) +end + +acclogp(acc::LogProbAccumulator, val) = basetypeof(acc)(logp(acc) + val) + +Base.zero(acc::T) where {T<:LogProbAccumulator} = T(zero(logp(acc))) + +function Base.convert( + ::Type{AccType}, acc::LogProbAccumulator +) where {T,AccType<:LogProbAccumulator{T}} + return AccType(convert(T, logp(acc))) +end + +function convert_eltype(::Type{T}, acc::LogProbAccumulator) where {T} + return basetypeof(acc)(convert(T, logp(acc))) +end + +""" + LogPriorAccumulator{T<:Real} <: LogProbAccumulator{T} An accumulator that tracks the cumulative log prior during model execution. @@ -10,21 +83,22 @@ linked or not. # Fields $(TYPEDFIELDS) """ -struct LogPriorAccumulator{T<:Real} <: AbstractAccumulator +struct LogPriorAccumulator{T<:Real} <: LogProbAccumulator{T} "the scalar log prior value" logp::T end -""" - LogPriorAccumulator{T}() +logp(acc::LogPriorAccumulator) = acc.logp -Create a new `LogPriorAccumulator` accumulator with the log prior initialized to zero. -""" -LogPriorAccumulator{T}() where {T<:Real} = LogPriorAccumulator(zero(T)) -LogPriorAccumulator() = LogPriorAccumulator{LogProbType}() +accumulator_name(::Type{<:LogPriorAccumulator}) = :LogPrior + +function accumulate_assume!!(acc::LogPriorAccumulator, val, logjac, vn, right) + return acclogp(acc, logpdf(right, val)) +end +accumulate_observe!!(acc::LogPriorAccumulator, right, left, vn) = acc """ - LogJacobianAccumulator{T<:Real} <: AbstractAccumulator + LogJacobianAccumulator{T<:Real} <: LogProbAccumulator{T} An accumulator that tracks the cumulative log Jacobian (technically, log(abs(det(J)))) during model execution. Specifically, J refers to the @@ -53,39 +127,44 @@ distribution to unconstrained space. # Fields $(TYPEDFIELDS) """ -struct LogJacobianAccumulator{T<:Real} <: AbstractAccumulator +struct LogJacobianAccumulator{T<:Real} <: LogProbAccumulator{T} "the logabsdet of the link transform Jacobian" logjac::T end -""" - LogJacobianAccumulator{T}() +logp(acc::LogJacobianAccumulator) = acc.logjac -Create a new `LogJacobianAccumulator` accumulator with the log Jacobian initialized to zero. -""" -LogJacobianAccumulator{T}() where {T<:Real} = LogJacobianAccumulator(zero(T)) -LogJacobianAccumulator() = LogJacobianAccumulator{LogProbType}() +accumulator_name(::Type{<:LogJacobianAccumulator}) = :LogJacobian + +function accumulate_assume!!(acc::LogJacobianAccumulator, val, logjac, vn, right) + return acclogp(acc, logjac) +end +accumulate_observe!!(acc::LogJacobianAccumulator, right, left, vn) = acc """ - LogLikelihoodAccumulator{T<:Real} <: AbstractAccumulator + LogLikelihoodAccumulator{T<:Real} <: LogProbAccumulator{T} An accumulator that tracks the cumulative log likelihood during model execution. # Fields $(TYPEDFIELDS) """ -struct LogLikelihoodAccumulator{T<:Real} <: AbstractAccumulator +struct LogLikelihoodAccumulator{T<:Real} <: LogProbAccumulator{T} "the scalar log likelihood value" logp::T end -""" - LogLikelihoodAccumulator{T}() +logp(acc::LogLikelihoodAccumulator) = acc.logp -Create a new `LogLikelihoodAccumulator` accumulator with the log likelihood initialized to zero. -""" -LogLikelihoodAccumulator{T}() where {T<:Real} = LogLikelihoodAccumulator(zero(T)) -LogLikelihoodAccumulator() = LogLikelihoodAccumulator{LogProbType}() +accumulator_name(::Type{<:LogLikelihoodAccumulator}) = :LogLikelihood + +accumulate_assume!!(acc::LogLikelihoodAccumulator, val, logjac, vn, right) = acc +function accumulate_observe!!(acc::LogLikelihoodAccumulator, right, left, vn) + # Note that it's important to use the loglikelihood function here, not logpdf, because + # they handle vectors differently: + # https://github.com/JuliaStats/Distributions.jl/issues/1972 + return acclogp(acc, Distributions.loglikelihood(right, left)) +end """ VariableOrderAccumulator{T} <: AbstractAccumulator @@ -117,85 +196,32 @@ VariableOrderAccumulator{T}(n=zero(T)) where {T<:Integer} = VariableOrderAccumulator(n) = VariableOrderAccumulator{typeof(n)}(n) VariableOrderAccumulator() = VariableOrderAccumulator{Int}() -Base.copy(acc::LogPriorAccumulator) = acc -Base.copy(acc::LogJacobianAccumulator) = acc -Base.copy(acc::LogLikelihoodAccumulator) = acc function Base.copy(acc::VariableOrderAccumulator) return VariableOrderAccumulator(acc.num_produce, copy(acc.order)) end -function Base.show(io::IO, acc::LogPriorAccumulator) - return print(io, "LogPriorAccumulator($(repr(acc.logp)))") -end -function Base.show(io::IO, acc::LogJacobianAccumulator) - return print(io, "LogJacobianAccumulator($(repr(acc.logjac)))") -end -function Base.show(io::IO, acc::LogLikelihoodAccumulator) - return print(io, "LogLikelihoodAccumulator($(repr(acc.logp)))") -end function Base.show(io::IO, acc::VariableOrderAccumulator) return print( - io, "VariableOrderAccumulator($(repr(acc.num_produce)), $(repr(acc.order)))" + io, "VariableOrderAccumulator($(string(acc.num_produce)), $(repr(acc.order)))" ) end -# Note that == and isequal are different, and equality under the latter should imply -# equality of hashes. Both of the below implementations are also different from the default -# implementation for structs. -Base.:(==)(acc1::LogPriorAccumulator, acc2::LogPriorAccumulator) = acc1.logp == acc2.logp -function Base.:(==)(acc1::LogJacobianAccumulator, acc2::LogJacobianAccumulator) - return acc1.logjac == acc2.logjac -end -function Base.:(==)(acc1::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator) - return acc1.logp == acc2.logp -end function Base.:(==)(acc1::VariableOrderAccumulator, acc2::VariableOrderAccumulator) return acc1.num_produce == acc2.num_produce && acc1.order == acc2.order end -function Base.isequal(acc1::LogPriorAccumulator, acc2::LogPriorAccumulator) - return isequal(acc1.logp, acc2.logp) -end -function Base.isequal(acc1::LogJacobianAccumulator, acc2::LogJacobianAccumulator) - return isequal(acc1.logjac, acc2.logjac) -end -function Base.isequal(acc1::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator) - return isequal(acc1.logp, acc2.logp) -end function Base.isequal(acc1::VariableOrderAccumulator, acc2::VariableOrderAccumulator) return isequal(acc1.num_produce, acc2.num_produce) && isequal(acc1.order, acc2.order) end -Base.hash(acc::LogPriorAccumulator, h::UInt) = hash((LogPriorAccumulator, acc.logp), h) -function Base.hash(acc::LogJacobianAccumulator, h::UInt) - return hash((LogJacobianAccumulator, acc.logjac), h) -end -function Base.hash(acc::LogLikelihoodAccumulator, h::UInt) - return hash((LogLikelihoodAccumulator, acc.logp), h) -end function Base.hash(acc::VariableOrderAccumulator, h::UInt) return hash((VariableOrderAccumulator, acc.num_produce, acc.order), h) end -accumulator_name(::Type{<:LogPriorAccumulator}) = :LogPrior -accumulator_name(::Type{<:LogJacobianAccumulator}) = :LogJacobian -accumulator_name(::Type{<:LogLikelihoodAccumulator}) = :LogLikelihood accumulator_name(::Type{<:VariableOrderAccumulator}) = :VariableOrder -split(::LogPriorAccumulator{T}) where {T} = LogPriorAccumulator(zero(T)) -split(::LogJacobianAccumulator{T}) where {T} = LogJacobianAccumulator(zero(T)) -split(::LogLikelihoodAccumulator{T}) where {T} = LogLikelihoodAccumulator(zero(T)) split(acc::VariableOrderAccumulator) = copy(acc) -function combine(acc::LogPriorAccumulator, acc2::LogPriorAccumulator) - return LogPriorAccumulator(acc.logp + acc2.logp) -end -function combine(acc::LogJacobianAccumulator, acc2::LogJacobianAccumulator) - return LogJacobianAccumulator(acc.logjac + acc2.logjac) -end -function combine(acc::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator) - return LogLikelihoodAccumulator(acc.logp + acc2.logp) -end function combine(acc::VariableOrderAccumulator, acc2::VariableOrderAccumulator) # Note that assumptions are not allowed in parallelised blocks, and thus the # dictionaries should be identical. @@ -204,60 +230,16 @@ function combine(acc::VariableOrderAccumulator, acc2::VariableOrderAccumulator) ) end -function Base.:+(acc1::LogPriorAccumulator, acc2::LogPriorAccumulator) - return LogPriorAccumulator(acc1.logp + acc2.logp) -end -function Base.:+(acc1::LogJacobianAccumulator, acc2::LogJacobianAccumulator) - return LogJacobianAccumulator(acc1.logjac + acc2.logjac) -end -function Base.:+(acc1::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator) - return LogLikelihoodAccumulator(acc1.logp + acc2.logp) -end function increment(acc::VariableOrderAccumulator) return VariableOrderAccumulator(acc.num_produce + oneunit(acc.num_produce), acc.order) end -Base.zero(acc::LogPriorAccumulator) = LogPriorAccumulator(zero(acc.logp)) -Base.zero(acc::LogJacobianAccumulator) = LogJacobianAccumulator(zero(acc.logjac)) -Base.zero(acc::LogLikelihoodAccumulator) = LogLikelihoodAccumulator(zero(acc.logp)) - -function accumulate_assume!!(acc::LogPriorAccumulator, val, logjac, vn, right) - return acc + LogPriorAccumulator(logpdf(right, val)) -end -accumulate_observe!!(acc::LogPriorAccumulator, right, left, vn) = acc - -function accumulate_assume!!(acc::LogJacobianAccumulator, val, logjac, vn, right) - return acc + LogJacobianAccumulator(logjac) -end -accumulate_observe!!(acc::LogJacobianAccumulator, right, left, vn) = acc - -accumulate_assume!!(acc::LogLikelihoodAccumulator, val, logjac, vn, right) = acc -function accumulate_observe!!(acc::LogLikelihoodAccumulator, right, left, vn) - # Note that it's important to use the loglikelihood function here, not logpdf, because - # they handle vectors differently: - # https://github.com/JuliaStats/Distributions.jl/issues/1972 - return acc + LogLikelihoodAccumulator(Distributions.loglikelihood(right, left)) -end - function accumulate_assume!!(acc::VariableOrderAccumulator, val, logjac, vn, right) acc.order[vn] = acc.num_produce return acc end accumulate_observe!!(acc::VariableOrderAccumulator, right, left, vn) = increment(acc) -function Base.convert(::Type{LogPriorAccumulator{T}}, acc::LogPriorAccumulator) where {T} - return LogPriorAccumulator(convert(T, acc.logp)) -end -function Base.convert( - ::Type{LogJacobianAccumulator{T}}, acc::LogJacobianAccumulator -) where {T} - return LogJacobianAccumulator(convert(T, acc.logjac)) -end -function Base.convert( - ::Type{LogLikelihoodAccumulator{T}}, acc::LogLikelihoodAccumulator -) where {T} - return LogLikelihoodAccumulator(convert(T, acc.logp)) -end function Base.convert( ::Type{VariableOrderAccumulator{ElType,VnType}}, acc::VariableOrderAccumulator ) where {ElType,VnType} @@ -273,15 +255,6 @@ end # convert_eltype(::AbstractAccumulator, ::Type). This is because they are only used to # deal with dual number types of AD backends, which shouldn't concern VariableOrderAccumulator. This is # horribly hacky and should be fixed. See also comment in `unflatten` in `src/varinfo.jl`. -function convert_eltype(::Type{T}, acc::LogPriorAccumulator) where {T} - return LogPriorAccumulator(convert(T, acc.logp)) -end -function convert_eltype(::Type{T}, acc::LogJacobianAccumulator) where {T} - return LogJacobianAccumulator(convert(T, acc.logjac)) -end -function convert_eltype(::Type{T}, acc::LogLikelihoodAccumulator) where {T} - return LogLikelihoodAccumulator(convert(T, acc.logp)) -end function default_accumulators( ::Type{FloatT}=LogProbType, ::Type{IntT}=Int @@ -293,3 +266,19 @@ function default_accumulators( VariableOrderAccumulator{IntT}(), ) end + +function subset(acc::VariableOrderAccumulator, vns::AbstractVector{<:VarName}) + order = filter(pair -> any(subsumes(vn, first(pair)) for vn in vns), acc.order) + return VariableOrderAccumulator(acc.num_produce, order) +end + +""" + merge(acc1::VariableOrderAccumulator, acc2::VariableOrderAccumulator) + +Merge two `VariableOrderAccumulator` instances. + +The `num_produce` field of the return value is the `num_produce` of `acc2`. +""" +function Base.merge(acc1::VariableOrderAccumulator, acc2::VariableOrderAccumulator) + return VariableOrderAccumulator(acc2.num_produce, merge(acc1.order, acc2.order)) +end diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 1538428fd..4997b4b8d 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -417,7 +417,9 @@ Base.eltype(::SimpleOrThreadSafeSimple{<:Any,V}) where {V} = V # `subset` function subset(varinfo::SimpleVarInfo, vns::AbstractVector{<:VarName}) - return Accessors.@set varinfo.values = _subset(varinfo.values, vns) + return SimpleVarInfo( + _subset(varinfo.values, vns), subset(getaccs(varinfo), vns), varinfo.transformation + ) end function _subset(x::AbstractDict, vns::AbstractVector{VN}) where {VN<:VarName} @@ -454,7 +456,7 @@ _subset(x::VarNamedVector, vns) = subset(x, vns) # `merge` function Base.merge(varinfo_left::SimpleVarInfo, varinfo_right::SimpleVarInfo) values = merge(varinfo_left.values, varinfo_right.values) - accs = copy(getaccs(varinfo_right)) + accs = merge(getaccs(varinfo_left), getaccs(varinfo_right)) transformation = merge_transformations( varinfo_left.transformation, varinfo_right.transformation ) diff --git a/src/utils.jl b/src/utils.jl index af2891a2b..d3371271f 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1139,3 +1139,10 @@ function group_varnames_by_symbol(vns::VarNameTuple) elements = map(collect, tuple((filter(vn -> getsym(vn) == s, vns) for s in syms)...)) return NamedTuple{syms}(elements) end + +""" + basetypeof(x) + +Return `typeof(x)` stripped of its type parameters. +""" +basetypeof(x::T) where {T} = Base.typename(T).wrapper diff --git a/src/varinfo.jl b/src/varinfo.jl index 101eb6d50..b364f5bcc 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -447,7 +447,7 @@ end function subset(varinfo::VarInfo, vns::AbstractVector{<:VarName}) metadata = subset(varinfo.metadata, vns) - return VarInfo(metadata, copy(varinfo.accs)) + return VarInfo(metadata, subset(getaccs(varinfo), vns)) end function subset(metadata::NamedTuple, vns::AbstractVector{<:VarName}) @@ -528,7 +528,8 @@ end function _merge(varinfo_left::VarInfo, varinfo_right::VarInfo) metadata = merge_metadata(varinfo_left.metadata, varinfo_right.metadata) - return VarInfo(metadata, copy(varinfo_right.accs)) + accs = merge(getaccs(varinfo_left), getaccs(varinfo_right)) + return VarInfo(metadata, accs) end function merge_metadata(vnv_left::VarNamedVector, vnv_right::VarNamedVector) diff --git a/test/accumulators.jl b/test/accumulators.jl index 506821c38..d84fbf43d 100644 --- a/test/accumulators.jl +++ b/test/accumulators.jl @@ -39,13 +39,11 @@ using DynamicPPL: end @testset "addition and incrementation" begin - @test LogPriorAccumulator(1.0f0) + LogPriorAccumulator(1.0f0) == - LogPriorAccumulator(2.0f0) - @test LogPriorAccumulator(1.0) + LogPriorAccumulator(1.0f0) == - LogPriorAccumulator(2.0) - @test LogLikelihoodAccumulator(1.0f0) + LogLikelihoodAccumulator(1.0f0) == + @test acclogp(LogPriorAccumulator(1.0f0), 1.0f0) == LogPriorAccumulator(2.0f0) + @test acclogp(LogPriorAccumulator(1.0), 1.0f0) == LogPriorAccumulator(2.0) + @test acclogp(LogLikelihoodAccumulator(1.0f0), 1.0f0) == LogLikelihoodAccumulator(2.0f0) - @test LogLikelihoodAccumulator(1.0) + LogLikelihoodAccumulator(1.0f0) == + @test acclogp(LogLikelihoodAccumulator(1.0), 1.0f0) == LogLikelihoodAccumulator(2.0) @test increment(VariableOrderAccumulator()) == VariableOrderAccumulator(1) @test increment(VariableOrderAccumulator{UInt8}()) == @@ -110,6 +108,73 @@ using DynamicPPL: @test accumulate_observe!!(VariableOrderAccumulator(1), right, left, vn) == VariableOrderAccumulator(2) end + + @testset "merge" begin + @test merge(LogPriorAccumulator(1.0), LogPriorAccumulator(2.0)) == + LogPriorAccumulator(2.0) + @test merge(LogJacobianAccumulator(1.0), LogJacobianAccumulator(2.0)) == + LogJacobianAccumulator(2.0) + @test merge(LogLikelihoodAccumulator(1.0), LogLikelihoodAccumulator(2.0)) == + LogLikelihoodAccumulator(2.0) + + @test merge( + VariableOrderAccumulator(1, Dict{VarName,Int}()), + VariableOrderAccumulator(2, Dict{VarName,Int}()), + ) == VariableOrderAccumulator(2, Dict{VarName,Int}()) + @test merge( + VariableOrderAccumulator( + 2, Dict{VarName,Int}((@varname(a) => 1, @varname(b) => 2)) + ), + VariableOrderAccumulator( + 1, Dict{VarName,Int}((@varname(a) => 2, @varname(c) => 3)) + ), + ) == VariableOrderAccumulator( + 1, Dict{VarName,Int}((@varname(a) => 2, @varname(b) => 2, @varname(c) => 3)) + ) + end + + @testset "subset" begin + @test subset(LogPriorAccumulator(1.0), VarName[]) == LogPriorAccumulator(1.0) + @test subset(LogJacobianAccumulator(1.0), VarName[]) == + LogJacobianAccumulator(1.0) + @test subset(LogLikelihoodAccumulator(1.0), VarName[]) == + LogLikelihoodAccumulator(1.0) + + @test subset( + VariableOrderAccumulator(1, Dict{VarName,Int}()), + VarName[@varname(a), @varname(b)], + ) == VariableOrderAccumulator(1, Dict{VarName,Int}()) + @test subset( + VariableOrderAccumulator( + 2, Dict{VarName,Int}((@varname(a) => 1, @varname(b) => 2)) + ), + VarName[@varname(a)], + ) == VariableOrderAccumulator(2, Dict{VarName,Int}((@varname(a) => 1))) + @test subset( + VariableOrderAccumulator( + 2, Dict{VarName,Int}((@varname(a) => 1, @varname(b) => 2)) + ), + VarName[], + ) == VariableOrderAccumulator(2, Dict{VarName,Int}()) + @test subset( + VariableOrderAccumulator( + 2, + Dict{VarName,Int}(( + @varname(a) => 1, + @varname(a.b.c) => 2, + @varname(a.b.c.d[1]) => 2, + @varname(b) => 3, + @varname(c[1]) => 4, + )), + ), + VarName[@varname(a.b), @varname(b)], + ) == VariableOrderAccumulator( + 2, + Dict{VarName,Int}(( + @varname(a.b.c) => 2, @varname(a.b.c.d[1]) => 2, @varname(b) => 3 + )), + ) + end end @testset "accumulator tuples" begin @@ -118,7 +183,7 @@ using DynamicPPL: lp_f32 = LogPriorAccumulator(1.0f0) ll_f64 = LogLikelihoodAccumulator(1.0) ll_f32 = LogLikelihoodAccumulator(1.0f0) - np_i64 = VariableOrderAccumulator(1) + vo_i64 = VariableOrderAccumulator(1) @testset "constructors" begin @test AccumulatorTuple(lp_f64, ll_f64) == AccumulatorTuple((lp_f64, ll_f64)) @@ -132,22 +197,22 @@ using DynamicPPL: end @testset "basic operations" begin - at_all64 = AccumulatorTuple(lp_f64, ll_f64, np_i64) + at_all64 = AccumulatorTuple(lp_f64, ll_f64, vo_i64) @test at_all64[:LogPrior] == lp_f64 @test at_all64[:LogLikelihood] == ll_f64 - @test at_all64[:VariableOrder] == np_i64 + @test at_all64[:VariableOrder] == vo_i64 - @test haskey(AccumulatorTuple(np_i64), Val(:VariableOrder)) - @test ~haskey(AccumulatorTuple(np_i64), Val(:LogPrior)) - @test length(AccumulatorTuple(lp_f64, ll_f64, np_i64)) == 3 + @test haskey(AccumulatorTuple(vo_i64), Val(:VariableOrder)) + @test ~haskey(AccumulatorTuple(vo_i64), Val(:LogPrior)) + @test length(AccumulatorTuple(lp_f64, ll_f64, vo_i64)) == 3 @test keys(at_all64) == (:LogPrior, :LogLikelihood, :VariableOrder) - @test collect(at_all64) == [lp_f64, ll_f64, np_i64] + @test collect(at_all64) == [lp_f64, ll_f64, vo_i64] # Replace the existing LogPriorAccumulator @test setacc!!(at_all64, lp_f32)[:LogPrior] == lp_f32 # Check that setacc!! didn't modify the original - @test at_all64 == AccumulatorTuple(lp_f64, ll_f64, np_i64) + @test at_all64 == AccumulatorTuple(lp_f64, ll_f64, vo_i64) # Add a new accumulator type. @test setacc!!(AccumulatorTuple(lp_f64), ll_f64) == AccumulatorTuple(lp_f64, ll_f64) @@ -175,6 +240,52 @@ using DynamicPPL: acc -> convert_eltype(Float64, acc), accs, Val(:LogLikelihood) ) == AccumulatorTuple(lp_f32, LogLikelihoodAccumulator(1.0)) end + + @testset "merge" begin + vo1 = VariableOrderAccumulator( + 1, Dict{VarName,Int}(@varname(a) => 1, @varname(b) => 1) + ) + vo2 = VariableOrderAccumulator( + 2, Dict{VarName,Int}(@varname(a) => 2, @varname(c) => 2) + ) + accs1 = AccumulatorTuple(lp_f64, ll_f64, vo1) + accs2 = AccumulatorTuple(lp_f32, vo2) + @test merge(accs1, accs2) == AccumulatorTuple( + ll_f64, + lp_f32, + VariableOrderAccumulator( + 2, + Dict{VarName,Int}(@varname(a) => 2, @varname(b) => 1, @varname(c) => 2), + ), + ) + @test merge(AccumulatorTuple(), accs1) == accs1 + @test merge(accs1, AccumulatorTuple()) == accs1 + @test merge(accs1, accs1) == accs1 + end + + @testset "subset" begin + accs = AccumulatorTuple( + lp_f64, + ll_f64, + VariableOrderAccumulator( + 1, + Dict{VarName,Int}( + @varname(a.b) => 1, @varname(a.b[1]) => 2, @varname(b) => 1 + ), + ), + ) + + @test subset(accs, VarName[]) == AccumulatorTuple( + lp_f64, ll_f64, VariableOrderAccumulator(1, Dict{VarName,Int}()) + ) + @test subset(accs, VarName[@varname(a)]) == AccumulatorTuple( + lp_f64, + ll_f64, + VariableOrderAccumulator( + 1, Dict{VarName,Int}(@varname(a.b) => 1, @varname(a.b[1]) => 2) + ), + ) + end end end