├── sft ├── __init__.py ├── self_distillation_sampler.py ├── dataloader.py ├── train.py └── evaluate.py ├── fig ├── overview.png └── result.png ├── math_utils ├── parquet_to_jsonl.py ├── format_checking.py ├── deduplicate_problem.py └── utils.py ├── requirements.txt ├── script ├── sft_script │ ├── SFT.sh │ └── expert_iteration.sh ├── data_script │ ├── teacher_data_download.py │ ├── extract_training_solution.py │ ├── model_download.py │ └── processing_self_distillation_traj.py └── eval_script │ ├── start_inference_server.sh │ └── eval_remote_server.sh ├── .gitignore ├── dataset ├── train │ └── README.md └── test │ ├── aime2025.jsonl │ └── amc.jsonl ├── LICENSE └── README.md /sft/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /fig/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StigLidu/DualDistill/HEAD/fig/overview.png -------------------------------------------------------------------------------- /fig/result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StigLidu/DualDistill/HEAD/fig/result.png -------------------------------------------------------------------------------- /math_utils/parquet_to_jsonl.py: -------------------------------------------------------------------------------- 1 | # turn parquet file to json file 2 | 3 | import pandas as pd 4 | import sys 5 | 6 | file_path = sys.argv[1] 7 | 8 | df = pd.read_parquet(file_path) 9 | 10 | df.to_json(file_path.replace(".parquet", ".jsonl"), orient='records', lines=True) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==1.8.1 2 | datasets==3.6.0 3 | huggingface_hub==0.33.1 4 | math_verify==0.8.0 5 | matplotlib==3.10.3 6 | openai==1.93.1 7 | pandas==2.3.1 8 | Requests==2.32.4 9 | thefuzz==0.22.1 10 | torch==2.7.0 11 | tqdm==4.66.4 12 | transformers==4.53.1 13 | vllm==0.9.1 14 | wandb==0.21.0 15 | numpy 16 | -------------------------------------------------------------------------------- /script/sft_script/SFT.sh: -------------------------------------------------------------------------------- 1 | python sft/train.py \ 2 | --model_path models/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B \ 3 | --data_path dataset/train/dual_distill_data.jsonl \ 4 | --epochs 4 \ 5 | --code_mode \ 6 | --batch_size 1 \ 7 | --save_interval 2 \ 8 | --data_seed 42 \ 9 | --save_path agentic_R1_Qwen-7B -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /models 2 | /result 3 | /outputs 4 | /slurm_logs 5 | **/__pycache__/ 6 | test.py 7 | *.parquet 8 | *.jsonl.bk 9 | *.json 10 | *.jsonl 11 | *.log 12 | wandb/ 13 | wandb 14 | _backup/ 15 | workspace 16 | *.txt 17 | torchinductor_weihuad 18 | plot 19 | !dataset/test/* 20 | !dataset/train/whole_training_set.jsonl 21 | !requirements.txt -------------------------------------------------------------------------------- /script/data_script/teacher_data_download.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | import json 3 | # Login using e.g. `huggingface-cli login` to access this dataset 4 | ds = load_dataset("VanishD/DualDistill") 5 | 6 | # Save the dataset to a local directory in jsonl format 7 | save_path = f"dataset/train/dual_distill_data.jsonl" 8 | with open(save_path, "w") as f: 9 | for item in ds["train"]: 10 | f.write(json.dumps(item) + "\n") -------------------------------------------------------------------------------- /script/eval_script/start_inference_server.sh: -------------------------------------------------------------------------------- 1 | model_path=$1 # models/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B 2 | display_name=$2 # DeepSeek-R1-Distill-Qwen-7B 3 | port=$3 4 | 5 | # default port is 8123 6 | if [ -z "$port" ]; then 7 | port=8123 8 | fi 9 | 10 | vllm serve $model_path \ 11 | -pp 2 \ 12 | -tp 2 \ 13 | --gpu-memory-utilization 0.8 \ 14 | --host 0.0.0.0 --port $port \ 15 | --served-model-name $display_name \ 16 | --enable-prefix-caching -------------------------------------------------------------------------------- /dataset/train/README.md: -------------------------------------------------------------------------------- 1 | ## Training Data Preparation 2 | 3 | Our training problems are selected from [DeepMath](https://huggingface.co/datasets/zwhe99/DeepMath-103K). 4 | 5 | ### For the Two Teachers 6 | 7 | * **Text-based Reasoning Teacher:** [DeepMath](https://huggingface.co/datasets/zwhe99/DeepMath-103K) provides solution trajectories from Deepseek-R1. 8 | * **Agentic Tool Use Teacher:** Generated from [Openhands](https://github.com/All-Hands-AI/OpenHands) using Claude-3.5-Sonnet as the underlying model. 9 | 10 | The distilled trajectory composition data (teacher distillation) is available at [DualDistill](https://huggingface.co/datasets/VanishD/DualDistill). 11 | 12 | ### For Self-Distillation 13 | 14 | The whole training problem list is `whole_training_set.jsonl`. -------------------------------------------------------------------------------- /script/eval_script/eval_remote_server.sh: -------------------------------------------------------------------------------- 1 | model_path=$1 # http://localhost:8125/v1 2 | model_name=$2 # DeepSeek-R1-Distill-Qwen-7B 3 | data_path=$3 # dataset/test/DeepMath-Large.jsonl 4 | code_mode=$4 # true 5 | max_tokens=$5 # 4096 6 | retrieve_path=$6 # retrieve_path 7 | 8 | if [ "$code_mode" = "true" ]; then 9 | extra_args1="--code_mode" 10 | else 11 | extra_args1="" 12 | fi 13 | 14 | # if max_tokens is not set, set it to 4096 15 | if [ -z "$max_tokens" ]; then 16 | max_tokens=4096 17 | fi 18 | 19 | if [ "$retrieve_path" != "" ]; then 20 | extra_args2="--generation_save_path $retrieve_path" 21 | else 22 | extra_args2="" 23 | fi 24 | 25 | python sft/evaluate.py \ 26 | --use_server_inference \ 27 | --num_samples 5 \ 28 | --model_path $model_path \ 29 | --model_name $model_name \ 30 | --data_path $data_path \ 31 | --max_tokens $max_tokens \ 32 | $extra_args1 \ 33 | $extra_args2 -------------------------------------------------------------------------------- /script/sft_script/expert_iteration.sh: -------------------------------------------------------------------------------- 1 | model_path=$1 2 | save_path=$2 3 | 4 | python sft/train.py \ 5 | --model_path $model_path \ 6 | --data_path dataset/train/self_distillation/iteration_1_correct_replay_buffer_deduplicated.jsonl \ 7 | dataset/train/self_distillation/iteration_1_incorrect_replay_buffer_revised_deduplicated_0.9.jsonl \ 8 | --save_path $save_path \ 9 | --epochs 4 \ 10 | --code_mode \ 11 | --batch_size 1 \ 12 | --save_interval 2 \ 13 | --data_seed 42 \ 14 | --max_length 8192 \ 15 | --lr 1e-5 \ 16 | --resume \ 17 | --resume_path models/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B_e0_latest_20250510_222416_3_agentic_r1_sd_20250707_234942_3 18 | # make sure lr * total_steps (max_data_count * epochs) a constant 19 | # actually, we should control lr * total_tokens a constant, while it is hard to control total_tokens 20 | # Be careful, LLM will suffer from overfitting if the total_tokens is too large -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Weihua Du 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /script/data_script/extract_training_solution.py: -------------------------------------------------------------------------------- 1 | import json 2 | from datasets import load_dataset 3 | import random 4 | 5 | problem_path = "dataset/train/whole_training_set.jsonl" 6 | ds = load_dataset("zwhe99/DeepMath-103K")['train'] 7 | 8 | # find a solution for each problem 9 | with open(problem_path, "r") as f: 10 | all_problems = sorted([json.loads(line) for line in f], key=lambda x: x["problem"]) 11 | 12 | print("Number of problems: ", len(all_problems)) 13 | ds = list(ds) 14 | ds.sort(key=lambda x: x["question"]) 15 | 16 | last_idx = 0 17 | data = [] 18 | for i in range(len(all_problems)): 19 | while ds[last_idx]["question"] != all_problems[i]["problem"]: 20 | last_idx += 1 21 | text_reasoning = ds[last_idx][f"r1_solution_{random.randint(1, 3)}"] 22 | text_reasoning = text_reasoning.split("")[0] + "\n" + "" + text_reasoning.split("")[1] + "\n" 23 | data.append({ 24 | "problem": all_problems[i]["problem"], 25 | "solution": text_reasoning, 26 | "answer": all_problems[i]["answer"] 27 | }) 28 | 29 | with open("dataset/train/whole_training_set_with_solution.jsonl", "w") as f: 30 | for item in data: 31 | f.write(json.dumps(item) + "\n") -------------------------------------------------------------------------------- /math_utils/format_checking.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | 5 | def check_format(data : str, interim_list : list[str] = ["think", "code", "executor"], end_token : str = "answer", allow_no_end_token : bool = True): 6 | data = data.strip() 7 | while not data.startswith("<" + end_token + ">"): 8 | next_block_exists = False 9 | for interim in interim_list: 10 | if data.startswith("<" + interim + ">"): 11 | next_block_exists = True 12 | if "" not in data: 13 | return False 14 | try: 15 | block_content = ("<" + interim + ">").join(data.split("")[0].strip().split("<" + interim + ">")[1:]).strip() 16 | for other in interim_list: 17 | if "<" + other + ">" in block_content: 18 | return False 19 | data = ("").join(data.split("")[1:]).strip() 20 | except: 21 | return False 22 | if not next_block_exists: 23 | # if allow_no_end_token is True, check if there is any interim block after the last interim block 24 | if allow_no_end_token: 25 | for interim in interim_list: 26 | if "<" + interim + ">" in data: 27 | return False 28 | return True 29 | else: 30 | return False 31 | if "" not in data: 32 | return False 33 | for interim in interim_list: 34 | if "<" + interim + ">" in data: 35 | return False 36 | return True 37 | 38 | if __name__ == "__main__": 39 | data = "print('Hello, world!')." 40 | print(check_format(data)) -------------------------------------------------------------------------------- /script/data_script/model_download.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import os 7 | from requests.exceptions import HTTPError 8 | import sys 9 | from pathlib import Path 10 | from typing import Optional 11 | 12 | 13 | def hf_download( 14 | repo_id: Optional[str] = None, 15 | hf_token: Optional[str] = None, 16 | local_dir: Optional[str] = None, 17 | ) -> None: 18 | from huggingface_hub import snapshot_download 19 | 20 | local_dir = local_dir or "checkpoints" 21 | 22 | os.makedirs(f"{local_dir}/{repo_id}", exist_ok=True) 23 | try: 24 | snapshot_download( 25 | repo_id, 26 | local_dir=f"{local_dir}/{repo_id}", 27 | local_dir_use_symlinks=False, 28 | token=hf_token, 29 | ) 30 | except HTTPError as e: 31 | if e.response.status_code == 401: 32 | print( 33 | "You need to pass a valid `--hf_token=...` to download private checkpoints." 34 | ) 35 | else: 36 | raise e 37 | 38 | 39 | if __name__ == "__main__": 40 | import argparse 41 | 42 | parser = argparse.ArgumentParser(description="Download data from HuggingFace Hub.") 43 | parser.add_argument( 44 | "--repo_id", 45 | type=str, 46 | default="checkpoints/meta-llama/llama-2-7b-chat-hf", 47 | help="Repository ID to download from.", 48 | ) 49 | parser.add_argument( 50 | "--local_dir", type=str, default=None, help="Local directory to download to." 51 | ) 52 | parser.add_argument( 53 | "--hf_token", type=str, default=None, help="HuggingFace API token." 54 | ) 55 | 56 | args = parser.parse_args() 57 | hf_download(args.repo_id, args.hf_token, args.local_dir) -------------------------------------------------------------------------------- /math_utils/deduplicate_problem.py: -------------------------------------------------------------------------------- 1 | from thefuzz import fuzz 2 | import json 3 | from utils import find_question 4 | from tqdm import tqdm 5 | import sys 6 | 7 | def check_duplicate(text1, text2, threshold=90): 8 | if fuzz.ratio(text1, text2) > threshold: 9 | return True 10 | return False 11 | 12 | #print(check_duplicate("What is the sum of the first 100 natural numbers?", "What is the sum of the first 100 natural numbers?")) 13 | 14 | training_set_path = "data/DeepMath-103K-big-number/DeepMath-103K-big-number_question_everything.jsonl" 15 | 16 | training_problem_list = [] 17 | with open(training_set_path, "r") as f: 18 | for line in f: 19 | data = json.loads(line) 20 | training_problem_list.append(data) 21 | 22 | print("training set size: ", len(training_problem_list)) 23 | 24 | #test_set_path = "data/combinatorics_test_ge10000.jsonl" 25 | test_set_path = sys.argv[1] 26 | threshold = 90 27 | test_problem_list = [] 28 | with open(test_set_path, "r") as f: 29 | for line in f: 30 | data = json.loads(line) 31 | test_problem_list.append(data) 32 | 33 | print("test set size: ", len(test_problem_list)) 34 | 35 | no_duplicate_count = 0 36 | no_duplicate_count_list = [] 37 | for problem in tqdm(test_problem_list): 38 | for training_problem in training_problem_list: 39 | if check_duplicate(find_question(problem), find_question(training_problem), threshold=threshold): 40 | print("find duplicate, ratio: ", fuzz.ratio(find_question(problem), find_question(training_problem))) 41 | print(">>>>>>>>>>>>>>>>>>>>") 42 | print(find_question(problem)) 43 | print("<<<<<<<<<<<<<<<<<<<<") 44 | print(find_question(training_problem)) 45 | print(">>>>>>>>>>>>>>>>>>>>") 46 | break 47 | else: 48 | no_duplicate_count += 1 49 | no_duplicate_count_list.append(problem) 50 | 51 | print("test set no duplicate count: ", no_duplicate_count) 52 | 53 | with open(test_set_path.replace(".jsonl", "_no_duplicate_threshold_{}.jsonl".format(threshold)), "w") as f: 54 | for problem in no_duplicate_count_list: 55 | f.write(json.dumps(problem) + "\n") 56 | -------------------------------------------------------------------------------- /script/data_script/processing_self_distillation_traj.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | 5 | # add the root directory to the python path 6 | sys.path.append(os.path.abspath(".")) 7 | 8 | from math_utils.format_checking import check_format 9 | from transformers import AutoTokenizer 10 | 11 | tokenizer = AutoTokenizer.from_pretrained("models/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B") 12 | root_dir = "dataset/train/self_distillation" 13 | data_path = os.path.join(root_dir, "iteration_1_correct_replay_buffer.jsonl") 14 | wrong_data_path = os.path.join(root_dir, "iteration_1_incorrect_replay_buffer.jsonl") 15 | correct_record_path = os.path.join(root_dir, "iteration_1_accuracy.jsonl") 16 | text_reasoning_path = "dataset/train/whole_training_set_with_solution.jsonl" 17 | 18 | num_samples = 16 19 | correctness_bar = 0.9 20 | 21 | with open(text_reasoning_path, "r") as f: 22 | text_reasoning_data = [json.loads(line) for line in f] 23 | 24 | data = [] 25 | count = 0 26 | with open(data_path, "r") as f: 27 | raw_data = [json.loads(line) for line in f] 28 | count = 0 29 | temp_add_data = None 30 | for i in range(len(raw_data)): 31 | # If not all the data are correct, we add an example into expert iteration 32 | # Do not add a text reasoning path because this part has been added in SFT stage 33 | # One idx may have multiple data, we only add one data piece into expert iteration to ensure diversity 34 | if i == 0 or raw_data[i]["idx"] != raw_data[i-1]["idx"]: 35 | if count < num_samples and temp_add_data is not None: 36 | # > 0, < correctness_bar 37 | data.append(temp_add_data) 38 | temp_add_data = None 39 | count = 0 40 | count += 1 41 | if check_format(raw_data[i]["synthetic_data"]): 42 | attempt = raw_data[i]["synthetic_data"].split("")[0].strip() 43 | if temp_add_data is None: 44 | for j in range(len(text_reasoning_data)): 45 | if text_reasoning_data[j]["problem"] == raw_data[i]["problem"]: 46 | temp_add_data = attempt + "\n\nWait, we can also use text-reasoning as an alternative way to verify the solution.\n\n" + text_reasoning_data[j]["solution"] 47 | raw_data[i]["synthetic_data"] = temp_add_data 48 | temp_add_data = raw_data[i] 49 | count += 1 50 | break 51 | 52 | print("self-distillation part 1: ", len(data)) 53 | 54 | wrong_data = [] 55 | 56 | # Wrong data but format correct 57 | # Replace the last block with correct text reasoning 58 | # At most one wrong data per idx 59 | # If exist a correct sample, we do not add a corrected wrong sample into expert iteration 60 | 61 | with open(wrong_data_path, "r") as f: 62 | raw_data = [json.loads(line) for line in f] 63 | with open(correct_record_path, "r") as ff: 64 | correct_record = [json.loads(line) for line in ff] 65 | max_attempt_len = 0 66 | for i in range(len(raw_data)): 67 | if i == 0 or raw_data[i]["idx"] != raw_data[i - 1]["idx"]: 68 | count = 0 69 | if correct_record[raw_data[i]["idx"] - 1]["accuracy"] >= correctness_bar: 70 | continue 71 | attempt = "".join(raw_data[i]["synthetic_data"].split("")[:-1]) + "\n" 72 | if "Wait, the code is not correct, let's try text reasoning" in attempt: 73 | continue 74 | if check_format(attempt) and count < 1: 75 | max_attempt_len = max(max_attempt_len, len(tokenizer.encode(attempt))) 76 | # find corresponding correct text-reasoning 77 | for j in range(len(text_reasoning_data)): 78 | if text_reasoning_data[j]["problem"] == raw_data[i]["problem"]: 79 | new_attempt = attempt + "\n\nWait, the code is not correct, let's try text reasoning.\n\n" + text_reasoning_data[j]["solution"] 80 | if check_format(new_attempt): 81 | new_data = raw_data[i] 82 | new_data["synthetic_data"] = new_attempt 83 | wrong_data.append(new_data) 84 | count += 1 85 | break 86 | 87 | print("self-distillation part 2: ", len(wrong_data)) 88 | data_root_dir = "dataset/train/self_distillation" 89 | 90 | with open(os.path.join(data_root_dir, f"iteration_1_correct_replay_buffer_deduplicated.jsonl"), "w") as f: 91 | for item in data: 92 | f.write(json.dumps(item) + "\n") 93 | 94 | with open(os.path.join(data_root_dir, f"iteration_1_incorrect_replay_buffer_revised_deduplicated_{correctness_bar}.jsonl"), "w") as f: 95 | for item in wrong_data: 96 | f.write(json.dumps(item) + "\n") -------------------------------------------------------------------------------- /sft/self_distillation_sampler.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import argparse 4 | import sys 5 | from concurrent.futures import ThreadPoolExecutor 6 | parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) 7 | sys.path.append(parent_dir) 8 | 9 | from math_utils.utils import read_json_or_jsonl, find_question, find_answer 10 | from sft.evaluate import server_inference 11 | from math_utils.utils import compute_score 12 | from tqdm import tqdm 13 | 14 | def self_distillation_sampler(server_url: str, model_name: str, model_path: str, data_path: str, num_samples: int, save_path: str, data_size: int, iteration: int=1): 15 | """ 16 | Self-distillation sampler 17 | The model generates #num_samples samples for each question in the data_path 18 | """ 19 | # check paths 20 | os.makedirs(save_path, exist_ok=True) 21 | # load the dataset 22 | data = read_json_or_jsonl(data_path) 23 | data = data[:data_size] 24 | # generate #num_samples samples for each question in the data_path 25 | correct_replay_buffer_path = os.path.join(save_path, f"iteration_{iteration}_correct_replay_buffer.jsonl") 26 | incorrect_replay_buffer_path = os.path.join(save_path, f"iteration_{iteration}_incorrect_replay_buffer.jsonl") 27 | if os.path.exists(correct_replay_buffer_path.replace(".jsonl", "_last_sample_id.txt")): 28 | with open(correct_replay_buffer_path.replace(".jsonl", "_last_sample_id.txt"), "r") as f: 29 | last_sample_id = int(f.read()) 30 | else: 31 | last_sample_id = 0 32 | if os.path.exists(correct_replay_buffer_path.replace("correct_replay_buffer.jsonl", "accuracy.jsonl")): 33 | with open(correct_replay_buffer_path.replace("correct_replay_buffer.jsonl", "accuracy.jsonl"), "r") as f: 34 | accuracy_data = [json.loads(line) for line in f] 35 | sum_accuracy = sum([data["accuracy"] for data in accuracy_data]) 36 | else: 37 | accuracy_data = [] 38 | sum_accuracy = 0 39 | with tqdm(data[last_sample_id:], total=len(data) - last_sample_id, desc="Generating samples") as pbar: 40 | for i, item in enumerate(pbar, 1): 41 | # generate #num_samples samples 42 | correct = 0 43 | with ThreadPoolExecutor(max_workers=num_samples) as executor: 44 | futures = [executor.submit(server_inference, \ 45 | model_base_url=server_url, \ 46 | model_name=model_name, \ 47 | tokenizer_path=model_path, \ 48 | input=find_question(item), \ 49 | code_mode=True, \ 50 | max_tokens=4096, \ 51 | is_ipython=False) for i in range(num_samples)] 52 | outputs = [future.result() for future in futures] 53 | 54 | correct = 0 55 | correct_data_to_save = [] 56 | incorrect_data_to_save = [] 57 | for output in outputs: 58 | # check if the sample is correct 59 | if compute_score(output.split("<|Assistant|>")[-1], find_answer(item)) == 1: 60 | # add the sample to the data 61 | correct += 1 62 | correct_data_to_save.append(output) 63 | else: 64 | incorrect_data_to_save.append(output) 65 | 66 | # save the samples 67 | for data in correct_data_to_save: 68 | with open(correct_replay_buffer_path, "a") as f: 69 | f.write(json.dumps({"idx": i + last_sample_id, "problem": find_question(item), "synthetic_data": data.split("<|Assistant|>")[-1]}) + "\n") 70 | for data in incorrect_data_to_save: 71 | with open(incorrect_replay_buffer_path, "a") as f: 72 | f.write(json.dumps({"idx": i + last_sample_id, "problem": find_question(item), "synthetic_data": data.split("<|Assistant|>")[-1]}) + "\n") 73 | 74 | sum_accuracy += correct / num_samples 75 | accuracy_data.append({"idx": i + last_sample_id, "accuracy": correct / num_samples, "correct": correct, "total": num_samples}) 76 | pbar.set_postfix({"accuracy": f"{sum_accuracy / (i + last_sample_id):.2%}", "correct": correct, "total": num_samples}) 77 | with open(correct_replay_buffer_path.replace(".jsonl", "_last_sample_id.txt"), "w") as f: 78 | f.write(str(i + last_sample_id)) 79 | with open(correct_replay_buffer_path.replace("correct_replay_buffer.jsonl", "accuracy.jsonl"), "w") as f: 80 | for data in accuracy_data: 81 | f.write(json.dumps(data) + "\n") 82 | 83 | print(f"Iteration {iteration}, average accuracy: {sum_accuracy / len(data)}") 84 | 85 | if __name__ == "__main__": 86 | parser = argparse.ArgumentParser() 87 | parser.add_argument("--server_url", type=str, default="http://localhost:8123/v1", help="The server url") 88 | parser.add_argument("--model_name", type=str, default="DeepSeek-R1-Distill-Qwen-7B") 89 | parser.add_argument("--model_path", type=str, default="models/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B") 90 | parser.add_argument("--data_path", type=str, default="dataset/train/whole_training_set.jsonl") 91 | parser.add_argument("--num_samples", type=int, default=16) 92 | parser.add_argument("--save_path", type=str, help="The path to save the self-distillation trajectories", default="dataset/train/self_distillation") 93 | parser.add_argument("--data_size", type=int, default=-1, help="The number of data to sample, -1 means all data") 94 | args = parser.parse_args() 95 | print(args) 96 | self_distillation_sampler(server_url=args.server_url, \ 97 | model_name=args.model_name, \ 98 | model_path=args.model_path, \ 99 | data_path=args.data_path, \ 100 | num_samples=args.num_samples, \ 101 | save_path=args.save_path, \ 102 | data_size=args.data_size) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DualDistill 2 | 3 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](LICENSE) 4 | [![arXiv](https://img.shields.io/badge/arXiv-2507.05707-b31b1b.svg)](https://arxiv.org/abs/2507.05707) 5 | 6 | Official implementation of **DualDistill**: A trajectory-composition distillation method for integrating tool use into long-chain-of-thought reasoning. 7 | 8 | > **Weihua Du, Pranjal Aggarwal, Sean Welleck, & Yiming Yang** 9 | > ["Agentic-R1: Distilled Dual-Strategy Reasoning." (2025)](https://arxiv.org/abs/2507.05707) 10 | 11 | ## Key Features 12 | 13 | - **Efficient Training**: Integrates tool use into long-chain-of-thought (CoT) reasoning using only 4 × A6000 GPUs 14 | - **Unified Reasoning**: Fuses heterogeneous reasoning traces from multiple teacher models into a single student model 15 | 16 |
17 | Overview of DualDistill 18 |

