DeCLUTR
The corresponding code from our paper "DeCLUTR: Deep Contrastive Learning for Unsupervised Textual Representations". Do not hesitate to open an issue if you run into any trouble!
Install / Use
/learn @JohnGiorgi/DeCLUTRREADME
DeCLUTR: Deep Contrastive Learning for Unsupervised Textual Representations
The corresponding code for our paper: DeCLUTR: Deep Contrastive Learning for Unsupervised Textual Representations. Results on SentEval are presented below (as averaged scores on the downstream and probing task test sets), along with existing state-of-the-art methods.
| Model | Requires labelled data? | Parameters | Embed. dim. | Downstream (-SNLI) | Probing | Δ | |------------------------------------------------------------------------------------------------------------|:-----------------------:|:----------:|:-----------:|:------------------:|:---------:|:-----:| | InferSent V2 | Yes | 38M | 4096 | 76.00 | 72.58 | -3.10 | | Universal Sentence Encoder | Yes | 147M | 512 | 78.89 | 66.70 | -0.21 | | Sentence Transformers ("roberta-base-nli-mean-tokens") | Yes | 125M | 768 | 77.19 | 63.22 | -1.91 | | Transformer-small (DistilRoBERTa-base) | No | 82M | 768 | 72.58 | 74.57 | -6.52 | | Transformer-base (RoBERTa-base) | No | 125M | 768 | 72.70 | 74.19 | -6.40 | | DeCLUTR-small (DistilRoBERTa-base) | No | 82M | 768 | 77.50 | 74.71 | -1.60 | | DeCLUTR-base (RoBERTa-base) | No | 125M | 768 | 79.10 | 74.65 | -- |
Transformer-* is the same underlying architecture and pretrained weights as DeCLUTR-* before continued pretraining with our contrastive objective. Transformer-* and DeCLUTR-* use mean pooling on their token-level embeddings to produce a fixed-length sentence representation. Downstream scores are computed without considering perfomance on SNLI (denoted "Downstream (-SNLI)") as InferSent, USE and Sentence Transformers all train on SNLI. Δ: difference to DeCLUTR-base downstream score.
Table of contents
Notebooks
The easiest way to get started is to follow along with one of our notebooks:
- Training your own model
- Embedding text with a pretrained model
- Evaluating a model with SentEval
Installation
This repository requires Python 3.6.1 or later.
Setting up a virtual environment
Before installing, you should create and activate a Python virtual environment. See here for detailed instructions.
Installing the library and dependencies
If you don't plan on modifying the source code, install from git using pip
pip install git+https://github.com/JohnGiorgi/DeCLUTR.git
Otherwise, clone the repository locally and then install
git clone https://github.com/JohnGiorgi/DeCLUTR.git
cd DeCLUTR
pip install --editable .
Gotchas
- If you plan on training your own model, you should also install PyTorch with CUDA support by following the instructions for your system here.
Usage
Preparing a dataset
A dataset is simply a file containing one item of text (a document, a scientific paper, etc.) per line. For demonstration purposes, we have provided a script that will download the WikiText-103 dataset and match our minimal preprocessing
python scripts/preprocess_wikitext_103.py path/to/output/wikitext-103/train.txt --min-length 2048
See scripts/preprocess_openwebtext.py for a script that can be used to recreate the (much larger) dataset used in our paper.
You can specify the train set path in the configs under "train_data_path".
Gotchas
- A training dataset should contain documents with a minimum of
num_anchors * max_span_len * 2whitespace tokens. This is required to sample spans according to our sampling procedure. See the dataset reader and/or our paper for more details on these hyperparameters.
Training
To train the model, use the allennlp train command with our declutr.jsonnet config. For example, to train DeCLUTR-small, run the following
# This can be (almost) any model from https://huggingface.co/ that supports masked language modelling.
TRANSFORMER_MODEL="distilroberta-base"
allennlp train "training_config/declutr.jsonnet" \
--serialization-dir "output" \
--overrides "{'train_data_path': 'path/to/your/dataset/train.txt'}" \
--include-package "declutr"
The --overrides flag allows you to override any field in the config with a JSON-formatted string, but you can equivalently update the config itself if you prefer. During training, models, vocabulary, configuration, and log files will be saved to the directory provided by --serialization-dir. This can be changed to any directory you like.
Gotchas
- There was a small bug in the original implementation that caused gradients derived from the contrastive loss to be scaled by 1/N, where N is the number of GPUs used during training. This has been fixed. To reproduce results from the paper, set
model.scale_fixtoFalsein your config. Note that this will have no effect if you are not using distributed training with more than 1 GPU.
Exporting a trained model to HuggingFace Transformers
We have provided a simple script to export a trained model so that it can be loaded with Hugging Face Transformers
wget -nc https://github.com/JohnGiorgi/DeCLUTR/blob/master/scripts/save_pretrained_hf.py
python save_pretrained_hf.py --archive-file "output" --save-directory "output_transformers"
The model, saved to --save-directory, can then be loaded using the Hugging Face Transformers library (see Embedding for more details)
from transformers import AutoTokenizer, AutoModelForMaskedLM
tokenizer = AutoTokenizer.from_pretrained("output_transformers")
model = AutoModel.from_pretrained("output_transformers")
If you would like to upload your model to the Hugging Face model repository, follow the instructions here.
Multi-GPU training
To train on more than one GPU, provide a list of CUDA devices in your call to allennlp train. For example, to train with four CUDA devices with IDs 0, 1, 2, 3
--overrides "{'distributed.cuda_devices': [0, 1, 2, 3]}"
Training with mixed-precision
If your GPU supports it, mixed-precision will be used automatically during training and inference.
Embedding
You can embed text with a trained model in one of four ways:
- Sentence Transformers: load our pretrained models with the SentenceTransformers library (recommended).
- Hugging Face Transformers: load our pretrained models with the Hugging Face Transformers library.
- From this repo: import and initialize an object from this repo which can be used to embed sentences/paragraphs.
- Bulk embed: embed all text in a given text file with a simple command-line interface.
The following pre-trained models are available:
SentenceTransformers
Ou
