Skip to content

Commit 04002ca

Browse files
committed
Warn on NaN's
1 parent 019e41b commit 04002ca

File tree

3 files changed

+15
-10
lines changed

3 files changed

+15
-10
lines changed

HISTORY.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# DynamicPPL Changelog
22

3+
## 0.35.7
4+
5+
`check_model_and_trace` now prints a warning if any NaN's are encountered when evaluating the model.
6+
37
## 0.35.6
48

59
Fixed the implementation of `.~`, such that running a model with it no longer requires DynamicPPL itself to be loaded.

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.35.6"
3+
version = "0.35.7"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/debug_utils.jl

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -231,20 +231,17 @@ function record_varname!(context::DebugContext, varname::VarName, dist)
231231
end
232232
end
233233

234-
# tilde
235-
_isassigned(x::AbstractArray, i) = isassigned(x, i)
236-
# HACK(torfjelde): Julia v1.7 only supports `isassigned(::AbstractArray, ::Int...)`.
237-
# TODO(torfjelde): Determine exactly in which version this change was introduced.
238-
if VERSION < v"v1.9.0-alpha1"
239-
_isassigned(x::AbstractArray, inds::Tuple) = isassigned(x, inds...)
240-
_isassigned(x::AbstractArray, idx::CartesianIndex) = _isassigned(x, Tuple(idx))
234+
function _check_nan(logp)
235+
if isnan(logp)
236+
@warn "Encountered a NaN logp value; this may indicate that your data contains NaN values"
237+
end
241238
end
242239

243240
_has_missings(x) = ismissing(x)
244241
function _has_missings(x::AbstractArray)
245242
# Can't just use `any` because `x` might contain `undef`.
246243
for i in eachindex(x)
247-
if _isassigned(x, i) && _has_missings(x[i])
244+
if isassigned(x, i) && _has_missings(x[i])
248245
return true
249246
end
250247
end
@@ -274,6 +271,7 @@ end
274271
function DynamicPPL.tilde_assume(context::DebugContext, right, vn, vi)
275272
record_pre_tilde_assume!(context, vn, right, vi)
276273
value, logp, vi = DynamicPPL.tilde_assume(childcontext(context), right, vn, vi)
274+
_check_nan(logp)
277275
record_post_tilde_assume!(context, vn, right, value, logp, vi)
278276
return value, logp, vi
279277
end
@@ -284,6 +282,7 @@ function DynamicPPL.tilde_assume(
284282
value, logp, vi = DynamicPPL.tilde_assume(
285283
rng, childcontext(context), sampler, right, vn, vi
286284
)
285+
_check_nan(logp)
287286
record_post_tilde_assume!(context, vn, right, value, logp, vi)
288287
return value, logp, vi
289288
end
@@ -316,12 +315,14 @@ end
316315
function DynamicPPL.tilde_observe(context::DebugContext, right, left, vi)
317316
record_pre_tilde_observe!(context, left, right, vi)
318317
logp, vi = DynamicPPL.tilde_observe(childcontext(context), right, left, vi)
318+
_check_nan(logp)
319319
record_post_tilde_observe!(context, left, right, logp, vi)
320320
return logp, vi
321321
end
322322
function DynamicPPL.tilde_observe(context::DebugContext, sampler, right, left, vi)
323323
record_pre_tilde_observe!(context, left, right, vi)
324324
logp, vi = DynamicPPL.tilde_observe(childcontext(context), sampler, right, left, vi)
325+
_check_nan(logp)
325326
record_post_tilde_observe!(context, left, right, logp, vi)
326327
return logp, vi
327328
end
@@ -474,7 +475,7 @@ end
474475
475476
Check that `model` is valid, warning about any potential issues.
476477
477-
See [`check_model_and_trace`](@ref) for more details on supported keword arguments
478+
See [`check_model_and_trace`](@ref) for more details on supported keyword arguments
478479
and details of which types of checks are performed.
479480
480481
# Returns

0 commit comments

Comments
 (0)