Michael Betancourt recently wrote a nice case study describing the problems often encountered with gaussian mixture models, specifically the estimation of parameters of a mixture model and identifiability i.e. the problem with labelling mixtures (http://mc-stan.org/documentation/case-studies/identifying_mixture_models.html). Also there has been suggestions that GMM’s can’t be easily done in Stan. I’ve found various examples online of simple 2d gaussian mixtures, and one (wrong) example of a Multivariate GMM. I wanted to demonstrate that Stan can actually do Multivariate GMM’s and very quickly! But as Mike’s already discussed problems with identifiability are still inherent in the model.
For this I will use R, but of course Stan is also available in wrappers of python, ruby and others. Firstly lets get the required libraries:
library(MASS) require(rstan)
Then we need to generate some toy data. Working in a 4 dimensional parameter space, I want to create 3 gaussian mixtures at different locations:
#first cluster mu1=c(0,0,0,0) sigma1=matrix(c(0.1,0,0,0,0,0.1,0,0,0,0,0.1,0,0,0,0,0.1),ncol=4,nrow=4, byrow=TRUE) norm1=mvrnorm(30, mu1, sigma1) #second cluster mu2=c(7,7,7,7) sigma2=sigma1 norm2=mvrnorm(30, mu2, sigma2) #third cluster mu3=c(3,3,3,3) sigma3=sigma1 norm3=mvrnorm(30, mu3, sigma3) norms=rbind(norm1,norm2,norm3) #combine the 3 mixtures together N=90 #total number of data points Dim=4 #number of dimensions y=array(as.vector(norms), dim=c(N,Dim)) mixture_data=list(N=N, D=4, K=3, y=y)
The model only takes a few lines of code:
mixture_model<-' data { int D; //number of dimensions int K; //number of gaussians int N; //number of data vector[D] y[N]; //data } parameters { simplex[K] theta; //mixing proportions ordered[D] mu[K]; //mixture component means cholesky_factor_corr[D] L[K]; //cholesky factor of covariance } model { real ps[K]; for(k in 1:K){ mu[k] ~ normal(0,3); L[k] ~ lkj_corr_cholesky(4); } for (n in 1:N){ for (k in 1:K){ ps[k] = log(theta[k])+multi_normal_cholesky_lpdf(y[n] | mu[k], L[k]); //increment log probability of the gaussian } target += log_sum_exp(ps); } } '
To run the model in R only takes 1 line too. Here I use 11000 iteration steps, 1000 of which are warmup (for adaptation of the NUTS sampler parameters). I’ll use only 1 chain for speed:
fit=stan(model_code=mixture_model, data=mixture_data, iter=11000, warmup=1000, chains=1)
SAMPLING FOR MODEL '16de4bc17f41669412586868e09d4c65' NOW (CHAIN 1). Chain 1, Iteration: 1 / 11000 [ 0%] (Warmup) Chain 1, Iteration: 1001 / 11000 [ 9%] (Sampling) Chain 1, Iteration: 2100 / 11000 [ 19%] (Sampling) Chain 1, Iteration: 3200 / 11000 [ 29%] (Sampling) Chain 1, Iteration: 4300 / 11000 [ 39%] (Sampling) Chain 1, Iteration: 5400 / 11000 [ 49%] (Sampling) Chain 1, Iteration: 6500 / 11000 [ 59%] (Sampling) Chain 1, Iteration: 7600 / 11000 [ 69%] (Sampling) Chain 1, Iteration: 8700 / 11000 [ 79%] (Sampling) Chain 1, Iteration: 9800 / 11000 [ 89%] (Sampling) Chain 1, Iteration: 10900 / 11000 [ 99%] (Sampling) Chain 1, Iteration: 11000 / 11000 [100%] (Sampling) Elapsed Time: 13.0271 seconds (Warm-up) 99.2967 seconds (Sampling) 112.324 seconds (Total)
From the results we can see we get good convergence:
print(fit)
Inference for Stan model: 16de4bc17f41669412586868e09d4c65. 1 chains, each with iter=11000; warmup=1000; thin=1; post-warmup draws per chain=10000, total post-warmup draws=10000. mean se_mean sd 2.5% 25% 50% 75% 97.5% n_eff theta[1] 0.33 0.00 0.05 0.24 0.30 0.33 0.37 0.43 10000 theta[2] 0.33 0.00 0.05 0.24 0.30 0.33 0.37 0.43 10000 theta[3] 0.33 0.00 0.05 0.24 0.30 0.33 0.36 0.43 10000 mu[1,1] -0.09 0.01 0.17 -0.41 -0.19 -0.10 0.01 0.28 751 mu[1,2] -0.01 0.01 0.16 -0.33 -0.11 -0.02 0.08 0.35 952 mu[1,3] 0.13 0.00 0.16 -0.20 0.04 0.13 0.22 0.45 3100 mu[1,4] 0.19 0.00 0.16 -0.15 0.10 0.19 0.29 0.51 2086 mu[2,1] 6.85 0.01 0.12 6.57 6.78 6.87 6.93 7.06 133 mu[2,2] 6.90 0.01 0.12 6.63 6.84 6.92 6.98 7.11 129 mu[2,3] 6.95 0.01 0.11 6.69 6.89 6.96 7.01 7.15 333 mu[2,4] 7.03 0.00 0.11 6.79 6.97 7.03 7.09 7.27 579 mu[3,1] 2.78 0.00 0.13 2.50 2.71 2.80 2.87 2.99 1704 mu[3,2] 2.84 0.00 0.12 2.56 2.77 2.86 2.92 3.05 1005 mu[3,3] 2.91 0.01 0.13 2.62 2.84 2.93 2.99 3.17 179 mu[3,4] 3.12 0.01 0.14 2.85 3.04 3.11 3.21 3.42 451 L[1,1,1] 1.00 0.00 0.00 1.00 1.00 1.00 1.00 1.00 10000 L[1,1,2] 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 10000 L[1,1,3] 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 10000 L[1,1,4] 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 10000 L[1,2,1] 0.73 0.02 0.41 -0.78 0.80 0.85 0.89 0.93 263 L[1,2,2] 0.54 0.00 0.11 0.38 0.46 0.52 0.60 0.83 1141 L[1,2,3] 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 10000 L[1,2,4] 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 10000 L[1,3,1] 0.16 0.08 0.77 -0.89 -0.78 0.70 0.82 0.89 97 L[1,3,2] 0.04 0.02 0.24 -0.44 -0.13 0.07 0.21 0.49 205 L[1,3,3] 0.56 0.00 0.11 0.40 0.49 0.55 0.62 0.81 7240 L[1,3,4] 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 10000 L[1,4,1] 0.15 0.08 0.75 -0.89 -0.76 0.66 0.80 0.88 89 L[1,4,2] 0.20 0.01 0.23 -0.27 0.05 0.23 0.36 0.65 351 L[1,4,3] 0.31 0.00 0.15 0.03 0.23 0.31 0.40 0.63 2280 L[1,4,4] 0.44 0.00 0.08 0.31 0.38 0.43 0.49 0.64 1879 L[2,1,1] 1.00 0.00 0.00 1.00 1.00 1.00 1.00 1.00 10000 L[2,1,2] 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 10000 L[2,1,3] 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 10000 L[2,1,4] 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 10000 L[2,2,1] 0.03 0.11 0.71 -0.87 -0.75 0.36 0.73 0.85 43 L[2,2,2] 0.69 0.00 0.13 0.49 0.60 0.68 0.77 0.98 1978 L[2,2,3] 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 10000 L[2,2,4] 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 10000 L[2,3,1] 0.45 0.05 0.59 -0.78 0.36 0.76 0.83 0.90 130 L[2,3,2] 0.00 0.04 0.39 -0.70 -0.33 0.05 0.33 0.69 89 L[2,3,3] 0.53 0.00 0.11 0.37 0.46 0.52 0.59 0.79 935 L[2,3,4] 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 10000 L[2,4,1] 0.25 0.07 0.68 -0.80 -0.57 0.70 0.84 0.91 89 L[2,4,2] -0.03 0.07 0.42 -0.74 -0.38 -0.07 0.33 0.72 40 L[2,4,3] -0.15 0.02 0.22 -0.55 -0.30 -0.17 0.01 0.27 122 L[2,4,4] 0.46 0.00 0.09 0.31 0.40 0.45 0.51 0.69 883 L[3,1,1] 1.00 0.00 0.00 1.00 1.00 1.00 1.00 1.00 10000 L[3,1,2] 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 10000 L[3,1,3] 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 10000 L[3,1,4] 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 10000 L[3,2,1] 0.45 0.04 0.59 -0.83 0.49 0.75 0.82 0.89 178 L[3,2,2] 0.66 0.00 0.12 0.46 0.57 0.64 0.73 0.96 1936 L[3,2,3] 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 10000 L[3,2,4] 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 10000 L[3,3,1] 0.58 0.08 0.55 -0.81 0.71 0.82 0.87 0.92 48 L[3,3,2] 0.01 0.03 0.28 -0.57 -0.19 0.07 0.21 0.48 121 L[3,3,3] 0.53 0.00 0.10 0.37 0.45 0.51 0.58 0.78 580 L[3,3,4] 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 10000 L[3,4,1] -0.44 0.06 0.64 -0.90 -0.84 -0.77 -0.51 0.87 104 L[3,4,2] -0.13 0.02 0.29 -0.59 -0.33 -0.20 0.07 0.49 236 L[3,4,3] -0.10 0.02 0.22 -0.48 -0.25 -0.12 0.05 0.35 101 L[3,4,4] 0.48 0.00 0.09 0.35 0.42 0.47 0.53 0.69 2954 lp__ -443.66 0.15 5.00 -454.49 -446.87 -443.28 -440.12 -435.03 1053 Rhat theta[1] 1.00 theta[2] 1.00 theta[3] 1.00 mu[1,1] 1.00 mu[1,2] 1.00 mu[1,3] 1.00 mu[1,4] 1.00 mu[2,1] 1.02 mu[2,2] 1.02 mu[2,3] 1.01 mu[2,4] 1.00 mu[3,1] 1.00 mu[3,2] 1.00 mu[3,3] 1.00 mu[3,4] 1.00 L[1,1,1] NaN L[1,1,2] NaN L[1,1,3] NaN L[1,1,4] NaN L[1,2,1] 1.00 L[1,2,2] 1.00 L[1,2,3] NaN L[1,2,4] NaN L[1,3,1] 1.00 L[1,3,2] 1.00 L[1,3,3] 1.00 L[1,3,4] NaN L[1,4,1] 1.00 L[1,4,2] 1.00 L[1,4,3] 1.00 L[1,4,4] 1.00 L[2,1,1] NaN L[2,1,2] NaN L[2,1,3] NaN L[2,1,4] NaN L[2,2,1] 1.02 L[2,2,2] 1.00 L[2,2,3] NaN L[2,2,4] NaN L[2,3,1] 1.00 L[2,3,2] 1.04 L[2,3,3] 1.00 L[2,3,4] NaN L[2,4,1] 1.02 L[2,4,2] 1.02 L[2,4,3] 1.01 L[2,4,4] 1.00 L[3,1,1] NaN L[3,1,2] NaN L[3,1,3] NaN L[3,1,4] NaN L[3,2,1] 1.01 L[3,2,2] 1.00 L[3,2,3] NaN L[3,2,4] NaN L[3,3,1] 1.01 L[3,3,2] 1.01 L[3,3,3] 1.00 L[3,3,4] NaN L[3,4,1] 1.00 L[3,4,2] 1.00 L[3,4,3] 1.02 L[3,4,4] 1.00 lp__ 1.00 Samples were drawn using NUTS(diag_e) at Tue Mar 21 10:26:33 2017. For each parameter, n_eff is a crude measure of effective sample size, and Rhat is the potential scale reduction factor on split chains (at convergence, Rhat=1)
As you can see we get very good Rhat values and effective samples, also the timescale is reasonable. We recover the input parameters really well too
params=extract(fit) #density plots of the posteriors of the mixture means par(mfrow=c(2,2)) plot(density(params$mu[,1,1]), ylab='', xlab='mu[1]', main='') lines(density(params$mu[,1,2]), col=rgb(0,0,0,0.7)) lines(density(params$mu[,1,3]), col=rgb(0,0,0,0.4)) lines(density(params$mu[,1,4]), col=rgb(0,0,0,0.1)) abline(v=c(0), lty='dotted', col='red',lwd=2) plot(density(params$mu[,2,1]), ylab='', xlab='mu[2]', main='') lines(density(params$mu[,2,2]), col=rgb(0,0,0,0.7)) lines(density(params$mu[,2,3]), col=rgb(0,0,0,0.4)) lines(density(params$mu[,2,4]), col=rgb(0,0,0,0.1)) abline(v=c(7), lty='dotted', col='red',lwd=2) plot(density(params$mu[,3,1]), ylab='', xlab='mu[3]', main='') lines(density(params$mu[,3,2]), col=rgb(0,0,0,0.7)) lines(density(params$mu[,3,3]), col=rgb(0,0,0,0.4)) lines(density(params$mu[,3,4]), col=rgb(0,0,0,0.1)) abline(v=c(3), lty='dotted', col='red',lwd=2)

Marginalised 1D posteriors of the 3 gaussian mixture means. The red dotted line is the truth