Skip to content

Faster evaluation: SimpleVarInfo #267

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 28 commits into from
Aug 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
208b62c
Merge branch 'tor/immutable-varinfo-support' into tor/simple-varinfo-v2
torfjelde Jun 30, 2021
a8e55bd
updated SimpleVarInfo impl
torfjelde Jun 30, 2021
8ea80d7
Merge branch 'tor/immutable-varinfo-support' into tor/simple-varinfo-v2
torfjelde Jun 30, 2021
d317bd8
Merge branch 'tor/immutable-varinfo-support' into tor/simple-varinfo-v2
torfjelde Jun 30, 2021
a88f8ea
Merge branch 'tor/immutable-varinfo-support' into tor/simple-varinfo-v2
torfjelde Jun 30, 2021
bfd7c78
added eltype impl for SimpleVarInfo
torfjelde Jul 2, 2021
acb15eb
formatting
torfjelde Jul 2, 2021
4828aab
fixed eltype for SimpleVarInfo
torfjelde Jul 6, 2021
b56024e
Merge branch 'tor/immutable-varinfo-support' into tor/simple-varinfo-v2
torfjelde Jul 9, 2021
e4f0ad2
formatting
torfjelde Jul 9, 2021
ccfd112
initial work on allowing sampling using SimpleVarInfo
torfjelde Jul 9, 2021
d660433
formatting
torfjelde Jul 9, 2021
c925b07
Merge branch 'master' into tor/simple-varinfo-v2
torfjelde Jul 16, 2021
3ec72c6
Merge branch 'tor/simple-varinfo-v2' of github.com:TuringLang/Dynamic…
torfjelde Jul 16, 2021
90cf754
add constructor for SimpleVarInfo using model
torfjelde Jul 16, 2021
0ab9d8b
improved leftover to_namedtuple_expr, fixing a bug when used with Zygote
torfjelde Jul 16, 2021
42ad552
bumped patch version
torfjelde Jul 16, 2021
975184d
Merge branch 'tor/allargs-construction-improvement' into tor/simple-v…
torfjelde Jul 16, 2021
a0cd0c4
Merge branch 'master' into tor/simple-varinfo-v2
torfjelde Jul 19, 2021
744a032
Merge branch 'tor/immutable-varinfo-support' into tor/simple-varinfo-v2
torfjelde Jul 19, 2021
76daca6
Merge branch 'tor/immutable-varinfo-support' into tor/simple-varinfo-v2
torfjelde Jul 20, 2021
6f947f7
Merge branch 'tor/immutable-varinfo-support' into tor/simple-varinfo-v2
torfjelde Jul 22, 2021
4076f63
Merge branch 'tor/immutable-varinfo-support' into tor/simple-varinfo-v2
torfjelde Jul 23, 2021
d0a08f6
fixed some issues and added support for usage of Dict in SimpleVarInfo
torfjelde Aug 5, 2021
4002318
Merge branch 'tor/immutable-varinfo-support' into tor/simple-varinfo-v2
torfjelde Aug 5, 2021
ff75ddc
added docstring and improved indexing behvaior for SimpleVarInfo
torfjelde Aug 5, 2021
d29dd8f
formatting
torfjelde Aug 5, 2021
a72594f
dont allow sampling with indexing when using SimpleVarInfo with Named…
torfjelde Aug 5, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
Expand Down
2 changes: 2 additions & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ export AbstractVarInfo,
VarInfo,
UntypedVarInfo,
TypedVarInfo,
SimpleVarInfo,
push!!,
empty!!,
getlogp,
Expand Down Expand Up @@ -135,6 +136,7 @@ include("varname.jl")
include("distribution_wrappers.jl")
include("contexts.jl")
include("varinfo.jl")
include("simple_varinfo.jl")
include("threadsafe.jl")
include("context_implementations.jl")
include("compiler.jl")
Expand Down
278 changes: 278 additions & 0 deletions src/simple_varinfo.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,278 @@
using Setfield

