SkillAgentSearch skills...

TokenFormer

[ICLR2025 Spotlight🔥] Official Implementation of TokenFormer: Rethinking Transformer Scaling with Tokenized Model Parameters

Install / Use

/learn @Haiyang-W/TokenFormer
About this skill

Quality Score

0/100

Supported Platforms

Zed

README

TokenFormer: a fully attention-based neural network with tokenized model parameters. Maximizing the flexibility of Transformer by Tokenizing Anything.

<h5 align="center">

arXiv project page huggingface weights Hits

</h5>

This repo is the official implementation of our paper: TokenFormer: Rethinking Transformer Scaling with Tokenized Model Parameters as well as the follow-ups. Our TokenFormer is a natively scalable architecture that leverages the attention mechanism not only for computations among input tokens but also for interactions between tokens and model parameters, thereby enhancing architectural flexibility. We have made every effort to ensure that the codebase is clean, concise, easily readable, state-of-the-art, and relies only on minimal dependencies.

TokenFormer: Rethinking Transformer Scaling with Tokenized Model Parameters

Haiyang Wang*, Yue Fan*, Muhammad Ferjad Naeem, Yongqin Xian, Jan Eric Lenssen, Liwei Wang, Federico Tombari, Bernt Schiele

  • Primary contact: Haiyang Wang ~~(haiwang@mpi-inf.mpg.de)~~(wanghaiyang6@stu.pku.edu.cn), Bernt Schiele (schiele@mpi-inf.mpg.de)
<div align="center"> <img src="assets/Figure1.png" width="800"/> </div>

📣 News

  • [25-02-11] 🔥 TokenFormer is accepted as spotlight presentation.
  • [25-01-22] 🔥 TokenFormer is accepted by ICLR2025.
  • [25-01-12] Jax code on TPU (GCP-Cloud) is released, please see here.
  • [24-11-08] 🚀 Training code with pytorch is released.
  • [24-11-02] Please feel free to email me if I've missed any relevant papers. I will do my best to include all related papers in future versions.
  • [24-10-31] 🚀 Inference code with pytorch is released.
  • [24-10-31] 👀 TokenFormer is released on arXiv.

🔥 Some Thoughts

  • We aim to offer a new perspective to models, applicable to any computation graph in the future. In theory, by using data tokens, parameter tokens, and memory tokens, and through dot-product interactions, it's possible to flexibly construct any network. There are many design possibilities here. For example, introducing memory tokens can build RNN-like networks similar to Mamba. Merging parameter tokens with memory tokens creates something akin to a TTT network. Parameter tokens can also attend to input data in reverse, making the network parameters dynamically data-dependent, updating layer by layer.

Overview

💫 What want to do?

We introduce Tokenformer, a <font color="red">fully attention-based</font> architecture that unifies the computations of token-token and token-parameter interactions by entirely employing the attention mechanism, <font color="red">maximizes the flexibility of neural network</font>. The advantage makes it can handle a variable number of parameters, inherently enhances the model's scalability, facilitating progressively efficient scaling.

<font color="red">We not only tokenizes data but also model parameters, replacing the model concept with interaction flows between data and parameter tokens, further advancing the network architecture towards unification.</font>

Hope that this architecture can offer greater flexibility than traditional Transformers, will further contribute to the development of foundation models, sparse inference (MoE), parameter efficient tuning, device-cloud collaboration, vision-language, model interpretability, and so on.

# Pattention Implementations with given inputs

query, key, value = inputs, key_param_tokens, value_param_tokens

attn_weight = query @ key.transpose(-2, -1) * scale_factor

attn_weight *= attn_masks
# modified softmax, softmax is equal to exp + L1 norm
attn_weight = nonlinear_norm_func(attn_weight, self.norm_activation_type, dim=-1)

output = attn_weight @ value

🚀 Main results

Incremental model scaling

<div align="center"> <img src="assets/Figure2.png" width="800"/> </div>

Traditionally, large transformer architectures are trained from scratch without reusing previous smaller-scale models. In this paper, we propose a novel fully attention-based architecture that allows scaling model incrementally, thus greatly reducing the overall cost of training large transformer architectures.

Language modeling on Pile dataset with zero-shot evaluation

(Zero-shot Evaluations.) The best performance for each model size is highlighted in bold. Our comparisons are made with publicly available transformer-based LMs with various tokenizers. Following Pythia, our model is trained for up to 300B tokens on pile dataset.

<div align="center"> <img src="assets/Figure3.png" width="800"/> </div>

Visual modeling on ImageNet-1k classification

(Image Classification.) Comparison of standard vision transformer on ImageNet-1K.

<div align="center"> <img src="assets/Figure4.png" width="1000"/> </div>

📘 Model Zoo

Language Modeling Benchmark (Pile)

Pretrained models are uploaded to huggingface TokenFormer-150M, TokenFormer-450M, TokenFormer-900M and TokenFormer-1-5B, trained on 300B tokens on the Pile.

These models were trained on the Pile, and follow the standard model dimensions of Transformer, and evaluated on standard zero-shot benchmark described by mamba: | Model |Params| Layers | Model dim. |ckpt|config|log| |---------|---------|---------|--------|--------|---------|---------| | TokenFormer-150M | 150M | 12 | 768 |ckpt| config |log| | TokenFormer-450M | 450M | 24 | 1024 |ckpt| config |log| | TokenFormer-900M| 900M| 32 | 1280 |ckpt| config |log| | TokenFormer-1-5B| 1-5B| 40 | 1536 |ckpt| config |log|

Note: these are base models trained only for 300B tokens, without any form of downstream modification (instruction tuning, etc.). Performance is expected to be comparable or better than other architectures trained on similar data, but not to match larger or fine-tuned models.

Visual Modeling Benchmark (DataComp-1B on CLIP approach)

Will be released later.

🛠️ Quick Start

Installation

First make sure you are in an environment with Python 3.8 and CUDA 12 with an appropriate version of PyTorch 1.8 or later installed. Note: our TokenFormer is based on the GPT-NeoX, some of the libraries that GPT-NeoX depends on have not been updated to be compatible with Python 3.10+. Python 3.9 appears to work, but this codebase has been developed and tested for Python 3.8.

To install the remaining basic dependencies, run:

conda create -n TokenFormer python=3.8

git clone https://github.com/Haiyang-W/TokenFormer.git

pip install torch==2.2.1 torchvision==0.17.1 torchaudio==2.2.1 --index-url https://download.pytorch.org/whl/cu121

### raven module load gcc/10

### If you face cargo problem when run pip install -r requirements/requirements.txt,  please follow the bellow command
# curl https://sh.rustup.rs -sSf | sh
# export PATH="$HOME/.cargo/bin:$PATH"
# source ~/.profile
# source ~/.cargo/env

### if you face mpi4py problem when run pip install -r requirements/requirements.txt, please:
# conda install -c conda-forge mpi4py=3.0.3

pip install -r requirements/requirements.txt

pip install -r requirements/requirements-flashattention.txt # need gcc > 9
pip install -r requirements/requirements-wandb.txt # optional, if logging using WandB
pip install -r requirements/requirements-tensorboard.
View on GitHub
GitHub Stars589
CategoryDevelopment
Updated23d ago
Forks43

Languages

Python

Security Score

100/100

Audited on Mar 13, 2026

No findings