├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── GPTQ.py ├── LICENSE ├── README.md ├── eval.py ├── generate.py ├── mixtral-moe ├── README.md ├── generate.py ├── model.py ├── quantize.py ├── scripts │ ├── convert_hf_checkpoint.py │ └── download.py └── tp.py ├── model.py ├── quantize.py ├── requirements.txt ├── scripts ├── convert_hf_checkpoint.py ├── download.py ├── prepare.sh ├── speculate_34B_bf16.sh ├── speculate_70B_int4.sh ├── speculate_7B_int4.sh ├── speculate_tp_70B_bf16.sh └── test_flow.sh ├── setup.py ├── tokenizer.py └── tp.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .idea 3 | .DS_Store 4 | *.egg-info 5 | build 6 | 7 | # data 8 | data 9 | checkpoints 10 | out 11 | !data/shakespeare/prepare.py 12 | wandb 13 | 14 | # downloaded by our tests 15 | original_model.py 16 | original_adapter.py -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at . All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 72 | 73 | [homepage]: https://www.contributor-covenant.org 74 | 75 | For answers to common questions about this code of conduct, see 76 | https://www.contributor-covenant.org/faq -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to gpt-fast 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | 6 | ## Pull Requests 7 | We actively welcome your pull requests. 8 | 9 | 1. Fork the repo and create your branch from `main`. 10 | 2. If you've added code that should be tested, add tests. 11 | 3. If you've changed APIs, update the documentation. 12 | 4. Ensure the test suite passes. 13 | 5. Make sure your code lints. 14 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 15 | 16 | ## Contributor License Agreement ("CLA") 17 | In order to accept your pull request, we need you to submit a CLA. You only need 18 | to do this once to work on any of Meta's open source projects. 19 | 20 | Complete your CLA here: 21 | 22 | ## Issues 23 | We use GitHub issues to track public bugs. Please ensure your description is 24 | clear and has sufficient instructions to be able to reproduce the issue. 25 | 26 | Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe 27 | disclosure of security bugs. In those cases, please go through the process 28 | outlined on that page and do not file a public issue. 29 | 30 | ## License 31 | By contributing to `gpt-fast`, you agree that your contributions will be licensed 32 | under the LICENSE file in the root directory of this source tree. 33 | -------------------------------------------------------------------------------- /GPTQ.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | import torch.fx as fx 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torch.utils._pytree import tree_flatten, tree_unflatten 13 | 14 | aten = torch.ops.aten 15 | 16 | from eval import ( 17 | setup_cache_padded_seq_input_pos_max_seq_length_for_prefill, 18 | GPTFastEvalWrapper 19 | ) 20 | 21 | 22 | class InputRecorder(GPTFastEvalWrapper): 23 | """ 24 | This is a fake evaluation wrapper that just records the inputs 25 | so that they can be used in calibration. 26 | 27 | If pad_calibration_inputs is enabled, the input recorder will take 28 | each input and pad/truncate it down to the calibration_seq_length. 29 | It will also edit the model embeddings to be zero for the 0 token used 30 | in padding and avoid any inputs with the 0 token. 31 | 32 | If not, it will only truncate inputs to the desired length. 33 | """ 34 | 35 | def __init__( 36 | self, 37 | model, 38 | tokenizer, 39 | calibration_seq_length, 40 | pad_calibration_inputs=False, 41 | ): 42 | super().__init__(model, tokenizer, calibration_seq_length) 43 | self._model = model 44 | self._tokenizer = tokenizer 45 | self._device = torch.device("cpu") 46 | self.vocab_size = model.config.vocab_size 47 | self.calibration_seq_length = calibration_seq_length 48 | self.pad_calibration_inputs = pad_calibration_inputs 49 | self.inputs = None 50 | 51 | if self.pad_calibration_inputs: 52 | # This is needed for the pad_calibration_inputs option 53 | # to work properly, the 0 token's embeddings are set to 0 so that 54 | # the padded inputs will not affect the model numerics. This token isn't used 55 | # commonly in the eval tasks for the meta-llama tokenizer and we skip any inputs 56 | # where it appears 57 | try: 58 | if isinstance(self._model.transformer.wte, nn.Embedding): 59 | self.mod.transformer.wte.weight.data[0, :] *= 0 60 | except: 61 | print( 62 | "Did not find embeddings in model.transformer.wte, disabling padding" 63 | ) 64 | self.pad_calibration_inputs = False 65 | 66 | 67 | def add_input(self, args): 68 | if self.inputs is None: 69 | self.inputs = [MultiInput([arg]) for arg in args] 70 | else: 71 | self.inputs = [ 72 | multi.add_input(arg) for (multi, arg) in zip(self.inputs, args) 73 | ] 74 | 75 | def get_recorded_inputs(self): 76 | return self.inputs 77 | 78 | def _model_call(self, inps): 79 | inps = inps.squeeze(0) 80 | T = len(inps) 81 | if ( 82 | # can't use inputs that are too short when padding disabled 83 | (T < self.calibration_seq_length and not self.pad_calibration_inputs) 84 | or 85 | # can't use inputs that actually use token we use for padding 86 | (self.pad_calibration_inputs and 0 in inps) 87 | ): 88 | # give random output 89 | return torch.randn( 90 | (1, T, self.vocab_size), dtype=torch.bfloat16, device=self._device 91 | ) 92 | 93 | # pad or truncate to the right size 94 | if T >= self.calibration_seq_length: 95 | inps = inps[: self.calibration_seq_length] 96 | else: 97 | inps = F.pad(inps, (0, self.calibration_seq_length - T)) 98 | 99 | max_new_tokens = 1 100 | ( 101 | seq, 102 | input_pos, 103 | max_seq_length, 104 | ) = setup_cache_padded_seq_input_pos_max_seq_length_for_prefill( 105 | self._model, inps, max_new_tokens, self.max_length 106 | ) 107 | x = seq.index_select(0, input_pos).view(1, -1) 108 | self.add_input((x, input_pos)) 109 | 110 | # output `something` with correct shape to keep eval going 111 | return torch.randn( 112 | (1, T, self.vocab_size), dtype=torch.bfloat16, device=self._device 113 | ) 114 | 115 | 116 | 117 | class MultiInput: 118 | def __init__(self, inputs): 119 | self.values = list(inputs) 120 | 121 | def add_input(self, input): 122 | self.values.append(input) 123 | return self 124 | 125 | def __getitem__(self, slice): 126 | return MultiInput(self.values[slice]) 127 | 128 | def cuda(self): 129 | self.values = [val.cuda() if isinstance(val, torch.Tensor) else val for val in self.values] 130 | 131 | 132 | class GenericGPTQRunner(fx.Interpreter): 133 | """ 134 | This is a generic GPTQ runner that takes an existing model and applies GPTQ. 135 | It uses torch._dynamo.export to obtain a graph of the model and then hooks 136 | into function calls and when it detects a linear, it applies GPTQ to the weight 137 | given the calibration of inputs passed in at initialization. It puts the results 138 | into the state_dict so that the quantized model weights/qparams can be loaded 139 | directly into the model. 140 | 141 | This class is expected to work in concert with a GPTQSimpleQuantizer 142 | class to define the specific type of quantization being done. 143 | """ 144 | 145 | def __init__( 146 | self, model, inputs: MultiInput, blocksize=128, percdamp=0.01, groupsize=128 147 | ): 148 | self.id_to_name = { 149 | id(value): name for name, value in dict(model.named_parameters()).items() 150 | } 151 | 152 | # trace model for one input 153 | one_input = [multi.values[0].cpu() for multi in inputs] 154 | exported_model = torch._dynamo.export( 155 | model.cpu(), aten_graph=True, pre_dispatch=True, tracing_mode="fake" 156 | )(*one_input) 157 | super().__init__(exported_model.graph_module) 158 | self.new_state_dict = model.state_dict() 159 | self.blocksize = blocksize 160 | self.percdamp = percdamp 161 | self.groupsize = groupsize 162 | self.inputs = inputs 163 | self.gptq_done = False 164 | self.debug = False 165 | 166 | def configure_quantization_mode( 167 | self, 168 | get_qparams_func, 169 | quantize_func, 170 | dequantize_func, 171 | combine_qparams_list_func, 172 | make_names_and_values_dict_func, 173 | skip_layer_func, 174 | ): 175 | # these functions need to already be curried with all inputs other than weight, qparams 176 | self.get_qparams_func = ( 177 | get_qparams_func # accepts [2d weight tensor], outputs qparams. 178 | ) 179 | 180 | self.quantize_func = quantize_func # accepts [2d weight tensor], [qparams], outputs a 2d quantized tensor of desired dtype 181 | 182 | self.dequantize_func = dequantize_func 183 | # accepts [quantized] tensor and [qparams], outputs a 2d dequantized tensor of type float, 184 | # assumes this output .to(w_orig_dtype) is ~eventual desired dequant behavior 185 | 186 | self.combine_qparams_list_func = combine_qparams_list_func 187 | # accepts [`list` of qparams] from quantizing one group at a time, 188 | # outputs a qparams object that could be passed into quant/dequantize_func 189 | 190 | self.skip_layer_func = skip_layer_func # accepts [weight tensor], outputs a bool on whether or not to apply gptq to this layer 191 | 192 | self.make_names_and_values_dict_func = make_names_and_values_dict_func # accepts [2d quantized tensor], [qparams], returns a dict of names, values to put in state_dict 193 | # note any final packing for storage should happen here 194 | return self 195 | 196 | def run(self): 197 | assert ( 198 | self.get_qparams_func is not None 199 | ), "need to configure quantization mode before running" 200 | self.gptq_done = True 201 | super().run(*self.inputs) 202 | 203 | def get_quantized_state_dict(self): 204 | assert ( 205 | self.gptq_done 206 | ), "need to run GPTQRunner before you can get_quantized_state_dict" 207 | quantized_state_dict = self.new_state_dict 208 | # Don't want to store/load the kv_cache so remove it from the state_dict 209 | del_list = [] 210 | for param_fqn in quantized_state_dict: 211 | if "kv_cache" in param_fqn: 212 | del_list.append(param_fqn) 213 | for param_fqn in del_list: 214 | quantized_state_dict.pop(param_fqn) 215 | return quantized_state_dict 216 | 217 | def call_function(self, target, args, kwargs, skip_quant=False): 218 | def tensors_to_cuda(args): 219 | new_args = [] 220 | for x in args: 221 | new_args.append(x.cuda() if isinstance(x, torch.Tensor) else x) 222 | return new_args 223 | 224 | # flatten args and kwargs together 225 | flat_args, spec = tree_flatten((args, kwargs)) 226 | # move all single tensors to cuda, will move MultiInputs to cuda one at a time 227 | flat_args = tensors_to_cuda(flat_args) 228 | 229 | has_multi_input = MultiInput in [type(x) for x in flat_args] 230 | if has_multi_input: 231 | # Just some trickery to convert 232 | # [MultiInput[a, a, a], MultiInput(b, b, b)] => [a, b], [a, b], [a, b] 233 | multi_input_count = max( 234 | [len(x.values) if isinstance(x, MultiInput) else 1 for x in flat_args] 235 | ) 236 | transposed_args = list( 237 | zip( 238 | *[x.values if isinstance(x, MultiInput) else [x] * multi_input_count for x in flat_args] 239 | ) 240 | ) 241 | else: 242 | transposed_args = [flat_args] 243 | outputs = [] 244 | 245 | # check whether we apply GPTQ to this module 246 | quantize_linear = ( 247 | (target == aten.linear.default) # if its a linear 248 | and id(args[1]) in self.id_to_name # and if we know the layer name 249 | and not skip_quant # and if we weren't told to skip quantization 250 | # and if the skip_layer_func doesn't say we should skip 251 | and not (self.skip_layer_func is not None and self.skip_layer_func(args[1])) 252 | ) # then we will quantize this linear layer/weight 253 | 254 | if quantize_linear: # instantiate variables for GPTQ 255 | H = 0 256 | total_batches = 0 257 | 258 | for inp in transposed_args: 259 | inp = tensors_to_cuda(inp) 260 | cur_args, cur_kwargs = tree_unflatten(inp, spec) 261 | 262 | if ( 263 | quantize_linear 264 | ): # calculate H instead of output (will run the linear eventually with updated weight) 265 | x = cur_args[0].float() 266 | shape = x.shape 267 | n = 1 if len(shape) == 2 else shape[0] 268 | H *= total_batches / (total_batches + n) 269 | total_batches += n 270 | x = ((2 / total_batches) ** (1 / 2)) * x.reshape( 271 | -1, shape[-1] 272 | ).t().float() 273 | H += x.matmul(x.t()) 274 | else: 275 | # get output if its not a linear 276 | out = super().call_function(target, cur_args, cur_kwargs) 277 | 278 | if isinstance(out, torch.Tensor): 279 | outputs.append(out.cpu()) 280 | else: 281 | outputs.append(out) 282 | 283 | if quantize_linear: 284 | mod_fqn = ".".join(self.id_to_name[id(args[1])].split(".")[:-1]) 285 | W = args[1].to(H.device) 286 | Q, DQ, qparams = self.faster_quant(H, W.detach()) 287 | print(mod_fqn) 288 | names_and_values_dict = self.make_names_and_values_dict_func(Q, qparams) 289 | 290 | # delete old weight 291 | if mod_fqn + ".weight" in self.new_state_dict: 292 | self.new_state_dict.pop(mod_fqn + ".weight") 293 | if len(args) > 2: 294 | self.new_state_dict[mod_fqn + ".bias"] = args[2] 295 | for name, value in names_and_values_dict.items(): 296 | self.new_state_dict[mod_fqn + "." + name] = value 297 | 298 | # run linear with new weight to get corrected output 299 | new_out = self.call_function( 300 | target, (args[0], DQ, *args[2:]), kwargs, skip_quant=True 301 | ) 302 | 303 | if self.debug: 304 | old_out = self.call_function( 305 | target, (args[0][:2], args[1], *args[2:]), kwargs, skip_quant=True 306 | ) 307 | 308 | def SQNR(x, y): 309 | return 20 * torch.log10(torch.norm(x) / torch.norm(x - y)) 310 | 311 | DQ_after = self.dequantize_func(Q, qparams).to(W.dtype) 312 | print( 313 | "SQNR for QDQ (this should be inf)", SQNR(DQ, DQ_after) 314 | ) # matches 315 | 316 | print( 317 | "SQNR for weight (can be low)", SQNR(W, DQ.cuda()) 318 | ) # fine to not match 319 | print( 320 | "SQNR for output with GPTQ (hopefully 35+)", 321 | torch.cat( 322 | [ 323 | SQNR(old.cpu(), new.cpu()).unsqueeze(0) 324 | for (old, new) in zip(old_out.values, new_out.values[:2]) 325 | ] 326 | ).mean(), 327 | ) 328 | 329 | qparams2 = self.get_qparams_func(W) 330 | Q2 = self.quantize_func(W, qparams2) 331 | DQ2 = self.dequantize_func(Q2, qparams2).to(W.dtype) 332 | old_q_out = self.call_function( 333 | target, (args[0][:2], DQ2, *args[2:]), kwargs, skip_quant=True 334 | ) 335 | 336 | print("SQNR for output without GPTQ (should be less than above)", 337 | torch.cat([ 338 | SQNR(old.cpu(), old_q.cpu()).unsqueeze(0) 339 | for (old, old_q) in zip(old_out.values, old_q_out.values) 340 | ]).mean(), 341 | ) 342 | return new_out 343 | 344 | return MultiInput(outputs) if has_multi_input else outputs[0] 345 | 346 | def faster_quant(self, H, W): 347 | percdamp = self.percdamp 348 | blocksize = self.blocksize 349 | groupsize = self.groupsize 350 | orig_dtype = W.dtype 351 | W = W.detach().float() 352 | rows, columns = W.shape[0], W.shape[1] 353 | device = W.device 354 | 355 | if groupsize == -1: 356 | cur_qparams = self.get_qparams_func(W) 357 | dead = torch.diag(H) == 0 358 | H[dead, dead] = 1 359 | W[:, dead] = 0 360 | 361 | Losses = torch.zeros_like(W) 362 | DQ = torch.zeros_like(W) 363 | 364 | damp = percdamp * torch.mean(torch.diag(H)) 365 | diag = torch.arange(columns, device=device) 366 | H[diag, diag] += damp 367 | H = torch.linalg.cholesky(H) 368 | H = torch.cholesky_inverse(H) 369 | H = torch.linalg.cholesky(H, upper=True) 370 | Hinv = H 371 | 372 | all_qparams = [] 373 | for i1 in range(0, columns, blocksize): 374 | i2 = min(i1 + blocksize, columns) 375 | count = i2 - i1 376 | W1 = W[:, i1:i2].clone() 377 | DQ1 = torch.zeros_like(W1) 378 | Err1 = torch.zeros_like(W1) 379 | Losses1 = torch.zeros_like(W1) 380 | Hinv1 = Hinv[i1:i2, i1:i2] 381 | for i in range(count): 382 | w = W1[:, i] 383 | d = Hinv1[i, i] 384 | 385 | if groupsize != -1 and (i1 + i) % groupsize == 0: # start of new group 386 | cur_qparams = self.get_qparams_func( 387 | W[:, (i1 + i) : (i1 + i + groupsize)] 388 | ) 389 | all_qparams.append(cur_qparams) 390 | 391 | q = self.quantize_func(w.unsqueeze(1), cur_qparams).flatten() 392 | dq = self.dequantize_func(q.unsqueeze(1), cur_qparams).flatten() 393 | 394 | DQ1[:, i] = dq 395 | Losses1[:, i] = (w - dq) ** 2 / d**2 396 | 397 | err1 = (w - dq) / d 398 | W1[:, i:] -= ( 399 | err1.to(Hinv1.dtype).unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) 400 | ) 401 | Err1[:, i] = err1 402 | 403 | DQ[:, i1:i2] = DQ1 404 | Losses[:, i1:i2] = Losses1 / 2 405 | 406 | W[:, i2:] -= Err1.to(Hinv.dtype).matmul(Hinv[i1:i2, i2:]) 407 | 408 | torch.cuda.synchronize() 409 | 410 | if all_qparams == []: 411 | all_qparams.append(cur_qparams) 412 | 413 | # convert a list of qparams objects into a single one. enerally by 414 | # concatenating a bunch of n,1 scale/zeros tensors into a n,num_groups tensor 415 | all_qparams = self.combine_qparams_list_func(all_qparams) 416 | Q = self.quantize_func(DQ, all_qparams) 417 | return Q, DQ.to(orig_dtype), all_qparams 418 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2023 Meta 2 | 3 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 4 | 5 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 6 | 7 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 8 | 9 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 10 | 11 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 12 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # gpt-fast 2 | Simple and efficient pytorch-native transformer text generation. 3 | 4 | Featuring: 5 | 1. Very low latency 6 | 2. <1000 lines of python 7 | 3. No dependencies other than PyTorch and sentencepiece 8 | 4. int8/int4 quantization 9 | 5. Speculative decoding 10 | 6. Tensor parallelism 11 | 7. Supports Nvidia and AMD GPUs 12 | 13 | This is *NOT* intended to be a "framework" or "library" - it is intended to show off what kind of performance you can get with native PyTorch :) Please copy-paste and fork as you desire. 14 | 15 | For an in-depth walkthrough of what's in this codebase, see this [blog post](https://pytorch.org/blog/accelerating-generative-ai-2/). 16 | 17 | ## Supported Models 18 | 19 | ### LLaMA family 20 | Please check the rest of this page about benchmark of LLaMA family models. 21 | 22 | ### Mixtral 8x7B 23 | We also supported [Mixtral 8x7B](https://mistral.ai/news/mixtral-of-experts/) which is a high-quality sparse mixture of experts (MoE) model, the average token generation rates are: 24 | 25 | | | 1 GPU | 2 GPU | 4 GPU | 8 GPU | 26 | |------------------|---------|-----------|--------|------------| 27 | |baseline(bfloat16)| OOM | 96.67 | 155.35 | 227.82 | 28 | | int8 | 97.92 | 155.03 | 216.87 | 279.35 | 29 | 30 | Note that the benchmarks run on an 8xA100-80GB, power limited to 330W with a hybrid cube mesh topology. Note that all benchmarks are run at *batch size=1*, making the reported tokens/s numbers equivalent to "tokens/s/user". In addition, they are run with a very small prompt length (just 5 tokens). 31 | 32 | For more details about Mixtral 8x7B, please check [this page](./mixtral-moe) or this [note](https://thonking.substack.com/p/short-supporting-mixtral-in-gpt-fast). 33 | 34 | ## Examples 35 | In the spirit of keeping the repo minimal, here are various examples of extensions you can make to gpt-fast as PRs. 36 | - [Google Gemma](https://github.com/pytorch-labs/gpt-fast/pull/115) 37 | - [xAI Grok-1](https://github.com/pytorch-labs/gpt-fast/pull/171) 38 | - [Databricks DBRX](https://github.com/pytorch-labs/gpt-fast/pull/174) 39 | 40 | ## Community 41 | 42 | Projects inspired by gpt-fast in the community: 43 | 44 | - [gpt-blazing](https://github.com/armed-gpt/gpt-blazing): applies the same performance optimization strategy to more models (e.g., baichuan2). 45 | - [gptfast](https://github.com/MDK8888/GPTFast): applies a subset of the performance optimizations to all Huggingface models 46 | - [gpt-accelera](https://github.com/Edward-Sun/gpt-accelera): extends `gpt-fast` to SFT/RM/PPO training and batched inference to optimize the throughput 47 | 48 | ## Installation 49 | [Download PyTorch nightly](https://pytorch.org/get-started/locally/) 50 | 51 | Install required packages: 52 | 53 | ```bash 54 | pip install -r requirements.txt 55 | ``` 56 | 57 | To download llama models, go to https://huggingface.co/meta-llama/Llama-2-7b and go through steps to obtain access. 58 | Then login with `huggingface-cli login` 59 | 60 | 61 | 62 | ## Downloading Weights 63 | Models tested/supported 64 | ```text 65 | tinyllamas/stories{15,42,100} 66 | openlm-research/open_llama_7b 67 | meta-llama/Llama-2-7b-chat-hf 68 | meta-llama/Llama-2-13b-chat-hf 69 | meta-llama/Llama-2-70b-chat-hf 70 | codellama/CodeLlama-7b-Python-hf 71 | codellama/CodeLlama-34b-Python-hf 72 | mistralai/Mistral-7B-v0.1 73 | mistralai/Mistral-7B-Instruct-v0.1 74 | mistralai/Mistral-7B-Instruct-v0.2 75 | meta-llama/Meta-Llama-3-8B 76 | meta-llama/Meta-Llama-3.1-8B 77 | meta-llama/Meta-Llama-3.1-70B 78 | meta-llama/Meta-Llama-3.1-405B 79 | ``` 80 | 81 | For example, to convert Llama-2-7b-chat-hf 82 | ```bash 83 | export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf 84 | ./scripts/prepare.sh $MODEL_REPO 85 | ``` 86 | 87 | ## Benchmarks 88 | Benchmarks run on an 8xA100-80GB, power limited to 330W with a hybrid cube mesh topology. Note that all benchmarks are run at *batch size=1*, making the reported tokens/s numbers equivalent to "tokens/s/user". In addition, they are run with a very small prompt length (just 5 tokens). 89 | 90 | | Model | Technique | Tokens/Second | Memory Bandwidth (GB/s) | 91 | | -------- | ------- | ------ | ------ | 92 | | Llama-2-7B | Base | 104.9 | 1397.31 | 93 | | | 8-bit | 155.58 | 1069.20 | 94 | | | 4-bit (G=32) | 196.80 | 862.69 | 95 | | Llama-2-70B | Base | OOM || 96 | | | 8-bit | 19.13 | 1322.58 | 97 | | | 4-bit (G=32) | 25.25 | 1097.66 | 98 | | Llama-3.1-8B | Base | 93.89 | 1410.76 | 99 | | | 8-bit | 137.64 | 1030.89 | 100 | | Llama-3.1-70B | Base | OOM || 101 | | | 8-bit | 18.04 | 1253.78 | 102 | 103 | ### Speculative Sampling 104 | [Verifier: Llama-70B (int4), Draft: Llama-7B (int4)](./scripts/speculate_70B_int4.sh): 48.4 tok/s 105 | 106 | ### Tensor Parallelism 107 | | Model | Number of GPUs | Tokens/Second | Memory Bandwidth (GB/s) | 108 | | -------- | ------- | ------ | ------ | 109 | | Llama-2-7B | 1 | 104.9 | 1397.31 | 110 | | | 2 | 168.84 | 1181.99 | 111 | | | 4 | 254.02 | 955.83 | 112 | | | 8 | 328.43 | 704.10 | 113 | | Llama-2-70B | 1 | OOM | | 114 | | | 2 | 21.32 | 1481.87 | 115 | | | 4 | 38.01 | 1340.76 | 116 | | | 8 | 62.50 | 1135.29 | 117 | | Llama-3.1-8B | 1 | 93.83 | 1408.37 | 118 | | | 2 | 149.10 | 1197.32 | 119 | | | 4 | 217.21 | 986.32 | 120 | | | 8 | 276.01 | 772.60 | 121 | | Llama-3.1-70B | 1 | OOM | | 122 | | | 2 | 16.03 | 1130.81 | 123 | | | 4 | 37.45 | 1360.53 | 124 | | | 8 | 58.78 | 1129.61 | 125 | 126 | ### Tensor Parallelism + Quantization 127 | | Model | Technique | Tokens/Second | Memory Bandwidth (GB/s) | 128 | | -------- | ------- | ------ | ------ | 129 | | Llama-2-70B | Base | 62.50 | 1135.29 | 130 | | | 8-bit | 80.44 | 752.04 | 131 | | | 4-bit (G=32) | 90.77 | 548.10 | 132 | | Llama-3.1-70B | Base | 58.78 | 1129.61 | 133 | | | 8-bit | 75.58 | 726.57 | 134 | | Llama-3.1-405B | 8-bit | 15.60 | 815.87 | 135 | 136 | ### AMD 137 | Benchmarks run on one GCD of a MI-250x. 138 | 139 | | Model | Technique | Tokens/Second | Memory Bandwidth (GB/s) | 140 | | -------- | ------- | ------ | ------ | 141 | | Llama-2-7B | Base | 76.33 | 1028.70 | 142 | | | 8-bit | 101.86 | 700.06 | 143 | 144 | ## Generate Text 145 | 146 | Model definition in `model.py`, generation code in `generate.py`. 147 | 148 | ```bash 149 | python generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model.pth --prompt "Hello, my name is" 150 | ``` 151 | 152 | To squeeze out a little bit more performance, you can also compile the prefill with `--compile_prefill`. This will increase compilation times though. 153 | 154 | ## Quantization 155 | Choose device to use by 156 | ```bash 157 | # The current support devices: cuda, cpu 158 | export DEVICE=cuda 159 | ``` 160 | ### Int8 Weight-Only Quantization 161 | To generate this version of the model 162 | ```bash 163 | # Spits out model at checkpoints/$MODEL_REPO/model_int8.pth 164 | python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int8 165 | ``` 166 | To run with int8, just pass the int8 checkpoint to generate.py. 167 | ```bash 168 | python generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model_int8.pth --device $DEVICE 169 | ``` 170 | 171 | ### Int4 Weight-Only Quantization 172 | To generate int4 version of model 173 | ```bash 174 | # Spits out model at checkpoints/$MODEL_REPO/model_int4.g32.$DEVICE.pth 175 | python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4 --groupsize 32 176 | ``` 177 | 178 | To run with int4, just pass the int4 checkpoint to generate.py. 179 | ```bash 180 | python generate.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.pth --compile 181 | ``` 182 | 183 | ## Speculative Sampling 184 | To generate with speculative sampling (DRAFT_MODEL_REPO should point to a smaller model compared with MODEL_REPO). 185 | 186 | In this example, the "smaller" model is just the int8 quantized version of the model. 187 | ``` 188 | export DRAFT_MODEL_REPO=meta-llama/Llama-2-7b-chat-hf 189 | python generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model.pth --draft_checkpoint_path checkpoints/$DRAFT_MODEL_REPO/model_int8.pth 190 | ``` 191 | 192 | Note: Running on an A100 80GB, albeit power-limited to 330 watts. Empirically, seems like peak bandwidth is about 1700 GB/s. 193 | 194 | 195 | ## Tensor Parallelism 196 | ```bash 197 | ENABLE_INTRA_NODE_COMM=1 torchrun --standalone --nproc_per_node=2 generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model.pth 198 | ``` 199 | 200 | ## Experimental 201 | ### Evaluation 202 | We use the EleutherAI evaluation harness to evaluate our model accuracy. To evaluate the accuracy, make sure the evaluation harness is installed and pass your model checkpoint and desired tasks to eval.py. 203 | 204 | ```bash 205 | python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --compile --tasks hellaswag winogrande 206 | ``` 207 | 208 | Note: Generative tasks are currently not supported for gpt-fast 209 | 210 | Installation Instructions for the evaluation harness: https://github.com/EleutherAI/lm-evaluation-harness/tree/master#install 211 | 212 | ### GPTQ 213 | We have a pure pytorch implementation of GPTQ that utilizes torch._dynamo.export to access the model structure. You can generate a GPTQ quantized 214 | version of int4 quantization by using the same command to quantize it but adding 'gptq' to the quantization mode i.e. 215 | ```bash 216 | # Spits out model at checkpoints/$MODEL_REPO/model_int4-gptq.g32.pth 217 | python quantize.py --mode int4-gptq --calibration_tasks wikitext --calibration_seq_length 2048 218 | ``` 219 | 220 | You can then eval or generate text with this model in the same way as above. 221 | 222 | ## License 223 | 224 | `gpt-fast` is released under the [BSD 3](https://github.com/pytorch-labs/gpt-fast/main/LICENSE) license. 225 | 226 | ## Acknowledgements 227 | Thanks to: 228 | * Lightning AI for supporting pytorch and work in flash attention, int8 quantization, and LoRA fine-tuning. 229 | * GGML for driving forward fast, on device inference of LLMs 230 | * Karpathy for spearheading simple, interpretable and fast LLM implementations 231 | * MLC-LLM for pushing 4-bit quantization performance on heterogeneous hardware 232 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import sys 7 | import time 8 | from pathlib import Path 9 | from typing import Optional 10 | 11 | import torch 12 | import torch._dynamo.config 13 | import torch._inductor.config 14 | 15 | torch._dynamo.config.automatic_dynamic_shapes = True 16 | torch._inductor.config.triton.unique_kernel_names = True 17 | torch._inductor.config.epilogue_fusion = False 18 | torch._dynamo.config.cache_size_limit = 100000 19 | 20 | from tokenizer import get_tokenizer 21 | 22 | from model import Transformer 23 | 24 | try: 25 | import lm_eval 26 | lm_eval_available = True 27 | except: 28 | lm_eval_available = False 29 | 30 | from generate import _load_model, encode_tokens, model_forward 31 | 32 | if lm_eval_available: 33 | try: # lm_eval version 0.4 34 | from lm_eval.models.huggingface import HFLM as eval_wrapper 35 | from lm_eval.tasks import get_task_dict 36 | from lm_eval.evaluator import evaluate 37 | except: #lm_eval version 0.3 38 | from lm_eval import base 39 | from lm_eval import tasks 40 | from lm_eval import evaluator 41 | eval_wrapper=base.BaseLM 42 | get_task_dict=tasks.get_task_dict 43 | evaluate=evaluator.evaluate 44 | 45 | 46 | def setup_cache_padded_seq_input_pos_max_seq_length_for_prefill( 47 | model: Transformer, 48 | prompt: torch.Tensor, 49 | max_new_tokens: int, 50 | max_seq_length: Optional[int] = None, 51 | ): 52 | """ 53 | Sets up model cache and does some bookkeeping calculations for prompt, input_pos and max_seq_length 54 | that are needed for prefill or model_forward 55 | 56 | Args: 57 | model (LLaMA): The model whose cache gets set up 58 | prompt (torch.Tensor): Tensor of shape (T) with indices of the prompt sequence. 59 | max_new_tokens (int): The desired maximum number of new tokens that can be generated. 60 | max_seq_length (Optional[int], optional): The maximum sequence length allowed. 61 | 62 | Returns: 63 | seq (torch.Tensor): prompt but padded with zeros to size max_seq_length 64 | input_pos (torch.Tensor): tensor of integers in increasing order 65 | max_seq_length (int): The maximum sequence length allowed, updated based on other numbers 66 | """ 67 | T = prompt.size(0) 68 | T_new = T + max_new_tokens 69 | if max_seq_length is None: 70 | max_seq_length = min(T_new, model.config.block_size) 71 | 72 | device, dtype = prompt.device, prompt.dtype 73 | # create an empty tensor of the expected final shape and fill in the current tokens 74 | empty = torch.empty(T_new, dtype=dtype, device=device) 75 | empty[:T] = prompt 76 | seq = empty 77 | input_pos = torch.arange(0, T, device=device) 78 | 79 | with torch.device(device): 80 | model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) 81 | 82 | return seq, input_pos, max_seq_length 83 | 84 | class GPTFastEvalWrapper(eval_wrapper): 85 | """ 86 | A wrapper class for GPTFast, providing integration with the lm-evaluation-harness library. 87 | """ 88 | def __init__( 89 | self, 90 | model: Transformer, 91 | tokenizer, 92 | max_seq_length: Optional[int]=None, 93 | ): 94 | super().__init__() 95 | self._model = model 96 | self._tokenizer = tokenizer 97 | self._device = torch.device('cuda') 98 | self._max_seq_length = 2048 if max_seq_length is None else max_seq_length 99 | 100 | @property 101 | def eot_token_id(self): 102 | return self._tokenizer.eos_id() 103 | 104 | @property 105 | def max_length(self): 106 | return self._max_seq_length 107 | 108 | @property 109 | def max_gen_toks(self): 110 | return 50 111 | 112 | @property 113 | def batch_size(self): 114 | return 1 115 | 116 | @property 117 | def device(self): 118 | return self._device 119 | 120 | def tok_encode(self, string: str, **kwargs): 121 | encoded = encode_tokens(self._tokenizer, 122 | string, bos=True, device=self._device) 123 | # encoded is a pytorch tensor, but some internal logic in the 124 | # eval harness expects it to be a list instead 125 | # TODO: verify this for multi-batch as well 126 | encoded = encoded.tolist() 127 | return encoded 128 | 129 | def tok_decode(self, tokens): 130 | decoded = self._tokenizer.decode(tokens) 131 | return decoded 132 | 133 | def _model_call(self, inps): 134 | # TODO: make batches work 135 | inps = inps.squeeze(0) 136 | 137 | max_new_tokens = 1 138 | seq, input_pos, max_seq_length = \ 139 | setup_cache_padded_seq_input_pos_max_seq_length_for_prefill( 140 | self._model, 141 | inps, 142 | max_new_tokens, 143 | self.max_length, 144 | ) 145 | x = seq.index_select(0, input_pos).view(1, -1) 146 | logits = model_forward(self._model, x, input_pos) 147 | return logits 148 | 149 | def _model_generate(self, context, max_length, eos_token_id): 150 | raise Exception('unimplemented') 151 | 152 | 153 | @torch.no_grad() 154 | def eval( 155 | model: Transformer, 156 | tokenizer, 157 | tasks: list = ["hellaswag"], 158 | limit: Optional[int] = None, 159 | max_seq_length: Optional[int] = None, 160 | ) -> dict: 161 | """ 162 | Evaluates a language model on a specified task using the lm-evaluation-harness library. 163 | 164 | Args: 165 | model (Transformer): The pre-trained language model to evaluate. 166 | tokenizer: The tokenizer to use for encoding/decoding text. 167 | tasks (list): The names of the evaluation tasks to perform. 168 | limit (Optional[int]): The maximum number of samples to evaluate (None for all available). 169 | max_seq_length (Optional[int]): The maximum sequence length allowed for input text. 170 | 171 | Returns: 172 | eval_results (dict): A dictionary of evaluation results for the specified task(s). 173 | """ 174 | model_eval_wrapper = GPTFastEvalWrapper( 175 | model, 176 | tokenizer, 177 | max_seq_length, 178 | ) 179 | 180 | try: 181 | lm_eval.tasks.initialize_tasks() 182 | except: 183 | pass 184 | 185 | if 'hendrycks_test' in tasks: 186 | tasks.remove('hendrycks_test') 187 | tasks += [x for x in lm_eval.tasks.hendrycks_test.create_all_tasks().keys()] 188 | task_dict = get_task_dict(tasks) 189 | 190 | eval_results = evaluate( 191 | model_eval_wrapper, 192 | task_dict, 193 | limit=limit, 194 | ) 195 | return eval_results 196 | 197 | 198 | def main( 199 | checkpoint_path: Path = Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/lit_model.pth"), 200 | compile: bool = False, 201 | tasks: list = ["hellaswag"], 202 | limit: Optional[int] = None, 203 | max_seq_length: Optional[int] = None, 204 | ) -> None: 205 | """Evaluates model on a task from the `lm-evaluation-harness` library. 206 | 207 | Args: 208 | checkpoint_path (Path): The path to the model checkpoint file to load. 209 | compile (bool): Whether or not to compile the model for optimization. 210 | tasks (list): The names of the evaluation tasks to perform. 211 | limit (Optional[int]): The maximum number of samples to evaluate (None for all available). 212 | max_seq_length (Optional[int]): The maximum sequence length allowed for input text. 213 | 214 | """ 215 | 216 | assert checkpoint_path.is_file(), checkpoint_path 217 | 218 | tokenizer_path = checkpoint_path.parent / "tokenizer.model" 219 | assert tokenizer_path.is_file(), str(tokenizer_path) 220 | 221 | device = 'cuda' 222 | precision = torch.bfloat16 223 | 224 | print("Loading model ...") 225 | t0 = time.time() 226 | model = _load_model(checkpoint_path, device, precision, False) 227 | 228 | torch.cuda.synchronize() 229 | print(f"Time to load model: {time.time() - t0:.02f} seconds.") 230 | 231 | model.eval() 232 | 233 | tokenizer = get_tokenizer(tokenizer_path, checkpoint_path) 234 | 235 | torch.manual_seed(1234) 236 | 237 | if compile: 238 | global model_forward 239 | model_forward = torch.compile(model_forward, mode="reduce-overhead", dynamic=True, fullgraph=True) 240 | torch._inductor.config.coordinate_descent_tuning = True 241 | 242 | t1 = time.time() 243 | result = eval( 244 | model, 245 | tokenizer, 246 | tasks, 247 | limit, 248 | max_seq_length, 249 | ) 250 | print(f"Time to run eval: {time.time() - t1:.02f} seconds.") 251 | print(f"For model {checkpoint_path}") 252 | for task, res in result["results"].items(): 253 | print(f"{task}: {res}") 254 | 255 | 256 | if __name__ == '__main__': 257 | import argparse 258 | parser = argparse.ArgumentParser(description='Your CLI description.') 259 | 260 | parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/lit_model.pth"), help='Model checkpoint path.') 261 | parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') 262 | parser.add_argument('--tasks', nargs='+', type=str, default=["hellaswag"], help='list of lm-eluther tasks to evaluate usage: --tasks task1 task2') 263 | parser.add_argument('--limit', type=int, default=None, help='number of samples to evalulate') 264 | parser.add_argument('--max_seq_length', type=int, default=None, help='maximum length sequence to evaluate') 265 | 266 | args = parser.parse_args() 267 | main( 268 | Path(args.checkpoint_path), args.compile, args.tasks, args.limit, args.max_seq_length, 269 | ) 270 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import itertools 7 | import sys 8 | import time 9 | from pathlib import Path 10 | from typing import Optional, Tuple, Union 11 | 12 | import torch 13 | import torch._dynamo.config 14 | import torch._inductor.config 15 | from torch.nn.attention.flex_attention import BlockMask, create_block_mask 16 | 17 | def device_sync(device): 18 | if "cuda" in device: 19 | torch.cuda.synchronize(device) 20 | elif ("cpu" in device) or ("mps" in device): 21 | pass 22 | else: 23 | print(f"device={device} is not yet suppported") 24 | 25 | 26 | torch._inductor.config.coordinate_descent_tuning = True 27 | torch._inductor.config.triton.unique_kernel_names = True 28 | # Experimental features to reduce compilation times, will be on by default in future 29 | torch._inductor.config.fx_graph_cache = True 30 | torch._functorch.config.enable_autograd_cache = True 31 | 32 | default_device = 'cuda' if torch.cuda.is_available() else 'cpu' 33 | 34 | create_block_mask = torch.compile(create_block_mask) 35 | 36 | # support running without installing as a package 37 | wd = Path(__file__).parent.parent.resolve() 38 | sys.path.append(str(wd)) 39 | 40 | from model import Transformer 41 | from tokenizer import get_tokenizer 42 | 43 | def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization 44 | q = torch.empty_like(probs_sort).exponential_(1) 45 | return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) 46 | 47 | def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None): 48 | logits = logits / max(temperature, 1e-5) 49 | 50 | if top_k is not None: 51 | v, _ = torch.topk(logits, min(top_k, logits.size(-1))) 52 | pivot = v.select(-1, -1).unsqueeze(-1) 53 | logits = torch.where(logits < pivot, -float("Inf"), logits) 54 | probs = torch.nn.functional.softmax(logits, dim=-1) 55 | return probs 56 | 57 | def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None): 58 | probs = logits_to_probs(logits[:, -1], temperature, top_k) 59 | idx_next = multinomial_sample_one_no_sync(probs) 60 | return idx_next, probs 61 | 62 | def roundup(val, multiplier): 63 | return ((val - 1) // multiplier + 1) * multiplier 64 | 65 | def causal_mask(b, h, q, kv): 66 | return q >= kv 67 | 68 | def prefill(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> torch.Tensor: 69 | # input_pos: [B, S] 70 | mask = create_block_mask(causal_mask, 1, 1, input_pos.shape[0], model.max_seq_length, device=x.device) 71 | logits = model(mask, x, input_pos) 72 | return sample(logits, **sampling_kwargs)[0] 73 | 74 | def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, block_mask: BlockMask, **sampling_kwargs) -> Tuple[torch.Tensor, torch.Tensor]: 75 | # input_pos: [B, 1] 76 | assert input_pos.shape[-1] == 1 77 | block_index = input_pos // block_mask.BLOCK_SIZE[0] 78 | mask = block_mask[:, :, block_index] 79 | mask.mask_mod = block_mask.mask_mod 80 | mask.seq_lengths = (1, model.max_seq_length) 81 | logits = model(mask, x, input_pos) 82 | return sample(logits, **sampling_kwargs) 83 | 84 | def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, callback=lambda _: _, **sampling_kwargs): 85 | block_mask = create_block_mask(causal_mask, 1, 1, model.max_seq_length, model.max_seq_length, device=cur_token.device) 86 | new_tokens, new_probs = [], [] 87 | for i in range(num_new_tokens): 88 | next_token, next_prob = decode_one_token( 89 | model, cur_token, input_pos, block_mask, **sampling_kwargs 90 | ) 91 | input_pos += 1 92 | new_tokens.append(next_token.clone()) 93 | callback(new_tokens[-1]) 94 | new_probs.append(next_prob.clone()) 95 | cur_token = next_token.clone() 96 | 97 | return new_tokens, new_probs 98 | 99 | 100 | def model_forward(model, x, input_pos): 101 | return model(x, input_pos) 102 | 103 | def speculative_decode( 104 | model: Transformer, 105 | draft_model: Transformer, 106 | cur_token: torch.Tensor, 107 | input_pos: int, 108 | speculate_k: int, 109 | **sampling_kwargs 110 | ) -> torch.Tensor: 111 | # draft model inference sequentially 112 | device = cur_token.device 113 | orig_input_pos = torch.tensor([input_pos], dtype=torch.int64, device=cur_token.device) 114 | draft_tokens, draft_probs = decode_n_tokens(draft_model, cur_token.view(1, -1), orig_input_pos.clone(), speculate_k, **sampling_kwargs) 115 | 116 | draft_tokens = torch.cat(draft_tokens) 117 | # parallel inference on target model using draft tokens 118 | target_logits = model_forward( 119 | model, 120 | torch.cat([cur_token.view(1), draft_tokens]).view(1, -1), 121 | torch.arange(input_pos, input_pos + speculate_k + 1, device=cur_token.device) 122 | ) 123 | target_probs = logits_to_probs(target_logits[0], **sampling_kwargs) 124 | draft_probs = torch.stack(draft_probs) 125 | # q: target prob, p: draft prob 126 | # q >= p: always accept draft token 127 | # q < p: q/p prob to accept draft token 128 | p = draft_probs[torch.arange(0, speculate_k, device=device), draft_tokens] 129 | q = target_probs[torch.arange(0, speculate_k, device=device), draft_tokens] 130 | accept_draft_prob = torch.minimum(torch.ones(()), q[:speculate_k]/ p) 131 | rejected_locations = (torch.rand_like(accept_draft_prob) > accept_draft_prob).nonzero() 132 | 133 | if rejected_locations.shape[0] == 0: # All draft tokens have been accepted 134 | accept_length = speculate_k + 1 135 | last_token = multinomial_sample_one_no_sync(target_probs[-1]) 136 | # fill last token into draft model 137 | model_forward( 138 | draft_model, 139 | draft_tokens[-1].view(1, -1), 140 | orig_input_pos + speculate_k, 141 | ) 142 | return torch.cat([draft_tokens, last_token]) 143 | else: 144 | accept_length = rejected_locations[0].item() 145 | p = draft_probs[accept_length] 146 | q = target_probs[accept_length] 147 | new = q - p 148 | new = torch.where(new > 0, new, 0.0) 149 | new = new / new.sum() 150 | next_token = multinomial_sample_one_no_sync(new) 151 | return torch.cat([draft_tokens[:accept_length], next_token]) 152 | 153 | @torch.no_grad() 154 | def generate( 155 | model: Transformer, 156 | prompt: torch.Tensor, 157 | max_new_tokens: int, 158 | batch_size: int, 159 | *, 160 | interactive: bool, 161 | draft_model: Transformer, 162 | speculate_k: Optional[int] = 8, 163 | callback = lambda x: x, 164 | **sampling_kwargs 165 | ) -> torch.Tensor: 166 | """ 167 | Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. 168 | """ 169 | 170 | is_speculative = draft_model is not None 171 | # create an empty tensor of the expected final shape and fill in the current tokens 172 | T = prompt.size(-1) 173 | T_new = T + max_new_tokens 174 | if interactive: 175 | max_seq_length = 350 176 | else: 177 | max_seq_length = min(T_new, model.config.block_size) 178 | 179 | device, dtype = prompt.device, prompt.dtype 180 | max_seq_length = max_seq_length + speculate_k + 1 if is_speculative else max_seq_length 181 | with torch.device(device): 182 | model.setup_caches(max_batch_size=batch_size, max_seq_length=max_seq_length) 183 | if is_speculative and draft_model is not model: 184 | draft_model.setup_caches(max_batch_size=batch_size, max_seq_length=max_seq_length) 185 | 186 | # create an empty tensor of the expected final shape and fill in the current tokens 187 | empty = torch.empty(batch_size, T_new, dtype=dtype, device=device) 188 | # We are just making the same prompt for every batch 189 | prompt = prompt.view(1, -1).repeat(batch_size, 1) 190 | empty[:, :T] = prompt 191 | seq = empty 192 | input_pos = torch.arange(0, T, device=device) 193 | 194 | next_token = prefill(model, prompt.view(batch_size, -1), input_pos, **sampling_kwargs).clone() 195 | if is_speculative: 196 | prefill(draft_model, prompt.view(batch_size, -1), input_pos, **sampling_kwargs) 197 | seq[:, T] = next_token.squeeze() 198 | 199 | input_pos = torch.tensor([T], device=device, dtype=torch.int) 200 | accept_counts = [0] * (speculate_k + 1) 201 | 202 | if is_speculative: 203 | input_pos = input_pos.item() # for speculative decoding easier to keep on host 204 | while input_pos < T_new - 1: 205 | cur_token = next_token.view(()) 206 | 207 | next_tokens = speculative_decode( 208 | model, draft_model, cur_token, input_pos, speculate_k, **sampling_kwargs 209 | ) 210 | 211 | accept_counts[len(next_tokens) - 1] += 1 212 | num_added = min(T_new - input_pos - 1, len(next_tokens)) 213 | seq[input_pos + 1 : input_pos + num_added + 1] = next_tokens[: num_added] 214 | for i in next_tokens[: num_added,]: 215 | callback(i) 216 | input_pos = input_pos + num_added 217 | next_token = next_tokens[-1] 218 | else: 219 | generated_tokens, _ = decode_n_tokens(model, next_token.view(batch_size, -1), input_pos, max_new_tokens - 1, callback=callback, **sampling_kwargs) 220 | seq[:, T + 1:] = torch.cat(generated_tokens, dim=-1) 221 | 222 | generate_stats = { 223 | 'accept_counts': accept_counts 224 | } 225 | return seq, generate_stats 226 | 227 | def encode_tokens(tokenizer, string, bos=True, device=default_device): 228 | tokens = tokenizer.encode(string) 229 | if bos: 230 | tokens = [tokenizer.bos_id()] + tokens 231 | return torch.tensor(tokens, dtype=torch.int, device=device) 232 | 233 | def _load_model(checkpoint_path, device, precision, use_tp): 234 | use_cuda = 'cuda' in device 235 | with torch.device('meta'): 236 | model = Transformer.from_name(checkpoint_path.parent.name) 237 | 238 | if "int8" in str(checkpoint_path): 239 | print("Using int8 weight-only quantization!") 240 | from quantize import WeightOnlyInt8QuantHandler 241 | simple_quantizer = WeightOnlyInt8QuantHandler(model) 242 | model = simple_quantizer.convert_for_runtime() 243 | 244 | if "int4" in str(checkpoint_path): 245 | print("Using int4 weight-only quantization!") 246 | path_comps = checkpoint_path.name.split(".") 247 | groupsize = int(path_comps[-2][1:]) 248 | from quantize import WeightOnlyInt4QuantHandler 249 | simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize) 250 | model = simple_quantizer.convert_for_runtime() 251 | 252 | checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) 253 | if "model" in checkpoint and "stories" in str(checkpoint_path): 254 | checkpoint = checkpoint["model"] 255 | model.load_state_dict(checkpoint, assign=True) 256 | 257 | if use_tp: 258 | from tp import apply_tp 259 | print("Applying tensor parallel to model ...") 260 | apply_tp(model) 261 | 262 | model = model.to(device=device, dtype=precision) 263 | return model.eval() 264 | 265 | def _get_model_size(model): 266 | model_size = 0 267 | params = 0 268 | for name, child in model.named_children(): 269 | if not isinstance(child, torch.nn.Embedding): 270 | model_size += sum( 271 | [ 272 | p.numel() * p.dtype.itemsize 273 | for p in itertools.chain(child.parameters(), child.buffers()) 274 | ] 275 | ) 276 | params += sum( 277 | [ 278 | p.numel() 279 | for p in itertools.chain(child.parameters(), child.buffers()) 280 | ] 281 | ) 282 | return model_size, params 283 | 284 | B_INST, E_INST = "[INST]", "[/INST]" 285 | 286 | def main( 287 | prompt: Union[int, str] = "Hello, my name is", 288 | interactive: bool = False, 289 | num_samples: int = 5, 290 | max_new_tokens: int = 100, 291 | batch_size: int = 1, 292 | top_k: int = 200, 293 | temperature: float = 0.8, 294 | checkpoint_path: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), 295 | compile: bool = True, 296 | compile_prefill: bool = False, 297 | profile: Optional[Path] = None, 298 | draft_checkpoint_path: Optional[Path] = None, 299 | speculate_k: int = 5, 300 | device=default_device, 301 | ) -> None: 302 | """Generates text samples based on a pre-trained Transformer model and tokenizer. 303 | """ 304 | assert checkpoint_path.is_file(), checkpoint_path 305 | 306 | tokenizer_path = checkpoint_path.parent / "tokenizer.model" 307 | assert tokenizer_path.is_file(), str(tokenizer_path) 308 | 309 | global print 310 | from tp import maybe_init_dist 311 | rank = maybe_init_dist() 312 | use_tp = rank is not None 313 | if use_tp: 314 | if rank != 0: 315 | # only print on rank 0 316 | print = lambda *args, **kwargs: None 317 | 318 | print(f"Using device={device}") 319 | precision = torch.bfloat16 320 | is_speculative = draft_checkpoint_path is not None 321 | is_chat = "chat" in str(checkpoint_path) 322 | 323 | print("Loading model ...") 324 | t0 = time.time() 325 | model = _load_model(checkpoint_path, device, precision, use_tp) 326 | 327 | if is_speculative: 328 | draft_model = _load_model(draft_checkpoint_path, device, precision, use_tp) 329 | else: 330 | draft_model = None 331 | 332 | device_sync(device=device) # MKG 333 | print(f"Time to load model: {time.time() - t0:.02f} seconds") 334 | 335 | tokenizer = get_tokenizer(tokenizer_path, checkpoint_path) 336 | 337 | if isinstance(prompt, str): 338 | encoded = encode_tokens(tokenizer, prompt, bos=True, device=device) 339 | else: 340 | # generate a fully synthetic prompt 341 | encoded = torch.randint(0, 1024, (prompt,), device=device, dtype=torch.int64) 342 | prompt_length = encoded.size(-1) 343 | 344 | torch.manual_seed(1234) 345 | model_size, params = _get_model_size(model) 346 | if compile: 347 | if is_speculative and use_tp: # and ("cuda" in device): 348 | torch._inductor.config.triton.cudagraph_trees = False # Bug with cudagraph trees in this case 349 | 350 | if is_speculative: 351 | global model_forward, logits_to_prob 352 | model_forward = torch.compile(model_forward, mode="reduce-overhead", fullgraph=True) 353 | 354 | global decode_one_token, prefill 355 | decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead", fullgraph=True) 356 | 357 | # Uncomment to squeeze more perf out of prefill 358 | if compile_prefill: 359 | prefill = torch.compile(prefill, fullgraph=True, dynamic=True) 360 | 361 | 362 | aggregate_metrics = { 363 | 'tokens_per_sec': [], 364 | 'accept_counts': [], 365 | } 366 | start = -1 if compile else 0 367 | 368 | for i in range(start, num_samples): 369 | device_sync(device=device) # MKG 370 | if i >= 0 and interactive: 371 | prompt = input("What is your prompt? ") 372 | if is_chat: 373 | prompt = f"{B_INST} {prompt.strip()} {E_INST}" 374 | encoded = encode_tokens(tokenizer, prompt, bos=True, device=device) 375 | 376 | if interactive and i >= 0: 377 | buffer = [] 378 | period_id = tokenizer.encode('.')[0] 379 | done_generating = False 380 | def callback(x): 381 | nonlocal done_generating 382 | if done_generating: 383 | return 384 | buffer.append(tokenizer.decode([period_id] + x.tolist())[1:]) 385 | if x.item() == tokenizer.eos_id(): 386 | done_generating = True 387 | if len(buffer) == 4 or done_generating: 388 | print(''.join(buffer), end='', flush=True) 389 | buffer.clear() 390 | # print(, end='', flush=True) 391 | else: 392 | callback = lambda x : x 393 | t0 = time.perf_counter() 394 | import contextlib 395 | if (i != num_samples - 1 or not profile) or (use_tp and rank != 0): 396 | prof = contextlib.nullcontext() 397 | else: 398 | torch.profiler._utils._init_for_cuda_graphs() 399 | prof = torch.profiler.profile() 400 | with prof: 401 | y, metrics = generate( 402 | model, 403 | encoded, 404 | max_new_tokens, 405 | batch_size=batch_size, 406 | draft_model=draft_model, 407 | speculate_k=speculate_k, 408 | interactive=interactive, 409 | callback=callback, 410 | temperature=temperature, 411 | top_k=top_k, 412 | ) 413 | aggregate_metrics['accept_counts'].append(metrics['accept_counts']) 414 | if i == -1: 415 | print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds") 416 | continue 417 | if hasattr(prof, "export_chrome_trace"): 418 | if use_tp: 419 | prof.export_chrome_trace(f"{profile}_rank_{rank}.json") 420 | else: 421 | prof.export_chrome_trace(f"{profile}.json") 422 | device_sync(device=device) # MKG 423 | t = time.perf_counter() - t0 424 | 425 | if not interactive: 426 | # Just displaying the first generation 427 | if batch_size > 1: 428 | print("Only displaying the first generation of the batch") 429 | print(tokenizer.decode(y[0].tolist())) 430 | else: 431 | print() 432 | tokens_generated = y.size(-1) - prompt_length 433 | generated_tokens_sec = tokens_generated / t 434 | aggregate_metrics['tokens_per_sec'].append(generated_tokens_sec) 435 | print(f"Time for inference {i + 1}: {t:.02f} sec total, {generated_tokens_sec:.02f} tokens/sec") 436 | print(f"Bandwidth achieved: {model_size * generated_tokens_sec / 1e9:.02f} GB/s") 437 | total_tokens_sec = y.numel() / t 438 | print(f"FLOPS achieved: {params * total_tokens_sec * 2 / 1e12:.02f} TF/s") 439 | print() 440 | print("==========") 441 | if is_speculative: 442 | counts_aggregated = [sum(i) for i in zip(*aggregate_metrics['accept_counts'])] 443 | acceptance_probs = [i/sum(counts_aggregated) for i in counts_aggregated] 444 | print(f"Acceptance probs: {acceptance_probs}") 445 | print(f"Mean Accepted: {sum([idx * i for idx, i in enumerate(counts_aggregated)])/sum(counts_aggregated)}") 446 | 447 | print(f"Batch Size: {batch_size}") 448 | print(f"Prompt Length: {prompt_length}") 449 | print(f"Generated tokens: {max_new_tokens}") 450 | print(f"Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f}") 451 | print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") 452 | 453 | 454 | if __name__ == '__main__': 455 | import argparse 456 | parser = argparse.ArgumentParser(description='Your CLI description.') 457 | 458 | def int_or_str(x): 459 | try: 460 | return int(x) 461 | except: 462 | return x 463 | 464 | parser.add_argument('--prompt', type=int_or_str, default="Hello, my name is", help="Input prompt. If it's an integer, will instead generate a synthetic prompt.") 465 | parser.add_argument('--interactive', action='store_true', help='Whether to launch in interactive mode') 466 | parser.add_argument('--num_samples', type=int, default=5, help='Number of samples.') 467 | parser.add_argument('--max_new_tokens', type=int, default=200, help='Maximum number of new tokens.') 468 | parser.add_argument('--batch_size', type=int, default=1, help='Batch size to benchmark with') 469 | parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.') 470 | parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.') 471 | parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), help='Model checkpoint path.') 472 | parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') 473 | parser.add_argument('--compile_prefill', action='store_true', help='Whether to compile the prefill (improves prefill perf, but higher compile times)') 474 | parser.add_argument('--profile', type=Path, default=None, help='Profile path.') 475 | parser.add_argument('--speculate_k', type=int, default=5, help='Speculative execution depth.') 476 | parser.add_argument('--draft_checkpoint_path', type=Path, default=None, help='Draft checkpoint path.') 477 | parser.add_argument('--device', type=str, default=default_device, help='Device to use') 478 | 479 | args = parser.parse_args() 480 | main( 481 | args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.batch_size, args.top_k, 482 | args.temperature, args.checkpoint_path, args.compile, args.compile_prefill, args.profile, args.draft_checkpoint_path, 483 | args.speculate_k, args.device 484 | ) 485 | -------------------------------------------------------------------------------- /mixtral-moe/README.md: -------------------------------------------------------------------------------- 1 | # Mixtral 8x7B 2 | [Mixtral 8x7B](https://mistral.ai/news/mixtral-of-experts/) is a high-quality sparse mixture of experts (MoE) model that matches or beats GPT3.5 on most benchmarks. This repro is a simple and efficient PyTorch native implementation of Mixtral 8x7B. 3 | 4 | ## Downloading Weights 5 | 6 | ```bash 7 | export MODEL_REPO=mistralai/Mixtral-8x7B-v0.1 8 | python scripts/download.py --repo_id $MODEL_REPO 9 | python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/$MODEL_REPO 10 | ``` 11 | 12 | ## Benchmarks 13 | Benchmarks run on an 8xA100-80GB, power limited to 330W with a hybrid cube mesh topology. Note that all benchmarks are run at *batch size=1*, making the reported tokens/s numbers equivalent to "tokens/s/user". In addition, they are run with a very small prompt length (just 5 tokens). 14 | 15 | | | 1 GPU | 2 GPU | 4 GPU | 8 GPU | 16 | |------------------|---------|-----------|--------|------------| 17 | |baseline(bfloat16)| OOM | 96.67 | 155.35 | 227.82 | 18 | | int8 | 97.92 | 155.03 | 216.87 | 279.35 | 19 | 20 | 21 | ## Generate Text 22 | 23 | Model definition in `model.py`, generation code in `generate.py`. 24 | 25 | ```bash 26 | python generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model.pth --prompt "Hello, my name is" 27 | ``` 28 | 29 | To squeeze out a little bit more performance, you can also compile the prefill with `--compile_prefill`. This will increase compilation times though. 30 | 31 | ## Quantization 32 | ### Int8 Weight-Only Quantization 33 | To generate this version of the model 34 | ```bash 35 | # Spits out model at checkpoints/$MODEL_REPO/model_int8.pth 36 | python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int8 37 | ``` 38 | To run with int8, just pass the int8 checkpoint to generate.py. 39 | ```bash 40 | python generate.py --compile --compile_prefill --checkpoint_path checkpoints/$MODEL_REPO/model_int8.pth 41 | ``` 42 | 43 | ## Tensor Parallelism 44 | ```bash 45 | ENABLE_INTRA_NODE_COMM=1 torchrun --standalone --nproc_per_node=8 generate.py --compile --compile_prefill --checkpoint_path checkpoints/$MODEL_REPO/model.pth 46 | ``` 47 | -------------------------------------------------------------------------------- /mixtral-moe/generate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import itertools 7 | import sys 8 | import time 9 | from pathlib import Path 10 | from typing import Optional, Tuple 11 | 12 | import torch 13 | import torch._dynamo.config 14 | import torch._inductor.config 15 | 16 | def device_sync(device): 17 | if "cuda" in device: 18 | torch.cuda.synchronize(device) 19 | elif "cpu" in device: 20 | pass 21 | else: 22 | print(f"device={device} is not yet suppported") 23 | 24 | 25 | torch._inductor.config.coordinate_descent_tuning = True 26 | torch._inductor.config.triton.unique_kernel_names = True 27 | torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future 28 | 29 | 30 | # support running without installing as a package 31 | wd = Path(__file__).parent.parent.resolve() 32 | sys.path.append(str(wd)) 33 | 34 | from sentencepiece import SentencePieceProcessor 35 | 36 | from model import Transformer 37 | from tp import maybe_init_dist 38 | 39 | 40 | def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization 41 | q = torch.empty_like(probs_sort).exponential_(1) 42 | return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) 43 | 44 | def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None): 45 | logits = logits / max(temperature, 1e-5) 46 | 47 | if top_k is not None: 48 | v, _ = torch.topk(logits, min(top_k, logits.size(-1))) 49 | pivot = v.select(-1, -1).unsqueeze(-1) 50 | logits = torch.where(logits < pivot, -float("Inf"), logits) 51 | probs = torch.nn.functional.softmax(logits, dim=-1) 52 | return probs 53 | 54 | def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None): 55 | probs = logits_to_probs(logits[0, -1], temperature, top_k) 56 | idx_next = multinomial_sample_one_no_sync(probs) 57 | return idx_next, probs 58 | 59 | def prefill(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> torch.Tensor: 60 | # input_pos: [B, S] 61 | logits = model(x, input_pos) 62 | return sample(logits, **sampling_kwargs)[0] 63 | 64 | def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> Tuple[torch.Tensor, torch.Tensor]: 65 | # input_pos: [B, 1] 66 | assert input_pos.shape[-1] == 1 67 | logits = model(x, input_pos) 68 | return sample(logits, **sampling_kwargs) 69 | 70 | def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, callback=lambda _: _, **sampling_kwargs): 71 | new_tokens, new_probs = [], [] 72 | for i in range(num_new_tokens): 73 | with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): # Actually better for Inductor to codegen attention here 74 | next_token, next_prob = decode_one_token( 75 | model, cur_token, input_pos, **sampling_kwargs 76 | ) 77 | input_pos += 1 78 | new_tokens.append(next_token.clone()) 79 | callback(new_tokens[-1]) 80 | new_probs.append(next_prob.clone()) 81 | cur_token = next_token.view(1, -1) 82 | 83 | return new_tokens, new_probs 84 | 85 | 86 | def model_forward(model, x, input_pos): 87 | return model(x, input_pos) 88 | 89 | @torch.no_grad() 90 | def generate( 91 | model: Transformer, 92 | prompt: torch.Tensor, 93 | max_new_tokens: int, 94 | *, 95 | interactive: bool, 96 | callback = lambda x: x, 97 | **sampling_kwargs 98 | ) -> torch.Tensor: 99 | """ 100 | Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. 101 | """ 102 | 103 | # create an empty tensor of the expected final shape and fill in the current tokens 104 | T = prompt.size(0) 105 | T_new = T + max_new_tokens 106 | if interactive: 107 | max_seq_length = 350 108 | else: 109 | max_seq_length = min(T_new, model.config.block_size) 110 | 111 | device, dtype = prompt.device, prompt.dtype 112 | with torch.device(device): 113 | model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) 114 | 115 | # create an empty tensor of the expected final shape and fill in the current tokens 116 | empty = torch.empty(T_new, dtype=dtype, device=device) 117 | empty[:T] = prompt 118 | seq = empty 119 | input_pos = torch.arange(0, T, device=device) 120 | 121 | next_token = prefill(model, prompt.view(1, -1), input_pos, **sampling_kwargs) 122 | seq[T] = next_token 123 | 124 | input_pos = torch.tensor([T], device=device, dtype=torch.int) 125 | 126 | generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, max_new_tokens - 1, callback=callback, **sampling_kwargs) 127 | seq[T + 1:] = torch.cat(generated_tokens) 128 | 129 | return seq 130 | 131 | def encode_tokens(tokenizer, string, bos=True, device='cuda'): 132 | tokens = tokenizer.encode(string) 133 | if bos: 134 | tokens = [tokenizer.bos_id()] + tokens 135 | return torch.tensor(tokens, dtype=torch.int, device=device) 136 | 137 | def _load_model(checkpoint_path, device, precision, use_tp): 138 | with torch.device('meta'): 139 | model = Transformer.from_name(checkpoint_path.parent.name) 140 | 141 | if "int8" in str(checkpoint_path): 142 | print("Using int8 weight-only quantization!") 143 | from quantize import WeightOnlyBit8QuantHandler 144 | simple_quantizer = WeightOnlyBit8QuantHandler(model, torch.int8) 145 | model = simple_quantizer.convert_for_runtime() 146 | 147 | checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) 148 | model.load_state_dict(checkpoint, assign=True) 149 | 150 | if use_tp: 151 | from tp import apply_tp 152 | print("Applying tensor parallel to model ...") 153 | apply_tp(model) 154 | 155 | model = model.to(device=device, dtype=precision) 156 | return model.eval() 157 | 158 | B_INST, E_INST = "[INST]", "[/INST]" 159 | 160 | def main( 161 | prompt: str = "Hello, my name is", 162 | interactive: bool = False, 163 | num_samples: int = 5, 164 | max_new_tokens: int = 100, 165 | top_k: int = 200, 166 | temperature: float = 0.8, 167 | checkpoint_path: Path = Path("checkpoints/mistralai/Mixtral-8x7B-v0.1/model.pth"), 168 | compile: bool = True, 169 | compile_prefill: bool = False, 170 | profile: Optional[Path] = None, 171 | device='cuda', 172 | ) -> None: 173 | """Generates text samples based on a pre-trained Transformer model and tokenizer. 174 | """ 175 | assert checkpoint_path.is_file(), checkpoint_path 176 | 177 | tokenizer_path = checkpoint_path.parent / "tokenizer.model" 178 | assert tokenizer_path.is_file(), str(tokenizer_path) 179 | 180 | global print 181 | rank = maybe_init_dist() 182 | use_tp = rank is not None 183 | if use_tp: 184 | if rank != 0: 185 | # only print on rank 0 186 | print = lambda *args, **kwargs: None 187 | 188 | print(f"Using device={device}") 189 | precision = torch.bfloat16 190 | is_chat = "chat" in str(checkpoint_path) 191 | 192 | print("Loading model ...") 193 | t0 = time.time() 194 | model = _load_model(checkpoint_path, device, precision, use_tp) 195 | 196 | device_sync(device=device) # MKG 197 | print(f"Time to load model: {time.time() - t0:.02f} seconds") 198 | 199 | tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path)) 200 | encoded = encode_tokens(tokenizer, prompt, bos=True, device=device) 201 | prompt_length = encoded.size(0) 202 | 203 | torch.manual_seed(1234) 204 | model_size = sum([p.numel() * p.dtype.itemsize for p in itertools.chain(model.parameters(), model.buffers())]) 205 | if compile: 206 | torch._inductor.config.assert_indirect_indexing = False 207 | 208 | global decode_one_token, prefill 209 | decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead", fullgraph=True) 210 | 211 | # Uncomment to squeeze more perf out of prefill 212 | if args.compile_prefill: 213 | prefill = torch.compile(prefill, fullgraph=True, dynamic=True) 214 | 215 | 216 | aggregate_metrics = { 217 | 'tokens_per_sec': [], 218 | } 219 | start = -1 if compile else 0 220 | 221 | for i in range(start, num_samples): 222 | device_sync(device=device) # MKG 223 | if i >= 0 and interactive: 224 | prompt = input("What is your prompt? ") 225 | if is_chat: 226 | prompt = f"{B_INST} {prompt.strip()} {E_INST}" 227 | encoded = encode_tokens(tokenizer, prompt, bos=True, device=device) 228 | 229 | if interactive and i >= 0: 230 | buffer = [] 231 | period_id = tokenizer.encode('.')[0] 232 | done_generating = False 233 | def callback(x): 234 | nonlocal done_generating 235 | if done_generating: 236 | return 237 | buffer.append(tokenizer.decode([period_id] + x.tolist())[1:]) 238 | if x.item() == tokenizer.eos_id(): 239 | done_generating = True 240 | if len(buffer) == 4 or done_generating: 241 | print(''.join(buffer), end='', flush=True) 242 | buffer.clear() 243 | # print(, end='', flush=True) 244 | else: 245 | callback = lambda x : x 246 | t0 = time.perf_counter() 247 | import contextlib 248 | if (i != num_samples - 1 or not profile) or (use_tp and rank != 0): 249 | prof = contextlib.nullcontext() 250 | else: 251 | torch.profiler._utils._init_for_cuda_graphs() 252 | prof = torch.profiler.profile() 253 | with prof: 254 | y = generate( 255 | model, 256 | encoded, 257 | max_new_tokens, 258 | interactive=interactive, 259 | callback=callback, 260 | temperature=temperature, 261 | top_k=top_k, 262 | ) 263 | if i == -1: 264 | print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds") 265 | continue 266 | if hasattr(prof, "export_chrome_trace"): 267 | if use_tp: 268 | prof.export_chrome_trace(f"{profile}_rank_{rank}.json") 269 | else: 270 | prof.export_chrome_trace(f"{profile}.json") 271 | device_sync(device=device) # MKG 272 | t = time.perf_counter() - t0 273 | 274 | if not interactive: 275 | print(tokenizer.decode(y.tolist())) 276 | else: 277 | print() 278 | tokens_generated = y.size(0) - prompt_length 279 | tokens_sec = tokens_generated / t 280 | aggregate_metrics['tokens_per_sec'].append(tokens_sec) 281 | print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec") 282 | print(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s") 283 | 284 | print(f"Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f}") 285 | print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") 286 | 287 | 288 | if __name__ == '__main__': 289 | import argparse 290 | parser = argparse.ArgumentParser(description='Your CLI description.') 291 | 292 | parser.add_argument('--prompt', type=str, default="Hello, my name is", help='Input prompt.') 293 | parser.add_argument('--interactive', action='store_true', help='Whether to launch in interactive mode') 294 | parser.add_argument('--num_samples', type=int, default=5, help='Number of samples.') 295 | parser.add_argument('--max_new_tokens', type=int, default=200, help='Maximum number of new tokens.') 296 | parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.') 297 | parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.') 298 | parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), help='Model checkpoint path.') 299 | parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') 300 | parser.add_argument('--compile_prefill', action='store_true', help='Whether to compile the prefill (improves prefill perf, but higher compile times)') 301 | parser.add_argument('--profile', type=Path, default=None, help='Profile path.') 302 | parser.add_argument('--device', type=str, default="cuda", help='device to use') 303 | 304 | args = parser.parse_args() 305 | main( 306 | args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k, 307 | args.temperature, args.checkpoint_path, args.compile, args.compile_prefill, args.profile, args.device 308 | ) 309 | -------------------------------------------------------------------------------- /mixtral-moe/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | from dataclasses import dataclass 7 | from typing import Optional 8 | 9 | import torch 10 | import torch.nn as nn 11 | from torch import Tensor 12 | from torch.nn import functional as F 13 | 14 | 15 | def find_multiple(n: int, k: int) -> int: 16 | if n % k == 0: 17 | return n 18 | return n + k - (n % k) 19 | 20 | @dataclass 21 | class ModelArgs: 22 | block_size: int = 2048 23 | vocab_size: int = 32000 24 | n_layer: int = 32 25 | n_head: int = 32 26 | dim: int = 4096 27 | intermediate_size: int = None 28 | n_local_heads: int = -1 29 | head_dim: int = 64 30 | rope_base: float = 10000 31 | norm_eps: float = 1e-5 32 | num_experts: int = 8 33 | num_activated_experts: int = 2 34 | 35 | def __post_init__(self): 36 | if self.n_local_heads == -1: 37 | self.n_local_heads = self.n_head 38 | if self.intermediate_size is None: 39 | hidden_dim = 4 * self.dim 40 | n_hidden = int(2 * hidden_dim / 3) 41 | self.intermediate_size = find_multiple(n_hidden, 256) 42 | self.head_dim = self.dim // self.n_head 43 | 44 | @classmethod 45 | def from_name(cls, name: str): 46 | if name in transformer_configs: 47 | return cls(**transformer_configs[name]) 48 | # fuzzy search 49 | config = [config for config in transformer_configs if config in str(name).upper() or config in str(name)] 50 | assert len(config) == 1, name 51 | return cls(**transformer_configs[config[0]]) 52 | 53 | 54 | transformer_configs = { 55 | "Mixtral-8x7B-v0.1": dict(block_size=32768, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, rope_base=1000000.0, num_experts=8, num_activated_experts=2), 56 | } 57 | 58 | class KVCache(nn.Module): 59 | def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16): 60 | super().__init__() 61 | cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) 62 | self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype)) 63 | self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype)) 64 | 65 | def update(self, input_pos, k_val, v_val): 66 | # input_pos: [S], k_val: [B, H, S, D] 67 | assert input_pos.shape[0] == k_val.shape[2] 68 | 69 | k_out = self.k_cache 70 | v_out = self.v_cache 71 | k_out[:, :, input_pos] = k_val 72 | v_out[:, :, input_pos] = v_val 73 | 74 | return k_out, v_out 75 | 76 | class Transformer(nn.Module): 77 | def __init__(self, config: ModelArgs) -> None: 78 | super().__init__() 79 | self.config = config 80 | 81 | self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim) 82 | self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer)) 83 | self.norm = RMSNorm(config.dim, eps=config.norm_eps) 84 | self.output = nn.Linear(config.dim, config.vocab_size, bias=False) 85 | 86 | self.freqs_cis: Optional[Tensor] = None 87 | self.mask_cache: Optional[Tensor] = None 88 | self.max_batch_size = -1 89 | self.max_seq_length = -1 90 | 91 | def setup_caches(self, max_batch_size, max_seq_length): 92 | if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size: 93 | return 94 | head_dim = self.config.dim // self.config.n_head 95 | max_seq_length = find_multiple(max_seq_length, 8) 96 | self.max_seq_length = max_seq_length 97 | self.max_batch_size = max_batch_size 98 | for b in self.layers: 99 | b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim) 100 | 101 | self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.dim // self.config.n_head, self.config.rope_base) 102 | self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)) 103 | 104 | def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: 105 | assert self.freqs_cis is not None, "Caches must be initialized first" 106 | mask = self.causal_mask[None, None, input_pos] 107 | freqs_cis = self.freqs_cis[input_pos] 108 | x = self.tok_embeddings(idx) 109 | 110 | for i, layer in enumerate(self.layers): 111 | x = layer(x, input_pos, freqs_cis, mask) 112 | x = self.norm(x) 113 | logits = self.output(x) 114 | return logits 115 | 116 | @classmethod 117 | def from_name(cls, name: str): 118 | return cls(ModelArgs.from_name(name)) 119 | 120 | 121 | class TransformerBlock(nn.Module): 122 | def __init__(self, config: ModelArgs) -> None: 123 | super().__init__() 124 | self.attention = Attention(config) 125 | self.block_sparse_moe = MOEFeedForward(config) 126 | self.ffn_norm = RMSNorm(config.dim, config.norm_eps) 127 | self.attention_norm = RMSNorm(config.dim, config.norm_eps) 128 | 129 | def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor: 130 | h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos) 131 | out = h + self.block_sparse_moe(self.ffn_norm(h)) 132 | return out 133 | 134 | 135 | class Attention(nn.Module): 136 | def __init__(self, config: ModelArgs): 137 | super().__init__() 138 | assert config.dim % config.n_head == 0 139 | 140 | total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim 141 | # key, query, value projections for all heads, but in a batch 142 | self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False) 143 | self.wo = nn.Linear(config.dim, config.dim, bias=False) 144 | self.kv_cache = None 145 | 146 | self.n_head = config.n_head 147 | self.head_dim = config.head_dim 148 | self.n_local_heads = config.n_local_heads 149 | self.dim = config.dim 150 | self._register_load_state_dict_pre_hook(self.load_hook) 151 | 152 | def load_hook(self, state_dict, prefix, *args): 153 | if prefix + "wq.weight" in state_dict: 154 | wq = state_dict.pop(prefix + "wq.weight") 155 | wk = state_dict.pop(prefix + "wk.weight") 156 | wv = state_dict.pop(prefix + "wv.weight") 157 | state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv]) 158 | 159 | def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: 160 | bsz, seqlen, _ = x.shape 161 | 162 | kv_size = self.n_local_heads * self.head_dim 163 | q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1) 164 | 165 | q = q.view(bsz, seqlen, self.n_head, self.head_dim) 166 | k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim) 167 | v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim) 168 | 169 | q = apply_rotary_emb(q, freqs_cis) 170 | k = apply_rotary_emb(k, freqs_cis) 171 | 172 | q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) 173 | 174 | if self.kv_cache is not None: 175 | k, v = self.kv_cache.update(input_pos, k, v) 176 | 177 | k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) 178 | v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) 179 | y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) 180 | 181 | y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) 182 | 183 | y = self.wo(y) 184 | return y 185 | 186 | 187 | class ConditionalFeedForward(nn.Module): 188 | def __init__(self, config): 189 | super().__init__() 190 | self.w1 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim)) 191 | self.w2 = nn.Parameter(torch.empty(config.num_experts, config.dim, config.intermediate_size)) 192 | self.w3 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim)) 193 | 194 | def forward(self, x: Tensor, expert_indices: Tensor) -> Tensor: 195 | w1_weights = self.w1[expert_indices] # [T, A, D, D] 196 | w3_weights = self.w3[expert_indices] # [T, A, D, D] 197 | w2_weights = self.w2[expert_indices] # [T, A, D, D] 198 | x1 = F.silu(torch.einsum('ti,taoi -> tao', x, w1_weights)) 199 | x3 = torch.einsum('ti, taoi -> tao', x, w3_weights) 200 | expert_outs = torch.einsum('tao, taio -> tai', (x1 * x3), w2_weights) 201 | return expert_outs 202 | 203 | 204 | class MOEFeedForward(nn.Module): 205 | def __init__(self, config) -> None: 206 | super().__init__() 207 | self.gate = nn.Linear(config.dim, config.num_experts, bias=False) 208 | self.cond_ffn = ConditionalFeedForward(config) 209 | self.dim = config.dim 210 | self.num_activated_experts = config.num_activated_experts 211 | def forward(self, x: Tensor) -> Tensor: 212 | x = x.view(-1, self.dim) 213 | # T = num_tokens, E = num_experts, D = hidden dim, A = activated experts 214 | # x: [T, D] 215 | scores = self.gate(x) # [T, E] 216 | expert_weights = F.softmax(scores, dim=-1) 217 | expert_weights, expert_indices = torch.topk(expert_weights, self.num_activated_experts, dim=-1) # [T, A], [T, A] 218 | expert_weights /= expert_weights.sum(dim=-1, keepdim=True) # [T, A] 219 | expert_outs = self.cond_ffn(x, expert_indices) 220 | return torch.einsum('tai,ta -> ti', expert_outs, expert_weights) 221 | 222 | 223 | class RMSNorm(nn.Module): 224 | def __init__(self, dim: int, eps: float = 1e-5): 225 | super().__init__() 226 | self.eps = eps 227 | self.weight = nn.Parameter(torch.ones(dim)) 228 | 229 | def _norm(self, x): 230 | return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) 231 | 232 | def forward(self, x: Tensor) -> Tensor: 233 | output = self._norm(x.float()).type_as(x) 234 | return output * self.weight 235 | 236 | 237 | def precompute_freqs_cis( 238 | seq_len: int, n_elem: int, base: int = 10000 239 | ) -> Tensor: 240 | freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)) 241 | t = torch.arange(seq_len, device=freqs.device) 242 | freqs = torch.outer(t, freqs) 243 | freqs_cis = torch.polar(torch.ones_like(freqs), freqs) 244 | cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) 245 | return cache.to(dtype=torch.bfloat16) 246 | 247 | 248 | def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: 249 | xshaped = x.float().reshape(*x.shape[:-1], -1, 2) 250 | freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) 251 | x_out2 = torch.stack( 252 | [ 253 | xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1], 254 | xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1], 255 | ], 256 | -1, 257 | ) 258 | 259 | x_out2 = x_out2.flatten(3) 260 | return x_out2.type_as(x) 261 | -------------------------------------------------------------------------------- /mixtral-moe/quantize.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import time 7 | from pathlib import Path 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | from model import Transformer, ConditionalFeedForward 14 | 15 | ##### Quantization Primitives ###### 16 | 17 | def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype): 18 | # assumes symmetric quantization 19 | # assumes axis == 0 20 | # assumes dense memory format 21 | # TODO(future): relax ^ as needed 22 | 23 | # default setup for affine quantization of activations 24 | eps = torch.finfo(torch.float32).eps 25 | 26 | # get min and max 27 | min_val, max_val = torch.aminmax(x, dim=1) 28 | 29 | # calculate scales and zero_points based on min and max 30 | # reference: https://fburl.com/code/srbiybme 31 | min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) 32 | max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) 33 | device = min_val_neg.device 34 | 35 | # reference: https://fburl.com/code/4wll53rk 36 | max_val_pos = torch.max(-min_val_neg, max_val_pos) 37 | scales = max_val_pos / (float(quant_max - quant_min) / 2) 38 | # ensure scales is the same dtype as the original tensor 39 | scales = torch.clamp(scales, min=eps).to(x.dtype) 40 | zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device) 41 | 42 | # quantize based on qmin/qmax/scales/zp 43 | # reference: https://www.internalfb.com/code/fbsource/[8edc275012b1]/fbcode/caffe2/torch/ao/quantization/fx/_decomposed.py?lines=63 44 | x_div = x / scales.unsqueeze(-1) 45 | x_round = torch.round(x_div) 46 | x_zp = x_round + zero_points.unsqueeze(-1) 47 | quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype) 48 | 49 | return quant, scales, zero_points 50 | 51 | 52 | ##### Weight-only int8 per-channel quantized code ###### 53 | 54 | def replace_linear_weight_only_bit8_per_channel(module, target_dtype): 55 | for name, child in module.named_children(): 56 | if isinstance(child, nn.Linear) and name != "gate": 57 | setattr(module, name, WeightOnlyBit8Linear(child.in_features, child.out_features, target_dtype=target_dtype)) 58 | elif isinstance(child, ConditionalFeedForward): 59 | num_experts, intermediate_size, dim = child.w1.shape 60 | setattr(module, name, ConditionalFeedForwardBit8(num_experts, intermediate_size, dim, target_dtype=target_dtype)) 61 | else: 62 | replace_linear_weight_only_bit8_per_channel(child, target_dtype) 63 | 64 | class WeightOnlyBit8QuantHandler: 65 | def __init__(self, mod, target_dtype): 66 | self.mod = mod 67 | self.target_dtype = target_dtype 68 | 69 | @torch.no_grad() 70 | def create_quantized_state_dict(self): 71 | cur_state_dict = self.mod.state_dict() 72 | for fqn, mod in self.mod.named_modules(): 73 | if isinstance(mod, torch.nn.Linear) and not fqn.endswith(".gate"): 74 | int8_weight, scales, _ = dynamically_quantize_per_channel(mod.weight.float(), -128, 127, self.target_dtype) 75 | cur_state_dict[f"{fqn}.weight"] = int8_weight 76 | cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype) 77 | elif isinstance(mod, ConditionalFeedForward): 78 | for weight_idx in range(0, 3): 79 | weight_name = f"w{weight_idx + 1}" 80 | scales_name = f"scales{weight_idx + 1}" 81 | weight = getattr(mod, weight_name) 82 | num_experts, intermediate_size, dim = weight.shape 83 | 84 | bit8_weight_list = [] 85 | scales_list = [] 86 | for expert_idx in range(num_experts): 87 | bit8_weight, scales, _ = dynamically_quantize_per_channel(weight[expert_idx].float(), -128, 127, self.target_dtype) 88 | bit8_weight_list.append(bit8_weight.reshape(1, intermediate_size, dim)) 89 | scales_list.append(scales.reshape(1, intermediate_size)) 90 | 91 | cur_state_dict[f"{fqn}.{weight_name}"] = torch.cat(bit8_weight_list, dim=0) 92 | cur_state_dict[f"{fqn}.{scales_name}"] = torch.cat(scales_list, dim=0) 93 | 94 | return cur_state_dict 95 | 96 | def convert_for_runtime(self): 97 | replace_linear_weight_only_bit8_per_channel(self.mod, self.target_dtype) 98 | return self.mod 99 | 100 | 101 | class WeightOnlyBit8Linear(torch.nn.Module): 102 | __constants__ = ['in_features', 'out_features'] 103 | in_features: int 104 | out_features: int 105 | weight: torch.Tensor 106 | 107 | def __init__(self, in_features: int, out_features: int, bias: bool = True, 108 | device=None, dtype=None, target_dtype=None) -> None: 109 | assert target_dtype is not None 110 | factory_kwargs = {'device': device, 'dtype': dtype} 111 | super().__init__() 112 | self.in_features = in_features 113 | self.out_features = out_features 114 | self.register_buffer("weight", torch.empty((out_features, in_features), dtype=target_dtype)) 115 | self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16)) 116 | 117 | def forward(self, input: torch.Tensor) -> torch.Tensor: 118 | return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales 119 | 120 | 121 | class ConditionalFeedForwardBit8(nn.Module): 122 | def __init__(self, num_experts, intermediate_size, dim, target_dtype): 123 | super().__init__() 124 | 125 | self.target_dtype = target_dtype 126 | 127 | self.register_buffer("w1", torch.empty(num_experts, intermediate_size, dim, dtype=target_dtype)) 128 | self.register_buffer("w2", torch.empty(num_experts, dim, intermediate_size, dtype=target_dtype)) 129 | self.register_buffer("w3", torch.empty(num_experts, intermediate_size, dim, dtype=target_dtype)) 130 | 131 | self.register_buffer("scales1", torch.empty(num_experts, intermediate_size, dtype=torch.bfloat16)) 132 | self.register_buffer("scales2", torch.empty(num_experts, dim, dtype=torch.bfloat16)) 133 | self.register_buffer("scales3", torch.empty(num_experts, intermediate_size, dtype=torch.bfloat16)) 134 | 135 | def forward(self, x, expert_indices): 136 | w1_weights = self.w1.to(x.dtype)[expert_indices] # [T, A, D, D] 137 | w3_weights = self.w3.to(x.dtype)[expert_indices] # [T, A, D, D] 138 | w2_weights = self.w2.to(x.dtype)[expert_indices] 139 | x1 = F.silu(torch.einsum('ti,taoi -> tao', x, w1_weights) * self.scales1[expert_indices].to(x.dtype)) 140 | x3 = torch.einsum('ti, taoi -> tao', x, w3_weights) * self.scales3[expert_indices].to(x.dtype) 141 | expert_outs = torch.einsum('tao, taio -> tai', (x1 * x3), w2_weights) * self.scales2[expert_indices].to(x.dtype) # [T, A, D, D] 142 | return expert_outs 143 | 144 | 145 | def quantize( 146 | checkpoint_path: Path = Path("checkpoints/mistralai/Mixtral-8x7B-v0.1/model.pth"), 147 | mode: str = 'int8', 148 | label: str = '', 149 | ) -> None: 150 | assert checkpoint_path.is_file(), checkpoint_path 151 | 152 | device = 'cpu' 153 | precision = torch.bfloat16 154 | 155 | print("Loading model ...") 156 | t0 = time.time() 157 | 158 | with torch.device('meta'): 159 | model = Transformer.from_name(checkpoint_path.parent.name) 160 | 161 | checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) 162 | model.load_state_dict(checkpoint, assign=True) 163 | model = model.to(dtype=precision, device=device) 164 | 165 | if mode == 'int8': 166 | print("Quantizing model weights for int8 weight-only symmetric per-channel quantization") 167 | quant_handler = WeightOnlyBit8QuantHandler(model, target_dtype=torch.int8) 168 | quantized_state_dict = quant_handler.create_quantized_state_dict() 169 | 170 | dir_name = checkpoint_path.parent 171 | base_name = checkpoint_path.name 172 | new_base_name = base_name.replace('.pth', f'{label}int8.pth') 173 | 174 | else: 175 | raise ValueError(f"Invalid quantization mode {mode} needs to be one of [int8,]") 176 | 177 | quantize_path = dir_name / new_base_name 178 | print(f"Writing quantized weights to {quantize_path}") 179 | quantize_path.unlink(missing_ok=True) # remove existing file if one already there 180 | torch.save(quantized_state_dict, quantize_path) 181 | print(f"Quantization complete took {time.time() - t0:.02f} seconds") 182 | return 183 | 184 | if __name__ == '__main__': 185 | import argparse 186 | parser = argparse.ArgumentParser(description='Quantize a model.') 187 | parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Path to the model checkpoint to be quantized.') 188 | parser.add_argument('--mode', '-q', type=str, default='int8', choices=['int8', 'int4', 'int4-gptq'], help='type of quantization to perform') 189 | parser.add_argument('--label', type=str, default='_', help='label to add to output filename') 190 | 191 | args = parser.parse_args() 192 | quantize(args.checkpoint_path, args.mode, args.label) 193 | -------------------------------------------------------------------------------- /mixtral-moe/scripts/convert_hf_checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import glob 7 | import json 8 | import re 9 | import sys 10 | from pathlib import Path 11 | from typing import Optional 12 | 13 | import torch 14 | 15 | # support running without installing as a package 16 | wd = Path(__file__).parent.parent.resolve() 17 | sys.path.append(str(wd)) 18 | 19 | from model import ModelArgs 20 | 21 | 22 | @torch.inference_mode() 23 | def convert_hf_checkpoint( 24 | *, 25 | checkpoint_dir: Path = Path("checkpoints/mistralai/Mixtral-8x7B-v0.1"), 26 | model_name: Optional[str] = None, 27 | ) -> None: 28 | if model_name is None: 29 | model_name = checkpoint_dir.name 30 | 31 | config = ModelArgs.from_name(model_name) 32 | print(f"Model config {config.__dict__}") 33 | 34 | weight_map = { 35 | "tok_embeddings.weight": "tok_embeddings.weight", 36 | "layers.{}.attention.wq.weight": "layers.{}.attention.wq.weight", 37 | "layers.{}.attention.wk.weight": "layers.{}.attention.wk.weight", 38 | "layers.{}.attention.wv.weight": "layers.{}.attention.wv.weight", 39 | "layers.{}.attention.wo.weight": "layers.{}.attention.wo.weight", 40 | "layers.{}.block_sparse_moe.w1": "layers.{}.block_sparse_moe.cond_ffn.w1", 41 | "layers.{}.block_sparse_moe.w2": "layers.{}.block_sparse_moe.cond_ffn.w2", 42 | "layers.{}.block_sparse_moe.w3": "layers.{}.block_sparse_moe.cond_ffn.w3", 43 | "layers.{}.block_sparse_moe.gate.weight": "layers.{}.block_sparse_moe.gate.weight", 44 | "layers.{}.attention_norm.weight": "layers.{}.attention_norm.weight", 45 | "layers.{}.ffn_norm.weight": "layers.{}.ffn_norm.weight", 46 | "norm.weight": "norm.weight", 47 | "output.weight": "output.weight", 48 | } 49 | 50 | pt_files = glob.glob(str(checkpoint_dir / "*.pt")) 51 | 52 | merged_result = {} 53 | for file in sorted(pt_files): 54 | state_dict = torch.load(str(file), map_location="cpu", mmap=True, weights_only=True) 55 | merged_result.update(state_dict) 56 | final_result = {} 57 | for key, value in merged_result.items(): 58 | if "layers" in key: 59 | abstract_key = re.sub(r'.(\d+).', '.{}.', key) 60 | layer_num = re.search(r'\d+', key).group(0) 61 | new_key = weight_map[abstract_key] 62 | if new_key is None: 63 | continue 64 | new_key = new_key.format(layer_num) 65 | else: 66 | new_key = weight_map[key] 67 | 68 | final_result[new_key] = value 69 | 70 | for key in tuple(final_result.keys()): 71 | if "wq" in key: 72 | q = final_result[key] 73 | k = final_result[key.replace("wq", "wk")] 74 | v = final_result[key.replace("wq", "wv")] 75 | final_result[key.replace("wq", "wqkv")] = torch.cat([q, k, v]) 76 | del final_result[key] 77 | del final_result[key.replace("wq", "wk")] 78 | del final_result[key.replace("wq", "wv")] 79 | elif "w1" in key or "w3" in key: 80 | final_result[key] = final_result[key].reshape(config.num_experts, config.intermediate_size, config.dim).contiguous() 81 | elif "w2" in key: 82 | final_result[key] = final_result[key].reshape(config.num_experts, config.intermediate_size, config.dim).permute(0, 2, 1).contiguous() 83 | elif "gate" in key: 84 | final_result[key] = final_result[key].contiguous() 85 | 86 | print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}") 87 | torch.save(final_result, checkpoint_dir / "model.pth") 88 | 89 | 90 | if __name__ == '__main__': 91 | import argparse 92 | parser = argparse.ArgumentParser(description='Convert HuggingFace checkpoint.') 93 | parser.add_argument('--checkpoint_dir', type=Path, default=Path("checkpoints/meta-llama/llama-2-7b-chat-hf")) 94 | parser.add_argument('--model_name', type=str, default=None) 95 | 96 | args = parser.parse_args() 97 | convert_hf_checkpoint( 98 | checkpoint_dir=args.checkpoint_dir, 99 | model_name=args.model_name, 100 | ) 101 | -------------------------------------------------------------------------------- /mixtral-moe/scripts/download.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import os 7 | from typing import Optional 8 | 9 | from requests.exceptions import HTTPError 10 | 11 | 12 | def hf_download(repo_id: Optional[str] = None, hf_token: Optional[str] = None) -> None: 13 | from huggingface_hub import snapshot_download 14 | os.makedirs(f"checkpoints/{repo_id}", exist_ok=True) 15 | try: 16 | snapshot_download(repo_id, local_dir=f"checkpoints/{repo_id}", local_dir_use_symlinks=False, token=hf_token, ignore_patterns="*.safetensors") 17 | except HTTPError as e: 18 | if e.response.status_code == 401: 19 | print("You need to pass a valid `--hf_token=...` to download private checkpoints.") 20 | else: 21 | raise e 22 | 23 | if __name__ == '__main__': 24 | import argparse 25 | parser = argparse.ArgumentParser(description='Download data from HuggingFace Hub.') 26 | parser.add_argument('--repo_id', type=str, default="checkpoints/meta-llama/llama-2-7b-chat-hf", help='Repository ID to download from.') 27 | parser.add_argument('--hf_token', type=str, default=None, help='HuggingFace API token.') 28 | 29 | args = parser.parse_args() 30 | hf_download(args.repo_id, args.hf_token) 31 | -------------------------------------------------------------------------------- /mixtral-moe/tp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import os 7 | from typing import List, Optional 8 | 9 | import torch 10 | import torch.distributed as dist 11 | from torch import nn 12 | from torch.distributed import _functional_collectives as funcol 13 | 14 | from model import Attention, MOEFeedForward, Transformer 15 | 16 | 17 | def _get_rank() -> int: 18 | return int(os.environ.get("LOCAL_RANK", "0")) 19 | 20 | def is_local(): 21 | return _get_rank() == 0 22 | 23 | def local_break(): 24 | if is_local(): 25 | breakpoint() 26 | dist.barrier() 27 | 28 | def _get_world_size() -> int: 29 | return int(os.environ.get("LOCAL_WORLD_SIZE", "1")) 30 | 31 | def maybe_init_dist() -> Optional[int]: 32 | try: 33 | # provided by torchrun 34 | rank = _get_rank() 35 | world_size = _get_world_size() 36 | 37 | if world_size < 2: 38 | # too few gpus to parallelize, tp is no-op 39 | return None 40 | except KeyError: 41 | # not run via torchrun, no-op 42 | return None 43 | 44 | torch.cuda.set_device(rank) 45 | dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) 46 | return rank 47 | 48 | rank = _get_rank() 49 | world_size = _get_world_size() 50 | 51 | def shard(x, dim): 52 | assert x.size(dim=dim) % world_size == 0 53 | return torch.tensor_split(x, world_size, dim=dim)[rank] 54 | 55 | def _apply_tp_linear(linear: nn.Linear, style: str, weight_splits: List[int] = []) -> None: 56 | rank = _get_rank() 57 | world_size = _get_world_size() 58 | 59 | # Linear's weight matrix is transposed, and is of shape 60 | # (linear.out_features, linear.in_features) 61 | dim_lookup = { 62 | "colwise": (0, "out_features"), 63 | "rowwise": (1, "in_features") 64 | } 65 | assert style in dim_lookup 66 | shard_dim, size_attr = dim_lookup[style] 67 | 68 | # ensure we can shard evenly 69 | assert getattr(linear, size_attr) % world_size == 0 70 | 71 | def shard_qkv(qkv, dim, weight_splits): 72 | q, k, v = qkv.split(weight_splits, dim=dim) 73 | q = shard(q, dim) 74 | k = shard(k, dim) 75 | v = shard(v, dim) 76 | return torch.cat((q,k,v), dim=dim) 77 | 78 | # shard 79 | if weight_splits: 80 | # attention 81 | assert len(weight_splits) == 3 82 | 83 | sharded_weight = shard_qkv(linear.weight, shard_dim, weight_splits) 84 | if hasattr(linear, "scales") and style == "colwise": 85 | linear.scales = shard_qkv(linear.scales, 0, weight_splits) 86 | else: 87 | sharded_weight = shard(linear.weight, shard_dim) 88 | if hasattr(linear, "scales") and style == "colwise": 89 | linear.scales = shard(linear.scales, 0) 90 | 91 | # local_break() 92 | linear.weight = nn.Parameter(sharded_weight, requires_grad=False) 93 | setattr(linear, size_attr, getattr(linear, size_attr) // world_size) 94 | 95 | # shape info should still be synced 96 | # assert linear.weight.shape == (linear.out_features, linear.in_features) 97 | 98 | 99 | def _apply_tp_moe_ffn(mlp: MOEFeedForward) -> None: 100 | mlp.cond_ffn.w1 = nn.Parameter(shard(mlp.cond_ffn.w1, 1), requires_grad=False) 101 | mlp.cond_ffn.w3 = nn.Parameter(shard(mlp.cond_ffn.w3, 1), requires_grad=False) 102 | mlp.cond_ffn.w2 = nn.Parameter(shard(mlp.cond_ffn.w2, 2), requires_grad=False) 103 | 104 | if hasattr(mlp.cond_ffn, "scales1"): 105 | mlp.cond_ffn.scales1 = nn.Parameter(shard(mlp.cond_ffn.scales1, 1), requires_grad=False) 106 | mlp.cond_ffn.scales3 = nn.Parameter(shard(mlp.cond_ffn.scales3, 1), requires_grad=False) 107 | mlp.cond_ffn.scales2 = nn.Parameter(mlp.cond_ffn.scales2, requires_grad=False) 108 | 109 | world_size = _get_world_size() 110 | mlp.cond_ffn.register_forward_hook(lambda _module, _input, output: funcol.all_reduce( 111 | output, "sum", list(range(world_size)))) 112 | 113 | 114 | def _apply_tp_attn(attn: Attention) -> None: 115 | assert hasattr(attn, "wqkv") 116 | assert hasattr(attn, "wo") 117 | 118 | kv_size = attn.n_local_heads * attn.head_dim 119 | _apply_tp_linear(attn.wqkv, "colwise", [attn.dim, kv_size, kv_size]) 120 | _apply_tp_linear(attn.wo, "rowwise") 121 | 122 | # overwrite 123 | world_size = _get_world_size() 124 | assert attn.n_head % world_size == 0, "assert attn.n_head % world_size == 0" 125 | attn.n_head = attn.n_head // world_size 126 | attn.dim = attn.dim // world_size 127 | attn.head_dim = attn.dim // attn.n_head 128 | attn.n_local_heads = attn.n_local_heads // world_size 129 | 130 | attn.register_forward_hook(lambda _module, _input, output: funcol.all_reduce( 131 | output[0], "sum", list(range(world_size)))) 132 | 133 | 134 | def _apply_tp_Transformer(Transformer: Transformer) -> None: 135 | # overwrite config before Transformer.setup_cache is called 136 | world_size = _get_world_size() 137 | Transformer.config.n_head = Transformer.config.n_head // world_size 138 | Transformer.config.dim = Transformer.config.dim // world_size 139 | Transformer.config.n_local_heads = Transformer.config.n_local_heads // world_size 140 | 141 | 142 | def apply_tp(model: Transformer) -> None: 143 | _apply_tp_Transformer(model) 144 | for block in model.layers: 145 | # Apply to MLP 146 | _apply_tp_moe_ffn(block.block_sparse_moe) 147 | _apply_tp_attn(block.attention) 148 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import math 7 | from dataclasses import dataclass 8 | from typing import Optional 9 | 10 | import torch 11 | import torch.nn as nn 12 | from torch import Tensor 13 | from torch.nn import functional as F 14 | from torch.nn.attention.flex_attention import ( 15 | _mask_mod_signature, 16 | BlockMask, 17 | flex_attention, 18 | ) 19 | 20 | 21 | def find_multiple(n: int, k: int) -> int: 22 | if n % k == 0: 23 | return n 24 | return n + k - (n % k) 25 | 26 | 27 | def get_mask_mod(mask_mod: _mask_mod_signature, offset: int): 28 | def _mask_mod(b, h, q, kv): 29 | return mask_mod(b, h, q + offset, kv) 30 | 31 | return _mask_mod 32 | 33 | 34 | @dataclass 35 | class ModelArgs: 36 | block_size: int = 2048 37 | vocab_size: int = 32000 38 | n_layer: int = 32 39 | n_head: int = 32 40 | dim: int = 4096 41 | intermediate_size: int = None 42 | n_local_heads: int = -1 43 | head_dim: int = 64 44 | rope_base: float = 10000 45 | norm_eps: float = 1e-5 46 | rope_scaling: Optional[dict] = None 47 | 48 | def __post_init__(self): 49 | if self.n_local_heads == -1: 50 | self.n_local_heads = self.n_head 51 | if self.intermediate_size is None: 52 | hidden_dim = 4 * self.dim 53 | n_hidden = int(2 * hidden_dim / 3) 54 | self.intermediate_size = find_multiple(n_hidden, 256) 55 | self.head_dim = self.dim // self.n_head 56 | 57 | @classmethod 58 | def from_name(cls, name: str): 59 | if name in transformer_configs: 60 | return cls(**transformer_configs[name]) 61 | # fuzzy search 62 | config = [config for config in transformer_configs if config.lower() in str(name).lower()] 63 | 64 | # We may have two or more configs matched (e.g. "7B" and "Mistral-7B"). Find the best config match, 65 | # take longer name (as it have more symbols matched) 66 | if len(config) > 1: 67 | config.sort(key=len, reverse=True) 68 | assert len(config[0]) != len(config[1]), name # make sure only one 'best' match 69 | 70 | return cls(**transformer_configs[config[0]]) 71 | 72 | 73 | transformer_configs = { 74 | "CodeLlama-7b-Python-hf": dict(block_size=16384, vocab_size=32000, n_layer=32, dim = 4096, rope_base=1000000), 75 | "7B": dict(n_layer=32, n_head=32, dim=4096), 76 | "13B": dict(n_layer=40, n_head=40, dim=5120), 77 | "30B": dict(n_layer=60, n_head=52, dim=6656), 78 | "34B": dict(n_layer=48, n_head=64, dim=8192, vocab_size=32000, n_local_heads=8, intermediate_size=22016, rope_base=1000000), # CodeLlama-34B-Python-hf 79 | "70B": dict(n_layer=80, n_head=64, dim=8192, n_local_heads=8, intermediate_size=28672), 80 | "Mistral-7B": dict(n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=32000), 81 | "stories15M": dict(n_layer=6, n_head=6, dim=288), 82 | "stories110M": dict(n_layer=12, n_head=12, dim=768), 83 | 84 | "llama-3-8b": dict(block_size=8192, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256, rope_base=500000), 85 | "llama-3-70b": dict(block_size=8192, n_layer=80, n_head=64, n_local_heads=8, dim=8192, intermediate_size=28672, vocab_size=128256, rope_base=500000), 86 | "llama-3.1-8b": dict(block_size=131072, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256, rope_base=500000, 87 | rope_scaling=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_position_embeddings=8192), 88 | ), 89 | "llama-3.1-70b": dict(block_size=131072, n_layer=80, n_head=64, n_local_heads=8, dim=8192, intermediate_size=28672, vocab_size=128256, rope_base=500000, 90 | rope_scaling=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_position_embeddings=8192), 91 | ), 92 | "llama-3.1-405b": dict(block_size=131072, n_layer=126, n_head=128, n_local_heads=8, dim=16384, intermediate_size=53248, vocab_size=128256, rope_base=500000, 93 | rope_scaling=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_position_embeddings=8192), 94 | ), 95 | } 96 | 97 | class KVCache(nn.Module): 98 | def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16): 99 | super().__init__() 100 | cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) 101 | self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype)) 102 | self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype)) 103 | 104 | def update(self, input_pos, k_val, v_val): 105 | # input_pos: [S], k_val: [B, H, S, D] 106 | assert input_pos.shape[0] == k_val.shape[2] 107 | 108 | k_out = self.k_cache 109 | v_out = self.v_cache 110 | k_out[:, :, input_pos] = k_val 111 | v_out[:, :, input_pos] = v_val 112 | 113 | return k_out, v_out 114 | 115 | class Transformer(nn.Module): 116 | def __init__(self, config: ModelArgs) -> None: 117 | super().__init__() 118 | self.config = config 119 | 120 | self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim) 121 | self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer)) 122 | self.norm = RMSNorm(config.dim, eps=config.norm_eps) 123 | self.output = nn.Linear(config.dim, config.vocab_size, bias=False) 124 | 125 | self.freqs_cis: Optional[Tensor] = None 126 | self.mask_cache: Optional[Tensor] = None 127 | self.max_batch_size = -1 128 | self.max_seq_length = -1 129 | self.get_mask_mod = get_mask_mod 130 | 131 | def setup_caches(self, max_batch_size, max_seq_length): 132 | if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size: 133 | return 134 | head_dim = self.config.dim // self.config.n_head 135 | max_seq_length = find_multiple(max_seq_length, 8) 136 | self.max_seq_length = max_seq_length 137 | self.max_batch_size = max_batch_size 138 | dtype = self.output.weight.dtype 139 | # For quantized layers, dtype is encoded in scales 140 | if hasattr(self.output, "scales"): 141 | dtype = self.output.scales.dtype 142 | elif hasattr(self.output, "scales_and_zeros"): 143 | dtype = self.output.scales_and_zeros.dtype 144 | for b in self.layers: 145 | b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype) 146 | 147 | self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.dim // self.config.n_head, self.config.rope_base, dtype, self.config.rope_scaling) 148 | 149 | def forward(self, mask: BlockMask, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: 150 | assert self.freqs_cis is not None, "Caches must be initialized first" 151 | mask.mask_mod = self.get_mask_mod(mask.mask_mod, input_pos[0]) 152 | freqs_cis = self.freqs_cis[input_pos] 153 | x = self.tok_embeddings(idx) 154 | 155 | for i, layer in enumerate(self.layers): 156 | x = layer(x, input_pos, freqs_cis, mask) 157 | x = self.norm(x) 158 | logits = self.output(x) 159 | return logits 160 | 161 | @classmethod 162 | def from_name(cls, name: str): 163 | return cls(ModelArgs.from_name(name)) 164 | 165 | 166 | class TransformerBlock(nn.Module): 167 | def __init__(self, config: ModelArgs) -> None: 168 | super().__init__() 169 | self.attention = Attention(config) 170 | self.feed_forward = FeedForward(config) 171 | self.ffn_norm = RMSNorm(config.dim, config.norm_eps) 172 | self.attention_norm = RMSNorm(config.dim, config.norm_eps) 173 | 174 | def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: BlockMask) -> Tensor: 175 | h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos) 176 | out = h + self.feed_forward(self.ffn_norm(h)) 177 | return out 178 | 179 | 180 | class Attention(nn.Module): 181 | def __init__(self, config: ModelArgs): 182 | super().__init__() 183 | assert config.dim % config.n_head == 0 184 | 185 | total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim 186 | # key, query, value projections for all heads, but in a batch 187 | self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False) 188 | self.wo = nn.Linear(config.dim, config.dim, bias=False) 189 | self.kv_cache = None 190 | 191 | self.n_head = config.n_head 192 | self.head_dim = config.head_dim 193 | self.n_local_heads = config.n_local_heads 194 | self.dim = config.dim 195 | self._register_load_state_dict_pre_hook(self.load_hook) 196 | 197 | def load_hook(self, state_dict, prefix, *args): 198 | if prefix + "wq.weight" in state_dict: 199 | wq = state_dict.pop(prefix + "wq.weight") 200 | wk = state_dict.pop(prefix + "wk.weight") 201 | wv = state_dict.pop(prefix + "wv.weight") 202 | state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv]) 203 | 204 | def forward(self, x: Tensor, freqs_cis: Tensor, mask: BlockMask, input_pos: Optional[Tensor] = None) -> Tensor: 205 | bsz, seqlen, _ = x.shape 206 | 207 | kv_size = self.n_local_heads * self.head_dim 208 | q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1) 209 | 210 | q = q.view(bsz, seqlen, self.n_head, self.head_dim) 211 | k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim) 212 | v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim) 213 | 214 | q = apply_rotary_emb(q, freqs_cis) 215 | k = apply_rotary_emb(k, freqs_cis) 216 | 217 | q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) 218 | 219 | if self.kv_cache is not None: 220 | k, v = self.kv_cache.update(input_pos, k, v) 221 | 222 | y = flex_attention(q, k, v, block_mask=mask, enable_gqa=(self.n_head != self.n_local_heads)) 223 | 224 | y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) 225 | 226 | y = self.wo(y) 227 | return y 228 | 229 | 230 | class FeedForward(nn.Module): 231 | def __init__(self, config: ModelArgs) -> None: 232 | super().__init__() 233 | self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False) 234 | self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False) 235 | self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False) 236 | 237 | def forward(self, x: Tensor) -> Tensor: 238 | return self.w2(F.silu(self.w1(x)) * self.w3(x)) 239 | 240 | 241 | class RMSNorm(nn.Module): 242 | def __init__(self, dim: int, eps: float = 1e-5): 243 | super().__init__() 244 | self.eps = eps 245 | self.weight = nn.Parameter(torch.ones(dim)) 246 | 247 | def _norm(self, x): 248 | return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) 249 | 250 | def forward(self, x: Tensor) -> Tensor: 251 | output = self._norm(x.float()).type_as(x) 252 | return output * self.weight 253 | 254 | 255 | def apply_rope_scaling(freqs: torch.Tensor, rope_scaling: Optional[dict] = None): 256 | factor = rope_scaling["factor"] 257 | low_freq_factor = rope_scaling["low_freq_factor"] 258 | high_freq_factor = rope_scaling["high_freq_factor"] 259 | old_context_len = rope_scaling["original_max_position_embeddings"] 260 | 261 | low_freq_wavelen = old_context_len / low_freq_factor 262 | high_freq_wavelen = old_context_len / high_freq_factor 263 | new_freqs = [] 264 | for freq in freqs: 265 | wavelen = 2 * math.pi / freq 266 | if wavelen < high_freq_wavelen: 267 | new_freqs.append(freq) 268 | elif wavelen > low_freq_wavelen: 269 | new_freqs.append(freq / factor) 270 | else: 271 | assert low_freq_wavelen != high_freq_wavelen 272 | smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) 273 | new_freqs.append((1 - smooth) * freq / factor + smooth * freq) 274 | return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) 275 | 276 | 277 | def precompute_freqs_cis( 278 | seq_len: int, n_elem: int, base: int = 10000, 279 | dtype: torch.dtype = torch.bfloat16, 280 | rope_scaling: Optional[dict] = None, 281 | ) -> Tensor: 282 | freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)) 283 | if rope_scaling is not None: 284 | freqs = apply_rope_scaling(freqs, rope_scaling) 285 | t = torch.arange(seq_len, device=freqs.device) 286 | freqs = torch.outer(t, freqs) 287 | freqs_cis = torch.polar(torch.ones_like(freqs), freqs) 288 | cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) 289 | return cache.to(dtype=dtype) 290 | 291 | 292 | def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: 293 | xshaped = x.float().reshape(*x.shape[:-1], -1, 2) 294 | freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) 295 | x_out2 = torch.stack( 296 | [ 297 | xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1], 298 | xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1], 299 | ], 300 | -1, 301 | ) 302 | 303 | x_out2 = x_out2.flatten(3) 304 | return x_out2.type_as(x) 305 | -------------------------------------------------------------------------------- /quantize.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import time 7 | from pathlib import Path 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from tokenizer import get_tokenizer 13 | 14 | try: 15 | from GPTQ import GenericGPTQRunner, InputRecorder 16 | from eval import get_task_dict, evaluate, lm_eval 17 | except: 18 | pass 19 | 20 | from model import Transformer 21 | 22 | ##### Quantization Primitives ###### 23 | 24 | def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype): 25 | # assumes symmetric quantization 26 | # assumes axis == 0 27 | # assumes dense memory format 28 | # TODO(future): relax ^ as needed 29 | 30 | # default setup for affine quantization of activations 31 | eps = torch.finfo(torch.float32).eps 32 | 33 | # get min and max 34 | min_val, max_val = torch.aminmax(x, dim=1) 35 | 36 | # calculate scales and zero_points based on min and max 37 | # reference: https://fburl.com/code/srbiybme 38 | min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) 39 | max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) 40 | device = min_val_neg.device 41 | 42 | # reference: https://fburl.com/code/4wll53rk 43 | max_val_pos = torch.max(-min_val_neg, max_val_pos) 44 | scales = max_val_pos / (float(quant_max - quant_min) / 2) 45 | # ensure scales is the same dtype as the original tensor 46 | scales = torch.clamp(scales, min=eps).to(x.dtype) 47 | zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device) 48 | 49 | # quantize based on qmin/qmax/scales/zp 50 | # reference: https://www.internalfb.com/code/fbsource/[8edc275012b1]/fbcode/caffe2/torch/ao/quantization/fx/_decomposed.py?lines=63 51 | x_div = x / scales.unsqueeze(-1) 52 | x_round = torch.round(x_div) 53 | x_zp = x_round + zero_points.unsqueeze(-1) 54 | quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype) 55 | 56 | return quant, scales, zero_points 57 | 58 | def get_group_qparams(w, n_bit=4, groupsize=128): 59 | # needed for GPTQ with padding 60 | if groupsize > w.shape[-1]: 61 | groupsize = w.shape[-1] 62 | assert groupsize > 1 63 | assert w.shape[-1] % groupsize == 0 64 | assert w.dim() == 2 65 | 66 | to_quant = w.reshape(-1, groupsize) 67 | assert torch.isnan(to_quant).sum() == 0 68 | 69 | max_val = to_quant.amax(dim=1, keepdim=True) 70 | min_val = to_quant.amin(dim=1, keepdim=True) 71 | max_int = 2**n_bit - 1 72 | scales = (max_val - min_val).clamp(min=1e-6) / max_int 73 | zeros = min_val + scales * (2 ** (n_bit - 1)) 74 | return scales.to(torch.bfloat16).reshape(w.shape[0], -1), zeros.to( 75 | torch.bfloat16 76 | ).reshape(w.shape[0], -1) 77 | 78 | 79 | def pack_scales_and_zeros(scales, zeros): 80 | assert scales.shape == zeros.shape 81 | assert scales.dtype == torch.bfloat16 82 | assert zeros.dtype == torch.bfloat16 83 | return ( 84 | torch.cat( 85 | [ 86 | scales.reshape(scales.size(0), scales.size(1), 1), 87 | zeros.reshape(zeros.size(0), zeros.size(1), 1), 88 | ], 89 | 2, 90 | ) 91 | .transpose(0, 1) 92 | .contiguous() 93 | ) 94 | 95 | 96 | def unpack_scales_and_zeros(scales_and_zeros): 97 | assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2 98 | assert scales_and_zeros.dtype == torch.float 99 | return torch.split(scales_and_zeros.transpose(0, 1), 1, 2) 100 | 101 | 102 | def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128): 103 | assert groupsize > 1 104 | # needed for GPTQ single column quantize 105 | if groupsize > w.shape[-1] and scales.shape[-1] == 1: 106 | groupsize = w.shape[-1] 107 | 108 | assert w.shape[-1] % groupsize == 0 109 | assert w.dim() == 2 110 | 111 | to_quant = w.reshape(-1, groupsize) 112 | assert torch.isnan(to_quant).sum() == 0 113 | 114 | scales = scales.reshape(-1, 1) 115 | zeros = zeros.reshape(-1, 1) 116 | min_val = zeros - scales * (2 ** (n_bit - 1)) 117 | max_int = 2**n_bit - 1 118 | min_int = 0 119 | w_int32 = ( 120 | to_quant.sub(min_val) 121 | .div(scales) 122 | .round() 123 | .clamp_(min_int, max_int) 124 | .to(torch.int32) 125 | .reshape_as(w) 126 | ) 127 | 128 | return w_int32 129 | 130 | 131 | def group_quantize_tensor(w, n_bit=4, groupsize=128): 132 | scales, zeros = get_group_qparams(w, n_bit, groupsize) 133 | w_int32 = group_quantize_tensor_from_qparams(w, scales, zeros, n_bit, groupsize) 134 | scales_and_zeros = pack_scales_and_zeros(scales, zeros) 135 | return w_int32, scales_and_zeros 136 | 137 | 138 | def group_dequantize_tensor_from_qparams( 139 | w_int32, scales, zeros, n_bit=4, groupsize=128 140 | ): 141 | assert groupsize > 1 142 | # needed for GPTQ single column dequantize 143 | if groupsize > w_int32.shape[-1] and scales.shape[-1] == 1: 144 | groupsize = w_int32.shape[-1] 145 | assert w_int32.shape[-1] % groupsize == 0 146 | assert w_int32.dim() == 2 147 | 148 | w_int32_grouped = w_int32.reshape(-1, groupsize) 149 | scales = scales.reshape(-1, 1) 150 | zeros = zeros.reshape(-1, 1) 151 | 152 | w_dq = ( 153 | w_int32_grouped.sub(2 ** (n_bit - 1)).mul(scales).add(zeros).reshape_as(w_int32) 154 | ) 155 | return w_dq 156 | 157 | 158 | def group_dequantize_tensor(w_int32, scales_and_zeros, n_bit=4, groupsize=128): 159 | scales, zeros = unpack_scales_and_zeros(scales_and_zeros) 160 | return group_dequantize_tensor_from_qparams( 161 | w_int32, scales, zeros, n_bit, groupsize 162 | ) 163 | 164 | class QuantHandler: 165 | def __init__(self, mod): 166 | self.mod = mod 167 | 168 | def create_quantized_state_dict(self) -> "StateDict": 169 | pass 170 | 171 | def convert_for_runtime(self) -> "nn.Module": 172 | pass 173 | 174 | class GPTQQuantHandler(QuantHandler): 175 | """ 176 | This class implements a GPTQ QuantHandler that can be used to apply GPTQ to a model in concert with the GenericGPTQRunner class. 177 | Unlike the base QuantHandler class, the user does not need to implement the create_quantized_state_dict, instead they have to reimplement 178 | __init__ such that it defines the functions for the quantization mode. User is expected to reimplement convert_for_runtime. 179 | 180 | The following functions (which must be defined in __init__) are used to define the quantization mode for both GPTQ and 181 | create_quantized_state_dict. Here is a description of each function. 182 | 183 | get_qparams_func: 184 | A function that calculates the quantization qparams for an input tensor. 185 | Args: 186 | weight: A 2d weight tensor with non-integer dtype. 187 | Returns: 188 | qparams: it can have any format but will need to be handled by the other defined functions below. 189 | 190 | quantize_func: 191 | A function that applies quantization to an input tensor. It should be noted 192 | that this function needs to be able to handle quantizing the entire weight tensor, a single group, 193 | or a single column. 194 | Args: 195 | weight: A 2d weight tensor with non-integer dtype. 196 | qparams: the output from get_qparams_func 197 | Returns: 198 | quantized_weight: A 2d quantized weight tensor (generally with an integer dtype) 199 | 200 | 201 | dequantize_func: 202 | A function that dequantizes an input quantized weight tensor. It should be noted 203 | that this function needs to be able to handle dequantizing the entire weight tensor, a single group, 204 | or a single column. 205 | Args: 206 | quantized_weight: A 2d quantized weight tensor (generally with an integer dtype) 207 | qparams: the output from get_qparams_func 208 | Returns: 209 | weight: A 2d weight tensor with non-integer dtype. 210 | 211 | combine_qparams_list_func: 212 | A function that combines several qparams into one qparam. 213 | Args: 214 | qparams_list: a list of qparams objects, each obtained by calling get_qparams_func 215 | on a single group from a weight tensor 216 | Returns: 217 | qparams: an object of the same format as the qparams above. 218 | 219 | skip_layer_func: 220 | A function that determines which linear layers should be skipped during GPTQ 221 | Args: 222 | weight: A 2d weight tensor with non-integer dtype. 223 | Returns: 224 | skip: boolean indicating whether layer should be skipped 225 | 226 | make_names_and_values_dict_func: 227 | A function that prepares the qparams and quantized_weight and creates a dictionary indicating how they 228 | should be inserted into the state_dict. Generally any packing of the weight and qparams should be done here. 229 | Args: 230 | quantized_weight: A 2d quantized weight tensor (generally with an integer dtype) 231 | qparams: the output from get_qparams_func 232 | Returns: 233 | names_and_values_dict: a dictionary mapping the name of the parameters of the quantized module to the 234 | corresponding quantized weights and qparams. 235 | """ 236 | def __init__(self): 237 | assert self.mod is not None 238 | assert self.get_qparams_func is not None 239 | assert self.quantize_func is not None 240 | assert self.dequantize_func is not None 241 | assert self.combine_qparams_list_func is not None 242 | assert self.make_names_and_values_dict_func is not None 243 | 244 | @staticmethod 245 | def get_inputs(model, tokenizer, calibration_tasks, calibration_limit, calibration_seq_length, pad_calibration_inputs) -> "MultiInput": 246 | input_recorder = InputRecorder( 247 | model, 248 | tokenizer, 249 | calibration_seq_length, 250 | pad_calibration_inputs, 251 | ) 252 | 253 | try: 254 | lm_eval.tasks.initialize_tasks() 255 | except: 256 | pass 257 | task_dict = get_task_dict(calibration_tasks) 258 | print("Obtaining GPTQ calibration inputs on: ", calibration_tasks) 259 | 260 | evaluate( 261 | input_recorder, 262 | task_dict, 263 | limit=calibration_limit, 264 | ) 265 | inputs = input_recorder.get_recorded_inputs() 266 | assert inputs is not None, ( 267 | f"No inputs were collected, use a task other than {calibration_tasks}, "+ 268 | f"use option pad_calibration_inputs, or decrease calibration_sequence_length (currently "+ 269 | f"{calibration_seq_length})" 270 | ) 271 | print(f"Obtained {len(inputs[0].values)} calibration samples") 272 | return inputs 273 | 274 | @torch.no_grad() 275 | def create_quantized_state_dict( 276 | self, 277 | tokenizer, 278 | blocksize, 279 | percdamp, 280 | groupsize, 281 | calibration_tasks, 282 | calibration_limit, 283 | calibration_seq_length, 284 | pad_calibration_inputs, 285 | ) -> "StateDict": 286 | inputs = GPTQQuantHandler.get_inputs(self.mod, tokenizer, calibration_tasks, calibration_limit, calibration_seq_length, pad_calibration_inputs) 287 | print("Tracing model for GPTQ") 288 | GPTQ_runner = GenericGPTQRunner( 289 | self.mod, 290 | inputs, 291 | blocksize, 292 | percdamp, 293 | groupsize, 294 | ).configure_quantization_mode( 295 | self.get_qparams_func, 296 | self.quantize_func, 297 | self.dequantize_func, 298 | self.combine_qparams_list_func, 299 | self.make_names_and_values_dict_func, 300 | self.skip_layer_func 301 | ) 302 | 303 | print("Applying GPTQ to weights") 304 | GPTQ_runner.run() 305 | return GPTQ_runner.get_quantized_state_dict() 306 | 307 | def convert_for_runtime(self) -> "nn.Module": 308 | pass 309 | 310 | ##### Weight-only int8 per-channel quantized code ###### 311 | 312 | def replace_linear_weight_only_int8_per_channel(module): 313 | for name, child in module.named_children(): 314 | if isinstance(child, nn.Linear): 315 | setattr(module, name, WeightOnlyInt8Linear(child.in_features, child.out_features)) 316 | else: 317 | replace_linear_weight_only_int8_per_channel(child) 318 | 319 | class WeightOnlyInt8QuantHandler: 320 | def __init__(self, mod): 321 | self.mod = mod 322 | 323 | @torch.no_grad() 324 | def create_quantized_state_dict(self): 325 | cur_state_dict = self.mod.state_dict() 326 | for fqn, mod in self.mod.named_modules(): 327 | if isinstance(mod, torch.nn.Linear): 328 | int8_weight, scales, _ = dynamically_quantize_per_channel(mod.weight.float(), -128, 127, torch.int8) 329 | cur_state_dict[f"{fqn}.weight"] = int8_weight 330 | cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype) 331 | 332 | return cur_state_dict 333 | 334 | def convert_for_runtime(self): 335 | replace_linear_weight_only_int8_per_channel(self.mod) 336 | return self.mod 337 | 338 | 339 | class WeightOnlyInt8Linear(torch.nn.Module): 340 | __constants__ = ['in_features', 'out_features'] 341 | in_features: int 342 | out_features: int 343 | weight: torch.Tensor 344 | 345 | def __init__(self, in_features: int, out_features: int, bias: bool = True, 346 | device=None, dtype=None) -> None: 347 | factory_kwargs = {'device': device, 'dtype': dtype} 348 | super().__init__() 349 | self.in_features = in_features 350 | self.out_features = out_features 351 | self.register_buffer("weight", torch.empty((out_features, in_features), dtype=torch.int8)) 352 | self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16)) 353 | 354 | def forward(self, input: torch.Tensor) -> torch.Tensor: 355 | return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales 356 | 357 | ##### weight only int4 per channel groupwise quantized code ###### 358 | 359 | def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_tiles): 360 | weight_int32, scales_and_zeros = group_quantize_tensor( 361 | weight_bf16, n_bit=4, groupsize=groupsize 362 | ) 363 | weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(weight_int32, inner_k_tiles) 364 | return weight_int4pack, scales_and_zeros 365 | 366 | 367 | def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize): 368 | origin_x_size = x.size() 369 | x = x.reshape(-1, origin_x_size[-1]) 370 | c = torch.ops.aten._weight_int4pack_mm(x, weight_int4pack, groupsize, scales_and_zeros) 371 | new_shape = origin_x_size[:-1] + (out_features,) 372 | c = c.reshape(new_shape) 373 | return c 374 | 375 | 376 | def _check_linear_int4_k(k, groupsize = 1, inner_k_tiles = 1): 377 | return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0 378 | 379 | def replace_linear_int4(module, groupsize, inner_k_tiles, padding): 380 | for name, child in module.named_children(): 381 | if isinstance(child, nn.Linear): 382 | if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles): 383 | setattr(module, name, WeightOnlyInt4Linear( 384 | child.in_features, child.out_features, bias=False, 385 | groupsize=groupsize, inner_k_tiles=inner_k_tiles, padding=False, 386 | )) 387 | elif padding: 388 | setattr(module, name, WeightOnlyInt4Linear( 389 | child.in_features, child.out_features, bias=False, 390 | groupsize=groupsize, inner_k_tiles=inner_k_tiles, padding=True, 391 | )) 392 | else: 393 | replace_linear_int4(child, groupsize, inner_k_tiles, padding) 394 | 395 | 396 | class WeightOnlyInt4QuantHandler: 397 | def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True): 398 | self.mod = mod 399 | self.groupsize = groupsize 400 | self.inner_k_tiles = inner_k_tiles 401 | self.padding = padding 402 | assert groupsize in [32, 64, 128, 256] 403 | assert inner_k_tiles in [2, 4, 8] 404 | 405 | @torch.no_grad() 406 | def create_quantized_state_dict(self, use_cuda = True): 407 | if use_cuda: 408 | device="cuda" 409 | else: 410 | device="cpu" 411 | 412 | cur_state_dict = self.mod.state_dict() 413 | for fqn, mod in self.mod.named_modules(): 414 | if isinstance(mod, torch.nn.Linear): 415 | assert not mod.bias 416 | out_features = mod.out_features 417 | in_features = mod.in_features 418 | assert out_features % 8 == 0, "require out_features % 8 == 0" 419 | print(f"linear: {fqn}, in={in_features}, out={out_features}") 420 | 421 | weight = mod.weight.data 422 | if not _check_linear_int4_k(in_features, self.groupsize, self.inner_k_tiles): 423 | if self.padding: 424 | from model import find_multiple 425 | import torch.nn.functional as F 426 | print(f"warning: {fqn} is padded to satisfy in_features % 1024 == 0") 427 | padded_in_features = find_multiple(in_features, 1024) 428 | weight = F.pad(weight, pad=(0, padded_in_features - in_features)) 429 | else: 430 | print(f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " + 431 | "and that groupsize and inner_k_tiles*16 evenly divide into it") 432 | continue 433 | weight_int4pack, scales_and_zeros = prepare_int4_weight_and_scales_and_zeros( 434 | weight.to(torch.bfloat16).to(device=device), self.groupsize, self.inner_k_tiles 435 | ) 436 | cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to('cpu') 437 | cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to('cpu') 438 | 439 | return cur_state_dict 440 | 441 | def convert_for_runtime(self): 442 | replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding) 443 | return self.mod 444 | 445 | class WeightOnlyInt4GPTQQuantHandler(GPTQQuantHandler): 446 | def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True): 447 | from model import find_multiple 448 | self.mod = mod 449 | self.groupsize = groupsize 450 | self.inner_k_tiles = inner_k_tiles 451 | self.padding = padding 452 | self.get_qparams_func = lambda w: get_group_qparams(w, 4, groupsize) 453 | self.quantize_func = lambda w, qparams: \ 454 | group_quantize_tensor_from_qparams(w, qparams[0], qparams[1], 4, groupsize) 455 | self.dequantize_func = lambda q, qparams: \ 456 | group_dequantize_tensor_from_qparams(q, qparams[0], qparams[1], 4, groupsize).float() 457 | self.combine_qparams_list_func = lambda qparams_list: \ 458 | [torch.cat(x, dim=1) for x in zip(*qparams_list)] 459 | # skip unless padding=True or its correctly sized 460 | self.skip_layer_func = lambda linear_weight: not ( 461 | _check_linear_int4_k(linear_weight.shape[-1], groupsize, inner_k_tiles) or padding 462 | ) 463 | # we need to do the padding here, both for q and the qparams if necessary 464 | def make_names_and_values_dict_func(q, qparams): 465 | k = q.shape[1] 466 | new_k = find_multiple(k, 1024) 467 | # how much we need to pad the weight 468 | delta_k = new_k - q.shape[1] 469 | final_q = torch.ops.aten._convert_weight_to_int4pack(F.pad(q, pad=(0, delta_k)), inner_k_tiles) 470 | scales_and_zeros = pack_scales_and_zeros(*qparams) 471 | # how many new groups we need for padded weight 472 | delta_groups = new_k // groupsize - scales_and_zeros.shape[0] 473 | final_s_and_z = F.pad(scales_and_zeros, pad=(0,0,0,0,0, delta_groups), value=1) 474 | return {"weight": final_q, "scales_and_zeros": final_s_and_z} 475 | self.make_names_and_values_dict_func = make_names_and_values_dict_func 476 | super().__init__() 477 | 478 | 479 | def convert_for_runtime(self): 480 | replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding) 481 | return self.mod 482 | 483 | class WeightOnlyInt4Linear(torch.nn.Module): 484 | __constants__ = ['in_features', 'out_features'] 485 | in_features: int 486 | out_features: int 487 | weight: torch.Tensor 488 | 489 | def __init__( 490 | self, in_features: int, out_features: int, 491 | bias=True, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8, padding: bool = True, 492 | ) -> None: 493 | super().__init__() 494 | self.padding = padding 495 | if padding: 496 | from model import find_multiple 497 | self.origin_in_features = in_features 498 | in_features = find_multiple(in_features, 1024) 499 | 500 | self.in_features = in_features 501 | self.out_features = out_features 502 | assert not bias, "require bias=False" 503 | self.groupsize = groupsize 504 | self.inner_k_tiles = inner_k_tiles 505 | 506 | assert out_features % 8 == 0, "require out_features % 8 == 0" 507 | assert in_features % (inner_k_tiles * 16) == 0, "require in_features % (innerKTiles * 16) == 0" 508 | self.register_buffer( 509 | "weight", 510 | torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32) 511 | ) 512 | self.register_buffer( 513 | "scales_and_zeros", 514 | torch.empty((in_features // groupsize, out_features, 2), dtype=torch.bfloat16) 515 | ) 516 | 517 | def forward(self, input: torch.Tensor) -> torch.Tensor: 518 | input = input.to(torch.bfloat16) 519 | if self.padding: 520 | import torch.nn.functional as F 521 | input = F.pad(input, pad=(0, self.in_features - self.origin_in_features)) 522 | return linear_forward_int4( 523 | input, 524 | self.weight, self.scales_and_zeros, self.out_features, self.groupsize 525 | ) 526 | 527 | 528 | def quantize( 529 | checkpoint_path: Path = Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), 530 | mode: str = 'int8', 531 | # following arguments only available when setting int4 quantization. 532 | groupsize: int = 128, 533 | # following arguments only used for GPTQ 534 | calibration_tasks: list = ["hellaswag"], 535 | calibration_limit: int = 1000, 536 | calibration_seq_length: int = 100, 537 | pad_calibration_inputs: bool = False, 538 | percdamp: float = .01, 539 | blocksize: int = 128, 540 | label: str = '', 541 | ) -> None: 542 | assert checkpoint_path.is_file(), checkpoint_path 543 | 544 | device = 'cpu' 545 | precision = torch.bfloat16 546 | 547 | print("Loading model ...") 548 | t0 = time.time() 549 | 550 | with torch.device('meta'): 551 | model = Transformer.from_name(checkpoint_path.parent.name) 552 | 553 | checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) 554 | model.load_state_dict(checkpoint, assign=True) 555 | model = model.to(dtype=precision, device=device) 556 | 557 | if mode == 'int8': 558 | print("Quantizing model weights for int8 weight-only symmetric per-channel quantization") 559 | quant_handler = WeightOnlyInt8QuantHandler(model) 560 | quantized_state_dict = quant_handler.create_quantized_state_dict() 561 | 562 | dir_name = checkpoint_path.parent 563 | base_name = checkpoint_path.name 564 | new_base_name = base_name.replace('.pth', f'{label}int8.pth') 565 | 566 | elif mode == 'int4': 567 | print("Quantizing model weights for int4 weight-only affine per-channel groupwise quantization") 568 | quant_handler = WeightOnlyInt4QuantHandler(model, groupsize) 569 | quantized_state_dict = quant_handler.create_quantized_state_dict() 570 | 571 | dir_name = checkpoint_path.parent 572 | base_name = checkpoint_path.name 573 | new_base_name = base_name.replace('.pth', f"{label}int4.g{groupsize}.pth") 574 | 575 | elif mode == 'int4-gptq': 576 | print("Quantizing model weights for int4 weight-only affine per-channel groupwise quantization using GPTQ...") 577 | quant_handler = WeightOnlyInt4GPTQQuantHandler(model, groupsize) 578 | 579 | tokenizer_path = checkpoint_path.parent / "tokenizer.model" 580 | assert tokenizer_path.is_file(), str(tokenizer_path) 581 | tokenizer = get_tokenizer(tokenizer_path, checkpoint_path) 582 | 583 | quantized_state_dict = quant_handler.create_quantized_state_dict( 584 | tokenizer, 585 | blocksize, 586 | percdamp, 587 | groupsize, 588 | calibration_tasks, 589 | calibration_limit, 590 | calibration_seq_length, 591 | pad_calibration_inputs 592 | ) 593 | 594 | dir_name = checkpoint_path.parent 595 | base_name = checkpoint_path.name 596 | new_base_name = base_name.replace('.pth', f"{label}int4-gptq.g{groupsize}.pth") 597 | else: 598 | raise ValueError(f"Invalid quantization mode {mode} needs to be one of [int8, int4, int4-gpptq]") 599 | 600 | quantize_path = dir_name / new_base_name 601 | print(f"Writing quantized weights to {quantize_path}") 602 | quantize_path.unlink(missing_ok=True) # remove existing file if one already there 603 | torch.save(quantized_state_dict, quantize_path) 604 | print(f"Quantization complete took {time.time() - t0:.02f} seconds") 605 | return 606 | 607 | if __name__ == '__main__': 608 | import argparse 609 | parser = argparse.ArgumentParser(description='Quantize a model.') 610 | parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Path to the model checkpoint to be quantized.') 611 | parser.add_argument('--mode', '-q', type=str, default='int8', choices=['int8', 'int4', 'int4-gptq'], help='type of quantization to perform') 612 | parser.add_argument('--groupsize', type=int, default=32, help='Group size for int4 quantization.') 613 | parser.add_argument('--calibration_tasks', type=str, nargs='+', default=['wikitext'], help='tasks to do gptq calibration on, if doing gptq') 614 | parser.add_argument('--calibration_limit', type=int, default=1000, help='number of samples to use for gptq calibration') 615 | parser.add_argument('--calibration_seq_length', type=int, default=100, help='length of sequences to use for gptq calibration') 616 | parser.add_argument('--pad_calibration_inputs', type=bool, default=False, help='pads sequences shorter than calibration_seq_length to that length, yielding more calibration inputs but running much slower') 617 | parser.add_argument('--percdamp', type=float, default=.01, help='gptq percentage dampening') 618 | parser.add_argument('--blocksize', type=int, default=128, help='blocksize for gptq') 619 | parser.add_argument('--label', type=str, default='_', help='label to add to output filename') 620 | 621 | args = parser.parse_args() 622 | quantize(args.checkpoint_path, args.mode, args.groupsize, args.calibration_tasks, args.calibration_limit, args.calibration_seq_length, args.pad_calibration_inputs, args.percdamp, args.blocksize, args.label) 623 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | sentencepiece 3 | tiktoken 4 | blobfile 5 | safetensors 6 | -------------------------------------------------------------------------------- /scripts/convert_hf_checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import json 7 | import re 8 | import shutil 9 | import sys 10 | from pathlib import Path 11 | from typing import Optional 12 | from safetensors.torch import load_file as load_safetensors_file 13 | import torch 14 | 15 | # support running without installing as a package 16 | wd = Path(__file__).parent.parent.resolve() 17 | sys.path.append(str(wd)) 18 | 19 | from model import ModelArgs 20 | 21 | 22 | @torch.inference_mode() 23 | def convert_hf_checkpoint( 24 | *, 25 | checkpoint_dir: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf"), 26 | model_name: Optional[str] = None, 27 | ) -> None: 28 | if model_name is None: 29 | model_name = checkpoint_dir.name 30 | 31 | config = ModelArgs.from_name(model_name) 32 | print(f"Model config {config.__dict__}") 33 | 34 | # Load the json file containing weight mapping 35 | model_map_json_safetensors = checkpoint_dir / 'model.safetensors.index.json' 36 | model_map_json_pytorch = checkpoint_dir / "pytorch_model.bin.index.json" 37 | model_map_json = None 38 | 39 | try: 40 | assert model_map_json_safetensors.is_file() 41 | model_map_json = model_map_json_safetensors 42 | print(f"Found safetensors index at {model_map_json_safetensors}") 43 | except AssertionError: 44 | print(f"{model_map_json_safetensors} not found") 45 | if model_map_json is None: 46 | try: 47 | assert model_map_json_pytorch.is_file() 48 | model_map_json = model_map_json_pytorch 49 | print(f"Found pytorch index at {model_map_json_pytorch}") 50 | except AssertionError: 51 | print(f"{model_map_json_pytorch} not found") 52 | 53 | if model_map_json is None: raise Exception("No model map found!") 54 | 55 | with open(model_map_json) as json_map: 56 | bin_index = json.load(json_map) 57 | 58 | weight_map = { 59 | "model.embed_tokens.weight": "tok_embeddings.weight", 60 | "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight", 61 | "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight", 62 | "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight", 63 | "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight", 64 | 'model.layers.{}.self_attn.rotary_emb.inv_freq': None, 65 | 'model.layers.{}.mlp.gate_proj.weight': 'layers.{}.feed_forward.w1.weight', 66 | "model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight", 67 | "model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight", 68 | "model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight", 69 | "model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight", 70 | "model.norm.weight": "norm.weight", 71 | "lm_head.weight": "output.weight", 72 | } 73 | bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()} 74 | 75 | def permute(w, n_head): 76 | dim = config.dim 77 | return ( 78 | w.view(n_head, 2, config.head_dim // 2, dim) 79 | .transpose(1, 2) 80 | .reshape(config.head_dim * n_head, dim) 81 | ) 82 | 83 | merged_result = {} 84 | for file in sorted(bin_files): 85 | if "safetensors" in str(file): 86 | state_dict = load_safetensors_file(str(file), device="cpu") 87 | merged_result.update(state_dict) 88 | else: 89 | state_dict = torch.load(str(file), map_location="cpu", mmap=True, weights_only=True) 90 | merged_result.update(state_dict) 91 | final_result = {} 92 | for key, value in merged_result.items(): 93 | if "layers" in key: 94 | abstract_key = re.sub(r'(\d+)', '{}', key) 95 | layer_num = re.search(r'\d+', key).group(0) 96 | new_key = weight_map[abstract_key] 97 | if new_key is None: 98 | continue 99 | new_key = new_key.format(layer_num) 100 | else: 101 | new_key = weight_map[key] 102 | 103 | final_result[new_key] = value 104 | 105 | for key in tuple(final_result.keys()): 106 | if "wq" in key: 107 | q = final_result[key] 108 | k = final_result[key.replace("wq", "wk")] 109 | v = final_result[key.replace("wq", "wv")] 110 | q = permute(q, config.n_head) 111 | k = permute(k, config.n_local_heads) 112 | final_result[key.replace("wq", "wqkv")] = torch.cat([q, k, v]) 113 | del final_result[key] 114 | del final_result[key.replace("wq", "wk")] 115 | del final_result[key.replace("wq", "wv")] 116 | print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}") 117 | torch.save(final_result, checkpoint_dir / "model.pth") 118 | if 'llama-3-' in model_name.lower() or 'llama-3.1-' in model_name.lower(): 119 | if 'llama-3.1-405b' in model_name.lower(): 120 | original_dir = checkpoint_dir / "original" / "mp16" 121 | else: 122 | original_dir = checkpoint_dir / "original" 123 | tokenizer_model = original_dir / "tokenizer.model" 124 | tokenizer_model_tiktoken = checkpoint_dir / "tokenizer.model" 125 | print(f"Copying {tokenizer_model} to {tokenizer_model_tiktoken}") 126 | shutil.copy(tokenizer_model, tokenizer_model_tiktoken) 127 | 128 | if __name__ == '__main__': 129 | import argparse 130 | parser = argparse.ArgumentParser(description='Convert HuggingFace checkpoint.') 131 | parser.add_argument('--checkpoint_dir', type=Path, default=Path("checkpoints/meta-llama/llama-2-7b-chat-hf")) 132 | parser.add_argument('--model_name', type=str, default=None) 133 | 134 | args = parser.parse_args() 135 | convert_hf_checkpoint( 136 | checkpoint_dir=args.checkpoint_dir, 137 | model_name=args.model_name, 138 | ) 139 | -------------------------------------------------------------------------------- /scripts/download.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import os 7 | from typing import Optional 8 | 9 | from requests.exceptions import HTTPError 10 | 11 | 12 | def hf_download(repo_id: Optional[str] = None, hf_token: Optional[str] = None) -> None: 13 | from huggingface_hub import snapshot_download 14 | os.makedirs(f"checkpoints/{repo_id}", exist_ok=True) 15 | try: 16 | snapshot_download(repo_id, local_dir=f"checkpoints/{repo_id}", local_dir_use_symlinks=False, token=hf_token) 17 | except HTTPError as e: 18 | if e.response.status_code == 401: 19 | print("You need to pass a valid `--hf_token=...` to download private checkpoints.") 20 | else: 21 | raise e 22 | 23 | if __name__ == '__main__': 24 | import argparse 25 | parser = argparse.ArgumentParser(description='Download data from HuggingFace Hub.') 26 | parser.add_argument('--repo_id', type=str, default="checkpoints/meta-llama/llama-2-7b-chat-hf", help='Repository ID to download from.') 27 | parser.add_argument('--hf_token', type=str, default=None, help='HuggingFace API token.') 28 | 29 | args = parser.parse_args() 30 | hf_download(args.repo_id, args.hf_token) 31 | -------------------------------------------------------------------------------- /scripts/prepare.sh: -------------------------------------------------------------------------------- 1 | python scripts/download.py --repo_id $1 && python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/$1 && python quantize.py --checkpoint_path checkpoints/$1/model.pth --mode int8 2 | -------------------------------------------------------------------------------- /scripts/speculate_34B_bf16.sh: -------------------------------------------------------------------------------- 1 | # 56.80 2 | export MODEL_REPO=codellama/CodeLlama-34b-Python-hf 3 | export DRAFT_MODEL_REPO=codellama/CodeLlama-7b-Python-hf 4 | time python generate.py --compile --draft_checkpoint_path checkpoints/$DRAFT_MODEL_REPO/model_int4.g32.pth --checkpoint_path checkpoints/$MODEL_REPO/model.pth --speculate_k 6 --prompt "def quicksort(arr):" --max_new_tokens 200 --num_samples 50 5 | -------------------------------------------------------------------------------- /scripts/speculate_70B_int4.sh: -------------------------------------------------------------------------------- 1 | # 49.26 2 | export MODEL_REPO=meta-llama/Llama-2-70b-chat-hf 3 | export DRAFT_MODEL_REPO=meta-llama/Llama-2-7b-chat-hf 4 | time python generate.py --compile --draft_checkpoint_path checkpoints/$DRAFT_MODEL_REPO/model_int4.g32.pth --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.pth --speculate_k 4 --prompt "def quicksort(arr):" --max_new_tokens 100 --num_samples 50 --temperature 0 5 | -------------------------------------------------------------------------------- /scripts/speculate_7B_int4.sh: -------------------------------------------------------------------------------- 1 | export MODEL_REPO=codellama/CodeLlama-7b-Python-hf 2 | export DRAFT_MODEL_REPO=PY007/TinyLlama-1.1B-intermediate-step-480k-1T 3 | time python generate.py --compile --draft_checkpoint_path checkpoints/$DRAFT_MODEL_REPO/model_int4.g32.pth --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.pth --speculate_k 5 --prompt "Hi my name is" --max_new_tokens 200 --num_samples 50 --temperature 0 --compile_prefill 4 | -------------------------------------------------------------------------------- /scripts/speculate_tp_70B_bf16.sh: -------------------------------------------------------------------------------- 1 | export MODEL_REPO=meta-llama/Llama-2-70b-chat-hf 2 | export DRAFT_MODEL_REPO=meta-llama/Llama-2-7b-chat-hf 3 | time torchrun --standalone --nproc_per_node=8 generate.py --compile --draft_checkpoint_path checkpoints/$DRAFT_MODEL_REPO/model_int8.pth --checkpoint_path checkpoints/$MODEL_REPO/model.pth --speculate_k 5 --prompt "def quicksort(arr):" --max_new_tokens 200 --num_samples 50 --temperature 0 4 | -------------------------------------------------------------------------------- /scripts/test_flow.sh: -------------------------------------------------------------------------------- 1 | export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf 2 | rm -r checkpoints/$MODEL_REPO 3 | python scripts/download.py --repo_id $MODEL_REPO 4 | python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/$MODEL_REPO 5 | python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth 6 | python generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model_int8.pth --max_new_tokens 100 7 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | from setuptools import setup, find_packages 7 | 8 | setup( 9 | name='gpt-fast', 10 | version='0.1', 11 | packages=find_packages(), 12 | install_requires=[ 13 | 'torch', 14 | ], 15 | description='A simple, fast, pure PyTorch Llama inference engine', 16 | long_description=open('README.md').read(), 17 | long_description_content_type='text/markdown', 18 | url='https://github.com/pytorch-labs/gpt-fast', 19 | ) 20 | -------------------------------------------------------------------------------- /tokenizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sentencepiece as spm 3 | import tiktoken 4 | from tiktoken.load import load_tiktoken_bpe 5 | from pathlib import Path 6 | from typing import Dict 7 | 8 | class TokenizerInterface: 9 | def __init__(self, model_path): 10 | self.model_path = model_path 11 | 12 | def encode(self, text): 13 | raise NotImplementedError("This method should be overridden by subclasses.") 14 | 15 | def decode(self, tokens): 16 | raise NotImplementedError("This method should be overridden by subclasses.") 17 | 18 | def bos_id(self): 19 | raise NotImplementedError("This method should be overridden by subclasses.") 20 | 21 | def eos_id(self): 22 | raise NotImplementedError("This method should be overridden by subclasses.") 23 | 24 | class SentencePieceWrapper(TokenizerInterface): 25 | def __init__(self, model_path): 26 | super().__init__(model_path) 27 | self.processor = spm.SentencePieceProcessor(str(model_path)) 28 | 29 | def encode(self, text): 30 | return self.processor.EncodeAsIds(text) 31 | 32 | def decode(self, tokens): 33 | return self.processor.DecodeIds(tokens) 34 | 35 | def bos_id(self): 36 | return self.processor.bos_id() 37 | 38 | def eos_id(self): 39 | return self.processor.eos_id() 40 | 41 | class TiktokenWrapper(TokenizerInterface): 42 | """ 43 | Tokenizing and encoding/decoding text using the Tiktoken tokenizer. 44 | """ 45 | 46 | special_tokens: Dict[str, int] 47 | 48 | num_reserved_special_tokens = 256 49 | 50 | pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501 51 | 52 | def __init__(self, model_path): 53 | super().__init__(model_path) 54 | assert os.path.isfile(model_path), str(model_path) 55 | mergeable_ranks = load_tiktoken_bpe(str(model_path)) 56 | num_base_tokens = len(mergeable_ranks) 57 | special_tokens = [ 58 | "<|begin_of_text|>", 59 | "<|end_of_text|>", 60 | "<|reserved_special_token_0|>", 61 | "<|reserved_special_token_1|>", 62 | "<|reserved_special_token_2|>", 63 | "<|reserved_special_token_3|>", 64 | "<|start_header_id|>", 65 | "<|end_header_id|>", 66 | "<|reserved_special_token_4|>", 67 | "<|eot_id|>", # end of turn 68 | ] + [ 69 | f"<|reserved_special_token_{i}|>" 70 | for i in range(5, self.num_reserved_special_tokens - 5) 71 | ] 72 | self.special_tokens = { 73 | token: num_base_tokens + i for i, token in enumerate(special_tokens) 74 | } 75 | self.model = tiktoken.Encoding( 76 | name=Path(model_path).name, 77 | pat_str=self.pat_str, 78 | mergeable_ranks=mergeable_ranks, 79 | special_tokens=self.special_tokens, 80 | ) 81 | # BOS / EOS token IDs 82 | self._bos_id: int = self.special_tokens["<|begin_of_text|>"] 83 | self._eos_id: int = self.special_tokens["<|end_of_text|>"] 84 | 85 | def encode(self, text): 86 | return self.model.encode(text) 87 | 88 | def decode(self, tokens): 89 | return self.model.decode(tokens) 90 | 91 | def bos_id(self): 92 | return self._bos_id 93 | 94 | def eos_id(self): 95 | return self._eos_id 96 | 97 | def get_tokenizer(tokenizer_model_path, model_name): 98 | """ 99 | Factory function to get the appropriate tokenizer based on the model name. 100 | 101 | Args: 102 | - tokenizer_model_path (str): The file path to the tokenizer model. 103 | - model_name (str): The name of the model, used to determine the tokenizer type. 104 | 105 | Returns: 106 | - TokenizerInterface: An instance of a tokenizer. 107 | """ 108 | 109 | if "llama-3" in str(model_name).lower(): 110 | return TiktokenWrapper(tokenizer_model_path) 111 | else: 112 | return SentencePieceWrapper(tokenizer_model_path) 113 | -------------------------------------------------------------------------------- /tp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import os 7 | from typing import List, Optional 8 | 9 | import torch 10 | import torch.distributed as dist 11 | from torch import nn 12 | if os.uname().sysname != "Darwin": 13 | from torch.distributed import _functional_collectives as funcol 14 | else: 15 | # Distributed is not supported on MacOS 16 | funcol = None 17 | 18 | from model import Attention, FeedForward, Transformer 19 | from quantize import WeightOnlyInt4Linear 20 | 21 | 22 | def _get_rank() -> int: 23 | return int(os.environ.get("LOCAL_RANK", "0")) 24 | 25 | def is_local(): 26 | return _get_rank() == 0 27 | 28 | def local_break(): 29 | if is_local(): 30 | breakpoint() 31 | dist.barrier() 32 | 33 | def _get_world_size() -> int: 34 | return int(os.environ.get("LOCAL_WORLD_SIZE", "1")) 35 | 36 | def maybe_init_dist() -> Optional[int]: 37 | try: 38 | # provided by torchrun 39 | rank = _get_rank() 40 | world_size = _get_world_size() 41 | 42 | if world_size < 2: 43 | # too few gpus to parallelize, tp is no-op 44 | return None 45 | except KeyError: 46 | # not run via torchrun, no-op 47 | return None 48 | 49 | torch.cuda.set_device(rank) 50 | dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) 51 | return rank 52 | 53 | 54 | def _apply_tp_linear(linear: nn.Linear, style: str, weight_splits: List[int] = []) -> None: 55 | rank = _get_rank() 56 | world_size = _get_world_size() 57 | 58 | # Linear's weight matrix is transposed, and is of shape 59 | # (linear.out_features, linear.in_features) 60 | dim_lookup = { 61 | "colwise": (0, "out_features"), 62 | "rowwise": (1, "in_features") 63 | } 64 | assert style in dim_lookup 65 | shard_dim, size_attr = dim_lookup[style] 66 | 67 | # ensure we can shard evenly 68 | assert getattr(linear, size_attr) % world_size == 0 69 | def shard(x, dim): 70 | assert x.size(dim=dim) % world_size == 0 71 | return torch.tensor_split(x, world_size, dim=dim)[rank] 72 | 73 | def shard_qkv(qkv, dim, weight_splits): 74 | q, k, v = qkv.split(weight_splits, dim=dim) 75 | q = shard(q, dim) 76 | k = shard(k, dim) 77 | v = shard(v, dim) 78 | return torch.cat((q,k,v), dim=dim) 79 | 80 | # shard 81 | if weight_splits: 82 | # attention 83 | assert len(weight_splits) == 3 84 | 85 | if isinstance(linear, WeightOnlyInt4Linear): 86 | sharded_weight = shard_qkv(linear.weight, shard_dim, [i//8 for i in weight_splits]) 87 | linear.scales_and_zeros = shard_qkv(linear.scales_and_zeros, 1 - shard_dim, weight_splits) 88 | else: 89 | sharded_weight = shard_qkv(linear.weight, shard_dim, weight_splits) 90 | if hasattr(linear, "scales") and style == "colwise": 91 | linear.scales = shard_qkv(linear.scales, 0, weight_splits) 92 | else: 93 | sharded_weight = shard(linear.weight, shard_dim) 94 | if isinstance(linear, WeightOnlyInt4Linear): 95 | linear.scales_and_zeros = shard(linear.scales_and_zeros, 1 - shard_dim) 96 | if style == "rowwise": 97 | assert linear.scales_and_zeros.shape[0] * 32 == sharded_weight.shape[1] * sharded_weight.shape[2] * sharded_weight.shape[3] 98 | assert linear.scales_and_zeros.shape[1] == sharded_weight.shape[0] * 8 99 | if hasattr(linear, "scales") and style == "colwise": 100 | linear.scales = shard(linear.scales, 0) 101 | 102 | # local_break() 103 | linear.weight = nn.Parameter(sharded_weight, requires_grad=False) 104 | setattr(linear, size_attr, getattr(linear, size_attr) // world_size) 105 | 106 | # shape info should still be synced 107 | # assert linear.weight.shape == (linear.out_features, linear.in_features) 108 | 109 | 110 | def _apply_tp_ffn(mlp: FeedForward) -> None: 111 | assert hasattr(mlp, "w1") 112 | assert hasattr(mlp, "w3") 113 | assert hasattr(mlp, "w2") 114 | 115 | _apply_tp_linear(mlp.w1, "colwise") 116 | _apply_tp_linear(mlp.w3, "colwise") 117 | _apply_tp_linear(mlp.w2, "rowwise") 118 | 119 | world_size = _get_world_size() 120 | mlp.register_forward_hook(lambda _module, _input, output: funcol.all_reduce( 121 | output, "sum", list(range(world_size)))) 122 | 123 | 124 | def _apply_tp_attn(attn: Attention) -> None: 125 | assert hasattr(attn, "wqkv") 126 | assert hasattr(attn, "wo") 127 | 128 | kv_size = attn.n_local_heads * attn.head_dim 129 | _apply_tp_linear(attn.wqkv, "colwise", [attn.dim, kv_size, kv_size]) 130 | _apply_tp_linear(attn.wo, "rowwise") 131 | 132 | # overwrite 133 | world_size = _get_world_size() 134 | attn.n_head = attn.n_head // world_size 135 | attn.dim = attn.dim // world_size 136 | attn.head_dim = attn.dim // attn.n_head 137 | attn.n_local_heads = attn.n_local_heads // world_size 138 | 139 | attn.register_forward_hook(lambda _module, _input, output: funcol.all_reduce( 140 | output[0], "sum", list(range(world_size)))) 141 | 142 | 143 | def _apply_tp_Transformer(Transformer: Transformer) -> None: 144 | # overwrite config before Transformer.setup_cache is called 145 | world_size = _get_world_size() 146 | Transformer.config.n_head = Transformer.config.n_head // world_size 147 | Transformer.config.dim = Transformer.config.dim // world_size 148 | Transformer.config.n_local_heads = Transformer.config.n_local_heads // world_size 149 | 150 | 151 | def apply_tp(model: Transformer) -> None: 152 | _apply_tp_Transformer(model) 153 | for block in model.layers: 154 | # Apply to MLP 155 | _apply_tp_ffn(block.feed_forward) 156 | _apply_tp_attn(block.attention) 157 | --------------------------------------------------------------------------------