RL4VLA
No description available
Install / Use
/learn @gen-robot/RL4VLAREADME
VLA-RL-Study: What Can RL Bring to VLA Generalization? An Empirical Study
Introduction
This repository contains the code for the paper What Can RL Bring to VLA Generalization? An Empirical Study. The pretrained checkpoints are available at HuggingFace.
Install
OpenVLA, Maniskill, Training Pipeline
# create conda env: rlvla_env
conda create -n rlvla_env -y python=3.10
conda activate rlvla_env
# install dependencies
pip install torch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 --index-url https://download.pytorch.org/whl/cu121
cd openvla && pip install -e . && cd ..
pip install -U tyro
pip install datasets==3.3.2
# special install for flash attention
wget https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.2cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
pip install flash_attn-2.7.4.post1+cu12torch2.2cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
rm flash_attn-2.7.4.post1+cu12torch2.2cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
# install other dependencies
cd ManiSkill && pip install -e . && cd ..
cd SimplerEnv && pip install -e . && cd ..
# optional: for ubuntu 2204
# sudo apt-get install libglvnd-dev
RLDS Dataset Maker
Used for building VLA warm-up dataset and OpenVLA SFT datasets.
# create conda env: rlds_env
cd openvla/rlds_dataset_builder
conda env create -f environment_ubuntu.yml
Octo Inference
Used for collecting data with Octo-Small, when building VLA warm-up dataset.
conda create -n octo_env -y python=3.10
conda activate octo_env
git clone https://github.com/octo-models/octo.git
cd ManiSkill && pip install -e . && cd ..
cd octo && pip install -e . && pip install -r requirements.txt && cd ..
pip install --upgrade "jax[cuda11_pip]==0.4.20" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 "nvidia-cudnn-cu11>=8.7,<9.0" --index-url https://download.pytorch.org/whl/cu118
pip install -U tyro
pip install scipy==1.12.0
cd SimplerEnv && pip install -e . && cd ..
Train
Warm-up OpenVLA
Collect Data with Octo-Small
Collect data with Octo-Small to build the warm-up dataset. Average Octo-Small success rate is about 14% on this task.
conda activate octo_env
cd SimplerEnv
cuda=0
# for OpenVLA warm-up (extra 5 trajectories for performance evaluation)
CUDA_VISIBLE_DEVICES=$cuda XLA_PYTHON_CLIENT_PREALLOCATE=false \
python simpler_env/eval_ms3_collect.py \
--env_id "PutCarrotOnPlateInScene-v1"\
--num-episodes 75 --num-envs 64 --seed 0
# try to increase `num-episodes` if not enough successful trajectories is collected
Collect Data with motion planner
Collect data with motion planner to build the warm-up dataset and SFT dataset.
conda activate rlvla_env
cd ManiSkill
cuda=0
# for OpenVLA warm-up (extra 5 trajectories for performance evaluation)
CUDA_VISIBLE_DEVICES=$cuda \
python -m mani_skill.examples.motionplanning.widowx.collect_simpler \
-e "PutOnPlateInScene25Single-v1" \
--save_video --save_data --num_procs 1 --num_traj 75 --seed=0
# for SFT (extra 16 trajectories for performance evaluation)
CUDA_VISIBLE_DEVICES=$cuda \
python -m mani_skill.examples.motionplanning.widowx.collect_simpler \
-e "PutOnPlateInScene25Main-v3" \
--save_video --save_data --num_procs 16 --num_traj 16400 --seed=100
Build VLA Warm-up Dataset
conda activate rlds_env
cd openvla/rlds_dataset_builder/warmup_dataset
tfds build --overwrite
cd ../../../ # at the root dir of this project
mkdir -p datasets
mv -T ~/tensorflow_datasets/example_dataset datasets/warmup
Warm-up OpenVLA
conda activate rlvla_env
cd openvla
# 1. Train LoRA
cuda="0,1,2,3"
task_name="warmup"
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True CUDA_VISIBLE_DEVICES=$cuda \
torchrun --standalone --nnodes 1 --nproc-per-node 4 vla-scripts/finetune.py \
--vla_path "openvla/openvla-7b" \
--data_root_dir "../datasets" \
--dataset_name ${task_name} \
--run_root_dir checkpoints/${task_name} \
--lora_rank 32 \
--batch_size 8 \
--max_steps 2000 \
--eval_steps 50 \
--save_steps "0,500,1000,1500,2000" \
--grad_accumulation_steps 1 \
--learning_rate 5e-4 \
--image_aug True \
--unnorm_key="bridge_orig" \
--wandb_project "RLVLA_sft"
# for 80G GPU, max batch size is 20
# for 40G GPU, max batch size is 8
# 2. Merge LoRA
cuda="0"
task_name="warmup"
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True CUDA_VISIBLE_DEVICES=$cuda \
torchrun --standalone --nnodes 1 --nproc-per-node 1 vla-scripts/merge_lora.py \
--vla_path "openvla/openvla-7b" \
--run_path "checkpoints/${task_name}/steps_2000" \
--lora_name "lora_002000"
RL
conda activate rlvla_env
cd SimplerEnv
#cuda="0,1" # env on GPU-0, model on GPU-1 (for 40G GPU)
cuda="0" # env and model on the same GPU (for 80G GPU)
CUDA_VISIBLE_DEVICES=$cuda XLA_PYTHON_CLIENT_PREALLOCATE=false PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \
python simpler_env/train_ms3_ppo.py \
--name="PPO-pc25m_v3-warmup" \
--env_id="PutOnPlateInScene25Main-v3" \
--vla_path="openvla/openvla-7b" --vla_unnorm_key="bridge_orig" \
--vla_load_path="../openvla/checkpoints/warmup/steps_2000/lora_002000" \
--seed=0
- GRPO: add
--alg_name="grpo" - GRPO (s): add
--alg_name="grpo"and--use_same_init - PPO from scratch: remove
--vla_load_patharg
SFT
Build OpenVLA SFT Dataset
conda activate rlds_env
# ulimit -n 17000 # avoid "too many open files" error
cd openvla/rlds_dataset_builder/sft_dataset
tfds build --overwrite
cd ../../../
mkdir -p datasets
mv -T ~/tensorflow_datasets/example_dataset datasets/sft
SFT Train
conda activate rlvla_env
cd openvla
cuda="0,1,2,3"
task_name="sft"
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True CUDA_VISIBLE_DEVICES=$cuda \
torchrun --standalone --nnodes 1 --nproc-per-node 4 ../openvla/vla-scripts/finetune.py \
--vla_path "../openvla/checkpoints/warmup/steps_2000/merged_002000" \
--data_root_dir "../datasets" \
--dataset_name ${task_name} \
--run_root_dir checkpoints/${task_name} \
--lora_rank 32 \
--batch_size 8 \
--max_steps 60000 \
--eval_steps 200 \
--save_steps "0,2500,5000,7500,10000,15000,20000,25000,30000,35000,40000,45000,50000,55000,60000" \
--grad_accumulation_steps 1 \
--learning_rate 5e-4 \
--image_aug False \
--wandb_project "RLVLA_sft"
Evaluate
Trained from scratch
conda activate rlvla_env
cd SimplerEnv
# Warm-up
ckpt_path="openvla/openvla-7b"
unnorm_key="bridge_orig"
vla_load_path="../openvla/checkpoints/warmup/steps_2000/lora_002000"
# RL
ckpt_path="openvla/openvla-7b"
unnorm_key="bridge_orig"
vla_load_path="../SimplerEnv/wandb/run-xxx-xxx/glob/steps_xxx" # replace with the actual path
# SFT
ckpt_path="../openvla/checkpoints/warmup/steps_2000/merged_002000"
unnorm_key="sft"
vla_load_path="../openvla/checkpoints/sft/steps_60000-no_aug/lora_060000"
# start evaluation
for seed in 0 1 2 ; do
for env_id in
"PutOnPlateInScene25VisionImage-v1" "PutOnPlateInScene25VisionTexture03-v1" "PutOnPlateInScene25VisionTexture05-v1" \
"PutOnPlateInScene25VisionWhole03-v1" "PutOnPlateInScene25VisionWhole05-v1" \
"PutOnPlateInScene25Carrot-v1" "PutOnPlateInScene25Plate-v1" "PutOnPlateInScene25Instruct-v1" \
"PutOnPlateInScene25MultiCarrot-v1" "PutOnPlateInScene25MultiPlate-v1" \
"PutOnPlateInScene25Position-v1" "PutOnPlateInScene25EEPose-v1" "PutOnPlateInScene25PositionChangeTo-v1" ; \
do
CUDA_VISIBLE_DEVICES=$cuda XLA_PYTHON_CLIENT_PREALLOCATE=false \
python simpler_env/train_ms3_ppo.py \
--vla_path="${ckpt_path}" --vla_unnorm_key="${unnorm_key}" \
--vla_load_path="${vla_load_path}" \
--env_id="${env_id}" \
--seed=${seed} \
--buffer_inferbatch=64 \
--no_wandb --only_render
done
done
# for 40G GPU, set `--buffer_inferbatch=16` to avoid OOM
Pre-trained checkpoints
The pretrained checkpoints (warm-upped, RL and SFT) are available at HuggingFace. Follow the evaluation scripts in the above section, and replace the environment variable with the pretrained checkpoint path.
# Warm-up (pretrained)
ckpt_path="gen-robot/openvla-7b-rlvla-warmup"
unnorm_key="bridge_orig"
vla_load_path=""
# RL (pretrained)
ckpt_path="gen-robot/openvla-7b-rlvla-rl"
unnorm_key="bridge_orig"
vla_load_path=""
# SFT (pretrained)
ckpt_path="gen-robot/openvla-7b-rlvla-sft_16k"
unnorm_key="sft"
vla_load_path=""
Gather results
- Option 1: Manually check the results and visualization videos: at
SimplerEnv/wandb/offline-run-xxx-xxx/glob/ - Option 2: Calculate statistics: at
SimplerEnv/scriptsrunpython calc_statistics.py, then check the results atSimplerEnv/scripts/stats
Task definition:
PutOnPlateInScene25VisionImage-v1-test: unseen tablePutOnPlateInScene25VisionTexture03-v1-test: dynamic texture (weak)PutOnPlateInScene25VisionTexture05-v1-test: dynamic texture (strong)PutOnPlateInScene25VisionWhole03-v1-`test
