MAPLS
This is the official implementation for WACV 2024 paper "Label Shift Estimation for Class-Imbalance Problem: A Bayesian Approach".
Install / Use
/learn @ChangkunYe/MAPLSREADME
Label Shift Estimation for Class-Imbalance Problem: A Bayesian Approach
This is the official implementation for WACV 2024 paper "Label Shift Estimation for Class-Imbalance Problem: A Bayesian Approach".
If you find this repository useful or use this code in your research, please cite the following paper:
@InProceedings{Ye_2024_WACV,
author = {Ye, Changkun and Tsuchida, Russell and Petersson, Lars and Barnes, Nick},
title = {Label Shift Estimation for Class-Imbalance Problem: A Bayesian Approach},
booktitle = {Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision (WACV)},
month = {January},
year = {2024},
pages = {1073-1082}
}
Requirements
The code is written in PyTorch. It is recommned to install via conda:
conda install scipy
conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch
conda install -c conda-forge cvxpy
When train the Neural Network classifier from scratch, the recommended hardware setup is as follows:
| Dataset | #GPU | #CPU | |:-----------:|:--------------:|:-------------------:| | CIFAR10/100 | ≥ 2 Gb | > 4 + 1 threads | | ImageNet | ≥ 4 x 12 Gb | > 16 + 1 threads | | Places | ≥ 6 x 12 Gb | > 24 + 1 threads |
Dataset Details
Our code support CIFAR10/100, ImageNet 2012 and Places365 datasets.
- CIFAR10/100 dataset: Please download with build-in function of pytorch.
- ImageNet dataset: Please download the ImageNet 2012 at official site https://image-net.org/.
- Places dataset: Please download at offical site http://places2.csail.mit.edu/
For Long-Tailed version of ImageNet and Places, please download the split at here. This split provided by Large-Scale Long-Tailed Recognition in an Open World paper.
data
|--CIFAR10
|--cifar-10-batches-py
|--CIFAR100
|--cifar-100-python
|--ImageNet
|--train
|--val
|--ImageNet_LT_train.txt
|--ImageNet_LT_test.txt
|--ImageNet_LT_val.txt
|--Places
|--data_256
|--val_256
|--test_256
|--Places_LT_train.txt
|--Places_LT_test.txt
|--Places_LT_val.txt
P.S It is recommended to prepare ImageNet dataset with:
Extract the training data:
mkdir train && mv ILSVRC2012_img_train.tar train/ && cd train
tar -xvf ILSVRC2012_img_train.tar && rm -f ILSVRC2012_img_train.tar
find . -name "*.tar" | while read NAME ; do mkdir -p "${NAME%.tar}"; tar -xvf "${NAME}" -C "${NAME%.tar}"; rm -f "${NAME}"; done
cd ..
Extract the validation data and move images to subfolders:
mkdir val && mv ILSVRC2012_img_val.tar val/ && cd val && tar -xvf ILSVRC2012_img_val.tar
wget -qO- https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh | bash
This script is originally provided in here.
Train the classifier
To train the classifier from scratch, please adjust the GPU ids "CUDA_VISIBLE_DEVICES", dataset path "$data_path" and config path "$cfg_path" in the bash script "./train_script.sh" and run:
./train_script.sh
Config examples are provided in "./config/".
Test Label Shift Estimation Model
To test existing models performance under label shift, adjust the dataset path "$data_path" and checkpoint path "$ckpt_path" in the bash script "./test_script.sh" and run:
./test_script.sh
The "$cfg_path" in "./test_script.sh" determines the type of label shift, including:
- "./config/batch_imb_LT": Ordered Long-Tailed Shift
- "./config/batch_imb_shuffle": Shuffled Long-Tailed Shift
- "./config/batch_imb_dirichlet": Dirichlet Shift
- "./config/batch_imb_knockout": Knockout Shift
License
Please see LICENSE
Questions?
Pleas raise issues or contact author at changkun.ye@anu.edu.au.
Related Skills
node-connect
347.0kDiagnose OpenClaw node connection and pairing failures for Android, iOS, and macOS companion apps
frontend-design
107.8kCreate distinctive, production-grade frontend interfaces with high design quality. Use this skill when the user asks to build web components, pages, or applications. Generates creative, polished code that avoids generic AI aesthetics.
openai-whisper-api
347.0kTranscribe audio via OpenAI Audio Transcriptions API (Whisper).
qqbot-media
347.0kQQBot 富媒体收发能力。使用 <qqmedia> 标签,系统根据文件扩展名自动识别类型(图片/语音/视频/文件)。