"""
SimpleVarInfo{NT,T} <: AbstractVarInfo

A simple wrapper of the parameters with a `logp` field for
accumulation of the logdensity.

Currently only implemented for `NT<:NamedTuple` and `NT<:Dict`.

# Notes
The major differences between this and `TypedVarInfo` are:
1. `SimpleVarInfo` does not require linearization.
2. `SimpleVarInfo` can use more efficient bijectors.
3. `SimpleVarInfo` is only type-stable if `NT<:NamedTuple` and either
a) no indexing is used in tilde-statements, or
b) the values have been specified with the corret shapes.

# Examples
```jldoctest; setup=:(using Distributions)
julia> using StableRNGs

julia> @model function demo()
m ~ Normal()
x = Vector{Float64}(undef, 2)
for i in eachindex(x)
x[i] ~ Normal()
end
return x
end
demo (generic function with 1 method)

julia> m = demo();

julia> rng = StableRNG(42);

julia> ### Sampling ###
ctx = SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext());

julia> # In the `NamedTuple` version we need to provide the place-holder values for
# the variablse which are using "containers", e.g. `Array`.
# In this case, this means that we need to specify `x` but not `m`.
_, vi = DynamicPPL.evaluate(m, SimpleVarInfo((x = ones(2), )), ctx); vi
SimpleVarInfo{NamedTuple{(:x, :m), Tuple{Vector{Float64}, Float64}}, Float64}((x = [1.6642061055583879, 1.796319600944139], m = -0.16796295277202952), -5.769094411622931)

julia> # (✓) Vroom, vroom! FAST!!!
DynamicPPL.getval(vi, @varname(x[1]))
1.6642061055583879

julia> # We can also access arbitrary varnames pointing to `x`, e.g.
DynamicPPL.getval(vi, @varname(x))
2-element Vector{Float64}:
1.6642061055583879
1.796319600944139

julia> DynamicPPL.getval(vi, @varname(x[1:2]))
2-element view(::Vector{Float64}, 1:2) with eltype Float64:
1.6642061055583879
1.796319600944139

julia> # (×) If we don't provide the container...
_, vi = DynamicPPL.evaluate(m, SimpleVarInfo(), ctx); vi
ERROR: type NamedTuple has no field x
[...]

julia> # If one does not know the varnames, we can use a `Dict` instead.
_, vi = DynamicPPL.evaluate(m, SimpleVarInfo{Float64}(Dict()), ctx); vi
SimpleVarInfo{Dict{Any, Any}, Float64}(Dict{Any, Any}(x[1] => 1.192696983568277, x[2] => 0.4914514300738121, m => 0.25572200616753643), -3.6215377732004237)

julia> # (✓) Sort of fast, but only possible at runtime.
DynamicPPL.getval(vi, @varname(x[1]))
1.192696983568277

julia> # In addtion, we can only access varnames as they appear in the model!
DynamicPPL.getval(vi, @varname(x))
ERROR: KeyError: key x not found
[...]

julia> julia> DynamicPPL.getval(vi, @varname(x[1:2]))
ERROR: KeyError: key x[1:2] not found
[...]
```
"""
struct SimpleVarInfo{NT,T} <: AbstractVarInfo
θ::NT
logp::T
end

SimpleVarInfo{T}(θ) where {T<:Real} = SimpleVarInfo{typeof(θ),T}(θ, zero(T))
SimpleVarInfo(θ) = SimpleVarInfo{eltype(first(θ))}(θ)
SimpleVarInfo{T}() where {T<:Real} = SimpleVarInfo{T}(NamedTuple())
SimpleVarInfo() = SimpleVarInfo{Float64}()

# Constructor from `Model`.
SimpleVarInfo(model::Model, args...) = SimpleVarInfo{Float64}(model, args...)
function SimpleVarInfo{T}(model::Model, args...) where {T<:Real}
_, svi = DynamicPPL.evaluate(model, SimpleVarInfo{T}(), args...)
return svi
end

