Skip to content

Difficulty sampling a model with truncated normal Likelihood #1722

@dlakelan

Description

@dlakelan
Contributor

I've been having problems sampling models that use truncated normal distributions.

This minimal worked example from the discourse.julialang.org discussion shows that the initial vector passed in seems to be problematic. The first printout doesn't show the provided initial value, but rather a value with a very small standard deviation of the errors, therefore it immediately has numerical issues, showing zero probability to be at the initial point.

https://discourse.julialang.org/t/making-turing-fast-with-large-numbers-of-parameters/69072/99?u=dlakelan

using Pkg
Pkg.activate(".")
using Turing, DataFrames,DataFramesMeta,LazyArrays,Distributions,DistributionsAD
using LazyArrays, ReverseDiff, Memoization

## every few hours a random staff member comes and gets a random
## patient to bring them outside to a garden through a door that has a
## scale. Sometimes using a wheelchair, sometimes not. knowing the
## total weight of the two people and the wheelchair plus some errors
## (from the scale measurements), infer the individual weights of all
## individuals and the weight of the wheelchair.

nstaff = 100
npat = 100
staffids = collect(1:nstaff)
patientids = collect(1:npat)
staffweights = rand(Normal(150,30),length(staffids))
patientweights = rand(Normal(150,30),length(staffids))
wheelchairwt = 15
nobs = 300

data = DataFrame(staff=rand(staffids,nobs),patient=rand(patientids,nobs))
data.usewch = rand(0:1,nobs)
data.totweights = [staffweights[data.staff[i]] + patientweights[data.patient[i]] for i in 1:nrow(data)] .+ data.usewch .* wheelchairwt .+ rand(Normal(0.0,20.0),nrow(data))


Turing.setadbackend(:reversediff)
Turing.setrdcache(true)
Turing.emptyrdcache()



@model function estweights(nstaff,staffid,npatients,patientid,usewch,totweight)
    wcwt ~ Gamma(20.0,15.0/19)
    staffweights ~ filldist(Normal(150,30),nstaff)
    patientweights ~ filldist(Normal(150,30),npatients)
    
    totweight ~ MvNormal(view(staffweights,staffid) .+ view(patientweights,patientid) .+ usewch .* wcwt,20.0)
end



@model function estweights2(nstaff,staffid,npatients,patientid,usewch,totweight)
    wcwt ~ Gamma(20.0,15.0/19)
    staffweights ~ filldist(truncated(Normal(150,30),90.0,Inf),nstaff)
    patientweights ~ filldist(truncated(Normal(150,30),90.0,Inf),npatients)
    
    totweight ~ arraydist([Gamma(15,(staffweights[staffid[i]] + patientweights[patientid[i]] + usewch[i] * wcwt)/14) for i in 1:length(totweight)])
end


@model function estweights3(nstaff,staffid,npatients,patientid,usewch,totweight)
    wcwt ~ Gamma(20.0,15.0/19)
    measerr ~ Gamma(10.0,20.0/9)
    staffweights ~ filldist(truncated(Normal(150,30),90.0,Inf),nstaff)
    patientweights ~ filldist(truncated(Normal(150,30),90.0,Inf),npatients)
    
    totweight ~ arraydist([truncated(Normal(staffweights[staffid[i]] + patientweights[patientid[i]] + usewch[i] * wcwt, measerr),0.0,Inf) for i in 1:length(totweight)])
end

function truncatenormal(a,b)::UnivariateDistribution
    truncated(Normal(a,b),0.0,Inf)
end


