├── .gitignore ├── .pre-commit-config.yaml ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── Makefile ├── README.md ├── assets └── architecture.png ├── benchmark.py ├── gpt_fast ├── __init__.py ├── gpt_dense_TP.py ├── gpt_desync_TP.py ├── gpt_ladder_TP.py ├── gpt_parallel_TP.py ├── tp.py └── utils.py ├── hf_modeling_utils ├── configs │ ├── Llama-3.1-8B-Instruct-Ladder-last16L.json │ └── Llama-3.1-8B-Instruct-Ladder-last20L.json ├── configuration_llama_ladder.py └── modeling_llama_ladder.py ├── requirements.txt ├── scripts ├── throughput-1B.sh ├── throughput-34B.sh ├── throughput-3B.sh ├── throughput-405B.sh ├── throughput-70B.sh ├── throughput-8B.sh └── throughput-bloom176b.sh ├── setup.py └── tools ├── plot-405b.py └── plot-70b.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .idea 3 | .DS_Store 4 | *.egg-info 5 | build 6 | 7 | # data 8 | data 9 | checkpoints 10 | out 11 | !data/shakespeare/prepare.py 12 | wandb 13 | 14 | # downloaded by our tests 15 | original_model.py 16 | original_adapter.py 17 | 18 | *.trace.json 19 | 20 | # torch inductor & torch elastic 21 | tmp* 22 | torchelastic* 23 | torchinductor* 24 | 25 | # logs file path 26 | logs/ 27 | 28 | # profiling results 29 | /profiles -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/PyCQA/autoflake 3 | rev: v2.3.1 4 | hooks: 5 | - id: autoflake 6 | - repo: https://github.com/pycqa/isort 7 | rev: 5.13.2 8 | hooks: 9 | - id: isort 10 | name: isort (python) 11 | - repo: https://github.com/psf/black 12 | rev: 24.8.0 13 | hooks: 14 | - id: black 15 | args: [--line-length=119,--target-version=py311] 16 | - repo: https://github.com/pre-commit/mirrors-clang-format 17 | rev: v18.1.8 18 | hooks: 19 | - id: clang-format 20 | types_or: [c++, c, cuda] 21 | args: [-style=file:.clang-format] 22 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at . All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 72 | 73 | [homepage]: https://www.contributor-covenant.org 74 | 75 | For answers to common questions about this code of conduct, see 76 | https://www.contributor-covenant.org/faq -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to gpt-fast 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | 6 | ## Pull Requests 7 | We actively welcome your pull requests. 8 | 9 | 1. Fork the repo and create your branch from `main`. 10 | 2. If you've added code that should be tested, add tests. 11 | 3. If you've changed APIs, update the documentation. 12 | 4. Ensure the test suite passes. 13 | 5. Make sure your code lints. 14 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 15 | 16 | ## Contributor License Agreement ("CLA") 17 | In order to accept your pull request, we need you to submit a CLA. You only need 18 | to do this once to work on any of Meta's open source projects. 19 | 20 | Complete your CLA here: 21 | 22 | ## Issues 23 | We use GitHub issues to track public bugs. Please ensure your description is 24 | clear and has sufficient instructions to be able to reproduce the issue. 25 | 26 | Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe 27 | disclosure of security bugs. In those cases, please go through the process 28 | outlined on that page and do not file a public issue. 29 | 30 | ## License 31 | By contributing to `gpt-fast`, you agree that your contributions will be licensed 32 | under the LICENSE file in the root directory of this source tree. 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2023 Meta 2 | 3 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 4 | 5 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 6 | 7 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 8 | 9 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 10 | 11 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 12 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | style: 2 | pre-commit run --all-files 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Ladder-Residual-Inference 2 | This repository contains the code for inference benchmarking for the paper [Ladder-residual: parallelism-aware architecture for accelerating large model inference with communication overlapping](https://arxiv.org/abs/2501.06589). 3 | 4 | If you are interested in training the Ladder Residual models, you can find the training code in [dolomite-engine](https://github.com/IBM/dolomite-engine). 5 | 6 | ## Ladder Redisual 7 | Tensor Parallelism (TP) is commonly used to partition a large language models (LLMs) across multiple accelerators to reduce the memory load and computation time during training and inference. However, TP is communication bound and thus requires fast interconnects (e.g., NVLink for NVIDIA GPUs) between the devices. However, these fast interconnects are only available on high-end datacenter GPUs. Even in the presence of these fast interconnects, the TP communication is often a significant bottleneck and thus limits the gains that can be achieved by increasing the number of accelerators. 8 | 9 | To mitigate this issue, we propose Ladder Residual: a simple architectural modification compatible to all residual-based models that enable straightforward overlapping, effectively hiding the latency of communication. Our insight is that in addition to rather than solely relying on systems-level optimizations, we propose re-architecting the model to separate communication from computation. In this parallel way, the model can continue processing input data even as communication tasks run in the background. 10 | 11 | For a Transformer model (Llama-3.1-8B), applying Ladder Residual to all its layers achieves 29% end-to-end wall clock speed up at inference time with TP world size of 8 devices. We refer to such model as the Ladder Transformer. We train a 1B and 3B Ladder Transformer from scratch and observe comparable performance to a standard dense transformer baseline. We also conduct adaptation experiments for our approach and show that it’s possible to adapt parts of the Llama-3.1 8B model with minimal accuracy degradation by retraining on only 3B tokens. 12 | 13 | We further explore an advanced architectural variant that eliminates communication altogether, enabling fast LLM inference on systems lacking high-speed interconnects. 14 | 15 | Ladder Residual Transformer (Ladder Transformer) is a decoder-based LLM architecture that allows overlapping of computation with communication for inference via model architecture modification. The proposed approach doesn't require any custom kernels making the method easily scalable and applicable to different hardware architectures and ML frameworks. 16 | 17 | ![image](./assets/architecture.png) 18 | 19 | ## Usage 20 | To run the code, you can install this repository and run one of the benchmarking scripts (70B) as follows: 21 | ```shell 22 | pip install -e . 23 | sh scripts/throughput-70B.sh 24 | ``` 25 | for multi-node benchmarking, you can use the following command: 26 | ```shell 27 | sh scripts/throughput-405B.sh 28 | ``` 29 | where `` is the rank of the current node, and please make sure to update master_addr and master_port in the script. 30 | 31 | ## Acknowledgement 32 | 33 | This repository is based on [gpt-fast](https://github.com/pytorch-labs/gpt-fast) and runs completely with PyTorch compile. The model architecture in this repository is based on the Llama architecture. 34 | -------------------------------------------------------------------------------- /assets/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mayank31398/ladder-residual-inference/a19b1b8068cf7d6b37eeb5de1b2b108d6550e61a/assets/architecture.png -------------------------------------------------------------------------------- /benchmark.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | import argparse 5 | 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | import sys 9 | import time 10 | from pathlib import Path 11 | from typing import Optional, Tuple 12 | 13 | import torch 14 | import torch._dynamo.config 15 | import torch._inductor.config 16 | import torch.distributed as dist 17 | 18 | from gpt_fast import GPTDense, GPTDesync, GPTLadder, GPTParallel 19 | from gpt_fast.utils import _get_model_size, set_flash_attention 20 | 21 | 22 | def print_rank_0(*args, **kwargs): 23 | if dist.get_rank() == 0: 24 | print(*args, **kwargs) 25 | 26 | 27 | torch._inductor.config.coordinate_descent_tuning = True 28 | torch._inductor.config.triton.unique_kernel_names = True 29 | # experimental features to reduce compilation times, will be on by default in future 30 | torch._inductor.config.fx_graph_cache = True 31 | # torch._functorch.config.enable_autograd_cache = True 32 | torch._inductor.config.reorder_for_compute_comm_overlap = True # allows overlap 33 | 34 | default_device = "cuda" if torch.cuda.is_available() else "cpu" 35 | 36 | # support running without installing as a package 37 | wd = Path(__file__).parent.parent.resolve() 38 | sys.path.append(str(wd)) 39 | 40 | _MODELS = { 41 | "gpt_dense": GPTDense, 42 | "gpt_desync": GPTDesync, 43 | "gpt_parallel": GPTParallel, 44 | "gpt_ladder": GPTLadder, 45 | } 46 | 47 | 48 | def device_sync(device): 49 | device = str(device) 50 | if "cuda" in device: 51 | torch.cuda.synchronize(device) 52 | elif ("cpu" in device) or ("mps" in device): 53 | pass 54 | else: 55 | print_rank_0(f"device={device} is not yet suppported") 56 | 57 | 58 | def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization 59 | q = torch.empty_like(probs_sort).exponential_(1) 60 | return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) 61 | 62 | 63 | def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None): 64 | logits = logits / max(temperature, 1e-5) 65 | 66 | if top_k is not None: 67 | v, _ = torch.topk(logits, min(top_k, logits.size(-1))) 68 | pivot = v.select(-1, -1).unsqueeze(-1) 69 | logits = torch.where(logits < pivot, -float("Inf"), logits) 70 | probs = torch.nn.functional.softmax(logits, dim=-1) 71 | return probs 72 | 73 | 74 | def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None): 75 | probs = logits_to_probs(logits[:, -1], temperature, top_k) 76 | idx_next = multinomial_sample_one_no_sync(probs) 77 | return idx_next, probs 78 | 79 | 80 | @torch.no_grad() 81 | def prefill(model: torch.nn.Module, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> torch.Tensor: 82 | logits = model(x, input_pos) 83 | return sample(logits, **sampling_kwargs)[0] 84 | 85 | 86 | def decode_one_token( 87 | model: torch.nn.Module, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs 88 | ) -> Tuple[torch.Tensor, torch.Tensor]: 89 | # input_pos: [B, 1] 90 | assert input_pos.shape[-1] == 1 91 | logits = model(x, input_pos) 92 | return sample(logits, **sampling_kwargs) 93 | 94 | 95 | @torch.no_grad() 96 | def decode_n_tokens( 97 | model: torch.nn.Module, 98 | cur_token: torch.Tensor, 99 | input_pos: torch.Tensor, 100 | num_new_tokens: int, 101 | callback=lambda _: _, 102 | **sampling_kwargs, 103 | ): 104 | new_tokens, new_probs = [], [] 105 | for i in range(num_new_tokens): 106 | # Actually better for Inductor to codegen attention here 107 | with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): 108 | next_token, next_prob = decode_one_token(model, cur_token, input_pos, **sampling_kwargs) 109 | input_pos += 1 110 | new_tokens.append(next_token.clone()) 111 | callback(new_tokens[-1]) 112 | new_probs.append(next_prob.clone()) 113 | cur_token = next_token.clone() 114 | 115 | return new_tokens, new_probs 116 | 117 | 118 | @torch.no_grad() 119 | def generate( 120 | model: torch.nn.Module, 121 | prompt: torch.Tensor, 122 | max_new_tokens: int, 123 | batch_size: int, 124 | empty: torch.Tensor, 125 | *, 126 | callback=lambda x: x, 127 | **sampling_kwargs, 128 | ) -> torch.Tensor: 129 | """ 130 | Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. 131 | """ 132 | 133 | T = prompt.size(-1) 134 | device = prompt.device 135 | # We are just making the same prompt for every batch 136 | prompt = prompt.view(1, -1).repeat(batch_size, 1) 137 | empty[:, :T] = prompt 138 | input_pos = torch.arange(0, T, device=device) 139 | 140 | device_sync(device) 141 | prefill_start = time.perf_counter() 142 | with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): 143 | next_token = prefill(model, prompt.view(batch_size, -1), input_pos, **sampling_kwargs) 144 | device_sync(device) 145 | prefill_latency = time.perf_counter() - prefill_start 146 | print_rank_0(f"Prefill latency: {prefill_latency} sec") 147 | 148 | next_token = next_token.clone() 149 | empty[:, T] = next_token.squeeze() 150 | 151 | input_pos = torch.tensor([T], device=device, dtype=torch.int) 152 | 153 | device_sync(device) 154 | decode_start = time.perf_counter() 155 | generated_tokens, _ = decode_n_tokens( 156 | model, next_token.view(batch_size, -1), input_pos, max_new_tokens - 1, callback=callback, **sampling_kwargs 157 | ) 158 | device_sync(device) 159 | decode_latency = time.perf_counter() - decode_start 160 | print_rank_0(f"Decode latency: {decode_latency} sec") 161 | 162 | empty[:, T + 1 :] = torch.cat(generated_tokens, dim=-1) 163 | 164 | return empty, decode_latency, prefill_latency 165 | 166 | 167 | @torch.no_grad() 168 | def generate_using_cuda_graphs( 169 | prefill_graph, 170 | static_x: torch.Tensor, 171 | static_input_pos: torch.Tensor, 172 | static_next_token_prefill: torch.Tensor, 173 | decode_graph, 174 | static_cur_token: torch.Tensor, 175 | static_decode_input_pos: torch.Tensor, 176 | static_next_token_decode: torch.Tensor, 177 | prompt: torch.Tensor, 178 | batch_size: int, 179 | empty: torch.Tensor, 180 | num_new_tokens: int, 181 | ) -> torch.Tensor: 182 | """ 183 | Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. 184 | """ 185 | 186 | T = prompt.size(-1) 187 | device = prompt.device 188 | 189 | # We are just making the same prompt for every batch 190 | prompt = prompt.view(1, -1).repeat(batch_size, 1) 191 | empty[:, :T] = prompt 192 | 193 | static_x.copy_(prompt) 194 | static_input_pos.copy_(torch.arange(0, T, device=device)) 195 | 196 | device_sync(device) 197 | prefill_start = time.perf_counter() 198 | prefill_graph.replay() 199 | torch.cuda.synchronize() 200 | prefill_latency = time.perf_counter() - prefill_start 201 | print_rank_0(f"Prefill latency: {prefill_latency} sec") 202 | 203 | empty[:, T] = static_next_token_prefill.squeeze() 204 | 205 | device_sync(device) 206 | decode_start = time.perf_counter() 207 | 208 | static_cur_token.copy_(static_next_token_prefill) 209 | static_decode_input_pos.copy_(torch.tensor([T], device=device, dtype=torch.int)) 210 | 211 | new_tokens, new_probs = [], [] 212 | for _ in range(num_new_tokens - 1): 213 | decode_graph.replay() 214 | static_decode_input_pos += 1 215 | 216 | new_tokens.append(static_next_token_decode.clone()) 217 | static_cur_token.copy_(static_next_token_decode.clone()) 218 | 219 | torch.cuda.synchronize() 220 | decode_latency = time.perf_counter() - decode_start 221 | print_rank_0(f"Decode latency: {decode_latency} sec") 222 | 223 | empty[:, T + 1 :] = torch.cat(new_tokens, dim=-1) 224 | 225 | return empty, decode_latency, prefill_latency 226 | 227 | 228 | def encode_tokens(tokenizer, string, bos=True, device=default_device): 229 | tokens = tokenizer.encode(string) 230 | if bos: 231 | tokens = [tokenizer.bos_id()] + tokens 232 | return torch.tensor(tokens, dtype=torch.int, device=device) 233 | 234 | 235 | def _load_model(model_name, device, precision): 236 | with torch.device("meta"): 237 | model = _MODELS[model_name.split(":")[0]].from_name(model_name.split(":")[1]) 238 | 239 | model = model.to(dtype=precision) 240 | model = model.to_empty(device=device) 241 | 242 | for p in model.parameters(): 243 | torch.nn.init.normal_(p, mean=0, std=0.02) 244 | 245 | print_rank_0(model) 246 | 247 | return model.eval() 248 | 249 | 250 | B_INST, E_INST = "[INST]", "[/INST]" 251 | 252 | 253 | @torch.no_grad() 254 | def get_cuda_graphs_for_prefill(model: torch.nn.Module, prompt: torch.Tensor, batch_size: int, **sampling_kwargs): 255 | T = prompt.size(-1) 256 | device = prompt.device 257 | 258 | # We are just making the same prompt for every batch 259 | static_x = prompt.view(1, -1).repeat(batch_size, 1) 260 | static_input_pos = torch.arange(0, T, device=device) 261 | 262 | s = torch.cuda.Stream() 263 | s.wait_stream(torch.cuda.current_stream()) 264 | 265 | with torch.cuda.stream(s): 266 | for _ in range(3): 267 | static_next_token = prefill(model, static_x.view(batch_size, -1), static_input_pos, **sampling_kwargs) 268 | 269 | torch.cuda.current_stream().wait_stream(s) 270 | 271 | g = torch.cuda.CUDAGraph() 272 | with torch.cuda.graph(g): 273 | static_next_token = prefill(model, static_x.view(batch_size, -1), static_input_pos, **sampling_kwargs) 274 | 275 | return g, static_x, static_input_pos, static_next_token 276 | 277 | 278 | @torch.no_grad() 279 | def get_cuda_graphs_for_decode( 280 | model: torch.nn.Module, 281 | prompt: torch.Tensor, 282 | batch_size: int, 283 | max_new_tokens: int, 284 | cur_token: torch.Tensor, 285 | **sampling_kwargs, 286 | ): 287 | T = prompt.size(-1) 288 | device = prompt.device 289 | 290 | static_cur_token = cur_token.clone() 291 | static_input_pos = torch.tensor([T], device=device, dtype=torch.int) 292 | 293 | # Warm up 294 | for _ in range(3): 295 | static_input_pos.copy_(torch.tensor([T], device=device, dtype=torch.int)) 296 | decode_one_token(model, static_cur_token, static_input_pos, **sampling_kwargs) 297 | 298 | static_input_pos.copy_(torch.tensor([T], device=device, dtype=torch.int)) 299 | 300 | # Capture CUDA graph 301 | g_decode = torch.cuda.CUDAGraph() 302 | with torch.cuda.graph(g_decode): 303 | static_next_token, _ = decode_one_token(model, static_cur_token, static_input_pos, **sampling_kwargs) 304 | 305 | return g_decode, static_cur_token, static_input_pos, static_next_token 306 | 307 | 308 | def main( 309 | model_name: str, 310 | prompt_length: int = 1, 311 | num_samples: int = 5, 312 | max_new_tokens: int = 100, 313 | batch_size: int = 1, 314 | top_k: int = 200, 315 | temperature: float = 0.8, 316 | compile: bool = True, 317 | compile_prefill: bool = False, 318 | profile: Optional[Path] = None, 319 | device=default_device, 320 | use_cuda_graphs: bool = False, 321 | ) -> None: 322 | """Generates text samples based on a pre-trained Transformer model and tokenizer.""" 323 | 324 | from gpt_fast import maybe_init_dist 325 | 326 | rank = maybe_init_dist() 327 | use_tp = rank is not None 328 | 329 | print_rank_0(f"our world size={dist.get_world_size()}") 330 | print_rank_0(f"Using device={device}") 331 | 332 | precision = torch.float16 333 | 334 | print_rank_0("Loading model ...") 335 | t0 = time.time() 336 | model = _load_model(model_name, device, precision) 337 | device_sync(device=device) # MKG 338 | 339 | print_rank_0(f"Time to load model: {time.time() - t0:.02f} seconds") 340 | # generate a fully synthetic prompt 341 | encoded = torch.randint(0, 1024, (prompt_length,), device=device, dtype=torch.int64) 342 | 343 | torch.manual_seed(1234) 344 | model_size, params = _get_model_size(model) 345 | 346 | T_new = encoded.size(-1) + max_new_tokens # include encode sequence length 347 | max_seq_length = min(T_new, model.config.block_size) 348 | 349 | with torch.device(device): 350 | model.setup_caches(max_batch_size=batch_size, max_seq_length=max_seq_length) 351 | 352 | if compile: 353 | global decode_one_token, decode_multi_token, prefill 354 | decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead", fullgraph=True) 355 | 356 | if compile_prefill: 357 | dynamic = False 358 | print_rank_0(f"Compiling prefill with dynamic={dynamic}") 359 | prefill = torch.compile(prefill, fullgraph=True, dynamic=dynamic) 360 | 361 | elif use_cuda_graphs: 362 | print_rank_0("CUDA_GRAPH are activate") 363 | prefill_graph, static_x, static_input_pos, static_next_token_prefill = get_cuda_graphs_for_prefill( 364 | model, 365 | prompt=encoded, 366 | batch_size=batch_size, 367 | temperature=temperature, 368 | top_k=top_k, 369 | ) 370 | 371 | decode_graph, static_cur_token, static_decode_input_pos, static_next_token_decode = get_cuda_graphs_for_decode( 372 | model, 373 | prompt=encoded, 374 | batch_size=batch_size, 375 | max_new_tokens=max_new_tokens, 376 | cur_token=static_next_token_prefill, 377 | temperature=temperature, 378 | top_k=top_k, 379 | ) 380 | 381 | aggregate_metrics = { 382 | "tokens_per_sec_per_user": [], 383 | "total_tokens_per_sec": [], 384 | "decode_latency": [], 385 | "prefill_latency": [], 386 | } 387 | start = 0 if profile else -5 388 | 389 | for i in range(start, num_samples): 390 | device_sync(device=device) # MKG 391 | 392 | callback = lambda x: x 393 | t0 = time.perf_counter() 394 | 395 | import contextlib 396 | 397 | if not profile or (use_tp and rank != 0) or i != num_samples - 1: 398 | prof = contextlib.nullcontext() 399 | else: 400 | torch.profiler._utils._init_for_cuda_graphs() 401 | prof = torch.profiler.profile( 402 | activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], 403 | on_trace_ready=torch.profiler.tensorboard_trace_handler(profile), 404 | record_shapes=True, 405 | ) 406 | 407 | empty = torch.empty(batch_size, T_new, dtype=encoded.dtype, device=device) 408 | 409 | with prof: 410 | if use_cuda_graphs: 411 | # NOTE we need to reset the static variable pointers for CUDA graph on each geenration here 412 | # however, for benchmarking throughput, it doesn't matter 413 | y, decode_latency, prefill_latency = generate_using_cuda_graphs( 414 | prefill_graph, 415 | static_x, 416 | static_input_pos, 417 | static_next_token_prefill, 418 | decode_graph, 419 | static_cur_token, 420 | static_decode_input_pos, 421 | static_next_token_decode, 422 | encoded, 423 | batch_size=batch_size, 424 | empty=empty, 425 | num_new_tokens=max_new_tokens, 426 | ) 427 | else: 428 | y, decode_latency, prefill_latency = generate( 429 | model, 430 | encoded, 431 | max_new_tokens, 432 | batch_size=batch_size, 433 | empty=empty, 434 | callback=callback, 435 | temperature=temperature, 436 | top_k=top_k, 437 | ) 438 | 439 | if i == -5: 440 | print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds") 441 | 442 | device_sync(device=device) # MKG 443 | 444 | if i < 0: 445 | continue 446 | 447 | t = time.perf_counter() - t0 448 | 449 | num_users = y.size(0) 450 | 451 | tokens_generated_per_user = y.size(-1) - prompt_length 452 | tokens_generated_per_user_per_sec = tokens_generated_per_user / t 453 | 454 | total_tokens_generated_per_sec = tokens_generated_per_user_per_sec * num_users 455 | 456 | aggregate_metrics["tokens_per_sec_per_user"].append(tokens_generated_per_user_per_sec) 457 | aggregate_metrics["total_tokens_per_sec"].append(total_tokens_generated_per_sec) 458 | aggregate_metrics["decode_latency"].append(decode_latency) 459 | aggregate_metrics["prefill_latency"].append(prefill_latency) 460 | 461 | print_rank_0(f"Time for inference {i + 1}: {t:.02f} sec total") 462 | print_rank_0(f"Tokens per second per user: {tokens_generated_per_user_per_sec:.02f} tokens/sec/user") 463 | print_rank_0(f"Total tokens per second: {total_tokens_generated_per_sec:.02f} tokens/sec") 464 | print_rank_0(f"Decode latency: {decode_latency:.02f} sec") 465 | print_rank_0(f"Prefill latency: {prefill_latency:.02f} sec") 466 | print_rank_0(f"Bandwidth achieved: {model_size * tokens_generated_per_user_per_sec / 1e9:.02f} GB/s") 467 | total_tokens_sec = y.numel() / t 468 | print_rank_0(f"FLOPS achieved: {params * total_tokens_sec * 2 / 1e12:.02f} TF/s") 469 | print_rank_0() 470 | 471 | print_rank_0("==========") 472 | 473 | print_rank_0(f"Batch Size: {batch_size}") 474 | print_rank_0(f"Prompt Length: {prompt_length}") 475 | print_rank_0(f"Generated tokens: {max_new_tokens}") 476 | print_rank_0( 477 | f"Average decode latency: {torch.mean(torch.tensor(aggregate_metrics['decode_latency'])).item():.04f} sec" 478 | ) 479 | print_rank_0( 480 | f"Average prefill latency: {torch.mean(torch.tensor(aggregate_metrics['prefill_latency'])).item():.04f} sec" 481 | ) 482 | print_rank_0( 483 | f"Average tokens/sec/user: {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec_per_user'])).item():.2f}" 484 | ) 485 | print_rank_0( 486 | f"Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['total_tokens_per_sec'])).item():.2f}" 487 | ) 488 | print_rank_0(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") 489 | 490 | dist.barrier() 491 | print_rank_0("Done. we are killing the process") 492 | exit() 493 | 494 | 495 | if __name__ == "__main__": 496 | parser = argparse.ArgumentParser(description="Your CLI description.") 497 | 498 | parser.add_argument("--model_name", type=str, required=True, help="model name") 499 | parser.add_argument("--prompt_length", type=int, required=True, help="Input prompt length") 500 | parser.add_argument("--num_samples", type=int, default=5, help="Number of samples.") 501 | parser.add_argument("--max_new_tokens", type=int, default=200, help="Maximum number of new tokens.") 502 | parser.add_argument("--batch_size", type=int, default=1, help="Batch size to benchmark with") 503 | parser.add_argument("--top_k", type=int, default=200, help="Top-k for sampling.") 504 | parser.add_argument("--temperature", type=float, default=0.8, help="Temperature for sampling.") 505 | parser.add_argument("--compile", action="store_true", help="Whether to compile the model.") 506 | parser.add_argument( 507 | "--compile_prefill", 508 | action="store_true", 509 | help="Whether to compile the prefill (improves prefill perf, but higher compile times)", 510 | ) 511 | parser.add_argument("--cuda_graph", action="store_true", help="Whether to use cuda graphs the model.") 512 | parser.add_argument( 513 | "--use_flash_attention", 514 | action="store_true", 515 | help="Whether to flash decode with kv cache in attn (not compile generated one)", 516 | ) 517 | parser.add_argument("--profile", type=Path, default=None, help="Profile path.") 518 | parser.add_argument("--device", type=str, default=default_device, help="Device to use") 519 | 520 | args = parser.parse_args() 521 | 522 | if args.cuda_graph: 523 | assert not args.compile 524 | 525 | print_rank_0(f"flash_kv_decode is set to {args.use_flash_attention}") 526 | with set_flash_attention(args.use_flash_attention): 527 | main( 528 | args.model_name, 529 | args.prompt_length, 530 | args.num_samples, 531 | args.max_new_tokens, 532 | args.batch_size, 533 | args.top_k, 534 | args.temperature, 535 | args.compile, 536 | args.compile_prefill, 537 | args.profile, 538 | args.device, 539 | args.cuda_graph, 540 | ) 541 | -------------------------------------------------------------------------------- /gpt_fast/__init__.py: -------------------------------------------------------------------------------- 1 | from .gpt_dense_TP import GPTDense 2 | from .gpt_desync_TP import GPTDesync 3 | from .gpt_ladder_TP import GPTLadder 4 | from .gpt_parallel_TP import GPTParallel 5 | from .tp import maybe_init_dist 6 | -------------------------------------------------------------------------------- /gpt_fast/gpt_dense_TP.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | from dataclasses import dataclass 7 | from typing import Optional 8 | 9 | import torch 10 | import torch.distributed as dist 11 | import torch.nn as nn 12 | from torch import Tensor 13 | 14 | from .tp import maybe_init_dist 15 | from .utils import Attention, FeedForward, KVCache, RMSNorm, all_reduce_func, precompute_freqs_cis 16 | 17 | 18 | def find_multiple(n: int, k: int) -> int: 19 | if n % k == 0: 20 | return n 21 | return n + k - (n % k) 22 | 23 | 24 | maybe_init_dist() 25 | tp_rank = dist.get_rank() 26 | tp_world_size = dist.get_world_size() 27 | tp_group = list(range(tp_world_size)) 28 | 29 | 30 | @dataclass 31 | class ModelArgs: 32 | block_size: int = 2048 33 | vocab_size: int = 32000 34 | n_layer: int = 32 35 | n_head: int = 32 36 | dim: int = 4096 37 | intermediate_size: int = None 38 | n_local_heads: int = -1 39 | head_dim: int = 64 40 | rope_base: float = 10000 41 | norm_eps: float = 1e-5 42 | rope_scaling: Optional[dict] = None 43 | semi_compiled_model: bool = False 44 | 45 | def __post_init__(self): 46 | if self.n_local_heads == -1: 47 | self.n_local_heads = self.n_head 48 | if self.intermediate_size is None: 49 | hidden_dim = 4 * self.dim 50 | n_hidden = int(2 * hidden_dim / 3) 51 | self.intermediate_size = find_multiple(n_hidden, 256) 52 | self.head_dim = self.dim // self.n_head 53 | 54 | assert self.dim % tp_world_size == 0 55 | assert self.intermediate_size % tp_world_size == 0 56 | 57 | @classmethod 58 | def from_name(cls, name: str): 59 | if name in transformer_configs: 60 | return cls(**transformer_configs[name]) 61 | 62 | 63 | transformer_configs = { 64 | "CodeLlama-7b-Python-hf": dict(block_size=16384, vocab_size=32000, n_layer=32, dim=4096, rope_base=1000000), 65 | "1b": dict( 66 | block_size=2048, 67 | n_layer=40, 68 | n_head=24, 69 | n_local_heads=24, 70 | dim=1536, 71 | intermediate_size=4096, 72 | vocab_size=49152, 73 | rope_base=10000, 74 | ), 75 | "3b": dict( 76 | block_size=2048, 77 | n_layer=40, 78 | n_head=32, 79 | n_local_heads=32, 80 | dim=2304, 81 | intermediate_size=9216, 82 | vocab_size=49152, 83 | rope_base=10000, 84 | ), 85 | "3.9bh": dict( 86 | block_size=2048, 87 | n_layer=52, 88 | n_head=32, 89 | n_local_heads=32, 90 | dim=2304, 91 | intermediate_size=9216, 92 | vocab_size=49152, 93 | rope_base=10000, 94 | ), 95 | "3.9bw1": dict( 96 | block_size=2048, 97 | n_layer=40, 98 | n_head=32, 99 | n_local_heads=32, 100 | dim=3200, 101 | intermediate_size=9216, 102 | vocab_size=49152, 103 | rope_base=10000, 104 | ), 105 | "3.9bw2": dict( 106 | block_size=2048, 107 | n_layer=40, 108 | n_head=32, 109 | n_local_heads=32, 110 | dim=2560, 111 | intermediate_size=11000, 112 | vocab_size=49152, 113 | rope_base=10000, 114 | ), 115 | "7B": dict(n_layer=32, n_head=32, dim=4096), 116 | "13B": dict(n_layer=40, n_head=40, dim=5120), 117 | "30B": dict(n_layer=60, n_head=52, dim=6656), 118 | "34B": dict( 119 | n_layer=48, n_head=64, dim=8192, vocab_size=32000, n_local_heads=8, intermediate_size=22016, rope_base=1000000 120 | ), # CodeLlama-34B-Python-hf 121 | "70B": dict(n_layer=80, n_head=64, dim=8192, n_local_heads=8, intermediate_size=28672), 122 | "70B-semi-compiled": dict( 123 | n_layer=80, n_head=64, dim=8192, n_local_heads=8, intermediate_size=28672, semi_compiled_model=True 124 | ), 125 | "Mistral-7B": dict(n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=32000), 126 | "stories15M": dict(n_layer=6, n_head=6, dim=288), 127 | "stories110M": dict(n_layer=12, n_head=12, dim=768), 128 | "llama-3.2-1b": dict(n_layer=16, n_head=32, dim=2048, intermediate_size=8192, vocab_size=128256), 129 | "llama-3.2-3b": dict(n_layer=28, n_head=24, dim=3072, intermediate_size=8192, vocab_size=128256), 130 | "llama-3-8b-4layers": dict( 131 | block_size=8192, 132 | n_layer=4, 133 | n_head=32, 134 | n_local_heads=8, 135 | dim=4096, 136 | intermediate_size=14336, 137 | vocab_size=128256, 138 | rope_base=500000, 139 | ), 140 | "llama-3-8b": dict( 141 | block_size=8192, 142 | n_layer=32, 143 | n_head=32, 144 | n_local_heads=8, 145 | dim=4096, 146 | intermediate_size=14336, 147 | vocab_size=128256, 148 | rope_base=500000, 149 | ), 150 | "llama-3-70b": dict( 151 | block_size=8192, 152 | n_layer=80, 153 | n_head=64, 154 | n_local_heads=8, 155 | dim=8192, 156 | intermediate_size=28672, 157 | vocab_size=128256, 158 | rope_base=500000, 159 | ), 160 | "llama-3.1-405b": dict( 161 | block_size=131072, 162 | n_layer=126, 163 | n_head=128, 164 | n_local_heads=16, 165 | dim=16384, 166 | intermediate_size=53248, 167 | vocab_size=128256, 168 | rope_base=500000, 169 | rope_scaling=dict( 170 | factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_position_embeddings=8192 171 | ), 172 | ), 173 | "bloom-176b": dict( 174 | block_size=8192, 175 | n_layer=70, 176 | n_head=112, 177 | dim=14336, 178 | itermediate_size=50176, 179 | vocab_size=250880, 180 | rope_base=500000, 181 | ), 182 | } 183 | 184 | 185 | class GPTDense(nn.Module): 186 | def __init__(self, config: ModelArgs) -> None: 187 | super().__init__() 188 | self.config = config 189 | 190 | self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim) 191 | self.layers = nn.ModuleList(DenseTransformerBlock(config) for _ in range(config.n_layer)) 192 | self.norm = RMSNorm(config.dim, eps=config.norm_eps) 193 | self.output = nn.Linear(config.dim, config.vocab_size, bias=False) 194 | 195 | self.freqs_cis: Optional[Tensor] = None 196 | self.mask_cache: Optional[Tensor] = None 197 | self.max_batch_size = -1 198 | self.max_seq_length = -1 199 | 200 | def setup_caches(self, max_batch_size, max_seq_length): 201 | if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size: 202 | return 203 | head_dim = self.config.dim // self.config.n_head 204 | max_seq_length = find_multiple(max_seq_length, 8) 205 | self.max_seq_length = max_seq_length 206 | self.max_batch_size = max_batch_size 207 | dtype = self.output.weight.dtype 208 | # For quantized layers, dtype is encoded in scales 209 | if hasattr(self.output, "scales"): 210 | dtype = self.output.scales.dtype 211 | elif hasattr(self.output, "scales_and_zeros"): 212 | dtype = self.output.scales_and_zeros.dtype 213 | for b in self.layers: 214 | b.attention.kv_cache = KVCache( 215 | max_batch_size, max_seq_length, self.config.n_local_heads // tp_world_size, head_dim, dtype 216 | ) 217 | 218 | self.freqs_cis = precompute_freqs_cis( 219 | self.config.block_size, 220 | self.config.dim // self.config.n_head, 221 | self.config.rope_base, 222 | dtype, 223 | self.config.rope_scaling, 224 | ) 225 | self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)) 226 | 227 | def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: 228 | assert self.freqs_cis is not None, "Caches must be initialized first" 229 | mask = self.causal_mask[None, None, input_pos] 230 | freqs_cis = self.freqs_cis[input_pos] 231 | x = self.tok_embeddings(idx) 232 | 233 | for i, layer in enumerate(self.layers): 234 | x = layer(x, input_pos, freqs_cis, mask) 235 | x = self.norm(x) 236 | logits = self.output(x) 237 | return logits 238 | 239 | @classmethod 240 | def from_name(cls, name: str): 241 | return cls(ModelArgs.from_name(name)) 242 | 243 | 244 | class DenseTransformerBlock(nn.Module): 245 | def __init__(self, config: ModelArgs) -> None: 246 | super().__init__() 247 | self.attention = Attention(config) 248 | self.feed_forward = FeedForward(config) 249 | self.ffn_norm = RMSNorm(config.dim, config.norm_eps) 250 | self.attention_norm = RMSNorm(config.dim, config.norm_eps) 251 | 252 | def _attn(x, residual, freqs_cis, mask, input_pos): 253 | x = self.attention_norm(x) 254 | x = self.attention(x, freqs_cis, mask, input_pos) 255 | 256 | if tp_rank == 0: 257 | x = x + residual 258 | 259 | return x 260 | 261 | def _ffn(x, residual): 262 | x = self.ffn_norm(x) 263 | x = self.feed_forward(x) 264 | 265 | if tp_rank == 0: 266 | x = x + residual 267 | 268 | return x 269 | 270 | self._attn = torch.compile(_attn) 271 | self._ffn = torch.compile(_ffn) 272 | self.semi_compiled_model = config.semi_compiled_model 273 | 274 | def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor: 275 | if self.semi_compiled_model: 276 | x = self._attn(x, x, freqs_cis, mask, input_pos) 277 | x = all_reduce_func(x, clone=True)[0] 278 | x = self._ffn(x, x) 279 | x = all_reduce_func(x, clone=True)[0] 280 | else: 281 | x = x + all_reduce_func(self.attention(self.attention_norm(x), freqs_cis, mask, input_pos), clone=False)[0] 282 | x = x + all_reduce_func(self.feed_forward(self.ffn_norm(x)), clone=False)[0] 283 | 284 | return x 285 | 286 | def extra_repr(self) -> str: 287 | return f"semi_compiled = {self.semi_compiled_model}" 288 | -------------------------------------------------------------------------------- /gpt_fast/gpt_desync_TP.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | from dataclasses import dataclass 7 | from typing import Optional 8 | 9 | import torch 10 | import torch.distributed as dist 11 | import torch.nn as nn 12 | from torch import Tensor 13 | 14 | from .tp import maybe_init_dist 15 | from .utils import Attention, FeedForward, KVCache, RMSNorm, all_reduce_func, precompute_freqs_cis 16 | 17 | 18 | def find_multiple(n: int, k: int) -> int: 19 | if n % k == 0: 20 | return n 21 | return n + k - (n % k) 22 | 23 | 24 | maybe_init_dist() 25 | tp_rank = dist.get_rank() 26 | tp_world_size = dist.get_world_size() 27 | tp_group = list(range(tp_world_size)) 28 | 29 | 30 | @dataclass 31 | class ModelArgs: 32 | block_size: int = 2048 33 | vocab_size: int = 32000 34 | n_layer: int = 32 35 | n_head: int = 32 36 | dim: int = 4096 37 | intermediate_size: int = None 38 | n_local_heads: int = -1 39 | head_dim: int = 64 40 | rope_base: float = 10000 41 | norm_eps: float = 1e-5 42 | rope_scaling: Optional[dict] = None 43 | semi_compiled_model: bool = False 44 | reduce_pattern: Optional[dict] = None 45 | force_disable_last_all_reduce: bool = False 46 | 47 | def __post_init__(self): 48 | if self.n_local_heads == -1: 49 | self.n_local_heads = self.n_head 50 | if self.intermediate_size is None: 51 | hidden_dim = 4 * self.dim 52 | n_hidden = int(2 * hidden_dim / 3) 53 | self.intermediate_size = find_multiple(n_hidden, 256) 54 | self.head_dim = self.dim // self.n_head 55 | 56 | assert self.dim % tp_world_size == 0 57 | assert self.intermediate_size % tp_world_size == 0 58 | 59 | if self.reduce_pattern is None: 60 | self.reduce_pattern = [{"attention": False, "mlp": True} for i in range(self.n_layer)] 61 | 62 | @classmethod 63 | def from_name(cls, name: str): 64 | if name in transformer_configs: 65 | return cls(**transformer_configs[name]) 66 | 67 | 68 | transformer_configs = { 69 | "CodeLlama-7b-Python-hf": dict(block_size=16384, vocab_size=32000, n_layer=32, dim=4096, rope_base=1000000), 70 | "1b": dict( 71 | block_size=2048, 72 | n_layer=40, 73 | n_head=24, 74 | n_local_heads=24, 75 | dim=1536, 76 | intermediate_size=4096, 77 | vocab_size=49152, 78 | rope_base=10000, 79 | reduce_pattern=[{"attention": False, "mlp": False} for _ in range(40)], 80 | force_disable_last_all_reduce=True, 81 | ), 82 | "1b-upper-bound": dict( 83 | block_size=2048, 84 | n_layer=40, 85 | n_head=24, 86 | n_local_heads=24, 87 | dim=1536, 88 | intermediate_size=4096, 89 | vocab_size=49152, 90 | rope_base=10000, 91 | ), 92 | "3b": dict( 93 | block_size=2048, 94 | n_layer=40, 95 | n_head=36, 96 | n_local_heads=36, 97 | dim=2304, 98 | intermediate_size=9216, 99 | vocab_size=49152, 100 | rope_base=10000, 101 | ), 102 | "3b-upper-bound": dict( 103 | block_size=2048, 104 | n_layer=40, 105 | n_head=36, 106 | n_local_heads=36, 107 | dim=2304, 108 | intermediate_size=9216, 109 | vocab_size=49152, 110 | rope_base=10000, 111 | reduce_pattern=[{"attention": False, "mlp": False} for _ in range(40)], 112 | force_disable_last_all_reduce=True, 113 | ), 114 | "7B": dict(n_layer=32, n_head=32, dim=4096), 115 | "13B": dict(n_layer=40, n_head=40, dim=5120), 116 | "30B": dict(n_layer=60, n_head=52, dim=6656), 117 | "34B": dict( 118 | n_layer=48, n_head=64, dim=8192, vocab_size=32000, n_local_heads=8, intermediate_size=22016, rope_base=1000000 119 | ), # CodeLlama-34B-Python-hf 120 | "70B": dict(n_layer=80, n_head=64, dim=8192, n_local_heads=8, intermediate_size=28672), 121 | "70B-infinite": dict( 122 | n_layer=80, 123 | n_head=64, 124 | dim=8192, 125 | n_local_heads=8, 126 | intermediate_size=28672, 127 | reduce_pattern=[{"attention": False, "mlp": False} for _ in range(80)], 128 | ), 129 | "70B-semi-compiled": dict( 130 | n_layer=80, n_head=64, dim=8192, n_local_heads=8, intermediate_size=28672, semi_compiled_model=True 131 | ), 132 | "70B-infinite-semi-compiled": dict( 133 | n_layer=80, 134 | n_head=64, 135 | dim=8192, 136 | n_local_heads=8, 137 | intermediate_size=28672, 138 | semi_compiled_model=True, 139 | reduce_pattern=[{"attention": False, "mlp": False} for _ in range(80)], 140 | ), 141 | "Mistral-7B": dict(n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=32000), 142 | "stories15M": dict(n_layer=6, n_head=6, dim=288), 143 | "stories110M": dict(n_layer=12, n_head=12, dim=768), 144 | "llama-3-8b": dict( 145 | block_size=8192, 146 | n_layer=32, 147 | n_head=32, 148 | n_local_heads=8, 149 | dim=4096, 150 | intermediate_size=14336, 151 | vocab_size=128256, 152 | rope_base=500000, 153 | ), 154 | "llama-3-8b-4x": dict( 155 | block_size=8192, 156 | n_layer=32, 157 | n_head=32, 158 | n_local_heads=8, 159 | dim=4096, 160 | intermediate_size=14336, 161 | vocab_size=128256, 162 | rope_base=500000, 163 | reduce_pattern=[{"attention": False, "mlp": False}, {"attention": False, "mlp": True}] * 16, 164 | ), 165 | "llama-3-8b-upper-bound": dict( 166 | block_size=8192, 167 | n_layer=32, 168 | n_head=32, 169 | n_local_heads=8, 170 | dim=4096, 171 | intermediate_size=14336, 172 | vocab_size=128256, 173 | rope_base=500000, 174 | reduce_pattern=[{"attention": False, "mlp": False} for _ in range(32)], 175 | force_disable_last_all_reduce=True, 176 | ), 177 | "llama-3-8b-infinite": dict( 178 | block_size=8192, 179 | n_layer=32, 180 | n_head=32, 181 | n_local_heads=8, 182 | dim=4096, 183 | intermediate_size=14336, 184 | vocab_size=128256, 185 | rope_base=500000, 186 | reduce_pattern=[{"attention": False, "mlp": False} for _ in range(32)], 187 | ), 188 | "llama-3-70b-4x": dict( 189 | block_size=8192, 190 | n_layer=80, 191 | n_head=64, 192 | n_local_heads=8, 193 | dim=8192, 194 | intermediate_size=28672, 195 | vocab_size=128256, 196 | rope_base=500000, 197 | reduce_pattern=[{"attention": False, "mlp": False}, {"attention": False, "mlp": True}] * 40, 198 | ), 199 | "llama-3-70b": dict( 200 | block_size=8192, 201 | n_layer=80, 202 | n_head=64, 203 | n_local_heads=8, 204 | dim=8192, 205 | intermediate_size=28672, 206 | vocab_size=128256, 207 | rope_base=500000, 208 | ), 209 | "llama-3-70b-upper-bound": dict( 210 | block_size=8192, 211 | n_layer=80, 212 | n_head=64, 213 | n_local_heads=8, 214 | dim=8192, 215 | intermediate_size=28672, 216 | vocab_size=128256, 217 | rope_base=500000, 218 | reduce_pattern=[{"attention": False, "mlp": False} for _ in range(80)], 219 | force_disable_last_all_reduce=True, 220 | ), 221 | "llama-3.1-405b": dict( 222 | block_size=131072, 223 | n_layer=126, 224 | n_head=128, 225 | n_local_heads=16, 226 | dim=16384, 227 | intermediate_size=53248, 228 | vocab_size=128256, 229 | rope_base=500000, 230 | rope_scaling=dict( 231 | factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_position_embeddings=8192 232 | ), 233 | ), 234 | "llama-3.1-405b-upper-bound": dict( 235 | block_size=131072, 236 | n_layer=126, 237 | n_head=128, 238 | n_local_heads=16, 239 | dim=16384, 240 | intermediate_size=53248, 241 | vocab_size=128256, 242 | rope_base=500000, 243 | reduce_pattern=[{"attention": False, "mlp": False} for _ in range(126)], 244 | force_disable_last_all_reduce=True, 245 | rope_scaling=dict( 246 | factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_position_embeddings=8192 247 | ), 248 | ), 249 | } 250 | 251 | 252 | class GPTDesync(nn.Module): 253 | def __init__(self, config: ModelArgs) -> None: 254 | super().__init__() 255 | self.config = config 256 | 257 | self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim) 258 | self.layers = nn.ModuleList(DesyncTransformerBlock(config, layer_idx=i) for i in range(config.n_layer)) 259 | self.norm = RMSNorm(config.dim, eps=config.norm_eps) 260 | self.output = nn.Linear(config.dim, config.vocab_size, bias=False) 261 | 262 | self.freqs_cis: Optional[Tensor] = None 263 | self.mask_cache: Optional[Tensor] = None 264 | self.max_batch_size = -1 265 | self.max_seq_length = -1 266 | 267 | def setup_caches(self, max_batch_size, max_seq_length): 268 | if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size: 269 | return 270 | head_dim = self.config.dim // self.config.n_head 271 | max_seq_length = find_multiple(max_seq_length, 8) 272 | self.max_seq_length = max_seq_length 273 | self.max_batch_size = max_batch_size 274 | dtype = self.output.weight.dtype 275 | # For quantized layers, dtype is encoded in scales 276 | if hasattr(self.output, "scales"): 277 | dtype = self.output.scales.dtype 278 | elif hasattr(self.output, "scales_and_zeros"): 279 | dtype = self.output.scales_and_zeros.dtype 280 | for b in self.layers: 281 | b.attention.kv_cache = KVCache( 282 | max_batch_size, max_seq_length, self.config.n_local_heads // tp_world_size, head_dim, dtype 283 | ) 284 | 285 | self.freqs_cis = precompute_freqs_cis( 286 | self.config.block_size, 287 | self.config.dim // self.config.n_head, 288 | self.config.rope_base, 289 | dtype, 290 | self.config.rope_scaling, 291 | ) 292 | self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)) 293 | 294 | def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: 295 | assert self.freqs_cis is not None, "Caches must be initialized first" 296 | mask = self.causal_mask[None, None, input_pos] 297 | freqs_cis = self.freqs_cis[input_pos] 298 | x = self.tok_embeddings(idx) 299 | 300 | for i, layer in enumerate(self.layers): 301 | x = layer(x, input_pos, freqs_cis, mask) 302 | x = self.norm(x) 303 | logits = self.output(x) 304 | return logits 305 | 306 | @classmethod 307 | def from_name(cls, name: str): 308 | return cls(ModelArgs.from_name(name)) 309 | 310 | 311 | class DesyncTransformerBlock(nn.Module): 312 | def __init__(self, config: ModelArgs, layer_idx: int) -> None: 313 | super().__init__() 314 | self.attention = Attention(config) 315 | self.feed_forward = FeedForward(config) 316 | self.ffn_norm = RMSNorm(config.dim, config.norm_eps) 317 | self.attention_norm = RMSNorm(config.dim, config.norm_eps) 318 | 319 | self.do_attention_all_reduce = config.reduce_pattern[layer_idx]["attention"] 320 | self.do_mlp_all_reduce = layer_idx == config.n_layer - 1 or config.reduce_pattern[layer_idx]["mlp"] 321 | 322 | if layer_idx == config.n_layer - 1 and config.force_disable_last_all_reduce: 323 | self.do_mlp_all_reduce = False 324 | 325 | def _attn(x, freqs_cis, mask, input_pos): 326 | y = self.attention(self.attention_norm(x), freqs_cis, mask, input_pos) 327 | 328 | if self.do_attention_all_reduce: 329 | y = y + x / tp_world_size 330 | else: 331 | y = y + x 332 | 333 | return y 334 | 335 | def _ffn(x): 336 | y = self.feed_forward(self.ffn_norm(x)) 337 | 338 | if self.do_mlp_all_reduce: 339 | y = y + x / tp_world_size 340 | else: 341 | y = y + x 342 | 343 | return y 344 | 345 | self._attn = torch.compile(_attn) 346 | self._ffn = torch.compile(_ffn) 347 | 348 | self.semi_compiled_model = config.semi_compiled_model 349 | 350 | def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor: 351 | if self.semi_compiled_model: 352 | x = self._attn(x, freqs_cis, mask, input_pos) 353 | 354 | if self.do_attention_all_reduce: 355 | x = all_reduce_func(x, clone=True)[0] 356 | else: 357 | x = x.clone() 358 | 359 | x = self._ffn(x) 360 | 361 | if self.do_mlp_all_reduce: 362 | x = all_reduce_func(x, clone=True)[0] 363 | else: 364 | x = x.clone() 365 | else: 366 | residual = x 367 | x = self.attention(self.attention_norm(x), freqs_cis, mask, input_pos) 368 | 369 | if self.do_attention_all_reduce: 370 | x = x + residual / tp_world_size 371 | x = all_reduce_func(x, clone=False)[0] 372 | else: 373 | x = x + residual 374 | 375 | residual = x 376 | x = self.feed_forward(self.ffn_norm(x)) 377 | 378 | if self.do_mlp_all_reduce: 379 | x = x + residual / tp_world_size 380 | x = all_reduce_func(x, clone=False)[0] 381 | else: 382 | x = x + residual 383 | 384 | return x 385 | 386 | def extra_repr(self) -> str: 387 | return f"semi_compiled = {self.semi_compiled_model}" 388 | -------------------------------------------------------------------------------- /gpt_fast/gpt_ladder_TP.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | from dataclasses import dataclass 7 | from typing import Optional 8 | 9 | import torch 10 | import torch.distributed as dist 11 | import torch.nn as nn 12 | import triton 13 | import triton.language as tl 14 | from torch import Tensor 15 | 16 | from .tp import maybe_init_dist 17 | from .utils import Attention, FeedForward, KVCache, RMSNorm, all_reduce_func, precompute_freqs_cis 18 | 19 | 20 | def find_multiple(n: int, k: int) -> int: 21 | if n % k == 0: 22 | return n 23 | return n + k - (n % k) 24 | 25 | 26 | maybe_init_dist() 27 | tp_rank = dist.get_rank() 28 | tp_world_size = dist.get_world_size() 29 | tp_group = list(range(tp_world_size)) 30 | 31 | 32 | @dataclass 33 | class ModelArgs: 34 | block_size: int = 2048 35 | vocab_size: int = 32000 36 | n_layer: int = 32 37 | n_head: int = 32 38 | dim: int = 4096 39 | intermediate_size: int = None 40 | n_local_heads: int = -1 41 | head_dim: int = 64 42 | rope_base: float = 10000 43 | norm_eps: float = 1e-5 44 | rope_scaling: Optional[dict] = None 45 | semi_compiled_model: bool = False 46 | 47 | def __post_init__(self): 48 | if self.n_local_heads == -1: 49 | self.n_local_heads = self.n_head 50 | if self.intermediate_size is None: 51 | hidden_dim = 4 * self.dim 52 | n_hidden = int(2 * hidden_dim / 3) 53 | self.intermediate_size = find_multiple(n_hidden, 256) 54 | self.head_dim = self.dim // self.n_head 55 | 56 | assert self.dim % tp_world_size == 0 57 | assert self.intermediate_size % tp_world_size == 0 58 | 59 | @classmethod 60 | def from_name(cls, name: str): 61 | if name in transformer_configs: 62 | return cls(**transformer_configs[name]) 63 | 64 | 65 | transformer_configs = { 66 | "CodeLlama-7b-Python-hf": dict(block_size=16384, vocab_size=32000, n_layer=32, dim=4096, rope_base=1000000), 67 | "1b": dict( 68 | block_size=2048, 69 | n_layer=40, 70 | n_head=24, 71 | n_local_heads=24, 72 | dim=1536, 73 | intermediate_size=4096, 74 | vocab_size=49152, 75 | rope_base=10000, 76 | ), 77 | "1.55b": dict( 78 | block_size=2048, 79 | n_layer=52, 80 | n_head=24, 81 | n_local_heads=24, 82 | dim=1536, 83 | intermediate_size=4096, 84 | vocab_size=49152, 85 | rope_base=10000, 86 | ), 87 | "3b": dict( 88 | block_size=2048, 89 | n_layer=40, 90 | n_head=32, 91 | n_local_heads=32, 92 | dim=2304, 93 | intermediate_size=9216, 94 | vocab_size=49152, 95 | rope_base=10000, 96 | ), 97 | "3.9bh": dict( 98 | block_size=2048, 99 | n_layer=52, 100 | n_head=32, 101 | n_local_heads=32, 102 | dim=2304, 103 | intermediate_size=9216, 104 | vocab_size=49152, 105 | rope_base=10000, 106 | ), 107 | "3.9bw1": dict( 108 | block_size=2048, 109 | n_layer=40, 110 | n_head=32, 111 | n_local_heads=32, 112 | dim=3200, 113 | intermediate_size=9216, 114 | vocab_size=49152, 115 | rope_base=10000, 116 | ), 117 | "3.9bw2": dict( 118 | block_size=2048, 119 | n_layer=40, 120 | n_head=32, 121 | n_local_heads=32, 122 | dim=2560, 123 | intermediate_size=11000, 124 | vocab_size=49152, 125 | rope_base=10000, 126 | ), 127 | "7B": dict(n_layer=32, n_head=32, dim=4096), 128 | "13B": dict(n_layer=40, n_head=40, dim=5120), 129 | "30B": dict(n_layer=60, n_head=52, dim=6656), 130 | "34B": dict( 131 | n_layer=48, n_head=64, dim=8192, vocab_size=32000, n_local_heads=8, intermediate_size=22016, rope_base=1000000 132 | ), # CodeLlama-34B-Python-hf 133 | "70B": dict(n_layer=80, n_head=64, dim=8192, n_local_heads=8, intermediate_size=28672), 134 | "70B-semi-compiled": dict( 135 | n_layer=80, n_head=64, dim=8192, n_local_heads=8, intermediate_size=28672, semi_compiled_model=True 136 | ), 137 | "Mistral-7B": dict(n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=32000), 138 | "stories15M": dict(n_layer=6, n_head=6, dim=288), 139 | "stories110M": dict(n_layer=12, n_head=12, dim=768), 140 | "llama-3.2-1b": dict(n_layer=16, n_head=32, dim=2048, intermediate_size=8192, vocab_size=128256), 141 | "llama-3.2-3b": dict(n_layer=28, n_head=24, dim=3072, intermediate_size=8192, vocab_size=128256), 142 | "llama-3-8b": dict( 143 | block_size=8192, 144 | n_layer=32, 145 | n_head=32, 146 | n_local_heads=8, 147 | dim=4096, 148 | intermediate_size=14336, 149 | vocab_size=128256, 150 | rope_base=500000, 151 | ), 152 | "llama-3-8b-semi-compiled": dict( 153 | block_size=8192, 154 | n_layer=32, 155 | n_head=32, 156 | n_local_heads=8, 157 | dim=4096, 158 | intermediate_size=14336, 159 | vocab_size=128256, 160 | rope_base=500000, 161 | semi_compiled_model=True, 162 | ), 163 | "llama-3-70b": dict( 164 | block_size=8192, 165 | n_layer=80, 166 | n_head=64, 167 | n_local_heads=8, 168 | dim=8192, 169 | intermediate_size=28672, 170 | vocab_size=128256, 171 | rope_base=500000, 172 | ), 173 | "llama-3-70b-semi-compiled": dict( 174 | block_size=8192, 175 | n_layer=80, 176 | n_head=64, 177 | n_local_heads=8, 178 | dim=8192, 179 | intermediate_size=28672, 180 | vocab_size=128256, 181 | rope_base=500000, 182 | semi_compiled_model=True, 183 | ), 184 | "llama-3.1-405b": dict( 185 | block_size=131072, 186 | n_layer=126, 187 | n_head=128, 188 | n_local_heads=16, 189 | dim=16384, 190 | intermediate_size=53248, 191 | vocab_size=128256, 192 | rope_base=500000, 193 | rope_scaling=dict( 194 | factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_position_embeddings=8192 195 | ), 196 | ), 197 | "bloom-176b": dict( 198 | block_size=8192, 199 | n_layer=70, 200 | n_head=112, 201 | n_local_heads=112, 202 | dim=14336, 203 | intermediate_size=50176, 204 | vocab_size=250880, 205 | rope_base=500000, 206 | ), 207 | } 208 | 209 | 210 | class GPTLadder(nn.Module): 211 | def __init__(self, config: ModelArgs) -> None: 212 | super().__init__() 213 | self.config = config 214 | 215 | self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim) 216 | self.layers = nn.ModuleList(LadderTransformerBlock(config) for _ in range(config.n_layer)) 217 | self.norm = RMSNorm(config.dim, eps=config.norm_eps) 218 | self.output = nn.Linear(config.dim, config.vocab_size, bias=False) 219 | 220 | self.freqs_cis: Optional[Tensor] = None 221 | self.mask_cache: Optional[Tensor] = None 222 | self.max_batch_size = -1 223 | self.max_seq_length = -1 224 | 225 | def setup_caches(self, max_batch_size, max_seq_length): 226 | if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size: 227 | return 228 | head_dim = self.config.dim // self.config.n_head 229 | max_seq_length = find_multiple(max_seq_length, 8) 230 | self.max_seq_length = max_seq_length 231 | self.max_batch_size = max_batch_size 232 | dtype = self.output.weight.dtype 233 | # For quantized layers, dtype is encoded in scales 234 | if hasattr(self.output, "scales"): 235 | dtype = self.output.scales.dtype 236 | elif hasattr(self.output, "scales_and_zeros"): 237 | dtype = self.output.scales_and_zeros.dtype 238 | for b in self.layers: 239 | b.attention.kv_cache = KVCache( 240 | max_batch_size, max_seq_length, self.config.n_local_heads // tp_world_size, head_dim, dtype 241 | ) 242 | 243 | self.freqs_cis = precompute_freqs_cis( 244 | self.config.block_size, 245 | self.config.dim // self.config.n_head, 246 | self.config.rope_base, 247 | dtype, 248 | self.config.rope_scaling, 249 | ) 250 | self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)) 251 | 252 | def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: 253 | assert self.freqs_cis is not None, "Caches must be initialized first" 254 | mask = self.causal_mask[None, None, input_pos] 255 | freqs_cis = self.freqs_cis[input_pos] 256 | x = self.tok_embeddings(idx) 257 | 258 | previous_attention_out = torch.zeros_like(x) 259 | previous_mlp_out = torch.zeros_like(x) 260 | attention_handle = None 261 | mlp_handle = None 262 | for i, layer in enumerate(self.layers): 263 | previous_attention_out, previous_mlp_out, x, attention_handle, mlp_handle = layer( 264 | previous_attention_out, 265 | previous_mlp_out, 266 | x, 267 | attention_handle, 268 | mlp_handle, 269 | input_pos, 270 | freqs_cis, 271 | mask, 272 | ) 273 | 274 | if attention_handle is not None: 275 | attention_handle.wait() 276 | 277 | if mlp_handle is not None: 278 | mlp_handle.wait() 279 | 280 | x = x + previous_attention_out + previous_mlp_out 281 | 282 | x = self.norm(x) 283 | logits = self.output(x) 284 | return logits 285 | 286 | @classmethod 287 | def from_name(cls, name: str): 288 | return cls(ModelArgs.from_name(name)) 289 | 290 | 291 | class LadderTransformerBlock(nn.Module): 292 | def __init__(self, config: ModelArgs) -> None: 293 | super().__init__() 294 | self.attention = Attention(config) 295 | self.feed_forward = FeedForward(config) 296 | self.ffn_norm = RMSNorm(config.dim, config.norm_eps) 297 | self.attention_norm = RMSNorm(config.dim, config.norm_eps) 298 | 299 | def _attn(x, freqs_cis, mask, input_pos): 300 | current_attention_out = self.attention(self.attention_norm(x), freqs_cis, mask, input_pos) 301 | return current_attention_out 302 | 303 | def _ffn(x): 304 | current_mlp_out = self.feed_forward(self.ffn_norm(x)) 305 | return current_mlp_out 306 | 307 | self.semi_compiled_model = config.semi_compiled_model 308 | if self.semi_compiled_model: 309 | self._attn = torch.compile(_attn) 310 | self._ffn = torch.compile(_ffn) 311 | 312 | def forward( 313 | self, 314 | previous_attention_out: Tensor, 315 | previous_mlp_out: Tensor, 316 | residual: Tensor, 317 | attention_handle, 318 | mlp_handle, 319 | input_pos: Tensor, 320 | freqs_cis: Tensor, 321 | mask: Tensor, 322 | ) -> Tensor: 323 | if attention_handle is not None: 324 | attention_handle.wait() 325 | 326 | numel = residual.numel() 327 | grid = (triton.cdiv(numel, 1024),) 328 | 329 | output = torch.empty_like(residual) 330 | # with torch.device(residual.device): 331 | add_tensor_forward_triton_kernel[grid](residual, previous_attention_out, output, numel, 1024) 332 | residual = output 333 | 334 | if self.semi_compiled_model: 335 | current_attention_out = self._attn(residual, freqs_cis, mask, input_pos) 336 | else: 337 | current_attention_out = self.attention(self.attention_norm(residual), freqs_cis, mask, input_pos) 338 | 339 | current_attention_out, attention_handle = all_reduce_func( 340 | current_attention_out, clone=self.semi_compiled_model, async_op=True 341 | ) 342 | 343 | if mlp_handle is not None: 344 | mlp_handle.wait() 345 | 346 | output = torch.empty_like(residual) 347 | # with torch.device(residual.device): 348 | add_tensor_forward_triton_kernel[grid](residual, previous_mlp_out, output, numel, 1024) 349 | residual = output 350 | 351 | if self.semi_compiled_model: 352 | current_mlp_out = self._ffn(residual) 353 | else: 354 | current_mlp_out = self.feed_forward(self.ffn_norm(residual)) 355 | 356 | current_mlp_out, mlp_handle = all_reduce_func(current_mlp_out, clone=self.semi_compiled_model, async_op=True) 357 | 358 | return current_attention_out, current_mlp_out, residual, attention_handle, mlp_handle 359 | 360 | def extra_repr(self) -> str: 361 | return f"semi_compiled = {self.semi_compiled_model}" 362 | 363 | 364 | @triton.jit 365 | def add_tensor_forward_triton_kernel(x_ptr, y_ptr, output_ptr, num_elements, BLOCK_SIZE: tl.constexpr): 366 | pid = tl.program_id(axis=0) 367 | 368 | block_start = pid * BLOCK_SIZE 369 | indices = block_start + tl.arange(0, BLOCK_SIZE) 370 | mask = indices < num_elements 371 | 372 | x = tl.load(x_ptr + indices, mask=mask) 373 | y = tl.load(y_ptr + indices, mask=mask) 374 | 375 | output = x + y 376 | 377 | tl.store(output_ptr + indices, output, mask=mask) 378 | -------------------------------------------------------------------------------- /gpt_fast/gpt_parallel_TP.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | from dataclasses import dataclass 7 | from typing import Optional 8 | 9 | import torch 10 | import torch.distributed as dist 11 | import torch.nn as nn 12 | from torch import Tensor 13 | 14 | from .tp import maybe_init_dist 15 | from .utils import FuseAttentionMLP, KVCache, RMSNorm, all_reduce_func, precompute_freqs_cis 16 | 17 | 18 | def find_multiple(n: int, k: int) -> int: 19 | if n % k == 0: 20 | return n 21 | return n + k - (n % k) 22 | 23 | 24 | maybe_init_dist() 25 | tp_rank = dist.get_rank() 26 | tp_world_size = dist.get_world_size() 27 | tp_group = list(range(tp_world_size)) 28 | 29 | 30 | @dataclass 31 | class ModelArgs: 32 | block_size: int = 2048 33 | vocab_size: int = 32000 34 | n_layer: int = 32 35 | n_head: int = 32 36 | dim: int = 4096 37 | intermediate_size: int = None 38 | n_local_heads: int = -1 39 | head_dim: int = 64 40 | rope_base: float = 10000 41 | norm_eps: float = 1e-5 42 | rope_scaling: Optional[dict] = None 43 | semi_compiled_model: bool = False 44 | 45 | def __post_init__(self): 46 | if self.n_local_heads == -1: 47 | self.n_local_heads = self.n_head 48 | if self.intermediate_size is None: 49 | hidden_dim = 4 * self.dim 50 | n_hidden = int(2 * hidden_dim / 3) 51 | self.intermediate_size = find_multiple(n_hidden, 256) 52 | self.head_dim = self.dim // self.n_head 53 | 54 | assert self.dim % tp_world_size == 0 55 | assert self.intermediate_size % tp_world_size == 0 56 | 57 | @classmethod 58 | def from_name(cls, name: str): 59 | if name in transformer_configs: 60 | return cls(**transformer_configs[name]) 61 | 62 | 63 | transformer_configs = { 64 | "CodeLlama-7b-Python-hf": dict(block_size=16384, vocab_size=32000, n_layer=32, dim=4096, rope_base=1000000), 65 | "1b": dict( 66 | block_size=2048, 67 | n_layer=40, 68 | n_head=24, 69 | n_local_heads=24, 70 | dim=1536, 71 | intermediate_size=4096, 72 | vocab_size=49152, 73 | rope_base=10000, 74 | ), 75 | "3b": dict( 76 | block_size=2048, 77 | n_layer=40, 78 | n_head=32, 79 | n_local_heads=32, 80 | dim=2304, 81 | intermediate_size=9216, 82 | vocab_size=49152, 83 | rope_base=10000, 84 | ), 85 | "7B": dict(n_layer=32, n_head=32, dim=4096), 86 | "13B": dict(n_layer=40, n_head=40, dim=5120), 87 | "30B": dict(n_layer=60, n_head=52, dim=6656), 88 | "34B": dict( 89 | n_layer=48, n_head=64, dim=8192, vocab_size=32000, n_local_heads=8, intermediate_size=22016, rope_base=1000000 90 | ), # CodeLlama-34B-Python-hf 91 | "70B": dict(n_layer=80, n_head=64, dim=8192, n_local_heads=8, intermediate_size=28672), 92 | "70B-semi-compiled": dict( 93 | n_layer=80, n_head=64, dim=8192, n_local_heads=8, intermediate_size=28672, semi_compiled_model=True 94 | ), 95 | "Mistral-7B": dict(n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=32000), 96 | "stories15M": dict(n_layer=6, n_head=6, dim=288), 97 | "stories110M": dict(n_layer=12, n_head=12, dim=768), 98 | "llama-3.2-1b": dict(n_layer=16, n_head=32, dim=2048, intermediate_size=8192, vocab_size=128256), 99 | "llama-3.2-3b": dict(n_layer=28, n_head=24, dim=3072, intermediate_size=8192, vocab_size=128256), 100 | "llama-3-8b": dict( 101 | block_size=8192, 102 | n_layer=32, 103 | n_head=32, 104 | n_local_heads=8, 105 | dim=4096, 106 | intermediate_size=14336, 107 | vocab_size=128256, 108 | rope_base=500000, 109 | ), 110 | "llama-3-70b": dict( 111 | block_size=8192, 112 | n_layer=80, 113 | n_head=64, 114 | n_local_heads=8, 115 | dim=8192, 116 | intermediate_size=28672, 117 | vocab_size=128256, 118 | rope_base=500000, 119 | ), 120 | "llama-3.1-405b": dict( 121 | block_size=131072, 122 | n_layer=126, 123 | n_head=128, 124 | n_local_heads=16, 125 | dim=16384, 126 | intermediate_size=53248, 127 | vocab_size=128256, 128 | rope_base=500000, 129 | rope_scaling=dict( 130 | factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_position_embeddings=8192 131 | ), 132 | ), 133 | } 134 | 135 | 136 | class GPTParallel(nn.Module): 137 | def __init__(self, config: ModelArgs) -> None: 138 | super().__init__() 139 | self.config = config 140 | 141 | self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim) 142 | self.layers = nn.ModuleList(ParallelTransformerBlock(config) for _ in range(config.n_layer)) 143 | self.norm = RMSNorm(config.dim, eps=config.norm_eps) 144 | self.output = nn.Linear(config.dim, config.vocab_size, bias=False) 145 | 146 | self.freqs_cis: Optional[Tensor] = None 147 | self.mask_cache: Optional[Tensor] = None 148 | self.max_batch_size = -1 149 | self.max_seq_length = -1 150 | 151 | def setup_caches(self, max_batch_size, max_seq_length): 152 | if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size: 153 | return 154 | head_dim = self.config.dim // self.config.n_head 155 | max_seq_length = find_multiple(max_seq_length, 8) 156 | self.max_seq_length = max_seq_length 157 | self.max_batch_size = max_batch_size 158 | dtype = self.output.weight.dtype 159 | # For quantized layers, dtype is encoded in scales 160 | if hasattr(self.output, "scales"): 161 | dtype = self.output.scales.dtype 162 | elif hasattr(self.output, "scales_and_zeros"): 163 | dtype = self.output.scales_and_zeros.dtype 164 | for b in self.layers: 165 | b.attention.kv_cache = KVCache( 166 | max_batch_size, max_seq_length, self.config.n_local_heads // tp_world_size, head_dim, dtype 167 | ) 168 | 169 | self.freqs_cis = precompute_freqs_cis( 170 | self.config.block_size, 171 | self.config.dim // self.config.n_head, 172 | self.config.rope_base, 173 | dtype, 174 | self.config.rope_scaling, 175 | ) 176 | self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)) 177 | 178 | def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: 179 | assert self.freqs_cis is not None, "Caches must be initialized first" 180 | mask = self.causal_mask[None, None, input_pos] 181 | freqs_cis = self.freqs_cis[input_pos] 182 | x = self.tok_embeddings(idx) 183 | 184 | for i, layer in enumerate(self.layers): 185 | x = layer(x, input_pos, freqs_cis, mask) 186 | x = self.norm(x) 187 | logits = self.output(x) 188 | return logits 189 | 190 | @classmethod 191 | def from_name(cls, name: str): 192 | return cls(ModelArgs.from_name(name)) 193 | 194 | 195 | class ParallelTransformerBlock(nn.Module): 196 | def __init__(self, config: ModelArgs) -> None: 197 | super().__init__() 198 | self.attention = FuseAttentionMLP(config) 199 | # self.feed_forward = FeedForward(config) 200 | self.attention_norm = RMSNorm(config.dim, config.norm_eps) 201 | 202 | def _attn_ffn(x, freqs_cis, mask, input_pos): 203 | y = self.attention(self.attention_norm(x), freqs_cis, mask, input_pos) 204 | 205 | if tp_rank == 0: 206 | y = y + x 207 | 208 | return y 209 | 210 | self._attn_ffn = torch.compile(_attn_ffn) 211 | 212 | self.semi_compiled_model = config.semi_compiled_model 213 | 214 | def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor: 215 | if self.semi_compiled_model: 216 | x = self._attn_ffn(x, freqs_cis, mask, input_pos) 217 | x = all_reduce_func(x, clone=True)[0] 218 | y = x 219 | else: 220 | y = self.attention( 221 | self.attention_norm(x), freqs_cis, mask, input_pos 222 | ) # + self.feed_forward(self.ffn_norm(x)) 223 | y = all_reduce_func(y, clone=False)[0] 224 | y = y + x 225 | 226 | return y 227 | 228 | def extra_repr(self) -> str: 229 | return f"semi_compiled = {self.semi_compiled_model}" 230 | -------------------------------------------------------------------------------- /gpt_fast/tp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import os 7 | from datetime import timedelta 8 | from typing import Optional 9 | 10 | import torch 11 | import torch.distributed as dist 12 | 13 | 14 | def _get_rank() -> int: 15 | return int(os.environ.get("LOCAL_RANK", "0")) 16 | 17 | 18 | def __get_global_rank() -> int: 19 | return int(os.environ.get("RANK", "0")) 20 | 21 | 22 | def _get_world_size() -> int: 23 | return int(os.environ.get("LOCAL_WORLD_SIZE", "1")) 24 | 25 | 26 | def _get_global_world_size() -> int: 27 | return int(os.environ.get("WORLD_SIZE", "1")) 28 | 29 | 30 | def is_local(): 31 | return _get_rank() == 0 32 | 33 | 34 | def local_break(): 35 | if is_local(): 36 | breakpoint() 37 | dist.barrier() 38 | 39 | 40 | def maybe_init_dist() -> Optional[int]: 41 | try: 42 | # provided by torchrun 43 | rank = _get_rank() 44 | global_rank = __get_global_rank() 45 | world_size = _get_world_size() 46 | global_world_size = _get_global_world_size() 47 | print( 48 | f"rank: {rank}, global_rank: {global_rank}, world_size: {world_size}, global_world_size: {global_world_size}" 49 | ) 50 | except KeyError: 51 | # not run via torchrun, no-op 52 | return None 53 | 54 | if not dist.is_initialized(): 55 | torch.cuda.set_device(rank) 56 | if global_world_size > 1: 57 | dist.init_process_group( 58 | backend="nccl", rank=global_rank, world_size=global_world_size, timeout=timedelta(seconds=600) 59 | ) 60 | else: 61 | dist.init_process_group(backend="nccl", rank=rank, world_size=world_size, timeout=timedelta(seconds=600)) 62 | 63 | return rank 64 | -------------------------------------------------------------------------------- /gpt_fast/utils.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import math 3 | from contextlib import contextmanager 4 | from typing import Optional 5 | 6 | import torch 7 | import torch.distributed as dist 8 | import torch.distributed._functional_collectives as funcol 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache 12 | from liger_kernel.ops.rms_norm import LigerRMSNormFunction 13 | from torch import Tensor 14 | 15 | from .tp import maybe_init_dist 16 | 17 | _USE_FLASH_ATTENTION: bool = False 18 | 19 | 20 | @contextmanager 21 | def set_flash_attention(enable: bool): 22 | global _USE_FLASH_ATTENTION 23 | 24 | original_value = _USE_FLASH_ATTENTION 25 | _USE_FLASH_ATTENTION = enable 26 | 27 | yield 28 | 29 | _USE_FLASH_ATTENTION = original_value 30 | 31 | 32 | def is_flash_attention_enabled() -> bool: 33 | global _USE_FLASH_ATTENTION 34 | return _USE_FLASH_ATTENTION 35 | 36 | 37 | maybe_init_dist() 38 | tp_rank = dist.get_rank() 39 | tp_world_size = dist.get_world_size() 40 | tp_group = list(range(tp_world_size)) 41 | 42 | 43 | class KVCache(nn.Module): 44 | def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16): 45 | super().__init__() 46 | cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) 47 | self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype)) 48 | self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype)) 49 | 50 | def update(self, input_pos, k_val, v_val): 51 | # input_pos: [S], k_val: [B, H, S, D] 52 | assert input_pos.shape[0] == k_val.shape[2] 53 | 54 | k_out = self.k_cache 55 | v_out = self.v_cache 56 | k_out[:, :, input_pos] = k_val 57 | v_out[:, :, input_pos] = v_val 58 | 59 | return k_out, v_out 60 | 61 | 62 | class RMSNorm(nn.Module): 63 | def __init__(self, dim: int, eps: float = 1e-5): 64 | super().__init__() 65 | self.eps = eps 66 | self.weight = nn.Parameter(torch.ones(dim)) 67 | 68 | def _norm(self, x): 69 | return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) 70 | 71 | def forward(self, x: Tensor) -> Tensor: 72 | if torch.compiler.is_compiling(): 73 | output = self._norm(x.float()).type_as(x) 74 | output = output * self.weight 75 | else: 76 | output = LigerRMSNormFunction.apply(x, self.weight, self.eps) 77 | 78 | return output 79 | 80 | 81 | def apply_rope_scaling(freqs: torch.Tensor, rope_scaling: Optional[dict] = None): 82 | factor = rope_scaling["factor"] 83 | low_freq_factor = rope_scaling["low_freq_factor"] 84 | high_freq_factor = rope_scaling["high_freq_factor"] 85 | old_context_len = rope_scaling["original_max_position_embeddings"] 86 | 87 | low_freq_wavelen = old_context_len / low_freq_factor 88 | high_freq_wavelen = old_context_len / high_freq_factor 89 | new_freqs = [] 90 | for freq in freqs: 91 | wavelen = 2 * math.pi / freq 92 | if wavelen < high_freq_wavelen: 93 | new_freqs.append(freq) 94 | elif wavelen > low_freq_wavelen: 95 | new_freqs.append(freq / factor) 96 | else: 97 | assert low_freq_wavelen != high_freq_wavelen 98 | smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) 99 | new_freqs.append((1 - smooth) * freq / factor + smooth * freq) 100 | return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) 101 | 102 | 103 | def precompute_freqs_cis( 104 | seq_len: int, 105 | n_elem: int, 106 | base: int = 10000, 107 | dtype: torch.dtype = torch.bfloat16, 108 | rope_scaling: Optional[dict] = None, 109 | ) -> Tensor: 110 | freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)) 111 | if rope_scaling is not None: 112 | freqs = apply_rope_scaling(freqs, rope_scaling) 113 | t = torch.arange(seq_len, device=freqs.device) 114 | freqs = torch.outer(t, freqs) 115 | freqs_cis = torch.polar(torch.ones_like(freqs), freqs) 116 | cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) 117 | return cache.to(dtype=dtype) 118 | 119 | 120 | def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: 121 | xshaped = x.float().reshape(*x.shape[:-1], -1, 2) 122 | freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) 123 | x_out2 = torch.stack( 124 | [ 125 | xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1], 126 | xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1], 127 | ], 128 | -1, 129 | ) 130 | 131 | x_out2 = x_out2.flatten(3) 132 | return x_out2.type_as(x) 133 | 134 | 135 | class FeedForward(nn.Module): 136 | def __init__(self, config) -> None: 137 | super().__init__() 138 | 139 | assert config.intermediate_size % tp_world_size == 0 140 | assert config.dim % tp_world_size == 0 141 | 142 | self.w1 = nn.Linear(config.dim, 2 * config.intermediate_size // tp_world_size, bias=False) 143 | self.w2 = nn.Linear(config.intermediate_size // tp_world_size, config.dim, bias=False) 144 | 145 | def forward(self, x: Tensor) -> Tensor: 146 | x = self.w1(x) 147 | u, g = x.chunk(2, dim=-1) 148 | y = self.w2(F.silu(g) * u) 149 | return y 150 | 151 | 152 | class Attention(nn.Module): 153 | def __init__(self, config): 154 | super().__init__() 155 | assert config.dim % config.n_head == 0 156 | 157 | total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim 158 | 159 | assert total_head_dim % tp_world_size == 0 160 | assert config.dim % tp_world_size == 0 161 | assert config.n_head % tp_world_size == 0 162 | assert config.n_local_heads % tp_world_size == 0 163 | 164 | # key, query, value projections for all heads, but in a batch 165 | self.wqkv = nn.Linear(config.dim, total_head_dim // tp_world_size, bias=False) 166 | self.wo = nn.Linear(config.dim // tp_world_size, config.dim, bias=False) 167 | self.kv_cache = None 168 | 169 | self.n_head = config.n_head 170 | self.head_dim = config.head_dim 171 | self.n_local_heads = config.n_local_heads 172 | self.dim = config.dim 173 | 174 | self.n_head = self.n_head // tp_world_size 175 | self.dim = self.dim // tp_world_size 176 | self.n_local_heads = self.n_local_heads // tp_world_size 177 | 178 | self._register_load_state_dict_pre_hook(self.load_hook) 179 | 180 | def load_hook(self, state_dict, prefix, *args): 181 | if prefix + "wq.weight" in state_dict: 182 | wq = state_dict.pop(prefix + "wq.weight") 183 | wk = state_dict.pop(prefix + "wk.weight") 184 | wv = state_dict.pop(prefix + "wv.weight") 185 | state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv]) 186 | 187 | def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: 188 | bsz, seqlen, _ = x.shape 189 | 190 | kv_size = self.n_local_heads * self.head_dim 191 | q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1) 192 | 193 | q = q.view(bsz, seqlen, self.n_head, self.head_dim) 194 | k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim) 195 | v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim) 196 | 197 | q = apply_rotary_emb(q, freqs_cis) 198 | k = apply_rotary_emb(k, freqs_cis) 199 | 200 | if is_flash_attention_enabled(): 201 | device = q.device 202 | 203 | if seqlen <= 1: # decode time 204 | k_cache = self.kv_cache.k_cache # (batch_size, n_local_heads, seqlen_cache, head_dim) 205 | k_cache = k_cache.transpose(1, 2) 206 | v_cache = self.kv_cache.v_cache 207 | v_cache = v_cache.transpose(1, 2) 208 | cache_seqlens = k_cache.size(1) 209 | y = flash_attn_with_kvcache( 210 | q, # (batch_size, seqlen_q, n_heads, head_dim) 211 | k_cache, # (batch_size, seqlen_cache, n_local_heads, head_dim) 212 | v_cache, # (batch_size, seqlen_new, n_local_heads, head_dim) 213 | k=k, # (batch_size, seqlen_new, n_local_heads, head_dim) 214 | v=v, # (batch_size, seqlen_new, n_local_heads, head_dim) 215 | cache_seqlens=cache_seqlens, 216 | cache_batch_idx=None, 217 | cache_leftpad=None, 218 | block_table=None, 219 | rotary_cos=None, 220 | rotary_sin=None, 221 | softmax_scale=None, 222 | causal=True, 223 | ) 224 | k_cache = k_cache.transpose(1, 2) 225 | v_cache = v_cache.transpose(1, 2) 226 | self.kv_cache.k_cache = k_cache 227 | self.kv_cache.v_cache = v_cache 228 | else: 229 | if self.kv_cache is not None: 230 | k, v = map(lambda x: x.transpose(1, 2), (k, v)) 231 | k, v = self.kv_cache.update(input_pos, k, v) 232 | k, v = map(lambda x: x.transpose(1, 2), (k, v)) 233 | q_var = q.reshape(-1, q.shape[-2], q.shape[-1]) 234 | k_var = k.reshape(-1, k.shape[-2], k.shape[-1]) 235 | v_var = v.reshape(-1, v.shape[-2], v.shape[-1]) 236 | lens = torch.full([q.shape[0]], seqlen, dtype=torch.int32, device=device) 237 | 238 | cu_seqlens = torch.cat( 239 | [ 240 | torch.zeros(1, dtype=torch.int32, device=device), 241 | torch.cumsum(lens, dim=0, dtype=torch.int32), 242 | ] 243 | ).int() 244 | y = flash_attn_varlen_func( 245 | q_var, k_var, v_var, cu_seqlens, cu_seqlens, q.size(1), k.size(1), causal=True 246 | ) 247 | else: 248 | q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) 249 | if self.kv_cache is not None: 250 | k, v = self.kv_cache.update(input_pos, k, v) 251 | 252 | k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) 253 | v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) 254 | 255 | y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) 256 | 257 | y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) 258 | y = self.wo(y) 259 | 260 | return y 261 | 262 | 263 | class FuseAttentionMLP(nn.Module): 264 | def __init__(self, config): 265 | super().__init__() 266 | assert config.dim % config.n_head == 0 267 | 268 | total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim 269 | 270 | assert total_head_dim % tp_world_size == 0 271 | assert config.dim % tp_world_size == 0 272 | assert config.intermediate_size % tp_world_size == 0 273 | assert config.n_head % tp_world_size == 0 274 | assert config.n_local_heads % tp_world_size == 0 275 | 276 | # key, query, value projections for all heads, but in a batch 277 | self.wqkv1 = nn.Linear( 278 | config.dim, total_head_dim // tp_world_size + 2 * config.intermediate_size // tp_world_size, bias=False 279 | ) 280 | self.wo = nn.Linear(config.dim // tp_world_size, config.dim, bias=False) 281 | self.w2 = nn.Linear(config.intermediate_size // tp_world_size, config.dim, bias=False) 282 | 283 | self.kv_cache = None 284 | 285 | self.n_head = config.n_head 286 | self.head_dim = config.head_dim 287 | self.n_local_heads = config.n_local_heads 288 | self.dim = config.dim 289 | 290 | self.n_head = self.n_head // tp_world_size 291 | self.dim = self.dim // tp_world_size 292 | self.n_local_heads = self.n_local_heads // tp_world_size 293 | 294 | self.intermediate_size = config.intermediate_size // tp_world_size 295 | 296 | self._register_load_state_dict_pre_hook(self.load_hook) 297 | 298 | def load_hook(self, state_dict, prefix, *args): 299 | if prefix + "wq.weight" in state_dict: 300 | wq = state_dict.pop(prefix + "wq.weight") 301 | wk = state_dict.pop(prefix + "wk.weight") 302 | wv = state_dict.pop(prefix + "wv.weight") 303 | state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv]) 304 | 305 | def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: 306 | bsz, seqlen, _ = x.shape 307 | 308 | kv_size = self.n_local_heads * self.head_dim 309 | # q, k, v = self.wqkv1(x).split([self.dim, kv_size, kv_size], dim=-1) 310 | # use fuse qkv1 311 | q, k, v, u, g = self.wqkv1(x).split( 312 | [self.dim, kv_size, kv_size, self.intermediate_size, self.intermediate_size], dim=-1 313 | ) 314 | 315 | q = q.view(bsz, seqlen, self.n_head, self.head_dim) 316 | k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim) 317 | v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim) 318 | 319 | q = apply_rotary_emb(q, freqs_cis) 320 | k = apply_rotary_emb(k, freqs_cis) 321 | 322 | if is_flash_attention_enabled(): 323 | device = q.device 324 | 325 | if seqlen <= 1: # decode 326 | k_cache = self.kv_cache.k_cache 327 | k_cache = k_cache.transpose(1, 2) 328 | v_cache = self.kv_cache.v_cache 329 | v_cache = v_cache.transpose(1, 2) 330 | cache_seqlens = k_cache.size(1) 331 | 332 | y = flash_attn_with_kvcache( 333 | q, # (batch_size, seqlen_q, n_heads, head_dim) 334 | k_cache, # (batch_size, seqlen_cache, n_local_heads, head_dim) 335 | v_cache, # (batch_size, seqlen_cache, n_local_heads, head_dim) 336 | k=k, # (batch_size, seqlen_new, n_local_heads, head_dim) 337 | v=v, # (batch_size, seqlen_new, n_local_heads, head_dim) 338 | cache_seqlens=cache_seqlens, 339 | cache_batch_idx=None, 340 | cache_leftpad=None, 341 | block_table=None, 342 | rotary_cos=None, 343 | rotary_sin=None, 344 | softmax_scale=None, 345 | causal=True, 346 | ) 347 | 348 | k_cache = k_cache.transpose(1, 2) 349 | v_cache = v_cache.transpose(1, 2) 350 | self.kv_cache.k_cache = k_cache 351 | self.kv_cache.v_cache = v_cache 352 | else: 353 | if self.kv_cache is not None: 354 | k, v = map(lambda x: x.transpose(1, 2), (k, v)) 355 | k, v = self.kv_cache.update(input_pos, k, v) 356 | k, v = map(lambda x: x.transpose(1, 2), (k, v)) 357 | 358 | q_var = q.reshape(-1, q.shape[-2], q.shape[-1]) 359 | k_var = k.reshape(-1, k.shape[-2], k.shape[-1]) 360 | v_var = v.reshape(-1, v.shape[-2], v.shape[-1]) 361 | lens = torch.full([q.shape[0]], seqlen, dtype=torch.int32, device=device) 362 | 363 | cu_seqlens = torch.cat( 364 | [ 365 | torch.zeros(1, dtype=torch.int32, device=device), 366 | torch.cumsum(lens, dim=0, dtype=torch.int32), 367 | ] 368 | ).int() 369 | y = flash_attn_varlen_func( 370 | q_var, k_var, v_var, cu_seqlens, cu_seqlens, q.size(1), k.size(1), causal=True 371 | ) 372 | else: 373 | q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) 374 | 375 | if self.kv_cache is not None: 376 | k, v = self.kv_cache.update(input_pos, k, v) 377 | 378 | k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) 379 | v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) 380 | 381 | y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) 382 | 383 | y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) 384 | y = self.wo(y) 385 | y = self.w2(F.silu(g) * u) + y 386 | 387 | return y 388 | 389 | 390 | def all_reduce_func(x: torch.Tensor, clone: bool, async_op=False) -> torch.Tensor: 391 | if torch.compiler.is_compiling() or clone: 392 | x = funcol.all_reduce(x, reduceOp="sum", group=tp_group) 393 | handle = None 394 | else: 395 | handle = dist.all_reduce(x, async_op=async_op) 396 | 397 | return x, handle 398 | 399 | 400 | def _get_model_size(model): 401 | model_size = 0 402 | params = 0 403 | for name, child in model.named_children(): 404 | if not isinstance(child, torch.nn.Embedding): 405 | model_size += sum( 406 | [p.numel() * p.dtype.itemsize for p in itertools.chain(child.parameters(), child.buffers())] 407 | ) 408 | params += sum([p.numel() for p in itertools.chain(child.parameters(), child.buffers())]) 409 | return model_size, params 410 | -------------------------------------------------------------------------------- /hf_modeling_utils/configs/Llama-3.1-8B-Instruct-Ladder-last16L.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LlamaForCausalLM" 4 | ], 5 | "auto_map": { 6 | "AutoConfig": "configuration_llama_ladder.LlamaLadderConfig", 7 | "AutoModelForCausalLM": "modeling_llama_ladder.LlamaLadderForCausalLM" 8 | }, 9 | "attention_bias": false, 10 | "attention_dropout": 0.0, 11 | "bos_token_id": 128000, 12 | "eos_token_id": [ 13 | 128001, 14 | 128008, 15 | 128009 16 | ], 17 | "hidden_act": "silu", 18 | "hidden_size": 4096, 19 | "initializer_range": 0.02, 20 | "intermediate_size": 14336, 21 | "max_position_embeddings": 131072, 22 | "mlp_bias": false, 23 | "model_type": "llamaLadder", 24 | "num_attention_heads": 32, 25 | "num_hidden_layers": 32, 26 | "num_key_value_heads": 8, 27 | "pretraining_tp": 1, 28 | "rms_norm_eps": 1e-05, 29 | "rope_scaling": { 30 | "factor": 8.0, 31 | "low_freq_factor": 1.0, 32 | "high_freq_factor": 4.0, 33 | "original_max_position_embeddings": 8192, 34 | "rope_type": "llama3" 35 | }, 36 | "rope_theta": 500000.0, 37 | "tie_word_embeddings": false, 38 | "torch_dtype": "bfloat16", 39 | "transformers_version": "4.42.3", 40 | "use_cache": true, 41 | "vocab_size": 128256, 42 | "ladder_layers": 16 43 | } -------------------------------------------------------------------------------- /hf_modeling_utils/configs/Llama-3.1-8B-Instruct-Ladder-last20L.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LlamaForCausalLM" 4 | ], 5 | "auto_map": { 6 | "AutoConfig": "configuration_llama_ladder.LlamaLadderConfig", 7 | "AutoModelForCausalLM": "modeling_llama_ladder.LlamaLadderForCausalLM" 8 | }, 9 | "attention_bias": false, 10 | "attention_dropout": 0.0, 11 | "bos_token_id": 128000, 12 | "eos_token_id": [ 13 | 128001, 14 | 128008, 15 | 128009 16 | ], 17 | "hidden_act": "silu", 18 | "hidden_size": 4096, 19 | "initializer_range": 0.02, 20 | "intermediate_size": 14336, 21 | "max_position_embeddings": 131072, 22 | "mlp_bias": false, 23 | "model_type": "llamaLadder", 24 | "num_attention_heads": 32, 25 | "num_hidden_layers": 32, 26 | "num_key_value_heads": 8, 27 | "pretraining_tp": 1, 28 | "rms_norm_eps": 1e-05, 29 | "rope_scaling": { 30 | "factor": 8.0, 31 | "low_freq_factor": 1.0, 32 | "high_freq_factor": 4.0, 33 | "original_max_position_embeddings": 8192, 34 | "rope_type": "llama3" 35 | }, 36 | "rope_theta": 500000.0, 37 | "tie_word_embeddings": false, 38 | "torch_dtype": "bfloat16", 39 | "transformers_version": "4.42.3", 40 | "use_cache": true, 41 | "vocab_size": 128256, 42 | "ladder_layers": 20 43 | } -------------------------------------------------------------------------------- /hf_modeling_utils/configuration_llama_ladder.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. 3 | # 4 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX 5 | # and OPT implementations in this library. It has been modified from its 6 | # original forms to accommodate minor architectural differences compared 7 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model. 8 | # 9 | # Licensed under the Apache License, Version 2.0 (the "License"); 10 | # you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # 13 | # http://www.apache.org/licenses/LICENSE-2.0 14 | # 15 | # Unless required by applicable law or agreed to in writing, software 16 | # distributed under the License is distributed on an "AS IS" BASIS, 17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | # See the License for the specific language governing permissions and 19 | # limitations under the License. 20 | """LLaMA model configuration""" 21 | 22 | from transformers.models.llama.configuration_llama import LlamaConfig 23 | 24 | 25 | class LlamaLadderConfig(LlamaConfig): 26 | 27 | model_type = "llamaLadder" 28 | keys_to_ignore_at_inference = ["past_key_values"] 29 | 30 | def __init__( 31 | self, 32 | ladder_layers=None, 33 | **kwargs, 34 | ): 35 | super().__init__( 36 | **kwargs, 37 | ) 38 | if ladder_layers is None: 39 | self.ladder_layers = [] 40 | elif isinstance(ladder_layers, int): 41 | self.ladder_layers = list(range(self.num_hidden_layers - ladder_layers, self.num_hidden_layers)) 42 | elif isinstance(ladder_layers, list): 43 | self.ladder_layers = ladder_layers 44 | else: 45 | raise ValueError(f"Invalid ladder_layers type: {type(ladder_layers)}") 46 | 47 | print(f"Ladder layers: {self.ladder_layers}") 48 | -------------------------------------------------------------------------------- /hf_modeling_utils/modeling_llama_ladder.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. 3 | # 4 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX 5 | # and OPT implementations in this library. It has been modified from its 6 | # original forms to accommodate minor architectural differences compared 7 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model. 8 | # 9 | # Licensed under the Apache License, Version 2.0 (the "License"); 10 | # you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # 13 | # http://www.apache.org/licenses/LICENSE-2.0 14 | # 15 | # Unless required by applicable law or agreed to in writing, software 16 | # distributed under the License is distributed on an "AS IS" BASIS, 17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | # See the License for the specific language governing permissions and 19 | # limitations under the License. 20 | import math 21 | from typing import List, Optional, Tuple, Union 22 | 23 | import torch 24 | import torch.utils.checkpoint 25 | from torch import nn 26 | 27 | from transformers.activations import ACT2FN 28 | from transformers.cache_utils import Cache, DynamicCache, StaticCache 29 | from transformers.generation import GenerationMixin 30 | from transformers.modeling_attn_mask_utils import AttentionMaskConverter 31 | from transformers.modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward 32 | from transformers.modeling_outputs import ( 33 | BaseModelOutputWithPast, 34 | CausalLMOutputWithPast, 35 | ) 36 | from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS 37 | from transformers.modeling_utils import PreTrainedModel 38 | from transformers.processing_utils import Unpack 39 | from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS 40 | from transformers.utils import ( 41 | LossKwargs, 42 | add_code_sample_docstrings, 43 | add_start_docstrings, 44 | add_start_docstrings_to_model_forward, 45 | is_flash_attn_greater_or_equal_2_10, 46 | logging, 47 | replace_return_docstrings, 48 | ) 49 | from transformers.models.llama.configuration_llama import LlamaConfig 50 | from .configuration_llama_ladder import LlamaLadderConfig 51 | 52 | 53 | logger = logging.get_logger(__name__) 54 | 55 | _CHECKPOINT_FOR_DOC = "meta-llama/Llama-2-7b-hf" 56 | _CONFIG_FOR_DOC = "LlamaLadderConfig" 57 | # HF version of this file: 4.47.0 58 | 59 | 60 | class LlamaRMSNorm(nn.Module): 61 | def __init__(self, hidden_size, eps=1e-6): 62 | """ 63 | LlamaRMSNorm is equivalent to T5LayerNorm 64 | """ 65 | super().__init__() 66 | self.weight = nn.Parameter(torch.ones(hidden_size)) 67 | self.variance_epsilon = eps 68 | 69 | def forward(self, hidden_states): 70 | input_dtype = hidden_states.dtype 71 | hidden_states = hidden_states.to(torch.float32) 72 | variance = hidden_states.pow(2).mean(-1, keepdim=True) 73 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) 74 | return self.weight * hidden_states.to(input_dtype) 75 | 76 | def extra_repr(self): 77 | return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" 78 | 79 | 80 | ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm) 81 | 82 | 83 | class LlamaRotaryEmbedding(nn.Module): 84 | def __init__( 85 | self, 86 | dim=None, 87 | max_position_embeddings=2048, 88 | base=10000, 89 | device=None, 90 | scaling_factor=1.0, 91 | rope_type="default", 92 | config: Optional[LlamaConfig] = None, 93 | ): 94 | super().__init__() 95 | # TODO (joao): remove the `if` below, only used for BC 96 | self.rope_kwargs = {} 97 | if config is None: 98 | logger.warning_once( 99 | "`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the " 100 | "`config` argument. All other arguments will be removed in v4.46" 101 | ) 102 | self.rope_kwargs = { 103 | "rope_type": rope_type, 104 | "factor": scaling_factor, 105 | "dim": dim, 106 | "base": base, 107 | "max_position_embeddings": max_position_embeddings, 108 | } 109 | self.rope_type = rope_type 110 | self.max_seq_len_cached = max_position_embeddings 111 | self.original_max_seq_len = max_position_embeddings 112 | else: 113 | # BC: "rope_type" was originally "type" 114 | if config.rope_scaling is not None: 115 | self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) 116 | else: 117 | self.rope_type = "default" 118 | self.max_seq_len_cached = config.max_position_embeddings 119 | self.original_max_seq_len = config.max_position_embeddings 120 | 121 | self.config = config 122 | self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] 123 | 124 | inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) 125 | self.register_buffer("inv_freq", inv_freq, persistent=False) 126 | self.original_inv_freq = self.inv_freq 127 | 128 | def _dynamic_frequency_update(self, position_ids, device): 129 | """ 130 | dynamic RoPE layers should recompute `inv_freq` in the following situations: 131 | 1 - growing beyond the cached sequence length (allow scaling) 132 | 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) 133 | """ 134 | seq_len = torch.max(position_ids) + 1 135 | if seq_len > self.max_seq_len_cached: # growth 136 | inv_freq, self.attention_scaling = self.rope_init_fn( 137 | self.config, device, seq_len=seq_len, **self.rope_kwargs 138 | ) 139 | self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation 140 | self.max_seq_len_cached = seq_len 141 | 142 | if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset 143 | self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) 144 | self.max_seq_len_cached = self.original_max_seq_len 145 | 146 | @torch.no_grad() 147 | def forward(self, x, position_ids): 148 | if "dynamic" in self.rope_type: 149 | self._dynamic_frequency_update(position_ids, device=x.device) 150 | 151 | # Core RoPE block 152 | inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) 153 | position_ids_expanded = position_ids[:, None, :].float() 154 | # Force float32 (see https://github.com/huggingface/transformers/pull/29285) 155 | device_type = x.device.type 156 | device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" 157 | with torch.autocast(device_type=device_type, enabled=False): 158 | freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) 159 | emb = torch.cat((freqs, freqs), dim=-1) 160 | cos = emb.cos() 161 | sin = emb.sin() 162 | 163 | # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention 164 | cos = cos * self.attention_scaling 165 | sin = sin * self.attention_scaling 166 | 167 | return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) 168 | 169 | 170 | class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): 171 | """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" 172 | 173 | def __init__(self, *args, **kwargs): 174 | logger.warning_once( 175 | "`LlamaLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use " 176 | "`LlamaRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)." 177 | ) 178 | kwargs["rope_type"] = "linear" 179 | super().__init__(*args, **kwargs) 180 | 181 | 182 | class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): 183 | """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" 184 | 185 | def __init__(self, *args, **kwargs): 186 | logger.warning_once( 187 | "`LlamaDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use " 188 | "`LlamaRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to " 189 | "__init__)." 190 | ) 191 | kwargs["rope_type"] = "dynamic" 192 | super().__init__(*args, **kwargs) 193 | 194 | 195 | def rotate_half(x): 196 | """Rotates half the hidden dims of the input.""" 197 | x1 = x[..., : x.shape[-1] // 2] 198 | x2 = x[..., x.shape[-1] // 2 :] 199 | return torch.cat((-x2, x1), dim=-1) 200 | 201 | 202 | def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): 203 | """Applies Rotary Position Embedding to the query and key tensors. 204 | 205 | Args: 206 | q (`torch.Tensor`): The query tensor. 207 | k (`torch.Tensor`): The key tensor. 208 | cos (`torch.Tensor`): The cosine part of the rotary embedding. 209 | sin (`torch.Tensor`): The sine part of the rotary embedding. 210 | position_ids (`torch.Tensor`, *optional*): 211 | Deprecated and unused. 212 | unsqueeze_dim (`int`, *optional*, defaults to 1): 213 | The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and 214 | sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note 215 | that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and 216 | k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes 217 | cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have 218 | the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. 219 | Returns: 220 | `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. 221 | """ 222 | cos = cos.unsqueeze(unsqueeze_dim) 223 | sin = sin.unsqueeze(unsqueeze_dim) 224 | q_embed = (q * cos) + (rotate_half(q) * sin) 225 | k_embed = (k * cos) + (rotate_half(k) * sin) 226 | return q_embed, k_embed 227 | 228 | 229 | class LlamaMLP(nn.Module): 230 | def __init__(self, config): 231 | super().__init__() 232 | self.config = config 233 | self.hidden_size = config.hidden_size 234 | self.intermediate_size = config.intermediate_size 235 | self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) 236 | self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) 237 | self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) 238 | self.act_fn = ACT2FN[config.hidden_act] 239 | 240 | def forward(self, x): 241 | down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) 242 | return down_proj 243 | 244 | 245 | def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: 246 | """ 247 | This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, 248 | num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) 249 | """ 250 | batch, num_key_value_heads, slen, head_dim = hidden_states.shape 251 | if n_rep == 1: 252 | return hidden_states 253 | hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) 254 | return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) 255 | 256 | 257 | class LlamaAttention(nn.Module): 258 | """Multi-headed attention from 'Attention Is All You Need' paper""" 259 | 260 | def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): 261 | super().__init__() 262 | self.config = config 263 | self.layer_idx = layer_idx 264 | if layer_idx is None: 265 | logger.warning_once( 266 | f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " 267 | "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " 268 | "when creating this class." 269 | ) 270 | 271 | self.attention_dropout = config.attention_dropout 272 | self.hidden_size = config.hidden_size 273 | self.num_heads = config.num_attention_heads 274 | self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) 275 | self.num_key_value_heads = config.num_key_value_heads 276 | self.num_key_value_groups = self.num_heads // self.num_key_value_heads 277 | self.max_position_embeddings = config.max_position_embeddings 278 | self.rope_theta = config.rope_theta 279 | self.is_causal = True 280 | 281 | self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) 282 | self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) 283 | self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) 284 | self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) 285 | 286 | # TODO (joao): remove in v4.46 (RoPE is computed in the model, not in the decoder layers) 287 | self.rotary_emb = LlamaRotaryEmbedding(config=self.config) 288 | 289 | def forward( 290 | self, 291 | hidden_states: torch.Tensor, 292 | attention_mask: Optional[torch.Tensor] = None, 293 | position_ids: Optional[torch.LongTensor] = None, 294 | past_key_value: Optional[Cache] = None, 295 | output_attentions: bool = False, 296 | use_cache: bool = False, 297 | cache_position: Optional[torch.LongTensor] = None, 298 | position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 299 | **kwargs, 300 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 301 | bsz, q_len, _ = hidden_states.size() 302 | 303 | query_states = self.q_proj(hidden_states) 304 | key_states = self.k_proj(hidden_states) 305 | value_states = self.v_proj(hidden_states) 306 | 307 | # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used 308 | query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) 309 | key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) 310 | value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) 311 | 312 | if position_embeddings is None: 313 | logger.warning_once( 314 | "The attention layers in this model are transitioning from computing the RoPE embeddings internally " 315 | "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " 316 | "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " 317 | "removed and `position_embeddings` will be mandatory." 318 | ) 319 | cos, sin = self.rotary_emb(value_states, position_ids) 320 | else: 321 | cos, sin = position_embeddings 322 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) 323 | 324 | if past_key_value is not None: 325 | # sin and cos are specific to RoPE models; cache_position needed for the static cache 326 | cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} 327 | key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) 328 | 329 | key_states = repeat_kv(key_states, self.num_key_value_groups) 330 | value_states = repeat_kv(value_states, self.num_key_value_groups) 331 | attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) 332 | 333 | if attention_mask is not None: # no matter the length, we just slice it 334 | causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] 335 | attn_weights = attn_weights + causal_mask 336 | 337 | # upcast attention to fp32 338 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) 339 | attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) 340 | attn_output = torch.matmul(attn_weights, value_states) 341 | 342 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): 343 | raise ValueError( 344 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" 345 | f" {attn_output.size()}" 346 | ) 347 | 348 | attn_output = attn_output.transpose(1, 2).contiguous() 349 | 350 | attn_output = attn_output.reshape(bsz, q_len, -1) 351 | 352 | attn_output = self.o_proj(attn_output) 353 | 354 | if not output_attentions: 355 | attn_weights = None 356 | 357 | return attn_output, attn_weights, past_key_value 358 | 359 | 360 | class LlamaFlashAttention2(LlamaAttention): 361 | """ 362 | Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays 363 | untouched. The only required change would be on the forward pass where it needs to correctly call the public API of 364 | flash attention and deal with padding tokens in case the input contains any of them. 365 | """ 366 | 367 | def __init__(self, *args, **kwargs): 368 | super().__init__(*args, **kwargs) 369 | 370 | # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. 371 | # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. 372 | # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). 373 | self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() 374 | 375 | def forward( 376 | self, 377 | hidden_states: torch.Tensor, 378 | attention_mask: Optional[torch.LongTensor] = None, 379 | position_ids: Optional[torch.LongTensor] = None, 380 | past_key_value: Optional[Cache] = None, 381 | output_attentions: bool = False, 382 | use_cache: bool = False, 383 | cache_position: Optional[torch.LongTensor] = None, 384 | position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 385 | **kwargs: Unpack[FlashAttentionKwargs], 386 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 387 | if isinstance(past_key_value, StaticCache): 388 | raise ValueError( 389 | "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " 390 | "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" 391 | ) 392 | 393 | output_attentions = False 394 | 395 | bsz, q_len, _ = hidden_states.size() 396 | 397 | query_states = self.q_proj(hidden_states) 398 | key_states = self.k_proj(hidden_states) 399 | value_states = self.v_proj(hidden_states) 400 | 401 | # Flash attention requires the input to have the shape 402 | # batch_size x seq_length x head_dim x hidden_dim 403 | # therefore we just need to keep the original shape 404 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 405 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 406 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 407 | 408 | if position_embeddings is None: 409 | logger.warning_once( 410 | "The attention layers in this model are transitioning from computing the RoPE embeddings internally " 411 | "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " 412 | "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " 413 | "removed and `position_embeddings` will be mandatory." 414 | ) 415 | cos, sin = self.rotary_emb(value_states, position_ids) 416 | else: 417 | cos, sin = position_embeddings 418 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) 419 | 420 | if past_key_value is not None: 421 | # sin and cos are specific to RoPE models; cache_position needed for the static cache 422 | cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} 423 | key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) 424 | 425 | # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache 426 | # to be able to avoid many of these transpose/reshape/view. 427 | query_states = query_states.transpose(1, 2) 428 | key_states = key_states.transpose(1, 2) 429 | value_states = value_states.transpose(1, 2) 430 | 431 | dropout_rate = self.attention_dropout if self.training else 0.0 432 | 433 | # In PEFT, usually we cast the layer norms in float32 for training stability reasons 434 | # therefore the input hidden states gets silently casted in float32. Hence, we need 435 | # cast them back in the correct dtype just to be sure everything works as expected. 436 | # This might slowdown training & inference so it is recommended to not cast the LayerNorms 437 | # in fp32. (LlamaRMSNorm handles it correctly) 438 | 439 | input_dtype = query_states.dtype 440 | if input_dtype == torch.float32: 441 | if torch.is_autocast_enabled(): 442 | target_dtype = torch.get_autocast_gpu_dtype() 443 | # Handle the case where the model is quantized 444 | elif hasattr(self.config, "_pre_quantization_dtype"): 445 | target_dtype = self.config._pre_quantization_dtype 446 | else: 447 | target_dtype = self.q_proj.weight.dtype 448 | 449 | logger.warning_once( 450 | f"The input hidden states seems to be silently casted in float32, this might be related to" 451 | f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" 452 | f" {target_dtype}." 453 | ) 454 | 455 | query_states = query_states.to(target_dtype) 456 | key_states = key_states.to(target_dtype) 457 | value_states = value_states.to(target_dtype) 458 | 459 | attn_output = _flash_attention_forward( 460 | query_states, 461 | key_states, 462 | value_states, 463 | attention_mask, 464 | q_len, 465 | position_ids=position_ids, 466 | dropout=dropout_rate, 467 | sliding_window=getattr(self, "sliding_window", None), 468 | use_top_left_mask=self._flash_attn_uses_top_left_mask, 469 | is_causal=self.is_causal, 470 | **kwargs, 471 | ) 472 | 473 | attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() 474 | attn_output = self.o_proj(attn_output) 475 | 476 | if not output_attentions: 477 | attn_weights = None 478 | 479 | return attn_output, attn_weights, past_key_value 480 | 481 | 482 | class LlamaSdpaAttention(LlamaAttention): 483 | """ 484 | Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from 485 | `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to 486 | SDPA API. 487 | """ 488 | 489 | # Adapted from LlamaAttention.forward 490 | def forward( 491 | self, 492 | hidden_states: torch.Tensor, 493 | attention_mask: Optional[torch.Tensor] = None, 494 | position_ids: Optional[torch.LongTensor] = None, 495 | past_key_value: Optional[Cache] = None, 496 | output_attentions: bool = False, 497 | use_cache: bool = False, 498 | cache_position: Optional[torch.LongTensor] = None, 499 | position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 500 | **kwargs, 501 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 502 | if output_attentions: 503 | # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. 504 | logger.warning_once( 505 | "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " 506 | 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' 507 | ) 508 | return super().forward( 509 | hidden_states=hidden_states, 510 | attention_mask=attention_mask, 511 | position_ids=position_ids, 512 | past_key_value=past_key_value, 513 | output_attentions=output_attentions, 514 | use_cache=use_cache, 515 | cache_position=cache_position, 516 | position_embeddings=position_embeddings, 517 | ) 518 | 519 | bsz, q_len, _ = hidden_states.size() 520 | 521 | query_states = self.q_proj(hidden_states) 522 | key_states = self.k_proj(hidden_states) 523 | value_states = self.v_proj(hidden_states) 524 | 525 | # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used 526 | query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) 527 | key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) 528 | value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) 529 | 530 | if position_embeddings is None: 531 | logger.warning_once( 532 | "The attention layers in this model are transitioning from computing the RoPE embeddings internally " 533 | "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " 534 | "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " 535 | "removed and `position_embeddings` will be mandatory." 536 | ) 537 | cos, sin = self.rotary_emb(value_states, position_ids) 538 | else: 539 | cos, sin = position_embeddings 540 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) 541 | 542 | if past_key_value is not None: 543 | # sin and cos are specific to RoPE models; cache_position needed for the static cache 544 | cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} 545 | key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) 546 | 547 | key_states = repeat_kv(key_states, self.num_key_value_groups) 548 | value_states = repeat_kv(value_states, self.num_key_value_groups) 549 | 550 | causal_mask = attention_mask 551 | if attention_mask is not None: 552 | causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] 553 | 554 | # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, 555 | # Reference: https://github.com/pytorch/pytorch/issues/112577. 556 | if query_states.device.type == "cuda" and causal_mask is not None: 557 | query_states = query_states.contiguous() 558 | key_states = key_states.contiguous() 559 | value_states = value_states.contiguous() 560 | 561 | # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment 562 | # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. 563 | is_causal = True if causal_mask is None and q_len > 1 else False 564 | 565 | attn_output = torch.nn.functional.scaled_dot_product_attention( 566 | query_states, 567 | key_states, 568 | value_states, 569 | attn_mask=causal_mask, 570 | dropout_p=self.attention_dropout if self.training else 0.0, 571 | is_causal=is_causal, 572 | ) 573 | 574 | attn_output = attn_output.transpose(1, 2).contiguous() 575 | attn_output = attn_output.view(bsz, q_len, -1) 576 | 577 | attn_output = self.o_proj(attn_output) 578 | 579 | return attn_output, None, past_key_value 580 | 581 | 582 | LLAMA_ATTENTION_CLASSES = { 583 | "eager": LlamaAttention, 584 | "flash_attention_2": LlamaFlashAttention2, 585 | "sdpa": LlamaSdpaAttention, 586 | } 587 | 588 | 589 | class LlamaLadderDecoderLayer(nn.Module): 590 | def __init__(self, config: LlamaLadderConfig, layer_idx: int): 591 | super().__init__() 592 | self.hidden_size = config.hidden_size 593 | 594 | self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) 595 | 596 | self.mlp = LlamaMLP(config) 597 | self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 598 | self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 599 | 600 | def forward( 601 | self, 602 | hidden_states: torch.Tensor, 603 | attention_mask: Optional[torch.Tensor] = None, 604 | position_ids: Optional[torch.LongTensor] = None, 605 | past_key_value: Optional[Cache] = None, 606 | output_attentions: Optional[bool] = False, 607 | use_cache: Optional[bool] = False, 608 | cache_position: Optional[torch.LongTensor] = None, 609 | position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC 610 | prev_attn_output: Optional[torch.Tensor] = None, 611 | **kwargs: Unpack[FlashAttentionKwargs], 612 | ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: 613 | 614 | if prev_attn_output is None: 615 | # Normal computation flow 616 | residual = hidden_states 617 | hidden_states = self.input_layernorm(hidden_states) 618 | 619 | # Self Attention 620 | hidden_states, self_attn_weights, present_key_value = self.self_attn( 621 | hidden_states=hidden_states, 622 | attention_mask=attention_mask, 623 | position_ids=position_ids, 624 | past_key_value=past_key_value, 625 | output_attentions=output_attentions, 626 | use_cache=use_cache, 627 | cache_position=cache_position, 628 | position_embeddings=position_embeddings, 629 | **kwargs, 630 | ) 631 | hidden_states = residual + hidden_states 632 | attn_output = hidden_states 633 | 634 | # Fully Connected 635 | residual = hidden_states 636 | hidden_states = self.post_attention_layernorm(hidden_states) 637 | hidden_states = self.mlp(hidden_states) 638 | hidden_states = residual + hidden_states 639 | mlp_output = hidden_states 640 | 641 | else: 642 | # Ladder computation flow 643 | prev_mlp_output = hidden_states 644 | attn_input = self.input_layernorm(prev_attn_output) 645 | 646 | # Self Attention 647 | attn_output, self_attn_weights, present_key_value = self.self_attn( 648 | hidden_states=attn_input, 649 | attention_mask=attention_mask, 650 | position_ids=position_ids, 651 | past_key_value=past_key_value, 652 | output_attentions=output_attentions, 653 | use_cache=use_cache, 654 | cache_position=cache_position, 655 | position_embeddings=position_embeddings, 656 | **kwargs, 657 | ) 658 | attn_output = prev_mlp_output + attn_output 659 | 660 | # Fully Connected 661 | mlp_input = self.post_attention_layernorm(prev_mlp_output) 662 | mlp_output = self.mlp(mlp_input) 663 | mlp_output = attn_output + mlp_output 664 | 665 | outputs = ((attn_output, mlp_output),) 666 | if output_attentions: 667 | outputs += (self_attn_weights,) 668 | 669 | if use_cache: 670 | outputs += (present_key_value,) 671 | 672 | return outputs 673 | 674 | 675 | LLAMA_START_DOCSTRING = r""" 676 | This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the 677 | library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads 678 | etc.) 679 | 680 | This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. 681 | Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage 682 | and behavior. 683 | 684 | Parameters: 685 | config ([`LlamaConfig`]): 686 | Model configuration class with all the parameters of the model. Initializing with a config file does not 687 | load the weights associated with the model, only the configuration. Check out the 688 | [`~PreTrainedModel.from_pretrained`] method to load the model weights. 689 | """ 690 | 691 | 692 | @add_start_docstrings( 693 | "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", 694 | LLAMA_START_DOCSTRING, 695 | ) 696 | class LlamaPreTrainedModel(PreTrainedModel): 697 | config_class = LlamaConfig 698 | base_model_prefix = "model" 699 | supports_gradient_checkpointing = True 700 | _no_split_modules = ["LlamaDecoderLayer", "LlamaLadderDecoderLayer"] 701 | _skip_keys_device_placement = ["past_key_values"] 702 | _supports_flash_attn_2 = True 703 | _supports_sdpa = True 704 | _supports_flex_attn = True 705 | _supports_cache_class = True 706 | _supports_quantized_cache = True 707 | _supports_static_cache = True 708 | 709 | def _init_weights(self, module): 710 | std = self.config.initializer_range 711 | if isinstance(module, nn.Linear): 712 | module.weight.data.normal_(mean=0.0, std=std) 713 | if module.bias is not None: 714 | module.bias.data.zero_() 715 | elif isinstance(module, nn.Embedding): 716 | module.weight.data.normal_(mean=0.0, std=std) 717 | if module.padding_idx is not None: 718 | module.weight.data[module.padding_idx].zero_() 719 | 720 | 721 | LLAMA_INPUTS_DOCSTRING = r""" 722 | Args: 723 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): 724 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide 725 | it. 726 | 727 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and 728 | [`PreTrainedTokenizer.__call__`] for details. 729 | 730 | [What are input IDs?](../glossary#input-ids) 731 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): 732 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 733 | 734 | - 1 for tokens that are **not masked**, 735 | - 0 for tokens that are **masked**. 736 | 737 | [What are attention masks?](../glossary#attention-mask) 738 | 739 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and 740 | [`PreTrainedTokenizer.__call__`] for details. 741 | 742 | If `past_key_values` is used, optionally only the last `input_ids` have to be input (see 743 | `past_key_values`). 744 | 745 | If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] 746 | and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more 747 | information on the default strategy. 748 | 749 | - 1 indicates the head is **not masked**, 750 | - 0 indicates the head is **masked**. 751 | position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 752 | Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, 753 | config.n_positions - 1]`. 754 | 755 | [What are position IDs?](../glossary#position-ids) 756 | past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): 757 | Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention 758 | blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` 759 | returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. 760 | 761 | Two formats are allowed: 762 | - a [`~cache_utils.Cache`] instance, see our 763 | [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); 764 | - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of 765 | shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy 766 | cache format. 767 | 768 | The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the 769 | legacy cache format will be returned. 770 | 771 | If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't 772 | have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` 773 | of shape `(batch_size, sequence_length)`. 774 | inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): 775 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This 776 | is useful if you want more control over how to convert `input_ids` indices into associated vectors than the 777 | model's internal embedding lookup matrix. 778 | use_cache (`bool`, *optional*): 779 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see 780 | `past_key_values`). 781 | output_attentions (`bool`, *optional*): 782 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned 783 | tensors for more detail. 784 | output_hidden_states (`bool`, *optional*): 785 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for 786 | more detail. 787 | return_dict (`bool`, *optional*): 788 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 789 | cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): 790 | Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, 791 | this tensor is not affected by padding. It is used to update the cache in the correct position and to infer 792 | the complete sequence length. 793 | """ 794 | 795 | 796 | @add_start_docstrings( 797 | "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", 798 | LLAMA_START_DOCSTRING, 799 | ) 800 | class LlamaLadderModel(LlamaPreTrainedModel): 801 | """ 802 | Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaLadderDecoderLayer`] 803 | 804 | Args: 805 | config: LlamaConfig 806 | """ 807 | 808 | def __init__(self, config: LlamaLadderConfig): 809 | super().__init__(config) 810 | self.padding_idx = config.pad_token_id 811 | self.vocab_size = config.vocab_size 812 | 813 | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) 814 | self.layers = nn.ModuleList( 815 | [LlamaLadderDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] 816 | ) 817 | self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 818 | self.rotary_emb = LlamaRotaryEmbedding(config=config) 819 | self.gradient_checkpointing = False 820 | 821 | # Initialize weights and apply final processing 822 | self.post_init() 823 | 824 | def get_input_embeddings(self): 825 | return self.embed_tokens 826 | 827 | def set_input_embeddings(self, value): 828 | self.embed_tokens = value 829 | 830 | @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) 831 | def forward( 832 | self, 833 | input_ids: torch.LongTensor = None, 834 | attention_mask: Optional[torch.Tensor] = None, 835 | position_ids: Optional[torch.LongTensor] = None, 836 | past_key_values: Optional[Cache] = None, 837 | inputs_embeds: Optional[torch.FloatTensor] = None, 838 | use_cache: Optional[bool] = None, 839 | output_attentions: Optional[bool] = None, 840 | output_hidden_states: Optional[bool] = None, 841 | return_dict: Optional[bool] = None, 842 | cache_position: Optional[torch.LongTensor] = None, 843 | **flash_attn_kwargs: Unpack[FlashAttentionKwargs], 844 | ) -> Union[Tuple, BaseModelOutputWithPast]: 845 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 846 | output_hidden_states = ( 847 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 848 | ) 849 | use_cache = use_cache if use_cache is not None else self.config.use_cache 850 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 851 | 852 | if (input_ids is None) ^ (inputs_embeds is not None): 853 | raise ValueError("You must specify exactly one of input_ids or inputs_embeds") 854 | 855 | if self.gradient_checkpointing and self.training and use_cache: 856 | logger.warning_once( 857 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." 858 | ) 859 | use_cache = False 860 | 861 | if inputs_embeds is None: 862 | inputs_embeds = self.embed_tokens(input_ids) 863 | 864 | if use_cache and past_key_values is None: 865 | past_key_values = DynamicCache() 866 | 867 | if cache_position is None: 868 | past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 869 | cache_position = torch.arange( 870 | past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device 871 | ) 872 | 873 | if position_ids is None: 874 | position_ids = cache_position.unsqueeze(0) 875 | 876 | causal_mask = self._update_causal_mask( 877 | attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions 878 | ) 879 | 880 | hidden_states = inputs_embeds 881 | 882 | # create position embeddings to be shared across the decoder layers 883 | position_embeddings = self.rotary_emb(hidden_states, position_ids) 884 | 885 | # decoder layers 886 | all_hidden_states = () if output_hidden_states else None 887 | all_self_attns = () if output_attentions else None 888 | 889 | prev_attn_output = hidden_states # No previous attn output for the first layer, we feed the embedding to both first attention and mlp if we do ladder 890 | for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): 891 | if output_hidden_states: 892 | all_hidden_states += (hidden_states,) 893 | 894 | if layer_idx not in self.config.ladder_layers: 895 | if self.gradient_checkpointing and self.training: 896 | layer_outputs = self._gradient_checkpointing_func( 897 | decoder_layer.__call__, 898 | hidden_states, 899 | causal_mask, 900 | position_ids, 901 | past_key_values, 902 | output_attentions, 903 | use_cache, 904 | cache_position, 905 | position_embeddings, 906 | ) 907 | else: 908 | layer_outputs = decoder_layer( 909 | hidden_states, 910 | attention_mask=causal_mask, 911 | position_ids=position_ids, 912 | past_key_value=past_key_values, 913 | output_attentions=output_attentions, 914 | use_cache=use_cache, 915 | cache_position=cache_position, 916 | position_embeddings=position_embeddings, 917 | **flash_attn_kwargs, 918 | ) 919 | else: 920 | if self.gradient_checkpointing and self.training: 921 | layer_outputs = self._gradient_checkpointing_func( 922 | decoder_layer.__call__, 923 | hidden_states, 924 | causal_mask, 925 | position_ids, 926 | past_key_values, 927 | output_attentions, 928 | use_cache, 929 | cache_position, 930 | position_embeddings, 931 | prev_attn_output, 932 | ) 933 | else: 934 | layer_outputs = decoder_layer( 935 | hidden_states, 936 | attention_mask=causal_mask, 937 | position_ids=position_ids, 938 | past_key_value=past_key_values, 939 | output_attentions=output_attentions, 940 | use_cache=use_cache, 941 | cache_position=cache_position, 942 | position_embeddings=position_embeddings, 943 | **flash_attn_kwargs, 944 | prev_attn_output=prev_attn_output, 945 | ) 946 | 947 | hidden_states = layer_outputs[0][1] # This will correspond to the mlp output 948 | prev_attn_output = layer_outputs[0][0] # Store the attention output to be used as the stale input for next attention 949 | 950 | if output_attentions: 951 | all_self_attns += (layer_outputs[1],) 952 | 953 | hidden_states = self.norm(hidden_states) 954 | 955 | # add hidden states from the last decoder layer 956 | if output_hidden_states: 957 | all_hidden_states += (hidden_states,) 958 | 959 | output = BaseModelOutputWithPast( 960 | last_hidden_state=hidden_states, 961 | past_key_values=past_key_values if use_cache else None, 962 | hidden_states=all_hidden_states, 963 | attentions=all_self_attns, 964 | ) 965 | return output if return_dict else output.to_tuple() 966 | 967 | def _update_causal_mask( 968 | self, 969 | attention_mask: torch.Tensor, 970 | input_tensor: torch.Tensor, 971 | cache_position: torch.Tensor, 972 | past_key_values: Cache, 973 | output_attentions: bool, 974 | ): 975 | if self.config._attn_implementation == "flash_attention_2": 976 | if attention_mask is not None and (attention_mask == 0.0).any(): 977 | return attention_mask 978 | return None 979 | 980 | # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in 981 | # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail 982 | # to infer the attention mask. 983 | past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 984 | using_static_cache = isinstance(past_key_values, StaticCache) 985 | 986 | # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward 987 | if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: 988 | if AttentionMaskConverter._ignore_causal_mask_sdpa( 989 | attention_mask, 990 | inputs_embeds=input_tensor, 991 | past_key_values_length=past_seen_tokens, 992 | is_training=self.training, 993 | ): 994 | return None 995 | 996 | dtype, device = input_tensor.dtype, input_tensor.device 997 | sequence_length = input_tensor.shape[1] 998 | if using_static_cache: 999 | target_length = past_key_values.get_max_cache_shape() 1000 | else: 1001 | target_length = ( 1002 | attention_mask.shape[-1] 1003 | if isinstance(attention_mask, torch.Tensor) 1004 | else past_seen_tokens + sequence_length + 1 1005 | ) 1006 | 1007 | # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). 1008 | causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( 1009 | attention_mask, 1010 | sequence_length=sequence_length, 1011 | target_length=target_length, 1012 | dtype=dtype, 1013 | device=device, 1014 | cache_position=cache_position, 1015 | batch_size=input_tensor.shape[0], 1016 | ) 1017 | 1018 | if ( 1019 | self.config._attn_implementation == "sdpa" 1020 | and attention_mask is not None 1021 | and attention_mask.device.type == "cuda" 1022 | and not output_attentions 1023 | ): 1024 | # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when 1025 | # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. 1026 | # Details: https://github.com/pytorch/pytorch/issues/110213 1027 | min_dtype = torch.finfo(dtype).min 1028 | causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) 1029 | 1030 | return causal_mask 1031 | 1032 | @staticmethod 1033 | def _prepare_4d_causal_attention_mask_with_cache_position( 1034 | attention_mask: torch.Tensor, 1035 | sequence_length: int, 1036 | target_length: int, 1037 | dtype: torch.dtype, 1038 | device: torch.device, 1039 | cache_position: torch.Tensor, 1040 | batch_size: int, 1041 | **kwargs, 1042 | ): 1043 | """ 1044 | Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape 1045 | `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. 1046 | 1047 | Args: 1048 | attention_mask (`torch.Tensor`): 1049 | A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape 1050 | `(batch_size, 1, query_length, key_value_length)`. 1051 | sequence_length (`int`): 1052 | The sequence length being processed. 1053 | target_length (`int`): 1054 | The target length: when generating with static cache, the mask should be as long as the static cache, 1055 | to account for the 0 padding, the part of the cache that is not filled yet. 1056 | dtype (`torch.dtype`): 1057 | The dtype to use for the 4D attention mask. 1058 | device (`torch.device`): 1059 | The device to plcae the 4D attention mask on. 1060 | cache_position (`torch.Tensor`): 1061 | Indices depicting the position of the input sequence tokens in the sequence. 1062 | batch_size (`torch.Tensor`): 1063 | Batch size. 1064 | """ 1065 | if attention_mask is not None and attention_mask.dim() == 4: 1066 | # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. 1067 | causal_mask = attention_mask 1068 | else: 1069 | min_dtype = torch.finfo(dtype).min 1070 | causal_mask = torch.full( 1071 | (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device 1072 | ) 1073 | if sequence_length != 1: 1074 | causal_mask = torch.triu(causal_mask, diagonal=1) 1075 | causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) 1076 | causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) 1077 | if attention_mask is not None: 1078 | causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit 1079 | mask_length = attention_mask.shape[-1] 1080 | padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] 1081 | padding_mask = padding_mask == 0 1082 | causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( 1083 | padding_mask, min_dtype 1084 | ) 1085 | 1086 | return causal_mask 1087 | 1088 | 1089 | class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... 1090 | 1091 | 1092 | class LlamaLadderForCausalLM(LlamaPreTrainedModel, GenerationMixin): 1093 | config_class = LlamaLadderConfig 1094 | _tied_weights_keys = ["lm_head.weight"] 1095 | _tp_plan = {"lm_head": "colwise_rep"} 1096 | 1097 | def __init__(self, config): 1098 | super().__init__(config) 1099 | self.model = LlamaLadderModel(config) 1100 | self.vocab_size = config.vocab_size 1101 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 1102 | 1103 | # Initialize weights and apply final processing 1104 | self.post_init() 1105 | 1106 | def get_input_embeddings(self): 1107 | return self.model.embed_tokens 1108 | 1109 | def set_input_embeddings(self, value): 1110 | self.model.embed_tokens = value 1111 | 1112 | def get_output_embeddings(self): 1113 | return self.lm_head 1114 | 1115 | def set_output_embeddings(self, new_embeddings): 1116 | self.lm_head = new_embeddings 1117 | 1118 | def set_decoder(self, decoder): 1119 | self.model = decoder 1120 | 1121 | def get_decoder(self): 1122 | return self.model 1123 | 1124 | @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) 1125 | @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) 1126 | def forward( 1127 | self, 1128 | input_ids: torch.LongTensor = None, 1129 | attention_mask: Optional[torch.Tensor] = None, 1130 | position_ids: Optional[torch.LongTensor] = None, 1131 | past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, 1132 | inputs_embeds: Optional[torch.FloatTensor] = None, 1133 | labels: Optional[torch.LongTensor] = None, 1134 | use_cache: Optional[bool] = None, 1135 | output_attentions: Optional[bool] = None, 1136 | output_hidden_states: Optional[bool] = None, 1137 | return_dict: Optional[bool] = None, 1138 | cache_position: Optional[torch.LongTensor] = None, 1139 | num_logits_to_keep: int = 0, 1140 | **kwargs: Unpack[KwargsForCausalLM], 1141 | ) -> Union[Tuple, CausalLMOutputWithPast]: 1142 | r""" 1143 | Args: 1144 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 1145 | Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., 1146 | config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored 1147 | (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. 1148 | 1149 | num_logits_to_keep (`int`, *optional*): 1150 | Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all 1151 | `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that 1152 | token can save memory, which becomes pretty significant for long sequences or large vocabulary size. 1153 | 1154 | Returns: 1155 | 1156 | Example: 1157 | 1158 | ```python 1159 | >>> from transformers import AutoTokenizer, LlamaForCausalLM 1160 | 1161 | >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") 1162 | >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") 1163 | 1164 | >>> prompt = "Hey, are you conscious? Can you talk to me?" 1165 | >>> inputs = tokenizer(prompt, return_tensors="pt") 1166 | 1167 | >>> # Generate 1168 | >>> generate_ids = model.generate(inputs.input_ids, max_length=30) 1169 | >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] 1170 | "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." 1171 | ```""" 1172 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 1173 | output_hidden_states = ( 1174 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 1175 | ) 1176 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1177 | 1178 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 1179 | outputs = self.model( 1180 | input_ids=input_ids, 1181 | attention_mask=attention_mask, 1182 | position_ids=position_ids, 1183 | past_key_values=past_key_values, 1184 | inputs_embeds=inputs_embeds, 1185 | use_cache=use_cache, 1186 | output_attentions=output_attentions, 1187 | output_hidden_states=output_hidden_states, 1188 | return_dict=return_dict, 1189 | cache_position=cache_position, 1190 | **kwargs, 1191 | ) 1192 | 1193 | hidden_states = outputs[0] 1194 | # Only compute necessary logits, and do not upcast them to float if we are not computing the loss 1195 | logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) 1196 | 1197 | loss = None 1198 | if labels is not None: 1199 | loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) 1200 | 1201 | if not return_dict: 1202 | output = (logits,) + outputs[1:] 1203 | return (loss,) + output if loss is not None else output 1204 | 1205 | return CausalLMOutputWithPast( 1206 | loss=loss, 1207 | logits=logits, 1208 | past_key_values=outputs.past_key_values, 1209 | hidden_states=outputs.hidden_states, 1210 | attentions=outputs.attentions, 1211 | ) 1212 | 1213 | 1214 | __all__ = [ 1215 | "LlamaLadderForCausalLM", 1216 | "LlamaLadderModel", 1217 | "LlamaPreTrainedModel", 1218 | ] -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | sentencepiece 3 | tiktoken 4 | blobfile 5 | safetensors 6 | -------------------------------------------------------------------------------- /scripts/throughput-1B.sh: -------------------------------------------------------------------------------- 1 | mode=compile 2 | nodenum=1 3 | prompt_length=1024 4 | max_new_tokens=512 5 | 6 | for P2P_DISABLE in 0 1 7 | do 8 | export NCCL_P2P_DISABLE=${P2P_DISABLE} 9 | for model_name in "gpt_dense:1b" "gpt_ladder:1b" "gpt_desync:1b-upper-bound" "gpt_parallel:1b" 10 | do 11 | folder=./logs/prompt_length_${prompt_length}_max_new_${max_new_tokens}/p2p_disable${P2P_DISABLE}/${mode}/${model_name} 12 | mkdir -p ${folder} 13 | for bssize in 1 4 16 64 14 | do 15 | for tpsize in 1 2 4 8 16 | do 17 | echo "Running with P2P_DISABLE=${P2P_DISABLE} bs=${bssize} tp=${tpsize}" 18 | ENABLE_INTRA_NODE_COMM=1 NCCL_NVLS_ENABLE=1 NCCL_P2P_DISABLE=${P2P_DISABLE} torchrun --standalone --nproc_per_node=${tpsize} --nnodes=${nodenum} --master_port=15328 benchmark.py \ 19 | --model_name ${model_name} \ 20 | --num_samples 10 \ 21 | --batch_size ${bssize} \ 22 | --prompt_length ${prompt_length} \ 23 | --max_new_tokens ${max_new_tokens} \ 24 | --compile \ 25 | --compile_prefill \ 26 | --device cuda 2>&1 | tee ${folder}/bs_${bssize}_tp_${tpsize}.log 27 | echo "Finished running with P2P_DISABLE=${P2P_DISABLE} bs=${bssize} tp=${tpsize}" 28 | done 29 | done 30 | done 31 | done -------------------------------------------------------------------------------- /scripts/throughput-34B.sh: -------------------------------------------------------------------------------- 1 | mode=compile 2 | nodenum=1 3 | prompt_length=1024 4 | max_new_tokens=512 5 | 6 | for P2P_DISABLE in 0 1 7 | do 8 | export NCCL_P2P_DISABLE=${P2P_DISABLE} 9 | for model_name in "gpt_dense:34B" "gpt_ladder:34B" "gpt_desync:34B-upper-bound" "gpt_parallel:34B" 10 | do 11 | folder=./logs/prompt_length_${prompt_length}_max_new_${max_new_tokens}/p2p_disable${P2P_DISABLE}/${mode}/${model_name} 12 | mkdir -p ${folder} 13 | for bssize in 1 4 16 64 14 | do 15 | for tpsize in 1 2 4 8 16 | do 17 | echo "Running with P2P_DISABLE=${P2P_DISABLE} bs=${bssize} tp=${tpsize}" 18 | ENABLE_INTRA_NODE_COMM=1 NCCL_NVLS_ENABLE=1 NCCL_P2P_DISABLE=${P2P_DISABLE} torchrun --standalone --nproc_per_node=${tpsize} --nnodes=${nodenum} --master_port=15328 benchmark.py \ 19 | --model_name ${model_name} \ 20 | --num_samples 10 \ 21 | --batch_size ${bssize} \ 22 | --prompt_length ${prompt_length} \ 23 | --max_new_tokens ${max_new_tokens} \ 24 | --compile \ 25 | --compile_prefill \ 26 | --device cuda 2>&1 | tee ${folder}/bs_${bssize}_tp_${tpsize}.log 27 | echo "Finished running with P2P_DISABLE=${P2P_DISABLE} bs=${bssize} tp=${tpsize}" 28 | done 29 | done 30 | done 31 | done -------------------------------------------------------------------------------- /scripts/throughput-3B.sh: -------------------------------------------------------------------------------- 1 | mode=compile 2 | nodenum=1 3 | prompt_length=1024 4 | max_new_tokens=512 5 | 6 | for P2P_DISABLE in 0 1 7 | do 8 | export NCCL_P2P_DISABLE=${P2P_DISABLE} 9 | for model_name in "gpt_dense:3b" "gpt_ladder:3b" "gpt_desync:3b-upper-bound" "gpt_parallel:3b" 10 | do 11 | folder=./logs/prompt_length_${prompt_length}_max_new_${max_new_tokens}/p2p_disable${P2P_DISABLE}/${mode}/${model_name} 12 | mkdir -p ${folder} 13 | for bssize in 1 4 16 64 14 | do 15 | for tpsize in 1 2 4 8 16 | do 17 | echo "Running with P2P_DISABLE=${P2P_DISABLE} bs=${bssize} tp=${tpsize}" 18 | ENABLE_INTRA_NODE_COMM=1 NCCL_NVLS_ENABLE=1 NCCL_P2P_DISABLE=${P2P_DISABLE} torchrun --standalone --nproc_per_node=${tpsize} --nnodes=${nodenum} --master_port=15328 benchmark.py \ 19 | --model_name ${model_name} \ 20 | --num_samples 10 \ 21 | --batch_size ${bssize} \ 22 | --prompt_length ${prompt_length} \ 23 | --max_new_tokens ${max_new_tokens} \ 24 | --compile \ 25 | --compile_prefill \ 26 | --device cuda 2>&1 | tee ${folder}/bs_${bssize}_tp_${tpsize}.log 27 | echo "Finished running with P2P_DISABLE=${P2P_DISABLE} bs=${bssize} tp=${tpsize}" 28 | done 29 | done 30 | done 31 | done -------------------------------------------------------------------------------- /scripts/throughput-405B.sh: -------------------------------------------------------------------------------- 1 | mode=compile 2 | nodenum=2 3 | prompt_length=1024 4 | max_new_tokens=512 5 | master_addr= # please specify the master_addr 6 | 7 | echo "the master_addr has been set to ${master_addr} and I am in the rank $1" 8 | 9 | for P2P_DISABLE in 0 1 10 | do 11 | export NCCL_P2P_DISABLE=${P2P_DISABLE} 12 | for model_name in "gpt_dense:llama-3-405b" "gpt_ladder:llama-3-405b" "gpt_desync:llama-3-405b-upper-bound" "gpt_parallel:llama-3-405b" 13 | do 14 | folder=./logs/prompt_length_${prompt_length}_max_new_${max_new_tokens}/p2p_disable${P2P_DISABLE}/${mode}/${model_name} 15 | mkdir -p ${folder} 16 | for bssize in 1 4 16 64 17 | do 18 | for tpsize in 16 19 | do 20 | nproc_per_node=$((tpsize/nodenum)) 21 | echo "Running with P2P_DISABLE=${P2P_DISABLE} bs=${bssize} tp=${tpsize}" 22 | ENABLE_INTRA_NODE_COMM=1 NCCL_NVLS_ENABLE=1 NCCL_P2P_DISABLE=${P2P_DISABLE} torchrun --nproc_per_node=${nproc_per_node} --nnodes=${nodenum} --master_addr=${master_addr} --master_port=15328 --node_rank=$1 benchmark.py \ 23 | --model_name ${model_name} \ 24 | --num_samples 10 \ 25 | --batch_size ${bssize} \ 26 | --prompt_length ${prompt_length} \ 27 | --max_new_tokens ${max_new_tokens} \ 28 | --compile \ 29 | --compile_prefill \ 30 | --device cuda 2>&1 | tee ${folder}/bs_${bssize}_tp_${tpsize}.log 31 | echo "Finished running with P2P_DISABLE=${P2P_DISABLE} bs=${bssize} tp=${tpsize}" 32 | done 33 | done 34 | done 35 | done -------------------------------------------------------------------------------- /scripts/throughput-70B.sh: -------------------------------------------------------------------------------- 1 | mode=compile 2 | nodenum=1 3 | prompt_length=1024 4 | max_new_tokens=512 5 | 6 | for P2P_DISABLE in 0 7 | do 8 | export NCCL_P2P_DISABLE=${P2P_DISABLE} 9 | for model_name in "gpt_dense:llama-3-70b" "gpt_ladder:llama-3-70b" "gpt_desync:llama-3-70b-upper-bound" "gpt_parallel:llama-3-70b" 10 | do 11 | folder=./logs/prompt_length_${prompt_length}_max_new_${max_new_tokens}/p2p_disable${P2P_DISABLE}/${mode}/${model_name} 12 | mkdir -p ${folder} 13 | for bssize in 1 4 16 64 14 | do 15 | for tpsize in 2 16 | do 17 | echo "Running with P2P_DISABLE=${P2P_DISABLE} bs=${bssize} tp=${tpsize}" 18 | ENABLE_INTRA_NODE_COMM=1 NCCL_NVLS_ENABLE=1 NCCL_P2P_DISABLE=${P2P_DISABLE} torchrun --standalone --nproc_per_node=${tpsize} --nnodes=${nodenum} --master_port=15328 benchmark.py \ 19 | --model_name ${model_name} \ 20 | --num_samples 10 \ 21 | --batch_size ${bssize} \ 22 | --prompt_length ${prompt_length} \ 23 | --max_new_tokens ${max_new_tokens} \ 24 | --compile \ 25 | --compile_prefill \ 26 | --device cuda 2>&1 | tee ${folder}/bs_${bssize}_tp_${tpsize}.log 27 | echo "Finished running with P2P_DISABLE=${P2P_DISABLE} bs=${bssize} tp=${tpsize}" 28 | done 29 | done 30 | done 31 | done -------------------------------------------------------------------------------- /scripts/throughput-8B.sh: -------------------------------------------------------------------------------- 1 | mode=compile 2 | nodenum=1 3 | prompt_length=1024 4 | max_new_tokens=512 5 | 6 | for P2P_DISABLE in 0 1 7 | do 8 | export NCCL_P2P_DISABLE=${P2P_DISABLE} 9 | for model_name in "gpt_dense:llama-3-8b" "gpt_ladder:llama-3-8b" "gpt_desync:llama-3-8b-upper-bound" "gpt_parallel:llama-3-8b" 10 | do 11 | folder=./logs/prompt_length_${prompt_length}_max_new_${max_new_tokens}/p2p_disable${P2P_DISABLE}/${mode}/${model_name} 12 | mkdir -p ${folder} 13 | for bssize in 1 4 16 64 14 | do 15 | for tpsize in 1 2 4 8 16 | do 17 | echo "Running with P2P_DISABLE=${P2P_DISABLE} bs=${bssize} tp=${tpsize}" 18 | ENABLE_INTRA_NODE_COMM=1 NCCL_NVLS_ENABLE=1 NCCL_P2P_DISABLE=${P2P_DISABLE} torchrun --standalone --nproc_per_node=${tpsize} --nnodes=${nodenum} --master_port=15328 benchmark.py \ 19 | --model_name ${model_name} \ 20 | --num_samples 10 \ 21 | --batch_size ${bssize} \ 22 | --prompt_length ${prompt_length} \ 23 | --max_new_tokens ${max_new_tokens} \ 24 | --compile \ 25 | --compile_prefill \ 26 | --device cuda 2>&1 | tee ${folder}/bs_${bssize}_tp_${tpsize}.log 27 | echo "Finished running with P2P_DISABLE=${P2P_DISABLE} bs=${bssize} tp=${tpsize}" 28 | done 29 | done 30 | done 31 | done -------------------------------------------------------------------------------- /scripts/throughput-bloom176b.sh: -------------------------------------------------------------------------------- 1 | mode=compile 2 | nodenum=1 3 | prompt_length=1024 4 | max_new_tokens=512 5 | 6 | for P2P_DISABLE in 0 1 7 | do 8 | export NCCL_P2P_DISABLE=${P2P_DISABLE} 9 | for model_name in "gpt_dense:bloom-176b" "gpt_ladder:bloom-176b" 10 | do 11 | folder=./logs/12-14/prompt_length_${prompt_length}_max_new_${max_new_tokens}/p2p_disable${P2P_DISABLE}/${mode}/${model_name} 12 | mkdir -p ${folder} 13 | for bssize in 4 14 | do 15 | for tpsize in 8 16 | do 17 | echo "Running with P2P_DISABLE=${P2P_DISABLE} bs=${bssize} tp=${tpsize}" 18 | ENABLE_INTRA_NODE_COMM=1 NCCL_NVLS_ENABLE=1 NCCL_P2P_DISABLE=${P2P_DISABLE} torchrun --standalone --nproc_per_node=${tpsize} --nnodes=${nodenum} --master_port=15328 benchmark.py \ 19 | --model_name ${model_name} \ 20 | --num_samples 10 \ 21 | --batch_size ${bssize} \ 22 | --prompt_length ${prompt_length} \ 23 | --max_new_tokens ${max_new_tokens} \ 24 | --compile \ 25 | --compile_prefill \ 26 | --device cuda 2>&1 | tee ${folder}/bs_${bssize}_tp_${tpsize}.log 27 | echo "Finished running with P2P_DISABLE=${P2P_DISABLE} bs=${bssize} tp=${tpsize}" 28 | done 29 | done 30 | done 31 | done -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup( 4 | name="gpt-fast", 5 | version="0.1", 6 | packages=find_packages(), 7 | install_requires=[ 8 | "torch", 9 | ], 10 | description="A simple, fast, pure PyTorch Llama inference engine", 11 | long_description=open("README.md").read(), 12 | long_description_content_type="text/markdown", 13 | url="https://github.com/pytorch-labs/gpt-fast", 14 | ) 15 | -------------------------------------------------------------------------------- /tools/plot-405b.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | marker = "." 4 | markersize = 8 5 | y_label = "tokens/sec speedup" 6 | x_label = "batch size" 7 | 8 | 9 | batch_sizes = [1, 4, 8, 16] 10 | ladder_nvl = [1.364019677, 1.307557841, 1.393313367, 1.349331588] 11 | parallel_nvl = [1.237526353, 1.242467866, 1.285679359, 1.272326351] 12 | upper_bound_nvl = [1.52073085, 1.511876607, 1.658883931, 1.631580761] 13 | 14 | ladder_no_nvl = [1.489855072, 1.565769112, 1.434301985, 1.461046963] 15 | parallel_no_nvl = [1.396618357, 1.439028165, 1.415040616, 1.398749487] 16 | upper_bound_no_nvl = [2.106763285, 2.274682761, 2.270580353, 2.15964584] 17 | 18 | plt.plot(batch_sizes, ladder_nvl, marker=marker, markersize=markersize, label="ladder transformer") 19 | plt.plot(batch_sizes, parallel_nvl, marker=marker, markersize=markersize, label="parallel attn") 20 | plt.plot(batch_sizes, upper_bound_nvl, marker=marker, markersize=markersize, label="upper bound") 21 | 22 | plt.ylim(1, 2.5) 23 | plt.xticks(batch_sizes) 24 | 25 | plt.title("P2P enabled (TP = 16)") 26 | plt.ylabel(y_label) 27 | plt.xlabel(x_label) 28 | plt.grid(True, linestyle=":", color="gray", linewidth=0.5) 29 | plt.legend(loc="upper center", ncol=3) 30 | plt.tight_layout() 31 | plt.savefig("405b-nvl.png", dpi=300) 32 | 33 | 34 | plt.figure() 35 | 36 | plt.plot(batch_sizes, ladder_no_nvl, marker=marker, markersize=markersize, label="ladder transformer") 37 | plt.plot(batch_sizes, parallel_no_nvl, marker=marker, markersize=markersize, label="parallel attn") 38 | plt.plot(batch_sizes, upper_bound_no_nvl, marker=marker, markersize=markersize, label="upper bound") 39 | 40 | plt.ylim(1, 2.5) 41 | plt.xticks(batch_sizes) 42 | 43 | plt.title("P2P disabled (TP = 16)") 44 | plt.ylabel(y_label) 45 | plt.xlabel(x_label) 46 | plt.grid(True, linestyle=":", color="gray", linewidth=0.5) 47 | plt.legend(loc="upper center", ncol=3) 48 | plt.tight_layout() 49 | plt.savefig("405b-no-nvl.png", dpi=300) 50 | -------------------------------------------------------------------------------- /tools/plot-70b.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | marker = "." 4 | markersize = 8 5 | y_label = "tokens/sec" 6 | x_label = "batch size" 7 | title = "batch size = {batch_size}, 70B model" 8 | tp_world_size = [2, 4, 8] 9 | 10 | blue = "#1f77b4" 11 | orange = "#ff7f0e" 12 | green = "#2ca02c" 13 | red = "#d62728" 14 | 15 | 16 | def plot( 17 | batch_size: int, 18 | standard_nvl: list[float], 19 | ladder_nvl: list[float], 20 | upper_bound_nvl: list[float], 21 | parallel_nvl: list[float], 22 | standard_no_nvl: list[float], 23 | ladder_no_nvl: list[float], 24 | parallel_no_nvl: list[float], 25 | ) -> None: 26 | plt.figure() 27 | 28 | plt.plot(tp_world_size, standard_nvl, marker=marker, markersize=markersize, label="standard transformer P2P=1", linestyle="-", color=blue) 29 | plt.plot(tp_world_size, standard_no_nvl, marker=marker, markersize=markersize, label="standard transformer P2P=0", linestyle="--", color=blue) 30 | plt.plot(tp_world_size, ladder_nvl, marker=marker, markersize=markersize, label="ladder transformer P2P=1", linestyle="-", color=orange) 31 | plt.plot(tp_world_size, ladder_no_nvl, marker=marker, markersize=markersize, label="ladder transformer P2P=0", linestyle="--", color=orange) 32 | plt.plot(tp_world_size, parallel_nvl, marker=marker, markersize=markersize, label="parallel attn P2P=1", linestyle="-", color=green) 33 | plt.plot(tp_world_size, parallel_no_nvl, marker=marker, markersize=markersize, label="parallel attn P2P=0", linestyle="--", color=green) 34 | plt.plot(tp_world_size, upper_bound_nvl, marker=marker, markersize=markersize, label="upper bound P2P=1", linestyle="-", color=red) 35 | 36 | plt.xticks(tp_world_size) 37 | plt.xlim(1.5, 8.5) 38 | 39 | plt.title(title.format(batch_size=batch_size)) 40 | plt.ylabel(y_label) 41 | plt.xlabel(x_label) 42 | plt.grid(True, linestyle=":", color="gray", linewidth=0.5) 43 | plt.legend(loc="upper left") 44 | plt.tight_layout() 45 | 46 | plt.savefig(f"70b-{batch_size}.png", dpi=300) 47 | 48 | plot( 49 | batch_size=1, 50 | standard_nvl=[35.42, 59.41, 77.39], 51 | ladder_nvl=[36.69, 67.51, 101.22], 52 | upper_bound_nvl=[38.24, 69.04, 110.59], 53 | parallel_nvl=[36.67, 65.4, 94.22], 54 | standard_no_nvl=[33.77, 53.35, 51.66], 55 | ladder_no_nvl=[36.94, 66.6, 82.59], 56 | parallel_no_nvl=[36.12, 61.93, 72.36], 57 | ) 58 | 59 | plot( 60 | batch_size=4, 61 | standard_nvl=[120.58, 185.01, 258.56], 62 | ladder_nvl=[126.09, 204.3, 331.45], 63 | upper_bound_nvl=[130.77, 213.73, 355.8], 64 | parallel_nvl=[123.66, 201.62, 307.34], 65 | standard_no_nvl=[106.6, 158.77, 173.62], 66 | ladder_no_nvl=[125.55, 204.24, 271.82], 67 | parallel_no_nvl=[116.83, 183.53, 241.08], 68 | ) 69 | 70 | plot( 71 | batch_size=16, 72 | standard_nvl=[float("nan"), 585.23, 843.15], 73 | ladder_nvl=[float("nan"), 635.98, 1003.52], 74 | upper_bound_nvl=[float("nan"), 665.07, 1109.65], 75 | parallel_nvl=[float("nan"), 628.23, 973.74], 76 | standard_no_nvl=[float("nan"), 518.41, 546.68], 77 | ladder_no_nvl=[float("nan"), 598.71, 738.56], 78 | parallel_no_nvl=[float("nan"), 583.48, 744.33], 79 | ) 80 | 81 | plot( 82 | batch_size=64, 83 | standard_nvl=[float("nan"), 1249.67, 1940.99], 84 | ladder_nvl=[float("nan"), 1358.65, 2242.1], 85 | upper_bound_nvl=[float("nan"), 1433.53, 2474.49], 86 | parallel_nvl=[float("nan"), 1311.54, 2259.38], 87 | standard_no_nvl=[float("nan"), 1199.62, 1454.42], 88 | ladder_no_nvl=[float("nan"), 1313.44, 1864.05], 89 | parallel_no_nvl=[float("nan"), 1276.66, 1873.71], 90 | ) 91 | --------------------------------------------------------------------------------