Skip to content

Remove samplers from VarInfo - indexing #793

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 45 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
4dc2a72
Remove selector stuff from varinfo tests
mhauru Jan 16, 2025
9b492a3
Implement link and invlink for varnames rather than samplers
mhauru Jan 16, 2025
b508f08
Replace set_retained_vns_del_by_spl! with set_retained_vns_del!
mhauru Jan 16, 2025
b8880d1
Make linking tests more extensive
mhauru Jan 16, 2025
99a8490
Remove sampler indexing from link methods (but not invlink)
mhauru Jan 22, 2025
4a79b1f
Remove indexing by samplers from invlink
mhauru Jan 22, 2025
26a1901
Merge remote-tracking branch 'origin/master' into mhauru/remove-selec…
mhauru Jan 22, 2025
090608b
Work towards removing sampler indexing with StaticTransformation
mhauru Jan 22, 2025
4749853
Fix invlink/link for TypedVarInfo and StaticTransformation
mhauru Jan 23, 2025
e960679
Fix a test in models.jl
mhauru Jan 23, 2025
d507a53
Move some functions to utils.jl, add tests and docstrings
mhauru Jan 23, 2025
41150b5
Fix a docstring typo
mhauru Jan 23, 2025
836fb13
Merge branch 'release-0.35' into mhauru/remove-selectors-linking
mhauru Jan 23, 2025
45d1f13
Various simplification to link/invlink
mhauru Jan 23, 2025
98915c2
Improve a docstring
mhauru Jan 23, 2025
f05068d
Style improvements
mhauru Jan 23, 2025
bc4c420
Fix broken link/invlink dispatch cascade for VectorVarInfo
mhauru Jan 23, 2025
71980ba
Fix some more broken dispatch cascades
mhauru Jan 23, 2025
45562a9
Apply suggestions from code review
mhauru Jan 24, 2025
db5b835
Remove comments that messed with docstrings
mhauru Jan 24, 2025
f99effe
Apply suggestions from code review
mhauru Jan 28, 2025
56194cd
Fix issues surfaced in code review
mhauru Jan 28, 2025
c187c49
Simplify link/invlink arguments
mhauru Jan 28, 2025
86b25c5
Fix a bug in unflatten VarNamedVector
mhauru Jan 28, 2025
2a6c1bc
Rename VarNameCollection -> VarNameTuple
mhauru Jan 28, 2025
853f47e
Remove test of a removed varname_namedtuple method
mhauru Jan 28, 2025
ed80328
Apply suggestions from code review
mhauru Jan 29, 2025
d996d0c
Respond to review feedback
mhauru Jan 29, 2025
2083148
Remove _default_sampler and a dead argument of maybe_invlink_before_eval
mhauru Jan 29, 2025
39fa647
Fix a typo in a comment
mhauru Jan 29, 2025
9df364f
Merge remote-tracking branch 'origin/release-0.35' into mhauru/remove…
mhauru Jan 30, 2025
2c73de5
Add HISTORY entry, fix one set_retained_vns_del! method
mhauru Jan 30, 2025
49604e1
Merge remote-tracking branch 'origin/release-0.35' into mhauru/remove…
mhauru Jan 30, 2025
1c50d0c
Remove some VarInfo getindex with samplers stuff
mhauru Jan 23, 2025
b919fe4
Remove some index setting with samplers
mhauru Jan 24, 2025
fc0e064
Remove more sampler indexing
mhauru Jan 24, 2025
5fbe016
Remove unflatten with samplers
mhauru Jan 30, 2025
cb5c79d
Clean up some setindex stuff
mhauru Jan 30, 2025
fa696d3
Merge remote-tracking branch 'origin/release-0.35' into mhauru/remove…
mhauru Jan 30, 2025
414f58e
Remove a bunch of varinfo.jl internal functions that used samplers/sp…
mhauru Jan 30, 2025
800b91a
Fix HISTORY.md
mhauru Jan 30, 2025
e65777e
Miscalleanous small fixes
mhauru Jan 30, 2025
8fcc289
Fix a bug in VarInfo constructor
mhauru Jan 30, 2025
c59aafe
Fix getparams(::LogDensityFunction)
mhauru Jan 30, 2025
934fb79
Apply suggestions from code review
mhauru Feb 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ This release removes the feature of `VarInfo` where it kept track of which varia

