From 18586b5e7419b1ec64858afe0a3ed372cbfa6cbb Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 4 Mar 2020 22:34:44 +0100 Subject: [PATCH] Display warning message if no known progress logging frontend is loaded --- Project.toml | 2 ++ src/AbstractMCMC.jl | 59 +++++++++++++++++++++++++++++++++++++++------ 2 files changed, 53 insertions(+), 8 deletions(-) diff --git a/Project.toml b/Project.toml index 0f316315..7108beb1 100644 --- a/Project.toml +++ b/Project.toml @@ -10,11 +10,13 @@ Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Requires = "ae029012-a4dd-5104-9daa-d747884805df" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" [compat] ProgressLogging = "0.1" +Requires = "1" StatsBase = "0.32" julia = "1" diff --git a/src/AbstractMCMC.jl b/src/AbstractMCMC.jl index 4893579d..9a7f55c6 100644 --- a/src/AbstractMCMC.jl +++ b/src/AbstractMCMC.jl @@ -1,6 +1,7 @@ module AbstractMCMC import ProgressLogging +import Requires import StatsBase using StatsBase: sample @@ -9,6 +10,20 @@ import Logging using Random: GLOBAL_RNG, AbstractRNG, seed! import UUIDs +const PROGRESSLOGGERLOADED = Ref(false) + +function __init__() + Requires.@require Atom = "c52e3926-4ff0-5f6e-af25-54175e0327b1" begin + PROGRESSLOGGERLOADED[] = true + end + Requires.@require ConsoleProgressMonitor = "88cd18e8-d9cc-4ea6-8889-5259c0d15c8b" begin + PROGRESSLOGGERLOADED[] = true + end + Requires.@require TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed" begin + PROGRESSLOGGERLOADED[] = true + end +end + """ AbstractChains @@ -44,7 +59,7 @@ abstract type AbstractModel end Return `N` samples from the MCMC `sampler` for the provided `model`. -If a callback function `f` with type signature +If a callback function `f` with type signature ```julia f(rng::AbstractRNG, model::AbstractModel, sampler::AbstractSampler, N::Integer, iteration::Integer, transition; kwargs...) @@ -77,6 +92,9 @@ function StatsBase.sample( # Perform any necessary setup. sample_init!(rng, model, sampler, N; kwargs...) + # Check progress logger frontends. + progress && checkprogresslogger() + # Create a progress bar. if progress progressid = UUIDs.uuid4() @@ -178,12 +196,12 @@ function sample_end!( end function bundle_samples( - ::AbstractRNG, - ::AbstractModel, - ::AbstractSampler, - ::Integer, + ::AbstractRNG, + ::AbstractModel, + ::AbstractSampler, + ::Integer, transitions, - ::Type{Any}; + ::Type{Any}; kwargs... ) return transitions @@ -259,7 +277,7 @@ end Sample `nchains` chains using the available threads, and combine them into a single chain. By default, the random number generator, the model and the samplers are deep copied for each -thread to prevent contamination between threads. +thread to prevent contamination between threads. """ function psample( model::AbstractModel, @@ -292,6 +310,9 @@ function psample( # Set up a chains vector. chains = Vector{Any}(undef, nchains) + # Check progress logger frontends. + progress && checkprogresslogger() + # Create a progress bar and a channel for progress logging. if progress progressid = UUIDs.uuid4() @@ -322,7 +343,7 @@ function psample( # Seed the thread-specific random number generator with the pre-made seed. subrng = rngs[id] seed!(subrng, seeds[i]) - + # Sample a chain and save it to the vector. chains[i] = sample(subrng, models[id], samplers[id], N; progress = false, kwargs...) @@ -399,4 +420,26 @@ function steps!( return Stepper(rng, model, s, kwargs) end +const PROGRESSLOGGERWARNING = """ +It seems that progress bars can not be displayed properly since no known progress logging +frontend is loaded. + +Please install a progress logging frontend such as +* Juno, +* TerminalLoggers (for the REPL), or +* ConsoleProgressMonitor (for Jupyter notebooks). + +More information can be found in the documentation of ProgressLogging and the respective +frontends. + +If you have a working progress logging frontend, you can disable this warning message +manually by setting `AbstractMCMC.PROGRESSLOGGERLOADED[] = true`. +""" + +# Check if known progress logging frontends are loaded +function checkprogresslogger() + !PROGRESSLOGGERLOADED[] && @warn PROGRESSLOGGERWARNING + return +end + end # module AbstractMCMC