diff --git a/HISTORY.md b/HISTORY.md index 70a5e5efef..3de9ac8244 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,3 +1,9 @@ +# 0.39.8 + +MCMCChains.jl doesn't understand vector- or matrix-valued variables, and in Turing we split up such values into their individual components. +This patch carries out some internal refactoring to avoid splitting up VarNames until absolutely necessary. +There are no user-facing changes in this patch. + # 0.39.7 Update compatibility to AdvancedPS 0.7 and Libtask 0.9. diff --git a/Project.toml b/Project.toml index 1d20c21e7f..8fc00d4f40 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Turing" uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" -version = "0.39.7" +version = "0.39.8" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/ext/TuringOptimExt.jl b/ext/TuringOptimExt.jl index d6c253e2a2..ad8fdad44b 100644 --- a/ext/TuringOptimExt.jl +++ b/ext/TuringOptimExt.jl @@ -186,7 +186,9 @@ function _optimize( logdensity_optimum = Optimisation.OptimLogDensity( f.ldf.model, vi_optimum, f.ldf.context ) - vns_vals_iter = Turing.Inference.getparams(f.ldf.model, vi_optimum) + vals_dict = Turing.Inference.getparams(f.ldf.model, vi_optimum) + iters = map(DynamicPPL.varname_and_value_leaves, keys(vals_dict), values(vals_dict)) + vns_vals_iter = mapreduce(collect, vcat, iters) varnames = map(Symbol ∘ first, vns_vals_iter) vals = map(last, vns_vals_iter) vmat = NamedArrays.NamedArray(vals, varnames) diff --git a/src/mcmc/Inference.jl b/src/mcmc/Inference.jl index 0370e619a3..24ff4c86d9 100644 --- a/src/mcmc/Inference.jl +++ b/src/mcmc/Inference.jl @@ -176,13 +176,7 @@ function getparams(model::DynamicPPL.Model, vi::DynamicPPL.VarInfo) # this means that the code below will work both of linked and invlinked `vi`. # Ref: https://github.com/TuringLang/Turing.jl/issues/2195 # NOTE: We need to `deepcopy` here to avoid modifying the original `vi`. - vals = DynamicPPL.values_as_in_model(model, true, deepcopy(vi)) - - # Obtain an iterator over the flattened parameter names and values. - iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)) - - # Materialize the iterators and concatenate. - return mapreduce(collect, vcat, iters) + return DynamicPPL.values_as_in_model(model, true, deepcopy(vi)) end function getparams( model::DynamicPPL.Model, untyped_vi::DynamicPPL.VarInfo{<:DynamicPPL.Metadata} @@ -193,14 +187,25 @@ function getparams( return getparams(model, DynamicPPL.typed_varinfo(untyped_vi)) end function getparams(::DynamicPPL.Model, ::DynamicPPL.VarInfo{NamedTuple{(),Tuple{}}}) - return float(Real)[] + return Dict{VarName,Any}() end function _params_to_array(model::DynamicPPL.Model, ts::Vector) names_set = OrderedSet{VarName}() # Extract the parameter names and values from each transition. dicts = map(ts) do t - nms_and_vs = getparams(model, t) + # In general getparams returns a dict of VarName => values. We need to also + # split it up into constituent elements using + # `DynamicPPL.varname_and_value_leaves` because otherwise MCMCChains.jl + # won't understand it. + vals = getparams(model, t) + nms_and_vs = if isempty(vals) + Tuple{VarName,Any}[] + else + iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)) + mapreduce(collect, vcat, iters) + end + nms = map(first, nms_and_vs) vs = map(last, nms_and_vs) for nm in nms @@ -210,9 +215,7 @@ function _params_to_array(model::DynamicPPL.Model, ts::Vector) return OrderedDict(zip(nms, vs)) end names = collect(names_set) - vals = [ - get(dicts[i], key, missing) for i in eachindex(dicts), (j, key) in enumerate(names) - ] + vals = [get(dicts[i], key, missing) for i in eachindex(dicts), key in names] return names, vals end diff --git a/src/optimisation/Optimisation.jl b/src/optimisation/Optimisation.jl index ddcc27b876..fedc2510d4 100644 --- a/src/optimisation/Optimisation.jl +++ b/src/optimisation/Optimisation.jl @@ -366,9 +366,10 @@ function Base.get(m::ModeResult, var_symbols::AbstractVector{Symbol}) # Get all the variable names in the model. This is the same as the list of keys in # m.values, but they are more convenient to filter when they are VarNames rather than # Symbols. - varnames = collect( - map(first, Turing.Inference.getparams(log_density.model, log_density.varinfo)) - ) + vals_dict = Turing.Inference.getparams(log_density.model, log_density.varinfo) + iters = map(DynamicPPL.varname_and_value_leaves, keys(vals_dict), values(vals_dict)) + vns_and_vals = mapreduce(collect, vcat, iters) + varnames = collect(map(first, vns_and_vals)) # For each symbol s in var_symbols, pick all the values from m.values for which the # variable name has that symbol. et = eltype(m.values) @@ -396,7 +397,9 @@ parameter space in case the optimization was done in a transformed space. function ModeResult(log_density::OptimLogDensity, solution::SciMLBase.OptimizationSolution) varinfo_new = DynamicPPL.unflatten(log_density.ldf.varinfo, solution.u) # `getparams` performs invlinking if needed - vns_vals_iter = Turing.Inference.getparams(log_density.ldf.model, varinfo_new) + vals = Turing.Inference.getparams(log_density.ldf.model, varinfo_new) + iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)) + vns_vals_iter = mapreduce(collect, vcat, iters) syms = map(Symbol ∘ first, vns_vals_iter) vals = map(last, vns_vals_iter) return ModeResult( diff --git a/test/mcmc/Inference.jl b/test/mcmc/Inference.jl index cf528ce517..c3ac571cb7 100644 --- a/test/mcmc/Inference.jl +++ b/test/mcmc/Inference.jl @@ -630,7 +630,9 @@ using Turing return x ~ Normal() end fvi = DynamicPPL.VarInfo(f()) - @test only(Turing.Inference.getparams(f(), fvi)) == (@varname(x), fvi[@varname(x)]) + fparams = Turing.Inference.getparams(f(), fvi) + @test fparams[@varname(x)] == fvi[@varname(x)] + @test length(fparams) == 1 @model function g() x ~ Normal() @@ -638,8 +640,8 @@ using Turing end gvi = DynamicPPL.VarInfo(g()) gparams = Turing.Inference.getparams(g(), gvi) - @test gparams[1] == (@varname(x), gvi[@varname(x)]) - @test gparams[2] == (@varname(y), gvi[@varname(y)]) + @test gparams[@varname(x)] == gvi[@varname(x)] + @test gparams[@varname(y)] == gvi[@varname(y)] @test length(gparams) == 2 end