├── .gitignore
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── GPTQ.py
├── LICENSE
├── README.md
├── eval.py
├── generate.py
├── mixtral-moe
├── README.md
├── generate.py
├── model.py
├── quantize.py
├── scripts
│ ├── convert_hf_checkpoint.py
│ └── download.py
└── tp.py
├── model.py
├── quantize.py
├── requirements.txt
├── scripts
├── convert_hf_checkpoint.py
├── download.py
├── prepare.sh
├── speculate_34B_bf16.sh
├── speculate_70B_int4.sh
├── speculate_7B_int4.sh
├── speculate_tp_70B_bf16.sh
└── test_flow.sh
├── setup.py
├── tokenizer.py
└── tp.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | .idea
3 | .DS_Store
4 | *.egg-info
5 | build
6 |
7 | # data
8 | data
9 | checkpoints
10 | out
11 | !data/shakespeare/prepare.py
12 | wandb
13 |
14 | # downloaded by our tests
15 | original_model.py
16 | original_adapter.py
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | In the interest of fostering an open and welcoming environment, we as
6 | contributors and maintainers pledge to make participation in our project and
7 | our community a harassment-free experience for everyone, regardless of age, body
8 | size, disability, ethnicity, sex characteristics, gender identity and expression,
9 | level of experience, education, socio-economic status, nationality, personal
10 | appearance, race, religion, or sexual identity and orientation.
11 |
12 | ## Our Standards
13 |
14 | Examples of behavior that contributes to creating a positive environment
15 | include:
16 |
17 | * Using welcoming and inclusive language
18 | * Being respectful of differing viewpoints and experiences
19 | * Gracefully accepting constructive criticism
20 | * Focusing on what is best for the community
21 | * Showing empathy towards other community members
22 |
23 | Examples of unacceptable behavior by participants include:
24 |
25 | * The use of sexualized language or imagery and unwelcome sexual attention or
26 | advances
27 | * Trolling, insulting/derogatory comments, and personal or political attacks
28 | * Public or private harassment
29 | * Publishing others' private information, such as a physical or electronic
30 | address, without explicit permission
31 | * Other conduct which could reasonably be considered inappropriate in a
32 | professional setting
33 |
34 | ## Our Responsibilities
35 |
36 | Project maintainers are responsible for clarifying the standards of acceptable
37 | behavior and are expected to take appropriate and fair corrective action in
38 | response to any instances of unacceptable behavior.
39 |
40 | Project maintainers have the right and responsibility to remove, edit, or
41 | reject comments, commits, code, wiki edits, issues, and other contributions
42 | that are not aligned to this Code of Conduct, or to ban temporarily or
43 | permanently any contributor for other behaviors that they deem inappropriate,
44 | threatening, offensive, or harmful.
45 |
46 | ## Scope
47 |
48 | This Code of Conduct applies within all project spaces, and it also applies when
49 | an individual is representing the project or its community in public spaces.
50 | Examples of representing a project or community include using an official
51 | project e-mail address, posting via an official social media account, or acting
52 | as an appointed representative at an online or offline event. Representation of
53 | a project may be further defined and clarified by project maintainers.
54 |
55 | ## Enforcement
56 |
57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
58 | reported by contacting the project team at <conduct@pytorch.org>. All
59 | complaints will be reviewed and investigated and will result in a response that
60 | is deemed necessary and appropriate to the circumstances. The project team is
61 | obligated to maintain confidentiality with regard to the reporter of an incident.
62 | Further details of specific enforcement policies may be posted separately.
63 |
64 | Project maintainers who do not follow or enforce the Code of Conduct in good
65 | faith may face temporary or permanent repercussions as determined by other
66 | members of the project's leadership.
67 |
68 | ## Attribution
69 |
70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
72 |
73 | [homepage]: https://www.contributor-covenant.org
74 |
75 | For answers to common questions about this code of conduct, see
76 | https://www.contributor-covenant.org/faq
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing to gpt-fast
2 | We want to make contributing to this project as easy and transparent as
3 | possible.
4 |
5 |
6 | ## Pull Requests
7 | We actively welcome your pull requests.
8 |
9 | 1. Fork the repo and create your branch from `main`.
10 | 2. If you've added code that should be tested, add tests.
11 | 3. If you've changed APIs, update the documentation.
12 | 4. Ensure the test suite passes.
13 | 5. Make sure your code lints.
14 | 6. If you haven't already, complete the Contributor License Agreement ("CLA").
15 |
16 | ## Contributor License Agreement ("CLA")
17 | In order to accept your pull request, we need you to submit a CLA. You only need
18 | to do this once to work on any of Meta's open source projects.
19 |
20 | Complete your CLA here: <https://code.facebook.com/cla>
21 |
22 | ## Issues
23 | We use GitHub issues to track public bugs. Please ensure your description is
24 | clear and has sufficient instructions to be able to reproduce the issue.
25 |
26 | Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe
27 | disclosure of security bugs. In those cases, please go through the process
28 | outlined on that page and do not file a public issue.
29 |
30 | ## License
31 | By contributing to `gpt-fast`, you agree that your contributions will be licensed
32 | under the LICENSE file in the root directory of this source tree.
33 |
--------------------------------------------------------------------------------
/GPTQ.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 |
9 | import torch.fx as fx
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | from torch.utils._pytree import tree_flatten, tree_unflatten
13 |
14 | aten = torch.ops.aten
15 |
16 | from eval import (
17 | setup_cache_padded_seq_input_pos_max_seq_length_for_prefill,
18 | GPTFastEvalWrapper
19 | )
20 |
21 |
22 | class InputRecorder(GPTFastEvalWrapper):
23 | """
24 | This is a fake evaluation wrapper that just records the inputs
25 | so that they can be used in calibration.
26 |
27 | If pad_calibration_inputs is enabled, the input recorder will take
28 | each input and pad/truncate it down to the calibration_seq_length.
29 | It will also edit the model embeddings to be zero for the 0 token used
30 | in padding and avoid any inputs with the 0 token.
31 |
32 | If not, it will only truncate inputs to the desired length.
33 | """
34 |
35 | def __init__(
36 | self,
37 | model,
38 | tokenizer,
39 | calibration_seq_length,
40 | pad_calibration_inputs=False,
41 | ):
42 | super().__init__(model, tokenizer, calibration_seq_length)
43 | self._model = model
44 | self._tokenizer = tokenizer
45 | self._device = torch.device("cpu")
46 | self.vocab_size = model.config.vocab_size
47 | self.calibration_seq_length = calibration_seq_length
48 | self.pad_calibration_inputs = pad_calibration_inputs
49 | self.inputs = None
50 |
51 | if self.pad_calibration_inputs:
52 | # This is needed for the pad_calibration_inputs option
53 | # to work properly, the 0 token's embeddings are set to 0 so that
54 | # the padded inputs will not affect the model numerics. This token isn't used
55 | # commonly in the eval tasks for the meta-llama tokenizer and we skip any inputs
56 | # where it appears
57 | try:
58 | if isinstance(self._model.transformer.wte, nn.Embedding):
59 | self.mod.transformer.wte.weight.data[0, :] *= 0
60 | except:
61 | print(
62 | "Did not find embeddings in model.transformer.wte, disabling padding"
63 | )
64 | self.pad_calibration_inputs = False
65 |
66 |
67 | def add_input(self, args):
68 | if self.inputs is None:
69 | self.inputs = [MultiInput([arg]) for arg in args]
70 | else:
71 | self.inputs = [
72 | multi.add_input(arg) for (multi, arg) in zip(self.inputs, args)
73 | ]
74 |
75 | def get_recorded_inputs(self):
76 | return self.inputs
77 |
78 | def _model_call(self, inps):
79 | inps = inps.squeeze(0)
80 | T = len(inps)
81 | if (
82 | # can't use inputs that are too short when padding disabled
83 | (T < self.calibration_seq_length and not self.pad_calibration_inputs)
84 | or
85 | # can't use inputs that actually use token we use for padding
86 | (self.pad_calibration_inputs and 0 in inps)
87 | ):
88 | # give random output
89 | return torch.randn(
90 | (1, T, self.vocab_size), dtype=torch.bfloat16, device=self._device
91 | )
92 |
93 | # pad or truncate to the right size
94 | if T >= self.calibration_seq_length:
95 | inps = inps[: self.calibration_seq_length]
96 | else:
97 | inps = F.pad(inps, (0, self.calibration_seq_length - T))
98 |
99 | max_new_tokens = 1
100 | (
101 | seq,
102 | input_pos,
103 | max_seq_length,
104 | ) = setup_cache_padded_seq_input_pos_max_seq_length_for_prefill(
105 | self._model, inps, max_new_tokens, self.max_length
106 | )
107 | x = seq.index_select(0, input_pos).view(1, -1)
108 | self.add_input((x, input_pos))
109 |
110 | # output `something` with correct shape to keep eval going
111 | return torch.randn(
112 | (1, T, self.vocab_size), dtype=torch.bfloat16, device=self._device
113 | )
114 |
115 |
116 |
117 | class MultiInput:
118 | def __init__(self, inputs):
119 | self.values = list(inputs)
120 |
121 | def add_input(self, input):
122 | self.values.append(input)
123 | return self
124 |
125 | def __getitem__(self, slice):
126 | return MultiInput(self.values[slice])
127 |
128 | def cuda(self):
129 | self.values = [val.cuda() if isinstance(val, torch.Tensor) else val for val in self.values]
130 |
131 |
132 | class GenericGPTQRunner(fx.Interpreter):
133 | """
134 | This is a generic GPTQ runner that takes an existing model and applies GPTQ.
135 | It uses torch._dynamo.export to obtain a graph of the model and then hooks
136 | into function calls and when it detects a linear, it applies GPTQ to the weight
137 | given the calibration of inputs passed in at initialization. It puts the results
138 | into the state_dict so that the quantized model weights/qparams can be loaded
139 | directly into the model.
140 |
141 | This class is expected to work in concert with a GPTQSimpleQuantizer
142 | class to define the specific type of quantization being done.
143 | """
144 |
145 | def __init__(
146 | self, model, inputs: MultiInput, blocksize=128, percdamp=0.01, groupsize=128
147 | ):
148 | self.id_to_name = {
149 | id(value): name for name, value in dict(model.named_parameters()).items()
150 | }
151 |
152 | # trace model for one input
153 | one_input = [multi.values[0].cpu() for multi in inputs]
154 | exported_model = torch._dynamo.export(
155 | model.cpu(), aten_graph=True, pre_dispatch=True, tracing_mode="fake"
156 | )(*one_input)
157 | super().__init__(exported_model.graph_module)
158 | self.new_state_dict = model.state_dict()
159 | self.blocksize = blocksize
160 | self.percdamp = percdamp
161 | self.groupsize = groupsize
162 | self.inputs = inputs
163 | self.gptq_done = False
164 | self.debug = False
165 |
166 | def configure_quantization_mode(
167 | self,
168 | get_qparams_func,
169 | quantize_func,
170 | dequantize_func,
171 | combine_qparams_list_func,
172 | make_names_and_values_dict_func,
173 | skip_layer_func,
174 | ):
175 | # these functions need to already be curried with all inputs other than weight, qparams
176 | self.get_qparams_func = (
177 | get_qparams_func # accepts [2d weight tensor], outputs qparams.
178 | )
179 |
180 | self.quantize_func = quantize_func # accepts [2d weight tensor], [qparams], outputs a 2d quantized tensor of desired dtype
181 |
182 | self.dequantize_func = dequantize_func
183 | # accepts [quantized] tensor and [qparams], outputs a 2d dequantized tensor of type float,
184 | # assumes this output .to(w_orig_dtype) is ~eventual desired dequant behavior
185 |
186 | self.combine_qparams_list_func = combine_qparams_list_func
187 | # accepts [`list` of qparams] from quantizing one group at a time,
188 | # outputs a qparams object that could be passed into quant/dequantize_func
189 |
190 | self.skip_layer_func = skip_layer_func # accepts [weight tensor], outputs a bool on whether or not to apply gptq to this layer
191 |
192 | self.make_names_and_values_dict_func = make_names_and_values_dict_func # accepts [2d quantized tensor], [qparams], returns a dict of names, values to put in state_dict
193 | # note any final packing for storage should happen here
194 | return self
195 |
196 | def run(self):
197 | assert (
198 | self.get_qparams_func is not None
199 | ), "need to configure quantization mode before running"
200 | self.gptq_done = True
201 | super().run(*self.inputs)
202 |
203 | def get_quantized_state_dict(self):
204 | assert (
205 | self.gptq_done
206 | ), "need to run GPTQRunner before you can get_quantized_state_dict"
207 | quantized_state_dict = self.new_state_dict
208 | # Don't want to store/load the kv_cache so remove it from the state_dict
209 | del_list = []
210 | for param_fqn in quantized_state_dict:
211 | if "kv_cache" in param_fqn:
212 | del_list.append(param_fqn)
213 | for param_fqn in del_list:
214 | quantized_state_dict.pop(param_fqn)
215 | return quantized_state_dict
216 |
217 | def call_function(self, target, args, kwargs, skip_quant=False):
218 | def tensors_to_cuda(args):
219 | new_args = []
220 | for x in args:
221 | new_args.append(x.cuda() if isinstance(x, torch.Tensor) else x)
222 | return new_args
223 |
224 | # flatten args and kwargs together
225 | flat_args, spec = tree_flatten((args, kwargs))
226 | # move all single tensors to cuda, will move MultiInputs to cuda one at a time
227 | flat_args = tensors_to_cuda(flat_args)
228 |
229 | has_multi_input = MultiInput in [type(x) for x in flat_args]
230 | if has_multi_input:
231 | # Just some trickery to convert
232 | # [MultiInput[a, a, a], MultiInput(b, b, b)] => [a, b], [a, b], [a, b]
233 | multi_input_count = max(
234 | [len(x.values) if isinstance(x, MultiInput) else 1 for x in flat_args]
235 | )
236 | transposed_args = list(
237 | zip(
238 | *[x.values if isinstance(x, MultiInput) else [x] * multi_input_count for x in flat_args]
239 | )
240 | )
241 | else:
242 | transposed_args = [flat_args]
243 | outputs = []
244 |
245 | # check whether we apply GPTQ to this module
246 | quantize_linear = (
247 | (target == aten.linear.default) # if its a linear
248 | and id(args[1]) in self.id_to_name # and if we know the layer name
249 | and not skip_quant # and if we weren't told to skip quantization
250 | # and if the skip_layer_func doesn't say we should skip
251 | and not (self.skip_layer_func is not None and self.skip_layer_func(args[1]))
252 | ) # then we will quantize this linear layer/weight
253 |
254 | if quantize_linear: # instantiate variables for GPTQ
255 | H = 0
256 | total_batches = 0
257 |
258 | for inp in transposed_args:
259 | inp = tensors_to_cuda(inp)
260 | cur_args, cur_kwargs = tree_unflatten(inp, spec)
261 |
262 | if (
263 | quantize_linear
264 | ): # calculate H instead of output (will run the linear eventually with updated weight)
265 | x = cur_args[0].float()
266 | shape = x.shape
267 | n = 1 if len(shape) == 2 else shape[0]
268 | H *= total_batches / (total_batches + n)
269 | total_batches += n
270 | x = ((2 / total_batches) ** (1 / 2)) * x.reshape(
271 | -1, shape[-1]
272 | ).t().float()
273 | H += x.matmul(x.t())
274 | else:
275 | # get output if its not a linear
276 | out = super().call_function(target, cur_args, cur_kwargs)
277 |
278 | if isinstance(out, torch.Tensor):
279 | outputs.append(out.cpu())
280 | else:
281 | outputs.append(out)
282 |
283 | if quantize_linear:
284 | mod_fqn = ".".join(self.id_to_name[id(args[1])].split(".")[:-1])
285 | W = args[1].to(H.device)
286 | Q, DQ, qparams = self.faster_quant(H, W.detach())
287 | print(mod_fqn)
288 | names_and_values_dict = self.make_names_and_values_dict_func(Q, qparams)
289 |
290 | # delete old weight
291 | if mod_fqn + ".weight" in self.new_state_dict:
292 | self.new_state_dict.pop(mod_fqn + ".weight")
293 | if len(args) > 2:
294 | self.new_state_dict[mod_fqn + ".bias"] = args[2]
295 | for name, value in names_and_values_dict.items():
296 | self.new_state_dict[mod_fqn + "." + name] = value
297 |
298 | # run linear with new weight to get corrected output
299 | new_out = self.call_function(
300 | target, (args[0], DQ, *args[2:]), kwargs, skip_quant=True
301 | )
302 |
303 | if self.debug:
304 | old_out = self.call_function(
305 | target, (args[0][:2], args[1], *args[2:]), kwargs, skip_quant=True
306 | )
307 |
308 | def SQNR(x, y):
309 | return 20 * torch.log10(torch.norm(x) / torch.norm(x - y))
310 |
311 | DQ_after = self.dequantize_func(Q, qparams).to(W.dtype)
312 | print(
313 | "SQNR for QDQ (this should be inf)", SQNR(DQ, DQ_after)
314 | ) # matches
315 |
316 | print(
317 | "SQNR for weight (can be low)", SQNR(W, DQ.cuda())
318 | ) # fine to not match
319 | print(
320 | "SQNR for output with GPTQ (hopefully 35+)",
321 | torch.cat(
322 | [
323 | SQNR(old.cpu(), new.cpu()).unsqueeze(0)
324 | for (old, new) in zip(old_out.values, new_out.values[:2])
325 | ]
326 | ).mean(),
327 | )
328 |
329 | qparams2 = self.get_qparams_func(W)
330 | Q2 = self.quantize_func(W, qparams2)
331 | DQ2 = self.dequantize_func(Q2, qparams2).to(W.dtype)
332 | old_q_out = self.call_function(
333 | target, (args[0][:2], DQ2, *args[2:]), kwargs, skip_quant=True
334 | )
335 |
336 | print("SQNR for output without GPTQ (should be less than above)",
337 | torch.cat([
338 | SQNR(old.cpu(), old_q.cpu()).unsqueeze(0)
339 | for (old, old_q) in zip(old_out.values, old_q_out.values)
340 | ]).mean(),
341 | )
342 | return new_out
343 |
344 | return MultiInput(outputs) if has_multi_input else outputs[0]
345 |
346 | def faster_quant(self, H, W):
347 | percdamp = self.percdamp
348 | blocksize = self.blocksize
349 | groupsize = self.groupsize
350 | orig_dtype = W.dtype
351 | W = W.detach().float()
352 | rows, columns = W.shape[0], W.shape[1]
353 | device = W.device
354 |
355 | if groupsize == -1:
356 | cur_qparams = self.get_qparams_func(W)
357 | dead = torch.diag(H) == 0
358 | H[dead, dead] = 1
359 | W[:, dead] = 0
360 |
361 | Losses = torch.zeros_like(W)
362 | DQ = torch.zeros_like(W)
363 |
364 | damp = percdamp * torch.mean(torch.diag(H))
365 | diag = torch.arange(columns, device=device)
366 | H[diag, diag] += damp
367 | H = torch.linalg.cholesky(H)
368 | H = torch.cholesky_inverse(H)
369 | H = torch.linalg.cholesky(H, upper=True)
370 | Hinv = H
371 |
372 | all_qparams = []
373 | for i1 in range(0, columns, blocksize):
374 | i2 = min(i1 + blocksize, columns)
375 | count = i2 - i1
376 | W1 = W[:, i1:i2].clone()
377 | DQ1 = torch.zeros_like(W1)
378 | Err1 = torch.zeros_like(W1)
379 | Losses1 = torch.zeros_like(W1)
380 | Hinv1 = Hinv[i1:i2, i1:i2]
381 | for i in range(count):
382 | w = W1[:, i]
383 | d = Hinv1[i, i]
384 |
385 | if groupsize != -1 and (i1 + i) % groupsize == 0: # start of new group
386 | cur_qparams = self.get_qparams_func(
387 | W[:, (i1 + i) : (i1 + i + groupsize)]
388 | )
389 | all_qparams.append(cur_qparams)
390 |
391 | q = self.quantize_func(w.unsqueeze(1), cur_qparams).flatten()
392 | dq = self.dequantize_func(q.unsqueeze(1), cur_qparams).flatten()
393 |
394 | DQ1[:, i] = dq
395 | Losses1[:, i] = (w - dq) ** 2 / d**2
396 |
397 | err1 = (w - dq) / d
398 | W1[:, i:] -= (
399 | err1.to(Hinv1.dtype).unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
400 | )
401 | Err1[:, i] = err1
402 |
403 | DQ[:, i1:i2] = DQ1
404 | Losses[:, i1:i2] = Losses1 / 2
405 |
406 | W[:, i2:] -= Err1.to(Hinv.dtype).matmul(Hinv[i1:i2, i2:])
407 |
408 | torch.cuda.synchronize()
409 |
410 | if all_qparams == []:
411 | all_qparams.append(cur_qparams)
412 |
413 | # convert a list of qparams objects into a single one. enerally by
414 | # concatenating a bunch of n,1 scale/zeros tensors into a n,num_groups tensor
415 | all_qparams = self.combine_qparams_list_func(all_qparams)
416 | Q = self.quantize_func(DQ, all_qparams)
417 | return Q, DQ.to(orig_dtype), all_qparams
418 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright 2023 Meta
2 |
3 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
4 |
5 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
6 |
7 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
8 |
9 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
10 |
11 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
12 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # gpt-fast
2 | Simple and efficient pytorch-native transformer text generation.
3 |
4 | Featuring:
5 | 1. Very low latency
6 | 2. <1000 lines of python
7 | 3. No dependencies other than PyTorch and sentencepiece
8 | 4. int8/int4 quantization
9 | 5. Speculative decoding
10 | 6. Tensor parallelism
11 | 7. Supports Nvidia and AMD GPUs
12 |
13 | This is *NOT* intended to be a "framework" or "library" - it is intended to show off what kind of performance you can get with native PyTorch :) Please copy-paste and fork as you desire.
14 |
15 | For an in-depth walkthrough of what's in this codebase, see this [blog post](https://pytorch.org/blog/accelerating-generative-ai-2/).
16 |
17 | ## Supported Models
18 |
19 | ### LLaMA family
20 | Please check the rest of this page about benchmark of LLaMA family models.
21 |
22 | ### Mixtral 8x7B
23 | We also supported [Mixtral 8x7B](https://mistral.ai/news/mixtral-of-experts/) which is a high-quality sparse mixture of experts (MoE) model, the average token generation rates are:
24 |
25 | | | 1 GPU | 2 GPU | 4 GPU | 8 GPU |
26 | |------------------|---------|-----------|--------|------------|
27 | |baseline(bfloat16)| OOM | 96.67 | 155.35 | 227.82 |
28 | | int8 | 97.92 | 155.03 | 216.87 | 279.35 |
29 |
30 | Note that the benchmarks run on an 8xA100-80GB, power limited to 330W with a hybrid cube mesh topology. Note that all benchmarks are run at *batch size=1*, making the reported tokens/s numbers equivalent to "tokens/s/user". In addition, they are run with a very small prompt length (just 5 tokens).
31 |
32 | For more details about Mixtral 8x7B, please check [this page](./mixtral-moe) or this [note](https://thonking.substack.com/p/short-supporting-mixtral-in-gpt-fast).
33 |
34 | ## Examples
35 | In the spirit of keeping the repo minimal, here are various examples of extensions you can make to gpt-fast as PRs.
36 | - [Google Gemma](https://github.com/meta-pytorch/gpt-fast/pull/115)
37 | - [xAI Grok-1](https://github.com/meta-pytorch/gpt-fast/pull/171)
38 | - [Databricks DBRX](https://github.com/meta-pytorch/gpt-fast/pull/174)
39 |
40 | ## Community
41 |
42 | Projects inspired by gpt-fast in the community:
43 |
44 | - [gpt-blazing](https://github.com/armed-gpt/gpt-blazing): applies the same performance optimization strategy to more models (e.g., baichuan2).
45 | - [gptfast](https://github.com/MDK8888/GPTFast): applies a subset of the performance optimizations to all Huggingface models
46 | - [gpt-accelera](https://github.com/Edward-Sun/gpt-accelera): extends `gpt-fast` to SFT/RM/PPO training and batched inference to optimize the throughput
47 |
48 | ## Installation
49 | [Download PyTorch nightly](https://pytorch.org/get-started/locally/)
50 |
51 | Install required packages:
52 |
53 | ```bash
54 | pip install -r requirements.txt
55 | ```
56 |
57 | To download llama models, go to https://huggingface.co/meta-llama/Llama-2-7b and go through steps to obtain access.
58 | Then login with `huggingface-cli login`
59 |
60 |
61 |
62 | ## Downloading Weights
63 | Models tested/supported
64 | ```text
65 | tinyllamas/stories{15,42,100}
66 | openlm-research/open_llama_7b
67 | meta-llama/Llama-2-7b-chat-hf
68 | meta-llama/Llama-2-13b-chat-hf
69 | meta-llama/Llama-2-70b-chat-hf
70 | codellama/CodeLlama-7b-Python-hf
71 | codellama/CodeLlama-34b-Python-hf
72 | mistralai/Mistral-7B-v0.1
73 | mistralai/Mistral-7B-Instruct-v0.1
74 | mistralai/Mistral-7B-Instruct-v0.2
75 | meta-llama/Meta-Llama-3-8B
76 | meta-llama/Meta-Llama-3.1-8B
77 | meta-llama/Meta-Llama-3.1-70B
78 | meta-llama/Meta-Llama-3.1-405B
79 | ```
80 |
81 | For example, to convert Llama-2-7b-chat-hf
82 | ```bash
83 | export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf
84 | ./scripts/prepare.sh $MODEL_REPO
85 | ```
86 |
87 | ## Benchmarks
88 | Benchmarks run on an 8xA100-80GB, power limited to 330W with a hybrid cube mesh topology. Note that all benchmarks are run at *batch size=1*, making the reported tokens/s numbers equivalent to "tokens/s/user". In addition, they are run with a very small prompt length (just 5 tokens).
89 |
90 | | Model | Technique | Tokens/Second | Memory Bandwidth (GB/s) |
91 | | -------- | ------- | ------ | ------ |
92 | | Llama-2-7B | Base | 104.9 | 1397.31 |
93 | | | 8-bit | 155.58 | 1069.20 |
94 | | | 4-bit (G=32) | 196.80 | 862.69 |
95 | | Llama-2-70B | Base | OOM ||
96 | | | 8-bit | 19.13 | 1322.58 |
97 | | | 4-bit (G=32) | 25.25 | 1097.66 |
98 | | Llama-3.1-8B | Base | 93.89 | 1410.76 |
99 | | | 8-bit | 137.64 | 1030.89 |
100 | | Llama-3.1-70B | Base | OOM ||
101 | | | 8-bit | 18.04 | 1253.78 |
102 |
103 | ### Speculative Sampling
104 | [Verifier: Llama-70B (int4), Draft: Llama-7B (int4)](./scripts/speculate_70B_int4.sh): 48.4 tok/s
105 |
106 | ### Tensor Parallelism
107 | | Model | Number of GPUs | Tokens/Second | Memory Bandwidth (GB/s) |
108 | | -------- | ------- | ------ | ------ |
109 | | Llama-2-7B | 1 | 104.9 | 1397.31 |
110 | | | 2 | 168.84 | 1181.99 |
111 | | | 4 | 254.02 | 955.83 |
112 | | | 8 | 328.43 | 704.10 |
113 | | Llama-2-70B | 1 | OOM | |
114 | | | 2 | 21.32 | 1481.87 |
115 | | | 4 | 38.01 | 1340.76 |
116 | | | 8 | 62.50 | 1135.29 |
117 | | Llama-3.1-8B | 1 | 93.83 | 1408.37 |
118 | | | 2 | 149.10 | 1197.32 |
119 | | | 4 | 217.21 | 986.32 |
120 | | | 8 | 276.01 | 772.60 |
121 | | Llama-3.1-70B | 1 | OOM | |
122 | | | 2 | 16.03 | 1130.81 |
123 | | | 4 | 37.45 | 1360.53 |
124 | | | 8 | 58.78 | 1129.61 |
125 |
126 | ### Tensor Parallelism + Quantization
127 | | Model | Technique | Tokens/Second | Memory Bandwidth (GB/s) |
128 | | -------- | ------- | ------ | ------ |
129 | | Llama-2-70B | Base | 62.50 | 1135.29 |
130 | | | 8-bit | 80.44 | 752.04 |
131 | | | 4-bit (G=32) | 90.77 | 548.10 |
132 | | Llama-3.1-70B | Base | 58.78 | 1129.61 |
133 | | | 8-bit | 75.58 | 726.57 |
134 | | Llama-3.1-405B | 8-bit | 15.60 | 815.87 |
135 |
136 | ### AMD
137 | Benchmarks run on one GCD of a MI-250x.
138 |
139 | | Model | Technique | Tokens/Second | Memory Bandwidth (GB/s) |
140 | | -------- | ------- | ------ | ------ |
141 | | Llama-2-7B | Base | 76.33 | 1028.70 |
142 | | | 8-bit | 101.86 | 700.06 |
143 |
144 | ## Generate Text
145 |
146 | Model definition in `model.py`, generation code in `generate.py`.
147 |
148 | ```bash
149 | python generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model.pth --prompt "Hello, my name is"
150 | ```
151 |
152 | To squeeze out a little bit more performance, you can also compile the prefill with `--compile_prefill`. This will increase compilation times though.
153 |
154 | ## Quantization
155 | Choose device to use by
156 | ```bash
157 | # The current support devices: cuda, cpu
158 | export DEVICE=cuda
159 | ```
160 | ### Int8 Weight-Only Quantization
161 | To generate this version of the model
162 | ```bash
163 | # Spits out model at checkpoints/$MODEL_REPO/model_int8.pth
164 | python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int8
165 | ```
166 | To run with int8, just pass the int8 checkpoint to generate.py.
167 | ```bash
168 | python generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model_int8.pth --device $DEVICE
169 | ```
170 |
171 | ### Int4 Weight-Only Quantization
172 | To generate int4 version of model
173 | ```bash
174 | # Spits out model at checkpoints/$MODEL_REPO/model_int4.g32.$DEVICE.pth
175 | python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4 --groupsize 32
176 | ```
177 |
178 | To run with int4, just pass the int4 checkpoint to generate.py.
179 | ```bash
180 | python generate.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.pth --compile
181 | ```
182 |
183 | ## Speculative Sampling
184 | To generate with speculative sampling (DRAFT_MODEL_REPO should point to a smaller model compared with MODEL_REPO).
185 |
186 | In this example, the "smaller" model is just the int8 quantized version of the model.
187 | ```
188 | export DRAFT_MODEL_REPO=meta-llama/Llama-2-7b-chat-hf
189 | python generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model.pth --draft_checkpoint_path checkpoints/$DRAFT_MODEL_REPO/model_int8.pth
190 | ```
191 |
192 | Note: Running on an A100 80GB, albeit power-limited to 330 watts. Empirically, seems like peak bandwidth is about 1700 GB/s.
193 |
194 |
195 | ## Tensor Parallelism
196 | ```bash
197 | ENABLE_INTRA_NODE_COMM=1 torchrun --standalone --nproc_per_node=2 generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model.pth
198 | ```
199 |
200 | ## Experimental
201 | ### Evaluation
202 | We use the EleutherAI evaluation harness to evaluate our model accuracy. To evaluate the accuracy, make sure the evaluation harness is installed and pass your model checkpoint and desired tasks to eval.py.
203 |
204 | ```bash
205 | python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --compile --tasks hellaswag winogrande
206 | ```
207 |
208 | Note: Generative tasks are currently not supported for gpt-fast
209 |
210 | Installation Instructions for the evaluation harness: https://github.com/EleutherAI/lm-evaluation-harness/tree/master#install
211 |
212 | ### GPTQ
213 | We have a pure pytorch implementation of GPTQ that utilizes torch._dynamo.export to access the model structure. You can generate a GPTQ quantized
214 | version of int4 quantization by using the same command to quantize it but adding 'gptq' to the quantization mode i.e.
215 | ```bash
216 | # Spits out model at checkpoints/$MODEL_REPO/model_int4-gptq.g32.pth
217 | python quantize.py --mode int4-gptq --calibration_tasks wikitext --calibration_seq_length 2048
218 | ```
219 |
220 | You can then eval or generate text with this model in the same way as above.
221 |
222 | ## License
223 |
224 | `gpt-fast` is released under the [BSD 3](https://github.com/meta-pytorch/gpt-fast/main/LICENSE) license.
225 |
226 | ## Acknowledgements
227 | Thanks to:
228 | * Lightning AI for supporting pytorch and work in flash attention, int8 quantization, and LoRA fine-tuning.
229 | * GGML for driving forward fast, on device inference of LLMs
230 | * Karpathy for spearheading simple, interpretable and fast LLM implementations
231 | * MLC-LLM for pushing 4-bit quantization performance on heterogeneous hardware
232 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | import sys
7 | import time
8 | from pathlib import Path
9 | from typing import Optional
10 |
11 | import torch
12 | import torch._dynamo.config
13 | import torch._inductor.config
14 |
15 | torch._dynamo.config.automatic_dynamic_shapes = True
16 | torch._inductor.config.triton.unique_kernel_names = True
17 | torch._inductor.config.epilogue_fusion = False
18 | torch._dynamo.config.cache_size_limit = 100000
19 |
20 | from tokenizer import get_tokenizer
21 |
22 | from model import Transformer
23 |
24 | try:
25 | import lm_eval
26 | lm_eval_available = True
27 | except:
28 | lm_eval_available = False
29 |
30 | from generate import _load_model, encode_tokens, model_forward
31 |
32 | if lm_eval_available:
33 | try: # lm_eval version 0.4
34 | from lm_eval.models.huggingface import HFLM as eval_wrapper
35 | from lm_eval.tasks import get_task_dict
36 | from lm_eval.evaluator import evaluate
37 | except: #lm_eval version 0.3
38 | from lm_eval import base
39 | from lm_eval import tasks
40 | from lm_eval import evaluator
41 | eval_wrapper=base.BaseLM
42 | get_task_dict=tasks.get_task_dict
43 | evaluate=evaluator.evaluate
44 |
45 |
46 | def setup_cache_padded_seq_input_pos_max_seq_length_for_prefill(
47 | model: Transformer,
48 | prompt: torch.Tensor,
49 | max_new_tokens: int,
50 | max_seq_length: Optional[int] = None,
51 | ):
52 | """
53 | Sets up model cache and does some bookkeeping calculations for prompt, input_pos and max_seq_length
54 | that are needed for prefill or model_forward
55 |
56 | Args:
57 | model (LLaMA): The model whose cache gets set up
58 | prompt (torch.Tensor): Tensor of shape (T) with indices of the prompt sequence.
59 | max_new_tokens (int): The desired maximum number of new tokens that can be generated.
60 | max_seq_length (Optional[int], optional): The maximum sequence length allowed.
61 |
62 | Returns:
63 | seq (torch.Tensor): prompt but padded with zeros to size max_seq_length
64 | input_pos (torch.Tensor): tensor of integers in increasing order
65 | max_seq_length (int): The maximum sequence length allowed, updated based on other numbers
66 | """
67 | T = prompt.size(0)
68 | T_new = T + max_new_tokens
69 | if max_seq_length is None:
70 | max_seq_length = min(T_new, model.config.block_size)
71 |
72 | device, dtype = prompt.device, prompt.dtype
73 | # create an empty tensor of the expected final shape and fill in the current tokens
74 | empty = torch.empty(T_new, dtype=dtype, device=device)
75 | empty[:T] = prompt
76 | seq = empty
77 | input_pos = torch.arange(0, T, device=device)
78 |
79 | with torch.device(device):
80 | model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
81 |
82 | return seq, input_pos, max_seq_length
83 |
84 | class GPTFastEvalWrapper(eval_wrapper):
85 | """
86 | A wrapper class for GPTFast, providing integration with the lm-evaluation-harness library.
87 | """
88 | def __init__(
89 | self,
90 | model: Transformer,
91 | tokenizer,
92 | max_seq_length: Optional[int]=None,
93 | ):
94 | super().__init__()
95 | self._model = model
96 | self._tokenizer = tokenizer
97 | self._device = torch.device('cuda')
98 | self._max_seq_length = 2048 if max_seq_length is None else max_seq_length
99 |
100 | @property
101 | def eot_token_id(self):
102 | return self._tokenizer.eos_id()
103 |
104 | @property
105 | def max_length(self):
106 | return self._max_seq_length
107 |
108 | @property
109 | def max_gen_toks(self):
110 | return 50
111 |
112 | @property
113 | def batch_size(self):
114 | return 1
115 |
116 | @property
117 | def device(self):
118 | return self._device
119 |
120 | def tok_encode(self, string: str, **kwargs):
121 | encoded = encode_tokens(self._tokenizer,
122 | string, bos=True, device=self._device)
123 | # encoded is a pytorch tensor, but some internal logic in the
124 | # eval harness expects it to be a list instead
125 | # TODO: verify this for multi-batch as well
126 | encoded = encoded.tolist()
127 | return encoded
128 |
129 | def tok_decode(self, tokens):
130 | decoded = self._tokenizer.decode(tokens)
131 | return decoded
132 |
133 | def _model_call(self, inps):
134 | # TODO: make batches work
135 | inps = inps.squeeze(0)
136 |
137 | max_new_tokens = 1
138 | seq, input_pos, max_seq_length = \
139 | setup_cache_padded_seq_input_pos_max_seq_length_for_prefill(
140 | self._model,
141 | inps,
142 | max_new_tokens,
143 | self.max_length,
144 | )
145 | x = seq.index_select(0, input_pos).view(1, -1)
146 | logits = model_forward(self._model, x, input_pos)
147 | return logits
148 |
149 | def _model_generate(self, context, max_length, eos_token_id):
150 | raise Exception('unimplemented')
151 |
152 |
153 | @torch.no_grad()
154 | def eval(
155 | model: Transformer,
156 | tokenizer,
157 | tasks: list = ["hellaswag"],
158 | limit: Optional[int] = None,
159 | max_seq_length: Optional[int] = None,
160 | ) -> dict:
161 | """
162 | Evaluates a language model on a specified task using the lm-evaluation-harness library.
163 |
164 | Args:
165 | model (Transformer): The pre-trained language model to evaluate.
166 | tokenizer: The tokenizer to use for encoding/decoding text.
167 | tasks (list): The names of the evaluation tasks to perform.
168 | limit (Optional[int]): The maximum number of samples to evaluate (None for all available).
169 | max_seq_length (Optional[int]): The maximum sequence length allowed for input text.
170 |
171 | Returns:
172 | eval_results (dict): A dictionary of evaluation results for the specified task(s).
173 | """
174 | model_eval_wrapper = GPTFastEvalWrapper(
175 | model,
176 | tokenizer,
177 | max_seq_length,
178 | )
179 |
180 | try:
181 | lm_eval.tasks.initialize_tasks()
182 | except:
183 | pass
184 |
185 | if 'hendrycks_test' in tasks:
186 | tasks.remove('hendrycks_test')
187 | tasks += [x for x in lm_eval.tasks.hendrycks_test.create_all_tasks().keys()]
188 | task_dict = get_task_dict(tasks)
189 |
190 | eval_results = evaluate(
191 | model_eval_wrapper,
192 | task_dict,
193 | limit=limit,
194 | )
195 | return eval_results
196 |
197 |
198 | def main(
199 | checkpoint_path: Path = Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/lit_model.pth"),
200 | compile: bool = False,
201 | tasks: list = ["hellaswag"],
202 | limit: Optional[int] = None,
203 | max_seq_length: Optional[int] = None,
204 | ) -> None:
205 | """Evaluates model on a task from the `lm-evaluation-harness` library.
206 |
207 | Args:
208 | checkpoint_path (Path): The path to the model checkpoint file to load.
209 | compile (bool): Whether or not to compile the model for optimization.
210 | tasks (list): The names of the evaluation tasks to perform.
211 | limit (Optional[int]): The maximum number of samples to evaluate (None for all available).
212 | max_seq_length (Optional[int]): The maximum sequence length allowed for input text.
213 |
214 | """
215 |
216 | assert checkpoint_path.is_file(), checkpoint_path
217 |
218 | tokenizer_path = checkpoint_path.parent / "tokenizer.model"
219 | assert tokenizer_path.is_file(), str(tokenizer_path)
220 |
221 | device = 'cuda'
222 | precision = torch.bfloat16
223 |
224 | print("Loading model ...")
225 | t0 = time.time()
226 | model = _load_model(checkpoint_path, device, precision, False)
227 |
228 | torch.cuda.synchronize()
229 | print(f"Time to load model: {time.time() - t0:.02f} seconds.")
230 |
231 | model.eval()
232 |
233 | tokenizer = get_tokenizer(tokenizer_path, checkpoint_path)
234 |
235 | torch.manual_seed(1234)
236 |
237 | if compile:
238 | global model_forward
239 | model_forward = torch.compile(model_forward, mode="reduce-overhead", dynamic=True, fullgraph=True)
240 | torch._inductor.config.coordinate_descent_tuning = True
241 |
242 | t1 = time.time()
243 | result = eval(
244 | model,
245 | tokenizer,
246 | tasks,
247 | limit,
248 | max_seq_length,
249 | )
250 | print(f"Time to run eval: {time.time() - t1:.02f} seconds.")
251 | print(f"For model {checkpoint_path}")
252 | for task, res in result["results"].items():
253 | print(f"{task}: {res}")
254 |
255 |
256 | if __name__ == '__main__':
257 | import argparse
258 | parser = argparse.ArgumentParser(description='Your CLI description.')
259 |
260 | parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/lit_model.pth"), help='Model checkpoint path.')
261 | parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')
262 | parser.add_argument('--tasks', nargs='+', type=str, default=["hellaswag"], help='list of lm-eluther tasks to evaluate usage: --tasks task1 task2')
263 | parser.add_argument('--limit', type=int, default=None, help='number of samples to evalulate')
264 | parser.add_argument('--max_seq_length', type=int, default=None, help='maximum length sequence to evaluate')
265 |
266 | args = parser.parse_args()
267 | main(
268 | Path(args.checkpoint_path), args.compile, args.tasks, args.limit, args.max_seq_length,
269 | )
270 |
--------------------------------------------------------------------------------
/generate.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | import itertools
7 | import sys
8 | import time
9 | from pathlib import Path
10 | from typing import Optional, Tuple, Union
11 |
12 | import torch
13 | import torch._dynamo.config
14 | import torch._inductor.config
15 | from torch.nn.attention.flex_attention import BlockMask, create_block_mask
16 |
17 | def device_sync(device):
18 | if "cuda" in device:
19 | torch.cuda.synchronize(device)
20 | elif ("cpu" in device) or ("mps" in device):
21 | pass
22 | else:
23 | print(f"device={device} is not yet suppported")
24 |
25 |
26 | torch._inductor.config.coordinate_descent_tuning = True
27 | torch._inductor.config.triton.unique_kernel_names = True
28 | # Experimental features to reduce compilation times, will be on by default in future
29 | torch._inductor.config.fx_graph_cache = True
30 | torch._functorch.config.enable_autograd_cache = True
31 |
32 | default_device = 'cuda' if torch.cuda.is_available() else 'cpu'
33 |
34 | create_block_mask = torch.compile(create_block_mask)
35 |
36 | # support running without installing as a package
37 | wd = Path(__file__).parent.parent.resolve()
38 | sys.path.append(str(wd))
39 |
40 | from model import Transformer
41 | from tokenizer import get_tokenizer
42 |
43 | def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization
44 | q = torch.empty_like(probs_sort).exponential_(1)
45 | return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
46 |
47 | def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None):
48 | logits = logits / max(temperature, 1e-5)
49 |
50 | if top_k is not None:
51 | v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
52 | pivot = v.select(-1, -1).unsqueeze(-1)
53 | logits = torch.where(logits < pivot, -float("Inf"), logits)
54 | probs = torch.nn.functional.softmax(logits, dim=-1)
55 | return probs
56 |
57 | def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
58 | probs = logits_to_probs(logits[:, -1], temperature, top_k)
59 | idx_next = multinomial_sample_one_no_sync(probs)
60 | return idx_next, probs
61 |
62 | def roundup(val, multiplier):
63 | return ((val - 1) // multiplier + 1) * multiplier
64 |
65 | def causal_mask(b, h, q, kv):
66 | return q >= kv
67 |
68 | def prefill(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> torch.Tensor:
69 | # input_pos: [B, S]
70 | mask = create_block_mask(causal_mask, 1, 1, input_pos.shape[0], model.max_seq_length, device=x.device)
71 | logits = model(mask, x, input_pos)
72 | return sample(logits, **sampling_kwargs)[0]
73 |
74 | def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, block_mask: BlockMask, **sampling_kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
75 | # input_pos: [B, 1]
76 | assert input_pos.shape[-1] == 1
77 | block_index = input_pos // block_mask.BLOCK_SIZE[0]
78 | mask = block_mask[:, :, block_index]
79 | mask.mask_mod = block_mask.mask_mod
80 | mask.seq_lengths = (1, model.max_seq_length)
81 | logits = model(mask, x, input_pos)
82 | return sample(logits, **sampling_kwargs)
83 |
84 | def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, callback=lambda _: _, **sampling_kwargs):
85 | block_mask = create_block_mask(causal_mask, 1, 1, model.max_seq_length, model.max_seq_length, device=cur_token.device)
86 | new_tokens, new_probs = [], []
87 | for i in range(num_new_tokens):
88 | next_token, next_prob = decode_one_token(
89 | model, cur_token, input_pos, block_mask, **sampling_kwargs
90 | )
91 | input_pos += 1
92 | new_tokens.append(next_token.clone())
93 | callback(new_tokens[-1])
94 | new_probs.append(next_prob.clone())
95 | cur_token = next_token.clone()
96 |
97 | return new_tokens, new_probs
98 |
99 |
100 | def model_forward(model, x, input_pos):
101 | return model(x, input_pos)
102 |
103 | def speculative_decode(
104 | model: Transformer,
105 | draft_model: Transformer,
106 | cur_token: torch.Tensor,
107 | input_pos: int,
108 | speculate_k: int,
109 | **sampling_kwargs
110 | ) -> torch.Tensor:
111 | # draft model inference sequentially
112 | device = cur_token.device
113 | orig_input_pos = torch.tensor([input_pos], dtype=torch.int64, device=cur_token.device)
114 | draft_tokens, draft_probs = decode_n_tokens(draft_model, cur_token.view(1, -1), orig_input_pos.clone(), speculate_k, **sampling_kwargs)
115 |
116 | draft_tokens = torch.cat(draft_tokens)
117 | # parallel inference on target model using draft tokens
118 | target_logits = model_forward(
119 | model,
120 | torch.cat([cur_token.view(1), draft_tokens]).view(1, -1),
121 | torch.arange(input_pos, input_pos + speculate_k + 1, device=cur_token.device)
122 | )
123 | target_probs = logits_to_probs(target_logits[0], **sampling_kwargs)
124 | draft_probs = torch.stack(draft_probs)
125 | # q: target prob, p: draft prob
126 | # q >= p: always accept draft token
127 | # q < p: q/p prob to accept draft token
128 | p = draft_probs[torch.arange(0, speculate_k, device=device), draft_tokens]
129 | q = target_probs[torch.arange(0, speculate_k, device=device), draft_tokens]
130 | accept_draft_prob = torch.minimum(torch.ones(()), q[:speculate_k]/ p)
131 | rejected_locations = (torch.rand_like(accept_draft_prob) > accept_draft_prob).nonzero()
132 |
133 | if rejected_locations.shape[0] == 0: # All draft tokens have been accepted
134 | accept_length = speculate_k + 1
135 | last_token = multinomial_sample_one_no_sync(target_probs[-1])
136 | # fill last token into draft model
137 | model_forward(
138 | draft_model,
139 | draft_tokens[-1].view(1, -1),
140 | orig_input_pos + speculate_k,
141 | )
142 | return torch.cat([draft_tokens, last_token])
143 | else:
144 | accept_length = rejected_locations[0].item()
145 | p = draft_probs[accept_length]
146 | q = target_probs[accept_length]
147 | new = q - p
148 | new = torch.where(new > 0, new, 0.0)
149 | new = new / new.sum()
150 | next_token = multinomial_sample_one_no_sync(new)
151 | return torch.cat([draft_tokens[:accept_length], next_token])
152 |
153 | @torch.no_grad()
154 | def generate(
155 | model: Transformer,
156 | prompt: torch.Tensor,
157 | max_new_tokens: int,
158 | batch_size: int,
159 | *,
160 | interactive: bool,
161 | draft_model: Transformer,
162 | speculate_k: Optional[int] = 8,
163 | callback = lambda x: x,
164 | **sampling_kwargs
165 | ) -> torch.Tensor:
166 | """
167 | Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
168 | """
169 |
170 | is_speculative = draft_model is not None
171 | # create an empty tensor of the expected final shape and fill in the current tokens
172 | T = prompt.size(-1)
173 | T_new = T + max_new_tokens
174 | if interactive:
175 | max_seq_length = 350
176 | else:
177 | max_seq_length = min(T_new, model.config.block_size)
178 |
179 | device, dtype = prompt.device, prompt.dtype
180 | max_seq_length = max_seq_length + speculate_k + 1 if is_speculative else max_seq_length
181 | with torch.device(device):
182 | model.setup_caches(max_batch_size=batch_size, max_seq_length=max_seq_length)
183 | if is_speculative and draft_model is not model:
184 | draft_model.setup_caches(max_batch_size=batch_size, max_seq_length=max_seq_length)
185 |
186 | # create an empty tensor of the expected final shape and fill in the current tokens
187 | empty = torch.empty(batch_size, T_new, dtype=dtype, device=device)
188 | # We are just making the same prompt for every batch
189 | prompt = prompt.view(1, -1).repeat(batch_size, 1)
190 | empty[:, :T] = prompt
191 | seq = empty
192 | input_pos = torch.arange(0, T, device=device)
193 |
194 | next_token = prefill(model, prompt.view(batch_size, -1), input_pos, **sampling_kwargs).clone()
195 | if is_speculative:
196 | prefill(draft_model, prompt.view(batch_size, -1), input_pos, **sampling_kwargs)
197 | seq[:, T] = next_token.squeeze()
198 |
199 | input_pos = torch.tensor([T], device=device, dtype=torch.int)
200 | accept_counts = [0] * (speculate_k + 1)
201 |
202 | if is_speculative:
203 | input_pos = input_pos.item() # for speculative decoding easier to keep on host
204 | while input_pos < T_new - 1:
205 | cur_token = next_token.view(())
206 |
207 | next_tokens = speculative_decode(
208 | model, draft_model, cur_token, input_pos, speculate_k, **sampling_kwargs
209 | )
210 |
211 | accept_counts[len(next_tokens) - 1] += 1
212 | num_added = min(T_new - input_pos - 1, len(next_tokens))
213 | seq[input_pos + 1 : input_pos + num_added + 1] = next_tokens[: num_added]
214 | for i in next_tokens[: num_added,]:
215 | callback(i)
216 | input_pos = input_pos + num_added
217 | next_token = next_tokens[-1]
218 | else:
219 | generated_tokens, _ = decode_n_tokens(model, next_token.view(batch_size, -1), input_pos, max_new_tokens - 1, callback=callback, **sampling_kwargs)
220 | seq[:, T + 1:] = torch.cat(generated_tokens, dim=-1)
221 |
222 | generate_stats = {
223 | 'accept_counts': accept_counts
224 | }
225 | return seq, generate_stats
226 |
227 | def encode_tokens(tokenizer, string, bos=True, device=default_device):
228 | tokens = tokenizer.encode(string)
229 | if bos:
230 | tokens = [tokenizer.bos_id()] + tokens
231 | return torch.tensor(tokens, dtype=torch.int, device=device)
232 |
233 | def _load_model(checkpoint_path, device, precision, use_tp):
234 | use_cuda = 'cuda' in device
235 | with torch.device('meta'):
236 | model = Transformer.from_name(checkpoint_path.parent.name)
237 |
238 | if "int8" in str(checkpoint_path):
239 | print("Using int8 weight-only quantization!")
240 | from quantize import WeightOnlyInt8QuantHandler
241 | simple_quantizer = WeightOnlyInt8QuantHandler(model)
242 | model = simple_quantizer.convert_for_runtime()
243 |
244 | if "int4" in str(checkpoint_path):
245 | print("Using int4 weight-only quantization!")
246 | path_comps = checkpoint_path.name.split(".")
247 | groupsize = int(path_comps[-2][1:])
248 | from quantize import WeightOnlyInt4QuantHandler
249 | simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
250 | model = simple_quantizer.convert_for_runtime()
251 |
252 | checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
253 | if "model" in checkpoint and "stories" in str(checkpoint_path):
254 | checkpoint = checkpoint["model"]
255 | model.load_state_dict(checkpoint, assign=True)
256 |
257 | if use_tp:
258 | from tp import apply_tp
259 | print("Applying tensor parallel to model ...")
260 | apply_tp(model)
261 |
262 | model = model.to(device=device, dtype=precision)
263 | return model.eval()
264 |
265 | def _get_model_size(model):
266 | model_size = 0
267 | params = 0
268 | for name, child in model.named_children():
269 | if not isinstance(child, torch.nn.Embedding):
270 | model_size += sum(
271 | [
272 | p.numel() * p.dtype.itemsize
273 | for p in itertools.chain(child.parameters(), child.buffers())
274 | ]
275 | )
276 | params += sum(
277 | [
278 | p.numel()
279 | for p in itertools.chain(child.parameters(), child.buffers())
280 | ]
281 | )
282 | return model_size, params
283 |
284 | B_INST, E_INST = "[INST]", "[/INST]"
285 |
286 | def main(
287 | prompt: Union[int, str] = "Hello, my name is",
288 | interactive: bool = False,
289 | num_samples: int = 5,
290 | max_new_tokens: int = 100,
291 | batch_size: int = 1,
292 | top_k: int = 200,
293 | temperature: float = 0.8,
294 | checkpoint_path: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"),
295 | compile: bool = True,
296 | compile_prefill: bool = False,
297 | profile: Optional[Path] = None,
298 | draft_checkpoint_path: Optional[Path] = None,
299 | speculate_k: int = 5,
300 | device=default_device,
301 | ) -> None:
302 | """Generates text samples based on a pre-trained Transformer model and tokenizer.
303 | """
304 | assert checkpoint_path.is_file(), checkpoint_path
305 |
306 | tokenizer_path = checkpoint_path.parent / "tokenizer.model"
307 | assert tokenizer_path.is_file(), str(tokenizer_path)
308 |
309 | global print
310 | from tp import maybe_init_dist
311 | rank = maybe_init_dist()
312 | use_tp = rank is not None
313 | if use_tp:
314 | if rank != 0:
315 | # only print on rank 0
316 | print = lambda *args, **kwargs: None
317 |
318 | print(f"Using device={device}")
319 | precision = torch.bfloat16
320 | is_speculative = draft_checkpoint_path is not None
321 | is_chat = "chat" in str(checkpoint_path)
322 |
323 | print("Loading model ...")
324 | t0 = time.time()
325 | model = _load_model(checkpoint_path, device, precision, use_tp)
326 |
327 | if is_speculative:
328 | draft_model = _load_model(draft_checkpoint_path, device, precision, use_tp)
329 | else:
330 | draft_model = None
331 |
332 | device_sync(device=device) # MKG
333 | print(f"Time to load model: {time.time() - t0:.02f} seconds")
334 |
335 | tokenizer = get_tokenizer(tokenizer_path, checkpoint_path)
336 |
337 | if isinstance(prompt, str):
338 | encoded = encode_tokens(tokenizer, prompt, bos=True, device=device)
339 | else:
340 | # generate a fully synthetic prompt
341 | encoded = torch.randint(0, 1024, (prompt,), device=device, dtype=torch.int64)
342 | prompt_length = encoded.size(-1)
343 |
344 | torch.manual_seed(1234)
345 | model_size, params = _get_model_size(model)
346 | if compile:
347 | if is_speculative and use_tp: # and ("cuda" in device):
348 | torch._inductor.config.triton.cudagraph_trees = False # Bug with cudagraph trees in this case
349 |
350 | if is_speculative:
351 | global model_forward, logits_to_prob
352 | model_forward = torch.compile(model_forward, mode="reduce-overhead", fullgraph=True)
353 |
354 | global decode_one_token, prefill
355 | decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead", fullgraph=True)
356 |
357 | # Uncomment to squeeze more perf out of prefill
358 | if compile_prefill:
359 | prefill = torch.compile(prefill, fullgraph=True, dynamic=True)
360 |
361 |
362 | aggregate_metrics = {
363 | 'tokens_per_sec': [],
364 | 'accept_counts': [],
365 | }
366 | start = -1 if compile else 0
367 |
368 | for i in range(start, num_samples):
369 | device_sync(device=device) # MKG
370 | if i >= 0 and interactive:
371 | prompt = input("What is your prompt? ")
372 | if is_chat:
373 | prompt = f"{B_INST} {prompt.strip()} {E_INST}"
374 | encoded = encode_tokens(tokenizer, prompt, bos=True, device=device)
375 |
376 | if interactive and i >= 0:
377 | buffer = []
378 | period_id = tokenizer.encode('.')[0]
379 | done_generating = False
380 | def callback(x):
381 | nonlocal done_generating
382 | if done_generating:
383 | return
384 | buffer.append(tokenizer.decode([period_id] + x.tolist())[1:])
385 | if x.item() == tokenizer.eos_id():
386 | done_generating = True
387 | if len(buffer) == 4 or done_generating:
388 | print(''.join(buffer), end='', flush=True)
389 | buffer.clear()
390 | # print(, end='', flush=True)
391 | else:
392 | callback = lambda x : x
393 | t0 = time.perf_counter()
394 | import contextlib
395 | if (i != num_samples - 1 or not profile) or (use_tp and rank != 0):
396 | prof = contextlib.nullcontext()
397 | else:
398 | torch.profiler._utils._init_for_cuda_graphs()
399 | prof = torch.profiler.profile()
400 | with prof:
401 | y, metrics = generate(
402 | model,
403 | encoded,
404 | max_new_tokens,
405 | batch_size=batch_size,
406 | draft_model=draft_model,
407 | speculate_k=speculate_k,
408 | interactive=interactive,
409 | callback=callback,
410 | temperature=temperature,
411 | top_k=top_k,
412 | )
413 | aggregate_metrics['accept_counts'].append(metrics['accept_counts'])
414 | if i == -1:
415 | print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
416 | continue
417 | if hasattr(prof, "export_chrome_trace"):
418 | if use_tp:
419 | prof.export_chrome_trace(f"{profile}_rank_{rank}.json")
420 | else:
421 | prof.export_chrome_trace(f"{profile}.json")
422 | device_sync(device=device) # MKG
423 | t = time.perf_counter() - t0
424 |
425 | if not interactive:
426 | # Just displaying the first generation
427 | if batch_size > 1:
428 | print("Only displaying the first generation of the batch")
429 | print(tokenizer.decode(y[0].tolist()))
430 | else:
431 | print()
432 | tokens_generated = y.size(-1) - prompt_length
433 | generated_tokens_sec = tokens_generated / t
434 | aggregate_metrics['tokens_per_sec'].append(generated_tokens_sec)
435 | print(f"Time for inference {i + 1}: {t:.02f} sec total, {generated_tokens_sec:.02f} tokens/sec")
436 | print(f"Bandwidth achieved: {model_size * generated_tokens_sec / 1e9:.02f} GB/s")
437 | total_tokens_sec = y.numel() / t
438 | print(f"FLOPS achieved: {params * total_tokens_sec * 2 / 1e12:.02f} TF/s")
439 | print()
440 | print("==========")
441 | if is_speculative:
442 | counts_aggregated = [sum(i) for i in zip(*aggregate_metrics['accept_counts'])]
443 | acceptance_probs = [i/sum(counts_aggregated) for i in counts_aggregated]
444 | print(f"Acceptance probs: {acceptance_probs}")
445 | print(f"Mean Accepted: {sum([idx * i for idx, i in enumerate(counts_aggregated)])/sum(counts_aggregated)}")
446 |
447 | print(f"Batch Size: {batch_size}")
448 | print(f"Prompt Length: {prompt_length}")
449 | print(f"Generated tokens: {max_new_tokens}")
450 | print(f"Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f}")
451 | print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")
452 |
453 |
454 | if __name__ == '__main__':
455 | import argparse
456 | parser = argparse.ArgumentParser(description='Your CLI description.')
457 |
458 | def int_or_str(x):
459 | try:
460 | return int(x)
461 | except:
462 | return x
463 |
464 | parser.add_argument('--prompt', type=int_or_str, default="Hello, my name is", help="Input prompt. If it's an integer, will instead generate a synthetic prompt.")
465 | parser.add_argument('--interactive', action='store_true', help='Whether to launch in interactive mode')
466 | parser.add_argument('--num_samples', type=int, default=5, help='Number of samples.')
467 | parser.add_argument('--max_new_tokens', type=int, default=200, help='Maximum number of new tokens.')
468 | parser.add_argument('--batch_size', type=int, default=1, help='Batch size to benchmark with')
469 | parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.')
470 | parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.')
471 | parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), help='Model checkpoint path.')
472 | parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')
473 | parser.add_argument('--compile_prefill', action='store_true', help='Whether to compile the prefill (improves prefill perf, but higher compile times)')
474 | parser.add_argument('--profile', type=Path, default=None, help='Profile path.')
475 | parser.add_argument('--speculate_k', type=int, default=5, help='Speculative execution depth.')
476 | parser.add_argument('--draft_checkpoint_path', type=Path, default=None, help='Draft checkpoint path.')
477 | parser.add_argument('--device', type=str, default=default_device, help='Device to use')
478 |
479 | args = parser.parse_args()
480 | main(
481 | args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.batch_size, args.top_k,
482 | args.temperature, args.checkpoint_path, args.compile, args.compile_prefill, args.profile, args.draft_checkpoint_path,
483 | args.speculate_k, args.device
484 | )
485 |
--------------------------------------------------------------------------------
/mixtral-moe/README.md:
--------------------------------------------------------------------------------
1 | # Mixtral 8x7B
2 | [Mixtral 8x7B](https://mistral.ai/news/mixtral-of-experts/) is a high-quality sparse mixture of experts (MoE) model that matches or beats GPT3.5 on most benchmarks. This repro is a simple and efficient PyTorch native implementation of Mixtral 8x7B.
3 |
4 | ## Downloading Weights
5 |
6 | ```bash
7 | export MODEL_REPO=mistralai/Mixtral-8x7B-v0.1
8 | python scripts/download.py --repo_id $MODEL_REPO
9 | python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/$MODEL_REPO
10 | ```
11 |
12 | ## Benchmarks
13 | Benchmarks run on an 8xA100-80GB, power limited to 330W with a hybrid cube mesh topology. Note that all benchmarks are run at *batch size=1*, making the reported tokens/s numbers equivalent to "tokens/s/user". In addition, they are run with a very small prompt length (just 5 tokens).
14 |
15 | | | 1 GPU | 2 GPU | 4 GPU | 8 GPU |
16 | |------------------|---------|-----------|--------|------------|
17 | |baseline(bfloat16)| OOM | 96.67 | 155.35 | 227.82 |
18 | | int8 | 97.92 | 155.03 | 216.87 | 279.35 |
19 |
20 |
21 | ## Generate Text
22 |
23 | Model definition in `model.py`, generation code in `generate.py`.
24 |
25 | ```bash
26 | python generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model.pth --prompt "Hello, my name is"
27 | ```
28 |
29 | To squeeze out a little bit more performance, you can also compile the prefill with `--compile_prefill`. This will increase compilation times though.
30 |
31 | ## Quantization
32 | ### Int8 Weight-Only Quantization
33 | To generate this version of the model
34 | ```bash
35 | # Spits out model at checkpoints/$MODEL_REPO/model_int8.pth
36 | python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int8
37 | ```
38 | To run with int8, just pass the int8 checkpoint to generate.py.
39 | ```bash
40 | python generate.py --compile --compile_prefill --checkpoint_path checkpoints/$MODEL_REPO/model_int8.pth
41 | ```
42 |
43 | ## Tensor Parallelism
44 | ```bash
45 | ENABLE_INTRA_NODE_COMM=1 torchrun --standalone --nproc_per_node=8 generate.py --compile --compile_prefill --checkpoint_path checkpoints/$MODEL_REPO/model.pth
46 | ```
47 |
--------------------------------------------------------------------------------
/mixtral-moe/generate.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | import itertools
7 | import sys
8 | import time
9 | from pathlib import Path
10 | from typing import Optional, Tuple
11 |
12 | import torch
13 | import torch._dynamo.config
14 | import torch._inductor.config
15 |
16 | def device_sync(device):
17 | if "cuda" in device:
18 | torch.cuda.synchronize(device)
19 | elif "cpu" in device:
20 | pass
21 | else:
22 | print(f"device={device} is not yet suppported")
23 |
24 |
25 | torch._inductor.config.coordinate_descent_tuning = True
26 | torch._inductor.config.triton.unique_kernel_names = True
27 | torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
28 |
29 |
30 | # support running without installing as a package
31 | wd = Path(__file__).parent.parent.resolve()
32 | sys.path.append(str(wd))
33 |
34 | from sentencepiece import SentencePieceProcessor
35 |
36 | from model import Transformer
37 | from tp import maybe_init_dist
38 |
39 |
40 | def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization
41 | q = torch.empty_like(probs_sort).exponential_(1)
42 | return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
43 |
44 | def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None):
45 | logits = logits / max(temperature, 1e-5)
46 |
47 | if top_k is not None:
48 | v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
49 | pivot = v.select(-1, -1).unsqueeze(-1)
50 | logits = torch.where(logits < pivot, -float("Inf"), logits)
51 | probs = torch.nn.functional.softmax(logits, dim=-1)
52 | return probs
53 |
54 | def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
55 | probs = logits_to_probs(logits[0, -1], temperature, top_k)
56 | idx_next = multinomial_sample_one_no_sync(probs)
57 | return idx_next, probs
58 |
59 | def prefill(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> torch.Tensor:
60 | # input_pos: [B, S]
61 | logits = model(x, input_pos)
62 | return sample(logits, **sampling_kwargs)[0]
63 |
64 | def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
65 | # input_pos: [B, 1]
66 | assert input_pos.shape[-1] == 1
67 | logits = model(x, input_pos)
68 | return sample(logits, **sampling_kwargs)
69 |
70 | def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, callback=lambda _: _, **sampling_kwargs):
71 | new_tokens, new_probs = [], []
72 | for i in range(num_new_tokens):
73 | with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): # Actually better for Inductor to codegen attention here
74 | next_token, next_prob = decode_one_token(
75 | model, cur_token, input_pos, **sampling_kwargs
76 | )
77 | input_pos += 1
78 | new_tokens.append(next_token.clone())
79 | callback(new_tokens[-1])
80 | new_probs.append(next_prob.clone())
81 | cur_token = next_token.view(1, -1)
82 |
83 | return new_tokens, new_probs
84 |
85 |
86 | def model_forward(model, x, input_pos):
87 | return model(x, input_pos)
88 |
89 | @torch.no_grad()
90 | def generate(
91 | model: Transformer,
92 | prompt: torch.Tensor,
93 | max_new_tokens: int,
94 | *,
95 | interactive: bool,
96 | callback = lambda x: x,
97 | **sampling_kwargs
98 | ) -> torch.Tensor:
99 | """
100 | Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
101 | """
102 |
103 | # create an empty tensor of the expected final shape and fill in the current tokens
104 | T = prompt.size(0)
105 | T_new = T + max_new_tokens
106 | if interactive:
107 | max_seq_length = 350
108 | else:
109 | max_seq_length = min(T_new, model.config.block_size)
110 |
111 | device, dtype = prompt.device, prompt.dtype
112 | with torch.device(device):
113 | model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
114 |
115 | # create an empty tensor of the expected final shape and fill in the current tokens
116 | empty = torch.empty(T_new, dtype=dtype, device=device)
117 | empty[:T] = prompt
118 | seq = empty
119 | input_pos = torch.arange(0, T, device=device)
120 |
121 | next_token = prefill(model, prompt.view(1, -1), input_pos, **sampling_kwargs)
122 | seq[T] = next_token
123 |
124 | input_pos = torch.tensor([T], device=device, dtype=torch.int)
125 |
126 | generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, max_new_tokens - 1, callback=callback, **sampling_kwargs)
127 | seq[T + 1:] = torch.cat(generated_tokens)
128 |
129 | return seq
130 |
131 | def encode_tokens(tokenizer, string, bos=True, device='cuda'):
132 | tokens = tokenizer.encode(string)
133 | if bos:
134 | tokens = [tokenizer.bos_id()] + tokens
135 | return torch.tensor(tokens, dtype=torch.int, device=device)
136 |
137 | def _load_model(checkpoint_path, device, precision, use_tp):
138 | with torch.device('meta'):
139 | model = Transformer.from_name(checkpoint_path.parent.name)
140 |
141 | if "int8" in str(checkpoint_path):
142 | print("Using int8 weight-only quantization!")
143 | from quantize import WeightOnlyBit8QuantHandler
144 | simple_quantizer = WeightOnlyBit8QuantHandler(model, torch.int8)
145 | model = simple_quantizer.convert_for_runtime()
146 |
147 | checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
148 | model.load_state_dict(checkpoint, assign=True)
149 |
150 | if use_tp:
151 | from tp import apply_tp
152 | print("Applying tensor parallel to model ...")
153 | apply_tp(model)
154 |
155 | model = model.to(device=device, dtype=precision)
156 | return model.eval()
157 |
158 | B_INST, E_INST = "[INST]", "[/INST]"
159 |
160 | def main(
161 | prompt: str = "Hello, my name is",
162 | interactive: bool = False,
163 | num_samples: int = 5,
164 | max_new_tokens: int = 100,
165 | top_k: int = 200,
166 | temperature: float = 0.8,
167 | checkpoint_path: Path = Path("checkpoints/mistralai/Mixtral-8x7B-v0.1/model.pth"),
168 | compile: bool = True,
169 | compile_prefill: bool = False,
170 | profile: Optional[Path] = None,
171 | device='cuda',
172 | ) -> None:
173 | """Generates text samples based on a pre-trained Transformer model and tokenizer.
174 | """
175 | assert checkpoint_path.is_file(), checkpoint_path
176 |
177 | tokenizer_path = checkpoint_path.parent / "tokenizer.model"
178 | assert tokenizer_path.is_file(), str(tokenizer_path)
179 |
180 | global print
181 | rank = maybe_init_dist()
182 | use_tp = rank is not None
183 | if use_tp:
184 | if rank != 0:
185 | # only print on rank 0
186 | print = lambda *args, **kwargs: None
187 |
188 | print(f"Using device={device}")
189 | precision = torch.bfloat16
190 | is_chat = "chat" in str(checkpoint_path)
191 |
192 | print("Loading model ...")
193 | t0 = time.time()
194 | model = _load_model(checkpoint_path, device, precision, use_tp)
195 |
196 | device_sync(device=device) # MKG
197 | print(f"Time to load model: {time.time() - t0:.02f} seconds")
198 |
199 | tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path))
200 | encoded = encode_tokens(tokenizer, prompt, bos=True, device=device)
201 | prompt_length = encoded.size(0)
202 |
203 | torch.manual_seed(1234)
204 | model_size = sum([p.numel() * p.dtype.itemsize for p in itertools.chain(model.parameters(), model.buffers())])
205 | if compile:
206 | torch._inductor.config.assert_indirect_indexing = False
207 |
208 | global decode_one_token, prefill
209 | decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead", fullgraph=True)
210 |
211 | # Uncomment to squeeze more perf out of prefill
212 | if args.compile_prefill:
213 | prefill = torch.compile(prefill, fullgraph=True, dynamic=True)
214 |
215 |
216 | aggregate_metrics = {
217 | 'tokens_per_sec': [],
218 | }
219 | start = -1 if compile else 0
220 |
221 | for i in range(start, num_samples):
222 | device_sync(device=device) # MKG
223 | if i >= 0 and interactive:
224 | prompt = input("What is your prompt? ")
225 | if is_chat:
226 | prompt = f"{B_INST} {prompt.strip()} {E_INST}"
227 | encoded = encode_tokens(tokenizer, prompt, bos=True, device=device)
228 |
229 | if interactive and i >= 0:
230 | buffer = []
231 | period_id = tokenizer.encode('.')[0]
232 | done_generating = False
233 | def callback(x):
234 | nonlocal done_generating
235 | if done_generating:
236 | return
237 | buffer.append(tokenizer.decode([period_id] + x.tolist())[1:])
238 | if x.item() == tokenizer.eos_id():
239 | done_generating = True
240 | if len(buffer) == 4 or done_generating:
241 | print(''.join(buffer), end='', flush=True)
242 | buffer.clear()
243 | # print(, end='', flush=True)
244 | else:
245 | callback = lambda x : x
246 | t0 = time.perf_counter()
247 | import contextlib
248 | if (i != num_samples - 1 or not profile) or (use_tp and rank != 0):
249 | prof = contextlib.nullcontext()
250 | else:
251 | torch.profiler._utils._init_for_cuda_graphs()
252 | prof = torch.profiler.profile()
253 | with prof:
254 | y = generate(
255 | model,
256 | encoded,
257 | max_new_tokens,
258 | interactive=interactive,
259 | callback=callback,
260 | temperature=temperature,
261 | top_k=top_k,
262 | )
263 | if i == -1:
264 | print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
265 | continue
266 | if hasattr(prof, "export_chrome_trace"):
267 | if use_tp:
268 | prof.export_chrome_trace(f"{profile}_rank_{rank}.json")
269 | else:
270 | prof.export_chrome_trace(f"{profile}.json")
271 | device_sync(device=device) # MKG
272 | t = time.perf_counter() - t0
273 |
274 | if not interactive:
275 | print(tokenizer.decode(y.tolist()))
276 | else:
277 | print()
278 | tokens_generated = y.size(0) - prompt_length
279 | tokens_sec = tokens_generated / t
280 | aggregate_metrics['tokens_per_sec'].append(tokens_sec)
281 | print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec")
282 | print(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s")
283 |
284 | print(f"Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f}")
285 | print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")
286 |
287 |
288 | if __name__ == '__main__':
289 | import argparse
290 | parser = argparse.ArgumentParser(description='Your CLI description.')
291 |
292 | parser.add_argument('--prompt', type=str, default="Hello, my name is", help='Input prompt.')
293 | parser.add_argument('--interactive', action='store_true', help='Whether to launch in interactive mode')
294 | parser.add_argument('--num_samples', type=int, default=5, help='Number of samples.')
295 | parser.add_argument('--max_new_tokens', type=int, default=200, help='Maximum number of new tokens.')
296 | parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.')
297 | parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.')
298 | parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), help='Model checkpoint path.')
299 | parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')
300 | parser.add_argument('--compile_prefill', action='store_true', help='Whether to compile the prefill (improves prefill perf, but higher compile times)')
301 | parser.add_argument('--profile', type=Path, default=None, help='Profile path.')
302 | parser.add_argument('--device', type=str, default="cuda", help='device to use')
303 |
304 | args = parser.parse_args()
305 | main(
306 | args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k,
307 | args.temperature, args.checkpoint_path, args.compile, args.compile_prefill, args.profile, args.device
308 | )
309 |
--------------------------------------------------------------------------------
/mixtral-moe/model.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | from dataclasses import dataclass
7 | from typing import Optional
8 |
9 | import torch
10 | import torch.nn as nn
11 | from torch import Tensor
12 | from torch.nn import functional as F
13 |
14 |
15 | def find_multiple(n: int, k: int) -> int:
16 | if n % k == 0:
17 | return n
18 | return n + k - (n % k)
19 |
20 | @dataclass
21 | class ModelArgs:
22 | block_size: int = 2048
23 | vocab_size: int = 32000
24 | n_layer: int = 32
25 | n_head: int = 32
26 | dim: int = 4096
27 | intermediate_size: int = None
28 | n_local_heads: int = -1
29 | head_dim: int = 64
30 | rope_base: float = 10000
31 | norm_eps: float = 1e-5
32 | num_experts: int = 8
33 | num_activated_experts: int = 2
34 |
35 | def __post_init__(self):
36 | if self.n_local_heads == -1:
37 | self.n_local_heads = self.n_head
38 | if self.intermediate_size is None:
39 | hidden_dim = 4 * self.dim
40 | n_hidden = int(2 * hidden_dim / 3)
41 | self.intermediate_size = find_multiple(n_hidden, 256)
42 | self.head_dim = self.dim // self.n_head
43 |
44 | @classmethod
45 | def from_name(cls, name: str):
46 | if name in transformer_configs:
47 | return cls(**transformer_configs[name])
48 | # fuzzy search
49 | config = [config for config in transformer_configs if config in str(name).upper() or config in str(name)]
50 | assert len(config) == 1, name
51 | return cls(**transformer_configs[config[0]])
52 |
53 |
54 | transformer_configs = {
55 | "Mixtral-8x7B-v0.1": dict(block_size=32768, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, rope_base=1000000.0, num_experts=8, num_activated_experts=2),
56 | }
57 |
58 | class KVCache(nn.Module):
59 | def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16):
60 | super().__init__()
61 | cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
62 | self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype))
63 | self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype))
64 |
65 | def update(self, input_pos, k_val, v_val):
66 | # input_pos: [S], k_val: [B, H, S, D]
67 | assert input_pos.shape[0] == k_val.shape[2]
68 |
69 | k_out = self.k_cache
70 | v_out = self.v_cache
71 | k_out[:, :, input_pos] = k_val
72 | v_out[:, :, input_pos] = v_val
73 |
74 | return k_out, v_out
75 |
76 | class Transformer(nn.Module):
77 | def __init__(self, config: ModelArgs) -> None:
78 | super().__init__()
79 | self.config = config
80 |
81 | self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
82 | self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer))
83 | self.norm = RMSNorm(config.dim, eps=config.norm_eps)
84 | self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
85 |
86 | self.freqs_cis: Optional[Tensor] = None
87 | self.mask_cache: Optional[Tensor] = None
88 | self.max_batch_size = -1
89 | self.max_seq_length = -1
90 |
91 | def setup_caches(self, max_batch_size, max_seq_length):
92 | if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size:
93 | return
94 | head_dim = self.config.dim // self.config.n_head
95 | max_seq_length = find_multiple(max_seq_length, 8)
96 | self.max_seq_length = max_seq_length
97 | self.max_batch_size = max_batch_size
98 | for b in self.layers:
99 | b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim)
100 |
101 | self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.dim // self.config.n_head, self.config.rope_base)
102 | self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool))
103 |
104 | def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
105 | assert self.freqs_cis is not None, "Caches must be initialized first"
106 | mask = self.causal_mask[None, None, input_pos]
107 | freqs_cis = self.freqs_cis[input_pos]
108 | x = self.tok_embeddings(idx)
109 |
110 | for i, layer in enumerate(self.layers):
111 | x = layer(x, input_pos, freqs_cis, mask)
112 | x = self.norm(x)
113 | logits = self.output(x)
114 | return logits
115 |
116 | @classmethod
117 | def from_name(cls, name: str):
118 | return cls(ModelArgs.from_name(name))
119 |
120 |
121 | class TransformerBlock(nn.Module):
122 | def __init__(self, config: ModelArgs) -> None:
123 | super().__init__()
124 | self.attention = Attention(config)
125 | self.block_sparse_moe = MOEFeedForward(config)
126 | self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
127 | self.attention_norm = RMSNorm(config.dim, config.norm_eps)
128 |
129 | def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor:
130 | h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
131 | out = h + self.block_sparse_moe(self.ffn_norm(h))
132 | return out
133 |
134 |
135 | class Attention(nn.Module):
136 | def __init__(self, config: ModelArgs):
137 | super().__init__()
138 | assert config.dim % config.n_head == 0
139 |
140 | total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
141 | # key, query, value projections for all heads, but in a batch
142 | self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
143 | self.wo = nn.Linear(config.dim, config.dim, bias=False)
144 | self.kv_cache = None
145 |
146 | self.n_head = config.n_head
147 | self.head_dim = config.head_dim
148 | self.n_local_heads = config.n_local_heads
149 | self.dim = config.dim
150 | self._register_load_state_dict_pre_hook(self.load_hook)
151 |
152 | def load_hook(self, state_dict, prefix, *args):
153 | if prefix + "wq.weight" in state_dict:
154 | wq = state_dict.pop(prefix + "wq.weight")
155 | wk = state_dict.pop(prefix + "wk.weight")
156 | wv = state_dict.pop(prefix + "wv.weight")
157 | state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
158 |
159 | def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
160 | bsz, seqlen, _ = x.shape
161 |
162 | kv_size = self.n_local_heads * self.head_dim
163 | q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
164 |
165 | q = q.view(bsz, seqlen, self.n_head, self.head_dim)
166 | k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
167 | v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
168 |
169 | q = apply_rotary_emb(q, freqs_cis)
170 | k = apply_rotary_emb(k, freqs_cis)
171 |
172 | q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
173 |
174 | if self.kv_cache is not None:
175 | k, v = self.kv_cache.update(input_pos, k, v)
176 |
177 | k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
178 | v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
179 | y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
180 |
181 | y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
182 |
183 | y = self.wo(y)
184 | return y
185 |
186 |
187 | class ConditionalFeedForward(nn.Module):
188 | def __init__(self, config):
189 | super().__init__()
190 | self.w1 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim))
191 | self.w2 = nn.Parameter(torch.empty(config.num_experts, config.dim, config.intermediate_size))
192 | self.w3 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim))
193 |
194 | def forward(self, x: Tensor, expert_indices: Tensor) -> Tensor:
195 | w1_weights = self.w1[expert_indices] # [T, A, D, D]
196 | w3_weights = self.w3[expert_indices] # [T, A, D, D]
197 | w2_weights = self.w2[expert_indices] # [T, A, D, D]
198 | x1 = F.silu(torch.einsum('ti,taoi -> tao', x, w1_weights))
199 | x3 = torch.einsum('ti, taoi -> tao', x, w3_weights)
200 | expert_outs = torch.einsum('tao, taio -> tai', (x1 * x3), w2_weights)
201 | return expert_outs
202 |
203 |
204 | class MOEFeedForward(nn.Module):
205 | def __init__(self, config) -> None:
206 | super().__init__()
207 | self.gate = nn.Linear(config.dim, config.num_experts, bias=False)
208 | self.cond_ffn = ConditionalFeedForward(config)
209 | self.dim = config.dim
210 | self.num_activated_experts = config.num_activated_experts
211 | def forward(self, x: Tensor) -> Tensor:
212 | x = x.view(-1, self.dim)
213 | # T = num_tokens, E = num_experts, D = hidden dim, A = activated experts
214 | # x: [T, D]
215 | scores = self.gate(x) # [T, E]
216 | expert_weights = F.softmax(scores, dim=-1)
217 | expert_weights, expert_indices = torch.topk(expert_weights, self.num_activated_experts, dim=-1) # [T, A], [T, A]
218 | expert_weights /= expert_weights.sum(dim=-1, keepdim=True) # [T, A]
219 | expert_outs = self.cond_ffn(x, expert_indices)
220 | return torch.einsum('tai,ta -> ti', expert_outs, expert_weights)
221 |
222 |
223 | class RMSNorm(nn.Module):
224 | def __init__(self, dim: int, eps: float = 1e-5):
225 | super().__init__()
226 | self.eps = eps
227 | self.weight = nn.Parameter(torch.ones(dim))
228 |
229 | def _norm(self, x):
230 | return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
231 |
232 | def forward(self, x: Tensor) -> Tensor:
233 | output = self._norm(x.float()).type_as(x)
234 | return output * self.weight
235 |
236 |
237 | def precompute_freqs_cis(
238 | seq_len: int, n_elem: int, base: int = 10000
239 | ) -> Tensor:
240 | freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
241 | t = torch.arange(seq_len, device=freqs.device)
242 | freqs = torch.outer(t, freqs)
243 | freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
244 | cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
245 | return cache.to(dtype=torch.bfloat16)
246 |
247 |
248 | def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
249 | xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
250 | freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
251 | x_out2 = torch.stack(
252 | [
253 | xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
254 | xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
255 | ],
256 | -1,
257 | )
258 |
259 | x_out2 = x_out2.flatten(3)
260 | return x_out2.type_as(x)
261 |
--------------------------------------------------------------------------------
/mixtral-moe/quantize.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | import time
7 | from pathlib import Path
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 |
13 | from model import Transformer, ConditionalFeedForward
14 |
15 | ##### Quantization Primitives ######
16 |
17 | def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype):
18 | # assumes symmetric quantization
19 | # assumes axis == 0
20 | # assumes dense memory format
21 | # TODO(future): relax ^ as needed
22 |
23 | # default setup for affine quantization of activations
24 | eps = torch.finfo(torch.float32).eps
25 |
26 | # get min and max
27 | min_val, max_val = torch.aminmax(x, dim=1)
28 |
29 | # calculate scales and zero_points based on min and max
30 | # reference: https://fburl.com/code/srbiybme
31 | min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
32 | max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
33 | device = min_val_neg.device
34 |
35 | # reference: https://fburl.com/code/4wll53rk
36 | max_val_pos = torch.max(-min_val_neg, max_val_pos)
37 | scales = max_val_pos / (float(quant_max - quant_min) / 2)
38 | # ensure scales is the same dtype as the original tensor
39 | scales = torch.clamp(scales, min=eps).to(x.dtype)
40 | zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
41 |
42 | # quantize based on qmin/qmax/scales/zp
43 | # reference: https://www.internalfb.com/code/fbsource/[8edc275012b1]/fbcode/caffe2/torch/ao/quantization/fx/_decomposed.py?lines=63
44 | x_div = x / scales.unsqueeze(-1)
45 | x_round = torch.round(x_div)
46 | x_zp = x_round + zero_points.unsqueeze(-1)
47 | quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype)
48 |
49 | return quant, scales, zero_points
50 |
51 |
52 | ##### Weight-only int8 per-channel quantized code ######
53 |
54 | def replace_linear_weight_only_bit8_per_channel(module, target_dtype):
55 | for name, child in module.named_children():
56 | if isinstance(child, nn.Linear) and name != "gate":
57 | setattr(module, name, WeightOnlyBit8Linear(child.in_features, child.out_features, target_dtype=target_dtype))
58 | elif isinstance(child, ConditionalFeedForward):
59 | num_experts, intermediate_size, dim = child.w1.shape
60 | setattr(module, name, ConditionalFeedForwardBit8(num_experts, intermediate_size, dim, target_dtype=target_dtype))
61 | else:
62 | replace_linear_weight_only_bit8_per_channel(child, target_dtype)
63 |
64 | class WeightOnlyBit8QuantHandler:
65 | def __init__(self, mod, target_dtype):
66 | self.mod = mod
67 | self.target_dtype = target_dtype
68 |
69 | @torch.no_grad()
70 | def create_quantized_state_dict(self):
71 | cur_state_dict = self.mod.state_dict()
72 | for fqn, mod in self.mod.named_modules():
73 | if isinstance(mod, torch.nn.Linear) and not fqn.endswith(".gate"):
74 | int8_weight, scales, _ = dynamically_quantize_per_channel(mod.weight.float(), -128, 127, self.target_dtype)
75 | cur_state_dict[f"{fqn}.weight"] = int8_weight
76 | cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype)
77 | elif isinstance(mod, ConditionalFeedForward):
78 | for weight_idx in range(0, 3):
79 | weight_name = f"w{weight_idx + 1}"
80 | scales_name = f"scales{weight_idx + 1}"
81 | weight = getattr(mod, weight_name)
82 | num_experts, intermediate_size, dim = weight.shape
83 |
84 | bit8_weight_list = []
85 | scales_list = []
86 | for expert_idx in range(num_experts):
87 | bit8_weight, scales, _ = dynamically_quantize_per_channel(weight[expert_idx].float(), -128, 127, self.target_dtype)
88 | bit8_weight_list.append(bit8_weight.reshape(1, intermediate_size, dim))
89 | scales_list.append(scales.reshape(1, intermediate_size))
90 |
91 | cur_state_dict[f"{fqn}.{weight_name}"] = torch.cat(bit8_weight_list, dim=0)
92 | cur_state_dict[f"{fqn}.{scales_name}"] = torch.cat(scales_list, dim=0)
93 |
94 | return cur_state_dict
95 |
96 | def convert_for_runtime(self):
97 | replace_linear_weight_only_bit8_per_channel(self.mod, self.target_dtype)
98 | return self.mod
99 |
100 |
101 | class WeightOnlyBit8Linear(torch.nn.Module):
102 | __constants__ = ['in_features', 'out_features']
103 | in_features: int
104 | out_features: int
105 | weight: torch.Tensor
106 |
107 | def __init__(self, in_features: int, out_features: int, bias: bool = True,
108 | device=None, dtype=None, target_dtype=None) -> None:
109 | assert target_dtype is not None
110 | factory_kwargs = {'device': device, 'dtype': dtype}
111 | super().__init__()
112 | self.in_features = in_features
113 | self.out_features = out_features
114 | self.register_buffer("weight", torch.empty((out_features, in_features), dtype=target_dtype))
115 | self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16))
116 |
117 | def forward(self, input: torch.Tensor) -> torch.Tensor:
118 | return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales
119 |
120 |
121 | class ConditionalFeedForwardBit8(nn.Module):
122 | def __init__(self, num_experts, intermediate_size, dim, target_dtype):
123 | super().__init__()
124 |
125 | self.target_dtype = target_dtype
126 |
127 | self.register_buffer("w1", torch.empty(num_experts, intermediate_size, dim, dtype=target_dtype))
128 | self.register_buffer("w2", torch.empty(num_experts, dim, intermediate_size, dtype=target_dtype))
129 | self.register_buffer("w3", torch.empty(num_experts, intermediate_size, dim, dtype=target_dtype))
130 |
131 | self.register_buffer("scales1", torch.empty(num_experts, intermediate_size, dtype=torch.bfloat16))
132 | self.register_buffer("scales2", torch.empty(num_experts, dim, dtype=torch.bfloat16))
133 | self.register_buffer("scales3", torch.empty(num_experts, intermediate_size, dtype=torch.bfloat16))
134 |
135 | def forward(self, x, expert_indices):
136 | w1_weights = self.w1.to(x.dtype)[expert_indices] # [T, A, D, D]
137 | w3_weights = self.w3.to(x.dtype)[expert_indices] # [T, A, D, D]
138 | w2_weights = self.w2.to(x.dtype)[expert_indices]
139 | x1 = F.silu(torch.einsum('ti,taoi -> tao', x, w1_weights) * self.scales1[expert_indices].to(x.dtype))
140 | x3 = torch.einsum('ti, taoi -> tao', x, w3_weights) * self.scales3[expert_indices].to(x.dtype)
141 | expert_outs = torch.einsum('tao, taio -> tai', (x1 * x3), w2_weights) * self.scales2[expert_indices].to(x.dtype) # [T, A, D, D]
142 | return expert_outs
143 |
144 |
145 | def quantize(
146 | checkpoint_path: Path = Path("checkpoints/mistralai/Mixtral-8x7B-v0.1/model.pth"),
147 | mode: str = 'int8',
148 | label: str = '',
149 | ) -> None:
150 | assert checkpoint_path.is_file(), checkpoint_path
151 |
152 | device = 'cpu'
153 | precision = torch.bfloat16
154 |
155 | print("Loading model ...")
156 | t0 = time.time()
157 |
158 | with torch.device('meta'):
159 | model = Transformer.from_name(checkpoint_path.parent.name)
160 |
161 | checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
162 | model.load_state_dict(checkpoint, assign=True)
163 | model = model.to(dtype=precision, device=device)
164 |
165 | if mode == 'int8':
166 | print("Quantizing model weights for int8 weight-only symmetric per-channel quantization")
167 | quant_handler = WeightOnlyBit8QuantHandler(model, target_dtype=torch.int8)
168 | quantized_state_dict = quant_handler.create_quantized_state_dict()
169 |
170 | dir_name = checkpoint_path.parent
171 | base_name = checkpoint_path.name
172 | new_base_name = base_name.replace('.pth', f'{label}int8.pth')
173 |
174 | else:
175 | raise ValueError(f"Invalid quantization mode {mode} needs to be one of [int8,]")
176 |
177 | quantize_path = dir_name / new_base_name
178 | print(f"Writing quantized weights to {quantize_path}")
179 | quantize_path.unlink(missing_ok=True) # remove existing file if one already there
180 | torch.save(quantized_state_dict, quantize_path)
181 | print(f"Quantization complete took {time.time() - t0:.02f} seconds")
182 | return
183 |
184 | if __name__ == '__main__':
185 | import argparse
186 | parser = argparse.ArgumentParser(description='Quantize a model.')
187 | parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Path to the model checkpoint to be quantized.')
188 | parser.add_argument('--mode', '-q', type=str, default='int8', choices=['int8', 'int4', 'int4-gptq'], help='type of quantization to perform')
189 | parser.add_argument('--label', type=str, default='_', help='label to add to output filename')
190 |
191 | args = parser.parse_args()
192 | quantize(args.checkpoint_path, args.mode, args.label)
193 |
--------------------------------------------------------------------------------
/mixtral-moe/scripts/convert_hf_checkpoint.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | import glob
7 | import json
8 | import re
9 | import sys
10 | from pathlib import Path
11 | from typing import Optional
12 |
13 | import torch
14 |
15 | # support running without installing as a package
16 | wd = Path(__file__).parent.parent.resolve()
17 | sys.path.append(str(wd))
18 |
19 | from model import ModelArgs
20 |
21 |
22 | @torch.inference_mode()
23 | def convert_hf_checkpoint(
24 | *,
25 | checkpoint_dir: Path = Path("checkpoints/mistralai/Mixtral-8x7B-v0.1"),
26 | model_name: Optional[str] = None,
27 | ) -> None:
28 | if model_name is None:
29 | model_name = checkpoint_dir.name
30 |
31 | config = ModelArgs.from_name(model_name)
32 | print(f"Model config {config.__dict__}")
33 |
34 | weight_map = {
35 | "tok_embeddings.weight": "tok_embeddings.weight",
36 | "layers.{}.attention.wq.weight": "layers.{}.attention.wq.weight",
37 | "layers.{}.attention.wk.weight": "layers.{}.attention.wk.weight",
38 | "layers.{}.attention.wv.weight": "layers.{}.attention.wv.weight",
39 | "layers.{}.attention.wo.weight": "layers.{}.attention.wo.weight",
40 | "layers.{}.block_sparse_moe.w1": "layers.{}.block_sparse_moe.cond_ffn.w1",
41 | "layers.{}.block_sparse_moe.w2": "layers.{}.block_sparse_moe.cond_ffn.w2",
42 | "layers.{}.block_sparse_moe.w3": "layers.{}.block_sparse_moe.cond_ffn.w3",
43 | "layers.{}.block_sparse_moe.gate.weight": "layers.{}.block_sparse_moe.gate.weight",
44 | "layers.{}.attention_norm.weight": "layers.{}.attention_norm.weight",
45 | "layers.{}.ffn_norm.weight": "layers.{}.ffn_norm.weight",
46 | "norm.weight": "norm.weight",
47 | "output.weight": "output.weight",
48 | }
49 |
50 | pt_files = glob.glob(str(checkpoint_dir / "*.pt"))
51 |
52 | merged_result = {}
53 | for file in sorted(pt_files):
54 | state_dict = torch.load(str(file), map_location="cpu", mmap=True, weights_only=True)
55 | merged_result.update(state_dict)
56 | final_result = {}
57 | for key, value in merged_result.items():
58 | if "layers" in key:
59 | abstract_key = re.sub(r'.(\d+).', '.{}.', key)
60 | layer_num = re.search(r'\d+', key).group(0)
61 | new_key = weight_map[abstract_key]
62 | if new_key is None:
63 | continue
64 | new_key = new_key.format(layer_num)
65 | else:
66 | new_key = weight_map[key]
67 |
68 | final_result[new_key] = value
69 |
70 | for key in tuple(final_result.keys()):
71 | if "wq" in key:
72 | q = final_result[key]
73 | k = final_result[key.replace("wq", "wk")]
74 | v = final_result[key.replace("wq", "wv")]
75 | final_result[key.replace("wq", "wqkv")] = torch.cat([q, k, v])
76 | del final_result[key]
77 | del final_result[key.replace("wq", "wk")]
78 | del final_result[key.replace("wq", "wv")]
79 | elif "w1" in key or "w3" in key:
80 | final_result[key] = final_result[key].reshape(config.num_experts, config.intermediate_size, config.dim).contiguous()
81 | elif "w2" in key:
82 | final_result[key] = final_result[key].reshape(config.num_experts, config.intermediate_size, config.dim).permute(0, 2, 1).contiguous()
83 | elif "gate" in key:
84 | final_result[key] = final_result[key].contiguous()
85 |
86 | print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}")
87 | torch.save(final_result, checkpoint_dir / "model.pth")
88 |
89 |
90 | if __name__ == '__main__':
91 | import argparse
92 | parser = argparse.ArgumentParser(description='Convert HuggingFace checkpoint.')
93 | parser.add_argument('--checkpoint_dir', type=Path, default=Path("checkpoints/meta-llama/llama-2-7b-chat-hf"))
94 | parser.add_argument('--model_name', type=str, default=None)
95 |
96 | args = parser.parse_args()
97 | convert_hf_checkpoint(
98 | checkpoint_dir=args.checkpoint_dir,
99 | model_name=args.model_name,
100 | )
101 |
--------------------------------------------------------------------------------
/mixtral-moe/scripts/download.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | import os
7 | from typing import Optional
8 |
9 | from requests.exceptions import HTTPError
10 |
11 |
12 | def hf_download(repo_id: Optional[str] = None, hf_token: Optional[str] = None) -> None:
13 | from huggingface_hub import snapshot_download
14 | os.makedirs(f"checkpoints/{repo_id}", exist_ok=True)
15 | try:
16 | snapshot_download(repo_id, local_dir=f"checkpoints/{repo_id}", local_dir_use_symlinks=False, token=hf_token, ignore_patterns="*.safetensors")
17 | except HTTPError as e:
18 | if e.response.status_code == 401:
19 | print("You need to pass a valid `--hf_token=...` to download private checkpoints.")
20 | else:
21 | raise e
22 |
23 | if __name__ == '__main__':
24 | import argparse
25 | parser = argparse.ArgumentParser(description='Download data from HuggingFace Hub.')
26 | parser.add_argument('--repo_id', type=str, default="checkpoints/meta-llama/llama-2-7b-chat-hf", help='Repository ID to download from.')
27 | parser.add_argument('--hf_token', type=str, default=None, help='HuggingFace API token.')
28 |
29 | args = parser.parse_args()
30 | hf_download(args.repo_id, args.hf_token)
31 |
--------------------------------------------------------------------------------
/mixtral-moe/tp.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | import os
7 | from typing import List, Optional
8 |
9 | import torch
10 | import torch.distributed as dist
11 | from torch import nn
12 | from torch.distributed import _functional_collectives as funcol
13 |
14 | from model import Attention, MOEFeedForward, Transformer
15 |
16 |
17 | def _get_rank() -> int:
18 | return int(os.environ.get("LOCAL_RANK", "0"))
19 |
20 | def is_local():
21 | return _get_rank() == 0
22 |
23 | def local_break():
24 | if is_local():
25 | breakpoint()
26 | dist.barrier()
27 |
28 | def _get_world_size() -> int:
29 | return int(os.environ.get("LOCAL_WORLD_SIZE", "1"))
30 |
31 | def maybe_init_dist() -> Optional[int]:
32 | try:
33 | # provided by torchrun
34 | rank = _get_rank()
35 | world_size = _get_world_size()
36 |
37 | if world_size < 2:
38 | # too few gpus to parallelize, tp is no-op
39 | return None
40 | except KeyError:
41 | # not run via torchrun, no-op
42 | return None
43 |
44 | torch.cuda.set_device(rank)
45 | dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
46 | return rank
47 |
48 | rank = _get_rank()
49 | world_size = _get_world_size()
50 |
51 | def shard(x, dim):
52 | assert x.size(dim=dim) % world_size == 0
53 | return torch.tensor_split(x, world_size, dim=dim)[rank]
54 |
55 | def _apply_tp_linear(linear: nn.Linear, style: str, weight_splits: List[int] = []) -> None:
56 | rank = _get_rank()
57 | world_size = _get_world_size()
58 |
59 | # Linear's weight matrix is transposed, and is of shape
60 | # (linear.out_features, linear.in_features)
61 | dim_lookup = {
62 | "colwise": (0, "out_features"),
63 | "rowwise": (1, "in_features")
64 | }
65 | assert style in dim_lookup
66 | shard_dim, size_attr = dim_lookup[style]
67 |
68 | # ensure we can shard evenly
69 | assert getattr(linear, size_attr) % world_size == 0
70 |
71 | def shard_qkv(qkv, dim, weight_splits):
72 | q, k, v = qkv.split(weight_splits, dim=dim)
73 | q = shard(q, dim)
74 | k = shard(k, dim)
75 | v = shard(v, dim)
76 | return torch.cat((q,k,v), dim=dim)
77 |
78 | # shard
79 | if weight_splits:
80 | # attention
81 | assert len(weight_splits) == 3
82 |
83 | sharded_weight = shard_qkv(linear.weight, shard_dim, weight_splits)
84 | if hasattr(linear, "scales") and style == "colwise":
85 | linear.scales = shard_qkv(linear.scales, 0, weight_splits)
86 | else:
87 | sharded_weight = shard(linear.weight, shard_dim)
88 | if hasattr(linear, "scales") and style == "colwise":
89 | linear.scales = shard(linear.scales, 0)
90 |
91 | # local_break()
92 | linear.weight = nn.Parameter(sharded_weight, requires_grad=False)
93 | setattr(linear, size_attr, getattr(linear, size_attr) // world_size)
94 |
95 | # shape info should still be synced
96 | # assert linear.weight.shape == (linear.out_features, linear.in_features)
97 |
98 |
99 | def _apply_tp_moe_ffn(mlp: MOEFeedForward) -> None:
100 | mlp.cond_ffn.w1 = nn.Parameter(shard(mlp.cond_ffn.w1, 1), requires_grad=False)
101 | mlp.cond_ffn.w3 = nn.Parameter(shard(mlp.cond_ffn.w3, 1), requires_grad=False)
102 | mlp.cond_ffn.w2 = nn.Parameter(shard(mlp.cond_ffn.w2, 2), requires_grad=False)
103 |
104 | if hasattr(mlp.cond_ffn, "scales1"):
105 | mlp.cond_ffn.scales1 = nn.Parameter(shard(mlp.cond_ffn.scales1, 1), requires_grad=False)
106 | mlp.cond_ffn.scales3 = nn.Parameter(shard(mlp.cond_ffn.scales3, 1), requires_grad=False)
107 | mlp.cond_ffn.scales2 = nn.Parameter(mlp.cond_ffn.scales2, requires_grad=False)
108 |
109 | world_size = _get_world_size()
110 | mlp.cond_ffn.register_forward_hook(lambda _module, _input, output: funcol.all_reduce(
111 | output, "sum", list(range(world_size))))
112 |
113 |
114 | def _apply_tp_attn(attn: Attention) -> None:
115 | assert hasattr(attn, "wqkv")
116 | assert hasattr(attn, "wo")
117 |
118 | kv_size = attn.n_local_heads * attn.head_dim
119 | _apply_tp_linear(attn.wqkv, "colwise", [attn.dim, kv_size, kv_size])
120 | _apply_tp_linear(attn.wo, "rowwise")
121 |
122 | # overwrite
123 | world_size = _get_world_size()
124 | assert attn.n_head % world_size == 0, "assert attn.n_head % world_size == 0"
125 | attn.n_head = attn.n_head // world_size
126 | attn.dim = attn.dim // world_size
127 | attn.head_dim = attn.dim // attn.n_head
128 | attn.n_local_heads = attn.n_local_heads // world_size
129 |
130 | attn.register_forward_hook(lambda _module, _input, output: funcol.all_reduce(
131 | output[0], "sum", list(range(world_size))))
132 |
133 |
134 | def _apply_tp_Transformer(Transformer: Transformer) -> None:
135 | # overwrite config before Transformer.setup_cache is called
136 | world_size = _get_world_size()
137 | Transformer.config.n_head = Transformer.config.n_head // world_size
138 | Transformer.config.dim = Transformer.config.dim // world_size
139 | Transformer.config.n_local_heads = Transformer.config.n_local_heads // world_size
140 |
141 |
142 | def apply_tp(model: Transformer) -> None:
143 | _apply_tp_Transformer(model)
144 | for block in model.layers:
145 | # Apply to MLP
146 | _apply_tp_moe_ffn(block.block_sparse_moe)
147 | _apply_tp_attn(block.attention)
148 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | import math
7 | from dataclasses import dataclass
8 | from typing import Optional
9 |
10 | import torch
11 | import torch.nn as nn
12 | from torch import Tensor
13 | from torch.nn import functional as F
14 | from torch.nn.attention.flex_attention import (
15 | _mask_mod_signature,
16 | BlockMask,
17 | flex_attention,
18 | )
19 |
20 |
21 | def find_multiple(n: int, k: int) -> int:
22 | if n % k == 0:
23 | return n
24 | return n + k - (n % k)
25 |
26 |
27 | def get_mask_mod(mask_mod: _mask_mod_signature, offset: int):
28 | def _mask_mod(b, h, q, kv):
29 | return mask_mod(b, h, q + offset, kv)
30 |
31 | return _mask_mod
32 |
33 |
34 | @dataclass
35 | class ModelArgs:
36 | block_size: int = 2048
37 | vocab_size: int = 32000
38 | n_layer: int = 32
39 | n_head: int = 32
40 | dim: int = 4096
41 | intermediate_size: int = None
42 | n_local_heads: int = -1
43 | head_dim: int = 64
44 | rope_base: float = 10000
45 | norm_eps: float = 1e-5
46 | rope_scaling: Optional[dict] = None
47 |
48 | def __post_init__(self):
49 | if self.n_local_heads == -1:
50 | self.n_local_heads = self.n_head
51 | if self.intermediate_size is None:
52 | hidden_dim = 4 * self.dim
53 | n_hidden = int(2 * hidden_dim / 3)
54 | self.intermediate_size = find_multiple(n_hidden, 256)
55 | self.head_dim = self.dim // self.n_head
56 |
57 | @classmethod
58 | def from_name(cls, name: str):
59 | if name in transformer_configs:
60 | return cls(**transformer_configs[name])
61 | # fuzzy search
62 | config = [config for config in transformer_configs if config.lower() in str(name).lower()]
63 |
64 | # We may have two or more configs matched (e.g. "7B" and "Mistral-7B"). Find the best config match,
65 | # take longer name (as it have more symbols matched)
66 | if len(config) > 1:
67 | config.sort(key=len, reverse=True)
68 | assert len(config[0]) != len(config[1]), name # make sure only one 'best' match
69 |
70 | return cls(**transformer_configs[config[0]])
71 |
72 |
73 | transformer_configs = {
74 | "CodeLlama-7b-Python-hf": dict(block_size=16384, vocab_size=32000, n_layer=32, dim = 4096, rope_base=1000000),
75 | "7B": dict(n_layer=32, n_head=32, dim=4096),
76 | "13B": dict(n_layer=40, n_head=40, dim=5120),
77 | "30B": dict(n_layer=60, n_head=52, dim=6656),
78 | "34B": dict(n_layer=48, n_head=64, dim=8192, vocab_size=32000, n_local_heads=8, intermediate_size=22016, rope_base=1000000), # CodeLlama-34B-Python-hf
79 | "70B": dict(n_layer=80, n_head=64, dim=8192, n_local_heads=8, intermediate_size=28672),
80 | "Mistral-7B": dict(n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=32000),
81 | "stories15M": dict(n_layer=6, n_head=6, dim=288),
82 | "stories110M": dict(n_layer=12, n_head=12, dim=768),
83 |
84 | "llama-3-8b": dict(block_size=8192, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256, rope_base=500000),
85 | "llama-3-70b": dict(block_size=8192, n_layer=80, n_head=64, n_local_heads=8, dim=8192, intermediate_size=28672, vocab_size=128256, rope_base=500000),
86 | "llama-3.1-8b": dict(block_size=131072, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256, rope_base=500000,
87 | rope_scaling=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_position_embeddings=8192),
88 | ),
89 | "llama-3.1-70b": dict(block_size=131072, n_layer=80, n_head=64, n_local_heads=8, dim=8192, intermediate_size=28672, vocab_size=128256, rope_base=500000,
90 | rope_scaling=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_position_embeddings=8192),
91 | ),
92 | "llama-3.1-405b": dict(block_size=131072, n_layer=126, n_head=128, n_local_heads=8, dim=16384, intermediate_size=53248, vocab_size=128256, rope_base=500000,
93 | rope_scaling=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_position_embeddings=8192),
94 | ),
95 | }
96 |
97 | class KVCache(nn.Module):
98 | def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16):
99 | super().__init__()
100 | cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
101 | self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype))
102 | self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype))
103 |
104 | def update(self, input_pos, k_val, v_val):
105 | # input_pos: [S], k_val: [B, H, S, D]
106 | assert input_pos.shape[0] == k_val.shape[2]
107 |
108 | k_out = self.k_cache
109 | v_out = self.v_cache
110 | k_out[:, :, input_pos] = k_val
111 | v_out[:, :, input_pos] = v_val
112 |
113 | return k_out, v_out
114 |
115 | class Transformer(nn.Module):
116 | def __init__(self, config: ModelArgs) -> None:
117 | super().__init__()
118 | self.config = config
119 |
120 | self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
121 | self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer))
122 | self.norm = RMSNorm(config.dim, eps=config.norm_eps)
123 | self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
124 |
125 | self.freqs_cis: Optional[Tensor] = None
126 | self.mask_cache: Optional[Tensor] = None
127 | self.max_batch_size = -1
128 | self.max_seq_length = -1
129 | self.get_mask_mod = get_mask_mod
130 |
131 | def setup_caches(self, max_batch_size, max_seq_length):
132 | if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size:
133 | return
134 | head_dim = self.config.dim // self.config.n_head
135 | max_seq_length = find_multiple(max_seq_length, 8)
136 | self.max_seq_length = max_seq_length
137 | self.max_batch_size = max_batch_size
138 | dtype = self.output.weight.dtype
139 | # For quantized layers, dtype is encoded in scales
140 | if hasattr(self.output, "scales"):
141 | dtype = self.output.scales.dtype
142 | elif hasattr(self.output, "scales_and_zeros"):
143 | dtype = self.output.scales_and_zeros.dtype
144 | for b in self.layers:
145 | b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype)
146 |
147 | self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.dim // self.config.n_head, self.config.rope_base, dtype, self.config.rope_scaling)
148 |
149 | def forward(self, mask: BlockMask, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
150 | assert self.freqs_cis is not None, "Caches must be initialized first"
151 | mask.mask_mod = self.get_mask_mod(mask.mask_mod, input_pos[0])
152 | freqs_cis = self.freqs_cis[input_pos]
153 | x = self.tok_embeddings(idx)
154 |
155 | for i, layer in enumerate(self.layers):
156 | x = layer(x, input_pos, freqs_cis, mask)
157 | x = self.norm(x)
158 | logits = self.output(x)
159 | return logits
160 |
161 | @classmethod
162 | def from_name(cls, name: str):
163 | return cls(ModelArgs.from_name(name))
164 |
165 |
166 | class TransformerBlock(nn.Module):
167 | def __init__(self, config: ModelArgs) -> None:
168 | super().__init__()
169 | self.attention = Attention(config)
170 | self.feed_forward = FeedForward(config)
171 | self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
172 | self.attention_norm = RMSNorm(config.dim, config.norm_eps)
173 |
174 | def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: BlockMask) -> Tensor:
175 | h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
176 | out = h + self.feed_forward(self.ffn_norm(h))
177 | return out
178 |
179 |
180 | class Attention(nn.Module):
181 | def __init__(self, config: ModelArgs):
182 | super().__init__()
183 | assert config.dim % config.n_head == 0
184 |
185 | total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
186 | # key, query, value projections for all heads, but in a batch
187 | self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
188 | self.wo = nn.Linear(config.dim, config.dim, bias=False)
189 | self.kv_cache = None
190 |
191 | self.n_head = config.n_head
192 | self.head_dim = config.head_dim
193 | self.n_local_heads = config.n_local_heads
194 | self.dim = config.dim
195 | self._register_load_state_dict_pre_hook(self.load_hook)
196 |
197 | def load_hook(self, state_dict, prefix, *args):
198 | if prefix + "wq.weight" in state_dict:
199 | wq = state_dict.pop(prefix + "wq.weight")
200 | wk = state_dict.pop(prefix + "wk.weight")
201 | wv = state_dict.pop(prefix + "wv.weight")
202 | state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
203 |
204 | def forward(self, x: Tensor, freqs_cis: Tensor, mask: BlockMask, input_pos: Optional[Tensor] = None) -> Tensor:
205 | bsz, seqlen, _ = x.shape
206 |
207 | kv_size = self.n_local_heads * self.head_dim
208 | q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
209 |
210 | q = q.view(bsz, seqlen, self.n_head, self.head_dim)
211 | k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
212 | v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
213 |
214 | q = apply_rotary_emb(q, freqs_cis)
215 | k = apply_rotary_emb(k, freqs_cis)
216 |
217 | q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
218 |
219 | if self.kv_cache is not None:
220 | k, v = self.kv_cache.update(input_pos, k, v)
221 |
222 | y = flex_attention(q, k, v, block_mask=mask, enable_gqa=(self.n_head != self.n_local_heads))
223 |
224 | y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
225 |
226 | y = self.wo(y)
227 | return y
228 |
229 |
230 | class FeedForward(nn.Module):
231 | def __init__(self, config: ModelArgs) -> None:
232 | super().__init__()
233 | self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
234 | self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
235 | self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
236 |
237 | def forward(self, x: Tensor) -> Tensor:
238 | return self.w2(F.silu(self.w1(x)) * self.w3(x))
239 |
240 |
241 | class RMSNorm(nn.Module):
242 | def __init__(self, dim: int, eps: float = 1e-5):
243 | super().__init__()
244 | self.eps = eps
245 | self.weight = nn.Parameter(torch.ones(dim))
246 |
247 | def _norm(self, x):
248 | return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
249 |
250 | def forward(self, x: Tensor) -> Tensor:
251 | output = self._norm(x.float()).type_as(x)
252 | return output * self.weight
253 |
254 |
255 | def apply_rope_scaling(freqs: torch.Tensor, rope_scaling: Optional[dict] = None):
256 | factor = rope_scaling["factor"]
257 | low_freq_factor = rope_scaling["low_freq_factor"]
258 | high_freq_factor = rope_scaling["high_freq_factor"]
259 | old_context_len = rope_scaling["original_max_position_embeddings"]
260 |
261 | low_freq_wavelen = old_context_len / low_freq_factor
262 | high_freq_wavelen = old_context_len / high_freq_factor
263 | new_freqs = []
264 | for freq in freqs:
265 | wavelen = 2 * math.pi / freq
266 | if wavelen < high_freq_wavelen:
267 | new_freqs.append(freq)
268 | elif wavelen > low_freq_wavelen:
269 | new_freqs.append(freq / factor)
270 | else:
271 | assert low_freq_wavelen != high_freq_wavelen
272 | smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
273 | new_freqs.append((1 - smooth) * freq / factor + smooth * freq)
274 | return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
275 |
276 |
277 | def precompute_freqs_cis(
278 | seq_len: int, n_elem: int, base: int = 10000,
279 | dtype: torch.dtype = torch.bfloat16,
280 | rope_scaling: Optional[dict] = None,
281 | ) -> Tensor:
282 | freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
283 | if rope_scaling is not None:
284 | freqs = apply_rope_scaling(freqs, rope_scaling)
285 | t = torch.arange(seq_len, device=freqs.device)
286 | freqs = torch.outer(t, freqs)
287 | freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
288 | cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
289 | return cache.to(dtype=dtype)
290 |
291 |
292 | def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
293 | xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
294 | freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
295 | x_out2 = torch.stack(
296 | [
297 | xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
298 | xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
299 | ],
300 | -1,
301 | )
302 |
303 | x_out2 = x_out2.flatten(3)
304 | return x_out2.type_as(x)
305 |
--------------------------------------------------------------------------------
/quantize.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | import time
7 | from pathlib import Path
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | from tokenizer import get_tokenizer
13 |
14 | try:
15 | from GPTQ import GenericGPTQRunner, InputRecorder
16 | from eval import get_task_dict, evaluate, lm_eval
17 | except:
18 | pass
19 |
20 | from model import Transformer
21 |
22 | ##### Quantization Primitives ######
23 |
24 | def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype):
25 | # assumes symmetric quantization
26 | # assumes axis == 0
27 | # assumes dense memory format
28 | # TODO(future): relax ^ as needed
29 |
30 | # default setup for affine quantization of activations
31 | eps = torch.finfo(torch.float32).eps
32 |
33 | # get min and max
34 | min_val, max_val = torch.aminmax(x, dim=1)
35 |
36 | # calculate scales and zero_points based on min and max
37 | # reference: https://fburl.com/code/srbiybme
38 | min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
39 | max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
40 | device = min_val_neg.device
41 |
42 | # reference: https://fburl.com/code/4wll53rk
43 | max_val_pos = torch.max(-min_val_neg, max_val_pos)
44 | scales = max_val_pos / (float(quant_max - quant_min) / 2)
45 | # ensure scales is the same dtype as the original tensor
46 | scales = torch.clamp(scales, min=eps).to(x.dtype)
47 | zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
48 |
49 | # quantize based on qmin/qmax/scales/zp
50 | # reference: https://www.internalfb.com/code/fbsource/[8edc275012b1]/fbcode/caffe2/torch/ao/quantization/fx/_decomposed.py?lines=63
51 | x_div = x / scales.unsqueeze(-1)
52 | x_round = torch.round(x_div)
53 | x_zp = x_round + zero_points.unsqueeze(-1)
54 | quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype)
55 |
56 | return quant, scales, zero_points
57 |
58 | def get_group_qparams(w, n_bit=4, groupsize=128):
59 | # needed for GPTQ with padding
60 | if groupsize > w.shape[-1]:
61 | groupsize = w.shape[-1]
62 | assert groupsize > 1
63 | assert w.shape[-1] % groupsize == 0
64 | assert w.dim() == 2
65 |
66 | to_quant = w.reshape(-1, groupsize)
67 | assert torch.isnan(to_quant).sum() == 0
68 |
69 | max_val = to_quant.amax(dim=1, keepdim=True)
70 | min_val = to_quant.amin(dim=1, keepdim=True)
71 | max_int = 2**n_bit - 1
72 | scales = (max_val - min_val).clamp(min=1e-6) / max_int
73 | zeros = min_val + scales * (2 ** (n_bit - 1))
74 | return scales.to(torch.bfloat16).reshape(w.shape[0], -1), zeros.to(
75 | torch.bfloat16
76 | ).reshape(w.shape[0], -1)
77 |
78 |
79 | def pack_scales_and_zeros(scales, zeros):
80 | assert scales.shape == zeros.shape
81 | assert scales.dtype == torch.bfloat16
82 | assert zeros.dtype == torch.bfloat16
83 | return (
84 | torch.cat(
85 | [
86 | scales.reshape(scales.size(0), scales.size(1), 1),
87 | zeros.reshape(zeros.size(0), zeros.size(1), 1),
88 | ],
89 | 2,
90 | )
91 | .transpose(0, 1)
92 | .contiguous()
93 | )
94 |
95 |
96 | def unpack_scales_and_zeros(scales_and_zeros):
97 | assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2
98 | assert scales_and_zeros.dtype == torch.float
99 | return torch.split(scales_and_zeros.transpose(0, 1), 1, 2)
100 |
101 |
102 | def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128):
103 | assert groupsize > 1
104 | # needed for GPTQ single column quantize
105 | if groupsize > w.shape[-1] and scales.shape[-1] == 1:
106 | groupsize = w.shape[-1]
107 |
108 | assert w.shape[-1] % groupsize == 0
109 | assert w.dim() == 2
110 |
111 | to_quant = w.reshape(-1, groupsize)
112 | assert torch.isnan(to_quant).sum() == 0
113 |
114 | scales = scales.reshape(-1, 1)
115 | zeros = zeros.reshape(-1, 1)
116 | min_val = zeros - scales * (2 ** (n_bit - 1))
117 | max_int = 2**n_bit - 1
118 | min_int = 0
119 | w_int32 = (
120 | to_quant.sub(min_val)
121 | .div(scales)
122 | .round()
123 | .clamp_(min_int, max_int)
124 | .to(torch.int32)
125 | .reshape_as(w)
126 | )
127 |
128 | return w_int32
129 |
130 |
131 | def group_quantize_tensor(w, n_bit=4, groupsize=128):
132 | scales, zeros = get_group_qparams(w, n_bit, groupsize)
133 | w_int32 = group_quantize_tensor_from_qparams(w, scales, zeros, n_bit, groupsize)
134 | scales_and_zeros = pack_scales_and_zeros(scales, zeros)
135 | return w_int32, scales_and_zeros
136 |
137 |
138 | def group_dequantize_tensor_from_qparams(
139 | w_int32, scales, zeros, n_bit=4, groupsize=128
140 | ):
141 | assert groupsize > 1
142 | # needed for GPTQ single column dequantize
143 | if groupsize > w_int32.shape[-1] and scales.shape[-1] == 1:
144 | groupsize = w_int32.shape[-1]
145 | assert w_int32.shape[-1] % groupsize == 0
146 | assert w_int32.dim() == 2
147 |
148 | w_int32_grouped = w_int32.reshape(-1, groupsize)
149 | scales = scales.reshape(-1, 1)
150 | zeros = zeros.reshape(-1, 1)
151 |
152 | w_dq = (
153 | w_int32_grouped.sub(2 ** (n_bit - 1)).mul(scales).add(zeros).reshape_as(w_int32)
154 | )
155 | return w_dq
156 |
157 |
158 | def group_dequantize_tensor(w_int32, scales_and_zeros, n_bit=4, groupsize=128):
159 | scales, zeros = unpack_scales_and_zeros(scales_and_zeros)
160 | return group_dequantize_tensor_from_qparams(
161 | w_int32, scales, zeros, n_bit, groupsize
162 | )
163 |
164 | class QuantHandler:
165 | def __init__(self, mod):
166 | self.mod = mod
167 |
168 | def create_quantized_state_dict(self) -> "StateDict":
169 | pass
170 |
171 | def convert_for_runtime(self) -> "nn.Module":
172 | pass
173 |
174 | class GPTQQuantHandler(QuantHandler):
175 | """
176 | This class implements a GPTQ QuantHandler that can be used to apply GPTQ to a model in concert with the GenericGPTQRunner class.
177 | Unlike the base QuantHandler class, the user does not need to implement the create_quantized_state_dict, instead they have to reimplement
178 | __init__ such that it defines the functions for the quantization mode. User is expected to reimplement convert_for_runtime.
179 |
180 | The following functions (which must be defined in __init__) are used to define the quantization mode for both GPTQ and
181 | create_quantized_state_dict. Here is a description of each function.
182 |
183 | get_qparams_func:
184 | A function that calculates the quantization qparams for an input tensor.
185 | Args:
186 | weight: A 2d weight tensor with non-integer dtype.
187 | Returns:
188 | qparams: it can have any format but will need to be handled by the other defined functions below.
189 |
190 | quantize_func:
191 | A function that applies quantization to an input tensor. It should be noted
192 | that this function needs to be able to handle quantizing the entire weight tensor, a single group,
193 | or a single column.
194 | Args:
195 | weight: A 2d weight tensor with non-integer dtype.
196 | qparams: the output from get_qparams_func
197 | Returns:
198 | quantized_weight: A 2d quantized weight tensor (generally with an integer dtype)
199 |
200 |
201 | dequantize_func:
202 | A function that dequantizes an input quantized weight tensor. It should be noted
203 | that this function needs to be able to handle dequantizing the entire weight tensor, a single group,
204 | or a single column.
205 | Args:
206 | quantized_weight: A 2d quantized weight tensor (generally with an integer dtype)
207 | qparams: the output from get_qparams_func
208 | Returns:
209 | weight: A 2d weight tensor with non-integer dtype.
210 |
211 | combine_qparams_list_func:
212 | A function that combines several qparams into one qparam.
213 | Args:
214 | qparams_list: a list of qparams objects, each obtained by calling get_qparams_func
215 | on a single group from a weight tensor
216 | Returns:
217 | qparams: an object of the same format as the qparams above.
218 |
219 | skip_layer_func:
220 | A function that determines which linear layers should be skipped during GPTQ
221 | Args:
222 | weight: A 2d weight tensor with non-integer dtype.
223 | Returns:
224 | skip: boolean indicating whether layer should be skipped
225 |
226 | make_names_and_values_dict_func:
227 | A function that prepares the qparams and quantized_weight and creates a dictionary indicating how they
228 | should be inserted into the state_dict. Generally any packing of the weight and qparams should be done here.
229 | Args:
230 | quantized_weight: A 2d quantized weight tensor (generally with an integer dtype)
231 | qparams: the output from get_qparams_func
232 | Returns:
233 | names_and_values_dict: a dictionary mapping the name of the parameters of the quantized module to the
234 | corresponding quantized weights and qparams.
235 | """
236 | def __init__(self):
237 | assert self.mod is not None
238 | assert self.get_qparams_func is not None
239 | assert self.quantize_func is not None
240 | assert self.dequantize_func is not None
241 | assert self.combine_qparams_list_func is not None
242 | assert self.make_names_and_values_dict_func is not None
243 |
244 | @staticmethod
245 | def get_inputs(model, tokenizer, calibration_tasks, calibration_limit, calibration_seq_length, pad_calibration_inputs) -> "MultiInput":
246 | input_recorder = InputRecorder(
247 | model,
248 | tokenizer,
249 | calibration_seq_length,
250 | pad_calibration_inputs,
251 | )
252 |
253 | try:
254 | lm_eval.tasks.initialize_tasks()
255 | except:
256 | pass
257 | task_dict = get_task_dict(calibration_tasks)
258 | print("Obtaining GPTQ calibration inputs on: ", calibration_tasks)
259 |
260 | evaluate(
261 | input_recorder,
262 | task_dict,
263 | limit=calibration_limit,
264 | )
265 | inputs = input_recorder.get_recorded_inputs()
266 | assert inputs is not None, (
267 | f"No inputs were collected, use a task other than {calibration_tasks}, "+
268 | f"use option pad_calibration_inputs, or decrease calibration_sequence_length (currently "+
269 | f"{calibration_seq_length})"
270 | )
271 | print(f"Obtained {len(inputs[0].values)} calibration samples")
272 | return inputs
273 |
274 | @torch.no_grad()
275 | def create_quantized_state_dict(
276 | self,
277 | tokenizer,
278 | blocksize,
279 | percdamp,
280 | groupsize,
281 | calibration_tasks,
282 | calibration_limit,
283 | calibration_seq_length,
284 | pad_calibration_inputs,
285 | ) -> "StateDict":
286 | inputs = GPTQQuantHandler.get_inputs(self.mod, tokenizer, calibration_tasks, calibration_limit, calibration_seq_length, pad_calibration_inputs)
287 | print("Tracing model for GPTQ")
288 | GPTQ_runner = GenericGPTQRunner(
289 | self.mod,
290 | inputs,
291 | blocksize,
292 | percdamp,
293 | groupsize,
294 | ).configure_quantization_mode(
295 | self.get_qparams_func,
296 | self.quantize_func,
297 | self.dequantize_func,
298 | self.combine_qparams_list_func,
299 | self.make_names_and_values_dict_func,
300 | self.skip_layer_func
301 | )
302 |
303 | print("Applying GPTQ to weights")
304 | GPTQ_runner.run()
305 | return GPTQ_runner.get_quantized_state_dict()
306 |
307 | def convert_for_runtime(self) -> "nn.Module":
308 | pass
309 |
310 | ##### Weight-only int8 per-channel quantized code ######
311 |
312 | def replace_linear_weight_only_int8_per_channel(module):
313 | for name, child in module.named_children():
314 | if isinstance(child, nn.Linear):
315 | setattr(module, name, WeightOnlyInt8Linear(child.in_features, child.out_features))
316 | else:
317 | replace_linear_weight_only_int8_per_channel(child)
318 |
319 | class WeightOnlyInt8QuantHandler:
320 | def __init__(self, mod):
321 | self.mod = mod
322 |
323 | @torch.no_grad()
324 | def create_quantized_state_dict(self):
325 | cur_state_dict = self.mod.state_dict()
326 | for fqn, mod in self.mod.named_modules():
327 | if isinstance(mod, torch.nn.Linear):
328 | int8_weight, scales, _ = dynamically_quantize_per_channel(mod.weight.float(), -128, 127, torch.int8)
329 | cur_state_dict[f"{fqn}.weight"] = int8_weight
330 | cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype)
331 |
332 | return cur_state_dict
333 |
334 | def convert_for_runtime(self):
335 | replace_linear_weight_only_int8_per_channel(self.mod)
336 | return self.mod
337 |
338 |
339 | class WeightOnlyInt8Linear(torch.nn.Module):
340 | __constants__ = ['in_features', 'out_features']
341 | in_features: int
342 | out_features: int
343 | weight: torch.Tensor
344 |
345 | def __init__(self, in_features: int, out_features: int, bias: bool = True,
346 | device=None, dtype=None) -> None:
347 | factory_kwargs = {'device': device, 'dtype': dtype}
348 | super().__init__()
349 | self.in_features = in_features
350 | self.out_features = out_features
351 | self.register_buffer("weight", torch.empty((out_features, in_features), dtype=torch.int8))
352 | self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16))
353 |
354 | def forward(self, input: torch.Tensor) -> torch.Tensor:
355 | return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales
356 |
357 | ##### weight only int4 per channel groupwise quantized code ######
358 |
359 | def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_tiles):
360 | weight_int32, scales_and_zeros = group_quantize_tensor(
361 | weight_bf16, n_bit=4, groupsize=groupsize
362 | )
363 | weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(weight_int32, inner_k_tiles)
364 | return weight_int4pack, scales_and_zeros
365 |
366 |
367 | def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize):
368 | origin_x_size = x.size()
369 | x = x.reshape(-1, origin_x_size[-1])
370 | c = torch.ops.aten._weight_int4pack_mm(x, weight_int4pack, groupsize, scales_and_zeros)
371 | new_shape = origin_x_size[:-1] + (out_features,)
372 | c = c.reshape(new_shape)
373 | return c
374 |
375 |
376 | def _check_linear_int4_k(k, groupsize = 1, inner_k_tiles = 1):
377 | return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0
378 |
379 | def replace_linear_int4(module, groupsize, inner_k_tiles, padding):
380 | for name, child in module.named_children():
381 | if isinstance(child, nn.Linear):
382 | if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles):
383 | setattr(module, name, WeightOnlyInt4Linear(
384 | child.in_features, child.out_features, bias=False,
385 | groupsize=groupsize, inner_k_tiles=inner_k_tiles, padding=False,
386 | ))
387 | elif padding:
388 | setattr(module, name, WeightOnlyInt4Linear(
389 | child.in_features, child.out_features, bias=False,
390 | groupsize=groupsize, inner_k_tiles=inner_k_tiles, padding=True,
391 | ))
392 | else:
393 | replace_linear_int4(child, groupsize, inner_k_tiles, padding)
394 |
395 |
396 | class WeightOnlyInt4QuantHandler:
397 | def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
398 | self.mod = mod
399 | self.groupsize = groupsize
400 | self.inner_k_tiles = inner_k_tiles
401 | self.padding = padding
402 | assert groupsize in [32, 64, 128, 256]
403 | assert inner_k_tiles in [2, 4, 8]
404 |
405 | @torch.no_grad()
406 | def create_quantized_state_dict(self, use_cuda = True):
407 | if use_cuda:
408 | device="cuda"
409 | else:
410 | device="cpu"
411 |
412 | cur_state_dict = self.mod.state_dict()
413 | for fqn, mod in self.mod.named_modules():
414 | if isinstance(mod, torch.nn.Linear):
415 | assert not mod.bias
416 | out_features = mod.out_features
417 | in_features = mod.in_features
418 | assert out_features % 8 == 0, "require out_features % 8 == 0"
419 | print(f"linear: {fqn}, in={in_features}, out={out_features}")
420 |
421 | weight = mod.weight.data
422 | if not _check_linear_int4_k(in_features, self.groupsize, self.inner_k_tiles):
423 | if self.padding:
424 | from model import find_multiple
425 | import torch.nn.functional as F
426 | print(f"warning: {fqn} is padded to satisfy in_features % 1024 == 0")
427 | padded_in_features = find_multiple(in_features, 1024)
428 | weight = F.pad(weight, pad=(0, padded_in_features - in_features))
429 | else:
430 | print(f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " +
431 | "and that groupsize and inner_k_tiles*16 evenly divide into it")
432 | continue
433 | weight_int4pack, scales_and_zeros = prepare_int4_weight_and_scales_and_zeros(
434 | weight.to(torch.bfloat16).to(device=device), self.groupsize, self.inner_k_tiles
435 | )
436 | cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to('cpu')
437 | cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to('cpu')
438 |
439 | return cur_state_dict
440 |
441 | def convert_for_runtime(self):
442 | replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding)
443 | return self.mod
444 |
445 | class WeightOnlyInt4GPTQQuantHandler(GPTQQuantHandler):
446 | def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
447 | from model import find_multiple
448 | self.mod = mod
449 | self.groupsize = groupsize
450 | self.inner_k_tiles = inner_k_tiles
451 | self.padding = padding
452 | self.get_qparams_func = lambda w: get_group_qparams(w, 4, groupsize)
453 | self.quantize_func = lambda w, qparams: \
454 | group_quantize_tensor_from_qparams(w, qparams[0], qparams[1], 4, groupsize)
455 | self.dequantize_func = lambda q, qparams: \
456 | group_dequantize_tensor_from_qparams(q, qparams[0], qparams[1], 4, groupsize).float()
457 | self.combine_qparams_list_func = lambda qparams_list: \
458 | [torch.cat(x, dim=1) for x in zip(*qparams_list)]
459 | # skip unless padding=True or its correctly sized
460 | self.skip_layer_func = lambda linear_weight: not (
461 | _check_linear_int4_k(linear_weight.shape[-1], groupsize, inner_k_tiles) or padding
462 | )
463 | # we need to do the padding here, both for q and the qparams if necessary
464 | def make_names_and_values_dict_func(q, qparams):
465 | k = q.shape[1]
466 | new_k = find_multiple(k, 1024)
467 | # how much we need to pad the weight
468 | delta_k = new_k - q.shape[1]
469 | final_q = torch.ops.aten._convert_weight_to_int4pack(F.pad(q, pad=(0, delta_k)), inner_k_tiles)
470 | scales_and_zeros = pack_scales_and_zeros(*qparams)
471 | # how many new groups we need for padded weight
472 | delta_groups = new_k // groupsize - scales_and_zeros.shape[0]
473 | final_s_and_z = F.pad(scales_and_zeros, pad=(0,0,0,0,0, delta_groups), value=1)
474 | return {"weight": final_q, "scales_and_zeros": final_s_and_z}
475 | self.make_names_and_values_dict_func = make_names_and_values_dict_func
476 | super().__init__()
477 |
478 |
479 | def convert_for_runtime(self):
480 | replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding)
481 | return self.mod
482 |
483 | class WeightOnlyInt4Linear(torch.nn.Module):
484 | __constants__ = ['in_features', 'out_features']
485 | in_features: int
486 | out_features: int
487 | weight: torch.Tensor
488 |
489 | def __init__(
490 | self, in_features: int, out_features: int,
491 | bias=True, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8, padding: bool = True,
492 | ) -> None:
493 | super().__init__()
494 | self.padding = padding
495 | if padding:
496 | from model import find_multiple
497 | self.origin_in_features = in_features
498 | in_features = find_multiple(in_features, 1024)
499 |
500 | self.in_features = in_features
501 | self.out_features = out_features
502 | assert not bias, "require bias=False"
503 | self.groupsize = groupsize
504 | self.inner_k_tiles = inner_k_tiles
505 |
506 | assert out_features % 8 == 0, "require out_features % 8 == 0"
507 | assert in_features % (inner_k_tiles * 16) == 0, "require in_features % (innerKTiles * 16) == 0"
508 | self.register_buffer(
509 | "weight",
510 | torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32)
511 | )
512 | self.register_buffer(
513 | "scales_and_zeros",
514 | torch.empty((in_features // groupsize, out_features, 2), dtype=torch.bfloat16)
515 | )
516 |
517 | def forward(self, input: torch.Tensor) -> torch.Tensor:
518 | input = input.to(torch.bfloat16)
519 | if self.padding:
520 | import torch.nn.functional as F
521 | input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
522 | return linear_forward_int4(
523 | input,
524 | self.weight, self.scales_and_zeros, self.out_features, self.groupsize
525 | )
526 |
527 |
528 | def quantize(
529 | checkpoint_path: Path = Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"),
530 | mode: str = 'int8',
531 | # following arguments only available when setting int4 quantization.
532 | groupsize: int = 128,
533 | # following arguments only used for GPTQ
534 | calibration_tasks: list = ["hellaswag"],
535 | calibration_limit: int = 1000,
536 | calibration_seq_length: int = 100,
537 | pad_calibration_inputs: bool = False,
538 | percdamp: float = .01,
539 | blocksize: int = 128,
540 | label: str = '',
541 | ) -> None:
542 | assert checkpoint_path.is_file(), checkpoint_path
543 |
544 | device = 'cpu'
545 | precision = torch.bfloat16
546 |
547 | print("Loading model ...")
548 | t0 = time.time()
549 |
550 | with torch.device('meta'):
551 | model = Transformer.from_name(checkpoint_path.parent.name)
552 |
553 | checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
554 | model.load_state_dict(checkpoint, assign=True)
555 | model = model.to(dtype=precision, device=device)
556 |
557 | if mode == 'int8':
558 | print("Quantizing model weights for int8 weight-only symmetric per-channel quantization")
559 | quant_handler = WeightOnlyInt8QuantHandler(model)
560 | quantized_state_dict = quant_handler.create_quantized_state_dict()
561 |
562 | dir_name = checkpoint_path.parent
563 | base_name = checkpoint_path.name
564 | new_base_name = base_name.replace('.pth', f'{label}int8.pth')
565 |
566 | elif mode == 'int4':
567 | print("Quantizing model weights for int4 weight-only affine per-channel groupwise quantization")
568 | quant_handler = WeightOnlyInt4QuantHandler(model, groupsize)
569 | quantized_state_dict = quant_handler.create_quantized_state_dict()
570 |
571 | dir_name = checkpoint_path.parent
572 | base_name = checkpoint_path.name
573 | new_base_name = base_name.replace('.pth', f"{label}int4.g{groupsize}.pth")
574 |
575 | elif mode == 'int4-gptq':
576 | print("Quantizing model weights for int4 weight-only affine per-channel groupwise quantization using GPTQ...")
577 | quant_handler = WeightOnlyInt4GPTQQuantHandler(model, groupsize)
578 |
579 | tokenizer_path = checkpoint_path.parent / "tokenizer.model"
580 | assert tokenizer_path.is_file(), str(tokenizer_path)
581 | tokenizer = get_tokenizer(tokenizer_path, checkpoint_path)
582 |
583 | quantized_state_dict = quant_handler.create_quantized_state_dict(
584 | tokenizer,
585 | blocksize,
586 | percdamp,
587 | groupsize,
588 | calibration_tasks,
589 | calibration_limit,
590 | calibration_seq_length,
591 | pad_calibration_inputs
592 | )
593 |
594 | dir_name = checkpoint_path.parent
595 | base_name = checkpoint_path.name
596 | new_base_name = base_name.replace('.pth', f"{label}int4-gptq.g{groupsize}.pth")
597 | else:
598 | raise ValueError(f"Invalid quantization mode {mode} needs to be one of [int8, int4, int4-gpptq]")
599 |
600 | quantize_path = dir_name / new_base_name
601 | print(f"Writing quantized weights to {quantize_path}")
602 | quantize_path.unlink(missing_ok=True) # remove existing file if one already there
603 | torch.save(quantized_state_dict, quantize_path)
604 | print(f"Quantization complete took {time.time() - t0:.02f} seconds")
605 | return
606 |
607 | if __name__ == '__main__':
608 | import argparse
609 | parser = argparse.ArgumentParser(description='Quantize a model.')
610 | parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Path to the model checkpoint to be quantized.')
611 | parser.add_argument('--mode', '-q', type=str, default='int8', choices=['int8', 'int4', 'int4-gptq'], help='type of quantization to perform')
612 | parser.add_argument('--groupsize', type=int, default=32, help='Group size for int4 quantization.')
613 | parser.add_argument('--calibration_tasks', type=str, nargs='+', default=['wikitext'], help='tasks to do gptq calibration on, if doing gptq')
614 | parser.add_argument('--calibration_limit', type=int, default=1000, help='number of samples to use for gptq calibration')
615 | parser.add_argument('--calibration_seq_length', type=int, default=100, help='length of sequences to use for gptq calibration')
616 | parser.add_argument('--pad_calibration_inputs', type=bool, default=False, help='pads sequences shorter than calibration_seq_length to that length, yielding more calibration inputs but running much slower')
617 | parser.add_argument('--percdamp', type=float, default=.01, help='gptq percentage dampening')
618 | parser.add_argument('--blocksize', type=int, default=128, help='blocksize for gptq')
619 | parser.add_argument('--label', type=str, default='_', help='label to add to output filename')
620 |
621 | args = parser.parse_args()
622 | quantize(args.checkpoint_path, args.mode, args.groupsize, args.calibration_tasks, args.calibration_limit, args.calibration_seq_length, args.pad_calibration_inputs, args.percdamp, args.blocksize, args.label)
623 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | sentencepiece
3 | tiktoken
4 | blobfile
5 | safetensors
6 |
--------------------------------------------------------------------------------
/scripts/convert_hf_checkpoint.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | import json
7 | import re
8 | import shutil
9 | import sys
10 | from pathlib import Path
11 | from typing import Optional
12 | from safetensors.torch import load_file as load_safetensors_file
13 | import torch
14 |
15 | # support running without installing as a package
16 | wd = Path(__file__).parent.parent.resolve()
17 | sys.path.append(str(wd))
18 |
19 | from model import ModelArgs
20 |
21 |
22 | @torch.inference_mode()
23 | def convert_hf_checkpoint(
24 | *,
25 | checkpoint_dir: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf"),
26 | model_name: Optional[str] = None,
27 | ) -> None:
28 | if model_name is None:
29 | model_name = checkpoint_dir.name
30 |
31 | config = ModelArgs.from_name(model_name)
32 | print(f"Model config {config.__dict__}")
33 |
34 | # Load the json file containing weight mapping
35 | model_map_json_safetensors = checkpoint_dir / 'model.safetensors.index.json'
36 | model_map_json_pytorch = checkpoint_dir / "pytorch_model.bin.index.json"
37 | model_map_json = None
38 |
39 | try:
40 | assert model_map_json_safetensors.is_file()
41 | model_map_json = model_map_json_safetensors
42 | print(f"Found safetensors index at {model_map_json_safetensors}")
43 | except AssertionError:
44 | print(f"{model_map_json_safetensors} not found")
45 | if model_map_json is None:
46 | try:
47 | assert model_map_json_pytorch.is_file()
48 | model_map_json = model_map_json_pytorch
49 | print(f"Found pytorch index at {model_map_json_pytorch}")
50 | except AssertionError:
51 | print(f"{model_map_json_pytorch} not found")
52 |
53 | if model_map_json is None: raise Exception("No model map found!")
54 |
55 | with open(model_map_json) as json_map:
56 | bin_index = json.load(json_map)
57 |
58 | weight_map = {
59 | "model.embed_tokens.weight": "tok_embeddings.weight",
60 | "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
61 | "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
62 | "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
63 | "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
64 | 'model.layers.{}.self_attn.rotary_emb.inv_freq': None,
65 | 'model.layers.{}.mlp.gate_proj.weight': 'layers.{}.feed_forward.w1.weight',
66 | "model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
67 | "model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
68 | "model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
69 | "model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
70 | "model.norm.weight": "norm.weight",
71 | "lm_head.weight": "output.weight",
72 | }
73 | bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()}
74 |
75 | def permute(w, n_head):
76 | dim = config.dim
77 | return (
78 | w.view(n_head, 2, config.head_dim // 2, dim)
79 | .transpose(1, 2)
80 | .reshape(config.head_dim * n_head, dim)
81 | )
82 |
83 | merged_result = {}
84 | for file in sorted(bin_files):
85 | if "safetensors" in str(file):
86 | state_dict = load_safetensors_file(str(file), device="cpu")
87 | merged_result.update(state_dict)
88 | else:
89 | state_dict = torch.load(str(file), map_location="cpu", mmap=True, weights_only=True)
90 | merged_result.update(state_dict)
91 | final_result = {}
92 | for key, value in merged_result.items():
93 | if "layers" in key:
94 | abstract_key = re.sub(r'(\d+)', '{}', key)
95 | layer_num = re.search(r'\d+', key).group(0)
96 | new_key = weight_map[abstract_key]
97 | if new_key is None:
98 | continue
99 | new_key = new_key.format(layer_num)
100 | else:
101 | new_key = weight_map[key]
102 |
103 | final_result[new_key] = value
104 |
105 | for key in tuple(final_result.keys()):
106 | if "wq" in key:
107 | q = final_result[key]
108 | k = final_result[key.replace("wq", "wk")]
109 | v = final_result[key.replace("wq", "wv")]
110 | q = permute(q, config.n_head)
111 | k = permute(k, config.n_local_heads)
112 | final_result[key.replace("wq", "wqkv")] = torch.cat([q, k, v])
113 | del final_result[key]
114 | del final_result[key.replace("wq", "wk")]
115 | del final_result[key.replace("wq", "wv")]
116 | print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}")
117 | torch.save(final_result, checkpoint_dir / "model.pth")
118 | if 'llama-3-' in model_name.lower() or 'llama-3.1-' in model_name.lower():
119 | if 'llama-3.1-405b' in model_name.lower():
120 | original_dir = checkpoint_dir / "original" / "mp16"
121 | else:
122 | original_dir = checkpoint_dir / "original"
123 | tokenizer_model = original_dir / "tokenizer.model"
124 | tokenizer_model_tiktoken = checkpoint_dir / "tokenizer.model"
125 | print(f"Copying {tokenizer_model} to {tokenizer_model_tiktoken}")
126 | shutil.copy(tokenizer_model, tokenizer_model_tiktoken)
127 |
128 | if __name__ == '__main__':
129 | import argparse
130 | parser = argparse.ArgumentParser(description='Convert HuggingFace checkpoint.')
131 | parser.add_argument('--checkpoint_dir', type=Path, default=Path("checkpoints/meta-llama/llama-2-7b-chat-hf"))
132 | parser.add_argument('--model_name', type=str, default=None)
133 |
134 | args = parser.parse_args()
135 | convert_hf_checkpoint(
136 | checkpoint_dir=args.checkpoint_dir,
137 | model_name=args.model_name,
138 | )
139 |
--------------------------------------------------------------------------------
/scripts/download.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | import os
7 | from typing import Optional
8 |
9 | from requests.exceptions import HTTPError
10 |
11 |
12 | def hf_download(repo_id: Optional[str] = None, hf_token: Optional[str] = None) -> None:
13 | from huggingface_hub import snapshot_download
14 | os.makedirs(f"checkpoints/{repo_id}", exist_ok=True)
15 | try:
16 | snapshot_download(repo_id, local_dir=f"checkpoints/{repo_id}", local_dir_use_symlinks=False, token=hf_token)
17 | except HTTPError as e:
18 | if e.response.status_code == 401:
19 | print("You need to pass a valid `--hf_token=...` to download private checkpoints.")
20 | else:
21 | raise e
22 |
23 | if __name__ == '__main__':
24 | import argparse
25 | parser = argparse.ArgumentParser(description='Download data from HuggingFace Hub.')
26 | parser.add_argument('--repo_id', type=str, default="checkpoints/meta-llama/llama-2-7b-chat-hf", help='Repository ID to download from.')
27 | parser.add_argument('--hf_token', type=str, default=None, help='HuggingFace API token.')
28 |
29 | args = parser.parse_args()
30 | hf_download(args.repo_id, args.hf_token)
31 |
--------------------------------------------------------------------------------
/scripts/prepare.sh:
--------------------------------------------------------------------------------
1 | python scripts/download.py --repo_id $1 && python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/$1 && python quantize.py --checkpoint_path checkpoints/$1/model.pth --mode int8
2 |
--------------------------------------------------------------------------------
/scripts/speculate_34B_bf16.sh:
--------------------------------------------------------------------------------
1 | # 56.80
2 | export MODEL_REPO=codellama/CodeLlama-34b-Python-hf
3 | export DRAFT_MODEL_REPO=codellama/CodeLlama-7b-Python-hf
4 | time python generate.py --compile --draft_checkpoint_path checkpoints/$DRAFT_MODEL_REPO/model_int4.g32.pth --checkpoint_path checkpoints/$MODEL_REPO/model.pth --speculate_k 6 --prompt "def quicksort(arr):" --max_new_tokens 200 --num_samples 50
5 |
--------------------------------------------------------------------------------
/scripts/speculate_70B_int4.sh:
--------------------------------------------------------------------------------
1 | # 49.26
2 | export MODEL_REPO=meta-llama/Llama-2-70b-chat-hf
3 | export DRAFT_MODEL_REPO=meta-llama/Llama-2-7b-chat-hf
4 | time python generate.py --compile --draft_checkpoint_path checkpoints/$DRAFT_MODEL_REPO/model_int4.g32.pth --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.pth --speculate_k 4 --prompt "def quicksort(arr):" --max_new_tokens 100 --num_samples 50 --temperature 0
5 |
--------------------------------------------------------------------------------
/scripts/speculate_7B_int4.sh:
--------------------------------------------------------------------------------
1 | export MODEL_REPO=codellama/CodeLlama-7b-Python-hf
2 | export DRAFT_MODEL_REPO=PY007/TinyLlama-1.1B-intermediate-step-480k-1T
3 | time python generate.py --compile --draft_checkpoint_path checkpoints/$DRAFT_MODEL_REPO/model_int4.g32.pth --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.pth --speculate_k 5 --prompt "Hi my name is" --max_new_tokens 200 --num_samples 50 --temperature 0 --compile_prefill
4 |
--------------------------------------------------------------------------------
/scripts/speculate_tp_70B_bf16.sh:
--------------------------------------------------------------------------------
1 | export MODEL_REPO=meta-llama/Llama-2-70b-chat-hf
2 | export DRAFT_MODEL_REPO=meta-llama/Llama-2-7b-chat-hf
3 | time torchrun --standalone --nproc_per_node=8 generate.py --compile --draft_checkpoint_path checkpoints/$DRAFT_MODEL_REPO/model_int8.pth --checkpoint_path checkpoints/$MODEL_REPO/model.pth --speculate_k 5 --prompt "def quicksort(arr):" --max_new_tokens 200 --num_samples 50 --temperature 0
4 |
--------------------------------------------------------------------------------
/scripts/test_flow.sh:
--------------------------------------------------------------------------------
1 | export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf
2 | rm -r checkpoints/$MODEL_REPO
3 | python scripts/download.py --repo_id $MODEL_REPO
4 | python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/$MODEL_REPO
5 | python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth
6 | python generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model_int8.pth --max_new_tokens 100
7 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | from setuptools import setup, find_packages
7 |
8 | setup(
9 | name='gpt-fast',
10 | version='0.1',
11 | packages=find_packages(),
12 | install_requires=[
13 | 'torch',
14 | ],
15 | description='A simple, fast, pure PyTorch Llama inference engine',
16 | long_description=open('README.md').read(),
17 | long_description_content_type='text/markdown',
18 | url='https://github.com/meta-pytorch/gpt-fast',
19 | )
20 |
--------------------------------------------------------------------------------
/tokenizer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sentencepiece as spm
3 | import tiktoken
4 | from tiktoken.load import load_tiktoken_bpe
5 | from pathlib import Path
6 | from typing import Dict
7 |
8 | class TokenizerInterface:
9 | def __init__(self, model_path):
10 | self.model_path = model_path
11 |
12 | def encode(self, text):
13 | raise NotImplementedError("This method should be overridden by subclasses.")
14 |
15 | def decode(self, tokens):
16 | raise NotImplementedError("This method should be overridden by subclasses.")
17 |
18 | def bos_id(self):
19 | raise NotImplementedError("This method should be overridden by subclasses.")
20 |
21 | def eos_id(self):
22 | raise NotImplementedError("This method should be overridden by subclasses.")
23 |
24 | class SentencePieceWrapper(TokenizerInterface):
25 | def __init__(self, model_path):
26 | super().__init__(model_path)
27 | self.processor = spm.SentencePieceProcessor(str(model_path))
28 |
29 | def encode(self, text):
30 | return self.processor.EncodeAsIds(text)
31 |
32 | def decode(self, tokens):
33 | return self.processor.DecodeIds(tokens)
34 |
35 | def bos_id(self):
36 | return self.processor.bos_id()
37 |
38 | def eos_id(self):
39 | return self.processor.eos_id()
40 |
41 | class TiktokenWrapper(TokenizerInterface):
42 | """
43 | Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
44 | """
45 |
46 | special_tokens: Dict[str, int]
47 |
48 | num_reserved_special_tokens = 256
49 |
50 | pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501
51 |
52 | def __init__(self, model_path):
53 | super().__init__(model_path)
54 | assert os.path.isfile(model_path), str(model_path)
55 | mergeable_ranks = load_tiktoken_bpe(str(model_path))
56 | num_base_tokens = len(mergeable_ranks)
57 | special_tokens = [
58 | "<|begin_of_text|>",
59 | "<|end_of_text|>",
60 | "<|reserved_special_token_0|>",
61 | "<|reserved_special_token_1|>",
62 | "<|reserved_special_token_2|>",
63 | "<|reserved_special_token_3|>",
64 | "<|start_header_id|>",
65 | "<|end_header_id|>",
66 | "<|reserved_special_token_4|>",
67 | "<|eot_id|>", # end of turn
68 | ] + [
69 | f"<|reserved_special_token_{i}|>"
70 | for i in range(5, self.num_reserved_special_tokens - 5)
71 | ]
72 | self.special_tokens = {
73 | token: num_base_tokens + i for i, token in enumerate(special_tokens)
74 | }
75 | self.model = tiktoken.Encoding(
76 | name=Path(model_path).name,
77 | pat_str=self.pat_str,
78 | mergeable_ranks=mergeable_ranks,
79 | special_tokens=self.special_tokens,
80 | )
81 | # BOS / EOS token IDs
82 | self._bos_id: int = self.special_tokens["<|begin_of_text|>"]
83 | self._eos_id: int = self.special_tokens["<|end_of_text|>"]
84 |
85 | def encode(self, text):
86 | return self.model.encode(text)
87 |
88 | def decode(self, tokens):
89 | return self.model.decode(tokens)
90 |
91 | def bos_id(self):
92 | return self._bos_id
93 |
94 | def eos_id(self):
95 | return self._eos_id
96 |
97 | def get_tokenizer(tokenizer_model_path, model_name):
98 | """
99 | Factory function to get the appropriate tokenizer based on the model name.
100 |
101 | Args:
102 | - tokenizer_model_path (str): The file path to the tokenizer model.
103 | - model_name (str): The name of the model, used to determine the tokenizer type.
104 |
105 | Returns:
106 | - TokenizerInterface: An instance of a tokenizer.
107 | """
108 |
109 | if "llama-3" in str(model_name).lower():
110 | return TiktokenWrapper(tokenizer_model_path)
111 | else:
112 | return SentencePieceWrapper(tokenizer_model_path)
113 |
--------------------------------------------------------------------------------
/tp.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | import os
7 | from typing import List, Optional
8 |
9 | import torch
10 | import torch.distributed as dist
11 | from torch import nn
12 | if os.uname().sysname != "Darwin":
13 | from torch.distributed import _functional_collectives as funcol
14 | else:
15 | # Distributed is not supported on MacOS
16 | funcol = None
17 |
18 | from model import Attention, FeedForward, Transformer
19 | from quantize import WeightOnlyInt4Linear
20 |
21 |
22 | def _get_rank() -> int:
23 | return int(os.environ.get("LOCAL_RANK", "0"))
24 |
25 | def is_local():
26 | return _get_rank() == 0
27 |
28 | def local_break():
29 | if is_local():
30 | breakpoint()
31 | dist.barrier()
32 |
33 | def _get_world_size() -> int:
34 | return int(os.environ.get("LOCAL_WORLD_SIZE", "1"))
35 |
36 | def maybe_init_dist() -> Optional[int]:
37 | try:
38 | # provided by torchrun
39 | rank = _get_rank()
40 | world_size = _get_world_size()
41 |
42 | if world_size < 2:
43 | # too few gpus to parallelize, tp is no-op
44 | return None
45 | except KeyError:
46 | # not run via torchrun, no-op
47 | return None
48 |
49 | torch.cuda.set_device(rank)
50 | dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
51 | return rank
52 |
53 |
54 | def _apply_tp_linear(linear: nn.Linear, style: str, weight_splits: List[int] = []) -> None:
55 | rank = _get_rank()
56 | world_size = _get_world_size()
57 |
58 | # Linear's weight matrix is transposed, and is of shape
59 | # (linear.out_features, linear.in_features)
60 | dim_lookup = {
61 | "colwise": (0, "out_features"),
62 | "rowwise": (1, "in_features")
63 | }
64 | assert style in dim_lookup
65 | shard_dim, size_attr = dim_lookup[style]
66 |
67 | # ensure we can shard evenly
68 | assert getattr(linear, size_attr) % world_size == 0
69 | def shard(x, dim):
70 | assert x.size(dim=dim) % world_size == 0
71 | return torch.tensor_split(x, world_size, dim=dim)[rank]
72 |
73 | def shard_qkv(qkv, dim, weight_splits):
74 | q, k, v = qkv.split(weight_splits, dim=dim)
75 | q = shard(q, dim)
76 | k = shard(k, dim)
77 | v = shard(v, dim)
78 | return torch.cat((q,k,v), dim=dim)
79 |
80 | # shard
81 | if weight_splits:
82 | # attention
83 | assert len(weight_splits) == 3
84 |
85 | if isinstance(linear, WeightOnlyInt4Linear):
86 | sharded_weight = shard_qkv(linear.weight, shard_dim, [i//8 for i in weight_splits])
87 | linear.scales_and_zeros = shard_qkv(linear.scales_and_zeros, 1 - shard_dim, weight_splits)
88 | else:
89 | sharded_weight = shard_qkv(linear.weight, shard_dim, weight_splits)
90 | if hasattr(linear, "scales") and style == "colwise":
91 | linear.scales = shard_qkv(linear.scales, 0, weight_splits)
92 | else:
93 | sharded_weight = shard(linear.weight, shard_dim)
94 | if isinstance(linear, WeightOnlyInt4Linear):
95 | linear.scales_and_zeros = shard(linear.scales_and_zeros, 1 - shard_dim)
96 | if style == "rowwise":
97 | assert linear.scales_and_zeros.shape[0] * 32 == sharded_weight.shape[1] * sharded_weight.shape[2] * sharded_weight.shape[3]
98 | assert linear.scales_and_zeros.shape[1] == sharded_weight.shape[0] * 8
99 | if hasattr(linear, "scales") and style == "colwise":
100 | linear.scales = shard(linear.scales, 0)
101 |
102 | # local_break()
103 | linear.weight = nn.Parameter(sharded_weight, requires_grad=False)
104 | setattr(linear, size_attr, getattr(linear, size_attr) // world_size)
105 |
106 | # shape info should still be synced
107 | # assert linear.weight.shape == (linear.out_features, linear.in_features)
108 |
109 |
110 | def _apply_tp_ffn(mlp: FeedForward) -> None:
111 | assert hasattr(mlp, "w1")
112 | assert hasattr(mlp, "w3")
113 | assert hasattr(mlp, "w2")
114 |
115 | _apply_tp_linear(mlp.w1, "colwise")
116 | _apply_tp_linear(mlp.w3, "colwise")
117 | _apply_tp_linear(mlp.w2, "rowwise")
118 |
119 | world_size = _get_world_size()
120 | mlp.register_forward_hook(lambda _module, _input, output: funcol.all_reduce(
121 | output, "sum", list(range(world_size))))
122 |
123 |
124 | def _apply_tp_attn(attn: Attention) -> None:
125 | assert hasattr(attn, "wqkv")
126 | assert hasattr(attn, "wo")
127 |
128 | kv_size = attn.n_local_heads * attn.head_dim
129 | _apply_tp_linear(attn.wqkv, "colwise", [attn.dim, kv_size, kv_size])
130 | _apply_tp_linear(attn.wo, "rowwise")
131 |
132 | # overwrite
133 | world_size = _get_world_size()
134 | attn.n_head = attn.n_head // world_size
135 | attn.dim = attn.dim // world_size
136 | attn.head_dim = attn.dim // attn.n_head
137 | attn.n_local_heads = attn.n_local_heads // world_size
138 |
139 | attn.register_forward_hook(lambda _module, _input, output: funcol.all_reduce(
140 | output[0], "sum", list(range(world_size))))
141 |
142 |
143 | def _apply_tp_Transformer(Transformer: Transformer) -> None:
144 | # overwrite config before Transformer.setup_cache is called
145 | world_size = _get_world_size()
146 | Transformer.config.n_head = Transformer.config.n_head // world_size
147 | Transformer.config.dim = Transformer.config.dim // world_size
148 | Transformer.config.n_local_heads = Transformer.config.n_local_heads // world_size
149 |
150 |
151 | def apply_tp(model: Transformer) -> None:
152 | _apply_tp_Transformer(model)
153 | for block in model.layers:
154 | # Apply to MLP
155 | _apply_tp_ffn(block.feed_forward)
156 | _apply_tp_attn(block.attention)
157 |
--------------------------------------------------------------------------------