SkillAgentSearch skills...

Autojudge

[NeurIPS 2025] Official PyTorch implementation for the paper AutoJudge: Judge Decoding Without Manual Annotation

Install / Use

/learn @garipovroma/Autojudge
About this skill

Quality Score

0/100

Supported Platforms

Universal

README

<div align="center"> <h1>AutoJudge: Judge Decoding Without Manual Annotation</h1> [<a href="https://arxiv.org/abs/2504.20039">📚 Paper</a>] | [<a href="https://huggingface.co/datasets/mightyneighbor/Autojudge">🤗 Datasets</a>] </div>

 

Official PyTorch implementation for the paper AutoJudge: Judge Decoding Without Manual Annotation

🚀 Running the code

Our approach introduces an algorithm for automatically identifying important token mismatches in model generations. We extract hidden states for these tokens, train a lightweight classifier to detect them, and employ it during inference.

To reproduce our results, follow these steps:

  1. Run the dataset mining script
  2. Calculate hidden states
  3. Train the classifier
  4. Run evaluations

🤗 Mined datasets available: preprocessed GSM8K & LiveCodeBench artifacts for the most compute-intensive stages (Dataset Mining & Hidden States Calculating) are available at Hugging Face mightyneighbor/Autojudge, so you can skip the first two steps for these setups!

🛠️ Getting started

Install packages from requirements.txt:

pip install -r requirements.txt

⛏️ Dataset mining 💎

Here we provide a small snippet of how to run dataset mining for GSM8K and LiveCodeBench, for the detailed instructions including multiple-gpu run please refer to the find_important_tokens_gsm8k.sh and find_important_tokens_lcb.sh scripts.

📐 GSM8K 🔢


export TORCH_DTYPE=auto
export GSM8K_TRAIN=data/train_small.jsonl # replace by data/train.jsonl for full run
export RANDOM_SEED=42
export MAX_NEW_TOKENS=2048
export OUTPUT_FOLDER=output
export OUTPUT_FILE=important_tokens
export DUMP_FREQ=64

mkdir $OUTPUT_FOLDER

# one-gpu run

CUDA_VISIBLE_DEVICES=0 python3 src/find_important_tokens.py \
    --draft_model $MODEL0 \
    --target_model $MODEL1 \
    --torch_dtype $TORCH_DTYPE \
    --gsm8k_train_path $GSM8K_TRAIN \
    --random_seed $RANDOM_SEED \
    --max_new_tokens $MAX_NEW_TOKENS \
    --output_folder $OUTPUT_FOLDER \
    --output_file $OUTPUT_FILE \
    --dump_freq $DUMP_FREQ

rm output/done*

💻 LiveCodeBench 📄

export TORCH_DTYPE=auto
export GSM8K_TRAIN=data/train_small.jsonl # replace by data/train.jsonl for full run
export RANDOM_SEED=42
export MAX_NEW_TOKENS=2048
export OUTPUT_FOLDER=output
export OUTPUT_FILE=important_tokens_lcb
export DUMP_FREQ=64
export NUM_PROCESS_EVALUATE=64
export N_TASKS=2 # will use 2 tasks for short demo, set 880 for full lcb release_v5 dataset
export TOTAL_GPUS=1

mkdir $OUTPUT_FOLDER

# one-gpu run

CUDA_VISIBLE_DEVICES=0 python3 src/find_important_tokens_lcb.py \
    --draft_model $MODEL0 \
    --target_model $MODEL1 \
    --torch_dtype $TORCH_DTYPE \
    --random_seed $RANDOM_SEED \
    --max_new_tokens $MAX_NEW_TOKENS \
    --output_folder $OUTPUT_FOLDER \
    --output_file $OUTPUT_FILE \
    --dump_freq $DUMP_FREQ \
    --n_tasks $N_TASKS \
    --num_process_evaluate $NUM_PROCESS_EVALUATE \
    --total_gpus $TOTAL_GPUS

🧮 Calculating hidden states ⚙️

For the full script including multiple-gpus run please refer to the calc_hiddens.sh script.

