Skip to content
65 changes: 48 additions & 17 deletions src/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,17 @@ function AbstractMCMC.step(
model::Model,
sampler::Union{SampleFromUniform,SampleFromPrior},
state=nothing;
trace_type=VarInfo,
kwargs...,
)
vi = VarInfo()
model(rng, vi, sampler)
if trace_type === VarInfo
vi = VarInfo()
model(rng, vi, sampler)
elseif trace_type === SimpleVarInfo
vi = last(evaluate!!(model, rng, SimpleVarInfo{Float64}(OrderedDict()), sampler))
else
throw(ArgumentError("Unknown trace type: $trace_type"))
end
return vi, nothing
end

Expand Down Expand Up @@ -97,23 +104,47 @@ end

# initial step: general interface for resuming and
function AbstractMCMC.step(
rng::Random.AbstractRNG, model::Model, spl::Sampler; initial_params=nothing, kwargs...
rng::Random.AbstractRNG,
model::Model,
spl::Sampler;
initial_params=nothing,
trace_type=VarInfo,
kwargs...,
)
# Sample initial values.
vi = default_varinfo(rng, model, spl)

# Update the parameters if provided.
if initial_params !== nothing
vi = initialize_parameters!!(vi, initial_params, spl, model)

# Update joint log probability.
# This is a quick fix for https://github.com/TuringLang/Turing.jl/issues/1588
# and https://github.com/TuringLang/Turing.jl/issues/1563
# to avoid that existing variables are resampled
vi = last(evaluate!!(model, vi, DefaultContext()))
end
if trace_type === VarInfo
# Sample initial values.
vi = default_varinfo(rng, model, spl)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't quite make sense because default_varinfo can also return SimpleVarInfo


# Update the parameters if provided.
if initial_params !== nothing
vi = initialize_parameters!!(vi, initial_params, spl, model)

# Update joint log probability.
# This is a quick fix for https://github.com/TuringLang/Turing.jl/issues/1588
# and https://github.com/TuringLang/Turing.jl/issues/1563
# to avoid that existing variables are resampled
vi = last(evaluate!!(model, vi, DefaultContext()))
end

return initialstep(rng, model, spl, vi; initial_params, kwargs...)
elseif trace_type === SimpleVarInfo
vi = last(
DynamicPPL.evaluate!!(
model,
SimpleVarInfo{Float64}(OrderedDict()),
SamplingContext(rng, SampleFromPrior(), DefaultContext()),
),
)

if initial_params !== nothing
vi = initialize_parameters!!(vi, initial_params, spl, model)
vi = last(evaluate!!(model, vi, DefaultContext()))
end

return initialstep(rng, model, spl, vi; initial_params, kwargs...)
return initialstep(rng, model, spl, vi; initial_params, kwargs...)
else
throw(ArgumentError("Unknown trace type: $trace_type"))
end
end

