print('Hello, world!')."
40 | print(check_format(data))
--------------------------------------------------------------------------------
/script/data_script/model_download.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | import os
7 | from requests.exceptions import HTTPError
8 | import sys
9 | from pathlib import Path
10 | from typing import Optional
11 |
12 |
13 | def hf_download(
14 | repo_id: Optional[str] = None,
15 | hf_token: Optional[str] = None,
16 | local_dir: Optional[str] = None,
17 | ) -> None:
18 | from huggingface_hub import snapshot_download
19 |
20 | local_dir = local_dir or "checkpoints"
21 |
22 | os.makedirs(f"{local_dir}/{repo_id}", exist_ok=True)
23 | try:
24 | snapshot_download(
25 | repo_id,
26 | local_dir=f"{local_dir}/{repo_id}",
27 | local_dir_use_symlinks=False,
28 | token=hf_token,
29 | )
30 | except HTTPError as e:
31 | if e.response.status_code == 401:
32 | print(
33 | "You need to pass a valid `--hf_token=...` to download private checkpoints."
34 | )
35 | else:
36 | raise e
37 |
38 |
39 | if __name__ == "__main__":
40 | import argparse
41 |
42 | parser = argparse.ArgumentParser(description="Download data from HuggingFace Hub.")
43 | parser.add_argument(
44 | "--repo_id",
45 | type=str,
46 | default="checkpoints/meta-llama/llama-2-7b-chat-hf",
47 | help="Repository ID to download from.",
48 | )
49 | parser.add_argument(
50 | "--local_dir", type=str, default=None, help="Local directory to download to."
51 | )
52 | parser.add_argument(
53 | "--hf_token", type=str, default=None, help="HuggingFace API token."
54 | )
55 |
56 | args = parser.parse_args()
57 | hf_download(args.repo_id, args.hf_token, args.local_dir)
--------------------------------------------------------------------------------
/math_utils/deduplicate_problem.py:
--------------------------------------------------------------------------------
1 | from thefuzz import fuzz
2 | import json
3 | from utils import find_question
4 | from tqdm import tqdm
5 | import sys
6 |
7 | def check_duplicate(text1, text2, threshold=90):
8 | if fuzz.ratio(text1, text2) > threshold:
9 | return True
10 | return False
11 |
12 | #print(check_duplicate("What is the sum of the first 100 natural numbers?", "What is the sum of the first 100 natural numbers?"))
13 |
14 | training_set_path = "data/DeepMath-103K-big-number/DeepMath-103K-big-number_question_everything.jsonl"
15 |
16 | training_problem_list = []
17 | with open(training_set_path, "r") as f:
18 | for line in f:
19 | data = json.loads(line)
20 | training_problem_list.append(data)
21 |
22 | print("training set size: ", len(training_problem_list))
23 |
24 | #test_set_path = "data/combinatorics_test_ge10000.jsonl"
25 | test_set_path = sys.argv[1]
26 | threshold = 90
27 | test_problem_list = []
28 | with open(test_set_path, "r") as f:
29 | for line in f:
30 | data = json.loads(line)
31 | test_problem_list.append(data)
32 |
33 | print("test set size: ", len(test_problem_list))
34 |
35 | no_duplicate_count = 0
36 | no_duplicate_count_list = []
37 | for problem in tqdm(test_problem_list):
38 | for training_problem in training_problem_list:
39 | if check_duplicate(find_question(problem), find_question(training_problem), threshold=threshold):
40 | print("find duplicate, ratio: ", fuzz.ratio(find_question(problem), find_question(training_problem)))
41 | print(">>>>>>>>>>>>>>>>>>>>")
42 | print(find_question(problem))
43 | print("<<<<<<<<<<<<<<<<<<<<")
44 | print(find_question(training_problem))
45 | print(">>>>>>>>>>>>>>>>>>>>")
46 | break
47 | else:
48 | no_duplicate_count += 1
49 | no_duplicate_count_list.append(problem)
50 |
51 | print("test set no duplicate count: ", no_duplicate_count)
52 |
53 | with open(test_set_path.replace(".jsonl", "_no_duplicate_threshold_{}.jsonl".format(threshold)), "w") as f:
54 | for problem in no_duplicate_count_list:
55 | f.write(json.dumps(problem) + "\n")
56 |
--------------------------------------------------------------------------------
/script/data_script/processing_self_distillation_traj.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import sys
4 |
5 | # add the root directory to the python path
6 | sys.path.append(os.path.abspath("."))
7 |
8 | from math_utils.format_checking import check_format
9 | from transformers import AutoTokenizer
10 |
11 | tokenizer = AutoTokenizer.from_pretrained("models/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B")
12 | root_dir = "dataset/train/self_distillation"
13 | data_path = os.path.join(root_dir, "iteration_1_correct_replay_buffer.jsonl")
14 | wrong_data_path = os.path.join(root_dir, "iteration_1_incorrect_replay_buffer.jsonl")
15 | correct_record_path = os.path.join(root_dir, "iteration_1_accuracy.jsonl")
16 | text_reasoning_path = "dataset/train/whole_training_set_with_solution.jsonl"
17 |
18 | num_samples = 16
19 | correctness_bar = 0.9
20 |
21 | with open(text_reasoning_path, "r") as f:
22 | text_reasoning_data = [json.loads(line) for line in f]
23 |
24 | data = []
25 | count = 0
26 | with open(data_path, "r") as f:
27 | raw_data = [json.loads(line) for line in f]
28 | count = 0
29 | temp_add_data = None
30 | for i in range(len(raw_data)):
31 | # If not all the data are correct, we add an example into expert iteration
32 | # Do not add a text reasoning path because this part has been added in SFT stage
33 | # One idx may have multiple data, we only add one data piece into expert iteration to ensure diversity
34 | if i == 0 or raw_data[i]["idx"] != raw_data[i-1]["idx"]:
35 | if count < num_samples and temp_add_data is not None:
36 | # > 0, < correctness_bar
37 | data.append(temp_add_data)
38 | temp_add_data = None
39 | count = 0
40 | count += 1
41 | if check_format(raw_data[i]["synthetic_data"]):
42 | attempt = raw_data[i]["synthetic_data"].split("")[0].strip()
43 | if temp_add_data is None:
44 | for j in range(len(text_reasoning_data)):
45 | if text_reasoning_data[j]["problem"] == raw_data[i]["problem"]:
46 | temp_add_data = attempt + "\n\nWait, we can also use text-reasoning as an alternative way to verify the solution.\n\n" + text_reasoning_data[j]["solution"]
47 | raw_data[i]["synthetic_data"] = temp_add_data
48 | temp_add_data = raw_data[i]
49 | count += 1
50 | break
51 |
52 | print("self-distillation part 1: ", len(data))
53 |
54 | wrong_data = []
55 |
56 | # Wrong data but format correct
57 | # Replace the last block with correct text reasoning
58 | # At most one wrong data per idx
59 | # If exist a correct sample, we do not add a corrected wrong sample into expert iteration
60 |
61 | with open(wrong_data_path, "r") as f:
62 | raw_data = [json.loads(line) for line in f]
63 | with open(correct_record_path, "r") as ff:
64 | correct_record = [json.loads(line) for line in ff]
65 | max_attempt_len = 0
66 | for i in range(len(raw_data)):
67 | if i == 0 or raw_data[i]["idx"] != raw_data[i - 1]["idx"]:
68 | count = 0
69 | if correct_record[raw_data[i]["idx"] - 1]["accuracy"] >= correctness_bar:
70 | continue
71 | attempt = " ".join(raw_data[i]["synthetic_data"].split("")[:-1]) + "\n"
72 | if "Wait, the code is not correct, let's try text reasoning" in attempt:
73 | continue
74 | if check_format(attempt) and count < 1:
75 | max_attempt_len = max(max_attempt_len, len(tokenizer.encode(attempt)))
76 | # find corresponding correct text-reasoning
77 | for j in range(len(text_reasoning_data)):
78 | if text_reasoning_data[j]["problem"] == raw_data[i]["problem"]:
79 | new_attempt = attempt + "\n\nWait, the code is not correct, let's try text reasoning.\n\n" + text_reasoning_data[j]["solution"]
80 | if check_format(new_attempt):
81 | new_data = raw_data[i]
82 | new_data["synthetic_data"] = new_attempt
83 | wrong_data.append(new_data)
84 | count += 1
85 | break
86 |
87 | print("self-distillation part 2: ", len(wrong_data))
88 | data_root_dir = "dataset/train/self_distillation"
89 |
90 | with open(os.path.join(data_root_dir, f"iteration_1_correct_replay_buffer_deduplicated.jsonl"), "w") as f:
91 | for item in data:
92 | f.write(json.dumps(item) + "\n")
93 |
94 | with open(os.path.join(data_root_dir, f"iteration_1_incorrect_replay_buffer_revised_deduplicated_{correctness_bar}.jsonl"), "w") as f:
95 | for item in wrong_data:
96 | f.write(json.dumps(item) + "\n")
--------------------------------------------------------------------------------
/sft/self_distillation_sampler.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import argparse
4 | import sys
5 | from concurrent.futures import ThreadPoolExecutor
6 | parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
7 | sys.path.append(parent_dir)
8 |
9 | from math_utils.utils import read_json_or_jsonl, find_question, find_answer
10 | from sft.evaluate import server_inference
11 | from math_utils.utils import compute_score
12 | from tqdm import tqdm
13 |
14 | def self_distillation_sampler(server_url: str, model_name: str, model_path: str, data_path: str, num_samples: int, save_path: str, data_size: int, iteration: int=1):
15 | """
16 | Self-distillation sampler
17 | The model generates #num_samples samples for each question in the data_path
18 | """
19 | # check paths
20 | os.makedirs(save_path, exist_ok=True)
21 | # load the dataset
22 | data = read_json_or_jsonl(data_path)
23 | data = data[:data_size]
24 | # generate #num_samples samples for each question in the data_path
25 | correct_replay_buffer_path = os.path.join(save_path, f"iteration_{iteration}_correct_replay_buffer.jsonl")
26 | incorrect_replay_buffer_path = os.path.join(save_path, f"iteration_{iteration}_incorrect_replay_buffer.jsonl")
27 | if os.path.exists(correct_replay_buffer_path.replace(".jsonl", "_last_sample_id.txt")):
28 | with open(correct_replay_buffer_path.replace(".jsonl", "_last_sample_id.txt"), "r") as f:
29 | last_sample_id = int(f.read())
30 | else:
31 | last_sample_id = 0
32 | if os.path.exists(correct_replay_buffer_path.replace("correct_replay_buffer.jsonl", "accuracy.jsonl")):
33 | with open(correct_replay_buffer_path.replace("correct_replay_buffer.jsonl", "accuracy.jsonl"), "r") as f:
34 | accuracy_data = [json.loads(line) for line in f]
35 | sum_accuracy = sum([data["accuracy"] for data in accuracy_data])
36 | else:
37 | accuracy_data = []
38 | sum_accuracy = 0
39 | with tqdm(data[last_sample_id:], total=len(data) - last_sample_id, desc="Generating samples") as pbar:
40 | for i, item in enumerate(pbar, 1):
41 | # generate #num_samples samples
42 | correct = 0
43 | with ThreadPoolExecutor(max_workers=num_samples) as executor:
44 | futures = [executor.submit(server_inference, \
45 | model_base_url=server_url, \
46 | model_name=model_name, \
47 | tokenizer_path=model_path, \
48 | input=find_question(item), \
49 | code_mode=True, \
50 | max_tokens=4096, \
51 | is_ipython=False) for i in range(num_samples)]
52 | outputs = [future.result() for future in futures]
53 |
54 | correct = 0
55 | correct_data_to_save = []
56 | incorrect_data_to_save = []
57 | for output in outputs:
58 | # check if the sample is correct
59 | if compute_score(output.split("<|Assistant|>")[-1], find_answer(item)) == 1:
60 | # add the sample to the data
61 | correct += 1
62 | correct_data_to_save.append(output)
63 | else:
64 | incorrect_data_to_save.append(output)
65 |
66 | # save the samples
67 | for data in correct_data_to_save:
68 | with open(correct_replay_buffer_path, "a") as f:
69 | f.write(json.dumps({"idx": i + last_sample_id, "problem": find_question(item), "synthetic_data": data.split("<|Assistant|>")[-1]}) + "\n")
70 | for data in incorrect_data_to_save:
71 | with open(incorrect_replay_buffer_path, "a") as f:
72 | f.write(json.dumps({"idx": i + last_sample_id, "problem": find_question(item), "synthetic_data": data.split("<|Assistant|>")[-1]}) + "\n")
73 |
74 | sum_accuracy += correct / num_samples
75 | accuracy_data.append({"idx": i + last_sample_id, "accuracy": correct / num_samples, "correct": correct, "total": num_samples})
76 | pbar.set_postfix({"accuracy": f"{sum_accuracy / (i + last_sample_id):.2%}", "correct": correct, "total": num_samples})
77 | with open(correct_replay_buffer_path.replace(".jsonl", "_last_sample_id.txt"), "w") as f:
78 | f.write(str(i + last_sample_id))
79 | with open(correct_replay_buffer_path.replace("correct_replay_buffer.jsonl", "accuracy.jsonl"), "w") as f:
80 | for data in accuracy_data:
81 | f.write(json.dumps(data) + "\n")
82 |
83 | print(f"Iteration {iteration}, average accuracy: {sum_accuracy / len(data)}")
84 |
85 | if __name__ == "__main__":
86 | parser = argparse.ArgumentParser()
87 | parser.add_argument("--server_url", type=str, default="http://localhost:8123/v1", help="The server url")
88 | parser.add_argument("--model_name", type=str, default="DeepSeek-R1-Distill-Qwen-7B")
89 | parser.add_argument("--model_path", type=str, default="models/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B")
90 | parser.add_argument("--data_path", type=str, default="dataset/train/whole_training_set.jsonl")
91 | parser.add_argument("--num_samples", type=int, default=16)
92 | parser.add_argument("--save_path", type=str, help="The path to save the self-distillation trajectories", default="dataset/train/self_distillation")
93 | parser.add_argument("--data_size", type=int, default=-1, help="The number of data to sample, -1 means all data")
94 | args = parser.parse_args()
95 | print(args)
96 | self_distillation_sampler(server_url=args.server_url, \
97 | model_name=args.model_name, \
98 | model_path=args.model_path, \
99 | data_path=args.data_path, \
100 | num_samples=args.num_samples, \
101 | save_path=args.save_path, \
102 | data_size=args.data_size)
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DualDistill
2 |
3 | [](LICENSE)
4 | [](https://arxiv.org/abs/2507.05707)
5 |
6 | Official implementation of **DualDistill**: A trajectory-composition distillation method for integrating tool use into long-chain-of-thought reasoning.
7 |
8 | > **Weihua Du, Pranjal Aggarwal, Sean Welleck, & Yiming Yang**
9 | > ["Agentic-R1: Distilled Dual-Strategy Reasoning." (2025)](https://arxiv.org/abs/2507.05707)
10 |
11 | ## Key Features
12 |
13 | - **Efficient Training**: Integrates tool use into long-chain-of-thought (CoT) reasoning using only 4 × A6000 GPUs
14 | - **Unified Reasoning**: Fuses heterogeneous reasoning traces from multiple teacher models into a single student model
15 |
16 |
17 |
18 | Overview of DualDistill methodology
19 |
20 |
21 | ## Datasets
22 |
23 | | Dataset | Description | Link |
24 | |---------|-------------|------|
25 | | **Training Set** | Complete training dataset with teacher trajectories | [🤗 HuggingFace](https://huggingface.co/datasets/VanishD/DualDistill) |
26 | | **Test Set** | Evaluation benchmarks | `dataset/test/` |
27 |
28 | ## Results
29 |
30 |
31 |
32 |
33 |
34 | - **Agentic-R1** demonstrates significant performance gains on **DeepMath-L** and **Combinatorics300**, where both complex reasoning and tool use are crucial for success.
35 | - **Agentic-R1-SD** (Self-Distilled) further enhances performance through our self-distillation approach, consistently outperforming baseline models across nearly all evaluation tasks.
36 |
37 | ## Quick Start
38 |
39 | ### Installation
40 |
41 | 1. **Clone the repository**:
42 | ```bash
43 | git clone https://github.com/StigLidu/DualDistill.git
44 | cd DualDistill
45 | ```
46 |
47 | 2. **Create environment** (optional but recommended):
48 | ```bash
49 | conda create -n dualdistill python=3.11
50 | conda activate dualdistill
51 | ```
52 |
53 | 3. **Install dependencies**:
54 | ```bash
55 | pip install -r requirements.txt
56 | pip install flash-attn --no-build-isolation
57 | ```
58 |
59 | ## Training Pipeline
60 |
61 | ### Step 1: Model & Data Preparation
62 |
63 | **Download the base model**:
64 | ```bash
65 | python script/data_script/model_download.py \
66 | --repo_id deepseek-ai/DeepSeek-R1-Distill-Qwen-7B \
67 | --local_dir models
68 | ```
69 |
70 | **Prepare training data**:
71 | ```bash
72 | python script/data_script/teacher_data_download.py
73 | ```
74 |
75 | ### Step 2: Teacher Distillation
76 |
77 | Train the student model using teacher trajectories:
78 | ```bash
79 | bash script/sft_script/SFT.sh
80 | ```
81 |
82 | ### Step 3: Self-Distillation
83 |
84 | **Start inference server**:
85 | ```bash
86 | bash script/eval_script/start_inference_server.sh [model_path] [display_name] [port]
87 | ```
88 |
89 | **Sample self-distillation trajectories**:
90 | ```bash
91 | python sft/self_distillation_sampler.py \
92 | --server_url http://localhost:$port/v1 \
93 | --model_name [display_name] \
94 | --model_path [model_path] \
95 | --save_path [path_to_save_trajectories]
96 | ```
97 |
98 | **Prepare self-distillation data**:
99 | ```bash
100 | # Extract teacher solutions
101 | python script/data_script/extract_training_solution.py
102 |
103 | # Construct training dataset
104 | python script/data_script/processing_self_distillation_traj.py
105 | ```
106 |
107 | **Fine-tune on self-distillation data**:
108 | ```bash
109 | bash script/sft_script/expert_iteration.sh [model_path] [data_path] [save_path]
110 | ```
111 |
112 | ## Model Evaluation
113 |
114 | ### Start Inference Server
115 | ```bash
116 | bash script/eval_script/start_inference_server.sh [model_path] [display_name] [port]
117 | ```
118 |
119 | ### Run Evaluation
120 | ```bash
121 | bash script/eval_script/eval_remote_server.sh \
122 | [url] [display_name] [data_path] [code_mode] [max_token]
123 | ```
124 |
125 | **Example**:
126 | ```bash
127 | bash script/eval_script/eval_remote_server.sh \
128 | "http://localhost:8080/v1" "agentic-r1" "dataset/test/math.json" "true" "4096"
129 | ```
130 |
131 | ## Trained Models
132 |
133 | | Model | Description | HuggingFace Link |
134 | |-------|-------------|------------------|
135 | | **Agentic-R1-7B** | Base model with teacher distillation | [🤗 Download](https://huggingface.co/VanishD/Agentic-R1) |
136 | | **Agentic-R1-7B-SD** | Enhanced model with self-distillation | [🤗 Download](https://huggingface.co/VanishD/Agentic-R1-SD) |
137 |
138 | ## ⚠️ Important Notes
139 |
140 | - **Code Execution Safety**: The evaluation scripts execute model-generated code locally. Only use trusted models before execution.
141 | - **Inference Config**: If you are using vLLM (a recent version) and encounter an error regarding the maximum context length. You may need to modify the `model_max_length` in `tokenizer_config.json`.
142 | - **Self-Distillation Warning**: The self-distillation step requires sampling many trajectories and can be time-consuming.
143 |
144 | ## License
145 |
146 | This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
147 |
148 | ## Acknowledgments
149 |
150 | We thank the following open-source projects for their foundational contributions:
151 |
152 | - [OpenHands](https://github.com/All-Hands-AI/OpenHands) - Agent framework
153 | - [DeepMath-103K](https://huggingface.co/datasets/zwhe99/DeepMath-103K) - Mathematical reasoning dataset
154 | - [vLLM](https://github.com/vllm-project/vllm) - High-performance inference engine
155 |
156 | ## Contact
157 |
158 | For questions or support, please contact:
159 |
160 | - **Weihua Du**: [weihuad@cs.cmu.edu](mailto:weihuad@cs.cmu.edu)
161 |
162 | ## Citation
163 |
164 | If you find our work useful, please consider citing:
165 |
166 | ```bibtex
167 | @article{du2025agentic,
168 | title={Agentic-R1: Distilled Dual-Strategy Reasoning},
169 | author={Du, Weihua and Aggarwal, Pranjal and Welleck, Sean and Yang, Yiming},
170 | journal={arXiv preprint arXiv:2507.05707},
171 | year={2025}
172 | }
173 | ```
174 |
175 | ---
176 |
177 |
178 | ⭐ Star us on GitHub if this project helped you!
179 |
180 |
--------------------------------------------------------------------------------
/sft/dataloader.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset
2 | from tqdm import tqdm
3 | import torch
4 | from torch.nn.utils.rnn import pad_sequence
5 | from transformers import AutoTokenizer
6 | import numpy as np
7 | import sys
8 | import os
9 | import json
10 |
11 | parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
12 | sys.path.append(parent_dir)
13 |
14 | from math_utils.utils import SYSTEM_PROMPT_TPL, CODE_INSTRUCTION
15 | from math_utils.format_checking import check_format
16 |
17 | def find_nth(haystack, needle, n):
18 | start = haystack.find(needle)
19 | while start >= 0 and n > 0:
20 | start = haystack.find(needle, start + len(needle))
21 | n -= 1
22 | return start
23 |
24 | class TrainData(Dataset):
25 | def __init__(self, data, tokenizer, code_instruction, max_data_count=None, data_seed=42, debug=False, max_length=16384):
26 | self.tokenizer = tokenizer
27 | self.items = []
28 | self.max_length = max_length
29 | self.total_loss_calculation_token_count = 0
30 | #TODO: seems like the max length of the model is 16384 because it throws an warning when the length is larger than 16384
31 | self.debug = debug
32 | system_prompt_tpl = SYSTEM_PROMPT_TPL.format(code_instruction=CODE_INSTRUCTION)
33 |
34 | for sample in tqdm(data, desc="Processing data"):
35 | question = sample["problem"]
36 | answer = sample["synthetic_data"]
37 |
38 | if not check_format(answer):
39 | continue
40 |
41 | # only keep the … part, avoid extra content
42 | if " " in answer:
43 | answer = answer.split("")[0] + ""
44 |
45 | system_prompt = system_prompt_tpl
46 | messages = (
47 | system_prompt + "\n\n<|User|>" + question +
48 | "\n\n<|Assistant|>" + answer
49 | )
50 |
51 | input_ids = tokenizer.encode(messages)
52 | input_ids.append(tokenizer.eos_token_id)
53 | if len(input_ids) > self.max_length:
54 | # ignore the input_ids
55 | continue
56 |
57 | q_ids = tokenizer.encode(
58 | system_prompt + "\n\n<|User|>" + question + "\n\n<|Assistant|>"
59 | )
60 |
61 | labels = [-100] * len(q_ids) + input_ids[len(q_ids):]
62 |
63 | #TODO: For expert iteration, the number of and may mismatch, so we need to check the number of and
64 | errors = ["SyntaxError", "Traceback (most recent call last)", "Error: Code execution timed out."]
65 | code_block_count = sample["synthetic_data"].count("")
66 | code_block_flag = [False] * code_block_count
67 | for c_id in range(code_block_count):
68 | # find the c_id-th ") > decoded_str.count("") and code_block_flag[decoded_str.count("") - 1] == False:
85 | labels[i] = -100
86 |
87 | # do not calculate the loss for the executor feedback
88 | decoded_str = ""
89 | for i in range(len(q_ids), len(input_ids)):
90 | decoded_str += tokenizer.decode(input_ids[i])
91 | if decoded_str.count("") > decoded_str.count(" "):
92 | labels[i] = -100
93 |
94 | # do not calculate the loss before the turn-over words
95 | turn_over_words = ["Wait, use text reasoning is too tedious, let's try code reasoning.", \
96 | "\nWait, the code is not correct, let's try text reasoning.", \
97 | "\nWait, the code may be incorrect, let's try text reasoning."
98 | ]
99 | turn_over_flag = False
100 | last_occurrence = np.inf
101 | for word in turn_over_words:
102 | if word in answer:
103 | # find the first occurrence of the word
104 | last_o = answer.find(word)
105 | if last_o < last_occurrence:
106 | last_occurrence = last_o
107 | turn_over_flag = True
108 | turn_over_word = word
109 | turn_over_num_token = len(tokenizer.encode(word))
110 |
111 | if turn_over_flag:
112 | decoded_str = ""
113 | for i in range(len(q_ids), len(input_ids)):
114 | decoded_str += tokenizer.decode(input_ids[i])
115 | labels[max(0, i - turn_over_num_token - 1)] = -100 # TODO: maybe contains 1 offset error, but it's ok.
116 | if turn_over_word in decoded_str:
117 | break
118 |
119 | # cache the tensors; change to half / int16 if memory is limited
120 | self.items.append((
121 | torch.tensor(input_ids, dtype=torch.long),
122 | torch.tensor(labels, dtype=torch.long),
123 | ))
124 |
125 | # loss calculation token count
126 | loss_calculation_token_count = 0
127 | for i in range(len(q_ids), len(input_ids)):
128 | if labels[i] != -100:
129 | loss_calculation_token_count += 1
130 | self.total_loss_calculation_token_count += loss_calculation_token_count
131 |
132 | if self.debug:
133 | # print the str that calculate the loss
134 | for i in range(len(q_ids), len(input_ids)):
135 | if labels[i] != -100:
136 | print(tokenizer.decode(input_ids[i]), end="")
137 | print()
138 |
139 | if max_data_count is None:
140 | max_data_count = len(self.items)
141 | np.random.seed(data_seed)
142 | np.random.shuffle(self.items)
143 | self.items = self.items[:max_data_count]
144 | print(f"Shuffled data with seed {data_seed} and got {len(self.items)} samples")
145 | print(f"Total loss calculation token count: {self.total_loss_calculation_token_count}")
146 |
147 | def __getitem__(self, idx):
148 | return self.items[idx]
149 |
150 | def __len__(self):
151 | return len(self.items)
152 |
153 | @staticmethod
154 | def collate_fn(batch):
155 | input_ids, labels = zip(*batch)
156 | input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
157 | labels = pad_sequence(labels, batch_first=True, padding_value=-100)
158 | return {"input_ids": input_ids, "labels": labels}
159 |
160 | if __name__ == "__main__":
161 | data = [
162 | {"problem": "What is the sum of the first 100 natural numbers?", "synthetic_data": "The sum of the first 100 natural numbers is 5050."},
163 | {"problem": "What is the sum of the first 100 natural numbers?", "synthetic_data": "The sum of the first 100 natural numbers is 5050. The sum of the first 100 natural numbers is 5050. "},
164 | {"problem": "What is the sum of the first 100 natural numbers?", "synthetic_data": "The sum of the first 100 natural numbers is 5050. test 1 Traceback (most recent call last) "},
165 | {"problem": "What is the sum of the first 100 natural numbers?", "synthetic_data": "The sum of the first 100 natural numbers is 5050. test 2 SyntaxError: I think the code is correct? "},
166 | {"problem": "What is the sum of the first 100 natural numbers?", "synthetic_data": "The sum of the first 100 natural numbers is 5050. test 3 Error: Code execution timed out. I think the code is correct. "},
167 | {"problem": "What is the sum of the first 100 natural numbers?", "synthetic_data": "The sum of the first 100 natural numbers is 5050. test 4 successful "},
168 | {"problem": "What is the sum of the first 100 natural numbers?", "synthetic_data": "The sum of the first 100 natural numbers is 5050. hahahaha \nWait, the code is not correct, let's try text reasoning.\n I think the code is correct. "},
169 | {"problem": "What is the sum of the first 100 natural numbers?", "synthetic_data": "The sum of the first 100 natural numbers is 5050. hahahaha Wait, use text reasoning is too tedious, let's try code reasoning. hahaha here is code! "},
170 | ]
171 | tokenizer = AutoTokenizer.from_pretrained("models/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B")
172 | train_data = TrainData(data, tokenizer, CODE_INSTRUCTION, debug=True)
--------------------------------------------------------------------------------
/sft/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import sys
4 | import gc, torch
5 | from torch.utils.data import Dataset
6 | from torch.optim import AdamW
7 | from torch.nn.utils.rnn import pad_sequence
8 | from tqdm import tqdm
9 | from accelerate import Accelerator
10 | from transformers import AutoTokenizer, AutoModelForCausalLM
11 | import time
12 | from typing import List, Union
13 | import numpy as np
14 | import random
15 |
16 | # Optional: Weights & Biases will only be used if a project name is passed.
17 | try:
18 | import wandb
19 | except ImportError:
20 | wandb = None
21 |
22 | parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
23 | sys.path.append(parent_dir)
24 |
25 | from sft.dataloader import TrainData
26 | from math_utils.utils import read_json_or_jsonl
27 | CODE_INSTRUCTION = """Meanwhile, you can use Python code to help you reasoning. The code should be enclosed within tags. For example, code here .
28 | A executor will run the code and provide feedback immediately after the code. The executor feedback should be enclosed within tags.
29 | You can use the executor feedback to improve your reasoning.
30 | """
31 |
32 | def train(model_path: str,
33 | data_path: Union[List[str], List[dict]],
34 | epochs: int,
35 | save_path: str,
36 | wandb_run=None,
37 | resume=False,
38 | resume_path=None,
39 | save_interval=1,
40 | batch_size=1,
41 | code_mode=False,
42 | max_data_count=None,
43 | data_seed=42,
44 | max_length=16384,
45 | lr=1e-5,
46 | gradient_accumulation_steps=4):
47 | """Train the model with the given parameters.
48 |
49 | Args:
50 | model_path (str): Path to the model checkpoint or identifier.
51 | data_path (List[str] or List[dict]): Path to the JSONL training data or list of training data.
52 | epochs (int): Number of training epochs.
53 | save_path (str): Where to save the fine-tuned model.
54 | wandb_run (wandb.Run or None): An optional wandb run instance.
55 | """
56 | accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps)
57 | os.environ['TOKENIZERS_PARALLELISM'] = 'false'
58 |
59 | model_path = model_path[:-1] if model_path.endswith('/') else model_path
60 | if resume:
61 | load_path = resume_path
62 | start_epoch = int(resume_path.split('_')[-1]) + 1
63 | else:
64 | load_path = model_path
65 | start_epoch = 0
66 | tokenizer = AutoTokenizer.from_pretrained(load_path)
67 | model = AutoModelForCausalLM.from_pretrained(
68 | load_path,
69 | torch_dtype=torch.bfloat16,
70 | attn_implementation="flash_attention_2",
71 | device_map="auto"
72 | )
73 | model.config.use_cache = False
74 | model.gradient_checkpointing_enable()
75 |
76 | data = []
77 | for p in data_path:
78 | if isinstance(p, dict):
79 | data.append(p)
80 | else:
81 | data.extend(read_json_or_jsonl(p))
82 |
83 | if code_mode:
84 | dataset = TrainData(data, tokenizer, CODE_INSTRUCTION, max_data_count=max_data_count, data_seed=data_seed, max_length=max_length)
85 | else:
86 | dataset = TrainData(data, tokenizer, "", max_data_count=max_data_count, data_seed=data_seed, max_length=max_length)
87 | data_loader = torch.utils.data.DataLoader(
88 | dataset, collate_fn=dataset.collate_fn, shuffle=True,
89 | batch_size=batch_size, num_workers=1
90 | )
91 | optimizer = AdamW(model.parameters(), lr=lr)
92 | model, optimizer, data_loader = accelerator.prepare(model, optimizer, data_loader)
93 |
94 | global_step = start_epoch * len(data_loader)
95 | for epoch in range(start_epoch, epochs):
96 | accelerator.print(f'Training epoch {epoch}')
97 | accelerator.wait_for_everyone()
98 | model.train()
99 |
100 | tk0 = tqdm(data_loader, total=len(data_loader), disable=not accelerator.is_main_process)
101 | loss_report = []
102 |
103 | loss_report = []
104 | grad_acc = accelerator.gradient_accumulation_steps
105 |
106 | for step, batch in enumerate(tk0):
107 | with accelerator.accumulate(model):
108 | outputs = model(**batch)
109 | loss = outputs.loss
110 | accelerator.backward(loss)
111 |
112 | # --- each micro-step gather the loss for statistics ---
113 | loss_val = accelerator.gather(loss.detach()).mean().item()
114 | loss_report.append(loss_val)
115 |
116 | if accelerator.sync_gradients:
117 | accelerator.clip_grad_norm_(model.parameters(), 1.0)
118 | optimizer.step()
119 | optimizer.zero_grad()
120 | # --- WandB loss ---
121 | if wandb_run is not None:
122 | window = loss_report[-grad_acc:]
123 | wandb_run.log(
124 | {"train_loss": sum(window) / len(window),
125 | "epoch": epoch},
126 | step=global_step
127 | )
128 | global_step += 1
129 |
130 | # --- average loss ---
131 | tk0.set_postfix(loss=sum(loss_report[-100:]) / len(loss_report[-100:]))
132 |
133 | if (epoch + 1) % save_interval == 0 or epoch == epochs - 1:
134 | accelerator.wait_for_everyone()
135 | unwrapped_model = accelerator.unwrap_model(model)
136 | unwrapped_model.save_pretrained(
137 | f'{save_path}_{epoch}',
138 | is_main_process=accelerator.is_main_process,
139 | save_function=accelerator.save,
140 | )
141 | tokenizer.save_pretrained(f'{save_path}_{epoch}')
142 | # Clean up
143 | del model, optimizer, data_loader, dataset, tokenizer
144 |
145 | def main():
146 | import argparse
147 | parser = argparse.ArgumentParser()
148 | parser.add_argument("--model_path", type=str, default='models/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B')
149 | parser.add_argument("--data_path", type=str, nargs='+', default=["data/dataset/train/dual_distill_data.jsonl"])
150 | parser.add_argument("--epochs", type=int, default=10)
151 | parser.add_argument("--resume", action='store_true')
152 | parser.add_argument("--resume_path", type=str, default=None)
153 | parser.add_argument("--code_mode", action='store_true')
154 | parser.add_argument("--save_path", type=str, default=None)
155 | parser.add_argument("--save_interval", type=int, default=5)
156 | parser.add_argument("--batch_size", type=int, default=1)
157 | parser.add_argument("--max_data_count", type=int, default=None)
158 | parser.add_argument("--data_seed", type=int, default=42)
159 | parser.add_argument("--max_length", type=int, default=16384)
160 | parser.add_argument(
161 | "--gradient_accumulation_steps",
162 | type=int,
163 | default=1,
164 | help="Number of batches to accumulate before each optimizer.step()"
165 | )
166 | parser.add_argument("--lr", type=float, default=1e-5)
167 | # Added W&B arguments
168 | parser.add_argument("--use_wandb", action='store_true', default=False)
169 | parser.add_argument("--wandb_project", type=str, default="dualdistill",
170 | help="If set, will enable wandb logging to the given project.")
171 | parser.add_argument("--wandb_run_name", type=str, default=None,
172 | help="An optional run name for wandb.")
173 |
174 | args = parser.parse_args()
175 | print(args)
176 |
177 | # fix all seeds
178 | torch.manual_seed(args.data_seed)
179 | torch.cuda.manual_seed(args.data_seed)
180 | torch.cuda.manual_seed_all(args.data_seed)
181 | np.random.seed(args.data_seed)
182 | random.seed(args.data_seed)
183 |
184 | assert args.epochs % args.save_interval == 0, "epochs must be divisible by save_interval"
185 | if args.model_path is None and not args.eval_only:
186 | raise ValueError("model_path is required for training")
187 | time_tag = time.strftime("%Y%m%d_%H%M%S", time.localtime())
188 | if args.resume:
189 | if args.resume_path is None:
190 | raise ValueError("resume_path is required for resume")
191 | time_tag = "_".join(args.resume_path.split("_")[-3:-1])
192 | print(f"Resuming from {args.resume_path} with time tag {time_tag}")
193 |
194 | if args.save_path is None:
195 | if len(args.data_path) == 1:
196 | args.save_path = args.model_path.strip("/") + "_" + args.data_path[0].strip("/").split("/")[-1].split(".")[0] + "_fine-tuned" + "_" + time_tag
197 | else:
198 | args.save_path = args.model_path.strip("/") + "_" + args.data_path[0].strip("/").split("/")[-2].split(".")[0] + "_mixed_data_" + "fine-tuned" + "_" + time_tag
199 | else:
200 | args.save_path = args.model_path.strip("/") + "_" + args.save_path.strip("/") + "_" + time_tag
201 |
202 | if args.wandb_run_name is None:
203 | args.wandb_run_name = (args.save_path.strip("/").split("/")[-1] +
204 | "_" + str(args.code_mode) +
205 | "_" + str(args.epochs) +
206 | "_" + time_tag)
207 |
208 | # Initialize wandb if user has set a project
209 | wandb_run = None
210 | if args.use_wandb and wandb is not None:
211 | config = vars(args)
212 | wandb_run = wandb.init(project=args.wandb_project, name=args.wandb_run_name, config=config)
213 |
214 | train(
215 | model_path=args.model_path,
216 | data_path=args.data_path,
217 | epochs=args.epochs,
218 | save_path=args.save_path,
219 | wandb_run=wandb_run,
220 | resume=args.resume,
221 | resume_path=args.resume_path,
222 | save_interval=args.save_interval,
223 | batch_size=args.batch_size,
224 | code_mode=args.code_mode,
225 | max_data_count=args.max_data_count,
226 | data_seed=args.data_seed,
227 | max_length=args.max_length,
228 | lr=args.lr,
229 | gradient_accumulation_steps=args.gradient_accumulation_steps
230 | )
231 |
232 | if wandb_run is not None:
233 | wandb_run.finish()
234 |
235 | if __name__ == '__main__':
236 | main()
--------------------------------------------------------------------------------
/dataset/test/aime2025.jsonl:
--------------------------------------------------------------------------------
1 | {"data_source":"aime2025","prompt":[{"content":"Find the sum of all integer bases $b>9$ for which $17_{b}$ is a divisor of $97_{b}$.","role":"user"}],"ability":"math","reward_model":{"ground_truth":"70","num_tokens":-512,"style":"rule"},"extra_info":{"index":0,"split":"test"}}
2 | {"data_source":"aime2025","prompt":[{"content":"On $\\triangle ABC$ points $A,D,E$, and $B$ lie that order on side $\\overline{AB}$ with $AD=4, DE=16$, and $EB=8$. Points $A,F,G$, and $C$ lie in that order on side $\\overline{AC}$ with $AF=13, FG=52$, and $GC=26$. Let $M$ be the reflection of $D$ through $F$, and let $N$ be the reflection of $G$ through $E$. Quadrilateral $DEGF$ has area 288. Find the area of heptagon $AFNBCEM$.","role":"user"}],"ability":"math","reward_model":{"ground_truth":"588","num_tokens":-512,"style":"rule"},"extra_info":{"index":1,"split":"test"}}
3 | {"data_source":"aime2025","prompt":[{"content":"The 9 members of a baseball team went to an ice cream parlor after their game. Each player had a singlescoop cone of chocolate, vanilla, or strawberry ice cream. At least one player chose each flavor, and the number of players who chose chocolate was greater than the number of players who chose vanilla, which was greater than the number of players who chose strawberry. Let $N$ be the number of different assignments of flavors to players that meet these conditions. Find the remainder when $N$ is divided by 1000.","role":"user"}],"ability":"math","reward_model":{"ground_truth":"16","num_tokens":-512,"style":"rule"},"extra_info":{"index":2,"split":"test"}}
4 | {"data_source":"aime2025","prompt":[{"content":"Find the number of ordered pairs $(x,y)$, where both $x$ and $y$ are integers between $-100$ and $100$, inclusive, such that $12x^{2}-xy-6y^{2}=0$.","role":"user"}],"ability":"math","reward_model":{"ground_truth":"117","num_tokens":-512,"style":"rule"},"extra_info":{"index":3,"split":"test"}}
5 | {"data_source":"aime2025","prompt":[{"content":"There are $8!=40320$ eight-digit positive integers that use each of the digits $1,2,3,4,5,6,7,8$ exactly once. Let $N$ be the number of these integers that are divisible by 22. Find the difference between $N$ and 2025.","role":"user"}],"ability":"math","reward_model":{"ground_truth":"279","num_tokens":-512,"style":"rule"},"extra_info":{"index":4,"split":"test"}}
6 | {"data_source":"aime2025","prompt":[{"content":"An isosceles trapezoid has an inscribed circle tangent to each of its four sides. The radius of the circle is 3, and the area of the trapezoid is 72. Let the parallel sides of the trapezoid have lengths $r$ and $s$, with $r \\neq s$. Find $r^{2}+s^{2}$.","role":"user"}],"ability":"math","reward_model":{"ground_truth":"504","num_tokens":-512,"style":"rule"},"extra_info":{"index":5,"split":"test"}}
7 | {"data_source":"aime2025","prompt":[{"content":"The twelve letters $A,B,C,D,E,F,G,H,I,J,K$, and $L$ are randomly grouped into six pairs of letters. The two letters in each pair are placed next to each other in alphabetical order to form six two-letter words, and those six words are listed alphabetically. For example, a possible result is $AB,CJ,DG,EK,FL,HI$. The probability that the last word listed contains $G$ is $\\frac{m}{n}$, where $m$ and $n$ are relatively prime positive integers. Find $m+n$.","role":"user"}],"ability":"math","reward_model":{"ground_truth":"821","num_tokens":-512,"style":"rule"},"extra_info":{"index":6,"split":"test"}}
8 | {"data_source":"aime2025","prompt":[{"content":"Let $k$ be real numbers such that the system $|25+20i-z|=5$ and $|z-4-k|=|z-3i-k|$ has exactly one complex solution $z$. The sum of all possible values of $k$ can be written as $\\frac{m}{n}$, where $m$ and $n$ are relatively prime positive integers. Find $m+n$. Here $i=\\sqrt{-1}$.","role":"user"}],"ability":"math","reward_model":{"ground_truth":"77","num_tokens":-512,"style":"rule"},"extra_info":{"index":7,"split":"test"}}
9 | {"data_source":"aime2025","prompt":[{"content":"The parabola with equation $y=x^{2}-4$ is rotated $60^{\\circ}$ counterclockwise around the origin. The unique point in the fourth quadrant where the original parabola and its image intersect has $y$-coordinate $\\frac{a-\\sqrt{b}}{c}$, where $a$, $b$, and $c$ are positive integers, and $a$ and $c$ are relatively prime. Find $a+b+c$.","role":"user"}],"ability":"math","reward_model":{"ground_truth":"62","num_tokens":-512,"style":"rule"},"extra_info":{"index":8,"split":"test"}}
10 | {"data_source":"aime2025","prompt":[{"content":"The 27 cells of a $3\\times9$ grid are filled in using the numbers 1 through 9 so that each row contains 9 different numbers, and each of the three $3\\times3$ blocks heavily outlined in the example below contains 9 different numbers, as in the first three rows of a Sudoku puzzle. \n | 4 | 2 | 8 | 9 | 6 | 3 | 1 | 7 | 5 | \n | 3 | 7 | 9 | 5 | 2 | 1 | 6 | 8 | 4 | \n | 5 | 6 | 1 | 8 | 4 | 7 | 9 | 2 | 3 | \n The number of different ways to fill such a grid can be written as $p^a\\cdot q^b\\cdot r^c\\cdot s^d$, where $p,q,r,$ and $s$ are distinct prime numbers and $a,b,c,$ and $d$ are positive integers. Find $p\\cdot a+q\\cdot b+r\\cdot c+s\\cdot d$.","role":"user"}],"ability":"math","reward_model":{"ground_truth":"81","num_tokens":-512,"style":"rule"},"extra_info":{"index":9,"split":"test"}}
11 | {"data_source":"aime2025","prompt":[{"content":"A piecewise linear periodic function is defined by $f(x)=\\begin{cases}x&\\text{if }x\\in[-1,1)\\\\2-x&\\text{if }x\\in[1,3)\\end{cases}$ and $f(x+4)=f(x)$ for all real numbers $x$. The graph of $f(x)$ has the sawtooth pattern. The parabola $x=34y^2$ intersects the graph of $f(x)$ at finitely many points. The sum of the $y$-coordinates of these intersection points can be expressed in the form $\\frac{a+b\\sqrt{c}}{d}$, where $a,b,c,$ and $d$ are positive integers, $a,b,$ and $d$ have greatest common divisor equal to 1, and $c$ is not divisible by the square of any prime. Find $a+b+c+d$.","role":"user"}],"ability":"math","reward_model":{"ground_truth":"259","num_tokens":-512,"style":"rule"},"extra_info":{"index":10,"split":"test"}}
12 | {"data_source":"aime2025","prompt":[{"content":"The set of points in 3-dimensional coordinate space that lie in the plane $x+y+z=75$ whose coordinates satisfy the inequalities $x-yz>')\n" + raw_code
53 |
54 | # incremental code execution
55 | os.makedirs("workspace", exist_ok=True)
56 | with tempfile.NamedTemporaryFile("w", dir="workspace", suffix=".py", delete=False, encoding='utf-8') as tmp_file:
57 | tmp_file.write(raw_code)
58 | tmp_filename = tmp_file.name
59 | try:
60 | result = subprocess.run(
61 | ["python", tmp_filename],
62 | capture_output=True,
63 | text=True,
64 | timeout=3
65 | )
66 | os.remove(tmp_filename)
67 | output = result.stdout.split("<>\n")[-1]
68 | if result.stderr:
69 | return output, result.stderr, previous_code
70 | else:
71 | return output, result.stderr, previous_code + "\n" + raw_code
72 | except subprocess.TimeoutExpired:
73 | os.remove(tmp_filename)
74 | return "", "Error: Code execution timed out.", previous_code
75 |
76 | def code_block_with_io(raw_code: str):
77 | code = raw_code.replace("```python", "").replace("```", "")
78 |
79 | stdout_capture = io.StringIO()
80 | stderr_capture = io.StringIO()
81 |
82 | env = {}
83 |
84 | try:
85 | tree = ast.parse(code, mode='exec')
86 |
87 | with contextlib.redirect_stdout(stdout_capture), contextlib.redirect_stderr(stderr_capture):
88 | for node in tree.body:
89 | if isinstance(node, ast.Expr):
90 | expr_code = compile(ast.Expression(node.value), filename="", mode="eval")
91 | result = eval(expr_code, env)
92 | if result is not None:
93 | print(result)
94 | else:
95 | stmt_code = compile(ast.Module([node], []), filename="", mode="exec")
96 | exec(stmt_code, env)
97 | except Exception as e:
98 | stderr_capture.write(str(e))
99 |
100 | return stdout_capture.getvalue(), stderr_capture.getvalue()
101 |
102 | def parse_boxed(text, reverse=False):
103 | """
104 | Returns a list of all the contents inside \\boxed{...} in `text`,
105 | handling nested braces to a reasonable extent.
106 | """
107 | results = []
108 | search_start = 0
109 | marker = r'\boxed{'
110 |
111 | while True:
112 | # Look for the next occurrence of \boxed{
113 | start_index = text.find(marker, search_start)
114 | if start_index == -1:
115 | # No more \boxed{ found
116 | break
117 |
118 | # The position right after '\boxed{'
119 | brace_start = start_index + len(marker)
120 |
121 | # Use a stack to find the matching '}'
122 | brace_count = 1
123 | pos = brace_start
124 |
125 | while pos < len(text) and brace_count > 0:
126 | if text[pos] == '{':
127 | brace_count += 1
128 | elif text[pos] == '}':
129 | brace_count -= 1
130 | pos += 1
131 |
132 | # If brace_count == 0, 'pos-1' is where the matching '}' was found
133 | if brace_count == 0:
134 | content = text[brace_start : pos - 1]
135 | results.append(content)
136 | # Continue searching after this boxed content
137 | search_start = pos
138 | else:
139 | # We reached the end of the text without finding a matching brace
140 | break
141 | if len(results) == 0:
142 | return "No Answer"
143 | if not reverse:
144 | return results[0]
145 | else:
146 | return results[-1]
147 |
148 | def read_json_or_jsonl(file_path):
149 | with open(file_path, 'r') as f:
150 | if file_path.endswith('.json'):
151 | return json.load(f)
152 | elif file_path.endswith('.jsonl'):
153 | return [json.loads(line) for line in f]
154 |
155 | SYSTEM_PROMPT_TPL = (
156 | "A conversation between User and Assistant. The user asks a question, "
157 | "and the Assistant solves it.\n"
158 | "The assistant first thinks about the reasoning process in the mind and then "
159 | "provides the user with the answer. \n"
160 | "The reasoning process and answer are enclosed within and "
161 | " tags, respectively, i.e., reasoning process here "
162 | " answer here .\n\n"
163 | "The final answer should be enclosed within \\boxed tags, i.e., "
164 | "\\boxed{{answer here}}.\n\n"
165 | "{code_instruction}\n\n"
166 | "Do not write text outside the tags."
167 | )
168 |
169 | CODE_INSTRUCTION = """Meanwhile, you can use Python code to help you reason. The code should be enclosed within tags. For example, code here .
170 | An executor will run the code and provide feedback immediately after the code. The executor feedback should be enclosed within tags.
171 | You can use the executor feedback to improve your reasoning.
172 | """
173 |
174 | def add_comma_into_number(number_str):
175 | try:
176 | number = float(number_str)
177 | if number.is_integer():
178 | return "{:,}".format(int(number))
179 | else:
180 | return "{:,}".format(number)
181 | except (ValueError, TypeError):
182 | return number_str
183 |
184 | def compute_score(solution_str, ground_truth, able_to_use_original_solution=False) -> float:
185 | ground_truth = str(ground_truth)
186 |
187 | # for non-finetuned model (i.e., except Agentic-R1), use additional pass by the original solution
188 | original_solution_str = solution_str
189 |
190 | # For the agentic trajectory, remove the redundant part
191 | if " " in solution_str:
192 | interim_strs = ["", "", "", ""]
193 | found = [(solution_str.rfind(s), i) for i, s in enumerate(interim_strs) if solution_str.rfind(s) != -1]
194 | if found:
195 | last_interim_str, last_interim_str_idx = max(found)
196 | solution_str = solution_str[:last_interim_str + len(interim_strs[last_interim_str_idx])]
197 |
198 | if " tags. For example, code here .
53 | A executor will run the code and provide feedback immediately after the code. The executor feedback should be enclosed within ")[-1].split("")[0]
90 | code_output, error, previous_code = code_block(raw_code, is_ipython=is_ipython, previous_code=previous_code)
91 | except Exception as e:
92 | error = str(e)
93 | code_output = "Error"
94 | if error:
95 | executor_feedback = f"\n tags. For example, code here .
156 | A executor will run the code and provide feedback immediately after the code. The executor feedback should be enclosed within ")[-1].split("")[0]
201 | code_output, error, previous_code = code_block(raw_code, is_ipython=is_ipython, previous_code=previous_code)
202 | except Exception as e:
203 | error = str(e)
204 | code_output = "Error"
205 |
206 | if error:
207 | executor_feedback = f"\n