Gaussian Mixture Models

The following tutorial illustrates the use of Turing for an unsupervised task, namely, clustering data using a Bayesian mixture model. The aim of this task is to infer a latent grouping (hidden structure) from unlabelled data.

Synthetic Data

We generate a synthetic dataset of \(N = 60\) two-dimensional points \(x_i \in \mathbb{R}^2\) drawn from a Gaussian mixture model. For simplicity, we use \(K = 2\) clusters with

  • equal weights, i.e., we use mixture weights \(w = [0.5, 0.5]\), and
  • isotropic Gaussian distributions of the points in each cluster.

More concretely, we use the Gaussian distributions \(\mathcal{N}([\mu_k, \mu_k]^\mathsf{T}, I)\) with parameters \(\mu_1 = -3.5\) and \(\mu_2 = 0.5\).

using Distributions
using FillArrays
using StatsPlots

using LinearAlgebra
using Random

# Set a random seed.
Random.seed!(3)

# Define Gaussian mixture model.
w = [0.5, 0.5]
μ = [-3.5, 0.5]
mixturemodel = MixtureModel([MvNormal(Fill(μₖ, 2), I) for μₖ in μ], w)

# We draw the data points.
N = 60
x = rand(mixturemodel, N);

The following plot shows the dataset.

scatter(x[1, :], x[2, :]; legend=false, title="Synthetic Dataset")

Gaussian Mixture Model in Turing

We are interested in recovering the grouping from the dataset. More precisely, we want to infer the mixture weights, the parameters \(\mu_1\) and \(\mu_2\), and the assignment of each datum to a cluster for the generative Gaussian mixture model.

In a Bayesian Gaussian mixture model with \(K\) components each data point \(x_i\) (\(i = 1,\ldots,N\)) is generated according to the following generative process. First we draw the model parameters, i.e., in our example we draw parameters \(\mu_k\) for the mean of the isotropic normal distributions and the mixture weights \(w\) of the \(K\) clusters. We use standard normal distributions as priors for \(\mu_k\) and a Dirichlet distribution with parameters \(\alpha_1 = \cdots = \alpha_K = 1\) as prior for \(w\): \[ \begin{aligned} \mu_k &\sim \mathcal{N}(0, 1) \qquad (k = 1,\ldots,K)\\ w &\sim \operatorname{Dirichlet}(\alpha_1, \ldots, \alpha_K) \end{aligned} \] After having constructed all the necessary model parameters, we can generate an observation by first selecting one of the clusters \[ z_i \sim \operatorname{Categorical}(w) \qquad (i = 1,\ldots,N), \] and then drawing the datum accordingly, i.e., in our example drawing \[ x_i \sim \mathcal{N}([\mu_{z_i}, \mu_{z_i}]^\mathsf{T}, I) \qquad (i=1,\ldots,N). \] For more details on Gaussian mixture models, refer to Chapter 9 of Christopher M. Bishop, Pattern Recognition and Machine Learning.

We specify the model in Turing:

using Turing

@model function gaussian_mixture_model(x)
    # Draw the parameters for each of the K=2 clusters from a standard normal distribution.
    K = 2
    μ ~ MvNormal(Zeros(K), I)

    # Draw the weights for the K clusters from a Dirichlet distribution with parameters αₖ = 1.
    w ~ Dirichlet(K, 1.0)
    # Alternatively, one could use a fixed set of weights.
    # w = fill(1/K, K)

    # Construct categorical distribution of assignments.
    distribution_assignments = Categorical(w)

    # Construct multivariate normal distributions of each cluster.
    D, N = size(x)
    distribution_clusters = [MvNormal(Fill(μₖ, D), I) for μₖ in μ]

    # Draw assignments for each datum and generate it from the multivariate normal distribution.
    k = Vector{Int}(undef, N)
    for i in 1:N
        k[i] ~ distribution_assignments
        x[:, i] ~ distribution_clusters[k[i]]
    end

    return k
end

model = gaussian_mixture_model(x);

We run a MCMC simulation to obtain an approximation of the posterior distribution of the parameters \(\mu\) and \(w\) and assignments \(k\). We use a Gibbs sampler that combines a particle Gibbs sampler for the discrete parameters (assignments \(k\)) and a Hamiltonian Monte Carlo sampler for the continuous parameters (\(\mu\) and \(w\)). We generate multiple chains in parallel using multi-threading.

sampler = Gibbs(:k => PG(100), (:μ, :w) => HMC(0.05, 10))
nsamples = 150
nchains = 4
burn = 10
chains = sample(model, sampler, MCMCThreads(), nsamples, nchains, discard_initial = burn);
Sampling With Multiple Threads

