Skip to content

Accumulators stage 2 #925

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jul 18, 2025
Merged
2 changes: 1 addition & 1 deletion benchmarks/src/DynamicPPLBenchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked::
vi = DynamicPPL.link(vi, model)
end

f = DynamicPPL.LogDensityFunction(model, vi; adtype=adbackend)
f = DynamicPPL.LogDensityFunction(model, DynamicPPL.getlogjoint, vi; adtype=adbackend)
# The parameters at which we evaluate f.
θ = vi[:]

Expand Down
141 changes: 91 additions & 50 deletions src/logdensityfunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ is_supported(::ADTypes.AutoReverseDiff) = true
"""
LogDensityFunction(
model::Model,
varinfo::AbstractVarInfo=VarInfo(model);
getlogdensity::Function=getlogjoint,
varinfo::AbstractVarInfo=ldf_default_varinfo(model, getlogdensity);
adtype::Union{ADTypes.AbstractADType,Nothing}=nothing
)

Expand All @@ -28,9 +29,10 @@ A struct which contains a model, along with all the information necessary to:
- and if `adtype` is provided, calculate the gradient of the log density at
that point.

At its most basic level, a LogDensityFunction wraps the model together with the
type of varinfo to be used. These must be known in order to calculate the log
density (using [`DynamicPPL.evaluate!!`](@ref)).
At its most basic level, a LogDensityFunction wraps the model together with a
function that specifies how to extract the log density, and the type of
VarInfo to be used. These must be known in order to calculate the log density
(using [`DynamicPPL.evaluate!!`](@ref)).

If the `adtype` keyword argument is provided, then this struct will also store
the adtype along with other information for efficient calculation of the
Expand Down Expand Up @@ -72,13 +74,13 @@ julia> LogDensityProblems.dimension(f)
1

julia> # By default it uses `VarInfo` under the hood, but this is not necessary.
f = LogDensityFunction(model, SimpleVarInfo(model));
f = LogDensityFunction(model, getlogjoint, SimpleVarInfo(model));

julia> LogDensityProblems.logdensity(f, [0.0])
-2.3378770664093453

julia> # LogDensityFunction respects the accumulators in VarInfo:
f_prior = LogDensityFunction(model, setaccs!!(VarInfo(model), (LogPriorAccumulator(),)));
julia> # One can also specify evaluating e.g. the log prior only:
f_prior = LogDensityFunction(model, getlogprior);

julia> LogDensityProblems.logdensity(f_prior, [0.0]) == logpdf(Normal(), 0.0)
true
Expand All @@ -93,11 +95,13 @@ julia> LogDensityProblems.logdensity_and_gradient(f, [0.0])
```
"""
struct LogDensityFunction{
M<:Model,V<:AbstractVarInfo,AD<:Union{Nothing,ADTypes.AbstractADType}
M<:Model,F<:Function,V<:AbstractVarInfo,AD<:Union{Nothing,ADTypes.AbstractADType}
} <: AbstractModel
"model used for evaluation"
model::M
"varinfo used for evaluation"
"function to be called on `varinfo` to extract the log density. By default `getlogjoint`."
getlogdensity::F
"varinfo used for evaluation. If not specified, generated with `ldf_default_varinfo`."
varinfo::V
"AD type used for evaluation of log density gradient. If `nothing`, no gradient can be calculated"
adtype::AD
Expand All @@ -106,7 +110,8 @@ struct LogDensityFunction{

function LogDensityFunction(
model::Model,
varinfo::AbstractVarInfo=VarInfo(model);
getlogdensity::Function=getlogjoint,
varinfo::AbstractVarInfo=ldf_default_varinfo(model, getlogdensity);
adtype::Union{ADTypes.AbstractADType,Nothing}=nothing,
)
if adtype === nothing
Expand All @@ -120,15 +125,22 @@ struct LogDensityFunction{
# Get a set of dummy params to use for prep
x = map(identity, varinfo[:])
if use_closure(adtype)
prep = DI.prepare_gradient(LogDensityAt(model, varinfo), adtype, x)
prep = DI.prepare_gradient(
LogDensityAt(model, getlogdensity, varinfo), adtype, x
)
else
prep = DI.prepare_gradient(
logdensity_at, adtype, x, DI.Constant(model), DI.Constant(varinfo)
logdensity_at,
adtype,
x,
DI.Constant(model),
DI.Constant(getlogdensity),
DI.Constant(varinfo),
)
end
end
return new{typeof(model),typeof(varinfo),typeof(adtype)}(
model, varinfo, adtype, prep
return new{typeof(model),typeof(getlogdensity),typeof(varinfo),typeof(adtype)}(
model, getlogdensity, varinfo, adtype, prep
)
end
end
Expand All @@ -149,83 +161,112 @@ function LogDensityFunction(
return if adtype === f.adtype
f # Avoid recomputing prep if not needed
else
LogDensityFunction(f.model, f.varinfo; adtype=adtype)
LogDensityFunction(f.model, f.getlogdensity, f.varinfo; adtype=adtype)
end
end

"""
ldf_default_varinfo(model::Model, getlogdensity::Function)

Create the default AbstractVarInfo that should be used for evaluating the log density.

Only the accumulators necesessary for `getlogdensity` will be used.
"""
function ldf_default_varinfo(::Model, getlogdensity::Function)
msg = """
LogDensityFunction does not know what sort of VarInfo should be used when \
`getlogdensity` is $getlogdensity. Please specify a VarInfo explicitly.
"""
return error(msg)
end

ldf_default_varinfo(model::Model, ::typeof(getlogjoint)) = VarInfo(model)

function ldf_default_varinfo(model::Model, ::typeof(getlogprior))
return setaccs!!(VarInfo(model), (LogPriorAccumulator(),))
end

function ldf_default_varinfo(model::Model, ::typeof(getloglikelihood))
return setaccs!!(VarInfo(model), (LogLikelihoodAccumulator(),))
end

"""
logdensity_at(
x::AbstractVector,
model::Model,
getlogdensity::Function,
varinfo::AbstractVarInfo,
)

Evaluate the log density of the given `model` at the given parameter values `x`,
using the given `varinfo`. Note that the `varinfo` argument is provided only
for its structure, in the sense that the parameters from the vector `x` are
inserted into it, and its own parameters are discarded. It does, however,
determine whether the log prior, likelihood, or joint is returned, based on
which accumulators are set in it.
Evaluate the log density of the given `model` at the given parameter values
`x`, using the given `varinfo`. Note that the `varinfo` argument is provided
only for its structure, in the sense that the parameters from the vector `x`
are inserted into it, and its own parameters are discarded. `getlogdensity` is
the function that extracts the log density from the evaluated varinfo.
"""
function logdensity_at(x::AbstractVector, model::Model, varinfo::AbstractVarInfo)
function logdensity_at(
x::AbstractVector, model::Model, getlogdensity::Function, varinfo::AbstractVarInfo
)
varinfo_new = unflatten(varinfo, x)
varinfo_eval = last(evaluate!!(model, varinfo_new))
has_prior = hasacc(varinfo_eval, Val(:LogPrior))
has_likelihood = hasacc(varinfo_eval, Val(:LogLikelihood))
if has_prior && has_likelihood
return getlogjoint(varinfo_eval)
elseif has_prior
return getlogprior(varinfo_eval)
elseif has_likelihood
return getloglikelihood(varinfo_eval)
else
error("LogDensityFunction: varinfo tracks neither log prior nor log likelihood")
end
return getlogdensity(varinfo_eval)
end

"""
LogDensityAt{M<:Model,V<:AbstractVarInfo}(
LogDensityAt{M<:Model,F<:Function,V<:AbstractVarInfo}(
model::M
getlogdensity::F,
varinfo::V
)

A callable struct that serves the same purpose as `x -> logdensity_at(x, model,
varinfo)`.
getlogdensity, varinfo)`.
"""
struct LogDensityAt{M<:Model,V<:AbstractVarInfo}
struct LogDensityAt{M<:Model,F<:Function,V<:AbstractVarInfo}
model::M
getlogdensity::F
varinfo::V
end
(ld::LogDensityAt)(x::AbstractVector) = logdensity_at(x, ld.model, ld.varinfo)
function (ld::LogDensityAt)(x::AbstractVector)
return logdensity_at(x, ld.model, ld.getlogdensity, ld.varinfo)
end

### LogDensityProblems interface

function LogDensityProblems.capabilities(
::Type{<:LogDensityFunction{M,V,Nothing}}
) where {M,V}
::Type{<:LogDensityFunction{M,F,V,Nothing}}
) where {M,F,V}
return LogDensityProblems.LogDensityOrder{0}()
end
function LogDensityProblems.capabilities(
::Type{<:LogDensityFunction{M,V,AD}}
) where {M,V,AD<:ADTypes.AbstractADType}
::Type{<:LogDensityFunction{M,F,V,AD}}
) where {M,F,V,AD<:ADTypes.AbstractADType}
return LogDensityProblems.LogDensityOrder{1}()
end
function LogDensityProblems.logdensity(f::LogDensityFunction, x::AbstractVector)
return logdensity_at(x, f.model, f.varinfo)
return logdensity_at(x, f.model, f.getlogdensity, f.varinfo)
end
function LogDensityProblems.logdensity_and_gradient(
f::LogDensityFunction{M,V,AD}, x::AbstractVector
) where {M,V,AD<:ADTypes.AbstractADType}
f::LogDensityFunction{M,F,V,AD}, x::AbstractVector
) where {M,F,V,AD<:ADTypes.AbstractADType}
f.prep === nothing &&
error("Gradient preparation not available; this should not happen")
x = map(identity, x) # Concretise type
# Make branching statically inferrable, i.e. type-stable (even if the two
# branches happen to return different types)
return if use_closure(f.adtype)
DI.value_and_gradient(LogDensityAt(f.model, f.varinfo), f.prep, f.adtype, x)
DI.value_and_gradient(
LogDensityAt(f.model, f.getlogdensity, f.varinfo), f.prep, f.adtype, x
)
else
DI.value_and_gradient(
logdensity_at, f.prep, f.adtype, x, DI.Constant(f.model), DI.Constant(f.varinfo)
logdensity_at,
f.prep,
f.adtype,
x,
DI.Constant(f.model),
DI.Constant(f.getlogdensity),
DI.Constant(f.varinfo),
)
end
end
Expand Down Expand Up @@ -264,9 +305,9 @@ There are two ways of dealing with this:

1. Construct a closure over the model, i.e. let g = Base.Fix1(logdensity, f)

2. Use a constant context. This lets us pass a two-argument function to
DifferentiationInterface, as long as we also give it the 'inactive argument'
(i.e. the model) wrapped in `DI.Constant`.
2. Use a constant DI.Context. This lets us pass a two-argument function to DI,
as long as we also give it the 'inactive argument' (i.e. the model) wrapped
in `DI.Constant`.

The relative performance of the two approaches, however, depends on the AD
backend used. Some benchmarks are provided here:
Expand All @@ -292,7 +333,7 @@ getmodel(f::DynamicPPL.LogDensityFunction) = f.model
Set the `DynamicPPL.Model` in the given log-density function `f` to `model`.
"""
function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model)
return LogDensityFunction(model, f.varinfo; adtype=f.adtype)
return LogDensityFunction(model, f.getlogdensity, f.varinfo; adtype=f.adtype)
end

"""
Expand Down
8 changes: 6 additions & 2 deletions src/simple_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,9 @@ function link!!(
x = vi.values
y, logjac = with_logabsdet_jacobian(b, x)
vi_new = Accessors.@set(vi.values = y)
vi_new = acclogprior!!(vi_new, -logjac)
if hasacc(vi_new, Val(:LogPrior))
vi_new = acclogprior!!(vi_new, -logjac)
end
return settrans!!(vi_new, t)
end

Expand All @@ -632,7 +634,9 @@ function invlink!!(
y = vi.values
x, logjac = with_logabsdet_jacobian(b, y)
vi_new = Accessors.@set(vi.values = x)
vi_new = acclogprior!!(vi_new, logjac)
if hasacc(vi_new, Val(:LogPrior))
vi_new = acclogprior!!(vi_new, logjac)
end
return settrans!!(vi_new, NoTransformation())
end

Expand Down
13 changes: 10 additions & 3 deletions src/test_utils/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using ADTypes: AbstractADType, AutoForwardDiff
using Chairmarks: @be
import DifferentiationInterface as DI
using DocStringExtensions
using DynamicPPL: Model, LogDensityFunction, VarInfo, AbstractVarInfo, link
using DynamicPPL: Model, LogDensityFunction, VarInfo, AbstractVarInfo, getlogjoint, link
using LogDensityProblems: logdensity, logdensity_and_gradient
using Random: AbstractRNG, default_rng
using Statistics: median
Expand Down Expand Up @@ -88,6 +88,8 @@ $(TYPEDFIELDS)
struct ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat,Ttol<:AbstractFloat}
"The DynamicPPL model that was tested"
model::Model
"The function used to extract the log density from the model"
getlogdensity::Function
"The VarInfo that was used"
varinfo::AbstractVarInfo
"The values at which the model was evaluated"
Expand Down Expand Up @@ -222,6 +224,7 @@ function run_ad(
benchmark::Bool=false,
atol::AbstractFloat=100 * eps(),
rtol::AbstractFloat=sqrt(eps()),
getlogdensity::Function=getlogjoint,
rng::AbstractRNG=default_rng(),
varinfo::AbstractVarInfo=link(VarInfo(rng, model), model),
params::Union{Nothing,Vector{<:AbstractFloat}}=nothing,
Expand All @@ -241,7 +244,8 @@ function run_ad(
# Calculate log-density and gradient with the backend of interest
verbose && @info "Running AD on $(model.f) with $(adtype)\n"
verbose && println(" params : $(params)")
ldf = LogDensityFunction(model, varinfo; adtype=adtype)
ldf = LogDensityFunction(model, getlogdensity, varinfo; adtype=adtype)

value, grad = logdensity_and_gradient(ldf, params)
# collect(): https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754
grad = collect(grad)
Expand All @@ -257,7 +261,9 @@ function run_ad(
value_true = test.value
grad_true = test.grad
elseif test isa WithBackend
ldf_reference = LogDensityFunction(model, varinfo; adtype=test.adtype)
ldf_reference = LogDensityFunction(
model, getlogdensity, varinfo; adtype=test.adtype
)
value_true, grad_true = logdensity_and_gradient(ldf_reference, params)
# collect(): https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754
grad_true = collect(grad_true)
Expand All @@ -282,6 +288,7 @@ function run_ad(

return ADResult(
model,
getlogdensity,
varinfo,
params,
adtype,
Expand Down
Loading
Loading