RL4LMs
A modular RL library to fine-tune language models to human preferences
Install / Use
/learn @allenai/RL4LMsREADME
We provide easily customizable building blocks for training language models including implementations of on-policy algorithms, reward functions, metrics, datasets and LM based actor-critic policies
Paper Link: https://arxiv.org/abs/2210.01241
Website Link: https://rl4lms.apps.allenai.org/
Thoroughly tested and benchmarked with over 2000 experiments :fire: (GRUE benchmark :trophy:) on a comprehensive set of:
- 7 different Natural Language Processing (NLP) Tasks:
- Summarization
- Generative Commonsense Reasoning
- IMDB Sentiment-based Text Continuation
- Table-to-text generation
- Abstractive Question Answering
- Machine Translation
- Dialogue Generation
- Different types of NLG metrics (20+) which can be used as reward functions:
- Lexical Metrics (eg: ROUGE, BLEU, SacreBLEU, METEOR)
- Semantic Metrics (eg: BERTSCORE, BLEURT)
- Task specific metrics (eg: PARENT, CIDER, SPICE)
- Scores from pre-trained classifiers (eg: Sentiment scores)
- On-policy algorithms of PPO, A2C, TRPO and novel NLPO (Natural Language Policy Optimization)
- Actor-Critic Policies supporting causal LMs (eg. GPT-2/3) and seq2seq LMs (eg. T5, BART)
All of these building blocks can be customizable allowing users to train transformer-based LMs to optimize any arbitrary reward function on any dataset of their choice.
Recent updates (v0.2.0) on 23-Nov-22
- Added daily dialog task
- Fixed compatibility issues with some Seq2seq models such as BART, blendorbot etc
- Implemented data parallel support
- Refactored policy classes
Recent updates (v0.2.1)
- Minor logging updates
Install
Local Installation
git clone https://github.com/allenai/RL4LMs.git
cd RL4LMs
pip install -e .
Docker
We provide also a Dockerfile for development using docker containers containing all the dependencies.
docker build . -t rl4lms
Additional dependencies
Optionally, coreNLP libraries are required for certain metric computations (eg. SPICE) which can be downloaded through cd rl4lms/envs/text_generation/caption_metrics/spice && bash get_stanford_models.sh
Quick Start - Train PPO/NLPO using pre-defined YAML configs
We provide a simple training API that can be invoked via train script that allows to train PPO, NLPO or a supervised model by using a config file (YAML).
For example, to train T5-base on CNN/DM summarization on PPO using Rouge-1 as reward function, you can run:
python scripts/training/train_text_generation.py --config_path scripts/training/task_configs/summarization/t5_ppo.yml
Config files for all tasks can be found here.
YAML file schema - Configuring building blocks
Config file contains details about hyper-parameter settings for building blocks which are described below:
-
Dataset/Task: Dataset containing samples with input prompts and reference sentences. Available datasets are found in the class
DataPoolRegistryin registry. (See how to create your own dataset here)datapool: id: cnn_daily_mail args: prompt_prefix: "Summarize: " -
Tokenizer - A pre-trained tokenizer that is used to (de)tokenize input and output sequences with settings for padding and truncation
tokenizer: model_name: t5-base padding_side: left truncation_side: left pad_token_as_eos_token: False -
Reward Function: Reward function which computes token-level scores at each time step of MDP. Available reward functions can be found in the class
RewardFunctionRegistry. (See how to create your own reward function here)reward_fn: id: rouge args: rouge_type: "rouge1" -
Environment: Configures a gym-style text generation environment which simulates MDP episodes. Rollouts are generated using train samples from dataset consisting of input and reference texts. Further, we wrap our env with
SubProcVecEnvfrom stable-baselines that processesn_envsepisodes in parallel using multi-processing to compute step-wise rewards.
Further configuration settings include:max_episode_length: max length of the episodemax_prompt_length- maximum length of the input text to considerterminate_on_eos- whether to terminate the episode as soon as EOS action is performedprompt_truncation_side- truncation side for the prompt textcontext_start_token- id for context token (corresponds to initial token given to decoder in encoder-decoder models)
env: n_envs: 10 args: max_prompt_length: 512 max_episode_length: 100 terminate_on_eos: True prompt_truncation_side: "right" context_start_token: 0 -
On-policy alg: We provide implementations of 4 on-policy algorithms: PPO, NLPO, A2C and TRPO adapted from stable-baselines3 tailored to work with NLP tasks which can be used out-of-the-box with either a causal policy or a seq2seq LM policy. (See how to create your own on-policy algorithm or policy)
-
We also provide a supervised trainer for benchmarking purposes. Supervised Warm start models are already uploaded to Huggingface Hub and specified in the respective config files.
-
Hyper-parameters for the algorithm can be specified at
alg/args. -
Further, all RL algorithms use adaptive KL controller to keep the LM close to original LM by setting initial KL co-efficient (
alg/kl_div/coeff) and target KL (alg/kl_div/target_kl). -
We support two types of LM policy: causal LM policy (for decoder only models) and seq2seq LM policy (for encoder-decoder models). Further for NLPO, we also provide maskable variants of these. Policy implementations can be found here in and it can be attached to algorithms by specifying
alg/policy/idandalg/policy/argsalg: id: ppo args: n_steps: 512 batch_size: 64 verbose: 1 learning_rate: 0.000002 n_epochs: 5 ent_coef: 0.0 kl_div: coeff: 0.001 target_kl: 0.2 policy: id: seq2seq_lm_actor_critic_policy args: model_name: t5-base apply_model_parallel: True prompt_truncation_side: "right" generation_kwargs: do_sample: True top_k: 50 min_length: 50 max_new_tokens: 100
-
-
Trainer Config: We provide an On-policy trainer - a feature-complete wrapper that instantiates building blocks from their corresponding configs and provides an outer training loop consisting of train and eval iterations
train_evaluation/n_iters.- Each iteration corresponds to performing updates with
alg/args/n_stepsxenv/n_envsof the chosen algorithm. - For every
eval_everyiters, LM is evaluated on validation split using metrics listed intrain_evaluation/metricswith generation kwargs provided intrain_evaluation/generation_kwargs(this overrides rolloutalg/policy/generation_kwargsfor inference purposes only)
# train and evaluation train_evaluation: eval_batch_size: 100 n_iters: 100 eval_every: 10 save_every: 1 metrics: - id: meteor args: {} - id: rouge - id: bleu args: {} - id: bert_score args: language: en - id: diversity args: {} generation_kwargs: do_sample: True top_k: 0 temperature: 0.7 min_length: 50 max_new_tokens: 100 - Each iteration corresponds to performing updates with
Custom Building Blocks :wrench:
RL4LMs provide complete customizability - with respect to adding new tasks/datasets, reward functions, evaluation metric, on-policy algorithms and actor-critic policies.
Adding dataset
Users can create their own datasets by sub-classing TextGenPool just by overriding prepare(cls, split: str, **args) -> 'TextGenPool': method to return an instance of TextGenPool. An example is shown below:
from rl4lms.data_pools.text_generation_pool import Sample, TextGenPool
class MyDataPool(TextGenPool):
@classmethod
def prepare(cls, split: str):
..
samples = []
for ix, item in enumerate(..):
sample = Sample(id=f"{split}_{ix}",
prompt_or_input_text=item["document"],
references=[item["target"]]
)
samples.append(sample)
pool_instance = cls(samples)
return pool_instance
Adding reward function
Custom reward funtions can be implemented easily by sub-classing RewardFunction (a callable) which takes observation ($s$