Overview of DualDistill methodology

19 |
20 | 21 | ## Datasets 22 | 23 | | Dataset | Description | Link | 24 | |---------|-------------|------| 25 | | **Training Set** | Complete training dataset with teacher trajectories | [🤗 HuggingFace](https://huggingface.co/datasets/VanishD/DualDistill) | 26 | | **Test Set** | Evaluation benchmarks | `dataset/test/` | 27 | 28 | ## Results 29 | 30 |
31 | Performance comparison of Agentic-R1 models 32 |
33 | 34 | - **Agentic-R1** demonstrates significant performance gains on **DeepMath-L** and **Combinatorics300**, where both complex reasoning and tool use are crucial for success. 35 | - **Agentic-R1-SD** (Self-Distilled) further enhances performance through our self-distillation approach, consistently outperforming baseline models across nearly all evaluation tasks. 36 | 37 | ## Quick Start 38 | 39 | ### Installation 40 | 41 | 1. **Clone the repository**: 42 | ```bash 43 | git clone https://github.com/StigLidu/DualDistill.git 44 | cd DualDistill 45 | ``` 46 | 47 | 2. **Create environment** (optional but recommended): 48 | ```bash 49 | conda create -n dualdistill python=3.11 50 | conda activate dualdistill 51 | ``` 52 | 53 | 3. **Install dependencies**: 54 | ```bash 55 | pip install -r requirements.txt 56 | pip install flash-attn --no-build-isolation 57 | ``` 58 | 59 | ## Training Pipeline 60 | 61 | ### Step 1: Model & Data Preparation 62 | 63 | **Download the base model**: 64 | ```bash 65 | python script/data_script/model_download.py \ 66 | --repo_id deepseek-ai/DeepSeek-R1-Distill-Qwen-7B \ 67 | --local_dir models 68 | ``` 69 | 70 | **Prepare training data**: 71 | ```bash 72 | python script/data_script/teacher_data_download.py 73 | ``` 74 | 75 | ### Step 2: Teacher Distillation 76 | 77 | Train the student model using teacher trajectories: 78 | ```bash 79 | bash script/sft_script/SFT.sh 80 | ``` 81 | 82 | ### Step 3: Self-Distillation 83 | 84 | **Start inference server**: 85 | ```bash 86 | bash script/eval_script/start_inference_server.sh [model_path] [display_name] [port] 87 | ``` 88 | 89 | **Sample self-distillation trajectories**: 90 | ```bash 91 | python sft/self_distillation_sampler.py \ 92 | --server_url http://localhost:$port/v1 \ 93 | --model_name [display_name] \ 94 | --model_path [model_path] \ 95 | --save_path [path_to_save_trajectories] 96 | ``` 97 | 98 | **Prepare self-distillation data**: 99 | ```bash 100 | # Extract teacher solutions 101 | python script/data_script/extract_training_solution.py 102 | 103 | # Construct training dataset 104 | python script/data_script/processing_self_distillation_traj.py 105 | ``` 106 | 107 | **Fine-tune on self-distillation data**: 108 | ```bash 109 | bash script/sft_script/expert_iteration.sh [model_path] [data_path] [save_path] 110 | ``` 111 | 112 | ## Model Evaluation 113 | 114 | ### Start Inference Server 115 | ```bash 116 | bash script/eval_script/start_inference_server.sh [model_path] [display_name] [port] 117 | ``` 118 | 119 | ### Run Evaluation 120 | ```bash 121 | bash script/eval_script/eval_remote_server.sh \ 122 | [url] [display_name] [data_path] [code_mode] [max_token] 123 | ``` 124 | 125 | **Example**: 126 | ```bash 127 | bash script/eval_script/eval_remote_server.sh \ 128 | "http://localhost:8080/v1" "agentic-r1" "dataset/test/math.json" "true" "4096" 129 | ``` 130 | 131 | ## Trained Models 132 | 133 | | Model | Description | HuggingFace Link | 134 | |-------|-------------|------------------| 135 | | **Agentic-R1-7B** | Base model with teacher distillation | [🤗 Download](https://huggingface.co/VanishD/Agentic-R1) | 136 | | **Agentic-R1-7B-SD** | Enhanced model with self-distillation | [🤗 Download](https://huggingface.co/VanishD/Agentic-R1-SD) | 137 | 138 | ## ⚠️ Important Notes 139 | 140 | - **Code Execution Safety**: The evaluation scripts execute model-generated code locally. Only use trusted models before execution. 141 | - **Inference Config**: If you are using vLLM (a recent version) and encounter an error regarding the maximum context length. You may need to modify the `model_max_length` in `tokenizer_config.json`. 142 | - **Self-Distillation Warning**: The self-distillation step requires sampling many trajectories and can be time-consuming. 143 | 144 | ## License 145 | 146 | This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. 147 | 148 | ## Acknowledgments 149 | 150 | We thank the following open-source projects for their foundational contributions: 151 | 152 | - [OpenHands](https://github.com/All-Hands-AI/OpenHands) - Agent framework 153 | - [DeepMath-103K](https://huggingface.co/datasets/zwhe99/DeepMath-103K) - Mathematical reasoning dataset 154 | - [vLLM](https://github.com/vllm-project/vllm) - High-performance inference engine 155 | 156 | ## Contact 157 | 158 | For questions or support, please contact: 159 | 160 | - **Weihua Du**: [weihuad@cs.cmu.edu](mailto:weihuad@cs.cmu.edu) 161 | 162 | ## Citation 163 | 164 | If you find our work useful, please consider citing: 165 | 166 | ```bibtex 167 | @article{du2025agentic, 168 | title={Agentic-R1: Distilled Dual-Strategy Reasoning}, 169 | author={Du, Weihua and Aggarwal, Pranjal and Welleck, Sean and Yang, Yiming}, 170 | journal={arXiv preprint arXiv:2507.05707}, 171 | year={2025} 172 | } 173 | ``` 174 | 175 | --- 176 | 177 |
178 |

