├── README.md ├── data ├── GSM8K_test.jsonl ├── GSM8K_train.jsonl ├── MATH_test.jsonl └── MATH_train.jsonl ├── dpo ├── alignment │ ├── __init__.py │ ├── configs.py │ ├── data.py │ ├── model_utils.py │ └── release.py ├── custom_trainer.py ├── prepare_dataset.py ├── run_dpo.py └── run_kto.py ├── eval ├── gsm8k │ ├── eval_gsm8k.py │ ├── run_gsm8k_eval.sh │ └── utils_ans.py └── math │ ├── configs │ ├── few_shot_test_configs.json │ └── zero_shot_test_configs.json │ ├── data_processing │ ├── __pycache__ │ │ ├── answer_extraction.cpython-310.pyc │ │ ├── answer_extraction.cpython-311.pyc │ │ └── process_utils.cpython-310.pyc │ ├── answer_extraction.py │ └── process_utils.py │ ├── eval │ ├── __pycache__ │ │ ├── eval_script.cpython-310.pyc │ │ ├── eval_script.cpython-311.pyc │ │ ├── eval_utils.cpython-310.pyc │ │ ├── eval_utils.cpython-311.pyc │ │ ├── ocwcourses_eval_utils.cpython-310.pyc │ │ ├── ocwcourses_eval_utils.cpython-311.pyc │ │ ├── python_executor.cpython-310.pyc │ │ └── utils.cpython-310.pyc │ ├── eval_script.py │ ├── eval_utils.py │ ├── ocwcourses_eval_utils.py │ ├── python_executor.py │ └── utils.py │ ├── run_cot_eval.py │ ├── run_math_eval.sh │ ├── run_subset_parallel.py │ └── utils.py ├── gen ├── gen_rft_data.py ├── gen_rft_data.sh ├── gen_step_explore.py ├── gen_step_explore.sh ├── get_dpo_data.py ├── get_rft_data.py ├── math_utils │ ├── answer_extraction.py │ ├── eval_script.py │ ├── eval_utils.py │ └── ocwcourses_eval_utils.py └── utils_others.py ├── images ├── main_result_image.png └── overview_image.png ├── requirements.txt ├── scripts ├── gsm8k │ ├── dpo │ │ ├── config.yaml │ │ ├── deepspeed_zero3.yaml │ │ └── run_dpo.sh │ └── sft │ │ ├── config.yaml │ │ ├── run_ft.sh │ │ └── run_rft.sh └── math │ ├── dpo │ ├── config.yaml │ ├── deepspeed_zero3.yaml │ └── run_dpo.sh │ └── sft │ ├── config.yaml │ ├── run_ft.sh │ └── run_rft.sh └── sft ├── sft_datasets.py ├── train_generator.py └── utils ├── cached_models.py ├── constants.py ├── datasets.py ├── flash_attn_monkey_patch.py ├── gsm8k ├── __init__.py ├── decoding.py └── metrics.py ├── metrics.py ├── models.py ├── optim.py ├── sampling.py ├── states.py └── verifier_models.py /README.md: -------------------------------------------------------------------------------- 1 | # Self-Explore 2 | #### Self-Explore to avoid ️the p️️it!
Improving the Reasoning Capabilities of Language Models with Fine-grained Rewards 3 | --- 4 | This is the official github repository for **Self-Explore**.

