├── .gitignore ├── LICENSE ├── README.md ├── arguments.py ├── checkpoint_utils.py ├── data_utils ├── data_utils_dpo.py ├── data_utils_ppo.py ├── data_utils_rm_pairwise.py ├── data_utils_rm_pointwise.py └── data_utils_sft.py ├── examples └── README.md ├── hf_argparser.py ├── inference_generate.py ├── inference_reward.py ├── interactive.py ├── models ├── frozen_layers.py ├── model.py ├── quantize.py ├── reward_model.py ├── rl_model.py ├── tokenizer_utils.py └── tp.py ├── scripts ├── convert_checkpoint_to_hf.py ├── convert_hf_checkpoint.py ├── download.py ├── prepare_ds_math_7b.sh ├── prepare_llemma_34b.sh └── prepare_llemma_7b.sh ├── train_rl_dpo.py ├── train_rl_ppo.py ├── train_rm_pairwise.py ├── train_rm_pointwise.py ├── train_sft.py ├── trainers ├── common_utils.py ├── ppo_trainer.py └── rl_trainer.py └── training_utils ├── fsdp_utils.py ├── memory_efficient_adam.py └── trainer_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2024, Zhiqing Sun 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | 3. Neither the name of the copyright holder nor the names of its 16 | contributors may be used to endorse or promote products derived from 17 | this software without specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # gpt-accelera 2 | 3 | Simple and efficient pytorch-native transformer training and inference (batched). 4 | 5 | `gpt-accelera` is a codebase based on [`gpt-fast`](https://github.com/pytorch-labs/gpt-fast/tree/main) -- the state-of-the-art pytorch-native tensor-parallel implementation of transformer text generation that minimizes latency (i.e. batch size=1) -- with the following improvements: 6 | 7 | Featuring: 8 | 9 | - Batched (i.e., batch size > 1) inference with compiled graph (i.e., torch.compile) 10 | - 2-D parallelism (Tensor-Parallel (TP) + Fully Sharded Data Parallel (FSDP)) training with mixed precision (i.e., torch.cuda.amp) 11 | - Supports for both LLaMA and DeepSeek models 12 | - Supports training policy models with Supervised Fine-Tuning (SFT) 13 | - Supports training reward models (RM) with pointwise and pairwise losses 14 | - Supports on-policy (PPO) and off-policy (DPO) reinforcement learning (RL) training 15 | - All the training can be performed with full fine-tuning for `7b-34b LLaMA/Llemma` models 16 | 17 | Shared features w/ `gpt-fast`: 18 | 19 | - Very low latency (on inference, batched inference, SFT, and PPO) 20 | - No dependencies other than PyTorch and sentencepiece 21 | - int8/int4 quantization (for inference and ref_policy / reward_model in PPO) 22 | - Supports Nvidia and AMD GPUs (?, TODO: test the codebase on AMD) 23 | 24 | Following the spirit of `gpt-fast`, this repository is NOT intended to be a "framework" or "library", but to show off what kind of performance you can get with native PyTorch. Please copy-paste and fork as you desire. 25 | 26 | ## Installation 27 | 28 | Install `torch==2.2.0`, `sentencepiece`, and `huggingface_hub`: 29 | 30 | ```bash 31 | pip install sentencepiece huggingface_hub 32 | ``` 33 | 34 | ## Downloading Weights 35 | 36 | Models tested/supported 37 | 38 | ``` 39 | meta-llama/Llama-2-7b-chat-hf 40 | meta-llama/Llama-2-13b-chat-hf 41 | meta-llama/Llama-2-70b-chat-hf 42 | codellama/CodeLlama-7b-Python-hf 43 | codellama/CodeLlama-34b-Python-hf 44 | EleutherAI/llemma_7b 45 | EleutherAI/llemma_34b 46 | deepseek-ai/deepseek-llm-7b-base 47 | deepseek-ai/deepseek-coder-6.7b-base 48 | deepseek-ai/deepseek-math-7b-base 49 | ``` 50 | 51 | ## Benchmarks 52 | 53 | TODO: Add benchmarks 54 | 55 | ## Running reference methods 56 | 57 | TODO: Add reference methods 58 | 59 | ## License 60 | 61 | Following `gpt-fast`, `gpt-accelera` is licensed under the BSD 3 license. See the LICENSE file for details. 62 | 63 | ### Community 64 | 65 | The `gpt-accelera` codebase is developed during the research and development of the [Easy-to-Hard Generalization](https://github.com/Edward-Sun/easy-to-hard/tree/main) project. 66 | 67 | ### Citation 68 | 69 | Please consider citing our work if you use the data or code in this repo. 70 | 71 | ``` 72 | @misc{gpt_accelera, 73 | author = {Zhiqing Sun }, 74 | title = {GPT-Accelera: Simple and efficient pytorch-native transformer training and inference (batched)}, 75 | year = {2024}, 76 | publisher = {GitHub}, 77 | journal = {GitHub repository}, 78 | howpublished = {\url{https://github.com/Edward-Sun/gpt-accelera}} 79 | } 80 | ``` 81 | 82 | ### Acknowledgements 83 | 84 | We thank the authors of following works for their open-source efforts in democratizing large language models. 85 | 86 | - The compiled generation part of `gpt-accelera` is adopted from [`gpt-fast`](https://github.com/pytorch-labs/gpt-fast/tree/main) 87 | - The RL part of `gpt-accelera` is adopted from [`SALMON`](https://github.com/IBM/SALMON), which is from [`alpaca_farm`](https://github.com/tatsu-lab/alpaca_farm). 88 | - The tokenization part of `gpt-accelera` is adopted from [`transformers`](https://github.com/huggingface/transformers/tree/main) 89 | -------------------------------------------------------------------------------- /data_utils/data_utils_dpo.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The GPT-Accelera Team 2 | # Copyright 2023 The Alpaca Team 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from typing import Dict, Sequence, Union 17 | 18 | import numpy as np 19 | import torch 20 | from torch.utils.data import Dataset 21 | from datasets import Dataset as HFDataset 22 | 23 | from arguments import Arguments 24 | import trainers.common_utils as utils 25 | from models.tokenizer_utils import AcceleraTokenizer 26 | from data_utils.data_utils_sft import preprocess_for_sft, extract_alpaca_dataset 27 | 28 | 29 | class DPODataset(Dataset): 30 | def __init__( 31 | self, 32 | args: Arguments, 33 | dataset: HFDataset, 34 | tokenizer: AcceleraTokenizer, 35 | ): 36 | super(DPODataset, self).__init__() 37 | self.tensors = preprocess_for_dpo( 38 | args=args, 39 | dataset=dataset, 40 | tokenizer=tokenizer, 41 | ) 42 | 43 | def __len__(self): 44 | return len(next(iter(self.tensors.values()))) 45 | 46 | def __getitem__(self, i) -> Dict[str, torch.Tensor]: 47 | return {key: value[i] for key, value in self.tensors.items()} 48 | 49 | 50 | def preprocess_for_dpo( 51 | args: Arguments, 52 | dataset: HFDataset, 53 | tokenizer: AcceleraTokenizer, 54 | reorder_wl: bool = True, 55 | ) -> dict[str, Union[torch.Tensor, Sequence[torch.Tensor]]]: 56 | df = dataset.to_pandas() 57 | output_1, output_2, preference = df["output_1"], df["output_2"], df["preference"] 58 | 59 | assign_w_kwargs = dict( 60 | output=np.where(preference == 1, output_1, output_2), 61 | ) 62 | assign_l_kwargs = dict( 63 | output=np.where(preference == 2, output_1, output_2), 64 | ) 65 | assign_keys = ["instruction", "input", "output"] 66 | 67 | if "is_eos_1" in df.columns: 68 | is_eos_1, is_eos_2 = df["is_eos_1"], df["is_eos_2"] 69 | assign_w_kwargs.update( 70 | is_eos=np.where(preference == 1, is_eos_1, is_eos_2), 71 | ) 72 | assign_l_kwargs.update( 73 | is_eos=np.where(preference == 2, is_eos_1, is_eos_2), 74 | ) 75 | assign_keys.extend(["is_eos"]) 76 | 77 | if "win_rate_1" in df.columns: 78 | win_rate_1, win_rate_2 = df["win_rate_1"], df["win_rate_2"] 79 | assign_w_kwargs.update( 80 | win_rate=np.where(preference == 1, win_rate_1, win_rate_2), 81 | ) 82 | assign_l_kwargs.update( 83 | win_rate=np.where(preference == 2, win_rate_1, win_rate_2), 84 | ) 85 | assign_keys.extend(["win_rate"]) 86 | 87 | if reorder_wl: 88 | df_w = df.assign(**assign_w_kwargs)[assign_keys] 89 | df_l = df.assign(**assign_l_kwargs)[assign_keys] 90 | else: 91 | df_w = df.assign(output=output_1)[assign_w_kwargs] 92 | df_l = df.assign(output=output_2)[assign_l_kwargs] 93 | 94 | df_w_list = df_w.to_dict("records") 95 | df_l_list = df_l.to_dict("records") 96 | 97 | assert len(df_w_list) == len(df_l_list) 98 | 99 | if args.dataset_format == "alpaca": 100 | for i in range(len(df_w_list)): 101 | df_w_list[i].update(extract_alpaca_dataset(df_w_list[i])) 102 | df_l_list[i].update(extract_alpaca_dataset(df_l_list[i])) 103 | elif args.dataset_format is None: 104 | pass 105 | else: 106 | raise ValueError(f"Unknown dataset format: {args.dataset_format}") 107 | 108 | tensors_w = preprocess_for_sft( 109 | instances=df_w_list, 110 | tokenizer=tokenizer, 111 | source_max_len=args.source_max_len, 112 | target_max_len=args.target_max_len, 113 | total_max_len=args.total_max_len, 114 | train_on_source=args.train_on_source, 115 | add_eos_to_target=args.add_eos_to_target, 116 | add_eos_to_marked_target=args.add_eos_to_marked_target, 117 | return_win_rate=True, 118 | ) 119 | tensors_l = preprocess_for_sft( 120 | instances=df_l_list, 121 | tokenizer=tokenizer, 122 | source_max_len=args.source_max_len, 123 | target_max_len=args.target_max_len, 124 | total_max_len=args.total_max_len, 125 | train_on_source=args.train_on_source, 126 | add_eos_to_target=args.add_eos_to_target, 127 | add_eos_to_marked_target=args.add_eos_to_marked_target, 128 | return_win_rate=True, 129 | ) 130 | return dict( 131 | input_ids_w=tensors_w["input_ids"], 132 | labels_w=tensors_w["labels"], 133 | win_rate_w=tensors_w["win_rate"], 134 | input_ids_l=tensors_l["input_ids"], 135 | labels_l=tensors_l["labels"], 136 | win_rate_l=tensors_l["win_rate"], 137 | ) 138 | 139 | 140 | def make_dpo_data_module( 141 | tokenizer: AcceleraTokenizer, 142 | args: Arguments, 143 | ) -> dict: 144 | preference_dataset = utils.local_dataset(args.dataset) 145 | train_preference = preference_dataset["train"] 146 | 147 | train_dataset = DPODataset( 148 | args=args, 149 | dataset=train_preference, 150 | tokenizer=tokenizer, 151 | ) 152 | 153 | eval_dataset = None 154 | if args.eval_size > 0: 155 | train_dataset, eval_dataset = utils.split_train_into_train_and_eval( 156 | train_dataset=train_dataset, 157 | eval_size=args.eval_size, 158 | seed=args.seed, 159 | ) 160 | data_collator = utils.DataCollatorForStackableDataset() 161 | return dict( 162 | train_dataset=train_dataset, 163 | eval_dataset=eval_dataset, 164 | data_collator=data_collator, 165 | ) 166 | -------------------------------------------------------------------------------- /data_utils/data_utils_ppo.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The GPT-Accelera Team 2 | # Copyright 2023 The Alpaca Team 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import os 17 | from typing import Dict 18 | import logging 19 | 20 | import torch 21 | from torch.utils.data import Dataset 22 | from datasets import Dataset as HFDataset 23 | 24 | from arguments import Arguments 25 | import trainers.common_utils as utils 26 | from models.tokenizer_utils import AcceleraTokenizer 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | 31 | class QueryDataset(Dataset): 32 | """Dataset that emits tokenized left-padded queries.""" 33 | 34 | def __init__( 35 | self, 36 | dataset: HFDataset, 37 | tokenizer: AcceleraTokenizer, 38 | query_len: int, 39 | ): 40 | super(QueryDataset, self).__init__() 41 | 42 | list_dict_data = dataset.to_pandas().to_dict("records") 43 | 44 | # prompts are strings; queries are tensors. 45 | queries = [dict_data["input"] for dict_data in list_dict_data] 46 | answers = [ 47 | f"{dict_data['answer']} ;;; {dict_data['gt_answer']} ;;; {dict_data['level']}" 48 | for dict_data in list_dict_data 49 | ] 50 | 51 | logger.warning(f"Debugging: {answers[:10]}") 52 | queries = [ 53 | tokenizer(query, return_tensors="pt", truncation=False).input_ids.squeeze( 54 | dim=0 55 | ) 56 | for query in queries 57 | ] 58 | 59 | answers = [ 60 | tokenizer(answer, return_tensors="pt", truncation=False).input_ids.squeeze( 61 | dim=0 62 | ) 63 | for answer in answers 64 | ] 65 | 66 | filtered_queries = [] 67 | filtered_answers = [] 68 | 69 | for query, answer in zip(queries, answers): 70 | if len(query) <= query_len: 71 | filtered_queries.append(query) 72 | filtered_answers.append(answer) 73 | 74 | logger.warning( 75 | f"Filtered out {len(queries) - len(filtered_queries)} instances out of {len(queries)} that " 76 | f"exceed length limit. These examples are not used for training, but will still be used in evaluation. " 77 | ) 78 | 79 | queries = torch.stack( 80 | [ 81 | utils.left_pad(query, target_size=(query_len,), value=tokenizer.pad_id) 82 | for query in filtered_queries 83 | ] 84 | ) 85 | 86 | max_answer_len = max([len(answer) for answer in filtered_answers]) 87 | answers = torch.stack( 88 | [ 89 | utils.left_pad( 90 | answer, 91 | target_size=(max_answer_len,), 92 | value=tokenizer.pad_id, 93 | ) 94 | for answer in filtered_answers 95 | ] 96 | ) 97 | 98 | assert queries.shape[0] == answers.shape[0] 99 | 100 | self.queries = queries 101 | self.query_attn_masks = queries.ne(tokenizer.pad_id).long() 102 | self.answers = answers 103 | # Auxiliary data. 104 | self.list_dict_data = list_dict_data 105 | 106 | def __getitem__(self, i): 107 | return dict( 108 | queries=self.queries[i], 109 | query_attn_masks=self.query_attn_masks[i], 110 | answers=self.answers[i], 111 | ) 112 | 113 | def __len__(self): 114 | return len(self.queries) 115 | 116 | 117 | def make_rl_data_module( 118 | tokenizer: AcceleraTokenizer, 119 | args: Arguments, 120 | ) -> Dict: 121 | """ 122 | Make dataset and collator for supervised fine-tuning. 123 | Datasets are expected to have the following columns: { `input`, `output` } 124 | """ 125 | 126 | def load_data(dataset_name): 127 | if os.path.exists(dataset_name): 128 | try: 129 | full_dataset = utils.local_dataset(dataset_name) 130 | return full_dataset 131 | except: 132 | raise ValueError(f"Error loading dataset from {dataset_name}") 133 | else: 134 | raise NotImplementedError(f"Dataset {dataset_name} not implemented yet.") 135 | 136 | def format_dataset(dataset): 137 | # Remove unused columns. 138 | dataset = dataset.remove_columns( 139 | [ 140 | col 141 | for col in dataset.column_names["train"] 142 | if col not in ["input", "answer", "gt_answer", "level"] 143 | ] 144 | ) 145 | return dataset 146 | 147 | # Load dataset. 148 | dataset = load_data(args.dataset) 149 | dataset = format_dataset(dataset) 150 | 151 | # Split train/eval, reduce size 152 | eval_dataset = None 153 | if args.do_eval: 154 | if args.eval_dataset is not None: 155 | eval_dataset = load_data(args.eval_dataset) 156 | eval_dataset = format_dataset(eval_dataset) 157 | eval_dataset = eval_dataset["train"] 158 | else: 159 | print( 160 | "Splitting train dataset in train and validation according to `eval_dataset_size`" 161 | ) 162 | dataset = dataset["train"].train_test_split( 163 | test_size=args.eval_dataset_size, shuffle=True, seed=42 164 | ) 165 | eval_dataset = dataset["test"] 166 | if ( 167 | args.max_eval_samples is not None 168 | and len(eval_dataset) > args.max_eval_samples 169 | ): 170 | eval_dataset = eval_dataset.select(range(args.max_eval_samples)) 171 | 172 | test_dataset = None 173 | if args.do_test: 174 | if args.test_dataset is not None: 175 | test_dataset = load_data(args.test_dataset) 176 | test_dataset = format_dataset(test_dataset) 177 | test_dataset = test_dataset["train"] 178 | else: 179 | raise NotImplementedError("Must specify test dataset if `do_test` is True.") 180 | 181 | train_dataset = dataset["train"] 182 | if ( 183 | args.max_train_samples is not None 184 | and len(train_dataset) > args.max_train_samples 185 | ): 186 | train_dataset = train_dataset.select(range(args.max_train_samples)) 187 | 188 | train_dataset = QueryDataset( 189 | dataset=train_dataset, 190 | tokenizer=tokenizer, 191 | query_len=args.source_max_len, 192 | ) 193 | 194 | if eval_dataset is not None: 195 | eval_dataset = QueryDataset( 196 | dataset=eval_dataset, 197 | tokenizer=tokenizer, 198 | query_len=args.source_max_len, 199 | ) 200 | 201 | if test_dataset is not None: 202 | test_dataset = QueryDataset( 203 | dataset=test_dataset, 204 | tokenizer=tokenizer, 205 | query_len=args.source_max_len, 206 | ) 207 | 208 | data_collator = utils.DataCollatorForStackableDataset() 209 | return dict( 210 | train_dataset=train_dataset, 211 | eval_dataset=eval_dataset, 212 | test_dataset=test_dataset, 213 | data_collator=data_collator, 214 | ) 215 | -------------------------------------------------------------------------------- /data_utils/data_utils_rm_pairwise.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The GPT-Accelera Team 2 | # Copyright 2023 The Alpaca Team 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import logging 17 | from typing import Optional, Dict, Sequence 18 | 19 | import torch 20 | from torch.utils.data import Dataset 21 | 22 | from datasets import Dataset as HFDataset 23 | 24 | from arguments import Arguments 25 | import trainers.common_utils as utils 26 | from models.tokenizer_utils import AcceleraTokenizer 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | 31 | DROMEDARY_PROMPT_DICT = { 32 | "prompt_input": ( 33 | "{meta_prompt}\n" "{instruction}\n\n" "{input}\n\n" "### Dromedary" 34 | ), 35 | "prompt_no_input": ("{meta_prompt}\n" "{instruction}\n\n" "### Dromedary"), 36 | } 37 | 38 | 39 | ALPACA_PROMPT_DICT = { 40 | "prompt_input": ( 41 | "Below is an instruction that describes a task, paired with an input that provides further context. " 42 | "Write a response that appropriately completes the request.\n\n" 43 | "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" 44 | ), 45 | "prompt_no_input": ( 46 | "Below is an instruction that describes a task. " 47 | "Write a response that appropriately completes the request.\n\n" 48 | "### Instruction:\n{instruction}\n\n### Response:\n" 49 | ), 50 | } 51 | 52 | 53 | def format_prompt( 54 | example: Dict[str, str], 55 | prompt_dict: Dict[str, str], 56 | ) -> str: 57 | if prompt_dict is not None: 58 | assert ( 59 | "instruction" in example 60 | ), "Internal error: example missing required keys." 61 | 62 | if example.get("input", "") != "": 63 | prompt_format = prompt_dict["prompt_input"] 64 | else: 65 | prompt_format = prompt_dict["prompt_no_input"] 66 | else: 67 | prompt_format = "{input}" 68 | 69 | format_prompt = prompt_format.format(**example) 70 | return format_prompt 71 | 72 | 73 | def format_output( 74 | example: dict, 75 | output_key="output", 76 | ) -> str: 77 | return example[output_key] 78 | 79 | 80 | def _tokenize_fn( 81 | strings: Sequence[str], 82 | tokenizer: AcceleraTokenizer, 83 | max_length: int, 84 | end_sequence_with_eos: bool, 85 | use_data_frame: bool = False, 86 | ) -> dict: 87 | """Tokenize a list of strings.""" 88 | if use_data_frame: 89 | raise NotImplementedError 90 | strings_ds = strings 91 | 92 | tokenized_strings = tokenizer( 93 | strings_ds, 94 | max_length=max_length, 95 | padding="max_length", 96 | truncation=True, 97 | add_bos=True, 98 | add_eos=True if end_sequence_with_eos else False, 99 | padding_side="right", 100 | truncation_side="right", 101 | ) 102 | 103 | input_ids = torch.stack( 104 | [torch.tensor(tokenized) for tokenized in tokenized_strings["input_ids"]], 105 | dim=0, 106 | ) 107 | 108 | return input_ids 109 | 110 | 111 | def preprocess_for_reward_modeling( 112 | data: HFDataset, 113 | tokenizer: AcceleraTokenizer, 114 | end_sequence_with_eos: bool = False, 115 | max_length: Optional[int] = None, 116 | query_len: Optional[int] = None, 117 | response_len: Optional[int] = None, 118 | prompt_dict: Optional[Dict[str, str]] = None, 119 | ) -> Dict[str, torch.Tensor]: 120 | list_dict_data = data.to_pandas().to_dict("records") 121 | 122 | def _get_numeric_preference(example: dict): 123 | # 1 vs 2 is stored in table, but for modeling we use 0 vs 1; remap here. 124 | return {1: 0, 2: 1}[example["preference"]] 125 | 126 | choice = torch.tensor( 127 | [[_get_numeric_preference(dict_data)] for dict_data in list_dict_data] 128 | ) 129 | 130 | def _get_text(example: dict, output_key: str): 131 | full_prompt = format_prompt(example, prompt_dict) + format_output( 132 | example, output_key 133 | ) 134 | return full_prompt 135 | 136 | text_list_0, text_list_1 = tuple( 137 | [_get_text(dict_data, key) for dict_data in list_dict_data] 138 | for key in ("output_1", "output_2") 139 | ) 140 | 141 | if max_length is None: 142 | max_length = query_len + response_len 143 | 144 | logger.warning(f"Tokenizing {len(list_dict_data)} pairs...") 145 | tokenized_0, tokenized_1 = tuple( 146 | _tokenize_fn(text_list, tokenizer, max_length, end_sequence_with_eos) 147 | for text_list in (text_list_0, text_list_1) 148 | ) 149 | # "size" (bsz, 2, seq_len) 150 | input_ids = torch.stack( 151 | [tokenized_0, tokenized_1], 152 | dim=1, 153 | ) 154 | 155 | packaged_data = dict( 156 | input_ids=input_ids, 157 | choice=choice, 158 | metadata=dict(mean_choice=choice.float().mean().item()), 159 | ) 160 | 161 | return packaged_data 162 | 163 | 164 | class PairwiseRewardModelingDataset(Dataset): 165 | def __init__( 166 | self, 167 | data: HFDataset, 168 | tokenizer: AcceleraTokenizer, 169 | end_sequence_with_eos: bool = False, 170 | max_length: Optional[int] = None, 171 | query_len: Optional[int] = None, 172 | response_len: Optional[int] = None, 173 | prompt_dict: Optional[Dict[str, str]] = None, 174 | ): 175 | super(PairwiseRewardModelingDataset, self).__init__() 176 | data_dict = preprocess_for_reward_modeling( 177 | data=data, 178 | tokenizer=tokenizer, 179 | end_sequence_with_eos=end_sequence_with_eos, 180 | max_length=max_length, 181 | query_len=query_len, 182 | response_len=response_len, 183 | prompt_dict=prompt_dict, 184 | ) 185 | self.input_ids = data_dict["input_ids"] 186 | self.choice = data_dict["choice"] 187 | self.metadata = data_dict["metadata"] 188 | 189 | def __len__(self): 190 | return len(self.input_ids) 191 | 192 | def __getitem__(self, i) -> Dict[str, torch.Tensor]: 193 | return dict( 194 | input_ids=self.input_ids[i], 195 | choice=self.choice[i], 196 | ) 197 | 198 | 199 | def make_pairwise_reward_modeling_data_module( 200 | tokenizer: AcceleraTokenizer, 201 | args: Arguments, 202 | ): 203 | preference_dataset = utils.local_dataset(args.dataset) 204 | train_preference = preference_dataset["train"] 205 | 206 | if args.dataset_format == "alpaca": 207 | prompt_dict = ALPACA_PROMPT_DICT 208 | elif args.dataset_format is None: 209 | prompt_dict = None 210 | else: 211 | raise ValueError( 212 | f"Unsupported dataset_format: {args.dataset_format}." 213 | "Only alpaca and None are supported." 214 | ) 215 | 216 | train_dataset = PairwiseRewardModelingDataset( 217 | data=train_preference, 218 | tokenizer=tokenizer, 219 | end_sequence_with_eos=args.add_eos_to_target, 220 | max_length=args.total_max_len, 221 | query_len=args.source_max_len, 222 | response_len=args.target_max_len, 223 | prompt_dict=prompt_dict, 224 | ) 225 | 226 | eval_dataset = None 227 | if args.eval_size > 0: 228 | train_dataset, eval_dataset = utils.split_train_into_train_and_eval( 229 | train_dataset=train_dataset, 230 | eval_size=args.eval_size, 231 | seed=args.seed, 232 | ) 233 | 234 | data_collator = utils.DataCollatorForStackableDataset() 235 | return dict( 236 | train_dataset=train_dataset, 237 | eval_dataset=eval_dataset, 238 | data_collator=data_collator, 239 | ) 240 | -------------------------------------------------------------------------------- /data_utils/data_utils_rm_pointwise.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The GPT-Accelera Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | from dataclasses import dataclass 17 | from typing import Dict, Sequence 18 | 19 | import torch 20 | from torch.nn.utils.rnn import pad_sequence 21 | 22 | from arguments import Arguments 23 | import trainers.common_utils as utils 24 | from models.tokenizer_utils import AcceleraTokenizer 25 | 26 | SPLITTER = " ;;; " 27 | 28 | 29 | @dataclass 30 | class DataCollatorForPointwiseRewardModeling(object): 31 | tokenizer: AcceleraTokenizer 32 | source_max_len: int 33 | target_max_len: int 34 | total_max_len: int 35 | train_on_every_token: bool 36 | 37 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: 38 | # Extract elements 39 | sources = [example["input"] for example in instances] 40 | targets = [f"\n{example['output']}" for example in instances] 41 | labels = [example["label"] for example in instances] 42 | 43 | begin_padding_len = self.tokenizer( 44 | ["\n"], return_tensors="pt", add_bos=False, add_eos=False 45 | ).input_ids.shape[1] 46 | 47 | # Tokenize 48 | tokenized_sources_with_prompt = self.tokenizer( 49 | sources, 50 | max_length=self.source_max_len, 51 | padding="max_length", 52 | truncation=True, 53 | add_bos=True, 54 | add_eos=False, 55 | padding_side="left", 56 | truncation_side="left", 57 | ) 58 | 59 | tokenized_targets = self.tokenizer( 60 | targets, 61 | max_length=self.target_max_len + begin_padding_len, 62 | padding="max_length", 63 | truncation=True, 64 | add_bos=False, 65 | add_eos=False, 66 | padding_side="right", 67 | truncation_side="right", 68 | ) 69 | # Build the input and labels for causal LM 70 | input_ids = [] 71 | weights = [] 72 | for ( 73 | source_length, 74 | target_length, 75 | tokenized_source, 76 | tokenized_target, 77 | ) in zip( 78 | tokenized_sources_with_prompt["length"], 79 | tokenized_targets["length"], 80 | tokenized_sources_with_prompt["input_ids"], 81 | tokenized_targets["input_ids"], 82 | ): 83 | real_target_length = target_length - begin_padding_len 84 | tokenized_target = tokenized_target[begin_padding_len:] 85 | full_seq = tokenized_source + tokenized_target 86 | 87 | # move the beginning padding to the end of the full_seq 88 | num_begin_padding = len(tokenized_source) - source_length 89 | full_seq = full_seq[num_begin_padding:] + full_seq[:num_begin_padding] 90 | 91 | if self.total_max_len is not None: 92 | full_seq = full_seq[: self.total_max_len] 93 | 94 | weight = ( 95 | [0 for _ in range(source_length)] 96 | + [1 for _ in range(real_target_length)] 97 | + [0 for _ in range(len(tokenized_target) - real_target_length)] 98 | + [0 for _ in range(num_begin_padding)] 99 | ) 100 | 101 | if not self.train_on_every_token: 102 | # we only train on the last three tokens of the target 103 | if real_target_length > 3: 104 | weight = ( 105 | [0 for _ in range(source_length)] 106 | + [0 for _ in range(real_target_length - 3)] 107 | + [1 for _ in range(3)] 108 | + [0 for _ in range(len(tokenized_target) - real_target_length)] 109 | + [0 for _ in range(num_begin_padding)] 110 | ) 111 | 112 | if self.total_max_len is not None: 113 | weight = weight[: self.total_max_len] 114 | 115 | input_ids.append(torch.tensor(full_seq)) 116 | weights.append(torch.tensor(weight)) 117 | 118 | # Apply padding 119 | input_ids = pad_sequence( 120 | input_ids, batch_first=True, padding_value=self.tokenizer.pad_id 121 | ) 122 | weights = pad_sequence(weights, batch_first=True, padding_value=0) 123 | weights = weights.float() 124 | labels = ( 125 | torch.tensor(labels).view(-1, 1).repeat(1, input_ids.shape[1]).contiguous() 126 | ) 127 | data_dict = { 128 | "input_ids": input_ids, 129 | "attention_mask": input_ids.ne(self.tokenizer.pad_id), 130 | "weights": weights, 131 | "labels": labels, 132 | } 133 | return data_dict 134 | 135 | 136 | @dataclass 137 | class DataCollatorForPointwiseRewardModelingV2(object): 138 | tokenizer: AcceleraTokenizer 139 | source_max_len: int 140 | target_max_len: int 141 | total_max_len: int 142 | train_on_every_token: bool 143 | 144 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: 145 | # Extract elements 146 | sources = [example["input"] for example in instances] 147 | batch_size = len(sources) 148 | targets = [] 149 | target_batch_sizes = [] 150 | for example in instances: 151 | target_steps = [ 152 | f"\n{output}" for output in example["output"].split(SPLITTER) 153 | ] 154 | targets.extend(target_steps) 155 | target_batch_sizes.append(len(target_steps)) 156 | step_labels = [ 157 | [int(_) for _ in example["label"].split(SPLITTER)] for example in instances 158 | ] 159 | 160 | begin_padding_len = self.tokenizer( 161 | ["\n"], return_tensors="pt", add_bos=False, add_eos=False 162 | ).input_ids.shape[1] 163 | 164 | # Tokenize 165 | tokenized_sources_with_prompt = self.tokenizer( 166 | sources, 167 | max_length=self.source_max_len, 168 | padding="max_length", 169 | truncation=True, 170 | add_bos=True, 171 | add_eos=False, 172 | padding_side="left", 173 | truncation_side="left", 174 | ) 175 | 176 | tokenized_targets = self.tokenizer( 177 | targets, 178 | max_length=self.target_max_len + begin_padding_len, 179 | padding="max_length", 180 | truncation=True, 181 | add_bos=False, 182 | add_eos=False, 183 | padding_side="right", 184 | truncation_side="right", 185 | ) 186 | # Build the input and labels for causal LM 187 | input_ids = [] 188 | weights = [] 189 | labels = [] 190 | 191 | batched_tokenized_targets = {} 192 | batched_tokenized_targets["input_ids"] = [] 193 | batched_tokenized_targets["length"] = [] 194 | start_idx = 0 195 | for i in range(0, batch_size): 196 | end_idx = start_idx + target_batch_sizes[i] 197 | batched_tokenized_targets["input_ids"].append( 198 | tokenized_targets["input_ids"][start_idx:end_idx] 199 | ) 200 | batched_tokenized_targets["length"].append( 201 | tokenized_targets["length"][start_idx:end_idx] 202 | ) 203 | start_idx = end_idx 204 | 205 | assert len(batched_tokenized_targets["input_ids"]) == len( 206 | tokenized_sources_with_prompt["input_ids"] 207 | ), f"{len(batched_tokenized_targets['input_ids'])} != {len(tokenized_sources_with_prompt['input_ids'])}" 208 | assert len(batched_tokenized_targets["length"]) == len( 209 | tokenized_sources_with_prompt["length"] 210 | ), f"{len(batched_tokenized_targets['length'])} != {len(tokenized_sources_with_prompt['length'])}" 211 | 212 | for ( 213 | source_length, 214 | batched_target_length, 215 | tokenized_source, 216 | batched_tokenized_target, 217 | batched_step_label, 218 | ) in zip( 219 | tokenized_sources_with_prompt["length"], 220 | batched_tokenized_targets["length"], 221 | tokenized_sources_with_prompt["input_ids"], 222 | batched_tokenized_targets["input_ids"], 223 | step_labels, 224 | ): 225 | weight = [] 226 | full_seq = [] 227 | label = [] 228 | 229 | # add source 230 | num_begin_padding = len(tokenized_source) - source_length 231 | full_seq = full_seq + tokenized_source[num_begin_padding:] 232 | weight = weight + [0 for _ in range(source_length)] 233 | label = label + [0 for _ in range(source_length)] 234 | 235 | # add target one by one 236 | for target_length, tokenized_target, step_label in zip( 237 | batched_target_length, batched_tokenized_target, batched_step_label 238 | ): 239 | real_target_length = target_length - begin_padding_len 240 | tokenized_target = tokenized_target[begin_padding_len:target_length] 241 | full_seq = full_seq + tokenized_target 242 | 243 | if not self.train_on_every_token and real_target_length > 3: 244 | weight = ( 245 | weight 246 | + [0 for _ in range(real_target_length - 3)] 247 | + [1 for _ in range(3)] 248 | ) 249 | else: 250 | weight = weight + [1 for _ in range(real_target_length)] 251 | label = label + [step_label for _ in range(real_target_length)] 252 | 253 | # add padding 254 | if self.total_max_len is not None: 255 | full_seq = full_seq[: self.total_max_len] 256 | weight = weight[: self.total_max_len] 257 | label = label[: self.total_max_len] 258 | 259 | if self.total_max_len > len(full_seq): 260 | padding_length = self.total_max_len - len(full_seq) 261 | weight = weight + [0 for _ in range(padding_length)] 262 | full_seq = full_seq + [ 263 | self.tokenizer.pad_id for _ in range(padding_length) 264 | ] 265 | label = label + [0 for _ in range(padding_length)] 266 | 267 | assert len(full_seq) == len(weight) 268 | assert len(full_seq) == len(label) 269 | input_ids.append(torch.tensor(full_seq)) 270 | weights.append(torch.tensor(weight)) 271 | labels.append(torch.tensor(label)) 272 | 273 | # Apply padding 274 | input_ids = pad_sequence( 275 | input_ids, batch_first=True, padding_value=self.tokenizer.pad_id 276 | ) 277 | weights = pad_sequence(weights, batch_first=True, padding_value=0) 278 | weights = weights.float() 279 | labels = pad_sequence(labels, batch_first=True, padding_value=0) 280 | labels = labels.long() 281 | data_dict = { 282 | "input_ids": input_ids, 283 | "attention_mask": input_ids.ne(self.tokenizer.pad_id), 284 | "weights": weights, 285 | "labels": labels, 286 | } 287 | return data_dict 288 | 289 | 290 | def extract_prm_dataset(example): 291 | if example["output_prefix"] == "": 292 | ret = { 293 | "input": "Question: " + example["input"], 294 | "output": "\n\nAnswer: " + example["output"], 295 | } 296 | else: 297 | ret = { 298 | "input": "Question: " 299 | + example["input"] 300 | + "\n\nAnswer: " 301 | + example["output_prefix"], 302 | "output": example["output"], 303 | } 304 | 305 | ret["label"] = example["label"] 306 | 307 | return ret 308 | 309 | 310 | def extract_prm_v2_dataset(example): 311 | if example["output_prefix"] == "": 312 | ret = { 313 | "input": "# Question\n\n" + example["input"] + "\n\n# Solution", 314 | "output": "\n\n" + example["output"], 315 | } 316 | else: 317 | ret = { 318 | "input": "# Question\n\n" 319 | + example["input"] 320 | + "\n\n# Solution\n\n" 321 | + example["output_prefix"], 322 | "output": example["output"], 323 | } 324 | 325 | ret["label"] = example["label"] 326 | 327 | return ret 328 | 329 | 330 | def extract_prm_v3_dataset(example): 331 | if example["output_prefix"] == "": 332 | ret = { 333 | "input": "# Question\n\n" + example["input"] + "\n\n# Solution\n\n", 334 | "output": example["output"], 335 | } 336 | else: 337 | ret = { 338 | "input": "# Question\n\n" 339 | + example["input"] 340 | + "\n\n# Solution\n\n" 341 | + example["output_prefix"], 342 | "output": example["output"], 343 | } 344 | 345 | ret["label"] = example["label"] 346 | 347 | return ret 348 | 349 | 350 | def extract_prm_v4_dataset(example): 351 | output = [_ + "\n\n" for _ in example["output"][:-1]] + [example["output"][-1]] 352 | assert len(output) == len(example["label"]) 353 | assert all([SPLITTER not in _ for _ in output]) 354 | 355 | _input = "# Question\n\n" + example["input"] + "\n\n# Solution\n\n" 356 | if "output_prefix" in example and example["output_prefix"] is not None: 357 | _input = _input + example["output_prefix"] 358 | 359 | ret = { 360 | "input": _input, 361 | "output": SPLITTER.join(output), 362 | "label": SPLITTER.join([str(_) for _ in example["label"]]), 363 | } 364 | return ret 365 | 366 | 367 | def make_pointwise_reward_modeling_data_module( 368 | tokenizer: AcceleraTokenizer, 369 | args: Arguments, 370 | ) -> Dict: 371 | """ 372 | Make dataset and collator for supervised fine-tuning. 373 | Datasets are expected to have the following columns: { `input`, `output` } 374 | """ 375 | 376 | def load_data(dataset_name): 377 | if os.path.exists(dataset_name): 378 | try: 379 | full_dataset = utils.local_dataset(dataset_name) 380 | return full_dataset 381 | except: 382 | raise ValueError(f"Error loading dataset from {dataset_name}") 383 | else: 384 | raise NotImplementedError(f"Dataset {dataset_name} not implemented yet.") 385 | 386 | multiple_output_dataset = False 387 | if args.dataset_format == "prm-v4": 388 | multiple_output_dataset = True 389 | 390 | def format_dataset(dataset, dataset_format): 391 | if dataset_format == "prm": 392 | dataset = dataset.map(extract_prm_dataset) 393 | elif dataset_format == "prm-v2": 394 | dataset = dataset.map(extract_prm_v2_dataset) 395 | elif dataset_format == "prm-v3": 396 | dataset = dataset.map(extract_prm_v3_dataset) 397 | elif dataset_format == "prm-v4": 398 | dataset = dataset.map(extract_prm_v4_dataset) 399 | else: 400 | raise ValueError(f"Unsupported dataset format: {dataset_format}") 401 | 402 | # Remove unused columns. 403 | dataset = dataset.remove_columns( 404 | [ 405 | col 406 | for col in dataset.column_names["train"] 407 | if col not in ["input", "output", "label"] 408 | ] 409 | ) 410 | return dataset 411 | 412 | # Load dataset. 413 | dataset = load_data(args.dataset) 414 | dataset = format_dataset(dataset, args.dataset_format) 415 | 416 | # Split train/eval, reduce size 417 | if args.do_eval: 418 | if "eval" in dataset: 419 | eval_dataset = dataset["eval"] 420 | else: 421 | print( 422 | "Splitting train dataset in train and validation according to `eval_dataset_size`" 423 | ) 424 | dataset = dataset["train"].train_test_split( 425 | test_size=args.eval_dataset_size, shuffle=True, seed=42 426 | ) 427 | eval_dataset = dataset["test"] 428 | if ( 429 | args.max_eval_samples is not None 430 | and len(eval_dataset) > args.max_eval_samples 431 | ): 432 | eval_dataset = eval_dataset.select(range(args.max_eval_samples)) 433 | 434 | if args.do_train: 435 | train_dataset = dataset["train"] 436 | if ( 437 | args.max_train_samples is not None 438 | and len(train_dataset) > args.max_train_samples 439 | ): 440 | train_dataset = train_dataset.select(range(args.max_train_samples)) 441 | 442 | if multiple_output_dataset: 443 | data_collator = DataCollatorForPointwiseRewardModelingV2( 444 | tokenizer=tokenizer, 445 | source_max_len=args.source_max_len, 446 | target_max_len=args.target_max_len, 447 | total_max_len=args.total_max_len, 448 | train_on_every_token=args.train_on_every_token, 449 | ) 450 | else: 451 | data_collator = DataCollatorForPointwiseRewardModeling( 452 | tokenizer=tokenizer, 453 | source_max_len=args.source_max_len, 454 | target_max_len=args.target_max_len, 455 | total_max_len=args.total_max_len, 456 | train_on_every_token=args.train_on_every_token, 457 | ) 458 | return dict( 459 | train_dataset=train_dataset if args.do_train else None, 460 | eval_dataset=eval_dataset if args.do_eval else None, 461 | data_collator=data_collator, 462 | ) 463 | -------------------------------------------------------------------------------- /data_utils/data_utils_sft.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The GPT-Accelera Team 2 | # Copyright 2023 The Alpaca Team 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import os 17 | from dataclasses import dataclass 18 | import logging 19 | from typing import Dict, Sequence, Union 20 | 21 | import torch 22 | 23 | from datasets import load_dataset 24 | 25 | from arguments import Arguments 26 | import trainers.common_utils as utils 27 | from models.tokenizer_utils import AcceleraTokenizer 28 | 29 | logger = logging.getLogger(__name__) 30 | 31 | DROMEDARY_PROMPT_DICT = { 32 | "prompt_input": ( 33 | "{meta_prompt}\n" "{instruction}\n\n" "{input}\n\n" "### Dromedary" 34 | ), 35 | "prompt_no_input": ("{meta_prompt}\n" "{instruction}\n\n" "### Dromedary"), 36 | } 37 | 38 | ALPACA_PROMPT_DICT = { 39 | "prompt_input": ( 40 | "Below is an instruction that describes a task, paired with an input that provides further context. " 41 | "Write a response that appropriately completes the request.\n\n" 42 | "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" 43 | ), 44 | "prompt_no_input": ( 45 | "Below is an instruction that describes a task. " 46 | "Write a response that appropriately completes the request.\n\n" 47 | "### Instruction:\n{instruction}\n\n### Response:\n" 48 | ), 49 | } 50 | 51 | 52 | def preprocess_for_sft( 53 | instances: Sequence[Dict], 54 | tokenizer: AcceleraTokenizer, 55 | source_max_len: int, 56 | target_max_len: int, 57 | total_max_len: int, 58 | train_on_source: bool, 59 | add_eos_to_target: bool, 60 | add_eos_to_marked_target: bool, 61 | return_win_rate: bool = False, 62 | ) -> Dict[str, Union[torch.Tensor, Sequence[torch.Tensor]]]: 63 | # Extract elements 64 | sources = [example["input"] for example in instances] 65 | targets = [f"\n{example['output']}" for example in instances] 66 | 67 | begin_padding_len = tokenizer( 68 | ["\n"], return_tensors="pt", add_bos=False, add_eos=False 69 | ).input_ids.shape[1] 70 | 71 | # Tokenize 72 | tokenized_sources_with_prompt = tokenizer( 73 | sources, 74 | max_length=source_max_len, 75 | padding="max_length", 76 | truncation=True, 77 | add_bos=True, 78 | add_eos=False, 79 | padding_side="left", 80 | truncation_side="left", 81 | ) 82 | 83 | marked_eos = None 84 | if "is_eos" in instances[0] and add_eos_to_marked_target: 85 | marked_eos = [example["is_eos"] for example in instances] 86 | 87 | win_rate = None 88 | if return_win_rate: 89 | if "win_rate" in instances[0]: 90 | win_rate = [example["win_rate"] for example in instances] 91 | else: 92 | win_rate = [0.5 for _ in instances] 93 | 94 | # logger.warning(f"Tokenizing {len(targets)} pairs...") 95 | tokenized_targets = tokenizer( 96 | targets, 97 | max_length=target_max_len + begin_padding_len, 98 | padding="max_length", 99 | truncation=True, 100 | add_bos=False, 101 | add_eos=add_eos_to_target, 102 | marked_eos=marked_eos, 103 | padding_side="right", 104 | truncation_side="right", 105 | ) 106 | # Build the input and labels for causal LM 107 | input_ids = [] 108 | labels = [] 109 | for source_length, tokenized_source, tokenized_target in zip( 110 | tokenized_sources_with_prompt["length"], 111 | tokenized_sources_with_prompt["input_ids"], 112 | tokenized_targets["input_ids"], 113 | ): 114 | tokenized_target = tokenized_target[begin_padding_len:] 115 | full_seq = tokenized_source + tokenized_target 116 | 117 | # move the beginning padding to the end of the full_seq 118 | num_begin_padding = len(tokenized_source) - source_length 119 | full_seq = full_seq[num_begin_padding:] + full_seq[:num_begin_padding] 120 | 121 | if total_max_len is not None: 122 | full_seq = full_seq[:total_max_len] 123 | 124 | # input_ids.append(torch.tensor(full_seq)) 125 | input_ids.append(full_seq) 126 | if not train_on_source: 127 | full_seq_label = ( 128 | [tokenizer.pad_id for _ in range(source_length)] 129 | + tokenized_target 130 | + [tokenizer.pad_id for _ in range(num_begin_padding)] 131 | ) 132 | if total_max_len is not None: 133 | full_seq_label = full_seq_label[:total_max_len] 134 | # labels.append(torch.tensor(full_seq_label)) 135 | labels.append(full_seq_label) 136 | else: 137 | # labels.append(torch.tensor(copy.deepcopy(full_seq))) 138 | labels.append(full_seq) 139 | # Apply padding 140 | # input_ids = pad_sequence( 141 | # input_ids, batch_first=True, padding_value=tokenizer.pad_id 142 | # ) 143 | # labels = pad_sequence(labels, batch_first=True, padding_value=tokenizer.pad_id) 144 | input_ids = torch.tensor(input_ids) 145 | labels = torch.tensor(labels) 146 | data_dict = { 147 | "input_ids": input_ids, 148 | "attention_mask": input_ids.ne(tokenizer.pad_id), 149 | } 150 | if labels is not None: 151 | data_dict["labels"] = labels 152 | if return_win_rate: 153 | data_dict["win_rate"] = torch.tensor(win_rate).view(-1, 1) 154 | return data_dict 155 | 156 | 157 | @dataclass 158 | class DataCollatorForCausalLM(object): 159 | tokenizer: AcceleraTokenizer 160 | source_max_len: int 161 | target_max_len: int 162 | total_max_len: int 163 | train_on_source: bool 164 | add_eos_to_target: bool 165 | add_eos_to_marked_target: bool 166 | 167 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: 168 | return preprocess_for_sft( 169 | instances=instances, 170 | tokenizer=self.tokenizer, 171 | source_max_len=self.source_max_len, 172 | target_max_len=self.target_max_len, 173 | total_max_len=self.total_max_len, 174 | train_on_source=self.train_on_source, 175 | add_eos_to_target=self.add_eos_to_target, 176 | add_eos_to_marked_target=self.add_eos_to_marked_target, 177 | ) 178 | 179 | 180 | def extract_alpaca_dataset(example): 181 | if example.get("input", "") != "": 182 | prompt_format = ALPACA_PROMPT_DICT["prompt_input"] 183 | else: 184 | prompt_format = ALPACA_PROMPT_DICT["prompt_no_input"] 185 | return {"input": prompt_format.format(**example)} 186 | 187 | 188 | def extract_dromedary_dataset(example, meta_prompts): 189 | assert "example_id" in example 190 | total_meta_prompt = len(meta_prompts) 191 | meta_prompt = meta_prompts[int(example["example_id"]) % total_meta_prompt] 192 | 193 | if example.get("input", "") != "": 194 | prompt_format = DROMEDARY_PROMPT_DICT["prompt_input"] 195 | else: 196 | prompt_format = DROMEDARY_PROMPT_DICT["prompt_no_input"] 197 | 198 | return { 199 | "input": prompt_format.format(meta_prompt=meta_prompt, **example), 200 | "output": "\n" + example["output"], 201 | } 202 | 203 | 204 | def extract_prm_dataset(example): 205 | if example["output_prefix"] == "": 206 | ret = { 207 | "input": "Question: " + example["input"], 208 | "output": "\n\nAnswer: " + example["output"], 209 | } 210 | else: 211 | ret = { 212 | "input": "Question: " 213 | + example["input"] 214 | + "\n\nAnswer: " 215 | + example["output_prefix"], 216 | "output": example["output"], 217 | } 218 | 219 | if "is_eos" in example: 220 | ret["is_eos"] = example["is_eos"] 221 | 222 | return ret 223 | 224 | 225 | def extract_prm_v2_dataset(example): 226 | if example["output_prefix"] == "": 227 | ret = { 228 | "input": "# Question\n\n" + example["input"] + "\n\n# Solution", 229 | "output": "\n\n" + example["output"], 230 | } 231 | else: 232 | ret = { 233 | "input": "# Question\n\n" 234 | + example["input"] 235 | + "\n\n# Solution\n\n" 236 | + example["output_prefix"], 237 | "output": example["output"], 238 | } 239 | 240 | if "is_eos" in example: 241 | ret["is_eos"] = example["is_eos"] 242 | 243 | return ret 244 | 245 | 246 | def extract_metamath_dataset(example): 247 | ret = { 248 | "input": "# Question\n\n" + example["query"] + "\n\n# Solution", 249 | "output": "\n\n" + example["output"], 250 | "is_eos": True, 251 | } 252 | 253 | return ret 254 | 255 | 256 | def make_sft_data_module( 257 | tokenizer: AcceleraTokenizer, 258 | args: Arguments, 259 | ) -> Dict: 260 | """ 261 | Make dataset and collator for supervised fine-tuning. 262 | Datasets are expected to have the following columns: { `input`, `output` } 263 | """ 264 | 265 | def load_data(dataset_name): 266 | if dataset_name == "alpaca": 267 | return load_dataset("tatsu-lab/alpaca") 268 | elif dataset_name == "alpaca-clean": 269 | return load_dataset("yahma/alpaca-cleaned") 270 | elif dataset_name == "chip2": 271 | return load_dataset("laion/OIG", data_files="unified_chip2.jsonl") 272 | elif dataset_name == "self-instruct": 273 | return load_dataset("yizhongw/self_instruct", name="self_instruct") 274 | elif dataset_name == "hh-rlhf": 275 | return load_dataset("Anthropic/hh-rlhf") 276 | elif dataset_name == "longform": 277 | return load_dataset("akoksal/LongForm") 278 | elif dataset_name == "oasst1": 279 | return load_dataset("timdettmers/openassistant-guanaco") 280 | elif dataset_name == "vicuna": 281 | raise NotImplementedError("Vicuna data was not released.") 282 | else: 283 | if os.path.exists(dataset_name): 284 | try: 285 | args.dataset_format = ( 286 | args.dataset_format if args.dataset_format else "alpaca" 287 | ) 288 | full_dataset = utils.local_dataset(dataset_name) 289 | return full_dataset 290 | except: 291 | raise ValueError(f"Error loading dataset from {dataset_name}") 292 | else: 293 | raise NotImplementedError( 294 | f"Dataset {dataset_name} not implemented yet." 295 | ) 296 | 297 | def format_dataset(dataset, dataset_format): 298 | if ( 299 | dataset_format == "alpaca" 300 | or dataset_format == "alpaca-clean" 301 | or (dataset_format is None and args.dataset in ["alpaca", "alpaca-clean"]) 302 | ): 303 | dataset = dataset.map( 304 | extract_alpaca_dataset, remove_columns=["instruction"] 305 | ) 306 | elif dataset_format == "hh-rlhf" or ( 307 | dataset_format is None and args.dataset == "hh-rlhf" 308 | ): 309 | dataset = dataset.map(lambda x: {"input": "", "output": x["chosen"]}) 310 | elif dataset_format == "prm": 311 | dataset = dataset.map(extract_prm_dataset) 312 | elif dataset_format == "prm-v2": 313 | dataset = dataset.map(extract_prm_v2_dataset) 314 | elif dataset_format == "metamath": 315 | dataset = dataset.map(extract_metamath_dataset) 316 | elif dataset_format == "mapped": 317 | dataset = dataset 318 | else: 319 | raise ValueError(f"Unsupported dataset format: {dataset_format}") 320 | 321 | # Remove unused columns. 322 | dataset = dataset.remove_columns( 323 | [ 324 | col 325 | for col in dataset.column_names["train"] 326 | if col not in ["input", "output", "is_eos"] 327 | ] 328 | ) 329 | return dataset 330 | 331 | # Load dataset. 332 | dataset = load_data(args.dataset) 333 | dataset = format_dataset(dataset, args.dataset_format) 334 | 335 | # Split train/eval, reduce size 336 | if args.do_eval: 337 | if "eval" in dataset: 338 | eval_dataset = dataset["eval"] 339 | else: 340 | print( 341 | "Splitting train dataset in train and validation according to `eval_dataset_size`" 342 | ) 343 | dataset = dataset["train"].train_test_split( 344 | test_size=args.eval_dataset_size, shuffle=True, seed=42 345 | ) 346 | eval_dataset = dataset["test"] 347 | if ( 348 | args.max_eval_samples is not None 349 | and len(eval_dataset) > args.max_eval_samples 350 | ): 351 | eval_dataset = eval_dataset.select(range(args.max_eval_samples)) 352 | 353 | if args.do_train: 354 | train_dataset = dataset["train"] 355 | if ( 356 | args.max_train_samples is not None 357 | and len(train_dataset) > args.max_train_samples 358 | ): 359 | train_dataset = train_dataset.select(range(args.max_train_samples)) 360 | 361 | data_collator = DataCollatorForCausalLM( 362 | tokenizer=tokenizer, 363 | source_max_len=args.source_max_len, 364 | target_max_len=args.target_max_len, 365 | total_max_len=args.total_max_len, 366 | train_on_source=args.train_on_source, 367 | add_eos_to_target=args.add_eos_to_target, 368 | add_eos_to_marked_target=args.add_eos_to_marked_target, 369 | ) 370 | return dict( 371 | train_dataset=train_dataset if args.do_train else None, 372 | eval_dataset=eval_dataset if args.do_eval else None, 373 | data_collator=data_collator, 374 | ) 375 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | 3 | TODO: Add examples 4 | -------------------------------------------------------------------------------- /inference_reward.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The GPT-Accelera Team 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import json 9 | import gc 10 | import os 11 | import sys 12 | import time 13 | from pathlib import Path 14 | from typing import Optional, Dict 15 | from collections import OrderedDict 16 | import itertools 17 | import fcntl 18 | 19 | import torch 20 | 21 | import torch._inductor.config 22 | import torch._dynamo.config 23 | 24 | torch._inductor.config.coordinate_descent_tuning = True 25 | torch._inductor.config.triton.unique_kernel_names = True 26 | torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future 27 | 28 | 29 | # support running without installing as a package 30 | wd = Path(__file__).parent.parent.resolve() 31 | sys.path.append(str(wd)) 32 | 33 | from models.reward_model import RewardModel, apply_reward_modeling_head 34 | from models.tp import ( 35 | maybe_init_dist, 36 | initialize_model_parallel, 37 | apply_tp, 38 | apply_reward_head_tp, 39 | get_model_parallel_rank, 40 | get_data_parallel_rank, 41 | get_data_parallel_world_size, 42 | ) 43 | from models.tokenizer_utils import ( 44 | AcceleraTokenizer, 45 | batch_encode_tokens, 46 | ) 47 | from checkpoint_utils import ( 48 | get_latest_checkpoint_path, 49 | load_inference_checkpoint, 50 | ) 51 | 52 | 53 | def model_forward(model, x): 54 | return model(x) 55 | 56 | 57 | def remove_all_backward_hooks(model: torch.nn.Module) -> Dict[str, OrderedDict]: 58 | all_backward_hooks = {} 59 | 60 | for name, module in model.named_modules(): 61 | all_backward_hooks[name] = module._backward_hooks 62 | module._backward_hooks = OrderedDict() 63 | 64 | return all_backward_hooks 65 | 66 | 67 | @torch.no_grad() 68 | def model_score( 69 | model: RewardModel, 70 | prompt: torch.Tensor, 71 | max_seq_len: Optional[int] = None, 72 | ) -> torch.Tensor: 73 | """ 74 | Scores a batch of prompts using a reward model. 75 | """ 76 | B, T = prompt.size(0), prompt.size(1) 77 | 78 | max_seq_len = max_seq_len or T 79 | 80 | device = prompt.device 81 | with torch.device(device): 82 | model.backbone_model.setup_caches( 83 | max_batch_size=B, max_seq_length=max_seq_len, kv_cache=False 84 | ) 85 | 86 | with torch.backends.cuda.sdp_kernel( 87 | enable_flash=True, enable_mem_efficient=False, enable_math=False 88 | ): 89 | rewards = model(prompt) 90 | 91 | return rewards 92 | 93 | 94 | def _load_reward_model(checkpoint_path, device, precision, use_tp): 95 | with torch.device("meta"): 96 | model = RewardModel.from_name(checkpoint_path.parent.name) 97 | 98 | if "int8" in str(checkpoint_path): 99 | raise NotImplementedError("int8 quantization cannot be used for reward model!") 100 | 101 | if "int4" in str(checkpoint_path): 102 | raise NotImplementedError("int4 quantization cannot be used for reward model!") 103 | 104 | checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) 105 | model.backbone_model.load_state_dict(checkpoint, assign=True) 106 | 107 | if use_tp: 108 | print("Applying tensor parallel to model ...") 109 | apply_tp(model.backbone_model) 110 | 111 | apply_reward_modeling_head(model.backbone_model) 112 | 113 | if use_tp: 114 | print("Applying tensor parallel to reward head ...") 115 | apply_reward_head_tp(model.backbone_model) 116 | 117 | model = model.to(device=device, dtype=precision) 118 | return model.eval() 119 | 120 | 121 | def main( 122 | prompt_file: Path, 123 | output_file: Path, 124 | batch_size: int = 4, 125 | checkpoint_path: Path = Path( 126 | "checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth" 127 | ), 128 | compile: bool = True, 129 | finetune_checkpoint_path: Optional[Path] = None, 130 | resume_generation: bool = False, 131 | tensor_parallel_size: Optional[int] = None, 132 | on_the_fly_8bit_quantization: bool = False, 133 | process_reward_with_answer: bool = False, 134 | ) -> None: 135 | """Generates text samples based on a pre-trained Transformer model and tokenizer.""" 136 | assert checkpoint_path.is_file(), checkpoint_path 137 | 138 | tokenizer_path = checkpoint_path.parent / "tokenizer.model" 139 | assert tokenizer_path.is_file(), tokenizer_path 140 | 141 | global print 142 | rank = maybe_init_dist() 143 | use_tp = rank is not None 144 | tp_size = 1 145 | if use_tp: 146 | tp_size = tensor_parallel_size or torch.distributed.get_world_size() 147 | initialize_model_parallel(tp_size) 148 | if rank != 0: 149 | # only print on rank 0 150 | print = lambda *args, **kwargs: None 151 | 152 | device = "cuda" 153 | precision = torch.bfloat16 154 | 155 | print("Loading model ...") 156 | t0 = time.time() 157 | model = _load_reward_model(checkpoint_path, device, precision, use_tp) 158 | 159 | if finetune_checkpoint_path is not None: 160 | finetune_checkpoint_path, _, _ = get_latest_checkpoint_path( 161 | finetune_checkpoint_path 162 | ) 163 | 164 | print("Loading finetune model ...") 165 | 166 | if finetune_checkpoint_path is not None: 167 | load_inference_checkpoint(finetune_checkpoint_path, model) 168 | 169 | model = model.eval() 170 | 171 | if on_the_fly_8bit_quantization: 172 | print("Quantizing model ...") 173 | from models.quantize import WeightOnlyInt8QuantHandler 174 | 175 | simple_quantizer = WeightOnlyInt8QuantHandler(model) 176 | model = simple_quantizer.convert_for_runtime_on_the_fly() 177 | model = model.to(device=device) 178 | model = model.eval() 179 | 180 | torch.cuda.synchronize() 181 | print(f"Time to load model: {time.time() - t0:.02f} seconds") 182 | 183 | tokenizer = AcceleraTokenizer(tokenizer=tokenizer_path) 184 | 185 | torch.manual_seed(1234) 186 | model_size = sum( 187 | [ 188 | p.numel() * p.dtype.itemsize 189 | for p in itertools.chain(model.parameters(), model.buffers()) 190 | ] 191 | ) 192 | if compile: 193 | global model_forward 194 | model_forward = torch.compile( 195 | model_forward, mode="reduce-overhead", fullgraph=True 196 | ) 197 | 198 | prompts = [] 199 | 200 | with open(prompt_file, "r") as f: 201 | for line in f: 202 | prompts.append(json.loads(line)) 203 | 204 | # sort prompts by length to minimize padding 205 | 206 | # # debug 207 | # prompts = prompts[:1000] 208 | 209 | assert "idx" in prompts[0] 210 | assert "sample_idx" in prompts[0] 211 | 212 | all_full_seq = [prompt["prompt"] + prompt["output"] for prompt in prompts] 213 | 214 | print("Tokenizing prompts ...") 215 | tokenized_full_seq = tokenizer.batch_encode( 216 | all_full_seq, bos=[False] * len(all_full_seq), eos=[False] * len(all_full_seq) 217 | ) 218 | 219 | for prompt, tokenized in zip(prompts, tokenized_full_seq): 220 | prompt["full_seq"] = prompt["prompt"] + prompt["output"] 221 | prompt["full_seq_len"] = len(tokenized) 222 | 223 | prompts = sorted(prompts, key=lambda x: x["full_seq_len"]) 224 | 225 | skipped_prompt_sample_ids = set() 226 | 227 | if rank == 0 or not use_tp: 228 | output_parent = output_file.parent 229 | if not output_parent.is_dir(): 230 | output_parent.mkdir(exist_ok=True, parents=True) 231 | 232 | if use_tp: 233 | torch.distributed.barrier() 234 | 235 | print("Skipping prompts that have already been generated ...") 236 | if resume_generation and os.path.isfile(output_file): 237 | with open(output_file, "r") as f: 238 | for line in f: 239 | sample = json.loads(line) 240 | prompt_sample_ids = (sample["idx"], sample["sample_idx"]) 241 | skipped_prompt_sample_ids.add(prompt_sample_ids) 242 | 243 | # prompts = [prompt for prompt in prompts if prompt["idx"] not in skipped_prompt_ids] 244 | new_prompts = [] 245 | for prompt in prompts: 246 | if (prompt["idx"], prompt["sample_idx"]) not in skipped_prompt_sample_ids: 247 | new_prompts.append(prompt) 248 | skipped_prompt_sample_ids.add((prompt["idx"], prompt["sample_idx"])) 249 | prompts = new_prompts 250 | 251 | while len(prompts) % batch_size != 0: 252 | prompts.insert(0, prompts[0]) 253 | 254 | dp_rank = get_data_parallel_rank() 255 | tp_rank = get_model_parallel_rank() 256 | 257 | dp_size = get_data_parallel_world_size() 258 | 259 | if tp_rank == 0: 260 | output_writer = open(output_file, "a") 261 | 262 | batch_idx = 0 263 | 264 | gc.collect() 265 | torch.cuda.empty_cache() 266 | 267 | max_seq_len = prompts[-1]["full_seq_len"] + 2 268 | print("Max sequence length:", max_seq_len) 269 | print("Max vocab size:", model.backbone_model.config.vocab_size) 270 | 271 | if compile: 272 | remove_all_backward_hooks(model) 273 | 274 | for batched_prompt_idx in range(0, len(prompts), batch_size): 275 | batch_idx += 1 276 | if batch_idx % dp_size != dp_rank: 277 | continue 278 | 279 | batched_prompts = prompts[batched_prompt_idx : batched_prompt_idx + batch_size] 280 | 281 | encoded = batch_encode_tokens( 282 | tokenizer, 283 | [_["full_seq"] for _ in batched_prompts], 284 | bos=True, 285 | eos=True, 286 | device=device, 287 | padding_side="right", 288 | ) 289 | prompt_length = encoded.size(1) 290 | 291 | model_vocab_size = model.backbone_model.config.vocab_size 292 | encoded[encoded >= model_vocab_size] = model_vocab_size - 1 293 | 294 | # torch.cuda.synchronize() 295 | t0 = time.perf_counter() 296 | 297 | y = model_score( 298 | model, 299 | encoded, 300 | max_seq_len=max_seq_len, 301 | ) 302 | 303 | assert y.size(0) == len(batched_prompts) 304 | assert y.size(1) == prompt_length 305 | 306 | outputs = y.tolist() 307 | 308 | print(outputs[0]) 309 | 310 | # torch.cuda.synchronize() 311 | t = time.perf_counter() - t0 312 | tokens_generated = prompt_length * y.size(0) 313 | tokens_sec = tokens_generated / t 314 | print(f"Prompt length: {prompt_length}") 315 | print( 316 | f"Time for inference {batched_prompt_idx + batch_size} / {len(prompts)}" 317 | f": {t:.02f} sec total, {tokens_sec:.02f} tokens/sec" 318 | ) 319 | print(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s") 320 | 321 | if tp_rank == 0: 322 | fcntl.flock(output_writer, fcntl.LOCK_EX) 323 | try: 324 | for prompt, score in zip(batched_prompts, outputs): 325 | output_writer.write( 326 | json.dumps( 327 | { 328 | "idx": prompt["idx"], 329 | "sample_idx": prompt["sample_idx"], 330 | "prompt": prompt["prompt"], 331 | "output": prompt["output"], 332 | "reward": score, 333 | } 334 | ) 335 | + "\n" 336 | ) 337 | output_writer.flush() 338 | finally: 339 | fcntl.flock(output_writer, fcntl.LOCK_UN) 340 | 341 | 342 | if __name__ == "__main__": 343 | import argparse 344 | 345 | parser = argparse.ArgumentParser(description="Your CLI description.") 346 | 347 | parser.add_argument( 348 | "--prompt_file", 349 | type=Path, 350 | required=True, 351 | help="File containing prompts, one per line.", 352 | ) 353 | parser.add_argument( 354 | "--output_file", 355 | type=Path, 356 | required=True, 357 | help="File to write rewards to, one per line.", 358 | ) 359 | parser.add_argument("--batch_size", type=int, default=4, help="Batch size.") 360 | parser.add_argument( 361 | "--checkpoint_path", 362 | type=Path, 363 | default=Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), 364 | help="Model checkpoint path.", 365 | ) 366 | parser.add_argument( 367 | "--compile", action="store_true", help="Whether to compile the model." 368 | ) 369 | parser.add_argument( 370 | "--finetune_checkpoint_path", 371 | type=Path, 372 | default=None, 373 | help="Finetune checkpoint path.", 374 | ) 375 | 376 | parser.add_argument( 377 | "--resume_generation", action="store_true", help="Whether to resume generation." 378 | ) 379 | 380 | parser.add_argument( 381 | "--tensor_parallel_size", 382 | type=int, 383 | default=None, 384 | help="Size of tensor parallelism.", 385 | ) 386 | 387 | parser.add_argument( 388 | "--on_the_fly_8bit_quantization", 389 | action="store_true", 390 | help="Whether to quantize after loading the model.", 391 | ) 392 | 393 | parser.add_argument( 394 | "--process_reward_with_answer", 395 | action="store_true", 396 | help="Whether to apply process reward with answer.", 397 | ) 398 | 399 | args = parser.parse_args() 400 | main( 401 | args.prompt_file, 402 | args.output_file, 403 | args.batch_size, 404 | args.checkpoint_path, 405 | args.compile, 406 | args.finetune_checkpoint_path, 407 | args.resume_generation, 408 | args.tensor_parallel_size, 409 | args.on_the_fly_8bit_quantization, 410 | args.process_reward_with_answer, 411 | ) 412 | -------------------------------------------------------------------------------- /models/frozen_layers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The GPT-Accelera Team 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Optional 8 | 9 | import torch 10 | import torch.nn as nn 11 | from torch.nn import functional as F 12 | from torch import Tensor 13 | 14 | 15 | try: 16 | from apex.normalization.fused_layer_norm import FusedRMSNormFunction 17 | 18 | # print( 19 | # "`apex` is installed. You can use fused RMSNorm by set_global_compile_mode(False)." 20 | # ) 21 | except ImportError as e: 22 | FusedRMSNormFunction = None 23 | # print("`apex` is not installed. Reverting to non-fused RMSNorm.") 24 | 25 | # whether to use fused RMSNorm or not (default: no) 26 | _GLOBAL_IN_COMPILE_MODE = True 27 | 28 | 29 | def find_multiple(n: int, k: int) -> int: 30 | if n % k == 0: 31 | return n 32 | return n + k - (n % k) 33 | 34 | 35 | class FrozenEmbedding(nn.Module): 36 | __constants__ = [ 37 | "num_embeddings", 38 | "embedding_dim", 39 | "padding_idx", 40 | "max_norm", 41 | "norm_type", 42 | "scale_grad_by_freq", 43 | "sparse", 44 | ] 45 | 46 | num_embeddings: int 47 | embedding_dim: int 48 | padding_idx: Optional[int] 49 | max_norm: Optional[float] 50 | norm_type: float 51 | scale_grad_by_freq: bool 52 | weight: Tensor 53 | freeze: bool 54 | sparse: bool 55 | 56 | def __init__( 57 | self, 58 | num_embeddings: int, 59 | embedding_dim: int, 60 | device=None, 61 | dtype=None, 62 | ) -> None: 63 | factory_kwargs = {"device": device, "dtype": dtype} 64 | super().__init__() 65 | self.num_embeddings = num_embeddings 66 | self.embedding_dim = embedding_dim 67 | self.padding_idx = None 68 | self.max_norm = None 69 | self.norm_type = 2.0 70 | self.scale_grad_by_freq = False 71 | self.sparse = False 72 | self.vocab_start_index = None 73 | self.vocab_end_index = None 74 | self.num_embeddings_per_partition = None 75 | self.register_buffer( 76 | "weight", torch.empty((num_embeddings, embedding_dim), **factory_kwargs) 77 | ) 78 | 79 | def forward(self, input: Tensor) -> Tensor: 80 | if self.num_embeddings_per_partition is None: 81 | return F.embedding( 82 | input, 83 | self.weight, 84 | self.padding_idx, 85 | self.max_norm, 86 | self.norm_type, 87 | self.scale_grad_by_freq, 88 | self.sparse, 89 | ) 90 | else: 91 | # Build the mask. 92 | print("vocab_start_index", self.vocab_start_index) 93 | print("vocab_end_index", self.vocab_end_index) 94 | input_mask = (input < self.vocab_start_index) | ( 95 | input >= self.vocab_end_index 96 | ) 97 | # Mask the input. 98 | masked_input = input.clone() - self.vocab_start_index 99 | masked_input[input_mask] = 0 100 | # Get the embeddings. 101 | output_parallel = F.embedding( 102 | masked_input, 103 | self.weight, 104 | self.padding_idx, 105 | self.max_norm, 106 | self.norm_type, 107 | self.scale_grad_by_freq, 108 | self.sparse, 109 | ) 110 | # Mask the output embedding. 111 | output_parallel[input_mask, :] = 0.0 112 | return output_parallel 113 | 114 | def extra_repr(self) -> str: 115 | s = "{num_embeddings}, {embedding_dim}" 116 | if self.padding_idx is not None: 117 | s += ", padding_idx={padding_idx}" 118 | if self.max_norm is not None: 119 | s += ", max_norm={max_norm}" 120 | if self.norm_type != 2.0: 121 | s += ", norm_type={norm_type}" 122 | if self.scale_grad_by_freq is not False: 123 | s += ", scale_grad_by_freq={scale_grad_by_freq}" 124 | if self.sparse is not False: 125 | s += ", sparse=True" 126 | return s.format(**self.__dict__) 127 | 128 | 129 | class FrozenRMSNorm(nn.Module): 130 | def __init__(self, dim: int, eps: float = 1e-5): 131 | super().__init__() 132 | self.eps = eps 133 | self.register_buffer("weight", torch.ones(dim)) 134 | 135 | global _GLOBAL_IN_COMPILE_MODE 136 | self.in_compile_mode = _GLOBAL_IN_COMPILE_MODE 137 | 138 | def _norm(self, x): 139 | return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) 140 | 141 | def forward(self, x: Tensor) -> Tensor: 142 | if self.in_compile_mode or FusedRMSNormFunction is None: 143 | with torch.autocast(device_type="cuda", enabled=False): 144 | output = self._norm(x.float()).to(dtype=x.dtype) 145 | return output * self.weight 146 | else: 147 | with torch.autocast(device_type="cuda", enabled=False): 148 | output = FusedRMSNormFunction.apply( 149 | x, 150 | self.weight.size(), 151 | self.eps, 152 | False, 153 | ) 154 | return output * self.weight 155 | 156 | 157 | class FrozenLinear(nn.Module): 158 | __constants__ = ["in_features", "out_features"] 159 | in_features: int 160 | out_features: int 161 | weight: Tensor 162 | 163 | def __init__( 164 | self, 165 | in_features: int, 166 | out_features: int, 167 | bias: bool = True, 168 | device=None, 169 | dtype=None, 170 | ) -> None: 171 | factory_kwargs = {"device": device, "dtype": dtype} 172 | super().__init__() 173 | self.in_features = in_features 174 | self.out_features = out_features 175 | self.register_buffer( 176 | "weight", torch.empty((out_features, in_features), **factory_kwargs) 177 | ) 178 | if bias: 179 | self.register_buffer("bias", torch.empty((out_features,), **factory_kwargs)) 180 | else: 181 | self.register_buffer("bias", None) 182 | 183 | def forward(self, input: Tensor) -> Tensor: 184 | return F.linear(input, self.weight, self.bias) 185 | 186 | def extra_repr(self) -> str: 187 | return f"in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}" 188 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The GPT-Accelera Team 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | from dataclasses import dataclass 9 | from typing import Optional, Union 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch.nn import functional as F 14 | from torch import Tensor 15 | import torch.utils.checkpoint as activation_checkpoint 16 | 17 | import models.frozen_layers as frozen_layers 18 | from models.frozen_layers import ( 19 | FrozenEmbedding, 20 | FrozenLinear, 21 | FrozenRMSNorm, 22 | ) 23 | 24 | 25 | try: 26 | from apex.normalization.fused_layer_norm import FusedRMSNormFunction 27 | 28 | # print( 29 | # "`apex` is installed. You can use fused RMSNorm by set_global_compile_mode(False)." 30 | # ) 31 | except ImportError as e: 32 | FusedRMSNormFunction = None 33 | # print("`apex` is not installed. Reverting to non-fused RMSNorm.") 34 | 35 | 36 | def find_multiple(n: int, k: int) -> int: 37 | if n % k == 0: 38 | return n 39 | return n + k - (n % k) 40 | 41 | 42 | @dataclass 43 | class ModelArgs: 44 | block_size: int = 2048 45 | vocab_size: int = 32000 46 | n_layer: int = 32 47 | n_head: int = 32 48 | dim: int = 4096 49 | intermediate_size: int = None 50 | n_local_heads: int = -1 51 | head_dim: int = 64 52 | rope_base: float = 10000 53 | norm_eps: float = 1e-5 54 | 55 | def __post_init__(self): 56 | if self.n_local_heads == -1: 57 | self.n_local_heads = self.n_head 58 | if self.intermediate_size is None: 59 | hidden_dim = 4 * self.dim 60 | n_hidden = int(2 * hidden_dim / 3) 61 | self.intermediate_size = find_multiple(n_hidden, 256) 62 | self.head_dim = self.dim // self.n_head 63 | 64 | @classmethod 65 | def from_name(cls, name: str): 66 | if name in transformer_configs: 67 | return cls(**transformer_configs[name]) 68 | # fuzzy search 69 | config = [ 70 | config 71 | for config in transformer_configs 72 | if config in str(name).upper() or config in str(name) 73 | ] 74 | assert len(config) >= 1, name 75 | return cls(**transformer_configs[config[0]]) 76 | 77 | 78 | transformer_configs = { 79 | "CodeLlama-7b-Python-hf": dict( 80 | block_size=16384, vocab_size=32000, n_layer=32, dim=4096, rope_base=1000000 81 | ), 82 | "llemma-7b": dict(block_size=4096, n_layer=32, n_head=32, dim=4096), 83 | "deepseek-math-7b": dict( 84 | block_size=4096, 85 | vocab_size=102400, 86 | n_layer=30, 87 | n_head=32, 88 | dim=4096, 89 | intermediate_size=11008, 90 | rope_base=10000, 91 | norm_eps=1e-6, 92 | ), 93 | "7B": dict(n_layer=32, n_head=32, dim=4096), 94 | "13B": dict(n_layer=40, n_head=40, dim=5120), 95 | "30B": dict(n_layer=60, n_head=52, dim=6656), 96 | "34B": dict( 97 | block_size=4096, 98 | n_layer=48, 99 | n_head=64, 100 | dim=8192, 101 | vocab_size=32000, 102 | n_local_heads=8, 103 | intermediate_size=22016, 104 | rope_base=1000000, 105 | ), # CodeLlama-34B-Python-hf 106 | "70B": dict( 107 | n_layer=80, n_head=64, dim=8192, n_local_heads=8, intermediate_size=28672 108 | ), 109 | } 110 | 111 | 112 | class KVCache(nn.Module): 113 | def __init__( 114 | self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16 115 | ): 116 | super().__init__() 117 | cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) 118 | self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype)) 119 | self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype)) 120 | 121 | def update(self, input_pos, k_val, v_val): 122 | # input_pos: [S], k_val: [B, H, S, D] 123 | assert input_pos.shape[0] == k_val.shape[2] 124 | 125 | k_out = self.k_cache 126 | v_out = self.v_cache 127 | k_out[: k_val.size(0), :, input_pos] = k_val 128 | v_out[: k_val.size(0), :, input_pos] = v_val 129 | return k_out[: k_val.size(0)], v_out[: k_val.size(0)] 130 | 131 | 132 | class Transformer(nn.Module): 133 | def __init__( 134 | self, 135 | config: ModelArgs, 136 | freeze_tok_embeddings: bool = False, 137 | freeze_norm: bool = False, 138 | freeze_output: bool = False, 139 | vocab_parallel: bool = False, 140 | ) -> None: 141 | super().__init__() 142 | self.config = config 143 | 144 | self.tok_embeddings = ( 145 | FrozenEmbedding(config.vocab_size, config.dim) 146 | if freeze_tok_embeddings 147 | else nn.Embedding(config.vocab_size, config.dim) 148 | ) 149 | 150 | self.layers = nn.ModuleList( 151 | TransformerBlock(config, freeze_norm=freeze_norm) 152 | for _ in range(config.n_layer) 153 | ) 154 | self.norm = ( 155 | FrozenRMSNorm(config.dim, eps=config.norm_eps) 156 | if freeze_norm 157 | else RMSNorm(config.dim, eps=config.norm_eps) 158 | ) 159 | self.output = ( 160 | FrozenLinear(config.dim, config.vocab_size, bias=False) 161 | if freeze_output 162 | else nn.Linear(config.dim, config.vocab_size, bias=False) 163 | ) 164 | 165 | self.freqs_cis: Optional[Tensor] = None 166 | self.mask_cache: Optional[Tensor] = None 167 | self.max_batch_size = -1 168 | self.max_seq_length = -1 169 | self.kv_cache_enabled = False 170 | self.vocab_parallel = False 171 | 172 | def setup_caches(self, max_batch_size, max_seq_length, kv_cache=True): 173 | if ( 174 | self.max_seq_length >= max_seq_length 175 | and self.max_batch_size >= max_batch_size 176 | ): 177 | if self.kv_cache_enabled or not kv_cache: 178 | return 179 | 180 | if (self.max_seq_length > 0 and self.max_seq_length < max_seq_length) or ( 181 | self.max_batch_size > 0 and self.max_batch_size < max_batch_size 182 | ): 183 | raise ValueError( 184 | "Cannot increase the size of the cache after compiled. " 185 | "Please create a new model with the desired cache size." 186 | ) 187 | 188 | head_dim = self.config.dim // self.config.n_head 189 | max_seq_length = find_multiple(max_seq_length, 8) 190 | self.max_seq_length = max_seq_length 191 | self.max_batch_size = max_batch_size 192 | for b in self.layers: 193 | if kv_cache: 194 | b.attention.kv_cache = KVCache( 195 | max_batch_size, max_seq_length, self.config.n_local_heads, head_dim 196 | ) 197 | self.kv_cache_enabled = True 198 | else: 199 | b.attention.kv_cache = None 200 | self.kv_cache_enabled = kv_cache 201 | 202 | self.freqs_cis = precompute_freqs_cis( 203 | self.config.block_size, 204 | self.config.dim // self.config.n_head, 205 | self.config.rope_base, 206 | ) 207 | self.causal_mask = torch.tril( 208 | torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool) 209 | ) 210 | self.self_mask = torch.eye(self.max_seq_length, dtype=torch.bool) 211 | 212 | def forward( 213 | self, 214 | idx: Tensor, 215 | input_pos: Optional[Tensor] = None, 216 | left_pad_mask_pos: Optional[Tensor] = None, 217 | fully_causal: bool = False, 218 | ) -> Tensor: 219 | assert self.freqs_cis is not None, "Caches must be initialized first" 220 | assert not (fully_causal and left_pad_mask_pos is not None), "Invalid mask" 221 | mask = self.causal_mask[None, None, input_pos] 222 | 223 | if left_pad_mask_pos is not None: 224 | pad_mask = torch.arange(mask.size(-1), device=mask.device).view( 225 | 1, -1 226 | ) >= left_pad_mask_pos.view(-1, 1) 227 | mask = torch.logical_and(mask, pad_mask[:, None, None, :].contiguous()) 228 | mask = torch.logical_or(mask, self.self_mask[None, None, input_pos]) 229 | 230 | x = self.tok_embeddings(idx) 231 | freqs_cis = self.freqs_cis[input_pos].to(dtype=x.dtype) 232 | 233 | for i, layer in enumerate(self.layers): 234 | if self.training: 235 | x = activation_checkpoint.checkpoint( 236 | layer, 237 | x, 238 | input_pos, 239 | freqs_cis, 240 | mask, 241 | fully_causal, 242 | use_reentrant=False, 243 | ) 244 | else: 245 | x = layer(x, input_pos, freqs_cis, mask, fully_causal) 246 | x = self.norm(x) 247 | logits = self.output(x) 248 | return logits 249 | 250 | @classmethod 251 | def from_name(cls, name: str, **kwargs): 252 | return cls(ModelArgs.from_name(name), **kwargs) 253 | 254 | 255 | class TransformerBlock(nn.Module): 256 | def __init__(self, config: ModelArgs, freeze_norm: bool = False) -> None: 257 | super().__init__() 258 | self.attention = Attention(config) 259 | self.feed_forward = FeedForward(config) 260 | self.ffn_norm = ( 261 | FrozenRMSNorm(config.dim, config.norm_eps) 262 | if freeze_norm 263 | else RMSNorm(config.dim, config.norm_eps) 264 | ) 265 | self.attention_norm = ( 266 | FrozenRMSNorm(config.dim, config.norm_eps) 267 | if freeze_norm 268 | else RMSNorm(config.dim, config.norm_eps) 269 | ) 270 | 271 | def forward( 272 | self, 273 | x: Tensor, 274 | input_pos: Tensor, 275 | freqs_cis: Tensor, 276 | mask: Union[Tensor, str], 277 | fully_causal: bool = False, 278 | ) -> Tensor: 279 | h = x + self.attention( 280 | self.attention_norm(x), freqs_cis, mask, input_pos, fully_causal 281 | ) 282 | out = h + self.feed_forward(self.ffn_norm(h)) 283 | return out 284 | 285 | 286 | class Attention(nn.Module): 287 | def __init__(self, config: ModelArgs): 288 | super().__init__() 289 | assert config.dim % config.n_head == 0 290 | 291 | total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim 292 | # key, query, value projections for all heads, but in a batch 293 | self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False) 294 | self.wo = nn.Linear(config.dim, config.dim, bias=False) 295 | self.kv_cache = None 296 | 297 | self.n_head = config.n_head 298 | self.head_dim = config.head_dim 299 | self.n_local_heads = config.n_local_heads 300 | self.dim = config.dim 301 | self._register_load_state_dict_pre_hook(self.load_hook) 302 | 303 | def load_hook(self, state_dict, prefix, *args): 304 | if prefix + "wq.weight" in state_dict: 305 | wq = state_dict.pop(prefix + "wq.weight") 306 | wk = state_dict.pop(prefix + "wk.weight") 307 | wv = state_dict.pop(prefix + "wv.weight") 308 | state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv]) 309 | 310 | def forward( 311 | self, 312 | x: Tensor, 313 | freqs_cis: Tensor, 314 | mask: Union[Tensor, str], 315 | input_pos: Optional[Tensor] = None, 316 | fully_causal: bool = False, 317 | ) -> Tensor: 318 | bsz, seqlen, _ = x.shape 319 | 320 | kv_size = self.n_local_heads * self.head_dim 321 | q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1) 322 | 323 | q = q.view(bsz, seqlen, self.n_head, self.head_dim) 324 | k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim) 325 | v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim) 326 | 327 | q = apply_rotary_emb(q, freqs_cis) 328 | k = apply_rotary_emb(k, freqs_cis) 329 | 330 | q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) 331 | 332 | if not fully_causal: 333 | k, v = self.kv_cache.update(input_pos, k, v) 334 | 335 | k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) 336 | v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) 337 | 338 | if q.size(2) == k.size(2) and fully_causal: 339 | with torch.backends.cuda.sdp_kernel( 340 | enable_flash=True, enable_math=False, enable_mem_efficient=False 341 | ): 342 | y = F.scaled_dot_product_attention(q, k, v, is_causal=True) 343 | else: 344 | y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask) 345 | 346 | y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) 347 | 348 | y = self.wo(y) 349 | return y 350 | 351 | 352 | class FeedForward(nn.Module): 353 | def __init__(self, config: ModelArgs) -> None: 354 | super().__init__() 355 | self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False) 356 | self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False) 357 | self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False) 358 | 359 | def forward(self, x: Tensor) -> Tensor: 360 | return self.w2(F.silu(self.w1(x)) * self.w3(x)) 361 | 362 | 363 | class RMSNorm(nn.Module): 364 | def __init__(self, dim: int, eps: float = 1e-5): 365 | super().__init__() 366 | self.eps = eps 367 | self.weight = nn.Parameter(torch.ones(dim)) 368 | self.in_compile_mode = frozen_layers._GLOBAL_IN_COMPILE_MODE 369 | 370 | def _norm(self, x): 371 | return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) 372 | 373 | def forward(self, x: Tensor) -> Tensor: 374 | if self.in_compile_mode or FusedRMSNormFunction is None: 375 | with torch.autocast(device_type="cuda", enabled=False): 376 | output = self._norm(x.float()).to(dtype=x.dtype) 377 | return output * self.weight 378 | else: 379 | with torch.autocast(device_type="cuda", enabled=False): 380 | output = FusedRMSNormFunction.apply( 381 | x, 382 | self.weight.size(), 383 | self.eps, 384 | False, 385 | ) 386 | return output * self.weight 387 | 388 | 389 | def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor: 390 | freqs = 1.0 / ( 391 | base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem) 392 | ) 393 | t = torch.arange(seq_len, device=freqs.device) 394 | freqs = torch.outer(t, freqs) 395 | freqs_cis = torch.polar(torch.ones_like(freqs), freqs) 396 | cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) 397 | return cache.to(dtype=torch.bfloat16) 398 | 399 | 400 | def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: 401 | xshaped = x.float().reshape(*x.shape[:-1], -1, 2) 402 | freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) 403 | x_out2 = torch.stack( 404 | [ 405 | xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1], 406 | xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1], 407 | ], 408 | -1, 409 | ) 410 | 411 | x_out2 = x_out2.flatten(3) 412 | return x_out2.type_as(x) 413 | 414 | 415 | def set_global_compile_mode(mode: bool): 416 | frozen_layers._GLOBAL_IN_COMPILE_MODE = mode 417 | -------------------------------------------------------------------------------- /models/reward_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The GPT-Accelera Team 2 | # Copyright 2023 The Self-Align Team 3 | # Copyright 2023 The Alpaca Team 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | from dataclasses import dataclass 18 | import math 19 | from typing import Optional, Dict, Sequence, Union 20 | 21 | import einops 22 | import torch 23 | from torch import Tensor, nn 24 | import torch.nn.functional as F 25 | 26 | from models.model import ModelArgs, Transformer 27 | 28 | 29 | def unpack_dict( 30 | d: Dict, keys: Sequence[str], return_type: type = tuple 31 | ) -> Union[Sequence, Dict]: 32 | if return_type in (tuple, list): 33 | return return_type(d[key] for key in keys) 34 | elif return_type == dict: 35 | return {key: d[key] for key in keys} 36 | else: 37 | raise ValueError(f"Unknown return_type: {return_type}") 38 | 39 | 40 | def batch_select(input: Tensor, index: Tensor): 41 | """Select elements from a batched tensor with a batched index tensor. 42 | 43 | Example: 44 | input = torch.tensor([ 45 | [0, 1, 2], 46 | [3, 0, 9], 47 | [6, 7, 8], 48 | ]) 49 | index = torch.tensor([[0, 1], [1, 0], [0, 0]]) 50 | batch_select(input, index) = tensor([ 51 | [0, 1], 52 | [0, 3], 53 | [6, 6] 54 | ]) 55 | """ 56 | dummy_index = torch.arange(input.size(0), device=input.device).unsqueeze(-1) 57 | return input[dummy_index, index] 58 | 59 | 60 | @dataclass 61 | class RewardArgs: 62 | backbone_args: ModelArgs 63 | 64 | @classmethod 65 | def from_name(cls, name: str): 66 | return cls(backbone_args=ModelArgs.from_name(name)) 67 | 68 | 69 | class RewardModel(nn.Module): 70 | def __init__(self, config: RewardArgs, **kwargs) -> None: 71 | super().__init__() 72 | self.config = config 73 | self.backbone_model = Transformer(config.backbone_args, **kwargs) 74 | 75 | def forward( 76 | self, 77 | idx: Tensor, 78 | eos_pos: Optional[Tensor] = None, 79 | ) -> Tensor: 80 | input_pos = torch.arange(0, idx.size(-1), device=idx.device) 81 | rewards = self.backbone_model(idx, input_pos=input_pos, fully_causal=True) 82 | rewards = rewards.mean(dim=-1) 83 | 84 | if eos_pos is not None: 85 | eos_pos = eos_pos.unsqueeze(-1) 86 | rewards = batch_select(rewards, eos_pos).squeeze(-1) 87 | 88 | return rewards 89 | 90 | @classmethod 91 | def from_name(cls, name: str, **kwargs): 92 | return cls(RewardArgs.from_name(name), **kwargs) 93 | 94 | 95 | def apply_reward_modeling_head( 96 | transformer: Transformer, requires_grad=False, init_sceheme="zeros" 97 | ): 98 | output_module = transformer.output 99 | # Linear's weight matrix is transposed, and is of shape 100 | # (linear.out_features, linear.in_features) 101 | 102 | # Temp fix due to https://github.com/pytorch/pytorch/issues/106951 103 | reward_head_weight = torch.zeros_like(output_module.weight)[:2, :] 104 | if init_sceheme == "zeros": 105 | output_module.weight = nn.Parameter( 106 | reward_head_weight, 107 | requires_grad=requires_grad, 108 | ) 109 | elif init_sceheme == "semantic": 110 | # ['### Preferred Output is '] [835, 4721, 14373, 10604, 338, 29871] 111 | # ['### Preferred Output is 1.'] [835, 4721, 14373, 10604, 338, 29871, 29896, 29889] 112 | # ['### Preferred Output is 2.'] [835, 4721, 14373, 10604, 338, 29871, 29906, 29889] 113 | token_1_id = 29896 114 | token_2_id = 29906 115 | reward_head_weight[0, :] = output_module.weight[token_2_id, :] 116 | reward_head_weight[1, :] = -output_module.weight[token_1_id, :] 117 | output_module.weight = nn.Parameter( 118 | reward_head_weight, 119 | requires_grad=requires_grad, 120 | ) 121 | elif init_sceheme == "random": 122 | generator = torch.Generator(device=reward_head_weight.device) 123 | generator.manual_seed(42) 124 | nn.init.kaiming_uniform_( 125 | reward_head_weight, a=math.sqrt(5), generator=generator 126 | ) 127 | output_module.weight = nn.Parameter( 128 | reward_head_weight * math.sqrt(2.0), 129 | requires_grad=requires_grad, 130 | ) 131 | else: 132 | raise ValueError(f"Unknown init_scheme: {init_sceheme}") 133 | setattr(output_module, "out_features", 2) 134 | 135 | 136 | def compute_pairwise_reward_modeling_loss(model, inputs, return_outputs=False): 137 | # input_ids, attention_mask each of size (bsz, num_candidates, seq_len). 138 | # index_0, index_1 each of size (bsz, num_pairs); indexes into input_ids. 139 | # choice of size (bsz, num_pairs); 1 if index_1's seq is chosen, 0 otherwise. 140 | input_ids, eos_pos, index_0, index_1, choice = unpack_dict( 141 | inputs, keys=("input_ids", "eos_pos", "index_0", "index_1", "choice") 142 | ) 143 | num_candidates, num_pairs = input_ids.size(1), choice.size(1) 144 | input_ids_flat = einops.rearrange(input_ids, "b c l -> (b c) l") 145 | eos_pos_flat = einops.rearrange(eos_pos, "b c -> (b c)") 146 | input_pos_flat = torch.arange( 147 | 0, input_ids_flat.size(-1), device=input_ids_flat.device 148 | ) 149 | outputs = model( 150 | input_ids=input_ids_flat, 151 | input_pos=input_pos_flat, 152 | eos_pos=eos_pos_flat, 153 | ) 154 | rewards_flat = outputs.rewards 155 | rewards = einops.rearrange( 156 | rewards_flat, "(b c) -> b c", c=num_candidates 157 | ) # Size: (bsz, num_candidates). 158 | 159 | rewards_0, rewards_1 = tuple( 160 | batch_select(rewards, index) for index in (index_0, index_1) 161 | ) # Size: (bsz, num_pairs). 162 | logits = rewards_1 - rewards_0 # Size: (bsz, num_pairs). 163 | # Type casting of `choice` is due to amp.autocast context manager. 164 | loss = F.binary_cross_entropy_with_logits( 165 | logits, choice.to(logits.dtype), reduction="mean" 166 | ) 167 | return (loss, dict(logits=logits)) if return_outputs else loss 168 | 169 | 170 | def compute_pairwise_reward_modeling_metrics( 171 | predictions: torch.Tensor, label_ids: torch.Tensor 172 | ) -> Dict: 173 | # eval_prediction.label_ids is a tuple that matches up with `training_args.label_names`. 174 | logits = torch.tensor(predictions).squeeze(-1) 175 | labels = torch.tensor(label_ids[-1]).squeeze(-1) 176 | predictions = (logits >= 0.0).long() 177 | accuracy = predictions.eq(labels).float().mean().item() 178 | label_positive_rate = (labels == 1).float().mean().item() 179 | return dict( 180 | accuracy=accuracy, 181 | label_positive_rate=label_positive_rate, 182 | ) 183 | -------------------------------------------------------------------------------- /models/rl_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The GPT-Accelera Team 2 | # Copyright 2023 The Self-Align Team 3 | # Copyright 2023 The Alpaca Team 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Model classes that are shared across different algorithms. 18 | 19 | WARNING: 20 | Do not tamper with the state_dict function for any of these classes. 21 | If you tamper, make sure the keys are the same, otherwise FSDP will get confused. 22 | """ 23 | 24 | import abc 25 | import logging 26 | from typing import Dict, Optional, Tuple 27 | 28 | import torch 29 | from torch import Tensor, nn 30 | import torch.nn.functional as F 31 | from torch.nn.parallel import DistributedDataParallel as DDP 32 | from torch.distributed import _functional_collectives as funcol 33 | 34 | from arguments import Arguments 35 | from models.model import Transformer 36 | from models.tokenizer_utils import AcceleraTokenizer 37 | from models.tp import get_model_parallel_group, compute_vocab_parallel_logprobs 38 | 39 | logger = logging.getLogger(__name__) 40 | 41 | 42 | class Policy(nn.Module, abc.ABC): 43 | def __init__( 44 | self, 45 | args: Arguments, 46 | base_model: Transformer, 47 | base_tokenizer: AcceleraTokenizer, 48 | ): 49 | super().__init__() 50 | self.args = args 51 | self.base_model = base_model 52 | self.base_tokenizer = base_tokenizer 53 | 54 | global decode_one_token 55 | 56 | if decode_one_token is None: 57 | if self.args.compile: 58 | decode_one_token = torch.compile( 59 | _decode_one_token, mode="default", fullgraph=True 60 | ) 61 | else: 62 | decode_one_token = _decode_one_token 63 | 64 | @abc.abstractmethod 65 | def forward( 66 | self, 67 | queries: Tensor, 68 | query_attn_masks: Tensor, 69 | responses: Optional[Tensor] = None, 70 | temperature: Optional[float] = None, 71 | mode: Optional[str] = None, 72 | ) -> Dict[str, Tensor]: 73 | raise NotImplementedError 74 | 75 | def respond( 76 | self, 77 | queries: Tensor, 78 | query_attn_masks: Tensor, 79 | temperature: Optional[float] = None, 80 | num_return_sequences=1, 81 | ) -> Dict[str, Tensor]: 82 | assert not self.training, "Policy must be in eval model for generation." 83 | return self._post_respond( 84 | self._respond(queries, query_attn_masks, temperature, num_return_sequences) 85 | ) 86 | 87 | @abc.abstractmethod 88 | def _respond( 89 | self, 90 | queries: Tensor, 91 | query_attn_masks: Tensor, 92 | temperature: Optional[float] = None, 93 | num_return_sequences=1, 94 | ) -> Dict[str, Tensor]: 95 | raise NotImplementedError 96 | 97 | def _post_respond(self, respond_outputs: Dict[str, Tensor]) -> Dict[str, Tensor]: 98 | return respond_outputs 99 | 100 | 101 | class AutoregressivePolicy(Policy): 102 | def forward( 103 | self, 104 | queries: Tensor, 105 | query_attn_masks: Tensor, 106 | responses: Optional[Tensor] = None, 107 | temperature: Optional[float] = None, 108 | mode: Optional[str] = None, 109 | ) -> Dict[str, Tensor]: 110 | # TODO(lxuechen): Refactor attention mask. Here query_attn_masks overrides padding-based attention mask. 111 | if mode == "respond": 112 | return self.respond(queries, query_attn_masks, temperature) 113 | 114 | assert responses is not None 115 | if temperature is None: 116 | temperature = self.args.temperature 117 | input_ids = torch.cat([queries, responses], dim=1) 118 | attention_mask = input_ids.ne(self.base_tokenizer.pad_id) 119 | attention_mask[:, : queries.size(1)] = query_attn_masks 120 | 121 | batch_size, T = input_ids.size(0), input_ids.size(1) 122 | device = input_ids.device 123 | 124 | inputs, shifts = prepare_right_pad_sequences( 125 | input_ids=input_ids, 126 | attention_mask=attention_mask, 127 | pad_token_id=self.base_tokenizer.pad_id, 128 | ) 129 | input_pos = torch.arange(0, T, device=device) 130 | 131 | logits = self.base_model(inputs, input_pos, fully_causal=True).float() 132 | logits = restore_from_right_pad_sequences(logits, shifts) 133 | 134 | original_logits = logits[:, -self.args.target_max_len - 1 : -1] 135 | logits = original_logits / temperature 136 | labels = input_ids[:, -self.args.target_max_len :] 137 | 138 | with torch.autocast(device_type="cuda", enabled=False): 139 | dtype_logits = logits.float() 140 | if self.base_model.vocab_parallel: 141 | logprobs = compute_vocab_parallel_logprobs( 142 | dtype_logits, labels, ignore_index=self.base_tokenizer.pad_id 143 | ) 144 | else: 145 | logprobs = compute_logprobs( 146 | dtype_logits, labels, ignore_index=self.base_tokenizer.pad_id 147 | ) 148 | entropies = -( 149 | dtype_logits.softmax(dim=-1) * dtype_logits.log_softmax(dim=-1) 150 | ).sum(dim=-1) 151 | non_ignore_mask = labels.ne(self.base_tokenizer.pad_id).to( 152 | dtype=entropies.dtype 153 | ) 154 | reg_entropies = entropies * non_ignore_mask 155 | return dict( 156 | logprobs=logprobs, 157 | entropies=entropies, 158 | reg_entropies=reg_entropies, 159 | reg_entropies_weight=non_ignore_mask, 160 | ) 161 | 162 | def _respond( 163 | self, 164 | queries: Tensor, 165 | query_attn_masks: Tensor, 166 | temperature: Optional[float] = None, 167 | num_return_sequences=1, 168 | ) -> Dict[str, Tensor]: 169 | del num_return_sequences # Unused. 170 | 171 | unwrapped_base_model = self.base_model 172 | if isinstance(self.base_model, DDP): 173 | unwrapped_base_model = self.base_model.module 174 | 175 | B, T = queries.size(0), queries.size(1) 176 | T_new = T + self.args.target_max_len 177 | assert T_new <= unwrapped_base_model.config.block_size 178 | 179 | device, dtype = queries.device, queries.dtype 180 | with torch.device(device): 181 | unwrapped_base_model.setup_caches(max_batch_size=B, max_seq_length=T_new) 182 | 183 | if temperature is None: 184 | temperature = self.args.temperature 185 | 186 | # create an zero's tensor of the expected final shape and fill in the current tokens 187 | empty = torch.zeros((B, T_new), dtype=dtype, device=device) 188 | empty[:, :T] = queries 189 | seq = empty 190 | input_pos = torch.arange(0, T, device=device) 191 | 192 | sampling_kwargs = dict( 193 | temperature=temperature, 194 | top_k=50, 195 | ) 196 | 197 | shifts = prepare_left_pad_mask_pos( 198 | queries, 199 | attention_mask=query_attn_masks, 200 | pad_token_id=self.base_tokenizer.pad_id, 201 | ) 202 | 203 | with torch.backends.cuda.sdp_kernel( 204 | enable_flash=False, enable_mem_efficient=False, enable_math=True 205 | ): 206 | next_token = prefill( 207 | unwrapped_base_model, queries, input_pos, shifts, **sampling_kwargs 208 | ) 209 | 210 | seq[:, T] = next_token.view(B) 211 | 212 | input_pos = torch.tensor([T], device=device, dtype=torch.int) 213 | 214 | generated_tokens, _, _ = decode_n_tokens( 215 | unwrapped_base_model, 216 | next_token.view(B, -1), 217 | input_pos, 218 | shifts, 219 | self.args.target_max_len - 1, 220 | self.base_tokenizer.eos_id, 221 | **sampling_kwargs, 222 | ) 223 | 224 | generated_tokens = torch.cat(generated_tokens, dim=-1).view(B, -1) 225 | seq[:, T + 1 : T + 1 + generated_tokens.size(1)] = generated_tokens 226 | assert seq[:, T:].size(1) == self.args.target_max_len 227 | 228 | return dict( 229 | responses=seq[:, T:] 230 | ) # Size (bsz * num_return_sequences, response_len). 231 | 232 | 233 | class Value(nn.Module, abc.ABC): 234 | def __init__( 235 | self, 236 | args: Arguments, 237 | base_model: Transformer, 238 | base_tokenizer: AcceleraTokenizer, 239 | ): 240 | super().__init__() 241 | self.args = args 242 | self.base_model = base_model 243 | self.base_tokenizer = base_tokenizer 244 | self.initialized = False 245 | 246 | @abc.abstractmethod 247 | def forward( 248 | self, queries: Tensor, query_attn_masks: Tensor, responses: Tensor 249 | ) -> Dict[str, Tensor]: 250 | raise NotImplementedError 251 | 252 | 253 | class AutoregressiveValue(Value): 254 | def forward( 255 | self, queries: Tensor, query_attn_masks: Tensor, responses: Tensor 256 | ) -> Dict[str, Tensor]: 257 | assert self.initialized, "Value model must be initialized before forward pass." 258 | 259 | sequences = torch.cat([queries, responses], dim=1) 260 | sequence_attn_masks = sequences.ne(self.base_tokenizer.pad_id) 261 | sequence_attn_masks[:, : queries.size(1)] = query_attn_masks 262 | 263 | B, T = sequences.size(0), sequences.size(1) 264 | inputs, shifts = prepare_right_pad_sequences( 265 | input_ids=sequences, 266 | attention_mask=sequence_attn_masks, 267 | pad_token_id=self.base_tokenizer.pad_id, 268 | ) 269 | 270 | device = queries.device 271 | values = self.base_model( 272 | inputs, torch.arange(0, T, device=device), fully_causal=True 273 | ) 274 | values = values.mean(dim=-1) 275 | 276 | values = restore_from_right_pad_sequences(values, shifts) 277 | values = values[:, queries.size(1) - 1 : -1] 278 | assert values.size(1) == responses.size(1) 279 | 280 | return dict(values=values) 281 | 282 | 283 | def make_policy_with_base_model( 284 | args: Arguments, 285 | base_model: Transformer, 286 | base_tokenizer: AcceleraTokenizer, 287 | ) -> AutoregressivePolicy: 288 | policy = AutoregressivePolicy(args, base_model, base_tokenizer) 289 | return policy 290 | 291 | 292 | def make_value_with_base_model( 293 | args: Arguments, 294 | base_model: Transformer, 295 | base_tokenizer: AcceleraTokenizer, 296 | ) -> AutoregressiveValue: 297 | value_model = AutoregressiveValue(args, base_model, base_tokenizer) 298 | value_model.initialized = True 299 | return value_model 300 | 301 | 302 | def prepare_right_pad_sequences(input_ids, attention_mask=None, pad_token_id=0): 303 | # Assuming '0' is the padding value 304 | if attention_mask is None: 305 | attention_mask = input_ids != pad_token_id 306 | # torch.argmax: If there are multiple maximal values 307 | # then the indices of the first maximal value are returned. 308 | shifts = torch.argmax(attention_mask.to(torch.int), dim=1) 309 | 310 | # if (shifts == 0).all(): 311 | # return input_ids, None 312 | 313 | ind0 = torch.arange(input_ids.size(0), device=input_ids.device) 314 | ind0 = ind0[:, None].expand(-1, input_ids.size(1)) 315 | ind1 = torch.arange(input_ids.size(1), device=input_ids.device) 316 | ind1 = ind1[None, :].expand(input_ids.size(0), -1) 317 | 318 | rolled_input_ids = input_ids[ 319 | ind0, (ind1 + shifts[:, None] + input_ids.size(1)) % input_ids.size(1) 320 | ] 321 | return rolled_input_ids, shifts 322 | 323 | 324 | def restore_from_right_pad_sequences(inputs, shifts): 325 | if shifts is None: 326 | return inputs 327 | 328 | ind0 = torch.arange(inputs.size(0), device=inputs.device) 329 | ind0 = ind0[:, None].expand(-1, inputs.size(1)) 330 | ind1 = torch.arange(inputs.size(1), device=inputs.device) 331 | ind1 = ind1[None, :].expand(inputs.size(0), -1) 332 | 333 | rolled_inputs = inputs[ 334 | ind0, (ind1 - shifts[:, None] - inputs.size(1)) % inputs.size(1) 335 | ] 336 | return rolled_inputs 337 | 338 | 339 | def prepare_left_pad_mask_pos(input_ids, attention_mask=None, pad_token_id=0): 340 | # Assuming '0' is the padding value 341 | if attention_mask is None: 342 | attention_mask = input_ids != pad_token_id 343 | shifts = torch.argmax(attention_mask.to(torch.int), dim=1) 344 | return shifts 345 | 346 | 347 | def multinomial_sample_one_no_sync( 348 | probs_sort, 349 | ): # Does multinomial sampling without a cuda synchronization 350 | q = torch.empty_like(probs_sort).exponential_(1) 351 | return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) 352 | # return torch.argmax(probs_sort, dim=-1, keepdim=True).to(dtype=torch.int) 353 | 354 | 355 | def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None): 356 | logits = logits / max(temperature, 1e-5) 357 | 358 | if top_k is not None: 359 | v, _ = torch.topk(logits, min(top_k, logits.size(-1)), dim=-1) 360 | pivot = v.select(-1, -1).view(-1, 1) 361 | logits = torch.where(logits < pivot, -float("Inf"), logits) 362 | probs = torch.nn.functional.softmax(logits, dim=-1) 363 | return probs 364 | 365 | 366 | def sample( 367 | logits, vocab_parallel, temperature: float = 1.0, top_k: Optional[int] = None 368 | ): 369 | with torch.autocast(device_type="cuda", enabled=False): 370 | logits = logits[:, -1].float() 371 | 372 | if vocab_parallel: 373 | logits = funcol.all_gather_tensor( 374 | logits, gather_dim=-1, group=get_model_parallel_group() 375 | ) 376 | 377 | probs = logits_to_probs(logits, temperature, top_k) 378 | idx_next = multinomial_sample_one_no_sync(probs) 379 | return idx_next, probs 380 | 381 | 382 | def prefill( 383 | model: Transformer, 384 | x: torch.Tensor, 385 | input_pos: torch.Tensor, 386 | left_pad_mask_pos: torch.Tensor, 387 | **sampling_kwargs, 388 | ) -> torch.Tensor: 389 | # input_pos: [B, S] 390 | logits = model(x, input_pos, left_pad_mask_pos) 391 | return sample(logits, model.vocab_parallel, **sampling_kwargs)[0] 392 | 393 | 394 | def _decode_one_token( 395 | model: Transformer, 396 | x: torch.Tensor, 397 | input_pos: torch.Tensor, 398 | left_pad_mask_pos: torch.Tensor, 399 | **sampling_kwargs, 400 | ) -> Tuple[torch.Tensor, torch.Tensor]: 401 | # input_pos: [B, 1] 402 | assert input_pos.shape[-1] == 1 403 | logits = model(x, input_pos, left_pad_mask_pos) 404 | return sample(logits, model.vocab_parallel, **sampling_kwargs) 405 | 406 | 407 | decode_one_token = None 408 | 409 | 410 | def decode_n_tokens( 411 | model: Transformer, 412 | cur_token: torch.Tensor, 413 | input_pos: torch.Tensor, 414 | left_pad_mask_pos: torch.Tensor, 415 | num_new_tokens: int, 416 | eos_id: Optional[int] = None, 417 | **sampling_kwargs, 418 | ): 419 | eos_flag = None 420 | if eos_id is not None: 421 | eos_flag = torch.zeros_like( 422 | cur_token, dtype=torch.bool, device=cur_token.device 423 | ) 424 | 425 | new_tokens, new_probs = [], [] 426 | for i in range(num_new_tokens): 427 | with torch.backends.cuda.sdp_kernel( 428 | enable_flash=False, enable_mem_efficient=False, enable_math=True 429 | ): # Actually better for Inductor to codegen attention here 430 | next_token, next_prob = decode_one_token( 431 | model, cur_token, input_pos, left_pad_mask_pos, **sampling_kwargs 432 | ) 433 | input_pos += 1 434 | new_tokens.append(next_token.clone().view(-1, 1)) 435 | new_probs.append(next_prob.clone().view(-1, 1)) 436 | cur_token = next_token.view(-1, 1) 437 | 438 | if eos_flag is not None: 439 | eos_flag = eos_flag | (next_token == eos_id) 440 | 441 | if eos_flag is not None and eos_flag.all(): 442 | break 443 | 444 | return new_tokens, new_probs, i 445 | 446 | 447 | def compute_logprobs( 448 | logits: torch.Tensor, labels: torch.Tensor, ignore_index: int 449 | ) -> torch.Tensor: 450 | """Compute per-token logprobs, zeroing out places with ignore_index (padding).""" 451 | return -F.cross_entropy( 452 | logits.permute(0, 2, 1), labels, reduction="none", ignore_index=ignore_index 453 | ) 454 | -------------------------------------------------------------------------------- /scripts/convert_checkpoint_to_hf.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig 2 | from tqdm import tqdm 3 | import torch 4 | import re 5 | import argparse 6 | import os 7 | import glob 8 | 9 | # we need to check that we have login the HF account 10 | # !huggingface-cli whoami 11 | # !huggingface-cli login 12 | 13 | 14 | def load_and_merge_models( 15 | tp_ckpt_name, pretrain_name, tokenizer_name, save_name_hf, push_to_hf_hub_name 16 | ): 17 | assert ( 18 | save_name_hf or push_to_hf_hub_name 19 | ), "Please provide a save path or push to HF hub name" 20 | 21 | tp_model_list = [] 22 | 23 | last_checkpoint_file = os.path.join(tp_ckpt_name, "last_checkpoint") 24 | with open(last_checkpoint_file, "r") as f: 25 | last_checkpoint_file = f.readline().strip() 26 | 27 | last_checkpoint_file = last_checkpoint_file.split("/")[-1] 28 | last_checkpoint_file = os.path.join(tp_ckpt_name, last_checkpoint_file) 29 | 30 | print("Loading checkpoint files:", last_checkpoint_file) 31 | for file in sorted(glob.glob(last_checkpoint_file)): 32 | tp_model_list.append( 33 | torch.load( 34 | file, 35 | mmap=True, 36 | )["model"] 37 | ) 38 | 39 | print("Loading HF model...") 40 | tokenizer = AutoTokenizer.from_pretrained( 41 | tokenizer_name, 42 | ) 43 | 44 | model = AutoModelForCausalLM.from_pretrained( 45 | pretrain_name, 46 | # device_map="cpu", 47 | load_in_8bit=False, 48 | torch_dtype=torch.bfloat16, 49 | ) 50 | cpu_state_dict = model.cpu().state_dict() 51 | 52 | replaced_keys = set() 53 | 54 | print("Convert to HF model...") 55 | num_tp = len(tp_model_list) 56 | 57 | state_dict = {} 58 | 59 | for key in tp_model_list[0].keys(): 60 | if "wo" in key or "w2" in key: 61 | state_dict[key] = torch.cat( 62 | [tp_model_list[i][key].cpu() for i in range(num_tp)], dim=1 63 | ) 64 | elif "wqkv" in key: 65 | state_dict[key] = torch.stack( 66 | [tp_model_list[i][key].cpu() for i in range(num_tp)], dim=0 67 | ) 68 | elif "output" in key: 69 | state_dict[key] = torch.cat( 70 | [tp_model_list[i][key].cpu() for i in range(num_tp)], dim=1 71 | ) 72 | else: 73 | state_dict[key] = torch.cat( 74 | [tp_model_list[i][key].cpu() for i in range(num_tp)], dim=0 75 | ) 76 | 77 | pattern = r"layers\.(\d+)\." 78 | 79 | for key in state_dict.keys(): 80 | layer = None 81 | match = re.search(pattern, key) 82 | # layer number except for: 83 | # lm_head.weight 84 | if match: 85 | layer = match.group(1) 86 | elif "output.weight" in key: 87 | name = f"lm_head.weight" 88 | print(cpu_state_dict[name].size(), state_dict[key].size()) 89 | # repeat on dim 0 to match the size 90 | repeat_size = cpu_state_dict[name].size(0) // state_dict[key].size(0) 91 | new_state_dict = state_dict[key].repeat(repeat_size, 1) 92 | cpu_state_dict[name] = 0.0 * cpu_state_dict[name] + new_state_dict 93 | replaced_keys.add(name) 94 | else: 95 | raise ValueError(f"Invalid key: {key}") 96 | 97 | print("Converting layer", key) 98 | if "wqkv" in key: 99 | merged_q, merged_k, merged_v = [], [], [] 100 | reconstruct_q, reconstruct_k = [], [] 101 | 102 | if state_dict[key].size(2) == 4096: 103 | n_heads, n_local_heads = 32, 32 104 | elif state_dict[key].size(2) == 5120: 105 | n_heads, n_local_heads = 40, 40 106 | elif state_dict[key].size(2) == 6656: 107 | n_heads, n_local_heads = 52, 52 108 | elif state_dict[key].size(2) == 8192: 109 | n_heads, n_local_heads = 64, 8 110 | else: 111 | raise ValueError(f"Invalid size for {key}: {state_dict[key].size()}") 112 | 113 | head_dim = state_dict[key].size(1) // (n_heads + n_local_heads * 2) 114 | 115 | weight_splits = [ 116 | head_dim * n_heads, 117 | head_dim * n_local_heads, 118 | head_dim * n_local_heads, 119 | ] 120 | 121 | for split_idx in range(state_dict[key].size(0)): 122 | chunk = state_dict[key][split_idx] 123 | q, k, v = chunk.split(weight_splits, dim=0) 124 | merged_q.append(q) 125 | merged_k.append(k) 126 | merged_v.append(v) 127 | merged_q = torch.cat(merged_q, dim=0) 128 | merged_k = torch.cat(merged_k, dim=0) 129 | merged_v = torch.cat(merged_v, dim=0) 130 | 131 | #### qk need reconstruction #### 132 | split_qs = torch.split(merged_q, split_size_or_sections=128, dim=0) 133 | split_ks = torch.split(merged_k, split_size_or_sections=128, dim=0) 134 | for split in split_qs: 135 | matrix0 = split[::2, :] 136 | matrix1 = split[1::2, :] 137 | reconstruct_q.append(matrix0) 138 | reconstruct_q.append(matrix1) 139 | reconstruct_q = torch.cat(reconstruct_q, dim=0) 140 | for split in split_ks: 141 | matrix0 = split[::2, :] 142 | matrix1 = split[1::2, :] 143 | reconstruct_k.append(matrix0) 144 | reconstruct_k.append(matrix1) 145 | reconstruct_k = torch.cat(reconstruct_k, dim=0) 146 | #### qk need reconstruction #### 147 | 148 | name = f"model.layers.{layer}.self_attn.q_proj.weight" 149 | cpu_state_dict[name] = reconstruct_q 150 | replaced_keys.add(name) 151 | 152 | name = f"model.layers.{layer}.self_attn.k_proj.weight" 153 | cpu_state_dict[name] = reconstruct_k 154 | replaced_keys.add(name) 155 | 156 | name = f"model.layers.{layer}.self_attn.v_proj.weight" 157 | cpu_state_dict[name] = merged_v 158 | replaced_keys.add(name) 159 | 160 | if "wo" in key: 161 | name = f"model.layers.{layer}.self_attn.o_proj.weight" 162 | cpu_state_dict[name] = state_dict[key] 163 | replaced_keys.add(name) 164 | if "w1" in key: 165 | name = f"model.layers.{layer}.mlp.gate_proj.weight" 166 | cpu_state_dict[name] = state_dict[key] 167 | replaced_keys.add(name) 168 | if "w3" in key: 169 | name = f"model.layers.{layer}.mlp.up_proj.weight" 170 | cpu_state_dict[name] = state_dict[key] 171 | replaced_keys.add(name) 172 | if "w2" in key: 173 | name = f"model.layers.{layer}.mlp.down_proj.weight" 174 | cpu_state_dict[name] = state_dict[key] 175 | replaced_keys.add(name) 176 | 177 | unreplaced_keys = set(cpu_state_dict.keys()) - replaced_keys 178 | print("Unreplaced keys:", unreplaced_keys) 179 | 180 | print("Loading state dict...") 181 | 182 | model.load_state_dict(cpu_state_dict, strict=False) 183 | 184 | print("Saving HF model...") 185 | 186 | if save_name_hf is not None: 187 | model.save_pretrained(save_name_hf) 188 | config = AutoConfig.from_pretrained(pretrain_name) 189 | tokenizer.save_pretrained(save_name_hf) 190 | config.save_pretrained(save_name_hf) 191 | else: 192 | model.push_to_hub(push_to_hf_hub_name, private=True, safe_serialization=False) 193 | 194 | 195 | if __name__ == "__main__": 196 | parser = argparse.ArgumentParser(description="Process some integers.") 197 | parser.add_argument( 198 | "--tp_ckpt_name", type=str, help="Path to the TP checkpoint name", required=True 199 | ) 200 | parser.add_argument( 201 | "--tokenizer_name", type=str, help="Path to the tokenizer name", required=True 202 | ) 203 | parser.add_argument( 204 | "--pretrain_name", type=str, help="Path to the pretrain name", required=True 205 | ) 206 | parser.add_argument( 207 | "--save_name_hf", type=str, default=None, help="Path to save the HF model" 208 | ) 209 | parser.add_argument( 210 | "--push_to_hf_hub_name", type=str, default=None, help="Push to HF hub" 211 | ) 212 | 213 | args = parser.parse_args() 214 | load_and_merge_models( 215 | args.tp_ckpt_name, 216 | args.pretrain_name, 217 | args.tokenizer_name, 218 | args.save_name_hf, 219 | args.push_to_hf_hub_name, 220 | ) 221 | -------------------------------------------------------------------------------- /scripts/convert_hf_checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import json 7 | import sys 8 | from pathlib import Path 9 | from typing import Optional 10 | 11 | import torch 12 | import re 13 | 14 | # support running without installing as a package 15 | wd = Path(__file__).parent.parent.resolve() 16 | sys.path.append(str(wd)) 17 | 18 | from models.model import ModelArgs 19 | 20 | 21 | @torch.inference_mode() 22 | def convert_hf_checkpoint( 23 | *, 24 | checkpoint_dir: Path = Path( 25 | "checkpoints/meta-Transformer/Transformer-2-7b-chat-hf" 26 | ), 27 | model_name: Optional[str] = None, 28 | target_precision: str = "fp32", 29 | ) -> None: 30 | if model_name is None: 31 | model_name = checkpoint_dir.name 32 | 33 | config = ModelArgs.from_name(model_name) 34 | print(f"Model config {config.__dict__}") 35 | 36 | # Load the json file containing weight mapping 37 | model_map_json = checkpoint_dir / "pytorch_model.bin.index.json" 38 | 39 | assert model_map_json.is_file() 40 | 41 | with open(model_map_json) as json_map: 42 | bin_index = json.load(json_map) 43 | 44 | weight_map = { 45 | "model.embed_tokens.weight": "tok_embeddings.weight", 46 | "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight", 47 | "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight", 48 | "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight", 49 | "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight", 50 | "model.layers.{}.self_attn.rotary_emb.inv_freq": None, 51 | "model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight", 52 | "model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight", 53 | "model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight", 54 | "model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight", 55 | "model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight", 56 | "model.norm.weight": "norm.weight", 57 | "lm_head.weight": "output.weight", 58 | } 59 | bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()} 60 | 61 | def permute(w, n_head): 62 | dim = config.dim 63 | return ( 64 | w.view(n_head, 2, config.head_dim // 2, dim) 65 | .transpose(1, 2) 66 | .reshape(config.head_dim * n_head, dim) 67 | ) 68 | 69 | merged_result = {} 70 | for file in sorted(bin_files): 71 | state_dict = torch.load( 72 | str(file), map_location="cpu", mmap=True, weights_only=True 73 | ) 74 | 75 | if target_precision == "fp16": 76 | for key in tuple(state_dict.keys()): 77 | state_dict[key] = state_dict[key].half() 78 | elif target_precision == "bf16": 79 | for key in tuple(state_dict.keys()): 80 | state_dict[key] = state_dict[key].bfloat16() 81 | elif target_precision == "fp32": 82 | pass 83 | else: 84 | raise ValueError(f"Unsupported target_precision {target_precision}") 85 | merged_result.update(state_dict) 86 | final_result = {} 87 | for key, value in merged_result.items(): 88 | if "layers" in key: 89 | abstract_key = re.sub(r"(\d+)", "{}", key) 90 | layer_num = re.search(r"\d+", key).group(0) 91 | new_key = weight_map[abstract_key] 92 | if new_key is None: 93 | continue 94 | new_key = new_key.format(layer_num) 95 | else: 96 | new_key = weight_map[key] 97 | 98 | if len(value.shape) == 2 and value.size(1) == 32016: 99 | value = value[:, :32000] 100 | if value.size(0) == 32016: 101 | value = value[:32000, :] 102 | 103 | final_result[new_key] = value 104 | 105 | for key in tuple(final_result.keys()): 106 | if "wq" in key: 107 | q = final_result[key] 108 | k = final_result[key.replace("wq", "wk")] 109 | v = final_result[key.replace("wq", "wv")] 110 | q = permute(q, config.n_head) 111 | k = permute(k, config.n_local_heads) 112 | final_result[key.replace("wq", "wqkv")] = torch.cat([q, k, v]) 113 | del final_result[key] 114 | del final_result[key.replace("wq", "wk")] 115 | del final_result[key.replace("wq", "wv")] 116 | print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}") 117 | torch.save(final_result, checkpoint_dir / "model.pth") 118 | 119 | 120 | if __name__ == "__main__": 121 | import argparse 122 | 123 | parser = argparse.ArgumentParser(description="Convert HuggingFace checkpoint.") 124 | parser.add_argument( 125 | "--checkpoint_dir", 126 | type=Path, 127 | default=Path("checkpoints/meta-llama/llama-2-7b-chat-hf"), 128 | ) 129 | parser.add_argument("--model_name", type=str, default=None) 130 | parser.add_argument("--target_precision", type=str, default="fp32") 131 | 132 | args = parser.parse_args() 133 | convert_hf_checkpoint( 134 | checkpoint_dir=args.checkpoint_dir, 135 | model_name=args.model_name, 136 | target_precision=args.target_precision, 137 | ) 138 | -------------------------------------------------------------------------------- /scripts/download.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import os 7 | from requests.exceptions import HTTPError 8 | import sys 9 | from pathlib import Path 10 | from typing import Optional 11 | 12 | 13 | def hf_download( 14 | repo_id: Optional[str] = None, 15 | hf_token: Optional[str] = None, 16 | local_dir: Optional[str] = None, 17 | ) -> None: 18 | from huggingface_hub import snapshot_download 19 | 20 | local_dir = local_dir or "checkpoints" 21 | 22 | os.makedirs(f"{local_dir}/{repo_id}", exist_ok=True) 23 | try: 24 | snapshot_download( 25 | repo_id, 26 | local_dir=f"{local_dir}/{repo_id}", 27 | local_dir_use_symlinks=False, 28 | token=hf_token, 29 | ) 30 | except HTTPError as e: 31 | if e.response.status_code == 401: 32 | print( 33 | "You need to pass a valid `--hf_token=...` to download private checkpoints." 34 | ) 35 | else: 36 | raise e 37 | 38 | 39 | if __name__ == "__main__": 40 | import argparse 41 | 42 | parser = argparse.ArgumentParser(description="Download data from HuggingFace Hub.") 43 | parser.add_argument( 44 | "--repo_id", 45 | type=str, 46 | default="checkpoints/meta-llama/llama-2-7b-chat-hf", 47 | help="Repository ID to download from.", 48 | ) 49 | parser.add_argument( 50 | "--local_dir", type=str, default=None, help="Local directory to download to." 51 | ) 52 | parser.add_argument( 53 | "--hf_token", type=str, default=None, help="HuggingFace API token." 54 | ) 55 | 56 | args = parser.parse_args() 57 | hf_download(args.repo_id, args.hf_token, args.local_dir) 58 | -------------------------------------------------------------------------------- /scripts/prepare_ds_math_7b.sh: -------------------------------------------------------------------------------- 1 | set -e 2 | set -x 3 | 4 | export DATA_DIR=/path/to/your/data/directory 5 | export MODEL_REPO=deepseek-ai/deepseek-math-7b-base 6 | 7 | python scripts/download.py \ 8 | --repo_id $MODEL_REPO \ 9 | --local_dir $DATA_DIR/checkpoints 10 | 11 | python scripts/convert_hf_checkpoint.py \ 12 | --checkpoint_dir $DATA_DIR/checkpoints/$MODEL_REPO \ 13 | --target_precision bf16 14 | -------------------------------------------------------------------------------- /scripts/prepare_llemma_34b.sh: -------------------------------------------------------------------------------- 1 | set -e 2 | set -x 3 | 4 | export DATA_DIR=/path/to/your/data/directory 5 | export MODEL_REPO=EleutherAI/llemma_34b 6 | 7 | python scripts/download.py \ 8 | --repo_id $MODEL_REPO \ 9 | --local_dir $DATA_DIR/checkpoints 10 | 11 | python scripts/convert_hf_checkpoint.py \ 12 | --checkpoint_dir $DATA_DIR/checkpoints/$MODEL_REPO \ 13 | --target_precision bf16 14 | -------------------------------------------------------------------------------- /scripts/prepare_llemma_7b.sh: -------------------------------------------------------------------------------- 1 | set -e 2 | set -x 3 | 4 | export DATA_DIR=/path/to/your/data/directory 5 | export MODEL_REPO=EleutherAI/llemma_7b 6 | 7 | python scripts/download.py \ 8 | --repo_id $MODEL_REPO \ 9 | --local_dir $DATA_DIR/checkpoints 10 | 11 | python scripts/convert_hf_checkpoint.py \ 12 | --checkpoint_dir $DATA_DIR/checkpoints/$MODEL_REPO \ 13 | --target_precision bf16 14 | -------------------------------------------------------------------------------- /train_rl_ppo.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The GPT-Accelera Team 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import os 9 | import sys 10 | import tempfile 11 | from pathlib import Path 12 | import logging 13 | 14 | import torch 15 | 16 | import torch._inductor.config 17 | import torch._dynamo.config 18 | 19 | torch._inductor.config.coordinate_descent_tuning = True 20 | torch._inductor.config.triton.unique_kernel_names = True 21 | torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future 22 | 23 | torch.backends.cuda.matmul.allow_tf32 = True 24 | torch.backends.cudnn.allow_tf32 = True 25 | 26 | try: 27 | import wandb 28 | except ImportError: 29 | wandb = None 30 | 31 | # support running without installing as a package 32 | wd = Path(__file__).parent.parent.resolve() 33 | sys.path.append(str(wd)) 34 | 35 | from models.tokenizer_utils import AcceleraTokenizer 36 | from models.tp import ( 37 | maybe_init_dist, 38 | initialize_model_parallel, 39 | ) 40 | from trainers.ppo_trainer import PPOTrainer, make_models 41 | from trainers.common_utils import manual_seed 42 | 43 | from data_utils.data_utils_ppo import make_rl_data_module 44 | 45 | from hf_argparser import HfArgumentParser 46 | from arguments import Arguments as TrainingArguments 47 | from checkpoint_utils import get_latest_checkpoint_path 48 | 49 | logger = logging.getLogger(__name__) 50 | 51 | 52 | def main(args: TrainingArguments): 53 | base_model_name_or_path = args.base_checkpoint_path 54 | tokenizer_path = base_model_name_or_path.parent / "tokenizer.model" 55 | if not tokenizer_path.is_file(): 56 | tokenizer_path = base_model_name_or_path.parent 57 | 58 | global print 59 | device_id = maybe_init_dist() 60 | use_tp = device_id is not None 61 | if use_tp: 62 | tp_size = args.tensor_parallel_size or torch.distributed.get_world_size() 63 | initialize_model_parallel(tp_size) 64 | torch.distributed.barrier() 65 | if device_id != 0: 66 | # only print on rank 0 67 | print = lambda *_args, **_kwargs: None 68 | 69 | checkpoint_dir, _, _ = get_latest_checkpoint_path(args.save_dir, prefix="policy_") 70 | checkpoint_dir = Path(checkpoint_dir).parent if checkpoint_dir is not None else None 71 | 72 | torch.distributed.barrier() 73 | if args.report_to == "wandb" and wandb is not None: 74 | if device_id == 0: 75 | wandb_logging_dir = os.path.join( 76 | tempfile.gettempdir(), f"{os.getuid()}_wandb" 77 | ) 78 | if not os.path.exists(wandb_logging_dir): 79 | os.makedirs(wandb_logging_dir, exist_ok=True) 80 | os.environ["WANDB_DIR"] = wandb_logging_dir 81 | wandb.init( 82 | name=args.wandb_name, 83 | project=args.wandb_project, 84 | entity=args.wandb_entity, 85 | resume="allow", 86 | magic=True, 87 | dir=wandb_logging_dir, 88 | force=True, 89 | ) 90 | wandb.config.update(vars(args)) 91 | 92 | if checkpoint_dir is None: 93 | print("Training from scratch.") 94 | else: 95 | print("Loading from checkpoint:", checkpoint_dir) 96 | 97 | tokenizer = AcceleraTokenizer(tokenizer_path) 98 | tokenizer.pad_id = tokenizer.unk_id 99 | 100 | manual_seed(args.seed) 101 | 102 | data_module: dict = make_rl_data_module(tokenizer, args) 103 | 104 | for i in range(3): 105 | token_ids = data_module["train_dataset"][i]["queries"] 106 | print(tokenizer.decode(token_ids, skip_special_tokens=True)) 107 | print("=" * 20) 108 | 109 | model_module = make_models( 110 | tokenizer, 111 | args, 112 | resume_from_checkpoint=( 113 | checkpoint_dir if args.resume_from_checkpoint else None 114 | ), 115 | ) 116 | 117 | trainer = PPOTrainer( 118 | args=args, 119 | **data_module, 120 | **model_module, 121 | tokenizer=tokenizer, 122 | ) 123 | 124 | trainer.train( 125 | resume_training_ckpt=checkpoint_dir if args.resume_from_checkpoint else None 126 | ) 127 | 128 | 129 | if __name__ == "__main__": 130 | parser = HfArgumentParser((TrainingArguments,)) 131 | args = parser.parse_args_into_dataclasses()[0] 132 | main(args) 133 | -------------------------------------------------------------------------------- /train_rm_pairwise.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The GPT-Accelera Team 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import os 9 | import sys 10 | import tempfile 11 | import time 12 | from pathlib import Path 13 | from typing import Dict 14 | import itertools 15 | 16 | import tqdm 17 | import einops 18 | 19 | import torch 20 | import torch.nn.functional as F 21 | from torch.utils.data.distributed import DistributedSampler 22 | 23 | import torch._inductor.config 24 | import torch._dynamo.config 25 | 26 | torch._inductor.config.coordinate_descent_tuning = True 27 | torch._inductor.config.triton.unique_kernel_names = True 28 | torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future 29 | 30 | torch.backends.cuda.matmul.allow_tf32 = True 31 | torch.backends.cudnn.allow_tf32 = True 32 | 33 | try: 34 | import wandb 35 | except ImportError: 36 | wandb = None 37 | 38 | # support running without installing as a package 39 | wd = Path(__file__).parent.parent.resolve() 40 | sys.path.append(str(wd)) 41 | 42 | from models.model import set_global_compile_mode 43 | from models.reward_model import RewardModel 44 | from models.tokenizer_utils import AcceleraTokenizer 45 | from models.tp import ( 46 | maybe_init_dist, 47 | initialize_model_parallel, 48 | get_model_parallel_group, 49 | get_model_parallel_world_size, 50 | get_data_parallel_world_size, 51 | clip_grad_norm_, 52 | ) 53 | 54 | from trainers.common_utils import manual_seed 55 | from data_utils.data_utils_rm_pairwise import ( 56 | make_pairwise_reward_modeling_data_module, 57 | ) 58 | 59 | from hf_argparser import HfArgumentParser 60 | from arguments import Arguments as TrainingArguments 61 | from checkpoint_utils import ( 62 | checkpoint_hook, 63 | get_latest_checkpoint_path, 64 | load_checkpoint, 65 | load_reward_model_from_sft_ckpt, 66 | ) 67 | from training_utils.trainer_utils import ( 68 | create_optimizer, 69 | create_fsdp_model_for_finetune, 70 | get_cosine_schedule_with_warmup, 71 | ) 72 | 73 | IGNORE_INDEX = -100 74 | 75 | 76 | def model_forward(model, x, eos_pos): 77 | return model(x, eos_pos) 78 | 79 | 80 | def model_forward_with_loss( 81 | model: RewardModel, 82 | input_ids: torch.Tensor, 83 | choice: torch.Tensor, 84 | ) -> torch.Tensor: 85 | """ 86 | Compute the loss for a given model and prompts. 87 | """ 88 | # create an empty tensor of the expected final shape and fill in the current tokens 89 | # input_ids: (bsz, num_candidates, max_seq_len) 90 | # choice: (bsz, num_pairs) 91 | batch_size, num_candidates, T = ( 92 | input_ids.size(0), 93 | input_ids.size(1), 94 | input_ids.size(2), 95 | ) 96 | 97 | assert choice.size(0) == batch_size 98 | assert choice.size(1) == 1 99 | assert num_candidates == 2 100 | 101 | device = input_ids.device 102 | with torch.device(device): 103 | model.backbone_model.setup_caches( 104 | max_batch_size=batch_size * 2, max_seq_length=T, kv_cache=False 105 | ) 106 | 107 | input_ids = einops.rearrange(input_ids, "b c t -> (b c) t") 108 | eos_pos = input_ids.ne(0).long().sum(dim=-1) - 1 109 | 110 | with torch.backends.cuda.sdp_kernel( 111 | enable_flash=True, enable_math=False, enable_mem_efficient=False 112 | ): 113 | rewards = model_forward(model, input_ids, eos_pos) 114 | 115 | rewards = einops.rearrange(rewards, "(b c) -> b c", b=batch_size, c=num_candidates) 116 | logits = rewards[:, 0] - rewards[:, 1] 117 | 118 | with torch.autocast(device_type="cuda", enabled=False): 119 | dtype_logits = logits.float() 120 | loss = F.binary_cross_entropy_with_logits( 121 | dtype_logits.view(-1).float(), 122 | choice.view(-1).float(), 123 | reduction="mean", 124 | ) 125 | 126 | metrics = compute_reward_modeling_metrics(logits.view(-1), choice.view(-1)) 127 | 128 | return loss, metrics 129 | 130 | 131 | def compute_reward_modeling_metrics(logits, labels) -> Dict: 132 | # eval_prediction.label_ids is a tuple that matches up with `training_args.label_names`. 133 | predictions = (logits >= 0.0).long() 134 | accuracy = predictions.eq(labels).float().mean() 135 | label_positive_rate = (labels == 1).float().mean() 136 | positive_rate = (predictions == 1).float().mean() 137 | true_positive_rate = (predictions * labels).float().sum() / labels.sum() 138 | false_positive_rate = (predictions * (1 - labels)).float().sum() / ( 139 | 1 - labels 140 | ).sum() 141 | return dict( 142 | accuracy=accuracy, 143 | label_positive_rate=label_positive_rate, 144 | positive_rate=positive_rate, 145 | true_positive_rate=true_positive_rate, 146 | false_positive_rate=false_positive_rate, 147 | ) 148 | 149 | 150 | def encode_tokens(tokenizer, string, bos=True, device="cuda"): 151 | tokens = tokenizer.encode(string) 152 | if bos: 153 | tokens = [tokenizer.bos_id()] + tokens 154 | return torch.tensor(tokens, dtype=torch.int, device=device) 155 | 156 | 157 | def main( 158 | args: TrainingArguments, 159 | ) -> None: 160 | """Finetune a model on a given dataset.""" 161 | checkpoint_path = args.checkpoint_path 162 | sft_checkpoint_path = args.sft_checkpoint_path 163 | compile = args.compile 164 | assert checkpoint_path.is_file(), checkpoint_path 165 | 166 | if sft_checkpoint_path is not None: 167 | assert sft_checkpoint_path.is_dir(), sft_checkpoint_path 168 | 169 | tokenizer_path = checkpoint_path.parent / "tokenizer.model" 170 | assert tokenizer_path.is_file(), tokenizer_path 171 | 172 | global print 173 | device_id = maybe_init_dist() 174 | use_tp = device_id is not None 175 | if use_tp: 176 | tp_size = args.tensor_parallel_size or torch.distributed.get_world_size() 177 | initialize_model_parallel(tp_size) 178 | torch.distributed.barrier() 179 | tp_group = get_model_parallel_group() 180 | 181 | if device_id != 0: 182 | # only print on rank 0 183 | print = lambda *args, **kwargs: None 184 | 185 | if args.report_to == "wandb" and wandb is not None: 186 | if device_id == 0: 187 | wandb_logging_dir = os.path.join( 188 | tempfile.gettempdir(), f"{os.getuid()}_wandb" 189 | ) 190 | if not os.path.exists(wandb_logging_dir): 191 | os.makedirs(wandb_logging_dir, exist_ok=True) 192 | os.environ["WANDB_DIR"] = wandb_logging_dir 193 | wandb.init( 194 | name=args.wandb_name, 195 | project=args.wandb_project, 196 | entity=args.wandb_entity, 197 | resume="allow", 198 | magic=True, 199 | dir=wandb_logging_dir, 200 | force=True, 201 | ) 202 | wandb.config.update(vars(args)) 203 | 204 | device = "cuda" 205 | precision = args.param_dtype 206 | 207 | print("Loading model ...") 208 | t0 = time.time() 209 | 210 | resume_from_checkpoint = None 211 | resume_epoch = 0 212 | resume_global_step = 0 213 | 214 | if args.resume_from_checkpoint: 215 | ( 216 | resume_from_checkpoint, 217 | resume_epoch, 218 | resume_global_step, 219 | ) = get_latest_checkpoint_path(args.save_dir) 220 | 221 | if resume_from_checkpoint is not None: 222 | sft_checkpoint_path = None 223 | 224 | model = load_reward_model_from_sft_ckpt( 225 | checkpoint_path, 226 | sft_checkpoint_path, 227 | device, 228 | precision, 229 | use_tp, 230 | requires_grad=True, 231 | reward_head_init_scheme=args.reward_head_init_scheme, 232 | ) 233 | 234 | torch.cuda.synchronize() 235 | if use_tp: 236 | torch.distributed.barrier() 237 | print(f"Time to load model: {time.time() - t0:.02f} seconds") 238 | 239 | tokenizer = AcceleraTokenizer(tokenizer_path) 240 | 241 | data_module = make_pairwise_reward_modeling_data_module( 242 | tokenizer=tokenizer, 243 | args=args, 244 | ) 245 | train_dataset = data_module["train_dataset"] 246 | data_collator = data_module["data_collator"] 247 | 248 | model_size = sum( 249 | [ 250 | p.numel() * p.dtype.itemsize 251 | for p in itertools.chain(model.parameters(), model.buffers()) 252 | ] 253 | ) 254 | 255 | print(f"Model size: {model_size / 1e6:.02f} MB") 256 | manual_seed(args.seed) 257 | 258 | sampler = None 259 | if use_tp: 260 | sampler = DistributedSampler( 261 | train_dataset, 262 | shuffle=True, 263 | drop_last=True, 264 | ) 265 | 266 | train_loader = torch.utils.data.DataLoader( 267 | train_dataset, 268 | batch_size=args.per_device_train_batch_size, 269 | shuffle=(sampler is None), 270 | sampler=sampler, 271 | num_workers=0, 272 | pin_memory=True, 273 | collate_fn=data_collator, 274 | ) 275 | 276 | if args.print_training_examples: 277 | print("Training examples:") 278 | cnt = 3 279 | for batch in train_loader: 280 | print("Input 1:") 281 | print( 282 | tokenizer.decode( 283 | batch["input_ids"][0, 0].tolist(), skip_special_tokens=False 284 | ), 285 | ) 286 | print("=" * 20) 287 | print("Input 2:") 288 | print( 289 | tokenizer.decode( 290 | batch["input_ids"][0, 1].tolist(), skip_special_tokens=False 291 | ), 292 | ) 293 | print("=" * 40) 294 | cnt -= 1 295 | if cnt == 0: 296 | break 297 | 298 | if compile: 299 | model = torch.compile(model) 300 | 301 | trainable_param_names = [ 302 | name for name, param in model.named_parameters() if param.requires_grad 303 | ] 304 | 305 | use_fsdp = False 306 | 307 | if get_data_parallel_world_size() > 1: 308 | use_fsdp = True 309 | model = create_fsdp_model_for_finetune(args, model) 310 | 311 | print("Using FSDP ...") 312 | print(model) 313 | 314 | optimizer = create_optimizer( 315 | args, 316 | model=model, 317 | optimizer_cpu_offload=args.optimizer_cpu_offload, 318 | model_cpu_offload=False, 319 | ) 320 | 321 | scheduler = get_cosine_schedule_with_warmup( 322 | optimizer, 323 | warmup_epochs=len(train_loader) * args.warmup_ratio, 324 | max_epochs=len(train_loader) * args.num_train_epochs, 325 | warmup_start_ratio=0.0, 326 | eta_min_ratio=args.lr_eta_min / args.learning_rate, 327 | ) 328 | 329 | if resume_from_checkpoint is not None: 330 | print( 331 | f"Resuming from checkpoint: {resume_from_checkpoint} (epoch {resume_epoch}, global step {resume_global_step})" 332 | ) 333 | load_checkpoint( 334 | resume_from_checkpoint, model, optimizer, scheduler, use_fsdp=use_fsdp 335 | ) 336 | 337 | micro_train_batch_size = ( 338 | args.micro_train_batch_size or args.per_device_train_batch_size 339 | ) 340 | 341 | assert ( 342 | args.per_device_train_batch_size % micro_train_batch_size == 0 343 | ), f"per_device_train_batch_size ({args.per_device_train_batch_size}) must be divisible by micro_train_batch_size ({micro_train_batch_size})" 344 | accumulate_steps = args.per_device_train_batch_size // micro_train_batch_size 345 | 346 | print( 347 | "Batch size per GPU for training: {}\n".format( 348 | args.per_device_train_batch_size 349 | ), 350 | "Micro batch size for training: {}\n".format(micro_train_batch_size), 351 | "Gradient accumulation steps: {}\n".format(accumulate_steps), 352 | ) 353 | 354 | micro_train_batch_size = micro_train_batch_size * torch.distributed.get_world_size() 355 | 356 | epoch_length = len(train_loader) 357 | 358 | if args.do_train: 359 | print("Starting training ...") 360 | t0 = time.time() 361 | for epoch in tqdm.trange( 362 | args.num_train_epochs, desc="Epoch", disable=device_id != 0 363 | ): 364 | if sampler is not None: 365 | train_loader.sampler.set_epoch(epoch) 366 | pbar = tqdm.tqdm( 367 | enumerate(train_loader), 368 | desc="Iteration", 369 | disable=device_id != 0, 370 | total=len(train_loader), 371 | ) 372 | for it, batch in pbar: 373 | global_step = epoch * epoch_length + it 374 | if global_step < resume_global_step: 375 | continue 376 | 377 | # torch.cuda.synchronize() 378 | model.zero_grad() 379 | 380 | input_ids = batch["input_ids"].to(device=device) 381 | choice = batch["choice"].to(device=device) 382 | 383 | input_ids, choice = prepare_batch( 384 | input_ids, 385 | choice, 386 | tokenizer=tokenizer, 387 | use_tp=use_tp, 388 | sync_group=tp_group, 389 | ) 390 | 391 | loss_scale = 1.0 / accumulate_steps 392 | for ex_idx in range(0, input_ids.size(0), micro_train_batch_size): 393 | if ex_idx + micro_train_batch_size < input_ids.size(0): 394 | with torch.cuda.amp.autocast(dtype=args.compute_dtype): 395 | loss, metrics = model_forward_with_loss( 396 | model, 397 | input_ids[ex_idx : ex_idx + micro_train_batch_size], 398 | choice[ex_idx : ex_idx + micro_train_batch_size], 399 | ) 400 | (loss_scale * loss).backward() 401 | else: 402 | with torch.cuda.amp.autocast(dtype=args.compute_dtype): 403 | loss, metrics = model_forward_with_loss( 404 | model, 405 | input_ids[ex_idx:], 406 | choice[ex_idx:], 407 | ) 408 | (loss_scale * loss).backward() 409 | grad_norm = clip_grad_norm_(model, 5.0) 410 | optimizer.step() 411 | scheduler.step() 412 | 413 | loss_copy = loss.detach().clone() 414 | acc_copy = metrics["accuracy"].detach().clone() 415 | torch.distributed.all_reduce(loss_copy) 416 | torch.distributed.all_reduce(acc_copy) 417 | avg_loss = (loss_copy / torch.distributed.get_world_size()).item() 418 | avg_acc = (acc_copy / torch.distributed.get_world_size()).item() 419 | grad_norm_copy = grad_norm.detach().clone().item() 420 | 421 | if device_id == 0: 422 | if args.report_to == "wandb" and wandb is not None: 423 | wandb.log( 424 | { 425 | "loss": avg_loss, 426 | "accuracy": avg_acc, 427 | "learning_rate": scheduler.get_last_lr()[0], 428 | "epoch": epoch, 429 | "step": it, 430 | "grad_norm": grad_norm_copy, 431 | }, 432 | step=global_step, 433 | ) 434 | else: 435 | # Just print to stdout. 436 | print( 437 | { 438 | "loss": avg_loss, 439 | "accuracy": avg_acc, 440 | "learning_rate": scheduler.get_last_lr()[0], 441 | "epoch": epoch, 442 | "step": it, 443 | "grad_norm": grad_norm_copy, 444 | } 445 | ) 446 | 447 | checkpoint_hook( 448 | args, 449 | model, 450 | optimizer, 451 | scheduler, 452 | epoch, 453 | global_step, 454 | epoch_length, 455 | use_fsdp=use_fsdp, 456 | trainable_param_names=trainable_param_names, 457 | ) 458 | 459 | torch.cuda.synchronize() 460 | 461 | epoch = args.num_train_epochs 462 | 463 | checkpoint_hook( 464 | args, 465 | model, 466 | optimizer, 467 | scheduler, 468 | epoch, 469 | epoch * epoch_length, 470 | epoch_length, 471 | use_fsdp=use_fsdp, 472 | trainable_param_names=trainable_param_names, 473 | ) 474 | 475 | print(f"Time to train: {time.time() - t0:.02f} seconds") 476 | 477 | 478 | def prepare_batch(input_ids, choice, tokenizer, use_tp, sync_group): 479 | pad_id = tokenizer.pad_id 480 | unk_id = tokenizer.unk_id 481 | if pad_id < 0: 482 | input_ids[input_ids == pad_id] = unk_id 483 | 484 | if use_tp and get_model_parallel_world_size() > 1: 485 | # aggregate (concat) all the inputs across tp sync_group 486 | new_input_ids = torch.empty_like(input_ids).repeat(sync_group.size(), 1) 487 | new_choice = torch.empty_like(new_choice).repeat(sync_group.size(), 1) 488 | 489 | torch.distributed.all_gather_into_tensor( 490 | new_input_ids, input_ids, group=sync_group 491 | ) 492 | torch.distributed.all_gather_into_tensor(new_choice, choice, group=sync_group) 493 | 494 | return new_input_ids, new_choice 495 | 496 | return input_ids, choice 497 | 498 | 499 | if __name__ == "__main__": 500 | parser = HfArgumentParser((TrainingArguments,)) 501 | args = parser.parse_args_into_dataclasses()[0] 502 | main(args) 503 | -------------------------------------------------------------------------------- /train_rm_pointwise.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The GPT-Accelera Team 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import os 9 | import sys 10 | import tempfile 11 | import time 12 | from pathlib import Path 13 | import itertools 14 | 15 | import tqdm 16 | 17 | import torch 18 | import torch.nn.functional as F 19 | from torch.utils.data.distributed import DistributedSampler 20 | 21 | import torch._inductor.config 22 | import torch._dynamo.config 23 | 24 | torch._inductor.config.coordinate_descent_tuning = True 25 | torch._inductor.config.triton.unique_kernel_names = True 26 | torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future 27 | 28 | torch.backends.cuda.matmul.allow_tf32 = True 29 | torch.backends.cudnn.allow_tf32 = True 30 | 31 | try: 32 | import wandb 33 | except ImportError: 34 | wandb = None 35 | 36 | # support running without installing as a package 37 | wd = Path(__file__).parent.parent.resolve() 38 | sys.path.append(str(wd)) 39 | 40 | from models.model import set_global_compile_mode 41 | from models.reward_model import RewardModel 42 | from models.tokenizer_utils import AcceleraTokenizer 43 | from models.tp import ( 44 | maybe_init_dist, 45 | initialize_model_parallel, 46 | get_model_parallel_group, 47 | get_model_parallel_world_size, 48 | get_data_parallel_world_size, 49 | clip_grad_norm_, 50 | ) 51 | 52 | from trainers.common_utils import manual_seed 53 | from data_utils.data_utils_rm_pointwise import ( 54 | make_pointwise_reward_modeling_data_module, 55 | ) 56 | 57 | from hf_argparser import HfArgumentParser 58 | from arguments import Arguments as TrainingArguments 59 | from checkpoint_utils import ( 60 | checkpoint_hook, 61 | get_latest_checkpoint_path, 62 | load_checkpoint, 63 | load_reward_model_from_sft_ckpt, 64 | ) 65 | from training_utils.trainer_utils import ( 66 | create_optimizer, 67 | create_fsdp_model_for_finetune, 68 | get_cosine_schedule_with_warmup, 69 | ) 70 | 71 | IGNORE_INDEX = -100 72 | 73 | 74 | def model_forward(model, x): 75 | return model(x) 76 | 77 | 78 | def model_forward_with_loss( 79 | model: RewardModel, 80 | input_ids: torch.Tensor, 81 | labels: torch.Tensor, 82 | weights: torch.Tensor, 83 | ) -> torch.Tensor: 84 | """ 85 | Compute the loss for a given model and prompts. 86 | """ 87 | # create an empty tensor of the expected final shape and fill in the current tokens 88 | batch_size, T = input_ids.size(0), input_ids.size(1) 89 | 90 | device = input_ids.device 91 | with torch.device(device): 92 | model.backbone_model.setup_caches( 93 | max_batch_size=batch_size, max_seq_length=T, kv_cache=False 94 | ) 95 | 96 | with torch.backends.cuda.sdp_kernel( 97 | enable_flash=True, enable_math=False, enable_mem_efficient=False 98 | ): 99 | logits = model_forward(model, input_ids) 100 | 101 | with torch.autocast(device_type="cuda", enabled=False): 102 | dtype_logits = logits.float() 103 | loss = F.binary_cross_entropy_with_logits( 104 | dtype_logits.view(-1).float(), 105 | labels.view(-1).float(), 106 | reduction="none", 107 | ) 108 | 109 | weights = weights.view(-1).float() 110 | loss = (loss * weights).sum() / (weights.sum() + 1e-6) 111 | return loss 112 | 113 | 114 | def encode_tokens(tokenizer, string, bos=True, device="cuda"): 115 | tokens = tokenizer.encode(string) 116 | if bos: 117 | tokens = [tokenizer.bos_id()] + tokens 118 | return torch.tensor(tokens, dtype=torch.int, device=device) 119 | 120 | 121 | def main( 122 | args: TrainingArguments, 123 | ) -> None: 124 | """Finetune a model on a given dataset.""" 125 | checkpoint_path = args.checkpoint_path 126 | sft_checkpoint_path = args.sft_checkpoint_path 127 | compile = args.compile 128 | assert checkpoint_path.is_file(), checkpoint_path 129 | 130 | if sft_checkpoint_path is not None: 131 | assert sft_checkpoint_path.is_dir(), sft_checkpoint_path 132 | 133 | tokenizer_path = checkpoint_path.parent / "tokenizer.model" 134 | if not tokenizer_path.is_file(): 135 | tokenizer_path = checkpoint_path.parent 136 | 137 | set_global_compile_mode(args.compile) 138 | 139 | global print 140 | device_id = maybe_init_dist() 141 | use_tp = device_id is not None 142 | if use_tp: 143 | tp_size = args.tensor_parallel_size or torch.distributed.get_world_size() 144 | initialize_model_parallel(tp_size) 145 | torch.distributed.barrier() 146 | tp_group = get_model_parallel_group() 147 | 148 | if device_id != 0: 149 | # only print on rank 0 150 | print = lambda *args, **kwargs: None 151 | 152 | if args.report_to == "wandb" and wandb is not None: 153 | if device_id == 0: 154 | wandb_logging_dir = os.path.join( 155 | tempfile.gettempdir(), f"{os.getuid()}_wandb" 156 | ) 157 | if not os.path.exists(wandb_logging_dir): 158 | os.makedirs(wandb_logging_dir, exist_ok=True) 159 | os.environ["WANDB_DIR"] = wandb_logging_dir 160 | wandb.init( 161 | name=args.wandb_name, 162 | project=args.wandb_project, 163 | entity=args.wandb_entity, 164 | resume="allow", 165 | magic=True, 166 | dir=wandb_logging_dir, 167 | force=True, 168 | ) 169 | wandb.config.update(vars(args)) 170 | 171 | device = "cuda" 172 | precision = args.param_dtype 173 | 174 | print("Loading model ...") 175 | t0 = time.time() 176 | 177 | resume_from_checkpoint = None 178 | resume_epoch = 0 179 | resume_global_step = 0 180 | 181 | if args.resume_from_checkpoint: 182 | ( 183 | resume_from_checkpoint, 184 | resume_epoch, 185 | resume_global_step, 186 | ) = get_latest_checkpoint_path(args.save_dir) 187 | 188 | if resume_from_checkpoint is not None: 189 | sft_checkpoint_path = None 190 | 191 | model = load_reward_model_from_sft_ckpt( 192 | checkpoint_path, 193 | sft_checkpoint_path, 194 | device, 195 | precision, 196 | use_tp, 197 | requires_grad=True, 198 | sequence_parallel=args.sequence_parallel, 199 | ) 200 | 201 | torch.cuda.synchronize() 202 | if use_tp: 203 | torch.distributed.barrier() 204 | print(f"Time to load model: {time.time() - t0:.02f} seconds") 205 | 206 | tokenizer = AcceleraTokenizer(tokenizer_path) 207 | 208 | data_module = make_pointwise_reward_modeling_data_module( 209 | tokenizer=tokenizer, 210 | args=args, 211 | ) 212 | train_dataset = data_module["train_dataset"] 213 | data_collator = data_module["data_collator"] 214 | 215 | model_size = sum( 216 | [ 217 | p.numel() * p.dtype.itemsize 218 | for p in itertools.chain(model.parameters(), model.buffers()) 219 | ] 220 | ) 221 | 222 | print(f"Model size: {model_size / 1e6:.02f} MB") 223 | manual_seed(args.seed) 224 | 225 | sampler = None 226 | if use_tp: 227 | sampler = DistributedSampler( 228 | train_dataset, 229 | shuffle=True, 230 | drop_last=True, 231 | ) 232 | 233 | train_loader = torch.utils.data.DataLoader( 234 | train_dataset, 235 | batch_size=args.per_device_train_batch_size, 236 | shuffle=(sampler is None), 237 | sampler=sampler, 238 | num_workers=0, 239 | pin_memory=True, 240 | collate_fn=data_collator, 241 | ) 242 | 243 | if args.print_training_examples: 244 | print("Training examples:") 245 | cnt = 16 246 | for batch in train_loader: 247 | print( 248 | "Input:", 249 | tokenizer.decode( 250 | batch["input_ids"][0].tolist(), skip_special_tokens=False 251 | ), 252 | ) 253 | cnt -= 1 254 | if cnt == 0: 255 | break 256 | 257 | if compile: 258 | model = torch.compile(model) 259 | 260 | trainable_param_names = [ 261 | name for name, param in model.named_parameters() if param.requires_grad 262 | ] 263 | 264 | use_fsdp = False 265 | 266 | if get_data_parallel_world_size() > 1: 267 | use_fsdp = True 268 | model = create_fsdp_model_for_finetune(args, model) 269 | print("Using FSDP ...") 270 | print(model) 271 | 272 | optimizer = create_optimizer( 273 | args, 274 | model=model, 275 | optimizer_cpu_offload=args.optimizer_cpu_offload, 276 | model_cpu_offload=False, 277 | ) 278 | 279 | scheduler = get_cosine_schedule_with_warmup( 280 | optimizer, 281 | warmup_epochs=len(train_loader) * args.warmup_ratio, 282 | max_epochs=len(train_loader) * args.num_train_epochs, 283 | warmup_start_ratio=0.0, 284 | eta_min_ratio=args.lr_eta_min / args.learning_rate, 285 | ) 286 | 287 | if resume_from_checkpoint is not None: 288 | print( 289 | f"Resuming from checkpoint: {resume_from_checkpoint} (epoch {resume_epoch}, global step {resume_global_step})" 290 | ) 291 | load_checkpoint( 292 | resume_from_checkpoint, model, optimizer, scheduler, use_fsdp=use_fsdp 293 | ) 294 | 295 | micro_train_batch_size = ( 296 | args.micro_train_batch_size or args.per_device_train_batch_size 297 | ) 298 | 299 | assert ( 300 | args.per_device_train_batch_size % micro_train_batch_size == 0 301 | ), f"per_device_train_batch_size ({args.per_device_train_batch_size}) must be divisible by micro_train_batch_size ({micro_train_batch_size})" 302 | accumulate_steps = args.per_device_train_batch_size // micro_train_batch_size 303 | 304 | print( 305 | "Batch size per GPU for training: {}\n".format( 306 | args.per_device_train_batch_size 307 | ), 308 | "Micro batch size for training: {}\n".format(micro_train_batch_size), 309 | "Gradient accumulation steps: {}\n".format(accumulate_steps), 310 | ) 311 | 312 | micro_train_batch_size = micro_train_batch_size * torch.distributed.get_world_size() 313 | 314 | epoch_length = len(train_loader) 315 | 316 | if args.do_train: 317 | print("Starting training ...") 318 | t0 = time.time() 319 | for epoch in tqdm.trange( 320 | args.num_train_epochs, desc="Epoch", disable=device_id != 0 321 | ): 322 | if sampler is not None: 323 | train_loader.sampler.set_epoch(epoch) 324 | pbar = tqdm.tqdm( 325 | enumerate(train_loader), 326 | desc="Iteration", 327 | disable=device_id != 0, 328 | total=len(train_loader), 329 | ) 330 | for it, batch in pbar: 331 | global_step = epoch * epoch_length + it 332 | if global_step < resume_global_step: 333 | continue 334 | 335 | # torch.cuda.synchronize() 336 | model.zero_grad() 337 | 338 | input_ids = batch["input_ids"].to(device=device) 339 | labels = batch["labels"].to(device=device) 340 | weights = batch["weights"].to(device=device) 341 | 342 | input_ids, labels, weights = prepare_batch( 343 | input_ids, 344 | labels, 345 | weights, 346 | tokenizer=tokenizer, 347 | use_tp=use_tp, 348 | sync_group=tp_group, 349 | ) 350 | 351 | loss_scale = 1.0 / accumulate_steps 352 | for ex_idx in range(0, input_ids.size(0), micro_train_batch_size): 353 | if ex_idx + micro_train_batch_size < input_ids.size(0): 354 | with torch.cuda.amp.autocast(dtype=args.compute_dtype): 355 | loss = model_forward_with_loss( 356 | model, 357 | input_ids[ex_idx : ex_idx + micro_train_batch_size], 358 | labels[ex_idx : ex_idx + micro_train_batch_size], 359 | weights[ex_idx : ex_idx + micro_train_batch_size], 360 | ) 361 | (loss_scale * loss).backward() 362 | else: 363 | with torch.cuda.amp.autocast(dtype=args.compute_dtype): 364 | loss = model_forward_with_loss( 365 | model, 366 | input_ids[ex_idx:], 367 | labels[ex_idx:], 368 | weights[ex_idx:], 369 | ) 370 | (loss_scale * loss).backward() 371 | grad_norm = clip_grad_norm_(model, 1.0) 372 | optimizer.step() 373 | scheduler.step() 374 | 375 | if it % 5 == 0: 376 | loss_copy = loss.detach().clone() 377 | torch.distributed.all_reduce(loss_copy) 378 | avg_loss = (loss_copy / torch.distributed.get_world_size()).item() 379 | grad_norm_copy = grad_norm.detach().clone().item() 380 | 381 | if device_id == 0: 382 | if args.report_to == "wandb" and wandb is not None: 383 | wandb.log( 384 | { 385 | "loss": avg_loss, 386 | "learning_rate": scheduler.get_last_lr()[0], 387 | "epoch": epoch, 388 | "step": it, 389 | "grad_norm": grad_norm_copy, 390 | }, 391 | step=global_step, 392 | ) 393 | else: 394 | # Just print to stdout. 395 | print( 396 | { 397 | "loss": avg_loss, 398 | "learning_rate": scheduler.get_last_lr()[0], 399 | "epoch": epoch, 400 | "step": it, 401 | "grad_norm": grad_norm_copy, 402 | } 403 | ) 404 | 405 | checkpoint_hook( 406 | args, 407 | model, 408 | optimizer, 409 | scheduler, 410 | epoch, 411 | global_step, 412 | epoch_length, 413 | use_fsdp=use_fsdp, 414 | trainable_param_names=trainable_param_names, 415 | ) 416 | 417 | torch.cuda.synchronize() 418 | 419 | epoch = args.num_train_epochs 420 | 421 | checkpoint_hook( 422 | args, 423 | model, 424 | optimizer, 425 | scheduler, 426 | epoch, 427 | epoch * epoch_length, 428 | epoch_length, 429 | use_fsdp=use_fsdp, 430 | trainable_param_names=trainable_param_names, 431 | ) 432 | 433 | print(f"Time to train: {time.time() - t0:.02f} seconds") 434 | 435 | 436 | def prepare_batch(input_ids, labels, weights, tokenizer, use_tp, sync_group): 437 | pad_id = tokenizer.pad_id 438 | unk_id = tokenizer.unk_id 439 | # if pad_id < 0, replace pad_id with unk_id 440 | labels[labels == pad_id] = IGNORE_INDEX 441 | if pad_id < 0: 442 | input_ids[input_ids == pad_id] = unk_id 443 | 444 | if use_tp and get_model_parallel_world_size() > 1: 445 | # aggregate (concat) all the inputs across tp sync_group 446 | new_input_ids = torch.empty_like(input_ids).repeat(sync_group.size(), 1) 447 | new_labels = torch.empty_like(labels).repeat(sync_group.size(), 1) 448 | new_weights = torch.empty_like(weights).repeat(sync_group.size(), 1) 449 | 450 | torch.distributed.all_gather_into_tensor( 451 | new_input_ids, input_ids, group=sync_group 452 | ) 453 | torch.distributed.all_gather_into_tensor(new_labels, labels, group=sync_group) 454 | torch.distributed.all_gather_into_tensor(new_weights, weights, group=sync_group) 455 | 456 | return new_input_ids, new_labels, new_weights 457 | 458 | return input_ids, labels, weights 459 | 460 | 461 | if __name__ == "__main__": 462 | parser = HfArgumentParser((TrainingArguments,)) 463 | args = parser.parse_args_into_dataclasses()[0] 464 | main(args) 465 | -------------------------------------------------------------------------------- /train_sft.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The GPT-Accelera Team 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import os 9 | import sys 10 | import tempfile 11 | import time 12 | from pathlib import Path 13 | import itertools 14 | 15 | import tqdm 16 | 17 | import torch 18 | import torch.nn.functional as F 19 | from torch.utils.data.distributed import DistributedSampler 20 | 21 | import torch._inductor.config 22 | import torch._dynamo.config 23 | 24 | torch._inductor.config.coordinate_descent_tuning = True 25 | torch._inductor.config.triton.unique_kernel_names = True 26 | torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future 27 | 28 | torch.backends.cuda.matmul.allow_tf32 = True 29 | torch.backends.cudnn.allow_tf32 = True 30 | 31 | try: 32 | import wandb 33 | except ImportError: 34 | wandb = None 35 | 36 | # support running without installing as a package 37 | wd = Path(__file__).parent.parent.resolve() 38 | sys.path.append(str(wd)) 39 | 40 | from models.model import Transformer, set_global_compile_mode 41 | from models.tokenizer_utils import AcceleraTokenizer 42 | from models.tp import ( 43 | maybe_init_dist, 44 | initialize_model_parallel, 45 | get_model_parallel_group, 46 | get_model_parallel_world_size, 47 | get_data_parallel_world_size, 48 | clip_grad_norm_, 49 | compute_vocab_parallel_logprobs, 50 | ) 51 | 52 | from trainers.common_utils import manual_seed 53 | from data_utils.data_utils_sft import make_sft_data_module 54 | 55 | from hf_argparser import HfArgumentParser 56 | from arguments import Arguments as TrainingArguments 57 | from checkpoint_utils import ( 58 | checkpoint_hook, 59 | get_latest_checkpoint_path, 60 | load_checkpoint, 61 | load_model_from_from_ckpt, 62 | ) 63 | from training_utils.trainer_utils import ( 64 | create_optimizer, 65 | create_fsdp_model_for_finetune, 66 | get_cosine_schedule_with_warmup, 67 | ) 68 | 69 | IGNORE_INDEX = -100 70 | 71 | 72 | def model_forward(model, x, input_pos): 73 | return model(x, input_pos, fully_causal=True) 74 | 75 | 76 | def model_forward_with_loss( 77 | model: Transformer, 78 | input_ids: torch.Tensor, 79 | labels: torch.Tensor, 80 | ) -> torch.Tensor: 81 | """ 82 | Compute the loss for a given model and prompts. 83 | """ 84 | # create an empty tensor of the expected final shape and fill in the current tokens 85 | batch_size, T = input_ids.size(0), input_ids.size(1) 86 | 87 | device = input_ids.device 88 | with torch.device(device): 89 | model.setup_caches(max_batch_size=batch_size, max_seq_length=T, kv_cache=False) 90 | # create an empty tensor of the expected final shape and fill in the current tokens 91 | input_pos = torch.arange(0, T, device=device) 92 | 93 | with torch.backends.cuda.sdp_kernel( 94 | enable_flash=True, enable_math=False, enable_mem_efficient=False 95 | ): 96 | logits = model_forward(model, input_ids, input_pos) 97 | 98 | with torch.autocast(device_type="cuda", enabled=False): 99 | logits = logits.float() 100 | logits = logits[..., :-1, :].contiguous() 101 | labels = labels[..., 1:].contiguous() 102 | 103 | if model.vocab_parallel: 104 | loss = -compute_vocab_parallel_logprobs( 105 | logits.view(-1, logits.size(-1)), 106 | labels.view(-1), 107 | ignore_index=IGNORE_INDEX, 108 | reduction="mean", 109 | ) 110 | else: 111 | loss = F.cross_entropy( 112 | logits.view(-1, logits.size(-1)), 113 | labels.view(-1), 114 | reduction="mean", 115 | ) 116 | return loss 117 | 118 | 119 | def encode_tokens(tokenizer, string, bos=True, device="cuda"): 120 | tokens = tokenizer.encode(string) 121 | if bos: 122 | tokens = [tokenizer.bos_id()] + tokens 123 | return torch.tensor(tokens, dtype=torch.int, device=device) 124 | 125 | 126 | def main( 127 | args: TrainingArguments, 128 | ) -> None: 129 | """Finetune a model on a given dataset.""" 130 | checkpoint_path = args.checkpoint_path 131 | sft_checkpoint_path = args.sft_checkpoint_path 132 | compile = args.compile 133 | assert checkpoint_path.is_file(), checkpoint_path 134 | 135 | if sft_checkpoint_path is not None: 136 | assert sft_checkpoint_path.is_dir(), sft_checkpoint_path 137 | 138 | tokenizer_path = checkpoint_path.parent / "tokenizer.model" 139 | if not tokenizer_path.is_file(): 140 | tokenizer_path = checkpoint_path.parent 141 | 142 | set_global_compile_mode(compile) 143 | 144 | global print 145 | device_id = maybe_init_dist() 146 | use_tp = device_id is not None 147 | if use_tp: 148 | group_size = args.tensor_parallel_size or torch.distributed.get_world_size() 149 | initialize_model_parallel(group_size) 150 | torch.distributed.barrier() 151 | intra_node_group = get_model_parallel_group() 152 | 153 | if device_id != 0: 154 | # only print on rank 0 155 | print = lambda *args, **kwargs: None 156 | 157 | if args.report_to == "wandb" and wandb is not None: 158 | if device_id == 0: 159 | wandb_logging_dir = os.path.join( 160 | tempfile.gettempdir(), f"{os.getuid()}_wandb" 161 | ) 162 | if not os.path.exists(wandb_logging_dir): 163 | os.makedirs(wandb_logging_dir, exist_ok=True) 164 | os.environ["WANDB_DIR"] = wandb_logging_dir 165 | wandb.init( 166 | name=args.wandb_name, 167 | project=args.wandb_project, 168 | entity=args.wandb_entity, 169 | resume="allow", 170 | magic=True, 171 | dir=wandb_logging_dir, 172 | force=True, 173 | ) 174 | wandb.config.update(vars(args)) 175 | 176 | device = "cuda" 177 | precision = args.param_dtype 178 | 179 | print("Loading model ...") 180 | t0 = time.time() 181 | model = load_model_from_from_ckpt( 182 | checkpoint_path, 183 | sft_checkpoint_path, 184 | device, 185 | precision, 186 | use_tp, 187 | requires_grad=True, 188 | sequence_parallel=args.sequence_parallel, 189 | vocab_parallel=args.vocab_parallel, 190 | ) 191 | 192 | torch.cuda.synchronize() 193 | print(f"Time to load model: {time.time() - t0:.02f} seconds") 194 | 195 | tokenizer = AcceleraTokenizer(tokenizer_path) 196 | 197 | data_module = make_sft_data_module( 198 | tokenizer=tokenizer, 199 | args=args, 200 | ) 201 | train_dataset = data_module["train_dataset"] 202 | data_collator = data_module["data_collator"] 203 | 204 | model_size = sum( 205 | [ 206 | p.numel() * p.dtype.itemsize 207 | for p in itertools.chain(model.parameters(), model.buffers()) 208 | ] 209 | ) 210 | 211 | print(f"Model size: {model_size / 1e6:.02f} MB") 212 | manual_seed(args.seed) 213 | 214 | sampler = None 215 | if use_tp: 216 | sampler = DistributedSampler( 217 | train_dataset, 218 | shuffle=True, 219 | drop_last=True, 220 | ) 221 | 222 | train_loader = torch.utils.data.DataLoader( 223 | train_dataset, 224 | batch_size=args.per_device_train_batch_size, 225 | shuffle=(sampler is None), 226 | sampler=sampler, 227 | num_workers=0, 228 | pin_memory=True, 229 | collate_fn=data_collator, 230 | ) 231 | 232 | if args.print_training_examples: 233 | print("Training examples:") 234 | cnt = 16 235 | for batch in train_loader: 236 | print( 237 | "Input:", 238 | tokenizer.decode( 239 | batch["input_ids"][0].tolist(), skip_special_tokens=False 240 | ), 241 | ) 242 | print( 243 | "Target:", 244 | tokenizer.decode( 245 | batch["labels"][0].tolist(), skip_special_tokens=False 246 | ), 247 | ) 248 | cnt -= 1 249 | if cnt == 0: 250 | break 251 | 252 | if compile: 253 | model = torch.compile(model) 254 | 255 | trainable_param_names = [ 256 | name for name, param in model.named_parameters() if param.requires_grad 257 | ] 258 | 259 | use_fsdp = False 260 | 261 | if get_data_parallel_world_size() > 1: 262 | use_fsdp = True 263 | model = create_fsdp_model_for_finetune(args, model) 264 | print("Using FSDP ...") 265 | print(model) 266 | 267 | optimizer = create_optimizer( 268 | args, 269 | model=model, 270 | optimizer_cpu_offload=args.optimizer_cpu_offload, 271 | model_cpu_offload=False, 272 | ) 273 | 274 | scheduler = get_cosine_schedule_with_warmup( 275 | optimizer, 276 | warmup_epochs=len(train_loader) * args.warmup_ratio, 277 | max_epochs=len(train_loader) * args.num_train_epochs, 278 | warmup_start_ratio=0.0, 279 | eta_min_ratio=args.lr_eta_min / args.learning_rate, 280 | ) 281 | 282 | resume_from_checkpoint = None 283 | resume_epoch = 0 284 | resume_global_step = 0 285 | 286 | if args.resume_from_checkpoint: 287 | ( 288 | resume_from_checkpoint, 289 | resume_epoch, 290 | resume_global_step, 291 | ) = get_latest_checkpoint_path(args.save_dir) 292 | 293 | if resume_from_checkpoint is not None: 294 | print( 295 | f"Resuming from checkpoint: {resume_from_checkpoint} (epoch {resume_epoch}, global step {resume_global_step})" 296 | ) 297 | load_checkpoint( 298 | resume_from_checkpoint, model, optimizer, scheduler, use_fsdp=use_fsdp 299 | ) 300 | 301 | micro_train_batch_size = ( 302 | args.micro_train_batch_size or args.per_device_train_batch_size 303 | ) 304 | 305 | assert ( 306 | args.per_device_train_batch_size % micro_train_batch_size == 0 307 | ), f"per_device_train_batch_size ({args.per_device_train_batch_size}) must be divisible by micro_train_batch_size ({micro_train_batch_size})" 308 | accumulate_steps = args.per_device_train_batch_size // micro_train_batch_size 309 | 310 | print( 311 | "Batch size per GPU for training: {}\n".format( 312 | args.per_device_train_batch_size 313 | ), 314 | "Micro batch size for training: {}\n".format(micro_train_batch_size), 315 | "Gradient accumulation steps: {}\n".format(accumulate_steps), 316 | ) 317 | 318 | micro_train_batch_size = micro_train_batch_size * torch.distributed.get_world_size() 319 | 320 | epoch_length = len(train_loader) 321 | 322 | if args.do_train: 323 | print("Starting training ...") 324 | t0 = time.time() 325 | for epoch in tqdm.trange( 326 | args.num_train_epochs, desc="Epoch", disable=device_id != 0 327 | ): 328 | if sampler is not None: 329 | train_loader.sampler.set_epoch(epoch) 330 | pbar = tqdm.tqdm( 331 | enumerate(train_loader), 332 | desc="Iteration", 333 | disable=device_id != 0, 334 | total=len(train_loader), 335 | ) 336 | for it, batch in pbar: 337 | global_step = epoch * epoch_length + it 338 | if global_step < resume_global_step: 339 | continue 340 | 341 | torch.cuda.synchronize() 342 | model.zero_grad() 343 | 344 | input_ids = batch["input_ids"].to(device=device) 345 | labels = batch["labels"].to(device=device) 346 | 347 | input_ids, labels = prepare_batch( 348 | input_ids, 349 | labels, 350 | tokenizer=tokenizer, 351 | use_tp=use_tp, 352 | sync_group=intra_node_group, 353 | ) 354 | 355 | loss_scale = 1.0 / accumulate_steps 356 | for ex_idx in range(0, input_ids.size(0), micro_train_batch_size): 357 | if ex_idx + micro_train_batch_size < input_ids.size(0): 358 | with torch.cuda.amp.autocast(dtype=args.compute_dtype): 359 | loss = model_forward_with_loss( 360 | model, 361 | input_ids[ex_idx : ex_idx + micro_train_batch_size], 362 | labels[ex_idx : ex_idx + micro_train_batch_size], 363 | ) 364 | (loss_scale * loss).backward() 365 | else: 366 | with torch.cuda.amp.autocast(dtype=args.compute_dtype): 367 | loss = model_forward_with_loss( 368 | model, 369 | input_ids[ex_idx:], 370 | labels[ex_idx:], 371 | ) 372 | (loss_scale * loss).backward() 373 | grad_norm = clip_grad_norm_(model, 1.0) 374 | optimizer.step() 375 | scheduler.step() 376 | 377 | if it % 5 == 0: 378 | loss_copy = loss.detach().clone() 379 | torch.distributed.all_reduce(loss_copy) 380 | avg_loss = (loss_copy / torch.distributed.get_world_size()).item() 381 | grad_norm_copy = grad_norm.detach().clone().item() 382 | 383 | if device_id == 0: 384 | if args.report_to == "wandb" and wandb is not None: 385 | wandb.log( 386 | { 387 | "loss": avg_loss, 388 | "learning_rate": scheduler.get_last_lr()[0], 389 | "epoch": epoch, 390 | "step": it, 391 | "grad_norm": grad_norm_copy, 392 | }, 393 | step=global_step, 394 | ) 395 | else: 396 | # Just print to stdout. 397 | print( 398 | { 399 | "loss": avg_loss, 400 | "learning_rate": scheduler.get_last_lr()[0], 401 | "epoch": epoch, 402 | "step": it, 403 | "grad_norm": grad_norm_copy, 404 | } 405 | ) 406 | 407 | checkpoint_hook( 408 | args, 409 | model, 410 | optimizer, 411 | scheduler, 412 | epoch, 413 | global_step, 414 | epoch_length, 415 | use_fsdp=use_fsdp, 416 | trainable_param_names=trainable_param_names, 417 | ) 418 | 419 | torch.cuda.synchronize() 420 | 421 | epoch = args.num_train_epochs 422 | 423 | checkpoint_hook( 424 | args, 425 | model, 426 | optimizer, 427 | scheduler, 428 | epoch, 429 | epoch * epoch_length, 430 | epoch_length, 431 | use_fsdp=use_fsdp, 432 | trainable_param_names=trainable_param_names, 433 | ) 434 | 435 | print(f"Time to train: {time.time() - t0:.02f} seconds") 436 | 437 | 438 | def prepare_batch(input_ids, labels, tokenizer, use_tp, sync_group): 439 | pad_id = tokenizer.pad_id 440 | unk_id = tokenizer.unk_id 441 | # if pad_id < 0, replace pad_id with unk_id 442 | labels[labels == pad_id] = IGNORE_INDEX 443 | if pad_id < 0: 444 | input_ids[input_ids == pad_id] = unk_id 445 | 446 | if use_tp and get_model_parallel_world_size() > 1: 447 | # aggregate (concat) all the inputs across tp sync_group 448 | new_input_ids = torch.empty_like(input_ids).repeat(sync_group.size(), 1) 449 | new_labels = torch.empty_like(labels).repeat(sync_group.size(), 1) 450 | 451 | torch.distributed.all_gather_into_tensor( 452 | new_input_ids, input_ids, group=sync_group 453 | ) 454 | torch.distributed.all_gather_into_tensor(new_labels, labels, group=sync_group) 455 | 456 | return new_input_ids, new_labels 457 | 458 | return input_ids, labels 459 | 460 | 461 | if __name__ == "__main__": 462 | parser = HfArgumentParser((TrainingArguments,)) 463 | args = parser.parse_args_into_dataclasses()[0] 464 | main(args) 465 | -------------------------------------------------------------------------------- /trainers/common_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The GPT-Accelera Team 2 | # Copyright 2023 The Alpaca Team 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import argparse 17 | from dataclasses import dataclass 18 | import os 19 | import tempfile 20 | import random 21 | from typing import Callable, Dict, Optional, Sequence, Union, Mapping, Any, Tuple 22 | import logging 23 | 24 | import numpy as np 25 | import torch 26 | import torch.nn.functional as F 27 | import torch.distributed as dist 28 | 29 | from torch.utils.data import Dataset 30 | from torch.utils.data import DataLoader 31 | from torch.utils.data import random_split 32 | 33 | from datasets import load_dataset 34 | 35 | Numeric = Union[int, float] 36 | 37 | 38 | def zip_(*args: Sequence): 39 | """Assert sequences of same length before zipping.""" 40 | if len(args) == 0: 41 | return [] 42 | assert alleq(args, lambda x, y: len(x) == len(y)) 43 | return zip(*args) 44 | 45 | 46 | def mean(*seqs: Sequence[Numeric]) -> Union[Numeric, Sequence[Numeric]]: 47 | singleton = len(seqs) == 1 48 | means = [float(np.mean(seq)) for seq in seqs] 49 | return means[0] if singleton else means 50 | 51 | 52 | def alleq(l: Sequence, f: Optional[Callable] = lambda x, y: x == y): 53 | """Check all arguments in a sequence are equal according to a given criterion. 54 | 55 | Args: 56 | f: A bi-variate boolean function. 57 | l: A list/tuple. 58 | 59 | Returns: 60 | True if everything is equal; otherwise False. 61 | """ 62 | return all(f(l[0], li) for li in l[1:]) 63 | 64 | 65 | def flatten_dict(nested, sep=".", postprocess_fn=lambda *args: args): 66 | def rec(nest, prefix, into): 67 | for k, v in nest.items(): 68 | if sep in k: 69 | raise ValueError(f"separator '{sep}' not allowed to be in key '{k}'") 70 | if isinstance(v, dict): # collections.Mapping fails in py3.10. 71 | rec(v, prefix + k + sep, into) 72 | else: 73 | v = postprocess_fn(v) 74 | into[prefix + k] = v 75 | 76 | flat = {} 77 | rec(nested, "", flat) 78 | return flat 79 | 80 | 81 | def unpack_dict( 82 | d: Dict, keys: Sequence[str], return_type: type = tuple 83 | ) -> Union[Sequence, Dict]: 84 | if return_type in (tuple, list): 85 | return return_type(d[key] for key in keys) 86 | elif return_type == dict: 87 | return {key: d[key] for key in keys} 88 | else: 89 | raise ValueError(f"Unknown return_type: {return_type}") 90 | 91 | 92 | def merge_dict(dicts: Sequence[dict], merge_fn: Callable = lambda *args: args) -> dict: 93 | """Merge a sequence of dicts (with the same set of keys) into a single dict.""" 94 | if len(dicts) == 0: 95 | return dict() 96 | return {key: merge_fn([dict_[key] for dict_ in dicts]) for key in dicts[0].keys()} 97 | 98 | 99 | def prepare_inputs( 100 | data: Union[torch.Tensor, Any], device: Union[str, int, torch.device] 101 | ) -> Union[torch.Tensor, Any]: 102 | if isinstance(data, Mapping): 103 | return type(data)( 104 | {k: prepare_inputs(v, device) for k, v in data.items()} 105 | ) # noqa 106 | elif isinstance(data, (tuple, list)): 107 | return type(data)(prepare_inputs(v, device) for v in data) 108 | elif isinstance(data, torch.Tensor): 109 | return data.to(device) # This can break with deepspeed. 110 | return data 111 | 112 | 113 | def pad_inputs_on_batch( 114 | data: Sequence[torch.Tensor], per_device_batch_size: int 115 | ) -> Sequence[torch.Tensor]: 116 | batch_size = None 117 | output_tensors = [] 118 | for tensor in data: 119 | if batch_size is None: 120 | batch_size = tensor.size(0) 121 | assert tensor.size(0) == batch_size 122 | 123 | if batch_size % per_device_batch_size != 0: 124 | filled_size = per_device_batch_size - (batch_size % per_device_batch_size) 125 | tensor = torch.cat( 126 | [ 127 | tensor, 128 | tensor[0:1].expand(filled_size, *tensor.size()[1:]), 129 | ], 130 | dim=0, 131 | ) 132 | output_tensors.append(tensor) 133 | return output_tensors 134 | 135 | 136 | def pad( 137 | inputs: torch.Tensor, 138 | target_size: Union[torch.Size, Sequence[int]], 139 | value=0.0, 140 | left=True, 141 | ): 142 | current_size = inputs.size() 143 | diffs = tuple(ti - ci for ti, ci in zip_(target_size, current_size)) 144 | pad_params = [] 145 | for diff in diffs: 146 | pad_params = ([diff, 0] if left else [0, diff]) + pad_params 147 | res = F.pad(inputs, pad=pad_params, value=value) 148 | return res 149 | 150 | 151 | def left_pad( 152 | inputs: torch.Tensor, target_size: Union[torch.Size, Sequence[int]], value=0.0 153 | ): 154 | return pad(inputs=inputs, target_size=target_size, value=value, left=True) 155 | 156 | 157 | def right_pad( 158 | inputs: torch.Tensor, target_size: Union[torch.Size, Sequence[int]], value=0.0 159 | ): 160 | return pad(inputs=inputs, target_size=target_size, value=value, left=False) 161 | 162 | 163 | def manual_seed(args_or_seed: Union[int, argparse.Namespace], fix_cudnn=False): 164 | if hasattr(args_or_seed, "seed"): 165 | args_or_seed = args_or_seed.seed 166 | random.seed(args_or_seed) 167 | np.random.seed(args_or_seed) 168 | torch.manual_seed(args_or_seed) 169 | torch.cuda.manual_seed_all(args_or_seed) 170 | os.environ["PYTHONHASHSEED"] = str(args_or_seed) 171 | if fix_cudnn: 172 | torch.backends.cudnn.deterministic = True # noqa 173 | torch.backends.cudnn.benchmark = False # noqa 174 | 175 | 176 | class InfiniteLoader(object): 177 | """Wraps an existing DataLoader so that it outputs stuff indefinitely; useful for semi-supervised learning and DDP.""" 178 | 179 | def __init__(self, loader: DataLoader): 180 | super(InfiniteLoader, self).__init__() 181 | self.loader = loader 182 | self.data_iterator = iter(loader) 183 | self.epoch = 0 184 | 185 | def __iter__(self): 186 | return self 187 | 188 | def __next__(self): 189 | try: 190 | data = next(self.data_iterator) 191 | except StopIteration: 192 | # Increment the epoch count 193 | self.epoch += 1 194 | 195 | # If using Distributed Data Parallel, set the epoch for the sampler 196 | if dist.is_initialized(): 197 | self.loader.sampler.set_epoch(self.epoch) 198 | 199 | # Create a new iterator for the next epoch 200 | self.data_iterator = iter(self.loader) 201 | data = next(self.data_iterator) 202 | 203 | return data 204 | 205 | 206 | class DisableLogger: 207 | def __enter__(self): 208 | logging.disable(logging.CRITICAL) 209 | 210 | def __exit__(self, exit_type, exit_value, exit_traceback): 211 | logging.disable(logging.NOTSET) 212 | 213 | 214 | @dataclass 215 | class DataCollatorForStackableDataset(object): 216 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: 217 | return { 218 | key: torch.stack([instance[key] for instance in instances]) 219 | for key in instances[0].keys() 220 | } 221 | 222 | 223 | def local_dataset(dataset_name): 224 | if dataset_name.endswith(".json"): 225 | full_dataset = load_dataset( 226 | "json", 227 | data_files=dataset_name, 228 | cache_dir=os.path.join( 229 | tempfile.gettempdir(), f"{os.getuid()}_cache", "huggingface", "datasets" 230 | ), 231 | ) 232 | else: 233 | raise ValueError(f"Unsupported dataset format: {dataset_name}") 234 | 235 | return full_dataset 236 | 237 | 238 | def _get_generator(seed: int) -> torch.Generator: 239 | rng = torch.Generator() 240 | rng.manual_seed(seed) 241 | return rng 242 | 243 | 244 | def split_train_into_train_and_eval( 245 | train_dataset: Dataset, eval_size: int, seed: int 246 | ) -> Tuple[Dataset, Dataset]: 247 | assert eval_size < len( 248 | train_dataset # noqa 249 | ), "Requested eval_size cannot be equal/larger than original train data size." 250 | new_train_size = len(train_dataset) - eval_size # noqa 251 | train_dataset, eval_dataset = random_split( 252 | train_dataset, [new_train_size, eval_size], generator=_get_generator(seed) 253 | ) 254 | return train_dataset, eval_dataset 255 | -------------------------------------------------------------------------------- /training_utils/memory_efficient_adam.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The GPT-Accelera Team 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | from collections import defaultdict 10 | from copy import deepcopy 11 | from itertools import chain 12 | from typing import Any, DefaultDict, Dict, Iterable 13 | 14 | import torch 15 | from torch.optim import Optimizer 16 | from torch.optim.optimizer import StateDict 17 | 18 | 19 | class MemoryEfficientAdamW(Optimizer): 20 | """ 21 | Arguments: 22 | model_params (iterable): iterable of parameters of dicts defining 23 | parameter groups. 24 | lr (float, optional): learning rate. (default: 1e-3) 25 | betas (Tuple[float, float], optional): coefficients used for computing 26 | running averages of gradient and its square. (default: (0.9, 0.999)) 27 | eps (float, optional): term added to the denominator to improve 28 | numerical stability. (default: 1e-8) 29 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 30 | adamw_mode (boolean, optional): Apply L2 regularization or weight decay 31 | True for decoupled weight decay(also known as AdamW) (default: True) 32 | 33 | .. _Adam\: A Method for Stochastic Optimization: 34 | https://arxiv.org/abs/1412.6980 35 | .. _On the Convergence of Adam and Beyond: 36 | https://openreview.net/forum?id=ryQu7f-RZ 37 | """ 38 | 39 | def __init__( 40 | self, 41 | model_params, 42 | lr=1e-3, 43 | betas=(0.9, 0.999), 44 | eps=1e-8, 45 | weight_decay=0, 46 | adamw_mode=True, 47 | optim_dtype=torch.bfloat16, 48 | optim_device=torch.device("cpu"), 49 | ): 50 | default_args = dict( 51 | lr=lr, 52 | betas=betas, 53 | eps=eps, 54 | weight_decay=weight_decay, 55 | ) 56 | super(MemoryEfficientAdamW, self).__init__(model_params, default_args) 57 | self.adamw_mode = adamw_mode 58 | self.optim_dtype = optim_dtype 59 | self.optim_device = optim_device 60 | 61 | def torch_adam_update_cpu( 62 | self, 63 | data, 64 | grad, 65 | exp_avg, 66 | exp_avg_sq, 67 | lr, 68 | beta1, 69 | beta2, 70 | eps, 71 | weight_decay, 72 | bias_correction1, 73 | bias_correction2, 74 | use_adamw=False, 75 | ): 76 | assert data.dtype == grad.dtype 77 | if weight_decay != 0: 78 | if use_adamw: 79 | data.mul_(1 - lr * weight_decay) 80 | else: 81 | grad = grad.add(data, alpha=weight_decay) 82 | 83 | non_blocking = self.optim_device.type == "cpu" 84 | 85 | exp_avg_cuda, exp_avg_sq_cuda = ( 86 | exp_avg.to(data.device, non_blocking=non_blocking), 87 | exp_avg_sq.to(data.device, non_blocking=non_blocking), 88 | ) 89 | 90 | dtype_grad = grad.to(dtype=self.optim_dtype) 91 | exp_avg_cuda.mul_(beta1).add_(dtype_grad, alpha=1 - beta1) 92 | exp_avg_sq_cuda.mul_(beta2).addcmul_(dtype_grad, dtype_grad, value=1 - beta2) 93 | denom_cuda = (exp_avg_sq_cuda.sqrt() / math.sqrt(bias_correction2)).add_(eps) 94 | 95 | step_size = lr / bias_correction1 96 | data.addcdiv_( 97 | exp_avg_cuda.to(dtype=data.dtype), 98 | denom_cuda.to(dtype=data.dtype), 99 | value=-step_size, 100 | ) 101 | 102 | # Write back to cpu 103 | exp_avg.copy_(exp_avg_cuda, non_blocking=non_blocking) 104 | exp_avg_sq.copy_(exp_avg_sq_cuda, non_blocking=non_blocking) 105 | 106 | @torch.no_grad() 107 | def step(self, closure=None): 108 | loss = None 109 | if closure is not None: 110 | with torch.enable_grad(): 111 | loss = closure() 112 | 113 | for _, group in enumerate(self.param_groups): 114 | for _, p in enumerate(group["params"]): 115 | if p.grad is None: 116 | continue 117 | 118 | state = self.state[p] 119 | assert ( 120 | p.device.type == "cuda" 121 | ), f"PinMemoryCPUAdam assume all parameters are on cuda" 122 | if len(state) == 0: 123 | state["step"] = 0 124 | # gradient momentums 125 | state["exp_avg"] = torch.zeros_like( 126 | p, 127 | device=self.optim_device, 128 | dtype=self.optim_dtype, 129 | ) 130 | # gradient variances 131 | state["exp_avg_sq"] = torch.zeros_like( 132 | p, 133 | device=self.optim_device, 134 | dtype=self.optim_dtype, 135 | ) 136 | if self.optim_device.type == "cpu": 137 | state["exp_avg"] = state["exp_avg"].pin_memory() 138 | state["exp_avg_sq"] = state["exp_avg_sq"].pin_memory() 139 | 140 | state["step"] += 1 141 | beta1, beta2 = group["betas"] 142 | 143 | assert ( 144 | p.data.numel() == p.grad.data.numel() 145 | ), "parameter and gradient should have the same size" 146 | assert ( 147 | state["exp_avg"].device.type == self.optim_device.type 148 | ), f"exp_avg should stay on {self.optim_device.type}" 149 | assert ( 150 | state["exp_avg_sq"].device.type == self.optim_device.type 151 | ), f"exp_avg should stay on {self.optim_device.type}" 152 | bias_correction1 = 1 - beta1 ** state["step"] 153 | bias_correction2 = 1 - beta2 ** state["step"] 154 | self.torch_adam_update_cpu( 155 | p.data, 156 | p.grad.data, 157 | state["exp_avg"], 158 | state["exp_avg_sq"], 159 | group["lr"], 160 | beta1, 161 | beta2, 162 | group["eps"], 163 | group["weight_decay"], 164 | bias_correction1, 165 | bias_correction2, 166 | self.adamw_mode, 167 | ) 168 | return loss 169 | 170 | @torch._disable_dynamo 171 | def load_state_dict(self, state_dict: StateDict) -> None: 172 | r"""Loads the optimizer state. 173 | 174 | Args: 175 | state_dict (dict): optimizer state. Should be an object returned 176 | from a call to :meth:`state_dict`. 177 | """ 178 | # shallow copy, to be consistent with module API 179 | state_dict = state_dict.copy() 180 | 181 | for pre_hook in self._optimizer_load_state_dict_pre_hooks.values(): 182 | hook_result = pre_hook(self, state_dict) 183 | if hook_result is not None: 184 | state_dict = hook_result 185 | 186 | # Validate the state_dict 187 | groups = self.param_groups 188 | 189 | # Deepcopy as we write into saved_groups later to update state 190 | saved_groups = deepcopy(state_dict["param_groups"]) 191 | 192 | if len(groups) != len(saved_groups): 193 | raise ValueError( 194 | "loaded state dict has a different number of " "parameter groups" 195 | ) 196 | param_lens = (len(g["params"]) for g in groups) 197 | saved_lens = (len(g["params"]) for g in saved_groups) 198 | if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): 199 | raise ValueError( 200 | "loaded state dict contains a parameter group " 201 | "that doesn't match the size of optimizer's group" 202 | ) 203 | 204 | # Update the state 205 | id_map = dict( 206 | zip( 207 | chain.from_iterable(g["params"] for g in saved_groups), 208 | chain.from_iterable(g["params"] for g in groups), 209 | ) 210 | ) 211 | 212 | def _cast(param, value, param_id=None, param_groups=None, key=None): 213 | r"""Make a deep copy of value, casting all tensors to device of param.""" 214 | if isinstance(value, torch.Tensor): 215 | if param.is_floating_point(): 216 | casted_value = value.to( 217 | dtype=self.optim_dtype, device=self.optim_device 218 | ) 219 | if self.optim_device.type == "cpu": 220 | casted_value = casted_value.pin_memory() 221 | else: 222 | casted_value = Optimizer._process_value_according_to_param_policy( 223 | param, value, param_id, param_groups, key 224 | ) 225 | return casted_value 226 | elif isinstance(value, dict): 227 | return { 228 | k: _cast( 229 | param, v, param_id=param_id, param_groups=param_groups, key=k 230 | ) 231 | for k, v in value.items() 232 | } 233 | elif isinstance(value, Iterable): 234 | return type(value)(_cast(param, v, param_id=param_id, param_groups=param_groups) for v in value) # type: ignore[call-arg] 235 | else: 236 | return value 237 | 238 | # Copy state assigned to params (and cast tensors to appropriate types). 239 | # State that is not assigned to params is copied as is (needed for 240 | # backward compatibility). 241 | state: DefaultDict[torch.Tensor, Dict[Any, Any]] = defaultdict(dict) 242 | for k, v in state_dict["state"].items(): 243 | if k in id_map: 244 | param = id_map[k] 245 | state[param] = _cast( 246 | param, v, param_id=k, param_groups=state_dict["param_groups"] 247 | ) 248 | else: 249 | state[k] = v 250 | 251 | # Update parameter groups, setting their 'params' value 252 | def update_group( 253 | group: Dict[str, Any], new_group: Dict[str, Any] 254 | ) -> Dict[str, Any]: 255 | new_group["params"] = group["params"] 256 | return new_group 257 | 258 | param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)] 259 | self.__setstate__({"state": state, "param_groups": param_groups}) 260 | 261 | for post_hook in self._optimizer_load_state_dict_post_hooks.values(): 262 | post_hook(self) 263 | -------------------------------------------------------------------------------- /training_utils/trainer_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The GPT-Accelera Team 2 | # Copyright 2023 The Alpaca Team 3 | # Copyright 2022 The HuggingFace Team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import math 18 | from functools import partial 19 | 20 | import torch 21 | import torch.nn as nn 22 | 23 | import torch.optim as optim 24 | from torch.optim.lr_scheduler import LambdaLR 25 | 26 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 27 | from torch.distributed.fsdp import MixedPrecision 28 | from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy 29 | 30 | from training_utils.memory_efficient_adam import MemoryEfficientAdamW 31 | from arguments import Arguments 32 | 33 | from models.model import TransformerBlock 34 | from models.tp import get_data_parallel_group, get_data_parallel_world_size 35 | 36 | 37 | def create_optimizer( 38 | args: Arguments, 39 | model: nn.Module, 40 | optimizer_cpu_offload: bool = False, 41 | model_cpu_offload: bool = False, 42 | ) -> optim.Optimizer: 43 | if not model_cpu_offload: 44 | model_device = next(iter(model.parameters())).device 45 | 46 | optimizer = MemoryEfficientAdamW( 47 | [p for p in model.parameters() if p.requires_grad], 48 | lr=args.learning_rate, 49 | betas=(args.adam_beta1, args.adam_beta2), 50 | eps=args.adam_eps, 51 | weight_decay=args.weight_decay, 52 | optim_dtype=args.optim_dtype, 53 | optim_device=( 54 | torch.device("cpu") if optimizer_cpu_offload else model_device 55 | ), 56 | ) 57 | else: 58 | optimizer = torch.optim.AdamW( 59 | [p for p in model.parameters() if p.requires_grad], 60 | lr=args.learning_rate, 61 | betas=(args.adam_beta1, args.adam_beta2), 62 | eps=args.adam_eps, 63 | weight_decay=args.weight_decay, 64 | fused=True, 65 | ) 66 | 67 | return optimizer 68 | 69 | 70 | def create_fsdp_model_for_finetune( 71 | args: Arguments, 72 | model: nn.Module, 73 | bf16_all_reduce_upper_bound: int = 16, 74 | ) -> FSDP: 75 | model = FSDP( 76 | module=model, 77 | process_group=get_data_parallel_group(), 78 | auto_wrap_policy=partial( 79 | transformer_auto_wrap_policy, 80 | transformer_layer_cls={ 81 | TransformerBlock, 82 | }, 83 | ), 84 | mixed_precision=MixedPrecision( 85 | param_dtype=args.compute_dtype, 86 | reduce_dtype=( 87 | torch.float32 88 | if get_data_parallel_world_size() >= bf16_all_reduce_upper_bound 89 | else args.compute_dtype 90 | ), 91 | keep_low_precision_grads=(args.optim_dtype != torch.float32), 92 | buffer_dtype=args.compute_dtype, 93 | ), 94 | cpu_offload=False, 95 | use_orig_params=False, 96 | forward_prefetch=True, 97 | limit_all_gathers=True, 98 | ) 99 | return model 100 | 101 | 102 | # https://github.com/huggingface/transformers/blob/976189a6df796a2ff442dd81b022626c840d8c27/src/transformers/optimization.py 103 | def _get_cosine_schedule_with_warmup_lr_lambda( 104 | current_step: int, 105 | *, 106 | num_warmup_steps: int, 107 | num_training_steps: int, 108 | warmup_start_ratio: float, 109 | eta_min_ratio: float, 110 | ): 111 | if current_step < num_warmup_steps: 112 | return warmup_start_ratio + (1.0 - warmup_start_ratio) * float( 113 | current_step 114 | ) / float(max(1, num_warmup_steps)) 115 | 116 | progress = float(current_step - num_warmup_steps) / float( 117 | max(1, num_training_steps - num_warmup_steps) 118 | ) 119 | return eta_min_ratio + (1.0 - eta_min_ratio) * max( 120 | 0.0, 0.5 * (1.0 + math.cos(math.pi * progress)) 121 | ) 122 | 123 | 124 | def get_cosine_schedule_with_warmup( 125 | optimizer: optim.Optimizer, 126 | warmup_epochs: int, 127 | max_epochs: int, 128 | warmup_start_ratio: float = 0.0, 129 | eta_min_ratio: float = 0.0, 130 | last_epoch: int = -1, 131 | ): 132 | """ 133 | Create a schedule with a learning rate that decreases following the values of the cosine function between the 134 | initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the 135 | initial lr set in the optimizer. 136 | 137 | Return: 138 | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. 139 | """ 140 | 141 | assert 0.0 <= warmup_start_ratio <= 1.0, "warmup_start_ratio should be in [0, 1]" 142 | assert 0.0 <= eta_min_ratio <= 1.0, "eta_min_ratio should be in [0, 1]" 143 | 144 | lr_lambda = partial( 145 | _get_cosine_schedule_with_warmup_lr_lambda, 146 | num_warmup_steps=warmup_epochs, 147 | num_training_steps=max_epochs, 148 | warmup_start_ratio=warmup_start_ratio, 149 | eta_min_ratio=eta_min_ratio, 150 | ) 151 | return LambdaLR(optimizer, lr_lambda, last_epoch) 152 | --------------------------------------------------------------------------------