diff --git a/HISTORY.md b/HISTORY.md index fcd005579..d367e9ad7 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -2,7 +2,15 @@ ## 0.37.0 -**Breaking changes** +DynamicPPL 0.37 comes with a substantial reworking of its internals. +Fundamentally, there is no change to the actual modelling syntax: if you are a Turing.jl user, for example, this release is unlikely to affect you much. +However, if you are a package developer or someone who uses DynamicPPL's functionality directly, you will notice a number of changes. + +To avoid overwhelming the reader, we begin by listing the most important, user-facing changes, before explaining the changes to the internals in more detail. + +Note that virtually all changes listed here are breaking. + +**Public-facing changes** ### Submodel macro @@ -19,6 +27,32 @@ There is now also an `rng` keyword argument to help seed parameter generation. Finally, instead of specifying `value_atol` and `grad_atol`, you can now specify `atol` and `rtol` which are used for both value and gradient. Their semantics are the same as in Julia's `isapprox`; two values are equal if they satisfy either `atol` or `rtol`. +### `DynamicPPL.TestUtils.check_model` + +You now need to explicitly pass a `VarInfo` argument to `check_model` and `check_model_and_trace`. +Previously, these functions would generate a new VarInfo for you (using an optionally provided `rng`). + +### Removal of `PriorContext` and `LikelihoodContext` + +A number of DynamicPPL's contexts have been removed, most notably `PriorContext` and `LikelihoodContext`. +Although these are not the only _exported_ contexts, we consider unlikely that anyone was using _other_ contexts manually: if you have a question about contexts _other_ than these, please continue reading the 'Internals' section below. + +Previously, during evaluation of a model, DynamicPPL only had the capability to store a _single_ log probability (`logp`) field. +`DefaultContext`, `PriorContext`, and `LikelihoodContext` were used to control what this field represented: they would accumulate the log joint, log prior, or log likelihood, respectively. + +Now, we have reworked DynamicPPL's `VarInfo` object such that it can track multiple log probabilities at once (see the 'Accumulators' section below). +If you were evaluating a model with `PriorContext`, you can now just evaluate it with `DefaultContext`, and instead of calling `getlogp(varinfo)`, you can call `getlogprior(varinfo)` (and similarly for the likelihood). + +If you were constructing a `LogDensityFunction` with `PriorContext`, you can now stick to `DefaultContext`. +`LogDensityFunction` now has an extra field, called `getlogdensity`, which represents a function that takes a `VarInfo` and returns the log density you want. +Thus, if you pass `getlogprior` as the value of this parameter, you will get the same behaviour as with `PriorContext`. + +The other case where one might use `PriorContext` was to use `@addlogprob!` to add to the log prior. +Previously, this was accomplished by manually checking `__context__ isa DynamicPPL.PriorContext`. +Now, you can write `@addlogprob (; logprior=x, loglikelihood=y)` to add `x` to the log-prior and `y` to the log-likelihood. + +**Internals** + ### Accumulators This release overhauls how VarInfo objects track variables such as the log joint probability. The new approach is to use what we call accumulators: Objects that the VarInfo carries on it that may change their state at each `tilde_assume!!` and `tilde_observe!!` call based on the value of the variable in question. They replace both variables that were previously hard-coded in the `VarInfo` object (`logp` and `num_produce`) and some contexts. This brings with it a number of breaking changes: diff --git a/src/accumulators.jl b/src/accumulators.jl index 10a988ae5..595c45d3f 100644 --- a/src/accumulators.jl +++ b/src/accumulators.jl @@ -53,10 +53,11 @@ function accumulate_observe!! end Update `acc` in a `tilde_assume!!` call. Returns the updated `acc`. -`vn` is the name of the variable being assumed, `val` is the value of the variable, and -`right` is the distribution on the RHS of the tilde statement. `logjac` is the log -determinant of the Jacobian of the transformation that was done to convert the value of `vn` -as it was given (e.g. by sampler operating in linked space) to `val`. +`vn` is the name of the variable being assumed, `val` is the value of the variable (in the +original, unlinked space), and `right` is the distribution on the RHS of the tilde +statement. `logjac` is the log determinant of the Jacobian of the transformation that was +done to convert the value of `vn` as it was given to `val`: for example, if the sampler is +operating in linked (Euclidean) space, then logjac will be nonzero. `accumulate_assume!!` may mutate `acc`, but not any of the other arguments. @@ -71,7 +72,7 @@ Return a new accumulator like `acc` but empty. The precise meaning of "empty" is that that the returned value should be such that `combine(acc, split(acc))` is equal to `acc`. This is used in the context of multi-threading -where different threads may accumulate independently and the results are the combined. +where different threads may accumulate independently and the results are then combined. See also: [`combine`](@ref) """ @@ -80,7 +81,8 @@ function split end """ combine(acc::AbstractAccumulator, acc2::AbstractAccumulator) -Combine two accumulators of the same type. Returns a new accumulator. +Combine two accumulators which have the same type (but may, in general, have different type +parameters). Returns a new accumulator of the same type. See also: [`split`](@ref) """ diff --git a/src/debug_utils.jl b/src/debug_utils.jl index 4343ce8ac..d1add6e00 100644 --- a/src/debug_utils.jl +++ b/src/debug_utils.jl @@ -49,7 +49,7 @@ end function show_right(io::IO, d::Distribution) pnames = fieldnames(typeof(d)) - uml, namevals = Distributions._use_multline_show(d, pnames) + _, namevals = Distributions._use_multline_show(d, pnames) return Distributions.show_oneline(io, d, namevals) end @@ -76,7 +76,6 @@ Base.@kwdef struct AssumeStmt <: Stmt varname right value - varinfo = nothing end function Base.show(io::IO, stmt::AssumeStmt) @@ -88,21 +87,30 @@ function Base.show(io::IO, stmt::AssumeStmt) print(io, " ") print(io, RESULT_SYMBOL) print(io, " ") - return print(io, stmt.value) + print(io, stmt.value) + return nothing end Base.@kwdef struct ObserveStmt <: Stmt - left + varname right - varinfo = nothing + value end function Base.show(io::IO, stmt::ObserveStmt) io = add_io_context(io) - print(io, "observe: ") - show_right(io, stmt.left) + print(io, " observe: ") + if stmt.varname === nothing + print(io, stmt.value) + else + show_varname(io, stmt.varname) + print(io, " (= ") + print(io, stmt.value) + print(io, ")") + end print(io, " ~ ") - return show_right(io, stmt.right) + show_right(io, stmt.right) + return nothing end # Some utility methods for extracting information from a trace. @@ -124,98 +132,88 @@ distributions_in_stmt(stmt::AssumeStmt) = [stmt.right] distributions_in_stmt(stmt::ObserveStmt) = [stmt.right] """ - DebugContext <: AbstractContext + DebugAccumulator <: AbstractAccumulator -A context used for checking validity of a model. +An accumulator which captures tilde-statements inside a model and attempts to catch +errors in the model. # Fields -$(FIELDS) +$(TYPEDFIELDS) """ -struct DebugContext{C<:AbstractContext} <: AbstractContext - "context used for running the model" - context::C +struct DebugAccumulator <: AbstractAccumulator "mapping from varnames to the number of times they have been seen" varnames_seen::OrderedDict{VarName,Int} "tilde statements that have been executed" statements::Vector{Stmt} - "whether to throw an error if we encounter warnings" + "whether to throw an error if we encounter errors in the model" error_on_failure::Bool - "whether to record the tilde statements" - record_statements::Bool - "whether to record the varinfo in every tilde statement" - record_varinfo::Bool -end - -function DebugContext( - context::AbstractContext=DefaultContext(); - varnames_seen=OrderedDict{VarName,Int}(), - statements=Vector{Stmt}(), - error_on_failure=false, - record_statements=true, - record_varinfo=false, -) - return DebugContext( - context, - varnames_seen, - statements, - error_on_failure, - record_statements, - record_varinfo, - ) end -DynamicPPL.NodeTrait(::DebugContext) = DynamicPPL.IsParent() -DynamicPPL.childcontext(context::DebugContext) = context.context -function DynamicPPL.setchildcontext(context::DebugContext, child) - Accessors.@set context.context = child +function DebugAccumulator(error_on_failure=false) + return DebugAccumulator(OrderedDict{VarName,Int}(), Vector{Stmt}(), error_on_failure) end -function record_varname!(context::DebugContext, varname::VarName, dist) - prefixed_varname = DynamicPPL.prefix(context, varname) - if haskey(context.varnames_seen, prefixed_varname) - if context.error_on_failure - error("varname $prefixed_varname used multiple times in model") +const _DEBUG_ACC_NAME = :Debug +DynamicPPL.accumulator_name(::Type{<:DebugAccumulator}) = _DEBUG_ACC_NAME + +function split(acc::DebugAccumulator) + return DebugAccumulator( + OrderedDict{VarName,Int}(), Vector{Stmt}(), acc.error_on_failure + ) +end +function combine(acc1::DebugAccumulator, acc2::DebugAccumulator) + return DebugAccumulator( + merge(acc1.varnames_seen, acc2.varnames_seen), + vcat(acc1.statements, acc2.statements), + acc1.error_on_failure || acc2.error_on_failure, + ) +end + +function record_varname!(acc::DebugAccumulator, varname::VarName, dist) + if haskey(acc.varnames_seen, varname) + if acc.error_on_failure + error("varname $varname used multiple times in model") else - @warn "varname $prefixed_varname used multiple times in model" + @warn "varname $varname used multiple times in model" end - context.varnames_seen[prefixed_varname] += 1 + acc.varnames_seen[varname] += 1 else # We need to check: # 1. Does this `varname` subsume any of the other keys. # 2. Does any of the other keys subsume `varname`. - vns = collect(keys(context.varnames_seen)) + vns = collect(keys(acc.varnames_seen)) # Is `varname` subsumed by any of the other keys? - idx_parent = findfirst(Base.Fix2(subsumes, prefixed_varname), vns) + idx_parent = findfirst(Base.Fix2(subsumes, varname), vns) if idx_parent !== nothing varname_parent = vns[idx_parent] - if context.error_on_failure + if acc.error_on_failure error( - "varname $(varname_parent) used multiple times in model (subsumes $prefixed_varname)", + "varname $(varname_parent) used multiple times in model (subsumes $varname)", ) else - @warn "varname $(varname_parent) used multiple times in model (subsumes $prefixed_varname)" + @warn "varname $(varname_parent) used multiple times in model (subsumes $varname)" end # Update count of parent. - context.varnames_seen[varname_parent] += 1 + acc.varnames_seen[varname_parent] += 1 else # Does `varname` subsume any of the other keys? - idx_child = findfirst(Base.Fix1(subsumes, prefixed_varname), vns) + idx_child = findfirst(Base.Fix1(subsumes, varname), vns) if idx_child !== nothing varname_child = vns[idx_child] - if context.error_on_failure + if acc.error_on_failure error( - "varname $(varname_child) used multiple times in model (subsumed by $prefixed_varname)", + "varname $(varname_child) used multiple times in model (subsumed by $varname)", ) else - @warn "varname $(varname_child) used multiple times in model (subsumed by $prefixed_varname)" + @warn "varname $(varname_child) used multiple times in model (subsumed by $varname)" end # Update count of child. - context.varnames_seen[varname_child] += 1 + acc.varnames_seen[varname_child] += 1 end end - context.varnames_seen[prefixed_varname] = 1 + acc.varnames_seen[varname] = 1 end end @@ -233,83 +231,56 @@ end _has_nans(x::NamedTuple) = any(_has_nans, x) _has_nans(x::AbstractArray) = any(_has_nans, x) _has_nans(x) = isnan(x) +_has_nans(::Missing) = false -# assume -function record_pre_tilde_assume!(context::DebugContext, vn, dist, varinfo) - record_varname!(context, vn, dist) - return nothing -end - -function record_post_tilde_assume!(context::DebugContext, vn, dist, value, varinfo) - stmt = AssumeStmt(; - varname=vn, - right=dist, - value=value, - varinfo=context.record_varinfo ? varinfo : nothing, - ) - if context.record_statements - push!(context.statements, stmt) - end - return nothing +function DynamicPPL.accumulate_assume!!( + acc::DebugAccumulator, val, _logjac, vn::VarName, right::Distribution +) + record_varname!(acc, vn, right) + stmt = AssumeStmt(; varname=vn, right=right, value=val) + push!(acc.statements, stmt) + return acc end -function DynamicPPL.tilde_assume(context::DebugContext, right, vn, vi) - record_pre_tilde_assume!(context, vn, right, vi) - value, vi = DynamicPPL.tilde_assume(childcontext(context), right, vn, vi) - record_post_tilde_assume!(context, vn, right, value, vi) - return value, vi -end -function DynamicPPL.tilde_assume( - rng::Random.AbstractRNG, context::DebugContext, sampler, right, vn, vi +function DynamicPPL.accumulate_observe!!( + acc::DebugAccumulator, right::Distribution, val, vn::Union{VarName,Nothing} ) - record_pre_tilde_assume!(context, vn, right, vi) - value, vi = DynamicPPL.tilde_assume(rng, childcontext(context), sampler, right, vn, vi) - record_post_tilde_assume!(context, vn, right, value, vi) - return value, vi -end - -# observe -function record_pre_tilde_observe!(context::DebugContext, left, dist, varinfo) - # Check for `missing`s; these should not end up here. - if _has_missings(left) - error( - "Encountered `missing` value(s) on the left-hand side" * - " of an observe statement. Using `missing` to de-condition" * - " a variable is only supported for univariate distributions," * - " not for $dist.", + if _has_missings(val) + # If `val` itself is a missing, that's a bug because that should cause + # us to go down the assume path. + val === missing && error( + "Encountered `missing` value on the left-hand side of an observe" * + " statement. This should not happen. Please open an issue at" * + " https://github.com/TuringLang/DynamicPPL.jl.", ) + # Otherwise it's an array with some missing values. + msg = + "Encountered a container with one or more `missing` value(s) on the" * + " left-hand side of an observe statement. To treat the variable on" * + " the left-hand side as a random variable, you should specify a single" * + " `missing` rather than a vector of `missing`s. It is not possible to" * + " set part but not all of a distribution to be `missing`." + if acc.error_on_failure + error(msg) + else + @warn msg + end end # Check for NaN's as well - if _has_nans(left) - error( + if _has_nans(val) + msg = "Encountered a NaN value on the left-hand side of an" * " observe statement; this may indicate that your data" * - " contain NaN values.", - ) + " contain NaN values." + if acc.error_on_failure + error(msg) + else + @warn msg + end end -end - -function record_post_tilde_observe!(context::DebugContext, left, right, varinfo) - stmt = ObserveStmt(; - left=left, right=right, varinfo=context.record_varinfo ? varinfo : nothing - ) - if context.record_statements - push!(context.statements, stmt) - end - return nothing -end - -function DynamicPPL.tilde_observe!!(context::DebugContext, right, left, vn, vi) - record_pre_tilde_observe!(context, left, right, vi) - vi = DynamicPPL.tilde_observe!!(childcontext(context), right, left, vn, vi) - record_post_tilde_observe!(context, left, right, vi) - return vi -end -function DynamicPPL.tilde_observe!!(context::DebugContext, sampler, right, left, vn, vi) - record_pre_tilde_observe!(context, left, right, vi) - vi = DynamicPPL.tilde_observe!!(childcontext(context), sampler, right, left, vn, vi) - record_post_tilde_observe!(context, left, right, vi) - return vi + stmt = ObserveStmt(; varname=vn, right=right, value=val) + push!(acc.statements, stmt) + return acc end _conditioned_varnames(d::AbstractDict) = keys(d) @@ -357,26 +328,26 @@ function check_model_pre_evaluation(model::Model) return issuccess end -function check_model_post_evaluation(model::Model) - return check_varnames_seen(model.context.varnames_seen) +function check_model_post_evaluation(acc::DebugAccumulator) + return check_varnames_seen(acc.varnames_seen) end """ - check_model_and_trace([rng, ]model::Model; kwargs...) + check_model_and_trace(model::Model, varinfo::AbstractVarInfo; error_on_failure=false) -Check that `model` is valid, warning about any potential issues. +Check that evaluating `model` with the given `varinfo` is valid, warning about any potential +issues. This will check the model for the following issues: + 1. Repeated usage of the same varname in a model. -2. Incorrectly treating a variable as random rather than fixed, and vice versa. +2. `NaN` on the left-hand side of observe statements. # Arguments -- `rng::Random.AbstractRNG`: The random number generator to use when evaluating the model. - `model::Model`: The model to check. +- `varinfo::AbstractVarInfo`: The varinfo to use when evaluating the model. -# Keyword Arguments -- `varinfo::VarInfo`: The varinfo to use when evaluating the model. Default: `VarInfo(model)`. -- `context::AbstractContext`: The context to use when evaluating the model. Default: [`DefaultContext`](@ref). +# Keyword Argument - `error_on_failure::Bool`: Whether to throw an error if the model check fails. Default: `false`. # Returns @@ -394,7 +365,9 @@ julia> rng = StableRNG(42); julia> @model demo_correct() = x ~ Normal() demo_correct (generic function with 2 methods) -julia> issuccess, trace = check_model_and_trace(rng, demo_correct()); +julia> model = demo_correct(); varinfo = VarInfo(rng, model); + +julia> issuccess, trace = check_model_and_trace(model, varinfo); julia> issuccess true @@ -402,7 +375,9 @@ true julia> print(trace) assume: x ~ Normal{Float64}(μ=0.0, σ=1.0) ⟼ -0.670252 -julia> issuccess, trace = check_model_and_trace(rng, demo_correct() | (x = 1.0,)); +julia> cond_model = model | (x = 1.0,); + +julia> issuccess, trace = check_model_and_trace(cond_model, VarInfo(cond_model)); ┌ Warning: The model does not contain any parameters. └ @ DynamicPPL.DebugUtils DynamicPPL.jl/src/debug_utils.jl:342 @@ -410,7 +385,7 @@ julia> issuccess true julia> print(trace) -observe: 1.0 ~ Normal{Float64}(μ=0.0, σ=1.0) + observe: x (= 1.0) ~ Normal{Float64}(μ=0.0, σ=1.0) ``` ## Incorrect model @@ -423,58 +398,53 @@ julia> @model function demo_incorrect() end demo_incorrect (generic function with 2 methods) -julia> issuccess, trace = check_model_and_trace(rng, demo_incorrect(); error_on_failure=true); +julia> # Notice that VarInfo(model_incorrect) evaluates the model, but doesn't actually + # alert us to the issue of `x` being sampled twice. + model = demo_incorrect(); varinfo = VarInfo(model); + +julia> issuccess, trace = check_model_and_trace(model, varinfo; error_on_failure=true); ERROR: varname x used multiple times in model ``` """ -function check_model_and_trace(model::Model; kwargs...) - return check_model_and_trace(Random.default_rng(), model; kwargs...) -end function check_model_and_trace( - rng::Random.AbstractRNG, - model::Model; - varinfo=VarInfo(), - error_on_failure=false, - kwargs..., + model::Model, varinfo::AbstractVarInfo; error_on_failure=false ) - # Execute the model with the debug context. - debug_context = DebugContext( - SamplingContext(rng, model.context); error_on_failure=error_on_failure, kwargs... + # Add debug accumulator to the VarInfo. + # Need a NumProduceAccumulator as well or else get_num_produce may throw + # TODO(mhauru) Remove this once VariableOrderAccumulator stuff is done. + varinfo = DynamicPPL.setaccs!!( + deepcopy(varinfo), (DebugAccumulator(error_on_failure), NumProduceAccumulator()) ) - debug_model = DynamicPPL.contextualize(model, debug_context) # Perform checks before evaluating the model. - issuccess = check_model_pre_evaluation(debug_model) + issuccess = check_model_pre_evaluation(model) # Force single-threaded execution. - DynamicPPL.evaluate_threadunsafe!!(debug_model, varinfo) + DynamicPPL.evaluate_threadunsafe!!(model, varinfo) # Perform checks after evaluating the model. - issuccess &= check_model_post_evaluation(debug_model) + debug_acc = DynamicPPL.getacc(varinfo, Val(_DEBUG_ACC_NAME)) + issuccess = issuccess && check_model_post_evaluation(debug_acc) if !issuccess && error_on_failure error("model check failed") end - trace = debug_context.statements + trace = debug_acc.statements return issuccess, trace end """ - check_model([rng, ]model::Model; kwargs...) - -Check that `model` is valid, warning about any potential issues. + check_model(model::Model, varinfo::AbstractVarInfo; error_on_failure=false) -See [`check_model_and_trace`](@ref) for more details on supported keyword arguments -and details of which types of checks are performed. +Check that `model` is valid, warning about any potential issues (or erroring if +`error_on_failure` is `true`). # Returns - `issuccess::Bool`: Whether the model check succeeded. """ -check_model(model::Model; kwargs...) = first(check_model_and_trace(model; kwargs...)) -function check_model(rng::Random.AbstractRNG, model::Model; kwargs...) - return first(check_model_and_trace(rng, model; kwargs...)) -end +check_model(model::Model, varinfo::AbstractVarInfo; error_on_failure=false) = + first(check_model_and_trace(model, varinfo; error_on_failure=error_on_failure)) # Convenience method used to check if all elements in a list are the same. function all_the_same(xs) @@ -490,7 +460,7 @@ function all_the_same(xs) end """ - has_static_constraints([rng, ]model::Model; num_evals=5, kwargs...) + has_static_constraints([rng, ]model::Model; num_evals=5, error_on_failure=false) Return `true` if the model has static constraints, `false` otherwise. @@ -503,19 +473,16 @@ and checking if the model is consistent across runs. # Keyword Arguments - `num_evals::Int`: The number of evaluations to perform. Default: `5`. -- `kwargs...`: Additional keyword arguments to pass to [`check_model_and_trace`](@ref). +- `error_on_failure::Bool`: Whether to throw an error if any of the `num_evals` model + checks fail. Default: `false`. """ -function has_static_constraints(model::Model; kwargs...) - return has_static_constraints(Random.default_rng(), model; kwargs...) -end function has_static_constraints( - rng::Random.AbstractRNG, model::Model; num_evals=5, kwargs... + rng::Random.AbstractRNG, model::Model; num_evals::Int=5, error_on_failure::Bool=false ) + new_model = DynamicPPL.contextualize(model, SamplingContext(rng, SampleFromPrior())) results = map(1:num_evals) do _ - check_model_and_trace(rng, model; kwargs...) + check_model_and_trace(new_model, VarInfo(); error_on_failure=error_on_failure) end - issuccess = all(first, results) - issuccess || throw(ArgumentError("model check failed")) # Extract the distributions and the corresponding bijectors for each run. traces = map(last, results) @@ -527,6 +494,13 @@ function has_static_constraints( # Check if the distributions are the same across all runs. return all_the_same(transforms) end +function has_static_constraints( + model::Model; num_evals::Int=5, error_on_failure::Bool=false +) + return has_static_constraints( + Random.default_rng(), model; num_evals=num_evals, error_on_failure=error_on_failure + ) +end """ gen_evaluator_call_with_types(model[, varinfo]) diff --git a/test/debug_utils.jl b/test/debug_utils.jl index 8279ac51a..5bf741ff3 100644 --- a/test/debug_utils.jl +++ b/test/debug_utils.jl @@ -1,13 +1,6 @@ @testset "check_model" begin - @testset "context interface" begin - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - context = DynamicPPL.DebugUtils.DebugContext() - DynamicPPL.TestUtils.test_context(context, model) - end - end - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - issuccess, trace = check_model_and_trace(model) + issuccess, trace = check_model_and_trace(model, VarInfo(model)) # These models should all work. @test issuccess @@ -33,11 +26,14 @@ return y ~ Normal() end buggy_model = buggy_demo_model() + varinfo = VarInfo(buggy_model) - @test_logs (:warn,) (:warn,) check_model(buggy_model) - issuccess = check_model(buggy_model; record_varinfo=false) + @test_logs (:warn,) (:warn,) check_model(buggy_model, varinfo) + issuccess = check_model(buggy_model, varinfo) @test !issuccess - @test_throws ErrorException check_model(buggy_model; error_on_failure=true) + @test_throws ErrorException check_model( + buggy_model, varinfo; error_on_failure=true + ) end @testset "submodel" begin @@ -48,7 +44,10 @@ return x ~ Normal() end model = ModelOuterBroken() - @test_throws ErrorException check_model(model; error_on_failure=true) + varinfo = VarInfo(model) + @test_throws ErrorException check_model( + model, VarInfo(model); error_on_failure=true + ) @model function ModelOuterWorking() # With automatic prefixing => `x` is not duplicated. @@ -57,7 +56,7 @@ return z end model = ModelOuterWorking() - @test check_model(model; error_on_failure=true) + @test check_model(model, VarInfo(model); error_on_failure=true) # With manual prefixing, https://github.com/TuringLang/DynamicPPL.jl/issues/785 @model function ModelOuterWorking2() @@ -66,7 +65,7 @@ return (x1, x2) end model = ModelOuterWorking2() - @test check_model(model; error_on_failure=true) + @test check_model(model, VarInfo(model); error_on_failure=true) end @testset "subsumes (x then x[1])" begin @@ -77,11 +76,14 @@ return nothing end buggy_model = buggy_subsumes_demo_model() + varinfo = VarInfo(buggy_model) - @test_logs (:warn,) (:warn,) check_model(buggy_model) - issuccess = check_model(buggy_model; record_varinfo=false) + @test_logs (:warn,) (:warn,) check_model(buggy_model, varinfo) + issuccess = check_model(buggy_model, varinfo) @test !issuccess - @test_throws ErrorException check_model(buggy_model; error_on_failure=true) + @test_throws ErrorException check_model( + buggy_model, varinfo; error_on_failure=true + ) end @testset "subsumes (x[1] then x)" begin @@ -92,11 +94,14 @@ return nothing end buggy_model = buggy_subsumes_demo_model() + varinfo = VarInfo(buggy_model) - @test_logs (:warn,) (:warn,) check_model(buggy_model) - issuccess = check_model(buggy_model; record_varinfo=false) + @test_logs (:warn,) (:warn,) check_model(buggy_model, varinfo) + issuccess = check_model(buggy_model, varinfo) @test !issuccess - @test_throws ErrorException check_model(buggy_model; error_on_failure=true) + @test_throws ErrorException check_model( + buggy_model, varinfo; error_on_failure=true + ) end @testset "subsumes (x.a then x)" begin @@ -107,11 +112,14 @@ return nothing end buggy_model = buggy_subsumes_demo_model() + varinfo = VarInfo(buggy_model) - @test_logs (:warn,) (:warn,) check_model(buggy_model) - issuccess = check_model(buggy_model; record_varinfo=false) + @test_logs (:warn,) (:warn,) check_model(buggy_model, varinfo) + issuccess = check_model(buggy_model, varinfo) @test !issuccess - @test_throws ErrorException check_model(buggy_model; error_on_failure=true) + @test_throws ErrorException check_model( + buggy_model, varinfo; error_on_failure=true + ) end end @@ -123,14 +131,14 @@ end end m = demo_nan_in_data([1.0, NaN]) - @test_throws ErrorException check_model(m; error_on_failure=true) + @test_throws ErrorException check_model(m, VarInfo(m); error_on_failure=true) # Test NamedTuples with nested arrays, see #898 @model function demo_nan_complicated(nt) nt ~ product_distribution((x=Normal(), y=Dirichlet([2, 4]))) return x ~ Normal() end m = demo_nan_complicated((x=1.0, y=[NaN, 0.5])) - @test_throws ErrorException check_model(m; error_on_failure=true) + @test_throws ErrorException check_model(m, VarInfo(m); error_on_failure=true) end @testset "incorrect use of condition" begin @@ -139,7 +147,10 @@ return x ~ MvNormal(zeros(length(x)), I) end model = demo_missing_in_multivariate([1.0, missing]) - @test_throws ErrorException check_model(model) + # Have to run this check_model call with an empty varinfo, because actually + # instantiating the VarInfo would cause it to throw a MethodError. + model = contextualize(model, SamplingContext()) + @test_throws ErrorException check_model(model, VarInfo(); error_on_failure=true) end @testset "condition both in args and context" begin @@ -153,8 +164,9 @@ OrderedDict(@varname(x[1]) => 2.0), ] conditioned_model = DynamicPPL.condition(model, vals) + varinfo = VarInfo(conditioned_model) @test_throws ErrorException check_model( - conditioned_model; error_on_failure=true + conditioned_model, varinfo; error_on_failure=true ) end end @@ -163,23 +175,26 @@ @testset "printing statements" begin @testset "assume" begin @model demo_assume() = x ~ Normal() - isuccess, trace = check_model_and_trace(demo_assume()) - @test isuccess + model = demo_assume() + issuccess, trace = check_model_and_trace(model, VarInfo(model)) + @test issuccess @test startswith(string(trace), " assume: x ~ Normal") end @testset "observe" begin @model demo_observe(x) = x ~ Normal() - isuccess, trace = check_model_and_trace(demo_observe(1.0)) - @test isuccess - @test occursin(r"observe: \d+\.\d+ ~ Normal", string(trace)) + model = demo_observe(1.0) + issuccess, trace = check_model_and_trace(model, VarInfo(model)) + @test issuccess + @test occursin(r"observe: x \(= \d+\.\d+\) ~ Normal", string(trace)) end end @testset "comparing multiple traces" begin + # Run the same model but with different VarInfos. model = DynamicPPL.TestUtils.demo_dynamic_constraint() - issuccess_1, trace_1 = check_model_and_trace(model) - issuccess_2, trace_2 = check_model_and_trace(model) + issuccess_1, trace_1 = check_model_and_trace(model, VarInfo(model)) + issuccess_2, trace_2 = check_model_and_trace(model, VarInfo(model)) @test issuccess_1 && issuccess_2 # Should have the same varnames present. @@ -204,7 +219,7 @@ end for ns in [(2,), (2, 2), (2, 2, 2)] model = demo_undef(ns...) - @test check_model(model; error_on_failure=true) + @test check_model(model, VarInfo(model); error_on_failure=true) end end