-
Notifications
You must be signed in to change notification settings - Fork 228
Closed
Labels
Description
I've been having problems sampling models that use truncated normal distributions.
This minimal worked example from the discourse.julialang.org discussion shows that the initial vector passed in seems to be problematic. The first printout doesn't show the provided initial value, but rather a value with a very small standard deviation of the errors, therefore it immediately has numerical issues, showing zero probability to be at the initial point.
using Pkg
Pkg.activate(".")
using Turing, DataFrames,DataFramesMeta,LazyArrays,Distributions,DistributionsAD
using LazyArrays, ReverseDiff, Memoization
## every few hours a random staff member comes and gets a random
## patient to bring them outside to a garden through a door that has a
## scale. Sometimes using a wheelchair, sometimes not. knowing the
## total weight of the two people and the wheelchair plus some errors
## (from the scale measurements), infer the individual weights of all
## individuals and the weight of the wheelchair.
nstaff = 100
npat = 100
staffids = collect(1:nstaff)
patientids = collect(1:npat)
staffweights = rand(Normal(150,30),length(staffids))
patientweights = rand(Normal(150,30),length(staffids))
wheelchairwt = 15
nobs = 300
data = DataFrame(staff=rand(staffids,nobs),patient=rand(patientids,nobs))
data.usewch = rand(0:1,nobs)
data.totweights = [staffweights[data.staff[i]] + patientweights[data.patient[i]] for i in 1:nrow(data)] .+ data.usewch .* wheelchairwt .+ rand(Normal(0.0,20.0),nrow(data))
Turing.setadbackend(:reversediff)
Turing.setrdcache(true)
Turing.emptyrdcache()
@model function estweights(nstaff,staffid,npatients,patientid,usewch,totweight)
wcwt ~ Gamma(20.0,15.0/19)
staffweights ~ filldist(Normal(150,30),nstaff)
patientweights ~ filldist(Normal(150,30),npatients)
totweight ~ MvNormal(view(staffweights,staffid) .+ view(patientweights,patientid) .+ usewch .* wcwt,20.0)
end
@model function estweights2(nstaff,staffid,npatients,patientid,usewch,totweight)
wcwt ~ Gamma(20.0,15.0/19)
staffweights ~ filldist(truncated(Normal(150,30),90.0,Inf),nstaff)
patientweights ~ filldist(truncated(Normal(150,30),90.0,Inf),npatients)
totweight ~ arraydist([Gamma(15,(staffweights[staffid[i]] + patientweights[patientid[i]] + usewch[i] * wcwt)/14) for i in 1:length(totweight)])
end
@model function estweights3(nstaff,staffid,npatients,patientid,usewch,totweight)
wcwt ~ Gamma(20.0,15.0/19)
measerr ~ Gamma(10.0,20.0/9)
staffweights ~ filldist(truncated(Normal(150,30),90.0,Inf),nstaff)
patientweights ~ filldist(truncated(Normal(150,30),90.0,Inf),npatients)
totweight ~ arraydist([truncated(Normal(staffweights[staffid[i]] + patientweights[patientid[i]] + usewch[i] * wcwt, measerr),0.0,Inf) for i in 1:length(totweight)])
end
function truncatenormal(a,b)::UnivariateDistribution
truncated(Normal(a,b),0.0,Inf)
end
@model function estweights3lazy(nstaff,staffid,npatients,patientid,usewch,totweight)
wcwt ~ Gamma(20.0,15.0/19)
measerr ~ Gamma(10.0,20.0/9)
staffweights ~ filldist(truncated(Normal(150,30),90.0,Inf),nstaff)
patientweights ~ filldist(truncated(Normal(150,30),90.0,Inf),npatients)
theta = LazyArray(@~ view(staffweights,staffid) .+ view(patientweights,patientid) .+ usewch .* wcwt)
println("""Evaluating model...
wcwt: $wcwt
measerr: $measerr
exstaffweights: $(staffweights[1:10])
expatweights: $(patientweights[1:10])
""")
totweight ~ arraydist(LazyArray(@~ truncatenormal.(theta,measerr)))
end
@model function estweights4(nstaff,staffid,npatients,patientid,usewch,totweight)
wcwt ~ Gamma(20.0,15.0/19)
measerr ~ Gamma(10.0,20.0/9)
staffweights ~ filldist(truncated(Normal(150,30),90.0,Inf),nstaff)
patientweights ~ filldist(truncated(Normal(150,30),90.0,Inf),npatients)
means = view(staffweights,staffid) .+ view(patientweights,patientid) .+ usewch .* wcwt
totweight .~ Gamma.(12,means./11)
end
@model function estweightslazygamma(nstaff,staffid,npatients,patientid,usewch,totweight)
wcwt ~ Gamma(20.0,15.0/19)
measerr ~ Gamma(10.0,20.0/9)
staffweights ~ filldist(truncated(Normal(150,30),90.0,Inf),nstaff)
patientweights ~ filldist(truncated(Normal(150,30),90.0,Inf),npatients)
theta = LazyArray(@~ view(staffweights,staffid) .+ view(patientweights,patientid) .+ usewch .* wcwt)
totweight ~ arraydist(LazyArray(@~ Gamma.(15, theta ./ 14)))
end
# ch1 = sample(estweights(nstaff,data.staff,npat,data.patient,data.usewch,data.totweights),NUTS(500,.75),1000)
# ch2 = sample(estweights2(nstaff,data.staff,npat,data.patient,data.usewch,data.totweights),NUTS(500,.75),1000)
# ch3 = sample(estweights3(nstaff,data.staff,npat,data.patient,data.usewch,data.totweights),NUTS(500,.75),1000)
ch3l = sample(estweights3lazy(nstaff,data.staff,npat,data.patient,data.usewch,data.totweights),NUTS(500,.75,init_ϵ=.002),1000;
init_theta = vcat([15.0,20.0],staffweights,patientweights))
# ch4 = sample(estweights4(nstaff,data.staff,npat,data.patient,data.usewch,data.totweights),NUTS(500,.75),1000)
#ch5 = sample(estweightslazygamma(nstaff,data.staff,npat,data.patient,data.usewch,data.totweights),NUTS(500,.75),1000)
When running this version, the initial
Metadata
Metadata
Assignees
Labels
Type
Projects
Milestone
Relationships
Development
Select code repository
Activity
devmotion commentedon Oct 27, 2021
I think this issue is relevant here: #1588
Most importantly, it seems you use the keyword argument
init_theta
. However, as also discussed in the linked issue, initial parameter values have to be specified withinit_params
.dlakelan commentedon Oct 28, 2021
Aha. Well that's disappointing, I've been running pre-optimizations and following the docs to use those optimizations as the starting points...
Ok, so if I use init_params it does seem to evaluate at the initial vector, but then it complains as follows:, first printing a few reals, then suddenly switching to TrackedReals and then complaining about non-finite values, specifically, apparently it thinks the parameter vector has gone non-finite, and so has the log(p) this is even though the printed values of the parameters that I was checking are reasonable.
[-]Initial vector seemingly not used?[/-][+]Difficulty sampling a model with truncated normal Likelihood[/+]noamsgl commentedon Jun 1, 2023
Getting similar errors, but with the tutorial from https://turing.ml/dev/tutorials/10-bayesian-differential-equations/
🎈 sde_lv.jl — Pluto.jl.pdf
torfjelde commentedon Jun 1, 2023
A few points on this issue:
θ
is not necessarily because the parameter itself is infininte, but it could also be that the gradient is infinite (the current message does not specify, which is unfortunate 😕).noamsgl commentedon Jun 3, 2023
Thanks for the response.
I copied and pasted the exact tutorial from the Turing.jl website, using all of the latest packages. How come we are getting this difference?
In which way can I get a more informative error message, or go about debugging this in general?
torfjelde commentedon Jun 22, 2023
Sorry, this went under my radar!
Okay, that's weird. Will have a look.
Pfft.. This is a bit difficult without touching internals. But you're right, we should have a good way of debugging these things. Let me think and I'll get back to you!
penelopeysm commentedon Jun 6, 2025
I didn't run the full code, but I'm willing to bet (a small amount of) money that any remaining AD issues are caused by the use of
truncated(dist, x, Inf)
. Replacing this withtruncated(dist; lower=x)
should fix the NaN gradients. Please ping if it's still a problem, and I'll be happy to reopen.