Simclr
SimCLRv2 - Big Self-Supervised Models are Strong Semi-Supervised Learners
Install / Use
/learn @google-research/SimclrREADME
SimCLR - A Simple Framework for Contrastive Learning of Visual Representations
<span style="color: red"><strong>News! </strong></span> We have released a TF2 implementation of SimCLR (along with converted checkpoints in TF2), they are in <a href="tf2/">tf2/ folder</a>.
<span style="color: red"><strong>News! </strong></span> Colabs for <a href="https://arxiv.org/abs/2011.02803">Intriguing Properties of Contrastive Losses</a> are added, see <a href="colabs/intriguing_properties/">here</a>.
<div align="center"> <img width="50%" alt="SimCLR Illustration" src="https://1.bp.blogspot.com/--vH4PKpE9Yo/Xo4a2BYervI/AAAAAAAAFpM/vaFDwPXOyAokAC8Xh852DzOgEs22NhbXwCLcBGAsYHQ/s1600/image4.gif"> </div> <div align="center"> An illustration of SimCLR (from <a href="https://ai.googleblog.com/2020/04/advancing-self-supervised-and-semi.html">our blog here</a>). </div>Pre-trained models for SimCLRv2
<a href="colabs/finetuning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
We opensourced total 65 pretrained models here, corresponding to those in Table 1 of the <a href="https://arxiv.org/abs/2006.10029">SimCLRv2</a> paper:
| Depth | Width | SK | Param (M) | F-T (1%) | F-T(10%) | F-T(100%) | Linear eval | Supervised | |--------:|--------:|------:|--------:|-------------:|--------------:|---------------:|-----------------:|--------------:| | 50 | 1X | False | 24 | 57.9 | 68.4 | 76.3 | 71.7 | 76.6 | | 50 | 1X | True | 35 | 64.5 | 72.1 | 78.7 | 74.6 | 78.5 | | 50 | 2X | False | 94 | 66.3 | 73.9 | 79.1 | 75.6 | 77.8 | | 50 | 2X | True | 140 | 70.6 | 77.0 | 81.3 | 77.7 | 79.3 | | 101 | 1X | False | 43 | 62.1 | 71.4 | 78.2 | 73.6 | 78.0 | | 101 | 1X | True | 65 | 68.3 | 75.1 | 80.6 | 76.3 | 79.6 | | 101 | 2X | False | 170 | 69.1 | 75.8 | 80.7 | 77.0 | 78.9 | | 101 | 2X | True | 257 | 73.2 | 78.8 | 82.4 | 79.0 | 80.1 | | 152 | 1X | False | 58 | 64.0 | 73.0 | 79.3 | 74.5 | 78.3 | | 152 | 1X | True | 89 | 70.0 | 76.5 | 81.3 | 77.2 | 79.9 | | 152 | 2X | False | 233 | 70.2 | 76.6 | 81.1 | 77.4 | 79.1 | | 152 | 2X | True | 354 | 74.2 | 79.4 | 82.9 | 79.4 | 80.4 | | 152 | 3X | True | 795 | 74.9 | 80.1 | 83.1 | 79.8 | 80.5 |
These checkpoints are stored in Google Cloud Storage:
- Pretrained SimCLRv2 models (with linear eval head): gs://simclr-checkpoints/simclrv2/pretrained
- Fine-tuned SimCLRv2 models on 1% of labels: gs://simclr-checkpoints/simclrv2/finetuned_1pct
- Fine-tuned SimCLRv2 models on 10% of labels: gs://simclr-checkpoints/simclrv2/finetuned_10pct
- Fine-tuned SimCLRv2 models on 100% of labels: gs://simclr-checkpoints/simclrv2/finetuned_100pct
- Supervised models with the same architectures: gs://simclr-checkpoints/simclrv2/supervised
- The distilled / self-trained models (after fine-tuning) are also provided:
We also provide examples on how to use the checkpoints in colabs/ folder.
Pre-trained models for SimCLRv1
The pre-trained models (base network with linear classifier layer) can be found below. Note that for these SimCLRv1 checkpoints, the projection head is not available.
| Model checkpoint and hub-module | ImageNet Top-1 | |-----------------------------------------------------------------------------------------|------------------------| |ResNet50 (1x) | 69.1 | |ResNet50 (2x) | 74.2 | |ResNet50 (4x) | 76.6 |
Additional SimCLRv1 checkpoints are available: gs://simclr-checkpoints/simclrv1.
A note on the signatures of the TensorFlow Hub module: default is the representation output of the base network; logits_sup is the supervised classification logits for ImageNet 1000 categories. Others (e.g. initial_max_pool, block_group1) are middle layers of ResNet; refer to resnet.py for the specifics. See this tutorial for additional information regarding use of TensorFlow Hub modules.
Enviroment setup
Our models are trained with TPUs. It is recommended to run distributed training with TPUs when using our code for pretraining.
Our code can also run on a single GPU. It does not support multi-GPUs, for reasons such as global BatchNorm and contrastive loss across cores.
The code is compatible with both TensorFlow v1 and v2. See requirements.txt for all prerequisites, and you can also install them using the following command.
pip install -r requirements.txt
Pretraining
To pretrain the model on CIFAR-10 with a single GPU, try the following command:
python run.py --train_mode=pretrain \
--train_batch_size=512 --train_epochs=1000 \
--learning_rate=1.0 --weight_decay=1e-4 --temperature=0.5 \
--dataset=cifar10 --image_size=32 --eval_split=test --resnet_depth=18 \
--use_blur=False --color_jitter_strength=0.5 \
--model_dir=/tmp/simclr_test --use_tpu=False
To pretrain the model on ImageNet with Cloud TPUs, first check out the Google Cloud TPU tutorial for basic information on how to use Google Cloud TPUs.
Once you have created virtual machine with Cloud TPUs, and pre-downloaded the ImageNet data for tensorflow_datasets, please set the following enviroment variables:
TPU_NAME=<tpu-name>
STORAGE_BUCKET=gs://<storage-bucket>
DATA_DIR=$STORAGE_BUCKET/<path-to-tensorflow-dataset>
MODEL_DIR=$STORAGE_BUCKET/<path-to-store-checkpoints>
The following command can be used to pretrain a ResNet-50 on ImageNet (which reflects the default hyperparameters in our paper):
python run.py --train_mode=pretrain \
--train_batch_size=4096 --train_epochs=100 --temperature=0.1 \
--learning_rate=0.075 --learning_rate_scaling=sqrt --weight_decay=1e-4 \
--dataset=imagenet2012 --image_size=224 --eval_split=validation \
--data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
--use_tpu=True --tpu_name=$TPU_NAME --train_summary_steps=0
A batch size of 4096 requires at least 32 TPUs. 100 epochs takes around 6 hours with 32 TPU v3s. Note that learning rate of 0.3 with learning_rate_scaling=linear is equivalent to that of 0.075 with learning_rate_scaling=sqrt when the batch size is 4096. However, using sqrt scaling allows it to train better when smaller batch size is used.
Finetuning the linear head (linear eval)
To fine-tune a linear head (with a single GPU), try the following command:
python run.py --mode=train_then_eval --train_mode=finetune \
--fine_tune_after_block=4 --zero_init_logits_layer=True \
--variable_schema='(?!global_step|(?:.*/|^)Momentum|head)' \
--global_bn=False --optimizer=momentum --learning_rate=0.1 --weight_decay=0.0 \
--train_epochs=100 --train_batch_size=512 --warmup_epochs=0 \
--dataset=cifar10 --image_size=32 --eval_split=test --resnet_depth=18 \
--checkpoint=/tmp/simclr_test --model_dir=/tmp/simclr_test_ft --use_tpu=False
You can check the results using tensorboard, such as
python -m tensorboard.main --logdir=/tmp/simclr_test
As a reference, the above runs on CIFAR-10 should give you around 91% accuracy, though it can be further optimized.
For fine-tuning a linear head on ImageNet using Cloud TPUs, first set the CHKPT_DIR to pretrained model dir and set a new MODEL_DIR, then use the following command:
python run.py --mode=train_then_eval --train_mode=finetune \
--fine_tune_after_block=4 --zero_init_logits_layer=True \
--variable_schema='(?!global_step|(?:.*/|^)Momentum|head)' \
--global_bn=False --optimizer=momentum --learning_rate=0.1 --weight_decay=1e-6 \
--train_epochs=90 --train_batch_size=4096 --warmup_epochs=0 \
--dataset=imagenet2012 --image_size=224 --eval_split=validation \
--data_dir=$DATA_DIR --model_dir=$MODEL_DIR --checkpoint=$CHKPT_DIR \
--use_tpu=True --tpu_name=$TPU_NAME --train_summary_steps=0
As a reference, the above runs on ImageNet should give you around 64.5% accuracy.
Semi-supervised learning and fine-tuning the whole network
You can access 1% and 10% ImageNet subsets used
Related Skills
proje
Interactive vocabulary learning platform with smart flashcards and spaced repetition for effective language acquisition.
YC-Killer
2.7kA library of enterprise-grade AI agents designed to democratize artificial intelligence and provide free, open-source alternatives to overvalued Y Combinator startups. If you are excited about democratizing AI access & AI agents, please star ⭐️ this repository and use the link in the readme to join our open source AI research team.
groundhog
398Groundhog's primary purpose is to teach people how Cursor and all these other coding agents work under the hood. If you understand how these coding assistants work from first principles, then you can drive these tools harder (or perhaps make your own!).
sec-edgar-agentkit
10AI agent toolkit for accessing and analyzing SEC EDGAR filing data. Build intelligent agents with LangChain, MCP-use, Gradio, Dify, and smolagents to analyze financial statements, insider trading, and company filings.
