RRHF
[NIPS2023] RRHF & Wombat
Install / Use
/learn @GanjinZero/RRHFREADME
Wombat 🐻❄️: from RLHF to RRHF, Aligning Human Preferences in a 'Right' Way
<center> <a href="https://en.wikipedia.org/wiki/Wombat" target="_blank"><img style="border-radius: 0.3125em; box-shadow: 0 2px 4px 0 rgba(34,36,38,.12),0 2px 10px 0 rgba(34,36,38,.08);" src="./assets/wombat.png"></a> <br> <div style="color:orange; border-bottom: 1px solid #d9d9d9; display: inline-block; color: #999; padding: 2px;">Wombats are adorable little creatures native to Australia. The first three pictures are generated from Stable Diffusion.</div> </center>License Notices: The dataset is CC BY NC 4.0 (allowing only non-commercial use) and models trained using the dataset should not be used outside of research purposes.
Update:
- 2023/4/13 We have released the weights of Wombat - LLaMA on Hugging Face. One can recover Wombat weights based on it.
- 2023/4/15 We add comparison with Alpaca-7B and ChatGPT based on Vicuna test set.
- 2023/5/23 We update our paper for more discussions and experiments.
- 2023/9/22 Paper accepted by NIPS 2023!
Overview
This is the repository for RRHF (Rank Response to align Human Feedback) and open-sourced language models Wombat. RRHF helps align large language models with human perference easier.
Reinforcement Learning from Human Feedback (RLHF) enables the alignment of large language models with human preference, improving the quality of interactions between humans and language models. Recent practice of RLHF uses PPO to enable the large language model optimization of such alignment. However, implementing PPO is non-trivial (where the training procedure requires interactive between policy, behavior policy, reward, value model) and it is also tedious to tuning many hyper-parameters. Our motivation is to simplify the alignment between language models with human preference, and our proposed paradigm RRHF (Rank Response from Human Feedback) can achieve such alignment as easily as conventional fine-tuning. It is simpler than PPO from the aspects of coding, model counts, and hyperparameters.
<center> <a target="_blank"><img style="border-radius: 0.3125em; box-shadow: 0 2px 4px 0 rgba(34,36,38,.12),0 2px 10px 0 rgba(34,36,38,.08);" src="./assets/comparison_of_workflow.png"></a> <br> <div style="color:orange; border-bottom: 1px solid #d9d9d9; display: inline-block; color: #999; padding: 2px;">Overview of workflow comparison between PPO and RRHF.</div> </center>In our preliminary experiments, we compare RRHF and PPO using 7B LLaMA [1] and Alpaca [2] models on Anthropic’s Helpful and Harmless (HH) [3] dataset. We evaluate the results by perplexity (PPL) and reward model scores (Reward). With a much simpler training paradigm, we found that RRHF perform comparable result with PPO in terms of generation fluency (PPL) and alignements (Reward).
| Models | Setting | PPL | Reward | | ------ | ------- | --------- | --------- | | LLaMA | PPO | 42.53 | -1.62 | | Alpaca | PPO | 13.84 | -1.03 | | LLaMA | RRHF | 67.12 | -1.34 | | Alpaca | RRHF | 14.75 | -1.02 |
For details, please refer to our paper on Arxiv. RRHF is still working in progress, and there are still limitations in this preliminary study. Due to the large cost of human evaluation, we experiment on the HH datasets and use a trained reward model Dahoas/gptj-rm-static trained by Dahoas. The reward model plays a role of a synthetic human feedback and the experiments is a proof-of-concept for RRHF. We are open to any suggestions and discussions and feel free to contact us through yuanzheng.yuanzhen@alibaba-inc.com, yuanhy20@mails.tsinghua.edu.cn or chuanqi.tcq@alibaba-inc.com.
Setting Up Environment
To set up, you can use the following command lines to set up python3.8 and pytorch requirements:
conda create -n rrhf python=3.8
pip install torch==1.13.0+cu116 torchvision==0.14.0+cu116 torchaudio==0.13.0 --extra-index-url https://download.pytorch.org/whl/cu116
Then install Hugging Face's transformers from the github repo for LLaMA models.
git clone https://github.com/huggingface/transformers.git
pip install -e ./transformers
Install other packages:
pip install -r requirements.txt
Train Alpaca with RRHF on Helpful and Harmless dataset
We use Helpful and Harmless dataset to compare PPO and RRHF. We use trained reward function from Dahoas/gptj-rm-static.
| Models | Initial Checkpoint | Sampling Models | Reward Score | | -------------- | ------------------ | ------------------------------------ | --------------------- | | Alpaca-7B-RRHF | Alpaca-7B | Alpaca-7B, responses from HH dataset | Dahoas/gptj-rm-static |
Data Generation
RRHF firstly samples responses for each query in the training data from the initial models, and then scores each response (including the 'chosen' and 'rejected' response in original HH labels) using the reward models.
The scripts for data generation are in ./data_generation, you can use throught the command line:
cd ./data_generation/
bash response_gen.sh <path_to_alpaca/hf_llama_directory> <path_to_data_directory>
We also release our generated data for the ease of RRHF training implementation through this link. After download, place it to <path_to_data_directory>.
Training with RRHF
You can train your own model with generated or released datasets using the script train.sh, please note that the training process requires 8*A100 80GB GPUs, bf16 and FSDP. In the future, we will try efficient training methods such as LoRA or Prefix-tuning or Adapter to lower the computational resource requirements.
bash ./train.sh <path_to_alpaca_directory> <save_path_directory> <path_to_data_json>
If you only have one A100, please try
--fsdp "full_shard auto_wrap offload"
Wombat: build your own chatbot
Introduction
To produce a more general purpose language model chatbot, we introduce Wombat to the model zoo of open-resourced language models.
| Models | Initial Checkpoint | Sampling Models | Reward Score | Delta Weights | | -------------- | ------------------ | ----------------------- | ------------ | ----------------------------------------------------------------------------------------- | | Wombat-7B | Alpaca-7B | ChatGPT, LLaMA, Alpaca | ChatGPT | GanjinZero/wombat-7b-delta | | Wombat-7B-GPT4 | Alpaca-7B | GPT-4, GPT-3.5, OPT-IML | GPT-4 | GanjinZero/wombat-7b-gpt4-delta |
Comparison based on Vicuna test set
| Model A | Score A | Score B | Model B| | -------| -------| -------| -------| | Alpaca-7B | 567|616|Wombat-7B| | Alpaca-7B-ChatGPT | 574 | 612 |Wombat-7B| | ChatGPT| 669 | 548|Wombat-7B|
Alpaca-7B-ChatGPT is initialized by LLaMA and trained use prompt from Alpaca and responses from ChatGPT.
Math and programming skill are weak for all-LLaMA-7B based models.
Weights
You should obtain LLaMa weights follow link. And you can use our provided scripts recover_wombat_7b.sh to recover origin Wombat weights.
Data and Training
- We reuse the query from Alpaca training data, and sample responses from Alpaca, LLaMA, chatGPT and text-davinci-003. We acquire quality assessments of responses from chatGPT, and train Alpace with RRHF to become a Wombat-7B. You can acquire the data with rewards for Wombat-7B from this link, and start training your own "Wombat". Use the following command lines:
bash ./train_wombat7b.sh <path_to_alpaca_directory> <save_path_directory> <path_to_data_json>
- You can acquire the data with rewards for Wombat-7B-GPT4 from GPT-4-LLM, we direct use their data trained for reward model as our training data. To use it, please first convert the data format by clean_gpt4_compare.py. Use the following command lines to train Wombat-7B-GPT4:
bash ./train_wombat7b_gpt4.sh <path_to_alpaca_directory> <save_path_directory> <path_to_data_json>
Responses generated by Wombat family
| Query | Wombat-7B
Related Skills
node-connect
343.1kDiagnose OpenClaw node connection and pairing failures for Android, iOS, and macOS companion apps
frontend-design
90.0kCreate distinctive, production-grade frontend interfaces with high design quality. Use this skill when the user asks to build web components, pages, or applications. Generates creative, polished code that avoids generic AI aesthetics.
openai-whisper-api
343.1kTranscribe audio via OpenAI Audio Transcriptions API (Whisper).
qqbot-media
343.1kQQBot 富媒体收发能力。使用 <qqmedia> 标签,系统根据文件扩展名自动识别类型(图片/语音/视频/文件)。
