Skip to content

VariableOrderAccumulator #940

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: breaking
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ get_num_produce
set_num_produce!!
increment_num_produce!!
reset_num_produce!!
setorder!
setorder!!
set_retained_vns_del!
```

Expand All @@ -358,7 +358,7 @@ DynamicPPL provides the following default accumulators.
```@docs
LogPriorAccumulator
LogLikelihoodAccumulator
NumProduceAccumulator
VariableOrderAccumulator
```

### Common API
Expand Down
4 changes: 2 additions & 2 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ export AbstractVarInfo,
AbstractAccumulator,
LogLikelihoodAccumulator,
LogPriorAccumulator,
NumProduceAccumulator,
VariableOrderAccumulator,
push!!,
empty!!,
subset,
Expand All @@ -73,7 +73,7 @@ export AbstractVarInfo,
is_flagged,
set_flag!,
unset_flag!,
setorder!,
setorder!!,
istrans,
link,
link!!,
Expand Down
34 changes: 30 additions & 4 deletions src/abstract_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,24 @@ function resetlogp!!(vi::AbstractVarInfo)
return vi
end

"""
setorder!!(vi::AbstractVarInfo, vn::VarName, index::Integer)

Set the `order` of `vn` in `vi` to `index`, where `order` is the number of `observe
statements run before sampling `vn`.
"""
function setorder!!(vi::AbstractVarInfo, vn::VarName, index::Integer)
return map_accumulator!!(acc -> (acc.order[vn] = index; acc), vi, Val(:VariableOrder))
end

"""
getorder(vi::VarInfo, vn::VarName)

Get the `order` of `vn` in `vi`, where `order` is the number of `observe` statements
run before sampling `vn`.
"""
getorder(vi::AbstractVarInfo, vn::VarName) = getacc(vi, Val(:VariableOrder)).order[vn]

# Variables and their realizations.
@doc """
keys(vi::AbstractVarInfo)
Expand Down Expand Up @@ -972,29 +990,37 @@ end

Return the `num_produce` of `vi`.
"""
get_num_produce(vi::AbstractVarInfo) = getacc(vi, Val(:NumProduce)).num
get_num_produce(vi::AbstractVarInfo) = getacc(vi, Val(:VariableOrder)).num_produce

"""
set_num_produce!!(vi::AbstractVarInfo, n::Int)

Set the `num_produce` field of `vi` to `n`.
"""
set_num_produce!!(vi::AbstractVarInfo, n::Int) = setacc!!(vi, NumProduceAccumulator(n))
function set_num_produce!!(vi::AbstractVarInfo, n::Integer)
if hasacc(vi, Val(:VariableOrder))
acc = getacc(vi, Val(:VariableOrder))
acc = VariableOrderAccumulator(n, acc.order)
else
acc = VariableOrderAccumulator(n)
end
return setacc!!(vi, acc)
end

"""
increment_num_produce!!(vi::AbstractVarInfo)

Add 1 to `num_produce` in `vi`.
"""
increment_num_produce!!(vi::AbstractVarInfo) =
map_accumulator!!(increment, vi, Val(:NumProduce))
map_accumulator!!(increment, vi, Val(:VariableOrder))

"""
reset_num_produce!!(vi::AbstractVarInfo)

Reset the value of `num_produce` in `vi` to 0.
"""
reset_num_produce!!(vi::AbstractVarInfo) = map_accumulator!!(zero, vi, Val(:NumProduce))
reset_num_produce!!(vi::AbstractVarInfo) = set_num_produce!!(vi, zero(get_num_produce(vi)))

"""
from_internal_transform(varinfo::AbstractVarInfo, vn::VarName[, dist])
Expand Down
4 changes: 4 additions & 0 deletions src/accumulators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ An accumulator type `T <: AbstractAccumulator` must implement the following meth
- `accumulator_name(acc::T)` or `accumulator_name(::Type{T})`
- `accumulate_observe!!(acc::T, right, left, vn)`
- `accumulate_assume!!(acc::T, val, logjac, vn, right)`
- `Base.copy(acc::T)`

To be able to work with multi-threading, it should also implement:
- `split(acc::T)`
Expand Down Expand Up @@ -136,6 +137,9 @@ function Base.haskey(at::AccumulatorTuple, ::Val{accname}) where {accname}
@inline return haskey(at.nt, accname)
end
Base.keys(at::AccumulatorTuple) = keys(at.nt)
Base.:(==)(at1::AccumulatorTuple, at2::AccumulatorTuple) = at1.nt == at2.nt
Base.hash(at::AccumulatorTuple, h::UInt) = Base.hash((AccumulatorTuple, at.nt), h)
Base.copy(at::AccumulatorTuple) = AccumulatorTuple(map(copy, at.nt))

