Skip to content

Commit b29b38d

Browse files
authored
Merge pull request #21 from TuringLang/filter
Provide default progress loggers
2 parents c867a65 + 930708c commit b29b38d

File tree

4 files changed

+167
-69
lines changed

4 files changed

+167
-69
lines changed

Project.toml

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,28 @@ desc = "A lightweight interface for common MCMC methods."
66
version = "0.5.0"
77

88
[deps]
9+
ConsoleProgressMonitor = "88cd18e8-d9cc-4ea6-8889-5259c0d15c8b"
910
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
1011
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
12+
LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36"
1113
ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
1214
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1315
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
14-
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
16+
TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed"
1517

1618
[compat]
19+
ConsoleProgressMonitor = "0.1"
20+
LoggingExtras = "0.4"
1721
ProgressLogging = "0.1"
1822
StatsBase = "0.32"
23+
TerminalLoggers = "0.1"
1924
julia = "1"
2025

2126
[extras]
27+
Atom = "c52e3926-4ff0-5f6e-af25-54175e0327b1"
28+
IJulia = "7073ff75-c697-5162-941a-fcdaad2a7d2a"
2229
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
23-
TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed"
2430
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2531

2632
[targets]
27-
test = ["Statistics", "Test", "TerminalLoggers"]
33+
test = ["Atom", "IJulia", "Statistics", "Test"]

src/AbstractMCMC.jl

Lines changed: 67 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,62 @@
11
module AbstractMCMC
22

3+
import ConsoleProgressMonitor
4+
import LoggingExtras
35
import ProgressLogging
46
import StatsBase
57
using StatsBase: sample
8+
import TerminalLoggers
69

710
import Distributed
811
import Logging
912
using Random: GLOBAL_RNG, AbstractRNG, seed!
10-
import UUIDs
13+
14+
# avoid creating a progress bar with @withprogress if progress logging is disabled
15+
# and add a custom progress logger if the current logger does not seem to be able to handle
16+
# progress logs
17+
macro ifwithprogresslogger(progress, exprs...)
18+
return quote
19+
if $progress
20+
if $hasprogresslevel($Logging.current_logger())
21+
$ProgressLogging.@withprogress $(exprs...)
22+
else
23+
$with_progresslogger($Logging.current_logger()) do
24+
$ProgressLogging.@withprogress $(exprs...)
25+
end
26+
end
27+
else
28+
$(exprs[end])
29+
end
30+
end |> esc
31+
end
32+
33+
# improved checks?
34+
function hasprogresslevel(logger)
35+
return Logging.min_enabled_level(logger) ProgressLogging.ProgressLevel
36+
end
37+
38+
# filter better, e.g., according to group?
39+
function with_progresslogger(f, logger)
40+
_module = @__MODULE__
41+
logger1 = LoggingExtras.EarlyFilteredLogger(progresslogger()) do log
42+
log._module === _module && log.level == ProgressLogging.ProgressLevel
43+
end
44+
logger2 = LoggingExtras.EarlyFilteredLogger(logger) do log
45+
log._module !== _module || log.level != ProgressLogging.ProgressLevel
46+
end
47+
48+
Logging.with_logger(f, LoggingExtras.TeeLogger(logger1, logger2))
49+
end
50+
51+
function progresslogger()
52+
# detect if code is running under IJulia since TerminalLogger does not work with IJulia
53+
# https://github.com/JuliaLang/IJulia.jl#detecting-that-code-is-running-under-ijulia
54+
if isdefined(Main, :IJulia) && Main.IJulia.inited
55+
return ConsoleProgressMonitor.ProgressLogger()
56+
else
57+
return TerminalLoggers.TerminalLogger()
58+
end
59+
end
1160

