Model-Agnostic Meta-Learning
One of the recent landmark papers in the area of meta-learning is MAML: Model-Agnostic Meta-Learning. The idea is simple yet surprisingly effective: train neural network parameters \(\theta\) on a distribution of tasks so that, when faced with a new task, can be rapidly adjusted through just a few gradient steps. In this post, I’ll briefly go over the notation and problem formulation for MAML, and meta-learning more generally.
Here’s the notation and setup, mostly following the paper:
-
The overall model \(f_\theta\) is what MAML is optimizing, with parameters \(\theta\). We denote \(\theta_i'\) as weights that have been adapted to the \(i\)-th task through one or more gradient steps. Since MAML can be applied to classification, regression, reinforcement learning, and imitation learning (plus even more stuff!) we generically refer to \(f_\theta\) as mapping from inputs \(x_t\) to outputs \(a_t\).
-
A task \(\mathcal{T}_i\) is defined as a tuple \((T_i, q_i, \mathcal{L}_{\mathcal{T}_i})\), where:
-
\(T_i\) is the time horizon. For (IID) supervised learning problems like classification, \(T_i=1\). For reinforcement learning and imitation learning, it’s whatever the environment dictates.
-
\(q_i\) is the transition distribution, defining a prior over initial observations \(q_i(x_1)\) and the transitions \(q_i(x_{t+1}\mid x_{t},a_t)\). Again, we can generally ignore this for simple supervised learning. Also, for imitation learning, this reduces to the distribution over expert trajectories.
-
\(\mathcal{L}_{\mathcal{T}_i}\) is a loss function that maps the sequence of network inputs \(x_{1:T}\) and outputs \(a_{1:T}\) to a scalar value indicating the quality of the model. For supervised learning tasks, this is almost always the cross entropy or squared error loss.
-
-
Tasks are drawn from some distribution \(p(\mathcal{T})\). For example, we can have a distribution over the abstract concept of doing well at “block stacking tasks”. One task could be about stacking blue blocks. Another could be about stacking red blocks. Yet another could be stacking blocks that are numbered and need to be ordered consecutively. Clearly, the performance of meta-learning (or any alternative algorithm, for that matter) on optimizing \(f_\theta\) depends on \(p(\mathcal{T})\). The more diverse the distribution’s tasks, the harder it is for \(f_\theta\) to quickly learn new tasks.
The MAML algorithm specifically finds a set of weights \(\theta\) that are easily fine-tuned to new, held-out tasks (for testing) by optimizing the following:
\[{\rm minimize}_\theta \sum_{\mathcal{T}_i\sim p(\mathcal{T})} \mathcal{L}_{\mathcal{T}_i} (f_{\theta_i'}) = \sum_{\mathcal{T}_i\sim p(\mathcal{T})} \mathcal{L}_{\mathcal{T}_i} \Big(f_{\theta - \alpha \nabla_\theta \mathcal{L}_{\mathcal{T}_i}(f_\theta)}\Big)\]This assumes that \(\theta_i' = \theta - \alpha \nabla_\theta \mathcal{L}_{\mathcal{T}_i}(f_\theta)\). It is also possible to do multiple gradient steps, not just one. Thus, if we do \(K\)-shot learning, then \(\theta_i'\) is obtained via \(K\) gradient updates based on the task. However, “one shot” is cooler than “few shot” and also easier to write, so we’ll stick with that.
Let’s look at the loss function above. We are optimizing over a sum of loss functions across several tasks. But we are evaluating the (outer-most) loss functions while assuming we made gradient updates to our weights \(\theta\). What if the loss function were like this:
\[{\rm minimize}_\theta \sum_{\mathcal{T}_i\sim p(\mathcal{T})} \mathcal{L}_{\mathcal{T}_i} (f_{\theta})\]This means \(f_\theta\) would be capable of learning how to perform well across all these tasks. But there’s no guarantee that this will work on held-out tasks, and generally speaking, unless the tasks are so closely related, it shouldn’t work. (I’ve tried doing some similar stuff in the past with the Atari 2600 benchmark where a “task” was “doing well on game X”, and got networks to optimize across several games, but generalization was not possible without fine-tuning.) Also, even if we were allowed to fine-tune, it’s very unlikely that one or few gradient steps would lead to solid performance. MAML should do better precisely because it optimizes \(\theta\) so that it can adapt to new tasks with just a few gradient steps.
MAML is an effective algorithm for meta-learning, and one of its advantages over other algorithms such as \({\rm RL}^2\) is that it is parameter-efficient. The gradient updates above do not introduce extra parameters. Furthermore, the actual optimization over the full model \(\theta\) is also done via SGD
\[\theta = \theta - \beta \left( \nabla_\theta \sum_{\mathcal{T}_i\sim p(\mathcal{T})} \mathcal{L}_{\mathcal{T}_i} \Big(f_{\theta - \alpha \nabla_\theta \mathcal{L}_{\mathcal{T}_i}(f_\theta)}\Big) \right)\]again introducing no new parameters. (The update is actually Adam if we’re doing supervised learning, and TRPO if doing RL, but SGD is the foundation of those and it’s easier for me to write the math. Also, even though the updates may be complex, I think the inner part, where we have \(f_{\theta - \alpha \nabla_\theta \mathcal{L}_{\mathcal{T}_i}(f_\theta)}\), I think that is always vanilla SGD, but I could be wrong.)
I’d like to emphasize a key point: the above update mandates two instances of \(\mathcal{L}_{\mathcal{T}_i}\). One of these — the one in the subscript to get \(\theta_i'\) should involve the \(K\) training instances from the task \(\mathcal{T}_i\) (or more specifically, \(q_i\)). The outer-most loss function should be computed on testing instances, also from task \(\mathcal{T}_i\). This is important because we want our ultimate evaluation to be done on testing instances.
Another important point is that we do not use those “testing instances” for evaluating meta-learning algorithms, as that would be cheating. For testing, one takes a held-out set of test tasks entirely, adjusts \(\theta\) for however many steps are allowed (one in the case of one-shot learning, etc.) and then evaluates according to whatever metric is appropriate for the task distribution.
In a subsequent post, I will further investigate several MAML extensions.