├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── GPTQ.py ├── LICENSE ├── README.md ├── download.py ├── eval.py ├── generate.py ├── model.py ├── quantize.py ├── requirements.txt ├── script.py ├── scripts ├── .ipynb_checkpoints │ ├── convert_hf_checkpoint-checkpoint.py │ └── prepare-checkpoint.sh ├── convert_hf_checkpoint.py ├── download.py ├── prepare.sh ├── speculate_34B_bf16.sh ├── speculate_70B_int4.sh ├── speculate_7B_int4.sh ├── speculate_tp_70B_bf16.sh └── test_flow.sh ├── setup.py └── tp.py /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at . All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 72 | 73 | [homepage]: https://www.contributor-covenant.org 74 | 75 | For answers to common questions about this code of conduct, see 76 | https://www.contributor-covenant.org/faq -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to gpt-fast 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | 6 | ## Pull Requests 7 | We actively welcome your pull requests. 8 | 9 | 1. Fork the repo and create your branch from `main`. 10 | 2. If you've added code that should be tested, add tests. 11 | 3. If you've changed APIs, update the documentation. 12 | 4. Ensure the test suite passes. 13 | 5. Make sure your code lints. 14 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 15 | 16 | ## Contributor License Agreement ("CLA") 17 | In order to accept your pull request, we need you to submit a CLA. You only need 18 | to do this once to work on any of Meta's open source projects. 19 | 20 | Complete your CLA here: 21 | 22 | ## Issues 23 | We use GitHub issues to track public bugs. Please ensure your description is 24 | clear and has sufficient instructions to be able to reproduce the issue. 25 | 26 | Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe 27 | disclosure of security bugs. In those cases, please go through the process 28 | outlined on that page and do not file a public issue. 29 | 30 | ## License 31 | By contributing to `gpt-fast`, you agree that your contributions will be licensed 32 | under the LICENSE file in the root directory of this source tree. 33 | -------------------------------------------------------------------------------- /GPTQ.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | import torch.fx as fx 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torch.utils._pytree import tree_flatten, tree_unflatten 13 | 14 | aten = torch.ops.aten 15 | 16 | from eval import ( 17 | setup_cache_padded_seq_input_pos_max_seq_length_for_prefill, 18 | GPTFastEvalWrapper 19 | ) 20 | 21 | 22 | class InputRecorder(GPTFastEvalWrapper): 23 | """ 24 | This is a fake evaluation wrapper that just records the inputs 25 | so that they can be used in calibration. 26 | 27 | If pad_calibration_inputs is enabled, the input recorder will take 28 | each input and pad/truncate it down to the calibration_seq_length. 29 | It will also edit the model embeddings to be zero for the 0 token used 30 | in padding and avoid any inputs with the 0 token. 31 | 32 | If not, it will only truncate inputs to the desired length. 33 | """ 34 | 35 | def __init__( 36 | self, 37 | model, 38 | tokenizer, 39 | calibration_seq_length, 40 | pad_calibration_inputs=False, 41 | ): 42 | super().__init__(model, tokenizer, calibration_seq_length) 43 | self._model = model 44 | self._tokenizer = tokenizer 45 | self._device = torch.device("cpu") 46 | self.vocab_size = model.config.vocab_size 47 | self.calibration_seq_length = calibration_seq_length 48 | self.pad_calibration_inputs = pad_calibration_inputs 49 | self.inputs = None 50 | 51 | if self.pad_calibration_inputs: 52 | # This is needed for the pad_calibration_inputs option 53 | # to work properly, the 0 token's embeddings are set to 0 so that 54 | # the padded inputs will not affect the model numerics. This token isn't used 55 | # commonly in the eval tasks for the meta-llama tokenizer and we skip any inputs 56 | # where it appears 57 | try: 58 | if isinstance(self._model.transformer.wte, nn.Embedding): 59 | self.mod.transformer.wte.weight.data[0, :] *= 0 60 | except: 61 | print( 62 | "Did not find embeddings in model.transformer.wte, disabling padding" 63 | ) 64 | self.pad_calibration_inputs = False 65 | 66 | 67 | def add_input(self, args): 68 | if self.inputs is None: 69 | self.inputs = [MultiInput([arg]) for arg in args] 70 | else: 71 | self.inputs = [ 72 | multi.add_input(arg) for (multi, arg) in zip(self.inputs, args) 73 | ] 74 | 75 | def get_recorded_inputs(self): 76 | return self.inputs 77 | 78 | def _model_call(self, inps): 79 | inps = inps.squeeze(0) 80 | T = len(inps) 81 | if ( 82 | # can't use inputs that are too short when padding disabled 83 | (T < self.calibration_seq_length and not self.pad_calibration_inputs) 84 | or 85 | # can't use inputs that actually use token we use for padding 86 | (self.pad_calibration_inputs and 0 in inps) 87 | ): 88 | # give random output 89 | return torch.randn( 90 | (1, T, self.vocab_size), dtype=torch.bfloat16, device=self._device 91 | ) 92 | 93 | # pad or truncate to the right size 94 | if T >= self.calibration_seq_length: 95 | inps = inps[: self.calibration_seq_length] 96 | else: 97 | inps = F.pad(inps, (0, self.calibration_seq_length - T)) 98 | 99 | max_new_tokens = 1 100 | ( 101 | seq, 102 | input_pos, 103 | max_seq_length, 104 | ) = setup_cache_padded_seq_input_pos_max_seq_length_for_prefill( 105 | self._model, inps, max_new_tokens, self.max_length 106 | ) 107 | x = seq.index_select(0, input_pos).view(1, -1) 108 | self.add_input((x, input_pos)) 109 | 110 | # output `something` with correct shape to keep eval going 111 | return torch.randn( 112 | (1, T, self.vocab_size), dtype=torch.bfloat16, device=self._device 113 | ) 114 | 115 | 116 | 117 | class MultiInput: 118 | def __init__(self, inputs): 119 | self.values = list(inputs) 120 | 121 | def add_input(self, input): 122 | self.values.append(input) 123 | return self 124 | 125 | def __getitem__(self, slice): 126 | return MultiInput(self.values[slice]) 127 | 128 | def cuda(self): 129 | self.values = [val.cuda() if isinstance(val, torch.Tensor) else val for val in self.values] 130 | 131 | 132 | class GenericGPTQRunner(fx.Interpreter): 133 | """ 134 | This is a generic GPTQ runner that takes an existing model and applies GPTQ. 135 | It uses torch._dynamo.export to obtain a graph of the model and then hooks 136 | into function calls and when it detects a linear, it applies GPTQ to the weight 137 | given the calibration of inputs passed in at initialization. It puts the results 138 | into the state_dict so that the quantized model weights/qparams can be loaded 139 | directly into the model. 140 | 141 | This class is expected to work in concert with a GPTQSimpleQuantizer 142 | class to define the specific type of quantization being done. 143 | """ 144 | 145 | def __init__( 146 | self, model, inputs: MultiInput, blocksize=128, percdamp=0.01, groupsize=128 147 | ): 148 | self.id_to_name = { 149 | id(value): name for name, value in dict(model.named_parameters()).items() 150 | } 151 | 152 | # trace model for one input 153 | one_input = [multi.values[0] for multi in inputs] 154 | exported_model = torch._dynamo.export( 155 | model, aten_graph=True, pre_dispatch=True, tracing_mode="fake" 156 | )(*one_input) 157 | super().__init__(exported_model.graph_module) 158 | self.new_state_dict = model.state_dict() 159 | self.blocksize = blocksize 160 | self.percdamp = percdamp 161 | self.groupsize = groupsize 162 | self.inputs = inputs 163 | self.gptq_done = False 164 | self.debug = False 165 | 166 | def configure_quantization_mode( 167 | self, 168 | get_qparams_func, 169 | quantize_func, 170 | dequantize_func, 171 | combine_qparams_list_func, 172 | make_names_and_values_dict_func, 173 | skip_layer_func, 174 | ): 175 | # these functions need to already be curried with all inputs other than weight, qparams 176 | self.get_qparams_func = ( 177 | get_qparams_func # accepts [2d weight tensor], outputs qparams. 178 | ) 179 | 180 | self.quantize_func = quantize_func # accepts [2d weight tensor], [qparams], outputs a 2d quantized tensor of desired dtype 181 | 182 | self.dequantize_func = dequantize_func 183 | # accepts [quantized] tensor and [qparams], outputs a 2d dequantized tensor of type float, 184 | # assumes this output .to(w_orig_dtype) is ~eventual desired dequant behavior 185 | 186 | self.combine_qparams_list_func = combine_qparams_list_func 187 | # accepts [`list` of qparams] from quantizing one group at a time, 188 | # outputs a qparams object that could be passed into quant/dequantize_func 189 | 190 | self.skip_layer_func = skip_layer_func # accepts [weight tensor], outputs a bool on whether or not to apply gptq to this layer 191 | 192 | self.make_names_and_values_dict_func = make_names_and_values_dict_func # accepts [2d quantized tensor], [qparams], returns a dict of names, values to put in state_dict 193 | # note any final packing for storage should happen here 194 | return self 195 | 196 | def run(self): 197 | assert ( 198 | self.get_qparams_func is not None 199 | ), "need to configure quantization mode before running" 200 | self.gptq_done = True 201 | super().run(*self.inputs) 202 | 203 | def get_quantized_state_dict(self): 204 | assert ( 205 | self.gptq_done 206 | ), "need to run GPTQRunner before you can get_quantized_state_dict" 207 | quantized_state_dict = self.new_state_dict 208 | # Don't want to store/load the kv_cache so remove it from the state_dict 209 | del_list = [] 210 | for param_fqn in quantized_state_dict: 211 | if "kv_cache" in param_fqn: 212 | del_list.append(param_fqn) 213 | for param_fqn in del_list: 214 | quantized_state_dict.pop(param_fqn) 215 | return quantized_state_dict 216 | 217 | def call_function(self, target, args, kwargs, skip_quant=False): 218 | def tensors_to_cuda(args): 219 | new_args = [] 220 | for x in args: 221 | new_args.append(x.cuda() if isinstance(x, torch.Tensor) else x) 222 | return new_args 223 | 224 | # flatten args and kwargs together 225 | flat_args, spec = tree_flatten((args, kwargs)) 226 | # move all single tensors to cuda, will move MultiInputs to cuda one at a time 227 | flat_args = tensors_to_cuda(flat_args) 228 | 229 | has_multi_input = MultiInput in [type(x) for x in flat_args] 230 | if has_multi_input: 231 | # Just some trickery to convert 232 | # [MultiInput[a, a, a], MultiInput(b, b, b)] => [a, b], [a, b], [a, b] 233 | multi_input_count = max( 234 | [len(x.values) if isinstance(x, MultiInput) else 1 for x in flat_args] 235 | ) 236 | transposed_args = list( 237 | zip( 238 | *[x.values if isinstance(x, MultiInput) else [x] * multi_input_count for x in flat_args] 239 | ) 240 | ) 241 | else: 242 | transposed_args = [flat_args] 243 | outputs = [] 244 | 245 | # check whether we apply GPTQ to this module 246 | quantize_linear = ( 247 | (target == aten.linear.default) # if its a linear 248 | and id(args[1]) in self.id_to_name # and if we know the layer name 249 | and not skip_quant # and if we weren't told to skip quantization 250 | # and if the skip_layer_func doesn't say we should skip 251 | and not (self.skip_layer_func is not None and self.skip_layer_func(args[1])) 252 | ) # then we will quantize this linear layer/weight 253 | 254 | if quantize_linear: # instantiate variables for GPTQ 255 | H = 0 256 | total_batches = 0 257 | 258 | for inp in transposed_args: 259 | inp = tensors_to_cuda(inp) 260 | cur_args, cur_kwargs = tree_unflatten(inp, spec) 261 | 262 | if ( 263 | quantize_linear 264 | ): # calculate H instead of output (will run the linear eventually with updated weight) 265 | x = cur_args[0].float() 266 | shape = x.shape 267 | n = 1 if len(shape) == 2 else shape[0] 268 | H *= total_batches / (total_batches + n) 269 | total_batches += n 270 | x = ((2 / total_batches) ** (1 / 2)) * x.reshape( 271 | -1, shape[-1] 272 | ).t().float() 273 | H += x.matmul(x.t()) 274 | else: 275 | # get output if its not a linear 276 | out = super().call_function(target, cur_args, cur_kwargs) 277 | 278 | if isinstance(out, torch.Tensor): 279 | outputs.append(out.cpu()) 280 | else: 281 | outputs.append(out) 282 | 283 | if quantize_linear: 284 | mod_fqn = ".".join(self.id_to_name[id(args[1])].split(".")[:-1]) 285 | W = args[1].to(H.device) 286 | Q, DQ, qparams = self.faster_quant(H, W.detach()) 287 | print(mod_fqn) 288 | names_and_values_dict = self.make_names_and_values_dict_func(Q, qparams) 289 | 290 | # delete old weight 291 | if mod_fqn + ".weight" in self.new_state_dict: 292 | self.new_state_dict.pop(mod_fqn + ".weight") 293 | if len(args) > 2: 294 | self.new_state_dict[mod_fqn + ".bias"] = args[2] 295 | for name, value in names_and_values_dict.items(): 296 | self.new_state_dict[mod_fqn + "." + name] = value 297 | 298 | # run linear with new weight to get corrected output 299 | new_out = self.call_function( 300 | target, (args[0], DQ, *args[2:]), kwargs, skip_quant=True 301 | ) 302 | 303 | if self.debug: 304 | old_out = self.call_function( 305 | target, (args[0][:2], args[1], *args[2:]), kwargs, skip_quant=True 306 | ) 307 | 308 | def SQNR(x, y): 309 | return 20 * torch.log10(torch.norm(x) / torch.norm(x - y)) 310 | 311 | DQ_after = self.dequantize_func(Q, qparams).to(W.dtype) 312 | print( 313 | "SQNR for QDQ (this should be inf)", SQNR(DQ, DQ_after) 314 | ) # matches 315 | 316 | print( 317 | "SQNR for weight (can be low)", SQNR(W, DQ.cuda()) 318 | ) # fine to not match 319 | print( 320 | "SQNR for output with GPTQ (hopefully 35+)", 321 | torch.cat( 322 | [ 323 | SQNR(old.cpu(), new.cpu()).unsqueeze(0) 324 | for (old, new) in zip(old_out.values, new_out.values[:2]) 325 | ] 326 | ).mean(), 327 | ) 328 | 329 | qparams2 = self.get_qparams_func(W) 330 | Q2 = self.quantize_func(W, qparams2) 331 | DQ2 = self.dequantize_func(Q2, qparams2).to(W.dtype) 332 | old_q_out = self.call_function( 333 | target, (args[0][:2], DQ2, *args[2:]), kwargs, skip_quant=True 334 | ) 335 | 336 | print("SQNR for output without GPTQ (should be less than above)", 337 | torch.cat([ 338 | SQNR(old.cpu(), old_q.cpu()).unsqueeze(0) 339 | for (old, old_q) in zip(old_out.values, old_q_out.values) 340 | ]).mean(), 341 | ) 342 | return new_out 343 | 344 | return MultiInput(outputs) if has_multi_input else outputs[0] 345 | 346 | def faster_quant(self, H, W): 347 | percdamp = self.percdamp 348 | blocksize = self.blocksize 349 | groupsize = self.groupsize 350 | orig_dtype = W.dtype 351 | W = W.detach().float() 352 | rows, columns = W.shape[0], W.shape[1] 353 | device = W.device 354 | 355 | if groupsize == -1: 356 | cur_qparams = self.get_qparams_func(W) 357 | dead = torch.diag(H) == 0 358 | H[dead, dead] = 1 359 | W[:, dead] = 0 360 | 361 | Losses = torch.zeros_like(W) 362 | DQ = torch.zeros_like(W) 363 | 364 | damp = percdamp * torch.mean(torch.diag(H)) 365 | diag = torch.arange(columns, device=device) 366 | H[diag, diag] += damp 367 | H = torch.linalg.cholesky(H) 368 | H = torch.cholesky_inverse(H) 369 | H = torch.linalg.cholesky(H, upper=True) 370 | Hinv = H 371 | 372 | all_qparams = [] 373 | for i1 in range(0, columns, blocksize): 374 | i2 = min(i1 + blocksize, columns) 375 | count = i2 - i1 376 | W1 = W[:, i1:i2].clone() 377 | DQ1 = torch.zeros_like(W1) 378 | Err1 = torch.zeros_like(W1) 379 | Losses1 = torch.zeros_like(W1) 380 | Hinv1 = Hinv[i1:i2, i1:i2] 381 | for i in range(count): 382 | w = W1[:, i] 383 | d = Hinv1[i, i] 384 | 385 | if groupsize != -1 and (i1 + i) % groupsize == 0: # start of new group 386 | cur_qparams = self.get_qparams_func( 387 | W[:, (i1 + i) : (i1 + i + groupsize)] 388 | ) 389 | all_qparams.append(cur_qparams) 390 | 391 | q = self.quantize_func(w.unsqueeze(1), cur_qparams).flatten() 392 | dq = self.dequantize_func(q.unsqueeze(1), cur_qparams).flatten() 393 | 394 | DQ1[:, i] = dq 395 | Losses1[:, i] = (w - dq) ** 2 / d**2 396 | 397 | err1 = (w - dq) / d 398 | W1[:, i:] -= ( 399 | err1.to(Hinv1.dtype).unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) 400 | ) 401 | Err1[:, i] = err1 402 | 403 | DQ[:, i1:i2] = DQ1 404 | Losses[:, i1:i2] = Losses1 / 2 405 | 406 | W[:, i2:] -= Err1.to(Hinv.dtype).matmul(Hinv[i1:i2, i2:]) 407 | 408 | torch.cuda.synchronize() 409 | 410 | if all_qparams == []: 411 | all_qparams.append(cur_qparams) 412 | 413 | # convert a list of qparams objects into a single one. enerally by 414 | # concatenating a bunch of n,1 scale/zeros tensors into a n,num_groups tensor 415 | all_qparams = self.combine_qparams_list_func(all_qparams) 416 | Q = self.quantize_func(DQ, all_qparams) 417 | return Q, DQ.to(orig_dtype), all_qparams 418 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2023 Meta 2 | 3 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 4 | 5 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 6 | 7 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 8 | 9 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 10 | 11 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 12 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # real-time inference demo for paligemma 2 | 3 | ### Setup and run: 4 | ``` 5 | pip3 install -r requirements.txt 6 | export MODEL_REPO=google/paligemma-3b-mix-224 7 | python generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model_int8.pth --prompt "detect car" --vid_path cars.mp4 --vid_start 20 --vid_end 35 --max_new_tokens 10 8 | 9 | ``` 10 | 11 | # gpt-fast 12 | Simple and efficient pytorch-native transformer text generation. 13 | 14 | Featuring: 15 | 1. Very low latency 16 | 2. <1000 lines of python 17 | 3. No dependencies other than PyTorch and sentencepiece 18 | 4. int8/int4 quantization 19 | 5. Speculative decoding 20 | 6. Tensor parallelism 21 | 7. Supports Nvidia and AMD GPUs 22 | 23 | This is *NOT* intended to be a "framework" or "library" - it is intended to show off what kind of performance you can get with native PyTorch :) Please copy-paste and fork as you desire. 24 | 25 | For an in-depth walkthrough of what's in this codebase, see this [blog post](https://pytorch.org/blog/accelerating-generative-ai-2/). 26 | 27 | ## Supported Models 28 | 29 | ### LLaMA family 30 | Please check the rest of this page about benchmark of LLaMA family models. 31 | 32 | ### Mixtral 8x7B 33 | We also supported [Mixtral 8x7B](https://mistral.ai/news/mixtral-of-experts/) which is a high-quality sparse mixture of experts (MoE) model, the average token generation rates are: 34 | 35 | | | 1 GPU | 2 GPU | 4 GPU | 8 GPU | 36 | |------------------|---------|-----------|--------|------------| 37 | |baseline(bfloat16)| OOM | 78.75 | 118.23 | 203.69 | 38 | | int8 | 56.04 | 99.91 | 149.53 | 218.48 | 39 | 40 | Note that the benchmarks run on an 8xA100-80GB, power limited to 330W with a hybrid cube mesh topology. Note that all benchmarks are run at *batch size=1*, making the reported tokens/s numbers equivalent to "tokens/s/user". In addition, they are run with a very small prompt length (just 5 tokens). 41 | 42 | For more details about Mixtral 8x7B, please check [this page](./mixtral-moe) or this [note](https://thonking.substack.com/p/short-supporting-mixtral-in-gpt-fast). 43 | 44 | ## Community 45 | 46 | Projects inspired by gpt-fast in the community: 47 | 48 | - [gpt-blazing](https://github.com/armed-gpt/gpt-blazing): applies the same performance optimization strategy to more models (e.g., baichuan2). 49 | - [gptfast](https://github.com/MDK8888/GPTFast): applies a subset of the performance optimizations to all Huggingface models 50 | 51 | ## Installation 52 | [Download PyTorch nightly](https://pytorch.org/get-started/locally/) 53 | Install sentencepiece and huggingface_hub 54 | ```bash 55 | pip install sentencepiece huggingface_hub 56 | ``` 57 | 58 | To download llama models, go to https://huggingface.co/meta-llama/Llama-2-7b and go through steps to obtain access. 59 | Then login with `huggingface-cli login` 60 | 61 | 62 | 63 | ## Downloading Weights 64 | Models tested/supported 65 | ```text 66 | openlm-research/open_llama_7b 67 | meta-llama/Llama-2-7b-chat-hf 68 | meta-llama/Llama-2-13b-chat-hf 69 | meta-llama/Llama-2-70b-chat-hf 70 | codellama/CodeLlama-7b-Python-hf 71 | codellama/CodeLlama-34b-Python-hf 72 | ``` 73 | 74 | For example, to convert Llama-2-7b-chat-hf 75 | ```bash 76 | export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf 77 | ./scripts/prepare.sh $MODEL_REPO 78 | ``` 79 | 80 | ## Benchmarks 81 | Benchmarks run on an 8xA100-80GB, power limited to 330W with a hybrid cube mesh topology. Note that all benchmarks are run at *batch size=1*, making the reported tokens/s numbers equivalent to "tokens/s/user". In addition, they are run with a very small prompt length (just 5 tokens). 82 | 83 | | Model | Technique | Tokens/Second | Memory Bandwidth (GB/s) | 84 | | -------- | ------- | ------ | ------ | 85 | | Llama-2-7B | Base | 104.9 | 1397.31 | 86 | | | 8-bit | 155.58 | 1069.20 | 87 | | | 4-bit (G=32) | 196.80 | 862.69 | 88 | | Llama-2-70B | Base | OOM || 89 | | | 8-bit | 19.13 | 1322.58 | 90 | | | 4-bit (G=32) | 25.25 | 1097.66 | 91 | 92 | ### Speculative Sampling 93 | [Verifier: Llama-70B (int4), Draft: Llama-7B (int4)](./scripts/speculate_70B_int4.sh): 48.4 tok/s 94 | 95 | ### Tensor Parallelism 96 | | Model | Number of GPUs | Tokens/Second | Memory Bandwidth (GB/s) | 97 | | -------- | ------- | ------ | ------ | 98 | | Llama-2-7B | 1 | 104.9 | 1397.31 | 99 | | | 2 | 168.84 | 1181.99 | 100 | | | 4 | 254.02 | 955.83 | 101 | | | 8 | 328.43 | 704.10 | 102 | | Llama-2-70B | 1 | OOM | | 103 | | | 2 | 21.32 | 1481.87 | 104 | | | 4 | 38.01 | 1340.76 | 105 | | | 8 | 62.50 | 1135.29 | 106 | 107 | ### Tensor Parallelism + Quantization 108 | | Model | Technique | Tokens/Second | Memory Bandwidth (GB/s) | 109 | | -------- | ------- | ------ | ------ | 110 | | Llama-2-70B | Base | 62.50 | 1135.29 | 111 | | | 8-bit | 80.44 | 752.04 | 112 | | | 4-bit (G=32) | 90.77 | 548.10 | 113 | 114 | ### AMD 115 | Benchmarks run on one GCD of a MI-250x. 116 | 117 | | Model | Technique | Tokens/Second | Memory Bandwidth (GB/s) | 118 | | -------- | ------- | ------ | ------ | 119 | | Llama-2-7B | Base | 76.33 | 1028.70 | 120 | | | 8-bit | 101.86 | 700.06 | 121 | 122 | ## Generate Text 123 | 124 | Model definition in `model.py`, generation code in `generate.py`. 125 | 126 | ```bash 127 | python generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model.pth --prompt "Hello, my name is" 128 | ``` 129 | 130 | To squeeze out a little bit more performance, you can also compile the prefill with `--compile_prefill`. This will increase compilation times though. 131 | 132 | ## Quantization 133 | ### Int8 Weight-Only Quantization 134 | To generate this version of the model 135 | ```bash 136 | # Spits out model at checkpoints/$MODEL_REPO/model_int8.pth 137 | python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int8 138 | ``` 139 | To run with int8, just pass the int8 checkpoint to generate.py. 140 | ```bash 141 | python generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model_int8.pth 142 | ``` 143 | 144 | ### Int4 Weight-Only Quantization 145 | To generate int4 version of model 146 | ```bash 147 | # Spits out model at checkpoints/$MODEL_REPO/model_int4.g32.pth 148 | python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4 --groupsize 32 149 | ``` 150 | 151 | To run with int4, just pass the int4 checkpoint to generate.py. 152 | ```bash 153 | python generate.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.pth --compile 154 | ``` 155 | 156 | ## Speculative Sampling 157 | To generate with speculative sampling (DRAFT_MODEL_REPO should point to a smaller model compared with MODEL_REPO). 158 | 159 | In this example, the "smaller" model is just the int8 quantized version of the model. 160 | ``` 161 | export DRAFT_MODEL_REPO=meta-llama/Llama-2-7b-chat-hf 162 | python generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model.pth --draft_checkpoint_path checkpoints/$DRAFT_MODEL_REPO/model_int8.pth 163 | ``` 164 | 165 | Note: Running on an A100 80GB, albeit power-limited to 330 watts. Empirically, seems like peak bandwidth is about 1700 GB/s. 166 | 167 | 168 | ## Tensor Parallelism 169 | ```bash 170 | ENABLE_INTRA_NODE_COMM=1 torchrun --standalone --nproc_per_node=2 generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model.pth 171 | ``` 172 | 173 | ## Experimental 174 | ### Evaluation 175 | We use the EleutherAI evaluation harness to evaluate our model accuracy. To evaluate the accuracy, make sure the evaluation harness is installed and pass your model checkpoint and desired tasks to eval.py. 176 | 177 | ```bash 178 | python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --compile --tasks hellaswag winogrande 179 | ``` 180 | 181 | Note: Generative tasks are currently not supported for gpt-fast 182 | 183 | Installation Instructions for the evaluation harness: https://github.com/EleutherAI/lm-evaluation-harness/tree/master#install 184 | 185 | ### GPTQ 186 | We have a pure pytorch implementation of GPTQ that utilizes torch._dynamo.export to access the model structure. You can generate a GPTQ quantized 187 | version of int4 quantization by using the same command to quantize it but adding 'gptq' to the quantization mode i.e. 188 | ```bash 189 | # Spits out model at checkpoints/$MODEL_REPO/model_int4-gptq.g32.pth 190 | python quantize.py --mode int4-gptq --calibration_tasks wikitext --calibration_seq_length 2048 191 | ``` 192 | 193 | You can then eval or generate text with this model in the same way as above. 194 | 195 | ## License 196 | 197 | `gpt-fast` is released under the [BSD 3](https://github.com/pytorch-labs/gpt-fast/main/LICENSE) license. 198 | 199 | ## Acknowledgements 200 | Thanks to: 201 | * Lightning AI for supporting pytorch and work in flash attention, int8 quantization, and LoRA fine-tuning. 202 | * GGML for driving forward fast, on device inference of LLMs 203 | * Karpathy for spearheading simple, interpretable and fast LLM implementations 204 | * MLC-LLM for pushing 4-bit quantization performance on heterogeneous hardware 205 | -------------------------------------------------------------------------------- /download.py: -------------------------------------------------------------------------------- 1 | from pytube import YouTube 2 | 3 | def download_youtube_video(url, resolution="480p"): 4 | try: 5 | # Create a YouTube object 6 | yt = YouTube(url) 7 | 8 | # Get the stream with the specified resolution 9 | stream = yt.streams.filter(res=resolution, file_extension='mp4').first() 10 | 11 | # If no stream found with the specified resolution 12 | if not stream: 13 | print(f"No stream found with resolution {resolution}. Downloading the highest resolution available.") 14 | stream = yt.streams.get_highest_resolution() 15 | 16 | # Download the video 17 | stream.download() 18 | 19 | print(f"Video downloaded successfully: {yt.title}") 20 | 21 | except Exception as e: 22 | print(f"An error occurred: {e}") 23 | 24 | # Example usage 25 | url = "https://www.youtube.com/watch?v=A28zps9Q-gE" 26 | download_youtube_video(url) 27 | 28 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import sys 7 | import time 8 | from pathlib import Path 9 | from typing import Optional 10 | 11 | import torch 12 | import torch._dynamo.config 13 | import torch._inductor.config 14 | 15 | torch._dynamo.config.automatic_dynamic_shapes = True 16 | torch._inductor.config.triton.unique_kernel_names = True 17 | torch._inductor.config.epilogue_fusion = False 18 | torch._inductor.config.triton.cudagraphs = True 19 | torch._dynamo.config.cache_size_limit = 100000 20 | 21 | from sentencepiece import SentencePieceProcessor 22 | 23 | from model import Transformer 24 | 25 | try: 26 | import lm_eval 27 | lm_eval_available = True 28 | except: 29 | lm_eval_available = False 30 | 31 | from generate import _load_model, encode_tokens, model_forward 32 | 33 | if lm_eval_available: 34 | try: # lm_eval version 0.4 35 | from lm_eval.models.huggingface import HFLM as eval_wrapper 36 | from lm_eval.tasks import get_task_dict 37 | from lm_eval.evaluator import evaluate 38 | except: #lm_eval version 0.3 39 | from lm_eval import base 40 | from lm_eval import tasks 41 | from lm_eval import evaluator 42 | eval_wrapper=base.BaseLM 43 | get_task_dict=tasks.get_task_dict 44 | evaluate=evaluator.evaluate 45 | 46 | 47 | def setup_cache_padded_seq_input_pos_max_seq_length_for_prefill( 48 | model: Transformer, 49 | prompt: torch.Tensor, 50 | max_new_tokens: int, 51 | max_seq_length: Optional[int] = None, 52 | ): 53 | """ 54 | Sets up model cache and does some bookkeeping calculations for prompt, input_pos and max_seq_length 55 | that are needed for prefill or model_forward 56 | 57 | Args: 58 | model (LLaMA): The model whose cache gets set up 59 | prompt (torch.Tensor): Tensor of shape (T) with indices of the prompt sequence. 60 | max_new_tokens (int): The desired maximum number of new tokens that can be generated. 61 | max_seq_length (Optional[int], optional): The maximum sequence length allowed. 62 | 63 | Returns: 64 | seq (torch.Tensor): prompt but padded with zeros to size max_seq_length 65 | input_pos (torch.Tensor): tensor of integers in increasing order 66 | max_seq_length (int): The maximum sequence length allowed, updated based on other numbers 67 | """ 68 | T = prompt.size(0) 69 | T_new = T + max_new_tokens 70 | if max_seq_length is None: 71 | max_seq_length = min(T_new, model.config.block_size) 72 | 73 | device, dtype = prompt.device, prompt.dtype 74 | # create an empty tensor of the expected final shape and fill in the current tokens 75 | empty = torch.empty(T_new, dtype=dtype, device=device) 76 | empty[:T] = prompt 77 | seq = empty 78 | input_pos = torch.arange(0, T, device=device) 79 | 80 | with torch.device(device): 81 | model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) 82 | 83 | return seq, input_pos, max_seq_length 84 | 85 | class GPTFastEvalWrapper(eval_wrapper): 86 | """ 87 | A wrapper class for GPTFast, providing integration with the lm-evaluation-harness library. 88 | """ 89 | def __init__( 90 | self, 91 | model: Transformer, 92 | tokenizer, 93 | max_seq_length: Optional[int]=None, 94 | ): 95 | super().__init__() 96 | self._model = model 97 | self._tokenizer = tokenizer 98 | self._device = torch.device('cuda') 99 | self._max_seq_length = 2048 if max_seq_length is None else max_seq_length 100 | 101 | @property 102 | def eot_token_id(self): 103 | return self._tokenizer.eos_id() 104 | 105 | @property 106 | def max_length(self): 107 | return self._max_seq_length 108 | 109 | @property 110 | def max_gen_toks(self): 111 | return 50 112 | 113 | @property 114 | def batch_size(self): 115 | return 1 116 | 117 | @property 118 | def device(self): 119 | return self._device 120 | 121 | def tok_encode(self, string: str, **kwargs): 122 | encoded = encode_tokens(self._tokenizer, 123 | string, bos=True, device=self._device) 124 | # encoded is a pytorch tensor, but some internal logic in the 125 | # eval harness expects it to be a list instead 126 | # TODO: verify this for multi-batch as well 127 | encoded = encoded.tolist() 128 | return encoded 129 | 130 | def tok_decode(self, tokens): 131 | decoded = self._tokenizer.decode(tokens) 132 | return decoded 133 | 134 | def _model_call(self, inps): 135 | # TODO: make batches work 136 | inps = inps.squeeze(0) 137 | 138 | max_new_tokens = 1 139 | seq, input_pos, max_seq_length = \ 140 | setup_cache_padded_seq_input_pos_max_seq_length_for_prefill( 141 | self._model, 142 | inps, 143 | max_new_tokens, 144 | self.max_length, 145 | ) 146 | x = seq.index_select(0, input_pos).view(1, -1) 147 | logits = model_forward(self._model, x, input_pos) 148 | return logits 149 | 150 | def _model_generate(self, context, max_length, eos_token_id): 151 | raise Exception('unimplemented') 152 | 153 | 154 | @torch.no_grad() 155 | def eval( 156 | model: Transformer, 157 | tokenizer, 158 | tasks: list = ["hellaswag"], 159 | limit: Optional[int] = None, 160 | max_seq_length: Optional[int] = None, 161 | ) -> dict: 162 | """ 163 | Evaluates a language model on a specified task using the lm-evaluation-harness library. 164 | 165 | Args: 166 | model (Transformer): The pre-trained language model to evaluate. 167 | tokenizer: The tokenizer to use for encoding/decoding text. 168 | task (str): The name of the evaluation task to perform. 169 | limit (Optional[int]): The maximum number of samples to evaluate (None for all available). 170 | max_seq_length (Optional[int]): The maximum sequence length allowed for input text. 171 | 172 | Returns: 173 | eval_results (dict): A dictionary of evaluation results for the specified task(s). 174 | """ 175 | model_eval_wrapper = GPTFastEvalWrapper( 176 | model, 177 | tokenizer, 178 | max_seq_length, 179 | ) 180 | 181 | try: 182 | lm_eval.tasks.initialize_tasks() 183 | except: 184 | pass 185 | 186 | if 'hendrycks_test' in tasks: 187 | tasks.remove('hendrycks_test') 188 | tasks += [x for x in lm_eval.tasks.hendrycks_test.create_all_tasks().keys()] 189 | task_dict = get_task_dict(tasks) 190 | 191 | eval_results = evaluate( 192 | model_eval_wrapper, 193 | task_dict, 194 | limit=limit, 195 | ) 196 | return eval_results 197 | 198 | 199 | def main( 200 | checkpoint_path: Path = Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/lit_model.pth"), 201 | compile: bool = False, 202 | tasks: list = ["hellaswag"], 203 | limit: Optional[int] = None, 204 | max_seq_length: Optional[int] = None, 205 | ) -> None: 206 | """Evaluates model on a task from the `lm-evaluation-harness` library. 207 | 208 | Args: 209 | checkpoint_path (Path): The path to the model checkpoint file to load. 210 | compile (bool): Whether or not to compile the model for optimization. 211 | task (Optional[str]): The name of the evaluation task or a list of tasks to perform. 212 | limit (Optional[int]): The maximum number of samples to evaluate (None for all available). 213 | max_seq_length (Optional[int]): The maximum sequence length allowed for input text. 214 | 215 | """ 216 | 217 | assert checkpoint_path.is_file(), checkpoint_path 218 | 219 | tokenizer_path = checkpoint_path.parent / "tokenizer.model" 220 | assert tokenizer_path.is_file(), tokenizer_path 221 | 222 | device = 'cuda' 223 | precision = torch.bfloat16 224 | 225 | print("Loading model ...") 226 | t0 = time.time() 227 | model = _load_model(checkpoint_path, device, precision, False) 228 | 229 | torch.cuda.synchronize() 230 | print(f"Time to load model: {time.time() - t0:.02f} seconds.") 231 | 232 | model.eval() 233 | 234 | tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path)) 235 | 236 | torch.manual_seed(1234) 237 | 238 | if compile: 239 | global model_forward 240 | model_forward = torch.compile(model_forward, mode="reduce-overhead", dynamic=True, fullgraph=True) 241 | torch._inductor.config.coordinate_descent_tuning = True 242 | 243 | t1 = time.time() 244 | result = eval( 245 | model, 246 | tokenizer, 247 | tasks, 248 | limit, 249 | max_seq_length, 250 | ) 251 | print(f"Time to run eval: {time.time() - t1:.02f} seconds.") 252 | print(f"For model {checkpoint_path}") 253 | for task, res in result["results"].items(): 254 | print(f"{task}: {res}") 255 | 256 | 257 | if __name__ == '__main__': 258 | import argparse 259 | parser = argparse.ArgumentParser(description='Your CLI description.') 260 | 261 | parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/lit_model.pth"), help='Model checkpoint path.') 262 | parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') 263 | parser.add_argument('--tasks', nargs='+', type=str, default=["hellaswag"], help='list of lm-eluther tasks to evaluate usage: --tasks task1 task2') 264 | parser.add_argument('--limit', type=int, default=None, help='number of samples to evalulate') 265 | parser.add_argument('--max_seq_length', type=int, default=None, help='maximum length sequence to evaluate') 266 | 267 | args = parser.parse_args() 268 | main( 269 | Path(args.checkpoint_path), args.compile, args.tasks, args.limit, args.max_seq_length, 270 | ) 271 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import itertools 7 | import sys 8 | import time 9 | from pathlib import Path 10 | from typing import Optional, Tuple 11 | 12 | import torch 13 | import torch._dynamo.config 14 | import torch._inductor.config 15 | 16 | def device_sync(device): 17 | if "cuda" in device: 18 | torch.cuda.synchronize() 19 | elif "cpu" in device: 20 | pass 21 | else: 22 | print(f"device={device} is not yet suppported") 23 | 24 | from transformers import AutoProcessor, PaliGemmaForConditionalGeneration 25 | from PIL import Image, ImageDraw, ImageFont 26 | import requests 27 | import torch 28 | import numpy as np 29 | import cv2 30 | 31 | torch._inductor.config.coordinate_descent_tuning = True 32 | torch._inductor.config.triton.unique_kernel_names = True 33 | torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future 34 | 35 | # support running without installing as a package 36 | wd = Path(__file__).parent.parent.resolve() 37 | sys.path.append(str(wd)) 38 | 39 | from sentencepiece import SentencePieceProcessor 40 | 41 | from model import Transformer 42 | 43 | 44 | def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization 45 | q = torch.empty_like(probs_sort).exponential_(1) 46 | return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) 47 | 48 | def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None): 49 | logits = logits / max(temperature, 1e-5) 50 | 51 | if top_k is not None: 52 | v, _ = torch.topk(logits, min(top_k, logits.size(-1))) 53 | pivot = v.select(-1, -1).unsqueeze(-1) 54 | logits = torch.where(logits < pivot, -float("Inf"), logits) 55 | probs = torch.nn.functional.softmax(logits, dim=-1) 56 | return probs 57 | 58 | def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None): 59 | #logits[0, -1, 1] = -10 60 | 61 | probs = logits_to_probs(logits[0, -1], temperature, top_k) 62 | idx_next = multinomial_sample_one_no_sync(probs) 63 | idx_next = torch.tensor([torch.argmax(logits[0, -1])]).to('cuda:0') 64 | 65 | return idx_next, probs 66 | 67 | def prefill(model: Transformer, x: torch.Tensor, embeds: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> torch.Tensor: 68 | # input_pos: [B, S] 69 | logits = model(x, input_pos, embeds=embeds) 70 | return sample(logits, **sampling_kwargs)[0] 71 | 72 | def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> Tuple[torch.Tensor, torch.Tensor]: 73 | # input_pos: [B, 1] 74 | assert input_pos.shape[-1] == 1 75 | logits = model(x, input_pos) 76 | return sample(logits, **sampling_kwargs) 77 | 78 | def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, callback=lambda _: _, **sampling_kwargs): 79 | new_tokens, new_probs = [], [] 80 | for i in range(num_new_tokens): 81 | with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): # Actually better for Inductor to codegen attention here 82 | next_token, next_prob = decode_one_token( 83 | model, cur_token, input_pos, **sampling_kwargs 84 | ) 85 | input_pos += 1 86 | new_tokens.append(next_token.clone()) 87 | callback(new_tokens[-1]) 88 | new_probs.append(next_prob.clone()) 89 | cur_token = next_token.view(1, -1) 90 | 91 | return new_tokens, new_probs 92 | 93 | 94 | def model_forward(model, x, input_pos): 95 | return model(x, input_pos) 96 | 97 | def speculative_decode( 98 | model: Transformer, 99 | draft_model: Transformer, 100 | cur_token: torch.Tensor, 101 | input_pos: int, 102 | speculate_k: int, 103 | **sampling_kwargs 104 | ) -> torch.Tensor: 105 | # draft model inference sequentially 106 | device = cur_token.device 107 | orig_input_pos = torch.tensor([input_pos], dtype=torch.int64, device=cur_token.device) 108 | draft_tokens, draft_probs = decode_n_tokens(draft_model, cur_token.view(1, -1), orig_input_pos.clone(), speculate_k, **sampling_kwargs) 109 | 110 | draft_tokens = torch.cat(draft_tokens) 111 | # parallel inference on target model using draft tokens 112 | target_logits = model_forward( 113 | model, 114 | torch.cat([cur_token.view(1), draft_tokens]).view(1, -1), 115 | torch.arange(input_pos, input_pos + speculate_k + 1, device=cur_token.device) 116 | ) 117 | target_probs = logits_to_probs(target_logits[0], **sampling_kwargs) 118 | draft_probs = torch.stack(draft_probs) 119 | # q: target prob, p: draft prob 120 | # q >= p: always accept draft token 121 | # q < p: q/p prob to accept draft token 122 | p = draft_probs[torch.arange(0, speculate_k, device=device), draft_tokens] 123 | q = target_probs[torch.arange(0, speculate_k, device=device), draft_tokens] 124 | accept_draft_prob = torch.minimum(torch.ones(()), q[:speculate_k]/ p) 125 | rejected_locations = (torch.rand_like(accept_draft_prob) > accept_draft_prob).nonzero() 126 | 127 | if rejected_locations.shape[0] == 0: # All draft tokens have been accepted 128 | accept_length = speculate_k + 1 129 | last_token = multinomial_sample_one_no_sync(target_probs[-1]) 130 | # fill last token into draft model 131 | model_forward( 132 | draft_model, 133 | draft_tokens[-1].view(1, -1), 134 | orig_input_pos + speculate_k, 135 | ) 136 | return torch.cat([draft_tokens, last_token]) 137 | else: 138 | accept_length = rejected_locations[0].item() 139 | p = draft_probs[accept_length] 140 | q = target_probs[accept_length] 141 | new = q - p 142 | new = torch.where(new > 0, new, 0.0) 143 | new = new / new.sum() 144 | next_token = multinomial_sample_one_no_sync(new) 145 | return torch.cat([draft_tokens[:accept_length], next_token]) 146 | 147 | @torch.no_grad() 148 | def generate( 149 | model: Transformer, 150 | prompt: torch.Tensor, 151 | embeds: torch.Tensor, 152 | max_new_tokens: int, 153 | 154 | *, 155 | interactive: bool, 156 | draft_model: Transformer, 157 | speculate_k: Optional[int] = 8, 158 | callback = lambda x: x, 159 | **sampling_kwargs 160 | ) -> torch.Tensor: 161 | """ 162 | Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. 163 | """ 164 | 165 | is_speculative = draft_model is not None 166 | # create an empty tensor of the expected final shape and fill in the current tokens 167 | T = prompt.size(0) 168 | T_new = T + max_new_tokens 169 | if interactive: 170 | max_seq_length = 350 171 | else: 172 | max_seq_length = min(T_new, model.config.block_size) 173 | 174 | device, dtype = prompt.device, prompt.dtype 175 | max_seq_length = max_seq_length + speculate_k + 1 if is_speculative else max_seq_length 176 | with torch.device(device): 177 | model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) 178 | if is_speculative and draft_model is not model: 179 | draft_model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) 180 | 181 | # create an empty tensor of the expected final shape and fill in the current tokens 182 | empty = torch.empty(T_new, dtype=dtype, device=device) 183 | empty[:T] = prompt 184 | seq = empty 185 | input_pos = torch.arange(0, T, device=device) 186 | 187 | print("prefill") 188 | next_token = prefill(model, prompt.view(1, -1), embeds, input_pos, **sampling_kwargs) 189 | if is_speculative: 190 | prefill(draft_model, prompt.view(1, -1), input_pos, **sampling_kwargs) 191 | seq[T] = next_token 192 | 193 | input_pos = torch.tensor([T], device=device, dtype=torch.int) 194 | accept_counts = [0] * (speculate_k + 1) 195 | 196 | if is_speculative: 197 | input_pos = input_pos.item() # for speculative decoding easier to keep on host 198 | while input_pos < T_new - 1: 199 | cur_token = next_token.view(()) 200 | 201 | next_tokens = speculative_decode( 202 | model, draft_model, cur_token, input_pos, speculate_k, **sampling_kwargs 203 | ) 204 | 205 | accept_counts[len(next_tokens) - 1] += 1 206 | num_added = min(T_new - input_pos - 1, len(next_tokens)) 207 | seq[input_pos + 1 : input_pos + num_added + 1] = next_tokens[: num_added] 208 | for i in next_tokens[: num_added,]: 209 | callback(i) 210 | input_pos = input_pos + num_added 211 | next_token = next_tokens[-1] 212 | else: 213 | generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, max_new_tokens - 1, callback=callback, **sampling_kwargs) 214 | seq[T + 1:] = torch.cat(generated_tokens) 215 | 216 | generate_stats = { 217 | 'accept_counts': accept_counts 218 | } 219 | return seq, generate_stats 220 | 221 | def encode_tokens(tokenizer, string, bos=True, device='cuda'): 222 | tokens = tokenizer.encode(string) 223 | if bos: 224 | tokens = [tokenizer.bos_id()] + tokens 225 | return torch.tensor(tokens, dtype=torch.int, device=device) 226 | 227 | def _load_model(checkpoint_path, device, precision, use_tp): 228 | with torch.device('meta'): 229 | model = Transformer.from_name(checkpoint_path.parent.name) 230 | 231 | if "int8" in str(checkpoint_path): 232 | print("Using int8 weight-only quantization!") 233 | from quantize import WeightOnlyInt8QuantHandler 234 | simple_quantizer = WeightOnlyInt8QuantHandler(model) 235 | model = simple_quantizer.convert_for_runtime() 236 | 237 | if "int4" in str(checkpoint_path): 238 | print("Using int4 quantization!") 239 | path_comps = checkpoint_path.name.split(".") 240 | assert path_comps[-2].startswith("g") 241 | groupsize = int(path_comps[-2][1:]) 242 | from quantize import WeightOnlyInt4QuantHandler 243 | simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize) 244 | model = simple_quantizer.convert_for_runtime() 245 | 246 | checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) 247 | model.load_state_dict(checkpoint, assign=True) 248 | 249 | if use_tp: 250 | from tp import apply_tp 251 | print("Applying tensor parallel to model ...") 252 | apply_tp(model) 253 | 254 | #print(model.get_tok_embeddings().bias) 255 | model = model.to(device=device, dtype=torch.bfloat16) 256 | return model.eval() 257 | 258 | B_INST, E_INST = "[INST]", "[/INST]" 259 | 260 | def main( 261 | prompt: str = "Hello, my name is", 262 | vid_path: str = "", 263 | vid_start: int = 1, 264 | vid_end: int = 2, 265 | interactive: bool = False, 266 | max_new_tokens: int = 100, 267 | top_k: int = 200, 268 | temperature: float = 0.0, 269 | checkpoint_path: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), 270 | compile: bool = True, 271 | compile_prefill: bool = False, 272 | profile: Optional[Path] = None, 273 | draft_checkpoint_path: Optional[Path] = None, 274 | speculate_k: int = 5, 275 | device='cuda', 276 | ) -> None: 277 | 278 | 279 | """Generates text samples based on a pre-trained Transformer model and tokenizer. 280 | """ 281 | assert checkpoint_path.is_file(), checkpoint_path 282 | 283 | tokenizer_path = checkpoint_path.parent / "tokenizer.model" 284 | assert tokenizer_path.is_file(), tokenizer_path 285 | 286 | global print 287 | from tp import maybe_init_dist 288 | rank = maybe_init_dist() 289 | use_tp = rank is not None 290 | if use_tp: 291 | if rank != 0: 292 | # only print on rank 0 293 | print = lambda *args, **kwargs: None 294 | 295 | print(f"Using device={device}") 296 | precision = torch.bfloat16 297 | is_speculative = draft_checkpoint_path is not None 298 | is_chat = "chat" in str(checkpoint_path) 299 | 300 | print("Loading model ...") 301 | t0 = time.time() 302 | model = _load_model(checkpoint_path, device, precision, use_tp) 303 | 304 | if is_speculative: 305 | draft_model = _load_model(draft_checkpoint_path, device, precision, use_tp) 306 | else: 307 | draft_model = None 308 | 309 | ### EDIT 310 | 311 | model_id = "google/paligemma-3b-mix-224" 312 | device = "cuda:0" 313 | dtype = torch.bfloat16 314 | 315 | #url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true" 316 | #image = Image.open("sidewalk.jpg") 317 | 318 | #url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true" 319 | #image = Image.open(requests.get(url, stream=True).raw) 320 | 321 | _model = PaliGemmaForConditionalGeneration.from_pretrained( 322 | checkpoint_path.parent, 323 | torch_dtype=dtype, 324 | device_map=device, 325 | revision="bfloat16", 326 | ).eval() 327 | 328 | vision_model = _model.vision_tower 329 | projector = _model.multi_modal_projector 330 | processor = AutoProcessor.from_pretrained(model_id) 331 | 332 | # Instruct the model to create a caption in Spanish 333 | #model_inputs = processor(text=prompt, images=image, return_tensors="pt").to('cuda:0') 334 | 335 | 336 | #input_len = model_inputs["input_ids"].shape[-1] 337 | 338 | device_sync(device=device) # MKG 339 | print(f"Time to load model: {time.time() - t0:.02f} seconds") 340 | 341 | cap = cv2.VideoCapture(vid_path) 342 | fps = cap.get(cv2.CAP_PROP_FPS) 343 | #frame_interval = int(fps // 8) 344 | frames = [] 345 | 346 | 347 | for i in range(int(fps * vid_end)): 348 | ret, frame = cap.read() 349 | 350 | if i > fps * vid_start: 351 | if not ret: 352 | break 353 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 354 | pil_frame = Image.fromarray(frame) 355 | frames.append(pil_frame) 356 | 357 | cap.release() 358 | 359 | out = cv2.VideoWriter('output_video.mp4', cv2.VideoWriter_fourcc(*'mp4v'), fps, (frame.shape[1], frame.shape[0])) 360 | 361 | 362 | #tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path)) 363 | #encoded = encode_tokens(tokenizer, prompt, bos=True, device=device) 364 | 365 | #encoded = model_inputs['input_ids'][0] 366 | #prompt_length = encoded.size(0) 367 | 368 | #torch.manual_seed(1234) 369 | model_size = sum([p.numel() * p.dtype.itemsize for p in itertools.chain(model.parameters(), model.buffers())]) 370 | if compile: 371 | if is_speculative and use_tp: # and ("cuda" in device): 372 | torch._inductor.config.triton.cudagraph_trees = False # Bug with cudagraph trees in this case 373 | 374 | if is_speculative: 375 | global model_forward, logits_to_prob 376 | model_forward = torch.compile(model_forward, mode="reduce-overhead", fullgraph=True) 377 | 378 | global decode_one_token, prefill 379 | decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead", fullgraph=True) 380 | 381 | # Uncomment to squeeze more perf out of prefill 382 | if args.compile_prefill: 383 | prefill = torch.compile(prefill, fullgraph=True, dynamic=True) 384 | 385 | 386 | aggregate_metrics = { 387 | 'tokens_per_sec': [], 388 | 'accept_counts': [], 389 | } 390 | start = -1 if compile else 0 391 | 392 | embed = model.get_tok_embeddings() 393 | 394 | """ 395 | 396 | embedding_values = embed(encoded) 397 | #print(embedding_values) 398 | 399 | img_embed = projector(vision_model(model_inputs.pixel_values.to(dtype=torch.bfloat16)).last_hidden_state) 400 | 401 | img_embed = img_embed / (2048 ** 0.5) 402 | 403 | print(embedding_values.shape) 404 | 405 | embedding_values[:256, :] = img_embed[0] 406 | 407 | embedding_values = embedding_values.unsqueeze(0) 408 | """ 409 | 410 | model_fps = 16 411 | 412 | #print(len(frames)) 413 | bounding_boxes = [] 414 | 415 | 416 | for i, frame in enumerate(frames): 417 | 418 | if i % 2== 0: 419 | 420 | model_inputs = processor(text=prompt, images=frame, return_tensors="pt").to('cuda:0') 421 | encoded = model_inputs['input_ids'][0] 422 | prompt_length = encoded.size(0) 423 | 424 | embedding_values = embed(encoded) 425 | 426 | img_embed = projector(vision_model(model_inputs.pixel_values.to(dtype=torch.bfloat16)).last_hidden_state) 427 | 428 | img_embed = img_embed / (2048 ** 0.5) 429 | #print(embedding_values.shape) 430 | 431 | embedding_values[:256, :] = img_embed[0] 432 | 433 | embedding_values = embedding_values.unsqueeze(0) 434 | 435 | #exit(0) 436 | 437 | #exit(0) 438 | device_sync(device=device) # MKG 439 | if i >= 0 and interactive: 440 | prompt = input("What is your prompt? ") 441 | if is_chat: 442 | prompt = f"{B_INST} {prompt.strip()} {E_INST}" 443 | encoded = encode_tokens(tokenizer, prompt, bos=True, device=device) 444 | 445 | if interactive and i >= 0: 446 | buffer = [] 447 | period_id = tokenizer.encode('.')[0] 448 | done_generating = False 449 | def callback(x): 450 | nonlocal done_generating 451 | if done_generating: 452 | return 453 | buffer.append(tokenizer.decode([period_id] + x.tolist())[1:]) 454 | if x.item() == tokenizer.eos_id(): 455 | done_generating = True 456 | if len(buffer) == 4 or done_generating: 457 | print(''.join(buffer), end='', flush=True) 458 | buffer.clear() 459 | # print(, end='', flush=True) 460 | else: 461 | callback = lambda x : x 462 | t0 = time.perf_counter() 463 | import contextlib 464 | prof = contextlib.nullcontext() 465 | 466 | with prof: 467 | y, metrics = generate( 468 | model, 469 | encoded, 470 | embedding_values, 471 | max_new_tokens, 472 | draft_model=draft_model, 473 | speculate_k=speculate_k, 474 | interactive=interactive, 475 | callback=callback, 476 | temperature=temperature, 477 | top_k=top_k, 478 | ) 479 | aggregate_metrics['accept_counts'].append(metrics['accept_counts']) 480 | if i == -1: 481 | print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds") 482 | continue 483 | if hasattr(prof, "export_chrome_trace"): 484 | if use_tp: 485 | prof.export_chrome_trace(f"{profile}_rank_{rank}.json") 486 | else: 487 | prof.export_chrome_trace(f"{profile}.json") 488 | device_sync(device=device) # MKG 489 | t = time.perf_counter() - t0 490 | 491 | if not interactive: 492 | #print(y) 493 | print(processor.decode(y, skip_special_tokens=True)) 494 | #print(tokenizer.decode(y.tolist())) 495 | else: 496 | print() 497 | 498 | decoded_output = processor.decode(y, skip_special_tokens=True) 499 | tokens_generated = y.size(0) - prompt_length 500 | tokens_sec = tokens_generated / t 501 | 502 | 503 | aggregate_metrics['tokens_per_sec'].append(tokens_sec) 504 | 505 | new_model_fps = int(1 / t) 506 | if new_model_fps != model_fps: 507 | model_fps=new_model_fps 508 | print(f"Model fps {new_model_fps}") 509 | print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec") 510 | print(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s") 511 | 512 | print(processor.decode(y, skip_special_tokens=True)) 513 | if ';' not in decoded_output and ('loc' in decoded_output): 514 | locations = [int(loc.replace('loc', '').replace('<', '').replace('>', '').replace('detect car\n', '').replace('car', '')) for loc in decoded_output.split("><") if 'loc' in loc] 515 | else: 516 | locations = [] 517 | if len(locations) > 0: 518 | bounding_boxes = [] 519 | 520 | if len(locations) > 0: 521 | # Convert locations to bounding boxes 522 | bounding_boxes.append(locations[0]) 523 | 524 | bounding_boxes.append(locations[1]) 525 | bounding_boxes.append(locations[2]) 526 | 527 | bounding_boxes.append(locations[3]) 528 | 529 | def convert_bbox(bbox, original_size=(1024, 1024), target_size=(480, 854)): 530 | """ 531 | Convert bounding box coordinates from the original resolution to the target resolution. 532 | 533 | Parameters: 534 | bbox (tuple): A tuple (x1, y1, x2, y2) representing the bounding box coordinates in the original resolution. 535 | original_size (tuple): A tuple (width, height) representing the original resolution. 536 | target_size (tuple): A tuple (width, height) representing the target resolution. 537 | 538 | Returns: 539 | tuple: A tuple (x1, y1, x2, y2) representing the bounding box coordinates in the target resolution. 540 | """ 541 | original_width, original_height = original_size 542 | target_width, target_height = target_size 543 | 544 | x1, y1, x2, y2 = bbox 545 | 546 | x1 = int(x1 * target_width / original_width) 547 | y1 = int(y1 * target_height / original_height) 548 | x2 = int(x2 * target_width / original_width) 549 | y2 = int(y2 * target_height / original_height) 550 | 551 | return (x1, y1, x2, y2) 552 | 553 | bounding_boxes = convert_bbox(bounding_boxes) 554 | bounding_boxes = [bounding_boxes[1], bounding_boxes[0], bounding_boxes[3], bounding_boxes[2]] 555 | 556 | # Draw bounding boxes on the frame if locations are detected 557 | if bounding_boxes: 558 | draw = ImageDraw.Draw(frame) 559 | #font = ImageFont.truetype("arial.ttf", 20) # Adjust the font and size as needed 560 | 561 | draw.rectangle(bounding_boxes, outline="lime", width=3) 562 | text_position = (bounding_boxes[2] - 5, bounding_boxes[3] - 5) 563 | draw.text(text_position, "car", fill="lime", font_size=30) 564 | 565 | frame = cv2.cvtColor(np.array(frame), cv2.COLOR_RGB2BGR) 566 | out.write(frame) 567 | 568 | 569 | out.release() 570 | print("Video saved as output_video.mp4") 571 | print("==========") 572 | 573 | if is_speculative: 574 | counts_aggregated = [sum(i) for i in zip(*aggregate_metrics['accept_counts'])] 575 | acceptance_probs = [i/sum(counts_aggregated) for i in counts_aggregated] 576 | print(f"Acceptance probs: {acceptance_probs}") 577 | print(f"Mean Accepted: {sum([idx * i for idx, i in enumerate(counts_aggregated)])/sum(counts_aggregated)}") 578 | 579 | print(f"Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f}") 580 | print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") 581 | 582 | 583 | 584 | if __name__ == '__main__': 585 | import argparse 586 | parser = argparse.ArgumentParser(description='Your CLI description.') 587 | 588 | ### NEW PARAMS 589 | 590 | parser.add_argument('--prompt', type=str, default="detect car", help='Input prompt.') 591 | parser.add_argument('--vid_path', type=str, default="", help='path to mp4 video.') 592 | parser.add_argument('--vid_start', type=int, default=0, help='Where in video to start detecting (seconds).') 593 | parser.add_argument('--vid_end', type=int, default=10, help='Where in video to end detecting (seconds).') 594 | 595 | ### OLD PARAMS 596 | 597 | parser.add_argument('--interactive', action='store_true', help='Whether to launch in interactive mode') 598 | parser.add_argument('--max_new_tokens', type=int, default=10, help='Maximum number of new tokens.') 599 | parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.') 600 | parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.') 601 | parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), help='Model checkpoint path.') 602 | parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') 603 | parser.add_argument('--compile_prefill', action='store_true', help='Whether to compile the prefill (improves prefill perf, but higher compile times)') 604 | parser.add_argument('--profile', type=Path, default=None, help='Profile path.') 605 | parser.add_argument('--speculate_k', type=int, default=5, help='Speculative execution depth.') 606 | parser.add_argument('--draft_checkpoint_path', type=Path, default=None, help='Draft checkpoint path.') 607 | parser.add_argument('--device', type=str, default="cuda", help='device to use') 608 | 609 | 610 | args = parser.parse_args() 611 | 612 | main( 613 | args.prompt, args.vid_path, args.vid_start, args.vid_end, args.interactive, args.max_new_tokens, args.top_k, 614 | args.temperature, args.checkpoint_path, args.compile, args.compile_prefill, args.profile, args.draft_checkpoint_path, 615 | args.speculate_k, args.device 616 | ) 617 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | from dataclasses import dataclass 7 | from typing import Optional 8 | 9 | import torch 10 | import torch.nn as nn 11 | from torch import Tensor 12 | from torch.nn import functional as F 13 | 14 | 15 | def find_multiple(n: int, k: int) -> int: 16 | if n % k == 0: 17 | return n 18 | return n + k - (n % k) 19 | 20 | @dataclass 21 | class ModelArgs: 22 | block_size: int = 8192 23 | vocab_size: int = 32000 24 | n_layer: int = 32 25 | n_head: int = 32 26 | dim: int = 4096 27 | intermediate_size: int = None 28 | n_local_heads: int = -1 29 | head_dim: int = None 30 | rope_base: float = 10000 31 | norm_eps: float = 1e-6 32 | 33 | def __post_init__(self): 34 | if self.n_local_heads == -1: 35 | self.n_local_heads = self.n_head 36 | if self.intermediate_size is None: 37 | hidden_dim = 4 * self.dim 38 | n_hidden = int(2 * hidden_dim / 3) 39 | self.intermediate_size = find_multiple(n_hidden, 256) 40 | if self.head_dim is None: 41 | self.head_dim = self.dim // self.n_head 42 | 43 | @classmethod 44 | def from_name(cls, name: str): 45 | if name in transformer_configs: 46 | return cls(**transformer_configs[name]) 47 | # fuzzy search 48 | config = [config for config in transformer_configs if config in str(name).upper() or config in str(name)] 49 | assert len(config) == 1, name 50 | return cls(**transformer_configs[config[0]]) 51 | 52 | 53 | transformer_configs = { 54 | "gemma-2b": dict(dim=2048, vocab_size=256000, n_layer=18, n_head=8, n_local_heads=1, intermediate_size=16384), 55 | 56 | "paligemma-3b-mix-448": dict(block_size = 8192, dim=2048, vocab_size=257216, n_layer=18, n_head=8, n_local_heads=1, intermediate_size=16384, rope_base=10000), 57 | 58 | "paligemma-3b-mix-224": dict(block_size = 8192, dim=2048, vocab_size=257216, n_layer=18, n_head=8, n_local_heads=1, intermediate_size=16384, rope_base=10000), 59 | "gemma-7b": dict(dim=3072, vocab_size=256000, n_layer=28, n_head=16, n_local_heads=16, intermediate_size=24576, head_dim=256), 60 | "CodeLlama-7b-Python-hf": dict(block_size=16384, vocab_size=32000, n_layer=32, dim = 4096, rope_base=1000000), 61 | "7B": dict(n_layer=32, n_head=32, dim=4096), 62 | "13B": dict(n_layer=40, n_head=40, dim=5120), 63 | "30B": dict(n_layer=60, n_head=52, dim=6656), 64 | "34B": dict(n_layer=48, n_head=64, dim=8192, vocab_size=32000, n_local_heads=8, intermediate_size=22016, rope_base=1000000), # CodeLlama-34B-Python-hf 65 | "70B": dict(n_layer=80, n_head=64, dim=8192, n_local_heads=8, intermediate_size=28672), 66 | } 67 | 68 | class KVCache(nn.Module): 69 | def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16): 70 | super().__init__() 71 | cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) 72 | self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype)) 73 | self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype)) 74 | 75 | def update(self, input_pos, k_val, v_val): 76 | # input_pos: [S], k_val: [B, H, S, D] 77 | assert input_pos.shape[0] == k_val.shape[2] 78 | 79 | k_out = self.k_cache 80 | v_out = self.v_cache 81 | k_out[:, :, input_pos] = k_val 82 | v_out[:, :, input_pos] = v_val 83 | 84 | return k_out, v_out 85 | 86 | class Transformer(nn.Module): 87 | def __init__(self, config: ModelArgs) -> None: 88 | super().__init__() 89 | self.config = config 90 | 91 | self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim, padding_idx=0) 92 | self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer)) 93 | self.norm = RMSNorm(config.dim, eps=config.norm_eps) 94 | self.output = nn.Linear(config.dim, config.vocab_size, bias=False) 95 | 96 | self.freqs_cis: Optional[Tensor] = None 97 | self.mask_cache: Optional[Tensor] = None 98 | self.max_batch_size = -1 99 | self.max_seq_length = -1 100 | 101 | 102 | def get_tok_embeddings(self): 103 | return self.tok_embeddings 104 | 105 | def setup_caches(self, max_batch_size, max_seq_length): 106 | if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size: 107 | return 108 | max_seq_length = find_multiple(max_seq_length, 8) 109 | self.max_seq_length = max_seq_length 110 | self.max_batch_size = max_batch_size 111 | for b in self.layers: 112 | b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, self.config.head_dim) 113 | 114 | self.freqs_cis = precompute_freqs_cis(8192 * 2, self.config.head_dim, self.config.rope_base) 115 | self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)) 116 | 117 | #self.causal_mask[:256, :256] = torch.ones_like(self.causal_mask[:256, :256]).to('cuda:0', 118 | # dtype=torch.bool) 119 | 120 | 121 | def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None, embeds: Optional[Tensor] = None) -> Tensor: 122 | assert self.freqs_cis is not None, "Caches must be initialized first" 123 | mask = self.causal_mask[None, None, input_pos] 124 | freqs_cis = self.freqs_cis[input_pos] 125 | 126 | if embeds is not None: 127 | x = embeds 128 | else: 129 | #print("setting embs") 130 | x = self.tok_embeddings(idx) 131 | 132 | x = (self.config.dim ** 0.5) * x 133 | 134 | for i, layer in enumerate(self.layers): 135 | x = layer(x, input_pos, freqs_cis, mask) 136 | x = self.norm(x) 137 | logits = self.output(x) 138 | return logits 139 | 140 | @classmethod 141 | def from_name(cls, name: str): 142 | return cls(ModelArgs.from_name(name)) 143 | 144 | 145 | class TransformerBlock(nn.Module): 146 | def __init__(self, config: ModelArgs) -> None: 147 | super().__init__() 148 | self.attention = Attention(config) 149 | self.feed_forward = FeedForward(config) 150 | self.ffn_norm = RMSNorm(config.dim, config.norm_eps) 151 | self.attention_norm = RMSNorm(config.dim, config.norm_eps) 152 | 153 | def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor: 154 | #print(mask.shape) 155 | #print(mask.shape) 156 | #print(mask.shape) 157 | if mask.shape[2] > 1: 158 | inp_size = mask.shape[2] 159 | #print(mask.shape) 160 | mask[:, :, :inp_size, :inp_size] = torch.ones_like(mask[:, :, :inp_size, :inp_size]).to('cuda:0', dtype=torch.bool) 161 | 162 | #print(mask.shape) 163 | h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos) 164 | out = h + self.feed_forward(self.ffn_norm(h)) 165 | return out 166 | 167 | 168 | class Attention(nn.Module): 169 | def __init__(self, config: ModelArgs): 170 | super().__init__() 171 | assert config.dim % config.n_head == 0 172 | 173 | total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim 174 | # key, query, value projections for all heads, but in a batch 175 | self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False) 176 | self.wo = nn.Linear(config.n_head * config.head_dim, config.dim, bias=False) 177 | self.kv_cache = None 178 | 179 | self.n_head = config.n_head 180 | self.head_dim = config.head_dim 181 | self.n_local_heads = config.n_local_heads 182 | self.dim = config.dim 183 | self._register_load_state_dict_pre_hook(self.load_hook) 184 | 185 | def load_hook(self, state_dict, prefix, *args): 186 | 187 | if prefix + "wq.weight" in state_dict: 188 | wq = state_dict.pop(prefix + "wq.weight") 189 | wk = state_dict.pop(prefix + "wk.weight") 190 | wv = state_dict.pop(prefix + "wv.weight") 191 | state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv]) 192 | 193 | def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: 194 | bsz, seqlen, _ = x.shape 195 | 196 | kv_size = self.n_local_heads * self.head_dim 197 | q, k, v = self.wqkv(x).split([self.n_head * self.head_dim, kv_size, kv_size], dim=-1) 198 | 199 | q = q.view(bsz, seqlen, self.n_head, self.head_dim) 200 | k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim) 201 | v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim) 202 | 203 | q = apply_rotary_emb(q, freqs_cis) 204 | k = apply_rotary_emb(k, freqs_cis) 205 | 206 | q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) 207 | 208 | if self.kv_cache is not None: 209 | k, v = self.kv_cache.update(input_pos, k, v) 210 | 211 | k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) 212 | v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) 213 | y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) 214 | 215 | y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.n_head * self.head_dim) 216 | 217 | y = self.wo(y) 218 | return y 219 | 220 | 221 | class FeedForward(nn.Module): 222 | def __init__(self, config: ModelArgs) -> None: 223 | super().__init__() 224 | self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False) 225 | self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False) 226 | self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False) 227 | 228 | def forward(self, x: Tensor) -> Tensor: 229 | return self.w2(F.gelu(self.w1(x), approximate="tanh") * self.w3(x)) 230 | 231 | 232 | class RMSNorm(nn.Module): 233 | def __init__(self, dim: int, eps: float = 1e-05): 234 | super().__init__() 235 | self.eps = eps 236 | self.weight = nn.Parameter(torch.ones(dim)) 237 | 238 | def _norm(self, x): 239 | return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + 1e-5) 240 | 241 | def forward(self, x: Tensor) -> Tensor: 242 | output = self._norm(x.float()).type_as(x) 243 | return output * (1 + self.weight) 244 | 245 | 246 | def precompute_freqs_cis( 247 | seq_len: int, n_elem: int, base: int = 10000 248 | ) -> Tensor: 249 | freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)) 250 | t = torch.arange(seq_len, device=freqs.device) 251 | freqs = torch.outer(t, freqs) 252 | freqs_cis = torch.polar(torch.ones_like(freqs), freqs) 253 | cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) 254 | return cache.to(dtype=torch.bfloat16) 255 | 256 | 257 | def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: 258 | xshaped = x.float().reshape(*x.shape[:-1], -1, 2) 259 | freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) 260 | x_out2 = torch.stack( 261 | [ 262 | xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1], 263 | xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1], 264 | ], 265 | -1, 266 | ) 267 | 268 | x_out2 = x_out2.flatten(3) 269 | return x_out2.type_as(x) 270 | -------------------------------------------------------------------------------- /quantize.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import time 7 | from pathlib import Path 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from sentencepiece import SentencePieceProcessor 13 | 14 | try: 15 | from GPTQ import GenericGPTQRunner, InputRecorder 16 | from eval import get_task_dict, evaluate, lm_eval 17 | except: 18 | pass 19 | 20 | from model import Transformer 21 | 22 | ##### Quantization Primitives ###### 23 | 24 | def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype): 25 | # assumes symmetric quantization 26 | # assumes axis == 0 27 | # assumes dense memory format 28 | # TODO(future): relax ^ as needed 29 | 30 | # default setup for affine quantization of activations 31 | eps = torch.finfo(torch.float32).eps 32 | 33 | # get min and max 34 | min_val, max_val = torch.aminmax(x, dim=1) 35 | 36 | # calculate scales and zero_points based on min and max 37 | # reference: https://fburl.com/code/srbiybme 38 | min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) 39 | max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) 40 | device = min_val_neg.device 41 | 42 | # reference: https://fburl.com/code/4wll53rk 43 | max_val_pos = torch.max(-min_val_neg, max_val_pos) 44 | scales = max_val_pos / (float(quant_max - quant_min) / 2) 45 | # ensure scales is the same dtype as the original tensor 46 | scales = torch.clamp(scales, min=eps).to(x.dtype) 47 | zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device) 48 | 49 | # quantize based on qmin/qmax/scales/zp 50 | # reference: https://www.internalfb.com/code/fbsource/[8edc275012b1]/fbcode/caffe2/torch/ao/quantization/fx/_decomposed.py?lines=63 51 | x_div = x / scales.unsqueeze(-1) 52 | x_round = torch.round(x_div) 53 | x_zp = x_round + zero_points.unsqueeze(-1) 54 | quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype) 55 | 56 | return quant, scales, zero_points 57 | 58 | def get_group_qparams(w, n_bit=4, groupsize=128): 59 | # needed for GPTQ with padding 60 | if groupsize > w.shape[-1]: 61 | groupsize = w.shape[-1] 62 | assert groupsize > 1 63 | assert w.shape[-1] % groupsize == 0 64 | assert w.dim() == 2 65 | 66 | to_quant = w.reshape(-1, groupsize) 67 | assert torch.isnan(to_quant).sum() == 0 68 | 69 | max_val = to_quant.amax(dim=1, keepdim=True) 70 | min_val = to_quant.amin(dim=1, keepdim=True) 71 | max_int = 2**n_bit - 1 72 | scales = (max_val - min_val).clamp(min=1e-6) / max_int 73 | zeros = min_val + scales * (2 ** (n_bit - 1)) 74 | return scales.to(torch.bfloat16).reshape(w.shape[0], -1), zeros.to( 75 | torch.bfloat16 76 | ).reshape(w.shape[0], -1) 77 | 78 | 79 | def pack_scales_and_zeros(scales, zeros): 80 | assert scales.shape == zeros.shape 81 | assert scales.dtype == torch.bfloat16 82 | assert zeros.dtype == torch.bfloat16 83 | return ( 84 | torch.cat( 85 | [ 86 | scales.reshape(scales.size(0), scales.size(1), 1), 87 | zeros.reshape(zeros.size(0), zeros.size(1), 1), 88 | ], 89 | 2, 90 | ) 91 | .transpose(0, 1) 92 | .contiguous() 93 | ) 94 | 95 | 96 | def unpack_scales_and_zeros(scales_and_zeros): 97 | assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2 98 | assert scales_and_zeros.dtype == torch.float 99 | return torch.split(scales_and_zeros.transpose(0, 1), 1, 2) 100 | 101 | 102 | def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128): 103 | assert groupsize > 1 104 | # needed for GPTQ single column quantize 105 | if groupsize > w.shape[-1] and scales.shape[-1] == 1: 106 | groupsize = w.shape[-1] 107 | 108 | assert w.shape[-1] % groupsize == 0 109 | assert w.dim() == 2 110 | 111 | to_quant = w.reshape(-1, groupsize) 112 | assert torch.isnan(to_quant).sum() == 0 113 | 114 | scales = scales.reshape(-1, 1) 115 | zeros = zeros.reshape(-1, 1) 116 | min_val = zeros - scales * (2 ** (n_bit - 1)) 117 | max_int = 2**n_bit - 1 118 | min_int = 0 119 | w_int32 = ( 120 | to_quant.sub(min_val) 121 | .div(scales) 122 | .round() 123 | .clamp_(min_int, max_int) 124 | .to(torch.int32) 125 | .reshape_as(w) 126 | ) 127 | 128 | return w_int32 129 | 130 | 131 | def group_quantize_tensor(w, n_bit=4, groupsize=128): 132 | scales, zeros = get_group_qparams(w, n_bit, groupsize) 133 | w_int32 = group_quantize_tensor_from_qparams(w, scales, zeros, n_bit, groupsize) 134 | scales_and_zeros = pack_scales_and_zeros(scales, zeros) 135 | return w_int32, scales_and_zeros 136 | 137 | 138 | def group_dequantize_tensor_from_qparams( 139 | w_int32, scales, zeros, n_bit=4, groupsize=128 140 | ): 141 | assert groupsize > 1 142 | # needed for GPTQ single column dequantize 143 | if groupsize > w_int32.shape[-1] and scales.shape[-1] == 1: 144 | groupsize = w_int32.shape[-1] 145 | assert w_int32.shape[-1] % groupsize == 0 146 | assert w_int32.dim() == 2 147 | 148 | w_int32_grouped = w_int32.reshape(-1, groupsize) 149 | scales = scales.reshape(-1, 1) 150 | zeros = zeros.reshape(-1, 1) 151 | 152 | w_dq = ( 153 | w_int32_grouped.sub(2 ** (n_bit - 1)).mul(scales).add(zeros).reshape_as(w_int32) 154 | ) 155 | return w_dq 156 | 157 | 158 | def group_dequantize_tensor(w_int32, scales_and_zeros, n_bit=4, groupsize=128): 159 | scales, zeros = unpack_scales_and_zeros(scales_and_zeros) 160 | return group_dequantize_tensor_from_qparams( 161 | w_int32, scales, zeros, n_bit, groupsize 162 | ) 163 | 164 | class QuantHandler: 165 | def __init__(self, mod): 166 | self.mod = mod 167 | 168 | def create_quantized_state_dict(self) -> "StateDict": 169 | pass 170 | 171 | def convert_for_runtime(self) -> "nn.Module": 172 | pass 173 | 174 | class GPTQQuantHandler(QuantHandler): 175 | """ 176 | This class implements a GPTQ QuantHandler that can be used to apply GPTQ to a model in concert with the GenericGPTQRunner class. 177 | Unlike the base QuantHandler class, the user does not need to implement the create_quantized_state_dict, instead they have to reimplement 178 | __init__ such that it defines the functions for the quantization mode. User is expected to reimplement convert_for_runtime. 179 | 180 | The following functions (which must be defined in __init__) are used to define the quantization mode for both GPTQ and 181 | create_quantized_state_dict. Here is a description of each function. 182 | 183 | get_qparams_func: 184 | A function that calculates the quantization qparams for an input tensor. 185 | Args: 186 | weight: A 2d weight tensor with non-integer dtype. 187 | Returns: 188 | qparams: it can have any format but will need to be handled by the other defined functions below. 189 | 190 | quantize_func: 191 | A function that applies quantization to an input tensor. It should be noted 192 | that this function needs to be able to handle quantizing the entire weight tensor, a single group, 193 | or a single column. 194 | Args: 195 | weight: A 2d weight tensor with non-integer dtype. 196 | qparams: the output from get_qparams_func 197 | Returns: 198 | quantized_weight: A 2d quantized weight tensor (generally with an integer dtype) 199 | 200 | 201 | dequantize_func: 202 | A function that dequantizes an input quantized weight tensor. It should be noted 203 | that this function needs to be able to handle dequantizing the entire weight tensor, a single group, 204 | or a single column. 205 | Args: 206 | quantized_weight: A 2d quantized weight tensor (generally with an integer dtype) 207 | qparams: the output from get_qparams_func 208 | Returns: 209 | weight: A 2d weight tensor with non-integer dtype. 210 | 211 | combine_qparams_list_func: 212 | A function that combines several qparams into one qparam. 213 | Args: 214 | qparams_list: a list of qparams objects, each obtained by calling get_qparams_func 215 | on a single group from a weight tensor 216 | Returns: 217 | qparams: an object of the same format as the qparams above. 218 | 219 | skip_layer_func: 220 | A function that determines which linear layers should be skipped during GPTQ 221 | Args: 222 | weight: A 2d weight tensor with non-integer dtype. 223 | Returns: 224 | skip: boolean indicating whether layer should be skipped 225 | 226 | make_names_and_values_dict_func: 227 | A function that prepares the qparams and quantized_weight and creates a dictionary indicating how they 228 | should be inserted into the state_dict. Generally any packing of the weight and qparams should be done here. 229 | Args: 230 | quantized_weight: A 2d quantized weight tensor (generally with an integer dtype) 231 | qparams: the output from get_qparams_func 232 | Returns: 233 | names_and_values_dict: a dictionary mapping the name of the parameters of the quantized module to the 234 | corresponding quantized weights and qparams. 235 | """ 236 | def __init__(self): 237 | assert self.mod is not None 238 | assert self.get_qparams_func is not None 239 | assert self.quantize_func is not None 240 | assert self.dequantize_func is not None 241 | assert self.combine_qparams_list_func is not None 242 | assert self.make_names_and_values_dict_func is not None 243 | 244 | @staticmethod 245 | def get_inputs(model, tokenizer, calibration_tasks, calibration_limit, calibration_seq_length, pad_calibration_inputs) -> "MultiInput": 246 | input_recorder = InputRecorder( 247 | model, 248 | tokenizer, 249 | calibration_seq_length, 250 | pad_calibration_inputs, 251 | ) 252 | 253 | try: 254 | lm_eval.tasks.initialize_tasks() 255 | except: 256 | pass 257 | task_dict = get_task_dict(calibration_tasks) 258 | print("Obtaining GPTQ calibration inputs on: ", calibration_tasks) 259 | 260 | evaluate( 261 | input_recorder, 262 | task_dict, 263 | limit=calibration_limit, 264 | ) 265 | inputs = input_recorder.get_recorded_inputs() 266 | assert inputs is not None, ( 267 | f"No inputs were collected, use a task other than {calibration_tasks}, "+ 268 | f"use option pad_calibration_inputs, or decrease calibration_sequence_length (currently "+ 269 | f"{calibration_seq_length})" 270 | ) 271 | print(f"Obtained {len(inputs[0].values)} calibration samples") 272 | return inputs 273 | 274 | @torch.no_grad() 275 | def create_quantized_state_dict( 276 | self, 277 | tokenizer, 278 | blocksize, 279 | percdamp, 280 | groupsize, 281 | calibration_tasks, 282 | calibration_limit, 283 | calibration_seq_length, 284 | pad_calibration_inputs, 285 | ) -> "StateDict": 286 | inputs = GPTQQuantHandler.get_inputs(self.mod, tokenizer, calibration_tasks, calibration_limit, calibration_seq_length, pad_calibration_inputs) 287 | print("Tracing model for GPTQ") 288 | GPTQ_runner = GenericGPTQRunner( 289 | self.mod, 290 | inputs, 291 | blocksize, 292 | percdamp, 293 | groupsize, 294 | ).configure_quantization_mode( 295 | self.get_qparams_func, 296 | self.quantize_func, 297 | self.dequantize_func, 298 | self.combine_qparams_list_func, 299 | self.make_names_and_values_dict_func, 300 | self.skip_layer_func 301 | ) 302 | 303 | print("Applying GPTQ to weights") 304 | GPTQ_runner.run() 305 | return GPTQ_runner.get_quantized_state_dict() 306 | 307 | def convert_for_runtime(self) -> "nn.Module": 308 | pass 309 | 310 | ##### Weight-only int8 per-channel quantized code ###### 311 | 312 | def replace_linear_weight_only_int8_per_channel(module): 313 | for name, child in module.named_children(): 314 | if isinstance(child, nn.Linear): 315 | setattr(module, name, WeightOnlyInt8Linear(child.in_features, child.out_features)) 316 | else: 317 | replace_linear_weight_only_int8_per_channel(child) 318 | 319 | class WeightOnlyInt8QuantHandler: 320 | def __init__(self, mod): 321 | self.mod = mod 322 | 323 | @torch.no_grad() 324 | def create_quantized_state_dict(self): 325 | cur_state_dict = self.mod.state_dict() 326 | for fqn, mod in self.mod.named_modules(): 327 | if isinstance(mod, torch.nn.Linear): 328 | int8_weight, scales, _ = dynamically_quantize_per_channel(mod.weight.float(), -128, 127, torch.int8) 329 | cur_state_dict[f"{fqn}.weight"] = int8_weight 330 | cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype) 331 | 332 | return cur_state_dict 333 | 334 | def convert_for_runtime(self): 335 | replace_linear_weight_only_int8_per_channel(self.mod) 336 | return self.mod 337 | 338 | 339 | class WeightOnlyInt8Linear(torch.nn.Module): 340 | __constants__ = ['in_features', 'out_features'] 341 | in_features: int 342 | out_features: int 343 | weight: torch.Tensor 344 | 345 | def __init__(self, in_features: int, out_features: int, bias: bool = True, 346 | device=None, dtype=None) -> None: 347 | factory_kwargs = {'device': device, 'dtype': dtype} 348 | super().__init__() 349 | self.in_features = in_features 350 | self.out_features = out_features 351 | self.register_buffer("weight", torch.empty((out_features, in_features), dtype=torch.int8)) 352 | self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16)) 353 | 354 | def forward(self, input: torch.Tensor) -> torch.Tensor: 355 | return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales 356 | 357 | ##### weight only int4 per channel groupwise quantized code ###### 358 | 359 | def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_tiles): 360 | weight_int32, scales_and_zeros = group_quantize_tensor( 361 | weight_bf16, n_bit=4, groupsize=groupsize 362 | ) 363 | weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(weight_int32, inner_k_tiles) 364 | return weight_int4pack, scales_and_zeros 365 | 366 | 367 | def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize): 368 | origin_x_size = x.size() 369 | x = x.reshape(-1, origin_x_size[-1]) 370 | c = torch.ops.aten._weight_int4pack_mm(x, weight_int4pack, groupsize, scales_and_zeros) 371 | new_shape = origin_x_size[:-1] + (out_features,) 372 | c = c.reshape(new_shape) 373 | return c 374 | 375 | 376 | def _check_linear_int4_k(k, groupsize = 1, inner_k_tiles = 1): 377 | return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0 378 | 379 | def replace_linear_int4(module, groupsize, inner_k_tiles, padding): 380 | for name, child in module.named_children(): 381 | if isinstance(child, nn.Linear): 382 | if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles): 383 | setattr(module, name, WeightOnlyInt4Linear( 384 | child.in_features, child.out_features, bias=False, 385 | groupsize=groupsize, inner_k_tiles=inner_k_tiles, padding=False, 386 | )) 387 | elif padding: 388 | setattr(module, name, WeightOnlyInt4Linear( 389 | child.in_features, child.out_features, bias=False, 390 | groupsize=groupsize, inner_k_tiles=inner_k_tiles, padding=True, 391 | )) 392 | else: 393 | replace_linear_int4(child, groupsize, inner_k_tiles, padding) 394 | 395 | 396 | class WeightOnlyInt4QuantHandler: 397 | def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True): 398 | self.mod = mod 399 | self.groupsize = groupsize 400 | self.inner_k_tiles = inner_k_tiles 401 | self.padding = padding 402 | assert groupsize in [32, 64, 128, 256] 403 | assert inner_k_tiles in [2, 4, 8] 404 | 405 | @torch.no_grad() 406 | def create_quantized_state_dict(self, use_cuda = True): 407 | if use_cuda: 408 | device="cuda" 409 | else: 410 | device="cpu" 411 | 412 | cur_state_dict = self.mod.state_dict() 413 | for fqn, mod in self.mod.named_modules(): 414 | if isinstance(mod, torch.nn.Linear): 415 | assert not mod.bias 416 | out_features = mod.out_features 417 | in_features = mod.in_features 418 | assert out_features % 8 == 0, "require out_features % 8 == 0" 419 | print(f"linear: {fqn}, in={in_features}, out={out_features}") 420 | 421 | weight = mod.weight.data 422 | if not _check_linear_int4_k(in_features, self.groupsize, self.inner_k_tiles): 423 | if self.padding: 424 | from model import find_multiple 425 | import torch.nn.functional as F 426 | print(f"warning: {fqn} is padded to satisfy in_features % 1024 == 0") 427 | padded_in_features = find_multiple(in_features, 1024) 428 | weight = F.pad(weight, pad=(0, padded_in_features - in_features)) 429 | else: 430 | print(f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " + 431 | "and that groupsize and inner_k_tiles*16 evenly divide into it") 432 | continue 433 | weight_int4pack, scales_and_zeros = prepare_int4_weight_and_scales_and_zeros( 434 | weight.to(torch.bfloat16).to(device=device), self.groupsize, self.inner_k_tiles 435 | ) 436 | cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to('cpu') 437 | cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to('cpu') 438 | 439 | return cur_state_dict 440 | 441 | def convert_for_runtime(self): 442 | replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding) 443 | return self.mod 444 | 445 | class WeightOnlyInt4GPTQQuantHandler(GPTQQuantHandler): 446 | def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True): 447 | from model import find_multiple 448 | self.mod = mod 449 | self.groupsize = groupsize 450 | self.inner_k_tiles = inner_k_tiles 451 | self.padding = padding 452 | self.get_qparams_func = lambda w: get_group_qparams(w, 4, groupsize) 453 | self.quantize_func = lambda w, qparams: \ 454 | group_quantize_tensor_from_qparams(w, qparams[0], qparams[1], 4, groupsize) 455 | self.dequantize_func = lambda q, qparams: \ 456 | group_dequantize_tensor_from_qparams(q, qparams[0], qparams[1], 4, groupsize).float() 457 | self.combine_qparams_list_func = lambda qparams_list: \ 458 | [torch.cat(x, dim=1) for x in zip(*qparams_list)] 459 | # skip unless padding=True or its correctly sized 460 | self.skip_layer_func = lambda linear_weight: not ( 461 | _check_linear_int4_k(linear_weight.shape[-1], groupsize, inner_k_tiles) or padding 462 | ) 463 | # we need to do the padding here, both for q and the qparams if necessary 464 | def make_names_and_values_dict_func(q, qparams): 465 | k = q.shape[1] 466 | new_k = find_multiple(k, 1024) 467 | # how much we need to pad the weight 468 | delta_k = new_k - q.shape[1] 469 | final_q = torch.ops.aten._convert_weight_to_int4pack(F.pad(q, pad=(0, delta_k)), inner_k_tiles) 470 | scales_and_zeros = pack_scales_and_zeros(*qparams) 471 | # how many new groups we need for padded weight 472 | delta_groups = new_k // groupsize - scales_and_zeros.shape[0] 473 | final_s_and_z = F.pad(scales_and_zeros, pad=(0,0,0,0,0, delta_groups), value=1) 474 | return {"weight": final_q, "scales_and_zeros": final_s_and_z} 475 | self.make_names_and_values_dict_func = make_names_and_values_dict_func 476 | super().__init__() 477 | 478 | 479 | def convert_for_runtime(self): 480 | replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding) 481 | return self.mod 482 | 483 | class WeightOnlyInt4Linear(torch.nn.Module): 484 | __constants__ = ['in_features', 'out_features'] 485 | in_features: int 486 | out_features: int 487 | weight: torch.Tensor 488 | 489 | def __init__( 490 | self, in_features: int, out_features: int, 491 | bias=True, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8, padding: bool = True, 492 | ) -> None: 493 | super().__init__() 494 | self.padding = padding 495 | if padding: 496 | from model import find_multiple 497 | self.origin_in_features = in_features 498 | in_features = find_multiple(in_features, 1024) 499 | 500 | self.in_features = in_features 501 | self.out_features = out_features 502 | assert not bias, "require bias=False" 503 | self.groupsize = groupsize 504 | self.inner_k_tiles = inner_k_tiles 505 | 506 | assert out_features % 8 == 0, "require out_features % 8 == 0" 507 | assert in_features % (inner_k_tiles * 16) == 0, "require in_features % (innerKTiles * 16) == 0" 508 | self.register_buffer( 509 | "weight", 510 | torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32) 511 | ) 512 | self.register_buffer( 513 | "scales_and_zeros", 514 | torch.empty((in_features // groupsize, out_features, 2), dtype=torch.bfloat16) 515 | ) 516 | 517 | def forward(self, input: torch.Tensor) -> torch.Tensor: 518 | input = input.to(torch.bfloat16) 519 | if self.padding: 520 | import torch.nn.functional as F 521 | input = F.pad(input, pad=(0, self.in_features - self.origin_in_features)) 522 | return linear_forward_int4( 523 | input, 524 | self.weight, self.scales_and_zeros, self.out_features, self.groupsize 525 | ) 526 | 527 | 528 | def quantize( 529 | checkpoint_path: Path = Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), 530 | mode: str = 'int8', 531 | # following arguments only available when setting int4 quantization. 532 | groupsize: int = 128, 533 | # following arguments only used for GPTQ 534 | calibration_tasks: list = ["hellaswag"], 535 | calibration_limit: int = 1000, 536 | calibration_seq_length: int = 100, 537 | pad_calibration_inputs: bool = False, 538 | percdamp: float = .01, 539 | blocksize: int = 128, 540 | label: str = '', 541 | ) -> None: 542 | assert checkpoint_path.is_file(), checkpoint_path 543 | 544 | device = 'cpu' 545 | precision = torch.bfloat16 546 | 547 | print("Loading model ...") 548 | t0 = time.time() 549 | 550 | with torch.device('meta'): 551 | model = Transformer.from_name(checkpoint_path.parent.name) 552 | 553 | checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) 554 | model.load_state_dict(checkpoint, assign=True) 555 | model = model.to(dtype=precision, device=device) 556 | 557 | if mode == 'int8': 558 | print("Quantizing model weights for int8 weight-only symmetric per-channel quantization") 559 | quant_handler = WeightOnlyInt8QuantHandler(model) 560 | quantized_state_dict = quant_handler.create_quantized_state_dict() 561 | 562 | dir_name = checkpoint_path.parent 563 | base_name = checkpoint_path.name 564 | new_base_name = base_name.replace('.pth', f'{label}int8.pth') 565 | 566 | elif mode == 'int4': 567 | print("Quantizing model weights for int4 weight-only affine per-channel groupwise quantization") 568 | quant_handler = WeightOnlyInt4QuantHandler(model, groupsize) 569 | quantized_state_dict = quant_handler.create_quantized_state_dict() 570 | 571 | dir_name = checkpoint_path.parent 572 | base_name = checkpoint_path.name 573 | new_base_name = base_name.replace('.pth', f"{label}int4.g{groupsize}.pth") 574 | 575 | elif mode == 'int4-gptq': 576 | print("Quantizing model weights for int4 weight-only affine per-channel groupwise quantization using GPTQ...") 577 | quant_handler = WeightOnlyInt4GPTQQuantHandler(model, groupsize) 578 | 579 | tokenizer_path = checkpoint_path.parent / "tokenizer.model" 580 | assert tokenizer_path.is_file(), tokenizer_path 581 | tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path)) 582 | 583 | quantized_state_dict = quant_handler.create_quantized_state_dict( 584 | tokenizer, 585 | blocksize, 586 | percdamp, 587 | groupsize, 588 | calibration_tasks, 589 | calibration_limit, 590 | calibration_seq_length, 591 | pad_calibration_inputs 592 | ) 593 | 594 | dir_name = checkpoint_path.parent 595 | base_name = checkpoint_path.name 596 | new_base_name = base_name.replace('.pth', f"{label}int4-gptq.g{groupsize}.pth") 597 | else: 598 | raise ValueError(f"Invalid quantization mode {mode} needs to be one of [int8, int4, int4-gpptq]") 599 | 600 | quantize_path = dir_name / new_base_name 601 | print(f"Writing quantized weights to {quantize_path}") 602 | quantize_path.unlink(missing_ok=True) # remove existing file if one already there 603 | torch.save(quantized_state_dict, quantize_path) 604 | print(f"Quantization complete took {time.time() - t0:.02f} seconds") 605 | return 606 | 607 | if __name__ == '__main__': 608 | import argparse 609 | parser = argparse.ArgumentParser(description='Quantize a model.') 610 | parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Path to the model checkpoint to be quantized.') 611 | parser.add_argument('--mode', '-q', type=str, default='int8', choices=['int8', 'int4', 'int4-gptq'], help='type of quantization to perform') 612 | parser.add_argument('--groupsize', type=int, default=32, help='Group size for int4 quantization.') 613 | parser.add_argument('--calibration_tasks', type=str, nargs='+', default=['wikitext'], help='tasks to do gptq calibration on, if doing gptq') 614 | parser.add_argument('--calibration_limit', type=int, default=1000, help='number of samples to use for gptq calibration') 615 | parser.add_argument('--calibration_seq_length', type=int, default=100, help='length of sequences to use for gptq calibration') 616 | parser.add_argument('--pad_calibration_inputs', type=bool, default=False, help='pads sequences shorter than calibration_seq_length to that length, yielding more calibration inputs but running much slower') 617 | parser.add_argument('--percdamp', type=float, default=.01, help='gptq percentage dampening') 618 | parser.add_argument('--blocksize', type=int, default=128, help='blocksize for gptq') 619 | parser.add_argument('--label', type=str, default='_', help='label to add to output filename') 620 | 621 | args = parser.parse_args() 622 | quantize(args.checkpoint_path, args.mode, args.groupsize, args.calibration_tasks, args.calibration_limit, args.calibration_seq_length, args.pad_calibration_inputs, args.percdamp, args.blocksize, args.label) 623 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.3.0 2 | sentencepiece 3 | tiktoken 4 | accelerate 5 | opencv-python 6 | Pillow 7 | transformers 8 | numpy 9 | 10 | -------------------------------------------------------------------------------- /script.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoProcessor, PaliGemmaForConditionalGeneration 2 | from PIL import Image 3 | import requests 4 | import torch 5 | 6 | model_id = "google/paligemma-3b-mix-224" 7 | device = "cuda:0" 8 | dtype = torch.bfloat16 9 | 10 | url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true" 11 | image = Image.open(requests.get(url, stream=True).raw) 12 | 13 | model = PaliGemmaForConditionalGeneration.from_pretrained( 14 | model_id, 15 | torch_dtype=dtype, 16 | device_map=device, 17 | revision="bfloat16", 18 | ).eval() 19 | processor = AutoProcessor.from_pretrained(model_id) 20 | 21 | # Instruct the model to create a caption in Spanish 22 | prompt = "segment car" 23 | model_inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device) 24 | input_len = model_inputs["input_ids"].shape[-1] 25 | 26 | print(model_inputs['input_ids']) 27 | with torch.inference_mode(): 28 | 29 | generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False) 30 | generation = generation[0] 31 | 32 | #print(generation) 33 | decoded = processor.decode(generation, skip_special_tokens=False) 34 | print(decoded) 35 | 36 | #print(processor.decode([256000])) 37 | -------------------------------------------------------------------------------- /scripts/.ipynb_checkpoints/convert_hf_checkpoint-checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import json 7 | import re 8 | import sys 9 | from pathlib import Path 10 | from typing import Optional 11 | 12 | import torch 13 | 14 | # support running without installing as a package 15 | wd = Path(__file__).parent.parent.resolve() 16 | sys.path.append(str(wd)) 17 | 18 | from model import ModelArgs 19 | 20 | 21 | @torch.inference_mode() 22 | def convert_hf_checkpoint( 23 | *, 24 | checkpoint_dir: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf"), 25 | model_name: Optional[str] = None, 26 | ) -> None: 27 | if model_name is None: 28 | model_name = checkpoint_dir.name 29 | 30 | config = ModelArgs.from_name(model_name) 31 | print(f"Model config {config.__dict__}") 32 | 33 | from safetensors import safe_open 34 | 35 | # Load the json file containing weight mapping 36 | model_map_json = checkpoint_dir / "model.safetensors.index.json" 37 | 38 | assert model_map_json.is_file() 39 | 40 | with open(model_map_json) as json_map: 41 | bin_index = json.load(json_map) 42 | 43 | 44 | 45 | weight_map = { 46 | "language_model.model.embed_tokens.weight": "tok_embeddings.weight", 47 | "language_model.model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight", 48 | "language_model.model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight", 49 | "language_model.model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight", 50 | "language_model.model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight", 51 | 'language_model.model.layers.{}.self_attn.rotary_emb.inv_freq': None, 52 | 'language_model.model.layers.{}.mlp.gate_proj.weight': 'layers.{}.feed_forward.w1.weight', 53 | "language_model.model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight", 54 | "language_model.model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight", 55 | "language_model.model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight", 56 | "language_model.model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight", 57 | "language_model.model.norm.weight": "norm.weight", 58 | "language_model.lm_head.weight": "output.weight", 59 | } 60 | bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()} 61 | 62 | def permute(w, n_head): 63 | dim = config.dim 64 | return ( 65 | w.view(n_head, 2, config.head_dim // 2, dim) 66 | .transpose(1, 2) 67 | .reshape(config.head_dim * n_head, dim) 68 | ) 69 | 70 | merged_result = {} 71 | for file in sorted(bin_files): 72 | state_dict = safe_open(str(file), framework="pt", device='cpu') 73 | state_dict = {k: state_dict.get_tensor(k) for k in state_dict.keys()} 74 | merged_result.update(state_dict) 75 | final_result = {} 76 | for key, value in merged_result.items(): 77 | if 'language_model' in key: 78 | if "layers" in key: 79 | abstract_key = re.sub(r'(\d+)', '{}', key) 80 | layer_num = re.search(r'\d+', key).group(0) 81 | new_key = weight_map[abstract_key] 82 | if new_key is None: 83 | continue 84 | new_key = new_key.format(layer_num) 85 | else: 86 | new_key = weight_map[key] 87 | 88 | final_result[new_key] = value 89 | 90 | for key in tuple(final_result.keys()): 91 | if "wq" in key: 92 | q = final_result[key] 93 | k = final_result[key.replace("wq", "wk")] 94 | v = final_result[key.replace("wq", "wv")] 95 | q = permute(q, config.n_head) 96 | k = permute(k, config.n_local_heads) 97 | final_result[key.replace("wq", "wqkv")] = torch.cat([q, k, v]) 98 | del final_result[key] 99 | del final_result[key.replace("wq", "wk")] 100 | del final_result[key.replace("wq", "wv")] 101 | if "output.weight" not in final_result: 102 | final_result["output.weight"] = final_result["tok_embeddings.weight"] 103 | 104 | 105 | print(final_result["tok_embeddings.weight"]) 106 | print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}") 107 | torch.save(final_result, checkpoint_dir / "model.pth") 108 | 109 | if __name__ == '__main__': 110 | import argparse 111 | parser = argparse.ArgumentParser(description='Convert HuggingFace checkpoint.') 112 | parser.add_argument('--checkpoint_dir', type=Path, default=Path("checkpoints/meta-llama/llama-2-7b-chat-hf")) 113 | parser.add_argument('--model_name', type=str, default=None) 114 | 115 | args = parser.parse_args() 116 | convert_hf_checkpoint( 117 | checkpoint_dir=args.checkpoint_dir, 118 | model_name=args.model_name, 119 | ) 120 | -------------------------------------------------------------------------------- /scripts/.ipynb_checkpoints/prepare-checkpoint.sh: -------------------------------------------------------------------------------- 1 | python scripts/download.py --repo_id $1 && python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/$1 && python quantize.py --checkpoint_path checkpoints/$1/model.pth --mode int8 2 | -------------------------------------------------------------------------------- /scripts/convert_hf_checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import json 7 | import re 8 | import sys 9 | from pathlib import Path 10 | from typing import Optional 11 | 12 | import torch 13 | 14 | # support running without installing as a package 15 | wd = Path(__file__).parent.parent.resolve() 16 | sys.path.append(str(wd)) 17 | 18 | from model import ModelArgs 19 | 20 | 21 | @torch.inference_mode() 22 | def convert_hf_checkpoint( 23 | *, 24 | checkpoint_dir: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf"), 25 | model_name: Optional[str] = None, 26 | ) -> None: 27 | if model_name is None: 28 | model_name = checkpoint_dir.name 29 | 30 | config = ModelArgs.from_name(model_name) 31 | print(f"Model config {config.__dict__}") 32 | 33 | from safetensors import safe_open 34 | 35 | # Load the json file containing weight mapping 36 | model_map_json = checkpoint_dir / "model.safetensors.index.json" 37 | 38 | assert model_map_json.is_file() 39 | 40 | with open(model_map_json) as json_map: 41 | bin_index = json.load(json_map) 42 | 43 | 44 | 45 | weight_map = { 46 | "language_model.model.embed_tokens.weight": "tok_embeddings.weight", 47 | "language_model.model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight", 48 | "language_model.model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight", 49 | "language_model.model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight", 50 | "language_model.model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight", 51 | 'language_model.model.layers.{}.self_attn.rotary_emb.inv_freq': None, 52 | 'language_model.model.layers.{}.mlp.gate_proj.weight': 'layers.{}.feed_forward.w1.weight', 53 | "language_model.model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight", 54 | "language_model.model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight", 55 | "language_model.model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight", 56 | "language_model.model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight", 57 | "language_model.model.norm.weight": "norm.weight", 58 | } 59 | bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()} 60 | 61 | def permute(w, n_head): 62 | dim = config.dim 63 | print(config.head_dim) 64 | print(config.dim) 65 | return ( 66 | w.view(n_head, 2, config.head_dim // 2, dim) 67 | .transpose(1, 2) 68 | .reshape(config.head_dim * n_head, dim) 69 | ) 70 | 71 | merged_result = {} 72 | for file in sorted(bin_files): 73 | state_dict = safe_open(str(file), framework="pt", device='cpu') 74 | state_dict = {k: state_dict.get_tensor(k) for k in state_dict.keys()} 75 | merged_result.update(state_dict) 76 | final_result = {} 77 | for key, value in merged_result.items(): 78 | if 'language_model' in key: 79 | if "layers" in key: 80 | abstract_key = re.sub(r'(\d+)', '{}', key) 81 | layer_num = re.search(r'\d+', key).group(0) 82 | new_key = weight_map[abstract_key] 83 | if new_key is None: 84 | continue 85 | new_key = new_key.format(layer_num) 86 | else: 87 | new_key = weight_map[key] 88 | 89 | final_result[new_key] = value 90 | 91 | for key in tuple(final_result.keys()): 92 | if "wq" in key: 93 | q = final_result[key] 94 | k = final_result[key.replace("wq", "wk")] 95 | v = final_result[key.replace("wq", "wv")] 96 | q = permute(q, config.n_head) 97 | k = permute(k, config.n_local_heads) 98 | final_result[key.replace("wq", "wqkv")] = torch.cat([q, k, v]) 99 | del final_result[key] 100 | del final_result[key.replace("wq", "wk")] 101 | del final_result[key.replace("wq", "wv")] 102 | if "output.weight" not in final_result: 103 | final_result["output.weight"] = final_result["tok_embeddings.weight"] 104 | 105 | 106 | print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}") 107 | torch.save(final_result, checkpoint_dir / "model.pth") 108 | 109 | if __name__ == '__main__': 110 | import argparse 111 | parser = argparse.ArgumentParser(description='Convert HuggingFace checkpoint.') 112 | parser.add_argument('--checkpoint_dir', type=Path, default=Path("checkpoints/meta-llama/llama-2-7b-chat-hf")) 113 | parser.add_argument('--model_name', type=str, default=None) 114 | 115 | args = parser.parse_args() 116 | convert_hf_checkpoint( 117 | checkpoint_dir=args.checkpoint_dir, 118 | model_name=args.model_name, 119 | ) 120 | -------------------------------------------------------------------------------- /scripts/download.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import os 7 | from typing import Optional 8 | 9 | from requests.exceptions import HTTPError 10 | 11 | 12 | def hf_download(repo_id: Optional[str] = None, hf_token: Optional[str] = None) -> None: 13 | from huggingface_hub import snapshot_download 14 | os.makedirs(f"checkpoints/{repo_id}", exist_ok=True) 15 | try: 16 | snapshot_download(repo_id, local_dir=f"checkpoints/{repo_id}", local_dir_use_symlinks=False, token=hf_token) 17 | except HTTPError as e: 18 | if e.response.status_code == 401: 19 | print("You need to pass a valid `--hf_token=...` to download private checkpoints.") 20 | else: 21 | raise e 22 | 23 | if __name__ == '__main__': 24 | import argparse 25 | parser = argparse.ArgumentParser(description='Download data from HuggingFace Hub.') 26 | parser.add_argument('--repo_id', type=str, default="checkpoints/meta-llama/llama-2-7b-chat-hf", help='Repository ID to download from.') 27 | parser.add_argument('--hf_token', type=str, default=None, help='HuggingFace API token.') 28 | 29 | args = parser.parse_args() 30 | hf_download(args.repo_id, args.hf_token) 31 | -------------------------------------------------------------------------------- /scripts/prepare.sh: -------------------------------------------------------------------------------- 1 | python scripts/download.py --repo_id $1 && python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/$1 && python quantize.py --checkpoint_path checkpoints/$1/model.pth --mode int8 2 | -------------------------------------------------------------------------------- /scripts/speculate_34B_bf16.sh: -------------------------------------------------------------------------------- 1 | # 56.80 2 | export MODEL_REPO=codellama/CodeLlama-34b-Python-hf 3 | export DRAFT_MODEL_REPO=codellama/CodeLlama-7b-Python-hf 4 | time python generate.py --compile --draft_checkpoint_path checkpoints/$DRAFT_MODEL_REPO/model_int4.g32.pth --checkpoint_path checkpoints/$MODEL_REPO/model.pth --speculate_k 6 --prompt "def quicksort(arr):" --max_new_tokens 200 --num_samples 50 5 | -------------------------------------------------------------------------------- /scripts/speculate_70B_int4.sh: -------------------------------------------------------------------------------- 1 | # 49.26 2 | export MODEL_REPO=meta-llama/Llama-2-70b-chat-hf 3 | export DRAFT_MODEL_REPO=meta-llama/Llama-2-7b-chat-hf 4 | time python generate.py --compile --draft_checkpoint_path checkpoints/$DRAFT_MODEL_REPO/model_int4.g32.pth --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.pth --speculate_k 4 --prompt "def quicksort(arr):" --max_new_tokens 100 --num_samples 50 --temperature 0 5 | -------------------------------------------------------------------------------- /scripts/speculate_7B_int4.sh: -------------------------------------------------------------------------------- 1 | export MODEL_REPO=codellama/CodeLlama-7b-Python-hf 2 | export DRAFT_MODEL_REPO=PY007/TinyLlama-1.1B-intermediate-step-480k-1T 3 | time python generate.py --compile --draft_checkpoint_path checkpoints/$DRAFT_MODEL_REPO/model_int4.g32.pth --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.pth --speculate_k 5 --prompt "Hi my name is" --max_new_tokens 200 --num_samples 50 --temperature 0 --compile_prefill 4 | -------------------------------------------------------------------------------- /scripts/speculate_tp_70B_bf16.sh: -------------------------------------------------------------------------------- 1 | export MODEL_REPO=meta-llama/Llama-2-70b-chat-hf 2 | export DRAFT_MODEL_REPO=meta-llama/Llama-2-7b-chat-hf 3 | time torchrun --standalone --nproc_per_node=8 generate.py --compile --draft_checkpoint_path checkpoints/$DRAFT_MODEL_REPO/model_int8.pth --checkpoint_path checkpoints/$MODEL_REPO/model.pth --speculate_k 5 --prompt "def quicksort(arr):" --max_new_tokens 200 --num_samples 50 --temperature 0 4 | -------------------------------------------------------------------------------- /scripts/test_flow.sh: -------------------------------------------------------------------------------- 1 | export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf 2 | rm -r checkpoints/$MODEL_REPO 3 | python scripts/download.py --repo_id $MODEL_REPO 4 | python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/$MODEL_REPO 5 | python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth 6 | python generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model_int8.pth --max_new_tokens 100 7 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | from setuptools import setup, find_packages 7 | 8 | setup( 9 | name='gpt-fast', 10 | version='0.1', 11 | packages=find_packages(), 12 | install_requires=[ 13 | 'torch', 14 | ], 15 | description='A simple, fast, pure PyTorch Llama inference engine', 16 | long_description=open('README.md').read(), 17 | long_description_content_type='text/markdown', 18 | url='https://github.com/pytorch-labs/gpt-fast', 19 | ) 20 | -------------------------------------------------------------------------------- /tp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import os 7 | from typing import List, Optional 8 | 9 | import torch 10 | import torch.distributed as dist 11 | from torch import nn 12 | from torch.distributed import _functional_collectives as funcol 13 | 14 | from model import Attention, FeedForward, Transformer 15 | from quantize import WeightOnlyInt4Linear 16 | 17 | 18 | def _get_rank() -> int: 19 | return int(os.environ.get("LOCAL_RANK", "0")) 20 | 21 | def is_local(): 22 | return _get_rank() == 0 23 | 24 | def local_break(): 25 | if is_local(): 26 | breakpoint() 27 | dist.barrier() 28 | 29 | def _get_world_size() -> int: 30 | return int(os.environ.get("LOCAL_WORLD_SIZE", "1")) 31 | 32 | def maybe_init_dist() -> Optional[int]: 33 | try: 34 | # provided by torchrun 35 | rank = _get_rank() 36 | world_size = _get_world_size() 37 | 38 | if world_size < 2: 39 | # too few gpus to parallelize, tp is no-op 40 | return None 41 | except KeyError: 42 | # not run via torchrun, no-op 43 | return None 44 | 45 | torch.cuda.set_device(rank) 46 | dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) 47 | return rank 48 | 49 | 50 | def _apply_tp_linear(linear: nn.Linear, style: str, weight_splits: List[int] = []) -> None: 51 | rank = _get_rank() 52 | world_size = _get_world_size() 53 | 54 | # Linear's weight matrix is transposed, and is of shape 55 | # (linear.out_features, linear.in_features) 56 | dim_lookup = { 57 | "colwise": (0, "out_features"), 58 | "rowwise": (1, "in_features") 59 | } 60 | assert style in dim_lookup 61 | shard_dim, size_attr = dim_lookup[style] 62 | 63 | # ensure we can shard evenly 64 | assert getattr(linear, size_attr) % world_size == 0 65 | def shard(x, dim): 66 | assert x.size(dim=dim) % world_size == 0 67 | return torch.tensor_split(x, world_size, dim=dim)[rank] 68 | 69 | def shard_qkv(qkv, dim, weight_splits): 70 | q, k, v = qkv.split(weight_splits, dim=dim) 71 | q = shard(q, dim) 72 | k = shard(k, dim) 73 | v = shard(v, dim) 74 | return torch.cat((q,k,v), dim=dim) 75 | 76 | # shard 77 | if weight_splits: 78 | # attention 79 | assert len(weight_splits) == 3 80 | 81 | if isinstance(linear, WeightOnlyInt4Linear): 82 | sharded_weight = shard_qkv(linear.weight, shard_dim, [i//8 for i in weight_splits]) 83 | linear.scales_and_zeros = shard_qkv(linear.scales_and_zeros, 1 - shard_dim, weight_splits) 84 | else: 85 | sharded_weight = shard_qkv(linear.weight, shard_dim, weight_splits) 86 | if hasattr(linear, "scales") and style == "colwise": 87 | linear.scales = shard_qkv(linear.scales, 0, weight_splits) 88 | else: 89 | sharded_weight = shard(linear.weight, shard_dim) 90 | if isinstance(linear, WeightOnlyInt4Linear): 91 | linear.scales_and_zeros = shard(linear.scales_and_zeros, 1 - shard_dim) 92 | if style == "rowwise": 93 | assert linear.scales_and_zeros.shape[0] * 32 == sharded_weight.shape[1] * sharded_weight.shape[2] * sharded_weight.shape[3] 94 | assert linear.scales_and_zeros.shape[1] == sharded_weight.shape[0] * 8 95 | if hasattr(linear, "scales") and style == "colwise": 96 | linear.scales = shard(linear.scales, 0) 97 | 98 | # local_break() 99 | linear.weight = nn.Parameter(sharded_weight, requires_grad=False) 100 | setattr(linear, size_attr, getattr(linear, size_attr) // world_size) 101 | 102 | # shape info should still be synced 103 | # assert linear.weight.shape == (linear.out_features, linear.in_features) 104 | 105 | 106 | def _apply_tp_ffn(mlp: FeedForward) -> None: 107 | assert hasattr(mlp, "w1") 108 | assert hasattr(mlp, "w3") 109 | assert hasattr(mlp, "w2") 110 | 111 | _apply_tp_linear(mlp.w1, "colwise") 112 | _apply_tp_linear(mlp.w3, "colwise") 113 | _apply_tp_linear(mlp.w2, "rowwise") 114 | 115 | world_size = _get_world_size() 116 | mlp.register_forward_hook(lambda _module, _input, output: funcol.all_reduce( 117 | output, "sum", list(range(world_size)))) 118 | 119 | 120 | def _apply_tp_attn(attn: Attention) -> None: 121 | assert hasattr(attn, "wqkv") 122 | assert hasattr(attn, "wo") 123 | 124 | kv_size = attn.n_local_heads * attn.head_dim 125 | _apply_tp_linear(attn.wqkv, "colwise", [attn.dim, kv_size, kv_size]) 126 | _apply_tp_linear(attn.wo, "rowwise") 127 | 128 | # overwrite 129 | world_size = _get_world_size() 130 | attn.n_head = attn.n_head // world_size 131 | attn.dim = attn.dim // world_size 132 | attn.head_dim = attn.dim // attn.n_head 133 | attn.n_local_heads = attn.n_local_heads // world_size 134 | 135 | attn.register_forward_hook(lambda _module, _input, output: funcol.all_reduce( 136 | output[0], "sum", list(range(world_size)))) 137 | 138 | 139 | def _apply_tp_Transformer(Transformer: Transformer) -> None: 140 | # overwrite config before Transformer.setup_cache is called 141 | world_size = _get_world_size() 142 | Transformer.config.n_head = Transformer.config.n_head // world_size 143 | Transformer.config.dim = Transformer.config.dim // world_size 144 | Transformer.config.n_local_heads = Transformer.config.n_local_heads // world_size 145 | 146 | 147 | def apply_tp(model: Transformer) -> None: 148 | _apply_tp_Transformer(model) 149 | for block in model.layers: 150 | # Apply to MLP 151 | _apply_tp_ffn(block.feed_forward) 152 | _apply_tp_attn(block.attention) 153 | --------------------------------------------------------------------------------