├── .gitignore ├── README.md ├── data ├── test │ ├── GSM8K_test.jsonl │ └── MATH_test.jsonl └── train │ ├── MetaMath-40K_split1.json │ ├── MetaMath-40K_split2.json │ └── README.md ├── deepspeed_config.json ├── env_setup.sh ├── eval_gsm8k.py ├── eval_math.py ├── eval_script ├── codellama_eval.sh ├── llama2_eval.sh └── llemma_eval.sh ├── imgs └── metamath.svg ├── train_math.py ├── train_script ├── train_codellama_full.sh ├── train_llama2_full.sh └── train_llemma_7b_full.sh └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.out 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Llemma: MetaMathQA Finetunes 2 | 3 | Code for finetuning the Code Llama 7B and Llemma 7B models on the MetaMathQA dataset. 4 | 5 | Instructions for replicating the finetuning experiments in Azerbayev et al. (2023) are below. 6 | 7 | ### Replication Instructions 8 | First, modify `env_setup.sh` to declare the `BASE_DIR` and `TRAIN_FILE` environment variables correctly. Then, from the base directory of this repository, run 9 | ``` 10 | ./train_scipt/train_llama2_full.sh 11 | ./train_script/train_codellama_full.sh 12 | ./train_script/train_llemma_7b_full.sh 13 | ``` 14 | Note that the `train_llama2_full.sh` script is designed to replicate the experiments in Yu et al. (2023). The scripts are designed for an 8xA100 80GB configuration: modify them for your hardware as appropriate. 15 | 16 | Once the models have finished finetuning, run 17 | ``` 18 | ./eval_scripts/llama2_eval.sh 19 | ./eval_scripts/codellama_eval.sh 20 | ./eval_scripts/llemma_eval.sh 21 | ``` 22 | from the base directory of this repository to replicate evaluation results. 23 | 24 | ### Citation 25 | 26 | ``` 27 | # Add Llemma citation 28 | 29 | @misc{yu2023metamath, 30 | title={MetaMath: Bootstrap Your Own Mathematical Questions for Large Language Models}, 31 | author={Longhui Yu and Weisen Jiang and Han Shi and Jincheng Yu and Zhengying Liu and Yu Zhang and James T. Kwok and Zhenguo Li and Adrian Weller and Weiyang Liu}, 32 | year={2023}, 33 | eprint={2309.12284}, 34 | archivePrefix={arXiv}, 35 | primaryClass={cs.CL} 36 | } 37 | ``` 38 | -------------------------------------------------------------------------------- /data/train/README.md: -------------------------------------------------------------------------------- 1 | # MetaMathQA 2 | -------------------------------------------------------------------------------- /deepspeed_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "precision": "bfloat16", 3 | "bf16": { 4 | "enabled": true 5 | }, 6 | "fp16": { 7 | "enabled": false 8 | }, 9 | "optimizer": { 10 | "type": "AdamW", 11 | "params": { 12 | "lr": "auto", 13 | "betas": "auto", 14 | "eps": "auto", 15 | "weight_decay": "auto" 16 | } 17 | }, 18 | "scheduler": { 19 | "type": "WarmupLR", 20 | "params": { 21 | "warmup_min_lr": "auto", 22 | "warmup_max_lr": "auto", 23 | "warmup_num_steps": "auto" 24 | } 25 | }, 26 | "zero_optimization": { 27 | "stage": 3, 28 | "overlap_comm": true, 29 | "contiguous_gradients": true, 30 | "sub_group_size": 1e9, 31 | "reduce_bucket_size": "auto", 32 | "stage3_prefetch_bucket_size": "auto", 33 | "stage3_param_persistence_threshold": "auto", 34 | "stage3_max_live_parameters": 1e9, 35 | "stage3_max_reuse_distance": 1e9, 36 | "stage3_gather_fp16_weights_on_model_save": true 37 | }, 38 | "gradient_accumulation_steps": "auto", 39 | "gradient_clipping": "auto", 40 | "steps_per_print": 2000, 41 | "train_batch_size": "auto", 42 | "train_micro_batch_size_per_gpu": "auto", 43 | "wall_clock_breakdown": false 44 | } 45 | -------------------------------------------------------------------------------- /env_setup.sh: -------------------------------------------------------------------------------- 1 | BASE_DIR=/replace/with/repo/base/dir # Replace with path this repository 2 | TRAIN_FILE=/path/to/MetaMathQA/json # Path to MetaMathQA-395k dataset in json form 3 | 4 | # System-specific configuration: environment, CUDA, etc. 5 | 6 | source /home/hailey81/miniconda3/bin/activate metainstruct 7 | 8 | which python 9 | 10 | export LD_LIBRARY_PATH=/home/hailey81/miniconda3/envs/metainstruct/lib/ 11 | export PATH=/home/hailey81/cuda_install/bin:$PATH 12 | 13 | ln -s /home/hailey81/miniconda3/envs/metainstruct-updated/bin/gcc/ ~/.local/bin/gcc 14 | export PATH=$HOME/.local/bin:$PATH 15 | -------------------------------------------------------------------------------- /eval_gsm8k.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import re 4 | import jsonlines 5 | from fraction import Fraction 6 | from vllm import LLM, SamplingParams 7 | import sys 8 | 9 | MAX_INT = sys.maxsize 10 | 11 | def is_number(s): 12 | try: 13 | float(s) 14 | return True 15 | except ValueError: 16 | pass 17 | try: 18 | import unicodedata 19 | unicodedata.numeric(s) 20 | return True 21 | except (TypeError, ValueError): 22 | pass 23 | return False 24 | 25 | def extract_answer_number(completion): 26 | text = completion.split('The answer is: ') 27 | if len(text) > 1: 28 | extract_ans = text[-1].strip() 29 | match = re.search(r'[\-+]?\d*[\.,/]?\d+', extract_ans) 30 | if match: 31 | if '/' in match.group(): 32 | denominator = match.group().split('/')[1] 33 | numerator = match.group().split('/')[0] 34 | if is_number(denominator) == True and is_number(numerator) == True: 35 | if denominator == '0': 36 | return round(float(numerator.replace(',', ''))) 37 | else: 38 | frac = Fraction(match.group().replace(',', '')) 39 | num_numerator = frac.numerator 40 | num_denominator = frac.denominator 41 | return round(float(num_numerator / num_denominator)) 42 | else: 43 | return None 44 | else: 45 | if float(match.group().replace(',', '')) == float('inf'): 46 | return None 47 | return round(float(match.group().replace(',', ''))) 48 | else: 49 | return None 50 | else: 51 | return None 52 | 53 | def batch_data(data_list, batch_size=1): 54 | n = len(data_list) // batch_size 55 | batch_data = [] 56 | for i in range(n-1): 57 | start = i * batch_size 58 | end = (i+1)*batch_size 59 | batch_data.append(data_list[start:end]) 60 | 61 | last_start = (n-1) * batch_size 62 | last_end = MAX_INT 63 | batch_data.append(data_list[last_start:last_end]) 64 | return batch_data 65 | 66 | 67 | def gsm8k_test(model, data_path, start=0, end=MAX_INT, batch_size=1, tensor_parallel_size=1): 68 | INVALID_ANS = "[invalid]" 69 | gsm8k_ins = [] 70 | gsm8k_answers = [] 71 | problem_prompt = ( 72 | "Below is an instruction that describes a task. " 73 | "Write a response that appropriately completes the request.\n\n" 74 | "### Instruction:\n{instruction}\n\n### Response: Let's think step by step." 75 | ) 76 | print('promt =====', problem_prompt) 77 | with open(data_path,"r+", encoding="utf8") as f: 78 | for idx, item in enumerate(jsonlines.Reader(f)): 79 | temp_instr = problem_prompt.format(instruction=item["query"]) 80 | gsm8k_ins.append(temp_instr) 81 | temp_ans = item['response'].split('#### ')[1] 82 | temp_ans = int(temp_ans.replace(',', '')) 83 | gsm8k_answers.append(temp_ans) 84 | 85 | gsm8k_ins = gsm8k_ins[start:end] 86 | gsm8k_answers = gsm8k_answers[start:end] 87 | print('lenght ====', len(gsm8k_ins)) 88 | batch_gsm8k_ins = batch_data(gsm8k_ins, batch_size=batch_size) 89 | 90 | stop_tokens = ["Question:", "Question", "USER:", "USER", "ASSISTANT:", "ASSISTANT", "Instruction:", "Instruction", "Response:", "Response"] 91 | sampling_params = SamplingParams(temperature=0.0, top_p=1, max_tokens=512, stop=stop_tokens) 92 | print('sampleing =====', sampling_params) 93 | llm = LLM(model=model,tensor_parallel_size=tensor_parallel_size) 94 | result = [] 95 | res_completions = [] 96 | for idx, (prompt, prompt_answer) in enumerate(zip(batch_gsm8k_ins, gsm8k_answers)): 97 | if isinstance(prompt, list): 98 | pass 99 | else: 100 | prompt = [prompt] 101 | 102 | completions = llm.generate(prompt, sampling_params) 103 | for output in completions: 104 | prompt = output.prompt 105 | generated_text = output.outputs[0].text 106 | res_completions.append(generated_text) 107 | 108 | invalid_outputs = [] 109 | for idx, (prompt, completion, prompt_answer) in enumerate(zip(gsm8k_ins, res_completions, gsm8k_answers)): 110 | doc = {'question': prompt} 111 | y_pred = extract_answer_number(completion) 112 | if y_pred != None: 113 | result.append(float(y_pred) == float(prompt_answer)) 114 | else: 115 | result.append(False) 116 | temp = {'question': prompt, 'output': completion, 'answer': prompt_answer} 117 | invalid_outputs.append(temp) 118 | acc = sum(result) / len(result) 119 | print('len invalid outputs ====', len(invalid_outputs), ', valid_outputs===', invalid_outputs) 120 | print('start===', start, ', end====', end) 121 | print('gsm8k length====', len(result), ', gsm8k acc====', acc) 122 | 123 | 124 | def parse_args(): 125 | parser = argparse.ArgumentParser() 126 | parser.add_argument("--model", type=str) # model path 127 | parser.add_argument("--data_file", type=str, default='/home/wliu/longhui/llms-all/gsm8k-inverse/data/test_use.jsonl') # data path 128 | parser.add_argument("--start", type=int, default=0) #start index 129 | parser.add_argument("--end", type=int, default=MAX_INT) # end index 130 | parser.add_argument("--batch_size", type=int, default=400) # batch_size 131 | parser.add_argument("--tensor_parallel_size", type=int, default=8) # tensor_parallel_size 132 | return parser.parse_args() 133 | # module load cuda/11.7 134 | # export TORCH_EXTENSIONS_DIR=./tmp 135 | # export PATH=/home/wliu/anaconda3/envs/llama_adapter/bin:$PATH 136 | # python eval_wizard_gsm8k.py --model /lustre/fast/fast/wliu/longhui/inverse_ckpt/lora_7b 137 | # python eval_math.py --model /lustre/fast/fast/wliu/longhui/inverse_ckpt/lora_7b 138 | # python eval_wizard_gsm8k.py --model /lustre/fast/fast/wliu/longhui/inverse_ckpt/lora_7b_cosine 139 | # python eval_wizard_gsm8k.py --model /lustre/fast/fast/wliu/longhui/inverse_ckpt/llama-70b-merged-qlora 140 | if __name__ == "__main__": 141 | args = parse_args() 142 | gsm8k_test(model=args.model, data_path=args.data_file, start=args.start, end=args.end, batch_size=args.batch_size, tensor_parallel_size=args.tensor_parallel_size) 143 | 144 | # MODEL_DIR='/lustre/fast/fast/wliu/longhui/inverse_ckpt/MATH_llama-7b-re' 145 | # MODEL_DIR='/lustre/fast/fast/wliu/longhui/inverse_ckpt/MATH_llama-7b-back' 146 | # MODEL_DIR='/lustre/fast/fast/wliu/longhui/inverse_ckpt/MATH_llama-7b-merge' 147 | # MODEL_DIR='/lustre/fast/fast/wliu/longhui/inverse_ckpt/MATH_llama-7b-gsm_240k' 148 | # MODEL_DIR='/lustre/fast/fast/wliu/longhui/inverse_ckpt/MATH_llama-7b-gsm_merge_353k' 149 | # MODEL_DIR='/lustre/fast/fast/wliu/longhui/inverse_ckpt/MATH_llama-7b-gsm_merge_no_special_353k' 150 | # MODEL_DIR='/lustre/fast/fast/wliu/longhui/inverse_ckpt/MATH_llama-7b-gsm_no_special_240k' 151 | 152 | # python eval_wizard_gsm8k.py --model /lustre/fast/fast/wliu/longhui/inverse_ckpt/MATH_llama-7b-gsm_no_special_240k 153 | # python eval_wizard_gsm8k.py --model /lustre/fast/fast/wliu/longhui/inverse_ckpt/MATH_llama-7b-gsm_merge_no_special_353k 154 | # python eval_wizard_gsm8k.py --model /lustre/fast/fast/wliu/longhui/inverse_ckpt/MATH_llama-7b-gsm_merge_353k 155 | # python eval_wizard_gsm8k.py --model /lustre/fast/fast/wliu/longhui/inverse_ckpt/MATH_llama-7b-gsm_240k 156 | 157 | # python eval_wizard_gsm8k.py --model /lustre/fast/fast/wliu/longhui/inverse_ckpt/MATH_llama-7b-for 158 | # python eval_wizard_gsm8k.py --model /lustre/fast/fast/wliu/longhui/inverse_ckpt/MATH_llama-7b-re 159 | # python eval_wizard_gsm8k.py --model /lustre/fast/fast/wliu/longhui/inverse_ckpt/MATH_llama-7b-back 160 | # python eval_wizard_gsm8k.py --model /lustre/fast/fast/wliu/longhui/inverse_ckpt/MATH_llama-7b-merge 161 | 162 | 163 | 164 | # python eval_wizard_gsm8k.py --model /lustre/fast/fast/wliu/longhui/inverse_ckpt/MATH_gsm_llama-7b-398k --tensor_parallel_size 8 --batch_size 400 --data_file /home/wliu/longhui/llms-all/gsm8k-inverse/data/test_use.jsonl 165 | -------------------------------------------------------------------------------- /eval_math.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import pdb 4 | import jsonlines 5 | 6 | import util 7 | from vllm import LLM, SamplingParams 8 | import sys 9 | MAX_INT = sys.maxsize 10 | INVALID_ANS = "[invalid]" 11 | 12 | invalid_outputs = [] 13 | def remove_boxed(s): 14 | left = "\\boxed{" 15 | try: 16 | assert s[:len(left)] == left 17 | assert s[-1] == "}" 18 | return s[len(left):-1] 19 | except: 20 | return None 21 | 22 | def process_results(doc, completion, answer): 23 | split_ans = completion.split('The answer is: ') 24 | if len(split_ans) > 1: 25 | ans = split_ans[-1] 26 | extract_ans_temp = ans.split('.\n')[0] 27 | extract_ans_temp = extract_ans_temp.strip() 28 | if len(extract_ans_temp)>0 and extract_ans_temp[-1] == '.': 29 | extract_ans = extract_ans_temp[0:-1] 30 | else: 31 | extract_ans = extract_ans_temp 32 | extract_ans = extract_ans.strip() 33 | if util.is_equiv(extract_ans, answer): 34 | return True 35 | else: 36 | return False 37 | else: 38 | temp = {'question': doc, 'output': completion, 'answer': answer} 39 | invalid_outputs.append(temp) 40 | return False 41 | def batch_data(data_list, batch_size=1): 42 | n = len(data_list) // batch_size 43 | batch_data = [] 44 | for i in range(n-1): 45 | start = i * batch_size 46 | end = (i+1)*batch_size 47 | batch_data.append(data_list[start:end]) 48 | 49 | last_start = (n-1) * batch_size 50 | last_end = MAX_INT 51 | batch_data.append(data_list[last_start:last_end]) 52 | return batch_data 53 | 54 | def test_hendrycks_math(model, data_path, start=0, end=MAX_INT, batch_size=1, tensor_parallel_size=1): 55 | hendrycks_math_ins = [] 56 | hendrycks_math_answers = [] 57 | problem_prompt = ( 58 | "Below is an instruction that describes a task. " 59 | "Write a response that appropriately completes the request.\n\n" 60 | "### Instruction:\n{instruction}\n\n### Response: Let's think step by step." 61 | ) 62 | print('promt =====', problem_prompt) 63 | with open(data_path, "r+", encoding="utf8") as f: 64 | for idx, item in enumerate(jsonlines.Reader(f)): 65 | temp_instr = problem_prompt.format(instruction=item["instruction"]) 66 | hendrycks_math_ins.append(temp_instr) 67 | solution = item['output'] 68 | temp_ans = remove_boxed(util.last_boxed_only_string(solution)) 69 | hendrycks_math_answers.append(temp_ans) 70 | 71 | print('total length ===', len(hendrycks_math_ins)) 72 | hendrycks_math_ins = hendrycks_math_ins[start:end] 73 | hendrycks_math_answers = hendrycks_math_answers[start:end] 74 | print('lenght ====', len(hendrycks_math_ins)) 75 | batch_hendrycks_math_ins = batch_data(hendrycks_math_ins, batch_size=batch_size) 76 | 77 | stop_tokens = ["Question:", "Question", "USER:", "USER", "ASSISTANT:", "ASSISTANT", "Instruction:", "Instruction", "Response:", "Response"] 78 | sampling_params = SamplingParams(temperature=0, top_p=1, max_tokens=2048, stop=stop_tokens) 79 | print('sampleing =====', sampling_params) 80 | llm = LLM(model=model,tensor_parallel_size=tensor_parallel_size) 81 | res_completions = [] 82 | for idx, (prompt, prompt_answer) in enumerate(zip(batch_hendrycks_math_ins, hendrycks_math_answers)): 83 | if isinstance(prompt, list): 84 | pass 85 | else: 86 | prompt = [prompt] 87 | completions = llm.generate(prompt, sampling_params) 88 | for output in completions: 89 | prompt_temp = output.prompt 90 | generated_text = output.outputs[0].text 91 | res_completions.append(generated_text) 92 | 93 | results = [] 94 | for idx, (prompt, completion, prompt_answer) in enumerate(zip(hendrycks_math_ins, res_completions, hendrycks_math_answers)): 95 | res = process_results(prompt, completion, prompt_answer) 96 | results.append(res) 97 | 98 | acc = sum(results) / len(results) 99 | print('len invalid outputs ====', len(invalid_outputs), ', valid_outputs===', invalid_outputs) 100 | print('start===', start, ', end====',end) 101 | print('length====', len(results), ', acc====', acc) 102 | 103 | def parse_args(): 104 | parser = argparse.ArgumentParser() 105 | parser.add_argument("--model", type=str, default='/lustre/fast/fast/wliu/longhui/inverse_ckpt/llama-2-7b-sft/') # model path 106 | parser.add_argument("--data_file", type=str, default='/home/wliu/longhui/llms-all/gsm8k-inverse/data/MATH/MATH_test.jsonl') # data path 107 | parser.add_argument("--start", type=int, default=0) #start index 108 | parser.add_argument("--end", type=int, default=MAX_INT) # end index 109 | parser.add_argument("--batch_size", type=int, default=400) # batch_size 110 | parser.add_argument("--tensor_parallel_size", type=int, default=8) # tensor_parallel_size 111 | return parser.parse_args() 112 | 113 | if __name__ == "__main__": 114 | args = parse_args() 115 | test_hendrycks_math(model=args.model, data_path=args.data_file, start=args.start, end=args.end, batch_size=args.batch_size, tensor_parallel_size=args.tensor_parallel_size) 116 | 117 | # python eval_math.py --tensor_parallel_size 8 --batch_size 384 --model 118 | -------------------------------------------------------------------------------- /eval_script/codellama_eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source env_setup.sh 4 | cd ${BASE_DIR} 5 | 6 | MODEL=model/codellama_metainstruct_full/checkpoint-9258 7 | 8 | python eval_gsm8k.py --model $MODEL --data_file data/test/GSM8K_test.jsonl --tensor_parallel_size 1 9 | python eval_math.py --model $MODEL --data_file data/test/MATH_test.jsonl --tensor_parallel_size 1 10 | -------------------------------------------------------------------------------- /eval_script/llama2_eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source env_setup.sh 4 | cd ${BASE_DIR} 5 | 6 | MODEL=model/llama2_metainstruct_full/checkpoint-9258 7 | 8 | python eval_gsm8k.py --model ${MODEL} --data_file data/test/GSM8K_test.jsonl --tensor_parallel_size 1 9 | python eval_math.py --model ${MODEL} --data_file data/test/MATH_test.jsonl --tensor_parallel_size 1 10 | -------------------------------------------------------------------------------- /eval_script/llemma_eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source env_setup.sh 4 | cd ${BASE_DIR} 5 | 6 | MODEL=model/llemma_7b_metainstruct_full/checkpoint-9258 # Replace with path to finetuned Llama-2 model 7 | 8 | python eval_gsm8k.py --model $MODEL --data_file data/test/GSM8K_test.jsonl --tensor_parallel_size 1 9 | python eval_math.py --model $MODEL --data_file data/test/MATH_test.jsonl --tensor_parallel_size 1 10 | -------------------------------------------------------------------------------- /train_math.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # Modified by Zheng Yuan and Hongyi Yuan 15 | 16 | import os 17 | import copy 18 | import logging 19 | from dataclasses import dataclass, field 20 | from typing import Optional, Dict, Sequence 21 | import io 22 | import torch 23 | import transformers 24 | from torch.utils.data import Dataset 25 | from transformers import Trainer 26 | import argparse 27 | import json 28 | import random;random.seed(42) 29 | 30 | def _make_r_io_base(f, mode: str): 31 | if not isinstance(f, io.IOBase): 32 | f = open(f, mode=mode) 33 | return f 34 | 35 | def jload(f, mode="r"): 36 | """Load a .json file into a dictionary.""" 37 | f = _make_r_io_base(f, mode) 38 | jdict = json.load(f) 39 | f.close() 40 | return jdict 41 | 42 | IGNORE_INDEX = -100 43 | DEFAULT_PAD_TOKEN = "[PAD]" 44 | DEFAULT_EOS_TOKEN = "" 45 | DEFAULT_BOS_TOKEN = "" 46 | DEFAULT_UNK_TOKEN = "" 47 | PROMPT_DICT = { 48 | "prompt_input": ( 49 | "Below is an instruction that describes a task, paired with an input that provides further context. " 50 | "Write a response that appropriately completes the request.\n\n" 51 | "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:" 52 | ), 53 | "prompt_no_input": ( 54 | "Below is an instruction that describes a task. " 55 | "Write a response that appropriately completes the request.\n\n" 56 | "### Instruction:\n{instruction}\n\n### Response:" 57 | ), 58 | } 59 | #### 28 60 | @dataclass 61 | class ModelArguments: 62 | model_name_or_path: Optional[str] = field(default="facebook/opt-125m") 63 | 64 | 65 | @dataclass 66 | class DataArguments: 67 | data_path: str = field(default=None, metadata={"help": "Path to the training data."}) 68 | 69 | 70 | @dataclass 71 | class TrainingArguments(transformers.TrainingArguments): 72 | cache_dir: Optional[str] = field(default=None) 73 | optim: str = field(default="adamw_torch") 74 | model_max_length: int = field( 75 | default=512, 76 | metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."}, 77 | ) 78 | overwrite_output_dir: bool = field(default=True) 79 | 80 | 81 | def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str): 82 | """Collects the state dict and dump to disk.""" 83 | state_dict = trainer.model.state_dict() 84 | if trainer.args.should_save: 85 | cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()} 86 | del state_dict 87 | trainer._save(output_dir, state_dict=cpu_state_dict) # noqa 88 | 89 | 90 | def smart_tokenizer_and_embedding_resize( 91 | special_tokens_dict: Dict, 92 | tokenizer: transformers.PreTrainedTokenizer, 93 | model: transformers.PreTrainedModel, 94 | ): 95 | """Resize tokenizer and embedding. 96 | 97 | Note: This is the unoptimized version that may make your embedding size not be divisible by 64. 98 | """ 99 | num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) 100 | model.resize_token_embeddings(len(tokenizer)) 101 | 102 | if num_new_tokens > 0: 103 | input_embeddings = model.get_input_embeddings().weight.data 104 | output_embeddings = model.get_output_embeddings().weight.data 105 | 106 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) 107 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) 108 | 109 | input_embeddings[-num_new_tokens:] = input_embeddings_avg 110 | output_embeddings[-num_new_tokens:] = output_embeddings_avg 111 | 112 | 113 | def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict: 114 | """Tokenize a list of strings.""" 115 | tokenized_list = [ 116 | tokenizer( 117 | text, 118 | return_tensors="pt", 119 | padding="longest", 120 | max_length=tokenizer.model_max_length, 121 | truncation=True, 122 | ) 123 | for text in strings 124 | ] 125 | input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list] 126 | input_ids_lens = labels_lens = [ 127 | tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list 128 | ] 129 | return dict( 130 | input_ids=input_ids, 131 | labels=labels, 132 | input_ids_lens=input_ids_lens, 133 | labels_lens=labels_lens, 134 | ) 135 | 136 | 137 | def preprocess( 138 | sources: Sequence[str], 139 | targets: Sequence[str], 140 | tokenizer: transformers.PreTrainedTokenizer, 141 | ) -> Dict: 142 | """Preprocess the data by tokenizing.""" 143 | examples = [s + t for s, t in zip(sources, targets)] 144 | examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)] 145 | input_ids = examples_tokenized["input_ids"] 146 | labels = copy.deepcopy(input_ids) 147 | for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]): 148 | label[:source_len] = IGNORE_INDEX 149 | return dict(input_ids=input_ids, labels=labels) 150 | 151 | class SupervisedDataset(Dataset): 152 | """Dataset for supervised fine-tuning.""" 153 | 154 | def __init__(self, data_args, tokenizer: transformers.PreTrainedTokenizer): 155 | super(SupervisedDataset, self).__init__() 156 | logging.warning("Loading data...") 157 | data_path = data_args.data_path 158 | try: 159 | data_path = data_path_map[data_path] 160 | except: 161 | data_path = data_path 162 | try: 163 | list_data_dict = jload(data_path) 164 | except BaseException: 165 | with open(data_path, 'r') as f: 166 | lines = f.readlines() 167 | list_data_dict = [json.loads(line.strip()) for line in lines] 168 | 169 | list_data_dict = random.sample(list_data_dict, len(list_data_dict)) 170 | list_data_dict = list_data_dict[:data_args.data_length] 171 | 172 | # logging.warning("Formatting inputs...") 173 | prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"] 174 | # print(list_data_dict[0]) 175 | if 'instruction' in list_data_dict[0]: 176 | pass 177 | else: 178 | def get_input(query): 179 | if query.find('\n') == -1: 180 | return '' 181 | return '\n'.join(query.split('\n')[1:]) 182 | list_data_dict = [{'instruction':data['query'].split('\n')[0], 'input':get_input(data['query']), 'output':data['response']} for data in list_data_dict] 183 | # import ipdb; ipdb.set_trace() 184 | sources = [ 185 | prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example) 186 | for example in list_data_dict 187 | ] 188 | targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict] 189 | 190 | self.sources = sources 191 | self.targets = targets 192 | 193 | def __len__(self): 194 | return len(self.sources) 195 | 196 | def naive__getitem__(self, i) -> Dict[str, torch.Tensor]: 197 | return dict(input_ids=self.input_ids[i], labels=self.labels[i]) 198 | 199 | def __getitem__(self, i): 200 | return dict(input_ids=self.sources[i], labels=self.targets[i]) 201 | 202 | @dataclass 203 | class DataCollatorForSupervisedDataset(object): 204 | """Collate examples for supervised fine-tuning.""" 205 | 206 | tokenizer: transformers.PreTrainedTokenizer 207 | 208 | def naive__call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: 209 | input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) 210 | input_ids = torch.nn.utils.rnn.pad_sequence( 211 | input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id 212 | ) 213 | labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) 214 | return dict( 215 | input_ids=input_ids, 216 | labels=labels, 217 | attention_mask=input_ids.ne(self.tokenizer.pad_token_id), 218 | ) 219 | 220 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: 221 | sources = [] 222 | targets = [] 223 | for instance in instances: 224 | source = instance['input_ids'] 225 | target = instance['labels'] 226 | sources.append(source) 227 | targets.append(target) 228 | 229 | data_dict = preprocess(sources, targets, self.tokenizer) 230 | input_ids, labels = data_dict['input_ids'], data_dict['labels'] 231 | # input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) 232 | input_ids = torch.nn.utils.rnn.pad_sequence( 233 | input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id 234 | ) 235 | labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) 236 | return dict( 237 | input_ids=input_ids, 238 | labels=labels, 239 | attention_mask=input_ids.ne(self.tokenizer.pad_token_id), 240 | ) 241 | 242 | def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict: 243 | """Make dataset and collator for supervised fine-tuning.""" 244 | train_dataset = SupervisedDataset(tokenizer=tokenizer, data_args=data_args) 245 | data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) 246 | return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator) 247 | 248 | 249 | def train(): 250 | parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) 251 | model_args, data_args, training_args, remaining_args = parser.parse_args_into_dataclasses(return_remaining_strings=True) 252 | data_args.data_length = int(remaining_args[1]) 253 | 254 | model = transformers.AutoModelForCausalLM.from_pretrained( 255 | model_args.model_name_or_path, 256 | cache_dir=training_args.cache_dir, 257 | ) 258 | 259 | tokenizer = transformers.AutoTokenizer.from_pretrained( 260 | "hf-internal-testing/llama-tokenizer", 261 | cache_dir=training_args.cache_dir, 262 | model_max_length=training_args.model_max_length, 263 | padding_side="right", 264 | use_fast=False, 265 | ) 266 | if tokenizer.pad_token is None: 267 | smart_tokenizer_and_embedding_resize( 268 | special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN), 269 | tokenizer=tokenizer, 270 | model=model, 271 | ) 272 | if "llama" in model_args.model_name_or_path: 273 | tokenizer.add_special_tokens( 274 | { 275 | "eos_token": DEFAULT_EOS_TOKEN, 276 | "bos_token": DEFAULT_BOS_TOKEN, 277 | "unk_token": DEFAULT_UNK_TOKEN, 278 | } 279 | ) 280 | 281 | data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) 282 | trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module) 283 | trainer.train() 284 | trainer.save_state() 285 | # if os.environ.get('LOCAL_RANK') == '0': 286 | safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir) 287 | 288 | 289 | if __name__ == "__main__": 290 | train() 291 | -------------------------------------------------------------------------------- /train_script/train_codellama_full.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source env_setup.sh 4 | cd ${BASE_DIR} 5 | 6 | MODEL=codellama/CodeLlama-7b-hf 7 | CONFIG=${BASE_DIR}/deepspeed_config.json 8 | 9 | OUTDIR=${BASE_DIR}/model/codellama_metainstruct_full 10 | 11 | NUM_STEPS=9258 12 | 13 | deepspeed --include localhost:0,1,2,3,4,5,6,7 ${BASE_DIR}/train_math.py \ 14 | --deepspeed ${CONFIG} \ 15 | --model_name_or_path ${MODEL} \ 16 | --data_path ${TRAIN_FILE} \ 17 | --data_length 395000 \ 18 | --bf16 \ 19 | --output_dir ${OUTDIR} \ 20 | --max_steps ${NUM_STEPS} \ 21 | --per_device_train_batch_size 4 \ 22 | --gradient_accumulation_steps 4 \ 23 | --save_strategy "steps" \ 24 | --save_steps ${NUM_STEPS} \ 25 | --save_total_limit 1 \ 26 | --learning_rate 2e-5 \ 27 | --warmup_ratio 0.03 \ 28 | --lr_scheduler_type "cosine" \ 29 | --logging_steps 10 \ 30 | --logging_dir "$OUTDIR" \ 31 | --report_to="tensorboard" \ 32 | -------------------------------------------------------------------------------- /train_script/train_llama2_full.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source env_setup.sh 3 | cd ${BASE_DIR} 4 | 5 | MODEL=meta-llama/Llama-2-7b-hf 6 | CONFIG=${BASE_DIR}/deepspeed_config.json 7 | 8 | OUTDIR=${BASE_DIR}/model/llama2_metainstruct_full 9 | 10 | NUM_STEPS=9258 11 | 12 | deepspeed --include localhost:0,1,2,3,4,5,6,7 ${BASE_DIR}/train_math.py \ 13 | --deepspeed ${CONFIG} \ 14 | --model_name_or_path ${MODEL} \ 15 | --data_path ${TRAIN_FILE} \ 16 | --data_length 395000 \ 17 | --bf16 \ 18 | --output_dir ${OUTDIR} \ 19 | --max_steps ${NUM_STEPS} \ 20 | --per_device_train_batch_size 4 \ 21 | --gradient_accumulation_steps 4 \ 22 | --save_strategy "steps" \ 23 | --save_steps ${NUM_STEPS} \ 24 | --save_total_limit 1 \ 25 | --learning_rate 2e-5 \ 26 | --warmup_ratio 0.03 \ 27 | --lr_scheduler_type "cosine" \ 28 | --logging_steps 10 \ 29 | --logging_dir "$OUTDIR" \ 30 | --report_to="tensorboard" \ 31 | -------------------------------------------------------------------------------- /train_script/train_llemma_7b_full.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source env_setup.sh 4 | cd ${BASE_DIR} 5 | 6 | MODEL=open-web-math/codellama_7b_200btok_step42000 7 | CONFIG=${BASE_DIR}/deepspeed_config.json 8 | 9 | OUTDIR=${BASE_DIR}/model/llemma_7b_metainstruct_full 10 | 11 | NUM_STEPS=9258 12 | 13 | deepspeed --include localhost:0,1,2,3,4,5,6,7 ${BASE_DIR}/train_math.py \ 14 | --deepspeed ${CONFIG} \ 15 | --model_name_or_path ${MODEL} \ 16 | --data_path ${TRAIN_FILE} \ 17 | --data_length 395000 \ 18 | --bf16 \ 19 | --output_dir ${OUTDIR} \ 20 | --max_steps ${NUM_STEPS} \ 21 | --per_device_train_batch_size 4 \ 22 | --gradient_accumulation_steps 4 \ 23 | --save_strategy "steps" \ 24 | --save_steps ${NUM_STEPS} \ 25 | --save_total_limit 1 \ 26 | --learning_rate 2e-5 \ 27 | --warmup_ratio 0.03 \ 28 | --lr_scheduler_type "cosine" \ 29 | --logging_steps 10 \ 30 | --logging_dir "$OUTDIR" \ 31 | --report_to="tensorboard" \ 32 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import pprint 2 | 3 | def last_boxed_only(sample): 4 | q, a = sample 5 | a = last_boxed_only_string(a) 6 | if a == None: 7 | return None 8 | return (q, a) 9 | 10 | def last_boxed_only_string(string): 11 | idx = string.rfind("\\boxed") 12 | if idx < 0: 13 | idx = string.rfind("\\fbox") 14 | if idx < 0: 15 | return None 16 | 17 | i = idx 18 | right_brace_idx = None 19 | num_left_braces_open = 0 20 | while i < len(string): 21 | if string[i] == "{": 22 | num_left_braces_open += 1 23 | if string[i] == "}": 24 | num_left_braces_open -= 1 25 | if num_left_braces_open == 0: 26 | right_brace_idx = i 27 | break 28 | i += 1 29 | 30 | if right_brace_idx == None: 31 | retval = None 32 | else: 33 | retval = string[idx:right_brace_idx + 1] 34 | 35 | return retval 36 | 37 | def only_until_first_boxed_from_tokens(string, tokens): 38 | idx = string.find("\\boxed") 39 | if idx < 0: 40 | idx = string.find("\\fbox") 41 | if idx < 0: 42 | return None 43 | 44 | cum_length = 0 45 | for i, t in enumerate(tokens): 46 | cum_length += len(t) 47 | if cum_length >= idx: 48 | break 49 | 50 | return tokens[:i] 51 | 52 | 53 | 54 | def clean_numbers(sample): 55 | if not sample: 56 | return None 57 | new_sample = list() 58 | for s in sample: 59 | new_sample.append(_clean_numbers(s)) 60 | 61 | return tuple(new_sample) 62 | 63 | def _clean_numbers(string): 64 | """ 65 | Clean Numbers in the given string 66 | 67 | >>> _clean_numbers(None, "Hello 123") 68 | 'Hello 123' 69 | >>> _clean_numbers(None, "Hello 1234") 70 | 'Hello 1,234' 71 | >>> _clean_numbers(None, "Hello 1234324asdasd") 72 | 'Hello 1,234,324asdasd' 73 | """ 74 | num_prev_digits = 0 75 | new_string = "" 76 | for i, c in enumerate(string): 77 | # isdigit() doesnt work here because of weird unicode chars. 78 | if c in {'1', '2', '3', '4', '5', '6', '7', '8', '9', '0'}: 79 | num_prev_digits += 1 80 | else: 81 | if num_prev_digits > 3: 82 | # Some fixing 83 | string_number = new_string[-num_prev_digits:] 84 | new_string = new_string[:-num_prev_digits] + "{0:,}".format(int(string_number)) 85 | num_prev_digits = 0 86 | new_string += c 87 | 88 | if num_prev_digits > 3: 89 | # Some fixing 90 | string_number = new_string[-num_prev_digits:] 91 | new_string = new_string[:-num_prev_digits] + "{0:,}".format(int(string_number)) 92 | 93 | return new_string 94 | 95 | def fix_fracs(string): 96 | substrs = string.split("\\frac") 97 | new_str = substrs[0] 98 | if len(substrs) > 1: 99 | substrs = substrs[1:] 100 | for substr in substrs: 101 | new_str += "\\frac" 102 | if substr[0] == "{": 103 | new_str += substr 104 | else: 105 | try: 106 | assert len(substr) >= 2 107 | except AssertionError: 108 | return string 109 | a = substr[0] 110 | b = substr[1] 111 | if b != "{": 112 | if len(substr) > 2: 113 | post_substr = substr[2:] 114 | new_str += "{" + a + "}{" + b + "}" + post_substr 115 | else: 116 | new_str += "{" + a + "}{" + b + "}" 117 | else: 118 | if len(substr) > 2: 119 | post_substr = substr[2:] 120 | new_str += "{" + a + "}" + b + post_substr 121 | else: 122 | new_str += "{" + a + "}" + b 123 | string = new_str 124 | return string 125 | 126 | def fix_a_slash_b(string): 127 | if len(string.split("/")) != 2: 128 | return string 129 | a = string.split("/")[0] 130 | b = string.split("/")[1] 131 | try: 132 | a = int(a) 133 | b = int(b) 134 | assert string == "{}/{}".format(a, b) 135 | new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" 136 | return new_string 137 | except AssertionError: 138 | return string 139 | 140 | def remove_right_units(string): 141 | # "\\text{ " only ever occurs (at least in the val set) when describing units 142 | if "\\text{ " in string: 143 | splits = string.split("\\text{ ") 144 | assert len(splits) == 2 145 | return splits[0] 146 | else: 147 | return string 148 | 149 | def fix_sqrt(string): 150 | if "\\sqrt" not in string: 151 | return string 152 | splits = string.split("\\sqrt") 153 | new_string = splits[0] 154 | for split in splits[1:]: 155 | if split[0] != "{": 156 | a = split[0] 157 | new_substr = "\\sqrt{" + a + "}" + split[1:] 158 | else: 159 | new_substr = "\\sqrt" + split 160 | new_string += new_substr 161 | return new_string 162 | 163 | 164 | def strip_string(string): 165 | # linebreaks 166 | string = string.replace("\n", "") 167 | 168 | # remove inverse spaces 169 | string = string.replace("\\!", "") 170 | 171 | # replace \\ with \ 172 | string = string.replace("\\\\", "\\") 173 | 174 | # replace tfrac and dfrac with frac 175 | string = string.replace("tfrac", "frac") 176 | string = string.replace("dfrac", "frac") 177 | 178 | # remove \left and \right 179 | string = string.replace("\\left", "") 180 | string = string.replace("\\right", "") 181 | 182 | # Remove circ (degrees) 183 | string = string.replace("^{\\circ}", "") 184 | string = string.replace("^\\circ", "") 185 | 186 | # remove dollar signs 187 | string = string.replace("\\$", "") 188 | 189 | # remove units (on the right) 190 | string = remove_right_units(string) 191 | 192 | # remove percentage 193 | string = string.replace("\\%", "") 194 | string = string.replace("\%", "") # noqa: W605 195 | 196 | # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string 197 | string = string.replace(" .", " 0.") 198 | string = string.replace("{.", "{0.") 199 | # if empty, return empty string 200 | if len(string) == 0: 201 | return string 202 | if string[0] == ".": 203 | string = "0" + string 204 | 205 | # to consider: get rid of e.g. "k = " or "q = " at beginning 206 | if len(string.split("=")) == 2: 207 | if len(string.split("=")[0]) <= 2: 208 | string = string.split("=")[1] 209 | 210 | # fix sqrt3 --> sqrt{3} 211 | string = fix_sqrt(string) 212 | 213 | # remove spaces 214 | string = string.replace(" ", "") 215 | 216 | # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} 217 | string = fix_fracs(string) 218 | 219 | # manually change 0.5 --> \frac{1}{2} 220 | if string == "0.5": 221 | string = "\\frac{1}{2}" 222 | 223 | # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y 224 | string = fix_a_slash_b(string) 225 | 226 | return string 227 | 228 | 229 | def is_equiv(str1, str2, verbose=False): 230 | if str1 is None and str2 is None: 231 | print("WARNING: Both None") 232 | return True 233 | if str1 is None or str2 is None: 234 | return False 235 | 236 | try: 237 | ss1 = strip_string(str1) 238 | ss2 = strip_string(str2) 239 | #pdb.set_trace() 240 | if verbose: 241 | print(ss1, ss2) 242 | return ss1 == ss2 243 | except Exception: 244 | return str1 == str2 245 | 246 | class NotEqual: 247 | def __eq__(self, other): 248 | return False --------------------------------------------------------------------------------