-
Notifications
You must be signed in to change notification settings - Fork 228
Description
Problem
julia> using Turing
julia> @model function buggy_model()
lb ~ Uniform(0, 0.1)
ub ~ Uniform(0.11, 0.2)
x ~ transformed(Normal(0, 1), inverse(Bijectors.Logit(lb, ub)))
end
buggy_model (generic function with 2 methods)
julia> model = buggy_model();
julia> chain = sample(model, NUTS(), 1000);
┌ Info: Found initial step size
└ ϵ = 3.2
Sampling 100%|█████████████████████████████████████████████████████████████████████████████| Time: 0:00:01
julia> results = generated_quantities(model, chain); # (×) Breaks!
ERROR: DomainError with -0.05206647177072762:
log was called with a negative real argument but will only return a complex result if called with a complex argument. Try log(Complex(x)).
DomainError detected in the user `f` function. This occurs when the domain of a function is violated.
For example, `log(-1.0)` is undefined because `log` of a real number is defined to only output real
numbers, but `log` of a negative number is complex valued and therefore Julia throws a DomainError
by default. Cases to be aware of include:
* `log(x)`, `sqrt(x)`, `cbrt(x)`, etc. where `x<0`
* `x^y` for `x<0` floating point `y` (example: `(-1.0)^(1/2) == im`)
...
In contrast, if we use Prior
to sample, we're good:
julia> chain_prior = sample(model, Prior(), 1000);
Sampling 100%|█████████████████████████████████████████████████████████████████████████████| Time: 0:00:00
julia> results_prior = generated_quantities(model, chain_prior); # (✓) Works because no linking needed
The issue is caused by the fact that we use DynamicPPL.invlink!!(varinfo, model)
when constructing a transition
, which is what ends up in the chain rather than an issue with the inference itself.
For example, if we use AdvancedHMC.jl directly:
julia> using AdvancedHMC: AdvancedHMC
julia> f = DynamicPPL.LogDensityFunction(model);
julia> DynamicPPL.link!!(f.varinfo, f.model);
julia> chain_ahmc = sample(f, AdvancedHMC.NUTS(0.8), 1000);
[ Info: Found initial step size 3.2
Sampling 100%|███████████████████████████████| Time: 0:00:00
iterations: 1000
ratio_divergent_transitions: 0.0
ratio_divergent_transitions_during_adaption: 0.0
n_steps: 7
is_accept: true
acceptance_rate: 0.7879658455930968
log_density: -5.038135476673508
hamiltonian_energy: 7.775565727543868
hamiltonian_energy_error: -0.11294798909710124
max_hamiltonian_energy_error: 0.5539216379943772
tree_depth: 3
numerical_error: false
step_size: 1.1685229504528063
nom_step_size: 1.1685229504528063
is_adapt: false
mass_matrix: DiagEuclideanMetric([1.0, 1.0, 1.0])
julia> function to_constrained(θ)
lb = inverse(Bijectors.Logit(0.0, 0.1))(θ[1])
ub = inverse(Bijectors.Logit(0.11, 0.2))(θ[2])
x = inverse(Bijectors.Logit(lb, ub))(θ[3])
return [lb, ub, x]
end
to_constrained (generic function with 1 method)
julia> chain_ahmc_constrained = mapreduce(hcat, chain_ahmc) do t
to_constrained(t.z.θ)
end;
julia> chain_ahmc = Chains(
permutedims(chain_ahmc_constrained),
[:lb, :ub, :x]
);
Visualizing the densities of the resulting chains, we also see that the one from Turing.NUTS
is incorrect (the blue line), while the other two (Prior
and AdvancedHMC.NUTS
) coincide:
Solution?
Fixing this I think will actually be quite annoying 😕 But I do think it's worth doing.
There are a few approaches:
- Re-evaluate the model for every transition we end up accepting to get the distributions corresponding to that particular realization.
- Double the memory usage of
VarInfo
and always store both the linked and the invlinked realizations. - Use a separate context to capture the invlinked realizations.
No matter how we do this, there is the issue that we can't support this properly for externalsampler
, etc. that uses the LogDensityFunction
, without explicit re-evaluation of the model 😕 Though it seems it would still be worth adding proper support for this in the "internal" impls of the samplers
Might be worth providing an option to force re-evaluation in combination with, say, a warning if we notice that supports change between two different realizations