function Base.convert(::Type{AccumulatorTuple{N,T}}, accs::AccumulatorTuple{N}) where {N,T}
return AccumulatorTuple(convert(T, accs.nt))
Expand Down
1 change: 0 additions & 1 deletion src/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,6 @@ function assume(
f = to_maybe_linked_internal_transform(vi, vn, dist)
# TODO(mhauru) This should probably be call a function called setindex_internal!
vi = BangBang.setindex!!(vi, f(r), vn)
setorder!(vi, vn, get_num_produce(vi))
else
# Otherwise we just extract it.
r = vi[vn, dist]
Expand Down
106 changes: 82 additions & 24 deletions src/default_accumulators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,52 +41,102 @@ LogLikelihoodAccumulator{T}() where {T<:Real} = LogLikelihoodAccumulator(zero(T)
LogLikelihoodAccumulator() = LogLikelihoodAccumulator{LogProbType}()

"""
NumProduceAccumulator{T} <: AbstractAccumulator
VariableOrderAccumulator{T} <: AbstractAccumulator

An accumulator that tracks the number of observations during model execution.
An accumulator that tracks the order of variables in a `VarInfo`.

This doesn't track the full ordering, but rather how many observations have taken place
before the assume statement for each variable. This is needed for particle methods, where
the model is segmented into parts by each observation, and we need to know which part each
assume statement is in.

# Fields
$(TYPEDFIELDS)
"""
struct NumProduceAccumulator{T<:Integer} <: AbstractAccumulator
struct VariableOrderAccumulator{Eltype<:Integer,VNType<:VarName} <: AbstractAccumulator
"the number of observations"
num::T
num_produce::Eltype
"mapping of variable names to their order in the model"
order::Dict{VNType,Eltype}
end

"""
NumProduceAccumulator{T<:Integer}()
VariableOrderAccumulator{T<:Integer}(n=zero(T))

Create a new `NumProduceAccumulator` accumulator with the number of observations initialized to zero.
Create a new `VariableOrderAccumulator` accumulator with the number of observations set to n
"""
NumProduceAccumulator{T}() where {T<:Integer} = NumProduceAccumulator(zero(T))
NumProduceAccumulator() = NumProduceAccumulator{Int}()
VariableOrderAccumulator{T}(n=zero(T)) where {T<:Integer} =
VariableOrderAccumulator(convert(T, n), Dict{VarName,T}())
VariableOrderAccumulator(n) = VariableOrderAccumulator{typeof(n)}(n)
VariableOrderAccumulator() = VariableOrderAccumulator{Int}()

Base.copy(acc::LogPriorAccumulator) = acc
Base.copy(acc::LogLikelihoodAccumulator) = acc
function Base.copy(acc::VariableOrderAccumulator)
return VariableOrderAccumulator(acc.num_produce, copy(acc.order))
end

function Base.show(io::IO, acc::LogPriorAccumulator)
return print(io, "LogPriorAccumulator($(repr(acc.logp)))")
end
function Base.show(io::IO, acc::LogLikelihoodAccumulator)
return print(io, "LogLikelihoodAccumulator($(repr(acc.logp)))")
end
function Base.show(io::IO, acc::NumProduceAccumulator)
return print(io, "NumProduceAccumulator($(repr(acc.num)))")
function Base.show(io::IO, acc::VariableOrderAccumulator)
return print(
io, "VariableOrderAccumulator($(repr(acc.num_produce)), $(repr(acc.order)))"
)
end

# Note that == and isequal are different, and equality under the latter should imply
# equality of hashes. Both of the below implementations are also different from the default
# implementation for structs.
Base.:(==)(acc1::LogPriorAccumulator, acc2::LogPriorAccumulator) = acc1.logp == acc2.logp
function Base.:(==)(acc1::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator)
return acc1.logp == acc2.logp
end
function Base.:(==)(acc1::VariableOrderAccumulator, acc2::VariableOrderAccumulator)
return acc1.num_produce == acc2.num_produce && acc1.order == acc2.order
end

function Base.isequal(acc1::LogPriorAccumulator, acc2::LogPriorAccumulator)
return isequal(acc1.logp, acc2.logp)
end
function Base.isequal(acc1::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator)
return isequal(acc1.logp, acc2.logp)
end
function Base.isequal(acc1::VariableOrderAccumulator, acc2::VariableOrderAccumulator)
return isequal(acc1.num_produce, acc2.num_produce) && isequal(acc1.order, acc2.order)
end

Base.hash(acc::LogPriorAccumulator, h::UInt) = hash((LogPriorAccumulator, acc.logp), h)
function Base.hash(acc::LogLikelihoodAccumulator, h::UInt)
return hash((LogLikelihoodAccumulator, acc.logp), h)
end
function Base.hash(acc::VariableOrderAccumulator, h::UInt)
return hash((VariableOrderAccumulator, acc.num_produce, acc.order), h)
end

accumulator_name(::Type{<:LogPriorAccumulator}) = :LogPrior
accumulator_name(::Type{<:LogLikelihoodAccumulator}) = :LogLikelihood
accumulator_name(::Type{<:NumProduceAccumulator}) = :NumProduce
accumulator_name(::Type{<:VariableOrderAccumulator}) = :VariableOrder

split(::LogPriorAccumulator{T}) where {T} = LogPriorAccumulator(zero(T))
split(::LogLikelihoodAccumulator{T}) where {T} = LogLikelihoodAccumulator(zero(T))
split(acc::NumProduceAccumulator) = acc
split(acc::VariableOrderAccumulator) = acc

function combine(acc::LogPriorAccumulator, acc2::LogPriorAccumulator)
return LogPriorAccumulator(acc.logp + acc2.logp)
end
function combine(acc::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator)
return LogLikelihoodAccumulator(acc.logp + acc2.logp)
end
function combine(acc::NumProduceAccumulator, acc2::NumProduceAccumulator)
return NumProduceAccumulator(max(acc.num, acc2.num))
function combine(acc::VariableOrderAccumulator, acc2::VariableOrderAccumulator)
# Note that assumptions are not allowed within in parallelised blocks, and thus the
# dictionaries should be identical.
return VariableOrderAccumulator(
max(acc.num_produce, acc2.num_produce), merge(acc.order, acc2.order)
)
end

function Base.:+(acc1::LogPriorAccumulator, acc2::LogPriorAccumulator)
Expand All @@ -95,11 +145,12 @@ end
function Base.:+(acc1::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator)
return LogLikelihoodAccumulator(acc1.logp + acc2.logp)
end
increment(acc::NumProduceAccumulator) = NumProduceAccumulator(acc.num + oneunit(acc.num))
function increment(acc::VariableOrderAccumulator)
return VariableOrderAccumulator(acc.num_produce + oneunit(acc.num_produce), acc.order)
end

Base.zero(acc::LogPriorAccumulator) = LogPriorAccumulator(zero(acc.logp))
Base.zero(acc::LogLikelihoodAccumulator) = LogLikelihoodAccumulator(zero(acc.logp))
Base.zero(acc::NumProduceAccumulator) = NumProduceAccumulator(zero(acc.num))

function accumulate_assume!!(acc::LogPriorAccumulator, val, logjac, vn, right)
return acc + LogPriorAccumulator(logpdf(right, val) + logjac)
Expand All @@ -114,8 +165,11 @@ function accumulate_observe!!(acc::LogLikelihoodAccumulator, right, left, vn)
return acc + LogLikelihoodAccumulator(Distributions.loglikelihood(right, left))
end

accumulate_assume!!(acc::NumProduceAccumulator, val, logjac, vn, right) = acc
accumulate_observe!!(acc::NumProduceAccumulator, right, left, vn) = increment(acc)
function accumulate_assume!!(acc::VariableOrderAccumulator, val, logjac, vn, right)
acc.order[vn] = acc.num_produce
return acc
end
accumulate_observe!!(acc::VariableOrderAccumulator, right, left, vn) = increment(acc)

function Base.convert(::Type{LogPriorAccumulator{T}}, acc::LogPriorAccumulator) where {T}
return LogPriorAccumulator(convert(T, acc.logp))
Expand All @@ -126,15 +180,19 @@ function Base.convert(
return LogLikelihoodAccumulator(convert(T, acc.logp))
end
function Base.convert(
::Type{NumProduceAccumulator{T}}, acc::NumProduceAccumulator
) where {T}
return NumProduceAccumulator(convert(T, acc.num))
::Type{VariableOrderAccumulator{ElType,VnType}}, acc::VariableOrderAccumulator
) where {ElType,VnType}
order = Dict{VnType,ElType}()
for (k, v) in acc.order
order[convert(VnType, k)] = convert(ElType, v)
end
return VariableOrderAccumulator(convert(ElType, acc.num_produce), order)
end

# TODO(mhauru)
# We ignore the convert_eltype calls for NumProduceAccumulator, by letting them fallback on
# We ignore the convert_eltype calls for VariableOrderAccumulator, by letting them fallback on
# convert_eltype(::AbstractAccumulator, ::Type). This is because they are only used to
# deal with dual number types of AD backends, which shouldn't concern NumProduceAccumulator. This is
# deal with dual number types of AD backends, which shouldn't concern VariableOrderAccumulator. This is
# horribly hacky and should be fixed. See also comment in `unflatten` in `src/varinfo.jl`.
function convert_eltype(::Type{T}, acc::LogPriorAccumulator) where {T}
return LogPriorAccumulator(convert(T, acc.logp))
Expand All @@ -149,6 +207,6 @@ function default_accumulators(
return AccumulatorTuple(
LogPriorAccumulator{FloatT}(),
LogLikelihoodAccumulator{FloatT}(),
NumProduceAccumulator{IntT}(),
VariableOrderAccumulator{IntT}(),
)
end
4 changes: 4 additions & 0 deletions src/extract_priors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ end

PriorDistributionAccumulator() = PriorDistributionAccumulator(OrderedDict{VarName,Any}())

function Base.copy(acc::PriorDistributionAccumulator)
return PriorDistributionAccumulator(copy(acc.priors))
end

accumulator_name(::PriorDistributionAccumulator) = :PriorDistributionAccumulator

split(acc::PriorDistributionAccumulator) = PriorDistributionAccumulator(empty(acc.priors))
Expand Down
4 changes: 4 additions & 0 deletions src/pointwise_logdensities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ function PointwiseLogProbAccumulator{whichlogprob,KeyType}() where {whichlogprob
return PointwiseLogProbAccumulator{whichlogprob,KeyType,typeof(logps)}(logps)
end

function Base.copy(acc::PointwiseLogProbAccumulator{whichlogprob}) where {whichlogprob}
return PointwiseLogProbAccumulator{whichlogprob}(copy(acc.logps))
end

function Base.push!(acc::PointwiseLogProbAccumulator, vn, logp)
logps = acc.logps
# The last(fieldtypes(eltype(...))) gets the type of the values, rather than the keys.
Expand Down
14 changes: 10 additions & 4 deletions src/simple_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,15 +125,15 @@ Evaluation in transformed space of course also works:

```jldoctest simplevarinfo-general
julia> vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), true)
Transformed SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0), NumProduce = NumProduceAccumulator(0)))
Transformed SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0), VariableOrder = VariableOrderAccumulator(0, Dict{VarName, Int64}())))

julia> # (✓) Positive probability mass on negative numbers!
getlogjoint(last(DynamicPPL.evaluate!!(m, vi, DynamicPPL.DefaultContext())))
-1.3678794411714423

julia> # While if we forget to indicate that it's transformed:
vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), false)
SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0), NumProduce = NumProduceAccumulator(0)))
SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0), VariableOrder = VariableOrderAccumulator(0, Dict{VarName, Int64}())))

julia> # (✓) No probability mass on negative numbers!
getlogjoint(last(DynamicPPL.evaluate!!(m, vi, DynamicPPL.DefaultContext())))
Expand Down Expand Up @@ -198,6 +198,12 @@ struct SimpleVarInfo{NT,Accs<:AccumulatorTuple where {N},C<:AbstractTransformati
transformation::C
end

function Base.:(==)(vi1::SimpleVarInfo, vi2::SimpleVarInfo)
return vi1.values == vi2.values &&
vi1.accs == vi2.accs &&
vi1.transformation == vi2.transformation
end

transformation(vi::SimpleVarInfo) = vi.transformation

function SimpleVarInfo(values, accs)
Expand Down Expand Up @@ -242,7 +248,7 @@ end
# Constructor from `VarInfo`.
function SimpleVarInfo(vi::NTVarInfo, ::Type{D}) where {D}
values = values_as(vi, D)
return SimpleVarInfo(values, deepcopy(getaccs(vi)))
return SimpleVarInfo(values, copy(getaccs(vi)))
end
function SimpleVarInfo{T}(vi::NTVarInfo, ::Type{D}) where {T<:Real,D}
values = values_as(vi, D)
Expand Down Expand Up @@ -441,7 +447,7 @@ _subset(x::VarNamedVector, vns) = subset(x, vns)
# `merge`
function Base.merge(varinfo_left::SimpleVarInfo, varinfo_right::SimpleVarInfo)
values = merge(varinfo_left.values, varinfo_right.values)
accs = deepcopy(getaccs(varinfo_right))
accs = copy(getaccs(varinfo_right))
transformation = merge_transformations(
varinfo_left.transformation, varinfo_right.transformation
)
Expand Down
4 changes: 3 additions & 1 deletion src/threadsafe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@ end

syms(vi::ThreadSafeVarInfo) = syms(vi.varinfo)

setorder!(vi::ThreadSafeVarInfo, vn::VarName, index::Int) = setorder!(vi.varinfo, vn, index)
function setorder!!(vi::ThreadSafeVarInfo, vn::VarName, index::Int)
return ThreadSafeVarInfo(setorder!!(vi.varinfo, vn, index), vi.accs_by_thread)
end
setval!(vi::ThreadSafeVarInfo, val, vn::VarName) = setval!(vi.varinfo, val, vn)

keys(vi::ThreadSafeVarInfo) = keys(vi.varinfo)
Expand Down
Loading
Loading