├── README.md
├── data
├── GSM8K_test.jsonl
├── GSM8K_train.jsonl
├── MATH_test.jsonl
└── MATH_train.jsonl
├── dpo
├── alignment
│ ├── __init__.py
│ ├── configs.py
│ ├── data.py
│ ├── model_utils.py
│ └── release.py
├── custom_trainer.py
├── prepare_dataset.py
├── run_dpo.py
└── run_kto.py
├── eval
├── gsm8k
│ ├── eval_gsm8k.py
│ ├── run_gsm8k_eval.sh
│ └── utils_ans.py
└── math
│ ├── configs
│ ├── few_shot_test_configs.json
│ └── zero_shot_test_configs.json
│ ├── data_processing
│ ├── __pycache__
│ │ ├── answer_extraction.cpython-310.pyc
│ │ ├── answer_extraction.cpython-311.pyc
│ │ └── process_utils.cpython-310.pyc
│ ├── answer_extraction.py
│ └── process_utils.py
│ ├── eval
│ ├── __pycache__
│ │ ├── eval_script.cpython-310.pyc
│ │ ├── eval_script.cpython-311.pyc
│ │ ├── eval_utils.cpython-310.pyc
│ │ ├── eval_utils.cpython-311.pyc
│ │ ├── ocwcourses_eval_utils.cpython-310.pyc
│ │ ├── ocwcourses_eval_utils.cpython-311.pyc
│ │ ├── python_executor.cpython-310.pyc
│ │ └── utils.cpython-310.pyc
│ ├── eval_script.py
│ ├── eval_utils.py
│ ├── ocwcourses_eval_utils.py
│ ├── python_executor.py
│ └── utils.py
│ ├── run_cot_eval.py
│ ├── run_math_eval.sh
│ ├── run_subset_parallel.py
│ └── utils.py
├── gen
├── gen_rft_data.py
├── gen_rft_data.sh
├── gen_step_explore.py
├── gen_step_explore.sh
├── get_dpo_data.py
├── get_rft_data.py
├── math_utils
│ ├── answer_extraction.py
│ ├── eval_script.py
│ ├── eval_utils.py
│ └── ocwcourses_eval_utils.py
└── utils_others.py
├── images
├── main_result_image.png
└── overview_image.png
├── requirements.txt
├── scripts
├── gsm8k
│ ├── dpo
│ │ ├── config.yaml
│ │ ├── deepspeed_zero3.yaml
│ │ └── run_dpo.sh
│ └── sft
│ │ ├── config.yaml
│ │ ├── run_ft.sh
│ │ └── run_rft.sh
└── math
│ ├── dpo
│ ├── config.yaml
│ ├── deepspeed_zero3.yaml
│ └── run_dpo.sh
│ └── sft
│ ├── config.yaml
│ ├── run_ft.sh
│ └── run_rft.sh
└── sft
├── sft_datasets.py
├── train_generator.py
└── utils
├── cached_models.py
├── constants.py
├── datasets.py
├── flash_attn_monkey_patch.py
├── gsm8k
├── __init__.py
├── decoding.py
└── metrics.py
├── metrics.py
├── models.py
├── optim.py
├── sampling.py
├── states.py
└── verifier_models.py
/README.md:
--------------------------------------------------------------------------------
1 | # Self-Explore
2 | #### Self-Explore to avoid ️the p️️it!
Improving the Reasoning Capabilities of Language Models with Fine-grained Rewards
3 | ---
4 | This is the official github repository for **Self-Explore**.
5 | Paper Link: https://arxiv.org/abs/2404.10346
6 |
7 | ## Overview:
8 | 
9 |
10 | ## Setting
11 |
12 | Run ``pip install -r requirements.txt``
13 | All experiments were carried out using 4 x NVIDIA A100 80GB, with CUDA version 12.0.
14 |
15 |
16 | ## Data
17 |
18 | In the data directory, you will find the train and test file for `GSM8K` and `MATH`.
19 |
20 | ## Training
21 |
22 | > #### Stage 1. Run SFT:
23 | Run **SFT** (or FT, in short) to get the base generator.
24 | In `/scripts/{task}/sft/run_ft.sh` you'll see the script necessary for this. (For data_path, please put the trian file.)
25 | Put necessary paths to the files and models then simply run `sh scripts/{task}/sft/run_ft.sh` in the main directory.
26 |
27 | > #### Stage 2. Get RFT Data:
28 | Now you'll need to generate *N* instances per problem.
29 | To do this, go to `gen` directory and run `sh gen_rft_data.sh`.
30 | This assumes you are using 4 GPUs, and generates the predictions in parallel using each GPU.
31 | Once completed, you will see **RFT** and **DPO** training file.
32 |
33 | > #### Stage 3. Run RFT:
34 | Run **RFT** to get the RFT model, which acts our explorer and reference model when training for DPO.
35 | in `/scripts/{task}/sft/run_rft.sh` you'll see the script necessary for this.
36 | Put necessary paths to the files and models then simply run `sh /scripts/{task}/sft/run_rft.sh` in the main directory.
37 |
38 | > #### Stage 4. 🔎 Explore :
39 | To find the first ***pit***, let the RFT model explore from each step within rejected sample.
40 | You can do this by running `gen_step_explore.sh` in `gen` directory. (For data_path here, please put the DPO file generated).
41 | Then you will get a file named ending in `gpair_{k}.jsonl`
42 | which is your fine-grained pairwise training data.
43 |
44 | > #### Stage 5. Train with Preference Learning Objective:
45 | You can apply any arbitrary preference learning objective, but in our work, we chose **DPO (Direct Preference Optimization)**.
46 | To do this refer to `scripts/{task}/dpo/run_dpo.sh`.
47 | - To run with the outcome-supervision labels, set the training data as the DPO file generated in Stage 3.
48 | - To run with the step-level fine-grained labels (ours), set the training data as the gpair file generated in Stage 4.
49 |
50 | ## Evaluation
51 |
52 | Under `eval/{task}` directory, you'll find the script needed for running evaluation.
53 |
54 | ## Results
55 | 
56 |
57 | ## Models
58 | We release our best trained DeepSeek-Math's **GSM8K** and **MATH** trained checkpoints on huggingface.
59 | | Model | Accuracy | Download |
60 | | :----------------------- | :-------------: | :----------------------------------------------------------: |
61 | |**DeepSeek_Math_Self_Explore_GSM8K** | 78.62 | 🤗 [HuggingFace](https://huggingface.co/hbin0701/DeepSeek_MATH_Self_Explore) |
62 | |**DeepSeek_Math_Self_Explore_MATH** | 37.68 | 🤗 [HuggingFace](https://huggingface.co/hbin0701/DeepSeek_GSM8K_Self_Exploret) |
63 |
64 | ## Acknowledgemenets
65 | Our evaluation codes are borrowed from:
66 | - GSM8K: [OVM](https://github.com/FreedomIntelligence/OVM)
67 | - MATH: [DeepSeek-Math](https://github.com/deepseek-ai/DeepSeek-Math)
68 |
69 | ## Citation
70 | ```
71 | @misc{hwang2024selfexplore,
72 | title={Self-Explore to Avoid the Pit: Improving the Reasoning Capabilities of Language Models with Fine-grained Rewards},
73 | author={Hyeonbin Hwang and Doyoung Kim and Seungone Kim and Seonghyeon Ye and Minjoon Seo},
74 | year={2024},
75 | eprint={2404.10346},
76 | archivePrefix={arXiv},
77 | primaryClass={cs.CL}
78 | }
79 | ```
80 |
--------------------------------------------------------------------------------
/dpo/alignment/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.2.0.dev0"
2 |
3 | from .configs import DataArguments, DPOConfig, H4ArgumentParser, ModelArguments, SFTConfig
4 | from .data import apply_chat_template, get_datasets
5 | from .model_utils import get_kbit_device_map, get_peft_config, get_quantization_config, get_tokenizer, is_adapter_model
6 |
--------------------------------------------------------------------------------
/dpo/alignment/data.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # coding=utf-8
3 | # Copyright 2023 The HuggingFace Team. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import re
18 | from typing import List, Literal, Optional
19 |
20 | from datasets import DatasetDict, concatenate_datasets, load_dataset
21 |
22 | from .configs import DataArguments
23 |
24 |
25 | DEFAULT_CHAT_TEMPLATE = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
26 |
27 |
28 | def apply_chat_template(
29 | example, tokenizer, task: Literal["sft", "generation", "rm", "dpo"] = "sft", assistant_prefix="<|assistant|>\n"
30 | ):
31 | def _strip_prefix(s, pattern):
32 | # Use re.escape to escape any special characters in the pattern
33 | return re.sub(f"^{re.escape(pattern)}", "", s)
34 |
35 | if task in ["sft", "generation"]:
36 | messages = example["messages"]
37 | # We add an empty system message if there is none
38 | if messages[0]["role"] != "system":
39 | messages.insert(0, {"role": "system", "content": ""})
40 | example["text"] = tokenizer.apply_chat_template(
41 | messages, tokenize=False, add_generation_prompt=True if task == "generation" else False
42 | )
43 | elif task == "rm":
44 | if all(k in example.keys() for k in ("chosen", "rejected")):
45 | chosen_messages = example["chosen"]
46 | rejected_messages = example["rejected"]
47 | # We add an empty system message if there is none
48 | if chosen_messages[0]["role"] != "system":
49 | chosen_messages.insert(0, {"role": "system", "content": ""})
50 | if rejected_messages[0]["role"] != "system":
51 | rejected_messages.insert(0, {"role": "system", "content": ""})
52 | example["text_chosen"] = tokenizer.apply_chat_template(chosen_messages, tokenize=False)
53 | example["text_rejected"] = tokenizer.apply_chat_template(rejected_messages, tokenize=False)
54 | else:
55 | raise ValueError(
56 | f"Could not format example as dialogue for `rm` task! Require `[chosen, rejected]` keys but found {list(example.keys())}"
57 | )
58 | elif task == "dpo":
59 | if all(k in example.keys() for k in ("chosen", "rejected")):
60 | # Compared to reward modeling, we filter out the prompt, so the text is everything after the last assistant token
61 | prompt_messages = [[msg for msg in example["chosen"] if msg["role"] == "user"][0]]
62 | # Insert system message
63 | if example["chosen"][0]["role"] != "system":
64 | prompt_messages.insert(0, {"role": "system", "content": ""})
65 | else:
66 | prompt_messages.insert(0, example["chosen"][0])
67 | # TODO: handle case where chosen/rejected also have system messages
68 | chosen_messages = example["chosen"][1:]
69 | rejected_messages = example["rejected"][1:]
70 | example["text_chosen"] = tokenizer.apply_chat_template(chosen_messages, tokenize=False)
71 | example["text_rejected"] = tokenizer.apply_chat_template(rejected_messages, tokenize=False)
72 | example["text_prompt"] = tokenizer.apply_chat_template(
73 | prompt_messages, tokenize=False, add_generation_prompt=True
74 | )
75 |
76 | example["text_chosen"] = _strip_prefix(example["text_chosen"], assistant_prefix)
77 | example["text_rejected"] = _strip_prefix(example["text_rejected"], assistant_prefix)
78 | else:
79 | raise ValueError(
80 | f"Could not format example as dialogue for `dpo` task! Require `[chosen, rejected]` keys but found {list(example.keys())}"
81 | )
82 | return example
83 |
84 |
85 | def get_datasets(
86 | data_config,
87 | splits: List[str] = ["train", "test"],
88 | shuffle: bool = True,
89 | ) -> DatasetDict:
90 | """
91 | Loads one or more datasets with varying training set proportions.
92 |
93 | Args:
94 | data_config (`DataArguments` or `dict`):
95 | Dataset configuration and split proportions.
96 | splits (`List[str]`, *optional*, defaults to `['train', 'test']`):
97 | Dataset splits to load and mix. Assumes the splits exist in all datasets and have a `train_` or `test_` prefix.
98 | shuffle (`bool`, *optional*, defaults to `True`):
99 | Whether to shuffle the training data.
100 |
101 | Returns
102 | [`DatasetDict`]: The dataset dictionary containing the loaded datasets.
103 | """
104 |
105 | if type(data_config) is DataArguments:
106 | # Structure of the config to read the datasets and their mix
107 | # datasets_mixer:
108 | # - 'dataset1': 0.5
109 | # - 'dataset2': 0.3
110 | # - 'dataset3': 0.2
111 | dataset_mixer = data_config.dataset_mixer
112 | elif type(data_config) is dict:
113 | # Structure of the input is:
114 | # dataset_mixer = {
115 | # "dataset1": 0.5,
116 | # "dataset1": 0.3,
117 | # "dataset1": 0.2,
118 | # }
119 | dataset_mixer = data_config
120 | else:
121 | raise ValueError(f"Data config {data_config} not recognized.")
122 |
123 | raw_datasets = mix_datasets(dataset_mixer, splits=splits, shuffle=shuffle)
124 | return raw_datasets
125 |
126 |
127 | def mix_datasets(dataset_mixer: dict, splits: Optional[List[str]] = None, shuffle=True) -> DatasetDict:
128 | """
129 | Loads and mixes datasets according to proportions specified in `dataset_mixer`.
130 |
131 | Args:
132 | dataset_mixer (`dict`):
133 | Dictionary containing the dataset names and their training proportions. By default, all test proportions are 1.
134 | splits (Optional[List[str]], *optional*, defaults to `None`):
135 | Dataset splits to load and mix. Assumes the splits exist in all datasets and have a `train_` or `test_` prefix.
136 | shuffle (`bool`, *optional*, defaults to `True`):
137 | Whether to shuffle the training data.
138 | """
139 | raw_datasets = DatasetDict()
140 | raw_train_datasets = []
141 | raw_val_datasets = []
142 | fracs = []
143 | for ds, frac in dataset_mixer.items():
144 | fracs.append(frac)
145 | for split in splits:
146 | if "train" in split:
147 | raw_train_datasets.append(
148 | load_dataset(
149 | ds,
150 | split=split,
151 | )
152 | )
153 | elif "test" in split:
154 | raw_val_datasets.append(
155 | load_dataset(
156 | ds,
157 | split=split,
158 | )
159 | )
160 | else:
161 | raise ValueError(f"Split type {split} not recognized as one of test or train.")
162 |
163 | if any(frac < 0 for frac in fracs):
164 | raise ValueError("Dataset fractions cannot be negative.")
165 |
166 | if len(raw_train_datasets) > 0:
167 | train_subsets = []
168 | for dataset, frac in zip(raw_train_datasets, fracs):
169 | train_subset = dataset.select(range(int(frac * len(dataset))))
170 | train_subsets.append(train_subset)
171 | if shuffle:
172 | raw_datasets["train"] = concatenate_datasets(train_subsets).shuffle(seed=42)
173 | else:
174 | raw_datasets["train"] = concatenate_datasets(train_subsets)
175 | # No subsampling for test datasets to enable fair comparison across models
176 | if len(raw_val_datasets) > 0:
177 | if shuffle:
178 | raw_datasets["test"] = concatenate_datasets(raw_val_datasets).shuffle(seed=42)
179 | else:
180 | raw_datasets["test"] = concatenate_datasets(raw_val_datasets)
181 |
182 | if len(raw_datasets) == 0:
183 | raise ValueError(
184 | f"Dataset {dataset_mixer} not recognized with split {split}. Check the dataset has been correctly formatted."
185 | )
186 |
187 | return raw_datasets
188 |
--------------------------------------------------------------------------------
/dpo/alignment/model_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # coding=utf-8
3 | # Copyright 2023 The HuggingFace Team. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | from typing import Dict
18 |
19 | import torch
20 | from transformers import AutoTokenizer, BitsAndBytesConfig, PreTrainedTokenizer
21 |
22 | from accelerate import Accelerator
23 | from huggingface_hub import list_repo_files
24 | from peft import LoraConfig, PeftConfig
25 |
26 | from .configs import DataArguments, ModelArguments
27 | from .data import DEFAULT_CHAT_TEMPLATE
28 |
29 |
30 | def get_current_device() -> int:
31 | """Get the current device. For GPU we return the local process index to enable multiple GPU training."""
32 | return Accelerator().local_process_index if torch.cuda.is_available() else "cpu"
33 |
34 |
35 | def get_kbit_device_map():
36 | """Useful for running inference with quantized models by setting `device_map=get_peft_device_map()`"""
37 | return {"": get_current_device()} if torch.cuda.is_available() else None
38 |
39 |
40 | def get_quantization_config(model_args):
41 | if model_args.load_in_4bit:
42 | quantization_config = BitsAndBytesConfig(
43 | load_in_4bit=True,
44 | bnb_4bit_compute_dtype=torch.float16, # For consistency with model weights, we use the same value as `torch_dtype` which is float16 for PEFT models
45 | bnb_4bit_quant_type=model_args.bnb_4bit_quant_type,
46 | bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant,
47 | )
48 | elif model_args.load_in_8bit:
49 | quantization_config = BitsAndBytesConfig(
50 | load_in_8bit=True,
51 | )
52 | else:
53 | quantization_config = None
54 |
55 | return quantization_config
56 |
57 |
58 | def get_tokenizer(model_args: ModelArguments, data_args: DataArguments) -> PreTrainedTokenizer:
59 | """Get the tokenizer for the model."""
60 | tokenizer = AutoTokenizer.from_pretrained(
61 | model_args.model_name_or_path,
62 | revision=model_args.model_revision,
63 | )
64 | if tokenizer.pad_token_id is None:
65 | tokenizer.pad_token_id = tokenizer.eos_token_id
66 |
67 | if data_args.truncation_side is not None:
68 | tokenizer.truncation_side = data_args.truncation_side
69 |
70 | # Set reasonable default for models without max length
71 | if tokenizer.model_max_length > 100_000:
72 | tokenizer.model_max_length = 2048
73 |
74 | if data_args.chat_template is not None:
75 | tokenizer.chat_template = data_args.chat_template
76 | elif tokenizer.chat_template is None:
77 | tokenizer.chat_template = DEFAULT_CHAT_TEMPLATE
78 |
79 | return tokenizer
80 |
81 |
82 | def get_peft_config(model_args: ModelArguments):
83 | if model_args.use_peft is False:
84 | return None
85 |
86 | peft_config = LoraConfig(
87 | r=model_args.lora_r,
88 | lora_alpha=model_args.lora_alpha,
89 | lora_dropout=model_args.lora_dropout,
90 | bias="none",
91 | task_type="CAUSAL_LM",
92 | target_modules=model_args.lora_target_modules,
93 | modules_to_save=model_args.lora_modules_to_save,
94 | )
95 |
96 | return peft_config
97 |
98 |
99 | def is_adapter_model(model_name_or_path: str, revision: str = "main") -> bool:
100 | repo_files = list_repo_files(model_name_or_path, revision=revision)
101 | return "adapter_model.safetensors" in repo_files or "adapter_model.bin" in repo_files
102 |
--------------------------------------------------------------------------------
/dpo/alignment/release.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The HuggingFace Team. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import argparse
17 | import re
18 |
19 | import packaging.version
20 |
21 |
22 | REPLACE_PATTERNS = {
23 | "init": (re.compile(r'^__version__\s+=\s+"([^"]+)"\s*$', re.MULTILINE), '__version__ = "VERSION"\n'),
24 | "setup": (re.compile(r'^(\s*)version\s*=\s*"[^"]+",', re.MULTILINE), r'\1version="VERSION",'),
25 | }
26 | REPLACE_FILES = {
27 | "init": "src/alignment/__init__.py",
28 | "setup": "setup.py",
29 | }
30 | README_FILE = "README.md"
31 |
32 |
33 | def update_version_in_file(fname, version, pattern):
34 | """Update the version in one file using a specific pattern."""
35 | with open(fname, "r", encoding="utf-8", newline="\n") as f:
36 | code = f.read()
37 | re_pattern, replace = REPLACE_PATTERNS[pattern]
38 | replace = replace.replace("VERSION", version)
39 | code = re_pattern.sub(replace, code)
40 | with open(fname, "w", encoding="utf-8", newline="\n") as f:
41 | f.write(code)
42 |
43 |
44 | def global_version_update(version, patch=False):
45 | """Update the version in all needed files."""
46 | for pattern, fname in REPLACE_FILES.items():
47 | update_version_in_file(fname, version, pattern)
48 |
49 |
50 | def get_version():
51 | """Reads the current version in the __init__."""
52 | with open(REPLACE_FILES["init"], "r") as f:
53 | code = f.read()
54 | default_version = REPLACE_PATTERNS["init"][0].search(code).groups()[0]
55 | return packaging.version.parse(default_version)
56 |
57 |
58 | def pre_release_work(patch=False):
59 | """Do all the necessary pre-release steps."""
60 | # First let's get the default version: base version if we are in dev, bump minor otherwise.
61 | default_version = get_version()
62 | if patch and default_version.is_devrelease:
63 | raise ValueError("Can't create a patch version from the dev branch, checkout a released version!")
64 | if default_version.is_devrelease:
65 | default_version = default_version.base_version
66 | elif patch:
67 | default_version = f"{default_version.major}.{default_version.minor}.{default_version.micro + 1}"
68 | else:
69 | default_version = f"{default_version.major}.{default_version.minor + 1}.0"
70 |
71 | # Now let's ask nicely if that's the right one.
72 | version = input(f"Which version are you releasing? [{default_version}]")
73 | if len(version) == 0:
74 | version = default_version
75 |
76 | print(f"Updating version to {version}.")
77 | global_version_update(version, patch=patch)
78 |
79 |
80 | def post_release_work():
81 | """Do all the necessary post-release steps."""
82 | # First let's get the current version
83 | current_version = get_version()
84 | dev_version = f"{current_version.major}.{current_version.minor + 1}.0.dev0"
85 | current_version = current_version.base_version
86 |
87 | # Check with the user we got that right.
88 | version = input(f"Which version are we developing now? [{dev_version}]")
89 | if len(version) == 0:
90 | version = dev_version
91 |
92 | print(f"Updating version to {version}.")
93 | global_version_update(version)
94 |
95 |
96 | if __name__ == "__main__":
97 | parser = argparse.ArgumentParser()
98 | parser.add_argument("--post_release", action="store_true", help="Whether this is pre or post release.")
99 | parser.add_argument("--patch", action="store_true", help="Whether or not this is a patch release.")
100 | args = parser.parse_args()
101 | if not args.post_release:
102 | pre_release_work(patch=args.patch)
103 | elif args.patch:
104 | print("Nothing to do after a patch :-)")
105 | else:
106 | post_release_work()
107 |
--------------------------------------------------------------------------------
/dpo/prepare_dataset.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import json
3 | from datasets import Dataset, DatasetDict
4 |
5 |
6 | def get_dataset(path_train, path_test):
7 |
8 | final_dict = {}
9 |
10 | for idx, path in enumerate([path_train, path_test]):
11 | li = [json.loads(x) for x in open(path)]
12 |
13 | text_prompts = []
14 | chosen = []
15 | rejected = []
16 |
17 | # If test (i.e. validation set), just use dummy data: subset of training.
18 | if idx == 1:
19 | li = li[:100]
20 |
21 | for elem in li:
22 |
23 | if 'prompt' not in elem.keys():
24 | elem['prompt'] = elem['input']
25 |
26 | text_prompts.append(elem['prompt'].rstrip("\n") + "\n")
27 | chosen.append(elem['chosen'])
28 |
29 | rej = elem['rejected']
30 |
31 | # [Note] Remove last line for rejected sample.
32 | tgt_string = "The answer is"
33 | if tgt_string in rej and len(rej.split("\n")) > 1:
34 | rej = rej[:rej.index(tgt_string)].strip()
35 |
36 | rejected.append(rej)
37 |
38 | d = {"prompt": text_prompts, "text_prompt": text_prompts, "chosen": chosen, "rejected": rejected}
39 |
40 | if idx == 0:
41 | final_dict["train"] = d
42 | else:
43 | final_dict["test"] = d
44 |
45 | train_dataset = Dataset.from_dict(final_dict["train"])
46 | test_dataset = Dataset.from_dict(final_dict["test"])
47 |
48 | print("train example", train_dataset[0])
49 |
50 | # Create a DatasetDict
51 | dataset_dict = DatasetDict({
52 | 'train': train_dataset.shuffle(seed=42),
53 | 'test': test_dataset.shuffle(seed=42)
54 | })
55 |
56 | return dataset_dict
--------------------------------------------------------------------------------
/dpo/run_dpo.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import sys
3 | import wandb
4 | import torch
5 | import transformers
6 | from transformers import AutoModelForCausalLM, set_seed, LlamaTokenizer, AutoTokenizer
7 | from accelerate import Accelerator
8 |
9 | from alignment import (
10 | DataArguments,
11 | DPOConfig,
12 | H4ArgumentParser,
13 | ModelArguments,
14 | apply_chat_template,
15 | get_datasets,
16 | get_kbit_device_map,
17 | get_peft_config,
18 | get_quantization_config,
19 | get_tokenizer,
20 | is_adapter_model,
21 | )
22 |
23 |
24 | from peft import PeftConfig, PeftModel
25 | from custom_trainer import DPOTrainer
26 | from prepare_dataset import get_dataset
27 | import os
28 |
29 | os.environ["WANDB_API_KEY"] = "" # PUR YOUR WANDB KEY
30 | os.environ["WANDB_ENTITY"] = "" # PUT YOUR WANDB ID
31 | os.environ["WANDB_PROJECT"] = "" # PUT YOUR WANDB PROJECT
32 |
33 | logger = logging.getLogger(__name__)
34 |
35 |
36 | def main():
37 | parser = H4ArgumentParser((ModelArguments, DataArguments, DPOConfig))
38 | model_args, data_args, training_args = parser.parse()
39 |
40 | #######
41 | # Setup
42 | #######
43 | logging.basicConfig(
44 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
45 | datefmt="%Y-%m-%d %H:%M:%S",
46 | handlers=[logging.StreamHandler(sys.stdout)],
47 | )
48 | log_level = training_args.get_process_log_level()
49 | logger.setLevel(log_level)
50 | transformers.utils.logging.set_verbosity(log_level)
51 | transformers.utils.logging.enable_default_handler()
52 | transformers.utils.logging.enable_explicit_format()
53 |
54 | # Log on each process the small summary:
55 | logger.info(f"Model parameters {model_args}")
56 | logger.info(f"Data parameters {data_args}")
57 | logger.info(f"Training/evaluation parameters {training_args}")
58 |
59 | # Set seed for reproducibility
60 | set_seed(training_args.seed)
61 |
62 | # Increase distributed timeout to 3h to enable push to Hub to complete
63 | accelerator = Accelerator()
64 | tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
65 |
66 | # there should be "train" and "test", with column anme "text_prompt", "text_chosen", "text_rejected".
67 | raw_datasets = get_dataset(data_args.train_data_file, data_args.test_data_file)
68 | print("dataset", raw_datasets)
69 | print(("---" * 30))
70 | torch_dtype = (
71 | model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
72 | )
73 | model_kwargs = dict(
74 | revision=model_args.model_revision,
75 | trust_remote_code=model_args.trust_remote_code,
76 | use_flash_attention_2=model_args.use_flash_attention_2,
77 | torch_dtype=torch_dtype,
78 | use_cache=False if training_args.gradient_checkpointing else True,
79 | device_map=get_kbit_device_map(),
80 | quantization_config=get_quantization_config(model_args),
81 | )
82 |
83 | model = model_args.model_name_or_path
84 | ref_model = model
85 | ref_model_kwargs = model_kwargs
86 |
87 | if model_args.use_peft is True:
88 | ref_model = None
89 | ref_model_kwargs = None
90 |
91 | #########################
92 | # Instantiate DPO trainer
93 | #########################
94 | dpo_trainer = DPOTrainer(
95 | model,
96 | ref_model,
97 | model_init_kwargs=model_kwargs,
98 | ref_model_init_kwargs=ref_model_kwargs,
99 | args=training_args,
100 | beta=training_args.beta,
101 | train_dataset= raw_datasets["train"],
102 | eval_dataset=raw_datasets["test"],
103 | tokenizer=tokenizer,
104 | max_length=training_args.max_length,
105 | max_prompt_length=training_args.max_prompt_length,
106 | peft_config=get_peft_config(model_args),
107 | loss_type=training_args.loss_type
108 | )
109 |
110 | ###############
111 | # Training loop
112 | ###############
113 | train_result = dpo_trainer.train()
114 | metrics = train_result.metrics
115 | max_train_samples = (
116 | data_args.max_train_samples if data_args.max_train_samples is not None else len(raw_datasets["train"])
117 | )
118 | metrics["train_samples"] = min(max_train_samples, len(raw_datasets["train"]))
119 | dpo_trainer.log_metrics("train", metrics)
120 | dpo_trainer.save_metrics("train", metrics)
121 | dpo_trainer.save_state()
122 |
123 | logger.info("*** Training complete ***")
124 |
125 | # Evaluate
126 | if training_args.do_eval:
127 | logger.info("*** Evaluate ***")
128 | metrics = dpo_trainer.evaluate()
129 | max_eval_samples = (
130 | data_args.max_eval_samples if data_args.max_eval_samples is not None else len(raw_datasets["test"])
131 | )
132 | metrics["eval_samples"] = min(max_eval_samples, len(raw_datasets["test"]))
133 | dpo_trainer.log_metrics("eval", metrics)
134 | dpo_trainer.save_metrics("eval", metrics)
135 |
136 | # Save model and create model card
137 | dpo_trainer.save_model(training_args.output_dir)
138 |
139 | # Save everything else on main process
140 | if accelerator.is_main_process:
141 | kwargs = {
142 | "finetuned_from": model_args.model_name_or_path,
143 | "dataset": list(data_args.dataset_mixer.keys()),
144 | "dataset_tags": list(data_args.dataset_mixer.keys()),
145 | "tags": ["alignment-handbook"],
146 | }
147 | dpo_trainer.create_model_card(**kwargs)
148 | # Restore k,v cache for fast inference
149 | dpo_trainer.model.config.use_cache = True
150 | dpo_trainer.model.config.save_pretrained(training_args.output_dir)
151 | if training_args.push_to_hub is True:
152 | dpo_trainer.push_to_hub()
153 |
154 | # Ensure we don't timeout on model save / push to Hub
155 | logger.info("*** Waiting for all processes to finish ***")
156 | accelerator.wait_for_everyone()
157 |
158 | logger.info("*** Run complete! ***")
159 |
160 |
161 | if __name__ == "__main__":
162 | main()
163 |
--------------------------------------------------------------------------------
/dpo/run_kto.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import sys
3 | import wandb
4 | import torch
5 | import transformers
6 | from transformers import AutoModelForCausalLM, set_seed, LlamaTokenizer, AutoTokenizer
7 | from accelerate import Accelerator
8 |
9 | from alignment import (
10 | DataArguments,
11 | DPOConfig,
12 | H4ArgumentParser,
13 | ModelArguments,
14 | apply_chat_template,
15 | get_datasets,
16 | get_kbit_device_map,
17 | get_peft_config,
18 | get_quantization_config,
19 | get_tokenizer,
20 | is_adapter_model,
21 | )
22 |
23 | from peft import PeftConfig, PeftModel
24 | # from custom_trainer import DPOTrainer, KTOTrainer, KTOConfig
25 | from prepare_dataset import get_dataset
26 | # from trls.trl.trainer import KTOTrainer, KTOConfig
27 | from custom_trainer import KTOTrainer, KTOConfig
28 | # from step_prepare_dataset import get_dataset
29 | import os
30 |
31 | os.environ["WANDB_API_KEY"] = "" # PUR YOUR WANDB KEY
32 | os.environ["WANDB_ENTITY"] = "" # PUT YOUR WANDB ID
33 | os.environ["WANDB_PROJECT"] = "" # PUT YOUR WANDB PROJECT
34 |
35 | logger = logging.getLogger(__name__)
36 |
37 |
38 | def main():
39 | parser = H4ArgumentParser((ModelArguments, DataArguments, KTOConfig))
40 | model_args, data_args, training_args = parser.parse()
41 |
42 | #######
43 | # Setup
44 | #######
45 | logging.basicConfig(
46 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
47 | datefmt="%Y-%m-%d %H:%M:%S",
48 | handlers=[logging.StreamHandler(sys.stdout)],
49 | )
50 | log_level = training_args.get_process_log_level()
51 | logger.setLevel(log_level)
52 | transformers.utils.logging.set_verbosity(log_level)
53 | transformers.utils.logging.enable_default_handler()
54 | transformers.utils.logging.enable_explicit_format()
55 |
56 | # Log on each process the small summary:
57 | logger.info(f"Model parameters {model_args}")
58 | logger.info(f"Data parameters {data_args}")
59 | logger.info(f"Training/evaluation parameters {training_args}")
60 |
61 | # Set seed for reproducibility
62 | set_seed(training_args.seed)
63 |
64 | # Increase distributed timeout to 3h to enable push to Hub to complete
65 | accelerator = Accelerator()
66 | tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
67 |
68 | # there should be "train" and "test", with column anme "text_prompt", "text_chosen", "text_rejected".
69 | # raw_datasets = get_dataset(data_args.train_data_file, data_args.test_data_file)
70 | import json
71 | from datasets import Dataset
72 |
73 | data = json.load(open(data_args.train_data_file))
74 |
75 | for idx in range(len(data['label'])):
76 | if data['label'][idx] == False:
77 | rej = data['completion'][idx]
78 | tgt_string = "The answer is"
79 | if tgt_string in rej and len(rej.split("\n")) > 1:
80 | rej = rej[:rej.index(tgt_string)].strip()
81 | data['completion'][idx] = rej
82 |
83 | # For format of KTO Dataset, see: https://huggingface.co/docs/trl/main/en/kto_trainer
84 | train_set = Dataset.from_dict({k: v for k, v in data.items()})
85 | train_set = train_set.shuffle(seed=42)
86 | test_set = Dataset.from_dict({k: v[:100] for k, v in data.items()})
87 |
88 | model_args.torch_dtype = "bfloat16"
89 | torch_dtype = (
90 | model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
91 | )
92 | print("torch_dtype", torch_dtype)
93 |
94 | model_kwargs = dict(
95 | revision=model_args.model_revision,
96 | trust_remote_code=model_args.trust_remote_code,
97 | use_flash_attention_2=True,
98 | torch_dtype=torch_dtype,
99 | use_cache=False,
100 | device_map=get_kbit_device_map(),
101 | quantization_config=get_quantization_config(model_args),
102 | )
103 |
104 | model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, **model_kwargs)
105 | ref_model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, **model_kwargs)
106 |
107 | # if model_args.use_peft is True:
108 | # ref_model = None
109 | # ref_model_kwargs = None
110 |
111 | #########################
112 | # Instantiate DPO trainer
113 | #########################
114 | # training_args.model_init_kwargs = model_kwargs
115 | # training_args.ref_model_init_kwargs = ref_model_kwargs
116 |
117 | kto_trainer = KTOTrainer(
118 | model,
119 | ref_model,
120 | args=training_args,
121 | train_dataset= train_set,
122 | eval_dataset=test_set,
123 | tokenizer=tokenizer,
124 | peft_config=get_peft_config(model_args),
125 | )
126 |
127 | ###############
128 | # Training loop
129 | ###############
130 | train_result = kto_trainer.train()
131 | metrics = train_result.metrics
132 | max_train_samples = (
133 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_set)
134 | )
135 | metrics["train_samples"] = min(max_train_samples, len(train_set))
136 | kto_trainer.log_metrics("train", metrics)
137 | kto_trainer.save_metrics("train", metrics)
138 | kto_trainer.save_state()
139 |
140 | logger.info("*** Training complete ***")
141 |
142 | # Evaluate
143 | # if training_args.do_eval:
144 | # logger.info("*** Evaluate ***")
145 | # metrics = kto_trainer.evaluate()
146 | # max_eval_samples = (
147 | # data_args.max_eval_samples if data_args.max_eval_samples is not None else len(test_set)
148 | # )
149 | # metrics["eval_samples"] = min(max_eval_samples, len(kto_trainer["test"]))
150 | # kto_trainer.log_metrics("eval", metrics)
151 | # kto_trainer.save_metrics("eval", metrics)
152 |
153 | # Save model and create model card
154 | kto_trainer.save_model(training_args.output_dir)
155 |
156 | # Save everything else on main process
157 | if accelerator.is_main_process:
158 | kwargs = {
159 | "finetuned_from": model_args.model_name_or_path,
160 | "dataset": list(data_args.dataset_mixer.keys()),
161 | "dataset_tags": list(data_args.dataset_mixer.keys()),
162 | "tags": ["alignment-handbook"],
163 | }
164 | kto_trainer.create_model_card(**kwargs)
165 | # Restore k,v cache for fast inference
166 | kto_trainer.model.config.use_cache = True
167 | kto_trainer.model.config.save_pretrained(training_args.output_dir)
168 | if training_args.push_to_hub is True:
169 | kto_trainer.push_to_hub()
170 |
171 | # Ensure we don't timeout on model save / push to Hub
172 | logger.info("*** Waiting for all processes to finish ***")
173 | accelerator.wait_for_everyone()
174 |
175 | logger.info("*** Run complete! ***")
176 |
177 |
178 | if __name__ == "__main__":
179 | main()
180 |
--------------------------------------------------------------------------------
/eval/gsm8k/eval_gsm8k.py:
--------------------------------------------------------------------------------
1 | import os
2 | # os.environ["CUDA_VISIBLE_DEVICES"]="4,5,6,7"
3 |
4 | from huggingface_hub import login
5 | import argparse
6 | import json
7 | import re
8 | import jsonlines
9 | from fraction import Fraction
10 | from vllm import LLM, SamplingParams
11 | import sys
12 | from tqdm.auto import tqdm
13 | from utils_ans import extract_answer
14 |
15 | MAX_INT = sys.maxsize
16 |
17 | def batch_data(data_list, batch_size=1):
18 | n = len(data_list) // batch_size
19 | batch_data = []
20 | for i in range(n-1):
21 | start = i * batch_size
22 | end = (i+1)*batch_size
23 | batch_data.append(data_list[start:end])
24 |
25 | last_start = (n-1) * batch_size
26 | last_end = MAX_INT
27 | batch_data.append(data_list[last_start:last_end])
28 | return batch_data
29 |
30 | def gsm8k_test(model, data_path, start=0, end=MAX_INT, batch_size=1, tensor_parallel_size=1, temp=0.7):
31 | gsm8k_ins = []
32 | gsm8k_answers = []
33 |
34 | # Check if it already exists.
35 | try:
36 | already_done = [json.loads(x) for x in open(args.result_file)]
37 | except:
38 | already_done = []
39 |
40 | with open(data_path,"r+", encoding="utf8") as f:
41 | for idx, item in enumerate(jsonlines.Reader(f)):
42 |
43 | if idx < len(already_done):
44 | continue
45 |
46 | gsm8k_ins.append(item["query"])
47 | temp_ans = int(item['response'].split('#### ')[1].replace(',', ''))
48 | gsm8k_answers.append(temp_ans)
49 |
50 | gsm8k_ins = gsm8k_ins[start:end]
51 | gsm8k_answers = gsm8k_answers[start:end]
52 | print('length ====', len(gsm8k_ins))
53 | batch_gsm8k_ins = batch_data(gsm8k_ins, batch_size=batch_size)
54 |
55 | # stop_tokens = ["\n\n", "Question:", "Question", "USER:", "USER", "ASSISTANT:", "ASSISTANT", "Instruction:", "Instruction", "Response:", "Response"]
56 | stop_tokens = []
57 |
58 | if temp == 0.7:
59 | n = 100
60 | else:
61 | n = 1
62 |
63 | sampling_params = SamplingParams(temperature=temp, top_p=1, max_tokens=512, stop=stop_tokens, n=n)
64 | print('sampling =====', sampling_params)
65 | llm = LLM(model=model,tensor_parallel_size=tensor_parallel_size, enforce_eager=True)
66 | result = []
67 | res_completions = []
68 |
69 | for idx, (prompt, prompt_answer) in tqdm(enumerate(zip(batch_gsm8k_ins, gsm8k_answers))):
70 | if isinstance(prompt, list):
71 | pass
72 | else:
73 | prompt = [prompt]
74 |
75 | completions = llm.generate(prompt, sampling_params)
76 |
77 | for num, output in enumerate(completions):
78 | prompt = output.prompt
79 | all_texts = [out.text for out in output.outputs]
80 | res_completions.append(all_texts)
81 |
82 | answer = gsm8k_answers[idx*batch_size + num]
83 | dict_ = {"prompt": prompt, "preds": all_texts, "answer": answer}
84 |
85 | with jsonlines.open(args.result_file, 'a') as writer:
86 | writer.write(dict_)
87 |
88 | li = [json.loads(x) for x in open(args.result_file)]
89 |
90 | sa = [] # singgle acc
91 |
92 | for x in li:
93 | if 'answers' in x:
94 | lbl = str(x['answers'])
95 | else:
96 | lbl = str(x['answer'])
97 |
98 | answers = [str(extract_answer(pred)) for pred in x['preds']]
99 | eq_answers = [ans == lbl for ans in answers]
100 | sa.append(eq_answers.count(True) / len(eq_answers))
101 |
102 | print("Final Acc:", sum(sa) / len(sa))
103 |
104 | def parse_args():
105 | parser = argparse.ArgumentParser()
106 | parser.add_argument("--model", type=str) # model path
107 | parser.add_argument("--data_file", type=str, default='') # data path
108 | parser.add_argument("--start", type=int, default=0) # start index
109 | parser.add_argument("--end", type=int, default=MAX_INT) # end index
110 | parser.add_argument("--batch_size", type=int, default=1000) # batch_size
111 | parser.add_argument("--tensor_parallel_size", type=int, default=1) # tensor_parallel_size
112 | parser.add_argument("--result_file", type=str, default="./log_file.jsonl") # tensor_parallel_size
113 | parser.add_argument("--temp", type=float, default=0.7)
114 |
115 | return parser.parse_args()
116 |
117 | if __name__ == "__main__":
118 | # Login First.
119 | # login(token="your_huggingface_token")
120 |
121 | args = parse_args()
122 | gsm8k_test(model=args.model, data_path=args.data_file, start=args.start, end=args.end, batch_size=args.batch_size, tensor_parallel_size=args.tensor_parallel_size, temp=args.temp)
--------------------------------------------------------------------------------
/eval/gsm8k/run_gsm8k_eval.sh:
--------------------------------------------------------------------------------
1 | model_path=""
2 | result_file=""
3 | data_file=""
4 |
5 | # To run evaluation for GSM8K:
6 | python eval_gsm8k.py --data_file $data_file --model $model_path --result_file $result_file --temp 0
7 |
--------------------------------------------------------------------------------
/eval/gsm8k/utils_ans.py:
--------------------------------------------------------------------------------
1 | ### From OVM/utils/gsm8k/decoding.py
2 |
3 | from contextlib import contextmanager
4 | import signal
5 | import json
6 | import os
7 | import re
8 |
9 | # ANS_RE = re.compile(r"#### (\-?[0-9\.\,]+)")
10 | ANS_RE = re.compile(r"The answer is:?\s*(\-?[0-9\.\,]+)")
11 | INVALID_ANS = "[invalid]"
12 |
13 |
14 | def extract_answer(completion):
15 | match = ANS_RE.search(completion)
16 | if match:
17 | match_str = match.group(1).strip()
18 | st_str = standardize_value_str(match_str)
19 | try: eval(st_str); return st_str
20 | except: ...
21 | return INVALID_ANS
22 |
23 | def extract_answers(completions):
24 | return [extract_answer(completion) for completion in completions]
25 |
26 | def standardize_value_str(x):
27 | """Standardize numerical values"""
28 | y = x.replace(",", "")
29 | if '.' in y:
30 | y = y.rstrip('0')
31 | if y[-1] == '.':
32 | y = y[:-1]
33 | if not len(y):
34 | return INVALID_ANS
35 | if y[0] == '.':
36 | y = '0' + y
37 | if y[-1] == '%':
38 | y = str(eval(y[:-1]) / 100)
39 | return y.rstrip('.')
40 |
41 | def get_answer_label(response_answer, gt):
42 | if response_answer == INVALID_ANS:
43 | return INVALID_ANS
44 | return response_answer == gt
45 |
46 |
47 | # taken from
48 | # https://stackoverflow.com/questions/492519/timeout-on-a-function-call
49 | @contextmanager
50 | def timeout(duration, formula):
51 | def timeout_handler(signum, frame):
52 | raise Exception(f"'{formula}': timed out after {duration} seconds")
53 |
54 | signal.signal(signal.SIGALRM, timeout_handler)
55 | signal.alarm(duration)
56 | yield
57 | signal.alarm(0)
58 |
59 |
60 | def eval_with_timeout(formula, max_time=3):
61 | try:
62 | with timeout(max_time, formula):
63 | return round(eval(formula), ndigits=4)
64 | except Exception as e:
65 | signal.alarm(0)
66 | print(f"Warning: Failed to eval {formula}, exception: {e}")
67 | return None
68 |
69 | # refer to https://github.com/openai/grade-school-math/blob/master/grade_school_math/calculator.py
70 | def use_calculator(sample):
71 | if "<<" not in sample:
72 | return None
73 |
74 | parts = sample.split("<<")
75 | remaining = parts[-1]
76 | if ">>" in remaining:
77 | return None
78 | if "=" not in remaining:
79 | return None
80 | lhs = remaining.split("=")[0]
81 | lhs = lhs.replace(",", "")
82 | if any([x not in "0123456789*+-/.()" for x in lhs]):
83 | return None
84 | ans = eval_with_timeout(lhs)
85 | if remaining[-1] == '-' and ans is not None and ans < 0:
86 | ans = -ans
87 | return ans
88 |
89 |
90 |
91 |
92 |
93 |
94 |
--------------------------------------------------------------------------------
/eval/math/configs/few_shot_test_configs.json:
--------------------------------------------------------------------------------
1 | {
2 | "math-cot-test": {
3 | "test_path": "../../data/MATH_test.jsonl",
4 | "language": "en",
5 | "tasks": [
6 | "cot"
7 | ],
8 | "process_fn": "process_math_test",
9 | "answer_extraction_fn": "extract_math_few_shot_cot_answer",
10 | "eval_fn": "eval_math",
11 | "few_shot_prompt": "MinervaMathPrompt"
12 | }
13 | }
--------------------------------------------------------------------------------
/eval/math/configs/zero_shot_test_configs.json:
--------------------------------------------------------------------------------
1 | {
2 | "gsm8k-test": {
3 | "test_path": "datasets/gsm8k/test.jsonl",
4 | "language": "en",
5 | "tasks": ["tool", "cot"],
6 | "process_fn": "process_gsm8k_test",
7 | "answer_extraction_fn": "extract_last_single_answer",
8 | "eval_fn": "eval_last_single_answer"
9 | },
10 | "math-test": {
11 | "test_path": "datasets/math/test.jsonl",
12 | "language": "en",
13 | "tasks": ["tool", "cot"],
14 | "process_fn": "process_math_test",
15 | "answer_extraction_fn": "extract_math_answer",
16 | "eval_fn": "eval_math"
17 | },
18 | "mgsm-zh": {
19 | "test_path": "datasets/mgsm_zh/mgsm_zh.jsonl",
20 | "language": "zh",
21 | "tasks": ["tool", "cot"],
22 | "process_fn": "process_mgsm_zh",
23 | "answer_extraction_fn": "extract_last_single_answer",
24 | "eval_fn": "eval_last_single_answer"
25 | },
26 | "cmath": {
27 | "test_path": "datasets/cmath/test.jsonl",
28 | "language": "zh",
29 | "tasks": ["tool", "cot"],
30 | "process_fn": "process_cmath",
31 | "answer_extraction_fn": "extract_last_single_answer",
32 | "eval_fn": "eval_last_single_answer"
33 | }
34 | }
--------------------------------------------------------------------------------
/eval/math/data_processing/__pycache__/answer_extraction.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hbin0701/Self-Explore/54b8ba9f4cee1ef52f72de916e100e0da2ae6863/eval/math/data_processing/__pycache__/answer_extraction.cpython-310.pyc
--------------------------------------------------------------------------------
/eval/math/data_processing/__pycache__/answer_extraction.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hbin0701/Self-Explore/54b8ba9f4cee1ef52f72de916e100e0da2ae6863/eval/math/data_processing/__pycache__/answer_extraction.cpython-311.pyc
--------------------------------------------------------------------------------
/eval/math/data_processing/__pycache__/process_utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hbin0701/Self-Explore/54b8ba9f4cee1ef52f72de916e100e0da2ae6863/eval/math/data_processing/__pycache__/process_utils.cpython-310.pyc
--------------------------------------------------------------------------------
/eval/math/data_processing/process_utils.py:
--------------------------------------------------------------------------------
1 | import regex
2 |
3 | from data_processing.answer_extraction import extract_math_answer, strip_string
4 |
5 | def process_gsm8k_test(item):
6 | sample = {
7 | 'dataset': 'gsm8k-cot',
8 | 'id': item['id'],
9 | 'messages': [
10 | {'role': 'user', 'content': item['question']},
11 | {'role': 'assistant', 'content': regex.sub(r"<<[^<>]*>>", "", item['cot']) + "\nSo the answer is $\\boxed{" + item['answer'].strip() + "}$."}
12 | ],
13 | 'answer': item['answer'].replace(',', '')
14 | }
15 | yield sample
16 |
17 | def process_math_test(item):
18 | question = item["problem"]
19 | try:
20 | answer = extract_math_answer(question, item['solution'], task="cot")
21 | except:
22 | return
23 | sample = {
24 | "dataset": "math-cot",
25 | "id": item['id'],
26 | "level": item["level"],
27 | "type": item["type"],
28 | "category": item["category"],
29 | "messages": [
30 | {"role": "user", "content": question},
31 | {"role": "assistant", "content": "\n".join(regex.split(r"(?<=\.) (?=[A-Z])", item["solution"]))}
32 | ],
33 | "answer": answer
34 | }
35 | yield sample
36 |
37 | def process_math_sat(item):
38 | options = item['options'].strip()
39 | assert 'A' == options[0]
40 | options = '(' + options
41 | for ch in 'BCDEFG':
42 | if f' {ch}) ' in options:
43 | options = regex.sub(f' {ch}\) ', f" ({ch}) ", options)
44 | question = f"{item['question'].strip()}\nWhat of the following is the right choice? Explain your answer.\n{options.strip()}"
45 | messages = [
46 | {'role': 'user', 'content': question},
47 | {'role': 'assistant', 'content': item['Answer']}
48 | ]
49 | item = {
50 | 'dataset': 'math_sat',
51 | 'id': item['id'],
52 | 'language': 'en',
53 | 'messages': messages,
54 | 'answer': item['Answer'],
55 | }
56 | yield item
57 |
58 | def process_ocwcourses(item):
59 | messages = [
60 | {'role': 'user', 'content': item['problem'].strip()},
61 | {'role': 'assistant', 'content': item['solution'].strip()}
62 | ]
63 | item = {
64 | "dataset": "OCWCourses",
65 | "id": item['id'],
66 | "language": "en",
67 | "messages": messages,
68 | "answer": item['answer']
69 | }
70 | yield item
71 |
72 | def process_mmlu_stem(item):
73 | options = item['options']
74 | for i, (label, option) in enumerate(zip('ABCD', options)):
75 | options[i] = f"({label}) {str(option).strip()}"
76 | options = ", ".join(options)
77 | question = f"{item['question'].strip()}\nWhat of the following is the right choice? Explain your answer.\n{options}"
78 | messages = [
79 | {'role': 'user', 'content': question},
80 | {'role': 'assistant', 'content': item['answer']}
81 | ]
82 | item = {
83 | "dataset": "MMLU-STEM",
84 | "id": item['id'],
85 | "language": "en",
86 | "messages": messages,
87 | "answer": item['answer']
88 | }
89 | yield item
90 |
91 | def process_mgsm_zh(item):
92 | item['answer'] = item['answer'].replace(',', '')
93 | yield item
94 |
95 | def process_cmath(item):
96 | item = {
97 | 'dataset': 'cmath',
98 | 'id': item['id'],
99 | 'grade': item['grade'],
100 | 'reasoning_step': item['reasoning_step'],
101 | 'messages': [
102 | {'role': 'user', 'content': item['question'].strip()},
103 | {'role': 'assistant', 'content': ''}
104 | ],
105 | 'answer': item['golden'].strip().replace(",", "")
106 | }
107 | yield item
108 |
109 | def process_agieval_gaokao_math_cloze(item):
110 | item = {
111 | 'dataset': 'agieval-gaokao-math-cloze',
112 | 'id': item['id'],
113 | 'messages': [
114 | {'role': 'user', 'content': item['question'].strip()},
115 | {'role': 'assistant', 'content': ''}
116 | ],
117 | 'answer': [strip_string(ans) for ans in item['answer'].strip().split(";")]
118 | }
119 | yield item
120 |
121 | def process_agieval_gaokao_mathqa(item):
122 | question = item['question'].strip()
123 | options = []
124 | for option in item['options']:
125 | option = option.strip()
126 | assert option[0] == '('
127 | assert option[2] == ')'
128 | assert option[1] in 'ABCD'
129 | option = f"{option[1]}: {option[3:].strip()}"
130 | options.append(option.strip())
131 | question = f"{question}\n{options}"
132 | item = {
133 | 'dataset': 'agieval-gaokao-mathqa',
134 | 'id': item['id'],
135 | 'messages': [
136 | {'role': 'user', 'content': question},
137 | {'role': 'assistant', 'content': ''}
138 | ],
139 | "answer": item['label']
140 | }
141 | yield item
142 |
143 | def process_agieval_gaokao_mathqa_few_shot_cot_test(item):
144 | question = item['question'].strip().rstrip('\\')
145 | options = " ".join([opt.strip() for opt in item['options']])
146 | question = f"{question}\n从以下选项中选择: {options}"
147 | item = {
148 | 'dataset': 'agieval-gaokao-mathqa',
149 | 'id': item['id'],
150 | 'messages': [
151 | {'role': 'user', 'content': question},
152 | {'role': 'assistant', 'content': ''}
153 | ],
154 | "answer": item['label']
155 | }
156 | yield item
157 |
158 | def process_minif2f_isabelle(item):
159 | question = f"(*### Problem\n\n{item['informal_statement'].strip()}\n\n### Solution\n\n{item['informal_proof'].strip()} *)\n\nFormal:\n{item['formal_statement'].strip()}"
160 | item = {
161 | 'dataset': 'minif2f-isabelle',
162 | 'id': item['id'],
163 | 'messages': [
164 | {'role': 'user', 'content': question},
165 | {'role': 'assistant', 'content': ''}
166 | ],
167 | "answer": "placeholder"
168 | }
169 | yield item
170 |
--------------------------------------------------------------------------------
/eval/math/eval/__pycache__/eval_script.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hbin0701/Self-Explore/54b8ba9f4cee1ef52f72de916e100e0da2ae6863/eval/math/eval/__pycache__/eval_script.cpython-310.pyc
--------------------------------------------------------------------------------
/eval/math/eval/__pycache__/eval_script.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hbin0701/Self-Explore/54b8ba9f4cee1ef52f72de916e100e0da2ae6863/eval/math/eval/__pycache__/eval_script.cpython-311.pyc
--------------------------------------------------------------------------------
/eval/math/eval/__pycache__/eval_utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hbin0701/Self-Explore/54b8ba9f4cee1ef52f72de916e100e0da2ae6863/eval/math/eval/__pycache__/eval_utils.cpython-310.pyc
--------------------------------------------------------------------------------
/eval/math/eval/__pycache__/eval_utils.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hbin0701/Self-Explore/54b8ba9f4cee1ef52f72de916e100e0da2ae6863/eval/math/eval/__pycache__/eval_utils.cpython-311.pyc
--------------------------------------------------------------------------------
/eval/math/eval/__pycache__/ocwcourses_eval_utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hbin0701/Self-Explore/54b8ba9f4cee1ef52f72de916e100e0da2ae6863/eval/math/eval/__pycache__/ocwcourses_eval_utils.cpython-310.pyc
--------------------------------------------------------------------------------
/eval/math/eval/__pycache__/ocwcourses_eval_utils.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hbin0701/Self-Explore/54b8ba9f4cee1ef52f72de916e100e0da2ae6863/eval/math/eval/__pycache__/ocwcourses_eval_utils.cpython-311.pyc
--------------------------------------------------------------------------------
/eval/math/eval/__pycache__/python_executor.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hbin0701/Self-Explore/54b8ba9f4cee1ef52f72de916e100e0da2ae6863/eval/math/eval/__pycache__/python_executor.cpython-310.pyc
--------------------------------------------------------------------------------
/eval/math/eval/__pycache__/utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hbin0701/Self-Explore/54b8ba9f4cee1ef52f72de916e100e0da2ae6863/eval/math/eval/__pycache__/utils.cpython-310.pyc
--------------------------------------------------------------------------------
/eval/math/eval/eval_script.py:
--------------------------------------------------------------------------------
1 | import regex
2 | from copy import deepcopy
3 | from eval.eval_utils import math_equal
4 | from eval.ocwcourses_eval_utils import normalize_numeric, numeric_equality, normalize_symbolic_equation, SymbolicMathMixin
5 |
6 | def is_correct(item, pred_key='prediction', prec=1e-3):
7 | # print("called again ...", item)
8 | pred = item[pred_key]
9 | ans = item['answer']
10 | if isinstance(pred, list) and isinstance(ans, list):
11 | pred_matched = set()
12 | ans_matched = set()
13 | for i in range(len(pred)):
14 | for j in range(len(ans)):
15 | item_cpy = deepcopy(item)
16 | item_cpy.update({
17 | pred_key: pred[i],
18 | 'answer': ans[j]
19 | })
20 | if is_correct(item_cpy, pred_key=pred_key, prec=prec):
21 | pred_matched.add(i)
22 | ans_matched.add(j)
23 | if item_cpy[pred_key] == '2,3,4':
24 | print(item, flush=True)
25 | print("wtf", flush=True)
26 | return len(pred_matched) == len(pred) and len(ans_matched) == len(ans)
27 | elif isinstance(pred, str) and isinstance(ans, str):
28 | if '\\cup' in pred and '\\cup' in ans:
29 | item = deepcopy(item)
30 | item.update({
31 | pred_key: pred.split('\\cup'),
32 | 'answer': ans.split('\\cup'),
33 | })
34 | return is_correct(item, pred_key=pred_key, prec=prec)
35 | else:
36 | label = False
37 | try:
38 | label = abs(float(regex.sub(r',', '', str(pred))) - float(regex.sub(r',', '', str(ans)))) < prec
39 | except:
40 | pass
41 | label = label or (ans and pred == ans) or math_equal(pred, ans)
42 | return label
43 | else:
44 | print(item, flush=True)
45 | raise NotImplementedError()
46 |
47 | def eval_math(item, pred_key='prediction', prec=1e-3):
48 | pred = item[pred_key]
49 | if pred_key == 'program_output' and isinstance(pred, str):
50 | pred = [pred]
51 | ans = item['answer']
52 | if isinstance(pred, list) and isinstance(ans, list):
53 | # for some questions in MATH, `reference` repeats answers
54 | _ans = []
55 | for a in ans:
56 | if a not in _ans:
57 | _ans.append(a)
58 | ans = _ans
59 | # some predictions for MATH questions also repeats answers
60 | _pred = []
61 | for a in pred:
62 | if a not in _pred:
63 | _pred.append(a)
64 | # some predictions mistakenly box non-answer strings
65 | pred = _pred[-len(ans):]
66 |
67 | item.update({
68 | pred_key: pred,
69 | 'answer': ans
70 | })
71 | return is_correct(item, pred_key=pred_key, prec=prec)
72 |
73 | def eval_last_single_answer(item, pred_key='prediction', prec=1e-3):
74 | for key in [pred_key, 'answer']:
75 | assert isinstance(item[key], str), f"{key} = `{item[key]}` is not a str"
76 | return is_correct(item, pred_key=pred_key, prec=prec)
77 |
78 | def eval_agieval_gaokao_math_cloze(item, pred_key='prediction', prec=1e-3):
79 | if pred_key == 'program_output' and isinstance(item[pred_key], str):
80 | item[pred_key] = [item[pred_key]]
81 | for key in [pred_key, 'answer']:
82 | assert isinstance(item[key], list), f"{key} = `{item[key]}` is not a list"
83 | pred = item[pred_key]
84 | ans = item['answer']
85 | _pred = []
86 | for p in pred:
87 | p = p + ";"
88 | while p:
89 | left_brackets = 0
90 | for i in range(len(p)):
91 | if p[i] == ';' or (p[i] == ',' and left_brackets == 0):
92 | _p, p = p[:i].strip(), p[i + 1:].strip()
93 | if _p not in _pred:
94 | _pred.append(_p)
95 | break
96 | elif p[i] in '([{':
97 | left_brackets += 1
98 | elif p[i] in ')]}':
99 | left_brackets -= 1
100 | pred = _pred[-len(ans):]
101 | if len(pred) == len(ans):
102 | for p, a in zip(pred, ans):
103 | item.update({
104 | pred_key: p,
105 | 'answer': a,
106 | })
107 | if not is_correct(item, pred_key=pred_key, prec=prec):
108 | return False
109 | return True
110 | else:
111 | return False
112 |
113 | def eval_agieval_gaokao_mathqa(item, pred_key='prediction', prec=1e-3):
114 | if pred_key == 'program_output' and isinstance(item[pred_key], str):
115 | item[pred_key] = [item[pred_key]]
116 | pred_str = " ".join(item[pred_key])
117 | ans = item['answer']
118 | tag = None
119 | idx = -1
120 | for t in 'ABCD':
121 | if t in pred_str and pred_str.index(t) > idx:
122 | tag = t
123 | idx = pred_str.index(t)
124 | return tag == ans
125 |
126 | def eval_math_sat(item, pred_key='prediction', prec=1e-3):
127 | for key in [pred_key, 'answer']:
128 | assert isinstance(item[key], str), f"{key} = `{item[key]}` is not a str"
129 | return item[pred_key].lower() == item['answer'].lower()
130 |
131 | def eval_mmlu_stem(item, pred_key='prediction', prec=1e-3):
132 | return eval_math_sat(item, pred_key=pred_key, prec=prec)
133 |
134 | def eval_ocwcourses(item, pred_key='prediction', prec=1e-3):
135 | INVALID_ANSWER = "[invalidanswer]"
136 | for key in [pred_key, 'answer']:
137 | assert isinstance(item[key], str), f"{key} = `{item[key]}` is not a str"
138 | pred = item[pred_key]
139 | ans = item['answer']
140 |
141 | try:
142 | float(ans)
143 | normalize_fn = normalize_numeric
144 | is_equiv = numeric_equality
145 | answer_type = "numeric"
146 | except ValueError:
147 | if "=" in ans:
148 | normalize_fn = normalize_symbolic_equation
149 | is_equiv = lambda x, y: x==y
150 | answer_type = "equation"
151 | else:
152 | normalize_fn = SymbolicMathMixin().normalize_tex
153 | is_equiv = SymbolicMathMixin().is_tex_equiv
154 | answer_type = "expression"
155 |
156 | correct_answer = normalize_fn(ans)
157 |
158 | unnormalized_answer = pred if pred else INVALID_ANSWER
159 | model_answer = normalize_fn(unnormalized_answer)
160 |
161 | if unnormalized_answer == INVALID_ANSWER:
162 | acc = 0
163 | elif model_answer == INVALID_ANSWER:
164 | acc = 0
165 | elif is_equiv(model_answer, correct_answer):
166 | acc = 1
167 | else:
168 | acc = 0
169 |
170 | return acc
171 |
172 | def eval_minif2f_isabelle(item, pred_key='prediction', prec=1e-3):
173 | return True
174 |
--------------------------------------------------------------------------------
/eval/math/eval/ocwcourses_eval_utils.py:
--------------------------------------------------------------------------------
1 | import re
2 | import numpy as np
3 | import sympy
4 | from sympy.core.sympify import SympifyError
5 | from sympy.parsing.latex import parse_latex
6 |
7 | import signal
8 |
9 | INVALID_ANSWER = "[invalidanswer]"
10 |
11 | class timeout:
12 | def __init__(self, seconds=1, error_message="Timeout"):
13 | self.seconds = seconds
14 | self.error_message = error_message
15 |
16 | def handle_timeout(self, signum, frame):
17 | raise TimeoutError(self.error_message)
18 |
19 | def __enter__(self):
20 | signal.signal(signal.SIGALRM, self.handle_timeout)
21 | signal.alarm(self.seconds)
22 |
23 | def __exit__(self, type, value, traceback):
24 | signal.alarm(0)
25 |
26 | def normalize_numeric(s):
27 | if s is None:
28 | return None
29 | for unit in [
30 | "eV",
31 | " \\mathrm{~kg} \\cdot \\mathrm{m} / \\mathrm{s}",
32 | " kg m/s",
33 | "kg*m/s",
34 | "kg",
35 | "m/s",
36 | "m / s",
37 | "m s^{-1}",
38 | "\\text{ m/s}",
39 | " \\mathrm{m/s}",
40 | " \\text{ m/s}",
41 | "g/mole",
42 | "g/mol",
43 | "\\mathrm{~g}",
44 | "\\mathrm{~g} / \\mathrm{mol}",
45 | "W",
46 | "erg/s",
47 | "years",
48 | "year",
49 | "cm",
50 | ]:
51 | s = s.replace(unit, "")
52 | s = s.strip()
53 | for maybe_unit in ["m", "s", "cm"]:
54 | s = s.replace("\\mathrm{" + maybe_unit + "}", "")
55 | s = s.replace("\\mathrm{~" + maybe_unit + "}", "")
56 | s = s.strip()
57 | s = s.strip("$")
58 | try:
59 | return float(eval(s))
60 | except:
61 | try:
62 | expr = parse_latex(s)
63 | if expr.is_number:
64 | return float(expr)
65 | return INVALID_ANSWER
66 | except:
67 | return INVALID_ANSWER
68 |
69 | def numeric_equality(n1, n2, threshold=0.01):
70 | if n1 is None or n2 is None:
71 | return False
72 | if np.isclose(n1, 0) or np.isclose(n2, 0) or np.isclose(n1 - n2, 0):
73 | return np.abs(n1 - n2) < threshold * (n1 + n2) / 2
74 | else:
75 | return np.isclose(n1, n2)
76 |
77 | def normalize_symbolic_equation(s):
78 | if not isinstance(s, str):
79 | return INVALID_ANSWER
80 | if s.startswith("\\["):
81 | s = s[2:]
82 | if s.endswith("\\]"):
83 | s = s[:-2]
84 | s = s.replace("\\left(", "(")
85 | s = s.replace("\\right)", ")")
86 | s = s.replace("\\\\", "\\")
87 | if s.startswith("$") or s.endswith("$"):
88 | s = s.strip("$")
89 | try:
90 | maybe_expression = parse_latex(s)
91 | if not isinstance(maybe_expression, sympy.core.relational.Equality):
92 | # we have equation, not expression
93 | return INVALID_ANSWER
94 | else:
95 | return maybe_expression
96 | except:
97 | return INVALID_ANSWER
98 |
99 | class SymbolicMathMixin:
100 | """
101 | Methods useful for parsing mathematical expressions from text and determining equivalence of expressions.
102 | """
103 |
104 | SUBSTITUTIONS = [ # used for text normalize
105 | ("an ", ""),
106 | ("a ", ""),
107 | (".$", "$"),
108 | ("\\$", ""),
109 | (r"\ ", ""),
110 | (" ", ""),
111 | ("mbox", "text"),
112 | (",\\text{and}", ","),
113 | ("\\text{and}", ","),
114 | ("\\text{m}", "\\text{}"),
115 | ]
116 | REMOVED_EXPRESSIONS = [ # used for text normalizer
117 | "square",
118 | "ways",
119 | "integers",
120 | "dollars",
121 | "mph",
122 | "inches",
123 | "ft",
124 | "hours",
125 | "km",
126 | "units",
127 | "\\ldots",
128 | "sue",
129 | "points",
130 | "feet",
131 | "minutes",
132 | "digits",
133 | "cents",
134 | "degrees",
135 | "cm",
136 | "gm",
137 | "pounds",
138 | "meters",
139 | "meals",
140 | "edges",
141 | "students",
142 | "childrentickets",
143 | "multiples",
144 | "\\text{s}",
145 | "\\text{.}",
146 | "\\text{\ns}",
147 | "\\text{}^2",
148 | "\\text{}^3",
149 | "\\text{\n}",
150 | "\\text{}",
151 | r"\mathrm{th}",
152 | r"^\circ",
153 | r"^{\circ}",
154 | r"\;",
155 | r",\!",
156 | "{,}",
157 | '"',
158 | "\\dots",
159 | ]
160 |
161 | def normalize_tex(self, final_answer: str) -> str:
162 | """
163 | Normalizes a string representing a mathematical expression.
164 | Used as a preprocessing step before parsing methods.
165 |
166 | Copied character for character from appendix D of Lewkowycz et al. (2022)
167 | """
168 | final_answer = final_answer.split("=")[-1]
169 |
170 | for before, after in self.SUBSTITUTIONS:
171 | final_answer = final_answer.replace(before, after)
172 | for expr in self.REMOVED_EXPRESSIONS:
173 | final_answer = final_answer.replace(expr, "")
174 |
175 | # Extract answer that is in LaTeX math, is bold,
176 | # is surrounded by a box, etc.
177 | final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer)
178 | final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer)
179 | final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer)
180 | final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer)
181 | final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer)
182 |
183 | # Normalize shorthand TeX:
184 | # \fracab -> \frac{a}{b}
185 | # \frac{abc}{bef} -> \frac{abc}{bef}
186 | # \fracabc -> \frac{a}{b}c
187 | # \sqrta -> \sqrt{a}
188 | # \sqrtab -> sqrt{a}b
189 | final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer)
190 | final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer)
191 | final_answer = final_answer.replace("$", "")
192 |
193 | # Normalize 100,000 -> 100000
194 | if final_answer.replace(",", "").isdigit():
195 | final_answer = final_answer.replace(",", "")
196 |
197 | return final_answer
198 |
199 | def parse_tex(self, text: str, time_limit: int = 5) -> sympy.Basic:
200 | """
201 | Wrapper around `sympy.parse_text` that outputs a SymPy expression.
202 | Typically, you want to apply `normalize_text` as a preprocessing step.
203 | """
204 | try:
205 | with timeout(seconds=time_limit):
206 | parsed = parse_latex(text)
207 | except (
208 | # general error handling: there is a long tail of possible sympy/other
209 | # errors we would like to catch
210 | Exception
211 | ) as e:
212 | print(f"failed to parse {text} with exception {e}")
213 | return None
214 |
215 | return parsed
216 |
217 | def is_exp_equiv(self, x1: sympy.Basic, x2: sympy.Basic, time_limit=5) -> bool:
218 | """
219 | Determines whether two sympy expressions are equal.
220 | """
221 | try:
222 | with timeout(seconds=time_limit):
223 | try:
224 | diff = x1 - x2
225 | except (SympifyError, ValueError, TypeError) as e:
226 | print(
227 | f"Couldn't subtract {x1} and {x2} with exception {e}"
228 | )
229 | return False
230 |
231 | try:
232 | if sympy.simplify(diff) == 0:
233 | return True
234 | else:
235 | return False
236 | except (SympifyError, ValueError, TypeError) as e:
237 | print(f"Failed to simplify {x1}-{x2} with {e}")
238 | return False
239 | except TimeoutError as e:
240 | print(f"Timed out comparing {x1} and {x2}")
241 | return False
242 | except Exception as e:
243 | print(f"failed on unrecognized exception {e}")
244 | return False
245 |
246 | def is_tex_equiv(self, x1: str, x2: str, time_limit=5) -> bool:
247 | """
248 | Determines whether two (ideally normalized using `normalize_text`) TeX expressions are equal.
249 |
250 | Does so by first checking for string exact-match, then falls back on sympy-equivalence,
251 | following the (Lewkowycz et al. 2022) methodology.
252 | """
253 | if x1 == x2:
254 | # don't resort to sympy if we have full string match, post-normalization
255 | return True
256 | else:
257 | return False
258 | parsed_x2 = self.parse_tex(x2)
259 | if not parsed_x2:
260 | # if our reference fails to parse into a Sympy object,
261 | # we forgo parsing + checking our generated answer.
262 | return False
263 | return self.is_exp_equiv(self.parse_tex(x1), parsed_x2, time_limit=time_limit)
264 |
--------------------------------------------------------------------------------
/eval/math/eval/python_executor.py:
--------------------------------------------------------------------------------
1 | import os
2 | import io
3 | from contextlib import redirect_stdout
4 | import pickle
5 | import regex
6 | import copy
7 | from typing import Any, Dict, Optional
8 | import multiprocess
9 | from pebble import ProcessPool
10 | from concurrent.futures import TimeoutError
11 | from functools import partial
12 | import traceback
13 | from timeout_decorator import timeout
14 |
15 | class GenericRuntime:
16 | GLOBAL_DICT = {}
17 | LOCAL_DICT = None
18 | HEADERS = []
19 | def __init__(self):
20 | self._global_vars = copy.copy(self.GLOBAL_DICT)
21 | self._local_vars = copy.copy(self.LOCAL_DICT) if self.LOCAL_DICT else None
22 |
23 | for c in self.HEADERS:
24 | self.exec_code(c)
25 |
26 | def exec_code(self, code_piece: str) -> None:
27 | if regex.search(r'(\s|^)?input\(', code_piece) or regex.search(r'(\s|^)?os.system\(', code_piece):
28 | raise RuntimeError()
29 | exec(code_piece, self._global_vars)
30 |
31 | def eval_code(self, expr: str) -> Any:
32 | return eval(expr, self._global_vars)
33 |
34 | def inject(self, var_dict: Dict[str, Any]) -> None:
35 | for k, v in var_dict.items():
36 | self._global_vars[k] = v
37 |
38 | @property
39 | def answer(self):
40 | return self._global_vars['answer']
41 |
42 | class PythonExecutor:
43 | def __init__(
44 | self,
45 | runtime: Optional[Any] = None,
46 | get_answer_symbol: Optional[str] = None,
47 | get_answer_expr: Optional[str] = None,
48 | get_answer_from_stdout: bool = False,
49 | ) -> None:
50 | self.runtime = runtime if runtime else GenericRuntime()
51 | self.answer_symbol = get_answer_symbol
52 | self.answer_expr = get_answer_expr
53 | self.get_answer_from_stdout = get_answer_from_stdout
54 |
55 | def process_generation_to_code(self, gens: str):
56 | batch_code = []
57 | for g in gens:
58 | multiline_comments = False
59 | code = []
60 | for line in g.split('\n'):
61 | strip_line = line.strip()
62 | if strip_line.startswith("#"):
63 | line = line.split("#", 1)[0] + "# comments"
64 | elif not multiline_comments and strip_line.startswith('"""') and strip_line.endswith('"""') and len(strip_line) >= 6:
65 | line = line.split('"""', 1)[0] + '"""comments"""'
66 | elif not multiline_comments and strip_line.startswith('"""'):
67 | multiline_comments = True
68 | elif multiline_comments and strip_line.endswith('"""'):
69 | multiline_comments = False
70 | line = ""
71 | if not multiline_comments:
72 | code.append(line)
73 | batch_code.append(code)
74 | return batch_code
75 |
76 | @staticmethod
77 | def execute(
78 | code,
79 | get_answer_from_stdout = None,
80 | runtime = None,
81 | answer_symbol = None,
82 | answer_expr = None,
83 | timeout_length = 10,
84 | ):
85 | try:
86 | if get_answer_from_stdout:
87 | program_io = io.StringIO()
88 | with redirect_stdout(program_io):
89 | timeout(timeout_length)(runtime.exec_code)('\n'.join(code))
90 | program_io.seek(0)
91 | result = "".join(program_io.readlines()) # [-1]
92 | elif answer_symbol:
93 | timeout(timeout_length)(runtime.exec_code)('\n'.join(code))
94 | result = runtime._global_vars[answer_symbol]
95 | elif answer_expr:
96 | timeout(timeout_length)(runtime.exec_code)('\n'.join(code))
97 | result = timeout(timeout_length)(runtime.eval_code)(answer_expr)
98 | else:
99 | timeout(timeout_length)(runtime.exec_code)('\n'.join(code[:-1]))
100 | result = timeout(timeout_length)(runtime.eval_code)(code[-1])
101 | concise_exec_info = ""
102 | exec_info = ""
103 | str(result)
104 | pickle.dumps(result) # serialization check
105 | except:
106 | # traceback.print_exc()
107 | result = ''
108 | concise_exec_info = traceback.format_exc().split('\n')[-2]
109 | exec_info = traceback.format_exc()
110 | if get_answer_from_stdout and 'exec(code_piece, self._global_vars)' in exec_info:
111 | exec_info = exec_info.split('exec(code_piece, self._global_vars)')[-1].strip()
112 | msg = []
113 | for line in exec_info.split("\n"):
114 | patt = regex.search(r'(?P.*)File "(?P.*)", line (?P\d+), (?P.*)', line)
115 | if patt is not None:
116 | if '' in patt.group('end'):
117 | continue
118 | fname = patt.group("file")
119 | if "site-packages" in fname:
120 | fname = f"site-packages{fname.split('site-packages', 1)[1]}"
121 | line = f'{patt.group("start")}File "{fname}", {patt.group("end")}'
122 | else:
123 | line = f'{patt.group("start")}{patt.group("end")}'
124 | else:
125 | patt = regex.search(r'(?P.*)(?P/.*site-packages/.*\.py)(?P.*)', line)
126 | if patt is not None:
127 | line = f'{patt.group("start")}site-packages{patt.group("file").split("site-packages", 1)[1]}{patt.group("end")}'
128 | msg.append(line)
129 | exec_info = "\n".join(msg)
130 | return result, concise_exec_info, exec_info
131 |
132 | def apply(self, code):
133 | return self.batch_apply([code])[0]
134 |
135 | def batch_apply(self, batch_code):
136 | all_code_snippets = self.process_generation_to_code(batch_code)
137 | all_exec_results = []
138 | executor = partial(
139 | self.execute,
140 | get_answer_from_stdout=self.get_answer_from_stdout,
141 | runtime=self.runtime,
142 | answer_symbol=self.answer_symbol,
143 | answer_expr=self.answer_expr,
144 | timeout_length=10,
145 | )
146 | with ProcessPool(max_workers=multiprocess.cpu_count()) as pool:
147 | iterator = pool.map(executor, all_code_snippets, timeout=10).result()
148 |
149 | while True:
150 | try:
151 | result = next(iterator)
152 | all_exec_results.append(result)
153 | except StopIteration:
154 | break
155 | except TimeoutError as error:
156 | all_exec_results.append(("", "Timeout Error", "Timeout Error"))
157 | except Exception as error:
158 | print(error)
159 | exit()
160 |
161 | batch_results = []
162 | for code, (result, concise_exec_info, exec_info) in zip(all_code_snippets, all_exec_results):
163 | metadata = {'code': code, 'exec_result': result, 'concise_exec_info': concise_exec_info, 'exec_info': exec_info}
164 | batch_results.append((result, metadata))
165 | return batch_results
166 |
--------------------------------------------------------------------------------
/eval/math/run_cot_eval.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import sys
5 | import_path = os.path.abspath(__file__)
6 | for _ in range(2):
7 | import_path = os.path.dirname(import_path)
8 | sys.path.append(import_path)
9 |
10 | import numpy
11 | from tqdm import tqdm
12 | import json
13 | from copy import deepcopy
14 | from vllm import LLM, SamplingParams
15 | from pebble import ProcessPool
16 | from concurrent.futures import TimeoutError
17 | import random
18 | from eval.utils import generate_completions, load_hf_lm_and_tokenizer
19 | from transformers import AutoTokenizer
20 | from data_processing.answer_extraction import *
21 | from eval.eval_script import *
22 |
23 | def evaluate(eval_fn, tasks, _timeout=15):
24 | with ProcessPool() as pool:
25 | timeout_cnt = 0
26 | iterator = pool.map(eval_fn, tasks, timeout=_timeout).result()
27 | labels = []
28 | while True:
29 | try:
30 | labels.append(int(next(iterator)))
31 | except StopIteration:
32 | break
33 | except TimeoutError as error:
34 | labels.append(0)
35 | timeout_cnt += 1
36 | except Exception as error:
37 | print(error.traceback, flush=True)
38 | exit()
39 | return labels, timeout_cnt
40 |
41 | def infer(args, test_data):
42 | global tokenizer
43 | if tokenizer is None:
44 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name_or_path, trust_remote_code=True)
45 |
46 | prompts = []
47 | for example in test_data:
48 | prompt = example['messages'][-2]['content'] + "\n" # normally
49 | prompts.append(prompt.lstrip())
50 |
51 | global model
52 | print("Loading model and tokenizer...")
53 | if args.use_vllm:
54 | if model is None:
55 | model = LLM(model=args.model_name_or_path, tokenizer=args.tokenizer_name_or_path, trust_remote_code=True, tensor_parallel_size=len(os.environ['CUDA_VISIBLE_DEVICES'].split(",")))
56 | eos_token = tokenizer.eos_token if tokenizer is not None and tokenizer.eos_token is not None else ''
57 | stop_words = [eos_token]
58 |
59 | outputs = model.generate(prompts, SamplingParams(temperature=args.temperature, top_p=1.0, max_tokens=1024, n=1, stop=stop_words))
60 | outputs = sorted(outputs, key=lambda x: int(x.request_id)) # sort outputs by request_id
61 | outputs = [output.outputs[0].text for output in outputs]
62 | else:
63 | model, tokenizer = load_hf_lm_and_tokenizer(
64 | model_name_or_path=args.model_name_or_path,
65 | tokenizer_name_or_path=args.tokenizer_name_or_path,
66 | load_in_8bit=args.load_in_8bit,
67 | load_in_half=args.load_in_half,
68 | gptq_model=args.gptq
69 | )
70 |
71 | stop_id_sequences = []
72 | if tokenizer.eos_token_id is not None:
73 | stop_id_sequences = [[tokenizer.eos_token_id]]
74 |
75 | outputs, finish_completion = generate_completions(
76 | model=model,
77 | tokenizer=tokenizer,
78 | prompts=prompts,
79 | max_new_tokens=512,
80 | batch_size=args.eval_batch_size,
81 | stop_id_sequences=stop_id_sequences if stop_id_sequences else None,
82 | end_of_generation_id_sequence=[tokenizer.eos_token_id] if tokenizer.eos_token_id is not None else None
83 | )
84 |
85 | if args.complete_partial_output:
86 | model_outputs = [example['messages'][-1]['content'] + output for example, output in zip(test_data, outputs)]
87 | else:
88 | model_outputs = outputs
89 |
90 | predictions = [eval(args.answer_extraction_fn)(item['messages'][-2]['content'], output, task='cot') for item, output in tqdm(zip(test_data, model_outputs), desc="extract answer", total=len(model_outputs))]
91 | assert len(model_outputs) > 0, f"{len(model_outputs)}"
92 |
93 | results = []
94 | for example, output, pred in zip(test_data, model_outputs, predictions):
95 | item = deepcopy(example)
96 | item.update({
97 | 'model_output': output,
98 | 'prediction': pred,
99 | })
100 | results.append(item)
101 |
102 | print("Returning!")
103 | return results
104 |
105 | def main(args):
106 | random.seed(42)
107 |
108 | print("Loading data...")
109 | test_data = []
110 | with open(os.path.join(args.data_dir, f"train.jsonl" if args.infer_train_set else f"test.jsonl")) as fin:
111 | for line in fin:
112 | example = json.loads(line)
113 | messages = example['messages']
114 | assert messages[-1]['role'] == 'assistant'
115 | if not args.complete_partial_output:
116 | example['reference'] = example.get('reference', '') or [mess['content'] for mess in messages if mess['role'] == 'assistant']
117 | for mess in messages:
118 | if mess['role'] == 'assistant':
119 | mess['content'] = ''
120 | example['messages'] = messages
121 | test_data.append(example)
122 |
123 | if args.max_num_examples and len(test_data) > args.max_num_examples:
124 | test_data = random.sample(test_data, args.max_num_examples)
125 |
126 | if args.n_subsets > 1:
127 | assert args.subset_id >= 0 and args.subset_id < args.n_subsets
128 | test_data = [item for i, item in enumerate(test_data) if i % args.n_subsets == args.subset_id]
129 |
130 | if not test_data:
131 | return
132 |
133 | if not os.path.exists(args.save_dir):
134 | os.makedirs(args.save_dir, exist_ok=True)
135 |
136 | results = infer(args, test_data)
137 |
138 | import jsonlines
139 | with jsonlines.open(f"{args.save_dir}/results_{args.subset_id}.jsonl", "a") as writer:
140 | writer.write_all(results)
141 |
142 | if __name__ == "__main__":
143 | parser = argparse.ArgumentParser()
144 | parser.add_argument("--data_dir", type=str, default="data/mgsm")
145 | parser.add_argument("--max_num_examples", type=int, default=None, help="maximum number of examples to evaluate.")
146 | parser.add_argument("--save_dir", type=str, default="results/mgsm")
147 | parser.add_argument("--model_name_or_path", type=str, default=None, help="if specified, we will load the model to generate the predictions.")
148 | parser.add_argument("--tokenizer_name_or_path", type=str, default=None, help="if specified, we will load the tokenizer from here.")
149 | parser.add_argument("--eval_batch_size", type=int, default=1, help="batch size for evaluation.")
150 | parser.add_argument("--load_in_8bit", action="store_true", help="load model in 8bit mode, which will reduce memory and speed up inference.")
151 | parser.add_argument("--gptq", action="store_true", help="If given, we're evaluating a 4-bit quantized GPTQ model.")
152 | parser.add_argument("--use_vllm", action="store_true")
153 | parser.add_argument("--load_in_half", action='store_true')
154 | parser.add_argument("--infer_train_set", action="store_true")
155 | parser.add_argument("--n_subsets", type=int, default=1)
156 | parser.add_argument("--subset_id", type=int, default=0)
157 | parser.add_argument("--temperature", type=float, default=0.0)
158 | parser.add_argument("--repeat_id_start", type=int, default=0)
159 | parser.add_argument("--n_repeat_sampling", type=int, default=1)
160 | parser.add_argument("--complete_partial_output", action='store_true')
161 | parser.add_argument("--prompt_format", type=str, choices=['sft', 'few_shot'], default='sft')
162 | parser.add_argument("--few_shot_prompt", type=str, default=None)
163 | parser.add_argument("--answer_extraction_fn", type=str, required=True)
164 | parser.add_argument("--eval_fn", type=str, required=True)
165 | parser.add_argument("--gpus", type=str, default=None)
166 | args, unparsed_args = parser.parse_known_args()
167 | if args.gpus is not None:
168 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus
169 |
170 | print(unparsed_args, flush=True)
171 |
172 | if 'math6' in args.data_dir:
173 | args.multi_turn = True
174 |
175 | # model_name_or_path cannot be both None or both not None.
176 | model = None
177 | tokenizer = None
178 | pool = None
179 | if args.n_repeat_sampling > 1 or args.repeat_id_start != 0:
180 | assert args.temperature > 0
181 | save_dir = args.save_dir
182 | for i in range(args.repeat_id_start, args.repeat_id_start + args.n_repeat_sampling):
183 | print(f"working on the {i} trials ...", flush=True)
184 | args.save_dir = os.path.join(save_dir, str(i))
185 | os.makedirs(args.save_dir, exist_ok=True)
186 | main(args)
187 | else:
188 | main(args)
189 |
190 | if pool is not None:
191 | pool.close()
192 |
--------------------------------------------------------------------------------
/eval/math/run_math_eval.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | output_dir="outputs/DeepSeek_MATH_Self_Explore"
4 | model_path=""
5 | tokenizer_path=""
6 | model_size="7b"
7 | use_vllm="--use-vllm"
8 | no_markup_question="--no-markup-question"
9 | test_conf="configs/few_shot_test_configs.json" # While we use this config, our code doesn't evaluate in few-shot setting.
10 | prompt_format="few_shot"
11 | n_repeats="1"
12 | temperature="0"
13 | ngpus="4"
14 | rank="0"
15 |
16 | python run_subset_parallel.py --output-dir $output_dir \
17 | --model-path $model_path \
18 | --tokenizer-path $tokenizer_path \
19 | --model-size $model_size \
20 | $use_vllm \
21 | $no_markup_question \
22 | --test-conf $test_conf \
23 | --prompt_format $prompt_format \
24 | --n-repeats $n_repeats \
25 | --temperature $temperature \
26 | --ngpus $ngpus \
27 | --rank $rank
28 |
--------------------------------------------------------------------------------
/eval/math/run_subset_parallel.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | from tqdm import tqdm
4 | from glob import glob
5 | import time
6 | import json
7 | import subprocess
8 | import numpy
9 | from utils import read_data
10 | from eval.eval_script import eval_math
11 | from data_processing.process_utils import *
12 |
13 | _worker_num = int(os.environ.get('WORLD_SIZE', 1))
14 | _worker_id = int(os.environ.get('RANK', 0))
15 |
16 | def markup_question(args, item, language, src, task):
17 | for i in range(len(item['messages']) - 2, -1, -2):
18 | if language == 'zh':
19 | if task == 'cot':
20 | item['messages'][i]['content'] = f"{item['messages'][i]['content']}\n请通过逐步推理来解答问题,并把最终答案放置于" + "\\boxed{}中。"
21 | elif task == 'tool':
22 | item['messages'][i]['content'] = f"{item['messages'][i]['content']}\n请结合自然语言和Python程序语言来解答问题,并把最终答案放置于" + "\\boxed{}中。"
23 | else:
24 | pass
25 | elif language == 'en':
26 | if task == 'cot':
27 | item['messages'][i]['content'] = f"{item['messages'][i]['content']}\nPlease reason step by step, and put your final answer within " + "\\boxed{}."
28 | elif task == 'tool':
29 | item['messages'][i]['content'] = f"{item['messages'][i]['content']}\nPlease integrate natural language reasoning with programs to solve the problem above, and put your final answer within " + "\\boxed{}."
30 | else:
31 | pass
32 | return item
33 |
34 | def do_parallel_sampling(args, task, answer_extraction_fn, eval_fn, input_dir, output_dir, log_dir):
35 | if task == 'pal':
36 | code_fname = "run_pal_eval"
37 | elif task == 'cot':
38 | code_fname = "run_cot_eval"
39 | elif task == 'tool':
40 | code_fname = "run_tool_integrated_eval"
41 | else:
42 | raise NotImplementedError()
43 |
44 | n_procs = args.ngpus // args.ngpus_per_model
45 | n_procs = 4 # temporary
46 |
47 | gpus = [str(i) for i in range(args.ngpus)]
48 | gpu_groups = []
49 | for i in range(n_procs):
50 | gpu_groups.append(gpus[i * args.ngpus_per_model: (i + 1) * args.ngpus_per_model])
51 |
52 | global_n_procs = n_procs * _worker_num
53 |
54 | procs = []
55 | for pid, gpus in enumerate(gpu_groups):
56 | global_pid = n_procs * (args.rank or _worker_id) + pid
57 | logpath = os.path.join(log_dir, f"{global_pid}.log")
58 | f = open(logpath, "w")
59 | cmd = f"python {code_fname}.py " \
60 | f"--data_dir {input_dir} " \
61 | f"--max_num_examples 100000000000000 " \
62 | f"--save_dir {output_dir} " \
63 | f"--model {args.model_path} " \
64 | f"--tokenizer {args.tokenizer_path or args.model_path} " \
65 | f"--eval_batch_size 1 " \
66 | f"--temperature {args.temperature} " \
67 | f"--repeat_id_start 0 " \
68 | f"--n_repeat_sampling {args.n_repeats} " \
69 | f"--n_subsets {global_n_procs} " \
70 | f"--prompt_format {args.prompt_format} " \
71 | f"--few_shot_prompt {args.few_shot_prompt} " \
72 | f"--answer_extraction_fn {answer_extraction_fn} " \
73 | f"--eval_fn {eval_fn} " \
74 | f"--subset_id {global_pid} " \
75 | f"--gpus {','.join(gpus)} "
76 | if args.use_vllm:
77 | cmd += " --use_vllm "
78 | if args.load_in_half:
79 | cmd += " --load_in_half "
80 |
81 | local_metric_path = os.path.join(output_dir, f"metrics.{global_pid}.json")
82 | if not args.overwrite and os.path.exists(local_metric_path) and read_data(local_metric_path)['n_samples'] > 0:
83 | continue
84 |
85 | print("LOGPATH", logpath)
86 | procs.append((global_pid, subprocess.Popen(cmd.split(), stdout=f, stderr=f), f))
87 |
88 | for (global_pid, proc, f) in procs:
89 | print(f"Waiting for the {global_pid}th process to finish ...", flush=True)
90 | proc.wait()
91 |
92 | for (global_pid, proc, f) in procs:
93 | print(f"Closing the {global_pid}th process ...", flush=True)
94 | f.close()
95 |
96 | ### ADDED
97 | agg_li = []
98 | for i in range(n_procs):
99 | agg_li += [json.loads(x) for x in open(f"{output_dir.strip('/')}/results_{i}.jsonl")]
100 |
101 | import jsonlines
102 | with jsonlines.open(f"{output_dir.strip('/')}/results_combined.jsonl", "w") as writer:
103 | writer.write_all(agg_li)
104 |
105 | all_labels = []
106 | for item in agg_li:
107 | all_labels.append(eval_math(item))
108 |
109 | with open(os.path.join(output_dir.strip('/'), "summary.json"), "w") as fout:
110 | json.dump({
111 | "n_samples": len(all_labels),
112 | "accuracy": all_labels.count(True) / len(all_labels)
113 | }, fout, indent=4)
114 |
115 |
116 | # labels, eval_timeout_cnt = evaluate(eval(args.eval_fn), results)
117 | # for item, label in zip(results, labels):
118 | # item['accuracy'] = label
119 |
120 | # print("Calculating accuracy...")
121 | # acc = 0
122 | # for item in results:
123 | # acc += item['accuracy']
124 | # print("output acc = {:.5f}".format(acc / len(results) * 100), flush=True)
125 |
126 | # print(f"Timeout count >>> output eval = {eval_timeout_cnt}", flush=True)
127 |
128 | # pred_fname = "predictions.json"
129 | # if args.n_subsets > 1:
130 | # pred_fname = f"predictions.{args.subset_id}.json"
131 | # with open(os.path.join(args.save_dir, pred_fname), "w") as fout:
132 | # json.dump(results, fout, ensure_ascii=True)
133 |
134 | # metric_fname = "metrics.json"
135 | # if args.n_subsets > 1:
136 | # metric_fname = f"metrics.{args.subset_id}.json"
137 | # with open(os.path.join(args.save_dir, metric_fname), "w") as fout:
138 | # json.dump({
139 | # "n_samples": len(results),
140 | # "accuracy": sum(item['accuracy'] for item in results) / len(results),
141 | # }, fout, indent=4)
142 |
143 |
144 |
145 |
146 | ###
147 |
148 | time.sleep(1)
149 |
150 | local_pids = [global_pid for (global_pid, _, _) in procs]
151 |
152 | agg_preds = []
153 | for fname in glob(os.path.join(output_dir, "predictions.*.json")):
154 | if any(str(pid) in fname for pid in local_pids):
155 | agg_preds.extend(read_data(fname))
156 |
157 | metrics = {}
158 | n_samples = 0
159 | for fname in glob(os.path.join(output_dir, "metrics.*.json")):
160 | if not any(str(pid) in fname for pid in local_pids):
161 | continue
162 | _metrics = read_data(fname)
163 | n_samples += _metrics['n_samples']
164 | for key, val in _metrics.items():
165 | if key != 'n_samples':
166 | metrics[key] = metrics.get(key, 0) + val * _metrics['n_samples']
167 | for key, val in metrics.items():
168 | metrics[key] = val / max(n_samples, 1)
169 |
170 | result_msg = f"n samples = {n_samples}"
171 | for key, val in metrics.items():
172 | result_msg += f"\n{key} = {val * 100}"
173 |
174 | metrics['n_samples'] = n_samples
175 |
176 | return metrics, agg_preds, result_msg
177 |
178 | def main():
179 | parser = argparse.ArgumentParser()
180 | parser.add_argument("--output-dir", type=str, required=True, help="default to `model_path`_predictions")
181 | parser.add_argument("--model-path", type=str, required=True)
182 | parser.add_argument("--tokenizer-path", type=str, default=None)
183 | parser.add_argument("--model-size", type=str, choices=['1b', '7b', '13b', '33b', '34b', '70b'], default="7b")
184 |
185 | parser.add_argument("--test-conf", type=str, default="configs/zero_shot_test_configs.json", help="path to testing data config file that maps from a source to its info")
186 | parser.add_argument("--ngpus", type=int, default=8)
187 | parser.add_argument("--overwrite", action='store_true')
188 | parser.add_argument("--temperature", type=float, default=0)
189 | parser.add_argument("--n-repeats", type=int, default=1)
190 | parser.add_argument("--use-vllm", action='store_true')
191 | parser.add_argument("--load_in_half", action='store_true')
192 |
193 | parser.add_argument("--prompt_format", type=str, default="sft")
194 | parser.add_argument("--few_shot_prompt", type=str, default=None)
195 |
196 | parser.add_argument("--no-markup-question", action='store_true')
197 |
198 | parser.add_argument("--rank", type=int, default=None)
199 | parser.add_argument("--seed", type=int, default=42)
200 | args, _ = parser.parse_known_args()
201 |
202 | print(f"Evaluating {args.model_path}", flush=True)
203 |
204 | if args.output_dir is None:
205 | args.output_dir = f"{args.model_path.rstrip('/')}_predictions"
206 |
207 | args.ngpus_per_model = 4 if args.model_size in ['70b', '33b', '34b'] else 1
208 | assert args.ngpus % args.ngpus_per_model == 0
209 |
210 | default_few_shot_prompt = args.few_shot_prompt
211 |
212 | test_conf = read_data(args.test_conf)
213 |
214 | for src, info in test_conf.items():
215 | if args.n_repeats > 1:
216 | _src = f"{src}/sample_logs"
217 | else:
218 | _src = f"{src}/infer_logs"
219 | if _worker_num > 1:
220 | _src = f"{_src}/{args.rank or _worker_id}"
221 | for task in info['tasks']:
222 | fname = os.path.join(args.output_dir, _src, task, "test_data", "test.jsonl")
223 | input_dir = os.path.dirname(fname)
224 | os.makedirs(input_dir, exist_ok=True)
225 | metric_path = os.path.join(args.output_dir, _src, task, "samples", "metrics.json")
226 | if not args.overwrite and os.path.exists(metric_path) and read_data(metric_path)['n_samples'] > 0:
227 | continue
228 | with open(fname, "w") as file:
229 | data = read_data(info['test_path'])
230 | for i, sample in enumerate(tqdm(data, desc=f'processing {src}')):
231 | fn = eval(info['process_fn'])
232 | sample['id'] = sample.get('id', f"{src}-{i}")
233 | for j, item in enumerate(fn(sample)):
234 | item['dataset'] = src
235 | item['id'] = f"{src}-test-{i}-{j}"
236 | assert 'answer' in item
237 | if not args.no_markup_question:
238 | item = markup_question(args, item, info['language'], src, task)
239 | print(json.dumps(item), file=file, flush=True)
240 |
241 | output_dir = os.path.join(args.output_dir, _src, task, "samples")
242 | log_dir = os.path.join(args.output_dir, _src, task, "logs")
243 | os.makedirs(output_dir, exist_ok=True)
244 | os.makedirs(log_dir, exist_ok=True)
245 | metrics, agg_preds, result_msg = do_parallel_sampling(args, task, info['answer_extraction_fn'], info['eval_fn'], input_dir, output_dir, log_dir)
246 |
247 | os.makedirs(os.path.dirname(metric_path), exist_ok=True)
248 | json.dump(metrics, open(metric_path, "w"), indent=4)
249 | data_path = os.path.join(args.output_dir, _src, task, "samples", "predictions.json")
250 | os.makedirs(os.path.dirname(data_path), exist_ok=True)
251 | with open(data_path, "w") as file:
252 | json.dump(agg_preds, file, ensure_ascii=False)
253 | print(f"src = {src} | task = {task} >>>\n{result_msg}\n\n", flush=True)
254 |
255 | if __name__ == '__main__':
256 | main()
257 |
--------------------------------------------------------------------------------
/eval/math/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import random
3 | import numpy as np
4 |
5 | def set_seed(seed):
6 | if seed > 0:
7 | random.seed(seed)
8 | np.random.seed(seed)
9 |
10 | def shuffle(data, seed):
11 | if seed < 0:
12 | return data
13 | set_seed(seed)
14 | indices = list(range(len(data)))
15 | np.random.shuffle(indices)
16 | data = [data[i] for i in indices]
17 | return data
18 |
19 | def read_data(path):
20 | if path.endswith("json"):
21 | data = json.load(open(path, "r"))
22 | elif path.endswith("jsonl"):
23 | data = []
24 | with open(path, "r") as file:
25 | for line in file:
26 | line = json.loads(line)
27 | data.append(line)
28 | else:
29 | raise NotImplementedError()
30 | return data
31 |
--------------------------------------------------------------------------------
/gen/gen_rft_data.py:
--------------------------------------------------------------------------------
1 | from huggingface_hub import login
2 | import argparse
3 | import json
4 | import re
5 | import jsonlines
6 | from vllm import LLM, SamplingParams
7 | import sys
8 | from tqdm.auto import tqdm
9 |
10 | def generate(model, data_path, tensor_parallel_size=1, temp=0.7, id=0, task="GSM8K"):
11 | inputs = []
12 | answers = []
13 |
14 | # Assume you are using 4 GPUs. for the parts that take longer, assign little less than others.
15 |
16 | if task == "GSM8K":
17 | if id == 0:
18 | start, end = 0, 1870
19 | elif id == 1:
20 | start, end = 1870, 3740
21 | elif id == 2:
22 | start, end = 3740, 5610
23 | elif id == 3:
24 | start, end = 5610, 7473
25 |
26 | elif task == "MATH":
27 | if id == 0:
28 | start, end = 0, 1750
29 | elif id == 1:
30 | start, end = 1750, 3750
31 | elif id == 2:
32 | start, end = 3750, 5750
33 | elif id == 3:
34 | start, end = 5750, 7500
35 |
36 | with open(data_path,"r+", encoding="utf8") as f:
37 | for idx, item in enumerate(jsonlines.Reader(f)):
38 |
39 | if not (start <= idx < end):
40 | continue
41 |
42 | inputs.append(item["query"].strip() + "\n")
43 |
44 | temp_ans = item['response'] # just add the response. We will worry about it later.
45 | answers.append(temp_ans)
46 |
47 | result_file = args.result_file.replace(".jsonl", f"_{id}.jsonl")
48 |
49 | try:
50 | already_done = [json.loads(x) for x in open(result_file)]
51 | except:
52 | already_done = []
53 |
54 | if len(already_done) != 0:
55 | inputs = inputs[len(already_done):]
56 | answers = answers[len(already_done):]
57 |
58 | if len(inputs) == 0 and len(answers) == 0:
59 | print("Already completed. Exiting.")
60 | return
61 |
62 | stop_tokens = ["Problem:"]
63 | print("[GPU ID]", id, "Length of inputs", len(inputs))
64 |
65 | if temp == 0.7:
66 | n = 100 # needs modification
67 | else:
68 | n = 1
69 |
70 | if task == "GSM8K":
71 | max_token_num = 512
72 | else:
73 | max_token_num = 1024
74 |
75 | sampling_params = SamplingParams(temperature=temp, top_p=1, max_tokens=max_token_num, stop=stop_tokens, n=n)
76 | print('sampling =====', sampling_params)
77 | llm = LLM(model=model,tensor_parallel_size=tensor_parallel_size, enforce_eager=True, swap_space=32)
78 | result = []
79 | res_completions = []
80 |
81 | completions = llm.generate(inputs, sampling_params)
82 |
83 | for num, output in enumerate(completions):
84 | prompt = output.prompt
85 | all_texts = [out.text for out in output.outputs]
86 | res_completions.append(all_texts)
87 |
88 | answer = answers[num]
89 | dict_ = {"prompt": prompt, "preds": all_texts, "answer": answer}
90 |
91 | with jsonlines.open(result_file, 'a') as writer:
92 | writer.write(dict_)
93 |
94 | print('start===', start, ', end====', end)
95 |
96 | def parse_args():
97 | parser = argparse.ArgumentParser()
98 | parser.add_argument("--model", type=str) # model path
99 | parser.add_argument("--data_file", type=str, default='') # data path
100 | parser.add_argument("--tensor_parallel_size", type=int, default=1) # tensor_parallel_size
101 | parser.add_argument("--result_file", type=str, default="./log_file.jsonl") # tensor_parallel_size
102 | parser.add_argument("--temp", type=float, default=0.7)
103 | parser.add_argument("--id", type=int, default=0)
104 | parser.add_argument("--task", type=str, default="GSM8K")
105 |
106 | return parser.parse_args()
107 |
108 | if __name__ == "__main__":
109 | # Login First.
110 | login(token="huggingface_token_here")
111 |
112 | args = parse_args()
113 | generate(model=args.model, data_path=args.data_file, tensor_parallel_size=args.tensor_parallel_size, temp=args.temp, id=args.id, task=args.task)
114 |
115 | file_lists = [args.result_file.replace(".jsonl", f"_{i}.jsonl") for i in range(4)]
116 |
117 | # IF completed:
118 | import os
119 | # if all 4 exists, now unite file:
120 | if all([os.path.exists(x) for x in file_lists]) and sum([len([json.loads(x) for x in open(file_name)]) for file_name in file_lists]) == len([json.loads(x) for x in open(args.data_file)]):
121 | from utils_others import unite_file
122 | unite_file(args.result_file, 4, "_gen")
123 |
124 | # Then get rft and dpo data.
125 | gen_name = args.result_file.replace(".jsonl", "_gen.jsonl")
126 |
127 | # This will generate rft and dpo data.
128 | from get_rft_data import run_rft_and_dpo
129 | run_rft_and_dpo(gen_name, args.task)
--------------------------------------------------------------------------------
/gen/gen_rft_data.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | model_path="put_your_FT_model_path" # Put the model path.
3 | result_file="put_your_result_file" # Put the result file path you wish to save the result.
4 | data_file="put_your_data_file" # Put the train jsonl file.
5 | task="GSM8K"
6 |
7 | # First set of processes with deepseek-ai/deepseek-math-7b-base
8 | for i in $(seq 0 3) # Loop from 0 to 6, stepping by 2
9 | do
10 | CUDA_VISIBLE_DEVICES=$i python gen_rft_data.py --task $task --model $model_path --data_file $data_file --result_file $result_file --temp 0.7 --id $i &
11 | done
12 | wait
--------------------------------------------------------------------------------
/gen/gen_step_explore.py:
--------------------------------------------------------------------------------
1 | import optparse
2 | import sys
3 | from math_utils.eval_script import eval_math, is_correct
4 | from math_utils.answer_extraction import extract_answer, extract_math_few_shot_cot_answer
5 |
6 | from huggingface_hub import login
7 | import argparse
8 | import json
9 | import jsonlines
10 | from vllm import LLM, SamplingParams
11 | import sys
12 | from utils_others import get_final_steps
13 | from tqdm.auto import tqdm
14 |
15 | MAX_INT = sys.maxsize
16 |
17 | def extract_from_pred(x, y, z):
18 | try:
19 | return extract_math_few_shot_cot_answer(x,y,z)[0]
20 | except:
21 | return "[Invalid]"
22 |
23 | def test(model, data_path, tensor_parallel_size=1, temp=0.7, id=0, k=4, task="GSM8K", result_file=""):
24 | stop_tokens = []
25 |
26 | if temp == 0.7:
27 | n = k # Change this if you want larger 'k' for exploration.
28 | else:
29 | n = 1
30 |
31 | li = [json.loads(x) for x in open(data_path)]
32 |
33 | # Get the rejected sample divided in steps.
34 | if 'final_steps' not in li[0].keys():
35 | li = get_final_steps(li, task)
36 |
37 | li = [x for x in li if x['rejected'].strip("\n").strip() != ""] # filter out here.
38 |
39 | # This is based on the assumption there are 4 GPUs.
40 | portion = (len(li) + 4) // 4
41 | li = li[int(portion * id): int(portion * (id + 1))]
42 |
43 | all_elems = []
44 |
45 | result_file = args.result_file.replace(".jsonl", f"_{id}.jsonl")
46 | ## Exclude already processed.
47 | try:
48 | already_processed = [json.loads(x) for x in open(result_file)]
49 |
50 | if len(already_processed) >= len(li) - 1:
51 | return
52 | else:
53 | li = li[len(already_processed):]
54 | except:
55 | already_processed = []
56 |
57 | max_token_num = 512 if task == "GSM8K" else 1024
58 | sampling_params = SamplingParams(temperature=temp, top_p=1, max_tokens=max_token_num, stop=stop_tokens, n=n)
59 |
60 | print('sampling =====', sampling_params)
61 | print("Length Remaining ...", len(li))
62 | llm = LLM(model=model,tensor_parallel_size=tensor_parallel_size, enforce_eager=True, swap_space=64)
63 |
64 | ## PREPARE SAMPLES
65 | for idx, elem in enumerate(li):
66 | elem['step_pred'] = {}
67 | elem['all_prompt'] = elem['final_steps']
68 | elem['completed'] = False
69 | elem['num_steps'] = len(elem['all_prompt'])
70 | elem['sample_idx'] = idx
71 |
72 | # Get answer.
73 | if task == "GSM8K":
74 | elem['answer'] = elem['chosen'][elem['chosen'].index("The answer is") + len("The answer is"):].strip()
75 | else:
76 | elem['answer'] = extract_math_few_shot_cot_answer(elem['prompt'], elem['chosen'], "")[0]
77 |
78 | all_elems.append(elem)
79 |
80 | print(f"***** [Length Remaining] {len(all_elems)} *****")
81 |
82 | # Samples to process: 100
83 | SAMPLES_NUM = 1000
84 |
85 | for i in tqdm(range(0, len(all_elems) + SAMPLES_NUM, SAMPLES_NUM)):
86 | current_samples = all_elems[i:i+SAMPLES_NUM]
87 | max_step = min(max([len(x['all_prompt']) for x in current_samples]), 20) # set 20 as hard limit.
88 |
89 | for step_idx in range(max_step):
90 | # each samples' step_idx-th step.
91 | print("step_idx", step_idx)
92 | curr_step_to_proc = [x['all_prompt'][step_idx] for x in current_samples if x['completed'] == False]
93 | sample_idxs = [x['sample_idx'] for x in current_samples if x['completed'] == False]
94 |
95 | completions = llm.generate(curr_step_to_proc, sampling_params)
96 |
97 | # check if reached answer, if not discard.
98 | for sample_idx, output in zip(sample_idxs, completions):
99 | all_texts = [out.text for out in output.outputs]
100 |
101 | if task == "GSM8K":
102 | def get_answer(text):
103 | try:
104 | ans = text[text.index("The answer is") + len("The answer is"):].strip()
105 | except:
106 | ans = "[invalid]"
107 | return ans
108 |
109 | all_answers = [get_answer(text) for text in all_texts]
110 | else:
111 | all_answers = [extract_from_pred(all_elems[sample_idx]['prompt'], text, "") for text in all_texts]
112 |
113 | # If all false ... no need to continue. (i.e. we found the first pit!)
114 | if task == "GSM8K":
115 | if True not in [str(pred) == str(all_elems[sample_idx]['answer']) for pred in all_answers]:
116 | all_elems[sample_idx]['completed'] = True
117 | else:
118 | if True not in [eval_math({"prediction": pred, "answer": all_elems[sample_idx]['answer']}) for pred in all_answers]:
119 | all_elems[sample_idx]['completed'] = True
120 |
121 | # If processed all steps, mark as completed.
122 | if step_idx + 1 >= current_samples[sample_idx]['num_steps']:
123 | all_elems[sample_idx]['completed'] = True
124 |
125 | all_elems[sample_idx]["step_pred"][step_idx + 1] = {"prompt": all_elems[sample_idx]['all_prompt'][step_idx], "preds": all_texts, "answers": all_answers}
126 |
127 | with jsonlines.open(result_file, "a") as writer:
128 | writer.write_all(current_samples)
129 |
130 | # return to prevent server crash.
131 | return
132 |
133 | def parse_args():
134 | parser = argparse.ArgumentParser()
135 | parser.add_argument("--model", type=str) # model path
136 | parser.add_argument("--data_file", type=str, default='') # data path
137 | parser.add_argument("--tensor_parallel_size", type=int, default=1) # tensor_parallel_size
138 | parser.add_argument("--result_file", type=str, default="./log_file.jsonl") # tensor_parallel_size
139 | parser.add_argument("--temp", type=float, default=0.7)
140 | parser.add_argument("--id", type=int, default=0)
141 | parser.add_argument("--k", type=int, default=4) # how many completions to generate per step.
142 | parser.add_argument("--task", type=str, default="GSM8K") # how many completions to generate per step.
143 |
144 | return parser.parse_args()
145 |
146 | if __name__ == "__main__":
147 | # Login First.
148 | # login(token="your_hugging_face_token_here")
149 | args = parse_args()
150 | test(model=args.model, data_path=args.data_file, tensor_parallel_size=args.tensor_parallel_size, temp=args.temp, id=args.id, k=args.k, task=args.task, result_file=args.result_file)
151 |
152 | # See if completed.
153 | import os
154 | file_lists = [args.result_file.replace(".jsonl", f"_{i}.jsonl") for i in range(4)]
155 |
156 | # IF completed:
157 | if all([os.path.exists(x) for x in file_lists]) and sum([len([json.loads(x) for x in open(file_name)]) for file_name in file_lists]) == len([json.loads(x) for x in open(args.data_file)]):
158 | from utils_others import unite_file
159 | unite_file(args.result_file, 4)
160 |
161 | # Then run form_gpair
162 | from utils_others import form_gpair
163 | form_gpair(args.result_file, args.result_file.replace(".jsonl", "_gpair_4.jsonl"), args.task)
--------------------------------------------------------------------------------
/gen/gen_step_explore.sh:
--------------------------------------------------------------------------------
1 | # Assuming you hvae 4 gpus, run as following to speed up the process.
2 | model_path="put_your_model_path" # Put the model path.
3 | result_file="put_your_result_file" # Put the name of file to save exploration result.
4 | data_file="put_the_dpo_data_here" # Here, put the DPO file generated.
5 | temp=0.7
6 | task="GSM8K" # or MATH
7 |
8 | # Change the number of iterations accordingly to the data_file.
9 | for j in $(seq 1 10); do
10 | CUDA_VISIBLED_DEVICES=0 python gen_step_explore.py --task $task --model $model_path --temp $temp --result_file $result_file --data_file $data_file --id 0 &
11 | CUDA_VISIBLED_DEVICES=1 python gen_step_explore.py --task $task --model $model_path --temp $temp --result_file $result_file --data_file $data_file --id 1 &
12 | CUDA_VISIBLED_DEVICES=2 python gen_step_explore.py --task $task --model $model_path --temp $temp --result_file $result_file --data_file $data_file --id 2 &
13 | CUDA_VISIBLED_DEVICES=3 python gen_step_explore.py --task $task --model $model_path --temp $temp --result_file $result_file --data_file $data_file --id 3 &
14 | wait
15 | done
--------------------------------------------------------------------------------
/gen/get_dpo_data.py:
--------------------------------------------------------------------------------
1 | # COLLECT DPO DATA
2 | import json
3 | import editdistance
4 | from tqdm.auto import tqdm
5 | import jsonlines
6 | import argparse
7 | from utils_others import extract_answer
8 |
9 | def collect_dpo_gsm8k(wo_dup_data_path, w_dup_data_path):
10 | wo_dup = [json.loads(x) for x in open(wo_dup_data_path)]
11 | w_dup = [json.loads(x) for x in open(w_dup_data_path)]
12 |
13 | query_preds_dict = {}
14 | dpo_data = []
15 |
16 | # Set Wrong Outputs
17 | for elem in w_dup:
18 |
19 | elem['output_answers'] = [extract_answer(completion) for completion in elem['preds']]
20 | elem['outputs'] = elem['preds']
21 | elem['ground_truth'] = extract_answer(elem['answer'].replace("####", "The answer is"))
22 | elem['input'] = elem['prompt']
23 |
24 | query = elem['input'].rstrip("\n")
25 | wrong_outputs = [out for out, ans in zip(elem['outputs'], elem['output_answers']) if str(ans) != str(elem['ground_truth'])]
26 | query_preds_dict[query] = list(set(wrong_outputs))
27 |
28 | # Find by Query:
29 | for elem in tqdm(wo_dup):
30 | query = elem['query'].rstrip("\n")
31 | if query not in query_preds_dict:
32 | raise AssertionError
33 | else:
34 | new_elem = {}
35 | new_elem['prompt'] = query.rstrip() + "\n"
36 | wrong_outputs = query_preds_dict[query]
37 |
38 | if len(wrong_outputs) == 0:
39 | continue # Skip
40 |
41 | new_elem['chosen'] = elem['response'].strip()
42 | new_elem['rejected'] = sorted([(ref, editdistance.eval(new_elem['chosen'], ref)) for ref in wrong_outputs], key=lambda x: x[1])[-1][0] # Choose the one with the largest edit distance.
43 |
44 | # Remove the rejected from the pool
45 | query_preds_dict[query].remove(new_elem['rejected'])
46 | dpo_data.append(new_elem)
47 |
48 | return dpo_data
49 |
50 | def collect_dpo_math(old_li, fname=""):
51 |
52 | dpo_set = []
53 |
54 | for elem in tqdm(old_li):
55 | corr = elem['filtered_corr']
56 |
57 | incorr = list(set(elem['filtered_incorr'])) # For incorrect, one could use string-level set, or edit distance based ... i.e., list(set(elem['incorr]))
58 |
59 | # for every correct sample, find the one with the maximum edit distance.
60 | MIN_NUM = min(len(corr), len(incorr))
61 |
62 | for sample in corr[:MIN_NUM]:
63 | all_rej = sorted([(ref, editdistance.eval(sample, ref) / (len(sample) + len(ref))) for ref in incorr], key=lambda x: x[1])
64 |
65 | # Try to find shortest one among maximum edit distance - this is to prevent from selecting one that has degenerated sample - i.e. repetition.
66 | filtered = sorted([x for x in all_rej if x[1] > 0.5], key=lambda x: len(x[0]))
67 | if len(filtered) == 0:
68 | filtered = sorted([x for x in all_rej if x[1] > 0.4], key=lambda x: len(x[0]))
69 | if len(filtered) == 0:
70 | filtered = sorted([x for x in all_rej if x[1] > 0.3], key=lambda x: len(x[0]))
71 | if len(filtered) == 0:
72 | print("Not passed any :( Selecting shortest ...")
73 | filtered = sorted(all_rej, key=lambda x: len(x[0]))
74 |
75 | rej = filtered[0][0]
76 | dpo_set.append({'prompt': elem['question'], 'chosen': sample, 'rejected': rej})
77 | incorr.remove(rej)
78 |
79 | return dpo_set
--------------------------------------------------------------------------------
/gen/get_rft_data.py:
--------------------------------------------------------------------------------
1 | # GET RFT DATA
2 | import re
3 | import editdistance
4 | from tqdm.auto import tqdm
5 | import json
6 | import jsonlines
7 | import argparse
8 | from utils_others import extract_answer
9 | from get_dpo_data import collect_dpo_gsm8k, collect_dpo_math
10 | import random
11 | import editdistance
12 |
13 | def read_jsonl(fname):
14 | return [json.loads(line) for line in open(fname)]
15 |
16 | def check_equation(eq):
17 | if eq.find('=') == -1:
18 | return False
19 |
20 | lhs = eq.split('=')[0]
21 | rhs = eq.split('=')[1]
22 |
23 | try:
24 | lhs_result = eval(str(lhs))
25 | if abs(float(lhs_result) - float(rhs)) < 1e-3:
26 | return True
27 | except BaseException:
28 | return False
29 | return False
30 |
31 | def is_eq_correct(equations):
32 | for eq in equations:
33 | if not check_equation(eq):
34 | return False
35 | return True
36 |
37 | def collect_rft_data_gsm8k(fname):
38 |
39 | eq_pattern = r'<<([^>]*)>>' # Match everything inside << >>
40 | # eq_pattern = r'\b\d+\.?\d*\b' # This for mistral-metamath.
41 |
42 | all_sets = [] # list of dicts of 'query' and 'response'.
43 | old_li = read_jsonl(fname)
44 |
45 | for sample in tqdm(old_li):
46 | sample['preds'] = list(set(sample['preds']))
47 | sample['output_answers'] = [extract_answer(completion) for completion in sample['preds']]
48 | sample['outputs'] = [x.strip() for x in sample['preds']]
49 | sample['ground_truth'] = extract_answer(sample['answer'].replace("####", "The answer is"))
50 | sample['input'] = sample['prompt'].rstrip() + "\n"
51 |
52 | exist_match = []
53 | correct_preds = [reasoning for reasoning, answer in zip(sample['outputs'], sample['output_answers']) if str(answer) == str(sample['ground_truth'])]
54 |
55 | matches = [{"reasoning": r, "equations": re.findall(eq_pattern, r)} for r in correct_preds] # Find all matches
56 |
57 | # remove this line in case of mistral-metamath.
58 | matches = [m for m in matches if is_eq_correct(m['equations'])]
59 |
60 | final_preds = {}
61 |
62 | for elem in matches:
63 | match_string = '|'.join(elem['equations']).replace(' ', '')
64 | if match_string in final_preds:
65 | other_solutions = [final_preds[k] for k in final_preds if k != match_string]
66 | now_score = sum([editdistance.eval(elem['reasoning'], ref) for ref in other_solutions])
67 | original_score = sum([editdistance.eval(final_preds[match_string], ref) for ref in other_solutions])
68 | if now_score > original_score:
69 | final_preds[match_string] = elem['reasoning']
70 | else:
71 | final_preds[match_string] = elem['reasoning']
72 |
73 | sample['rft_outputs'] = list(final_preds.values())
74 |
75 | MAX_NUM = 8
76 | for rft_sample in sample['rft_outputs'][:MAX_NUM]:
77 | all_sets.append({'query': sample['input'].rstrip("\n"), 'response': rft_sample})
78 |
79 | return old_li, all_sets
80 |
81 | def collect_rft_data_math(fname):
82 | from math_utils.eval_script import eval_math, is_correct
83 | from math_utils.answer_extraction import extract_answer, extract_math_few_shot_cot_answer
84 |
85 | random.seed(42)
86 | MAX_SAMPLES = 8
87 |
88 | def remove_dup(corr, N=100):
89 | corr = list(set(corr))
90 | random.shuffle(corr)
91 | final_results = []
92 |
93 | for elem in corr:
94 | if all([editdistance.eval(elem, x) / (len(elem) + len(x) // 2) > 0.2 for x in final_results]):
95 | final_results.append(elem)
96 | if len(final_results) >= N:
97 | break
98 | return final_results
99 |
100 | # file should be already aggregated.
101 | li = [json.loads(x) for x in open(fname)]
102 |
103 | # Make sure you check it by your own.
104 | # if len(li) != 7500 or len(set([x['prompt'] for x in li])) != 7500:
105 | # raise ValueError("RFT is missing some file. Make sure all questions are properly handled.")
106 |
107 | # [To-do] Change data_file dir.
108 | all_q = [json.loads(x) for x in open("../data/MATH_train.jsonl")]
109 |
110 | for orig_elem, elem in zip(all_q, li):
111 | elem['answer'] = extract_answer(elem['answer'])
112 | elem['question'] = orig_elem['query']
113 |
114 | new_li = []
115 |
116 | for idx in tqdm(range(len(li))):
117 |
118 | li[idx]['incorr'] = []
119 | li[idx]['corr'] = []
120 | q = li[idx]['question']
121 |
122 | for pred in li[idx]['preds']:
123 | # Only select those that "Final Answer:" is present, to prevent repetition being used as negative.
124 | # We want a good quality negative. :)
125 | tgt_str = "Final Answer:"
126 | tgt_str = "The answer is"
127 |
128 | if tgt_str not in pred.strip().split("\n")[-1]:
129 | if tgt_str in pred:
130 | pred_idx = pred.index(tgt_str)
131 | end_of_line_idx = pred.find('\n', pred_idx + len(tgt_str))
132 | if end_of_line_idx != -1:
133 | pred = pred[:end_of_line_idx]
134 | else:
135 | print("Final Answer Error.")
136 | continue # just skip.
137 | else:
138 | continue # If not found, likely to be repetition, or incomplete solution.
139 |
140 | try:
141 | new_x = {"prediction": extract_math_few_shot_cot_answer(q, pred, "")[0], "answer": li[idx]['answer']}
142 | out = eval_math(new_x)
143 | except:
144 | out = False
145 |
146 | if out:
147 | li[idx]['corr'].append(pred)
148 | else:
149 | li[idx]['incorr'].append(pred)
150 |
151 | filtered_corr = remove_dup(li[idx]['corr'], N=MAX_SAMPLES)
152 | filtered_corr = [x.replace("I hope it is correct.", "" ).strip() for x in filtered_corr]
153 | li[idx]['filtered_corr'] = filtered_corr
154 |
155 | if len(filtered_corr) == 0:
156 | li[idx]['filtered_incorr'] = []
157 | continue
158 |
159 | filtered_incorr = remove_dup(li[idx]['incorr'], N=MAX_SAMPLES * 3) # for simplicity.
160 | filtered_incorr = [x.replace("I hope it is correct.", "" ).strip() for x in filtered_incorr]
161 | li[idx]['filtered_incorr'] = filtered_incorr
162 |
163 | for filtered_corr_sample in filtered_corr:
164 | new_li.append({"query": q.rstrip() + "\n", "response": filtered_corr_sample})
165 |
166 | return li, new_li
167 |
168 | def run_rft_and_dpo(fname, task):
169 |
170 | fname1 = fname
171 | fname2 = fname1.replace("gen", "rft")
172 | fname3 = fname1.replace("gen", "dpo")
173 |
174 | collect_rft_fn = collect_rft_data_gsm8k if task == "GSM8K" else collect_rft_data_math
175 | collect_dpo_fn = collect_dpo_gsm8k if task == "GSM8K" else collect_dpo_math
176 |
177 | old_set, new_set = collect_rft_fn(fname1)
178 |
179 | # Generate RFT file
180 | with jsonlines.open(fname2, mode='w') as writer:
181 | writer.write_all(new_set)
182 |
183 | if task == "GSM8K":
184 | new_set = collect_dpo_fn(fname2, fname1)
185 | else:
186 | new_set = collect_dpo_fn(old_set, fname1)
187 |
188 | # Generate DPO file
189 | with jsonlines.open(fname3, mode='w') as writer:
190 | writer.write_all(new_set)
--------------------------------------------------------------------------------
/gen/math_utils/eval_script.py:
--------------------------------------------------------------------------------
1 | import regex
2 | from copy import deepcopy
3 | from .eval_utils import math_equal
4 | from .ocwcourses_eval_utils import normalize_numeric, numeric_equality, normalize_symbolic_equation, SymbolicMathMixin
5 |
6 | def is_correct(item, pred_key='prediction', prec=1e-3):
7 | # print("called again ...", item)
8 | pred = item[pred_key]
9 | ans = item['answer']
10 | if isinstance(pred, list) and isinstance(ans, list):
11 | pred_matched = set()
12 | ans_matched = set()
13 | for i in range(len(pred)):
14 | for j in range(len(ans)):
15 | item_cpy = deepcopy(item)
16 | item_cpy.update({
17 | pred_key: pred[i],
18 | 'answer': ans[j]
19 | })
20 | if is_correct(item_cpy, pred_key=pred_key, prec=prec):
21 | pred_matched.add(i)
22 | ans_matched.add(j)
23 | if item_cpy[pred_key] == '2,3,4':
24 | print(item, flush=True)
25 | print("wtf", flush=True)
26 | return len(pred_matched) == len(pred) and len(ans_matched) == len(ans)
27 | elif isinstance(pred, str) and isinstance(ans, str):
28 | if '\\cup' in pred and '\\cup' in ans:
29 | item = deepcopy(item)
30 | item.update({
31 | pred_key: pred.split('\\cup'),
32 | 'answer': ans.split('\\cup'),
33 | })
34 | return is_correct(item, pred_key=pred_key, prec=prec)
35 | else:
36 | label = False
37 | try:
38 | label = abs(float(regex.sub(r',', '', str(pred))) - float(regex.sub(r',', '', str(ans)))) < prec
39 | except:
40 | pass
41 | label = label or (ans and pred == ans) or math_equal(pred, ans)
42 | return label
43 | else:
44 | print(item, flush=True)
45 | raise NotImplementedError()
46 |
47 | def eval_math(item, pred_key='prediction', prec=1e-3):
48 | pred = item[pred_key]
49 | if pred_key == 'program_output' and isinstance(pred, str):
50 | pred = [pred]
51 | ans = item['answer']
52 | if isinstance(pred, list) and isinstance(ans, list):
53 | # for some questions in MATH, `reference` repeats answers
54 | _ans = []
55 | for a in ans:
56 | if a not in _ans:
57 | _ans.append(a)
58 | ans = _ans
59 | # some predictions for MATH questions also repeats answers
60 | _pred = []
61 | for a in pred:
62 | if a not in _pred:
63 | _pred.append(a)
64 | # some predictions mistakenly box non-answer strings
65 | pred = _pred[-len(ans):]
66 |
67 | item.update({
68 | pred_key: pred,
69 | 'answer': ans
70 | })
71 | return is_correct(item, pred_key=pred_key, prec=prec)
72 |
73 | def eval_last_single_answer(item, pred_key='prediction', prec=1e-3):
74 | for key in [pred_key, 'answer']:
75 | assert isinstance(item[key], str), f"{key} = `{item[key]}` is not a str"
76 | return is_correct(item, pred_key=pred_key, prec=prec)
77 |
78 | def eval_agieval_gaokao_math_cloze(item, pred_key='prediction', prec=1e-3):
79 | if pred_key == 'program_output' and isinstance(item[pred_key], str):
80 | item[pred_key] = [item[pred_key]]
81 | for key in [pred_key, 'answer']:
82 | assert isinstance(item[key], list), f"{key} = `{item[key]}` is not a list"
83 | pred = item[pred_key]
84 | ans = item['answer']
85 | _pred = []
86 | for p in pred:
87 | p = p + ";"
88 | while p:
89 | left_brackets = 0
90 | for i in range(len(p)):
91 | if p[i] == ';' or (p[i] == ',' and left_brackets == 0):
92 | _p, p = p[:i].strip(), p[i + 1:].strip()
93 | if _p not in _pred:
94 | _pred.append(_p)
95 | break
96 | elif p[i] in '([{':
97 | left_brackets += 1
98 | elif p[i] in ')]}':
99 | left_brackets -= 1
100 | pred = _pred[-len(ans):]
101 | if len(pred) == len(ans):
102 | for p, a in zip(pred, ans):
103 | item.update({
104 | pred_key: p,
105 | 'answer': a,
106 | })
107 | if not is_correct(item, pred_key=pred_key, prec=prec):
108 | return False
109 | return True
110 | else:
111 | return False
112 |
113 | def eval_agieval_gaokao_mathqa(item, pred_key='prediction', prec=1e-3):
114 | if pred_key == 'program_output' and isinstance(item[pred_key], str):
115 | item[pred_key] = [item[pred_key]]
116 | pred_str = " ".join(item[pred_key])
117 | ans = item['answer']
118 | tag = None
119 | idx = -1
120 | for t in 'ABCD':
121 | if t in pred_str and pred_str.index(t) > idx:
122 | tag = t
123 | idx = pred_str.index(t)
124 | return tag == ans
125 |
126 | def eval_math_sat(item, pred_key='prediction', prec=1e-3):
127 | for key in [pred_key, 'answer']:
128 | assert isinstance(item[key], str), f"{key} = `{item[key]}` is not a str"
129 | return item[pred_key].lower() == item['answer'].lower()
130 |
131 | def eval_mmlu_stem(item, pred_key='prediction', prec=1e-3):
132 | return eval_math_sat(item, pred_key=pred_key, prec=prec)
133 |
134 | def eval_ocwcourses(item, pred_key='prediction', prec=1e-3):
135 | INVALID_ANSWER = "[invalidanswer]"
136 | for key in [pred_key, 'answer']:
137 | assert isinstance(item[key], str), f"{key} = `{item[key]}` is not a str"
138 | pred = item[pred_key]
139 | ans = item['answer']
140 |
141 | try:
142 | float(ans)
143 | normalize_fn = normalize_numeric
144 | is_equiv = numeric_equality
145 | answer_type = "numeric"
146 | except ValueError:
147 | if "=" in ans:
148 | normalize_fn = normalize_symbolic_equation
149 | is_equiv = lambda x, y: x==y
150 | answer_type = "equation"
151 | else:
152 | normalize_fn = SymbolicMathMixin().normalize_tex
153 | is_equiv = SymbolicMathMixin().is_tex_equiv
154 | answer_type = "expression"
155 |
156 | correct_answer = normalize_fn(ans)
157 |
158 | unnormalized_answer = pred if pred else INVALID_ANSWER
159 | model_answer = normalize_fn(unnormalized_answer)
160 |
161 | if unnormalized_answer == INVALID_ANSWER:
162 | acc = 0
163 | elif model_answer == INVALID_ANSWER:
164 | acc = 0
165 | elif is_equiv(model_answer, correct_answer):
166 | acc = 1
167 | else:
168 | acc = 0
169 |
170 | return acc
171 |
172 | def eval_minif2f_isabelle(item, pred_key='prediction', prec=1e-3):
173 | return True
174 |
--------------------------------------------------------------------------------
/gen/math_utils/ocwcourses_eval_utils.py:
--------------------------------------------------------------------------------
1 | import re
2 | import numpy as np
3 | import sympy
4 | from sympy.core.sympify import SympifyError
5 | from sympy.parsing.latex import parse_latex
6 |
7 | import signal
8 |
9 | INVALID_ANSWER = "[invalidanswer]"
10 |
11 | class timeout:
12 | def __init__(self, seconds=1, error_message="Timeout"):
13 | self.seconds = seconds
14 | self.error_message = error_message
15 |
16 | def handle_timeout(self, signum, frame):
17 | raise TimeoutError(self.error_message)
18 |
19 | def __enter__(self):
20 | signal.signal(signal.SIGALRM, self.handle_timeout)
21 | signal.alarm(self.seconds)
22 |
23 | def __exit__(self, type, value, traceback):
24 | signal.alarm(0)
25 |
26 | def normalize_numeric(s):
27 | if s is None:
28 | return None
29 | for unit in [
30 | "eV",
31 | " \\mathrm{~kg} \\cdot \\mathrm{m} / \\mathrm{s}",
32 | " kg m/s",
33 | "kg*m/s",
34 | "kg",
35 | "m/s",
36 | "m / s",
37 | "m s^{-1}",
38 | "\\text{ m/s}",
39 | " \\mathrm{m/s}",
40 | " \\text{ m/s}",
41 | "g/mole",
42 | "g/mol",
43 | "\\mathrm{~g}",
44 | "\\mathrm{~g} / \\mathrm{mol}",
45 | "W",
46 | "erg/s",
47 | "years",
48 | "year",
49 | "cm",
50 | ]:
51 | s = s.replace(unit, "")
52 | s = s.strip()
53 | for maybe_unit in ["m", "s", "cm"]:
54 | s = s.replace("\\mathrm{" + maybe_unit + "}", "")
55 | s = s.replace("\\mathrm{~" + maybe_unit + "}", "")
56 | s = s.strip()
57 | s = s.strip("$")
58 | try:
59 | return float(eval(s))
60 | except:
61 | try:
62 | expr = parse_latex(s)
63 | if expr.is_number:
64 | return float(expr)
65 | return INVALID_ANSWER
66 | except:
67 | return INVALID_ANSWER
68 |
69 | def numeric_equality(n1, n2, threshold=0.01):
70 | if n1 is None or n2 is None:
71 | return False
72 | if np.isclose(n1, 0) or np.isclose(n2, 0) or np.isclose(n1 - n2, 0):
73 | return np.abs(n1 - n2) < threshold * (n1 + n2) / 2
74 | else:
75 | return np.isclose(n1, n2)
76 |
77 | def normalize_symbolic_equation(s):
78 | if not isinstance(s, str):
79 | return INVALID_ANSWER
80 | if s.startswith("\\["):
81 | s = s[2:]
82 | if s.endswith("\\]"):
83 | s = s[:-2]
84 | s = s.replace("\\left(", "(")
85 | s = s.replace("\\right)", ")")
86 | s = s.replace("\\\\", "\\")
87 | if s.startswith("$") or s.endswith("$"):
88 | s = s.strip("$")
89 | try:
90 | maybe_expression = parse_latex(s)
91 | if not isinstance(maybe_expression, sympy.core.relational.Equality):
92 | # we have equation, not expression
93 | return INVALID_ANSWER
94 | else:
95 | return maybe_expression
96 | except:
97 | return INVALID_ANSWER
98 |
99 | class SymbolicMathMixin:
100 | """
101 | Methods useful for parsing mathematical expressions from text and determining equivalence of expressions.
102 | """
103 |
104 | SUBSTITUTIONS = [ # used for text normalize
105 | ("an ", ""),
106 | ("a ", ""),
107 | (".$", "$"),
108 | ("\\$", ""),
109 | (r"\ ", ""),
110 | (" ", ""),
111 | ("mbox", "text"),
112 | (",\\text{and}", ","),
113 | ("\\text{and}", ","),
114 | ("\\text{m}", "\\text{}"),
115 | ]
116 | REMOVED_EXPRESSIONS = [ # used for text normalizer
117 | "square",
118 | "ways",
119 | "integers",
120 | "dollars",
121 | "mph",
122 | "inches",
123 | "ft",
124 | "hours",
125 | "km",
126 | "units",
127 | "\\ldots",
128 | "sue",
129 | "points",
130 | "feet",
131 | "minutes",
132 | "digits",
133 | "cents",
134 | "degrees",
135 | "cm",
136 | "gm",
137 | "pounds",
138 | "meters",
139 | "meals",
140 | "edges",
141 | "students",
142 | "childrentickets",
143 | "multiples",
144 | "\\text{s}",
145 | "\\text{.}",
146 | "\\text{\ns}",
147 | "\\text{}^2",
148 | "\\text{}^3",
149 | "\\text{\n}",
150 | "\\text{}",
151 | r"\mathrm{th}",
152 | r"^\circ",
153 | r"^{\circ}",
154 | r"\;",
155 | r",\!",
156 | "{,}",
157 | '"',
158 | "\\dots",
159 | ]
160 |
161 | def normalize_tex(self, final_answer: str) -> str:
162 | """
163 | Normalizes a string representing a mathematical expression.
164 | Used as a preprocessing step before parsing methods.
165 |
166 | Copied character for character from appendix D of Lewkowycz et al. (2022)
167 | """
168 | final_answer = final_answer.split("=")[-1]
169 |
170 | for before, after in self.SUBSTITUTIONS:
171 | final_answer = final_answer.replace(before, after)
172 | for expr in self.REMOVED_EXPRESSIONS:
173 | final_answer = final_answer.replace(expr, "")
174 |
175 | # Extract answer that is in LaTeX math, is bold,
176 | # is surrounded by a box, etc.
177 | final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer)
178 | final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer)
179 | final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer)
180 | final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer)
181 | final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer)
182 |
183 | # Normalize shorthand TeX:
184 | # \fracab -> \frac{a}{b}
185 | # \frac{abc}{bef} -> \frac{abc}{bef}
186 | # \fracabc -> \frac{a}{b}c
187 | # \sqrta -> \sqrt{a}
188 | # \sqrtab -> sqrt{a}b
189 | final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer)
190 | final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer)
191 | final_answer = final_answer.replace("$", "")
192 |
193 | # Normalize 100,000 -> 100000
194 | if final_answer.replace(",", "").isdigit():
195 | final_answer = final_answer.replace(",", "")
196 |
197 | return final_answer
198 |
199 | def parse_tex(self, text: str, time_limit: int = 5) -> sympy.Basic:
200 | """
201 | Wrapper around `sympy.parse_text` that outputs a SymPy expression.
202 | Typically, you want to apply `normalize_text` as a preprocessing step.
203 | """
204 | try:
205 | with timeout(seconds=time_limit):
206 | parsed = parse_latex(text)
207 | except (
208 | # general error handling: there is a long tail of possible sympy/other
209 | # errors we would like to catch
210 | Exception
211 | ) as e:
212 | print(f"failed to parse {text} with exception {e}")
213 | return None
214 |
215 | return parsed
216 |
217 | def is_exp_equiv(self, x1: sympy.Basic, x2: sympy.Basic, time_limit=5) -> bool:
218 | """
219 | Determines whether two sympy expressions are equal.
220 | """
221 | try:
222 | with timeout(seconds=time_limit):
223 | try:
224 | diff = x1 - x2
225 | except (SympifyError, ValueError, TypeError) as e:
226 | print(
227 | f"Couldn't subtract {x1} and {x2} with exception {e}"
228 | )
229 | return False
230 |
231 | try:
232 | if sympy.simplify(diff) == 0:
233 | return True
234 | else:
235 | return False
236 | except (SympifyError, ValueError, TypeError) as e:
237 | print(f"Failed to simplify {x1}-{x2} with {e}")
238 | return False
239 | except TimeoutError as e:
240 | print(f"Timed out comparing {x1} and {x2}")
241 | return False
242 | except Exception as e:
243 | print(f"failed on unrecognized exception {e}")
244 | return False
245 |
246 | def is_tex_equiv(self, x1: str, x2: str, time_limit=5) -> bool:
247 | """
248 | Determines whether two (ideally normalized using `normalize_text`) TeX expressions are equal.
249 |
250 | Does so by first checking for string exact-match, then falls back on sympy-equivalence,
251 | following the (Lewkowycz et al. 2022) methodology.
252 | """
253 | if x1 == x2:
254 | # don't resort to sympy if we have full string match, post-normalization
255 | return True
256 | else:
257 | return False
258 | parsed_x2 = self.parse_tex(x2)
259 | if not parsed_x2:
260 | # if our reference fails to parse into a Sympy object,
261 | # we forgo parsing + checking our generated answer.
262 | return False
263 | return self.is_exp_equiv(self.parse_tex(x1), parsed_x2, time_limit=time_limit)
264 |
--------------------------------------------------------------------------------
/images/main_result_image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hbin0701/Self-Explore/54b8ba9f4cee1ef52f72de916e100e0da2ae6863/images/main_result_image.png
--------------------------------------------------------------------------------
/images/overview_image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hbin0701/Self-Explore/54b8ba9f4cee1ef52f72de916e100e0da2ae6863/images/overview_image.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.27.0
2 | datasets==2.14.6
3 | deepspeed==0.13.1
4 | editdistance==0.8.1
5 | einops==0.7.0
6 | flash-attn==2.5.2
7 | huggingface-hub==0.20.3
8 | json5==0.9.14
9 | jsonlines==4.0.0
10 | numpy==1.24.3
11 | openai==1.13.3
12 | sentencepiece==0.1.99
13 | transformers==4.38.1
14 | triton==2.1.0
15 | trl==0.7.10
16 | vllm==0.3.2
17 | wandb==0.16.1
--------------------------------------------------------------------------------
/scripts/gsm8k/dpo/config.yaml:
--------------------------------------------------------------------------------
1 | # Model arguments
2 | model_name_or_path: # put model_name_to_train_on_here
3 | run_name: # put_wandb_run_nmae_here
4 |
5 | # Data training arguments
6 | # For definitions, see: src/h4/training/config.py
7 | dataset_mixer:
8 | HuggingFaceH4/ultrafeedback_binarized: 1.0
9 | dataset_splits:
10 | - train_prefs
11 | - test_prefs
12 | preprocessing_num_workers: 12
13 | train_data_file: train_data_file_directory # for both training and test_data, put training data.
14 | test_data_file: train_data_file_directory # testing data (i.e. eval set) is a dummy data(subset of training), because we don't use eval data.
15 |
16 | # DPOTrainer argument
17 | bf16: true
18 | beta: 0.1
19 | do_eval: true
20 | evaluation_strategy: epoch
21 | gradient_accumulation_steps: 1
22 | gradient_checkpointing: true
23 | hub_model_id: zephyr-7b-dpo-full
24 | learning_rate: 1.0e-7 # use 1.0e-7 for Mistral, 1.0e-6 for others.
25 | log_level: info
26 | logging_steps: 10
27 | lr_scheduler_type: linear
28 | max_length: 384
29 | num_train_epochs: 3
30 | optim: rmsprop
31 | output_dir: some_directory_to_save # Put directory for saving model checkpoints here
32 | per_device_train_batch_size: 8
33 | per_device_eval_batch_size: 8
34 | push_to_hub: false
35 | save_strategy: "epoch"
36 | save_only_model: true
37 | save_total_limit: 3
38 | seed: 42
39 | warmup_ratio: 0.1
--------------------------------------------------------------------------------
/scripts/gsm8k/dpo/deepspeed_zero3.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | main_process_port: 32899
3 | debug: false
4 | deepspeed_config:
5 | deepspeed_multinode_launcher: standard
6 | offload_optimizer_device: none
7 | offload_param_device: none
8 | zero3_init_flag: true
9 | zero3_save_16bit_model: true
10 | zero_stage: 3
11 | distributed_type: DEEPSPEED
12 | downcast_bf16: 'no'
13 | machine_rank: 0
14 | main_training_function: main
15 | mixed_precision: bf16
16 | num_machines: 1
17 | num_processes: 4
18 | rdzv_backend: static
19 | same_network: true
20 | tpu_env: []
21 | tpu_use_cluster: false
22 | tpu_use_sudo: false
23 | use_cpu: false
24 |
--------------------------------------------------------------------------------
/scripts/gsm8k/dpo/run_dpo.sh:
--------------------------------------------------------------------------------
1 | # Run form main directory.
2 | CUDA_VISIBLE_DEVICES=0,1,2,3 ACCELERATE_LOG_LEVEL=info accelerate launch --config_file scripts/gsm8k/dpo/deepspeed_zero3.yaml dpo/run_dpo.py scripts/gsm8k/dpo/config.yaml
--------------------------------------------------------------------------------
/scripts/gsm8k/sft/config.yaml:
--------------------------------------------------------------------------------
1 | command_file: null
2 | commands: null
3 | main_process_port: 51888
4 | compute_environment: LOCAL_MACHINE
5 | deepspeed_config:
6 | deepspeed_multinode_launcher: standard
7 | gradient_clipping: 1.0
8 | zero_stage: 1
9 | distributed_type: DEEPSPEED
10 | downcast_bf16: true
11 | dynamo_backend: 'NO'
12 | gpu_ids: null
13 | machine_rank: 0
14 | main_process_ip: null
15 | main_process_port: null
16 | main_training_function: main
17 | fsdp_config: {}
18 | megatron_lm_config: {}
19 | mixed_precision: bf16
20 | num_machines: 1
21 | num_processes: 4
22 | rdzv_backend: static
23 | same_network: true
24 | tpu_name: null
25 | tpu_zone: null
26 | use_cpu: false
27 |
--------------------------------------------------------------------------------
/scripts/gsm8k/sft/run_ft.sh:
--------------------------------------------------------------------------------
1 | export WANDB_API_KEY=wandb_key # put your wandb key here
2 | export WANDB_PROJECT=project_name # put your project name here
3 | export WANDB_ENTITY=my_id # put your wandb id here.
4 |
5 | model_name_or_path=/home/user/models/deepseek-math-7b-base # model path to train on.
6 | save_generator_id=deepseek_GSM8K_FT # model name to be saved.
7 |
8 | save_dir=/home/user/models/${save_generator_id}/
9 | export WANDB_NAME=${save_generator_id}
10 |
11 | # lr: 1e-6 for Mistral and 1e-5 for Others.
12 |
13 | accelerate launch \
14 | --config_file scripts/gsm8k/sft/config.yaml \
15 | --main_process_port=40999 \
16 | sft/train_generator.py \
17 | --model_name_or_path ${model_name_or_path} \
18 | --data_dir put_your_data_file_here \
19 | --target_set train \
20 | --save_dir ${save_dir} \
21 | --num_train_epoches 5 \
22 | --save_strategy epoch \
23 | --per_device_train_batch_size 16 \
24 | --per_device_eval_batch_size 4 \
25 | --gradient_accumulation_steps 1 \
26 | --gradient_checkpointing True \
27 | --learning_rate 1e-5 \
28 | --weight_decay 0 \
29 | --lr_scheduler_type "linear" \
30 | --warmup_steps 0 \
31 | --save_best False \
32 | --save_total_limit 5 \
33 | --logging_dir ./wandb \
34 | --logging_steps 8 \
35 | --seed 42 \
36 | --save_model_only True \
37 | --mode "ft_GSM8K" # mode is one of "ft_GSM8K", "ft_MATH", "rft_GSM8K", "rft_MATH"
--------------------------------------------------------------------------------
/scripts/gsm8k/sft/run_rft.sh:
--------------------------------------------------------------------------------
1 | export WANDB_API_KEY=wandb_key # put your wandb key here
2 | export WANDB_PROJECT=project_name # put your project name here
3 | export WANDB_ENTITY=my_id # put your wandb id here.
4 |
5 | model_name_or_path=/home/user/models/deepseek-math-7b-base # model path to train on.
6 | save_generator_id=deepseek_GSM8K_RFT # model name to be saved.
7 |
8 | save_dir=/home/user/models/${save_generator_id}/
9 | export WANDB_NAME=${save_generator_id}
10 |
11 | # lr: 1e-6 for Mistral and 1e-5 for Others.
12 |
13 | accelerate launch \
14 | --config_file scripts/gsm8k/sft/config.yaml \
15 | --main_process_port=40999 \
16 | sft/train_generator.py \
17 | --model_name_or_path ${model_name_or_path} \
18 | --data_dir put_your_data_file_here \
19 | --target_set train \
20 | --save_dir ${save_dir} \
21 | --num_train_epoches 5 \
22 | --save_strategy epoch \
23 | --per_device_train_batch_size 16 \
24 | --per_device_eval_batch_size 4 \
25 | --gradient_accumulation_steps 1 \
26 | --gradient_checkpointing True \
27 | --learning_rate 1e-5 \
28 | --weight_decay 0 \
29 | --lr_scheduler_type "linear" \
30 | --warmup_steps 0 \
31 | --save_best False \
32 | --save_total_limit 5 \
33 | --logging_dir ./wandb \
34 | --logging_steps 8 \
35 | --seed 42 \
36 | --save_model_only True \
37 | --mode "rft_GSM8K" # mode is one of "ft_GSM8K", "ft_MATH", "rft_GSM8K", "rft_MATH"
--------------------------------------------------------------------------------
/scripts/math/dpo/config.yaml:
--------------------------------------------------------------------------------
1 | # Model arguments
2 | model_name_or_path: # put model_name_to_train_on_here
3 | run_name: # put_wandb_run_nmae_here
4 |
5 | # Data training arguments
6 | # For definitions, see: src/h4/training/config.py
7 | dataset_mixer:
8 | HuggingFaceH4/ultrafeedback_binarized: 1.0
9 | dataset_splits:
10 | - train_prefs
11 | - test_prefs
12 | preprocessing_num_workers: 12
13 | train_data_file: train_data_file_directory # for both training and test_data, put training data.
14 | test_data_file: train_data_file_directory # testing data (i.e. eval set) is a dummy data(subset of training), because we don't use eval data.
15 |
16 | # DPOTrainer argument
17 | bf16: true
18 | beta: 0.1
19 | do_eval: true
20 | evaluation_strategy: epoch
21 | gradient_accumulation_steps: 1
22 | gradient_checkpointing: true
23 | hub_model_id: zephyr-7b-dpo-full
24 | learning_rate: 1.0e-7 # use 1.0e-7 for Mistral, 1.0e-6 for others.
25 | log_level: info
26 | logging_steps: 10
27 | lr_scheduler_type: linear
28 | max_length: 1024
29 | max_prompt_length: 512
30 | num_train_epochs: 3
31 | optim: rmsprop
32 | output_dir: some_directory_to_save # Put directory for saving model checkpoints here
33 | per_device_train_batch_size: 8
34 | per_device_eval_batch_size: 8
35 | push_to_hub: false
36 | save_strategy: "epoch"
37 | save_only_model: true
38 | save_total_limit: 3
39 | seed: 42
40 | warmup_ratio: 0.1
--------------------------------------------------------------------------------
/scripts/math/dpo/deepspeed_zero3.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | main_process_port: 48882
3 | debug: false
4 | deepspeed_config:
5 | deepspeed_multinode_launcher: standard
6 | offload_optimizer_device: none
7 | offload_param_device: none
8 | zero3_init_flag: true
9 | zero3_save_16bit_model: true
10 | zero_stage: 3
11 | distributed_type: DEEPSPEED
12 | downcast_bf16: 'no'
13 | machine_rank: 0
14 | main_training_function: main
15 | mixed_precision: bf16
16 | num_machines: 1
17 | num_processes: 4
18 | rdzv_backend: static
19 | same_network: true
20 | tpu_env: []
21 | tpu_use_cluster: false
22 | tpu_use_sudo: false
23 | use_cpu: false
24 |
--------------------------------------------------------------------------------
/scripts/math/dpo/run_dpo.sh:
--------------------------------------------------------------------------------
1 | # Run form main directory.
2 | CUDA_VISIBLE_DEVICES=0,1,2,3 ACCELERATE_LOG_LEVEL=info accelerate launch --config_file scripts/math/dpo/deepspeed_zero3.yaml dpo/run_dpo.py scripts/math/dpo/config.yaml
--------------------------------------------------------------------------------
/scripts/math/sft/config.yaml:
--------------------------------------------------------------------------------
1 | command_file: null
2 | commands: null
3 | main_process_port: 51888
4 | compute_environment: LOCAL_MACHINE
5 | deepspeed_config:
6 | deepspeed_multinode_launcher: standard
7 | gradient_clipping: 1.0
8 | zero_stage: 1
9 | distributed_type: DEEPSPEED
10 | downcast_bf16: true
11 | dynamo_backend: 'NO'
12 | gpu_ids: null
13 | machine_rank: 0
14 | main_process_ip: null
15 | main_process_port: null
16 | main_training_function: main
17 | fsdp_config: {}
18 | megatron_lm_config: {}
19 | mixed_precision: bf16
20 | num_machines: 1
21 | num_processes: 4
22 | rdzv_backend: static
23 | same_network: true
24 | tpu_name: null
25 | tpu_zone: null
26 | use_cpu: false
27 |
--------------------------------------------------------------------------------
/scripts/math/sft/run_ft.sh:
--------------------------------------------------------------------------------
1 | export WANDB_API_KEY=wandb_key # put your wandb key here
2 | export WANDB_PROJECT=project_name # put your project name here
3 | export WANDB_ENTITY=my_id # put your wandb id here.
4 |
5 | model_name_or_path=/home/user/models/deepseek-math-7b-base # model path to train on.
6 | save_generator_id=deepseek_math_FT # model name to be saved.
7 |
8 | save_dir=/home/user/models/${save_generator_id}/
9 | export WANDB_NAME=${save_generator_id}
10 |
11 | # lr: 1e-6 for Mistral and 1e-5 for Others.
12 |
13 | accelerate launch \
14 | --config_file scripts/math/sft/config.yaml \
15 | --main_process_port=40999 \
16 | sft/train_generator.py \
17 | --model_name_or_path ${model_name_or_path} \
18 | --data_dir put_your_data_file_here \
19 | --target_set train \
20 | --save_dir ${save_dir} \
21 | --num_train_epoches 5 \
22 | --save_strategy epoch \
23 | --max_length 1024 \
24 | --per_device_train_batch_size 8 \
25 | --per_device_eval_batch_size 2 \
26 | --gradient_accumulation_steps 2 \
27 | --gradient_checkpointing True \
28 | --learning_rate 1e-5 \
29 | --weight_decay 0 \
30 | --lr_scheduler_type "linear" \
31 | --warmup_steps 0 \
32 | --save_best False \
33 | --save_total_limit 5 \
34 | --logging_dir ./wandb \
35 | --logging_steps 8 \
36 | --seed 42 \
37 | --save_model_only True \
38 | --mode "ft_MATH" # mode is one of "ft_GSM8K", "ft_MATH", "rft_GSM8K", "rft_MATH"
--------------------------------------------------------------------------------
/scripts/math/sft/run_rft.sh:
--------------------------------------------------------------------------------
1 | export WANDB_API_KEY=wandb_key # put your wandb key here
2 | export WANDB_PROJECT=project_name # put your project name here
3 | export WANDB_ENTITY=my_id # put your wandb id here.
4 |
5 | model_name_or_path=/home/user/models/deepseek-math-7b-base # model path to train on.
6 | save_generator_id=deepseek_math_RFT # model name to be saved.
7 |
8 | save_dir=/home/user/models/${save_generator_id}/
9 | export WANDB_NAME=${save_generator_id}
10 |
11 | # lr: 1e-6 for Mistral and 1e-5 for Others.
12 |
13 | accelerate launch \
14 | --config_file scripts/math/sft/config.yaml \
15 | --main_process_port=40999 \
16 | sft/train_generator.py \
17 | --model_name_or_path ${model_name_or_path} \
18 | --data_dir put_your_data_file_here \
19 | --target_set train \
20 | --save_dir ${save_dir} \
21 | --num_train_epoches 5 \
22 | --save_strategy epoch \
23 | --max_length 1024 \
24 | --per_device_train_batch_size 8 \
25 | --per_device_eval_batch_size 2 \
26 | --gradient_accumulation_steps 2 \
27 | --gradient_checkpointing True \
28 | --learning_rate 1e-5 \
29 | --weight_decay 0 \
30 | --lr_scheduler_type "linear" \
31 | --warmup_steps 0 \
32 | --save_best False \
33 | --save_total_limit 5 \
34 | --logging_dir ./wandb \
35 | --logging_steps 8 \
36 | --seed 42 \
37 | --save_model_only True \
38 | --mode "rft_MATH" # mode is one of "ft_GSM8K", "ft_MATH", "rft_GSM8K", "rft_MATH"
--------------------------------------------------------------------------------
/sft/sft_datasets.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import re
4 | import torch
5 | import torch.nn.functional as F
6 | from typing import Optional, Sequence, List, Set, Dict, Any, Union
7 | import transformers
8 | import logging
9 | from dataclasses import dataclass
10 | import pathlib
11 |
12 | from utils.datasets import read_jsonl, get_few_shot_prompt, left_pad_sequences, right_pad_sequences, mask_labels
13 | from utils.constants import DEFAULT_BOS_TOKEN, DEFAULT_EOS_TOKEN, IGNORE_INDEX
14 | from utils.gsm8k.decoding import extract_answer
15 |
16 | def get_examples(data_dir, split, mode):
17 |
18 | train_mode, dataset = mode.split("_")
19 |
20 | if train_mode == "ft":
21 | data_dir = {
22 | 'train': data_dir,
23 | 'test': data_dir,
24 | }[split]
25 |
26 | examples = read_jsonl(data_dir)
27 |
28 | if dataset == "GSM8K":
29 | for ex in examples:
30 | ex['response'] = ex['response'].replace('#### ', 'The answer is ')
31 |
32 | elif train_mode == "rft":
33 | examples = read_jsonl(data_dir)
34 |
35 | else:
36 | raise NotImplementedError
37 |
38 | if dataset == "GSM8K":
39 | for ex in examples:
40 | if "question" not in ex:
41 | ex["question"] = ex["query"].strip()
42 | if "answer" not in ex:
43 | ex["answer"] = ex["response"].strip()
44 |
45 | if 'mm' not in data_dir.lower() and 'metamath' not in data_dir.lower():
46 | ex.update(question=ex["question"].rstrip() + "\n")
47 | ex.update(answer=ex["answer"].replace('#### ', 'The answer is '))
48 | else:
49 | ex.update(question=ex["question"].rstrip() + "\n")
50 | ex.update(answer=ex["answer"].strip())
51 | else:
52 | for ex in examples:
53 | if "question" not in ex:
54 | ex["question"] = ex["query"].rstrip() + "\n"
55 | if "answer" not in ex:
56 | ex["answer"] = ex["response"].strip()
57 |
58 | print(examples[0])
59 | print(f"{len(examples)} {split} examples")
60 | return examples
61 |
62 |
63 | def make_finetuning_generator_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args: dataclass) -> Dict:
64 | train_dataset = FineTuningGeneratorDataset(
65 | tokenizer=tokenizer,
66 | data_dir=data_args.data_dir,
67 | mode=data_args.mode,
68 | target_set=data_args.target_set,
69 | loss_on_prefix=data_args.loss_on_prefix,
70 | )
71 | val_dataset = None
72 |
73 | return dict(train_dataset=train_dataset, val_dataset=val_dataset)
74 |
75 |
76 | def make_test_generator_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args: dataclass, inference_args: dataclass) -> Dict:
77 | test_dataset = TestGeneratorDataset(
78 | tokenizer=tokenizer,
79 | data_dir=data_args.data_dir,
80 | target_set=data_args.target_set
81 | )
82 | return test_dataset
83 |
84 | class FineTuningGeneratorDataset(torch.utils.data.Dataset):
85 | def __init__(
86 | self,
87 | tokenizer: transformers.PreTrainedTokenizer = None,
88 | mode: str="ft",
89 | data_dir: str = 'data/gsm8k',
90 | target_set: str = 'train',
91 | loss_on_prefix=True,
92 | ):
93 | self.tokenizer = tokenizer
94 | self.data_dir = data_dir
95 | self.target_set = target_set
96 | self.loss_on_prefix = loss_on_prefix
97 | self.pad_token_id = tokenizer.pad_token_id
98 | self.eos_token_id = tokenizer.eos_token_id
99 |
100 | print("+ [Dataset] Loading Training Data")
101 | self.examples = get_examples(self.data_dir, target_set, mode)
102 | qns_str = [ex["question"] for ex in self.examples]
103 | ans_str = [ex["answer"] for ex in self.examples]
104 |
105 | print("+ [Dataset] Tokenizing Testing Data")
106 |
107 | qns_tokens = []
108 | ans_tokens = []
109 |
110 | for x, y in zip(qns_str, ans_str):
111 | x_ = tokenizer(x, padding=False).input_ids
112 | y_ = tokenizer(y, padding=False, add_special_tokens=False).input_ids
113 |
114 | if len(x_) + len(y_) >= 1024:
115 | continue
116 |
117 | else:
118 | qns_tokens.append(x_)
119 | ans_tokens.append(y_)
120 |
121 | # qns_tokens = tokenizer(qns_str, padding=False, max_length=1024).input_ids
122 | # ans_tokens = tokenizer(ans_str, padding=False, add_special_tokens=False, max_length=1024).input_ids
123 |
124 | self.qns_str = qns_str
125 | self.ans_str = ans_str
126 | self.qns_tokens = qns_tokens
127 | self.ans_tokens = ans_tokens
128 |
129 | print("MAX QNS TOKENS", max([len(qns_tokens[i]) for i in range(len(qns_tokens))]))
130 | print("MAX ANS TOKENS", max([len(ans_tokens[i]) for i in range(len(ans_tokens))]))
131 |
132 | self.max_len = max([
133 | len(qns_tokens[i]) + len(ans_tokens[i]) + 1
134 | for i in range(len(qns_tokens))
135 | ]
136 | )
137 | print(f"Max tokens: {self.max_len}")
138 | print("Length:", len(self.qns_tokens))
139 |
140 | def __len__(self):
141 | return len(self.qns_tokens)
142 |
143 | def __getitem__(self, idx):
144 | qn_tokens = self.qns_tokens[idx]
145 | ans_tokens = self.ans_tokens[idx]
146 |
147 | input_ids = qn_tokens + ans_tokens + [self.eos_token_id]
148 | # input_ids = qn_tokens + ans_tokens
149 | labels = input_ids
150 |
151 | masks = (
152 | ([1] if self.loss_on_prefix else [0]) * len(qn_tokens)
153 | + ([1] * len(ans_tokens))
154 | + ([1])
155 | )
156 | labels = mask_labels(labels, masks)
157 |
158 | input_ids = torch.tensor(input_ids)
159 | labels = torch.tensor(labels)
160 | return dict(input_ids=input_ids, labels=labels)
161 |
162 | def collate_fn(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
163 | input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
164 | input_ids, attention_mask = right_pad_sequences(input_ids, padding_value=self.pad_token_id, return_attention_mask=True)
165 | labels = right_pad_sequences(labels, padding_value=IGNORE_INDEX, return_attention_mask=False)
166 |
167 | return dict(
168 | input_ids=input_ids,
169 | attention_mask=attention_mask,
170 | labels=labels,
171 | )
172 |
173 |
174 |
175 |
176 | class TestGeneratorDataset(torch.utils.data.Dataset):
177 | """Left Padding"""
178 | def __init__(
179 | self,
180 | tokenizer: transformers.PreTrainedTokenizer = None,
181 | data_dir: str = 'data/gsm8k',
182 | target_set: str = None,
183 | ):
184 | self.tokenizer = tokenizer
185 | self.data_dir = data_dir
186 | self.target_set = target_set
187 | self.pad_token_id = tokenizer.pad_token_id
188 |
189 | print("+ [Dataset] Loading Testing Data")
190 | self.examples = get_examples(data_dir, target_set)
191 | qns_str = [ex["question"] for ex in self.examples]
192 | ans_str = [ex["answer"] for ex in self.examples]
193 | gts_str = [extract_answer(ans) for ans in ans_str]
194 |
195 | print("+ [Dataset] Tokenizing Testing Data")
196 | qns_tokens = tokenizer(qns_str, padding=False, max_length=1024, truncate=True).input_ids
197 | ans_tokens = tokenizer(ans_str, padding=False, add_special_tokens=False, max_length=1024, truncate=True).input_ids
198 |
199 | self.qns_str = qns_str
200 | self.qns_tokens = qns_tokens
201 | self.ans_str = ans_str
202 | self.gts_str = gts_str
203 |
204 | self.max_len = max([
205 | len(qns_tokens[i]) + len(ans_tokens[i]) + 1
206 | for i in range(len(qns_tokens))
207 | ]
208 | )
209 | print(f"Max tokens: {self.max_len}")
210 |
211 | def __len__(self):
212 | return len(self.examples)
213 |
214 | def __getitem__(self, idx):
215 | qn_tokens = self.qns_tokens[idx]
216 | qn_str = self.qns_str[idx]
217 | ans_str = self.ans_str[idx]
218 | gt = self.gts_str[idx]
219 |
220 | input_ids = torch.tensor(qn_tokens)
221 | return dict(
222 | idx=idx,
223 | input_ids=input_ids,
224 | input=qn_str,
225 | question=qn_str,
226 | reference=gt,
227 | record_data=dict(answer=ans_str, ground_truth=gt),
228 | )
229 |
230 | def collate_fn(self, instances: Sequence[Dict]) -> Dict[str, Any]:
231 | idx, input_ids, input, question, reference, record_data = tuple([instance[key] for instance in instances] for key in ("idx", "input_ids", "input", "question", "reference", "record_data"))
232 | record_data = {k: [instance[k] for instance in record_data] for k in record_data[0].keys()}
233 |
234 | input_ids, attention_mask = left_pad_sequences(input_ids, padding_value=self.pad_token_id, return_attention_mask=True)
235 |
236 | return dict(
237 | idx=idx,
238 | input_ids=input_ids,
239 | attention_mask=attention_mask,
240 | input=input,
241 | question=question,
242 | reference=reference,
243 | record_data=record_data,
244 | )
245 |
--------------------------------------------------------------------------------
/sft/train_generator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import transformers
3 | from tqdm.auto import tqdm
4 | from dataclasses import dataclass, field
5 | from typing import Optional, List, Dict, Set, Any, Union
6 | import gc
7 | from accelerate import Accelerator
8 | import wandb
9 | import os
10 | import re
11 | from huggingface_hub import login
12 |
13 | from utils.states import set_deepspeed_config, set_training_states, set_random_seed
14 | from utils.optim import get_optimizers
15 | from utils.models import build_model, save_llm_checkpoint, save_llm, save_training_args_with_accelerator
16 | from utils.datasets import make_training_dataloaders
17 |
18 |
19 | @dataclass
20 | class ModelArguments:
21 | model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
22 |
23 | @dataclass
24 | class DataArguments:
25 | dataset: str = field(default='gsm8k')
26 | data_dir: str = field(default='data/gsm8k/', metadata={"help": "Path to the training data."})
27 | target_set: str = field(default='train')
28 | mode: str = field(default="ft", metadata={"help": "ft or rft"})
29 | loss_on_prefix: bool = field(default=True, metadata={"help": "Whether to compute loss on the prefix"})
30 |
31 | @dataclass
32 | class TrainingArguments:
33 | cache_dir: Optional[str] = field(default=None)
34 | model_max_length: int = field(
35 | default=2048,
36 | metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
37 | )
38 |
39 | max_steps: int = field(default=-1, metadata={"help": "When it is specified, num_train_epoches is ignored"})
40 | num_train_epoches: int = field(default=1)
41 | per_device_train_batch_size: int = field(default=4)
42 | gradient_accumulation_steps: int = field(default=1)
43 | gradient_checkpointing: bool = field(default=True)
44 |
45 | eval_steps: int = field(default=-1, metadata={"help": "When it is specified, eval_epoches is ignored"})
46 | eval_epoches: int = field(default=1)
47 | evaluation_strategy: str = field(default="epoch")
48 |
49 | per_device_eval_batch_size: int = field(default=4)
50 |
51 | learning_rate: float = field(default=1e-5)
52 | weight_decay: float = field(default=0)
53 | lr_scheduler_type: str = field(default="linear")
54 | warmup_steps: int = field(default=-1, metadata={"help": "When it is specified, warmup_ratio is ignored"})
55 | warmup_ratio: float = field(default=0)
56 |
57 | logging_steps: int = field(default=-1, metadata={"help": "When it is specified, logging_epoches is ignored"})
58 | logging_epoches: int = field(default=1)
59 |
60 | save_steps: int = field(default=-1, metadata={"help": "When it is specified, save_epoches is ignored"})
61 | save_epoches: int = field(default=1)
62 | save_total_limit: int = field(default=3)
63 | save_best: bool = field(default=False)
64 | save_strategy: str = field(default="epoch")
65 | save_model_only: bool = field(default=True)
66 |
67 | seed: int = field(default=42)
68 |
69 | @dataclass
70 | class GenerationArguments:
71 | do_sample: bool = field(default=False)
72 | num_beams: int = field(default=1)
73 |
74 | temperature: float = field(default=0.7)
75 | top_k: int = field(default=50)
76 | top_p: float = field(default=1.0)
77 | repetition_penalty: float = field(default=1.0)
78 | length_penalty: float = field(default=1.0)
79 |
80 | max_length : int = field(default=2048)
81 | max_new_tokens: int = field(default=400)
82 |
83 | @dataclass
84 | class OutputArguments:
85 | logging_dir: str = field(default='wandb/')
86 | save_dir: str = field(default='checkpoints/')
87 |
88 |
89 | def main():
90 | parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments, GenerationArguments, OutputArguments))
91 | model_args, data_args, training_args, generation_args, output_args = parser.parse_args_into_dataclasses()
92 | config_args_dict = model_args.__dict__.copy().update(dict(**data_args.__dict__, **training_args.__dict__))
93 | set_random_seed(training_args.seed)
94 |
95 | # print("TORCH VERSION", torch.__version__)
96 | # print("TORCH NCCL VERSION", torch.cuda.nccl.version())
97 |
98 | from sft_datasets import make_finetuning_generator_data_module
99 |
100 | accelerator = Accelerator(gradient_accumulation_steps=training_args.gradient_accumulation_steps)
101 |
102 | # load model, tokenizer, and dataloader
103 | set_deepspeed_config(accelerator, training_args)
104 | model, tokenizer = build_model(model_args, training_args)
105 |
106 | data_module = make_finetuning_generator_data_module(tokenizer, data_args)
107 | train_dataloader, val_dataloader = make_training_dataloaders(data_module, training_args)
108 |
109 | # config optimizer and scheduler
110 | set_training_states(data_module, training_args)
111 | optimizer, lr_scheduler = get_optimizers(model, training_args)
112 |
113 | model, train_dataloader, optimizer = accelerator.prepare(model, train_dataloader, optimizer)
114 |
115 | model.to(torch.bfloat16)
116 |
117 | cur_epoch = local_step = global_step = 0
118 | start_local_step = start_global_step = -1
119 |
120 | # init wandb
121 | if accelerator.is_main_process:
122 | project_name = os.environ['WANDB_PROJECT']
123 | logging_dir = os.path.join(output_args.logging_dir, project_name)
124 |
125 | os.makedirs(logging_dir, exist_ok=True)
126 | wandb_id = wandb.util.generate_id()
127 | wandb.init(id=wandb_id, dir=logging_dir, config=config_args_dict)
128 |
129 |
130 | # training
131 | global_step = 0
132 | model.train()
133 | while global_step < training_args.num_training_steps:
134 | train_dataloader_iterator = tqdm(enumerate(train_dataloader), total=len(train_dataloader), desc=f'Training - Epoch {cur_epoch+1} / {training_args.num_train_epoches}') if accelerator.is_main_process else enumerate(train_dataloader)
135 |
136 | for local_step, batch in train_dataloader_iterator:
137 | if global_step < start_global_step:
138 | global_step += 1
139 | continue
140 |
141 | batch_input = {k: v for k, v in batch.items() if k in ('input_ids', 'attention_mask', 'labels')}
142 | # backpropagation
143 | with accelerator.accumulate(model):
144 | output = model(**batch_input, return_dict=True)
145 | loss = output.loss
146 | accelerator.backward(loss)
147 |
148 | optimizer.step()
149 | if not accelerator.optimizer_step_was_skipped and global_step % training_args.gradient_accumulation_steps == 0:
150 | lr_scheduler.step()
151 | optimizer.zero_grad()
152 |
153 | # training logging
154 | if accelerator.is_main_process:
155 | train_dataloader_iterator.set_postfix(epoch=cur_epoch, step=local_step, loss=loss.item())
156 |
157 | if global_step % training_args.num_logging_steps == 0:
158 | wandb.log({
159 | 'loss': loss.item(),
160 | 'lr': lr_scheduler.get_last_lr()[0]
161 | }, step=global_step)
162 |
163 | # save checkpoint
164 | if global_step != 0 and global_step % training_args.per_save_steps == 0:
165 | accelerator.wait_for_everyone()
166 | save_llm_checkpoint(accelerator, model, tokenizer, output_args.save_dir, global_step, training_args.save_total_limit)
167 |
168 | # save states for resuming [Not needed]
169 | # if global_step != 0 and global_step % training_args.per_save_steps == 0:
170 | # accelerator.wait_for_everyone()
171 | # accelerator.save_state(os.path.join(output_args.save_dir, 'resume'))
172 | global_step += 1
173 |
174 | cur_epoch += 1
175 |
176 | gc.collect(); torch.cuda.empty_cache()
177 |
178 | accelerator.wait_for_everyone()
179 | save_llm(accelerator, model, tokenizer, output_args.save_dir)
180 | save_training_args_with_accelerator(accelerator, training_args, output_args.save_dir)
181 |
182 | accelerator.save_state(os.path.join(output_args.save_dir, 'resume'))
183 | if accelerator.is_main_process:
184 | wandb.finish()
185 |
186 | if __name__ == "__main__":
187 | login(token="your_hugging_face_token_here") # put your huggingface token here!
188 | main()
189 |
--------------------------------------------------------------------------------
/sft/utils/constants.py:
--------------------------------------------------------------------------------
1 | IGNORE_INDEX = -100
2 | DEFAULT_PAD_TOKEN = ""
3 | DEFAULT_BOS_TOKEN = ""
4 | DEFAULT_EOS_TOKEN = ""
5 | DEFAULT_UNK_TOKEN = ""
6 |
7 | LLAMA_EQUALS_TOKENS = set([353, 3892, 29922, 10457]) # _=, )=, =, =-
8 | LLAMA_LEFTMARK_TOKENS = set([3532, 9314]) # <<, _<<
9 | LLAMA_RIGHTMARK_TOKEN = 6778 # >>
10 | LLAMA_NEWLINE_TOKEN = 13 # \n
11 |
12 |
--------------------------------------------------------------------------------
/sft/utils/flash_attn_monkey_patch.py:
--------------------------------------------------------------------------------
1 | # refer to https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama2_flash_attn_monkey_patch.py
2 |
3 | import warnings
4 | from typing import Optional, Tuple
5 |
6 | import torch
7 | from flash_attn import __version__ as flash_attn_version
8 | from flash_attn.bert_padding import pad_input, unpad_input
9 | from flash_attn.flash_attn_interface import (
10 | flash_attn_func,
11 | flash_attn_varlen_kvpacked_func,
12 | )
13 | from transformers.models.llama.modeling_llama import (
14 | LlamaAttention,
15 | LlamaModel,
16 | rotate_half,
17 | )
18 | from transformers.models.mistral.modeling_mistral import (
19 | MistralAttention,
20 | MistralModel,
21 | rotate_half,
22 | )
23 |
24 |
25 |
26 | def apply_rotary_pos_emb(q, k, cos_sin, position_ids):
27 | gather_indices = position_ids[:, :, None, None] # [bsz, seq_len, 1, 1]
28 | gather_indices = gather_indices.repeat(
29 | 1, 1, cos_sin[0].shape[1], cos_sin[0].shape[3]
30 | )
31 | bsz = gather_indices.shape[0]
32 | cos, sin = (
33 | torch.gather(x.transpose(1, 2).repeat(bsz, 1, 1, 1), 1, gather_indices)
34 | for x in cos_sin
35 | )
36 | q, k = ((x * cos) + (rotate_half(x) * sin) for x in (q, k))
37 | return q, k
38 |
39 |
40 | def forward(
41 | self,
42 | hidden_states: torch.Tensor,
43 | attention_mask: Optional[torch.Tensor] = None,
44 | position_ids: Optional[torch.Tensor] = None,
45 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
46 | output_attentions: bool = False,
47 | use_cache: bool = False,
48 | padding_mask: Optional[torch.Tensor] = None,
49 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
50 | if output_attentions:
51 | warnings.warn(
52 | "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
53 | )
54 |
55 | bsz, q_len, _ = hidden_states.size()
56 | kv_heads = getattr(self, "num_key_value_heads", self.num_heads)
57 |
58 | q, k, v = (
59 | op(hidden_states).view(bsz, q_len, nh, self.head_dim)
60 | for op, nh in (
61 | (self.q_proj, self.num_heads),
62 | (self.k_proj, kv_heads),
63 | (self.v_proj, kv_heads),
64 | )
65 | )
66 | # shape: (b, s, num_heads, head_dim)
67 |
68 | kv_seq_len = k.shape[1]
69 | past_kv_len = 0
70 | if past_key_value is not None:
71 | past_kv_len = past_key_value[0].shape[2]
72 | kv_seq_len += past_kv_len
73 |
74 | cos_sin = self.rotary_emb(v, seq_len=kv_seq_len)
75 | q, k = apply_rotary_pos_emb(q, k, cos_sin, position_ids)
76 |
77 | if past_key_value is not None:
78 | assert (
79 | flash_attn_version >= "2.1.0"
80 | ), "past_key_value support requires flash-attn >= 2.1.0"
81 | # reuse k, v
82 | k = torch.cat([past_key_value[0].transpose(1, 2), k], dim=1)
83 | v = torch.cat([past_key_value[1].transpose(1, 2), v], dim=1)
84 |
85 | past_key_value = (k.transpose(1, 2), v.transpose(1, 2)) if use_cache else None
86 |
87 | if attention_mask is None:
88 | output = flash_attn_func(q, k, v, 0.0, softmax_scale=None, causal=True).view(
89 | bsz, q_len, -1
90 | )
91 | else:
92 | q, indices, cu_q_lens, max_s = unpad_input(q, attention_mask[:, -q_len:])
93 | # We can skip concat and call unpad twice but seems better to call unpad only once.
94 | kv, _, cu_k_lens, max_k = unpad_input(
95 | torch.stack((k, v), dim=2), attention_mask
96 | )
97 | output_unpad = flash_attn_varlen_kvpacked_func(
98 | q,
99 | kv,
100 | cu_q_lens,
101 | cu_k_lens,
102 | max_s,
103 | max_k,
104 | 0.0,
105 | softmax_scale=None,
106 | causal=True,
107 | )
108 | output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim)
109 | output = pad_input(output_unpad, indices, bsz, q_len)
110 |
111 | return self.o_proj(output), None, past_key_value
112 |
113 |
114 | # Disable the transformation of the attention mask in LlamaModel as flash attention
115 | # takes a boolean key_padding_mask. Fills in the past kv length for use in forward.
116 | def _prepare_decoder_attention_mask(
117 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length
118 | ):
119 |
120 | if attention_mask is not None and torch.all(attention_mask):
121 | return None # This uses the faster call when training with full samples
122 |
123 | return attention_mask
124 |
125 |
126 | def replace_llama_attn_with_flash_attn():
127 | cuda_major, cuda_minor = torch.cuda.get_device_capability()
128 | if cuda_major < 8:
129 | warnings.warn(
130 | "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
131 | "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
132 | )
133 |
134 | LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask
135 | LlamaAttention.forward = forward
136 |
137 |
--------------------------------------------------------------------------------
/sft/utils/gsm8k/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hbin0701/Self-Explore/54b8ba9f4cee1ef52f72de916e100e0da2ae6863/sft/utils/gsm8k/__init__.py
--------------------------------------------------------------------------------
/sft/utils/gsm8k/decoding.py:
--------------------------------------------------------------------------------
1 | ### From OVM/utils/gsm8k/decoding.py
2 |
3 | from contextlib import contextmanager
4 | import signal
5 | import json
6 | import os
7 | import re
8 |
9 |
10 | # ANS_RE = re.compile(r"#### (\-?[0-9\.\,]+)")
11 | ANS_RE = re.compile(r"The answer is:?\s*(\-?[0-9\.\,]+)")
12 | INVALID_ANS = "[invalid]"
13 |
14 |
15 | def extract_answer(completion):
16 | match = ANS_RE.search(completion)
17 | if match:
18 | match_str = match.group(1).strip()
19 | st_str = standardize_value_str(match_str)
20 | try: eval(st_str); return st_str
21 | except: ...
22 | return INVALID_ANS
23 |
24 | def extract_answers(completions):
25 | return [extract_answer(completion) for completion in completions]
26 |
27 | def standardize_value_str(x):
28 | """Standardize numerical values"""
29 | y = x.replace(",", "")
30 | if '.' in y:
31 | y = y.rstrip('0')
32 | if y[-1] == '.':
33 | y = y[:-1]
34 | if not len(y):
35 | return INVALID_ANS
36 | if y[0] == '.':
37 | y = '0' + y
38 | if y[-1] == '%':
39 | y = str(eval(y[:-1]) / 100)
40 | return y.rstrip('.')
41 |
42 | def get_answer_label(response_answer, gt):
43 | if response_answer == INVALID_ANS:
44 | return INVALID_ANS
45 | return response_answer == gt
46 |
47 |
48 | # taken from
49 | # https://stackoverflow.com/questions/492519/timeout-on-a-function-call
50 | @contextmanager
51 | def timeout(duration, formula):
52 | def timeout_handler(signum, frame):
53 | raise Exception(f"'{formula}': timed out after {duration} seconds")
54 |
55 | signal.signal(signal.SIGALRM, timeout_handler)
56 | signal.alarm(duration)
57 | yield
58 | signal.alarm(0)
59 |
60 |
61 | def eval_with_timeout(formula, max_time=3):
62 | try:
63 | with timeout(max_time, formula):
64 | return round(eval(formula), ndigits=4)
65 | except Exception as e:
66 | signal.alarm(0)
67 | print(f"Warning: Failed to eval {formula}, exception: {e}")
68 | return None
69 |
70 |
71 | # refer to https://github.com/openai/grade-school-math/blob/master/grade_school_math/calculator.py
72 | def use_calculator(sample):
73 | if "<<" not in sample:
74 | return None
75 |
76 | parts = sample.split("<<")
77 | remaining = parts[-1]
78 | if ">>" in remaining:
79 | return None
80 | if "=" not in remaining:
81 | return None
82 | lhs = remaining.split("=")[0]
83 | lhs = lhs.replace(",", "")
84 | if any([x not in "0123456789*+-/.()" for x in lhs]):
85 | return None
86 | ans = eval_with_timeout(lhs)
87 | if remaining[-1] == '-' and ans is not None and ans < 0:
88 | ans = -ans
89 | return ans
90 |
91 |
92 |
93 |
94 |
95 |
96 |
--------------------------------------------------------------------------------
/sft/utils/gsm8k/metrics.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | from typing import Optional, List, Dict, Set, Any, Union
5 | import torch.distributed as dist
6 | from utils.gsm8k.decoding import INVALID_ANS, extract_answers, get_answer_label
7 |
8 |
9 |
10 | class GeneratorAnswerAcc:
11 | def __init__(self, n_data: int):
12 | self.n_data = n_data
13 |
14 | self.world_size = dist.get_world_size() if dist.is_initialized() else 1
15 |
16 | self.corrs = []
17 | self.gather = False
18 |
19 | @torch.inference_mode(mode=True)
20 | def __call__(self, completions: List[str], gts: List[str]):
21 | answers = extract_answers(completions)
22 |
23 | corrs = [float(get_answer_label(answer, gt) == True) for answer, gt in zip(answers, gts)]
24 |
25 | self.corrs.append(corrs)
26 |
27 | def get_metric(self, reset=True):
28 | if not self.gather:
29 | if self.world_size != 1:
30 | gathered_corrs = [None] * self.world_size
31 | for obj, container in [
32 | (self.corrs, gathered_corrs),
33 | ]:
34 | dist.all_gather_object(container, obj)
35 |
36 | flatten_corrs = []
37 | for corrs_gpus in zip(*gathered_corrs):
38 | for corrs in corrs_gpus:
39 | flatten_corrs.extend(corrs)
40 |
41 | else:
42 | flatten_corrs = [item for sublist in self.corrs for item in sublist]
43 |
44 | self.corrs = flatten_corrs[:self.n_data]
45 | self.gather = True
46 |
47 | acc = (sum(self.corrs) / len(self.corrs))
48 |
49 | if reset:
50 | self.corrs = []
51 | self.gather = False
52 | return acc
53 |
54 |
55 | class MultiSamplingAnswerAcc:
56 | def __init__(self, n_data: int = None):
57 | self.n_data = n_data
58 |
59 | self.world_size = dist.get_world_size() if dist.is_initialized() else 1
60 |
61 | self.answers = []
62 | self.gts = []
63 |
64 | def start_new_sol_epoch(self):
65 | self.cur_answers = []
66 | self.cur_gts = []
67 |
68 | def end_the_sol_epoch(self):
69 |
70 | if self.world_size != 1:
71 | gathered_answers, gathered_gts = tuple([None] * self.world_size for _ in range(2))
72 | for obj, container in [
73 | (self.cur_answers, gathered_answers),
74 | (self.cur_gts, gathered_gts),
75 | ]:
76 | dist.all_gather_object(container, obj)
77 |
78 | flatten_answers, flatten_gts = [], []
79 | for answers_gpus, gts_gpus in zip(zip(*gathered_answers), zip(*gathered_gts)):
80 | for answers, gts in zip(answers_gpus, gts_gpus):
81 | flatten_answers.extend(answers)
82 | flatten_gts.extend(gts)
83 |
84 | else:
85 | flatten_answers, flatten_gts = tuple([item for sublist in container for item in sublist]
86 | for container in [self.cur_answers, self.cur_gts])
87 |
88 | self.answers.append(flatten_answers[:self.n_data])
89 | self.gts.append(flatten_gts[:self.n_data])
90 |
91 |
92 | @torch.inference_mode(mode=True)
93 | def __call__(self, completions: List[str], gts: List[str]):
94 | answers = extract_answers(completions)
95 |
96 | answers = [float(a) if a != INVALID_ANS else float('nan') for a in answers]
97 | gts = [float(gt) for gt in gts]
98 |
99 | self.cur_answers.append(answers)
100 | self.cur_gts.append(gts)
101 |
102 |
103 | def get_metric(self, n_solution: int=3, reset=True):
104 |
105 | assert all(x == self.gts[0] for x in self.gts)
106 |
107 | # [n_question]
108 | gts = np.array(self.gts[0])
109 | # [n_question, n_solution]
110 | answers = np.stack(self.answers[:n_solution], axis=1)
111 | # print('answers:', answers.shape)
112 |
113 | pass_k = (answers == gts.reshape((-1, 1))).any(1).mean(0)
114 | acc_majority = np.mean([is_majority(a, gt, ignore=float('nan')) for a, gt in zip(answers, gts)])
115 |
116 | if reset:
117 | self.gts = []
118 | self.answers = []
119 | return pass_k, acc_majority
120 |
121 |
122 |
123 | def is_passk(answers, gt):
124 | return gt in answers
125 |
126 | def is_majority(answers, gt, ignore = INVALID_ANS):
127 | filter_answers = list(filter(lambda x: x!=ignore, answers))
128 | final_answer = max(filter_answers, key=filter_answers.count)
129 | return final_answer == gt
130 |
131 |
132 |
--------------------------------------------------------------------------------
/sft/utils/metrics.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | import torch.distributed as dist
5 | from typing import Optional, List, Dict, Set, Any, Union
6 | from utils.constants import IGNORE_INDEX
7 |
8 |
9 |
10 | class VerifierClassificationAcc:
11 | def __init__(self, n_data: int):
12 | self.n_data = n_data
13 |
14 | self.world_size = dist.get_world_size() if dist.is_initialized() else 1
15 |
16 | self.scores = []
17 | self.gts = []
18 | self.gather = False
19 |
20 | @torch.inference_mode(mode=True)
21 | def __call__(self, v_scores: torch.FloatTensor, v_labels: torch.LongTensor):
22 | bsz, n_seq = v_labels.shape
23 | index = ((n_seq - 1) - v_labels.ne(IGNORE_INDEX).flip(dims=[1]).float().argmax(1)).view(-1, 1)
24 |
25 | scores = v_scores.squeeze(-1).gather(1, index).squeeze()
26 | gts = v_labels.gather(1, index).squeeze()
27 |
28 | self.scores.append(scores.tolist())
29 | self.gts.append(gts.tolist())
30 |
31 | def get_metric(self, thres: float=0.5, reset=True):
32 | if not self.gather:
33 | if self.world_size != 1:
34 | gathered_scores, gathered_gts = tuple([None] * self.world_size for _ in range(2))
35 | for obj, container in [
36 | (self.scores, gathered_scores),
37 | (self.gts, gathered_gts),
38 | ]:
39 | dist.all_gather_object(container, obj)
40 |
41 | flatten_scores, flatten_gts = [], []
42 | for scores_gpus, gts_gpus in zip(zip(*gathered_scores), zip(*gathered_gts)):
43 | for scores, gts in zip(scores_gpus, gts_gpus):
44 | flatten_scores.extend(scores)
45 | flatten_gts.extend(gts)
46 |
47 | else:
48 | flatten_scores, flatten_gts = tuple([item for sublist in container for item in sublist]
49 | for container in [self.scores, self.gts])
50 |
51 | self.scores = flatten_scores[:self.n_data]
52 | self.gts = flatten_gts[:self.n_data]
53 | self.gather = True
54 |
55 |
56 | pred = (np.array(self.scores) > thres)
57 | corrs = np.where(np.array(self.gts).astype(bool), pred, ~pred)
58 | acc = (sum(corrs) / len(corrs))
59 |
60 | if reset:
61 | self.scores = []
62 | self.gts = []
63 | self.gather = False
64 | return acc
65 |
66 |
67 | class VerifierMPk:
68 | def __init__(self, n_data: int, n_solution_per_problem: int):
69 | self.n_data = n_data
70 | self.n_solution_per_problem = n_solution_per_problem
71 |
72 | self.world_size = dist.get_world_size() if dist.is_initialized() else 1
73 |
74 | self.preds = []
75 | self.gts = []
76 | self.gather = False
77 |
78 | @torch.inference_mode(mode=True)
79 | def __call__(self, v_scores: torch.FloatTensor, v_labels: torch.LongTensor):
80 | bsz, n_seq = v_labels.shape
81 | index = ((n_seq - 1) - v_labels.ne(IGNORE_INDEX).flip(dims=[1]).float().argmax(1)).view(-1, 1)
82 |
83 | preds = v_scores.squeeze(-1).gather(1, index).squeeze()
84 | gts = v_labels.gather(1, index).squeeze()
85 |
86 | self.preds.append(preds.tolist())
87 | self.gts.append(gts.tolist())
88 |
89 | def get_metric(self, k, reset=True):
90 | if not self.gather:
91 | if self.world_size != 1:
92 | gathered_preds, gathered_gts = tuple([None] * self.world_size for _ in range(2))
93 | for obj, container in [
94 | (self.preds, gathered_preds),
95 | (self.gts, gathered_gts),
96 | ]:
97 | dist.all_gather_object(container, obj)
98 |
99 | flatten_preds, flatten_gts = [], []
100 | for preds_gpus, gts_gpus in zip(zip(*gathered_preds), zip(*gathered_gts)):
101 | for preds, gts in zip(preds_gpus, gts_gpus):
102 | flatten_preds.extend(preds)
103 | flatten_gts.extend(gts)
104 |
105 | else:
106 | flatten_preds, flatten_gts = tuple([item for sublist in container for item in sublist]
107 | for container in [self.preds, self.gts])
108 |
109 | self.preds = flatten_preds[:self.n_data]
110 | self.gts = flatten_gts[:self.n_data]
111 | self.gather = True
112 |
113 | preds = np.array(self.preds).reshape(-1, self.n_solution_per_problem)
114 | gts = np.array(self.gts).reshape(-1, self.n_solution_per_problem)
115 |
116 | indices = np.argsort(-preds, axis=1)
117 | gts = np.take_along_axis(gts, indices, axis=1)
118 |
119 | # [n_question, k]
120 | gts_topk = gts[:, :k]
121 |
122 | # how portion of solutions predicted topest are really correct
123 | mpk = gts_topk.mean(1).mean(0)
124 |
125 | if reset:
126 | self.preds = []
127 | self.gts = []
128 | self.gather = False
129 | return mpk
130 |
131 |
132 | class GenWithVerifierAcc:
133 | def __init__(self, n_data: int, n_solution_per_problem: int):
134 | self.n_data = n_data
135 | self.n_solution_per_problem = n_solution_per_problem
136 |
137 | self.world_size = dist.get_world_size() if dist.is_initialized() else 1
138 |
139 | self.preds = []
140 | self.gts = []
141 | self.gather = False
142 |
143 | @torch.inference_mode(mode=True)
144 | def __call__(self, v_scores: torch.FloatTensor, v_labels: torch.LongTensor):
145 | bsz, n_seq = v_labels.shape
146 | index = ((n_seq - 1) - v_labels.ne(IGNORE_INDEX).flip(dims=[1]).float().argmax(1)).view(-1, 1)
147 |
148 | preds = v_scores.squeeze(-1).gather(1, index).squeeze()
149 | gts = v_labels.gather(1, index).squeeze()
150 |
151 | self.preds.append(preds.tolist())
152 | self.gts.append(gts.tolist())
153 |
154 |
155 | def get_metric(self, k, reset=True):
156 | if not self.gather:
157 | if self.world_size != 1:
158 | gathered_preds, gathered_gts = tuple([None] * self.world_size for _ in range(2))
159 | for obj, container in [
160 | (self.preds, gathered_preds),
161 | (self.gts, gathered_gts),
162 | ]:
163 | dist.all_gather_object(container, obj)
164 |
165 | flatten_preds, flatten_gts = [], []
166 | for preds_gpus, gts_gpus in zip(zip(*gathered_preds), zip(*gathered_gts)):
167 | for preds, gts in zip(preds_gpus, gts_gpus):
168 | flatten_preds.extend(preds)
169 | flatten_gts.extend(gts)
170 |
171 | else:
172 | flatten_preds, flatten_gts = tuple([item for sublist in container for item in sublist]
173 | for container in [self.preds, self.gts])
174 |
175 | self.preds = flatten_preds[:self.n_data]
176 | self.gts = flatten_gts[:self.n_data]
177 | self.gather = True
178 |
179 |
180 | preds = np.array(self.preds).reshape(-1, self.n_solution_per_problem)
181 | gts = np.array(self.gts).reshape(-1, self.n_solution_per_problem)
182 | gts = gts[:, :k]
183 | preds = preds[:, :k]
184 |
185 | indices = np.argsort(-preds, axis=1)
186 | gts = np.take_along_axis(gts, indices, axis=1)
187 |
188 | acc = gts[:, 0].mean(0)
189 |
190 | if reset:
191 | self.preds = []
192 | self.gts = []
193 | self.gather = False
194 | return acc
195 |
196 |
--------------------------------------------------------------------------------
/sft/utils/models.py:
--------------------------------------------------------------------------------
1 | from utils.flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
2 | from utils.cached_models import build_transformers_mapping_to_cached_models, build_transformers_mapping_to_custom_tokenizers
3 | replace_llama_attn_with_flash_attn()
4 | build_transformers_mapping_to_cached_models()
5 | build_transformers_mapping_to_custom_tokenizers()
6 |
7 | from typing import Optional, List, Dict, Set, Any, Union, Callable, Mapping
8 | from torch import nn
9 | import torch
10 | import pathlib
11 | from dataclasses import dataclass
12 | from accelerate import Accelerator
13 | import os
14 | import re
15 | import shutil
16 | from functools import wraps
17 | import transformers
18 | from utils.constants import DEFAULT_PAD_TOKEN, DEFAULT_BOS_TOKEN, DEFAULT_EOS_TOKEN, DEFAULT_UNK_TOKEN
19 |
20 |
21 | def smart_tokenizer_and_embedding_resize_for_pad(
22 | special_tokens_dict: Dict,
23 | tokenizer: transformers.PreTrainedTokenizer,
24 | model: transformers.PreTrainedModel,
25 | ):
26 | """Resize tokenizer and embedding.
27 | Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
28 | """
29 | num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
30 | model.resize_token_embeddings(len(tokenizer))
31 | if num_new_tokens > 0:
32 | input_embeddings = model.get_input_embeddings().weight.data
33 | output_embeddings = model.get_output_embeddings().weight.data
34 |
35 | input_embeddings[-1] = torch.zeros_like(input_embeddings)[0]
36 | output_embeddings[-1] = torch.zeros_like(output_embeddings)[0]
37 |
38 |
39 | def build_model(model_args: dataclass, training_args: dataclass):
40 | # Step 1: Initialize LLM
41 | print(f"+ [Model] Initializing LM: {model_args.model_name_or_path}")
42 | model = transformers.AutoModelForCausalLM.from_pretrained(
43 | model_args.model_name_or_path,
44 | use_cache=False,
45 | trust_remote_code=True
46 | )
47 | try:
48 | if training_args.gradient_checkpointing:
49 | model.gradient_checkpointing_enable()
50 | except:
51 | pass
52 |
53 | # Step 2: Initialize tokenizer
54 | print(f"+ [Model] Initializing Tokenizer: {model_args.model_name_or_path}")
55 |
56 | tokenizer = transformers.AutoTokenizer.from_pretrained(
57 | model_args.model_name_or_path,
58 | cache_dir=training_args.cache_dir,
59 | model_max_length=training_args.model_max_length,
60 | padding_side="right",
61 | use_fast=False,
62 | )
63 |
64 | # Step 3: Add special tokens
65 | if "mistral" in model_args.model_name_or_path.lower() or "llama" in model_args.model_name_or_path.lower() or "llema" in model_args.model_name_or_path.lower():
66 | if tokenizer.pad_token is None:
67 | smart_tokenizer_and_embedding_resize_for_pad(
68 | special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
69 | tokenizer=tokenizer,
70 | model=model,
71 | )
72 | tokenizer.add_special_tokens({
73 | "eos_token": DEFAULT_EOS_TOKEN,
74 | "bos_token": DEFAULT_BOS_TOKEN,
75 | "unk_token": DEFAULT_UNK_TOKEN,
76 | })
77 |
78 | # Step 4: Align special token ids between tokenizer and model.config
79 | model.config.pad_token_id = tokenizer.pad_token_id
80 | model.config.bos_token_id = tokenizer.bos_token_id
81 | model.config.eos_token_id = tokenizer.eos_token_id
82 |
83 | else: # Deepseek_math
84 | tokenizer.pad_token_id = tokenizer.eos_token_id
85 | model.generation_config = transformers.GenerationConfig.from_pretrained(model_args.model_name_or_path)
86 | model.generation_config.pad_token_id = model.generation_config.eos_token_id
87 |
88 | return model, tokenizer
89 |
90 | def safe_delete_with_accelerator(accelerator: Accelerator, path: str):
91 | @accelerator.on_main_process
92 | def delete(path):
93 | shutil.rmtree(path, ignore_errors=True)
94 |
95 | delete(path)
96 |
97 | def safe_move_with_accelerator(accelerator: Accelerator, ori_path: str, new_path: str):
98 | @accelerator.on_main_process
99 | def move(ori_path, new_path):
100 | try:
101 | shutil.move(ori_path, new_path)
102 | except:
103 | ...
104 |
105 | move(ori_path, new_path)
106 |
107 |
108 |
109 | def wrapper_safe_save_model_with_accelerator(save_model_func):
110 | @wraps(save_model_func)
111 | def wrapper(accelerator: Accelerator,
112 | model: nn.Module,
113 | tokenizer: transformers.AutoTokenizer,
114 | output_dir: str):
115 | @accelerator.on_main_process
116 | def save_model(cpu_state_dict, output_dir):
117 | save_model_func(accelerator=accelerator, model=model, cpu_state_dict=cpu_state_dict, output_dir=output_dir)
118 | @accelerator.on_main_process
119 | def save_tokenizer(output_dir):
120 | tokenizer.save_pretrained(output_dir)
121 |
122 | os.makedirs(output_dir, exist_ok=True)
123 | state_dict = accelerator.get_state_dict(model)
124 | cpu_state_dict = {
125 | key: value.cpu()
126 | for key, value in state_dict.items()
127 | }
128 |
129 | save_model(cpu_state_dict, output_dir)
130 | save_tokenizer(output_dir)
131 |
132 | print(f"+ [Save] Save model and tokenizer to: {output_dir}")
133 | return wrapper
134 |
135 |
136 | # refer to transformers.trainer._sorted_checkpoints "https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/trainer.py#L2848"
137 | def wrapper_save_checkpoint(save_func):
138 | @wraps(save_func)
139 | def outwrapper(func):
140 | def wrapper(accelerator: Accelerator,
141 | model: transformers.AutoModelForCausalLM,
142 | tokenizer: transformers.PreTrainedTokenizer,
143 | output_dir: str,
144 | global_step: int,
145 | save_total_limit: int=None):
146 | checkpoint_output_dir = os.path.join(output_dir, f'checkpoint-{global_step}')
147 | if os.path.exists(checkpoint_output_dir) or save_total_limit < 1:
148 | return
149 | save_func(accelerator=accelerator, model=model, tokenizer=tokenizer, output_dir=checkpoint_output_dir)
150 |
151 | ordering_and_checkpoint_path = []
152 | glob_checkpoints = [str(x) for x in pathlib.Path(output_dir).glob('*checkpoint-*')]
153 | for path in glob_checkpoints:
154 | regex_match = re.match(r".*checkpoint-([0-9]+)", path)
155 | if regex_match is not None:
156 | ordering_and_checkpoint_path.append((int(regex_match.group(1)), path))
157 |
158 | checkpoints_sorted = sorted(ordering_and_checkpoint_path)
159 | checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
160 |
161 | best_checkpoint = [str(x) for x in pathlib.Path(output_dir).glob('best-checkpoint-*')]
162 | if best_checkpoint:
163 | best_checkpoint = best_checkpoint[0]
164 | best_model_index = checkpoints_sorted.index(best_checkpoint)
165 | for i in range(best_model_index, len(checkpoints_sorted) - 2):
166 | checkpoints_sorted[i], checkpoints_sorted[i + 1] = checkpoints_sorted[i + 1], checkpoints_sorted[i]
167 |
168 | if save_total_limit:
169 | checkpoints_to_be_deleted = checkpoints_sorted[:-save_total_limit]
170 | for checkpoint in checkpoints_to_be_deleted:
171 | safe_delete_with_accelerator(accelerator, checkpoint)
172 | return wrapper
173 | return outwrapper
174 |
175 |
176 | def wrapper_save_best_checkpoint(save_checkpoint_func):
177 | @wraps(save_checkpoint_func)
178 | def outwrapper(func):
179 | def wrapper(accelerator: Accelerator,
180 | model: transformers.AutoModelForCausalLM,
181 | tokenizer: transformers.PreTrainedTokenizer,
182 | output_dir: str,
183 | global_step: int,
184 | save_total_limit: int=None):
185 |
186 | ori_best_checkpoint = [str(x) for x in pathlib.Path(output_dir).glob('best-checkpoint-*')]
187 | if ori_best_checkpoint:
188 | ori_best_checkpoint = ori_best_checkpoint[0]
189 | filename = os.path.basename(os.path.normpath(ori_best_checkpoint))[5:]
190 | safe_move_with_accelerator(accelerator, ori_best_checkpoint, os.path.join(output_dir, filename))
191 |
192 | save_checkpoint_func(accelerator=accelerator, model=model, tokenizer=tokenizer, output_dir=output_dir, global_step=global_step, save_total_limit=save_total_limit)
193 | checkpoint_dir = os.path.join(output_dir, f'checkpoint-{global_step}')
194 | best_checkpoint_dir = os.path.join(output_dir, f'best-checkpoint-{global_step}')
195 | safe_move_with_accelerator(accelerator, checkpoint_dir, best_checkpoint_dir)
196 | return wrapper
197 | return outwrapper
198 |
199 |
200 |
201 |
202 |
203 | @wrapper_safe_save_model_with_accelerator
204 | def save_llm(accelerator: Accelerator,
205 | model: transformers.AutoModelForCausalLM,
206 | cpu_state_dict: Mapping,
207 | output_dir: str):
208 | accelerator.unwrap_model(model).save_pretrained(
209 | output_dir,
210 | state_dict=cpu_state_dict,
211 | is_main_process=accelerator.is_main_process,
212 | save_function=accelerator.save,
213 | )
214 |
215 |
216 | @wrapper_save_checkpoint(save_func=save_llm)
217 | def save_llm_checkpoint(accelerator: Accelerator,
218 | model: transformers.AutoModelForCausalLM,
219 | tokenizer: transformers.PreTrainedTokenizer,
220 | checkpoint_output_dir: str):
221 | ...
222 |
223 |
224 | @wrapper_save_best_checkpoint(save_checkpoint_func=save_llm_checkpoint)
225 | def save_best_llm_checkpoint(accelerator: Accelerator,
226 | model: transformers.AutoModelForCausalLM,
227 | tokenizer: transformers.PreTrainedTokenizer,
228 | output_dir: str,
229 | global_step: int,
230 | save_total_limit: int=None):
231 | ...
232 |
233 |
234 |
235 | def save_training_args_with_accelerator(accelerator: Accelerator,
236 | training_args: dataclass,
237 | output_dir: str):
238 | output_file = os.path.join(output_dir, 'training_args.bin')
239 | accelerator.save(training_args, output_file)
240 |
241 | print(f"+ [Save] Save training_args to: {output_file}")
242 |
243 |
244 |
245 |
246 |
--------------------------------------------------------------------------------
/sft/utils/optim.py:
--------------------------------------------------------------------------------
1 | from transformers import AdamW
2 | from transformers import get_scheduler
3 | import transformers
4 | from typing import Optional, List, Dict, Set, Any, Union
5 | from dataclasses import dataclass
6 | import os
7 |
8 | def get_optimizers(model: transformers.AutoModelForCausalLM, training_args: dataclass) -> Dict:
9 | no_decay = ["bias", "LayerNorm.weight"]
10 | optimizer_grouped_parameters = [
11 | {
12 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
13 | "weight_decay": training_args.weight_decay,
14 | },
15 | {
16 | "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
17 | "weight_decay": 0.0,
18 | },
19 | ]
20 |
21 | optim = AdamW(
22 | optimizer_grouped_parameters,
23 | lr=training_args.learning_rate,
24 | # weight_decay=training_args.weight_decay
25 | )
26 | lr_scheduler = get_scheduler(
27 | training_args.lr_scheduler_type,
28 | optimizer=optim,
29 | # num_warmup_steps=training_args.num_updating_warmup_steps_aggr_devices,
30 | # num_training_steps=training_args.num_updating_steps_aggr_devices,
31 | num_warmup_steps=training_args.num_updating_warmup_steps,
32 | num_training_steps=training_args.num_updating_steps,
33 | )
34 | return optim, lr_scheduler
35 |
36 |
37 |
38 |
--------------------------------------------------------------------------------
/sft/utils/states.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import json
4 | from dataclasses import dataclass
5 | import random
6 | import math
7 | import numpy as np
8 | from accelerate import Accelerator
9 |
10 |
11 | def set_deepspeed_config(accelerator: Accelerator, training_args: dataclass):
12 | world_size = int(os.environ.get("WORLD_SIZE", 1))
13 | accelerator.state.deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] = training_args.per_device_train_batch_size
14 | accelerator.state.deepspeed_plugin.deepspeed_config['train_batch_size'] = training_args.per_device_train_batch_size * world_size * accelerator.gradient_accumulation_steps
15 |
16 |
17 | def set_training_states(data_module: dict, training_args: dataclass):
18 | set_num_steps_per_epoch(data_module, training_args)
19 | set_num_training_steps(training_args)
20 | set_num_updating_steps(training_args)
21 | set_num_eval_steps(training_args)
22 | set_per_eval_steps(training_args)
23 | set_num_warmup_steps(training_args)
24 |
25 | set_num_logging_steps(training_args)
26 | set_per_save_steps(training_args)
27 |
28 | print(f"+ [Training States] There are {training_args.num_training_steps} steps in total.")
29 |
30 |
31 | def set_num_steps_per_epoch(data_module: dict, training_args: dataclass):
32 | num_devices = int(os.environ.get("WORLD_SIZE", 1))
33 |
34 | len_train_set_per_device = math.ceil(len(data_module["train_dataset"]) / num_devices)
35 | num_train_steps_per_device = math.ceil(len_train_set_per_device / training_args.per_device_train_batch_size)
36 | num_updating_steps_per_epoch = num_train_steps_per_device // training_args.gradient_accumulation_steps
37 |
38 | len_eval_set_per_device = math.ceil(len(data_module["val_dataset"]) / num_devices) if data_module["val_dataset"] is not None else None
39 | num_eval_steps_per_device = math.ceil(len_eval_set_per_device / training_args.per_device_eval_batch_size) if data_module["val_dataset"] is not None else None
40 |
41 | training_args.num_training_steps_per_epoch = num_train_steps_per_device
42 | training_args.num_updating_steps_per_epoch = num_updating_steps_per_epoch
43 | training_args.num_eval_steps_per_epoch = num_eval_steps_per_device
44 |
45 | def set_num_training_steps(training_args: dataclass):
46 | if training_args.max_steps != -1:
47 | num_training_steps = training_args.max_steps
48 | else:
49 | assert training_args.num_train_epoches != -1
50 | num_training_steps = training_args.num_training_steps_per_epoch * training_args.num_train_epoches
51 | num_training_steps_aggr_devices = num_training_steps * int(os.environ.get("WORLD_SIZE", 1))
52 |
53 | training_args.num_training_steps = num_training_steps
54 | training_args.num_training_steps_aggr_devices = num_training_steps_aggr_devices
55 |
56 | def set_num_updating_steps(training_args: dataclass):
57 | num_updating_steps = training_args.num_training_steps // training_args.gradient_accumulation_steps
58 | num_updating_steps_aggr_devices = num_updating_steps * int(os.environ.get("WORLD_SIZE", 1))
59 |
60 | training_args.num_updating_steps = num_updating_steps
61 | training_args.num_updating_steps_aggr_devices = num_updating_steps_aggr_devices
62 |
63 |
64 | def set_num_eval_steps(training_args: dataclass):
65 | training_args.num_eval_steps = training_args.num_eval_steps_per_epoch
66 |
67 | def set_per_eval_steps(training_args: dataclass):
68 | if training_args.eval_steps != -1:
69 | per_eval_steps = training_args.eval_steps
70 | else:
71 | assert training_args.eval_epoches != -1
72 | per_eval_steps = training_args.num_training_steps_per_epoch * training_args.eval_epoches
73 |
74 | training_args.per_eval_steps = per_eval_steps
75 |
76 | def set_num_warmup_steps(training_args: dataclass):
77 | # if training_args.warmup_steps != -1:
78 | # num_warmup_steps_forward = training_args.warmup_steps
79 | # else:
80 | # assert training_args.warmup_ratio != -1
81 | # num_warmup_steps_forward = int(training_args.num_training_steps * training_args.warmup_ratio)
82 | # num_updating_warmup_steps = num_warmup_steps_forward // training_args.gradient_accumulation_steps
83 | # num_updating_warmup_steps_aggr_devices = num_updating_warmup_steps * int(os.environ.get("WORLD_SIZE", 1))
84 | if training_args.warmup_steps != -1:
85 | num_updating_warmup_steps = training_args.warmup_steps
86 | else:
87 | assert training_args.warmup_ratio != -1
88 | num_updating_warmup_steps = int(training_args.num_updating_steps * training_args.warmup_ratio)
89 | num_updating_warmup_steps_aggr_devices = num_updating_warmup_steps * int(os.environ.get("WORLD_SIZE", 1))
90 |
91 | training_args.num_updating_warmup_steps = num_updating_warmup_steps
92 | training_args.num_updating_warmup_steps_aggr_devices = num_updating_warmup_steps_aggr_devices
93 |
94 | def set_num_logging_steps(training_args: dataclass):
95 | if training_args.logging_steps != -1:
96 | num_logging_steps = training_args.logging_steps
97 | else:
98 | assert training_args.logging_epoches != -1
99 | num_logging_steps = training_args.num_training_steps_per_epoch * training_args.logging_epoches
100 |
101 | training_args.num_logging_steps = num_logging_steps
102 |
103 | def set_per_save_steps(training_args: dataclass):
104 | if training_args.save_steps != -1:
105 | per_save_steps = training_args.save_steps
106 | else:
107 | assert training_args.save_epoches != -1
108 | per_save_steps = training_args.num_training_steps_per_epoch * training_args.save_epoches
109 |
110 | training_args.per_save_steps = per_save_steps
111 |
112 |
113 | def set_random_seed(seed: int):
114 | random.seed(seed)
115 | np.random.seed(seed)
116 | torch.manual_seed(seed)
117 | torch.cuda.manual_seed_all(seed)
118 | torch.cuda.manual_seed(seed)
119 |
120 |
121 |
--------------------------------------------------------------------------------
/sft/utils/verifier_models.py:
--------------------------------------------------------------------------------
1 | from utils.models import wrapper_safe_save_model_with_accelerator, wrapper_save_checkpoint, wrapper_save_best_checkpoint, build_model, load_model
2 | from utils.sampling import shift_padding_to_left_2D, shift_padding_to_right_2D, find_rightmost_notpadded_positions
3 | from utils.constants import IGNORE_INDEX
4 |
5 |
6 | from typing import Optional, List, Dict, Set, Any, Union, Callable, Mapping
7 | import transformers
8 | from transformers.generation.utils import ModelOutput
9 | from torch import nn
10 | import torch.nn.functional as F
11 | import torch
12 | import pathlib
13 | import logging
14 | from dataclasses import dataclass, field
15 | from accelerate import Accelerator
16 | import os
17 | import re
18 | import shutil
19 |
20 |
21 | @dataclass
22 | class VerifierModelOutput(ModelOutput):
23 | loss: Optional[torch.FloatTensor] = None
24 | v_scores: torch.FloatTensor = None
25 | all_losses: Optional[Dict[str, torch.FloatTensor]] = None
26 |
27 |
28 | class Verifier(nn.Module):
29 | def __init__(self, backbone, checkpoint_dir=None):
30 | super(Verifier, self).__init__()
31 | self.backbone = backbone
32 |
33 | self.gain = nn.Parameter(torch.randn(1,))
34 | self.bias = nn.Parameter(torch.randn(1,))
35 | self.dropout = nn.Dropout(p=0.2)
36 | self.vscore_head = nn.Linear(self.backbone.get_input_embeddings().embedding_dim, 1, bias=False)
37 |
38 | if checkpoint_dir and os.path.exists(os.path.join(checkpoint_dir, 'verifier.pth')):
39 | verifier_params = torch.load(os.path.join(checkpoint_dir, 'verifier.pth'))
40 | self.load_state_dict(verifier_params, strict=False)
41 | else:
42 | self.init_head_params()
43 |
44 | self.pad_token_id = backbone.config.pad_token_id
45 |
46 | def init_head_params(self):
47 | output_embeddings = self.backbone.get_output_embeddings().weight.data
48 | output_embeddings_avg = output_embeddings.mean(dim=0, keepdim=True)
49 |
50 | self.vscore_head.weight = nn.Parameter(output_embeddings_avg)
51 |
52 | def loss_fct(self, v_scores: torch.FloatTensor, v_labels: torch.LongTensor):
53 | # (batch_size, n_seq, 1)
54 | return mse_loss_with_mask(v_scores.squeeze(), v_labels.type_as(v_scores))
55 |
56 | def transform(self, last_hidden_states):
57 | return self.gain * last_hidden_states + self.bias
58 |
59 | def forward(self,
60 | input_ids: torch.LongTensor,
61 | attention_mask: Optional[torch.Tensor] = None,
62 | position_ids: Optional[torch.LongTensor] = None,
63 | past_key_values: Optional[List[torch.FloatTensor]] = None,
64 | labels: Optional[torch.LongTensor] = None,
65 | v_labels: Optional[torch.LongTensor] = None,
66 | output_all_losses: Optional[bool] = None,
67 | ):
68 | outputs = self.backbone(
69 | input_ids=input_ids,
70 | attention_mask=attention_mask,
71 | position_ids=position_ids,
72 | past_key_values=past_key_values,
73 | labels=labels,
74 | use_cache=False,
75 | output_hidden_states=True,
76 | return_dict=True,
77 | )
78 | llm_logits = outputs.logits
79 | llm_loss = outputs.loss
80 | llm_hidden_states = outputs.hidden_states
81 |
82 | # (batch_size, n_seq, embed_dim)
83 | v_hidden_states = self.transform(llm_hidden_states[-1])
84 | # (batch_size, n_seq, 1)
85 | v_scores = self.vscore_head(self.dropout(v_hidden_states))
86 |
87 | v_loss, loss = None, None
88 | if v_labels is not None:
89 | v_loss = self.loss_fct(v_scores, v_labels)
90 | loss = v_loss + (llm_loss if labels is not None else 0)
91 |
92 | all_losses = None
93 | if output_all_losses:
94 | all_losses = {'llm_loss': llm_loss, 'v_loss': v_loss}
95 |
96 | return VerifierModelOutput(
97 | loss=loss,
98 | v_scores=v_scores,
99 | all_losses=all_losses,
100 | )
101 |
102 | @torch.inference_mode(mode=True)
103 | def scoring_sequences(self, input_ids: torch.LongTensor):
104 | input_ids = shift_padding_to_right_2D(input_ids, pad_value=self.pad_token_id)
105 | outputs = self(
106 | input_ids=input_ids,
107 | attention_mask=input_ids.ne(self.pad_token_id),
108 | )
109 | inds = find_rightmost_notpadded_positions(input_ids, pad_value=self.pad_token_id)
110 | return outputs.v_scores[:, :, -1].gather(1, inds.view(-1, 1)).squeeze(-1)
111 |
112 | def gradient_checkpointing_enable(self):
113 | self.backbone.gradient_checkpointing_enable()
114 |
115 | def gradient_checkpointing_disable(self):
116 | self.backbone.gradient_checkpointing_disable()
117 |
118 |
119 | def mse_loss_with_mask(scores: torch.FloatTensor, labels: torch.FloatTensor):
120 | scores = torch.where(labels.ne(IGNORE_INDEX), scores, 0)
121 | labels = torch.where(labels.ne(IGNORE_INDEX), labels, 0)
122 | return F.mse_loss(scores, labels, reduction='sum') / scores.shape[0]
123 |
124 |
125 |
126 | @wrapper_safe_save_model_with_accelerator
127 | def save_verifier(accelerator: Accelerator,
128 | model: transformers.AutoModelForCausalLM,
129 | cpu_state_dict: Mapping,
130 | output_dir: str):
131 | cpu_state_dict_backbone = {
132 | k.split('backbone.')[1]: v
133 | for k, v in cpu_state_dict.items() if k.startswith('backbone')
134 | }
135 | cpu_state_dict_verifier = {
136 | k: v
137 | for k, v in cpu_state_dict.items() if not k.startswith('backbone')
138 | }
139 | accelerator.unwrap_model(model).backbone.save_pretrained(
140 | output_dir,
141 | state_dict=cpu_state_dict_backbone,
142 | is_main_process=accelerator.is_main_process,
143 | save_function=accelerator.save,
144 | )
145 | accelerator.save(cpu_state_dict_verifier, os.path.join(output_dir, 'verifier.pth'))
146 |
147 |
148 | @wrapper_save_checkpoint(save_func=save_verifier)
149 | def save_verifier_checkpoint(accelerator: Accelerator,
150 | model: transformers.AutoModelForCausalLM,
151 | tokenizer: transformers.PreTrainedTokenizer,
152 | checkpoint_output_dir: str):
153 | ...
154 |
155 |
156 | @wrapper_save_best_checkpoint(save_checkpoint_func=save_verifier_checkpoint)
157 | def save_best_verifier_checkpoint(accelerator: Accelerator,
158 | model: transformers.AutoModelForCausalLM,
159 | tokenizer: transformers.PreTrainedTokenizer,
160 | output_dir: str,
161 | global_step: int,
162 | save_total_limit: int=None):
163 | ...
164 |
165 |
166 |
167 |
168 | def build_verifier(model_args: dataclass, training_args: dataclass):
169 | backbone, tokenizer = build_model(model_args, training_args)
170 | return Verifier(backbone), tokenizer
171 |
172 |
173 | def load_verifier(model_args: dataclass):
174 | backbone, tokenizer = load_model(model_args)
175 | return Verifier(backbone, checkpoint_dir=model_args.model_name_or_path), tokenizer
176 |
177 |
178 | def load_generator_and_verifier(model_args: dataclass):
179 | generator, tokenizer = load_model(model_args)
180 |
181 | v_backbone = transformers.AutoModelForCausalLM.from_pretrained(
182 | model_args.verifier_model_name_or_path,
183 | torch_dtype=torch.float16 if model_args.fp16 else torch.bfloat16,
184 | )
185 |
186 | verifier = Verifier(v_backbone, checkpoint_dir=model_args.verifier_model_name_or_path)
187 | return generator, verifier, tokenizer
188 |
189 |
190 |
191 |
--------------------------------------------------------------------------------