Tabm
(ICLR 2025) TabM: Advancing Tabular Deep Learning With Parameter-Efficient Ensembling
Install / Use
/learn @yandex-research/TabmREADME
TabM: Advancing Tabular Deep Learning With Parameter-Efficient Ensembling" (ICLR 2025)<!-- omit in toc -->
:scroll: arXiv :books: Other tabular DL projects
This is the official repository of the paper "TabM: Advancing Tabular Deep Learning With Parameter-Efficient Ensembling". It consists of two parts:
- Python package described in this document.
- Paper-related content (code, metrics, hyperparameters, etc.) described in
paper/README.md.
- TabM was used in the winning solution in the competition by UM.
- TabM was used in the winning solution, as well as in the top-3, top-4, top-5 and many other solutions in the competition by CIBMTR. Later, it turned out that it was possible to achieve the 25-th place out of 3300+ with only TabM, without ensembling it with other models.
TabReD is a benchmark based on real-world industrial datasets with time-related distribution drifts and hundreds of features, which makes it more challenging than traditional benchmarks. The figure below shows that TabM achieves higher performance on TabReD (plus one more real-world dataset) compared to prior tabular DL methods.
<img src="images/tabred-and-microsoft.png" width=35% display=block margin=auto>One dot represents a performance score on one dataset. For a given model, a diamond represents the mean value across the datasets.
</details> <details> <summary>Training and inference efficiency</summary>TabM is a simple and reasonably efficient model, which makes it suitable for real-world applications, including large datasets. The biggest dataset used in the paper contains 13M objects, and we are aware of a successful training run on 100M+ objects, though training takes more time in such cases.
The figure below shows that TabM is relatively slower than MLPs and GBDT, but faster than prior tabular DL methods. Note that (1) the inference throughput was measured on a single CPU thread and without any optimizations, in particular without the TabM-specific acceleration technique described later in this document; (2) the left plot uses the logarithmic scale.
<img src="images/efficiency.png" display=block margin=auto>One dot represents a measurement on one dataset. For a given model, a diamond represents the mean value across the datasets. In the left plot, $\mathrm{TabM_{mini}^{\dagger*}}$ denotes $\mathrm{TabM_{mini}^{\dagger}}$ trained with mixed precision and torch.compile.
TL;DR<!-- omit in toc -->
<img src="images/tabm.png" width=65% display=block margin=auto>TabM (Tabular DL model that makes Multiple predictions) is a simple and powerful tabular DL architecture that efficiently imitates an ensemble of MLPs. The two main differences of TabM compared to a regular ensemble of MLPs:
- Parallel training of the MLPs. This allows monitoring the performance of the ensemble during the training and stopping the training when it is optimal for the ensemble, not for individual MLPs.
- Weight sharing between the MLPs. In fact, the whole TabM fits in just one MLP-like model. Not only this significantly improves the runtime and memory efficiency, but also turns out to be an effective regularization leading to better task performance.
Reproducing experiments and browsing results<!-- omit in toc -->
[!IMPORTANT] To use TabM in practice and for future work, use the
tabmpackage described below.
The paper-related content (code, metrics, hyperparameters, etc.) is located in the paper/ directory and is described in paper/README.md.
Python package<!-- omit in toc -->
tabm is a PyTorch-based Python package providing the TabM model, as well as layers and tools for building custom TabM-like architectures (i.e. efficient ensembles of MLP-like models).
Installation
pip install tabm
Basic usage
This section shows how to create a model in typical use cases, and gives high-level comments on training and inference.
Creating TabM
The below example showcases the basic version of TabM without feature embeddings.
For better performance, num_embeddings should usually be passed as explained in the next section.
<!-- test main -->[!NOTE]
TabM.make(...)used below adds default hyperparameters based on the provided arguments.
import torch
from tabm import TabM
# >>> Common setup for all subsequent sections.
d_out = 1 # For example, one regression task.
batch_size = 256
# The dataset has 24 numerical (continuous) features.
n_num_features = 24
# The dataset has 2 categorical features.
# The first categorical feature has 3 unique categories.
# The second categorical feature has 7 unique categories.
cat_cardinalities = [3, 7]
# <<<
model = TabM.make(
n_num_features=n_num_features,
cat_cardinalities=cat_cardinalities, # One-hot encoding will be used.
d_out=d_out,
)
x_num = torch.randn(batch_size, n_num_features)
x_cat = torch.column_stack([
# The i-th categorical features must take values in range(0, cat_cardinalities[i]).
torch.randint(0, c, (batch_size,)) for c in cat_cardinalities
])
y_pred = model(x_num, x_cat)
# TabM represents an ensemble of k models, hence k predictions per object.
assert y_pred.shape == (batch_size, model.k, d_out)
Creating TabM with feature embeddings
On typical tabular tasks, the best performance is usually achieved by passing feature embedding
modules as num_embeddings (in the paper, TabM with embeddings is denoted as $\mathrm{TabM^\dagger}$). TabM supports several feature embedding modules from the
rtdl_num_embeddings
package. The below example showcases the simplest embedding module LinearReLUEmbeddings.
<!-- test main _ -->[!TIP] The best performance is usually achieved with more advanced embeddings, such as
PiecewiseLinearEmbeddingsandPeriodicEmbeddings. Their usage is covered in the end-to-end usage example.
from rtdl_num_embeddings import LinearReLUEmbeddings
model = TabM.make(
n_num_features=n_num_features,
num_embeddings=LinearReLUEmbeddings(n_num_features),
d_out=d_out
)
x_num = torch.randn(batch_size, n_num_features)
y_pred = model(x_num)
assert y_pred.shape == (batch_size, model.k, d_out)
Using TabM with custom inputs and input modules
[!TIP] The implementation of
tabm.TabMis a good example of defining inputs and input modules in TabM-based models.
Assume that you want to change what input TabM takes or how TabM handles the input, but you still want to use TabM as the backbone. Then, a typical usage looks as follows:
from tabm import EnsembleView, make_tabm_backbone, LinearEnsemble
class Model(nn.Module):
def __init__(self, ...):
# >>> Create any custom modules.
...
# <<<
# Create the ensemble input module.
self.ensemble_view = EnsembleView(...)
# Create the backbone.
self.backbone = make_tabm_backbone(...)
# Create the prediction head.
self.output = LinearEnsemble(...)
def forward(self, arg1, arg2, ...):
# Transform the input as needed to one tensor.
# This step can include feature embeddings
# and all other kinds of feature transformations.
# `handle_input` is a hypothetical user-defined function.
x = handle_input(arg1, arg2, ...) # -> (B, D) or (B, k, D)
# The only difference from conventional models is
# the call of self.ensemble_view.
x = self.ensemble_view(x) # -> (B, k, D)
x = self.backbone(x)
x = self.output(x)
return x # -> (B, k, d_out)
[!NOTE] Regarding the shape of
xin the linex = handle_input(...):
- TabM can be used as a conventional MLP-like backbone, which corresponds to
xhaving the standard shape(B, D)during both training and inference. Thi
