Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# 0.39.8

MCMCChains.jl doesn't understand vector- or matrix-valued variables, and in Turing we split up such values into their individual components.
This patch carries out some internal refactoring to avoid splitting up VarNames until absolutely necessary.
There are no user-facing changes in this patch.

# 0.39.7

Update compatibility to AdvancedPS 0.7 and Libtask 0.9.
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.39.7"
version = "0.39.8"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
4 changes: 3 additions & 1 deletion ext/TuringOptimExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,9 @@ function _optimize(
logdensity_optimum = Optimisation.OptimLogDensity(
f.ldf.model, vi_optimum, f.ldf.context
)
vns_vals_iter = Turing.Inference.getparams(f.ldf.model, vi_optimum)
vals_dict = Turing.Inference.getparams(f.ldf.model, vi_optimum)
iters = map(DynamicPPL.varname_and_value_leaves, keys(vals_dict), values(vals_dict))
vns_vals_iter = mapreduce(collect, vcat, iters)
varnames = map(Symbol ∘ first, vns_vals_iter)
vals = map(last, vns_vals_iter)
vmat = NamedArrays.NamedArray(vals, varnames)
Expand Down
27 changes: 15 additions & 12 deletions src/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -176,13 +176,7 @@ function getparams(model::DynamicPPL.Model, vi::DynamicPPL.VarInfo)
# this means that the code below will work both of linked and invlinked `vi`.
# Ref: https://github.com/TuringLang/Turing.jl/issues/2195
# NOTE: We need to `deepcopy` here to avoid modifying the original `vi`.
vals = DynamicPPL.values_as_in_model(model, true, deepcopy(vi))

# Obtain an iterator over the flattened parameter names and values.
iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals))

# Materialize the iterators and concatenate.
return mapreduce(collect, vcat, iters)
return DynamicPPL.values_as_in_model(model, true, deepcopy(vi))
end
function getparams(
model::DynamicPPL.Model, untyped_vi::DynamicPPL.VarInfo{<:DynamicPPL.Metadata}
Expand All @@ -193,14 +187,25 @@ function getparams(
return getparams(model, DynamicPPL.typed_varinfo(untyped_vi))
end
function getparams(::DynamicPPL.Model, ::DynamicPPL.VarInfo{NamedTuple{(),Tuple{}}})
return float(Real)[]
return Dict{VarName,Any}()
end

function _params_to_array(model::DynamicPPL.Model, ts::Vector)
names_set = OrderedSet{VarName}()
# Extract the parameter names and values from each transition.
dicts = map(ts) do t
nms_and_vs = getparams(model, t)
# In general getparams returns a dict of VarName => values. We need to also
# split it up into constituent elements using
# `DynamicPPL.varname_and_value_leaves` because otherwise MCMCChains.jl
# won't understand it.
vals = getparams(model, t)
nms_and_vs = if isempty(vals)
Tuple{VarName,Any}[]
else
iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals))
mapreduce(collect, vcat, iters)
end

nms = map(first, nms_and_vs)
vs = map(last, nms_and_vs)
for nm in nms
Expand All @@ -210,9 +215,7 @@ function _params_to_array(model::DynamicPPL.Model, ts::Vector)
return OrderedDict(zip(nms, vs))
end
names = collect(names_set)
vals = [
get(dicts[i], key, missing) for i in eachindex(dicts), (j, key) in enumerate(names)
]
vals = [get(dicts[i], key, missing) for i in eachindex(dicts), key in names]

return names, vals
end
Expand Down
11 changes: 7 additions & 4 deletions src/optimisation/Optimisation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -366,9 +366,10 @@ function Base.get(m::ModeResult, var_symbols::AbstractVector{Symbol})
# Get all the variable names in the model. This is the same as the list of keys in
# m.values, but they are more convenient to filter when they are VarNames rather than
# Symbols.
varnames = collect(
map(first, Turing.Inference.getparams(log_density.model, log_density.varinfo))
)
vals_dict = Turing.Inference.getparams(log_density.model, log_density.varinfo)
iters = map(DynamicPPL.varname_and_value_leaves, keys(vals_dict), values(vals_dict))
vns_and_vals = mapreduce(collect, vcat, iters)
varnames = collect(map(first, vns_and_vals))
# For each symbol s in var_symbols, pick all the values from m.values for which the
# variable name has that symbol.
et = eltype(m.values)
Expand Down Expand Up @@ -396,7 +397,9 @@ parameter space in case the optimization was done in a transformed space.
function ModeResult(log_density::OptimLogDensity, solution::SciMLBase.OptimizationSolution)
varinfo_new = DynamicPPL.unflatten(log_density.ldf.varinfo, solution.u)
# `getparams` performs invlinking if needed
vns_vals_iter = Turing.Inference.getparams(log_density.ldf.model, varinfo_new)
vals = Turing.Inference.getparams(log_density.ldf.model, varinfo_new)
iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals))
vns_vals_iter = mapreduce(collect, vcat, iters)
syms = map(Symbol ∘ first, vns_vals_iter)
vals = map(last, vns_vals_iter)
return ModeResult(
Expand Down
8 changes: 5 additions & 3 deletions test/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -630,16 +630,18 @@ using Turing
return x ~ Normal()
end
fvi = DynamicPPL.VarInfo(f())
@test only(Turing.Inference.getparams(f(), fvi)) == (@varname(x), fvi[@varname(x)])
fparams = Turing.Inference.getparams(f(), fvi)
@test fparams[@varname(x)] == fvi[@varname(x)]
@test length(fparams) == 1

@model function g()
x ~ Normal()
return y ~ Poisson()
end
gvi = DynamicPPL.VarInfo(g())
gparams = Turing.Inference.getparams(g(), gvi)
@test gparams[1] == (@varname(x), gvi[@varname(x)])
@test gparams[2] == (@varname(y), gvi[@varname(y)])
@test gparams[@varname(x)] == gvi[@varname(x)]
@test gparams[@varname(y)] == gvi[@varname(y)]
@test length(gparams) == 2
end

Expand Down
Loading