GRIT
This is an official implementation for "GRIT: Graph Inductive Biases in Transformers without Message Passing".
Install / Use
/learn @LiamMa/GRITREADME
README
This repo is the official implementation of Graph Inductive Biases in Transformers without Message Passing (Ma et al., ICML 2023) [PMLR] [arXiv]
The implementation is based on GraphGPS (Rampasek et al., 2022).
Correction of Typos in Paper
There is a typo on $\mathbf{W}_\text{V}$ in the sentence following Eq. (2). The corrected version is as follws:
``where $\sigma$ is a non-linear activation (ReLU by default); $\mathbf{W}_\text{Q}, \mathbf{W}_\text{K}, \mathbf{W}_\text{Ew}, \mathbf{W}_\text{Eb} \in \mathbb{R}^{d' \times d}$, $\mathbf{W}_\text{A} \in \mathbb{R}^{1 \times d'}$, $\mathbf{W}_\text{V} \in \mathbb{R}^{d \times d}$ and $\mathbf{W}_\text{Ev} \in \mathbb{R}^{d \times d'}$ are learnable weight matrices; ......''
Python environment setup with Conda
conda create -n grit python=3.9
conda activate grit
# please change the cuda/device version as you need
pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113 --trusted-host download.pytorch.org
pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric==2.2.0 -f https://data.pyg.org/whl/torch-1.12.1+cu113.html --trusted-host data.pyg.org
# RDKit is required for OGB-LSC PCQM4Mv2 and datasets derived from it.
## conda install openbabel fsspec rdkit -c conda-forge
pip install rdkit
pip install torchmetrics==0.9.1
pip install ogb
pip install tensorboardX
pip install yacs
pip install opt_einsum
pip install graphgym
pip install pytorch-lightning # required by graphgym
pip install setuptools==59.5.0
# distuitls has conflicts with pytorch with latest version of setuptools
# ---- experiment management tools --------
# pip install wandb # the wandb is used in GraphGPS but not used in GRIT (ours); please verify the usability before using.
# pip install mlflow
### mlflow server --backend-store-uri mlruns --port 5000
Running GRIT
# Run
python main.py --cfg configs/GRIT/zinc-GRIT.yaml wandb.use False accelerator "cuda:0" optim.max_epoch 2000 seed 41 dataset.dir 'xx/xx/data'
# replace 'cuda:0' with the device to use
# replace 'xx/xx/data' with your data-dir (by default './datasets")
# replace 'configs/GRIT/zinc-GRIT.yaml' with any experiments to run
Configurations and Scripts
- Configurations are available under
./configs/GRIT/xxxxx.yaml - Scripts to execute are available under
./scripts/xxx.sh- will run 4 trials of experiments parallelly on
GPU:0,1,2,3.
- will run 4 trials of experiments parallelly on
Intro to the Code Structure
Our code is based on GraphGym, which intensively relies on the module registration. This mechanism allows us to combine modules by module names.
However, it is challenging to trace the code from main.py. Therefore, we provide hints for the overall code architecture.
You can write your customized modules and register them, to build new models under this framework.
The overall architecture of the code: ([x] indicates 'x' is a folder in the code)
- model
- utils
- [act] (the activation functions: be called by other modules)
- [pooling] (global pooling functions: be called in output head for graph level tasks)
- [network] (the macro model architecture: stem->backbone->output head)
- [encoder] (feature/PE encoders(stem): to bridge inputs to the backbone)
- [layer] (backbone layer: )
- [head] (task-dependent output head: )
- training pipeline
- data
- [loader] (data loaders: )
- [transform] (pre-computed transform: PE and other preprocessing)
- [train] (training pipeline: logging, visualization, early-stopping, checkpointing, etc.)
- [optimizer] (optimizers and lr schedulers: )
- [loss] (loss functions: )
- [config] (the default configurations)
Notes on RRWP
Storing all RRWP values for large graphs can be memory-intensive, as torch_geometric loads the entire dataset into memory by default.
Alternatively, you can customize the PyG dataset class or calculate RRWP on the fly within the dataloader. Owe to the simplicity of RRWP computations, performing them on the fly only marginally slows down training with multiple processing workers. (for graphs with nodes fewer than 500).
Example config can be found in cifar10-GRIT-RRWP.yaml (line 5 and line 14).
Citation
If you find this work useful, please consider citing:
@inproceedings{ma2023GraphInductiveBiases,
title = {Graph {Inductive} {Biases} in {Transformers} without {Message} {Passing}},
booktitle = {Proc. {Int}. {Conf}. {Mach}. {Learn}.},
author = {Ma, Liheng and Lin, Chen and Lim, Derek and Romero-Soriano, Adriana and K. Dokania and Coates, Mark and H.S. Torr, Philip and Lim, Ser-Nam},
year = {2023},
}
Related Skills
node-connect
350.8kDiagnose OpenClaw node connection and pairing failures for Android, iOS, and macOS companion apps
frontend-design
110.4kCreate distinctive, production-grade frontend interfaces with high design quality. Use this skill when the user asks to build web components, pages, or applications. Generates creative, polished code that avoids generic AI aesthetics.
openai-whisper-api
350.8kTranscribe audio via OpenAI Audio Transcriptions API (Whisper).
qqbot-media
350.8kQQBot 富媒体收发能力。使用 <qqmedia> 标签,系统根据文件扩展名自动识别类型(图片/语音/视频/文件)。
