CPSD
PyTorch implementation of Boosting Multi-Label Image Classification with Complementary Parallel Self-Distillation, IJCAI 2022.
Install / Use
/learn @Robbie-Xu/CPSDREADME
Introdunction
PyTorch implementation of Boosting Multi-Label Image Classification with Complementary Parallel Self-Distillation, IJCAI 2022.
Abstract
Multi-Label Image Classification (MLIC) approaches usually exploit label correlations to achieve good performance. However, emphasizing correlation like co-occurrence may overlook discriminative features of the target itself and lead to overfitting, thus undermining the performance. In this study, we propose a generic framework named Parallel Self-Distillation (PSD) for boosting MLIC models. PSD decomposes the original MLIC task into several simpler MLIC sub-tasks via two elaborated complementary task decomposition strategies named Co-occurrence Graph Partition (CGP) and Dis-occurrence Graph Partition (DGP). Then, the MLIC models of fewer categories are trained with these sub-tasks in parallel for respectively learning the joint patterns and the category-specific patterns of labels. Finally, knowledge distillation is leveraged to learn a compact global ensemble of full categories for reconciling the label correlation exploitation and model overfitting. Extensive results on MS-COCO and NUS-WIDE datasets demonstrate that our framework can be easily plugged into many MLIC approaches and improve performances of recent state-of-the-art approaches. The explainable visual study also further validates that our method is able to learn both the category-specific and co-occurring features.
Results on MS-COCO:

Results on NUS-WIDE:

Requirements
Following packages are recommended
- numpy
- torch-1.9.0
- torchvision-0.10.0
- tqdm
- environments for different models(q2l, TResNet, etc.)
Quick start
For a detailed description of parameters, use 'python3 file.py -h'
-
Partition We give our files after partition in ./data/, and provide the partition code in ./partition/ .
-
Train teachers
e.g.
python3 train_B_part.py --data ./data/coco --dataset coco --part 0 --num-classes 80 --subnum 1 --typ cluster --model 101 --metric mse
- Train compact student
e.g.
python3 dis_B_cpsd.py --data ./data/coco --dataset coco -t q2l -s q2l --num-classes 80 --subnum 5 --model-root ./checkpoint/ --metric mse
Pretrianed Models
Download pretrained models.
| Modelname | mAP | link(Google drive) |
| ---- | ---- | ----
| CPSD-R101-448-COCO | 83.1 | link |
| CPSD-Q2L-448-COCO | 84.9 | link |
| CPSD-R101TF-448-COCO | 85.2 | link |
| CPSD-R101TF-576-COCO | 86.7 | link |
| CPSD-TResNet-448-COCO | 87.3 | link |
| CPSD-ResNeXt-448-COCO | 87.7 | link |
| CPSD-R101TF-448-NUS | 65.8 | link |
| CPSD-TResNet-448-NUS | 66.5 | link |
Test pretrained models
e.g.
python3 validate.py --model 101tf --dataset nuswide --resume ./checkpoint/dis_nus_combine_101tf2101tf_partition0_cl5_mse_55-65.80161317536648-42.ckpt
Acknowledgement
We thank the authors of Q2L, ASL, TResNet, ML-GCN, C-Tran for their great works and codes.
Related Skills
node-connect
351.4kDiagnose OpenClaw node connection and pairing failures for Android, iOS, and macOS companion apps
frontend-design
110.7kCreate distinctive, production-grade frontend interfaces with high design quality. Use this skill when the user asks to build web components, pages, or applications. Generates creative, polished code that avoids generic AI aesthetics.
openai-whisper-api
351.4kTranscribe audio via OpenAI Audio Transcriptions API (Whisper).
qqbot-media
351.4kQQBot 富媒体收发能力。使用 <qqmedia> 标签,系统根据文件扩展名自动识别类型(图片/语音/视频/文件)。
