Skip to content

Basic rewrite of the package #25

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 29 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
8449d29
First rewrite pass with addition of DSVI
theogf Feb 12, 2021
0e3b798
Better handling of Tracker
theogf Feb 12, 2021
33948e6
Correct entropy implementation
theogf Feb 12, 2021
06317ca
Moved back to ADVI definition and readded objectives
theogf Feb 17, 2021
fd657d5
More updates on the interface
theogf Feb 17, 2021
603b5ae
Corrections interface
theogf Feb 18, 2021
ebf5cfa
Remove manifest
theogf Feb 18, 2021
5f711e6
Corrected approach of BBVI
theogf Feb 18, 2021
62f9f26
Working version of BBVI
theogf Feb 18, 2021
b7390ac
Update CI
theogf Feb 18, 2021
1c7045e
Fixed test and issues
theogf Feb 18, 2021
8104a5c
Reorganized file structure and tests
theogf Feb 18, 2021
26f33e1
Adding more tests
theogf Feb 18, 2021
58fde04
Cleaned up tests more
theogf Feb 18, 2021
dfb4351
Fixing tests
theogf Feb 20, 2021
22b024b
Removing unneeded parts distributions
theogf Feb 20, 2021
b01cdce
Relaxing objectives test
theogf Feb 20, 2021
5343a2f
Adapted grad! for others AD
theogf Feb 20, 2021
59834d4
Correct inheritance
theogf Feb 20, 2021
f58970b
Fixing versioning
theogf Feb 22, 2021
e3ad352
Back to unintuitive ad backend naming
theogf Apr 12, 2021
633fa0e
Addressing comments
theogf Apr 12, 2021
094e30c
Removed XXt
theogf Jul 29, 2021
d290fa0
Merge branch 'tg/rework_advi' of https://github.com/TuringLang/Advanc…
theogf Jul 29, 2021
ec4f84e
Passing now a RNG through
theogf Jul 29, 2021
3a0fb4b
Formatted everything with bluestyle
theogf Jul 29, 2021
a638e12
Fixed makelogpi
theogf Jul 29, 2021
3718978
Remade the interface to take function plus arguments
theogf Aug 24, 2021
9202f24
Initialize clean up
theogf Aug 31, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
style = "blue"
36 changes: 11 additions & 25 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
@@ -1,37 +1,27 @@
name: CI

on:
push:
branches:
- master
pull_request:

jobs:
test:
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }}
runs-on: ${{ matrix.os }}
continue-on-error: ${{ matrix.version == 'nightly' }}
strategy:
fail-fast: false
matrix:
version:
- '1'
- '1.3'
- 'nightly'
os:
- ubuntu-latest
- macOS-latest
- windows-latest
arch:
- x64
- x86
exclude:
- os: macOS-latest
arch: x86
include:
- version: '1.0'
os: ubuntu-latest
arch: x64
- os: ubuntu-latest
version: '1'
arch: x64
coverage: true
steps:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@v1
Expand All @@ -48,16 +38,12 @@ jobs:
${{ runner.os }}-test-${{ env.cache-name }}-
${{ runner.os }}-test-
${{ runner.os }}-
- uses: julia-actions/julia-buildpkg@latest
- uses: julia-actions/julia-runtest@latest
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-runtest@v1
- uses: julia-actions/julia-processcoverage@v1
if: matrix.coverage
- uses: codecov/codecov-action@v1
if: matrix.coverage
- name: Send coverage
if: matrix.version == '1' && matrix.os == 'ubuntu-latest'
uses: coverallsapp/github-action@master
with:
file: lcov.info
- uses: coverallsapp/github-action@master
if: matrix.coverage
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
path-to-lcov: lcov.info
github-token: ${{ secrets.GITHUB_TOKEN }}
path-to-lcov: ./lcov.info
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
test/Manifest.toml
.vscode
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"

[compat]
Bijectors = "0.4.0, 0.5, 0.6, 0.7, 0.8"
Expand All @@ -26,7 +27,6 @@ ProgressMeter = "1.0.0"
Requires = "0.5, 1.0"
StatsBase = "0.32, 0.33"
StatsFuns = "0.8, 0.9"
Tracker = "0.2.3"
julia = "1"

