SkillAgentSearch skills...

Ls

Learning to Split for Automatic Bias Detection

Install / Use

/learn @YujiaBao/Ls
About this skill

Quality Score

0/100

Supported Platforms

Universal

README

Learning to split

This is the official implementation for "Learning to Split for Automatic Bias Detection."

ls can learn to automatically split any dataset, such that models trained on the training set will not generalize to the testing set. The resultant splits are valuable as they can help us debug the dataset. What are the minority groups? Are there any annotation errors? Analyzing these challenging splits can guide us towards a more robust model.

<p align="center"> <img src="assets/intro.png" width=90% /> </p>

Consider the task of classifying samoyed images vs. polar bear images. In this example, the learned splits help us identify the hidden bias: background color. While predictors can achieve perfect performance on the training split by using the spurious heuristic: polar bears live in snowy habitats, they fail to generalize to the under-represented group (polar bears that appear on grass).

Key features:

  • Automatically identifies challenging train/test splits for any torch.utils.data.Dataset object.
  • Supports all torchvision classification models with pre-trained weights.
  • Supports custom models.
  • Supports all optimizers and learning rate schedulers under torch.optim

Installation

# It is always a good practice to use a new conda/virtual env :)
conda create --name ls python=3.9
conda activate ls
pip install git+https://github.com/YujiaBao/ls.git#egg=ls

ls is tested on Python 3.9 with PyTorch 1.12.1.

Quickstart

You can directly use the ls.learning_to_split() interface to generate challenging splits on PyTorch dataset object. Here is a quick example using the Tox21 dataset:

>>> import ls

# Load the Tox21 dataset.
>>> data = ls.datasets.Tox21()

# Learning to split the Tox21 dataset.
# Here we use a simple mlp as our model backbone and use roc_auc as the evaluation metric.
>>> train_data, test_data = ls.learning_to_split(data, model={'name': 'mlp'}, metric='roc_auc')

Best split:
ls outer loop 9 @ 23:51:42 2022/10/17
| generalization gap 64.31 (val 98.97, test 34.65)
| train count 72.7% (7440)
| test  count 27.3% (2800)
| train label dist {0: 7218, 1: 222}
| test  label dist {0: 2627, 1: 173}

By default, learning_to_split will output the split status for each outer loop iteration (see tox21.log for the full log). In this example, we see that ls converged after 9 iterations. It identified a very challenging train/test split (generalization gap = 64.31%).

In some cases, one may want to access the indices of the training/testing data or examine the learned dataset splitter. Users can tailor the output to their own needs through the return_order argument.

train_data, test_data, train_indices, test_indices, splitter = ls.learning_to_split(
    data, model={'name': 'mlp'},
    return_order=['train_data', 'test_data', 'train_indices', 'test_indices', 'splitter']
)
# splitter:                    The learned splitter (torch.nn.Module)
# train_data, test_data:       The training and testing dataset (torch.utils.data.Dataset).
# train_indices, test_indices: The indices of the training/testing examples in the original dataset (list[int])

Models

Built-in models

# List all built-in models.
>>> ls.models.builtin_models
['mnistcnn', 'mlp', 'bert', 'textcnn', 'alexnet', 'convnext_tiny', 'convnext_small', 'convnext_base', 'convnext_large', 'densenet121', 'densenet161', 'densenet169', 'densenet201', 'efficientnet_v2_s', 'efficientnet_v2_m', 'efficientnet_v2_l', 'efficientnet_b0', 'efficientnet_b1', 'efficientnet_b2', 'efficientnet_b3', 'efficientnet_b4', 'efficientnet_b5', 'efficientnet_b6', 'efficientnet_b7', 'googlenet', 'inception_v3', 'mnasnet0_5', 'mnasnet0_75', 'mnasnet1_0', 'mnasnet1_3', 'mobilenet_v2', 'mobilenet_v3_large', 'mobilenet_v3_small', 'regnet_y_400mf', 'regnet_y_800mf', 'regnet_y_1_6gf', 'regnet_y_3_2gf', 'regnet_y_8gf', 'regnet_y_16gf', 'regnet_y_32gf', 'regnet_y_128gf', 'regnet_x_400mf', 'regnet_x_800mf', 'regnet_x_1_6gf', 'regnet_x_3_2gf', 'regnet_x_8gf', 'regnet_x_16gf', 'regnet_x_32gf', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'resnext101_64x4d', 'wide_resnet50_2', 'wide_resnet101_2', 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0', 'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0', 'squeezenet1_0', 'squeezenet1_1', 'swin_t', 'swin_s', 'swin_b', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19', 'vgg19_bn', 'vit_b_16', 'vit_b_32', 'vit_l_16', 'vit_l_32', 'vit_h_14']

