CAVIA: Fast Context Adaptation via Meta-Learning

Adapting to previously unseen tasks is a long-standing problem in machine learning. Ideally, we want to do this fast and with as little data as possible.

Consider the following example: You want to train an object classifier which can detect whether an image contains a meerkat or a cigar. However, you only have four training images per class (this is also called 2-way 4-shot classification):

Training a (deep) neural network from scratch on this dataset would not work at all: the model would overfit to the training data, and would not be able to generalise to an unseen image like the one on the right.

However, we might have access to a large collection of labelled images of different object categories:

We can use these to build 2-way 4-shot mini-datasets like the meerkat-cigar one, and learn how to learn quickly on such types of datasets.

One particular approach to these types of problems is meta-learning. For a fantastic overview of meta-learning settings and different approaches we recommend this blog post by Lilian Weng. In our work, we build on a method which solves this problem by learning a network initialisation as follows.

Background: MAML

Model-Agnostic Meta-Learning (MAML) is a powerful gradient-based approach to the problem of fast adaptation. MAML tries to learn a parameter initialisation \theta such that adapting to new tasks can be done within several gradient updates. This approach is model and task agnostic: it can be used with any gradient-based algorithm, and can be applied to regression, classification, and reinforcement learning tasks. After meta-training, the model is evaluated on a new task: given a small set of labelled data points (in supervised learning) or trajectories (in reinforcement learning), the learned initial parameters are adapted using just a few gradient steps.

As such, MAML adapts the entire model when learning a new task. However, this is (a) often not necessary since many tasks and existing benchmarks do not require generalisation beyond task identification, and (b) can in fact be detrimental to performance, since it can lead to overfitting.

We propose an extension to MAML which addresses these points, and has the additional benefit of being interpretable and easier to parallelise. We call our algorithm Fast Context Adaptation via Meta-Learning (CAVIA), and show empirically that this results in equal or better performance compared to MAML on a range of tasks.

CAVIA

So, how does our CAVIA work? Let’s formalise the problem setting first. We describe the supervised learning setting here. However, it is easy to transfer it to the reinforcement learning setup (check our paper for more details).

We are given a distribution over training tasks p_{train}(\mathcal{T}) and test tasks p_{test}(\mathcal{T}). The goal of the supervised learning algorithm is to learn a model f: x \rightarrow \hat{y} mapping input features x to a label y.

To understand CAVIA, it is easier to start with MAML.

\theta_i = \theta - \alpha \nabla_{\theta}\frac{1}{M^i_{train}}\sum_{(x,y) \in \mathcal{D}^{train}_i}{\mathcal{L}_{\mathcal{T}_i}(f_{\theta}(x), y)},

where M is the dataset \mathcal{D} size and \alpha is the learning rate.

\theta = \theta - \beta \nabla_{\theta}\frac{1}{N}\sum_{\mathcal{T}_i \in \mathbf{T}}{\frac{1}{M^i_{test}}\sum_{(x,y) \in \mathcal{D}^{test}_i}{\mathcal{L}_{\mathcal{T}_i}(f_{\theta_i}(x), y)}},

where \beta is the outer loop learning rate. As we can see, in both cases, we update \theta, all the parameters of the network.

CAVIA does a similar update. However, we split all the network parameters into two disjoint subsets: global parameters \theta and context parameters \phi.

Like MAML, CAVIA consists of an inner and an outer loop update, with the difference that we update only the context parameters \phi in the inner loop, and only the shared network parameters \theta in the outer loop.

In the inner update loop, we update context parameters \phi.

\phi_i = \phi_0 - \alpha \nabla_{\phi}\frac{1}{M^i_{train}}\sum_{(x,y) \in \mathcal{D}^{train}_i}{\mathcal{L}_{\mathcal{T}_i}(f_{\phi_0, \theta}(x), y)}

In the outer update loop, we update the global parameters \theta.

\theta = \theta - \beta \nabla_{\theta}\frac{1}{N}\sum_{\mathcal{T}_i \in \mathbf{T}}{\frac{1}{M^i_{test}}\sum_{(x,y) \in \mathcal{D}^{test}_i}{\mathcal{L}_{\mathcal{T}_i}(f_{\phi, \theta_i}(x), y)}}

Keeping a separate set of parameters has two advantages. First, we can vary the size of it based on the task on hand, incorporating prior knowledge about the task into the network structure. Second, it is much easier to parallelise than MAML.

Experimental Results

We evaluated CAVIA on a range of popular meta-learning benchmarks for regression, classification and reinforcement learning tasks. One of the motivations of CAVIA is that many tasks do not require generalisation beyond task identification – and this is also true for many current benchmarks.

To illustrate this, the below figure shows the number of parameters we update on the benchmarks we tested, for MAML versus CAVIA (note the log-scale on the y-axis):

This figure shows that the amount of adaptation on these benchmarks is relatively small. In the following, we look at those benchmarks in more detail.

Regression

Fitting sine curves

