diff --git a/HISTORY.md b/HISTORY.md index 55010c533..9c85674c3 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,17 @@ # DynamicPPL Changelog +## 0.37.4 + +An extension for MarginalLogDensities.jl has been added. + +Loading DynamicPPL and MarginalLogDensities now provides the `DynamicPPL.marginalize` function to marginalise out variables from a model. +This is useful for averaging out random effects or nuisance parameters while improving inference on fixed effects/parameters of interest. +The `marginalize` function returns a `MarginalLogDensities.MarginalLogDensity`, a function-like callable struct that returns the approximate log-density of a subset of the parameters after integrating out the rest of them. +By default, this uses the Laplace approximation and sparse AD, making the marginalisation computationally very efficient. +Note that the Laplace approximation relies on the model being differentiable with respect to the marginalised variables, and that their posteriors are unimodal and approximately Gaussian. + +Please see [the MarginalLogDensities documentation](https://eloceanografo.github.io/MarginalLogDensities.jl/stable) and the [new Marginalisation section of the DynamicPPL documentation](https://turinglang.org/DynamicPPL.jl/v0.37/api/#Marginalisation) for further information. + ## 0.37.3 Prevents inlining of `DynamicPPL.istrans` with Enzyme, which allows Enzyme to differentiate models where `VarName`s have the same symbol but different types. diff --git a/Project.toml b/Project.toml index 024aef5c3..51028b831 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.37.3" +version = "0.37.4" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -33,6 +33,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" +MarginalLogDensities = "f0c3360a-fb8d-11e9-1194-5521fd7ee392" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" [extensions] @@ -41,6 +42,7 @@ DynamicPPLEnzymeCoreExt = ["EnzymeCore"] DynamicPPLForwardDiffExt = ["ForwardDiff"] DynamicPPLJETExt = ["JET"] DynamicPPLMCMCChainsExt = ["MCMCChains"] +DynamicPPLMarginalLogDensitiesExt = ["MarginalLogDensities"] DynamicPPLMooncakeExt = ["Mooncake"] [compat] @@ -66,6 +68,7 @@ LinearAlgebra = "1.6" LogDensityProblems = "2" MCMCChains = "6, 7" MacroTools = "0.5.6" +MarginalLogDensities = "0.4.3" Mooncake = "0.4.147" OrderedCollections = "1" Printf = "1.10" diff --git a/docs/Project.toml b/docs/Project.toml index 1f01b11ef..cc0be339d 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -10,6 +10,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" +MarginalLogDensities = "f0c3360a-fb8d-11e9-1194-5521fd7ee392" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" [compat] diff --git a/docs/make.jl b/docs/make.jl index 9c59cb06b..828b20658 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -11,6 +11,10 @@ using Distributions using DocumenterMermaid # load MCMCChains package extension to make `predict` available using MCMCChains +using MarginalLogDensities: MarginalLogDensities + +# Need this to document a method which uses a type inside the extension... +DPPLMLDExt = Base.get_extension(DynamicPPL, :DynamicPPLMarginalLogDensitiesExt) # Doctest setup DocMeta.setdocmeta!( @@ -24,7 +28,11 @@ makedocs(; format=Documenter.HTML(; size_threshold=2^10 * 400, mathengine=Documenter.HTMLWriter.MathJax3() ), - modules=[DynamicPPL, Base.get_extension(DynamicPPL, :DynamicPPLMCMCChainsExt)], + modules=[ + DynamicPPL, + Base.get_extension(DynamicPPL, :DynamicPPLMCMCChainsExt), + Base.get_extension(DynamicPPL, :DynamicPPLMarginalLogDensitiesExt), + ], pages=[ "Home" => "index.md", "API" => "api.md", "Internals" => ["internals/varinfo.md"] ], diff --git a/docs/src/api.md b/docs/src/api.md index 9a1923b53..999bbe822 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -136,6 +136,22 @@ When using `predict` with `MCMCChains.Chains`, you can control which variables a - `include_all=false` (default): Include only newly predicted variables - `include_all=true`: Include both parameters from the original chain and predicted variables +## Marginalisation + +DynamicPPL provides the `marginalize` function to marginalise out variables from a model. +This requires `MarginalLogDensities.jl` to be loaded in your environment. + +```@docs +marginalize +``` + +A `MarginalLogDensity` object acts as a function which maps non-marginalised parameter values to a marginal log-probability. +To retrieve a VarInfo object from it, you can use: + +```@docs +VarInfo(::MarginalLogDensities.MarginalLogDensity{<:DPPLMLDExt.LogDensityFunctionWrapper}, ::Union{AbstractVector,Nothing}) +``` + ## Models within models One can include models and call another model inside the model function with `left ~ to_submodel(model)`. diff --git a/ext/DynamicPPLMarginalLogDensitiesExt.jl b/ext/DynamicPPLMarginalLogDensitiesExt.jl new file mode 100644 index 000000000..2155fa161 --- /dev/null +++ b/ext/DynamicPPLMarginalLogDensitiesExt.jl @@ -0,0 +1,204 @@ +module DynamicPPLMarginalLogDensitiesExt + +using DynamicPPL: DynamicPPL, LogDensityProblems, VarName +using MarginalLogDensities: MarginalLogDensities + +# A thin wrapper to adapt a DynamicPPL.LogDensityFunction to the interface expected by +# MarginalLogDensities. It's helpful to have a struct so that we can dispatch on its type +# below. +struct LogDensityFunctionWrapper{L<:DynamicPPL.LogDensityFunction} + logdensity::L +end +function (lw::LogDensityFunctionWrapper)(x, _) + return LogDensityProblems.logdensity(lw.logdensity, x) +end + +""" + marginalize( + model::DynamicPPL.Model, + marginalized_varnames::AbstractVector{<:VarName}; + varinfo::DynamicPPL.AbstractVarInfo=link(VarInfo(model), model), + getlogprob=DynamicPPL.getlogjoint, + method::MarginalLogDensities.AbstractMarginalizer=MarginalLogDensities.LaplaceApprox(); + kwargs..., + ) + +Construct a `MarginalLogDensities.MarginalLogDensity` object that represents the marginal +log-density of the given `model`, after marginalizing out the variables specified in +`varnames`. + +The resulting object can be called with a vector of parameter values to compute the marginal +log-density. + +## Keyword arguments + +- `varinfo`: The `varinfo` to use for the model. By default we use a linked `VarInfo`, + meaning that the resulting log-density function accepts parameters that have been + transformed to unconstrained space. + +- `getlogprob`: A function which specifies which kind of marginal log-density to compute. + Its default value is `DynamicPPL.getlogjoint` which returns the marginal log-joint + probability. + +- `method`: The marginalization method; defaults to a Laplace approximation. Please see [the + MarginalLogDensities.jl package](https://github.com/ElOceanografo/MarginalLogDensities.jl/) + for other options. + +- Other keyword arguments are passed to the `MarginalLogDensities.MarginalLogDensity` + constructor. + +## Example + +```jldoctest +julia> using DynamicPPL, Distributions, MarginalLogDensities + +julia> @model function demo() + x ~ Normal(1.0) + y ~ Normal(2.0) + end +demo (generic function with 2 methods) + +julia> marginalized = marginalize(demo(), [:x]); + +julia> # The resulting callable computes the marginal log-density of `y`. + marginalized([1.0]) +-1.4189385332046727 + +julia> logpdf(Normal(2.0), 1.0) +-1.4189385332046727 +``` + + +!!! warning + + The default usage of linked VarInfo means that, for example, optimization of the + marginal log-density can be performed in unconstrained space. However, care must be + taken if the model contains variables where the link transformation depends on a + marginalized variable. For example: + + ```julia + @model function f() + x ~ Normal() + y ~ truncated(Normal(); lower=x) + end + ``` + + Here, the support of `y`, and hence the link transformation used, depends on the value + of `x`. If we now marginalize over `x`, we obtain a function mapping linked values of + `y` to log-probabilities. However, it will not be possible to use DynamicPPL to + correctly retrieve _unlinked_ values of `y`. +""" +function DynamicPPL.marginalize( + model::DynamicPPL.Model, + marginalized_varnames::AbstractVector{<:VarName}; + varinfo::DynamicPPL.AbstractVarInfo=DynamicPPL.link(DynamicPPL.VarInfo(model), model), + getlogprob::Function=DynamicPPL.getlogjoint, + method::MarginalLogDensities.AbstractMarginalizer=MarginalLogDensities.LaplaceApprox(), + kwargs..., +) + # Determine the indices for the variables to marginalise out. + varindices = reduce(vcat, DynamicPPL.vector_getranges(varinfo, marginalized_varnames)) + # Construct the marginal log-density model. + f = DynamicPPL.LogDensityFunction(model, getlogprob, varinfo) + mld = MarginalLogDensities.MarginalLogDensity( + LogDensityFunctionWrapper(f), varinfo[:], varindices, (), method; kwargs... + ) + return mld +end + +""" + VarInfo( + mld::MarginalLogDensities.MarginalLogDensity{<:LogDensityFunctionWrapper}, + unmarginalized_params::Union{AbstractVector,Nothing}=nothing + ) + +Retrieve the `VarInfo` object used in the marginalisation process. + +If a Laplace approximation was used for the marginalisation, the values of the marginalized +parameters are also set to their mode (note that this only happens if the `mld` object has +been used to compute the marginal log-density at least once, so that the mode has been +computed). + +If a vector of `unmarginalized_params` is specified, the values for the corresponding +parameters will also be updated in the returned VarInfo. This vector may be obtained e.g. by +performing an optimization of the marginal log-density. + +All other aspects of the VarInfo, such as link status, are preserved from the original +VarInfo used in the marginalisation. + +!!! note + + The other fields of the VarInfo, e.g. accumulated log-probabilities, will not be + updated. If you wish to have a fully consistent VarInfo, you should re-evaluate the + model with the returned VarInfo (e.g. using `vi = last(DynamicPPL.evaluate!!(model, + vi))`). + +## Example + +```jldoctest +julia> using DynamicPPL, Distributions, MarginalLogDensities + +julia> @model function demo() + x ~ Normal() + y ~ Beta(2, 2) + end +demo (generic function with 2 methods) + +julia> # Note that by default `marginalize` uses a linked VarInfo. + mld = marginalize(demo(), [@varname(x)]); + +julia> using MarginalLogDensities: Optimization, OptimizationOptimJL + +julia> # Find the mode of the marginal log-density of `y`, with an initial point of `y0`. + y0 = 2.0; opt_problem = Optimization.OptimizationProblem(mld, [y0]) +OptimizationProblem. In-place: true +u0: 1-element Vector{Float64}: + 2.0 + +julia> # This tells us the optimal (linked) value of `y` is around 0. + opt_solution = Optimization.solve(opt_problem, OptimizationOptimJL.NelderMead()) +retcode: Success +u: 1-element Vector{Float64}: + 4.88281250001733e-5 + +julia> # Get the VarInfo corresponding to the mode of `y`. + vi = VarInfo(mld, opt_solution.u); + +julia> # `x` is set to its mode (which for `Normal()` is zero). + vi[@varname(x)] +0.0 + +julia> # `y` is set to the optimal value we found above. + DynamicPPL.getindex_internal(vi, @varname(y)) +1-element Vector{Float64}: + 4.88281250001733e-5 + +julia> # To obtain values in the original constrained space, we can either + # use `getindex`: + vi[@varname(y)] +0.5000122070312476 + +julia> # Or invlink the entire VarInfo object using the model: + vi_unlinked = DynamicPPL.invlink(vi, demo()); vi_unlinked[:] +2-element Vector{Float64}: + 0.0 + 0.5000122070312476 +``` +""" +function DynamicPPL.VarInfo( + mld::MarginalLogDensities.MarginalLogDensity{<:LogDensityFunctionWrapper}, + unmarginalized_params::Union{AbstractVector,Nothing}=nothing, +) + # Extract the original VarInfo. Its contents will in general be junk. + original_vi = mld.logdensity.logdensity.varinfo + # Extract the stored parameters, which includes the modes for any marginalized + # parameters + full_params = MarginalLogDensities.cached_params(mld) + # We can then (if needed) set the values for any non-marginalized parameters + if unmarginalized_params !== nothing + full_params[MarginalLogDensities.ijoint(mld)] = unmarginalized_params + end + return DynamicPPL.unflatten(original_vi, full_params) +end + +end diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 5c8233915..bdc953a12 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -122,6 +122,7 @@ export AbstractVarInfo, fix, unfix, predict, + marginalize, prefix, returned, to_submodel, @@ -199,9 +200,9 @@ include("test_utils.jl") include("experimental.jl") include("deprecated.jl") -# Better error message if users forget to load JET if isdefined(Base.Experimental, :register_error_hint) function __init__() + # Better error message if users forget to load JET.jl Base.Experimental.register_error_hint(MethodError) do io, exc, argtypes, _ requires_jet = exc.f === DynamicPPL.Experimental._determine_varinfo_jet && @@ -222,6 +223,23 @@ if isdefined(Base.Experimental, :register_error_hint) end end + # Same for MarginalLogDensities.jl + Base.Experimental.register_error_hint(MethodError) do io, exc, argtypes, _ + requires_mld = + exc.f === DynamicPPL.marginalize && + length(argtypes) == 2 && + argtypes[1] <: Model && + argtypes[2] <: AbstractVector{<:Union{Symbol,<:VarName}} + if requires_mld + printstyled( + io, + "\n\n `$(exc.f)` requires MarginalLogDensities.jl to be loaded.\n Please run `using MarginalLogDensities` before calling `$(exc.f)`.\n"; + color=:cyan, + bold=true, + ) + end + end + Base.Experimental.register_error_hint(MethodError) do io, exc, argtypes, _ is_evaluate_three_arg = exc.f === AbstractPPL.evaluate!! && @@ -243,4 +261,7 @@ end # Ref: https://www.stochasticlifestyle.com/improved-forwarddiff-jl-stacktraces-with-package-tags/ struct DynamicPPLTag end +# Extended in MarginalLogDensitiesExt +function marginalize end + end # module diff --git a/test/Project.toml b/test/Project.toml index 537214464..589b150f4 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -18,6 +18,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +MarginalLogDensities = "f0c3360a-fb8d-11e9-1194-5521fd7ee392" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/test/ext/DynamicPPLMarginalLogDensitiesExt.jl b/test/ext/DynamicPPLMarginalLogDensitiesExt.jl new file mode 100644 index 000000000..32c4bb479 --- /dev/null +++ b/test/ext/DynamicPPLMarginalLogDensitiesExt.jl @@ -0,0 +1,104 @@ +module MarginalLogDensitiesExtTests + +using Bijectors: Bijectors +using DynamicPPL, Distributions, Test +using MarginalLogDensities +using ADTypes: AutoForwardDiff + +@testset "MarginalLogDensities" begin + @testset "Basic usage" begin + @model function demo() + x ~ MvNormal(zeros(2), [1, 1]) + return y ~ Normal(0, 1) + end + model = demo() + vi = VarInfo(model) + # Marginalize out `x`. + @testset for getlogprob in [DynamicPPL.getlogprior, DynamicPPL.getlogjoint] + marginalized = marginalize( + model, + [@varname(x)]; + varinfo=vi, + getlogprob=getlogprob, + hess_adtype=AutoForwardDiff(), + ) + for y in range(-5, 5; length=100) + @test marginalized([y]) ≈ logpdf(Normal(0, 1), y) atol = 1e-5 + end + end + end + + @testset "Respects linked status of VarInfo" begin + @model function f() + x ~ Normal() + return y ~ Beta(2, 2) + end + model = f() + vi_unlinked = VarInfo(model) + vi_linked = DynamicPPL.link(vi_unlinked, model) + + @testset "unlinked VarInfo" begin + mx = marginalize(model, [@varname(x)]; varinfo=vi_unlinked) + for x in range(0.01, 0.99; length=10) + @test mx([x]) ≈ logpdf(Beta(2, 2), x) + end + # generally when marginalising Beta it doesn't go to zero + # https://github.com/TuringLang/DynamicPPL.jl/pull/1036#discussion_r2349388067 + my = marginalize(model, [@varname(y)]; varinfo=vi_unlinked) + diff = my([0.0]) - logpdf(Normal(), 0.0) + for x in range(-5, 5; length=10) + @test my([x]) ≈ logpdf(Normal(), x) + diff + end + end + + @testset "linked VarInfo" begin + mx = marginalize(model, [@varname(x)]; varinfo=vi_linked) + binv = Bijectors.inverse(Bijectors.bijector(Beta(2, 2))) + for y_linked in range(-5, 5; length=10) + y_unlinked = binv(y_linked) + @test mx([y_linked]) ≈ logpdf(Beta(2, 2), y_unlinked) + end + # generally when marginalising Beta it doesn't go to zero + # https://github.com/TuringLang/DynamicPPL.jl/pull/1036#discussion_r2349388067 + my = marginalize(model, [@varname(y)]; varinfo=vi_linked) + diff = my([0.0]) - logpdf(Normal(), 0.0) + for x in range(-5, 5; length=10) + @test my([x]) ≈ logpdf(Normal(), x) + diff + end + end + end + + @testset "retrieving VarInfo from MLD" begin + @model function f() + x ~ Normal() + return y ~ Beta(2, 2) + end + model = f() + vi_unlinked = VarInfo(model) + vi_linked = DynamicPPL.link(vi_unlinked, model) + + @testset "unlinked VarInfo" begin + mx = marginalize(model, [@varname(x)]; varinfo=vi_unlinked) + mx([0.5]) # evaluate at some point to force calculation of Laplace approx + vi = VarInfo(mx) + @test vi[@varname(x)] ≈ mode(Normal()) + vi = VarInfo(mx, [0.5]) # this 0.5 is unlinked + @test vi[@varname(x)] ≈ mode(Normal()) + @test vi[@varname(y)] ≈ 0.5 + end + + @testset "linked VarInfo" begin + mx = marginalize(model, [@varname(x)]; varinfo=vi_linked) + mx([0.5]) # evaluate at some point to force calculation of Laplace approx + vi = VarInfo(mx) + @test vi[@varname(x)] ≈ mode(Normal()) + vi = VarInfo(mx, [0.5]) # this 0.5 is linked + binv = Bijectors.inverse(Bijectors.bijector(Beta(2, 2))) + @test vi[@varname(x)] ≈ mode(Normal()) + # when using getindex it always returns unlinked values + @test vi[@varname(y)] ≈ binv(0.5) + end + end +end + +end diff --git a/test/runtests.jl b/test/runtests.jl index c60c06786..40960884e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -80,6 +80,7 @@ include("test_util.jl") @testset "extensions" begin include("ext/DynamicPPLMCMCChainsExt.jl") include("ext/DynamicPPLJETExt.jl") + include("ext/DynamicPPLMarginalLogDensitiesExt.jl") end @testset "ad" begin include("ext/DynamicPPLForwardDiffExt.jl")