- `link` and `invlink`, and their `!!` versions, no longer accept a sampler as an argument to specify which variables to (inv)link. The `link(varinfo, model)` methods remain in place, and as a new addition one can give a `Tuple` of `VarName`s to (inv)link only select variables, as in `link(varinfo, varname_tuple, model)`.
- `set_retained_vns_del_by_spl!` has been replaced by `set_retained_vns_del!` which applies to all variables.
- `getindex`, `setindex!`, and `setindex!!` no longer accept samplers as arguments
- `unflatten` no longer accepts a sampler as an argument
- `eltype(::VarInfo)` no longer accepts a sampler as an argument
- `keys(::VarInfo)` no longer accepts a sampler as an argument
- `VarInfo(::VarInfo, ::Sampler, ::AbstactVector)` no longer accepts the sampler argument.

### Reverse prefixing order

Expand Down
32 changes: 8 additions & 24 deletions src/abstract_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,6 @@

"""
getindex(vi::AbstractVarInfo, ::Colon)
getindex(vi::AbstractVarInfo, ::AbstractSampler)

Return the current value(s) of `vn` (`vns`) in `vi` in the support of its (their)
distribution(s) as a flattened `Vector`.
Expand All @@ -159,7 +158,6 @@
See also: [`getindex(vi::AbstractVarInfo, vn::VarName, dist::Distribution)`](@ref)
"""
Base.getindex(vi::AbstractVarInfo, ::Colon) = values_as(vi, Vector)
Base.getindex(vi::AbstractVarInfo, ::AbstractSampler) = vi[:]

