SkillAgentSearch skills...

Pytabkit

ML models + benchmark for tabular data classification and regression

Install / Use

/learn @dholzmueller/Pytabkit

README

Open In Colab test Downloads

PyTabKit: Tabular ML models and benchmarking (NeurIPS 2024)

Paper | Documentation | RealMLP-TD-S standalone implementation | Grinsztajn et al. benchmark code | Data archive | |-------------------------------------------|--------------------------------------------------|---------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------|-----------------------------------------------------|

PyTabKit provides scikit-learn interfaces for modern tabular classification and regression methods benchmarked in our paper, see below. It also contains the code we used for benchmarking these methods on our benchmarks.

Meta-test benchmark results

When (not) to use pytabkit

  • To get the best possible results:
    • Generally we recommend AutoGluon for the best possible results, though it does not include all the models from pytabkit. AutoGluon 1.4 includes RealMLP (though not in a default configuration) and TabM (in the "extreme" preset for <= 30K samples).
    • To get the best possible results from pytabkit, we recommend using Ensemble_HPO_Classifier(n_cv=8, use_full_caruana_ensembling=True, use_tabarena_spaces=True, n_hpo_steps=50) with a val_metric_name corresponding to your target metric (e.g., class_error, cross_entropy, brier, 1-auc_ovr), or the corresponding Regressor. (This might take very long to fit.)
    • For only a single model, we recommend using RealMLP_HPO_Classifier(n_cv=8, hpo_space_name='tabarena-new', use_caruana_ensembling=True, n_hyperopt_steps=50), also with val_metric_name as above, or the corresponding Regressor.
  • Models: TabArena also includes some newer models like RealMLP and TabM with more general preprocessing (missing numericals, text, etc.), as well as very good boosted tree implementations. pytabkit is currently still easier to use and supports vectorized cross-validation for RealMLP, which can significantly speed up the training.
  • Benchmarking: While pytabkit can be good for quick benchmarking for development, for method evaluation we recommend TabArena.

Installation (new in 1.4.0: optional model dependencies)

pip install pytabkit[models]
  • RealMLP (and TabM) can be used without the [models] part.
  • For xRFM on GPU, faster kernels will be used if you install kermac[cu12] or kermac[cu11] (depending on your CUDA version).
  • If you want to use TabR, you have to manually install faiss, which is only available on conda.
  • Please install torch separately if you want to control the version (CPU/GPU etc.)
  • Use pytabkit[models,autogluon,extra,hpo,bench,dev] to install additional dependencies for the other models, AutoGluon models, extra preprocessing, hyperparameter optimization methods beyond random search (hyperopt/SMAC), the benchmarking part, and testing/documentation. For the hpo part, you might need to install swig (e.g. via pip) if the build of pyrfr fails. See also the documentation. To run the data download for the meta-train benchmark, you need one of rar, unrar, or 7-zip to be installed on the system.

Using the ML models

Most of our machine learning models are directly available via scikit-learn interfaces. For example, you can use RealMLP-TD for classification as follows:

from pytabkit import RealMLP_TD_Classifier

model = RealMLP_TD_Classifier()  # or TabR_S_D_Classifier, CatBoost_TD_Classifier, etc.
model.fit(X_train, y_train)
model.predict(X_test)

The code above will automatically select a GPU if available, try to detect categorical columns in dataframes, preprocess numerical variables and regression targets (no standardization required), and use a training-validation split for early stopping. All of this (and much more) can be configured through the constructor and the parameters of the fit() method. For example, it is possible to do bagging (ensembling of models on 5-fold cross-validation) simply by passing n_cv=5 to the constructor. Here is an example for some of the parameters that can be set explicitly:

from pytabkit import RealMLP_TD_Classifier

model = RealMLP_TD_Classifier(device='cpu', random_state=0, n_cv=1, n_refit=0,
                              n_epochs=256, batch_size=256, hidden_sizes=[256] * 3,
                              val_metric_name='cross_entropy',
                              use_ls=False,  # for metrics like AUC / log-loss
                              lr=0.04, verbosity=2)
model.fit(X_train, y_train, X_val, y_val, cat_col_names=['Education'])
model.predict_proba(X_test)

See this notebook for more examples. Missing numerical values are currently not allowed and need to be imputed beforehand.

Available ML models

Our ML models are available in up to three variants, all with best-epoch selection:

  • library defaults (D)
  • our tuned defaults (TD)
  • random search hyperparameter optimization (HPO), sometimes also tree parzen estimator (HPO-TPE) or weighted ensembling (Ensemble)

We provide the following ML models:

Post-hoc calibration and refinement stopping

For using post-hoc temperature scaling and refinement stopping from our paper Rethinking Early Stopping: Refine, Then Calibrate, you can pass the following parameters to the scikit-learn interfaces:

from pytabkit import RealMLP_TD_Classifier
clf = RealMLP_TD_Classifier(
    val_metric_name='ref-ll-ts',  # short for 'refinement_logloss_ts-mix_all'
    calibration_method='ts-mix',  # temperature scaling with laplace smoothing
    use_ls=False  # recommended for cross-entropy loss
)

Other calibration methods and validation metrics from probmetrics can be used as well.

For reproducing the results from this paper, we refer to the documentation.

Benchmarking code

Our benchmarking code has functionality for

  • dataset download
  • running methods highly parallel on single-node/multi-node/multi-GPU hardware, with automatic scheduling and trying to respect RAM constraints
  • analyzing/plotting results

For more details, we refer to the documentation.

Preprocessing code

While many preprocessing methods are implemented in this repository, a standalone version of our robust scaling + smooth clipping can be found here.

Citation

If you use this repository for research purposes, please cite our paper:

@inproceedings{holzmuller2024better,
  title={Better by default: {S}trong pre-tuned {MLPs} and boosted trees on tabular data},
  author={Holzm{\"u}ller, David and Grinsztajn, Leo and Steinwart, Ingo},
  booktitle = {Neural {Information} {Processing} {Systems}},
  year={2024}
}

Contributors

  • David Holzmüller (main developer)
  • Léo Grinsztajn (deep learning baselines, plotting)
  • Ingo Steinwart (UCI dataset download)
  • Katharina Strecker (PyTorch-Lightning interface)
  • Daniel Beaglehole (part of the xRFM implementation)
  • Lennart Purucker (some features/fixes)
  • Jérôme Dockès (deployment, continuous integration)

Acknowledgements

Code fro

Related Skills

View on GitHub
GitHub Stars359
CategoryEducation
Updated1d ago
Forks33

Languages

Python

Security Score

100/100

Audited on Mar 29, 2026

No findings