The sample() call above assumes that you have at least two threads available in your Julia instance. If you do not, the multiple chains will run sequentially, and you may notice a warning. For more information, see the Turing documentation on sampling multiple chains.

Inferred Mixture Model

After sampling we can visualize the trace and density of the parameters of interest.

We consider the samples of the location parameters \(\mu_1\) and \(\mu_2\) for the two clusters.

plot(chains[["μ[1]", "μ[2]"]]; legend=true)

From the plots above, we can see that the chains have converged to seemingly different values for the parameters \(\mu_1\) and \(\mu_2\). However, these actually represent the same solution: it does not matter whether we assign \(\mu_1\) to the first cluster and \(\mu_2\) to the second, or vice versa, since the resulting sum is the same. (In principle it is also possible for the parameters to swap places within a single chain, although this does not happen in this example.) For more information see the Stan documentation, or Bishop’s book, where the concept of identifiability is discussed.

Having \(\mu_1\) and \(\mu_2\) swap can complicate the interpretation of the results, especially when different chains converge to different assignments. One solution here is to enforce an ordering on our \(\mu\) vector, requiring \(\mu_k \geq \mu_{k-1}\) for all \(k\). Bijectors.jl provides a convenient function, ordered(), which can be applied to a (continuous multivariate) distribution to enforce this:

using Bijectors: ordered

@model function gaussian_mixture_model_ordered(x)
    # Draw the parameters for each of the K=2 clusters from a standard normal distribution.
    K = 2
    μ ~ ordered(MvNormal(Zeros(K), I))
    # Draw the weights for the K clusters from a Dirichlet distribution with parameters αₖ = 1.
    w ~ Dirichlet(K, 1.0)
    # Alternatively, one could use a fixed set of weights.
    # w = fill(1/K, K)
    # Construct categorical distribution of assignments.
    distribution_assignments = Categorical(w)
    # Construct multivariate normal distributions of each cluster.
    D, N = size(x)
    distribution_clusters = [MvNormal(Fill(μₖ, D), I) for μₖ in μ]
    # Draw assignments for each datum and generate it from the multivariate normal distribution.
    k = Vector{Int}(undef, N)
    for i in 1:N
        k[i] ~ distribution_assignments
        x[:, i] ~ distribution_clusters[k[i]]
    end
    return k
end

model = gaussian_mixture_model_ordered(x);

Now, re-running our model, we can see that the assigned means are consistent between chains:

chains = sample(model, sampler, MCMCThreads(), nsamples, nchains, discard_initial = burn);
plot(chains[["μ[1]", "μ[2]"]]; legend=true)

We also inspect the samples of the mixture weights \(w\).

plot(chains[["w[1]", "w[2]"]]; legend=true)

As the distributions of the samples for the parameters \(\mu_1\), \(\mu_2\), \(w_1\), and \(w_2\) are unimodal, we can safely visualize the density region of our model using the average values.

# Model with mean of samples as parameters.
μ_mean = [mean(chains, "μ[$i]") for i in 1:2]
w_mean = [mean(chains, "w[$i]") for i in 1:2]
mixturemodel_mean = MixtureModel([MvNormal(Fill(μₖ, 2), I) for μₖ in μ_mean], w_mean)
contour(
    range(-7.5, 3; length=1_000),
    range(-6.5, 3; length=1_000),
    (x, y) -> logpdf(mixturemodel_mean, [x, y]);
    widen=false,
)
scatter!(x[1, :], x[2, :]; legend=false, title="Synthetic Dataset")

Inferred Assignments

Finally, we can inspect the assignments of the data points inferred using Turing. As we can see, the dataset is partitioned into two distinct groups.

assignments = [mean(chains, "k[$i]") for i in 1:N]
scatter(
    x[1, :],
    x[2, :];
    legend=false,
    title="Assignments on Synthetic Dataset",
    zcolor=assignments,
)

Marginalizing Out The Assignments

We can write out the marginal posterior of (continuous) \(w, \mu\) by summing out the influence of our (discrete) assignments \(z_i\) from our likelihood:

\[p(y \mid w, \mu ) = \sum_{k=1}^K w_k p_k(y \mid \mu_k)\]

In our case, this gives us:

\[p(y \mid w, \mu) = \sum_{k=1}^K w_k \cdot \operatorname{MvNormal}(y \mid \mu_k, I)\]

Marginalizing By Hand

We could implement the above version of the Gaussian mixture model in Turing as follows.

First, Turing uses log-probabilities, so the likelihood above must be converted into log-space:

\[\log \left( p(y \mid w, \mu) \right) = \text{logsumexp} \left[\log (w_k) + \log(\operatorname{MvNormal}(y \mid \mu_k, I)) \right]\]

Where we sum the components with logsumexp from the LogExpFunctions.jl package. The manually incremented likelihood can be added to the log-probability with @addlogprob!, giving us the following model:

using LogExpFunctions

@model function gmm_marginalized(x)
    K = 2
    D, N = size(x)
    μ ~ ordered(MvNormal(Zeros(K), I))
    w ~ Dirichlet(K, 1.0)
    dists = [MvNormal(Fill(μₖ, D), I) for μₖ in μ]
    for i in 1:N
        lvec = Vector(undef, K)
        for k in 1:K
            lvec[k] = (w[k] + logpdf(dists[k], x[:, i]))
        end
        @addlogprob! logsumexp(lvec)
    end
end
Manually Incrementing Probablity

When possible, use of @addlogprob! should be avoided, as it exists outside the usual structure of a Turing model. In most cases, a custom distribution should be used instead.

The next section demonstrates the preferred method: using the MixtureModel distribution we have seen already to perform the marginalization automatically.

Marginalizing For Free With Distribution.jl’s MixtureModel Implementation

We can use Turing’s ~ syntax with anything that Distributions.jl provides logpdf and rand methods for. It turns out that the MixtureModel distribution it provides has, as its logpdf method, logpdf(MixtureModel([Component_Distributions], weight_vector), Y), where Y can be either a single observation or vector of observations.

In fact, Distributions.jl provides many convenient constructors for mixture models, allowing further simplification in common special cases.

For example, when mixtures distributions are of the same type, one can write: ~ MixtureModel(Normal, [(μ1, σ1), (μ2, σ2)], w), or when the weight vector is known to allocate probability equally, it can be ommited.

The logpdf implementation for a MixtureModel distribution is exactly the marginalization defined above, and so our model can be simplified to:

@model function gmm_marginalized(x)
    K = 2
    D, _ = size(x)
    μ ~ ordered(MvNormal(Zeros(K), I))
    w ~ Dirichlet(K, 1.0)
    x ~ MixtureModel([MvNormal(Fill(μₖ, D), I) for μₖ in μ], w)
end
model = gmm_marginalized(x);

As we have summed out the discrete components, we can perform inference using NUTS() alone.

sampler = NUTS()
chains = sample(model, sampler, MCMCThreads(), nsamples, nchains; discard_initial = burn);

NUTS() significantly outperforms our compositional Gibbs sampler, in large part because our model is now Rao-Blackwellized thanks to the marginalization of our assignment parameter.

plot(chains[["μ[1]", "μ[2]"]], legend=true)

Inferred Assignments With The Marginalized Model

As we have summed over possible assignments, the latent parameter representing the assignments is no longer available in our chain. This is not a problem, however, as given any fixed sample \((\mu, w)\), the assignment probability \(p(z_i \mid y_i)\) can be recovered using Bayes’s theorme:

\[p(z_i \mid y_i) = \frac{p(y_i \mid z_i) p(z_i)}{\sum_{k = 1}^K \left(p(y_i \mid z_i) p(z_i) \right)}\]

This quantity can be computed for every \(p(z = z_i \mid y_i)\), resulting in a probability vector, which is then used to sample posterior predictive assignments from a categorial distribution. For details on the mathematics here, see the Stan documentation on latent discrete parameters.

function sample_class(xi, dists, w)
    lvec = [(logpdf(d, xi) + log(w[i])) for (i, d) in enumerate(dists)]
    rand(Categorical(softmax(lvec)))
end

@model function gmm_recover(x)
    K = 2
    D, N =  size(x)
    μ ~ ordered(MvNormal(Zeros(K), I))
    w ~ Dirichlet(K, 1.0)
    dists = [MvNormal(Fill(μₖ, D), I) for μₖ in μ]
    x ~ MixtureModel(dists, w)
    # Return assignment draws for each datapoint.
    return [sample_class(x[:, i], dists, w) for i in 1:N]
end

We sample from this model as before:

model = gmm_recover(x)
chains = sample(model, sampler, MCMCThreads(), nsamples, nchains, discard_initial = burn);

Given a sample from the marginalized posterior, these assignments can be recovered with:

assignments = mean(returned(gmm_recover(x), chains));
scatter(
    x[1, :],
    x[2, :];
    legend=false,
    title="Assignments on Synthetic Dataset - Recovered",
    zcolor=assignments,
)
Back to top