├── .gitignore ├── data ├── AIME24 │ └── test.parquet ├── AIME25 │ └── test.parquet ├── AMC23 │ └── test.parquet ├── HMMT25 │ └── test.parquet ├── BRUMO25 │ └── test.parquet ├── CMIMC25 │ └── test.parquet ├── MATH-500 │ └── test.parquet ├── Minerva │ └── test.parquet └── Olympiad-Bench │ └── test.parquet ├── assets └── fig1_aime24_curves_added.png ├── evals ├── gen_vllm.py ├── grade.py └── utils.py └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | evals/__pycache__/ 2 | justrl_eval_outputs/ -------------------------------------------------------------------------------- /data/AIME24/test.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/JustRL/HEAD/data/AIME24/test.parquet -------------------------------------------------------------------------------- /data/AIME25/test.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/JustRL/HEAD/data/AIME25/test.parquet -------------------------------------------------------------------------------- /data/AMC23/test.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/JustRL/HEAD/data/AMC23/test.parquet -------------------------------------------------------------------------------- /data/HMMT25/test.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/JustRL/HEAD/data/HMMT25/test.parquet -------------------------------------------------------------------------------- /data/BRUMO25/test.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/JustRL/HEAD/data/BRUMO25/test.parquet -------------------------------------------------------------------------------- /data/CMIMC25/test.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/JustRL/HEAD/data/CMIMC25/test.parquet -------------------------------------------------------------------------------- /data/MATH-500/test.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/JustRL/HEAD/data/MATH-500/test.parquet -------------------------------------------------------------------------------- /data/Minerva/test.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/JustRL/HEAD/data/Minerva/test.parquet -------------------------------------------------------------------------------- /data/Olympiad-Bench/test.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/JustRL/HEAD/data/Olympiad-Bench/test.parquet -------------------------------------------------------------------------------- /assets/fig1_aime24_curves_added.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/JustRL/HEAD/assets/fig1_aime24_curves_added.png -------------------------------------------------------------------------------- /evals/gen_vllm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | import concurrent.futures 5 | from pathlib import Path 6 | 7 | import pandas as pd 8 | from tqdm import tqdm 9 | from vllm import LLM, SamplingParams 10 | 11 | # --------------------------------------------------------------------------- # 12 | # Global constants / variables # 13 | # --------------------------------------------------------------------------- # 14 | DATA_DIR = "data" 15 | TASKS = [ 16 | {"name": "AIME24", "path": f"{DATA_DIR}/AIME24/test.parquet", "N": 32}, 17 | {"name": "AIME25", "path": f"{DATA_DIR}/AIME25/test.parquet", "N": 32}, 18 | {"name": "AMC23", "path": f"{DATA_DIR}/AMC23/test.parquet", "N": 32}, 19 | {"name": "MATH-500", "path": f"{DATA_DIR}/MATH-500/test.parquet", "N": 4}, 20 | {"name": "Minerva", "path": f"{DATA_DIR}/Minerva/test.parquet", "N": 4}, 21 | {"name": "Olympiad-Bench", "path": f"{DATA_DIR}/Olympiad-Bench/test.parquet", "N": 4}, 22 | {"name": "BRUMO25", "path": f"{DATA_DIR}/BRUMO25/test.parquet", "N": 32}, 23 | {"name": "CMIMC25", "path": f"{DATA_DIR}/CMIMC25/test.parquet", "N": 32}, 24 | {"name": "HMMT25", "path": f"{DATA_DIR}/HMMT25/test.parquet", "N": 32}, 25 | ] 26 | PROMPT_TEMPLATE = """{problem} Please reason step by step, and put your final answer within \\boxed{{}}.""" 27 | NAME = "hbx/JustRL-DeepSeek-1.5B" # "hbx/JustRL-Nemotron-1.5B" 28 | MAX_TOKENS = 31744 29 | TEMPERATURE = 0.7 30 | TOP_P = 0.9 31 | OUT_DIR = Path(f"justrl_eval_outputs/{NAME.split('/')[-1]}") 32 | OUT_DIR.mkdir(parents=True, exist_ok=True) 33 | 34 | # --------------------------------------------------------------------------- # 35 | # Helper functions # 36 | # --------------------------------------------------------------------------- # 37 | def load_samples(filepath: str): 38 | """Read parquet file and return a list of prompts (no duplication).""" 39 | df = pd.read_parquet(filepath) 40 | if "BRUMO25" in filepath or "CMIMC25" in filepath or "HMMT25" in filepath: 41 | samples = [ 42 | { 43 | "example_id": i, 44 | "prompt": df.at[i, "problem"].strip(), 45 | "answer": df.at[i, "answer"].strip(), 46 | } 47 | for i in range(len(df)) 48 | ] 49 | else: 50 | samples = [ 51 | { 52 | "example_id": i, 53 | "prompt": df.at[i, "prompt"][0]["content"].strip(), 54 | "answer": df.at[i, "reward_model"]["ground_truth"].strip(), 55 | } 56 | for i in range(len(df)) 57 | ] 58 | print(f"Total unique samples: {len(samples)}") 59 | return samples 60 | 61 | 62 | def split_seeds(seeds: list[int], num_workers: int): 63 | """Round-robin split of the seed list into num_workers chunks.""" 64 | chunks = [[] for _ in range(num_workers)] 65 | for idx, s in enumerate(seeds): 66 | chunks[idx % num_workers].append(s) 67 | return chunks 68 | 69 | 70 | # --------------------------------------------------------------------------- # 71 | # Worker process (one GPU) # 72 | # --------------------------------------------------------------------------- # 73 | def worker_process(args_tuple): 74 | """ 75 | Each worker runs on a single GPU: 76 | 77 | args_tuple = (samples, seed_list, gpu_id) 78 | """ 79 | samples, seed_list, gpu_id = args_tuple 80 | os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) 81 | print(f"[GPU {gpu_id}] seeds={seed_list} | loading model...", flush=True) 82 | 83 | llm = LLM(model=NAME, enforce_eager=True) 84 | results = [] 85 | 86 | for seed in seed_list: 87 | sampling = SamplingParams( 88 | temperature=TEMPERATURE, 89 | top_p=TOP_P, 90 | max_tokens=MAX_TOKENS, 91 | seed=seed, 92 | ) 93 | messages = [[{"role": "user", "content": s["prompt"]}] for s in samples] 94 | outputs = llm.chat(messages, sampling, use_tqdm=True) 95 | for sample, out in zip(samples, outputs): 96 | results.append( 97 | { 98 | "example_id": sample["example_id"], 99 | "prompt": sample["prompt"], 100 | "answer": sample["answer"], 101 | "seed": seed, 102 | "response": out.outputs[0].text, 103 | } 104 | ) 105 | return results 106 | 107 | 108 | # --------------------------------------------------------------------------- # 109 | # main # 110 | # --------------------------------------------------------------------------- # 111 | def main(): 112 | available_workers = [0,1,2,3,4,5,6,7] 113 | num_workers = len(available_workers) 114 | for task in TASKS: 115 | task_name = task["name"] 116 | task_path = task["path"] 117 | N = task["N"] 118 | 119 | print(f"Starting evaluation for task: {task_name} (N={N})") 120 | 121 | # Update output path for the current task 122 | out_path = OUT_DIR / f"{task_name.lower()}_t{TEMPERATURE}_p{TOP_P}_n{N}-MNT{MAX_TOKENS}.jsonl" 123 | 124 | # 1. Load original prompts 125 | samples = load_samples(task_path) 126 | 127 | # Append suffix prompt to each sample 128 | for sample in samples: 129 | sample["prompt"] = PROMPT_TEMPLATE.format(problem=sample["prompt"]) 130 | 131 | # demo print 132 | print("Example prompt after formatting:") 133 | print(samples[0]["prompt"]) 134 | 135 | # 2. Generate N distinct random seeds and split across GPUs 136 | random_seeds = random.sample(range(2**31 - 1), N) # unique & shuffled 137 | seed_chunks = split_seeds(random_seeds, num_workers) 138 | 139 | # 3. Launch workers 140 | all_results = [] 141 | args_list = [(samples, seed_chunks[i], gid) for (i, gid) in enumerate(available_workers)] 142 | with concurrent.futures.ProcessPoolExecutor(max_workers=num_workers) as ex: 143 | futures = [ex.submit(worker_process, tup) for tup in args_list] 144 | for fut in tqdm(concurrent.futures.as_completed(futures), 145 | total=len(futures), desc=f"GPU workers ({task_name})"): 146 | all_results.extend(fut.result()) 147 | 148 | print(f"Total generations collected for {task_name}: {len(all_results)}") # len(samples) * N 149 | 150 | # 4. Save to disk 151 | with out_path.open("w", encoding="utf-8") as f: 152 | for item in all_results: 153 | f.write(json.dumps(item, ensure_ascii=False) + "\n") 154 | print(f"Saved results for {task_name} to {out_path}") 155 | 156 | 157 | if __name__ == "__main__": 158 | main() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

