Multivariate Gaussian Mixture Model done properly

Blog

Michael Betancourt recently wrote a nice case study describing the problems often encountered with gaussian mixture models, specifically the estimation of parameters of a mixture model and identifiability i.e. the problem with labelling mixtures (http://mc-stan.org/documentation/case-studies/identifying_mixture_models.html). Also there has been suggestions that GMM’s can’t be easily done in Stan. I’ve found various examples online of simple 2d gaussian mixtures, and one (wrong) example of a Multivariate GMM. I wanted to demonstrate that Stan can actually do Multivariate GMM’s and very quickly! But as Mike’s already discussed problems with identifiability are still inherent in the model.

For this I will use R, but of course Stan is also available in wrappers of python, ruby and others. Firstly lets get the required libraries:

library(MASS)
require(rstan)

Then we need to generate some toy data. Working in a 4 dimensional parameter space, I want to create 3 gaussian mixtures at different locations:

#first cluster
mu1=c(0,0,0,0)
sigma1=matrix(c(0.1,0,0,0,0,0.1,0,0,0,0,0.1,0,0,0,0,0.1),ncol=4,nrow=4, byrow=TRUE)
norm1=mvrnorm(30, mu1, sigma1)

#second cluster
mu2=c(7,7,7,7)
sigma2=sigma1
norm2=mvrnorm(30, mu2, sigma2)

#third cluster
mu3=c(3,3,3,3)
sigma3=sigma1
norm3=mvrnorm(30, mu3, sigma3)

norms=rbind(norm1,norm2,norm3) #combine the 3 mixtures together
N=90 #total number of data points 
Dim=4 #number of dimensions
y=array(as.vector(norms), dim=c(N,Dim))
mixture_data=list(N=N, D=4, K=3, y=y)

The model only takes a few lines of code:

mixture_model<-'
data {
 int D; //number of dimensions
 int K; //number of gaussians
 int N; //number of data
 vector[D] y[N]; //data
}

parameters {
 simplex[K] theta; //mixing proportions
 ordered[D] mu[K]; //mixture component means
 cholesky_factor_corr[D] L[K]; //cholesky factor of covariance
}

model {
 real ps[K];
 
 for(k in 1:K){
 mu[k] ~ normal(0,3);
 L[k] ~ lkj_corr_cholesky(4);
 }
 

 for (n in 1:N){
 for (k in 1:K){
 ps[k] = log(theta[k])+multi_normal_cholesky_lpdf(y[n] | mu[k], L[k]); //increment log probability of the gaussian
 }
 target += log_sum_exp(ps);
 }

}

'

To run the model in R only takes 1 line too. Here I use 11000 iteration steps, 1000 of which are warmup (for adaptation of the NUTS sampler parameters). I’ll use only 1 chain for speed:

fit=stan(model_code=mixture_model, data=mixture_data, iter=11000, warmup=1000, chains=1)
SAMPLING FOR MODEL '16de4bc17f41669412586868e09d4c65' NOW (CHAIN 1).

Chain 1, Iteration:     1 / 11000 [  0%]  (Warmup)
Chain 1, Iteration:  1001 / 11000 [  9%]  (Sampling)
Chain 1, Iteration:  2100 / 11000 [ 19%]  (Sampling)
Chain 1, Iteration:  3200 / 11000 [ 29%]  (Sampling)
Chain 1, Iteration:  4300 / 11000 [ 39%]  (Sampling)
Chain 1, Iteration:  5400 / 11000 [ 49%]  (Sampling)
Chain 1, Iteration:  6500 / 11000 [ 59%]  (Sampling)
Chain 1, Iteration:  7600 / 11000 [ 69%]  (Sampling)
Chain 1, Iteration:  8700 / 11000 [ 79%]  (Sampling)
Chain 1, Iteration:  9800 / 11000 [ 89%]  (Sampling)
Chain 1, Iteration: 10900 / 11000 [ 99%]  (Sampling)
Chain 1, Iteration: 11000 / 11000 [100%]  (Sampling)
 Elapsed Time: 13.0271 seconds (Warm-up)
               99.2967 seconds (Sampling)
               112.324 seconds (Total)

From the results we can see we get good convergence:

print(fit)
Inference for Stan model: 16de4bc17f41669412586868e09d4c65.
1 chains, each with iter=11000; warmup=1000; thin=1; 
post-warmup draws per chain=10000, total post-warmup draws=10000.

            mean se_mean   sd    2.5%     25%     50%     75%   97.5% n_eff
theta[1]    0.33    0.00 0.05    0.24    0.30    0.33    0.37    0.43 10000
theta[2]    0.33    0.00 0.05    0.24    0.30    0.33    0.37    0.43 10000
theta[3]    0.33    0.00 0.05    0.24    0.30    0.33    0.36    0.43 10000
mu[1,1]    -0.09    0.01 0.17   -0.41   -0.19   -0.10    0.01    0.28   751
mu[1,2]    -0.01    0.01 0.16   -0.33   -0.11   -0.02    0.08    0.35   952
mu[1,3]     0.13    0.00 0.16   -0.20    0.04    0.13    0.22    0.45  3100
mu[1,4]     0.19    0.00 0.16   -0.15    0.10    0.19    0.29    0.51  2086
mu[2,1]     6.85    0.01 0.12    6.57    6.78    6.87    6.93    7.06   133
mu[2,2]     6.90    0.01 0.12    6.63    6.84    6.92    6.98    7.11   129
mu[2,3]     6.95    0.01 0.11    6.69    6.89    6.96    7.01    7.15   333
mu[2,4]     7.03    0.00 0.11    6.79    6.97    7.03    7.09    7.27   579
mu[3,1]     2.78    0.00 0.13    2.50    2.71    2.80    2.87    2.99  1704
mu[3,2]     2.84    0.00 0.12    2.56    2.77    2.86    2.92    3.05  1005
mu[3,3]     2.91    0.01 0.13    2.62    2.84    2.93    2.99    3.17   179
mu[3,4]     3.12    0.01 0.14    2.85    3.04    3.11    3.21    3.42   451
L[1,1,1]    1.00    0.00 0.00    1.00    1.00    1.00    1.00    1.00 10000
L[1,1,2]    0.00    0.00 0.00    0.00    0.00    0.00    0.00    0.00 10000
L[1,1,3]    0.00    0.00 0.00    0.00    0.00    0.00    0.00    0.00 10000
L[1,1,4]    0.00    0.00 0.00    0.00    0.00    0.00    0.00    0.00 10000
L[1,2,1]    0.73    0.02 0.41   -0.78    0.80    0.85    0.89    0.93   263
L[1,2,2]    0.54    0.00 0.11    0.38    0.46    0.52    0.60    0.83  1141
L[1,2,3]    0.00    0.00 0.00    0.00    0.00    0.00    0.00    0.00 10000
L[1,2,4]    0.00    0.00 0.00    0.00    0.00    0.00    0.00    0.00 10000
L[1,3,1]    0.16    0.08 0.77   -0.89   -0.78    0.70    0.82    0.89    97
L[1,3,2]    0.04    0.02 0.24   -0.44   -0.13    0.07    0.21    0.49   205
L[1,3,3]    0.56    0.00 0.11    0.40    0.49    0.55    0.62    0.81  7240
L[1,3,4]    0.00    0.00 0.00    0.00    0.00    0.00    0.00    0.00 10000
L[1,4,1]    0.15    0.08 0.75   -0.89   -0.76    0.66    0.80    0.88    89
L[1,4,2]    0.20    0.01 0.23   -0.27    0.05    0.23    0.36    0.65   351
L[1,4,3]    0.31    0.00 0.15    0.03    0.23    0.31    0.40    0.63  2280
L[1,4,4]    0.44    0.00 0.08    0.31    0.38    0.43    0.49    0.64  1879
L[2,1,1]    1.00    0.00 0.00    1.00    1.00    1.00    1.00    1.00 10000
L[2,1,2]    0.00    0.00 0.00    0.00    0.00    0.00    0.00    0.00 10000
L[2,1,3]    0.00    0.00 0.00    0.00    0.00    0.00    0.00    0.00 10000
L[2,1,4]    0.00    0.00 0.00    0.00    0.00    0.00    0.00    0.00 10000
L[2,2,1]    0.03    0.11 0.71   -0.87   -0.75    0.36    0.73    0.85    43
L[2,2,2]    0.69    0.00 0.13    0.49    0.60    0.68    0.77    0.98  1978
L[2,2,3]    0.00    0.00 0.00    0.00    0.00    0.00    0.00    0.00 10000
L[2,2,4]    0.00    0.00 0.00    0.00    0.00    0.00    0.00    0.00 10000
L[2,3,1]    0.45    0.05 0.59   -0.78    0.36    0.76    0.83    0.90   130
L[2,3,2]    0.00    0.04 0.39   -0.70   -0.33    0.05    0.33    0.69    89
L[2,3,3]    0.53    0.00 0.11    0.37    0.46    0.52    0.59    0.79   935
L[2,3,4]    0.00    0.00 0.00    0.00    0.00    0.00    0.00    0.00 10000
L[2,4,1]    0.25    0.07 0.68   -0.80   -0.57    0.70    0.84    0.91    89
L[2,4,2]   -0.03    0.07 0.42   -0.74   -0.38   -0.07    0.33    0.72    40
L[2,4,3]   -0.15    0.02 0.22   -0.55   -0.30   -0.17    0.01    0.27   122
L[2,4,4]    0.46    0.00 0.09    0.31    0.40    0.45    0.51    0.69   883
L[3,1,1]    1.00    0.00 0.00    1.00    1.00    1.00    1.00    1.00 10000
L[3,1,2]    0.00    0.00 0.00    0.00    0.00    0.00    0.00    0.00 10000
L[3,1,3]    0.00    0.00 0.00    0.00    0.00    0.00    0.00    0.00 10000
L[3,1,4]    0.00    0.00 0.00    0.00    0.00    0.00    0.00    0.00 10000
L[3,2,1]    0.45    0.04 0.59   -0.83    0.49    0.75    0.82    0.89   178
L[3,2,2]    0.66    0.00 0.12    0.46    0.57    0.64    0.73    0.96  1936
L[3,2,3]    0.00    0.00 0.00    0.00    0.00    0.00    0.00    0.00 10000
L[3,2,4]    0.00    0.00 0.00    0.00    0.00    0.00    0.00    0.00 10000
L[3,3,1]    0.58    0.08 0.55   -0.81    0.71    0.82    0.87    0.92    48
L[3,3,2]    0.01    0.03 0.28   -0.57   -0.19    0.07    0.21    0.48   121
L[3,3,3]    0.53    0.00 0.10    0.37    0.45    0.51    0.58    0.78   580
L[3,3,4]    0.00    0.00 0.00    0.00    0.00    0.00    0.00    0.00 10000
L[3,4,1]   -0.44    0.06 0.64   -0.90   -0.84   -0.77   -0.51    0.87   104
L[3,4,2]   -0.13    0.02 0.29   -0.59   -0.33   -0.20    0.07    0.49   236
L[3,4,3]   -0.10    0.02 0.22   -0.48   -0.25   -0.12    0.05    0.35   101
L[3,4,4]    0.48    0.00 0.09    0.35    0.42    0.47    0.53    0.69  2954
lp__     -443.66    0.15 5.00 -454.49 -446.87 -443.28 -440.12 -435.03  1053
         Rhat
theta[1] 1.00
theta[2] 1.00
theta[3] 1.00
mu[1,1]  1.00
mu[1,2]  1.00
mu[1,3]  1.00
mu[1,4]  1.00
mu[2,1]  1.02
mu[2,2]  1.02
mu[2,3]  1.01
mu[2,4]  1.00
mu[3,1]  1.00
mu[3,2]  1.00
mu[3,3]  1.00
mu[3,4]  1.00
L[1,1,1]  NaN
L[1,1,2]  NaN
L[1,1,3]  NaN
L[1,1,4]  NaN
L[1,2,1] 1.00
L[1,2,2] 1.00
L[1,2,3]  NaN
L[1,2,4]  NaN
L[1,3,1] 1.00
L[1,3,2] 1.00
L[1,3,3] 1.00
L[1,3,4]  NaN
L[1,4,1] 1.00
L[1,4,2] 1.00
L[1,4,3] 1.00
L[1,4,4] 1.00
L[2,1,1]  NaN
L[2,1,2]  NaN
L[2,1,3]  NaN
L[2,1,4]  NaN
L[2,2,1] 1.02
L[2,2,2] 1.00
L[2,2,3]  NaN
L[2,2,4]  NaN
L[2,3,1] 1.00
L[2,3,2] 1.04
L[2,3,3] 1.00
L[2,3,4]  NaN
L[2,4,1] 1.02
L[2,4,2] 1.02
L[2,4,3] 1.01
L[2,4,4] 1.00
L[3,1,1]  NaN
L[3,1,2]  NaN
L[3,1,3]  NaN
L[3,1,4]  NaN
L[3,2,1] 1.01
L[3,2,2] 1.00
L[3,2,3]  NaN
L[3,2,4]  NaN
L[3,3,1] 1.01
L[3,3,2] 1.01
L[3,3,3] 1.00
L[3,3,4]  NaN
L[3,4,1] 1.00
L[3,4,2] 1.00
L[3,4,3] 1.02
L[3,4,4] 1.00
lp__     1.00

Samples were drawn using NUTS(diag_e) at Tue Mar 21 10:26:33 2017.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at 
convergence, Rhat=1)