export MODEL0="meta-llama/Llama-3.2-1B-Instruct"
export MODEL1="meta-llama/Llama-3.1-8B-Instruct"
export TORCH_DTYPE=auto
export BATCH_SIZE=8
export DATA_FILE=output/important_tokens.pt
export OUTPUT_PATH=output/important_tokens_with_hiddens
export SAVE_FREQ=128
export N_PROCESSES=1

# single gpu run
CUDA_VISIBLE_DEVICES=0 python src/calc_hiddens.py \
    --draft_model $MODEL0 \
    --target_model $MODEL1 \
    --torch_dtype $TORCH_DTYPE \
    --batch_size $BATCH_SIZE \
    --data_file $DATA_FILE \
    --output_path $OUTPUT_PATH \
    --save_freq $SAVE_FREQ \
    --n_processes $N_PROCESSES \
    --process_id 0 

🧠 Training a classifier 🎯

export MODEL0="meta-llama/Llama-3.2-1B-Instruct"
export MODEL1="meta-llama/Llama-3.1-8B-Instruct"


python src/train_head_gsm8k_nirvana.py \
    --random_seed 52 \
    --train_size 0.9 \
    --data_path output/important_tokens_with_hiddens.pt  \
    --checkpoint_path output/trained_head.pkl \
    --target_model $MODEL1 \
    --draft_model $MODEL0 \
    --setup DD-DT \
    --train_on_all_data
    # --convert_to_vllm - add this option to get head that can be used with our vllm patch

To convert trained head to our vllm format you can run the following python script:

import pandas as pd
import os
import pickle

checkpoint = pd.read_pickle('trained_head.pkl')

head = checkpoint['model']
scaler = checkpoint['scaler']

target_hidden_size = 4096 # 4096 for Llama-3.1-8B-Instruct and 8192 for Llama-3.1-70B-Instruct
head_dict = dict(
    mean=scaler.mean_[-target_hidden_size:],
    scale=scaler.scale_[-target_hidden_size:],
    weights=head.coef_[0][-target_hidden_size:],
    bias=head.intercept_[-target_hidden_size:],
    thr=0.25 #
)

vllm_checkpoint_path = 'vllm_compatible_head.pkl'

with open(vllm_checkpoint_path, 'wb') as f:
    dump_dict = head_dict
    pickle.dump(dump_dict, f)

More scripts to be uploaded later.

📊 Evaluations 📝

Accuracy vs Average Accepted Tokens, Sections 4.1 and 4.2

Here we provide evaluation example for GSM8K, similar scripts were used to obtain main results on LiveCodeBench. To run it, please refer to eval/run_lcb_folds.py and eval/run_lcb_topk.py. There for each threshold(ours) and K(for baseline) values we also vary FOLD_ID since we use out-of-fold technique.

AutoJudge Eval

export START_IDX=0
export END_IDX=10 # set 1319 for full eval
export THR_ID=0 # vary this from 0 to 25, thresholds for inference are selected automatically in train scripts
export MODEL0="meta-llama/Llama-3.2-1B-Instruct"
export MODEL1="meta-llama/Llama-3.1-8B-Instruct"
export NUM_SHOTS=0 # set 8 for 8 shot setup
export MAX_NEW_TOKENS=1024

# Running eval on 2 gpus
CUDA_VISIBLE_DEVICES=0,1 python3 eval/gpu_parallel.py --gpus_per_script 1 --start $START_IDX --end $END_IDX --use_queue --script eval/run_inference_task.py --extra_args "--save_folder output/eval_$THR_ID --gsm8k_test_path data/gsm8k_test.json --torch_dtype auto --window_size 64 --head_path output/trained_head.pkl --setup DD-DT --max_new_tokens $MAX_NEW_TOKENS --num_shots $NUM_SHOTS --head_threshold_idx $THR_ID --draft_model $MODEL0 --target_model $MODEL1"

Top-K Baseline Eval

export START_IDX=0
export END_IDX=10 # set 1319 for full eval
export K=2048 # to be varied, we considered the following values [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 128256]
export MODEL0="meta-llama/Llama-3.2-1B-Instruct"
export MODEL1="meta-llama/Llama-3.1-8B-Instruct"
export NUM_SHOTS=0 # set 8 for 8 shot setup
export MAX_NEW_TOKENS=1024

# Running eval on 2 gpus
CUDA_VISIBLE_DEVICES=0,1 python3 eval/gpu_parallel.py --gpus_per_script 1 --start $START_IDX --end $END_IDX --use_queue --script eval/run_inference_topk_baseline_task.py --extra_args "--save_folder output/eval_baseline_$K --gsm8k_test_path data/gsm8k_test.json --torch_dtype auto --window_size 64 --head_path output/trained_head.pkl --setup DD-DT --max_new_tokens $MAX_NEW_TOKENS --num_shots $NUM_SHOTS --K $K --draft_model $MODEL0 --target_model $MODEL1"

To make a final report based on the evaluation outputs you can use the following snippet:

import pandas as pd
import numpy as np
import os

def make_pareto_curve_df(data, group_by_col='thr'):
    df = pd.DataFrame(data)

    mean_accept = pd.DataFrame(
        df.groupby(group_by_col).apply(lambda x: np.concatenate(x['raw_accepts'].tolist()).mean()),
        columns=['mean_accept']
    ).reset_index()

    gsm_acc = pd.DataFrame(
        df.groupby(group_by_col).apply(lambda x: np.mean(x['tp'])),
        columns=['gsm8k_acc']
    ).reset_index()

    pareto_curve_df = pd.merge(left=mean_accept, right=gsm_acc, on=group_by_col).sort_values(by=[group_by_col])

    return pareto_curve_df

AJ_DIRS = ['output/eval_0', 'output/eval_1'] # output/eval_2, ... output/eval_25
aj_data = []
for DIR in AJ_DIRS:
    files = os.listdir(DIR)
    aj_data.extend([pd.read_pickle(os.path.join(DIR, f)) for f in files])

autojudge_df = make_pareto_curve_df(aj_data)
print(autojudge_df)

BASELINE_DIRS = ['output/eval_baseline_0', 'output/eval_baseline_1'] # output/eval_baseline_2, ... output/eval_baseline_17
baseline_data = []
for DIR in BASELINE_DIRS:
    files = os.listdir(DIR)
    baseline_data.extend([pd.read_pickle(os.path.join(DIR, f)) for f in files])

baseline_df = make_pareto_curve_df(baseline_data, group_by_col='k')
print(baseline_df)

VLLM, Section 4.3

Clone vllm repository and checkout commit a83a0f92b56b71855dc38e8e3d9809619e58bcd1. Copy out patch file to the vllm repo and apply it: git apply vllm_patch.patch. Install vllm with VLLM_USE_PRECOMPILED pip install -e path/to/vllm/folder.

Evaluating GSM8K

Run commands

python vllm_gsm8k.py -o $RESULTS_FILE --target_model meta-llama/Llama-3.1-8B-Instruct\
    --draft_model meta-llama/Llama-3.2-1B-Instruct --judge_path vllm_heads/head_${SHOTS}shot_8b.pkl --judge_threshold $THRESHOLD --shots $SHOTS
python vllm_gsm8k.py -o $RESULTS_FILE --target_model meta-llama/Llama-3.1-70B-Instruct\
    --draft_model meta-llama/Llama-3.1-8B-Instruct --judge_path vllm_heads/head_${SHOTS}shot_70b.pkl --judge_threshold $THRESHOLD --shots $SHOTS

to run evaluations on GSM8K dataset. SHOTS can be either 0 or 8.

For example, to reproduce results for 0-shot 70B/8B model run

for threshold in 0.03719609313336198 0.07084856680433153 0.09208237305325259 0.13549077699786996 0.2209569
View on GitHub
GitHub Stars21
CategoryDevelopment
Updated22h ago
Forks2

Languages

Jupyter Notebook

Security Score

90/100

Audited on Apr 10, 2026

No findings