In many applications it is desirable to allow the model to adjust its complexity to the amount of data. Consider for example the task of assigning objects into clusters or groups. This task often involves the specification of the number of groups. However, often times it is not known beforehand how many groups exist. Moreover, in some applictions, e.g. modelling topics in text documents or grouping species, the number of examples per group is heavy tailed. This makes it impossible to predefine the number of groups and requiring the model to form new groups when data points from previously unseen groups are observed.
A natural approach for such applications is the use of non-parametric models. This tutorial will introduce how to use the Dirichlet process in a mixture of infinitely many Gaussians using Turing. For further information on Bayesian nonparametrics and the Dirichlet process we refer to the introduction by Zoubin Ghahramani and the book “Fundamentals of Nonparametric Bayesian Inference” by Subhashis Ghosal and Aad van der Vaart.
usingTuring
Precompiling Turing...
790.9 ms ? OptimizationBase
1403.5 ms ? Optimization
2148.6 ms ? OptimizationOptimJL
Info Given Turing was explicitly requested, output will be shown live
WARNING: redefinition of constant OptimizationBase.OPTIMIZER_MISSING_ERROR_MESSAGE. This may fail, cause incorrect answers, or produce other errors.
WARNING: Method definition (::Type{OptimizationBase.OptimizerMissingError})(Any) in module OptimizationBase at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:23 overwritten at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:177.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
5504.4 ms ? Turing
5678.4 ms ? Turing → TuringOptimExt
WARNING: redefinition of constant OptimizationBase.OPTIMIZER_MISSING_ERROR_MESSAGE. This may fail, cause incorrect answers, or produce other errors.
WARNING: Method definition (::Type{OptimizationBase.OptimizerMissingError})(Any) in module OptimizationBase at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:23 overwritten at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:177.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
Precompiling Optimization...
812.8 ms ? OptimizationBaseInfo Given Optimization was explicitly requested, output will be shown live
WARNING: redefinition of constant OptimizationBase.OPTIMIZER_MISSING_ERROR_MESSAGE. This may fail, cause incorrect answers, or produce other errors.
WARNING: Method definition (::Type{OptimizationBase.OptimizerMissingError})(Any) in module OptimizationBase at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:23 overwritten at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:177.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
1449.0 ms ? Optimization
WARNING: redefinition of constant OptimizationBase.OPTIMIZER_MISSING_ERROR_MESSAGE. This may fail, cause incorrect answers, or produce other errors.
WARNING: Method definition (::Type{OptimizationBase.OptimizerMissingError})(Any) in module OptimizationBase at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:23 overwritten at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:177.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
Precompiling OptimizationBase...
Info Given OptimizationBase was explicitly requested, output will be shown live
WARNING: redefinition of constant OptimizationBase.OPTIMIZER_MISSING_ERROR_MESSAGE. This may fail, cause incorrect answers, or produce other errors.
WARNING: Method definition (::Type{OptimizationBase.OptimizerMissingError})(Any) in module OptimizationBase at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:23 overwritten at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:177.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
791.1 ms ? OptimizationBase
WARNING: redefinition of constant OptimizationBase.OPTIMIZER_MISSING_ERROR_MESSAGE. This may fail, cause incorrect answers, or produce other errors.
WARNING: Method definition (::Type{OptimizationBase.OptimizerMissingError})(Any) in module OptimizationBase at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:23 overwritten at /home/runner/.julia/packages/OptimizationBase/sfIfa/src/solve.jl:177.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
WARNING: redefinition of constant OptimizationBase.OPTIMIZER_MISSING_ERROR_MESSAGE. This may fail, cause incorrect answers, or produce other errors.
WARNING: redefinition of constant OptimizationBase.OPTIMIZER_MISSING_ERROR_MESSAGE. This may fail, cause incorrect answers, or produce other errors.
┌ Warning: Replacing docs for `CommonSolve.solve :: Tuple{SciMLBase.OptimizationProblem, Any, Vararg{Any}}` in module `OptimizationBase`
└ @ Base.Docs docs/Docs.jl:243
WARNING: redefinition of constant OptimizationBase.OPTIMIZER_MISSING_ERROR_MESSAGE. This may fail, cause incorrect answers, or produce other errors.
┌ Warning: Replacing docs for `CommonSolve.init :: Tuple{SciMLBase.OptimizationProblem, Any, Vararg{Any}}` in module `OptimizationBase`
└ @ Base.Docs docs/Docs.jl:243┌ Warning: Replacing docs for `CommonSolve.solve! :: Tuple{SciMLBase.AbstractOptimizationCache}` in module `OptimizationBase`
└ @ Base.Docs docs/Docs.jl:243Precompiling OptimizationOptimJL...
813.2 ms ? OptimizationBase
1035.2 ms ? Optimization
Info Given OptimizationOptimJL was explicitly requested, output will be shown live
┌ Warning: Module Optimization with build ID ffffffff-ffff-ffff-a210-ced4f5316257 is missing from the cache.
│ This may mean Optimization [7f7a1694-90dd-40f0-9382-eb1efda571ba] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
1131.4 ms ? OptimizationOptimJL
┌ Warning: Module Optimization with build ID ffffffff-ffff-ffff-a210-ced4f5316257 is missing from the cache.
│ This may mean Optimization [7f7a1694-90dd-40f0-9382-eb1efda571ba] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541Precompiling TuringOptimExt...
798.6 ms ? OptimizationBase
1000.2 ms ? Optimization
1152.5 ms ? OptimizationOptimJL
3821.0 ms ? Turing
Info Given TuringOptimExt was explicitly requested, output will be shown live
┌ Warning: Module Turing with build ID ffffffff-ffff-ffff-dac8-5c905bed82a1 is missing from the cache.
│ This may mean Turing [fce5fe82-541a-59a6-adf8-730c64b5f9a0] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
631.1 ms ? Turing → TuringOptimExt
┌ Warning: Module Turing with build ID ffffffff-ffff-ffff-dac8-5c905bed82a1 is missing from the cache.
│ This may mean Turing [fce5fe82-541a-59a6-adf8-730c64b5f9a0] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Mixture Model
Before introducing infinite mixture models in Turing, we will briefly review the construction of finite mixture models. Subsequently, we will define how to use the Chinese restaurant process construction of a Dirichlet process for non-parametric clustering.
Two-Component Model
First, consider the simple case of a mixture model with two Gaussian components with fixed covariance. The generative process of such a model can be written as:
where \(\pi_1, \pi_2\) are the mixing weights of the mixture model, i.e. \(\pi_1 + \pi_2 = 1\), and \(z_i\) is a latent assignment of the observation \(x_i\) to a component (Gaussian).
We can implement this model in Turing for 1D data as follows:
@modelfunctiontwo_model(x)# Hyper-parameters μ0=0.0 σ0=1.0# Draw weights. π1~Beta(1, 1) π2=1- π1# Draw locations of the components. μ1~Normal(μ0, σ0) μ2~Normal(μ0, σ0)# Draw latent assignment. z ~Categorical([π1, π2])# Draw observation from selected component.if z ==1 x ~Normal(μ1, 1.0)else x ~Normal(μ2, 1.0)endend
two_model (generic function with 2 methods)
Finite Mixture Model
If we have more than two components, this model can elegantly be extended using a Dirichlet distribution as prior for the mixing weights \(\pi_1, \dots, \pi_K\). Note that the Dirichlet distribution is the multivariate generalization of the beta distribution. The resulting model can be written as:
\[
\begin{align}
(\pi_1, \dots, \pi_K) &\sim Dirichlet(K, \alpha) \\
\mu_k &\sim \mathrm{Normal}(\mu_0, \Sigma_0), \;\; \forall k \\
z &\sim Categorical(\pi_1, \dots, \pi_K) \\
x &\sim \mathrm{Normal}(\mu_z, \Sigma)
\end{align}
\]
The question now arises, is there a generalization of a Dirichlet distribution for which the dimensionality \(K\) is infinite, i.e. \(K = \infty\)?
But first, to implement an infinite Gaussian mixture model in Turing, we first need to load the Turing.RandomMeasures module. RandomMeasures contains a variety of tools useful in nonparametrics.
usingTuring.RandomMeasures
We now will utilize the fact that one can integrate out the mixing weights in a Gaussian mixture model allowing us to arrive at the Chinese restaurant process construction. See Carl E. Rasmussen: The Infinite Gaussian Mixture Model, NIPS (2000) for details.
In fact, if the mixing weights are integrated out, the conditional prior for the latent variable \(z\) is given by:
where \(z_{\not i}\) are the latent assignments of all observations except observation \(i\). Note that we use \(n_k\) to denote the number of observations at component \(k\) excluding observation \(i\). The parameter \(\alpha\) is the concentration parameter of the Dirichlet distribution used as prior over the mixing weights.
Chinese Restaurant Process
To obtain the Chinese restaurant process construction, we can now derive the conditional prior if \(K \rightarrow \infty\).
Those equations show that the conditional prior for component assignments is proportional to the number of such observations, meaning that the Chinese restaurant process has a rich get richer property.
To get a better understanding of this property, we can plot the cluster choosen by for each new observation drawn from the conditional prior.
# Concentration parameter.α =10.0# Random measure, e.g. Dirichlet process.rpm =DirichletProcess(α)# Cluster assignments for each observation.z =Vector{Int}()# Maximum number of observations we observe.Nmax =500for i in1:Nmax# Number of observations per cluster. K =isempty(z) ? 0:maximum(z) nk =Vector{Int}(map(k ->sum(z .== k), 1:K))# Draw new assignment.push!(z, rand(ChineseRestaurantProcess(rpm, nk)))end
usingPlots# Plot the cluster assignments over time@giffor i in1:Nmaxscatter(collect(1:i), z[1:i]; markersize=2, xlabel="observation (i)", ylabel="cluster (k)", legend=false, )end
GKS: cannot open display - headless operation mode active
[ Info: Saved animation to /tmp/jl_s5aBETUdmf.gif
Further, we can see that the number of clusters is logarithmic in the number of observations and data points. This is a side-effect of the “rich-get-richer” phenomenon, i.e. we expect large clusters and thus the number of clusters has to be smaller than the number of observations.
We can see from the equation that the concentration parameter \(\alpha\) allows us to control the number of clusters formed a priori.
In Turing we can implement an infinite Gaussian mixture model using the Chinese restaurant process construction of a Dirichlet process as follows:
@modelfunctioninfiniteGMM(x)# Hyper-parameters, i.e. concentration parameter and parameters of H. α =1.0 μ0=0.0 σ0=1.0# Define random measure, e.g. Dirichlet process. rpm =DirichletProcess(α)# Define the base distribution, i.e. expected value of the Dirichlet process. H =Normal(μ0, σ0)# Latent assignment. z =zeros(Int, length(x))# Locations of the infinitely many clusters. μ =zeros(Float64, 0)for i in1:length(x)# Number of clusters. K =maximum(z) nk =Vector{Int}(map(k ->sum(z .== k), 1:K))# Draw the latent assignment. z[i] ~ChineseRestaurantProcess(rpm, nk)# Create a new cluster?if z[i] > Kpush!(μ, 0.0)# Draw location of new cluster. μ[z[i]] ~ Hend# Draw observation. x[i] ~Normal(μ[z[i]], 1.0)endend
infiniteGMM (generic function with 2 methods)
We can now use Turing to infer the assignments of some data points. First, we will create some random data that comes from three clusters, with means of 0, -5, and 10.
usingPlots, Random# Generate some test data.Random.seed!(1)data =vcat(randn(10), randn(10) .-5, randn(10) .+10)data .-=mean(data)data /=std(data);
Finally, we can plot the number of clusters in each sample.
# Extract the number of clusters for each sample of the Markov chain.k =map( t ->length(unique(vec(chain[t, MCMCChains.namesingroup(chain, :z), :].value))),1:iterations,);# Visualize the number of clusters.plot(k; xlabel="Iteration", ylabel="Number of clusters", label="Chain 1")
If we visualize the histogram of the number of clusters sampled from our posterior, we observe that the model seems to prefer 3 clusters, which is the true number of clusters. Note that the number of clusters in a Dirichlet process mixture model is not limited a priori and will grow to infinity with probability one. However, if conditioned on data the posterior will concentrate on a finite number of clusters enforcing the resulting model to have a finite amount of clusters. It is, however, not given that the posterior of a Dirichlet process Gaussian mixture model converges to the true number of clusters, given that data comes from a finite mixture model. See Jeffrey Miller and Matthew Harrison: A simple example of Dirichlet process mixture inconsitency for the number of components for details.
histogram(k; xlabel="Number of clusters", legend=false)
One issue with the Chinese restaurant process construction is that the number of latent parameters we need to sample scales with the number of observations. It may be desirable to use alternative constructions in certain cases. Alternative methods of constructing a Dirichlet process can be employed via the following representations: