Skip to content

Issues with constrained parameters depending on each other #2195

@torfjelde

Description

@torfjelde

Problem

julia> using Turing

julia> @model function buggy_model()
           lb ~ Uniform(0, 0.1)
           ub ~ Uniform(0.11, 0.2)
           x ~ transformed(Normal(0, 1), inverse(Bijectors.Logit(lb, ub)))
       end

buggy_model (generic function with 2 methods)

julia> model = buggy_model();

julia> chain = sample(model, NUTS(), 1000);
┌ Info: Found initial step size
└   ϵ = 3.2
Sampling 100%|█████████████████████████████████████████████████████████████████████████████| Time: 0:00:01

julia> results = generated_quantities(model, chain); # (×) Breaks!
ERROR: DomainError with -0.05206647177072762:
log was called with a negative real argument but will only return a complex result if called with a complex argument. Try log(Complex(x)).
DomainError detected in the user `f` function. This occurs when the domain of a function is violated.
For example, `log(-1.0)` is undefined because `log` of a real number is defined to only output real
numbers, but `log` of a negative number is complex valued and therefore Julia throws a DomainError
by default. Cases to be aware of include:

* `log(x)`, `sqrt(x)`, `cbrt(x)`, etc. where `x<0`
* `x^y` for `x<0` floating point `y` (example: `(-1.0)^(1/2) == im`)
...

In contrast, if we use Prior to sample, we're good:

julia> chain_prior = sample(model, Prior(), 1000);

Sampling 100%|█████████████████████████████████████████████████████████████████████████████| Time: 0:00:00

julia> results_prior = generated_quantities(model, chain_prior); # (✓) Works because no linking needed

The issue is caused by the fact that we use DynamicPPL.invlink!!(varinfo, model) when constructing a transition, which is what ends up in the chain rather than an issue with the inference itself.

For example, if we use AdvancedHMC.jl directly:

julia> using AdvancedHMC: AdvancedHMC

julia> f = DynamicPPL.LogDensityFunction(model);

julia> DynamicPPL.link!!(f.varinfo, f.model);

julia> chain_ahmc = sample(f, AdvancedHMC.NUTS(0.8), 1000);
[ Info: Found initial step size 3.2
Sampling 100%|███████████████████████████████| Time: 0:00:00
  iterations:                                   1000
  ratio_divergent_transitions:                  0.0
  ratio_divergent_transitions_during_adaption:  0.0
  n_steps:                                      7
  is_accept:                                    true
  acceptance_rate:                              0.7879658455930968
  log_density:                                  -5.038135476673508
  hamiltonian_energy:                           7.775565727543868
  hamiltonian_energy_error:                     -0.11294798909710124
  max_hamiltonian_energy_error:                 0.5539216379943772
  tree_depth:                                   3
  numerical_error:                              false
  step_size:                                    1.1685229504528063
  nom_step_size:                                1.1685229504528063
  is_adapt:                                     false
  mass_matrix:                                  DiagEuclideanMetric([1.0, 1.0, 1.0])

julia> function to_constrained(θ)
           lb = inverse(Bijectors.Logit(0.0, 0.1))(θ[1])
           ub = inverse(Bijectors.Logit(0.11, 0.2))(θ[2])
           x = inverse(Bijectors.Logit(lb, ub))(θ[3])
           return [lb, ub, x]
       end
to_constrained (generic function with 1 method)

julia> chain_ahmc_constrained = mapreduce(hcat, chain_ahmc) do t
           to_constrained(t.z.θ)
       end;

julia> chain_ahmc = Chains(
           permutedims(chain_ahmc_constrained),
           [:lb, :ub, :x]
       );

Visualizing the densities of the resulting chains, we also see that the one from Turing.NUTS is incorrect (the blue line), while the other two (Prior and AdvancedHMC.NUTS) coincide:

image

Solution?

Fixing this I think will actually be quite annoying 😕 But I do think it's worth doing.

There are a few approaches:

  1. Re-evaluate the model for every transition we end up accepting to get the distributions corresponding to that particular realization.
  2. Double the memory usage of VarInfo and always store both the linked and the invlinked realizations.
  3. Use a separate context to capture the invlinked realizations.

No matter how we do this, there is the issue that we can't support this properly for externalsampler, etc. that uses the LogDensityFunction, without explicit re-evaluation of the model 😕 Though it seems it would still be worth adding proper support for this in the "internal" impls of the samplers

Might be worth providing an option to force re-evaluation in combination with, say, a warning if we notice that supports change between two different realizations

@yebai @devmotion @sunxd3

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions