SkillAgentSearch skills...

Swav

PyTorch implementation of SwAV https//arxiv.org/abs/2006.09882

Install / Use

/learn @facebookresearch/Swav
About this skill

Quality Score

0/100

Supported Platforms

Universal

README

Unsupervised Learning of Visual Features by Contrasting Cluster Assignments

This code provides a PyTorch implementation and pretrained models for SwAV (Swapping Assignments between Views), as described in the paper Unsupervised Learning of Visual Features by Contrasting Cluster Assignments.

<div align="center"> <img width="100%" alt="SwAV Illustration" src="https://dl.fbaipublicfiles.com/deepcluster/animated.gif"> </div>

SwAV is an efficient and simple method for pre-training convnets without using annotations. Similarly to contrastive approaches, SwAV learns representations by comparing transformations of an image, but unlike contrastive methods, it does not require to compute feature pairwise comparisons. It makes our framework more efficient since it does not require a large memory bank or an auxiliary momentum network. Specifically, our method simultaneously clusters the data while enforcing consistency between cluster assignments produced for different augmentations (or “views”) of the same image, instead of comparing features directly. Simply put, we use a “swapped” prediction mechanism where we predict the cluster assignment of a view from the representation of another view. Our method can be trained with large and small batches and can scale to unlimited amounts of data.

Model Zoo

We release several models pre-trained with SwAV with the hope that other researchers might also benefit by replacing the ImageNet supervised network with SwAV backbone. To load our best SwAV pre-trained ResNet-50 model, simply do:

import torch
model = torch.hub.load('facebookresearch/swav:main', 'resnet50')

We provide several baseline SwAV pre-trained models with ResNet-50 architecture in torchvision format. We also provide models pre-trained with DeepCluster-v2 and SeLa-v2 obtained by applying improvements from the self-supervised community to DeepCluster and SeLa (see details in the appendix of our paper).

| method | epochs | batch-size | multi-crop | ImageNet top-1 acc. | url | args | |-------------------|-------------------|---------------------|--------------------|--------------------|--------------------|--------------------| | SwAV | 800 | 4096 | 2x224 + 6x96 | 75.3 | model | script | | SwAV | 400 | 4096 | 2x224 + 6x96 | 74.6 | model | script | | SwAV | 200 | 4096 | 2x224 + 6x96 | 73.9 | model | script | | SwAV | 100 | 4096 | 2x224 + 6x96 | 72.1 | model | script | | SwAV | 200 | 256 | 2x224 + 6x96 | 72.7 | model | script | | SwAV | 400 | 256 | 2x224 + 6x96 | 74.3 | model | script | | SwAV | 400 | 4096 | 2x224 | 70.1 | model | script | | DeepCluster-v2 | 800 | 4096 | 2x224 + 6x96 | 75.2 | model | script | | DeepCluster-v2 | 400 | 4096 | 2x160 + 4x96 | 74.3 | model | script | | DeepCluster-v2 | 400 | 4096 | 2x224 | 70.2 | model | script | | SeLa-v2 | 400 | 4096 | 2x160 + 4x96 | 71.8 | model | - | | SeLa-v2 | 400 | 4096 | 2x224 | 67.2 | model | - |

Larger architectures

We provide SwAV models with ResNet-50 networks where we multiply the width by a factor ×2, ×4, and ×5. To load the corresponding backbone you can use:

import torch
rn50w2 = torch.hub.load('facebookresearch/swav:main', 'resnet50w2')
rn50w4 = torch.hub.load('facebookresearch/swav:main', 'resnet50w4')
rn50w5 = torch.hub.load('facebookresearch/swav:main', 'resnet50w5')

| network | parameters | epochs | ImageNet top-1 acc. | url | args | |-------------------|---------------------|--------------------|--------------------|--------------------|--------------------| | RN50-w2 | 94M | 400 | 77.3 | model | script | | RN50-w4 | 375M | 400 | 77.9 | model | script | | RN50-w5 | 586M | 400 | 78.5 | model | - |

Running times

We provide the running times for some of our runs: | method | batch-size | multi-crop | scripts | time per epoch | |---------------------|--------------------|--------------------|--------------------|--------------------| | SwAV | 4096 | 2x224 + 6x96 | * * * * | 3min40s | | SwAV | 256 | 2x224 + 6x96 | * * | 52min10s | | DeepCluster-v2 | 4096 | 2x160 + 4x96 | * | 3min13s |

Running SwAV unsupervised training

Requirements

Singlenode training

SwAV is very simple to implement and experiment with. Our implementation consists in a main_swav.py file from which are imported the dataset definition src/multicropdataset.py, the model architecture src/resnet50.py and some miscellaneous training utilities src/utils.py.

For example, to train SwAV baseline on a single node with 8 gpus for 400 epochs, run:

python -m torch.distributed.launch --nproc_per_node=8 main_swav.py \
--data_path /path/to/imagenet/train \
--epochs 400 \
--base_lr 0.6 \
--final_lr 0.0006 \
--warmup_epochs 0 \
--batch_size 32 \
--size_crops 224 96 \
--nmb_crops 2 6 \
--min_scale_crops 0.14 0.05 \
--max_scale_crops 1. 0.14 \
--use_fp16 true \
--freeze_prototypes_niters 5005 \
--queue_length 3840 \
--epoch_queue_starts 15

Multinode training

Distributed training is available via Slurm. We provide several SBATCH scripts to reproduce our SwAV models. For example, to train SwAV on 8 nodes and 64 GPUs with a batch size of 4096 for 800 epochs run:

sbatch ./scripts/swav_800ep_pretrain.sh

Note that you might need to remove the copyright header from the sbatch file to launch it.

Set up dist_url parameter: We refer the user to pytorch distributed documentation (env or file or tcp) for setting the distributed initialization method (parameter dist_url) correctly. In the provided sbatch files, we use the tcp init method (see * for example).

Evaluating models

Evaluate models: Linear classification on ImageNet

To train a supervised linear classifier on frozen features/weights on a single node with 8 gpus, run:

python -m torch.distributed.launch --nproc_per_node=8 eval_linear.py \
--data_path /path/to/imagenet \
--pretrained /path/to/checkpoints/swav_800ep_pretrain.pth.tar

The resulting linear classifier can be downloaded here.

Evaluate models: Semi-supervised learning on ImageNet

To reproduce our results and fine-tune a network with 1% or 10% of ImageNet labels on a single node with 8 gpus, run:

  • 10% labels
python -m torch.distributed.launch --nproc_per_node=8 eval_semisup.py \
--data_path /path/to/imagenet \
--pretrained /path/to/checkpoints/swav_800ep_pretrain.pth.tar \
--labels_perc "10" \
--lr 0.01 \
--lr_last_layer 0.2
  • 1% labels
python -m torch.distributed.launch --nproc_per_node=8 eval_semisup.py \
--data_path /path/to/imagenet \
--pretrained /path/to/checkpoints/swav_800ep_pretrain.pth.tar \
--labels_perc "1" \
--lr 0.02 \
--lr_last_layer 5

Evaluate models: Transferring to Detection with DETR

DETR is a recent object detection framework that reaches competitive performance with Faster R-CNN while being conceptually simpler and trainable end-to-end. We evaluate our SwAV ResNet-50 backbone on object detection on COCO dataset using DETR framework with full fine-tu

Related Skills

View on GitHub
GitHub Stars2.1k
CategoryDevelopment
Updated2d ago
Forks285

Languages

Python

Security Score

80/100

Audited on Mar 30, 2026

No findings