[extras]
Expand Down
224 changes: 28 additions & 196 deletions src/AdvancedVI.jl
Original file line number Diff line number Diff line change
@@ -1,238 +1,70 @@
module AdvancedVI

using Random: AbstractRNG
using Bijectors: GLOBAL_RNG
using Random: AbstractRNG, GLOBAL_RNG

using Distributions, DistributionsAD, Bijectors
using Bijectors
using Distributions
using DistributionsAD
using DocStringExtensions

using ProgressMeter, LinearAlgebra

using ForwardDiff
using Tracker
using Flux: Optimise # Temp before Optimisers.jl is registered
using Functors
using ProgressMeter, LinearAlgebra
using Random
using Requires

const PROGRESS = Ref(true)
function turnprogress(switch::Bool)
@info("[AdvancedVI]: global PROGRESS is set as $switch")
PROGRESS[] = switch
return PROGRESS[] = switch
end

const DEBUG = Bool(parse(Int, get(ENV, "DEBUG_ADVANCEDVI", "0")))

include("ad.jl")

using Requires
function __init__()
@require Flux="587475ba-b771-5e3f-ad9e-33799f191a9c" begin
@require Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" begin
apply!(o, x, Δ) = Flux.Optimise.apply!(o, x, Δ)
Flux.Optimise.apply!(o::TruncatedADAGrad, x, Δ) = apply!(o, x, Δ)
Flux.Optimise.apply!(o::DecayedADAGrad, x, Δ) = apply!(o, x, Δ)
end
@require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin
include("compat/zygote.jl")
export ZygoteAD

function AdvancedVI.grad!(
vo,
alg::VariationalInference{<:AdvancedVI.ZygoteAD},
q,
model,
θ::AbstractVector{<:Real},
out::DiffResults.MutableDiffResult,
args...
)
f(θ) = if (q isa Distribution)
- vo(alg, update(q, θ), model, args...)
else
- vo(alg, q(θ), model, args...)
end
y, back = Zygote.pullback(f, θ)
dy = first(back(1.0))
DiffResults.value!(out, y)
DiffResults.gradient!(out, dy)
return out
end
include(joinpath("compat", "zygote.jl"))
end
@require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin
include("compat/reversediff.jl")
export ReverseDiffAD

function AdvancedVI.grad!(
vo,
alg::VariationalInference{<:AdvancedVI.ReverseDiffAD{false}},
q,
model,
θ::AbstractVector{<:Real},
out::DiffResults.MutableDiffResult,
args...
)
f(θ) = if (q isa Distribution)
- vo(alg, update(q, θ), model, args...)
else
- vo(alg, q(θ), model, args...)
end
tp = AdvancedVI.tape(f, θ)
ReverseDiff.gradient!(out, tp, θ)
return out
end
include(joinpath("compat", "reversediff.jl"))
end
@require Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin
include(joinpath("compat", "tracker.jl"))
end
end

export
vi,
ADVI,
ELBO,
elbo,
TruncatedADAGrad,
DecayedADAGrad,
VariationalInference
export vi, ADVI, BBVI, ELBO, TruncatedADAGrad, DecayedADAGrad, VariationalInference

abstract type VariationalInference{AD} end

getchunksize(::Type{<:VariationalInference{AD}}) where AD = getchunksize(AD)
getADtype(::VariationalInference{AD}) where AD = AD
getchunksize(::Type{<:VariationalInference{AD}}) where {AD} = getchunksize(AD)
getADtype(::VariationalInference{AD}) where {AD} = AD

abstract type VariationalObjective end

const VariationalPosterior = Distribution{Multivariate, Continuous}


"""
grad!(vo, alg::VariationalInference, q, model::Model, θ, out, args...)

Computes the gradients used in `optimize!`. Default implementation is provided for
`VariationalInference{AD}` where `AD` is either `ForwardDiffAD` or `TrackerAD`.
This implicitly also gives a default implementation of `optimize!`.

Variance reduction techniques, e.g. control variates, should be implemented in this function.
"""
function grad! end

