Scaling Gaussian Mixture Models to Massive Datasets

Using algorithms from manifold optimization, it is possible to fit Gaussian Mixtures Models (GMMs) through gradient descent efficiently, which makes them more scalable and suitable for streaming data. This algorithm was implemented in a Scala package, which can work locally and on Apache Spark.
GMMs are among the most popular clustering algorithms due to their simplicity and the fact that they endow the data with an underlying probabilistic model that is easy to interpret. However most implementations out there use the EM algorithm to fit the model, and since the EM algorithm works in batches (i.e., it processes the whole dataset over and over) this makes it costly to scale the model to datasets with millions of observations.
When I was looking for interesting projects for my MSc dissertation I stumbled upon the paper of Hoseni & Sra (2017), in which, through methods from manifold optimization, they devised an algorithm that makes GMMs suitable to be fitted by stochastic gradient descent. For me this was a brilliant result, since it gives GMMs the scalability that we see in deep learning models and makes them naturally suitable for streaming data, which EM is not designed to handle; at that moment there was no implementation available online, so I decided to make one in Scala that could run on Apache Spark, and at the same time add a couple tweaks to make it faster and more robust.
The package's quickstart guide can be found on the repository, and some time ago I wrote a blog post about its comparison to Spark's own implementation; finally, my dissertation describes everything I did in painful detail, so in this post I'll limit the discussion to how the underlying mathematics work.
Manifold optimisation
I'll try to give an very brief explanation of how manifold optimization works: the main problem with gradient descent in the context of GMMs is that in reality fitting GMMs involves a constrained optimization problem, because we need the covariance matrices to stay symmetric positive definite (SPD); gradient descent is not suitable for this because it does not care about this detail, so we can end up getting an invalid covariance matrix. Luckily for us, the set of SPD matrices is actually a Riemannian manifold, and this means we can use techniques from manifold optimization to find our model's optimal covariance matrix.
Simply put, in manifold optimization we perform a usual gradient descent step, but then project the resulting point back into the manifold with some mathematical transformation that we call retraction and denote by .
The retraction operation depends on the Riemannian metric of our manifold, which is basically the dot product operator that we define between two elements in the tangent space (the blue plane on the figure above) of a manifold element (i.e., a covariance matrix in our case); the metric of choice in Hoseni & Sra (2017) is
The retraction operation in general is not unique, but a cheap one computationally speaking is
which is called Euclidean retraction in the mentioned paper; finally, for whatever loss function we have on regular Euclidean space (say, the model's log-likelihood), when computing the gradient we have to take into account that we are moving on a manifold and not on a regular Euclidean space, so the gradient of on the manifold is given by
where denotes the usual (i.e., Euclidean) gradient. What we have so far is enough to implement a gradient descent algorithm for the covariance matrices of a GMM taking the usual log-likelihood as our loss function, though it turns out that the model's log-likelihood is not ideal to use in this approach because it is not, in some sense, concave in the manifold; Hoseni & Sra (2017) proposed a reformulation which preserve the solutions of the original problem, makes the problem concave on the manifold, and wraps the mean and covariance parameter in a single, extended SPD matrix.
Reformulated problem (Hoseni & Sra (2017))
To show what the reformulation looks like, assume we have samples from a GMM with components, mean parameters , covariance parameters , and weight parameters . The regular log-likelihood maximization problem looks like
where is the density of a multivariate normal distribution with parameters evaluated on . For the reformulation, let us define the following expanded observations and wrapper matrices
We'll also use a slightly tweaked normal density, defined as
Finally, in order to also use gradient descent to fit the weights , we remove their constraint (i.e., ) by setting them as softmax functions of unconstrained real-valued auxiliary variables .
Finally, the reformulated loss function is given by
and it can be proven that it preserves the solutions of the original problem and is g-concave. Since the wrapper matrices are SPD as well, we can use the results from the last section as they stand; it also turns out that the parameter gradients have remarkably simple expressions, and for the -th component's parameters we have:
where is the posterior probability (up to the current iteration) of observation coming from component :
Upgrading to accelerated gradient descent
Being able to use SGD on GMMs also opens the door to using accelerated stochastic gradient descent (ASGD), techniques that have given good results on other Machine Learning models, particularly deep learning. The first one of my tweaks was to implement two flavors of ASGD: descent with momentum and Nesterov's correction. To briefly explain this, assume that at time , for a loss function and step size we have a current parameter estimate , then to get the next estimate in standard SGD (or rather ascent in this case) we would do
In the figure below, the left image correspond to Nesterov's correction and the right one to descent with momentum. The Y axis is the mean log-likelihood in the validation set after each iteration, the dotted line at the top is the true mean log-likelihood and the black trend represents standard SGD. As we can see, ASGD gets to the solution between 2 and 4 times faster than standard SGD, even when there is an excess of inertia!
In addition to this, numerical results also suggest that the right amount of inertia from ASGD can help the algorithm get out of saddle points where standard SGD and EM can get stuck:
According to the figure above, having too little inertia (standard SGD, in black, is a descent algorithm without inertia) or too much (the pink trends in both images) gets the algorithm stuck in saddle points, which does not happens for intermediate amounts of inertia (which in turn is controlled by the parameter in Nesterov's descent and by in descent with momentum); the bottom line is that ASGD makes for a potentially faster and more effective fitting algorithm for GMMs.
Comparison to EM
When comparing the accelerated models to Spark's EM implementation I tested different dimensionalities, number of components in the true underlying model and data separations; this last one is a measure of how close clusters are to one another, so the closer they are the harder it is to distinguish them. A value of 1 or less should start to be hard for the algorithms. ASGD didn't disappoint, and ended up speeding up the process up to a hundred times depending on the data configuration; though this varied widely, ASGD was at least 5 times as fast in all of the experiments. Even local ASGD performed much better than distributed EM on 32 cores! There is a price to pay for this speedup though, which is some additional estimation error, although for most application this is probably reasonable. To put some numbers in all this I share a fragment of the experimental results below:
Separation | Dimension | # Components | ASGD time (s) | ASGD progress | EM time (s) | EM progress |
---|---|---|---|---|---|---|
0.2 | 30 | 4 | 1.20 | 98.37% | 211.70 | 99.75% |
0.2 | 30 | 8 | 11.56 | 94.73% | 547.70 | 99.68% |
0.2 | 30 | 12 | 2.99 | 90.43% | 1116.44 | 99.86% |
What I call progress in the table, is the progress of the models towards the true data log-likelihood, which is the maximum achievable value, but I express it as a percentage (100% would mean getting to the exact true data log-likelihood, while 0% would be the starting point's log-likelihood), and I do this because raw log-likelihood values don't mean much; both algorithms had the very same starting point in all experiments.
There were much more parameter configurations but the qualitative nature or the results were consistent across them. Some of the additional estimation error can be made much smaller simply by doing more iterations or using larger mini batches, since the time surplus allows us to do that comfortably. As a final note, some EM runs were stuck very early one and ended up performing badly, while ASGD never stopped at less than 80% of progress, which might again be partly due to acceleration.
Avoiding mean collapse by using logarithmic barriers
Mean collapse is a phenomenon that can happen in GMMs when one of the components ends up with a single point or exactly on it, and the component's variance shrinks to the zero matrix. Hoseni & Sra (2017) suggests to fix this using a conjugate prior distribution over the parameters, but numerical experiments suggested that for small batch sizes of a few tens of observations, the influence of the prior is too strong and prevents the model from making rapid progress towards the true parameters, settling instead in a sort of middle point between them and the prior's expectations.
The second one of my tweaks was to instead use a good old logarithmic barrier regularization on the determinant of , which favors inflating the covariance matrix a bit but does not affect the mean or weight parameters; another small advantage is that we no longer need to do matrix products at each fitting iteration to compute the prior's regularization value, and since we need to calculate anyway at each iteration, the logarithmic barrier term is effectively free. The final form of this regularization is
which can be proven that preserves the solution of adding a logarithmic barrier to the original GMM formulation. The logarithmic barrier's gradient is also quite simple and given by
The figure below shows the algorithm's trajectories for both types of regularization and a batch size of 100 in a problem with three components, although only one of them is shown; the black ellipse represent the true distribution's contour line. The left image shows the trajectories after 10 iterations and the right image shows trajectories after 50 iterations; it is clear that logarithmic barrier allows for a faster progress, and that after a certain amount of progress, the conjugate prior regularization gets the algorithm stuck (also, the pull from the prior mean parameter is clearly visible in the pink trajectory).
To be fair, this last phenomenon would disapear as we increase the batch size, but this also goes against the philosophy of SGD and ASGD, where taking noisy mini-batch steps is the price that has to be paid to scale to large datasets.
Going further
After implementing everything in Spark, I realized that even though SGD can in fact make the processing much faster than a batch algorithm such as EM (local SGD beat distributed EM sometimes!), when working on a distributed environment the latency and communication costs makes that advantage disappear, and Spark's own GMM-EM implementation ends up being roughly as fast as SGD/ASGD; this is because of the synchronization costs more than anything else, but as far as I know, Spark does not allow asynchronous updates; this means that in order to exploit the big gains in performance that we got from ASGD, it would have to be implemented in a framework like Pytorch, Tensorflow or any other distributed asynchronous framework. It would definitely be interesting to see how GMMs perform when fitted under asynchronous ASGD, and if this is coupled with a way to dynamically update the number of components depending on the newest data, it would make GMMs a powerful model for handling streaming data.