diff --git a/HISTORY.md b/HISTORY.md index 03c564b64..6b7247c8d 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -10,6 +10,11 @@ This release removes the feature of `VarInfo` where it kept track of which varia - `link` and `invlink`, and their `!!` versions, no longer accept a sampler as an argument to specify which variables to (inv)link. The `link(varinfo, model)` methods remain in place, and as a new addition one can give a `Tuple` of `VarName`s to (inv)link only select variables, as in `link(varinfo, varname_tuple, model)`. - `set_retained_vns_del_by_spl!` has been replaced by `set_retained_vns_del!` which applies to all variables. + - `getindex`, `setindex!`, and `setindex!!` no longer accept samplers as arguments + - `unflatten` no longer accepts a sampler as an argument + - `eltype(::VarInfo)` no longer accepts a sampler as an argument + - `keys(::VarInfo)` no longer accepts a sampler as an argument + - `VarInfo(::VarInfo, ::Sampler, ::AbstactVector)` no longer accepts the sampler argument. ### Reverse prefixing order diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 26c4268d8..4e9e5c554 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -149,7 +149,6 @@ If `dist` is specified, the value(s) will be massaged into the representation ex """ getindex(vi::AbstractVarInfo, ::Colon) - getindex(vi::AbstractVarInfo, ::AbstractSampler) Return the current value(s) of `vn` (`vns`) in `vi` in the support of its (their) distribution(s) as a flattened `Vector`. @@ -159,7 +158,6 @@ The default implementation is to call [`values_as`](@ref) with `Vector` as the t See also: [`getindex(vi::AbstractVarInfo, vn::VarName, dist::Distribution)`](@ref) """ Base.getindex(vi::AbstractVarInfo, ::Colon) = values_as(vi, Vector) -Base.getindex(vi::AbstractVarInfo, ::AbstractSampler) = vi[:] """ getindex_internal(vi::AbstractVarInfo, vn::VarName) @@ -341,9 +339,9 @@ julia> values_as(vi, Vector) function values_as end """ - eltype(vi::AbstractVarInfo, spl::Union{AbstractSampler,SampleFromPrior} + eltype(vi::AbstractVarInfo) -Determine the default `eltype` of the values returned by `vi[spl]`. +Return the `eltype` of the values returned by `vi[:]`. !!! warning This should generally not be called explicitly, as it's only used in @@ -352,13 +350,13 @@ Determine the default `eltype` of the values returned by `vi[spl]`. This method is considered legacy, and is likely to be deprecated in the future. """ -function Base.eltype(vi::AbstractVarInfo, spl::Union{AbstractSampler,SampleFromPrior}) - T = Base.promote_op(getindex, typeof(vi), typeof(spl)) +function Base.eltype(vi::AbstractVarInfo) + T = Base.promote_op(getindex, typeof(vi), Colon) if T === Union{} - # In this case `getindex(vi, spl)` errors + # In this case `getindex(vi, :)` errors # Let us throw a more descriptive error message # Ref https://github.com/TuringLang/Turing.jl/issues/2151 - return eltype(vi[spl]) + return eltype(vi[:]) end return eltype(T) end @@ -720,25 +718,11 @@ end # Utilities """ - unflatten(vi::AbstractVarInfo[, context::AbstractContext], x::AbstractVector) + unflatten(vi::AbstractVarInfo, x::AbstractVector) Return a new instance of `vi` with the values of `x` assigned to the variables. - -If `context` is provided, `x` is assumed to be realizations only for variables not -filtered out by `context`. """ -function unflatten(varinfo::AbstractVarInfo, context::AbstractContext, θ) - if hassampler(context) - unflatten(getsampler(context), varinfo, context, θ) - else - DynamicPPL.unflatten(varinfo, θ) - end -end - -# TODO: deprecate this once `sampler` is no longer the main way of filtering out variables. -function unflatten(sampler::AbstractSampler, varinfo::AbstractVarInfo, ::AbstractContext, θ) - return unflatten(varinfo, sampler, θ) -end +function unflatten end """ to_maybe_linked_internal(vi::AbstractVarInfo, vn::VarName, dist, val) diff --git a/src/compiler.jl b/src/compiler.jl index c67da6f95..8743641af 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -3,7 +3,7 @@ const INTERNALNAMES = (:__model__, :__context__, :__varinfo__) """ need_concretize(expr) -Return `true` if `expr` needs to be concretized, i.e., if it contains a colon `:` or +Return `true` if `expr` needs to be concretized, i.e., if it contains a colon `:` or requires a dynamic optic. # Examples @@ -730,19 +730,19 @@ function warn_empty(body) return nothing end +# TODO(mhauru) matchingvalue has methods that can accept both types and values. Why? +# TODO(mhauru) This function needs a more comprehensive docstring. """ - matchingvalue(sampler, vi, value) - matchingvalue(context::AbstractContext, vi, value) + matchingvalue(vi, value) -Convert the `value` to the correct type for the `sampler` or `context` and the `vi` object. - -For a `context` that is _not_ a `SamplingContext`, we fall back to -`matchingvalue(SampleFromPrior(), vi, value)`. +Convert the `value` to the correct type for the `vi` object. """ -function matchingvalue(sampler, vi, value) +function matchingvalue(vi, value) T = typeof(value) if hasmissing(T) - _value = convert(get_matching_type(sampler, vi, T), value) + _value = convert(get_matching_type(vi, T), value) + # TODO(mhauru) Why do we make a deepcopy, even though in the !hasmissing branch we + # are happy to return `value` as-is? if _value === value return deepcopy(_value) else @@ -752,45 +752,30 @@ function matchingvalue(sampler, vi, value) return value end end -# If we hit `Type` or `TypeWrap`, we immediately jump to `get_matching_type`. -function matchingvalue(sampler::AbstractSampler, vi, value::FloatOrArrayType) - return get_matching_type(sampler, vi, value) -end -function matchingvalue(sampler::AbstractSampler, vi, value::TypeWrap{T}) where {T} - return TypeWrap{get_matching_type(sampler, vi, T)}() -end -function matchingvalue(context::AbstractContext, vi, value) - return matchingvalue(NodeTrait(matchingvalue, context), context, vi, value) +function matchingvalue(vi, value::FloatOrArrayType) + return get_matching_type(vi, value) end -function matchingvalue(::IsLeaf, context::AbstractContext, vi, value) - return matchingvalue(SampleFromPrior(), vi, value) -end -function matchingvalue(::IsParent, context::AbstractContext, vi, value) - return matchingvalue(childcontext(context), vi, value) -end -function matchingvalue(context::SamplingContext, vi, value) - return matchingvalue(context.sampler, vi, value) +function matchingvalue(vi, ::TypeWrap{T}) where {T} + return TypeWrap{get_matching_type(vi, T)}() end +# TODO(mhauru) This function needs a more comprehensive docstring. What is it for? """ - get_matching_type(spl::AbstractSampler, vi, ::TypeWrap{T}) where {T} - -Get the specialized version of type `T` for sampler `spl`. + get_matching_type(vi, ::TypeWrap{T}) where {T} -For example, if `T === Float64` and `spl::Hamiltonian`, the matching type is -`eltype(vi[spl])`. +Get the specialized version of type `T` for `vi`. """ -get_matching_type(spl::AbstractSampler, vi, ::Type{T}) where {T} = T -function get_matching_type(spl::AbstractSampler, vi, ::Type{<:Union{Missing,AbstractFloat}}) - return Union{Missing,float_type_with_fallback(eltype(vi, spl))} +get_matching_type(_, ::Type{T}) where {T} = T +function get_matching_type(vi, ::Type{<:Union{Missing,AbstractFloat}}) + return Union{Missing,float_type_with_fallback(eltype(vi))} end -function get_matching_type(spl::AbstractSampler, vi, ::Type{<:AbstractFloat}) - return float_type_with_fallback(eltype(vi, spl)) +function get_matching_type(vi, ::Type{<:AbstractFloat}) + return float_type_with_fallback(eltype(vi)) end -function get_matching_type(spl::AbstractSampler, vi, ::Type{<:Array{T,N}}) where {T,N} - return Array{get_matching_type(spl, vi, T),N} +function get_matching_type(vi, ::Type{<:Array{T,N}}) where {T,N} + return Array{get_matching_type(vi, T),N} end -function get_matching_type(spl::AbstractSampler, vi, ::Type{<:Array{T}}) where {T} - return Array{get_matching_type(spl, vi, T)} +function get_matching_type(vi, ::Type{<:Array{T}}) where {T} + return Array{get_matching_type(vi, T)} end diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 214369ab0..29f591cc3 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -121,22 +121,17 @@ end getsampler(f::LogDensityFunction) = getsampler(getcontext(f)) hassampler(f::LogDensityFunction) = hassampler(getcontext(f)) -_get_indexer(ctx::AbstractContext) = _get_indexer(NodeTrait(ctx), ctx) -_get_indexer(ctx::SamplingContext) = ctx.sampler -_get_indexer(::IsParent, ctx::AbstractContext) = _get_indexer(childcontext(ctx)) -_get_indexer(::IsLeaf, ctx::AbstractContext) = Colon() - """ getparams(f::LogDensityFunction) Return the parameters of the wrapped varinfo as a vector. """ -getparams(f::LogDensityFunction) = f.varinfo[_get_indexer(getcontext(f))] +getparams(f::LogDensityFunction) = f.varinfo[:] # LogDensityProblems interface function LogDensityProblems.logdensity(f::LogDensityFunction, θ::AbstractVector) context = getcontext(f) - vi_new = unflatten(f.varinfo, context, θ) + vi_new = unflatten(f.varinfo, θ) return getlogp(last(evaluate!!(f.model, vi_new, context))) end function LogDensityProblems.capabilities(::Type{<:LogDensityFunction}) diff --git a/src/model.jl b/src/model.jl index 462db7397..3601d77fd 100644 --- a/src/model.jl +++ b/src/model.jl @@ -948,9 +948,9 @@ Return the arguments and keyword arguments to be passed to the evaluator of the ) where {_F,argnames} unwrap_args = [ if is_splat_symbol(var) - :($matchingvalue(context_new, varinfo, model.args.$var)...) + :($matchingvalue(varinfo, model.args.$var)...) else - :($matchingvalue(context_new, varinfo, model.args.$var)) + :($matchingvalue(varinfo, model.args.$var)) end for var in argnames ] diff --git a/src/sampler.jl b/src/sampler.jl index 974828e8b..56cd8404e 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -118,7 +118,7 @@ function AbstractMCMC.step( # Update the parameters if provided. if initial_params !== nothing - vi = initialize_parameters!!(vi, initial_params, spl, model) + vi = initialize_parameters!!(vi, initial_params, model) # Update joint log probability. # This is a quick fix for https://github.com/TuringLang/Turing.jl/issues/1588 @@ -156,9 +156,7 @@ By default, it returns an instance of [`SampleFromPrior`](@ref). """ initialsampler(spl::Sampler) = SampleFromPrior() -function set_values!!( - varinfo::AbstractVarInfo, initial_params::AbstractVector, spl::AbstractSampler -) +function set_values!!(varinfo::AbstractVarInfo, initial_params::AbstractVector) throw( ArgumentError( "`initial_params` must be a vector of type `Union{Real,Missing}`. " * @@ -168,11 +166,9 @@ function set_values!!( end function set_values!!( - varinfo::AbstractVarInfo, - initial_params::AbstractVector{<:Union{Real,Missing}}, - spl::AbstractSampler, + varinfo::AbstractVarInfo, initial_params::AbstractVector{<:Union{Real,Missing}} ) - flattened_param_vals = varinfo[spl] + flattened_param_vals = varinfo[:] length(flattened_param_vals) == length(initial_params) || throw( DimensionMismatch( "Provided initial value size ($(length(initial_params))) doesn't match " * @@ -189,12 +185,11 @@ function set_values!!( end # Update in `varinfo`. - return setindex!!(varinfo, flattened_param_vals, spl) + setall!(varinfo, flattened_param_vals) + return varinfo end -function set_values!!( - varinfo::AbstractVarInfo, initial_params::NamedTuple, spl::AbstractSampler -) +function set_values!!(varinfo::AbstractVarInfo, initial_params::NamedTuple) vars_in_varinfo = keys(varinfo) for v in keys(initial_params) vn = VarName{v}() @@ -219,23 +214,21 @@ function set_values!!( ) end -function initialize_parameters!!( - vi::AbstractVarInfo, initial_params, spl::AbstractSampler, model::Model -) +function initialize_parameters!!(vi::AbstractVarInfo, initial_params, model::Model) @debug "Using passed-in initial variable values" initial_params # `link` the varinfo if needed. - linked = islinked(vi, spl) + linked = islinked(vi) if linked - vi = invlink!!(vi, spl, model) + vi = invlink!!(vi, model) end # Set the values in `vi`. - vi = set_values!!(vi, initial_params, spl) + vi = set_values!!(vi, initial_params) # `invlink` if needed. if linked - vi = link!!(vi, spl, model) + vi = link!!(vi, model) end return vi diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 57b167077..07296c3f7 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -258,7 +258,6 @@ function typed_simple_varinfo(model::Model) return last(evaluate!!(model, varinfo, SamplingContext())) end -unflatten(svi::SimpleVarInfo, spl::AbstractSampler, x::AbstractVector) = unflatten(svi, x) function unflatten(svi::SimpleVarInfo, x::AbstractVector) logp = getlogp(svi) vals = unflatten(svi.values, x) @@ -342,10 +341,6 @@ function BangBang.setindex!!(vi::SimpleVarInfo, val, vn::VarName) return Accessors.@set vi.values = set!!(vi.values, vn, val) end -function BangBang.setindex!!(vi::SimpleVarInfo, val, spl::AbstractSampler) - return unflatten(vi, spl, val) -end - # TODO: Specialize to handle certain cases, e.g. a collection of `VarName` with # same symbol and same type of, say, `IndexLens`, for improved `.~` performance. function BangBang.setindex!!(vi::SimpleVarInfo, vals, vns::AbstractVector{<:VarName}) @@ -428,11 +423,7 @@ const SimpleOrThreadSafeSimple{T,V,C} = Union{ } # Necessary for `matchingvalue` to work properly. -function Base.eltype( - vi::SimpleOrThreadSafeSimple{<:Any,V}, spl::Union{AbstractSampler,SampleFromPrior} -) where {V} - return V -end +Base.eltype(::SimpleOrThreadSafeSimple{<:Any,V}) where {V} = V # `subset` function subset(varinfo::SimpleVarInfo, vns::AbstractVector{<:VarName}) @@ -562,7 +553,7 @@ istrans(vi::SimpleVarInfo) = !(vi.transformation isa NoTransformation) istrans(vi::SimpleVarInfo, vn::VarName) = istrans(vi) istrans(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, vn::VarName) = istrans(vi.varinfo, vn) -islinked(vi::SimpleVarInfo, ::Union{Sampler,SampleFromPrior}) = istrans(vi) +islinked(vi::SimpleVarInfo) = istrans(vi) values_as(vi::SimpleVarInfo) = vi.values values_as(vi::SimpleVarInfo{<:T}, ::Type{T}) where {T} = vi.values diff --git a/src/threadsafe.jl b/src/threadsafe.jl index 69be5dcb1..4367ff06d 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -79,7 +79,7 @@ setval!(vi::ThreadSafeVarInfo, val, vn::VarName) = setval!(vi.varinfo, val, vn) keys(vi::ThreadSafeVarInfo) = keys(vi.varinfo) haskey(vi::ThreadSafeVarInfo, vn::VarName) = haskey(vi.varinfo, vn) -islinked(vi::ThreadSafeVarInfo, spl::AbstractSampler) = islinked(vi.varinfo, spl) +islinked(vi::ThreadSafeVarInfo) = islinked(vi.varinfo) function link!!(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...) return Accessors.@set vi.varinfo = link!!(t, vi.varinfo, args...) @@ -138,17 +138,6 @@ end function getindex(vi::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}, dist::Distribution) return getindex(vi.varinfo, vns, dist) end -getindex(vi::ThreadSafeVarInfo, spl::AbstractSampler) = getindex(vi.varinfo, spl) - -function BangBang.setindex!!(vi::ThreadSafeVarInfo, val, spl::AbstractSampler) - return Accessors.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, val, spl) -end -function BangBang.setindex!!(vi::ThreadSafeVarInfo, val, spl::SampleFromPrior) - return Accessors.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, val, spl) -end -function BangBang.setindex!!(vi::ThreadSafeVarInfo, val, spl::SampleFromUniform) - return Accessors.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, val, spl) -end function BangBang.setindex!!(vi::ThreadSafeVarInfo, vals, vn::VarName) return Accessors.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, vals, vn) @@ -184,13 +173,9 @@ function is_flagged(vi::ThreadSafeVarInfo, vn::VarName, flag::String) return is_flagged(vi.varinfo, vn, flag) end -# Transformations. function settrans!!(vi::ThreadSafeVarInfo, trans::Bool, vn::VarName) return Accessors.@set vi.varinfo = settrans!!(vi.varinfo, trans, vn) end -function settrans!!(vi::ThreadSafeVarInfo, spl::AbstractSampler, dist::Distribution) - return Accessors.@set vi.varinfo = settrans!!(vi.varinfo, spl, dist) -end istrans(vi::ThreadSafeVarInfo, vn::VarName) = istrans(vi.varinfo, vn) istrans(vi::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}) = istrans(vi.varinfo, vns) @@ -200,9 +185,6 @@ getindex_internal(vi::ThreadSafeVarInfo, vn::VarName) = getindex_internal(vi.var function unflatten(vi::ThreadSafeVarInfo, x::AbstractVector) return Accessors.@set vi.varinfo = unflatten(vi.varinfo, x) end -function unflatten(vi::ThreadSafeVarInfo, spl::AbstractSampler, x::AbstractVector) - return Accessors.@set vi.varinfo = unflatten(vi.varinfo, spl, x) -end function subset(varinfo::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}) return Accessors.@set varinfo.varinfo = subset(varinfo.varinfo, vns) diff --git a/src/utils.jl b/src/utils.jl index 2539b7179..d64f6dc66 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -942,9 +942,9 @@ function update_values!!(vi::AbstractVarInfo, vals::NamedTuple, vns) end """ - float_type_with_fallback(x) + float_type_with_fallback(T::DataType) -Return type corresponding to `float(typeof(x))` if possible; otherwise return `float(Real)`. +Return `float(T)` if possible; otherwise return `float(Real)`. """ float_type_with_fallback(::Type) = float(Real) float_type_with_fallback(::Type{Union{}}) = float(Real) diff --git a/src/varinfo.jl b/src/varinfo.jl index 09f5960c1..8f7f7b6c1 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -111,10 +111,11 @@ const VarInfoOrThreadSafeVarInfo{Tmeta} = Union{ # NOTE: This is kind of weird, but it effectively preserves the "old" # behavior where we're allowed to call `link!` on the same `VarInfo` # multiple times. -transformation(vi::VarInfo) = DynamicTransformation() +transformation(::VarInfo) = DynamicTransformation() -function VarInfo(old_vi::VarInfo, spl, x::AbstractVector) - md = replace_values(old_vi.metadata, Val(getspace(spl)), x) +# TODO(mhauru) Isn't this the same as unflatten and/or replace_values? +function VarInfo(old_vi::VarInfo, x::AbstractVector) + md = replace_values(old_vi.metadata, x) return VarInfo( md, Base.RefValue{eltype(x)}(getlogp(old_vi)), Ref(get_num_produce(old_vi)) ) @@ -217,53 +218,42 @@ vector_length(varinfo::VarInfo) = length(varinfo.metadata) vector_length(varinfo::TypedVarInfo) = sum(length, varinfo.metadata) vector_length(md::Metadata) = sum(length, md.ranges) -unflatten(vi::VarInfo, x::AbstractVector) = unflatten(vi, SampleFromPrior(), x) - -# TODO: deprecate. -function unflatten(vi::VarInfo, spl::AbstractSampler, x::AbstractVector) - md = unflatten(vi.metadata, spl, x) +function unflatten(vi::VarInfo, x::AbstractVector) + md = unflatten_metadata(vi.metadata, x) + # Note that use of RefValue{eltype(x)} rather than Ref is necessary to deal with cases + # where e.g. x is a type gradient of some AD backend. return VarInfo(md, Base.RefValue{eltype(x)}(getlogp(vi)), Ref(get_num_produce(vi))) end -# The Val(getspace(spl)) is used to dispatch into the below generated function. -function unflatten(metadata::NamedTuple, spl::AbstractSampler, x::AbstractVector) - return unflatten(metadata, Val(getspace(spl)), x) -end - -@generated function unflatten( - metadata::NamedTuple{names}, ::Val{space}, x -) where {names,space} +# We would call this `unflatten` if not for `unflatten` having a method for NamedTuples in +# utils.jl. +@generated function unflatten_metadata( + metadata::NamedTuple{names}, x::AbstractVector +) where {names} exprs = [] offset = :(0) for f in names mdf = :(metadata.$f) - if inspace(f, space) || length(space) == 0 - len = :(sum(length, $mdf.ranges)) - push!(exprs, :($f = unflatten($mdf, x[($offset + 1):($offset + $len)]))) - offset = :($offset + $len) - else - push!(exprs, :($f = $mdf)) - end + len = :(sum(length, $mdf.ranges)) + push!(exprs, :($f = unflatten_metadata($mdf, x[($offset + 1):($offset + $len)]))) + offset = :($offset + $len) end length(exprs) == 0 && return :(NamedTuple()) return :($(exprs...),) end # For Metadata unflatten and replace_values are the same. For VarNamedVector they are not. -function unflatten(md::Metadata, x::AbstractVector) +function unflatten_metadata(md::Metadata, x::AbstractVector) return replace_values(md, x) end -function unflatten(md::Metadata, spl::AbstractSampler, x::AbstractVector) - return replace_values(md, spl, x) -end + +unflatten_metadata(vnv::VarNamedVector, x::AbstractVector) = unflatten(vnv, x) # without AbstractSampler function VarInfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext) return VarInfo(rng, model, SampleFromPrior(), context) end -# TODO: Remove `space` argument when no longer needed. Ref: https://github.com/TuringLang/DynamicPPL.jl/issues/573 -replace_values(metadata::Metadata, space, x) = replace_values(metadata, x) function replace_values(metadata::Metadata, x) return Metadata( metadata.idcs, @@ -277,20 +267,14 @@ function replace_values(metadata::Metadata, x) ) end -@generated function replace_values( - metadata::NamedTuple{names}, ::Val{space}, x -) where {names,space} +@generated function replace_values(metadata::NamedTuple{names}, x) where {names} exprs = [] offset = :(0) for f in names mdf = :(metadata.$f) - if inspace(f, space) || length(space) == 0 - len = :(sum(length, $mdf.ranges)) - push!(exprs, :($f = replace_values($mdf, x[($offset + 1):($offset + $len)]))) - offset = :($offset + $len) - else - push!(exprs, :($f = $mdf)) - end + len = :(sum(length, $mdf.ranges)) + push!(exprs, :($f = replace_values($mdf, x[($offset + 1):($offset + $len)]))) + offset = :($offset + $len) end length(exprs) == 0 && return :(NamedTuple()) return :($(exprs...),) @@ -786,7 +770,7 @@ settrans!!(vi::VarInfo, trans::AbstractTransformation) = settrans!!(vi, true) """ syms(vi::VarInfo) -Returns a tuple of the unique symbols of random variables sampled in `vi`. +Returns a tuple of the unique symbols of random variables in `vi`. """ syms(vi::UntypedVarInfo) = Tuple(unique!(map(getsym, vi.metadata.vns))) # get all symbols syms(vi::TypedVarInfo) = keys(vi.metadata) @@ -794,16 +778,6 @@ syms(vi::TypedVarInfo) = keys(vi.metadata) _getidcs(vi::UntypedVarInfo) = 1:length(vi.metadata.idcs) _getidcs(vi::TypedVarInfo) = _getidcs(vi.metadata) -# Get all indices of variables belonging to SampleFromPrior: -# if the gid/selector of a var is an empty Set, then that var is assumed to be assigned to -# the SampleFromPrior sampler -@inline function _getidcs(vi::UntypedVarInfo, ::SampleFromPrior) - return filter(i -> isempty(vi.metadata.gids[i]), 1:length(vi.metadata.gids)) -end -# Get a NamedTuple of all the indices belonging to SampleFromPrior, one for each symbol -@inline function _getidcs(vi::TypedVarInfo, ::SampleFromPrior) - return _getidcs(vi.metadata) -end @generated function _getidcs(metadata::NamedTuple{names}) where {names} exprs = [] for f in names @@ -813,93 +787,15 @@ end return :($(exprs...),) end -# Get all indices of variables belonging to a given sampler -@inline function _getidcs(vi::VarInfo, spl::Sampler) - # NOTE: 0b00 is the sanity flag for - # |\____ getidcs (mask = 0b10) - # \_____ getranges (mask = 0b01) - #if ~haskey(spl.info, :cache_updated) spl.info[:cache_updated] = CACHERESET end - # Checks if cache is valid, i.e. no new pushes were made, to return the cached idcs - # Otherwise, it recomputes the idcs and caches it - #if haskey(spl.info, :idcs) && (spl.info[:cache_updated] & CACHEIDCS) > 0 - # spl.info[:idcs] - #else - #spl.info[:cache_updated] = spl.info[:cache_updated] | CACHEIDCS - idcs = _getidcs(vi, spl.selector, Val(getspace(spl))) - #spl.info[:idcs] = idcs - #end - return idcs -end -@inline _getidcs(vi::UntypedVarInfo, s::Selector, space) = findinds(vi.metadata, s, space) -@inline _getidcs(vi::TypedVarInfo, s::Selector, space) = _getidcs(vi.metadata, s, space) -# Get a NamedTuple for all the indices belonging to a given selector for each symbol -@generated function _getidcs( - metadata::NamedTuple{names}, s::Selector, ::Val{space} -) where {names,space} - exprs = [] - # Iterate through each varname in metadata. - for f in names - # If the varname is in the sampler space - # or the sample space is empty (all variables) - # then return the indices for that variable. - if inspace(f, space) || length(space) == 0 - push!(exprs, :($f = findinds(metadata.$f, s, Val($space)))) - end - end - length(exprs) == 0 && return :(NamedTuple()) - return :($(exprs...),) -end -@inline function findinds(f_meta::Metadata, s, ::Val{space}) where {space} - # Get all the idcs of the vns in `space` and that belong to the selector `s` - return filter( - (i) -> - (s in f_meta.gids[i] || isempty(f_meta.gids[i])) && - (isempty(space) || inspace(f_meta.vns[i], space)), - 1:length(f_meta.gids), - ) -end @inline function findinds(f_meta::Metadata) # Get all the idcs of the vns return filter((i) -> isempty(f_meta.gids[i]), 1:length(f_meta.gids)) end -function findinds(vnv::VarNamedVector, ::Selector, ::Val{space}) where {space} - # New Metadata objects are created with an empty list of gids, which is intrepreted as - # all Selectors applying to all variables. We assume the same behavior for - # VarNamedVector, and thus ignore the Selector argument. - if space !== () - msg = "VarNamedVector does not support selecting variables based on samplers" - throw(ErrorException(msg)) - else - return findinds(vnv) - end -end - function findinds(vnv::VarNamedVector) return 1:length(vnv.varnames) end -# Get all vns of variables belonging to spl -_getvns(vi::VarInfo, spl::Sampler) = _getvns(vi, spl.selector, Val(getspace(spl))) -function _getvns(vi::VarInfo, spl::Union{SampleFromPrior,SampleFromUniform}) - return _getvns(vi, Selector(), Val(())) -end -function _getvns(vi::UntypedVarInfo, s::Selector, space) - return view(vi.metadata.vns, _getidcs(vi, s, space)) -end -function _getvns(vi::TypedVarInfo, s::Selector, space) - return _getvns(vi.metadata, _getidcs(vi, s, space)) -end -# Get a NamedTuple for all the `vns` of indices `idcs`, one entry for each symbol -@generated function _getvns(metadata, idcs::NamedTuple{names}) where {names} - exprs = [] - for f in names - push!(exprs, :($f = Base.keys(metadata.$f)[idcs.$f])) - end - length(exprs) == 0 && return :(NamedTuple()) - return :($(exprs...),) -end - """ all_varnames_grouped_by_symbol(vi::TypedVarInfo) @@ -916,47 +812,6 @@ all_varnames_grouped_by_symbol(vi::TypedVarInfo) = return expr end -# Get the index (in vals) ranges of all the vns of variables belonging to spl -@inline function _getranges(vi::VarInfo, spl::Sampler) - ## Uncomment the spl.info stuff when it is concretely typed, not Dict{Symbol, Any} - #if ~haskey(spl.info, :cache_updated) spl.info[:cache_updated] = CACHERESET end - #if haskey(spl.info, :ranges) && (spl.info[:cache_updated] & CACHERANGES) > 0 - # spl.info[:ranges] - #else - #spl.info[:cache_updated] = spl.info[:cache_updated] | CACHERANGES - ranges = _getranges(vi, spl.selector, Val(getspace(spl))) - #spl.info[:ranges] = ranges - return ranges - #end -end -# Get the index (in vals) ranges of all the vns of variables belonging to selector `s` in `space` -@inline function _getranges(vi::VarInfo, s::Selector, space) - return _getranges(vi, _getidcs(vi, s, space)) -end -@inline function _getranges(vi::VarInfo, idcs::Vector{Int}) - return mapreduce(i -> vi.metadata.ranges[i], vcat, idcs; init=Int[]) -end -@inline _getranges(vi::TypedVarInfo, idcs::NamedTuple) = _getranges(vi.metadata, idcs) - -@generated function _getranges(metadata::NamedTuple, idcs::NamedTuple{names}) where {names} - exprs = [] - for f in names - push!(exprs, :($f = findranges(metadata.$f.ranges, idcs.$f))) - end - length(exprs) == 0 && return :(NamedTuple()) - return :($(exprs...),) -end - -@inline function findranges(f_ranges, f_idcs) - # Old implementation was using `mapreduce` but turned out - # to be type-unstable. - results = Int[] - for i in f_idcs - append!(results, f_ranges[i]) - end - return results -end - # TODO(mhauru) These set_flag! methods return the VarInfo. They should probably be called # set_flag!!. """ @@ -1096,12 +951,6 @@ Base.keys(vi::TypedVarInfo{<:NamedTuple{()}}) = VarName[] return expr end -# FIXME(torfjelde): Don't use `_getvns`. -Base.keys(vi::UntypedVarInfo, spl::AbstractSampler) = _getvns(vi, spl) -function Base.keys(vi::TypedVarInfo, spl::AbstractSampler) - return mapreduce(values, vcat, _getvns(vi, spl)) -end - """ setgid!(vi::VarInfo, gid::Selector, vn::VarName) @@ -1191,7 +1040,6 @@ function link!!(t::DynamicTransformation, vi::ThreadSafeVarInfo{<:VarInfo}, mode return Accessors.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, model) end -# X -> R for all variables associated with given sampler function link!!(::DynamicTransformation, vi::VarInfo, vns::VarNameTuple, model::Model) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return _link(model, vi, vns) @@ -1297,7 +1145,6 @@ function invlink!!(t::DynamicTransformation, vi::ThreadSafeVarInfo{<:VarInfo}, m return Accessors.@set vi.varinfo = DynamicPPL.invlink!!(t, vi.varinfo, model) end -# R -> X for all variables associated with given sampler function invlink!!(::DynamicTransformation, vi::VarInfo, vns::VarNameTuple, model::Model) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return _invlink(model, vi, vns) @@ -1394,16 +1241,6 @@ function _inner_transform!(md::Metadata, vi::VarInfo, vn::VarName, f) return vi end -# HACK: We need `SampleFromPrior` to result in ALL values which are in need -# of a transformation to be transformed. `_getvns` will by default return -# an empty iterable for `SampleFromPrior`, so we need to override it here. -# This is quite hacky, but seems safer than changing the behavior of `_getvns`. -_getvns_link(varinfo::VarInfo, spl::AbstractSampler) = _getvns(varinfo, spl) -_getvns_link(varinfo::VarInfo, spl::SampleFromPrior) = nothing -function _getvns_link(varinfo::TypedVarInfo, spl::SampleFromPrior) - return map(Returns(nothing), varinfo.metadata) -end - function link(::DynamicTransformation, vi::TypedVarInfo, model::Model) return _link(model, vi, all_varnames_grouped_by_symbol(vi)) end @@ -1617,7 +1454,7 @@ function _invlink_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, targ # Construct the new transformed values, and keep track of their lengths. vals_new = map(vns) do vn # Return early if we're already in constrained space OR if we're not - # supposed to touch this `vn`, e.g. when `vn` does not belong to the current sampler. + # supposed to touch this `vn`. # HACK: if `target_vns` is `nothing`, we ignore the `target_vns` check. if !istrans(varinfo, vn) || (target_vns !== nothing && vn ∉ target_vns) return metadata.vals[getrange(metadata, vn)] @@ -1677,30 +1514,26 @@ function _invlink_metadata!!( return metadata end +# TODO(mhauru) The treatment of the case when some variables are linked and others are not +# should be revised. It used to be the case that for UntypedVarInfo `islinked` returned +# whether the first variable was linked. For TypedVarInfo we did an OR over the first +# variables under each symbol. We now more consistently use OR, but I'm not convinced this +# is really the right thing to do. """ - islinked(vi::VarInfo, spl::Union{Sampler, SampleFromPrior}) + islinked(vi::VarInfo) -Check whether `vi` is in the transformed space for a particular sampler `spl`. +Check whether `vi` is in the transformed space. Turing's Hamiltonian samplers use the `link` and `invlink` functions from [Bijectors.jl](https://github.com/TuringLang/Bijectors.jl) to map a constrained variable (for example, one bounded to the space `[0, 1]`) from its constrained space to the set of real numbers. `islinked` checks if the number is in the constrained space or the real space. + +If some but only some of the variables in `vi` are linked, this function will return `true`. +This behavior will likely change in the future. """ -function islinked(vi::UntypedVarInfo, spl::Union{Sampler,SampleFromPrior}) - vns = _getvns(vi, spl) - return istrans(vi, vns[1]) -end -function islinked(vi::TypedVarInfo, spl::Union{Sampler,SampleFromPrior}) - vns = _getvns(vi, spl) - return _islinked(vi, vns) -end -@generated function _islinked(vi, vns::NamedTuple{names}) where {names} - out = [] - for f in names - push!(out, :(isempty(vns.$f) ? false : istrans(vi, vns.$f[1]))) - end - return Expr(:||, false, out...) +function islinked(vi::VarInfo) + return any(istrans(vi, vn) for vn in keys(vi)) end function nested_setindex_maybe!(vi::UntypedVarInfo, val, vn::VarName) @@ -1788,22 +1621,6 @@ function getindex(vi::VarInfo, vns::Vector{<:VarName}, dist::Distribution) return recombine(dist, vals_linked, length(vns)) end -""" - getindex(vi::VarInfo, spl::Union{SampleFromPrior, Sampler}) - -Return the current value(s) of the random variables sampled by `spl` in `vi`. - -The value(s) may or may not be transformed to Euclidean space. -""" -getindex(vi::UntypedVarInfo, spl::Sampler) = - copy(getindex(vi.metadata.vals, _getranges(vi, spl))) -getindex(vi::VarInfo, spl::Sampler) = copy(getindex_internal(vi, _getranges(vi, spl))) -function getindex(vi::TypedVarInfo, spl::Sampler) - # Gets the ranges as a NamedTuple - ranges = _getranges(vi, spl) - # Calling getfield(ranges, f) gives all the indices in `vals` of the `vn`s with symbol `f` sampled by `spl` in `vi` - return reduce(vcat, _getindex(vi.metadata, ranges)) -end # Recursively builds a tuple of the `vals` of all the symbols @generated function _getindex(metadata, ranges::NamedTuple{names}) where {names} expr = Expr(:tuple) @@ -1828,43 +1645,6 @@ function BangBang.setindex!!(vi::VarInfo, val, vn::VarName) return vi end -""" - setindex!(vi::VarInfo, val, spl::Union{SampleFromPrior, Sampler}) - -Set the current value(s) of the random variables sampled by `spl` in `vi` to `val`. - -The value(s) may or may not be transformed to Euclidean space. -""" -setindex!(vi::VarInfo, val, spl::SampleFromPrior) = setall!(vi, val) -setindex!(vi::UntypedVarInfo, val, spl::Sampler) = setval!(vi, val, _getranges(vi, spl)) -function setindex!(vi::TypedVarInfo, val, spl::Sampler) - # Gets a `NamedTuple` mapping each symbol to the indices in the symbol's `vals` field sampled from the sampler `spl` - ranges = _getranges(vi, spl) - _setindex!(vi.metadata, val, ranges) - return nothing -end - -function BangBang.setindex!!(vi::VarInfo, val, spl::AbstractSampler) - setindex!(vi, val, spl) - return vi -end - -# Recursively writes the entries of `val` to the `vals` fields of all the symbols as if they were a contiguous vector. -@generated function _setindex!(metadata, val, ranges::NamedTuple{names}) where {names} - expr = Expr(:block) - offset = :(0) - for f in names - f_vals = :(metadata.$f.vals) - f_range = :(ranges.$f) - start = :($offset + 1) - len = :(length($f_range)) - finish = :($offset + $len) - push!(expr.args, :(@views $f_vals[$f_range] .= val[($start):($finish)])) - offset = :($offset + $len) - end - return expr -end - @inline function findvns(vi, f_vns) if length(f_vns) == 0 throw("Unidentified error, please report this error in an issue.") @@ -1877,7 +1657,7 @@ Base.haskey(metadata::Metadata, vn::VarName) = haskey(metadata.idcs, vn) """ haskey(vi::VarInfo, vn::VarName) -Check whether `vn` has been sampled in `vi`. +Check whether `vn` has a value in `vi`. """ Base.haskey(vi::VarInfo, vn::VarName) = haskey(getmetadata(vi, vn), vn) function Base.haskey(vi::TypedVarInfo, vn::VarName) diff --git a/src/varnamedvector.jl b/src/varnamedvector.jl index 7da126321..3b3f0ce42 100644 --- a/src/varnamedvector.jl +++ b/src/varnamedvector.jl @@ -510,12 +510,6 @@ function getindex_internal(vnv::VarNamedVector, ::Colon) end end -# TODO(mhauru): Remove this as soon as possible. Only needed because of the old Gibbs -# sampler. -function Base.getindex(vnv::VarNamedVector, spl::AbstractSampler) - throw(ErrorException("Cannot index a VarNamedVector with a sampler.")) -end - function Base.setindex!(vnv::VarNamedVector, val, vn::VarName) if haskey(vnv, vn) return update!(vnv, val, vn) @@ -1077,15 +1071,6 @@ function unflatten(vnv::VarNamedVector, vals::AbstractVector) ) end -# TODO(mhauru) To be removed once the old Gibbs sampler is removed. -function unflatten(vnv::VarNamedVector, spl::AbstractSampler, vals::AbstractVector) - if length(getspace(spl)) > 0 - msg = "Selecting values in a VarNamedVector with a space is not supported." - throw(ArgumentError(msg)) - end - return unflatten(vnv, vals) -end - function Base.merge(left_vnv::VarNamedVector, right_vnv::VarNamedVector) # Return early if possible. isempty(left_vnv) && return deepcopy(right_vnv) diff --git a/test/model.jl b/test/model.jl index e91de4bd2..256ada0ad 100644 --- a/test/model.jl +++ b/test/model.jl @@ -230,8 +230,8 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() for i in 1:10 # Sample with large variations. - r_raw = randn(length(vi[spl])) * 10 - vi[spl] = r_raw + r_raw = randn(length(vi[:])) * 10 + DynamicPPL.setall!(vi, r_raw) @test vi[@varname(m)] == r_raw[1] @test vi[@varname(x)] != r_raw[2] model(vi) diff --git a/test/sampler.jl b/test/sampler.jl index 3b5424671..50111b1fd 100644 --- a/test/sampler.jl +++ b/test/sampler.jl @@ -196,11 +196,11 @@ vi = VarInfo(model) @test_throws ArgumentError DynamicPPL.initialize_parameters!!( - vi, [initial_z, initial_x], DynamicPPL.SampleFromPrior(), model + vi, [initial_z, initial_x], model ) @test_throws ArgumentError DynamicPPL.initialize_parameters!!( - vi, (X=initial_x, Z=initial_z), DynamicPPL.SampleFromPrior(), model + vi, (X=initial_x, Z=initial_z), model ) end end