JustRL: Simplicity at Scale

3 |
4 | 🚀 Competitive RL Performance Without Complex Techniques 🌟 5 |
6 |
7 |
8 | 9 |
10 | 11 | Code 12 | 13 | 14 | Hugging Face 15 | 16 | 17 | Notion 18 | 19 | 20 | Paper 21 | 22 |
23 | 24 | ## 📰 Overview 25 | 26 | **JustRL** demonstrates that competitive reinforcement learning performance for small language models doesn't require complex multi-stage pipelines or dynamic schedules. Using a minimal recipe with single-stage training and fixed hyperparameters, we achieve state-of-the-art results on mathematical reasoning tasks. This repository contains a lightweight evaluation script to reproduce evaluation results for **JustRL** models on nine challenging math benchmarks. 27 | 28 | We release two models: 29 | 30 | - [**JustRL-DeepSeek-1.5B**](https://huggingface.co/hbx/JustRL-DeepSeek-1.5B): Trained from DeepSeek-R1-Distill-Qwen-1.5B 31 | - [**JustRL-Nemotron-1.5B**](https://huggingface.co/hbx/JustRL-Nemotron-1.5B): Trained from OpenMath-Nemotron-1.5B 32 | 33 | Both models use identical hyperparameters without per-model tuning, demonstrating the robustness of our approach. 34 | 35 | ![The AIME24 performance curve for scaling from a weak base DeekSeek-R1-Distill-Qwen-1.5B and a strong base OpenMath-Nemotron-1.5B over thousands of steps.](./assets/fig1_aime24_curves_added.png) 36 | 37 | ## 🎯 Key Highlights 38 | 39 | ✨ **Simplicity**: Single-stage training with fixed hyperparameters, without multi-stage pipelines or dynamic schedules 40 | 41 | 📈 **Stability**: Smooth, monotonic improvement over 4,000+ training steps without collapses or oscillations 42 | 43 | 🎯 **Performance**: State-of-the-art results at 1.5B scale, matching or exceeding more complex approaches 44 | 45 | 💰 **Efficiency**: Comparable or better performance with 2× less compute than multi-stage methods 46 | 47 | 🔓 **Open**: Complete evaluation scripts, and model weights released 48 | 49 | ## 📁 Repository Structure 50 | 51 | ``` 52 | JustRL/ 53 | ├── evals/ # Evaluation scripts 54 | │ ├── gen_vllm.py # Generation script using vLLM 55 | │ ├── grade.py # Grading script with hybrid verification 56 | │ └── utils.py # Answer verification utilities 57 | ├── data/ # Benchmark datasets 58 | │ ├── AIME24/ 59 | │ ├── AIME25/ 60 | │ ├── AMC23/ 61 | │ ├── MATH-500/ 62 | │ ├── Minerva/ 63 | │ ├── Olympiad-Bench/ 64 | │ ├── BRUMO25/ 65 | │ ├── CMIMC25/ 66 | │ └── HMMT25/ 67 | └── justrl_eval_outputs/ # Evaluation outputs (download from Google Drive) 68 | ├── JustRL-DeepSeek-1.5B/ 69 | │ ├── *.jsonl # Generation outputs per benchmark 70 | │ └── grading_results.json 71 | └── JustRL-Nemotron-1.5B/ 72 | ├── *.jsonl 73 | └── grading_results.json 74 | ``` 75 | 76 | ## 🔧 Setup 77 | 78 | ### Environment Requirements 79 | 80 | We recommend using a conda environment with the following key dependencies: 81 | 82 | ```bash 83 | conda create -n justrl python=3.10 84 | conda activate justrl 85 | ``` 86 | 87 | ### Key Dependencies 88 | 89 | - **PyTorch**: `2.6.0` 90 | - **vLLM**: `0.8.4` 91 | - **transformers**: `4.51.3` 92 | - **sympy**: `1.13.1` 93 | - **pylatexenc**: `2.10` 94 | 95 | ### Download Evaluation Outputs 96 | 97 | The evaluation outputs are large and hosted on Google Drive. Download them for reproduction: 98 | 99 | **📥 Download Link**: [Google Drive](https://drive.google.com/file/d/1G5oHTNYR8edbj6NLDgY8_6X3SB1MDngc/view?usp=sharing) 100 | 101 | After downloading, extract the `justrl_eval_outputs/` directory to the repository root directory. 102 | 103 | ## 🚀 Usage 104 | 105 | This evaluation script is based on [POLARIS](https://github.com/ChenxinAn-fdu/POLARIS), with one key modification: we add a model-based verifier ([CompassVerifier-3B](https://huggingface.co/opencompass/CompassVerifier-3B)) for more robust evaluation, complementing the rule-based verification system. 106 | 107 | ### Generation (Optional) 108 | 109 | ```bash 110 | cd evals 111 | python gen_vllm.py 112 | ``` 113 | 114 | Configure the model name in `gen_vllm.py` by setting the `NAME` variable. And set appropriate`available_workers`. 115 | 116 | ### Grading 117 | 118 | ```bash 119 | cd evals 120 | python grade.py 121 | ``` 122 | 123 | The grading script processes all JSONL files in the output directory and generates `grading_results.json`. 124 | 125 | ## 📈 Performance 126 | 127 | ### JustRL-DeepSeek-1.5B (Based on DeepSeek-R1-Distill-Qwen-1.5B) 128 | 129 | | Model | AIME24 (@32) | AIME25 (@32) | AMC23 (@32) | MATH-500 (@4) | Minerva (@4) | OlympiadBench (@4) | HMMT25 (@32) | BRUMO25 (@32) | CMIMC25 (@32) | Avg | 130 | | ------------------------ | ------------ | ------------ | ----------- | ------------- | ------------ | ------------------ | ------------ | ------------- | ------------- | --------- | 131 | | DeepSeek-R1-Distill-1.5B | 29.90 | 22.40 | 63.82 | 84.90 | 34.65 | 45.95 | 13.44 | 30.94 | 12.89 | 37.65 | 132 | | DeepScaleR-1.5B-Preview | 40.21 | 28.65 | 73.83 | 89.30 | 39.34 | 52.79 | 18.96 | 40.00 | 21.00 | 44.88 | 133 | | ProRL-V2 | 51.87 | 35.73 | 88.75 | 92.00 | 49.03 | 67.84 | 19.38 | 47.29 | **25.86** | 53.08 | 134 | | BroRL | **57.50** | 36.88 | / | **92.14** | 49.08 | 61.54 | / | / | / | / | 135 | | JustRL-DeepSeek-1.5B | 52.60 | **38.75** | **91.02** | 91.65 | **51.47** | **67.99** | **21.98** | **52.71** | 25.63 | **54.87** | 136 | 137 | Besides, the real question is whether our simplicity comes at a computational cost. It doesn't. We match half of ProRL-V2's compute budget while using a single-stage recipe with fixed hyperparameters. BroRL requires 4.9× more compute by increasing rollouts to 512 per example, essentially exhaustively exploring the solution space. Our approach achieves competitive performance without this computational overhead. 138 | 139 | ### JustRL-Nemotron-1.5B (Based on OpenMath-Nemotron-1.5B) 140 | 141 | | Model | AIME24 (@32) | AIME25 (@32) | AMC23 (@32) | MATH-500 (@4) | Minerva (@4) | OlympiadBench (@4) | HMMT25 (@32) | BRUMO25 (@32) | CMIMC25 (@32) | Avg | 142 | | ---------------------- | ------------ | ------------ | ----------- | ------------- | ------------ | ------------------ | ------------ | ------------- | ------------- | --------- | 143 | | OpenMath-Nemotron-1.5B | 58.75 | 48.44 | 90.55 | 92.40 | 26.93 | 71.70 | 30.10 | 61.67 | 30.08 | 56.74 | 144 | | QUESTA-Nemotron-1.5B | **71.56** | 62.08 | 93.44 | 92.95 | **32.08** | 72.28 | **40.94** | **67.50** | 41.48 | 63.81 | 145 | | JustRL-Nemotron-1.5B | 69.69 | **62.92** | **96.02** | **94.15** | 30.24 | **76.59** | 40.63 | 66.88 | **41.72** | **64.32** | 146 | 147 | We achieve 64.32% average, slightly outperforming QuestA's 63.81% and leading on five of nine benchmarks. The gap is narrow, which makes sense—both approaches are pushing the boundaries of what's achievable at 1.5B scale. The key difference is in how we get there. We use 2× less compute while achieving slightly better average performance without designing a complex curriculum as used in QuestA. 148 | 149 | ## 📖 Training Recipe 150 | 151 | Our approach is deliberately minimal: 152 | 153 | **Core Algorithm**: Standard GRPO with binary outcome rewards 154 | 155 | - **Reward**: Simple DAPO verifier (string-matching, no SymPy) 156 | - **Training**: Single-stage, no curriculum or stage transitions 157 | - **Hyperparameters**: Fixed throughout (no adaptive schedules) 158 | - **Data**: DAPO-Math-17k without filtering or dynamic sampling 159 | - **Length Control**: 16K context cap (no explicit penalties) 160 | - **Stabilization**: Only "clip higher" for gradient stability 161 | 162 | Detail hyperparameters and comparisons on training techniques with other methods can refer to our [blog](https://relieved-cafe-fe1.notion.site/JustRL-Scaling-a-1-5B-LLM-with-a-Simple-RL-Recipe-24f6198b0b6b80e48e74f519bfdaf0a8). 163 | 164 | **Training Data**: We train on [DAPO-Math-17k](https://huggingface.co/datasets/BytedTsinghua-SIA/DAPO-Math-17k), a curated dataset of mathematical problems. **No offline difficulty filtering or online dynamic sampling is used.** 165 | 166 | ## 🎈 Citation 167 | 168 | ```bibtex 169 | @misc{he2025justrlscaling15bllm, 170 | title={JustRL: Scaling a 1.5B LLM with a Simple RL Recipe}, 171 | author={Bingxiang He and Zekai Qu and Zeyuan Liu and Yinghao Chen and Yuxin Zuo and Cheng Qian and Kaiyan Zhang and Weize Chen and Chaojun Xiao and Ganqu Cui and Ning Ding and Zhiyuan Liu}, 172 | year={2025}, 173 | eprint={2512.16649}, 174 | archivePrefix={arXiv}, 175 | primaryClass={cs.CL}, 176 | url={https://arxiv.org/abs/2512.16649}, 177 | } 178 | ``` 179 | -------------------------------------------------------------------------------- /evals/grade.py: -------------------------------------------------------------------------------- 1 | 2 | from utils import grade_answer_verl 3 | from transformers import AutoTokenizer 4 | import json 5 | import pandas as pd 6 | from pathlib import Path 7 | import re 8 | from vllm import LLM, SamplingParams 9 | 10 | CV_PROMPT = """ 11 | Please as a grading expert, judge whether the final answers given by the candidates below are consistent with the standard answers, that is, whether the candidates answered correctly. 12 | Here are some evaluation criteria: 13 | 1. Please refer to the given standard answer. You don't need to re-generate the answer to the question because the standard answer has been given. You only need to judge whether the candidate's answer is consistent with the standard answer according to the form of the question. THE STANDARD ANSWER IS ALWAYS CORRECT AND THE QUESTION IS PERFECTLY VALID. NEVER QUESTION THEM. 14 | 2. ONLY compare the FINAL ANSWER - COMPLETELY IGNORE any potential errors in the REASONING PROCESSES. 15 | 3. Some answers may be expressed in different ways, such as some answers may be a mathematical expression, some answers may be a textual description, as long as the meaning expressed is the same. Before making a judgment, please understand the question and the standard answer first, and then judge whether the candidate's answer is correct. 16 | 4. Some answers may consist of multiple items, such as multiple-choice questions, multiple-select questions, fill-in-the-blank questions, etc. Regardless of the question type, the final answer will be considered correct as long as it matches the standard answer, regardless of whether the reasoning process is correct. For multiple-select questions and multi-blank fill-in-the-blank questions, all corresponding options or blanks must be answered correctly and match the standard answer exactly to be deemed correct. 17 | 5. If the prediction is given with \\boxed{{}}, please ignore the \\boxed{{}} and only judge whether the candidate's answer is consistent with the standard answer. 18 | 6. If the candidate's answer is invalid (e.g., incomplete (cut off mid-response), lots of unnormal repetitive content, or irrelevant to the question, saying it can't answer the question because some irresistible factors, like ethical issues, no enough information, etc.), select option C (INVALID).Please judge whether the following answers are consistent with the standard answer based on the above criteria. Grade the predicted answer of this new question as one of: 19 | A: CORRECT 20 | B: INCORRECT 21 | C: INVALID 22 | Just return the letters "A", "B", or "C", with no text around it. 23 | Here is your task. Simply reply with either CORRECT, INCORRECT, or INVALID. Don't apologize or correct yourself if there was a mistake; we are just trying to grade the answer. 24 | : 25 | {question} 26 | 27 | : 28 | {gold_answer} 29 | 30 | : 31 | {llm_response} 32 | 33 | Judging the correctness of the candidate's answer: 34 | """ 35 | 36 | NAME = "JustRL-Nemotron-1.5B" # "JustRL-Nemotron-1.5B" 37 | EVAL_DIR = Path(f"justrl_eval_outputs/{NAME}") 38 | OUTPUT_FILE = EVAL_DIR / "grading_results.json" 39 | 40 | model_name = "opencompass/CompassVerifier-3B" 41 | model_tokenizer = AutoTokenizer.from_pretrained(model_name) 42 | vllm_model = LLM( 43 | model=model_name, 44 | tensor_parallel_size=1 45 | ) 46 | sampling_params = SamplingParams( 47 | temperature=0.0, 48 | max_tokens=2048 49 | ) 50 | 51 | length_tokenizer = None 52 | 53 | def get_len(seq): 54 | return len(length_tokenizer.encode(seq)) 55 | 56 | def get_diverse_score(sequences, n=4): 57 | """ 58 | calculate the Distinct-n score。 59 | 60 | sequences: List[str] response list 61 | n: int, n-gram default=4 62 | """ 63 | distinct_ngrams = set() 64 | total_ngrams = 0 65 | 66 | for seq in sequences: 67 | # more accurate n-gram 68 | # tokens = nltk.word_tokenize(seq) 69 | tokens = seq.split() 70 | for i in range(len(tokens) - n + 1): 71 | ngram = tuple(tokens[i:i + n]) 72 | distinct_ngrams.add(ngram) 73 | total_ngrams += 1 74 | 75 | return len(distinct_ngrams) / total_ngrams if total_ngrams > 0 else 0 76 | 77 | def process_jsonl_file(file_name): 78 | """ 79 | Process a JSONL file and dynamically handle the number of problems. 80 | """ 81 | results = [] 82 | with open(file_name) as f: 83 | for line in f: 84 | data = json.loads(line) 85 | id = int(data["example_id"]) 86 | while len(results) <= id: # Ensure the list is large enough 87 | results.append({"gt": None, "responses": []}) 88 | gt = data["answer"] 89 | response = data["response"] 90 | results[id]["gt"] = gt 91 | results[id]["responses"].append(response) 92 | return results 93 | 94 | def parse_hyperparameters_from_filename(filename): 95 | """ 96 | Parse hyperparameters from the filename. 97 | Example filename format: {taskname}_t{temperature}_p{top_p}_n{n}-MNT{max_tokens}.jsonl 98 | """ 99 | match = re.search(r"_t(?P[\d.]+)_p(?P[\d.]+)_n(?P\d+)-MNT(?P\d+)", 100 | filename) 101 | return match.groupdict() if match else {} 102 | 103 | def grade_file(file_path): 104 | """ 105 | Grade a single file and return the results. 106 | """ 107 | hyperparams = parse_hyperparameters_from_filename(file_path.name) 108 | if not hyperparams: 109 | print(f"Skipping file with unrecognized format: {file_path}") 110 | return None 111 | 112 | task_name = file_path.stem.split("_")[0] 113 | hyperparams["task_name"] = task_name 114 | 115 | if "parquet" in str(file_path): 116 | df = pd.read_parquet(file_path) 117 | num_pred = len(df["responses"][0]) 118 | else: 119 | df = process_jsonl_file(file_path) 120 | num_pred = len(df[0]["responses"]) 121 | 122 | results = { 123 | "hyperparameters": hyperparams, 124 | "mean_score": 0, 125 | "distinct_4gram": 0, 126 | "best_score": 0, 127 | "solve_none": 0, 128 | "solve_all": 0, 129 | "avg_output_length": 0, 130 | "format_error_rollouts": 0, 131 | } 132 | 133 | diverse = [] 134 | avg_scores = [] 135 | best = [] 136 | solve_none = 0 137 | solve_all = 0 138 | without_boxed = 0 139 | response_lengths = [] 140 | incorrect_data = [] # List to store incorrect responses and ground truths 141 | 142 | all_model_inputs = [] # Collect all prompts for batch processing 143 | all_responses = [] # Keep track of responses for mapping back 144 | all_questions = [] # Keep track of questions for mapping back 145 | all_ground_truths = [] # Keep track of ground truths for mapping back 146 | rule_based_scores = [] # Store rule-based scores for fallback logic 147 | 148 | for i in range(len(df)): 149 | if "jsonl" in str(file_path): 150 | responses = df[i]["responses"] 151 | gt = df[i]["gt"] 152 | question = df[i].get("question", "") # Assuming question is part of the data 153 | else: 154 | responses = df["responses"][i] 155 | gt = df["reward_model"][i]["ground_truth"] 156 | question = df["reward_model"][i].get("question", "") 157 | 158 | responses_list = [str(response) for response in responses] 159 | if length_tokenizer: 160 | response_lengths += [get_len(response) for response in responses_list] 161 | else: 162 | response_lengths = [0] 163 | not_formated = ["boxed" not in response for response in responses_list] 164 | without_boxed += sum(not_formated) 165 | 166 | # First, use the rule-based verifier 167 | for response in responses_list: 168 | rule_score = grade_answer_verl(response, gt) 169 | rule_based_scores.append(rule_score) 170 | if not rule_score: # If rule-based verifier fails, prepare for model-based verifier 171 | model_input = CV_PROMPT.format( 172 | question=question, 173 | gold_answer=gt, 174 | llm_response=response 175 | ) 176 | all_model_inputs.append(model_input) 177 | all_responses.append(response) 178 | all_questions.append(question) 179 | all_ground_truths.append(gt) 180 | 181 | diverse.append(get_diverse_score(responses_list)) 182 | 183 | # Batch process all model-based verifier inputs 184 | if all_model_inputs: 185 | model_inputs = [model_tokenizer.apply_chat_template( 186 | [{"role": "user", "content": input_text}], 187 | add_generation_prompt=True, 188 | tokenize=False 189 | ) for input_text in all_model_inputs] 190 | outputs = vllm_model.generate(model_inputs, sampling_params) 191 | 192 | # Map back the results to the corresponding responses 193 | model_based_scores = [] 194 | for idx, output in enumerate(outputs): 195 | judgement = output.outputs[0].text.strip() 196 | model_score = "A" == judgement # True if "A" (correct), False otherwise 197 | model_based_scores.append(model_score) 198 | 199 | # Save incorrect responses and ground truths 200 | if not model_score: 201 | incorrect_data.append({ 202 | "response": all_responses[idx][-300:], # Save last 300 characters 203 | "ground_truth": all_ground_truths[idx] 204 | }) 205 | 206 | # Combine rule-based and model-based scores 207 | model_idx = 0 208 | final_scores = [] 209 | for rule_score in rule_based_scores: 210 | if rule_score: # If rule-based verifier passed 211 | final_scores.append(rule_score) 212 | else: # Use model-based verifier score 213 | final_scores.append(model_based_scores[model_idx]) 214 | model_idx += 1 215 | else: 216 | final_scores = rule_based_scores 217 | 218 | # Calculate metrics 219 | avg_scores = [sum(final_scores[i:i + num_pred]) / num_pred for i in range(0, len(final_scores), num_pred)] 220 | best = [max(final_scores[i:i + num_pred]) for i in range(0, len(final_scores), num_pred)] 221 | 222 | solve_none = sum(1 for avg_score in avg_scores if avg_score == 0) 223 | solve_all = sum(1 for avg_score in avg_scores if avg_score == 1) 224 | 225 | results["mean_score"] = sum(avg_scores) / len(avg_scores) 226 | results["distinct_4gram"] = sum(diverse) / len(diverse) 227 | results["best_score"] = sum(best) / len(best) 228 | results["solve_none"] = solve_none 229 | results["solve_all"] = solve_all 230 | results["avg_output_length"] = sum(response_lengths) / len(response_lengths) 231 | results["format_error_rollouts"] = without_boxed 232 | 233 | # Save incorrect responses and ground truths to a separate file 234 | # incorrect_file = EVAL_DIR / f"{file_path.stem}_incorrect_data.json" 235 | # with incorrect_file.open("w", encoding="utf-8") as f: 236 | # json.dump(incorrect_data, f, indent=4) 237 | 238 | return results 239 | 240 | def main(): 241 | all_results = [] 242 | for file_path in EVAL_DIR.glob("*.jsonl"): 243 | print(f"Processing file: {file_path}") 244 | file_result = grade_file(file_path) 245 | if file_result: 246 | all_results.append(file_result) 247 | 248 | # Save results to JSON 249 | with OUTPUT_FILE.open("w", encoding="utf-8") as f: 250 | json.dump(all_results, f, indent=4) 251 | print(f"Grading results saved to {OUTPUT_FILE}") 252 | 253 | if __name__ == "__main__": 254 | main() 255 | 256 | -------------------------------------------------------------------------------- /evals/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Answer checker API that uses sympy to simplify expressions and check for equality. 3 | 4 | Call grade_answer(given_answer: str, ground_truth: str). 5 | """ 6 | import re 7 | from pylatexenc import latex2text 8 | import sympy 9 | from sympy.parsing import sympy_parser 10 | from typing import Optional 11 | 12 | 13 | # Dan Hendrycks' code 14 | def mathd_normalize_answer(answer: Optional[str]) -> Optional[str]: 15 | if answer is None: 16 | return None 17 | answer = answer.strip() 18 | try: 19 | # Remove enclosing `\text{}`. 20 | m = re.search("^\\\\text\{(?P.+?)\}$", answer) 21 | if m is not None: 22 | answer = m.group("text").strip() 23 | return _strip_string(answer) 24 | except: 25 | return answer 26 | 27 | def _strip_string(string): 28 | def _fix_fracs(string): 29 | substrs = string.split("\\frac") 30 | new_str = substrs[0] 31 | if len(substrs) > 1: 32 | substrs = substrs[1:] 33 | for substr in substrs: 34 | new_str += "\\frac" 35 | if substr[0] == "{": 36 | new_str += substr 37 | else: 38 | try: 39 | assert len(substr) >= 2 40 | except: 41 | return string 42 | a = substr[0] 43 | b = substr[1] 44 | if b != "{": 45 | if len(substr) > 2: 46 | post_substr = substr[2:] 47 | new_str += "{" + a + "}{" + b + "}" + post_substr 48 | else: 49 | new_str += "{" + a + "}{" + b + "}" 50 | else: 51 | if len(substr) > 2: 52 | post_substr = substr[2:] 53 | new_str += "{" + a + "}" + b + post_substr 54 | else: 55 | new_str += "{" + a + "}" + b 56 | string = new_str 57 | return string 58 | 59 | 60 | def _fix_a_slash_b(string): 61 | if len(string.split("/")) != 2: 62 | return string 63 | a = string.split("/")[0] 64 | b = string.split("/")[1] 65 | try: 66 | a = int(a) 67 | b = int(b) 68 | assert string == "{}/{}".format(a, b) 69 | new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" 70 | return new_string 71 | except: 72 | return string 73 | 74 | 75 | def _remove_right_units(string): 76 | # "\\text{ " only ever occurs (at least in the val set) when describing units 77 | if "\\text{ " in string: 78 | splits = string.split("\\text{ ") 79 | assert len(splits) == 2 80 | return splits[0] 81 | else: 82 | return string 83 | 84 | 85 | def _fix_sqrt(string): 86 | if "\\sqrt" not in string: 87 | return string 88 | splits = string.split("\\sqrt") 89 | new_string = splits[0] 90 | for split in splits[1:]: 91 | if split[0] != "{": 92 | a = split[0] 93 | new_substr = "\\sqrt{" + a + "}" + split[1:] 94 | else: 95 | new_substr = "\\sqrt" + split 96 | new_string += new_substr 97 | return new_string 98 | # linebreaks 99 | string = string.replace("\n", "") 100 | # print(string) 101 | 102 | # remove inverse spaces 103 | string = string.replace("\\!", "") 104 | # print(string) 105 | 106 | # replace \\ with \ 107 | string = string.replace("\\\\", "\\") 108 | # print(string) 109 | 110 | # replace tfrac and dfrac with frac 111 | string = string.replace("tfrac", "frac") 112 | string = string.replace("dfrac", "frac") 113 | # print(string) 114 | 115 | # remove \left and \right 116 | string = string.replace("\\left", "") 117 | string = string.replace("\\right", "") 118 | # print(string) 119 | 120 | # Remove circ (degrees) 121 | string = string.replace("^{\\circ}", "") 122 | string = string.replace("^\\circ", "") 123 | 124 | # remove dollar signs 125 | string = string.replace("\\$", "") 126 | 127 | # remove units (on the right) 128 | string = _remove_right_units(string) 129 | 130 | # remove percentage 131 | string = string.replace("\\%", "") 132 | string = string.replace("\%", "") 133 | 134 | # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string 135 | string = string.replace(" .", " 0.") 136 | string = string.replace("{.", "{0.") 137 | # if empty, return empty string 138 | if len(string) == 0: 139 | return string 140 | if string[0] == ".": 141 | string = "0" + string 142 | 143 | # to consider: get rid of e.g. "k = " or "q = " at beginning 144 | if len(string.split("=")) == 2: 145 | if len(string.split("=")[0]) <= 2: 146 | string = string.split("=")[1] 147 | 148 | # fix sqrt3 --> sqrt{3} 149 | string = _fix_sqrt(string) 150 | 151 | # remove spaces 152 | string = string.replace(" ", "") 153 | 154 | # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} 155 | string = _fix_fracs(string) 156 | 157 | # manually change 0.5 --> \frac{1}{2} 158 | if string == "0.5": 159 | string = "\\frac{1}{2}" 160 | 161 | # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y 162 | string = _fix_a_slash_b(string) 163 | 164 | return string 165 | 166 | 167 | # sympy might hang -- we don't care about trying to be lenient in these cases 168 | BAD_SUBSTRINGS = ["^{", "^("] 169 | BAD_REGEXES = ["\^[0-9]+\^", "\^[0-9][0-9]+"] 170 | TUPLE_CHARS = "()[]" 171 | 172 | 173 | def _sympy_parse(expr: str): 174 | """Parses an expression with sympy.""" 175 | py_expr = expr.replace("^", "**") 176 | return sympy_parser.parse_expr( 177 | py_expr, 178 | transformations=( 179 | sympy_parser.standard_transformations 180 | + (sympy_parser.implicit_multiplication_application,) 181 | ), 182 | ) 183 | 184 | 185 | def _parse_latex(expr: str) -> str: 186 | """Attempts to parse latex to an expression sympy can read.""" 187 | expr = expr.replace("\\tfrac", "\\frac") 188 | expr = expr.replace("\\dfrac", "\\frac") 189 | expr = expr.replace("\\frac", " \\frac") # Play nice with mixed numbers. 190 | expr = latex2text.LatexNodes2Text().latex_to_text(expr) 191 | 192 | # Replace the specific characters that this parser uses. 193 | expr = expr.replace("√", "sqrt") 194 | expr = expr.replace("π", "pi") 195 | expr = expr.replace("∞", "inf") 196 | expr = expr.replace("∪", "U") 197 | expr = expr.replace("·", "*") 198 | expr = expr.replace("×", "*") 199 | 200 | return expr.strip() 201 | 202 | 203 | def _is_float(num: str) -> bool: 204 | try: 205 | float(num) 206 | return True 207 | except ValueError: 208 | return False 209 | 210 | 211 | def _is_int(x: float) -> bool: 212 | try: 213 | return abs(x - int(round(x))) <= 1e-7 214 | except: 215 | return False 216 | 217 | 218 | def _is_frac(expr: str) -> bool: 219 | return bool(re.search(r"^-?[0-9]+.?/0*[1-9][0-9]*.?$", expr)) 220 | 221 | 222 | def _str_is_int(x: str) -> bool: 223 | try: 224 | x = _strip_properly_formatted_commas(x) 225 | x = float(x) 226 | return abs(x - int(round(x))) <= 1e-7 227 | except: 228 | return False 229 | 230 | 231 | def _str_to_int(x: str) -> bool: 232 | x = x.replace(",", "") 233 | x = float(x) 234 | return int(x) 235 | 236 | 237 | def _inject_implicit_mixed_number(step: str): 238 | """ 239 | Automatically make a mixed number evalable 240 | e.g. 7 3/4 => 7+3/4 241 | """ 242 | p1 = re.compile("([0-9]) +([0-9])") 243 | step = p1.sub("\\1+\\2", step) ## implicit mults 244 | return step 245 | 246 | 247 | def _strip_properly_formatted_commas(expr: str): 248 | # We want to be careful because we don't want to strip tuple commas 249 | p1 = re.compile("(\d)(,)(\d\d\d)($|\D)") 250 | while True: 251 | next_expr = p1.sub("\\1\\3\\4", expr) 252 | if next_expr == expr: 253 | break 254 | expr = next_expr 255 | return next_expr 256 | 257 | 258 | def _normalize(expr: str) -> str: 259 | """Normalize answer expressions.""" 260 | if expr is None: 261 | return None 262 | 263 | # Remove enclosing `\text{}`. 264 | m = re.search("^\\\\text\{(?P.+?)\}$", expr) 265 | if m is not None: 266 | expr = m.group("text") 267 | 268 | expr = expr.replace("\\%", "%") 269 | expr = expr.replace("\\$", "$") 270 | expr = expr.replace("$", "") 271 | expr = expr.replace("%", "") 272 | expr = expr.replace(" or ", " , ") 273 | expr = expr.replace(" and ", " , ") 274 | 275 | expr = expr.replace("million", "*10^6") 276 | expr = expr.replace("billion", "*10^9") 277 | expr = expr.replace("trillion", "*10^12") 278 | 279 | for unit in [ 280 | "degree", 281 | "cm", 282 | "centimeter", 283 | "meter", 284 | "mile", 285 | "second", 286 | "minute", 287 | "hour", 288 | "day", 289 | "week", 290 | "month", 291 | "year", 292 | "foot", 293 | "feet", 294 | "inch", 295 | "yard", 296 | ]: 297 | expr = re.sub(f"{unit}(es)?(s)? *(\^[0-9]+)?", "", expr) 298 | expr = re.sub(f"\^ *\\\\circ", "", expr) 299 | 300 | if len(expr) > 0 and expr[0] == "{" and expr[-1] == "}": 301 | expr = expr[1:-1] 302 | 303 | expr = re.sub(",\\\\! *", "", expr) 304 | if _is_float(expr) and _is_int(float(expr)): 305 | expr = str(int(round(float(expr)))) 306 | if "\\" in expr: 307 | try: 308 | expr = _parse_latex(expr) 309 | except: 310 | pass 311 | 312 | # edge case with mixed numbers and negative signs 313 | expr = re.sub("- *", "-", expr) 314 | 315 | expr = _inject_implicit_mixed_number(expr) 316 | expr = expr.replace(" ", "") 317 | 318 | # if we somehow still have latex braces here, just drop them 319 | expr = expr.replace("{", "") 320 | expr = expr.replace("}", "") 321 | 322 | # don't be case sensitive for text answers 323 | expr = expr.lower() 324 | 325 | if _str_is_int(expr): 326 | expr = str(_str_to_int(expr)) 327 | 328 | return expr 329 | 330 | 331 | def count_unknown_letters_in_expr(expr: str): 332 | expr = expr.replace("sqrt", "") 333 | expr = expr.replace("frac", "") 334 | letters_in_expr = set([x for x in expr if x.isalpha()]) 335 | return len(letters_in_expr) 336 | 337 | 338 | def should_allow_eval(expr: str): 339 | # we don't want to try parsing unknown text or functions of more than two variables 340 | if count_unknown_letters_in_expr(expr) > 2: 341 | return False 342 | 343 | for bad_string in BAD_SUBSTRINGS: 344 | if bad_string in expr: 345 | return False 346 | 347 | for bad_regex in BAD_REGEXES: 348 | if re.search(bad_regex, expr) is not None: 349 | return False 350 | 351 | return True 352 | 353 | 354 | def are_equal_under_sympy(ground_truth_normalized: str, given_normalized: str): 355 | are_equal = False 356 | try: 357 | expr = f"({ground_truth_normalized})-({given_normalized})" 358 | if should_allow_eval(expr): 359 | sympy_diff = _sympy_parse(expr) 360 | simplified = sympy.simplify(sympy_diff) 361 | if simplified == 0: 362 | are_equal = True 363 | except: 364 | pass 365 | return are_equal 366 | 367 | 368 | def split_tuple(expr: str): 369 | """ 370 | Split the elements in a tuple/interval, while handling well-formatted commas in large numbers 371 | """ 372 | expr = _strip_properly_formatted_commas(expr) 373 | if len(expr) == 0: 374 | return [] 375 | if ( 376 | len(expr) > 2 377 | and expr[0] in TUPLE_CHARS 378 | and expr[-1] in TUPLE_CHARS 379 | and all([ch not in expr[1:-1] for ch in TUPLE_CHARS]) 380 | ): 381 | elems = [elem.strip() for elem in expr[1:-1].split(",")] 382 | else: 383 | elems = [expr] 384 | return elems 385 | 386 | 387 | def last_boxed_only_string(string): 388 | idx = string.rfind("\\boxed") 389 | if idx < 0: 390 | idx = string.rfind("\\fbox") 391 | if idx < 0: 392 | return None 393 | 394 | i = idx 395 | right_brace_idx = None 396 | num_left_braces_open = 0 397 | while i < len(string): 398 | if string[i] == "{": 399 | num_left_braces_open += 1 400 | if string[i] == "}": 401 | num_left_braces_open -= 1 402 | if num_left_braces_open == 0: 403 | right_brace_idx = i 404 | break 405 | i += 1 406 | 407 | if right_brace_idx == None: 408 | retval = None 409 | else: 410 | retval = string[idx:right_brace_idx + 1] 411 | 412 | return retval 413 | 414 | def remove_boxed(s): 415 | left = "\\boxed{" 416 | try: 417 | assert s[:len(left)] == left 418 | assert s[-1] == "}" 419 | return s[len(left):-1] 420 | except: 421 | return None 422 | 423 | 424 | def extract_boxed_answer(solution: str) -> str: 425 | """Extract the answer from inside a LaTeX \\boxed{} command""" 426 | solution = last_boxed_only_string(solution) 427 | solution = remove_boxed(solution) 428 | return solution 429 | 430 | def grade_answer_sympy(given_answer: str, ground_truth: str) -> bool: 431 | ground_truth_normalized = _normalize(ground_truth) 432 | given_normalized = _normalize(given_answer) 433 | 434 | if ground_truth_normalized is None: 435 | return False 436 | 437 | if ground_truth_normalized == given_normalized: 438 | return True 439 | 440 | if len(given_normalized) == 0: 441 | return False 442 | 443 | ground_truth_elems = split_tuple(ground_truth_normalized) 444 | given_elems = split_tuple(given_normalized) 445 | 446 | if len(ground_truth_elems) > 1 and ( 447 | ground_truth_normalized[0] != given_normalized[0] 448 | or ground_truth_normalized[-1] != given_normalized[-1] 449 | ): 450 | is_correct = False 451 | elif len(ground_truth_elems) != len(given_elems): 452 | is_correct = False 453 | else: 454 | for ground_truth_elem, given_elem in zip(ground_truth_elems, given_elems): 455 | if _is_frac(ground_truth_elem) and _is_frac(given_elem): 456 | # if fractions aren't reduced, then shouldn't be marked as correct 457 | # so, we don't want to allow sympy.simplify in this case 458 | is_correct = ground_truth_elem == given_elem 459 | elif _str_is_int(ground_truth_elem) != _str_is_int(given_elem): 460 | # if the ground truth answer is an integer, we require the given answer to be a strict match (no sympy.simplify) 461 | is_correct = False 462 | else: 463 | is_correct = are_equal_under_sympy(ground_truth_elem, given_elem) 464 | if not is_correct: 465 | break 466 | 467 | return is_correct 468 | 469 | def grade_answer_mathd(given_answer: str, ground_truth: str) -> bool: 470 | ground_truth_normalized_mathd = mathd_normalize_answer(ground_truth) 471 | given_answer_normalized_mathd = mathd_normalize_answer(given_answer) 472 | 473 | # be at least as lenient as mathd 474 | if ground_truth_normalized_mathd == given_answer_normalized_mathd: 475 | return True 476 | return False 477 | 478 | 479 | 480 | def extract_answer(passage: str) -> str: 481 | if "\\boxed" in passage: 482 | return extract_boxed_answer(passage) 483 | return None 484 | 485 | def grade_answer_verl(solution_str, ground_truth): 486 | if not ground_truth: 487 | return False 488 | if '\\boxed' in ground_truth: 489 | ground_truth = extract_answer(ground_truth) 490 | given_answer = extract_answer(solution_str) 491 | if given_answer is None: 492 | return False 493 | return grade_answer_mathd(given_answer, ground_truth) \ 494 | or grade_answer_sympy(given_answer, ground_truth) 495 | 496 | def main(): 497 | # Example usage 498 | ground_truth = "\\frac{e^{2 i t}}{(2+2 i)}" 499 | given_answer = "\\boxed{\\dfrac{e^{2i t}}{2 + 2i}}" 500 | is_correct = grade_answer_verl(given_answer, ground_truth) 501 | print(f"Is the given answer correct? {is_correct}") 502 | 503 | if __name__ == "__main__": 504 | main() --------------------------------------------------------------------------------