@model function estweights3lazy(nstaff,staffid,npatients,patientid,usewch,totweight)

    wcwt ~ Gamma(20.0,15.0/19)
    measerr ~ Gamma(10.0,20.0/9)
    staffweights ~ filldist(truncated(Normal(150,30),90.0,Inf),nstaff)
    patientweights ~ filldist(truncated(Normal(150,30),90.0,Inf),npatients)
    theta = LazyArray(@~ view(staffweights,staffid) .+ view(patientweights,patientid) .+ usewch .* wcwt)
    println("""Evaluating model... 
wcwt: $wcwt
measerr: $measerr
exstaffweights: $(staffweights[1:10])
expatweights: $(patientweights[1:10])
""")
    
    totweight ~ arraydist(LazyArray(@~ truncatenormal.(theta,measerr)))
end



@model function estweights4(nstaff,staffid,npatients,patientid,usewch,totweight)
    wcwt ~ Gamma(20.0,15.0/19)
    measerr ~ Gamma(10.0,20.0/9)
    staffweights ~ filldist(truncated(Normal(150,30),90.0,Inf),nstaff)
    patientweights ~ filldist(truncated(Normal(150,30),90.0,Inf),npatients)
    means = view(staffweights,staffid) .+ view(patientweights,patientid) .+ usewch .* wcwt
    totweight .~ Gamma.(12,means./11)
end




@model function estweightslazygamma(nstaff,staffid,npatients,patientid,usewch,totweight)
    wcwt ~ Gamma(20.0,15.0/19)
    measerr ~ Gamma(10.0,20.0/9)
    staffweights ~ filldist(truncated(Normal(150,30),90.0,Inf),nstaff)
    patientweights ~ filldist(truncated(Normal(150,30),90.0,Inf),npatients)
    theta = LazyArray(@~ view(staffweights,staffid) .+ view(patientweights,patientid) .+ usewch .* wcwt)
    totweight ~ arraydist(LazyArray(@~ Gamma.(15, theta ./ 14)))
end




# ch1 = sample(estweights(nstaff,data.staff,npat,data.patient,data.usewch,data.totweights),NUTS(500,.75),1000)


# ch2 = sample(estweights2(nstaff,data.staff,npat,data.patient,data.usewch,data.totweights),NUTS(500,.75),1000)

# ch3 = sample(estweights3(nstaff,data.staff,npat,data.patient,data.usewch,data.totweights),NUTS(500,.75),1000)

ch3l = sample(estweights3lazy(nstaff,data.staff,npat,data.patient,data.usewch,data.totweights),NUTS(500,.75,init_ϵ=.002),1000;
              init_theta = vcat([15.0,20.0],staffweights,patientweights))

# ch4 = sample(estweights4(nstaff,data.staff,npat,data.patient,data.usewch,data.totweights),NUTS(500,.75),1000)


#ch5 = sample(estweightslazygamma(nstaff,data.staff,npat,data.patient,data.usewch,data.totweights),NUTS(500,.75),1000)

When running this version, the initial

Activity

devmotion

devmotion commented on Oct 27, 2021

@devmotion
Member

I think this issue is relevant here: #1588

Most importantly, it seems you use the keyword argument init_theta. However, as also discussed in the linked issue, initial parameter values have to be specified with init_params.

dlakelan

dlakelan commented on Oct 28, 2021

@dlakelan
ContributorAuthor

Aha. Well that's disappointing, I've been running pre-optimizations and following the docs to use those optimizations as the starting points...

Ok, so if I use init_params it does seem to evaluate at the initial vector, but then it complains as follows:, first printing a few reals, then suddenly switching to TrackedReals and then complaining about non-finite values, specifically, apparently it thinks the parameter vector has gone non-finite, and so has the log(p) this is even though the printed values of the parameters that I was checking are reasonable.

Evaluating model... 
wcwt: 0.37882044075836946
measerr: 1.2209697179924754
exstaffweights: [144.59651665320027, 139.26250118653655, 150.90239119517722, 171.0364350885989, 128.94644940913446, 116.41230791013996, 153.99
758061107676, 158.27233225657514, 123.29445199120855, 130.95171040790126]
expatweights: [153.84183383480226, 117.09050452304707, 182.46653358442495, 144.2100653423684, 167.11682736617547, 121.57197068693307, 138.8748
8688049697, 167.95514160373503, 140.8483579031441, 168.57633784427617]

