Colorinng every pixel in an image

We're going to stop working on image classification and instead try out a new task that still use images: semantic segmentation. As always, the best way to describe things in deep learning is by specifying the exact inputs and ouputs. So...

Source: https://www.researchgate.net/publication/326875064/figure/fig3/AS:659518916681730@1534252971987/Example-of-2D-semantic-segmentation-Top-input-image-Bottom-prediction.png

Essentially, for a given input image, we want to "color" or assign a class label to every pixel in it.

...

from fastai.vision.all import *

Prepare Dataset

For segmentation, we're going to use a dataset called CAMVID, which is a collection of common scenes that you might expect a self-driving car, for example, to see. Note that we're not going to level 5 autonomy just yet, but it's a step..

CAMVID is available directly in fastai, so we're going to load it in the usual way.

path = untar_data(URLs.CAMVID)

As an aside, the reason we're using all these common, standard, academic datasets is not because they're representative of the real world. It's because these datasets, like CIFAR10, ImageNet, and CAMVID, are the ones that researchers have agreed to use, so we have a common standard for comparing and benchmarkig models.

The CAMVID dataset is also formatted differently, so we'll have to do a little more work to get it into the format we want (a fastai DataLoaders object).

path.ls()
(#4) [Path('/home/iyaja/.fastai/data/camvid/valid.txt'),Path('/home/iyaja/.fastai/data/camvid/images'),Path('/home/iyaja/.fastai/data/camvid/labels'),Path('/home/iyaja/.fastai/data/camvid/codes.txt')]
valid_fnames = open(path/'valid.txt').read().split('\n')
type(valid_fnames)
list
valid_fnames[:10]
['0016E5_07959.png',
 '0016E5_07961.png',
 '0016E5_07963.png',
 '0016E5_07965.png',
 '0016E5_07967.png',
 '0016E5_07969.png',
 '0016E5_07971.png',
 '0016E5_07973.png',
 '0016E5_07975.png',
 '0016E5_07977.png']

It looks like CAMVID provides the validation images as a specified list, rather than seperating the training and validation sets into seperate folders.

path_images = path/'images'
path_labels = path/'labels'
images = get_image_files(path_images)
labels = get_image_files(path_labels)
img = PILImage.create(images[0])
img.show(figsize=(5,5))
<matplotlib.axes._subplots.AxesSubplot at 0x7f72bcc64b80>
lbl = PILMask.create(labels[0])
lbl.show(figsize=(5,5), alpha=1)
<matplotlib.axes._subplots.AxesSubplot at 0x7f72bcb8ca30>

One problem is that the list of images and labels are not in order, so they don't correspond to one another. We need a function that gives us the path to the mask that matches the right image.

def get_mask(filename):
    return path/'labels'/f'{filename.stem}_P{filename.suffix}'
img = PILImage.create(images[0])
mask = PILMask.create(get_mask(images[0]))
img.show(figsize=(5,5))
mask.show(figsize=(5,5), alpha=1)
<matplotlib.axes._subplots.AxesSubplot at 0x7f72bcb0de80>

Perfect! Now let's examine the mask a little closer

tensor(mask)
tensor([[ 4,  4,  4,  ...,  4,  4,  4],
        [ 4,  4,  4,  ...,  4,  4,  4],
        [ 4,  4,  4,  ...,  4,  4,  4],
        ...,
        [17, 17, 17,  ..., 17, 17, 17],
        [17, 17, 17,  ..., 17, 17, 17],
        [17, 17, 17,  ..., 17, 17, 17]], dtype=torch.uint8)
codes = np.loadtxt(path/'codes.txt', dtype=str); codes
array(['Animal', 'Archway', 'Bicyclist', 'Bridge', 'Building', 'Car',
       'CartLuggagePram', 'Child', 'Column_Pole', 'Fence', 'LaneMkgsDriv',
       'LaneMkgsNonDriv', 'Misc_Text', 'MotorcycleScooter', 'OtherMoving',
       'ParkingBlock', 'Pedestrian', 'Road', 'RoadShoulder', 'Sidewalk',
       'SignSymbol', 'Sky', 'SUVPickupTruck', 'TrafficCone',
       'TrafficLight', 'Train', 'Tree', 'Truck_Bus', 'Tunnel',
       'VegetationMisc', 'Void', 'Wall'], dtype='<U17')

Progressive Resizing

full_size = mask.shape
full_size
(720, 960)
half_size = (360, 480)

We now have everything we need to assemble our DataLoaders.

Remeber, fastai classes and functions are super generic and flexible, so we can use both the techniques we used previously (transfer learning and data augmentation) together in the same waywe did before!

data = DataBlock(blocks = (ImageBlock, MaskBlock(codes)),
                 get_items = get_image_files,
                 splitter = FileSplitter(path/'valid.txt'),
                 get_y = get_mask,
                 batch_tfms = [*aug_transforms(size=half_size), Normalize.from_stats(*imagenet_stats)]
                )
dls = data.dataloaders(path/'images', bs=8)
dls.show_batch(figsize=(20, 8))

There are a few more things we need to take care of to ensure that we get resonable and interpretable outputs.

dls.vocab = codes
name2id = {v:k for k,v in enumerate(codes)}
void_code = name2id['Void']

Specifically, we want to make sure that we don't count the "unknown" or "void" class when we compute accuracy (this is how the CAMVID accuracy is defined, and you'll see that it makes a difference).

