SkillAgentSearch skills...

RL4LMs

A modular RL library to fine-tune language models to human preferences

Install / Use

/learn @allenai/RL4LMs

README

<p align="center"> <img src="RL4LMs_logo.png" width=512px> </p> <h1 align="center"> :robot: RL4LMs :rocket: </h1> <h3 align="center"> A modular RL library to fine-tune language models to human preferences </h3> <br>

We provide easily customizable building blocks for training language models including implementations of on-policy algorithms, reward functions, metrics, datasets and LM based actor-critic policies

Paper Link: https://arxiv.org/abs/2210.01241

Website Link: https://rl4lms.apps.allenai.org/

Thoroughly tested and benchmarked with over 2000 experiments :fire: (GRUE benchmark :trophy:) on a comprehensive set of:

  • 7 different Natural Language Processing (NLP) Tasks:
    • Summarization
    • Generative Commonsense Reasoning
    • IMDB Sentiment-based Text Continuation
    • Table-to-text generation
    • Abstractive Question Answering
    • Machine Translation
    • Dialogue Generation
  • Different types of NLG metrics (20+) which can be used as reward functions:
    • Lexical Metrics (eg: ROUGE, BLEU, SacreBLEU, METEOR)
    • Semantic Metrics (eg: BERTSCORE, BLEURT)
    • Task specific metrics (eg: PARENT, CIDER, SPICE)
    • Scores from pre-trained classifiers (eg: Sentiment scores)
  • On-policy algorithms of PPO, A2C, TRPO and novel NLPO (Natural Language Policy Optimization)
  • Actor-Critic Policies supporting causal LMs (eg. GPT-2/3) and seq2seq LMs (eg. T5, BART)

All of these building blocks can be customizable allowing users to train transformer-based LMs to optimize any arbitrary reward function on any dataset of their choice.

Recent updates (v0.2.0) on 23-Nov-22

  • Added daily dialog task
  • Fixed compatibility issues with some Seq2seq models such as BART, blendorbot etc
  • Implemented data parallel support
  • Refactored policy classes

Recent updates (v0.2.1)

  • Minor logging updates

Install

Local Installation

git clone https://github.com/allenai/RL4LMs.git
cd RL4LMs
pip install -e .

Docker

We provide also a Dockerfile for development using docker containers containing all the dependencies.

docker build . -t rl4lms

Additional dependencies

Optionally, coreNLP libraries are required for certain metric computations (eg. SPICE) which can be downloaded through cd rl4lms/envs/text_generation/caption_metrics/spice && bash get_stanford_models.sh


Quick Start - Train PPO/NLPO using pre-defined YAML configs

We provide a simple training API that can be invoked via train script that allows to train PPO, NLPO or a supervised model by using a config file (YAML).

For example, to train T5-base on CNN/DM summarization on PPO using Rouge-1 as reward function, you can run:

python scripts/training/train_text_generation.py --config_path scripts/training/task_configs/summarization/t5_ppo.yml

Config files for all tasks can be found here.

YAML file schema - Configuring building blocks

Config file contains details about hyper-parameter settings for building blocks which are described below:

  • Dataset/Task: Dataset containing samples with input prompts and reference sentences. Available datasets are found in the class DataPoolRegistry in registry. (See how to create your own dataset here)

    datapool:
      id: cnn_daily_mail
      args:
        prompt_prefix: "Summarize: "
    
  • Tokenizer - A pre-trained tokenizer that is used to (de)tokenize input and output sequences with settings for padding and truncation

    tokenizer:
      model_name: t5-base
      padding_side: left
      truncation_side: left
      pad_token_as_eos_token: False
    
  • Reward Function: Reward function which computes token-level scores at each time step of MDP. Available reward functions can be found in the class RewardFunctionRegistry. (See how to create your own reward function here)

    reward_fn:
      id: rouge
      args:
        rouge_type: "rouge1"
    
  • Environment: Configures a gym-style text generation environment which simulates MDP episodes. Rollouts are generated using train samples from dataset consisting of input and reference texts. Further, we wrap our env with SubProcVecEnv from stable-baselines that processes n_envs episodes in parallel using multi-processing to compute step-wise rewards.
    Further configuration settings include:

    • max_episode_length : max length of the episode
    • max_prompt_length - maximum length of the input text to consider
    • terminate_on_eos - whether to terminate the episode as soon as EOS action is performed
    • prompt_truncation_side - truncation side for the prompt text
    • context_start_token - id for context token (corresponds to initial token given to decoder in encoder-decoder models)
    env:
      n_envs: 10
      args:
        max_prompt_length: 512
        max_episode_length: 100
        terminate_on_eos: True
        prompt_truncation_side: "right"
        context_start_token: 0
    
  • On-policy alg: We provide implementations of 4 on-policy algorithms: PPO, NLPO, A2C and TRPO adapted from stable-baselines3 tailored to work with NLP tasks which can be used out-of-the-box with either a causal policy or a seq2seq LM policy. (See how to create your own on-policy algorithm or policy)

    • We also provide a supervised trainer for benchmarking purposes. Supervised Warm start models are already uploaded to Huggingface Hub and specified in the respective config files.

    • Hyper-parameters for the algorithm can be specified at alg/args.

    • Further, all RL algorithms use adaptive KL controller to keep the LM close to original LM by setting initial KL co-efficient (alg/kl_div/coeff) and target KL (alg/kl_div/target_kl).

    • We support two types of LM policy: causal LM policy (for decoder only models) and seq2seq LM policy (for encoder-decoder models). Further for NLPO, we also provide maskable variants of these. Policy implementations can be found here in and it can be attached to algorithms by specifying alg/policy/id and alg/policy/args

      alg:
        id: ppo
        args: 
          n_steps: 512
          batch_size: 64
          verbose: 1
          learning_rate: 0.000002
          n_epochs: 5
          ent_coef: 0.0
        kl_div:
          coeff: 0.001
          target_kl: 0.2
        policy:
          id: seq2seq_lm_actor_critic_policy
          args:
            model_name: t5-base
            apply_model_parallel: True
            prompt_truncation_side: "right"
            generation_kwargs:
              do_sample: True
              top_k: 50
              min_length: 50
              max_new_tokens: 100          
      
  • Trainer Config: We provide an On-policy trainer - a feature-complete wrapper that instantiates building blocks from their corresponding configs and provides an outer training loop consisting of train and eval iterations train_evaluation/n_iters.

    • Each iteration corresponds to performing updates with alg/args/n_steps x env/n_envs of the chosen algorithm.
    • For every eval_every iters, LM is evaluated on validation split using metrics listed in train_evaluation/metrics with generation kwargs provided in train_evaluation/generation_kwargs (this overrides rollout alg/policy/generation_kwargs for inference purposes only)
    # train and evaluation
    train_evaluation:
      eval_batch_size: 100
      n_iters: 100
      eval_every: 10
      save_every: 1
      metrics:
        - id: meteor
          args: {}
        - id: rouge
        - id: bleu
          args: {}
        - id: bert_score
          args:
            language: en
        - id: diversity
          args: {}
      generation_kwargs: 
        do_sample: True
        top_k: 0
        temperature: 0.7
        min_length: 50
        max_new_tokens: 100
    

Custom Building Blocks :wrench:

RL4LMs provide complete customizability - with respect to adding new tasks/datasets, reward functions, evaluation metric, on-policy algorithms and actor-critic policies.

Adding dataset

Users can create their own datasets by sub-classing TextGenPool just by overriding prepare(cls, split: str, **args) -> 'TextGenPool': method to return an instance of TextGenPool. An example is shown below:

from rl4lms.data_pools.text_generation_pool import Sample, TextGenPool

class MyDataPool(TextGenPool):
   @classmethod
   def prepare(cls, split: str):
       .. 
       samples = []
       for ix, item in enumerate(..):
           sample = Sample(id=f"{split}_{ix}",
                           prompt_or_input_text=item["document"],
                           references=[item["target"]]
                           )
           samples.append(sample)
       pool_instance = cls(samples)
       return pool_instance

Adding reward function

Custom reward funtions can be implemented easily by sub-classing RewardFunction (a callable) which takes observation ($s$

View on GitHub
GitHub Stars2.4k
CategoryEducation
Updated1d ago
Forks203

Languages

Python

Security Score

100/100

Audited on Apr 3, 2026

No findings