├── README.md └── train_code ├── .DS_Store ├── arc_data ├── arc_7k_train_data.jsonl ├── arc_7k_val_data.jsonl ├── arc_challenge_train_val │ ├── train_data.jsonl │ └── val_data.jsonl ├── arc_small_200_query_data.jsonl ├── arc_test_data.jsonl └── arc_train_data.jsonl ├── arc_prompt ├── 0-shot-prompt.txt └── 1-shot-prompt.txt ├── arguments.py ├── cal_metric_vllm.py ├── checkpoint_utils.py ├── convert4reward_auto_ground_sample.py ├── convert_auto.py ├── convert_checkpoint_to_hf.py ├── data_utils ├── __pycache__ │ ├── data_utils_ppo.cpython-310.pyc │ ├── data_utils_ppo.cpython-311.pyc │ └── data_utils_rm_pointwise.cpython-310.pyc ├── data_utils_dpo.py ├── data_utils_ppo.py ├── data_utils_rm_pairwise.py ├── data_utils_rm_pointwise.py └── data_utils_sft.py ├── determine_hyper.py ├── eval_arc_save_metrics.py ├── gsm8k_test.py ├── hf_argparser.py ├── infer_arc.sh ├── infer_gsm8k1.sh ├── inference_reward_llama3.py ├── math_utils ├── README.md ├── __pycache__ │ ├── grader.cpython-310.pyc │ ├── grader.cpython-311.pyc │ ├── math_normalize.cpython-310.pyc │ ├── math_normalize.cpython-311.pyc │ ├── math_rl_utils.cpython-310.pyc │ └── math_rl_utils.cpython-311.pyc ├── grader.py ├── math_normalize.py └── math_rl_utils.py ├── metric_modiacc_auto.py ├── models ├── __pycache__ │ ├── frozen_layers.cpython-310.pyc │ ├── frozen_layers.cpython-311.pyc │ ├── frozen_layers.cpython-39.pyc │ ├── model.cpython-310.pyc │ ├── model.cpython-311.pyc │ ├── model.cpython-39.pyc │ ├── quantize.cpython-310.pyc │ ├── quantize.cpython-311.pyc │ ├── reward_model.cpython-310.pyc │ ├── reward_model.cpython-311.pyc │ ├── rl_model.cpython-310.pyc │ ├── rl_model.cpython-311.pyc │ ├── tokenizer_utils.cpython-310.pyc │ ├── tokenizer_utils.cpython-311.pyc │ ├── tokenizer_utils.cpython-38.pyc │ ├── tp.cpython-310.pyc │ └── tp.cpython-311.pyc ├── frozen_layers.py ├── model.py ├── quantize.py ├── reward_model.py ├── rl_model.py ├── tokenizer_utils.py └── tp.py ├── scripts ├── convert.sh ├── convert_checkpoint_to_hf.py ├── convert_hf_checkpoint.py ├── convert_hf_checkpoint_llama3.py ├── download.py ├── prepare_ds_math_7b.sh ├── prepare_llemma_34b.sh └── prepare_llemma_7b.sh ├── test.jsonl ├── test_ppo.json ├── train_bstar.sh ├── train_reward.sh ├── train_rm_pointwise.py ├── train_sft.py ├── train_sft.sh ├── train_sft_step.py ├── trainers ├── __pycache__ │ ├── common_utils.cpython-310.pyc │ ├── common_utils.cpython-311.pyc │ ├── ppo_trainer.cpython-310.pyc │ ├── ppo_trainer.cpython-311.pyc │ ├── rl_trainer.cpython-310.pyc │ └── rl_trainer.cpython-311.pyc ├── common_utils.py ├── ppo_trainer.py └── rl_trainer.py ├── training_utils ├── __pycache__ │ ├── fsdp_utils.cpython-310.pyc │ ├── fsdp_utils.cpython-311.pyc │ ├── memory_efficient_adam.cpython-310.pyc │ ├── memory_efficient_adam.cpython-311.pyc │ ├── trainer_utils.cpython-310.pyc │ └── trainer_utils.cpython-311.pyc ├── fsdp_utils.py ├── memory_efficient_adam.py └── trainer_utils.py ├── vllm_infer.py ├── vllm_infer_arc.py └── vllm_infer_auto.py /README.md: -------------------------------------------------------------------------------- 1 | # B-STAR: Monitoring and Balancing Exploration and Exploitation in Self-Taught Reasoners 2 | 3 |

4 | 5 | 📄 Paper    6 | 8 |