"""
getindex_internal(vi::AbstractVarInfo, vn::VarName)
Expand Down Expand Up @@ -341,9 +339,9 @@
function values_as end

"""
eltype(vi::AbstractVarInfo, spl::Union{AbstractSampler,SampleFromPrior}
eltype(vi::AbstractVarInfo)

Determine the default `eltype` of the values returned by `vi[spl]`.
Return the `eltype` of the values returned by `vi[:]`.

!!! warning
This should generally not be called explicitly, as it's only used in
Expand All @@ -352,13 +350,13 @@

This method is considered legacy, and is likely to be deprecated in the future.
"""
function Base.eltype(vi::AbstractVarInfo, spl::Union{AbstractSampler,SampleFromPrior})
T = Base.promote_op(getindex, typeof(vi), typeof(spl))
function Base.eltype(vi::AbstractVarInfo)
T = Base.promote_op(getindex, typeof(vi), Colon)
if T === Union{}
# In this case `getindex(vi, spl)` errors
# In this case `getindex(vi, :)` errors
# Let us throw a more descriptive error message
# Ref https://github.com/TuringLang/Turing.jl/issues/2151
return eltype(vi[spl])
return eltype(vi[:])

Check warning on line 359 in src/abstract_varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/abstract_varinfo.jl#L359

Added line #L359 was not covered by tests
end
return eltype(T)
end
Expand Down Expand Up @@ -720,25 +718,11 @@

# Utilities
"""
unflatten(vi::AbstractVarInfo[, context::AbstractContext], x::AbstractVector)
unflatten(vi::AbstractVarInfo, x::AbstractVector)

Return a new instance of `vi` with the values of `x` assigned to the variables.

If `context` is provided, `x` is assumed to be realizations only for variables not
filtered out by `context`.
"""
function unflatten(varinfo::AbstractVarInfo, context::AbstractContext, θ)
if hassampler(context)
unflatten(getsampler(context), varinfo, context, θ)
else
DynamicPPL.unflatten(varinfo, θ)
end
end

# TODO: deprecate this once `sampler` is no longer the main way of filtering out variables.
function unflatten(sampler::AbstractSampler, varinfo::AbstractVarInfo, ::AbstractContext, θ)
return unflatten(varinfo, sampler, θ)
end
function unflatten end

"""
to_maybe_linked_internal(vi::AbstractVarInfo, vn::VarName, dist, val)
Expand Down
65 changes: 25 additions & 40 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""
need_concretize(expr)

Return `true` if `expr` needs to be concretized, i.e., if it contains a colon `:` or
Return `true` if `expr` needs to be concretized, i.e., if it contains a colon `:` or
requires a dynamic optic.

# Examples
Expand Down Expand Up @@ -730,19 +730,19 @@
return nothing
end

# TODO(mhauru) matchingvalue has methods that can accept both types and values. Why?
# TODO(mhauru) This function needs a more comprehensive docstring.
"""
matchingvalue(sampler, vi, value)
matchingvalue(context::AbstractContext, vi, value)
matchingvalue(vi, value)

Convert the `value` to the correct type for the `sampler` or `context` and the `vi` object.

For a `context` that is _not_ a `SamplingContext`, we fall back to
`matchingvalue(SampleFromPrior(), vi, value)`.
Convert the `value` to the correct type for the `vi` object.
"""
function matchingvalue(sampler, vi, value)
function matchingvalue(vi, value)
T = typeof(value)
if hasmissing(T)
_value = convert(get_matching_type(sampler, vi, T), value)
_value = convert(get_matching_type(vi, T), value)
# TODO(mhauru) Why do we make a deepcopy, even though in the !hasmissing branch we
# are happy to return `value` as-is?
if _value === value
return deepcopy(_value)
else
Expand All @@ -752,45 +752,30 @@
return value
end
end
# If we hit `Type` or `TypeWrap`, we immediately jump to `get_matching_type`.
function matchingvalue(sampler::AbstractSampler, vi, value::FloatOrArrayType)
return get_matching_type(sampler, vi, value)
end
function matchingvalue(sampler::AbstractSampler, vi, value::TypeWrap{T}) where {T}
return TypeWrap{get_matching_type(sampler, vi, T)}()
end

function matchingvalue(context::AbstractContext, vi, value)
return matchingvalue(NodeTrait(matchingvalue, context), context, vi, value)
function matchingvalue(vi, value::FloatOrArrayType)
return get_matching_type(vi, value)

Check warning on line 757 in src/compiler.jl

View check run for this annotation

Codecov / codecov/patch

src/compiler.jl#L756-L757

Added lines #L756 - L757 were not covered by tests
end
function matchingvalue(::IsLeaf, context::AbstractContext, vi, value)
return matchingvalue(SampleFromPrior(), vi, value)
end
function matchingvalue(::IsParent, context::AbstractContext, vi, value)
return matchingvalue(childcontext(context), vi, value)
end
function matchingvalue(context::SamplingContext, vi, value)
return matchingvalue(context.sampler, vi, value)
function matchingvalue(vi, ::TypeWrap{T}) where {T}
return TypeWrap{get_matching_type(vi, T)}()
end

# TODO(mhauru) This function needs a more comprehensive docstring. What is it for?
"""
get_matching_type(spl::AbstractSampler, vi, ::TypeWrap{T}) where {T}

Get the specialized version of type `T` for sampler `spl`.
get_matching_type(vi, ::TypeWrap{T}) where {T}

For example, if `T === Float64` and `spl::Hamiltonian`, the matching type is
`eltype(vi[spl])`.
Get the specialized version of type `T` for `vi`.
"""
get_matching_type(spl::AbstractSampler, vi, ::Type{T}) where {T} = T
function get_matching_type(spl::AbstractSampler, vi, ::Type{<:Union{Missing,AbstractFloat}})
return Union{Missing,float_type_with_fallback(eltype(vi, spl))}
get_matching_type(_, ::Type{T}) where {T} = T

Check warning on line 769 in src/compiler.jl

View check run for this annotation

Codecov / codecov/patch

src/compiler.jl#L769

Added line #L769 was not covered by tests
function get_matching_type(vi, ::Type{<:Union{Missing,AbstractFloat}})
return Union{Missing,float_type_with_fallback(eltype(vi))}
end
function get_matching_type(spl::AbstractSampler, vi, ::Type{<:AbstractFloat})
return float_type_with_fallback(eltype(vi, spl))
function get_matching_type(vi, ::Type{<:AbstractFloat})
return float_type_with_fallback(eltype(vi))
end
function get_matching_type(spl::AbstractSampler, vi, ::Type{<:Array{T,N}}) where {T,N}
return Array{get_matching_type(spl, vi, T),N}
function get_matching_type(vi, ::Type{<:Array{T,N}}) where {T,N}
return Array{get_matching_type(vi, T),N}
end
function get_matching_type(spl::AbstractSampler, vi, ::Type{<:Array{T}}) where {T}
return Array{get_matching_type(spl, vi, T)}
function get_matching_type(vi, ::Type{<:Array{T}}) where {T}
return Array{get_matching_type(vi, T)}
end
9 changes: 2 additions & 7 deletions src/logdensityfunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,22 +121,17 @@ end
getsampler(f::LogDensityFunction) = getsampler(getcontext(f))
hassampler(f::LogDensityFunction) = hassampler(getcontext(f))

_get_indexer(ctx::AbstractContext) = _get_indexer(NodeTrait(ctx), ctx)
_get_indexer(ctx::SamplingContext) = ctx.sampler
_get_indexer(::IsParent, ctx::AbstractContext) = _get_indexer(childcontext(ctx))
_get_indexer(::IsLeaf, ctx::AbstractContext) = Colon()

"""
getparams(f::LogDensityFunction)

Return the parameters of the wrapped varinfo as a vector.
"""
getparams(f::LogDensityFunction) = f.varinfo[_get_indexer(getcontext(f))]
getparams(f::LogDensityFunction) = f.varinfo[:]

# LogDensityProblems interface
function LogDensityProblems.logdensity(f::LogDensityFunction, θ::AbstractVector)
context = getcontext(f)
vi_new = unflatten(f.varinfo, context, θ)
vi_new = unflatten(f.varinfo, θ)
return getlogp(last(evaluate!!(f.model, vi_new, context)))
end
function LogDensityProblems.capabilities(::Type{<:LogDensityFunction})
Expand Down
4 changes: 2 additions & 2 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -948,9 +948,9 @@ Return the arguments and keyword arguments to be passed to the evaluator of the
) where {_F,argnames}
unwrap_args = [
if is_splat_symbol(var)
:($matchingvalue(context_new, varinfo, model.args.$var)...)
:($matchingvalue(varinfo, model.args.$var)...)
else
:($matchingvalue(context_new, varinfo, model.args.$var))
:($matchingvalue(varinfo, model.args.$var))
end for var in argnames
]

