KDEP
(CVPR2022) Official PyTorch Implementation of KDEP. Knowledge Distillation as Efficient Pre-training: Faster Convergence, Higher Data-efficiency, and Better Transferability
Install / Use
/learn @CVMI-Lab/KDEPREADME
Knowledge Distillation as Efficient Pretraining: Faster Convergence, Higher Data-efficiency, and Better Transferability
This repository contains the code and models necessary to replicate the results of our paper:
@inproceedings{he2022knowledge,
title={Knowledge Distillation as Efficient Pre-training: Faster Convergence, Higher Data-efficiency, and Better Transferability
},
author={He, Ruifei and Sun, Shuyang, and Yang, Jihan, and Bai, Song and Qi, Xiaojuan},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
year={2022}
}
Abstract
Large-scale pre-training has been proven to be crucial for various computer vision tasks. However, with the increase of pre-training data amount, model architecture amount, and the private/inaccessible data, it is not very efficient or possible to pre-train all the model architectures on large-scale datasets. In this work, we investigate an alternative strategy for pre-training, namely Knowledge Distillation as Efficient Pre-training (KDEP), aiming to efficiently transfer the learned feature representation from existing pre-trained models to new student models for future downstream tasks. We observe that existing Knowledge Distillation (KD) methods are unsuitable towards pre-training since they normally distill the logits that are going to be discarded when transferred to downstream tasks. To resolve this problem, we propose a feature-based KD method with non-parametric feature dimension aligning. Notably, our method performs comparably with supervised pre-training counterparts in 3 downstream tasks and 9 downstream datasets requiring 10x less data and 5x less pre-training time.

Getting started
-
Clone our repo:
git clone https://github.com/CVMI-Lab/KDEP.git -
Install dependencies:
conda create -n KDEP python=3.7 conda activate KDEP pip install -r requirements.txt
Data preparation
- ImageNet-1K (Download)
- Caltech256 (Download)
- Cifar100 (Automatically downloaded when you run the code)
- DTD (Download)
- CUB-200 (Download)
- Cityscapes (Download)
- VOC (segmentation and detection, Download)
- ADE20K (Download)
- COCO (Download)
For image classification datasets (except for Caltech256), the folder structure should follow ImageNet:
data root
├─ train/
├── n01440764
│ ├── n01440764_10026.JPEG
│ ├── n01440764_10027.JPEG
│ ├── ......
├── ......
├─ val/
├── n01440764
│ ├── ILSVRC2012_val_00000293.JPEG
│ ├── ILSVRC2012_val_00002138.JPEG
│ ├── ......
├── ......
For semantic segmentation datasets, please refer to PyTorch Semantic Segmentation.
For object detection datasets, please refer to Detectron2.
Pre-training with KDEP
-
Download teacher models (Download), and put them under
pretrained-models/. -
You can use a provided python file
scripts/make-imgnet-subset.pyto create the 10% of ImageNet-1K data. -
Update the path of the dataset for KDEP (10% or 100% of ImageNet-1K) in
src/utils/constants.py. -
Prepare the SVD weights for teacher models. You can download the weights we provide (Download) or generate using our provided script
scripts/gen_svd_weights.sh.sh scripts/gen_svd_weights.sh imgnet_128k ex_gen_svd 0 -
Scripts of pre-training with KDEP are in
scripts/. For example, you can use teacher-student pair of Microsoft ResNet50 -> ResNet18 withscripts/KDEP_MS-R50_R18.shby:sh scripts/KDEP_MS-R50_R18.sh imgnet_128k exp_name 90 30 5e-4 0,1,2,3 ### imgnet_128k or imgnet_full to select 10% or 100% ImageNet-1K data ### 90 is #epoch, 30 is step-lr ### 5e-4 is weight decay ### 0,1,2,3 is GPU idYou can run KDEP with different data amount and training schedules by changing the data name (imgnet_128k or imgnet_full), #epoch and step-lr, and weight decay.
Note that we do not generate the svd weights for 100% ImageNet-1K data, but directly use the svd weights generated from 10% data.
Transfer learning experiments
Image classification
-
We use four image classification tasks: CIFAR100, DTD, Caltech256, CUB-200.
-
Scripts (
scripts/TL_img-cls_R18.shandscripts/TL_img-cls_mnv2.sh) are provided for running all four tasks twice for a distilled student (R18/mnv2).sh scripts/TL_img-cls_R18.sh exp_name # note the exp_name here should be identical to that of the distilled student
Semantic segmentation
-
We use three semantic segmentation tasks: Cityscapes, VOC2012, ADE20K.
-
Transform the checkpoint into segmentation code format by
src/transform_ckpt_custom2seg.pycd src python3 transform_ckpt_custom2seg.py exp_name # note the exp_name here should be identical to that of the distilled studentMove the transformed checkpoint to
semseg/initmodel/. -
Scripts (
semseg/tool/TL_seg_R18.shandsemseg/tool/TL_seg_mnv2.sh) are provided for running all three tasks twice for a distilled student (R18/mnv2).cd semseg sh tool/TL_seg_R18.sh ckpt_name # note the ckpt_name should be what you put into the semseg/initmodel/ in step1.
Object detection
-
We use two object detection tasks: COCO and VOC.
-
Transform the checkpoint into Detectron2 format by
src/transform_ckpt_custom2det.pycd src python3 transform_ckpt_custom2det.py exp_name R18 # note the exp_name here should be identical to that of the distilled student # R18 could be changed to mnv2Move the transformed checkpoint to
detectron2/ckpts/. -
Install Detectron2, and export dataset path
python3 -m pip install -e detectron2 export DETECTRON2_DATASETS='path/to/datasets' -
Scripts (
detectron2/tool/TL_det_R18.shanddetectron2/tool/TL_det_mnv2.sh) are provided for running all two tasks twice for a distilled student (R18/mnv2).cd detectron2/tool sh TL_det_R18.sh ckpt_name # note the ckpt_name should be what you put into the semseg/initmodel/ in step1.
Distilled models of KDEP
We provide some distilled models of KDEP here.
- (Download) ResNet18, KDEP(SVD+PTS) from MS-R50 teacher on 100% ImageNet-1K data for 90 epochs.
- (Download) MobileNet-V2, KDEP(SVD+PTS) from MS-R50 teacher on 100% ImageNet-1K data for 90 epochs.
Acknowledgement
Our code is mainly based on robust-models-transfer, we also thank the open source code from PyTorch Semantic Segmentation and Detectron2.
Related Skills
node-connect
347.6kDiagnose OpenClaw node connection and pairing failures for Android, iOS, and macOS companion apps
frontend-design
108.4kCreate distinctive, production-grade frontend interfaces with high design quality. Use this skill when the user asks to build web components, pages, or applications. Generates creative, polished code that avoids generic AI aesthetics.
openai-whisper-api
347.6kTranscribe audio via OpenAI Audio Transcriptions API (Whisper).
qqbot-media
347.6kQQBot 富媒体收发能力。使用 <qqmedia> 标签,系统根据文件扩展名自动识别类型(图片/语音/视频/文件)。
