BranchGRPO
BranchGRPO: Stable and Efficient GRPO with Structured Branching in Diffusion Models
Install / Use
/learn @Fredreic1849/BranchGRPOREADME
BranchGRPO: Stable and Efficient GRPO with Structured Branching in Diffusion Models (WIP)
BranchGRPO is a novel approach that restructures the rollout process into a branching tree, where shared prefixes amortize computation and pruning removes low-value paths and redundant depths.
📄 Paper: arXiv:2509.06040
🌐 Project Page: https://fredreic1849.github.io/BranchGRPO-Webpage/
💻 Code: GitHub Repository
Abstract
Recent progress in aligning image and video generative models with Group Relative Policy Optimization (GRPO) has improved human preference alignment, yet existing approaches still suffer from high computational cost due to sequential rollouts and large numbers of SDE sampling steps, as well as training instability caused by sparse rewards. In this paper, we present BranchGRPO, a method that restructures the rollout process into a branching tree, where shared prefixes amortize computation and pruning removes low-value paths and redundant depths.
Key Features
BranchGRPO introduces three main contributions:
- Branch Sampling Scheme: Reduces rollout cost by reusing common segments
- Tree-based Advantage Estimator: Converts sparse terminal rewards into dense, step-level signals
- Pruning Strategies: Accelerate convergence while preserving exploration
Performance
- 16% improvement in alignment scores over strong baselines on HPDv2.1 image alignment
- 55% reduction in per-iteration training time
- Higher Video-Align scores with sharper and temporally consistent frames on WanX-1.3B video generation
Getting Started
Prerequisites
- Python 3.8+
- PyTorch 2.0+
- CUDA 11.8+
- 8+ GPUs (H800/A100 recommended)
Installation
# Clone the repository
git clone https://github.com/your-username/BranchGRPO.git
cd BranchGRPO
# Set up environment
./env_setup.sh branchgrpo
# Install dependencies
pip install -r requirements.txt
Download Checkpoints
- FLUX checkpoints: Download from here to
./data/flux - HPS-v2.1 checkpoint: Download from here to
./hps_ckpt - CLIP H-14 checkpoint: Download from here to
./hps_ckpt
Quick Start
Multi-GPU Training
# Preprocess embeddings (8 GPUs)
bash scripts/preprocess/preprocess_flux_rl_embeddings.sh
# Train with BranchGRPO (8 GPUs)
bash scripts/finetune/finetune_flux_branchgrpo_8gpus.sh
Note: For multi-node training, please configure the launcher (e.g., Slurm, torchrun, MPI) according to your own cluster environment.
Configuration
Key parameters for BranchGRPO:
--tree_split_points: Comma-separated split points (e.g., "0,3,6,9")--tree_split_noise_scale: Noise scale for tree splits (default: 4.0)--depth_pruning: Depths to prune from training (e.g., "15,16,17")--width_pruning_mode: Width pruning strategy (0=none, 1=best per branch, 2=global best/worst)--mix_ode_sde_tree: Enable mixed ODE/SDE rollout
Method Overview
BranchGRPO restructures sequential GRPO rollouts into a branching tree:
- Branching Rollouts: At selected denoising steps, trajectories split into multiple children that share early prefixes
- Reward Fusion: Leaf rewards are fused upward using path-probability weighting
- Depth-wise Normalization: Normalized per depth to obtain dense, step-wise advantages
- Pruning: Lightweight width and depth pruning limit backpropagation to selected nodes
Results
Efficiency-Quality Comparison
| Method | NFE π_θ_old | NFE π_θ | Iteration Time (s)↓ | HPS-v2.1↑ | Pick Score↑ | Image Reward↑ | | ------------------- | ----------- | ------- | ------------------- | --------- | ----------- | ------------- | | FLUX | - | - | - | 0.313 | 0.227 | 1.112 | | DanceGRPO (tf=1.0) | 20 | 20 | 698 | 0.360 | 0.229 | 1.189 | | DanceGRPO (tf=0.6) | 20 | 12 | 469 | 0.353 | 0.228 | 1.219 | | MixGRPO (20,5) | 20 | 5 | 289 | 0.359 | 0.228 | 1.211 | | BranchGRPO | 13.68 | 13.68 | 493 | 0.363 | 0.229 | 1.233 | | BranchGRPO-WidPru | 13.68 | 8.625 | 314 | 0.364 | 0.230 | 1.300 | | BranchGRPO-DepPru | 13.68 | 8.625 | 314 | 0.369 | 0.231 | 1.319 | | BranchGRPO-Mix | 13.68 | 4.25 | 148 | 0.363 | 0.230 | 1.290 |
Contributing
We welcome contributions! Please see our contributing guidelines for details.
Acknowledgments
This work builds upon:
Citation
If you use BranchGRPO in your research, please cite our paper:
@article{li2025branchgrpo,
title={BranchGRPO: Stable and Efficient GRPO with Structured Branching in Diffusion Models},
author={Li, Yuming and Wang, Yikai and Zhu, Yuying and Zhao, Zhongyu and Lu, Ming and She, Qi and Zhang, Shanghang},
journal={arXiv preprint arXiv:2509.06040},
year={2025}
}
License
This project is licensed under the Apache License 2.0 - see the LICENSE file for details.
Related Skills
node-connect
349.9kDiagnose OpenClaw node connection and pairing failures for Android, iOS, and macOS companion apps
frontend-design
109.8kCreate distinctive, production-grade frontend interfaces with high design quality. Use this skill when the user asks to build web components, pages, or applications. Generates creative, polished code that avoids generic AI aesthetics.
openai-whisper-api
349.9kTranscribe audio via OpenAI Audio Transcriptions API (Whisper).
qqbot-media
349.9kQQBot 富媒体收发能力。使用 <qqmedia> 标签,系统根据文件扩展名自动识别类型(图片/语音/视频/文件)。