⭐ Star us on GitHub if this project helped you!

179 |
180 | -------------------------------------------------------------------------------- /sft/dataloader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from tqdm import tqdm 3 | import torch 4 | from torch.nn.utils.rnn import pad_sequence 5 | from transformers import AutoTokenizer 6 | import numpy as np 7 | import sys 8 | import os 9 | import json 10 | 11 | parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) 12 | sys.path.append(parent_dir) 13 | 14 | from math_utils.utils import SYSTEM_PROMPT_TPL, CODE_INSTRUCTION 15 | from math_utils.format_checking import check_format 16 | 17 | def find_nth(haystack, needle, n): 18 | start = haystack.find(needle) 19 | while start >= 0 and n > 0: 20 | start = haystack.find(needle, start + len(needle)) 21 | n -= 1 22 | return start 23 | 24 | class TrainData(Dataset): 25 | def __init__(self, data, tokenizer, code_instruction, max_data_count=None, data_seed=42, debug=False, max_length=16384): 26 | self.tokenizer = tokenizer 27 | self.items = [] 28 | self.max_length = max_length 29 | self.total_loss_calculation_token_count = 0 30 | #TODO: seems like the max length of the model is 16384 because it throws an warning when the length is larger than 16384 31 | self.debug = debug 32 | system_prompt_tpl = SYSTEM_PROMPT_TPL.format(code_instruction=CODE_INSTRUCTION) 33 | 34 | for sample in tqdm(data, desc="Processing data"): 35 | question = sample["problem"] 36 | answer = sample["synthetic_data"] 37 | 38 | if not check_format(answer): 39 | continue 40 | 41 | # only keep the part, avoid extra content 42 | if "
" in answer: 43 | answer = answer.split("")[0] + "" 44 | 45 | system_prompt = system_prompt_tpl 46 | messages = ( 47 | system_prompt + "\n\n<|User|>" + question + 48 | "\n\n<|Assistant|>" + answer 49 | ) 50 | 51 | input_ids = tokenizer.encode(messages) 52 | input_ids.append(tokenizer.eos_token_id) 53 | if len(input_ids) > self.max_length: 54 | # ignore the input_ids 55 | continue 56 | 57 | q_ids = tokenizer.encode( 58 | system_prompt + "\n\n<|User|>" + question + "\n\n<|Assistant|>" 59 | ) 60 | 61 | labels = [-100] * len(q_ids) + input_ids[len(q_ids):] 62 | 63 | #TODO: For expert iteration, the number of and may mismatch, so we need to check the number of and 64 | errors = ["SyntaxError", "Traceback (most recent call last)", "Error: Code execution timed out."] 65 | code_block_count = sample["synthetic_data"].count("
") 66 | code_block_flag = [False] * code_block_count 67 | for c_id in range(code_block_count): 68 | # find the c_id-th and 69 | executor_start = find_nth(sample["synthetic_data"], "", c_id) 70 | executor_end = sample["synthetic_data"].find("", executor_start) 71 | if executor_start == -1 or executor_end == -1: 72 | code_block_flag[c_id] = False 73 | continue 74 | code_block_flag[c_id] = True 75 | for e in errors: 76 | if e in sample["synthetic_data"][executor_start:executor_end]: 77 | code_block_flag[c_id] = False 78 | break 79 | 80 | # do not calculate the loss for the code blocks with errors 81 | decoded_str = "" 82 | for i in range(len(q_ids), len(input_ids)): 83 | decoded_str += tokenizer.decode(input_ids[i]) 84 | if decoded_str.count("") > decoded_str.count("") and code_block_flag[decoded_str.count("") - 1] == False: 85 | labels[i] = -100 86 | 87 | # do not calculate the loss for the executor feedback 88 | decoded_str = "" 89 | for i in range(len(q_ids), len(input_ids)): 90 | decoded_str += tokenizer.decode(input_ids[i]) 91 | if decoded_str.count("") > decoded_str.count(""): 92 | labels[i] = -100 93 | 94 | # do not calculate the loss before the turn-over words 95 | turn_over_words = ["Wait, use text reasoning is too tedious, let's try code reasoning.", \ 96 | "\nWait, the code is not correct, let's try text reasoning.", \ 97 | "\nWait, the code may be incorrect, let's try text reasoning." 98 | ] 99 | turn_over_flag = False 100 | last_occurrence = np.inf 101 | for word in turn_over_words: 102 | if word in answer: 103 | # find the first occurrence of the word 104 | last_o = answer.find(word) 105 | if last_o < last_occurrence: 106 | last_occurrence = last_o 107 | turn_over_flag = True 108 | turn_over_word = word 109 | turn_over_num_token = len(tokenizer.encode(word)) 110 | 111 | if turn_over_flag: 112 | decoded_str = "" 113 | for i in range(len(q_ids), len(input_ids)): 114 | decoded_str += tokenizer.decode(input_ids[i]) 115 | labels[max(0, i - turn_over_num_token - 1)] = -100 # TODO: maybe contains 1 offset error, but it's ok. 116 | if turn_over_word in decoded_str: 117 | break 118 | 119 | # cache the tensors; change to half / int16 if memory is limited 120 | self.items.append(( 121 | torch.tensor(input_ids, dtype=torch.long), 122 | torch.tensor(labels, dtype=torch.long), 123 | )) 124 | 125 | # loss calculation token count 126 | loss_calculation_token_count = 0 127 | for i in range(len(q_ids), len(input_ids)): 128 | if labels[i] != -100: 129 | loss_calculation_token_count += 1 130 | self.total_loss_calculation_token_count += loss_calculation_token_count 131 | 132 | if self.debug: 133 | # print the str that calculate the loss 134 | for i in range(len(q_ids), len(input_ids)): 135 | if labels[i] != -100: 136 | print(tokenizer.decode(input_ids[i]), end="") 137 | print() 138 | 139 | if max_data_count is None: 140 | max_data_count = len(self.items) 141 | np.random.seed(data_seed) 142 | np.random.shuffle(self.items) 143 | self.items = self.items[:max_data_count] 144 | print(f"Shuffled data with seed {data_seed} and got {len(self.items)} samples") 145 | print(f"Total loss calculation token count: {self.total_loss_calculation_token_count}") 146 | 147 | def __getitem__(self, idx): 148 | return self.items[idx] 149 | 150 | def __len__(self): 151 | return len(self.items) 152 | 153 | @staticmethod 154 | def collate_fn(batch): 155 | input_ids, labels = zip(*batch) 156 | input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) 157 | labels = pad_sequence(labels, batch_first=True, padding_value=-100) 158 | return {"input_ids": input_ids, "labels": labels} 159 | 160 | if __name__ == "__main__": 161 | data = [ 162 | {"problem": "What is the sum of the first 100 natural numbers?", "synthetic_data": "The sum of the first 100 natural numbers is 5050."}, 163 | {"problem": "What is the sum of the first 100 natural numbers?", "synthetic_data": "The sum of the first 100 natural numbers is 5050. The sum of the first 100 natural numbers is 5050. "}, 164 | {"problem": "What is the sum of the first 100 natural numbers?", "synthetic_data": "The sum of the first 100 natural numbers is 5050. test 1 Traceback (most recent call last) "}, 165 | {"problem": "What is the sum of the first 100 natural numbers?", "synthetic_data": "The sum of the first 100 natural numbers is 5050. test 2 SyntaxError: I think the code is correct? "}, 166 | {"problem": "What is the sum of the first 100 natural numbers?", "synthetic_data": "The sum of the first 100 natural numbers is 5050. test 3 Error: Code execution timed out. I think the code is correct. "}, 167 | {"problem": "What is the sum of the first 100 natural numbers?", "synthetic_data": "The sum of the first 100 natural numbers is 5050. test 4 successful "}, 168 | {"problem": "What is the sum of the first 100 natural numbers?", "synthetic_data": "The sum of the first 100 natural numbers is 5050. hahahaha \nWait, the code is not correct, let's try text reasoning.\n I think the code is correct. "}, 169 | {"problem": "What is the sum of the first 100 natural numbers?", "synthetic_data": "The sum of the first 100 natural numbers is 5050. hahahaha Wait, use text reasoning is too tedious, let's try code reasoning. hahaha here is code! "}, 170 | ] 171 | tokenizer = AutoTokenizer.from_pretrained("models/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B") 172 | train_data = TrainData(data, tokenizer, CODE_INSTRUCTION, debug=True) -------------------------------------------------------------------------------- /sft/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import sys 4 | import gc, torch 5 | from torch.utils.data import Dataset 6 | from torch.optim import AdamW 7 | from torch.nn.utils.rnn import pad_sequence 8 | from tqdm import tqdm 9 | from accelerate import Accelerator 10 | from transformers import AutoTokenizer, AutoModelForCausalLM 11 | import time 12 | from typing import List, Union 13 | import numpy as np 14 | import random 15 | 16 | # Optional: Weights & Biases will only be used if a project name is passed. 17 | try: 18 | import wandb 19 | except ImportError: 20 | wandb = None 21 | 22 | parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) 23 | sys.path.append(parent_dir) 24 | 25 | from sft.dataloader import TrainData 26 | from math_utils.utils import read_json_or_jsonl 27 | CODE_INSTRUCTION = """Meanwhile, you can use Python code to help you reasoning. The code should be enclosed within tags. For example, code here . 28 | A executor will run the code and provide feedback immediately after the code. The executor feedback should be enclosed within tags. 29 | You can use the executor feedback to improve your reasoning. 30 | """ 31 | 32 | def train(model_path: str, 33 | data_path: Union[List[str], List[dict]], 34 | epochs: int, 35 | save_path: str, 36 | wandb_run=None, 37 | resume=False, 38 | resume_path=None, 39 | save_interval=1, 40 | batch_size=1, 41 | code_mode=False, 42 | max_data_count=None, 43 | data_seed=42, 44 | max_length=16384, 45 | lr=1e-5, 46 | gradient_accumulation_steps=4): 47 | """Train the model with the given parameters. 48 | 49 | Args: 50 | model_path (str): Path to the model checkpoint or identifier. 51 | data_path (List[str] or List[dict]): Path to the JSONL training data or list of training data. 52 | epochs (int): Number of training epochs. 53 | save_path (str): Where to save the fine-tuned model. 54 | wandb_run (wandb.Run or None): An optional wandb run instance. 55 | """ 56 | accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps) 57 | os.environ['TOKENIZERS_PARALLELISM'] = 'false' 58 | 59 | model_path = model_path[:-1] if model_path.endswith('/') else model_path 60 | if resume: 61 | load_path = resume_path 62 | start_epoch = int(resume_path.split('_')[-1]) + 1 63 | else: 64 | load_path = model_path 65 | start_epoch = 0 66 | tokenizer = AutoTokenizer.from_pretrained(load_path) 67 | model = AutoModelForCausalLM.from_pretrained( 68 | load_path, 69 | torch_dtype=torch.bfloat16, 70 | attn_implementation="flash_attention_2", 71 | device_map="auto" 72 | ) 73 | model.config.use_cache = False 74 | model.gradient_checkpointing_enable() 75 | 76 | data = [] 77 | for p in data_path: 78 | if isinstance(p, dict): 79 | data.append(p) 80 | else: 81 | data.extend(read_json_or_jsonl(p)) 82 | 83 | if code_mode: 84 | dataset = TrainData(data, tokenizer, CODE_INSTRUCTION, max_data_count=max_data_count, data_seed=data_seed, max_length=max_length) 85 | else: 86 | dataset = TrainData(data, tokenizer, "", max_data_count=max_data_count, data_seed=data_seed, max_length=max_length) 87 | data_loader = torch.utils.data.DataLoader( 88 | dataset, collate_fn=dataset.collate_fn, shuffle=True, 89 | batch_size=batch_size, num_workers=1 90 | ) 91 | optimizer = AdamW(model.parameters(), lr=lr) 92 | model, optimizer, data_loader = accelerator.prepare(model, optimizer, data_loader) 93 | 94 | global_step = start_epoch * len(data_loader) 95 | for epoch in range(start_epoch, epochs): 96 | accelerator.print(f'Training epoch {epoch}') 97 | accelerator.wait_for_everyone() 98 | model.train() 99 | 100 | tk0 = tqdm(data_loader, total=len(data_loader), disable=not accelerator.is_main_process) 101 | loss_report = [] 102 | 103 | loss_report = [] 104 | grad_acc = accelerator.gradient_accumulation_steps 105 | 106 | for step, batch in enumerate(tk0): 107 | with accelerator.accumulate(model): 108 | outputs = model(**batch) 109 | loss = outputs.loss 110 | accelerator.backward(loss) 111 | 112 | # --- each micro-step gather the loss for statistics --- 113 | loss_val = accelerator.gather(loss.detach()).mean().item() 114 | loss_report.append(loss_val) 115 | 116 | if accelerator.sync_gradients: 117 | accelerator.clip_grad_norm_(model.parameters(), 1.0) 118 | optimizer.step() 119 | optimizer.zero_grad() 120 | # --- WandB loss --- 121 | if wandb_run is not None: 122 | window = loss_report[-grad_acc:] 123 | wandb_run.log( 124 | {"train_loss": sum(window) / len(window), 125 | "epoch": epoch}, 126 | step=global_step 127 | ) 128 | global_step += 1 129 | 130 | # --- average loss --- 131 | tk0.set_postfix(loss=sum(loss_report[-100:]) / len(loss_report[-100:])) 132 | 133 | if (epoch + 1) % save_interval == 0 or epoch == epochs - 1: 134 | accelerator.wait_for_everyone() 135 | unwrapped_model = accelerator.unwrap_model(model) 136 | unwrapped_model.save_pretrained( 137 | f'{save_path}_{epoch}', 138 | is_main_process=accelerator.is_main_process, 139 | save_function=accelerator.save, 140 | ) 141 | tokenizer.save_pretrained(f'{save_path}_{epoch}') 142 | # Clean up 143 | del model, optimizer, data_loader, dataset, tokenizer 144 | 145 | def main(): 146 | import argparse 147 | parser = argparse.ArgumentParser() 148 | parser.add_argument("--model_path", type=str, default='models/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B') 149 | parser.add_argument("--data_path", type=str, nargs='+', default=["data/dataset/train/dual_distill_data.jsonl"]) 150 | parser.add_argument("--epochs", type=int, default=10) 151 | parser.add_argument("--resume", action='store_true') 152 | parser.add_argument("--resume_path", type=str, default=None) 153 | parser.add_argument("--code_mode", action='store_true') 154 | parser.add_argument("--save_path", type=str, default=None) 155 | parser.add_argument("--save_interval", type=int, default=5) 156 | parser.add_argument("--batch_size", type=int, default=1) 157 | parser.add_argument("--max_data_count", type=int, default=None) 158 | parser.add_argument("--data_seed", type=int, default=42) 159 | parser.add_argument("--max_length", type=int, default=16384) 160 | parser.add_argument( 161 | "--gradient_accumulation_steps", 162 | type=int, 163 | default=1, 164 | help="Number of batches to accumulate before each optimizer.step()" 165 | ) 166 | parser.add_argument("--lr", type=float, default=1e-5) 167 | # Added W&B arguments 168 | parser.add_argument("--use_wandb", action='store_true', default=False) 169 | parser.add_argument("--wandb_project", type=str, default="dualdistill", 170 | help="If set, will enable wandb logging to the given project.") 171 | parser.add_argument("--wandb_run_name", type=str, default=None, 172 | help="An optional run name for wandb.") 173 | 174 | args = parser.parse_args() 175 | print(args) 176 | 177 | # fix all seeds 178 | torch.manual_seed(args.data_seed) 179 | torch.cuda.manual_seed(args.data_seed) 180 | torch.cuda.manual_seed_all(args.data_seed) 181 | np.random.seed(args.data_seed) 182 | random.seed(args.data_seed) 183 | 184 | assert args.epochs % args.save_interval == 0, "epochs must be divisible by save_interval" 185 | if args.model_path is None and not args.eval_only: 186 | raise ValueError("model_path is required for training") 187 | time_tag = time.strftime("%Y%m%d_%H%M%S", time.localtime()) 188 | if args.resume: 189 | if args.resume_path is None: 190 | raise ValueError("resume_path is required for resume") 191 | time_tag = "_".join(args.resume_path.split("_")[-3:-1]) 192 | print(f"Resuming from {args.resume_path} with time tag {time_tag}") 193 | 194 | if args.save_path is None: 195 | if len(args.data_path) == 1: 196 | args.save_path = args.model_path.strip("/") + "_" + args.data_path[0].strip("/").split("/")[-1].split(".")[0] + "_fine-tuned" + "_" + time_tag 197 | else: 198 | args.save_path = args.model_path.strip("/") + "_" + args.data_path[0].strip("/").split("/")[-2].split(".")[0] + "_mixed_data_" + "fine-tuned" + "_" + time_tag 199 | else: 200 | args.save_path = args.model_path.strip("/") + "_" + args.save_path.strip("/") + "_" + time_tag 201 | 202 | if args.wandb_run_name is None: 203 | args.wandb_run_name = (args.save_path.strip("/").split("/")[-1] + 204 | "_" + str(args.code_mode) + 205 | "_" + str(args.epochs) + 206 | "_" + time_tag) 207 | 208 | # Initialize wandb if user has set a project 209 | wandb_run = None 210 | if args.use_wandb and wandb is not None: 211 | config = vars(args) 212 | wandb_run = wandb.init(project=args.wandb_project, name=args.wandb_run_name, config=config) 213 | 214 | train( 215 | model_path=args.model_path, 216 | data_path=args.data_path, 217 | epochs=args.epochs, 218 | save_path=args.save_path, 219 | wandb_run=wandb_run, 220 | resume=args.resume, 221 | resume_path=args.resume_path, 222 | save_interval=args.save_interval, 223 | batch_size=args.batch_size, 224 | code_mode=args.code_mode, 225 | max_data_count=args.max_data_count, 226 | data_seed=args.data_seed, 227 | max_length=args.max_length, 228 | lr=args.lr, 229 | gradient_accumulation_steps=args.gradient_accumulation_steps 230 | ) 231 | 232 | if wandb_run is not None: 233 | wandb_run.finish() 234 | 235 | if __name__ == '__main__': 236 | main() -------------------------------------------------------------------------------- /dataset/test/aime2025.jsonl: -------------------------------------------------------------------------------- 1 | {"data_source":"aime2025","prompt":[{"content":"Find the sum of all integer bases $b>9$ for which $17_{b}$ is a divisor of $97_{b}$.","role":"user"}],"ability":"math","reward_model":{"ground_truth":"70","num_tokens":-512,"style":"rule"},"extra_info":{"index":0,"split":"test"}} 2 | {"data_source":"aime2025","prompt":[{"content":"On $\\triangle ABC$ points $A,D,E$, and $B$ lie that order on side $\\overline{AB}$ with $AD=4, DE=16$, and $EB=8$. Points $A,F,G$, and $C$ lie in that order on side $\\overline{AC}$ with $AF=13, FG=52$, and $GC=26$. Let $M$ be the reflection of $D$ through $F$, and let $N$ be the reflection of $G$ through $E$. Quadrilateral $DEGF$ has area 288. Find the area of heptagon $AFNBCEM$.","role":"user"}],"ability":"math","reward_model":{"ground_truth":"588","num_tokens":-512,"style":"rule"},"extra_info":{"index":1,"split":"test"}} 3 | {"data_source":"aime2025","prompt":[{"content":"The 9 members of a baseball team went to an ice cream parlor after their game. Each player had a singlescoop cone of chocolate, vanilla, or strawberry ice cream. At least one player chose each flavor, and the number of players who chose chocolate was greater than the number of players who chose vanilla, which was greater than the number of players who chose strawberry. Let $N$ be the number of different assignments of flavors to players that meet these conditions. Find the remainder when $N$ is divided by 1000.","role":"user"}],"ability":"math","reward_model":{"ground_truth":"16","num_tokens":-512,"style":"rule"},"extra_info":{"index":2,"split":"test"}} 4 | {"data_source":"aime2025","prompt":[{"content":"Find the number of ordered pairs $(x,y)$, where both $x$ and $y$ are integers between $-100$ and $100$, inclusive, such that $12x^{2}-xy-6y^{2}=0$.","role":"user"}],"ability":"math","reward_model":{"ground_truth":"117","num_tokens":-512,"style":"rule"},"extra_info":{"index":3,"split":"test"}} 5 | {"data_source":"aime2025","prompt":[{"content":"There are $8!=40320$ eight-digit positive integers that use each of the digits $1,2,3,4,5,6,7,8$ exactly once. Let $N$ be the number of these integers that are divisible by 22. Find the difference between $N$ and 2025.","role":"user"}],"ability":"math","reward_model":{"ground_truth":"279","num_tokens":-512,"style":"rule"},"extra_info":{"index":4,"split":"test"}} 6 | {"data_source":"aime2025","prompt":[{"content":"An isosceles trapezoid has an inscribed circle tangent to each of its four sides. The radius of the circle is 3, and the area of the trapezoid is 72. Let the parallel sides of the trapezoid have lengths $r$ and $s$, with $r \\neq s$. Find $r^{2}+s^{2}$.","role":"user"}],"ability":"math","reward_model":{"ground_truth":"504","num_tokens":-512,"style":"rule"},"extra_info":{"index":5,"split":"test"}} 7 | {"data_source":"aime2025","prompt":[{"content":"The twelve letters $A,B,C,D,E,F,G,H,I,J,K$, and $L$ are randomly grouped into six pairs of letters. The two letters in each pair are placed next to each other in alphabetical order to form six two-letter words, and those six words are listed alphabetically. For example, a possible result is $AB,CJ,DG,EK,FL,HI$. The probability that the last word listed contains $G$ is $\\frac{m}{n}$, where $m$ and $n$ are relatively prime positive integers. Find $m+n$.","role":"user"}],"ability":"math","reward_model":{"ground_truth":"821","num_tokens":-512,"style":"rule"},"extra_info":{"index":6,"split":"test"}} 8 | {"data_source":"aime2025","prompt":[{"content":"Let $k$ be real numbers such that the system $|25+20i-z|=5$ and $|z-4-k|=|z-3i-k|$ has exactly one complex solution $z$. The sum of all possible values of $k$ can be written as $\\frac{m}{n}$, where $m$ and $n$ are relatively prime positive integers. Find $m+n$. Here $i=\\sqrt{-1}$.","role":"user"}],"ability":"math","reward_model":{"ground_truth":"77","num_tokens":-512,"style":"rule"},"extra_info":{"index":7,"split":"test"}} 9 | {"data_source":"aime2025","prompt":[{"content":"The parabola with equation $y=x^{2}-4$ is rotated $60^{\\circ}$ counterclockwise around the origin. The unique point in the fourth quadrant where the original parabola and its image intersect has $y$-coordinate $\\frac{a-\\sqrt{b}}{c}$, where $a$, $b$, and $c$ are positive integers, and $a$ and $c$ are relatively prime. Find $a+b+c$.","role":"user"}],"ability":"math","reward_model":{"ground_truth":"62","num_tokens":-512,"style":"rule"},"extra_info":{"index":8,"split":"test"}} 10 | {"data_source":"aime2025","prompt":[{"content":"The 27 cells of a $3\\times9$ grid are filled in using the numbers 1 through 9 so that each row contains 9 different numbers, and each of the three $3\\times3$ blocks heavily outlined in the example below contains 9 different numbers, as in the first three rows of a Sudoku puzzle. \n | 4 | 2 | 8 | 9 | 6 | 3 | 1 | 7 | 5 | \n | 3 | 7 | 9 | 5 | 2 | 1 | 6 | 8 | 4 | \n | 5 | 6 | 1 | 8 | 4 | 7 | 9 | 2 | 3 | \n The number of different ways to fill such a grid can be written as $p^a\\cdot q^b\\cdot r^c\\cdot s^d$, where $p,q,r,$ and $s$ are distinct prime numbers and $a,b,c,$ and $d$ are positive integers. Find $p\\cdot a+q\\cdot b+r\\cdot c+s\\cdot d$.","role":"user"}],"ability":"math","reward_model":{"ground_truth":"81","num_tokens":-512,"style":"rule"},"extra_info":{"index":9,"split":"test"}} 11 | {"data_source":"aime2025","prompt":[{"content":"A piecewise linear periodic function is defined by $f(x)=\\begin{cases}x&\\text{if }x\\in[-1,1)\\\\2-x&\\text{if }x\\in[1,3)\\end{cases}$ and $f(x+4)=f(x)$ for all real numbers $x$. The graph of $f(x)$ has the sawtooth pattern. The parabola $x=34y^2$ intersects the graph of $f(x)$ at finitely many points. The sum of the $y$-coordinates of these intersection points can be expressed in the form $\\frac{a+b\\sqrt{c}}{d}$, where $a,b,c,$ and $d$ are positive integers, $a,b,$ and $d$ have greatest common divisor equal to 1, and $c$ is not divisible by the square of any prime. Find $a+b+c+d$.","role":"user"}],"ability":"math","reward_model":{"ground_truth":"259","num_tokens":-512,"style":"rule"},"extra_info":{"index":10,"split":"test"}} 12 | {"data_source":"aime2025","prompt":[{"content":"The set of points in 3-dimensional coordinate space that lie in the plane $x+y+z=75$ whose coordinates satisfy the inequalities $x-yz>')\n" + raw_code 53 | 54 | # incremental code execution 55 | os.makedirs("workspace", exist_ok=True) 56 | with tempfile.NamedTemporaryFile("w", dir="workspace", suffix=".py", delete=False, encoding='utf-8') as tmp_file: 57 | tmp_file.write(raw_code) 58 | tmp_filename = tmp_file.name 59 | try: 60 | result = subprocess.run( 61 | ["python", tmp_filename], 62 | capture_output=True, 63 | text=True, 64 | timeout=3 65 | ) 66 | os.remove(tmp_filename) 67 | output = result.stdout.split("<>\n")[-1] 68 | if result.stderr: 69 | return output, result.stderr, previous_code 70 | else: 71 | return output, result.stderr, previous_code + "\n" + raw_code 72 | except subprocess.TimeoutExpired: 73 | os.remove(tmp_filename) 74 | return "", "Error: Code execution timed out.", previous_code 75 | 76 | def code_block_with_io(raw_code: str): 77 | code = raw_code.replace("```python", "").replace("```", "") 78 | 79 | stdout_capture = io.StringIO() 80 | stderr_capture = io.StringIO() 81 | 82 | env = {} 83 | 84 | try: 85 | tree = ast.parse(code, mode='exec') 86 | 87 | with contextlib.redirect_stdout(stdout_capture), contextlib.redirect_stderr(stderr_capture): 88 | for node in tree.body: 89 | if isinstance(node, ast.Expr): 90 | expr_code = compile(ast.Expression(node.value), filename="", mode="eval") 91 | result = eval(expr_code, env) 92 | if result is not None: 93 | print(result) 94 | else: 95 | stmt_code = compile(ast.Module([node], []), filename="", mode="exec") 96 | exec(stmt_code, env) 97 | except Exception as e: 98 | stderr_capture.write(str(e)) 99 | 100 | return stdout_capture.getvalue(), stderr_capture.getvalue() 101 | 102 | def parse_boxed(text, reverse=False): 103 | """ 104 | Returns a list of all the contents inside \\boxed{...} in `text`, 105 | handling nested braces to a reasonable extent. 106 | """ 107 | results = [] 108 | search_start = 0 109 | marker = r'\boxed{' 110 | 111 | while True: 112 | # Look for the next occurrence of \boxed{ 113 | start_index = text.find(marker, search_start) 114 | if start_index == -1: 115 | # No more \boxed{ found 116 | break 117 | 118 | # The position right after '\boxed{' 119 | brace_start = start_index + len(marker) 120 | 121 | # Use a stack to find the matching '}' 122 | brace_count = 1 123 | pos = brace_start 124 | 125 | while pos < len(text) and brace_count > 0: 126 | if text[pos] == '{': 127 | brace_count += 1 128 | elif text[pos] == '}': 129 | brace_count -= 1 130 | pos += 1 131 | 132 | # If brace_count == 0, 'pos-1' is where the matching '}' was found 133 | if brace_count == 0: 134 | content = text[brace_start : pos - 1] 135 | results.append(content) 136 | # Continue searching after this boxed content 137 | search_start = pos 138 | else: 139 | # We reached the end of the text without finding a matching brace 140 | break 141 | if len(results) == 0: 142 | return "No Answer" 143 | if not reverse: 144 | return results[0] 145 | else: 146 | return results[-1] 147 | 148 | def read_json_or_jsonl(file_path): 149 | with open(file_path, 'r') as f: 150 | if file_path.endswith('.json'): 151 | return json.load(f) 152 | elif file_path.endswith('.jsonl'): 153 | return [json.loads(line) for line in f] 154 | 155 | SYSTEM_PROMPT_TPL = ( 156 | "A conversation between User and Assistant. The user asks a question, " 157 | "and the Assistant solves it.\n" 158 | "The assistant first thinks about the reasoning process in the mind and then " 159 | "provides the user with the answer. \n" 160 | "The reasoning process and answer are enclosed within and " 161 | " tags, respectively, i.e., reasoning process here " 162 | " answer here .\n\n" 163 | "The final answer should be enclosed within \\boxed tags, i.e., " 164 | "\\boxed{{answer here}}.\n\n" 165 | "{code_instruction}\n\n" 166 | "Do not write text outside the tags." 167 | ) 168 | 169 | CODE_INSTRUCTION = """Meanwhile, you can use Python code to help you reason. The code should be enclosed within tags. For example, code here . 170 | An executor will run the code and provide feedback immediately after the code. The executor feedback should be enclosed within tags. 171 | You can use the executor feedback to improve your reasoning. 172 | """ 173 | 174 | def add_comma_into_number(number_str): 175 | try: 176 | number = float(number_str) 177 | if number.is_integer(): 178 | return "{:,}".format(int(number)) 179 | else: 180 | return "{:,}".format(number) 181 | except (ValueError, TypeError): 182 | return number_str 183 | 184 | def compute_score(solution_str, ground_truth, able_to_use_original_solution=False) -> float: 185 | ground_truth = str(ground_truth) 186 | 187 | # for non-finetuned model (i.e., except Agentic-R1), use additional pass by the original solution 188 | original_solution_str = solution_str 189 | 190 | # For the agentic trajectory, remove the redundant part 191 | if "" in solution_str: 192 | interim_strs = ["", "", "", ""] 193 | found = [(solution_str.rfind(s), i) for i, s in enumerate(interim_strs) if solution_str.rfind(s) != -1] 194 | if found: 195 | last_interim_str, last_interim_str_idx = max(found) 196 | solution_str = solution_str[:last_interim_str + len(interim_strs[last_interim_str_idx])] 197 | 198 | if "" in solution_str and "" in solution_str: 199 | # for the case that the model output a final answer within ... 200 | solution_str_parsed = solution_str.split("")[1].split("")[0] 201 | if ground_truth in solution_str_parsed or add_comma_into_number(ground_truth) in solution_str_parsed: 202 | return 1. 203 | 204 | elif "" in solution_str and "" in solution_str: 205 | # otherwise, try to parse the last output from the executor 206 | solution_str_parsed = solution_str.split("")[1].split("")[0] 207 | if ground_truth in solution_str_parsed or add_comma_into_number(ground_truth) in solution_str_parsed: 208 | return 1. 209 | 210 | elif "" in solution_str and "" not in solution_str: 211 | # for the case that the model output is ... and answer after think, i.e., the text-reasoning model 212 | solution_str_parsed = solution_str.split("")[-1] 213 | if ground_truth in solution_str_parsed or add_comma_into_number(ground_truth) in solution_str_parsed: 214 | return 1. 215 | 216 | # fuzzy match the ground truth and the solution 217 | retval = 0. 218 | if '\\boxed' not in ground_truth: ground_truth = '\\boxed{' + str(ground_truth) + '}' 219 | correct = verify(parse(ground_truth), parse(solution_str)) or (able_to_use_original_solution and verify(parse(ground_truth), parse(original_solution_str))) 220 | if correct: 221 | retval = 1. 222 | 223 | return retval 224 | 225 | # string normalization from https://github.com/EleutherAI/lm-evaluation-harness/blob/master/lm_eval/tasks/hendrycks_math.py 226 | def is_equiv(str1, str2, verbose=False): 227 | if str1 is None and str2 is None: 228 | print("WARNING: Both None") 229 | return True 230 | if str1 is None or str2 is None: 231 | return False 232 | 233 | try: 234 | ss1 = strip_string(str1) 235 | ss2 = strip_string(str2) 236 | if verbose: 237 | print(ss1, ss2) 238 | return ss1 == ss2 239 | except Exception: 240 | return str1 == str2 241 | 242 | 243 | def remove_boxed(s): 244 | if "\\boxed " in s: 245 | left = "\\boxed " 246 | assert s[:len(left)] == left 247 | return s[len(left):] 248 | 249 | left = "\\boxed{" 250 | 251 | assert s[:len(left)] == left 252 | assert s[-1] == "}" 253 | 254 | return s[len(left):-1] 255 | 256 | 257 | def last_boxed_only_string(string): 258 | idx = string.rfind("\\boxed") 259 | if "\\boxed " in string: 260 | return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0] 261 | if idx < 0: 262 | idx = string.rfind("\\fbox") 263 | if idx < 0: 264 | return None 265 | 266 | i = idx 267 | right_brace_idx = None 268 | num_left_braces_open = 0 269 | while i < len(string): 270 | if string[i] == "{": 271 | num_left_braces_open += 1 272 | if string[i] == "}": 273 | num_left_braces_open -= 1 274 | if num_left_braces_open == 0: 275 | right_brace_idx = i 276 | break 277 | i += 1 278 | 279 | if right_brace_idx is None: 280 | retval = None 281 | else: 282 | retval = string[idx:right_brace_idx + 1] 283 | 284 | return retval 285 | 286 | 287 | def fix_fracs(string): 288 | substrs = string.split("\\frac") 289 | new_str = substrs[0] 290 | if len(substrs) > 1: 291 | substrs = substrs[1:] 292 | for substr in substrs: 293 | new_str += "\\frac" 294 | if substr[0] == "{": 295 | new_str += substr 296 | else: 297 | try: 298 | assert len(substr) >= 2 299 | except AssertionError: 300 | return string 301 | a = substr[0] 302 | b = substr[1] 303 | if b != "{": 304 | if len(substr) > 2: 305 | post_substr = substr[2:] 306 | new_str += "{" + a + "}{" + b + "}" + post_substr 307 | else: 308 | new_str += "{" + a + "}{" + b + "}" 309 | else: 310 | if len(substr) > 2: 311 | post_substr = substr[2:] 312 | new_str += "{" + a + "}" + b + post_substr 313 | else: 314 | new_str += "{" + a + "}" + b 315 | string = new_str 316 | return string 317 | 318 | 319 | def fix_a_slash_b(string): 320 | if len(string.split("/")) != 2: 321 | return string 322 | a = string.split("/")[0] 323 | b = string.split("/")[1] 324 | try: 325 | a = int(a) 326 | b = int(b) 327 | assert string == "{}/{}".format(a, b) 328 | new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" 329 | return new_string 330 | except AssertionError: 331 | return string 332 | 333 | 334 | def remove_right_units(string): 335 | # "\\text{ " only ever occurs (at least in the val set) when describing units 336 | if "\\text{ " in string: 337 | splits = string.split("\\text{ ") 338 | assert len(splits) == 2 339 | return splits[0] 340 | else: 341 | return string 342 | 343 | 344 | def fix_sqrt(string): 345 | if "\\sqrt" not in string: 346 | return string 347 | splits = string.split("\\sqrt") 348 | new_string = splits[0] 349 | for split in splits[1:]: 350 | if split[0] != "{": 351 | a = split[0] 352 | new_substr = "\\sqrt{" + a + "}" + split[1:] 353 | else: 354 | new_substr = "\\sqrt" + split 355 | new_string += new_substr 356 | return new_string 357 | 358 | 359 | def strip_string(string): 360 | # linebreaks 361 | string = string.replace("\n", "") 362 | 363 | # remove inverse spaces 364 | string = string.replace("\\!", "") 365 | 366 | # replace \\ with \ 367 | string = string.replace("\\\\", "\\") 368 | 369 | # replace tfrac and dfrac with frac 370 | string = string.replace("tfrac", "frac") 371 | string = string.replace("dfrac", "frac") 372 | 373 | # remove \left and \right 374 | string = string.replace("\\left", "") 375 | string = string.replace("\\right", "") 376 | 377 | # Remove circ (degrees) 378 | string = string.replace("^{\\circ}", "") 379 | string = string.replace("^\\circ", "") 380 | 381 | # remove dollar signs 382 | string = string.replace("\\$", "") 383 | 384 | # remove units (on the right) 385 | string = remove_right_units(string) 386 | 387 | # remove percentage 388 | string = string.replace("\\%", "") 389 | string = string.replace("\%", "") # noqa: W605 390 | 391 | # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string 392 | string = string.replace(" .", " 0.") 393 | string = string.replace("{.", "{0.") 394 | # if empty, return empty string 395 | if len(string) == 0: 396 | return string 397 | if string[0] == ".": 398 | string = "0" + string 399 | 400 | # to consider: get rid of e.g. "k = " or "q = " at beginning 401 | if len(string.split("=")) == 2: 402 | if len(string.split("=")[0]) <= 2: 403 | string = string.split("=")[1] 404 | 405 | # fix sqrt3 --> sqrt{3} 406 | string = fix_sqrt(string) 407 | 408 | # remove spaces 409 | string = string.replace(" ", "") 410 | 411 | # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} 412 | string = fix_fracs(string) 413 | 414 | # manually change 0.5 --> \frac{1}{2} 415 | if string == "0.5": 416 | string = "\\frac{1}{2}" 417 | 418 | # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y 419 | string = fix_a_slash_b(string) 420 | 421 | return string 422 | 423 | if __name__ == "__main__": 424 | code_output, error, previous_code = code_block("a=5\nprint(a*2)", is_ipython=False, previous_code="") 425 | print(code_output, error) 426 | code_output, error, previous_code = code_block("b=a*2\nprint(b + 1)", is_ipython=False, previous_code=previous_code) 427 | print(code_output, error) 428 | code_output, error, previous_code = code_block("import time\ntime.sleep(5)\nprint('ok')", is_ipython=False, previous_code=previous_code) 429 | print(code_output, error) 430 | code_output, error, previous_code = code_block("print(b * 2)", is_ipython=False, previous_code=previous_code) 431 | print(code_output, error) 432 | code_output, error, previous_code = code_block("print(c * 2)", is_ipython=False, previous_code=previous_code) 433 | print(code_output, error) 434 | code_output, error, previous_code = code_block("print(b * 2)", is_ipython=False, previous_code=previous_code) 435 | print(code_output, error) 436 | -------------------------------------------------------------------------------- /sft/evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | import json 5 | from vllm import LLM, SamplingParams 6 | import sys 7 | from tqdm import tqdm 8 | from openai import OpenAI 9 | from transformers import AutoTokenizer 10 | 11 | parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) 12 | sys.path.append(parent_dir) 13 | 14 | from math_utils.utils import compute_score 15 | 16 | from math_utils.utils import read_json_or_jsonl, code_block, find_question, find_answer 17 | from concurrent.futures import ThreadPoolExecutor 18 | def chunks_to_text(chunks): 19 | count = 0 20 | text = "" 21 | for chunk in chunks: 22 | text += chunk.choices[0].text 23 | count += 1 24 | return text, count 25 | 26 | def get_token_length(text, tokenizer): 27 | return len(tokenizer.encode(text, add_special_tokens=False)) 28 | 29 | def server_inference(model_base_url: str, model_name: str, tokenizer_path: str, input: str, code_mode=False, max_tokens=4096, is_ipython=False): 30 | """Perform inference with vLLM. 31 | 32 | Args: 33 | model_base_url (str): The base URL of the model. 34 | model_name (str): The name of the model. 35 | tokenizer_path (str): The path to the tokenizer. 36 | input (str): the question 37 | code_mode (bool): Whether to run the special code-block logic. 38 | max_tokens (int): Maximum tokens to generate. 39 | is_ipython (bool): Whether to run the special ipython logic. 40 | Returns: 41 | str: The generated text. 42 | """ 43 | 44 | # Modify OpenAI's API key and API base to use vLLM's API server. 45 | openai_api_key = "EMPTY" # TODO: remove this or set it to the correct value 46 | openai_api_base = model_base_url 47 | client = OpenAI( 48 | base_url=openai_api_base, 49 | api_key=openai_api_key, 50 | ) 51 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) 52 | CODE_INSTRUCTION = """Meanwhile, you can use Python code to help you reasoning. The code should be enclosed within tags. For example, code here . 53 | A executor will run the code and provide feedback immediately after the code. The executor feedback should be enclosed within tags. 54 | You can use the executor feedback to improve your reasoning. 55 | """ 56 | 57 | input_format = """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. 58 | The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. 59 | The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here . 60 | The final answer should be enclosed within \\boxed tags, i.e., \\boxed{{answer here}}. 61 | {code_instruction} 62 | 63 | <|User|>{problem} 64 | 65 | <|Assistant|> 66 | """ 67 | input_prompt = input_format.format( 68 | problem=input, 69 | code_instruction=CODE_INSTRUCTION if code_mode else "" 70 | ) 71 | raw_input_prompt = input_prompt 72 | 73 | output = client.completions.create( 74 | model=model_name, 75 | prompt=input_prompt, 76 | max_tokens=max_tokens, 77 | temperature=0.6, 78 | extra_body={"stop": [""], "include_stop_str_in_output": True, "skip_special_tokens": False}, 79 | stream=True, 80 | ) 81 | llm_output, token_length = chunks_to_text(output) 82 | total_inference_length = token_length 83 | previous_code = "" 84 | # Keep going when we see '' in the last chunk 85 | code_use_count = 0 86 | while "" in llm_output and total_inference_length < max_tokens and "" not in llm_output and code_use_count < 10: 87 | code_use_count += 1 88 | try: 89 | raw_code = llm_output.split("")[-1].split("")[0] 90 | code_output, error, previous_code = code_block(raw_code, is_ipython=is_ipython, previous_code=previous_code) 91 | except Exception as e: 92 | error = str(e) 93 | code_output = "Error" 94 | if error: 95 | executor_feedback = f"\n\n{error}\n\n" 96 | else: 97 | executor_feedback = f"\n\n{code_output}\n\n" 98 | 99 | llm_output = ( 100 | llm_output.split("")[0] 101 | + "" 102 | + executor_feedback 103 | ) 104 | total_inference_length += len(tokenizer.encode(executor_feedback, add_special_tokens=False)) 105 | if total_inference_length >= max_tokens: 106 | break 107 | 108 | input_prompt = input_prompt + llm_output 109 | output = client.completions.create( 110 | model=model_name, 111 | prompt=input_prompt, 112 | max_tokens=max_tokens - total_inference_length, 113 | temperature=0.6, 114 | extra_body={"stop": [""], "include_stop_str_in_output": True, "skip_special_tokens": False}, 115 | stream=True, 116 | ) 117 | llm_output, token_length = chunks_to_text(output) 118 | total_inference_length += token_length 119 | 120 | if code_use_count == 10 and total_inference_length < max_tokens: 121 | # excced the max number of code use 122 | input_prompt = input_prompt + llm_output + "\n\nThe code use count has exceeded the limit. Please stop using code.\n\n" 123 | output = client.completions.create( 124 | model=model_name, 125 | prompt=input_prompt, 126 | max_tokens=max_tokens - total_inference_length, 127 | temperature=0.6, 128 | extra_body={"stop": [""], "include_stop_str_in_output": True, "skip_special_tokens": False}, 129 | stream=True, 130 | ) 131 | llm_output, token_length = chunks_to_text(output) 132 | total_inference_length += token_length 133 | 134 | whole_output = (input_prompt + llm_output)[len(raw_input_prompt):] 135 | 136 | # truncate the output to max_tokens 137 | whole_output_tokens = tokenizer.encode(whole_output, add_special_tokens=False) 138 | whole_output_tokens = whole_output_tokens[:max_tokens] 139 | whole_output = tokenizer.decode(whole_output_tokens, skip_special_tokens=True) 140 | return raw_input_prompt + whole_output 141 | 142 | def inference(model: LLM, input: str, code_mode=False, max_tokens=4096, is_ipython=False, model_name = None, tokenizer_path = None): 143 | """Perform inference with vLLM. 144 | 145 | Args: 146 | model (LLM): The vLLM instance. 147 | input (str): the question 148 | code_mode (bool): Whether to run the special code-block logic. 149 | max_tokens (int): Maximum tokens to generate. 150 | is_ipython (bool): Whether to run the special ipython logic. 151 | Returns: 152 | str: The generated text. 153 | """ 154 | 155 | CODE_INSTRUCTION = """Meanwhile, you can use Python code to help you reasoning. The code should be enclosed within tags. For example, code here . 156 | A executor will run the code and provide feedback immediately after the code. The executor feedback should be enclosed within tags. 157 | You can use the executor feedback to improve your reasoning. 158 | """ 159 | 160 | input_format = """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. 161 | The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. 162 | The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here . 163 | The final answer should be enclosed within \\boxed tags, i.e., \\boxed{{answer here}}. 164 | {code_instruction} 165 | 166 | <|User|>{problem} 167 | 168 | <|Assistant|> 169 | """ 170 | input_prompt = input_format.format( 171 | problem=input, 172 | code_instruction=CODE_INSTRUCTION if code_mode else "" 173 | ) 174 | input_prefix_tokens = model.get_tokenizer().encode(input_prompt) 175 | prefix_length = len(input_prefix_tokens) 176 | 177 | if code_mode: 178 | total_inference_length = 0 179 | sampling_params = SamplingParams( 180 | max_tokens=max_tokens, 181 | temperature=0.6, 182 | stop=[""], 183 | include_stop_str_in_output=True, 184 | skip_special_tokens=False 185 | ) 186 | output = model.generate( 187 | prompts=[{"prompt_token_ids": input_prefix_tokens}], 188 | sampling_params=sampling_params, 189 | use_tqdm=False 190 | ) 191 | llm_output = output[0].outputs[0].text 192 | total_inference_length += len(output[0].outputs[0].token_ids) 193 | 194 | previous_code = "" 195 | # Keep going when we see '' in the last chunk 196 | code_use_count = 0 197 | while "" in llm_output and total_inference_length <= max_tokens and "" not in llm_output and code_use_count < 10: 198 | code_use_count += 1 199 | try: 200 | raw_code = llm_output.split("")[-1].split("")[0] 201 | code_output, error, previous_code = code_block(raw_code, is_ipython=is_ipython, previous_code=previous_code) 202 | except Exception as e: 203 | error = str(e) 204 | code_output = "Error" 205 | 206 | if error: 207 | executor_feedback = f"\n\n{error}\n\n" 208 | else: 209 | executor_feedback = f"\n\n{code_output}\n\n" 210 | 211 | llm_output = ( 212 | llm_output.split("")[0] 213 | + "" 214 | + executor_feedback 215 | ) 216 | total_inference_length += len( 217 | model.get_tokenizer().encode(executor_feedback, add_special_tokens=False) 218 | ) 219 | if total_inference_length > max_tokens: 220 | break 221 | input_prefix_tokens.extend( 222 | model.get_tokenizer().encode(llm_output, add_special_tokens=False) 223 | ) 224 | output = model.generate( 225 | prompts=[{"prompt_token_ids": input_prefix_tokens}], 226 | sampling_params=sampling_params, 227 | use_tqdm=False 228 | ) 229 | llm_output = output[0].outputs[0].text 230 | total_inference_length += len(output[0].outputs[0].token_ids) 231 | 232 | if total_inference_length > max_tokens: 233 | # TODO: check if the output is truncated 234 | # If we reached the token limit, force close the reasoning 235 | input_prefix_tokens.extend( 236 | model.get_tokenizer().encode(llm_output, add_special_tokens=False) 237 | ) 238 | input_prefix_tokens = input_prefix_tokens[: max_tokens + prefix_length] 239 | llm_output = "" 240 | else: 241 | sampling_params = SamplingParams(max_tokens=max_tokens, temperature=0.6) 242 | output = model.generate( 243 | prompts=[{"prompt_token_ids": input_prefix_tokens}], 244 | sampling_params=sampling_params, 245 | use_tqdm=False 246 | ) 247 | llm_output = output[0].outputs[0].text 248 | 249 | whole_output = ( 250 | model.get_tokenizer().decode(input_prefix_tokens, skip_special_tokens=True) 251 | + llm_output 252 | ) 253 | return whole_output 254 | 255 | def llm_eval(model_path: str, 256 | data_path: str = None, 257 | prompt: str = None, 258 | code_mode: bool = False, 259 | max_tokens: int = 4096, 260 | generation_save_path: str = "result", 261 | is_ipython: bool = False, 262 | use_server_inference: bool = False, 263 | model_name = "DeepSeek-R1-Distill-Qwen-7B_finetuned", 264 | tokenizer_path = "models/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B", 265 | num_samples: int = 1): 266 | """Evaluate a model on a dataset or a single prompt. 267 | 268 | Args: 269 | model_path (str): Path to the model or model name. 270 | data_path (str, optional): JSON data file path for batch evaluation. Defaults to None. 271 | prompt (str, optional): A single prompt for inference. Defaults to None. 272 | code_mode (bool): Whether to parse code blocks in the response. 273 | max_tokens (int): Maximum tokens to generate. 274 | generation_save_path (str): Folder to store output text or results. 275 | wandb_run (wandb.Run or None): Optional wandb run to log metrics. 276 | """ 277 | # call nvidia-smi to check the memory usage 278 | if not use_server_inference: 279 | display_name = model_path.strip("/").split("/")[-1] 280 | else: 281 | display_name = model_name 282 | 283 | if "/" not in generation_save_path: 284 | generation_save_path = os.path.join( 285 | generation_save_path, 286 | time.strftime("%Y%m%d_%H%M%S", time.localtime()) 287 | + ("_" + display_name) 288 | + ("_" + data_path.strip("/").split("/")[-1] if data_path else "") 289 | + ("_" + str(code_mode)) 290 | + ("_" + str(num_samples)) 291 | + ("_" + str(max_tokens)) 292 | ) 293 | if not os.path.exists(generation_save_path): 294 | os.makedirs(generation_save_path) 295 | if use_server_inference: 296 | resume_id = [] 297 | else: 298 | resume_id = 0 299 | else: 300 | # resume generation 301 | file_list = os.listdir(generation_save_path) 302 | file_list = [f for f in file_list if f.endswith(".txt")] 303 | # find the biggest number 304 | ids = [int(f.split(".")[0]) for f in file_list] 305 | if len(ids) > 0: 306 | if use_server_inference: 307 | resume_id = ids 308 | else: 309 | resume_id = (max(ids) + 1) // num_samples 310 | else: 311 | if use_server_inference: 312 | resume_id = [] 313 | else: 314 | resume_id = 0 315 | 316 | if not use_server_inference: 317 | assert num_samples == 1, "num_samples must be 1 when use_server_inference is False" 318 | model = LLM(model=model_path, tokenizer=model_path, swap_space=8, tensor_parallel_size=1, gpu_memory_utilization=0.8) 319 | # it is set to 1 to avoid the fork or spawn error 320 | 321 | # TODO: the format is for Qwen model, we need to change it to the format for other models 322 | assert data_path is not None or prompt is not None, "Must provide either data_path or prompt." 323 | 324 | if prompt is not None: 325 | result = inference( 326 | model, 327 | prompt, 328 | code_mode=code_mode, 329 | max_tokens=max_tokens, 330 | is_ipython=is_ipython 331 | ) 332 | print(result) 333 | 334 | if data_path is None: 335 | del model 336 | torch.cuda.empty_cache() 337 | 338 | if data_path is not None: 339 | data = read_json_or_jsonl(data_path) 340 | if os.path.exists(os.path.join(generation_save_path, "result.json")): 341 | with open(os.path.join(generation_save_path, "result.json"), "r") as f: 342 | result = json.load(f) 343 | total = result["total"] 344 | correct = result["correct"] 345 | assert total == resume_id, f"The resume_id is not consistent with the result.json file, {total} != {resume_id}" 346 | else: 347 | total = 0 348 | correct = 0 349 | pbar = tqdm(data[resume_id:], desc="Evaluating", total=len(data) - resume_id) 350 | for item in pbar: 351 | output = inference( 352 | model, 353 | find_question(item), 354 | code_mode=code_mode, 355 | max_tokens=max_tokens, 356 | is_ipython=is_ipython 357 | ) 358 | answer = item['answer'] if 'answer' in item else item['reward_model']['ground_truth'] if 'reward_model' in item else item["final_answer"] 359 | try: 360 | score = compute_score(output.split("<|Assistant|>")[-1], answer) 361 | if score == 1: 362 | correct += 1 363 | except Exception as e: 364 | correct += 0 365 | total += 1 366 | accuracy = correct / total 367 | pbar.set_postfix(correct=correct, accuracy=f"{accuracy:.2%}") 368 | 369 | text_validation = ( 370 | "\n########################################################\n" 371 | f"correct answer: {answer}\n" 372 | f"score: {score}\n" 373 | "########################################################" 374 | ) 375 | with open(os.path.join(generation_save_path, f"{total - 1}.txt"), "w", encoding="utf-8") as f: 376 | f.write(output + text_validation) 377 | with open(os.path.join(generation_save_path, "result.json"), "w") as f: 378 | f.write(json.dumps({"total": total, "correct": correct, "accuracy": accuracy})) 379 | 380 | del model 381 | torch.cuda.empty_cache() 382 | return correct / total 383 | else: 384 | from concurrent.futures import ThreadPoolExecutor, as_completed 385 | import threading 386 | 387 | def run_one(item_idx: int, sample_idx: int, item): 388 | """ 389 | One call of `server_inference` ⇒ returns (item_idx, sample_idx, output, score) 390 | """ 391 | answer = find_answer(item) 392 | output = server_inference(model_base_url=model_path, 393 | model_name=model_name, 394 | tokenizer_path=tokenizer_path, 395 | input=find_question(item), 396 | code_mode=code_mode, 397 | max_tokens=max_tokens, 398 | is_ipython=is_ipython) 399 | return item_idx, sample_idx, output, answer 400 | 401 | data = read_json_or_jsonl(data_path) 402 | n_calls = [i for i in range(len(data) * num_samples) if i not in resume_id] 403 | 404 | tot_lock = threading.Lock() 405 | total = 0 406 | correct = 0 407 | 408 | pbar = tqdm(total=len(n_calls), desc="Evaluating") 409 | 410 | with ThreadPoolExecutor(max_workers=num_samples) as pool: 411 | futures = [] 412 | for i in n_calls: 413 | futures.append(pool.submit(run_one, i // num_samples, i % num_samples, data[i // num_samples])) 414 | 415 | for fut in as_completed(futures): 416 | item_idx, sample_idx, output, answer = fut.result() 417 | score = compute_score(output.split("<|Assistant|>")[-1], answer) 418 | 419 | # --- global bookkeeping (thread-safe) --- 420 | with tot_lock: 421 | total += 1 422 | correct += (score == 1) 423 | 424 | # --- write per-sample result --- 425 | fname = os.path.join(generation_save_path, 426 | f"{item_idx*num_samples + sample_idx}.txt") 427 | text_validation = ( 428 | "\n########################################################\n" 429 | f"correct answer: {answer}\n" 430 | f"score: {score}\n" 431 | "########################################################" 432 | ) 433 | with open(fname, "w", encoding="utf-8") as f: 434 | f.write(output + text_validation) 435 | 436 | # --- write running summary (overwrite) --- 437 | with open(os.path.join(generation_save_path, "result.json"), "w") as f: 438 | json.dump({"total": total, 439 | "correct": correct, 440 | "accuracy": correct / total}, 441 | f, ensure_ascii=False, indent=2) 442 | 443 | pbar.set_postfix(correct=correct, accuracy=f"{correct/total:.2%}") 444 | pbar.update() 445 | 446 | pbar.close() 447 | return correct / total 448 | 449 | if __name__ == "__main__": 450 | import argparse 451 | parser = argparse.ArgumentParser() 452 | parser.add_argument("--model_path", type=str, default='models/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B') 453 | parser.add_argument("--model_name", type=str, default="DeepSeek-R1-Distill-Qwen-7B") 454 | parser.add_argument("--data_path", type=str, default="synthetic_data_7b_all_2.jsonl") 455 | parser.add_argument("--code_mode", action='store_true') 456 | parser.add_argument("--max_tokens", type=int, default=4096) 457 | parser.add_argument("--generation_save_path", type=str, default="result") 458 | parser.add_argument("--is_ipython", action='store_true') 459 | parser.add_argument("--use_server_inference", action='store_true') 460 | parser.add_argument("--num_samples", type=int, default=1) 461 | args = parser.parse_args() 462 | print(args) 463 | llm_eval( 464 | model_path=args.model_path, 465 | model_name=args.model_name, 466 | data_path=args.data_path, 467 | code_mode=args.code_mode, 468 | max_tokens=args.max_tokens, 469 | generation_save_path=args.generation_save_path, 470 | is_ipython=args.is_ipython, 471 | use_server_inference=args.use_server_inference, 472 | num_samples=args.num_samples 473 | ) -------------------------------------------------------------------------------- /dataset/test/amc.jsonl: -------------------------------------------------------------------------------- 1 | {"data_source":"","prompt":[{"content":"$\\frac{m}{n}$ is the Irreducible fraction value of \\[3+\\frac{1}{3+\\frac{1}{3+\\frac13}}\\], what is the value of $m+n$?","role":"user"}],"ability":"math","reward_model":{"ground_truth":142.0,"num_tokens":-512,"style":"rule"},"extra_info":{"index":0,"split":"test"}} 2 | {"data_source":"","prompt":[{"content":"How many ways are there to split the integers $1$ through $14$ into $7$ pairs such that in each pair, the greater number is at least $2$ times the lesser number?","role":"user"}],"ability":"math","reward_model":{"ground_truth":144.0,"num_tokens":-512,"style":"rule"},"extra_info":{"index":1,"split":"test"}} 3 | {"data_source":"","prompt":[{"content":"What is the product of all real numbers $x$ such that the distance on the number line between $\\log_6x$ and $\\log_69$ is twice the distance on the number line between $\\log_610$ and $1$?","role":"user"}],"ability":"math","reward_model":{"ground_truth":81.0,"num_tokens":-512,"style":"rule"},"extra_info":{"index":2,"split":"test"}} 4 | {"data_source":"","prompt":[{"content":"Let $M$ be the midpoint of $\\overline{AB}$ in regular tetrahedron $ABCD$. $\\frac{p}{q}=\\cos(\\angle CMD)$ is irreducible fraction, what is the value of $p+q$?","role":"user"}],"ability":"math","reward_model":{"ground_truth":4.0,"num_tokens":-512,"style":"rule"},"extra_info":{"index":3,"split":"test"}} 5 | {"data_source":"","prompt":[{"content":"Let $\\mathcal{R}$ be the region in the complex plane consisting of all complex numbers $z$ that can be written as the sum of complex numbers $z_1$ and $z_2$, where $z_1$ lies on the segment with endpoints $3$ and $4i$, and $z_2$ has magnitude at most $1$. What integer is closest to the area of $\\mathcal{R}$?","role":"user"}],"ability":"math","reward_model":{"ground_truth":13.0,"num_tokens":-512,"style":"rule"},"extra_info":{"index":4,"split":"test"}} 6 | {"data_source":"","prompt":[{"content":"What is the value of \\[(\\log 5)^{3}+(\\log 20)^{3}+(\\log 8)(\\log 0.25)\\] where $\\log$ denotes the base-ten logarithm?","role":"user"}],"ability":"math","reward_model":{"ground_truth":2.0,"num_tokens":-512,"style":"rule"},"extra_info":{"index":5,"split":"test"}} 7 | {"data_source":"","prompt":[{"content":"The roots of the polynomial $10x^3 - 39x^2 + 29x - 6$ are the height, length, and width of a rectangular box (right rectangular prism). A new rectangular box is formed by lengthening each edge of the original box by $2$\nunits. What is the volume of the new box?","role":"user"}],"ability":"math","reward_model":{"ground_truth":30.0,"num_tokens":-512,"style":"rule"},"extra_info":{"index":6,"split":"test"}} 8 | {"data_source":"","prompt":[{"content":"A $\\emph{triangular number}$ is a positive integer that can be expressed in the form $t_n = 1+2+3+\\cdots+n$, for some positive integer $n$. The three smallest triangular numbers that are also perfect squares are\n$t_1 = 1 = 1^2$, $t_8 = 36 = 6^2$, and $t_{49} = 1225 = 35^2$. What is the sum of the digits of the fourth smallest triangular number that is also a perfect square?","role":"user"}],"ability":"math","reward_model":{"ground_truth":18.0,"num_tokens":-512,"style":"rule"},"extra_info":{"index":7,"split":"test"}} 9 | {"data_source":"","prompt":[{"content":"Suppose $a$ is a real number such that the equation \\[a\\cdot(\\sin{x}+\\sin{(2x)}) = \\sin{(3x)}\\]\nhas more than one solution in the interval $(0, \\pi)$. The set of all such $a$ that can be written\nin the form \\[(p,q) \\cup (q,r),\\]\nwhere $p, q,$ and $r$ are real numbers with $p < q< r$. What is $p+q+r$?","role":"user"}],"ability":"math","reward_model":{"ground_truth":-4.0,"num_tokens":-512,"style":"rule"},"extra_info":{"index":8,"split":"test"}} 10 | {"data_source":"","prompt":[{"content":"Let $T_k$ be the transformation of the coordinate plane that first rotates the plane $k$ degrees counterclockwise around the origin and then reflects the plane across the $y$-axis. What is the least positive\ninteger $n$ such that performing the sequence of transformations $T_1, T_2, T_3, \\cdots, T_n$ returns the point $(1,0)$ back to itself?","role":"user"}],"ability":"math","reward_model":{"ground_truth":359.0,"num_tokens":-512,"style":"rule"},"extra_info":{"index":9,"split":"test"}} 11 | {"data_source":"","prompt":[{"content":"Suppose that $13$ cards numbered $1, 2, 3, \\ldots, 13$ are arranged in a row. The task is to pick them up in numerically increasing order, working repeatedly from left to right. In the example below, cards $1, 2, 3$ are picked up on the first pass, $4$ and $5$ on the second pass, $6$ on the third pass, $7, 8, 9, 10$ on the fourth pass, and $11, 12, 13$ on the fifth pass. For how many of the $13!$ possible orderings of the cards will the $13$ cards be picked up in exactly two passes?","role":"user"}],"ability":"math","reward_model":{"ground_truth":8178.0,"num_tokens":-512,"style":"rule"},"extra_info":{"index":10,"split":"test"}} 12 | {"data_source":"","prompt":[{"content":"The sum of three numbers is $96.$ The first number is $6$ times the third number, and the third number is $40$ less than the second number. What is the absolute value of the difference between the first and second numbers?","role":"user"}],"ability":"math","reward_model":{"ground_truth":5.0,"num_tokens":-512,"style":"rule"},"extra_info":{"index":11,"split":"test"}} 13 | {"data_source":"","prompt":[{"content":"Isosceles trapezoid $ABCD$ has parallel sides $\\overline{AD}$ and $\\overline{BC},$ with $BC < AD$ and $AB = CD.$ There is a point $P$ in the plane such that $PA=1, PB=2, PC=3,$ and $PD=4.$ Let $\\frac{r}{s}=frac{BC}{AD}$ is irreducible fraction, what is the value of $r+s$?","role":"user"}],"ability":"math","reward_model":{"ground_truth":4.0,"num_tokens":-512,"style":"rule"},"extra_info":{"index":12,"split":"test"}} 14 | {"data_source":"","prompt":[{"content":"Let $c$ be a real number, and let $z_1$ and $z_2$ be the two complex numbers satisfying the equation\n$z^2 - cz + 10 = 0$. Points $z_1$, $z_2$, $\\frac{1}{z_1}$, and $\\frac{1}{z_2}$ are the vertices of (convex) quadrilateral $\\mathcal{Q}$ in the complex plane. When the area of $\\mathcal{Q}$ obtains its maximum possible value, let $c=\\sqrt{m}$. what is the value of m","role":"user"}],"ability":"math","reward_model":{"ground_truth":20.0,"num_tokens":-512,"style":"rule"},"extra_info":{"index":13,"split":"test"}} 15 | {"data_source":"","prompt":[{"content":"Let $h_n$ and $k_n$ be the unique relatively prime positive integers such that \\[\\frac{1}{1}+\\frac{1}{2}+\\frac{1}{3}+\\cdots+\\frac{1}{n}=\\frac{h_n}{k_n}.\\] Let $L_n$ denote the least common multiple of the numbers $1, 2, 3, \\ldots, n$. For how many integers with $1\\le{n}\\le{22}$ is $k_n d(R, S)$. The final answer can be written in the form $\\frac{m}{n}$, where $m$ and $n$ are relatively prime positive integers. What is $m+n$?","role":"user"}],"ability":"math","reward_model":{"ground_truth":29.0,"num_tokens":-512,"style":"rule"},"extra_info":{"index":54,"split":"test"}} 56 | {"data_source":"","prompt":[{"content":"Let $f$ be the unique function defined on the positive integers such that \\[\\sum_{d\\mid n}d\\cdot f\\left(\\frac{n}{d}\\right)=1\\] for all positive integers $n$. What is $f(2023)$?","role":"user"}],"ability":"math","reward_model":{"ground_truth":96.0,"num_tokens":-512,"style":"rule"},"extra_info":{"index":55,"split":"test"}} 57 | {"data_source":"","prompt":[{"content":"How many ordered pairs of positive real numbers $(a,b)$ satisfy the equation\n\\[(1+2a)(2+2b)(2a+b) = 32ab?\\]","role":"user"}],"ability":"math","reward_model":{"ground_truth":1.0,"num_tokens":-512,"style":"rule"},"extra_info":{"index":56,"split":"test"}} 58 | {"data_source":"","prompt":[{"content":"Let $K$ be the number of sequences $A_1$, $A_2$, $\\dots$, $A_n$ such that $n$ is a positive integer less than or equal to $10$, each $A_i$ is a subset of $\\{1, 2, 3, \\dots, 10\\}$, and $A_{i-1}$ is a subset of $A_i$ for each $i$ between $2$ and $n$, inclusive. For example, $\\{\\}$, $\\{5, 7\\}$, $\\{2, 5, 7\\}$, $\\{2, 5, 7\\}$, $\\{2, 5, 6, 7, 9\\}$ is one such sequence, with $n = 5$.What is the remainder when $K$ is divided by $10$?","role":"user"}],"ability":"math","reward_model":{"ground_truth":5.0,"num_tokens":-512,"style":"rule"},"extra_info":{"index":57,"split":"test"}} 59 | {"data_source":"","prompt":[{"content":"There is a unique sequence of integers $a_1, a_2, \\cdots a_{2023}$ such that\n\\[\\tan2023x = \\frac{a_1 \\tan x + a_3 \\tan^3 x + a_5 \\tan^5 x + \\cdots + a_{2023} \\tan^{2023} x}{1 + a_2 \\tan^2 x + a_4 \\tan^4 x \\cdots + a_{2022} \\tan^{2022} x}\\]whenever $\\tan 2023x$ is defined. What is $a_{2023}?$","role":"user"}],"ability":"math","reward_model":{"ground_truth":-1.0,"num_tokens":-512,"style":"rule"},"extra_info":{"index":58,"split":"test"}} 60 | {"data_source":"","prompt":[{"content":"How many positive perfect squares less than $2023$ are divisible by $5$?","role":"user"}],"ability":"math","reward_model":{"ground_truth":8.0,"num_tokens":-512,"style":"rule"},"extra_info":{"index":59,"split":"test"}} 61 | {"data_source":"","prompt":[{"content":"How many digits are in the base-ten representation of $8^5 \\cdot 5^{10} \\cdot 15^5$?","role":"user"}],"ability":"math","reward_model":{"ground_truth":18.0,"num_tokens":-512,"style":"rule"},"extra_info":{"index":60,"split":"test"}} 62 | {"data_source":"","prompt":[{"content":"Janet rolls a standard $6$-sided die $4$ times and keeps a running total of the numbers she rolls. What is the probability that at some point, her running total will equal $3$? The final answer can be written in the form $\\frac{m}{n}$, where $m$ and $n$ are relatively prime positive integers. What is $m+n$?","role":"user"}],"ability":"math","reward_model":{"ground_truth":265.0,"num_tokens":-512,"style":"rule"},"extra_info":{"index":61,"split":"test"}} 63 | {"data_source":"","prompt":[{"content":"Points $A$ and $B$ lie on the graph of $y=\\log_{2}x$. The midpoint of $\\overline{AB}$ is $(6, 2)$. What is the positive difference between the $x$-coordinates of $A$ and $B$? The final answer can be written in the form $m \\sqrt{n}$, where $m$ and $n$ are relatively prime positive integers. What is $m+n$?","role":"user"}],"ability":"math","reward_model":{"ground_truth":9.0,"num_tokens":-512,"style":"rule"},"extra_info":{"index":62,"split":"test"}} 64 | {"data_source":"","prompt":[{"content":"A digital display shows the current date as an $8$-digit integer consisting of a $4$-digit year, followed by a $2$-digit month, followed by a $2$-digit date within the month. For example, Arbor Day this year is displayed as 20230428. For how many dates in $2023$ will each digit appear an even number of times in the 8-digital display for that date?","role":"user"}],"ability":"math","reward_model":{"ground_truth":9.0,"num_tokens":-512,"style":"rule"},"extra_info":{"index":63,"split":"test"}} 65 | {"data_source":"","prompt":[{"content":"Maureen is keeping track of the mean of her quiz scores this semester. If Maureen scores an $11$ on the next quiz, her mean will increase by $1$. If she scores an $11$ on each of the next three quizzes, her mean will increase by $2$. What is the mean of her quiz scores currently?","role":"user"}],"ability":"math","reward_model":{"ground_truth":7.0,"num_tokens":-512,"style":"rule"},"extra_info":{"index":64,"split":"test"}} 66 | {"data_source":"","prompt":[{"content":"Mrs. Jones is pouring orange juice into four identical glasses for her four sons. She fills the first three glasses completely but runs out of juice when the fourth glass is only $\\frac{1}{3}$ full. What fraction of a glass must Mrs. Jones pour from each of the first three glasses into the fourth glass so that all four glasses will have the same amount of juice? The final answer can be written in the form $\\frac{m}{n}$, where $m$ and $n$ are relatively prime positive integers. What is $m+n$?","role":"user"}],"ability":"math","reward_model":{"ground_truth":7.0,"num_tokens":-512,"style":"rule"},"extra_info":{"index":65,"split":"test"}} 67 | {"data_source":"","prompt":[{"content":"In the $xy$-plane, a circle of radius $4$ with center on the positive $x$-axis is tangent to the $y$-axis at the origin, and a circle with radius $10$ with center on the positive $y$-axis is tangent to the $x$-axis at the origin. What is the slope of the line passing through the two points at which these circles intersect? The final answer can be written in the form $\\frac{m}{n}$, where $m$ and $n$ are relatively prime positive integers. What is $m+n$?","role":"user"}],"ability":"math","reward_model":{"ground_truth":7.0,"num_tokens":-512,"style":"rule"},"extra_info":{"index":66,"split":"test"}} 68 | {"data_source":"","prompt":[{"content":"Calculate the maximum area of an isosceles trapezoid that has legs of length $1$ and one base twice as long as the other. The final answer can be written in the form $\\frac{m}{n}$, where $m$ and $n$ are relatively prime positive integers. What is $m^2+n^2$?","role":"user"}],"ability":"math","reward_model":{"ground_truth":13.0,"num_tokens":-512,"style":"rule"},"extra_info":{"index":67,"split":"test"}} 69 | {"data_source":"","prompt":[{"content":"For complex number $u = a+bi$ and $v = c+di$ (where $i=\\sqrt{-1}$), define the binary operation\n$u \\otimes v = ac + bdi$\nSuppose $z$ is a complex number such that $z\\otimes z = z^{2}+40$. What is $|z|^2$?","role":"user"}],"ability":"math","reward_model":{"ground_truth":50.0,"num_tokens":-512,"style":"rule"},"extra_info":{"index":68,"split":"test"}} 70 | {"data_source":"","prompt":[{"content":"A rectangular box $P$ has distinct edge lengths $a$, $b$, and $c$. The sum of the lengths of all $12$ edges of $P$ is $13$, the areas of all $6$ faces of $P$ is $\\frac{11}{2}$, and the volume of $P$ is $\\frac{1}{2}$. Find the length of the longest interior diagonal connecting two vertices of $P$. The final answer can be written in the form $\\frac{m}{n}$, where $m$ and $n$ are relatively prime positive integers. What is $m+n$?","role":"user"}],"ability":"math","reward_model":{"ground_truth":13.0,"num_tokens":-512,"style":"rule"},"extra_info":{"index":69,"split":"test"}} 71 | {"data_source":"","prompt":[{"content":"For how many ordered pairs $(a,b)$ of integers does the polynomial $x^3+ax^2+bx+6$ have $3$ distinct integer roots?","role":"user"}],"ability":"math","reward_model":{"ground_truth":5.0,"num_tokens":-512,"style":"rule"},"extra_info":{"index":70,"split":"test"}} 72 | {"data_source":"","prompt":[{"content":"In the state of Coinland, coins have values $6,10,$ and $15$ cents. Suppose $x$ is the value in cents of the most expensive item in Coinland that cannot be purchased using these coins with exact change. What is the sum of the digits of $x?$","role":"user"}],"ability":"math","reward_model":{"ground_truth":11.0,"num_tokens":-512,"style":"rule"},"extra_info":{"index":71,"split":"test"}} 73 | {"data_source":"","prompt":[{"content":"Triangle $ABC$ has side lengths in arithmetic progression, and the smallest side has length $6.$ If the triangle has an angle of $120^\\circ,$ Find the area of $ABC$. The final answer can be simplified in the form $m \\sqrt{n}$, where $m$ and $n$ are positive integers and $n$ without square factore. What is $m+n$?","role":"user"}],"ability":"math","reward_model":{"ground_truth":18.0,"num_tokens":-512,"style":"rule"},"extra_info":{"index":72,"split":"test"}} 74 | {"data_source":"","prompt":[{"content":"Carlos went to a sports store to buy running shoes. Running shoes were on sale, with prices reduced by $20\\%$ on every pair of shoes. Carlos also knew that he had to pay a $7.5\\%$ sales tax on the discounted price. He had $$43$ dollars. What is the original (before discount) price of the most expensive shoes he could afford to buy?","role":"user"}],"ability":"math","reward_model":{"ground_truth":50.0,"num_tokens":-512,"style":"rule"},"extra_info":{"index":73,"split":"test"}} 75 | {"data_source":"","prompt":[{"content":"When $n$ standard six-sided dice are rolled, the product of the numbers rolled can be any of $936$ possible values. What is $n$?","role":"user"}],"ability":"math","reward_model":{"ground_truth":11.0,"num_tokens":-512,"style":"rule"},"extra_info":{"index":74,"split":"test"}} 76 | {"data_source":"","prompt":[{"content":"Suppose that $a$, $b$, $c$ and $d$ are positive integers satisfying all of the following relations.\n\\[abcd=2^6\\cdot 3^9\\cdot 5^7\\]\n\\[\\text{lcm}(a,b)=2^3\\cdot 3^2\\cdot 5^3\\]\n\\[\\text{lcm}(a,c)=2^3\\cdot 3^3\\cdot 5^3\\]\n\\[\\text{lcm}(a,d)=2^3\\cdot 3^3\\cdot 5^3\\]\n\\[\\text{lcm}(b,c)=2^1\\cdot 3^3\\cdot 5^2\\]\n\\[\\text{lcm}(b,d)=2^2\\cdot 3^3\\cdot 5^2\\]\n\\[\\text{lcm}(c,d)=2^2\\cdot 3^3\\cdot 5^2\\]\nWhat is $\\text{gcd}(a,b,c,d)$?","role":"user"}],"ability":"math","reward_model":{"ground_truth":3.0,"num_tokens":-512,"style":"rule"},"extra_info":{"index":75,"split":"test"}} 77 | {"data_source":"","prompt":[{"content":"A $3-4-5$ right triangle is inscribed in circle $A$, and a $5-12-13$ right triangle is inscribed in circle $B$. Find the ratio of the area of circle $A$ to the area of circle $B$. The final answer can be written in the form $\\frac{m}{n}$, where $m$ and $n$ are relatively prime positive integers. What is $m+n$?","role":"user"}],"ability":"math","reward_model":{"ground_truth":194.0,"num_tokens":-512,"style":"rule"},"extra_info":{"index":76,"split":"test"}} 78 | {"data_source":"","prompt":[{"content":"Jackson's paintbrush makes a narrow strip with a width of $6.5$ millimeters. Jackson has enough paint to make a strip $25$ meters long. How many square centimeters of paper could Jackson cover with paint?","role":"user"}],"ability":"math","reward_model":{"ground_truth":1625.0,"num_tokens":-512,"style":"rule"},"extra_info":{"index":77,"split":"test"}} 79 | {"data_source":"","prompt":[{"content":"You are playing a game. A $2 \\times 1$ rectangle covers two adjacent squares (oriented either horizontally or vertically) of a $3 \\times 3$ grid of squares, but you are not told which two squares are covered. Your goal is to find at least one square that is covered by the rectangle. A \"turn\" consists of you guessing a square, after which you are told whether that square is covered by the hidden rectangle. What is the minimum number of turns you need to ensure that at least one of your guessed squares is covered by the rectangle?","role":"user"}],"ability":"math","reward_model":{"ground_truth":4.0,"num_tokens":-512,"style":"rule"},"extra_info":{"index":78,"split":"test"}} 80 | {"data_source":"","prompt":[{"content":"When the roots of the polynomial \n\\[P(x) = (x-1)^1 (x-2)^2 (x-3)^3 \\cdot \\cdot \\cdot (x-10)^{10}\\]\nare removed from the number line, what remains is the union of $11$ disjoint open intervals. On how many of these intervals is $P(x)$ positive?","role":"user"}],"ability":"math","reward_model":{"ground_truth":6.0,"num_tokens":-512,"style":"rule"},"extra_info":{"index":79,"split":"test"}} 81 | {"data_source":"","prompt":[{"content":"For how many integers $n$ does the expression\\[\\sqrt{\\frac{\\log (n^2) - (\\log n)^2}{\\log n - 3}}\\]represent a real number, where log denotes the base $10$ logarithm?","role":"user"}],"ability":"math","reward_model":{"ground_truth":901.0,"num_tokens":-512,"style":"rule"},"extra_info":{"index":80,"split":"test"}} 82 | {"data_source":"","prompt":[{"content":"How many nonempty subsets $B$ of ${0, 1, 2, 3, \\cdots, 12}$ have the property that the number of elements in $B$ is equal to the least element of $B$? For example, $B = {4, 6, 8, 11}$ satisfies the condition.","role":"user"}],"ability":"math","reward_model":{"ground_truth":144.0,"num_tokens":-512,"style":"rule"},"extra_info":{"index":81,"split":"test"}} 83 | {"data_source":"","prompt":[{"content":"What is the area of the region in the coordinate plane defined by\n$| | x | - 1 | + | | y | - 1 | \\le 1$?","role":"user"}],"ability":"math","reward_model":{"ground_truth":8.0,"num_tokens":-512,"style":"rule"},"extra_info":{"index":82,"split":"test"}} 84 | --------------------------------------------------------------------------------