Skip to content

Commit e2ba322

Browse files
authored
Substantially Update GMM-01 (#447)
1 parent 4468838 commit e2ba322

File tree

3 files changed

+235
-25
lines changed

3 files changed

+235
-25
lines changed

tutorials/01-gaussian-mixture-model/Manifest.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# This file is machine-generated - editing it directly is not advised
22

3-
julia_version = "1.10.0"
3+
julia_version = "1.10.3"
44
manifest_format = "2.0"
5-
project_hash = "a0a3ca6aa1eb54296a719c0190365d414d855bc0"
5+
project_hash = "eaea37cdde9f1d06a5695c8ca16ab89aa3a9a2b4"
66

77
[[deps.ADTypes]]
88
git-tree-sha1 = "daf26bbdec60d9ca1c0003b70f389d821ddb4224"
@@ -383,7 +383,7 @@ weakdeps = ["Dates", "LinearAlgebra"]
383383
[[deps.CompilerSupportLibraries_jll]]
384384
deps = ["Artifacts", "Libdl"]
385385
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
386-
version = "1.0.5+1"
386+
version = "1.1.1+0"
387387

388388
[[deps.CompositionsBase]]
389389
git-tree-sha1 = "802bb88cd69dfd1509f6670416bd4434015693ad"
@@ -1238,7 +1238,7 @@ version = "1.3.5+1"
12381238
[[deps.OpenBLAS_jll]]
12391239
deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"]
12401240
uuid = "4536629a-c528-5b80-bd46-f80d51c5b363"
1241-
version = "0.3.23+2"
1241+
version = "0.3.23+4"
12421242

12431243
[[deps.OpenLibm_jll]]
12441244
deps = ["Artifacts", "Libdl"]
Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
[deps]
2-
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
3-
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
4-
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
5-
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
6-
StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd"
7-
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
1+
[deps]
2+
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
3+
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
4+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
5+
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
6+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
7+
StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd"
8+
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"

tutorials/01-gaussian-mixture-model/index.qmd

Lines changed: 223 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -115,15 +115,17 @@ We generate multiple chains in parallel using multi-threading.
115115

116116
```{julia}
117117
#| output: false
118+
#| echo: false
118119
setprogress!(false)
119120
```
120121

121122
```{julia}
122123
#| output: false
123124
sampler = Gibbs(PG(100, :k), HMC(0.05, 10, :μ, :w))
124-
nsamples = 100
125-
nchains = 3
126-
chains = sample(model, sampler, MCMCThreads(), nsamples, nchains);
125+
nsamples = 150
126+
nchains = 4
127+
burn = 10
128+
chains = sample(model, sampler, MCMCThreads(), nsamples, nchains, discard_initial = burn);
127129
```
128130

129131
::: {.callout-warning collapse="true"}
@@ -152,32 +154,81 @@ After sampling we can visualize the trace and density of the parameters of inter
152154
We consider the samples of the location parameters $\mu_1$ and $\mu_2$ for the two clusters.
153155

154156
```{julia}
155-
plot(chains[["μ[1]", "μ[2]"]]; colordim=:parameter, legend=true)
157+
plot(chains[["μ[1]", "μ[2]"]]; legend=true)
156158
```
157159

158160
It can happen that the modes of $\mu_1$ and $\mu_2$ switch between chains.
159-
For more information see the [Stan documentation](https://mc-stan.org/users/documentation/case-studies/identifying_mixture_models.html) for potential solutions.
161+
For more information see the [Stan documentation](https://mc-stan.org/users/documentation/case-studies/identifying_mixture_models.html). This is because it's possible for either model parameter $\mu_k$ to be assigned to either of the corresponding true means, and this assignment need not be consistent between chains.
160162

161-
We also inspect the samples of the mixture weights $w$.
163+
That is, the posterior is fundamentally multimodal, and different chains can end up in different modes, complicating inference.
164+
One solution here is to enforce an ordering on our $\mu$ vector, requiring $\mu_k > \mu_{k-1}$ for all $k$.
165+
`Bijectors.jl` [provides](https://turinglang.org/Bijectors.jl/dev/transforms/#Bijectors.OrderedBijector) an easy transformation (`ordered()`) for this purpose:
166+
167+
```{julia}
168+
@model function gaussian_mixture_model_ordered(x)
169+
# Draw the parameters for each of the K=2 clusters from a standard normal distribution.
170+
K = 2
171+
μ ~ Bijectors.ordered(MvNormal(Zeros(K), I))
172+
# Draw the weights for the K clusters from a Dirichlet distribution with parameters αₖ = 1.
173+
w ~ Dirichlet(K, 1.0)
174+
# Alternatively, one could use a fixed set of weights.
175+
# w = fill(1/K, K)
176+
# Construct categorical distribution of assignments.
177+
distribution_assignments = Categorical(w)
178+
# Construct multivariate normal distributions of each cluster.
179+
D, N = size(x)
180+
distribution_clusters = [MvNormal(Fill(μₖ, D), I) for μₖ in μ]
181+
# Draw assignments for each datum and generate it from the multivariate normal distribution.
182+
k = Vector{Int}(undef, N)
183+
for i in 1:N
184+
k[i] ~ distribution_assignments
185+
x[:, i] ~ distribution_clusters[k[i]]
186+
end
187+
return k
188+
end
189+
190+
model = gaussian_mixture_model_ordered(x);
191+
```
192+
193+
194+
Now, re-running our model, we can see that the assigned means are consistent across chains:
195+
196+
```{julia}
197+
#| output: false
198+
chains = sample(model, sampler, MCMCThreads(), nsamples, nchains, discard_initial = burn);
199+
```
200+
201+
202+
```{julia}
203+
#| echo: false
204+
let
205+
# Verify that the output of the chain is as expected
206+
for i in MCMCChains.chains(chains)
207+
# μ[1] and μ[2] can no longer switch places. Check that they've found the mean
208+
chain = Array(chains[:, ["μ[1]", "μ[2]"], i])
209+
μ_mean = vec(mean(chain; dims=1))
210+
@assert isapprox(sort(μ_mean), μ; rtol=0.4) "Difference between estimated mean of μ ($(sort(μ_mean))) and data-generating μ ($μ) unexpectedly large!"
211+
end
212+
end
213+
```
162214

163215
```{julia}
164-
plot(chains[["w[1]", "w[2]"]]; colordim=:parameter, legend=true)
216+
plot(chains[["μ[1]", "μ[2]"]]; legend=true)
165217
```
166218

167-
In the following, we just use the first chain to ensure the validity of our inference.
219+
We also inspect the samples of the mixture weights $w$.
168220

169221
```{julia}
170-
chain = chains[:, :, 1];
222+
plot(chains[["w[1]", "w[2]"]]; legend=true)
171223
```
172224

173225
As the distributions of the samples for the parameters $\mu_1$, $\mu_2$, $w_1$, and $w_2$ are unimodal, we can safely visualize the density region of our model using the average values.
174226

175227
```{julia}
176228
# Model with mean of samples as parameters.
177-
μ_mean = [mean(chain, "μ[$i]") for i in 1:2]
178-
w_mean = [mean(chain, "w[$i]") for i in 1:2]
229+
μ_mean = [mean(chains, "μ[$i]") for i in 1:2]
230+
w_mean = [mean(chains, "w[$i]") for i in 1:2]
179231
mixturemodel_mean = MixtureModel([MvNormal(Fill(μₖ, 2), I) for μₖ in μ_mean], w_mean)
180-
181232
contour(
182233
range(-7.5, 3; length=1_000),
183234
range(-6.5, 3; length=1_000),
@@ -188,12 +239,11 @@ scatter!(x[1, :], x[2, :]; legend=false, title="Synthetic Dataset")
188239
```
189240

190241
## Inferred Assignments
191-
192242
Finally, we can inspect the assignments of the data points inferred using Turing.
193243
As we can see, the dataset is partitioned into two distinct groups.
194244

195245
```{julia}
196-
assignments = [mean(chain, "k[$i]") for i in 1:N]
246+
assignments = [mean(chains, "k[$i]") for i in 1:N]
197247
scatter(
198248
x[1, :],
199249
x[2, :];
@@ -202,3 +252,162 @@ scatter(
202252
zcolor=assignments,
203253
)
204254
```
255+
256+
257+
## Marginalizing Out The Assignments
258+
We can write out the marginal posterior of (continuous) $w, \mu$ by summing out the influence of our (discrete) assignments $z_i$ from
259+
our likelihood:
260+
$$
261+
p(y \mid w, \mu ) = \sum_{k=1}^K w_k p_k(y \mid \mu_k)
262+
$$
263+
In our case, this gives us:
264+
$$
265+
p(y \mid w, \mu) = \sum_{k=1}^K w_k \cdot \operatorname{MvNormal}(y \mid \mu_k, I)
266+
$$
267+
268+
269+
### Marginalizing By Hand
270+
We could implement the above version of the Gaussian mixture model in Turing as follows:
271+
First, Turing uses log-probabilities, so the likelihood above must be converted into log-space:
272+
$$
273+
\log \left( p(y \mid w, \mu) \right) = \text{logsumexp} \left[\log (w_k) + \log(\operatorname{MvNormal}(y \mid \mu_k, I)) \right]
274+
$$
275+
276+
Where we sum the components with `logsumexp` from the [`LogExpFunctions.jl` package](https://juliastats.org/LogExpFunctions.jl/stable/).
277+
The manually incremented likelihood can be added to the log-probability with `Turing.@addlogprob!`, giving us the following model:
278+
279+
```{julia}
280+
#| output: false
281+
using LogExpFunctions
282+
283+
@model function gmm_marginalized(x)
284+
K = 2
285+
D, N = size(x)
286+
μ ~ Bijectors.ordered(MvNormal(Zeros(K), I))
287+
w ~ Dirichlet(K, 1.0)
288+
dists = [MvNormal(Fill(μₖ, D), I) for μₖ in μ]
289+
for i in 1:N
290+
lvec = Vector(undef, K)
291+
for k in 1:K
292+
lvec[k] = (w[k] + logpdf(dists[k], x[:, i]))
293+
end
294+
Turing.@addlogprob! logsumexp(lvec)
295+
end
296+
end
297+
```
298+
299+
::: {.callout-warning collapse="false"}
300+
## Manually Incrementing Probablity
301+
302+
When possible, use of `Turing.@addlogprob!` should be avoided, as it exists outside the
303+
usual structure of a Turing model. In most cases, a custom distribution should be used instead.
304+
305+
Here, the next section demonstrates the perfered method --- using the `MixtureModel` distribution we have seen already to
306+
perform the marginalization automatically.
307+
:::
308+
309+
310+
### Marginalizing For Free With Distribution.jl's MixtureModel Implementation
311+
312+
We can use Turing's `~` syntax with anything that `Distributions.jl` provides `logpdf` and `rand` methods for. It turns out that the
313+
`MixtureModel` distribution it provides has, as its `logpdf` method, `logpdf(MixtureModel([Component_Distributions], weight_vector), Y)`, where `Y` can be either a single observation or vector of observations.
314+
315+
In fact, `Distributions.jl` provides [many convenient constructors](https://juliastats.org/Distributions.jl/stable/mixture/) for mixture models, allowing further simplification in common special cases.
316+
317+
For example, when mixtures distributions are of the same type, one can write: `~ MixtureModel(Normal, [(μ1, σ1), (μ2, σ2)], w)`, or when the weight vector is known to allocate probability equally, it can be ommited.
318+
319+
The `logpdf` implementation for a `MixtureModel` distribution is exactly the marginalization defined above, and so our model becomes simply:
320+
321+
```{julia}
322+
#| output: false
323+
@model function gmm_marginalized(x)
324+
K = 2
325+
D, _ = size(x)
326+
μ ~ Bijectors.ordered(MvNormal(Zeros(K), I))
327+
w ~ Dirichlet(K, 1.0)
328+
x ~ MixtureModel([MvNormal(Fill(μₖ, D), I) for μₖ in μ], w)
329+
end
330+
model = gmm_marginalized(x);
331+
```
332+
333+
As we've summed out the discrete components, we can perform inference using `NUTS()` alone.
334+
335+
```{julia}
336+
#| output: false
337+
sampler = NUTS()
338+
chains = sample(model, sampler, MCMCThreads(), nsamples, nchains; discard_initial = burn);
339+
```
340+
341+
342+
```{julia}
343+
#| echo: false
344+
let
345+
# Verify for marginalized model that the output of the chain is as expected
346+
for i in MCMCChains.chains(chains)
347+
# μ[1] and μ[2] can no longer switch places. Check that they've found the mean
348+
chain = Array(chains[:, ["μ[1]", "μ[2]"], i])
349+
μ_mean = vec(mean(chain; dims=1))
350+
@assert isapprox(sort(μ_mean), μ; rtol=0.4) "Difference between estimated mean of μ ($(sort(μ_mean))) and data-generating μ ($μ) unexpectedly large!"
351+
end
352+
end
353+
```
354+
355+
`NUTS()` significantly outperforms our compositional Gibbs sampler, in large part because our model is now Rao-Blackwellized thanks to
356+
the marginalization of our assignment parameter.
357+
358+
```{julia}
359+
plot(chains[["μ[1]", "μ[2]"]], legend=true)
360+
```
361+
362+
## Inferred Assignments - Marginalized Model
363+
As we've summed over possible assignments, the associated parameter is no longer available in our chain.
364+
This is not a problem, however, as given any fixed sample $(\mu, w)$, the assignment probability — $p(z_i \mid y_i)$ — can be recovered using Bayes rule:
365+
$$
366+
p(z_i \mid y_i) = \frac{p(y_i \mid z_i) p(z_i)}{\sum_{k = 1}^K \left(p(y_i \mid z_i) p(z_i) \right)}
367+
$$
368+
369+
This quantity can be computed for every $p(z = z_i \mid y_i)$, resulting in a probability vector, which is then used to sample
370+
posterior predictive assignments from a categorial distribution.
371+
For details on the mathematics here, see [the Stan documentation on latent discrete parameters](https://mc-stan.org/docs/stan-users-guide/latent-discrete.html).
372+
```{julia}
373+
#| output: false
374+
function sample_class(xi, dists, w)
375+
lvec = [(logpdf(d, xi) + log(w[i])) for (i, d) in enumerate(dists)]
376+
rand(Categorical(softmax(lvec)))
377+
end
378+
379+
@model function gmm_recover(x)
380+
K = 2
381+
D, N = size(x)
382+
μ ~ Bijectors.ordered(MvNormal(Zeros(K), I))
383+
w ~ Dirichlet(K, 1.0)
384+
dists = [MvNormal(Fill(μₖ, D), I) for μₖ in μ]
385+
x ~ MixtureModel(dists, w)
386+
# Return assignment draws for each datapoint.
387+
return [sample_class(x[:, i], dists, w) for i in 1:N]
388+
end
389+
```
390+
391+
We sample from this model as before:
392+
393+
```{julia}
394+
#| output: false
395+
model = gmm_recover(x)
396+
chains = sample(model, sampler, MCMCThreads(), nsamples, nchains, discard_initial = burn);
397+
```
398+
399+
Given a sample from the marginalized posterior, these assignments can be recovered with:
400+
401+
```{julia}
402+
assignments = mean(generated_quantities(gmm_recover(x), chains));
403+
```
404+
405+
```{julia}
406+
scatter(
407+
x[1, :],
408+
x[2, :];
409+
legend=false,
410+
title="Assignments on Synthetic Dataset - Recovered",
411+
zcolor=assignments,
412+
)
413+
```

0 commit comments

Comments
 (0)