"""
Expand Down
2 changes: 2 additions & 0 deletions src/simple_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,8 @@ function dot_assume(
return value, lp, vi
end

updategid!(vi::SimpleOrThreadSafeSimple, vn::VarName, spl::Sampler) = nothing

# NOTE: We don't implement `settrans!!(vi, trans, vn)`.
function settrans!!(vi::SimpleVarInfo, trans)
return settrans!!(vi, trans ? DynamicTransformation() : NoTransformation())
Expand Down
151 changes: 140 additions & 11 deletions test/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,44 +11,63 @@
N = 1_000

chains = sample(model, SampleFromPrior(), N; progress=false)
chains_svi = sample(
model, SampleFromPrior(), N; progress=false, trace_type=SimpleVarInfo
)
@test chains isa Vector{<:VarInfo}
@test length(chains) == N
@test chains_svi isa Vector{<:SimpleVarInfo}
@test length(chains_svi) == N

# Expected value of ``X`` where ``X ~ N(2, ...)`` is 2.
@test mean(vi[@varname(m)] for vi in chains) ≈ 2 atol = 0.15
@test mean(vi[@varname(m)] for vi in chains_svi) ≈ 2 atol = 0.15

# Expected value of ``X`` where ``X ~ IG(2, 3)`` is 3.
@test mean(vi[@varname(s)] for vi in chains) ≈ 3 atol = 0.2
@test mean(vi[@varname(s)] for vi in chains_svi) ≈ 3 atol = 0.2

chains = sample(model, SampleFromUniform(), N; progress=false)
chains_svi = sample(
model, SampleFromUniform(), N; progress=false, trace_type=SimpleVarInfo
)
@test chains isa Vector{<:VarInfo}
@test length(chains) == N
@test chains_svi isa Vector{<:SimpleVarInfo}
@test length(chains_svi) == N

# `m` is Gaussian, i.e. no transformation is used, so it
# should have a mean equal to its prior, i.e. 2.
@test mean(vi[@varname(m)] for vi in chains) ≈ 2 atol = 0.1
@test mean(vi[@varname(m)] for vi in chains_svi) ≈ 2 atol = 0.1

# Expected value of ``exp(X)`` where ``X ~ U[-2, 2]`` is ≈ 1.8.
@test mean(vi[@varname(s)] for vi in chains) ≈ 1.8 atol = 0.1
@test mean(vi[@varname(s)] for vi in chains_svi) ≈ 1.8 atol = 0.1
end

@testset "init" begin
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
N = 1000
chain_init = sample(model, SampleFromUniform(), N; progress=false)
chain_init_svi = sample(
model, SampleFromUniform(), N; progress=false, trace_type=SimpleVarInfo
)

for vn in keys(first(chain_init))
if AbstractPPL.subsumes(@varname(s), vn)
# `s ~ InverseGamma(2, 3)` and its unconstrained value will be sampled from Unif[-2,2].
dist = InverseGamma(2, 3)
b = DynamicPPL.link_transform(dist)
@test mean(mean(b(vi[vn])) for vi in chain_init) ≈ 0 atol = 0.11
elseif AbstractPPL.subsumes(@varname(m), vn)
# `m ~ Normal(0, sqrt(s))` and its constrained value is the same.
@test mean(mean(vi[vn]) for vi in chain_init) ≈ 0 atol = 0.11
else
error("Unknown variable name: $vn")
for chain in (chain_init, chain_init_svi)
for vn in keys(first(chain))
if AbstractPPL.subsumes(@varname(s), vn)
# `s ~ InverseGamma(2, 3)` and its unconstrained value will be sampled from Unif[-2,2].
dist = InverseGamma(2, 3)
b = DynamicPPL.link_transform(dist)
@test mean(mean(b(vi[vn])) for vi in chain) ≈ 0 atol = 0.11
elseif AbstractPPL.subsumes(@varname(m), vn)
# `m ~ Normal(0, sqrt(s))` and its constrained value is the same.
@test mean(mean(vi[vn]) for vi in chain) ≈ 0 atol = 0.11
else
error("Unknown variable name: $vn")
end
end
end
end
Expand Down Expand Up @@ -85,8 +104,18 @@
sampler = Sampler(alg)
lptrue = logpdf(Binomial(25, 0.2), 10)
chain = sample(model, sampler, 1; initial_params=0.2, progress=false)
chain_svi = sample(
model,
sampler,
1;
initial_params=0.2,
progress=false,
trace_type=SimpleVarInfo,
)
@test chain[1].metadata.p.vals == [0.2]
@test getlogp(chain[1]) == lptrue
@test chain_svi[1][@varname(p)] == 0.2
@test getlogp(chain_svi[1]) == lptrue

# parallel sampling
chains = sample(
Expand All @@ -103,6 +132,21 @@
@test getlogp(c[1]) == lptrue
end

chains_svi = sample(
model,
sampler,
MCMCThreads(),
1,
10;
initial_params=fill(0.2, 10),
progress=false,
trace_type=SimpleVarInfo,
)
for c in chains_svi
@test c[1][@varname(p)] == 0.2
@test getlogp(c[1]) == lptrue
end

# model with two variables: initialization s = 4, m = -1
@model function twovars()
s ~ InverseGamma(2, 3)
Expand All @@ -114,6 +158,17 @@
@test chain[1].metadata.s.vals == [4]
@test chain[1].metadata.m.vals == [-1]
@test getlogp(chain[1]) == lptrue
chain_svi = sample(
model,
sampler,
1;
initial_params=[4, -1],
progress=false,
trace_type=SimpleVarInfo,
)
@test chain_svi[1][@varname(s)] == 4
@test chain_svi[1][@varname(m)] == -1
@test getlogp(chain_svi[1]) == lptrue

# parallel sampling
chains = sample(
Expand All @@ -131,10 +186,36 @@
@test getlogp(c[1]) == lptrue
end

chains_svi = sample(
model,
sampler,
MCMCThreads(),
1,
10;
initial_params=fill([4, -1], 10),
progress=false,
trace_type=SimpleVarInfo,
)
for c in chains_svi
@test c[1][@varname(s)] == 4
@test c[1][@varname(m)] == -1
@test getlogp(c[1]) == lptrue
end

# set only m = -1
chain = sample(model, sampler, 1; initial_params=[missing, -1], progress=false)
@test !ismissing(chain[1].metadata.s.vals[1])
@test chain[1].metadata.m.vals == [-1]
chain_svi = sample(
model,
sampler,
1;
initial_params=[missing, -1],
progress=false,
trace_type=SimpleVarInfo,
)
@test !ismissing(chain_svi[1][@varname(s)])
@test chain_svi[1][@varname(m)] == -1

# parallel sampling
chains = sample(
Expand All @@ -150,26 +231,74 @@
@test !ismissing(c[1].metadata.s.vals[1])
@test c[1].metadata.m.vals == [-1]
end
chains_svi = sample(
model,
sampler,
MCMCThreads(),
1,
10;
initial_params=fill([missing, -1], 10),
progress=false,
trace_type=SimpleVarInfo,
)
for c in chains_svi
@test !ismissing(c[1][@varname(s)])
@test c[1][@varname(m)] == -1
end

# specify `initial_params=nothing`
Random.seed!(1234)
chain1 = sample(model, sampler, 1; progress=false)
chain1_svi = sample(model, sampler, 1; progress=false, trace_type=SimpleVarInfo)
Random.seed!(1234)
chain2 = sample(model, sampler, 1; initial_params=nothing, progress=false)
chain2_svi = sample(
model,
sampler,
1;
initial_params=nothing,
progress=false,
trace_type=SimpleVarInfo,
)
@test chain1[1].metadata.m.vals == chain2[1].metadata.m.vals
@test chain1[1].metadata.s.vals == chain2[1].metadata.s.vals
@test chain1_svi[1][@varname(m)] == chain2_svi[1][@varname(m)]
@test chain1_svi[1][@varname(s)] == chain2_svi[1][@varname(s)]

# parallel sampling
Random.seed!(1234)
chains1 = sample(model, sampler, MCMCThreads(), 1, 10; progress=false)
chains1_svi = sample(
model,
sampler,
MCMCThreads(),
1,
10;
progress=false,
trace_type=SimpleVarInfo,
)
Random.seed!(1234)
chains2 = sample(
model, sampler, MCMCThreads(), 1, 10; initial_params=nothing, progress=false
)
chains2_svi = sample(
model,
sampler,
MCMCThreads(),
1,
10;
initial_params=nothing,
progress=false,
trace_type=SimpleVarInfo,
)
for (c1, c2) in zip(chains1, chains2)
@test c1[1].metadata.m.vals == c2[1].metadata.m.vals
@test c1[1].metadata.s.vals == c2[1].metadata.s.vals
end
for (c1, c2) in zip(chains1_svi, chains2_svi)
@test c1[1][@varname(m)] == c2[1][@varname(m)]
@test c1[1][@varname(s)] == c2[1][@varname(s)]
end
end
end
end