Skip to content

Commit a836e18

Browse files
authored
Try #316:
2 parents 80240bf + a558e2e commit a836e18

File tree

3 files changed

+54
-18
lines changed

3 files changed

+54
-18
lines changed

src/compiler.jl

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -470,10 +470,7 @@ Builds the output expression.
470470
"""
471471
function build_output(modelinfo, linenumbernode)
472472
## Build the anonymous evaluator from the user-provided model definition.
473-
474-
# Remove the name.
475473
evaluatordef = deepcopy(modelinfo[:modeldef])
476-
delete!(evaluatordef, :name)
477474

478475
# Add the internal arguments to the user-specified arguments (positional + keywords).
479476
evaluatordef[:args] = vcat(
@@ -489,7 +486,13 @@ function build_output(modelinfo, linenumbernode)
489486
evaluatordef[:kwargs] = []
490487

491488
# Replace the user-provided function body with the version created by DynamicPPL.
492-
evaluatordef[:body] = modelinfo[:body]
489+
# We use `MacroTools.@q begin ... end` instead of regular `quote ... end` to ensure
490+
# that no new `LineNumberNode`s are added apart from the reference `linenumbernode`
491+
# to the call site
492+
evaluatordef[:body] = MacroTools.@q begin
493+
$(linenumbernode)
494+
$(modelinfo[:body])
495+
end
493496

494497
## Build the model function.
495498

@@ -498,24 +501,24 @@ function build_output(modelinfo, linenumbernode)
498501
defaults_namedtuple = modelinfo[:defaults_namedtuple]
499502

500503
# Update the function body of the user-specified model.
501-
# We use a name for the anonymous evaluator that does not conflict with other variables.
502-
modeldef = modelinfo[:modeldef]
503-
@gensym evaluator
504504
# We use `MacroTools.@q begin ... end` instead of regular `quote ... end` to ensure
505505
# that no new `LineNumberNode`s are added apart from the reference `linenumbernode`
506506
# to the call site
507+
modeldef = modelinfo[:modeldef]
507508
modeldef[:body] = MacroTools.@q begin
508509
$(linenumbernode)
509-
$evaluator = $(MacroTools.combinedef(evaluatordef))
510510
return $(DynamicPPL.Model)(
511511
$(QuoteNode(modeldef[:name])),
512-
$evaluator,
512+
$(modeldef[:name]),
513513
$allargs_namedtuple,
514514
$defaults_namedtuple,
515515
)
516516
end
517517

518-
return :($(Base).@__doc__ $(MacroTools.combinedef(modeldef)))
518+
return MacroTools.@q begin
519+
$(MacroTools.combinedef(evaluatordef))
520+
$(Base).@__doc__ $(MacroTools.combinedef(modeldef))
521+
end
519522
end
520523

521524
function warn_empty(body)

src/model.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ julia> @model function demo()
121121
x ~ Normal(m, 1)
122122
return (; m=m, x=x)
123123
end
124-
demo (generic function with 1 method)
124+
demo (generic function with 2 methods)
125125
126126
julia> model = demo();
127127
@@ -161,7 +161,7 @@ julia> @model function demo_mv(::Type{TV}=Float64) where {TV}
161161
m[2] ~ Normal()
162162
return m
163163
end
164-
demo_mv (generic function with 2 methods)
164+
demo_mv (generic function with 3 methods)
165165
166166
julia> model = demo_mv();
167167
@@ -192,13 +192,13 @@ the use of [`@submodel`](@ref).
192192
193193
```jldoctest condition
194194
julia> @model demo_inner() = m ~ Normal()
195-
demo_inner (generic function with 1 method)
195+
demo_inner (generic function with 2 methods)
196196
197197
julia> @model function demo_outer()
198198
m = @submodel demo_inner()
199199
return m
200200
end
201-
demo_outer (generic function with 1 method)
201+
demo_outer (generic function with 2 methods)
202202
203203
julia> model = demo_outer();
204204
@@ -218,7 +218,7 @@ julia> @model function demo_outer_prefix()
218218
m = @submodel inner demo_inner()
219219
return m
220220
end
221-
demo_outer_prefix (generic function with 1 method)
221+
demo_outer_prefix (generic function with 2 methods)
222222
223223
julia> # This doesn't work now!
224224
conditioned_model = demo_outer_prefix() | (m = 1.0, );
@@ -279,7 +279,7 @@ julia> @model function demo()
279279
x ~ Normal(m, 1)
280280
return (; m=m, x=x)
281281
end
282-
demo (generic function with 1 method)
282+
demo (generic function with 2 methods)
283283
284284
julia> conditioned_model = condition(demo(), m = 1.0, x = 10.0);
285285
@@ -333,7 +333,7 @@ julia> @model function demo()
333333
m ~ Normal()
334334
x ~ Normal(m, 1)
335335
end
336-
demo (generic function with 1 method)
336+
demo (generic function with 2 methods)
337337
338338
julia> m = demo();
339339
@@ -613,7 +613,7 @@ julia> @model function demo(xs)
613613
end
614614
return (m, )
615615
end
616-
demo (generic function with 1 method)
616+
demo (generic function with 2 methods)
617617
618618
julia> model = demo(randn(10));
619619

test/compiler.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,15 @@ end
3939

4040
return x, y
4141
end
42+
@test length(methods(testmodel_comp)) == 2
4243
testmodel_comp(1.0, 1.2)
4344

4445
# check if drawing from the prior works
4546
@model function testmodel01(x=missing)
4647
x ~ Normal()
4748
return x
4849
end
50+
@test length(methods(testmodel01)) == 3
4951
f0_mm = testmodel01()
5052
@test mean(f0_mm() for _ in 1:1000) 0.0 atol = 0.1
5153

@@ -58,6 +60,7 @@ end
5860
x[2] ~ Normal()
5961
return x
6062
end
63+
@test length(methods(testmodel02)) == 3
6164
f0_mm = testmodel02()
6265
@test all(x -> isapprox(x, 0; atol=0.1), mean(f0_mm() for _ in 1:1000))
6366

@@ -66,6 +69,7 @@ end
6669
return x
6770
end
6871
f01_mm = testmodel03()
72+
@test length(methods(testmodel03)) == 3
6973
@test mean(f01_mm() for _ in 1:1000) 0.5 atol = 0.1
7074

7175
# test if we get the correct return values
@@ -78,6 +82,7 @@ end
7882

7983
return x1, x2
8084
end
85+
@test length(methods(testmodel1)) == 2
8186
f1_mm = testmodel1(1.0, 10.0)
8287
@test f1_mm() == (1, 10)
8388

@@ -95,6 +100,7 @@ end
95100

96101
return x1, x2
97102
end
103+
@test length(methods(testmodel2)) == 2
98104
f1_mm = testmodel2(; x1=1.0, x2=10.0)
99105
@test f1_mm() == (1, 10)
100106

@@ -461,4 +467,31 @@ end
461467
model = @model(x -> (x ~ Normal()))
462468
end
463469
end
470+
471+
@testset "dispatching with model" begin
472+
f(x) = false
473+
474+
@model demo() = x ~ Normal()
475+
@test !f(demo())
476+
f(::Model{typeof(demo)}) = true
477+
@test f(demo())
478+
479+
# Leads to re-definition of `demo` and trait is not affected.
480+
@test length(methods(demo)) == 2
481+
@model demo() = x ~ Normal()
482+
@test length(methods(demo)) == 2
483+
@test f(demo())
484+
485+
# Ensure we can specialize on arguments.
486+
@model demo(x) = x ~ Normal()
487+
length(methods(demo))
488+
@test f(demo(1.0))
489+
f(::Model{typeof(demo),(:x,)}) = false
490+
@test !f(demo(1.0))
491+
@test f(demo()) # should still be `true`
492+
493+
# Set it to `false` again.
494+
f(::Model{typeof(demo),()}) = false
495+
@test !f(demo())
496+
end
464497
end

0 commit comments

Comments
 (0)