├── README.md
└── train_code
├── .DS_Store
├── arc_data
├── arc_7k_train_data.jsonl
├── arc_7k_val_data.jsonl
├── arc_challenge_train_val
│ ├── train_data.jsonl
│ └── val_data.jsonl
├── arc_small_200_query_data.jsonl
├── arc_test_data.jsonl
└── arc_train_data.jsonl
├── arc_prompt
├── 0-shot-prompt.txt
└── 1-shot-prompt.txt
├── arguments.py
├── cal_metric_vllm.py
├── checkpoint_utils.py
├── convert4reward_auto_ground_sample.py
├── convert_auto.py
├── convert_checkpoint_to_hf.py
├── data_utils
├── __pycache__
│ ├── data_utils_ppo.cpython-310.pyc
│ ├── data_utils_ppo.cpython-311.pyc
│ └── data_utils_rm_pointwise.cpython-310.pyc
├── data_utils_dpo.py
├── data_utils_ppo.py
├── data_utils_rm_pairwise.py
├── data_utils_rm_pointwise.py
└── data_utils_sft.py
├── determine_hyper.py
├── eval_arc_save_metrics.py
├── gsm8k_test.py
├── hf_argparser.py
├── infer_arc.sh
├── infer_gsm8k1.sh
├── inference_reward_llama3.py
├── math_utils
├── README.md
├── __pycache__
│ ├── grader.cpython-310.pyc
│ ├── grader.cpython-311.pyc
│ ├── math_normalize.cpython-310.pyc
│ ├── math_normalize.cpython-311.pyc
│ ├── math_rl_utils.cpython-310.pyc
│ └── math_rl_utils.cpython-311.pyc
├── grader.py
├── math_normalize.py
└── math_rl_utils.py
├── metric_modiacc_auto.py
├── models
├── __pycache__
│ ├── frozen_layers.cpython-310.pyc
│ ├── frozen_layers.cpython-311.pyc
│ ├── frozen_layers.cpython-39.pyc
│ ├── model.cpython-310.pyc
│ ├── model.cpython-311.pyc
│ ├── model.cpython-39.pyc
│ ├── quantize.cpython-310.pyc
│ ├── quantize.cpython-311.pyc
│ ├── reward_model.cpython-310.pyc
│ ├── reward_model.cpython-311.pyc
│ ├── rl_model.cpython-310.pyc
│ ├── rl_model.cpython-311.pyc
│ ├── tokenizer_utils.cpython-310.pyc
│ ├── tokenizer_utils.cpython-311.pyc
│ ├── tokenizer_utils.cpython-38.pyc
│ ├── tp.cpython-310.pyc
│ └── tp.cpython-311.pyc
├── frozen_layers.py
├── model.py
├── quantize.py
├── reward_model.py
├── rl_model.py
├── tokenizer_utils.py
└── tp.py
├── scripts
├── convert.sh
├── convert_checkpoint_to_hf.py
├── convert_hf_checkpoint.py
├── convert_hf_checkpoint_llama3.py
├── download.py
├── prepare_ds_math_7b.sh
├── prepare_llemma_34b.sh
└── prepare_llemma_7b.sh
├── test.jsonl
├── test_ppo.json
├── train_bstar.sh
├── train_reward.sh
├── train_rm_pointwise.py
├── train_sft.py
├── train_sft.sh
├── train_sft_step.py
├── trainers
├── __pycache__
│ ├── common_utils.cpython-310.pyc
│ ├── common_utils.cpython-311.pyc
│ ├── ppo_trainer.cpython-310.pyc
│ ├── ppo_trainer.cpython-311.pyc
│ ├── rl_trainer.cpython-310.pyc
│ └── rl_trainer.cpython-311.pyc
├── common_utils.py
├── ppo_trainer.py
└── rl_trainer.py
├── training_utils
├── __pycache__
│ ├── fsdp_utils.cpython-310.pyc
│ ├── fsdp_utils.cpython-311.pyc
│ ├── memory_efficient_adam.cpython-310.pyc
│ ├── memory_efficient_adam.cpython-311.pyc
│ ├── trainer_utils.cpython-310.pyc
│ └── trainer_utils.cpython-311.pyc
├── fsdp_utils.py
├── memory_efficient_adam.py
└── trainer_utils.py
├── vllm_infer.py
├── vllm_infer_arc.py
└── vllm_infer_auto.py
/README.md:
--------------------------------------------------------------------------------
1 | # B-STAR: Monitoring and Balancing Exploration and Exploitation in Self-Taught Reasoners
2 |
3 |
4 |
5 | 📄 Paper
6 |
8 |
9 |
10 | B-STAR (Balanced Self-Taught Reasoner) is a framework designed to improve the self-improvement process of reasoning models by dynamically balancing exploration and exploitation throughout training. This approach is particularly effective in enhancing performance in tasks requiring complex reasoning, such as mathematical problem-solving, coding, and commonsense reasoning.
11 |
12 |
13 | 
14 |
15 |
16 | ## Overview
17 |
18 | Self-improvement in reasoning models involves iterative training where models generate their own training data from outputs. However, existing methods often stagnate after a few iterations due to imbalances between two critical factors:
19 |
20 | 1. **Exploration**: The model's ability to generate diverse and high-quality responses.
21 | 2. **Exploitation**: The effectiveness of external rewards in distinguishing and leveraging high-quality responses.
22 |
23 | 
24 |
25 | B-STAR introduces an adaptive mechanism to monitor and balance these factors dynamically, ensuring consistent performance improvements over multiple training iterations
26 |
27 |
28 | ## Key Features
29 |
30 | - **Dynamic Configuration Adjustments**: Automatically tunes exploration and exploitation configurations (e.g., sampling temperature, reward thresholds) to optimize the self-improvement process.
31 | - **Balance Score Metric**: Quantifies the interplay between exploration and exploitation, guiding dynamic adjustments.
32 | - **Generalization Across Tasks**: Demonstrates effectiveness in mathematical reasoning, coding challenges, and commonsense reasoning tasks
33 |
34 |
35 | ## Results
36 |
37 | B-STAR achieves state-of-the-art performance across various benchmarks:
38 |
39 | - Significant improvements compared to previsous self-improvement methods.
40 | 
41 |
42 |
43 | - Sustained performance growth across multiple iterations, outperforming existing methods that stagnate after a few iterations.
44 | 
45 |
46 | ## Reproduction
47 |
48 | Our code builds upon [easy-to-hard](https://github.com/Edward-Sun/easy-to-hard/tree/main) and [gpt-accelerate](https://github.com/Edward-Sun/gpt-accelera). Please refer to gpt-accelerate for environment setup and model weight conversion instructions.
49 |
50 | ### 1. Prepare Model
51 |
52 | We first need to prepare the model checkpoint in the gpt-fast format.
53 |
54 | ```shell
55 | export DATA_DIR=/path/to/your/data/directory
56 | export MODEL_REPO=mistralai/Mistral-7B-v0.1
57 |
58 | python scripts/download.py \
59 | --repo_id $MODEL_REPO \
60 | --local_dir $DATA_DIR/checkpoints
61 |
62 | python scripts/convert_hf_checkpoint.py \
63 | --checkpoint_dir $DATA_DIR/checkpoints/$MODEL_REPO \
64 | --target_precision bf16
65 | ```
66 |
67 | ### 2. Train SFT Model
68 |
69 | ```shell
70 | export DATA_DIR=/path/to/your/data/directory
71 | export MODEL_REPO= $DATA_DIR/checkpoints/Mistral-7B-v0.1
72 |
73 | export OMP_NUM_THREADS=8
74 |
75 |
76 | SFT_TRAIN_DATA=https://huggingface.co/datasets/AndrewZeng/math-trn-format/blob/main/math_format.json
77 |
78 | # Please download this dataset to local folder
79 | SFT_MODEL_SAVE_NAME=math_format_11k_mistral
80 |
81 | torchrun --standalone --nproc_per_node=8 \
82 | train_sft.py \
83 | --do_train \
84 | --checkpoint_path $MODEL_REPO/model.pth \
85 | --source_max_len 768 \
86 | --target_max_len 768 \
87 | --total_max_len 1024 \
88 | --per_device_train_batch_size 16 \
89 | --micro_train_batch_size 4 \
90 | --learning_rate 5e-6 \
91 | --lr_eta_min 2e-7 \
92 | --num_train_epochs 3 \
93 | --dataset "$SFT_TRAIN_DATA" \
94 | --dataset_format "metamath" \
95 | --add_eos_to_marked_target \
96 | --save_strategy "steps" \
97 | --save_steps 25 \
98 | --optim_dtype bf16 \
99 | --save_total_limit 40 \
100 | --tensor_parallel_size 1 \
101 | --save_dir $DATA_DIR/checkpoints/$SFT_MODEL_SAVE_NAME \
102 | --resume_from_checkpoint
103 | ```
104 |
105 | ### 3. Train PRM Model
106 |
107 | We constructed the [PRM training data](https://huggingface.co/datasets/AndrewZeng/prm-reward-data) using the [math-shepherd](https://arxiv.org/abs/2312.08935) approach and trained the reward model using a pointwise objective.
108 |
109 | ```shell
110 | export DATA_DIR=/path/to/your/data/directory
111 |
112 | export MODEL_REPO= $DATA_DIR/checkpoints/Mistral-7B-v0.1
113 | export OMP_NUM_THREADS=4
114 |
115 |
116 | RM_DATA=train_prm_math_shepherd_mistral.json
117 | RM_MODEL_SAVE_NAME=prm_model_mistral_sample_complete
118 |
119 | torchrun --standalone --nproc_per_node=8 \
120 | train_rm_pointwise.py \
121 | --do_train \
122 | --checkpoint_path $MODEL_REPO/model.pth \
123 | --source_max_len 768 \
124 | --target_max_len 768 \
125 | --total_max_len 1024 \
126 | --per_device_train_batch_size 32 \
127 | --micro_train_batch_size 32 \
128 | --learning_rate 2e-6 \
129 | --lr_eta_min 2e-7 \
130 | --num_train_epochs 2 \
131 | --dataset "$RM_DATA" \
132 | --dataset_format "prm-v4" \
133 | --save_strategy epoch \
134 | --save_total_limit 5 \
135 | --train_on_every_token \
136 | --tensor_parallel_size 1 \
137 | --save_only_model True \
138 | --optim_dtype bf16 \
139 | --save_dir $DATA_DIR/checkpoints/$RM_MODEL_SAVE_NAME \
140 | --resume_from_checkpoint
141 | ```
142 |
143 | ### 4. Train B-STaR
144 |
145 | ```shell
146 | ## This is our initial release code.
147 | ## We are working hard to clean it to make our code more clear and more readable
148 | cd train_code
149 | bash train_bstar.sh
150 | ```
151 |
152 | ### 5. Evaluation
153 |
154 | Coming Soon !
155 |
156 | ## Citation
157 |
158 | If you find B-STaR useful, please cite our paper:
159 |
160 | ```
161 | @article{zeng2024bstar,
162 | title={B-STAR: Monitoring and Balancing Exploration and Exploitation in Self-Taught Reasoners},
163 | author={Weihao Zeng, Yuzhen Huang, Lulu Zhao, Yijun Wang, Zifei Shan, Junxian He},
164 | journal={arXiv preprint arXiv:2412.17256},
165 | year={2024},
166 | url={https://arxiv.org/abs/2412.17256}
167 | }
168 | ```
169 |
170 |
171 |
--------------------------------------------------------------------------------
/train_code/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/.DS_Store
--------------------------------------------------------------------------------
/train_code/arc_prompt/0-shot-prompt.txt:
--------------------------------------------------------------------------------
1 | Your task is to answer the question below. Give step by step reasoning before you answer.
2 |
3 | Question: {{question}}
4 | Solution:
5 |
--------------------------------------------------------------------------------
/train_code/arc_prompt/1-shot-prompt.txt:
--------------------------------------------------------------------------------
1 | Your task is to answer the question below. Give step by step reasoning before you answer.
2 |
3 | Question: George wants to warm his hands quickly by rubbing them. Which skin surface will produce the most heat?
4 | (A) dry palms
5 | (B) wet palms
6 | (C) palms covered with oil
7 | (D) palms covered with lotion
8 | Solution:
9 | To determine which skin surface will produce the most heat when George rubs his hands, let's analyze each option step by step.
10 | 1. Heat Production and Friction: When two surfaces are rubbed together, the heat generated is a result of friction between those surfaces. The higher the friction, the more heat is produced.
11 | 2. Effect of Different Surface Conditions:
12 | Dry Palms (Option A): Dry palms have a relatively high friction coefficient, as they lack any lubricants or moisture that could reduce friction. This means rubbing dry palms together would generate significant friction and, thus, more heat.
13 | Wet Palms (Option B): Wet palms reduce friction due to the presence of water. Water acts as a lubricant, decreasing the friction between the palms, resulting in less heat production.
14 | Palms Covered with Oil (Option C): Oil is an effective lubricant. It reduces friction significantly, which would decrease the heat produced when rubbing the palms.
15 | Palms Covered with Lotion (Option D): Lotion, similar to oil, acts as a lubricant and reduces friction between the palms, leading to lower heat production compared to dry palms.
16 | 3. Conclusion: Since friction is highest with dry palms, they will produce the most heat when rubbed together, as no lubricants are present to reduce friction.
17 | Final answer: (B)
18 |
19 | Question: {{question}}
20 | Solution:
21 |
--------------------------------------------------------------------------------
/train_code/cal_metric_vllm.py:
--------------------------------------------------------------------------------
1 | import json
2 | from models.tokenizer_utils import AcceleraTokenizer
3 | #from math_utils.math_rl_utils import post_process_math_rollouts
4 |
5 | from typing import Dict, List, Optional, Tuple
6 | import torch
7 | import trainers.common_utils as common_utils
8 | from itertools import chain
9 | from math_utils import grader
10 | import torch.distributed as dist
11 | from pathlib import Path
12 | def last_boxed_only_string(string):
13 | idx = string.rfind("\\boxed")
14 | if idx < 0:
15 | return None
16 |
17 | i = idx
18 | right_brace_idx = None
19 | num_left_braces_open = 0
20 | while i < len(string):
21 | if string[i] == "{":
22 | num_left_braces_open += 1
23 | if string[i] == "}":
24 | num_left_braces_open -= 1
25 | if num_left_braces_open == 0:
26 | right_brace_idx = i
27 | break
28 | i += 1
29 |
30 | if right_brace_idx == None:
31 | retval = None
32 | else:
33 | retval = string[idx : right_brace_idx + 1]
34 |
35 | return retval
36 |
37 |
38 | def remove_boxed(s):
39 | left = "\\boxed{"
40 | try:
41 | assert s[: len(left)] == left
42 | assert s[-1] == "}"
43 | return s[len(left) : -1]
44 | except:
45 | return None
46 | def _calculate_outcome_accuracy(
47 | predicted_answers: List[str],
48 | gt_answers: List[str],
49 | answers: List[str],
50 | levels: List[int],
51 | outcome_reward: bool,
52 | easy_outcome_reward: bool,
53 | device: torch.device,
54 | ):
55 | assert len(predicted_answers) == len(answers)
56 |
57 | assert not (
58 | outcome_reward and easy_outcome_reward
59 | ), "Cannot use both outcome_reward and easy_outcome_reward."
60 |
61 | with common_utils.DisableLogger():
62 | outcome_accuracy = [
63 | 1.0 if grader.grade_answer(predicted_answer, gt_answer) else 0.0
64 | for predicted_answer, gt_answer in zip(predicted_answers, gt_answers)
65 | ]
66 |
67 | # TODO (zhiqings): 0.25 is a magic number.
68 | unavailable_reward = 0.25
69 | if outcome_reward:
70 | symbolic_rewards = outcome_accuracy
71 | elif easy_outcome_reward:
72 | symbolic_rewards = []
73 | for predicted_answer, answer in zip(predicted_answers, answers):
74 | if answer == "Unavailable":
75 | score = unavailable_reward
76 | elif grader.grade_answer(predicted_answer, answer):
77 | score = 1.0
78 | else:
79 | score = 0.0
80 | symbolic_rewards.append(score)
81 | else:
82 | symbolic_rewards = [
83 | unavailable_reward for _ in range(len(predicted_answers))
84 | ]
85 |
86 | assert len(symbolic_rewards) == len(predicted_answers)
87 |
88 | per_level_counts = {}
89 | per_level_accuracy = {}
90 |
91 | all_unique_levels = list(range(1, 6))
92 |
93 | for level in all_unique_levels:
94 | per_level_counts[level] = []
95 | per_level_accuracy[level] = []
96 |
97 | for level, accuracy in zip(levels, outcome_accuracy):
98 | for unique_level in all_unique_levels:
99 | if level == unique_level:
100 | per_level_counts[unique_level].append(1.0)
101 | per_level_accuracy[unique_level].append(accuracy)
102 | else:
103 | per_level_counts[unique_level].append(0.0)
104 | per_level_accuracy[unique_level].append(0.0)
105 |
106 | for level in all_unique_levels:
107 | assert len(per_level_counts[level]) == len(outcome_accuracy)
108 | assert len(per_level_accuracy[level]) == len(outcome_accuracy)
109 | per_level_counts[level] = torch.tensor(per_level_counts[level], device=device)
110 | per_level_accuracy[level] = torch.tensor(
111 | per_level_accuracy[level], device=device
112 | )
113 |
114 | original_symbolic_rewards = symbolic_rewards
115 |
116 | symbolic_rewards = torch.tensor(symbolic_rewards, device=device)
117 | outcome_accuracy = torch.tensor(outcome_accuracy, device=device)
118 |
119 | ret_dict = {
120 | "symbolic_rewards": symbolic_rewards,
121 | "outcome_accuracy": outcome_accuracy,
122 | }
123 |
124 | for level in sorted(list(all_unique_levels)):
125 | ret_dict[f"level_{level}_counts"] = per_level_counts[level]
126 | ret_dict[f"level_{level}_accuracy"] = per_level_accuracy[level]
127 |
128 | return ret_dict, original_symbolic_rewards
129 | def merge_fn(tensor_or_list):
130 | if isinstance(tensor_or_list[0], list):
131 | return list(chain(*tensor_or_list))
132 | else:
133 | return torch.cat(tensor_or_list, dim=0)
134 | def post_process_math_rollouts(
135 | text_responses: List[str],
136 | answers: List[str],
137 | gt_answers: List[str],
138 | levels: List[str],
139 | tokenizer: AcceleraTokenizer,
140 | stop_token: Optional[str],
141 | outcome_reward: bool,
142 | easy_outcome_reward: bool,
143 | device: torch.device,
144 | ):
145 | if stop_token is not None:
146 | parsed_stop_token = stop_token
147 | parsed_stop_token = parsed_stop_token.replace(r"\n", "\n")
148 | parsed_stop_token = parsed_stop_token.replace(r"\\", "\\")
149 | else:
150 | parsed_stop_token = tokenizer.eos_token
151 |
152 | predicted_answers = []
153 | for text_response in text_responses:
154 | predicted_answer = "No answer found."
155 | if "\n\n" in parsed_stop_token:
156 | if parsed_stop_token in text_response:
157 | predicted_answer = text_response.split(parsed_stop_token)[1]
158 | predicted_answer = predicted_answer.split(tokenizer.eos_token)[0]
159 | elif "\\boxed{}" == parsed_stop_token:
160 | boxed_predicted_answer = text_response.split(tokenizer.eos_token)[0]
161 | boxed_predicted_answer = remove_boxed(
162 | last_boxed_only_string(boxed_predicted_answer)
163 | )
164 | if boxed_predicted_answer is not None:
165 | predicted_answer = boxed_predicted_answer
166 | else:
167 | raise ValueError(f"Unknown stop token: {parsed_stop_token}")
168 | predicted_answers.append(predicted_answer)
169 |
170 | # text_answers_gt_levels = tokenizer.batch_decode(
171 | # answer_gt_levels,
172 | # skip_special_tokens=True,
173 | # )
174 |
175 | #answers, gt_answers, levels = [], [], []
176 | # for text_answers_gt_level in text_answers_gt_levels:
177 | # assert len(text_answers_gt_level.split(";;;")) == 3, text_answers_gt_level
178 | # answer, gt_answer, level = text_answers_gt_level.split(";;;")
179 | # answers.append(answer.strip())
180 | # gt_answers.append(gt_answer.strip())
181 | # levels.append(int(level.strip()))
182 |
183 | outcome_metrics, symbolic_rewards = _calculate_outcome_accuracy(
184 | predicted_answers,
185 | gt_answers,
186 | answers,
187 | levels,
188 | outcome_reward,
189 | easy_outcome_reward,
190 | device,
191 | )
192 | return (
193 | predicted_answers,
194 | gt_answers,
195 | levels,
196 | symbolic_rewards,
197 | outcome_metrics,
198 | )
199 | def main(
200 | tokenizer_path: Path = Path(
201 | "/ssddata/weihao00/model_zoo/llemma_7b/checkpoints/EleutherAI/llemma_7b/tokenizer.model"
202 | ),
203 |
204 | answer_file: Path = Path(
205 | "/ssddata/weihao00/easy2hard/easy-to-hard-main/data/test_ppo.json"
206 | ),
207 | output_file: Path = Path(
208 | "/ssddata/weihao00/easy2hard/save_file/test_ppo_infer_metric.json"
209 | ),
210 |
211 |
212 |
213 | ):
214 |
215 | tokenizer = AcceleraTokenizer(tokenizer_path)
216 | tokenizer.pad_id = tokenizer.unk_id
217 |
218 | with open(answer_file, "r") as r:
219 | test_ppo = json.load(r)
220 |
221 |
222 | # with open(prompt_file , "r") as r:
223 | # data_lines = r.readlines()
224 |
225 | # data_json = [json.loads(l) for l in data_lines]
226 | #test_ppo = data_json
227 | # for idx, item in enumerate(test_ppo):
228 | # item["idx"] = idx
229 |
230 | # for i in data_json:
231 | # if i["idx"] == idx:
232 | # item["prompt"] = i["prompt"]
233 |
234 | # item["output"] = i["output"]
235 |
236 |
237 | test_queries = []
238 | test_responses = []
239 | answers_list = []
240 | gt_answers_list = []
241 | levels_list = []
242 | for item in test_ppo:
243 | test_queries.append(item["input"])
244 | test_responses.append(item["output0"])
245 | answers_list.append(item["answer"])
246 | gt_answers_list.append(item["gt_answer"])
247 | levels_list.append(item["level"])
248 |
249 | eval_rollouts_batch = {}
250 | eval_rollouts_batch["text_queries"] = test_queries
251 | eval_rollouts_batch["text_responses"] = test_responses
252 |
253 | outcome_metrics = post_process_math_rollouts(test_responses, answers_list, gt_answers_list, levels_list, tokenizer, "\n\n# Answer\n\n", False, False, torch.device('cpu'))
254 |
255 | eval_rollouts_batch.update(outcome_metrics[-1])
256 | cpu_eval_rollouts = []
257 |
258 | cpu_eval_rollouts.append(
259 | {
260 | key: value.cpu() if torch.is_tensor(value) else value
261 | for key, value in eval_rollouts_batch.items()
262 | }
263 | )
264 | eval_rollouts = cpu_eval_rollouts
265 |
266 | eval_rollouts = common_utils.merge_dict(eval_rollouts, merge_fn=merge_fn)
267 |
268 | eval_stats = {}
269 | overall_counts = 0.0
270 | overall_accuracy = 0.0
271 | for level in range(9):
272 | if f"level_{level}_counts" in eval_rollouts:
273 | level_counts = eval_rollouts[f"level_{level}_counts"].sum()
274 | level_accuracy = eval_rollouts[f"level_{level}_accuracy"].sum()
275 | overall_counts += level_counts
276 | overall_accuracy += level_accuracy
277 |
278 | eval_stats[f"accuracy_level_{level}"] = level_accuracy / level_counts
279 |
280 | eval_stats[f"counts_level_{level}"] = level_counts
281 |
282 | eval_stats[f"accuracy_overall"] = overall_accuracy.view(1) / (
283 | overall_counts.view(1)
284 | )
285 | eval_stats[f"counts_overall"] = overall_counts
286 | eval_stats = {
287 | key: value.item() if torch.is_tensor(value) else value
288 | for key, value in eval_stats.items()
289 | }
290 | print(eval_stats)
291 | with open(output_file, "w") as w:
292 | json.dump(eval_stats, w)
293 |
294 | if __name__ == "__main__":
295 |
296 | import argparse
297 |
298 | parser = argparse.ArgumentParser(description="Your CLI description.")
299 |
300 | parser.add_argument(
301 | "--tokenizer_path",
302 | type=Path,
303 | required=True,
304 | help="File containing prompts, one per line.",
305 | )
306 | parser.add_argument(
307 | "--answer_file",
308 | type=Path,
309 | required=True,
310 | help="File containing prompts, one per line.",
311 | )
312 | parser.add_argument(
313 | "--output_file",
314 | type=Path,
315 | required=True,
316 | help="File to write generated samples to.",
317 | )
318 | args = parser.parse_args()
319 | main(
320 | args.tokenizer_path,
321 | args.answer_file,
322 | args.output_file,
323 |
324 | )
325 |
326 |
327 | # tokenizer = AcceleraTokenizer("/ssddata/weihao00/model_zoo/llemma_7b/checkpoints/EleutherAI/llemma_7b/tokenizer.model")
328 | # tokenizer.pad_id = tokenizer.unk_id
329 | # with open("/ssddata/weihao00/easy2hard/easy-to-hard-main/data/test_ppo.json", "r") as r:
330 | # test_ppo = json.load(r)
331 |
332 | # with open("/ssddata/weihao00/easy2hard/save_file/test_ppo_infer.json" , "r") as r:
333 | # data_lines = r.readlines()
334 |
335 |
336 | # data_set = set()
337 |
338 | # data_json = [json.loads(l) for l in data_lines]
339 |
340 | # # for item in data_json:
341 | # # data_set.add(item)
342 |
343 |
344 | # for idx, item in enumerate(test_ppo):
345 | # item["idx"] = idx
346 |
347 | # for i in data_json:
348 | # if i["idx"] == idx:
349 | # item["prompt"] = i["prompt"]
350 |
351 | # item["output"] = i["output"]
352 |
353 |
354 | # test_queries = []
355 | # test_responses = []
356 | # answers_list = []
357 | # gt_answers_list = []
358 | # levels_list = []
359 | # for item in test_ppo:
360 | # test_queries.append(item["input"])
361 | # test_responses.append(item["output"])
362 | # answers_list.append(item["answer"])
363 | # gt_answers_list.append(item["gt_answer"])
364 | # levels_list.append(item["level"])
365 | # eval_rollouts_batch = {}
366 | # eval_rollouts_batch["text_queries"] = test_queries
367 | # eval_rollouts_batch["text_responses"] = test_responses
368 |
369 | # outcome_metrics = post_process_math_rollouts(test_responses, answers_list, gt_answers_list, levels_list, tokenizer, "\n\n# Answer\n\n", False, False, "cuda:6")
370 |
371 | # eval_rollouts_batch.update(outcome_metrics[-1])
372 | # cpu_eval_rollouts = []
373 |
374 | # cpu_eval_rollouts.append(
375 | # {
376 | # key: value.cpu() if torch.is_tensor(value) else value
377 | # for key, value in eval_rollouts_batch.items()
378 | # }
379 | # )
380 | # eval_rollouts = cpu_eval_rollouts
381 |
382 | # eval_rollouts = common_utils.merge_dict(eval_rollouts, merge_fn=merge_fn)
383 | # # filtered_eval_rollouts = {}
384 |
385 | # # for key, value in eval_rollouts.items():
386 | # # filtered_eval_rollouts[key] = value[:eval_data_size]
387 | # # eval_rollouts = filtered_eval_rollouts
388 |
389 | # eval_stats = {}
390 | # overall_counts = 0.0
391 | # overall_accuracy = 0.0
392 | # for level in range(9):
393 | # if f"level_{level}_counts" in eval_rollouts:
394 | # level_counts = eval_rollouts[f"level_{level}_counts"].sum()
395 | # level_accuracy = eval_rollouts[f"level_{level}_accuracy"].sum()
396 | # overall_counts += level_counts
397 | # overall_accuracy += level_accuracy
398 |
399 | # eval_stats[f"accuracy_level_{level}"] = level_accuracy / level_counts
400 |
401 | # eval_stats[f"counts_level_{level}"] = level_counts
402 |
403 | # eval_stats[f"accuracy_overall"] = overall_accuracy.view(1) / (
404 | # overall_counts.view(1)
405 | # )
406 | # eval_stats[f"counts_overall"] = overall_counts
407 | # eval_stats = {
408 | # key: value.item() if torch.is_tensor(value) else value
409 | # for key, value in eval_stats.items()
410 | # }
411 | # print("bupt")
412 |
413 |
414 |
415 |
416 |
417 |
418 |
419 |
420 |
421 |
422 |
423 |
424 |
425 | # print("bupt")
--------------------------------------------------------------------------------
/train_code/convert4reward_auto_ground_sample.py:
--------------------------------------------------------------------------------
1 | import json
2 | import argparse
3 | import os
4 |
5 | def process_files(base_path, output_path, num_files, sample_num):
6 | data_list = []
7 |
8 | if num_files == -1:
9 | with open(base_path, "r") as r:
10 | data_list = json.load(r)
11 | else:
12 | for i in range(1, num_files + 1):
13 | file_path = base_path.format(i)
14 | if not os.path.exists(file_path):
15 | print(f"File {file_path} does not exist, skipping.")
16 | continue
17 | with open(file_path, "r") as r:
18 | data = json.load(r)
19 | data_list.extend(data)
20 |
21 | trn_json = []
22 | sample_json = []
23 |
24 | for item in data_list:
25 | trn_json.append(item)
26 |
27 | data_json = trn_json
28 | format_data_json = []
29 |
30 | for idx, item in enumerate(data_json):
31 | for i in range(sample_num):
32 | temp_json = {
33 | "idx": idx,
34 | "sample_idx": i,
35 | "prompt": item["prompt"],
36 | "response": item["output"],
37 | "output": item["output" + str(i)]
38 | }
39 |
40 | if "\n\n# Answer\n\n" in temp_json["output"]:
41 | format_data_json.append(temp_json)
42 |
43 | with open(output_path, "w") as w:
44 | for item in format_data_json:
45 | w.write(json.dumps(item))
46 | w.write("\n")
47 |
48 | if __name__ == "__main__":
49 | parser = argparse.ArgumentParser(description="Process JSON files and output formatted data.")
50 | parser.add_argument("--input_path", type=str, required=True, help="Base path for input files, use {} for file number placeholder")
51 | parser.add_argument("--output_path", type=str, required=True, help="Path for output file")
52 | parser.add_argument("--num_files", type=int, required=True, help="Number of input files to process")
53 | parser.add_argument("--sample_num", type=int, required=True, help="Number of samples to process")
54 |
55 | args = parser.parse_args()
56 |
57 | process_files(args.input_path, args.output_path, args.num_files, args.sample_num)
58 |
--------------------------------------------------------------------------------
/train_code/convert_auto.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 |
4 | def convert_checkpoints(checkpoint_dir, pretrain_name, tokenizer_name):
5 | # 定义检查点文件夹路径和其他参数
6 | last_checkpoint_file = os.path.join(checkpoint_dir, "last_checkpoint")
7 |
8 | # 创建保存转换文件的文件夹
9 | converted_dir = os.path.join(checkpoint_dir, "converted")
10 | os.makedirs(converted_dir, exist_ok=True)
11 |
12 | max_step = -1
13 | latest_ckpt_file = None
14 |
15 | # 遍历文件夹中所有的检查点文件
16 | for file_name in os.listdir(checkpoint_dir):
17 | if file_name.endswith(".pt"):
18 | # 修改 last_checkpoint 文件内容
19 | with open(last_checkpoint_file, 'w') as file:
20 | file.write(file_name)
21 | # 提取 step 数目
22 | step_number = int(file_name.split('_')[3])
23 |
24 | # 更新最新的ckpt文件
25 | if step_number > max_step:
26 | max_step = step_number
27 | latest_ckpt_file = file_name
28 |
29 | save_name_hf = os.path.join(converted_dir, f"ckpt{step_number}")
30 |
31 | # 如果保存文件已存在,则跳过
32 | if os.path.exists(save_name_hf):
33 | print(f"{save_name_hf} already exists. Skipping...")
34 | continue
35 |
36 | # 构建并执行转换命令
37 | command = f"""
38 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python convert_checkpoint_to_hf.py \\
39 | --tp_ckpt_name {checkpoint_dir} \\
40 | --pretrain_name {pretrain_name} \\
41 | --tokenizer_name {tokenizer_name} \\
42 | --save_name_hf {save_name_hf}
43 | """
44 | print(f"Executing command: {command}")
45 | os.system(command)
46 |
47 | # 在所有转换完成后,修改 last_checkpoint 文件内容
48 | if latest_ckpt_file:
49 | with open(last_checkpoint_file, 'w') as file:
50 | file.write(latest_ckpt_file)
51 | print(f"Updated {last_checkpoint_file} with {latest_ckpt_file}")
52 |
53 | if __name__ == "__main__":
54 | parser = argparse.ArgumentParser(description='Convert checkpoints to Hugging Face format')
55 | parser.add_argument('--checkpoint_dir', type=str, required=True, help='Path to the checkpoint directory')
56 | parser.add_argument('--pretrain_name', type=str, required=True, help='Name of the pre-trained model')
57 | parser.add_argument('--tokenizer_name', type=str, required=True, help='Name of the tokenizer')
58 |
59 | args = parser.parse_args()
60 |
61 | convert_checkpoints(args.checkpoint_dir, args.pretrain_name, args.tokenizer_name)
62 |
--------------------------------------------------------------------------------
/train_code/convert_checkpoint_to_hf.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
2 | from tqdm import tqdm
3 | import torch
4 | import re
5 | import argparse
6 | import os
7 | import glob
8 |
9 | # we need to check that we have login the HF account
10 | # !huggingface-cli whoami
11 | # !huggingface-cli login
12 |
13 |
14 | def load_and_merge_models(
15 | tp_ckpt_name, pretrain_name, tokenizer_name, save_name_hf, push_to_hf_hub_name
16 | ):
17 | assert (
18 | save_name_hf or push_to_hf_hub_name
19 | ), "Please provide a save path or push to HF hub name"
20 |
21 | tp_model_list = []
22 |
23 | last_checkpoint_file = os.path.join(tp_ckpt_name, "last_checkpoint")
24 | with open(last_checkpoint_file, "r") as f:
25 | last_checkpoint_file = f.readline().strip()
26 |
27 | last_checkpoint_file = last_checkpoint_file.split("/")[-1]
28 | last_checkpoint_file = os.path.join(tp_ckpt_name, last_checkpoint_file)
29 |
30 | print("Loading checkpoint files:", last_checkpoint_file)
31 | for file in sorted(glob.glob(last_checkpoint_file)):
32 | tp_model_list.append(
33 | torch.load(
34 | file,
35 | mmap=True,
36 | )["model"]
37 | )
38 |
39 | print("Loading HF model...")
40 | tokenizer = AutoTokenizer.from_pretrained(
41 | tokenizer_name,
42 | )
43 |
44 | model = AutoModelForCausalLM.from_pretrained(
45 | pretrain_name,
46 | # device_map="cpu",
47 | load_in_8bit=False,
48 | torch_dtype=torch.bfloat16,
49 | )
50 | cpu_state_dict = model.cpu().state_dict()
51 |
52 | replaced_keys = set()
53 |
54 | print("Convert to HF model...")
55 | num_tp = len(tp_model_list)
56 |
57 | state_dict = {}
58 |
59 | for key in tp_model_list[0].keys():
60 | if "wo" in key or "w2" in key:
61 | state_dict[key] = torch.cat(
62 | [tp_model_list[i][key].cpu() for i in range(num_tp)], dim=1
63 | )
64 | elif "wqkv" in key:
65 | state_dict[key] = torch.stack(
66 | [tp_model_list[i][key].cpu() for i in range(num_tp)], dim=0
67 | )
68 | elif "output" in key:
69 | state_dict[key] = torch.cat(
70 | [tp_model_list[i][key].cpu() for i in range(num_tp)], dim=1
71 | )
72 | else:
73 | state_dict[key] = torch.cat(
74 | [tp_model_list[i][key].cpu() for i in range(num_tp)], dim=0
75 | )
76 |
77 | pattern = r"layers\.(\d+)\."
78 |
79 | for key in state_dict.keys():
80 | layer = None
81 | match = re.search(pattern, key)
82 | # layer number except for:
83 | # lm_head.weight
84 | if match:
85 | layer = match.group(1)
86 | elif "output.weight" in key:
87 | name = f"lm_head.weight"
88 | print(cpu_state_dict[name].size(), state_dict[key].size())
89 | # repeat on dim 0 to match the size
90 | repeat_size = cpu_state_dict[name].size(0) // state_dict[key].size(0)
91 | new_state_dict = state_dict[key].repeat(repeat_size, 1)
92 | cpu_state_dict[name] = 0.0 * cpu_state_dict[name] + new_state_dict
93 | replaced_keys.add(name)
94 | else:
95 | raise ValueError(f"Invalid key: {key}")
96 |
97 | print("Converting layer", key)
98 | if "wqkv" in key:
99 | merged_q, merged_k, merged_v = [], [], []
100 | reconstruct_q, reconstruct_k = [], []
101 |
102 | print("state_dict[key].size(2)", state_dict[key].size(2))
103 |
104 |
105 |
106 | if state_dict[key].size(2) == 4096:
107 | if "Mistral-7B-v0.1" in pretrain_name:
108 | n_heads, n_local_heads = 32, 8
109 | print("Mistral-7B-v0.1")
110 |
111 | elif "Mistral-7B" in pretrain_name:
112 | n_heads, n_local_heads = 32, 8
113 | print("Mistral-7B")
114 |
115 | elif "Meta-Llama-3-8B" in pretrain_name:
116 | n_heads, n_local_heads = 32, 8
117 | print("Meta-Llama-3-8B")
118 |
119 | elif "Llama-3.1-8B" in pretrain_name:
120 | n_heads, n_local_heads = 32, 8
121 | print("Meta-Llama-3-8B")
122 | elif "Meta-Llama-3-70B" in pretrain_name:
123 | n_heads, n_local_heads = 64, 8
124 | print("Meta-Llama-3-70B")
125 | else:
126 |
127 | n_heads, n_local_heads = 32, 32
128 | elif state_dict[key].size(2) == 5120:
129 | n_heads, n_local_heads = 40, 40
130 | elif state_dict[key].size(2) == 6656:
131 | n_heads, n_local_heads = 52, 52
132 | elif state_dict[key].size(2) == 8192:
133 | n_heads, n_local_heads = 64, 8
134 | else:
135 | raise ValueError(f"Invalid size for {key}: {state_dict[key].size()}")
136 |
137 | head_dim = state_dict[key].size(1) // (n_heads + n_local_heads * 2)
138 |
139 | weight_splits = [
140 | head_dim * n_heads,
141 | head_dim * n_local_heads,
142 | head_dim * n_local_heads,
143 | ]
144 |
145 | for split_idx in range(state_dict[key].size(0)):
146 | chunk = state_dict[key][split_idx]
147 | q, k, v = chunk.split(weight_splits, dim=0)
148 | merged_q.append(q)
149 | merged_k.append(k)
150 | merged_v.append(v)
151 | merged_q = torch.cat(merged_q, dim=0)
152 | merged_k = torch.cat(merged_k, dim=0)
153 | merged_v = torch.cat(merged_v, dim=0)
154 |
155 | #### qk need reconstruction ####
156 | split_qs = torch.split(merged_q, split_size_or_sections=128, dim=0)
157 | split_ks = torch.split(merged_k, split_size_or_sections=128, dim=0)
158 | for split in split_qs:
159 | matrix0 = split[::2, :]
160 | matrix1 = split[1::2, :]
161 | reconstruct_q.append(matrix0)
162 | reconstruct_q.append(matrix1)
163 | reconstruct_q = torch.cat(reconstruct_q, dim=0)
164 | for split in split_ks:
165 | matrix0 = split[::2, :]
166 | matrix1 = split[1::2, :]
167 | reconstruct_k.append(matrix0)
168 | reconstruct_k.append(matrix1)
169 | reconstruct_k = torch.cat(reconstruct_k, dim=0)
170 | #### qk need reconstruction ####
171 |
172 | name = f"model.layers.{layer}.self_attn.q_proj.weight"
173 | cpu_state_dict[name] = reconstruct_q
174 | replaced_keys.add(name)
175 |
176 | name = f"model.layers.{layer}.self_attn.k_proj.weight"
177 | cpu_state_dict[name] = reconstruct_k
178 | replaced_keys.add(name)
179 |
180 | name = f"model.layers.{layer}.self_attn.v_proj.weight"
181 | cpu_state_dict[name] = merged_v
182 | replaced_keys.add(name)
183 |
184 | if "wo" in key:
185 | name = f"model.layers.{layer}.self_attn.o_proj.weight"
186 | cpu_state_dict[name] = state_dict[key]
187 | replaced_keys.add(name)
188 | if "w1" in key:
189 | name = f"model.layers.{layer}.mlp.gate_proj.weight"
190 | cpu_state_dict[name] = state_dict[key]
191 | replaced_keys.add(name)
192 | if "w3" in key:
193 | name = f"model.layers.{layer}.mlp.up_proj.weight"
194 | cpu_state_dict[name] = state_dict[key]
195 | replaced_keys.add(name)
196 | if "w2" in key:
197 | name = f"model.layers.{layer}.mlp.down_proj.weight"
198 | cpu_state_dict[name] = state_dict[key]
199 | replaced_keys.add(name)
200 |
201 | unreplaced_keys = set(cpu_state_dict.keys()) - replaced_keys
202 | print("Unreplaced keys:", unreplaced_keys)
203 |
204 | print("Loading state dict...")
205 |
206 | model.load_state_dict(cpu_state_dict, strict=False)
207 |
208 | print("Saving HF model...")
209 |
210 | if save_name_hf is not None:
211 | model.save_pretrained(save_name_hf)
212 | config = AutoConfig.from_pretrained(pretrain_name)
213 | tokenizer.save_pretrained(save_name_hf)
214 | config.save_pretrained(save_name_hf)
215 | else:
216 | model.push_to_hub(push_to_hf_hub_name, private=True, safe_serialization=False)
217 |
218 |
219 | if __name__ == "__main__":
220 | parser = argparse.ArgumentParser(description="Process some integers.")
221 | parser.add_argument(
222 | "--tp_ckpt_name", type=str, help="Path to the TP checkpoint name", required=True
223 | )
224 | parser.add_argument(
225 | "--tokenizer_name", type=str, help="Path to the tokenizer name", required=True
226 | )
227 | parser.add_argument(
228 | "--pretrain_name", type=str, help="Path to the pretrain name", required=True
229 | )
230 | parser.add_argument(
231 | "--save_name_hf", type=str, default=None, help="Path to save the HF model"
232 | )
233 | parser.add_argument(
234 | "--push_to_hf_hub_name", type=str, default=None, help="Push to HF hub"
235 | )
236 |
237 | args = parser.parse_args()
238 | load_and_merge_models(
239 | args.tp_ckpt_name,
240 | args.pretrain_name,
241 | args.tokenizer_name,
242 | args.save_name_hf,
243 | args.push_to_hf_hub_name,
244 | )
245 |
--------------------------------------------------------------------------------
/train_code/data_utils/__pycache__/data_utils_ppo.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/data_utils/__pycache__/data_utils_ppo.cpython-310.pyc
--------------------------------------------------------------------------------
/train_code/data_utils/__pycache__/data_utils_ppo.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/data_utils/__pycache__/data_utils_ppo.cpython-311.pyc
--------------------------------------------------------------------------------
/train_code/data_utils/__pycache__/data_utils_rm_pointwise.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/data_utils/__pycache__/data_utils_rm_pointwise.cpython-310.pyc
--------------------------------------------------------------------------------
/train_code/data_utils/data_utils_dpo.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The GPT-Accelera Team
2 | # Copyright 2023 The Alpaca Team
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from typing import Dict, Sequence, Union
17 |
18 | import numpy as np
19 | import torch
20 | from torch.utils.data import Dataset
21 | from datasets import Dataset as HFDataset
22 |
23 | from arguments import Arguments
24 | import trainers.common_utils as utils
25 | from models.tokenizer_utils import AcceleraTokenizer
26 | from data_utils.data_utils_sft import preprocess_for_sft, extract_alpaca_dataset
27 |
28 |
29 | class DPODataset(Dataset):
30 | def __init__(
31 | self,
32 | args: Arguments,
33 | dataset: HFDataset,
34 | tokenizer: AcceleraTokenizer,
35 | ):
36 | super(DPODataset, self).__init__()
37 | self.tensors = preprocess_for_dpo(
38 | args=args,
39 | dataset=dataset,
40 | tokenizer=tokenizer,
41 | )
42 |
43 | def __len__(self):
44 | return len(next(iter(self.tensors.values())))
45 |
46 | def __getitem__(self, i) -> Dict[str, torch.Tensor]:
47 | return {key: value[i] for key, value in self.tensors.items()}
48 |
49 |
50 | def preprocess_for_dpo(
51 | args: Arguments,
52 | dataset: HFDataset,
53 | tokenizer: AcceleraTokenizer,
54 | reorder_wl: bool = True,
55 | ) -> dict[str, Union[torch.Tensor, Sequence[torch.Tensor]]]:
56 | df = dataset.to_pandas()
57 | output_1, output_2, preference = df["output_1"], df["output_2"], df["preference"]
58 |
59 | assign_w_kwargs = dict(
60 | output=np.where(preference == 1, output_1, output_2),
61 | )
62 | assign_l_kwargs = dict(
63 | output=np.where(preference == 2, output_1, output_2),
64 | )
65 | assign_keys = ["instruction", "input", "output"]
66 |
67 | if "is_eos_1" in df.columns:
68 | is_eos_1, is_eos_2 = df["is_eos_1"], df["is_eos_2"]
69 | assign_w_kwargs.update(
70 | is_eos=np.where(preference == 1, is_eos_1, is_eos_2),
71 | )
72 | assign_l_kwargs.update(
73 | is_eos=np.where(preference == 2, is_eos_1, is_eos_2),
74 | )
75 | assign_keys.extend(["is_eos"])
76 |
77 | if "win_rate_1" in df.columns:
78 | win_rate_1, win_rate_2 = df["win_rate_1"], df["win_rate_2"]
79 | assign_w_kwargs.update(
80 | win_rate=np.where(preference == 1, win_rate_1, win_rate_2),
81 | )
82 | assign_l_kwargs.update(
83 | win_rate=np.where(preference == 2, win_rate_1, win_rate_2),
84 | )
85 | assign_keys.extend(["win_rate"])
86 |
87 | if reorder_wl:
88 | df_w = df.assign(**assign_w_kwargs)[assign_keys]
89 | df_l = df.assign(**assign_l_kwargs)[assign_keys]
90 | else:
91 | df_w = df.assign(output=output_1)[assign_w_kwargs]
92 | df_l = df.assign(output=output_2)[assign_l_kwargs]
93 |
94 | df_w_list = df_w.to_dict("records")
95 | df_l_list = df_l.to_dict("records")
96 |
97 | assert len(df_w_list) == len(df_l_list)
98 |
99 | if args.dataset_format == "alpaca":
100 | for i in range(len(df_w_list)):
101 | df_w_list[i].update(extract_alpaca_dataset(df_w_list[i]))
102 | df_l_list[i].update(extract_alpaca_dataset(df_l_list[i]))
103 | elif args.dataset_format is None:
104 | pass
105 | else:
106 | raise ValueError(f"Unknown dataset format: {args.dataset_format}")
107 |
108 | tensors_w = preprocess_for_sft(
109 | instances=df_w_list,
110 | tokenizer=tokenizer,
111 | source_max_len=args.source_max_len,
112 | target_max_len=args.target_max_len,
113 | total_max_len=args.total_max_len,
114 | train_on_source=args.train_on_source,
115 | add_eos_to_target=args.add_eos_to_target,
116 | add_eos_to_marked_target=args.add_eos_to_marked_target,
117 | return_win_rate=True,
118 | )
119 | tensors_l = preprocess_for_sft(
120 | instances=df_l_list,
121 | tokenizer=tokenizer,
122 | source_max_len=args.source_max_len,
123 | target_max_len=args.target_max_len,
124 | total_max_len=args.total_max_len,
125 | train_on_source=args.train_on_source,
126 | add_eos_to_target=args.add_eos_to_target,
127 | add_eos_to_marked_target=args.add_eos_to_marked_target,
128 | return_win_rate=True,
129 | )
130 | return dict(
131 | input_ids_w=tensors_w["input_ids"],
132 | labels_w=tensors_w["labels"],
133 | win_rate_w=tensors_w["win_rate"],
134 | input_ids_l=tensors_l["input_ids"],
135 | labels_l=tensors_l["labels"],
136 | win_rate_l=tensors_l["win_rate"],
137 | )
138 |
139 |
140 | def make_dpo_data_module(
141 | tokenizer: AcceleraTokenizer,
142 | args: Arguments,
143 | ) -> dict:
144 | preference_dataset = utils.local_dataset(args.dataset)
145 | train_preference = preference_dataset["train"]
146 |
147 | train_dataset = DPODataset(
148 | args=args,
149 | dataset=train_preference,
150 | tokenizer=tokenizer,
151 | )
152 |
153 | eval_dataset = None
154 | if args.eval_size > 0:
155 | train_dataset, eval_dataset = utils.split_train_into_train_and_eval(
156 | train_dataset=train_dataset,
157 | eval_size=args.eval_size,
158 | seed=args.seed,
159 | )
160 | data_collator = utils.DataCollatorForStackableDataset()
161 | return dict(
162 | train_dataset=train_dataset,
163 | eval_dataset=eval_dataset,
164 | data_collator=data_collator,
165 | )
166 |
--------------------------------------------------------------------------------
/train_code/data_utils/data_utils_ppo.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The GPT-Accelera Team
2 | # Copyright 2023 The Alpaca Team
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import os
17 | from typing import Dict
18 | import logging
19 |
20 | import torch
21 | from torch.utils.data import Dataset
22 | from datasets import Dataset as HFDataset
23 |
24 | from arguments import Arguments
25 | import trainers.common_utils as utils
26 | from models.tokenizer_utils import AcceleraTokenizer
27 |
28 | logger = logging.getLogger(__name__)
29 |
30 |
31 | class QueryDataset(Dataset):
32 | """Dataset that emits tokenized left-padded queries."""
33 |
34 | def __init__(
35 | self,
36 | dataset: HFDataset,
37 | tokenizer: AcceleraTokenizer,
38 | query_len: int,
39 | ):
40 | super(QueryDataset, self).__init__()
41 |
42 | list_dict_data = dataset.to_pandas().to_dict("records")
43 |
44 | # prompts are strings; queries are tensors.
45 | queries = [dict_data["input"] for dict_data in list_dict_data]
46 | answers = [
47 | f"{dict_data['answer']} ;;; {dict_data['gt_answer']} ;;; {dict_data['level']}"
48 | for dict_data in list_dict_data
49 | ]
50 |
51 | logger.warning(f"Debugging: {answers[:10]}")
52 | queries = [
53 | tokenizer(query, return_tensors="pt", truncation=False).input_ids.squeeze(
54 | dim=0
55 | )
56 | for query in queries
57 | ]
58 |
59 | answers = [
60 | tokenizer(answer, return_tensors="pt", truncation=False).input_ids.squeeze(
61 | dim=0
62 | )
63 | for answer in answers
64 | ]
65 |
66 | filtered_queries = []
67 | filtered_answers = []
68 |
69 | for query, answer in zip(queries, answers):
70 | if len(query) <= query_len:
71 | filtered_queries.append(query)
72 | filtered_answers.append(answer)
73 |
74 | logger.warning(
75 | f"Filtered out {len(queries) - len(filtered_queries)} instances out of {len(queries)} that "
76 | f"exceed length limit. These examples are not used for training, but will still be used in evaluation. "
77 | )
78 |
79 | queries = torch.stack(
80 | [
81 | utils.left_pad(query, target_size=(query_len,), value=tokenizer.pad_id)
82 | for query in filtered_queries
83 | ]
84 | )
85 |
86 | max_answer_len = max([len(answer) for answer in filtered_answers])
87 | answers = torch.stack(
88 | [
89 | utils.left_pad(
90 | answer,
91 | target_size=(max_answer_len,),
92 | value=tokenizer.pad_id,
93 | )
94 | for answer in filtered_answers
95 | ]
96 | )
97 |
98 | assert queries.shape[0] == answers.shape[0]
99 |
100 | self.queries = queries
101 | self.query_attn_masks = queries.ne(tokenizer.pad_id).long()
102 | self.answers = answers
103 | # Auxiliary data.
104 | self.list_dict_data = list_dict_data
105 |
106 | def __getitem__(self, i):
107 | return dict(
108 | queries=self.queries[i],
109 | query_attn_masks=self.query_attn_masks[i],
110 | answers=self.answers[i],
111 | )
112 |
113 | def __len__(self):
114 | return len(self.queries)
115 |
116 |
117 | def make_rl_data_module(
118 | tokenizer: AcceleraTokenizer,
119 | args: Arguments,
120 | ) -> Dict:
121 | """
122 | Make dataset and collator for supervised fine-tuning.
123 | Datasets are expected to have the following columns: { `input`, `output` }
124 | """
125 |
126 | def load_data(dataset_name):
127 | if os.path.exists(dataset_name):
128 | try:
129 | full_dataset = utils.local_dataset(dataset_name)
130 | return full_dataset
131 | except:
132 | raise ValueError(f"Error loading dataset from {dataset_name}")
133 | else:
134 | raise NotImplementedError(f"Dataset {dataset_name} not implemented yet.")
135 |
136 | def format_dataset(dataset):
137 | # Remove unused columns.
138 | dataset = dataset.remove_columns(
139 | [
140 | col
141 | for col in dataset.column_names["train"]
142 | if col not in ["input", "answer", "gt_answer", "level"]
143 | ]
144 | )
145 | return dataset
146 |
147 | # Load dataset.
148 | dataset = load_data(args.dataset)
149 | dataset = format_dataset(dataset)
150 |
151 | # Split train/eval, reduce size
152 | eval_dataset = None
153 | if args.do_eval:
154 | if args.eval_dataset is not None:
155 | eval_dataset = load_data(args.eval_dataset)
156 | eval_dataset = format_dataset(eval_dataset)
157 | eval_dataset = eval_dataset["train"]
158 | else:
159 | print(
160 | "Splitting train dataset in train and validation according to `eval_dataset_size`"
161 | )
162 | dataset = dataset["train"].train_test_split(
163 | test_size=args.eval_dataset_size, shuffle=True, seed=42
164 | )
165 | eval_dataset = dataset["test"]
166 | if (
167 | args.max_eval_samples is not None
168 | and len(eval_dataset) > args.max_eval_samples
169 | ):
170 | eval_dataset = eval_dataset.select(range(args.max_eval_samples))
171 |
172 | test_dataset = None
173 | if args.do_test:
174 | if args.test_dataset is not None:
175 | test_dataset = load_data(args.test_dataset)
176 | test_dataset = format_dataset(test_dataset)
177 | test_dataset = test_dataset["train"]
178 | else:
179 | raise NotImplementedError("Must specify test dataset if `do_test` is True.")
180 |
181 | train_dataset = dataset["train"]
182 | if (
183 | args.max_train_samples is not None
184 | and len(train_dataset) > args.max_train_samples
185 | ):
186 | train_dataset = train_dataset.select(range(args.max_train_samples))
187 |
188 | train_dataset = QueryDataset(
189 | dataset=train_dataset,
190 | tokenizer=tokenizer,
191 | query_len=args.source_max_len,
192 | )
193 |
194 | if eval_dataset is not None:
195 | eval_dataset = QueryDataset(
196 | dataset=eval_dataset,
197 | tokenizer=tokenizer,
198 | query_len=args.source_max_len,
199 | )
200 |
201 | if test_dataset is not None:
202 | test_dataset = QueryDataset(
203 | dataset=test_dataset,
204 | tokenizer=tokenizer,
205 | query_len=args.source_max_len,
206 | )
207 |
208 | data_collator = utils.DataCollatorForStackableDataset()
209 | return dict(
210 | train_dataset=train_dataset,
211 | eval_dataset=eval_dataset,
212 | test_dataset=test_dataset,
213 | data_collator=data_collator,
214 | )
215 |
--------------------------------------------------------------------------------
/train_code/data_utils/data_utils_rm_pairwise.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The GPT-Accelera Team
2 | # Copyright 2023 The Alpaca Team
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import logging
17 | from typing import Optional, Dict, Sequence
18 |
19 | import torch
20 | from torch.utils.data import Dataset
21 |
22 | from datasets import Dataset as HFDataset
23 |
24 | from arguments import Arguments
25 | import trainers.common_utils as utils
26 | from models.tokenizer_utils import AcceleraTokenizer
27 |
28 | logger = logging.getLogger(__name__)
29 |
30 |
31 | DROMEDARY_PROMPT_DICT = {
32 | "prompt_input": (
33 | "{meta_prompt}\n" "{instruction}\n\n" "{input}\n\n" "### Dromedary"
34 | ),
35 | "prompt_no_input": ("{meta_prompt}\n" "{instruction}\n\n" "### Dromedary"),
36 | }
37 |
38 |
39 | ALPACA_PROMPT_DICT = {
40 | "prompt_input": (
41 | "Below is an instruction that describes a task, paired with an input that provides further context. "
42 | "Write a response that appropriately completes the request.\n\n"
43 | "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
44 | ),
45 | "prompt_no_input": (
46 | "Below is an instruction that describes a task. "
47 | "Write a response that appropriately completes the request.\n\n"
48 | "### Instruction:\n{instruction}\n\n### Response:\n"
49 | ),
50 | }
51 |
52 |
53 | def format_prompt(
54 | example: Dict[str, str],
55 | prompt_dict: Dict[str, str],
56 | ) -> str:
57 | if prompt_dict is not None:
58 | assert (
59 | "instruction" in example
60 | ), "Internal error: example missing required keys."
61 |
62 | if example.get("input", "") != "":
63 | prompt_format = prompt_dict["prompt_input"]
64 | else:
65 | prompt_format = prompt_dict["prompt_no_input"]
66 | else:
67 | prompt_format = "{input}"
68 |
69 | format_prompt = prompt_format.format(**example)
70 | return format_prompt
71 |
72 |
73 | def format_output(
74 | example: dict,
75 | output_key="output",
76 | ) -> str:
77 | return example[output_key]
78 |
79 |
80 | def _tokenize_fn(
81 | strings: Sequence[str],
82 | tokenizer: AcceleraTokenizer,
83 | max_length: int,
84 | end_sequence_with_eos: bool,
85 | use_data_frame: bool = False,
86 | ) -> dict:
87 | """Tokenize a list of strings."""
88 | if use_data_frame:
89 | raise NotImplementedError
90 | strings_ds = strings
91 |
92 | tokenized_strings = tokenizer(
93 | strings_ds,
94 | max_length=max_length,
95 | padding="max_length",
96 | truncation=True,
97 | add_bos=True,
98 | add_eos=True if end_sequence_with_eos else False,
99 | padding_side="right",
100 | truncation_side="right",
101 | )
102 |
103 | input_ids = torch.stack(
104 | [torch.tensor(tokenized) for tokenized in tokenized_strings["input_ids"]],
105 | dim=0,
106 | )
107 |
108 | return input_ids
109 |
110 |
111 | def preprocess_for_reward_modeling(
112 | data: HFDataset,
113 | tokenizer: AcceleraTokenizer,
114 | end_sequence_with_eos: bool = False,
115 | max_length: Optional[int] = None,
116 | query_len: Optional[int] = None,
117 | response_len: Optional[int] = None,
118 | prompt_dict: Optional[Dict[str, str]] = None,
119 | ) -> Dict[str, torch.Tensor]:
120 | list_dict_data = data.to_pandas().to_dict("records")
121 |
122 | def _get_numeric_preference(example: dict):
123 | # 1 vs 2 is stored in table, but for modeling we use 0 vs 1; remap here.
124 | return {1: 0, 2: 1}[example["preference"]]
125 |
126 | choice = torch.tensor(
127 | [[_get_numeric_preference(dict_data)] for dict_data in list_dict_data]
128 | )
129 |
130 | def _get_text(example: dict, output_key: str):
131 | full_prompt = format_prompt(example, prompt_dict) + format_output(
132 | example, output_key
133 | )
134 | return full_prompt
135 |
136 | text_list_0, text_list_1 = tuple(
137 | [_get_text(dict_data, key) for dict_data in list_dict_data]
138 | for key in ("output_1", "output_2")
139 | )
140 |
141 | if max_length is None:
142 | max_length = query_len + response_len
143 |
144 | logger.warning(f"Tokenizing {len(list_dict_data)} pairs...")
145 | tokenized_0, tokenized_1 = tuple(
146 | _tokenize_fn(text_list, tokenizer, max_length, end_sequence_with_eos)
147 | for text_list in (text_list_0, text_list_1)
148 | )
149 | # "size" (bsz, 2, seq_len)
150 | input_ids = torch.stack(
151 | [tokenized_0, tokenized_1],
152 | dim=1,
153 | )
154 |
155 | packaged_data = dict(
156 | input_ids=input_ids,
157 | choice=choice,
158 | metadata=dict(mean_choice=choice.float().mean().item()),
159 | )
160 |
161 | return packaged_data
162 |
163 |
164 | class PairwiseRewardModelingDataset(Dataset):
165 | def __init__(
166 | self,
167 | data: HFDataset,
168 | tokenizer: AcceleraTokenizer,
169 | end_sequence_with_eos: bool = False,
170 | max_length: Optional[int] = None,
171 | query_len: Optional[int] = None,
172 | response_len: Optional[int] = None,
173 | prompt_dict: Optional[Dict[str, str]] = None,
174 | ):
175 | super(PairwiseRewardModelingDataset, self).__init__()
176 | data_dict = preprocess_for_reward_modeling(
177 | data=data,
178 | tokenizer=tokenizer,
179 | end_sequence_with_eos=end_sequence_with_eos,
180 | max_length=max_length,
181 | query_len=query_len,
182 | response_len=response_len,
183 | prompt_dict=prompt_dict,
184 | )
185 | self.input_ids = data_dict["input_ids"]
186 | self.choice = data_dict["choice"]
187 | self.metadata = data_dict["metadata"]
188 |
189 | def __len__(self):
190 | return len(self.input_ids)
191 |
192 | def __getitem__(self, i) -> Dict[str, torch.Tensor]:
193 | return dict(
194 | input_ids=self.input_ids[i],
195 | choice=self.choice[i],
196 | )
197 |
198 |
199 | def make_pairwise_reward_modeling_data_module(
200 | tokenizer: AcceleraTokenizer,
201 | args: Arguments,
202 | ):
203 | preference_dataset = utils.local_dataset(args.dataset)
204 | train_preference = preference_dataset["train"]
205 |
206 | if args.dataset_format == "alpaca":
207 | prompt_dict = ALPACA_PROMPT_DICT
208 | elif args.dataset_format is None:
209 | prompt_dict = None
210 | else:
211 | raise ValueError(
212 | f"Unsupported dataset_format: {args.dataset_format}."
213 | "Only alpaca and None are supported."
214 | )
215 |
216 | train_dataset = PairwiseRewardModelingDataset(
217 | data=train_preference,
218 | tokenizer=tokenizer,
219 | end_sequence_with_eos=args.add_eos_to_target,
220 | max_length=args.total_max_len,
221 | query_len=args.source_max_len,
222 | response_len=args.target_max_len,
223 | prompt_dict=prompt_dict,
224 | )
225 |
226 | eval_dataset = None
227 | if args.eval_size > 0:
228 | train_dataset, eval_dataset = utils.split_train_into_train_and_eval(
229 | train_dataset=train_dataset,
230 | eval_size=args.eval_size,
231 | seed=args.seed,
232 | )
233 |
234 | data_collator = utils.DataCollatorForStackableDataset()
235 | return dict(
236 | train_dataset=train_dataset,
237 | eval_dataset=eval_dataset,
238 | data_collator=data_collator,
239 | )
240 |
--------------------------------------------------------------------------------
/train_code/data_utils/data_utils_sft.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The GPT-Accelera Team
2 | # Copyright 2023 The Alpaca Team
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import os
17 | from dataclasses import dataclass
18 | import logging
19 | from typing import Dict, Sequence, Union
20 |
21 | import torch
22 |
23 | from datasets import load_dataset
24 |
25 | from arguments import Arguments
26 | import trainers.common_utils as utils
27 | from models.tokenizer_utils import AcceleraTokenizer
28 |
29 | logger = logging.getLogger(__name__)
30 |
31 | DROMEDARY_PROMPT_DICT = {
32 | "prompt_input": (
33 | "{meta_prompt}\n" "{instruction}\n\n" "{input}\n\n" "### Dromedary"
34 | ),
35 | "prompt_no_input": ("{meta_prompt}\n" "{instruction}\n\n" "### Dromedary"),
36 | }
37 |
38 | ALPACA_PROMPT_DICT = {
39 | "prompt_input": (
40 | "Below is an instruction that describes a task, paired with an input that provides further context. "
41 | "Write a response that appropriately completes the request.\n\n"
42 | "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
43 | ),
44 | "prompt_no_input": (
45 | "Below is an instruction that describes a task. "
46 | "Write a response that appropriately completes the request.\n\n"
47 | "### Instruction:\n{instruction}\n\n### Response:\n"
48 | ),
49 | }
50 |
51 |
52 | def preprocess_for_sft(
53 | instances: Sequence[Dict],
54 | tokenizer: AcceleraTokenizer,
55 | source_max_len: int,
56 | target_max_len: int,
57 | total_max_len: int,
58 | train_on_source: bool,
59 | add_eos_to_target: bool,
60 | add_eos_to_marked_target: bool,
61 | return_win_rate: bool = False,
62 | ) -> Dict[str, Union[torch.Tensor, Sequence[torch.Tensor]]]:
63 | # Extract elements
64 | sources = [example["input"] for example in instances]
65 | targets = [f"\n{example['output']}" for example in instances]
66 |
67 | begin_padding_len = tokenizer(
68 | ["\n"], return_tensors="pt", add_bos=False, add_eos=False
69 | ).input_ids.shape[1]
70 |
71 | # Tokenize
72 | tokenized_sources_with_prompt = tokenizer(
73 | sources,
74 | max_length=source_max_len,
75 | padding="max_length",
76 | truncation=True,
77 | add_bos=True,
78 | add_eos=False,
79 | padding_side="left",
80 | truncation_side="left",
81 | )
82 |
83 | marked_eos = None
84 | if "is_eos" in instances[0] and add_eos_to_marked_target:
85 | marked_eos = [example["is_eos"] for example in instances]
86 |
87 | win_rate = None
88 | if return_win_rate:
89 | if "win_rate" in instances[0]:
90 | win_rate = [example["win_rate"] for example in instances]
91 | else:
92 | win_rate = [0.5 for _ in instances]
93 |
94 | # logger.warning(f"Tokenizing {len(targets)} pairs...")
95 | tokenized_targets = tokenizer(
96 | targets,
97 | max_length=target_max_len + begin_padding_len,
98 | padding="max_length",
99 | truncation=True,
100 | add_bos=False,
101 | add_eos=add_eos_to_target,
102 | marked_eos=marked_eos,
103 | padding_side="right",
104 | truncation_side="right",
105 | )
106 | # Build the input and labels for causal LM
107 | input_ids = []
108 | labels = []
109 | for source_length, tokenized_source, tokenized_target in zip(
110 | tokenized_sources_with_prompt["length"],
111 | tokenized_sources_with_prompt["input_ids"],
112 | tokenized_targets["input_ids"],
113 | ):
114 | tokenized_target = tokenized_target[begin_padding_len:]
115 | full_seq = tokenized_source + tokenized_target
116 |
117 | # move the beginning padding to the end of the full_seq
118 | num_begin_padding = len(tokenized_source) - source_length
119 | full_seq = full_seq[num_begin_padding:] + full_seq[:num_begin_padding]
120 |
121 | if total_max_len is not None:
122 | full_seq = full_seq[:total_max_len]
123 |
124 | # input_ids.append(torch.tensor(full_seq))
125 | input_ids.append(full_seq)
126 | if not train_on_source:
127 | full_seq_label = (
128 | [tokenizer.pad_id for _ in range(source_length)]
129 | + tokenized_target
130 | + [tokenizer.pad_id for _ in range(num_begin_padding)]
131 | )
132 | if total_max_len is not None:
133 | full_seq_label = full_seq_label[:total_max_len]
134 | # labels.append(torch.tensor(full_seq_label))
135 | labels.append(full_seq_label)
136 | else:
137 | # labels.append(torch.tensor(copy.deepcopy(full_seq)))
138 | labels.append(full_seq)
139 | # Apply padding
140 | # input_ids = pad_sequence(
141 | # input_ids, batch_first=True, padding_value=tokenizer.pad_id
142 | # )
143 | # labels = pad_sequence(labels, batch_first=True, padding_value=tokenizer.pad_id)
144 | input_ids = torch.tensor(input_ids)
145 | labels = torch.tensor(labels)
146 | data_dict = {
147 | "input_ids": input_ids,
148 | "attention_mask": input_ids.ne(tokenizer.pad_id),
149 | }
150 | if labels is not None:
151 | data_dict["labels"] = labels
152 | if return_win_rate:
153 | data_dict["win_rate"] = torch.tensor(win_rate).view(-1, 1)
154 | return data_dict
155 |
156 |
157 | @dataclass
158 | class DataCollatorForCausalLM(object):
159 | tokenizer: AcceleraTokenizer
160 | source_max_len: int
161 | target_max_len: int
162 | total_max_len: int
163 | train_on_source: bool
164 | add_eos_to_target: bool
165 | add_eos_to_marked_target: bool
166 |
167 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
168 | return preprocess_for_sft(
169 | instances=instances,
170 | tokenizer=self.tokenizer,
171 | source_max_len=self.source_max_len,
172 | target_max_len=self.target_max_len,
173 | total_max_len=self.total_max_len,
174 | train_on_source=self.train_on_source,
175 | add_eos_to_target=self.add_eos_to_target,
176 | add_eos_to_marked_target=self.add_eos_to_marked_target,
177 | )
178 |
179 |
180 | def extract_alpaca_dataset(example):
181 | if example.get("input", "") != "":
182 | prompt_format = ALPACA_PROMPT_DICT["prompt_input"]
183 | else:
184 | prompt_format = ALPACA_PROMPT_DICT["prompt_no_input"]
185 | return {"input": prompt_format.format(**example)}
186 |
187 |
188 | def extract_dromedary_dataset(example, meta_prompts):
189 | assert "example_id" in example
190 | total_meta_prompt = len(meta_prompts)
191 | meta_prompt = meta_prompts[int(example["example_id"]) % total_meta_prompt]
192 |
193 | if example.get("input", "") != "":
194 | prompt_format = DROMEDARY_PROMPT_DICT["prompt_input"]
195 | else:
196 | prompt_format = DROMEDARY_PROMPT_DICT["prompt_no_input"]
197 |
198 | return {
199 | "input": prompt_format.format(meta_prompt=meta_prompt, **example),
200 | "output": "\n" + example["output"],
201 | }
202 |
203 |
204 | def extract_prm_dataset(example):
205 | if example["output_prefix"] == "":
206 | ret = {
207 | "input": "Question: " + example["input"],
208 | "output": "\n\nAnswer: " + example["output"],
209 | }
210 | else:
211 | ret = {
212 | "input": "Question: "
213 | + example["input"]
214 | + "\n\nAnswer: "
215 | + example["output_prefix"],
216 | "output": example["output"],
217 | }
218 |
219 | if "is_eos" in example:
220 | ret["is_eos"] = example["is_eos"]
221 |
222 | return ret
223 |
224 |
225 | def extract_prm_v2_dataset(example):
226 | if example["output_prefix"] == "":
227 | ret = {
228 | "input": "# Question\n\n" + example["input"] + "\n\n# Solution",
229 | "output": "\n\n" + example["output"],
230 | }
231 | else:
232 | ret = {
233 | "input": "# Question\n\n"
234 | + example["input"]
235 | + "\n\n# Solution\n\n"
236 | + example["output_prefix"],
237 | "output": example["output"],
238 | }
239 |
240 | if "is_eos" in example:
241 | ret["is_eos"] = example["is_eos"]
242 |
243 | return ret
244 |
245 |
246 | def extract_metamath_dataset(example):
247 | ret = {
248 | "input": "# Question\n\n" + example["query"] + "\n\n# Solution",
249 | "output": "\n\n" + example["output"],
250 | "is_eos": True,
251 | }
252 |
253 | return ret
254 |
255 | def extract_arc_dataset(example):
256 | ret = {
257 | "input": example["input"],
258 | "output": example["output"],
259 | "is_eos": True,
260 | }
261 | return ret
262 |
263 | def extract_apps_dataset(example):
264 | ret = {
265 | "input": example["input"],
266 | "output": example["output"],
267 | "is_eos": True,
268 | }
269 |
270 | return ret
271 |
272 | def make_sft_data_module(
273 | tokenizer: AcceleraTokenizer,
274 | args: Arguments,
275 | ) -> Dict:
276 | """
277 | Make dataset and collator for supervised fine-tuning.
278 | Datasets are expected to have the following columns: { `input`, `output` }
279 | """
280 |
281 | def load_data(dataset_name):
282 | if dataset_name == "alpaca":
283 | return load_dataset("tatsu-lab/alpaca")
284 | elif dataset_name == "alpaca-clean":
285 | return load_dataset("yahma/alpaca-cleaned")
286 | elif dataset_name == "chip2":
287 | return load_dataset("laion/OIG", data_files="unified_chip2.jsonl")
288 | elif dataset_name == "self-instruct":
289 | return load_dataset("yizhongw/self_instruct", name="self_instruct")
290 | elif dataset_name == "hh-rlhf":
291 | return load_dataset("Anthropic/hh-rlhf")
292 | elif dataset_name == "longform":
293 | return load_dataset("akoksal/LongForm")
294 | elif dataset_name == "oasst1":
295 | return load_dataset("timdettmers/openassistant-guanaco")
296 | elif dataset_name == "vicuna":
297 | raise NotImplementedError("Vicuna data was not released.")
298 | else:
299 | if os.path.exists(dataset_name):
300 | try:
301 | args.dataset_format = (
302 | args.dataset_format if args.dataset_format else "alpaca"
303 | )
304 | full_dataset = utils.local_dataset(dataset_name)
305 | return full_dataset
306 | except:
307 | raise ValueError(f"Error loading dataset from {dataset_name}")
308 | else:
309 | raise NotImplementedError(
310 | f"Dataset {dataset_name} not implemented yet."
311 | )
312 |
313 | def format_dataset(dataset, dataset_format):
314 | if (
315 | dataset_format == "alpaca"
316 | or dataset_format == "alpaca-clean"
317 | or (dataset_format is None and args.dataset in ["alpaca", "alpaca-clean"])
318 | ):
319 | dataset = dataset.map(
320 | extract_alpaca_dataset, remove_columns=["instruction"]
321 | )
322 | elif dataset_format == "hh-rlhf" or (
323 | dataset_format is None and args.dataset == "hh-rlhf"
324 | ):
325 | dataset = dataset.map(lambda x: {"input": "", "output": x["chosen"]})
326 | elif dataset_format == "prm":
327 | dataset = dataset.map(extract_prm_dataset)
328 | elif dataset_format == "prm-v2":
329 | dataset = dataset.map(extract_prm_v2_dataset)
330 | elif dataset_format == "arc":
331 | dataset = dataset.map(extract_arc_dataset)
332 | elif dataset_format == "apps":
333 | dataset = dataset.map(extract_apps_dataset)
334 | elif dataset_format == "metamath":
335 | dataset = dataset.map(extract_metamath_dataset)
336 | elif dataset_format == "mapped":
337 | dataset = dataset
338 | else:
339 | raise ValueError(f"Unsupported dataset format: {dataset_format}")
340 |
341 | # Remove unused columns.
342 | dataset = dataset.remove_columns(
343 | [
344 | col
345 | for col in dataset.column_names["train"]
346 | if col not in ["input", "output", "is_eos"]
347 | ]
348 | )
349 | return dataset
350 |
351 | # Load dataset.
352 | dataset = load_data(args.dataset)
353 | dataset = format_dataset(dataset, args.dataset_format)
354 |
355 | # Split train/eval, reduce size
356 | if args.do_eval:
357 | if "eval" in dataset:
358 | eval_dataset = dataset["eval"]
359 | else:
360 | print(
361 | "Splitting train dataset in train and validation according to `eval_dataset_size`"
362 | )
363 | dataset = dataset["train"].train_test_split(
364 | test_size=args.eval_dataset_size, shuffle=True, seed=42
365 | )
366 | eval_dataset = dataset["test"]
367 | if (
368 | args.max_eval_samples is not None
369 | and len(eval_dataset) > args.max_eval_samples
370 | ):
371 | eval_dataset = eval_dataset.select(range(args.max_eval_samples))
372 |
373 | if args.do_train:
374 | train_dataset = dataset["train"]
375 | if (
376 | args.max_train_samples is not None
377 | and len(train_dataset) > args.max_train_samples
378 | ):
379 | train_dataset = train_dataset.select(range(args.max_train_samples))
380 |
381 | data_collator = DataCollatorForCausalLM(
382 | tokenizer=tokenizer,
383 | source_max_len=args.source_max_len,
384 | target_max_len=args.target_max_len,
385 | total_max_len=args.total_max_len,
386 | train_on_source=args.train_on_source,
387 | add_eos_to_target=args.add_eos_to_target,
388 | add_eos_to_marked_target=args.add_eos_to_marked_target,
389 | )
390 | return dict(
391 | train_dataset=train_dataset if args.do_train else None,
392 | eval_dataset=eval_dataset if args.do_eval else None,
393 | data_collator=data_collator,
394 | )
395 |
--------------------------------------------------------------------------------
/train_code/determine_hyper.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | import random
4 | import pandas as pd
5 | import numpy as np
6 | from collections import Counter
7 | import re
8 | import argparse
9 | import os
10 | import argparse
11 | from itertools import product
12 |
13 |
14 |
15 | def process_data(input_file, ref_file):
16 | if not os.path.exists(input_file) or not os.path.exists(ref_file):
17 | return []
18 |
19 | with open(ref_file, "r") as r:
20 | data_lines = r.readlines()
21 |
22 | infer4reward_data = [json.loads(l) for l in data_lines]
23 |
24 | infer_dict = {}
25 | for item in infer4reward_data: # 获取官方答案
26 | infer_dict[item["prompt"]] = item["response"]
27 |
28 | with open(input_file, "r") as r:
29 | data_lines = r.readlines()
30 |
31 | data_json = []
32 | for item in data_lines:
33 | try:
34 | data_json.append(json.loads(item))
35 | except json.JSONDecodeError:
36 | continue
37 |
38 | random.shuffle(data_json)
39 |
40 | trn_data = {}
41 | for item in data_json:
42 | if item["idx"] in trn_data:
43 | trn_data[item["idx"]]["sample_list"].append({"output": item["output"], "reward": item["reward"], "response": infer_dict[item["prompt"]]})
44 | else:
45 | trn_data[item["idx"]] = {
46 | "prompt": item["prompt"],
47 | "sample_list": [{"output": item["output"], "reward": item["reward"], "response": infer_dict[item["prompt"]]}]
48 | }
49 |
50 | trn_list = []
51 | for item in trn_data:
52 | trn_list.append(trn_data[item])
53 |
54 | return trn_list
55 |
56 | def get_unique_trn_json(max_samples, trn_data, score_threshold):
57 | trn_json = []
58 | for item in trn_data:
59 | solutions = []
60 | for sample in item["sample_list"]:
61 | if "\n\n# Answer\n\n" in sample["output"]:
62 | final_answer = sample["output"].split("\n\n# Answer\n\n")[-1]
63 | prm_score = min(sample["reward"])
64 |
65 | orm_score = 0.0
66 | if sample["output"].split("\n\n# Answer\n\n")[-1] == sample["response"].split("\n\n# Answer\n\n")[-1]:
67 | orm_score = 1.0
68 |
69 | final_score = prm_score / 5.0 + orm_score
70 | solutions.append({'final_answer': final_answer, 'prm_score': prm_score, 'output': sample["output"], "score": final_score}) #计算prm_score和orm_score的加权平均, 也就是最终的reward 分数, 用于数据的筛选
71 |
72 | solutions.append({'final_answer': final_answer, 'prm_score': 0.0, 'output': sample["response"], "score": 1.0}) # 添加官方数据
73 | if len(solutions) == 0:
74 | continue
75 |
76 | solutions_sorted = sorted(solutions, key=lambda x: x['score'], reverse=True)
77 | idx = 0
78 | temp_input = item["prompt"].split("\n\n# Solution\n\n")[0]
79 | temp_input = temp_input.split("# Question\n\n")[-1]
80 |
81 | for solu in solutions_sorted:
82 | if solu["score"] > score_threshold:
83 | trn_json.append({"query": temp_input, "output": solu["output"], "response": sample["response"], "reward": solu["prm_score"]})
84 | idx += 1
85 | if idx >= max_samples:
86 | break
87 | # 去重部分
88 | unique_trn_json = []
89 | seen = set()
90 | for item in trn_json:
91 | identifier = (item["query"], item["output"])
92 | if identifier not in seen:
93 | seen.add(identifier)
94 | unique_trn_json.append(item)
95 |
96 | return unique_trn_json
97 |
98 |
99 |
100 | def cal_modi_acc_soft(data_json):
101 | unique_set = set()
102 | for item in data_json:
103 | unique_set.add(item["query"])
104 | unique_dict = {}
105 |
106 | for item in unique_set:
107 | unique_dict[item] = []
108 |
109 | #trn_json = []
110 | for item in data_json:
111 | unique_dict[item["query"]].append({"output": item["output"], "response": item["response"]})
112 |
113 | correct_num = []
114 |
115 | actual_num = []
116 | correct_ratio = []
117 |
118 | for item in unique_dict:
119 | temp_count = 0
120 | for output in unique_dict[item]:
121 | if output["output"].split("\n\n# Answer\n\n")[-1] == output["response"].split("\n\n# Answer\n\n")[-1]:
122 | temp_count = temp_count + 1
123 |
124 | correct_num.append(temp_count)
125 | actual_num.append(len(unique_dict[item]))
126 | correct_ratio.append(temp_count/len(unique_dict[item]))
127 | modi_acc_list = []
128 | for num, ratio in zip(correct_num, correct_ratio):
129 | if num >= 8:
130 | modi_acc_list.append(ratio)
131 |
132 | else:
133 | modi_acc_list.append((num * ratio)/ 8)
134 |
135 | return np.mean(modi_acc_list)
136 |
137 |
138 |
139 | def find_rewardbase(ana_list, reward_base, target_size):
140 | max_samples = 1
141 | max_iterations = 64
142 | iteration = 0
143 | while iteration < max_iterations:
144 | unique_trn_json = get_unique_trn_json(max_samples, ana_list, reward_base)
145 | if len(unique_trn_json) >= target_size:
146 | break
147 | max_samples += 1
148 | iteration += 1
149 |
150 | modi_acc = cal_modi_acc_soft(unique_trn_json)
151 | random.shuffle(unique_trn_json)
152 |
153 | return modi_acc
154 |
155 |
156 |
157 | def find_strategy(input, ref, target_size):
158 | ana_list = process_data(input, ref) # 合并 input 和 ref, 结合数据中的 response 和 reward 数据
159 |
160 | if not ana_list:
161 | return float('-inf')
162 | hyper_param = -1.0
163 |
164 | max_modi_acc = float('-inf')
165 |
166 | while True:
167 | modi_acc = find_rewardbase(ana_list, hyper_param, target_size)
168 | if modi_acc > max_modi_acc:
169 | max_modi_acc = modi_acc
170 |
171 | hyper_param += 0.01
172 |
173 | if hyper_param > 1.0:
174 | break
175 |
176 | return max_modi_acc
177 |
178 |
179 | def find_dataset(input_list, ref_list, target_size):
180 | modi_acc_list = []
181 |
182 |
183 | for input_item, ref_item in zip(input_list, ref_list):
184 | modi_acc = find_strategy(input_item, ref_item, target_size)
185 |
186 |
187 | modi_acc_list.append(modi_acc)
188 |
189 | return modi_acc_list
190 |
191 |
192 |
193 | def parse_args():
194 | parser = argparse.ArgumentParser(description="Generate input and reference file paths based on provided temps and sample numbers.")
195 | parser.add_argument('--temps', type=float, nargs='+', required=True, help="List of temperatures.")
196 | parser.add_argument('--sample_nums', type=int, nargs='+', required=True, help="List of sample numbers.")
197 | parser.add_argument('--input_path_template', type=str, required=True, help="Template for the input path.")
198 | parser.add_argument('--ref_path_template', type=str, required=True, help="Template for the reference path.")
199 | parser.add_argument('--iter', type=int, required=True, help="Iteration number to be included in the file path.")
200 | parser.add_argument('--valid_sample_size', type=int, required=True, help="Valid sample size.")
201 | return parser.parse_args()
202 |
203 | def main():
204 | args = parse_args()
205 |
206 | input_list = []
207 | ref_list = []
208 | combinations = list(product(args.temps, args.sample_nums)) # 记录temp, sample_num的组合
209 |
210 |
211 | for temp, sample_num in combinations:
212 | # 设置输入和输出路径,将 iter 也替换到路径中
213 | input_path = args.input_path_template.format(temp=temp, sample_num=sample_num, iter=args.iter)
214 | ref_path = args.ref_path_template.format(temp=temp, sample_num=sample_num, iter=args.iter)
215 |
216 | input_list.append(input_path)
217 | ref_list.append(ref_path)
218 |
219 | # 假设 find_dataset 返回的是一个效果值列表,与 combinations 一一对应
220 | effect_values = find_dataset(input_list, ref_list, args.valid_sample_size)
221 |
222 | # 找到效果值最好的组合
223 | max_effect_index = effect_values.index(max(effect_values))
224 | best_temp, best_sample_num = combinations[max_effect_index]
225 |
226 | # 输出最好的组合,格式为:temp=最佳temp值, sample_num=最佳sample_num值
227 | print(f"{best_temp} {best_sample_num}")
228 |
229 | if __name__ == "__main__":
230 | main()
231 |
--------------------------------------------------------------------------------
/train_code/eval_arc_save_metrics.py:
--------------------------------------------------------------------------------
1 | import json
2 | import argparse
3 | import os
4 | def main():
5 | parser = argparse.ArgumentParser()
6 | parser.add_argument("--input_file", type=str, required=True)
7 | parser.add_argument("--output_file", type=str, required=True)
8 | args = parser.parse_args()
9 |
10 | print(args)
11 | data = json.load(open(args.input_file, "r"))
12 | correct_count = 0
13 | for item in data.values():
14 | assert "output1" not in item, "Pass@k is not allowed"
15 | score = item["output0"]["score"]
16 | if abs(float(score - 1.0) )< 1e-5:
17 | correct_count += 1
18 | print(f"Accuracy: {correct_count / len(data)}")
19 | metrics = {"accuracy": correct_count / len(data)}
20 | os.makedirs(os.path.dirname(args.output_file), exist_ok=True)
21 | with open(args.output_file, "w") as f:
22 | json.dump(metrics, f)
23 |
24 |
25 |
26 | if __name__ == "__main__":
27 | main()
--------------------------------------------------------------------------------
/train_code/gsm8k_test.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import pdb
4 | import re
5 |
6 | import jsonlines
7 | from fraction import Fraction
8 |
9 | from vllm import LLM, SamplingParams
10 | #from extract_answer_use_chatgpt import _extract_answer_chatgpt
11 | import sys
12 |
13 | MAX_INT = sys.maxsize #最大整数
14 |
15 | ####CUDA_VISIBLE_DEVICES=0 python gsm8k_test_the_answer_is_original_batch.py --model xxxxx --start 0 --end 1400 --batch_size 80 --tensor_parallel_size 1
16 |
17 | def is_number(s):
18 | try: # 如果能运行float(s)语句,返回True(字符串s是浮点数)
19 | float(s)
20 | return True
21 | except ValueError: # ValueError为Python的一种标准异常,表示"传入无效的参数"
22 | pass # 如果引发了ValueError这种异常,不做任何事情(pass:不做任何事情,一般用做占位语句)
23 | try:
24 | import unicodedata # 处理ASCii码的包
25 | unicodedata.numeric(s) # 把一个表示数字的字符串转换为浮点数返回的函数
26 | return True
27 | except (TypeError, ValueError):
28 | pass
29 | return False
30 |
31 | def extract_answer_number(completion):
32 | text = completion.split('\n\n# Answer\n\n')
33 | if len(text) > 1:
34 | extract_ans = text[-1].strip()
35 | match = re.search(r'[\-+]?\d*[\.,/]?\d+', extract_ans)
36 | if match: ######gsm8k答案都是整数
37 | if '/' in match.group():
38 | denominator = match.group().split('/')[1] # 分母
39 | numerator = match.group().split('/')[0] # 分子
40 | if is_number(denominator) == True and is_number(numerator) == True:
41 | if denominator == '0': ##分母为0
42 |
43 | print('分母为0 ====:', match.group())
44 | return round(float(numerator.replace(',', '')))
45 | else: ##分母不为0
46 | frac = Fraction(match.group().replace(',', ''))
47 | num_numerator = frac.numerator
48 | num_denominator = frac.denominator
49 | return round(float(num_numerator / num_denominator)) # 分数, 四舍五入取整
50 | else:
51 | return None
52 | else:
53 | if float(match.group().replace(',', '')) == float('inf'):
54 | return None
55 | return round(float(match.group().replace(',', ''))) ###小数和千分数, 四舍五入取整
56 | else:
57 | return None
58 | else:
59 | return None
60 |
61 | def batch_data(data_list, batch_size=1):
62 | n = len(data_list) // batch_size
63 | batch_data = []
64 | for i in range(n-1):
65 | start = i * batch_size
66 | end = (i+1)*batch_size
67 | batch_data.append(data_list[start:end])
68 |
69 | last_start = (n-1) * batch_size
70 | last_end = MAX_INT
71 | batch_data.append(data_list[last_start:last_end])
72 | return batch_data
73 |
74 |
75 | def gsm8k_test(args, model, start=0, end=MAX_INT, batch_size=1, tensor_parallel_size=1):
76 | INVALID_ANS = "[invalid]"
77 | gsm8k_ins = []
78 | gsm8k_answers = []
79 |
80 |
81 | #problem_prompt = ("A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {instruction} ASSISTANT: Let's think step by step.")
82 | #problem_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response: Let's think step by step."
83 |
84 | problem_prompt = "# Question\n\n{instruction}\n\n# Solution\n\n"
85 | #problem_prompt = ("A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant's response is from ChatGPT3.5. USER: {instruction} ASSISTANT: Let's think step by step.")
86 | # problem_prompt = (
87 | # "You are a math assistant, skilled at solving various mathematical problems. USER: {instruction} ASSISTANT: Let's think step by step.")
88 | print('promt =====', problem_prompt)
89 | with open('./test.jsonl',"r+", encoding="utf8") as f:
90 | for idx, item in enumerate(jsonlines.Reader(f)):
91 | temp_instr = problem_prompt.format(instruction=item["question"])
92 | gsm8k_ins.append(temp_instr)
93 | temp_ans = item['answer'].split('#### ')[1]
94 | temp_ans = int(temp_ans.replace(',', ''))
95 | gsm8k_answers.append(temp_ans)
96 |
97 | gsm8k_ins = gsm8k_ins[start:end]
98 | gsm8k_answers = gsm8k_answers[start:end]
99 | print('lenght ====', len(gsm8k_ins))
100 | batch_gsm8k_ins = batch_data(gsm8k_ins, batch_size=batch_size)
101 |
102 |
103 |
104 |
105 | stop_tokens = ["Question:", "Question", "USER:", "USER", "ASSISTANT:", "ASSISTANT", "Instruction:", "Instruction", "Response:", "Response"]
106 | sampling_params = SamplingParams(temperature=0, top_p=1, max_tokens=args.max_tokens)
107 | print('sampleing =====', sampling_params)
108 | llm = LLM(model=model,tensor_parallel_size=tensor_parallel_size)
109 | result = []
110 | res_completions = []
111 | for idx, (prompt, prompt_answer) in enumerate(zip(batch_gsm8k_ins, gsm8k_answers)):
112 | print('llm idx ====', idx)
113 | if isinstance(prompt, list):
114 | pass
115 | else:
116 | prompt = [prompt]
117 |
118 | completions = llm.generate(prompt, sampling_params)
119 | for output in completions:
120 | prompt = output.prompt
121 | generated_text = output.outputs[0].text
122 | res_completions.append(generated_text)
123 |
124 | invalid_outputs = []
125 | all_outputs = []
126 | for idx, (prompt, completion, prompt_answer) in enumerate(zip(gsm8k_ins, res_completions, gsm8k_answers)):
127 | print('chatgpt idx =====', idx)
128 | doc = {'question': prompt}
129 | y_pred = extract_answer_number(completion)
130 | if y_pred != None:
131 | result.append(float(y_pred) == float(prompt_answer))
132 | temp = {'question': prompt, 'output': completion, 'answer': prompt_answer}
133 | all_outputs.append(temp)
134 | else:
135 | result.append(False)
136 | temp = {'question': prompt, 'output': completion, 'answer': prompt_answer}
137 | invalid_outputs.append(temp)
138 | all_outputs.append(temp)
139 | # pdb.set_trace()
140 | acc = sum(result) / len(result)
141 | print('start===', start, ', end====', end)
142 | print('length====', len(result), ', acc====', acc)
143 | #print('len invalid outputs ====', len(invalid_outputs), ', valid_outputs===', invalid_outputs)
144 | model_name = '_'.join(model.split('/')[-2:])
145 |
146 | OUTPUT_FILE_SUFFIX = ".json"
147 | import os
148 | output_data_path = os.path.join(args.output_path, args.model_id + OUTPUT_FILE_SUFFIX)
149 | import os
150 | output_dir = os.path.dirname(output_data_path)
151 | if not os.path.exists(output_dir):
152 | os.makedirs(output_dir)
153 |
154 | with open(output_data_path, 'w') as f:
155 | json.dump({"length====": len(result), "acc====": acc}, f)
156 |
157 |
158 | #pdb.set_trace()
159 |
160 |
161 | def parse_args():
162 | parser = argparse.ArgumentParser()
163 | parser.add_argument("--model", type=str) # start index
164 | parser.add_argument("--model_id", type=str) # start index
165 | parser.add_argument("--output_path", type=str) # start index
166 | parser.add_argument("--start", type=int, default=0) #start index
167 | parser.add_argument("--end", type=int, default=MAX_INT) # start index
168 | parser.add_argument("--batch_size", type=int, default=1) # batch_size
169 | parser.add_argument("--tensor_parallel_size", type=int, default=1) # tensor_parallel_size
170 | parser.add_argument("--max_tokens", type=int, default=768)
171 | return parser.parse_args()
172 |
173 | if __name__ == "__main__":
174 | args = parse_args()
175 | gsm8k_test(args=args, model=args.model, start=args.start, end=args.end, batch_size=args.batch_size, tensor_parallel_size=args.tensor_parallel_size)
176 | #pdb.set_trace()
177 |
--------------------------------------------------------------------------------
/train_code/infer_arc.sh:
--------------------------------------------------------------------------------
1 | RUN_NAME=arc_train_online_rft_mistral_temp0.4_7kdata
2 | LOG_PATH=logs/${RUN_NAME}
3 | QUERY_FILE=/mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/zengweihao02/easy2hard/share/project/weihao/easy-to-hard-main-share/dynamic_ana/arc_data/arc_test_data.jsonl
4 | PROMPT_DIR="/mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/zengweihao02/easy2hard/share/project/weihao/easy-to-hard-main-share/dynamic_ana/arc_prompt"
5 | mkdir -p $LOG_PATH
6 |
7 | # for i in {2500..4500..500}
8 | # do
9 | # python vllm_infer_arc.py \
10 | # --input_file $QUERY_FILE \
11 | # --prompt_template_dir $PROMPT_DIR \
12 | # --shot_num 0 \
13 | # --model /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/zengweihao02/checkpoints/$RUN_NAME/converted/ckpt$i \
14 | # --sample_num 1 \
15 | # --temperature 0 \
16 | # --top_k -1 \
17 | # --max_tokens 1024 \
18 | # --split test \
19 | # --save /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/zengweihao02/easy2hard/share/project/weihao/easy-to-hard-main-share/dynamic_ana/arc/results/${RUN_NAME}/ckpt$i/gen.json \
20 | # --tensor-parallel-size 1 \
21 | # --cuda_ids 0 \
22 | # --cache_dir /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/zengweihao02/easy2hard/share/project/weihao/easy-to-hard-main-share/dynamic_ana/ARC/.cache \
23 | # 2>&1 | tee -a $LOG_PATH/ckpt$i.log
24 |
25 | # python eval_arc_save_metrics.py \
26 | # --input_file /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/zengweihao02/easy2hard/share/project/weihao/easy-to-hard-main-share/dynamic_ana/arc/results/${RUN_NAME}/ckpt$i/gen.json \
27 | # --output_file /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/zengweihao02/easy2hard/share/project/weihao/easy-to-hard-main-share/dynamic_ana/arc/results/${RUN_NAME}/ckpt$i/metrics.json \
28 | # 2>&1 | tee -a $LOG_PATH/ckpt$i.log
29 | # done
30 |
31 | RUN_NAME=Mistral-7B-Instruct-v0.1
32 | for i in {2500..2500..500}
33 | do
34 | python vllm_infer_arc.py \
35 | --input_file $QUERY_FILE \
36 | --prompt_template_dir $PROMPT_DIR \
37 | --shot_num 1 \
38 | --model /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/zengweihao02/modelzoo/Mistral-7B-Instruct-v0.1 \
39 | --sample_num 1 \
40 | --temperature 0 \
41 | --top_k -1 \
42 | --max_tokens 1024 \
43 | --split test \
44 | --save /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/zengweihao02/easy2hard/share/project/weihao/easy-to-hard-main-share/dynamic_ana/arc/results/${RUN_NAME}/ckpt$i/gen.json \
45 | --tensor-parallel-size 1 \
46 | --cuda_ids 0 \
47 | --cache_dir /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/zengweihao02/easy2hard/share/project/weihao/easy-to-hard-main-share/dynamic_ana/ARC/.cache \
48 | 2>&1 | tee -a $LOG_PATH/ckpt$i.log
49 |
50 | python eval_arc_save_metrics.py \
51 | --input_file /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/zengweihao02/easy2hard/share/project/weihao/easy-to-hard-main-share/dynamic_ana/arc/results/${RUN_NAME}/ckpt$i/gen.json \
52 | --output_file /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/zengweihao02/easy2hard/share/project/weihao/easy-to-hard-main-share/dynamic_ana/arc/results/${RUN_NAME}/ckpt$i/metrics.json \
53 | 2>&1 | tee -a $LOG_PATH/ckpt$i.log
54 | done
--------------------------------------------------------------------------------
/train_code/infer_gsm8k1.sh:
--------------------------------------------------------------------------------
1 | #export HF_HOME="/ssddata/model_hub"
2 | # export CUDA_HOME=/usr/local/cuda-11.7 #指定cuda根目录
3 | # export PATH=$PATH:/usr/local/cuda-11.7/bin #安装的cuda的路径下的bin文件夹,包含了nvcc等二进制程序
4 | # export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-11.7/lib64 ##安装的cuda的路径下的lib64文件夹,包含很多库文件
5 |
6 |
7 | # for i in {100..500..100}
8 | # do
9 | # CUDA_VISIBLE_DEVICES=1 python gsm8k_test.py \
10 | # --model /share/project/weihao/save_dir/checkpoints/math_format_infer_iter9_fix_rest_64k/iter7/converted/ckpt$i \
11 | # --output_path /share/project/weihao/save_dir/checkpoints/math_format_infer_iter9_fix_rest_64k/iter7/converted/ckpt$i \
12 | # --model_id gsm8k \
13 | # --start 0 \
14 | # --end 1400 \
15 | # --batch_size 800 \
16 | # --tensor_parallel_size 1 \
17 | # --max_tokens 768
18 | # done
19 |
20 |
21 | # for iter in {2..6}
22 | # do
23 | # for i in {100..500..100}
24 | # do
25 | # CUDA_VISIBLE_DEVICES=0 python gsm8k_test.py \
26 | # --model /share/project/weihao/save_dir/checkpoints/math_format_infer_iter9_fix_rewardneg0.0_pre_iterrft_64k/iter$iter/converted/ckpt$i \
27 | # --output_path /share/project/weihao/save_dir/checkpoints/math_format_infer_iter9_fix_rewardneg0.0_pre_iterrft_64k/iter$iter/converted/ckpt$i \
28 | # --model_id gsm8k \
29 | # --start 0 \
30 | # --end 1400 \
31 | # --batch_size 800 \
32 | # --tensor_parallel_size 1 \
33 | # --max_tokens 768
34 | # done
35 | # done
36 |
37 |
38 |
39 | # for i in {1750..2250..250}
40 | # do
41 | # CUDA_VISIBLE_DEVICES=0 python gsm8k_test.py \
42 | # --model /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/FMG/zengweihao02/checkpoints/math_format_b_star_deepseek_math_16k_fix_bsz_fine_stand_fix_moretemp/converted/ckpt$i \
43 | # --output_path /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/FMG/zengweihao02/checkpoints/math_format_b_star_deepseek_math_16k_fix_bsz_fine_stand_fix_moretemp/converted/ckpt$i \
44 | # --model_id gsm8k \
45 | # --start 0 \
46 | # --end 1400 \
47 | # --batch_size 800 \
48 | # --tensor_parallel_size 1 \
49 | # --max_tokens 768
50 | # done
51 |
52 | # for i in {1750..2250..250}
53 | # do
54 | # CUDA_VISIBLE_DEVICES=0 python vllm_infer.py \
55 | # --input_data /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/FMG/zengweihao02/easy2hard/share/project/weihao/easy-to-hard-main-share/easy-to-hard-main-share/data/test_ppo.json \
56 | # --model_dir /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/FMG/zengweihao02/checkpoints/math_format_b_star_deepseek_math_16k_fix_bsz_fine_stand_fix_moretemp/converted/ckpt$i \
57 | # --sample_num 1 \
58 | # --output_file /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/FMG/zengweihao02/checkpoints/math_format_b_star_deepseek_math_16k_fix_bsz_fine_stand_fix_moretemp/converted/ckpt$i/test_ppo_1to5_infer_greedy.json \
59 | # --tensor_parallel_size 1 \
60 | # --temperature 0.0 \
61 | # --top_k -1 \
62 | # --max_tokens 768
63 | # done
64 |
65 |
66 | # # python cal_metric_vllm.py \
67 | # # --tokenizer_path /share/project/weihao/model_zoo/llemma_7b/tokenizer.model \
68 | # # --answer_file /share/project/weihao/save_dir/checkpoints/math_format_online_rft_deepseek_math_48k_fix_bsz_upreward_update/converted/ckpt4000/test_ppo_1to5_infer_greedy.json \
69 | # # --output_file /share/project/weihao/save_dir/checkpoints/math_format_online_rft_deepseek_math_48k_fix_bsz_upreward_update/converted/ckpt4000/test_ppo_1to5_infer_greedy_metric.json \
70 |
71 |
72 |
73 | # for i in {1750..2250..250}
74 | # do
75 | # python cal_metric_vllm.py \
76 | # --tokenizer_path /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/FMG/zengweihao02/modelzoo/llemma_7b/tokenizer.model \
77 | # --answer_file /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/FMG/zengweihao02/checkpoints/math_format_b_star_deepseek_math_16k_fix_bsz_fine_stand_fix_moretemp/converted/ckpt$i/test_ppo_1to5_infer_greedy.json \
78 | # --output_file /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/FMG/zengweihao02/checkpoints/math_format_b_star_deepseek_math_16k_fix_bsz_fine_stand_fix_moretemp/converted/ckpt$i/test_ppo_1to5_infer_greedy_metric.json
79 | # done
80 |
81 |
82 | for iter in {2..10}
83 | do
84 | for i in {250..250..250}
85 | do
86 | CUDA_VISIBLE_DEVICES=0 python gsm8k_test.py \
87 | --model /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/FMG/zengweihao02/checkpoints/math_format_rest_llama3.1_math_32k_lr2e_5/iter$iter/converted/ckpt$i \
88 | --output_path /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/FMG/zengweihao02/checkpoints/math_format_rest_llama3.1_math_32k_lr2e_5/iter$iter/converted/ckpt$i \
89 | --model_id gsm8k \
90 | --start 0 \
91 | --end 1400 \
92 | --batch_size 800 \
93 | --tensor_parallel_size 1 \
94 | --max_tokens 768
95 | done
96 | done
97 |
98 |
99 | for iter in {2..10}
100 | do
101 | for i in {250..250..250}
102 | do
103 | CUDA_VISIBLE_DEVICES=0 python vllm_infer.py \
104 | --input_data /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/FMG/zengweihao02/easy2hard/share/project/weihao/easy-to-hard-main-share/easy-to-hard-main-share/data/test_ppo.json \
105 | --model_dir /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/FMG/zengweihao02/checkpoints/math_format_rest_llama3.1_math_32k_lr2e_5/iter$iter/converted/ckpt$i \
106 | --sample_num 1 \
107 | --output_file /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/FMG/zengweihao02/checkpoints/math_format_rest_llama3.1_math_32k_lr2e_5/iter$iter/converted/ckpt$i/test_ppo_1to5_infer_greedy.json \
108 | --tensor_parallel_size 1 \
109 | --temperature 0.0 \
110 | --top_k -1 \
111 | --max_tokens 768
112 | done
113 | done
114 |
115 |
116 |
117 | for iter in {2..10}
118 | do
119 | for i in {250..250..250}
120 | do
121 | python cal_metric_vllm.py \
122 | --tokenizer_path /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/FMG/zengweihao02/modelzoo/llemma_7b/tokenizer.model \
123 | --answer_file /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/FMG/zengweihao02/checkpoints/math_format_rest_llama3.1_math_32k_lr2e_5/iter$iter/converted/ckpt$i/test_ppo_1to5_infer_greedy.json \
124 | --output_file /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/FMG/zengweihao02/checkpoints/math_format_rest_llama3.1_math_32k_lr2e_5/iter$iter/converted/ckpt$i/test_ppo_1to5_infer_greedy_metric.json
125 | done
126 | done
127 |
128 |
129 | # for t in 0.5 0.7 0.9 1.1
130 | # do
131 | # for i in {500..4500..500}
132 | # do
133 | # CUDA_VISIBLE_DEVICES=0 python gsm8k_test.py \
134 | # --model /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/FMG/zengweihao02/checkpoints/math_format_online_rft_mistral_math_64k_${t}t_neg0.4r/converted/ckpt$i \
135 | # --output_path /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/FMG/zengweihao02/checkpoints/math_format_online_rft_mistral_math_64k_${t}t_neg0.4r/converted/ckpt$i \
136 | # --model_id gsm8k \
137 | # --start 0 \
138 | # --end 1400 \
139 | # --batch_size 800 \
140 | # --tensor_parallel_size 1 \
141 | # --max_tokens 768
142 | # done
143 |
144 | # for i in {500..4500..500}
145 | # do
146 | # CUDA_VISIBLE_DEVICES=0 python vllm_infer.py \
147 | # --input_data /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/FMG/zengweihao02/easy2hard/share/project/weihao/easy-to-hard-main-share/easy-to-hard-main-share/data/test_ppo.json \
148 | # --model_dir /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/FMG/zengweihao02/checkpoints/math_format_online_rft_mistral_math_64k_${t}t_neg0.4r/converted/ckpt$i \
149 | # --sample_num 1 \
150 | # --output_file /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/FMG/zengweihao02/checkpoints/math_format_online_rft_mistral_math_64k_${t}t_neg0.4r/converted/ckpt$i/test_ppo_1to5_infer_greedy.json \
151 | # --tensor_parallel_size 1 \
152 | # --temperature 0.0 \
153 | # --top_k -1 \
154 | # --max_tokens 768
155 | # done
156 |
157 | # for i in {500..4500..500}
158 | # do
159 | # python cal_metric_vllm.py \
160 | # --tokenizer_path /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/FMG/zengweihao02/modelzoo/llemma_7b/tokenizer.model \
161 | # --answer_file /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/FMG/zengweihao02/checkpoints/math_format_online_rft_mistral_math_64k_${t}t_neg0.4r/converted/ckpt$i/test_ppo_1to5_infer_greedy.json \
162 | # --output_file /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/FMG/zengweihao02/checkpoints/math_format_online_rft_mistral_math_64k_${t}t_neg0.4r/converted/ckpt$i/test_ppo_1to5_infer_greedy_metric.json
163 | # done
164 | # done
165 |
--------------------------------------------------------------------------------
/train_code/math_utils/README.md:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/math_utils/README.md
--------------------------------------------------------------------------------
/train_code/math_utils/__pycache__/grader.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/math_utils/__pycache__/grader.cpython-310.pyc
--------------------------------------------------------------------------------
/train_code/math_utils/__pycache__/grader.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/math_utils/__pycache__/grader.cpython-311.pyc
--------------------------------------------------------------------------------
/train_code/math_utils/__pycache__/math_normalize.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/math_utils/__pycache__/math_normalize.cpython-310.pyc
--------------------------------------------------------------------------------
/train_code/math_utils/__pycache__/math_normalize.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/math_utils/__pycache__/math_normalize.cpython-311.pyc
--------------------------------------------------------------------------------
/train_code/math_utils/__pycache__/math_rl_utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/math_utils/__pycache__/math_rl_utils.cpython-310.pyc
--------------------------------------------------------------------------------
/train_code/math_utils/__pycache__/math_rl_utils.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/math_utils/__pycache__/math_rl_utils.cpython-311.pyc
--------------------------------------------------------------------------------
/train_code/math_utils/grader.py:
--------------------------------------------------------------------------------
1 | """
2 | Answer checker API that uses sympy to simplify expressions and check for equality.
3 |
4 | Call grade_answer(given_answer: str, ground_truth: str).
5 | """
6 |
7 | import re
8 | import sympy
9 | from pylatexenc import latex2text
10 | from sympy.parsing import sympy_parser
11 |
12 | from math_utils import math_normalize
13 |
14 |
15 | # sympy might hang -- we don't care about trying to be lenient in these cases
16 | BAD_SUBSTRINGS = ["^{", "^("]
17 | BAD_REGEXES = ["\^[0-9]+\^", "\^[0-9][0-9]+"]
18 | TUPLE_CHARS = "()[]"
19 |
20 |
21 | def _sympy_parse(expr: str):
22 | """Parses an expression with sympy."""
23 | py_expr = expr.replace("^", "**")
24 | return sympy_parser.parse_expr(
25 | py_expr,
26 | transformations=(
27 | sympy_parser.standard_transformations
28 | + (sympy_parser.implicit_multiplication_application,)
29 | ),
30 | )
31 |
32 |
33 | def _parse_latex(expr: str) -> str:
34 | """Attempts to parse latex to an expression sympy can read."""
35 | expr = expr.replace("\\tfrac", "\\frac")
36 | expr = expr.replace("\\dfrac", "\\frac")
37 | expr = expr.replace("\\frac", " \\frac") # Play nice with mixed numbers.
38 | expr = latex2text.LatexNodes2Text().latex_to_text(expr)
39 |
40 | # Replace the specific characters that this parser uses.
41 | expr = expr.replace("√", "sqrt")
42 | expr = expr.replace("π", "pi")
43 | expr = expr.replace("∞", "inf")
44 | expr = expr.replace("∪", "U")
45 | expr = expr.replace("·", "*")
46 | expr = expr.replace("×", "*")
47 |
48 | return expr.strip()
49 |
50 |
51 | def _is_float(num: str) -> bool:
52 | try:
53 | float(num)
54 | return True
55 | except ValueError:
56 | return False
57 |
58 |
59 | def _is_int(x: float) -> bool:
60 | try:
61 | return abs(x - int(round(x))) <= 1e-7
62 | except:
63 | return False
64 |
65 |
66 | def _is_frac(expr: str) -> bool:
67 | return bool(re.search(r"^-?[0-9]+.?/0*[1-9][0-9]*.?$", expr))
68 |
69 |
70 | def _str_is_int(x: str) -> bool:
71 | try:
72 | x = _strip_properly_formatted_commas(x)
73 | x = float(x)
74 | return abs(x - int(round(x))) <= 1e-7
75 | except:
76 | return False
77 |
78 |
79 | def _str_to_int(x: str) -> bool:
80 | x = x.replace(",", "")
81 | x = float(x)
82 | return int(x)
83 |
84 |
85 | def _inject_implicit_mixed_number(step: str):
86 | """
87 | Automatically make a mixed number evalable
88 | e.g. 7 3/4 => 7+3/4
89 | """
90 | p1 = re.compile("([0-9]) +([0-9])")
91 | step = p1.sub("\\1+\\2", step) ## implicit mults
92 | return step
93 |
94 |
95 | def _strip_properly_formatted_commas(expr: str):
96 | # We want to be careful because we don't want to strip tuple commas
97 | p1 = re.compile("(\d)(,)(\d\d\d)($|\D)")
98 | while True:
99 | next_expr = p1.sub("\\1\\3\\4", expr)
100 | if next_expr == expr:
101 | break
102 | expr = next_expr
103 | return next_expr
104 |
105 |
106 | def _normalize(expr: str) -> str:
107 | """Normalize answer expressions."""
108 | if expr is None:
109 | return None
110 |
111 | # Remove enclosing `\text{}`.
112 | m = re.search("^\\\\text\{(?P.+?)\}$", expr)
113 | if m is not None:
114 | expr = m.group("text")
115 |
116 | expr = expr.replace("\\%", "%")
117 | expr = expr.replace("\\$", "$")
118 | expr = expr.replace("$", "")
119 | expr = expr.replace("%", "")
120 | expr = expr.replace(" or ", " , ")
121 | expr = expr.replace(" and ", " , ")
122 |
123 | expr = expr.replace("million", "*10^6")
124 | expr = expr.replace("billion", "*10^9")
125 | expr = expr.replace("trillion", "*10^12")
126 |
127 | for unit in [
128 | "degree",
129 | "cm",
130 | "centimeter",
131 | "meter",
132 | "mile",
133 | "second",
134 | "minute",
135 | "hour",
136 | "day",
137 | "week",
138 | "month",
139 | "year",
140 | "foot",
141 | "feet",
142 | "inch",
143 | "yard",
144 | ]:
145 | expr = re.sub(f"{unit}(es)?(s)? *(\^[0-9]+)?", "", expr)
146 | expr = re.sub(f"\^ *\\\\circ", "", expr)
147 |
148 | if len(expr) > 0 and expr[0] == "{" and expr[-1] == "}":
149 | expr = expr[1:-1]
150 |
151 | expr = re.sub(",\\\\! *", "", expr)
152 | if _is_float(expr) and _is_int(float(expr)):
153 | expr = str(int(round(float(expr))))
154 | if "\\" in expr:
155 | try:
156 | expr = _parse_latex(expr)
157 | except:
158 | pass
159 |
160 | # edge case with mixed numbers and negative signs
161 | expr = re.sub("- *", "-", expr)
162 |
163 | expr = _inject_implicit_mixed_number(expr)
164 | expr = expr.replace(" ", "")
165 |
166 | # if we somehow still have latex braces here, just drop them
167 | expr = expr.replace("{", "")
168 | expr = expr.replace("}", "")
169 |
170 | # don't be case sensitive for text answers
171 | expr = expr.lower()
172 |
173 | if _str_is_int(expr):
174 | expr = str(_str_to_int(expr))
175 |
176 | return expr
177 |
178 |
179 | def count_unknown_letters_in_expr(expr: str):
180 | expr = expr.replace("sqrt", "")
181 | expr = expr.replace("frac", "")
182 | letters_in_expr = set([x for x in expr if x.isalpha()])
183 | return len(letters_in_expr)
184 |
185 |
186 | def should_allow_eval(expr: str):
187 | # we don't want to try parsing unknown text or functions of more than two variables
188 | if count_unknown_letters_in_expr(expr) > 2:
189 | return False
190 |
191 | for bad_string in BAD_SUBSTRINGS:
192 | if bad_string in expr:
193 | return False
194 |
195 | for bad_regex in BAD_REGEXES:
196 | if re.search(bad_regex, expr) is not None:
197 | return False
198 |
199 | return True
200 |
201 |
202 | def are_equal_under_sympy(ground_truth_normalized: str, given_normalized: str):
203 | are_equal = False
204 | try:
205 | expr = f"({ground_truth_normalized})-({given_normalized})"
206 | if should_allow_eval(expr):
207 | sympy_diff = _sympy_parse(expr)
208 | simplified = sympy.simplify(sympy_diff)
209 | if simplified == 0:
210 | are_equal = True
211 | except:
212 | pass
213 | return are_equal
214 |
215 |
216 | def split_tuple(expr: str):
217 | """
218 | Split the elements in a tuple/interval, while handling well-formatted commas in large numbers
219 | """
220 | expr = _strip_properly_formatted_commas(expr)
221 | if len(expr) == 0:
222 | return []
223 | if (
224 | len(expr) > 2
225 | and expr[0] in TUPLE_CHARS
226 | and expr[-1] in TUPLE_CHARS
227 | and all([ch not in expr[1:-1] for ch in TUPLE_CHARS])
228 | ):
229 | elems = [elem.strip() for elem in expr[1:-1].split(",")]
230 | else:
231 | elems = [expr]
232 | return elems
233 |
234 |
235 | def grade_answer(given_answer: str, ground_truth: str) -> bool:
236 | """
237 | The answer will be considered correct if:
238 | (a) it normalizes to the same string as the ground truth answer
239 | OR
240 | (b) sympy can simplify the difference between the expressions to 0
241 | """
242 | if given_answer is None:
243 | return False
244 |
245 | ground_truth_normalized_mathd = math_normalize.normalize_answer(ground_truth)
246 | given_answer_normalized_mathd = math_normalize.normalize_answer(given_answer)
247 |
248 | # be at least as lenient as mathd
249 | if ground_truth_normalized_mathd == given_answer_normalized_mathd:
250 | return True
251 |
252 | ground_truth_normalized = _normalize(ground_truth)
253 | given_normalized = _normalize(given_answer)
254 |
255 | if ground_truth_normalized is None:
256 | return False
257 |
258 | if ground_truth_normalized == given_normalized:
259 | return True
260 |
261 | if len(given_normalized) == 0:
262 | return False
263 |
264 | ground_truth_elems = split_tuple(ground_truth_normalized)
265 | given_elems = split_tuple(given_normalized)
266 |
267 | if len(ground_truth_elems) > 1 and (
268 | ground_truth_normalized[0] != given_normalized[0]
269 | or ground_truth_normalized[-1] != given_normalized[-1]
270 | ):
271 | is_correct = False
272 | elif len(ground_truth_elems) != len(given_elems):
273 | is_correct = False
274 | else:
275 | for ground_truth_elem, given_elem in zip(ground_truth_elems, given_elems):
276 | if _is_frac(ground_truth_elem) and _is_frac(given_elem):
277 | # if fractions aren't reduced, then shouldn't be marked as correct
278 | # so, we don't want to allow sympy.simplify in this case
279 | is_correct = ground_truth_elem == given_elem
280 | elif _str_is_int(ground_truth_elem) != _str_is_int(given_elem):
281 | # if the ground truth answer is an integer, we require the given answer to be a strict match (no sympy.simplify)
282 | is_correct = False
283 | else:
284 | # is_correct = are_equal_under_sympy(ground_truth_elem, given_elem)
285 | is_correct = False
286 | if not is_correct:
287 | break
288 |
289 | return is_correct
290 |
--------------------------------------------------------------------------------
/train_code/math_utils/math_normalize.py:
--------------------------------------------------------------------------------
1 | """
2 | This logic is largely copied from the Hendrycks' MATH release (math_equivalence).
3 | """
4 | import re
5 | from typing import Optional
6 |
7 |
8 | def normalize_answer(answer: Optional[str]) -> Optional[str]:
9 | if answer is None:
10 | return None
11 | answer = answer.strip()
12 | try:
13 | # Remove enclosing `\text{}`.
14 | m = re.search("^\\\\text\{(?P.+?)\}$", answer)
15 | if m is not None:
16 | answer = m.group("text").strip()
17 | return _strip_string(answer)
18 | except:
19 | return answer
20 |
21 |
22 | def _fix_fracs(string):
23 | substrs = string.split("\\frac")
24 | new_str = substrs[0]
25 | if len(substrs) > 1:
26 | substrs = substrs[1:]
27 | for substr in substrs:
28 | new_str += "\\frac"
29 | if substr[0] == "{":
30 | new_str += substr
31 | else:
32 | try:
33 | assert len(substr) >= 2
34 | except:
35 | return string
36 | a = substr[0]
37 | b = substr[1]
38 | if b != "{":
39 | if len(substr) > 2:
40 | post_substr = substr[2:]
41 | new_str += "{" + a + "}{" + b + "}" + post_substr
42 | else:
43 | new_str += "{" + a + "}{" + b + "}"
44 | else:
45 | if len(substr) > 2:
46 | post_substr = substr[2:]
47 | new_str += "{" + a + "}" + b + post_substr
48 | else:
49 | new_str += "{" + a + "}" + b
50 | string = new_str
51 | return string
52 |
53 |
54 | def _fix_a_slash_b(string):
55 | if len(string.split("/")) != 2:
56 | return string
57 | a = string.split("/")[0]
58 | b = string.split("/")[1]
59 | try:
60 | a = int(a)
61 | b = int(b)
62 | assert string == "{}/{}".format(a, b)
63 | new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
64 | return new_string
65 | except:
66 | return string
67 |
68 |
69 | def _remove_right_units(string):
70 | # "\\text{ " only ever occurs (at least in the val set) when describing units
71 | if "\\text{ " in string:
72 | splits = string.split("\\text{ ")
73 | assert len(splits) == 2
74 | return splits[0]
75 | else:
76 | return string
77 |
78 |
79 | def _fix_sqrt(string):
80 | if "\\sqrt" not in string:
81 | return string
82 | splits = string.split("\\sqrt")
83 | new_string = splits[0]
84 | for split in splits[1:]:
85 | if split[0] != "{":
86 | a = split[0]
87 | new_substr = "\\sqrt{" + a + "}" + split[1:]
88 | else:
89 | new_substr = "\\sqrt" + split
90 | new_string += new_substr
91 | return new_string
92 |
93 |
94 | def _strip_string(string):
95 | # linebreaks
96 | string = string.replace("\n", "")
97 | # print(string)
98 |
99 | # remove inverse spaces
100 | string = string.replace("\\!", "")
101 | # print(string)
102 |
103 | # replace \\ with \
104 | string = string.replace("\\\\", "\\")
105 | # print(string)
106 |
107 | # replace tfrac and dfrac with frac
108 | string = string.replace("tfrac", "frac")
109 | string = string.replace("dfrac", "frac")
110 | # print(string)
111 |
112 | # remove \left and \right
113 | string = string.replace("\\left", "")
114 | string = string.replace("\\right", "")
115 | # print(string)
116 |
117 | # Remove circ (degrees)
118 | string = string.replace("^{\\circ}", "")
119 | string = string.replace("^\\circ", "")
120 |
121 | # remove dollar signs
122 | string = string.replace("\\$", "")
123 |
124 | # remove units (on the right)
125 | string = _remove_right_units(string)
126 |
127 | # remove percentage
128 | string = string.replace("\\%", "")
129 | string = string.replace("\%", "")
130 |
131 | # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
132 | string = string.replace(" .", " 0.")
133 | string = string.replace("{.", "{0.")
134 | # if empty, return empty string
135 | if len(string) == 0:
136 | return string
137 | if string[0] == ".":
138 | string = "0" + string
139 |
140 | # to consider: get rid of e.g. "k = " or "q = " at beginning
141 | if len(string.split("=")) == 2:
142 | if len(string.split("=")[0]) <= 2:
143 | string = string.split("=")[1]
144 |
145 | # fix sqrt3 --> sqrt{3}
146 | string = _fix_sqrt(string)
147 |
148 | # remove spaces
149 | string = string.replace(" ", "")
150 |
151 | # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
152 | string = _fix_fracs(string)
153 |
154 | # manually change 0.5 --> \frac{1}{2}
155 | if string == "0.5":
156 | string = "\\frac{1}{2}"
157 |
158 | # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
159 | string = _fix_a_slash_b(string)
160 |
161 | return string
162 |
--------------------------------------------------------------------------------
/train_code/models/__pycache__/frozen_layers.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/models/__pycache__/frozen_layers.cpython-310.pyc
--------------------------------------------------------------------------------
/train_code/models/__pycache__/frozen_layers.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/models/__pycache__/frozen_layers.cpython-311.pyc
--------------------------------------------------------------------------------
/train_code/models/__pycache__/frozen_layers.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/models/__pycache__/frozen_layers.cpython-39.pyc
--------------------------------------------------------------------------------
/train_code/models/__pycache__/model.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/models/__pycache__/model.cpython-310.pyc
--------------------------------------------------------------------------------
/train_code/models/__pycache__/model.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/models/__pycache__/model.cpython-311.pyc
--------------------------------------------------------------------------------
/train_code/models/__pycache__/model.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/models/__pycache__/model.cpython-39.pyc
--------------------------------------------------------------------------------
/train_code/models/__pycache__/quantize.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/models/__pycache__/quantize.cpython-310.pyc
--------------------------------------------------------------------------------
/train_code/models/__pycache__/quantize.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/models/__pycache__/quantize.cpython-311.pyc
--------------------------------------------------------------------------------
/train_code/models/__pycache__/reward_model.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/models/__pycache__/reward_model.cpython-310.pyc
--------------------------------------------------------------------------------
/train_code/models/__pycache__/reward_model.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/models/__pycache__/reward_model.cpython-311.pyc
--------------------------------------------------------------------------------
/train_code/models/__pycache__/rl_model.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/models/__pycache__/rl_model.cpython-310.pyc
--------------------------------------------------------------------------------
/train_code/models/__pycache__/rl_model.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/models/__pycache__/rl_model.cpython-311.pyc
--------------------------------------------------------------------------------
/train_code/models/__pycache__/tokenizer_utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/models/__pycache__/tokenizer_utils.cpython-310.pyc
--------------------------------------------------------------------------------
/train_code/models/__pycache__/tokenizer_utils.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/models/__pycache__/tokenizer_utils.cpython-311.pyc
--------------------------------------------------------------------------------
/train_code/models/__pycache__/tokenizer_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/models/__pycache__/tokenizer_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/train_code/models/__pycache__/tp.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/models/__pycache__/tp.cpython-310.pyc
--------------------------------------------------------------------------------
/train_code/models/__pycache__/tp.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/models/__pycache__/tp.cpython-311.pyc
--------------------------------------------------------------------------------
/train_code/models/frozen_layers.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The GPT-Accelera Team
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import Optional
8 |
9 | import torch
10 | import torch.nn as nn
11 | from torch.nn import functional as F
12 | from torch import Tensor
13 |
14 |
15 | try:
16 | from apex.normalization.fused_layer_norm import FusedRMSNormFunction
17 |
18 | # print(
19 | # "`apex` is installed. You can use fused RMSNorm by set_global_compile_mode(False)."
20 | # )
21 | except ImportError as e:
22 | FusedRMSNormFunction = None
23 | # print("`apex` is not installed. Reverting to non-fused RMSNorm.")
24 |
25 | # whether to use fused RMSNorm or not (default: no)
26 | _GLOBAL_IN_COMPILE_MODE = True
27 |
28 |
29 | def find_multiple(n: int, k: int) -> int:
30 | if n % k == 0:
31 | return n
32 | return n + k - (n % k)
33 |
34 |
35 | class FrozenEmbedding(nn.Module):
36 | __constants__ = [
37 | "num_embeddings",
38 | "embedding_dim",
39 | "padding_idx",
40 | "max_norm",
41 | "norm_type",
42 | "scale_grad_by_freq",
43 | "sparse",
44 | ]
45 |
46 | num_embeddings: int
47 | embedding_dim: int
48 | padding_idx: Optional[int]
49 | max_norm: Optional[float]
50 | norm_type: float
51 | scale_grad_by_freq: bool
52 | weight: Tensor
53 | freeze: bool
54 | sparse: bool
55 |
56 | def __init__(
57 | self,
58 | num_embeddings: int,
59 | embedding_dim: int,
60 | device=None,
61 | dtype=None,
62 | ) -> None:
63 | factory_kwargs = {"device": device, "dtype": dtype}
64 | super().__init__()
65 | self.num_embeddings = num_embeddings
66 | self.embedding_dim = embedding_dim
67 | self.padding_idx = None
68 | self.max_norm = None
69 | self.norm_type = 2.0
70 | self.scale_grad_by_freq = False
71 | self.sparse = False
72 | self.vocab_start_index = None
73 | self.vocab_end_index = None
74 | self.num_embeddings_per_partition = None
75 | self.register_buffer(
76 | "weight", torch.empty((num_embeddings, embedding_dim), **factory_kwargs)
77 | )
78 |
79 | def forward(self, input: Tensor) -> Tensor:
80 | if self.num_embeddings_per_partition is None:
81 | return F.embedding(
82 | input,
83 | self.weight,
84 | self.padding_idx,
85 | self.max_norm,
86 | self.norm_type,
87 | self.scale_grad_by_freq,
88 | self.sparse,
89 | )
90 | else:
91 | # Build the mask.
92 | print("vocab_start_index", self.vocab_start_index)
93 | print("vocab_end_index", self.vocab_end_index)
94 | input_mask = (input < self.vocab_start_index) | (
95 | input >= self.vocab_end_index
96 | )
97 | # Mask the input.
98 | masked_input = input.clone() - self.vocab_start_index
99 | masked_input[input_mask] = 0
100 | # Get the embeddings.
101 | output_parallel = F.embedding(
102 | masked_input,
103 | self.weight,
104 | self.padding_idx,
105 | self.max_norm,
106 | self.norm_type,
107 | self.scale_grad_by_freq,
108 | self.sparse,
109 | )
110 | # Mask the output embedding.
111 | output_parallel[input_mask, :] = 0.0
112 | return output_parallel
113 |
114 | def extra_repr(self) -> str:
115 | s = "{num_embeddings}, {embedding_dim}"
116 | if self.padding_idx is not None:
117 | s += ", padding_idx={padding_idx}"
118 | if self.max_norm is not None:
119 | s += ", max_norm={max_norm}"
120 | if self.norm_type != 2.0:
121 | s += ", norm_type={norm_type}"
122 | if self.scale_grad_by_freq is not False:
123 | s += ", scale_grad_by_freq={scale_grad_by_freq}"
124 | if self.sparse is not False:
125 | s += ", sparse=True"
126 | return s.format(**self.__dict__)
127 |
128 |
129 | class FrozenRMSNorm(nn.Module):
130 | def __init__(self, dim: int, eps: float = 1e-5):
131 | super().__init__()
132 | self.eps = eps
133 | self.register_buffer("weight", torch.ones(dim))
134 |
135 | global _GLOBAL_IN_COMPILE_MODE
136 | self.in_compile_mode = _GLOBAL_IN_COMPILE_MODE
137 |
138 | def _norm(self, x):
139 | return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
140 |
141 | def forward(self, x: Tensor) -> Tensor:
142 | if self.in_compile_mode or FusedRMSNormFunction is None:
143 | with torch.autocast(device_type="cuda", enabled=False):
144 | output = self._norm(x.float()).to(dtype=x.dtype)
145 | return output * self.weight
146 | else:
147 | with torch.autocast(device_type="cuda", enabled=False):
148 | output = FusedRMSNormFunction.apply(
149 | x,
150 | self.weight.size(),
151 | self.eps,
152 | False,
153 | )
154 | return output * self.weight
155 |
156 |
157 | class FrozenLinear(nn.Module):
158 | __constants__ = ["in_features", "out_features"]
159 | in_features: int
160 | out_features: int
161 | weight: Tensor
162 |
163 | def __init__(
164 | self,
165 | in_features: int,
166 | out_features: int,
167 | bias: bool = True,
168 | device=None,
169 | dtype=None,
170 | ) -> None:
171 | factory_kwargs = {"device": device, "dtype": dtype}
172 | super().__init__()
173 | self.in_features = in_features
174 | self.out_features = out_features
175 | self.register_buffer(
176 | "weight", torch.empty((out_features, in_features), **factory_kwargs)
177 | )
178 | if bias:
179 | self.register_buffer("bias", torch.empty((out_features,), **factory_kwargs))
180 | else:
181 | self.register_buffer("bias", None)
182 |
183 | def forward(self, input: Tensor) -> Tensor:
184 | return F.linear(input, self.weight, self.bias)
185 |
186 | def extra_repr(self) -> str:
187 | return f"in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}"
188 |
--------------------------------------------------------------------------------
/train_code/models/reward_model.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The GPT-Accelera Team
2 | # Copyright 2023 The Self-Align Team
3 | # Copyright 2023 The Alpaca Team
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | from dataclasses import dataclass
18 | import math
19 | from typing import Optional, Dict, Sequence, Union
20 |
21 | import einops
22 | import torch
23 | from torch import Tensor, nn
24 | import torch.nn.functional as F
25 |
26 | from models.model import ModelArgs, Transformer
27 |
28 |
29 | def unpack_dict(
30 | d: Dict, keys: Sequence[str], return_type: type = tuple
31 | ) -> Union[Sequence, Dict]:
32 | if return_type in (tuple, list):
33 | return return_type(d[key] for key in keys)
34 | elif return_type == dict:
35 | return {key: d[key] for key in keys}
36 | else:
37 | raise ValueError(f"Unknown return_type: {return_type}")
38 |
39 |
40 | def batch_select(input: Tensor, index: Tensor):
41 | """Select elements from a batched tensor with a batched index tensor.
42 |
43 | Example:
44 | input = torch.tensor([
45 | [0, 1, 2],
46 | [3, 0, 9],
47 | [6, 7, 8],
48 | ])
49 | index = torch.tensor([[0, 1], [1, 0], [0, 0]])
50 | batch_select(input, index) = tensor([
51 | [0, 1],
52 | [0, 3],
53 | [6, 6]
54 | ])
55 | """
56 | dummy_index = torch.arange(input.size(0), device=input.device).unsqueeze(-1)
57 | return input[dummy_index, index]
58 |
59 |
60 | @dataclass
61 | class RewardArgs:
62 | backbone_args: ModelArgs
63 |
64 | @classmethod
65 | def from_name(cls, name: str):
66 | return cls(backbone_args=ModelArgs.from_name(name))
67 |
68 |
69 | class RewardModel(nn.Module):
70 | def __init__(self, config: RewardArgs, **kwargs) -> None:
71 | super().__init__()
72 | self.config = config
73 | self.backbone_model = Transformer(config.backbone_args, **kwargs)
74 |
75 | def forward(
76 | self,
77 | idx: Tensor,
78 | eos_pos: Optional[Tensor] = None,
79 | ) -> Tensor:
80 | input_pos = torch.arange(0, idx.size(-1), device=idx.device)
81 | rewards = self.backbone_model(idx, input_pos=input_pos, fully_causal=True)
82 | rewards = rewards.mean(dim=-1)
83 |
84 | if eos_pos is not None:
85 | eos_pos = eos_pos.unsqueeze(-1)
86 | rewards = batch_select(rewards, eos_pos).squeeze(-1)
87 |
88 | return rewards
89 |
90 | @classmethod
91 | def from_name(cls, name: str, **kwargs):
92 | return cls(RewardArgs.from_name(name), **kwargs)
93 |
94 |
95 | def apply_reward_modeling_head(
96 | transformer: Transformer, requires_grad=False, init_sceheme="zeros"
97 | ):
98 | output_module = transformer.output
99 | # Linear's weight matrix is transposed, and is of shape
100 | # (linear.out_features, linear.in_features)
101 |
102 | # Temp fix due to https://github.com/pytorch/pytorch/issues/106951
103 | reward_head_weight = torch.zeros_like(output_module.weight)[:2, :]
104 | if init_sceheme == "zeros":
105 | output_module.weight = nn.Parameter(
106 | reward_head_weight,
107 | requires_grad=requires_grad,
108 | )
109 | elif init_sceheme == "semantic":
110 | # ['### Preferred Output is '] [835, 4721, 14373, 10604, 338, 29871]
111 | # ['### Preferred Output is 1.'] [835, 4721, 14373, 10604, 338, 29871, 29896, 29889]
112 | # ['### Preferred Output is 2.'] [835, 4721, 14373, 10604, 338, 29871, 29906, 29889]
113 | token_1_id = 29896
114 | token_2_id = 29906
115 | reward_head_weight[0, :] = output_module.weight[token_2_id, :]
116 | reward_head_weight[1, :] = -output_module.weight[token_1_id, :]
117 | output_module.weight = nn.Parameter(
118 | reward_head_weight,
119 | requires_grad=requires_grad,
120 | )
121 | elif init_sceheme == "random":
122 | generator = torch.Generator(device=reward_head_weight.device)
123 | generator.manual_seed(42)
124 | nn.init.kaiming_uniform_(
125 | reward_head_weight, a=math.sqrt(5), generator=generator
126 | )
127 | output_module.weight = nn.Parameter(
128 | reward_head_weight * math.sqrt(2.0),
129 | requires_grad=requires_grad,
130 | )
131 | else:
132 | raise ValueError(f"Unknown init_scheme: {init_sceheme}")
133 | setattr(output_module, "out_features", 2)
134 |
135 |
136 | def compute_pairwise_reward_modeling_loss(model, inputs, return_outputs=False):
137 | # input_ids, attention_mask each of size (bsz, num_candidates, seq_len).
138 | # index_0, index_1 each of size (bsz, num_pairs); indexes into input_ids.
139 | # choice of size (bsz, num_pairs); 1 if index_1's seq is chosen, 0 otherwise.
140 | input_ids, eos_pos, index_0, index_1, choice = unpack_dict(
141 | inputs, keys=("input_ids", "eos_pos", "index_0", "index_1", "choice")
142 | )
143 | num_candidates, num_pairs = input_ids.size(1), choice.size(1)
144 | input_ids_flat = einops.rearrange(input_ids, "b c l -> (b c) l")
145 | eos_pos_flat = einops.rearrange(eos_pos, "b c -> (b c)")
146 | input_pos_flat = torch.arange(
147 | 0, input_ids_flat.size(-1), device=input_ids_flat.device
148 | )
149 | outputs = model(
150 | input_ids=input_ids_flat,
151 | input_pos=input_pos_flat,
152 | eos_pos=eos_pos_flat,
153 | )
154 | rewards_flat = outputs.rewards
155 | rewards = einops.rearrange(
156 | rewards_flat, "(b c) -> b c", c=num_candidates
157 | ) # Size: (bsz, num_candidates).
158 |
159 | rewards_0, rewards_1 = tuple(
160 | batch_select(rewards, index) for index in (index_0, index_1)
161 | ) # Size: (bsz, num_pairs).
162 | logits = rewards_1 - rewards_0 # Size: (bsz, num_pairs).
163 | # Type casting of `choice` is due to amp.autocast context manager.
164 | loss = F.binary_cross_entropy_with_logits(
165 | logits, choice.to(logits.dtype), reduction="mean"
166 | )
167 | return (loss, dict(logits=logits)) if return_outputs else loss
168 |
169 |
170 | def compute_pairwise_reward_modeling_metrics(
171 | predictions: torch.Tensor, label_ids: torch.Tensor
172 | ) -> Dict:
173 | # eval_prediction.label_ids is a tuple that matches up with `training_args.label_names`.
174 | logits = torch.tensor(predictions).squeeze(-1)
175 | labels = torch.tensor(label_ids[-1]).squeeze(-1)
176 | predictions = (logits >= 0.0).long()
177 | accuracy = predictions.eq(labels).float().mean().item()
178 | label_positive_rate = (labels == 1).float().mean().item()
179 | return dict(
180 | accuracy=accuracy,
181 | label_positive_rate=label_positive_rate,
182 | )
183 |
--------------------------------------------------------------------------------
/train_code/scripts/convert.sh:
--------------------------------------------------------------------------------
1 | # export DATA_DIR=/cfs/hadoop-aipnlp/zengweihao02/modelzoo/Llama-2-70b-hf/Llama-2-70b-hf
2 | # #export MODEL_REPO=ScalableMath/llemma-7b-sft-metamath-level-1to3-hf
3 |
4 | # python convert_hf_checkpoint.py \
5 | # --checkpoint_dir $DATA_DIR \
6 | # --target_precision bf16
7 |
8 |
9 | export DATA_DIR=/cfs/hadoop-aipnlp/zengweihao02/modelzoo/Llama-3.1-8B-Instruct
10 | #export MODEL_REPO=ScalableMath/llemma-7b-sft-metamath-level-1to3-hf
11 |
12 |
13 | #export DATA_DIR=/mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/zengweihao02/modelzoo/deepseek-math-7b-base/deepseek-math-7b-base
14 | # python convert_hf_checkpoint.py \
15 | # --checkpoint_dir $DATA_DIR \
16 | # --target_precision bf16
17 | #/cfs/hadoop-aipnlp/zengweihao02/b-star/easy-to-hard-main-share/scripts/convert_hf_checkpoint_llama3.py
18 |
19 | python /cfs/hadoop-aipnlp/zengweihao02/b-star/easy-to-hard-main-share/scripts/convert_hf_checkpoint_llama3.py \
20 | --checkpoint_dir $DATA_DIR \
21 | --target_precision bf16
22 |
23 | # python /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/zengweihao02/easy2hard/share/project/weihao/easy-to-hard-main-share/easy-to-hard-main-share/scripts/convert_hf_checkpoint.py \
24 | # --checkpoint_dir $DATA_DIR \
25 | # --target_precision bf16
--------------------------------------------------------------------------------
/train_code/scripts/convert_checkpoint_to_hf.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
2 | from tqdm import tqdm
3 | import torch
4 | import re
5 | import argparse
6 | import os
7 | import glob
8 |
9 | # we need to check that we have login the HF account
10 | # !huggingface-cli whoami
11 | # !huggingface-cli login
12 |
13 |
14 | def load_and_merge_models(
15 | tp_ckpt_name, pretrain_name, tokenizer_name, save_name_hf, push_to_hf_hub_name
16 | ):
17 | assert (
18 | save_name_hf or push_to_hf_hub_name
19 | ), "Please provide a save path or push to HF hub name"
20 |
21 | tp_model_list = []
22 |
23 | last_checkpoint_file = os.path.join(tp_ckpt_name, "last_checkpoint")
24 | with open(last_checkpoint_file, "r") as f:
25 | last_checkpoint_file = f.readline().strip()
26 |
27 | last_checkpoint_file = last_checkpoint_file.split("/")[-1]
28 | last_checkpoint_file = os.path.join(tp_ckpt_name, last_checkpoint_file)
29 |
30 | print("Loading checkpoint files:", last_checkpoint_file)
31 | for file in sorted(glob.glob(last_checkpoint_file)):
32 | tp_model_list.append(
33 | torch.load(
34 | file,
35 | mmap=True,
36 | )["model"]
37 | )
38 |
39 | print("Loading HF model...")
40 | tokenizer = AutoTokenizer.from_pretrained(
41 | tokenizer_name,
42 | )
43 |
44 | model = AutoModelForCausalLM.from_pretrained(
45 | pretrain_name,
46 | # device_map="cpu",
47 | load_in_8bit=False,
48 | torch_dtype=torch.bfloat16,
49 | )
50 | cpu_state_dict = model.cpu().state_dict()
51 |
52 | replaced_keys = set()
53 |
54 | print("Convert to HF model...")
55 | num_tp = len(tp_model_list)
56 |
57 | state_dict = {}
58 |
59 | for key in tp_model_list[0].keys():
60 | if "wo" in key or "w2" in key:
61 | state_dict[key] = torch.cat(
62 | [tp_model_list[i][key].cpu() for i in range(num_tp)], dim=1
63 | )
64 | elif "wqkv" in key:
65 | state_dict[key] = torch.stack(
66 | [tp_model_list[i][key].cpu() for i in range(num_tp)], dim=0
67 | )
68 | elif "output" in key:
69 | state_dict[key] = torch.cat(
70 | [tp_model_list[i][key].cpu() for i in range(num_tp)], dim=1
71 | )
72 | else:
73 | state_dict[key] = torch.cat(
74 | [tp_model_list[i][key].cpu() for i in range(num_tp)], dim=0
75 | )
76 |
77 | pattern = r"layers\.(\d+)\."
78 |
79 | for key in state_dict.keys():
80 | layer = None
81 | match = re.search(pattern, key)
82 | # layer number except for:
83 | # lm_head.weight
84 | if match:
85 | layer = match.group(1)
86 | elif "output.weight" in key:
87 | name = f"lm_head.weight"
88 | print(cpu_state_dict[name].size(), state_dict[key].size())
89 | # repeat on dim 0 to match the size
90 | repeat_size = cpu_state_dict[name].size(0) // state_dict[key].size(0)
91 | new_state_dict = state_dict[key].repeat(repeat_size, 1)
92 | cpu_state_dict[name] = 0.0 * cpu_state_dict[name] + new_state_dict
93 | replaced_keys.add(name)
94 | else:
95 | raise ValueError(f"Invalid key: {key}")
96 |
97 | print("Converting layer", key)
98 | if "wqkv" in key:
99 | merged_q, merged_k, merged_v = [], [], []
100 | reconstruct_q, reconstruct_k = [], []
101 |
102 | if state_dict[key].size(2) == 4096:
103 | n_heads, n_local_heads = 32, 32
104 | elif state_dict[key].size(2) == 5120:
105 | n_heads, n_local_heads = 40, 40
106 | elif state_dict[key].size(2) == 6656:
107 | n_heads, n_local_heads = 52, 52
108 | elif state_dict[key].size(2) == 8192:
109 | n_heads, n_local_heads = 64, 8
110 | else:
111 | raise ValueError(f"Invalid size for {key}: {state_dict[key].size()}")
112 |
113 | head_dim = state_dict[key].size(1) // (n_heads + n_local_heads * 2)
114 |
115 | weight_splits = [
116 | head_dim * n_heads,
117 | head_dim * n_local_heads,
118 | head_dim * n_local_heads,
119 | ]
120 |
121 | for split_idx in range(state_dict[key].size(0)):
122 | chunk = state_dict[key][split_idx]
123 | q, k, v = chunk.split(weight_splits, dim=0)
124 | merged_q.append(q)
125 | merged_k.append(k)
126 | merged_v.append(v)
127 | merged_q = torch.cat(merged_q, dim=0)
128 | merged_k = torch.cat(merged_k, dim=0)
129 | merged_v = torch.cat(merged_v, dim=0)
130 |
131 | #### qk need reconstruction ####
132 | split_qs = torch.split(merged_q, split_size_or_sections=128, dim=0)
133 | split_ks = torch.split(merged_k, split_size_or_sections=128, dim=0)
134 | for split in split_qs:
135 | matrix0 = split[::2, :]
136 | matrix1 = split[1::2, :]
137 | reconstruct_q.append(matrix0)
138 | reconstruct_q.append(matrix1)
139 | reconstruct_q = torch.cat(reconstruct_q, dim=0)
140 | for split in split_ks:
141 | matrix0 = split[::2, :]
142 | matrix1 = split[1::2, :]
143 | reconstruct_k.append(matrix0)
144 | reconstruct_k.append(matrix1)
145 | reconstruct_k = torch.cat(reconstruct_k, dim=0)
146 | #### qk need reconstruction ####
147 |
148 | name = f"model.layers.{layer}.self_attn.q_proj.weight"
149 | cpu_state_dict[name] = reconstruct_q
150 | replaced_keys.add(name)
151 |
152 | name = f"model.layers.{layer}.self_attn.k_proj.weight"
153 | cpu_state_dict[name] = reconstruct_k
154 | replaced_keys.add(name)
155 |
156 | name = f"model.layers.{layer}.self_attn.v_proj.weight"
157 | cpu_state_dict[name] = merged_v
158 | replaced_keys.add(name)
159 |
160 | if "wo" in key:
161 | name = f"model.layers.{layer}.self_attn.o_proj.weight"
162 | cpu_state_dict[name] = state_dict[key]
163 | replaced_keys.add(name)
164 | if "w1" in key:
165 | name = f"model.layers.{layer}.mlp.gate_proj.weight"
166 | cpu_state_dict[name] = state_dict[key]
167 | replaced_keys.add(name)
168 | if "w3" in key:
169 | name = f"model.layers.{layer}.mlp.up_proj.weight"
170 | cpu_state_dict[name] = state_dict[key]
171 | replaced_keys.add(name)
172 | if "w2" in key:
173 | name = f"model.layers.{layer}.mlp.down_proj.weight"
174 | cpu_state_dict[name] = state_dict[key]
175 | replaced_keys.add(name)
176 |
177 | unreplaced_keys = set(cpu_state_dict.keys()) - replaced_keys
178 | print("Unreplaced keys:", unreplaced_keys)
179 |
180 | print("Loading state dict...")
181 |
182 | model.load_state_dict(cpu_state_dict, strict=False)
183 |
184 | print("Saving HF model...")
185 |
186 | if save_name_hf is not None:
187 | model.save_pretrained(save_name_hf)
188 | config = AutoConfig.from_pretrained(pretrain_name)
189 | tokenizer.save_pretrained(save_name_hf)
190 | config.save_pretrained(save_name_hf)
191 | else:
192 | model.push_to_hub(push_to_hf_hub_name, private=True, safe_serialization=False)
193 |
194 |
195 | if __name__ == "__main__":
196 | parser = argparse.ArgumentParser(description="Process some integers.")
197 | parser.add_argument(
198 | "--tp_ckpt_name", type=str, help="Path to the TP checkpoint name", required=True
199 | )
200 | parser.add_argument(
201 | "--tokenizer_name", type=str, help="Path to the tokenizer name", required=True
202 | )
203 | parser.add_argument(
204 | "--pretrain_name", type=str, help="Path to the pretrain name", required=True
205 | )
206 | parser.add_argument(
207 | "--save_name_hf", type=str, default=None, help="Path to save the HF model"
208 | )
209 | parser.add_argument(
210 | "--push_to_hf_hub_name", type=str, default=None, help="Push to HF hub"
211 | )
212 |
213 | args = parser.parse_args()
214 | load_and_merge_models(
215 | args.tp_ckpt_name,
216 | args.pretrain_name,
217 | args.tokenizer_name,
218 | args.save_name_hf,
219 | args.push_to_hf_hub_name,
220 | )
221 |
--------------------------------------------------------------------------------
/train_code/scripts/convert_hf_checkpoint.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | import json
7 | import sys
8 | from pathlib import Path
9 | from typing import Optional
10 |
11 | import torch
12 | import re
13 |
14 | # support running without installing as a package
15 | wd = Path(__file__).parent.parent.resolve()
16 | sys.path.append(str(wd))
17 |
18 | from models.model import ModelArgs
19 |
20 |
21 | @torch.inference_mode()
22 | def convert_hf_checkpoint(
23 | *,
24 | checkpoint_dir: Path = Path(
25 | "checkpoints/meta-Transformer/Transformer-2-7b-chat-hf"
26 | ),
27 | model_name: Optional[str] = None,
28 | target_precision: str = "fp32",
29 | ) -> None:
30 | if model_name is None:
31 | model_name = checkpoint_dir.name
32 |
33 | config = ModelArgs.from_name(model_name)
34 | print(f"Model config {config.__dict__}")
35 |
36 | # Load the json file containing weight mapping
37 | model_map_json = checkpoint_dir / "pytorch_model.bin.index.json"
38 |
39 | assert model_map_json.is_file()
40 |
41 | with open(model_map_json) as json_map:
42 | bin_index = json.load(json_map)
43 |
44 | weight_map = {
45 | "model.embed_tokens.weight": "tok_embeddings.weight",
46 | "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
47 | "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
48 | "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
49 | "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
50 | "model.layers.{}.self_attn.rotary_emb.inv_freq": None,
51 | "model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight",
52 | "model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
53 | "model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
54 | "model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
55 | "model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
56 | "model.norm.weight": "norm.weight",
57 | "lm_head.weight": "output.weight",
58 | }
59 | bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()}
60 |
61 | def permute(w, n_head):
62 | dim = config.dim
63 | return (
64 | w.view(n_head, 2, config.head_dim // 2, dim)
65 | .transpose(1, 2)
66 | .reshape(config.head_dim * n_head, dim)
67 | )
68 |
69 | merged_result = {}
70 | for file in sorted(bin_files):
71 | state_dict = torch.load(
72 | str(file), map_location="cpu", mmap=True, weights_only=True
73 | )
74 |
75 | if target_precision == "fp16":
76 | for key in tuple(state_dict.keys()):
77 | state_dict[key] = state_dict[key].half()
78 | elif target_precision == "bf16":
79 | for key in tuple(state_dict.keys()):
80 | state_dict[key] = state_dict[key].bfloat16()
81 | elif target_precision == "fp32":
82 | pass
83 | else:
84 | raise ValueError(f"Unsupported target_precision {target_precision}")
85 | merged_result.update(state_dict)
86 | final_result = {}
87 | for key, value in merged_result.items():
88 | if "layers" in key:
89 | abstract_key = re.sub(r"(\d+)", "{}", key)
90 | layer_num = re.search(r"\d+", key).group(0)
91 | new_key = weight_map[abstract_key]
92 | if new_key is None:
93 | continue
94 | new_key = new_key.format(layer_num)
95 | else:
96 | new_key = weight_map[key]
97 |
98 | if len(value.shape) == 2 and value.size(1) == 32016:
99 | value = value[:, :32000]
100 | if value.size(0) == 32016:
101 | value = value[:32000, :]
102 |
103 | final_result[new_key] = value
104 |
105 | for key in tuple(final_result.keys()):
106 | if "wq" in key:
107 | q = final_result[key]
108 | k = final_result[key.replace("wq", "wk")]
109 | v = final_result[key.replace("wq", "wv")]
110 | q = permute(q, config.n_head)
111 | k = permute(k, config.n_local_heads)
112 | final_result[key.replace("wq", "wqkv")] = torch.cat([q, k, v])
113 | del final_result[key]
114 | del final_result[key.replace("wq", "wk")]
115 | del final_result[key.replace("wq", "wv")]
116 | print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}")
117 | torch.save(final_result, checkpoint_dir / "model.pth")
118 |
119 |
120 | if __name__ == "__main__":
121 | import argparse
122 |
123 | parser = argparse.ArgumentParser(description="Convert HuggingFace checkpoint.")
124 | parser.add_argument(
125 | "--checkpoint_dir",
126 | type=Path,
127 | default=Path("checkpoints/meta-llama/llama-2-7b-chat-hf"),
128 | )
129 | parser.add_argument("--model_name", type=str, default=None)
130 | parser.add_argument("--target_precision", type=str, default="fp32")
131 |
132 | args = parser.parse_args()
133 | convert_hf_checkpoint(
134 | checkpoint_dir=args.checkpoint_dir,
135 | model_name=args.model_name,
136 | target_precision=args.target_precision,
137 | )
138 |
--------------------------------------------------------------------------------
/train_code/scripts/convert_hf_checkpoint_llama3.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | import json
7 | import sys
8 | from pathlib import Path
9 | from typing import Optional
10 | # from safetensors import load_file
11 | from safetensors.torch import load_file
12 | import torch
13 | import re
14 |
15 | # support running without installing as a package
16 | wd = Path(__file__).parent.parent.resolve()
17 | sys.path.append(str(wd))
18 |
19 | from models.model import ModelArgs
20 |
21 |
22 | @torch.inference_mode()
23 | def convert_hf_checkpoint(
24 | *,
25 | checkpoint_dir: Path = Path(
26 | "checkpoints/meta-Transformer/Transformer-2-7b-chat-hf"
27 | ),
28 | model_name: Optional[str] = None,
29 | target_precision: str = "fp32",
30 | ) -> None:
31 | if model_name is None:
32 | model_name = checkpoint_dir.name
33 | print(f"Model Name {model_name}")
34 | config = ModelArgs.from_name(model_name)
35 | print(f"Model config {config.__dict__}")
36 |
37 | # Load the json file containing weight mapping
38 | #model_map_json = checkpoint_dir / "pytorch_model.bin.index.json"
39 |
40 | model_map_json = checkpoint_dir / "model.safetensors.index.json"
41 |
42 | assert model_map_json.is_file()
43 |
44 | with open(model_map_json) as json_map:
45 | bin_index = json.load(json_map)
46 |
47 | weight_map = {
48 | "model.embed_tokens.weight": "tok_embeddings.weight",
49 | "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
50 | "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
51 | "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
52 | "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
53 | "model.layers.{}.self_attn.rotary_emb.inv_freq": None,
54 | "model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight",
55 | "model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
56 | "model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
57 | "model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
58 | "model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
59 | "model.norm.weight": "norm.weight",
60 | "lm_head.weight": "output.weight",
61 | }
62 |
63 |
64 | # if model_name == "Meta-Llama-3-8B":
65 | # weight_map = {
66 | # "embed_tokens.weight": "tok_embeddings.weight",
67 | # "layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
68 | # "layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
69 | # "layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
70 | # "layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
71 | # "layers.{}.self_attn.rotary_emb.inv_freq": None,
72 | # "layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight",
73 | # "layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
74 | # "layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
75 | # "layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
76 | # "layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
77 | # "norm.weight": "norm.weight",
78 | # "lm_head.weight": "output.weight",
79 | # }
80 | bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()}
81 |
82 | def permute(w, n_head):
83 | dim = config.dim
84 | return (
85 | w.view(n_head, 2, config.head_dim // 2, dim)
86 | .transpose(1, 2)
87 | .reshape(config.head_dim * n_head, dim)
88 | )
89 |
90 | merged_result = {}
91 |
92 | print(bin_files)
93 | # for file in sorted(bin_files):
94 | # state_dict = torch.load(
95 | # str(file), map_location="cpu", mmap=True, weights_only=True
96 | # )
97 |
98 | # if target_precision == "fp16":
99 | # for key in tuple(state_dict.keys()):
100 | # state_dict[key] = state_dict[key].half()
101 | # elif target_precision == "bf16":
102 | # for key in tuple(state_dict.keys()):
103 | # state_dict[key] = state_dict[key].bfloat16()
104 | # elif target_precision == "fp32":
105 | # pass
106 | # else:
107 | # raise ValueError(f"Unsupported target_precision {target_precision}")
108 | # merged_result.update(state_dict)
109 |
110 | # 假设 bin_files 是存储 .safetensors 文件路径的列表
111 | for file in sorted(bin_files):
112 | # 使用 safetensors 的 load_file 函数来加载权重
113 | state_dict = load_file(str(file))
114 |
115 | # 根据目标精度进行转换
116 | if target_precision == "fp16":
117 | for key in tuple(state_dict.keys()):
118 | state_dict[key] = state_dict[key].half()
119 | elif target_precision == "bf16":
120 | for key in tuple(state_dict.keys()):
121 | state_dict[key] = state_dict[key].bfloat16()
122 | elif target_precision == "fp32":
123 | pass
124 | else:
125 | raise ValueError(f"Unsupported target_precision {target_precision}")
126 |
127 | # 将加载的权重更新到最终的 merged_result 字典中
128 | merged_result.update(state_dict)
129 | final_result = {}
130 | for key, value in merged_result.items():
131 | print("key", key)
132 | if "layers" in key:
133 | abstract_key = re.sub(r"(\d+)", "{}", key)
134 | layer_num = re.search(r"\d+", key).group(0)
135 | new_key = weight_map[abstract_key]
136 | if new_key is None:
137 | continue
138 | new_key = new_key.format(layer_num)
139 | else:
140 | new_key = weight_map[key]
141 |
142 | if len(value.shape) == 2 and value.size(1) == 32016:
143 | value = value[:, :32000]
144 | if value.size(0) == 32016:
145 | value = value[:32000, :]
146 |
147 | final_result[new_key] = value
148 |
149 | for key in tuple(final_result.keys()):
150 | if "wq" in key:
151 | q = final_result[key]
152 | k = final_result[key.replace("wq", "wk")]
153 | v = final_result[key.replace("wq", "wv")]
154 | q = permute(q, config.n_head)
155 | k = permute(k, config.n_local_heads)
156 | final_result[key.replace("wq", "wqkv")] = torch.cat([q, k, v])
157 | del final_result[key]
158 | del final_result[key.replace("wq", "wk")]
159 | del final_result[key.replace("wq", "wv")]
160 | print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}")
161 | torch.save(final_result, checkpoint_dir / "model.pth")
162 |
163 |
164 | if __name__ == "__main__":
165 | import argparse
166 |
167 | parser = argparse.ArgumentParser(description="Convert HuggingFace checkpoint.")
168 | parser.add_argument(
169 | "--checkpoint_dir",
170 | type=Path,
171 | default=Path("checkpoints/meta-llama/llama-2-7b-chat-hf"),
172 | )
173 | parser.add_argument("--model_name", type=str, default=None)
174 | parser.add_argument("--target_precision", type=str, default="fp32")
175 |
176 | args = parser.parse_args()
177 | convert_hf_checkpoint(
178 | checkpoint_dir=args.checkpoint_dir,
179 | model_name=args.model_name,
180 | target_precision=args.target_precision,
181 | )
182 |
--------------------------------------------------------------------------------
/train_code/scripts/download.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | import os
7 | from requests.exceptions import HTTPError
8 | import sys
9 | from pathlib import Path
10 | from typing import Optional
11 |
12 |
13 | def hf_download(
14 | repo_id: Optional[str] = None,
15 | hf_token: Optional[str] = None,
16 | local_dir: Optional[str] = None,
17 | ) -> None:
18 | from huggingface_hub import snapshot_download
19 |
20 | local_dir = local_dir or "checkpoints"
21 |
22 | os.makedirs(f"{local_dir}/{repo_id}", exist_ok=True)
23 | try:
24 | snapshot_download(
25 | repo_id,
26 | local_dir=f"{local_dir}/{repo_id}",
27 | local_dir_use_symlinks=False,
28 | token=hf_token,
29 | )
30 | except HTTPError as e:
31 | if e.response.status_code == 401:
32 | print(
33 | "You need to pass a valid `--hf_token=...` to download private checkpoints."
34 | )
35 | else:
36 | raise e
37 |
38 |
39 | if __name__ == "__main__":
40 | import argparse
41 |
42 | parser = argparse.ArgumentParser(description="Download data from HuggingFace Hub.")
43 | parser.add_argument(
44 | "--repo_id",
45 | type=str,
46 | default="checkpoints/meta-llama/llama-2-7b-chat-hf",
47 | help="Repository ID to download from.",
48 | )
49 | parser.add_argument(
50 | "--local_dir", type=str, default=None, help="Local directory to download to."
51 | )
52 | parser.add_argument(
53 | "--hf_token", type=str, default=None, help="HuggingFace API token."
54 | )
55 |
56 | args = parser.parse_args()
57 | hf_download(args.repo_id, args.hf_token, args.local_dir)
58 |
--------------------------------------------------------------------------------
/train_code/scripts/prepare_ds_math_7b.sh:
--------------------------------------------------------------------------------
1 | set -e
2 | set -x
3 |
4 | export DATA_DIR=/path/to/your/data/directory
5 | export MODEL_REPO=deepseek-ai/deepseek-math-7b-base
6 |
7 | python scripts/download.py \
8 | --repo_id $MODEL_REPO \
9 | --local_dir $DATA_DIR/checkpoints
10 |
11 | python scripts/convert_hf_checkpoint.py \
12 | --checkpoint_dir $DATA_DIR/checkpoints/$MODEL_REPO \
13 | --target_precision bf16
14 |
--------------------------------------------------------------------------------
/train_code/scripts/prepare_llemma_34b.sh:
--------------------------------------------------------------------------------
1 | set -e
2 | set -x
3 |
4 | export DATA_DIR=/path/to/your/data/directory
5 | export MODEL_REPO=EleutherAI/llemma_34b
6 |
7 | python scripts/download.py \
8 | --repo_id $MODEL_REPO \
9 | --local_dir $DATA_DIR/checkpoints
10 |
11 | python scripts/convert_hf_checkpoint.py \
12 | --checkpoint_dir $DATA_DIR/checkpoints/$MODEL_REPO \
13 | --target_precision bf16
14 |
--------------------------------------------------------------------------------
/train_code/scripts/prepare_llemma_7b.sh:
--------------------------------------------------------------------------------
1 | set -e
2 | set -x
3 |
4 | export DATA_DIR=/path/to/your/data/directory
5 | export MODEL_REPO=EleutherAI/llemma_7b
6 |
7 | python scripts/download.py \
8 | --repo_id $MODEL_REPO \
9 | --local_dir $DATA_DIR/checkpoints
10 |
11 | python scripts/convert_hf_checkpoint.py \
12 | --checkpoint_dir $DATA_DIR/checkpoints/$MODEL_REPO \
13 | --target_precision bf16
14 |
--------------------------------------------------------------------------------
/train_code/train_bstar.sh:
--------------------------------------------------------------------------------
1 | export MODEL_REPO= $DATA_DIR/checkpoints/Mistral-7B-v0.1
2 | export OMP_NUM_THREADS=8
3 |
4 | SFT_Midlle_Dir=/yourfolder/bstar_math
5 | SFT_MODEL_SAVE_NAME=math_format_b_star_mistral
6 | NUM_EPOCHS=9
7 | NUM_ITERATIONS=9
8 | NUM_GPUS=8
9 |
10 | BASE_STEP=500
11 | BASE_TARGET_SIZE=135000
12 | GSM8K_TARGET_SIZE=81000
13 |
14 | MATH_TARGET_SIZE=67500
15 |
16 |
17 |
18 |
19 | # Check if the directory exists, if not, create it
20 | if [ ! -d "$SFT_Midlle_Dir" ]; then
21 | mkdir -p "$SFT_Midlle_Dir"
22 | echo "Directory $SFT_Midlle_Dir created."
23 | else
24 | echo "Directory $SFT_Midlle_Dir already exists."
25 | fi
26 |
27 | LOG_DIR=yourfolder/logs
28 | mkdir -p $LOG_DIR
29 | LOG_FILE=$LOG_DIR/train_bstar_log.txt
30 |
31 |
32 | for ((iter=1; iter<=NUM_ITERATIONS; iter++))
33 | do
34 |
35 | GEN_STEP=$(((iter-1) * BASE_STEP))
36 |
37 | sample_num=64
38 |
39 | # You should download following data
40 |
41 | input_data="https://huggingface.co/datasets/AndrewZeng/bstar-math-dev/blob/main/dynamic_ana_1k_withans_math.json"
42 | model_dir="$DATA_DIR/checkpoints/$SFT_MODEL_SAVE_NAME/converted/ckpt$GEN_STEP"
43 | output_dir="$SFT_Midlle_Dir"
44 | tensor_parallel_size=1
45 | top_k=-1
46 | max_tokens=768
47 | wandb_project="vllm_gen"
48 |
49 | # select temperature
50 | start_temp=0.50
51 | end_temp=0.85
52 | temp_step=0.05
53 |
54 | gpu=0
55 |
56 | for temp in $(seq $start_temp $temp_step $end_temp); do
57 | output_file="$output_dir/dynamic_ana_1k_withans_iter${iter}_temp${temp}_sample${sample_num}_math.json"
58 |
59 | # For the first step, the model should be the first epoch checkpoint of SFT model
60 |
61 | # We will determine some configurations on small dev set
62 | CUDA_VISIBLE_DEVICES=$gpu python vllm_infer_auto.py \
63 | --input_data $input_data \
64 | --model_dir $model_dir \
65 | --sample_num $sample_num \
66 | --output_file $output_file \
67 | --tensor_parallel_size $tensor_parallel_size \
68 | --temperature $temp \
69 | --top_k $top_k \
70 | --max_tokens $max_tokens
71 |
72 | gpu=$(( (gpu + 1) % 9 ))
73 | done
74 |
75 | wait
76 |
77 |
78 | start_temp=0.9
79 | end_temp=1.2
80 | temp_step=0.05
81 |
82 |
83 | gpu=0
84 |
85 |
86 | for temp in $(seq $start_temp $temp_step $end_temp); do
87 | output_file="$output_dir/dynamic_ana_1k_withans_iter${iter}_temp${temp}_sample${sample_num}_math.json"
88 |
89 | CUDA_VISIBLE_DEVICES=$gpu python vllm_infer_auto.py \
90 | --input_data $input_data \
91 | --model_dir $model_dir \
92 | --sample_num $sample_num \
93 | --output_file $output_file \
94 | --tensor_parallel_size $tensor_parallel_size \
95 | --temperature $temp \
96 | --top_k $top_k \
97 | --max_tokens $max_tokens
98 |
99 |
100 | gpu=$(( (gpu + 1) % 9 ))
101 | done
102 |
103 | wait
104 |
105 |
106 | temps=(0.50 0.55 0.60 0.65 0.70 0.75 0.80 0.85 0.90 0.95 1.00 1.05 1.10 1.15 1.20)
107 | sample_nums=(64)
108 |
109 | # We will then use reward model to give reward
110 |
111 |
112 |
113 | for temp in "${temps[@]}"; do
114 | for sample_num in "${sample_nums[@]}"; do
115 |
116 | input_path="$output_dir/dynamic_ana_1k_withans_iter${iter}_temp${temp}_sample${sample_num}_math.json"
117 | output_path="$output_dir/dynamic_ana_1k_withans_infer4reward_iter${iter}_temp${temp}_sample${sample_num}_math.json"
118 |
119 | python convert4reward_auto_ground_sample.py \
120 | --input_path "$input_path" \
121 | --output_path "$output_path" \
122 | --sample_num $sample_num \
123 | --num_files -1
124 |
125 | prompt_file="$output_path"
126 | reward_output_file="$output_dir/dynamic_ana_1k_withans_allreward_iter${iter}_temp${temp}_sample${sample_num}_math.json"
127 |
128 | torchrun --standalone --nproc_per_node=8 \
129 | inference_reward_llama3.py \
130 | --prompt_file "$prompt_file" \
131 | --output_file "$reward_output_file" \
132 | --batch_size 200 \
133 | --process_reward_with_answer \
134 | --tensor_parallel_size 1 \
135 | --checkpoint_path $DATA_DIR/checkpoints/Mistral-7B-v0.1/model.pth \
136 | --finetune_checkpoint_path $DATA_DIR/checkpoints/prm_model_mistral_sample_complete
137 |
138 | done
139 | done
140 |
141 |
142 |
143 | best_combination=$(python determine_hyper.py --temps 0.50 0.55 0.60 0.65 0.70 0.75 0.80 0.85 0.90 0.95 1.00 1.05 1.10 1.15 1.20 --sample_nums 64 \
144 | --input_path_template "$output_dir/dynamic_ana_1k_withans_allreward_iter${iter}_temp{temp}_sample{sample_num}_math.json" \
145 | --ref_path_template "$output_dir/dynamic_ana_1k_withans_infer4reward_iter${iter}_temp{temp}_sample{sample_num}_math.json" \
146 | --valid_sample_size 3550 \
147 | --iter $iter)
148 |
149 |
150 | temp=$(echo $best_combination | cut -d' ' -f1)
151 | sample_num=$(echo $best_combination | cut -d' ' -f2)
152 |
153 |
154 | export BEST_MATH_TEMP=$temp
155 | export BEST_MATH_SAMPLE_NUM=$sample_num
156 |
157 |
158 | echo "Best math temp: $BEST_MATH_TEMP" | tee -a $LOG_FILE
159 | echo "Best math sample_num: $BEST_MATH_SAMPLE_NUM" | tee -a $LOG_FILE
160 |
161 |
162 | for ((i=0; i&1 | tee -a $LOG_FILE
183 |
184 |
185 | torchrun --standalone --nproc_per_node=8 \
186 | inference_reward_llama3.py \
187 | --prompt_file $SFT_Midlle_Dir/gsm8k_math_format_infer_part_infer4reward_sample32_iter$((iter))_math.json \
188 | --output_file $SFT_Midlle_Dir/gsm8k_math_format_infer_part_allreward_sample32_iter$((iter))_math.json \
189 | --batch_size 200 \
190 | --process_reward_with_answer \
191 | --tensor_parallel_size 1 \
192 | --checkpoint_path $DATA_DIR/checkpoints/Mistral-7B-v0.1/model.pth \
193 | --finetune_checkpoint_path $DATA_DIR/checkpoints/prm_1to5_model_mistral_sample_complete \
194 | 2>&1 | tee -a $LOG_FILE
195 |
196 | python metric_modiacc_auto.py \
197 | --input $SFT_Midlle_Dir/gsm8k_math_format_infer_part_allreward_sample32_iter$((iter))_math.json \
198 | --output $SFT_Midlle_Dir/gsm8k_math_format_infer_iter$((iter))_135k_math.json \
199 | --ref $SFT_Midlle_Dir/gsm8k_math_format_infer_part_infer4reward_sample32_iter$((iter))_math.json \
200 | --target_size $MATH_TARGET_SIZE \
201 | --correct_num 4 \
202 | --correct_ratio 0.9 \
203 | 2>&1 | tee -a $LOG_FILE
204 |
205 | SFT_TRAIN_DATA=$SFT_Midlle_Dir/gsm8k_math_format_infer_iter$((iter))_135k_math.json
206 |
207 | END_STEP=$((iter * BASE_STEP))
208 |
209 | if [ $iter -eq 1 ]; then
210 | torchrun --standalone --nproc_per_node=$NUM_GPUS \
211 | train_sft_step.py \
212 | --do_train \
213 | --checkpoint_path $MODEL_REPO/model.pth \
214 | --sft_checkpoint_path $DATA_DIR/checkpoints/$SFT_MODEL_SAVE_NAME \
215 | --source_max_len 768 \
216 | --target_max_len 768 \
217 | --total_max_len 1024 \
218 | --per_device_train_batch_size 16 \
219 | --micro_train_batch_size 8 \
220 | --learning_rate 4.12e-6 \
221 | --lr_eta_min 2e-7 \
222 | --num_train_epochs $NUM_EPOCHS \
223 | --dataset "$SFT_TRAIN_DATA" \
224 | --dataset_format "metamath" \
225 | --add_eos_to_marked_target \
226 | --save_strategy "steps" \
227 | --save_steps 500 \
228 | --optim_dtype bf16 \
229 | --save_total_limit 40 \
230 | --tensor_parallel_size 1 \
231 | --end_step $END_STEP \
232 | --save_dir $DATA_DIR/checkpoints/$SFT_MODEL_SAVE_NAME \
233 | 2>&1 | tee -a $LOG_FILE
234 | else
235 | torchrun --standalone --nproc_per_node=$NUM_GPUS \
236 | train_sft_step.py \
237 | --do_train \
238 | --checkpoint_path $MODEL_REPO/model.pth \
239 | --sft_checkpoint_path $DATA_DIR/checkpoints/$SFT_MODEL_SAVE_NAME \
240 | --source_max_len 768 \
241 | --target_max_len 768 \
242 | --total_max_len 1024 \
243 | --per_device_train_batch_size 16 \
244 | --micro_train_batch_size 8 \
245 | --learning_rate 4.12e-6 \
246 | --lr_eta_min 2e-7 \
247 | --num_train_epochs $NUM_EPOCHS \
248 | --dataset "$SFT_TRAIN_DATA" \
249 | --dataset_format "metamath" \
250 | --add_eos_to_marked_target \
251 | --save_strategy "steps" \
252 | --save_steps 500 \
253 | --optim_dtype bf16 \
254 | --save_total_limit 40 \
255 | --tensor_parallel_size 1 \
256 | --end_step $END_STEP \
257 | --save_dir $DATA_DIR/checkpoints/$SFT_MODEL_SAVE_NAME \
258 | --resume_from_checkpoint \
259 | 2>&1 | tee -a $LOG_FILE
260 | fi
261 |
262 | python convert_auto.py \
263 | --checkpoint_dir $DATA_DIR/checkpoints/$SFT_MODEL_SAVE_NAME \
264 | --pretrain_name $MODEL_REPO \
265 | --tokenizer_name $MODEL_REPO \
266 | 2>&1 | tee -a $LOG_FILE
267 | done
268 |
269 |
--------------------------------------------------------------------------------
/train_code/train_reward.sh:
--------------------------------------------------------------------------------
1 | export DATA_DIR=/path/to/your/data/directory
2 |
3 | export MODEL_REPO= $DATA_DIR/checkpoints/Mistral-7B-v0.1
4 | export OMP_NUM_THREADS=4
5 |
6 |
7 | RM_DATA=train_prm_math_shepherd_mistral.json
8 | RM_MODEL_SAVE_NAME=prm_model_mistral_sample_complete
9 |
10 | torchrun --standalone --nproc_per_node=8 \
11 | train_rm_pointwise.py \
12 | --do_train \
13 | --checkpoint_path $MODEL_REPO/model.pth \
14 | --source_max_len 768 \
15 | --target_max_len 768 \
16 | --total_max_len 1024 \
17 | --per_device_train_batch_size 32 \
18 | --micro_train_batch_size 32 \
19 | --learning_rate 2e-6 \
20 | --lr_eta_min 2e-7 \
21 | --num_train_epochs 2 \
22 | --dataset "$RM_DATA" \
23 | --dataset_format "prm-v4" \
24 | --save_strategy epoch \
25 | --save_total_limit 5 \
26 | --train_on_every_token \
27 | --tensor_parallel_size 1 \
28 | --save_only_model True \
29 | --optim_dtype bf16 \
30 | --save_dir $DATA_DIR/checkpoints/$RM_MODEL_SAVE_NAME \
31 | --resume_from_checkpoint
--------------------------------------------------------------------------------
/train_code/train_sft.sh:
--------------------------------------------------------------------------------
1 | export DATA_DIR=/path/to/your/data/directory
2 | export MODEL_REPO= $DATA_DIR/checkpoints/Mistral-7B-v0.1
3 |
4 | export OMP_NUM_THREADS=8
5 |
6 |
7 | SFT_TRAIN_DATA=https://huggingface.co/datasets/AndrewZeng/math-trn-format/blob/main/math_format.json
8 |
9 | # Please download this dataset to local folder
10 | SFT_MODEL_SAVE_NAME=math_format_11k_mistral
11 |
12 | torchrun --standalone --nproc_per_node=8 \
13 | train_sft.py \
14 | --do_train \
15 | --checkpoint_path $MODEL_REPO/model.pth \
16 | --source_max_len 768 \
17 | --target_max_len 768 \
18 | --total_max_len 1024 \
19 | --per_device_train_batch_size 16 \
20 | --micro_train_batch_size 4 \
21 | --learning_rate 5e-6 \
22 | --lr_eta_min 2e-7 \
23 | --num_train_epochs 3 \
24 | --dataset "$SFT_TRAIN_DATA" \
25 | --dataset_format "metamath" \
26 | --add_eos_to_marked_target \
27 | --save_strategy "steps" \
28 | --save_steps 25 \
29 | --optim_dtype bf16 \
30 | --save_total_limit 40 \
31 | --tensor_parallel_size 1 \
32 | --save_dir $DATA_DIR/checkpoints/$SFT_MODEL_SAVE_NAME \
33 | --resume_from_checkpoint
--------------------------------------------------------------------------------
/train_code/trainers/__pycache__/common_utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/trainers/__pycache__/common_utils.cpython-310.pyc
--------------------------------------------------------------------------------
/train_code/trainers/__pycache__/common_utils.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/trainers/__pycache__/common_utils.cpython-311.pyc
--------------------------------------------------------------------------------
/train_code/trainers/__pycache__/ppo_trainer.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/trainers/__pycache__/ppo_trainer.cpython-310.pyc
--------------------------------------------------------------------------------
/train_code/trainers/__pycache__/ppo_trainer.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/trainers/__pycache__/ppo_trainer.cpython-311.pyc
--------------------------------------------------------------------------------
/train_code/trainers/__pycache__/rl_trainer.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/trainers/__pycache__/rl_trainer.cpython-310.pyc
--------------------------------------------------------------------------------
/train_code/trainers/__pycache__/rl_trainer.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/trainers/__pycache__/rl_trainer.cpython-311.pyc
--------------------------------------------------------------------------------
/train_code/trainers/common_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The GPT-Accelera Team
2 | # Copyright 2023 The Alpaca Team
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import argparse
17 | from dataclasses import dataclass
18 | import os
19 | import tempfile
20 | import random
21 | from typing import Callable, Dict, Optional, Sequence, Union, Mapping, Any, Tuple
22 | import logging
23 |
24 | import numpy as np
25 | import torch
26 | import torch.nn.functional as F
27 | import torch.distributed as dist
28 |
29 | from torch.utils.data import Dataset
30 | from torch.utils.data import DataLoader
31 | from torch.utils.data import random_split
32 |
33 | from datasets import load_dataset
34 |
35 | Numeric = Union[int, float]
36 |
37 |
38 | def zip_(*args: Sequence):
39 | """Assert sequences of same length before zipping."""
40 | if len(args) == 0:
41 | return []
42 | assert alleq(args, lambda x, y: len(x) == len(y))
43 | return zip(*args)
44 |
45 |
46 | def mean(*seqs: Sequence[Numeric]) -> Union[Numeric, Sequence[Numeric]]:
47 | singleton = len(seqs) == 1
48 | means = [float(np.mean(seq)) for seq in seqs]
49 | return means[0] if singleton else means
50 |
51 |
52 | def alleq(l: Sequence, f: Optional[Callable] = lambda x, y: x == y):
53 | """Check all arguments in a sequence are equal according to a given criterion.
54 |
55 | Args:
56 | f: A bi-variate boolean function.
57 | l: A list/tuple.
58 |
59 | Returns:
60 | True if everything is equal; otherwise False.
61 | """
62 | return all(f(l[0], li) for li in l[1:])
63 |
64 |
65 | def flatten_dict(nested, sep=".", postprocess_fn=lambda *args: args):
66 | def rec(nest, prefix, into):
67 | for k, v in nest.items():
68 | if sep in k:
69 | raise ValueError(f"separator '{sep}' not allowed to be in key '{k}'")
70 | if isinstance(v, dict): # collections.Mapping fails in py3.10.
71 | rec(v, prefix + k + sep, into)
72 | else:
73 | v = postprocess_fn(v)
74 | into[prefix + k] = v
75 |
76 | flat = {}
77 | rec(nested, "", flat)
78 | return flat
79 |
80 |
81 | def unpack_dict(
82 | d: Dict, keys: Sequence[str], return_type: type = tuple
83 | ) -> Union[Sequence, Dict]:
84 | if return_type in (tuple, list):
85 | return return_type(d[key] for key in keys)
86 | elif return_type == dict:
87 | return {key: d[key] for key in keys}
88 | else:
89 | raise ValueError(f"Unknown return_type: {return_type}")
90 |
91 |
92 | def merge_dict(dicts: Sequence[dict], merge_fn: Callable = lambda *args: args) -> dict:
93 | """Merge a sequence of dicts (with the same set of keys) into a single dict."""
94 | if len(dicts) == 0:
95 | return dict()
96 | return {key: merge_fn([dict_[key] for dict_ in dicts]) for key in dicts[0].keys()}
97 |
98 |
99 | def prepare_inputs(
100 | data: Union[torch.Tensor, Any], device: Union[str, int, torch.device]
101 | ) -> Union[torch.Tensor, Any]:
102 | if isinstance(data, Mapping):
103 | return type(data)(
104 | {k: prepare_inputs(v, device) for k, v in data.items()}
105 | ) # noqa
106 | elif isinstance(data, (tuple, list)):
107 | return type(data)(prepare_inputs(v, device) for v in data)
108 | elif isinstance(data, torch.Tensor):
109 | return data.to(device) # This can break with deepspeed.
110 | return data
111 |
112 |
113 | def pad_inputs_on_batch(
114 | data: Sequence[torch.Tensor], per_device_batch_size: int
115 | ) -> Sequence[torch.Tensor]:
116 | batch_size = None
117 | output_tensors = []
118 | for tensor in data:
119 | if batch_size is None:
120 | batch_size = tensor.size(0)
121 | assert tensor.size(0) == batch_size
122 |
123 | if batch_size % per_device_batch_size != 0:
124 | filled_size = per_device_batch_size - (batch_size % per_device_batch_size)
125 | tensor = torch.cat(
126 | [
127 | tensor,
128 | tensor[0:1].expand(filled_size, *tensor.size()[1:]),
129 | ],
130 | dim=0,
131 | )
132 | output_tensors.append(tensor)
133 | return output_tensors
134 |
135 |
136 | def pad(
137 | inputs: torch.Tensor,
138 | target_size: Union[torch.Size, Sequence[int]],
139 | value=0.0,
140 | left=True,
141 | ):
142 | current_size = inputs.size()
143 | diffs = tuple(ti - ci for ti, ci in zip_(target_size, current_size))
144 | pad_params = []
145 | for diff in diffs:
146 | pad_params = ([diff, 0] if left else [0, diff]) + pad_params
147 | res = F.pad(inputs, pad=pad_params, value=value)
148 | return res
149 |
150 |
151 | def left_pad(
152 | inputs: torch.Tensor, target_size: Union[torch.Size, Sequence[int]], value=0.0
153 | ):
154 | return pad(inputs=inputs, target_size=target_size, value=value, left=True)
155 |
156 |
157 | def right_pad(
158 | inputs: torch.Tensor, target_size: Union[torch.Size, Sequence[int]], value=0.0
159 | ):
160 | return pad(inputs=inputs, target_size=target_size, value=value, left=False)
161 |
162 |
163 | def manual_seed(args_or_seed: Union[int, argparse.Namespace], fix_cudnn=False):
164 | if hasattr(args_or_seed, "seed"):
165 | args_or_seed = args_or_seed.seed
166 | random.seed(args_or_seed)
167 | np.random.seed(args_or_seed)
168 | torch.manual_seed(args_or_seed)
169 | torch.cuda.manual_seed_all(args_or_seed)
170 | os.environ["PYTHONHASHSEED"] = str(args_or_seed)
171 | if fix_cudnn:
172 | torch.backends.cudnn.deterministic = True # noqa
173 | torch.backends.cudnn.benchmark = False # noqa
174 |
175 |
176 | class InfiniteLoader(object):
177 | """Wraps an existing DataLoader so that it outputs stuff indefinitely; useful for semi-supervised learning and DDP."""
178 |
179 | def __init__(self, loader: DataLoader):
180 | super(InfiniteLoader, self).__init__()
181 | self.loader = loader
182 | self.data_iterator = iter(loader)
183 | self.epoch = 0
184 |
185 | def __iter__(self):
186 | return self
187 |
188 | def __next__(self):
189 | try:
190 | data = next(self.data_iterator)
191 | except StopIteration:
192 | # Increment the epoch count
193 | self.epoch += 1
194 |
195 | # If using Distributed Data Parallel, set the epoch for the sampler
196 | if dist.is_initialized():
197 | self.loader.sampler.set_epoch(self.epoch)
198 |
199 | # Create a new iterator for the next epoch
200 | self.data_iterator = iter(self.loader)
201 | data = next(self.data_iterator)
202 |
203 | return data
204 |
205 |
206 | class DisableLogger:
207 | def __enter__(self):
208 | logging.disable(logging.CRITICAL)
209 |
210 | def __exit__(self, exit_type, exit_value, exit_traceback):
211 | logging.disable(logging.NOTSET)
212 |
213 |
214 | @dataclass
215 | class DataCollatorForStackableDataset(object):
216 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
217 | return {
218 | key: torch.stack([instance[key] for instance in instances])
219 | for key in instances[0].keys()
220 | }
221 |
222 |
223 | def local_dataset(dataset_name):
224 | if dataset_name.endswith(".json"):
225 | full_dataset = load_dataset(
226 | "json",
227 | data_files=dataset_name,
228 | cache_dir=os.path.join(
229 | tempfile.gettempdir(), f"{os.getuid()}_cache", "huggingface", "datasets"
230 | ),
231 | )
232 | else:
233 | raise ValueError(f"Unsupported dataset format: {dataset_name}")
234 |
235 | return full_dataset
236 |
237 |
238 | def _get_generator(seed: int) -> torch.Generator:
239 | rng = torch.Generator()
240 | rng.manual_seed(seed)
241 | return rng
242 |
243 |
244 | def split_train_into_train_and_eval(
245 | train_dataset: Dataset, eval_size: int, seed: int
246 | ) -> Tuple[Dataset, Dataset]:
247 | assert eval_size < len(
248 | train_dataset # noqa
249 | ), "Requested eval_size cannot be equal/larger than original train data size."
250 | new_train_size = len(train_dataset) - eval_size # noqa
251 | train_dataset, eval_dataset = random_split(
252 | train_dataset, [new_train_size, eval_size], generator=_get_generator(seed)
253 | )
254 | return train_dataset, eval_dataset
255 |
--------------------------------------------------------------------------------
/train_code/training_utils/__pycache__/fsdp_utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/training_utils/__pycache__/fsdp_utils.cpython-310.pyc
--------------------------------------------------------------------------------
/train_code/training_utils/__pycache__/fsdp_utils.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/training_utils/__pycache__/fsdp_utils.cpython-311.pyc
--------------------------------------------------------------------------------
/train_code/training_utils/__pycache__/memory_efficient_adam.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/training_utils/__pycache__/memory_efficient_adam.cpython-310.pyc
--------------------------------------------------------------------------------
/train_code/training_utils/__pycache__/memory_efficient_adam.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/training_utils/__pycache__/memory_efficient_adam.cpython-311.pyc
--------------------------------------------------------------------------------
/train_code/training_utils/__pycache__/trainer_utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/training_utils/__pycache__/trainer_utils.cpython-310.pyc
--------------------------------------------------------------------------------
/train_code/training_utils/__pycache__/trainer_utils.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/training_utils/__pycache__/trainer_utils.cpython-311.pyc
--------------------------------------------------------------------------------
/train_code/training_utils/memory_efficient_adam.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The GPT-Accelera Team
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import math
8 |
9 | from collections import defaultdict
10 | from copy import deepcopy
11 | from itertools import chain
12 | from typing import Any, DefaultDict, Dict, Iterable
13 |
14 | import torch
15 | from torch.optim import Optimizer
16 | from torch.optim.optimizer import StateDict
17 |
18 |
19 | class MemoryEfficientAdamW(Optimizer):
20 | """
21 | Arguments:
22 | model_params (iterable): iterable of parameters of dicts defining
23 | parameter groups.
24 | lr (float, optional): learning rate. (default: 1e-3)
25 | betas (Tuple[float, float], optional): coefficients used for computing
26 | running averages of gradient and its square. (default: (0.9, 0.999))
27 | eps (float, optional): term added to the denominator to improve
28 | numerical stability. (default: 1e-8)
29 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
30 | adamw_mode (boolean, optional): Apply L2 regularization or weight decay
31 | True for decoupled weight decay(also known as AdamW) (default: True)
32 |
33 | .. _Adam\: A Method for Stochastic Optimization:
34 | https://arxiv.org/abs/1412.6980
35 | .. _On the Convergence of Adam and Beyond:
36 | https://openreview.net/forum?id=ryQu7f-RZ
37 | """
38 |
39 | def __init__(
40 | self,
41 | model_params,
42 | lr=1e-3,
43 | betas=(0.9, 0.999),
44 | eps=1e-8,
45 | weight_decay=0,
46 | adamw_mode=True,
47 | optim_dtype=torch.bfloat16,
48 | optim_device=torch.device("cpu"),
49 | ):
50 | default_args = dict(
51 | lr=lr,
52 | betas=betas,
53 | eps=eps,
54 | weight_decay=weight_decay,
55 | )
56 | super(MemoryEfficientAdamW, self).__init__(model_params, default_args)
57 | self.adamw_mode = adamw_mode
58 | self.optim_dtype = optim_dtype
59 | self.optim_device = optim_device
60 |
61 | def torch_adam_update_cpu(
62 | self,
63 | data,
64 | grad,
65 | exp_avg,
66 | exp_avg_sq,
67 | lr,
68 | beta1,
69 | beta2,
70 | eps,
71 | weight_decay,
72 | bias_correction1,
73 | bias_correction2,
74 | use_adamw=False,
75 | ):
76 | assert data.dtype == grad.dtype
77 | if weight_decay != 0:
78 | if use_adamw:
79 | data.mul_(1 - lr * weight_decay)
80 | else:
81 | grad = grad.add(data, alpha=weight_decay)
82 |
83 | non_blocking = self.optim_device.type == "cpu"
84 |
85 | exp_avg_cuda, exp_avg_sq_cuda = (
86 | exp_avg.to(data.device, non_blocking=non_blocking),
87 | exp_avg_sq.to(data.device, non_blocking=non_blocking),
88 | )
89 |
90 | dtype_grad = grad.to(dtype=self.optim_dtype)
91 | exp_avg_cuda.mul_(beta1).add_(dtype_grad, alpha=1 - beta1)
92 | exp_avg_sq_cuda.mul_(beta2).addcmul_(dtype_grad, dtype_grad, value=1 - beta2)
93 | denom_cuda = (exp_avg_sq_cuda.sqrt() / math.sqrt(bias_correction2)).add_(eps)
94 |
95 | step_size = lr / bias_correction1
96 | data.addcdiv_(
97 | exp_avg_cuda.to(dtype=data.dtype),
98 | denom_cuda.to(dtype=data.dtype),
99 | value=-step_size,
100 | )
101 |
102 | # Write back to cpu
103 | exp_avg.copy_(exp_avg_cuda, non_blocking=non_blocking)
104 | exp_avg_sq.copy_(exp_avg_sq_cuda, non_blocking=non_blocking)
105 |
106 | @torch.no_grad()
107 | def step(self, closure=None):
108 | loss = None
109 | if closure is not None:
110 | with torch.enable_grad():
111 | loss = closure()
112 |
113 | for _, group in enumerate(self.param_groups):
114 | for _, p in enumerate(group["params"]):
115 | if p.grad is None:
116 | continue
117 |
118 | state = self.state[p]
119 | assert (
120 | p.device.type == "cuda"
121 | ), f"PinMemoryCPUAdam assume all parameters are on cuda"
122 | if len(state) == 0:
123 | state["step"] = 0
124 | # gradient momentums
125 | state["exp_avg"] = torch.zeros_like(
126 | p,
127 | device=self.optim_device,
128 | dtype=self.optim_dtype,
129 | )
130 | # gradient variances
131 | state["exp_avg_sq"] = torch.zeros_like(
132 | p,
133 | device=self.optim_device,
134 | dtype=self.optim_dtype,
135 | )
136 | if self.optim_device.type == "cpu":
137 | state["exp_avg"] = state["exp_avg"].pin_memory()
138 | state["exp_avg_sq"] = state["exp_avg_sq"].pin_memory()
139 |
140 | state["step"] += 1
141 | beta1, beta2 = group["betas"]
142 |
143 | assert (
144 | p.data.numel() == p.grad.data.numel()
145 | ), "parameter and gradient should have the same size"
146 | assert (
147 | state["exp_avg"].device.type == self.optim_device.type
148 | ), f"exp_avg should stay on {self.optim_device.type}"
149 | assert (
150 | state["exp_avg_sq"].device.type == self.optim_device.type
151 | ), f"exp_avg should stay on {self.optim_device.type}"
152 | bias_correction1 = 1 - beta1 ** state["step"]
153 | bias_correction2 = 1 - beta2 ** state["step"]
154 | self.torch_adam_update_cpu(
155 | p.data,
156 | p.grad.data,
157 | state["exp_avg"],
158 | state["exp_avg_sq"],
159 | group["lr"],
160 | beta1,
161 | beta2,
162 | group["eps"],
163 | group["weight_decay"],
164 | bias_correction1,
165 | bias_correction2,
166 | self.adamw_mode,
167 | )
168 | return loss
169 |
170 | @torch._disable_dynamo
171 | def load_state_dict(self, state_dict: StateDict) -> None:
172 | r"""Loads the optimizer state.
173 |
174 | Args:
175 | state_dict (dict): optimizer state. Should be an object returned
176 | from a call to :meth:`state_dict`.
177 | """
178 | # shallow copy, to be consistent with module API
179 | state_dict = state_dict.copy()
180 |
181 | for pre_hook in self._optimizer_load_state_dict_pre_hooks.values():
182 | hook_result = pre_hook(self, state_dict)
183 | if hook_result is not None:
184 | state_dict = hook_result
185 |
186 | # Validate the state_dict
187 | groups = self.param_groups
188 |
189 | # Deepcopy as we write into saved_groups later to update state
190 | saved_groups = deepcopy(state_dict["param_groups"])
191 |
192 | if len(groups) != len(saved_groups):
193 | raise ValueError(
194 | "loaded state dict has a different number of " "parameter groups"
195 | )
196 | param_lens = (len(g["params"]) for g in groups)
197 | saved_lens = (len(g["params"]) for g in saved_groups)
198 | if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
199 | raise ValueError(
200 | "loaded state dict contains a parameter group "
201 | "that doesn't match the size of optimizer's group"
202 | )
203 |
204 | # Update the state
205 | id_map = dict(
206 | zip(
207 | chain.from_iterable(g["params"] for g in saved_groups),
208 | chain.from_iterable(g["params"] for g in groups),
209 | )
210 | )
211 |
212 | def _cast(param, value, param_id=None, param_groups=None, key=None):
213 | r"""Make a deep copy of value, casting all tensors to device of param."""
214 | if isinstance(value, torch.Tensor):
215 | if param.is_floating_point():
216 | casted_value = value.to(
217 | dtype=self.optim_dtype, device=self.optim_device
218 | )
219 | if self.optim_device.type == "cpu":
220 | casted_value = casted_value.pin_memory()
221 | else:
222 | casted_value = Optimizer._process_value_according_to_param_policy(
223 | param, value, param_id, param_groups, key
224 | )
225 | return casted_value
226 | elif isinstance(value, dict):
227 | return {
228 | k: _cast(
229 | param, v, param_id=param_id, param_groups=param_groups, key=k
230 | )
231 | for k, v in value.items()
232 | }
233 | elif isinstance(value, Iterable):
234 | return type(value)(_cast(param, v, param_id=param_id, param_groups=param_groups) for v in value) # type: ignore[call-arg]
235 | else:
236 | return value
237 |
238 | # Copy state assigned to params (and cast tensors to appropriate types).
239 | # State that is not assigned to params is copied as is (needed for
240 | # backward compatibility).
241 | state: DefaultDict[torch.Tensor, Dict[Any, Any]] = defaultdict(dict)
242 | for k, v in state_dict["state"].items():
243 | if k in id_map:
244 | param = id_map[k]
245 | state[param] = _cast(
246 | param, v, param_id=k, param_groups=state_dict["param_groups"]
247 | )
248 | else:
249 | state[k] = v
250 |
251 | # Update parameter groups, setting their 'params' value
252 | def update_group(
253 | group: Dict[str, Any], new_group: Dict[str, Any]
254 | ) -> Dict[str, Any]:
255 | new_group["params"] = group["params"]
256 | return new_group
257 |
258 | param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)]
259 | self.__setstate__({"state": state, "param_groups": param_groups})
260 |
261 | for post_hook in self._optimizer_load_state_dict_post_hooks.values():
262 | post_hook(self)
263 |
--------------------------------------------------------------------------------
/train_code/training_utils/trainer_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The GPT-Accelera Team
2 | # Copyright 2023 The Alpaca Team
3 | # Copyright 2022 The HuggingFace Team. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import math
18 | from functools import partial
19 |
20 | import torch
21 | import torch.nn as nn
22 |
23 | import torch.optim as optim
24 | from torch.optim.lr_scheduler import LambdaLR
25 |
26 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
27 | from torch.distributed.fsdp import MixedPrecision
28 | from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
29 |
30 | from training_utils.memory_efficient_adam import MemoryEfficientAdamW
31 | from arguments import Arguments
32 |
33 | from models.model import TransformerBlock
34 | from models.tp import get_data_parallel_group, get_data_parallel_world_size
35 |
36 |
37 | def create_optimizer(
38 | args: Arguments,
39 | model: nn.Module,
40 | optimizer_cpu_offload: bool = False,
41 | model_cpu_offload: bool = False,
42 | ) -> optim.Optimizer:
43 | if not model_cpu_offload:
44 | model_device = next(iter(model.parameters())).device
45 |
46 | optimizer = MemoryEfficientAdamW(
47 | [p for p in model.parameters() if p.requires_grad],
48 | lr=args.learning_rate,
49 | betas=(args.adam_beta1, args.adam_beta2),
50 | eps=args.adam_eps,
51 | weight_decay=args.weight_decay,
52 | optim_dtype=args.optim_dtype,
53 | optim_device=(
54 | torch.device("cpu") if optimizer_cpu_offload else model_device
55 | ),
56 | )
57 | else:
58 | optimizer = torch.optim.AdamW(
59 | [p for p in model.parameters() if p.requires_grad],
60 | lr=args.learning_rate,
61 | betas=(args.adam_beta1, args.adam_beta2),
62 | eps=args.adam_eps,
63 | weight_decay=args.weight_decay,
64 | fused=True,
65 | )
66 |
67 | return optimizer
68 |
69 |
70 | def create_fsdp_model_for_finetune(
71 | args: Arguments,
72 | model: nn.Module,
73 | bf16_all_reduce_upper_bound: int = 16,
74 | ) -> FSDP:
75 | model = FSDP(
76 | module=model,
77 | process_group=get_data_parallel_group(),
78 | auto_wrap_policy=partial(
79 | transformer_auto_wrap_policy,
80 | transformer_layer_cls={
81 | TransformerBlock,
82 | },
83 | ),
84 | mixed_precision=MixedPrecision(
85 | param_dtype=args.compute_dtype,
86 | reduce_dtype=(
87 | torch.float32
88 | if get_data_parallel_world_size() >= bf16_all_reduce_upper_bound
89 | else args.compute_dtype
90 | ),
91 | keep_low_precision_grads=(args.optim_dtype != torch.float32),
92 | buffer_dtype=args.compute_dtype,
93 | ),
94 | cpu_offload=False,
95 | use_orig_params=False,
96 | forward_prefetch=True,
97 | limit_all_gathers=True,
98 | )
99 | return model
100 |
101 |
102 | # https://github.com/huggingface/transformers/blob/976189a6df796a2ff442dd81b022626c840d8c27/src/transformers/optimization.py
103 | def _get_cosine_schedule_with_warmup_lr_lambda(
104 | current_step: int,
105 | *,
106 | num_warmup_steps: int,
107 | num_training_steps: int,
108 | warmup_start_ratio: float,
109 | eta_min_ratio: float,
110 | ):
111 | if current_step < num_warmup_steps:
112 | return warmup_start_ratio + (1.0 - warmup_start_ratio) * float(
113 | current_step
114 | ) / float(max(1, num_warmup_steps))
115 |
116 | progress = float(current_step - num_warmup_steps) / float(
117 | max(1, num_training_steps - num_warmup_steps)
118 | )
119 | return eta_min_ratio + (1.0 - eta_min_ratio) * max(
120 | 0.0, 0.5 * (1.0 + math.cos(math.pi * progress))
121 | )
122 |
123 |
124 | def get_cosine_schedule_with_warmup(
125 | optimizer: optim.Optimizer,
126 | warmup_epochs: int,
127 | max_epochs: int,
128 | warmup_start_ratio: float = 0.0,
129 | eta_min_ratio: float = 0.0,
130 | last_epoch: int = -1,
131 | ):
132 | """
133 | Create a schedule with a learning rate that decreases following the values of the cosine function between the
134 | initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
135 | initial lr set in the optimizer.
136 |
137 | Return:
138 | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
139 | """
140 |
141 | assert 0.0 <= warmup_start_ratio <= 1.0, "warmup_start_ratio should be in [0, 1]"
142 | assert 0.0 <= eta_min_ratio <= 1.0, "eta_min_ratio should be in [0, 1]"
143 |
144 | lr_lambda = partial(
145 | _get_cosine_schedule_with_warmup_lr_lambda,
146 | num_warmup_steps=warmup_epochs,
147 | num_training_steps=max_epochs,
148 | warmup_start_ratio=warmup_start_ratio,
149 | eta_min_ratio=eta_min_ratio,
150 | )
151 | return LambdaLR(optimizer, lr_lambda, last_epoch)
152 |
153 |
154 | def _get_constant_cosine_schedule_lr_lambda(
155 | current_step: int,
156 | *,
157 | num_training_steps: int,
158 | constant_lr: float,
159 | eta_min_ratio: float,
160 | constant_ratio: float,
161 | ):
162 | """
163 | Implements a learning rate schedule with a constant learning rate for the specified proportion
164 | of the training steps and a cosine annealing schedule for the remaining steps.
165 |
166 | Args:
167 | current_step (int): The current training step.
168 | num_training_steps (int): The total number of training steps.
169 | constant_lr (float): The constant learning rate for the specified proportion of the training steps.
170 | eta_min_ratio (float): The minimum learning rate ratio at the end of the cosine annealing schedule.
171 | constant_ratio (float): The proportion of the training steps to keep the learning rate constant.
172 |
173 | Returns:
174 | float: The learning rate multiplier.
175 | """
176 | constant_steps = int(num_training_steps * constant_ratio)
177 |
178 |
179 | if current_step < constant_steps:
180 | return constant_lr
181 |
182 | #progress = float(current_step - constant_steps) / float(num_training_steps - constant_steps)
183 | progress = float(current_step - constant_steps) / float(
184 | max(1, num_training_steps - constant_steps)
185 | )
186 | return eta_min_ratio + (constant_lr - eta_min_ratio) * 0.5 * (1.0 + math.cos(math.pi * progress))
187 |
188 | def get_constant_cosine_schedule_with_warmup(
189 | optimizer: torch.optim.Optimizer,
190 | max_epochs: int,
191 | constant_lr: float = 1.0,
192 | eta_min_ratio: float = 0.0,
193 | constant_ratio: float = 0.67,
194 | last_epoch: int = -1,
195 | ):
196 | """
197 | Create a schedule with a learning rate that is constant for a specified proportion of the training steps,
198 | then decreases following the values of the cosine function for the remaining steps.
199 |
200 | Args:
201 | optimizer (torch.optim.Optimizer): The optimizer for which to schedule the learning rate.
202 | max_epochs (int): The total number of training epochs.
203 | constant_lr (float, optional): The constant learning rate for the specified proportion of the training steps. Default is 1.0.
204 | eta_min_ratio (float, optional): The minimum learning rate ratio at the end of the cosine annealing schedule. Default is 0.0.
205 | constant_ratio (float, optional): The proportion of the training steps to keep the learning rate constant. Default is 0.67.
206 | last_epoch (int, optional): The index of the last epoch when resuming training. Default is -1.
207 |
208 | Returns:
209 | torch.optim.lr_scheduler.LambdaLR: The learning rate scheduler.
210 | """
211 | lr_lambda = partial(
212 | _get_constant_cosine_schedule_lr_lambda,
213 | num_training_steps=max_epochs,
214 | constant_lr=constant_lr,
215 | eta_min_ratio=eta_min_ratio,
216 | constant_ratio=constant_ratio,
217 | )
218 | return LambdaLR(optimizer, lr_lambda, last_epoch)
--------------------------------------------------------------------------------
/train_code/vllm_infer.py:
--------------------------------------------------------------------------------
1 | # 基于diversity筛选出的数据,进行self-consistency
2 |
3 | from vllm import LLM, SamplingParams
4 |
5 | from transformers import AutoModelForCausalLM, AutoTokenizer
6 | import json
7 | import random
8 |
9 | import re
10 | import argparse
11 |
12 | def sample_resp(args):
13 | with open(args.input_data, "r") as r:
14 | data_json = json.load(r)
15 | sample_data = data_json
16 |
17 |
18 | data_dict = {}
19 |
20 |
21 | for item in sample_data:
22 | data_dict[item["input"]] = item
23 |
24 | prompt_template = '''
25 |
26 | "Below is an instruction that describes a task. " "Write a response that appropriately completes the request.\n\n" "### Instruction:\n{instruction}\n\n### Response: Let's think step by step."
27 |
28 | '''
29 | llm = LLM(model=args.model_dir, tensor_parallel_size=args.tensor_parallel_size)
30 | stop_tokens = ["\nQUESTION:\n", "Question:", "Question", "USER:", "USER", "ASSISTANT:", "ASSISTANT", "Instruction:", "Instruction", "Response:", "Response", "# your code here", "QUESTION", "# Your code goes here", "# Write your code", "\n\n\n\n", "<|end_of_text|>", "\n\nSolved"]
31 | sampling_params = SamplingParams(max_tokens=args.max_tokens, temperature=args.temperature, top_k=args.top_k, n=args.sample_num, stop=stop_tokens)
32 | all_input = []
33 |
34 | for item in sample_data:
35 | #user_input = prompt_template.replace("{instruction}", item)
36 | all_input.append(item["input"])
37 | outputs = llm.generate(all_input, sampling_params)
38 |
39 | all_output = []
40 | all_prompt = []
41 |
42 | all_json = []
43 |
44 | for output in outputs:
45 |
46 | temp_json = data_dict[output.prompt]
47 | all_prompt.append(output.prompt)
48 |
49 | temp_json["prompt"] = output.prompt
50 |
51 | for i in range(args.sample_num):
52 | temp_json["output"+str(i)] = output.outputs[i].text
53 |
54 | all_json.append(temp_json)
55 |
56 | with open(args.output_file, "w") as w:
57 | json.dump(all_json, w)
58 |
59 | def parse_args():
60 | parser = argparse.ArgumentParser()
61 | parser.add_argument("--input_data", type=str) # start index
62 | parser.add_argument("--model_dir", type=str) # start index
63 | parser.add_argument("--sample_num", type=int, default=10) #start index
64 | parser.add_argument("--temperature", type=float, default=1.0) #start index
65 | parser.add_argument("--top_k", type=int, default=20) #start index
66 | parser.add_argument("--max_tokens", type=int, default=768) #start index
67 | parser.add_argument("--output_file", type=str) # start index
68 | parser.add_argument("--tensor_parallel_size", type=int, default=1) # tensor_parallel_size
69 | return parser.parse_args()
70 |
71 | if __name__ == "__main__":
72 | args = parse_args()
73 |
74 | sample_resp(args=args)
75 |
76 |
--------------------------------------------------------------------------------
/train_code/vllm_infer_arc.py:
--------------------------------------------------------------------------------
1 | import io
2 | import json
3 | import logging
4 | import math
5 | import random
6 | import numpy as np
7 | import os
8 | import pprint
9 | import sys
10 | import time
11 | import transformers
12 | import torch
13 |
14 | from datasets import load_dataset
15 |
16 | from datetime import datetime, date
17 | from tqdm import tqdm
18 | from vllm import LLM, SamplingParams
19 | from ray_vllm import LLMRayWrapper
20 | from string import ascii_uppercase
21 | import re
22 |
23 | def load_jsonl(file_path):
24 | with open(file_path, "r") as f:
25 | return [json.loads(line) for line in f]
26 |
27 | def construct_prompt(template, data):
28 | choices = data['choices']["text"]
29 | question = data['question'] + "\n"
30 | i = 0
31 | for choice in choices:
32 | question += f"({ascii_uppercase[i]}) {choice}\n"
33 | i += 1
34 | question = question.strip()
35 | return template.replace("{{question}}", question)
36 |
37 |
38 | def extract_answer(response):
39 | pattern_list = [r"Final answer:\s*[\[\(]?([A-Za-z])[\]\)]?"]
40 | for pattern in pattern_list:
41 | match = re.search(pattern, response)
42 | if match:
43 | return match.group(1)
44 | return None
45 |
46 |
47 | def main(args):
48 |
49 | argsdict = vars(args)
50 | print(pprint.pformat(argsdict))
51 |
52 | if args.input_file is not None:
53 | problems = load_jsonl(args.input_file)
54 | problems = problems[args.start:args.end]
55 | print("Loading problems from ", min(len(problems),args.end))
56 | print("Number of problems: ", len(problems))
57 | else:
58 | problems = load_dataset("allenai/ai2_arc", "ARC-Challenge", split=args.split, cache_dir=args.cache_dir)
59 | # problems = [{"id": idx , **item} for idx, item in enumerate(problems)]
60 | # random.seed(42)
61 | problems = problems.select(range(args.start, min(len(problems),args.end)))
62 |
63 | if not os.path.exists(os.path.dirname(args.save)):
64 | os.makedirs(os.path.dirname(args.save), exist_ok=True)
65 |
66 |
67 | llm = LLMRayWrapper(model_path=args.model, tensor_parallel_size=args.tensor_parallel_size, max_model_len=4096, cuda_ids=args.cuda_ids, swap_space = 20) # Adjust tensor_parallel_size as needed
68 | stop_tokens = ["\nQUESTION:\n", "Question:", "Question", "USER:", "USER", "ASSISTANT:", "ASSISTANT", "Instruction:", "Instruction", "Response:", "Response", "\n\n"]
69 | sampling_params = SamplingParams(
70 | n=args.sample_num,
71 | temperature=args.temperature,
72 | top_k=args.top_k,
73 | max_tokens=args.max_tokens,
74 | stop=stop_tokens
75 | )
76 | prompt = open(args.prompt_template_dir+"/"+str(args.shot_num)+"-shot-prompt.txt", "r").read()
77 | zero_shot_prompt = open(args.prompt_template_dir+"/0-shot-prompt.txt", "r").read()
78 | # tokenizer = AutoTokenizer.from_pretrained(args.model)
79 |
80 | # Filter too long problems with tokenizer
81 | # problems = problems.filter(lambda x: len(tokenizer(x["question"])["input_ids"]) < 2048)
82 | if not os.path.exists(os.path.dirname(args.save)):
83 | os.makedirs(os.path.dirname(args.save), exist_ok=True)
84 | # ans_file = open(args.save, "w")
85 |
86 | prompt_list = []
87 | for item in problems:
88 | prompt_list.append(construct_prompt(prompt, item))
89 | outputs = llm.generate(prompts = prompt_list, sampling_params = sampling_params)
90 | correct = 0
91 | gpt_codes = {}
92 | for i, (p, output) in enumerate(zip(problems, outputs)):
93 | question = p["question"]
94 | output_list = {
95 | "id": p["id"],
96 | "question": question,
97 | "input": construct_prompt(zero_shot_prompt, p),
98 | "gt": p["answerKey"]
99 | }
100 | for q in range(args.sample_num):
101 | raw_response = output.outputs[q].text
102 | answer = extract_answer(raw_response)
103 | gt = p["answerKey"]
104 | if answer and answer.lower() == gt.lower():
105 | correct += 1
106 | score = 1.0
107 | else:
108 | score = 0.0
109 | output_list["output"+str(q)] ={
110 | "text": raw_response,
111 | "score": score,
112 | "finish_reason": output.outputs[q].finish_reason,
113 | "extracted_answer": answer
114 | }
115 | gpt_codes[p["id"]] = output_list
116 |
117 | # ans_file.write(json.dumps(output_list) + "\n")
118 | # ans_file.flush()
119 |
120 |
121 |
122 | if args.debug:
123 | print("Prompt: ", "-" * 100)
124 | print(output.prompt)
125 | print("Completion: ", "-" * 100)
126 | print(output_list['output0']['text'])
127 | print("Ground Truth: ", gt)
128 | print("Score: ", output_list['output0']['score'])
129 | # ans_file.close()
130 |
131 | with open(args.save, "w") as f:
132 | json.dump(gpt_codes, f, indent=2)
133 |
134 | print(f"Accuracy: {correct / (len(problems) * args.sample_num)}")
135 |
136 |
137 |
138 |
139 | if __name__ == "__main__":
140 | import argparse
141 |
142 | parser = argparse.ArgumentParser(description="Run a tranined model to generate ARC Challenge.")
143 | parser.add_argument("--model", default="gpt2")
144 | parser.add_argument("--prompt_template_dir", default=None, type=str)
145 | parser.add_argument("--shot_num", default=None, type=int)
146 | parser.add_argument("--input_file", default=None, type=str)
147 | # parser.add_argument("-t","--test_loc", default="~/apps/data_split/test.json", type=str, help="path to the test folder.")
148 | # parser.add_argument("-r","--root", default="../", type=str, help="where the data is stored.")
149 | # parser.add_argument("-l","--load", default="", type=str)
150 | # parser.add_argument("--peeking", default=0.0, type=float)
151 | parser.add_argument("--sample_num", type=int, default=10) # 采样数量
152 | parser.add_argument("--temperature", type=float, default=1.0) # 温度
153 | parser.add_argument("--top_k", type=int, default=20) # top_k
154 | parser.add_argument("--max_tokens", type=int, default=768) # 最大token数
155 | # parser.add_argument("--difficulty", default="introductory", type=str)
156 | # parser.add_argument("--num-beams", default=5, type=int)
157 | parser.add_argument("-s","--start", default=0, type=int)
158 | parser.add_argument("-e","--end", default=10000000, type=int)
159 | # parser.add_argument("-i", "--index", default=None, type=int)
160 | parser.add_argument("--cuda_ids", type=str, default="0")
161 | parser.add_argument("-d", "--debug", action="store_true")
162 | parser.add_argument("--split", type=str, default="train", help="What split to use.")
163 | parser.add_argument("--save", type=str, default="./results")
164 | parser.add_argument("--tensor-parallel-size", type=int, default=1, help="Tensor parallel size for vllm")
165 | parser.add_argument("--cache_dir", type=str, default=None, help="Cache directory for datasets")
166 | args = parser.parse_args()
167 |
168 | main(args)
169 |
--------------------------------------------------------------------------------
/train_code/vllm_infer_auto.py:
--------------------------------------------------------------------------------
1 | from vllm import LLM, SamplingParams
2 | from transformers import AutoModelForCausalLM, AutoTokenizer
3 | import json
4 | import random
5 | import re
6 | import argparse
7 | # import wandb
8 | # from huggingface_hub import HfApi
9 | # from huggingface_hub import HfApi, upload_file
10 |
11 | # def upload_to_huggingface(output_file, repo_id, token, commit_message="Upload output file"):
12 | # # Initialize HfApi with token
13 | # api = HfApi()
14 |
15 | # # Upload the file
16 | # api.upload_file(
17 | # path_or_fileobj=output_file,
18 | # path_in_repo=output_file, # You can change the path in the repo if needed
19 | # repo_id=repo_id,
20 | # token=token,
21 | # commit_message=commit_message
22 | # )
23 |
24 | def upload_to_huggingface(output_file, repo_id, token, repo_type='dataset', commit_message="Upload output file"):
25 | # Upload the file to the specified repository
26 | upload_file(
27 | path_or_fileobj=output_file,
28 | path_in_repo=output_file, # You can change the path in the repo if needed
29 | repo_id=repo_id,
30 | token=token,
31 | repo_type=repo_type,
32 | commit_message=commit_message
33 | )
34 | print(f"File {output_file} uploaded successfully to {repo_id}.")
35 |
36 | def sample_resp(args):
37 | # wandb.init(project=args.wandb_project, config={
38 | # "input_data": args.input_data,
39 | # "model_dir": args.model_dir,
40 | # "sample_num": args.sample_num,
41 | # "temperature": args.temperature,
42 | # "top_k": args.top_k,
43 | # "max_tokens": args.max_tokens,
44 | # "output_file": args.output_file,
45 | # "tensor_parallel_size": args.tensor_parallel_size,
46 | # })
47 |
48 | with open(args.input_data, "r") as r:
49 | data_json = json.load(r)
50 | sample_data = data_json
51 |
52 | data_dict = {}
53 |
54 | for item in sample_data:
55 | data_dict[item["input"]] = item
56 |
57 | prompt_template = '''
58 |
59 | "Below is an instruction that describes a task. " "Write a response that appropriately completes the request.\n\n" "### Instruction:\n{instruction}\n\n### Response: Let's think step by step."
60 |
61 | '''
62 | llm = LLM(model=args.model_dir, tensor_parallel_size=args.tensor_parallel_size)
63 |
64 | sampling_params = SamplingParams(max_tokens=args.max_tokens, temperature=args.temperature, top_k=args.top_k, n=args.sample_num)
65 | all_input = []
66 |
67 | for item in sample_data:
68 | all_input.append(item["input"])
69 | outputs = llm.generate(all_input, sampling_params)
70 |
71 | all_output = []
72 | all_prompt = []
73 |
74 | all_json = []
75 |
76 | for output in outputs:
77 |
78 | temp_json = data_dict[output.prompt]
79 | all_prompt.append(output.prompt)
80 |
81 | temp_json["prompt"] = output.prompt
82 |
83 | for i in range(args.sample_num):
84 | temp_json["output"+str(i)] = output.outputs[i].text
85 |
86 | all_json.append(temp_json)
87 |
88 | with open(args.output_file, "w") as w:
89 | json.dump(all_json, w)
90 |
91 | # 上传文件到Hugging Face
92 | #upload_to_huggingface(args.output_file, args.repo_id, args.hf_token)
93 |
94 |
95 |
96 | def parse_args():
97 | parser = argparse.ArgumentParser()
98 | parser.add_argument("--input_data", type=str, required=True) # 输入数据文件路径
99 | parser.add_argument("--model_dir", type=str, required=True) # 模型目录
100 | parser.add_argument("--sample_num", type=int, default=10) # 采样数量
101 | parser.add_argument("--temperature", type=float, default=1.0) # 温度
102 | parser.add_argument("--top_k", type=int, default=20) # top_k
103 | parser.add_argument("--max_tokens", type=int, default=768) # 最大token数
104 | parser.add_argument("--output_file", type=str, required=True) # 输出文件路径
105 | parser.add_argument("--tensor_parallel_size", type=int, default=1) # 并行大小
106 | #parser.add_argument("--repo_id", type=str, required=True, help="Hugging Face repository ID") # Hugging Face存储库ID
107 | #parser.add_argument("--hf_token", type=str, required=True, help="Hugging Face access token") # Hugging Face访问令牌
108 | #parser.add_argument("--wandb_project", type=str, default="my_project") # wandb project name
109 | return parser.parse_args()
110 |
111 | if __name__ == "__main__":
112 | args = parse_args()
113 | sample_resp(args=args)
114 |
--------------------------------------------------------------------------------