Skip to content

Implement values_as_in_model using an accumulator #908

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 24 additions & 16 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,18 @@
end
end

"""
make_varname_expression(expr)

Return a `VarName` based on `expr`, concretizing it if necessary.
"""
function make_varname_expression(expr)
# HACK: Usage of `drop_escape` is unfortunate. It's a consequence of the fact
# that in DynamicPPL we the entire function body. Instead we should be
# more selective with our escape. Until that's the case, we remove them all.
return AbstractPPL.drop_escape(varname(expr, need_concretize(expr)))
end

"""
isassumption(expr[, vn])

Expand All @@ -48,10 +60,7 @@
If `vn` is not specified, `AbstractPPL.varname(expr, need_concretize(expr))` will be
used in its place.
"""
function isassumption(
expr::Union{Expr,Symbol},
vn=AbstractPPL.drop_escape(varname(expr, need_concretize(expr))),
)
function isassumption(expr::Union{Expr,Symbol}, vn=make_varname_expression(expr))
return quote
if $(DynamicPPL.contextual_isassumption)(
__context__, $(DynamicPPL.prefix)(__context__, $vn)
Expand Down Expand Up @@ -402,14 +411,18 @@
end

function generate_assign(left, right)
right_expr = :($(TrackedValue)($right))
tilde_expr = generate_tilde(left, right_expr)
# A statement `x := y` reduces to `x = y`, but if __varinfo__ has an accumulator for
# ValuesAsInModel then in addition we push! the pair of `x` and `y` to the accumulator.
@gensym acc right_val vn
return quote
if $(is_extracting_values)(__context__)
$tilde_expr
else
$left = $right
$right_val = $right
if $(DynamicPPL.is_extracting_values)(__varinfo__)
$vn = $(DynamicPPL.prefix)(__context__, $(make_varname_expression(left)))
__varinfo__ = $(map_accumulator!!)(
$acc -> push!($acc, $vn, $right_val), __varinfo__, Val(:ValuesAsInModel)

Check warning on line 422 in src/compiler.jl

View check run for this annotation

Codecov / codecov/patch

src/compiler.jl#L418-L422

Added lines #L418 - L422 were not covered by tests
)
end
$left = $right_val

Check warning on line 425 in src/compiler.jl

View check run for this annotation

Codecov / codecov/patch

src/compiler.jl#L425

Added line #L425 was not covered by tests
end
end

Expand Down Expand Up @@ -437,14 +450,9 @@
# if the LHS represents an observation
@gensym vn isassumption value dist

# HACK: Usage of `drop_escape` is unfortunate. It's a consequence of the fact
# that in DynamicPPL we the entire function body. Instead we should be
# more selective with our escape. Until that's the case, we remove them all.
return quote
$dist = $right
$vn = $(DynamicPPL.resolve_varnames)(
$(AbstractPPL.drop_escape(varname(left, need_concretize(left)))), $dist
)
$vn = $(DynamicPPL.resolve_varnames)($(make_varname_expression(left)), $dist)
$isassumption = $(DynamicPPL.isassumption(left, vn))
if $(DynamicPPL.isfixed(left, vn))
$left = $(DynamicPPL.getfixed_nested)(
Expand Down
95 changes: 31 additions & 64 deletions src/values_as_in_model.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,7 @@
struct TrackedValue{T}
value::T
end

is_tracked_value(::TrackedValue) = true
is_tracked_value(::Any) = false

check_tilde_rhs(x::TrackedValue) = x

"""
ValuesAsInModelContext
ValuesAsInModelAccumulator <: AbstractAccumulator

A context that is used by [`values_as_in_model`](@ref) to obtain values
An accumulator that is used by [`values_as_in_model`](@ref) to obtain values
of the model parameters as they are in the model.

This is particularly useful when working in unconstrained space, but one
Expand All @@ -19,72 +10,47 @@
# Fields
$(TYPEDFIELDS)
"""
struct ValuesAsInModelContext{C<:AbstractContext} <: AbstractContext
struct ValuesAsInModelAccumulator <: AbstractAccumulator
"values that are extracted from the model"
values::OrderedDict
"whether to extract variables on the LHS of :="
include_colon_eq::Bool
"child context"
context::C
end
function ValuesAsInModelContext(include_colon_eq, context::AbstractContext)
return ValuesAsInModelContext(OrderedDict(), include_colon_eq, context)
function ValuesAsInModelAccumulator(include_colon_eq)
return ValuesAsInModelAccumulator(OrderedDict(), include_colon_eq)
end

NodeTrait(::ValuesAsInModelContext) = IsParent()
childcontext(context::ValuesAsInModelContext) = context.context
function setchildcontext(context::ValuesAsInModelContext, child)
return ValuesAsInModelContext(context.values, context.include_colon_eq, child)
end
accumulator_name(::Type{<:ValuesAsInModelAccumulator}) = :ValuesAsInModel

is_extracting_values(context::ValuesAsInModelContext) = context.include_colon_eq
function is_extracting_values(context::AbstractContext)
return is_extracting_values(NodeTrait(context), context)
function split(acc::ValuesAsInModelAccumulator)
return ValuesAsInModelAccumulator(empty(acc.values), acc.include_colon_eq)
end
is_extracting_values(::IsParent, ::AbstractContext) = false
is_extracting_values(::IsLeaf, ::AbstractContext) = false

function Base.push!(context::ValuesAsInModelContext, vn::VarName, value)
return setindex!(context.values, copy(value), prefix(context, vn))
function combine(acc1::ValuesAsInModelAccumulator, acc2::ValuesAsInModelAccumulator)
if acc1.include_colon_eq != acc2.include_colon_eq
msg = "Cannot combine accumulators with different include_colon_eq values."
throw(ArgumentError(msg))

Check warning on line 31 in src/values_as_in_model.jl

View check run for this annotation

Codecov / codecov/patch

src/values_as_in_model.jl#L30-L31

Added lines #L30 - L31 were not covered by tests
end
return ValuesAsInModelAccumulator(
merge(acc1.values, acc2.values), acc1.include_colon_eq
)
end

function broadcast_push!(context::ValuesAsInModelContext, vns, values)
return push!.((context,), vns, values)
function Base.push!(acc::ValuesAsInModelAccumulator, vn::VarName, val)
setindex!(acc.values, deepcopy(val), vn)
return acc
end

# This will be hit if we're broadcasting an `AbstractMatrix` over a `MultivariateDistribution`.
function broadcast_push!(
context::ValuesAsInModelContext, vns::AbstractVector, values::AbstractMatrix
)
for (vn, col) in zip(vns, eachcol(values))
push!(context, vn, col)
end
function is_extracting_values(vi::AbstractVarInfo)
return hasacc(vi, Val(:ValuesAsInModel)) &&
getacc(vi, Val(:ValuesAsInModel)).include_colon_eq
end

# `tilde_asssume`
function tilde_assume(context::ValuesAsInModelContext, right, vn, vi)
if is_tracked_value(right)
value = right.value
else
value, vi = tilde_assume(childcontext(context), right, vn, vi)
end
push!(context, vn, value)
return value, vi
end
function tilde_assume(
rng::Random.AbstractRNG, context::ValuesAsInModelContext, sampler, right, vn, vi
)
if is_tracked_value(right)
value = right.value
else
value, vi = tilde_assume(rng, childcontext(context), sampler, right, vn, vi)
end
# Save the value.
push!(context, vn, value)
# Pass on.
return value, vi
function accumulate_assume!!(acc::ValuesAsInModelAccumulator, val, logjac, vn, right)
return push!(acc, vn, val)
end

accumulate_observe!!(acc::ValuesAsInModelAccumulator, right, left, vn) = acc

"""
values_as_in_model(model::Model, include_colon_eq::Bool, varinfo::AbstractVarInfo[, context::AbstractContext])

Expand All @@ -103,7 +69,7 @@
- `model::Model`: model to extract realizations from.
- `include_colon_eq::Bool`: whether to also include variables on the LHS of `:=`.
- `varinfo::AbstractVarInfo`: variable information to use for the extraction.
- `context::AbstractContext`: base context to use for the extraction. Defaults
- `context::AbstractContext`: evaluation context to use in the extraction. Defaults
to `DynamicPPL.DefaultContext()`.

# Examples
Expand Down Expand Up @@ -164,7 +130,8 @@
varinfo::AbstractVarInfo,
context::AbstractContext=DefaultContext(),
)
context = ValuesAsInModelContext(include_colon_eq, context)
evaluate!!(model, varinfo, context)
return context.values
accs = getaccs(varinfo)
varinfo = setaccs!!(deepcopy(varinfo), (ValuesAsInModelAccumulator(include_colon_eq),))
varinfo = last(evaluate!!(model, varinfo, context))
return getacc(varinfo, Val(:ValuesAsInModel)).values
end
31 changes: 29 additions & 2 deletions test/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -732,10 +732,10 @@ module Issue537 end
y := 100 + x
return (; x, y)
end
@model function demo_tracked_submodel()
@model function demo_tracked_submodel_no_prefix()
return vals ~ to_submodel(demo_tracked(), false)
end
for model in [demo_tracked(), demo_tracked_submodel()]
for model in [demo_tracked(), demo_tracked_submodel_no_prefix()]
# Make sure it's runnable and `y` is present in the return-value.
@test model() isa NamedTuple{(:x, :y)}

Expand All @@ -756,6 +756,33 @@ module Issue537 end
@test haskey(values, @varname(x))
@test !haskey(values, @varname(y))
end

@model function demo_tracked_return_x()
x ~ Normal()
y := 100 + x
return x
end
@model function demo_tracked_submodel_prefix()
return a ~ to_submodel(demo_tracked_return_x())
end
@model function demo_tracked_subsubmodel_prefix()
return b ~ to_submodel(demo_tracked_submodel_prefix())
end
# As above, but the variables should now have their names prefixed with `b.a`.
model = demo_tracked_subsubmodel_prefix()
varinfo = VarInfo(model)
@test haskey(varinfo, @varname(b.a.x))
@test length(keys(varinfo)) == 1

values = values_as_in_model(model, true, deepcopy(varinfo))
@test haskey(values, @varname(b.a.x))
@test haskey(values, @varname(b.a.y))

# And if include_colon_eq is set to `false`, then `values` should
# only contain `x`.
values = values_as_in_model(model, false, deepcopy(varinfo))
@test haskey(values, @varname(b.a.x))
@test length(keys(varinfo)) == 1
end

@testset "signature parsing + TypeWrap" begin
Expand Down
2 changes: 1 addition & 1 deletion test/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
@test DynamicPPL.prefix(ctx2, vn) == @varname(a.x[1])
ctx3 = PrefixContext(@varname(b), ctx2)
@test DynamicPPL.prefix(ctx3, vn) == @varname(b.a.x[1])
ctx4 = DynamicPPL.ValuesAsInModelContext(OrderedDict(), false, ctx3)
ctx4 = DynamicPPL.SamplingContext(ctx3)
@test DynamicPPL.prefix(ctx4, vn) == @varname(b.a.x[1])
end

Expand Down
Loading