ls supports the following models:

  • mlp: a simple multi-layer perceptron for inputs with fixed-dimensional features.
  • textcnn: a one-dimensional conv net for text classification with pre-trained word embeddings.
  • bert: BERT for sequence classification.
  • mnistcnn: a two-layer conv net for MNIST digit classification.

Thanks to torchvision, ls also supports the following classification models with pre-trained weights:

  • AlexNet: alexnet
  • ConvNeXt: convnext_tiny, convnext_small, convnext_base, convnext_large
  • DenseNet: densenet121, densenet161, densenet169, densenet201
  • EfficientNet: efficientnet_b0, efficientnet_b1, efficientnet_b2, efficientnet_b3, efficientnet_b4, efficientnet_b5, efficientnet_b6, efficientnet_b7
  • EfficientNet V2: efficientnet_v2_s, efficientnet_v2_m, efficientnet_v2_l
  • GoogLeNet: googlenet
  • Inception V3: inception_v3
  • MNASNet: mnasnet0_5, mnasnet0_75, mnasnet1_0, mnasnet1_3
  • MobileNet V2: mobilenet_v2
  • MobileNet V3: mobilenet_v3_small, mobilenet_v3_large
  • RegNet: regnet_y_400mf, regnet_y_800mf, regnet_y_1_6gf, regnet_y_3_2gf, regnet_y_8gf, regnet_y_16gf, regnet_y_32gf, regnet_y_128gf, regnet_x_400mf, regnet_x_800mf, regnet_x_1_6gf, regnet_x_3_2gf, regnet_x_8gf, regnet_x_16gf, regnet_x_32gf
  • ResNet: resnet18, resnet34, resnet50, resnet101, resnet152
  • ResNeXt: resnext50_32x4d, resnext101_32x8d, resnext101_64x4d
  • ShuffleNet V2: shufflenet_v2_x0_5, shufflenet_v2_x1_0, shufflenet_v2_x1_5, shufflenet_v2_x2_0
  • SqueezeNet: squeezenet1_0, squeezenet1_1
  • SwinTransformer: swin_t, swin_s, swin_b
  • VGG: vgg11, vgg11_bn, vgg13, vgg13_bn, vgg16, vgg16_bn, vgg19, vgg19_bn
  • VisionTransformer: vit_b_16, vit_b_32, vit_l_16, vit_l_32, vit_h_14
  • Wide ResNet: wide_resnet50_2, wide_resnet101_2

For example to use EfficientNetV2-L, one can run

train_data, test_data = ls.learning_to_split(data, model={'name': 'efficientnet_v2_l'})

Custom models

Users can also use their custom models in ls.learning_to_split(). In the following example, we will register a new model custom into the model factory. Checkout customize.py for a more detailed example.

from ls.models.build import ModelFactory

@ModelFactory.register('custom')
class custom(nn.Module):
    '''
       Define a fancy custom class here.
    '''

# kwargs will be passed to initializer of the custom class.
train_data, test_data = ls.learning_to_split(data, model={'name': 'custom', 'args': kwargs})

Optimizers and learning rate schedulers

Why do we need to specify this? The algorithm learning to split has two components: splitter and predictor. On a high level, the splitter partitions the original dataset into a training set and a testing set at each iteration. The predictor then estimates the generalization gap of the current split, providing learning signal for the splitter. Depending on the dataset and the model configurations, the optimization settings (for the splitter and the predictor) are often different.

ls.learning_to_split() supports all optimization algorithms under torch.optim and all learning rate schedulers under torch.optim.lr_scheduler. In the following example, we will define an optimizer and a learning rate scheduler for the algorithm.

# Define the optimizer.
# optim['args'] specifies the keyword arguments for torch.optim.Adam()
optim = {
    'name': 'Adam',
    'args': {
        'lr': 1e-3
	'weight_decay': 1e-4
    }
}

# Define the learning rate scheduler.
# lr_scheduler['args'] specifies the keyword arguments for torch.optim.lr_scheduler.StepLR()
lr_scheduler = {
    'name': 'StepLR',
    'args': {
        'step_size': 30,
	'gamma': 0.1
    }
}

# Run learning_to_split with the specified model, optimizer and learnign rate scheduler.
train_data, test_data = ls.learning_to_split(data, model={'name': 'mlp'}, optim=optim, lr_scheduler=lr_scheduler)

Datasets

Built-in datasets

ls contains the following built-in datasets:

  • [BeerReviews](./ls/datasets/BeerRev

Related Skills

View on GitHub
GitHub Stars47
CategoryEducation
Updated3mo ago
Forks7

Languages

Python

Security Score

92/100

Audited on Dec 24, 2025

No findings