"""
vi(model, alg::VariationalInference)
vi(model, alg::VariationalInference, q::VariationalPosterior)
vi(model, alg::VariationalInference, getq::Function, θ::AbstractArray)

Constructs the variational posterior from the `model` and performs the optimization
following the configuration of the given `VariationalInference` instance.

# Arguments
- `model`: `Turing.Model` or `Function` z ↦ log p(x, z) where `x` denotes the observations
- `alg`: the VI algorithm used
- `q`: a `VariationalPosterior` for which it is assumed a specialized implementation of the variational objective used exists.
- `getq`: function taking parameters `θ` as input and returns a `VariationalPosterior`
- `θ`: only required if `getq` is used, in which case it is the initial parameters for the variational posterior
"""
function vi end

function update end

# default implementations
function grad!(
vo,
alg::VariationalInference{<:ForwardDiffAD},
q,
model,
θ::AbstractVector{<:Real},
out::DiffResults.MutableDiffResult,
args...
)
f(θ_) = if (q isa Distribution)
- vo(alg, update(q, θ_), model, args...)
else
- vo(alg, q(θ_), model, args...)
end

chunk_size = getchunksize(typeof(alg))
# Set chunk size and do ForwardMode.
chunk = ForwardDiff.Chunk(min(length(θ), chunk_size))
config = ForwardDiff.GradientConfig(f, θ, chunk)
ForwardDiff.gradient!(out, f, θ, config)
end
const VariationalPosterior = Distribution{Multivariate,Continuous}

function grad!(
vo,
alg::VariationalInference{<:TrackerAD},
q,
model,
θ::AbstractVector{<:Real},
out::DiffResults.MutableDiffResult,
args...
)
θ_tracked = Tracker.param(θ)
y = if (q isa Distribution)
- vo(alg, update(q, θ_tracked), model, args...)
else
- vo(alg, q(θ_tracked), model, args...)
end
Tracker.back!(y, 1.0)

DiffResults.value!(out, Tracker.data(y))
DiffResults.gradient!(out, Tracker.grad(θ_tracked))
end
# Custom distributions
include(joinpath("distributions", "distributions.jl"))


"""
optimize!(vo, alg::VariationalInference{AD}, q::VariationalPosterior, model::Model, θ; optimizer = TruncatedADAGrad())

Iteratively updates parameters by calling `grad!` and using the given `optimizer` to compute
the steps.
"""
function optimize!(
vo,
alg::VariationalInference,
q,
model,
θ::AbstractVector{<:Real};
optimizer = TruncatedADAGrad()
)
# TODO: should we always assume `samples_per_step` and `max_iters` for all algos?
alg_name = alg_str(alg)
samples_per_step = alg.samples_per_step
max_iters = alg.max_iters

num_params = length(θ)

# TODO: really need a better way to warn the user about potentially
# not using the correct accumulator
if (optimizer isa TruncatedADAGrad) && (θ ∉ keys(optimizer.acc))
# this message should only occurr once in the optimization process
@info "[$alg_name] Should only be seen once: optimizer created for θ" objectid(θ)
end

diff_result = DiffResults.GradientResult(θ)

i = 0
prog = if PROGRESS[]
ProgressMeter.Progress(max_iters, 1, "[$alg_name] Optimizing...", 0)
else
0
end

# add criterion? A running mean maybe?
time_elapsed = @elapsed while (i < max_iters) # & converged
grad!(vo, alg, q, model, θ, diff_result, samples_per_step)

# apply update rule
Δ = DiffResults.gradient(diff_result)
Δ = apply!(optimizer, θ, Δ)
@. θ = θ - Δ

AdvancedVI.DEBUG && @debug "Step $i" Δ DiffResults.value(diff_result)
PROGRESS[] && (ProgressMeter.next!(prog))

i += 1
end

return θ
end
include("utils.jl")

# objectives
include("objectives.jl")

# optimisers
include("gradients.jl")
include("interface.jl")
include("optimisers.jl")

# VI algorithms
include("advi.jl")
include(joinpath("algorithms", "advi.jl"))
include(joinpath("algorithms", "bbvi.jl"))

end # module
Loading