Skip to content

Clean up varinfo get/set functions #853

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,16 @@
# DynamicPPL Changelog

## 0.35.5

Several internal methods have been removed:

- `DynamicPPL.getall(vi::AbstractVarInfo)` has been removed. You can directly replace this with `getindex_internal(vi, Colon())`.
- `DynamicPPL.setall!(vi::AbstractVarInfo, values)` has been removed. Rewrite the calling function to not assume mutation and use `unflatten(vi, values)` instead.
- `DynamicPPL.replace_values(md::Metadata, values)` and `DynamicPPL.replace_values(nt::NamedTuple, values)` (where the `nt` is a NamedTuple of Metadatas) have been removed. Use `DynamicPPL.unflatten_metadata` as a direct replacement.
- `DynamicPPL.set_values!!(vi::AbstractVarInfo, values)` has been renamed to `DynamicPPL.set_initial_values(vi::AbstractVarInfo, values)`; it also no longer mutates the varinfo argument.

The **exported** method `VarInfo(vi::VarInfo, values)` has been deprecated, and will be removed in the next minor version. You can replace this directly with `unflatten(vi, values)` instead.

## 0.35.4

Fixed a type instability in an implementation of `with_logabsdet_jacobian`, which resulted in the log-jacobian returned being an Int in some cases and a Float in others.
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.35.4"
version = "0.35.5"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
10 changes: 9 additions & 1 deletion src/abstract_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,16 @@ Base.getindex(vi::AbstractVarInfo, ::Colon) = values_as(vi, Vector)
"""
getindex_internal(vi::AbstractVarInfo, vn::VarName)
getindex_internal(vi::AbstractVarInfo, vns::Vector{<:VarName})
getindex_internal(vi::AbstractVarInfo, ::Colon)

Return the current value(s) of `vn` (`vns`) in `vi` as represented internally in `vi`.
Return the internal value of the varname `vn`, varnames `vns`, or all varnames
in `vi` respectively. The internal value is the value of the variables that is
stored in the varinfo object; this may be the actual realisation of the random
variable (i.e. the value sampled from the distribution), or it may have been
transformed to Euclidean space, depending on whether the varinfo was linked.

See https://turinglang.org/docs/developers/transforms/dynamicppl/ for more
information on how transformed variables are stored in DynamicPPL.

See also: [`getindex(vi::AbstractVarInfo, vn::VarName, dist::Distribution)`](@ref)
"""
Expand Down
27 changes: 21 additions & 6 deletions src/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,21 @@ By default, it returns an instance of [`SampleFromPrior`](@ref).
"""
initialsampler(spl::Sampler) = SampleFromPrior()

function set_values!!(varinfo::AbstractVarInfo, initial_params::AbstractVector)
"""
set_initial_values(varinfo::AbstractVarInfo, initial_params::AbstractVector)
set_initial_values(varinfo::AbstractVarInfo, initial_params::NamedTuple)

Take the values inside `initial_params`, replace the corresponding values in
the given VarInfo object, and return a new VarInfo object with the updated values.

This differs from `DynamicPPL.unflatten` in two ways:

1. It works with `NamedTuple` arguments.
2. For the `AbstractVector` method, if any of the elements are missing, it will not
overwrite the original value in the VarInfo (it will just use the original
value instead).
"""
function set_initial_values(varinfo::AbstractVarInfo, initial_params::AbstractVector)
throw(
ArgumentError(
"`initial_params` must be a vector of type `Union{Real,Missing}`. " *
Expand All @@ -160,7 +174,7 @@ function set_values!!(varinfo::AbstractVarInfo, initial_params::AbstractVector)
)
end

function set_values!!(
function set_initial_values(
varinfo::AbstractVarInfo, initial_params::AbstractVector{<:Union{Real,Missing}}
Comment on lines -163 to 178
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function, originally called set_values!!, is very similar to unflatten but does a couple of extra things. Consequently, I renamed the function and added a docstring to make it clear how it differs.

)
flattened_param_vals = varinfo[:]
Expand All @@ -180,11 +194,12 @@ function set_values!!(
end

# Update in `varinfo`.
setall!(varinfo, flattened_param_vals)
return varinfo
new_varinfo = unflatten(varinfo, flattened_param_vals)
return new_varinfo
end

function set_values!!(varinfo::AbstractVarInfo, initial_params::NamedTuple)
function set_initial_values(varinfo::AbstractVarInfo, initial_params::NamedTuple)
varinfo = deepcopy(varinfo)
vars_in_varinfo = keys(varinfo)
for v in keys(initial_params)
vn = VarName{v}()
Expand Down Expand Up @@ -219,7 +234,7 @@ function initialize_parameters!!(vi::AbstractVarInfo, initial_params, model::Mod
end

# Set the values in `vi`.
vi = set_values!!(vi, initial_params)
vi = set_initial_values(vi, initial_params)

# `invlink` if needed.
if linked
Expand Down
102 changes: 15 additions & 87 deletions src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,20 +100,14 @@ const TypedVarInfo = VarInfo{<:NamedTuple}
const VarInfoOrThreadSafeVarInfo{Tmeta} = Union{
VarInfo{Tmeta},ThreadSafeVarInfo{<:VarInfo{Tmeta}}
}
# TODO: Remove this
@deprecate VarInfo(vi::VarInfo, x::AbstractVector) unflatten(vi, x)

# NOTE: This is kind of weird, but it effectively preserves the "old"
# behavior where we're allowed to call `link!` on the same `VarInfo`
# multiple times.
transformation(::VarInfo) = DynamicTransformation()

# TODO(mhauru) Isn't this the same as unflatten and/or replace_values?
function VarInfo(old_vi::VarInfo, x::AbstractVector)
md = replace_values(old_vi.metadata, x)
return VarInfo(
md, Base.RefValue{eltype(x)}(getlogp(old_vi)), Ref(get_num_produce(old_vi))
)
end
Comment on lines -109 to -115
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This constructor, as far as I can tell, wasn't being used anywhere. Unfortunately, because VarInfo is exported, this constructor is public and so this PR has to be a breaking change (all other changes are purely internal).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine IMO, looks like we'll touch it again when the num_produce is removed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm that's a good point, VarInfo constructors will change when num_produce is changed - but that might be a while since Markus is off for a week and it's also a tricky piece of work. Maybe it'd be easier to leave this method in (maybe with a depwarn) and then release this PR as a patch - what do you think?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, that works for me, seems this is the only breaking change?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, I think so.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok - I added this back in with a @deprecate, and we can remove it in 0.36.0.


# No-op if we're already working with a `VarNamedVector`.
metadata_to_varnamedvector(vnv::VarNamedVector) = vnv
function metadata_to_varnamedvector(md::Metadata)
Expand Down Expand Up @@ -243,9 +237,8 @@ end
return :($(exprs...),)
end

# For Metadata unflatten and replace_values are the same. For VarNamedVector they are not.
function unflatten_metadata(md::Metadata, x::AbstractVector)
return replace_values(md, x)
return Metadata(md.idcs, md.vns, md.ranges, x, md.dists, md.orders, md.flags)
end
Comment on lines 240 to 242
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was the only place that replace_values was being called, so I just inlined its definition


unflatten_metadata(vnv::VarNamedVector, x::AbstractVector) = unflatten(vnv, x)
Expand All @@ -255,31 +248,6 @@ function VarInfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext
return VarInfo(rng, model, SampleFromPrior(), context)
end

function replace_values(metadata::Metadata, x)
return Metadata(
metadata.idcs,
metadata.vns,
metadata.ranges,
x,
metadata.dists,
metadata.orders,
metadata.flags,
)
end

@generated function replace_values(metadata::NamedTuple{names}, x) where {names}
exprs = []
offset = :(0)
for f in names
mdf = :(metadata.$f)
len = :(sum(length, $mdf.ranges))
push!(exprs, :($f = replace_values($mdf, x[($offset + 1):($offset + $len)])))
offset = :($offset + $len)
end
length(exprs) == 0 && return :(NamedTuple())
return :($(exprs...),)
end
Comment on lines -270 to -281
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method is not needed because it's exactly the same thing as unflatten_metadata


####
#### Internal functions
####
Expand Down Expand Up @@ -652,10 +620,20 @@ getindex_internal(vi::VarInfo, vn::VarName) = getindex_internal(getmetadata(vi,
# what a bijector would result in, even if the input is a view (`SubArray`).
# TODO(torfjelde): An alternative is to implement `view` directly instead.
getindex_internal(md::Metadata, vn::VarName) = getindex(md.vals, getrange(md, vn))

function getindex_internal(vi::VarInfo, vns::Vector{<:VarName})
return mapreduce(Base.Fix1(getindex_internal, vi), vcat, vns)
end
getindex_internal(vi::VarInfo, ::Colon) = getindex_internal(vi.metadata, Colon())
# NOTE: `mapreduce` over `NamedTuple` results in worse type-inference.
# See for example https://github.com/JuliaLang/julia/pull/46381.
function getindex_internal(vi::TypedVarInfo, ::Colon)
return reduce(vcat, map(Base.Fix2(getindex_internal, Colon()), vi.metadata))
end
function getindex_internal(md::Metadata, ::Colon)
return mapreduce(
Base.Fix1(getindex_internal, md), vcat, md.vns; init=similar(md.vals, 0)
)
end
Comment on lines +626 to +636
Copy link
Member Author

@penelopeysm penelopeysm Mar 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getall(vi) is replaced with getindex_internal(vi, :). This is for two reasons:

(1) getall(::VarNamedVector) already defers to getindex_internal
(2) IMO, the internal part of the name makes it clearer that one is accessing internal values.


"""
setval!(vi::VarInfo, val, vn::VarName)
Expand All @@ -672,56 +650,6 @@ function setval!(md::Metadata, val, vn::VarName)
return md.vals[getrange(md, vn)] = tovec(val)
end