9 | 10 | B-STAR (Balanced Self-Taught Reasoner) is a framework designed to improve the self-improvement process of reasoning models by dynamically balancing exploration and exploitation throughout training. This approach is particularly effective in enhancing performance in tasks requiring complex reasoning, such as mathematical problem-solving, coding, and commonsense reasoning. 11 | 12 | 13 | ![截屏2024-12-22 17 35 44](https://github.com/user-attachments/assets/fb97aec4-dbfa-45f3-a64a-f3022aeff599) 14 | 15 | 16 | ## Overview 17 | 18 | Self-improvement in reasoning models involves iterative training where models generate their own training data from outputs. However, existing methods often stagnate after a few iterations due to imbalances between two critical factors: 19 | 20 | 1. **Exploration**: The model's ability to generate diverse and high-quality responses. 21 | 2. **Exploitation**: The effectiveness of external rewards in distinguishing and leveraging high-quality responses. 22 | 23 | ![截屏2024-12-22 17 40 13](https://github.com/user-attachments/assets/3970c997-8a9c-4c40-9c7a-4884b4897076) 24 | 25 | B-STAR introduces an adaptive mechanism to monitor and balance these factors dynamically, ensuring consistent performance improvements over multiple training iterations 26 | 27 | 28 | ## Key Features 29 | 30 | - **Dynamic Configuration Adjustments**: Automatically tunes exploration and exploitation configurations (e.g., sampling temperature, reward thresholds) to optimize the self-improvement process. 31 | - **Balance Score Metric**: Quantifies the interplay between exploration and exploitation, guiding dynamic adjustments. 32 | - **Generalization Across Tasks**: Demonstrates effectiveness in mathematical reasoning, coding challenges, and commonsense reasoning tasks 33 | 34 | 35 | ## Results 36 | 37 | B-STAR achieves state-of-the-art performance across various benchmarks: 38 | 39 | - Significant improvements compared to previsous self-improvement methods. 40 | ![截屏2024-12-22 17 39 06](https://github.com/user-attachments/assets/6fe32096-6099-49df-8824-f912ee31f71d) 41 | 42 | 43 | - Sustained performance growth across multiple iterations, outperforming existing methods that stagnate after a few iterations. 44 | ![截屏2024-12-22 17 39 31](https://github.com/user-attachments/assets/76f35782-6617-4d54-a6ea-f9a89fe0b2bb) 45 | 46 | ## Reproduction 47 | 48 | Our code builds upon [easy-to-hard](https://github.com/Edward-Sun/easy-to-hard/tree/main) and [gpt-accelerate](https://github.com/Edward-Sun/gpt-accelera). Please refer to gpt-accelerate for environment setup and model weight conversion instructions. 49 | 50 | ### 1. Prepare Model 51 | 52 | We first need to prepare the model checkpoint in the gpt-fast format. 53 | 54 | ```shell 55 | export DATA_DIR=/path/to/your/data/directory 56 | export MODEL_REPO=mistralai/Mistral-7B-v0.1 57 | 58 | python scripts/download.py \ 59 | --repo_id $MODEL_REPO \ 60 | --local_dir $DATA_DIR/checkpoints 61 | 62 | python scripts/convert_hf_checkpoint.py \ 63 | --checkpoint_dir $DATA_DIR/checkpoints/$MODEL_REPO \ 64 | --target_precision bf16 65 | ``` 66 | 67 | ### 2. Train SFT Model 68 | 69 | ```shell 70 | export DATA_DIR=/path/to/your/data/directory 71 | export MODEL_REPO= $DATA_DIR/checkpoints/Mistral-7B-v0.1 72 | 73 | export OMP_NUM_THREADS=8 74 | 75 | 76 | SFT_TRAIN_DATA=https://huggingface.co/datasets/AndrewZeng/math-trn-format/blob/main/math_format.json 77 | 78 | # Please download this dataset to local folder 79 | SFT_MODEL_SAVE_NAME=math_format_11k_mistral 80 | 81 | torchrun --standalone --nproc_per_node=8 \ 82 | train_sft.py \ 83 | --do_train \ 84 | --checkpoint_path $MODEL_REPO/model.pth \ 85 | --source_max_len 768 \ 86 | --target_max_len 768 \ 87 | --total_max_len 1024 \ 88 | --per_device_train_batch_size 16 \ 89 | --micro_train_batch_size 4 \ 90 | --learning_rate 5e-6 \ 91 | --lr_eta_min 2e-7 \ 92 | --num_train_epochs 3 \ 93 | --dataset "$SFT_TRAIN_DATA" \ 94 | --dataset_format "metamath" \ 95 | --add_eos_to_marked_target \ 96 | --save_strategy "steps" \ 97 | --save_steps 25 \ 98 | --optim_dtype bf16 \ 99 | --save_total_limit 40 \ 100 | --tensor_parallel_size 1 \ 101 | --save_dir $DATA_DIR/checkpoints/$SFT_MODEL_SAVE_NAME \ 102 | --resume_from_checkpoint 103 | ``` 104 | 105 | ### 3. Train PRM Model 106 | 107 | We constructed the [PRM training data](https://huggingface.co/datasets/AndrewZeng/prm-reward-data) using the [math-shepherd](https://arxiv.org/abs/2312.08935) approach and trained the reward model using a pointwise objective. 108 | 109 | ```shell 110 | export DATA_DIR=/path/to/your/data/directory 111 | 112 | export MODEL_REPO= $DATA_DIR/checkpoints/Mistral-7B-v0.1 113 | export OMP_NUM_THREADS=4 114 | 115 | 116 | RM_DATA=train_prm_math_shepherd_mistral.json 117 | RM_MODEL_SAVE_NAME=prm_model_mistral_sample_complete 118 | 119 | torchrun --standalone --nproc_per_node=8 \ 120 | train_rm_pointwise.py \ 121 | --do_train \ 122 | --checkpoint_path $MODEL_REPO/model.pth \ 123 | --source_max_len 768 \ 124 | --target_max_len 768 \ 125 | --total_max_len 1024 \ 126 | --per_device_train_batch_size 32 \ 127 | --micro_train_batch_size 32 \ 128 | --learning_rate 2e-6 \ 129 | --lr_eta_min 2e-7 \ 130 | --num_train_epochs 2 \ 131 | --dataset "$RM_DATA" \ 132 | --dataset_format "prm-v4" \ 133 | --save_strategy epoch \ 134 | --save_total_limit 5 \ 135 | --train_on_every_token \ 136 | --tensor_parallel_size 1 \ 137 | --save_only_model True \ 138 | --optim_dtype bf16 \ 139 | --save_dir $DATA_DIR/checkpoints/$RM_MODEL_SAVE_NAME \ 140 | --resume_from_checkpoint 141 | ``` 142 | 143 | ### 4. Train B-STaR 144 | 145 | ```shell 146 | ## This is our initial release code. 147 | ## We are working hard to clean it to make our code more clear and more readable 148 | cd train_code 149 | bash train_bstar.sh 150 | ``` 151 | 152 | ### 5. Evaluation 153 | 154 | Coming Soon ! 155 | 156 | ## Citation 157 | 158 | If you find B-STaR useful, please cite our paper: 159 | 160 | ``` 161 | @article{zeng2024bstar, 162 | title={B-STAR: Monitoring and Balancing Exploration and Exploitation in Self-Taught Reasoners}, 163 | author={Weihao Zeng, Yuzhen Huang, Lulu Zhao, Yijun Wang, Zifei Shan, Junxian He}, 164 | journal={arXiv preprint arXiv:2412.17256}, 165 | year={2024}, 166 | url={https://arxiv.org/abs/2412.17256} 167 | } 168 | ``` 169 | 170 | 171 | -------------------------------------------------------------------------------- /train_code/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/.DS_Store -------------------------------------------------------------------------------- /train_code/arc_prompt/0-shot-prompt.txt: -------------------------------------------------------------------------------- 1 | Your task is to answer the question below. Give step by step reasoning before you answer. 2 | 3 | Question: {{question}} 4 | Solution: 5 | -------------------------------------------------------------------------------- /train_code/arc_prompt/1-shot-prompt.txt: -------------------------------------------------------------------------------- 1 | Your task is to answer the question below. Give step by step reasoning before you answer. 2 | 3 | Question: George wants to warm his hands quickly by rubbing them. Which skin surface will produce the most heat? 4 | (A) dry palms 5 | (B) wet palms 6 | (C) palms covered with oil 7 | (D) palms covered with lotion 8 | Solution: 9 | To determine which skin surface will produce the most heat when George rubs his hands, let's analyze each option step by step. 10 | 1. Heat Production and Friction: When two surfaces are rubbed together, the heat generated is a result of friction between those surfaces. The higher the friction, the more heat is produced. 11 | 2. Effect of Different Surface Conditions: 12 | Dry Palms (Option A): Dry palms have a relatively high friction coefficient, as they lack any lubricants or moisture that could reduce friction. This means rubbing dry palms together would generate significant friction and, thus, more heat. 13 | Wet Palms (Option B): Wet palms reduce friction due to the presence of water. Water acts as a lubricant, decreasing the friction between the palms, resulting in less heat production. 14 | Palms Covered with Oil (Option C): Oil is an effective lubricant. It reduces friction significantly, which would decrease the heat produced when rubbing the palms. 15 | Palms Covered with Lotion (Option D): Lotion, similar to oil, acts as a lubricant and reduces friction between the palms, leading to lower heat production compared to dry palms. 16 | 3. Conclusion: Since friction is highest with dry palms, they will produce the most heat when rubbed together, as no lubricants are present to reduce friction. 17 | Final answer: (B) 18 | 19 | Question: {{question}} 20 | Solution: 21 | -------------------------------------------------------------------------------- /train_code/cal_metric_vllm.py: -------------------------------------------------------------------------------- 1 | import json 2 | from models.tokenizer_utils import AcceleraTokenizer 3 | #from math_utils.math_rl_utils import post_process_math_rollouts 4 | 5 | from typing import Dict, List, Optional, Tuple 6 | import torch 7 | import trainers.common_utils as common_utils 8 | from itertools import chain 9 | from math_utils import grader 10 | import torch.distributed as dist 11 | from pathlib import Path 12 | def last_boxed_only_string(string): 13 | idx = string.rfind("\\boxed") 14 | if idx < 0: 15 | return None 16 | 17 | i = idx 18 | right_brace_idx = None 19 | num_left_braces_open = 0 20 | while i < len(string): 21 | if string[i] == "{": 22 | num_left_braces_open += 1 23 | if string[i] == "}": 24 | num_left_braces_open -= 1 25 | if num_left_braces_open == 0: 26 | right_brace_idx = i 27 | break 28 | i += 1 29 | 30 | if right_brace_idx == None: 31 | retval = None 32 | else: 33 | retval = string[idx : right_brace_idx + 1] 34 | 35 | return retval 36 | 37 | 38 | def remove_boxed(s): 39 | left = "\\boxed{" 40 | try: 41 | assert s[: len(left)] == left 42 | assert s[-1] == "}" 43 | return s[len(left) : -1] 44 | except: 45 | return None 46 | def _calculate_outcome_accuracy( 47 | predicted_answers: List[str], 48 | gt_answers: List[str], 49 | answers: List[str], 50 | levels: List[int], 51 | outcome_reward: bool, 52 | easy_outcome_reward: bool, 53 | device: torch.device, 54 | ): 55 | assert len(predicted_answers) == len(answers) 56 | 57 | assert not ( 58 | outcome_reward and easy_outcome_reward 59 | ), "Cannot use both outcome_reward and easy_outcome_reward." 60 | 61 | with common_utils.DisableLogger(): 62 | outcome_accuracy = [ 63 | 1.0 if grader.grade_answer(predicted_answer, gt_answer) else 0.0 64 | for predicted_answer, gt_answer in zip(predicted_answers, gt_answers) 65 | ] 66 | 67 | # TODO (zhiqings): 0.25 is a magic number. 68 | unavailable_reward = 0.25 69 | if outcome_reward: 70 | symbolic_rewards = outcome_accuracy 71 | elif easy_outcome_reward: 72 | symbolic_rewards = [] 73 | for predicted_answer, answer in zip(predicted_answers, answers): 74 | if answer == "Unavailable": 75 | score = unavailable_reward 76 | elif grader.grade_answer(predicted_answer, answer): 77 | score = 1.0 78 | else: 79 | score = 0.0 80 | symbolic_rewards.append(score) 81 | else: 82 | symbolic_rewards = [ 83 | unavailable_reward for _ in range(len(predicted_answers)) 84 | ] 85 | 86 | assert len(symbolic_rewards) == len(predicted_answers) 87 | 88 | per_level_counts = {} 89 | per_level_accuracy = {} 90 | 91 | all_unique_levels = list(range(1, 6)) 92 | 93 | for level in all_unique_levels: 94 | per_level_counts[level] = [] 95 | per_level_accuracy[level] = [] 96 | 97 | for level, accuracy in zip(levels, outcome_accuracy): 98 | for unique_level in all_unique_levels: 99 | if level == unique_level: 100 | per_level_counts[unique_level].append(1.0) 101 | per_level_accuracy[unique_level].append(accuracy) 102 | else: 103 | per_level_counts[unique_level].append(0.0) 104 | per_level_accuracy[unique_level].append(0.0) 105 | 106 | for level in all_unique_levels: 107 | assert len(per_level_counts[level]) == len(outcome_accuracy) 108 | assert len(per_level_accuracy[level]) == len(outcome_accuracy) 109 | per_level_counts[level] = torch.tensor(per_level_counts[level], device=device) 110 | per_level_accuracy[level] = torch.tensor( 111 | per_level_accuracy[level], device=device 112 | ) 113 | 114 | original_symbolic_rewards = symbolic_rewards 115 | 116 | symbolic_rewards = torch.tensor(symbolic_rewards, device=device) 117 | outcome_accuracy = torch.tensor(outcome_accuracy, device=device) 118 | 119 | ret_dict = { 120 | "symbolic_rewards": symbolic_rewards, 121 | "outcome_accuracy": outcome_accuracy, 122 | } 123 | 124 | for level in sorted(list(all_unique_levels)): 125 | ret_dict[f"level_{level}_counts"] = per_level_counts[level] 126 | ret_dict[f"level_{level}_accuracy"] = per_level_accuracy[level] 127 | 128 | return ret_dict, original_symbolic_rewards 129 | def merge_fn(tensor_or_list): 130 | if isinstance(tensor_or_list[0], list): 131 | return list(chain(*tensor_or_list)) 132 | else: 133 | return torch.cat(tensor_or_list, dim=0) 134 | def post_process_math_rollouts( 135 | text_responses: List[str], 136 | answers: List[str], 137 | gt_answers: List[str], 138 | levels: List[str], 139 | tokenizer: AcceleraTokenizer, 140 | stop_token: Optional[str], 141 | outcome_reward: bool, 142 | easy_outcome_reward: bool, 143 | device: torch.device, 144 | ): 145 | if stop_token is not None: 146 | parsed_stop_token = stop_token 147 | parsed_stop_token = parsed_stop_token.replace(r"\n", "\n") 148 | parsed_stop_token = parsed_stop_token.replace(r"\\", "\\") 149 | else: 150 | parsed_stop_token = tokenizer.eos_token 151 | 152 | predicted_answers = [] 153 | for text_response in text_responses: 154 | predicted_answer = "No answer found." 155 | if "\n\n" in parsed_stop_token: 156 | if parsed_stop_token in text_response: 157 | predicted_answer = text_response.split(parsed_stop_token)[1] 158 | predicted_answer = predicted_answer.split(tokenizer.eos_token)[0] 159 | elif "\\boxed{}" == parsed_stop_token: 160 | boxed_predicted_answer = text_response.split(tokenizer.eos_token)[0] 161 | boxed_predicted_answer = remove_boxed( 162 | last_boxed_only_string(boxed_predicted_answer) 163 | ) 164 | if boxed_predicted_answer is not None: 165 | predicted_answer = boxed_predicted_answer 166 | else: 167 | raise ValueError(f"Unknown stop token: {parsed_stop_token}") 168 | predicted_answers.append(predicted_answer) 169 | 170 | # text_answers_gt_levels = tokenizer.batch_decode( 171 | # answer_gt_levels, 172 | # skip_special_tokens=True, 173 | # ) 174 | 175 | #answers, gt_answers, levels = [], [], [] 176 | # for text_answers_gt_level in text_answers_gt_levels: 177 | # assert len(text_answers_gt_level.split(";;;")) == 3, text_answers_gt_level 178 | # answer, gt_answer, level = text_answers_gt_level.split(";;;") 179 | # answers.append(answer.strip()) 180 | # gt_answers.append(gt_answer.strip()) 181 | # levels.append(int(level.strip())) 182 | 183 | outcome_metrics, symbolic_rewards = _calculate_outcome_accuracy( 184 | predicted_answers, 185 | gt_answers, 186 | answers, 187 | levels, 188 | outcome_reward, 189 | easy_outcome_reward, 190 | device, 191 | ) 192 | return ( 193 | predicted_answers, 194 | gt_answers, 195 | levels, 196 | symbolic_rewards, 197 | outcome_metrics, 198 | ) 199 | def main( 200 | tokenizer_path: Path = Path( 201 | "/ssddata/weihao00/model_zoo/llemma_7b/checkpoints/EleutherAI/llemma_7b/tokenizer.model" 202 | ), 203 | 204 | answer_file: Path = Path( 205 | "/ssddata/weihao00/easy2hard/easy-to-hard-main/data/test_ppo.json" 206 | ), 207 | output_file: Path = Path( 208 | "/ssddata/weihao00/easy2hard/save_file/test_ppo_infer_metric.json" 209 | ), 210 | 211 | 212 | 213 | ): 214 | 215 | tokenizer = AcceleraTokenizer(tokenizer_path) 216 | tokenizer.pad_id = tokenizer.unk_id 217 | 218 | with open(answer_file, "r") as r: 219 | test_ppo = json.load(r) 220 | 221 | 222 | # with open(prompt_file , "r") as r: 223 | # data_lines = r.readlines() 224 | 225 | # data_json = [json.loads(l) for l in data_lines] 226 | #test_ppo = data_json 227 | # for idx, item in enumerate(test_ppo): 228 | # item["idx"] = idx 229 | 230 | # for i in data_json: 231 | # if i["idx"] == idx: 232 | # item["prompt"] = i["prompt"] 233 | 234 | # item["output"] = i["output"] 235 | 236 | 237 | test_queries = [] 238 | test_responses = [] 239 | answers_list = [] 240 | gt_answers_list = [] 241 | levels_list = [] 242 | for item in test_ppo: 243 | test_queries.append(item["input"]) 244 | test_responses.append(item["output0"]) 245 | answers_list.append(item["answer"]) 246 | gt_answers_list.append(item["gt_answer"]) 247 | levels_list.append(item["level"]) 248 | 249 | eval_rollouts_batch = {} 250 | eval_rollouts_batch["text_queries"] = test_queries 251 | eval_rollouts_batch["text_responses"] = test_responses 252 | 253 | outcome_metrics = post_process_math_rollouts(test_responses, answers_list, gt_answers_list, levels_list, tokenizer, "\n\n# Answer\n\n", False, False, torch.device('cpu')) 254 | 255 | eval_rollouts_batch.update(outcome_metrics[-1]) 256 | cpu_eval_rollouts = [] 257 | 258 | cpu_eval_rollouts.append( 259 | { 260 | key: value.cpu() if torch.is_tensor(value) else value 261 | for key, value in eval_rollouts_batch.items() 262 | } 263 | ) 264 | eval_rollouts = cpu_eval_rollouts 265 | 266 | eval_rollouts = common_utils.merge_dict(eval_rollouts, merge_fn=merge_fn) 267 | 268 | eval_stats = {} 269 | overall_counts = 0.0 270 | overall_accuracy = 0.0 271 | for level in range(9): 272 | if f"level_{level}_counts" in eval_rollouts: 273 | level_counts = eval_rollouts[f"level_{level}_counts"].sum() 274 | level_accuracy = eval_rollouts[f"level_{level}_accuracy"].sum() 275 | overall_counts += level_counts 276 | overall_accuracy += level_accuracy 277 | 278 | eval_stats[f"accuracy_level_{level}"] = level_accuracy / level_counts 279 | 280 | eval_stats[f"counts_level_{level}"] = level_counts 281 | 282 | eval_stats[f"accuracy_overall"] = overall_accuracy.view(1) / ( 283 | overall_counts.view(1) 284 | ) 285 | eval_stats[f"counts_overall"] = overall_counts 286 | eval_stats = { 287 | key: value.item() if torch.is_tensor(value) else value 288 | for key, value in eval_stats.items() 289 | } 290 | print(eval_stats) 291 | with open(output_file, "w") as w: 292 | json.dump(eval_stats, w) 293 | 294 | if __name__ == "__main__": 295 | 296 | import argparse 297 | 298 | parser = argparse.ArgumentParser(description="Your CLI description.") 299 | 300 | parser.add_argument( 301 | "--tokenizer_path", 302 | type=Path, 303 | required=True, 304 | help="File containing prompts, one per line.", 305 | ) 306 | parser.add_argument( 307 | "--answer_file", 308 | type=Path, 309 | required=True, 310 | help="File containing prompts, one per line.", 311 | ) 312 | parser.add_argument( 313 | "--output_file", 314 | type=Path, 315 | required=True, 316 | help="File to write generated samples to.", 317 | ) 318 | args = parser.parse_args() 319 | main( 320 | args.tokenizer_path, 321 | args.answer_file, 322 | args.output_file, 323 | 324 | ) 325 | 326 | 327 | # tokenizer = AcceleraTokenizer("/ssddata/weihao00/model_zoo/llemma_7b/checkpoints/EleutherAI/llemma_7b/tokenizer.model") 328 | # tokenizer.pad_id = tokenizer.unk_id 329 | # with open("/ssddata/weihao00/easy2hard/easy-to-hard-main/data/test_ppo.json", "r") as r: 330 | # test_ppo = json.load(r) 331 | 332 | # with open("/ssddata/weihao00/easy2hard/save_file/test_ppo_infer.json" , "r") as r: 333 | # data_lines = r.readlines() 334 | 335 | 336 | # data_set = set() 337 | 338 | # data_json = [json.loads(l) for l in data_lines] 339 | 340 | # # for item in data_json: 341 | # # data_set.add(item) 342 | 343 | 344 | # for idx, item in enumerate(test_ppo): 345 | # item["idx"] = idx 346 | 347 | # for i in data_json: 348 | # if i["idx"] == idx: 349 | # item["prompt"] = i["prompt"] 350 | 351 | # item["output"] = i["output"] 352 | 353 | 354 | # test_queries = [] 355 | # test_responses = [] 356 | # answers_list = [] 357 | # gt_answers_list = [] 358 | # levels_list = [] 359 | # for item in test_ppo: 360 | # test_queries.append(item["input"]) 361 | # test_responses.append(item["output"]) 362 | # answers_list.append(item["answer"]) 363 | # gt_answers_list.append(item["gt_answer"]) 364 | # levels_list.append(item["level"]) 365 | # eval_rollouts_batch = {} 366 | # eval_rollouts_batch["text_queries"] = test_queries 367 | # eval_rollouts_batch["text_responses"] = test_responses 368 | 369 | # outcome_metrics = post_process_math_rollouts(test_responses, answers_list, gt_answers_list, levels_list, tokenizer, "\n\n# Answer\n\n", False, False, "cuda:6") 370 | 371 | # eval_rollouts_batch.update(outcome_metrics[-1]) 372 | # cpu_eval_rollouts = [] 373 | 374 | # cpu_eval_rollouts.append( 375 | # { 376 | # key: value.cpu() if torch.is_tensor(value) else value 377 | # for key, value in eval_rollouts_batch.items() 378 | # } 379 | # ) 380 | # eval_rollouts = cpu_eval_rollouts 381 | 382 | # eval_rollouts = common_utils.merge_dict(eval_rollouts, merge_fn=merge_fn) 383 | # # filtered_eval_rollouts = {} 384 | 385 | # # for key, value in eval_rollouts.items(): 386 | # # filtered_eval_rollouts[key] = value[:eval_data_size] 387 | # # eval_rollouts = filtered_eval_rollouts 388 | 389 | # eval_stats = {} 390 | # overall_counts = 0.0 391 | # overall_accuracy = 0.0 392 | # for level in range(9): 393 | # if f"level_{level}_counts" in eval_rollouts: 394 | # level_counts = eval_rollouts[f"level_{level}_counts"].sum() 395 | # level_accuracy = eval_rollouts[f"level_{level}_accuracy"].sum() 396 | # overall_counts += level_counts 397 | # overall_accuracy += level_accuracy 398 | 399 | # eval_stats[f"accuracy_level_{level}"] = level_accuracy / level_counts 400 | 401 | # eval_stats[f"counts_level_{level}"] = level_counts 402 | 403 | # eval_stats[f"accuracy_overall"] = overall_accuracy.view(1) / ( 404 | # overall_counts.view(1) 405 | # ) 406 | # eval_stats[f"counts_overall"] = overall_counts 407 | # eval_stats = { 408 | # key: value.item() if torch.is_tensor(value) else value 409 | # for key, value in eval_stats.items() 410 | # } 411 | # print("bupt") 412 | 413 | 414 | 415 | 416 | 417 | 418 | 419 | 420 | 421 | 422 | 423 | 424 | 425 | # print("bupt") -------------------------------------------------------------------------------- /train_code/convert4reward_auto_ground_sample.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | import os 4 | 5 | def process_files(base_path, output_path, num_files, sample_num): 6 | data_list = [] 7 | 8 | if num_files == -1: 9 | with open(base_path, "r") as r: 10 | data_list = json.load(r) 11 | else: 12 | for i in range(1, num_files + 1): 13 | file_path = base_path.format(i) 14 | if not os.path.exists(file_path): 15 | print(f"File {file_path} does not exist, skipping.") 16 | continue 17 | with open(file_path, "r") as r: 18 | data = json.load(r) 19 | data_list.extend(data) 20 | 21 | trn_json = [] 22 | sample_json = [] 23 | 24 | for item in data_list: 25 | trn_json.append(item) 26 | 27 | data_json = trn_json 28 | format_data_json = [] 29 | 30 | for idx, item in enumerate(data_json): 31 | for i in range(sample_num): 32 | temp_json = { 33 | "idx": idx, 34 | "sample_idx": i, 35 | "prompt": item["prompt"], 36 | "response": item["output"], 37 | "output": item["output" + str(i)] 38 | } 39 | 40 | if "\n\n# Answer\n\n" in temp_json["output"]: 41 | format_data_json.append(temp_json) 42 | 43 | with open(output_path, "w") as w: 44 | for item in format_data_json: 45 | w.write(json.dumps(item)) 46 | w.write("\n") 47 | 48 | if __name__ == "__main__": 49 | parser = argparse.ArgumentParser(description="Process JSON files and output formatted data.") 50 | parser.add_argument("--input_path", type=str, required=True, help="Base path for input files, use {} for file number placeholder") 51 | parser.add_argument("--output_path", type=str, required=True, help="Path for output file") 52 | parser.add_argument("--num_files", type=int, required=True, help="Number of input files to process") 53 | parser.add_argument("--sample_num", type=int, required=True, help="Number of samples to process") 54 | 55 | args = parser.parse_args() 56 | 57 | process_files(args.input_path, args.output_path, args.num_files, args.sample_num) 58 | -------------------------------------------------------------------------------- /train_code/convert_auto.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | def convert_checkpoints(checkpoint_dir, pretrain_name, tokenizer_name): 5 | # 定义检查点文件夹路径和其他参数 6 | last_checkpoint_file = os.path.join(checkpoint_dir, "last_checkpoint") 7 | 8 | # 创建保存转换文件的文件夹 9 | converted_dir = os.path.join(checkpoint_dir, "converted") 10 | os.makedirs(converted_dir, exist_ok=True) 11 | 12 | max_step = -1 13 | latest_ckpt_file = None 14 | 15 | # 遍历文件夹中所有的检查点文件 16 | for file_name in os.listdir(checkpoint_dir): 17 | if file_name.endswith(".pt"): 18 | # 修改 last_checkpoint 文件内容 19 | with open(last_checkpoint_file, 'w') as file: 20 | file.write(file_name) 21 | # 提取 step 数目 22 | step_number = int(file_name.split('_')[3]) 23 | 24 | # 更新最新的ckpt文件 25 | if step_number > max_step: 26 | max_step = step_number 27 | latest_ckpt_file = file_name 28 | 29 | save_name_hf = os.path.join(converted_dir, f"ckpt{step_number}") 30 | 31 | # 如果保存文件已存在,则跳过 32 | if os.path.exists(save_name_hf): 33 | print(f"{save_name_hf} already exists. Skipping...") 34 | continue 35 | 36 | # 构建并执行转换命令 37 | command = f""" 38 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python convert_checkpoint_to_hf.py \\ 39 | --tp_ckpt_name {checkpoint_dir} \\ 40 | --pretrain_name {pretrain_name} \\ 41 | --tokenizer_name {tokenizer_name} \\ 42 | --save_name_hf {save_name_hf} 43 | """ 44 | print(f"Executing command: {command}") 45 | os.system(command) 46 | 47 | # 在所有转换完成后,修改 last_checkpoint 文件内容 48 | if latest_ckpt_file: 49 | with open(last_checkpoint_file, 'w') as file: 50 | file.write(latest_ckpt_file) 51 | print(f"Updated {last_checkpoint_file} with {latest_ckpt_file}") 52 | 53 | if __name__ == "__main__": 54 | parser = argparse.ArgumentParser(description='Convert checkpoints to Hugging Face format') 55 | parser.add_argument('--checkpoint_dir', type=str, required=True, help='Path to the checkpoint directory') 56 | parser.add_argument('--pretrain_name', type=str, required=True, help='Name of the pre-trained model') 57 | parser.add_argument('--tokenizer_name', type=str, required=True, help='Name of the tokenizer') 58 | 59 | args = parser.parse_args() 60 | 61 | convert_checkpoints(args.checkpoint_dir, args.pretrain_name, args.tokenizer_name) 62 | -------------------------------------------------------------------------------- /train_code/convert_checkpoint_to_hf.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig 2 | from tqdm import tqdm 3 | import torch 4 | import re 5 | import argparse 6 | import os 7 | import glob 8 | 9 | # we need to check that we have login the HF account 10 | # !huggingface-cli whoami 11 | # !huggingface-cli login 12 | 13 | 14 | def load_and_merge_models( 15 | tp_ckpt_name, pretrain_name, tokenizer_name, save_name_hf, push_to_hf_hub_name 16 | ): 17 | assert ( 18 | save_name_hf or push_to_hf_hub_name 19 | ), "Please provide a save path or push to HF hub name" 20 | 21 | tp_model_list = [] 22 | 23 | last_checkpoint_file = os.path.join(tp_ckpt_name, "last_checkpoint") 24 | with open(last_checkpoint_file, "r") as f: 25 | last_checkpoint_file = f.readline().strip() 26 | 27 | last_checkpoint_file = last_checkpoint_file.split("/")[-1] 28 | last_checkpoint_file = os.path.join(tp_ckpt_name, last_checkpoint_file) 29 | 30 | print("Loading checkpoint files:", last_checkpoint_file) 31 | for file in sorted(glob.glob(last_checkpoint_file)): 32 | tp_model_list.append( 33 | torch.load( 34 | file, 35 | mmap=True, 36 | )["model"] 37 | ) 38 | 39 | print("Loading HF model...") 40 | tokenizer = AutoTokenizer.from_pretrained( 41 | tokenizer_name, 42 | ) 43 | 44 | model = AutoModelForCausalLM.from_pretrained( 45 | pretrain_name, 46 | # device_map="cpu", 47 | load_in_8bit=False, 48 | torch_dtype=torch.bfloat16, 49 | ) 50 | cpu_state_dict = model.cpu().state_dict() 51 | 52 | replaced_keys = set() 53 | 54 | print("Convert to HF model...") 55 | num_tp = len(tp_model_list) 56 | 57 | state_dict = {} 58 | 59 | for key in tp_model_list[0].keys(): 60 | if "wo" in key or "w2" in key: 61 | state_dict[key] = torch.cat( 62 | [tp_model_list[i][key].cpu() for i in range(num_tp)], dim=1 63 | ) 64 | elif "wqkv" in key: 65 | state_dict[key] = torch.stack( 66 | [tp_model_list[i][key].cpu() for i in range(num_tp)], dim=0 67 | ) 68 | elif "output" in key: 69 | state_dict[key] = torch.cat( 70 | [tp_model_list[i][key].cpu() for i in range(num_tp)], dim=1 71 | ) 72 | else: 73 | state_dict[key] = torch.cat( 74 | [tp_model_list[i][key].cpu() for i in range(num_tp)], dim=0 75 | ) 76 | 77 | pattern = r"layers\.(\d+)\." 78 | 79 | for key in state_dict.keys(): 80 | layer = None 81 | match = re.search(pattern, key) 82 | # layer number except for: 83 | # lm_head.weight 84 | if match: 85 | layer = match.group(1) 86 | elif "output.weight" in key: 87 | name = f"lm_head.weight" 88 | print(cpu_state_dict[name].size(), state_dict[key].size()) 89 | # repeat on dim 0 to match the size 90 | repeat_size = cpu_state_dict[name].size(0) // state_dict[key].size(0) 91 | new_state_dict = state_dict[key].repeat(repeat_size, 1) 92 | cpu_state_dict[name] = 0.0 * cpu_state_dict[name] + new_state_dict 93 | replaced_keys.add(name) 94 | else: 95 | raise ValueError(f"Invalid key: {key}") 96 | 97 | print("Converting layer", key) 98 | if "wqkv" in key: 99 | merged_q, merged_k, merged_v = [], [], [] 100 | reconstruct_q, reconstruct_k = [], [] 101 | 102 | print("state_dict[key].size(2)", state_dict[key].size(2)) 103 | 104 | 105 | 106 | if state_dict[key].size(2) == 4096: 107 | if "Mistral-7B-v0.1" in pretrain_name: 108 | n_heads, n_local_heads = 32, 8 109 | print("Mistral-7B-v0.1") 110 | 111 | elif "Mistral-7B" in pretrain_name: 112 | n_heads, n_local_heads = 32, 8 113 | print("Mistral-7B") 114 | 115 | elif "Meta-Llama-3-8B" in pretrain_name: 116 | n_heads, n_local_heads = 32, 8 117 | print("Meta-Llama-3-8B") 118 | 119 | elif "Llama-3.1-8B" in pretrain_name: 120 | n_heads, n_local_heads = 32, 8 121 | print("Meta-Llama-3-8B") 122 | elif "Meta-Llama-3-70B" in pretrain_name: 123 | n_heads, n_local_heads = 64, 8 124 | print("Meta-Llama-3-70B") 125 | else: 126 | 127 | n_heads, n_local_heads = 32, 32 128 | elif state_dict[key].size(2) == 5120: 129 | n_heads, n_local_heads = 40, 40 130 | elif state_dict[key].size(2) == 6656: 131 | n_heads, n_local_heads = 52, 52 132 | elif state_dict[key].size(2) == 8192: 133 | n_heads, n_local_heads = 64, 8 134 | else: 135 | raise ValueError(f"Invalid size for {key}: {state_dict[key].size()}") 136 | 137 | head_dim = state_dict[key].size(1) // (n_heads + n_local_heads * 2) 138 | 139 | weight_splits = [ 140 | head_dim * n_heads, 141 | head_dim * n_local_heads, 142 | head_dim * n_local_heads, 143 | ] 144 | 145 | for split_idx in range(state_dict[key].size(0)): 146 | chunk = state_dict[key][split_idx] 147 | q, k, v = chunk.split(weight_splits, dim=0) 148 | merged_q.append(q) 149 | merged_k.append(k) 150 | merged_v.append(v) 151 | merged_q = torch.cat(merged_q, dim=0) 152 | merged_k = torch.cat(merged_k, dim=0) 153 | merged_v = torch.cat(merged_v, dim=0) 154 | 155 | #### qk need reconstruction #### 156 | split_qs = torch.split(merged_q, split_size_or_sections=128, dim=0) 157 | split_ks = torch.split(merged_k, split_size_or_sections=128, dim=0) 158 | for split in split_qs: 159 | matrix0 = split[::2, :] 160 | matrix1 = split[1::2, :] 161 | reconstruct_q.append(matrix0) 162 | reconstruct_q.append(matrix1) 163 | reconstruct_q = torch.cat(reconstruct_q, dim=0) 164 | for split in split_ks: 165 | matrix0 = split[::2, :] 166 | matrix1 = split[1::2, :] 167 | reconstruct_k.append(matrix0) 168 | reconstruct_k.append(matrix1) 169 | reconstruct_k = torch.cat(reconstruct_k, dim=0) 170 | #### qk need reconstruction #### 171 | 172 | name = f"model.layers.{layer}.self_attn.q_proj.weight" 173 | cpu_state_dict[name] = reconstruct_q 174 | replaced_keys.add(name) 175 | 176 | name = f"model.layers.{layer}.self_attn.k_proj.weight" 177 | cpu_state_dict[name] = reconstruct_k 178 | replaced_keys.add(name) 179 | 180 | name = f"model.layers.{layer}.self_attn.v_proj.weight" 181 | cpu_state_dict[name] = merged_v 182 | replaced_keys.add(name) 183 | 184 | if "wo" in key: 185 | name = f"model.layers.{layer}.self_attn.o_proj.weight" 186 | cpu_state_dict[name] = state_dict[key] 187 | replaced_keys.add(name) 188 | if "w1" in key: 189 | name = f"model.layers.{layer}.mlp.gate_proj.weight" 190 | cpu_state_dict[name] = state_dict[key] 191 | replaced_keys.add(name) 192 | if "w3" in key: 193 | name = f"model.layers.{layer}.mlp.up_proj.weight" 194 | cpu_state_dict[name] = state_dict[key] 195 | replaced_keys.add(name) 196 | if "w2" in key: 197 | name = f"model.layers.{layer}.mlp.down_proj.weight" 198 | cpu_state_dict[name] = state_dict[key] 199 | replaced_keys.add(name) 200 | 201 | unreplaced_keys = set(cpu_state_dict.keys()) - replaced_keys 202 | print("Unreplaced keys:", unreplaced_keys) 203 | 204 | print("Loading state dict...") 205 | 206 | model.load_state_dict(cpu_state_dict, strict=False) 207 | 208 | print("Saving HF model...") 209 | 210 | if save_name_hf is not None: 211 | model.save_pretrained(save_name_hf) 212 | config = AutoConfig.from_pretrained(pretrain_name) 213 | tokenizer.save_pretrained(save_name_hf) 214 | config.save_pretrained(save_name_hf) 215 | else: 216 | model.push_to_hub(push_to_hf_hub_name, private=True, safe_serialization=False) 217 | 218 | 219 | if __name__ == "__main__": 220 | parser = argparse.ArgumentParser(description="Process some integers.") 221 | parser.add_argument( 222 | "--tp_ckpt_name", type=str, help="Path to the TP checkpoint name", required=True 223 | ) 224 | parser.add_argument( 225 | "--tokenizer_name", type=str, help="Path to the tokenizer name", required=True 226 | ) 227 | parser.add_argument( 228 | "--pretrain_name", type=str, help="Path to the pretrain name", required=True 229 | ) 230 | parser.add_argument( 231 | "--save_name_hf", type=str, default=None, help="Path to save the HF model" 232 | ) 233 | parser.add_argument( 234 | "--push_to_hf_hub_name", type=str, default=None, help="Push to HF hub" 235 | ) 236 | 237 | args = parser.parse_args() 238 | load_and_merge_models( 239 | args.tp_ckpt_name, 240 | args.pretrain_name, 241 | args.tokenizer_name, 242 | args.save_name_hf, 243 | args.push_to_hf_hub_name, 244 | ) 245 | -------------------------------------------------------------------------------- /train_code/data_utils/__pycache__/data_utils_ppo.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/data_utils/__pycache__/data_utils_ppo.cpython-310.pyc -------------------------------------------------------------------------------- /train_code/data_utils/__pycache__/data_utils_ppo.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/data_utils/__pycache__/data_utils_ppo.cpython-311.pyc -------------------------------------------------------------------------------- /train_code/data_utils/__pycache__/data_utils_rm_pointwise.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/data_utils/__pycache__/data_utils_rm_pointwise.cpython-310.pyc -------------------------------------------------------------------------------- /train_code/data_utils/data_utils_dpo.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The GPT-Accelera Team 2 | # Copyright 2023 The Alpaca Team 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from typing import Dict, Sequence, Union 17 | 18 | import numpy as np 19 | import torch 20 | from torch.utils.data import Dataset 21 | from datasets import Dataset as HFDataset 22 | 23 | from arguments import Arguments 24 | import trainers.common_utils as utils 25 | from models.tokenizer_utils import AcceleraTokenizer 26 | from data_utils.data_utils_sft import preprocess_for_sft, extract_alpaca_dataset 27 | 28 | 29 | class DPODataset(Dataset): 30 | def __init__( 31 | self, 32 | args: Arguments, 33 | dataset: HFDataset, 34 | tokenizer: AcceleraTokenizer, 35 | ): 36 | super(DPODataset, self).__init__() 37 | self.tensors = preprocess_for_dpo( 38 | args=args, 39 | dataset=dataset, 40 | tokenizer=tokenizer, 41 | ) 42 | 43 | def __len__(self): 44 | return len(next(iter(self.tensors.values()))) 45 | 46 | def __getitem__(self, i) -> Dict[str, torch.Tensor]: 47 | return {key: value[i] for key, value in self.tensors.items()} 48 | 49 | 50 | def preprocess_for_dpo( 51 | args: Arguments, 52 | dataset: HFDataset, 53 | tokenizer: AcceleraTokenizer, 54 | reorder_wl: bool = True, 55 | ) -> dict[str, Union[torch.Tensor, Sequence[torch.Tensor]]]: 56 | df = dataset.to_pandas() 57 | output_1, output_2, preference = df["output_1"], df["output_2"], df["preference"] 58 | 59 | assign_w_kwargs = dict( 60 | output=np.where(preference == 1, output_1, output_2), 61 | ) 62 | assign_l_kwargs = dict( 63 | output=np.where(preference == 2, output_1, output_2), 64 | ) 65 | assign_keys = ["instruction", "input", "output"] 66 | 67 | if "is_eos_1" in df.columns: 68 | is_eos_1, is_eos_2 = df["is_eos_1"], df["is_eos_2"] 69 | assign_w_kwargs.update( 70 | is_eos=np.where(preference == 1, is_eos_1, is_eos_2), 71 | ) 72 | assign_l_kwargs.update( 73 | is_eos=np.where(preference == 2, is_eos_1, is_eos_2), 74 | ) 75 | assign_keys.extend(["is_eos"]) 76 | 77 | if "win_rate_1" in df.columns: 78 | win_rate_1, win_rate_2 = df["win_rate_1"], df["win_rate_2"] 79 | assign_w_kwargs.update( 80 | win_rate=np.where(preference == 1, win_rate_1, win_rate_2), 81 | ) 82 | assign_l_kwargs.update( 83 | win_rate=np.where(preference == 2, win_rate_1, win_rate_2), 84 | ) 85 | assign_keys.extend(["win_rate"]) 86 | 87 | if reorder_wl: 88 | df_w = df.assign(**assign_w_kwargs)[assign_keys] 89 | df_l = df.assign(**assign_l_kwargs)[assign_keys] 90 | else: 91 | df_w = df.assign(output=output_1)[assign_w_kwargs] 92 | df_l = df.assign(output=output_2)[assign_l_kwargs] 93 | 94 | df_w_list = df_w.to_dict("records") 95 | df_l_list = df_l.to_dict("records") 96 | 97 | assert len(df_w_list) == len(df_l_list) 98 | 99 | if args.dataset_format == "alpaca": 100 | for i in range(len(df_w_list)): 101 | df_w_list[i].update(extract_alpaca_dataset(df_w_list[i])) 102 | df_l_list[i].update(extract_alpaca_dataset(df_l_list[i])) 103 | elif args.dataset_format is None: 104 | pass 105 | else: 106 | raise ValueError(f"Unknown dataset format: {args.dataset_format}") 107 | 108 | tensors_w = preprocess_for_sft( 109 | instances=df_w_list, 110 | tokenizer=tokenizer, 111 | source_max_len=args.source_max_len, 112 | target_max_len=args.target_max_len, 113 | total_max_len=args.total_max_len, 114 | train_on_source=args.train_on_source, 115 | add_eos_to_target=args.add_eos_to_target, 116 | add_eos_to_marked_target=args.add_eos_to_marked_target, 117 | return_win_rate=True, 118 | ) 119 | tensors_l = preprocess_for_sft( 120 | instances=df_l_list, 121 | tokenizer=tokenizer, 122 | source_max_len=args.source_max_len, 123 | target_max_len=args.target_max_len, 124 | total_max_len=args.total_max_len, 125 | train_on_source=args.train_on_source, 126 | add_eos_to_target=args.add_eos_to_target, 127 | add_eos_to_marked_target=args.add_eos_to_marked_target, 128 | return_win_rate=True, 129 | ) 130 | return dict( 131 | input_ids_w=tensors_w["input_ids"], 132 | labels_w=tensors_w["labels"], 133 | win_rate_w=tensors_w["win_rate"], 134 | input_ids_l=tensors_l["input_ids"], 135 | labels_l=tensors_l["labels"], 136 | win_rate_l=tensors_l["win_rate"], 137 | ) 138 | 139 | 140 | def make_dpo_data_module( 141 | tokenizer: AcceleraTokenizer, 142 | args: Arguments, 143 | ) -> dict: 144 | preference_dataset = utils.local_dataset(args.dataset) 145 | train_preference = preference_dataset["train"] 146 | 147 | train_dataset = DPODataset( 148 | args=args, 149 | dataset=train_preference, 150 | tokenizer=tokenizer, 151 | ) 152 | 153 | eval_dataset = None 154 | if args.eval_size > 0: 155 | train_dataset, eval_dataset = utils.split_train_into_train_and_eval( 156 | train_dataset=train_dataset, 157 | eval_size=args.eval_size, 158 | seed=args.seed, 159 | ) 160 | data_collator = utils.DataCollatorForStackableDataset() 161 | return dict( 162 | train_dataset=train_dataset, 163 | eval_dataset=eval_dataset, 164 | data_collator=data_collator, 165 | ) 166 | -------------------------------------------------------------------------------- /train_code/data_utils/data_utils_ppo.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The GPT-Accelera Team 2 | # Copyright 2023 The Alpaca Team 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import os 17 | from typing import Dict 18 | import logging 19 | 20 | import torch 21 | from torch.utils.data import Dataset 22 | from datasets import Dataset as HFDataset 23 | 24 | from arguments import Arguments 25 | import trainers.common_utils as utils 26 | from models.tokenizer_utils import AcceleraTokenizer 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | 31 | class QueryDataset(Dataset): 32 | """Dataset that emits tokenized left-padded queries.""" 33 | 34 | def __init__( 35 | self, 36 | dataset: HFDataset, 37 | tokenizer: AcceleraTokenizer, 38 | query_len: int, 39 | ): 40 | super(QueryDataset, self).__init__() 41 | 42 | list_dict_data = dataset.to_pandas().to_dict("records") 43 | 44 | # prompts are strings; queries are tensors. 45 | queries = [dict_data["input"] for dict_data in list_dict_data] 46 | answers = [ 47 | f"{dict_data['answer']} ;;; {dict_data['gt_answer']} ;;; {dict_data['level']}" 48 | for dict_data in list_dict_data 49 | ] 50 | 51 | logger.warning(f"Debugging: {answers[:10]}") 52 | queries = [ 53 | tokenizer(query, return_tensors="pt", truncation=False).input_ids.squeeze( 54 | dim=0 55 | ) 56 | for query in queries 57 | ] 58 | 59 | answers = [ 60 | tokenizer(answer, return_tensors="pt", truncation=False).input_ids.squeeze( 61 | dim=0 62 | ) 63 | for answer in answers 64 | ] 65 | 66 | filtered_queries = [] 67 | filtered_answers = [] 68 | 69 | for query, answer in zip(queries, answers): 70 | if len(query) <= query_len: 71 | filtered_queries.append(query) 72 | filtered_answers.append(answer) 73 | 74 | logger.warning( 75 | f"Filtered out {len(queries) - len(filtered_queries)} instances out of {len(queries)} that " 76 | f"exceed length limit. These examples are not used for training, but will still be used in evaluation. " 77 | ) 78 | 79 | queries = torch.stack( 80 | [ 81 | utils.left_pad(query, target_size=(query_len,), value=tokenizer.pad_id) 82 | for query in filtered_queries 83 | ] 84 | ) 85 | 86 | max_answer_len = max([len(answer) for answer in filtered_answers]) 87 | answers = torch.stack( 88 | [ 89 | utils.left_pad( 90 | answer, 91 | target_size=(max_answer_len,), 92 | value=tokenizer.pad_id, 93 | ) 94 | for answer in filtered_answers 95 | ] 96 | ) 97 | 98 | assert queries.shape[0] == answers.shape[0] 99 | 100 | self.queries = queries 101 | self.query_attn_masks = queries.ne(tokenizer.pad_id).long() 102 | self.answers = answers 103 | # Auxiliary data. 104 | self.list_dict_data = list_dict_data 105 | 106 | def __getitem__(self, i): 107 | return dict( 108 | queries=self.queries[i], 109 | query_attn_masks=self.query_attn_masks[i], 110 | answers=self.answers[i], 111 | ) 112 | 113 | def __len__(self): 114 | return len(self.queries) 115 | 116 | 117 | def make_rl_data_module( 118 | tokenizer: AcceleraTokenizer, 119 | args: Arguments, 120 | ) -> Dict: 121 | """ 122 | Make dataset and collator for supervised fine-tuning. 123 | Datasets are expected to have the following columns: { `input`, `output` } 124 | """ 125 | 126 | def load_data(dataset_name): 127 | if os.path.exists(dataset_name): 128 | try: 129 | full_dataset = utils.local_dataset(dataset_name) 130 | return full_dataset 131 | except: 132 | raise ValueError(f"Error loading dataset from {dataset_name}") 133 | else: 134 | raise NotImplementedError(f"Dataset {dataset_name} not implemented yet.") 135 | 136 | def format_dataset(dataset): 137 | # Remove unused columns. 138 | dataset = dataset.remove_columns( 139 | [ 140 | col 141 | for col in dataset.column_names["train"] 142 | if col not in ["input", "answer", "gt_answer", "level"] 143 | ] 144 | ) 145 | return dataset 146 | 147 | # Load dataset. 148 | dataset = load_data(args.dataset) 149 | dataset = format_dataset(dataset) 150 | 151 | # Split train/eval, reduce size 152 | eval_dataset = None 153 | if args.do_eval: 154 | if args.eval_dataset is not None: 155 | eval_dataset = load_data(args.eval_dataset) 156 | eval_dataset = format_dataset(eval_dataset) 157 | eval_dataset = eval_dataset["train"] 158 | else: 159 | print( 160 | "Splitting train dataset in train and validation according to `eval_dataset_size`" 161 | ) 162 | dataset = dataset["train"].train_test_split( 163 | test_size=args.eval_dataset_size, shuffle=True, seed=42 164 | ) 165 | eval_dataset = dataset["test"] 166 | if ( 167 | args.max_eval_samples is not None 168 | and len(eval_dataset) > args.max_eval_samples 169 | ): 170 | eval_dataset = eval_dataset.select(range(args.max_eval_samples)) 171 | 172 | test_dataset = None 173 | if args.do_test: 174 | if args.test_dataset is not None: 175 | test_dataset = load_data(args.test_dataset) 176 | test_dataset = format_dataset(test_dataset) 177 | test_dataset = test_dataset["train"] 178 | else: 179 | raise NotImplementedError("Must specify test dataset if `do_test` is True.") 180 | 181 | train_dataset = dataset["train"] 182 | if ( 183 | args.max_train_samples is not None 184 | and len(train_dataset) > args.max_train_samples 185 | ): 186 | train_dataset = train_dataset.select(range(args.max_train_samples)) 187 | 188 | train_dataset = QueryDataset( 189 | dataset=train_dataset, 190 | tokenizer=tokenizer, 191 | query_len=args.source_max_len, 192 | ) 193 | 194 | if eval_dataset is not None: 195 | eval_dataset = QueryDataset( 196 | dataset=eval_dataset, 197 | tokenizer=tokenizer, 198 | query_len=args.source_max_len, 199 | ) 200 | 201 | if test_dataset is not None: 202 | test_dataset = QueryDataset( 203 | dataset=test_dataset, 204 | tokenizer=tokenizer, 205 | query_len=args.source_max_len, 206 | ) 207 | 208 | data_collator = utils.DataCollatorForStackableDataset() 209 | return dict( 210 | train_dataset=train_dataset, 211 | eval_dataset=eval_dataset, 212 | test_dataset=test_dataset, 213 | data_collator=data_collator, 214 | ) 215 | -------------------------------------------------------------------------------- /train_code/data_utils/data_utils_rm_pairwise.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The GPT-Accelera Team 2 | # Copyright 2023 The Alpaca Team 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import logging 17 | from typing import Optional, Dict, Sequence 18 | 19 | import torch 20 | from torch.utils.data import Dataset 21 | 22 | from datasets import Dataset as HFDataset 23 | 24 | from arguments import Arguments 25 | import trainers.common_utils as utils 26 | from models.tokenizer_utils import AcceleraTokenizer 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | 31 | DROMEDARY_PROMPT_DICT = { 32 | "prompt_input": ( 33 | "{meta_prompt}\n" "{instruction}\n\n" "{input}\n\n" "### Dromedary" 34 | ), 35 | "prompt_no_input": ("{meta_prompt}\n" "{instruction}\n\n" "### Dromedary"), 36 | } 37 | 38 | 39 | ALPACA_PROMPT_DICT = { 40 | "prompt_input": ( 41 | "Below is an instruction that describes a task, paired with an input that provides further context. " 42 | "Write a response that appropriately completes the request.\n\n" 43 | "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" 44 | ), 45 | "prompt_no_input": ( 46 | "Below is an instruction that describes a task. " 47 | "Write a response that appropriately completes the request.\n\n" 48 | "### Instruction:\n{instruction}\n\n### Response:\n" 49 | ), 50 | } 51 | 52 | 53 | def format_prompt( 54 | example: Dict[str, str], 55 | prompt_dict: Dict[str, str], 56 | ) -> str: 57 | if prompt_dict is not None: 58 | assert ( 59 | "instruction" in example 60 | ), "Internal error: example missing required keys." 61 | 62 | if example.get("input", "") != "": 63 | prompt_format = prompt_dict["prompt_input"] 64 | else: 65 | prompt_format = prompt_dict["prompt_no_input"] 66 | else: 67 | prompt_format = "{input}" 68 | 69 | format_prompt = prompt_format.format(**example) 70 | return format_prompt 71 | 72 | 73 | def format_output( 74 | example: dict, 75 | output_key="output", 76 | ) -> str: 77 | return example[output_key] 78 | 79 | 80 | def _tokenize_fn( 81 | strings: Sequence[str], 82 | tokenizer: AcceleraTokenizer, 83 | max_length: int, 84 | end_sequence_with_eos: bool, 85 | use_data_frame: bool = False, 86 | ) -> dict: 87 | """Tokenize a list of strings.""" 88 | if use_data_frame: 89 | raise NotImplementedError 90 | strings_ds = strings 91 | 92 | tokenized_strings = tokenizer( 93 | strings_ds, 94 | max_length=max_length, 95 | padding="max_length", 96 | truncation=True, 97 | add_bos=True, 98 | add_eos=True if end_sequence_with_eos else False, 99 | padding_side="right", 100 | truncation_side="right", 101 | ) 102 | 103 | input_ids = torch.stack( 104 | [torch.tensor(tokenized) for tokenized in tokenized_strings["input_ids"]], 105 | dim=0, 106 | ) 107 | 108 | return input_ids 109 | 110 | 111 | def preprocess_for_reward_modeling( 112 | data: HFDataset, 113 | tokenizer: AcceleraTokenizer, 114 | end_sequence_with_eos: bool = False, 115 | max_length: Optional[int] = None, 116 | query_len: Optional[int] = None, 117 | response_len: Optional[int] = None, 118 | prompt_dict: Optional[Dict[str, str]] = None, 119 | ) -> Dict[str, torch.Tensor]: 120 | list_dict_data = data.to_pandas().to_dict("records") 121 | 122 | def _get_numeric_preference(example: dict): 123 | # 1 vs 2 is stored in table, but for modeling we use 0 vs 1; remap here. 124 | return {1: 0, 2: 1}[example["preference"]] 125 | 126 | choice = torch.tensor( 127 | [[_get_numeric_preference(dict_data)] for dict_data in list_dict_data] 128 | ) 129 | 130 | def _get_text(example: dict, output_key: str): 131 | full_prompt = format_prompt(example, prompt_dict) + format_output( 132 | example, output_key 133 | ) 134 | return full_prompt 135 | 136 | text_list_0, text_list_1 = tuple( 137 | [_get_text(dict_data, key) for dict_data in list_dict_data] 138 | for key in ("output_1", "output_2") 139 | ) 140 | 141 | if max_length is None: 142 | max_length = query_len + response_len 143 | 144 | logger.warning(f"Tokenizing {len(list_dict_data)} pairs...") 145 | tokenized_0, tokenized_1 = tuple( 146 | _tokenize_fn(text_list, tokenizer, max_length, end_sequence_with_eos) 147 | for text_list in (text_list_0, text_list_1) 148 | ) 149 | # "size" (bsz, 2, seq_len) 150 | input_ids = torch.stack( 151 | [tokenized_0, tokenized_1], 152 | dim=1, 153 | ) 154 | 155 | packaged_data = dict( 156 | input_ids=input_ids, 157 | choice=choice, 158 | metadata=dict(mean_choice=choice.float().mean().item()), 159 | ) 160 | 161 | return packaged_data 162 | 163 | 164 | class PairwiseRewardModelingDataset(Dataset): 165 | def __init__( 166 | self, 167 | data: HFDataset, 168 | tokenizer: AcceleraTokenizer, 169 | end_sequence_with_eos: bool = False, 170 | max_length: Optional[int] = None, 171 | query_len: Optional[int] = None, 172 | response_len: Optional[int] = None, 173 | prompt_dict: Optional[Dict[str, str]] = None, 174 | ): 175 | super(PairwiseRewardModelingDataset, self).__init__() 176 | data_dict = preprocess_for_reward_modeling( 177 | data=data, 178 | tokenizer=tokenizer, 179 | end_sequence_with_eos=end_sequence_with_eos, 180 | max_length=max_length, 181 | query_len=query_len, 182 | response_len=response_len, 183 | prompt_dict=prompt_dict, 184 | ) 185 | self.input_ids = data_dict["input_ids"] 186 | self.choice = data_dict["choice"] 187 | self.metadata = data_dict["metadata"] 188 | 189 | def __len__(self): 190 | return len(self.input_ids) 191 | 192 | def __getitem__(self, i) -> Dict[str, torch.Tensor]: 193 | return dict( 194 | input_ids=self.input_ids[i], 195 | choice=self.choice[i], 196 | ) 197 | 198 | 199 | def make_pairwise_reward_modeling_data_module( 200 | tokenizer: AcceleraTokenizer, 201 | args: Arguments, 202 | ): 203 | preference_dataset = utils.local_dataset(args.dataset) 204 | train_preference = preference_dataset["train"] 205 | 206 | if args.dataset_format == "alpaca": 207 | prompt_dict = ALPACA_PROMPT_DICT 208 | elif args.dataset_format is None: 209 | prompt_dict = None 210 | else: 211 | raise ValueError( 212 | f"Unsupported dataset_format: {args.dataset_format}." 213 | "Only alpaca and None are supported." 214 | ) 215 | 216 | train_dataset = PairwiseRewardModelingDataset( 217 | data=train_preference, 218 | tokenizer=tokenizer, 219 | end_sequence_with_eos=args.add_eos_to_target, 220 | max_length=args.total_max_len, 221 | query_len=args.source_max_len, 222 | response_len=args.target_max_len, 223 | prompt_dict=prompt_dict, 224 | ) 225 | 226 | eval_dataset = None 227 | if args.eval_size > 0: 228 | train_dataset, eval_dataset = utils.split_train_into_train_and_eval( 229 | train_dataset=train_dataset, 230 | eval_size=args.eval_size, 231 | seed=args.seed, 232 | ) 233 | 234 | data_collator = utils.DataCollatorForStackableDataset() 235 | return dict( 236 | train_dataset=train_dataset, 237 | eval_dataset=eval_dataset, 238 | data_collator=data_collator, 239 | ) 240 | -------------------------------------------------------------------------------- /train_code/data_utils/data_utils_sft.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The GPT-Accelera Team 2 | # Copyright 2023 The Alpaca Team 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import os 17 | from dataclasses import dataclass 18 | import logging 19 | from typing import Dict, Sequence, Union 20 | 21 | import torch 22 | 23 | from datasets import load_dataset 24 | 25 | from arguments import Arguments 26 | import trainers.common_utils as utils 27 | from models.tokenizer_utils import AcceleraTokenizer 28 | 29 | logger = logging.getLogger(__name__) 30 | 31 | DROMEDARY_PROMPT_DICT = { 32 | "prompt_input": ( 33 | "{meta_prompt}\n" "{instruction}\n\n" "{input}\n\n" "### Dromedary" 34 | ), 35 | "prompt_no_input": ("{meta_prompt}\n" "{instruction}\n\n" "### Dromedary"), 36 | } 37 | 38 | ALPACA_PROMPT_DICT = { 39 | "prompt_input": ( 40 | "Below is an instruction that describes a task, paired with an input that provides further context. " 41 | "Write a response that appropriately completes the request.\n\n" 42 | "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" 43 | ), 44 | "prompt_no_input": ( 45 | "Below is an instruction that describes a task. " 46 | "Write a response that appropriately completes the request.\n\n" 47 | "### Instruction:\n{instruction}\n\n### Response:\n" 48 | ), 49 | } 50 | 51 | 52 | def preprocess_for_sft( 53 | instances: Sequence[Dict], 54 | tokenizer: AcceleraTokenizer, 55 | source_max_len: int, 56 | target_max_len: int, 57 | total_max_len: int, 58 | train_on_source: bool, 59 | add_eos_to_target: bool, 60 | add_eos_to_marked_target: bool, 61 | return_win_rate: bool = False, 62 | ) -> Dict[str, Union[torch.Tensor, Sequence[torch.Tensor]]]: 63 | # Extract elements 64 | sources = [example["input"] for example in instances] 65 | targets = [f"\n{example['output']}" for example in instances] 66 | 67 | begin_padding_len = tokenizer( 68 | ["\n"], return_tensors="pt", add_bos=False, add_eos=False 69 | ).input_ids.shape[1] 70 | 71 | # Tokenize 72 | tokenized_sources_with_prompt = tokenizer( 73 | sources, 74 | max_length=source_max_len, 75 | padding="max_length", 76 | truncation=True, 77 | add_bos=True, 78 | add_eos=False, 79 | padding_side="left", 80 | truncation_side="left", 81 | ) 82 | 83 | marked_eos = None 84 | if "is_eos" in instances[0] and add_eos_to_marked_target: 85 | marked_eos = [example["is_eos"] for example in instances] 86 | 87 | win_rate = None 88 | if return_win_rate: 89 | if "win_rate" in instances[0]: 90 | win_rate = [example["win_rate"] for example in instances] 91 | else: 92 | win_rate = [0.5 for _ in instances] 93 | 94 | # logger.warning(f"Tokenizing {len(targets)} pairs...") 95 | tokenized_targets = tokenizer( 96 | targets, 97 | max_length=target_max_len + begin_padding_len, 98 | padding="max_length", 99 | truncation=True, 100 | add_bos=False, 101 | add_eos=add_eos_to_target, 102 | marked_eos=marked_eos, 103 | padding_side="right", 104 | truncation_side="right", 105 | ) 106 | # Build the input and labels for causal LM 107 | input_ids = [] 108 | labels = [] 109 | for source_length, tokenized_source, tokenized_target in zip( 110 | tokenized_sources_with_prompt["length"], 111 | tokenized_sources_with_prompt["input_ids"], 112 | tokenized_targets["input_ids"], 113 | ): 114 | tokenized_target = tokenized_target[begin_padding_len:] 115 | full_seq = tokenized_source + tokenized_target 116 | 117 | # move the beginning padding to the end of the full_seq 118 | num_begin_padding = len(tokenized_source) - source_length 119 | full_seq = full_seq[num_begin_padding:] + full_seq[:num_begin_padding] 120 | 121 | if total_max_len is not None: 122 | full_seq = full_seq[:total_max_len] 123 | 124 | # input_ids.append(torch.tensor(full_seq)) 125 | input_ids.append(full_seq) 126 | if not train_on_source: 127 | full_seq_label = ( 128 | [tokenizer.pad_id for _ in range(source_length)] 129 | + tokenized_target 130 | + [tokenizer.pad_id for _ in range(num_begin_padding)] 131 | ) 132 | if total_max_len is not None: 133 | full_seq_label = full_seq_label[:total_max_len] 134 | # labels.append(torch.tensor(full_seq_label)) 135 | labels.append(full_seq_label) 136 | else: 137 | # labels.append(torch.tensor(copy.deepcopy(full_seq))) 138 | labels.append(full_seq) 139 | # Apply padding 140 | # input_ids = pad_sequence( 141 | # input_ids, batch_first=True, padding_value=tokenizer.pad_id 142 | # ) 143 | # labels = pad_sequence(labels, batch_first=True, padding_value=tokenizer.pad_id) 144 | input_ids = torch.tensor(input_ids) 145 | labels = torch.tensor(labels) 146 | data_dict = { 147 | "input_ids": input_ids, 148 | "attention_mask": input_ids.ne(tokenizer.pad_id), 149 | } 150 | if labels is not None: 151 | data_dict["labels"] = labels 152 | if return_win_rate: 153 | data_dict["win_rate"] = torch.tensor(win_rate).view(-1, 1) 154 | return data_dict 155 | 156 | 157 | @dataclass 158 | class DataCollatorForCausalLM(object): 159 | tokenizer: AcceleraTokenizer 160 | source_max_len: int 161 | target_max_len: int 162 | total_max_len: int 163 | train_on_source: bool 164 | add_eos_to_target: bool 165 | add_eos_to_marked_target: bool 166 | 167 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: 168 | return preprocess_for_sft( 169 | instances=instances, 170 | tokenizer=self.tokenizer, 171 | source_max_len=self.source_max_len, 172 | target_max_len=self.target_max_len, 173 | total_max_len=self.total_max_len, 174 | train_on_source=self.train_on_source, 175 | add_eos_to_target=self.add_eos_to_target, 176 | add_eos_to_marked_target=self.add_eos_to_marked_target, 177 | ) 178 | 179 | 180 | def extract_alpaca_dataset(example): 181 | if example.get("input", "") != "": 182 | prompt_format = ALPACA_PROMPT_DICT["prompt_input"] 183 | else: 184 | prompt_format = ALPACA_PROMPT_DICT["prompt_no_input"] 185 | return {"input": prompt_format.format(**example)} 186 | 187 | 188 | def extract_dromedary_dataset(example, meta_prompts): 189 | assert "example_id" in example 190 | total_meta_prompt = len(meta_prompts) 191 | meta_prompt = meta_prompts[int(example["example_id"]) % total_meta_prompt] 192 | 193 | if example.get("input", "") != "": 194 | prompt_format = DROMEDARY_PROMPT_DICT["prompt_input"] 195 | else: 196 | prompt_format = DROMEDARY_PROMPT_DICT["prompt_no_input"] 197 | 198 | return { 199 | "input": prompt_format.format(meta_prompt=meta_prompt, **example), 200 | "output": "\n" + example["output"], 201 | } 202 | 203 | 204 | def extract_prm_dataset(example): 205 | if example["output_prefix"] == "": 206 | ret = { 207 | "input": "Question: " + example["input"], 208 | "output": "\n\nAnswer: " + example["output"], 209 | } 210 | else: 211 | ret = { 212 | "input": "Question: " 213 | + example["input"] 214 | + "\n\nAnswer: " 215 | + example["output_prefix"], 216 | "output": example["output"], 217 | } 218 | 219 | if "is_eos" in example: 220 | ret["is_eos"] = example["is_eos"] 221 | 222 | return ret 223 | 224 | 225 | def extract_prm_v2_dataset(example): 226 | if example["output_prefix"] == "": 227 | ret = { 228 | "input": "# Question\n\n" + example["input"] + "\n\n# Solution", 229 | "output": "\n\n" + example["output"], 230 | } 231 | else: 232 | ret = { 233 | "input": "# Question\n\n" 234 | + example["input"] 235 | + "\n\n# Solution\n\n" 236 | + example["output_prefix"], 237 | "output": example["output"], 238 | } 239 | 240 | if "is_eos" in example: 241 | ret["is_eos"] = example["is_eos"] 242 | 243 | return ret 244 | 245 | 246 | def extract_metamath_dataset(example): 247 | ret = { 248 | "input": "# Question\n\n" + example["query"] + "\n\n# Solution", 249 | "output": "\n\n" + example["output"], 250 | "is_eos": True, 251 | } 252 | 253 | return ret 254 | 255 | def extract_arc_dataset(example): 256 | ret = { 257 | "input": example["input"], 258 | "output": example["output"], 259 | "is_eos": True, 260 | } 261 | return ret 262 | 263 | def extract_apps_dataset(example): 264 | ret = { 265 | "input": example["input"], 266 | "output": example["output"], 267 | "is_eos": True, 268 | } 269 | 270 | return ret 271 | 272 | def make_sft_data_module( 273 | tokenizer: AcceleraTokenizer, 274 | args: Arguments, 275 | ) -> Dict: 276 | """ 277 | Make dataset and collator for supervised fine-tuning. 278 | Datasets are expected to have the following columns: { `input`, `output` } 279 | """ 280 | 281 | def load_data(dataset_name): 282 | if dataset_name == "alpaca": 283 | return load_dataset("tatsu-lab/alpaca") 284 | elif dataset_name == "alpaca-clean": 285 | return load_dataset("yahma/alpaca-cleaned") 286 | elif dataset_name == "chip2": 287 | return load_dataset("laion/OIG", data_files="unified_chip2.jsonl") 288 | elif dataset_name == "self-instruct": 289 | return load_dataset("yizhongw/self_instruct", name="self_instruct") 290 | elif dataset_name == "hh-rlhf": 291 | return load_dataset("Anthropic/hh-rlhf") 292 | elif dataset_name == "longform": 293 | return load_dataset("akoksal/LongForm") 294 | elif dataset_name == "oasst1": 295 | return load_dataset("timdettmers/openassistant-guanaco") 296 | elif dataset_name == "vicuna": 297 | raise NotImplementedError("Vicuna data was not released.") 298 | else: 299 | if os.path.exists(dataset_name): 300 | try: 301 | args.dataset_format = ( 302 | args.dataset_format if args.dataset_format else "alpaca" 303 | ) 304 | full_dataset = utils.local_dataset(dataset_name) 305 | return full_dataset 306 | except: 307 | raise ValueError(f"Error loading dataset from {dataset_name}") 308 | else: 309 | raise NotImplementedError( 310 | f"Dataset {dataset_name} not implemented yet." 311 | ) 312 | 313 | def format_dataset(dataset, dataset_format): 314 | if ( 315 | dataset_format == "alpaca" 316 | or dataset_format == "alpaca-clean" 317 | or (dataset_format is None and args.dataset in ["alpaca", "alpaca-clean"]) 318 | ): 319 | dataset = dataset.map( 320 | extract_alpaca_dataset, remove_columns=["instruction"] 321 | ) 322 | elif dataset_format == "hh-rlhf" or ( 323 | dataset_format is None and args.dataset == "hh-rlhf" 324 | ): 325 | dataset = dataset.map(lambda x: {"input": "", "output": x["chosen"]}) 326 | elif dataset_format == "prm": 327 | dataset = dataset.map(extract_prm_dataset) 328 | elif dataset_format == "prm-v2": 329 | dataset = dataset.map(extract_prm_v2_dataset) 330 | elif dataset_format == "arc": 331 | dataset = dataset.map(extract_arc_dataset) 332 | elif dataset_format == "apps": 333 | dataset = dataset.map(extract_apps_dataset) 334 | elif dataset_format == "metamath": 335 | dataset = dataset.map(extract_metamath_dataset) 336 | elif dataset_format == "mapped": 337 | dataset = dataset 338 | else: 339 | raise ValueError(f"Unsupported dataset format: {dataset_format}") 340 | 341 | # Remove unused columns. 342 | dataset = dataset.remove_columns( 343 | [ 344 | col 345 | for col in dataset.column_names["train"] 346 | if col not in ["input", "output", "is_eos"] 347 | ] 348 | ) 349 | return dataset 350 | 351 | # Load dataset. 352 | dataset = load_data(args.dataset) 353 | dataset = format_dataset(dataset, args.dataset_format) 354 | 355 | # Split train/eval, reduce size 356 | if args.do_eval: 357 | if "eval" in dataset: 358 | eval_dataset = dataset["eval"] 359 | else: 360 | print( 361 | "Splitting train dataset in train and validation according to `eval_dataset_size`" 362 | ) 363 | dataset = dataset["train"].train_test_split( 364 | test_size=args.eval_dataset_size, shuffle=True, seed=42 365 | ) 366 | eval_dataset = dataset["test"] 367 | if ( 368 | args.max_eval_samples is not None 369 | and len(eval_dataset) > args.max_eval_samples 370 | ): 371 | eval_dataset = eval_dataset.select(range(args.max_eval_samples)) 372 | 373 | if args.do_train: 374 | train_dataset = dataset["train"] 375 | if ( 376 | args.max_train_samples is not None 377 | and len(train_dataset) > args.max_train_samples 378 | ): 379 | train_dataset = train_dataset.select(range(args.max_train_samples)) 380 | 381 | data_collator = DataCollatorForCausalLM( 382 | tokenizer=tokenizer, 383 | source_max_len=args.source_max_len, 384 | target_max_len=args.target_max_len, 385 | total_max_len=args.total_max_len, 386 | train_on_source=args.train_on_source, 387 | add_eos_to_target=args.add_eos_to_target, 388 | add_eos_to_marked_target=args.add_eos_to_marked_target, 389 | ) 390 | return dict( 391 | train_dataset=train_dataset if args.do_train else None, 392 | eval_dataset=eval_dataset if args.do_eval else None, 393 | data_collator=data_collator, 394 | ) 395 | -------------------------------------------------------------------------------- /train_code/determine_hyper.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import random 4 | import pandas as pd 5 | import numpy as np 6 | from collections import Counter 7 | import re 8 | import argparse 9 | import os 10 | import argparse 11 | from itertools import product 12 | 13 | 14 | 15 | def process_data(input_file, ref_file): 16 | if not os.path.exists(input_file) or not os.path.exists(ref_file): 17 | return [] 18 | 19 | with open(ref_file, "r") as r: 20 | data_lines = r.readlines() 21 | 22 | infer4reward_data = [json.loads(l) for l in data_lines] 23 | 24 | infer_dict = {} 25 | for item in infer4reward_data: # 获取官方答案 26 | infer_dict[item["prompt"]] = item["response"] 27 | 28 | with open(input_file, "r") as r: 29 | data_lines = r.readlines() 30 | 31 | data_json = [] 32 | for item in data_lines: 33 | try: 34 | data_json.append(json.loads(item)) 35 | except json.JSONDecodeError: 36 | continue 37 | 38 | random.shuffle(data_json) 39 | 40 | trn_data = {} 41 | for item in data_json: 42 | if item["idx"] in trn_data: 43 | trn_data[item["idx"]]["sample_list"].append({"output": item["output"], "reward": item["reward"], "response": infer_dict[item["prompt"]]}) 44 | else: 45 | trn_data[item["idx"]] = { 46 | "prompt": item["prompt"], 47 | "sample_list": [{"output": item["output"], "reward": item["reward"], "response": infer_dict[item["prompt"]]}] 48 | } 49 | 50 | trn_list = [] 51 | for item in trn_data: 52 | trn_list.append(trn_data[item]) 53 | 54 | return trn_list 55 | 56 | def get_unique_trn_json(max_samples, trn_data, score_threshold): 57 | trn_json = [] 58 | for item in trn_data: 59 | solutions = [] 60 | for sample in item["sample_list"]: 61 | if "\n\n# Answer\n\n" in sample["output"]: 62 | final_answer = sample["output"].split("\n\n# Answer\n\n")[-1] 63 | prm_score = min(sample["reward"]) 64 | 65 | orm_score = 0.0 66 | if sample["output"].split("\n\n# Answer\n\n")[-1] == sample["response"].split("\n\n# Answer\n\n")[-1]: 67 | orm_score = 1.0 68 | 69 | final_score = prm_score / 5.0 + orm_score 70 | solutions.append({'final_answer': final_answer, 'prm_score': prm_score, 'output': sample["output"], "score": final_score}) #计算prm_score和orm_score的加权平均, 也就是最终的reward 分数, 用于数据的筛选 71 | 72 | solutions.append({'final_answer': final_answer, 'prm_score': 0.0, 'output': sample["response"], "score": 1.0}) # 添加官方数据 73 | if len(solutions) == 0: 74 | continue 75 | 76 | solutions_sorted = sorted(solutions, key=lambda x: x['score'], reverse=True) 77 | idx = 0 78 | temp_input = item["prompt"].split("\n\n# Solution\n\n")[0] 79 | temp_input = temp_input.split("# Question\n\n")[-1] 80 | 81 | for solu in solutions_sorted: 82 | if solu["score"] > score_threshold: 83 | trn_json.append({"query": temp_input, "output": solu["output"], "response": sample["response"], "reward": solu["prm_score"]}) 84 | idx += 1 85 | if idx >= max_samples: 86 | break 87 | # 去重部分 88 | unique_trn_json = [] 89 | seen = set() 90 | for item in trn_json: 91 | identifier = (item["query"], item["output"]) 92 | if identifier not in seen: 93 | seen.add(identifier) 94 | unique_trn_json.append(item) 95 | 96 | return unique_trn_json 97 | 98 | 99 | 100 | def cal_modi_acc_soft(data_json): 101 | unique_set = set() 102 | for item in data_json: 103 | unique_set.add(item["query"]) 104 | unique_dict = {} 105 | 106 | for item in unique_set: 107 | unique_dict[item] = [] 108 | 109 | #trn_json = [] 110 | for item in data_json: 111 | unique_dict[item["query"]].append({"output": item["output"], "response": item["response"]}) 112 | 113 | correct_num = [] 114 | 115 | actual_num = [] 116 | correct_ratio = [] 117 | 118 | for item in unique_dict: 119 | temp_count = 0 120 | for output in unique_dict[item]: 121 | if output["output"].split("\n\n# Answer\n\n")[-1] == output["response"].split("\n\n# Answer\n\n")[-1]: 122 | temp_count = temp_count + 1 123 | 124 | correct_num.append(temp_count) 125 | actual_num.append(len(unique_dict[item])) 126 | correct_ratio.append(temp_count/len(unique_dict[item])) 127 | modi_acc_list = [] 128 | for num, ratio in zip(correct_num, correct_ratio): 129 | if num >= 8: 130 | modi_acc_list.append(ratio) 131 | 132 | else: 133 | modi_acc_list.append((num * ratio)/ 8) 134 | 135 | return np.mean(modi_acc_list) 136 | 137 | 138 | 139 | def find_rewardbase(ana_list, reward_base, target_size): 140 | max_samples = 1 141 | max_iterations = 64 142 | iteration = 0 143 | while iteration < max_iterations: 144 | unique_trn_json = get_unique_trn_json(max_samples, ana_list, reward_base) 145 | if len(unique_trn_json) >= target_size: 146 | break 147 | max_samples += 1 148 | iteration += 1 149 | 150 | modi_acc = cal_modi_acc_soft(unique_trn_json) 151 | random.shuffle(unique_trn_json) 152 | 153 | return modi_acc 154 | 155 | 156 | 157 | def find_strategy(input, ref, target_size): 158 | ana_list = process_data(input, ref) # 合并 input 和 ref, 结合数据中的 response 和 reward 数据 159 | 160 | if not ana_list: 161 | return float('-inf') 162 | hyper_param = -1.0 163 | 164 | max_modi_acc = float('-inf') 165 | 166 | while True: 167 | modi_acc = find_rewardbase(ana_list, hyper_param, target_size) 168 | if modi_acc > max_modi_acc: 169 | max_modi_acc = modi_acc 170 | 171 | hyper_param += 0.01 172 | 173 | if hyper_param > 1.0: 174 | break 175 | 176 | return max_modi_acc 177 | 178 | 179 | def find_dataset(input_list, ref_list, target_size): 180 | modi_acc_list = [] 181 | 182 | 183 | for input_item, ref_item in zip(input_list, ref_list): 184 | modi_acc = find_strategy(input_item, ref_item, target_size) 185 | 186 | 187 | modi_acc_list.append(modi_acc) 188 | 189 | return modi_acc_list 190 | 191 | 192 | 193 | def parse_args(): 194 | parser = argparse.ArgumentParser(description="Generate input and reference file paths based on provided temps and sample numbers.") 195 | parser.add_argument('--temps', type=float, nargs='+', required=True, help="List of temperatures.") 196 | parser.add_argument('--sample_nums', type=int, nargs='+', required=True, help="List of sample numbers.") 197 | parser.add_argument('--input_path_template', type=str, required=True, help="Template for the input path.") 198 | parser.add_argument('--ref_path_template', type=str, required=True, help="Template for the reference path.") 199 | parser.add_argument('--iter', type=int, required=True, help="Iteration number to be included in the file path.") 200 | parser.add_argument('--valid_sample_size', type=int, required=True, help="Valid sample size.") 201 | return parser.parse_args() 202 | 203 | def main(): 204 | args = parse_args() 205 | 206 | input_list = [] 207 | ref_list = [] 208 | combinations = list(product(args.temps, args.sample_nums)) # 记录temp, sample_num的组合 209 | 210 | 211 | for temp, sample_num in combinations: 212 | # 设置输入和输出路径,将 iter 也替换到路径中 213 | input_path = args.input_path_template.format(temp=temp, sample_num=sample_num, iter=args.iter) 214 | ref_path = args.ref_path_template.format(temp=temp, sample_num=sample_num, iter=args.iter) 215 | 216 | input_list.append(input_path) 217 | ref_list.append(ref_path) 218 | 219 | # 假设 find_dataset 返回的是一个效果值列表,与 combinations 一一对应 220 | effect_values = find_dataset(input_list, ref_list, args.valid_sample_size) 221 | 222 | # 找到效果值最好的组合 223 | max_effect_index = effect_values.index(max(effect_values)) 224 | best_temp, best_sample_num = combinations[max_effect_index] 225 | 226 | # 输出最好的组合,格式为:temp=最佳temp值, sample_num=最佳sample_num值 227 | print(f"{best_temp} {best_sample_num}") 228 | 229 | if __name__ == "__main__": 230 | main() 231 | -------------------------------------------------------------------------------- /train_code/eval_arc_save_metrics.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | import os 4 | def main(): 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument("--input_file", type=str, required=True) 7 | parser.add_argument("--output_file", type=str, required=True) 8 | args = parser.parse_args() 9 | 10 | print(args) 11 | data = json.load(open(args.input_file, "r")) 12 | correct_count = 0 13 | for item in data.values(): 14 | assert "output1" not in item, "Pass@k is not allowed" 15 | score = item["output0"]["score"] 16 | if abs(float(score - 1.0) )< 1e-5: 17 | correct_count += 1 18 | print(f"Accuracy: {correct_count / len(data)}") 19 | metrics = {"accuracy": correct_count / len(data)} 20 | os.makedirs(os.path.dirname(args.output_file), exist_ok=True) 21 | with open(args.output_file, "w") as f: 22 | json.dump(metrics, f) 23 | 24 | 25 | 26 | if __name__ == "__main__": 27 | main() -------------------------------------------------------------------------------- /train_code/gsm8k_test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import pdb 4 | import re 5 | 6 | import jsonlines 7 | from fraction import Fraction 8 | 9 | from vllm import LLM, SamplingParams 10 | #from extract_answer_use_chatgpt import _extract_answer_chatgpt 11 | import sys 12 | 13 | MAX_INT = sys.maxsize #最大整数 14 | 15 | ####CUDA_VISIBLE_DEVICES=0 python gsm8k_test_the_answer_is_original_batch.py --model xxxxx --start 0 --end 1400 --batch_size 80 --tensor_parallel_size 1 16 | 17 | def is_number(s): 18 | try: # 如果能运行float(s)语句,返回True(字符串s是浮点数) 19 | float(s) 20 | return True 21 | except ValueError: # ValueError为Python的一种标准异常,表示"传入无效的参数" 22 | pass # 如果引发了ValueError这种异常,不做任何事情(pass:不做任何事情,一般用做占位语句) 23 | try: 24 | import unicodedata # 处理ASCii码的包 25 | unicodedata.numeric(s) # 把一个表示数字的字符串转换为浮点数返回的函数 26 | return True 27 | except (TypeError, ValueError): 28 | pass 29 | return False 30 | 31 | def extract_answer_number(completion): 32 | text = completion.split('\n\n# Answer\n\n') 33 | if len(text) > 1: 34 | extract_ans = text[-1].strip() 35 | match = re.search(r'[\-+]?\d*[\.,/]?\d+', extract_ans) 36 | if match: ######gsm8k答案都是整数 37 | if '/' in match.group(): 38 | denominator = match.group().split('/')[1] # 分母 39 | numerator = match.group().split('/')[0] # 分子 40 | if is_number(denominator) == True and is_number(numerator) == True: 41 | if denominator == '0': ##分母为0 42 | 43 | print('分母为0 ====:', match.group()) 44 | return round(float(numerator.replace(',', ''))) 45 | else: ##分母不为0 46 | frac = Fraction(match.group().replace(',', '')) 47 | num_numerator = frac.numerator 48 | num_denominator = frac.denominator 49 | return round(float(num_numerator / num_denominator)) # 分数, 四舍五入取整 50 | else: 51 | return None 52 | else: 53 | if float(match.group().replace(',', '')) == float('inf'): 54 | return None 55 | return round(float(match.group().replace(',', ''))) ###小数和千分数, 四舍五入取整 56 | else: 57 | return None 58 | else: 59 | return None 60 | 61 | def batch_data(data_list, batch_size=1): 62 | n = len(data_list) // batch_size 63 | batch_data = [] 64 | for i in range(n-1): 65 | start = i * batch_size 66 | end = (i+1)*batch_size 67 | batch_data.append(data_list[start:end]) 68 | 69 | last_start = (n-1) * batch_size 70 | last_end = MAX_INT 71 | batch_data.append(data_list[last_start:last_end]) 72 | return batch_data 73 | 74 | 75 | def gsm8k_test(args, model, start=0, end=MAX_INT, batch_size=1, tensor_parallel_size=1): 76 | INVALID_ANS = "[invalid]" 77 | gsm8k_ins = [] 78 | gsm8k_answers = [] 79 | 80 | 81 | #problem_prompt = ("A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {instruction} ASSISTANT: Let's think step by step.") 82 | #problem_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response: Let's think step by step." 83 | 84 | problem_prompt = "# Question\n\n{instruction}\n\n# Solution\n\n" 85 | #problem_prompt = ("A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant's response is from ChatGPT3.5. USER: {instruction} ASSISTANT: Let's think step by step.") 86 | # problem_prompt = ( 87 | # "You are a math assistant, skilled at solving various mathematical problems. USER: {instruction} ASSISTANT: Let's think step by step.") 88 | print('promt =====', problem_prompt) 89 | with open('./test.jsonl',"r+", encoding="utf8") as f: 90 | for idx, item in enumerate(jsonlines.Reader(f)): 91 | temp_instr = problem_prompt.format(instruction=item["question"]) 92 | gsm8k_ins.append(temp_instr) 93 | temp_ans = item['answer'].split('#### ')[1] 94 | temp_ans = int(temp_ans.replace(',', '')) 95 | gsm8k_answers.append(temp_ans) 96 | 97 | gsm8k_ins = gsm8k_ins[start:end] 98 | gsm8k_answers = gsm8k_answers[start:end] 99 | print('lenght ====', len(gsm8k_ins)) 100 | batch_gsm8k_ins = batch_data(gsm8k_ins, batch_size=batch_size) 101 | 102 | 103 | 104 | 105 | stop_tokens = ["Question:", "Question", "USER:", "USER", "ASSISTANT:", "ASSISTANT", "Instruction:", "Instruction", "Response:", "Response"] 106 | sampling_params = SamplingParams(temperature=0, top_p=1, max_tokens=args.max_tokens) 107 | print('sampleing =====', sampling_params) 108 | llm = LLM(model=model,tensor_parallel_size=tensor_parallel_size) 109 | result = [] 110 | res_completions = [] 111 | for idx, (prompt, prompt_answer) in enumerate(zip(batch_gsm8k_ins, gsm8k_answers)): 112 | print('llm idx ====', idx) 113 | if isinstance(prompt, list): 114 | pass 115 | else: 116 | prompt = [prompt] 117 | 118 | completions = llm.generate(prompt, sampling_params) 119 | for output in completions: 120 | prompt = output.prompt 121 | generated_text = output.outputs[0].text 122 | res_completions.append(generated_text) 123 | 124 | invalid_outputs = [] 125 | all_outputs = [] 126 | for idx, (prompt, completion, prompt_answer) in enumerate(zip(gsm8k_ins, res_completions, gsm8k_answers)): 127 | print('chatgpt idx =====', idx) 128 | doc = {'question': prompt} 129 | y_pred = extract_answer_number(completion) 130 | if y_pred != None: 131 | result.append(float(y_pred) == float(prompt_answer)) 132 | temp = {'question': prompt, 'output': completion, 'answer': prompt_answer} 133 | all_outputs.append(temp) 134 | else: 135 | result.append(False) 136 | temp = {'question': prompt, 'output': completion, 'answer': prompt_answer} 137 | invalid_outputs.append(temp) 138 | all_outputs.append(temp) 139 | # pdb.set_trace() 140 | acc = sum(result) / len(result) 141 | print('start===', start, ', end====', end) 142 | print('length====', len(result), ', acc====', acc) 143 | #print('len invalid outputs ====', len(invalid_outputs), ', valid_outputs===', invalid_outputs) 144 | model_name = '_'.join(model.split('/')[-2:]) 145 | 146 | OUTPUT_FILE_SUFFIX = ".json" 147 | import os 148 | output_data_path = os.path.join(args.output_path, args.model_id + OUTPUT_FILE_SUFFIX) 149 | import os 150 | output_dir = os.path.dirname(output_data_path) 151 | if not os.path.exists(output_dir): 152 | os.makedirs(output_dir) 153 | 154 | with open(output_data_path, 'w') as f: 155 | json.dump({"length====": len(result), "acc====": acc}, f) 156 | 157 | 158 | #pdb.set_trace() 159 | 160 | 161 | def parse_args(): 162 | parser = argparse.ArgumentParser() 163 | parser.add_argument("--model", type=str) # start index 164 | parser.add_argument("--model_id", type=str) # start index 165 | parser.add_argument("--output_path", type=str) # start index 166 | parser.add_argument("--start", type=int, default=0) #start index 167 | parser.add_argument("--end", type=int, default=MAX_INT) # start index 168 | parser.add_argument("--batch_size", type=int, default=1) # batch_size 169 | parser.add_argument("--tensor_parallel_size", type=int, default=1) # tensor_parallel_size 170 | parser.add_argument("--max_tokens", type=int, default=768) 171 | return parser.parse_args() 172 | 173 | if __name__ == "__main__": 174 | args = parse_args() 175 | gsm8k_test(args=args, model=args.model, start=args.start, end=args.end, batch_size=args.batch_size, tensor_parallel_size=args.tensor_parallel_size) 176 | #pdb.set_trace() 177 | -------------------------------------------------------------------------------- /train_code/infer_arc.sh: -------------------------------------------------------------------------------- 1 | RUN_NAME=arc_train_online_rft_mistral_temp0.4_7kdata 2 | LOG_PATH=logs/${RUN_NAME} 3 | QUERY_FILE=/mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/zengweihao02/easy2hard/share/project/weihao/easy-to-hard-main-share/dynamic_ana/arc_data/arc_test_data.jsonl 4 | PROMPT_DIR="/mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/zengweihao02/easy2hard/share/project/weihao/easy-to-hard-main-share/dynamic_ana/arc_prompt" 5 | mkdir -p $LOG_PATH 6 | 7 | # for i in {2500..4500..500} 8 | # do 9 | # python vllm_infer_arc.py \ 10 | # --input_file $QUERY_FILE \ 11 | # --prompt_template_dir $PROMPT_DIR \ 12 | # --shot_num 0 \ 13 | # --model /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/zengweihao02/checkpoints/$RUN_NAME/converted/ckpt$i \ 14 | # --sample_num 1 \ 15 | # --temperature 0 \ 16 | # --top_k -1 \ 17 | # --max_tokens 1024 \ 18 | # --split test \ 19 | # --save /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/zengweihao02/easy2hard/share/project/weihao/easy-to-hard-main-share/dynamic_ana/arc/results/${RUN_NAME}/ckpt$i/gen.json \ 20 | # --tensor-parallel-size 1 \ 21 | # --cuda_ids 0 \ 22 | # --cache_dir /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/zengweihao02/easy2hard/share/project/weihao/easy-to-hard-main-share/dynamic_ana/ARC/.cache \ 23 | # 2>&1 | tee -a $LOG_PATH/ckpt$i.log 24 | 25 | # python eval_arc_save_metrics.py \ 26 | # --input_file /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/zengweihao02/easy2hard/share/project/weihao/easy-to-hard-main-share/dynamic_ana/arc/results/${RUN_NAME}/ckpt$i/gen.json \ 27 | # --output_file /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/zengweihao02/easy2hard/share/project/weihao/easy-to-hard-main-share/dynamic_ana/arc/results/${RUN_NAME}/ckpt$i/metrics.json \ 28 | # 2>&1 | tee -a $LOG_PATH/ckpt$i.log 29 | # done 30 | 31 | RUN_NAME=Mistral-7B-Instruct-v0.1 32 | for i in {2500..2500..500} 33 | do 34 | python vllm_infer_arc.py \ 35 | --input_file $QUERY_FILE \ 36 | --prompt_template_dir $PROMPT_DIR \ 37 | --shot_num 1 \ 38 | --model /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/zengweihao02/modelzoo/Mistral-7B-Instruct-v0.1 \ 39 | --sample_num 1 \ 40 | --temperature 0 \ 41 | --top_k -1 \ 42 | --max_tokens 1024 \ 43 | --split test \ 44 | --save /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/zengweihao02/easy2hard/share/project/weihao/easy-to-hard-main-share/dynamic_ana/arc/results/${RUN_NAME}/ckpt$i/gen.json \ 45 | --tensor-parallel-size 1 \ 46 | --cuda_ids 0 \ 47 | --cache_dir /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/zengweihao02/easy2hard/share/project/weihao/easy-to-hard-main-share/dynamic_ana/ARC/.cache \ 48 | 2>&1 | tee -a $LOG_PATH/ckpt$i.log 49 | 50 | python eval_arc_save_metrics.py \ 51 | --input_file /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/zengweihao02/easy2hard/share/project/weihao/easy-to-hard-main-share/dynamic_ana/arc/results/${RUN_NAME}/ckpt$i/gen.json \ 52 | --output_file /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/zengweihao02/easy2hard/share/project/weihao/easy-to-hard-main-share/dynamic_ana/arc/results/${RUN_NAME}/ckpt$i/metrics.json \ 53 | 2>&1 | tee -a $LOG_PATH/ckpt$i.log 54 | done -------------------------------------------------------------------------------- /train_code/infer_gsm8k1.sh: -------------------------------------------------------------------------------- 1 | #export HF_HOME="/ssddata/model_hub" 2 | # export CUDA_HOME=/usr/local/cuda-11.7 #指定cuda根目录 3 | # export PATH=$PATH:/usr/local/cuda-11.7/bin #安装的cuda的路径下的bin文件夹,包含了nvcc等二进制程序 4 | # export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-11.7/lib64 ##安装的cuda的路径下的lib64文件夹,包含很多库文件 5 | 6 | 7 | # for i in {100..500..100} 8 | # do 9 | # CUDA_VISIBLE_DEVICES=1 python gsm8k_test.py \ 10 | # --model /share/project/weihao/save_dir/checkpoints/math_format_infer_iter9_fix_rest_64k/iter7/converted/ckpt$i \ 11 | # --output_path /share/project/weihao/save_dir/checkpoints/math_format_infer_iter9_fix_rest_64k/iter7/converted/ckpt$i \ 12 | # --model_id gsm8k \ 13 | # --start 0 \ 14 | # --end 1400 \ 15 | # --batch_size 800 \ 16 | # --tensor_parallel_size 1 \ 17 | # --max_tokens 768 18 | # done 19 | 20 | 21 | # for iter in {2..6} 22 | # do 23 | # for i in {100..500..100} 24 | # do 25 | # CUDA_VISIBLE_DEVICES=0 python gsm8k_test.py \ 26 | # --model /share/project/weihao/save_dir/checkpoints/math_format_infer_iter9_fix_rewardneg0.0_pre_iterrft_64k/iter$iter/converted/ckpt$i \ 27 | # --output_path /share/project/weihao/save_dir/checkpoints/math_format_infer_iter9_fix_rewardneg0.0_pre_iterrft_64k/iter$iter/converted/ckpt$i \ 28 | # --model_id gsm8k \ 29 | # --start 0 \ 30 | # --end 1400 \ 31 | # --batch_size 800 \ 32 | # --tensor_parallel_size 1 \ 33 | # --max_tokens 768 34 | # done 35 | # done 36 | 37 | 38 | 39 | # for i in {1750..2250..250} 40 | # do 41 | # CUDA_VISIBLE_DEVICES=0 python gsm8k_test.py \ 42 | # --model /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/FMG/zengweihao02/checkpoints/math_format_b_star_deepseek_math_16k_fix_bsz_fine_stand_fix_moretemp/converted/ckpt$i \ 43 | # --output_path /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/FMG/zengweihao02/checkpoints/math_format_b_star_deepseek_math_16k_fix_bsz_fine_stand_fix_moretemp/converted/ckpt$i \ 44 | # --model_id gsm8k \ 45 | # --start 0 \ 46 | # --end 1400 \ 47 | # --batch_size 800 \ 48 | # --tensor_parallel_size 1 \ 49 | # --max_tokens 768 50 | # done 51 | 52 | # for i in {1750..2250..250} 53 | # do 54 | # CUDA_VISIBLE_DEVICES=0 python vllm_infer.py \ 55 | # --input_data /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/FMG/zengweihao02/easy2hard/share/project/weihao/easy-to-hard-main-share/easy-to-hard-main-share/data/test_ppo.json \ 56 | # --model_dir /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/FMG/zengweihao02/checkpoints/math_format_b_star_deepseek_math_16k_fix_bsz_fine_stand_fix_moretemp/converted/ckpt$i \ 57 | # --sample_num 1 \ 58 | # --output_file /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/FMG/zengweihao02/checkpoints/math_format_b_star_deepseek_math_16k_fix_bsz_fine_stand_fix_moretemp/converted/ckpt$i/test_ppo_1to5_infer_greedy.json \ 59 | # --tensor_parallel_size 1 \ 60 | # --temperature 0.0 \ 61 | # --top_k -1 \ 62 | # --max_tokens 768 63 | # done 64 | 65 | 66 | # # python cal_metric_vllm.py \ 67 | # # --tokenizer_path /share/project/weihao/model_zoo/llemma_7b/tokenizer.model \ 68 | # # --answer_file /share/project/weihao/save_dir/checkpoints/math_format_online_rft_deepseek_math_48k_fix_bsz_upreward_update/converted/ckpt4000/test_ppo_1to5_infer_greedy.json \ 69 | # # --output_file /share/project/weihao/save_dir/checkpoints/math_format_online_rft_deepseek_math_48k_fix_bsz_upreward_update/converted/ckpt4000/test_ppo_1to5_infer_greedy_metric.json \ 70 | 71 | 72 | 73 | # for i in {1750..2250..250} 74 | # do 75 | # python cal_metric_vllm.py \ 76 | # --tokenizer_path /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/FMG/zengweihao02/modelzoo/llemma_7b/tokenizer.model \ 77 | # --answer_file /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/FMG/zengweihao02/checkpoints/math_format_b_star_deepseek_math_16k_fix_bsz_fine_stand_fix_moretemp/converted/ckpt$i/test_ppo_1to5_infer_greedy.json \ 78 | # --output_file /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/FMG/zengweihao02/checkpoints/math_format_b_star_deepseek_math_16k_fix_bsz_fine_stand_fix_moretemp/converted/ckpt$i/test_ppo_1to5_infer_greedy_metric.json 79 | # done 80 | 81 | 82 | for iter in {2..10} 83 | do 84 | for i in {250..250..250} 85 | do 86 | CUDA_VISIBLE_DEVICES=0 python gsm8k_test.py \ 87 | --model /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/FMG/zengweihao02/checkpoints/math_format_rest_llama3.1_math_32k_lr2e_5/iter$iter/converted/ckpt$i \ 88 | --output_path /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/FMG/zengweihao02/checkpoints/math_format_rest_llama3.1_math_32k_lr2e_5/iter$iter/converted/ckpt$i \ 89 | --model_id gsm8k \ 90 | --start 0 \ 91 | --end 1400 \ 92 | --batch_size 800 \ 93 | --tensor_parallel_size 1 \ 94 | --max_tokens 768 95 | done 96 | done 97 | 98 | 99 | for iter in {2..10} 100 | do 101 | for i in {250..250..250} 102 | do 103 | CUDA_VISIBLE_DEVICES=0 python vllm_infer.py \ 104 | --input_data /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/FMG/zengweihao02/easy2hard/share/project/weihao/easy-to-hard-main-share/easy-to-hard-main-share/data/test_ppo.json \ 105 | --model_dir /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/FMG/zengweihao02/checkpoints/math_format_rest_llama3.1_math_32k_lr2e_5/iter$iter/converted/ckpt$i \ 106 | --sample_num 1 \ 107 | --output_file /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/FMG/zengweihao02/checkpoints/math_format_rest_llama3.1_math_32k_lr2e_5/iter$iter/converted/ckpt$i/test_ppo_1to5_infer_greedy.json \ 108 | --tensor_parallel_size 1 \ 109 | --temperature 0.0 \ 110 | --top_k -1 \ 111 | --max_tokens 768 112 | done 113 | done 114 | 115 | 116 | 117 | for iter in {2..10} 118 | do 119 | for i in {250..250..250} 120 | do 121 | python cal_metric_vllm.py \ 122 | --tokenizer_path /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/FMG/zengweihao02/modelzoo/llemma_7b/tokenizer.model \ 123 | --answer_file /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/FMG/zengweihao02/checkpoints/math_format_rest_llama3.1_math_32k_lr2e_5/iter$iter/converted/ckpt$i/test_ppo_1to5_infer_greedy.json \ 124 | --output_file /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/FMG/zengweihao02/checkpoints/math_format_rest_llama3.1_math_32k_lr2e_5/iter$iter/converted/ckpt$i/test_ppo_1to5_infer_greedy_metric.json 125 | done 126 | done 127 | 128 | 129 | # for t in 0.5 0.7 0.9 1.1 130 | # do 131 | # for i in {500..4500..500} 132 | # do 133 | # CUDA_VISIBLE_DEVICES=0 python gsm8k_test.py \ 134 | # --model /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/FMG/zengweihao02/checkpoints/math_format_online_rft_mistral_math_64k_${t}t_neg0.4r/converted/ckpt$i \ 135 | # --output_path /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/FMG/zengweihao02/checkpoints/math_format_online_rft_mistral_math_64k_${t}t_neg0.4r/converted/ckpt$i \ 136 | # --model_id gsm8k \ 137 | # --start 0 \ 138 | # --end 1400 \ 139 | # --batch_size 800 \ 140 | # --tensor_parallel_size 1 \ 141 | # --max_tokens 768 142 | # done 143 | 144 | # for i in {500..4500..500} 145 | # do 146 | # CUDA_VISIBLE_DEVICES=0 python vllm_infer.py \ 147 | # --input_data /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/FMG/zengweihao02/easy2hard/share/project/weihao/easy-to-hard-main-share/easy-to-hard-main-share/data/test_ppo.json \ 148 | # --model_dir /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/FMG/zengweihao02/checkpoints/math_format_online_rft_mistral_math_64k_${t}t_neg0.4r/converted/ckpt$i \ 149 | # --sample_num 1 \ 150 | # --output_file /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/FMG/zengweihao02/checkpoints/math_format_online_rft_mistral_math_64k_${t}t_neg0.4r/converted/ckpt$i/test_ppo_1to5_infer_greedy.json \ 151 | # --tensor_parallel_size 1 \ 152 | # --temperature 0.0 \ 153 | # --top_k -1 \ 154 | # --max_tokens 768 155 | # done 156 | 157 | # for i in {500..4500..500} 158 | # do 159 | # python cal_metric_vllm.py \ 160 | # --tokenizer_path /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/FMG/zengweihao02/modelzoo/llemma_7b/tokenizer.model \ 161 | # --answer_file /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/FMG/zengweihao02/checkpoints/math_format_online_rft_mistral_math_64k_${t}t_neg0.4r/converted/ckpt$i/test_ppo_1to5_infer_greedy.json \ 162 | # --output_file /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/FMG/zengweihao02/checkpoints/math_format_online_rft_mistral_math_64k_${t}t_neg0.4r/converted/ckpt$i/test_ppo_1to5_infer_greedy_metric.json 163 | # done 164 | # done 165 | -------------------------------------------------------------------------------- /train_code/math_utils/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/math_utils/README.md -------------------------------------------------------------------------------- /train_code/math_utils/__pycache__/grader.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/math_utils/__pycache__/grader.cpython-310.pyc -------------------------------------------------------------------------------- /train_code/math_utils/__pycache__/grader.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/math_utils/__pycache__/grader.cpython-311.pyc -------------------------------------------------------------------------------- /train_code/math_utils/__pycache__/math_normalize.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/math_utils/__pycache__/math_normalize.cpython-310.pyc -------------------------------------------------------------------------------- /train_code/math_utils/__pycache__/math_normalize.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/math_utils/__pycache__/math_normalize.cpython-311.pyc -------------------------------------------------------------------------------- /train_code/math_utils/__pycache__/math_rl_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/math_utils/__pycache__/math_rl_utils.cpython-310.pyc -------------------------------------------------------------------------------- /train_code/math_utils/__pycache__/math_rl_utils.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/math_utils/__pycache__/math_rl_utils.cpython-311.pyc -------------------------------------------------------------------------------- /train_code/math_utils/grader.py: -------------------------------------------------------------------------------- 1 | """ 2 | Answer checker API that uses sympy to simplify expressions and check for equality. 3 | 4 | Call grade_answer(given_answer: str, ground_truth: str). 5 | """ 6 | 7 | import re 8 | import sympy 9 | from pylatexenc import latex2text 10 | from sympy.parsing import sympy_parser 11 | 12 | from math_utils import math_normalize 13 | 14 | 15 | # sympy might hang -- we don't care about trying to be lenient in these cases 16 | BAD_SUBSTRINGS = ["^{", "^("] 17 | BAD_REGEXES = ["\^[0-9]+\^", "\^[0-9][0-9]+"] 18 | TUPLE_CHARS = "()[]" 19 | 20 | 21 | def _sympy_parse(expr: str): 22 | """Parses an expression with sympy.""" 23 | py_expr = expr.replace("^", "**") 24 | return sympy_parser.parse_expr( 25 | py_expr, 26 | transformations=( 27 | sympy_parser.standard_transformations 28 | + (sympy_parser.implicit_multiplication_application,) 29 | ), 30 | ) 31 | 32 | 33 | def _parse_latex(expr: str) -> str: 34 | """Attempts to parse latex to an expression sympy can read.""" 35 | expr = expr.replace("\\tfrac", "\\frac") 36 | expr = expr.replace("\\dfrac", "\\frac") 37 | expr = expr.replace("\\frac", " \\frac") # Play nice with mixed numbers. 38 | expr = latex2text.LatexNodes2Text().latex_to_text(expr) 39 | 40 | # Replace the specific characters that this parser uses. 41 | expr = expr.replace("√", "sqrt") 42 | expr = expr.replace("π", "pi") 43 | expr = expr.replace("∞", "inf") 44 | expr = expr.replace("∪", "U") 45 | expr = expr.replace("·", "*") 46 | expr = expr.replace("×", "*") 47 | 48 | return expr.strip() 49 | 50 | 51 | def _is_float(num: str) -> bool: 52 | try: 53 | float(num) 54 | return True 55 | except ValueError: 56 | return False 57 | 58 | 59 | def _is_int(x: float) -> bool: 60 | try: 61 | return abs(x - int(round(x))) <= 1e-7 62 | except: 63 | return False 64 | 65 | 66 | def _is_frac(expr: str) -> bool: 67 | return bool(re.search(r"^-?[0-9]+.?/0*[1-9][0-9]*.?$", expr)) 68 | 69 | 70 | def _str_is_int(x: str) -> bool: 71 | try: 72 | x = _strip_properly_formatted_commas(x) 73 | x = float(x) 74 | return abs(x - int(round(x))) <= 1e-7 75 | except: 76 | return False 77 | 78 | 79 | def _str_to_int(x: str) -> bool: 80 | x = x.replace(",", "") 81 | x = float(x) 82 | return int(x) 83 | 84 | 85 | def _inject_implicit_mixed_number(step: str): 86 | """ 87 | Automatically make a mixed number evalable 88 | e.g. 7 3/4 => 7+3/4 89 | """ 90 | p1 = re.compile("([0-9]) +([0-9])") 91 | step = p1.sub("\\1+\\2", step) ## implicit mults 92 | return step 93 | 94 | 95 | def _strip_properly_formatted_commas(expr: str): 96 | # We want to be careful because we don't want to strip tuple commas 97 | p1 = re.compile("(\d)(,)(\d\d\d)($|\D)") 98 | while True: 99 | next_expr = p1.sub("\\1\\3\\4", expr) 100 | if next_expr == expr: 101 | break 102 | expr = next_expr 103 | return next_expr 104 | 105 | 106 | def _normalize(expr: str) -> str: 107 | """Normalize answer expressions.""" 108 | if expr is None: 109 | return None 110 | 111 | # Remove enclosing `\text{}`. 112 | m = re.search("^\\\\text\{(?P.+?)\}$", expr) 113 | if m is not None: 114 | expr = m.group("text") 115 | 116 | expr = expr.replace("\\%", "%") 117 | expr = expr.replace("\\$", "$") 118 | expr = expr.replace("$", "") 119 | expr = expr.replace("%", "") 120 | expr = expr.replace(" or ", " , ") 121 | expr = expr.replace(" and ", " , ") 122 | 123 | expr = expr.replace("million", "*10^6") 124 | expr = expr.replace("billion", "*10^9") 125 | expr = expr.replace("trillion", "*10^12") 126 | 127 | for unit in [ 128 | "degree", 129 | "cm", 130 | "centimeter", 131 | "meter", 132 | "mile", 133 | "second", 134 | "minute", 135 | "hour", 136 | "day", 137 | "week", 138 | "month", 139 | "year", 140 | "foot", 141 | "feet", 142 | "inch", 143 | "yard", 144 | ]: 145 | expr = re.sub(f"{unit}(es)?(s)? *(\^[0-9]+)?", "", expr) 146 | expr = re.sub(f"\^ *\\\\circ", "", expr) 147 | 148 | if len(expr) > 0 and expr[0] == "{" and expr[-1] == "}": 149 | expr = expr[1:-1] 150 | 151 | expr = re.sub(",\\\\! *", "", expr) 152 | if _is_float(expr) and _is_int(float(expr)): 153 | expr = str(int(round(float(expr)))) 154 | if "\\" in expr: 155 | try: 156 | expr = _parse_latex(expr) 157 | except: 158 | pass 159 | 160 | # edge case with mixed numbers and negative signs 161 | expr = re.sub("- *", "-", expr) 162 | 163 | expr = _inject_implicit_mixed_number(expr) 164 | expr = expr.replace(" ", "") 165 | 166 | # if we somehow still have latex braces here, just drop them 167 | expr = expr.replace("{", "") 168 | expr = expr.replace("}", "") 169 | 170 | # don't be case sensitive for text answers 171 | expr = expr.lower() 172 | 173 | if _str_is_int(expr): 174 | expr = str(_str_to_int(expr)) 175 | 176 | return expr 177 | 178 | 179 | def count_unknown_letters_in_expr(expr: str): 180 | expr = expr.replace("sqrt", "") 181 | expr = expr.replace("frac", "") 182 | letters_in_expr = set([x for x in expr if x.isalpha()]) 183 | return len(letters_in_expr) 184 | 185 | 186 | def should_allow_eval(expr: str): 187 | # we don't want to try parsing unknown text or functions of more than two variables 188 | if count_unknown_letters_in_expr(expr) > 2: 189 | return False 190 | 191 | for bad_string in BAD_SUBSTRINGS: 192 | if bad_string in expr: 193 | return False 194 | 195 | for bad_regex in BAD_REGEXES: 196 | if re.search(bad_regex, expr) is not None: 197 | return False 198 | 199 | return True 200 | 201 | 202 | def are_equal_under_sympy(ground_truth_normalized: str, given_normalized: str): 203 | are_equal = False 204 | try: 205 | expr = f"({ground_truth_normalized})-({given_normalized})" 206 | if should_allow_eval(expr): 207 | sympy_diff = _sympy_parse(expr) 208 | simplified = sympy.simplify(sympy_diff) 209 | if simplified == 0: 210 | are_equal = True 211 | except: 212 | pass 213 | return are_equal 214 | 215 | 216 | def split_tuple(expr: str): 217 | """ 218 | Split the elements in a tuple/interval, while handling well-formatted commas in large numbers 219 | """ 220 | expr = _strip_properly_formatted_commas(expr) 221 | if len(expr) == 0: 222 | return [] 223 | if ( 224 | len(expr) > 2 225 | and expr[0] in TUPLE_CHARS 226 | and expr[-1] in TUPLE_CHARS 227 | and all([ch not in expr[1:-1] for ch in TUPLE_CHARS]) 228 | ): 229 | elems = [elem.strip() for elem in expr[1:-1].split(",")] 230 | else: 231 | elems = [expr] 232 | return elems 233 | 234 | 235 | def grade_answer(given_answer: str, ground_truth: str) -> bool: 236 | """ 237 | The answer will be considered correct if: 238 | (a) it normalizes to the same string as the ground truth answer 239 | OR 240 | (b) sympy can simplify the difference between the expressions to 0 241 | """ 242 | if given_answer is None: 243 | return False 244 | 245 | ground_truth_normalized_mathd = math_normalize.normalize_answer(ground_truth) 246 | given_answer_normalized_mathd = math_normalize.normalize_answer(given_answer) 247 | 248 | # be at least as lenient as mathd 249 | if ground_truth_normalized_mathd == given_answer_normalized_mathd: 250 | return True 251 | 252 | ground_truth_normalized = _normalize(ground_truth) 253 | given_normalized = _normalize(given_answer) 254 | 255 | if ground_truth_normalized is None: 256 | return False 257 | 258 | if ground_truth_normalized == given_normalized: 259 | return True 260 | 261 | if len(given_normalized) == 0: 262 | return False 263 | 264 | ground_truth_elems = split_tuple(ground_truth_normalized) 265 | given_elems = split_tuple(given_normalized) 266 | 267 | if len(ground_truth_elems) > 1 and ( 268 | ground_truth_normalized[0] != given_normalized[0] 269 | or ground_truth_normalized[-1] != given_normalized[-1] 270 | ): 271 | is_correct = False 272 | elif len(ground_truth_elems) != len(given_elems): 273 | is_correct = False 274 | else: 275 | for ground_truth_elem, given_elem in zip(ground_truth_elems, given_elems): 276 | if _is_frac(ground_truth_elem) and _is_frac(given_elem): 277 | # if fractions aren't reduced, then shouldn't be marked as correct 278 | # so, we don't want to allow sympy.simplify in this case 279 | is_correct = ground_truth_elem == given_elem 280 | elif _str_is_int(ground_truth_elem) != _str_is_int(given_elem): 281 | # if the ground truth answer is an integer, we require the given answer to be a strict match (no sympy.simplify) 282 | is_correct = False 283 | else: 284 | # is_correct = are_equal_under_sympy(ground_truth_elem, given_elem) 285 | is_correct = False 286 | if not is_correct: 287 | break 288 | 289 | return is_correct 290 | -------------------------------------------------------------------------------- /train_code/math_utils/math_normalize.py: -------------------------------------------------------------------------------- 1 | """ 2 | This logic is largely copied from the Hendrycks' MATH release (math_equivalence). 3 | """ 4 | import re 5 | from typing import Optional 6 | 7 | 8 | def normalize_answer(answer: Optional[str]) -> Optional[str]: 9 | if answer is None: 10 | return None 11 | answer = answer.strip() 12 | try: 13 | # Remove enclosing `\text{}`. 14 | m = re.search("^\\\\text\{(?P.+?)\}$", answer) 15 | if m is not None: 16 | answer = m.group("text").strip() 17 | return _strip_string(answer) 18 | except: 19 | return answer 20 | 21 | 22 | def _fix_fracs(string): 23 | substrs = string.split("\\frac") 24 | new_str = substrs[0] 25 | if len(substrs) > 1: 26 | substrs = substrs[1:] 27 | for substr in substrs: 28 | new_str += "\\frac" 29 | if substr[0] == "{": 30 | new_str += substr 31 | else: 32 | try: 33 | assert len(substr) >= 2 34 | except: 35 | return string 36 | a = substr[0] 37 | b = substr[1] 38 | if b != "{": 39 | if len(substr) > 2: 40 | post_substr = substr[2:] 41 | new_str += "{" + a + "}{" + b + "}" + post_substr 42 | else: 43 | new_str += "{" + a + "}{" + b + "}" 44 | else: 45 | if len(substr) > 2: 46 | post_substr = substr[2:] 47 | new_str += "{" + a + "}" + b + post_substr 48 | else: 49 | new_str += "{" + a + "}" + b 50 | string = new_str 51 | return string 52 | 53 | 54 | def _fix_a_slash_b(string): 55 | if len(string.split("/")) != 2: 56 | return string 57 | a = string.split("/")[0] 58 | b = string.split("/")[1] 59 | try: 60 | a = int(a) 61 | b = int(b) 62 | assert string == "{}/{}".format(a, b) 63 | new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" 64 | return new_string 65 | except: 66 | return string 67 | 68 | 69 | def _remove_right_units(string): 70 | # "\\text{ " only ever occurs (at least in the val set) when describing units 71 | if "\\text{ " in string: 72 | splits = string.split("\\text{ ") 73 | assert len(splits) == 2 74 | return splits[0] 75 | else: 76 | return string 77 | 78 | 79 | def _fix_sqrt(string): 80 | if "\\sqrt" not in string: 81 | return string 82 | splits = string.split("\\sqrt") 83 | new_string = splits[0] 84 | for split in splits[1:]: 85 | if split[0] != "{": 86 | a = split[0] 87 | new_substr = "\\sqrt{" + a + "}" + split[1:] 88 | else: 89 | new_substr = "\\sqrt" + split 90 | new_string += new_substr 91 | return new_string 92 | 93 | 94 | def _strip_string(string): 95 | # linebreaks 96 | string = string.replace("\n", "") 97 | # print(string) 98 | 99 | # remove inverse spaces 100 | string = string.replace("\\!", "") 101 | # print(string) 102 | 103 | # replace \\ with \ 104 | string = string.replace("\\\\", "\\") 105 | # print(string) 106 | 107 | # replace tfrac and dfrac with frac 108 | string = string.replace("tfrac", "frac") 109 | string = string.replace("dfrac", "frac") 110 | # print(string) 111 | 112 | # remove \left and \right 113 | string = string.replace("\\left", "") 114 | string = string.replace("\\right", "") 115 | # print(string) 116 | 117 | # Remove circ (degrees) 118 | string = string.replace("^{\\circ}", "") 119 | string = string.replace("^\\circ", "") 120 | 121 | # remove dollar signs 122 | string = string.replace("\\$", "") 123 | 124 | # remove units (on the right) 125 | string = _remove_right_units(string) 126 | 127 | # remove percentage 128 | string = string.replace("\\%", "") 129 | string = string.replace("\%", "") 130 | 131 | # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string 132 | string = string.replace(" .", " 0.") 133 | string = string.replace("{.", "{0.") 134 | # if empty, return empty string 135 | if len(string) == 0: 136 | return string 137 | if string[0] == ".": 138 | string = "0" + string 139 | 140 | # to consider: get rid of e.g. "k = " or "q = " at beginning 141 | if len(string.split("=")) == 2: 142 | if len(string.split("=")[0]) <= 2: 143 | string = string.split("=")[1] 144 | 145 | # fix sqrt3 --> sqrt{3} 146 | string = _fix_sqrt(string) 147 | 148 | # remove spaces 149 | string = string.replace(" ", "") 150 | 151 | # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} 152 | string = _fix_fracs(string) 153 | 154 | # manually change 0.5 --> \frac{1}{2} 155 | if string == "0.5": 156 | string = "\\frac{1}{2}" 157 | 158 | # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y 159 | string = _fix_a_slash_b(string) 160 | 161 | return string 162 | -------------------------------------------------------------------------------- /train_code/models/__pycache__/frozen_layers.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/models/__pycache__/frozen_layers.cpython-310.pyc -------------------------------------------------------------------------------- /train_code/models/__pycache__/frozen_layers.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/models/__pycache__/frozen_layers.cpython-311.pyc -------------------------------------------------------------------------------- /train_code/models/__pycache__/frozen_layers.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/models/__pycache__/frozen_layers.cpython-39.pyc -------------------------------------------------------------------------------- /train_code/models/__pycache__/model.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/models/__pycache__/model.cpython-310.pyc -------------------------------------------------------------------------------- /train_code/models/__pycache__/model.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/models/__pycache__/model.cpython-311.pyc -------------------------------------------------------------------------------- /train_code/models/__pycache__/model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/models/__pycache__/model.cpython-39.pyc -------------------------------------------------------------------------------- /train_code/models/__pycache__/quantize.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/models/__pycache__/quantize.cpython-310.pyc -------------------------------------------------------------------------------- /train_code/models/__pycache__/quantize.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/models/__pycache__/quantize.cpython-311.pyc -------------------------------------------------------------------------------- /train_code/models/__pycache__/reward_model.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/models/__pycache__/reward_model.cpython-310.pyc -------------------------------------------------------------------------------- /train_code/models/__pycache__/reward_model.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/models/__pycache__/reward_model.cpython-311.pyc -------------------------------------------------------------------------------- /train_code/models/__pycache__/rl_model.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/models/__pycache__/rl_model.cpython-310.pyc -------------------------------------------------------------------------------- /train_code/models/__pycache__/rl_model.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/models/__pycache__/rl_model.cpython-311.pyc -------------------------------------------------------------------------------- /train_code/models/__pycache__/tokenizer_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/models/__pycache__/tokenizer_utils.cpython-310.pyc -------------------------------------------------------------------------------- /train_code/models/__pycache__/tokenizer_utils.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/models/__pycache__/tokenizer_utils.cpython-311.pyc -------------------------------------------------------------------------------- /train_code/models/__pycache__/tokenizer_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/models/__pycache__/tokenizer_utils.cpython-38.pyc -------------------------------------------------------------------------------- /train_code/models/__pycache__/tp.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/models/__pycache__/tp.cpython-310.pyc -------------------------------------------------------------------------------- /train_code/models/__pycache__/tp.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/models/__pycache__/tp.cpython-311.pyc -------------------------------------------------------------------------------- /train_code/models/frozen_layers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The GPT-Accelera Team 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Optional 8 | 9 | import torch 10 | import torch.nn as nn 11 | from torch.nn import functional as F 12 | from torch import Tensor 13 | 14 | 15 | try: 16 | from apex.normalization.fused_layer_norm import FusedRMSNormFunction 17 | 18 | # print( 19 | # "`apex` is installed. You can use fused RMSNorm by set_global_compile_mode(False)." 20 | # ) 21 | except ImportError as e: 22 | FusedRMSNormFunction = None 23 | # print("`apex` is not installed. Reverting to non-fused RMSNorm.") 24 | 25 | # whether to use fused RMSNorm or not (default: no) 26 | _GLOBAL_IN_COMPILE_MODE = True 27 | 28 | 29 | def find_multiple(n: int, k: int) -> int: 30 | if n % k == 0: 31 | return n 32 | return n + k - (n % k) 33 | 34 | 35 | class FrozenEmbedding(nn.Module): 36 | __constants__ = [ 37 | "num_embeddings", 38 | "embedding_dim", 39 | "padding_idx", 40 | "max_norm", 41 | "norm_type", 42 | "scale_grad_by_freq", 43 | "sparse", 44 | ] 45 | 46 | num_embeddings: int 47 | embedding_dim: int 48 | padding_idx: Optional[int] 49 | max_norm: Optional[float] 50 | norm_type: float 51 | scale_grad_by_freq: bool 52 | weight: Tensor 53 | freeze: bool 54 | sparse: bool 55 | 56 | def __init__( 57 | self, 58 | num_embeddings: int, 59 | embedding_dim: int, 60 | device=None, 61 | dtype=None, 62 | ) -> None: 63 | factory_kwargs = {"device": device, "dtype": dtype} 64 | super().__init__() 65 | self.num_embeddings = num_embeddings 66 | self.embedding_dim = embedding_dim 67 | self.padding_idx = None 68 | self.max_norm = None 69 | self.norm_type = 2.0 70 | self.scale_grad_by_freq = False 71 | self.sparse = False 72 | self.vocab_start_index = None 73 | self.vocab_end_index = None 74 | self.num_embeddings_per_partition = None 75 | self.register_buffer( 76 | "weight", torch.empty((num_embeddings, embedding_dim), **factory_kwargs) 77 | ) 78 | 79 | def forward(self, input: Tensor) -> Tensor: 80 | if self.num_embeddings_per_partition is None: 81 | return F.embedding( 82 | input, 83 | self.weight, 84 | self.padding_idx, 85 | self.max_norm, 86 | self.norm_type, 87 | self.scale_grad_by_freq, 88 | self.sparse, 89 | ) 90 | else: 91 | # Build the mask. 92 | print("vocab_start_index", self.vocab_start_index) 93 | print("vocab_end_index", self.vocab_end_index) 94 | input_mask = (input < self.vocab_start_index) | ( 95 | input >= self.vocab_end_index 96 | ) 97 | # Mask the input. 98 | masked_input = input.clone() - self.vocab_start_index 99 | masked_input[input_mask] = 0 100 | # Get the embeddings. 101 | output_parallel = F.embedding( 102 | masked_input, 103 | self.weight, 104 | self.padding_idx, 105 | self.max_norm, 106 | self.norm_type, 107 | self.scale_grad_by_freq, 108 | self.sparse, 109 | ) 110 | # Mask the output embedding. 111 | output_parallel[input_mask, :] = 0.0 112 | return output_parallel 113 | 114 | def extra_repr(self) -> str: 115 | s = "{num_embeddings}, {embedding_dim}" 116 | if self.padding_idx is not None: 117 | s += ", padding_idx={padding_idx}" 118 | if self.max_norm is not None: 119 | s += ", max_norm={max_norm}" 120 | if self.norm_type != 2.0: 121 | s += ", norm_type={norm_type}" 122 | if self.scale_grad_by_freq is not False: 123 | s += ", scale_grad_by_freq={scale_grad_by_freq}" 124 | if self.sparse is not False: 125 | s += ", sparse=True" 126 | return s.format(**self.__dict__) 127 | 128 | 129 | class FrozenRMSNorm(nn.Module): 130 | def __init__(self, dim: int, eps: float = 1e-5): 131 | super().__init__() 132 | self.eps = eps 133 | self.register_buffer("weight", torch.ones(dim)) 134 | 135 | global _GLOBAL_IN_COMPILE_MODE 136 | self.in_compile_mode = _GLOBAL_IN_COMPILE_MODE 137 | 138 | def _norm(self, x): 139 | return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) 140 | 141 | def forward(self, x: Tensor) -> Tensor: 142 | if self.in_compile_mode or FusedRMSNormFunction is None: 143 | with torch.autocast(device_type="cuda", enabled=False): 144 | output = self._norm(x.float()).to(dtype=x.dtype) 145 | return output * self.weight 146 | else: 147 | with torch.autocast(device_type="cuda", enabled=False): 148 | output = FusedRMSNormFunction.apply( 149 | x, 150 | self.weight.size(), 151 | self.eps, 152 | False, 153 | ) 154 | return output * self.weight 155 | 156 | 157 | class FrozenLinear(nn.Module): 158 | __constants__ = ["in_features", "out_features"] 159 | in_features: int 160 | out_features: int 161 | weight: Tensor 162 | 163 | def __init__( 164 | self, 165 | in_features: int, 166 | out_features: int, 167 | bias: bool = True, 168 | device=None, 169 | dtype=None, 170 | ) -> None: 171 | factory_kwargs = {"device": device, "dtype": dtype} 172 | super().__init__() 173 | self.in_features = in_features 174 | self.out_features = out_features 175 | self.register_buffer( 176 | "weight", torch.empty((out_features, in_features), **factory_kwargs) 177 | ) 178 | if bias: 179 | self.register_buffer("bias", torch.empty((out_features,), **factory_kwargs)) 180 | else: 181 | self.register_buffer("bias", None) 182 | 183 | def forward(self, input: Tensor) -> Tensor: 184 | return F.linear(input, self.weight, self.bias) 185 | 186 | def extra_repr(self) -> str: 187 | return f"in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}" 188 | -------------------------------------------------------------------------------- /train_code/models/reward_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The GPT-Accelera Team 2 | # Copyright 2023 The Self-Align Team 3 | # Copyright 2023 The Alpaca Team 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | from dataclasses import dataclass 18 | import math 19 | from typing import Optional, Dict, Sequence, Union 20 | 21 | import einops 22 | import torch 23 | from torch import Tensor, nn 24 | import torch.nn.functional as F 25 | 26 | from models.model import ModelArgs, Transformer 27 | 28 | 29 | def unpack_dict( 30 | d: Dict, keys: Sequence[str], return_type: type = tuple 31 | ) -> Union[Sequence, Dict]: 32 | if return_type in (tuple, list): 33 | return return_type(d[key] for key in keys) 34 | elif return_type == dict: 35 | return {key: d[key] for key in keys} 36 | else: 37 | raise ValueError(f"Unknown return_type: {return_type}") 38 | 39 | 40 | def batch_select(input: Tensor, index: Tensor): 41 | """Select elements from a batched tensor with a batched index tensor. 42 | 43 | Example: 44 | input = torch.tensor([ 45 | [0, 1, 2], 46 | [3, 0, 9], 47 | [6, 7, 8], 48 | ]) 49 | index = torch.tensor([[0, 1], [1, 0], [0, 0]]) 50 | batch_select(input, index) = tensor([ 51 | [0, 1], 52 | [0, 3], 53 | [6, 6] 54 | ]) 55 | """ 56 | dummy_index = torch.arange(input.size(0), device=input.device).unsqueeze(-1) 57 | return input[dummy_index, index] 58 | 59 | 60 | @dataclass 61 | class RewardArgs: 62 | backbone_args: ModelArgs 63 | 64 | @classmethod 65 | def from_name(cls, name: str): 66 | return cls(backbone_args=ModelArgs.from_name(name)) 67 | 68 | 69 | class RewardModel(nn.Module): 70 | def __init__(self, config: RewardArgs, **kwargs) -> None: 71 | super().__init__() 72 | self.config = config 73 | self.backbone_model = Transformer(config.backbone_args, **kwargs) 74 | 75 | def forward( 76 | self, 77 | idx: Tensor, 78 | eos_pos: Optional[Tensor] = None, 79 | ) -> Tensor: 80 | input_pos = torch.arange(0, idx.size(-1), device=idx.device) 81 | rewards = self.backbone_model(idx, input_pos=input_pos, fully_causal=True) 82 | rewards = rewards.mean(dim=-1) 83 | 84 | if eos_pos is not None: 85 | eos_pos = eos_pos.unsqueeze(-1) 86 | rewards = batch_select(rewards, eos_pos).squeeze(-1) 87 | 88 | return rewards 89 | 90 | @classmethod 91 | def from_name(cls, name: str, **kwargs): 92 | return cls(RewardArgs.from_name(name), **kwargs) 93 | 94 | 95 | def apply_reward_modeling_head( 96 | transformer: Transformer, requires_grad=False, init_sceheme="zeros" 97 | ): 98 | output_module = transformer.output 99 | # Linear's weight matrix is transposed, and is of shape 100 | # (linear.out_features, linear.in_features) 101 | 102 | # Temp fix due to https://github.com/pytorch/pytorch/issues/106951 103 | reward_head_weight = torch.zeros_like(output_module.weight)[:2, :] 104 | if init_sceheme == "zeros": 105 | output_module.weight = nn.Parameter( 106 | reward_head_weight, 107 | requires_grad=requires_grad, 108 | ) 109 | elif init_sceheme == "semantic": 110 | # ['### Preferred Output is '] [835, 4721, 14373, 10604, 338, 29871] 111 | # ['### Preferred Output is 1.'] [835, 4721, 14373, 10604, 338, 29871, 29896, 29889] 112 | # ['### Preferred Output is 2.'] [835, 4721, 14373, 10604, 338, 29871, 29906, 29889] 113 | token_1_id = 29896 114 | token_2_id = 29906 115 | reward_head_weight[0, :] = output_module.weight[token_2_id, :] 116 | reward_head_weight[1, :] = -output_module.weight[token_1_id, :] 117 | output_module.weight = nn.Parameter( 118 | reward_head_weight, 119 | requires_grad=requires_grad, 120 | ) 121 | elif init_sceheme == "random": 122 | generator = torch.Generator(device=reward_head_weight.device) 123 | generator.manual_seed(42) 124 | nn.init.kaiming_uniform_( 125 | reward_head_weight, a=math.sqrt(5), generator=generator 126 | ) 127 | output_module.weight = nn.Parameter( 128 | reward_head_weight * math.sqrt(2.0), 129 | requires_grad=requires_grad, 130 | ) 131 | else: 132 | raise ValueError(f"Unknown init_scheme: {init_sceheme}") 133 | setattr(output_module, "out_features", 2) 134 | 135 | 136 | def compute_pairwise_reward_modeling_loss(model, inputs, return_outputs=False): 137 | # input_ids, attention_mask each of size (bsz, num_candidates, seq_len). 138 | # index_0, index_1 each of size (bsz, num_pairs); indexes into input_ids. 139 | # choice of size (bsz, num_pairs); 1 if index_1's seq is chosen, 0 otherwise. 140 | input_ids, eos_pos, index_0, index_1, choice = unpack_dict( 141 | inputs, keys=("input_ids", "eos_pos", "index_0", "index_1", "choice") 142 | ) 143 | num_candidates, num_pairs = input_ids.size(1), choice.size(1) 144 | input_ids_flat = einops.rearrange(input_ids, "b c l -> (b c) l") 145 | eos_pos_flat = einops.rearrange(eos_pos, "b c -> (b c)") 146 | input_pos_flat = torch.arange( 147 | 0, input_ids_flat.size(-1), device=input_ids_flat.device 148 | ) 149 | outputs = model( 150 | input_ids=input_ids_flat, 151 | input_pos=input_pos_flat, 152 | eos_pos=eos_pos_flat, 153 | ) 154 | rewards_flat = outputs.rewards 155 | rewards = einops.rearrange( 156 | rewards_flat, "(b c) -> b c", c=num_candidates 157 | ) # Size: (bsz, num_candidates). 158 | 159 | rewards_0, rewards_1 = tuple( 160 | batch_select(rewards, index) for index in (index_0, index_1) 161 | ) # Size: (bsz, num_pairs). 162 | logits = rewards_1 - rewards_0 # Size: (bsz, num_pairs). 163 | # Type casting of `choice` is due to amp.autocast context manager. 164 | loss = F.binary_cross_entropy_with_logits( 165 | logits, choice.to(logits.dtype), reduction="mean" 166 | ) 167 | return (loss, dict(logits=logits)) if return_outputs else loss 168 | 169 | 170 | def compute_pairwise_reward_modeling_metrics( 171 | predictions: torch.Tensor, label_ids: torch.Tensor 172 | ) -> Dict: 173 | # eval_prediction.label_ids is a tuple that matches up with `training_args.label_names`. 174 | logits = torch.tensor(predictions).squeeze(-1) 175 | labels = torch.tensor(label_ids[-1]).squeeze(-1) 176 | predictions = (logits >= 0.0).long() 177 | accuracy = predictions.eq(labels).float().mean().item() 178 | label_positive_rate = (labels == 1).float().mean().item() 179 | return dict( 180 | accuracy=accuracy, 181 | label_positive_rate=label_positive_rate, 182 | ) 183 | -------------------------------------------------------------------------------- /train_code/scripts/convert.sh: -------------------------------------------------------------------------------- 1 | # export DATA_DIR=/cfs/hadoop-aipnlp/zengweihao02/modelzoo/Llama-2-70b-hf/Llama-2-70b-hf 2 | # #export MODEL_REPO=ScalableMath/llemma-7b-sft-metamath-level-1to3-hf 3 | 4 | # python convert_hf_checkpoint.py \ 5 | # --checkpoint_dir $DATA_DIR \ 6 | # --target_precision bf16 7 | 8 | 9 | export DATA_DIR=/cfs/hadoop-aipnlp/zengweihao02/modelzoo/Llama-3.1-8B-Instruct 10 | #export MODEL_REPO=ScalableMath/llemma-7b-sft-metamath-level-1to3-hf 11 | 12 | 13 | #export DATA_DIR=/mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/zengweihao02/modelzoo/deepseek-math-7b-base/deepseek-math-7b-base 14 | # python convert_hf_checkpoint.py \ 15 | # --checkpoint_dir $DATA_DIR \ 16 | # --target_precision bf16 17 | #/cfs/hadoop-aipnlp/zengweihao02/b-star/easy-to-hard-main-share/scripts/convert_hf_checkpoint_llama3.py 18 | 19 | python /cfs/hadoop-aipnlp/zengweihao02/b-star/easy-to-hard-main-share/scripts/convert_hf_checkpoint_llama3.py \ 20 | --checkpoint_dir $DATA_DIR \ 21 | --target_precision bf16 22 | 23 | # python /mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/zengweihao02/easy2hard/share/project/weihao/easy-to-hard-main-share/easy-to-hard-main-share/scripts/convert_hf_checkpoint.py \ 24 | # --checkpoint_dir $DATA_DIR \ 25 | # --target_precision bf16 -------------------------------------------------------------------------------- /train_code/scripts/convert_checkpoint_to_hf.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig 2 | from tqdm import tqdm 3 | import torch 4 | import re 5 | import argparse 6 | import os 7 | import glob 8 | 9 | # we need to check that we have login the HF account 10 | # !huggingface-cli whoami 11 | # !huggingface-cli login 12 | 13 | 14 | def load_and_merge_models( 15 | tp_ckpt_name, pretrain_name, tokenizer_name, save_name_hf, push_to_hf_hub_name 16 | ): 17 | assert ( 18 | save_name_hf or push_to_hf_hub_name 19 | ), "Please provide a save path or push to HF hub name" 20 | 21 | tp_model_list = [] 22 | 23 | last_checkpoint_file = os.path.join(tp_ckpt_name, "last_checkpoint") 24 | with open(last_checkpoint_file, "r") as f: 25 | last_checkpoint_file = f.readline().strip() 26 | 27 | last_checkpoint_file = last_checkpoint_file.split("/")[-1] 28 | last_checkpoint_file = os.path.join(tp_ckpt_name, last_checkpoint_file) 29 | 30 | print("Loading checkpoint files:", last_checkpoint_file) 31 | for file in sorted(glob.glob(last_checkpoint_file)): 32 | tp_model_list.append( 33 | torch.load( 34 | file, 35 | mmap=True, 36 | )["model"] 37 | ) 38 | 39 | print("Loading HF model...") 40 | tokenizer = AutoTokenizer.from_pretrained( 41 | tokenizer_name, 42 | ) 43 | 44 | model = AutoModelForCausalLM.from_pretrained( 45 | pretrain_name, 46 | # device_map="cpu", 47 | load_in_8bit=False, 48 | torch_dtype=torch.bfloat16, 49 | ) 50 | cpu_state_dict = model.cpu().state_dict() 51 | 52 | replaced_keys = set() 53 | 54 | print("Convert to HF model...") 55 | num_tp = len(tp_model_list) 56 | 57 | state_dict = {} 58 | 59 | for key in tp_model_list[0].keys(): 60 | if "wo" in key or "w2" in key: 61 | state_dict[key] = torch.cat( 62 | [tp_model_list[i][key].cpu() for i in range(num_tp)], dim=1 63 | ) 64 | elif "wqkv" in key: 65 | state_dict[key] = torch.stack( 66 | [tp_model_list[i][key].cpu() for i in range(num_tp)], dim=0 67 | ) 68 | elif "output" in key: 69 | state_dict[key] = torch.cat( 70 | [tp_model_list[i][key].cpu() for i in range(num_tp)], dim=1 71 | ) 72 | else: 73 | state_dict[key] = torch.cat( 74 | [tp_model_list[i][key].cpu() for i in range(num_tp)], dim=0 75 | ) 76 | 77 | pattern = r"layers\.(\d+)\." 78 | 79 | for key in state_dict.keys(): 80 | layer = None 81 | match = re.search(pattern, key) 82 | # layer number except for: 83 | # lm_head.weight 84 | if match: 85 | layer = match.group(1) 86 | elif "output.weight" in key: 87 | name = f"lm_head.weight" 88 | print(cpu_state_dict[name].size(), state_dict[key].size()) 89 | # repeat on dim 0 to match the size 90 | repeat_size = cpu_state_dict[name].size(0) // state_dict[key].size(0) 91 | new_state_dict = state_dict[key].repeat(repeat_size, 1) 92 | cpu_state_dict[name] = 0.0 * cpu_state_dict[name] + new_state_dict 93 | replaced_keys.add(name) 94 | else: 95 | raise ValueError(f"Invalid key: {key}") 96 | 97 | print("Converting layer", key) 98 | if "wqkv" in key: 99 | merged_q, merged_k, merged_v = [], [], [] 100 | reconstruct_q, reconstruct_k = [], [] 101 | 102 | if state_dict[key].size(2) == 4096: 103 | n_heads, n_local_heads = 32, 32 104 | elif state_dict[key].size(2) == 5120: 105 | n_heads, n_local_heads = 40, 40 106 | elif state_dict[key].size(2) == 6656: 107 | n_heads, n_local_heads = 52, 52 108 | elif state_dict[key].size(2) == 8192: 109 | n_heads, n_local_heads = 64, 8 110 | else: 111 | raise ValueError(f"Invalid size for {key}: {state_dict[key].size()}") 112 | 113 | head_dim = state_dict[key].size(1) // (n_heads + n_local_heads * 2) 114 | 115 | weight_splits = [ 116 | head_dim * n_heads, 117 | head_dim * n_local_heads, 118 | head_dim * n_local_heads, 119 | ] 120 | 121 | for split_idx in range(state_dict[key].size(0)): 122 | chunk = state_dict[key][split_idx] 123 | q, k, v = chunk.split(weight_splits, dim=0) 124 | merged_q.append(q) 125 | merged_k.append(k) 126 | merged_v.append(v) 127 | merged_q = torch.cat(merged_q, dim=0) 128 | merged_k = torch.cat(merged_k, dim=0) 129 | merged_v = torch.cat(merged_v, dim=0) 130 | 131 | #### qk need reconstruction #### 132 | split_qs = torch.split(merged_q, split_size_or_sections=128, dim=0) 133 | split_ks = torch.split(merged_k, split_size_or_sections=128, dim=0) 134 | for split in split_qs: 135 | matrix0 = split[::2, :] 136 | matrix1 = split[1::2, :] 137 | reconstruct_q.append(matrix0) 138 | reconstruct_q.append(matrix1) 139 | reconstruct_q = torch.cat(reconstruct_q, dim=0) 140 | for split in split_ks: 141 | matrix0 = split[::2, :] 142 | matrix1 = split[1::2, :] 143 | reconstruct_k.append(matrix0) 144 | reconstruct_k.append(matrix1) 145 | reconstruct_k = torch.cat(reconstruct_k, dim=0) 146 | #### qk need reconstruction #### 147 | 148 | name = f"model.layers.{layer}.self_attn.q_proj.weight" 149 | cpu_state_dict[name] = reconstruct_q 150 | replaced_keys.add(name) 151 | 152 | name = f"model.layers.{layer}.self_attn.k_proj.weight" 153 | cpu_state_dict[name] = reconstruct_k 154 | replaced_keys.add(name) 155 | 156 | name = f"model.layers.{layer}.self_attn.v_proj.weight" 157 | cpu_state_dict[name] = merged_v 158 | replaced_keys.add(name) 159 | 160 | if "wo" in key: 161 | name = f"model.layers.{layer}.self_attn.o_proj.weight" 162 | cpu_state_dict[name] = state_dict[key] 163 | replaced_keys.add(name) 164 | if "w1" in key: 165 | name = f"model.layers.{layer}.mlp.gate_proj.weight" 166 | cpu_state_dict[name] = state_dict[key] 167 | replaced_keys.add(name) 168 | if "w3" in key: 169 | name = f"model.layers.{layer}.mlp.up_proj.weight" 170 | cpu_state_dict[name] = state_dict[key] 171 | replaced_keys.add(name) 172 | if "w2" in key: 173 | name = f"model.layers.{layer}.mlp.down_proj.weight" 174 | cpu_state_dict[name] = state_dict[key] 175 | replaced_keys.add(name) 176 | 177 | unreplaced_keys = set(cpu_state_dict.keys()) - replaced_keys 178 | print("Unreplaced keys:", unreplaced_keys) 179 | 180 | print("Loading state dict...") 181 | 182 | model.load_state_dict(cpu_state_dict, strict=False) 183 | 184 | print("Saving HF model...") 185 | 186 | if save_name_hf is not None: 187 | model.save_pretrained(save_name_hf) 188 | config = AutoConfig.from_pretrained(pretrain_name) 189 | tokenizer.save_pretrained(save_name_hf) 190 | config.save_pretrained(save_name_hf) 191 | else: 192 | model.push_to_hub(push_to_hf_hub_name, private=True, safe_serialization=False) 193 | 194 | 195 | if __name__ == "__main__": 196 | parser = argparse.ArgumentParser(description="Process some integers.") 197 | parser.add_argument( 198 | "--tp_ckpt_name", type=str, help="Path to the TP checkpoint name", required=True 199 | ) 200 | parser.add_argument( 201 | "--tokenizer_name", type=str, help="Path to the tokenizer name", required=True 202 | ) 203 | parser.add_argument( 204 | "--pretrain_name", type=str, help="Path to the pretrain name", required=True 205 | ) 206 | parser.add_argument( 207 | "--save_name_hf", type=str, default=None, help="Path to save the HF model" 208 | ) 209 | parser.add_argument( 210 | "--push_to_hf_hub_name", type=str, default=None, help="Push to HF hub" 211 | ) 212 | 213 | args = parser.parse_args() 214 | load_and_merge_models( 215 | args.tp_ckpt_name, 216 | args.pretrain_name, 217 | args.tokenizer_name, 218 | args.save_name_hf, 219 | args.push_to_hf_hub_name, 220 | ) 221 | -------------------------------------------------------------------------------- /train_code/scripts/convert_hf_checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import json 7 | import sys 8 | from pathlib import Path 9 | from typing import Optional 10 | 11 | import torch 12 | import re 13 | 14 | # support running without installing as a package 15 | wd = Path(__file__).parent.parent.resolve() 16 | sys.path.append(str(wd)) 17 | 18 | from models.model import ModelArgs 19 | 20 | 21 | @torch.inference_mode() 22 | def convert_hf_checkpoint( 23 | *, 24 | checkpoint_dir: Path = Path( 25 | "checkpoints/meta-Transformer/Transformer-2-7b-chat-hf" 26 | ), 27 | model_name: Optional[str] = None, 28 | target_precision: str = "fp32", 29 | ) -> None: 30 | if model_name is None: 31 | model_name = checkpoint_dir.name 32 | 33 | config = ModelArgs.from_name(model_name) 34 | print(f"Model config {config.__dict__}") 35 | 36 | # Load the json file containing weight mapping 37 | model_map_json = checkpoint_dir / "pytorch_model.bin.index.json" 38 | 39 | assert model_map_json.is_file() 40 | 41 | with open(model_map_json) as json_map: 42 | bin_index = json.load(json_map) 43 | 44 | weight_map = { 45 | "model.embed_tokens.weight": "tok_embeddings.weight", 46 | "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight", 47 | "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight", 48 | "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight", 49 | "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight", 50 | "model.layers.{}.self_attn.rotary_emb.inv_freq": None, 51 | "model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight", 52 | "model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight", 53 | "model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight", 54 | "model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight", 55 | "model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight", 56 | "model.norm.weight": "norm.weight", 57 | "lm_head.weight": "output.weight", 58 | } 59 | bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()} 60 | 61 | def permute(w, n_head): 62 | dim = config.dim 63 | return ( 64 | w.view(n_head, 2, config.head_dim // 2, dim) 65 | .transpose(1, 2) 66 | .reshape(config.head_dim * n_head, dim) 67 | ) 68 | 69 | merged_result = {} 70 | for file in sorted(bin_files): 71 | state_dict = torch.load( 72 | str(file), map_location="cpu", mmap=True, weights_only=True 73 | ) 74 | 75 | if target_precision == "fp16": 76 | for key in tuple(state_dict.keys()): 77 | state_dict[key] = state_dict[key].half() 78 | elif target_precision == "bf16": 79 | for key in tuple(state_dict.keys()): 80 | state_dict[key] = state_dict[key].bfloat16() 81 | elif target_precision == "fp32": 82 | pass 83 | else: 84 | raise ValueError(f"Unsupported target_precision {target_precision}") 85 | merged_result.update(state_dict) 86 | final_result = {} 87 | for key, value in merged_result.items(): 88 | if "layers" in key: 89 | abstract_key = re.sub(r"(\d+)", "{}", key) 90 | layer_num = re.search(r"\d+", key).group(0) 91 | new_key = weight_map[abstract_key] 92 | if new_key is None: 93 | continue 94 | new_key = new_key.format(layer_num) 95 | else: 96 | new_key = weight_map[key] 97 | 98 | if len(value.shape) == 2 and value.size(1) == 32016: 99 | value = value[:, :32000] 100 | if value.size(0) == 32016: 101 | value = value[:32000, :] 102 | 103 | final_result[new_key] = value 104 | 105 | for key in tuple(final_result.keys()): 106 | if "wq" in key: 107 | q = final_result[key] 108 | k = final_result[key.replace("wq", "wk")] 109 | v = final_result[key.replace("wq", "wv")] 110 | q = permute(q, config.n_head) 111 | k = permute(k, config.n_local_heads) 112 | final_result[key.replace("wq", "wqkv")] = torch.cat([q, k, v]) 113 | del final_result[key] 114 | del final_result[key.replace("wq", "wk")] 115 | del final_result[key.replace("wq", "wv")] 116 | print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}") 117 | torch.save(final_result, checkpoint_dir / "model.pth") 118 | 119 | 120 | if __name__ == "__main__": 121 | import argparse 122 | 123 | parser = argparse.ArgumentParser(description="Convert HuggingFace checkpoint.") 124 | parser.add_argument( 125 | "--checkpoint_dir", 126 | type=Path, 127 | default=Path("checkpoints/meta-llama/llama-2-7b-chat-hf"), 128 | ) 129 | parser.add_argument("--model_name", type=str, default=None) 130 | parser.add_argument("--target_precision", type=str, default="fp32") 131 | 132 | args = parser.parse_args() 133 | convert_hf_checkpoint( 134 | checkpoint_dir=args.checkpoint_dir, 135 | model_name=args.model_name, 136 | target_precision=args.target_precision, 137 | ) 138 | -------------------------------------------------------------------------------- /train_code/scripts/convert_hf_checkpoint_llama3.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import json 7 | import sys 8 | from pathlib import Path 9 | from typing import Optional 10 | # from safetensors import load_file 11 | from safetensors.torch import load_file 12 | import torch 13 | import re 14 | 15 | # support running without installing as a package 16 | wd = Path(__file__).parent.parent.resolve() 17 | sys.path.append(str(wd)) 18 | 19 | from models.model import ModelArgs 20 | 21 | 22 | @torch.inference_mode() 23 | def convert_hf_checkpoint( 24 | *, 25 | checkpoint_dir: Path = Path( 26 | "checkpoints/meta-Transformer/Transformer-2-7b-chat-hf" 27 | ), 28 | model_name: Optional[str] = None, 29 | target_precision: str = "fp32", 30 | ) -> None: 31 | if model_name is None: 32 | model_name = checkpoint_dir.name 33 | print(f"Model Name {model_name}") 34 | config = ModelArgs.from_name(model_name) 35 | print(f"Model config {config.__dict__}") 36 | 37 | # Load the json file containing weight mapping 38 | #model_map_json = checkpoint_dir / "pytorch_model.bin.index.json" 39 | 40 | model_map_json = checkpoint_dir / "model.safetensors.index.json" 41 | 42 | assert model_map_json.is_file() 43 | 44 | with open(model_map_json) as json_map: 45 | bin_index = json.load(json_map) 46 | 47 | weight_map = { 48 | "model.embed_tokens.weight": "tok_embeddings.weight", 49 | "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight", 50 | "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight", 51 | "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight", 52 | "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight", 53 | "model.layers.{}.self_attn.rotary_emb.inv_freq": None, 54 | "model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight", 55 | "model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight", 56 | "model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight", 57 | "model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight", 58 | "model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight", 59 | "model.norm.weight": "norm.weight", 60 | "lm_head.weight": "output.weight", 61 | } 62 | 63 | 64 | # if model_name == "Meta-Llama-3-8B": 65 | # weight_map = { 66 | # "embed_tokens.weight": "tok_embeddings.weight", 67 | # "layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight", 68 | # "layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight", 69 | # "layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight", 70 | # "layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight", 71 | # "layers.{}.self_attn.rotary_emb.inv_freq": None, 72 | # "layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight", 73 | # "layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight", 74 | # "layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight", 75 | # "layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight", 76 | # "layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight", 77 | # "norm.weight": "norm.weight", 78 | # "lm_head.weight": "output.weight", 79 | # } 80 | bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()} 81 | 82 | def permute(w, n_head): 83 | dim = config.dim 84 | return ( 85 | w.view(n_head, 2, config.head_dim // 2, dim) 86 | .transpose(1, 2) 87 | .reshape(config.head_dim * n_head, dim) 88 | ) 89 | 90 | merged_result = {} 91 | 92 | print(bin_files) 93 | # for file in sorted(bin_files): 94 | # state_dict = torch.load( 95 | # str(file), map_location="cpu", mmap=True, weights_only=True 96 | # ) 97 | 98 | # if target_precision == "fp16": 99 | # for key in tuple(state_dict.keys()): 100 | # state_dict[key] = state_dict[key].half() 101 | # elif target_precision == "bf16": 102 | # for key in tuple(state_dict.keys()): 103 | # state_dict[key] = state_dict[key].bfloat16() 104 | # elif target_precision == "fp32": 105 | # pass 106 | # else: 107 | # raise ValueError(f"Unsupported target_precision {target_precision}") 108 | # merged_result.update(state_dict) 109 | 110 | # 假设 bin_files 是存储 .safetensors 文件路径的列表 111 | for file in sorted(bin_files): 112 | # 使用 safetensors 的 load_file 函数来加载权重 113 | state_dict = load_file(str(file)) 114 | 115 | # 根据目标精度进行转换 116 | if target_precision == "fp16": 117 | for key in tuple(state_dict.keys()): 118 | state_dict[key] = state_dict[key].half() 119 | elif target_precision == "bf16": 120 | for key in tuple(state_dict.keys()): 121 | state_dict[key] = state_dict[key].bfloat16() 122 | elif target_precision == "fp32": 123 | pass 124 | else: 125 | raise ValueError(f"Unsupported target_precision {target_precision}") 126 | 127 | # 将加载的权重更新到最终的 merged_result 字典中 128 | merged_result.update(state_dict) 129 | final_result = {} 130 | for key, value in merged_result.items(): 131 | print("key", key) 132 | if "layers" in key: 133 | abstract_key = re.sub(r"(\d+)", "{}", key) 134 | layer_num = re.search(r"\d+", key).group(0) 135 | new_key = weight_map[abstract_key] 136 | if new_key is None: 137 | continue 138 | new_key = new_key.format(layer_num) 139 | else: 140 | new_key = weight_map[key] 141 | 142 | if len(value.shape) == 2 and value.size(1) == 32016: 143 | value = value[:, :32000] 144 | if value.size(0) == 32016: 145 | value = value[:32000, :] 146 | 147 | final_result[new_key] = value 148 | 149 | for key in tuple(final_result.keys()): 150 | if "wq" in key: 151 | q = final_result[key] 152 | k = final_result[key.replace("wq", "wk")] 153 | v = final_result[key.replace("wq", "wv")] 154 | q = permute(q, config.n_head) 155 | k = permute(k, config.n_local_heads) 156 | final_result[key.replace("wq", "wqkv")] = torch.cat([q, k, v]) 157 | del final_result[key] 158 | del final_result[key.replace("wq", "wk")] 159 | del final_result[key.replace("wq", "wv")] 160 | print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}") 161 | torch.save(final_result, checkpoint_dir / "model.pth") 162 | 163 | 164 | if __name__ == "__main__": 165 | import argparse 166 | 167 | parser = argparse.ArgumentParser(description="Convert HuggingFace checkpoint.") 168 | parser.add_argument( 169 | "--checkpoint_dir", 170 | type=Path, 171 | default=Path("checkpoints/meta-llama/llama-2-7b-chat-hf"), 172 | ) 173 | parser.add_argument("--model_name", type=str, default=None) 174 | parser.add_argument("--target_precision", type=str, default="fp32") 175 | 176 | args = parser.parse_args() 177 | convert_hf_checkpoint( 178 | checkpoint_dir=args.checkpoint_dir, 179 | model_name=args.model_name, 180 | target_precision=args.target_precision, 181 | ) 182 | -------------------------------------------------------------------------------- /train_code/scripts/download.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import os 7 | from requests.exceptions import HTTPError 8 | import sys 9 | from pathlib import Path 10 | from typing import Optional 11 | 12 | 13 | def hf_download( 14 | repo_id: Optional[str] = None, 15 | hf_token: Optional[str] = None, 16 | local_dir: Optional[str] = None, 17 | ) -> None: 18 | from huggingface_hub import snapshot_download 19 | 20 | local_dir = local_dir or "checkpoints" 21 | 22 | os.makedirs(f"{local_dir}/{repo_id}", exist_ok=True) 23 | try: 24 | snapshot_download( 25 | repo_id, 26 | local_dir=f"{local_dir}/{repo_id}", 27 | local_dir_use_symlinks=False, 28 | token=hf_token, 29 | ) 30 | except HTTPError as e: 31 | if e.response.status_code == 401: 32 | print( 33 | "You need to pass a valid `--hf_token=...` to download private checkpoints." 34 | ) 35 | else: 36 | raise e 37 | 38 | 39 | if __name__ == "__main__": 40 | import argparse 41 | 42 | parser = argparse.ArgumentParser(description="Download data from HuggingFace Hub.") 43 | parser.add_argument( 44 | "--repo_id", 45 | type=str, 46 | default="checkpoints/meta-llama/llama-2-7b-chat-hf", 47 | help="Repository ID to download from.", 48 | ) 49 | parser.add_argument( 50 | "--local_dir", type=str, default=None, help="Local directory to download to." 51 | ) 52 | parser.add_argument( 53 | "--hf_token", type=str, default=None, help="HuggingFace API token." 54 | ) 55 | 56 | args = parser.parse_args() 57 | hf_download(args.repo_id, args.hf_token, args.local_dir) 58 | -------------------------------------------------------------------------------- /train_code/scripts/prepare_ds_math_7b.sh: -------------------------------------------------------------------------------- 1 | set -e 2 | set -x 3 | 4 | export DATA_DIR=/path/to/your/data/directory 5 | export MODEL_REPO=deepseek-ai/deepseek-math-7b-base 6 | 7 | python scripts/download.py \ 8 | --repo_id $MODEL_REPO \ 9 | --local_dir $DATA_DIR/checkpoints 10 | 11 | python scripts/convert_hf_checkpoint.py \ 12 | --checkpoint_dir $DATA_DIR/checkpoints/$MODEL_REPO \ 13 | --target_precision bf16 14 | -------------------------------------------------------------------------------- /train_code/scripts/prepare_llemma_34b.sh: -------------------------------------------------------------------------------- 1 | set -e 2 | set -x 3 | 4 | export DATA_DIR=/path/to/your/data/directory 5 | export MODEL_REPO=EleutherAI/llemma_34b 6 | 7 | python scripts/download.py \ 8 | --repo_id $MODEL_REPO \ 9 | --local_dir $DATA_DIR/checkpoints 10 | 11 | python scripts/convert_hf_checkpoint.py \ 12 | --checkpoint_dir $DATA_DIR/checkpoints/$MODEL_REPO \ 13 | --target_precision bf16 14 | -------------------------------------------------------------------------------- /train_code/scripts/prepare_llemma_7b.sh: -------------------------------------------------------------------------------- 1 | set -e 2 | set -x 3 | 4 | export DATA_DIR=/path/to/your/data/directory 5 | export MODEL_REPO=EleutherAI/llemma_7b 6 | 7 | python scripts/download.py \ 8 | --repo_id $MODEL_REPO \ 9 | --local_dir $DATA_DIR/checkpoints 10 | 11 | python scripts/convert_hf_checkpoint.py \ 12 | --checkpoint_dir $DATA_DIR/checkpoints/$MODEL_REPO \ 13 | --target_precision bf16 14 | -------------------------------------------------------------------------------- /train_code/train_bstar.sh: -------------------------------------------------------------------------------- 1 | export MODEL_REPO= $DATA_DIR/checkpoints/Mistral-7B-v0.1 2 | export OMP_NUM_THREADS=8 3 | 4 | SFT_Midlle_Dir=/yourfolder/bstar_math 5 | SFT_MODEL_SAVE_NAME=math_format_b_star_mistral 6 | NUM_EPOCHS=9 7 | NUM_ITERATIONS=9 8 | NUM_GPUS=8 9 | 10 | BASE_STEP=500 11 | BASE_TARGET_SIZE=135000 12 | GSM8K_TARGET_SIZE=81000 13 | 14 | MATH_TARGET_SIZE=67500 15 | 16 | 17 | 18 | 19 | # Check if the directory exists, if not, create it 20 | if [ ! -d "$SFT_Midlle_Dir" ]; then 21 | mkdir -p "$SFT_Midlle_Dir" 22 | echo "Directory $SFT_Midlle_Dir created." 23 | else 24 | echo "Directory $SFT_Midlle_Dir already exists." 25 | fi 26 | 27 | LOG_DIR=yourfolder/logs 28 | mkdir -p $LOG_DIR 29 | LOG_FILE=$LOG_DIR/train_bstar_log.txt 30 | 31 | 32 | for ((iter=1; iter<=NUM_ITERATIONS; iter++)) 33 | do 34 | 35 | GEN_STEP=$(((iter-1) * BASE_STEP)) 36 | 37 | sample_num=64 38 | 39 | # You should download following data 40 | 41 | input_data="https://huggingface.co/datasets/AndrewZeng/bstar-math-dev/blob/main/dynamic_ana_1k_withans_math.json" 42 | model_dir="$DATA_DIR/checkpoints/$SFT_MODEL_SAVE_NAME/converted/ckpt$GEN_STEP" 43 | output_dir="$SFT_Midlle_Dir" 44 | tensor_parallel_size=1 45 | top_k=-1 46 | max_tokens=768 47 | wandb_project="vllm_gen" 48 | 49 | # select temperature 50 | start_temp=0.50 51 | end_temp=0.85 52 | temp_step=0.05 53 | 54 | gpu=0 55 | 56 | for temp in $(seq $start_temp $temp_step $end_temp); do 57 | output_file="$output_dir/dynamic_ana_1k_withans_iter${iter}_temp${temp}_sample${sample_num}_math.json" 58 | 59 | # For the first step, the model should be the first epoch checkpoint of SFT model 60 | 61 | # We will determine some configurations on small dev set 62 | CUDA_VISIBLE_DEVICES=$gpu python vllm_infer_auto.py \ 63 | --input_data $input_data \ 64 | --model_dir $model_dir \ 65 | --sample_num $sample_num \ 66 | --output_file $output_file \ 67 | --tensor_parallel_size $tensor_parallel_size \ 68 | --temperature $temp \ 69 | --top_k $top_k \ 70 | --max_tokens $max_tokens 71 | 72 | gpu=$(( (gpu + 1) % 9 )) 73 | done 74 | 75 | wait 76 | 77 | 78 | start_temp=0.9 79 | end_temp=1.2 80 | temp_step=0.05 81 | 82 | 83 | gpu=0 84 | 85 | 86 | for temp in $(seq $start_temp $temp_step $end_temp); do 87 | output_file="$output_dir/dynamic_ana_1k_withans_iter${iter}_temp${temp}_sample${sample_num}_math.json" 88 | 89 | CUDA_VISIBLE_DEVICES=$gpu python vllm_infer_auto.py \ 90 | --input_data $input_data \ 91 | --model_dir $model_dir \ 92 | --sample_num $sample_num \ 93 | --output_file $output_file \ 94 | --tensor_parallel_size $tensor_parallel_size \ 95 | --temperature $temp \ 96 | --top_k $top_k \ 97 | --max_tokens $max_tokens 98 | 99 | 100 | gpu=$(( (gpu + 1) % 9 )) 101 | done 102 | 103 | wait 104 | 105 | 106 | temps=(0.50 0.55 0.60 0.65 0.70 0.75 0.80 0.85 0.90 0.95 1.00 1.05 1.10 1.15 1.20) 107 | sample_nums=(64) 108 | 109 | # We will then use reward model to give reward 110 | 111 | 112 | 113 | for temp in "${temps[@]}"; do 114 | for sample_num in "${sample_nums[@]}"; do 115 | 116 | input_path="$output_dir/dynamic_ana_1k_withans_iter${iter}_temp${temp}_sample${sample_num}_math.json" 117 | output_path="$output_dir/dynamic_ana_1k_withans_infer4reward_iter${iter}_temp${temp}_sample${sample_num}_math.json" 118 | 119 | python convert4reward_auto_ground_sample.py \ 120 | --input_path "$input_path" \ 121 | --output_path "$output_path" \ 122 | --sample_num $sample_num \ 123 | --num_files -1 124 | 125 | prompt_file="$output_path" 126 | reward_output_file="$output_dir/dynamic_ana_1k_withans_allreward_iter${iter}_temp${temp}_sample${sample_num}_math.json" 127 | 128 | torchrun --standalone --nproc_per_node=8 \ 129 | inference_reward_llama3.py \ 130 | --prompt_file "$prompt_file" \ 131 | --output_file "$reward_output_file" \ 132 | --batch_size 200 \ 133 | --process_reward_with_answer \ 134 | --tensor_parallel_size 1 \ 135 | --checkpoint_path $DATA_DIR/checkpoints/Mistral-7B-v0.1/model.pth \ 136 | --finetune_checkpoint_path $DATA_DIR/checkpoints/prm_model_mistral_sample_complete 137 | 138 | done 139 | done 140 | 141 | 142 | 143 | best_combination=$(python determine_hyper.py --temps 0.50 0.55 0.60 0.65 0.70 0.75 0.80 0.85 0.90 0.95 1.00 1.05 1.10 1.15 1.20 --sample_nums 64 \ 144 | --input_path_template "$output_dir/dynamic_ana_1k_withans_allreward_iter${iter}_temp{temp}_sample{sample_num}_math.json" \ 145 | --ref_path_template "$output_dir/dynamic_ana_1k_withans_infer4reward_iter${iter}_temp{temp}_sample{sample_num}_math.json" \ 146 | --valid_sample_size 3550 \ 147 | --iter $iter) 148 | 149 | 150 | temp=$(echo $best_combination | cut -d' ' -f1) 151 | sample_num=$(echo $best_combination | cut -d' ' -f2) 152 | 153 | 154 | export BEST_MATH_TEMP=$temp 155 | export BEST_MATH_SAMPLE_NUM=$sample_num 156 | 157 | 158 | echo "Best math temp: $BEST_MATH_TEMP" | tee -a $LOG_FILE 159 | echo "Best math sample_num: $BEST_MATH_SAMPLE_NUM" | tee -a $LOG_FILE 160 | 161 | 162 | for ((i=0; i&1 | tee -a $LOG_FILE 183 | 184 | 185 | torchrun --standalone --nproc_per_node=8 \ 186 | inference_reward_llama3.py \ 187 | --prompt_file $SFT_Midlle_Dir/gsm8k_math_format_infer_part_infer4reward_sample32_iter$((iter))_math.json \ 188 | --output_file $SFT_Midlle_Dir/gsm8k_math_format_infer_part_allreward_sample32_iter$((iter))_math.json \ 189 | --batch_size 200 \ 190 | --process_reward_with_answer \ 191 | --tensor_parallel_size 1 \ 192 | --checkpoint_path $DATA_DIR/checkpoints/Mistral-7B-v0.1/model.pth \ 193 | --finetune_checkpoint_path $DATA_DIR/checkpoints/prm_1to5_model_mistral_sample_complete \ 194 | 2>&1 | tee -a $LOG_FILE 195 | 196 | python metric_modiacc_auto.py \ 197 | --input $SFT_Midlle_Dir/gsm8k_math_format_infer_part_allreward_sample32_iter$((iter))_math.json \ 198 | --output $SFT_Midlle_Dir/gsm8k_math_format_infer_iter$((iter))_135k_math.json \ 199 | --ref $SFT_Midlle_Dir/gsm8k_math_format_infer_part_infer4reward_sample32_iter$((iter))_math.json \ 200 | --target_size $MATH_TARGET_SIZE \ 201 | --correct_num 4 \ 202 | --correct_ratio 0.9 \ 203 | 2>&1 | tee -a $LOG_FILE 204 | 205 | SFT_TRAIN_DATA=$SFT_Midlle_Dir/gsm8k_math_format_infer_iter$((iter))_135k_math.json 206 | 207 | END_STEP=$((iter * BASE_STEP)) 208 | 209 | if [ $iter -eq 1 ]; then 210 | torchrun --standalone --nproc_per_node=$NUM_GPUS \ 211 | train_sft_step.py \ 212 | --do_train \ 213 | --checkpoint_path $MODEL_REPO/model.pth \ 214 | --sft_checkpoint_path $DATA_DIR/checkpoints/$SFT_MODEL_SAVE_NAME \ 215 | --source_max_len 768 \ 216 | --target_max_len 768 \ 217 | --total_max_len 1024 \ 218 | --per_device_train_batch_size 16 \ 219 | --micro_train_batch_size 8 \ 220 | --learning_rate 4.12e-6 \ 221 | --lr_eta_min 2e-7 \ 222 | --num_train_epochs $NUM_EPOCHS \ 223 | --dataset "$SFT_TRAIN_DATA" \ 224 | --dataset_format "metamath" \ 225 | --add_eos_to_marked_target \ 226 | --save_strategy "steps" \ 227 | --save_steps 500 \ 228 | --optim_dtype bf16 \ 229 | --save_total_limit 40 \ 230 | --tensor_parallel_size 1 \ 231 | --end_step $END_STEP \ 232 | --save_dir $DATA_DIR/checkpoints/$SFT_MODEL_SAVE_NAME \ 233 | 2>&1 | tee -a $LOG_FILE 234 | else 235 | torchrun --standalone --nproc_per_node=$NUM_GPUS \ 236 | train_sft_step.py \ 237 | --do_train \ 238 | --checkpoint_path $MODEL_REPO/model.pth \ 239 | --sft_checkpoint_path $DATA_DIR/checkpoints/$SFT_MODEL_SAVE_NAME \ 240 | --source_max_len 768 \ 241 | --target_max_len 768 \ 242 | --total_max_len 1024 \ 243 | --per_device_train_batch_size 16 \ 244 | --micro_train_batch_size 8 \ 245 | --learning_rate 4.12e-6 \ 246 | --lr_eta_min 2e-7 \ 247 | --num_train_epochs $NUM_EPOCHS \ 248 | --dataset "$SFT_TRAIN_DATA" \ 249 | --dataset_format "metamath" \ 250 | --add_eos_to_marked_target \ 251 | --save_strategy "steps" \ 252 | --save_steps 500 \ 253 | --optim_dtype bf16 \ 254 | --save_total_limit 40 \ 255 | --tensor_parallel_size 1 \ 256 | --end_step $END_STEP \ 257 | --save_dir $DATA_DIR/checkpoints/$SFT_MODEL_SAVE_NAME \ 258 | --resume_from_checkpoint \ 259 | 2>&1 | tee -a $LOG_FILE 260 | fi 261 | 262 | python convert_auto.py \ 263 | --checkpoint_dir $DATA_DIR/checkpoints/$SFT_MODEL_SAVE_NAME \ 264 | --pretrain_name $MODEL_REPO \ 265 | --tokenizer_name $MODEL_REPO \ 266 | 2>&1 | tee -a $LOG_FILE 267 | done 268 | 269 | -------------------------------------------------------------------------------- /train_code/train_reward.sh: -------------------------------------------------------------------------------- 1 | export DATA_DIR=/path/to/your/data/directory 2 | 3 | export MODEL_REPO= $DATA_DIR/checkpoints/Mistral-7B-v0.1 4 | export OMP_NUM_THREADS=4 5 | 6 | 7 | RM_DATA=train_prm_math_shepherd_mistral.json 8 | RM_MODEL_SAVE_NAME=prm_model_mistral_sample_complete 9 | 10 | torchrun --standalone --nproc_per_node=8 \ 11 | train_rm_pointwise.py \ 12 | --do_train \ 13 | --checkpoint_path $MODEL_REPO/model.pth \ 14 | --source_max_len 768 \ 15 | --target_max_len 768 \ 16 | --total_max_len 1024 \ 17 | --per_device_train_batch_size 32 \ 18 | --micro_train_batch_size 32 \ 19 | --learning_rate 2e-6 \ 20 | --lr_eta_min 2e-7 \ 21 | --num_train_epochs 2 \ 22 | --dataset "$RM_DATA" \ 23 | --dataset_format "prm-v4" \ 24 | --save_strategy epoch \ 25 | --save_total_limit 5 \ 26 | --train_on_every_token \ 27 | --tensor_parallel_size 1 \ 28 | --save_only_model True \ 29 | --optim_dtype bf16 \ 30 | --save_dir $DATA_DIR/checkpoints/$RM_MODEL_SAVE_NAME \ 31 | --resume_from_checkpoint -------------------------------------------------------------------------------- /train_code/train_sft.sh: -------------------------------------------------------------------------------- 1 | export DATA_DIR=/path/to/your/data/directory 2 | export MODEL_REPO= $DATA_DIR/checkpoints/Mistral-7B-v0.1 3 | 4 | export OMP_NUM_THREADS=8 5 | 6 | 7 | SFT_TRAIN_DATA=https://huggingface.co/datasets/AndrewZeng/math-trn-format/blob/main/math_format.json 8 | 9 | # Please download this dataset to local folder 10 | SFT_MODEL_SAVE_NAME=math_format_11k_mistral 11 | 12 | torchrun --standalone --nproc_per_node=8 \ 13 | train_sft.py \ 14 | --do_train \ 15 | --checkpoint_path $MODEL_REPO/model.pth \ 16 | --source_max_len 768 \ 17 | --target_max_len 768 \ 18 | --total_max_len 1024 \ 19 | --per_device_train_batch_size 16 \ 20 | --micro_train_batch_size 4 \ 21 | --learning_rate 5e-6 \ 22 | --lr_eta_min 2e-7 \ 23 | --num_train_epochs 3 \ 24 | --dataset "$SFT_TRAIN_DATA" \ 25 | --dataset_format "metamath" \ 26 | --add_eos_to_marked_target \ 27 | --save_strategy "steps" \ 28 | --save_steps 25 \ 29 | --optim_dtype bf16 \ 30 | --save_total_limit 40 \ 31 | --tensor_parallel_size 1 \ 32 | --save_dir $DATA_DIR/checkpoints/$SFT_MODEL_SAVE_NAME \ 33 | --resume_from_checkpoint -------------------------------------------------------------------------------- /train_code/trainers/__pycache__/common_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/trainers/__pycache__/common_utils.cpython-310.pyc -------------------------------------------------------------------------------- /train_code/trainers/__pycache__/common_utils.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/trainers/__pycache__/common_utils.cpython-311.pyc -------------------------------------------------------------------------------- /train_code/trainers/__pycache__/ppo_trainer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/trainers/__pycache__/ppo_trainer.cpython-310.pyc -------------------------------------------------------------------------------- /train_code/trainers/__pycache__/ppo_trainer.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/trainers/__pycache__/ppo_trainer.cpython-311.pyc -------------------------------------------------------------------------------- /train_code/trainers/__pycache__/rl_trainer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/trainers/__pycache__/rl_trainer.cpython-310.pyc -------------------------------------------------------------------------------- /train_code/trainers/__pycache__/rl_trainer.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/trainers/__pycache__/rl_trainer.cpython-311.pyc -------------------------------------------------------------------------------- /train_code/trainers/common_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The GPT-Accelera Team 2 | # Copyright 2023 The Alpaca Team 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import argparse 17 | from dataclasses import dataclass 18 | import os 19 | import tempfile 20 | import random 21 | from typing import Callable, Dict, Optional, Sequence, Union, Mapping, Any, Tuple 22 | import logging 23 | 24 | import numpy as np 25 | import torch 26 | import torch.nn.functional as F 27 | import torch.distributed as dist 28 | 29 | from torch.utils.data import Dataset 30 | from torch.utils.data import DataLoader 31 | from torch.utils.data import random_split 32 | 33 | from datasets import load_dataset 34 | 35 | Numeric = Union[int, float] 36 | 37 | 38 | def zip_(*args: Sequence): 39 | """Assert sequences of same length before zipping.""" 40 | if len(args) == 0: 41 | return [] 42 | assert alleq(args, lambda x, y: len(x) == len(y)) 43 | return zip(*args) 44 | 45 | 46 | def mean(*seqs: Sequence[Numeric]) -> Union[Numeric, Sequence[Numeric]]: 47 | singleton = len(seqs) == 1 48 | means = [float(np.mean(seq)) for seq in seqs] 49 | return means[0] if singleton else means 50 | 51 | 52 | def alleq(l: Sequence, f: Optional[Callable] = lambda x, y: x == y): 53 | """Check all arguments in a sequence are equal according to a given criterion. 54 | 55 | Args: 56 | f: A bi-variate boolean function. 57 | l: A list/tuple. 58 | 59 | Returns: 60 | True if everything is equal; otherwise False. 61 | """ 62 | return all(f(l[0], li) for li in l[1:]) 63 | 64 | 65 | def flatten_dict(nested, sep=".", postprocess_fn=lambda *args: args): 66 | def rec(nest, prefix, into): 67 | for k, v in nest.items(): 68 | if sep in k: 69 | raise ValueError(f"separator '{sep}' not allowed to be in key '{k}'") 70 | if isinstance(v, dict): # collections.Mapping fails in py3.10. 71 | rec(v, prefix + k + sep, into) 72 | else: 73 | v = postprocess_fn(v) 74 | into[prefix + k] = v 75 | 76 | flat = {} 77 | rec(nested, "", flat) 78 | return flat 79 | 80 | 81 | def unpack_dict( 82 | d: Dict, keys: Sequence[str], return_type: type = tuple 83 | ) -> Union[Sequence, Dict]: 84 | if return_type in (tuple, list): 85 | return return_type(d[key] for key in keys) 86 | elif return_type == dict: 87 | return {key: d[key] for key in keys} 88 | else: 89 | raise ValueError(f"Unknown return_type: {return_type}") 90 | 91 | 92 | def merge_dict(dicts: Sequence[dict], merge_fn: Callable = lambda *args: args) -> dict: 93 | """Merge a sequence of dicts (with the same set of keys) into a single dict.""" 94 | if len(dicts) == 0: 95 | return dict() 96 | return {key: merge_fn([dict_[key] for dict_ in dicts]) for key in dicts[0].keys()} 97 | 98 | 99 | def prepare_inputs( 100 | data: Union[torch.Tensor, Any], device: Union[str, int, torch.device] 101 | ) -> Union[torch.Tensor, Any]: 102 | if isinstance(data, Mapping): 103 | return type(data)( 104 | {k: prepare_inputs(v, device) for k, v in data.items()} 105 | ) # noqa 106 | elif isinstance(data, (tuple, list)): 107 | return type(data)(prepare_inputs(v, device) for v in data) 108 | elif isinstance(data, torch.Tensor): 109 | return data.to(device) # This can break with deepspeed. 110 | return data 111 | 112 | 113 | def pad_inputs_on_batch( 114 | data: Sequence[torch.Tensor], per_device_batch_size: int 115 | ) -> Sequence[torch.Tensor]: 116 | batch_size = None 117 | output_tensors = [] 118 | for tensor in data: 119 | if batch_size is None: 120 | batch_size = tensor.size(0) 121 | assert tensor.size(0) == batch_size 122 | 123 | if batch_size % per_device_batch_size != 0: 124 | filled_size = per_device_batch_size - (batch_size % per_device_batch_size) 125 | tensor = torch.cat( 126 | [ 127 | tensor, 128 | tensor[0:1].expand(filled_size, *tensor.size()[1:]), 129 | ], 130 | dim=0, 131 | ) 132 | output_tensors.append(tensor) 133 | return output_tensors 134 | 135 | 136 | def pad( 137 | inputs: torch.Tensor, 138 | target_size: Union[torch.Size, Sequence[int]], 139 | value=0.0, 140 | left=True, 141 | ): 142 | current_size = inputs.size() 143 | diffs = tuple(ti - ci for ti, ci in zip_(target_size, current_size)) 144 | pad_params = [] 145 | for diff in diffs: 146 | pad_params = ([diff, 0] if left else [0, diff]) + pad_params 147 | res = F.pad(inputs, pad=pad_params, value=value) 148 | return res 149 | 150 | 151 | def left_pad( 152 | inputs: torch.Tensor, target_size: Union[torch.Size, Sequence[int]], value=0.0 153 | ): 154 | return pad(inputs=inputs, target_size=target_size, value=value, left=True) 155 | 156 | 157 | def right_pad( 158 | inputs: torch.Tensor, target_size: Union[torch.Size, Sequence[int]], value=0.0 159 | ): 160 | return pad(inputs=inputs, target_size=target_size, value=value, left=False) 161 | 162 | 163 | def manual_seed(args_or_seed: Union[int, argparse.Namespace], fix_cudnn=False): 164 | if hasattr(args_or_seed, "seed"): 165 | args_or_seed = args_or_seed.seed 166 | random.seed(args_or_seed) 167 | np.random.seed(args_or_seed) 168 | torch.manual_seed(args_or_seed) 169 | torch.cuda.manual_seed_all(args_or_seed) 170 | os.environ["PYTHONHASHSEED"] = str(args_or_seed) 171 | if fix_cudnn: 172 | torch.backends.cudnn.deterministic = True # noqa 173 | torch.backends.cudnn.benchmark = False # noqa 174 | 175 | 176 | class InfiniteLoader(object): 177 | """Wraps an existing DataLoader so that it outputs stuff indefinitely; useful for semi-supervised learning and DDP.""" 178 | 179 | def __init__(self, loader: DataLoader): 180 | super(InfiniteLoader, self).__init__() 181 | self.loader = loader 182 | self.data_iterator = iter(loader) 183 | self.epoch = 0 184 | 185 | def __iter__(self): 186 | return self 187 | 188 | def __next__(self): 189 | try: 190 | data = next(self.data_iterator) 191 | except StopIteration: 192 | # Increment the epoch count 193 | self.epoch += 1 194 | 195 | # If using Distributed Data Parallel, set the epoch for the sampler 196 | if dist.is_initialized(): 197 | self.loader.sampler.set_epoch(self.epoch) 198 | 199 | # Create a new iterator for the next epoch 200 | self.data_iterator = iter(self.loader) 201 | data = next(self.data_iterator) 202 | 203 | return data 204 | 205 | 206 | class DisableLogger: 207 | def __enter__(self): 208 | logging.disable(logging.CRITICAL) 209 | 210 | def __exit__(self, exit_type, exit_value, exit_traceback): 211 | logging.disable(logging.NOTSET) 212 | 213 | 214 | @dataclass 215 | class DataCollatorForStackableDataset(object): 216 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: 217 | return { 218 | key: torch.stack([instance[key] for instance in instances]) 219 | for key in instances[0].keys() 220 | } 221 | 222 | 223 | def local_dataset(dataset_name): 224 | if dataset_name.endswith(".json"): 225 | full_dataset = load_dataset( 226 | "json", 227 | data_files=dataset_name, 228 | cache_dir=os.path.join( 229 | tempfile.gettempdir(), f"{os.getuid()}_cache", "huggingface", "datasets" 230 | ), 231 | ) 232 | else: 233 | raise ValueError(f"Unsupported dataset format: {dataset_name}") 234 | 235 | return full_dataset 236 | 237 | 238 | def _get_generator(seed: int) -> torch.Generator: 239 | rng = torch.Generator() 240 | rng.manual_seed(seed) 241 | return rng 242 | 243 | 244 | def split_train_into_train_and_eval( 245 | train_dataset: Dataset, eval_size: int, seed: int 246 | ) -> Tuple[Dataset, Dataset]: 247 | assert eval_size < len( 248 | train_dataset # noqa 249 | ), "Requested eval_size cannot be equal/larger than original train data size." 250 | new_train_size = len(train_dataset) - eval_size # noqa 251 | train_dataset, eval_dataset = random_split( 252 | train_dataset, [new_train_size, eval_size], generator=_get_generator(seed) 253 | ) 254 | return train_dataset, eval_dataset 255 | -------------------------------------------------------------------------------- /train_code/training_utils/__pycache__/fsdp_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/training_utils/__pycache__/fsdp_utils.cpython-310.pyc -------------------------------------------------------------------------------- /train_code/training_utils/__pycache__/fsdp_utils.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/training_utils/__pycache__/fsdp_utils.cpython-311.pyc -------------------------------------------------------------------------------- /train_code/training_utils/__pycache__/memory_efficient_adam.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/training_utils/__pycache__/memory_efficient_adam.cpython-310.pyc -------------------------------------------------------------------------------- /train_code/training_utils/__pycache__/memory_efficient_adam.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/training_utils/__pycache__/memory_efficient_adam.cpython-311.pyc -------------------------------------------------------------------------------- /train_code/training_utils/__pycache__/trainer_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/training_utils/__pycache__/trainer_utils.cpython-310.pyc -------------------------------------------------------------------------------- /train_code/training_utils/__pycache__/trainer_utils.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkust-nlp/B-STaR/a0f5d95eaba692327be71bb57716f621024af4ec/train_code/training_utils/__pycache__/trainer_utils.cpython-311.pyc -------------------------------------------------------------------------------- /train_code/training_utils/memory_efficient_adam.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The GPT-Accelera Team 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | from collections import defaultdict 10 | from copy import deepcopy 11 | from itertools import chain 12 | from typing import Any, DefaultDict, Dict, Iterable 13 | 14 | import torch 15 | from torch.optim import Optimizer 16 | from torch.optim.optimizer import StateDict 17 | 18 | 19 | class MemoryEfficientAdamW(Optimizer): 20 | """ 21 | Arguments: 22 | model_params (iterable): iterable of parameters of dicts defining 23 | parameter groups. 24 | lr (float, optional): learning rate. (default: 1e-3) 25 | betas (Tuple[float, float], optional): coefficients used for computing 26 | running averages of gradient and its square. (default: (0.9, 0.999)) 27 | eps (float, optional): term added to the denominator to improve 28 | numerical stability. (default: 1e-8) 29 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 30 | adamw_mode (boolean, optional): Apply L2 regularization or weight decay 31 | True for decoupled weight decay(also known as AdamW) (default: True) 32 | 33 | .. _Adam\: A Method for Stochastic Optimization: 34 | https://arxiv.org/abs/1412.6980 35 | .. _On the Convergence of Adam and Beyond: 36 | https://openreview.net/forum?id=ryQu7f-RZ 37 | """ 38 | 39 | def __init__( 40 | self, 41 | model_params, 42 | lr=1e-3, 43 | betas=(0.9, 0.999), 44 | eps=1e-8, 45 | weight_decay=0, 46 | adamw_mode=True, 47 | optim_dtype=torch.bfloat16, 48 | optim_device=torch.device("cpu"), 49 | ): 50 | default_args = dict( 51 | lr=lr, 52 | betas=betas, 53 | eps=eps, 54 | weight_decay=weight_decay, 55 | ) 56 | super(MemoryEfficientAdamW, self).__init__(model_params, default_args) 57 | self.adamw_mode = adamw_mode 58 | self.optim_dtype = optim_dtype 59 | self.optim_device = optim_device 60 | 61 | def torch_adam_update_cpu( 62 | self, 63 | data, 64 | grad, 65 | exp_avg, 66 | exp_avg_sq, 67 | lr, 68 | beta1, 69 | beta2, 70 | eps, 71 | weight_decay, 72 | bias_correction1, 73 | bias_correction2, 74 | use_adamw=False, 75 | ): 76 | assert data.dtype == grad.dtype 77 | if weight_decay != 0: 78 | if use_adamw: 79 | data.mul_(1 - lr * weight_decay) 80 | else: 81 | grad = grad.add(data, alpha=weight_decay) 82 | 83 | non_blocking = self.optim_device.type == "cpu" 84 | 85 | exp_avg_cuda, exp_avg_sq_cuda = ( 86 | exp_avg.to(data.device, non_blocking=non_blocking), 87 | exp_avg_sq.to(data.device, non_blocking=non_blocking), 88 | ) 89 | 90 | dtype_grad = grad.to(dtype=self.optim_dtype) 91 | exp_avg_cuda.mul_(beta1).add_(dtype_grad, alpha=1 - beta1) 92 | exp_avg_sq_cuda.mul_(beta2).addcmul_(dtype_grad, dtype_grad, value=1 - beta2) 93 | denom_cuda = (exp_avg_sq_cuda.sqrt() / math.sqrt(bias_correction2)).add_(eps) 94 | 95 | step_size = lr / bias_correction1 96 | data.addcdiv_( 97 | exp_avg_cuda.to(dtype=data.dtype), 98 | denom_cuda.to(dtype=data.dtype), 99 | value=-step_size, 100 | ) 101 | 102 | # Write back to cpu 103 | exp_avg.copy_(exp_avg_cuda, non_blocking=non_blocking) 104 | exp_avg_sq.copy_(exp_avg_sq_cuda, non_blocking=non_blocking) 105 | 106 | @torch.no_grad() 107 | def step(self, closure=None): 108 | loss = None 109 | if closure is not None: 110 | with torch.enable_grad(): 111 | loss = closure() 112 | 113 | for _, group in enumerate(self.param_groups): 114 | for _, p in enumerate(group["params"]): 115 | if p.grad is None: 116 | continue 117 | 118 | state = self.state[p] 119 | assert ( 120 | p.device.type == "cuda" 121 | ), f"PinMemoryCPUAdam assume all parameters are on cuda" 122 | if len(state) == 0: 123 | state["step"] = 0 124 | # gradient momentums 125 | state["exp_avg"] = torch.zeros_like( 126 | p, 127 | device=self.optim_device, 128 | dtype=self.optim_dtype, 129 | ) 130 | # gradient variances 131 | state["exp_avg_sq"] = torch.zeros_like( 132 | p, 133 | device=self.optim_device, 134 | dtype=self.optim_dtype, 135 | ) 136 | if self.optim_device.type == "cpu": 137 | state["exp_avg"] = state["exp_avg"].pin_memory() 138 | state["exp_avg_sq"] = state["exp_avg_sq"].pin_memory() 139 | 140 | state["step"] += 1 141 | beta1, beta2 = group["betas"] 142 | 143 | assert ( 144 | p.data.numel() == p.grad.data.numel() 145 | ), "parameter and gradient should have the same size" 146 | assert ( 147 | state["exp_avg"].device.type == self.optim_device.type 148 | ), f"exp_avg should stay on {self.optim_device.type}" 149 | assert ( 150 | state["exp_avg_sq"].device.type == self.optim_device.type 151 | ), f"exp_avg should stay on {self.optim_device.type}" 152 | bias_correction1 = 1 - beta1 ** state["step"] 153 | bias_correction2 = 1 - beta2 ** state["step"] 154 | self.torch_adam_update_cpu( 155 | p.data, 156 | p.grad.data, 157 | state["exp_avg"], 158 | state["exp_avg_sq"], 159 | group["lr"], 160 | beta1, 161 | beta2, 162 | group["eps"], 163 | group["weight_decay"], 164 | bias_correction1, 165 | bias_correction2, 166 | self.adamw_mode, 167 | ) 168 | return loss 169 | 170 | @torch._disable_dynamo 171 | def load_state_dict(self, state_dict: StateDict) -> None: 172 | r"""Loads the optimizer state. 173 | 174 | Args: 175 | state_dict (dict): optimizer state. Should be an object returned 176 | from a call to :meth:`state_dict`. 177 | """ 178 | # shallow copy, to be consistent with module API 179 | state_dict = state_dict.copy() 180 | 181 | for pre_hook in self._optimizer_load_state_dict_pre_hooks.values(): 182 | hook_result = pre_hook(self, state_dict) 183 | if hook_result is not None: 184 | state_dict = hook_result 185 | 186 | # Validate the state_dict 187 | groups = self.param_groups 188 | 189 | # Deepcopy as we write into saved_groups later to update state 190 | saved_groups = deepcopy(state_dict["param_groups"]) 191 | 192 | if len(groups) != len(saved_groups): 193 | raise ValueError( 194 | "loaded state dict has a different number of " "parameter groups" 195 | ) 196 | param_lens = (len(g["params"]) for g in groups) 197 | saved_lens = (len(g["params"]) for g in saved_groups) 198 | if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): 199 | raise ValueError( 200 | "loaded state dict contains a parameter group " 201 | "that doesn't match the size of optimizer's group" 202 | ) 203 | 204 | # Update the state 205 | id_map = dict( 206 | zip( 207 | chain.from_iterable(g["params"] for g in saved_groups), 208 | chain.from_iterable(g["params"] for g in groups), 209 | ) 210 | ) 211 | 212 | def _cast(param, value, param_id=None, param_groups=None, key=None): 213 | r"""Make a deep copy of value, casting all tensors to device of param.""" 214 | if isinstance(value, torch.Tensor): 215 | if param.is_floating_point(): 216 | casted_value = value.to( 217 | dtype=self.optim_dtype, device=self.optim_device 218 | ) 219 | if self.optim_device.type == "cpu": 220 | casted_value = casted_value.pin_memory() 221 | else: 222 | casted_value = Optimizer._process_value_according_to_param_policy( 223 | param, value, param_id, param_groups, key 224 | ) 225 | return casted_value 226 | elif isinstance(value, dict): 227 | return { 228 | k: _cast( 229 | param, v, param_id=param_id, param_groups=param_groups, key=k 230 | ) 231 | for k, v in value.items() 232 | } 233 | elif isinstance(value, Iterable): 234 | return type(value)(_cast(param, v, param_id=param_id, param_groups=param_groups) for v in value) # type: ignore[call-arg] 235 | else: 236 | return value 237 | 238 | # Copy state assigned to params (and cast tensors to appropriate types). 239 | # State that is not assigned to params is copied as is (needed for 240 | # backward compatibility). 241 | state: DefaultDict[torch.Tensor, Dict[Any, Any]] = defaultdict(dict) 242 | for k, v in state_dict["state"].items(): 243 | if k in id_map: 244 | param = id_map[k] 245 | state[param] = _cast( 246 | param, v, param_id=k, param_groups=state_dict["param_groups"] 247 | ) 248 | else: 249 | state[k] = v 250 | 251 | # Update parameter groups, setting their 'params' value 252 | def update_group( 253 | group: Dict[str, Any], new_group: Dict[str, Any] 254 | ) -> Dict[str, Any]: 255 | new_group["params"] = group["params"] 256 | return new_group 257 | 258 | param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)] 259 | self.__setstate__({"state": state, "param_groups": param_groups}) 260 | 261 | for post_hook in self._optimizer_load_state_dict_post_hooks.values(): 262 | post_hook(self) 263 | -------------------------------------------------------------------------------- /train_code/training_utils/trainer_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The GPT-Accelera Team 2 | # Copyright 2023 The Alpaca Team 3 | # Copyright 2022 The HuggingFace Team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import math 18 | from functools import partial 19 | 20 | import torch 21 | import torch.nn as nn 22 | 23 | import torch.optim as optim 24 | from torch.optim.lr_scheduler import LambdaLR 25 | 26 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 27 | from torch.distributed.fsdp import MixedPrecision 28 | from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy 29 | 30 | from training_utils.memory_efficient_adam import MemoryEfficientAdamW 31 | from arguments import Arguments 32 | 33 | from models.model import TransformerBlock 34 | from models.tp import get_data_parallel_group, get_data_parallel_world_size 35 | 36 | 37 | def create_optimizer( 38 | args: Arguments, 39 | model: nn.Module, 40 | optimizer_cpu_offload: bool = False, 41 | model_cpu_offload: bool = False, 42 | ) -> optim.Optimizer: 43 | if not model_cpu_offload: 44 | model_device = next(iter(model.parameters())).device 45 | 46 | optimizer = MemoryEfficientAdamW( 47 | [p for p in model.parameters() if p.requires_grad], 48 | lr=args.learning_rate, 49 | betas=(args.adam_beta1, args.adam_beta2), 50 | eps=args.adam_eps, 51 | weight_decay=args.weight_decay, 52 | optim_dtype=args.optim_dtype, 53 | optim_device=( 54 | torch.device("cpu") if optimizer_cpu_offload else model_device 55 | ), 56 | ) 57 | else: 58 | optimizer = torch.optim.AdamW( 59 | [p for p in model.parameters() if p.requires_grad], 60 | lr=args.learning_rate, 61 | betas=(args.adam_beta1, args.adam_beta2), 62 | eps=args.adam_eps, 63 | weight_decay=args.weight_decay, 64 | fused=True, 65 | ) 66 | 67 | return optimizer 68 | 69 | 70 | def create_fsdp_model_for_finetune( 71 | args: Arguments, 72 | model: nn.Module, 73 | bf16_all_reduce_upper_bound: int = 16, 74 | ) -> FSDP: 75 | model = FSDP( 76 | module=model, 77 | process_group=get_data_parallel_group(), 78 | auto_wrap_policy=partial( 79 | transformer_auto_wrap_policy, 80 | transformer_layer_cls={ 81 | TransformerBlock, 82 | }, 83 | ), 84 | mixed_precision=MixedPrecision( 85 | param_dtype=args.compute_dtype, 86 | reduce_dtype=( 87 | torch.float32 88 | if get_data_parallel_world_size() >= bf16_all_reduce_upper_bound 89 | else args.compute_dtype 90 | ), 91 | keep_low_precision_grads=(args.optim_dtype != torch.float32), 92 | buffer_dtype=args.compute_dtype, 93 | ), 94 | cpu_offload=False, 95 | use_orig_params=False, 96 | forward_prefetch=True, 97 | limit_all_gathers=True, 98 | ) 99 | return model 100 | 101 | 102 | # https://github.com/huggingface/transformers/blob/976189a6df796a2ff442dd81b022626c840d8c27/src/transformers/optimization.py 103 | def _get_cosine_schedule_with_warmup_lr_lambda( 104 | current_step: int, 105 | *, 106 | num_warmup_steps: int, 107 | num_training_steps: int, 108 | warmup_start_ratio: float, 109 | eta_min_ratio: float, 110 | ): 111 | if current_step < num_warmup_steps: 112 | return warmup_start_ratio + (1.0 - warmup_start_ratio) * float( 113 | current_step 114 | ) / float(max(1, num_warmup_steps)) 115 | 116 | progress = float(current_step - num_warmup_steps) / float( 117 | max(1, num_training_steps - num_warmup_steps) 118 | ) 119 | return eta_min_ratio + (1.0 - eta_min_ratio) * max( 120 | 0.0, 0.5 * (1.0 + math.cos(math.pi * progress)) 121 | ) 122 | 123 | 124 | def get_cosine_schedule_with_warmup( 125 | optimizer: optim.Optimizer, 126 | warmup_epochs: int, 127 | max_epochs: int, 128 | warmup_start_ratio: float = 0.0, 129 | eta_min_ratio: float = 0.0, 130 | last_epoch: int = -1, 131 | ): 132 | """ 133 | Create a schedule with a learning rate that decreases following the values of the cosine function between the 134 | initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the 135 | initial lr set in the optimizer. 136 | 137 | Return: 138 | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. 139 | """ 140 | 141 | assert 0.0 <= warmup_start_ratio <= 1.0, "warmup_start_ratio should be in [0, 1]" 142 | assert 0.0 <= eta_min_ratio <= 1.0, "eta_min_ratio should be in [0, 1]" 143 | 144 | lr_lambda = partial( 145 | _get_cosine_schedule_with_warmup_lr_lambda, 146 | num_warmup_steps=warmup_epochs, 147 | num_training_steps=max_epochs, 148 | warmup_start_ratio=warmup_start_ratio, 149 | eta_min_ratio=eta_min_ratio, 150 | ) 151 | return LambdaLR(optimizer, lr_lambda, last_epoch) 152 | 153 | 154 | def _get_constant_cosine_schedule_lr_lambda( 155 | current_step: int, 156 | *, 157 | num_training_steps: int, 158 | constant_lr: float, 159 | eta_min_ratio: float, 160 | constant_ratio: float, 161 | ): 162 | """ 163 | Implements a learning rate schedule with a constant learning rate for the specified proportion 164 | of the training steps and a cosine annealing schedule for the remaining steps. 165 | 166 | Args: 167 | current_step (int): The current training step. 168 | num_training_steps (int): The total number of training steps. 169 | constant_lr (float): The constant learning rate for the specified proportion of the training steps. 170 | eta_min_ratio (float): The minimum learning rate ratio at the end of the cosine annealing schedule. 171 | constant_ratio (float): The proportion of the training steps to keep the learning rate constant. 172 | 173 | Returns: 174 | float: The learning rate multiplier. 175 | """ 176 | constant_steps = int(num_training_steps * constant_ratio) 177 | 178 | 179 | if current_step < constant_steps: 180 | return constant_lr 181 | 182 | #progress = float(current_step - constant_steps) / float(num_training_steps - constant_steps) 183 | progress = float(current_step - constant_steps) / float( 184 | max(1, num_training_steps - constant_steps) 185 | ) 186 | return eta_min_ratio + (constant_lr - eta_min_ratio) * 0.5 * (1.0 + math.cos(math.pi * progress)) 187 | 188 | def get_constant_cosine_schedule_with_warmup( 189 | optimizer: torch.optim.Optimizer, 190 | max_epochs: int, 191 | constant_lr: float = 1.0, 192 | eta_min_ratio: float = 0.0, 193 | constant_ratio: float = 0.67, 194 | last_epoch: int = -1, 195 | ): 196 | """ 197 | Create a schedule with a learning rate that is constant for a specified proportion of the training steps, 198 | then decreases following the values of the cosine function for the remaining steps. 199 | 200 | Args: 201 | optimizer (torch.optim.Optimizer): The optimizer for which to schedule the learning rate. 202 | max_epochs (int): The total number of training epochs. 203 | constant_lr (float, optional): The constant learning rate for the specified proportion of the training steps. Default is 1.0. 204 | eta_min_ratio (float, optional): The minimum learning rate ratio at the end of the cosine annealing schedule. Default is 0.0. 205 | constant_ratio (float, optional): The proportion of the training steps to keep the learning rate constant. Default is 0.67. 206 | last_epoch (int, optional): The index of the last epoch when resuming training. Default is -1. 207 | 208 | Returns: 209 | torch.optim.lr_scheduler.LambdaLR: The learning rate scheduler. 210 | """ 211 | lr_lambda = partial( 212 | _get_constant_cosine_schedule_lr_lambda, 213 | num_training_steps=max_epochs, 214 | constant_lr=constant_lr, 215 | eta_min_ratio=eta_min_ratio, 216 | constant_ratio=constant_ratio, 217 | ) 218 | return LambdaLR(optimizer, lr_lambda, last_epoch) -------------------------------------------------------------------------------- /train_code/vllm_infer.py: -------------------------------------------------------------------------------- 1 | # 基于diversity筛选出的数据,进行self-consistency 2 | 3 | from vllm import LLM, SamplingParams 4 | 5 | from transformers import AutoModelForCausalLM, AutoTokenizer 6 | import json 7 | import random 8 | 9 | import re 10 | import argparse 11 | 12 | def sample_resp(args): 13 | with open(args.input_data, "r") as r: 14 | data_json = json.load(r) 15 | sample_data = data_json 16 | 17 | 18 | data_dict = {} 19 | 20 | 21 | for item in sample_data: 22 | data_dict[item["input"]] = item 23 | 24 | prompt_template = ''' 25 | 26 | "Below is an instruction that describes a task. " "Write a response that appropriately completes the request.\n\n" "### Instruction:\n{instruction}\n\n### Response: Let's think step by step." 27 | 28 | ''' 29 | llm = LLM(model=args.model_dir, tensor_parallel_size=args.tensor_parallel_size) 30 | stop_tokens = ["\nQUESTION:\n", "Question:", "Question", "USER:", "USER", "ASSISTANT:", "ASSISTANT", "Instruction:", "Instruction", "Response:", "Response", "# your code here", "QUESTION", "# Your code goes here", "# Write your code", "\n\n\n\n", "<|end_of_text|>", "\n\nSolved"] 31 | sampling_params = SamplingParams(max_tokens=args.max_tokens, temperature=args.temperature, top_k=args.top_k, n=args.sample_num, stop=stop_tokens) 32 | all_input = [] 33 | 34 | for item in sample_data: 35 | #user_input = prompt_template.replace("{instruction}", item) 36 | all_input.append(item["input"]) 37 | outputs = llm.generate(all_input, sampling_params) 38 | 39 | all_output = [] 40 | all_prompt = [] 41 | 42 | all_json = [] 43 | 44 | for output in outputs: 45 | 46 | temp_json = data_dict[output.prompt] 47 | all_prompt.append(output.prompt) 48 | 49 | temp_json["prompt"] = output.prompt 50 | 51 | for i in range(args.sample_num): 52 | temp_json["output"+str(i)] = output.outputs[i].text 53 | 54 | all_json.append(temp_json) 55 | 56 | with open(args.output_file, "w") as w: 57 | json.dump(all_json, w) 58 | 59 | def parse_args(): 60 | parser = argparse.ArgumentParser() 61 | parser.add_argument("--input_data", type=str) # start index 62 | parser.add_argument("--model_dir", type=str) # start index 63 | parser.add_argument("--sample_num", type=int, default=10) #start index 64 | parser.add_argument("--temperature", type=float, default=1.0) #start index 65 | parser.add_argument("--top_k", type=int, default=20) #start index 66 | parser.add_argument("--max_tokens", type=int, default=768) #start index 67 | parser.add_argument("--output_file", type=str) # start index 68 | parser.add_argument("--tensor_parallel_size", type=int, default=1) # tensor_parallel_size 69 | return parser.parse_args() 70 | 71 | if __name__ == "__main__": 72 | args = parse_args() 73 | 74 | sample_resp(args=args) 75 | 76 | -------------------------------------------------------------------------------- /train_code/vllm_infer_arc.py: -------------------------------------------------------------------------------- 1 | import io 2 | import json 3 | import logging 4 | import math 5 | import random 6 | import numpy as np 7 | import os 8 | import pprint 9 | import sys 10 | import time 11 | import transformers 12 | import torch 13 | 14 | from datasets import load_dataset 15 | 16 | from datetime import datetime, date 17 | from tqdm import tqdm 18 | from vllm import LLM, SamplingParams 19 | from ray_vllm import LLMRayWrapper 20 | from string import ascii_uppercase 21 | import re 22 | 23 | def load_jsonl(file_path): 24 | with open(file_path, "r") as f: 25 | return [json.loads(line) for line in f] 26 | 27 | def construct_prompt(template, data): 28 | choices = data['choices']["text"] 29 | question = data['question'] + "\n" 30 | i = 0 31 | for choice in choices: 32 | question += f"({ascii_uppercase[i]}) {choice}\n" 33 | i += 1 34 | question = question.strip() 35 | return template.replace("{{question}}", question) 36 | 37 | 38 | def extract_answer(response): 39 | pattern_list = [r"Final answer:\s*[\[\(]?([A-Za-z])[\]\)]?"] 40 | for pattern in pattern_list: 41 | match = re.search(pattern, response) 42 | if match: 43 | return match.group(1) 44 | return None 45 | 46 | 47 | def main(args): 48 | 49 | argsdict = vars(args) 50 | print(pprint.pformat(argsdict)) 51 | 52 | if args.input_file is not None: 53 | problems = load_jsonl(args.input_file) 54 | problems = problems[args.start:args.end] 55 | print("Loading problems from ", min(len(problems),args.end)) 56 | print("Number of problems: ", len(problems)) 57 | else: 58 | problems = load_dataset("allenai/ai2_arc", "ARC-Challenge", split=args.split, cache_dir=args.cache_dir) 59 | # problems = [{"id": idx , **item} for idx, item in enumerate(problems)] 60 | # random.seed(42) 61 | problems = problems.select(range(args.start, min(len(problems),args.end))) 62 | 63 | if not os.path.exists(os.path.dirname(args.save)): 64 | os.makedirs(os.path.dirname(args.save), exist_ok=True) 65 | 66 | 67 | llm = LLMRayWrapper(model_path=args.model, tensor_parallel_size=args.tensor_parallel_size, max_model_len=4096, cuda_ids=args.cuda_ids, swap_space = 20) # Adjust tensor_parallel_size as needed 68 | stop_tokens = ["\nQUESTION:\n", "Question:", "Question", "USER:", "USER", "ASSISTANT:", "ASSISTANT", "Instruction:", "Instruction", "Response:", "Response", "\n\n"] 69 | sampling_params = SamplingParams( 70 | n=args.sample_num, 71 | temperature=args.temperature, 72 | top_k=args.top_k, 73 | max_tokens=args.max_tokens, 74 | stop=stop_tokens 75 | ) 76 | prompt = open(args.prompt_template_dir+"/"+str(args.shot_num)+"-shot-prompt.txt", "r").read() 77 | zero_shot_prompt = open(args.prompt_template_dir+"/0-shot-prompt.txt", "r").read() 78 | # tokenizer = AutoTokenizer.from_pretrained(args.model) 79 | 80 | # Filter too long problems with tokenizer 81 | # problems = problems.filter(lambda x: len(tokenizer(x["question"])["input_ids"]) < 2048) 82 | if not os.path.exists(os.path.dirname(args.save)): 83 | os.makedirs(os.path.dirname(args.save), exist_ok=True) 84 | # ans_file = open(args.save, "w") 85 | 86 | prompt_list = [] 87 | for item in problems: 88 | prompt_list.append(construct_prompt(prompt, item)) 89 | outputs = llm.generate(prompts = prompt_list, sampling_params = sampling_params) 90 | correct = 0 91 | gpt_codes = {} 92 | for i, (p, output) in enumerate(zip(problems, outputs)): 93 | question = p["question"] 94 | output_list = { 95 | "id": p["id"], 96 | "question": question, 97 | "input": construct_prompt(zero_shot_prompt, p), 98 | "gt": p["answerKey"] 99 | } 100 | for q in range(args.sample_num): 101 | raw_response = output.outputs[q].text 102 | answer = extract_answer(raw_response) 103 | gt = p["answerKey"] 104 | if answer and answer.lower() == gt.lower(): 105 | correct += 1 106 | score = 1.0 107 | else: 108 | score = 0.0 109 | output_list["output"+str(q)] ={ 110 | "text": raw_response, 111 | "score": score, 112 | "finish_reason": output.outputs[q].finish_reason, 113 | "extracted_answer": answer 114 | } 115 | gpt_codes[p["id"]] = output_list 116 | 117 | # ans_file.write(json.dumps(output_list) + "\n") 118 | # ans_file.flush() 119 | 120 | 121 | 122 | if args.debug: 123 | print("Prompt: ", "-" * 100) 124 | print(output.prompt) 125 | print("Completion: ", "-" * 100) 126 | print(output_list['output0']['text']) 127 | print("Ground Truth: ", gt) 128 | print("Score: ", output_list['output0']['score']) 129 | # ans_file.close() 130 | 131 | with open(args.save, "w") as f: 132 | json.dump(gpt_codes, f, indent=2) 133 | 134 | print(f"Accuracy: {correct / (len(problems) * args.sample_num)}") 135 | 136 | 137 | 138 | 139 | if __name__ == "__main__": 140 | import argparse 141 | 142 | parser = argparse.ArgumentParser(description="Run a tranined model to generate ARC Challenge.") 143 | parser.add_argument("--model", default="gpt2") 144 | parser.add_argument("--prompt_template_dir", default=None, type=str) 145 | parser.add_argument("--shot_num", default=None, type=int) 146 | parser.add_argument("--input_file", default=None, type=str) 147 | # parser.add_argument("-t","--test_loc", default="~/apps/data_split/test.json", type=str, help="path to the test folder.") 148 | # parser.add_argument("-r","--root", default="../", type=str, help="where the data is stored.") 149 | # parser.add_argument("-l","--load", default="", type=str) 150 | # parser.add_argument("--peeking", default=0.0, type=float) 151 | parser.add_argument("--sample_num", type=int, default=10) # 采样数量 152 | parser.add_argument("--temperature", type=float, default=1.0) # 温度 153 | parser.add_argument("--top_k", type=int, default=20) # top_k 154 | parser.add_argument("--max_tokens", type=int, default=768) # 最大token数 155 | # parser.add_argument("--difficulty", default="introductory", type=str) 156 | # parser.add_argument("--num-beams", default=5, type=int) 157 | parser.add_argument("-s","--start", default=0, type=int) 158 | parser.add_argument("-e","--end", default=10000000, type=int) 159 | # parser.add_argument("-i", "--index", default=None, type=int) 160 | parser.add_argument("--cuda_ids", type=str, default="0") 161 | parser.add_argument("-d", "--debug", action="store_true") 162 | parser.add_argument("--split", type=str, default="train", help="What split to use.") 163 | parser.add_argument("--save", type=str, default="./results") 164 | parser.add_argument("--tensor-parallel-size", type=int, default=1, help="Tensor parallel size for vllm") 165 | parser.add_argument("--cache_dir", type=str, default=None, help="Cache directory for datasets") 166 | args = parser.parse_args() 167 | 168 | main(args) 169 | -------------------------------------------------------------------------------- /train_code/vllm_infer_auto.py: -------------------------------------------------------------------------------- 1 | from vllm import LLM, SamplingParams 2 | from transformers import AutoModelForCausalLM, AutoTokenizer 3 | import json 4 | import random 5 | import re 6 | import argparse 7 | # import wandb 8 | # from huggingface_hub import HfApi 9 | # from huggingface_hub import HfApi, upload_file 10 | 11 | # def upload_to_huggingface(output_file, repo_id, token, commit_message="Upload output file"): 12 | # # Initialize HfApi with token 13 | # api = HfApi() 14 | 15 | # # Upload the file 16 | # api.upload_file( 17 | # path_or_fileobj=output_file, 18 | # path_in_repo=output_file, # You can change the path in the repo if needed 19 | # repo_id=repo_id, 20 | # token=token, 21 | # commit_message=commit_message 22 | # ) 23 | 24 | def upload_to_huggingface(output_file, repo_id, token, repo_type='dataset', commit_message="Upload output file"): 25 | # Upload the file to the specified repository 26 | upload_file( 27 | path_or_fileobj=output_file, 28 | path_in_repo=output_file, # You can change the path in the repo if needed 29 | repo_id=repo_id, 30 | token=token, 31 | repo_type=repo_type, 32 | commit_message=commit_message 33 | ) 34 | print(f"File {output_file} uploaded successfully to {repo_id}.") 35 | 36 | def sample_resp(args): 37 | # wandb.init(project=args.wandb_project, config={ 38 | # "input_data": args.input_data, 39 | # "model_dir": args.model_dir, 40 | # "sample_num": args.sample_num, 41 | # "temperature": args.temperature, 42 | # "top_k": args.top_k, 43 | # "max_tokens": args.max_tokens, 44 | # "output_file": args.output_file, 45 | # "tensor_parallel_size": args.tensor_parallel_size, 46 | # }) 47 | 48 | with open(args.input_data, "r") as r: 49 | data_json = json.load(r) 50 | sample_data = data_json 51 | 52 | data_dict = {} 53 | 54 | for item in sample_data: 55 | data_dict[item["input"]] = item 56 | 57 | prompt_template = ''' 58 | 59 | "Below is an instruction that describes a task. " "Write a response that appropriately completes the request.\n\n" "### Instruction:\n{instruction}\n\n### Response: Let's think step by step." 60 | 61 | ''' 62 | llm = LLM(model=args.model_dir, tensor_parallel_size=args.tensor_parallel_size) 63 | 64 | sampling_params = SamplingParams(max_tokens=args.max_tokens, temperature=args.temperature, top_k=args.top_k, n=args.sample_num) 65 | all_input = [] 66 | 67 | for item in sample_data: 68 | all_input.append(item["input"]) 69 | outputs = llm.generate(all_input, sampling_params) 70 | 71 | all_output = [] 72 | all_prompt = [] 73 | 74 | all_json = [] 75 | 76 | for output in outputs: 77 | 78 | temp_json = data_dict[output.prompt] 79 | all_prompt.append(output.prompt) 80 | 81 | temp_json["prompt"] = output.prompt 82 | 83 | for i in range(args.sample_num): 84 | temp_json["output"+str(i)] = output.outputs[i].text 85 | 86 | all_json.append(temp_json) 87 | 88 | with open(args.output_file, "w") as w: 89 | json.dump(all_json, w) 90 | 91 | # 上传文件到Hugging Face 92 | #upload_to_huggingface(args.output_file, args.repo_id, args.hf_token) 93 | 94 | 95 | 96 | def parse_args(): 97 | parser = argparse.ArgumentParser() 98 | parser.add_argument("--input_data", type=str, required=True) # 输入数据文件路径 99 | parser.add_argument("--model_dir", type=str, required=True) # 模型目录 100 | parser.add_argument("--sample_num", type=int, default=10) # 采样数量 101 | parser.add_argument("--temperature", type=float, default=1.0) # 温度 102 | parser.add_argument("--top_k", type=int, default=20) # top_k 103 | parser.add_argument("--max_tokens", type=int, default=768) # 最大token数 104 | parser.add_argument("--output_file", type=str, required=True) # 输出文件路径 105 | parser.add_argument("--tensor_parallel_size", type=int, default=1) # 并行大小 106 | #parser.add_argument("--repo_id", type=str, required=True, help="Hugging Face repository ID") # Hugging Face存储库ID 107 | #parser.add_argument("--hf_token", type=str, required=True, help="Hugging Face access token") # Hugging Face访问令牌 108 | #parser.add_argument("--wandb_project", type=str, default="my_project") # wandb project name 109 | return parser.parse_args() 110 | 111 | if __name__ == "__main__": 112 | args = parse_args() 113 | sample_resp(args=args) 114 | --------------------------------------------------------------------------------