I have recently been working on minibatch Markov chain Monte Carlo (MCMC) methods for Bayesian posterior inference. In this post, I’d like to give a brief summary of what that means and mention two ICML papers (from 2011 and 2014) that have substantially influenced my thinking.
When we say we do “MCMC for Bayesian posterior inference,” what this typically means is that we have some dataset and a parameter of interest for some . The posterior distribution we wish to estimate (via sampling) is
You’ll notice that we assume the data is conditionally independent given the parameter. In many cases, this is unrealistic, but most of the literature assumes this for simplicity.
After initializing , the procedure for sampling the posterior on step is:
- Draw a candidate
- Compute the acceptance probability (note that the denominators of cancel out):
- Draw , and accept if . This means setting . Otherwise, we set and repeat this loop. This tripped me up earlier. Just to be clear, we generate a sample every iteration, even if it is a repeat of the previous one.
This satisfies “detailed balance,” which roughly speaking, means that if one samples long enough, one will arrive at a stationary distribution matching the posterior, though a burn-in period and/or only using samples at regular intervals is often done in practice. The resulting collection of (correlated!) samples for large can be used to compute the value of . Let’s consider a very simple example. Say can take on three values: , , and . If our sampled set is , then since appears two times out of four, we have . The other two probabilities would naturally have each according to the samples.
One issue with the standard procedure above is that in today’s big data world with on the order of billions, it is ridiculously expensive to compute because that involves determining all of the likelihood factors. Remember, this has to be done every iteration! Doing all this just to get one bit of data (whether to accept or not) is not a good tradeoff. Hence, there has been substantial research on how to perform minibatch MCMC on large datasets. In this case, rather than use all data points, we just use a subset of points each iteration. This approximates the target distribution. The downside? It no longer satisfies detailed balance. (I don’t know the details on why, and it probably involves some complicated convergence studies, but I am willing to believe it.)
Just to be clear, we are focusing on getting a distribution, not a point estimate. That’s the whole purpose of Bayesian estimation! A distribution means we need a full probability function that sums to one and is non-negative; if is one or two dimensional, we can easily plot the posterior estimate (I provide an example of this later in this post). A point estimate means finding one value of , usually the “best”, which we commonly express as the maximum likelihood estimate .
All right, now let’s briefly discuss two papers that tackle the problem of minibatch MCMC.
Bayesian Learning via Stochastic Gradient Langevin Dynamics
This paper appeared in ICML 2011 and proposes using minibatch update methods for Bayesian posterior inference, with a concept known as Langevin Dynamics to inject the correct amount of noise into parameter updates so that the set of sampled parameters converges to the posterior, and not just a mode. To make the distinction clear, let’s see how we can use minibatch updates to converge to a mode — more specifically, the maximum a posteriori estimate. The function we are trying to optimize is listed above. So … we just use stochastic gradient ascent! (We are ascending, not descending, because is a posterior probability, and we want to make its values higher.) This means the update rule is as follows: . Plugging in above, we get
where is a sequence of step sizes.
The above is a stochastic gradient ascent update because we use terms to approximate the gradient value, which is why I inserted an approximation symbol () in the underbrace. Because we only use terms, however, we must multiply the summation by to rescale the value appropriately. Intuitively, all those terms in that summation are negative since they’re log probabilities. If we use instead of terms, that summation is strictly smaller in absolute value. So we must rescale to make the value on the same order of magnitude.
What’s the problem with the above for Bayesian posterior inference? It doesn’t actually do Bayesian posterior inference. The above will mean converges to a single value. We instead want a distribution. So what can we do? We can use Langevin Dynamics, meaning that (for the full batch case) our updates are:
A couple of things are worth noting:
- We use all terms in the summation, so the gradient is in fact exact.
- The injected noise means the values will “bounce around” to approximate a distribution and not converge to a single point.
- The is now constant instead of decreasing, and is balanced so that the variance of the injected noise matches that of the posterior.
- We use instead of as we did earlier. The only difference is that the indicates the randomness in the minibatch. It does not mean that is one scalar element of a vector. In other words, both and are in for some .
For simplicity, the above assumes that . In the general case, these should be multivariate Gaussians, with covariance .
The problem with this, of course, is the need to use all points. So let’s use points, and we have the following update:
where now, we need to vary and decrease towards zero. The reason for this is that as the step size goes to zero, the corresponding (expensive!) Metropolis-Hastings test has rejection rates that decrease to zero, effectively meaning we can omit it.
Austerity in MCMC Land: Cutting the Metropolis-Hastings Budget
This paper appeared in ICML 2014, and is also about minibatch MCMC. Here, instead of relying on simulating the physics of the system (as was the case with Stochastic Gradient Langevin Dynamics), they propose reformulating the standard MCMC method with the standard MH test into a sequential hypothesis test. To frame this, they take the log of both sides of the acceptance inequality:
In the first step we also dropped the initial “min” because if the “1” case applies, we will always accept. In the last step we divide both sides by . What is the purpose of this? The above is equivalent to the original MH test. But the right hand side depends on all data points, so what happens if we compute the right hand side using points?
This is the heart of their test. They start out by using a small fraction of the points and compute the right hand side. If the proposed element is so out of whack, then even with just points, we should already know to reject it. (And a similar case holds if is really good.) If we cannot tell whether to accept or reject with some specified confidence threshold, then we increase the minibatch size and test again. Their acceptance test relies on the Central Limit Theorem and the Student-t distribution. The details are in the paper, but the main idea is straightforward: increasing the number of samples increases our certainty as to whether we accept or reject, and we can generally make these decisions with far fewer than samples.
What’s the downside of their algorithm? In the worst case, we might need the entire data in one iteration. This may or may not be a problem, depending on the particular circumstances.
Their philosophy runs deeper than what the above says. Here’s a key quote:
We advocate MCMC algorithms with a “bias-knob”, allowing one to dial down the bias at a rate the optimally balances error due to bias and variance.
One other algorithm that adheres to this strategy? Stochastic Gradient Langevin Dynamics! (Not coincidentally, Professor Max Welling is a co-author on both papers.) Side note: the reason why SGLD is biased is because it omits the Metropolis-Hastings test. The reason why this algorithm (Adaptive Sampling) is biased is because it makes decisions based on a fraction of the data. So both are biased, but for slightly different reasons.
Putting it All Together: A Code Example
In order to better understand the two papers described above, I wrote some Python code to run SGLD and adaptive sampling. I also implemented the standard (full-batch) MCMC method for a baseline.
I tested using the experiment in Section 5.1 of the SGLD paper. The parameter is 2-D, , and the parameter/data generation process is
I’ve pasted all my Python code here. If you put this together in a Jupyter notebook, it should work correctly. If your time is limited, feel free to skip the code and go straight to the output. The code is here mainly for the benefit of current and future researchers who might want a direct implementation of the above algorithms.
The first part of my code imports the necessary libraries and generates the data according to my interpretation of their problem. I am generating 500 points here, whereas the SGLD paper only used 100 points.
Next, I define a bunch of functions that I will need for the future. The most important one is the
log_f function, which returns the log of the posterior. (Implementing this correctly requires
filling in some missing mathematical details that follow from the formulation of multivariate
Next, I run three methods to estimate the posterior: standard full-batch MCMC, Stochastic Gradient Langevin Dynamics, and the Adaptive Sampling method. Code comments clearly separate and indicate these methods in the following code section. Note the following:
- The standard MCMC and Adaptive Sampling methods use a random walk proposal.
- I used 10000 samples for the methods, except that for SGLD, I found that I needed to increase the number of samples (here it’s 30000) since the algorithm occasionally got stuck at a mode.
- The minibatch size for SGLD is 30, and for the Adaptive Sampling method, it starts at 10 (and can increase by 10 up to 500).
With the sampled values in
all_3, I can now plot
them using my favorite Python library, matplotlib. I also create a contour plot of what the log
posterior should really look like.
By the way, it was only recently that I found out I could put LaTeX directly in matplotlib text. That’s pretty cool!
Here are the results in a scatter plot where darker spots indicate more points:
It looks like all three methods roughly obtain the same posterior form.