Multimodal
TorchMultimodal is a PyTorch library for training state-of-the-art multimodal multi-task models at scale.
Install / Use
/learn @facebookresearch/MultimodalREADME
TorchMultimodal (Beta Release)
Models | Example scripts | Getting started | Code overview | Installation | Contributing | License
Introduction
TorchMultimodal is a PyTorch library for training state-of-the-art multimodal multi-task models at scale, including both content understanding and generative models. TorchMultimodal contains:
- A repository of modular and composable building blocks (fusion layers, loss functions, datasets and utilities).
- A collection of common multimodal model classes built up from said building blocks with pretrained weights for canonical configurations.
- A set of examples that show how to combine these building blocks with components and common infrastructure from across the PyTorch Ecosystem to replicate state-of-the-art models published in the literature. These examples should serve as baselines for ongoing research in the field, as well as a starting point for future work.
Models
TorchMultimodal contains a number of models, including
- ALBEF: model class, paper
- BLIP-2: model class, paper
- CLIP: model class, paper
- CoCa: model class, paper
- DALL-E 2: model, paper
- FLAVA: model class, paper
- MAE/Audio MAE: model class, MAE paper, Audio MAE paper
- MDETR: model class, paper
Example scripts
In addition to the above models, we provide example scripts for training, fine-tuning, and evaluation of models on popular multimodal tasks. Examples can be found under examples/ and include
| Model | Supported Tasks | | :--------------------------------------: | :----------------------: | | ALBEF | Retrieval <br/> Visual Question Answering | | DDPM | Training and Inference (notebook) | FLAVA | Pretraining <br/> Fine-tuning <br/> Zero-shot| | MDETR | Phrase grounding <br/> Visual Question Answering | | MUGEN | Text-to-video retrieval <br/> Text-to-video generation | | Omnivore | Pre-training <br/> Evaluation |
Getting started
Below we give minimal examples of how you can write a simple training or zero-shot evaluation script using components from TorchMultimodal.
<details> <summary>FLAVA zero-shot example</summary>import torch
from PIL import Image
from torchmultimodal.models.flava.model import flava_model
from torchmultimodal.transforms.bert_text_transform import BertTextTransform
from torchmultimodal.transforms.flava_transform import FLAVAImageTransform
# Define helper function for zero-shot prediction
def predict(zero_shot_model, image, labels):
zero_shot_model.eval()
with torch.no_grad():
image = image_transform(img)["image"].unsqueeze(0)
texts = text_transform(labels)
_, image_features = zero_shot_model.encode_image(image, projection=True)
_, text_features = zero_shot_model.encode_text(texts, projection=True)
scores = image_features @ text_features.t()
probs = torch.nn.Softmax(dim=-1)(scores)
label = labels[torch.argmax(probs)]
print(
"Label probabilities: ",
{labels[i]: probs[:, i] for i in range(len(labels))},
)
print(f"Predicted label: {label}")
image_transform = FLAVAImageTransform(is_train=False)
text_transform = BertTextTransform()
zero_shot_model = flava_model(pretrained=True)
img = Image.open("my_image.jpg") # point to your own image
predict(zero_shot_model, img, ["dog", "cat", "house"])
# Example output:
# Label probabilities: {'dog': tensor([0.80590]), 'cat': tensor([0.0971]), 'house': tensor([0.0970])}
# Predicted label: dog
</details>
<details>
<summary>MAE training example</summary>
import torch
from torch.utils.data import DataLoader
from torchmultimodal.models.masked_auto_encoder.model import vit_l_16_image_mae
from torchmultimodal.models.masked_auto_encoder.utils import (
CosineWithWarmupAndLRScaling,
)
from torchmultimodal.modules.losses.reconstruction_loss import ReconstructionLoss
from torchmultimodal.transforms.mae_transform import ImagePretrainTransform
mae_transform = ImagePretrainTransform()
dataset = MyDatasetClass(transforms=mae_transform) # you should define this
dataloader = DataLoader(dataset, batch_size=8)
# Instantiate model and loss
mae_model = vit_l_16_image_mae()
mae_loss = ReconstructionLoss()
# Define optimizer and lr scheduler
optimizer = torch.optim.AdamW(mae_model.parameters())
lr_scheduler = CosineWithWarmupAndLRScaling(
optimizer, max_iters=1000, warmup_iters=100 # you should set these
)
# Train one epoch
for batch in dataloader:
model_out = mae_model(batch["images"])
loss = mae_loss(model_out.decoder_pred, model_out.label_patches, model_out.mask)
loss.backward()
optimizer.step()
lr_scheduler.step()
</details>
Code overview
torchmultimodal/diffusion_labs
diffusion_labs contains components for building diffusion models. For more details on these components, see diffusion_labs/README.md.
torchmultimodal/models
Look here for model classes as well as any other modeling code specific to a given architecture. E.g. the directory torchmultimodal/models/blip2 contains modeling components specific to BLIP-2.
torchmultimodal/modules
Look here for common generic building blocks that can be stitched together to build a new architecture. This includes layers like codebooks, patch embeddings, or transformer encoder/decoders, losses like contrastive loss with temperature or reconstruction loss, encoders like [ViT](https://github.com/f
