Making our models work better by getting more data for free

Last week, we talked about image classification - the task where you, or more specifically the model, looks at images and tries to guess what it is based on what you said it could possibly be (i.e. what "class" is an image a member of).

Our model did fairly ok, but it may not do as well all the time. Often, one of the biggest restriction syou have is training data. Unless you're working with images that someone else has already checked, labelled, and prepared for you, it's not easy to go and collect millions of images that might be useful. And getting images that are representative of the real-world distribution would be even harder.

On a sidenote, this is why the imagenet dataset was so revolutionary to deep learning. It provided a large, standardized set of images that researchers could use, which removed the need for people developing algorithms to think about data collection.

Today, we're going to be looking at a training technique called data augmentation, which will partially alleviate this issue.

In short, data augmentation means transforming your inputs and outputs in some way that you know is correct. For example, I can rotate an image of a dog by a few degrees and it's still a dog, but this might not be true for satellite images. Figuring out what data augmentations work for your problem and using them can generally improve your model's performance a lot.

So let's see how we can use data augmentation in fastai, and if it improves the performance of the classifier we build last time. We set up the dataloaders in the usual way. This time, we're using the CIFAR10 dataset, because we'll be training a few models and we want to iterate fast.

from fastai.vision.all import *
path = untar_data(URLs.IMAGENETTE)
data = DataBlock(blocks = (ImageBlock, CategoryBlock),
                 get_items = get_image_files,
                 get_y = parent_label,
                 splitter = GrandparentSplitter(valid_name='val'),
                 item_tfms = RandomResizedCrop(128, min_scale=0.35), 
                 batch_tfms = None
                )
dls = data.dataloaders(path)
dls.show_batch()

Let's recreate the baseline classifier. I'm turning off pretrained for now so that we can see the improvements more clearly. But as always, you should be using a pretrained model in most cases.

learn = cnn_learner(dls, resnet18, pretrained=False, metrics=accuracy)
learn.fit_one_cycle(5)

Ok, so that's that. Time to start looking at some augmentations.

Resizing

One of the first things you should be doing is resizing your images to a standardized format. Because we're batching up images for faster training, the network is going to expect that all of them are the same size.

data = data.new(item_tfms = Resize(128, ResizeMethod.Squish), batch_tfms = None)
dls = data.dataloaders(path)
dls.valid.show_batch(max_n=5, nrows=1)

But squishing, which is the usual way people resize images, doesn't make much sense in deep learning, because the relative height and width may be important. So generally, we prefer padding.

data = data.new(item_tfms=Resize(128, ResizeMethod.Pad, pad_mode='zeros'), batch_tfms = None)
dls = data.dataloaders(path)
dls.valid.show_batch(max_n=5, nrows=1)

There's more than one way to pad though...

data = data.new(item_tfms=Resize(128, ResizeMethod.Pad, pad_mode='reflection'), batch_tfms = None)
dls = data.dataloaders(path)
dls.valid.show_batch(max_n=5, nrows=1)

Padding is generally ok, but we can do even better with cropping.

data = data.new(item_tfms = RandomResizedCrop(224, min_scale=0.3), batch_tfms = None)
dls = data.dataloaders(path)
dls.train.show_batch(max_n=5, nrows=1, unique=True)

In fact, there are multiple cropping methods as well! what tends to be used in research is center cropping, so that everyone has a standardized set of transforms to do on imagenet, but random cropping would work better in practice.

Batched GPU Transforms

If you thought we went overboard with resizing, wait till you see the real transforms. There are a lot of possible data augmentation techniques, and it would be impossible to cover them all. Luckily, fastai provides a really nice way to apply a wide variety of transforms together!

data = data.new(item_tfms = CropPad(500), batch_tfms = aug_transforms(mult=1))
dls = data.dataloaders(path)
dls.train.show_batch(max_n=5, nrows=1, unique=True)

(Slightly More) Advanced Transforms

There's also been a lot of interst in newer data augmentations recently, since they provide a regularizing effect (prevent overfitting or "memorizing").

Arguably, the one that started this trend is called MixUp, which again, is implemented in fastai already. However, it's implemented a little differently. CutMix uses multiple images, so it doesn't really fit the fastai defintion of a transform. Instead, we use a callback, which is just a function that's triggered on an event (like button clicks in javascript).

from fastai.callback.cutmix import *
mixup = MixUp(0.5)
learn = cnn_learner(dls, resnet18, pretrained=False, metrics=accuracy, cbs=mixup)
learn = cnn_learner(dls, resnet18, pretrained=False, metrics=accuracy)
learn.fit_one_cycle(5)
cutmix = CutMix(alpha=1.)
learn = cnn_learner(dls, resnet18, pretrained=False, metrics=accuracy, cbs=cutmix)