POLARIS
Scaling RL on advanced reasoning models
Install / Use
/learn @ChenxinAn-fdu/POLARISREADME
POLARIS
<div> 🌠 A <strong>PO</strong>st-training recipe for scaling R<strong>L</strong> on <strong>A</strong>dvanced <strong>R</strong>eason<strong>I</strong>ng model<strong>S</strong> 🚀 </div> </div> <div> <br> <div align="center"> </div> </div>Overview
Polaris is an open-source post-training recipe that leverages reinforcement learning (RL) scaling to further optimize models with strong reasoning capabilities. Our work demonstrates that even state-of-the-art models like Qwen3-4B can achieve remarkable gains on complex reasoning tasks when enhanced with Polaris. By training with open-source data and academic-grade resources, Polaris elevates the performance of open-recipe reasoning models to an entirely new level. In benchmark evaluations, our approach astonishingly outperforms leading commercial systems such as Claude-4-Opus, Grok-3-Beta, and o3-mini-high(2025/01/03).
This work is done as part of the HKU NLP Group and Bytedance Seed. Our training and evaluation codebase is built on Verl. To foster progress in scaling RL on advanced reasoning models, we are open‐sourcing our complete dataset, code, and training details for the research community.
RL training for the 4B model requires 10 days on 32 H800 GPUs (~0.33 hours per step), using a batch size of 128, a rollout size of 8.
<div align="center"> <img src="figs/aime25.png" width="80%" /> </div>🔥Releases
[2025-07-10]
- 🤗 Polaris-1.7B-Preview fine-tuned from
Qwen3-1.7Bfor 500 steps with our open-source codebase.- AIME24: 66.9 (+18.6) & AIME25: 53.0 (+16.2)
- Training scripts:
scripts/train/qwen3-1.7b - Data:
polaris-data-53K.parquet - Training logs: wandb.
- ⌨️ Polaris-Coder is coming soon. Stay tuned!
<strong>[2025/06/20]</strong>
- 🧾 The Blog that details our training recipe: Notion and Blog
- 🤗 Model weights: Polaris-4B-Preview and Polaris-7B-Preview. Polaris-4B-Preview is fine-tuned from Qwen3-4B and Polaris-7B-Preview is fine-tuned from Deepseek-R1-Distill-Qwen-7B.
- 📚 The filtered training dataset with difficulty distribution Polaris-Dataset-53K
Running environment
cd POLARIS
pip install -e ./verl
pip install -e ./
pip install transformers==4.51.0
pip install vllm==0.8.4
pip install tensordict==0.6.2
# do not use xformers backend
unset VLLM_ATTENTION_BACKEND
Demo
import torch
from transformers import AutoTokenizer
from vllm import SamplingParams, LLM
example = {
"question": "Find the largest possible real part of \\[(75+117i)z+\\frac{96+144i}{z}\\]where $z$ is a complex number with $|z|=4$.\nLet's think step by step and output the final answer within \\boxed{}.",
"answer": "540"
}
model = "/path/to/Polaris-4B-Preview"
tokenzier = AutoTokenizer.from_pretrained(model)
llm = LLM(
model=model,
dtype=torch.bfloat16,
tensor_parallel_size=1,
gpu_memory_utilization=0.9
)
sampling_params = SamplingParams(
temperature=1.4,
top_p=1.0,
max_tokens=90000
)
question = example["question"]
answer = example["answer"]
output = llm.generate(
prompts=tokenzier.apply_chat_template(conversation=[{"content": question, "role": "user"}],
add_generation_prompt=True,
tokenize=False),
sampling_params=sampling_params
)
print(f"***QUESTION***:\n{question}\n***GROUND TRUTH***:\n{answer}\n***MODEL OUTPUT***:\n{output[0].outputs[0].text}\n")
Evaluation
We recommend using a higher temperature for decoding than that suggested for Qwen3 (0.6 → 1.4). However, it is not advisable to exceed the temperature used during training. For POLARIS, a longer response length (> 64K) should be utilized to prevent performance degradation from truncation, which could otherwise cause its performance to fall below that of Qwen3. All other settings remain the same.
##### Testing with vllm (faster); Output: jsonl file #####
python scripts/eval/eval_vllm_aime24.py --model /path/to/model --n 32 --max_length 90000 --k 20 --t 1.4
python scripts/eval/eval_vllm_aime25.py --model /path/to/model --n 32 --max_length 90000 --k 20 --t 1.4 or 1.45
##### Testing with VeRL; Output: parquet file #####
./scripts/eval/eval_model_aime24.sh --model /path/to/model --n 32 --max_length 90000 --k 20 --t 1.4
./scripts/eval/eval_model_aime25.sh --model /path/to/model --n 32 --max_length 90000 --k 20 --t 1.4 or 1.45
Grade the outputs📊:
python evaluation/grade.py --file_name evaluation/results/aime24-reproduced.parquet or jsonl file # replace with your output file
Training
Data preparation
We provide the parquet data for training Qwen3-4B.
The training data used in this work is filtered from DeepScaleR-dataset-40K and AReaL-dataset-106K. To process your json or jsonl data, use the following command to convert it into Parquet format:
# Generate parquet files for parquet_data/polaris-data-53K.parquet
python scripts/data/jsonl2parquet.py --jsonl_file data/jsonl_data/polaris-data-53K.jsonl
Multi-stage training on single node
The training scripts for Qwen3-1.7B, Qwen3-4B, Deepseek-R1-Distill-Qwen-7B are avaliable here.
Please set the "max_position_embeddings": 131072 in config.json before training.
You can run the scripts on a single node by:
###### stage1 ######
# stage1 training script
./scripts/train/qwen3-4b/stage1.sh --model /path/to/qwen3-4b --data_path parquet/stage1/qwen3-4b-s1.parquet --experiment_name qwen3-4b-stage1 (unique experiment id)
###### stage2 ######
# convert the checkpoint after stage1-training to hf model
python verl/scripts/model_merger.py --local_dir /path/to/checkpoints/global_step_xxx/actor --target_dir checkpoints_hf/ckpt-4b-stage1
# Then find the temperature that yields a diversity score similar to stage-1
# You can follow our temperature setting but re-searching for the optimal temperature for `checkpoints_hf/ckpt-4b-stage1` is a better approach.
python search_optimal_temperature.py --start 1.4 --end 1.6 --step 0.05 --model /path/to/model
# You can use our provided data but drop the easy data based on your training process is a better approach.
python drop_easy_data.py --data_path parquet/stage1/qwen3-4b-s1.parquet --experiment_name qwen3-4b-stage1 --output parquet/stage2/qwen3-4b-s2.parquet
# stage2 training script
./scripts/train/qwen3-4b/stage2.sh --model checkpoints_hf/ckpt-4b-stage1 --data_path parquet/stage2/qwen3-4b-s2.parquet --experiment_name qwen3-4b-stage2
###### stage3 ######
# convert the checkpoint after stage1-training to hf model \ search for the optimal temeprature \ remove the easy samples
# stage3 training script
./scripts/train/qwen3-4b/stage3.sh --model heckpoints_hf/ckpt-4b-stage2 --data_path parquet/stage3/qwen3-4b-s3.parquet --experiment_name qwen3-4b-stage3
Debug
Pdb is not supported in Ray. In this codebase you can set trainer.debug=True and insert breakpoint() (instead of pdb.set_trace()) to debug.
...
batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
breakpoint()
batch = batch.union(gen_batch_output)
...
# open a new terminal and run:
ray debug
Multi-node training
To accelerate the training process, we recommend using at least 4 nodes.
Our multi-node training is based on Ray.
You can run ray start --head on the head node and ray start --address=[RAY_ADDRESS] on other nodes to start the Ray cluster.
After starting the cluster,run the training script on the head node. We also prepare a useful script which is very easy to start the training without manually initializing Ray:
