-
Notifications
You must be signed in to change notification settings - Fork 35
Faster evaluation: SimpleVarInfo
#267
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
torfjelde
merged 28 commits into
tor/immutable-varinfo-support
from
tor/simple-varinfo-v2
Aug 14, 2021
Merged
Changes from all commits
Commits
Show all changes
28 commits
Select commit
Hold shift + click to select a range
208b62c
Merge branch 'tor/immutable-varinfo-support' into tor/simple-varinfo-v2
torfjelde a8e55bd
updated SimpleVarInfo impl
torfjelde 8ea80d7
Merge branch 'tor/immutable-varinfo-support' into tor/simple-varinfo-v2
torfjelde d317bd8
Merge branch 'tor/immutable-varinfo-support' into tor/simple-varinfo-v2
torfjelde a88f8ea
Merge branch 'tor/immutable-varinfo-support' into tor/simple-varinfo-v2
torfjelde bfd7c78
added eltype impl for SimpleVarInfo
torfjelde acb15eb
formatting
torfjelde 4828aab
fixed eltype for SimpleVarInfo
torfjelde b56024e
Merge branch 'tor/immutable-varinfo-support' into tor/simple-varinfo-v2
torfjelde e4f0ad2
formatting
torfjelde ccfd112
initial work on allowing sampling using SimpleVarInfo
torfjelde d660433
formatting
torfjelde c925b07
Merge branch 'master' into tor/simple-varinfo-v2
torfjelde 3ec72c6
Merge branch 'tor/simple-varinfo-v2' of github.com:TuringLang/Dynamic…
torfjelde 90cf754
add constructor for SimpleVarInfo using model
torfjelde 0ab9d8b
improved leftover to_namedtuple_expr, fixing a bug when used with Zygote
torfjelde 42ad552
bumped patch version
torfjelde 975184d
Merge branch 'tor/allargs-construction-improvement' into tor/simple-v…
torfjelde a0cd0c4
Merge branch 'master' into tor/simple-varinfo-v2
torfjelde 744a032
Merge branch 'tor/immutable-varinfo-support' into tor/simple-varinfo-v2
torfjelde 76daca6
Merge branch 'tor/immutable-varinfo-support' into tor/simple-varinfo-v2
torfjelde 6f947f7
Merge branch 'tor/immutable-varinfo-support' into tor/simple-varinfo-v2
torfjelde 4076f63
Merge branch 'tor/immutable-varinfo-support' into tor/simple-varinfo-v2
torfjelde d0a08f6
fixed some issues and added support for usage of Dict in SimpleVarInfo
torfjelde 4002318
Merge branch 'tor/immutable-varinfo-support' into tor/simple-varinfo-v2
torfjelde ff75ddc
added docstring and improved indexing behvaior for SimpleVarInfo
torfjelde d29dd8f
formatting
torfjelde a72594f
dont allow sampling with indexing when using SimpleVarInfo with Named…
torfjelde File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,278 @@ | ||
using Setfield | ||
|
||
""" | ||
SimpleVarInfo{NT,T} <: AbstractVarInfo | ||
|
||
A simple wrapper of the parameters with a `logp` field for | ||
accumulation of the logdensity. | ||
|
||
Currently only implemented for `NT<:NamedTuple` and `NT<:Dict`. | ||
|
||
# Notes | ||
The major differences between this and `TypedVarInfo` are: | ||
1. `SimpleVarInfo` does not require linearization. | ||
2. `SimpleVarInfo` can use more efficient bijectors. | ||
3. `SimpleVarInfo` is only type-stable if `NT<:NamedTuple` and either | ||
a) no indexing is used in tilde-statements, or | ||
b) the values have been specified with the corret shapes. | ||
|
||
# Examples | ||
```jldoctest; setup=:(using Distributions) | ||
julia> using StableRNGs | ||
|
||
julia> @model function demo() | ||
m ~ Normal() | ||
x = Vector{Float64}(undef, 2) | ||
for i in eachindex(x) | ||
x[i] ~ Normal() | ||
end | ||
return x | ||
end | ||
demo (generic function with 1 method) | ||
|
||
julia> m = demo(); | ||
|
||
julia> rng = StableRNG(42); | ||
|
||
julia> ### Sampling ### | ||
ctx = SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()); | ||
|
||
julia> # In the `NamedTuple` version we need to provide the place-holder values for | ||
# the variablse which are using "containers", e.g. `Array`. | ||
# In this case, this means that we need to specify `x` but not `m`. | ||
_, vi = DynamicPPL.evaluate(m, SimpleVarInfo((x = ones(2), )), ctx); vi | ||
SimpleVarInfo{NamedTuple{(:x, :m), Tuple{Vector{Float64}, Float64}}, Float64}((x = [1.6642061055583879, 1.796319600944139], m = -0.16796295277202952), -5.769094411622931) | ||
|
||
julia> # (✓) Vroom, vroom! FAST!!! | ||
DynamicPPL.getval(vi, @varname(x[1])) | ||
1.6642061055583879 | ||
|
||
julia> # We can also access arbitrary varnames pointing to `x`, e.g. | ||
DynamicPPL.getval(vi, @varname(x)) | ||
2-element Vector{Float64}: | ||
1.6642061055583879 | ||
1.796319600944139 | ||
|
||
julia> DynamicPPL.getval(vi, @varname(x[1:2])) | ||
2-element view(::Vector{Float64}, 1:2) with eltype Float64: | ||
1.6642061055583879 | ||
1.796319600944139 | ||
|
||
julia> # (×) If we don't provide the container... | ||
_, vi = DynamicPPL.evaluate(m, SimpleVarInfo(), ctx); vi | ||
ERROR: type NamedTuple has no field x | ||
[...] | ||
|
||
julia> # If one does not know the varnames, we can use a `Dict` instead. | ||
_, vi = DynamicPPL.evaluate(m, SimpleVarInfo{Float64}(Dict()), ctx); vi | ||
SimpleVarInfo{Dict{Any, Any}, Float64}(Dict{Any, Any}(x[1] => 1.192696983568277, x[2] => 0.4914514300738121, m => 0.25572200616753643), -3.6215377732004237) | ||
|
||
julia> # (✓) Sort of fast, but only possible at runtime. | ||
DynamicPPL.getval(vi, @varname(x[1])) | ||
1.192696983568277 | ||
|
||
julia> # In addtion, we can only access varnames as they appear in the model! | ||
DynamicPPL.getval(vi, @varname(x)) | ||
ERROR: KeyError: key x not found | ||
[...] | ||
|
||
julia> julia> DynamicPPL.getval(vi, @varname(x[1:2])) | ||
ERROR: KeyError: key x[1:2] not found | ||
[...] | ||
``` | ||
""" | ||
struct SimpleVarInfo{NT,T} <: AbstractVarInfo | ||
θ::NT | ||
logp::T | ||
end | ||
|
||
SimpleVarInfo{T}(θ) where {T<:Real} = SimpleVarInfo{typeof(θ),T}(θ, zero(T)) | ||
SimpleVarInfo(θ) = SimpleVarInfo{eltype(first(θ))}(θ) | ||
SimpleVarInfo{T}() where {T<:Real} = SimpleVarInfo{T}(NamedTuple()) | ||
SimpleVarInfo() = SimpleVarInfo{Float64}() | ||
|
||
# Constructor from `Model`. | ||
SimpleVarInfo(model::Model, args...) = SimpleVarInfo{Float64}(model, args...) | ||
function SimpleVarInfo{T}(model::Model, args...) where {T<:Real} | ||
_, svi = DynamicPPL.evaluate(model, SimpleVarInfo{T}(), args...) | ||
return svi | ||
end | ||
|
||
# Constructor from `VarInfo`. | ||
function SimpleVarInfo(vi::TypedVarInfo, ::Type{D}=NamedTuple; kwargs...) where {D} | ||
return SimpleVarInfo{eltype(getlogp(vi))}(vi, D; kwargs...) | ||
end | ||
function SimpleVarInfo{T}( | ||
vi::VarInfo{<:NamedTuple{names}}, ::Type{D} | ||
) where {T<:Real,names,D} | ||
values = values_as(vi, D) | ||
return SimpleVarInfo{T}(values) | ||
end | ||
|
||
getlogp(vi::SimpleVarInfo) = vi.logp | ||
setlogp!!(vi::SimpleVarInfo, logp) = SimpleVarInfo(vi.θ, logp) | ||
acclogp!!(vi::SimpleVarInfo, logp) = SimpleVarInfo(vi.θ, getlogp(vi) + logp) | ||
|
||
function setlogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) | ||
vi.logp[] = logp | ||
return vi | ||
end | ||
|
||
function acclogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) | ||
vi.logp[] += logp | ||
return vi | ||
end | ||
|
||
function _getvalue(nt::NamedTuple, ::Val{sym}, inds=()) where {sym} | ||
# Use `getproperty` instead of `getfield` | ||
value = getproperty(nt, sym) | ||
# Note that this will return a `view`, even if the resulting value is 0-dim. | ||
# This makes it possible to call `setindex!` on the result later to update | ||
# in place even in the case where are retrieving a single element, e.g. `x[1]`. | ||
return _getindex(value, inds) | ||
end | ||
|
||
# `NamedTuple` | ||
function getval(vi::SimpleVarInfo{<:NamedTuple}, vn::VarName{sym}) where {sym} | ||
return maybe_unwrap_view(_getvalue(vi.θ, Val{sym}(), vn.indexing)) | ||
end | ||
|
||
# `Dict` | ||
function getval(vi::SimpleVarInfo{<:Dict}, vn::VarName) | ||
return vi.θ[vn] | ||
end | ||
|
||
# `SimpleVarInfo` doesn't necessarily vectorize, so we can have arrays other than | ||
# just `Vector`. | ||
getval(vi::SimpleVarInfo, vns::AbstractArray{<:VarName}) = map(vn -> getval(vi, vn), vns) | ||
# To disambiguiate. | ||
getval(vi::SimpleVarInfo, vns::Vector{<:VarName}) = map(vn -> getval(vi, vn), vns) | ||
|
||
haskey(vi::SimpleVarInfo, vn) = haskey(vi.θ, getsym(vn)) | ||
|
||
istrans(::SimpleVarInfo, vn::VarName) = false | ||
|
||
getindex(vi::SimpleVarInfo, spl::SampleFromPrior) = vi.θ | ||
getindex(vi::SimpleVarInfo, spl::SampleFromUniform) = vi.θ | ||
# TODO: Should we do better? | ||
getindex(vi::SimpleVarInfo, spl::Sampler) = vi.θ | ||
getindex(vi::SimpleVarInfo, vn::VarName) = getval(vi, vn) | ||
getindex(vi::SimpleVarInfo, vns::AbstractArray{<:VarName}) = getval(vi, vns) | ||
# HACK: Need to disambiguiate. | ||
getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}) = getval(vi, vns) | ||
|
||
# Necessary for `matchingvalue` to work properly. | ||
function Base.eltype( | ||
vi::SimpleVarInfo{<:Any,T}, spl::Union{AbstractSampler,SampleFromPrior} | ||
) where {T} | ||
return T | ||
end | ||
|
||
# `NamedTuple` | ||
function push!!( | ||
vi::SimpleVarInfo{<:NamedTuple}, | ||
vn::VarName{sym,Tuple{}}, | ||
value, | ||
dist::Distribution, | ||
gidset::Set{Selector}, | ||
) where {sym} | ||
@set vi.θ = merge(vi.θ, NamedTuple{(sym,)}((value,))) | ||
end | ||
function push!!( | ||
vi::SimpleVarInfo{<:NamedTuple}, | ||
vn::VarName{sym}, | ||
value, | ||
dist::Distribution, | ||
gidset::Set{Selector}, | ||
) where {sym} | ||
# We update in place. | ||
# We need a view into the array, hence we call `_getvalue` directly | ||
# rather than `getval`. | ||
current = _getvalue(vi.θ, Val{sym}(), vn.indexing) | ||
current .= value | ||
return vi | ||
end | ||
|
||
# `Dict` | ||
function push!!( | ||
vi::SimpleVarInfo{<:Dict}, vn::VarName, r, dist::Distribution, gidset::Set{Selector} | ||
) | ||
vi.θ[vn] = r | ||
return vi | ||
end | ||
|
||
# Context implementations | ||
function tilde_assume!!(context, right, vn, inds, vi::SimpleVarInfo) | ||
value, logp, vi_new = tilde_assume(context, right, vn, inds, vi) | ||
return value, acclogp!!(vi_new, logp) | ||
end | ||
|
||
function assume(dist::Distribution, vn::VarName, vi::SimpleVarInfo) | ||
left = vi[vn] | ||
return left, Distributions.loglikelihood(dist, left), vi | ||
end | ||
|
||
function assume( | ||
rng::Random.AbstractRNG, | ||
sampler::SampleFromPrior, | ||
dist::Distribution, | ||
vn::VarName, | ||
vi::SimpleVarInfo, | ||
) | ||
value = init(rng, dist, sampler) | ||
vi = push!!(vi, vn, value, dist, sampler) | ||
vi = settrans!!(vi, false, vn) | ||
return value, Distributions.loglikelihood(dist, value), vi | ||
end | ||
|
||
# function dot_tilde_assume!!(context, right, left, vn, inds, vi::SimpleVarInfo) | ||
# throw(MethodError(dot_tilde_assume!!, (context, right, left, vn, inds, vi))) | ||
# end | ||
|
||
function dot_tilde_assume!!(context, right, left, vn, inds, vi::SimpleVarInfo) | ||
value, logp, vi_new = dot_tilde_assume(context, right, left, vn, inds, vi) | ||
# Mutation of `value` no longer occurs in main body, so we do it here. | ||
left .= value | ||
return value, acclogp!!(vi_new, logp) | ||
end | ||
|
||
function dot_assume( | ||
dist::MultivariateDistribution, | ||
var::AbstractMatrix, | ||
vns::AbstractVector{<:VarName}, | ||
vi::SimpleVarInfo, | ||
) | ||
@assert length(dist) == size(var, 1) | ||
# NOTE: We cannot work with `var` here because we might have a model of the form | ||
# | ||
# m = Vector{Float64}(undef, n) | ||
# m .~ Normal() | ||
# | ||
# in which case `var` will have `undef` elements, even if `m` is present in `vi`. | ||
value = vi[vns] | ||
lp = sum(zip(vns, eachcol(value))) do vn, val | ||
return Distributions.logpdf(dist, val) | ||
end | ||
return value, lp, vi | ||
end | ||
|
||
function dot_assume( | ||
dists::Union{Distribution,AbstractArray{<:Distribution}}, | ||
var::AbstractArray, | ||
vns::AbstractArray{<:VarName}, | ||
vi::SimpleVarInfo{<:NamedTuple}, | ||
) | ||
# NOTE: We cannot work with `var` here because we might have a model of the form | ||
# | ||
# m = Vector{Float64}(undef, n) | ||
# m .~ Normal() | ||
# | ||
# in which case `var` will have `undef` elements, even if `m` is present in `vi`. | ||
value = vi[vns] | ||
lp = sum(Distributions.logpdf.(dists, value)) | ||
return value, lp, vi | ||
end | ||
|
||
# HACK: Allows us to re-use the impleemntation of `dot_tilde`, etc. for literals. | ||
increment_num_produce!(::SimpleVarInfo) = nothing | ||
settrans!!(vi::SimpleVarInfo, trans::Bool, vn::VarName) = vi |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Notice how here we also take the
vi_new
.Ideally this is what we should be doing overall if we're going to support immutable varinfos. But if we force this, we'll end up breaking a lot of downstream samplers since it requires changing
assume
statements to also returnvi
.As an intermediate step, it might be worth just overloading
tilde_assume!!
as I have done here +assume
as I have done below. This also brings up another annoyance though:vi
should really be at the beginning of the arguments of all the tilde-statements (and ideallyassume
too, but we can delay this) to minimize method ambiguities but allow us to have some special behavior for the different impls forAbstractVarInfo
.