# Constructor from `VarInfo`.
function SimpleVarInfo(vi::TypedVarInfo, ::Type{D}=NamedTuple; kwargs...) where {D}
return SimpleVarInfo{eltype(getlogp(vi))}(vi, D; kwargs...)
end
function SimpleVarInfo{T}(
vi::VarInfo{<:NamedTuple{names}}, ::Type{D}
) where {T<:Real,names,D}
values = values_as(vi, D)
return SimpleVarInfo{T}(values)
end

getlogp(vi::SimpleVarInfo) = vi.logp
setlogp!!(vi::SimpleVarInfo, logp) = SimpleVarInfo(vi.θ, logp)
acclogp!!(vi::SimpleVarInfo, logp) = SimpleVarInfo(vi.θ, getlogp(vi) + logp)

function setlogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp)
vi.logp[] = logp
return vi
end

function acclogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp)
vi.logp[] += logp
return vi
end

function _getvalue(nt::NamedTuple, ::Val{sym}, inds=()) where {sym}
# Use `getproperty` instead of `getfield`
value = getproperty(nt, sym)
# Note that this will return a `view`, even if the resulting value is 0-dim.
# This makes it possible to call `setindex!` on the result later to update
# in place even in the case where are retrieving a single element, e.g. `x[1]`.
return _getindex(value, inds)
end

# `NamedTuple`
function getval(vi::SimpleVarInfo{<:NamedTuple}, vn::VarName{sym}) where {sym}
return maybe_unwrap_view(_getvalue(vi.θ, Val{sym}(), vn.indexing))
end

# `Dict`
function getval(vi::SimpleVarInfo{<:Dict}, vn::VarName)
return vi.θ[vn]
end

# `SimpleVarInfo` doesn't necessarily vectorize, so we can have arrays other than
# just `Vector`.
getval(vi::SimpleVarInfo, vns::AbstractArray{<:VarName}) = map(vn -> getval(vi, vn), vns)
# To disambiguiate.
getval(vi::SimpleVarInfo, vns::Vector{<:VarName}) = map(vn -> getval(vi, vn), vns)

haskey(vi::SimpleVarInfo, vn) = haskey(vi.θ, getsym(vn))

istrans(::SimpleVarInfo, vn::VarName) = false

getindex(vi::SimpleVarInfo, spl::SampleFromPrior) = vi.θ
getindex(vi::SimpleVarInfo, spl::SampleFromUniform) = vi.θ
# TODO: Should we do better?
getindex(vi::SimpleVarInfo, spl::Sampler) = vi.θ
getindex(vi::SimpleVarInfo, vn::VarName) = getval(vi, vn)
getindex(vi::SimpleVarInfo, vns::AbstractArray{<:VarName}) = getval(vi, vns)
# HACK: Need to disambiguiate.
getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}) = getval(vi, vns)

# Necessary for `matchingvalue` to work properly.
function Base.eltype(
vi::SimpleVarInfo{<:Any,T}, spl::Union{AbstractSampler,SampleFromPrior}
) where {T}
return T
end

# `NamedTuple`
function push!!(
vi::SimpleVarInfo{<:NamedTuple},
vn::VarName{sym,Tuple{}},
value,
dist::Distribution,
gidset::Set{Selector},
) where {sym}
@set vi.θ = merge(vi.θ, NamedTuple{(sym,)}((value,)))
end
function push!!(
vi::SimpleVarInfo{<:NamedTuple},
vn::VarName{sym},
value,
dist::Distribution,
gidset::Set{Selector},
) where {sym}
# We update in place.
# We need a view into the array, hence we call `_getvalue` directly
# rather than `getval`.
current = _getvalue(vi.θ, Val{sym}(), vn.indexing)
current .= value
return vi
end

# `Dict`
function push!!(
vi::SimpleVarInfo{<:Dict}, vn::VarName, r, dist::Distribution, gidset::Set{Selector}
)
vi.θ[vn] = r
return vi
end

# Context implementations
function tilde_assume!!(context, right, vn, inds, vi::SimpleVarInfo)
value, logp, vi_new = tilde_assume(context, right, vn, inds, vi)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Notice how here we also take the vi_new.

Ideally this is what we should be doing overall if we're going to support immutable varinfos. But if we force this, we'll end up breaking a lot of downstream samplers since it requires changing assume statements to also return vi.

