SkillAgentSearch skills...

Atlas

Code repository for supporting the paper "Atlas Few-shot Learning with Retrieval Augmented Language Models",(https//arxiv.org/abs/2208.03299)

Install / Use

/learn @facebookresearch/Atlas
About this skill

Quality Score

0/100

Supported Platforms

Universal

README

Atlas: Few-shot Learning with Retrieval Augmented Language Models

REPO NO LONGER MAINTAINED, RESEARCH CODE PROVIDED AS IT IS

This repository contains pre-trained models, corpora, indices, and code for pre-training, finetuning, retrieving and evaluating for the paper Atlas: Few-shot Learning with Retrieval Augmented Language Models

Read our Atlas blog post for a quick overview of the project and how to run the code with torchrun (slurm free option).

We jointly pretrain a retrieval-augmented seq2seq language model, comprised of a passage-based dense retriever and a encoder-decoder language model. We perform evaluations on a wide range of tasks, including MMLU, KILT and NaturalQuestions, and study the impact of the content of the document index, showing that it can easily be updated. Notably, Atlas reaches over 45% accuracy on Natural Questions using only 64 examples when supplied with wikipedia index from 2018, outperforming a 540B parameters model by 6% despite having 50x fewer parameters. Atlas also works very well when finetuned on larger datasets - when finetuned on the full Natural Questions data, Atlas sets a new state-of-the-art of 64%, 8 points higher than the current state of the art.

This repository supports pretraining and finetuning, for both large and small datasets. This repository can be supports the following features:

  • Training large fusion-in-decoder seq2seq models, tested up to 11B parameters
  • Distilling relevance signals from fusion-in-decoder models into dense retrieval models using a variety of different distillation approaches.
  • Performing end-to-end retrieval-augmented training over a user-supplied corpus of passages (tested with up to 400M passages, ~40B words) with retrieval-in-the-training-loop
  • Support for training on Masked-Language modelling, prefix-language modelling, wikipedia section generation, Open-Domain Question Answering, Multiple Choice Question Answering, Fact checking, and KILT (arbitrary seq2seq tasks can also be supported)
  • A fast, parallel distributed GPU-based exact and approximate maximum inner product search for dense vector retrieval
  • Support for fast in-place index refreshes
  • Various memory optimizations and methods for maintaining fast and accurate retrieval while training retrievers in-the-loop.
  • plus more, see the command line arguments or the readme for additional features

Table of Contents

Installation

The Atlas codebase uses the following dependencies:

  • python 3 (tested with 3.8)
  • fairscale (tested with 0.4.6)
  • transformers (tested with 4.18.0)
  • numpy (tested with 1.22.4)
  • faiss (tested with 1.7.2)

We recommend installing using conda. The following will install all dependencies:

git clone https://github.com/facebookresearch/atlas.git
cd atlas
conda create --name atlas-env python=3.8
conda activate atlas-env
conda install pytorch==1.11.0 cudatoolkit=11.3 -c pytorch
conda install -c pytorch faiss-gpu=1.7.2 cudatoolkit=11.3
pip install -r requirements.txt

Getting Started and Codebase at a Glance

The Atlas repository provides functionality for training and evaluating retrieval-augmented generation models, comprised of an encoder-decoder language model, and dense-vector retriever.

<!-- We current functionality for *jointly* training an encoder-decoder language model and retrieval-augmented with a dense retriever. -->

We currently support T5 architectures for the encoder-decoder language model and Contriever architectures for the retriever (Support for other architectures is not currently planned, but PRs are welcome). Atlas models are comprised of a Contriever retriever and fusion-in-decoder (FID) architecture (which uses T5). You can learn more about the Contriever and FID here and here respectively if desired, but all required functionality has been reimplemented in this codebase.

The biggest difference to most standard NLP training codebases is that Atlas performs retrieval on-the-fly, and can refresh its retrieval embeddings index in-place. This is achieved using a custom-designed distributed GPU index, which automatically handles fast and scale-able retrieval.

A note on how retrieval is accomplished: When launching a training or evaluation run, the codebase will first load pretrained models, then each GPU worker will load a shard of the supplied passages to retrieve from -- if there are N GPUs, each will load a shard of 1/N passages. Each worker will then embed its shard of the passages using the retriever embedder, and keep the passage embedding shard in GPU memory (and optionally build a FAISS index). At this point, the passage and embedding shards (referred to as "the index") can be optionally saved to disk to avoid the need to recompute indices for every run. Retrieval is performed in parallel, with each GPU worker performing an exact maximum inner product search for all the queries for its shard. More details on retrieval are given in the Retrieval and Index Details section. Note that all of the above is all handled automatically by the codebase, so users should not need to know or worry too much about how embedding, index refresh or retrieval is accomplished, other than

  1. noting that they can easily retrieve from any set of passages that they like by just passing in paths to suitably-formatted passages on disk (or any saved index)
  2. noting that embedding, index refresh retrieving will get faster with more GPU workers.
  3. Depending on how many GPUs and CPU memory is available, Atlas can support training models with 11B+ parameters and indices of 400M+ vectors, or ~40 billion words (assuming ~100 words a passage)

Training and Evaluation uses a data-parallel model: for N GPU workers, each processes 1/N of the total mini-batch of data. To save memory at training time, optimizer state and gradients can be sharded using fairscale's ShardedDataParallel.

All data files (retriever passages and train/dev/test data) should be supplied in the form of jsonlines ("jsonl") files. Passages to retrieve from should consist of json-serialized objects with text and title text fields, one passage per line. Example passage files are available for wikipedia (see corpora). Train/dev/test data files should be json-serialized objects, one instance per line. The name of the fields is task dependent (covered in detail in Tasks), but e.g. for NaturalQuestions, the required fields are question (a question string) and answers (a list of reference answer strings)

The codebase has two entrypoint scripts: train.py for training, and evaluate.py for test-time evaluation (and stand-alone retrieval, if you want). You can list the full Atlas functionality by printing the command-line flags using python train.py -h (full output here)

The easiest way to illustrate the codebase is with an example:

The following example shows an example use case: few-shot finetuning and evaluating on NaturalQuestions with Atlas-large (which are also available as a runnable sbatch scripts in example_scripts/nq/), retrieving from a wikipedia dump from 2018 (of about 30M passages)

# assumes 4 nodes, each with 8 GPUs
DATA_DIR=./atlas_data
SIZE=large # lets use large, (slower than base, but still quite fast and accessible, but less accurate than xl or xxl)

# download the NQ data
python preprocessing/prepare_qa.py --output_directory ${DATA_DIR}/data/
# download the Wikipedia 2018 corpus
python preprocessing/download_corpus.py --corpus corpora/wiki/enwiki-dec2018 --output_directory ${DATA_DIR} 
# downloads pretrained Atlas-large
python preprocessing/download_model.py --model models/atlas/${SIZE} --output_directory ${DATA_DIR}  

port=$(shuf -i 15000-16000 -n 1)
TRAIN_FILE="${DATA_DIR}/data/nq_data/train.64-shot.jsonl"
EVAL_FILES="${DATA_DIR}/data/nq_data/dev.jsonl"
SAVE_DIR=${DATA_DIR}/experiments/
EXPERIMENT_NAME=my-nq-64-shot-example
TRAIN_STEPS=30

srun python train.py \
    --shuffle \
    --train_retriever \
    --gold_score_mode pdist \ # loss function for r
View on GitHub
GitHub Stars554
CategoryCustomer
Updated6d ago
Forks72

Languages

Python

Security Score

80/100

Audited on Mar 25, 2026

No findings