SkillAgentSearch skills...

Tabm

(ICLR 2025) TabM: Advancing Tabular Deep Learning With Parameter-Efficient Ensembling

Install / Use

/learn @yandex-research/Tabm
About this skill

Quality Score

0/100

Supported Platforms

Universal

README

TabM: Advancing Tabular Deep Learning With Parameter-Efficient Ensembling" (ICLR 2025)<!-- omit in toc -->

:scroll: arXiv   :books: Other tabular DL projects

This is the official repository of the paper "TabM: Advancing Tabular Deep Learning With Parameter-Efficient Ensembling". It consists of two parts:

<br> <details> <summary>TabM on <b>Kaggle</b> (as of June 2025)</summary>
  • TabM was used in the winning solution in the competition by UM.
  • TabM was used in the winning solution, as well as in the top-3, top-4, top-5 and many other solutions in the competition by CIBMTR. Later, it turned out that it was possible to achieve the 25-th place out of 3300+ with only TabM, without ensembling it with other models.
</details> <details> <summary>TabM on <b>TabReD</b> (a challenging benchmark)</summary>

TabReD is a benchmark based on real-world industrial datasets with time-related distribution drifts and hundreds of features, which makes it more challenging than traditional benchmarks. The figure below shows that TabM achieves higher performance on TabReD (plus one more real-world dataset) compared to prior tabular DL methods.

<img src="images/tabred-and-microsoft.png" width=35% display=block margin=auto>

One dot represents a performance score on one dataset. For a given model, a diamond represents the mean value across the datasets.

</details> <details> <summary>Training and inference efficiency</summary>

TabM is a simple and reasonably efficient model, which makes it suitable for real-world applications, including large datasets. The biggest dataset used in the paper contains 13M objects, and we are aware of a successful training run on 100M+ objects, though training takes more time in such cases.

The figure below shows that TabM is relatively slower than MLPs and GBDT, but faster than prior tabular DL methods. Note that (1) the inference throughput was measured on a single CPU thread and without any optimizations, in particular without the TabM-specific acceleration technique described later in this document; (2) the left plot uses the logarithmic scale.

<img src="images/efficiency.png" display=block margin=auto>

One dot represents a measurement on one dataset. For a given model, a diamond represents the mean value across the datasets. In the left plot, $\mathrm{TabM_{mini}^{\dagger*}}$ denotes $\mathrm{TabM_{mini}^{\dagger}}$ trained with mixed precision and torch.compile.

</details>

TL;DR<!-- omit in toc -->

<img src="images/tabm.png" width=65% display=block margin=auto>

TabM (Tabular DL model that makes Multiple predictions) is a simple and powerful tabular DL architecture that efficiently imitates an ensemble of MLPs. The two main differences of TabM compared to a regular ensemble of MLPs:

  • Parallel training of the MLPs. This allows monitoring the performance of the ensemble during the training and stopping the training when it is optimal for the ensemble, not for individual MLPs.
  • Weight sharing between the MLPs. In fact, the whole TabM fits in just one MLP-like model. Not only this significantly improves the runtime and memory efficiency, but also turns out to be an effective regularization leading to better task performance.

Reproducing experiments and browsing results<!-- omit in toc -->

[!IMPORTANT] To use TabM in practice and for future work, use the tabm package described below.

The paper-related content (code, metrics, hyperparameters, etc.) is located in the paper/ directory and is described in paper/README.md.

Python package<!-- omit in toc -->

tabm is a PyTorch-based Python package providing the TabM model, as well as layers and tools for building custom TabM-like architectures (i.e. efficient ensembles of MLP-like models).

Installation

pip install tabm

Basic usage

This section shows how to create a model in typical use cases, and gives high-level comments on training and inference.

Creating TabM

The below example showcases the basic version of TabM without feature embeddings. For better performance, num_embeddings should usually be passed as explained in the next section.

[!NOTE] TabM.make(...) used below adds default hyperparameters based on the provided arguments.

<!-- test main -->
import torch
from tabm import TabM

# >>> Common setup for all subsequent sections.
d_out = 1  # For example, one regression task.
batch_size = 256

# The dataset has 24 numerical (continuous) features.
n_num_features = 24

# The dataset has 2 categorical features.
# The first categorical feature has 3 unique categories.
# The second categorical feature has 7 unique categories.
cat_cardinalities = [3, 7]
# <<<

model = TabM.make(
    n_num_features=n_num_features,
    cat_cardinalities=cat_cardinalities,  # One-hot encoding will be used.
    d_out=d_out,
)
x_num = torch.randn(batch_size, n_num_features)
x_cat = torch.column_stack([
    # The i-th categorical features must take values in range(0, cat_cardinalities[i]).
    torch.randint(0, c, (batch_size,)) for c in cat_cardinalities
])
y_pred = model(x_num, x_cat)

# TabM represents an ensemble of k models, hence k predictions per object.
assert y_pred.shape == (batch_size, model.k, d_out)

Creating TabM with feature embeddings

On typical tabular tasks, the best performance is usually achieved by passing feature embedding modules as num_embeddings (in the paper, TabM with embeddings is denoted as $\mathrm{TabM^\dagger}$). TabM supports several feature embedding modules from the rtdl_num_embeddings package. The below example showcases the simplest embedding module LinearReLUEmbeddings.

[!TIP] The best performance is usually achieved with more advanced embeddings, such as PiecewiseLinearEmbeddings and PeriodicEmbeddings. Their usage is covered in the end-to-end usage example.

<!-- test main _ -->
from rtdl_num_embeddings import LinearReLUEmbeddings

model = TabM.make(
    n_num_features=n_num_features,
    num_embeddings=LinearReLUEmbeddings(n_num_features),
    d_out=d_out
)
x_num = torch.randn(batch_size, n_num_features)
y_pred = model(x_num)

assert y_pred.shape == (batch_size, model.k, d_out)

Using TabM with custom inputs and input modules

[!TIP] The implementation of tabm.TabM is a good example of defining inputs and input modules in TabM-based models.

Assume that you want to change what input TabM takes or how TabM handles the input, but you still want to use TabM as the backbone. Then, a typical usage looks as follows:

from tabm import EnsembleView, make_tabm_backbone, LinearEnsemble


class Model(nn.Module):
    def __init__(self, ...):
        # >>> Create any custom modules.
        ...
        # <<<

        # Create the ensemble input module.
        self.ensemble_view = EnsembleView(...)
        # Create the backbone.
        self.backbone = make_tabm_backbone(...)
        # Create the prediction head.
        self.output = LinearEnsemble(...)

    def forward(self, arg1, arg2, ...):
        # Transform the input as needed to one tensor.
        # This step can include feature embeddings
        # and all other kinds of feature transformations.
        # `handle_input` is a hypothetical user-defined function.
        x = handle_input(arg1, arg2, ...)  # -> (B, D) or (B, k, D)

        # The only difference from conventional models is
        # the call of self.ensemble_view.
        x = self.ensemble_view(x)  # -> (B, k, D)
        x = self.backbone(x)
        x = self.output(x)
        return x  # -> (B, k, d_out)

[!NOTE] Regarding the shape of x in the line x = handle_input(...):

  • TabM can be used as a conventional MLP-like backbone, which corresponds to x having the standard shape (B, D) during both training and inference. Thi
View on GitHub
GitHub Stars986
CategoryEducation
Updated6d ago
Forks89

Languages

Python

Security Score

95/100

Audited on Mar 26, 2026

No findings