From 28e5ba4174994040a954d35f8c43db965bf614da Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 8 Jul 2025 23:20:09 +0100 Subject: [PATCH 01/14] DebugContext -> DebugAccumulator --- src/accumulators.jl | 21 ++-- src/contexts.jl | 2 +- src/debug_utils.jl | 277 ++++++++++++++++---------------------------- test/debug_utils.jl | 86 +++++++------- 4 files changed, 161 insertions(+), 225 deletions(-) diff --git a/src/accumulators.jl b/src/accumulators.jl index 10a988ae5..e0bea91b9 100644 --- a/src/accumulators.jl +++ b/src/accumulators.jl @@ -53,10 +53,11 @@ function accumulate_observe!! end Update `acc` in a `tilde_assume!!` call. Returns the updated `acc`. -`vn` is the name of the variable being assumed, `val` is the value of the variable, and -`right` is the distribution on the RHS of the tilde statement. `logjac` is the log -determinant of the Jacobian of the transformation that was done to convert the value of `vn` -as it was given (e.g. by sampler operating in linked space) to `val`. +`vn` is the name of the variable being assumed, `val` is the value of the variable (in the +original, unlinked space), and `right` is the distribution on the RHS of the tilde +statement. `logjac` is the log determinant of the Jacobian of the transformation that was +done to convert the value of `vn` as it was given to `val`: for example, if the sampler is +operating in linked (Euclidean) space, then logjac will be nonzero. `accumulate_assume!!` may mutate `acc`, but not any of the other arguments. @@ -71,16 +72,16 @@ Return a new accumulator like `acc` but empty. The precise meaning of "empty" is that that the returned value should be such that `combine(acc, split(acc))` is equal to `acc`. This is used in the context of multi-threading -where different threads may accumulate independently and the results are the combined. +where different threads may accumulate independently and the results are then combined. See also: [`combine`](@ref) """ function split end """ - combine(acc::AbstractAccumulator, acc2::AbstractAccumulator) + combine(acc::TAcc, acc2::TAcc) where {TAcc<:AbstractAccumulator} -Combine two accumulators of the same type. Returns a new accumulator. +Combine two accumulators of the same type. Returns a new accumulator of the same type. See also: [`split`](@ref) """ @@ -126,8 +127,10 @@ end AccumulatorTuple(accs::Vararg{AbstractAccumulator}) = AccumulatorTuple(accs) AccumulatorTuple(nt::NamedTuple) = AccumulatorTuple(tuple(nt...)) -# When showing with text/plain, leave out information about the wrapper AccumulatorTuple. -Base.show(io::IO, mime::MIME"text/plain", at::AccumulatorTuple) = show(io, mime, at.nt) +# When showing with text/plain, leave out type information about the wrapper AccumulatorTuple. +function Base.show(io::IO, mime::MIME"text/plain", at::AccumulatorTuple) + return "AccumulatorTuple(" * show(io, mime, at.nt) * ")" +end Base.getindex(at::AccumulatorTuple, idx) = at.nt[idx] Base.length(::AccumulatorTuple{N}) where {N} = N Base.iterate(at::AccumulatorTuple, args...) = iterate(at.nt, args...) diff --git a/src/contexts.jl b/src/contexts.jl index addadfa1a..bafa3ed26 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -472,7 +472,7 @@ end """ conditioned(context::AbstractContext) -Return `NamedTuple` of values that are conditioned on under context`. +Return a `Dict{VarName,Any}` of the values that are conditioned on under `context`. Note that this will recursively traverse the context stack and return a merged version of the condition values. diff --git a/src/debug_utils.jl b/src/debug_utils.jl index 4343ce8ac..d731e497c 100644 --- a/src/debug_utils.jl +++ b/src/debug_utils.jl @@ -49,7 +49,7 @@ end function show_right(io::IO, d::Distribution) pnames = fieldnames(typeof(d)) - uml, namevals = Distributions._use_multline_show(d, pnames) + _, namevals = Distributions._use_multline_show(d, pnames) return Distributions.show_oneline(io, d, namevals) end @@ -76,7 +76,6 @@ Base.@kwdef struct AssumeStmt <: Stmt varname right value - varinfo = nothing end function Base.show(io::IO, stmt::AssumeStmt) @@ -88,21 +87,30 @@ function Base.show(io::IO, stmt::AssumeStmt) print(io, " ") print(io, RESULT_SYMBOL) print(io, " ") - return print(io, stmt.value) + print(io, stmt.value) + return nothing end Base.@kwdef struct ObserveStmt <: Stmt - left + varname right - varinfo = nothing + value end function Base.show(io::IO, stmt::ObserveStmt) io = add_io_context(io) - print(io, "observe: ") - show_right(io, stmt.left) + print(io, " observe: ") + if stmt.varname === nothing + print(io, stmt.value) + else + show_varname(io, stmt.varname) + print(io, " (= ") + print(io, stmt.value) + print(io, ")") + end print(io, " ~ ") - return show_right(io, stmt.right) + show_right(io, stmt.right) + return nothing end # Some utility methods for extracting information from a trace. @@ -124,192 +132,118 @@ distributions_in_stmt(stmt::AssumeStmt) = [stmt.right] distributions_in_stmt(stmt::ObserveStmt) = [stmt.right] """ - DebugContext <: AbstractContext + DebugAccumulator <: AbstractAccumulator -A context used for checking validity of a model. +An accumulator which captures tilde-statements inside a model and attempts to catch +errors in the model. # Fields -$(FIELDS) +$(TYPEDFIELDS) """ -struct DebugContext{C<:AbstractContext} <: AbstractContext - "context used for running the model" - context::C +struct DebugAccumulator <: AbstractAccumulator "mapping from varnames to the number of times they have been seen" varnames_seen::OrderedDict{VarName,Int} "tilde statements that have been executed" statements::Vector{Stmt} - "whether to throw an error if we encounter warnings" + "whether to throw an error if we encounter errors in the model" error_on_failure::Bool - "whether to record the tilde statements" - record_statements::Bool - "whether to record the varinfo in every tilde statement" - record_varinfo::Bool end -function DebugContext( - context::AbstractContext=DefaultContext(); - varnames_seen=OrderedDict{VarName,Int}(), - statements=Vector{Stmt}(), - error_on_failure=false, - record_statements=true, - record_varinfo=false, -) - return DebugContext( - context, - varnames_seen, - statements, - error_on_failure, - record_statements, - record_varinfo, - ) +function DebugAccumulator(error_on_failure=false) + return DebugAccumulator(OrderedDict{VarName,Int}(), Vector{Stmt}(), error_on_failure) end -DynamicPPL.NodeTrait(::DebugContext) = DynamicPPL.IsParent() -DynamicPPL.childcontext(context::DebugContext) = context.context -function DynamicPPL.setchildcontext(context::DebugContext, child) - Accessors.@set context.context = child +const _DEBUG_ACC_NAME = :Debug +DynamicPPL.accumulator_name(::Type{<:DebugAccumulator}) = _DEBUG_ACC_NAME + +function split(acc::DebugAccumulator) + return DebugAccumulator( + OrderedDict{VarName,Int}(), Vector{Stmt}(), acc.error_on_failure + ) +end +function combine(acc1::DebugAccumulator, acc2::DebugAccumulator) + return DebugAccumulator( + merge(acc1.varnames_seen, acc2.varnames_seen), + vcat(acc1.statements, acc2.statements), + acc1.error_on_failure || acc2.error_on_failure, + ) end -function record_varname!(context::DebugContext, varname::VarName, dist) - prefixed_varname = DynamicPPL.prefix(context, varname) - if haskey(context.varnames_seen, prefixed_varname) - if context.error_on_failure - error("varname $prefixed_varname used multiple times in model") +function record_varname!(acc::DebugAccumulator, varname::VarName, dist) + if haskey(acc.varnames_seen, varname) + if acc.error_on_failure + error("varname $varname used multiple times in model") else - @warn "varname $prefixed_varname used multiple times in model" + @warn "varname $varname used multiple times in model" end - context.varnames_seen[prefixed_varname] += 1 + acc.varnames_seen[varname] += 1 else # We need to check: # 1. Does this `varname` subsume any of the other keys. # 2. Does any of the other keys subsume `varname`. - vns = collect(keys(context.varnames_seen)) + vns = collect(keys(acc.varnames_seen)) # Is `varname` subsumed by any of the other keys? - idx_parent = findfirst(Base.Fix2(subsumes, prefixed_varname), vns) + idx_parent = findfirst(Base.Fix2(subsumes, varname), vns) if idx_parent !== nothing varname_parent = vns[idx_parent] - if context.error_on_failure + if acc.error_on_failure error( - "varname $(varname_parent) used multiple times in model (subsumes $prefixed_varname)", + "varname $(varname_parent) used multiple times in model (subsumes $varname)", ) else - @warn "varname $(varname_parent) used multiple times in model (subsumes $prefixed_varname)" + @warn "varname $(varname_parent) used multiple times in model (subsumes $varname)" end # Update count of parent. - context.varnames_seen[varname_parent] += 1 + acc.varnames_seen[varname_parent] += 1 else # Does `varname` subsume any of the other keys? - idx_child = findfirst(Base.Fix1(subsumes, prefixed_varname), vns) + idx_child = findfirst(Base.Fix1(subsumes, varname), vns) if idx_child !== nothing varname_child = vns[idx_child] - if context.error_on_failure + if acc.error_on_failure error( - "varname $(varname_child) used multiple times in model (subsumed by $prefixed_varname)", + "varname $(varname_child) used multiple times in model (subsumed by $varname)", ) else - @warn "varname $(varname_child) used multiple times in model (subsumed by $prefixed_varname)" + @warn "varname $(varname_child) used multiple times in model (subsumed by $varname)" end # Update count of child. - context.varnames_seen[varname_child] += 1 + acc.varnames_seen[varname_child] += 1 end end - context.varnames_seen[prefixed_varname] = 1 + acc.varnames_seen[varname] = 1 end end -_has_missings(x) = ismissing(x) -function _has_missings(x::AbstractArray) - # Can't just use `any` because `x` might contain `undef`. - for i in eachindex(x) - if isassigned(x, i) && _has_missings(x[i]) - return true - end - end - return false -end - _has_nans(x::NamedTuple) = any(_has_nans, x) _has_nans(x::AbstractArray) = any(_has_nans, x) _has_nans(x) = isnan(x) -# assume -function record_pre_tilde_assume!(context::DebugContext, vn, dist, varinfo) - record_varname!(context, vn, dist) - return nothing -end - -function record_post_tilde_assume!(context::DebugContext, vn, dist, value, varinfo) - stmt = AssumeStmt(; - varname=vn, - right=dist, - value=value, - varinfo=context.record_varinfo ? varinfo : nothing, - ) - if context.record_statements - push!(context.statements, stmt) - end - return nothing -end - -function DynamicPPL.tilde_assume(context::DebugContext, right, vn, vi) - record_pre_tilde_assume!(context, vn, right, vi) - value, vi = DynamicPPL.tilde_assume(childcontext(context), right, vn, vi) - record_post_tilde_assume!(context, vn, right, value, vi) - return value, vi -end -function DynamicPPL.tilde_assume( - rng::Random.AbstractRNG, context::DebugContext, sampler, right, vn, vi +function DynamicPPL.accumulate_assume!!( + acc::DebugAccumulator, val, _logjac, vn::VarName, right::Distribution ) - record_pre_tilde_assume!(context, vn, right, vi) - value, vi = DynamicPPL.tilde_assume(rng, childcontext(context), sampler, right, vn, vi) - record_post_tilde_assume!(context, vn, right, value, vi) - return value, vi + record_varname!(acc, vn, right) + stmt = AssumeStmt(; varname=vn, right=right, value=val) + push!(acc.statements, stmt) + return acc end -# observe -function record_pre_tilde_observe!(context::DebugContext, left, dist, varinfo) - # Check for `missing`s; these should not end up here. - if _has_missings(left) - error( - "Encountered `missing` value(s) on the left-hand side" * - " of an observe statement. Using `missing` to de-condition" * - " a variable is only supported for univariate distributions," * - " not for $dist.", - ) - end +function DynamicPPL.accumulate_observe!!( + acc::DebugAccumulator, right::Distribution, val, vn::Union{VarName,Nothing} +) # Check for NaN's as well - if _has_nans(left) + if _has_nans(val) error( "Encountered a NaN value on the left-hand side of an" * " observe statement; this may indicate that your data" * " contain NaN values.", ) end -end - -function record_post_tilde_observe!(context::DebugContext, left, right, varinfo) - stmt = ObserveStmt(; - left=left, right=right, varinfo=context.record_varinfo ? varinfo : nothing - ) - if context.record_statements - push!(context.statements, stmt) - end - return nothing -end - -function DynamicPPL.tilde_observe!!(context::DebugContext, right, left, vn, vi) - record_pre_tilde_observe!(context, left, right, vi) - vi = DynamicPPL.tilde_observe!!(childcontext(context), right, left, vn, vi) - record_post_tilde_observe!(context, left, right, vi) - return vi -end -function DynamicPPL.tilde_observe!!(context::DebugContext, sampler, right, left, vn, vi) - record_pre_tilde_observe!(context, left, right, vi) - vi = DynamicPPL.tilde_observe!!(childcontext(context), sampler, right, left, vn, vi) - record_post_tilde_observe!(context, left, right, vi) - return vi + stmt = ObserveStmt(; varname=vn, right=right, value=val) + push!(acc.statements, stmt) + return acc end _conditioned_varnames(d::AbstractDict) = keys(d) @@ -357,26 +291,26 @@ function check_model_pre_evaluation(model::Model) return issuccess end -function check_model_post_evaluation(model::Model) - return check_varnames_seen(model.context.varnames_seen) +function check_model_post_evaluation(acc::DebugAccumulator) + return check_varnames_seen(acc.varnames_seen) end """ - check_model_and_trace([rng, ]model::Model; kwargs...) + check_model_and_trace(model::Model, varinfo::AbstractVarInfo; error_on_failure=false) -Check that `model` is valid, warning about any potential issues. +Check that evaluating `model` with the given `varinfo` is valid, warning about any potential +issues. This will check the model for the following issues: + 1. Repeated usage of the same varname in a model. -2. Incorrectly treating a variable as random rather than fixed, and vice versa. +2. `NaN` on the left-hand side of observe statements. # Arguments -- `rng::Random.AbstractRNG`: The random number generator to use when evaluating the model. - `model::Model`: The model to check. +- `varinfo::AbstractVarInfo`: The varinfo to use when evaluating the model. -# Keyword Arguments -- `varinfo::VarInfo`: The varinfo to use when evaluating the model. Default: `VarInfo(model)`. -- `context::AbstractContext`: The context to use when evaluating the model. Default: [`DefaultContext`](@ref). +# Keyword Argument - `error_on_failure::Bool`: Whether to throw an error if the model check fails. Default: `false`. # Returns @@ -394,7 +328,9 @@ julia> rng = StableRNG(42); julia> @model demo_correct() = x ~ Normal() demo_correct (generic function with 2 methods) -julia> issuccess, trace = check_model_and_trace(rng, demo_correct()); +julia> model = demo_correct(); varinfo = VarInfo(rng, model); + +julia> issuccess, trace = check_model_and_trace(model, varinfo); julia> issuccess true @@ -427,36 +363,27 @@ julia> issuccess, trace = check_model_and_trace(rng, demo_incorrect(); error_on_ ERROR: varname x used multiple times in model ``` """ -function check_model_and_trace(model::Model; kwargs...) - return check_model_and_trace(Random.default_rng(), model; kwargs...) -end function check_model_and_trace( - rng::Random.AbstractRNG, - model::Model; - varinfo=VarInfo(), - error_on_failure=false, - kwargs..., + model::Model, varinfo::AbstractVarInfo; error_on_failure=false ) - # Execute the model with the debug context. - debug_context = DebugContext( - SamplingContext(rng, model.context); error_on_failure=error_on_failure, kwargs... - ) - debug_model = DynamicPPL.contextualize(model, debug_context) + # Add debug accumulator to the VarInfo. + varinfo = DynamicPPL.setacc!!(deepcopy(varinfo), DebugAccumulator(error_on_failure)) # Perform checks before evaluating the model. - issuccess = check_model_pre_evaluation(debug_model) + issuccess = check_model_pre_evaluation(model) # Force single-threaded execution. - DynamicPPL.evaluate_threadunsafe!!(debug_model, varinfo) + DynamicPPL.evaluate_threadunsafe!!(model, varinfo) # Perform checks after evaluating the model. - issuccess &= check_model_post_evaluation(debug_model) + debug_acc = DynamicPPL.getacc(varinfo, Val(_DEBUG_ACC_NAME)) + issuccess = issuccess && check_model_post_evaluation(debug_acc) if !issuccess && error_on_failure error("model check failed") end - trace = debug_context.statements + trace = debug_acc.statements return issuccess, trace end @@ -471,10 +398,8 @@ and details of which types of checks are performed. # Returns - `issuccess::Bool`: Whether the model check succeeded. """ -check_model(model::Model; kwargs...) = first(check_model_and_trace(model; kwargs...)) -function check_model(rng::Random.AbstractRNG, model::Model; kwargs...) - return first(check_model_and_trace(rng, model; kwargs...)) -end +check_model(model::Model, varinfo::AbstractVarInfo=VarInfo(model); error_on_failure=false) = + first(check_model_and_trace(model, varinfo; error_on_failure=error_on_failure)) # Convenience method used to check if all elements in a list are the same. function all_the_same(xs) @@ -503,19 +428,16 @@ and checking if the model is consistent across runs. # Keyword Arguments - `num_evals::Int`: The number of evaluations to perform. Default: `5`. -- `kwargs...`: Additional keyword arguments to pass to [`check_model_and_trace`](@ref). +- `error_on_failure::Bool`: Whether to throw an error if any of the `num_evals` model + checks fail. Default: `false`. """ -function has_static_constraints(model::Model; kwargs...) - return has_static_constraints(Random.default_rng(), model; kwargs...) -end function has_static_constraints( - rng::Random.AbstractRNG, model::Model; num_evals=5, kwargs... + rng::Random.AbstractRNG, model::Model; num_evals::Int=5, error_on_failure::Bool=false ) + new_model = DynamicPPL.contextualize(model, SamplingContext(rng, SampleFromPrior())) results = map(1:num_evals) do _ - check_model_and_trace(rng, model; kwargs...) + check_model_and_trace(new_model, VarInfo(); error_on_failure=error_on_failure) end - issuccess = all(first, results) - issuccess || throw(ArgumentError("model check failed")) # Extract the distributions and the corresponding bijectors for each run. traces = map(last, results) @@ -527,6 +449,13 @@ function has_static_constraints( # Check if the distributions are the same across all runs. return all_the_same(transforms) end +function has_static_constraints( + model::Model; num_evals::Int=5, error_on_failure::Bool=false +) + return has_static_constraints( + Random.default_rng(), model; num_evals=num_evals, error_on_failure=error_on_failure + ) +end """ gen_evaluator_call_with_types(model[, varinfo]) diff --git a/test/debug_utils.jl b/test/debug_utils.jl index 8279ac51a..e30283fc5 100644 --- a/test/debug_utils.jl +++ b/test/debug_utils.jl @@ -1,13 +1,6 @@ @testset "check_model" begin - @testset "context interface" begin - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - context = DynamicPPL.DebugUtils.DebugContext() - DynamicPPL.TestUtils.test_context(context, model) - end - end - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - issuccess, trace = check_model_and_trace(model) + issuccess, trace = check_model_and_trace(model, VarInfo(model)) # These models should all work. @test issuccess @@ -33,11 +26,14 @@ return y ~ Normal() end buggy_model = buggy_demo_model() + varinfo = VarInfo(buggy_model) @test_logs (:warn,) (:warn,) check_model(buggy_model) - issuccess = check_model(buggy_model; record_varinfo=false) + issuccess = check_model(buggy_model, varinfo) @test !issuccess - @test_throws ErrorException check_model(buggy_model; error_on_failure=true) + @test_throws ErrorException check_model( + buggy_model, varinfo; error_on_failure=true + ) end @testset "submodel" begin @@ -48,7 +44,10 @@ return x ~ Normal() end model = ModelOuterBroken() - @test_throws ErrorException check_model(model; error_on_failure=true) + varinfo = VarInfo(model) + @test_throws ErrorException check_model( + model, VarInfo(model); error_on_failure=true + ) @model function ModelOuterWorking() # With automatic prefixing => `x` is not duplicated. @@ -57,7 +56,7 @@ return z end model = ModelOuterWorking() - @test check_model(model; error_on_failure=true) + @test check_model(model, VarInfo(model); error_on_failure=true) # With manual prefixing, https://github.com/TuringLang/DynamicPPL.jl/issues/785 @model function ModelOuterWorking2() @@ -66,7 +65,7 @@ return (x1, x2) end model = ModelOuterWorking2() - @test check_model(model; error_on_failure=true) + @test check_model(model, VarInfo(model); error_on_failure=true) end @testset "subsumes (x then x[1])" begin @@ -77,11 +76,14 @@ return nothing end buggy_model = buggy_subsumes_demo_model() + varinfo = VarInfo(buggy_model) - @test_logs (:warn,) (:warn,) check_model(buggy_model) - issuccess = check_model(buggy_model; record_varinfo=false) + @test_logs (:warn,) (:warn,) check_model(buggy_model, varinfo) + issuccess = check_model(buggy_model, varinfo) @test !issuccess - @test_throws ErrorException check_model(buggy_model; error_on_failure=true) + @test_throws ErrorException check_model( + buggy_model, varinfo; error_on_failure=true + ) end @testset "subsumes (x[1] then x)" begin @@ -92,11 +94,14 @@ return nothing end buggy_model = buggy_subsumes_demo_model() + varinfo = VarInfo(buggy_model) - @test_logs (:warn,) (:warn,) check_model(buggy_model) - issuccess = check_model(buggy_model; record_varinfo=false) + @test_logs (:warn,) (:warn,) check_model(buggy_model, varinfo) + issuccess = check_model(buggy_model, varinfo) @test !issuccess - @test_throws ErrorException check_model(buggy_model; error_on_failure=true) + @test_throws ErrorException check_model( + buggy_model, varinfo; error_on_failure=true + ) end @testset "subsumes (x.a then x)" begin @@ -107,11 +112,14 @@ return nothing end buggy_model = buggy_subsumes_demo_model() + varinfo = VarInfo(buggy_model) - @test_logs (:warn,) (:warn,) check_model(buggy_model) - issuccess = check_model(buggy_model; record_varinfo=false) + @test_logs (:warn,) (:warn,) check_model(buggy_model, varinfo) + issuccess = check_model(buggy_model, varinfo) @test !issuccess - @test_throws ErrorException check_model(buggy_model; error_on_failure=true) + @test_throws ErrorException check_model( + buggy_model, varinfo; error_on_failure=true + ) end end @@ -123,25 +131,17 @@ end end m = demo_nan_in_data([1.0, NaN]) - @test_throws ErrorException check_model(m; error_on_failure=true) + @test_throws ErrorException check_model(m, VarInfo(m); error_on_failure=true) # Test NamedTuples with nested arrays, see #898 @model function demo_nan_complicated(nt) nt ~ product_distribution((x=Normal(), y=Dirichlet([2, 4]))) return x ~ Normal() end m = demo_nan_complicated((x=1.0, y=[NaN, 0.5])) - @test_throws ErrorException check_model(m; error_on_failure=true) + @test_throws ErrorException check_model(m, VarInfo(m); error_on_failure=true) end @testset "incorrect use of condition" begin - @testset "missing in multivariate" begin - @model function demo_missing_in_multivariate(x) - return x ~ MvNormal(zeros(length(x)), I) - end - model = demo_missing_in_multivariate([1.0, missing]) - @test_throws ErrorException check_model(model) - end - @testset "condition both in args and context" begin @model function demo_condition_both_in_args_and_context(x) return x ~ Normal() @@ -153,8 +153,9 @@ OrderedDict(@varname(x[1]) => 2.0), ] conditioned_model = DynamicPPL.condition(model, vals) + varinfo = VarInfo(conditioned_model) @test_throws ErrorException check_model( - conditioned_model; error_on_failure=true + conditioned_model, varinfo; error_on_failure=true ) end end @@ -163,23 +164,26 @@ @testset "printing statements" begin @testset "assume" begin @model demo_assume() = x ~ Normal() - isuccess, trace = check_model_and_trace(demo_assume()) - @test isuccess + model = demo_assume() + issuccess, trace = check_model_and_trace(model, VarInfo(model)) + @test issuccess @test startswith(string(trace), " assume: x ~ Normal") end @testset "observe" begin @model demo_observe(x) = x ~ Normal() - isuccess, trace = check_model_and_trace(demo_observe(1.0)) - @test isuccess - @test occursin(r"observe: \d+\.\d+ ~ Normal", string(trace)) + model = demo_observe(1.0) + issuccess, trace = check_model_and_trace(model, VarInfo(model)) + @test issuccess + @test occursin(r"observe: x \(= \d+\.\d+\) ~ Normal", string(trace)) end end @testset "comparing multiple traces" begin + # Run the same model but with different VarInfos. model = DynamicPPL.TestUtils.demo_dynamic_constraint() - issuccess_1, trace_1 = check_model_and_trace(model) - issuccess_2, trace_2 = check_model_and_trace(model) + issuccess_1, trace_1 = check_model_and_trace(model, VarInfo(model)) + issuccess_2, trace_2 = check_model_and_trace(model, VarInfo(model)) @test issuccess_1 && issuccess_2 # Should have the same varnames present. @@ -204,7 +208,7 @@ end for ns in [(2,), (2, 2), (2, 2, 2)] model = demo_undef(ns...) - @test check_model(model; error_on_failure=true) + @test check_model(model, VarInfo(model); error_on_failure=true) end end From 88da7bd5e347f3d34a702f0862897f4f19a618ca Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 8 Jul 2025 23:35:52 +0100 Subject: [PATCH 02/14] Changelog --- HISTORY.md | 36 +++++++++++++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/HISTORY.md b/HISTORY.md index 955a28963..1748dca66 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -2,7 +2,15 @@ ## 0.37.0 -**Breaking changes** +DynamicPPL 0.37 comes with a substantial reworking of its internals. +Fundamentally, there is no change to the actual modelling syntax: if you are a Turing.jl user, for example, this release is unlikely to affect you much. +However, if you are a package developer or someone who uses DynamicPPL's functionality directly, you will notice a number of changes. + +To avoid overwhelming the reader, we begin by listing the most important, user-facing changes, before explaining the changes to the internals in more detail. + +Note that virtually all changes listed here are breaking. + +**Public-facing changes** ### Submodel macro @@ -19,6 +27,32 @@ There is now also an `rng` keyword argument to help seed parameter generation. Finally, instead of specifying `value_atol` and `grad_atol`, you can now specify `atol` and `rtol` which are used for both value and gradient. Their semantics are the same as in Julia's `isapprox`; two values are equal if they satisfy either `atol` or `rtol`. +### `DynamicPPL.TestUtils.check_model` + +You now need to explicitly pass a `VarInfo` argument to `check_model` and `check_model_and_trace`. +Previously, these functions would generate a new VarInfo for you (using an optionally provided `rng`). + +### Removal of `PriorContext` and `LikelihoodContext` + +A number of DynamicPPL's contexts have been removed, most notably `PriorContext` and `LikelihoodContext`. +Although these are not the only _exported_ contexts, we consider unlikely that anyone was using _other_ contexts manually: if you have a question about contexts _other_ than these, please continue reading the 'Internals' section below. + +Previously, during evaluation of a model, DynamicPPL only had the capability to store a _single_ log probability (`logp`) field. +`DefaultContext`, `PriorContext`, and `LikelihoodContext` were used to control what this field represented: they would accumulate the log joint, log prior, or log likelihood, respectively. + +Now, we have reworked DynamicPPL's `VarInfo` object such that it can track multiple log probabilities at once (see the 'Accumulators' section below). +If you were evaluating a model with `PriorContext`, you can now just evaluate it with `DefaultContext`, and instead of calling `getlogp(varinfo)`, you can call `getlogprior(varinfo)` (and similarly for the likelihood). + +If you were constructing a `LogDensityFunction` with `PriorContext`, you can now stick to `DefaultContext`. +`LogDensityFunction` now has an extra field, called `getlogdensity`, which represents a function that takes a `VarInfo` and returns the log density you want. +Thus, if you pass `getlogprior` as the value of this parameter, you will get the same behaviour as with `PriorContext`. + +The other case where one might use `PriorContext` was to use `@addlogprob!` to add to the log prior. +Previously, this was accomplished by manually checking `__context__ isa DynamicPPL.PriorContext`. +Now, you can write `@addlogprob (; logprior=x, loglikelihood=y)` to add `x` to the log-prior and `y` to the log-likelihood. + +**Internals** + ### Accumulators This release overhauls how VarInfo objects track variables such as the log joint probability. The new approach is to use what we call accumulators: Objects that the VarInfo carries on it that may change their state at each `tilde_assume!!` and `tilde_observe!!` call based on the value of the variable in question. They replace both variables that were previously hard-coded in the `VarInfo` object (`logp` and `num_produce`) and some contexts. This brings with it a number of breaking changes: From 8c7aff92d83c8ab8bdf1ace91d1d333ed4d7ca91 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 8 Jul 2025 23:45:48 +0100 Subject: [PATCH 03/14] Force `conditioned` to return a dict --- src/contexts.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/contexts.jl b/src/contexts.jl index bafa3ed26..e0a36352f 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -480,14 +480,14 @@ a merged version of the condition values. function conditioned(context::AbstractContext) return conditioned(NodeTrait(conditioned, context), context) end -conditioned(::IsLeaf, context) = NamedTuple() +conditioned(::IsLeaf, context) = Dict{VarName,Any}() conditioned(::IsParent, context) = conditioned(childcontext(context)) function conditioned(context::ConditionContext) # Note the order of arguments to `merge`. The behavior of the rest of DPPL # is that the outermost `context` takes precendence, hence when resolving # the `conditioned` variables we need to ensure that `context.values` takes # precedence over decendants of `context`. - return _merge(context.values, conditioned(childcontext(context))) + return _merge(to_varname_dict(context.values), conditioned(childcontext(context))) end function conditioned(context::PrefixContext) return conditioned(collapse_prefix_stack(context)) From a85f28d03b750c854639ca09545f4bd618b8df2e Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 8 Jul 2025 23:48:30 +0100 Subject: [PATCH 04/14] fix conditioned implementation --- src/contexts.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/contexts.jl b/src/contexts.jl index e0a36352f..41bed5f00 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -487,7 +487,7 @@ function conditioned(context::ConditionContext) # is that the outermost `context` takes precendence, hence when resolving # the `conditioned` variables we need to ensure that `context.values` takes # precedence over decendants of `context`. - return _merge(to_varname_dict(context.values), conditioned(childcontext(context))) + return _merge(conditioned(childcontext(context)), to_varname_dict(context.values)) end function conditioned(context::PrefixContext) return conditioned(collapse_prefix_stack(context)) From 919cb253c485899458e82a1e1be362de1c423928 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 8 Jul 2025 23:50:01 +0100 Subject: [PATCH 05/14] revert `conditioned` bugfix (will merge this to main instead) --- src/contexts.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/contexts.jl b/src/contexts.jl index 41bed5f00..addadfa1a 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -472,7 +472,7 @@ end """ conditioned(context::AbstractContext) -Return a `Dict{VarName,Any}` of the values that are conditioned on under `context`. +Return `NamedTuple` of values that are conditioned on under context`. Note that this will recursively traverse the context stack and return a merged version of the condition values. @@ -480,14 +480,14 @@ a merged version of the condition values. function conditioned(context::AbstractContext) return conditioned(NodeTrait(conditioned, context), context) end -conditioned(::IsLeaf, context) = Dict{VarName,Any}() +conditioned(::IsLeaf, context) = NamedTuple() conditioned(::IsParent, context) = conditioned(childcontext(context)) function conditioned(context::ConditionContext) # Note the order of arguments to `merge`. The behavior of the rest of DPPL # is that the outermost `context` takes precendence, hence when resolving # the `conditioned` variables we need to ensure that `context.values` takes # precedence over decendants of `context`. - return _merge(conditioned(childcontext(context)), to_varname_dict(context.values)) + return _merge(context.values, conditioned(childcontext(context))) end function conditioned(context::PrefixContext) return conditioned(collapse_prefix_stack(context)) From 06499721d92734b8ea47655ca54177c8646f9b4b Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 9 Jul 2025 00:48:01 +0100 Subject: [PATCH 06/14] fix show --- src/accumulators.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/accumulators.jl b/src/accumulators.jl index e0bea91b9..ca394c21c 100644 --- a/src/accumulators.jl +++ b/src/accumulators.jl @@ -129,7 +129,10 @@ AccumulatorTuple(nt::NamedTuple) = AccumulatorTuple(tuple(nt...)) # When showing with text/plain, leave out type information about the wrapper AccumulatorTuple. function Base.show(io::IO, mime::MIME"text/plain", at::AccumulatorTuple) - return "AccumulatorTuple(" * show(io, mime, at.nt) * ")" + print(io, "AccumulatorTuple(") + show(io, mime, at.nt) + print(io, ")") + return nothing end Base.getindex(at::AccumulatorTuple, idx) = at.nt[idx] Base.length(::AccumulatorTuple{N}) where {N} = N From 0ebb56e5b21a20db8fae61599eb37fd1a4bdc345 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 9 Jul 2025 01:46:41 +0100 Subject: [PATCH 07/14] Fix doctests --- src/debug_utils.jl | 12 +++++++++--- src/simple_varinfo.jl | 4 ++-- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/src/debug_utils.jl b/src/debug_utils.jl index d731e497c..e888326d0 100644 --- a/src/debug_utils.jl +++ b/src/debug_utils.jl @@ -338,7 +338,9 @@ true julia> print(trace) assume: x ~ Normal{Float64}(μ=0.0, σ=1.0) ⟼ -0.670252 -julia> issuccess, trace = check_model_and_trace(rng, demo_correct() | (x = 1.0,)); +julia> cond_model = model | (x = 1.0,); + +julia> issuccess, trace = check_model_and_trace(cond_model, VarInfo(cond_model)); ┌ Warning: The model does not contain any parameters. └ @ DynamicPPL.DebugUtils DynamicPPL.jl/src/debug_utils.jl:342 @@ -346,7 +348,7 @@ julia> issuccess true julia> print(trace) -observe: 1.0 ~ Normal{Float64}(μ=0.0, σ=1.0) +observe: x (= 1.0) ~ Normal{Float64}(μ=0.0, σ=1.0) ``` ## Incorrect model @@ -359,7 +361,11 @@ julia> @model function demo_incorrect() end demo_incorrect (generic function with 2 methods) -julia> issuccess, trace = check_model_and_trace(rng, demo_incorrect(); error_on_failure=true); +julia> # Notice that VarInfo(model_incorrect) evaluates the model, but doesn't actually + # alert us to the issue of `x` being sampled twice. + model = demo_incorrect(); varinfo = VarInfo(model); + +julia> issuccess, trace = check_model_and_trace(model, varinfo; error_on_failure=true); ERROR: varname x used multiple times in model ``` """ diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index ddc3275ae..665ce7172 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -122,7 +122,7 @@ Evaluation in transformed space of course also works: ```jldoctest simplevarinfo-general julia> vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), true) -Transformed SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0), NumProduce = NumProduceAccumulator(0))) +Transformed SimpleVarInfo((x = -1.0,), AccumulatorTuple((LogPrior = LogPriorAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0), NumProduce = NumProduceAccumulator(0)))) julia> # (✓) Positive probability mass on negative numbers! getlogjoint(last(DynamicPPL.evaluate!!(m, vi))) @@ -130,7 +130,7 @@ julia> # (✓) Positive probability mass on negative numbers! julia> # While if we forget to indicate that it's transformed: vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), false) -SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0), NumProduce = NumProduceAccumulator(0))) +SimpleVarInfo((x = -1.0,), AccumulatorTuple((LogPrior = LogPriorAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0), NumProduce = NumProduceAccumulator(0)))) julia> # (✓) No probability mass on negative numbers! getlogjoint(last(DynamicPPL.evaluate!!(m, vi))) From e534434dc5d55b782f720608892f5f33256f5350 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 9 Jul 2025 01:50:58 +0100 Subject: [PATCH 08/14] fix doctests 2 --- src/debug_utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/debug_utils.jl b/src/debug_utils.jl index e888326d0..aaef3a170 100644 --- a/src/debug_utils.jl +++ b/src/debug_utils.jl @@ -348,7 +348,7 @@ julia> issuccess true julia> print(trace) -observe: x (= 1.0) ~ Normal{Float64}(μ=0.0, σ=1.0) + observe: x (= 1.0) ~ Normal{Float64}(μ=0.0, σ=1.0) ``` ## Incorrect model From d73bb1400d081b2a700ed020d4b8218a59d2d698 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 9 Jul 2025 17:56:28 +0100 Subject: [PATCH 09/14] Make VarInfo actually mandatory in check_model --- src/debug_utils.jl | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/debug_utils.jl b/src/debug_utils.jl index aaef3a170..79b48c0bd 100644 --- a/src/debug_utils.jl +++ b/src/debug_utils.jl @@ -394,17 +394,15 @@ function check_model_and_trace( end """ - check_model([rng, ]model::Model; kwargs...) + check_model(model::Model, varinfo::AbstractVarInfo; error_on_failure=false) -Check that `model` is valid, warning about any potential issues. - -See [`check_model_and_trace`](@ref) for more details on supported keyword arguments -and details of which types of checks are performed. +Check that `model` is valid, warning about any potential issues (or erroring if +`error_on_failure` is `true`). # Returns - `issuccess::Bool`: Whether the model check succeeded. """ -check_model(model::Model, varinfo::AbstractVarInfo=VarInfo(model); error_on_failure=false) = +check_model(model::Model, varinfo::AbstractVarInfo; error_on_failure=false) = first(check_model_and_trace(model, varinfo; error_on_failure=error_on_failure)) # Convenience method used to check if all elements in a list are the same. @@ -421,7 +419,7 @@ function all_the_same(xs) end """ - has_static_constraints([rng, ]model::Model; num_evals=5, kwargs...) + has_static_constraints([rng, ]model::Model; num_evals=5, error_on_failure=false) Return `true` if the model has static constraints, `false` otherwise. From cd2f96999fae0d224494bc2e50bc68c89fe76cce Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 9 Jul 2025 18:21:27 +0100 Subject: [PATCH 10/14] Re-implement `missing` check --- src/debug_utils.jl | 48 +++++++++++++++++++++++++++++++++++++++++---- test/debug_utils.jl | 13 +++++++++++- 2 files changed, 56 insertions(+), 5 deletions(-) diff --git a/src/debug_utils.jl b/src/debug_utils.jl index 79b48c0bd..5154f0c34 100644 --- a/src/debug_utils.jl +++ b/src/debug_utils.jl @@ -217,9 +217,21 @@ function record_varname!(acc::DebugAccumulator, varname::VarName, dist) end end +_has_missings(x) = ismissing(x) +function _has_missings(x::AbstractArray) + # Can't just use `any` because `x` might contain `undef`. + for i in eachindex(x) + if isassigned(x, i) && _has_missings(x[i]) + return true + end + end + return false +end + _has_nans(x::NamedTuple) = any(_has_nans, x) _has_nans(x::AbstractArray) = any(_has_nans, x) _has_nans(x) = isnan(x) +_has_nans(::Missing) = false function DynamicPPL.accumulate_assume!!( acc::DebugAccumulator, val, _logjac, vn::VarName, right::Distribution @@ -233,13 +245,38 @@ end function DynamicPPL.accumulate_observe!!( acc::DebugAccumulator, right::Distribution, val, vn::Union{VarName,Nothing} ) + if _has_missings(val) + # If `val` itself is a missing, that's a bug because that should cause + # us to go down the assume path. + val === missing && error( + "Encountered `missing` value on the left-hand side of an observe" * + " statement. This should not happen. Please open an issue at" * + " https://github.com/TuringLang/DynamicPPL.jl.", + ) + # Otherwise it's an array with some missing values. + msg = + "Encountered a container with one or more `missing` value(s) on the" * + " left-hand side of an observe statement. To treat the variable on" * + " the left-hand side as a random variable, you should specify a single" * + " `missing` rather than a vector of `missing`s. It is not possible to" * + " set part but not all of a distribution to be `missing`." + if acc.error_on_failure + error(msg) + else + @warn msg + end + end # Check for NaN's as well if _has_nans(val) - error( + msg = "Encountered a NaN value on the left-hand side of an" * " observe statement; this may indicate that your data" * - " contain NaN values.", - ) + " contain NaN values." + if acc.error_on_failure + error(msg) + else + @warn msg + end end stmt = ObserveStmt(; varname=vn, right=right, value=val) push!(acc.statements, stmt) @@ -373,7 +410,10 @@ function check_model_and_trace( model::Model, varinfo::AbstractVarInfo; error_on_failure=false ) # Add debug accumulator to the VarInfo. - varinfo = DynamicPPL.setacc!!(deepcopy(varinfo), DebugAccumulator(error_on_failure)) + # Need a NumProduceAccumulator as well or else get_num_produce may throw + varinfo = DynamicPPL.setaccs!!( + deepcopy(varinfo), (DebugAccumulator(error_on_failure), NumProduceAccumulator()) + ) # Perform checks before evaluating the model. issuccess = check_model_pre_evaluation(model) diff --git a/test/debug_utils.jl b/test/debug_utils.jl index e30283fc5..5bf741ff3 100644 --- a/test/debug_utils.jl +++ b/test/debug_utils.jl @@ -28,7 +28,7 @@ buggy_model = buggy_demo_model() varinfo = VarInfo(buggy_model) - @test_logs (:warn,) (:warn,) check_model(buggy_model) + @test_logs (:warn,) (:warn,) check_model(buggy_model, varinfo) issuccess = check_model(buggy_model, varinfo) @test !issuccess @test_throws ErrorException check_model( @@ -142,6 +142,17 @@ end @testset "incorrect use of condition" begin + @testset "missing in multivariate" begin + @model function demo_missing_in_multivariate(x) + return x ~ MvNormal(zeros(length(x)), I) + end + model = demo_missing_in_multivariate([1.0, missing]) + # Have to run this check_model call with an empty varinfo, because actually + # instantiating the VarInfo would cause it to throw a MethodError. + model = contextualize(model, SamplingContext()) + @test_throws ErrorException check_model(model, VarInfo(); error_on_failure=true) + end + @testset "condition both in args and context" begin @model function demo_condition_both_in_args_and_context(x) return x ~ Normal() From 1f10b181ddcdfa78f9974dbdf07c24560a1f6929 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 9 Jul 2025 18:34:30 +0100 Subject: [PATCH 11/14] Revert `combine` signature in docstring --- src/accumulators.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/accumulators.jl b/src/accumulators.jl index ca394c21c..4a6de7a04 100644 --- a/src/accumulators.jl +++ b/src/accumulators.jl @@ -79,9 +79,10 @@ See also: [`combine`](@ref) function split end """ - combine(acc::TAcc, acc2::TAcc) where {TAcc<:AbstractAccumulator} + combine(acc::AbstractAccumulator, acc2::AbstractAccumulator) -Combine two accumulators of the same type. Returns a new accumulator of the same type. +Combine two accumulators which have the same type (but may, in general, have different type +parameters). Returns a new accumulator of the same type. See also: [`split`](@ref) """ From 063560721ba95ba757de27e1aa80c1a2c879b2ee Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 15 Jul 2025 19:01:52 +0100 Subject: [PATCH 12/14] Revert changes to `Base.show` on AccumulatorTuple --- src/accumulators.jl | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/src/accumulators.jl b/src/accumulators.jl index 4a6de7a04..595c45d3f 100644 --- a/src/accumulators.jl +++ b/src/accumulators.jl @@ -128,13 +128,8 @@ end AccumulatorTuple(accs::Vararg{AbstractAccumulator}) = AccumulatorTuple(accs) AccumulatorTuple(nt::NamedTuple) = AccumulatorTuple(tuple(nt...)) -# When showing with text/plain, leave out type information about the wrapper AccumulatorTuple. -function Base.show(io::IO, mime::MIME"text/plain", at::AccumulatorTuple) - print(io, "AccumulatorTuple(") - show(io, mime, at.nt) - print(io, ")") - return nothing -end +# When showing with text/plain, leave out information about the wrapper AccumulatorTuple. +Base.show(io::IO, mime::MIME"text/plain", at::AccumulatorTuple) = show(io, mime, at.nt) Base.getindex(at::AccumulatorTuple, idx) = at.nt[idx] Base.length(::AccumulatorTuple{N}) where {N} = N Base.iterate(at::AccumulatorTuple, args...) = iterate(at.nt, args...) From 40eddde2c839ad4e4b513e7339928a218ee9ac90 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 15 Jul 2025 19:02:33 +0100 Subject: [PATCH 13/14] Add TODO comment about VariableOrderAccumulator Co-authored-by: Markus Hauru --- src/debug_utils.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/debug_utils.jl b/src/debug_utils.jl index 5154f0c34..d1add6e00 100644 --- a/src/debug_utils.jl +++ b/src/debug_utils.jl @@ -411,6 +411,7 @@ function check_model_and_trace( ) # Add debug accumulator to the VarInfo. # Need a NumProduceAccumulator as well or else get_num_produce may throw + # TODO(mhauru) Remove this once VariableOrderAccumulator stuff is done. varinfo = DynamicPPL.setaccs!!( deepcopy(varinfo), (DebugAccumulator(error_on_failure), NumProduceAccumulator()) ) From 408b9511103752979ddfe29ed6075a397f6bb993 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 16 Jul 2025 17:13:51 +0100 Subject: [PATCH 14/14] Fix doctests --- src/simple_varinfo.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 665ce7172..ddc3275ae 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -122,7 +122,7 @@ Evaluation in transformed space of course also works: ```jldoctest simplevarinfo-general julia> vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), true) -Transformed SimpleVarInfo((x = -1.0,), AccumulatorTuple((LogPrior = LogPriorAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0), NumProduce = NumProduceAccumulator(0)))) +Transformed SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0), NumProduce = NumProduceAccumulator(0))) julia> # (✓) Positive probability mass on negative numbers! getlogjoint(last(DynamicPPL.evaluate!!(m, vi))) @@ -130,7 +130,7 @@ julia> # (✓) Positive probability mass on negative numbers! julia> # While if we forget to indicate that it's transformed: vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), false) -SimpleVarInfo((x = -1.0,), AccumulatorTuple((LogPrior = LogPriorAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0), NumProduce = NumProduceAccumulator(0)))) +SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0), NumProduce = NumProduceAccumulator(0))) julia> # (✓) No probability mass on negative numbers! getlogjoint(last(DynamicPPL.evaluate!!(m, vi)))