-
Notifications
You must be signed in to change notification settings - Fork 104
Substantial updates to tutorial 01_gaussian-mixture-model #439
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
915f0f8
updates to 01_gmm tutorial
JasonPekos b1069a0
remove files that shouldn't have gotten picked up by git?
JasonPekos 5f0914c
remove files that shouldn't be here
JasonPekos 87ffabd
Update tutorials/01-gaussian-mixture-model/01_gaussian-mixture-model.jmd
JasonPekos 3952644
Update tutorials/01-gaussian-mixture-model/01_gaussian-mixture-model.jmd
JasonPekos 835203f
Update tutorials/01-gaussian-mixture-model/01_gaussian-mixture-model.jmd
JasonPekos 99e7313
Update tutorials/01-gaussian-mixture-model/01_gaussian-mixture-model.jmd
JasonPekos 17b3990
Update tutorials/01-gaussian-mixture-model/01_gaussian-mixture-model.jmd
JasonPekos cf5469f
Update tutorials/01-gaussian-mixture-model/01_gaussian-mixture-model.jmd
JasonPekos d56a1a1
various small fixes to 01_gmm
JasonPekos baeb774
pull class recovery function outside of model
JasonPekos d937095
no longer manually import bijectors
JasonPekos 5517704
reword clunky sentence
JasonPekos 1c031f5
reword clunky sentence
JasonPekos 5672a23
grammar
JasonPekos 80e8ed5
Merge branch 'TuringLang:master' into master
JasonPekos File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -112,18 +112,19 @@ We generate multiple chains in parallel using multi-threading. | |||||||
```julia | ||||||||
sampler = Gibbs(PG(100, :k), HMC(0.05, 10, :μ, :w)) | ||||||||
nsamples = 100 | ||||||||
nchains = 3 | ||||||||
chains = sample(model, sampler, MCMCThreads(), nsamples, nchains); | ||||||||
nchains = 4 | ||||||||
burn = 10 | ||||||||
chains = sample(model, sampler, MCMCThreads(), nsamples, nchains; discard_initial = burn); | ||||||||
``` | ||||||||
|
||||||||
```julia; echo=false | ||||||||
let | ||||||||
# Verify that the output of the chain is as expected. | ||||||||
# Verify that the output of the chain is as expected | ||||||||
for i in MCMCChains.chains(chains) | ||||||||
# μ[1] and μ[2] can switch places, so we sort the values first. | ||||||||
chain = Array(chains[:, ["μ[1]", "μ[2]"], i]) | ||||||||
μ_mean = vec(mean(chain; dims=1)) | ||||||||
@assert isapprox(sort(μ_mean), μ; rtol=0.1) "Difference between estimated mean of μ ($(sort(μ_mean))) and data-generating μ ($μ) unexpectedly large!" | ||||||||
# In this case, we *want* to see the degenerate behaviour | ||||||||
# So error if Rhat is *small*. | ||||||||
rhat = MCMCChains.rhat(chains) | ||||||||
@assert maximum(rhat[:, :rhat]) > 2 "Example intended to demonstrate multi-modality likely failed to find both modes!" | ||||||||
end | ||||||||
end | ||||||||
``` | ||||||||
|
@@ -135,30 +136,83 @@ After sampling we can visualize the trace and density of the parameters of inter | |||||||
We consider the samples of the location parameters $\mu_1$ and $\mu_2$ for the two clusters. | ||||||||
|
||||||||
```julia | ||||||||
plot(chains[["μ[1]", "μ[2]"]]; colordim=:parameter, legend=true) | ||||||||
plot(chains[["μ[1]", "μ[2]"]]; legend=true) | ||||||||
``` | ||||||||
|
||||||||
It can happen that the modes of $\mu_1$ and $\mu_2$ switch between chains. | ||||||||
For more information see the [Stan documentation](https://mc-stan.org/users/documentation/case-studies/identifying_mixture_models.html) for potential solutions. | ||||||||
For more information see the [Stan documentation](https://mc-stan.org/users/documentation/case-studies/identifying_mixture_models.html). This is because it's possible for either model parameter $\mu_k$ to be assigned to either of the corresponding true means, and this assignment need not be consistent between chains. | ||||||||
|
||||||||
We also inspect the samples of the mixture weights $w$. | ||||||||
That is, the posterior is fundamentally multimodal, and different chains can end up in different modes, complicating inference. | ||||||||
|
||||||||
One solution here is to enforce an ordering on our $\mu$ vector, requiring $\mu_k > \mu_{k-1}$ for all $k$. | ||||||||
|
||||||||
`Bijectors.jl` [provides](https://turinglang.org/Bijectors.jl/dev/transforms/#Bijectors.OrderedBijector) an easy transformation (`ordered()`) for this purpose: | ||||||||
|
||||||||
```julia | ||||||||
@model function gaussian_mixture_model_ordered(x) | ||||||||
# Draw the parameters for each of the K=2 clusters from a standard normal distribution. | ||||||||
K = 2 | ||||||||
μ ~ Bijectors.ordered(MvNormal(Zeros(K), I)) | ||||||||
|
||||||||
# Draw the weights for the K clusters from a Dirichlet distribution with parameters αₖ = 1. | ||||||||
w ~ Dirichlet(K, 1.0) | ||||||||
# Alternatively, one could use a fixed set of weights. | ||||||||
# w = fill(1/K, K) | ||||||||
|
||||||||
# Construct categorical distribution of assignments. | ||||||||
distribution_assignments = Categorical(w) | ||||||||
|
||||||||
# Construct multivariate normal distributions of each cluster. | ||||||||
D, N = size(x) | ||||||||
distribution_clusters = [MvNormal(Fill(μₖ, D), I) for μₖ in μ] | ||||||||
|
||||||||
# Draw assignments for each datum and generate it from the multivariate normal distribution. | ||||||||
k = Vector{Int}(undef, N) | ||||||||
for i in 1:N | ||||||||
k[i] ~ distribution_assignments | ||||||||
x[:, i] ~ distribution_clusters[k[i]] | ||||||||
end | ||||||||
|
||||||||
return k | ||||||||
end | ||||||||
|
||||||||
model = gaussian_mixture_model_ordered(x); | ||||||||
``` | ||||||||
|
||||||||
Now, re-running our model, we can see that the assigned means are consistent across chains: | ||||||||
|
||||||||
```julia | ||||||||
chains = sample(model, sampler, nsamples, nchains; discard_initial = burn); | ||||||||
``` | ||||||||
|
||||||||
```julia; echo = false | ||||||||
let | ||||||||
# Verify that the output of the chain is as expected | ||||||||
for i in MCMCChains.chains(chains) | ||||||||
# μ[1] and μ[2] can no longer switch places. Check that they've found the mean | ||||||||
chain = Array(chains[:, ["μ[1]", "μ[2]"], i]) | ||||||||
μ_mean = vec(mean(chain; dims=1)) | ||||||||
@assert isapprox(sort(μ_mean), μ; rtol=0.4) "Difference between estimated mean of μ ($(sort(μ_mean))) and data-generating μ ($μ) unexpectedly large!" | ||||||||
end | ||||||||
end | ||||||||
``` | ||||||||
|
||||||||
```julia | ||||||||
plot(chains[["w[1]", "w[2]"]]; colordim=:parameter, legend=true) | ||||||||
plot(chains[["μ[1]", "μ[2]"]]; legend=true) | ||||||||
``` | ||||||||
|
||||||||
In the following, we just use the first chain to ensure the validity of our inference. | ||||||||
We also inspect the samples of the mixture weights $w$. | ||||||||
|
||||||||
```julia | ||||||||
chain = chains[:, :, 1]; | ||||||||
plot(chains[["w[1]", "w[2]"]]; legend=true) | ||||||||
``` | ||||||||
|
||||||||
As the distributions of the samples for the parameters $\mu_1$, $\mu_2$, $w_1$, and $w_2$ are unimodal, we can safely visualize the density region of our model using the average values. | ||||||||
|
||||||||
```julia | ||||||||
# Model with mean of samples as parameters. | ||||||||
μ_mean = [mean(chain, "μ[$i]") for i in 1:2] | ||||||||
w_mean = [mean(chain, "w[$i]") for i in 1:2] | ||||||||
μ_mean = [mean(chains, "μ[$i]") for i in 1:2] | ||||||||
w_mean = [mean(chains, "w[$i]") for i in 1:2] | ||||||||
mixturemodel_mean = MixtureModel([MvNormal(Fill(μₖ, 2), I) for μₖ in μ_mean], w_mean) | ||||||||
|
||||||||
contour( | ||||||||
|
@@ -176,7 +230,7 @@ Finally, we can inspect the assignments of the data points inferred using Turing | |||||||
As we can see, the dataset is partitioned into two distinct groups. | ||||||||
|
||||||||
```julia | ||||||||
assignments = [mean(chain, "k[$i]") for i in 1:N] | ||||||||
assignments = [mean(chains, "k[$i]") for i in 1:N] | ||||||||
scatter( | ||||||||
x[1, :], | ||||||||
x[2, :]; | ||||||||
|
@@ -186,7 +240,168 @@ scatter( | |||||||
) | ||||||||
``` | ||||||||
|
||||||||
```julia, echo=false, skip="notebook", tangle=false | ||||||||
## Marginalizing Out The Assignments | ||||||||
|
||||||||
We can write out the marginal posterior of (continuous) $w, \mu$ by summing out the influence of our (discrete) assignments $z_i$ from | ||||||||
our likelihood: | ||||||||
|
||||||||
$$ | ||||||||
p(y \mid w, \mu ) = \sum_{k=1}^K w_k p_k(y \mid \mu_k) | ||||||||
$$ | ||||||||
|
||||||||
In our case, this gives us: | ||||||||
|
||||||||
$$ | ||||||||
p(y \mid w, \mu) = \sum_{k=1}^K w_k \cdot \operatorname{MvNormal}(y \mid \mu_k, I) | ||||||||
$$ | ||||||||
|
||||||||
|
||||||||
### Marginalizing By Hand | ||||||||
|
||||||||
We can implement the above version of the Gaussian mixture model in Turing as follows: | ||||||||
|
||||||||
First, Turing uses log-probabilities, so the likelihood above must be converted into log-space: | ||||||||
|
||||||||
$$ | ||||||||
\log \left( p(y \mid w, \mu) \right) = \text{logsumexp} \left[\log (w_k) + \log(\operatorname{MvNormal}(y \mid \mu_k, I)) \right] | ||||||||
$$ | ||||||||
|
||||||||
Where we sum the components with `logsumexp` from the [`LogExpFunctions.jl` package](https://juliastats.org/LogExpFunctions.jl/stable/). | ||||||||
|
||||||||
|
||||||||
The manually incremented likelihood can be added to the log-probability with `Turing.@addlogprob!`, giving us the following model: | ||||||||
|
||||||||
```julia | ||||||||
using StatsFuns | ||||||||
|
||||||||
@model function gmm_marginalized(x) | ||||||||
K = 2 | ||||||||
D, N = size(x) | ||||||||
μ ~ Bijectors.ordered(MvNormal(Zeros(K), I)) | ||||||||
w ~ Dirichlet(K, 1.0) | ||||||||
dists = [MvNormal(Fill(μₖ, D), I) for μₖ in μ] | ||||||||
|
||||||||
for i in 1:N | ||||||||
lvec = Vector(undef, K) | ||||||||
for k in 1:K | ||||||||
lvec[k] = (w[k] + logpdf(dists[k], x[:, i])) | ||||||||
end | ||||||||
Turing.@addlogprob! logsumexp(lvec) | ||||||||
end | ||||||||
end | ||||||||
|
||||||||
model = gmm_marginalized(x); | ||||||||
``` | ||||||||
|
||||||||
### Marginalizing For Free With Distribution.jl's MixtureModel Implementation | ||||||||
|
||||||||
We can use Turing's `~` syntax with anything that `Distributions.jl` provides `logpdf` and `rand` methods for. It turns out that the | ||||||||
`MixtureModel` distribution it provides has, as its `logpdf` method, `logpdf(MixtureModel([Component_Distributions], weight_vector), Y)`, where `Y` can be either a single observation or vector of observations. | ||||||||
|
||||||||
In fact, `Distributions.jl` provides [many convenient constructors](https://juliastats.org/Distributions.jl/stable/mixture/) for mixture models, allowing further simplification in common special cases. | ||||||||
|
||||||||
For example, when mixtures distributions are of the same type, one can write: `~ MixtureModel(Normal, [(μ1, σ1), (μ2, σ2)], w)`, or when the weight vector is known to allocate probability equally, it can be ommited. | ||||||||
|
||||||||
The `logpdf` implementation for a `MixtureModel` distribution is exactly the marginalization defined above, and so our model becomes simply: | ||||||||
|
||||||||
```julia | ||||||||
@model function gmm_marginalized(x) | ||||||||
K = 2 | ||||||||
D, _ = size(x) | ||||||||
μ ~ Bijectors.ordered(MvNormal(Zeros(K), I)) | ||||||||
w ~ Dirichlet(K, 1.0) | ||||||||
|
||||||||
x ~ MixtureModel([MvNormal(Fill(μₖ, D), I) for μₖ in μ], w) | ||||||||
end | ||||||||
|
||||||||
model = gmm_marginalized(x); | ||||||||
``` | ||||||||
|
||||||||
As we've summed out the discrete components, we can perform inference using `NUTS()` alone. | ||||||||
|
||||||||
```julia | ||||||||
sampler = NUTS() | ||||||||
chains = sample(model, sampler, MCMCThreads(), nsamples, nchains; discard_initial = burn); | ||||||||
Comment on lines
+323
to
+324
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
``` | ||||||||
|
||||||||
```julia; echo=false | ||||||||
let | ||||||||
# Verify for marginalized model that the output of the chain is as expected | ||||||||
for i in MCMCChains.chains(chains) | ||||||||
# μ[1] and μ[2] can no longer switch places. Check that they've found the mean | ||||||||
chain = Array(chains[:, ["μ[1]", "μ[2]"], i]) | ||||||||
μ_mean = vec(mean(chain; dims=1)) | ||||||||
@assert isapprox(sort(μ_mean), μ; rtol=0.4) "Difference between estimated mean of μ ($(sort(μ_mean))) and data-generating μ ($μ) unexpectedly large!" | ||||||||
end | ||||||||
end | ||||||||
``` | ||||||||
|
||||||||
`NUTS()` significantly outperforms our compositional Gibbs sampler, in large part because our model is now Rao-Blackwellized thanks to | ||||||||
the marginalization of our assignment parameter. | ||||||||
|
||||||||
|
||||||||
```julia | ||||||||
plot(chains[["μ[1]", "μ[2]"]], legend=true) | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
``` | ||||||||
|
||||||||
## Inferred Assignments - Marginalized Model | ||||||||
|
||||||||
As we've summed over possible assignments, the associated parameter is no longer available in our chain. | ||||||||
This is not a problem, however, as given any fixed sample $(\mu, w)$, the assignment probability — $p(z_i \mid y_i)$ — can be recovered using Bayes rule: | ||||||||
|
||||||||
$$ | ||||||||
p(z_i \mid y_i) = \frac{p(y_i \mid z_i) p(z_i)}{\sum_{k = 1}^K \left(p(y_i \mid z_i) p(z_i) \right)} | ||||||||
$$ | ||||||||
|
||||||||
This quantity can be computed for every $p(z = z_i \mid y_i)$, resulting in a probability vector, which is then used to sample | ||||||||
posterior predictive assignments from a categorial distribution. | ||||||||
|
||||||||
For details on the mathematics here, see [the Stan documentation on latent discrete parameters](https://mc-stan.org/docs/stan-users-guide/latent-discrete.html). | ||||||||
|
||||||||
```julia | ||||||||
function sample_class(xi, dists, w) | ||||||||
lvec = [(logpdf(d, xi) + log(w[i])) for (i, d) in enumerate(dists)] | ||||||||
rand(Categorical(softmax(lvec))) | ||||||||
end | ||||||||
|
||||||||
@model function gmm_recover(x) | ||||||||
K = 2 | ||||||||
D, N = size(x) | ||||||||
μ ~ Bijectors.ordered(MvNormal(Zeros(K), I)) | ||||||||
w ~ Dirichlet(K, 1.0) | ||||||||
|
||||||||
dists = [MvNormal(Fill(μₖ, D), I) for μₖ in μ] | ||||||||
|
||||||||
x ~ MixtureModel(dists, w) | ||||||||
|
||||||||
# Return assignment draws for each datapoint. | ||||||||
return [sample_class(x[:, i], dists, w) for i in 1:N] | ||||||||
end | ||||||||
``` | ||||||||
|
||||||||
We sample from this model as before: | ||||||||
|
||||||||
```julia | ||||||||
chains = sample(model, NUTS(), nsamples, nchains; discard_initial = burn); | ||||||||
``` | ||||||||
|
||||||||
Given a sample from the marginalized posterior, these assignments can be recovered with: | ||||||||
|
||||||||
```julia | ||||||||
assignments = mean(generated_quantities(gmm_recover(x), chains)) | ||||||||
``` | ||||||||
|
||||||||
```julia | ||||||||
scatter( | ||||||||
x[1, :], | ||||||||
x[2, :]; | ||||||||
legend=false, | ||||||||
title="Assignments on Synthetic Dataset - Recovered", | ||||||||
zcolor=assignments, | ||||||||
) | ||||||||
``` | ||||||||
|
||||||||
```julia; echo=false, skip="notebook", tangle=false | ||||||||
if isdefined(Main, :TuringTutorials) | ||||||||
Main.TuringTutorials.tutorial_footer(WEAVE_ARGS[:folder], WEAVE_ARGS[:file]) | ||||||||
end | ||||||||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO we should not recommend the use of
Turing.@addlogprob!
in it's so easy to misuse and to get (silently) wrong results because it operates completely outside of the~
logic in Turing/DynamicPPL. Instead I think usually one should use~
with a (possibly custom) distribution.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good! I initially wasn't going to include that section for basically the reasons you bring up, but I ended up including it (even though I don't actually sample from that model) to motivate what's going on with the MixtureModel lpdf.
I can replace it with a custom distribution (although this might be a little long for a model that's really just exposition), or omit it entirely.