"""
getall(vi::VarInfo)

Return the values of all the variables in `vi`.

The values may or may not be transformed to Euclidean space.
"""
getall(vi::VarInfo) = getall(vi.metadata)
# NOTE: `mapreduce` over `NamedTuple` results in worse type-inference.
# See for example https://github.com/JuliaLang/julia/pull/46381.
getall(vi::TypedVarInfo) = reduce(vcat, map(getall, vi.metadata))
function getall(md::Metadata)
return mapreduce(
Base.Fix1(getindex_internal, md), vcat, md.vns; init=similar(md.vals, 0)
)
end
getall(vnv::VarNamedVector) = getindex_internal(vnv, Colon())

"""
setall!(vi::VarInfo, val)

Set the values of all the variables in `vi` to `val`.

The values may or may not be transformed to Euclidean space.
"""
setall!(vi::VarInfo, val) = _setall!(vi.metadata, val)

function _setall!(metadata::Metadata, val)
for r in metadata.ranges
metadata.vals[r] .= val[r]
end
end
function _setall!(vnv::VarNamedVector, val)
# TODO(mhauru) Do something more efficient here.
for i in 1:length_internal(vnv)
setindex_internal!(vnv, val[i], i)
end
end
@generated function _setall!(metadata::NamedTuple{names}, val) where {names}
expr = Expr(:block)
start = :(1)
for f in names
length = :(sum(length, metadata.$f.ranges))
finish = :($start + $length - 1)
push!(expr.args, :(copyto!(metadata.$f.vals, 1, val, $start, $length)))
start = :($start + $length)
end
return expr
end
Comment on lines -693 to -723
Copy link
Member Author

@penelopeysm penelopeysm Mar 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

setall! (and _setall!) was just a mutating version of unflatten (and unflatten_metadata). The solution to this code duplication is to remove the mutating version and make people use unflatten. There weren't any parts of the codebase where this could conceivably cause performance problems.


function settrans!!(vi::VarInfo, trans::Bool, vn::VarName)
settrans!!(getmetadata(vi, vn), trans, vn)
return vi
Expand Down Expand Up @@ -2114,7 +2042,7 @@ function _setval_and_resample_kernel!(
end

values_as(vi::VarInfo) = vi.metadata
values_as(vi::VarInfo, ::Type{Vector}) = copy(getall(vi))
values_as(vi::VarInfo, ::Type{Vector}) = copy(getindex_internal(vi, Colon()))
function values_as(vi::UntypedVarInfo, ::Type{NamedTuple})
iter = values_from_metadata(vi.metadata)
return NamedTuple(map(p -> Symbol(p.first) => p.second, iter))
Expand Down
6 changes: 3 additions & 3 deletions test/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,12 +163,12 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
Random.seed!(100 + i)
vi = VarInfo()
model(Random.default_rng(), vi, sampler)
vals = DynamicPPL.getall(vi)
vals = vi[:]

Random.seed!(100 + i)
vi = VarInfo()
model(Random.default_rng(), vi, sampler)
@test DynamicPPL.getall(vi) == vals
@test vi[:] == vals
end
end
end
Expand Down Expand Up @@ -240,7 +240,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
for i in 1:10
# Sample with large variations.
r_raw = randn(length(vi[:])) * 10
DynamicPPL.setall!(vi, r_raw)
vi = DynamicPPL.unflatten(vi, r_raw)
@test vi[@varname(m)] == r_raw[1]
@test vi[@varname(x)] != r_raw[2]
model(vi)
Expand Down
3 changes: 2 additions & 1 deletion test/simple_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,8 @@

# `getlogp` should be equal to the logjoint with log-absdet-jac correction.
lp = getlogp(svi)
@test lp ≈ lp_true
# needs higher atol because of https://github.com/TuringLang/Bijectors.jl/issues/375
@test lp ≈ lp_true atol = 1.2e-5
end
end
end
Expand Down
3 changes: 2 additions & 1 deletion test/test_util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
end
const gdemo_default = gdemo_d()

# TODO(penelopeysm): Remove this (and also test/compat/ad.jl)
function test_model_ad(model, logp_manual)
vi = VarInfo(model)
x = DynamicPPL.getall(vi)
x = vi[:]

# Log probabilities using the model.
ℓ = DynamicPPL.LogDensityFunction(model, vi)
Expand Down