CSI
CSI: Novelty Detection via Contrastive Learning on Distributionally Shifted Instances (NeurIPS 2020)
Install / Use
/learn @alinlab/CSIREADME
CSI: Novelty Detection via Contrastive Learning on Distributionally Shifted Instances
Official PyTorch implementation of "CSI: Novelty Detection via Contrastive Learning on Distributionally Shifted Instances" (NeurIPS 2020) by Jihoon Tack*, Sangwoo Mo*, Jongheon Jeong, and Jinwoo Shin.
<p align="center"> <img src=figures/shifting_transformations.png width="900"> </p>1. Requirements
Environments
Currently, requires following packages
- python 3.6+
- torch 1.4+
- torchvision 0.5+
- CUDA 10.1+
- scikit-learn 0.22+
- tensorboard 2.0+
- torchlars == 0.1.2
- pytorch-gradual-warmup-lr packages
- apex == 0.1
- diffdist == 0.1
Datasets
For CIFAR, please download the following datasets to ~/data.
For ImageNet-30, please download the following datasets to ~/data.
- ImageNet-30-train, ImageNet-30-test
- CUB-200, Stanford Dogs, Oxford Pets, Oxford flowers, Food-101, Places-365, Caltech-256, DTD
For Food-101, remove hotdog class to avoid overlap.
2. Training
Currently, all code examples are assuming distributed launch with 4 multi GPUs.
To run the code with single GPU, remove -m torch.distributed.launch --nproc_per_node=4.
Unlabeled one-class & multi-class
To train unlabeled one-class & multi-class models in the paper, run this command:
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 train.py --dataset <DATASET> --model <NETWORK> --mode simclr_CSI --shift_trans_type rotation --batch_size 32 --one_class_idx <One-Class-Index>
Option --one_class_idx denotes the in-distribution of one-class training. For multi-class training, set --one_class_idx as None. To run SimCLR simply change --mode to simclr. Total batch size should be 512 = 4 (GPU) * 32 (--batch_size option) * 4 (cardinality of shifted transformation set).
Labeled multi-class
To train labeled multi-class model (confidence calibrated classifier) in the paper, run this command:
# Representation train
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 train.py --dataset <DATASET> --model <NETWORK> --mode sup_simclr_CSI --shift_trans_type rotation --batch_size 32 --epoch 700
# Linear layer train
python train.py --mode sup_CSI_linear --dataset <DATASET> --model <NETWORK> --batch_size 32 --epoch 100 --shift_trans_type rotation --load_path <MODEL_PATH>
To run SupCLR simply change --mode to sup_simclr, sup_linear for representation training and linear layer training respectively. Total batch size should be same as above. Currently only supports rotation for shifted transformation.
3. Evaluation
We provide the checkpoint of the CSI pre-trained model. Download the checkpoint from the following link:
- One-class CIFAR-10: ResNet-18
- Unlabeled (multi-class) CIFAR-10: ResNet-18
- Unlabeled (multi-class) ImageNet-30: ResNet-18
- Labeled (multi-class) CIFAR-10: ResNet-18
Unlabeled one-class & multi-class
To evaluate my model on unlabeled one-class & multi-class out-of-distribution (OOD) detection setting, run this command:
python eval.py --mode ood_pre --dataset <DATASET> --model <NETWORK> --ood_score CSI --shift_trans_type rotation --print_score --ood_samples 10 --resize_factor 0.54 --resize_fix --one_class_idx <One-Class-Index> --load_path <MODEL_PATH>
Option --one_class_idx denotes the in-distribution of one-class evaluation. For multi-class evaluation, set --one_class_idx as None. The resize_factor & resize fix option fix the cropping size of RandomResizedCrop(). For SimCLR evaluation, change --ood_score to simclr.
Labeled multi-class
To evaluate my model on labeled multi-class accuracy, ECE, OOD detection setting, run this command:
# OOD AUROC
python eval.py --mode ood --ood_score baseline_marginalized --print_score --dataset <DATASET> --model <NETWORK> --shift_trans_type rotation --load_path <MODEL_PATH>
# Accuray & ECE
python eval.py --mode test_marginalized_acc --dataset <DATASET> --model <NETWORK> --shift_trans_type rotation --load_path <MODEL_PATH>
This option is for marginalized inference. For single inference (also used for SupCLR) change --ood_score baseline in first command, and --mode test_acc in second command.
4. Results
Our model achieves the following performance on:
One-Class Out-of-Distribution Detection
| Method | Dataset | AUROC (Mean) | | --------------|------------------ | --------------| | SimCLR | CIFAR-10-OC | 87.9% | | Rot+Trans | CIFAR-10-OC | 90.0% | | CSI (ours) | CIFAR-10-OC | 94.3% |
We only show CIFAR-10 one-class result in this repo. For other setting, please see our paper.
Unlabeled Multi-Class Out-of-Distribution Detection
| Method | Dataset | OOD Dataset | AUROC (Mean) | | --------------|------------------ |---------------|--------------| | Rot+Trans | CIFAR-10 | CIFAR-100 | 82.5% | | CSI (ours) | CIFAR-10 | CIFAR-100 | 89.3% |
We only show CIFAR-10 to CIFAR-100 OOD detection result in this repo. For other OOD dataset results, see our paper.
Labeled Multi-Class Result
| Method | Dataset | OOD Dataset | Acc | ECE | AUROC (Mean) | | ---------------- |------------------ |---------------|-------|-------|--------------| | SupCLR | CIFAR-10 | CIFAR-100 | 93.9% | 5.54% | 88.3% | | CSI (ours) | CIFAR-10 | CIFAR-100 | 94.8% | 4.24% | 90.6% | | CSI-ensem (ours) | CIFAR-10 | CIFAR-100 | 96.0% | 3.64% | 92.3% |
We only show CIFAR-10 with CIFAR-100 as OOD in this repo. For other dataset results, please see our paper.
5. New OOD dataset
<p align="center"> <img src=figures/fixed_ood_benchmarks.png width="600"> </p>We find that current benchmark datasets for OOD detection, are visually far from in-distribution datasets (e.g. CIFAR).
To address this issue, we provide new datasets for OOD detection evaluation: LSUN_fix, ImageNet_fix. See the above figure for the visualization of current benchmark and our dataset.
To generate OOD datasets, run the following codes inside the ./datasets folder:
# ImageNet FIX generation code
python imagenet_fix_preprocess.py
# LSUN FIX generation code
python lsun_fix_preprocess.py
Citation
@inproceedings{tack2020csi,
title={CSI: Novelty Detection via Contrastive Learning on Distributionally Shifted Instances},
author={Jihoon Tack and Sangwoo Mo and Jongheon Jeong and Jinwoo Shin},
booktitle={Advances in Neural Information Processing Systems},
year={2020}
}
Related Skills
proje
Interactive vocabulary learning platform with smart flashcards and spaced repetition for effective language acquisition.
YC-Killer
2.7kA library of enterprise-grade AI agents designed to democratize artificial intelligence and provide free, open-source alternatives to overvalued Y Combinator startups. If you are excited about democratizing AI access & AI agents, please star ⭐️ this repository and use the link in the readme to join our open source AI research team.
groundhog
398Groundhog's primary purpose is to teach people how Cursor and all these other coding agents work under the hood. If you understand how these coding assistants work from first principles, then you can drive these tools harder (or perhaps make your own!).
sec-edgar-agentkit
10AI agent toolkit for accessing and analyzing SEC EDGAR filing data. Build intelligent agents with LangChain, MCP-use, Gradio, Dify, and smolagents to analyze financial statements, insider trading, and company filings.
