Gem
A Pytorch-based library to evaluate learning methods on small image classification datasets
Install / Use
/learn @lorenzobrigato/GemREADME
:gem: GEM: Generalization-Efficient Methods for image classification with small datasets
GEM is a PyTorch-based library with the goal of providing a shared codebase for fast prototyping, training and reproducible evaluation of learning algorithms that generalize on small image datasets.
In particular, the repository contains all the tools to reproduce and possibly extend the experiments of the paper Image Classification with Small Datasets: Overview and Benchmark. It provides:
- [x] A (possibly extendable) benchmark of 5 datasets spanning various data domains and types
- [x] A realistic and fair experimental pipeline including hyper-parameter optimization and common training set-ups
- [x] A (possibly extendable) large pool of implementations for state-of-the-art methods
Given the "living" nature of our libary, we plan in the future to introduce and keep the repository updated with new approaches and datasets to drive further progress toward small-sample learning methods.
:bookmark_tabs: Table of Contents
- Overview<a name="book-overview"/>
- Usage<a name="gear-usage"/>
- Installation<a name="installation"/>
- Method Evaluation<a name="method-evaluation"/>
- Library Extension<a name="library-extension"/>
- Results<a name="bar_chart-results"/>
- Citation<a name="writing_hand-citation"/>
:book: Overview
Structure
More details soon!
Datasets
The datasets constituting our benchmark are the following:
| Dataset | Classes | Imgs/Class | Trainval | Test | Problem Domain | Data Type | Identifier |
|:------------------|--------:|-----------:|---------:|-------:|:---------------|:--------------|:---------------|
| [ciFAIR-10][1]* | 10 | 50 | 500 | 10,000 | Natural Images | RGB (32x32) | cifair10 |
| [CUB][2] | 200 | 30 | 5,994 | 5,794 | Fine-Grained | RGB | cub |
| [ISIC 2018][3]* | 7 | 80 | 560 | 1,944 | Medical | RGB | isic2018 |
| [EuroSAT][4]* | 10 | 50 | 500 | 19,500 | Remote Sensing | Multispectral | eurosat |
| [CLaMM][5]* | 12 | 50 | 600 | 2,000 | Handwriting | Grayscale | clamm |
* We use subsampled versions of the original datasets with fewer images per class.
For additional details on the dataset statistics, splits, and ways to download the data, visit the respective page in the folder datasets.
The directory contains one sub-directory for each dataset in our benchmark. These directories contain the split files specifying the subsets of data employed in our experiments. The files trainval{i}.txt are simply the concatenation of train{i}.txt and val{i}.txt (with i in {0,1,2}). These subsets can be used for the final training before evaluating a method on the test set. Development and hyper-parameter optimization (HPO), however, should only be conducted using the training and validation sets.
The aforementioned files list all images contained in the respective subset, one per line, along with their class labels. Each line contains the filename of an image followed by a space and the numeric index of its label.
The only exception from this common format is ciFAIR-10, since it does not have filenames. A description of the split can be found in the README of the respective directory.
Methods
We currently provide the implementations of the following methods:
| Method | Original code | Our implementation | Identifier |
|:-------------------------------------------------|-------------------------:|------------------------------------------------:|-------------------:|
| Cross-Entropy Loss (baseline) | -- |[xent.py][xent.py] |xent |
| [Deep Hybrid Networks][scattering] |[link][scattering_code] |[scattering.py][scattering.py] |scattering |
| [OLÉ][ole] |[link][ole_code] |[ole.py][ole.py] |ole |
| [Grad-L2 Penalty][kernelregular] |[link][kernelregular_code]|[kernelregular.py][kernelregular.py] |gradl2 |
| [Cosine Loss (+ Cross-Entropy)][cosineloss] |-- |[cosineloss.py][cosineloss.py] |cosine |
| [Harmonic Networks][harmonic] |[link][harmonic_code] |[harmonic.py][harmonic.py] |harmonic |
| [Full Convolution][fconv] |[link][fconv_code] |[fconv.py][fconv.py] |fconv |
| [DSK Networks][dsknet] |-- |[dsk_classifier.py][dsk_classifier.py] | dsk |
| [Distilling Visual Priors][distill] |[link][distill_code] |[distill_pretraining.py][distill_pretraining.py]<br>[distill_classifier.py][distill_classifier.py]| dvp-pretrain<br>dvp-distill|
| [Auxiliary Learning][auxilearn] |[link][auxilearn_code] | [auxilearn.py][auxilearn.py] |auxilearn |
| [T-vMF Similarity][tvmf] |[link][tvmf_code] | [tvmf.py][tvmf.py] |tvmf |
:gear: Usage
Installation
To use the repository, clone it in your local system:
git clone https://github.com/lorenzobrigato/gem.git
and install the required packages with:
python -m pip install -r requirements.txt
Note: GEM requires PyTorch with GPU support. Hence, for instructions on how to install PyTorch versions compatible with your CUDA versions, see pytorch.org.
Method Evaluation
We provide a set of scripts located in the directories scripts and bash_scripts to reproduce the experimental pipeline presented in our paper. In particular, evaluating one method on the full benchmark consists in:
- Finding hyper-parameters by training the approach on the
train{i}.txtsplit while evaluating on the respectiveval{i}.txt - Training 10 instances of the method given the found configuration on the full
trainval{i}.txtsplit while evaluating on the test split - Repeating independently points 1. and 2. for all values of
i
For all datasets, the number of training splits used in our paper is 3, hence i is in the range {0,1,2}. For the testing sets, in some cases we have multiple splits as for the training, in others we employed a single test0.txt split. We performed multiple independent evaluations changing dataset splits and optimization runs to account for random variance (particularly significant in the small-sample regime).
To separately perform 1. and 2., we respectively provide hpo.py and train.py / train_ray.py. It is also possible to do 1. and 2. sequentially by executing full_train.py.
For achieving 3., refer to the bash scripts available in bash_scripts.
We are now going to treat in more details all the available chioces in terms of runnable scripts.
Hyper-Parameter Optimization (HPO)
For what concerns HPO, we employ an efficient and easy-to-use library (Tune) and a state-of-the-art search algorithm (Asynchronous Successive Halving Algorithm (ASHA)).
Script hpo.py is dedicated to finding hyper-parameters of a method.
For instance, searching for default hyper-parameters, i.e., learning rate, weight decay, and batch size, for the cross-entropy baseline with a Wide ResNet-16-8 on the [ciFAIR-10][1] dataset and splits 0 (default) is achievable by running:
python scripts/hpo.py cifair10 \
--method xent \
--architecture wrn-16-8 \
--rand-shift 4 \
--epochs 500 \
--grace-period 50 \
--num-trials 250 \
--cpus-per-trial 8 \
--gpus-per-trial 0.5
After completion, the script will print on screen the found hyper-parameters. Notice that --grace-period and --num-trials refer to parameters of the search algorithm. that have been fixed for each dataset and are hard-coded in the bash scripts of folder bash_scripts.
To have a complete view of all the arguments accepted by the script, chek the help message of the parser by running:
python scripts/hpo.py -h
Note also that you can configure the hardware resources spent on trials. For examle, with --gpus-per-trial 0.5 the script will run two trials in parallel.
Exploit parallelism to speed up the search but consider that the number of trials per GPU is bounded by the GPU memory available.
Final Evaluation
Once that the hyper-parameters have been found, you can execute the training of a single model for the test evaluation with script train.py. Or you can also train multiple instances of the same model in parallel exploiting again the Tune library and script [train_ray.py](scr
