├── .gitignore ├── LICENSE ├── MCSD ├── README.md ├── evaluation.py ├── inference │ ├── __init__.py │ ├── generate.py │ └── strategies.py └── model │ ├── __init__.py │ └── llama_tree_attn │ ├── __init__.py │ ├── configuration_llama.py │ ├── convert_llama_weights_to_hf.py │ ├── modeling_llama.py │ ├── tokenization_llama.py │ └── tokenization_llama_fast.py ├── README.md └── dataset └── wmt_ende.json /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 NJUNLP 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MCSD/README.md: -------------------------------------------------------------------------------- 1 | # Source Code for Multi-Candidate Speculative Decoding 2 | 3 | We provide Python application interfaces for inference, as well as command-line interfaces for evaluation. 4 | 5 | ## Dependencies 6 | 7 | PyTorch version >= 1.11.0 8 | 9 | Python version >= 3.8 10 | 11 | transformers >= 4.34.0 12 | 13 | ## Evaluation CLI 14 | Run the following script for evaluation: 15 | ``` 16 | python evaluation.py \ 17 | --draft-model PATH_TO_DRAFT_MODEL \ 18 | --target-model PATH_TO_TARGET_MODEL \ 19 | --fp16 \ 20 | --k-config 4,2,2 \ 21 | --datapath PATH_TO_DATA \ 22 | --sampling-type sampling 23 | ``` 24 | 25 | ### Options 26 | ``` 27 | -h, --help show this help message and exit 28 | --draft-model Draft model path. 29 | --target-model Target model path. 30 | --tokenizer Tokenizer path. If not provided, use the Target model path. 31 | --fp16 Use float16 dtype. 32 | --k-config Use comma separations, e.g. `--k-config 4,2,2`. 33 | --datapath The json data file. 34 | --max-new-tokens 35 | --replacement Sampling with replacement. 36 | --naive-sampling Use multi-candidate naive sampling. 37 | --disable-tree-attn 38 | --sampling-type {argmax,sampling} 39 | --disable-tqdm 40 | --auto-model Use AutoModelForCausalLM and AutoTokenizer to load the model and tokenizer, this will disable the tree attn. 41 | ``` 42 | 43 | Note: 44 | * Tree Attn is currently not supported for models other than LLaMA. Therefore, when using '--auto-model', Tree Attn will be disabled. 45 | * Since flash-attn does not support custom attention masks, it is currently incompatible with Tree Attn. 46 | 47 | ## Python application interfaces 48 | Here is an example of inference using our generator, see here for the function of each argument. 49 | ```python 50 | import torch 51 | from model.llama_tree_attn import LlamaForCausalLM, LlamaTokenizer 52 | from inference.generate import SpeculativeGenerator 53 | 54 | draft_model = LlamaForCausalLM.from_pretrained( 55 | "PATH_TO_DRAFT_MODEL", 56 | torch_dtype=torch.float16, 57 | device_map=0, 58 | ) 59 | target_model = LlamaForCausalLM.from_pretrained( 60 | "PATH_TO_TARGET_MODEL", 61 | torch_dtype=torch.float16, 62 | device_map="auto", 63 | ) 64 | tokenizer = LlamaTokenizer.from_pretrained("PATH_TO_TARGET_MODEL") 65 | 66 | generator = SpeculativeGenerator( 67 | draft_model, 68 | target_model, 69 | eos_token_id=tokenizer.eos_token_id, 70 | k_config=(4, 2, 2), 71 | max_new_tokens=128, 72 | draft_model_temp=1, 73 | target_model_temp=1, 74 | replacement=False, 75 | speculative_sampling=True, 76 | tree_attn=True, 77 | ) 78 | 79 | prompt_text = "Hey, are you conscious? Can you talk to me?" 80 | inputs = tokenizer(prompt_text, return_tensors="pt").to("cuda") 81 | input_ids = inputs.input_ids 82 | with torch.no_grad(): 83 | output = generator.generate(input_ids) 84 | output_text = tokenizer.batch_decode( 85 | output.sequences, skip_special_tokens=True, clean_up_tokenization_spaces=False 86 | )[0] 87 | 88 | print("Output:\n{}".format(output_text)) 89 | 90 | ``` 91 | -------------------------------------------------------------------------------- /MCSD/evaluation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | import time 5 | from typing import Literal, Tuple 6 | 7 | import torch 8 | from inference.generate import Generator, BaseGenerator, SpeculativeGenerator 9 | from model.llama_tree_attn import LlamaForCausalLM, LlamaTokenizer 10 | from transformers import AutoModelForCausalLM, AutoTokenizer 11 | from tqdm import tqdm 12 | 13 | # Setup logging 14 | logging.basicConfig( 15 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 16 | datefmt="%m/%d/%Y %H:%M:%S", 17 | level=logging.INFO, 18 | ) 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | class JsonData: 24 | def __init__(self, path) -> None: 25 | with open(path) as fin: 26 | self.data = json.load(fin) 27 | 28 | def __getitem__(self, index) -> Tuple[str, str]: 29 | return self.data[index] 30 | 31 | def __len__(self): 32 | return len(self.data) 33 | 34 | 35 | def run_eval( 36 | draft_model, 37 | target_model, 38 | tokenizer, 39 | k_config: Tuple[int], 40 | datapath: str, 41 | max_new_tokens: int = 128, 42 | replacement=False, 43 | speculative_sampling=True, 44 | tree_attn=True, 45 | sampling_type: Literal["argmax", "sampling"] = "sampling", 46 | disable_tqdm: bool = False, 47 | ): 48 | if sampling_type not in ["argmax", "sampling"]: 49 | raise ValueError( 50 | f'`sampling_type` can be either `"argmax"` or `"sampling"`, but received "{sampling_type}"' 51 | ) 52 | if sampling_type == "argmax": 53 | target_model_temp = 0 54 | draft_model_temp = 0 55 | else: 56 | target_model_temp = 1 57 | draft_model_temp = 1 58 | 59 | dataloader = JsonData(datapath) 60 | generator = SpeculativeGenerator( 61 | draft_model, 62 | target_model, 63 | eos_token_id=tokenizer.eos_token_id, 64 | k_config=k_config, 65 | max_new_tokens=max_new_tokens, 66 | draft_model_temp=draft_model_temp, 67 | target_model_temp=target_model_temp, 68 | replacement=replacement, 69 | speculative_sampling=speculative_sampling, 70 | tree_attn=tree_attn, 71 | ) 72 | 73 | draft_model.eval() 74 | target_model.eval() 75 | 76 | logger.info("evaluation start.") 77 | start_time = time.time() 78 | 79 | acceptance_count = 0 80 | draft_token_count = 0 81 | invocation_count = 0 82 | 83 | iterator = range(len(dataloader)) 84 | with torch.no_grad(): 85 | for sample_idx in iterator if disable_tqdm else tqdm(iterator): 86 | prompt_text = dataloader[sample_idx] 87 | inputs = tokenizer(prompt_text, return_tensors="pt").to("cuda") 88 | input_ids = inputs.input_ids 89 | output = generator.generate(input_ids) 90 | 91 | acceptance_count += output.acceptance_count 92 | draft_token_count += output.draft_token_count 93 | invocation_count += output.invocation_count 94 | end_time = time.time() 95 | 96 | logger.info("evaluation complete.") 97 | 98 | run_time = end_time - start_time 99 | 100 | latency = run_time / (acceptance_count + invocation_count) 101 | acceptance_rate = acceptance_count / draft_token_count 102 | block_efficiency = 1 + acceptance_count / invocation_count 103 | 104 | logger.info("Running time: {:.2f} s".format(run_time)) 105 | logger.info("Token latency: {:.2f} ms".format(latency * 1000)) 106 | logger.info("Acceptance rate: {:.2f}".format(acceptance_rate)) 107 | logger.info("Block efficiency: {:.2f}".format(block_efficiency)) 108 | 109 | 110 | def run_baseline_eval( 111 | target_model, 112 | tokenizer, 113 | datapath: str, 114 | max_new_tokens: int = 128, 115 | sampling_type: Literal["argmax", "sampling"] = "sampling", 116 | disable_tqdm: bool = False, 117 | ): 118 | if sampling_type not in ["argmax", "sampling"]: 119 | raise ValueError( 120 | f'`sampling_type` can be either `"argmax"` or `"sampling"`, but received "{sampling_type}"' 121 | ) 122 | if sampling_type == "argmax": 123 | target_model_temp = 0 124 | else: 125 | target_model_temp = 1 126 | 127 | dataloader = JsonData(datapath) 128 | generator = BaseGenerator( 129 | target_model, 130 | eos_token_id=tokenizer.eos_token_id, 131 | max_new_tokens=max_new_tokens, 132 | temp=target_model_temp, 133 | ) 134 | 135 | target_model.eval() 136 | 137 | logger.info("evaluation start.") 138 | start_time = time.time() 139 | 140 | invocation_count = 0 141 | 142 | iterator = range(len(dataloader)) 143 | with torch.no_grad(): 144 | for sample_idx in iterator if disable_tqdm else tqdm(iterator): 145 | prompt_text = dataloader[sample_idx] 146 | inputs = tokenizer(prompt_text, return_tensors="pt").to("cuda") 147 | input_ids = inputs.input_ids 148 | output = generator.generate(input_ids) 149 | 150 | invocation_count += output.invocation_count 151 | end_time = time.time() 152 | 153 | logger.info("evaluation complete.") 154 | 155 | run_time = end_time - start_time 156 | 157 | latency = run_time / invocation_count 158 | 159 | logger.info("Running time: {:.2f} s".format(run_time)) 160 | logger.info("Token latency: {:.2f} ms".format(latency * 1000)) 161 | 162 | 163 | def main(args): 164 | torch_dtype = torch.float16 if args.fp16 else torch.float32 165 | 166 | logger.info("The full evaluation configuration:\n" + repr(args)) 167 | 168 | if args.auto_model and not args.disable_tree_attn: 169 | logger.warning( 170 | "Tree Attn is currently not supported for models other than LLaMA. Therefore, " 171 | "when using '--auto-model', Tree Attn will be disabled." 172 | ) 173 | args.disable_tree_attn = True 174 | 175 | ModelLoader = AutoModelForCausalLM if args.auto_model else LlamaForCausalLM 176 | TokenizerLoader = AutoTokenizer if args.auto_model else LlamaTokenizer 177 | 178 | logger.info("Loading draft model: {}".format(args.draft_model)) 179 | draft_model = ModelLoader.from_pretrained( 180 | args.draft_model, 181 | torch_dtype=torch_dtype, 182 | device_map=0, 183 | use_flash_attention_2=True if args.flash_attn else False, 184 | ) 185 | 186 | logger.info("Loading target model: {}".format(args.target_model)) 187 | target_model = ModelLoader.from_pretrained( 188 | args.target_model, 189 | torch_dtype=torch_dtype, 190 | device_map="auto", 191 | use_flash_attention_2=True if args.flash_attn else False, 192 | ) 193 | 194 | tokenizer = TokenizerLoader.from_pretrained(args.tokenizer) 195 | 196 | if args.run_baseline: 197 | run_baseline_eval( 198 | target_model, 199 | tokenizer=tokenizer, 200 | datapath=args.datapath, 201 | max_new_tokens=args.max_new_tokens, 202 | sampling_type=args.sampling_type, 203 | disable_tqdm=args.disable_tqdm, 204 | ) 205 | else: 206 | run_eval( 207 | draft_model, 208 | target_model, 209 | tokenizer=tokenizer, 210 | k_config=args.k_config, 211 | datapath=args.datapath, 212 | max_new_tokens=args.max_new_tokens, 213 | replacement=args.replacement, 214 | speculative_sampling=not args.naive_sampling, 215 | tree_attn=not args.disable_tree_attn, 216 | sampling_type=args.sampling_type, 217 | disable_tqdm=args.disable_tqdm, 218 | ) 219 | 220 | 221 | if __name__ == "__main__": 222 | parser = argparse.ArgumentParser() 223 | parser.add_argument( 224 | "--draft-model", type=str, required=True, help="Draft model path." 225 | ) 226 | parser.add_argument( 227 | "--target-model", type=str, required=True, help="Target model path." 228 | ) 229 | parser.add_argument("--tokenizer", type=str, default=None, help="Tokenizer path.") 230 | parser.add_argument("--fp16", action="store_true", help="use float16 dtype.") 231 | 232 | parser.add_argument( 233 | "--k-config", 234 | type=lambda x: tuple(map(int, x.split(","))), 235 | required=True, 236 | help="Use comma separations, e.g. `--k-config 4,2,2`.", 237 | ) 238 | 239 | parser.add_argument( 240 | "--datapath", type=str, required=True, help="The json data file." 241 | ) 242 | parser.add_argument("--max-new-tokens", type=int, default=128) 243 | parser.add_argument( 244 | "--replacement", 245 | action="store_true", 246 | help="Sampling with replacement.", 247 | ) 248 | parser.add_argument( 249 | "--naive-sampling", 250 | action="store_true", 251 | help="Use multi-candidate naive sampling.", 252 | ) 253 | 254 | parser.add_argument("--disable-tree-attn", action="store_true") 255 | 256 | parser.add_argument( 257 | "--sampling-type", type=str, default="sampling", choices=["argmax", "sampling"] 258 | ) 259 | 260 | parser.add_argument("--disable-tqdm", action="store_true") 261 | 262 | parser.add_argument("--auto-model", action="store_true") 263 | parser.add_argument("--run-baseline", action="store_true") 264 | 265 | parser.add_argument("--flash-attn", action="store_true") 266 | 267 | args = parser.parse_args() 268 | 269 | if args.tokenizer is None: 270 | args.tokenizer = args.target_model 271 | main(args) 272 | -------------------------------------------------------------------------------- /MCSD/inference/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJUNLP/MCSD/8aadd6501a9e987ba5fca6cc8f9ad5949e480ec7/MCSD/inference/__init__.py -------------------------------------------------------------------------------- /MCSD/inference/generate.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List, Optional, Tuple, Union 3 | 4 | import torch 5 | from transformers.modeling_outputs import ModelOutput, CausalLMOutputWithPast 6 | 7 | from . import strategies 8 | 9 | 10 | @dataclass 11 | class DecoderOnlyOutput(ModelOutput): 12 | """ 13 | Base class for outputs of decoder-only generation models using MCSD. 14 | """ 15 | 16 | sequences: torch.LongTensor 17 | acceptance_count: int = None 18 | draft_token_count: int = None 19 | invocation_count: int = None 20 | 21 | 22 | class Generator: 23 | def __init__(self) -> None: 24 | pass 25 | 26 | def generate( 27 | self, 28 | input_ids: Optional[torch.Tensor] = None, 29 | ) -> DecoderOnlyOutput: 30 | raise NotImplementedError 31 | 32 | 33 | class BaseGenerator: 34 | def __init__( 35 | self, 36 | model, 37 | eos_token_id: int, 38 | max_new_tokens: int = 128, 39 | temp: float = 1, 40 | ) -> None: 41 | self.model = model 42 | self.eos_token_id = eos_token_id 43 | self.max_new_tokens = max_new_tokens 44 | self.temp = temp 45 | 46 | def generate( 47 | self, 48 | input_ids: Optional[torch.Tensor] = None, 49 | ) -> DecoderOnlyOutput: 50 | past_key_values = None 51 | invocation_count = 0 52 | 53 | init_input_len = input_ids.size(-1) 54 | 55 | while True: 56 | if past_key_values is not None: 57 | pruned_input_ids = input_ids[:, past_key_values[0][0].size(2) :] 58 | else: 59 | pruned_input_ids = input_ids 60 | 61 | outputs: CausalLMOutputWithPast = self.model( 62 | input_ids=pruned_input_ids, 63 | use_cache=True, 64 | past_key_values=past_key_values, 65 | return_dict=True, 66 | output_attentions=False, 67 | output_hidden_states=False, 68 | ) 69 | 70 | logits = outputs.logits 71 | past_key_values = outputs.past_key_values 72 | 73 | batch_num, seq_len, _ = logits.size() 74 | 75 | if self.temp == 0: 76 | _, ground_tokens = logits.topk(k=1, dim=-1) # batch x seq_len x 1 77 | else: 78 | ground_probs = torch.softmax( 79 | logits / self.temp, dim=-1 80 | ) # batch x seq_len x hidden_dim 81 | 82 | ground_tokens = torch.multinomial( 83 | ground_probs.view(batch_num * seq_len, -1), num_samples=1 84 | ) # batch*seq_len x 1 85 | ground_tokens = ground_tokens.view(batch_num, seq_len) 86 | 87 | input_ids = torch.cat( 88 | (input_ids, ground_tokens[:, -1:].to(input_ids)), dim=1 89 | ) 90 | 91 | invocation_count += 1 92 | 93 | if ( 94 | self.eos_token_id in input_ids[0, -1:] 95 | or input_ids.size(-1) - init_input_len >= self.max_new_tokens 96 | ): 97 | break 98 | return DecoderOnlyOutput(sequences=input_ids, invocation_count=invocation_count) 99 | 100 | 101 | class SpeculativeGenerator: 102 | def __init__( 103 | self, 104 | draft_model, 105 | target_model, 106 | eos_token_id: int, 107 | k_config: Tuple[int], 108 | max_new_tokens: int = 128, 109 | draft_model_temp: float = 1, 110 | target_model_temp: float = 1, 111 | replacement: bool = False, 112 | speculative_sampling: bool = True, 113 | tree_attn: bool = True, 114 | ) -> None: 115 | self.eos_token_id = eos_token_id 116 | self.max_new_tokens = max_new_tokens 117 | self.strategy: strategies.Strategy = None 118 | 119 | if tree_attn: 120 | self.strategy = strategies.TreeStrategy( 121 | draft_model=draft_model, 122 | target_model=target_model, 123 | k_config=k_config, 124 | draft_model_temp=draft_model_temp, 125 | target_model_temp=target_model_temp, 126 | replacement=replacement, 127 | speculative_sampling=speculative_sampling, 128 | ) 129 | else: 130 | self.strategy = strategies.BatchStrategy( 131 | draft_model=draft_model, 132 | target_model=target_model, 133 | k_config=k_config, 134 | draft_model_temp=draft_model_temp, 135 | target_model_temp=target_model_temp, 136 | replacement=replacement, 137 | speculative_sampling=speculative_sampling, 138 | ) 139 | 140 | def generate( 141 | self, 142 | input_ids: Optional[torch.Tensor] = None, 143 | ) -> DecoderOnlyOutput: 144 | target_model_past_key_values = None 145 | draft_model_past_key_values = None 146 | 147 | invocation_count = 0 148 | acceptance_count = 0 149 | 150 | init_input_len = input_ids.size(-1) 151 | 152 | while True: 153 | draft_output = self.strategy.generate_draft( 154 | input_ids, 155 | past_key_values=draft_model_past_key_values, 156 | ) 157 | 158 | draft_model_past_key_values = draft_output.past_key_values 159 | 160 | verification_output = self.strategy.verify( 161 | input_ids=draft_output.sequences, 162 | target_model_past_key_values=target_model_past_key_values, 163 | draft_model_past_key_values=draft_output.past_key_values, 164 | cand_probs=draft_output.cand_probs, 165 | ) 166 | 167 | input_ids = verification_output.sequences 168 | 169 | draft_model_past_key_values = ( 170 | verification_output.draft_model_past_key_values 171 | ) 172 | target_model_past_key_values = ( 173 | verification_output.target_model_past_key_values 174 | ) 175 | 176 | invocation_count += 1 177 | acceptance_count += verification_output.acceptance_count 178 | 179 | if ( 180 | self.eos_token_id in input_ids[0, -self.strategy.max_draft_len :] 181 | or input_ids.size(-1) - init_input_len >= self.max_new_tokens 182 | ): 183 | break 184 | return DecoderOnlyOutput( 185 | sequences=input_ids, 186 | acceptance_count=acceptance_count, 187 | draft_token_count=invocation_count * self.strategy.max_draft_len, 188 | invocation_count=invocation_count, 189 | ) 190 | -------------------------------------------------------------------------------- /MCSD/inference/strategies.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from dataclasses import dataclass 3 | from typing import Callable, List, Literal, Optional, Tuple, Union 4 | 5 | import torch 6 | from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput 7 | 8 | 9 | @dataclass 10 | class DecoderOnlyDraftOutput(ModelOutput): 11 | """ 12 | Base class for draft outputs of decoder-only generation models using speculative decoding. 13 | """ 14 | 15 | sequences: torch.LongTensor = None 16 | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None 17 | cand_probs: Optional[Tuple[torch.FloatTensor]] = None 18 | 19 | 20 | @dataclass 21 | class DecoderOnlyVerificationOutput(ModelOutput): 22 | """ 23 | Base class for verification outputs of decoder-only generation models using speculative decoding. 24 | """ 25 | 26 | sequences: torch.LongTensor = None 27 | target_model_past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None 28 | draft_model_past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None 29 | acceptance_count: Optional[int] = None 30 | 31 | 32 | def _MCNS( 33 | ground_probs: torch.FloatTensor, 34 | cand_probs: Tuple[torch.FloatTensor], 35 | cand_tokens: torch.LongTensor, 36 | ) -> Optional[int]: 37 | ground_token = torch.multinomial(ground_probs, num_samples=1).item() 38 | 39 | for check_idx, cand_token in enumerate(cand_tokens): 40 | if ground_token == cand_token: 41 | return check_idx 42 | ground_probs[:] = 0 43 | ground_probs[ground_token] = 1 44 | return None 45 | 46 | 47 | def _MCSSwoReplacement( 48 | ground_probs: torch.FloatTensor, 49 | cand_probs: Tuple[torch.FloatTensor], 50 | cand_tokens: torch.LongTensor, 51 | ) -> Optional[int]: 52 | cand_probs = cand_probs.to(ground_probs.device) 53 | for check_idx, cand_token in enumerate(cand_tokens): 54 | accept_threshold = ground_probs[cand_token] / cand_probs[cand_token] 55 | if torch.rand(1, device=accept_threshold.device) <= accept_threshold: 56 | return check_idx 57 | else: 58 | ground_probs -= cand_probs 59 | ground_probs = torch.nn.functional.relu(ground_probs, inplace=True) 60 | ground_probs /= ground_probs.sum() 61 | cand_probs[cand_token] = 0 62 | cand_probs = cand_probs / cand_probs.sum() 63 | return None 64 | 65 | 66 | def _MCSSwReplacement( 67 | ground_probs: torch.FloatTensor, 68 | cand_probs: Tuple[torch.FloatTensor], 69 | cand_tokens: torch.LongTensor, 70 | ) -> Optional[int]: 71 | cand_probs = cand_probs.to(ground_probs.device) 72 | for check_idx, cand_token in enumerate(cand_tokens): 73 | accept_threshold = ground_probs[cand_token] / cand_probs[cand_token] 74 | if torch.rand(1, device=accept_threshold.device) <= accept_threshold: 75 | return check_idx 76 | else: 77 | ground_probs -= cand_probs 78 | ground_probs = torch.nn.functional.relu(ground_probs, inplace=True) 79 | ground_probs /= ground_probs.sum() 80 | return None 81 | 82 | 83 | class Strategy: 84 | def __init__( 85 | self, 86 | draft_model, 87 | target_model, 88 | k_config: Tuple[int], 89 | draft_model_temp: float = 1, 90 | target_model_temp: float = 1, 91 | replacement: bool = False, 92 | speculative_sampling: bool = True, 93 | ) -> None: 94 | self.k_config = k_config 95 | self.draft_model = draft_model 96 | self.target_model = target_model 97 | self.draft_model_device = draft_model.model.get_input_embeddings().weight.device 98 | self.target_model_device = ( 99 | target_model.model.get_input_embeddings().weight.device 100 | ) 101 | self.max_draft_len = len(k_config) 102 | self.draft_model_temp = draft_model_temp 103 | self.target_model_temp = target_model_temp 104 | self.replacement = replacement 105 | self.speculative_sampling = speculative_sampling 106 | 107 | self.acceptance_check: Callable[ 108 | [torch.FloatTensor, Tuple[torch.FloatTensor], torch.LongTensor], 109 | Optional[int], 110 | ] = None 111 | if speculative_sampling: 112 | if replacement: 113 | self.acceptance_check = _MCSSwReplacement 114 | if draft_model_temp == 0: 115 | warnings.warn( 116 | ( 117 | "You have set Temp=0 and are using sampling with replacement. " 118 | "As a result, all the candidates obtained are the same, causing " 119 | "the MCSD algorithm to degenerate into the vanilla SD." 120 | ), 121 | category=UserWarning, 122 | stacklevel=3, 123 | ) 124 | else: 125 | self.acceptance_check = _MCSSwoReplacement 126 | else: 127 | if replacement: 128 | warnings.warn( 129 | ( 130 | "`replacement` is not applicable when `speculative_sampling` is False." 131 | "The acceptance check algorithm defaults to MCNS (Multi-Candidate Naive Sampling)" 132 | " when `speculative_sampling=False`." 133 | ), 134 | category=UserWarning, 135 | stacklevel=3, 136 | ) 137 | self.acceptance_check = _MCNS 138 | 139 | def generate_draft( 140 | self, 141 | input_ids: torch.LongTensor, 142 | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]], 143 | ) -> DecoderOnlyDraftOutput: 144 | raise NotImplementedError 145 | 146 | def acceptance_check(self, ground_probs, cand_probs, cand_tokens) -> Optional[int]: 147 | raise NotImplementedError 148 | 149 | def verify( 150 | self, 151 | input_ids: torch.LongTensor, 152 | target_model_past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]], 153 | draft_model_past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]], 154 | cand_probs: Optional[Tuple[torch.FloatTensor]], 155 | ) -> DecoderOnlyVerificationOutput: 156 | raise NotImplementedError 157 | 158 | 159 | class BatchStrategy(Strategy): 160 | def __init__( 161 | self, 162 | draft_model, 163 | target_model, 164 | k_config: Tuple[int], 165 | draft_model_temp=1, 166 | target_model_temp=1, 167 | replacement: bool = False, 168 | speculative_sampling: bool = True, 169 | ) -> None: 170 | super().__init__( 171 | draft_model, 172 | target_model, 173 | k_config, 174 | draft_model_temp, 175 | target_model_temp, 176 | replacement, 177 | speculative_sampling, 178 | ) 179 | 180 | reversed_prod_size = [1] 181 | for i in range(1, self.max_draft_len): 182 | reversed_prod_size.insert(0, reversed_prod_size[0] * k_config[-i]) 183 | 184 | self.reversed_prod_size = reversed_prod_size 185 | 186 | def generate_draft( 187 | self, 188 | input_ids: torch.LongTensor, 189 | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]], 190 | ) -> DecoderOnlyDraftOutput: 191 | input_ids = input_ids.to(self.draft_model_device) 192 | cand_probs = [] 193 | for step in range(self.max_draft_len): 194 | step_k = self.k_config[step] 195 | if past_key_values is not None: 196 | pruned_input_ids = input_ids[:, past_key_values[0][0].size(2) :] 197 | else: 198 | pruned_input_ids = input_ids 199 | outputs: BaseModelOutputWithPast = self.draft_model.model( 200 | input_ids=pruned_input_ids, 201 | use_cache=True, 202 | past_key_values=past_key_values, 203 | return_dict=True, 204 | output_attentions=False, 205 | output_hidden_states=False, 206 | ) 207 | 208 | hidden_states = outputs.last_hidden_state 209 | 210 | logits = self.draft_model.lm_head(hidden_states[:, -1]) 211 | 212 | past_key_values = list(outputs.past_key_values) 213 | 214 | if self.draft_model_temp == 0: 215 | if not self.replacement: 216 | topk_logit, topk_index = logits.topk(k=step_k, dim=-1) # batch x k 217 | topk_probs = torch.softmax(topk_logit, dim=-1) 218 | step_cand_probs = torch.zeros_like(logits) 219 | step_cand_probs.scatter_(dim=1, index=topk_index, src=topk_probs) 220 | cand_tokens = topk_index.view(-1, 1) 221 | else: 222 | topk_logit, topk_index = logits.topk(k=1, dim=-1) # batch x k 223 | step_cand_probs = torch.zeros_like(logits) 224 | step_cand_probs.scatter_(dim=1, index=topk_index, value=1) 225 | cand_tokens = topk_index.view(-1, 1) 226 | cand_tokens = torch.repeat_interleave(cand_tokens, step_k, dim=0) 227 | else: 228 | step_cand_probs = torch.softmax(logits / self.draft_model_temp, dim=-1) 229 | cand_tokens = torch.multinomial( 230 | step_cand_probs, 231 | step_k, 232 | replacement=self.replacement, 233 | ).view(-1, 1) 234 | 235 | cand_probs.append(step_cand_probs) 236 | 237 | input_ids = input_ids.repeat_interleave(step_k, dim=0) 238 | input_ids = torch.cat( 239 | ( 240 | input_ids, 241 | cand_tokens, 242 | ), 243 | dim=1, 244 | ) 245 | if step + 1 != self.max_draft_len: 246 | for i in range(len(past_key_values)): 247 | past_key_values[i] = ( 248 | past_key_values[i][0].repeat_interleave(step_k, dim=0), 249 | past_key_values[i][1].repeat_interleave(step_k, dim=0), 250 | ) 251 | 252 | return DecoderOnlyDraftOutput( 253 | sequences=input_ids, 254 | past_key_values=past_key_values, 255 | cand_probs=tuple(cand_probs), 256 | ) 257 | 258 | def verify( 259 | self, 260 | input_ids: torch.LongTensor, 261 | target_model_past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]], 262 | draft_model_past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]], 263 | cand_probs: Optional[Tuple[torch.FloatTensor]], 264 | ) -> DecoderOnlyVerificationOutput: 265 | input_ids = input_ids.to(self.target_model_device) 266 | batch_size, input_len = input_ids.size() 267 | if target_model_past_key_values is not None: 268 | pruned_input_ids = input_ids[ 269 | :, target_model_past_key_values[0][0].size(2) : 270 | ] 271 | for i in range(len(target_model_past_key_values)): 272 | target_model_past_key_values[i] = ( 273 | target_model_past_key_values[i][0].repeat_interleave( 274 | batch_size, dim=0 275 | ), 276 | target_model_past_key_values[i][1].repeat_interleave( 277 | batch_size, dim=0 278 | ), 279 | ) 280 | else: 281 | pruned_input_ids = input_ids 282 | 283 | outputs: BaseModelOutputWithPast = self.target_model.model( 284 | input_ids=pruned_input_ids, 285 | use_cache=True, 286 | past_key_values=target_model_past_key_values, 287 | return_dict=True, 288 | output_attentions=False, 289 | output_hidden_states=False, 290 | ) 291 | hidden_states = outputs.last_hidden_state 292 | target_model_past_key_values = list(outputs.past_key_values) 293 | 294 | logits = self.target_model.lm_head(hidden_states[:, -self.max_draft_len - 1 :]) 295 | 296 | if self.target_model_temp == 0: 297 | _, topk_index = logits.topk(k=1, dim=-1) # seq_len x 1 298 | ground_probs = torch.zeros_like(logits) 299 | ground_probs.scatter_(dim=2, index=topk_index, value=1) 300 | else: 301 | ground_probs = torch.softmax(logits / self.target_model_temp, dim=-1) 302 | 303 | unverified_input_ids = input_ids[:, -self.max_draft_len :] 304 | 305 | assert ground_probs.size(1) == unverified_input_ids.size(1) + 1 306 | 307 | cand_probs_idx = 0 308 | alive_group_id = 0 309 | 310 | for depth in range(self.max_draft_len): 311 | verify_batch_ids = [ 312 | alive_group_id + group_offset * self.reversed_prod_size[depth] 313 | for group_offset in range(self.k_config[depth]) 314 | ] 315 | accept_idx_bias = self.acceptance_check( 316 | ground_probs[alive_group_id, depth], 317 | cand_probs[depth][cand_probs_idx], 318 | unverified_input_ids[verify_batch_ids, depth], 319 | ) 320 | if accept_idx_bias is not None: 321 | alive_group_id = verify_batch_ids[accept_idx_bias] 322 | cand_probs_idx = accept_idx_bias + cand_probs_idx * self.k_config[depth] 323 | if depth == self.max_draft_len - 1: 324 | depth = self.max_draft_len 325 | else: 326 | break 327 | input_ids = input_ids[alive_group_id, : input_len - self.max_draft_len + depth] 328 | endpoint_token = torch.multinomial( 329 | ground_probs[alive_group_id, depth], num_samples=1 330 | ).to(device=input_ids.device) 331 | 332 | input_ids = torch.cat((input_ids, endpoint_token)) 333 | 334 | input_ids.unsqueeze_(0) 335 | 336 | for i in range(len(target_model_past_key_values)): 337 | target_model_past_key_values[i] = ( 338 | target_model_past_key_values[i][0][ 339 | None, alive_group_id, :, : input_len - self.max_draft_len + depth 340 | ], 341 | target_model_past_key_values[i][1][ 342 | None, alive_group_id, :, : input_len - self.max_draft_len + depth 343 | ], 344 | ) 345 | for i in range(len(draft_model_past_key_values)): 346 | draft_model_past_key_values[i] = ( 347 | draft_model_past_key_values[i][0][ 348 | None, 349 | alive_group_id // self.k_config[-1], 350 | :, 351 | : input_len - self.max_draft_len + depth, 352 | ], 353 | draft_model_past_key_values[i][1][ 354 | None, 355 | alive_group_id // self.k_config[-1], 356 | :, 357 | : input_len - self.max_draft_len + depth, 358 | ], 359 | ) 360 | return DecoderOnlyVerificationOutput( 361 | sequences=input_ids, 362 | target_model_past_key_values=target_model_past_key_values, 363 | draft_model_past_key_values=draft_model_past_key_values, 364 | acceptance_count=depth, 365 | ) 366 | 367 | 368 | def get_tree_attn_self_mask(k_config: Tuple[int]): 369 | k_config = torch.tensor(k_config, dtype=torch.int) 370 | prod_size = torch.cumprod(k_config, dim=0) 371 | mask_size = prod_size.sum().item() 372 | attn_mask = torch.zeros((mask_size, mask_size), dtype=torch.bool) 373 | attn_mask = attn_mask.diagonal_scatter(torch.ones(mask_size)) 374 | # run BFS 375 | idx_queue = [ 376 | (0, None, idx) for idx in list(range(k_config[0])) 377 | ] # each node: (depth, parent, idx) 378 | while len(idx_queue) != 0: 379 | depth, parent, idx = idx_queue.pop(0) 380 | if parent is not None: 381 | attn_mask[idx, : parent + 1] = attn_mask[parent, : parent + 1] 382 | 383 | if depth != len(k_config) - 1: 384 | idx_base = prod_size[:depth].sum().item() 385 | child_idx_base = prod_size[: depth + 1].sum().item() 386 | for child_idx_bias in range(k_config[depth + 1]): 387 | real_child_idx = ( 388 | (idx - idx_base) * k_config[depth + 1] 389 | + child_idx_base 390 | + child_idx_bias 391 | ) 392 | idx_queue.append((depth + 1, idx, real_child_idx)) 393 | return attn_mask 394 | 395 | 396 | class TreeStrategy(Strategy): 397 | def __init__( 398 | self, 399 | draft_model, 400 | target_model, 401 | k_config: Tuple[int], 402 | draft_model_temp: float = 1, 403 | target_model_temp: float = 1, 404 | replacement: bool = False, 405 | speculative_sampling: bool = True, 406 | ) -> None: 407 | super().__init__( 408 | draft_model, 409 | target_model, 410 | k_config, 411 | draft_model_temp, 412 | target_model_temp, 413 | replacement, 414 | speculative_sampling, 415 | ) 416 | 417 | prod_size = torch.cumprod(torch.tensor(k_config, dtype=torch.int), dim=0) 418 | prod_size = torch.cat((torch.zeros(1).to(prod_size), prod_size)).tolist() 419 | self.prod_size = prod_size 420 | self.cumulative_prod_size = torch.cumsum( 421 | torch.tensor(prod_size), dim=0 422 | ).tolist() 423 | 424 | self.tree_attn_self_mask = get_tree_attn_self_mask(k_config).to( 425 | device=self.draft_model_device 426 | ) 427 | 428 | def generate_draft( 429 | self, 430 | input_ids: torch.LongTensor, 431 | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]], 432 | ) -> DecoderOnlyDraftOutput: 433 | input_ids = input_ids.to(self.draft_model_device) 434 | cand_probs = [] 435 | step_tree_attn_mask = None 436 | position_ids = None 437 | init_input_length = input_ids.size(1) 438 | if past_key_values is not None: 439 | pruned_input_ids = input_ids[:, past_key_values[0][0].size(2) :] 440 | else: 441 | pruned_input_ids = input_ids 442 | for step in range(self.max_draft_len): 443 | step_k = self.k_config[step] 444 | 445 | # prepare attn mask 446 | if step != 0: 447 | step_tree_attn_self_mask = self.tree_attn_self_mask[ 448 | self.cumulative_prod_size[step - 1] : self.cumulative_prod_size[ 449 | step 450 | ], 451 | : self.cumulative_prod_size[step], 452 | ] 453 | position_ids = torch.full( 454 | (1, self.prod_size[step]), 455 | init_input_length + step - 1, 456 | dtype=torch.long, 457 | device=self.draft_model_device, 458 | ) 459 | context_attn_mask = torch.ones( 460 | (self.prod_size[step], init_input_length), dtype=torch.bool 461 | ).to(self.tree_attn_self_mask) 462 | step_tree_attn_mask = torch.cat( 463 | (context_attn_mask, step_tree_attn_self_mask), dim=1 464 | ) 465 | 466 | outputs: BaseModelOutputWithPast = self.draft_model.model( 467 | input_ids=pruned_input_ids, 468 | use_cache=True, 469 | past_key_values=past_key_values, 470 | return_dict=True, 471 | output_attentions=False, 472 | output_hidden_states=False, 473 | tree_attn_mask=step_tree_attn_mask, 474 | position_ids=position_ids, 475 | ) 476 | 477 | hidden_states = outputs.last_hidden_state 478 | 479 | if step == 0: 480 | hidden_states = hidden_states[0, -1:] 481 | else: 482 | hidden_states = hidden_states[0] 483 | logits = self.draft_model.lm_head(hidden_states) # seq_len x hidden_dim 484 | 485 | past_key_values = list(outputs.past_key_values) 486 | 487 | if self.draft_model_temp == 0: 488 | if not self.replacement: 489 | topk_logit, topk_index = logits.topk( 490 | k=step_k, dim=-1 491 | ) # seq_len x k 492 | topk_probs = torch.softmax(topk_logit, dim=-1) 493 | step_cand_probs = torch.zeros_like(logits) 494 | step_cand_probs.scatter_(dim=1, index=topk_index, src=topk_probs) 495 | cand_tokens = topk_index.view(1, -1) 496 | else: 497 | topk_logit, topk_index = logits.topk(k=1, dim=-1) # seq_len x k 498 | step_cand_probs = torch.zeros_like(logits) 499 | step_cand_probs.scatter_(dim=1, index=topk_index, value=1) 500 | cand_tokens = topk_index.view(1, -1) 501 | cand_tokens = torch.repeat_interleave(cand_tokens, step_k, dim=1) 502 | else: 503 | step_cand_probs = torch.softmax(logits / self.draft_model_temp, dim=-1) 504 | cand_tokens = torch.multinomial( 505 | step_cand_probs, step_k, replacement=self.replacement 506 | ).view(1, -1) 507 | cand_probs.append(step_cand_probs) 508 | 509 | pruned_input_ids = cand_tokens 510 | 511 | input_ids = torch.cat((input_ids, pruned_input_ids), dim=1) 512 | 513 | return DecoderOnlyDraftOutput( 514 | sequences=input_ids, 515 | past_key_values=past_key_values, 516 | cand_probs=tuple(cand_probs), 517 | ) 518 | 519 | def _forward_target_model( 520 | self, 521 | input_ids: torch.LongTensor, 522 | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]], 523 | ): 524 | input_ids = input_ids.to(self.target_model_device) 525 | tree_attn_len = self.tree_attn_self_mask.size(0) 526 | init_input_length = input_ids.size(1) - tree_attn_len 527 | init_forward = False 528 | 529 | if past_key_values is not None: 530 | pruned_input_ids = input_ids[:, past_key_values[0][0].size(2) :] 531 | else: 532 | pruned_input_ids = input_ids 533 | init_forward = True 534 | 535 | if init_forward: 536 | tree_attn_mask = torch.zeros( 537 | (input_ids.size(1), input_ids.size(1)), 538 | dtype=torch.bool, 539 | device=self.target_model_device, 540 | ) 541 | mask_cond = torch.arange( 542 | tree_attn_mask.size(-1), device=self.target_model_device 543 | ) 544 | tree_attn_mask.masked_fill_( 545 | mask_cond < (mask_cond + 1).view(tree_attn_mask.size(-1), 1), 1 546 | ) 547 | tree_attn_mask[-tree_attn_len:, -tree_attn_len:] = self.tree_attn_self_mask 548 | position_ids = tree_attn_mask.sum(dim=1) - 1 549 | 550 | else: 551 | tree_attn_mask = torch.ones( 552 | ( 553 | tree_attn_len + 1, 554 | input_ids.size(1), 555 | ), # there is one token not stored in the kv values 556 | dtype=torch.bool, 557 | device=self.target_model_device, 558 | ) 559 | 560 | tree_attn_mask[1:, init_input_length:] = self.tree_attn_self_mask 561 | tree_attn_mask[0, init_input_length:] = 0 562 | position_ids = tree_attn_mask.sum(dim=1) - 1 563 | 564 | outputs: BaseModelOutputWithPast = self.target_model.model( 565 | input_ids=pruned_input_ids, 566 | use_cache=True, 567 | past_key_values=past_key_values, 568 | return_dict=True, 569 | output_attentions=False, 570 | output_hidden_states=False, 571 | tree_attn_mask=tree_attn_mask, 572 | position_ids=position_ids, 573 | ) 574 | hidden_states = outputs.last_hidden_state 575 | past_key_values = list(outputs.past_key_values) 576 | 577 | logits = self.target_model.lm_head( 578 | hidden_states[:, -tree_attn_len - 1 :] 579 | ) # 1 x seq_len x hidden_dim 580 | return logits, past_key_values 581 | 582 | def verify( 583 | self, 584 | input_ids: torch.LongTensor, 585 | target_model_past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]], 586 | draft_model_past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]], 587 | cand_probs: Optional[Tuple[torch.FloatTensor]], 588 | ) -> DecoderOnlyVerificationOutput: 589 | input_ids = input_ids.to(self.target_model_device) 590 | logits, target_model_past_key_values = self._forward_target_model( 591 | input_ids, target_model_past_key_values 592 | ) 593 | logits = logits[0] # seq_len x hidden_dim 594 | tree_attn_len = self.tree_attn_self_mask.size(0) 595 | unverified_tokens = input_ids[0, -tree_attn_len:] 596 | init_input_length = input_ids.size(1) - tree_attn_len 597 | 598 | if self.target_model_temp == 0: 599 | _, topk_index = logits.topk(k=1, dim=-1) # seq_len x 1 600 | ground_probs = torch.zeros_like(logits) 601 | ground_probs.scatter_(dim=1, index=topk_index, value=1) 602 | else: 603 | ground_probs = torch.softmax(logits / self.target_model_temp, dim=-1) 604 | current_ground_prob = ground_probs[0] 605 | ground_probs = ground_probs[1:] 606 | 607 | keep_indices = list(range(init_input_length)) 608 | to_drop_len = 0 609 | idx_group_bias = 0 610 | cand_probs_idx = 0 611 | 612 | for depth in range(self.max_draft_len): 613 | idx_base = self.cumulative_prod_size[depth] + idx_group_bias 614 | accept_idx_bias = self.acceptance_check( 615 | current_ground_prob, 616 | cand_probs[depth][cand_probs_idx], 617 | unverified_tokens[idx_base : idx_base + self.k_config[depth]], 618 | ) 619 | if accept_idx_bias is not None: 620 | global_idx = idx_base + accept_idx_bias 621 | current_ground_prob = ground_probs[global_idx] 622 | keep_indices.append(init_input_length + global_idx) 623 | if depth == self.max_draft_len - 1: 624 | to_drop_len += 1 625 | depth = self.max_draft_len 626 | else: 627 | cand_probs_idx = idx_group_bias + accept_idx_bias 628 | idx_group_bias = cand_probs_idx * self.k_config[depth + 1] 629 | else: 630 | break 631 | 632 | keep_indices = torch.tensor( 633 | keep_indices, dtype=torch.long, device=self.target_model_device 634 | ) 635 | if to_drop_len != 0: 636 | draft_keep_indices = keep_indices[: len(keep_indices) - to_drop_len] 637 | else: 638 | draft_keep_indices = keep_indices 639 | 640 | tail_ground_token = torch.multinomial(current_ground_prob, num_samples=1).to( 641 | device=input_ids.device 642 | ) 643 | 644 | input_ids = input_ids.index_select(dim=1, index=keep_indices) 645 | input_ids = torch.cat((input_ids, tail_ground_token[None]), dim=1) 646 | 647 | for i in range(len(target_model_past_key_values)): 648 | keep_indices = keep_indices.to( 649 | device=target_model_past_key_values[i][0].device 650 | ) 651 | target_model_past_key_values[i] = ( 652 | target_model_past_key_values[i][0].index_select( 653 | dim=2, index=keep_indices 654 | ), 655 | target_model_past_key_values[i][1].index_select( 656 | dim=2, index=keep_indices 657 | ), 658 | ) 659 | for i in range(len(draft_model_past_key_values)): 660 | draft_model_past_key_values[i] = ( 661 | draft_model_past_key_values[i][0].index_select( 662 | dim=2, index=draft_keep_indices 663 | ), 664 | draft_model_past_key_values[i][1].index_select( 665 | dim=2, index=draft_keep_indices 666 | ), 667 | ) 668 | 669 | return DecoderOnlyVerificationOutput( 670 | sequences=input_ids, 671 | target_model_past_key_values=target_model_past_key_values, 672 | draft_model_past_key_values=draft_model_past_key_values, 673 | acceptance_count=depth, 674 | ) 675 | -------------------------------------------------------------------------------- /MCSD/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJUNLP/MCSD/8aadd6501a9e987ba5fca6cc8f9ad5949e480ec7/MCSD/model/__init__.py -------------------------------------------------------------------------------- /MCSD/model/llama_tree_attn/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 EleutherAI and The HuggingFace Inc. team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import TYPE_CHECKING 15 | 16 | from transformers.utils import ( 17 | OptionalDependencyNotAvailable, 18 | _LazyModule, 19 | is_sentencepiece_available, 20 | is_tokenizers_available, 21 | is_torch_available, 22 | ) 23 | 24 | 25 | _import_structure = { 26 | "configuration_llama": ["LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP", "LlamaConfig"], 27 | } 28 | 29 | try: 30 | if not is_sentencepiece_available(): 31 | raise OptionalDependencyNotAvailable() 32 | except OptionalDependencyNotAvailable: 33 | pass 34 | else: 35 | _import_structure["tokenization_llama"] = ["LlamaTokenizer"] 36 | 37 | try: 38 | if not is_tokenizers_available(): 39 | raise OptionalDependencyNotAvailable() 40 | except OptionalDependencyNotAvailable: 41 | pass 42 | else: 43 | _import_structure["tokenization_llama_fast"] = ["LlamaTokenizerFast"] 44 | 45 | try: 46 | if not is_torch_available(): 47 | raise OptionalDependencyNotAvailable() 48 | except OptionalDependencyNotAvailable: 49 | pass 50 | else: 51 | _import_structure["modeling_llama"] = [ 52 | "LlamaForCausalLM", 53 | "LlamaModel", 54 | "LlamaPreTrainedModel", 55 | "LlamaForSequenceClassification", 56 | ] 57 | 58 | 59 | if TYPE_CHECKING: 60 | from .configuration_llama import LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP, LlamaConfig 61 | 62 | try: 63 | if not is_sentencepiece_available(): 64 | raise OptionalDependencyNotAvailable() 65 | except OptionalDependencyNotAvailable: 66 | pass 67 | else: 68 | from .tokenization_llama import LlamaTokenizer 69 | 70 | try: 71 | if not is_tokenizers_available(): 72 | raise OptionalDependencyNotAvailable() 73 | except OptionalDependencyNotAvailable: 74 | pass 75 | else: 76 | from .tokenization_llama_fast import LlamaTokenizerFast 77 | 78 | try: 79 | if not is_torch_available(): 80 | raise OptionalDependencyNotAvailable() 81 | except OptionalDependencyNotAvailable: 82 | pass 83 | else: 84 | from .modeling_llama import ( 85 | LlamaForCausalLM, 86 | LlamaForSequenceClassification, 87 | LlamaModel, 88 | LlamaPreTrainedModel, 89 | ) 90 | 91 | 92 | else: 93 | import sys 94 | 95 | sys.modules[__name__] = _LazyModule( 96 | __name__, globals()["__file__"], _import_structure, module_spec=__spec__ 97 | ) 98 | -------------------------------------------------------------------------------- /MCSD/model/llama_tree_attn/configuration_llama.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. 3 | # 4 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX 5 | # and OPT implementations in this library. It has been modified from its 6 | # original forms to accommodate minor architectural differences compared 7 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model. 8 | # 9 | # Licensed under the Apache License, Version 2.0 (the "License"); 10 | # you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # 13 | # http://www.apache.org/licenses/LICENSE-2.0 14 | # 15 | # Unless required by applicable law or agreed to in writing, software 16 | # distributed under the License is distributed on an "AS IS" BASIS, 17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | # See the License for the specific language governing permissions and 19 | # limitations under the License. 20 | """ LLaMA model configuration""" 21 | 22 | from transformers.configuration_utils import PretrainedConfig 23 | from transformers.utils import logging 24 | 25 | 26 | logger = logging.get_logger(__name__) 27 | 28 | LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP = {} 29 | 30 | 31 | class LlamaConfig(PretrainedConfig): 32 | r""" 33 | This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA 34 | model according to the specified arguments, defining the model architecture. Instantiating a configuration with the 35 | defaults will yield a similar configuration to that of the LLaMA-7B. 36 | 37 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 38 | documentation from [`PretrainedConfig`] for more information. 39 | 40 | 41 | Args: 42 | vocab_size (`int`, *optional*, defaults to 32000): 43 | Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the 44 | `inputs_ids` passed when calling [`LlamaModel`] 45 | hidden_size (`int`, *optional*, defaults to 4096): 46 | Dimension of the hidden representations. 47 | intermediate_size (`int`, *optional*, defaults to 11008): 48 | Dimension of the MLP representations. 49 | num_hidden_layers (`int`, *optional*, defaults to 32): 50 | Number of hidden layers in the Transformer encoder. 51 | num_attention_heads (`int`, *optional*, defaults to 32): 52 | Number of attention heads for each attention layer in the Transformer encoder. 53 | num_key_value_heads (`int`, *optional*): 54 | This is the number of key_value heads that should be used to implement Grouped Query Attention. If 55 | `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if 56 | `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When 57 | converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed 58 | by meanpooling all the original heads within that group. For more details checkout [this 59 | paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to 60 | `num_attention_heads`. 61 | pretraining_tp (`int`, *optional*, defaults to `1`): 62 | Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this 63 | document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is 64 | necessary to ensure exact reproducibility of the pretraining results. Please refer to [this 65 | issue](https://github.com/pytorch/pytorch/issues/76232). 66 | hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): 67 | The non-linear activation function (function or string) in the decoder. 68 | max_position_embeddings (`int`, *optional*, defaults to 2048): 69 | The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens, 70 | Llama 2 up to 4096, CodeLlama up to 16384. 71 | initializer_range (`float`, *optional*, defaults to 0.02): 72 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 73 | rms_norm_eps (`float`, *optional*, defaults to 1e-12): 74 | The epsilon used by the rms normalization layers. 75 | use_cache (`bool`, *optional*, defaults to `True`): 76 | Whether or not the model should return the last key/values attentions (not used by all models). Only 77 | relevant if `config.is_decoder=True`. 78 | tie_word_embeddings(`bool`, *optional*, defaults to `False`): 79 | Whether to tie weight embeddings 80 | rope_theta (`float`, *optional*, defaults to 10000.0): 81 | The base period of the RoPE embeddings. 82 | rope_scaling (`Dict`, *optional*): 83 | Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling 84 | strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format 85 | is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update 86 | `max_position_embeddings` to the expected new maximum. See the following thread for more information on how 87 | these scaling strategies behave: 88 | https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an 89 | experimental feature, subject to breaking API changes in future versions. 90 | attention_bias (`bool`, defaults to `False`): 91 | Whether to use a bias in the query, key, value and output projection layers during self-attention. 92 | 93 | Example: 94 | 95 | ```python 96 | >>> from transformers import LlamaModel, LlamaConfig 97 | 98 | >>> # Initializing a LLaMA llama-7b style configuration 99 | >>> configuration = LlamaConfig() 100 | 101 | >>> # Initializing a model from the llama-7b style configuration 102 | >>> model = LlamaModel(configuration) 103 | 104 | >>> # Accessing the model configuration 105 | >>> configuration = model.config 106 | ```""" 107 | model_type = "llama" 108 | keys_to_ignore_at_inference = ["past_key_values"] 109 | 110 | def __init__( 111 | self, 112 | vocab_size=32000, 113 | hidden_size=4096, 114 | intermediate_size=11008, 115 | num_hidden_layers=32, 116 | num_attention_heads=32, 117 | num_key_value_heads=None, 118 | hidden_act="silu", 119 | max_position_embeddings=2048, 120 | initializer_range=0.02, 121 | rms_norm_eps=1e-6, 122 | use_cache=True, 123 | pad_token_id=None, 124 | bos_token_id=1, 125 | eos_token_id=2, 126 | pretraining_tp=1, 127 | tie_word_embeddings=False, 128 | rope_theta=10000.0, 129 | rope_scaling=None, 130 | attention_bias=False, 131 | **kwargs, 132 | ): 133 | self.vocab_size = vocab_size 134 | self.max_position_embeddings = max_position_embeddings 135 | self.hidden_size = hidden_size 136 | self.intermediate_size = intermediate_size 137 | self.num_hidden_layers = num_hidden_layers 138 | self.num_attention_heads = num_attention_heads 139 | 140 | # for backward compatibility 141 | if num_key_value_heads is None: 142 | num_key_value_heads = num_attention_heads 143 | 144 | self.num_key_value_heads = num_key_value_heads 145 | self.hidden_act = hidden_act 146 | self.initializer_range = initializer_range 147 | self.rms_norm_eps = rms_norm_eps 148 | self.pretraining_tp = pretraining_tp 149 | self.use_cache = use_cache 150 | self.rope_theta = rope_theta 151 | self.rope_scaling = rope_scaling 152 | self._rope_scaling_validation() 153 | self.attention_bias = attention_bias 154 | 155 | super().__init__( 156 | pad_token_id=pad_token_id, 157 | bos_token_id=bos_token_id, 158 | eos_token_id=eos_token_id, 159 | tie_word_embeddings=tie_word_embeddings, 160 | **kwargs, 161 | ) 162 | 163 | def _rope_scaling_validation(self): 164 | """ 165 | Validate the `rope_scaling` configuration. 166 | """ 167 | if self.rope_scaling is None: 168 | return 169 | 170 | if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: 171 | raise ValueError( 172 | "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, " 173 | f"got {self.rope_scaling}" 174 | ) 175 | rope_scaling_type = self.rope_scaling.get("type", None) 176 | rope_scaling_factor = self.rope_scaling.get("factor", None) 177 | if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: 178 | raise ValueError( 179 | f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" 180 | ) 181 | if ( 182 | rope_scaling_factor is None 183 | or not isinstance(rope_scaling_factor, float) 184 | or rope_scaling_factor <= 1.0 185 | ): 186 | raise ValueError( 187 | f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}" 188 | ) 189 | -------------------------------------------------------------------------------- /MCSD/model/llama_tree_attn/convert_llama_weights_to_hf.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 EleutherAI and The HuggingFace Inc. team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import argparse 15 | import gc 16 | import json 17 | import os 18 | import shutil 19 | import warnings 20 | 21 | import torch 22 | 23 | from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer 24 | 25 | 26 | try: 27 | from transformers import LlamaTokenizerFast 28 | except ImportError as e: 29 | warnings.warn(e) 30 | warnings.warn( 31 | "The converted tokenizer will be the `slow` tokenizer. To use the fast, update your `tokenizers` library and re-run the tokenizer conversion" 32 | ) 33 | LlamaTokenizerFast = None 34 | 35 | """ 36 | Sample usage: 37 | 38 | ``` 39 | python src/transformers/models/llama/convert_llama_weights_to_hf.py \ 40 | --input_dir /path/to/downloaded/llama/weights --model_size 7B --output_dir /output/path 41 | ``` 42 | 43 | Thereafter, models can be loaded via: 44 | 45 | ```py 46 | from transformers import LlamaForCausalLM, LlamaTokenizer 47 | 48 | model = LlamaForCausalLM.from_pretrained("/output/path") 49 | tokenizer = LlamaTokenizer.from_pretrained("/output/path") 50 | ``` 51 | 52 | Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions 53 | come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM). 54 | """ 55 | 56 | NUM_SHARDS = { 57 | "7B": 1, 58 | "7Bf": 1, 59 | "13B": 2, 60 | "13Bf": 2, 61 | "34B": 4, 62 | "30B": 4, 63 | "65B": 8, 64 | "70B": 8, 65 | "70Bf": 8, 66 | } 67 | 68 | 69 | def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256): 70 | return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of) 71 | 72 | 73 | def read_json(path): 74 | with open(path, "r") as f: 75 | return json.load(f) 76 | 77 | 78 | def write_json(text, path): 79 | with open(path, "w") as f: 80 | json.dump(text, f) 81 | 82 | 83 | def write_model(model_path, input_base_path, model_size, tokenizer_path=None, safe_serialization=True): 84 | # for backward compatibility, before you needed the repo to be called `my_repo/model_size` 85 | if not os.path.isfile(os.path.join(input_base_path, "params.json")): 86 | input_base_path = os.path.join(input_base_path, model_size) 87 | 88 | os.makedirs(model_path, exist_ok=True) 89 | tmp_model_path = os.path.join(model_path, "tmp") 90 | os.makedirs(tmp_model_path, exist_ok=True) 91 | 92 | params = read_json(os.path.join(input_base_path, "params.json")) 93 | num_shards = NUM_SHARDS[model_size] 94 | n_layers = params["n_layers"] 95 | n_heads = params["n_heads"] 96 | n_heads_per_shard = n_heads // num_shards 97 | dim = params["dim"] 98 | dims_per_head = dim // n_heads 99 | base = params.get("rope_theta", 10000.0) 100 | inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head)) 101 | if base > 10000.0: 102 | max_position_embeddings = 16384 103 | else: 104 | max_position_embeddings = 2048 105 | 106 | tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast 107 | if tokenizer_path is not None: 108 | tokenizer = tokenizer_class(tokenizer_path) 109 | tokenizer.save_pretrained(model_path) 110 | vocab_size = tokenizer.vocab_size if tokenizer_path is not None else 32000 111 | 112 | if "n_kv_heads" in params: 113 | num_key_value_heads = params["n_kv_heads"] # for GQA / MQA 114 | num_local_key_value_heads = n_heads_per_shard // num_key_value_heads 115 | key_value_dim = dim // num_key_value_heads 116 | else: # compatibility with other checkpoints 117 | num_key_value_heads = n_heads 118 | num_local_key_value_heads = n_heads_per_shard 119 | key_value_dim = dim 120 | 121 | # permute for sliced rotary 122 | def permute(w, n_heads=n_heads, dim1=dim, dim2=dim): 123 | return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2) 124 | 125 | print(f"Fetching all parameters from the checkpoint at {input_base_path}.") 126 | # Load weights 127 | if model_size == "7B": 128 | # Not sharded 129 | # (The sharded implementation would also work, but this is simpler.) 130 | loaded = torch.load(os.path.join(input_base_path, "consolidated.00.pth"), map_location="cpu") 131 | else: 132 | # Sharded 133 | loaded = [ 134 | torch.load(os.path.join(input_base_path, f"consolidated.{i:02d}.pth"), map_location="cpu") 135 | for i in range(num_shards) 136 | ] 137 | param_count = 0 138 | index_dict = {"weight_map": {}} 139 | for layer_i in range(n_layers): 140 | filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin" 141 | if model_size == "7B": 142 | # Unsharded 143 | state_dict = { 144 | f"model.layers.{layer_i}.self_attn.q_proj.weight": permute( 145 | loaded[f"layers.{layer_i}.attention.wq.weight"] 146 | ), 147 | f"model.layers.{layer_i}.self_attn.k_proj.weight": permute( 148 | loaded[f"layers.{layer_i}.attention.wk.weight"] 149 | ), 150 | f"model.layers.{layer_i}.self_attn.v_proj.weight": loaded[f"layers.{layer_i}.attention.wv.weight"], 151 | f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"layers.{layer_i}.attention.wo.weight"], 152 | f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w1.weight"], 153 | f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w2.weight"], 154 | f"model.layers.{layer_i}.mlp.up_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w3.weight"], 155 | f"model.layers.{layer_i}.input_layernorm.weight": loaded[f"layers.{layer_i}.attention_norm.weight"], 156 | f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[f"layers.{layer_i}.ffn_norm.weight"], 157 | } 158 | else: 159 | # Sharded 160 | # Note that attention.w{q,k,v,o}, feed_fordward.w[1,2,3], attention_norm.weight and ffn_norm.weight share 161 | # the same storage object, saving attention_norm and ffn_norm will save other weights too, which is 162 | # redundant as other weights will be stitched from multiple shards. To avoid that, they are cloned. 163 | 164 | state_dict = { 165 | f"model.layers.{layer_i}.input_layernorm.weight": loaded[0][ 166 | f"layers.{layer_i}.attention_norm.weight" 167 | ].clone(), 168 | f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[0][ 169 | f"layers.{layer_i}.ffn_norm.weight" 170 | ].clone(), 171 | } 172 | state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = permute( 173 | torch.cat( 174 | [ 175 | loaded[i][f"layers.{layer_i}.attention.wq.weight"].view(n_heads_per_shard, dims_per_head, dim) 176 | for i in range(num_shards) 177 | ], 178 | dim=0, 179 | ).reshape(dim, dim) 180 | ) 181 | state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = permute( 182 | torch.cat( 183 | [ 184 | loaded[i][f"layers.{layer_i}.attention.wk.weight"].view( 185 | num_local_key_value_heads, dims_per_head, dim 186 | ) 187 | for i in range(num_shards) 188 | ], 189 | dim=0, 190 | ).reshape(key_value_dim, dim), 191 | num_key_value_heads, 192 | key_value_dim, 193 | dim, 194 | ) 195 | state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat( 196 | [ 197 | loaded[i][f"layers.{layer_i}.attention.wv.weight"].view( 198 | num_local_key_value_heads, dims_per_head, dim 199 | ) 200 | for i in range(num_shards) 201 | ], 202 | dim=0, 203 | ).reshape(key_value_dim, dim) 204 | 205 | state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat( 206 | [loaded[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(num_shards)], dim=1 207 | ) 208 | state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat( 209 | [loaded[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(num_shards)], dim=0 210 | ) 211 | state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat( 212 | [loaded[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(num_shards)], dim=1 213 | ) 214 | state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat( 215 | [loaded[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(num_shards)], dim=0 216 | ) 217 | 218 | state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq 219 | for k, v in state_dict.items(): 220 | index_dict["weight_map"][k] = filename 221 | param_count += v.numel() 222 | torch.save(state_dict, os.path.join(tmp_model_path, filename)) 223 | 224 | filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin" 225 | if model_size == "7B": 226 | # Unsharded 227 | state_dict = { 228 | "model.embed_tokens.weight": loaded["tok_embeddings.weight"], 229 | "model.norm.weight": loaded["norm.weight"], 230 | "lm_head.weight": loaded["output.weight"], 231 | } 232 | else: 233 | state_dict = { 234 | "model.norm.weight": loaded[0]["norm.weight"], 235 | "model.embed_tokens.weight": torch.cat( 236 | [loaded[i]["tok_embeddings.weight"] for i in range(num_shards)], dim=1 237 | ), 238 | "lm_head.weight": torch.cat([loaded[i]["output.weight"] for i in range(num_shards)], dim=0), 239 | } 240 | 241 | for k, v in state_dict.items(): 242 | index_dict["weight_map"][k] = filename 243 | param_count += v.numel() 244 | torch.save(state_dict, os.path.join(tmp_model_path, filename)) 245 | 246 | # Write configs 247 | index_dict["metadata"] = {"total_size": param_count * 2} 248 | write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json")) 249 | ffn_dim_multiplier = params["ffn_dim_multiplier"] if "ffn_dim_multiplier" in params else 1 250 | multiple_of = params["multiple_of"] if "multiple_of" in params else 256 251 | config = LlamaConfig( 252 | hidden_size=dim, 253 | intermediate_size=compute_intermediate_size(dim, ffn_dim_multiplier, multiple_of), 254 | num_attention_heads=params["n_heads"], 255 | num_hidden_layers=params["n_layers"], 256 | rms_norm_eps=params["norm_eps"], 257 | num_key_value_heads=num_key_value_heads, 258 | vocab_size=vocab_size, 259 | rope_theta=base, 260 | max_position_embeddings=max_position_embeddings, 261 | ) 262 | config.save_pretrained(tmp_model_path) 263 | 264 | # Make space so we can load the model properly now. 265 | del state_dict 266 | del loaded 267 | gc.collect() 268 | 269 | print("Loading the checkpoint in a Llama model.") 270 | model = LlamaForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True) 271 | # Avoid saving this as part of the config. 272 | del model.config._name_or_path 273 | model.config.torch_dtype = torch.float16 274 | print("Saving in the Transformers format.") 275 | model.save_pretrained(model_path, safe_serialization=safe_serialization) 276 | shutil.rmtree(tmp_model_path) 277 | 278 | 279 | def write_tokenizer(tokenizer_path, input_tokenizer_path): 280 | # Initialize the tokenizer based on the `spm` model 281 | tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast 282 | print(f"Saving a {tokenizer_class.__name__} to {tokenizer_path}.") 283 | tokenizer = tokenizer_class(input_tokenizer_path) 284 | tokenizer.save_pretrained(tokenizer_path) 285 | 286 | 287 | def main(): 288 | parser = argparse.ArgumentParser() 289 | parser.add_argument( 290 | "--input_dir", 291 | help="Location of LLaMA weights, which contains tokenizer.model and model folders", 292 | ) 293 | parser.add_argument( 294 | "--model_size", 295 | choices=["7B", "7Bf", "13B", "13Bf", "30B", "34B", "65B", "70B", "70Bf", "tokenizer_only"], 296 | help="'f' models correspond to the finetuned versions, and are specific to the Llama2 official release. For more details on Llama2, checkout the original repo: https://huggingface.co/meta-llama", 297 | ) 298 | parser.add_argument( 299 | "--output_dir", 300 | help="Location to write HF model and tokenizer", 301 | ) 302 | parser.add_argument("--safe_serialization", type=bool, help="Whether or not to save using `safetensors`.") 303 | args = parser.parse_args() 304 | spm_path = os.path.join(args.input_dir, "tokenizer.model") 305 | if args.model_size != "tokenizer_only": 306 | write_model( 307 | model_path=args.output_dir, 308 | input_base_path=args.input_dir, 309 | model_size=args.model_size, 310 | safe_serialization=args.safe_serialization, 311 | tokenizer_path=spm_path, 312 | ) 313 | else: 314 | write_tokenizer(args.output_dir, spm_path) 315 | 316 | 317 | if __name__ == "__main__": 318 | main() 319 | -------------------------------------------------------------------------------- /MCSD/model/llama_tree_attn/modeling_llama.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. 3 | # 4 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX 5 | # and OPT implementations in this library. It has been modified from its 6 | # original forms to accommodate minor architectural differences compared 7 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model. 8 | # 9 | # Licensed under the Apache License, Version 2.0 (the "License"); 10 | # you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # 13 | # http://www.apache.org/licenses/LICENSE-2.0 14 | # 15 | # Unless required by applicable law or agreed to in writing, software 16 | # distributed under the License is distributed on an "AS IS" BASIS, 17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | # See the License for the specific language governing permissions and 19 | # limitations under the License. 20 | """ PyTorch LLaMA model.""" 21 | import math 22 | from typing import List, Optional, Tuple, Union 23 | 24 | import torch 25 | import torch.nn.functional as F 26 | import torch.utils.checkpoint 27 | from torch import nn 28 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 29 | 30 | from transformers.activations import ACT2FN 31 | from transformers.modeling_outputs import ( 32 | BaseModelOutputWithPast, 33 | CausalLMOutputWithPast, 34 | SequenceClassifierOutputWithPast, 35 | ) 36 | from transformers.modeling_utils import PreTrainedModel 37 | from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS 38 | from transformers.utils import ( 39 | add_start_docstrings, 40 | add_start_docstrings_to_model_forward, 41 | is_flash_attn_available, 42 | logging, 43 | replace_return_docstrings, 44 | ) 45 | from .configuration_llama import LlamaConfig 46 | 47 | 48 | if is_flash_attn_available(): 49 | from flash_attn import flash_attn_func, flash_attn_varlen_func 50 | from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa 51 | 52 | 53 | logger = logging.get_logger(__name__) 54 | 55 | _CONFIG_FOR_DOC = "LlamaConfig" 56 | 57 | 58 | def _get_unpad_data(padding_mask): 59 | seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32) 60 | indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten() 61 | max_seqlen_in_batch = seqlens_in_batch.max().item() 62 | cu_seqlens = F.pad( 63 | torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0) 64 | ) 65 | return ( 66 | indices, 67 | cu_seqlens, 68 | max_seqlen_in_batch, 69 | ) 70 | 71 | 72 | # Copied from transformers.models.bart.modeling_bart._make_causal_mask 73 | def _make_causal_mask( 74 | input_ids_shape: torch.Size, 75 | dtype: torch.dtype, 76 | device: torch.device, 77 | past_key_values_length: int = 0, 78 | tree_attn_mask: Optional[torch.Tensor] = None, 79 | ): 80 | """ 81 | Make causal mask used for bi-directional self-attention. 82 | """ 83 | bsz, tgt_len = input_ids_shape 84 | 85 | if tree_attn_mask is not None: 86 | mask = torch.full_like( 87 | tree_attn_mask, 88 | torch.finfo(dtype).min, 89 | dtype=dtype, 90 | device=device, 91 | ) 92 | mask.masked_fill_(tree_attn_mask, 0) 93 | else: 94 | mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) 95 | mask_cond = torch.arange(mask.size(-1), device=device) 96 | mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) 97 | mask = mask.to(dtype) 98 | 99 | if past_key_values_length > 0: 100 | mask = torch.cat( 101 | [ 102 | torch.zeros( 103 | tgt_len, past_key_values_length, dtype=dtype, device=device 104 | ), 105 | mask, 106 | ], 107 | dim=-1, 108 | ) 109 | return mask[None, None, :, :].expand( 110 | bsz, 1, tgt_len, tgt_len + past_key_values_length 111 | ) 112 | 113 | 114 | # Copied from transformers.models.bart.modeling_bart._expand_mask 115 | def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): 116 | """ 117 | Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. 118 | """ 119 | bsz, src_len = mask.size() 120 | tgt_len = tgt_len if tgt_len is not None else src_len 121 | 122 | expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) 123 | 124 | inverted_mask = 1.0 - expanded_mask 125 | 126 | return inverted_mask.masked_fill( 127 | inverted_mask.to(torch.bool), torch.finfo(dtype).min 128 | ) 129 | 130 | 131 | class LlamaRMSNorm(nn.Module): 132 | def __init__(self, hidden_size, eps=1e-6): 133 | """ 134 | LlamaRMSNorm is equivalent to T5LayerNorm 135 | """ 136 | super().__init__() 137 | self.weight = nn.Parameter(torch.ones(hidden_size)) 138 | self.variance_epsilon = eps 139 | 140 | def forward(self, hidden_states): 141 | input_dtype = hidden_states.dtype 142 | hidden_states = hidden_states.to(torch.float32) 143 | variance = hidden_states.pow(2).mean(-1, keepdim=True) 144 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) 145 | return self.weight * hidden_states.to(input_dtype) 146 | 147 | 148 | ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm) 149 | 150 | 151 | class LlamaRotaryEmbedding(nn.Module): 152 | def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): 153 | super().__init__() 154 | 155 | self.dim = dim 156 | self.max_position_embeddings = max_position_embeddings 157 | self.base = base 158 | inv_freq = 1.0 / ( 159 | self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) 160 | ) 161 | self.register_buffer("inv_freq", inv_freq, persistent=False) 162 | 163 | # Build here to make `torch.jit.trace` work. 164 | self._set_cos_sin_cache( 165 | seq_len=max_position_embeddings, 166 | device=self.inv_freq.device, 167 | dtype=torch.get_default_dtype(), 168 | ) 169 | 170 | def _set_cos_sin_cache(self, seq_len, device, dtype): 171 | self.max_seq_len_cached = seq_len 172 | t = torch.arange( 173 | self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype 174 | ) 175 | 176 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 177 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 178 | emb = torch.cat((freqs, freqs), dim=-1) 179 | self.register_buffer( 180 | "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False 181 | ) 182 | self.register_buffer( 183 | "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False 184 | ) 185 | 186 | def forward(self, x, seq_len=None): 187 | # x: [bs, num_attention_heads, seq_len, head_size] 188 | if seq_len > self.max_seq_len_cached: 189 | self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) 190 | 191 | return ( 192 | self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), 193 | self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), 194 | ) 195 | 196 | 197 | class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): 198 | """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" 199 | 200 | def __init__( 201 | self, 202 | dim, 203 | max_position_embeddings=2048, 204 | base=10000, 205 | device=None, 206 | scaling_factor=1.0, 207 | ): 208 | self.scaling_factor = scaling_factor 209 | super().__init__(dim, max_position_embeddings, base, device) 210 | 211 | def _set_cos_sin_cache(self, seq_len, device, dtype): 212 | self.max_seq_len_cached = seq_len 213 | t = torch.arange( 214 | self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype 215 | ) 216 | t = t / self.scaling_factor 217 | 218 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 219 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 220 | emb = torch.cat((freqs, freqs), dim=-1) 221 | self.register_buffer( 222 | "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False 223 | ) 224 | self.register_buffer( 225 | "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False 226 | ) 227 | 228 | 229 | class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): 230 | """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" 231 | 232 | def __init__( 233 | self, 234 | dim, 235 | max_position_embeddings=2048, 236 | base=10000, 237 | device=None, 238 | scaling_factor=1.0, 239 | ): 240 | self.scaling_factor = scaling_factor 241 | super().__init__(dim, max_position_embeddings, base, device) 242 | 243 | def _set_cos_sin_cache(self, seq_len, device, dtype): 244 | self.max_seq_len_cached = seq_len 245 | 246 | if seq_len > self.max_position_embeddings: 247 | base = self.base * ( 248 | (self.scaling_factor * seq_len / self.max_position_embeddings) 249 | - (self.scaling_factor - 1) 250 | ) ** (self.dim / (self.dim - 2)) 251 | inv_freq = 1.0 / ( 252 | base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) 253 | ) 254 | self.register_buffer("inv_freq", inv_freq, persistent=False) 255 | 256 | t = torch.arange( 257 | self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype 258 | ) 259 | 260 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 261 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 262 | emb = torch.cat((freqs, freqs), dim=-1) 263 | self.register_buffer( 264 | "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False 265 | ) 266 | self.register_buffer( 267 | "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False 268 | ) 269 | 270 | 271 | def rotate_half(x): 272 | """Rotates half the hidden dims of the input.""" 273 | x1 = x[..., : x.shape[-1] // 2] 274 | x2 = x[..., x.shape[-1] // 2 :] 275 | return torch.cat((-x2, x1), dim=-1) 276 | 277 | 278 | def apply_rotary_pos_emb(q, k, cos, sin, position_ids): 279 | # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. 280 | cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] 281 | sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] 282 | cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] 283 | sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] 284 | q_embed = (q * cos) + (rotate_half(q) * sin) 285 | k_embed = (k * cos) + (rotate_half(k) * sin) 286 | return q_embed, k_embed 287 | 288 | 289 | class LlamaMLP(nn.Module): 290 | def __init__(self, config): 291 | super().__init__() 292 | self.config = config 293 | self.hidden_size = config.hidden_size 294 | self.intermediate_size = config.intermediate_size 295 | self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) 296 | self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) 297 | self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) 298 | self.act_fn = ACT2FN[config.hidden_act] 299 | 300 | def forward(self, x): 301 | if self.config.pretraining_tp > 1: 302 | slice = self.intermediate_size // self.config.pretraining_tp 303 | gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) 304 | up_proj_slices = self.up_proj.weight.split(slice, dim=0) 305 | down_proj_slices = self.down_proj.weight.split(slice, dim=1) 306 | 307 | gate_proj = torch.cat( 308 | [ 309 | F.linear(x, gate_proj_slices[i]) 310 | for i in range(self.config.pretraining_tp) 311 | ], 312 | dim=-1, 313 | ) 314 | up_proj = torch.cat( 315 | [ 316 | F.linear(x, up_proj_slices[i]) 317 | for i in range(self.config.pretraining_tp) 318 | ], 319 | dim=-1, 320 | ) 321 | 322 | intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) 323 | down_proj = [ 324 | F.linear(intermediate_states[i], down_proj_slices[i]) 325 | for i in range(self.config.pretraining_tp) 326 | ] 327 | down_proj = sum(down_proj) 328 | else: 329 | down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) 330 | 331 | return down_proj 332 | 333 | 334 | def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: 335 | """ 336 | This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, 337 | num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) 338 | """ 339 | batch, num_key_value_heads, slen, head_dim = hidden_states.shape 340 | if n_rep == 1: 341 | return hidden_states 342 | hidden_states = hidden_states[:, :, None, :, :].expand( 343 | batch, num_key_value_heads, n_rep, slen, head_dim 344 | ) 345 | return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) 346 | 347 | 348 | class LlamaAttention(nn.Module): 349 | """Multi-headed attention from 'Attention Is All You Need' paper""" 350 | 351 | def __init__(self, config: LlamaConfig): 352 | super().__init__() 353 | self.config = config 354 | self.hidden_size = config.hidden_size 355 | self.num_heads = config.num_attention_heads 356 | self.head_dim = self.hidden_size // self.num_heads 357 | self.num_key_value_heads = config.num_key_value_heads 358 | self.num_key_value_groups = self.num_heads // self.num_key_value_heads 359 | self.max_position_embeddings = config.max_position_embeddings 360 | self.rope_theta = config.rope_theta 361 | 362 | if (self.head_dim * self.num_heads) != self.hidden_size: 363 | raise ValueError( 364 | f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" 365 | f" and `num_heads`: {self.num_heads})." 366 | ) 367 | self.q_proj = nn.Linear( 368 | self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias 369 | ) 370 | self.k_proj = nn.Linear( 371 | self.hidden_size, 372 | self.num_key_value_heads * self.head_dim, 373 | bias=config.attention_bias, 374 | ) 375 | self.v_proj = nn.Linear( 376 | self.hidden_size, 377 | self.num_key_value_heads * self.head_dim, 378 | bias=config.attention_bias, 379 | ) 380 | self.o_proj = nn.Linear( 381 | self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias 382 | ) 383 | self._init_rope() 384 | 385 | def _init_rope(self): 386 | if self.config.rope_scaling is None: 387 | self.rotary_emb = LlamaRotaryEmbedding( 388 | self.head_dim, 389 | max_position_embeddings=self.max_position_embeddings, 390 | base=self.rope_theta, 391 | ) 392 | else: 393 | scaling_type = self.config.rope_scaling["type"] 394 | scaling_factor = self.config.rope_scaling["factor"] 395 | if scaling_type == "linear": 396 | self.rotary_emb = LlamaLinearScalingRotaryEmbedding( 397 | self.head_dim, 398 | max_position_embeddings=self.max_position_embeddings, 399 | scaling_factor=scaling_factor, 400 | base=self.rope_theta, 401 | ) 402 | elif scaling_type == "dynamic": 403 | self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( 404 | self.head_dim, 405 | max_position_embeddings=self.max_position_embeddings, 406 | scaling_factor=scaling_factor, 407 | base=self.rope_theta, 408 | ) 409 | else: 410 | raise ValueError(f"Unknown RoPE scaling type {scaling_type}") 411 | 412 | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): 413 | return ( 414 | tensor.view(bsz, seq_len, self.num_heads, self.head_dim) 415 | .transpose(1, 2) 416 | .contiguous() 417 | ) 418 | 419 | def forward( 420 | self, 421 | hidden_states: torch.Tensor, 422 | attention_mask: Optional[torch.Tensor] = None, 423 | position_ids: Optional[torch.LongTensor] = None, 424 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 425 | output_attentions: bool = False, 426 | use_cache: bool = False, 427 | padding_mask: Optional[torch.LongTensor] = None, 428 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 429 | bsz, q_len, _ = hidden_states.size() 430 | 431 | if self.config.pretraining_tp > 1: 432 | key_value_slicing = ( 433 | self.num_key_value_heads * self.head_dim 434 | ) // self.config.pretraining_tp 435 | query_slices = self.q_proj.weight.split( 436 | (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 437 | ) 438 | key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) 439 | value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) 440 | 441 | query_states = [ 442 | F.linear(hidden_states, query_slices[i]) 443 | for i in range(self.config.pretraining_tp) 444 | ] 445 | query_states = torch.cat(query_states, dim=-1) 446 | 447 | key_states = [ 448 | F.linear(hidden_states, key_slices[i]) 449 | for i in range(self.config.pretraining_tp) 450 | ] 451 | key_states = torch.cat(key_states, dim=-1) 452 | 453 | value_states = [ 454 | F.linear(hidden_states, value_slices[i]) 455 | for i in range(self.config.pretraining_tp) 456 | ] 457 | value_states = torch.cat(value_states, dim=-1) 458 | 459 | else: 460 | query_states = self.q_proj(hidden_states) 461 | key_states = self.k_proj(hidden_states) 462 | value_states = self.v_proj(hidden_states) 463 | 464 | query_states = query_states.view( 465 | bsz, q_len, self.num_heads, self.head_dim 466 | ).transpose(1, 2) 467 | key_states = key_states.view( 468 | bsz, q_len, self.num_key_value_heads, self.head_dim 469 | ).transpose(1, 2) 470 | value_states = value_states.view( 471 | bsz, q_len, self.num_key_value_heads, self.head_dim 472 | ).transpose(1, 2) 473 | 474 | kv_seq_len = key_states.shape[-2] 475 | if past_key_value is not None: 476 | kv_seq_len += past_key_value[0].shape[-2] 477 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 478 | query_states, key_states = apply_rotary_pos_emb( 479 | query_states, key_states, cos, sin, position_ids 480 | ) 481 | 482 | if past_key_value is not None: 483 | # reuse k, v, self_attention 484 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 485 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 486 | 487 | past_key_value = (key_states, value_states) if use_cache else None 488 | 489 | key_states = repeat_kv(key_states, self.num_key_value_groups) 490 | value_states = repeat_kv(value_states, self.num_key_value_groups) 491 | 492 | attn_weights = torch.matmul( 493 | query_states, key_states.transpose(2, 3) 494 | ) / math.sqrt(self.head_dim) 495 | 496 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): 497 | raise ValueError( 498 | f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" 499 | f" {attn_weights.size()}" 500 | ) 501 | 502 | if attention_mask is not None: 503 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): 504 | raise ValueError( 505 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" 506 | ) 507 | attn_weights = attn_weights + attention_mask 508 | 509 | # upcast attention to fp32 510 | attn_weights = nn.functional.softmax( 511 | attn_weights, dim=-1, dtype=torch.float32 512 | ).to(query_states.dtype) 513 | attn_output = torch.matmul(attn_weights, value_states) 514 | 515 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): 516 | raise ValueError( 517 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" 518 | f" {attn_output.size()}" 519 | ) 520 | 521 | attn_output = attn_output.transpose(1, 2).contiguous() 522 | 523 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) 524 | 525 | if self.config.pretraining_tp > 1: 526 | attn_output = attn_output.split( 527 | self.hidden_size // self.config.pretraining_tp, dim=2 528 | ) 529 | o_proj_slices = self.o_proj.weight.split( 530 | self.hidden_size // self.config.pretraining_tp, dim=1 531 | ) 532 | attn_output = sum( 533 | [ 534 | F.linear(attn_output[i], o_proj_slices[i]) 535 | for i in range(self.config.pretraining_tp) 536 | ] 537 | ) 538 | else: 539 | attn_output = self.o_proj(attn_output) 540 | 541 | if not output_attentions: 542 | attn_weights = None 543 | 544 | return attn_output, attn_weights, past_key_value 545 | 546 | 547 | class LlamaFlashAttention2(LlamaAttention): 548 | """ 549 | Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays 550 | untouched. The only required change would be on the forward pass where it needs to correctly call the public API of 551 | flash attention and deal with padding tokens in case the input contains any of them. 552 | """ 553 | 554 | def forward( 555 | self, 556 | hidden_states: torch.Tensor, 557 | attention_mask: Optional[torch.Tensor] = None, 558 | position_ids: Optional[torch.LongTensor] = None, 559 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 560 | output_attentions: bool = False, 561 | use_cache: bool = False, 562 | padding_mask: Optional[torch.LongTensor] = None, 563 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 564 | # LlamaFlashAttention2 attention does not support output_attentions 565 | output_attentions = False 566 | 567 | bsz, q_len, _ = hidden_states.size() 568 | 569 | query_states = self.q_proj(hidden_states) 570 | key_states = self.k_proj(hidden_states) 571 | value_states = self.v_proj(hidden_states) 572 | 573 | # Flash attention requires the input to have the shape 574 | # batch_size x seq_length x head_dime x hidden_dim 575 | # therefore we just need to keep the original shape 576 | query_states = query_states.view( 577 | bsz, q_len, self.num_heads, self.head_dim 578 | ).transpose(1, 2) 579 | key_states = key_states.view( 580 | bsz, q_len, self.num_key_value_heads, self.head_dim 581 | ).transpose(1, 2) 582 | value_states = value_states.view( 583 | bsz, q_len, self.num_key_value_heads, self.head_dim 584 | ).transpose(1, 2) 585 | 586 | kv_seq_len = key_states.shape[-2] 587 | if past_key_value is not None: 588 | kv_seq_len += past_key_value[0].shape[-2] 589 | 590 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 591 | 592 | query_states, key_states = apply_rotary_pos_emb( 593 | query_states, key_states, cos, sin, position_ids 594 | ) 595 | 596 | if past_key_value is not None: 597 | # reuse k, v, self_attention 598 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 599 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 600 | 601 | past_key_value = (key_states, value_states) if use_cache else None 602 | 603 | query_states = query_states.transpose(1, 2) 604 | key_states = key_states.transpose(1, 2) 605 | value_states = value_states.transpose(1, 2) 606 | 607 | # TODO: llama does not have dropout in the config?? 608 | # It is recommended to use dropout with FA according to the docs 609 | # when training. 610 | dropout_rate = 0.0 # if not self.training else self.attn_dropout 611 | 612 | # In PEFT, usually we cast the layer norms in float32 for training stability reasons 613 | # therefore the input hidden states gets silently casted in float32. Hence, we need 614 | # cast them back in float16 just to be sure everything works as expected. 615 | # This might slowdown training & inference so it is recommended to not cast the LayerNorms 616 | # in fp32. (LlamaRMSNorm handles it correctly) 617 | input_dtype = query_states.dtype 618 | if input_dtype == torch.float32: 619 | logger.warning_once( 620 | "The input hidden states seems to be silently casted in float32, this might be related to" 621 | " the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" 622 | " float16." 623 | ) 624 | 625 | query_states = query_states.to(torch.float16) 626 | key_states = key_states.to(torch.float16) 627 | value_states = value_states.to(torch.float16) 628 | 629 | attn_output = self._flash_attention_forward( 630 | query_states, 631 | key_states, 632 | value_states, 633 | padding_mask, 634 | q_len, 635 | dropout=dropout_rate, 636 | ) 637 | 638 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() 639 | attn_output = self.o_proj(attn_output) 640 | 641 | if not output_attentions: 642 | attn_weights = None 643 | 644 | return attn_output, attn_weights, past_key_value 645 | 646 | def _flash_attention_forward( 647 | self, 648 | query_states, 649 | key_states, 650 | value_states, 651 | padding_mask, 652 | query_length, 653 | dropout=0.0, 654 | softmax_scale=None, 655 | ): 656 | """ 657 | Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token 658 | first unpad the input, then computes the attention scores and pad the final attention scores. 659 | 660 | Args: 661 | query_states (`torch.Tensor`): 662 | Input query states to be passed to Flash Attention API 663 | key_states (`torch.Tensor`): 664 | Input key states to be passed to Flash Attention API 665 | value_states (`torch.Tensor`): 666 | Input value states to be passed to Flash Attention API 667 | padding_mask (`torch.Tensor`): 668 | The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the 669 | position of padding tokens and 1 for the position of non-padding tokens. 670 | dropout (`int`, *optional*): 671 | Attention dropout 672 | softmax_scale (`float`, *optional*): 673 | The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) 674 | """ 675 | # Contains at least one padding token in the sequence 676 | if padding_mask is not None: 677 | batch_size = query_states.shape[0] 678 | ( 679 | query_states, 680 | key_states, 681 | value_states, 682 | indices_q, 683 | cu_seq_lens, 684 | max_seq_lens, 685 | ) = self._upad_input( 686 | query_states, key_states, value_states, padding_mask, query_length 687 | ) 688 | 689 | cu_seqlens_q, cu_seqlens_k = cu_seq_lens 690 | max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens 691 | 692 | attn_output_unpad = flash_attn_varlen_func( 693 | query_states, 694 | key_states, 695 | value_states, 696 | cu_seqlens_q=cu_seqlens_q, 697 | cu_seqlens_k=cu_seqlens_k, 698 | max_seqlen_q=max_seqlen_in_batch_q, 699 | max_seqlen_k=max_seqlen_in_batch_k, 700 | dropout_p=dropout, 701 | softmax_scale=softmax_scale, 702 | causal=True, 703 | ) 704 | 705 | attn_output = pad_input( 706 | attn_output_unpad, indices_q, batch_size, query_length 707 | ) 708 | else: 709 | attn_output = flash_attn_func( 710 | query_states, 711 | key_states, 712 | value_states, 713 | dropout, 714 | softmax_scale=softmax_scale, 715 | causal=True, 716 | ) 717 | 718 | return attn_output 719 | 720 | def _upad_input( 721 | self, query_layer, key_layer, value_layer, padding_mask, query_length 722 | ): 723 | indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask) 724 | batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape 725 | 726 | key_layer = index_first_axis( 727 | key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), 728 | indices_k, 729 | ) 730 | value_layer = index_first_axis( 731 | value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), 732 | indices_k, 733 | ) 734 | if query_length == kv_seq_len: 735 | query_layer = index_first_axis( 736 | query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), 737 | indices_k, 738 | ) 739 | cu_seqlens_q = cu_seqlens_k 740 | max_seqlen_in_batch_q = max_seqlen_in_batch_k 741 | indices_q = indices_k 742 | elif query_length == 1: 743 | max_seqlen_in_batch_q = 1 744 | cu_seqlens_q = torch.arange( 745 | batch_size + 1, dtype=torch.int32, device=query_layer.device 746 | ) # There is a memcpy here, that is very bad. 747 | indices_q = cu_seqlens_q[:-1] 748 | query_layer = query_layer.squeeze(1) 749 | else: 750 | # The -q_len: slice assumes left padding. 751 | padding_mask = padding_mask[:, -query_length:] 752 | query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( 753 | query_layer, padding_mask 754 | ) 755 | 756 | return ( 757 | query_layer, 758 | key_layer, 759 | value_layer, 760 | indices_q, 761 | (cu_seqlens_q, cu_seqlens_k), 762 | (max_seqlen_in_batch_q, max_seqlen_in_batch_k), 763 | ) 764 | 765 | 766 | class LlamaDecoderLayer(nn.Module): 767 | def __init__(self, config: LlamaConfig): 768 | super().__init__() 769 | self.hidden_size = config.hidden_size 770 | self.self_attn = ( 771 | LlamaAttention(config=config) 772 | if not getattr(config, "_flash_attn_2_enabled", False) 773 | else LlamaFlashAttention2(config=config) 774 | ) 775 | self.mlp = LlamaMLP(config) 776 | self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 777 | self.post_attention_layernorm = LlamaRMSNorm( 778 | config.hidden_size, eps=config.rms_norm_eps 779 | ) 780 | 781 | def forward( 782 | self, 783 | hidden_states: torch.Tensor, 784 | attention_mask: Optional[torch.Tensor] = None, 785 | position_ids: Optional[torch.LongTensor] = None, 786 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 787 | output_attentions: Optional[bool] = False, 788 | use_cache: Optional[bool] = False, 789 | padding_mask: Optional[torch.LongTensor] = None, 790 | ) -> Tuple[ 791 | torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] 792 | ]: 793 | """ 794 | Args: 795 | hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` 796 | attention_mask (`torch.FloatTensor`, *optional*): attention mask of size 797 | `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. 798 | output_attentions (`bool`, *optional*): 799 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 800 | returned tensors for more detail. 801 | use_cache (`bool`, *optional*): 802 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding 803 | (see `past_key_values`). 804 | past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states 805 | """ 806 | 807 | residual = hidden_states 808 | 809 | hidden_states = self.input_layernorm(hidden_states) 810 | 811 | # Self Attention 812 | hidden_states, self_attn_weights, present_key_value = self.self_attn( 813 | hidden_states=hidden_states, 814 | attention_mask=attention_mask, 815 | position_ids=position_ids, 816 | past_key_value=past_key_value, 817 | output_attentions=output_attentions, 818 | use_cache=use_cache, 819 | padding_mask=padding_mask, 820 | ) 821 | hidden_states = residual + hidden_states 822 | 823 | # Fully Connected 824 | residual = hidden_states 825 | hidden_states = self.post_attention_layernorm(hidden_states) 826 | hidden_states = self.mlp(hidden_states) 827 | hidden_states = residual + hidden_states 828 | 829 | outputs = (hidden_states,) 830 | 831 | if output_attentions: 832 | outputs += (self_attn_weights,) 833 | 834 | if use_cache: 835 | outputs += (present_key_value,) 836 | 837 | return outputs 838 | 839 | 840 | LLAMA_START_DOCSTRING = r""" 841 | This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the 842 | library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads 843 | etc.) 844 | 845 | This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. 846 | Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage 847 | and behavior. 848 | 849 | Parameters: 850 | config ([`LlamaConfig`]): 851 | Model configuration class with all the parameters of the model. Initializing with a config file does not 852 | load the weights associated with the model, only the configuration. Check out the 853 | [`~PreTrainedModel.from_pretrained`] method to load the model weights. 854 | """ 855 | 856 | 857 | @add_start_docstrings( 858 | "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", 859 | LLAMA_START_DOCSTRING, 860 | ) 861 | class LlamaPreTrainedModel(PreTrainedModel): 862 | config_class = LlamaConfig 863 | base_model_prefix = "model" 864 | supports_gradient_checkpointing = True 865 | _no_split_modules = ["LlamaDecoderLayer"] 866 | _skip_keys_device_placement = "past_key_values" 867 | _supports_flash_attn_2 = True 868 | 869 | def _init_weights(self, module): 870 | std = self.config.initializer_range 871 | if isinstance(module, nn.Linear): 872 | module.weight.data.normal_(mean=0.0, std=std) 873 | if module.bias is not None: 874 | module.bias.data.zero_() 875 | elif isinstance(module, nn.Embedding): 876 | module.weight.data.normal_(mean=0.0, std=std) 877 | if module.padding_idx is not None: 878 | module.weight.data[module.padding_idx].zero_() 879 | 880 | def _set_gradient_checkpointing(self, module, value=False): 881 | if isinstance(module, LlamaModel): 882 | module.gradient_checkpointing = value 883 | 884 | 885 | LLAMA_INPUTS_DOCSTRING = r""" 886 | Args: 887 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): 888 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide 889 | it. 890 | 891 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and 892 | [`PreTrainedTokenizer.__call__`] for details. 893 | 894 | [What are input IDs?](../glossary#input-ids) 895 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): 896 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 897 | 898 | - 1 for tokens that are **not masked**, 899 | - 0 for tokens that are **masked**. 900 | 901 | [What are attention masks?](../glossary#attention-mask) 902 | 903 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and 904 | [`PreTrainedTokenizer.__call__`] for details. 905 | 906 | If `past_key_values` is used, optionally only the last `input_ids` have to be input (see 907 | `past_key_values`). 908 | 909 | If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] 910 | and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more 911 | information on the default strategy. 912 | 913 | - 1 indicates the head is **not masked**, 914 | - 0 indicates the head is **masked**. 915 | position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 916 | Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, 917 | config.n_positions - 1]`. 918 | 919 | [What are position IDs?](../glossary#position-ids) 920 | past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): 921 | Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape 922 | `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape 923 | `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. 924 | 925 | Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention 926 | blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. 927 | 928 | If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't 929 | have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` 930 | of shape `(batch_size, sequence_length)`. 931 | inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): 932 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This 933 | is useful if you want more control over how to convert `input_ids` indices into associated vectors than the 934 | model's internal embedding lookup matrix. 935 | use_cache (`bool`, *optional*): 936 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see 937 | `past_key_values`). 938 | output_attentions (`bool`, *optional*): 939 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned 940 | tensors for more detail. 941 | output_hidden_states (`bool`, *optional*): 942 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for 943 | more detail. 944 | return_dict (`bool`, *optional*): 945 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 946 | """ 947 | 948 | 949 | @add_start_docstrings( 950 | "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", 951 | LLAMA_START_DOCSTRING, 952 | ) 953 | class LlamaModel(LlamaPreTrainedModel): 954 | """ 955 | Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] 956 | 957 | Args: 958 | config: LlamaConfig 959 | """ 960 | 961 | def __init__(self, config: LlamaConfig): 962 | super().__init__(config) 963 | self.padding_idx = config.pad_token_id 964 | self.vocab_size = config.vocab_size 965 | 966 | self.embed_tokens = nn.Embedding( 967 | config.vocab_size, config.hidden_size, self.padding_idx 968 | ) 969 | self.layers = nn.ModuleList( 970 | [LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)] 971 | ) 972 | self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 973 | 974 | self.gradient_checkpointing = False 975 | # Initialize weights and apply final processing 976 | self.post_init() 977 | 978 | def get_input_embeddings(self): 979 | return self.embed_tokens 980 | 981 | def set_input_embeddings(self, value): 982 | self.embed_tokens = value 983 | 984 | # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask 985 | def _prepare_decoder_attention_mask( 986 | self, 987 | attention_mask, 988 | input_shape, 989 | inputs_embeds, 990 | past_key_values_length, 991 | tree_attn_mask=None, 992 | ): 993 | # create causal mask 994 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 995 | combined_attention_mask = None 996 | if input_shape[-1] > 1: 997 | combined_attention_mask = _make_causal_mask( 998 | input_shape, 999 | inputs_embeds.dtype, 1000 | device=inputs_embeds.device, 1001 | past_key_values_length=past_key_values_length, 1002 | tree_attn_mask=tree_attn_mask, 1003 | ) 1004 | 1005 | if attention_mask is not None: 1006 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 1007 | expanded_attn_mask = _expand_mask( 1008 | attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] 1009 | ).to(inputs_embeds.device) 1010 | combined_attention_mask = ( 1011 | expanded_attn_mask 1012 | if combined_attention_mask is None 1013 | else expanded_attn_mask + combined_attention_mask 1014 | ) 1015 | 1016 | return combined_attention_mask 1017 | 1018 | @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) 1019 | def forward( 1020 | self, 1021 | input_ids: torch.LongTensor = None, 1022 | attention_mask: Optional[torch.Tensor] = None, 1023 | position_ids: Optional[torch.LongTensor] = None, 1024 | past_key_values: Optional[List[torch.FloatTensor]] = None, 1025 | inputs_embeds: Optional[torch.FloatTensor] = None, 1026 | use_cache: Optional[bool] = None, 1027 | output_attentions: Optional[bool] = None, 1028 | output_hidden_states: Optional[bool] = None, 1029 | return_dict: Optional[bool] = None, 1030 | tree_attn_mask: Optional[torch.Tensor] = None, 1031 | ) -> Union[Tuple, BaseModelOutputWithPast]: 1032 | output_attentions = ( 1033 | output_attentions 1034 | if output_attentions is not None 1035 | else self.config.output_attentions 1036 | ) 1037 | output_hidden_states = ( 1038 | output_hidden_states 1039 | if output_hidden_states is not None 1040 | else self.config.output_hidden_states 1041 | ) 1042 | use_cache = use_cache if use_cache is not None else self.config.use_cache 1043 | 1044 | return_dict = ( 1045 | return_dict if return_dict is not None else self.config.use_return_dict 1046 | ) 1047 | 1048 | # retrieve input_ids and inputs_embeds 1049 | if input_ids is not None and inputs_embeds is not None: 1050 | raise ValueError( 1051 | "You cannot specify both input_ids and inputs_embeds at the same time" 1052 | ) 1053 | elif input_ids is not None: 1054 | batch_size, seq_length = input_ids.shape 1055 | elif inputs_embeds is not None: 1056 | batch_size, seq_length, _ = inputs_embeds.shape 1057 | else: 1058 | raise ValueError("You have to specify either input_ids or inputs_embeds") 1059 | 1060 | seq_length_with_past = seq_length 1061 | past_key_values_length = 0 1062 | 1063 | if past_key_values is not None: 1064 | past_key_values_length = past_key_values[0][0].shape[2] 1065 | seq_length_with_past = seq_length_with_past + past_key_values_length 1066 | 1067 | if position_ids is None: 1068 | device = input_ids.device if input_ids is not None else inputs_embeds.device 1069 | position_ids = torch.arange( 1070 | past_key_values_length, 1071 | seq_length + past_key_values_length, 1072 | dtype=torch.long, 1073 | device=device, 1074 | ) 1075 | position_ids = position_ids.unsqueeze(0).view(-1, seq_length) 1076 | else: 1077 | position_ids = position_ids.view(-1, seq_length).long() 1078 | 1079 | if inputs_embeds is None: 1080 | inputs_embeds = self.embed_tokens(input_ids) 1081 | # embed positions 1082 | if attention_mask is None: 1083 | attention_mask = torch.ones( 1084 | (batch_size, seq_length_with_past), 1085 | dtype=torch.bool, 1086 | device=inputs_embeds.device, 1087 | ) 1088 | padding_mask = None 1089 | else: 1090 | if 0 in attention_mask: 1091 | padding_mask = attention_mask 1092 | else: 1093 | padding_mask = None 1094 | 1095 | attention_mask = self._prepare_decoder_attention_mask( 1096 | attention_mask, 1097 | (batch_size, seq_length), 1098 | inputs_embeds, 1099 | past_key_values_length, 1100 | tree_attn_mask=tree_attn_mask, 1101 | ) 1102 | 1103 | hidden_states = inputs_embeds 1104 | 1105 | if self.gradient_checkpointing and self.training: 1106 | if use_cache: 1107 | logger.warning_once( 1108 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." 1109 | ) 1110 | use_cache = False 1111 | 1112 | # decoder layers 1113 | all_hidden_states = () if output_hidden_states else None 1114 | all_self_attns = () if output_attentions else None 1115 | next_decoder_cache = () if use_cache else None 1116 | 1117 | for idx, decoder_layer in enumerate(self.layers): 1118 | if output_hidden_states: 1119 | all_hidden_states += (hidden_states,) 1120 | 1121 | past_key_value = ( 1122 | past_key_values[idx] if past_key_values is not None else None 1123 | ) 1124 | 1125 | if self.gradient_checkpointing and self.training: 1126 | 1127 | def create_custom_forward(module): 1128 | def custom_forward(*inputs): 1129 | # None for past_key_value 1130 | return module( 1131 | *inputs, 1132 | past_key_value, 1133 | output_attentions, 1134 | padding_mask=padding_mask, 1135 | ) 1136 | 1137 | return custom_forward 1138 | 1139 | layer_outputs = torch.utils.checkpoint.checkpoint( 1140 | create_custom_forward(decoder_layer), 1141 | hidden_states, 1142 | attention_mask, 1143 | position_ids, 1144 | ) 1145 | else: 1146 | layer_outputs = decoder_layer( 1147 | hidden_states, 1148 | attention_mask=attention_mask, 1149 | position_ids=position_ids, 1150 | past_key_value=past_key_value, 1151 | output_attentions=output_attentions, 1152 | use_cache=use_cache, 1153 | padding_mask=padding_mask, 1154 | ) 1155 | 1156 | hidden_states = layer_outputs[0] 1157 | 1158 | if use_cache: 1159 | next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) 1160 | 1161 | if output_attentions: 1162 | all_self_attns += (layer_outputs[1],) 1163 | 1164 | hidden_states = self.norm(hidden_states) 1165 | 1166 | # add hidden states from the last decoder layer 1167 | if output_hidden_states: 1168 | all_hidden_states += (hidden_states,) 1169 | 1170 | next_cache = next_decoder_cache if use_cache else None 1171 | if not return_dict: 1172 | return tuple( 1173 | v 1174 | for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] 1175 | if v is not None 1176 | ) 1177 | return BaseModelOutputWithPast( 1178 | last_hidden_state=hidden_states, 1179 | past_key_values=next_cache, 1180 | hidden_states=all_hidden_states, 1181 | attentions=all_self_attns, 1182 | ) 1183 | 1184 | 1185 | class LlamaForCausalLM(LlamaPreTrainedModel): 1186 | _tied_weights_keys = ["lm_head.weight"] 1187 | 1188 | def __init__(self, config): 1189 | super().__init__(config) 1190 | self.model = LlamaModel(config) 1191 | self.vocab_size = config.vocab_size 1192 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 1193 | 1194 | # Initialize weights and apply final processing 1195 | self.post_init() 1196 | 1197 | def get_input_embeddings(self): 1198 | return self.model.embed_tokens 1199 | 1200 | def set_input_embeddings(self, value): 1201 | self.model.embed_tokens = value 1202 | 1203 | def get_output_embeddings(self): 1204 | return self.lm_head 1205 | 1206 | def set_output_embeddings(self, new_embeddings): 1207 | self.lm_head = new_embeddings 1208 | 1209 | def set_decoder(self, decoder): 1210 | self.model = decoder 1211 | 1212 | def get_decoder(self): 1213 | return self.model 1214 | 1215 | @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) 1216 | @replace_return_docstrings( 1217 | output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC 1218 | ) 1219 | def forward( 1220 | self, 1221 | input_ids: torch.LongTensor = None, 1222 | attention_mask: Optional[torch.Tensor] = None, 1223 | position_ids: Optional[torch.LongTensor] = None, 1224 | past_key_values: Optional[List[torch.FloatTensor]] = None, 1225 | inputs_embeds: Optional[torch.FloatTensor] = None, 1226 | labels: Optional[torch.LongTensor] = None, 1227 | use_cache: Optional[bool] = None, 1228 | output_attentions: Optional[bool] = None, 1229 | output_hidden_states: Optional[bool] = None, 1230 | return_dict: Optional[bool] = None, 1231 | ) -> Union[Tuple, CausalLMOutputWithPast]: 1232 | r""" 1233 | Args: 1234 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 1235 | Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., 1236 | config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored 1237 | (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. 1238 | 1239 | Returns: 1240 | 1241 | Example: 1242 | 1243 | ```python 1244 | >>> from transformers import AutoTokenizer, LlamaForCausalLM 1245 | 1246 | >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) 1247 | >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) 1248 | 1249 | >>> prompt = "Hey, are you conscious? Can you talk to me?" 1250 | >>> inputs = tokenizer(prompt, return_tensors="pt") 1251 | 1252 | >>> # Generate 1253 | >>> generate_ids = model.generate(inputs.input_ids, max_length=30) 1254 | >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] 1255 | "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." 1256 | ```""" 1257 | 1258 | output_attentions = ( 1259 | output_attentions 1260 | if output_attentions is not None 1261 | else self.config.output_attentions 1262 | ) 1263 | output_hidden_states = ( 1264 | output_hidden_states 1265 | if output_hidden_states is not None 1266 | else self.config.output_hidden_states 1267 | ) 1268 | return_dict = ( 1269 | return_dict if return_dict is not None else self.config.use_return_dict 1270 | ) 1271 | 1272 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 1273 | outputs = self.model( 1274 | input_ids=input_ids, 1275 | attention_mask=attention_mask, 1276 | position_ids=position_ids, 1277 | past_key_values=past_key_values, 1278 | inputs_embeds=inputs_embeds, 1279 | use_cache=use_cache, 1280 | output_attentions=output_attentions, 1281 | output_hidden_states=output_hidden_states, 1282 | return_dict=return_dict, 1283 | ) 1284 | 1285 | hidden_states = outputs[0] 1286 | if self.config.pretraining_tp > 1: 1287 | lm_head_slices = self.lm_head.weight.split( 1288 | self.vocab_size // self.config.pretraining_tp, dim=0 1289 | ) 1290 | logits = [ 1291 | F.linear(hidden_states, lm_head_slices[i]) 1292 | for i in range(self.config.pretraining_tp) 1293 | ] 1294 | logits = torch.cat(logits, dim=-1) 1295 | else: 1296 | logits = self.lm_head(hidden_states) 1297 | logits = logits.float() 1298 | 1299 | loss = None 1300 | if labels is not None: 1301 | # Shift so that tokens < n predict n 1302 | shift_logits = logits[..., :-1, :].contiguous() 1303 | shift_labels = labels[..., 1:].contiguous() 1304 | # Flatten the tokens 1305 | loss_fct = CrossEntropyLoss() 1306 | shift_logits = shift_logits.view(-1, self.config.vocab_size) 1307 | shift_labels = shift_labels.view(-1) 1308 | # Enable model parallelism 1309 | shift_labels = shift_labels.to(shift_logits.device) 1310 | loss = loss_fct(shift_logits, shift_labels) 1311 | 1312 | if not return_dict: 1313 | output = (logits,) + outputs[1:] 1314 | return (loss,) + output if loss is not None else output 1315 | 1316 | return CausalLMOutputWithPast( 1317 | loss=loss, 1318 | logits=logits, 1319 | past_key_values=outputs.past_key_values, 1320 | hidden_states=outputs.hidden_states, 1321 | attentions=outputs.attentions, 1322 | ) 1323 | 1324 | def prepare_inputs_for_generation( 1325 | self, 1326 | input_ids, 1327 | past_key_values=None, 1328 | attention_mask=None, 1329 | inputs_embeds=None, 1330 | **kwargs, 1331 | ): 1332 | if past_key_values: 1333 | input_ids = input_ids[:, -1:] 1334 | 1335 | position_ids = kwargs.get("position_ids", None) 1336 | if attention_mask is not None and position_ids is None: 1337 | # create position_ids on the fly for batch generation 1338 | position_ids = attention_mask.long().cumsum(-1) - 1 1339 | position_ids.masked_fill_(attention_mask == 0, 1) 1340 | if past_key_values: 1341 | position_ids = position_ids[:, -1].unsqueeze(-1) 1342 | 1343 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step 1344 | if inputs_embeds is not None and past_key_values is None: 1345 | model_inputs = {"inputs_embeds": inputs_embeds} 1346 | else: 1347 | model_inputs = {"input_ids": input_ids} 1348 | 1349 | model_inputs.update( 1350 | { 1351 | "position_ids": position_ids, 1352 | "past_key_values": past_key_values, 1353 | "use_cache": kwargs.get("use_cache"), 1354 | "attention_mask": attention_mask, 1355 | } 1356 | ) 1357 | return model_inputs 1358 | 1359 | @staticmethod 1360 | def _reorder_cache(past_key_values, beam_idx): 1361 | reordered_past = () 1362 | for layer_past in past_key_values: 1363 | reordered_past += ( 1364 | tuple( 1365 | past_state.index_select(0, beam_idx.to(past_state.device)) 1366 | for past_state in layer_past 1367 | ), 1368 | ) 1369 | return reordered_past 1370 | 1371 | 1372 | @add_start_docstrings( 1373 | """ 1374 | The LLaMa Model transformer with a sequence classification head on top (linear layer). 1375 | 1376 | [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models 1377 | (e.g. GPT-2) do. 1378 | 1379 | Since it does classification on the last token, it requires to know the position of the last token. If a 1380 | `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If 1381 | no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the 1382 | padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in 1383 | each row of the batch). 1384 | """, 1385 | LLAMA_START_DOCSTRING, 1386 | ) 1387 | class LlamaForSequenceClassification(LlamaPreTrainedModel): 1388 | def __init__(self, config): 1389 | super().__init__(config) 1390 | self.num_labels = config.num_labels 1391 | self.model = LlamaModel(config) 1392 | self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) 1393 | 1394 | # Initialize weights and apply final processing 1395 | self.post_init() 1396 | 1397 | def get_input_embeddings(self): 1398 | return self.model.embed_tokens 1399 | 1400 | def set_input_embeddings(self, value): 1401 | self.model.embed_tokens = value 1402 | 1403 | @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) 1404 | def forward( 1405 | self, 1406 | input_ids: torch.LongTensor = None, 1407 | attention_mask: Optional[torch.Tensor] = None, 1408 | position_ids: Optional[torch.LongTensor] = None, 1409 | past_key_values: Optional[List[torch.FloatTensor]] = None, 1410 | inputs_embeds: Optional[torch.FloatTensor] = None, 1411 | labels: Optional[torch.LongTensor] = None, 1412 | use_cache: Optional[bool] = None, 1413 | output_attentions: Optional[bool] = None, 1414 | output_hidden_states: Optional[bool] = None, 1415 | return_dict: Optional[bool] = None, 1416 | ) -> Union[Tuple, SequenceClassifierOutputWithPast]: 1417 | r""" 1418 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 1419 | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., 1420 | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If 1421 | `config.num_labels > 1` a classification loss is computed (Cross-Entropy). 1422 | """ 1423 | return_dict = ( 1424 | return_dict if return_dict is not None else self.config.use_return_dict 1425 | ) 1426 | 1427 | transformer_outputs = self.model( 1428 | input_ids, 1429 | attention_mask=attention_mask, 1430 | position_ids=position_ids, 1431 | past_key_values=past_key_values, 1432 | inputs_embeds=inputs_embeds, 1433 | use_cache=use_cache, 1434 | output_attentions=output_attentions, 1435 | output_hidden_states=output_hidden_states, 1436 | return_dict=return_dict, 1437 | ) 1438 | hidden_states = transformer_outputs[0] 1439 | logits = self.score(hidden_states) 1440 | 1441 | if input_ids is not None: 1442 | batch_size = input_ids.shape[0] 1443 | else: 1444 | batch_size = inputs_embeds.shape[0] 1445 | 1446 | if self.config.pad_token_id is None and batch_size != 1: 1447 | raise ValueError( 1448 | "Cannot handle batch sizes > 1 if no padding token is defined." 1449 | ) 1450 | if self.config.pad_token_id is None: 1451 | sequence_lengths = -1 1452 | else: 1453 | if input_ids is not None: 1454 | sequence_lengths = ( 1455 | torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1 1456 | ).to(logits.device) 1457 | else: 1458 | sequence_lengths = -1 1459 | 1460 | pooled_logits = logits[ 1461 | torch.arange(batch_size, device=logits.device), sequence_lengths 1462 | ] 1463 | 1464 | loss = None 1465 | if labels is not None: 1466 | labels = labels.to(logits.device) 1467 | if self.config.problem_type is None: 1468 | if self.num_labels == 1: 1469 | self.config.problem_type = "regression" 1470 | elif self.num_labels > 1 and ( 1471 | labels.dtype == torch.long or labels.dtype == torch.int 1472 | ): 1473 | self.config.problem_type = "single_label_classification" 1474 | else: 1475 | self.config.problem_type = "multi_label_classification" 1476 | 1477 | if self.config.problem_type == "regression": 1478 | loss_fct = MSELoss() 1479 | if self.num_labels == 1: 1480 | loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) 1481 | else: 1482 | loss = loss_fct(pooled_logits, labels) 1483 | elif self.config.problem_type == "single_label_classification": 1484 | loss_fct = CrossEntropyLoss() 1485 | loss = loss_fct( 1486 | pooled_logits.view(-1, self.num_labels), labels.view(-1) 1487 | ) 1488 | elif self.config.problem_type == "multi_label_classification": 1489 | loss_fct = BCEWithLogitsLoss() 1490 | loss = loss_fct(pooled_logits, labels) 1491 | if not return_dict: 1492 | output = (pooled_logits,) + transformer_outputs[1:] 1493 | return ((loss,) + output) if loss is not None else output 1494 | 1495 | return SequenceClassifierOutputWithPast( 1496 | loss=loss, 1497 | logits=pooled_logits, 1498 | past_key_values=transformer_outputs.past_key_values, 1499 | hidden_states=transformer_outputs.hidden_states, 1500 | attentions=transformer_outputs.attentions, 1501 | ) 1502 | -------------------------------------------------------------------------------- /MCSD/model/llama_tree_attn/tokenization_llama.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. 3 | # 4 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX 5 | # and OPT implementations in this library. It has been modified from its 6 | # original forms to accommodate minor architectural differences compared 7 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model. 8 | # 9 | # Licensed under the Apache License, Version 2.0 (the "License"); 10 | # you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # 13 | # http://www.apache.org/licenses/LICENSE-2.0 14 | # 15 | # Unless required by applicable law or agreed to in writing, software 16 | # distributed under the License is distributed on an "AS IS" BASIS, 17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | # See the License for the specific language governing permissions and 19 | # limitations under the License. 20 | 21 | """Tokenization classes for LLaMA.""" 22 | import os 23 | from shutil import copyfile 24 | from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple 25 | 26 | import sentencepiece as spm 27 | 28 | from transformers.convert_slow_tokenizer import import_protobuf 29 | from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer 30 | from transformers.utils import logging 31 | 32 | 33 | if TYPE_CHECKING: 34 | from transformers.tokenization_utils_base import TextInput 35 | 36 | logger = logging.get_logger(__name__) 37 | 38 | VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"} 39 | 40 | PRETRAINED_VOCAB_FILES_MAP = { 41 | "vocab_file": { 42 | "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model", 43 | }, 44 | "tokenizer_file": { 45 | "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json", 46 | }, 47 | } 48 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 49 | "hf-internal-testing/llama-tokenizer": 2048, 50 | } 51 | SPIECE_UNDERLINE = "▁" 52 | 53 | B_INST, E_INST = "[INST]", "[/INST]" 54 | B_SYS, E_SYS = "<>\n", "\n<>\n\n" 55 | 56 | # fmt: off 57 | DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \ 58 | answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\ 59 | that your responses are socially unbiased and positive in nature. 60 | 61 | If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \ 62 | correct. If you don't know the answer to a question, please don't share false information.""" 63 | # fmt: on 64 | 65 | 66 | class LlamaTokenizer(PreTrainedTokenizer): 67 | """ 68 | Construct a Llama tokenizer. Based on byte-level Byte-Pair-Encoding. The default padding token is unset as there is 69 | no padding token in the original model. 70 | 71 | Args: 72 | vocab_file (`str`): 73 | Path to the vocabulary file. 74 | legacy (`bool`, *optional*): 75 | Whether or not the `legacy` behavior of the tokenizer should be used. Legacy is before the merge of #24622 76 | and #25224 which includes fixes to properly handle tokens that appear after special tokens. A simple 77 | example: 78 | 79 | - `legacy=True`: 80 | ```python 81 | >>> from transformers import T5Tokenizer 82 | 83 | >>> tokenizer = T5Tokenizer.from_pretrained("t5-base", legacy=True) 84 | >>> tokenizer.encode("Hello .") 85 | [8774, 32099, 3, 5, 1] 86 | ``` 87 | - `legacy=False`: 88 | ```python 89 | >>> from transformers import T5Tokenizer 90 | 91 | >>> tokenizer = T5Tokenizer.from_pretrained("t5-base", legacy=False) 92 | >>> tokenizer.encode("Hello .") # the extra space `[3]` is no longer here 93 | [8774, 32099, 5, 1] 94 | ``` 95 | Checkout the [pull request](https://github.com/huggingface/transformers/pull/24565) for more details. 96 | 97 | """ 98 | 99 | vocab_files_names = VOCAB_FILES_NAMES 100 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 101 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 102 | model_input_names = ["input_ids", "attention_mask"] 103 | 104 | def __init__( 105 | self, 106 | vocab_file, 107 | unk_token="", 108 | bos_token="", 109 | eos_token="", 110 | pad_token=None, 111 | sp_model_kwargs: Optional[Dict[str, Any]] = None, 112 | add_bos_token=True, 113 | add_eos_token=False, 114 | clean_up_tokenization_spaces=False, 115 | use_default_system_prompt=True, 116 | spaces_between_special_tokens=False, 117 | legacy=None, 118 | **kwargs, 119 | ): 120 | self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs 121 | bos_token = ( 122 | AddedToken(bos_token, lstrip=False, rstrip=False) 123 | if isinstance(bos_token, str) 124 | else bos_token 125 | ) 126 | eos_token = ( 127 | AddedToken(eos_token, lstrip=False, rstrip=False) 128 | if isinstance(eos_token, str) 129 | else eos_token 130 | ) 131 | unk_token = ( 132 | AddedToken(unk_token, lstrip=False, rstrip=False) 133 | if isinstance(unk_token, str) 134 | else unk_token 135 | ) 136 | pad_token = ( 137 | AddedToken(pad_token, lstrip=False, rstrip=False) 138 | if isinstance(pad_token, str) 139 | else pad_token 140 | ) 141 | 142 | if legacy is None: 143 | logger.warning_once( 144 | f"You are using the default legacy behaviour of the {self.__class__}. This is" 145 | " expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you." 146 | " If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it" 147 | " means, and thouroughly read the reason why this was added as explained in" 148 | " https://github.com/huggingface/transformers/pull/24565" 149 | ) 150 | legacy = True 151 | 152 | self.legacy = legacy 153 | self.vocab_file = vocab_file 154 | self.add_bos_token = add_bos_token 155 | self.add_eos_token = add_eos_token 156 | self.use_default_system_prompt = use_default_system_prompt 157 | self.sp_model = self.get_spm_processor(kwargs.pop("from_slow", False)) 158 | 159 | super().__init__( 160 | bos_token=bos_token, 161 | eos_token=eos_token, 162 | unk_token=unk_token, 163 | pad_token=pad_token, 164 | add_bos_token=add_bos_token, 165 | add_eos_token=add_eos_token, 166 | sp_model_kwargs=self.sp_model_kwargs, 167 | clean_up_tokenization_spaces=clean_up_tokenization_spaces, 168 | use_default_system_prompt=use_default_system_prompt, 169 | spaces_between_special_tokens=spaces_between_special_tokens, 170 | legacy=legacy, 171 | **kwargs, 172 | ) 173 | 174 | @property 175 | def unk_token_length(self): 176 | return len(self.sp_model.encode(str(self.unk_token))) 177 | 178 | # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_spm_processor 179 | def get_spm_processor(self, from_slow=False): 180 | tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs) 181 | if self.legacy or from_slow: # no dependency on protobuf 182 | tokenizer.Load(self.vocab_file) 183 | return tokenizer 184 | 185 | with open(self.vocab_file, "rb") as f: 186 | sp_model = f.read() 187 | model_pb2 = import_protobuf( 188 | f"The new behaviour of {self.__class__.__name__} (with `self.legacy = False`)" 189 | ) 190 | model = model_pb2.ModelProto.FromString(sp_model) 191 | normalizer_spec = model_pb2.NormalizerSpec() 192 | normalizer_spec.add_dummy_prefix = False 193 | model.normalizer_spec.MergeFrom(normalizer_spec) 194 | sp_model = model.SerializeToString() 195 | tokenizer.LoadFromSerializedProto(sp_model) 196 | return tokenizer 197 | 198 | def __getstate__(self): 199 | state = self.__dict__.copy() 200 | state["sp_model"] = None 201 | state["sp_model_proto"] = self.sp_model.serialized_model_proto() 202 | return state 203 | 204 | def __setstate__(self, d): 205 | self.__dict__ = d 206 | self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) 207 | self.sp_model.LoadFromSerializedProto(self.sp_model_proto) 208 | 209 | @property 210 | def vocab_size(self): 211 | """Returns vocab size""" 212 | return self.sp_model.get_piece_size() 213 | 214 | def get_vocab(self): 215 | """Returns vocab as a dict""" 216 | vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} 217 | vocab.update(self.added_tokens_encoder) 218 | return vocab 219 | 220 | # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.tokenize 221 | def tokenize( 222 | self, text: "TextInput", add_special_tokens=False, **kwargs 223 | ) -> List[str]: 224 | """ 225 | Converts a string to a list of tokens. If `self.legacy` is set to `False`, a prefix token is added unless the 226 | first token is special. 227 | """ 228 | if self.legacy or len(text) == 0: 229 | return super().tokenize(text, **kwargs) 230 | 231 | tokens = super().tokenize( 232 | SPIECE_UNDERLINE + text.replace(SPIECE_UNDERLINE, " "), **kwargs 233 | ) 234 | 235 | if ( 236 | len(tokens) > 1 237 | and tokens[0] == SPIECE_UNDERLINE 238 | and tokens[1] in self.all_special_tokens 239 | ): 240 | tokens = tokens[1:] 241 | return tokens 242 | 243 | # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._tokenize 244 | def _tokenize(self, text, **kwargs): 245 | """ 246 | Returns a tokenized string. 247 | 248 | We de-activated the `add_dummy_prefix` option, thus the sentencepiece internals will always strip any 249 | SPIECE_UNDERLINE. For example: `self.sp_model.encode(f"{SPIECE_UNDERLINE}Hey", out_type = str)` will give 250 | `['H', 'e', 'y']` instead of `['▁He', 'y']`. Thus we always encode `f"{unk_token}text"` and strip the 251 | `unk_token`. Here is an example with `unk_token = ""` and `unk_token_length = 4`. 252 | `self.tokenizer.sp_model.encode(" Hey", out_type = str)[4:]`. 253 | """ 254 | tokens = self.sp_model.encode(text, out_type=str) 255 | if self.legacy or not text.startswith((SPIECE_UNDERLINE, " ")): 256 | return tokens 257 | 258 | # 1. Encode string + prefix ex: " Hey" 259 | tokens = self.sp_model.encode(self.unk_token + text, out_type=str) 260 | # 2. Remove self.unk_token from ['<','unk','>', '▁Hey'] 261 | return ( 262 | tokens[self.unk_token_length :] 263 | if len(tokens) >= self.unk_token_length 264 | else tokens 265 | ) 266 | 267 | def _convert_token_to_id(self, token): 268 | """Converts a token (str) in an id using the vocab.""" 269 | return self.sp_model.piece_to_id(token) 270 | 271 | def _convert_id_to_token(self, index): 272 | """Converts an index (integer) in a token (str) using the vocab.""" 273 | token = self.sp_model.IdToPiece(index) 274 | return token 275 | 276 | def convert_tokens_to_string(self, tokens): 277 | """Converts a sequence of tokens (string) in a single string.""" 278 | # since we manually add the prefix space, we have to remove it when decoding 279 | if tokens[0].startswith(SPIECE_UNDERLINE): 280 | tokens[0] = tokens[0][1:] 281 | 282 | current_sub_tokens = [] 283 | out_string = "" 284 | prev_is_special = False 285 | for i, token in enumerate(tokens): 286 | # make sure that special tokens are not decoded using sentencepiece model 287 | if token in self.all_special_tokens: 288 | if not prev_is_special and i != 0 and self.legacy: 289 | out_string += " " 290 | out_string += self.sp_model.decode(current_sub_tokens) + token 291 | prev_is_special = True 292 | current_sub_tokens = [] 293 | else: 294 | current_sub_tokens.append(token) 295 | prev_is_special = False 296 | out_string += self.sp_model.decode(current_sub_tokens) 297 | return out_string 298 | 299 | def save_vocabulary( 300 | self, save_directory, filename_prefix: Optional[str] = None 301 | ) -> Tuple[str]: 302 | """ 303 | Save the vocabulary and special tokens file to a directory. 304 | 305 | Args: 306 | save_directory (`str`): 307 | The directory in which to save the vocabulary. 308 | 309 | Returns: 310 | `Tuple(str)`: Paths to the files saved. 311 | """ 312 | if not os.path.isdir(save_directory): 313 | logger.error(f"Vocabulary path ({save_directory}) should be a directory") 314 | return 315 | out_vocab_file = os.path.join( 316 | save_directory, 317 | (filename_prefix + "-" if filename_prefix else "") 318 | + VOCAB_FILES_NAMES["vocab_file"], 319 | ) 320 | 321 | if os.path.abspath(self.vocab_file) != os.path.abspath( 322 | out_vocab_file 323 | ) and os.path.isfile(self.vocab_file): 324 | copyfile(self.vocab_file, out_vocab_file) 325 | elif not os.path.isfile(self.vocab_file): 326 | with open(out_vocab_file, "wb") as fi: 327 | content_spiece_model = self.sp_model.serialized_model_proto() 328 | fi.write(content_spiece_model) 329 | 330 | return (out_vocab_file,) 331 | 332 | def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): 333 | bos_token_id = [self.bos_token_id] if self.add_bos_token else [] 334 | eos_token_id = [self.eos_token_id] if self.add_eos_token else [] 335 | 336 | output = bos_token_id + token_ids_0 + eos_token_id 337 | 338 | if token_ids_1 is not None: 339 | output = output + bos_token_id + token_ids_1 + eos_token_id 340 | 341 | return output 342 | 343 | def get_special_tokens_mask( 344 | self, 345 | token_ids_0: List[int], 346 | token_ids_1: Optional[List[int]] = None, 347 | already_has_special_tokens: bool = False, 348 | ) -> List[int]: 349 | """ 350 | Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding 351 | special tokens using the tokenizer `prepare_for_model` method. 352 | 353 | Args: 354 | token_ids_0 (`List[int]`): 355 | List of IDs. 356 | token_ids_1 (`List[int]`, *optional*): 357 | Optional second list of IDs for sequence pairs. 358 | already_has_special_tokens (`bool`, *optional*, defaults to `False`): 359 | Whether or not the token list is already formatted with special tokens for the model. 360 | 361 | Returns: 362 | `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. 363 | """ 364 | if already_has_special_tokens: 365 | return super().get_special_tokens_mask( 366 | token_ids_0=token_ids_0, 367 | token_ids_1=token_ids_1, 368 | already_has_special_tokens=True, 369 | ) 370 | 371 | bos_token_id = [1] if self.add_bos_token else [] 372 | eos_token_id = [1] if self.add_eos_token else [] 373 | 374 | if token_ids_1 is None: 375 | return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id 376 | return ( 377 | bos_token_id 378 | + ([0] * len(token_ids_0)) 379 | + eos_token_id 380 | + bos_token_id 381 | + ([0] * len(token_ids_1)) 382 | + eos_token_id 383 | ) 384 | 385 | def create_token_type_ids_from_sequences( 386 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None 387 | ) -> List[int]: 388 | """ 389 | Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT 390 | sequence pair mask has the following format: 391 | 392 | ``` 393 | 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 394 | | first sequence | second sequence | 395 | ``` 396 | 397 | if token_ids_1 is None, only returns the first portion of the mask (0s). 398 | 399 | Args: 400 | token_ids_0 (`List[int]`): 401 | List of ids. 402 | token_ids_1 (`List[int]`, *optional*): 403 | Optional second list of IDs for sequence pairs. 404 | 405 | Returns: 406 | `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). 407 | """ 408 | bos_token_id = [self.bos_token_id] if self.add_bos_token else [] 409 | eos_token_id = [self.eos_token_id] if self.add_eos_token else [] 410 | 411 | output = [0] * len(bos_token_id + token_ids_0 + eos_token_id) 412 | 413 | if token_ids_1 is not None: 414 | output += [1] * len(bos_token_id + token_ids_1 + eos_token_id) 415 | 416 | return output 417 | 418 | @property 419 | def default_chat_template(self): 420 | """ 421 | LLaMA uses [INST] and [/INST] to indicate user messages, and <> and <> to indicate system messages. 422 | Assistant messages do not have special tokens, because LLaMA chat models are generally trained with strict 423 | user/assistant/user/assistant message ordering, and so assistant messages can be identified from the ordering 424 | rather than needing special tokens. The system message is partly 'embedded' in the first user message, which 425 | results in an unusual token ordering when it is present. This template should definitely be changed if you wish 426 | to fine-tune a model with more flexible role ordering! 427 | 428 | The output should look something like: 429 | 430 | [INST] B_SYS SystemPrompt E_SYS Prompt [/INST] Answer [INST] Prompt [/INST] Answer 431 | [INST] Prompt [/INST] 432 | """ 433 | 434 | template = ( 435 | "{% if messages[0]['role'] == 'system' %}" 436 | "{% set loop_messages = messages[1:] %}" # Extract system message if it's present 437 | "{% set system_message = messages[0]['content'] %}" 438 | "{% elif USE_DEFAULT_PROMPT == true and not '<>' in messages[0]['content'] %}" 439 | "{% set loop_messages = messages %}" # Or use the default system message if the flag is set 440 | "{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}" 441 | "{% else %}" 442 | "{% set loop_messages = messages %}" 443 | "{% set system_message = false %}" 444 | "{% endif %}" 445 | "{% for message in loop_messages %}" # Loop over all non-system messages 446 | "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}" 447 | "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}" 448 | "{% endif %}" 449 | "{% if loop.index0 == 0 and system_message != false %}" # Embed system message in first message 450 | "{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}" 451 | "{% else %}" 452 | "{% set content = message['content'] %}" 453 | "{% endif %}" 454 | "{% if message['role'] == 'user' %}" # After all of that, handle messages/roles in a fairly normal way 455 | "{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}" 456 | "{% elif message['role'] == 'system' %}" 457 | "{{ '<>\\n' + content.strip() + '\\n<>\\n\\n' }}" 458 | "{% elif message['role'] == 'assistant' %}" 459 | "{{ ' ' + content.strip() + ' ' + eos_token }}" 460 | "{% endif %}" 461 | "{% endfor %}" 462 | ) 463 | template = template.replace( 464 | "USE_DEFAULT_PROMPT", "true" if self.use_default_system_prompt else "false" 465 | ) 466 | default_message = DEFAULT_SYSTEM_PROMPT.replace("\n", "\\n").replace("'", "\\'") 467 | template = template.replace("DEFAULT_SYSTEM_MESSAGE", default_message) 468 | 469 | return template 470 | -------------------------------------------------------------------------------- /MCSD/model/llama_tree_attn/tokenization_llama_fast.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import os 16 | from shutil import copyfile 17 | from typing import Optional, Tuple 18 | 19 | from tokenizers import processors 20 | 21 | from transformers.tokenization_utils_fast import PreTrainedTokenizerFast 22 | from transformers.utils import is_sentencepiece_available, logging 23 | from transformers.utils.versions import require_version 24 | 25 | 26 | require_version("tokenizers>=0.13.3") 27 | 28 | if is_sentencepiece_available(): 29 | from .tokenization_llama import LlamaTokenizer 30 | else: 31 | LlamaTokenizer = None 32 | 33 | logger = logging.get_logger(__name__) 34 | VOCAB_FILES_NAMES = { 35 | "vocab_file": "tokenizer.model", 36 | "tokenizer_file": "tokenizer.json", 37 | } 38 | 39 | PRETRAINED_VOCAB_FILES_MAP = { 40 | "vocab_file": { 41 | "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model", 42 | }, 43 | "tokenizer_file": { 44 | "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json", 45 | }, 46 | } 47 | B_INST, E_INST = "[INST]", "[/INST]" 48 | B_SYS, E_SYS = "<>\n", "\n<>\n\n" 49 | 50 | # fmt: off 51 | DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \ 52 | answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\ 53 | that your responses are socially unbiased and positive in nature. 54 | 55 | If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \ 56 | correct. If you don't know the answer to a question, please don't share false information.""" 57 | # fmt: on 58 | 59 | 60 | class LlamaTokenizerFast(PreTrainedTokenizerFast): 61 | """ 62 | Construct a Llama tokenizer. Based on byte-level Byte-Pair-Encoding. 63 | 64 | This uses notably ByteFallback and no normalization. 65 | 66 | ``` 67 | from transformers import LlamaTokenizerFast 68 | 69 | tokenizer = LlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer") 70 | tokenizer.encode("Hello this is a test") 71 | >>> [1, 15043, 445, 338, 263, 1243] 72 | ``` 73 | 74 | If you want to change the `bos_token` or the `eos_token`, make sure to specify them when initializing the model, or 75 | call `tokenizer.update_post_processor()` to make sure that the post-processing is correctly done (otherwise the 76 | values of the first token and final token of an encoded sequence will not be correct). For more details, checkout 77 | [post-processors] (https://huggingface.co/docs/tokenizers/api/post-processors) documentation. 78 | 79 | 80 | This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should 81 | refer to this superclass for more information regarding those methods. 82 | 83 | Args: 84 | vocab_file (`str`): 85 | [SentencePiece](https://github.com/google/sentencepiece) file (generally has a .model extension) that 86 | contains the vocabulary necessary to instantiate a tokenizer. 87 | tokenizer_file (`str`): 88 | [tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that 89 | contains everything needed to load the tokenizer. 90 | 91 | clean_up_tokenization_spaces (`str`, *optional*, defaults to `False`): 92 | Wether to cleanup spaces after decoding, cleanup consists in removing potential artifacts like extra 93 | spaces. 94 | 95 | bos_token (`str`, *optional*, defaults to `""`): 96 | The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. 97 | 98 | eos_token (`str`, *optional*, defaults to `""`): 99 | The end of sequence token. 100 | 101 | unk_token (`str`, *optional*, defaults to `""`): 102 | The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this 103 | token instead. 104 | """ 105 | 106 | vocab_files_names = VOCAB_FILES_NAMES 107 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 108 | slow_tokenizer_class = LlamaTokenizer 109 | padding_side = "left" 110 | model_input_names = ["input_ids", "attention_mask"] 111 | 112 | def __init__( 113 | self, 114 | vocab_file=None, 115 | tokenizer_file=None, 116 | clean_up_tokenization_spaces=False, 117 | unk_token="", 118 | bos_token="", 119 | eos_token="", 120 | add_bos_token=True, 121 | add_eos_token=False, 122 | use_default_system_prompt=True, 123 | **kwargs, 124 | ): 125 | super().__init__( 126 | vocab_file=vocab_file, 127 | tokenizer_file=tokenizer_file, 128 | clean_up_tokenization_spaces=clean_up_tokenization_spaces, 129 | unk_token=unk_token, 130 | bos_token=bos_token, 131 | eos_token=eos_token, 132 | use_default_system_prompt=use_default_system_prompt, 133 | **kwargs, 134 | ) 135 | self._add_bos_token = add_bos_token 136 | self._add_eos_token = add_eos_token 137 | self.update_post_processor() 138 | self.use_default_system_prompt = use_default_system_prompt 139 | self.vocab_file = vocab_file 140 | 141 | @property 142 | def can_save_slow_tokenizer(self) -> bool: 143 | return os.path.isfile(self.vocab_file) if self.vocab_file else False 144 | 145 | def update_post_processor(self): 146 | """ 147 | Updates the underlying post processor with the current `bos_token` and `eos_token`. 148 | """ 149 | bos = self.bos_token 150 | bos_token_id = self.bos_token_id 151 | 152 | eos = self.eos_token 153 | eos_token_id = self.eos_token_id 154 | 155 | single = f"{(bos+':0 ') * self.add_bos_token}$A:0{(' '+eos+':0') if self.add_eos_token else ''}" 156 | pair = f"{single}{(' '+bos+':1') * self.add_bos_token} $B:1{(' '+eos+':1') if self.add_eos_token else ''}" 157 | 158 | special_tokens = [] 159 | if self.add_bos_token: 160 | special_tokens.append((bos, bos_token_id)) 161 | if self.add_eos_token: 162 | special_tokens.append((eos, eos_token_id)) 163 | self._tokenizer.post_processor = processors.TemplateProcessing( 164 | single=single, pair=pair, special_tokens=special_tokens 165 | ) 166 | 167 | @property 168 | def add_eos_token(self): 169 | return self._add_eos_token 170 | 171 | @property 172 | def add_bos_token(self): 173 | return self._add_bos_token 174 | 175 | @add_eos_token.setter 176 | def add_eos_token(self, value): 177 | self._add_eos_token = value 178 | self.update_post_processor() 179 | 180 | @add_bos_token.setter 181 | def add_bos_token(self, value): 182 | self._add_bos_token = value 183 | self.update_post_processor() 184 | 185 | def save_vocabulary( 186 | self, save_directory: str, filename_prefix: Optional[str] = None 187 | ) -> Tuple[str]: 188 | if not self.can_save_slow_tokenizer: 189 | raise ValueError( 190 | "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow " 191 | "tokenizer." 192 | ) 193 | 194 | if not os.path.isdir(save_directory): 195 | logger.error(f"Vocabulary path ({save_directory}) should be a directory") 196 | return 197 | out_vocab_file = os.path.join( 198 | save_directory, 199 | (filename_prefix + "-" if filename_prefix else "") 200 | + VOCAB_FILES_NAMES["vocab_file"], 201 | ) 202 | 203 | if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): 204 | copyfile(self.vocab_file, out_vocab_file) 205 | 206 | return (out_vocab_file,) 207 | 208 | @property 209 | # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.default_chat_template 210 | def default_chat_template(self): 211 | """ 212 | LLaMA uses [INST] and [/INST] to indicate user messages, and <> and <> to indicate system messages. 213 | Assistant messages do not have special tokens, because LLaMA chat models are generally trained with strict 214 | user/assistant/user/assistant message ordering, and so assistant messages can be identified from the ordering 215 | rather than needing special tokens. The system message is partly 'embedded' in the first user message, which 216 | results in an unusual token ordering when it is present. This template should definitely be changed if you wish 217 | to fine-tune a model with more flexible role ordering! 218 | 219 | The output should look something like: 220 | 221 | [INST] B_SYS SystemPrompt E_SYS Prompt [/INST] Answer [INST] Prompt [/INST] Answer 222 | [INST] Prompt [/INST] 223 | """ 224 | 225 | template = ( 226 | "{% if messages[0]['role'] == 'system' %}" 227 | "{% set loop_messages = messages[1:] %}" # Extract system message if it's present 228 | "{% set system_message = messages[0]['content'] %}" 229 | "{% elif USE_DEFAULT_PROMPT == true and not '<>' in messages[0]['content'] %}" 230 | "{% set loop_messages = messages %}" # Or use the default system message if the flag is set 231 | "{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}" 232 | "{% else %}" 233 | "{% set loop_messages = messages %}" 234 | "{% set system_message = false %}" 235 | "{% endif %}" 236 | "{% for message in loop_messages %}" # Loop over all non-system messages 237 | "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}" 238 | "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}" 239 | "{% endif %}" 240 | "{% if loop.index0 == 0 and system_message != false %}" # Embed system message in first message 241 | "{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}" 242 | "{% else %}" 243 | "{% set content = message['content'] %}" 244 | "{% endif %}" 245 | "{% if message['role'] == 'user' %}" # After all of that, handle messages/roles in a fairly normal way 246 | "{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}" 247 | "{% elif message['role'] == 'system' %}" 248 | "{{ '<>\\n' + content.strip() + '\\n<>\\n\\n' }}" 249 | "{% elif message['role'] == 'assistant' %}" 250 | "{{ ' ' + content.strip() + ' ' + eos_token }}" 251 | "{% endif %}" 252 | "{% endfor %}" 253 | ) 254 | template = template.replace( 255 | "USE_DEFAULT_PROMPT", "true" if self.use_default_system_prompt else "false" 256 | ) 257 | default_message = DEFAULT_SYSTEM_PROMPT.replace("\n", "\\n").replace("'", "\\'") 258 | template = template.replace("DEFAULT_SYSTEM_MESSAGE", default_message) 259 | 260 | return template 261 | 262 | # TODO ArthurZ let's rely on the template processor instead, refactor all fast tokenizers 263 | # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.build_inputs_with_special_tokens 264 | def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): 265 | bos_token_id = [self.bos_token_id] if self.add_bos_token else [] 266 | eos_token_id = [self.eos_token_id] if self.add_eos_token else [] 267 | 268 | output = bos_token_id + token_ids_0 + eos_token_id 269 | 270 | if token_ids_1 is not None: 271 | output = output + bos_token_id + token_ids_1 + eos_token_id 272 | 273 | return output 274 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multi-Candidate Speculative Decoding 2 | 3 | ## Code Release 4 | See [here](./MCSD/). 5 | 6 | ## Data Release 7 | For [Alpaca dataset](https://github.com/flexflow/FlexFlow/tree/inference?tab=readme-ov-file#prompt-datasets), we use exactly the same exact source as [SpecInfer](https://arxiv.org/pdf/2305.09781.pdf). 8 | 9 | For the [WMT dataset](/dataset/wmt_ende.json), we follow the process of SpecInfer: randomly sampling 1000 samples from the test set. We wrap the source sentences using the following template: 10 | ``` 11 | Translate the input English sentence into German. 12 | Input: {source sentence} 13 | Output: 14 | ``` 15 | 16 | ## Model Release 17 | We release our fine-tuned draft models on hugginface, see [Vicuna-68M](https://huggingface.co/double7/vicuna-68m) and [Vicuna-160M](https://huggingface.co/double7/vicuna-160m). They are fine-tuned from [LLaMA-68M](https://huggingface.co/JackFram/llama-68m) and [LLaMA-160M](https://huggingface.co/JackFram/llama-160m) respectively on ShareGPT data. The training setup follows [FastChat](https://github.com/lm-sys/FastChat). 18 | --------------------------------------------------------------------------------