-
Notifications
You must be signed in to change notification settings - Fork 36
Clean up varinfo get/set functions #853
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -100,20 +100,14 @@ const TypedVarInfo = VarInfo{<:NamedTuple} | |
const VarInfoOrThreadSafeVarInfo{Tmeta} = Union{ | ||
VarInfo{Tmeta},ThreadSafeVarInfo{<:VarInfo{Tmeta}} | ||
} | ||
# TODO: Remove this | ||
@deprecate VarInfo(vi::VarInfo, x::AbstractVector) unflatten(vi, x) | ||
|
||
# NOTE: This is kind of weird, but it effectively preserves the "old" | ||
# behavior where we're allowed to call `link!` on the same `VarInfo` | ||
# multiple times. | ||
transformation(::VarInfo) = DynamicTransformation() | ||
|
||
# TODO(mhauru) Isn't this the same as unflatten and/or replace_values? | ||
function VarInfo(old_vi::VarInfo, x::AbstractVector) | ||
md = replace_values(old_vi.metadata, x) | ||
return VarInfo( | ||
md, Base.RefValue{eltype(x)}(getlogp(old_vi)), Ref(get_num_produce(old_vi)) | ||
) | ||
end | ||
Comment on lines
-109
to
-115
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This constructor, as far as I can tell, wasn't being used anywhere. Unfortunately, because VarInfo is exported, this constructor is public and so this PR has to be a breaking change (all other changes are purely internal). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is fine IMO, looks like we'll touch it again when the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm that's a good point, VarInfo constructors will change when num_produce is changed - but that might be a while since Markus is off for a week and it's also a tricky piece of work. Maybe it'd be easier to leave this method in (maybe with a depwarn) and then release this PR as a patch - what do you think? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah, that works for me, seems this is the only breaking change? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yup, I think so. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok - I added this back in with a |
||
|
||
# No-op if we're already working with a `VarNamedVector`. | ||
metadata_to_varnamedvector(vnv::VarNamedVector) = vnv | ||
function metadata_to_varnamedvector(md::Metadata) | ||
|
@@ -243,9 +237,8 @@ end | |
return :($(exprs...),) | ||
end | ||
|
||
# For Metadata unflatten and replace_values are the same. For VarNamedVector they are not. | ||
function unflatten_metadata(md::Metadata, x::AbstractVector) | ||
return replace_values(md, x) | ||
return Metadata(md.idcs, md.vns, md.ranges, x, md.dists, md.orders, md.flags) | ||
end | ||
Comment on lines
240
to
242
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was the only place that |
||
|
||
unflatten_metadata(vnv::VarNamedVector, x::AbstractVector) = unflatten(vnv, x) | ||
|
@@ -255,31 +248,6 @@ function VarInfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext | |
return VarInfo(rng, model, SampleFromPrior(), context) | ||
end | ||
|
||
function replace_values(metadata::Metadata, x) | ||
return Metadata( | ||
metadata.idcs, | ||
metadata.vns, | ||
metadata.ranges, | ||
x, | ||
metadata.dists, | ||
metadata.orders, | ||
metadata.flags, | ||
) | ||
end | ||
|
||
@generated function replace_values(metadata::NamedTuple{names}, x) where {names} | ||
exprs = [] | ||
offset = :(0) | ||
for f in names | ||
mdf = :(metadata.$f) | ||
len = :(sum(length, $mdf.ranges)) | ||
push!(exprs, :($f = replace_values($mdf, x[($offset + 1):($offset + $len)]))) | ||
offset = :($offset + $len) | ||
end | ||
length(exprs) == 0 && return :(NamedTuple()) | ||
return :($(exprs...),) | ||
end | ||
Comment on lines
-270
to
-281
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This method is not needed because it's exactly the same thing as |
||
|
||
#### | ||
#### Internal functions | ||
#### | ||
|
@@ -652,10 +620,20 @@ getindex_internal(vi::VarInfo, vn::VarName) = getindex_internal(getmetadata(vi, | |
# what a bijector would result in, even if the input is a view (`SubArray`). | ||
# TODO(torfjelde): An alternative is to implement `view` directly instead. | ||
getindex_internal(md::Metadata, vn::VarName) = getindex(md.vals, getrange(md, vn)) | ||
|
||
function getindex_internal(vi::VarInfo, vns::Vector{<:VarName}) | ||
return mapreduce(Base.Fix1(getindex_internal, vi), vcat, vns) | ||
end | ||
getindex_internal(vi::VarInfo, ::Colon) = getindex_internal(vi.metadata, Colon()) | ||
# NOTE: `mapreduce` over `NamedTuple` results in worse type-inference. | ||
# See for example https://github.com/JuliaLang/julia/pull/46381. | ||
function getindex_internal(vi::TypedVarInfo, ::Colon) | ||
return reduce(vcat, map(Base.Fix2(getindex_internal, Colon()), vi.metadata)) | ||
end | ||
function getindex_internal(md::Metadata, ::Colon) | ||
return mapreduce( | ||
Base.Fix1(getindex_internal, md), vcat, md.vns; init=similar(md.vals, 0) | ||
) | ||
end | ||
Comment on lines
+626
to
+636
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
(1) |
||
|
||
""" | ||
setval!(vi::VarInfo, val, vn::VarName) | ||
|
@@ -672,56 +650,6 @@ function setval!(md::Metadata, val, vn::VarName) | |
return md.vals[getrange(md, vn)] = tovec(val) | ||
end | ||
|
||
""" | ||
getall(vi::VarInfo) | ||
|
||
Return the values of all the variables in `vi`. | ||
|
||
The values may or may not be transformed to Euclidean space. | ||
""" | ||
getall(vi::VarInfo) = getall(vi.metadata) | ||
# NOTE: `mapreduce` over `NamedTuple` results in worse type-inference. | ||
# See for example https://github.com/JuliaLang/julia/pull/46381. | ||
getall(vi::TypedVarInfo) = reduce(vcat, map(getall, vi.metadata)) | ||
function getall(md::Metadata) | ||
return mapreduce( | ||
Base.Fix1(getindex_internal, md), vcat, md.vns; init=similar(md.vals, 0) | ||
) | ||
end | ||
getall(vnv::VarNamedVector) = getindex_internal(vnv, Colon()) | ||
|
||
""" | ||
setall!(vi::VarInfo, val) | ||
|
||
Set the values of all the variables in `vi` to `val`. | ||
|
||
The values may or may not be transformed to Euclidean space. | ||
""" | ||
setall!(vi::VarInfo, val) = _setall!(vi.metadata, val) | ||
|
||
function _setall!(metadata::Metadata, val) | ||
for r in metadata.ranges | ||
metadata.vals[r] .= val[r] | ||
end | ||
end | ||
function _setall!(vnv::VarNamedVector, val) | ||
# TODO(mhauru) Do something more efficient here. | ||
for i in 1:length_internal(vnv) | ||
setindex_internal!(vnv, val[i], i) | ||
end | ||
end | ||
@generated function _setall!(metadata::NamedTuple{names}, val) where {names} | ||
expr = Expr(:block) | ||
start = :(1) | ||
for f in names | ||
length = :(sum(length, metadata.$f.ranges)) | ||
finish = :($start + $length - 1) | ||
push!(expr.args, :(copyto!(metadata.$f.vals, 1, val, $start, $length))) | ||
start = :($start + $length) | ||
end | ||
return expr | ||
end | ||
Comment on lines
-693
to
-723
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
function settrans!!(vi::VarInfo, trans::Bool, vn::VarName) | ||
settrans!!(getmetadata(vi, vn), trans, vn) | ||
return vi | ||
|
@@ -2114,7 +2042,7 @@ function _setval_and_resample_kernel!( | |
end | ||
|
||
values_as(vi::VarInfo) = vi.metadata | ||
values_as(vi::VarInfo, ::Type{Vector}) = copy(getall(vi)) | ||
values_as(vi::VarInfo, ::Type{Vector}) = copy(getindex_internal(vi, Colon())) | ||
function values_as(vi::UntypedVarInfo, ::Type{NamedTuple}) | ||
iter = values_from_metadata(vi.metadata) | ||
return NamedTuple(map(p -> Symbol(p.first) => p.second, iter)) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function, originally called
set_values!!
, is very similar tounflatten
but does a couple of extra things. Consequently, I renamed the function and added a docstring to make it clear how it differs.