Skip to content

[Merged by Bors] - **BREAKING** Make evaluator a method of the model function #316

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.14.2"
version = "0.15.0"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
23 changes: 13 additions & 10 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -470,10 +470,7 @@ Builds the output expression.
"""
function build_output(modelinfo, linenumbernode)
## Build the anonymous evaluator from the user-provided model definition.

# Remove the name.
evaluatordef = deepcopy(modelinfo[:modeldef])
delete!(evaluatordef, :name)

# Add the internal arguments to the user-specified arguments (positional + keywords).
evaluatordef[:args] = vcat(
Expand All @@ -489,7 +486,13 @@ function build_output(modelinfo, linenumbernode)
evaluatordef[:kwargs] = []

# Replace the user-provided function body with the version created by DynamicPPL.
evaluatordef[:body] = modelinfo[:body]
# We use `MacroTools.@q begin ... end` instead of regular `quote ... end` to ensure
# that no new `LineNumberNode`s are added apart from the reference `linenumbernode`
# to the call site
evaluatordef[:body] = MacroTools.@q begin
$(linenumbernode)
$(modelinfo[:body])
end

## Build the model function.

Expand All @@ -498,24 +501,24 @@ function build_output(modelinfo, linenumbernode)
defaults_namedtuple = modelinfo[:defaults_namedtuple]

# Update the function body of the user-specified model.
# We use a name for the anonymous evaluator that does not conflict with other variables.
modeldef = modelinfo[:modeldef]
@gensym evaluator
# We use `MacroTools.@q begin ... end` instead of regular `quote ... end` to ensure
# that no new `LineNumberNode`s are added apart from the reference `linenumbernode`
# to the call site
modeldef = modelinfo[:modeldef]
modeldef[:body] = MacroTools.@q begin
$(linenumbernode)
$evaluator = $(MacroTools.combinedef(evaluatordef))
return $(DynamicPPL.Model)(
$(QuoteNode(modeldef[:name])),
$evaluator,
$(modeldef[:name]),
$allargs_namedtuple,
$defaults_namedtuple,
)
end

return :($(Base).@__doc__ $(MacroTools.combinedef(modeldef)))
return MacroTools.@q begin
$(MacroTools.combinedef(evaluatordef))
$(Base).@__doc__ $(MacroTools.combinedef(modeldef))
end
end

function warn_empty(body)
Expand Down
16 changes: 8 additions & 8 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ julia> @model function demo()
x ~ Normal(m, 1)
return (; m=m, x=x)
end
demo (generic function with 1 method)
demo (generic function with 2 methods)

julia> model = demo();

Expand Down Expand Up @@ -161,7 +161,7 @@ julia> @model function demo_mv(::Type{TV}=Float64) where {TV}
m[2] ~ Normal()
return m
end
demo_mv (generic function with 2 methods)
demo_mv (generic function with 3 methods)

julia> model = demo_mv();

Expand Down Expand Up @@ -192,13 +192,13 @@ the use of [`@submodel`](@ref).

```jldoctest condition
julia> @model demo_inner() = m ~ Normal()
demo_inner (generic function with 1 method)
demo_inner (generic function with 2 methods)

julia> @model function demo_outer()
m = @submodel demo_inner()
return m
end
demo_outer (generic function with 1 method)
demo_outer (generic function with 2 methods)

julia> model = demo_outer();

Expand All @@ -218,7 +218,7 @@ julia> @model function demo_outer_prefix()
m = @submodel inner demo_inner()
return m
end
demo_outer_prefix (generic function with 1 method)
demo_outer_prefix (generic function with 2 methods)

julia> # This doesn't work now!
conditioned_model = demo_outer_prefix() | (m = 1.0, );
Expand Down Expand Up @@ -279,7 +279,7 @@ julia> @model function demo()
x ~ Normal(m, 1)
return (; m=m, x=x)
end
demo (generic function with 1 method)
demo (generic function with 2 methods)

julia> conditioned_model = condition(demo(), m = 1.0, x = 10.0);

Expand Down Expand Up @@ -333,7 +333,7 @@ julia> @model function demo()
m ~ Normal()
x ~ Normal(m, 1)
end
demo (generic function with 1 method)
demo (generic function with 2 methods)

julia> m = demo();

Expand Down Expand Up @@ -613,7 +613,7 @@ julia> @model function demo(xs)
end
return (m, )
end
demo (generic function with 1 method)
demo (generic function with 2 methods)

julia> model = demo(randn(10));

Expand Down
33 changes: 33 additions & 0 deletions test/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,15 @@ end

return x, y
end
@test length(methods(testmodel_comp)) == 2
testmodel_comp(1.0, 1.2)

# check if drawing from the prior works
@model function testmodel01(x=missing)
x ~ Normal()
return x
end
@test length(methods(testmodel01)) == 3
f0_mm = testmodel01()
@test mean(f0_mm() for _ in 1:1000) ≈ 0.0 atol = 0.1

Expand All @@ -58,6 +60,7 @@ end
x[2] ~ Normal()
return x
end
@test length(methods(testmodel02)) == 3
f0_mm = testmodel02()
@test all(x -> isapprox(x, 0; atol=0.1), mean(f0_mm() for _ in 1:1000))

Expand All @@ -66,6 +69,7 @@ end
return x
end
f01_mm = testmodel03()
@test length(methods(testmodel03)) == 3
@test mean(f01_mm() for _ in 1:1000) ≈ 0.5 atol = 0.1

# test if we get the correct return values
Expand All @@ -78,6 +82,7 @@ end

return x1, x2
end
@test length(methods(testmodel1)) == 2
f1_mm = testmodel1(1.0, 10.0)
@test f1_mm() == (1, 10)

Expand All @@ -95,6 +100,7 @@ end

return x1, x2
end
@test length(methods(testmodel2)) == 2
f1_mm = testmodel2(; x1=1.0, x2=10.0)
@test f1_mm() == (1, 10)

Expand Down Expand Up @@ -461,4 +467,31 @@ end
model = @model(x -> (x ~ Normal()))
end
end

@testset "dispatching with model" begin
f(x) = false

@model demo() = x ~ Normal()
@test !f(demo())
f(::Model{typeof(demo)}) = true
@test f(demo())

# Leads to re-definition of `demo` and trait is not affected.
@test length(methods(demo)) == 2
@model demo() = x ~ Normal()
@test length(methods(demo)) == 2
@test f(demo())

# Ensure we can specialize on arguments.
@model demo(x) = x ~ Normal()
length(methods(demo))
@test f(demo(1.0))
f(::Model{typeof(demo),(:x,)}) = false
@test !f(demo(1.0))
@test f(demo()) # should still be `true`

# Set it to `false` again.
f(::Model{typeof(demo),()}) = false
@test !f(demo())
end
end
2 changes: 1 addition & 1 deletion test/turing/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"

[compat]
DynamicPPL = "0.14"
DynamicPPL = "0.15"
Turing = "0.17"
julia = "1.3"