1
1
module AbstractMCMC
2
2
3
+ import ConsoleProgressMonitor
4
+ import LoggingExtras
3
5
import ProgressLogging
4
6
import StatsBase
5
7
using StatsBase: sample
8
+ import TerminalLoggers
6
9
7
10
import Distributed
8
11
import Logging
9
12
using Random: GLOBAL_RNG, AbstractRNG, seed!
10
- import UUIDs
13
+
14
+ # avoid creating a progress bar with @withprogress if progress logging is disabled
15
+ # and add a custom progress logger if the current logger does not seem to be able to handle
16
+ # progress logs
17
+ macro ifwithprogresslogger (progress, exprs... )
18
+ return quote
19
+ if $ progress
20
+ if $ hasprogresslevel ($ Logging. current_logger ())
21
+ $ ProgressLogging. @withprogress $ (exprs... )
22
+ else
23
+ $ with_progresslogger ($ Logging. current_logger ()) do
24
+ $ ProgressLogging. @withprogress $ (exprs... )
25
+ end
26
+ end
27
+ else
28
+ $ (exprs[end ])
29
+ end
30
+ end |> esc
31
+ end
32
+
33
+ # improved checks?
34
+ function hasprogresslevel (logger)
35
+ return Logging. min_enabled_level (logger) ≤ ProgressLogging. ProgressLevel
36
+ end
37
+
38
+ # filter better, e.g., according to group?
39
+ function with_progresslogger (f, logger)
40
+ _module = @__MODULE__
41
+ logger1 = LoggingExtras. EarlyFilteredLogger (progresslogger ()) do log
42
+ log. _module === _module && log. level == ProgressLogging. ProgressLevel
43
+ end
44
+ logger2 = LoggingExtras. EarlyFilteredLogger (logger) do log
45
+ log. _module != = _module || log. level != ProgressLogging. ProgressLevel
46
+ end
47
+
48
+ Logging. with_logger (f, LoggingExtras. TeeLogger (logger1, logger2))
49
+ end
50
+
51
+ function progresslogger ()
52
+ # detect if code is running under IJulia since TerminalLogger does not work with IJulia
53
+ # https://github.com/JuliaLang/IJulia.jl#detecting-that-code-is-running-under-ijulia
54
+ if isdefined (Main, :IJulia ) && Main. IJulia. inited
55
+ return ConsoleProgressMonitor. ProgressLogger ()
56
+ else
57
+ return TerminalLoggers. TerminalLogger ()
58
+ end
59
+ end
11
60
12
61
"""
13
62
AbstractChains
@@ -44,7 +93,7 @@ abstract type AbstractModel end
44
93
45
94
Return `N` samples from the MCMC `sampler` for the provided `model`.
46
95
47
- If a callback function `f` with type signature
96
+ If a callback function `f` with type signature
48
97
```julia
49
98
f(rng::AbstractRNG, model::AbstractModel, sampler::AbstractSampler, N::Integer,
50
99
iteration::Integer, transition; kwargs...)
@@ -77,15 +126,7 @@ function StatsBase.sample(
77
126
# Perform any necessary setup.
78
127
sample_init! (rng, model, sampler, N; kwargs... )
79
128
80
- # Create a progress bar.
81
- if progress
82
- progressid = UUIDs. uuid4 ()
83
- Logging. @logmsg (ProgressLogging. ProgressLevel, progressname, progress= NaN ,
84
- _id= progressid)
85
- end
86
-
87
- local transitions
88
- try
129
+ @ifwithprogresslogger progress name= progressname begin
89
130
# Obtain the initial transition.
90
131
transition = step! (rng, model, sampler, N; iteration= 1 , kwargs... )
91
132
@@ -97,10 +138,7 @@ function StatsBase.sample(
97
138
transitions_save! (transitions, 1 , transition, model, sampler, N; kwargs... )
98
139
99
140
# Update the progress bar.
100
- if progress
101
- Logging. @logmsg (ProgressLogging. ProgressLevel, progressname, progress= 1 / N,
102
- _id= progressid)
103
- end
141
+ progress && ProgressLogging. @logprogress 1 / N
104
142
105
143
# Step through the sampler.
106
144
for i in 2 : N
@@ -114,16 +152,7 @@ function StatsBase.sample(
114
152
transitions_save! (transitions, i, transition, model, sampler, N; kwargs... )
115
153
116
154
# Update the progress bar.
117
- if progress
118
- Logging. @logmsg (ProgressLogging. ProgressLevel, progressname, progress= i/ N,
119
- _id= progressid)
120
- end
121
- end
122
- finally
123
- # Close the progress bar.
124
- if progress
125
- Logging. @logmsg (ProgressLogging. ProgressLevel, progressname, progress= " done" ,
126
- _id= progressid)
155
+ progress && ProgressLogging. @logprogress i/ N
127
156
end
128
157
end
129
158
@@ -178,12 +207,12 @@ function sample_end!(
178
207
end
179
208
180
209
function bundle_samples (
181
- :: AbstractRNG ,
182
- :: AbstractModel ,
183
- :: AbstractSampler ,
184
- :: Integer ,
210
+ :: AbstractRNG ,
211
+ :: AbstractModel ,
212
+ :: AbstractSampler ,
213
+ :: Integer ,
185
214
transitions,
186
- :: Type{Any} ;
215
+ :: Type{Any} ;
187
216
kwargs...
188
217
)
189
218
return transitions
259
288
Sample `nchains` chains using the available threads, and combine them into a single chain.
260
289
261
290
By default, the random number generator, the model and the samplers are deep copied for each
262
- thread to prevent contamination between threads.
291
+ thread to prevent contamination between threads.
263
292
"""
264
293
function psample (
265
294
model:: AbstractModel ,
@@ -292,24 +321,20 @@ function psample(
292
321
# Set up a chains vector.
293
322
chains = Vector {Any} (undef, nchains)
294
323
295
- # Create a progress bar and a channel for progress logging.
296
- if progress
297
- progressid = UUIDs. uuid4 ()
298
- Logging. @logmsg (ProgressLogging. ProgressLevel, progressname, progress= NaN ,
299
- _id= progressid)
300
- channel = Distributed. RemoteChannel (() -> Channel {Bool} (nchains), 1 )
301
- end
324
+ @ifwithprogresslogger progress name= progressname begin
325
+ # Create a channel for progress logging.
326
+ if progress
327
+ channel = Distributed. RemoteChannel (() -> Channel {Bool} (nchains), 1 )
328
+ end
302
329
303
- try
304
330
Distributed. @sync begin
305
331
if progress
306
332
Distributed. @async begin
307
333
# Update the progress bar.
308
334
progresschains = 0
309
335
while take! (channel)
310
336
progresschains += 1
311
- Logging. @logmsg (ProgressLogging. ProgressLevel, progressname,
312
- progress= progresschains/ nchains, _id= progressid)
337
+ ProgressLogging. @logprogress progresschains/ nchains
313
338
end
314
339
end
315
340
end
@@ -322,7 +347,7 @@ function psample(
322
347
# Seed the thread-specific random number generator with the pre-made seed.
323
348
subrng = rngs[id]
324
349
seed! (subrng, seeds[i])
325
-
350
+
326
351
# Sample a chain and save it to the vector.
327
352
chains[i] = sample (subrng, models[id], samplers[id], N;
328
353
progress = false , kwargs... )
@@ -335,12 +360,6 @@ function psample(
335
360
progress && put! (channel, false )
336
361
end
337
362
end
338
- finally
339
- # Close the progress bar.
340
- if progress
341
- Logging. @logmsg (ProgressLogging. ProgressLevel, progressname,
342
- progress= " done" , _id= progressid)
343
- end
344
363
end
345
364
346
365
# Concatenate the chains together.
0 commit comments