Skip to content

Get rid of repeated construction of varname lenses #310

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 38 additions & 37 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ const INTERNALNAMES = (:__model__, :__sampler__, :__context__, :__varinfo__, :__
const DEPRECATED_INTERNALNAMES = (:_model, :_sampler, :_context, :_varinfo, :_rng)

"""
isassumption(expr)
isassumption(expr, vn)

Return an expression that can be evaluated to check if `expr` is an assumption in the
model.
Expand All @@ -15,38 +15,38 @@ Let `expr` be `:(x[1])`. It is an assumption in the following cases:

When `expr` is not an expression or symbol (i.e., a literal), this expands to `false`.
"""
function isassumption(expr::Union{Symbol,Expr})
vn = gensym(:vn)

function isassumption(expr::Union{Expr,Symbol}, vn=AbstractPPL.drop_escape(varname(expr)))
return quote
let $vn = $(AbstractPPL.drop_escape(varname(expr)))
if $(DynamicPPL.contextual_isassumption)(__context__, $vn)
# Considered an assumption by `__context__` which means either:
# 1. We hit the default implementation, e.g. using `DefaultContext`,
# which in turn means that we haven't considered if it's one of
# the model arguments, hence we need to check this.
# 2. We are working with a `ConditionContext` _and_ it's NOT in the model arguments,
# i.e. we're trying to condition one of the latent variables.
# In this case, the below will return `true` since the first branch
# will be hit.
# 3. We are working with a `ConditionContext` _and_ it's in the model arguments,
# i.e. we're trying to override the value. This is currently NOT supported.
# TODO: Support by adding context to model, and use `model.args`
# as the default conditioning. Then we no longer need to check `inargnames`
# since it will all be handled by `contextual_isassumption`.
if !($(DynamicPPL.inargnames)($vn, __model__)) ||
$(DynamicPPL.inmissings)($vn, __model__)
true
else
$(maybe_view(expr)) === missing
end
if $(DynamicPPL.contextual_isassumption)(__context__, $vn)
# Considered an assumption by `__context__` which means either:
# 1. We hit the default implementation, e.g. using `DefaultContext`,
# which in turn means that we haven't considered if it's one of
# the model arguments, hence we need to check this.
# 2. We are working with a `ConditionContext` _and_ it's NOT in the model arguments,
# i.e. we're trying to condition one of the latent variables.
# In this case, the below will return `true` since the first branch
# will be hit.
# 3. We are working with a `ConditionContext` _and_ it's in the model arguments,
# i.e. we're trying to override the value. This is currently NOT supported.
# TODO: Support by adding context to model, and use `model.args`
# as the default conditioning. Then we no longer need to check `inargnames`
# since it will all be handled by `contextual_isassumption`.
if !($(DynamicPPL.inargnames)($vn, __model__)) ||
$(DynamicPPL.inmissings)($vn, __model__)
true
else
false
$(maybe_view(expr)) === missing
end
else
false
end
end
end

# failsafe: a literal is never an assumption
isassumption(expr, vn) = :(false)
isassumption(expr) = :(false)

"""
contextual_isassumption(context, vn)

Expand Down Expand Up @@ -80,9 +80,6 @@ function contextual_isassumption(context::PrefixContext, vn)
return contextual_isassumption(childcontext(context), prefix(context, vn))
end

# failsafe: a literal is never an assumption
isassumption(expr) = :(false)

# If we're working with, say, a `Symbol`, then we're not going to `view`.
maybe_view(x) = x
maybe_view(x::Expr) = :($(DynamicPPL.maybe_unwrap_view)(@views($x)))
Expand Down Expand Up @@ -396,7 +393,7 @@ function generate_tilde(left, right)
# more selective with our escape. Until that's the case, we remove them all.
return quote
$vn = $(AbstractPPL.drop_escape(varname(left)))
$isassumption = $(DynamicPPL.isassumption(left))
$isassumption = $(DynamicPPL.isassumption(left, vn))
if $isassumption
$(generate_tilde_assume(left, right, vn))
else
Expand All @@ -417,20 +414,24 @@ function generate_tilde(left, right)
end

function generate_tilde_assume(left, right, vn)
expr = :(
$left = $(DynamicPPL.tilde_assume!)(
tilde = :(
$(DynamicPPL.tilde_assume!)(
__context__,
$(DynamicPPL.unwrap_right_vn)($(DynamicPPL.check_tilde_rhs)($right), $vn)...,
__varinfo__,
)
)

return if left isa Expr
AbstractPPL.drop_escape(
Setfield.setmacro(BangBang.prefermutation, expr; overwrite=true)
)
# `x[i] = ...` needs to become `x = set(x, @lens(_[i]), ...)`
@gensym lens
vn_name = AbstractPPL.vsym(left)
quote
$lens = $(BangBang.prefermutation)($(DynamicPPL.getindexing)($vn))
$vn_name = $(Setfield.set)($vn_name, $lens, $tilde)
end
else
return expr
return :($left = $tilde)
end
end

Expand All @@ -447,7 +448,7 @@ function generate_dot_tilde(left, right)
@gensym vn isassumption
return quote
$vn = $(AbstractPPL.drop_escape(varname(left)))
$isassumption = $(DynamicPPL.isassumption(left))
$isassumption = $(DynamicPPL.isassumption(left, vn))
if $isassumption
$(generate_dot_tilde_assume(left, right, vn))
else
Expand Down