As you can see we get very good Rhat values and effective samples, also the timescale is reasonable. We recover the input parameters really well too

params=extract(fit)
#density plots of the posteriors of the mixture means
par(mfrow=c(2,2))
plot(density(params$mu[,1,1]), ylab='', xlab='mu[1]', main='')
lines(density(params$mu[,1,2]), col=rgb(0,0,0,0.7))
lines(density(params$mu[,1,3]), col=rgb(0,0,0,0.4))
lines(density(params$mu[,1,4]), col=rgb(0,0,0,0.1))
abline(v=c(0), lty='dotted', col='red',lwd=2)

plot(density(params$mu[,2,1]), ylab='', xlab='mu[2]', main='')
lines(density(params$mu[,2,2]), col=rgb(0,0,0,0.7))
lines(density(params$mu[,2,3]), col=rgb(0,0,0,0.4))
lines(density(params$mu[,2,4]), col=rgb(0,0,0,0.1))
abline(v=c(7), lty='dotted', col='red',lwd=2)

plot(density(params$mu[,3,1]), ylab='', xlab='mu[3]', main='')
lines(density(params$mu[,3,2]), col=rgb(0,0,0,0.7))
lines(density(params$mu[,3,3]), col=rgb(0,0,0,0.4))
lines(density(params$mu[,3,4]), col=rgb(0,0,0,0.1))
abline(v=c(3), lty='dotted', col='red',lwd=2)
Multivariate gaussian mixture components

Marginalised 1D posteriors of the 3 gaussian mixture means. The red dotted line is the truth