RECAP
Retrieval-Enhanced Context-Aware Prefix Encoder for Personalized Dialogue Response Generation
Install / Use
/learn @isi-nlp/RECAPREADME
RECAP
The official repository for the ACL 2023 paper "RECAP: Retrieval-Enhanced Context-Aware Prefix Encoder for Personalized Dialogue Response Generation".
Installation
Commends for enviroment setup with conda.
conda create --name recap python=3.8
conda activate recap
pip install -U pip
pip install -r requirements.txt
Data
The data is extracted from the Reddit dump from pushshift.io. To preserve persona and personal writing style as much as possible, we did not filter out conversations with unethical content. You can download the raw data from the link here.
Pre-processing
Pre-process the raw data into the format for retrieval and generation.
Retrieval Data
Encode text representations
python src/preprocess/encode_comments.py -d <raw_data_path> -o <output_path>
Retrieval
python src/preprocess/retrieval.py -d <raw_data_path> -o <output_path>
Generation Data
Most recent hisotry responses
python src/preprocess/recent.py -d <raw_data_path> -o <output_path>
Retrieved by hierarchical transformer
This requires the retriever output in retrieved_path. Please see section training retriever and inference retrieve for details on how to train and retrieve with the hierarchical transformer retriever.
python src/preprocess/retrieved.py -d <raw_data_path> -r <retrieved_path> -o <output_path>
Training
Train the retriever and the generator on a single GPU. The code works for multi GPUs, but the batch_size here is per device batch size, so please change it accordingly if you use more than one GPU.
Retriever
python src/train_retriever.py \
--data_path <data_path> \
--raw_data_path <raw_data_path> \
--reps_path <representations_path> \
--save_path <save_path> \
--ref_type <style OR semantic> \
--lr 5e-5 \
--batch_size 4 \
--grad_accumulation 8 \
--warmup 6250 \
--nhead 12
Generator
python src/train_generator.py \
--data_path <data_path> \
--save_path <save_path> \
--injection_mode <(optional) concat OR context-prefix> \
--ref_type <(optional) style OR semantic> \
--lr 5e-5 \
--batch_size 128 \
--warmup 10000
Inference
Retrieve and generate with trained models.
Retrieve
python src/retrieve.py \
--data_path <data_path> \
--model_path <retriever_model_path> \
--save_path <save_path> \
--ref_type <style OR semantic>
Generate
python src/generated.py \
--data_path <data_path> \
--model_path <generator_model_path> \
--save_path <save_path> \
--injection_mode <(optional) concat OR context-prefix> \
--ref_type <(optional) style OR semantic>
Evaluate
Please download the bleurt checkpoint BLEURT-20-D3 from here before running the evaluation.
python src/eval.py \
--generated_path <generated_responses_path> \
--dataset_path <data_path> \
--cache_dir <eval_cache_dir> \
--cav_samples <eval_cav_samples_file>
Related Skills
node-connect
354.2kDiagnose OpenClaw node connection and pairing failures for Android, iOS, and macOS companion apps
frontend-design
112.2kCreate distinctive, production-grade frontend interfaces with high design quality. Use this skill when the user asks to build web components, pages, or applications. Generates creative, polished code that avoids generic AI aesthetics.
openai-whisper-api
354.2kTranscribe audio via OpenAI Audio Transcriptions API (Whisper).
qqbot-media
354.2kQQBot 富媒体收发能力。使用 <qqmedia> 标签,系统根据文件扩展名自动识别类型(图片/语音/视频/文件)。
