From ef20a49c0a38c9cdac1ba7362334434025713299 Mon Sep 17 00:00:00 2001 From: Jason Pekos <78766845+JasonPekos@users.noreply.github.com> Date: Thu, 23 May 2024 19:59:08 -0400 Subject: [PATCH] Substantially Update GMM-01 --- .../01-gaussian-mixture-model/Manifest.toml | 8 +- .../01-gaussian-mixture-model/Project.toml | 15 +- tutorials/01-gaussian-mixture-model/index.qmd | 237 ++++++++++++++++-- 3 files changed, 235 insertions(+), 25 deletions(-) diff --git a/tutorials/01-gaussian-mixture-model/Manifest.toml b/tutorials/01-gaussian-mixture-model/Manifest.toml index 7d6912e68..fcef503b8 100755 --- a/tutorials/01-gaussian-mixture-model/Manifest.toml +++ b/tutorials/01-gaussian-mixture-model/Manifest.toml @@ -1,8 +1,8 @@ # This file is machine-generated - editing it directly is not advised -julia_version = "1.10.0" +julia_version = "1.10.3" manifest_format = "2.0" -project_hash = "a0a3ca6aa1eb54296a719c0190365d414d855bc0" +project_hash = "eaea37cdde9f1d06a5695c8ca16ab89aa3a9a2b4" [[deps.ADTypes]] git-tree-sha1 = "daf26bbdec60d9ca1c0003b70f389d821ddb4224" @@ -383,7 +383,7 @@ weakdeps = ["Dates", "LinearAlgebra"] [[deps.CompilerSupportLibraries_jll]] deps = ["Artifacts", "Libdl"] uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" -version = "1.0.5+1" +version = "1.1.1+0" [[deps.CompositionsBase]] git-tree-sha1 = "802bb88cd69dfd1509f6670416bd4434015693ad" @@ -1238,7 +1238,7 @@ version = "1.3.5+1" [[deps.OpenBLAS_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" -version = "0.3.23+2" +version = "0.3.23+4" [[deps.OpenLibm_jll]] deps = ["Artifacts", "Libdl"] diff --git a/tutorials/01-gaussian-mixture-model/Project.toml b/tutorials/01-gaussian-mixture-model/Project.toml index c9ac8b8bc..b28425f20 100755 --- a/tutorials/01-gaussian-mixture-model/Project.toml +++ b/tutorials/01-gaussian-mixture-model/Project.toml @@ -1,7 +1,8 @@ -[deps] -Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" -FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" -LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd" -Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" +[deps] +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd" +Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" diff --git a/tutorials/01-gaussian-mixture-model/index.qmd b/tutorials/01-gaussian-mixture-model/index.qmd index 5c2ae526e..7e7e528f2 100755 --- a/tutorials/01-gaussian-mixture-model/index.qmd +++ b/tutorials/01-gaussian-mixture-model/index.qmd @@ -115,15 +115,17 @@ We generate multiple chains in parallel using multi-threading. ```{julia} #| output: false +#| echo: false setprogress!(false) ``` ```{julia} #| output: false sampler = Gibbs(PG(100, :k), HMC(0.05, 10, :μ, :w)) -nsamples = 100 -nchains = 3 -chains = sample(model, sampler, MCMCThreads(), nsamples, nchains); +nsamples = 150 +nchains = 4 +burn = 10 +chains = sample(model, sampler, MCMCThreads(), nsamples, nchains, discard_initial = burn); ``` ::: {.callout-warning collapse="true"} @@ -152,32 +154,81 @@ After sampling we can visualize the trace and density of the parameters of inter We consider the samples of the location parameters $\mu_1$ and $\mu_2$ for the two clusters. ```{julia} -plot(chains[["μ[1]", "μ[2]"]]; colordim=:parameter, legend=true) +plot(chains[["μ[1]", "μ[2]"]]; legend=true) ``` It can happen that the modes of $\mu_1$ and $\mu_2$ switch between chains. -For more information see the [Stan documentation](https://mc-stan.org/users/documentation/case-studies/identifying_mixture_models.html) for potential solutions. +For more information see the [Stan documentation](https://mc-stan.org/users/documentation/case-studies/identifying_mixture_models.html). This is because it's possible for either model parameter $\mu_k$ to be assigned to either of the corresponding true means, and this assignment need not be consistent between chains. -We also inspect the samples of the mixture weights $w$. +That is, the posterior is fundamentally multimodal, and different chains can end up in different modes, complicating inference. +One solution here is to enforce an ordering on our $\mu$ vector, requiring $\mu_k > \mu_{k-1}$ for all $k$. +`Bijectors.jl` [provides](https://turinglang.org/Bijectors.jl/dev/transforms/#Bijectors.OrderedBijector) an easy transformation (`ordered()`) for this purpose: + +```{julia} +@model function gaussian_mixture_model_ordered(x) + # Draw the parameters for each of the K=2 clusters from a standard normal distribution. + K = 2 + μ ~ Bijectors.ordered(MvNormal(Zeros(K), I)) + # Draw the weights for the K clusters from a Dirichlet distribution with parameters αₖ = 1. + w ~ Dirichlet(K, 1.0) + # Alternatively, one could use a fixed set of weights. + # w = fill(1/K, K) + # Construct categorical distribution of assignments. + distribution_assignments = Categorical(w) + # Construct multivariate normal distributions of each cluster. + D, N = size(x) + distribution_clusters = [MvNormal(Fill(μₖ, D), I) for μₖ in μ] + # Draw assignments for each datum and generate it from the multivariate normal distribution. + k = Vector{Int}(undef, N) + for i in 1:N + k[i] ~ distribution_assignments + x[:, i] ~ distribution_clusters[k[i]] + end + return k +end + +model = gaussian_mixture_model_ordered(x); +``` + + +Now, re-running our model, we can see that the assigned means are consistent across chains: + +```{julia} +#| output: false +chains = sample(model, sampler, MCMCThreads(), nsamples, nchains, discard_initial = burn); +``` + + +```{julia} +#| echo: false +let + # Verify that the output of the chain is as expected + for i in MCMCChains.chains(chains) + # μ[1] and μ[2] can no longer switch places. Check that they've found the mean + chain = Array(chains[:, ["μ[1]", "μ[2]"], i]) + μ_mean = vec(mean(chain; dims=1)) + @assert isapprox(sort(μ_mean), μ; rtol=0.4) "Difference between estimated mean of μ ($(sort(μ_mean))) and data-generating μ ($μ) unexpectedly large!" + end +end +``` ```{julia} -plot(chains[["w[1]", "w[2]"]]; colordim=:parameter, legend=true) +plot(chains[["μ[1]", "μ[2]"]]; legend=true) ``` -In the following, we just use the first chain to ensure the validity of our inference. +We also inspect the samples of the mixture weights $w$. ```{julia} -chain = chains[:, :, 1]; +plot(chains[["w[1]", "w[2]"]]; legend=true) ``` As the distributions of the samples for the parameters $\mu_1$, $\mu_2$, $w_1$, and $w_2$ are unimodal, we can safely visualize the density region of our model using the average values. ```{julia} # Model with mean of samples as parameters. -μ_mean = [mean(chain, "μ[$i]") for i in 1:2] -w_mean = [mean(chain, "w[$i]") for i in 1:2] +μ_mean = [mean(chains, "μ[$i]") for i in 1:2] +w_mean = [mean(chains, "w[$i]") for i in 1:2] mixturemodel_mean = MixtureModel([MvNormal(Fill(μₖ, 2), I) for μₖ in μ_mean], w_mean) - contour( range(-7.5, 3; length=1_000), range(-6.5, 3; length=1_000), @@ -188,12 +239,11 @@ scatter!(x[1, :], x[2, :]; legend=false, title="Synthetic Dataset") ``` ## Inferred Assignments - Finally, we can inspect the assignments of the data points inferred using Turing. As we can see, the dataset is partitioned into two distinct groups. ```{julia} -assignments = [mean(chain, "k[$i]") for i in 1:N] +assignments = [mean(chains, "k[$i]") for i in 1:N] scatter( x[1, :], x[2, :]; @@ -202,3 +252,162 @@ scatter( zcolor=assignments, ) ``` + + +## Marginalizing Out The Assignments +We can write out the marginal posterior of (continuous) $w, \mu$ by summing out the influence of our (discrete) assignments $z_i$ from +our likelihood: +$$ +p(y \mid w, \mu ) = \sum_{k=1}^K w_k p_k(y \mid \mu_k) +$$ +In our case, this gives us: +$$ +p(y \mid w, \mu) = \sum_{k=1}^K w_k \cdot \operatorname{MvNormal}(y \mid \mu_k, I) +$$ + + +### Marginalizing By Hand +We could implement the above version of the Gaussian mixture model in Turing as follows: +First, Turing uses log-probabilities, so the likelihood above must be converted into log-space: +$$ +\log \left( p(y \mid w, \mu) \right) = \text{logsumexp} \left[\log (w_k) + \log(\operatorname{MvNormal}(y \mid \mu_k, I)) \right] +$$ + +Where we sum the components with `logsumexp` from the [`LogExpFunctions.jl` package](https://juliastats.org/LogExpFunctions.jl/stable/). +The manually incremented likelihood can be added to the log-probability with `Turing.@addlogprob!`, giving us the following model: + +```{julia} +#| output: false +using LogExpFunctions + +@model function gmm_marginalized(x) + K = 2 + D, N = size(x) + μ ~ Bijectors.ordered(MvNormal(Zeros(K), I)) + w ~ Dirichlet(K, 1.0) + dists = [MvNormal(Fill(μₖ, D), I) for μₖ in μ] + for i in 1:N + lvec = Vector(undef, K) + for k in 1:K + lvec[k] = (w[k] + logpdf(dists[k], x[:, i])) + end + Turing.@addlogprob! logsumexp(lvec) + end +end +``` + +::: {.callout-warning collapse="false"} +## Manually Incrementing Probablity + +When possible, use of `Turing.@addlogprob!` should be avoided, as it exists outside the +usual structure of a Turing model. In most cases, a custom distribution should be used instead. + +Here, the next section demonstrates the perfered method --- using the `MixtureModel` distribution we have seen already to +perform the marginalization automatically. +::: + + +### Marginalizing For Free With Distribution.jl's MixtureModel Implementation + +We can use Turing's `~` syntax with anything that `Distributions.jl` provides `logpdf` and `rand` methods for. It turns out that the +`MixtureModel` distribution it provides has, as its `logpdf` method, `logpdf(MixtureModel([Component_Distributions], weight_vector), Y)`, where `Y` can be either a single observation or vector of observations. + +In fact, `Distributions.jl` provides [many convenient constructors](https://juliastats.org/Distributions.jl/stable/mixture/) for mixture models, allowing further simplification in common special cases. + +For example, when mixtures distributions are of the same type, one can write: `~ MixtureModel(Normal, [(μ1, σ1), (μ2, σ2)], w)`, or when the weight vector is known to allocate probability equally, it can be ommited. + +The `logpdf` implementation for a `MixtureModel` distribution is exactly the marginalization defined above, and so our model becomes simply: + +```{julia} +#| output: false +@model function gmm_marginalized(x) + K = 2 + D, _ = size(x) + μ ~ Bijectors.ordered(MvNormal(Zeros(K), I)) + w ~ Dirichlet(K, 1.0) + x ~ MixtureModel([MvNormal(Fill(μₖ, D), I) for μₖ in μ], w) +end +model = gmm_marginalized(x); +``` + +As we've summed out the discrete components, we can perform inference using `NUTS()` alone. + +```{julia} +#| output: false +sampler = NUTS() +chains = sample(model, sampler, MCMCThreads(), nsamples, nchains; discard_initial = burn); +``` + + +```{julia} +#| echo: false +let + # Verify for marginalized model that the output of the chain is as expected + for i in MCMCChains.chains(chains) + # μ[1] and μ[2] can no longer switch places. Check that they've found the mean + chain = Array(chains[:, ["μ[1]", "μ[2]"], i]) + μ_mean = vec(mean(chain; dims=1)) + @assert isapprox(sort(μ_mean), μ; rtol=0.4) "Difference between estimated mean of μ ($(sort(μ_mean))) and data-generating μ ($μ) unexpectedly large!" + end +end +``` + +`NUTS()` significantly outperforms our compositional Gibbs sampler, in large part because our model is now Rao-Blackwellized thanks to +the marginalization of our assignment parameter. + +```{julia} +plot(chains[["μ[1]", "μ[2]"]], legend=true) +``` + +## Inferred Assignments - Marginalized Model +As we've summed over possible assignments, the associated parameter is no longer available in our chain. +This is not a problem, however, as given any fixed sample $(\mu, w)$, the assignment probability — $p(z_i \mid y_i)$ — can be recovered using Bayes rule: +$$ +p(z_i \mid y_i) = \frac{p(y_i \mid z_i) p(z_i)}{\sum_{k = 1}^K \left(p(y_i \mid z_i) p(z_i) \right)} +$$ + +This quantity can be computed for every $p(z = z_i \mid y_i)$, resulting in a probability vector, which is then used to sample +posterior predictive assignments from a categorial distribution. +For details on the mathematics here, see [the Stan documentation on latent discrete parameters](https://mc-stan.org/docs/stan-users-guide/latent-discrete.html). +```{julia} +#| output: false +function sample_class(xi, dists, w) + lvec = [(logpdf(d, xi) + log(w[i])) for (i, d) in enumerate(dists)] + rand(Categorical(softmax(lvec))) +end + +@model function gmm_recover(x) + K = 2 + D, N = size(x) + μ ~ Bijectors.ordered(MvNormal(Zeros(K), I)) + w ~ Dirichlet(K, 1.0) + dists = [MvNormal(Fill(μₖ, D), I) for μₖ in μ] + x ~ MixtureModel(dists, w) + # Return assignment draws for each datapoint. + return [sample_class(x[:, i], dists, w) for i in 1:N] +end +``` + +We sample from this model as before: + +```{julia} +#| output: false +model = gmm_recover(x) +chains = sample(model, sampler, MCMCThreads(), nsamples, nchains, discard_initial = burn); +``` + +Given a sample from the marginalized posterior, these assignments can be recovered with: + +```{julia} +assignments = mean(generated_quantities(gmm_recover(x), chains)); +``` + +```{julia} +scatter( + x[1, :], + x[2, :]; + legend=false, + title="Assignments on Synthetic Dataset - Recovered", + zcolor=assignments, +) +``` \ No newline at end of file