-
Notifications
You must be signed in to change notification settings - Fork 226
Rework sample()
call stack to use LogDensityFunction
#2555
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Sounds reasonable to me, but I haven't thought about this call stack deeply, so my opinion isn't strong. For Gibbs, I don't immediately see a problem as long as we can call |
Great idea. Let's rename FYI, there is ongoing discussion @sunxd3 and I about refactoring all modelling APIs like |
Uh oh!
There was an error while loading. Please reload this page.
Current situation
Right now, if you call
sample(::Model, ::InferenceAlgorithm, ::Int)
this first goes tosrc/mcmc/Inference.jl
where theInferenceAlgorithm
gets wrapped inDynamicPPL.Sampler
, e.g.Turing.jl/src/mcmc/Inference.jl
Lines 268 to 278 in 5acc97f
This then goes to
src/mcmc/$sampler.jl
which defines the methodssample(::Model, ::Sampler{<:InferenceAlgorithm}, ::Int)
, e.g.Turing.jl/src/mcmc/hmc.jl
Lines 82 to 104 in 5acc97f
This then goes to AbstractMCMC's
sample
:https://github.com/TuringLang/AbstractMCMC.jl/blob/fdaa0ebce22ce227b068e847415cd9ee0e15c004/src/sample.jl#L255-L259
Which then calls
step(::AbstractRNG, ::Model, ::Sampler{<:InferenceAlgorithm})
, which is defined in DynamicPPL:https://github.com/TuringLang/DynamicPPL.jl/blob/072234d094d1d68064bf259d3c3e815a87c18c8e/src/sampler.jl#L108-L126
Which then calls
initialstep
, which goes back to being defined insrc/mcmc/$sampler.jl
:Turing.jl/src/mcmc/hmc.jl
Lines 141 to 149 in 5acc97f
(this signature claims to work on
AbstractModel
, it only really works forDynamicPPL.Model
)Inside here, we finally construct a
LogDensityFunction
from the model. So, there are very many steps between the time thatsample()
is called, and the time where aLogDensityFunction
is actually constructed.Proposal
Rework everything below the very first call to accept
LogDensityFunction
rather thanModel
. That is to say, the methodsample(::Model, ::InferenceAlgorithm, ::Int)
should look something like this:This would require making several changes across DynamicPPL and Turing. It (thankfully) probably does not need to touch AbstractMCMC, as long as we make
LogDensityFunction
a subtype ofAbstractMCMC.AbstractModel
(so thatmcmcsample
can work). That should be fine, becauseAbstractModel
has no interface.Why?
For one, this is probably the best way to let people have greater control over their sampling process. For example:
sample()
. Right now, this is actually very difficult to do. (Just try it!!) Note that this also provides a natural interface for opting into ThreadSafeVarInfo (cf.ThreadSafeVarInfo
andthreadid
DynamicPPL.jl#924)More philosophically, it's IMO the first step that's necessary towards encapsulating Turing's "magic behaviour" at the very top level of the call stack. We know that a
DynamicPPL.Model
on its own does not actually give enough information about how to evaluate it — it's onlyLogDensityFunction
that contains the necessary information. Thus, it shouldn't be the job of the low-level functions likestep
to make this decision — they should just 'receive' objects that are already complete.I'm still not sure about
How this will work with Gibbs. I haven't looked at it deeply enough.
The text was updated successfully, but these errors were encountered: