Skip to content

Scaling Gaussian Mixture Models to Massive Datasets

Nestor Sanchez

Posted in unsupervised learning

3-cluster synthetic data from a Gaussian Mixtures model
3-cluster synthetic data from a Gaussian Mixtures model

Using algorithms from manifold optimization, it is possible to fit Gaussian Mixtures Models (GMMs) through gradient descent efficiently, which makes them more scalable and suitable for streaming data. This algorithm was implemented in a Scala package, which can work locally and on Apache Spark.

GMMs are among the most popular clustering algorithms due to their simplicity and the fact that they endow the data with an underlying probabilistic model that is easy to interpret. However most implementations out there use the EM algorithm to fit the model, and since the EM algorithm works in batches (i.e., it processes the whole dataset over and over) this makes it costly to scale the model to datasets with millions of observations.

When I was looking for interesting projects for my MSc dissertation I stumbled upon the paper of Hoseni & Sra (2017), in which, through methods from manifold optimization, they devised an algorithm that makes GMMs suitable to be fitted by stochastic gradient descent. For me this was a brilliant result, since it gives GMMs the scalability that we see in deep learning models and makes them naturally suitable for streaming data, which EM is not designed to handle; at that moment there was no implementation available online, so I decided to make one in Scala that could run on Apache Spark, and at the same time add a couple tweaks to make it faster and more robust.

The package's quickstart guide can be found on the repository, and some time ago I wrote a blog post about its comparison to Spark's own implementation; finally, my dissertation describes everything I did in painful detail, so in this post I'll limit the discussion to how the underlying mathematics work.

Manifold optimisation

I'll try to give an very brief explanation of how manifold optimization works: the main problem with gradient descent in the context of GMMs is that in reality fitting GMMs involves a constrained optimization problem, because we need the covariance matrices to stay symmetric positive definite (SPD); gradient descent is not suitable for this because it does not care about this detail, so we can end up getting an invalid covariance matrix. Luckily for us, the set of SPD matrices is actually a Riemannian manifold, and this means we can use techniques from manifold optimization to find our model's optimal covariance matrix.

Simply put, in manifold optimization we perform a usual gradient descent step, but then project the resulting point back into the manifold with some mathematical transformation that we call retraction and denote by RΣ()R_\Sigma(\cdot).

Graphical representation of the concept of retraction and vector transport within the framework of Riemannian optimization techniques.

The retraction operation depends on the Riemannian metric of our manifold, which is basically the dot product operator that we define between two elements A,BA, B in the tangent space (the blue plane on the figure above) of a manifold element Σ\Sigma (i.e., a covariance matrix in our case); the metric of choice in Hoseni & Sra (2017) is

gΣ(A,B)=tr(Σ1AΣ1B)g_{\Sigma}(A, B)=\operatorname{tr}\left(\Sigma^{-1} A \Sigma^{-1} B\right)

The retraction operation in general is not unique, but a cheap one computationally speaking is

RΣ(A)=Σ+AR_\Sigma(A) = \Sigma+A