As an intermediate step, it might be worth just overloading tilde_assume!! as I have done here + assume as I have done below. This also brings up another annoyance though: vi should really be at the beginning of the arguments of all the tilde-statements (and ideally assume too, but we can delay this) to minimize method ambiguities but allow us to have some special behavior for the different impls for AbstractVarInfo.

return value, acclogp!!(vi_new, logp)
end

function assume(dist::Distribution, vn::VarName, vi::SimpleVarInfo)
left = vi[vn]
return left, Distributions.loglikelihood(dist, left), vi
end

function assume(
rng::Random.AbstractRNG,
sampler::SampleFromPrior,
dist::Distribution,
vn::VarName,
vi::SimpleVarInfo,
)
value = init(rng, dist, sampler)
vi = push!!(vi, vn, value, dist, sampler)
vi = settrans!!(vi, false, vn)
return value, Distributions.loglikelihood(dist, value), vi
end

# function dot_tilde_assume!!(context, right, left, vn, inds, vi::SimpleVarInfo)
# throw(MethodError(dot_tilde_assume!!, (context, right, left, vn, inds, vi)))
# end

function dot_tilde_assume!!(context, right, left, vn, inds, vi::SimpleVarInfo)
value, logp, vi_new = dot_tilde_assume(context, right, left, vn, inds, vi)
# Mutation of `value` no longer occurs in main body, so we do it here.
left .= value
return value, acclogp!!(vi_new, logp)
end

function dot_assume(
dist::MultivariateDistribution,
var::AbstractMatrix,
vns::AbstractVector{<:VarName},
vi::SimpleVarInfo,
)
@assert length(dist) == size(var, 1)
# NOTE: We cannot work with `var` here because we might have a model of the form
#
# m = Vector{Float64}(undef, n)
# m .~ Normal()
#
# in which case `var` will have `undef` elements, even if `m` is present in `vi`.
value = vi[vns]
lp = sum(zip(vns, eachcol(value))) do vn, val
return Distributions.logpdf(dist, val)
end
return value, lp, vi
end

function dot_assume(
dists::Union{Distribution,AbstractArray{<:Distribution}},
var::AbstractArray,
vns::AbstractArray{<:VarName},
vi::SimpleVarInfo{<:NamedTuple},
)
# NOTE: We cannot work with `var` here because we might have a model of the form
#
# m = Vector{Float64}(undef, n)
# m .~ Normal()
#
# in which case `var` will have `undef` elements, even if `m` is present in `vi`.
value = vi[vns]
lp = sum(Distributions.logpdf.(dists, value))
return value, lp, vi
end

# HACK: Allows us to re-use the impleemntation of `dot_tilde`, etc. for literals.
increment_num_produce!(::SimpleVarInfo) = nothing
settrans!!(vi::SimpleVarInfo, trans::Bool, vn::VarName) = vi
23 changes: 23 additions & 0 deletions src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1493,3 +1493,26 @@ function _setval_and_resample_kernel!(vi::AbstractVarInfo, vn::VarName, values,

return indices
end

"""
values_as(vi::TypedVarInfo, ::Type{NamedTuple})
values_as(vi::TypedVarInfo, ::Type{Dict})

Return values in `vi` as the specified type, e.g. `NamedTuple` is returned if
"""
function values_as(vi::VarInfo{<:NamedTuple{names}}, ::Type{NamedTuple}) where {names}
iter = Iterators.flatten(values_from_metadata(getfield(vi.metadata, n)) for n in names)
return NamedTuple(map(p -> Symbol(p.first) => p.second, iter))
end

function values_as(vi::VarInfo{<:NamedTuple{names}}, ::Type{Dict}) where {names}
iter = Iterators.flatten(values_from_metadata(getfield(vi.metadata, n)) for n in names)
return Dict(iter)
end

function values_from_metadata(md::Metadata)
return (
vn => reconstruct(md.dists[md.idcs[vn]], md.vals[md.ranges[md.idcs[vn]]]) for
vn in md.vns
)
end