Evaluating model... 
wcwt: 15.0
measerr: 20.0
exstaffweights: [128.5120210513865, 188.86912870512919, 166.7190188565341, 144.91910704233473, 158.5787954759802, 193.11692512992863, 133.3731
1167397962, 111.20409894257753, 122.8398842784177, 137.08070752897862]
expatweights: [121.38295137055916, 118.59580388554596, 212.68126054123962, 157.15016895391787, 133.42273212231717, 189.0564931535782, 132.5339
1016542082, 158.18721536473618, 139.25611445597923, 143.00152296009045]

Evaluating model... 
wcwt: 15.0
measerr: 19.999999999999996
exstaffweights: [128.5120210513865, 188.86912870512919, 166.7190188565341, 144.9191070423347, 158.5787954759802, 193.11692512992857, 133.37311
167397962, 111.20409894257753, 122.83988427841771, 137.08070752897862]
expatweights: [121.38295137055916, 118.59580388554596, 212.68126054123962, 157.15016895391787, 133.42273212231717, 189.05649315357823, 132.533
91016542082, 158.1872153647362, 139.25611445597923, 143.00152296009045]

Evaluating model... 
wcwt: TrackedReal<ETu>(15.0, 0.0, D0C, ---)
measerr: TrackedReal<D5S>(19.999999999999996, 0.0, D0C, ---)
exstaffweights: ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}[Tra
ckedReal<7Gz>(128.5120210513865, 0.0, D0C, 1, Cer), TrackedReal<1bK>(188.86912870512919, 0.0, D0C, 2, Cer), TrackedReal<DWX>(166.7190188565341
, 0.0, D0C, 3, Cer), TrackedReal<Bs3>(144.9191070423347, 0.0, D0C, 4, Cer), TrackedReal<71A>(158.5787954759802, 0.0, D0C, 5, Cer), TrackedReal
<FR7>(193.11692512992857, 0.0, D0C, 6, Cer), TrackedReal<1Af>(133.37311167397962, 0.0, D0C, 7, Cer), TrackedReal<LFq>(111.20409894257753, 0.0,
 D0C, 8, Cer), TrackedReal<3FK>(122.83988427841771, 0.0, D0C, 9, Cer), TrackedReal<FNM>(137.08070752897862, 0.0, D0C, 10, Cer)]
expatweights: ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}[Track
edReal<40n>(121.38295137055916, 0.0, D0C, 1, 7Ev), TrackedReal<66Y>(118.59580388554596, 0.0, D0C, 2, 7Ev), TrackedReal<C3o>(212.68126054123962
, 0.0, D0C, 3, 7Ev), TrackedReal<FFa>(157.15016895391787, 0.0, D0C, 4, 7Ev), TrackedReal<IFo>(133.42273212231717, 0.0, D0C, 5, 7Ev), TrackedRe
al<1Dd>(189.05649315357823, 0.0, D0C, 6, 7Ev), TrackedReal<BJt>(132.53391016542082, 0.0, D0C, 7, 7Ev), TrackedReal<7kv>(158.1872153647362, 0.0
, D0C, 8, 7Ev), TrackedReal<Kbf>(139.25611445597923, 0.0, D0C, 9, 7Ev), TrackedReal<PAj>(143.00152296009045, 0.0, D0C, 10, 7Ev)]

