Skip to content

Substantially Update GMM-01 #447

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions tutorials/01-gaussian-mixture-model/Manifest.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# This file is machine-generated - editing it directly is not advised

julia_version = "1.10.0"
julia_version = "1.10.3"
manifest_format = "2.0"
project_hash = "a0a3ca6aa1eb54296a719c0190365d414d855bc0"
project_hash = "eaea37cdde9f1d06a5695c8ca16ab89aa3a9a2b4"

[[deps.ADTypes]]
git-tree-sha1 = "daf26bbdec60d9ca1c0003b70f389d821ddb4224"
Expand Down Expand Up @@ -383,7 +383,7 @@ weakdeps = ["Dates", "LinearAlgebra"]
[[deps.CompilerSupportLibraries_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
version = "1.0.5+1"
version = "1.1.1+0"

[[deps.CompositionsBase]]
git-tree-sha1 = "802bb88cd69dfd1509f6670416bd4434015693ad"
Expand Down Expand Up @@ -1238,7 +1238,7 @@ version = "1.3.5+1"
[[deps.OpenBLAS_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"]
uuid = "4536629a-c528-5b80-bd46-f80d51c5b363"
version = "0.3.23+2"
version = "0.3.23+4"

[[deps.OpenLibm_jll]]
deps = ["Artifacts", "Libdl"]
Expand Down
15 changes: 8 additions & 7 deletions tutorials/01-gaussian-mixture-model/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
237 changes: 223 additions & 14 deletions tutorials/01-gaussian-mixture-model/index.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -115,15 +115,17 @@ We generate multiple chains in parallel using multi-threading.

```{julia}
#| output: false
#| echo: false
setprogress!(false)
```

```{julia}
#| output: false
sampler = Gibbs(PG(100, :k), HMC(0.05, 10, :μ, :w))
nsamples = 100
nchains = 3
chains = sample(model, sampler, MCMCThreads(), nsamples, nchains);
nsamples = 150
nchains = 4
burn = 10
chains = sample(model, sampler, MCMCThreads(), nsamples, nchains, discard_initial = burn);
```

::: {.callout-warning collapse="true"}
Expand Down Expand Up @@ -152,32 +154,81 @@ After sampling we can visualize the trace and density of the parameters of inter
We consider the samples of the location parameters $\mu_1$ and $\mu_2$ for the two clusters.

```{julia}
plot(chains[["μ[1]", "μ[2]"]]; colordim=:parameter, legend=true)
plot(chains[["μ[1]", "μ[2]"]]; legend=true)
```

It can happen that the modes of $\mu_1$ and $\mu_2$ switch between chains.
For more information see the [Stan documentation](https://mc-stan.org/users/documentation/case-studies/identifying_mixture_models.html) for potential solutions.
For more information see the [Stan documentation](https://mc-stan.org/users/documentation/case-studies/identifying_mixture_models.html). This is because it's possible for either model parameter $\mu_k$ to be assigned to either of the corresponding true means, and this assignment need not be consistent between chains.

We also inspect the samples of the mixture weights $w$.
That is, the posterior is fundamentally multimodal, and different chains can end up in different modes, complicating inference.
One solution here is to enforce an ordering on our $\mu$ vector, requiring $\mu_k > \mu_{k-1}$ for all $k$.
`Bijectors.jl` [provides](https://turinglang.org/Bijectors.jl/dev/transforms/#Bijectors.OrderedBijector) an easy transformation (`ordered()`) for this purpose:

```{julia}
@model function gaussian_mixture_model_ordered(x)
# Draw the parameters for each of the K=2 clusters from a standard normal distribution.
K = 2
μ ~ Bijectors.ordered(MvNormal(Zeros(K), I))
# Draw the weights for the K clusters from a Dirichlet distribution with parameters αₖ = 1.
w ~ Dirichlet(K, 1.0)
# Alternatively, one could use a fixed set of weights.
# w = fill(1/K, K)
# Construct categorical distribution of assignments.
distribution_assignments = Categorical(w)
# Construct multivariate normal distributions of each cluster.
D, N = size(x)
distribution_clusters = [MvNormal(Fill(μₖ, D), I) for μₖ in μ]
# Draw assignments for each datum and generate it from the multivariate normal distribution.
k = Vector{Int}(undef, N)
for i in 1:N
k[i] ~ distribution_assignments
x[:, i] ~ distribution_clusters[k[i]]
end
return k
end

model = gaussian_mixture_model_ordered(x);
```


Now, re-running our model, we can see that the assigned means are consistent across chains:

```{julia}
#| output: false
chains = sample(model, sampler, MCMCThreads(), nsamples, nchains, discard_initial = burn);
```


```{julia}
#| echo: false
let
# Verify that the output of the chain is as expected
for i in MCMCChains.chains(chains)
# μ[1] and μ[2] can no longer switch places. Check that they've found the mean
chain = Array(chains[:, ["μ[1]", "μ[2]"], i])
μ_mean = vec(mean(chain; dims=1))
@assert isapprox(sort(μ_mean), μ; rtol=0.4) "Difference between estimated mean of μ ($(sort(μ_mean))) and data-generating μ ($μ) unexpectedly large!"
end
end
```

```{julia}
plot(chains[["w[1]", "w[2]"]]; colordim=:parameter, legend=true)
plot(chains[["μ[1]", "μ[2]"]]; legend=true)
```

In the following, we just use the first chain to ensure the validity of our inference.
We also inspect the samples of the mixture weights $w$.

```{julia}
chain = chains[:, :, 1];
plot(chains[["w[1]", "w[2]"]]; legend=true)
```

As the distributions of the samples for the parameters $\mu_1$, $\mu_2$, $w_1$, and $w_2$ are unimodal, we can safely visualize the density region of our model using the average values.

```{julia}
# Model with mean of samples as parameters.
μ_mean = [mean(chain, "μ[$i]") for i in 1:2]
w_mean = [mean(chain, "w[$i]") for i in 1:2]
μ_mean = [mean(chains, "μ[$i]") for i in 1:2]
w_mean = [mean(chains, "w[$i]") for i in 1:2]
mixturemodel_mean = MixtureModel([MvNormal(Fill(μₖ, 2), I) for μₖ in μ_mean], w_mean)

contour(
range(-7.5, 3; length=1_000),
range(-6.5, 3; length=1_000),
Expand All @@ -188,12 +239,11 @@ scatter!(x[1, :], x[2, :]; legend=false, title="Synthetic Dataset")
```

## Inferred Assignments

Finally, we can inspect the assignments of the data points inferred using Turing.
As we can see, the dataset is partitioned into two distinct groups.

```{julia}
assignments = [mean(chain, "k[$i]") for i in 1:N]
assignments = [mean(chains, "k[$i]") for i in 1:N]
scatter(
x[1, :],
x[2, :];
Expand All @@ -202,3 +252,162 @@ scatter(
zcolor=assignments,
)
```


## Marginalizing Out The Assignments
We can write out the marginal posterior of (continuous) $w, \mu$ by summing out the influence of our (discrete) assignments $z_i$ from
our likelihood:
$$
p(y \mid w, \mu ) = \sum_{k=1}^K w_k p_k(y \mid \mu_k)
$$
In our case, this gives us:
$$
p(y \mid w, \mu) = \sum_{k=1}^K w_k \cdot \operatorname{MvNormal}(y \mid \mu_k, I)
$$


### Marginalizing By Hand
We could implement the above version of the Gaussian mixture model in Turing as follows:
First, Turing uses log-probabilities, so the likelihood above must be converted into log-space:
$$
\log \left( p(y \mid w, \mu) \right) = \text{logsumexp} \left[\log (w_k) + \log(\operatorname{MvNormal}(y \mid \mu_k, I)) \right]
$$

Where we sum the components with `logsumexp` from the [`LogExpFunctions.jl` package](https://juliastats.org/LogExpFunctions.jl/stable/).
The manually incremented likelihood can be added to the log-probability with `Turing.@addlogprob!`, giving us the following model:

```{julia}
#| output: false
using LogExpFunctions

@model function gmm_marginalized(x)
K = 2
D, N = size(x)
μ ~ Bijectors.ordered(MvNormal(Zeros(K), I))
w ~ Dirichlet(K, 1.0)
dists = [MvNormal(Fill(μₖ, D), I) for μₖ in μ]
for i in 1:N
lvec = Vector(undef, K)
for k in 1:K
lvec[k] = (w[k] + logpdf(dists[k], x[:, i]))
end
Turing.@addlogprob! logsumexp(lvec)
end
end
```

::: {.callout-warning collapse="false"}
## Manually Incrementing Probablity

When possible, use of `Turing.@addlogprob!` should be avoided, as it exists outside the
usual structure of a Turing model. In most cases, a custom distribution should be used instead.

Here, the next section demonstrates the perfered method --- using the `MixtureModel` distribution we have seen already to
perform the marginalization automatically.
:::


### Marginalizing For Free With Distribution.jl's MixtureModel Implementation

We can use Turing's `~` syntax with anything that `Distributions.jl` provides `logpdf` and `rand` methods for. It turns out that the
`MixtureModel` distribution it provides has, as its `logpdf` method, `logpdf(MixtureModel([Component_Distributions], weight_vector), Y)`, where `Y` can be either a single observation or vector of observations.

In fact, `Distributions.jl` provides [many convenient constructors](https://juliastats.org/Distributions.jl/stable/mixture/) for mixture models, allowing further simplification in common special cases.

For example, when mixtures distributions are of the same type, one can write: `~ MixtureModel(Normal, [(μ1, σ1), (μ2, σ2)], w)`, or when the weight vector is known to allocate probability equally, it can be ommited.

The `logpdf` implementation for a `MixtureModel` distribution is exactly the marginalization defined above, and so our model becomes simply:

```{julia}
#| output: false
@model function gmm_marginalized(x)
K = 2
D, _ = size(x)
μ ~ Bijectors.ordered(MvNormal(Zeros(K), I))
w ~ Dirichlet(K, 1.0)
x ~ MixtureModel([MvNormal(Fill(μₖ, D), I) for μₖ in μ], w)
end
model = gmm_marginalized(x);
```

As we've summed out the discrete components, we can perform inference using `NUTS()` alone.

```{julia}
#| output: false
sampler = NUTS()
chains = sample(model, sampler, MCMCThreads(), nsamples, nchains; discard_initial = burn);
```


```{julia}
#| echo: false
let
# Verify for marginalized model that the output of the chain is as expected
for i in MCMCChains.chains(chains)
# μ[1] and μ[2] can no longer switch places. Check that they've found the mean
chain = Array(chains[:, ["μ[1]", "μ[2]"], i])
μ_mean = vec(mean(chain; dims=1))
@assert isapprox(sort(μ_mean), μ; rtol=0.4) "Difference between estimated mean of μ ($(sort(μ_mean))) and data-generating μ ($μ) unexpectedly large!"
end
end
```

`NUTS()` significantly outperforms our compositional Gibbs sampler, in large part because our model is now Rao-Blackwellized thanks to
the marginalization of our assignment parameter.

```{julia}
plot(chains[["μ[1]", "μ[2]"]], legend=true)
```

## Inferred Assignments - Marginalized Model
As we've summed over possible assignments, the associated parameter is no longer available in our chain.
This is not a problem, however, as given any fixed sample $(\mu, w)$, the assignment probability — $p(z_i \mid y_i)$ — can be recovered using Bayes rule:
$$
p(z_i \mid y_i) = \frac{p(y_i \mid z_i) p(z_i)}{\sum_{k = 1}^K \left(p(y_i \mid z_i) p(z_i) \right)}
$$

This quantity can be computed for every $p(z = z_i \mid y_i)$, resulting in a probability vector, which is then used to sample
posterior predictive assignments from a categorial distribution.
For details on the mathematics here, see [the Stan documentation on latent discrete parameters](https://mc-stan.org/docs/stan-users-guide/latent-discrete.html).
```{julia}
#| output: false
function sample_class(xi, dists, w)
lvec = [(logpdf(d, xi) + log(w[i])) for (i, d) in enumerate(dists)]
rand(Categorical(softmax(lvec)))
end

@model function gmm_recover(x)
K = 2
D, N = size(x)
μ ~ Bijectors.ordered(MvNormal(Zeros(K), I))
w ~ Dirichlet(K, 1.0)
dists = [MvNormal(Fill(μₖ, D), I) for μₖ in μ]
x ~ MixtureModel(dists, w)
# Return assignment draws for each datapoint.
return [sample_class(x[:, i], dists, w) for i in 1:N]
end
```

We sample from this model as before:

```{julia}
#| output: false
model = gmm_recover(x)
chains = sample(model, sampler, MCMCThreads(), nsamples, nchains, discard_initial = burn);
```

Given a sample from the marginalized posterior, these assignments can be recovered with:

```{julia}
assignments = mean(generated_quantities(gmm_recover(x), chains));
```

```{julia}
scatter(
x[1, :],
x[2, :];
legend=false,
title="Assignments on Synthetic Dataset - Recovered",
zcolor=assignments,
)
```