Skip to content

Display warning message if no known progress logging frontend is loaded #20

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@ Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"

[compat]
ProgressLogging = "0.1"
Requires = "1"
StatsBase = "0.32"
julia = "1"

Expand Down
59 changes: 51 additions & 8 deletions src/AbstractMCMC.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module AbstractMCMC

import ProgressLogging
import Requires
import StatsBase
using StatsBase: sample

Expand All @@ -9,6 +10,20 @@ import Logging
using Random: GLOBAL_RNG, AbstractRNG, seed!
import UUIDs

const PROGRESSLOGGERLOADED = Ref(false)

function __init__()
Requires.@require Atom = "c52e3926-4ff0-5f6e-af25-54175e0327b1" begin
PROGRESSLOGGERLOADED[] = true
end
Requires.@require ConsoleProgressMonitor = "88cd18e8-d9cc-4ea6-8889-5259c0d15c8b" begin
PROGRESSLOGGERLOADED[] = true
end
Requires.@require TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed" begin
PROGRESSLOGGERLOADED[] = true
end
end

"""
AbstractChains

Expand Down Expand Up @@ -44,7 +59,7 @@ abstract type AbstractModel end

Return `N` samples from the MCMC `sampler` for the provided `model`.

If a callback function `f` with type signature
If a callback function `f` with type signature
```julia
f(rng::AbstractRNG, model::AbstractModel, sampler::AbstractSampler, N::Integer,
iteration::Integer, transition; kwargs...)
Expand Down Expand Up @@ -77,6 +92,9 @@ function StatsBase.sample(
# Perform any necessary setup.
sample_init!(rng, model, sampler, N; kwargs...)

# Check progress logger frontends.
progress && checkprogresslogger()

# Create a progress bar.
if progress
progressid = UUIDs.uuid4()
Expand Down Expand Up @@ -178,12 +196,12 @@ function sample_end!(
end

function bundle_samples(
::AbstractRNG,
::AbstractModel,
::AbstractSampler,
::Integer,
::AbstractRNG,
::AbstractModel,
::AbstractSampler,
::Integer,
transitions,
::Type{Any};
::Type{Any};
kwargs...
)
return transitions
Expand Down Expand Up @@ -259,7 +277,7 @@ end
Sample `nchains` chains using the available threads, and combine them into a single chain.

By default, the random number generator, the model and the samplers are deep copied for each
thread to prevent contamination between threads.
thread to prevent contamination between threads.
"""
function psample(
model::AbstractModel,
Expand Down Expand Up @@ -292,6 +310,9 @@ function psample(
# Set up a chains vector.
chains = Vector{Any}(undef, nchains)

# Check progress logger frontends.
progress && checkprogresslogger()

# Create a progress bar and a channel for progress logging.
if progress
progressid = UUIDs.uuid4()
Expand Down Expand Up @@ -322,7 +343,7 @@ function psample(
# Seed the thread-specific random number generator with the pre-made seed.
subrng = rngs[id]
seed!(subrng, seeds[i])

# Sample a chain and save it to the vector.
chains[i] = sample(subrng, models[id], samplers[id], N;
progress = false, kwargs...)
Expand Down Expand Up @@ -399,4 +420,26 @@ function steps!(
return Stepper(rng, model, s, kwargs)
end

const PROGRESSLOGGERWARNING = """
It seems that progress bars can not be displayed properly since no known progress logging
frontend is loaded.

Please install a progress logging frontend such as
* Juno,
* TerminalLoggers (for the REPL), or
* ConsoleProgressMonitor (for Jupyter notebooks).

More information can be found in the documentation of ProgressLogging and the respective
frontends.

If you have a working progress logging frontend, you can disable this warning message
manually by setting `AbstractMCMC.PROGRESSLOGGERLOADED[] = true`.
"""

# Check if known progress logging frontends are loaded
function checkprogresslogger()
!PROGRESSLOGGERLOADED[] && @warn PROGRESSLOGGERWARNING
return
end

end # module AbstractMCMC