Let us start with a regression task, in which we want to fit sine curves, as done in the Model-Agnostic Meta-Learning (MAML) paper. Amplitude and phase fully specify a task. We sample amplitudes from [0.1,0.5] range and the phase from [0,\pi].

Figures below show the curve fitting before and after the gradient update. While MAML and CAVIA both succeed in the task, we would like to point out, that CAVIA adjusts just 2 context parameters, when MAML adjusts approximately 1500 weights, which makes it prone to overfitting.

Before the update
After the update

For this example, we can easily visualise what the context parameters learn. Below you see a visualisation of what the context parameters learn when using only two context parameters:

The x-axis shows the resulting value of context parameter 1 after the update, and the y-axis show the resulting value of context parameter 2 after the update (each dot is a single task and its position reflects the value of the context parameters). The colour shows the true task variable (amplitude on the left, and phase on the right). As we can see, CAVIA learns an embedding which can be smoothly interpolated. The circular shape is probably due to the phase being periodic.

Image Completion

Next, we decided to test CAVIA on a more challenging task: CelebA image completion which was suggested by Marta Garnelo et al. (2018). The table below shows CAVIA superiority in terms of the pixel-wise MSE.

 Random PixelsOrdered Pixels
101001000101001000
CNP0.0390.0160.0090.0570.0470.021
MAML0.0400.0170.0060.0550.0470.007
CAVIA0.0370.0140.0060.0530.0470.006

As the next figure justifies, CAVIA is able to learn to restore a picture of a face from ten pixels only. In this particular experiment, we used 128 context parameters and five gradient steps for adaptation.

Few-Shot Classification

We also tested CAVIA for few-shot classification on the challenging Mini-Imagenet benchmark. This task requires large convolutional networks, which have the risk of overfitting when updated on only a small number of datapoints. The question for us was whether CAVIA can scale to large networks without overfitting. In our experiments, we used 100 context parameters for CAVIA, and increased the size of \theta by increasing the number of filters (numbers in brackets in the table). The table below shows that as the network size increases, the performance of MAML goes down, whereas the performance of CAVIA increases.

Method5-way accuracy
1-shot, %5-shot, %
Matching Nets46.660.0
Meta LSTM43.44±0.7760.60±0.71
Prototypical Networks46.61±0.7865.77±0.70
Meta-SGD50.47±1.8764.03±0.94
REPTILE49.97±0.3265.99±0.58
MT-NET51.70±1.84-
VERSA53.40±1.8267.37±0.86
MAML (32)48.07±1.7563.15±0.91
MAML (64)44.70±1.6961.87±0.93
CAVIA (32)47.24±0.6559.05±0.54
CAVIA (128)49.84±0.6864.63±0.54
CAVIA (512)51.82±0.6565.85±0.55
CAVIA (512, first order)49.92±0.6863.59±0.57

Reinforcement Learning

The final experiment we conducted is MuJoCo, a high dimensional reinforcement learning benchmark. The first task reward incentivised HalfCheetah going to a particular direction. The second encouraged agents going with a particular speed. All of the methods were trained up to 500 meta-iterations. CAVIA used 50 context parameters. As we can see in the figures below, in both of the tasks CAVIA outperformed MAML for one gradient step. MAML catches up for the second tasks achieving a similar performance after three gradient updates.

Direction
Velocity

This shows that CAVIA can learn the same compared to MAML, when adapting only a context parameter vector of size 50 at test time.

We also wanted to have a look at the learned policy, which is shown in the video below for the Forward/Backward task. To get a feeling for how good the task embedding is, we learned a binary classifier using logistic regression which predicts the task (forward/backward) from the context parameters (we learn this after meta-training). As the video shows, this predicts roughly a 50/50 chance before doing any updates, and predicts the correct task (backwards) just from the context parameters after three updates.

Below is another video, this time for the forward task. Notice that in this rollout, even though the task is inferred with high certainty only after one update, the forward policy is not as elegant as the backward one.

Discussion

CAVIA highlights that

  • Many interesting few-shot learning tasks require generalisation in the form of task identification, and we can use this to develop more robust algorithms.
  • Current benchmarks often require only this type of generalisation. We believe that more benchmarks are necessary to push beyond this (such as the Meta-Dataset).

We believe that CAVIA opens up the possibilities for exciting new research directions, including

  • More interpretable algorithms that allow an analysis of what the context parameters learn.
  • The reuse of the context parameters for downstream tasks, auxiliary tasks, or distributed machine learning systems.

Finally, we believe that for tasks that require adaptation beyond task identification, methods that do adapt more than just context parameters are necessary. A combination of CAVIA-like methods for task identification, and MAML-like adaptation on all parameters is a promising future direction.

Summary

In this blog post we covered CAVIA, a meta-learning method for fast adaptation to previously unseen tasks. We have didn’t go into a lot of details and implementation peculiarities (e.g. how to condition the network on the parameters, how to initialise context parameters etc) and connections to related work in meta-learning. If you want to know more, check our paper, play with the code or write us an email.

Leave a Reply

Your email address will not be published. Required fields are marked *