def camvid_accuracy(pred, real):
    real = real.squeeze(1)
    mask = real != void_code
    return (pred.argmax(dim=1)[mask] == targ[mask]).float().mean()

UNet Learner & The Small Things That Make A Difference

config = unet_config(self_attention=True, act_cls=Mish)
learn = unet_learner(dls, resnet18, metrics=[camvid_accuracy], config=config, opt_func=ranger)

There's another handy function that tells up what's going on with our model. It can be useful when you're analyzing a particular architecture or

learn.summary()
epoch train_loss valid_loss camvid_accuracy time
0 None None 00:00
DynamicUnet (Input shape: ['8 x 3 x 360 x 480'])
================================================================
Layer (type)         Output Shape         Param #    Trainable 
================================================================
Conv2d               8 x 64 x 180 x 240   9,408      False     
________________________________________________________________
BatchNorm2d          8 x 64 x 180 x 240   128        True      
________________________________________________________________
ReLU                 8 x 64 x 180 x 240   0          False     
________________________________________________________________
MaxPool2d            8 x 64 x 90 x 120    0          False     
________________________________________________________________
Conv2d               8 x 64 x 90 x 120    36,864     False     
________________________________________________________________
BatchNorm2d          8 x 64 x 90 x 120    128        True      
________________________________________________________________
ReLU                 8 x 64 x 90 x 120    0          False     
________________________________________________________________
Conv2d               8 x 64 x 90 x 120    36,864     False     
________________________________________________________________
BatchNorm2d          8 x 64 x 90 x 120    128        True      
________________________________________________________________
Conv2d               8 x 64 x 90 x 120    36,864     False     
________________________________________________________________
BatchNorm2d          8 x 64 x 90 x 120    128        True      
________________________________________________________________
ReLU                 8 x 64 x 90 x 120    0          False     
________________________________________________________________
Conv2d               8 x 64 x 90 x 120    36,864     False     
________________________________________________________________
BatchNorm2d          8 x 64 x 90 x 120    128        True      
________________________________________________________________
Conv2d               8 x 128 x 45 x 60    73,728     False     
________________________________________________________________
BatchNorm2d          8 x 128 x 45 x 60    256        True      
________________________________________________________________
ReLU                 8 x 128 x 45 x 60    0          False     
________________________________________________________________
Conv2d               8 x 128 x 45 x 60    147,456    False     
________________________________________________________________
BatchNorm2d          8 x 128 x 45 x 60    256        True      
________________________________________________________________
Conv2d               8 x 128 x 45 x 60    8,192      False     
________________________________________________________________
BatchNorm2d          8 x 128 x 45 x 60    256        True      
________________________________________________________________
Conv2d               8 x 128 x 45 x 60    147,456    False     
________________________________________________________________
BatchNorm2d          8 x 128 x 45 x 60    256        True      
________________________________________________________________
ReLU                 8 x 128 x 45 x 60    0          False     
________________________________________________________________
Conv2d               8 x 128 x 45 x 60    147,456    False     
________________________________________________________________
BatchNorm2d          8 x 128 x 45 x 60    256        True      
________________________________________________________________
Conv2d               8 x 256 x 23 x 30    294,912    False     
________________________________________________________________
BatchNorm2d          8 x 256 x 23 x 30    512        True      
________________________________________________________________
ReLU                 8 x 256 x 23 x 30    0          False     
________________________________________________________________
Conv2d               8 x 256 x 23 x 30    589,824    False     
________________________________________________________________
BatchNorm2d          8 x 256 x 23 x 30    512        True      
________________________________________________________________
Conv2d               8 x 256 x 23 x 30    32,768     False     
________________________________________________________________
BatchNorm2d          8 x 256 x 23 x 30    512        True      
________________________________________________________________
Conv2d               8 x 256 x 23 x 30    589,824    False     
________________________________________________________________
BatchNorm2d          8 x 256 x 23 x 30    512        True      
________________________________________________________________
ReLU                 8 x 256 x 23 x 30    0          False     
________________________________________________________________
Conv2d               8 x 256 x 23 x 30    589,824    False     
________________________________________________________________
BatchNorm2d          8 x 256 x 23 x 30    512        True      
________________________________________________________________
Conv2d               8 x 512 x 12 x 15    1,179,648  False     
________________________________________________________________
BatchNorm2d          8 x 512 x 12 x 15    1,024      True      
________________________________________________________________
ReLU                 8 x 512 x 12 x 15    0          False     
________________________________________________________________
Conv2d               8 x 512 x 12 x 15    2,359,296  False     
________________________________________________________________
BatchNorm2d          8 x 512 x 12 x 15    1,024      True      
________________________________________________________________
Conv2d               8 x 512 x 12 x 15    131,072    False     
________________________________________________________________
BatchNorm2d          8 x 512 x 12 x 15    1,024      True      
________________________________________________________________
Conv2d               8 x 512 x 12 x 15    2,359,296  False     
________________________________________________________________
BatchNorm2d          8 x 512 x 12 x 15    1,024      True      
________________________________________________________________
ReLU                 8 x 512 x 12 x 15    0          False     
________________________________________________________________
Conv2d               8 x 512 x 12 x 15    2,359,296  False     
________________________________________________________________
BatchNorm2d          8 x 512 x 12 x 15    1,024      True      
________________________________________________________________
BatchNorm2d          8 x 512 x 12 x 15    1,024      True      
________________________________________________________________
ReLU                 8 x 512 x 12 x 15    0          False     
________________________________________________________________
Conv2d               8 x 1024 x 12 x 15   4,719,616  True      
________________________________________________________________
Mish                 8 x 1024 x 12 x 15   0          False     
________________________________________________________________
Conv2d               8 x 512 x 12 x 15    4,719,104  True      
________________________________________________________________
Mish                 8 x 512 x 12 x 15    0          False     
________________________________________________________________
Conv2d               8 x 1024 x 12 x 15   525,312    True      
________________________________________________________________
Mish                 8 x 1024 x 12 x 15   0          False     
________________________________________________________________
PixelShuffle         8 x 256 x 24 x 30    0          False     
________________________________________________________________
BatchNorm2d          8 x 256 x 23 x 30    512        True      
________________________________________________________________
Conv2d               8 x 512 x 23 x 30    2,359,808  True      
________________________________________________________________
Mish                 8 x 512 x 23 x 30    0          False     
________________________________________________________________
Conv2d               8 x 512 x 23 x 30    2,359,808  True      
________________________________________________________________
Mish                 8 x 512 x 23 x 30    0          False     
________________________________________________________________
Mish                 8 x 512 x 23 x 30    0          False     
________________________________________________________________
Conv2d               8 x 1024 x 23 x 30   525,312    True      
________________________________________________________________
Mish                 8 x 1024 x 23 x 30   0          False     
________________________________________________________________
PixelShuffle         8 x 256 x 46 x 60    0          False     
________________________________________________________________
BatchNorm2d          8 x 128 x 45 x 60    256        True      
________________________________________________________________
Conv2d               8 x 384 x 45 x 60    1,327,488  True      
________________________________________________________________
Mish                 8 x 384 x 45 x 60    0          False     
________________________________________________________________
Conv2d               8 x 384 x 45 x 60    1,327,488  True      
________________________________________________________________
Mish                 8 x 384 x 45 x 60    0          False     
________________________________________________________________
Conv1d               8 x 48 x 2700        18,432     True      
________________________________________________________________
Conv1d               8 x 48 x 2700        18,432     True      
________________________________________________________________
Conv1d               8 x 384 x 2700       147,456    True      
________________________________________________________________
Mish                 8 x 384 x 45 x 60    0          False     
________________________________________________________________
Conv2d               8 x 768 x 45 x 60    295,680    True      
________________________________________________________________
Mish                 8 x 768 x 45 x 60    0          False     
________________________________________________________________
PixelShuffle         8 x 192 x 90 x 120   0          False     
________________________________________________________________
BatchNorm2d          8 x 64 x 90 x 120    128        True      
________________________________________________________________
Conv2d               8 x 256 x 90 x 120   590,080    True      
________________________________________________________________
Mish                 8 x 256 x 90 x 120   0          False     
________________________________________________________________
Conv2d               8 x 256 x 90 x 120   590,080    True      
________________________________________________________________
Mish                 8 x 256 x 90 x 120   0          False     
________________________________________________________________
Mish                 8 x 256 x 90 x 120   0          False     
________________________________________________________________
Conv2d               8 x 512 x 90 x 120   131,584    True      
________________________________________________________________
Mish                 8 x 512 x 90 x 120   0          False     
________________________________________________________________
PixelShuffle         8 x 128 x 180 x 240  0          False     
________________________________________________________________
BatchNorm2d          8 x 64 x 180 x 240   128        True      
________________________________________________________________
Conv2d               8 x 96 x 180 x 240   165,984    True      
________________________________________________________________
Mish                 8 x 96 x 180 x 240   0          False     
________________________________________________________________
Conv2d               8 x 96 x 180 x 240   83,040     True      
________________________________________________________________
Mish                 8 x 96 x 180 x 240   0          False     
________________________________________________________________
Mish                 8 x 192 x 180 x 240  0          False     
________________________________________________________________
Conv2d               8 x 384 x 180 x 240  37,248     True      
________________________________________________________________
Mish                 8 x 384 x 180 x 240  0          False     
________________________________________________________________
PixelShuffle         8 x 96 x 360 x 480   0          False     
________________________________________________________________
ResizeToOrig         8 x 96 x 360 x 480   0          False     
________________________________________________________________
MergeLayer           8 x 99 x 360 x 480   0          False     
________________________________________________________________
Conv2d               8 x 99 x 360 x 480   88,308     True      
________________________________________________________________
Mish                 8 x 99 x 360 x 480   0          False     
________________________________________________________________
Conv2d               8 x 99 x 360 x 480   88,308     True      
________________________________________________________________
Sequential           8 x 99 x 360 x 480   0          False     
________________________________________________________________
Mish                 8 x 99 x 360 x 480   0          False     
________________________________________________________________
Conv2d               8 x 32 x 360 x 480   3,200      True      
________________________________________________________________

Total params: 31,300,328
Total trainable params: 20,133,416
Total non-trainable params: 11,166,912

Optimizer used: <function ranger at 0x7f72b19b7af0>
Loss function: FlattenedLoss of CrossEntropyLoss()

Model frozen up to parameter group #2

Callbacks:
  - TrainEvalCallback
  - Recorder
  - ProgressCallback
learn.lr_find()
0.00% [0/2 00:00<00:00]
68.00% [51/75 15:44<07:24 3.2741]
lr = 1e-3
learn.fine_tune(15, slice(lr), freeze_epochs = 10)
learn.save('xresnet50_half.pkl')
learn.show_results(figsize=(20,8))