Transfusion.torch
PyTorch Implementation of Transfusion: Predict the Next Token and Diffuse Images with One Multi-Modal Model
Install / Use
/learn @VachanVY/Transfusion.torchREADME
Transfusion [Paper]
- Transfusion is a Multi-Modal Transformer, it can generate text like GPTs and images like Diffusion Models, all at once in one go not separately!
- It can easily switch between text and image modalities for generations, and it is nothing complicated, just a single transformer with some modality-specific components!
- This can easily be extended to other modalities like videos, audio, etc, but for now, it can only take images and text as input
TODO: Train on a large Multi-Modal Dataset (something like tiny stories dataset with images in between illustrating the story...?)
from src import LLaMA, Transfussion
class config:
... # Fill in some parameters for the model | see src/configs.py for reference
model = Transfussion(
model=LLaMA(config),
config=config
)
text_and_images = [
[
torch.randint(0, 10, (39,)), # text
# You get "image" after passing the image to PatchOps.patchify() while preprocessing
(torch.randn(345, config.patch_size**2 * config.in_channels), torch.randint(0, config.num_timesteps, (1,))), # (image, timestep)
torch.randint(0, 10, (14,)) # text
],
[
torch.randint(0, 10, (16,)), # text
# You get "image" after passing the image to PatchOps.patchify() while preprocessing
(torch.randn(359, config.patch_size**2 * config.in_channels), torch.randint(0, config.num_timesteps, (1,))), # (image, timestep)
torch.randint(0, 10, (5,)), # text
# You get "image" after passing the image to PatchOps.patchify() while preprocessing
(torch.randn(2, config.patch_size**2 * config.in_channels), torch.randint(0, config.num_timesteps, (1,))), # (image, timestep)
torch.randint(0, 10, (9,)) # text
]
]
output = model(text_and_images, [["text", "image", "text"], ["text", "image", "text", "image", "text"]])
Contents
<!-- * [Test Trained on Fashion MNIST Dataset](https://github.com/VachanVY/Transfusion.torch/tree/main?tab=readme-ov-file#test-trained-on-fashion-mnist-dataset) <===> [Training Notebook with some generated samples](https://github.com/VachanVY/Transfusion.torch/blob/main/fashion_mnist_test_transfusion.ipynb) * [Test Trained on MNIST dataset](https://github.com/VachanVY/Transfusion.torch/tree/main?tab=readme-ov-file#test-trained-on-mnist-dataset) <===> [Training Notebook with some generated samples](https://github.com/VachanVY/Transfusion.torch/blob/main/mnist_test_transfusion.ipynb) --> <!-- ## Test Trained on Fashion MNIST Dataset * Can produce 2 images of Fashion Items along with the text (in the form of tokens) shown above the respective images <!-- the integers above the images can be interpreted using this dictionary --> <!-- ```python {'T-shirt/top': 0, 'Trouser': 1, 'Pullover': 2, 'Dress': 3, 'Coat': 4, 'Sandal': 5, 'Shirt': 6, 'Sneaker': 7, 'Bag': 8, 'Ankle boot': 9} ``` --> <!-- So `5` means it's a sandal and `0` means it's a T-shirt/top from the below image and just like that some more examples. Use the dictionary to interpret the tokens as text (for now, will change it)\  --- `8` is a bag\  * See [this notebook](https://github.com/VachanVY/Transfusion.torch/blob/main/fashion_mnist_test_transfusion.ipynb) for more examples. ## Test Trained on MNIST dataset * Generates text and images in an alternating way as shown below  ---  * See [this notebook](https://github.com/VachanVY/Transfusion.torch/blob/main/mnist_test_transfusion.ipynb) for more examples -->Introduction
- Transfusion by pretraining a transformer model on 50% text and 50% image data using a different objective for each modality: next token prediction for text and diffusion for images
- We apply causal attention for text tokens and bidirectional attention for image patches. For inference, we introduce a decoding algorithm that combines the standard practices of text generation from language models and image generation from diffusion models
- Intra-image bidirectional attention is important, and replacing it with causal attention hurts text-to-image generation
Language Modelling Utils and Loss
- Autoregressive Classification
- Usual Cross-Entropy Loss
Diffusion Utils and Loss
- Noise Schedule: cosine scheduler
(We found that while the linear noise schedule used in Ho et al. (2020) worked well for high-resolution images, it was sub-optimal for images of
resolution 64 × 64 and 32 × 32)
- Loss: Mean Squared Error
- Latent Image Representation: Variational autoencoders (VAEs) [Kingma and Welling, 2013] can save compute by encoding images into a lower-dimensional latent space
Data Representation
- Discrete text and continuous images
- Each text string is tokenized into a sequence of discrete tokens from a fixed vocabulary, where each token is represented as an integer
Model Architecture
- The vast majority of the model’s parameters belong to a single transformer, which processes every sequence, regardless of modality (We follow Llama’s [Touvron et al., 2023a] flavour of the transformer block, which includes the SwiGLU activation function [Shazeer, 2020] and RoPE [Su et al., 2024])
- To convert our data into this space, we use lightweight modality-specific components with unshared parameters
- For text, these are the embedding matrices
- Images, we experiment with two alternatives for compressing local windows of k × k patch vectors into a single transformer vector (and vice versa):
- a simple linear layer (We add an embedding of the timestep t to every patch vector before the linear layer)
- up and down blocks of a U-Net (We replace the U-Net’s AdaLayerNorm with regular layer norm in our implementation)
- Transfusion Attention: While text is naturally sequential, images are not, and are usually modelled with unrestricted (bidirectional) attention. Transfusion combines both attention patterns by applying causal attention to every element in the sequence, and bidirectional attention within the aspects of each individual image
Training Objective
- LM loss is computed per token (When the input is a BOI token, we do not compute any loss), while diffusion loss is computed per image, which may span multiple elements (image patches) in the sequence
- Specifically, we add noise ϵ to each input latent image x0 according to the diffusion process to produce xt before patchification, and then compute the image-level diffusion loss
Optimization
- AdamW => | betas=(0.9, 0.95) | eps=1e-8 | lr=3e-4 | warmup=4000 | min_lr=1.5e-5 | weight_decay=0.1 | clip_norm=1.0 |
- balancing_coeff (lambda in loss function) = 5
Inference
- 250 diffusion steps (but trained on 1000 timesteps)
- cfg_coeff = 5.0
