├── requirements.txt ├── tmp.py ├── Makefile ├── LICENSE ├── README.md ├── .gitignore ├── bench.py ├── test.py ├── gpt.py └── kernels.py /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | transformers 3 | numpy<2.0.0 4 | black 5 | isort 6 | ipdb 7 | -------------------------------------------------------------------------------- /tmp.py: -------------------------------------------------------------------------------- 1 | # TRITON_INTERPRET=1 python3 tmp.py 2 | 3 | import torch 4 | 5 | from gpt import FusedGPT, GPTConfig, estimate_days, get_num_parameters 6 | 7 | print("training time (in hours):", t) 8 | 9 | import ipdb 10 | 11 | ipdb.set_trace() 12 | 13 | device = "cuda" if torch.cuda.is_available() else "cpu" 14 | 15 | torch.manual_seed(1337) 16 | 17 | print(z.shape, z) 18 | 19 | print("diff:", (z - z_torch).abs().max()) 20 | print(z) 21 | print(z_torch) 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | checkdirs := . 2 | 3 | style: 4 | black $(checkdirs) 5 | isort $(checkdirs) 6 | 7 | install: 8 | pip install jaxtyping 9 | pip install git+https://github.com/Deep-Learning-Profiling-Tools/triton-viz@v1 10 | wget "https://dl.cloudsmith.io/public/test-wha/triton-puzzles/raw/files/triton-3.0.0-cp310-cp310-linux_x86_64.whl" 11 | pip install triton-3.0.0-cp310-cp310-linux_x86_64.whl 12 | # export LC_ALL="en_US.UTF-8" 13 | # export LD_LIBRARY_PATH="/usr/lib64-nvidia" 14 | # export LIBRARY_PATH="/usr/local/cuda/lib64/stubs" 15 | # ldconfig /usr/lib64-nvidia 16 | pip3 install -r requirements.txt 17 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Vasudev Gupta 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Triton implementation of GPT/LLAMA models. Objective of this project is to understand how much performance can be squeezed out if we implement full-GPT-block in one triton kernel. 2 | 3 | **Performance** 4 | 5 | triton implementation is more fast & memory efficient compared to HuggingFace Transformers implementation. 6 | 7 | ```bash 8 | python3 bench.py 9 | ``` 10 | 11 | **Latency** 12 | 13 | | precision | HuggingFace GPT | Triton GPT | 14 | |------------------------|-----------------|------------| 15 | | fp32 | 1800 ms | - | 16 | | tf32 | 631.35 ms | 462.63 ms | 17 | | mixed precision (fp16) | 510.80 ms | 273 ms | 18 | | fp16 | 301.92 ms | - | 19 | 20 | _time taken to process batch size - 512x300 on 1 A100 40 GB_ 21 | 22 | **Max Batch Size** 23 | 24 | | | max batch size | 25 | |------------------------|----------------| 26 | | HuggingFace GPT | 1024 | 27 | | Triton GPT | 2048 | 28 | 29 | _I considered batch sizes with power of 2 only. Both runs had seqlen=300 and mixed precision was enabled._ 30 | 31 | **MFU** 32 | 33 | ```python 34 | from gpt import compute_mfu 35 | # fwd MFU 36 | 37 | # HuggingFace GPT (fp16) 38 | compute_mfu(2 * 124 * 10**6 * 512*512 / 0.302, gpu="h100") 39 | # 21.76% 40 | 41 | # HuggingFace GPT (mixed precision) 42 | compute_mfu(2 * 124 * 10**6 * 512*512 / 0.510, gpu="h100") 43 | # 12.88% 44 | 45 | # triton (mixed precision) 46 | compute_mfu(2 * 124 * 10**6 * 512*512 / 0.273, gpu="h100") 47 | # 24.07% 48 | ``` 49 | 50 | **Supported Features** 51 | * [x] fused implementation of several components of GPT block (for eg: `dropout(wte(x) + wpe(x))`, `dropout(wx + b)`, `gelu(wx + b)`) 52 | * [x] flash attention v1 algorithm 53 | * [x] GPT2 implementation in triton 54 | * [x] support for loading pre-trained weights of huggingface-gpt2 55 | * [ ] support KV cache & sampling for inference loop 56 | * [ ] implement back-propogation of GPT block in triton (i.e. solving the math problem) 57 | * [ ] implement paged-attention from vLLM project in triton 58 | * [ ] implement flash attention v2 & v3 59 | * [ ] add kernels for LLAMA-3.1 60 | * [ ] implement adamw in triton (with FSDP-stage2 support) 61 | 62 | **Installation** 63 | 64 | ```bash 65 | pip3 install -r requirements.txt 66 | # `numpy<2` is hard-requirement for running on CPU 67 | # else triton gives garbage - likely some bug in triton 68 | ``` 69 | 70 | **Running tests** 71 | 72 | ```python 73 | # you can run following command on CPU 74 | TRITON_INTERPRET=1 pytest -sv test.py 75 | 76 | # you can run following command on GPU 77 | pytest -sv test.py 78 | ``` 79 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /bench.py: -------------------------------------------------------------------------------- 1 | # TRITON_INTERPRET=1 python3 bench.py 2 | 3 | import torch 4 | import triton 5 | from transformers import AutoTokenizer 6 | from gpt import convert_hf_and_load_model 7 | 8 | STRING = """\ 9 | Large language models have been shown to achieve remarkable performance across a variety of natural\ 10 | language tasks using few-shot learning, which drastically reduces the number of task-specific training\ 11 | examples needed to adapt the model to a particular application. To further our understanding of the\ 12 | impact of scale on few-shot learning, we trained a 540-billion parameter, densely activated, Transformer\ 13 | language model, which we call Pathways Language Model (PaLM).\ 14 | We trained PaLM on 6144 TPU v4 chips using Pathways, a new ML system which enables highly efficient\ 15 | training across multiple TPU Pods. We demonstrate continued benefits of scaling by achieving state-ofthe-art few-shot learning results on hundreds of language understanding and generation benchmarks. On a\ 16 | number of these tasks, PaLM 540B achieves breakthrough performance, outperforming the finetuned stateof-the-art on a suite of multi-step reasoning tasks, and outperforming average human performance on the\ 17 | recently released BIG-bench benchmark. A significant number of BIG-bench tasks showed discontinuous\ 18 | improvements from model scale, meaning that performance steeply increased as we scaled to our largest\ 19 | model. PaLM also has strong capabilities in multilingual tasks and source code generation, which we\ 20 | demonstrate on a wide array of benchmarks. We additionally provide a comprehensive analysis on bias\ 21 | and toxicity, and study the extent of training data memorization with respect to model scale. Finally,\ 22 | we discuss the ethical considerations related to large language models and discuss potential mitigation\ 23 | strategies.\ 24 | """ 25 | 26 | def run_benchmark(provider, warmup=25, rep=100, mixed_precison=False): 27 | assert torch.cuda.is_available() 28 | device = "cuda" 29 | model_id = "gpt2" 30 | model, hf_model = convert_hf_and_load_model(model_id, device) 31 | if mixed_precison: 32 | model.to(torch.float16) 33 | # hf_model.to(torch.float16) 34 | tokenizer = AutoTokenizer.from_pretrained(model_id) 35 | # triton is slow for batch_size = 1 with current settings but much faster with batch > 1 36 | inputs = tokenizer([STRING] * 512, return_tensors="pt", max_length=512, truncation=True) 37 | inputs = {k: v.to(device) for k, v in inputs.items()} 38 | with torch.no_grad(): 39 | # z_torch = hf_model(**inputs).last_hidden_state 40 | # z = model(inputs["input_ids"]) 41 | # print("diff:", z - z_torch) 42 | if provider == "torch": 43 | def fn(): 44 | if mixed_precison: 45 | with torch.autocast(device_type="cuda", dtype=torch.float16): 46 | return hf_model(**inputs).last_hidden_state 47 | else: 48 | return hf_model(**inputs).last_hidden_state 49 | return triton.testing.do_bench(fn, warmup=warmup, rep=rep) 50 | if provider == "triton": 51 | fn = lambda: model(inputs["input_ids"]) 52 | return triton.testing.do_bench(fn, warmup=warmup, rep=rep) 53 | 54 | # 1 A100 40 GB 55 | # torch: batch_size = 512 && t = 1801.32 56 | # triton: batch_size = 512 && t = 789.14 57 | # torch: batch_size = 1024 && OOM 58 | # triton: batch_size = 2048 && t = 3153.70 59 | 60 | print("triton:", run_benchmark("triton")) 61 | print("torch:", run_benchmark("torch")) 62 | 63 | # OLD SUMMARY 64 | # fp32 65 | # torch: 1800 66 | # triton: 789.14 67 | 68 | # mixed precision 69 | # torch: 510.80 70 | # triton: 429.80 71 | 72 | # fp16 73 | # torch: 301.92 74 | 75 | # triton with mixed precison = False 76 | # ffn cast enabled: 791.13 77 | # flash cast enabled: 759.71 78 | # num_warps = 8 & BLOCK_SIZE = 64 ffn :: 759.18 79 | # num_warps = 8 & BLOCK_SIZE = 128 ffn :: 463.80 80 | # layer norm BLOCK_SIZE = 32768 :: 832.63 81 | # layer norm BLOCK_SIZE = 512 :: 462.61 82 | # embeddings BLOCK_SIZE = 512 :: 462.87 83 | # attention BLOCK_SIZE = 128 & num_stages = 4 :: 1279.38 84 | # attention BLOCK_SIZE = 128 & num_stages = 8 :: 460.27 85 | # final config: embeddings (512, 4) + layer norm (512, 4) + ffn (128, 128, 64, 8) + attention (128, 8) 86 | 87 | # mixed precision = True 88 | # triton: 273.61 89 | # with attention (128, 8), t = 900 but with attention (64, 4), t = 273! 90 | 91 | # mixed precision = False 92 | # torch.backends.cuda.matmul.allow_tf32 = True 93 | # torch.backends.cudnn.allow_tf32 = True 94 | # torch: 623.3262329101562 95 | 96 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # TRITON_INTERPRET=1 pytest -sv test.py 2 | 3 | import math 4 | 5 | import pytest 6 | import torch 7 | import torch.nn as nn 8 | from transformers import AutoTokenizer 9 | from transformers.activations import ACT2FN 10 | 11 | from gpt import (FusedGPT, GPTConfig, convert_hf_and_load_model, estimate_days, 12 | get_num_parameters) 13 | from kernels import (flash_attention_v1, fused_embeddings, fused_ffn, 14 | fused_layer_norm) 15 | 16 | 17 | def _get_inputs(M, K, N, device): 18 | torch.manual_seed(1337) 19 | x = torch.rand((M, K), device=device, dtype=torch.float32) 20 | w = torch.rand((K, N), device=device, dtype=torch.float32) 21 | b = torch.rand((N,), device=device, dtype=torch.float32) 22 | r = torch.rand_like(x, dtype=torch.float32) 23 | if K != N: 24 | r = r_torch = None 25 | return x, w, b, r 26 | 27 | 28 | @pytest.mark.parametrize("vocab_size", [2, 32]) 29 | @pytest.mark.parametrize("batch_size", [8]) 30 | @pytest.mark.parametrize("hidden_size", [32, 128, 256]) 31 | @pytest.mark.parametrize("seqlen, block_size", [(10, 20), (20, 20)]) 32 | def test_fused_embeddings(batch_size, seqlen, vocab_size, block_size, hidden_size): 33 | device = "cuda" if torch.cuda.is_available() else "cpu" 34 | 35 | x = torch.randint( 36 | 0, vocab_size, size=(batch_size, seqlen), dtype=torch.long, device=device 37 | ) 38 | wte = torch.rand((vocab_size, hidden_size), device=device) 39 | wpe = torch.rand((block_size, hidden_size), device=device) 40 | 41 | z_torch = wte[x] + wpe[torch.arange(x.shape[1], device=device)][None] 42 | z = fused_embeddings(x, wte, wpe) 43 | 44 | assert torch.allclose(z, z_torch, atol=1e-5), (z - z_torch).abs().max() 45 | 46 | 47 | @pytest.mark.parametrize("M", [249, 32]) 48 | @pytest.mark.parametrize("K", [123, 128, 64]) 49 | def test_fused_layer_norm(M, K): 50 | N = 32 51 | device = "cuda" if torch.cuda.is_available() else "cpu" 52 | x, *_ = _get_inputs(M, K, N, device) 53 | x_torch, *_ = _get_inputs(M, K, N, device) 54 | 55 | layer_norm = nn.LayerNorm(K).to(device) 56 | x_torch = layer_norm(x_torch) 57 | x = fused_layer_norm(x, layer_norm.weight.data, layer_norm.bias.data) 58 | 59 | assert torch.allclose(x, x_torch, atol=1e-5), (x - x_torch).abs().max() 60 | 61 | 62 | def torch_ffn(x, w, b=None, r=None): 63 | z = x @ w 64 | if b is not None: 65 | z += b 66 | z = ACT2FN["gelu_new"](z) 67 | if r is not None: 68 | z += r 69 | return z 70 | 71 | 72 | @pytest.mark.parametrize("M,N,K", [(128, 128, 256), (199, 129, 129), (61, 31, 23)]) 73 | @pytest.mark.parametrize("add_gelu", [True, False]) 74 | @pytest.mark.parametrize("add_bias", [True, False]) 75 | def test_fused_ffn(M, N, K, add_gelu, add_bias): 76 | device = "cuda" if torch.cuda.is_available() else "cpu" 77 | x_torch, w_torch, b_torch, r_torch = _get_inputs(M, K, N, device) 78 | x, w, b, r = _get_inputs(M, K, N, device) 79 | 80 | if not add_bias: 81 | b_torch = None 82 | b = None 83 | 84 | z_torch = torch_ffn(x_torch, w_torch, b=b_torch, r=r_torch) 85 | 86 | z = fused_ffn(x, w, bias=b, residual=r, add_gelu=True) 87 | assert torch.allclose(z, z_torch, atol=1e-5), (z - z_torch).abs().max() 88 | 89 | 90 | def _get_attn_inputs(B, N, L, H, device): 91 | torch.manual_seed(1337) 92 | q = torch.rand((B, N, L, H), device=device) 93 | k = torch.rand_like(q) 94 | v = torch.rand_like(q) 95 | return q, k, v 96 | 97 | 98 | def torch_attention(q, k, v): 99 | assert q.shape == k.shape == v.shape 100 | B, N, L, H = q.shape 101 | q, k, v = map(lambda x: x.view(B * N, L, H), (q, k, v)) 102 | z = (q @ k.transpose(1, 2)) / math.sqrt(H) 103 | attn_mask = torch.tril(torch.ones((L, L), dtype=torch.bool)) 104 | z = torch.where(attn_mask, z, float("-inf")) 105 | z = z.softmax(-1) @ v 106 | return z.view(B, N, L, H) 107 | 108 | 109 | @pytest.mark.parametrize("B,N", [(3, 9), (2, 7)]) 110 | @pytest.mark.parametrize("L", [199, 128, 63]) 111 | @pytest.mark.parametrize("H", [64, 128, 256]) 112 | def test_flash_attention_v1(B, N, L, H): 113 | device = "cuda" if torch.cuda.is_available() else "cpu" 114 | q, k, v = _get_attn_inputs(B, N, L, H, device) 115 | z_torch = torch_attention(q, k, v) 116 | z = flash_attention_v1(q, k, v) 117 | assert torch.allclose(z, z_torch, atol=1e-5), (z - z_torch).abs().max() 118 | 119 | 120 | def test_gpt2(): 121 | device = "cuda" if torch.cuda.is_available() else "cpu" 122 | model_id = "gpt2" 123 | model, hf_model = convert_hf_and_load_model(model_id, device) 124 | tokenizer = AutoTokenizer.from_pretrained(model_id) 125 | with torch.no_grad(): 126 | string = "I am vasudev gupta. I like AI." 127 | inputs = tokenizer(string, return_tensors="pt") 128 | inputs = {k: v.to(device) for k, v in inputs.items()} 129 | hf_out = hf_model(**inputs).last_hidden_state 130 | out = model(inputs["input_ids"]) 131 | print((out - hf_out).abs()) 132 | # TODO: need to look at why we can't do low precision 133 | assert torch.allclose(out, hf_out, atol=1e-1), (out - hf_out).abs().max() 134 | 135 | 136 | def test_flops(): 137 | config = GPTConfig() 138 | model = FusedGPT(config).eval() 139 | num_tokens = 1024 140 | fwd_flops = model.get_fwd_flops(num_tokens) 141 | total_flops = fwd_flops * 3 142 | num_parameters = get_num_parameters(model) 143 | r = (fwd_flops * 3) / (6 * num_parameters * num_tokens) 144 | assert r >= 0.9995, r 145 | 146 | 147 | def test_estimate_days(): 148 | # llama-3.1 paper reports 54 days for pre-training 405B parameter model 149 | # its very close to what we get from following equation 150 | flops = 6 * (405 * 10**9) * (15 * 10**12) 151 | t = estimate_days(flops, mfu=0.45, gpu="h100", num_gpus=16_000) 152 | assert t == 59.24544994944388, t 153 | -------------------------------------------------------------------------------- /gpt.py: -------------------------------------------------------------------------------- 1 | # TRITON_INTERPRET=1 python3 gpt.py 2 | 3 | from dataclasses import dataclass 4 | 5 | import torch 6 | import torch.nn as nn 7 | from tqdm.auto import tqdm 8 | from transformers import AutoTokenizer 9 | from transformers import GPT2Model as HFGPT2 10 | 11 | from kernels import (flash_attention_v1, fused_embeddings, fused_ffn, 12 | fused_layer_norm, matmul_and_split_qkv) 13 | 14 | GPU_TO_FLOPS = { 15 | "v100": 130 * 10**12, 16 | "a100": 312 * 10**12, 17 | "h100": 989 * 10**12, 18 | } 19 | 20 | 21 | class FusedAttention(nn.Module): 22 | def __init__(self, hidden_size, num_heads, dropout_prob=0.0): 23 | super().__init__() 24 | self.dropout_prob = dropout_prob 25 | self.num_heads = num_heads 26 | 27 | self.hidden_size = hidden_size 28 | 29 | self.layer_norm_weight = nn.Parameter(torch.ones(hidden_size)) 30 | self.layer_norm_bias = nn.Parameter(torch.zeros(hidden_size)) 31 | 32 | self.c_attn_weight = nn.Parameter(torch.rand(hidden_size, 3 * hidden_size)) 33 | self.c_attn_bias = nn.Parameter(torch.rand(3 * hidden_size)) 34 | 35 | self.c_proj_weight = nn.Parameter(torch.rand(hidden_size, hidden_size)) 36 | self.c_proj_bias = nn.Parameter(torch.rand(hidden_size)) 37 | 38 | def forward(self, x): 39 | residual = x 40 | x = fused_layer_norm(x, self.layer_norm_weight.data, self.layer_norm_bias.data) 41 | q, k, v = matmul_and_split_qkv( 42 | x, self.c_attn_weight.data, self.c_attn_bias.data, self.num_heads 43 | ) 44 | dropout_prob = self.dropout_prob if self.training else 0.0 45 | x = flash_attention_v1( 46 | q, 47 | k, 48 | v, 49 | dropout_prob=dropout_prob, 50 | ) 51 | x = x.transpose(1, 2).contiguous().view(residual.shape) 52 | x = fused_ffn( 53 | x, 54 | self.c_proj_weight.data, 55 | bias=self.c_proj_bias.data, 56 | residual=residual, 57 | add_gelu=False, 58 | dropout_prob=dropout_prob, 59 | ) 60 | return x 61 | 62 | def get_fwd_flops(self, num_tokens): 63 | h = self.hidden_size 64 | layer_norm = num_tokens * h + num_tokens * h 65 | c_attn = num_tokens * (3 * h) * (2 * h) + num_tokens * (3 * h) 66 | c_proj = num_tokens * h * (2 * h) + num_tokens * h 67 | return layer_norm + c_attn + c_proj 68 | 69 | 70 | class FusedMLP(nn.Module): 71 | def __init__(self, hidden_size, dropout_prob=0.0): 72 | super().__init__() 73 | 74 | self.dropout_prob = dropout_prob 75 | 76 | self.layer_norm_weight = nn.Parameter(torch.ones((hidden_size,))) 77 | self.layer_norm_bias = nn.Parameter(torch.zeros((hidden_size,))) 78 | 79 | intermediate_size = 4 * hidden_size 80 | 81 | self.ffn1_weight = nn.Parameter(torch.rand(hidden_size, intermediate_size)) 82 | self.ffn1_bias = nn.Parameter(torch.rand(intermediate_size)) 83 | 84 | self.ffn2_weight = nn.Parameter(torch.rand(intermediate_size, hidden_size)) 85 | self.ffn2_bias = nn.Parameter(torch.rand(hidden_size)) 86 | 87 | self.hidden_size = hidden_size 88 | self.intermediate_size = intermediate_size 89 | 90 | def forward(self, x): 91 | # mlp = DROPOUT(GELU(LN(X) @ A + a) @ B + b) + X 92 | dropout_prob = self.dropout_prob if self.training else 0.0 93 | residual = x 94 | x = fused_layer_norm(x, self.layer_norm_weight.data, self.layer_norm_bias.data) 95 | x = fused_ffn( 96 | x, 97 | self.ffn1_weight.data, 98 | bias=self.ffn1_bias.data, 99 | residual=None, 100 | add_gelu=True, 101 | dropout_prob=dropout_prob, 102 | ) 103 | x = fused_ffn( 104 | x, 105 | self.ffn2_weight.data, 106 | bias=self.ffn2_bias.data, 107 | residual=residual, 108 | add_gelu=False, 109 | dropout_prob=dropout_prob, 110 | ) 111 | return x 112 | 113 | def get_fwd_flops(self, num_tokens): 114 | h = self.hidden_size 115 | mid = self.intermediate_size 116 | layer_norm = num_tokens * h + num_tokens * h 117 | ffn1 = num_tokens * mid * (2 * h) + num_tokens * mid 118 | ffn2 = num_tokens * h * (2 * mid) + num_tokens * h 119 | return layer_norm + ffn1 + ffn2 120 | 121 | 122 | @dataclass 123 | class GPTConfig: 124 | vocab_size: int = 50304 125 | block_size: int = 512 126 | n_layer: int = 12 127 | n_head: int = 12 128 | n_embd: int = 768 129 | dropout: float = 0.1 130 | 131 | 132 | class FusedGPT(nn.Module): 133 | def __init__(self, config): 134 | super().__init__() 135 | self.config = config 136 | 137 | self.wte_weight = nn.Parameter(torch.rand(config.vocab_size, config.n_embd)) 138 | self.wpe_weight = nn.Parameter(torch.rand(config.block_size, config.n_embd)) 139 | 140 | self.blocks = nn.ModuleList( 141 | [ 142 | nn.Sequential( 143 | FusedAttention( 144 | config.n_embd, 145 | config.n_head, 146 | dropout_prob=config.dropout, 147 | ), 148 | FusedMLP( 149 | config.n_embd, 150 | dropout_prob=config.dropout, 151 | ), 152 | ) 153 | for _ in range(config.n_layer) 154 | ] 155 | ) 156 | self.layer_norm_weight = nn.Parameter(torch.ones((config.n_embd,))) 157 | self.layer_norm_bias = nn.Parameter(torch.zeros((config.n_embd,))) 158 | 159 | # TODO: we don't wanna consume consume 2x memory here because of transpose and contiguous 160 | # instead implement transposed matmul in triton kernel 161 | # self.lm_head_weight = self.wte.weight.data.T.contiguous() 162 | 163 | def forward(self, x): 164 | # it does causal automatically, no need of separate attention/padding mask 165 | dropout_prob = self.config.dropout_prob if self.training else 0.0 166 | x = fused_embeddings( 167 | x, self.wte_weight.data, self.wpe_weight.data, dropout_prob=dropout_prob 168 | ) 169 | for block in self.blocks: 170 | x = block(x) 171 | x = fused_layer_norm(x, self.layer_norm_weight, self.layer_norm_bias) 172 | # x = fused_ffn( 173 | # x, 174 | # self.lm_head_weight, 175 | # bias=None, 176 | # residual=None, 177 | # add_gelu=False, 178 | # dropout_prob=0.0, 179 | # ) 180 | return x 181 | 182 | def get_fwd_flops(self, num_tokens): 183 | h = self.config.n_embd 184 | v = self.config.vocab_size 185 | p = self.config.block_size 186 | wte = num_tokens * h * (2 * v) 187 | wpe = num_tokens * h * (2 * p) 188 | blocks = sum( 189 | [ 190 | module.get_fwd_flops(num_tokens) 191 | for block in self.blocks 192 | for module in block 193 | ] 194 | ) 195 | layer_norm = num_tokens * h + num_tokens * h 196 | return blocks + layer_norm + wte + wpe 197 | 198 | 199 | def convert_huggingface_to_triton(hf_sd, hf_config): 200 | config = GPTConfig( 201 | vocab_size=hf_config.vocab_size, 202 | block_size=hf_config.n_ctx, 203 | n_layer=hf_config.n_layer, 204 | n_head=hf_config.n_head, 205 | n_embd=hf_config.n_embd, 206 | dropout=0.1, 207 | ) 208 | mapping = { 209 | "wte.weight": "wte_weight", 210 | "wpe.weight": "wpe_weight", 211 | "ln_f.weight": "layer_norm_weight", 212 | "ln_f.bias": "layer_norm_bias", 213 | } 214 | block = { 215 | "h.{i}.ln_1.weight": "blocks.{i}.0.layer_norm_weight", 216 | "h.{i}.ln_1.bias": "blocks.{i}.0.layer_norm_bias", 217 | "h.{i}.attn.bias": None, 218 | "h.{i}.attn.c_attn.weight": "blocks.{i}.0.c_attn_weight", 219 | "h.{i}.attn.c_attn.bias": "blocks.{i}.0.c_attn_bias", 220 | "h.{i}.attn.c_proj.weight": "blocks.{i}.0.c_proj_weight", 221 | "h.{i}.attn.c_proj.bias": "blocks.{i}.0.c_proj_bias", 222 | "h.{i}.ln_2.weight": "blocks.{i}.1.layer_norm_weight", 223 | "h.{i}.ln_2.bias": "blocks.{i}.1.layer_norm_bias", 224 | "h.{i}.mlp.c_fc.weight": "blocks.{i}.1.ffn1_weight", 225 | "h.{i}.mlp.c_fc.bias": "blocks.{i}.1.ffn1_bias", 226 | "h.{i}.mlp.c_proj.weight": "blocks.{i}.1.ffn2_weight", 227 | "h.{i}.mlp.c_proj.bias": "blocks.{i}.1.ffn2_bias", 228 | } 229 | for k, v in block.items(): 230 | if v is None: 231 | continue 232 | for i in range(config.n_layer): 233 | mapping[k.format(i=i)] = v.format(i=i) 234 | sd = {} 235 | for k, v in tqdm(hf_sd.items()): 236 | sd[mapping[k]] = v 237 | return sd, config 238 | 239 | 240 | def convert_hf_and_load_model(model_id, device): 241 | hf_model = HFGPT2.from_pretrained(model_id) 242 | state_dict, config = convert_huggingface_to_triton( 243 | hf_model.state_dict(), hf_model.config 244 | ) 245 | model = FusedGPT(config) 246 | model.load_state_dict(state_dict) 247 | return model.to(device).eval(), hf_model.to(device).eval() 248 | 249 | 250 | def estimate_days(flops, mfu=0.45, gpu="h100", num_gpus=1): 251 | # its probably very hard to achieve 0.45 mfu - LOL 252 | # but thats kinda SOTA in papers from top labs 253 | assert gpu in GPU_TO_FLOPS 254 | return flops / (mfu * GPU_TO_FLOPS[gpu] * 3600 * 24 * num_gpus) 255 | 256 | 257 | def get_num_parameters(model): 258 | return sum([p.numel() for p in model.parameters()]) 259 | 260 | 261 | def compute_mfu(flops_per_second, gpu="h100"): 262 | assert gpu in GPU_TO_FLOPS 263 | return flops_per_second / GPU_TO_FLOPS[gpu] 264 | -------------------------------------------------------------------------------- /kernels.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import triton 5 | import triton.language as tl 6 | 7 | # torch becomes 3x faster with following lines for fp32 8 | torch.backends.cuda.matmul.allow_tf32 = True 9 | torch.backends.cudnn.allow_tf32 = True 10 | 11 | 12 | # TODO: shift to `make_block_ptr`? 13 | 14 | 15 | # tl.math.tanh doesn't exist in CPU version of triton 16 | @triton.jit 17 | def tanh(x): 18 | return 2 * tl.sigmoid(2 * x) - 1 19 | 20 | 21 | @triton.jit 22 | def gelu_new(x): 23 | pi = math.pi 24 | a = tl.math.sqrt(2.0 / pi) 25 | b = x + 0.044715 * x * x * x 26 | return 0.5 * x * (1.0 + tanh(a * b)) 27 | 28 | 29 | # TODO: fixed seed would hurt the performance 30 | # but how do we modify seed design wise? 31 | @triton.jit 32 | def dropout(x, p, seed, offset): 33 | random = tl.rand(seed, offset) 34 | return tl.where(random > p, x / (1 - p), 0.0) 35 | 36 | 37 | @triton.jit 38 | def fused_embeddings_kernel( 39 | x_ptr, 40 | wte_ptr, 41 | wpe_ptr, 42 | z_ptr, 43 | B, 44 | L, 45 | V, 46 | P, 47 | H, 48 | dropout_prob=0.0, 49 | seed=1337, 50 | BLOCK_SIZE: tl.constexpr = 512, 51 | ): 52 | # f = dropout(wte(x) + wpe(x)) 53 | 54 | # x: (B*S,) 55 | # wte: (V, H) 56 | # wpe: (P, H) 57 | # z: (B*S, H) 58 | 59 | pid = tl.program_id(0) 60 | wte_ptr += tl.load(x_ptr + pid) * H 61 | wpe_ptr += (pid % L) * H 62 | z_ptr += pid * H 63 | 64 | for k in range(0, H, BLOCK_SIZE): 65 | offset = k + tl.arange(0, BLOCK_SIZE) 66 | mask = offset < H 67 | 68 | z = tl.load(wte_ptr + offset, mask=mask, other=0.0) 69 | z += tl.load(wpe_ptr + offset, mask=mask, other=0.0) 70 | z = dropout(z, dropout_prob, seed, offset) 71 | 72 | tl.store(z_ptr + offset, z, mask=mask) 73 | 74 | 75 | @torch.no_grad() 76 | def fused_embeddings(x, wte, wpe, dropout_prob=0.0): 77 | # x: (batch_size, seqlen) 78 | # wte: (vocab_size, hidden_size) 79 | # wpe: (block_size, hidden_size) 80 | assert wte.shape[1] == wpe.shape[1] 81 | assert x.is_contiguous() 82 | assert wte.is_contiguous() 83 | assert wpe.is_contiguous() 84 | B, L = x.shape 85 | V, H = wte.shape 86 | P = wpe.shape[0] 87 | z = torch.empty((B * L, H), device=x.device, dtype=wte.dtype) 88 | grid = (z.shape[0],) 89 | fused_embeddings_kernel[grid]( 90 | x.view(-1), 91 | wte, 92 | wpe, 93 | z, 94 | B, 95 | L, 96 | V, 97 | P, 98 | H, 99 | dropout_prob=dropout_prob, 100 | ) 101 | return z.view((B, L, H)) 102 | 103 | 104 | @triton.jit 105 | def fused_layer_norm_kernel( 106 | x_ptr, w_ptr, b_ptr, z_ptr, H, eps=1e-5, BLOCK_SIZE: tl.constexpr = 512 107 | ): 108 | # f = ((x - mean) / (std + eps)) * w + b 109 | # x: (M, H) 110 | # launch with 1D grid along M direction 111 | 112 | row_id = tl.program_id(0) 113 | x_ptr += row_id * H 114 | z_ptr += row_id * H 115 | 116 | x_mean = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) 117 | for i in range(0, H, BLOCK_SIZE): 118 | offset = i + tl.arange(0, BLOCK_SIZE) 119 | x = tl.load(x_ptr + offset, mask=(offset < H), other=0.0) 120 | x_mean += x.to(tl.float32) 121 | x_mean = tl.sum(x_mean) / H 122 | 123 | x_var = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) 124 | for i in range(0, H, BLOCK_SIZE): 125 | offset = i + tl.arange(0, BLOCK_SIZE) 126 | x = tl.load(x_ptr + offset, mask=(offset < H), other=x_mean) 127 | x = x.to(tl.float32) 128 | x_var += (x - x_mean) * (x - x_mean) 129 | x_var = tl.sum(x_var) / H 130 | rstd = 1 / tl.sqrt(x_var + eps) 131 | 132 | # TODO: we could prevent this extra loop if we fuse it in ffn block? 133 | # but thats quite hacky - so, lets move with extra loop for now 134 | for i in range(0, H, BLOCK_SIZE): 135 | offset = i + tl.arange(0, BLOCK_SIZE) 136 | mask = offset < H 137 | 138 | x = tl.load(x_ptr + offset, mask=mask, other=0.0) 139 | w = tl.load(w_ptr + offset, mask=mask, other=0.0) 140 | b = tl.load(b_ptr + offset, mask=mask, other=0.0) 141 | 142 | z = (x - x_mean) * rstd 143 | z = z * w + b 144 | 145 | tl.store(z_ptr + offset, z, mask=mask) 146 | 147 | 148 | @torch.no_grad() 149 | def fused_layer_norm(x, weight, bias): 150 | # x: (*, hidden_size) 151 | # weight: (hidden_size,) 152 | # bias: (hidden_size,) 153 | assert x.is_contiguous() 154 | assert weight.is_contiguous() 155 | assert bias.is_contiguous() 156 | assert weight.shape == bias.shape 157 | assert x.shape[-1] == weight.shape[0] 158 | out_shape = x.shape 159 | x = x.view((-1, x.shape[-1])) 160 | B, H = x.shape 161 | x = x.view((B, H)) 162 | z = torch.empty(x.shape, device=x.device, dtype=x.dtype) 163 | fused_layer_norm_kernel[(B,)](x, weight, bias, z, H) 164 | return z.view(out_shape) 165 | 166 | 167 | # TODO: implement grouping for extra 10% speedup 168 | # also, need to understand what's gemm matmul 169 | @triton.jit 170 | def fused_ffn_kernel( 171 | x_ptr, 172 | w_ptr, 173 | z_ptr, 174 | M, 175 | N, 176 | K, 177 | b_ptr=None, 178 | r_ptr=None, 179 | apply_gelu=False, 180 | dropout_prob=0.0, 181 | seed=1337, 182 | BLOCK_SIZE_M: tl.constexpr = 128, 183 | BLOCK_SIZE_N: tl.constexpr = 128, 184 | BLOCK_SIZE_K: tl.constexpr = 64, 185 | ): 186 | # f = dropout(gelu(x @ w + b)) + residual 187 | # launch with 2D grid of blocks along M & N directions 188 | 189 | pid_m = tl.program_id(0) 190 | pid_n = tl.program_id(1) 191 | 192 | # intuition is this: In normal math, we basically take 1 row of X & 1 column of W 193 | # and just multiply element wise and add stuff 194 | # but here we add multiple consecutive rows of X & multiple consecutive rows of W 195 | # and do dot product basically 196 | 197 | # pid_m: vertical 198 | # pid_n: horizontal 199 | 200 | # we basically move over output matrix and computes each block in each kernel 201 | 202 | # x: (M, K) 203 | # w: (K, N) 204 | # b: (N,) 205 | # z: (M, N) 206 | 207 | # x block size: (BLOCK_SIZE_M, BLOCK_SIZE_K) 208 | # w block size: (BLOCK_SIZE_K, BLOCK_SIZE_N) 209 | # z block size: (BLOCK_SIZE_M, BLOCK_SIZE_N) 210 | 211 | # these are the pointer of 1st element for each block in output matrix 212 | 213 | # we basically add row-block-shift here 214 | offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)[:, None] 215 | 216 | # we basically add column-block-shift here 217 | offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)[None, :] 218 | 219 | # each block in z would be of shape-(M, N) 220 | # block of size: BLOCK_SIZE_M x BLOCK_SIZE_K would move in horizontal direction 221 | # block of size: BLOCK_SIZE_K x BLOCK_SIZE_N would move in vertical direction 222 | 223 | # we need this loop because we might not be able to fit full row of X & full column of W in-memory 224 | z = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 225 | for k in range(0, K, BLOCK_SIZE_K): 226 | x_k = tl.arange(0, BLOCK_SIZE_K)[None, :] + k 227 | x = tl.load(x_ptr + offs_m * K + x_k, mask=(offs_m < M) & (x_k < K), other=0.0) 228 | # TODO: need to read why casting to fp16 is important here 229 | x = x.to(tl.float16) 230 | # (BLOCK_SIZE_M, BLOCK_SIZE_K) 231 | 232 | w_k = tl.arange(0, BLOCK_SIZE_K)[:, None] + k 233 | w = tl.load(w_ptr + w_k * N + offs_n, mask=(w_k < K) & (offs_n < N), other=0.0) 234 | w = w.to(tl.float16) 235 | # (BLOCK_SIZE_K, BLOCK_SIZE_N) 236 | 237 | z = tl.dot(x, w, acc=z) 238 | # (BLOCK_SIZE_M, BLOCK_SIZE_N) 239 | 240 | if b_ptr is not None: 241 | b = tl.load(b_ptr + offs_n, mask=(offs_n < N), other=0.0) 242 | z += b.to(tl.float32) 243 | # (1, BLOCK_SIZE_N) 244 | 245 | z_offset = offs_m * N + offs_n 246 | z_mask = (offs_m < M) & (offs_n < N) 247 | 248 | if apply_gelu: 249 | z = gelu_new(z) 250 | if dropout_prob > 0.0: 251 | z = dropout(z, dropout_prob, seed, z_offset) 252 | 253 | if r_ptr is not None: 254 | r = tl.load(r_ptr + z_offset, mask=z_mask) 255 | z += r.to(tl.float32) 256 | 257 | tl.store(z_ptr + z_offset, z, mask=z_mask) 258 | 259 | 260 | @torch.no_grad() 261 | def fused_ffn( 262 | x, 263 | weight, 264 | bias=None, 265 | residual=None, 266 | add_gelu=False, 267 | dropout_prob=0.0, 268 | ): 269 | # x: (*, K) 270 | # weight: (K, N) 271 | # bias: (N,) 272 | # f = dropout(gelu(x @ w + b)) + residual 273 | 274 | out_shape_0 = x.shape[:-1] 275 | x = x.view((-1, x.shape[-1])) 276 | 277 | M, K = x.shape 278 | N = weight.shape[1] 279 | 280 | x = x.view((M, K)) 281 | z = torch.empty((M, N), device=x.device, dtype=x.dtype) 282 | 283 | assert x.is_contiguous() 284 | assert weight.is_contiguous() 285 | assert x.shape[1] == weight.shape[0] 286 | if bias is not None: 287 | assert bias.is_contiguous() 288 | assert weight.shape[1] == bias.shape[0] 289 | if residual is not None: 290 | residual = residual.view(z.shape) 291 | assert residual.is_contiguous() 292 | 293 | # (128, 128, 64) leads to 6x slowdown with num_stages == 4 294 | # while its 40% faster with num_stages = 8 295 | BLOCK_SIZE_M = 128 296 | BLOCK_SIZE_N = 128 297 | BLOCK_SIZE_K = 64 298 | grid = (triton.cdiv(M, BLOCK_SIZE_M), triton.cdiv(N, BLOCK_SIZE_N), 1) 299 | fused_ffn_kernel[grid]( 300 | x, 301 | weight, 302 | z, 303 | M, 304 | N, 305 | K, 306 | apply_gelu=add_gelu, 307 | dropout_prob=dropout_prob, 308 | b_ptr=bias, 309 | r_ptr=residual, 310 | BLOCK_SIZE_M=BLOCK_SIZE_M, 311 | BLOCK_SIZE_N=BLOCK_SIZE_N, 312 | BLOCK_SIZE_K=BLOCK_SIZE_K, 313 | num_warps=8, 314 | ) 315 | return z.view((*out_shape_0, N)) 316 | 317 | 318 | # @triton.jit 319 | # def softmax_kernel(x_ptr, z_ptr, L, N1, H, BLOCK_SIZE_L: tl.constexpr, B1: tl.constexpr): 320 | # # x: (L, H) 321 | # # out: (L, H) 322 | # pid_0 = tl.program_id(0) 323 | # x_ptr += pid_0 * H 324 | # z_ptr += pid_0 * H 325 | # max_value, denominator = 0., 0. 326 | # for i in range(0, H, B1): 327 | # offset = tl.arange(i, i + B1) 328 | # x = tl.load(x_ptr + offset, mask=offset < H, other=0) 329 | # block_max_value = tl.max(x, keep_dims=True) 330 | # new_max_value = tl.where( 331 | # block_max_value > max_value, block_max_value, max_value 332 | # ) 333 | # x = tl.exp(x - new_max_value) 334 | # denominator = denominator / tl.exp(new_max_value - max_value) 335 | # denominator += tl.sum(x) 336 | # max_value = new_max_value 337 | # for i in range(0, H, B1): 338 | # offset = tl.arange(i, i + B1) 339 | # x = tl.load(x_ptr + offset, mask=offset < H, other=0) 340 | # z = tl.exp(x - max_value) 341 | # z = z / denominator 342 | # tl.store(z_ptr + offset, z, mask=offset < H) 343 | 344 | 345 | # TODO: what if we just write separate kernel for this? 346 | # TODO: can we fuse this in attention kernel? 347 | @torch.no_grad() 348 | def matmul_and_split_qkv(x, weight, bias, num_heads): 349 | # x: (batch_size, seqlen, hidden_size) 350 | x = fused_ffn(x, weight, bias=bias) 351 | # (batch_size, seqlen, 3 * hidden_size) 352 | batch_size, seqlen, hidden_size = x.shape 353 | assert hidden_size % 3 == 0, hidden_size 354 | hidden_size = hidden_size // 3 355 | q, k, v = x.split(hidden_size, dim=2) 356 | assert hidden_size % num_heads == 0, (hidden_size, num_heads) 357 | head_size = hidden_size // num_heads 358 | # (batch_size, seqlen, num_heads, head_size) 359 | # TODO: following is unecessary read & write - memory bound operation 360 | q, k, v = map( 361 | lambda x: x.view(batch_size, seqlen, num_heads, head_size) 362 | .transpose(1, 2) 363 | .contiguous(), 364 | (q, k, v), 365 | ) 366 | # (batch_size, num_heads, seqlen, head_size) 367 | return q, k, v 368 | 369 | 370 | # TODO: does triton re-compile when different tl.constexpr is passed? 371 | # TODO: read about flash-2 and see if we can switch to that 372 | # TODO: then read about flash-3 and see if we can switch to that instead 373 | # TODO: can we do score computation for only unmasked positions? 374 | # pytorch flex-attention does something like that - it would make computation 50% efficient 375 | @triton.jit 376 | def flash_attention_v1_kernel( 377 | q_ptr, 378 | k_ptr, 379 | v_ptr, 380 | z_ptr, 381 | BN, 382 | Lq, 383 | Lk, 384 | scale, 385 | H: tl.constexpr, 386 | dropout_prob=0.0, 387 | seed=1337, 388 | BLOCK_SIZE_L: tl.constexpr = 64, 389 | ): 390 | # f = (q @ k.T) / math.sqrt(head_size) 391 | # f = dropout(F.softmax(apply_causal_mask(f), dim=-1)) 392 | # f = f @ v 393 | 394 | # q, z: (B * N, Lq, H) 395 | # k, v: (B * N, Lk, H) 396 | 397 | q_ptr += tl.program_id(0) * (Lq * H) 398 | z_ptr += tl.program_id(0) * (Lq * H) 399 | k_ptr += tl.program_id(0) * (Lk * H) 400 | v_ptr += tl.program_id(0) * (Lk * H) 401 | 402 | # assuming that `H` can stay SRAM fully and doesn't require blocking 403 | # this assumptions was made for original implementation of flash attention as well 404 | # its reasonable as most of LLMs use head size <= 256 405 | offs_lq = tl.program_id(1) * BLOCK_SIZE_L + tl.arange(0, BLOCK_SIZE_L) 406 | offs_h = tl.arange(0, H) 407 | 408 | q_mask = offs_lq[:, None] < Lq 409 | q_offs = offs_lq[:, None] * H + offs_h[None, :] 410 | # this remains in sram throughtout computation 411 | q = tl.load(q_ptr + q_offs, mask=q_mask, other=0.0) 412 | # (BLOCK_SIZE_L, H) 413 | 414 | q = q.to(tl.float16) 415 | 416 | # loop over k, v and compute attention & weighted v 417 | z = tl.zeros((BLOCK_SIZE_L, H), dtype=tl.float32) 418 | max_value = tl.zeros((BLOCK_SIZE_L, 1), dtype=tl.float32) + float("-inf") 419 | denominator = tl.zeros((BLOCK_SIZE_L, 1), dtype=tl.float32) 420 | for i in range(0, Lk, BLOCK_SIZE_L): 421 | offs_lk = i + tl.arange(0, BLOCK_SIZE_L) 422 | kv_mask = offs_lk[:, None] < Lk 423 | kv_offs = offs_lk[:, None] * H + offs_h[None, :] 424 | 425 | k = tl.load(k_ptr + kv_offs, mask=kv_mask, other=0.0) 426 | # (BLOCK_SIZE_L, H) 427 | 428 | k = k.to(q.dtype) 429 | qk = tl.dot(q, k.trans(1, 0)) * scale 430 | # (BLOCK_SIZE_L, BLOCK_SIZE_L) 431 | 432 | # TODO: remove eventually, its for debugging 433 | # qk_offs = offs_lq[:, None] * Lk + offs_lk[None, :] 434 | # tl.store(z_ptr + qk_offs, qk) 435 | 436 | # apply causal mask ; we still compute the attention over the future blocks 437 | # we wanna optimise that eventually 438 | qk = tl.where(offs_lq[:, None] >= offs_lk[None, :], qk, float("-inf")) 439 | 440 | block_max_value = tl.max(qk, axis=1, keep_dims=True) 441 | # (BLOCK_SIZE_L, 1) 442 | new_max_value = tl.where( 443 | block_max_value > max_value, block_max_value, max_value 444 | ) 445 | # (BLOCK_SIZE_L, 1) 446 | 447 | qk = tl.exp(qk - new_max_value) 448 | # (BLOCK_SIZE_L, BLOCK_SIZE_L) 449 | 450 | multiplier = tl.exp(max_value - new_max_value) 451 | denominator *= multiplier 452 | z *= multiplier 453 | 454 | denominator += tl.sum(qk, axis=1, keep_dims=True) 455 | max_value = new_max_value 456 | # (BLOCK_SIZE_L, 1) 457 | 458 | if dropout_prob > 0.0: 459 | qk_offs = offs_lq[:, None] * Lk + offs_lk[None, :] 460 | qk = dropout(qk, dropout_prob, seed, qk_offs) 461 | 462 | v = tl.load(v_ptr + kv_offs, mask=kv_mask, other=0.0) 463 | # (BLOCK_SIZE_L, H) 464 | 465 | v = v.to(q.dtype) 466 | qk = qk.to(q.dtype) 467 | 468 | z = tl.dot(qk, v, acc=z) 469 | # (BLOCK_SIZE_L, H) 470 | 471 | z /= denominator 472 | z = z.to(z_ptr.dtype.element_ty) 473 | 474 | tl.store(z_ptr + q_offs, z, mask=q_mask) 475 | 476 | 477 | @torch.no_grad() 478 | def flash_attention_v1(q, k, v, dropout_prob=0.0): 479 | # (batch_size, num_heads, seqlen, head_size) 480 | assert q.shape[:2] == k.shape[:2] 481 | assert q.shape[-1] == k.shape[-1] 482 | assert k.shape == v.shape 483 | # B: batch_size 484 | # N: num_heads 485 | # L: seqlen 486 | # H: head_size 487 | B, N, Lq, H = q.shape 488 | Lk = k.shape[2] 489 | 490 | assert H in {16, 32, 64, 128, 256} 491 | # above condition is necessary because shared memory is limited 492 | # and we don't do additional blocking over head_size dim 493 | 494 | q = q.view(B * N, Lq, H) 495 | k = k.view(B * N, Lk, H) 496 | v = v.view(B * N, Lk, H) 497 | 498 | z = torch.empty_like(q) 499 | 500 | # z = torch.rand((B * N, Lq, Lk), dtype=q.dtype, device=q.device) 501 | 502 | assert q.is_contiguous() 503 | assert k.is_contiguous() 504 | assert v.is_contiguous() 505 | assert z.is_contiguous() 506 | 507 | scale = 1 / math.sqrt(H) 508 | 509 | BLOCK_SIZE_L = 64 510 | grid = (B * N, triton.cdiv(Lq, BLOCK_SIZE_L), 1) 511 | flash_attention_v1_kernel[grid]( 512 | q, 513 | k, 514 | v, 515 | z, 516 | B * N, 517 | Lq, 518 | Lk, 519 | scale, 520 | H, 521 | dropout_prob=dropout_prob, 522 | BLOCK_SIZE_L=BLOCK_SIZE_L, 523 | # num_warps=8, 524 | ) 525 | return z.view(B, N, Lq, H) 526 | --------------------------------------------------------------------------------