Ls
Learning to Split for Automatic Bias Detection
Install / Use
/learn @YujiaBao/LsREADME
Learning to split
This is the official implementation for "Learning to Split for Automatic Bias Detection."
ls can learn to automatically split any dataset, such that models trained on the training set will not generalize to the testing set. The resultant splits are valuable as they can help us debug the dataset. What are the minority groups? Are there any annotation errors? Analyzing these challenging splits can guide us towards a more robust model.
Consider the task of classifying samoyed images vs. polar bear images. In this example, the learned splits help us identify the hidden bias: background color. While predictors can achieve perfect performance on the training split by using the spurious heuristic: polar bears live in snowy habitats, they fail to generalize to the under-represented group (polar bears that appear on grass).
Key features:
- Automatically identifies challenging train/test splits for any
torch.utils.data.Datasetobject. - Supports all
torchvisionclassification models with pre-trained weights. - Supports custom models.
- Supports all optimizers and learning rate schedulers under
torch.optim
Installation
# It is always a good practice to use a new conda/virtual env :)
conda create --name ls python=3.9
conda activate ls
pip install git+https://github.com/YujiaBao/ls.git#egg=ls
ls is tested on Python 3.9 with PyTorch 1.12.1.
Quickstart
You can directly use the ls.learning_to_split() interface to generate challenging splits on PyTorch dataset object. Here is a quick example using the Tox21 dataset:
>>> import ls
# Load the Tox21 dataset.
>>> data = ls.datasets.Tox21()
# Learning to split the Tox21 dataset.
# Here we use a simple mlp as our model backbone and use roc_auc as the evaluation metric.
>>> train_data, test_data = ls.learning_to_split(data, model={'name': 'mlp'}, metric='roc_auc')
Best split:
ls outer loop 9 @ 23:51:42 2022/10/17
| generalization gap 64.31 (val 98.97, test 34.65)
| train count 72.7% (7440)
| test count 27.3% (2800)
| train label dist {0: 7218, 1: 222}
| test label dist {0: 2627, 1: 173}
By default, learning_to_split will output the split status for each outer loop iteration (see tox21.log for the full log). In this example, we see that ls converged after 9 iterations. It identified a very challenging train/test split (generalization gap = 64.31%).
In some cases, one may want to access the indices of the training/testing data or examine the learned dataset splitter. Users can tailor the output to their own needs through the return_order argument.
train_data, test_data, train_indices, test_indices, splitter = ls.learning_to_split(
data, model={'name': 'mlp'},
return_order=['train_data', 'test_data', 'train_indices', 'test_indices', 'splitter']
)
# splitter: The learned splitter (torch.nn.Module)
# train_data, test_data: The training and testing dataset (torch.utils.data.Dataset).
# train_indices, test_indices: The indices of the training/testing examples in the original dataset (list[int])
Models
Built-in models
# List all built-in models.
>>> ls.models.builtin_models
['mnistcnn', 'mlp', 'bert', 'textcnn', 'alexnet', 'convnext_tiny', 'convnext_small', 'convnext_base', 'convnext_large', 'densenet121', 'densenet161', 'densenet169', 'densenet201', 'efficientnet_v2_s', 'efficientnet_v2_m', 'efficientnet_v2_l', 'efficientnet_b0', 'efficientnet_b1', 'efficientnet_b2', 'efficientnet_b3', 'efficientnet_b4', 'efficientnet_b5', 'efficientnet_b6', 'efficientnet_b7', 'googlenet', 'inception_v3', 'mnasnet0_5', 'mnasnet0_75', 'mnasnet1_0', 'mnasnet1_3', 'mobilenet_v2', 'mobilenet_v3_large', 'mobilenet_v3_small', 'regnet_y_400mf', 'regnet_y_800mf', 'regnet_y_1_6gf', 'regnet_y_3_2gf', 'regnet_y_8gf', 'regnet_y_16gf', 'regnet_y_32gf', 'regnet_y_128gf', 'regnet_x_400mf', 'regnet_x_800mf', 'regnet_x_1_6gf', 'regnet_x_3_2gf', 'regnet_x_8gf', 'regnet_x_16gf', 'regnet_x_32gf', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'resnext101_64x4d', 'wide_resnet50_2', 'wide_resnet101_2', 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0', 'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0', 'squeezenet1_0', 'squeezenet1_1', 'swin_t', 'swin_s', 'swin_b', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19', 'vgg19_bn', 'vit_b_16', 'vit_b_32', 'vit_l_16', 'vit_l_32', 'vit_h_14']
ls supports the following models:
- mlp: a simple multi-layer perceptron for inputs with fixed-dimensional features.
- textcnn: a one-dimensional conv net for text classification with pre-trained word embeddings.
- bert: BERT for sequence classification.
- mnistcnn: a two-layer conv net for MNIST digit classification.
Thanks to torchvision, ls also supports the following classification models with pre-trained weights:
- AlexNet:
alexnet - ConvNeXt:
convnext_tiny,convnext_small,convnext_base,convnext_large - DenseNet:
densenet121,densenet161,densenet169,densenet201 - EfficientNet:
efficientnet_b0,efficientnet_b1,efficientnet_b2,efficientnet_b3,efficientnet_b4,efficientnet_b5,efficientnet_b6,efficientnet_b7 - EfficientNet V2:
efficientnet_v2_s,efficientnet_v2_m,efficientnet_v2_l - GoogLeNet:
googlenet - Inception V3:
inception_v3 - MNASNet:
mnasnet0_5,mnasnet0_75,mnasnet1_0,mnasnet1_3 - MobileNet V2:
mobilenet_v2 - MobileNet V3:
mobilenet_v3_small,mobilenet_v3_large - RegNet:
regnet_y_400mf,regnet_y_800mf,regnet_y_1_6gf,regnet_y_3_2gf,regnet_y_8gf,regnet_y_16gf,regnet_y_32gf,regnet_y_128gf,regnet_x_400mf,regnet_x_800mf,regnet_x_1_6gf,regnet_x_3_2gf,regnet_x_8gf,regnet_x_16gf,regnet_x_32gf - ResNet:
resnet18,resnet34,resnet50,resnet101,resnet152 - ResNeXt:
resnext50_32x4d,resnext101_32x8d,resnext101_64x4d - ShuffleNet V2:
shufflenet_v2_x0_5,shufflenet_v2_x1_0,shufflenet_v2_x1_5,shufflenet_v2_x2_0 - SqueezeNet:
squeezenet1_0,squeezenet1_1 - SwinTransformer:
swin_t,swin_s,swin_b - VGG:
vgg11,vgg11_bn,vgg13,vgg13_bn,vgg16,vgg16_bn,vgg19,vgg19_bn - VisionTransformer:
vit_b_16,vit_b_32,vit_l_16,vit_l_32,vit_h_14 - Wide ResNet:
wide_resnet50_2,wide_resnet101_2
For example to use EfficientNetV2-L, one can run
train_data, test_data = ls.learning_to_split(data, model={'name': 'efficientnet_v2_l'})
Custom models
Users can also use their custom models in ls.learning_to_split(). In the following example, we will register a new model custom into the model factory. Checkout customize.py for a more detailed example.
from ls.models.build import ModelFactory
@ModelFactory.register('custom')
class custom(nn.Module):
'''
Define a fancy custom class here.
'''
# kwargs will be passed to initializer of the custom class.
train_data, test_data = ls.learning_to_split(data, model={'name': 'custom', 'args': kwargs})
Optimizers and learning rate schedulers
Why do we need to specify this? The algorithm learning to split has two components: splitter and predictor. On a high level, the splitter partitions the original dataset into a training set and a testing set at each iteration. The predictor then estimates the generalization gap of the current split, providing learning signal for the splitter. Depending on the dataset and the model configurations, the optimization settings (for the splitter and the predictor) are often different.
ls.learning_to_split() supports all optimization algorithms under torch.optim and all learning rate schedulers under torch.optim.lr_scheduler. In the following example, we will define an optimizer and a learning rate scheduler for the algorithm.
# Define the optimizer.
# optim['args'] specifies the keyword arguments for torch.optim.Adam()
optim = {
'name': 'Adam',
'args': {
'lr': 1e-3
'weight_decay': 1e-4
}
}
# Define the learning rate scheduler.
# lr_scheduler['args'] specifies the keyword arguments for torch.optim.lr_scheduler.StepLR()
lr_scheduler = {
'name': 'StepLR',
'args': {
'step_size': 30,
'gamma': 0.1
}
}
# Run learning_to_split with the specified model, optimizer and learnign rate scheduler.
train_data, test_data = ls.learning_to_split(data, model={'name': 'mlp'}, optim=optim, lr_scheduler=lr_scheduler)
Datasets
Built-in datasets
ls contains the following built-in datasets:
- [
BeerReviews](./ls/datasets/BeerRev
Related Skills
YC-Killer
2.7kA library of enterprise-grade AI agents designed to democratize artificial intelligence and provide free, open-source alternatives to overvalued Y Combinator startups. If you are excited about democratizing AI access & AI agents, please star ⭐️ this repository and use the link in the readme to join our open source AI research team.
best-practices-researcher
The most comprehensive Claude Code skills registry | Web Search: https://skills-registry-web.vercel.app
groundhog
398Groundhog's primary purpose is to teach people how Cursor and all these other coding agents work under the hood. If you understand how these coding assistants work from first principles, then you can drive these tools harder (or perhaps make your own!).
isf-agent
a repo for an agent that helps researchers apply for isf funding
