Mtl
Unofficial implementation of: Multi-task learning using uncertainty to weigh losses for scene geometry and semantics
Install / Use
/learn @ranandalon/MtlREADME
Multi-Task Learning project
Unofficial implementation of:<br> Kendall, Alex, Yarin Gal, and Roberto Cipolla. "Multi-task learning using uncertainty to weigh losses for scene geometry and semantics." Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2018. [arXiv].
Abstract
Numerous deep learning applications benefit from multi-task learning with multiple regression and classification objectives. In this paper we make the observation that the performance of such systems is strongly dependent on the relative weighting between each task’s loss. Tuning these weights by hand is a difficult and expensive process, making multi-task learning prohibitive in practice. We propose a principled approach to multi-task deep learning which weighs multiple loss functions by considering the homoscedastic uncertainty of each task. This allows us to simultaneously learn various quantities with different units or scales in both classification and regression settings. We demonstrate our model learning per-pixel depth regression, semantic and instance segmentation from a monocular input image. Perhaps surprisingly, we show our model can learn multi-task weightings and outperform separate models trained individually on each task.
Multi Task Learning with Homoscedastic Uncertainty
The naive approach to combining multi objective losses would be to simply perform a weighted linear sum of the losses for each individual task:<br> <img src='images/naive_loss.PNG'><br>
The paper suggest that using Homoscedastic uncertainty can be used as a basis for weighting losses in a multi-task learning problem and produce supirior results then the naive approach.
Mathematical Formulation
First the paper defines multi-task likelihoods:<br>
- For regression tasks, likelihood is defined as a Gaussian with mean given by the model output with an observation noise scalar σ:<br> <img src='images/reg_likelihood.PNG'><br>
- For classification, likelihood is defined as:<br> <img src='images/class_likelihood_1.PNG'><br> where:<br> <img src='images/class_likelihood_0.PNG'><br>
In maximum likelihood inference, we maximise the log likelihood of the model. In regression for example:<br> <img src='images/reg_loglikelihood.PNG'><br> σ is the model’s observation noise parameter - capturing how much noise we have in the outputs. We then maximise the log likelihood with respect to the model parameters W and observation noise parameter σ.<br>
Assuming two tasks that follow a Gaussian distributions:<br> <img src='images/two_task.PNG'><br> The loss will be:<br> <img src='images/total_loss_h.PNG'><br> <img src='images/loss7.PNG'><br> This means that W and σ are the learned parameters of the network. W are the wights of the network while σ are used to calculate the wights of each task loss and also to regularize this task loss wight.
Architecture
Overview
The network consisets of an encoder which produce a shared representation and followed by three task-specific decoders:
- Semantic segmantation Decoder.
- Instance segmantation Decoder.
- Depth estimation Decoder.
Encoder
The encoder consisets of a fine tuned pre-trained ResNet 101 v1 with the following chnges:
- Droped the final fully conected layer.
- Last layer is resized to 128X256.
- used Dilated convolutional approch (atrous convolution).
Atrous convolution
Given an image, we assume that we first have a downsampling operation that reduces the resolution by a factor of 2, and then perform a convolution with a kernel (in the example beneath: the vertical Gaussian derivative). If one implants the resulting feature map in the original image coordinates, we realize that we have obtained responses at only 1/4 of the image positions. Instead, we can compute responses at all image positions if we convolve the full resolution image with a filter ‘with holes’, in which we upsample the original filter by a factor of 2, and introduce zeros in between filter values. Although the effective filter size increases, we only need to take into account the non-zero filter values, hence both the number of filter parameters and the number of operations per position stay constant. The resulting scheme allows us to easily and explicitly control the spatial resolution of neural network feature responses.
<img src='images/atrous_convolution.png'>Decoders
The decoders consisets of three convolution layers:
- 3X3 Conv + ReLU (512 kernels).
- 1X1 Conv + ReLU (512 kernels).
- 1X1 Conv + ReLU (as many kernels as needed for the task).
Semantic segmantation Decoder: last layer 34 channels.<br> <img src='images/semantic_segmantation.png' height="100px">
Instance segmantation Decoder: last layer 2 channels.<br> <img src='images/instance_segmantation.png' height="100px">
Depth estimation Decoder: last layer 1 channel.<br> <img src='images/depth_estimation.png' height="100px">
Losses
Specific losses
- Semantic segmantation loss (<img src='images/l_label.PNG' height="20px">): Cross entropy on softMax per pixel (only on valid pixels).
- Instance segmantation loss (<img src='images/l_instance.PNG' height="20px">): Centroid regression using masked L1. For each instance in the GT we calculate a mask of valid pixels and for each pixel in the mask the length (in pixels) from the mask center (for x and for y) - this will be used as the instance segmantation GT. Then for all valid pixels we calculate L1 between the network output and the instance segmantation GT.
- Depth estimation loss (<img src='images/l_disp.PNG' height="20px">): L1 (only on valid pixels).
Multi loss
<img src='images/multi_loss.PNG'>Notice that: <img src='images/sigmas.PNG' height="20px"> are learnable.
Instance segmantation explained
The instance segmantation decoder produces two channels so that each pixel is a vector pointing to the instance center. Using the semantic segmantation result we calculate a mask for to calculate the instance segmantation valid pixels. Then we combine the mask and the vectors calculated by the instance segmantation decoder and using the OPTICS clustering algorithem we cluster the vectors to diffrent instances. OPTICS is an efficient density based clustering algorithm. It is able to identify an unknown number of multi-scale clusters with varying density from a given set of samples. OPICS is used for two reasons. It does not assume knowledge of the number of clusters like algorithms such as k-means. Secondly, it does not assume a canonical instance size or density like discretised binning approaches.
<img src='images/instance_pipline_legand2.png'>Results
Examples
| Input | Label <br>segmentation |Instance <br>segmentation| Depth | |:-------------------:|:-------------------:|:-------------------:|:-------------------:| |<img width="200px" src='inputs/Pedestrian_crossing_0.png'>|<img src='results/resNet_label_instance_disp/label_Pedestrian_crossing_0.png' width="200px">|<img src='results/resNet_label_instance_disp/instance_Pedestrian_crossing_0.png' width="200px">|<img src='results/resNet_label_instance_disp/disp_Pedestrian_crossing_0.png' width="200px">| |<img width="200px" src='inputs/Pedestrian_crossing_1.png'>|<img src='results/resNet_label_instance_disp/label_Pedestrian_crossing_1.png' width="200px">|<img src='results/resNet_label_instance_disp/instance_Pedestrian_crossing_1.png' width="200px">|<img src='results/resNet_label_instance_disp/disp_Pedestrian_crossing_1.png' width="200px">| |<img width="200px" src='inputs/bicycle_0.png'>|<img src='results/resNet_label_instance_disp/label_bicycle_0.png' width="200px">|<img src='results/resNet_label_instance_disp/instance_bicycle_0.png' width="200px">|<img src='results/resNet_label_instance_disp/disp_bicycle_0.png' width="200px">| |<img width="200px" src='inputs/bicycle_1.png'>|<img src='results/resNet_label_instance_disp/label_bicycle_1.png' width="200px">|<img src='results/resNet_label_instance_disp/instance_bicycle_1.png' width="200px">|<img src='results/resNet_label_instance_disp/disp_bicycle_1.png' width="200px">| |<img width="200px" src='inputs/bus_0.png'>|<img src='results/resNet_label_instance_disp/label_bus_0.png' width="200px">|<img src='results/resNet_label_instance_disp/instance_bus_0.png' width="200px">|<img src='results/resNet_label_instance_disp/disp_bus_0.png' width="200px">| |<img width="200px" src='inputs/bus_1.png'>|<img src='results/resNet_label_instance_disp/label_bus_1.png' width="200px">|<img src='results/resNet_label_instance_disp/instance_bus_1.png' width="200px">|<img src='results/resNet_label_instance_disp/disp_bus_1.png' width="200px">| |<img width="200px" src='inputs/parking_0.png'>|<img src='results/resNet_label_instance_disp/label_parking_0.png' width="200px">|<img src='results/resNet_label_instance_disp/instance_parking_0.png' width="200px">|<img src='results/resNet_label_instance_disp/disp_parking_0.png' width="200px">| |<img width="200px" src='inputs/parking_1.png'>|<img src='results/resNet_label_instance_disp/label_parking_1.png' width="200px">|<img src='results/resNet_label_instance_disp/instance_parking_1.png' width="200px">|<img src='results/resNet_label_instance_disp/disp_parking_1.png' width="200px">| |<img width="200px" src='inputs/truck_0.png'>|<img src='results/resNet_label_instance_disp/label_truck_0.png' width="200px">|<img src='results/resNet_label_instance_disp/instance_truck_0.png' width="200px">|<img src='results/resNet_label_instance_disp/disp_truck_0.png' width="200px">| |<img width="200px" src='inputs/truck_1.png'>|<img src='results/resNet_label_instance_disp/label_truck_1.png' width="200px">|<img src='results/resNet_label_instance_disp/instance_truck_1.png' width="200px">|<img src='results/resNet_label_instance_disp/disp_truck_1.png' width="200px">|
Single vs. Dual vs. All
**Task quantitative result per
