You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
It can happen that the modes of $\mu_1$ and $\mu_2$ switch between chains.
159
-
For more information see the [Stan documentation](https://mc-stan.org/users/documentation/case-studies/identifying_mixture_models.html)for potential solutions.
161
+
For more information see the [Stan documentation](https://mc-stan.org/users/documentation/case-studies/identifying_mixture_models.html). This is because it's possible for either model parameter $\mu_k$ to be assigned to either of the corresponding true means, and this assignment need not be consistent between chains.
160
162
161
-
We also inspect the samples of the mixture weights $w$.
163
+
That is, the posterior is fundamentally multimodal, and different chains can end up in different modes, complicating inference.
164
+
One solution here is to enforce an ordering on our $\mu$ vector, requiring $\mu_k > \mu_{k-1}$ for all $k$.
165
+
`Bijectors.jl`[provides](https://turinglang.org/Bijectors.jl/dev/transforms/#Bijectors.OrderedBijector) an easy transformation (`ordered()`) for this purpose:
166
+
167
+
```{julia}
168
+
@model function gaussian_mixture_model_ordered(x)
169
+
# Draw the parameters for each of the K=2 clusters from a standard normal distribution.
170
+
K = 2
171
+
μ ~ Bijectors.ordered(MvNormal(Zeros(K), I))
172
+
# Draw the weights for the K clusters from a Dirichlet distribution with parameters αₖ = 1.
173
+
w ~ Dirichlet(K, 1.0)
174
+
# Alternatively, one could use a fixed set of weights.
175
+
# w = fill(1/K, K)
176
+
# Construct categorical distribution of assignments.
177
+
distribution_assignments = Categorical(w)
178
+
# Construct multivariate normal distributions of each cluster.
179
+
D, N = size(x)
180
+
distribution_clusters = [MvNormal(Fill(μₖ, D), I) for μₖ in μ]
181
+
# Draw assignments for each datum and generate it from the multivariate normal distribution.
182
+
k = Vector{Int}(undef, N)
183
+
for i in 1:N
184
+
k[i] ~ distribution_assignments
185
+
x[:, i] ~ distribution_clusters[k[i]]
186
+
end
187
+
return k
188
+
end
189
+
190
+
model = gaussian_mixture_model_ordered(x);
191
+
```
192
+
193
+
194
+
Now, re-running our model, we can see that the assigned means are consistent across chains:
In the following, we just use the first chain to ensure the validity of our inference.
219
+
We also inspect the samples of the mixture weights $w$.
168
220
169
221
```{julia}
170
-
chain = chains[:, :, 1];
222
+
plot(chains[["w[1]", "w[2]"]]; legend=true)
171
223
```
172
224
173
225
As the distributions of the samples for the parameters $\mu_1$, $\mu_2$, $w_1$, and $w_2$ are unimodal, we can safely visualize the density region of our model using the average values.
174
226
175
227
```{julia}
176
228
# Model with mean of samples as parameters.
177
-
μ_mean = [mean(chain, "μ[$i]") for i in 1:2]
178
-
w_mean = [mean(chain, "w[$i]") for i in 1:2]
229
+
μ_mean = [mean(chains, "μ[$i]") for i in 1:2]
230
+
w_mean = [mean(chains, "w[$i]") for i in 1:2]
179
231
mixturemodel_mean = MixtureModel([MvNormal(Fill(μₖ, 2), I) for μₖ in μ_mean], w_mean)
Where we sum the components with `logsumexp` from the [`LogExpFunctions.jl` package](https://juliastats.org/LogExpFunctions.jl/stable/).
277
+
The manually incremented likelihood can be added to the log-probability with `Turing.@addlogprob!`, giving us the following model:
278
+
279
+
```{julia}
280
+
#| output: false
281
+
using LogExpFunctions
282
+
283
+
@model function gmm_marginalized(x)
284
+
K = 2
285
+
D, N = size(x)
286
+
μ ~ Bijectors.ordered(MvNormal(Zeros(K), I))
287
+
w ~ Dirichlet(K, 1.0)
288
+
dists = [MvNormal(Fill(μₖ, D), I) for μₖ in μ]
289
+
for i in 1:N
290
+
lvec = Vector(undef, K)
291
+
for k in 1:K
292
+
lvec[k] = (w[k] + logpdf(dists[k], x[:, i]))
293
+
end
294
+
Turing.@addlogprob! logsumexp(lvec)
295
+
end
296
+
end
297
+
```
298
+
299
+
::: {.callout-warning collapse="false"}
300
+
## Manually Incrementing Probablity
301
+
302
+
When possible, use of `Turing.@addlogprob!` should be avoided, as it exists outside the
303
+
usual structure of a Turing model. In most cases, a custom distribution should be used instead.
304
+
305
+
Here, the next section demonstrates the perfered method --- using the `MixtureModel` distribution we have seen already to
306
+
perform the marginalization automatically.
307
+
:::
308
+
309
+
310
+
### Marginalizing For Free With Distribution.jl's MixtureModel Implementation
311
+
312
+
We can use Turing's `~` syntax with anything that `Distributions.jl` provides `logpdf` and `rand` methods for. It turns out that the
313
+
`MixtureModel` distribution it provides has, as its `logpdf` method, `logpdf(MixtureModel([Component_Distributions], weight_vector), Y)`, where `Y` can be either a single observation or vector of observations.
314
+
315
+
In fact, `Distributions.jl` provides [many convenient constructors](https://juliastats.org/Distributions.jl/stable/mixture/) for mixture models, allowing further simplification in common special cases.
316
+
317
+
For example, when mixtures distributions are of the same type, one can write: `~ MixtureModel(Normal, [(μ1, σ1), (μ2, σ2)], w)`, or when the weight vector is known to allocate probability equally, it can be ommited.
318
+
319
+
The `logpdf` implementation for a `MixtureModel` distribution is exactly the marginalization defined above, and so our model becomes simply:
320
+
321
+
```{julia}
322
+
#| output: false
323
+
@model function gmm_marginalized(x)
324
+
K = 2
325
+
D, _ = size(x)
326
+
μ ~ Bijectors.ordered(MvNormal(Zeros(K), I))
327
+
w ~ Dirichlet(K, 1.0)
328
+
x ~ MixtureModel([MvNormal(Fill(μₖ, D), I) for μₖ in μ], w)
329
+
end
330
+
model = gmm_marginalized(x);
331
+
```
332
+
333
+
As we've summed out the discrete components, we can perform inference using `NUTS()` alone.
# Verify for marginalized model that the output of the chain is as expected
346
+
for i in MCMCChains.chains(chains)
347
+
# μ[1] and μ[2] can no longer switch places. Check that they've found the mean
348
+
chain = Array(chains[:, ["μ[1]", "μ[2]"], i])
349
+
μ_mean = vec(mean(chain; dims=1))
350
+
@assert isapprox(sort(μ_mean), μ; rtol=0.4) "Difference between estimated mean of μ ($(sort(μ_mean))) and data-generating μ ($μ) unexpectedly large!"
351
+
end
352
+
end
353
+
```
354
+
355
+
`NUTS()` significantly outperforms our compositional Gibbs sampler, in large part because our model is now Rao-Blackwellized thanks to
356
+
the marginalization of our assignment parameter.
357
+
358
+
```{julia}
359
+
plot(chains[["μ[1]", "μ[2]"]], legend=true)
360
+
```
361
+
362
+
## Inferred Assignments - Marginalized Model
363
+
As we've summed over possible assignments, the associated parameter is no longer available in our chain.
364
+
This is not a problem, however, as given any fixed sample $(\mu, w)$, the assignment probability — $p(z_i \mid y_i)$ — can be recovered using Bayes rule:
This quantity can be computed for every $p(z = z_i \mid y_i)$, resulting in a probability vector, which is then used to sample
370
+
posterior predictive assignments from a categorial distribution.
371
+
For details on the mathematics here, see [the Stan documentation on latent discrete parameters](https://mc-stan.org/docs/stan-users-guide/latent-discrete.html).
372
+
```{julia}
373
+
#| output: false
374
+
function sample_class(xi, dists, w)
375
+
lvec = [(logpdf(d, xi) + log(w[i])) for (i, d) in enumerate(dists)]
376
+
rand(Categorical(softmax(lvec)))
377
+
end
378
+
379
+
@model function gmm_recover(x)
380
+
K = 2
381
+
D, N = size(x)
382
+
μ ~ Bijectors.ordered(MvNormal(Zeros(K), I))
383
+
w ~ Dirichlet(K, 1.0)
384
+
dists = [MvNormal(Fill(μₖ, D), I) for μₖ in μ]
385
+
x ~ MixtureModel(dists, w)
386
+
# Return assignment draws for each datapoint.
387
+
return [sample_class(x[:, i], dists, w) for i in 1:N]
0 commit comments