which is called Euclidean retraction in the mentioned paper; finally, for whatever loss function we have on regular Euclidean space fE(Σ)f_E(\Sigma) (say, the model's log-likelihood), when computing the gradient we have to take into account that we are moving on a manifold and not on a regular Euclidean space, so the gradient of f(Σ)f(\Sigma) on the manifold is given by

f(Σ)=12Σ(Ef(Σ)+[Ef(Σ)]T)Σ\nabla f(\Sigma)=\frac{1}{2} \Sigma\left(\nabla_{E} f(\Sigma)+\left[\nabla_{E} f(\Sigma)\right]^{T}\right) \Sigma

where E\nabla_{E} denotes the usual (i.e., Euclidean) gradient. What we have so far is enough to implement a gradient descent algorithm for the covariance matrices of a GMM taking the usual log-likelihood as our loss function, though it turns out that the model's log-likelihood is not ideal to use in this approach because it is not, in some sense, concave in the manifold; Hoseni & Sra (2017) proposed a reformulation which preserve the solutions of the original problem, makes the problem concave on the manifold, and wraps the mean and covariance parameter in a single, extended SPD matrix.

Reformulated problem (Hoseni & Sra (2017))

To show what the reformulation looks like, assume we have nn samples x1,...,xnRx_1,...,x_n \in \mathbb{R} from a GMM with kk components, mean parameters μ1,...,μk\mu_1,..., \mu_k, covariance parameters Σ1,...,Σk\Sigma_1,...,\Sigma_k, and weight parameters π1,...,πk\pi_1,...,\pi_k. The regular log-likelihood maximization problem looks like

max{πj}j=1k,{μj}j=1k,{Σj}j=1ki=1nlog(j=1kπjN(xi;μj,Σj))\text{max}_{\{\pi_j\}_{j=1}^k, \{\mu_j\}_{j=1}^k, \{\Sigma_j\}_{j=1}^k} \sum_{i=1}^n \log \left( \sum_{j=1}^k \pi_j \mathcal{N}(x_i;\mu_j,\Sigma_j)\right)

where N(x;μ,Σ)\mathcal{N}(x;\mu,\Sigma) is the density of a multivariate normal distribution with parameters μ,Σ\mu, \Sigma evaluated on xx. For the reformulation, let us define the following expanded observations and wrapper matrices

yi=(xi1)T,i=1,...,ny_{i}=\left(\begin{array}{ll} {x_{i}} & {1} \end{array}\right)^{T}, \quad i=1,...,n Sj=[Σj+μjμjTμjμjT1]R(d+1)×(d+1),j=1,...,kS_{j}=\left[\begin{array}{cc}{\Sigma_{j}+ \mu_{j} \mu_{j}^{T}} & { \mu_{j}} \\{ \mu_{j}^{T}} & {1}\end{array}\right] \in \mathbb{R}^{(d+1) \times(d+1)}, \quad j=1,...,k

We'll also use a slightly tweaked normal density, defined as

qN(yi;Sj)=2πexp(12)N(yi;0,Sj)q_{\mathcal{N}}\left(y_{i} ; S_{j}\right)=\sqrt{2 \pi} \exp \left(\frac{1}{2}\right) \mathcal{N}\left(y_{i} ; 0, S_{j}\right)

Finally, in order to also use gradient descent to fit the weights π1,...,πk\pi_1,...,\pi_k, we remove their constraint (i.e., j=1kπj=1\sum_{j=1}^k \pi_j = 1) by setting them as softmax functions of unconstrained real-valued auxiliary variables ω1,...,ωkR\omega_1,...,\omega_k \in \mathbb{R}.

πj=exp(ωj)j=1Kexp(ωj),j=1,k\pi_{j}=\frac{\exp \left(\omega_{j}\right)}{\sum_{j=1}^{K} \exp \left(\omega_{j}\right)}, \quad j=1, \ldots k

Finally, the reformulated loss function is given by

max{Sj},{ωj}L^(Y;{Sj},{ωj})=max{Sj},{ωj}1ni=1nln(j=1Kexp(ωj)m=1Kexp(ωm)qN(yi;Sj))\underset{\left\{S_{j}\right\},\left\{\omega_{j}\right\}}{\operatorname{max}} \hat{\mathcal{L}}\left(Y ;\left\{S_{j}\right\},\left\{\omega_{j}\right\}\right)=\underset{\left\{S_{j}\right\},\left\{\omega_{j}\right\}}{\operatorname{max}} \frac{1}{n} \sum_{i=1}^{n} \ln \left(\sum_{j=1}^{K} \frac{\exp \left(\omega_{j}\right)}{\sum_{m=1}^{K} \exp \left(\omega_{m}\right)} q_{\mathcal{N}}\left(y_{i} ; S_{j}\right)\right)

and it can be proven that it preserves the solutions of the original problem and is g-concave. Since the wrapper matrices SjS_j are SPD as well, we can use the results from the last section as they stand; it also turns out that the parameter gradients have remarkably simple expressions, and for the jj-th component's parameters we have:

SjL^(Y;{Sj},{ωj})=12ni=1npi(yiyiTSj)\nabla_{S_{j}} \hat{\mathcal{L}}\left(Y ;\left\{S_{j}\right\},\left\{\omega_{j}\right\}\right) = \frac{1}{2 n} \sum_{i=1}^{n} p_{i}\left(y_{i} y_{i}^{T}-S_{j}\right) ωjL^(Y;{Sj},{ωj})=(1ni=1npi)πj\nabla_{\omega_{j}} \hat{\mathcal{L}}\left(Y ;\left\{S_{j}\right\},\left\{\omega_{j}\right\}\right) = \left(\frac{1}{n} \sum_{i=1}^{n} p_{i}\right)-\pi_{j}

where pip_i is the posterior probability (up to the current iteration) of observation ii coming from component jj:

pi=πiqN(yi;Sj)j=1KπiqN(yi;Sj)p_{i}=\frac{\pi_{i} q_{\mathcal{N}}\left(y_{i} ; S_{j}\right)}{\sum_{j=1}^{K} \pi_{i} q_{\mathcal{N}}\left(y_{i} ; S_{j}\right)}
After all this, the only thing left to do is getting all of it into a computer program.

Upgrading to accelerated gradient descent

Being able to use SGD on GMMs also opens the door to using accelerated stochastic gradient descent (ASGD), techniques that have given good results on other Machine Learning models, particularly deep learning. The first one of my tweaks was to implement two flavors of ASGD: descent with momentum and Nesterov's correction. To briefly explain this, assume that at time tt, for a loss function f(;θ)f(*;\theta) and step size α\alpha we have a current parameter estimate θt\theta_t, then to get the next estimate in standard SGD (or rather ascent in this case) we would do

θt+1=θt+αθf(;θt)\theta_{t+1} = \theta_t + \alpha \nabla_{\theta} f(*;\theta_t)
In ascent with momentum we would instead do
vt+1=βvt+θf(;θt),β(0,1),v0=0θt+1=θt+αvt+1\begin{aligned} v_{t+1} & = \beta v_t + \nabla_{\theta} f(*;\theta_t), \quad \beta \in (0,1), v_0 = 0\\ \theta_{t+1} & = \theta_t + \alpha v_{t+1} \end{aligned}
and in the case of Nesterov's correction:
vt+1=vt+γf(;θt),γ(0,1),v0=0θt+1=θt+1+α(vt+1vt)\begin{aligned} v_{t+1} & =v_{t}+\gamma \nabla f\left(*;\theta_t\right), \quad \gamma \in (0,1), v_0=0\\ \theta_{t+1} & = \theta_{t+1}+\alpha\left(v_{t+1}-v_{t}\right) \end{aligned}
We can see both ASGD flavors as adding some kind of inertia, as in a ball going down a slope, but in Nesterov's case this inertia is also sort of a correction in direction, and in fact it was proven optimal for a certain kind of functions. It turns out that in the case of GMMs they both improve the performance considerably; all the experiments were run on synthetic data with varying degrees of difficulty (the closer the model's components are, the harder it is for an algorithm to find the true parameters) using mini batches in each iteration, and the results were similar in all of them.

