├── .gitignore
├── data
├── AIME24
│ └── test.parquet
├── AIME25
│ └── test.parquet
├── AMC23
│ └── test.parquet
├── HMMT25
│ └── test.parquet
├── BRUMO25
│ └── test.parquet
├── CMIMC25
│ └── test.parquet
├── MATH-500
│ └── test.parquet
├── Minerva
│ └── test.parquet
└── Olympiad-Bench
│ └── test.parquet
├── assets
└── fig1_aime24_curves_added.png
├── evals
├── gen_vllm.py
├── grade.py
└── utils.py
└── README.md
/.gitignore:
--------------------------------------------------------------------------------
1 | evals/__pycache__/
2 | justrl_eval_outputs/
--------------------------------------------------------------------------------
/data/AIME24/test.parquet:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/JustRL/HEAD/data/AIME24/test.parquet
--------------------------------------------------------------------------------
/data/AIME25/test.parquet:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/JustRL/HEAD/data/AIME25/test.parquet
--------------------------------------------------------------------------------
/data/AMC23/test.parquet:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/JustRL/HEAD/data/AMC23/test.parquet
--------------------------------------------------------------------------------
/data/HMMT25/test.parquet:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/JustRL/HEAD/data/HMMT25/test.parquet
--------------------------------------------------------------------------------
/data/BRUMO25/test.parquet:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/JustRL/HEAD/data/BRUMO25/test.parquet
--------------------------------------------------------------------------------
/data/CMIMC25/test.parquet:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/JustRL/HEAD/data/CMIMC25/test.parquet
--------------------------------------------------------------------------------
/data/MATH-500/test.parquet:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/JustRL/HEAD/data/MATH-500/test.parquet
--------------------------------------------------------------------------------
/data/Minerva/test.parquet:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/JustRL/HEAD/data/Minerva/test.parquet
--------------------------------------------------------------------------------
/data/Olympiad-Bench/test.parquet:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/JustRL/HEAD/data/Olympiad-Bench/test.parquet
--------------------------------------------------------------------------------
/assets/fig1_aime24_curves_added.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/JustRL/HEAD/assets/fig1_aime24_curves_added.png
--------------------------------------------------------------------------------
/evals/gen_vllm.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import random
4 | import concurrent.futures
5 | from pathlib import Path
6 |
7 | import pandas as pd
8 | from tqdm import tqdm
9 | from vllm import LLM, SamplingParams
10 |
11 | # --------------------------------------------------------------------------- #
12 | # Global constants / variables #
13 | # --------------------------------------------------------------------------- #
14 | DATA_DIR = "data"
15 | TASKS = [
16 | {"name": "AIME24", "path": f"{DATA_DIR}/AIME24/test.parquet", "N": 32},
17 | {"name": "AIME25", "path": f"{DATA_DIR}/AIME25/test.parquet", "N": 32},
18 | {"name": "AMC23", "path": f"{DATA_DIR}/AMC23/test.parquet", "N": 32},
19 | {"name": "MATH-500", "path": f"{DATA_DIR}/MATH-500/test.parquet", "N": 4},
20 | {"name": "Minerva", "path": f"{DATA_DIR}/Minerva/test.parquet", "N": 4},
21 | {"name": "Olympiad-Bench", "path": f"{DATA_DIR}/Olympiad-Bench/test.parquet", "N": 4},
22 | {"name": "BRUMO25", "path": f"{DATA_DIR}/BRUMO25/test.parquet", "N": 32},
23 | {"name": "CMIMC25", "path": f"{DATA_DIR}/CMIMC25/test.parquet", "N": 32},
24 | {"name": "HMMT25", "path": f"{DATA_DIR}/HMMT25/test.parquet", "N": 32},
25 | ]
26 | PROMPT_TEMPLATE = """{problem} Please reason step by step, and put your final answer within \\boxed{{}}."""
27 | NAME = "hbx/JustRL-DeepSeek-1.5B" # "hbx/JustRL-Nemotron-1.5B"
28 | MAX_TOKENS = 31744
29 | TEMPERATURE = 0.7
30 | TOP_P = 0.9
31 | OUT_DIR = Path(f"justrl_eval_outputs/{NAME.split('/')[-1]}")
32 | OUT_DIR.mkdir(parents=True, exist_ok=True)
33 |
34 | # --------------------------------------------------------------------------- #
35 | # Helper functions #
36 | # --------------------------------------------------------------------------- #
37 | def load_samples(filepath: str):
38 | """Read parquet file and return a list of prompts (no duplication)."""
39 | df = pd.read_parquet(filepath)
40 | if "BRUMO25" in filepath or "CMIMC25" in filepath or "HMMT25" in filepath:
41 | samples = [
42 | {
43 | "example_id": i,
44 | "prompt": df.at[i, "problem"].strip(),
45 | "answer": df.at[i, "answer"].strip(),
46 | }
47 | for i in range(len(df))
48 | ]
49 | else:
50 | samples = [
51 | {
52 | "example_id": i,
53 | "prompt": df.at[i, "prompt"][0]["content"].strip(),
54 | "answer": df.at[i, "reward_model"]["ground_truth"].strip(),
55 | }
56 | for i in range(len(df))
57 | ]
58 | print(f"Total unique samples: {len(samples)}")
59 | return samples
60 |
61 |
62 | def split_seeds(seeds: list[int], num_workers: int):
63 | """Round-robin split of the seed list into num_workers chunks."""
64 | chunks = [[] for _ in range(num_workers)]
65 | for idx, s in enumerate(seeds):
66 | chunks[idx % num_workers].append(s)
67 | return chunks
68 |
69 |
70 | # --------------------------------------------------------------------------- #
71 | # Worker process (one GPU) #
72 | # --------------------------------------------------------------------------- #
73 | def worker_process(args_tuple):
74 | """
75 | Each worker runs on a single GPU:
76 |
77 | args_tuple = (samples, seed_list, gpu_id)
78 | """
79 | samples, seed_list, gpu_id = args_tuple
80 | os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
81 | print(f"[GPU {gpu_id}] seeds={seed_list} | loading model...", flush=True)
82 |
83 | llm = LLM(model=NAME, enforce_eager=True)
84 | results = []
85 |
86 | for seed in seed_list:
87 | sampling = SamplingParams(
88 | temperature=TEMPERATURE,
89 | top_p=TOP_P,
90 | max_tokens=MAX_TOKENS,
91 | seed=seed,
92 | )
93 | messages = [[{"role": "user", "content": s["prompt"]}] for s in samples]
94 | outputs = llm.chat(messages, sampling, use_tqdm=True)
95 | for sample, out in zip(samples, outputs):
96 | results.append(
97 | {
98 | "example_id": sample["example_id"],
99 | "prompt": sample["prompt"],
100 | "answer": sample["answer"],
101 | "seed": seed,
102 | "response": out.outputs[0].text,
103 | }
104 | )
105 | return results
106 |
107 |
108 | # --------------------------------------------------------------------------- #
109 | # main #
110 | # --------------------------------------------------------------------------- #
111 | def main():
112 | available_workers = [0,1,2,3,4,5,6,7]
113 | num_workers = len(available_workers)
114 | for task in TASKS:
115 | task_name = task["name"]
116 | task_path = task["path"]
117 | N = task["N"]
118 |
119 | print(f"Starting evaluation for task: {task_name} (N={N})")
120 |
121 | # Update output path for the current task
122 | out_path = OUT_DIR / f"{task_name.lower()}_t{TEMPERATURE}_p{TOP_P}_n{N}-MNT{MAX_TOKENS}.jsonl"
123 |
124 | # 1. Load original prompts
125 | samples = load_samples(task_path)
126 |
127 | # Append suffix prompt to each sample
128 | for sample in samples:
129 | sample["prompt"] = PROMPT_TEMPLATE.format(problem=sample["prompt"])
130 |
131 | # demo print
132 | print("Example prompt after formatting:")
133 | print(samples[0]["prompt"])
134 |
135 | # 2. Generate N distinct random seeds and split across GPUs
136 | random_seeds = random.sample(range(2**31 - 1), N) # unique & shuffled
137 | seed_chunks = split_seeds(random_seeds, num_workers)
138 |
139 | # 3. Launch workers
140 | all_results = []
141 | args_list = [(samples, seed_chunks[i], gid) for (i, gid) in enumerate(available_workers)]
142 | with concurrent.futures.ProcessPoolExecutor(max_workers=num_workers) as ex:
143 | futures = [ex.submit(worker_process, tup) for tup in args_list]
144 | for fut in tqdm(concurrent.futures.as_completed(futures),
145 | total=len(futures), desc=f"GPU workers ({task_name})"):
146 | all_results.extend(fut.result())
147 |
148 | print(f"Total generations collected for {task_name}: {len(all_results)}") # len(samples) * N
149 |
150 | # 4. Save to disk
151 | with out_path.open("w", encoding="utf-8") as f:
152 | for item in all_results:
153 | f.write(json.dumps(item, ensure_ascii=False) + "\n")
154 | print(f"Saved results for {task_name} to {out_path}")
155 |
156 |
157 | if __name__ == "__main__":
158 | main()
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
JustRL: Simplicity at Scale
3 |
4 | 🚀 Competitive RL Performance Without Complex Techniques 🌟
5 |
6 |
7 |
8 |
9 |
23 |
24 | ## 📰 Overview
25 |
26 | **JustRL** demonstrates that competitive reinforcement learning performance for small language models doesn't require complex multi-stage pipelines or dynamic schedules. Using a minimal recipe with single-stage training and fixed hyperparameters, we achieve state-of-the-art results on mathematical reasoning tasks. This repository contains a lightweight evaluation script to reproduce evaluation results for **JustRL** models on nine challenging math benchmarks.
27 |
28 | We release two models:
29 |
30 | - [**JustRL-DeepSeek-1.5B**](https://huggingface.co/hbx/JustRL-DeepSeek-1.5B): Trained from DeepSeek-R1-Distill-Qwen-1.5B
31 | - [**JustRL-Nemotron-1.5B**](https://huggingface.co/hbx/JustRL-Nemotron-1.5B): Trained from OpenMath-Nemotron-1.5B
32 |
33 | Both models use identical hyperparameters without per-model tuning, demonstrating the robustness of our approach.
34 |
35 | 
36 |
37 | ## 🎯 Key Highlights
38 |
39 | ✨ **Simplicity**: Single-stage training with fixed hyperparameters, without multi-stage pipelines or dynamic schedules
40 |
41 | 📈 **Stability**: Smooth, monotonic improvement over 4,000+ training steps without collapses or oscillations
42 |
43 | 🎯 **Performance**: State-of-the-art results at 1.5B scale, matching or exceeding more complex approaches
44 |
45 | 💰 **Efficiency**: Comparable or better performance with 2× less compute than multi-stage methods
46 |
47 | 🔓 **Open**: Complete evaluation scripts, and model weights released
48 |
49 | ## 📁 Repository Structure
50 |
51 | ```
52 | JustRL/
53 | ├── evals/ # Evaluation scripts
54 | │ ├── gen_vllm.py # Generation script using vLLM
55 | │ ├── grade.py # Grading script with hybrid verification
56 | │ └── utils.py # Answer verification utilities
57 | ├── data/ # Benchmark datasets
58 | │ ├── AIME24/
59 | │ ├── AIME25/
60 | │ ├── AMC23/
61 | │ ├── MATH-500/
62 | │ ├── Minerva/
63 | │ ├── Olympiad-Bench/
64 | │ ├── BRUMO25/
65 | │ ├── CMIMC25/
66 | │ └── HMMT25/
67 | └── justrl_eval_outputs/ # Evaluation outputs (download from Google Drive)
68 | ├── JustRL-DeepSeek-1.5B/
69 | │ ├── *.jsonl # Generation outputs per benchmark
70 | │ └── grading_results.json
71 | └── JustRL-Nemotron-1.5B/
72 | ├── *.jsonl
73 | └── grading_results.json
74 | ```
75 |
76 | ## 🔧 Setup
77 |
78 | ### Environment Requirements
79 |
80 | We recommend using a conda environment with the following key dependencies:
81 |
82 | ```bash
83 | conda create -n justrl python=3.10
84 | conda activate justrl
85 | ```
86 |
87 | ### Key Dependencies
88 |
89 | - **PyTorch**: `2.6.0`
90 | - **vLLM**: `0.8.4`
91 | - **transformers**: `4.51.3`
92 | - **sympy**: `1.13.1`
93 | - **pylatexenc**: `2.10`
94 |
95 | ### Download Evaluation Outputs
96 |
97 | The evaluation outputs are large and hosted on Google Drive. Download them for reproduction:
98 |
99 | **📥 Download Link**: [Google Drive](https://drive.google.com/file/d/1G5oHTNYR8edbj6NLDgY8_6X3SB1MDngc/view?usp=sharing)
100 |
101 | After downloading, extract the `justrl_eval_outputs/` directory to the repository root directory.
102 |
103 | ## 🚀 Usage
104 |
105 | This evaluation script is based on [POLARIS](https://github.com/ChenxinAn-fdu/POLARIS), with one key modification: we add a model-based verifier ([CompassVerifier-3B](https://huggingface.co/opencompass/CompassVerifier-3B)) for more robust evaluation, complementing the rule-based verification system.
106 |
107 | ### Generation (Optional)
108 |
109 | ```bash
110 | cd evals
111 | python gen_vllm.py
112 | ```
113 |
114 | Configure the model name in `gen_vllm.py` by setting the `NAME` variable. And set appropriate`available_workers`.
115 |
116 | ### Grading
117 |
118 | ```bash
119 | cd evals
120 | python grade.py
121 | ```
122 |
123 | The grading script processes all JSONL files in the output directory and generates `grading_results.json`.
124 |
125 | ## 📈 Performance
126 |
127 | ### JustRL-DeepSeek-1.5B (Based on DeepSeek-R1-Distill-Qwen-1.5B)
128 |
129 | | Model | AIME24 (@32) | AIME25 (@32) | AMC23 (@32) | MATH-500 (@4) | Minerva (@4) | OlympiadBench (@4) | HMMT25 (@32) | BRUMO25 (@32) | CMIMC25 (@32) | Avg |
130 | | ------------------------ | ------------ | ------------ | ----------- | ------------- | ------------ | ------------------ | ------------ | ------------- | ------------- | --------- |
131 | | DeepSeek-R1-Distill-1.5B | 29.90 | 22.40 | 63.82 | 84.90 | 34.65 | 45.95 | 13.44 | 30.94 | 12.89 | 37.65 |
132 | | DeepScaleR-1.5B-Preview | 40.21 | 28.65 | 73.83 | 89.30 | 39.34 | 52.79 | 18.96 | 40.00 | 21.00 | 44.88 |
133 | | ProRL-V2 | 51.87 | 35.73 | 88.75 | 92.00 | 49.03 | 67.84 | 19.38 | 47.29 | **25.86** | 53.08 |
134 | | BroRL | **57.50** | 36.88 | / | **92.14** | 49.08 | 61.54 | / | / | / | / |
135 | | JustRL-DeepSeek-1.5B | 52.60 | **38.75** | **91.02** | 91.65 | **51.47** | **67.99** | **21.98** | **52.71** | 25.63 | **54.87** |
136 |
137 | Besides, the real question is whether our simplicity comes at a computational cost. It doesn't. We match half of ProRL-V2's compute budget while using a single-stage recipe with fixed hyperparameters. BroRL requires 4.9× more compute by increasing rollouts to 512 per example, essentially exhaustively exploring the solution space. Our approach achieves competitive performance without this computational overhead.
138 |
139 | ### JustRL-Nemotron-1.5B (Based on OpenMath-Nemotron-1.5B)
140 |
141 | | Model | AIME24 (@32) | AIME25 (@32) | AMC23 (@32) | MATH-500 (@4) | Minerva (@4) | OlympiadBench (@4) | HMMT25 (@32) | BRUMO25 (@32) | CMIMC25 (@32) | Avg |
142 | | ---------------------- | ------------ | ------------ | ----------- | ------------- | ------------ | ------------------ | ------------ | ------------- | ------------- | --------- |
143 | | OpenMath-Nemotron-1.5B | 58.75 | 48.44 | 90.55 | 92.40 | 26.93 | 71.70 | 30.10 | 61.67 | 30.08 | 56.74 |
144 | | QUESTA-Nemotron-1.5B | **71.56** | 62.08 | 93.44 | 92.95 | **32.08** | 72.28 | **40.94** | **67.50** | 41.48 | 63.81 |
145 | | JustRL-Nemotron-1.5B | 69.69 | **62.92** | **96.02** | **94.15** | 30.24 | **76.59** | 40.63 | 66.88 | **41.72** | **64.32** |
146 |
147 | We achieve 64.32% average, slightly outperforming QuestA's 63.81% and leading on five of nine benchmarks. The gap is narrow, which makes sense—both approaches are pushing the boundaries of what's achievable at 1.5B scale. The key difference is in how we get there. We use 2× less compute while achieving slightly better average performance without designing a complex curriculum as used in QuestA.
148 |
149 | ## 📖 Training Recipe
150 |
151 | Our approach is deliberately minimal:
152 |
153 | **Core Algorithm**: Standard GRPO with binary outcome rewards
154 |
155 | - **Reward**: Simple DAPO verifier (string-matching, no SymPy)
156 | - **Training**: Single-stage, no curriculum or stage transitions
157 | - **Hyperparameters**: Fixed throughout (no adaptive schedules)
158 | - **Data**: DAPO-Math-17k without filtering or dynamic sampling
159 | - **Length Control**: 16K context cap (no explicit penalties)
160 | - **Stabilization**: Only "clip higher" for gradient stability
161 |
162 | Detail hyperparameters and comparisons on training techniques with other methods can refer to our [blog](https://relieved-cafe-fe1.notion.site/JustRL-Scaling-a-1-5B-LLM-with-a-Simple-RL-Recipe-24f6198b0b6b80e48e74f519bfdaf0a8).
163 |
164 | **Training Data**: We train on [DAPO-Math-17k](https://huggingface.co/datasets/BytedTsinghua-SIA/DAPO-Math-17k), a curated dataset of mathematical problems. **No offline difficulty filtering or online dynamic sampling is used.**
165 |
166 | ## 🎈 Citation
167 |
168 | ```bibtex
169 | @misc{he2025justrlscaling15bllm,
170 | title={JustRL: Scaling a 1.5B LLM with a Simple RL Recipe},
171 | author={Bingxiang He and Zekai Qu and Zeyuan Liu and Yinghao Chen and Yuxin Zuo and Cheng Qian and Kaiyan Zhang and Weize Chen and Chaojun Xiao and Ganqu Cui and Ning Ding and Zhiyuan Liu},
172 | year={2025},
173 | eprint={2512.16649},
174 | archivePrefix={arXiv},
175 | primaryClass={cs.CL},
176 | url={https://arxiv.org/abs/2512.16649},
177 | }
178 | ```
179 |
--------------------------------------------------------------------------------
/evals/grade.py:
--------------------------------------------------------------------------------
1 |
2 | from utils import grade_answer_verl
3 | from transformers import AutoTokenizer
4 | import json
5 | import pandas as pd
6 | from pathlib import Path
7 | import re
8 | from vllm import LLM, SamplingParams
9 |
10 | CV_PROMPT = """
11 | Please as a grading expert, judge whether the final answers given by the candidates below are consistent with the standard answers, that is, whether the candidates answered correctly.
12 | Here are some evaluation criteria:
13 | 1. Please refer to the given standard answer. You don't need to re-generate the answer to the question because the standard answer has been given. You only need to judge whether the candidate's answer is consistent with the standard answer according to the form of the question. THE STANDARD ANSWER IS ALWAYS CORRECT AND THE QUESTION IS PERFECTLY VALID. NEVER QUESTION THEM.
14 | 2. ONLY compare the FINAL ANSWER - COMPLETELY IGNORE any potential errors in the REASONING PROCESSES.
15 | 3. Some answers may be expressed in different ways, such as some answers may be a mathematical expression, some answers may be a textual description, as long as the meaning expressed is the same. Before making a judgment, please understand the question and the standard answer first, and then judge whether the candidate's answer is correct.
16 | 4. Some answers may consist of multiple items, such as multiple-choice questions, multiple-select questions, fill-in-the-blank questions, etc. Regardless of the question type, the final answer will be considered correct as long as it matches the standard answer, regardless of whether the reasoning process is correct. For multiple-select questions and multi-blank fill-in-the-blank questions, all corresponding options or blanks must be answered correctly and match the standard answer exactly to be deemed correct.
17 | 5. If the prediction is given with \\boxed{{}}, please ignore the \\boxed{{}} and only judge whether the candidate's answer is consistent with the standard answer.
18 | 6. If the candidate's answer is invalid (e.g., incomplete (cut off mid-response), lots of unnormal repetitive content, or irrelevant to the question, saying it can't answer the question because some irresistible factors, like ethical issues, no enough information, etc.), select option C (INVALID).Please judge whether the following answers are consistent with the standard answer based on the above criteria. Grade the predicted answer of this new question as one of:
19 | A: CORRECT
20 | B: INCORRECT
21 | C: INVALID
22 | Just return the letters "A", "B", or "C", with no text around it.
23 | Here is your task. Simply reply with either CORRECT, INCORRECT, or INVALID. Don't apologize or correct yourself if there was a mistake; we are just trying to grade the answer.
24 | :
25 | {question}
26 |
27 | :
28 | {gold_answer}
29 |
30 | :
31 | {llm_response}
32 |
33 | Judging the correctness of the candidate's answer:
34 | """
35 |
36 | NAME = "JustRL-Nemotron-1.5B" # "JustRL-Nemotron-1.5B"
37 | EVAL_DIR = Path(f"justrl_eval_outputs/{NAME}")
38 | OUTPUT_FILE = EVAL_DIR / "grading_results.json"
39 |
40 | model_name = "opencompass/CompassVerifier-3B"
41 | model_tokenizer = AutoTokenizer.from_pretrained(model_name)
42 | vllm_model = LLM(
43 | model=model_name,
44 | tensor_parallel_size=1
45 | )
46 | sampling_params = SamplingParams(
47 | temperature=0.0,
48 | max_tokens=2048
49 | )
50 |
51 | length_tokenizer = None
52 |
53 | def get_len(seq):
54 | return len(length_tokenizer.encode(seq))
55 |
56 | def get_diverse_score(sequences, n=4):
57 | """
58 | calculate the Distinct-n score。
59 |
60 | sequences: List[str] response list
61 | n: int, n-gram default=4
62 | """
63 | distinct_ngrams = set()
64 | total_ngrams = 0
65 |
66 | for seq in sequences:
67 | # more accurate n-gram
68 | # tokens = nltk.word_tokenize(seq)
69 | tokens = seq.split()
70 | for i in range(len(tokens) - n + 1):
71 | ngram = tuple(tokens[i:i + n])
72 | distinct_ngrams.add(ngram)
73 | total_ngrams += 1
74 |
75 | return len(distinct_ngrams) / total_ngrams if total_ngrams > 0 else 0
76 |
77 | def process_jsonl_file(file_name):
78 | """
79 | Process a JSONL file and dynamically handle the number of problems.
80 | """
81 | results = []
82 | with open(file_name) as f:
83 | for line in f:
84 | data = json.loads(line)
85 | id = int(data["example_id"])
86 | while len(results) <= id: # Ensure the list is large enough
87 | results.append({"gt": None, "responses": []})
88 | gt = data["answer"]
89 | response = data["response"]
90 | results[id]["gt"] = gt
91 | results[id]["responses"].append(response)
92 | return results
93 |
94 | def parse_hyperparameters_from_filename(filename):
95 | """
96 | Parse hyperparameters from the filename.
97 | Example filename format: {taskname}_t{temperature}_p{top_p}_n{n}-MNT{max_tokens}.jsonl
98 | """
99 | match = re.search(r"_t(?P[\d.]+)_p(?P[\d.]+)_n(?P\d+)-MNT(?P\d+)",
100 | filename)
101 | return match.groupdict() if match else {}
102 |
103 | def grade_file(file_path):
104 | """
105 | Grade a single file and return the results.
106 | """
107 | hyperparams = parse_hyperparameters_from_filename(file_path.name)
108 | if not hyperparams:
109 | print(f"Skipping file with unrecognized format: {file_path}")
110 | return None
111 |
112 | task_name = file_path.stem.split("_")[0]
113 | hyperparams["task_name"] = task_name
114 |
115 | if "parquet" in str(file_path):
116 | df = pd.read_parquet(file_path)
117 | num_pred = len(df["responses"][0])
118 | else:
119 | df = process_jsonl_file(file_path)
120 | num_pred = len(df[0]["responses"])
121 |
122 | results = {
123 | "hyperparameters": hyperparams,
124 | "mean_score": 0,
125 | "distinct_4gram": 0,
126 | "best_score": 0,
127 | "solve_none": 0,
128 | "solve_all": 0,
129 | "avg_output_length": 0,
130 | "format_error_rollouts": 0,
131 | }
132 |
133 | diverse = []
134 | avg_scores = []
135 | best = []
136 | solve_none = 0
137 | solve_all = 0
138 | without_boxed = 0
139 | response_lengths = []
140 | incorrect_data = [] # List to store incorrect responses and ground truths
141 |
142 | all_model_inputs = [] # Collect all prompts for batch processing
143 | all_responses = [] # Keep track of responses for mapping back
144 | all_questions = [] # Keep track of questions for mapping back
145 | all_ground_truths = [] # Keep track of ground truths for mapping back
146 | rule_based_scores = [] # Store rule-based scores for fallback logic
147 |
148 | for i in range(len(df)):
149 | if "jsonl" in str(file_path):
150 | responses = df[i]["responses"]
151 | gt = df[i]["gt"]
152 | question = df[i].get("question", "") # Assuming question is part of the data
153 | else:
154 | responses = df["responses"][i]
155 | gt = df["reward_model"][i]["ground_truth"]
156 | question = df["reward_model"][i].get("question", "")
157 |
158 | responses_list = [str(response) for response in responses]
159 | if length_tokenizer:
160 | response_lengths += [get_len(response) for response in responses_list]
161 | else:
162 | response_lengths = [0]
163 | not_formated = ["boxed" not in response for response in responses_list]
164 | without_boxed += sum(not_formated)
165 |
166 | # First, use the rule-based verifier
167 | for response in responses_list:
168 | rule_score = grade_answer_verl(response, gt)
169 | rule_based_scores.append(rule_score)
170 | if not rule_score: # If rule-based verifier fails, prepare for model-based verifier
171 | model_input = CV_PROMPT.format(
172 | question=question,
173 | gold_answer=gt,
174 | llm_response=response
175 | )
176 | all_model_inputs.append(model_input)
177 | all_responses.append(response)
178 | all_questions.append(question)
179 | all_ground_truths.append(gt)
180 |
181 | diverse.append(get_diverse_score(responses_list))
182 |
183 | # Batch process all model-based verifier inputs
184 | if all_model_inputs:
185 | model_inputs = [model_tokenizer.apply_chat_template(
186 | [{"role": "user", "content": input_text}],
187 | add_generation_prompt=True,
188 | tokenize=False
189 | ) for input_text in all_model_inputs]
190 | outputs = vllm_model.generate(model_inputs, sampling_params)
191 |
192 | # Map back the results to the corresponding responses
193 | model_based_scores = []
194 | for idx, output in enumerate(outputs):
195 | judgement = output.outputs[0].text.strip()
196 | model_score = "A" == judgement # True if "A" (correct), False otherwise
197 | model_based_scores.append(model_score)
198 |
199 | # Save incorrect responses and ground truths
200 | if not model_score:
201 | incorrect_data.append({
202 | "response": all_responses[idx][-300:], # Save last 300 characters
203 | "ground_truth": all_ground_truths[idx]
204 | })
205 |
206 | # Combine rule-based and model-based scores
207 | model_idx = 0
208 | final_scores = []
209 | for rule_score in rule_based_scores:
210 | if rule_score: # If rule-based verifier passed
211 | final_scores.append(rule_score)
212 | else: # Use model-based verifier score
213 | final_scores.append(model_based_scores[model_idx])
214 | model_idx += 1
215 | else:
216 | final_scores = rule_based_scores
217 |
218 | # Calculate metrics
219 | avg_scores = [sum(final_scores[i:i + num_pred]) / num_pred for i in range(0, len(final_scores), num_pred)]
220 | best = [max(final_scores[i:i + num_pred]) for i in range(0, len(final_scores), num_pred)]
221 |
222 | solve_none = sum(1 for avg_score in avg_scores if avg_score == 0)
223 | solve_all = sum(1 for avg_score in avg_scores if avg_score == 1)
224 |
225 | results["mean_score"] = sum(avg_scores) / len(avg_scores)
226 | results["distinct_4gram"] = sum(diverse) / len(diverse)
227 | results["best_score"] = sum(best) / len(best)
228 | results["solve_none"] = solve_none
229 | results["solve_all"] = solve_all
230 | results["avg_output_length"] = sum(response_lengths) / len(response_lengths)
231 | results["format_error_rollouts"] = without_boxed
232 |
233 | # Save incorrect responses and ground truths to a separate file
234 | # incorrect_file = EVAL_DIR / f"{file_path.stem}_incorrect_data.json"
235 | # with incorrect_file.open("w", encoding="utf-8") as f:
236 | # json.dump(incorrect_data, f, indent=4)
237 |
238 | return results
239 |
240 | def main():
241 | all_results = []
242 | for file_path in EVAL_DIR.glob("*.jsonl"):
243 | print(f"Processing file: {file_path}")
244 | file_result = grade_file(file_path)
245 | if file_result:
246 | all_results.append(file_result)
247 |
248 | # Save results to JSON
249 | with OUTPUT_FILE.open("w", encoding="utf-8") as f:
250 | json.dump(all_results, f, indent=4)
251 | print(f"Grading results saved to {OUTPUT_FILE}")
252 |
253 | if __name__ == "__main__":
254 | main()
255 |
256 |
--------------------------------------------------------------------------------
/evals/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Answer checker API that uses sympy to simplify expressions and check for equality.
3 |
4 | Call grade_answer(given_answer: str, ground_truth: str).
5 | """
6 | import re
7 | from pylatexenc import latex2text
8 | import sympy
9 | from sympy.parsing import sympy_parser
10 | from typing import Optional
11 |
12 |
13 | # Dan Hendrycks' code
14 | def mathd_normalize_answer(answer: Optional[str]) -> Optional[str]:
15 | if answer is None:
16 | return None
17 | answer = answer.strip()
18 | try:
19 | # Remove enclosing `\text{}`.
20 | m = re.search("^\\\\text\{(?P.+?)\}$", answer)
21 | if m is not None:
22 | answer = m.group("text").strip()
23 | return _strip_string(answer)
24 | except:
25 | return answer
26 |
27 | def _strip_string(string):
28 | def _fix_fracs(string):
29 | substrs = string.split("\\frac")
30 | new_str = substrs[0]
31 | if len(substrs) > 1:
32 | substrs = substrs[1:]
33 | for substr in substrs:
34 | new_str += "\\frac"
35 | if substr[0] == "{":
36 | new_str += substr
37 | else:
38 | try:
39 | assert len(substr) >= 2
40 | except:
41 | return string
42 | a = substr[0]
43 | b = substr[1]
44 | if b != "{":
45 | if len(substr) > 2:
46 | post_substr = substr[2:]
47 | new_str += "{" + a + "}{" + b + "}" + post_substr
48 | else:
49 | new_str += "{" + a + "}{" + b + "}"
50 | else:
51 | if len(substr) > 2:
52 | post_substr = substr[2:]
53 | new_str += "{" + a + "}" + b + post_substr
54 | else:
55 | new_str += "{" + a + "}" + b
56 | string = new_str
57 | return string
58 |
59 |
60 | def _fix_a_slash_b(string):
61 | if len(string.split("/")) != 2:
62 | return string
63 | a = string.split("/")[0]
64 | b = string.split("/")[1]
65 | try:
66 | a = int(a)
67 | b = int(b)
68 | assert string == "{}/{}".format(a, b)
69 | new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
70 | return new_string
71 | except:
72 | return string
73 |
74 |
75 | def _remove_right_units(string):
76 | # "\\text{ " only ever occurs (at least in the val set) when describing units
77 | if "\\text{ " in string:
78 | splits = string.split("\\text{ ")
79 | assert len(splits) == 2
80 | return splits[0]
81 | else:
82 | return string
83 |
84 |
85 | def _fix_sqrt(string):
86 | if "\\sqrt" not in string:
87 | return string
88 | splits = string.split("\\sqrt")
89 | new_string = splits[0]
90 | for split in splits[1:]:
91 | if split[0] != "{":
92 | a = split[0]
93 | new_substr = "\\sqrt{" + a + "}" + split[1:]
94 | else:
95 | new_substr = "\\sqrt" + split
96 | new_string += new_substr
97 | return new_string
98 | # linebreaks
99 | string = string.replace("\n", "")
100 | # print(string)
101 |
102 | # remove inverse spaces
103 | string = string.replace("\\!", "")
104 | # print(string)
105 |
106 | # replace \\ with \
107 | string = string.replace("\\\\", "\\")
108 | # print(string)
109 |
110 | # replace tfrac and dfrac with frac
111 | string = string.replace("tfrac", "frac")
112 | string = string.replace("dfrac", "frac")
113 | # print(string)
114 |
115 | # remove \left and \right
116 | string = string.replace("\\left", "")
117 | string = string.replace("\\right", "")
118 | # print(string)
119 |
120 | # Remove circ (degrees)
121 | string = string.replace("^{\\circ}", "")
122 | string = string.replace("^\\circ", "")
123 |
124 | # remove dollar signs
125 | string = string.replace("\\$", "")
126 |
127 | # remove units (on the right)
128 | string = _remove_right_units(string)
129 |
130 | # remove percentage
131 | string = string.replace("\\%", "")
132 | string = string.replace("\%", "")
133 |
134 | # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
135 | string = string.replace(" .", " 0.")
136 | string = string.replace("{.", "{0.")
137 | # if empty, return empty string
138 | if len(string) == 0:
139 | return string
140 | if string[0] == ".":
141 | string = "0" + string
142 |
143 | # to consider: get rid of e.g. "k = " or "q = " at beginning
144 | if len(string.split("=")) == 2:
145 | if len(string.split("=")[0]) <= 2:
146 | string = string.split("=")[1]
147 |
148 | # fix sqrt3 --> sqrt{3}
149 | string = _fix_sqrt(string)
150 |
151 | # remove spaces
152 | string = string.replace(" ", "")
153 |
154 | # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
155 | string = _fix_fracs(string)
156 |
157 | # manually change 0.5 --> \frac{1}{2}
158 | if string == "0.5":
159 | string = "\\frac{1}{2}"
160 |
161 | # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
162 | string = _fix_a_slash_b(string)
163 |
164 | return string
165 |
166 |
167 | # sympy might hang -- we don't care about trying to be lenient in these cases
168 | BAD_SUBSTRINGS = ["^{", "^("]
169 | BAD_REGEXES = ["\^[0-9]+\^", "\^[0-9][0-9]+"]
170 | TUPLE_CHARS = "()[]"
171 |
172 |
173 | def _sympy_parse(expr: str):
174 | """Parses an expression with sympy."""
175 | py_expr = expr.replace("^", "**")
176 | return sympy_parser.parse_expr(
177 | py_expr,
178 | transformations=(
179 | sympy_parser.standard_transformations
180 | + (sympy_parser.implicit_multiplication_application,)
181 | ),
182 | )
183 |
184 |
185 | def _parse_latex(expr: str) -> str:
186 | """Attempts to parse latex to an expression sympy can read."""
187 | expr = expr.replace("\\tfrac", "\\frac")
188 | expr = expr.replace("\\dfrac", "\\frac")
189 | expr = expr.replace("\\frac", " \\frac") # Play nice with mixed numbers.
190 | expr = latex2text.LatexNodes2Text().latex_to_text(expr)
191 |
192 | # Replace the specific characters that this parser uses.
193 | expr = expr.replace("√", "sqrt")
194 | expr = expr.replace("π", "pi")
195 | expr = expr.replace("∞", "inf")
196 | expr = expr.replace("∪", "U")
197 | expr = expr.replace("·", "*")
198 | expr = expr.replace("×", "*")
199 |
200 | return expr.strip()
201 |
202 |
203 | def _is_float(num: str) -> bool:
204 | try:
205 | float(num)
206 | return True
207 | except ValueError:
208 | return False
209 |
210 |
211 | def _is_int(x: float) -> bool:
212 | try:
213 | return abs(x - int(round(x))) <= 1e-7
214 | except:
215 | return False
216 |
217 |
218 | def _is_frac(expr: str) -> bool:
219 | return bool(re.search(r"^-?[0-9]+.?/0*[1-9][0-9]*.?$", expr))
220 |
221 |
222 | def _str_is_int(x: str) -> bool:
223 | try:
224 | x = _strip_properly_formatted_commas(x)
225 | x = float(x)
226 | return abs(x - int(round(x))) <= 1e-7
227 | except:
228 | return False
229 |
230 |
231 | def _str_to_int(x: str) -> bool:
232 | x = x.replace(",", "")
233 | x = float(x)
234 | return int(x)
235 |
236 |
237 | def _inject_implicit_mixed_number(step: str):
238 | """
239 | Automatically make a mixed number evalable
240 | e.g. 7 3/4 => 7+3/4
241 | """
242 | p1 = re.compile("([0-9]) +([0-9])")
243 | step = p1.sub("\\1+\\2", step) ## implicit mults
244 | return step
245 |
246 |
247 | def _strip_properly_formatted_commas(expr: str):
248 | # We want to be careful because we don't want to strip tuple commas
249 | p1 = re.compile("(\d)(,)(\d\d\d)($|\D)")
250 | while True:
251 | next_expr = p1.sub("\\1\\3\\4", expr)
252 | if next_expr == expr:
253 | break
254 | expr = next_expr
255 | return next_expr
256 |
257 |
258 | def _normalize(expr: str) -> str:
259 | """Normalize answer expressions."""
260 | if expr is None:
261 | return None
262 |
263 | # Remove enclosing `\text{}`.
264 | m = re.search("^\\\\text\{(?P.+?)\}$", expr)
265 | if m is not None:
266 | expr = m.group("text")
267 |
268 | expr = expr.replace("\\%", "%")
269 | expr = expr.replace("\\$", "$")
270 | expr = expr.replace("$", "")
271 | expr = expr.replace("%", "")
272 | expr = expr.replace(" or ", " , ")
273 | expr = expr.replace(" and ", " , ")
274 |
275 | expr = expr.replace("million", "*10^6")
276 | expr = expr.replace("billion", "*10^9")
277 | expr = expr.replace("trillion", "*10^12")
278 |
279 | for unit in [
280 | "degree",
281 | "cm",
282 | "centimeter",
283 | "meter",
284 | "mile",
285 | "second",
286 | "minute",
287 | "hour",
288 | "day",
289 | "week",
290 | "month",
291 | "year",
292 | "foot",
293 | "feet",
294 | "inch",
295 | "yard",
296 | ]:
297 | expr = re.sub(f"{unit}(es)?(s)? *(\^[0-9]+)?", "", expr)
298 | expr = re.sub(f"\^ *\\\\circ", "", expr)
299 |
300 | if len(expr) > 0 and expr[0] == "{" and expr[-1] == "}":
301 | expr = expr[1:-1]
302 |
303 | expr = re.sub(",\\\\! *", "", expr)
304 | if _is_float(expr) and _is_int(float(expr)):
305 | expr = str(int(round(float(expr))))
306 | if "\\" in expr:
307 | try:
308 | expr = _parse_latex(expr)
309 | except:
310 | pass
311 |
312 | # edge case with mixed numbers and negative signs
313 | expr = re.sub("- *", "-", expr)
314 |
315 | expr = _inject_implicit_mixed_number(expr)
316 | expr = expr.replace(" ", "")
317 |
318 | # if we somehow still have latex braces here, just drop them
319 | expr = expr.replace("{", "")
320 | expr = expr.replace("}", "")
321 |
322 | # don't be case sensitive for text answers
323 | expr = expr.lower()
324 |
325 | if _str_is_int(expr):
326 | expr = str(_str_to_int(expr))
327 |
328 | return expr
329 |
330 |
331 | def count_unknown_letters_in_expr(expr: str):
332 | expr = expr.replace("sqrt", "")
333 | expr = expr.replace("frac", "")
334 | letters_in_expr = set([x for x in expr if x.isalpha()])
335 | return len(letters_in_expr)
336 |
337 |
338 | def should_allow_eval(expr: str):
339 | # we don't want to try parsing unknown text or functions of more than two variables
340 | if count_unknown_letters_in_expr(expr) > 2:
341 | return False
342 |
343 | for bad_string in BAD_SUBSTRINGS:
344 | if bad_string in expr:
345 | return False
346 |
347 | for bad_regex in BAD_REGEXES:
348 | if re.search(bad_regex, expr) is not None:
349 | return False
350 |
351 | return True
352 |
353 |
354 | def are_equal_under_sympy(ground_truth_normalized: str, given_normalized: str):
355 | are_equal = False
356 | try:
357 | expr = f"({ground_truth_normalized})-({given_normalized})"
358 | if should_allow_eval(expr):
359 | sympy_diff = _sympy_parse(expr)
360 | simplified = sympy.simplify(sympy_diff)
361 | if simplified == 0:
362 | are_equal = True
363 | except:
364 | pass
365 | return are_equal
366 |
367 |
368 | def split_tuple(expr: str):
369 | """
370 | Split the elements in a tuple/interval, while handling well-formatted commas in large numbers
371 | """
372 | expr = _strip_properly_formatted_commas(expr)
373 | if len(expr) == 0:
374 | return []
375 | if (
376 | len(expr) > 2
377 | and expr[0] in TUPLE_CHARS
378 | and expr[-1] in TUPLE_CHARS
379 | and all([ch not in expr[1:-1] for ch in TUPLE_CHARS])
380 | ):
381 | elems = [elem.strip() for elem in expr[1:-1].split(",")]
382 | else:
383 | elems = [expr]
384 | return elems
385 |
386 |
387 | def last_boxed_only_string(string):
388 | idx = string.rfind("\\boxed")
389 | if idx < 0:
390 | idx = string.rfind("\\fbox")
391 | if idx < 0:
392 | return None
393 |
394 | i = idx
395 | right_brace_idx = None
396 | num_left_braces_open = 0
397 | while i < len(string):
398 | if string[i] == "{":
399 | num_left_braces_open += 1
400 | if string[i] == "}":
401 | num_left_braces_open -= 1
402 | if num_left_braces_open == 0:
403 | right_brace_idx = i
404 | break
405 | i += 1
406 |
407 | if right_brace_idx == None:
408 | retval = None
409 | else:
410 | retval = string[idx:right_brace_idx + 1]
411 |
412 | return retval
413 |
414 | def remove_boxed(s):
415 | left = "\\boxed{"
416 | try:
417 | assert s[:len(left)] == left
418 | assert s[-1] == "}"
419 | return s[len(left):-1]
420 | except:
421 | return None
422 |
423 |
424 | def extract_boxed_answer(solution: str) -> str:
425 | """Extract the answer from inside a LaTeX \\boxed{} command"""
426 | solution = last_boxed_only_string(solution)
427 | solution = remove_boxed(solution)
428 | return solution
429 |
430 | def grade_answer_sympy(given_answer: str, ground_truth: str) -> bool:
431 | ground_truth_normalized = _normalize(ground_truth)
432 | given_normalized = _normalize(given_answer)
433 |
434 | if ground_truth_normalized is None:
435 | return False
436 |
437 | if ground_truth_normalized == given_normalized:
438 | return True
439 |
440 | if len(given_normalized) == 0:
441 | return False
442 |
443 | ground_truth_elems = split_tuple(ground_truth_normalized)
444 | given_elems = split_tuple(given_normalized)
445 |
446 | if len(ground_truth_elems) > 1 and (
447 | ground_truth_normalized[0] != given_normalized[0]
448 | or ground_truth_normalized[-1] != given_normalized[-1]
449 | ):
450 | is_correct = False
451 | elif len(ground_truth_elems) != len(given_elems):
452 | is_correct = False
453 | else:
454 | for ground_truth_elem, given_elem in zip(ground_truth_elems, given_elems):
455 | if _is_frac(ground_truth_elem) and _is_frac(given_elem):
456 | # if fractions aren't reduced, then shouldn't be marked as correct
457 | # so, we don't want to allow sympy.simplify in this case
458 | is_correct = ground_truth_elem == given_elem
459 | elif _str_is_int(ground_truth_elem) != _str_is_int(given_elem):
460 | # if the ground truth answer is an integer, we require the given answer to be a strict match (no sympy.simplify)
461 | is_correct = False
462 | else:
463 | is_correct = are_equal_under_sympy(ground_truth_elem, given_elem)
464 | if not is_correct:
465 | break
466 |
467 | return is_correct
468 |
469 | def grade_answer_mathd(given_answer: str, ground_truth: str) -> bool:
470 | ground_truth_normalized_mathd = mathd_normalize_answer(ground_truth)
471 | given_answer_normalized_mathd = mathd_normalize_answer(given_answer)
472 |
473 | # be at least as lenient as mathd
474 | if ground_truth_normalized_mathd == given_answer_normalized_mathd:
475 | return True
476 | return False
477 |
478 |
479 |
480 | def extract_answer(passage: str) -> str:
481 | if "\\boxed" in passage:
482 | return extract_boxed_answer(passage)
483 | return None
484 |
485 | def grade_answer_verl(solution_str, ground_truth):
486 | if not ground_truth:
487 | return False
488 | if '\\boxed' in ground_truth:
489 | ground_truth = extract_answer(ground_truth)
490 | given_answer = extract_answer(solution_str)
491 | if given_answer is None:
492 | return False
493 | return grade_answer_mathd(given_answer, ground_truth) \
494 | or grade_answer_sympy(given_answer, ground_truth)
495 |
496 | def main():
497 | # Example usage
498 | ground_truth = "\\frac{e^{2 i t}}{(2+2 i)}"
499 | given_answer = "\\boxed{\\dfrac{e^{2i t}}{2 + 2i}}"
500 | is_correct = grade_answer_verl(given_answer, ground_truth)
501 | print(f"Is the given answer correct? {is_correct}")
502 |
503 | if __name__ == "__main__":
504 | main()
--------------------------------------------------------------------------------