SkillAgentSearch skills...

CRATE

Code for CRATE (Coding RAte reduction TransformEr).

Install / Use

/learn @Ma-Lab-Berkeley/CRATE
About this skill

Quality Score

0/100

Supported Platforms

Universal

README

CRATE (Coding RAte reduction TransformEr)

This repository is the official PyTorch implementation of the papers:

Also, we have released a larger journal-length overview paper of this line of research, which contains a superset of all the results presented above, and also more results in NLP and vision SSL.

Table of Contents

Theoretical Background: What is CRATE?

CRATE (Coding RAte reduction TransformEr) is a white-box (mathematically interpretable) transformer architecture, where each layer performs a single step of an alternating minimization algorithm to optimize the sparse rate reduction objective

<p align="center"> <img src="figs/fig_objective.png" width="400"\> </p> <p align="center">

where $R$ and $R^{c}$ are different coding rates for the input representations w.r.t.~different codebooks, and the $\ell^{0}$-norm promotes the sparsity of the final token representations $\boldsymbol{Z} = f(\boldsymbol{X})$. The function $f$ is defined as $$f=f^{L} \circ f^{L-1} \circ \cdots \circ f^{1} \circ f^{\mathrm{pre}},$$ where $f^{\mathrm{pre}}$ is the pre-processing mapping, and $f^{\ell}$ is the $\ell$-th layer forward mapping that transforms the token distribution to optimize the above sparse rate reduction objective incrementally. More specifically, $f^{\ell}$ transforms the $\ell$-th layer token representations $\boldsymbol{Z}^{\ell}$ to $\boldsymbol{Z}^{\ell+1}$ via the $\texttt{MSSA}$ (Multi-Head Subspace Self-Attention) block and the $\texttt{ISTA}$ (Iterative Shrinkage-Thresholding Algorithms) block, i.e., $$\boldsymbol{Z}^{\ell+1} = f^{\ell}(\boldsymbol{Z}^{\ell}) = \texttt{ISTA}(\boldsymbol{Z}^{\ell} + \texttt{MSSA}(\boldsymbol{Z}^{\ell})).$$

1. CRATE Architecture overview

The following figure presents an overview of the pipeline for our proposed CRATE architecture:

<p align="center"> <img src="figs/fig_pipeline.png" width="900"\> </p> <p align="center">

2. One layer/block of CRATE

The following figure shows the overall architecture of one layer of CRATE as the composition of $\texttt{MSSA}$ and $\texttt{ISTA}$ blocks.

<p align="center"> <img src="figs/fig_arch.png" width="900"\> </p> <p align="center">

3. Per-layer optimization in CRATE

In the following figure, we measure the compression term [ $R^{c}$ ($\boldsymbol{Z}^{\ell+1/2}$) ] and the sparsity term [ $||\boldsymbol{Z}^{\ell+1}||_0$ ] defined in the sparse rate reduction objective, and we find that each layer of CRATE indeed optimizes the targeted objectives, showing that our white-box theoretical design is predictive of practice.

<p align="center"> <img src="figs/fig_layerwise.png" width="900"\> </p> <p align="center">

4. Segmentation visualization of CRATE

In the following figure, we visualize self-attention maps from a supervised CRATE model with 8x8 patches (similar to the ones shown in DINO :t-rex:).

<p align="center"> <img src="figs/fig_seg.png" width="900"\> </p> <p align="center">

We also discover a surprising empirical phenomenon where each attention head in CRATE retains its own semantics.

<p align="center"> <img src="figs/fig_seg_headwise.png" width="900"\> </p> <p align="center">

Autoencoding

We can also use our theory to build a principled autoencoder, which has the following architecture.

<p align="center"> <img src="figs/fig_arch_autoencoder.png" width="900"\> </p> <p align="center">

It has many of the same empirical properties as the base CRATE model, such as segmented attention maps and amenability to layer-wise analysis. We train it on the masked autoencoding task (calling this model CRATE-MAE), and it achieves comparable performance in linear probing and reconstruction quality as the base ViT-MAE.

<p align="center"> <img src="figs/fig_masked_reconstruction.png" width="900"\> </p> <p align="center">

Implementation and Experiments

Constructing a CRATE model

A CRATE model can be defined using the following code, (the below parameters are specified for CRATE-Tiny)

from model.crate import CRATE
dim = 384
n_heads = 6
depth = 12
model = CRATE(image_size=224,
              patch_size=16,
              num_classes=1000,
              dim=dim,
              depth=depth,
              heads=n_heads,
              dim_head=dim // n_heads)

Pre-trained Checkpoints (ImageNet-1K)

| model | dim | n_heads | depth | pre-trained checkpoint | | -------- | -------- | -------- | -------- | -------- | | CRATE-T(iny) | 384 | 6 | 12 | TODO | | CRATE-S(mall) | 576 | 12 | 12 | download link | | CRATE-B(ase) | 768 | 12 | 12 | TODO | | CRATE-L(arge) | 1024 | 16 | 24 | TODO |

Training CRATE on ImageNet

To train a CRATE model on ImageNet-1K, run the following script (training CRATE-tiny)

As an example, we use the following command for training CRATE-tiny on ImageNet-1K:

python main.py 
  --arch CRATE_tiny 
  --batch-size 512 
  --epochs 200 
  --optimizer Lion 
  --lr 0.0002 
  --weight-decay 0.05 
  --print-freq 25 
  --data DATA_DIR

and replace DATA_DIR with [imagenet-folder with train and val folders].

Finetuning pretrained / training random initialized CRATE on CIFAR10

python finetune.py 
  --bs 256 
  --net CRATE_tiny 
  --opt adamW  
  --lr 5e-5 
  --n_epochs 200 
  --randomaug 1 
  --data cifar10 
  --ckpt_dir CKPT_DIR 
  --data_dir DATA_DIR

Replace CKPT_DIR with the path for the pretrained CRATE weight, and replace DATA_DIR with the path for the CIFAR10 dataset. If CKPT_DIR is None, then this script is for training CRATE from random initialization on CIFAR10.

Demo: Emergent segmentation in CRATE

CRATE models exhibit emergent segmentation in their self-attention maps solely through supervised training. We provide a Colab Jupyter notebook to visualize the emerged segmentations from a supervised CRATE model. The demo provides visualizations which match the segmentation figures above.

Link: crate-emergence.ipynb (in colab)

<p align="center">
View on GitHub
GitHub Stars1.3k
CategoryDevelopment
Updated2d ago
Forks97

Languages

Python

Security Score

100/100

Audited on Apr 3, 2026

No findings