HIPT
Hierarchical Image Pyramid Transformer - CVPR 2022 (Oral)
Install / Use
/learn @mahmoodlab/HIPTREADME
Scaling Vision Transformers to Gigapixel Images via Hierarchical Self-Supervised Learning
<details> <summary> <b>Scaling Vision Transformers to Gigapixel Images via Hierarchical Self-Supervised Learning</b>, CVPR 2022. <a href="https://openaccess.thecvf.com/content/CVPR2022/html/Chen_Scaling_Vision_Transformers_to_Gigapixel_Images_via_Hierarchical_Self-Supervised_Learning_CVPR_2022_paper.html" target="blank">[HTML]</a> <a href="https://arxiv.org/abs/2206.02647" target="blank">[arXiv]</a> <a href="https://www.youtube.com/watch?v=cABkB1J-GTA" target="blank">[Oral]</a> <br><em><a href="http://richarizardd.me">Richard. J. Chen</a>, <a href="https://www.kuanchchen.com">Chengkuan Chen</a>, <a href="https://www.linkedin.com/in/yicong-jackson-li/">Yicong Li</a>, <a href="https://twitter.com/tiffanyytchen">Tiffany Y. Chen</a>, <a href="https://www.gatesfoundation.org/about/leadership/andrew-trister">Andrew D. Trister</a>, <a href="http://www.cs.toronto.edu/~rahulgk/index.html">Rahul G. Krishnan*</a>, <a href="https://faisal.ai/">Faisal Mahmood*</a></em></br> </summary>@inproceedings{chen2022scaling,
author = {Chen, Richard J. and Chen, Chengkuan and Li, Yicong and Chen, Tiffany Y. and Trister, Andrew D. and Krishnan, Rahul G. and Mahmood, Faisal},
title = {Scaling Vision Transformers to Gigapixel Images via Hierarchical Self-Supervised Learning},
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
month = {June},
year = {2022},
pages = {16144-16155}
}
</details>
<div align="center">
<img width="100%" alt="HIPT Illustration" src=".github/HIPT Architecture.gif">
</div>
<details>
<summary>
<b>Key Ideas & Main Findings</b>
</summary>
- Hierarchical Image Pyramid Transformer (HIPT) Architecture: Three-stage hierarchical ViT that formulates gigapixel whole-slide images (WSIs) as a disjoint set of nested sequences. HIPT unroll the WSI into non-overlapping
[4096 × 4096]image regions, followed by unrolling each region into non-overlapping[256 × 256]image patches, and lastly each patch as non-overlapping[16 × 16]cell tokens. Our method is analgous to that of hierarchical attention networks in long document modeling, in which word embeddings within sentences are aggregated to form sentence-level embeddings and subsequently aggregated into document-level embeddings. Inference in HIPT is performed via bottom-up aggregation of[16 × 16]visual tokens in their respective[256 × 256]and[4096 × 4096]windows via Transformer attention to compute a slide-level representation. - Learning Context-Aware Token Dependencies in WSIs: Note that Transformer attention is computed only in local windows (instead of across the entire WSI), which makes learning long-range dependencies tractable. Though representation learning for
[4096 × 4096]image regions may seem expensive, also note that the patch size at this level is[256 × 256], and thus has similar complexity of applying ViTs to[256 × 256]image patches with[16 × 16]tokens. - Hierarchical Pretraining: Since encoding
[4096 x 4096]images is the same subproblem as encoding[256 x 256]images, we hypothesize that ViT pretraining techniques can generalize to higher resolutions with little modification. DINO is used to not only pretrain ViT-16 in HIPT, but also ViT-256 via [6 x 6] local and [14 x 14] global crops on a 2D grid-of-features (obtained by using VIT-16 as a patch tokenizer for ViT-256). - Self-Supervised Slide-Level Representation Learning: HIPT is evaluated via pretraining + freezing the ViT-16 / ViT-256 stages, with the ViT-4K stage finetuned with slide-level labels, assessed on cancer subtyping and survival prediction tasks in TCGA. We also perform self-supervised KNN evaluation of HIPT embeddings via computing the mean [CLS]-4K tokens extracted from ViT-256, as a proxy for the slide-level embedding. On Renal Cell Carcinoma subtyping, we report that averaged, pretrained HIPT-4K embeddings without any labels perform as well as CLAM-SB.
Updates / TODOs
Please follow this GitHub for more updates.
- [ ] Removing dead code in HIPT_4K library.
- [X] Better documentation on interpretability code example.
- [x] Add pretrained models + instructions for hierarchical visualization.
- [X] Add pre-extracted slide-level embeddings, and code for K-NN evaluation.
- [X] Add weakly-supervised results for Tensorboard.
Pre-Reqs + Installation
This repository includes not only the code base for HIPT, but also saved HIPT checkpoints and pre-extracted HIPT slide embeddings with ~4.08 GiB of storage, which we version control via Git LFS.
To clone this repository without large files initially:
GIT_LFS_SKIP_SMUDGE=1 git clone https://github.com/mahmoodlab/HIPT.git # Pulls just the codebase
git lfs pull --include "*.pth" # Pulls the pretrained checkpoints
git lfs pull --include "*.pt" # Pulls pre-extracted slide embeddings
git lfs pull --include "*.pkl" # Pulls pre-extracted patch embeddings
git lfs pull --include "*.png" # Pulls demo images (required for 4K x 4K visualization)
To clone all files:
git clone https://github.com/mahmoodlab/HIPT.git
To install Python dependencies:
pip install -r requirements.txt
HIPT Walkthrough
How HIPT Works
Below is a snippet of a standalone two-stage HIPT model architecture that can load fully self-supervised weights for nested [16 x 16] and [256 x 256] token aggregation, defined in ./HIPT_4K/hipt_4k.py. Via a few einsum operations, you can put together multiple ViT encoders and have it scale to large resolutions. HIPT_4K was used for feature extraction of non-overlapping [4096 x 4096] image regions across the TCGA.
import torch
from einops import rearrange, repeat
from HIPT_4K.hipt_model_utils import get_vit256, get_vit4k
class HIPT_4K(torch.nn.Module):
"""
HIPT Model (ViT_4K-256) for encoding non-square images (with [256 x 256] patch tokens), with
[256 x 256] patch tokens encoded via ViT_256-16 using [16 x 16] patch tokens.
"""
def __init__(self,
model256_path: str = 'path/to/Checkpoints/vit256_small_dino.pth',
model4k_path: str = 'path/to/Checkpoints/vit4k_xs_dino.pth',
device256=torch.device('cuda:0'),
device4k=torch.device('cuda:1')):
super().__init__()
self.model256 = get_vit256(pretrained_weights=model256_path).to(device256)
self.model4k = get_vit4k(pretrained_weights=model4k_path).to(device4k)
self.device256 = device256
self.device4k = device4k
self.patch_filter_params = patch_filter_params
def forward(self, x):
"""
Forward pass of HIPT (given an image tensor x), outputting the [CLS] token from ViT_4K.
1. x is center-cropped such that the W / H is divisible by the patch token size in ViT_4K (e.g. - 256 x 256).
2. x then gets unfolded into a "batch" of [256 x 256] images.
3. A pretrained ViT_256-16 model extracts the CLS token from each [256 x 256] image in the batch.
4. These batch-of-features are then reshaped into a 2D feature grid (of width "w_256" and height "h_256".)
5. This feature grid is then used as the input to ViT_4K-256, outputting [CLS]_4K.
Args:
- x (torch.Tensor): [1 x C x W' x H'] image tensor.
Return:
- features_cls4k (torch.Tensor): [1 x 192] cls token (d_4k = 192 by default).
"""
batch_256, w_256, h_256 = self.prepare_img_tensor(x) # 1. [1 x 3 x W x H].
batch_256 = batch_256.unfold(2, 256, 256).unfold(3, 256, 256) # 2. [1 x 3 x w_256 x h_256 x 256 x 256]
batch_256 = rearrange(batch_256, 'b c p1 p2 w h -> (b p1 p2) c w h') # 2. [B x 3 x 256 x 256], where B = (1*w_256*h_256)
features_cls256 = []
for mini_bs in range(0, batch_256.shape[0], 256): # 3. B may be too large for ViT_256. We further take minibatches of 256.
minibatch_256 = batch_256[mini_bs:mini_bs+256].to(self.device256, non_blocking=True)
features_cls256.append(self.model256(minibatch_256).detach().cpu()) # 3. Extracting ViT_256 features from [256 x 3 x 256 x 256] image batches.
features_cls256 = torch.vstack(features_cls256) # 3. [B x 384], where 384 == dim of ViT-256 [ClS] token.
features_cls256 = features_cls256.reshape(w_256, h_256, 384).transpose(0,1).transpose(0,2).unsqueeze(dim=0)
features_cls256 = features_cls256.to(self.device4k, non_blocking=True) # 4. [1 x 384 x w_256 x h_256]
features_cls4k = self.model4k.forward(features_cls256) # 5. [1 x 192], where 192 == dim of ViT_4K [ClS] token.
return features_cls4k
Using the HIPT_4K API
You can use the HIPT_4K model out-of-the-box, and use it to plug-and-play into any of your downstream tasks (example below).
from HIPT_4K.hipt_4k import HIPT_4K
from HIPT_4K.hipt_model_utils import eval_transforms
model = HIPT_4K()
model.eval()
region = Image.open('HIPT_4K/image_demo/image_4k.png')
x = eval_transforms()(region).unsqueeze(dim=0)
out = model.forward(x)
Hierarchical Interpretability
<div align="center"> <img width="100%" alt="DINO illustration" src=".github/HIPT_attention.jpg"> </div>For hierarchical interpretability, please see the following notebook, which uses the following functions in ./HIPT_4K/hipt_heatmap_utils.py.
Downloading + Preprocessing + Orga
Related Skills
YC-Killer
2.7kA library of enterprise-grade AI agents designed to democratize artificial intelligence and provide free, open-source alternatives to overvalued Y Combinator startups. If you are excited about democratizing AI access & AI agents, please star ⭐️ this repository and use the link in the readme to join our open source AI research team.
flutter-tutor
Flutter Learning Tutor Guide You are a friendly computer science tutor specializing in Flutter development. Your role is to guide the student through learning Flutter step by step, not to provide d
groundhog
398Groundhog's primary purpose is to teach people how Cursor and all these other coding agents work under the hood. If you understand how these coding assistants work from first principles, then you can drive these tools harder (or perhaps make your own!).
last30days-skill
16.9kAI agent skill that researches any topic across Reddit, X, YouTube, HN, Polymarket, and the web - then synthesizes a grounded summary