Expand Down
31 changes: 12 additions & 19 deletions src/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@

# Update the parameters if provided.
if initial_params !== nothing
vi = initialize_parameters!!(vi, initial_params, spl, model)
vi = initialize_parameters!!(vi, initial_params, model)

# Update joint log probability.
# This is a quick fix for https://github.com/TuringLang/Turing.jl/issues/1588
Expand Down Expand Up @@ -156,9 +156,7 @@
"""
initialsampler(spl::Sampler) = SampleFromPrior()

function set_values!!(
varinfo::AbstractVarInfo, initial_params::AbstractVector, spl::AbstractSampler
)
function set_values!!(varinfo::AbstractVarInfo, initial_params::AbstractVector)
throw(
ArgumentError(
"`initial_params` must be a vector of type `Union{Real,Missing}`. " *
Expand All @@ -168,11 +166,9 @@
end

function set_values!!(
varinfo::AbstractVarInfo,
initial_params::AbstractVector{<:Union{Real,Missing}},
spl::AbstractSampler,
varinfo::AbstractVarInfo, initial_params::AbstractVector{<:Union{Real,Missing}}
)
flattened_param_vals = varinfo[spl]
flattened_param_vals = varinfo[:]
length(flattened_param_vals) == length(initial_params) || throw(
DimensionMismatch(
"Provided initial value size ($(length(initial_params))) doesn't match " *
Expand All @@ -189,12 +185,11 @@
end

# Update in `varinfo`.
return setindex!!(varinfo, flattened_param_vals, spl)
setall!(varinfo, flattened_param_vals)
return varinfo
end

function set_values!!(
varinfo::AbstractVarInfo, initial_params::NamedTuple, spl::AbstractSampler
)
function set_values!!(varinfo::AbstractVarInfo, initial_params::NamedTuple)
vars_in_varinfo = keys(varinfo)
for v in keys(initial_params)
vn = VarName{v}()
Expand All @@ -219,23 +214,21 @@
)
end

function initialize_parameters!!(
vi::AbstractVarInfo, initial_params, spl::AbstractSampler, model::Model
)
function initialize_parameters!!(vi::AbstractVarInfo, initial_params, model::Model)
@debug "Using passed-in initial variable values" initial_params

# `link` the varinfo if needed.
linked = islinked(vi, spl)
linked = islinked(vi)
if linked
vi = invlink!!(vi, spl, model)
vi = invlink!!(vi, model)

Check warning on line 223 in src/sampler.jl

View check run for this annotation

Codecov / codecov/patch

src/sampler.jl#L223

Added line #L223 was not covered by tests
end

# Set the values in `vi`.
vi = set_values!!(vi, initial_params, spl)
vi = set_values!!(vi, initial_params)

# `invlink` if needed.
if linked
vi = link!!(vi, spl, model)
vi = link!!(vi, model)

Check warning on line 231 in src/sampler.jl

View check run for this annotation

Codecov / codecov/patch

src/sampler.jl#L231

Added line #L231 was not covered by tests
end

return vi
Expand Down
13 changes: 2 additions & 11 deletions src/simple_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,6 @@
return last(evaluate!!(model, varinfo, SamplingContext()))
end

unflatten(svi::SimpleVarInfo, spl::AbstractSampler, x::AbstractVector) = unflatten(svi, x)
function unflatten(svi::SimpleVarInfo, x::AbstractVector)
logp = getlogp(svi)
vals = unflatten(svi.values, x)
Expand Down Expand Up @@ -342,10 +341,6 @@
return Accessors.@set vi.values = set!!(vi.values, vn, val)
end

function BangBang.setindex!!(vi::SimpleVarInfo, val, spl::AbstractSampler)
return unflatten(vi, spl, val)
end

# TODO: Specialize to handle certain cases, e.g. a collection of `VarName` with
# same symbol and same type of, say, `IndexLens`, for improved `.~` performance.
function BangBang.setindex!!(vi::SimpleVarInfo, vals, vns::AbstractVector{<:VarName})
Expand Down Expand Up @@ -428,11 +423,7 @@
}

# Necessary for `matchingvalue` to work properly.
function Base.eltype(
vi::SimpleOrThreadSafeSimple{<:Any,V}, spl::Union{AbstractSampler,SampleFromPrior}
) where {V}
return V
end
Base.eltype(::SimpleOrThreadSafeSimple{<:Any,V}) where {V} = V

# `subset`
function subset(varinfo::SimpleVarInfo, vns::AbstractVector{<:VarName})
Expand Down Expand Up @@ -562,7 +553,7 @@
istrans(vi::SimpleVarInfo, vn::VarName) = istrans(vi)
istrans(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, vn::VarName) = istrans(vi.varinfo, vn)

islinked(vi::SimpleVarInfo, ::Union{Sampler,SampleFromPrior}) = istrans(vi)
islinked(vi::SimpleVarInfo) = istrans(vi)

Check warning on line 556 in src/simple_varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/simple_varinfo.jl#L556

Added line #L556 was not covered by tests

values_as(vi::SimpleVarInfo) = vi.values
values_as(vi::SimpleVarInfo{<:T}, ::Type{T}) where {T} = vi.values
Expand Down
20 changes: 1 addition & 19 deletions src/threadsafe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
keys(vi::ThreadSafeVarInfo) = keys(vi.varinfo)
haskey(vi::ThreadSafeVarInfo, vn::VarName) = haskey(vi.varinfo, vn)

islinked(vi::ThreadSafeVarInfo, spl::AbstractSampler) = islinked(vi.varinfo, spl)
islinked(vi::ThreadSafeVarInfo) = islinked(vi.varinfo)

Check warning on line 82 in src/threadsafe.jl

View check run for this annotation

Codecov / codecov/patch

src/threadsafe.jl#L82

Added line #L82 was not covered by tests

function link!!(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...)
return Accessors.@set vi.varinfo = link!!(t, vi.varinfo, args...)
Expand Down Expand Up @@ -138,17 +138,6 @@
function getindex(vi::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}, dist::Distribution)
return getindex(vi.varinfo, vns, dist)
end
getindex(vi::ThreadSafeVarInfo, spl::AbstractSampler) = getindex(vi.varinfo, spl)

function BangBang.setindex!!(vi::ThreadSafeVarInfo, val, spl::AbstractSampler)
return Accessors.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, val, spl)
end
function BangBang.setindex!!(vi::ThreadSafeVarInfo, val, spl::SampleFromPrior)
return Accessors.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, val, spl)
end
function BangBang.setindex!!(vi::ThreadSafeVarInfo, val, spl::SampleFromUniform)
return Accessors.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, val, spl)
end

function BangBang.setindex!!(vi::ThreadSafeVarInfo, vals, vn::VarName)
return Accessors.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, vals, vn)
Expand Down Expand Up @@ -184,13 +173,9 @@
return is_flagged(vi.varinfo, vn, flag)
end

# Transformations.
function settrans!!(vi::ThreadSafeVarInfo, trans::Bool, vn::VarName)
return Accessors.@set vi.varinfo = settrans!!(vi.varinfo, trans, vn)
end
function settrans!!(vi::ThreadSafeVarInfo, spl::AbstractSampler, dist::Distribution)
return Accessors.@set vi.varinfo = settrans!!(vi.varinfo, spl, dist)
end

istrans(vi::ThreadSafeVarInfo, vn::VarName) = istrans(vi.varinfo, vn)
istrans(vi::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}) = istrans(vi.varinfo, vns)
Expand All @@ -200,9 +185,6 @@
function unflatten(vi::ThreadSafeVarInfo, x::AbstractVector)
return Accessors.@set vi.varinfo = unflatten(vi.varinfo, x)
end
function unflatten(vi::ThreadSafeVarInfo, spl::AbstractSampler, x::AbstractVector)
return Accessors.@set vi.varinfo = unflatten(vi.varinfo, spl, x)
end

function subset(varinfo::ThreadSafeVarInfo, vns::AbstractVector{<:VarName})
return Accessors.@set varinfo.varinfo = subset(varinfo.varinfo, vns)
Expand Down
Loading