diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml new file mode 100644 index 000000000..c7439503e --- /dev/null +++ b/.JuliaFormatter.toml @@ -0,0 +1 @@ +style = "blue" \ No newline at end of file diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 0ac7c3267..d9d9ad787 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -1,37 +1,27 @@ name: CI - on: push: branches: - master pull_request: - jobs: test: + name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} runs-on: ${{ matrix.os }} continue-on-error: ${{ matrix.version == 'nightly' }} strategy: + fail-fast: false matrix: version: - '1' + - '1.3' + - 'nightly' os: - ubuntu-latest - macOS-latest - windows-latest arch: - x64 - - x86 - exclude: - - os: macOS-latest - arch: x86 - include: - - version: '1.0' - os: ubuntu-latest - arch: x64 - - os: ubuntu-latest - version: '1' - arch: x64 - coverage: true steps: - uses: actions/checkout@v2 - uses: julia-actions/setup-julia@v1 @@ -48,16 +38,12 @@ jobs: ${{ runner.os }}-test-${{ env.cache-name }}- ${{ runner.os }}-test- ${{ runner.os }}- - - uses: julia-actions/julia-buildpkg@latest - - uses: julia-actions/julia-runtest@latest + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-runtest@v1 - uses: julia-actions/julia-processcoverage@v1 - if: matrix.coverage - - uses: codecov/codecov-action@v1 - if: matrix.coverage + - name: Send coverage + if: matrix.version == '1' && matrix.os == 'ubuntu-latest' + uses: coverallsapp/github-action@master with: - file: lcov.info - - uses: coverallsapp/github-action@master - if: matrix.coverage - with: - github-token: ${{ secrets.GITHUB_TOKEN }} - path-to-lcov: lcov.info \ No newline at end of file + github-token: ${{ secrets.GITHUB_TOKEN }} + path-to-lcov: ./lcov.info diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..7d493193e --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +test/Manifest.toml +.vscode \ No newline at end of file diff --git a/Project.toml b/Project.toml index e3b7ec0dc..3e82ed5a4 100644 --- a/Project.toml +++ b/Project.toml @@ -7,14 +7,15 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" -Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" [compat] Bijectors = "0.4.0, 0.5, 0.6, 0.7, 0.8" @@ -26,7 +27,6 @@ ProgressMeter = "1.0.0" Requires = "0.5, 1.0" StatsBase = "0.32, 0.33" StatsFuns = "0.8, 0.9" -Tracker = "0.2.3" julia = "1" [extras] diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 28906f9ad..4fced3df5 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -1,238 +1,70 @@ module AdvancedVI -using Random: AbstractRNG +using Bijectors: GLOBAL_RNG +using Random: AbstractRNG, GLOBAL_RNG -using Distributions, DistributionsAD, Bijectors +using Bijectors +using Distributions +using DistributionsAD using DocStringExtensions - -using ProgressMeter, LinearAlgebra - using ForwardDiff -using Tracker +using Flux: Optimise # Temp before Optimisers.jl is registered +using Functors +using ProgressMeter, LinearAlgebra +using Random +using Requires const PROGRESS = Ref(true) function turnprogress(switch::Bool) @info("[AdvancedVI]: global PROGRESS is set as $switch") - PROGRESS[] = switch + return PROGRESS[] = switch end const DEBUG = Bool(parse(Int, get(ENV, "DEBUG_ADVANCEDVI", "0"))) include("ad.jl") -using Requires function __init__() - @require Flux="587475ba-b771-5e3f-ad9e-33799f191a9c" begin + @require Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" begin apply!(o, x, Δ) = Flux.Optimise.apply!(o, x, Δ) Flux.Optimise.apply!(o::TruncatedADAGrad, x, Δ) = apply!(o, x, Δ) Flux.Optimise.apply!(o::DecayedADAGrad, x, Δ) = apply!(o, x, Δ) end @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin - include("compat/zygote.jl") - export ZygoteAD - - function AdvancedVI.grad!( - vo, - alg::VariationalInference{<:AdvancedVI.ZygoteAD}, - q, - model, - θ::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult, - args... - ) - f(θ) = if (q isa Distribution) - - vo(alg, update(q, θ), model, args...) - else - - vo(alg, q(θ), model, args...) - end - y, back = Zygote.pullback(f, θ) - dy = first(back(1.0)) - DiffResults.value!(out, y) - DiffResults.gradient!(out, dy) - return out - end + include(joinpath("compat", "zygote.jl")) end @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin - include("compat/reversediff.jl") - export ReverseDiffAD - - function AdvancedVI.grad!( - vo, - alg::VariationalInference{<:AdvancedVI.ReverseDiffAD{false}}, - q, - model, - θ::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult, - args... - ) - f(θ) = if (q isa Distribution) - - vo(alg, update(q, θ), model, args...) - else - - vo(alg, q(θ), model, args...) - end - tp = AdvancedVI.tape(f, θ) - ReverseDiff.gradient!(out, tp, θ) - return out - end + include(joinpath("compat", "reversediff.jl")) + end + @require Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin + include(joinpath("compat", "tracker.jl")) end end -export - vi, - ADVI, - ELBO, - elbo, - TruncatedADAGrad, - DecayedADAGrad, - VariationalInference +export vi, ADVI, BBVI, ELBO, TruncatedADAGrad, DecayedADAGrad, VariationalInference abstract type VariationalInference{AD} end -getchunksize(::Type{<:VariationalInference{AD}}) where AD = getchunksize(AD) -getADtype(::VariationalInference{AD}) where AD = AD +getchunksize(::Type{<:VariationalInference{AD}}) where {AD} = getchunksize(AD) +getADtype(::VariationalInference{AD}) where {AD} = AD abstract type VariationalObjective end -const VariationalPosterior = Distribution{Multivariate, Continuous} - - -""" - grad!(vo, alg::VariationalInference, q, model::Model, θ, out, args...) - -Computes the gradients used in `optimize!`. Default implementation is provided for -`VariationalInference{AD}` where `AD` is either `ForwardDiffAD` or `TrackerAD`. -This implicitly also gives a default implementation of `optimize!`. - -Variance reduction techniques, e.g. control variates, should be implemented in this function. -""" -function grad! end - -""" - vi(model, alg::VariationalInference) - vi(model, alg::VariationalInference, q::VariationalPosterior) - vi(model, alg::VariationalInference, getq::Function, θ::AbstractArray) - -Constructs the variational posterior from the `model` and performs the optimization -following the configuration of the given `VariationalInference` instance. - -# Arguments -- `model`: `Turing.Model` or `Function` z ↦ log p(x, z) where `x` denotes the observations -- `alg`: the VI algorithm used -- `q`: a `VariationalPosterior` for which it is assumed a specialized implementation of the variational objective used exists. -- `getq`: function taking parameters `θ` as input and returns a `VariationalPosterior` -- `θ`: only required if `getq` is used, in which case it is the initial parameters for the variational posterior -""" -function vi end - -function update end - -# default implementations -function grad!( - vo, - alg::VariationalInference{<:ForwardDiffAD}, - q, - model, - θ::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult, - args... -) - f(θ_) = if (q isa Distribution) - - vo(alg, update(q, θ_), model, args...) - else - - vo(alg, q(θ_), model, args...) - end - - chunk_size = getchunksize(typeof(alg)) - # Set chunk size and do ForwardMode. - chunk = ForwardDiff.Chunk(min(length(θ), chunk_size)) - config = ForwardDiff.GradientConfig(f, θ, chunk) - ForwardDiff.gradient!(out, f, θ, config) -end +const VariationalPosterior = Distribution{Multivariate,Continuous} -function grad!( - vo, - alg::VariationalInference{<:TrackerAD}, - q, - model, - θ::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult, - args... -) - θ_tracked = Tracker.param(θ) - y = if (q isa Distribution) - - vo(alg, update(q, θ_tracked), model, args...) - else - - vo(alg, q(θ_tracked), model, args...) - end - Tracker.back!(y, 1.0) - - DiffResults.value!(out, Tracker.data(y)) - DiffResults.gradient!(out, Tracker.grad(θ_tracked)) -end +# Custom distributions +include(joinpath("distributions", "distributions.jl")) - -""" - optimize!(vo, alg::VariationalInference{AD}, q::VariationalPosterior, model::Model, θ; optimizer = TruncatedADAGrad()) - -Iteratively updates parameters by calling `grad!` and using the given `optimizer` to compute -the steps. -""" -function optimize!( - vo, - alg::VariationalInference, - q, - model, - θ::AbstractVector{<:Real}; - optimizer = TruncatedADAGrad() -) - # TODO: should we always assume `samples_per_step` and `max_iters` for all algos? - alg_name = alg_str(alg) - samples_per_step = alg.samples_per_step - max_iters = alg.max_iters - - num_params = length(θ) - - # TODO: really need a better way to warn the user about potentially - # not using the correct accumulator - if (optimizer isa TruncatedADAGrad) && (θ ∉ keys(optimizer.acc)) - # this message should only occurr once in the optimization process - @info "[$alg_name] Should only be seen once: optimizer created for θ" objectid(θ) - end - - diff_result = DiffResults.GradientResult(θ) - - i = 0 - prog = if PROGRESS[] - ProgressMeter.Progress(max_iters, 1, "[$alg_name] Optimizing...", 0) - else - 0 - end - - # add criterion? A running mean maybe? - time_elapsed = @elapsed while (i < max_iters) # & converged - grad!(vo, alg, q, model, θ, diff_result, samples_per_step) - - # apply update rule - Δ = DiffResults.gradient(diff_result) - Δ = apply!(optimizer, θ, Δ) - @. θ = θ - Δ - - AdvancedVI.DEBUG && @debug "Step $i" Δ DiffResults.value(diff_result) - PROGRESS[] && (ProgressMeter.next!(prog)) - - i += 1 - end - - return θ -end +include("utils.jl") # objectives include("objectives.jl") - -# optimisers +include("gradients.jl") +include("interface.jl") include("optimisers.jl") # VI algorithms -include("advi.jl") +include(joinpath("algorithms", "advi.jl")) +include(joinpath("algorithms", "bbvi.jl")) end # module diff --git a/src/ad.jl b/src/ad.jl index 59c69cb0e..d38c32e8d 100644 --- a/src/ad.jl +++ b/src/ad.jl @@ -3,27 +3,16 @@ ############################## const ADBACKEND = Ref(:forwarddiff) setadbackend(backend_sym::Symbol) = setadbackend(Val(backend_sym)) -function setadbackend(::Val{:forward_diff}) - Base.depwarn("`AdvancedVI.setadbackend(:forward_diff)` is deprecated. Please use `AdvancedVI.setadbackend(:forwarddiff)` to use `ForwardDiff`.", :setadbackend) - setadbackend(Val(:forwarddiff)) -end + function setadbackend(::Val{:forwarddiff}) CHUNKSIZE[] == 0 && setchunksize(40) - ADBACKEND[] = :forwarddiff -end - -function setadbackend(::Val{:reverse_diff}) - Base.depwarn("`AdvancedVI.setadbackend(:reverse_diff)` is deprecated. Please use `AdvancedVI.setadbackend(:tracker)` to use `Tracker` or `AdvancedVI.setadbackend(:reversediff)` to use `ReverseDiff`. To use `ReverseDiff`, please make sure it is loaded separately with `using ReverseDiff`.", :setadbackend) - setadbackend(Val(:tracker)) -end -function setadbackend(::Val{:tracker}) - ADBACKEND[] = :tracker + return ADBACKEND[] = :forwarddiff end const ADSAFE = Ref(false) function setadsafe(switch::Bool) @info("[AdvancedVI]: global ADSAFE is set as $switch") - ADSAFE[] = switch + return ADSAFE[] = switch end const CHUNKSIZE = Ref(40) # default chunksize used by AD @@ -37,13 +26,14 @@ end abstract type ADBackend end struct ForwardDiffAD{chunk} <: ADBackend end -getchunksize(::Type{<:ForwardDiffAD{chunk}}) where chunk = chunk - -struct TrackerAD <: ADBackend end +getchunksize(::Type{<:ForwardDiffAD{chunk}}) where {chunk} = chunk ADBackend() = ADBackend(ADBACKEND[]) ADBackend(T::Symbol) = ADBackend(Val(T)) -ADBackend(::Val{:forwarddiff}) = ForwardDiffAD{CHUNKSIZE[]} -ADBackend(::Val{:tracker}) = TrackerAD -ADBackend(::Val) = error("The requested AD backend is not available. Make sure to load all required packages.") +ADBackend(::Val{:ForwardDiff}) = ForwardDiffAD{CHUNKSIZE[]} +function ADBackend(::Val) + return error( + "The requested AD backend is not available. Make sure to load all required packages.", + ) +end diff --git a/src/advi.jl b/src/advi.jl deleted file mode 100644 index a5f880cee..000000000 --- a/src/advi.jl +++ /dev/null @@ -1,100 +0,0 @@ -using StatsFuns -using DistributionsAD -using Bijectors -using Bijectors: TransformedDistribution -using Random: AbstractRNG, GLOBAL_RNG - - -""" -$(TYPEDEF) - -Automatic Differentiation Variational Inference (ADVI) with automatic differentiation -backend `AD`. - -# Fields - -$(TYPEDFIELDS) -""" -struct ADVI{AD} <: VariationalInference{AD} - "Number of samples used to estimate the ELBO in each optimization step." - samples_per_step::Int - "Maximum number of gradient steps." - max_iters::Int -end - -function ADVI(samples_per_step::Int=1, max_iters::Int=1000) - return ADVI{ADBackend()}(samples_per_step, max_iters) -end - -alg_str(::ADVI) = "ADVI" - -function vi(model, alg::ADVI, q, θ_init; optimizer = TruncatedADAGrad()) - θ = copy(θ_init) - optimize!(elbo, alg, q, model, θ; optimizer = optimizer) - - # If `q` is a mean-field approx we use the specialized `update` function - if q isa Distribution - return update(q, θ) - else - # Otherwise we assume it's a mapping θ → q - return q(θ) - end -end - - -function optimize(elbo::ELBO, alg::ADVI, q, model, θ_init; optimizer = TruncatedADAGrad()) - θ = copy(θ_init) - - # `model` assumed to be callable z ↦ p(x, z) - optimize!(elbo, alg, q, model, θ; optimizer = optimizer) - - return θ -end - -# WITHOUT updating parameters inside ELBO -function (elbo::ELBO)( - rng::AbstractRNG, - alg::ADVI, - q::VariationalPosterior, - logπ::Function, - num_samples -) - # 𝔼_q(z)[log p(xᵢ, z)] - # = ∫ log p(xᵢ, z) q(z) dz - # = ∫ log p(xᵢ, f(ϕ)) q(f(ϕ)) |det J_f(ϕ)| dϕ (since change of variables) - # = ∫ log p(xᵢ, f(ϕ)) q̃(ϕ) dϕ (since q(f(ϕ)) |det J_f(ϕ)| = q̃(ϕ)) - # = 𝔼_q̃(ϕ)[log p(xᵢ, z)] - - # 𝔼_q(z)[log q(z)] - # = ∫ q(f(ϕ)) log (q(f(ϕ))) |det J_f(ϕ)| dϕ (since q(f(ϕ)) |det J_f(ϕ)| = q̃(ϕ)) - # = 𝔼_q̃(ϕ) [log q(f(ϕ))] - # = 𝔼_q̃(ϕ) [log q̃(ϕ) - log |det J_f(ϕ)|] - # = 𝔼_q̃(ϕ) [log q̃(ϕ)] - 𝔼_q̃(ϕ) [log |det J_f(ϕ)|] - # = - ℍ(q̃(ϕ)) - 𝔼_q̃(ϕ) [log |det J_f(ϕ)|] - - # Finally, the ELBO is given by - # ELBO = 𝔼_q(z)[log p(xᵢ, z)] - 𝔼_q(z)[log q(z)] - # = 𝔼_q̃(ϕ)[log p(xᵢ, z)] + 𝔼_q̃(ϕ) [log |det J_f(ϕ)|] + ℍ(q̃(ϕ)) - - # If f: supp(p(z | x)) → ℝ then - # ELBO = 𝔼[log p(x, z) - log q(z)] - # = 𝔼[log p(x, f⁻¹(z̃)) + logabsdet(J(f⁻¹(z̃)))] + ℍ(q̃(z̃)) - # = 𝔼[log p(x, z) - logabsdetjac(J(f(z)))] + ℍ(q̃(z̃)) - - # But our `forward(q)` is using f⁻¹: ℝ → supp(p(z | x)) going forward → `+ logjac` - _, z, logjac, _ = forward(rng, q) - res = (logπ(z) + logjac) / num_samples - - if q isa TransformedDistribution - res += entropy(q.dist) - else - res += entropy(q) - end - - for i = 2:num_samples - _, z, logjac, _ = forward(rng, q) - res += (logπ(z) + logjac) / num_samples - end - - return res -end diff --git a/src/algorithms/advi.jl b/src/algorithms/advi.jl new file mode 100644 index 000000000..4b12473a3 --- /dev/null +++ b/src/algorithms/advi.jl @@ -0,0 +1,85 @@ +""" +$(TYPEDEF) + +"Automatic Differentiation Variational Inference" (ADVI) with automatic differentiation +backend `AD`. + +As described in [^ADVI16] + +# Fields + +$(TYPEDFIELDS) + +[^ADVI16]: Alp Kucukelbir, Dustin Tran, Rajesh Ranganath, Andrew Gelman, David M. Blei (2016), [Automatic Differentiation Variational Inference](https://arxiv.org/abs/1603.00788) +""" +struct ADVI{AD} <: VariationalInference{AD} + "Number of samples used to estimate the ELBO in each optimization step." + samples_per_step::Int + "Maximum number of gradient steps." + max_iters::Int +end + +function ADVI(samples_per_step::Int=1, max_iters::Int=1000) + return ADVI{ADBackend()}(samples_per_step, max_iters) +end + +alg_str(::ADVI) = "ADVI" +samples_per_step(alg::ADVI) = alg.samples_per_step +maxiters(alg::ADVI) = alg.max_iters + +function init(rng, alg::ADVI, q, θ, opt) # This is where the optimizer can be correctly initiated as well + n_samples_per_step = samples_per_step(alg) + x₀ = rand(rng, q(θ), n_samples_per_step) # Preallocating x₀ + x = similar(x₀) # Preallocating x + diff_result = DiffResults.GradientResult(x) + opt_state = Optimisers.init(opt, θ) + return (x₀=x₀, x=x, diff_result=diff_result, opt_state=opt_state) +end + +function step!(rng, ::ELBO, alg::ADVI, q, θ, logπ, state, opt) + randn!(rng, state.x₀) # Get initial samples from x₀ + reparametrize!(state.x, q(θ), state.x₀) + f(X) = + sum(eachcol(X)) do x + return evaluate(logπ, q, x) + end + grad!(state.diff_result, f, state.x, alg) + θ, state.opt_state = Optimisers.update!(opt, opt_state, θ, Δ) + return update!(alg, q, state, opt) +end + + +function update!(alg::ADVI, q, state, opt) + Δ = DiffResults.gradient(state.diff_result) + update_mean!(q, vec(mean(Δ; dims=2)), opt) + update_cov!(alg, q, Δ, state, opt) + return q +end + +function update_cov!(alg::ADVI, q::Bijectors.TransformedDistribution, Δ, state, opt) + return update_cov!(alg, q.dist, Δ, state, opt) +end + +if VERSION < v"1.5.0" + function update_cov!(::ADVI, q::CholMvNormal, Δ, state, opt) + return q.Γ .+= LowerTriangular( + Optimise.apply!( + opt, q.Γ.data, Δ * state.x₀' / size(state.x₀, 2) + inv(Diagonal(q.Γ.data)) + ), + ) + end +else + function update_cov!(::ADVI, q::CholMvNormal, Δ, state, opt) + return q.Γ .+= LowerTriangular( + Optimise.apply!( + opt, q.Γ.data, Δ * state.x₀' / size(state.x₀, 2) + inv(Diagonal(q.Γ)) + ), + ) + end +end + +function update_cov!(::ADVI, q::DiagMvNormal, Δ, state, opt) + return q.Γ .+= Optimise.apply!(opt, q.Γ, vec(mean(Δ .* state.x₀; dims=2)) + inv.(q.Γ)) +end + +Distributions.entropy(::ADVI, q) = Distributions.entropy(q) diff --git a/src/algorithms/bbvi.jl b/src/algorithms/bbvi.jl new file mode 100644 index 000000000..d28fa7af9 --- /dev/null +++ b/src/algorithms/bbvi.jl @@ -0,0 +1,59 @@ +""" +$(TYPEDEF) + +Black-Box Variational Inference (BBVI) with automatic differentiation +backend `AD`. + +# Fields + +$(TYPEDFIELDS) +""" +struct BBVI{AD} <: VariationalInference{AD} + "Number of samples used to estimate the ELBO in each optimization step." + samples_per_step::Int + "Maximum number of gradient steps." + max_iters::Int +end + +function BBVI(samples_per_step::Int=1, max_iters::Int=1000) + return BBVI{ADBackend()}(samples_per_step, max_iters) +end + +alg_str(::BBVI) = "BBVI" +nsamples(alg::BBVI) = alg.samples_per_step +niters(alg::BBVI) = alg.max_iters + +function compats(::BBVI) + return Union{ + CholMvNormal, + # Bijectors.TransformDistribution{<:CholMvNormal}, + DiagMvNormal, + # Bijectors.TransformedDistribution{<:DiagMvNormal}, + } +end + +function init(rng::AbstractRNG, alg::BBVI, q, opt) + samples_per_step = nsamples(alg) + x = rand(rng, q, samples_per_step) # Preallocating x + θ = to_vec(q) + diff_result = DiffResults.GradientResult(zeros(length(θ))) + return (x=x, θ=θ, diff_result=diff_result) +end + +function step!(rng::AbstractRNG, ::ELBO, alg::BBVI, q, logπ, state, opt) + q̂ = to_dist(q, state.θ) + rand!(rng, q̂, state.x) # Get initial samples from x₀ + Δlog = evaluate.(logπ, Ref(q̂), eachcol(state.x)) .- logpdf(q̂, state.x) + f(θ) = dot(logpdf(to_dist(q, θ), state.x), Δlog) / nsamples(alg) + grad!(state.diff_result, f, state.θ, alg) + return update!(alg, q, state, opt) +end + +function update!(::BBVI, q, state, opt) + state.θ .+= Optimise.apply!(opt, state.θ, DiffResults.gradient(state.diff_result)) + return state +end + +function final_dist(::BBVI, q, state) + return to_dist(q, state.θ) +end diff --git a/src/compat/reversediff.jl b/src/compat/reversediff.jl index 721d03618..b278fb2e4 100644 --- a/src/compat/reversediff.jl +++ b/src/compat/reversediff.jl @@ -7,10 +7,23 @@ setcache(b::Bool) = RDCache[] = b getcache() = RDCache[] ADBackend(::Val{:reversediff}) = ReverseDiffAD{getcache()} function setadbackend(::Val{:reversediff}) - ADBACKEND[] = :reversediff + return ADBACKEND[] = :reversediff end tape(f, x) = GradientTape(f, x) function taperesult(f, x) return tape(f, x), GradientResult(x) end + +export ReverseDiffAD + +function AdvancedVI.grad!( + out::DiffResults.MutableDiffResult, + f, + x, + ::VariationalInference{<:AdvancedVI.ReverseDiffAD{false}}, +) + tp = AdvancedVI.tape(f, x) + ReverseDiff.gradient!(out, tp, x) + return out +end diff --git a/src/compat/tracker.jl b/src/compat/tracker.jl new file mode 100644 index 000000000..c706d9db5 --- /dev/null +++ b/src/compat/tracker.jl @@ -0,0 +1,21 @@ +using .Tracker + +ADBackend(::Val{:tracker}) = ReverseDiffAD{getcache()} +function setadbackend(::Val{:tracker}) + return ADBACKEND[] = :tracker +end + +struct TrackerAD <: ADBackend end + +ADBackend(::Val{:tracker}) = TrackerAD + +function grad!( + out::DiffResults.MutableDiffResult, f, x, ::VariationalInference{<:TrackerAD} +) + x_tracked = Tracker.param(x) + y = f(x_tracked) + Tracker.back!(y, one(eltype(y))) + + DiffResults.value!(out, Tracker.data(y)) + return DiffResults.gradient!(out, Tracker.grad(x_tracked)) +end diff --git a/src/compat/zygote.jl b/src/compat/zygote.jl index 40022e215..35508d54e 100644 --- a/src/compat/zygote.jl +++ b/src/compat/zygote.jl @@ -1,5 +1,18 @@ +using .Zygote + struct ZygoteAD <: ADBackend end ADBackend(::Val{:zygote}) = ZygoteAD function setadbackend(::Val{:zygote}) - ADBACKEND[] = :zygote + return ADBACKEND[] = :zygote +end +export ZygoteAD + +function AdvancedVI.grad!( + out::DiffResults.MutableDiffResult, f, x, ::VariationalInference{<:AdvancedVI.ZygoteAD} +) + y, back = Zygote.pullback(f, x) + dy = first(back(one(eltype(y)))) + DiffResults.value!(out, y) + DiffResults.gradient!(out, dy) + return out end diff --git a/src/distributions/cholmvnormal.jl b/src/distributions/cholmvnormal.jl new file mode 100644 index 000000000..3d5e5ef47 --- /dev/null +++ b/src/distributions/cholmvnormal.jl @@ -0,0 +1,39 @@ +## Traditional Cholesky representation where Γ is Lower Triangular +struct CholMvNormal{T,Tμ<:AbstractVector{T},TΓ<:LowerTriangular{T}} <: + AbstractPosteriorMvNormal{T} + dim::Int + μ::Tμ + Γ::TΓ + function CholMvNormal(μ::AbstractVector{T}, Γ::LowerTriangular{T}) where {T} + length(μ) == size(Γ, 1) || + throw(DimensionMismatch("μ and Γ have incompatible sizes")) + return new{T,typeof(μ),typeof(Γ)}(length(μ), μ, Γ) + end + function CholMvNormal( + dim::Int, μ::Tμ, Γ::TΓ + ) where {T,Tμ<:AbstractVector{T},TΓ<:LowerTriangular{T}} + length(μ) == size(Γ, 1) || + throw(DimensionMismatch("μ and Γ have incompatible sizes")) + return new{T,Tμ,TΓ}(dim, μ, Γ) + end +end + +Distributions.cov(d::CholMvNormal) = d.Γ * d.Γ' +Distributions.logdetcov(d::CholMvNormal) = 2 * logdet(d.Γ) + +@functor CholMvNormal + +function reparametrize!(x, q::CholMvNormal, z) + return x .= q.μ .+ q.Γ * z +end + +function to_vec(q::CholMvNormal) + return vcat(q.μ, vec(q.Γ)) +end + +function to_dist(q::CholMvNormal, θ::AbstractVector) + return CholMvNormal( + θ[1:length(q)], + LowerTriangular(reshape(θ[(length(q) + 1):end], length(q), length(q))), + ) +end diff --git a/src/distributions/diagmvnormal.jl b/src/distributions/diagmvnormal.jl new file mode 100644 index 000000000..ce6a0fa27 --- /dev/null +++ b/src/distributions/diagmvnormal.jl @@ -0,0 +1,47 @@ + +struct DiagMvNormal{T,Tμ<:AbstractVector{T},TΓ<:AbstractVector{T}} <: + AbstractPosteriorMvNormal{T} + dim::Int + μ::Tμ + Γ::TΓ + function DiagMvNormal(μ::AbstractVector{T}, Γ::AbstractVector{T}) where {T} + return new{T,typeof(μ),typeof(Γ)}(length(μ), μ, Γ) + end + function DiagMvNormal( + dim::Int, μ::Tμ, Γ::TΓ + ) where {T,Tμ<:AbstractVector{T},TΓ<:AbstractVector{T}} + return new{T,Tμ,TΓ}(dim, μ, Γ) + end +end + +function Distributions._rand!( + rng::AbstractRNG, d::DiagMvNormal{T}, x::AbstractVector +) where {T} + nDim = length(x) + nDim == d.dim || error("Wrong dimensions") + return x .= d.μ + d.Γ .* randn(rng, T, nDim) +end + +function Distributions._rand!( + rng::AbstractRNG, d::DiagMvNormal{T}, x::AbstractMatrix +) where {T} + nDim, nPoints = size(x) + nDim == d.dim || error("Wrong dimensions") + return x .= d.μ .+ d.Γ .* randn(rng, T, nDim, nPoints) +end + +Distributions.cov(d::DiagMvNormal) = Diagonal(abs2.(d.Γ)) + +@functor DiagMvNormal + +function reparametrize!(x, q::DiagMvNormal, z) + return x .= q.μ .+ q.Γ .* z +end + +function to_vec(q::DiagMvNormal) + return vcat(q.μ, q.Γ) +end + +function to_dist(q::DiagMvNormal, θ::AbstractVector) + return DiagMvNormal(θ[1:length(q)], θ[(length(q) + 1):end]) +end diff --git a/src/distributions/distributions.jl b/src/distributions/distributions.jl new file mode 100644 index 000000000..c37f521ae --- /dev/null +++ b/src/distributions/distributions.jl @@ -0,0 +1,58 @@ +## Series of variation of the MvNormal distribution, different methods need different parametrizations ## +abstract type AbstractPosteriorMvNormal{T} <: Distributions.AbstractMvNormal end + +Base.length(d::AbstractPosteriorMvNormal) = d.dim +Distributions.dim(d::AbstractPosteriorMvNormal) = d.dim +Distributions.mean(d::AbstractPosteriorMvNormal) = d.μ +rank(d::AbstractPosteriorMvNormal) = d.dim +function eval_entropy(::VariationalInference, d::AbstractPosteriorMvNormal) + return Distributions.entropy(d) +end +Distributions.logdetcov(d::AbstractPosteriorMvNormal) = logdet(cov(d)) +Distributions.invcov(d::AbstractPosteriorMvNormal) = inv(cov(d)) +function Distributions.entropy(d::AbstractPosteriorMvNormal) + return 0.5 * (logdet(cov(d)) + length(d) * log2π) +end + +function Distributions._logpdf(d::AbstractPosteriorMvNormal, x::AbstractArray) + return Distributions._logpdf(MvNormal(d), x) +end + +function Distributions._rand!( + rng::AbstractRNG, d::AbstractPosteriorMvNormal{T}, x::AbstractVector +) where {T} + return Distributions._rand!(rng, MvNormal(d), x) +end + +function Distributions._rand!( + rng::AbstractRNG, d::AbstractPosteriorMvNormal{T}, x::AbstractMatrix +) where {T} + return Distributions._rand!(rng, MvNormal(d), x) +end + +function Distributions.MvNormal(d::AbstractPosteriorMvNormal) + return Distributions.MvNormal(mean(d), cov(d)) +end + +## Update methods + +function update_mean!(q::Bijectors.TransformedDistribution, Δ, opt) + return update_mean!(q.dist, Δ, opt) +end + +function update_mean!(q::AbstractPosteriorMvNormal, Δ, opt) + return q.μ .+= Optimise.apply!(opt, q.μ, Δ) +end + +## Flattening and reconstruction methods + +function to_vec(q::Bijectors.TransformedDistribution) + return to_vec(q.dist) +end + +function to_dist(q::Bijectors.TransformedDistribution, θ::AbstractVector) + return transformed(to_dist(q.dist, θ), q.transform) +end + +include("cholmvnormal.jl") +include("diagmvnormal.jl") diff --git a/src/gradients.jl b/src/gradients.jl new file mode 100644 index 000000000..c2b1ebdcc --- /dev/null +++ b/src/gradients.jl @@ -0,0 +1,14 @@ + +## Implace implementation of gradient for ForwardDiff +function grad!( + diff_result::DiffResults.MutableDiffResult, + f, + x::AbstractArray, + alg::VariationalInference{<:ForwardDiffAD}, +) + chunk_size = getchunksize(typeof(alg)) + # Set chunk size and do ForwardMode. + chunk = ForwardDiff.Chunk(min(length(x), chunk_size)) + config = ForwardDiff.GradientConfig(f, x, chunk) + return ForwardDiff.gradient!(diff_result, f, x, config) +end diff --git a/src/interface.jl b/src/interface.jl new file mode 100644 index 000000000..de0c36550 --- /dev/null +++ b/src/interface.jl @@ -0,0 +1,208 @@ +""" + vi([rng::AbstractRNG, [vo::VariationalObjective]], model, alg::VariationalInference, q::VariationalPosterior; opt, hyperparams, opt_hyperparams)::VariationalPosterior + +Constructs the variational posterior from the `model` and performs the optimization +following the configuration of the given `VariationalInference` instance. +## Arguments +- `vo` : `VariationalObjective`, `ELBO()` by default +- `model`: `Turing.Model` or `Function` z ↦ log p(x, z) where `x` denotes the observations +- `alg`: the VI algorithm used +- `q`: a `VariationalPosterior` for which it is assumed a specialized implementation of the variational objective used exists. +## Keyword Arguments +- `opt` : Optimiser (from `Flux.Optimise`) used to update the variational parameters +- `hyperparams` : Hyperparameters, if different than nothing, `model(hyperparams)` will be called to obatin the logjoint +- `opt_hyperparams` : Optimiser for the Hyperparameters + + vi([rng::AbstractRNG, [vo::VariationalObjective]], model, alg::VariationalInference, q::Function, θ::AbstractVector; opt, hyperparams, opt_hyperparams)::AbstractVector + +Constructs the variational posterior from the `model` and performs the optimization +following the configuration of the given `VariationalInference` instance. +## Arguments +- `vo` : `VariationalObjective`, `ELBO()` by default +- `model`: `Turing.Model` or `Function` z ↦ log p(x, z) where `x` denotes the observations +- `alg`: the VI algorithm used +- `q`: a function creating a distribution from the parameters `θ` +- `θ`: the variational parameters +## Keyword Arguments +- `opt` : Optimiser (from `Flux.Optimise`) used to update the variational parameters +- `hyperparams` : Hyperparameters, if different than nothing, `model(hyperparams)` will be called to obatin the logjoint +- `opt_hyperparams` : Optimiser for the Hyperparameters + +""" +function vi( + rng::AbstractRNG, + vo::VariationalObjective, + model, + alg::VariationalInference, + q; + opt=TruncatedADAGrad(), + hyperparams=nothing, + opt_hyperparams=nothing, +) + θ, to_dist = flatten(q) + θ = vi( + rng, + vo, + alg, + to_dist, + θ, + model; + opt=opt, + hyperparams=hyperparams, + opt_hyperparams=opt_hyperparams, + ) + return to_dist(θ) +end + +function vi( + vo::VariationalObjective, + model, + alg::VariationalInference, + q; + opt=TruncatedADAGrad(), + hyperparams=nothing, + opt_hyperparams=nothing, +) + return vi( + GLOBAL_RNG, + vo, + model, + alg, + q; + opt=opt, + hyperparams=hyperparams, + opt_hyperparams=opt_hyperparams, + ) +end + +function vi( + model, + alg::VariationalInference, + q; + opt=TruncatedADAGrad(), + hyperparams=nothing, + opt_hyperparams=nothing, +) + return vi( + ELBO(), + model, + alg, + q; + opt=opt, + hyperparams=hyperparams, + opt_hyperparams=opt_hyperparams, + ) +end + +function vi( + rng::AbstractRNG, + vo::VariationalObjective, + model, + alg::VariationalInference, + q, + θ::AbstractVector; + opt=TruncatedADAGrad(), + hyperparams=nothing, + opt_hyperparams=nothing, +) + return optimize!( + rng, + vo, + alg, + q, + θ, + model; + opt=opt, + hyperparams=hyperparams, + opt_hyperparams=opt_hyperparams, + ) +end + +function vi( + vo::VariationalObjective, + model, + alg::VariationalInference, + q, + θ::AbstractVector; + opt=TruncatedADAGrad(), + hyperparams=nothing, + opt_hyperparams=nothing, +) + return vi( + GLOBAL_RNG, + vo, + model, + alg, + q, + θ; + opt=opt, + hyperparams=hyperparams, + opt_hyperparams=opt_hyperparams, + ) +end + +function vi( + model, + alg::VariationalInference, + q, + θ; + opt=TruncatedADAGrad(), + hyperparams=nothing, + opt_hyperparams=nothing, +) + return vi( + ELBO(), + model, + alg, + q, + θ; + opt=opt, + hyperparams=hyperparams, + opt_hyperparams=opt_hyperparams, + ) +end + + + +""" + optimize!([vo::VariationalObjective, [alg::VariationalInference{AD}, q::VariationalPosterior, model::Model], θ], ]; optimizer = TruncatedADAGrad()) + +Iteratively updates parameters by calling `grad!` and using the given `optimizer` to compute +the steps. +""" +function optimize!( + rng::AbstractRNG, + vo::VariationalObjective, + alg::VariationalInference, + to_dist, + θ, + model; + opt=TruncatedADAGrad(), + hyperparams=nothing, + opt_hyperparams=nothing, +) + max_iters = maxiters(alg) + + state = init(rng, alg, to_dist, θ, opt) # opt is there to be used in the future + + i = 0 + prog = if PROGRESS[] + ProgressMeter.Progress(max_iters, 1, "[$(alg_str(alg))] Optimizing...", 0) + else + 0 + end + + # add criterion? A running mean maybe? + time_elapsed = @elapsed while (i < max_iters) # & converged + logπ = makelogπ(model, hyperparams) + step!(rng, vo, alg, to_dist, θ, logπ, state, opt) + + # For debugging this would need to be updated somehow + # AdvancedVI.DEBUG && @debug "Step $i" Δ DiffResults.value(diff_result) + PROGRESS[] && (ProgressMeter.next!(prog)) + + i += 1 + end + + return θ +end \ No newline at end of file diff --git a/src/objectives.jl b/src/objectives.jl index db0046205..cc6f522de 100644 --- a/src/objectives.jl +++ b/src/objectives.jl @@ -1,9 +1,14 @@ -using Random: GLOBAL_RNG - struct ELBO <: VariationalObjective end -function (elbo::ELBO)(alg, q, logπ, num_samples; kwargs...) - return elbo(GLOBAL_RNG, alg, q, logπ, num_samples; kwargs...) +const FreeEnergy = ELBO + +## Generic evaluation of the free energy +function evaluate(::ELBO, alg, q, logπ) + return expec_logπ(alg, q, logπ) - entropy(alg, q) +end + +function elbo(alg, q, logπ) + return evaluate(ELBO(), alg, q, logπ) end -const elbo = ELBO() +elbo(alg, q, θ, logπ) = elbo(alg, q(θ), logπ) diff --git a/src/optimisers.jl b/src/optimisers.jl index 8077f98cb..816beb572 100644 --- a/src/optimisers.jl +++ b/src/optimisers.jl @@ -20,26 +20,24 @@ mutable struct TruncatedADAGrad eta::Float64 tau::Float64 n::Int - + iters::IdDict acc::IdDict end -function TruncatedADAGrad(η = 0.1, τ = 1.0, n = 100) - TruncatedADAGrad(η, τ, n, IdDict(), IdDict()) +function TruncatedADAGrad(η=0.1, τ=1.0, n=100) + return TruncatedADAGrad(η, τ, n, IdDict(), IdDict()) end function apply!(o::TruncatedADAGrad, x, Δ) T = eltype(Tracker.data(Δ)) - + η = o.eta τ = o.tau g² = get!( - o.acc, - x, - [zeros(T, size(x)) for j = 1:o.n] - )::Array{typeof(Tracker.data(Δ)), 1} + o.acc, x, [zeros(T, size(x)) for j in 1:(o.n)] + )::Array{typeof(Tracker.data(Δ)),1} i = get!(o.iters, x, 1)::Int # Example: suppose i = 12 and o.n = 10 @@ -50,10 +48,10 @@ function apply!(o::TruncatedADAGrad, x, Δ) # TODO: make more efficient and stable s = sum(g²) - + # increment o.iters[x] += 1 - + # TODO: increment (but "truncate") # o.iters[x] = i > o.n ? o.n + mod(i, o.n) : i + 1 @@ -82,11 +80,11 @@ mutable struct DecayedADAGrad acc::IdDict end -DecayedADAGrad(η = 0.1, pre = 1.0, post = 0.9) = DecayedADAGrad(η, pre, post, IdDict()) +DecayedADAGrad(η=0.1, pre=1.0, post=0.9) = DecayedADAGrad(η, pre, post, IdDict()) function apply!(o::DecayedADAGrad, x, Δ) T = eltype(Tracker.data(Δ)) - + η = o.eta acc = get!(o.acc, x, fill(T(ϵ), size(x)))::typeof(Tracker.data(x)) @. acc = o.post * acc + o.pre * Δ^2 diff --git a/src/utils.jl b/src/utils.jl new file mode 100644 index 000000000..e9fd21c22 --- /dev/null +++ b/src/utils.jl @@ -0,0 +1,16 @@ +makelogπ(logπ, ::Nothing) = logπ +makelogπ(model, hp) = model(hp) + +## Generic evaluation of the expectation +function expec_logπ(alg, q, logπ) + return mean(logπ, eachcol(rand(q, samples_per_step(alg)))) +end + +function evaluate(logπ, q::Bijectors.TransformedDistribution, x::AbstractVector) + z, logjac = forward(q.transform, x) + return logπ(z) + logjac +end + +function evaluate(logπ, ::Any, x::AbstractVector) + return logπ(x) +end diff --git a/test/Project.toml b/test/Project.toml new file mode 100644 index 000000000..3d95b890c --- /dev/null +++ b/test/Project.toml @@ -0,0 +1,9 @@ +[deps] +AdvancedVI = "b5ca4192-6429-45e5-a2d9-87aec30a685c" +Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/algorithms/advi.jl b/test/algorithms/advi.jl new file mode 100644 index 000000000..84f4e34a4 --- /dev/null +++ b/test/algorithms/advi.jl @@ -0,0 +1,19 @@ +@testset "advi" begin + using AdvancedVI: CholMvNormal, DiagMvNormal + ## Testing no transform + target = MvNormal(ones(2)) + xs = rand(target, 10) + logπ(z) = logpdf(target, z) + qs = [ + CholMvNormal(randn(2), LowerTriangular(randn(2, 2))), + DiagMvNormal(randn(2), randn(2)), + ] + advi = ADVI(10, 1000) + for q in qs + q = vi(logπ, advi, q; opt=ADAM(0.01)) + @test mean(abs2, logpdf(q, xs) - logpdf(target, xs)) ≤ 0.05 + end + + ## Testing with transform # TODO + +end diff --git a/test/algorithms/bbvi.jl b/test/algorithms/bbvi.jl new file mode 100644 index 000000000..abaaa3636 --- /dev/null +++ b/test/algorithms/bbvi.jl @@ -0,0 +1,19 @@ +@testset "bbvi" begin + using AdvancedVI: CholMvNormal, DiagMvNormal + ## Testing no transform + target = MvNormal(ones(2)) + xs = rand(target, 10) + logπ(z) = logpdf(target, z) + qs = [ + CholMvNormal(randn(2), LowerTriangular(randn(2, 2))), + DiagMvNormal(randn(2), randn(2)), + ] + advi = BBVI(10, 1000) + for q in qs + q = vi(logπ, advi, q; opt=ADAM(0.01)) + @test mean(abs2, logpdf(q, xs) - logpdf(target, xs)) ≤ 0.05 + end + + ## Testing with transform # TODO + +end diff --git a/test/distributions/cholmvnormal.jl b/test/distributions/cholmvnormal.jl new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/test/distributions/cholmvnormal.jl @@ -0,0 +1 @@ + diff --git a/test/distributions/diagmvnormal.jl b/test/distributions/diagmvnormal.jl new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/test/distributions/diagmvnormal.jl @@ -0,0 +1 @@ + diff --git a/test/distributions/distributions.jl b/test/distributions/distributions.jl new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/test/distributions/distributions.jl @@ -0,0 +1 @@ + diff --git a/test/gradients.jl b/test/gradients.jl new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/test/gradients.jl @@ -0,0 +1 @@ + diff --git a/test/interface.jl b/test/interface.jl new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/test/interface.jl @@ -0,0 +1 @@ + diff --git a/test/objectives.jl b/test/objectives.jl new file mode 100644 index 000000000..0d5ff1fdd --- /dev/null +++ b/test/objectives.jl @@ -0,0 +1,13 @@ +@testset "objectives" begin + using AdvancedVI: ELBO, FreeEnergy, VariationalObjective, evaluate + using AdvancedVI: elbo, entropy, expec_logπ + L = ELBO() + @test L isa VariationalObjective + @test L isa FreeEnergy + alg = ADVI(1000, 1) + q = AdvancedVI.CholMvNormal(zeros(2), LowerTriangular(diagm(ones(2)))) + logπ(x) = logpdf(MvNormal(ones(2)), x) + @test evaluate(L, alg, q, logπ) ≈ (expec_logπ(alg, q, logπ) - entropy(alg, q)) atol = + 1e0 + @test elbo(alg, q, logπ) ≈ evaluate(L, alg, q, logπ) atol = 1e0 +end diff --git a/test/optimisers.jl b/test/optimisers.jl index fae652ed0..f1bee4e5b 100644 --- a/test/optimisers.jl +++ b/test/optimisers.jl @@ -1,11 +1,9 @@ -using Random, Test, LinearAlgebra, ForwardDiff -using AdvancedVI: TruncatedADAGrad, DecayedADAGrad, apply! - -θ = randn(10, 10) @testset for opt in [TruncatedADAGrad(), DecayedADAGrad(1e-2)] + using AdvancedVI: TruncatedADAGrad, DecayedADAGrad, apply! + θ = randn(10, 10) θ_fit = randn(10, 10) - loss(x, θ_) = mean(sum(abs2, θ*x - θ_*x; dims = 1)) - for t = 1:10^4 + loss(x, θ_) = mean(sum(abs2, θ * x - θ_ * x; dims=1)) + for t in 1:(10^4) x = rand(10) Δ = ForwardDiff.gradient(θ_ -> loss(x, θ_), θ_fit) Δ = apply!(opt, θ_fit, Δ) @@ -14,4 +12,3 @@ using AdvancedVI: TruncatedADAGrad, DecayedADAGrad, apply! @test loss(rand(10, 100), θ_fit) < 0.01 @test length(opt.acc) == 1 end - diff --git a/test/runtests.jl b/test/runtests.jl index a305c25e5..b07464987 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,28 +1,27 @@ -using Test -using Distributions, DistributionsAD using AdvancedVI +using Bijectors +using Distributions +using Flux +using ForwardDiff +using LinearAlgebra +using Random +using Test -include("optimisers.jl") - -target = MvNormal(ones(2)) -logπ(z) = logpdf(target, z) -advi = ADVI(10, 1000) - -# Using a function z ↦ q(⋅∣z) -getq(θ) = TuringDiagMvNormal(θ[1:2], exp.(θ[3:4])) -q = vi(logπ, advi, getq, randn(4)) - -xs = rand(target, 10) -@test mean(abs2, logpdf(q, xs) - logpdf(target, xs)) ≤ 0.05 - -# OR: implement `update` and pass a `Distribution` -function AdvancedVI.update(d::TuringDiagMvNormal, θ::AbstractArray{<:Real}) - return TuringDiagMvNormal(θ[1:length(q)], exp.(θ[length(q) + 1:end])) +# include("optimisers.jl") + +@testset "AdvancedVI" begin + @testset "algorithms" begin + include(joinpath("algorithms", "advi.jl")) + include(joinpath("algorithms", "bbvi.jl")) + end + @testset "distributions" begin + include(joinpath("distributions", "distributions.jl")) + include(joinpath("distributions", "diagmvnormal.jl")) + include(joinpath("distributions", "cholmvnormal.jl")) + end + include("gradients.jl") + include("interface.jl") + # include("optimisers.jl") # Relying on Tracker... + include("objectives.jl") + include("utils.jl") end - -q0 = TuringDiagMvNormal(zeros(2), ones(2)) -q = vi(logπ, advi, q0, randn(4)) - -xs = rand(target, 10) -@test mean(abs2, logpdf(q, xs) - logpdf(target, xs)) ≤ 0.05 - diff --git a/test/utils.jl b/test/utils.jl new file mode 100644 index 000000000..475b61188 --- /dev/null +++ b/test/utils.jl @@ -0,0 +1,16 @@ +@testset "utils" begin + using AdvancedVI: makelogπ, evaluate + + f(x) = 2x + make_f(h) = f + @test makelogπ(f, nothing) == f + @test makelogπ(make_f, []) == f + + x = rand(2) + q = MvNormal(ones(2)) + q̂ = transformed(q, Bijectors.RadialLayer(2)) + logπ(x) = logpdf(q, x) + z, logj = forward(q̂.transform, x) + @test evaluate(logπ, q, x) == logπ(x) + @test evaluate(logπ, q̂, x) == logπ(z) + logj +end