5 | Paper Link: https://arxiv.org/abs/2404.10346 6 | 7 | ## Overview: 8 | ![Overview Image](images/overview_image.png) 9 | 10 | ## Setting 11 | 12 | Run ``pip install -r requirements.txt``
13 | All experiments were carried out using 4 x NVIDIA A100 80GB, with CUDA version 12.0. 14 | 15 | 16 | ## Data 17 | 18 | In the data directory, you will find the train and test file for `GSM8K` and `MATH`. 19 | 20 | ## Training 21 | 22 | > #### Stage 1. Run SFT: 23 | Run **SFT** (or FT, in short) to get the base generator.
24 | In `/scripts/{task}/sft/run_ft.sh` you'll see the script necessary for this. (For data_path, please put the trian file.)
25 | Put necessary paths to the files and models then simply run `sh scripts/{task}/sft/run_ft.sh` in the main directory. 26 | 27 | > #### Stage 2. Get RFT Data: 28 | Now you'll need to generate *N* instances per problem.
29 | To do this, go to `gen` directory and run `sh gen_rft_data.sh`.
30 | This assumes you are using 4 GPUs, and generates the predictions in parallel using each GPU.
31 | Once completed, you will see **RFT** and **DPO** training file. 32 | 33 | > #### Stage 3. Run RFT: 34 | Run **RFT** to get the RFT model, which acts our explorer and reference model when training for DPO.
35 | in `/scripts/{task}/sft/run_rft.sh` you'll see the script necessary for this.
36 | Put necessary paths to the files and models then simply run `sh /scripts/{task}/sft/run_rft.sh` in the main directory. 37 | 38 | > #### Stage 4. 🔎 Explore : 39 | To find the first ***pit***, let the RFT model explore from each step within rejected sample.
40 | You can do this by running `gen_step_explore.sh` in `gen` directory. (For data_path here, please put the DPO file generated).
41 | Then you will get a file named ending in `gpair_{k}.jsonl`
42 | which is your fine-grained pairwise training data. 43 | 44 | > #### Stage 5. Train with Preference Learning Objective: 45 | You can apply any arbitrary preference learning objective, but in our work, we chose **DPO (Direct Preference Optimization)**.
46 | To do this refer to `scripts/{task}/dpo/run_dpo.sh`. 47 | - To run with the outcome-supervision labels, set the training data as the DPO file generated in Stage 3. 48 | - To run with the step-level fine-grained labels (ours), set the training data as the gpair file generated in Stage 4. 49 | 50 | ## Evaluation 51 | 52 | Under `eval/{task}` directory, you'll find the script needed for running evaluation. 53 | 54 | ## Results 55 | ![Result Image](images/main_result_image.png) 56 | 57 | ## Models 58 | We release our best trained DeepSeek-Math's **GSM8K** and **MATH** trained checkpoints on huggingface. 59 | | Model | Accuracy | Download | 60 | | :----------------------- | :-------------: | :----------------------------------------------------------: | 61 | |**DeepSeek_Math_Self_Explore_GSM8K** | 78.62 | 🤗 [HuggingFace](https://huggingface.co/hbin0701/DeepSeek_MATH_Self_Explore) | 62 | |**DeepSeek_Math_Self_Explore_MATH** | 37.68 | 🤗 [HuggingFace](https://huggingface.co/hbin0701/DeepSeek_GSM8K_Self_Exploret) | 63 | 64 | ## Acknowledgemenets 65 | Our evaluation codes are borrowed from:
66 | - GSM8K: [OVM](https://github.com/FreedomIntelligence/OVM)
67 | - MATH: [DeepSeek-Math](https://github.com/deepseek-ai/DeepSeek-Math) 68 | 69 | ## Citation 70 | ``` 71 | @misc{hwang2024selfexplore, 72 | title={Self-Explore to Avoid the Pit: Improving the Reasoning Capabilities of Language Models with Fine-grained Rewards}, 73 | author={Hyeonbin Hwang and Doyoung Kim and Seungone Kim and Seonghyeon Ye and Minjoon Seo}, 74 | year={2024}, 75 | eprint={2404.10346}, 76 | archivePrefix={arXiv}, 77 | primaryClass={cs.CL} 78 | } 79 | ``` 80 | -------------------------------------------------------------------------------- /dpo/alignment/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.2.0.dev0" 2 | 3 | from .configs import DataArguments, DPOConfig, H4ArgumentParser, ModelArguments, SFTConfig 4 | from .data import apply_chat_template, get_datasets 5 | from .model_utils import get_kbit_device_map, get_peft_config, get_quantization_config, get_tokenizer, is_adapter_model 6 | -------------------------------------------------------------------------------- /dpo/alignment/data.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # coding=utf-8 3 | # Copyright 2023 The HuggingFace Team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import re 18 | from typing import List, Literal, Optional 19 | 20 | from datasets import DatasetDict, concatenate_datasets, load_dataset 21 | 22 | from .configs import DataArguments 23 | 24 | 25 | DEFAULT_CHAT_TEMPLATE = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}" 26 | 27 | 28 | def apply_chat_template( 29 | example, tokenizer, task: Literal["sft", "generation", "rm", "dpo"] = "sft", assistant_prefix="<|assistant|>\n" 30 | ): 31 | def _strip_prefix(s, pattern): 32 | # Use re.escape to escape any special characters in the pattern 33 | return re.sub(f"^{re.escape(pattern)}", "", s) 34 | 35 | if task in ["sft", "generation"]: 36 | messages = example["messages"] 37 | # We add an empty system message if there is none 38 | if messages[0]["role"] != "system": 39 | messages.insert(0, {"role": "system", "content": ""}) 40 | example["text"] = tokenizer.apply_chat_template( 41 | messages, tokenize=False, add_generation_prompt=True if task == "generation" else False 42 | ) 43 | elif task == "rm": 44 | if all(k in example.keys() for k in ("chosen", "rejected")): 45 | chosen_messages = example["chosen"] 46 | rejected_messages = example["rejected"] 47 | # We add an empty system message if there is none 48 | if chosen_messages[0]["role"] != "system": 49 | chosen_messages.insert(0, {"role": "system", "content": ""}) 50 | if rejected_messages[0]["role"] != "system": 51 | rejected_messages.insert(0, {"role": "system", "content": ""}) 52 | example["text_chosen"] = tokenizer.apply_chat_template(chosen_messages, tokenize=False) 53 | example["text_rejected"] = tokenizer.apply_chat_template(rejected_messages, tokenize=False) 54 | else: 55 | raise ValueError( 56 | f"Could not format example as dialogue for `rm` task! Require `[chosen, rejected]` keys but found {list(example.keys())}" 57 | ) 58 | elif task == "dpo": 59 | if all(k in example.keys() for k in ("chosen", "rejected")): 60 | # Compared to reward modeling, we filter out the prompt, so the text is everything after the last assistant token 61 | prompt_messages = [[msg for msg in example["chosen"] if msg["role"] == "user"][0]] 62 | # Insert system message 63 | if example["chosen"][0]["role"] != "system": 64 | prompt_messages.insert(0, {"role": "system", "content": ""}) 65 | else: 66 | prompt_messages.insert(0, example["chosen"][0]) 67 | # TODO: handle case where chosen/rejected also have system messages 68 | chosen_messages = example["chosen"][1:] 69 | rejected_messages = example["rejected"][1:] 70 | example["text_chosen"] = tokenizer.apply_chat_template(chosen_messages, tokenize=False) 71 | example["text_rejected"] = tokenizer.apply_chat_template(rejected_messages, tokenize=False) 72 | example["text_prompt"] = tokenizer.apply_chat_template( 73 | prompt_messages, tokenize=False, add_generation_prompt=True 74 | ) 75 | 76 | example["text_chosen"] = _strip_prefix(example["text_chosen"], assistant_prefix) 77 | example["text_rejected"] = _strip_prefix(example["text_rejected"], assistant_prefix) 78 | else: 79 | raise ValueError( 80 | f"Could not format example as dialogue for `dpo` task! Require `[chosen, rejected]` keys but found {list(example.keys())}" 81 | ) 82 | return example 83 | 84 | 85 | def get_datasets( 86 | data_config, 87 | splits: List[str] = ["train", "test"], 88 | shuffle: bool = True, 89 | ) -> DatasetDict: 90 | """ 91 | Loads one or more datasets with varying training set proportions. 92 | 93 | Args: 94 | data_config (`DataArguments` or `dict`): 95 | Dataset configuration and split proportions. 96 | splits (`List[str]`, *optional*, defaults to `['train', 'test']`): 97 | Dataset splits to load and mix. Assumes the splits exist in all datasets and have a `train_` or `test_` prefix. 98 | shuffle (`bool`, *optional*, defaults to `True`): 99 | Whether to shuffle the training data. 100 | 101 | Returns 102 | [`DatasetDict`]: The dataset dictionary containing the loaded datasets. 103 | """ 104 | 105 | if type(data_config) is DataArguments: 106 | # Structure of the config to read the datasets and their mix 107 | # datasets_mixer: 108 | # - 'dataset1': 0.5 109 | # - 'dataset2': 0.3 110 | # - 'dataset3': 0.2 111 | dataset_mixer = data_config.dataset_mixer 112 | elif type(data_config) is dict: 113 | # Structure of the input is: 114 | # dataset_mixer = { 115 | # "dataset1": 0.5, 116 | # "dataset1": 0.3, 117 | # "dataset1": 0.2, 118 | # } 119 | dataset_mixer = data_config 120 | else: 121 | raise ValueError(f"Data config {data_config} not recognized.") 122 | 123 | raw_datasets = mix_datasets(dataset_mixer, splits=splits, shuffle=shuffle) 124 | return raw_datasets 125 | 126 | 127 | def mix_datasets(dataset_mixer: dict, splits: Optional[List[str]] = None, shuffle=True) -> DatasetDict: 128 | """ 129 | Loads and mixes datasets according to proportions specified in `dataset_mixer`. 130 | 131 | Args: 132 | dataset_mixer (`dict`): 133 | Dictionary containing the dataset names and their training proportions. By default, all test proportions are 1. 134 | splits (Optional[List[str]], *optional*, defaults to `None`): 135 | Dataset splits to load and mix. Assumes the splits exist in all datasets and have a `train_` or `test_` prefix. 136 | shuffle (`bool`, *optional*, defaults to `True`): 137 | Whether to shuffle the training data. 138 | """ 139 | raw_datasets = DatasetDict() 140 | raw_train_datasets = [] 141 | raw_val_datasets = [] 142 | fracs = [] 143 | for ds, frac in dataset_mixer.items(): 144 | fracs.append(frac) 145 | for split in splits: 146 | if "train" in split: 147 | raw_train_datasets.append( 148 | load_dataset( 149 | ds, 150 | split=split, 151 | ) 152 | ) 153 | elif "test" in split: 154 | raw_val_datasets.append( 155 | load_dataset( 156 | ds, 157 | split=split, 158 | ) 159 | ) 160 | else: 161 | raise ValueError(f"Split type {split} not recognized as one of test or train.") 162 | 163 | if any(frac < 0 for frac in fracs): 164 | raise ValueError("Dataset fractions cannot be negative.") 165 | 166 | if len(raw_train_datasets) > 0: 167 | train_subsets = [] 168 | for dataset, frac in zip(raw_train_datasets, fracs): 169 | train_subset = dataset.select(range(int(frac * len(dataset)))) 170 | train_subsets.append(train_subset) 171 | if shuffle: 172 | raw_datasets["train"] = concatenate_datasets(train_subsets).shuffle(seed=42) 173 | else: 174 | raw_datasets["train"] = concatenate_datasets(train_subsets) 175 | # No subsampling for test datasets to enable fair comparison across models 176 | if len(raw_val_datasets) > 0: 177 | if shuffle: 178 | raw_datasets["test"] = concatenate_datasets(raw_val_datasets).shuffle(seed=42) 179 | else: 180 | raw_datasets["test"] = concatenate_datasets(raw_val_datasets) 181 | 182 | if len(raw_datasets) == 0: 183 | raise ValueError( 184 | f"Dataset {dataset_mixer} not recognized with split {split}. Check the dataset has been correctly formatted." 185 | ) 186 | 187 | return raw_datasets 188 | -------------------------------------------------------------------------------- /dpo/alignment/model_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # coding=utf-8 3 | # Copyright 2023 The HuggingFace Team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | from typing import Dict 18 | 19 | import torch 20 | from transformers import AutoTokenizer, BitsAndBytesConfig, PreTrainedTokenizer 21 | 22 | from accelerate import Accelerator 23 | from huggingface_hub import list_repo_files 24 | from peft import LoraConfig, PeftConfig 25 | 26 | from .configs import DataArguments, ModelArguments 27 | from .data import DEFAULT_CHAT_TEMPLATE 28 | 29 | 30 | def get_current_device() -> int: 31 | """Get the current device. For GPU we return the local process index to enable multiple GPU training.""" 32 | return Accelerator().local_process_index if torch.cuda.is_available() else "cpu" 33 | 34 | 35 | def get_kbit_device_map(): 36 | """Useful for running inference with quantized models by setting `device_map=get_peft_device_map()`""" 37 | return {"": get_current_device()} if torch.cuda.is_available() else None 38 | 39 | 40 | def get_quantization_config(model_args): 41 | if model_args.load_in_4bit: 42 | quantization_config = BitsAndBytesConfig( 43 | load_in_4bit=True, 44 | bnb_4bit_compute_dtype=torch.float16, # For consistency with model weights, we use the same value as `torch_dtype` which is float16 for PEFT models 45 | bnb_4bit_quant_type=model_args.bnb_4bit_quant_type, 46 | bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant, 47 | ) 48 | elif model_args.load_in_8bit: 49 | quantization_config = BitsAndBytesConfig( 50 | load_in_8bit=True, 51 | ) 52 | else: 53 | quantization_config = None 54 | 55 | return quantization_config 56 | 57 | 58 | def get_tokenizer(model_args: ModelArguments, data_args: DataArguments) -> PreTrainedTokenizer: 59 | """Get the tokenizer for the model.""" 60 | tokenizer = AutoTokenizer.from_pretrained( 61 | model_args.model_name_or_path, 62 | revision=model_args.model_revision, 63 | ) 64 | if tokenizer.pad_token_id is None: 65 | tokenizer.pad_token_id = tokenizer.eos_token_id 66 | 67 | if data_args.truncation_side is not None: 68 | tokenizer.truncation_side = data_args.truncation_side 69 | 70 | # Set reasonable default for models without max length 71 | if tokenizer.model_max_length > 100_000: 72 | tokenizer.model_max_length = 2048 73 | 74 | if data_args.chat_template is not None: 75 | tokenizer.chat_template = data_args.chat_template 76 | elif tokenizer.chat_template is None: 77 | tokenizer.chat_template = DEFAULT_CHAT_TEMPLATE 78 | 79 | return tokenizer 80 | 81 | 82 | def get_peft_config(model_args: ModelArguments): 83 | if model_args.use_peft is False: 84 | return None 85 | 86 | peft_config = LoraConfig( 87 | r=model_args.lora_r, 88 | lora_alpha=model_args.lora_alpha, 89 | lora_dropout=model_args.lora_dropout, 90 | bias="none", 91 | task_type="CAUSAL_LM", 92 | target_modules=model_args.lora_target_modules, 93 | modules_to_save=model_args.lora_modules_to_save, 94 | ) 95 | 96 | return peft_config 97 | 98 | 99 | def is_adapter_model(model_name_or_path: str, revision: str = "main") -> bool: 100 | repo_files = list_repo_files(model_name_or_path, revision=revision) 101 | return "adapter_model.safetensors" in repo_files or "adapter_model.bin" in repo_files 102 | -------------------------------------------------------------------------------- /dpo/alignment/release.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The HuggingFace Team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import argparse 17 | import re 18 | 19 | import packaging.version 20 | 21 | 22 | REPLACE_PATTERNS = { 23 | "init": (re.compile(r'^__version__\s+=\s+"([^"]+)"\s*$', re.MULTILINE), '__version__ = "VERSION"\n'), 24 | "setup": (re.compile(r'^(\s*)version\s*=\s*"[^"]+",', re.MULTILINE), r'\1version="VERSION",'), 25 | } 26 | REPLACE_FILES = { 27 | "init": "src/alignment/__init__.py", 28 | "setup": "setup.py", 29 | } 30 | README_FILE = "README.md" 31 | 32 | 33 | def update_version_in_file(fname, version, pattern): 34 | """Update the version in one file using a specific pattern.""" 35 | with open(fname, "r", encoding="utf-8", newline="\n") as f: 36 | code = f.read() 37 | re_pattern, replace = REPLACE_PATTERNS[pattern] 38 | replace = replace.replace("VERSION", version) 39 | code = re_pattern.sub(replace, code) 40 | with open(fname, "w", encoding="utf-8", newline="\n") as f: 41 | f.write(code) 42 | 43 | 44 | def global_version_update(version, patch=False): 45 | """Update the version in all needed files.""" 46 | for pattern, fname in REPLACE_FILES.items(): 47 | update_version_in_file(fname, version, pattern) 48 | 49 | 50 | def get_version(): 51 | """Reads the current version in the __init__.""" 52 | with open(REPLACE_FILES["init"], "r") as f: 53 | code = f.read() 54 | default_version = REPLACE_PATTERNS["init"][0].search(code).groups()[0] 55 | return packaging.version.parse(default_version) 56 | 57 | 58 | def pre_release_work(patch=False): 59 | """Do all the necessary pre-release steps.""" 60 | # First let's get the default version: base version if we are in dev, bump minor otherwise. 61 | default_version = get_version() 62 | if patch and default_version.is_devrelease: 63 | raise ValueError("Can't create a patch version from the dev branch, checkout a released version!") 64 | if default_version.is_devrelease: 65 | default_version = default_version.base_version 66 | elif patch: 67 | default_version = f"{default_version.major}.{default_version.minor}.{default_version.micro + 1}" 68 | else: 69 | default_version = f"{default_version.major}.{default_version.minor + 1}.0" 70 | 71 | # Now let's ask nicely if that's the right one. 72 | version = input(f"Which version are you releasing? [{default_version}]") 73 | if len(version) == 0: 74 | version = default_version 75 | 76 | print(f"Updating version to {version}.") 77 | global_version_update(version, patch=patch) 78 | 79 | 80 | def post_release_work(): 81 | """Do all the necessary post-release steps.""" 82 | # First let's get the current version 83 | current_version = get_version() 84 | dev_version = f"{current_version.major}.{current_version.minor + 1}.0.dev0" 85 | current_version = current_version.base_version 86 | 87 | # Check with the user we got that right. 88 | version = input(f"Which version are we developing now? [{dev_version}]") 89 | if len(version) == 0: 90 | version = dev_version 91 | 92 | print(f"Updating version to {version}.") 93 | global_version_update(version) 94 | 95 | 96 | if __name__ == "__main__": 97 | parser = argparse.ArgumentParser() 98 | parser.add_argument("--post_release", action="store_true", help="Whether this is pre or post release.") 99 | parser.add_argument("--patch", action="store_true", help="Whether or not this is a patch release.") 100 | args = parser.parse_args() 101 | if not args.post_release: 102 | pre_release_work(patch=args.patch) 103 | elif args.patch: 104 | print("Nothing to do after a patch :-)") 105 | else: 106 | post_release_work() 107 | -------------------------------------------------------------------------------- /dpo/prepare_dataset.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import json 3 | from datasets import Dataset, DatasetDict 4 | 5 | 6 | def get_dataset(path_train, path_test): 7 | 8 | final_dict = {} 9 | 10 | for idx, path in enumerate([path_train, path_test]): 11 | li = [json.loads(x) for x in open(path)] 12 | 13 | text_prompts = [] 14 | chosen = [] 15 | rejected = [] 16 | 17 | # If test (i.e. validation set), just use dummy data: subset of training. 18 | if idx == 1: 19 | li = li[:100] 20 | 21 | for elem in li: 22 | 23 | if 'prompt' not in elem.keys(): 24 | elem['prompt'] = elem['input'] 25 | 26 | text_prompts.append(elem['prompt'].rstrip("\n") + "\n") 27 | chosen.append(elem['chosen']) 28 | 29 | rej = elem['rejected'] 30 | 31 | # [Note] Remove last line for rejected sample. 32 | tgt_string = "The answer is" 33 | if tgt_string in rej and len(rej.split("\n")) > 1: 34 | rej = rej[:rej.index(tgt_string)].strip() 35 | 36 | rejected.append(rej) 37 | 38 | d = {"prompt": text_prompts, "text_prompt": text_prompts, "chosen": chosen, "rejected": rejected} 39 | 40 | if idx == 0: 41 | final_dict["train"] = d 42 | else: 43 | final_dict["test"] = d 44 | 45 | train_dataset = Dataset.from_dict(final_dict["train"]) 46 | test_dataset = Dataset.from_dict(final_dict["test"]) 47 | 48 | print("train example", train_dataset[0]) 49 | 50 | # Create a DatasetDict 51 | dataset_dict = DatasetDict({ 52 | 'train': train_dataset.shuffle(seed=42), 53 | 'test': test_dataset.shuffle(seed=42) 54 | }) 55 | 56 | return dataset_dict -------------------------------------------------------------------------------- /dpo/run_dpo.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | import wandb 4 | import torch 5 | import transformers 6 | from transformers import AutoModelForCausalLM, set_seed, LlamaTokenizer, AutoTokenizer 7 | from accelerate import Accelerator 8 | 9 | from alignment import ( 10 | DataArguments, 11 | DPOConfig, 12 | H4ArgumentParser, 13 | ModelArguments, 14 | apply_chat_template, 15 | get_datasets, 16 | get_kbit_device_map, 17 | get_peft_config, 18 | get_quantization_config, 19 | get_tokenizer, 20 | is_adapter_model, 21 | ) 22 | 23 | 24 | from peft import PeftConfig, PeftModel 25 | from custom_trainer import DPOTrainer 26 | from prepare_dataset import get_dataset 27 | import os 28 | 29 | os.environ["WANDB_API_KEY"] = "" # PUR YOUR WANDB KEY 30 | os.environ["WANDB_ENTITY"] = "" # PUT YOUR WANDB ID 31 | os.environ["WANDB_PROJECT"] = "" # PUT YOUR WANDB PROJECT 32 | 33 | logger = logging.getLogger(__name__) 34 | 35 | 36 | def main(): 37 | parser = H4ArgumentParser((ModelArguments, DataArguments, DPOConfig)) 38 | model_args, data_args, training_args = parser.parse() 39 | 40 | ####### 41 | # Setup 42 | ####### 43 | logging.basicConfig( 44 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 45 | datefmt="%Y-%m-%d %H:%M:%S", 46 | handlers=[logging.StreamHandler(sys.stdout)], 47 | ) 48 | log_level = training_args.get_process_log_level() 49 | logger.setLevel(log_level) 50 | transformers.utils.logging.set_verbosity(log_level) 51 | transformers.utils.logging.enable_default_handler() 52 | transformers.utils.logging.enable_explicit_format() 53 | 54 | # Log on each process the small summary: 55 | logger.info(f"Model parameters {model_args}") 56 | logger.info(f"Data parameters {data_args}") 57 | logger.info(f"Training/evaluation parameters {training_args}") 58 | 59 | # Set seed for reproducibility 60 | set_seed(training_args.seed) 61 | 62 | # Increase distributed timeout to 3h to enable push to Hub to complete 63 | accelerator = Accelerator() 64 | tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path) 65 | 66 | # there should be "train" and "test", with column anme "text_prompt", "text_chosen", "text_rejected". 67 | raw_datasets = get_dataset(data_args.train_data_file, data_args.test_data_file) 68 | print("dataset", raw_datasets) 69 | print(("---" * 30)) 70 | torch_dtype = ( 71 | model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) 72 | ) 73 | model_kwargs = dict( 74 | revision=model_args.model_revision, 75 | trust_remote_code=model_args.trust_remote_code, 76 | use_flash_attention_2=model_args.use_flash_attention_2, 77 | torch_dtype=torch_dtype, 78 | use_cache=False if training_args.gradient_checkpointing else True, 79 | device_map=get_kbit_device_map(), 80 | quantization_config=get_quantization_config(model_args), 81 | ) 82 | 83 | model = model_args.model_name_or_path 84 | ref_model = model 85 | ref_model_kwargs = model_kwargs 86 | 87 | if model_args.use_peft is True: 88 | ref_model = None 89 | ref_model_kwargs = None 90 | 91 | ######################### 92 | # Instantiate DPO trainer 93 | ######################### 94 | dpo_trainer = DPOTrainer( 95 | model, 96 | ref_model, 97 | model_init_kwargs=model_kwargs, 98 | ref_model_init_kwargs=ref_model_kwargs, 99 | args=training_args, 100 | beta=training_args.beta, 101 | train_dataset= raw_datasets["train"], 102 | eval_dataset=raw_datasets["test"], 103 | tokenizer=tokenizer, 104 | max_length=training_args.max_length, 105 | max_prompt_length=training_args.max_prompt_length, 106 | peft_config=get_peft_config(model_args), 107 | loss_type=training_args.loss_type 108 | ) 109 | 110 | ############### 111 | # Training loop 112 | ############### 113 | train_result = dpo_trainer.train() 114 | metrics = train_result.metrics 115 | max_train_samples = ( 116 | data_args.max_train_samples if data_args.max_train_samples is not None else len(raw_datasets["train"]) 117 | ) 118 | metrics["train_samples"] = min(max_train_samples, len(raw_datasets["train"])) 119 | dpo_trainer.log_metrics("train", metrics) 120 | dpo_trainer.save_metrics("train", metrics) 121 | dpo_trainer.save_state() 122 | 123 | logger.info("*** Training complete ***") 124 | 125 | # Evaluate 126 | if training_args.do_eval: 127 | logger.info("*** Evaluate ***") 128 | metrics = dpo_trainer.evaluate() 129 | max_eval_samples = ( 130 | data_args.max_eval_samples if data_args.max_eval_samples is not None else len(raw_datasets["test"]) 131 | ) 132 | metrics["eval_samples"] = min(max_eval_samples, len(raw_datasets["test"])) 133 | dpo_trainer.log_metrics("eval", metrics) 134 | dpo_trainer.save_metrics("eval", metrics) 135 | 136 | # Save model and create model card 137 | dpo_trainer.save_model(training_args.output_dir) 138 | 139 | # Save everything else on main process 140 | if accelerator.is_main_process: 141 | kwargs = { 142 | "finetuned_from": model_args.model_name_or_path, 143 | "dataset": list(data_args.dataset_mixer.keys()), 144 | "dataset_tags": list(data_args.dataset_mixer.keys()), 145 | "tags": ["alignment-handbook"], 146 | } 147 | dpo_trainer.create_model_card(**kwargs) 148 | # Restore k,v cache for fast inference 149 | dpo_trainer.model.config.use_cache = True 150 | dpo_trainer.model.config.save_pretrained(training_args.output_dir) 151 | if training_args.push_to_hub is True: 152 | dpo_trainer.push_to_hub() 153 | 154 | # Ensure we don't timeout on model save / push to Hub 155 | logger.info("*** Waiting for all processes to finish ***") 156 | accelerator.wait_for_everyone() 157 | 158 | logger.info("*** Run complete! ***") 159 | 160 | 161 | if __name__ == "__main__": 162 | main() 163 | -------------------------------------------------------------------------------- /dpo/run_kto.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | import wandb 4 | import torch 5 | import transformers 6 | from transformers import AutoModelForCausalLM, set_seed, LlamaTokenizer, AutoTokenizer 7 | from accelerate import Accelerator 8 | 9 | from alignment import ( 10 | DataArguments, 11 | DPOConfig, 12 | H4ArgumentParser, 13 | ModelArguments, 14 | apply_chat_template, 15 | get_datasets, 16 | get_kbit_device_map, 17 | get_peft_config, 18 | get_quantization_config, 19 | get_tokenizer, 20 | is_adapter_model, 21 | ) 22 | 23 | from peft import PeftConfig, PeftModel 24 | # from custom_trainer import DPOTrainer, KTOTrainer, KTOConfig 25 | from prepare_dataset import get_dataset 26 | # from trls.trl.trainer import KTOTrainer, KTOConfig 27 | from custom_trainer import KTOTrainer, KTOConfig 28 | # from step_prepare_dataset import get_dataset 29 | import os 30 | 31 | os.environ["WANDB_API_KEY"] = "" # PUR YOUR WANDB KEY 32 | os.environ["WANDB_ENTITY"] = "" # PUT YOUR WANDB ID 33 | os.environ["WANDB_PROJECT"] = "" # PUT YOUR WANDB PROJECT 34 | 35 | logger = logging.getLogger(__name__) 36 | 37 | 38 | def main(): 39 | parser = H4ArgumentParser((ModelArguments, DataArguments, KTOConfig)) 40 | model_args, data_args, training_args = parser.parse() 41 | 42 | ####### 43 | # Setup 44 | ####### 45 | logging.basicConfig( 46 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 47 | datefmt="%Y-%m-%d %H:%M:%S", 48 | handlers=[logging.StreamHandler(sys.stdout)], 49 | ) 50 | log_level = training_args.get_process_log_level() 51 | logger.setLevel(log_level) 52 | transformers.utils.logging.set_verbosity(log_level) 53 | transformers.utils.logging.enable_default_handler() 54 | transformers.utils.logging.enable_explicit_format() 55 | 56 | # Log on each process the small summary: 57 | logger.info(f"Model parameters {model_args}") 58 | logger.info(f"Data parameters {data_args}") 59 | logger.info(f"Training/evaluation parameters {training_args}") 60 | 61 | # Set seed for reproducibility 62 | set_seed(training_args.seed) 63 | 64 | # Increase distributed timeout to 3h to enable push to Hub to complete 65 | accelerator = Accelerator() 66 | tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path) 67 | 68 | # there should be "train" and "test", with column anme "text_prompt", "text_chosen", "text_rejected". 69 | # raw_datasets = get_dataset(data_args.train_data_file, data_args.test_data_file) 70 | import json 71 | from datasets import Dataset 72 | 73 | data = json.load(open(data_args.train_data_file)) 74 | 75 | for idx in range(len(data['label'])): 76 | if data['label'][idx] == False: 77 | rej = data['completion'][idx] 78 | tgt_string = "The answer is" 79 | if tgt_string in rej and len(rej.split("\n")) > 1: 80 | rej = rej[:rej.index(tgt_string)].strip() 81 | data['completion'][idx] = rej 82 | 83 | # For format of KTO Dataset, see: https://huggingface.co/docs/trl/main/en/kto_trainer 84 | train_set = Dataset.from_dict({k: v for k, v in data.items()}) 85 | train_set = train_set.shuffle(seed=42) 86 | test_set = Dataset.from_dict({k: v[:100] for k, v in data.items()}) 87 | 88 | model_args.torch_dtype = "bfloat16" 89 | torch_dtype = ( 90 | model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) 91 | ) 92 | print("torch_dtype", torch_dtype) 93 | 94 | model_kwargs = dict( 95 | revision=model_args.model_revision, 96 | trust_remote_code=model_args.trust_remote_code, 97 | use_flash_attention_2=True, 98 | torch_dtype=torch_dtype, 99 | use_cache=False, 100 | device_map=get_kbit_device_map(), 101 | quantization_config=get_quantization_config(model_args), 102 | ) 103 | 104 | model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, **model_kwargs) 105 | ref_model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, **model_kwargs) 106 | 107 | # if model_args.use_peft is True: 108 | # ref_model = None 109 | # ref_model_kwargs = None 110 | 111 | ######################### 112 | # Instantiate DPO trainer 113 | ######################### 114 | # training_args.model_init_kwargs = model_kwargs 115 | # training_args.ref_model_init_kwargs = ref_model_kwargs 116 | 117 | kto_trainer = KTOTrainer( 118 | model, 119 | ref_model, 120 | args=training_args, 121 | train_dataset= train_set, 122 | eval_dataset=test_set, 123 | tokenizer=tokenizer, 124 | peft_config=get_peft_config(model_args), 125 | ) 126 | 127 | ############### 128 | # Training loop 129 | ############### 130 | train_result = kto_trainer.train() 131 | metrics = train_result.metrics 132 | max_train_samples = ( 133 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_set) 134 | ) 135 | metrics["train_samples"] = min(max_train_samples, len(train_set)) 136 | kto_trainer.log_metrics("train", metrics) 137 | kto_trainer.save_metrics("train", metrics) 138 | kto_trainer.save_state() 139 | 140 | logger.info("*** Training complete ***") 141 | 142 | # Evaluate 143 | # if training_args.do_eval: 144 | # logger.info("*** Evaluate ***") 145 | # metrics = kto_trainer.evaluate() 146 | # max_eval_samples = ( 147 | # data_args.max_eval_samples if data_args.max_eval_samples is not None else len(test_set) 148 | # ) 149 | # metrics["eval_samples"] = min(max_eval_samples, len(kto_trainer["test"])) 150 | # kto_trainer.log_metrics("eval", metrics) 151 | # kto_trainer.save_metrics("eval", metrics) 152 | 153 | # Save model and create model card 154 | kto_trainer.save_model(training_args.output_dir) 155 | 156 | # Save everything else on main process 157 | if accelerator.is_main_process: 158 | kwargs = { 159 | "finetuned_from": model_args.model_name_or_path, 160 | "dataset": list(data_args.dataset_mixer.keys()), 161 | "dataset_tags": list(data_args.dataset_mixer.keys()), 162 | "tags": ["alignment-handbook"], 163 | } 164 | kto_trainer.create_model_card(**kwargs) 165 | # Restore k,v cache for fast inference 166 | kto_trainer.model.config.use_cache = True 167 | kto_trainer.model.config.save_pretrained(training_args.output_dir) 168 | if training_args.push_to_hub is True: 169 | kto_trainer.push_to_hub() 170 | 171 | # Ensure we don't timeout on model save / push to Hub 172 | logger.info("*** Waiting for all processes to finish ***") 173 | accelerator.wait_for_everyone() 174 | 175 | logger.info("*** Run complete! ***") 176 | 177 | 178 | if __name__ == "__main__": 179 | main() 180 | -------------------------------------------------------------------------------- /eval/gsm8k/eval_gsm8k.py: -------------------------------------------------------------------------------- 1 | import os 2 | # os.environ["CUDA_VISIBLE_DEVICES"]="4,5,6,7" 3 | 4 | from huggingface_hub import login 5 | import argparse 6 | import json 7 | import re 8 | import jsonlines 9 | from fraction import Fraction 10 | from vllm import LLM, SamplingParams 11 | import sys 12 | from tqdm.auto import tqdm 13 | from utils_ans import extract_answer 14 | 15 | MAX_INT = sys.maxsize 16 | 17 | def batch_data(data_list, batch_size=1): 18 | n = len(data_list) // batch_size 19 | batch_data = [] 20 | for i in range(n-1): 21 | start = i * batch_size 22 | end = (i+1)*batch_size 23 | batch_data.append(data_list[start:end]) 24 | 25 | last_start = (n-1) * batch_size 26 | last_end = MAX_INT 27 | batch_data.append(data_list[last_start:last_end]) 28 | return batch_data 29 | 30 | def gsm8k_test(model, data_path, start=0, end=MAX_INT, batch_size=1, tensor_parallel_size=1, temp=0.7): 31 | gsm8k_ins = [] 32 | gsm8k_answers = [] 33 | 34 | # Check if it already exists. 35 | try: 36 | already_done = [json.loads(x) for x in open(args.result_file)] 37 | except: 38 | already_done = [] 39 | 40 | with open(data_path,"r+", encoding="utf8") as f: 41 | for idx, item in enumerate(jsonlines.Reader(f)): 42 | 43 | if idx < len(already_done): 44 | continue 45 | 46 | gsm8k_ins.append(item["query"]) 47 | temp_ans = int(item['response'].split('#### ')[1].replace(',', '')) 48 | gsm8k_answers.append(temp_ans) 49 | 50 | gsm8k_ins = gsm8k_ins[start:end] 51 | gsm8k_answers = gsm8k_answers[start:end] 52 | print('length ====', len(gsm8k_ins)) 53 | batch_gsm8k_ins = batch_data(gsm8k_ins, batch_size=batch_size) 54 | 55 | # stop_tokens = ["\n\n", "Question:", "Question", "USER:", "USER", "ASSISTANT:", "ASSISTANT", "Instruction:", "Instruction", "Response:", "Response"] 56 | stop_tokens = [] 57 | 58 | if temp == 0.7: 59 | n = 100 60 | else: 61 | n = 1 62 | 63 | sampling_params = SamplingParams(temperature=temp, top_p=1, max_tokens=512, stop=stop_tokens, n=n) 64 | print('sampling =====', sampling_params) 65 | llm = LLM(model=model,tensor_parallel_size=tensor_parallel_size, enforce_eager=True) 66 | result = [] 67 | res_completions = [] 68 | 69 | for idx, (prompt, prompt_answer) in tqdm(enumerate(zip(batch_gsm8k_ins, gsm8k_answers))): 70 | if isinstance(prompt, list): 71 | pass 72 | else: 73 | prompt = [prompt] 74 | 75 | completions = llm.generate(prompt, sampling_params) 76 | 77 | for num, output in enumerate(completions): 78 | prompt = output.prompt 79 | all_texts = [out.text for out in output.outputs] 80 | res_completions.append(all_texts) 81 | 82 | answer = gsm8k_answers[idx*batch_size + num] 83 | dict_ = {"prompt": prompt, "preds": all_texts, "answer": answer} 84 | 85 | with jsonlines.open(args.result_file, 'a') as writer: 86 | writer.write(dict_) 87 | 88 | li = [json.loads(x) for x in open(args.result_file)] 89 | 90 | sa = [] # singgle acc 91 | 92 | for x in li: 93 | if 'answers' in x: 94 | lbl = str(x['answers']) 95 | else: 96 | lbl = str(x['answer']) 97 | 98 | answers = [str(extract_answer(pred)) for pred in x['preds']] 99 | eq_answers = [ans == lbl for ans in answers] 100 | sa.append(eq_answers.count(True) / len(eq_answers)) 101 | 102 | print("Final Acc:", sum(sa) / len(sa)) 103 | 104 | def parse_args(): 105 | parser = argparse.ArgumentParser() 106 | parser.add_argument("--model", type=str) # model path 107 | parser.add_argument("--data_file", type=str, default='') # data path 108 | parser.add_argument("--start", type=int, default=0) # start index 109 | parser.add_argument("--end", type=int, default=MAX_INT) # end index 110 | parser.add_argument("--batch_size", type=int, default=1000) # batch_size 111 | parser.add_argument("--tensor_parallel_size", type=int, default=1) # tensor_parallel_size 112 | parser.add_argument("--result_file", type=str, default="./log_file.jsonl") # tensor_parallel_size 113 | parser.add_argument("--temp", type=float, default=0.7) 114 | 115 | return parser.parse_args() 116 | 117 | if __name__ == "__main__": 118 | # Login First. 119 | # login(token="your_huggingface_token") 120 | 121 | args = parse_args() 122 | gsm8k_test(model=args.model, data_path=args.data_file, start=args.start, end=args.end, batch_size=args.batch_size, tensor_parallel_size=args.tensor_parallel_size, temp=args.temp) -------------------------------------------------------------------------------- /eval/gsm8k/run_gsm8k_eval.sh: -------------------------------------------------------------------------------- 1 | model_path="" 2 | result_file="" 3 | data_file="" 4 | 5 | # To run evaluation for GSM8K: 6 | python eval_gsm8k.py --data_file $data_file --model $model_path --result_file $result_file --temp 0 7 | -------------------------------------------------------------------------------- /eval/gsm8k/utils_ans.py: -------------------------------------------------------------------------------- 1 | ### From OVM/utils/gsm8k/decoding.py 2 | 3 | from contextlib import contextmanager 4 | import signal 5 | import json 6 | import os 7 | import re 8 | 9 | # ANS_RE = re.compile(r"#### (\-?[0-9\.\,]+)") 10 | ANS_RE = re.compile(r"The answer is:?\s*(\-?[0-9\.\,]+)") 11 | INVALID_ANS = "[invalid]" 12 | 13 | 14 | def extract_answer(completion): 15 | match = ANS_RE.search(completion) 16 | if match: 17 | match_str = match.group(1).strip() 18 | st_str = standardize_value_str(match_str) 19 | try: eval(st_str); return st_str 20 | except: ... 21 | return INVALID_ANS 22 | 23 | def extract_answers(completions): 24 | return [extract_answer(completion) for completion in completions] 25 | 26 | def standardize_value_str(x): 27 | """Standardize numerical values""" 28 | y = x.replace(",", "") 29 | if '.' in y: 30 | y = y.rstrip('0') 31 | if y[-1] == '.': 32 | y = y[:-1] 33 | if not len(y): 34 | return INVALID_ANS 35 | if y[0] == '.': 36 | y = '0' + y 37 | if y[-1] == '%': 38 | y = str(eval(y[:-1]) / 100) 39 | return y.rstrip('.') 40 | 41 | def get_answer_label(response_answer, gt): 42 | if response_answer == INVALID_ANS: 43 | return INVALID_ANS 44 | return response_answer == gt 45 | 46 | 47 | # taken from 48 | # https://stackoverflow.com/questions/492519/timeout-on-a-function-call 49 | @contextmanager 50 | def timeout(duration, formula): 51 | def timeout_handler(signum, frame): 52 | raise Exception(f"'{formula}': timed out after {duration} seconds") 53 | 54 | signal.signal(signal.SIGALRM, timeout_handler) 55 | signal.alarm(duration) 56 | yield 57 | signal.alarm(0) 58 | 59 | 60 | def eval_with_timeout(formula, max_time=3): 61 | try: 62 | with timeout(max_time, formula): 63 | return round(eval(formula), ndigits=4) 64 | except Exception as e: 65 | signal.alarm(0) 66 | print(f"Warning: Failed to eval {formula}, exception: {e}") 67 | return None 68 | 69 | # refer to https://github.com/openai/grade-school-math/blob/master/grade_school_math/calculator.py 70 | def use_calculator(sample): 71 | if "<<" not in sample: 72 | return None 73 | 74 | parts = sample.split("<<") 75 | remaining = parts[-1] 76 | if ">>" in remaining: 77 | return None 78 | if "=" not in remaining: 79 | return None 80 | lhs = remaining.split("=")[0] 81 | lhs = lhs.replace(",", "") 82 | if any([x not in "0123456789*+-/.()" for x in lhs]): 83 | return None 84 | ans = eval_with_timeout(lhs) 85 | if remaining[-1] == '-' and ans is not None and ans < 0: 86 | ans = -ans 87 | return ans 88 | 89 | 90 | 91 | 92 | 93 | 94 | -------------------------------------------------------------------------------- /eval/math/configs/few_shot_test_configs.json: -------------------------------------------------------------------------------- 1 | { 2 | "math-cot-test": { 3 | "test_path": "../../data/MATH_test.jsonl", 4 | "language": "en", 5 | "tasks": [ 6 | "cot" 7 | ], 8 | "process_fn": "process_math_test", 9 | "answer_extraction_fn": "extract_math_few_shot_cot_answer", 10 | "eval_fn": "eval_math", 11 | "few_shot_prompt": "MinervaMathPrompt" 12 | } 13 | } -------------------------------------------------------------------------------- /eval/math/configs/zero_shot_test_configs.json: -------------------------------------------------------------------------------- 1 | { 2 | "gsm8k-test": { 3 | "test_path": "datasets/gsm8k/test.jsonl", 4 | "language": "en", 5 | "tasks": ["tool", "cot"], 6 | "process_fn": "process_gsm8k_test", 7 | "answer_extraction_fn": "extract_last_single_answer", 8 | "eval_fn": "eval_last_single_answer" 9 | }, 10 | "math-test": { 11 | "test_path": "datasets/math/test.jsonl", 12 | "language": "en", 13 | "tasks": ["tool", "cot"], 14 | "process_fn": "process_math_test", 15 | "answer_extraction_fn": "extract_math_answer", 16 | "eval_fn": "eval_math" 17 | }, 18 | "mgsm-zh": { 19 | "test_path": "datasets/mgsm_zh/mgsm_zh.jsonl", 20 | "language": "zh", 21 | "tasks": ["tool", "cot"], 22 | "process_fn": "process_mgsm_zh", 23 | "answer_extraction_fn": "extract_last_single_answer", 24 | "eval_fn": "eval_last_single_answer" 25 | }, 26 | "cmath": { 27 | "test_path": "datasets/cmath/test.jsonl", 28 | "language": "zh", 29 | "tasks": ["tool", "cot"], 30 | "process_fn": "process_cmath", 31 | "answer_extraction_fn": "extract_last_single_answer", 32 | "eval_fn": "eval_last_single_answer" 33 | } 34 | } -------------------------------------------------------------------------------- /eval/math/data_processing/__pycache__/answer_extraction.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hbin0701/Self-Explore/54b8ba9f4cee1ef52f72de916e100e0da2ae6863/eval/math/data_processing/__pycache__/answer_extraction.cpython-310.pyc -------------------------------------------------------------------------------- /eval/math/data_processing/__pycache__/answer_extraction.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hbin0701/Self-Explore/54b8ba9f4cee1ef52f72de916e100e0da2ae6863/eval/math/data_processing/__pycache__/answer_extraction.cpython-311.pyc -------------------------------------------------------------------------------- /eval/math/data_processing/__pycache__/process_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hbin0701/Self-Explore/54b8ba9f4cee1ef52f72de916e100e0da2ae6863/eval/math/data_processing/__pycache__/process_utils.cpython-310.pyc -------------------------------------------------------------------------------- /eval/math/data_processing/process_utils.py: -------------------------------------------------------------------------------- 1 | import regex 2 | 3 | from data_processing.answer_extraction import extract_math_answer, strip_string 4 | 5 | def process_gsm8k_test(item): 6 | sample = { 7 | 'dataset': 'gsm8k-cot', 8 | 'id': item['id'], 9 | 'messages': [ 10 | {'role': 'user', 'content': item['question']}, 11 | {'role': 'assistant', 'content': regex.sub(r"<<[^<>]*>>", "", item['cot']) + "\nSo the answer is $\\boxed{" + item['answer'].strip() + "}$."} 12 | ], 13 | 'answer': item['answer'].replace(',', '') 14 | } 15 | yield sample 16 | 17 | def process_math_test(item): 18 | question = item["problem"] 19 | try: 20 | answer = extract_math_answer(question, item['solution'], task="cot") 21 | except: 22 | return 23 | sample = { 24 | "dataset": "math-cot", 25 | "id": item['id'], 26 | "level": item["level"], 27 | "type": item["type"], 28 | "category": item["category"], 29 | "messages": [ 30 | {"role": "user", "content": question}, 31 | {"role": "assistant", "content": "\n".join(regex.split(r"(?<=\.) (?=[A-Z])", item["solution"]))} 32 | ], 33 | "answer": answer 34 | } 35 | yield sample 36 | 37 | def process_math_sat(item): 38 | options = item['options'].strip() 39 | assert 'A' == options[0] 40 | options = '(' + options 41 | for ch in 'BCDEFG': 42 | if f' {ch}) ' in options: 43 | options = regex.sub(f' {ch}\) ', f" ({ch}) ", options) 44 | question = f"{item['question'].strip()}\nWhat of the following is the right choice? Explain your answer.\n{options.strip()}" 45 | messages = [ 46 | {'role': 'user', 'content': question}, 47 | {'role': 'assistant', 'content': item['Answer']} 48 | ] 49 | item = { 50 | 'dataset': 'math_sat', 51 | 'id': item['id'], 52 | 'language': 'en', 53 | 'messages': messages, 54 | 'answer': item['Answer'], 55 | } 56 | yield item 57 | 58 | def process_ocwcourses(item): 59 | messages = [ 60 | {'role': 'user', 'content': item['problem'].strip()}, 61 | {'role': 'assistant', 'content': item['solution'].strip()} 62 | ] 63 | item = { 64 | "dataset": "OCWCourses", 65 | "id": item['id'], 66 | "language": "en", 67 | "messages": messages, 68 | "answer": item['answer'] 69 | } 70 | yield item 71 | 72 | def process_mmlu_stem(item): 73 | options = item['options'] 74 | for i, (label, option) in enumerate(zip('ABCD', options)): 75 | options[i] = f"({label}) {str(option).strip()}" 76 | options = ", ".join(options) 77 | question = f"{item['question'].strip()}\nWhat of the following is the right choice? Explain your answer.\n{options}" 78 | messages = [ 79 | {'role': 'user', 'content': question}, 80 | {'role': 'assistant', 'content': item['answer']} 81 | ] 82 | item = { 83 | "dataset": "MMLU-STEM", 84 | "id": item['id'], 85 | "language": "en", 86 | "messages": messages, 87 | "answer": item['answer'] 88 | } 89 | yield item 90 | 91 | def process_mgsm_zh(item): 92 | item['answer'] = item['answer'].replace(',', '') 93 | yield item 94 | 95 | def process_cmath(item): 96 | item = { 97 | 'dataset': 'cmath', 98 | 'id': item['id'], 99 | 'grade': item['grade'], 100 | 'reasoning_step': item['reasoning_step'], 101 | 'messages': [ 102 | {'role': 'user', 'content': item['question'].strip()}, 103 | {'role': 'assistant', 'content': ''} 104 | ], 105 | 'answer': item['golden'].strip().replace(",", "") 106 | } 107 | yield item 108 | 109 | def process_agieval_gaokao_math_cloze(item): 110 | item = { 111 | 'dataset': 'agieval-gaokao-math-cloze', 112 | 'id': item['id'], 113 | 'messages': [ 114 | {'role': 'user', 'content': item['question'].strip()}, 115 | {'role': 'assistant', 'content': ''} 116 | ], 117 | 'answer': [strip_string(ans) for ans in item['answer'].strip().split(";")] 118 | } 119 | yield item 120 | 121 | def process_agieval_gaokao_mathqa(item): 122 | question = item['question'].strip() 123 | options = [] 124 | for option in item['options']: 125 | option = option.strip() 126 | assert option[0] == '(' 127 | assert option[2] == ')' 128 | assert option[1] in 'ABCD' 129 | option = f"{option[1]}: {option[3:].strip()}" 130 | options.append(option.strip()) 131 | question = f"{question}\n{options}" 132 | item = { 133 | 'dataset': 'agieval-gaokao-mathqa', 134 | 'id': item['id'], 135 | 'messages': [ 136 | {'role': 'user', 'content': question}, 137 | {'role': 'assistant', 'content': ''} 138 | ], 139 | "answer": item['label'] 140 | } 141 | yield item 142 | 143 | def process_agieval_gaokao_mathqa_few_shot_cot_test(item): 144 | question = item['question'].strip().rstrip('\\') 145 | options = " ".join([opt.strip() for opt in item['options']]) 146 | question = f"{question}\n从以下选项中选择: {options}" 147 | item = { 148 | 'dataset': 'agieval-gaokao-mathqa', 149 | 'id': item['id'], 150 | 'messages': [ 151 | {'role': 'user', 'content': question}, 152 | {'role': 'assistant', 'content': ''} 153 | ], 154 | "answer": item['label'] 155 | } 156 | yield item 157 | 158 | def process_minif2f_isabelle(item): 159 | question = f"(*### Problem\n\n{item['informal_statement'].strip()}\n\n### Solution\n\n{item['informal_proof'].strip()} *)\n\nFormal:\n{item['formal_statement'].strip()}" 160 | item = { 161 | 'dataset': 'minif2f-isabelle', 162 | 'id': item['id'], 163 | 'messages': [ 164 | {'role': 'user', 'content': question}, 165 | {'role': 'assistant', 'content': ''} 166 | ], 167 | "answer": "placeholder" 168 | } 169 | yield item 170 | -------------------------------------------------------------------------------- /eval/math/eval/__pycache__/eval_script.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hbin0701/Self-Explore/54b8ba9f4cee1ef52f72de916e100e0da2ae6863/eval/math/eval/__pycache__/eval_script.cpython-310.pyc -------------------------------------------------------------------------------- /eval/math/eval/__pycache__/eval_script.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hbin0701/Self-Explore/54b8ba9f4cee1ef52f72de916e100e0da2ae6863/eval/math/eval/__pycache__/eval_script.cpython-311.pyc -------------------------------------------------------------------------------- /eval/math/eval/__pycache__/eval_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hbin0701/Self-Explore/54b8ba9f4cee1ef52f72de916e100e0da2ae6863/eval/math/eval/__pycache__/eval_utils.cpython-310.pyc -------------------------------------------------------------------------------- /eval/math/eval/__pycache__/eval_utils.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hbin0701/Self-Explore/54b8ba9f4cee1ef52f72de916e100e0da2ae6863/eval/math/eval/__pycache__/eval_utils.cpython-311.pyc -------------------------------------------------------------------------------- /eval/math/eval/__pycache__/ocwcourses_eval_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hbin0701/Self-Explore/54b8ba9f4cee1ef52f72de916e100e0da2ae6863/eval/math/eval/__pycache__/ocwcourses_eval_utils.cpython-310.pyc -------------------------------------------------------------------------------- /eval/math/eval/__pycache__/ocwcourses_eval_utils.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hbin0701/Self-Explore/54b8ba9f4cee1ef52f72de916e100e0da2ae6863/eval/math/eval/__pycache__/ocwcourses_eval_utils.cpython-311.pyc -------------------------------------------------------------------------------- /eval/math/eval/__pycache__/python_executor.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hbin0701/Self-Explore/54b8ba9f4cee1ef52f72de916e100e0da2ae6863/eval/math/eval/__pycache__/python_executor.cpython-310.pyc -------------------------------------------------------------------------------- /eval/math/eval/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hbin0701/Self-Explore/54b8ba9f4cee1ef52f72de916e100e0da2ae6863/eval/math/eval/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /eval/math/eval/eval_script.py: -------------------------------------------------------------------------------- 1 | import regex 2 | from copy import deepcopy 3 | from eval.eval_utils import math_equal 4 | from eval.ocwcourses_eval_utils import normalize_numeric, numeric_equality, normalize_symbolic_equation, SymbolicMathMixin 5 | 6 | def is_correct(item, pred_key='prediction', prec=1e-3): 7 | # print("called again ...", item) 8 | pred = item[pred_key] 9 | ans = item['answer'] 10 | if isinstance(pred, list) and isinstance(ans, list): 11 | pred_matched = set() 12 | ans_matched = set() 13 | for i in range(len(pred)): 14 | for j in range(len(ans)): 15 | item_cpy = deepcopy(item) 16 | item_cpy.update({ 17 | pred_key: pred[i], 18 | 'answer': ans[j] 19 | }) 20 | if is_correct(item_cpy, pred_key=pred_key, prec=prec): 21 | pred_matched.add(i) 22 | ans_matched.add(j) 23 | if item_cpy[pred_key] == '2,3,4': 24 | print(item, flush=True) 25 | print("wtf", flush=True) 26 | return len(pred_matched) == len(pred) and len(ans_matched) == len(ans) 27 | elif isinstance(pred, str) and isinstance(ans, str): 28 | if '\\cup' in pred and '\\cup' in ans: 29 | item = deepcopy(item) 30 | item.update({ 31 | pred_key: pred.split('\\cup'), 32 | 'answer': ans.split('\\cup'), 33 | }) 34 | return is_correct(item, pred_key=pred_key, prec=prec) 35 | else: 36 | label = False 37 | try: 38 | label = abs(float(regex.sub(r',', '', str(pred))) - float(regex.sub(r',', '', str(ans)))) < prec 39 | except: 40 | pass 41 | label = label or (ans and pred == ans) or math_equal(pred, ans) 42 | return label 43 | else: 44 | print(item, flush=True) 45 | raise NotImplementedError() 46 | 47 | def eval_math(item, pred_key='prediction', prec=1e-3): 48 | pred = item[pred_key] 49 | if pred_key == 'program_output' and isinstance(pred, str): 50 | pred = [pred] 51 | ans = item['answer'] 52 | if isinstance(pred, list) and isinstance(ans, list): 53 | # for some questions in MATH, `reference` repeats answers 54 | _ans = [] 55 | for a in ans: 56 | if a not in _ans: 57 | _ans.append(a) 58 | ans = _ans 59 | # some predictions for MATH questions also repeats answers 60 | _pred = [] 61 | for a in pred: 62 | if a not in _pred: 63 | _pred.append(a) 64 | # some predictions mistakenly box non-answer strings 65 | pred = _pred[-len(ans):] 66 | 67 | item.update({ 68 | pred_key: pred, 69 | 'answer': ans 70 | }) 71 | return is_correct(item, pred_key=pred_key, prec=prec) 72 | 73 | def eval_last_single_answer(item, pred_key='prediction', prec=1e-3): 74 | for key in [pred_key, 'answer']: 75 | assert isinstance(item[key], str), f"{key} = `{item[key]}` is not a str" 76 | return is_correct(item, pred_key=pred_key, prec=prec) 77 | 78 | def eval_agieval_gaokao_math_cloze(item, pred_key='prediction', prec=1e-3): 79 | if pred_key == 'program_output' and isinstance(item[pred_key], str): 80 | item[pred_key] = [item[pred_key]] 81 | for key in [pred_key, 'answer']: 82 | assert isinstance(item[key], list), f"{key} = `{item[key]}` is not a list" 83 | pred = item[pred_key] 84 | ans = item['answer'] 85 | _pred = [] 86 | for p in pred: 87 | p = p + ";" 88 | while p: 89 | left_brackets = 0 90 | for i in range(len(p)): 91 | if p[i] == ';' or (p[i] == ',' and left_brackets == 0): 92 | _p, p = p[:i].strip(), p[i + 1:].strip() 93 | if _p not in _pred: 94 | _pred.append(_p) 95 | break 96 | elif p[i] in '([{': 97 | left_brackets += 1 98 | elif p[i] in ')]}': 99 | left_brackets -= 1 100 | pred = _pred[-len(ans):] 101 | if len(pred) == len(ans): 102 | for p, a in zip(pred, ans): 103 | item.update({ 104 | pred_key: p, 105 | 'answer': a, 106 | }) 107 | if not is_correct(item, pred_key=pred_key, prec=prec): 108 | return False 109 | return True 110 | else: 111 | return False 112 | 113 | def eval_agieval_gaokao_mathqa(item, pred_key='prediction', prec=1e-3): 114 | if pred_key == 'program_output' and isinstance(item[pred_key], str): 115 | item[pred_key] = [item[pred_key]] 116 | pred_str = " ".join(item[pred_key]) 117 | ans = item['answer'] 118 | tag = None 119 | idx = -1 120 | for t in 'ABCD': 121 | if t in pred_str and pred_str.index(t) > idx: 122 | tag = t 123 | idx = pred_str.index(t) 124 | return tag == ans 125 | 126 | def eval_math_sat(item, pred_key='prediction', prec=1e-3): 127 | for key in [pred_key, 'answer']: 128 | assert isinstance(item[key], str), f"{key} = `{item[key]}` is not a str" 129 | return item[pred_key].lower() == item['answer'].lower() 130 | 131 | def eval_mmlu_stem(item, pred_key='prediction', prec=1e-3): 132 | return eval_math_sat(item, pred_key=pred_key, prec=prec) 133 | 134 | def eval_ocwcourses(item, pred_key='prediction', prec=1e-3): 135 | INVALID_ANSWER = "[invalidanswer]" 136 | for key in [pred_key, 'answer']: 137 | assert isinstance(item[key], str), f"{key} = `{item[key]}` is not a str" 138 | pred = item[pred_key] 139 | ans = item['answer'] 140 | 141 | try: 142 | float(ans) 143 | normalize_fn = normalize_numeric 144 | is_equiv = numeric_equality 145 | answer_type = "numeric" 146 | except ValueError: 147 | if "=" in ans: 148 | normalize_fn = normalize_symbolic_equation 149 | is_equiv = lambda x, y: x==y 150 | answer_type = "equation" 151 | else: 152 | normalize_fn = SymbolicMathMixin().normalize_tex 153 | is_equiv = SymbolicMathMixin().is_tex_equiv 154 | answer_type = "expression" 155 | 156 | correct_answer = normalize_fn(ans) 157 | 158 | unnormalized_answer = pred if pred else INVALID_ANSWER 159 | model_answer = normalize_fn(unnormalized_answer) 160 | 161 | if unnormalized_answer == INVALID_ANSWER: 162 | acc = 0 163 | elif model_answer == INVALID_ANSWER: 164 | acc = 0 165 | elif is_equiv(model_answer, correct_answer): 166 | acc = 1 167 | else: 168 | acc = 0 169 | 170 | return acc 171 | 172 | def eval_minif2f_isabelle(item, pred_key='prediction', prec=1e-3): 173 | return True 174 | -------------------------------------------------------------------------------- /eval/math/eval/ocwcourses_eval_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import numpy as np 3 | import sympy 4 | from sympy.core.sympify import SympifyError 5 | from sympy.parsing.latex import parse_latex 6 | 7 | import signal 8 | 9 | INVALID_ANSWER = "[invalidanswer]" 10 | 11 | class timeout: 12 | def __init__(self, seconds=1, error_message="Timeout"): 13 | self.seconds = seconds 14 | self.error_message = error_message 15 | 16 | def handle_timeout(self, signum, frame): 17 | raise TimeoutError(self.error_message) 18 | 19 | def __enter__(self): 20 | signal.signal(signal.SIGALRM, self.handle_timeout) 21 | signal.alarm(self.seconds) 22 | 23 | def __exit__(self, type, value, traceback): 24 | signal.alarm(0) 25 | 26 | def normalize_numeric(s): 27 | if s is None: 28 | return None 29 | for unit in [ 30 | "eV", 31 | " \\mathrm{~kg} \\cdot \\mathrm{m} / \\mathrm{s}", 32 | " kg m/s", 33 | "kg*m/s", 34 | "kg", 35 | "m/s", 36 | "m / s", 37 | "m s^{-1}", 38 | "\\text{ m/s}", 39 | " \\mathrm{m/s}", 40 | " \\text{ m/s}", 41 | "g/mole", 42 | "g/mol", 43 | "\\mathrm{~g}", 44 | "\\mathrm{~g} / \\mathrm{mol}", 45 | "W", 46 | "erg/s", 47 | "years", 48 | "year", 49 | "cm", 50 | ]: 51 | s = s.replace(unit, "") 52 | s = s.strip() 53 | for maybe_unit in ["m", "s", "cm"]: 54 | s = s.replace("\\mathrm{" + maybe_unit + "}", "") 55 | s = s.replace("\\mathrm{~" + maybe_unit + "}", "") 56 | s = s.strip() 57 | s = s.strip("$") 58 | try: 59 | return float(eval(s)) 60 | except: 61 | try: 62 | expr = parse_latex(s) 63 | if expr.is_number: 64 | return float(expr) 65 | return INVALID_ANSWER 66 | except: 67 | return INVALID_ANSWER 68 | 69 | def numeric_equality(n1, n2, threshold=0.01): 70 | if n1 is None or n2 is None: 71 | return False 72 | if np.isclose(n1, 0) or np.isclose(n2, 0) or np.isclose(n1 - n2, 0): 73 | return np.abs(n1 - n2) < threshold * (n1 + n2) / 2 74 | else: 75 | return np.isclose(n1, n2) 76 | 77 | def normalize_symbolic_equation(s): 78 | if not isinstance(s, str): 79 | return INVALID_ANSWER 80 | if s.startswith("\\["): 81 | s = s[2:] 82 | if s.endswith("\\]"): 83 | s = s[:-2] 84 | s = s.replace("\\left(", "(") 85 | s = s.replace("\\right)", ")") 86 | s = s.replace("\\\\", "\\") 87 | if s.startswith("$") or s.endswith("$"): 88 | s = s.strip("$") 89 | try: 90 | maybe_expression = parse_latex(s) 91 | if not isinstance(maybe_expression, sympy.core.relational.Equality): 92 | # we have equation, not expression 93 | return INVALID_ANSWER 94 | else: 95 | return maybe_expression 96 | except: 97 | return INVALID_ANSWER 98 | 99 | class SymbolicMathMixin: 100 | """ 101 | Methods useful for parsing mathematical expressions from text and determining equivalence of expressions. 102 | """ 103 | 104 | SUBSTITUTIONS = [ # used for text normalize 105 | ("an ", ""), 106 | ("a ", ""), 107 | (".$", "$"), 108 | ("\\$", ""), 109 | (r"\ ", ""), 110 | (" ", ""), 111 | ("mbox", "text"), 112 | (",\\text{and}", ","), 113 | ("\\text{and}", ","), 114 | ("\\text{m}", "\\text{}"), 115 | ] 116 | REMOVED_EXPRESSIONS = [ # used for text normalizer 117 | "square", 118 | "ways", 119 | "integers", 120 | "dollars", 121 | "mph", 122 | "inches", 123 | "ft", 124 | "hours", 125 | "km", 126 | "units", 127 | "\\ldots", 128 | "sue", 129 | "points", 130 | "feet", 131 | "minutes", 132 | "digits", 133 | "cents", 134 | "degrees", 135 | "cm", 136 | "gm", 137 | "pounds", 138 | "meters", 139 | "meals", 140 | "edges", 141 | "students", 142 | "childrentickets", 143 | "multiples", 144 | "\\text{s}", 145 | "\\text{.}", 146 | "\\text{\ns}", 147 | "\\text{}^2", 148 | "\\text{}^3", 149 | "\\text{\n}", 150 | "\\text{}", 151 | r"\mathrm{th}", 152 | r"^\circ", 153 | r"^{\circ}", 154 | r"\;", 155 | r",\!", 156 | "{,}", 157 | '"', 158 | "\\dots", 159 | ] 160 | 161 | def normalize_tex(self, final_answer: str) -> str: 162 | """ 163 | Normalizes a string representing a mathematical expression. 164 | Used as a preprocessing step before parsing methods. 165 | 166 | Copied character for character from appendix D of Lewkowycz et al. (2022) 167 | """ 168 | final_answer = final_answer.split("=")[-1] 169 | 170 | for before, after in self.SUBSTITUTIONS: 171 | final_answer = final_answer.replace(before, after) 172 | for expr in self.REMOVED_EXPRESSIONS: 173 | final_answer = final_answer.replace(expr, "") 174 | 175 | # Extract answer that is in LaTeX math, is bold, 176 | # is surrounded by a box, etc. 177 | final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer) 178 | final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer) 179 | final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer) 180 | final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer) 181 | final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer) 182 | 183 | # Normalize shorthand TeX: 184 | # \fracab -> \frac{a}{b} 185 | # \frac{abc}{bef} -> \frac{abc}{bef} 186 | # \fracabc -> \frac{a}{b}c 187 | # \sqrta -> \sqrt{a} 188 | # \sqrtab -> sqrt{a}b 189 | final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer) 190 | final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer) 191 | final_answer = final_answer.replace("$", "") 192 | 193 | # Normalize 100,000 -> 100000 194 | if final_answer.replace(",", "").isdigit(): 195 | final_answer = final_answer.replace(",", "") 196 | 197 | return final_answer 198 | 199 | def parse_tex(self, text: str, time_limit: int = 5) -> sympy.Basic: 200 | """ 201 | Wrapper around `sympy.parse_text` that outputs a SymPy expression. 202 | Typically, you want to apply `normalize_text` as a preprocessing step. 203 | """ 204 | try: 205 | with timeout(seconds=time_limit): 206 | parsed = parse_latex(text) 207 | except ( 208 | # general error handling: there is a long tail of possible sympy/other 209 | # errors we would like to catch 210 | Exception 211 | ) as e: 212 | print(f"failed to parse {text} with exception {e}") 213 | return None 214 | 215 | return parsed 216 | 217 | def is_exp_equiv(self, x1: sympy.Basic, x2: sympy.Basic, time_limit=5) -> bool: 218 | """ 219 | Determines whether two sympy expressions are equal. 220 | """ 221 | try: 222 | with timeout(seconds=time_limit): 223 | try: 224 | diff = x1 - x2 225 | except (SympifyError, ValueError, TypeError) as e: 226 | print( 227 | f"Couldn't subtract {x1} and {x2} with exception {e}" 228 | ) 229 | return False 230 | 231 | try: 232 | if sympy.simplify(diff) == 0: 233 | return True 234 | else: 235 | return False 236 | except (SympifyError, ValueError, TypeError) as e: 237 | print(f"Failed to simplify {x1}-{x2} with {e}") 238 | return False 239 | except TimeoutError as e: 240 | print(f"Timed out comparing {x1} and {x2}") 241 | return False 242 | except Exception as e: 243 | print(f"failed on unrecognized exception {e}") 244 | return False 245 | 246 | def is_tex_equiv(self, x1: str, x2: str, time_limit=5) -> bool: 247 | """ 248 | Determines whether two (ideally normalized using `normalize_text`) TeX expressions are equal. 249 | 250 | Does so by first checking for string exact-match, then falls back on sympy-equivalence, 251 | following the (Lewkowycz et al. 2022) methodology. 252 | """ 253 | if x1 == x2: 254 | # don't resort to sympy if we have full string match, post-normalization 255 | return True 256 | else: 257 | return False 258 | parsed_x2 = self.parse_tex(x2) 259 | if not parsed_x2: 260 | # if our reference fails to parse into a Sympy object, 261 | # we forgo parsing + checking our generated answer. 262 | return False 263 | return self.is_exp_equiv(self.parse_tex(x1), parsed_x2, time_limit=time_limit) 264 | -------------------------------------------------------------------------------- /eval/math/eval/python_executor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import io 3 | from contextlib import redirect_stdout 4 | import pickle 5 | import regex 6 | import copy 7 | from typing import Any, Dict, Optional 8 | import multiprocess 9 | from pebble import ProcessPool 10 | from concurrent.futures import TimeoutError 11 | from functools import partial 12 | import traceback 13 | from timeout_decorator import timeout 14 | 15 | class GenericRuntime: 16 | GLOBAL_DICT = {} 17 | LOCAL_DICT = None 18 | HEADERS = [] 19 | def __init__(self): 20 | self._global_vars = copy.copy(self.GLOBAL_DICT) 21 | self._local_vars = copy.copy(self.LOCAL_DICT) if self.LOCAL_DICT else None 22 | 23 | for c in self.HEADERS: 24 | self.exec_code(c) 25 | 26 | def exec_code(self, code_piece: str) -> None: 27 | if regex.search(r'(\s|^)?input\(', code_piece) or regex.search(r'(\s|^)?os.system\(', code_piece): 28 | raise RuntimeError() 29 | exec(code_piece, self._global_vars) 30 | 31 | def eval_code(self, expr: str) -> Any: 32 | return eval(expr, self._global_vars) 33 | 34 | def inject(self, var_dict: Dict[str, Any]) -> None: 35 | for k, v in var_dict.items(): 36 | self._global_vars[k] = v 37 | 38 | @property 39 | def answer(self): 40 | return self._global_vars['answer'] 41 | 42 | class PythonExecutor: 43 | def __init__( 44 | self, 45 | runtime: Optional[Any] = None, 46 | get_answer_symbol: Optional[str] = None, 47 | get_answer_expr: Optional[str] = None, 48 | get_answer_from_stdout: bool = False, 49 | ) -> None: 50 | self.runtime = runtime if runtime else GenericRuntime() 51 | self.answer_symbol = get_answer_symbol 52 | self.answer_expr = get_answer_expr 53 | self.get_answer_from_stdout = get_answer_from_stdout 54 | 55 | def process_generation_to_code(self, gens: str): 56 | batch_code = [] 57 | for g in gens: 58 | multiline_comments = False 59 | code = [] 60 | for line in g.split('\n'): 61 | strip_line = line.strip() 62 | if strip_line.startswith("#"): 63 | line = line.split("#", 1)[0] + "# comments" 64 | elif not multiline_comments and strip_line.startswith('"""') and strip_line.endswith('"""') and len(strip_line) >= 6: 65 | line = line.split('"""', 1)[0] + '"""comments"""' 66 | elif not multiline_comments and strip_line.startswith('"""'): 67 | multiline_comments = True 68 | elif multiline_comments and strip_line.endswith('"""'): 69 | multiline_comments = False 70 | line = "" 71 | if not multiline_comments: 72 | code.append(line) 73 | batch_code.append(code) 74 | return batch_code 75 | 76 | @staticmethod 77 | def execute( 78 | code, 79 | get_answer_from_stdout = None, 80 | runtime = None, 81 | answer_symbol = None, 82 | answer_expr = None, 83 | timeout_length = 10, 84 | ): 85 | try: 86 | if get_answer_from_stdout: 87 | program_io = io.StringIO() 88 | with redirect_stdout(program_io): 89 | timeout(timeout_length)(runtime.exec_code)('\n'.join(code)) 90 | program_io.seek(0) 91 | result = "".join(program_io.readlines()) # [-1] 92 | elif answer_symbol: 93 | timeout(timeout_length)(runtime.exec_code)('\n'.join(code)) 94 | result = runtime._global_vars[answer_symbol] 95 | elif answer_expr: 96 | timeout(timeout_length)(runtime.exec_code)('\n'.join(code)) 97 | result = timeout(timeout_length)(runtime.eval_code)(answer_expr) 98 | else: 99 | timeout(timeout_length)(runtime.exec_code)('\n'.join(code[:-1])) 100 | result = timeout(timeout_length)(runtime.eval_code)(code[-1]) 101 | concise_exec_info = "" 102 | exec_info = "" 103 | str(result) 104 | pickle.dumps(result) # serialization check 105 | except: 106 | # traceback.print_exc() 107 | result = '' 108 | concise_exec_info = traceback.format_exc().split('\n')[-2] 109 | exec_info = traceback.format_exc() 110 | if get_answer_from_stdout and 'exec(code_piece, self._global_vars)' in exec_info: 111 | exec_info = exec_info.split('exec(code_piece, self._global_vars)')[-1].strip() 112 | msg = [] 113 | for line in exec_info.split("\n"): 114 | patt = regex.search(r'(?P.*)File "(?P.*)", line (?P\d+), (?P.*)', line) 115 | if patt is not None: 116 | if '' in patt.group('end'): 117 | continue 118 | fname = patt.group("file") 119 | if "site-packages" in fname: 120 | fname = f"site-packages{fname.split('site-packages', 1)[1]}" 121 | line = f'{patt.group("start")}File "{fname}", {patt.group("end")}' 122 | else: 123 | line = f'{patt.group("start")}{patt.group("end")}' 124 | else: 125 | patt = regex.search(r'(?P.*)(?P/.*site-packages/.*\.py)(?P.*)', line) 126 | if patt is not None: 127 | line = f'{patt.group("start")}site-packages{patt.group("file").split("site-packages", 1)[1]}{patt.group("end")}' 128 | msg.append(line) 129 | exec_info = "\n".join(msg) 130 | return result, concise_exec_info, exec_info 131 | 132 | def apply(self, code): 133 | return self.batch_apply([code])[0] 134 | 135 | def batch_apply(self, batch_code): 136 | all_code_snippets = self.process_generation_to_code(batch_code) 137 | all_exec_results = [] 138 | executor = partial( 139 | self.execute, 140 | get_answer_from_stdout=self.get_answer_from_stdout, 141 | runtime=self.runtime, 142 | answer_symbol=self.answer_symbol, 143 | answer_expr=self.answer_expr, 144 | timeout_length=10, 145 | ) 146 | with ProcessPool(max_workers=multiprocess.cpu_count()) as pool: 147 | iterator = pool.map(executor, all_code_snippets, timeout=10).result() 148 | 149 | while True: 150 | try: 151 | result = next(iterator) 152 | all_exec_results.append(result) 153 | except StopIteration: 154 | break 155 | except TimeoutError as error: 156 | all_exec_results.append(("", "Timeout Error", "Timeout Error")) 157 | except Exception as error: 158 | print(error) 159 | exit() 160 | 161 | batch_results = [] 162 | for code, (result, concise_exec_info, exec_info) in zip(all_code_snippets, all_exec_results): 163 | metadata = {'code': code, 'exec_result': result, 'concise_exec_info': concise_exec_info, 'exec_info': exec_info} 164 | batch_results.append((result, metadata)) 165 | return batch_results 166 | -------------------------------------------------------------------------------- /eval/math/run_cot_eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import sys 5 | import_path = os.path.abspath(__file__) 6 | for _ in range(2): 7 | import_path = os.path.dirname(import_path) 8 | sys.path.append(import_path) 9 | 10 | import numpy 11 | from tqdm import tqdm 12 | import json 13 | from copy import deepcopy 14 | from vllm import LLM, SamplingParams 15 | from pebble import ProcessPool 16 | from concurrent.futures import TimeoutError 17 | import random 18 | from eval.utils import generate_completions, load_hf_lm_and_tokenizer 19 | from transformers import AutoTokenizer 20 | from data_processing.answer_extraction import * 21 | from eval.eval_script import * 22 | 23 | def evaluate(eval_fn, tasks, _timeout=15): 24 | with ProcessPool() as pool: 25 | timeout_cnt = 0 26 | iterator = pool.map(eval_fn, tasks, timeout=_timeout).result() 27 | labels = [] 28 | while True: 29 | try: 30 | labels.append(int(next(iterator))) 31 | except StopIteration: 32 | break 33 | except TimeoutError as error: 34 | labels.append(0) 35 | timeout_cnt += 1 36 | except Exception as error: 37 | print(error.traceback, flush=True) 38 | exit() 39 | return labels, timeout_cnt 40 | 41 | def infer(args, test_data): 42 | global tokenizer 43 | if tokenizer is None: 44 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name_or_path, trust_remote_code=True) 45 | 46 | prompts = [] 47 | for example in test_data: 48 | prompt = example['messages'][-2]['content'] + "\n" # normally 49 | prompts.append(prompt.lstrip()) 50 | 51 | global model 52 | print("Loading model and tokenizer...") 53 | if args.use_vllm: 54 | if model is None: 55 | model = LLM(model=args.model_name_or_path, tokenizer=args.tokenizer_name_or_path, trust_remote_code=True, tensor_parallel_size=len(os.environ['CUDA_VISIBLE_DEVICES'].split(","))) 56 | eos_token = tokenizer.eos_token if tokenizer is not None and tokenizer.eos_token is not None else '' 57 | stop_words = [eos_token] 58 | 59 | outputs = model.generate(prompts, SamplingParams(temperature=args.temperature, top_p=1.0, max_tokens=1024, n=1, stop=stop_words)) 60 | outputs = sorted(outputs, key=lambda x: int(x.request_id)) # sort outputs by request_id 61 | outputs = [output.outputs[0].text for output in outputs] 62 | else: 63 | model, tokenizer = load_hf_lm_and_tokenizer( 64 | model_name_or_path=args.model_name_or_path, 65 | tokenizer_name_or_path=args.tokenizer_name_or_path, 66 | load_in_8bit=args.load_in_8bit, 67 | load_in_half=args.load_in_half, 68 | gptq_model=args.gptq 69 | ) 70 | 71 | stop_id_sequences = [] 72 | if tokenizer.eos_token_id is not None: 73 | stop_id_sequences = [[tokenizer.eos_token_id]] 74 | 75 | outputs, finish_completion = generate_completions( 76 | model=model, 77 | tokenizer=tokenizer, 78 | prompts=prompts, 79 | max_new_tokens=512, 80 | batch_size=args.eval_batch_size, 81 | stop_id_sequences=stop_id_sequences if stop_id_sequences else None, 82 | end_of_generation_id_sequence=[tokenizer.eos_token_id] if tokenizer.eos_token_id is not None else None 83 | ) 84 | 85 | if args.complete_partial_output: 86 | model_outputs = [example['messages'][-1]['content'] + output for example, output in zip(test_data, outputs)] 87 | else: 88 | model_outputs = outputs 89 | 90 | predictions = [eval(args.answer_extraction_fn)(item['messages'][-2]['content'], output, task='cot') for item, output in tqdm(zip(test_data, model_outputs), desc="extract answer", total=len(model_outputs))] 91 | assert len(model_outputs) > 0, f"{len(model_outputs)}" 92 | 93 | results = [] 94 | for example, output, pred in zip(test_data, model_outputs, predictions): 95 | item = deepcopy(example) 96 | item.update({ 97 | 'model_output': output, 98 | 'prediction': pred, 99 | }) 100 | results.append(item) 101 | 102 | print("Returning!") 103 | return results 104 | 105 | def main(args): 106 | random.seed(42) 107 | 108 | print("Loading data...") 109 | test_data = [] 110 | with open(os.path.join(args.data_dir, f"train.jsonl" if args.infer_train_set else f"test.jsonl")) as fin: 111 | for line in fin: 112 | example = json.loads(line) 113 | messages = example['messages'] 114 | assert messages[-1]['role'] == 'assistant' 115 | if not args.complete_partial_output: 116 | example['reference'] = example.get('reference', '') or [mess['content'] for mess in messages if mess['role'] == 'assistant'] 117 | for mess in messages: 118 | if mess['role'] == 'assistant': 119 | mess['content'] = '' 120 | example['messages'] = messages 121 | test_data.append(example) 122 | 123 | if args.max_num_examples and len(test_data) > args.max_num_examples: 124 | test_data = random.sample(test_data, args.max_num_examples) 125 | 126 | if args.n_subsets > 1: 127 | assert args.subset_id >= 0 and args.subset_id < args.n_subsets 128 | test_data = [item for i, item in enumerate(test_data) if i % args.n_subsets == args.subset_id] 129 | 130 | if not test_data: 131 | return 132 | 133 | if not os.path.exists(args.save_dir): 134 | os.makedirs(args.save_dir, exist_ok=True) 135 | 136 | results = infer(args, test_data) 137 | 138 | import jsonlines 139 | with jsonlines.open(f"{args.save_dir}/results_{args.subset_id}.jsonl", "a") as writer: 140 | writer.write_all(results) 141 | 142 | if __name__ == "__main__": 143 | parser = argparse.ArgumentParser() 144 | parser.add_argument("--data_dir", type=str, default="data/mgsm") 145 | parser.add_argument("--max_num_examples", type=int, default=None, help="maximum number of examples to evaluate.") 146 | parser.add_argument("--save_dir", type=str, default="results/mgsm") 147 | parser.add_argument("--model_name_or_path", type=str, default=None, help="if specified, we will load the model to generate the predictions.") 148 | parser.add_argument("--tokenizer_name_or_path", type=str, default=None, help="if specified, we will load the tokenizer from here.") 149 | parser.add_argument("--eval_batch_size", type=int, default=1, help="batch size for evaluation.") 150 | parser.add_argument("--load_in_8bit", action="store_true", help="load model in 8bit mode, which will reduce memory and speed up inference.") 151 | parser.add_argument("--gptq", action="store_true", help="If given, we're evaluating a 4-bit quantized GPTQ model.") 152 | parser.add_argument("--use_vllm", action="store_true") 153 | parser.add_argument("--load_in_half", action='store_true') 154 | parser.add_argument("--infer_train_set", action="store_true") 155 | parser.add_argument("--n_subsets", type=int, default=1) 156 | parser.add_argument("--subset_id", type=int, default=0) 157 | parser.add_argument("--temperature", type=float, default=0.0) 158 | parser.add_argument("--repeat_id_start", type=int, default=0) 159 | parser.add_argument("--n_repeat_sampling", type=int, default=1) 160 | parser.add_argument("--complete_partial_output", action='store_true') 161 | parser.add_argument("--prompt_format", type=str, choices=['sft', 'few_shot'], default='sft') 162 | parser.add_argument("--few_shot_prompt", type=str, default=None) 163 | parser.add_argument("--answer_extraction_fn", type=str, required=True) 164 | parser.add_argument("--eval_fn", type=str, required=True) 165 | parser.add_argument("--gpus", type=str, default=None) 166 | args, unparsed_args = parser.parse_known_args() 167 | if args.gpus is not None: 168 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus 169 | 170 | print(unparsed_args, flush=True) 171 | 172 | if 'math6' in args.data_dir: 173 | args.multi_turn = True 174 | 175 | # model_name_or_path cannot be both None or both not None. 176 | model = None 177 | tokenizer = None 178 | pool = None 179 | if args.n_repeat_sampling > 1 or args.repeat_id_start != 0: 180 | assert args.temperature > 0 181 | save_dir = args.save_dir 182 | for i in range(args.repeat_id_start, args.repeat_id_start + args.n_repeat_sampling): 183 | print(f"working on the {i} trials ...", flush=True) 184 | args.save_dir = os.path.join(save_dir, str(i)) 185 | os.makedirs(args.save_dir, exist_ok=True) 186 | main(args) 187 | else: 188 | main(args) 189 | 190 | if pool is not None: 191 | pool.close() 192 | -------------------------------------------------------------------------------- /eval/math/run_math_eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | output_dir="outputs/DeepSeek_MATH_Self_Explore" 4 | model_path="" 5 | tokenizer_path="" 6 | model_size="7b" 7 | use_vllm="--use-vllm" 8 | no_markup_question="--no-markup-question" 9 | test_conf="configs/few_shot_test_configs.json" # While we use this config, our code doesn't evaluate in few-shot setting. 10 | prompt_format="few_shot" 11 | n_repeats="1" 12 | temperature="0" 13 | ngpus="4" 14 | rank="0" 15 | 16 | python run_subset_parallel.py --output-dir $output_dir \ 17 | --model-path $model_path \ 18 | --tokenizer-path $tokenizer_path \ 19 | --model-size $model_size \ 20 | $use_vllm \ 21 | $no_markup_question \ 22 | --test-conf $test_conf \ 23 | --prompt_format $prompt_format \ 24 | --n-repeats $n_repeats \ 25 | --temperature $temperature \ 26 | --ngpus $ngpus \ 27 | --rank $rank 28 | -------------------------------------------------------------------------------- /eval/math/run_subset_parallel.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from tqdm import tqdm 4 | from glob import glob 5 | import time 6 | import json 7 | import subprocess 8 | import numpy 9 | from utils import read_data 10 | from eval.eval_script import eval_math 11 | from data_processing.process_utils import * 12 | 13 | _worker_num = int(os.environ.get('WORLD_SIZE', 1)) 14 | _worker_id = int(os.environ.get('RANK', 0)) 15 | 16 | def markup_question(args, item, language, src, task): 17 | for i in range(len(item['messages']) - 2, -1, -2): 18 | if language == 'zh': 19 | if task == 'cot': 20 | item['messages'][i]['content'] = f"{item['messages'][i]['content']}\n请通过逐步推理来解答问题,并把最终答案放置于" + "\\boxed{}中。" 21 | elif task == 'tool': 22 | item['messages'][i]['content'] = f"{item['messages'][i]['content']}\n请结合自然语言和Python程序语言来解答问题,并把最终答案放置于" + "\\boxed{}中。" 23 | else: 24 | pass 25 | elif language == 'en': 26 | if task == 'cot': 27 | item['messages'][i]['content'] = f"{item['messages'][i]['content']}\nPlease reason step by step, and put your final answer within " + "\\boxed{}." 28 | elif task == 'tool': 29 | item['messages'][i]['content'] = f"{item['messages'][i]['content']}\nPlease integrate natural language reasoning with programs to solve the problem above, and put your final answer within " + "\\boxed{}." 30 | else: 31 | pass 32 | return item 33 | 34 | def do_parallel_sampling(args, task, answer_extraction_fn, eval_fn, input_dir, output_dir, log_dir): 35 | if task == 'pal': 36 | code_fname = "run_pal_eval" 37 | elif task == 'cot': 38 | code_fname = "run_cot_eval" 39 | elif task == 'tool': 40 | code_fname = "run_tool_integrated_eval" 41 | else: 42 | raise NotImplementedError() 43 | 44 | n_procs = args.ngpus // args.ngpus_per_model 45 | n_procs = 4 # temporary 46 | 47 | gpus = [str(i) for i in range(args.ngpus)] 48 | gpu_groups = [] 49 | for i in range(n_procs): 50 | gpu_groups.append(gpus[i * args.ngpus_per_model: (i + 1) * args.ngpus_per_model]) 51 | 52 | global_n_procs = n_procs * _worker_num 53 | 54 | procs = [] 55 | for pid, gpus in enumerate(gpu_groups): 56 | global_pid = n_procs * (args.rank or _worker_id) + pid 57 | logpath = os.path.join(log_dir, f"{global_pid}.log") 58 | f = open(logpath, "w") 59 | cmd = f"python {code_fname}.py " \ 60 | f"--data_dir {input_dir} " \ 61 | f"--max_num_examples 100000000000000 " \ 62 | f"--save_dir {output_dir} " \ 63 | f"--model {args.model_path} " \ 64 | f"--tokenizer {args.tokenizer_path or args.model_path} " \ 65 | f"--eval_batch_size 1 " \ 66 | f"--temperature {args.temperature} " \ 67 | f"--repeat_id_start 0 " \ 68 | f"--n_repeat_sampling {args.n_repeats} " \ 69 | f"--n_subsets {global_n_procs} " \ 70 | f"--prompt_format {args.prompt_format} " \ 71 | f"--few_shot_prompt {args.few_shot_prompt} " \ 72 | f"--answer_extraction_fn {answer_extraction_fn} " \ 73 | f"--eval_fn {eval_fn} " \ 74 | f"--subset_id {global_pid} " \ 75 | f"--gpus {','.join(gpus)} " 76 | if args.use_vllm: 77 | cmd += " --use_vllm " 78 | if args.load_in_half: 79 | cmd += " --load_in_half " 80 | 81 | local_metric_path = os.path.join(output_dir, f"metrics.{global_pid}.json") 82 | if not args.overwrite and os.path.exists(local_metric_path) and read_data(local_metric_path)['n_samples'] > 0: 83 | continue 84 | 85 | print("LOGPATH", logpath) 86 | procs.append((global_pid, subprocess.Popen(cmd.split(), stdout=f, stderr=f), f)) 87 | 88 | for (global_pid, proc, f) in procs: 89 | print(f"Waiting for the {global_pid}th process to finish ...", flush=True) 90 | proc.wait() 91 | 92 | for (global_pid, proc, f) in procs: 93 | print(f"Closing the {global_pid}th process ...", flush=True) 94 | f.close() 95 | 96 | ### ADDED 97 | agg_li = [] 98 | for i in range(n_procs): 99 | agg_li += [json.loads(x) for x in open(f"{output_dir.strip('/')}/results_{i}.jsonl")] 100 | 101 | import jsonlines 102 | with jsonlines.open(f"{output_dir.strip('/')}/results_combined.jsonl", "w") as writer: 103 | writer.write_all(agg_li) 104 | 105 | all_labels = [] 106 | for item in agg_li: 107 | all_labels.append(eval_math(item)) 108 | 109 | with open(os.path.join(output_dir.strip('/'), "summary.json"), "w") as fout: 110 | json.dump({ 111 | "n_samples": len(all_labels), 112 | "accuracy": all_labels.count(True) / len(all_labels) 113 | }, fout, indent=4) 114 | 115 | 116 | # labels, eval_timeout_cnt = evaluate(eval(args.eval_fn), results) 117 | # for item, label in zip(results, labels): 118 | # item['accuracy'] = label 119 | 120 | # print("Calculating accuracy...") 121 | # acc = 0 122 | # for item in results: 123 | # acc += item['accuracy'] 124 | # print("output acc = {:.5f}".format(acc / len(results) * 100), flush=True) 125 | 126 | # print(f"Timeout count >>> output eval = {eval_timeout_cnt}", flush=True) 127 | 128 | # pred_fname = "predictions.json" 129 | # if args.n_subsets > 1: 130 | # pred_fname = f"predictions.{args.subset_id}.json" 131 | # with open(os.path.join(args.save_dir, pred_fname), "w") as fout: 132 | # json.dump(results, fout, ensure_ascii=True) 133 | 134 | # metric_fname = "metrics.json" 135 | # if args.n_subsets > 1: 136 | # metric_fname = f"metrics.{args.subset_id}.json" 137 | # with open(os.path.join(args.save_dir, metric_fname), "w") as fout: 138 | # json.dump({ 139 | # "n_samples": len(results), 140 | # "accuracy": sum(item['accuracy'] for item in results) / len(results), 141 | # }, fout, indent=4) 142 | 143 | 144 | 145 | 146 | ### 147 | 148 | time.sleep(1) 149 | 150 | local_pids = [global_pid for (global_pid, _, _) in procs] 151 | 152 | agg_preds = [] 153 | for fname in glob(os.path.join(output_dir, "predictions.*.json")): 154 | if any(str(pid) in fname for pid in local_pids): 155 | agg_preds.extend(read_data(fname)) 156 | 157 | metrics = {} 158 | n_samples = 0 159 | for fname in glob(os.path.join(output_dir, "metrics.*.json")): 160 | if not any(str(pid) in fname for pid in local_pids): 161 | continue 162 | _metrics = read_data(fname) 163 | n_samples += _metrics['n_samples'] 164 | for key, val in _metrics.items(): 165 | if key != 'n_samples': 166 | metrics[key] = metrics.get(key, 0) + val * _metrics['n_samples'] 167 | for key, val in metrics.items(): 168 | metrics[key] = val / max(n_samples, 1) 169 | 170 | result_msg = f"n samples = {n_samples}" 171 | for key, val in metrics.items(): 172 | result_msg += f"\n{key} = {val * 100}" 173 | 174 | metrics['n_samples'] = n_samples 175 | 176 | return metrics, agg_preds, result_msg 177 | 178 | def main(): 179 | parser = argparse.ArgumentParser() 180 | parser.add_argument("--output-dir", type=str, required=True, help="default to `model_path`_predictions") 181 | parser.add_argument("--model-path", type=str, required=True) 182 | parser.add_argument("--tokenizer-path", type=str, default=None) 183 | parser.add_argument("--model-size", type=str, choices=['1b', '7b', '13b', '33b', '34b', '70b'], default="7b") 184 | 185 | parser.add_argument("--test-conf", type=str, default="configs/zero_shot_test_configs.json", help="path to testing data config file that maps from a source to its info") 186 | parser.add_argument("--ngpus", type=int, default=8) 187 | parser.add_argument("--overwrite", action='store_true') 188 | parser.add_argument("--temperature", type=float, default=0) 189 | parser.add_argument("--n-repeats", type=int, default=1) 190 | parser.add_argument("--use-vllm", action='store_true') 191 | parser.add_argument("--load_in_half", action='store_true') 192 | 193 | parser.add_argument("--prompt_format", type=str, default="sft") 194 | parser.add_argument("--few_shot_prompt", type=str, default=None) 195 | 196 | parser.add_argument("--no-markup-question", action='store_true') 197 | 198 | parser.add_argument("--rank", type=int, default=None) 199 | parser.add_argument("--seed", type=int, default=42) 200 | args, _ = parser.parse_known_args() 201 | 202 | print(f"Evaluating {args.model_path}", flush=True) 203 | 204 | if args.output_dir is None: 205 | args.output_dir = f"{args.model_path.rstrip('/')}_predictions" 206 | 207 | args.ngpus_per_model = 4 if args.model_size in ['70b', '33b', '34b'] else 1 208 | assert args.ngpus % args.ngpus_per_model == 0 209 | 210 | default_few_shot_prompt = args.few_shot_prompt 211 | 212 | test_conf = read_data(args.test_conf) 213 | 214 | for src, info in test_conf.items(): 215 | if args.n_repeats > 1: 216 | _src = f"{src}/sample_logs" 217 | else: 218 | _src = f"{src}/infer_logs" 219 | if _worker_num > 1: 220 | _src = f"{_src}/{args.rank or _worker_id}" 221 | for task in info['tasks']: 222 | fname = os.path.join(args.output_dir, _src, task, "test_data", "test.jsonl") 223 | input_dir = os.path.dirname(fname) 224 | os.makedirs(input_dir, exist_ok=True) 225 | metric_path = os.path.join(args.output_dir, _src, task, "samples", "metrics.json") 226 | if not args.overwrite and os.path.exists(metric_path) and read_data(metric_path)['n_samples'] > 0: 227 | continue 228 | with open(fname, "w") as file: 229 | data = read_data(info['test_path']) 230 | for i, sample in enumerate(tqdm(data, desc=f'processing {src}')): 231 | fn = eval(info['process_fn']) 232 | sample['id'] = sample.get('id', f"{src}-{i}") 233 | for j, item in enumerate(fn(sample)): 234 | item['dataset'] = src 235 | item['id'] = f"{src}-test-{i}-{j}" 236 | assert 'answer' in item 237 | if not args.no_markup_question: 238 | item = markup_question(args, item, info['language'], src, task) 239 | print(json.dumps(item), file=file, flush=True) 240 | 241 | output_dir = os.path.join(args.output_dir, _src, task, "samples") 242 | log_dir = os.path.join(args.output_dir, _src, task, "logs") 243 | os.makedirs(output_dir, exist_ok=True) 244 | os.makedirs(log_dir, exist_ok=True) 245 | metrics, agg_preds, result_msg = do_parallel_sampling(args, task, info['answer_extraction_fn'], info['eval_fn'], input_dir, output_dir, log_dir) 246 | 247 | os.makedirs(os.path.dirname(metric_path), exist_ok=True) 248 | json.dump(metrics, open(metric_path, "w"), indent=4) 249 | data_path = os.path.join(args.output_dir, _src, task, "samples", "predictions.json") 250 | os.makedirs(os.path.dirname(data_path), exist_ok=True) 251 | with open(data_path, "w") as file: 252 | json.dump(agg_preds, file, ensure_ascii=False) 253 | print(f"src = {src} | task = {task} >>>\n{result_msg}\n\n", flush=True) 254 | 255 | if __name__ == '__main__': 256 | main() 257 | -------------------------------------------------------------------------------- /eval/math/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | import numpy as np 4 | 5 | def set_seed(seed): 6 | if seed > 0: 7 | random.seed(seed) 8 | np.random.seed(seed) 9 | 10 | def shuffle(data, seed): 11 | if seed < 0: 12 | return data 13 | set_seed(seed) 14 | indices = list(range(len(data))) 15 | np.random.shuffle(indices) 16 | data = [data[i] for i in indices] 17 | return data 18 | 19 | def read_data(path): 20 | if path.endswith("json"): 21 | data = json.load(open(path, "r")) 22 | elif path.endswith("jsonl"): 23 | data = [] 24 | with open(path, "r") as file: 25 | for line in file: 26 | line = json.loads(line) 27 | data.append(line) 28 | else: 29 | raise NotImplementedError() 30 | return data 31 | -------------------------------------------------------------------------------- /gen/gen_rft_data.py: -------------------------------------------------------------------------------- 1 | from huggingface_hub import login 2 | import argparse 3 | import json 4 | import re 5 | import jsonlines 6 | from vllm import LLM, SamplingParams 7 | import sys 8 | from tqdm.auto import tqdm 9 | 10 | def generate(model, data_path, tensor_parallel_size=1, temp=0.7, id=0, task="GSM8K"): 11 | inputs = [] 12 | answers = [] 13 | 14 | # Assume you are using 4 GPUs. for the parts that take longer, assign little less than others. 15 | 16 | if task == "GSM8K": 17 | if id == 0: 18 | start, end = 0, 1870 19 | elif id == 1: 20 | start, end = 1870, 3740 21 | elif id == 2: 22 | start, end = 3740, 5610 23 | elif id == 3: 24 | start, end = 5610, 7473 25 | 26 | elif task == "MATH": 27 | if id == 0: 28 | start, end = 0, 1750 29 | elif id == 1: 30 | start, end = 1750, 3750 31 | elif id == 2: 32 | start, end = 3750, 5750 33 | elif id == 3: 34 | start, end = 5750, 7500 35 | 36 | with open(data_path,"r+", encoding="utf8") as f: 37 | for idx, item in enumerate(jsonlines.Reader(f)): 38 | 39 | if not (start <= idx < end): 40 | continue 41 | 42 | inputs.append(item["query"].strip() + "\n") 43 | 44 | temp_ans = item['response'] # just add the response. We will worry about it later. 45 | answers.append(temp_ans) 46 | 47 | result_file = args.result_file.replace(".jsonl", f"_{id}.jsonl") 48 | 49 | try: 50 | already_done = [json.loads(x) for x in open(result_file)] 51 | except: 52 | already_done = [] 53 | 54 | if len(already_done) != 0: 55 | inputs = inputs[len(already_done):] 56 | answers = answers[len(already_done):] 57 | 58 | if len(inputs) == 0 and len(answers) == 0: 59 | print("Already completed. Exiting.") 60 | return 61 | 62 | stop_tokens = ["Problem:"] 63 | print("[GPU ID]", id, "Length of inputs", len(inputs)) 64 | 65 | if temp == 0.7: 66 | n = 100 # needs modification 67 | else: 68 | n = 1 69 | 70 | if task == "GSM8K": 71 | max_token_num = 512 72 | else: 73 | max_token_num = 1024 74 | 75 | sampling_params = SamplingParams(temperature=temp, top_p=1, max_tokens=max_token_num, stop=stop_tokens, n=n) 76 | print('sampling =====', sampling_params) 77 | llm = LLM(model=model,tensor_parallel_size=tensor_parallel_size, enforce_eager=True, swap_space=32) 78 | result = [] 79 | res_completions = [] 80 | 81 | completions = llm.generate(inputs, sampling_params) 82 | 83 | for num, output in enumerate(completions): 84 | prompt = output.prompt 85 | all_texts = [out.text for out in output.outputs] 86 | res_completions.append(all_texts) 87 | 88 | answer = answers[num] 89 | dict_ = {"prompt": prompt, "preds": all_texts, "answer": answer} 90 | 91 | with jsonlines.open(result_file, 'a') as writer: 92 | writer.write(dict_) 93 | 94 | print('start===', start, ', end====', end) 95 | 96 | def parse_args(): 97 | parser = argparse.ArgumentParser() 98 | parser.add_argument("--model", type=str) # model path 99 | parser.add_argument("--data_file", type=str, default='') # data path 100 | parser.add_argument("--tensor_parallel_size", type=int, default=1) # tensor_parallel_size 101 | parser.add_argument("--result_file", type=str, default="./log_file.jsonl") # tensor_parallel_size 102 | parser.add_argument("--temp", type=float, default=0.7) 103 | parser.add_argument("--id", type=int, default=0) 104 | parser.add_argument("--task", type=str, default="GSM8K") 105 | 106 | return parser.parse_args() 107 | 108 | if __name__ == "__main__": 109 | # Login First. 110 | login(token="huggingface_token_here") 111 | 112 | args = parse_args() 113 | generate(model=args.model, data_path=args.data_file, tensor_parallel_size=args.tensor_parallel_size, temp=args.temp, id=args.id, task=args.task) 114 | 115 | file_lists = [args.result_file.replace(".jsonl", f"_{i}.jsonl") for i in range(4)] 116 | 117 | # IF completed: 118 | import os 119 | # if all 4 exists, now unite file: 120 | if all([os.path.exists(x) for x in file_lists]) and sum([len([json.loads(x) for x in open(file_name)]) for file_name in file_lists]) == len([json.loads(x) for x in open(args.data_file)]): 121 | from utils_others import unite_file 122 | unite_file(args.result_file, 4, "_gen") 123 | 124 | # Then get rft and dpo data. 125 | gen_name = args.result_file.replace(".jsonl", "_gen.jsonl") 126 | 127 | # This will generate rft and dpo data. 128 | from get_rft_data import run_rft_and_dpo 129 | run_rft_and_dpo(gen_name, args.task) -------------------------------------------------------------------------------- /gen/gen_rft_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | model_path="put_your_FT_model_path" # Put the model path. 3 | result_file="put_your_result_file" # Put the result file path you wish to save the result. 4 | data_file="put_your_data_file" # Put the train jsonl file. 5 | task="GSM8K" 6 | 7 | # First set of processes with deepseek-ai/deepseek-math-7b-base 8 | for i in $(seq 0 3) # Loop from 0 to 6, stepping by 2 9 | do 10 | CUDA_VISIBLE_DEVICES=$i python gen_rft_data.py --task $task --model $model_path --data_file $data_file --result_file $result_file --temp 0.7 --id $i & 11 | done 12 | wait -------------------------------------------------------------------------------- /gen/gen_step_explore.py: -------------------------------------------------------------------------------- 1 | import optparse 2 | import sys 3 | from math_utils.eval_script import eval_math, is_correct 4 | from math_utils.answer_extraction import extract_answer, extract_math_few_shot_cot_answer 5 | 6 | from huggingface_hub import login 7 | import argparse 8 | import json 9 | import jsonlines 10 | from vllm import LLM, SamplingParams 11 | import sys 12 | from utils_others import get_final_steps 13 | from tqdm.auto import tqdm 14 | 15 | MAX_INT = sys.maxsize 16 | 17 | def extract_from_pred(x, y, z): 18 | try: 19 | return extract_math_few_shot_cot_answer(x,y,z)[0] 20 | except: 21 | return "[Invalid]" 22 | 23 | def test(model, data_path, tensor_parallel_size=1, temp=0.7, id=0, k=4, task="GSM8K", result_file=""): 24 | stop_tokens = [] 25 | 26 | if temp == 0.7: 27 | n = k # Change this if you want larger 'k' for exploration. 28 | else: 29 | n = 1 30 | 31 | li = [json.loads(x) for x in open(data_path)] 32 | 33 | # Get the rejected sample divided in steps. 34 | if 'final_steps' not in li[0].keys(): 35 | li = get_final_steps(li, task) 36 | 37 | li = [x for x in li if x['rejected'].strip("\n").strip() != ""] # filter out here. 38 | 39 | # This is based on the assumption there are 4 GPUs. 40 | portion = (len(li) + 4) // 4 41 | li = li[int(portion * id): int(portion * (id + 1))] 42 | 43 | all_elems = [] 44 | 45 | result_file = args.result_file.replace(".jsonl", f"_{id}.jsonl") 46 | ## Exclude already processed. 47 | try: 48 | already_processed = [json.loads(x) for x in open(result_file)] 49 | 50 | if len(already_processed) >= len(li) - 1: 51 | return 52 | else: 53 | li = li[len(already_processed):] 54 | except: 55 | already_processed = [] 56 | 57 | max_token_num = 512 if task == "GSM8K" else 1024 58 | sampling_params = SamplingParams(temperature=temp, top_p=1, max_tokens=max_token_num, stop=stop_tokens, n=n) 59 | 60 | print('sampling =====', sampling_params) 61 | print("Length Remaining ...", len(li)) 62 | llm = LLM(model=model,tensor_parallel_size=tensor_parallel_size, enforce_eager=True, swap_space=64) 63 | 64 | ## PREPARE SAMPLES 65 | for idx, elem in enumerate(li): 66 | elem['step_pred'] = {} 67 | elem['all_prompt'] = elem['final_steps'] 68 | elem['completed'] = False 69 | elem['num_steps'] = len(elem['all_prompt']) 70 | elem['sample_idx'] = idx 71 | 72 | # Get answer. 73 | if task == "GSM8K": 74 | elem['answer'] = elem['chosen'][elem['chosen'].index("The answer is") + len("The answer is"):].strip() 75 | else: 76 | elem['answer'] = extract_math_few_shot_cot_answer(elem['prompt'], elem['chosen'], "")[0] 77 | 78 | all_elems.append(elem) 79 | 80 | print(f"***** [Length Remaining] {len(all_elems)} *****") 81 | 82 | # Samples to process: 100 83 | SAMPLES_NUM = 1000 84 | 85 | for i in tqdm(range(0, len(all_elems) + SAMPLES_NUM, SAMPLES_NUM)): 86 | current_samples = all_elems[i:i+SAMPLES_NUM] 87 | max_step = min(max([len(x['all_prompt']) for x in current_samples]), 20) # set 20 as hard limit. 88 | 89 | for step_idx in range(max_step): 90 | # each samples' step_idx-th step. 91 | print("step_idx", step_idx) 92 | curr_step_to_proc = [x['all_prompt'][step_idx] for x in current_samples if x['completed'] == False] 93 | sample_idxs = [x['sample_idx'] for x in current_samples if x['completed'] == False] 94 | 95 | completions = llm.generate(curr_step_to_proc, sampling_params) 96 | 97 | # check if reached answer, if not discard. 98 | for sample_idx, output in zip(sample_idxs, completions): 99 | all_texts = [out.text for out in output.outputs] 100 | 101 | if task == "GSM8K": 102 | def get_answer(text): 103 | try: 104 | ans = text[text.index("The answer is") + len("The answer is"):].strip() 105 | except: 106 | ans = "[invalid]" 107 | return ans 108 | 109 | all_answers = [get_answer(text) for text in all_texts] 110 | else: 111 | all_answers = [extract_from_pred(all_elems[sample_idx]['prompt'], text, "") for text in all_texts] 112 | 113 | # If all false ... no need to continue. (i.e. we found the first pit!) 114 | if task == "GSM8K": 115 | if True not in [str(pred) == str(all_elems[sample_idx]['answer']) for pred in all_answers]: 116 | all_elems[sample_idx]['completed'] = True 117 | else: 118 | if True not in [eval_math({"prediction": pred, "answer": all_elems[sample_idx]['answer']}) for pred in all_answers]: 119 | all_elems[sample_idx]['completed'] = True 120 | 121 | # If processed all steps, mark as completed. 122 | if step_idx + 1 >= current_samples[sample_idx]['num_steps']: 123 | all_elems[sample_idx]['completed'] = True 124 | 125 | all_elems[sample_idx]["step_pred"][step_idx + 1] = {"prompt": all_elems[sample_idx]['all_prompt'][step_idx], "preds": all_texts, "answers": all_answers} 126 | 127 | with jsonlines.open(result_file, "a") as writer: 128 | writer.write_all(current_samples) 129 | 130 | # return to prevent server crash. 131 | return 132 | 133 | def parse_args(): 134 | parser = argparse.ArgumentParser() 135 | parser.add_argument("--model", type=str) # model path 136 | parser.add_argument("--data_file", type=str, default='') # data path 137 | parser.add_argument("--tensor_parallel_size", type=int, default=1) # tensor_parallel_size 138 | parser.add_argument("--result_file", type=str, default="./log_file.jsonl") # tensor_parallel_size 139 | parser.add_argument("--temp", type=float, default=0.7) 140 | parser.add_argument("--id", type=int, default=0) 141 | parser.add_argument("--k", type=int, default=4) # how many completions to generate per step. 142 | parser.add_argument("--task", type=str, default="GSM8K") # how many completions to generate per step. 143 | 144 | return parser.parse_args() 145 | 146 | if __name__ == "__main__": 147 | # Login First. 148 | # login(token="your_hugging_face_token_here") 149 | args = parse_args() 150 | test(model=args.model, data_path=args.data_file, tensor_parallel_size=args.tensor_parallel_size, temp=args.temp, id=args.id, k=args.k, task=args.task, result_file=args.result_file) 151 | 152 | # See if completed. 153 | import os 154 | file_lists = [args.result_file.replace(".jsonl", f"_{i}.jsonl") for i in range(4)] 155 | 156 | # IF completed: 157 | if all([os.path.exists(x) for x in file_lists]) and sum([len([json.loads(x) for x in open(file_name)]) for file_name in file_lists]) == len([json.loads(x) for x in open(args.data_file)]): 158 | from utils_others import unite_file 159 | unite_file(args.result_file, 4) 160 | 161 | # Then run form_gpair 162 | from utils_others import form_gpair 163 | form_gpair(args.result_file, args.result_file.replace(".jsonl", "_gpair_4.jsonl"), args.task) -------------------------------------------------------------------------------- /gen/gen_step_explore.sh: -------------------------------------------------------------------------------- 1 | # Assuming you hvae 4 gpus, run as following to speed up the process. 2 | model_path="put_your_model_path" # Put the model path. 3 | result_file="put_your_result_file" # Put the name of file to save exploration result. 4 | data_file="put_the_dpo_data_here" # Here, put the DPO file generated. 5 | temp=0.7 6 | task="GSM8K" # or MATH 7 | 8 | # Change the number of iterations accordingly to the data_file. 9 | for j in $(seq 1 10); do 10 | CUDA_VISIBLED_DEVICES=0 python gen_step_explore.py --task $task --model $model_path --temp $temp --result_file $result_file --data_file $data_file --id 0 & 11 | CUDA_VISIBLED_DEVICES=1 python gen_step_explore.py --task $task --model $model_path --temp $temp --result_file $result_file --data_file $data_file --id 1 & 12 | CUDA_VISIBLED_DEVICES=2 python gen_step_explore.py --task $task --model $model_path --temp $temp --result_file $result_file --data_file $data_file --id 2 & 13 | CUDA_VISIBLED_DEVICES=3 python gen_step_explore.py --task $task --model $model_path --temp $temp --result_file $result_file --data_file $data_file --id 3 & 14 | wait 15 | done -------------------------------------------------------------------------------- /gen/get_dpo_data.py: -------------------------------------------------------------------------------- 1 | # COLLECT DPO DATA 2 | import json 3 | import editdistance 4 | from tqdm.auto import tqdm 5 | import jsonlines 6 | import argparse 7 | from utils_others import extract_answer 8 | 9 | def collect_dpo_gsm8k(wo_dup_data_path, w_dup_data_path): 10 | wo_dup = [json.loads(x) for x in open(wo_dup_data_path)] 11 | w_dup = [json.loads(x) for x in open(w_dup_data_path)] 12 | 13 | query_preds_dict = {} 14 | dpo_data = [] 15 | 16 | # Set Wrong Outputs 17 | for elem in w_dup: 18 | 19 | elem['output_answers'] = [extract_answer(completion) for completion in elem['preds']] 20 | elem['outputs'] = elem['preds'] 21 | elem['ground_truth'] = extract_answer(elem['answer'].replace("####", "The answer is")) 22 | elem['input'] = elem['prompt'] 23 | 24 | query = elem['input'].rstrip("\n") 25 | wrong_outputs = [out for out, ans in zip(elem['outputs'], elem['output_answers']) if str(ans) != str(elem['ground_truth'])] 26 | query_preds_dict[query] = list(set(wrong_outputs)) 27 | 28 | # Find by Query: 29 | for elem in tqdm(wo_dup): 30 | query = elem['query'].rstrip("\n") 31 | if query not in query_preds_dict: 32 | raise AssertionError 33 | else: 34 | new_elem = {} 35 | new_elem['prompt'] = query.rstrip() + "\n" 36 | wrong_outputs = query_preds_dict[query] 37 | 38 | if len(wrong_outputs) == 0: 39 | continue # Skip 40 | 41 | new_elem['chosen'] = elem['response'].strip() 42 | new_elem['rejected'] = sorted([(ref, editdistance.eval(new_elem['chosen'], ref)) for ref in wrong_outputs], key=lambda x: x[1])[-1][0] # Choose the one with the largest edit distance. 43 | 44 | # Remove the rejected from the pool 45 | query_preds_dict[query].remove(new_elem['rejected']) 46 | dpo_data.append(new_elem) 47 | 48 | return dpo_data 49 | 50 | def collect_dpo_math(old_li, fname=""): 51 | 52 | dpo_set = [] 53 | 54 | for elem in tqdm(old_li): 55 | corr = elem['filtered_corr'] 56 | 57 | incorr = list(set(elem['filtered_incorr'])) # For incorrect, one could use string-level set, or edit distance based ... i.e., list(set(elem['incorr])) 58 | 59 | # for every correct sample, find the one with the maximum edit distance. 60 | MIN_NUM = min(len(corr), len(incorr)) 61 | 62 | for sample in corr[:MIN_NUM]: 63 | all_rej = sorted([(ref, editdistance.eval(sample, ref) / (len(sample) + len(ref))) for ref in incorr], key=lambda x: x[1]) 64 | 65 | # Try to find shortest one among maximum edit distance - this is to prevent from selecting one that has degenerated sample - i.e. repetition. 66 | filtered = sorted([x for x in all_rej if x[1] > 0.5], key=lambda x: len(x[0])) 67 | if len(filtered) == 0: 68 | filtered = sorted([x for x in all_rej if x[1] > 0.4], key=lambda x: len(x[0])) 69 | if len(filtered) == 0: 70 | filtered = sorted([x for x in all_rej if x[1] > 0.3], key=lambda x: len(x[0])) 71 | if len(filtered) == 0: 72 | print("Not passed any :( Selecting shortest ...") 73 | filtered = sorted(all_rej, key=lambda x: len(x[0])) 74 | 75 | rej = filtered[0][0] 76 | dpo_set.append({'prompt': elem['question'], 'chosen': sample, 'rejected': rej}) 77 | incorr.remove(rej) 78 | 79 | return dpo_set -------------------------------------------------------------------------------- /gen/get_rft_data.py: -------------------------------------------------------------------------------- 1 | # GET RFT DATA 2 | import re 3 | import editdistance 4 | from tqdm.auto import tqdm 5 | import json 6 | import jsonlines 7 | import argparse 8 | from utils_others import extract_answer 9 | from get_dpo_data import collect_dpo_gsm8k, collect_dpo_math 10 | import random 11 | import editdistance 12 | 13 | def read_jsonl(fname): 14 | return [json.loads(line) for line in open(fname)] 15 | 16 | def check_equation(eq): 17 | if eq.find('=') == -1: 18 | return False 19 | 20 | lhs = eq.split('=')[0] 21 | rhs = eq.split('=')[1] 22 | 23 | try: 24 | lhs_result = eval(str(lhs)) 25 | if abs(float(lhs_result) - float(rhs)) < 1e-3: 26 | return True 27 | except BaseException: 28 | return False 29 | return False 30 | 31 | def is_eq_correct(equations): 32 | for eq in equations: 33 | if not check_equation(eq): 34 | return False 35 | return True 36 | 37 | def collect_rft_data_gsm8k(fname): 38 | 39 | eq_pattern = r'<<([^>]*)>>' # Match everything inside << >> 40 | # eq_pattern = r'\b\d+\.?\d*\b' # This for mistral-metamath. 41 | 42 | all_sets = [] # list of dicts of 'query' and 'response'. 43 | old_li = read_jsonl(fname) 44 | 45 | for sample in tqdm(old_li): 46 | sample['preds'] = list(set(sample['preds'])) 47 | sample['output_answers'] = [extract_answer(completion) for completion in sample['preds']] 48 | sample['outputs'] = [x.strip() for x in sample['preds']] 49 | sample['ground_truth'] = extract_answer(sample['answer'].replace("####", "The answer is")) 50 | sample['input'] = sample['prompt'].rstrip() + "\n" 51 | 52 | exist_match = [] 53 | correct_preds = [reasoning for reasoning, answer in zip(sample['outputs'], sample['output_answers']) if str(answer) == str(sample['ground_truth'])] 54 | 55 | matches = [{"reasoning": r, "equations": re.findall(eq_pattern, r)} for r in correct_preds] # Find all matches 56 | 57 | # remove this line in case of mistral-metamath. 58 | matches = [m for m in matches if is_eq_correct(m['equations'])] 59 | 60 | final_preds = {} 61 | 62 | for elem in matches: 63 | match_string = '|'.join(elem['equations']).replace(' ', '') 64 | if match_string in final_preds: 65 | other_solutions = [final_preds[k] for k in final_preds if k != match_string] 66 | now_score = sum([editdistance.eval(elem['reasoning'], ref) for ref in other_solutions]) 67 | original_score = sum([editdistance.eval(final_preds[match_string], ref) for ref in other_solutions]) 68 | if now_score > original_score: 69 | final_preds[match_string] = elem['reasoning'] 70 | else: 71 | final_preds[match_string] = elem['reasoning'] 72 | 73 | sample['rft_outputs'] = list(final_preds.values()) 74 | 75 | MAX_NUM = 8 76 | for rft_sample in sample['rft_outputs'][:MAX_NUM]: 77 | all_sets.append({'query': sample['input'].rstrip("\n"), 'response': rft_sample}) 78 | 79 | return old_li, all_sets 80 | 81 | def collect_rft_data_math(fname): 82 | from math_utils.eval_script import eval_math, is_correct 83 | from math_utils.answer_extraction import extract_answer, extract_math_few_shot_cot_answer 84 | 85 | random.seed(42) 86 | MAX_SAMPLES = 8 87 | 88 | def remove_dup(corr, N=100): 89 | corr = list(set(corr)) 90 | random.shuffle(corr) 91 | final_results = [] 92 | 93 | for elem in corr: 94 | if all([editdistance.eval(elem, x) / (len(elem) + len(x) // 2) > 0.2 for x in final_results]): 95 | final_results.append(elem) 96 | if len(final_results) >= N: 97 | break 98 | return final_results 99 | 100 | # file should be already aggregated. 101 | li = [json.loads(x) for x in open(fname)] 102 | 103 | # Make sure you check it by your own. 104 | # if len(li) != 7500 or len(set([x['prompt'] for x in li])) != 7500: 105 | # raise ValueError("RFT is missing some file. Make sure all questions are properly handled.") 106 | 107 | # [To-do] Change data_file dir. 108 | all_q = [json.loads(x) for x in open("../data/MATH_train.jsonl")] 109 | 110 | for orig_elem, elem in zip(all_q, li): 111 | elem['answer'] = extract_answer(elem['answer']) 112 | elem['question'] = orig_elem['query'] 113 | 114 | new_li = [] 115 | 116 | for idx in tqdm(range(len(li))): 117 | 118 | li[idx]['incorr'] = [] 119 | li[idx]['corr'] = [] 120 | q = li[idx]['question'] 121 | 122 | for pred in li[idx]['preds']: 123 | # Only select those that "Final Answer:" is present, to prevent repetition being used as negative. 124 | # We want a good quality negative. :) 125 | tgt_str = "Final Answer:" 126 | tgt_str = "The answer is" 127 | 128 | if tgt_str not in pred.strip().split("\n")[-1]: 129 | if tgt_str in pred: 130 | pred_idx = pred.index(tgt_str) 131 | end_of_line_idx = pred.find('\n', pred_idx + len(tgt_str)) 132 | if end_of_line_idx != -1: 133 | pred = pred[:end_of_line_idx] 134 | else: 135 | print("Final Answer Error.") 136 | continue # just skip. 137 | else: 138 | continue # If not found, likely to be repetition, or incomplete solution. 139 | 140 | try: 141 | new_x = {"prediction": extract_math_few_shot_cot_answer(q, pred, "")[0], "answer": li[idx]['answer']} 142 | out = eval_math(new_x) 143 | except: 144 | out = False 145 | 146 | if out: 147 | li[idx]['corr'].append(pred) 148 | else: 149 | li[idx]['incorr'].append(pred) 150 | 151 | filtered_corr = remove_dup(li[idx]['corr'], N=MAX_SAMPLES) 152 | filtered_corr = [x.replace("I hope it is correct.", "" ).strip() for x in filtered_corr] 153 | li[idx]['filtered_corr'] = filtered_corr 154 | 155 | if len(filtered_corr) == 0: 156 | li[idx]['filtered_incorr'] = [] 157 | continue 158 | 159 | filtered_incorr = remove_dup(li[idx]['incorr'], N=MAX_SAMPLES * 3) # for simplicity. 160 | filtered_incorr = [x.replace("I hope it is correct.", "" ).strip() for x in filtered_incorr] 161 | li[idx]['filtered_incorr'] = filtered_incorr 162 | 163 | for filtered_corr_sample in filtered_corr: 164 | new_li.append({"query": q.rstrip() + "\n", "response": filtered_corr_sample}) 165 | 166 | return li, new_li 167 | 168 | def run_rft_and_dpo(fname, task): 169 | 170 | fname1 = fname 171 | fname2 = fname1.replace("gen", "rft") 172 | fname3 = fname1.replace("gen", "dpo") 173 | 174 | collect_rft_fn = collect_rft_data_gsm8k if task == "GSM8K" else collect_rft_data_math 175 | collect_dpo_fn = collect_dpo_gsm8k if task == "GSM8K" else collect_dpo_math 176 | 177 | old_set, new_set = collect_rft_fn(fname1) 178 | 179 | # Generate RFT file 180 | with jsonlines.open(fname2, mode='w') as writer: 181 | writer.write_all(new_set) 182 | 183 | if task == "GSM8K": 184 | new_set = collect_dpo_fn(fname2, fname1) 185 | else: 186 | new_set = collect_dpo_fn(old_set, fname1) 187 | 188 | # Generate DPO file 189 | with jsonlines.open(fname3, mode='w') as writer: 190 | writer.write_all(new_set) -------------------------------------------------------------------------------- /gen/math_utils/eval_script.py: -------------------------------------------------------------------------------- 1 | import regex 2 | from copy import deepcopy 3 | from .eval_utils import math_equal 4 | from .ocwcourses_eval_utils import normalize_numeric, numeric_equality, normalize_symbolic_equation, SymbolicMathMixin 5 | 6 | def is_correct(item, pred_key='prediction', prec=1e-3): 7 | # print("called again ...", item) 8 | pred = item[pred_key] 9 | ans = item['answer'] 10 | if isinstance(pred, list) and isinstance(ans, list): 11 | pred_matched = set() 12 | ans_matched = set() 13 | for i in range(len(pred)): 14 | for j in range(len(ans)): 15 | item_cpy = deepcopy(item) 16 | item_cpy.update({ 17 | pred_key: pred[i], 18 | 'answer': ans[j] 19 | }) 20 | if is_correct(item_cpy, pred_key=pred_key, prec=prec): 21 | pred_matched.add(i) 22 | ans_matched.add(j) 23 | if item_cpy[pred_key] == '2,3,4': 24 | print(item, flush=True) 25 | print("wtf", flush=True) 26 | return len(pred_matched) == len(pred) and len(ans_matched) == len(ans) 27 | elif isinstance(pred, str) and isinstance(ans, str): 28 | if '\\cup' in pred and '\\cup' in ans: 29 | item = deepcopy(item) 30 | item.update({ 31 | pred_key: pred.split('\\cup'), 32 | 'answer': ans.split('\\cup'), 33 | }) 34 | return is_correct(item, pred_key=pred_key, prec=prec) 35 | else: 36 | label = False 37 | try: 38 | label = abs(float(regex.sub(r',', '', str(pred))) - float(regex.sub(r',', '', str(ans)))) < prec 39 | except: 40 | pass 41 | label = label or (ans and pred == ans) or math_equal(pred, ans) 42 | return label 43 | else: 44 | print(item, flush=True) 45 | raise NotImplementedError() 46 | 47 | def eval_math(item, pred_key='prediction', prec=1e-3): 48 | pred = item[pred_key] 49 | if pred_key == 'program_output' and isinstance(pred, str): 50 | pred = [pred] 51 | ans = item['answer'] 52 | if isinstance(pred, list) and isinstance(ans, list): 53 | # for some questions in MATH, `reference` repeats answers 54 | _ans = [] 55 | for a in ans: 56 | if a not in _ans: 57 | _ans.append(a) 58 | ans = _ans 59 | # some predictions for MATH questions also repeats answers 60 | _pred = [] 61 | for a in pred: 62 | if a not in _pred: 63 | _pred.append(a) 64 | # some predictions mistakenly box non-answer strings 65 | pred = _pred[-len(ans):] 66 | 67 | item.update({ 68 | pred_key: pred, 69 | 'answer': ans 70 | }) 71 | return is_correct(item, pred_key=pred_key, prec=prec) 72 | 73 | def eval_last_single_answer(item, pred_key='prediction', prec=1e-3): 74 | for key in [pred_key, 'answer']: 75 | assert isinstance(item[key], str), f"{key} = `{item[key]}` is not a str" 76 | return is_correct(item, pred_key=pred_key, prec=prec) 77 | 78 | def eval_agieval_gaokao_math_cloze(item, pred_key='prediction', prec=1e-3): 79 | if pred_key == 'program_output' and isinstance(item[pred_key], str): 80 | item[pred_key] = [item[pred_key]] 81 | for key in [pred_key, 'answer']: 82 | assert isinstance(item[key], list), f"{key} = `{item[key]}` is not a list" 83 | pred = item[pred_key] 84 | ans = item['answer'] 85 | _pred = [] 86 | for p in pred: 87 | p = p + ";" 88 | while p: 89 | left_brackets = 0 90 | for i in range(len(p)): 91 | if p[i] == ';' or (p[i] == ',' and left_brackets == 0): 92 | _p, p = p[:i].strip(), p[i + 1:].strip() 93 | if _p not in _pred: 94 | _pred.append(_p) 95 | break 96 | elif p[i] in '([{': 97 | left_brackets += 1 98 | elif p[i] in ')]}': 99 | left_brackets -= 1 100 | pred = _pred[-len(ans):] 101 | if len(pred) == len(ans): 102 | for p, a in zip(pred, ans): 103 | item.update({ 104 | pred_key: p, 105 | 'answer': a, 106 | }) 107 | if not is_correct(item, pred_key=pred_key, prec=prec): 108 | return False 109 | return True 110 | else: 111 | return False 112 | 113 | def eval_agieval_gaokao_mathqa(item, pred_key='prediction', prec=1e-3): 114 | if pred_key == 'program_output' and isinstance(item[pred_key], str): 115 | item[pred_key] = [item[pred_key]] 116 | pred_str = " ".join(item[pred_key]) 117 | ans = item['answer'] 118 | tag = None 119 | idx = -1 120 | for t in 'ABCD': 121 | if t in pred_str and pred_str.index(t) > idx: 122 | tag = t 123 | idx = pred_str.index(t) 124 | return tag == ans 125 | 126 | def eval_math_sat(item, pred_key='prediction', prec=1e-3): 127 | for key in [pred_key, 'answer']: 128 | assert isinstance(item[key], str), f"{key} = `{item[key]}` is not a str" 129 | return item[pred_key].lower() == item['answer'].lower() 130 | 131 | def eval_mmlu_stem(item, pred_key='prediction', prec=1e-3): 132 | return eval_math_sat(item, pred_key=pred_key, prec=prec) 133 | 134 | def eval_ocwcourses(item, pred_key='prediction', prec=1e-3): 135 | INVALID_ANSWER = "[invalidanswer]" 136 | for key in [pred_key, 'answer']: 137 | assert isinstance(item[key], str), f"{key} = `{item[key]}` is not a str" 138 | pred = item[pred_key] 139 | ans = item['answer'] 140 | 141 | try: 142 | float(ans) 143 | normalize_fn = normalize_numeric 144 | is_equiv = numeric_equality 145 | answer_type = "numeric" 146 | except ValueError: 147 | if "=" in ans: 148 | normalize_fn = normalize_symbolic_equation 149 | is_equiv = lambda x, y: x==y 150 | answer_type = "equation" 151 | else: 152 | normalize_fn = SymbolicMathMixin().normalize_tex 153 | is_equiv = SymbolicMathMixin().is_tex_equiv 154 | answer_type = "expression" 155 | 156 | correct_answer = normalize_fn(ans) 157 | 158 | unnormalized_answer = pred if pred else INVALID_ANSWER 159 | model_answer = normalize_fn(unnormalized_answer) 160 | 161 | if unnormalized_answer == INVALID_ANSWER: 162 | acc = 0 163 | elif model_answer == INVALID_ANSWER: 164 | acc = 0 165 | elif is_equiv(model_answer, correct_answer): 166 | acc = 1 167 | else: 168 | acc = 0 169 | 170 | return acc 171 | 172 | def eval_minif2f_isabelle(item, pred_key='prediction', prec=1e-3): 173 | return True 174 | -------------------------------------------------------------------------------- /gen/math_utils/ocwcourses_eval_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import numpy as np 3 | import sympy 4 | from sympy.core.sympify import SympifyError 5 | from sympy.parsing.latex import parse_latex 6 | 7 | import signal 8 | 9 | INVALID_ANSWER = "[invalidanswer]" 10 | 11 | class timeout: 12 | def __init__(self, seconds=1, error_message="Timeout"): 13 | self.seconds = seconds 14 | self.error_message = error_message 15 | 16 | def handle_timeout(self, signum, frame): 17 | raise TimeoutError(self.error_message) 18 | 19 | def __enter__(self): 20 | signal.signal(signal.SIGALRM, self.handle_timeout) 21 | signal.alarm(self.seconds) 22 | 23 | def __exit__(self, type, value, traceback): 24 | signal.alarm(0) 25 | 26 | def normalize_numeric(s): 27 | if s is None: 28 | return None 29 | for unit in [ 30 | "eV", 31 | " \\mathrm{~kg} \\cdot \\mathrm{m} / \\mathrm{s}", 32 | " kg m/s", 33 | "kg*m/s", 34 | "kg", 35 | "m/s", 36 | "m / s", 37 | "m s^{-1}", 38 | "\\text{ m/s}", 39 | " \\mathrm{m/s}", 40 | " \\text{ m/s}", 41 | "g/mole", 42 | "g/mol", 43 | "\\mathrm{~g}", 44 | "\\mathrm{~g} / \\mathrm{mol}", 45 | "W", 46 | "erg/s", 47 | "years", 48 | "year", 49 | "cm", 50 | ]: 51 | s = s.replace(unit, "") 52 | s = s.strip() 53 | for maybe_unit in ["m", "s", "cm"]: 54 | s = s.replace("\\mathrm{" + maybe_unit + "}", "") 55 | s = s.replace("\\mathrm{~" + maybe_unit + "}", "") 56 | s = s.strip() 57 | s = s.strip("$") 58 | try: 59 | return float(eval(s)) 60 | except: 61 | try: 62 | expr = parse_latex(s) 63 | if expr.is_number: 64 | return float(expr) 65 | return INVALID_ANSWER 66 | except: 67 | return INVALID_ANSWER 68 | 69 | def numeric_equality(n1, n2, threshold=0.01): 70 | if n1 is None or n2 is None: 71 | return False 72 | if np.isclose(n1, 0) or np.isclose(n2, 0) or np.isclose(n1 - n2, 0): 73 | return np.abs(n1 - n2) < threshold * (n1 + n2) / 2 74 | else: 75 | return np.isclose(n1, n2) 76 | 77 | def normalize_symbolic_equation(s): 78 | if not isinstance(s, str): 79 | return INVALID_ANSWER 80 | if s.startswith("\\["): 81 | s = s[2:] 82 | if s.endswith("\\]"): 83 | s = s[:-2] 84 | s = s.replace("\\left(", "(") 85 | s = s.replace("\\right)", ")") 86 | s = s.replace("\\\\", "\\") 87 | if s.startswith("$") or s.endswith("$"): 88 | s = s.strip("$") 89 | try: 90 | maybe_expression = parse_latex(s) 91 | if not isinstance(maybe_expression, sympy.core.relational.Equality): 92 | # we have equation, not expression 93 | return INVALID_ANSWER 94 | else: 95 | return maybe_expression 96 | except: 97 | return INVALID_ANSWER 98 | 99 | class SymbolicMathMixin: 100 | """ 101 | Methods useful for parsing mathematical expressions from text and determining equivalence of expressions. 102 | """ 103 | 104 | SUBSTITUTIONS = [ # used for text normalize 105 | ("an ", ""), 106 | ("a ", ""), 107 | (".$", "$"), 108 | ("\\$", ""), 109 | (r"\ ", ""), 110 | (" ", ""), 111 | ("mbox", "text"), 112 | (",\\text{and}", ","), 113 | ("\\text{and}", ","), 114 | ("\\text{m}", "\\text{}"), 115 | ] 116 | REMOVED_EXPRESSIONS = [ # used for text normalizer 117 | "square", 118 | "ways", 119 | "integers", 120 | "dollars", 121 | "mph", 122 | "inches", 123 | "ft", 124 | "hours", 125 | "km", 126 | "units", 127 | "\\ldots", 128 | "sue", 129 | "points", 130 | "feet", 131 | "minutes", 132 | "digits", 133 | "cents", 134 | "degrees", 135 | "cm", 136 | "gm", 137 | "pounds", 138 | "meters", 139 | "meals", 140 | "edges", 141 | "students", 142 | "childrentickets", 143 | "multiples", 144 | "\\text{s}", 145 | "\\text{.}", 146 | "\\text{\ns}", 147 | "\\text{}^2", 148 | "\\text{}^3", 149 | "\\text{\n}", 150 | "\\text{}", 151 | r"\mathrm{th}", 152 | r"^\circ", 153 | r"^{\circ}", 154 | r"\;", 155 | r",\!", 156 | "{,}", 157 | '"', 158 | "\\dots", 159 | ] 160 | 161 | def normalize_tex(self, final_answer: str) -> str: 162 | """ 163 | Normalizes a string representing a mathematical expression. 164 | Used as a preprocessing step before parsing methods. 165 | 166 | Copied character for character from appendix D of Lewkowycz et al. (2022) 167 | """ 168 | final_answer = final_answer.split("=")[-1] 169 | 170 | for before, after in self.SUBSTITUTIONS: 171 | final_answer = final_answer.replace(before, after) 172 | for expr in self.REMOVED_EXPRESSIONS: 173 | final_answer = final_answer.replace(expr, "") 174 | 175 | # Extract answer that is in LaTeX math, is bold, 176 | # is surrounded by a box, etc. 177 | final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer) 178 | final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer) 179 | final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer) 180 | final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer) 181 | final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer) 182 | 183 | # Normalize shorthand TeX: 184 | # \fracab -> \frac{a}{b} 185 | # \frac{abc}{bef} -> \frac{abc}{bef} 186 | # \fracabc -> \frac{a}{b}c 187 | # \sqrta -> \sqrt{a} 188 | # \sqrtab -> sqrt{a}b 189 | final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer) 190 | final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer) 191 | final_answer = final_answer.replace("$", "") 192 | 193 | # Normalize 100,000 -> 100000 194 | if final_answer.replace(",", "").isdigit(): 195 | final_answer = final_answer.replace(",", "") 196 | 197 | return final_answer 198 | 199 | def parse_tex(self, text: str, time_limit: int = 5) -> sympy.Basic: 200 | """ 201 | Wrapper around `sympy.parse_text` that outputs a SymPy expression. 202 | Typically, you want to apply `normalize_text` as a preprocessing step. 203 | """ 204 | try: 205 | with timeout(seconds=time_limit): 206 | parsed = parse_latex(text) 207 | except ( 208 | # general error handling: there is a long tail of possible sympy/other 209 | # errors we would like to catch 210 | Exception 211 | ) as e: 212 | print(f"failed to parse {text} with exception {e}") 213 | return None 214 | 215 | return parsed 216 | 217 | def is_exp_equiv(self, x1: sympy.Basic, x2: sympy.Basic, time_limit=5) -> bool: 218 | """ 219 | Determines whether two sympy expressions are equal. 220 | """ 221 | try: 222 | with timeout(seconds=time_limit): 223 | try: 224 | diff = x1 - x2 225 | except (SympifyError, ValueError, TypeError) as e: 226 | print( 227 | f"Couldn't subtract {x1} and {x2} with exception {e}" 228 | ) 229 | return False 230 | 231 | try: 232 | if sympy.simplify(diff) == 0: 233 | return True 234 | else: 235 | return False 236 | except (SympifyError, ValueError, TypeError) as e: 237 | print(f"Failed to simplify {x1}-{x2} with {e}") 238 | return False 239 | except TimeoutError as e: 240 | print(f"Timed out comparing {x1} and {x2}") 241 | return False 242 | except Exception as e: 243 | print(f"failed on unrecognized exception {e}") 244 | return False 245 | 246 | def is_tex_equiv(self, x1: str, x2: str, time_limit=5) -> bool: 247 | """ 248 | Determines whether two (ideally normalized using `normalize_text`) TeX expressions are equal. 249 | 250 | Does so by first checking for string exact-match, then falls back on sympy-equivalence, 251 | following the (Lewkowycz et al. 2022) methodology. 252 | """ 253 | if x1 == x2: 254 | # don't resort to sympy if we have full string match, post-normalization 255 | return True 256 | else: 257 | return False 258 | parsed_x2 = self.parse_tex(x2) 259 | if not parsed_x2: 260 | # if our reference fails to parse into a Sympy object, 261 | # we forgo parsing + checking our generated answer. 262 | return False 263 | return self.is_exp_equiv(self.parse_tex(x1), parsed_x2, time_limit=time_limit) 264 | -------------------------------------------------------------------------------- /images/main_result_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hbin0701/Self-Explore/54b8ba9f4cee1ef52f72de916e100e0da2ae6863/images/main_result_image.png -------------------------------------------------------------------------------- /images/overview_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hbin0701/Self-Explore/54b8ba9f4cee1ef52f72de916e100e0da2ae6863/images/overview_image.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.27.0 2 | datasets==2.14.6 3 | deepspeed==0.13.1 4 | editdistance==0.8.1 5 | einops==0.7.0 6 | flash-attn==2.5.2 7 | huggingface-hub==0.20.3 8 | json5==0.9.14 9 | jsonlines==4.0.0 10 | numpy==1.24.3 11 | openai==1.13.3 12 | sentencepiece==0.1.99 13 | transformers==4.38.1 14 | triton==2.1.0 15 | trl==0.7.10 16 | vllm==0.3.2 17 | wandb==0.16.1 -------------------------------------------------------------------------------- /scripts/gsm8k/dpo/config.yaml: -------------------------------------------------------------------------------- 1 | # Model arguments 2 | model_name_or_path: # put model_name_to_train_on_here 3 | run_name: # put_wandb_run_nmae_here 4 | 5 | # Data training arguments 6 | # For definitions, see: src/h4/training/config.py 7 | dataset_mixer: 8 | HuggingFaceH4/ultrafeedback_binarized: 1.0 9 | dataset_splits: 10 | - train_prefs 11 | - test_prefs 12 | preprocessing_num_workers: 12 13 | train_data_file: train_data_file_directory # for both training and test_data, put training data. 14 | test_data_file: train_data_file_directory # testing data (i.e. eval set) is a dummy data(subset of training), because we don't use eval data. 15 | 16 | # DPOTrainer argument 17 | bf16: true 18 | beta: 0.1 19 | do_eval: true 20 | evaluation_strategy: epoch 21 | gradient_accumulation_steps: 1 22 | gradient_checkpointing: true 23 | hub_model_id: zephyr-7b-dpo-full 24 | learning_rate: 1.0e-7 # use 1.0e-7 for Mistral, 1.0e-6 for others. 25 | log_level: info 26 | logging_steps: 10 27 | lr_scheduler_type: linear 28 | max_length: 384 29 | num_train_epochs: 3 30 | optim: rmsprop 31 | output_dir: some_directory_to_save # Put directory for saving model checkpoints here 32 | per_device_train_batch_size: 8 33 | per_device_eval_batch_size: 8 34 | push_to_hub: false 35 | save_strategy: "epoch" 36 | save_only_model: true 37 | save_total_limit: 3 38 | seed: 42 39 | warmup_ratio: 0.1 -------------------------------------------------------------------------------- /scripts/gsm8k/dpo/deepspeed_zero3.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | main_process_port: 32899 3 | debug: false 4 | deepspeed_config: 5 | deepspeed_multinode_launcher: standard 6 | offload_optimizer_device: none 7 | offload_param_device: none 8 | zero3_init_flag: true 9 | zero3_save_16bit_model: true 10 | zero_stage: 3 11 | distributed_type: DEEPSPEED 12 | downcast_bf16: 'no' 13 | machine_rank: 0 14 | main_training_function: main 15 | mixed_precision: bf16 16 | num_machines: 1 17 | num_processes: 4 18 | rdzv_backend: static 19 | same_network: true 20 | tpu_env: [] 21 | tpu_use_cluster: false 22 | tpu_use_sudo: false 23 | use_cpu: false 24 | -------------------------------------------------------------------------------- /scripts/gsm8k/dpo/run_dpo.sh: -------------------------------------------------------------------------------- 1 | # Run form main directory. 2 | CUDA_VISIBLE_DEVICES=0,1,2,3 ACCELERATE_LOG_LEVEL=info accelerate launch --config_file scripts/gsm8k/dpo/deepspeed_zero3.yaml dpo/run_dpo.py scripts/gsm8k/dpo/config.yaml -------------------------------------------------------------------------------- /scripts/gsm8k/sft/config.yaml: -------------------------------------------------------------------------------- 1 | command_file: null 2 | commands: null 3 | main_process_port: 51888 4 | compute_environment: LOCAL_MACHINE 5 | deepspeed_config: 6 | deepspeed_multinode_launcher: standard 7 | gradient_clipping: 1.0 8 | zero_stage: 1 9 | distributed_type: DEEPSPEED 10 | downcast_bf16: true 11 | dynamo_backend: 'NO' 12 | gpu_ids: null 13 | machine_rank: 0 14 | main_process_ip: null 15 | main_process_port: null 16 | main_training_function: main 17 | fsdp_config: {} 18 | megatron_lm_config: {} 19 | mixed_precision: bf16 20 | num_machines: 1 21 | num_processes: 4 22 | rdzv_backend: static 23 | same_network: true 24 | tpu_name: null 25 | tpu_zone: null 26 | use_cpu: false 27 | -------------------------------------------------------------------------------- /scripts/gsm8k/sft/run_ft.sh: -------------------------------------------------------------------------------- 1 | export WANDB_API_KEY=wandb_key # put your wandb key here 2 | export WANDB_PROJECT=project_name # put your project name here 3 | export WANDB_ENTITY=my_id # put your wandb id here. 4 | 5 | model_name_or_path=/home/user/models/deepseek-math-7b-base # model path to train on. 6 | save_generator_id=deepseek_GSM8K_FT # model name to be saved. 7 | 8 | save_dir=/home/user/models/${save_generator_id}/ 9 | export WANDB_NAME=${save_generator_id} 10 | 11 | # lr: 1e-6 for Mistral and 1e-5 for Others. 12 | 13 | accelerate launch \ 14 | --config_file scripts/gsm8k/sft/config.yaml \ 15 | --main_process_port=40999 \ 16 | sft/train_generator.py \ 17 | --model_name_or_path ${model_name_or_path} \ 18 | --data_dir put_your_data_file_here \ 19 | --target_set train \ 20 | --save_dir ${save_dir} \ 21 | --num_train_epoches 5 \ 22 | --save_strategy epoch \ 23 | --per_device_train_batch_size 16 \ 24 | --per_device_eval_batch_size 4 \ 25 | --gradient_accumulation_steps 1 \ 26 | --gradient_checkpointing True \ 27 | --learning_rate 1e-5 \ 28 | --weight_decay 0 \ 29 | --lr_scheduler_type "linear" \ 30 | --warmup_steps 0 \ 31 | --save_best False \ 32 | --save_total_limit 5 \ 33 | --logging_dir ./wandb \ 34 | --logging_steps 8 \ 35 | --seed 42 \ 36 | --save_model_only True \ 37 | --mode "ft_GSM8K" # mode is one of "ft_GSM8K", "ft_MATH", "rft_GSM8K", "rft_MATH" -------------------------------------------------------------------------------- /scripts/gsm8k/sft/run_rft.sh: -------------------------------------------------------------------------------- 1 | export WANDB_API_KEY=wandb_key # put your wandb key here 2 | export WANDB_PROJECT=project_name # put your project name here 3 | export WANDB_ENTITY=my_id # put your wandb id here. 4 | 5 | model_name_or_path=/home/user/models/deepseek-math-7b-base # model path to train on. 6 | save_generator_id=deepseek_GSM8K_RFT # model name to be saved. 7 | 8 | save_dir=/home/user/models/${save_generator_id}/ 9 | export WANDB_NAME=${save_generator_id} 10 | 11 | # lr: 1e-6 for Mistral and 1e-5 for Others. 12 | 13 | accelerate launch \ 14 | --config_file scripts/gsm8k/sft/config.yaml \ 15 | --main_process_port=40999 \ 16 | sft/train_generator.py \ 17 | --model_name_or_path ${model_name_or_path} \ 18 | --data_dir put_your_data_file_here \ 19 | --target_set train \ 20 | --save_dir ${save_dir} \ 21 | --num_train_epoches 5 \ 22 | --save_strategy epoch \ 23 | --per_device_train_batch_size 16 \ 24 | --per_device_eval_batch_size 4 \ 25 | --gradient_accumulation_steps 1 \ 26 | --gradient_checkpointing True \ 27 | --learning_rate 1e-5 \ 28 | --weight_decay 0 \ 29 | --lr_scheduler_type "linear" \ 30 | --warmup_steps 0 \ 31 | --save_best False \ 32 | --save_total_limit 5 \ 33 | --logging_dir ./wandb \ 34 | --logging_steps 8 \ 35 | --seed 42 \ 36 | --save_model_only True \ 37 | --mode "rft_GSM8K" # mode is one of "ft_GSM8K", "ft_MATH", "rft_GSM8K", "rft_MATH" -------------------------------------------------------------------------------- /scripts/math/dpo/config.yaml: -------------------------------------------------------------------------------- 1 | # Model arguments 2 | model_name_or_path: # put model_name_to_train_on_here 3 | run_name: # put_wandb_run_nmae_here 4 | 5 | # Data training arguments 6 | # For definitions, see: src/h4/training/config.py 7 | dataset_mixer: 8 | HuggingFaceH4/ultrafeedback_binarized: 1.0 9 | dataset_splits: 10 | - train_prefs 11 | - test_prefs 12 | preprocessing_num_workers: 12 13 | train_data_file: train_data_file_directory # for both training and test_data, put training data. 14 | test_data_file: train_data_file_directory # testing data (i.e. eval set) is a dummy data(subset of training), because we don't use eval data. 15 | 16 | # DPOTrainer argument 17 | bf16: true 18 | beta: 0.1 19 | do_eval: true 20 | evaluation_strategy: epoch 21 | gradient_accumulation_steps: 1 22 | gradient_checkpointing: true 23 | hub_model_id: zephyr-7b-dpo-full 24 | learning_rate: 1.0e-7 # use 1.0e-7 for Mistral, 1.0e-6 for others. 25 | log_level: info 26 | logging_steps: 10 27 | lr_scheduler_type: linear 28 | max_length: 1024 29 | max_prompt_length: 512 30 | num_train_epochs: 3 31 | optim: rmsprop 32 | output_dir: some_directory_to_save # Put directory for saving model checkpoints here 33 | per_device_train_batch_size: 8 34 | per_device_eval_batch_size: 8 35 | push_to_hub: false 36 | save_strategy: "epoch" 37 | save_only_model: true 38 | save_total_limit: 3 39 | seed: 42 40 | warmup_ratio: 0.1 -------------------------------------------------------------------------------- /scripts/math/dpo/deepspeed_zero3.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | main_process_port: 48882 3 | debug: false 4 | deepspeed_config: 5 | deepspeed_multinode_launcher: standard 6 | offload_optimizer_device: none 7 | offload_param_device: none 8 | zero3_init_flag: true 9 | zero3_save_16bit_model: true 10 | zero_stage: 3 11 | distributed_type: DEEPSPEED 12 | downcast_bf16: 'no' 13 | machine_rank: 0 14 | main_training_function: main 15 | mixed_precision: bf16 16 | num_machines: 1 17 | num_processes: 4 18 | rdzv_backend: static 19 | same_network: true 20 | tpu_env: [] 21 | tpu_use_cluster: false 22 | tpu_use_sudo: false 23 | use_cpu: false 24 | -------------------------------------------------------------------------------- /scripts/math/dpo/run_dpo.sh: -------------------------------------------------------------------------------- 1 | # Run form main directory. 2 | CUDA_VISIBLE_DEVICES=0,1,2,3 ACCELERATE_LOG_LEVEL=info accelerate launch --config_file scripts/math/dpo/deepspeed_zero3.yaml dpo/run_dpo.py scripts/math/dpo/config.yaml -------------------------------------------------------------------------------- /scripts/math/sft/config.yaml: -------------------------------------------------------------------------------- 1 | command_file: null 2 | commands: null 3 | main_process_port: 51888 4 | compute_environment: LOCAL_MACHINE 5 | deepspeed_config: 6 | deepspeed_multinode_launcher: standard 7 | gradient_clipping: 1.0 8 | zero_stage: 1 9 | distributed_type: DEEPSPEED 10 | downcast_bf16: true 11 | dynamo_backend: 'NO' 12 | gpu_ids: null 13 | machine_rank: 0 14 | main_process_ip: null 15 | main_process_port: null 16 | main_training_function: main 17 | fsdp_config: {} 18 | megatron_lm_config: {} 19 | mixed_precision: bf16 20 | num_machines: 1 21 | num_processes: 4 22 | rdzv_backend: static 23 | same_network: true 24 | tpu_name: null 25 | tpu_zone: null 26 | use_cpu: false 27 | -------------------------------------------------------------------------------- /scripts/math/sft/run_ft.sh: -------------------------------------------------------------------------------- 1 | export WANDB_API_KEY=wandb_key # put your wandb key here 2 | export WANDB_PROJECT=project_name # put your project name here 3 | export WANDB_ENTITY=my_id # put your wandb id here. 4 | 5 | model_name_or_path=/home/user/models/deepseek-math-7b-base # model path to train on. 6 | save_generator_id=deepseek_math_FT # model name to be saved. 7 | 8 | save_dir=/home/user/models/${save_generator_id}/ 9 | export WANDB_NAME=${save_generator_id} 10 | 11 | # lr: 1e-6 for Mistral and 1e-5 for Others. 12 | 13 | accelerate launch \ 14 | --config_file scripts/math/sft/config.yaml \ 15 | --main_process_port=40999 \ 16 | sft/train_generator.py \ 17 | --model_name_or_path ${model_name_or_path} \ 18 | --data_dir put_your_data_file_here \ 19 | --target_set train \ 20 | --save_dir ${save_dir} \ 21 | --num_train_epoches 5 \ 22 | --save_strategy epoch \ 23 | --max_length 1024 \ 24 | --per_device_train_batch_size 8 \ 25 | --per_device_eval_batch_size 2 \ 26 | --gradient_accumulation_steps 2 \ 27 | --gradient_checkpointing True \ 28 | --learning_rate 1e-5 \ 29 | --weight_decay 0 \ 30 | --lr_scheduler_type "linear" \ 31 | --warmup_steps 0 \ 32 | --save_best False \ 33 | --save_total_limit 5 \ 34 | --logging_dir ./wandb \ 35 | --logging_steps 8 \ 36 | --seed 42 \ 37 | --save_model_only True \ 38 | --mode "ft_MATH" # mode is one of "ft_GSM8K", "ft_MATH", "rft_GSM8K", "rft_MATH" -------------------------------------------------------------------------------- /scripts/math/sft/run_rft.sh: -------------------------------------------------------------------------------- 1 | export WANDB_API_KEY=wandb_key # put your wandb key here 2 | export WANDB_PROJECT=project_name # put your project name here 3 | export WANDB_ENTITY=my_id # put your wandb id here. 4 | 5 | model_name_or_path=/home/user/models/deepseek-math-7b-base # model path to train on. 6 | save_generator_id=deepseek_math_RFT # model name to be saved. 7 | 8 | save_dir=/home/user/models/${save_generator_id}/ 9 | export WANDB_NAME=${save_generator_id} 10 | 11 | # lr: 1e-6 for Mistral and 1e-5 for Others. 12 | 13 | accelerate launch \ 14 | --config_file scripts/math/sft/config.yaml \ 15 | --main_process_port=40999 \ 16 | sft/train_generator.py \ 17 | --model_name_or_path ${model_name_or_path} \ 18 | --data_dir put_your_data_file_here \ 19 | --target_set train \ 20 | --save_dir ${save_dir} \ 21 | --num_train_epoches 5 \ 22 | --save_strategy epoch \ 23 | --max_length 1024 \ 24 | --per_device_train_batch_size 8 \ 25 | --per_device_eval_batch_size 2 \ 26 | --gradient_accumulation_steps 2 \ 27 | --gradient_checkpointing True \ 28 | --learning_rate 1e-5 \ 29 | --weight_decay 0 \ 30 | --lr_scheduler_type "linear" \ 31 | --warmup_steps 0 \ 32 | --save_best False \ 33 | --save_total_limit 5 \ 34 | --logging_dir ./wandb \ 35 | --logging_steps 8 \ 36 | --seed 42 \ 37 | --save_model_only True \ 38 | --mode "rft_MATH" # mode is one of "ft_GSM8K", "ft_MATH", "rft_GSM8K", "rft_MATH" -------------------------------------------------------------------------------- /sft/sft_datasets.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import re 4 | import torch 5 | import torch.nn.functional as F 6 | from typing import Optional, Sequence, List, Set, Dict, Any, Union 7 | import transformers 8 | import logging 9 | from dataclasses import dataclass 10 | import pathlib 11 | 12 | from utils.datasets import read_jsonl, get_few_shot_prompt, left_pad_sequences, right_pad_sequences, mask_labels 13 | from utils.constants import DEFAULT_BOS_TOKEN, DEFAULT_EOS_TOKEN, IGNORE_INDEX 14 | from utils.gsm8k.decoding import extract_answer 15 | 16 | def get_examples(data_dir, split, mode): 17 | 18 | train_mode, dataset = mode.split("_") 19 | 20 | if train_mode == "ft": 21 | data_dir = { 22 | 'train': data_dir, 23 | 'test': data_dir, 24 | }[split] 25 | 26 | examples = read_jsonl(data_dir) 27 | 28 | if dataset == "GSM8K": 29 | for ex in examples: 30 | ex['response'] = ex['response'].replace('#### ', 'The answer is ') 31 | 32 | elif train_mode == "rft": 33 | examples = read_jsonl(data_dir) 34 | 35 | else: 36 | raise NotImplementedError 37 | 38 | if dataset == "GSM8K": 39 | for ex in examples: 40 | if "question" not in ex: 41 | ex["question"] = ex["query"].strip() 42 | if "answer" not in ex: 43 | ex["answer"] = ex["response"].strip() 44 | 45 | if 'mm' not in data_dir.lower() and 'metamath' not in data_dir.lower(): 46 | ex.update(question=ex["question"].rstrip() + "\n") 47 | ex.update(answer=ex["answer"].replace('#### ', 'The answer is ')) 48 | else: 49 | ex.update(question=ex["question"].rstrip() + "\n") 50 | ex.update(answer=ex["answer"].strip()) 51 | else: 52 | for ex in examples: 53 | if "question" not in ex: 54 | ex["question"] = ex["query"].rstrip() + "\n" 55 | if "answer" not in ex: 56 | ex["answer"] = ex["response"].strip() 57 | 58 | print(examples[0]) 59 | print(f"{len(examples)} {split} examples") 60 | return examples 61 | 62 | 63 | def make_finetuning_generator_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args: dataclass) -> Dict: 64 | train_dataset = FineTuningGeneratorDataset( 65 | tokenizer=tokenizer, 66 | data_dir=data_args.data_dir, 67 | mode=data_args.mode, 68 | target_set=data_args.target_set, 69 | loss_on_prefix=data_args.loss_on_prefix, 70 | ) 71 | val_dataset = None 72 | 73 | return dict(train_dataset=train_dataset, val_dataset=val_dataset) 74 | 75 | 76 | def make_test_generator_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args: dataclass, inference_args: dataclass) -> Dict: 77 | test_dataset = TestGeneratorDataset( 78 | tokenizer=tokenizer, 79 | data_dir=data_args.data_dir, 80 | target_set=data_args.target_set 81 | ) 82 | return test_dataset 83 | 84 | class FineTuningGeneratorDataset(torch.utils.data.Dataset): 85 | def __init__( 86 | self, 87 | tokenizer: transformers.PreTrainedTokenizer = None, 88 | mode: str="ft", 89 | data_dir: str = 'data/gsm8k', 90 | target_set: str = 'train', 91 | loss_on_prefix=True, 92 | ): 93 | self.tokenizer = tokenizer 94 | self.data_dir = data_dir 95 | self.target_set = target_set 96 | self.loss_on_prefix = loss_on_prefix 97 | self.pad_token_id = tokenizer.pad_token_id 98 | self.eos_token_id = tokenizer.eos_token_id 99 | 100 | print("+ [Dataset] Loading Training Data") 101 | self.examples = get_examples(self.data_dir, target_set, mode) 102 | qns_str = [ex["question"] for ex in self.examples] 103 | ans_str = [ex["answer"] for ex in self.examples] 104 | 105 | print("+ [Dataset] Tokenizing Testing Data") 106 | 107 | qns_tokens = [] 108 | ans_tokens = [] 109 | 110 | for x, y in zip(qns_str, ans_str): 111 | x_ = tokenizer(x, padding=False).input_ids 112 | y_ = tokenizer(y, padding=False, add_special_tokens=False).input_ids 113 | 114 | if len(x_) + len(y_) >= 1024: 115 | continue 116 | 117 | else: 118 | qns_tokens.append(x_) 119 | ans_tokens.append(y_) 120 | 121 | # qns_tokens = tokenizer(qns_str, padding=False, max_length=1024).input_ids 122 | # ans_tokens = tokenizer(ans_str, padding=False, add_special_tokens=False, max_length=1024).input_ids 123 | 124 | self.qns_str = qns_str 125 | self.ans_str = ans_str 126 | self.qns_tokens = qns_tokens 127 | self.ans_tokens = ans_tokens 128 | 129 | print("MAX QNS TOKENS", max([len(qns_tokens[i]) for i in range(len(qns_tokens))])) 130 | print("MAX ANS TOKENS", max([len(ans_tokens[i]) for i in range(len(ans_tokens))])) 131 | 132 | self.max_len = max([ 133 | len(qns_tokens[i]) + len(ans_tokens[i]) + 1 134 | for i in range(len(qns_tokens)) 135 | ] 136 | ) 137 | print(f"Max tokens: {self.max_len}") 138 | print("Length:", len(self.qns_tokens)) 139 | 140 | def __len__(self): 141 | return len(self.qns_tokens) 142 | 143 | def __getitem__(self, idx): 144 | qn_tokens = self.qns_tokens[idx] 145 | ans_tokens = self.ans_tokens[idx] 146 | 147 | input_ids = qn_tokens + ans_tokens + [self.eos_token_id] 148 | # input_ids = qn_tokens + ans_tokens 149 | labels = input_ids 150 | 151 | masks = ( 152 | ([1] if self.loss_on_prefix else [0]) * len(qn_tokens) 153 | + ([1] * len(ans_tokens)) 154 | + ([1]) 155 | ) 156 | labels = mask_labels(labels, masks) 157 | 158 | input_ids = torch.tensor(input_ids) 159 | labels = torch.tensor(labels) 160 | return dict(input_ids=input_ids, labels=labels) 161 | 162 | def collate_fn(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: 163 | input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) 164 | input_ids, attention_mask = right_pad_sequences(input_ids, padding_value=self.pad_token_id, return_attention_mask=True) 165 | labels = right_pad_sequences(labels, padding_value=IGNORE_INDEX, return_attention_mask=False) 166 | 167 | return dict( 168 | input_ids=input_ids, 169 | attention_mask=attention_mask, 170 | labels=labels, 171 | ) 172 | 173 | 174 | 175 | 176 | class TestGeneratorDataset(torch.utils.data.Dataset): 177 | """Left Padding""" 178 | def __init__( 179 | self, 180 | tokenizer: transformers.PreTrainedTokenizer = None, 181 | data_dir: str = 'data/gsm8k', 182 | target_set: str = None, 183 | ): 184 | self.tokenizer = tokenizer 185 | self.data_dir = data_dir 186 | self.target_set = target_set 187 | self.pad_token_id = tokenizer.pad_token_id 188 | 189 | print("+ [Dataset] Loading Testing Data") 190 | self.examples = get_examples(data_dir, target_set) 191 | qns_str = [ex["question"] for ex in self.examples] 192 | ans_str = [ex["answer"] for ex in self.examples] 193 | gts_str = [extract_answer(ans) for ans in ans_str] 194 | 195 | print("+ [Dataset] Tokenizing Testing Data") 196 | qns_tokens = tokenizer(qns_str, padding=False, max_length=1024, truncate=True).input_ids 197 | ans_tokens = tokenizer(ans_str, padding=False, add_special_tokens=False, max_length=1024, truncate=True).input_ids 198 | 199 | self.qns_str = qns_str 200 | self.qns_tokens = qns_tokens 201 | self.ans_str = ans_str 202 | self.gts_str = gts_str 203 | 204 | self.max_len = max([ 205 | len(qns_tokens[i]) + len(ans_tokens[i]) + 1 206 | for i in range(len(qns_tokens)) 207 | ] 208 | ) 209 | print(f"Max tokens: {self.max_len}") 210 | 211 | def __len__(self): 212 | return len(self.examples) 213 | 214 | def __getitem__(self, idx): 215 | qn_tokens = self.qns_tokens[idx] 216 | qn_str = self.qns_str[idx] 217 | ans_str = self.ans_str[idx] 218 | gt = self.gts_str[idx] 219 | 220 | input_ids = torch.tensor(qn_tokens) 221 | return dict( 222 | idx=idx, 223 | input_ids=input_ids, 224 | input=qn_str, 225 | question=qn_str, 226 | reference=gt, 227 | record_data=dict(answer=ans_str, ground_truth=gt), 228 | ) 229 | 230 | def collate_fn(self, instances: Sequence[Dict]) -> Dict[str, Any]: 231 | idx, input_ids, input, question, reference, record_data = tuple([instance[key] for instance in instances] for key in ("idx", "input_ids", "input", "question", "reference", "record_data")) 232 | record_data = {k: [instance[k] for instance in record_data] for k in record_data[0].keys()} 233 | 234 | input_ids, attention_mask = left_pad_sequences(input_ids, padding_value=self.pad_token_id, return_attention_mask=True) 235 | 236 | return dict( 237 | idx=idx, 238 | input_ids=input_ids, 239 | attention_mask=attention_mask, 240 | input=input, 241 | question=question, 242 | reference=reference, 243 | record_data=record_data, 244 | ) 245 | -------------------------------------------------------------------------------- /sft/train_generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import transformers 3 | from tqdm.auto import tqdm 4 | from dataclasses import dataclass, field 5 | from typing import Optional, List, Dict, Set, Any, Union 6 | import gc 7 | from accelerate import Accelerator 8 | import wandb 9 | import os 10 | import re 11 | from huggingface_hub import login 12 | 13 | from utils.states import set_deepspeed_config, set_training_states, set_random_seed 14 | from utils.optim import get_optimizers 15 | from utils.models import build_model, save_llm_checkpoint, save_llm, save_training_args_with_accelerator 16 | from utils.datasets import make_training_dataloaders 17 | 18 | 19 | @dataclass 20 | class ModelArguments: 21 | model_name_or_path: Optional[str] = field(default="facebook/opt-125m") 22 | 23 | @dataclass 24 | class DataArguments: 25 | dataset: str = field(default='gsm8k') 26 | data_dir: str = field(default='data/gsm8k/', metadata={"help": "Path to the training data."}) 27 | target_set: str = field(default='train') 28 | mode: str = field(default="ft", metadata={"help": "ft or rft"}) 29 | loss_on_prefix: bool = field(default=True, metadata={"help": "Whether to compute loss on the prefix"}) 30 | 31 | @dataclass 32 | class TrainingArguments: 33 | cache_dir: Optional[str] = field(default=None) 34 | model_max_length: int = field( 35 | default=2048, 36 | metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."}, 37 | ) 38 | 39 | max_steps: int = field(default=-1, metadata={"help": "When it is specified, num_train_epoches is ignored"}) 40 | num_train_epoches: int = field(default=1) 41 | per_device_train_batch_size: int = field(default=4) 42 | gradient_accumulation_steps: int = field(default=1) 43 | gradient_checkpointing: bool = field(default=True) 44 | 45 | eval_steps: int = field(default=-1, metadata={"help": "When it is specified, eval_epoches is ignored"}) 46 | eval_epoches: int = field(default=1) 47 | evaluation_strategy: str = field(default="epoch") 48 | 49 | per_device_eval_batch_size: int = field(default=4) 50 | 51 | learning_rate: float = field(default=1e-5) 52 | weight_decay: float = field(default=0) 53 | lr_scheduler_type: str = field(default="linear") 54 | warmup_steps: int = field(default=-1, metadata={"help": "When it is specified, warmup_ratio is ignored"}) 55 | warmup_ratio: float = field(default=0) 56 | 57 | logging_steps: int = field(default=-1, metadata={"help": "When it is specified, logging_epoches is ignored"}) 58 | logging_epoches: int = field(default=1) 59 | 60 | save_steps: int = field(default=-1, metadata={"help": "When it is specified, save_epoches is ignored"}) 61 | save_epoches: int = field(default=1) 62 | save_total_limit: int = field(default=3) 63 | save_best: bool = field(default=False) 64 | save_strategy: str = field(default="epoch") 65 | save_model_only: bool = field(default=True) 66 | 67 | seed: int = field(default=42) 68 | 69 | @dataclass 70 | class GenerationArguments: 71 | do_sample: bool = field(default=False) 72 | num_beams: int = field(default=1) 73 | 74 | temperature: float = field(default=0.7) 75 | top_k: int = field(default=50) 76 | top_p: float = field(default=1.0) 77 | repetition_penalty: float = field(default=1.0) 78 | length_penalty: float = field(default=1.0) 79 | 80 | max_length : int = field(default=2048) 81 | max_new_tokens: int = field(default=400) 82 | 83 | @dataclass 84 | class OutputArguments: 85 | logging_dir: str = field(default='wandb/') 86 | save_dir: str = field(default='checkpoints/') 87 | 88 | 89 | def main(): 90 | parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments, GenerationArguments, OutputArguments)) 91 | model_args, data_args, training_args, generation_args, output_args = parser.parse_args_into_dataclasses() 92 | config_args_dict = model_args.__dict__.copy().update(dict(**data_args.__dict__, **training_args.__dict__)) 93 | set_random_seed(training_args.seed) 94 | 95 | # print("TORCH VERSION", torch.__version__) 96 | # print("TORCH NCCL VERSION", torch.cuda.nccl.version()) 97 | 98 | from sft_datasets import make_finetuning_generator_data_module 99 | 100 | accelerator = Accelerator(gradient_accumulation_steps=training_args.gradient_accumulation_steps) 101 | 102 | # load model, tokenizer, and dataloader 103 | set_deepspeed_config(accelerator, training_args) 104 | model, tokenizer = build_model(model_args, training_args) 105 | 106 | data_module = make_finetuning_generator_data_module(tokenizer, data_args) 107 | train_dataloader, val_dataloader = make_training_dataloaders(data_module, training_args) 108 | 109 | # config optimizer and scheduler 110 | set_training_states(data_module, training_args) 111 | optimizer, lr_scheduler = get_optimizers(model, training_args) 112 | 113 | model, train_dataloader, optimizer = accelerator.prepare(model, train_dataloader, optimizer) 114 | 115 | model.to(torch.bfloat16) 116 | 117 | cur_epoch = local_step = global_step = 0 118 | start_local_step = start_global_step = -1 119 | 120 | # init wandb 121 | if accelerator.is_main_process: 122 | project_name = os.environ['WANDB_PROJECT'] 123 | logging_dir = os.path.join(output_args.logging_dir, project_name) 124 | 125 | os.makedirs(logging_dir, exist_ok=True) 126 | wandb_id = wandb.util.generate_id() 127 | wandb.init(id=wandb_id, dir=logging_dir, config=config_args_dict) 128 | 129 | 130 | # training 131 | global_step = 0 132 | model.train() 133 | while global_step < training_args.num_training_steps: 134 | train_dataloader_iterator = tqdm(enumerate(train_dataloader), total=len(train_dataloader), desc=f'Training - Epoch {cur_epoch+1} / {training_args.num_train_epoches}') if accelerator.is_main_process else enumerate(train_dataloader) 135 | 136 | for local_step, batch in train_dataloader_iterator: 137 | if global_step < start_global_step: 138 | global_step += 1 139 | continue 140 | 141 | batch_input = {k: v for k, v in batch.items() if k in ('input_ids', 'attention_mask', 'labels')} 142 | # backpropagation 143 | with accelerator.accumulate(model): 144 | output = model(**batch_input, return_dict=True) 145 | loss = output.loss 146 | accelerator.backward(loss) 147 | 148 | optimizer.step() 149 | if not accelerator.optimizer_step_was_skipped and global_step % training_args.gradient_accumulation_steps == 0: 150 | lr_scheduler.step() 151 | optimizer.zero_grad() 152 | 153 | # training logging 154 | if accelerator.is_main_process: 155 | train_dataloader_iterator.set_postfix(epoch=cur_epoch, step=local_step, loss=loss.item()) 156 | 157 | if global_step % training_args.num_logging_steps == 0: 158 | wandb.log({ 159 | 'loss': loss.item(), 160 | 'lr': lr_scheduler.get_last_lr()[0] 161 | }, step=global_step) 162 | 163 | # save checkpoint 164 | if global_step != 0 and global_step % training_args.per_save_steps == 0: 165 | accelerator.wait_for_everyone() 166 | save_llm_checkpoint(accelerator, model, tokenizer, output_args.save_dir, global_step, training_args.save_total_limit) 167 | 168 | # save states for resuming [Not needed] 169 | # if global_step != 0 and global_step % training_args.per_save_steps == 0: 170 | # accelerator.wait_for_everyone() 171 | # accelerator.save_state(os.path.join(output_args.save_dir, 'resume')) 172 | global_step += 1 173 | 174 | cur_epoch += 1 175 | 176 | gc.collect(); torch.cuda.empty_cache() 177 | 178 | accelerator.wait_for_everyone() 179 | save_llm(accelerator, model, tokenizer, output_args.save_dir) 180 | save_training_args_with_accelerator(accelerator, training_args, output_args.save_dir) 181 | 182 | accelerator.save_state(os.path.join(output_args.save_dir, 'resume')) 183 | if accelerator.is_main_process: 184 | wandb.finish() 185 | 186 | if __name__ == "__main__": 187 | login(token="your_hugging_face_token_here") # put your huggingface token here! 188 | main() 189 | -------------------------------------------------------------------------------- /sft/utils/constants.py: -------------------------------------------------------------------------------- 1 | IGNORE_INDEX = -100 2 | DEFAULT_PAD_TOKEN = "" 3 | DEFAULT_BOS_TOKEN = "" 4 | DEFAULT_EOS_TOKEN = "" 5 | DEFAULT_UNK_TOKEN = "" 6 | 7 | LLAMA_EQUALS_TOKENS = set([353, 3892, 29922, 10457]) # _=, )=, =, =- 8 | LLAMA_LEFTMARK_TOKENS = set([3532, 9314]) # <<, _<< 9 | LLAMA_RIGHTMARK_TOKEN = 6778 # >> 10 | LLAMA_NEWLINE_TOKEN = 13 # \n 11 | 12 | -------------------------------------------------------------------------------- /sft/utils/flash_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | # refer to https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama2_flash_attn_monkey_patch.py 2 | 3 | import warnings 4 | from typing import Optional, Tuple 5 | 6 | import torch 7 | from flash_attn import __version__ as flash_attn_version 8 | from flash_attn.bert_padding import pad_input, unpad_input 9 | from flash_attn.flash_attn_interface import ( 10 | flash_attn_func, 11 | flash_attn_varlen_kvpacked_func, 12 | ) 13 | from transformers.models.llama.modeling_llama import ( 14 | LlamaAttention, 15 | LlamaModel, 16 | rotate_half, 17 | ) 18 | from transformers.models.mistral.modeling_mistral import ( 19 | MistralAttention, 20 | MistralModel, 21 | rotate_half, 22 | ) 23 | 24 | 25 | 26 | def apply_rotary_pos_emb(q, k, cos_sin, position_ids): 27 | gather_indices = position_ids[:, :, None, None] # [bsz, seq_len, 1, 1] 28 | gather_indices = gather_indices.repeat( 29 | 1, 1, cos_sin[0].shape[1], cos_sin[0].shape[3] 30 | ) 31 | bsz = gather_indices.shape[0] 32 | cos, sin = ( 33 | torch.gather(x.transpose(1, 2).repeat(bsz, 1, 1, 1), 1, gather_indices) 34 | for x in cos_sin 35 | ) 36 | q, k = ((x * cos) + (rotate_half(x) * sin) for x in (q, k)) 37 | return q, k 38 | 39 | 40 | def forward( 41 | self, 42 | hidden_states: torch.Tensor, 43 | attention_mask: Optional[torch.Tensor] = None, 44 | position_ids: Optional[torch.Tensor] = None, 45 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 46 | output_attentions: bool = False, 47 | use_cache: bool = False, 48 | padding_mask: Optional[torch.Tensor] = None, 49 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 50 | if output_attentions: 51 | warnings.warn( 52 | "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead." 53 | ) 54 | 55 | bsz, q_len, _ = hidden_states.size() 56 | kv_heads = getattr(self, "num_key_value_heads", self.num_heads) 57 | 58 | q, k, v = ( 59 | op(hidden_states).view(bsz, q_len, nh, self.head_dim) 60 | for op, nh in ( 61 | (self.q_proj, self.num_heads), 62 | (self.k_proj, kv_heads), 63 | (self.v_proj, kv_heads), 64 | ) 65 | ) 66 | # shape: (b, s, num_heads, head_dim) 67 | 68 | kv_seq_len = k.shape[1] 69 | past_kv_len = 0 70 | if past_key_value is not None: 71 | past_kv_len = past_key_value[0].shape[2] 72 | kv_seq_len += past_kv_len 73 | 74 | cos_sin = self.rotary_emb(v, seq_len=kv_seq_len) 75 | q, k = apply_rotary_pos_emb(q, k, cos_sin, position_ids) 76 | 77 | if past_key_value is not None: 78 | assert ( 79 | flash_attn_version >= "2.1.0" 80 | ), "past_key_value support requires flash-attn >= 2.1.0" 81 | # reuse k, v 82 | k = torch.cat([past_key_value[0].transpose(1, 2), k], dim=1) 83 | v = torch.cat([past_key_value[1].transpose(1, 2), v], dim=1) 84 | 85 | past_key_value = (k.transpose(1, 2), v.transpose(1, 2)) if use_cache else None 86 | 87 | if attention_mask is None: 88 | output = flash_attn_func(q, k, v, 0.0, softmax_scale=None, causal=True).view( 89 | bsz, q_len, -1 90 | ) 91 | else: 92 | q, indices, cu_q_lens, max_s = unpad_input(q, attention_mask[:, -q_len:]) 93 | # We can skip concat and call unpad twice but seems better to call unpad only once. 94 | kv, _, cu_k_lens, max_k = unpad_input( 95 | torch.stack((k, v), dim=2), attention_mask 96 | ) 97 | output_unpad = flash_attn_varlen_kvpacked_func( 98 | q, 99 | kv, 100 | cu_q_lens, 101 | cu_k_lens, 102 | max_s, 103 | max_k, 104 | 0.0, 105 | softmax_scale=None, 106 | causal=True, 107 | ) 108 | output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim) 109 | output = pad_input(output_unpad, indices, bsz, q_len) 110 | 111 | return self.o_proj(output), None, past_key_value 112 | 113 | 114 | # Disable the transformation of the attention mask in LlamaModel as flash attention 115 | # takes a boolean key_padding_mask. Fills in the past kv length for use in forward. 116 | def _prepare_decoder_attention_mask( 117 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length 118 | ): 119 | 120 | if attention_mask is not None and torch.all(attention_mask): 121 | return None # This uses the faster call when training with full samples 122 | 123 | return attention_mask 124 | 125 | 126 | def replace_llama_attn_with_flash_attn(): 127 | cuda_major, cuda_minor = torch.cuda.get_device_capability() 128 | if cuda_major < 8: 129 | warnings.warn( 130 | "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." 131 | "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593" 132 | ) 133 | 134 | LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask 135 | LlamaAttention.forward = forward 136 | 137 | -------------------------------------------------------------------------------- /sft/utils/gsm8k/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hbin0701/Self-Explore/54b8ba9f4cee1ef52f72de916e100e0da2ae6863/sft/utils/gsm8k/__init__.py -------------------------------------------------------------------------------- /sft/utils/gsm8k/decoding.py: -------------------------------------------------------------------------------- 1 | ### From OVM/utils/gsm8k/decoding.py 2 | 3 | from contextlib import contextmanager 4 | import signal 5 | import json 6 | import os 7 | import re 8 | 9 | 10 | # ANS_RE = re.compile(r"#### (\-?[0-9\.\,]+)") 11 | ANS_RE = re.compile(r"The answer is:?\s*(\-?[0-9\.\,]+)") 12 | INVALID_ANS = "[invalid]" 13 | 14 | 15 | def extract_answer(completion): 16 | match = ANS_RE.search(completion) 17 | if match: 18 | match_str = match.group(1).strip() 19 | st_str = standardize_value_str(match_str) 20 | try: eval(st_str); return st_str 21 | except: ... 22 | return INVALID_ANS 23 | 24 | def extract_answers(completions): 25 | return [extract_answer(completion) for completion in completions] 26 | 27 | def standardize_value_str(x): 28 | """Standardize numerical values""" 29 | y = x.replace(",", "") 30 | if '.' in y: 31 | y = y.rstrip('0') 32 | if y[-1] == '.': 33 | y = y[:-1] 34 | if not len(y): 35 | return INVALID_ANS 36 | if y[0] == '.': 37 | y = '0' + y 38 | if y[-1] == '%': 39 | y = str(eval(y[:-1]) / 100) 40 | return y.rstrip('.') 41 | 42 | def get_answer_label(response_answer, gt): 43 | if response_answer == INVALID_ANS: 44 | return INVALID_ANS 45 | return response_answer == gt 46 | 47 | 48 | # taken from 49 | # https://stackoverflow.com/questions/492519/timeout-on-a-function-call 50 | @contextmanager 51 | def timeout(duration, formula): 52 | def timeout_handler(signum, frame): 53 | raise Exception(f"'{formula}': timed out after {duration} seconds") 54 | 55 | signal.signal(signal.SIGALRM, timeout_handler) 56 | signal.alarm(duration) 57 | yield 58 | signal.alarm(0) 59 | 60 | 61 | def eval_with_timeout(formula, max_time=3): 62 | try: 63 | with timeout(max_time, formula): 64 | return round(eval(formula), ndigits=4) 65 | except Exception as e: 66 | signal.alarm(0) 67 | print(f"Warning: Failed to eval {formula}, exception: {e}") 68 | return None 69 | 70 | 71 | # refer to https://github.com/openai/grade-school-math/blob/master/grade_school_math/calculator.py 72 | def use_calculator(sample): 73 | if "<<" not in sample: 74 | return None 75 | 76 | parts = sample.split("<<") 77 | remaining = parts[-1] 78 | if ">>" in remaining: 79 | return None 80 | if "=" not in remaining: 81 | return None 82 | lhs = remaining.split("=")[0] 83 | lhs = lhs.replace(",", "") 84 | if any([x not in "0123456789*+-/.()" for x in lhs]): 85 | return None 86 | ans = eval_with_timeout(lhs) 87 | if remaining[-1] == '-' and ans is not None and ans < 0: 88 | ans = -ans 89 | return ans 90 | 91 | 92 | 93 | 94 | 95 | 96 | -------------------------------------------------------------------------------- /sft/utils/gsm8k/metrics.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from typing import Optional, List, Dict, Set, Any, Union 5 | import torch.distributed as dist 6 | from utils.gsm8k.decoding import INVALID_ANS, extract_answers, get_answer_label 7 | 8 | 9 | 10 | class GeneratorAnswerAcc: 11 | def __init__(self, n_data: int): 12 | self.n_data = n_data 13 | 14 | self.world_size = dist.get_world_size() if dist.is_initialized() else 1 15 | 16 | self.corrs = [] 17 | self.gather = False 18 | 19 | @torch.inference_mode(mode=True) 20 | def __call__(self, completions: List[str], gts: List[str]): 21 | answers = extract_answers(completions) 22 | 23 | corrs = [float(get_answer_label(answer, gt) == True) for answer, gt in zip(answers, gts)] 24 | 25 | self.corrs.append(corrs) 26 | 27 | def get_metric(self, reset=True): 28 | if not self.gather: 29 | if self.world_size != 1: 30 | gathered_corrs = [None] * self.world_size 31 | for obj, container in [ 32 | (self.corrs, gathered_corrs), 33 | ]: 34 | dist.all_gather_object(container, obj) 35 | 36 | flatten_corrs = [] 37 | for corrs_gpus in zip(*gathered_corrs): 38 | for corrs in corrs_gpus: 39 | flatten_corrs.extend(corrs) 40 | 41 | else: 42 | flatten_corrs = [item for sublist in self.corrs for item in sublist] 43 | 44 | self.corrs = flatten_corrs[:self.n_data] 45 | self.gather = True 46 | 47 | acc = (sum(self.corrs) / len(self.corrs)) 48 | 49 | if reset: 50 | self.corrs = [] 51 | self.gather = False 52 | return acc 53 | 54 | 55 | class MultiSamplingAnswerAcc: 56 | def __init__(self, n_data: int = None): 57 | self.n_data = n_data 58 | 59 | self.world_size = dist.get_world_size() if dist.is_initialized() else 1 60 | 61 | self.answers = [] 62 | self.gts = [] 63 | 64 | def start_new_sol_epoch(self): 65 | self.cur_answers = [] 66 | self.cur_gts = [] 67 | 68 | def end_the_sol_epoch(self): 69 | 70 | if self.world_size != 1: 71 | gathered_answers, gathered_gts = tuple([None] * self.world_size for _ in range(2)) 72 | for obj, container in [ 73 | (self.cur_answers, gathered_answers), 74 | (self.cur_gts, gathered_gts), 75 | ]: 76 | dist.all_gather_object(container, obj) 77 | 78 | flatten_answers, flatten_gts = [], [] 79 | for answers_gpus, gts_gpus in zip(zip(*gathered_answers), zip(*gathered_gts)): 80 | for answers, gts in zip(answers_gpus, gts_gpus): 81 | flatten_answers.extend(answers) 82 | flatten_gts.extend(gts) 83 | 84 | else: 85 | flatten_answers, flatten_gts = tuple([item for sublist in container for item in sublist] 86 | for container in [self.cur_answers, self.cur_gts]) 87 | 88 | self.answers.append(flatten_answers[:self.n_data]) 89 | self.gts.append(flatten_gts[:self.n_data]) 90 | 91 | 92 | @torch.inference_mode(mode=True) 93 | def __call__(self, completions: List[str], gts: List[str]): 94 | answers = extract_answers(completions) 95 | 96 | answers = [float(a) if a != INVALID_ANS else float('nan') for a in answers] 97 | gts = [float(gt) for gt in gts] 98 | 99 | self.cur_answers.append(answers) 100 | self.cur_gts.append(gts) 101 | 102 | 103 | def get_metric(self, n_solution: int=3, reset=True): 104 | 105 | assert all(x == self.gts[0] for x in self.gts) 106 | 107 | # [n_question] 108 | gts = np.array(self.gts[0]) 109 | # [n_question, n_solution] 110 | answers = np.stack(self.answers[:n_solution], axis=1) 111 | # print('answers:', answers.shape) 112 | 113 | pass_k = (answers == gts.reshape((-1, 1))).any(1).mean(0) 114 | acc_majority = np.mean([is_majority(a, gt, ignore=float('nan')) for a, gt in zip(answers, gts)]) 115 | 116 | if reset: 117 | self.gts = [] 118 | self.answers = [] 119 | return pass_k, acc_majority 120 | 121 | 122 | 123 | def is_passk(answers, gt): 124 | return gt in answers 125 | 126 | def is_majority(answers, gt, ignore = INVALID_ANS): 127 | filter_answers = list(filter(lambda x: x!=ignore, answers)) 128 | final_answer = max(filter_answers, key=filter_answers.count) 129 | return final_answer == gt 130 | 131 | 132 | -------------------------------------------------------------------------------- /sft/utils/metrics.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import torch.distributed as dist 5 | from typing import Optional, List, Dict, Set, Any, Union 6 | from utils.constants import IGNORE_INDEX 7 | 8 | 9 | 10 | class VerifierClassificationAcc: 11 | def __init__(self, n_data: int): 12 | self.n_data = n_data 13 | 14 | self.world_size = dist.get_world_size() if dist.is_initialized() else 1 15 | 16 | self.scores = [] 17 | self.gts = [] 18 | self.gather = False 19 | 20 | @torch.inference_mode(mode=True) 21 | def __call__(self, v_scores: torch.FloatTensor, v_labels: torch.LongTensor): 22 | bsz, n_seq = v_labels.shape 23 | index = ((n_seq - 1) - v_labels.ne(IGNORE_INDEX).flip(dims=[1]).float().argmax(1)).view(-1, 1) 24 | 25 | scores = v_scores.squeeze(-1).gather(1, index).squeeze() 26 | gts = v_labels.gather(1, index).squeeze() 27 | 28 | self.scores.append(scores.tolist()) 29 | self.gts.append(gts.tolist()) 30 | 31 | def get_metric(self, thres: float=0.5, reset=True): 32 | if not self.gather: 33 | if self.world_size != 1: 34 | gathered_scores, gathered_gts = tuple([None] * self.world_size for _ in range(2)) 35 | for obj, container in [ 36 | (self.scores, gathered_scores), 37 | (self.gts, gathered_gts), 38 | ]: 39 | dist.all_gather_object(container, obj) 40 | 41 | flatten_scores, flatten_gts = [], [] 42 | for scores_gpus, gts_gpus in zip(zip(*gathered_scores), zip(*gathered_gts)): 43 | for scores, gts in zip(scores_gpus, gts_gpus): 44 | flatten_scores.extend(scores) 45 | flatten_gts.extend(gts) 46 | 47 | else: 48 | flatten_scores, flatten_gts = tuple([item for sublist in container for item in sublist] 49 | for container in [self.scores, self.gts]) 50 | 51 | self.scores = flatten_scores[:self.n_data] 52 | self.gts = flatten_gts[:self.n_data] 53 | self.gather = True 54 | 55 | 56 | pred = (np.array(self.scores) > thres) 57 | corrs = np.where(np.array(self.gts).astype(bool), pred, ~pred) 58 | acc = (sum(corrs) / len(corrs)) 59 | 60 | if reset: 61 | self.scores = [] 62 | self.gts = [] 63 | self.gather = False 64 | return acc 65 | 66 | 67 | class VerifierMPk: 68 | def __init__(self, n_data: int, n_solution_per_problem: int): 69 | self.n_data = n_data 70 | self.n_solution_per_problem = n_solution_per_problem 71 | 72 | self.world_size = dist.get_world_size() if dist.is_initialized() else 1 73 | 74 | self.preds = [] 75 | self.gts = [] 76 | self.gather = False 77 | 78 | @torch.inference_mode(mode=True) 79 | def __call__(self, v_scores: torch.FloatTensor, v_labels: torch.LongTensor): 80 | bsz, n_seq = v_labels.shape 81 | index = ((n_seq - 1) - v_labels.ne(IGNORE_INDEX).flip(dims=[1]).float().argmax(1)).view(-1, 1) 82 | 83 | preds = v_scores.squeeze(-1).gather(1, index).squeeze() 84 | gts = v_labels.gather(1, index).squeeze() 85 | 86 | self.preds.append(preds.tolist()) 87 | self.gts.append(gts.tolist()) 88 | 89 | def get_metric(self, k, reset=True): 90 | if not self.gather: 91 | if self.world_size != 1: 92 | gathered_preds, gathered_gts = tuple([None] * self.world_size for _ in range(2)) 93 | for obj, container in [ 94 | (self.preds, gathered_preds), 95 | (self.gts, gathered_gts), 96 | ]: 97 | dist.all_gather_object(container, obj) 98 | 99 | flatten_preds, flatten_gts = [], [] 100 | for preds_gpus, gts_gpus in zip(zip(*gathered_preds), zip(*gathered_gts)): 101 | for preds, gts in zip(preds_gpus, gts_gpus): 102 | flatten_preds.extend(preds) 103 | flatten_gts.extend(gts) 104 | 105 | else: 106 | flatten_preds, flatten_gts = tuple([item for sublist in container for item in sublist] 107 | for container in [self.preds, self.gts]) 108 | 109 | self.preds = flatten_preds[:self.n_data] 110 | self.gts = flatten_gts[:self.n_data] 111 | self.gather = True 112 | 113 | preds = np.array(self.preds).reshape(-1, self.n_solution_per_problem) 114 | gts = np.array(self.gts).reshape(-1, self.n_solution_per_problem) 115 | 116 | indices = np.argsort(-preds, axis=1) 117 | gts = np.take_along_axis(gts, indices, axis=1) 118 | 119 | # [n_question, k] 120 | gts_topk = gts[:, :k] 121 | 122 | # how portion of solutions predicted topest are really correct 123 | mpk = gts_topk.mean(1).mean(0) 124 | 125 | if reset: 126 | self.preds = [] 127 | self.gts = [] 128 | self.gather = False 129 | return mpk 130 | 131 | 132 | class GenWithVerifierAcc: 133 | def __init__(self, n_data: int, n_solution_per_problem: int): 134 | self.n_data = n_data 135 | self.n_solution_per_problem = n_solution_per_problem 136 | 137 | self.world_size = dist.get_world_size() if dist.is_initialized() else 1 138 | 139 | self.preds = [] 140 | self.gts = [] 141 | self.gather = False 142 | 143 | @torch.inference_mode(mode=True) 144 | def __call__(self, v_scores: torch.FloatTensor, v_labels: torch.LongTensor): 145 | bsz, n_seq = v_labels.shape 146 | index = ((n_seq - 1) - v_labels.ne(IGNORE_INDEX).flip(dims=[1]).float().argmax(1)).view(-1, 1) 147 | 148 | preds = v_scores.squeeze(-1).gather(1, index).squeeze() 149 | gts = v_labels.gather(1, index).squeeze() 150 | 151 | self.preds.append(preds.tolist()) 152 | self.gts.append(gts.tolist()) 153 | 154 | 155 | def get_metric(self, k, reset=True): 156 | if not self.gather: 157 | if self.world_size != 1: 158 | gathered_preds, gathered_gts = tuple([None] * self.world_size for _ in range(2)) 159 | for obj, container in [ 160 | (self.preds, gathered_preds), 161 | (self.gts, gathered_gts), 162 | ]: 163 | dist.all_gather_object(container, obj) 164 | 165 | flatten_preds, flatten_gts = [], [] 166 | for preds_gpus, gts_gpus in zip(zip(*gathered_preds), zip(*gathered_gts)): 167 | for preds, gts in zip(preds_gpus, gts_gpus): 168 | flatten_preds.extend(preds) 169 | flatten_gts.extend(gts) 170 | 171 | else: 172 | flatten_preds, flatten_gts = tuple([item for sublist in container for item in sublist] 173 | for container in [self.preds, self.gts]) 174 | 175 | self.preds = flatten_preds[:self.n_data] 176 | self.gts = flatten_gts[:self.n_data] 177 | self.gather = True 178 | 179 | 180 | preds = np.array(self.preds).reshape(-1, self.n_solution_per_problem) 181 | gts = np.array(self.gts).reshape(-1, self.n_solution_per_problem) 182 | gts = gts[:, :k] 183 | preds = preds[:, :k] 184 | 185 | indices = np.argsort(-preds, axis=1) 186 | gts = np.take_along_axis(gts, indices, axis=1) 187 | 188 | acc = gts[:, 0].mean(0) 189 | 190 | if reset: 191 | self.preds = [] 192 | self.gts = [] 193 | self.gather = False 194 | return acc 195 | 196 | -------------------------------------------------------------------------------- /sft/utils/models.py: -------------------------------------------------------------------------------- 1 | from utils.flash_attn_monkey_patch import replace_llama_attn_with_flash_attn 2 | from utils.cached_models import build_transformers_mapping_to_cached_models, build_transformers_mapping_to_custom_tokenizers 3 | replace_llama_attn_with_flash_attn() 4 | build_transformers_mapping_to_cached_models() 5 | build_transformers_mapping_to_custom_tokenizers() 6 | 7 | from typing import Optional, List, Dict, Set, Any, Union, Callable, Mapping 8 | from torch import nn 9 | import torch 10 | import pathlib 11 | from dataclasses import dataclass 12 | from accelerate import Accelerator 13 | import os 14 | import re 15 | import shutil 16 | from functools import wraps 17 | import transformers 18 | from utils.constants import DEFAULT_PAD_TOKEN, DEFAULT_BOS_TOKEN, DEFAULT_EOS_TOKEN, DEFAULT_UNK_TOKEN 19 | 20 | 21 | def smart_tokenizer_and_embedding_resize_for_pad( 22 | special_tokens_dict: Dict, 23 | tokenizer: transformers.PreTrainedTokenizer, 24 | model: transformers.PreTrainedModel, 25 | ): 26 | """Resize tokenizer and embedding. 27 | Note: This is the unoptimized version that may make your embedding size not be divisible by 64. 28 | """ 29 | num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) 30 | model.resize_token_embeddings(len(tokenizer)) 31 | if num_new_tokens > 0: 32 | input_embeddings = model.get_input_embeddings().weight.data 33 | output_embeddings = model.get_output_embeddings().weight.data 34 | 35 | input_embeddings[-1] = torch.zeros_like(input_embeddings)[0] 36 | output_embeddings[-1] = torch.zeros_like(output_embeddings)[0] 37 | 38 | 39 | def build_model(model_args: dataclass, training_args: dataclass): 40 | # Step 1: Initialize LLM 41 | print(f"+ [Model] Initializing LM: {model_args.model_name_or_path}") 42 | model = transformers.AutoModelForCausalLM.from_pretrained( 43 | model_args.model_name_or_path, 44 | use_cache=False, 45 | trust_remote_code=True 46 | ) 47 | try: 48 | if training_args.gradient_checkpointing: 49 | model.gradient_checkpointing_enable() 50 | except: 51 | pass 52 | 53 | # Step 2: Initialize tokenizer 54 | print(f"+ [Model] Initializing Tokenizer: {model_args.model_name_or_path}") 55 | 56 | tokenizer = transformers.AutoTokenizer.from_pretrained( 57 | model_args.model_name_or_path, 58 | cache_dir=training_args.cache_dir, 59 | model_max_length=training_args.model_max_length, 60 | padding_side="right", 61 | use_fast=False, 62 | ) 63 | 64 | # Step 3: Add special tokens 65 | if "mistral" in model_args.model_name_or_path.lower() or "llama" in model_args.model_name_or_path.lower() or "llema" in model_args.model_name_or_path.lower(): 66 | if tokenizer.pad_token is None: 67 | smart_tokenizer_and_embedding_resize_for_pad( 68 | special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN), 69 | tokenizer=tokenizer, 70 | model=model, 71 | ) 72 | tokenizer.add_special_tokens({ 73 | "eos_token": DEFAULT_EOS_TOKEN, 74 | "bos_token": DEFAULT_BOS_TOKEN, 75 | "unk_token": DEFAULT_UNK_TOKEN, 76 | }) 77 | 78 | # Step 4: Align special token ids between tokenizer and model.config 79 | model.config.pad_token_id = tokenizer.pad_token_id 80 | model.config.bos_token_id = tokenizer.bos_token_id 81 | model.config.eos_token_id = tokenizer.eos_token_id 82 | 83 | else: # Deepseek_math 84 | tokenizer.pad_token_id = tokenizer.eos_token_id 85 | model.generation_config = transformers.GenerationConfig.from_pretrained(model_args.model_name_or_path) 86 | model.generation_config.pad_token_id = model.generation_config.eos_token_id 87 | 88 | return model, tokenizer 89 | 90 | def safe_delete_with_accelerator(accelerator: Accelerator, path: str): 91 | @accelerator.on_main_process 92 | def delete(path): 93 | shutil.rmtree(path, ignore_errors=True) 94 | 95 | delete(path) 96 | 97 | def safe_move_with_accelerator(accelerator: Accelerator, ori_path: str, new_path: str): 98 | @accelerator.on_main_process 99 | def move(ori_path, new_path): 100 | try: 101 | shutil.move(ori_path, new_path) 102 | except: 103 | ... 104 | 105 | move(ori_path, new_path) 106 | 107 | 108 | 109 | def wrapper_safe_save_model_with_accelerator(save_model_func): 110 | @wraps(save_model_func) 111 | def wrapper(accelerator: Accelerator, 112 | model: nn.Module, 113 | tokenizer: transformers.AutoTokenizer, 114 | output_dir: str): 115 | @accelerator.on_main_process 116 | def save_model(cpu_state_dict, output_dir): 117 | save_model_func(accelerator=accelerator, model=model, cpu_state_dict=cpu_state_dict, output_dir=output_dir) 118 | @accelerator.on_main_process 119 | def save_tokenizer(output_dir): 120 | tokenizer.save_pretrained(output_dir) 121 | 122 | os.makedirs(output_dir, exist_ok=True) 123 | state_dict = accelerator.get_state_dict(model) 124 | cpu_state_dict = { 125 | key: value.cpu() 126 | for key, value in state_dict.items() 127 | } 128 | 129 | save_model(cpu_state_dict, output_dir) 130 | save_tokenizer(output_dir) 131 | 132 | print(f"+ [Save] Save model and tokenizer to: {output_dir}") 133 | return wrapper 134 | 135 | 136 | # refer to transformers.trainer._sorted_checkpoints "https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/trainer.py#L2848" 137 | def wrapper_save_checkpoint(save_func): 138 | @wraps(save_func) 139 | def outwrapper(func): 140 | def wrapper(accelerator: Accelerator, 141 | model: transformers.AutoModelForCausalLM, 142 | tokenizer: transformers.PreTrainedTokenizer, 143 | output_dir: str, 144 | global_step: int, 145 | save_total_limit: int=None): 146 | checkpoint_output_dir = os.path.join(output_dir, f'checkpoint-{global_step}') 147 | if os.path.exists(checkpoint_output_dir) or save_total_limit < 1: 148 | return 149 | save_func(accelerator=accelerator, model=model, tokenizer=tokenizer, output_dir=checkpoint_output_dir) 150 | 151 | ordering_and_checkpoint_path = [] 152 | glob_checkpoints = [str(x) for x in pathlib.Path(output_dir).glob('*checkpoint-*')] 153 | for path in glob_checkpoints: 154 | regex_match = re.match(r".*checkpoint-([0-9]+)", path) 155 | if regex_match is not None: 156 | ordering_and_checkpoint_path.append((int(regex_match.group(1)), path)) 157 | 158 | checkpoints_sorted = sorted(ordering_and_checkpoint_path) 159 | checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted] 160 | 161 | best_checkpoint = [str(x) for x in pathlib.Path(output_dir).glob('best-checkpoint-*')] 162 | if best_checkpoint: 163 | best_checkpoint = best_checkpoint[0] 164 | best_model_index = checkpoints_sorted.index(best_checkpoint) 165 | for i in range(best_model_index, len(checkpoints_sorted) - 2): 166 | checkpoints_sorted[i], checkpoints_sorted[i + 1] = checkpoints_sorted[i + 1], checkpoints_sorted[i] 167 | 168 | if save_total_limit: 169 | checkpoints_to_be_deleted = checkpoints_sorted[:-save_total_limit] 170 | for checkpoint in checkpoints_to_be_deleted: 171 | safe_delete_with_accelerator(accelerator, checkpoint) 172 | return wrapper 173 | return outwrapper 174 | 175 | 176 | def wrapper_save_best_checkpoint(save_checkpoint_func): 177 | @wraps(save_checkpoint_func) 178 | def outwrapper(func): 179 | def wrapper(accelerator: Accelerator, 180 | model: transformers.AutoModelForCausalLM, 181 | tokenizer: transformers.PreTrainedTokenizer, 182 | output_dir: str, 183 | global_step: int, 184 | save_total_limit: int=None): 185 | 186 | ori_best_checkpoint = [str(x) for x in pathlib.Path(output_dir).glob('best-checkpoint-*')] 187 | if ori_best_checkpoint: 188 | ori_best_checkpoint = ori_best_checkpoint[0] 189 | filename = os.path.basename(os.path.normpath(ori_best_checkpoint))[5:] 190 | safe_move_with_accelerator(accelerator, ori_best_checkpoint, os.path.join(output_dir, filename)) 191 | 192 | save_checkpoint_func(accelerator=accelerator, model=model, tokenizer=tokenizer, output_dir=output_dir, global_step=global_step, save_total_limit=save_total_limit) 193 | checkpoint_dir = os.path.join(output_dir, f'checkpoint-{global_step}') 194 | best_checkpoint_dir = os.path.join(output_dir, f'best-checkpoint-{global_step}') 195 | safe_move_with_accelerator(accelerator, checkpoint_dir, best_checkpoint_dir) 196 | return wrapper 197 | return outwrapper 198 | 199 | 200 | 201 | 202 | 203 | @wrapper_safe_save_model_with_accelerator 204 | def save_llm(accelerator: Accelerator, 205 | model: transformers.AutoModelForCausalLM, 206 | cpu_state_dict: Mapping, 207 | output_dir: str): 208 | accelerator.unwrap_model(model).save_pretrained( 209 | output_dir, 210 | state_dict=cpu_state_dict, 211 | is_main_process=accelerator.is_main_process, 212 | save_function=accelerator.save, 213 | ) 214 | 215 | 216 | @wrapper_save_checkpoint(save_func=save_llm) 217 | def save_llm_checkpoint(accelerator: Accelerator, 218 | model: transformers.AutoModelForCausalLM, 219 | tokenizer: transformers.PreTrainedTokenizer, 220 | checkpoint_output_dir: str): 221 | ... 222 | 223 | 224 | @wrapper_save_best_checkpoint(save_checkpoint_func=save_llm_checkpoint) 225 | def save_best_llm_checkpoint(accelerator: Accelerator, 226 | model: transformers.AutoModelForCausalLM, 227 | tokenizer: transformers.PreTrainedTokenizer, 228 | output_dir: str, 229 | global_step: int, 230 | save_total_limit: int=None): 231 | ... 232 | 233 | 234 | 235 | def save_training_args_with_accelerator(accelerator: Accelerator, 236 | training_args: dataclass, 237 | output_dir: str): 238 | output_file = os.path.join(output_dir, 'training_args.bin') 239 | accelerator.save(training_args, output_file) 240 | 241 | print(f"+ [Save] Save training_args to: {output_file}") 242 | 243 | 244 | 245 | 246 | -------------------------------------------------------------------------------- /sft/utils/optim.py: -------------------------------------------------------------------------------- 1 | from transformers import AdamW 2 | from transformers import get_scheduler 3 | import transformers 4 | from typing import Optional, List, Dict, Set, Any, Union 5 | from dataclasses import dataclass 6 | import os 7 | 8 | def get_optimizers(model: transformers.AutoModelForCausalLM, training_args: dataclass) -> Dict: 9 | no_decay = ["bias", "LayerNorm.weight"] 10 | optimizer_grouped_parameters = [ 11 | { 12 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 13 | "weight_decay": training_args.weight_decay, 14 | }, 15 | { 16 | "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 17 | "weight_decay": 0.0, 18 | }, 19 | ] 20 | 21 | optim = AdamW( 22 | optimizer_grouped_parameters, 23 | lr=training_args.learning_rate, 24 | # weight_decay=training_args.weight_decay 25 | ) 26 | lr_scheduler = get_scheduler( 27 | training_args.lr_scheduler_type, 28 | optimizer=optim, 29 | # num_warmup_steps=training_args.num_updating_warmup_steps_aggr_devices, 30 | # num_training_steps=training_args.num_updating_steps_aggr_devices, 31 | num_warmup_steps=training_args.num_updating_warmup_steps, 32 | num_training_steps=training_args.num_updating_steps, 33 | ) 34 | return optim, lr_scheduler 35 | 36 | 37 | 38 | -------------------------------------------------------------------------------- /sft/utils/states.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import json 4 | from dataclasses import dataclass 5 | import random 6 | import math 7 | import numpy as np 8 | from accelerate import Accelerator 9 | 10 | 11 | def set_deepspeed_config(accelerator: Accelerator, training_args: dataclass): 12 | world_size = int(os.environ.get("WORLD_SIZE", 1)) 13 | accelerator.state.deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] = training_args.per_device_train_batch_size 14 | accelerator.state.deepspeed_plugin.deepspeed_config['train_batch_size'] = training_args.per_device_train_batch_size * world_size * accelerator.gradient_accumulation_steps 15 | 16 | 17 | def set_training_states(data_module: dict, training_args: dataclass): 18 | set_num_steps_per_epoch(data_module, training_args) 19 | set_num_training_steps(training_args) 20 | set_num_updating_steps(training_args) 21 | set_num_eval_steps(training_args) 22 | set_per_eval_steps(training_args) 23 | set_num_warmup_steps(training_args) 24 | 25 | set_num_logging_steps(training_args) 26 | set_per_save_steps(training_args) 27 | 28 | print(f"+ [Training States] There are {training_args.num_training_steps} steps in total.") 29 | 30 | 31 | def set_num_steps_per_epoch(data_module: dict, training_args: dataclass): 32 | num_devices = int(os.environ.get("WORLD_SIZE", 1)) 33 | 34 | len_train_set_per_device = math.ceil(len(data_module["train_dataset"]) / num_devices) 35 | num_train_steps_per_device = math.ceil(len_train_set_per_device / training_args.per_device_train_batch_size) 36 | num_updating_steps_per_epoch = num_train_steps_per_device // training_args.gradient_accumulation_steps 37 | 38 | len_eval_set_per_device = math.ceil(len(data_module["val_dataset"]) / num_devices) if data_module["val_dataset"] is not None else None 39 | num_eval_steps_per_device = math.ceil(len_eval_set_per_device / training_args.per_device_eval_batch_size) if data_module["val_dataset"] is not None else None 40 | 41 | training_args.num_training_steps_per_epoch = num_train_steps_per_device 42 | training_args.num_updating_steps_per_epoch = num_updating_steps_per_epoch 43 | training_args.num_eval_steps_per_epoch = num_eval_steps_per_device 44 | 45 | def set_num_training_steps(training_args: dataclass): 46 | if training_args.max_steps != -1: 47 | num_training_steps = training_args.max_steps 48 | else: 49 | assert training_args.num_train_epoches != -1 50 | num_training_steps = training_args.num_training_steps_per_epoch * training_args.num_train_epoches 51 | num_training_steps_aggr_devices = num_training_steps * int(os.environ.get("WORLD_SIZE", 1)) 52 | 53 | training_args.num_training_steps = num_training_steps 54 | training_args.num_training_steps_aggr_devices = num_training_steps_aggr_devices 55 | 56 | def set_num_updating_steps(training_args: dataclass): 57 | num_updating_steps = training_args.num_training_steps // training_args.gradient_accumulation_steps 58 | num_updating_steps_aggr_devices = num_updating_steps * int(os.environ.get("WORLD_SIZE", 1)) 59 | 60 | training_args.num_updating_steps = num_updating_steps 61 | training_args.num_updating_steps_aggr_devices = num_updating_steps_aggr_devices 62 | 63 | 64 | def set_num_eval_steps(training_args: dataclass): 65 | training_args.num_eval_steps = training_args.num_eval_steps_per_epoch 66 | 67 | def set_per_eval_steps(training_args: dataclass): 68 | if training_args.eval_steps != -1: 69 | per_eval_steps = training_args.eval_steps 70 | else: 71 | assert training_args.eval_epoches != -1 72 | per_eval_steps = training_args.num_training_steps_per_epoch * training_args.eval_epoches 73 | 74 | training_args.per_eval_steps = per_eval_steps 75 | 76 | def set_num_warmup_steps(training_args: dataclass): 77 | # if training_args.warmup_steps != -1: 78 | # num_warmup_steps_forward = training_args.warmup_steps 79 | # else: 80 | # assert training_args.warmup_ratio != -1 81 | # num_warmup_steps_forward = int(training_args.num_training_steps * training_args.warmup_ratio) 82 | # num_updating_warmup_steps = num_warmup_steps_forward // training_args.gradient_accumulation_steps 83 | # num_updating_warmup_steps_aggr_devices = num_updating_warmup_steps * int(os.environ.get("WORLD_SIZE", 1)) 84 | if training_args.warmup_steps != -1: 85 | num_updating_warmup_steps = training_args.warmup_steps 86 | else: 87 | assert training_args.warmup_ratio != -1 88 | num_updating_warmup_steps = int(training_args.num_updating_steps * training_args.warmup_ratio) 89 | num_updating_warmup_steps_aggr_devices = num_updating_warmup_steps * int(os.environ.get("WORLD_SIZE", 1)) 90 | 91 | training_args.num_updating_warmup_steps = num_updating_warmup_steps 92 | training_args.num_updating_warmup_steps_aggr_devices = num_updating_warmup_steps_aggr_devices 93 | 94 | def set_num_logging_steps(training_args: dataclass): 95 | if training_args.logging_steps != -1: 96 | num_logging_steps = training_args.logging_steps 97 | else: 98 | assert training_args.logging_epoches != -1 99 | num_logging_steps = training_args.num_training_steps_per_epoch * training_args.logging_epoches 100 | 101 | training_args.num_logging_steps = num_logging_steps 102 | 103 | def set_per_save_steps(training_args: dataclass): 104 | if training_args.save_steps != -1: 105 | per_save_steps = training_args.save_steps 106 | else: 107 | assert training_args.save_epoches != -1 108 | per_save_steps = training_args.num_training_steps_per_epoch * training_args.save_epoches 109 | 110 | training_args.per_save_steps = per_save_steps 111 | 112 | 113 | def set_random_seed(seed: int): 114 | random.seed(seed) 115 | np.random.seed(seed) 116 | torch.manual_seed(seed) 117 | torch.cuda.manual_seed_all(seed) 118 | torch.cuda.manual_seed(seed) 119 | 120 | 121 | -------------------------------------------------------------------------------- /sft/utils/verifier_models.py: -------------------------------------------------------------------------------- 1 | from utils.models import wrapper_safe_save_model_with_accelerator, wrapper_save_checkpoint, wrapper_save_best_checkpoint, build_model, load_model 2 | from utils.sampling import shift_padding_to_left_2D, shift_padding_to_right_2D, find_rightmost_notpadded_positions 3 | from utils.constants import IGNORE_INDEX 4 | 5 | 6 | from typing import Optional, List, Dict, Set, Any, Union, Callable, Mapping 7 | import transformers 8 | from transformers.generation.utils import ModelOutput 9 | from torch import nn 10 | import torch.nn.functional as F 11 | import torch 12 | import pathlib 13 | import logging 14 | from dataclasses import dataclass, field 15 | from accelerate import Accelerator 16 | import os 17 | import re 18 | import shutil 19 | 20 | 21 | @dataclass 22 | class VerifierModelOutput(ModelOutput): 23 | loss: Optional[torch.FloatTensor] = None 24 | v_scores: torch.FloatTensor = None 25 | all_losses: Optional[Dict[str, torch.FloatTensor]] = None 26 | 27 | 28 | class Verifier(nn.Module): 29 | def __init__(self, backbone, checkpoint_dir=None): 30 | super(Verifier, self).__init__() 31 | self.backbone = backbone 32 | 33 | self.gain = nn.Parameter(torch.randn(1,)) 34 | self.bias = nn.Parameter(torch.randn(1,)) 35 | self.dropout = nn.Dropout(p=0.2) 36 | self.vscore_head = nn.Linear(self.backbone.get_input_embeddings().embedding_dim, 1, bias=False) 37 | 38 | if checkpoint_dir and os.path.exists(os.path.join(checkpoint_dir, 'verifier.pth')): 39 | verifier_params = torch.load(os.path.join(checkpoint_dir, 'verifier.pth')) 40 | self.load_state_dict(verifier_params, strict=False) 41 | else: 42 | self.init_head_params() 43 | 44 | self.pad_token_id = backbone.config.pad_token_id 45 | 46 | def init_head_params(self): 47 | output_embeddings = self.backbone.get_output_embeddings().weight.data 48 | output_embeddings_avg = output_embeddings.mean(dim=0, keepdim=True) 49 | 50 | self.vscore_head.weight = nn.Parameter(output_embeddings_avg) 51 | 52 | def loss_fct(self, v_scores: torch.FloatTensor, v_labels: torch.LongTensor): 53 | # (batch_size, n_seq, 1) 54 | return mse_loss_with_mask(v_scores.squeeze(), v_labels.type_as(v_scores)) 55 | 56 | def transform(self, last_hidden_states): 57 | return self.gain * last_hidden_states + self.bias 58 | 59 | def forward(self, 60 | input_ids: torch.LongTensor, 61 | attention_mask: Optional[torch.Tensor] = None, 62 | position_ids: Optional[torch.LongTensor] = None, 63 | past_key_values: Optional[List[torch.FloatTensor]] = None, 64 | labels: Optional[torch.LongTensor] = None, 65 | v_labels: Optional[torch.LongTensor] = None, 66 | output_all_losses: Optional[bool] = None, 67 | ): 68 | outputs = self.backbone( 69 | input_ids=input_ids, 70 | attention_mask=attention_mask, 71 | position_ids=position_ids, 72 | past_key_values=past_key_values, 73 | labels=labels, 74 | use_cache=False, 75 | output_hidden_states=True, 76 | return_dict=True, 77 | ) 78 | llm_logits = outputs.logits 79 | llm_loss = outputs.loss 80 | llm_hidden_states = outputs.hidden_states 81 | 82 | # (batch_size, n_seq, embed_dim) 83 | v_hidden_states = self.transform(llm_hidden_states[-1]) 84 | # (batch_size, n_seq, 1) 85 | v_scores = self.vscore_head(self.dropout(v_hidden_states)) 86 | 87 | v_loss, loss = None, None 88 | if v_labels is not None: 89 | v_loss = self.loss_fct(v_scores, v_labels) 90 | loss = v_loss + (llm_loss if labels is not None else 0) 91 | 92 | all_losses = None 93 | if output_all_losses: 94 | all_losses = {'llm_loss': llm_loss, 'v_loss': v_loss} 95 | 96 | return VerifierModelOutput( 97 | loss=loss, 98 | v_scores=v_scores, 99 | all_losses=all_losses, 100 | ) 101 | 102 | @torch.inference_mode(mode=True) 103 | def scoring_sequences(self, input_ids: torch.LongTensor): 104 | input_ids = shift_padding_to_right_2D(input_ids, pad_value=self.pad_token_id) 105 | outputs = self( 106 | input_ids=input_ids, 107 | attention_mask=input_ids.ne(self.pad_token_id), 108 | ) 109 | inds = find_rightmost_notpadded_positions(input_ids, pad_value=self.pad_token_id) 110 | return outputs.v_scores[:, :, -1].gather(1, inds.view(-1, 1)).squeeze(-1) 111 | 112 | def gradient_checkpointing_enable(self): 113 | self.backbone.gradient_checkpointing_enable() 114 | 115 | def gradient_checkpointing_disable(self): 116 | self.backbone.gradient_checkpointing_disable() 117 | 118 | 119 | def mse_loss_with_mask(scores: torch.FloatTensor, labels: torch.FloatTensor): 120 | scores = torch.where(labels.ne(IGNORE_INDEX), scores, 0) 121 | labels = torch.where(labels.ne(IGNORE_INDEX), labels, 0) 122 | return F.mse_loss(scores, labels, reduction='sum') / scores.shape[0] 123 | 124 | 125 | 126 | @wrapper_safe_save_model_with_accelerator 127 | def save_verifier(accelerator: Accelerator, 128 | model: transformers.AutoModelForCausalLM, 129 | cpu_state_dict: Mapping, 130 | output_dir: str): 131 | cpu_state_dict_backbone = { 132 | k.split('backbone.')[1]: v 133 | for k, v in cpu_state_dict.items() if k.startswith('backbone') 134 | } 135 | cpu_state_dict_verifier = { 136 | k: v 137 | for k, v in cpu_state_dict.items() if not k.startswith('backbone') 138 | } 139 | accelerator.unwrap_model(model).backbone.save_pretrained( 140 | output_dir, 141 | state_dict=cpu_state_dict_backbone, 142 | is_main_process=accelerator.is_main_process, 143 | save_function=accelerator.save, 144 | ) 145 | accelerator.save(cpu_state_dict_verifier, os.path.join(output_dir, 'verifier.pth')) 146 | 147 | 148 | @wrapper_save_checkpoint(save_func=save_verifier) 149 | def save_verifier_checkpoint(accelerator: Accelerator, 150 | model: transformers.AutoModelForCausalLM, 151 | tokenizer: transformers.PreTrainedTokenizer, 152 | checkpoint_output_dir: str): 153 | ... 154 | 155 | 156 | @wrapper_save_best_checkpoint(save_checkpoint_func=save_verifier_checkpoint) 157 | def save_best_verifier_checkpoint(accelerator: Accelerator, 158 | model: transformers.AutoModelForCausalLM, 159 | tokenizer: transformers.PreTrainedTokenizer, 160 | output_dir: str, 161 | global_step: int, 162 | save_total_limit: int=None): 163 | ... 164 | 165 | 166 | 167 | 168 | def build_verifier(model_args: dataclass, training_args: dataclass): 169 | backbone, tokenizer = build_model(model_args, training_args) 170 | return Verifier(backbone), tokenizer 171 | 172 | 173 | def load_verifier(model_args: dataclass): 174 | backbone, tokenizer = load_model(model_args) 175 | return Verifier(backbone, checkpoint_dir=model_args.model_name_or_path), tokenizer 176 | 177 | 178 | def load_generator_and_verifier(model_args: dataclass): 179 | generator, tokenizer = load_model(model_args) 180 | 181 | v_backbone = transformers.AutoModelForCausalLM.from_pretrained( 182 | model_args.verifier_model_name_or_path, 183 | torch_dtype=torch.float16 if model_args.fp16 else torch.bfloat16, 184 | ) 185 | 186 | verifier = Verifier(v_backbone, checkpoint_dir=model_args.verifier_model_name_or_path) 187 | return generator, verifier, tokenizer 188 | 189 | 190 | 191 | --------------------------------------------------------------------------------