1261
"""
1362
AbstractChains
@@ -44,7 +93,7 @@ abstract type AbstractModel end
4493
4594
Return `N` samples from the MCMC `sampler` for the provided `model`.
4695
47-
If a callback function `f` with type signature
96+
If a callback function `f` with type signature
4897
```julia
4998
f(rng::AbstractRNG, model::AbstractModel, sampler::AbstractSampler, N::Integer,
5099
iteration::Integer, transition; kwargs...)
@@ -77,15 +126,7 @@ function StatsBase.sample(
77126
# Perform any necessary setup.
78127
sample_init!(rng, model, sampler, N; kwargs...)
79128

80-
# Create a progress bar.
81-
if progress
82-
progressid = UUIDs.uuid4()
83-
Logging.@logmsg(ProgressLogging.ProgressLevel, progressname, progress=NaN,
84-
_id=progressid)
85-
end
86-
87-
local transitions
88-
try
129+
@ifwithprogresslogger progress name=progressname begin
89130
# Obtain the initial transition.
90131
transition = step!(rng, model, sampler, N; iteration=1, kwargs...)
91132

@@ -97,10 +138,7 @@ function StatsBase.sample(
97138
transitions_save!(transitions, 1, transition, model, sampler, N; kwargs...)
98139

99140
# Update the progress bar.
100-
if progress
101-
Logging.@logmsg(ProgressLogging.ProgressLevel, progressname, progress=1/N,
102-
_id=progressid)
103-
end
141+
progress && ProgressLogging.@logprogress 1/N
104142

105143
# Step through the sampler.
106144
for i in 2:N
@@ -114,16 +152,7 @@ function StatsBase.sample(
114152
transitions_save!(transitions, i, transition, model, sampler, N; kwargs...)
115153

116154
# Update the progress bar.
117-
if progress
118-
Logging.@logmsg(ProgressLogging.ProgressLevel, progressname, progress=i/N,
119-
_id=progressid)
120-
end
121-
end
122-
finally
123-
# Close the progress bar.
124-
if progress
125-
Logging.@logmsg(ProgressLogging.ProgressLevel, progressname, progress="done",
126-
_id=progressid)
155+
progress && ProgressLogging.@logprogress i/N
127156
end
128157
end
129158

@@ -178,12 +207,12 @@ function sample_end!(
178207
end
179208

180209
function bundle_samples(
181-
::AbstractRNG,
182-
::AbstractModel,
183-
::AbstractSampler,
184-
::Integer,
210+
::AbstractRNG,
211+
::AbstractModel,
212+
::AbstractSampler,
213+
::Integer,
185214
transitions,
186-
::Type{Any};
215+
::Type{Any};
187216
kwargs...
188217
)
189218
return transitions
@@ -259,7 +288,7 @@ end
259288
Sample `nchains` chains using the available threads, and combine them into a single chain.
260289
261290
By default, the random number generator, the model and the samplers are deep copied for each
262-
thread to prevent contamination between threads.
291+
thread to prevent contamination between threads.
263292
"""
264293
function psample(
265294
model::AbstractModel,
@@ -292,24 +321,20 @@ function psample(
292321
# Set up a chains vector.
293322
chains = Vector{Any}(undef, nchains)
294323

295-
# Create a progress bar and a channel for progress logging.
296-
if progress
297-
progressid = UUIDs.uuid4()
298-
Logging.@logmsg(ProgressLogging.ProgressLevel, progressname, progress=NaN,
299-
_id=progressid)
300-
channel = Distributed.RemoteChannel(() -> Channel{Bool}(nchains), 1)
301-
end
324+
@ifwithprogresslogger progress name=progressname begin
325+
# Create a channel for progress logging.
326+
if progress
327+
channel = Distributed.RemoteChannel(() -> Channel{Bool}(nchains), 1)
328+
end
302329

303-
try
304330
Distributed.@sync begin
305331
if progress
306332
Distributed.@async begin
307333
# Update the progress bar.
308334
progresschains = 0
309335
while take!(channel)
310336
progresschains += 1
311-
Logging.@logmsg(ProgressLogging.ProgressLevel, progressname,
312-
progress=progresschains/nchains, _id=progressid)
337+
ProgressLogging.@logprogress progresschains/nchains
313338
end
314339
end
315340
end
@@ -322,7 +347,7 @@ function psample(
322347
# Seed the thread-specific random number generator with the pre-made seed.
323348
subrng = rngs[id]
324349
seed!(subrng, seeds[i])
325-
350+
326351
# Sample a chain and save it to the vector.
327352
chains[i] = sample(subrng, models[id], samplers[id], N;
328353
progress = false, kwargs...)
@@ -335,12 +360,6 @@ function psample(
335360
progress && put!(channel, false)
336361
end
337362
end
338-
finally
339-
# Close the progress bar.
340-
if progress
341-
Logging.@logmsg(ProgressLogging.ProgressLevel, progressname,
342-
progress="done", _id=progressid)
343-
end
344363
end
345364

346365
# Concatenate the chains together.

test/interface.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,13 @@ function AbstractMCMC.step!(
1919
N::Integer,
2020
transition::Union{Nothing,MyTransition};
2121
sleepy = false,
22+
loggers = false,
2223
kwargs...
2324
)
2425
a = rand(rng)
2526
b = randn(rng)
2627

28+
loggers && push!(LOGGERS, Logging.current_logger())
2729
sleepy && sleep(0.001)
2830

2931
return MyTransition(a, b)
@@ -50,4 +52,4 @@ function AbstractMCMC.bundle_samples(
5052
return MyChain(as, bs)
5153
end
5254

53-
AbstractMCMC.chainscat(chains::Union{MyChain,Vector{<:MyChain}}...) = vcat(chains...)
55+
AbstractMCMC.chainscat(chains::Union{MyChain,Vector{<:MyChain}}...) = vcat(chains...)

test/runtests.jl

Lines changed: 88 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,104 @@
11
using AbstractMCMC
22
using AbstractMCMC: sample, psample, steps!
3-
import TerminalLoggers
3+
using Atom.Progress: JunoProgressLogger
4+
using ConsoleProgressMonitor: ProgressLogger
5+
using IJulia
6+
using LoggingExtras: TeeLogger, EarlyFilteredLogger
7+
using TerminalLoggers: TerminalLogger
48

59
import Logging
610
using Random
711
using Statistics
812
using Test
913
using Test: collect_test_logs
1014

11-
# install progress logger
12-
Logging.global_logger(TerminalLoggers.TerminalLogger(right_justify=120))
15+
const LOGGERS = Set()
16+
const CURRENT_LOGGER = Logging.current_logger()
1317

1418
include("interface.jl")
1519

1620
@testset "AbstractMCMC" begin
1721
@testset "Basic sampling" begin
18-
Random.seed!(1234)
19-
N = 1_000
20-
chain = sample(MyModel(), MySampler(), N; sleepy = true)
21-
22-
# test output type and size
23-
@test chain isa Vector{MyTransition}
24-
@test length(chain) == N
25-
26-
# test some statistical properties
27-
@test mean(x.a for x in chain) 0.5 atol=6e-2
28-
@test var(x.a for x in chain) 1 / 12 atol=5e-3
29-
@test mean(x.b for x in chain) 0.0 atol=5e-2
30-
@test var(x.b for x in chain) 1 atol=6e-2
22+
@testset "REPL" begin
23+
empty!(LOGGERS)
24+
25+
Random.seed!(1234)
26+
N = 1_000
27+
chain = sample(MyModel(), MySampler(), N; sleepy = true, loggers = true)
28+
29+
@test length(LOGGERS) == 1
30+
logger = first(LOGGERS)
31+
@test logger isa TeeLogger
32+
@test logger.loggers[1].logger isa TerminalLogger
33+
@test logger.loggers[2].logger === CURRENT_LOGGER
34+
@test Logging.current_logger() === CURRENT_LOGGER
35+
36+
# test output type and size
37+
@test chain isa Vector{MyTransition}
38+
@test length(chain) == N
39+
40+
# test some statistical properties
41+
@test mean(x.a for x in chain) 0.5 atol=6e-2
42+
@test var(x.a for x in chain) 1 / 12 atol=5e-3
43+
@test mean(x.b for x in chain) 0.0 atol=5e-2
44+
@test var(x.b for x in chain) 1 atol=6e-2
45+
end
46+
47+
@testset "Juno" begin
48+
empty!(LOGGERS)
49+
50+
Random.seed!(1234)
51+
N = 10
52+
53+
logger = JunoProgressLogger()
54+
Logging.with_logger(logger) do
55+
sample(MyModel(), MySampler(), N; sleepy = true, loggers = true)
56+
end
57+
58+
@test length(LOGGERS) == 1
59+
@test first(LOGGERS) === logger
60+
@test Logging.current_logger() === CURRENT_LOGGER
61+
end
62+
63+
@testset "IJulia" begin
64+
# emulate running IJulia kernel
65+
@eval IJulia begin
66+
inited = true
67+
end
68+
69+
empty!(LOGGERS)
70+
71+
Random.seed!(1234)
72+
N = 10
73+
sample(MyModel(), MySampler(), N; sleepy = true, loggers = true)
74+
75+
@test length(LOGGERS) == 1
76+
logger = first(LOGGERS)
77+
@test logger isa TeeLogger
78+
@test logger.loggers[1].logger isa ProgressLogger
79+
@test logger.loggers[2].logger === CURRENT_LOGGER
80+
@test Logging.current_logger() === CURRENT_LOGGER
81+
82+
@eval IJulia begin
83+
inited = false
84+
end
85+
end
86+
87+
@testset "Custom logger" begin
88+
empty!(LOGGERS)
89+
90+
Random.seed!(1234)
91+
N = 10
92+
93+
logger = Logging.ConsoleLogger(stderr, Logging.LogLevel(-1))
94+
Logging.with_logger(logger) do
95+
sample(MyModel(), MySampler(), N; sleepy = true, loggers = true)
96+
end
97+
98+
@test length(LOGGERS) == 1
99+
@test first(LOGGERS) === logger
100+
@test Logging.current_logger() === CURRENT_LOGGER
101+
end
31102
end
32103

33104
if VERSION v"1.3"
@@ -104,4 +175,4 @@ include("interface.jl")
104175
@test Base.IteratorSize(iter) == Base.IsInfinite()
105176
@test Base.IteratorEltype(iter) == Base.EltypeUnknown()
106177
end
107-
end
178+
end

0 commit comments

Comments
 (0)