Understanding Prioritized Experience Replay
Prioritized Experience Replay (PER) is one of the most important and conceptually straightforward improvements for the vanilla Deep Q-Network (DQN) algorithm. It is built on top of experience replay buffers, which allow a reinforcement learning (RL) agent to store experiences in the form of transition tuples, usually denoted as $(s_t,a_t,r_{t},s_{t+1})$ with states, actions, rewards, and successor states at some time index $t$. In contrast to consuming samples online and discarding them thereafter, sampling from the stored experiences means they are less heavily “correlated” and can be re-used for learning.
Uniform sampling from a replay buffer is a good default strategy, and probably the first one to attempt. But prioritized sampling, as the name implies, will weigh the samples so that “important” ones are drawn more frequently for training. In this post, I review Prioritized Experience Replay, with an emphasis on relevant ideas or concepts that are often hidden under the hood or implicitly assumed.
I assume that PER is applied with the DQN framework because that is what the original paper used, but PER can, in theory, be applied to any algorithm which samples from a database of items. As most Artificial Intelligence students and practitioners probably know, the DQN algorithm attempts to find a policy \(\pi\) which maps a given state $s_t$ to an action $a_t$ such that it maximizes the expected reward of the agent \(\mathbb{E}_{\pi}\Big[ \sum_{t=0}^\infty r_t \Big]\) from some starting state $s_0$. DQN obtains \(\pi\) implicitly by calculating a state-value function $Q_\theta(s,a)$ parameterized by $\theta$, which measures the goodness of the given state-action with respect to some behavioral policy. (This is a critical point that’s often missed: state-action values, or state-values for that matter, don’t make sense unless they are also attached to some policy.)
To find an appropriate $\theta$, which then determines the final policy $\pi$, DQN performs the following optimization:
\[{\rm minimize}_{\theta} \;\; \mathbb{E}_{(s_t,a_t,r_t,s_{t+1})\sim D} \left[ \Big(r_t + \gamma \max_{a \in \mathcal{A}} Q_{\theta^-}(s_{t+1},a) - Q_\theta(s_t,a_t)\Big)^2 \right]\]where $(s_t,a_t,r_t,s_{t+1})$ are batches of samples from the replay buffer $D$, which is designed to store the past $N$ samples (usually $N=1,000,000$ for Atari 2600 benchmarks). In addition, $\mathcal{A}$ represents the set of discrete actions, $\theta$ is the current or online network, and $\theta^-$ represents the target network. Both networks use the same architecture, and we use $Q_\theta(s,a)$ or $Q_{\theta^-}(s,a)$ to denote which of the two is being applied to evaluate $(s,a)$.
The target network starts off by getting matched to the current network, but remains frozen (usually for thousands of steps) before getting updated again to match the network. The process repeats throughout training, with the goal of increasing the stability of the targets $r_t + \gamma \max_{a \in \mathcal{A}} Q_{\theta^-}(s_{t+1},a)$.
I have an older blog post here if you would like an intuitive perspective on DQN. For more background on reinforcement learning, I refer you to the standard textbook in the field by Sutton and Barto. It is freely available (God bless the authors) and updated to the second edition for 2018. Woo hoo! Expect future blog posts here about the more technical concepts from the book.
Now, let us get started on PER. The intuition of the algorithm is clear, and the Prioritized Experience Replay paper (presented at ICLR 2016) is surprisingly readable. They say:
In particular, we propose to more frequently replay transitions with high expected learning progress, as measured by the magnitude of their temporal-difference (TD) error. This prioritization can lead to a loss of diversity, which we alleviate with stochastic prioritization, and introduce bias, which we correct with importance sampling. Our resulting algorithms are robust and scalable, which we demonstrate on the Atari 2600 benchmark suite, where we obtain faster learning and state-of-the-art performance.
The paper was written in 2015 and submitted to ICLR 2016, so straight-up PER with DQN is definitely not state of the art performance. For example, the Rainbow DQN algorithm is superior. Everything else is correct, though. The PER idea reminds me of “hard negative mining” in the supervised learning setting. The magnitude of the TD error (squared) is what we want to minimize in the Bellman equation. Hence, pick the samples with the largest error so that our neural network can minimize it!
To clarify a somewhat implied point (for those who did not read the paper), and to play some devil’s advocate, why do we minimize the magnitude of the TD error? Ideally we would sample with respect to some mysterious function $f( (s_t,a_t,r_t,s_{t+1}) )$ that exactly tells us the “usefulness” of sample $(s_t,a_t,r_t,s_{t+1})$ for fastest learning to get maximum reward. But since this magical function $f$ is unknown, we use absolute TD error because it appears to be a reasonable approximation to it. There are other options, and I encourage you to read the discussion in Appendix A. I am not sure how many alternatives to TD error magnitude have been implemented in the literature. Since I have not seen any (besides a KL-based one in Rainbow DQN), it suggests that DeepMind’s choice of absolute TD error was the right one. The TD error for vanilla DQN, is:
\[\delta_i = r_t + \gamma \max_{a \in \mathcal{A}} Q_{\theta^-}(s_{t+1},a) - Q_\theta(s_t,a_t)\]and for Double DQN, it would be:
\[\delta_i = r_t + \gamma Q_{\theta^-}(s_{t+1},{\rm argmax}_{a \in \mathcal{A}} Q_\theta(s_{t+1},a)) - Q_\theta(s_t,a_t)\]and either way, we use $| \delta_i |$ as the magnitude of the TD error. Negative versus positive TD errors are combined into one case here, but in principle we could consider them as separate cases and add a bonus to whichever one we feel is more important to address.
This provides the absolute TD error, but how do we incorporate this into an RL algorithm?
First, we can immediately try to assign the priorities ($| \delta_i |$) as components to add to the samples. That means our replay buffer samples are now $(s_{t},a_{t},r_{t},s_{t+1}, | \delta_t |)$. (Strictly speaking, they should also have a “done” flag $d_t$ which tells us if we should use the bootstrapped estimate of our target, but we often omit this notation since it is implicitly assumed. This is yet another minor detail that is not clear until one implements DQN.)
But then here’s a problem: how is it possible to keep a tally of all the magnitude of TD errors updated? Replay buffers might have a million elements in them. Each time we update the neural network, do we really need to update each and every $\delta_i$ term, which would involve a forward pass through $Q_\theta$ (and possibly $Q_{\theta^-}$ if it was changed) for each item in the buffer? DeepMind proposes a far more computationally efficient alternative of only updating the $\delta_i$ terms for items that are actually sampled during the minibatch gradient updates. Since we have to compute $\delta_i$ anyway to get the loss, we might as well use those to change the priorities. For a minibatch size of 32, each gradient update will change the priorities of 32 samples in the replay buffer, but leave the (many) remaining items alone.
That makes sense. Next, given the absolute TD terms, how do we get a probability distribution for sampling? DeepMind proposes two ways of getting priorities, denoted as $p_i$:
-
A rank based method: $p_i = 1 / {\rm rank}(i)$ which sorts the items according to $| \delta_i |$ to get the rank.
-
A proportional variant: $p_i = | \delta_i | + \epsilon$, where $\epsilon$ is a small constant ensuring that the sample has some non-zero probability of being drawn.
During exploration, the $p_i$ terms are not known for brand-new samples because those have not been evaluated with the networks to get a TD error term. To get around this, PER initializes $p_i$ according to the maximum priority of any priority thus far, thus favoring those terms during sampling later.
From either of these, we can easily get a probability distribution:
\[P(i) = \frac{p_i^\alpha}{\sum_k p_k^\alpha}\]where $\alpha$ determines the level of prioritization. If $\alpha \to 0$, then there is no prioritization, because all $p(i)^\alpha =1$. If $\alpha \to 1$, then we get to, in some sense, “full” prioritization, where sampling data points is more heavily dependent on the actual $\delta_i$ values. Now that I think about it, we could increase $\alpha$ above one, but that would likely cause dramatic problems with over-fitting as the distribution could become heavily “pointy” with low entropy.
We finally have our actual probability $P(i)$ of sampling the $i$-th data point for a given minibatch, which would be (again) $(s_t,a_t,r_t,s_{t+1},| \delta_t |)$. During training, we can draw these simply by weighting all samples in the $N$-sized replay buffer by $P(i)$.
Since the buffer size $N$ can be quite large (e.g., one million), DeepMind uses special data structures to reduce the time complexity of certain operations. For the proportional-based variant, which is what OpenAI implements, a sum-tree data structure is used to make both updating and sampling $O(\log N)$ operations.
Is that it? Well, not quite. There are a few technical details to resolve, but probably the most important one (pun intended) is an importance sampling correction. DeepMind describes why:
The estimation of the expected value with stochastic updates relies on those updates corresponding to the same distribution as its expectation. Prioritized replay introduces bias because it changes this distribution in an uncontrolled fashion, and therefore changes the solution that the estimates will converge to (even if the policy and state distribution are fixed). We can correct this bias by using importance-sampling (IS) weights.
This makes sense. Here is my intuition, which I hope is useful. I think the distribution DeepMind is talking about (“same distribution as its expectation”) above is the distribution of samples that are obtained when sampling uniformly at random from the replay buffer. Recall the expectation I wrote above, which I repeat again for convenience:
\[{\rm minimize}_{\theta} \;\; \mathbb{E}_{(s_t,a_t,r_t,s_{t+1})\sim D} \left[ \Big(r_t + \gamma \max_{a \in \mathcal{A}} Q_{\theta^-}(s_{t+1},a) - Q_\theta(s_t,a_t)\Big)^2 \right]\]Here, the “true distribution” for the expectation is indicated with this notation under the expectation:
\[(s_t,a_t,r_{t},s_{t+1})\sim D\]which means we uniformly sample from the replay buffer. Since prioritization means we are not doing that, then the distribution of samples we get is different from the “true” distribution using uniform sampling. In particular, PER over-samples those with high priority, so the importance sampling correction should down-weight the impact of the sampled term, which it does by scaling the gradient term so that the gradient has “less impact” on the parameters.
To add yet more confusion, I don’t even think the uniform sampling is the “true” distribution we want, in the sense that it is the distribution under the expectation for the Q-learning loss. What I think we want is the actual set of samples that are induced by the agent’s current policy, so that we really use:
\[(s_t,a_t,r_{t},s_{t+1})\sim \pi\]where $\pi$ is a policy induced from the agent’s current Q-values. Perhaps it is greedy for simplicity. So what effectively happens is that, due to uniform sampling, there is extra bias and over-sampling towards the older samples in the replay buffer. Despite this, we should be OK because Q-learning is off-policy, so it shouldn’t matter in theory where the samples come from. Thus it’s unclear what a “true distribution of samples” should be like, if any exists. Incidentally, the off-policy aspect of Q-learning and why it does not take expectations “over the policy” appears to be the reason why importance sampling is not needed in vanilla DQN. (When we add an ingredient like importance sampling to PER, it is worth thinking about why we had to use it in this case, and not in others.) Things might change when we talk about $n$-step returns, but that raises the complexity to a new level … or we might just ignore importance sampling corrections, as this StackExchange answer suggests.
This all makes sense intuitively, but there has to be a nice, rigorous way to formalize it. The “TL;DR” is that the importance sampling in PER is to correct the over-sampling with respect to the uniform distribution.
Hopefully this is clear. Feel free to refer back to an earlier blog post about importance sampling more generally; I was hoping to follow it up right away with this current post, but my blogging plans never go according to plan.
How do we apply importance sampling? We use the following weights:
\[w_i = \left( \frac{1}{N} \cdot \frac{1}{P(i)} \right)^\beta\]and then further scaled in each minibatch so that $\max_i w_i = 1$ for stability reasons; generally, we don’t want weights to be wildly large.
Let’s dissect this term. The $1/N$ part is because of the current experience replay size. To clarify: this is NOT the same as the capacity of the buffer, and it only becomes equivalent to it once we hit the capacity and have to start over-riding samples. The $P(i)$ represents the probability of sampling data point $i$ according to priorities. It is this key term that scales the weights proportionally. As $P(i) \to 1$ (which really should never happen) the weight gets smaller, with an extreme down-weighting of the sample’s impact. As $P(i) \to 0$, the weight gets larger. If $P(i) = 1/N$ for all $i$, then we get uniform sampling with the $1/N$ term canceling out $1/(1/N)$.
Don’t forget the $\beta$ term in the exponent, which controls how much prioritization to apply. They argue that training is highly unstable at the beginning, and that importance sampling corrections matter more near the end of training. Thus, $\beta$ starts small (values of 0.4 to 0.6 are commonly used) and anneals towards one.
We finally “fold” this weight together with the $\delta_i$ TD error term during training, with $w_i \delta_i$, because the $\delta_i$ is multiplied with the gradient $\nabla_\theta Q_\theta(s_t,a_t)$ following the chain rule.
The PER paper shows that PER+(D)DQN it outperforms uniform sampling on 41 out of 49 Atari 2600 games, though which of the exact 8 games it did not improve on is unclear. From looking at Figure 3 (which uses Double DQN, not DQN), perhaps Robotank, Defender, Tutankham, Boxing, Bowling, BankHeist, Centipede, and Yar’s Revenge? I wouldn’t get too bogged down with the details; the benefits of PER are abundantly clear.
As a testament to the importance of prioritization, the Rainbow DQN paper showed that prioritization was perhaps the most essential extension for obtaining high scores on Atari games. Granted, their prioritization was based not on absolute TD error but based on a Kullback-Leibler loss because of their use of distributional DQNs, but the main logic might still apply to TD error.
Prioritization can be applied to other applications of experience replay. For example, suppose we wanted to add extra samples to the buffer from some “demonstrator” as in Deep Q-Learning from Demonstrations (blog post here). We can keep the same replay buffer code as earlier, but allocate the first $k$ items in the list come from demonstrator samples. Then our indexing for overriding older samples from the current agent must skip over the first $k$ items. It might be simplest to record this by adding a flag $f_t$ to the sample indicating whether it is a demonstrator or current agent sample. You can probably see why researchers prefer to write $(s_t,a_t,r_t,s_{t+1})$ without all the annoying flags and extra terms! To apply prioritization, one can adjust the raw values $p_i$ to increase those from the demonstrator.
I hope this was an illuminating overview of prioritized experience replay. For details, I refer you (again) to the paper and for an open-source implementation from OpenAI. Happy readings!