┌ Warning: The current proposal will be rejected due to numerical error(s).
│   isfinite.((θ, r, ℓπ, ℓκ)) = (false, true, false, true)
└ @ AdvancedHMC ~/.julia/packages/AdvancedHMC/HQHnm/src/hamiltonian.jl:47
┌ Warning: The current proposal will be rejected due to numerical error(s).
│   isfinite.((θ, r, ℓπ, ℓκ)) = (false, true, false, true)
└ @ AdvancedHMC ~/.julia/packages/AdvancedHMC/HQHnm/src/hamiltonian.jl:47
┌ Warning: The current proposal will be rejected due to numerical error(s).
│   isfinite.((θ, r, ℓπ, ℓκ)) = (false, false, false, false)
└ @ AdvancedHMC ~/.julia/packages/AdvancedHMC/HQHnm/src/hamiltonian.jl:47
┌ Warning: Incorrect ϵ = NaN; ϵ_previous = 0.2 is used instead.
└ @ AdvancedHMC.Adaptation ~/.julia/packages/AdvancedHMC/HQHnm/src/adaptation/stepsize.jl:125
┌ Warning: The current proposal will be rejected due to numerical error(s).
│   isfinite.((θ, r, ℓπ, ℓκ)) = (false, true, false, true)
└ @ AdvancedHMC ~/.julia/packages/AdvancedHMC/HQHnm/src/hamiltonian.jl:47
┌ Warning: The current proposal will be rejected due to numerical error(s).
│   isfinite.((θ, r, ℓπ, ℓκ)) = (false, false, false, false)
└ @ AdvancedHMC ~/.julia/packages/AdvancedHMC/HQHnm/src/hamiltonian.jl:47
┌ Warning: Incorrect ϵ = NaN; ϵ_previous = 0.2 is used instead.
└ @ AdvancedHMC.Adaptation ~/.julia/packages/AdvancedHMC/HQHnm/src/adaptation/stepsize.jl:125
┌ Warning: The current proposal will be rejected due to numerical error(s).
│   isfinite.((θ, r, ℓπ, ℓκ)) = (false, true, false, true)
└ @ AdvancedHMC ~/.julia/packages/AdvancedHMC/HQHnm/src/hamiltonian.jl:47
changed the title [-]Initial vector seemingly not used?[/-] [+]Difficulty sampling a model with truncated normal Likelihood[/+] on Oct 28, 2021
noamsgl

noamsgl commented on Jun 1, 2023

@noamsgl
torfjelde

torfjelde commented on Jun 1, 2023

@torfjelde
Member

A few points on this issue:

  1. When using samplers such as NUTS, which have this initial "adaptation phase", the initial point is the first point use when starting adaptation, but then once we start sampling, we might no longer be at this initial point.
  2. The reason why you're getting complaints about θ is not necessarily because the parameter itself is infininte, but it could also be that the gradient is infinite (the current message does not specify, which is unfortunate 😕).
noamsgl

noamsgl commented on Jun 3, 2023

@noamsgl

Thanks for the response.

  1. I copied and pasted the exact tutorial from the Turing.jl website, using all of the latest packages. How come we are getting this difference?

  2. In which way can I get a more informative error message, or go about debugging this in general?

torfjelde

torfjelde commented on Jun 22, 2023

@torfjelde
Member

Sorry, this went under my radar!

I copied and pasted the exact tutorial from the Turing.jl website, using all of the latest packages. How come we are getting this difference?

Okay, that's weird. Will have a look.

In which way can I get a more informative error message, or go about debugging this in general?

Pfft.. This is a bit difficult without touching internals. But you're right, we should have a good way of debugging these things. Let me think and I'll get back to you!

penelopeysm

penelopeysm commented on Jun 6, 2025

@penelopeysm
Member

I didn't run the full code, but I'm willing to bet (a small amount of) money that any remaining AD issues are caused by the use of truncated(dist, x, Inf). Replacing this with truncated(dist; lower=x) should fix the NaN gradients. Please ping if it's still a problem, and I'll be happy to reopen.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

      Development

      No branches or pull requests

        Participants

        @devmotion@yebai@dlakelan@torfjelde@noamsgl

        Issue actions

          Difficulty sampling a model with truncated normal Likelihood · Issue #1722 · TuringLang/Turing.jl