In the figure below, the left image correspond to Nesterov's correction and the right one to descent with momentum. The Y axis is the mean log-likelihood in the validation set after each iteration, the dotted line at the top is the true mean log-likelihood and the black trend represents standard SGD. As we can see, ASGD gets to the solution between 2 and 4 times faster than standard SGD, even when there is an excess of inertia!


In addition to this, numerical results also suggest that the right amount of inertia from ASGD can help the algorithm get out of saddle points where standard SGD and EM can get stuck:


According to the figure above, having too little inertia (standard SGD, in black, is a descent algorithm without inertia) or too much (the pink trends in both images) gets the algorithm stuck in saddle points, which does not happens for intermediate amounts of inertia (which in turn is controlled by the γ\gamma parameter in Nesterov's descent and by β\beta in descent with momentum); the bottom line is that ASGD makes for a potentially faster and more effective fitting algorithm for GMMs.

Comparison to EM

When comparing the accelerated models to Spark's EM implementation I tested different dimensionalities, number of components in the true underlying model and data separations; this last one is a measure of how close clusters are to one another, so the closer they are the harder it is to distinguish them. A value of 1 or less should start to be hard for the algorithms. ASGD didn't disappoint, and ended up speeding up the process up to a hundred times depending on the data configuration; though this varied widely, ASGD was at least 5 times as fast in all of the experiments. Even local ASGD performed much better than distributed EM on 32 cores! There is a price to pay for this speedup though, which is some additional estimation error, although for most application this is probably reasonable. To put some numbers in all this I share a fragment of the experimental results below:

Separation Dimension # Components ASGD time (s) ASGD progress EM time (s) EM progress
0.2 30 4 1.20 98.37% 211.70 99.75%
0.2 30 8 11.56 94.73% 547.70 99.68%
0.2 30 12 2.99 90.43% 1116.44 99.86%

What I call progress in the table, is the progress of the models towards the true data log-likelihood, which is the maximum achievable value, but I express it as a percentage (100% would mean getting to the exact true data log-likelihood, while 0% would be the starting point's log-likelihood), and I do this because raw log-likelihood values don't mean much; both algorithms had the very same starting point in all experiments.

There were much more parameter configurations but the qualitative nature or the results were consistent across them. Some of the additional estimation error can be made much smaller simply by doing more iterations or using larger mini batches, since the time surplus allows us to do that comfortably. As a final note, some EM runs were stuck very early one and ended up performing badly, while ASGD never stopped at less than 80% of progress, which might again be partly due to acceleration.

Avoiding mean collapse by using logarithmic barriers

Mean collapse is a phenomenon that can happen in GMMs when one of the components ends up with a single point or exactly on it, and the component's variance shrinks to the zero matrix. Hoseni & Sra (2017) suggests to fix this using a conjugate prior distribution over the parameters, but numerical experiments suggested that for small batch sizes of a few tens of observations, the influence of the prior is too strong and prevents the model from making rapid progress towards the true parameters, settling instead in a sort of middle point between them and the prior's expectations.

The second one of my tweaks was to instead use a good old logarithmic barrier regularization on the determinant of SjS_j, which favors inflating the covariance matrix a bit but does not affect the mean or weight parameters; another small advantage is that we no longer need to do kk matrix products at each fitting iteration to compute the prior's regularization value, and since we need to calculate det(Sj)\det(S_j) anyway at each iteration, the logarithmic barrier term is effectively free. The final form of this regularization is

argmax{Sj},{ωj}L^(Y;{Sj},{ωj})+j=1K(lndet(Sj)1)\underset{\left\{S_{j}\right\},\left\{\omega_{j}\right\}}{\operatorname{argmax}} \hat{\mathcal{L}}\left(Y ;\left\{S_{j}\right\},\left\{\omega_{j}\right\}\right)+\sum_{j=1}^{K}\left(\ln \operatorname{det}\left(S_{j}\right)-1\right)

which can be proven that preserves the solution of adding a logarithmic barrier to the original GMM formulation. The logarithmic barrier's gradient is also quite simple and given by

Sj[lndet(Sj)]=12n(Sjμ^jμ^jT),μ^j=(μj1)T\nabla_{S_j}[\ln \det (S_j)] = \frac{1}{2 n}\left(S_{j}-\hat{\mu}_{j} \hat{\mu}_{j}^{T}\right), \hat{\mu}_{j}=\left(\begin{array}{ll} \mu_j & {1} \end{array}\right)^{T}

The figure below shows the algorithm's trajectories for both types of regularization and a batch size of 100 in a problem with three components, although only one of them is shown; the black ellipse represent the true distribution's contour line. The left image shows the trajectories after 10 iterations and the right image shows trajectories after 50 iterations; it is clear that logarithmic barrier allows for a faster progress, and that after a certain amount of progress, the conjugate prior regularization gets the algorithm stuck (also, the pull from the prior mean parameter (0,0)(0,0) is clearly visible in the pink trajectory).


To be fair, this last phenomenon would disapear as we increase the batch size, but this also goes against the philosophy of SGD and ASGD, where taking noisy mini-batch steps is the price that has to be paid to scale to large datasets.

Going further

After implementing everything in Spark, I realized that even though SGD can in fact make the processing much faster than a batch algorithm such as EM (local SGD beat distributed EM sometimes!), when working on a distributed environment the latency and communication costs makes that advantage disappear, and Spark's own GMM-EM implementation ends up being roughly as fast as SGD/ASGD; this is because of the synchronization costs more than anything else, but as far as I know, Spark does not allow asynchronous updates; this means that in order to exploit the big gains in performance that we got from ASGD, it would have to be implemented in a framework like Pytorch, Tensorflow or any other distributed asynchronous framework. It would definitely be interesting to see how GMMs perform when fitted under asynchronous ASGD, and if this is coupled with a way to dynamically update the number of components depending on the newest data, it would make GMMs a powerful model for handling streaming data.