SkillAgentSearch skills...

GatedTabTransformer

A deep learning tabular classification architecture inspired by TabTransformer with integrated gated multilayer perceptron.

Install / Use

/learn @radi-cho/GatedTabTransformer
About this skill

Quality Score

0/100

Supported Platforms

Universal

README

The GatedTabTransformer.

A deep learning tabular classification architecture inspired by TabTransformer with integrated gated multilayer perceptron. Check out our paper on arXiv. Applications and usage demonstrations are available here.

<img alt="Architecture" src="./paper/media/GatedTabTransformer-architecture.png" width="350px"></img>

Usage

import torch
import torch.nn as nn
from gated_tab_transformer import GatedTabTransformer

model = GatedTabTransformer(
    categories = (10, 5, 6, 5, 8),      # tuple containing the number of unique values within each category
    num_continuous = 10,                # number of continuous values
    transformer_dim = 32,               # dimension, paper set at 32
    dim_out = 1,                        # binary prediction, but could be anything
    transformer_depth = 6,              # depth, paper recommended 6
    transformer_heads = 8,              # heads, paper recommends 8
    attn_dropout = 0.1,                 # post-attention dropout
    ff_dropout = 0.1,                   # feed forward dropout
    mlp_act = nn.LeakyReLU(0),          # activation for final mlp, defaults to relu, but could be anything else (selu, etc.)
    mlp_depth=4,                        # mlp hidden layers depth
    mlp_dimension=32,                   # dimension of mlp layers
    gmlp_enabled=True                   # gmlp or standard mlp
)

x_categ = torch.randint(0, 5, (1, 5))   # category values, from 0 - max number of categories, in the order as passed into the constructor above
x_cont = torch.randn(1, 10)             # assume continuous values are already normalized individually

pred = model(x_categ, x_cont)
print(pred)

Citation

@misc{cholakov2022gatedtabtransformer,
      title={The GatedTabTransformer. An enhanced deep learning architecture for tabular modeling}, 
      author={Radostin Cholakov and Todor Kolev},
      year={2022},
      eprint={2201.00199},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}
View on GitHub
GitHub Stars101
CategoryEducation
Updated6mo ago
Forks7

Languages

Jupyter Notebook

Security Score

92/100

Audited on Sep 12, 2025

No findings