├── pyproject.toml ├── benchmarking ├── __init__.py ├── benchmark_utils.py └── profiler.py ├── .gitignore ├── models └── llama │ ├── requirements.txt │ ├── __init__.py │ ├── llama │ ├── __init__.py │ ├── math_ops.py │ ├── tokenizer.py │ ├── model.py │ └── generation.py │ ├── example_text_completion.py │ └── example_chat_completion.py ├── kernels ├── blocksparse │ ├── __init__.py │ ├── softmax.py │ └── matmul.py ├── __init__.py ├── cross_entropy.py ├── matmul_perf_model.py ├── matmul.py └── flash_attention.py ├── test ├── conftest.py ├── test_cross_entropy.py ├── test_flash_attention.py ├── test_inductor.py ├── test_blocksparse.py └── test_matmul.py ├── LICENSE └── main.py /pyproject.toml: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /benchmarking/__init__.py: -------------------------------------------------------------------------------- 1 | from .profiler import Profiler 2 | from .benchmark_utils import compare_benchmarks 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python caches 2 | __pycache__/ 3 | *.py[cod] 4 | .pytest_cache 5 | **/.cache 6 | **/meta-llama/**/* 7 | -------------------------------------------------------------------------------- /models/llama/requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | huggingface-hub 3 | fire 4 | blobfile 5 | tiktoken 6 | fairscale 7 | triton 8 | pandas 9 | -------------------------------------------------------------------------------- /kernels/blocksparse/__init__.py: -------------------------------------------------------------------------------- 1 | from .matmul import matmul 2 | from .softmax import softmax 3 | 4 | __all__ = [ 5 | "matmul", 6 | "softmax", 7 | ] 8 | -------------------------------------------------------------------------------- /models/llama/__init__.py: -------------------------------------------------------------------------------- 1 | from .example_chat_completion import main as llama_example_chat_completion 2 | from .example_text_completion import main as llama_example_text_completion 3 | -------------------------------------------------------------------------------- /test/conftest.py: -------------------------------------------------------------------------------- 1 | # content of conftest.py 2 | 3 | 4 | def pytest_configure(config): 5 | config.addinivalue_line( 6 | "markers", "interpreter: indicate whether interpreter supports the test" 7 | ) 8 | -------------------------------------------------------------------------------- /models/llama/llama/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. 3 | 4 | from .generation import Llama 5 | from .model import ModelArgs, Transformer 6 | from .tokenizer import Dialog, Tokenizer 7 | -------------------------------------------------------------------------------- /kernels/__init__.py: -------------------------------------------------------------------------------- 1 | # from .conv import _conv, conv 2 | from . import blocksparse 3 | from .cross_entropy import _cross_entropy, cross_entropy 4 | from .flash_attention import attention 5 | from .matmul import _matmul, get_higher_dtype, matmul 6 | 7 | __all__ = [ 8 | "blocksparse", 9 | "_cross_entropy", 10 | "cross_entropy", 11 | "_matmul", 12 | "matmul", 13 | "attention", 14 | "get_higher_dtype", 15 | ] 16 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWAR 20 | -------------------------------------------------------------------------------- /benchmarking/benchmark_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | import pandas as pd 3 | 4 | 5 | def compare_benchmarks(benchmarks: Dict[str, Dict[str, Any]]) -> Dict[str, Any]: 6 | series_dict = {k: pd.Series(v.values()) for k, v in benchmarks.items()} 7 | series_dict["kernel_path"] = pd.Series( 8 | benchmarks[list(benchmarks.keys())[0]].keys() 9 | ) 10 | series_dict["kernel"] = pd.Series( 11 | [k.split(".")[-1] for k in series_dict["kernel_path"]] 12 | ) 13 | df = pd.DataFrame() 14 | 15 | for k, v in series_dict.items(): 16 | df[k] = v 17 | columns = [c for c in df.columns if not "kernel" in c] 18 | for i in range(len(columns)): 19 | for j in range(i + 1, len(columns)): 20 | # calculate the difference between the two columns 21 | diff_col_name = f"{columns[i]}-{columns[j]}" 22 | df[diff_col_name] = df[columns[i]] - df[columns[j]] 23 | df.sort_values(by="kernel_path", inplace=True) 24 | columns = [c for c in df.columns if not "kernel" in c] 25 | columns = ["kernel", "kernel_path"] + columns 26 | df = df[columns] 27 | df.set_index("kernel", inplace=True) 28 | return df 29 | -------------------------------------------------------------------------------- /test/test_cross_entropy.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | import triton 5 | import triton.ops 6 | 7 | 8 | @pytest.mark.parametrize( 9 | "M, N, dtype, mode", 10 | [ # 11 | (M, N, dtype, mode) 12 | for M in [1024, 821] 13 | for N in [512, 857, 1871, 2089, 8573, 31000] 14 | for dtype in ["float16", "float32"] 15 | for mode in ["forward", "backward"] 16 | ], 17 | ) 18 | def test_op(M, N, dtype, mode, device): 19 | capability = torch.cuda.get_device_capability() 20 | if capability[0] < 8 and dtype == "bfloat16": 21 | pytest.skip("Only test bfloat16 on devices with sm >= 80") 22 | dtype = { 23 | "bfloat16": torch.bfloat16, 24 | "float16": torch.float16, 25 | "float32": torch.float32, 26 | }[dtype] 27 | # create inputs 28 | x = torch.randn(M, N, dtype=dtype, device=device, requires_grad=True) 29 | idx = 4 + torch.ones(M, dtype=torch.int64, device=device) 30 | # forward pass 31 | tt_y = triton.ops.cross_entropy(x, idx) 32 | th_y = torch.nn.CrossEntropyLoss(reduction="none")(x, idx) 33 | if mode == "forward": 34 | torch.testing.assert_close(th_y, tt_y) 35 | # backward pass 36 | elif mode == "backward": 37 | dy = torch.randn_like(tt_y) 38 | # triton backward 39 | tt_y.backward(dy) 40 | tt_dx = x.grad.clone() 41 | # torch backward 42 | x.grad = None 43 | th_y.backward(dy) 44 | th_dx = x.grad.clone() 45 | if dtype == torch.float16: 46 | torch.testing.assert_close(th_dx, tt_dx, rtol=0.001, atol=0.001) 47 | else: 48 | torch.testing.assert_close(th_dx, tt_dx) 49 | -------------------------------------------------------------------------------- /models/llama/example_text_completion.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. 3 | 4 | from typing import List 5 | 6 | from .llama import Llama 7 | from benchmarking import Profiler 8 | 9 | 10 | @Profiler.profiling_decorator(record_name="text_completion", skip_profiling=True) 11 | def main( 12 | ckpt_dir: str, 13 | tokenizer_path: str, 14 | use_triton: bool = False, 15 | temperature: float = 0.6, 16 | top_p: float = 0.9, 17 | max_seq_len: int = 128, 18 | max_gen_len: int = 64, 19 | max_batch_size: int = 4, 20 | suppress_prints: bool = False, 21 | ): 22 | """ 23 | Examples to run with the pre-trained models (no fine-tuning). Prompts are 24 | usually in the form of an incomplete text prefix that the model can then try to complete. 25 | 26 | The context window of llama3 models is 8192 tokens, so `max_seq_len` needs to be <= 8192. 27 | `max_gen_len` is needed because pre-trained models usually do not stop completions naturally. 28 | """ 29 | generator = Llama.build( 30 | ckpt_dir=ckpt_dir, 31 | tokenizer_path=tokenizer_path, 32 | max_seq_len=max_seq_len, 33 | max_batch_size=max_batch_size, 34 | use_triton=use_triton, 35 | ) 36 | 37 | prompts: List[str] = [ 38 | # For these prompts, the expected answer is the natural continuation of the prompt 39 | "I believe the meaning of life is", 40 | "Simply put, the theory of relativity states that ", 41 | """A brief message congratulating the team on the launch: 42 | 43 | Hi everyone, 44 | 45 | I just """, 46 | # Few shot prompt (providing a few examples before asking model to complete more); 47 | """Translate English to French: 48 | 49 | sea otter => loutre de mer 50 | peppermint => menthe poivrée 51 | plush girafe => girafe peluche 52 | cheese =>""", 53 | ] 54 | results = generator.text_completion( 55 | prompts, 56 | max_gen_len=max_gen_len, 57 | temperature=temperature, 58 | top_p=top_p, 59 | ) 60 | if suppress_prints: 61 | return 62 | for prompt, result in zip(prompts, results): 63 | print(prompt) 64 | print(f"> {result['generation']}") 65 | print("\n==================================\n") 66 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import os 3 | 4 | import fire 5 | import torch 6 | 7 | from models.llama import llama_example_chat_completion, llama_example_text_completion 8 | from benchmarking import Profiler, compare_benchmarks 9 | import pprint 10 | 11 | 12 | def main(operation: str, profile=False, benchmark=False, **kwargs): 13 | """ 14 | all kwargs are passed to the operation you choose. 15 | 16 | The profile and benchmark flags can be set independently of each other 17 | *but* if you set both then profiling will be done on both sets 18 | """ 19 | p = Profiler(profile, benchmark) 20 | profiles = {} 21 | benchmarks = {} 22 | if benchmark: 23 | # warm_up 24 | torch.cuda.empty_cache() 25 | kwargs["suppress_prints"] = True 26 | p = Profiler(False, False) 27 | runner(operation, kwargs) 28 | 29 | kwargs["use_triton"] = False 30 | Profiler.reset() 31 | p = Profiler(profile, benchmark) 32 | torch.cuda.empty_cache() 33 | runner(operation, kwargs) 34 | benchmarks["triton"] = Profiler.get_benchmark_vals() 35 | profiles["triton"] = Profiler.get_profiling_data() 36 | Profiler.reset() 37 | p = Profiler(profile, benchmark) 38 | 39 | kwargs["use_triton"] = True 40 | kwargs["suppress_prints"] = False 41 | Profiler.reset() 42 | p = Profiler(profile, benchmark) 43 | torch.cuda.empty_cache() 44 | runner(operation, kwargs) 45 | benchmarks["non_triton"] = Profiler.get_benchmark_vals() 46 | profiles["non_triton"] = Profiler.get_profiling_data() 47 | elif profile: 48 | runner(operation, kwargs) 49 | data = Profiler.get_profiling_data() 50 | if kwargs["use_triton"]: 51 | profiles["triton"] = data 52 | else: 53 | profiles["non_triton"] = data 54 | else: 55 | runner(operation, kwargs) 56 | 57 | if profile: 58 | for k, v in profiles.items(): 59 | print(f"Profile for {k}") 60 | pprint.pprint(v, width=160) 61 | print("\n==================================\n") 62 | if benchmark: 63 | print("Benchmark results") 64 | output = compare_benchmarks(benchmarks) 65 | print(output) 66 | print("\n==================================\n") 67 | 68 | 69 | def runner(operation: str, kwargs): 70 | if operation == "llama_chat_completion": 71 | llama_example_chat_completion(**kwargs) 72 | elif operation == "llama_text_completion": 73 | llama_example_text_completion(**kwargs) 74 | else: 75 | raise ValueError(f"Unknown operation: {operation}") 76 | 77 | 78 | if __name__ == "__main__": 79 | os.environ["RANK"] = "0" 80 | os.environ["WORLD_SIZE"] = "1" 81 | os.environ["MASTER_ADDR"] = "127.0.0.1" 82 | os.environ["MASTER_PORT"] = "29500" 83 | fire.Fire(main) 84 | -------------------------------------------------------------------------------- /benchmarking/profiler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import contextlib 3 | import time 4 | from collections import defaultdict 5 | 6 | 7 | class Profiler: 8 | _instance = None 9 | 10 | def __new__(cls, should_profile: bool = False, benchmark: bool = False): 11 | if cls._instance is None: 12 | cls._instance = super().__new__(cls) 13 | cls._instance.profiler = ( 14 | torch.profiler.profile( 15 | record_shapes=True, 16 | with_flops=True, 17 | profile_memory=True, 18 | with_stack=True, 19 | with_modules=True, 20 | ) 21 | if should_profile 22 | else None 23 | ) 24 | cls._instance.benchmark = benchmark 25 | cls._instance.benchmark_vals = defaultdict(list) 26 | cls._instance.function_stack = [] 27 | 28 | return cls._instance 29 | 30 | @classmethod 31 | def reset(cls): 32 | cls._instance = None 33 | 34 | @classmethod 35 | def profiling_decorator( 36 | cls, 37 | record_name: str = None, 38 | skip_profiling: bool = False, 39 | skip_benchmark: bool = False, 40 | ): 41 | def inner(func): 42 | def wrapper(*args, **kwargs): 43 | if not cls._instance or (skip_profiling and skip_benchmark): 44 | return func(*args, **kwargs) 45 | cls._instance.function_stack.append(record_name or func.__name__) 46 | name = ".".join(cls._instance.function_stack) 47 | if cls._instance.profiler and not skip_profiling: 48 | cls._instance.profiler.start() 49 | start_time = time.perf_counter() 50 | 51 | with torch.profiler.record_function(name): 52 | result = func(*args, **kwargs) 53 | 54 | end_time = time.perf_counter() 55 | if cls._instance.benchmark and not skip_benchmark: 56 | cls._instance.benchmark_vals[name].append(end_time - start_time) 57 | if cls._instance.profiler and not skip_profiling: 58 | cls._instance.profiler.stop() 59 | cls._instance.function_stack.pop() 60 | return result 61 | 62 | return wrapper 63 | 64 | return inner 65 | 66 | @classmethod 67 | def step(cls): 68 | if cls._instance and cls._instance.profiler: 69 | cls._instance.profiler.step() 70 | 71 | @classmethod 72 | def get_benchmark_vals(cls): 73 | if cls._instance and cls._instance.benchmark: 74 | return {k: sum(v) / len(v) for k, v in cls._instance.benchmark_vals.items()} 75 | return None 76 | 77 | @classmethod 78 | def get_profiling_data(cls): 79 | if cls._instance and cls._instance.profiler: 80 | return self.profiler.key_averages() 81 | return None 82 | -------------------------------------------------------------------------------- /models/llama/example_chat_completion.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. 3 | from typing import List, Optional 4 | 5 | from .llama import Dialog, Llama 6 | from benchmarking import Profiler 7 | 8 | 9 | @Profiler.profiling_decorator(record_name="chat_completion", skip_profiling=True) 10 | def main( 11 | ckpt_dir: str, 12 | tokenizer_path: str, 13 | use_triton: bool = False, 14 | temperature: float = 0.6, 15 | top_p: float = 0.9, 16 | max_seq_len: int = 512, 17 | max_batch_size: int = 4, 18 | max_gen_len: Optional[int] = None, 19 | suppress_prints: bool = False, 20 | ): 21 | """ 22 | Examples to run with the models finetuned for chat. Prompts correspond of chat 23 | turns between the user and assistant with the final one always being the user. 24 | 25 | An optional system prompt at the beginning to control how the model should respond 26 | is also supported. 27 | 28 | The context window of llama3 models is 8192 tokens, so `max_seq_len` needs to be <= 8192. 29 | 30 | `max_gen_len` is optional because finetuned models are able to stop generations naturally. 31 | """ 32 | generator = Llama.build( 33 | ckpt_dir=ckpt_dir, 34 | tokenizer_path=tokenizer_path, 35 | max_seq_len=max_seq_len, 36 | max_batch_size=max_batch_size, 37 | use_triton=use_triton, 38 | ) 39 | 40 | dialogs: List[Dialog] = [ 41 | [{"role": "user", "content": "what is the recipe of mayonnaise?"}], 42 | [ 43 | {"role": "user", "content": "I am going to Paris, what should I see?"}, 44 | { 45 | "role": "assistant", 46 | "content": """\ 47 | Paris, the capital of France, is known for its stunning architecture, art museums, historical landmarks, and romantic atmosphere. Here are some of the top attractions to see in Paris: 48 | 49 | 1. The Eiffel Tower: The iconic Eiffel Tower is one of the most recognizable landmarks in the world and offers breathtaking views of the city. 50 | 2. The Louvre Museum: The Louvre is one of the world's largest and most famous museums, housing an impressive collection of art and artifacts, including the Mona Lisa. 51 | 3. Notre-Dame Cathedral: This beautiful cathedral is one of the most famous landmarks in Paris and is known for its Gothic architecture and stunning stained glass windows. 52 | 53 | These are just a few of the many attractions that Paris has to offer. With so much to see and do, it's no wonder that Paris is one of the most popular tourist destinations in the world.""", 54 | }, 55 | {"role": "user", "content": "What is so great about #1?"}, 56 | ], 57 | [ 58 | {"role": "system", "content": "Always answer with Haiku"}, 59 | {"role": "user", "content": "I am going to Paris, what should I see?"}, 60 | ], 61 | [ 62 | { 63 | "role": "system", 64 | "content": "Always answer with emojis", 65 | }, 66 | {"role": "user", "content": "How to go from Beijing to NY?"}, 67 | ], 68 | ] 69 | results = generator.chat_completion( 70 | dialogs, 71 | max_gen_len=max_gen_len, 72 | temperature=temperature, 73 | top_p=top_p, 74 | ) 75 | 76 | if suppress_prints: 77 | return 78 | for dialog, result in zip(dialogs, results): 79 | for msg in dialog: 80 | print(f"{msg['role'].capitalize()}: {msg['content']}\n") 81 | print( 82 | f"> {result['generation']['role'].capitalize()}: {result['generation']['content']}" 83 | ) 84 | print("\n==================================\n") 85 | -------------------------------------------------------------------------------- /kernels/cross_entropy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from triton import heuristics, jit 4 | from triton import language as tl 5 | from triton import next_power_of_2 6 | 7 | 8 | def num_warps(N): 9 | if N < 2048: 10 | return 4 11 | elif N < 8192: 12 | return 8 13 | return 16 14 | 15 | 16 | @heuristics({"num_warps": lambda nargs: num_warps(nargs["N"])}) 17 | @heuristics({"BLOCK": lambda nargs: next_power_of_2(nargs["N"])}) 18 | @jit 19 | def _forward(LOGITS, PROBS, IDX, LOSS, N, BLOCK: tl.constexpr): 20 | row = tl.program_id(0) 21 | cols = tl.arange(0, BLOCK) 22 | idx = tl.load(IDX + row) 23 | # pointers to logit and probs 24 | LOGITS = LOGITS + row * N + cols 25 | WRIT_PROBS = PROBS + row * N + cols 26 | READ_PROBS = PROBS + row * N + idx 27 | # write-back negative log-probs 28 | logits = tl.load(LOGITS, mask=cols < N, other=-float("inf")) 29 | logits = logits.to(tl.float32) 30 | logits = logits - tl.max(logits, 0) 31 | probs = tl.log(tl.sum(tl.exp(logits), 0)) - logits 32 | tl.store(WRIT_PROBS, probs, mask=cols < N) 33 | # There is a bug in the compiler, which fails to insert a barrier here. 34 | # We add it explicitly for now. Will be fixed soon. 35 | tl.debug_barrier() 36 | # write-back loss 37 | probs = tl.load(READ_PROBS) 38 | tl.store(LOSS + row, probs) 39 | 40 | 41 | @heuristics({"num_warps": lambda nargs: num_warps(nargs["N"])}) 42 | @heuristics({"BLOCK": lambda nargs: next_power_of_2(nargs["N"])}) 43 | @jit 44 | def _backward(PROBS, IDX, DPROBS, N, BLOCK: tl.constexpr): 45 | row = tl.program_id(0) 46 | cols = tl.arange(0, BLOCK) 47 | idx = tl.load(IDX + row) 48 | # pointers to probs 49 | PROBS = PROBS + row * N + cols 50 | # We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k] 51 | # and we have -log(p[k]) stored in PROBS, so this is easy 52 | probs = -tl.load(PROBS, mask=cols < N, other=float("inf")) 53 | probs = tl.exp(probs.to(tl.float32)) 54 | delta = cols == idx 55 | # write result in-place in PROBS 56 | dout = tl.load(DPROBS + row) 57 | din = (probs - delta) * dout 58 | tl.store(PROBS, din.to(PROBS.dtype.element_ty), mask=cols < N) 59 | 60 | 61 | class _cross_entropy(torch.autograd.Function): 62 | 63 | @classmethod 64 | def forward(cls, ctx, logits, indices): 65 | # make sure we can use triton 66 | assert indices.dtype == torch.int64, "Indices are expected to be of type long." 67 | # make kernel 68 | device, dtype = logits.device, logits.dtype 69 | n_cols = logits.shape[-1] 70 | # run the kernel 71 | result = torch.empty_like(indices, dtype=dtype, device=device) 72 | neg_logprobs = torch.empty_like(logits, dtype=dtype, device=device) 73 | grid = lambda opt: (logits.numel() // n_cols,) 74 | _forward[grid](logits, neg_logprobs, indices, result, n_cols) 75 | # save for backward 76 | ctx.save_for_backward(neg_logprobs, indices) 77 | return result 78 | 79 | @classmethod 80 | def backward(cls, ctx, dneg_logprobs): 81 | """We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k] 82 | so we initialize the gradient as neg_logprobs, so we can just exponentiate 83 | to get p[k], which is most of what we need... neg_logprobs will be 84 | modified in place to become the gradient we want 85 | """ 86 | # load saved tensors 87 | neg_logprobs, indices = ctx.saved_tensors 88 | # run the kernel 89 | # neg_logprobs will be modified in place to become our gradient: 90 | n_cols = neg_logprobs.shape[-1] 91 | grid = lambda opt: (neg_logprobs.numel() // n_cols,) 92 | _backward[grid](neg_logprobs, indices, dneg_logprobs, n_cols) 93 | return neg_logprobs, None 94 | 95 | 96 | cross_entropy = _cross_entropy.apply 97 | -------------------------------------------------------------------------------- /test/test_flash_attention.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import os 4 | 5 | import triton 6 | import triton.ops 7 | 8 | 9 | @pytest.mark.interpreter 10 | @pytest.mark.parametrize( 11 | "Z, H, N_CTX, D_HEAD", 12 | [ # 13 | (2, 4, 512, 16), 14 | (2, 4, 512, 32), 15 | (2, 4, 512, 64), 16 | (2, 4, 512, 128), 17 | ], 18 | ) 19 | @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) 20 | @pytest.mark.parametrize("causal", [True, False]) 21 | @pytest.mark.parametrize("seq_par", [True, False]) 22 | def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par, device): 23 | capability = torch.cuda.get_device_capability() 24 | if capability[0] < 8: 25 | pytest.skip("Flash attention only supported for compute capability >= 80") 26 | if dtype == torch.bfloat16 and os.environ.get("TRITON_INTERPRET", "0") == "1": 27 | pytest.skip("Flash attention bfloat16 not supported in interpreter mode") 28 | torch.manual_seed(20) 29 | q = ( 30 | torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device=device) 31 | .normal_(mean=0.0, std=0.5) 32 | .requires_grad_() 33 | ) 34 | k = ( 35 | torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device=device) 36 | .normal_(mean=0.0, std=0.5) 37 | .requires_grad_() 38 | ) 39 | v = ( 40 | torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device=device) 41 | .normal_(mean=0.0, std=0.5) 42 | .requires_grad_() 43 | ) 44 | sm_scale = 0.5 45 | dout = torch.randn_like(q) 46 | # reference implementation 47 | M = torch.tril(torch.ones((N_CTX, N_CTX), device=device)) 48 | p = torch.matmul(q, k.transpose(2, 3)) * sm_scale 49 | if causal: 50 | p[:, :, M == 0] = float("-inf") 51 | p = torch.softmax(p.float(), dim=-1).to(dtype) 52 | # p = torch.exp(p) 53 | ref_out = torch.matmul(p, v) 54 | ref_out.backward(dout) 55 | ref_dv, v.grad = v.grad.clone(), None 56 | ref_dk, k.grad = k.grad.clone(), None 57 | ref_dq, q.grad = q.grad.clone(), None 58 | # # triton implementation 59 | tri_out = triton.ops.attention(q, k, v, causal, sm_scale, seq_par) 60 | tri_out.backward(dout) 61 | tri_dv, v.grad = v.grad.clone(), None 62 | tri_dk, k.grad = k.grad.clone(), None 63 | tri_dq, q.grad = q.grad.clone(), None 64 | # compare 65 | atol = 1e-1 if dtype == torch.bfloat16 else 1e-2 66 | torch.testing.assert_close( 67 | torch.nn.functional.normalize(torch.flatten(ref_out), dim=0), 68 | torch.nn.functional.normalize(torch.flatten(tri_out), dim=0), 69 | atol=atol, 70 | rtol=0, 71 | ) 72 | torch.testing.assert_close( 73 | torch.nn.functional.normalize(torch.flatten(ref_dv), dim=0), 74 | torch.nn.functional.normalize(torch.flatten(tri_dv), dim=0), 75 | atol=atol, 76 | rtol=0, 77 | ) 78 | torch.testing.assert_close( 79 | torch.nn.functional.normalize(torch.flatten(ref_dk), dim=0), 80 | torch.nn.functional.normalize(torch.flatten(tri_dk), dim=0), 81 | atol=atol, 82 | rtol=0, 83 | ) 84 | torch.testing.assert_close( 85 | torch.nn.functional.normalize(torch.flatten(ref_dq), dim=0), 86 | torch.nn.functional.normalize(torch.flatten(tri_dq), dim=0), 87 | atol=atol, 88 | rtol=0, 89 | ) 90 | 91 | 92 | try: 93 | from flash_attn.flash_attn_interface import flash_attn_func 94 | 95 | HAS_FLASH = True 96 | except BaseException: 97 | HAS_FLASH = False 98 | 99 | BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 100 | # vary seq length for fixed head and batch=4 101 | configs = [ 102 | triton.testing.Benchmark( 103 | x_names=["N_CTX"], 104 | x_vals=[2**i for i in range(10, 14)], 105 | line_arg="provider", 106 | line_vals=["triton"] + (["flash"] if HAS_FLASH else []), 107 | line_names=["Triton"] + (["Flash"] if HAS_FLASH else []), 108 | styles=[("red", "-"), ("blue", "-")], 109 | ylabel="ms", 110 | plot_name=f"fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}-{casual}-{seq_par}", 111 | args={ 112 | "H": N_HEADS, 113 | "BATCH": BATCH, 114 | "D_HEAD": D_HEAD, 115 | "dtype": torch.float16, 116 | "mode": mode, 117 | "casual": casual, 118 | "seq_par": seq_par, 119 | }, 120 | ) 121 | for mode in ["fwd", "bwd"] 122 | for casual in [True, False] 123 | for seq_par in [True, False] 124 | ] 125 | 126 | 127 | @triton.testing.perf_report(configs) 128 | def bench_flash_attention( 129 | BATCH, 130 | H, 131 | N_CTX, 132 | D_HEAD, 133 | mode, 134 | casual, 135 | seq_par, 136 | provider, 137 | dtype=torch.float16, 138 | device="cuda", 139 | ): 140 | assert mode in ["fwd", "bwd"] 141 | warmup = 25 142 | rep = 100 143 | sm_scale = 1.3 144 | q = torch.randn( 145 | (BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True 146 | ) 147 | k = torch.randn( 148 | (BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True 149 | ) 150 | v = torch.randn( 151 | (BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True 152 | ) 153 | if provider == "triton": 154 | fn = lambda: triton.ops.attention(q, k, v, casual, sm_scale, seq_par) 155 | if mode == "bwd": 156 | o = fn() 157 | do = torch.randn_like(o) 158 | fn = lambda: o.backward(do, retain_graph=True) 159 | ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) 160 | return ms 161 | if provider == "flash": 162 | lengths = torch.full((BATCH,), fill_value=N_CTX, device=device) 163 | cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32) 164 | cu_seqlens[1:] = lengths.cumsum(0) 165 | fn = lambda: flash_attn_func( 166 | q, k, v, dropout_p=0.0, softmax_scale=sm_scale, causal=casual 167 | ) 168 | if mode == "bwd": 169 | o = fn() 170 | do = torch.randn_like(o) 171 | fn = lambda: o.backward(do, retain_graph=True) 172 | ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) 173 | return ms 174 | 175 | 176 | # only works on post-Ampere GPUs right now 177 | # bench_flash_attention.run(save_path='.', print_data=True) 178 | -------------------------------------------------------------------------------- /models/llama/llama/math_ops.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch.nn.functional as F 3 | import torch 4 | import triton 5 | from typing import Tuple 6 | from torch import nn 7 | from kernels.matmul import matmul 8 | from kernels.cross_entropy import cross_entropy 9 | from kernels.matmul import matmul 10 | from kernels.flash_attention import attention 11 | from benchmarking import Profiler 12 | import time 13 | 14 | 15 | def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): 16 | ndim = x.ndim 17 | assert 0 <= 1 < ndim 18 | assert freqs_cis.shape == (x.shape[1], x.shape[-1]) 19 | shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] 20 | return freqs_cis.view(*shape) 21 | 22 | 23 | class _RMSNorm(torch.nn.Module): 24 | def __init__(self, dim: int, eps: float = 1e-6, use_triton=False): 25 | super().__init__() 26 | self.use_triton = use_triton 27 | self.eps = eps 28 | self.weight = nn.Parameter(torch.ones(dim)) 29 | 30 | def __triton_norm(self, x): 31 | """ 32 | TODO: Triton kernel added here as needed. remove if we dont want to convert this one 33 | For now adding the torch version as a placeholder 34 | """ 35 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 36 | 37 | @Profiler.profiling_decorator(record_name="RMSNorm") 38 | def _norm(self, x): 39 | if not self.use_triton: 40 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 41 | else: 42 | return self.__triton_norm(x) 43 | 44 | def forward(self, x): 45 | output = self._norm(x.float()).type_as(x) 46 | return output * self.weight 47 | 48 | 49 | class MathOps: 50 | printed_attention = False 51 | 52 | def __init__(self, use_triton=False): 53 | self.use_triton = use_triton 54 | 55 | @Profiler.profiling_decorator("matmul") 56 | def matmul(self, x, y): 57 | if self.use_triton: 58 | return torch.matmul(x, y) 59 | else: 60 | return torch.matmul(x, y) 61 | 62 | @Profiler.profiling_decorator("attention") 63 | def attention(self, xq, keys, values, head_dim, mask): 64 | scores = self.matmul(xq, keys.transpose(2, 3)) / math.sqrt(head_dim) 65 | if mask is not None: 66 | scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen) 67 | scores = self.softmax(scores.float(), dim=-1).type_as(xq) 68 | output = self.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim) 69 | return output 70 | 71 | @Profiler.profiling_decorator("softmax") 72 | def softmax(self, x, dim): 73 | if self.use_triton: 74 | return F.softmax(x, dim=-1) 75 | else: 76 | return F.softmax(x, dim=-1) 77 | 78 | @Profiler.profiling_decorator("argmax") 79 | def argmax(self, x, dim): 80 | if self.use_triton: 81 | return torch.argmax(x, dim=-1) 82 | else: 83 | return torch.argmax(x, dim=-1) 84 | 85 | @Profiler.profiling_decorator("cross_entropy") 86 | def cross_entropy(self, input_val, target, reduction, ignore_index): 87 | if self.use_triton: 88 | return cross_entropy( 89 | input=input_val, 90 | target=target, 91 | reduction=reduction, 92 | ignore_index=ignore_index, 93 | ) 94 | else: 95 | return -F.cross_entropy( 96 | input=input_val, 97 | target=target, 98 | reduction=reduction, 99 | ignore_index=ignore_index, 100 | ) 101 | 102 | def get_rms_norm(self, dim: int, eps: float = 1e-6): 103 | return _RMSNorm(dim, eps, self.use_triton) 104 | 105 | @Profiler.profiling_decorator("apply_rotary_emb") 106 | def apply_rotary_emb( 107 | self, 108 | xq: torch.Tensor, 109 | xk: torch.Tensor, 110 | freqs_cis: torch.Tensor, 111 | ) -> Tuple[torch.Tensor, torch.Tensor]: 112 | 113 | def torch_based(xq, xk, freqs_cis): 114 | xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) 115 | xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) 116 | freqs_cis = reshape_for_broadcast(freqs_cis, xq_) 117 | xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) 118 | xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) 119 | return xq_out.type_as(xq), xk_out.type_as(xk) 120 | 121 | def triton_based(xq, xk, freqs_cis): 122 | xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) 123 | xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) 124 | freqs_cis = reshape_for_broadcast(freqs_cis, xq_) 125 | xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) 126 | xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) 127 | return xq_out.type_as(xq), xk_out.type_as(xk) 128 | 129 | if self.use_triton: 130 | return triton_based(xq, xk, freqs_cis) 131 | else: 132 | return torch_based(xq, xk, freqs_cis) 133 | 134 | @Profiler.profiling_decorator(record_name="precompute_freqs_cis") 135 | def precompute_freqs_cis(self, dim: int, end: int, theta: float = 10000.0): 136 | def torch_based(dim: int, end: int, theta: float = 10000.0): 137 | freqs = 1.0 / ( 138 | theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim) 139 | ) 140 | t = torch.arange(end, device=freqs.device, dtype=torch.float32) 141 | freqs = torch.outer(t, freqs) 142 | freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 143 | return freqs_cis 144 | 145 | def triton_based(dim: int, end: int, theta: float = 10000.0): 146 | freqs = 1.0 / ( 147 | theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim) 148 | ) 149 | t = torch.arange(end, device=freqs.device, dtype=torch.float32) 150 | freqs = torch.outer(t, freqs) 151 | freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 152 | return freqs_cis 153 | 154 | return torch_based(dim, end, theta) 155 | if self.use_triton: 156 | return triton_based(dim, end, theta) 157 | else: 158 | return torch_based(dim, end, theta) 159 | -------------------------------------------------------------------------------- /kernels/matmul_perf_model.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import heapq 3 | 4 | import torch 5 | 6 | from triton import cdiv 7 | from triton.runtime import driver 8 | from triton.testing import ( 9 | get_dram_gbps, 10 | get_max_simd_tflops, 11 | get_max_tensorcore_tflops, 12 | nvsmi, 13 | ) 14 | 15 | 16 | @functools.lru_cache() 17 | def get_clock_rate_in_khz(): 18 | try: 19 | return nvsmi(["clocks.max.sm"])[0] * 1e3 20 | except FileNotFoundError: 21 | import pynvml 22 | 23 | pynvml.nvmlInit() 24 | handle = pynvml.nvmlDeviceGetHandleByIndex(0) 25 | return pynvml.nvmlDeviceGetMaxClockInfo(handle, pynvml.NVML_CLOCK_SM) * 1e3 26 | 27 | 28 | def get_tensorcore_tflops(device, num_ctas, num_warps, dtype): 29 | """return compute throughput in TOPS""" 30 | total_warps = num_ctas * min(num_warps, 4) 31 | num_subcores = ( 32 | driver.active.utils.get_device_properties(device)["multiprocessor_count"] * 4 33 | ) # on recent GPUs 34 | tflops = ( 35 | min(num_subcores, total_warps) 36 | / num_subcores 37 | * get_max_tensorcore_tflops(dtype, get_clock_rate_in_khz(), device) 38 | ) 39 | return tflops 40 | 41 | 42 | def get_simd_tflops(device, num_ctas, num_warps, dtype): 43 | """return compute throughput in TOPS""" 44 | total_warps = num_ctas * min(num_warps, 4) 45 | num_subcores = ( 46 | driver.active.utils.get_device_properties(device)["multiprocessor_count"] * 4 47 | ) # on recent GPUs 48 | tflops = ( 49 | min(num_subcores, total_warps) 50 | / num_subcores 51 | * get_max_simd_tflops(dtype, get_clock_rate_in_khz(), device) 52 | ) 53 | return tflops 54 | 55 | 56 | def get_tflops(device, num_ctas, num_warps, dtype): 57 | capability = torch.cuda.get_device_capability(device) 58 | if capability[0] < 8 and dtype == torch.float32: 59 | return get_simd_tflops(device, num_ctas, num_warps, dtype) 60 | return get_tensorcore_tflops(device, num_ctas, num_warps, dtype) 61 | 62 | 63 | def estimate_matmul_time( 64 | # backend, device, 65 | num_warps, 66 | num_stages, # 67 | A, 68 | B, 69 | C, # 70 | M, 71 | N, 72 | K, # 73 | BLOCK_M, 74 | BLOCK_N, 75 | BLOCK_K, 76 | SPLIT_K, # 77 | debug=False, 78 | **kwargs, # 79 | ): 80 | """return estimated running time in ms 81 | = max(compute, loading) + store""" 82 | device = torch.cuda.current_device() 83 | dtype = A.dtype 84 | dtsize = A.element_size() 85 | 86 | num_cta_m = cdiv(M, BLOCK_M) 87 | num_cta_n = cdiv(N, BLOCK_N) 88 | num_cta_k = SPLIT_K 89 | num_ctas = num_cta_m * num_cta_n * num_cta_k 90 | 91 | # If the input is smaller than the block size 92 | M, N = max(M, BLOCK_M), max(N, BLOCK_N) 93 | 94 | # time to compute 95 | total_ops = 2 * M * N * K / (1024 * 1024 * 1024) # GOPS 96 | tput = get_tflops(device, num_ctas, num_warps, dtype) 97 | compute_ms = total_ops / tput 98 | 99 | # time to load data 100 | num_sm = driver.active.utils.get_device_properties(device)["multiprocessor_count"] 101 | active_cta_ratio = min(1, num_ctas / num_sm) 102 | active_cta_ratio_bw1 = min( 103 | 1, num_ctas / 32 104 | ) # 32 active ctas are enough to saturate 105 | active_cta_ratio_bw2 = max( 106 | min(1, (num_ctas - 32) / (108 - 32)), 0 107 | ) # 32-108, remaining 5% 108 | dram_bw = get_dram_gbps(device) * ( 109 | active_cta_ratio_bw1 * 0.95 + active_cta_ratio_bw2 * 0.05 110 | ) # in GB/s 111 | l2_bw = dram_bw * 4 # rough estimation (should be 4.7 for A100?) 112 | # assume 80% of (following) loads are in L2 cache 113 | load_a_dram = M * K * dtsize * (1 + 0.2 * (num_cta_n - 1)) 114 | load_a_l2 = M * K * dtsize * 0.8 * (num_cta_n - 1) 115 | load_b_dram = N * K * dtsize * (1 + 0.2 * (num_cta_m - 1)) 116 | load_b_l2 = N * K * dtsize * 0.8 * (num_cta_m - 1) 117 | # total 118 | total_dram = (load_a_dram + load_b_dram) / (1024 * 1024) # MB 119 | total_l2 = (load_a_l2 + load_b_l2) / (1024 * 1024) 120 | # loading time in ms 121 | load_ms = total_dram / dram_bw + total_l2 / l2_bw 122 | 123 | # estimate storing time 124 | store_bw = dram_bw * 0.6 # :o 125 | store_c_dram = M * N * dtsize * SPLIT_K / (1024 * 1024) # MB 126 | if SPLIT_K == 1: 127 | store_ms = store_c_dram / store_bw 128 | else: 129 | reduce_bw = store_bw 130 | store_ms = store_c_dram / reduce_bw 131 | # c.zero_() 132 | zero_ms = M * N * 2 / (1024 * 1024) / store_bw 133 | store_ms += zero_ms 134 | 135 | total_time_ms = max(compute_ms, load_ms) + store_ms 136 | if debug: 137 | print( 138 | f"Total time: {total_time_ms}ms, compute time: {compute_ms}ms, " 139 | f"loading time: {load_ms}ms, store time: {store_ms}ms, " 140 | f"Activate CTAs: {active_cta_ratio*100}%" 141 | ) 142 | return total_time_ms 143 | 144 | 145 | def early_config_prune(configs, named_args, **kwargs): 146 | device = torch.cuda.current_device() 147 | capability = torch.cuda.get_device_capability() 148 | # BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages 149 | dtsize = named_args["A"].element_size() 150 | dtype = named_args["A"].dtype 151 | 152 | # 1. make sure we have enough smem 153 | pruned_configs = [] 154 | for config in configs: 155 | kw = config.kwargs 156 | BLOCK_M, BLOCK_N, BLOCK_K, num_stages = ( 157 | kw["BLOCK_M"], 158 | kw["BLOCK_N"], 159 | kw["BLOCK_K"], 160 | config.num_stages, 161 | ) 162 | 163 | max_shared_memory = driver.active.utils.get_device_properties(device)[ 164 | "max_shared_mem" 165 | ] 166 | required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize 167 | if required_shared_memory <= max_shared_memory: 168 | pruned_configs.append(config) 169 | configs = pruned_configs 170 | 171 | # Some dtypes do not allow atomic_add 172 | if dtype not in [torch.float16, torch.float32]: 173 | configs = [config for config in configs if config.kwargs["SPLIT_K"] == 1] 174 | 175 | # group configs by (BLOCK_M,_N,_K, SPLIT_K, num_warps) 176 | configs_map = {} 177 | for config in configs: 178 | kw = config.kwargs 179 | BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages = ( 180 | kw["BLOCK_M"], 181 | kw["BLOCK_N"], 182 | kw["BLOCK_K"], 183 | kw["SPLIT_K"], 184 | config.num_warps, 185 | config.num_stages, 186 | ) 187 | 188 | key = (BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps) 189 | if key in configs_map: 190 | configs_map[key].append((config, num_stages)) 191 | else: 192 | configs_map[key] = [(config, num_stages)] 193 | 194 | pruned_configs = [] 195 | for k, v in configs_map.items(): 196 | BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps = k 197 | if capability[0] >= 8: 198 | # compute cycles (only works for ampere GPUs) 199 | mmas = BLOCK_M * BLOCK_N * BLOCK_K / (16 * 8 * 16) 200 | mma_cycles = mmas / min(4, num_warps) * 8 201 | 202 | ldgsts_latency = 300 # Does this matter? 203 | optimal_num_stages = ldgsts_latency / mma_cycles 204 | 205 | # nearest stages, prefer large #stages 206 | nearest = heapq.nsmallest( 207 | 2, 208 | v, 209 | key=lambda x: ( 210 | 10 + abs(x[1] - optimal_num_stages) 211 | if (x[1] - optimal_num_stages) < 0 212 | else x[1] - optimal_num_stages 213 | ), 214 | ) 215 | 216 | for n in nearest: 217 | pruned_configs.append(n[0]) 218 | else: # Volta & Turing only supports num_stages <= 2 219 | random_config = v[0][0] 220 | random_config.num_stages = 2 221 | pruned_configs.append(random_config) 222 | return pruned_configs 223 | -------------------------------------------------------------------------------- /models/llama/llama/tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. 3 | 4 | import os 5 | from logging import getLogger 6 | from pathlib import Path 7 | from typing import ( 8 | AbstractSet, 9 | cast, 10 | Collection, 11 | Dict, 12 | Iterator, 13 | List, 14 | Literal, 15 | Sequence, 16 | TypedDict, 17 | Union, 18 | ) 19 | 20 | import tiktoken 21 | from tiktoken.load import load_tiktoken_bpe 22 | 23 | 24 | logger = getLogger(__name__) 25 | 26 | 27 | Role = Literal["system", "user", "assistant"] 28 | 29 | 30 | class Message(TypedDict): 31 | role: Role 32 | content: str 33 | 34 | 35 | Dialog = Sequence[Message] 36 | 37 | 38 | class Tokenizer: 39 | """ 40 | Tokenizing and encoding/decoding text using the Tiktoken tokenizer. 41 | """ 42 | 43 | special_tokens: Dict[str, int] 44 | 45 | num_reserved_special_tokens = 256 46 | 47 | pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501 48 | 49 | def __init__(self, model_path: str): 50 | """ 51 | Initializes the Tokenizer with a Tiktoken model. 52 | 53 | Args: 54 | model_path (str): The path to the Tiktoken model file. 55 | """ 56 | assert os.path.isfile(model_path), model_path 57 | 58 | mergeable_ranks = load_tiktoken_bpe(model_path) 59 | num_base_tokens = len(mergeable_ranks) 60 | special_tokens = [ 61 | "<|begin_of_text|>", 62 | "<|end_of_text|>", 63 | "<|reserved_special_token_0|>", 64 | "<|reserved_special_token_1|>", 65 | "<|reserved_special_token_2|>", 66 | "<|reserved_special_token_3|>", 67 | "<|start_header_id|>", 68 | "<|end_header_id|>", 69 | "<|reserved_special_token_4|>", 70 | "<|eot_id|>", # end of turn 71 | ] + [ 72 | f"<|reserved_special_token_{i}|>" 73 | for i in range(5, self.num_reserved_special_tokens - 5) 74 | ] 75 | self.special_tokens = { 76 | token: num_base_tokens + i for i, token in enumerate(special_tokens) 77 | } 78 | self.model = tiktoken.Encoding( 79 | name=Path(model_path).name, 80 | pat_str=self.pat_str, 81 | mergeable_ranks=mergeable_ranks, 82 | special_tokens=self.special_tokens, 83 | ) 84 | logger.info(f"Reloaded tiktoken model from {model_path}") 85 | 86 | self.n_words: int = self.model.n_vocab 87 | # BOS / EOS token IDs 88 | self.bos_id: int = self.special_tokens["<|begin_of_text|>"] 89 | self.eos_id: int = self.special_tokens["<|end_of_text|>"] 90 | self.pad_id: int = -1 91 | self.stop_tokens = { 92 | self.special_tokens["<|end_of_text|>"], 93 | self.special_tokens["<|eot_id|>"], 94 | } 95 | logger.info( 96 | f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}" 97 | ) 98 | 99 | def encode( 100 | self, 101 | s: str, 102 | *, 103 | bos: bool, 104 | eos: bool, 105 | allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), 106 | disallowed_special: Union[Literal["all"], Collection[str]] = (), 107 | ) -> List[int]: 108 | """ 109 | Encodes a string into a list of token IDs. 110 | 111 | Args: 112 | s (str): The input string to be encoded. 113 | bos (bool): Whether to prepend the beginning-of-sequence token. 114 | eos (bool): Whether to append the end-of-sequence token. 115 | allowed_tokens ("all"|set[str]): allowed special tokens in string 116 | disallowed_tokens ("all"|set[str]): special tokens that raise an error when in string 117 | 118 | Returns: 119 | list[int]: A list of token IDs. 120 | 121 | By default, setting disallowed_special=() encodes a string by ignoring 122 | special tokens. Specifically: 123 | - Setting `disallowed_special` to () will cause all text corresponding 124 | to special tokens to be encoded as natural text (insteading of raising 125 | an error). 126 | - Setting `allowed_special` to "all" will treat all text corresponding 127 | to special tokens to be encoded as special tokens. 128 | """ 129 | assert type(s) is str 130 | 131 | # The tiktoken tokenizer can handle <=400k chars without 132 | # pyo3_runtime.PanicException. 133 | TIKTOKEN_MAX_ENCODE_CHARS = 400_000 134 | 135 | # https://github.com/openai/tiktoken/issues/195 136 | # Here we iterate over subsequences and split if we exceed the limit 137 | # of max consecutive non-whitespace or whitespace characters. 138 | MAX_NO_WHITESPACES_CHARS = 25_000 139 | 140 | substrs = ( 141 | substr 142 | for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS) 143 | for substr in self._split_whitespaces_or_nonwhitespaces( 144 | s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS 145 | ) 146 | ) 147 | t: List[int] = [] 148 | for substr in substrs: 149 | t.extend( 150 | self.model.encode( 151 | substr, 152 | allowed_special=allowed_special, 153 | disallowed_special=disallowed_special, 154 | ) 155 | ) 156 | if bos: 157 | t.insert(0, self.bos_id) 158 | if eos: 159 | t.append(self.eos_id) 160 | return t 161 | 162 | def decode(self, t: Sequence[int]) -> str: 163 | """ 164 | Decodes a list of token IDs into a string. 165 | 166 | Args: 167 | t (List[int]): The list of token IDs to be decoded. 168 | 169 | Returns: 170 | str: The decoded string. 171 | """ 172 | # Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence. 173 | return self.model.decode(cast(List[int], t)) 174 | 175 | @staticmethod 176 | def _split_whitespaces_or_nonwhitespaces( 177 | s: str, max_consecutive_slice_len: int 178 | ) -> Iterator[str]: 179 | """ 180 | Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len` 181 | consecutive whitespaces or consecutive non-whitespaces. 182 | """ 183 | current_slice_len = 0 184 | current_slice_is_space = s[0].isspace() if len(s) > 0 else False 185 | slice_start = 0 186 | 187 | for i in range(len(s)): 188 | is_now_space = s[i].isspace() 189 | 190 | if current_slice_is_space ^ is_now_space: 191 | current_slice_len = 1 192 | current_slice_is_space = is_now_space 193 | else: 194 | current_slice_len += 1 195 | if current_slice_len > max_consecutive_slice_len: 196 | yield s[slice_start:i] 197 | slice_start = i 198 | current_slice_len = 1 199 | yield s[slice_start:] 200 | 201 | 202 | class ChatFormat: 203 | def __init__(self, tokenizer: Tokenizer): 204 | self.tokenizer = tokenizer 205 | 206 | def encode_header(self, message: Message) -> List[int]: 207 | tokens = [] 208 | tokens.append(self.tokenizer.special_tokens["<|start_header_id|>"]) 209 | tokens.extend(self.tokenizer.encode(message["role"], bos=False, eos=False)) 210 | tokens.append(self.tokenizer.special_tokens["<|end_header_id|>"]) 211 | tokens.extend(self.tokenizer.encode("\n\n", bos=False, eos=False)) 212 | return tokens 213 | 214 | def encode_message(self, message: Message) -> List[int]: 215 | tokens = self.encode_header(message) 216 | tokens.extend( 217 | self.tokenizer.encode(message["content"].strip(), bos=False, eos=False) 218 | ) 219 | tokens.append(self.tokenizer.special_tokens["<|eot_id|>"]) 220 | return tokens 221 | 222 | def encode_dialog_prompt(self, dialog: Dialog) -> List[int]: 223 | tokens = [] 224 | tokens.append(self.tokenizer.special_tokens["<|begin_of_text|>"]) 225 | for message in dialog: 226 | tokens.extend(self.encode_message(message)) 227 | # Add the start of an assistant message for the model to complete. 228 | tokens.extend(self.encode_header({"role": "assistant", "content": ""})) 229 | return tokens 230 | -------------------------------------------------------------------------------- /test/test_inductor.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | import triton 5 | import triton.language as tl 6 | 7 | 8 | def test_normalization_with_remat(device): 9 | 10 | @triton.jit 11 | def triton_( 12 | in_out_ptr0, 13 | in_out_ptr1, 14 | in_ptr0, 15 | in_ptr1, 16 | in_ptr2, 17 | in_ptr3, 18 | xnumel, 19 | rnumel, 20 | XBLOCK: tl.constexpr, 21 | RBLOCK: tl.constexpr, 22 | ): 23 | xnumel = 512 24 | rnumel = 4096 25 | xoffset = tl.program_id(0) * XBLOCK 26 | xindex = xoffset + tl.arange(0, XBLOCK)[:, None] 27 | xmask = xindex < xnumel 28 | rbase = tl.arange(0, RBLOCK)[None, :] 29 | x3 = xindex 30 | x0 = xindex % 64 31 | tmp1 = tl.load(in_ptr0 + (x0), xmask) 32 | tmp3 = tl.load(in_ptr1 + (x0), xmask) 33 | tmp11 = tl.load(in_ptr2 + (x0), xmask) 34 | tmp13 = tl.load(in_ptr3 + (x0), xmask) 35 | _tmp17 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0 36 | for roffset in range(0, rnumel, RBLOCK): 37 | rindex = roffset + rbase 38 | rmask = rindex < rnumel 39 | r2 = rindex 40 | tmp0 = tl.load( 41 | in_out_ptr0 + (r2 + (4096 * x3)), 42 | rmask & xmask, 43 | eviction_policy="evict_last", 44 | other=0, 45 | ) 46 | tmp2 = tmp0 - tmp1 47 | tmp4 = 1e-05 48 | tmp5 = tmp3 + tmp4 49 | tmp6 = tl.sqrt(tmp5) 50 | tmp7 = 1 / tmp6 51 | tmp8 = 1.0 52 | tmp9 = tmp7 * tmp8 53 | tmp10 = tmp2 * tmp9 54 | tmp12 = tmp10 * tmp11 55 | tmp14 = tmp12 + tmp13 56 | _tmp17 = tl.where(rmask & xmask, _tmp17 + tmp14, _tmp17) 57 | tl.store( 58 | in_out_ptr0 + (r2 + (4096 * x3) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), 59 | tmp14, 60 | rmask & xmask, 61 | ) 62 | tmp17 = tl.sum(_tmp17, 1)[:, None] 63 | tmp18 = 4096.0 64 | tmp19 = tmp17 / tmp18 65 | tl.store(in_out_ptr1 + (x3 + tl.zeros([XBLOCK, 1], tl.int32)), tmp19, xmask) 66 | 67 | torch.manual_seed(123) 68 | 69 | buf14 = torch.rand(8, 64, 64, 64, device=device) 70 | buf16 = torch.rand(8, 1, 64, device=device) 71 | arg114_1 = torch.rand(64, device=device) 72 | arg115_1 = torch.rand(64, device=device) 73 | arg8_1 = torch.rand(64, device=device) 74 | arg9_1 = torch.rand(64, device=device) 75 | triton_[(512,)]( 76 | buf14, buf16, arg114_1, arg115_1, arg8_1, arg9_1, 512, 4096, 1, 2048 77 | ) 78 | torch.testing.assert_close( 79 | buf16.mean().item(), buf14.mean().item(), atol=1e-7, rtol=0 80 | ) 81 | 82 | 83 | def test_avg_pool_bw(device): 84 | 85 | @triton.jit 86 | def triton_(in_ptr0, out_ptr0, XBLOCK: tl.constexpr): 87 | xoffset = tl.program_id(0) * XBLOCK 88 | xindex = xoffset + tl.arange(0, XBLOCK)[:] 89 | x1 = (xindex // 8) % 8 90 | x0 = xindex % 8 91 | x2 = xindex // 64 92 | x5 = xindex 93 | tmp0 = (-1) + x1 94 | tmp1 = (-1) + x0 95 | tmp2 = 2 + x1 96 | tmp3 = 2 + x0 97 | tmp4 = 0 98 | tmp5 = tl.where(tmp0 != tmp0, tmp0, tl.where(tmp0 > tmp4, tmp0, tmp4)) 99 | tmp6 = tl.where(tmp1 != tmp1, tmp1, tl.where(tmp1 > tmp4, tmp1, tmp4)) 100 | tmp7 = 8 101 | tmp8 = tl.where(tmp2 != tmp2, tmp2, tl.where(tmp2 < tmp7, tmp2, tmp7)) 102 | tmp9 = tl.where(tmp3 != tmp3, tmp3, tl.where(tmp3 < tmp7, tmp3, tmp7)) 103 | tmp10 = tmp5 + tmp4 104 | tmp11 = tmp6 + tmp4 105 | tmp12 = 1 106 | tmp13 = tmp8 - tmp12 107 | tmp14 = tl.where(tmp10 != tmp10, tmp10, tl.where(tmp10 < tmp13, tmp10, tmp13)) 108 | tmp15 = tmp9 - tmp12 109 | tmp16 = tl.where(tmp11 != tmp11, tmp11, tl.where(tmp11 < tmp15, tmp11, tmp15)) 110 | tmp17 = tl.load(in_ptr0 + (tmp16 + (8 * tmp14) + (64 * x2)), None).to( 111 | tl.float32 112 | ) 113 | tmp18 = tmp17 / 9 114 | tmp19 = tmp10 < tmp8 115 | tmp20 = tmp11 < tmp9 116 | tmp21 = tmp19 & tmp20 117 | tmp22 = 0.0 118 | tmp23 = tl.where(tmp21, tmp18, tmp22) 119 | tmp24 = tmp6 + tmp12 120 | tmp25 = tl.where(tmp24 != tmp24, tmp24, tl.where(tmp24 < tmp15, tmp24, tmp15)) 121 | tmp26 = tl.load(in_ptr0 + (tmp25 + (8 * tmp14) + (64 * x2)), None).to( 122 | tl.float32 123 | ) 124 | tmp27 = tmp26 / 9 125 | tmp28 = tmp24 < tmp9 126 | tmp29 = tmp19 & tmp28 127 | tmp30 = tmp23 + tmp27 128 | tmp31 = tl.where(tmp29, tmp30, tmp23) 129 | tmp32 = 2 130 | tmp33 = tmp6 + tmp32 131 | tmp34 = tl.where(tmp33 != tmp33, tmp33, tl.where(tmp33 < tmp15, tmp33, tmp15)) 132 | tmp35 = tl.load(in_ptr0 + (tmp34 + (8 * tmp14) + (64 * x2)), None).to( 133 | tl.float32 134 | ) 135 | tmp36 = tmp35 / 9 136 | tmp37 = tmp33 < tmp9 137 | tmp38 = tmp19 & tmp37 138 | tmp39 = tmp31 + tmp36 139 | tmp40 = tl.where(tmp38, tmp39, tmp31) 140 | tmp41 = tmp5 + tmp12 141 | tmp42 = tl.where(tmp41 != tmp41, tmp41, tl.where(tmp41 < tmp13, tmp41, tmp13)) 142 | tmp43 = tl.load(in_ptr0 + (tmp16 + (8 * tmp42) + (64 * x2)), None).to( 143 | tl.float32 144 | ) 145 | tmp44 = tmp43 / 9 146 | tmp45 = tmp41 < tmp8 147 | tmp46 = tmp45 & tmp20 148 | tmp47 = tmp40 + tmp44 149 | tmp48 = tl.where(tmp46, tmp47, tmp40) 150 | tmp49 = tl.load(in_ptr0 + (tmp25 + (8 * tmp42) + (64 * x2)), None).to( 151 | tl.float32 152 | ) 153 | tmp50 = tmp49 / 9 154 | tmp51 = tmp45 & tmp28 155 | tmp52 = tmp48 + tmp50 156 | tmp53 = tl.where(tmp51, tmp52, tmp48) 157 | tmp54 = tl.load(in_ptr0 + (tmp34 + (8 * tmp42) + (64 * x2)), None).to( 158 | tl.float32 159 | ) 160 | tmp55 = tmp54 / 9 161 | tmp56 = tmp45 & tmp37 162 | tmp57 = tmp53 + tmp55 163 | tmp58 = tl.where(tmp56, tmp57, tmp53) 164 | tmp59 = tmp5 + tmp32 165 | tmp60 = tl.where(tmp59 != tmp59, tmp59, tl.where(tmp59 < tmp13, tmp59, tmp13)) 166 | tmp61 = tl.load(in_ptr0 + (tmp16 + (8 * tmp60) + (64 * x2)), None).to( 167 | tl.float32 168 | ) 169 | tmp62 = tmp61 / 9 170 | tmp63 = tmp59 < tmp8 171 | tmp64 = tmp63 & tmp20 172 | tmp65 = tmp58 + tmp62 173 | tmp66 = tl.where(tmp64, tmp65, tmp58) 174 | tmp67 = tl.load(in_ptr0 + (tmp25 + (8 * tmp60) + (64 * x2)), None).to( 175 | tl.float32 176 | ) 177 | tmp68 = tmp67 / 9 178 | tmp69 = tmp63 & tmp28 179 | tmp70 = tmp66 + tmp68 180 | tmp71 = tl.where(tmp69, tmp70, tmp66) 181 | tmp72 = tl.load(in_ptr0 + (tmp34 + (8 * tmp60) + (64 * x2)), None).to( 182 | tl.float32 183 | ) 184 | tmp73 = tmp72 / 9 185 | tmp74 = tmp63 & tmp37 186 | tmp75 = tmp71 + tmp73 187 | tmp76 = tl.where(tmp74, tmp75, tmp71) 188 | tl.store(out_ptr0 + (x5 + tl.zeros([XBLOCK], tl.int32)), tmp76, None) 189 | 190 | inp = torch.ones(8, 2048, 8, 8, device=device, dtype=torch.half) 191 | out = torch.ones_like(inp) * 3 192 | numel = inp.numel() 193 | triton_[(numel // 1024,)](inp, out, 1024) 194 | out_ref = torch.ones_like(inp) 195 | out_ref[:, :, 1:7, 0::7] = 2 / 3 196 | out_ref[:, :, 0::7, 1:7] = 2 / 3 197 | out_ref[:, :, 0::7, 0::7] = 4 / 9 198 | torch.testing.assert_close(out, out_ref) 199 | 200 | 201 | @pytest.mark.parametrize("RBLOCK", [1, 16, 32, 64, 128]) 202 | @pytest.mark.parametrize("num_warps", [1, 4]) 203 | def test_scan2d_broadcast(RBLOCK, num_warps, device): 204 | 205 | @triton.jit(debug=True) 206 | def fn(in_ptr, out_ptr, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr): 207 | rindex = tl.arange(0, RBLOCK)[None, :] 208 | xindex = tl.arange(0, XBLOCK)[:, None] 209 | data = tl.load(in_ptr + rindex) 210 | scan = tl.cumsum(data, 1) 211 | expected_max = tl.sum(data, 1) 212 | tl.device_assert(scan <= expected_max) 213 | tl.store(out_ptr + xindex * RBLOCK + rindex, scan) 214 | 215 | XBLOCK = 4 216 | input = torch.randint(0, 10, (1, RBLOCK), dtype=torch.int64, device=device) 217 | output = torch.empty((XBLOCK, RBLOCK), dtype=torch.int64, device=device) 218 | fn[(1,)](input, output, XBLOCK, RBLOCK, num_warps=num_warps) 219 | ref = input.cumsum(1).broadcast_to((XBLOCK, RBLOCK)) 220 | torch.testing.assert_close(output, ref) 221 | 222 | 223 | def test_scan2d_for(device): 224 | 225 | @triton.jit 226 | def fn(out_ptr0, rnumel, RBLOCK: tl.constexpr): 227 | rbase = tl.arange(0, RBLOCK)[None, :] 228 | for roffset in range(0, rnumel, RBLOCK): 229 | rindex = roffset + rbase 230 | rmask = rindex < rnumel 231 | tmp3 = tl.where(rmask, 1, 0) 232 | tmp6 = tl.cumsum(tmp3, 1) 233 | tl.store(out_ptr0 + rindex, tmp6, rmask) 234 | 235 | RBLOCK = 8 236 | out0 = torch.empty(RBLOCK, device=device, dtype=torch.int64) 237 | fn[(1,)](out0, RBLOCK, RBLOCK) 238 | ref = torch.arange(RBLOCK, device=device, dtype=torch.int64) + 1 239 | torch.testing.assert_close(out0, ref) 240 | -------------------------------------------------------------------------------- /kernels/blocksparse/softmax.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from triton import jit 4 | from triton import language as tl 5 | from triton import next_power_of_2 6 | 7 | 8 | def num_warps(n): 9 | if n <= 128: 10 | return 1 11 | if n <= 256: 12 | return 2 13 | if n <= 512: 14 | return 4 15 | if n <= 4096: 16 | return 8 17 | return 16 18 | 19 | 20 | @jit 21 | def _blocksparse_softmax_fwd( 22 | Out, 23 | A, 24 | stride_xz, 25 | LUT, # 26 | R, 27 | extent, 28 | stride_zr, 29 | stride_hr, # relative attention 30 | scale, 31 | is_causal, # 32 | ROW_SIZE: tl.constexpr, # 33 | BLOCK_SIZE: tl.constexpr, # 34 | IS_DENSE: tl.constexpr, # 35 | ): 36 | h = tl.program_id(0) 37 | m = tl.program_id(1) 38 | z = tl.program_id(2) 39 | # create index ranges 40 | hm = h * tl.num_programs(1) + m 41 | lane_n = tl.arange(0, ROW_SIZE) % BLOCK_SIZE 42 | block_n = tl.arange(0, ROW_SIZE) // BLOCK_SIZE 43 | # extract information from LUT 44 | header = LUT + (hm // BLOCK_SIZE) * 2 45 | size = tl.load(header + 0) 46 | offset = tl.load(header + 1) 47 | # pointer offset 48 | off_a = z * stride_xz 49 | off_a += (offset + block_n) * BLOCK_SIZE * BLOCK_SIZE # block indx 50 | off_a += (m % BLOCK_SIZE) * BLOCK_SIZE # row indx 51 | # do not need to read column indices in the dense case 52 | if IS_DENSE: 53 | ns = tl.arange(0, ROW_SIZE) 54 | else: 55 | off_lut = offset + 2 * tl.num_programs(0) * tl.num_programs(1) // BLOCK_SIZE 56 | start_n = tl.load(LUT + off_lut + block_n, mask=block_n < size, other=0) 57 | ns = start_n * BLOCK_SIZE + lane_n 58 | # load X 59 | mask = block_n < size 60 | a = tl.load(A + off_a + lane_n, mask=mask, other=-float("inf")) 61 | a = a.to(tl.float32) 62 | # compute 63 | out = a 64 | out *= scale 65 | # apply relative attention 66 | if R is not None: 67 | R += z * stride_zr 68 | R += h * stride_hr 69 | off_lo = (extent - m - 1) + ns 70 | mask_lo = (off_lo >= 0) & (off_lo < extent) 71 | rel_logits = tl.load(R + m * extent + off_lo, mask=mask_lo, other=0.0) 72 | out += rel_logits 73 | out = out.to(tl.float32) 74 | # apply causal mask 75 | out = tl.where((ns > m) & is_causal, -float("inf"), out) 76 | # computation 77 | out = tl.softmax(out) 78 | # write-back 79 | tl.store(Out + off_a + lane_n, out, mask=mask) 80 | 81 | 82 | @jit 83 | def _blocksparse_softmax_bwd( 84 | DA, 85 | stride_zdx, # 86 | DOut, 87 | stride_zdout, # 88 | Out, 89 | stride_zout, # 90 | scale, # 91 | LUT, # 92 | DR, 93 | extent, 94 | stride_zr, 95 | stride_hr, 96 | stride_er, # 97 | is_causal, # 98 | ROW_SIZE: tl.constexpr, # 99 | BLOCK_SIZE: tl.constexpr, # 100 | IS_DENSE: tl.constexpr, 101 | ): 102 | h = tl.program_id(0) 103 | m = tl.program_id(1) 104 | z = tl.program_id(2) 105 | # create index ranges 106 | hm = h * tl.num_programs(1) + m 107 | lane_n = tl.arange(0, ROW_SIZE) % BLOCK_SIZE 108 | block_n = tl.arange(0, ROW_SIZE) // BLOCK_SIZE 109 | # extract information from LUT 110 | header = LUT + (hm // BLOCK_SIZE) * 2 111 | size = tl.load(header + 0) 112 | offset = tl.load(header + 1) 113 | # row-col offset 114 | off_mn = (offset + block_n) * BLOCK_SIZE * BLOCK_SIZE 115 | off_mn += (m % BLOCK_SIZE) * BLOCK_SIZE 116 | mask = block_n < size 117 | # pointers 118 | As = Out + z * stride_zout + off_mn 119 | DOuts = DOut + z * stride_zdout + off_mn 120 | # do not need to read column indices in the dense case 121 | if IS_DENSE: 122 | ns = tl.arange(0, ROW_SIZE) 123 | else: 124 | off_lut = offset + 2 * tl.num_programs(0) * tl.num_programs(1) // BLOCK_SIZE 125 | start_n = tl.load(LUT + off_lut + block_n, mask=mask, other=0) 126 | ns = start_n * BLOCK_SIZE + lane_n 127 | # load data 128 | a = tl.load(As + lane_n, mask=mask, other=0.0) 129 | a = a.to(tl.float32) 130 | dout = tl.load(DOuts + lane_n, mask=mask, other=0.0) 131 | dout = dout.to(tl.float32) 132 | # compute 133 | a = tl.where((ns > m) & is_causal & (a == a), 0.0, a) 134 | da = a * (dout - tl.sum(a * dout, 0)) 135 | # apply relative attention 136 | if DR is not None: 137 | DR += z * stride_zr 138 | DR += h * stride_hr 139 | off_lo = (extent - m - 1) + ns 140 | mask_lo = (off_lo >= 0) & (off_lo < extent) & mask 141 | tl.store(DR + m * extent + off_lo, da, mask=mask_lo) 142 | da = da * scale 143 | # convert da 144 | # write-back 145 | DAs = DA + z * stride_zdx + off_mn 146 | tl.store(DAs + lane_n, da, mask=mask) 147 | 148 | 149 | class _softmax(torch.autograd.Function): 150 | 151 | @staticmethod 152 | def make_lut(layout, block, device): 153 | _empty = torch.tensor([], dtype=torch.int64, device=layout.device) 154 | sizes = _empty.clone() 155 | # sizes along rows 156 | for h in range(layout.shape[0]): 157 | sizes = torch.cat((sizes, layout[h, :, :].sum(-1))) 158 | total_sizes = sizes * block 159 | # offsets in block format 160 | offsets = torch.zeros_like(sizes) 161 | offsets[1:] = torch.cumsum(sizes[:-1], dim=0) 162 | # block indices 163 | columns = layout.nonzero(as_tuple=False)[:, 2] 164 | header = torch.stack((sizes, offsets), dim=1).view(-1) 165 | lut = torch.cat((header, columns)).type(torch.int32).to(device) 166 | return lut, int(total_sizes.max()) 167 | 168 | @staticmethod 169 | def forward( 170 | ctx, a, scale, rel_logits, is_causal, spdims, block, lut, maxlut, is_dense 171 | ): 172 | if scale is not None and isinstance(scale, torch.Tensor): 173 | assert scale.device.type == "cpu" 174 | scale = scale.item() 175 | M = a.shape[0] 176 | grid = [spdims[0], spdims[1] * block, M] 177 | rel_shape = (1, 1, 1, 1) if rel_logits is None else rel_logits.shape 178 | rel_strides = (1, 1, 1, 1) if rel_logits is None else rel_logits.stride() 179 | # enqueue kernel 180 | out = torch.empty_like(a) 181 | _blocksparse_softmax_fwd[grid]( 182 | out, 183 | a, 184 | a.stride(0), 185 | lut, # 186 | rel_logits, 187 | rel_shape[-1], 188 | rel_strides[0], 189 | rel_strides[1], # relative attn# 190 | scale, # 191 | is_causal, # 192 | BLOCK_SIZE=block, # 193 | ROW_SIZE=next_power_of_2(maxlut), # 194 | IS_DENSE=is_dense, # 195 | num_warps=num_warps(maxlut), # 196 | ) 197 | # save to context 198 | # ctx.mark_dirty(x) 199 | ctx.save_for_backward(out, lut) 200 | ctx.spdims = spdims 201 | ctx.block = block 202 | ctx.maxlut = maxlut 203 | ctx.scale = scale 204 | ctx.rel_shape = rel_shape 205 | ctx.rel_strides = rel_strides 206 | ctx.rel_dtype = a.dtype 207 | ctx.is_dense = is_dense 208 | ctx.is_causal = is_causal 209 | return out 210 | 211 | @staticmethod 212 | def backward(ctx, dout): 213 | # retrieve from context 214 | out, lut = ctx.saved_tensors 215 | # relative logits gradients 216 | dr = None 217 | if ctx.needs_input_grad[3]: 218 | dr = torch.zeros(ctx.rel_shape, dtype=ctx.rel_dtype, device=out.device) 219 | # run kernel 220 | M = out.shape[0] 221 | grid = (ctx.spdims[0], ctx.spdims[1] * ctx.block, M) 222 | da = torch.empty_like(dout) 223 | _blocksparse_softmax_bwd[grid]( 224 | da, 225 | da.stride(0), # 226 | dout, 227 | dout.stride(0), # 228 | out, 229 | out.stride(0), # 230 | ctx.scale, # 231 | lut, # 232 | dr, 233 | ctx.rel_shape[-1], 234 | ctx.rel_strides[0], 235 | ctx.rel_strides[1], 236 | ctx.rel_strides[2], # 237 | ctx.is_causal, # 238 | BLOCK_SIZE=ctx.block, # 239 | ROW_SIZE=next_power_of_2(ctx.maxlut), # 240 | IS_DENSE=ctx.is_dense, # 241 | num_warps=num_warps(ctx.maxlut), # 242 | ) 243 | return ( 244 | da, 245 | None, 246 | None, 247 | dr, 248 | None, 249 | None, 250 | None, 251 | None, 252 | None, 253 | None, 254 | None, 255 | None, 256 | None, 257 | None, 258 | None, 259 | None, 260 | None, 261 | None, 262 | ) 263 | 264 | 265 | class softmax: 266 | 267 | def __init__(self, layout, block, device, is_dense=False): 268 | self.spdims = layout.shape 269 | self.layout = layout 270 | self.block = block 271 | self.lut, self.maxlut = _softmax.make_lut(self.layout, self.block, device) 272 | self.is_dense = is_dense 273 | 274 | def __call__(self, a, *, scale=1.0, rel_logits=None, is_causal=False): 275 | if rel_logits is not None and rel_logits.dtype != a.dtype: 276 | raise ValueError(f"relative position embedding must be {a.dtype}") 277 | a = _softmax.apply( 278 | a, 279 | scale, 280 | rel_logits, 281 | is_causal, 282 | self.spdims, 283 | self.block, 284 | self.lut, 285 | self.maxlut, 286 | self.is_dense, 287 | ) 288 | return a 289 | -------------------------------------------------------------------------------- /test/test_blocksparse.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | import triton 5 | import triton.ops 6 | 7 | 8 | def is_hip_mi200(): 9 | target = triton.runtime.driver.active.get_current_target() 10 | return target.backend == "hip" and target.arch == "gfx90a" 11 | 12 | 13 | def sparsify_tensor(x, mask, block): 14 | ret = torch.empty( 15 | (x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device 16 | ) 17 | for idx, (h, i, j) in enumerate(zip(*mask.nonzero(as_tuple=True))): 18 | ret[:, idx, :, :] = x[ 19 | :, h, i * block : (i + 1) * block, j * block : (j + 1) * block 20 | ] 21 | return ret 22 | 23 | 24 | def make_pair( 25 | shape, 26 | device="cuda", 27 | alpha=1e-2, 28 | beta=0.0, 29 | trans=False, 30 | data=None, 31 | dtype=torch.float32, 32 | ): 33 | if data is None: 34 | data = torch.randn( 35 | shape, dtype=torch.float32, requires_grad=True, device=device 36 | ) 37 | ref_ret = data 38 | ref_ret = ref_ret * alpha + beta 39 | ref_ret = ref_ret.half().to(dtype) 40 | if trans: 41 | ref_ret = ref_ret.t().requires_grad_() 42 | ref_ret = ref_ret.detach().requires_grad_() 43 | tri_ret = ref_ret.clone().detach().requires_grad_() 44 | return ref_ret, tri_ret 45 | 46 | 47 | def mask_tensor(x, mask, block, value=0): 48 | ret = x.clone() 49 | for h, i, j in zip(*(mask == 0).nonzero(as_tuple=True)): 50 | ret[:, h, i * block : (i + 1) * block, j * block : (j + 1) * block] = value 51 | return ret 52 | 53 | 54 | @pytest.mark.parametrize("MODE", ["sdd", "dds", "dsd"]) 55 | @pytest.mark.parametrize("TRANS_A", [False, True]) 56 | @pytest.mark.parametrize("TRANS_B", [False, True]) 57 | @pytest.mark.parametrize("BLOCK", [16, 32, 64]) 58 | @pytest.mark.parametrize("DTYPE", [torch.float16]) 59 | def test_matmul( 60 | MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, device, Z=3, H=2, M=512, N=384, K=256 61 | ): 62 | seed = 0 63 | torch.manual_seed(seed) 64 | is_sdd = MODE == "sdd" 65 | is_dsd = MODE == "dsd" 66 | is_dds = MODE == "dds" 67 | do_sparsify = lambda x: sparsify_tensor(x, layout, BLOCK) 68 | do_mask = lambda x: mask_tensor(x, layout, BLOCK) 69 | # create inputs 70 | # create op 71 | a_shape = (Z, H, K, M) if TRANS_A else (Z, H, M, K) 72 | b_shape = (Z, H, N, K) if TRANS_B else (Z, H, K, N) 73 | c_shape = (Z, H, M, N) 74 | shape = { 75 | "sdd": (M, N), 76 | "dsd": (a_shape[2], a_shape[3]), 77 | "dds": (b_shape[2], b_shape[3]), 78 | }[MODE] 79 | layout = torch.randint(2, (H, shape[0] // BLOCK, shape[1] // BLOCK)) 80 | layout[1, 2, :] = 0 81 | layout[1, :, 1] = 0 82 | # create data 83 | a_ref, a_tri = make_pair(a_shape, alpha=0.1, dtype=DTYPE) 84 | b_ref, b_tri = make_pair(b_shape, alpha=0.1, dtype=DTYPE) 85 | dc_ref, dc_tri = make_pair(c_shape, dtype=DTYPE) 86 | # compute [torch] 87 | dc_ref = do_mask(dc_ref) if is_sdd else dc_ref 88 | a_ref = do_mask(a_ref) if is_dsd else a_ref 89 | b_ref = do_mask(b_ref) if is_dds else b_ref 90 | a_ref.retain_grad() 91 | b_ref.retain_grad() 92 | c_ref = torch.matmul( 93 | a_ref.transpose(2, 3) if TRANS_A else a_ref, 94 | b_ref.transpose(2, 3) if TRANS_B else b_ref, 95 | ) 96 | c_ref.backward(dc_ref) 97 | c_ref = do_sparsify(c_ref) if is_sdd else c_ref 98 | da_ref = do_sparsify(a_ref.grad) if is_dsd else a_ref.grad 99 | db_ref = do_sparsify(b_ref.grad) if is_dds else b_ref.grad 100 | # triton result 101 | dc_tri = do_sparsify(dc_tri) if is_sdd else dc_tri 102 | a_tri = do_sparsify(a_tri) if is_dsd else a_tri 103 | b_tri = do_sparsify(b_tri) if is_dds else b_tri 104 | a_tri.retain_grad() 105 | b_tri.retain_grad() 106 | op = triton.ops.blocksparse.matmul( 107 | layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B, device=device 108 | ) 109 | c_tri = op(a_tri, b_tri) 110 | c_tri.backward(dc_tri) 111 | da_tri = a_tri.grad 112 | db_tri = b_tri.grad 113 | 114 | # Bigger tolerance for AMD MI200 devices. 115 | # MI200 devices use reduced precision fp16 and bf16 and flush input and 116 | # output denormal values to zero. Detailed info is at: https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices 117 | tol = {"atol": 1e-3, "rtol": 0} if is_hip_mi200() else {} 118 | 119 | # compare 120 | torch.testing.assert_close(c_ref, c_tri, **tol) 121 | torch.testing.assert_close(da_ref, da_tri, **tol) 122 | torch.testing.assert_close(db_ref, db_tri, **tol) 123 | 124 | 125 | configs = [ 126 | (16, 256), 127 | (32, 576), 128 | (64, 1871), 129 | (128, 2511), 130 | ] 131 | 132 | 133 | @pytest.mark.parametrize("is_dense", [False, True]) 134 | @pytest.mark.parametrize("BLOCK, WIDTH", configs) 135 | def test_softmax(BLOCK, WIDTH, is_dense, device, Z=2, H=2, is_causal=True, scale=0.4): 136 | # set seed 137 | torch.random.manual_seed(0) 138 | Z, H, M, N = 2, 3, WIDTH, WIDTH 139 | # initialize layout 140 | # make sure each row has at least one non-zero element 141 | layout = torch.randint(2, (H, M // BLOCK, N // BLOCK)) 142 | if is_dense: 143 | layout[:] = 1 144 | else: 145 | layout[1, 2, :] = 0 146 | layout[1, :, 1] = 0 147 | # initialize data 148 | a_shape = (Z, H, M, N) 149 | a_ref, a_tri = make_pair(a_shape) 150 | dout_ref, dout_tri = make_pair(a_shape) 151 | # compute [torch] 152 | a_ref = mask_tensor(a_ref, layout, BLOCK, value=float("-inf")) 153 | a_ref.retain_grad() 154 | at_mask = torch.ones((M, N), device=device) 155 | if is_causal: 156 | at_mask = torch.tril(at_mask) 157 | M = at_mask[None, None, :, :] + torch.zeros_like(a_ref) 158 | a_ref[M == 0] = float("-inf") 159 | out_ref = torch.softmax(a_ref * scale, -1) 160 | out_ref.backward(dout_ref) 161 | out_ref = sparsify_tensor(out_ref, layout, BLOCK) 162 | da_ref = sparsify_tensor(a_ref.grad, layout, BLOCK) 163 | # compute [triton] 164 | a_tri = sparsify_tensor(a_tri, layout, BLOCK) 165 | a_tri.retain_grad() 166 | dout_tri = sparsify_tensor(dout_tri, layout, BLOCK) 167 | op = triton.ops.blocksparse.softmax(layout, BLOCK, device=device, is_dense=is_dense) 168 | out_tri = op(a_tri, scale=scale, is_causal=is_causal) 169 | out_tri.backward(dout_tri) 170 | da_tri = a_tri.grad 171 | # compare 172 | torch.testing.assert_close(out_tri, out_ref, equal_nan=True) 173 | torch.testing.assert_close(da_tri, da_ref, equal_nan=True) 174 | 175 | 176 | @pytest.mark.parametrize("block", [16, 32, 64]) 177 | @pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) 178 | def test_attention_fwd_bwd( 179 | block, 180 | dtype, 181 | device, 182 | input_scale=1.0, 183 | scale=1 / 8.0, 184 | n_ctx=256, 185 | batch_size=2, 186 | n_heads=2, 187 | ): 188 | capability = torch.cuda.get_device_capability() 189 | if capability[0] < 7: 190 | pytest.skip("Only test tl.dot() on devices with sm >= 70") 191 | 192 | # inputs 193 | qkv_shape = (batch_size, n_heads, n_ctx, 64) 194 | qkvs = [ 195 | torch.nn.Parameter(input_scale * torch.randn(qkv_shape), requires_grad=True) 196 | .to(dtype) 197 | .cuda() 198 | for _ in range(3) 199 | ] 200 | 201 | # Triton: 202 | n_blocks = n_ctx // block 203 | layout = torch.tril(torch.ones([n_heads, n_blocks, n_blocks], dtype=torch.long)) 204 | query, key, value = [x.clone() for x in qkvs] 205 | query.retain_grad() 206 | key.retain_grad() 207 | value.retain_grad() 208 | attn_out = triton_attention( 209 | layout, block, query=query, key=key, value=value, scale=scale 210 | ) 211 | # ad hoc loss 212 | loss = (attn_out**2).mean() 213 | loss.backward() 214 | grads = [query.grad, key.grad, value.grad] 215 | 216 | # Torch version: 217 | torch_q, torch_k, torch_v = [x.clone() for x in qkvs] 218 | attn_mask = torch.ones([n_ctx, n_ctx], device=device, dtype=dtype) 219 | attn_mask = torch.tril(attn_mask, diagonal=0) 220 | attn_mask = 1e6 * (-1 + (attn_mask.reshape((1, 1, n_ctx, n_ctx)).cuda())) 221 | torch_q.retain_grad() 222 | torch_k.retain_grad() 223 | torch_v.retain_grad() 224 | scores = scale * torch.einsum("bhsd,bhtd->bhst", torch_q, torch_k) 225 | scores = scores + attn_mask 226 | probs = torch.softmax(scores, dim=-1) 227 | torch_attn_out = torch.einsum("bhst,bhtd->bhsd", probs, torch_v) 228 | # ad hoc loss 229 | torch_loss = (torch_attn_out**2).mean() 230 | torch_loss.backward() 231 | torch_grads = [torch_q.grad, torch_k.grad, torch_v.grad] 232 | 233 | # comparison 234 | # print(f"Triton loss {loss} and torch loss {torch_loss}. Also checking grads...") 235 | torch.testing.assert_close(loss, torch_loss, atol=1e-3, rtol=0) 236 | 237 | # Bigger tolerance for AMD MI200 devices. 238 | # MI200 devices use reduced precision fp16 and bf16 and flush input and 239 | # output denormal values to zero. Detailed info is at: https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices 240 | tol = {"atol": 1e-3, "rtol": 0} if is_hip_mi200() else {} 241 | for g1, g2 in zip(grads, torch_grads): 242 | torch.testing.assert_close(g1, g2, **tol) 243 | 244 | 245 | @pytest.mark.parametrize("block", [16, 32, 64]) 246 | def triton_attention( 247 | layout, 248 | block: int, 249 | query: torch.Tensor, 250 | key: torch.Tensor, 251 | value: torch.Tensor, 252 | scale: float, 253 | ): 254 | sparse_dot_sdd_nt = triton.ops.blocksparse.matmul( 255 | layout, block, "sdd", trans_a=False, trans_b=True, device=value.device 256 | ) 257 | sparse_dot_dsd_nn = triton.ops.blocksparse.matmul( 258 | layout, block, "dsd", trans_a=False, trans_b=False, device=value.device 259 | ) 260 | sparse_softmax = triton.ops.blocksparse.softmax(layout, block, device=value.device) 261 | 262 | w = sparse_dot_sdd_nt(query, key) 263 | w = sparse_softmax(w, scale=scale, is_causal=True) 264 | a = sparse_dot_dsd_nn(w, value) 265 | return a 266 | -------------------------------------------------------------------------------- /models/llama/llama/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. 3 | 4 | import math 5 | from dataclasses import dataclass 6 | from typing import Optional 7 | 8 | import fairscale.nn.model_parallel.initialize as fs_init 9 | import torch 10 | import torch.nn.functional as F 11 | from fairscale.nn.model_parallel.layers import ( 12 | ColumnParallelLinear, 13 | RowParallelLinear, 14 | VocabParallelEmbedding, 15 | ) 16 | from torch import nn 17 | from .math_ops import MathOps 18 | from benchmarking import Profiler 19 | 20 | 21 | @dataclass 22 | class ModelArgs: 23 | dim: int = 4096 24 | n_layers: int = 32 25 | n_heads: int = 32 26 | n_kv_heads: Optional[int] = None 27 | vocab_size: int = -1 28 | multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 29 | ffn_dim_multiplier: Optional[float] = None 30 | norm_eps: float = 1e-5 31 | rope_theta: float = 500000 32 | 33 | max_batch_size: int = 32 34 | max_seq_len: int = 2048 35 | 36 | 37 | def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: 38 | """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" 39 | bs, slen, n_kv_heads, head_dim = x.shape 40 | if n_rep == 1: 41 | return x 42 | return ( 43 | x[:, :, :, None, :] 44 | .expand(bs, slen, n_kv_heads, n_rep, head_dim) 45 | .reshape(bs, slen, n_kv_heads * n_rep, head_dim) 46 | ) 47 | 48 | 49 | class Attention(nn.Module): 50 | def __init__(self, args: ModelArgs, use_triton=False): 51 | super().__init__() 52 | self.use_triton = use_triton 53 | self.Math = MathOps(use_triton) 54 | self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads 55 | model_parallel_size = fs_init.get_model_parallel_world_size() 56 | self.n_local_heads = args.n_heads // model_parallel_size 57 | self.n_local_kv_heads = self.n_kv_heads // model_parallel_size 58 | self.n_rep = self.n_local_heads // self.n_local_kv_heads 59 | self.head_dim = args.dim // args.n_heads 60 | 61 | self.wq = ColumnParallelLinear( 62 | args.dim, 63 | args.n_heads * self.head_dim, 64 | bias=False, 65 | gather_output=False, 66 | init_method=lambda x: x, 67 | ) 68 | self.wk = ColumnParallelLinear( 69 | args.dim, 70 | self.n_kv_heads * self.head_dim, 71 | bias=False, 72 | gather_output=False, 73 | init_method=lambda x: x, 74 | ) 75 | self.wv = ColumnParallelLinear( 76 | args.dim, 77 | self.n_kv_heads * self.head_dim, 78 | bias=False, 79 | gather_output=False, 80 | init_method=lambda x: x, 81 | ) 82 | self.wo = RowParallelLinear( 83 | args.n_heads * self.head_dim, 84 | args.dim, 85 | bias=False, 86 | input_is_parallel=True, 87 | init_method=lambda x: x, 88 | ) 89 | 90 | self.cache_k = torch.zeros( 91 | ( 92 | args.max_batch_size, 93 | args.max_seq_len, 94 | self.n_local_kv_heads, 95 | self.head_dim, 96 | ) 97 | ).cuda() 98 | self.cache_v = torch.zeros( 99 | ( 100 | args.max_batch_size, 101 | args.max_seq_len, 102 | self.n_local_kv_heads, 103 | self.head_dim, 104 | ) 105 | ).cuda() 106 | 107 | @Profiler.profiling_decorator(record_name="attention_forward", skip_profiling=True) 108 | def forward( 109 | self, 110 | x: torch.Tensor, 111 | start_pos: int, 112 | freqs_cis: torch.Tensor, 113 | mask: Optional[torch.Tensor], 114 | ): 115 | bsz, seqlen, _ = x.shape 116 | xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) 117 | 118 | xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) 119 | xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) 120 | xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) 121 | 122 | xq, xk = self.Math.apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) 123 | 124 | self.cache_k = self.cache_k.to(xq) 125 | self.cache_v = self.cache_v.to(xq) 126 | 127 | self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk 128 | self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv 129 | 130 | keys = self.cache_k[:bsz, : start_pos + seqlen] 131 | values = self.cache_v[:bsz, : start_pos + seqlen] 132 | 133 | # repeat k/v heads if n_kv_heads < n_heads 134 | keys = repeat_kv( 135 | keys, self.n_rep 136 | ) # (bs, cache_len + seqlen, n_local_heads, head_dim) 137 | values = repeat_kv( 138 | values, self.n_rep 139 | ) # (bs, cache_len + seqlen, n_local_heads, head_dim) 140 | 141 | xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) 142 | keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) 143 | values = values.transpose( 144 | 1, 2 145 | ) # (bs, n_local_heads, cache_len + seqlen, head_dim) 146 | output = self.Math.attention(xq, keys, values, self.head_dim, mask) 147 | output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) 148 | return self.wo(output) 149 | 150 | 151 | class FeedForward(nn.Module): 152 | def __init__( 153 | self, 154 | dim: int, 155 | hidden_dim: int, 156 | multiple_of: int, 157 | ffn_dim_multiplier: Optional[float], 158 | ): 159 | super().__init__() 160 | hidden_dim = int(2 * hidden_dim / 3) 161 | # custom dim factor multiplier 162 | if ffn_dim_multiplier is not None: 163 | hidden_dim = int(ffn_dim_multiplier * hidden_dim) 164 | hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) 165 | 166 | self.w1 = ColumnParallelLinear( 167 | dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x 168 | ) 169 | self.w2 = RowParallelLinear( 170 | hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x 171 | ) 172 | self.w3 = ColumnParallelLinear( 173 | dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x 174 | ) 175 | 176 | @Profiler.profiling_decorator( 177 | record_name="feed_forward_forward", skip_profiling=True 178 | ) 179 | def forward(self, x): 180 | return self.w2(F.silu(self.w1(x)) * self.w3(x)) 181 | 182 | 183 | class TransformerBlock(nn.Module): 184 | def __init__(self, layer_id: int, args: ModelArgs, use_triton=False): 185 | super().__init__() 186 | self.use_triton = use_triton 187 | self.Math = MathOps(use_triton) 188 | 189 | self.n_heads = args.n_heads 190 | self.dim = args.dim 191 | self.head_dim = args.dim // args.n_heads 192 | self.attention = Attention(args, use_triton=self.use_triton) 193 | self.feed_forward = FeedForward( 194 | dim=args.dim, 195 | hidden_dim=4 * args.dim, 196 | multiple_of=args.multiple_of, 197 | ffn_dim_multiplier=args.ffn_dim_multiplier, 198 | ) 199 | self.layer_id = layer_id 200 | self.attention_norm = self.Math.get_rms_norm(args.dim, eps=args.norm_eps) 201 | self.ffn_norm = self.Math.get_rms_norm(args.dim, eps=args.norm_eps) 202 | 203 | @Profiler.profiling_decorator( 204 | record_name="transform_block_forward", skip_profiling=True 205 | ) 206 | def forward( 207 | self, 208 | x: torch.Tensor, 209 | start_pos: int, 210 | freqs_cis: torch.Tensor, 211 | mask: Optional[torch.Tensor], 212 | ): 213 | h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask) 214 | out = h + self.feed_forward(self.ffn_norm(h)) 215 | return out 216 | 217 | 218 | class Transformer(nn.Module): 219 | def __init__(self, params: ModelArgs, use_triton=False): 220 | super().__init__() 221 | self.use_triton = use_triton 222 | self.Math = MathOps(use_triton) 223 | self.params = params 224 | self.vocab_size = params.vocab_size 225 | self.n_layers = params.n_layers 226 | 227 | self.tok_embeddings = VocabParallelEmbedding( 228 | params.vocab_size, params.dim, init_method=lambda x: x 229 | ) 230 | 231 | self.layers = torch.nn.ModuleList() 232 | for layer_id in range(params.n_layers): 233 | self.layers.append(TransformerBlock(layer_id, params, self.use_triton)) 234 | 235 | self.norm = self.Math.get_rms_norm(params.dim, eps=params.norm_eps) 236 | self.output = ColumnParallelLinear( 237 | params.dim, params.vocab_size, bias=False, init_method=lambda x: x 238 | ) 239 | 240 | self.freqs_cis = self.Math.precompute_freqs_cis( 241 | params.dim // params.n_heads, 242 | params.max_seq_len * 2, 243 | params.rope_theta, 244 | ) 245 | 246 | @torch.inference_mode() 247 | @Profiler.profiling_decorator( 248 | record_name="transformer_forward", skip_profiling=True 249 | ) 250 | def forward(self, tokens: torch.Tensor, start_pos: int): 251 | _bsz, seqlen = tokens.shape 252 | h = self.tok_embeddings(tokens) 253 | self.freqs_cis = self.freqs_cis.to(h.device) 254 | freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen] 255 | 256 | mask = None 257 | if seqlen > 1: 258 | mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device) 259 | 260 | mask = torch.triu(mask, diagonal=1) 261 | 262 | # When performing key-value caching, we compute the attention scores 263 | # only for the new sequence. Thus, the matrix of scores is of size 264 | # (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for 265 | # j > cache_len + i, since row i corresponds to token cache_len + i. 266 | mask = torch.hstack( 267 | [torch.zeros((seqlen, start_pos), device=tokens.device), mask] 268 | ).type_as(h) 269 | 270 | for layer in self.layers: 271 | h = layer(h, start_pos, freqs_cis, mask) 272 | h = self.norm(h) 273 | output = self.output(h).float() 274 | return output 275 | -------------------------------------------------------------------------------- /kernels/matmul.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from triton import Config, autotune, cdiv, heuristics, jit 4 | from triton import language as tl 5 | from .matmul_perf_model import early_config_prune, estimate_matmul_time 6 | 7 | _ordered_datatypes = [torch.int8, torch.float16, torch.bfloat16, torch.float32] 8 | 9 | 10 | def upcast_if_fp8(a): 11 | if "fp8" in str(a): 12 | return torch.float16 13 | return a 14 | 15 | 16 | def get_higher_dtype(a, b): 17 | a = upcast_if_fp8(a) 18 | b = upcast_if_fp8(b) 19 | if a is b: 20 | return a 21 | 22 | assert a in _ordered_datatypes 23 | assert b in _ordered_datatypes 24 | 25 | for d in _ordered_datatypes: 26 | if a is d: 27 | return b 28 | if b is d: 29 | return a 30 | 31 | 32 | def init_to_zero(name): 33 | return lambda nargs: nargs[name].zero_() 34 | 35 | 36 | def get_configs_io_bound(): 37 | configs = [] 38 | for num_stages in [2, 3, 4, 5, 6]: 39 | for block_m in [16, 32]: 40 | for block_k in [32, 64]: 41 | for block_n in [32, 64, 128, 256]: 42 | num_warps = 2 if block_n <= 64 else 4 43 | configs.append( 44 | Config( 45 | { 46 | "BLOCK_M": block_m, 47 | "BLOCK_N": block_n, 48 | "BLOCK_K": block_k, 49 | "SPLIT_K": 1, 50 | }, 51 | num_stages=num_stages, 52 | num_warps=num_warps, 53 | ) 54 | ) 55 | # split_k 56 | for split_k in [2, 4, 8, 16]: 57 | configs.append( 58 | Config( 59 | { 60 | "BLOCK_M": block_m, 61 | "BLOCK_N": block_n, 62 | "BLOCK_K": block_k, 63 | "SPLIT_K": split_k, 64 | }, 65 | num_stages=num_stages, 66 | num_warps=num_warps, 67 | pre_hook=init_to_zero("C"), 68 | ) 69 | ) 70 | return configs 71 | 72 | 73 | @autotune( 74 | configs=[ 75 | # basic configs for compute-bound matmuls 76 | Config( 77 | {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, 78 | num_stages=3, 79 | num_warps=8, 80 | ), 81 | Config( 82 | {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, 83 | num_stages=3, 84 | num_warps=8, 85 | ), 86 | Config( 87 | {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, 88 | num_stages=4, 89 | num_warps=4, 90 | ), 91 | Config( 92 | {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, 93 | num_stages=4, 94 | num_warps=4, 95 | ), 96 | Config( 97 | {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, 98 | num_stages=4, 99 | num_warps=4, 100 | ), 101 | Config( 102 | {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, 103 | num_stages=4, 104 | num_warps=4, 105 | ), 106 | Config( 107 | {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, 108 | num_stages=4, 109 | num_warps=4, 110 | ), 111 | Config( 112 | {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, 113 | num_stages=4, 114 | num_warps=4, 115 | ), 116 | Config( 117 | {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, 118 | num_stages=5, 119 | num_warps=2, 120 | ), 121 | # good for int8 122 | Config( 123 | {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, 124 | num_stages=3, 125 | num_warps=8, 126 | ), 127 | Config( 128 | {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, 129 | num_stages=3, 130 | num_warps=8, 131 | ), 132 | Config( 133 | {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, 134 | num_stages=4, 135 | num_warps=4, 136 | ), 137 | Config( 138 | {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, 139 | num_stages=4, 140 | num_warps=4, 141 | ), 142 | Config( 143 | {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, 144 | num_stages=4, 145 | num_warps=4, 146 | ), 147 | Config( 148 | {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, 149 | num_stages=4, 150 | num_warps=4, 151 | ), 152 | Config( 153 | {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, 154 | num_stages=4, 155 | num_warps=4, 156 | ), 157 | Config( 158 | {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, 159 | num_stages=4, 160 | num_warps=4, 161 | ), 162 | Config( 163 | {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, 164 | num_stages=5, 165 | num_warps=2, 166 | ), 167 | ] 168 | + get_configs_io_bound(), 169 | key=["M", "N", "K"], 170 | prune_configs_by={ 171 | "early_config_prune": early_config_prune, 172 | "perf_model": estimate_matmul_time, 173 | "top_k": 10, 174 | }, 175 | ) 176 | @heuristics( 177 | { 178 | "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0, 179 | } 180 | ) 181 | @jit 182 | def _kernel( 183 | A, 184 | B, 185 | C, 186 | M, 187 | N, 188 | K, # 189 | stride_am, 190 | stride_ak, # 191 | stride_bk, 192 | stride_bn, # 193 | stride_cm, 194 | stride_cn, # 195 | acc_dtype: tl.constexpr, # 196 | input_precision: tl.constexpr, # 197 | fp8_fast_accum: tl.constexpr, # 198 | BLOCK_M: tl.constexpr, 199 | BLOCK_N: tl.constexpr, 200 | BLOCK_K: tl.constexpr, # 201 | GROUP_M: tl.constexpr, 202 | SPLIT_K: tl.constexpr, 203 | EVEN_K: tl.constexpr, 204 | AB_DTYPE: tl.constexpr, # 205 | ): 206 | # matrix multiplication 207 | pid = tl.program_id(0) 208 | pid_z = tl.program_id(1) 209 | grid_m = tl.cdiv(M, BLOCK_M) 210 | grid_n = tl.cdiv(N, BLOCK_N) 211 | # re-order program ID for better L2 performance 212 | width = GROUP_M * grid_n 213 | group_id = pid // width 214 | group_size = min(grid_m - group_id * GROUP_M, GROUP_M) 215 | pid_m = group_id * GROUP_M + (pid % group_size) 216 | pid_n = (pid % width) // (group_size) 217 | # do matrix multiplication 218 | rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 219 | rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 220 | ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) 221 | rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) 222 | rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) 223 | # pointers 224 | A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) 225 | B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) 226 | acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=acc_dtype) 227 | for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): 228 | if EVEN_K: 229 | a = tl.load(A) 230 | b = tl.load(B) 231 | else: 232 | k_remaining = K - k * (BLOCK_K * SPLIT_K) 233 | _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty) 234 | a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0) 235 | b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0) 236 | if AB_DTYPE is not None: 237 | a = a.to(AB_DTYPE) 238 | b = b.to(AB_DTYPE) 239 | if fp8_fast_accum: 240 | acc = tl.dot( 241 | a, b, acc, out_dtype=acc_dtype, input_precision=input_precision 242 | ) 243 | else: 244 | acc += tl.dot(a, b, out_dtype=acc_dtype, input_precision=input_precision) 245 | A += BLOCK_K * SPLIT_K * stride_ak 246 | B += BLOCK_K * SPLIT_K * stride_bk 247 | acc = acc.to(C.dtype.element_ty) 248 | # rematerialize rm and rn to save registers 249 | rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 250 | rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 251 | C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) 252 | mask = (rm < M)[:, None] & (rn < N)[None, :] 253 | # handles write-back with reduction-splitting 254 | if SPLIT_K == 1: 255 | tl.store(C, acc, mask=mask) 256 | else: 257 | tl.atomic_add(C, acc, mask=mask) 258 | 259 | 260 | class _matmul(torch.autograd.Function): 261 | kernel = _kernel 262 | 263 | _locks = {} 264 | 265 | @staticmethod 266 | def _call(a, b, acc_dtype, input_precision, fp8_fast_accum, output_dtype): 267 | device = a.device 268 | # handle non-contiguous inputs if necessary 269 | if a.stride(0) > 1 and a.stride(1) > 1: 270 | a = a.contiguous() 271 | if b.stride(0) > 1 and b.stride(1) > 1: 272 | b = b.contiguous() 273 | # checks constraints 274 | assert ( 275 | a.shape[1] == b.shape[0] 276 | ), f"incompatible dimensions {a.shape} and {b.shape}" 277 | M, K = a.shape 278 | _, N = b.shape 279 | 280 | # common type between a and b 281 | ab_dtype = get_higher_dtype(a.dtype, b.dtype) 282 | 283 | # allocates output 284 | if output_dtype is None: 285 | output_dtype = ab_dtype 286 | 287 | c = torch.empty((M, N), device=device, dtype=output_dtype) 288 | 289 | # Allowed types for acc_type given the types of a and b. 290 | supported_acc_dtypes = { 291 | torch.float16: (torch.float32, torch.float16), 292 | torch.bfloat16: (torch.float32, torch.bfloat16), 293 | torch.float32: (torch.float32,), 294 | torch.int8: (torch.int32,), 295 | } 296 | 297 | if acc_dtype is None: 298 | acc_dtype = supported_acc_dtypes[ab_dtype][0] 299 | else: 300 | assert isinstance(acc_dtype, torch.dtype), "acc_dtype must be a torch.dtype" 301 | assert ( 302 | acc_dtype in supported_acc_dtypes[a.dtype] 303 | ), "acc_dtype not compatible with the type of a" 304 | assert ( 305 | acc_dtype in supported_acc_dtypes[b.dtype] 306 | ), "acc_dtype not compatible with the type of b" 307 | 308 | def to_tl_type(ty): 309 | return getattr(tl, str(ty).split(".")[-1]) 310 | 311 | acc_dtype = to_tl_type(acc_dtype) 312 | ab_dtype = to_tl_type(ab_dtype) 313 | output_dtype = to_tl_type(output_dtype) 314 | 315 | # Tensor cores support input with mixed float8 types. 316 | if a.dtype in [tl.float8e4nv, tl.float8e5] and b.dtype in [ 317 | tl.float8e4nv, 318 | tl.float8e5, 319 | ]: 320 | ab_dtype = None 321 | # launch kernel 322 | grid = lambda META: ( 323 | cdiv(M, META["BLOCK_M"]) * cdiv(N, META["BLOCK_N"]), 324 | META["SPLIT_K"], 325 | ) 326 | _kernel[grid]( 327 | a, 328 | b, 329 | c, 330 | M, 331 | N, 332 | K, # 333 | a.stride(0), 334 | a.stride(1), # 335 | b.stride(0), 336 | b.stride(1), # 337 | c.stride(0), 338 | c.stride(1), # 339 | acc_dtype=acc_dtype, # 340 | input_precision=input_precision, # 341 | fp8_fast_accum=fp8_fast_accum, # 342 | GROUP_M=8, 343 | AB_DTYPE=ab_dtype, 344 | ) 345 | return c 346 | 347 | @staticmethod 348 | def forward( 349 | ctx, 350 | a, 351 | b, 352 | acc_dtype=None, 353 | input_precision=None, 354 | fp8_fast_accum=True, 355 | output_dtype=None, 356 | ): 357 | return _matmul._call( 358 | a, 359 | b, 360 | acc_dtype=acc_dtype, 361 | input_precision=input_precision, 362 | fp8_fast_accum=fp8_fast_accum, 363 | output_dtype=output_dtype, 364 | ) 365 | 366 | 367 | matmul = _matmul.apply 368 | -------------------------------------------------------------------------------- /models/llama/llama/generation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. 3 | 4 | import json 5 | import os 6 | import sys 7 | import time 8 | from pathlib import Path 9 | from typing import List, Optional, Tuple, TypedDict 10 | 11 | import torch 12 | from fairscale.nn.model_parallel.initialize import ( 13 | get_model_parallel_rank, 14 | initialize_model_parallel, 15 | model_parallel_is_initialized, 16 | ) 17 | 18 | from .model import ModelArgs, Transformer 19 | from .math_ops import MathOps 20 | from .tokenizer import ChatFormat, Dialog, Message, Tokenizer 21 | from benchmarking import Profiler 22 | 23 | 24 | class CompletionPrediction(TypedDict, total=False): 25 | generation: str 26 | tokens: List[str] # not required 27 | logprobs: List[float] # not required 28 | 29 | 30 | class ChatPrediction(TypedDict, total=False): 31 | generation: Message 32 | tokens: List[str] # not required 33 | logprobs: List[float] # not required 34 | 35 | 36 | class Llama: 37 | @staticmethod 38 | def build( 39 | ckpt_dir: str, 40 | tokenizer_path: str, 41 | max_seq_len: int, 42 | max_batch_size: int, 43 | model_parallel_size: Optional[int] = None, 44 | seed: int = 1, 45 | use_triton: bool = False, 46 | ) -> "Llama": 47 | """ 48 | Build a Llama instance by initializing and loading a model checkpoint. 49 | 50 | Args: 51 | ckpt_dir (str): Path to the directory containing checkpoint files. 52 | tokenizer_path (str): Path to the tokenizer file. 53 | max_seq_len (int): Maximum sequence length for input text. 54 | max_batch_size (int): Maximum batch size for inference. 55 | model_parallel_size (Optional[int], optional): Number of model parallel processes. 56 | If not provided, it's determined from the environment. Defaults to None. 57 | 58 | Returns: 59 | Llama: An instance of the Llama class with the loaded model and tokenizer. 60 | 61 | Raises: 62 | AssertionError: If there are no checkpoint files in the specified directory, 63 | or if the model parallel size does not match the number of checkpoint files. 64 | 65 | Note: 66 | This method initializes the distributed process group, sets the device to CUDA, 67 | and loads the pre-trained model and tokenizer. 68 | """ 69 | assert ( 70 | 1 <= max_seq_len <= 8192 71 | ), f"max_seq_len must be between 1 and 8192, got {max_seq_len}." 72 | assert os.path.isdir( 73 | ckpt_dir 74 | ), f"Checkpoint directory '{ckpt_dir}' does not exist." 75 | assert os.path.isfile( 76 | tokenizer_path 77 | ), f"Tokenizer file '{tokenizer_path}' does not exist." 78 | 79 | if not torch.distributed.is_initialized(): 80 | torch.distributed.init_process_group("nccl") 81 | if model_parallel_size is None: 82 | model_parallel_size = int(os.environ.get("WORLD_SIZE", 1)) 83 | if not model_parallel_is_initialized(): 84 | initialize_model_parallel(model_parallel_size) 85 | 86 | local_rank = int(os.environ.get("LOCAL_RANK", 0)) 87 | torch.cuda.set_device(local_rank) 88 | 89 | # seed must be the same in all processes 90 | torch.manual_seed(seed) 91 | 92 | if local_rank > 0: 93 | sys.stdout = open(os.devnull, "w") 94 | 95 | start_time = time.time() 96 | checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) 97 | assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}" 98 | assert model_parallel_size == len( 99 | checkpoints 100 | ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}" 101 | ckpt_path = checkpoints[get_model_parallel_rank()] 102 | checkpoint = torch.load(ckpt_path, map_location="cpu") 103 | with open(Path(ckpt_dir) / "params.json", "r") as f: 104 | params = json.loads(f.read()) 105 | 106 | model_args: ModelArgs = ModelArgs( 107 | max_seq_len=max_seq_len, 108 | max_batch_size=max_batch_size, 109 | **params, 110 | ) 111 | tokenizer = Tokenizer(model_path=tokenizer_path) 112 | assert model_args.vocab_size == tokenizer.n_words 113 | if torch.cuda.is_bf16_supported(): 114 | torch.set_default_tensor_type(torch.cuda.BFloat16Tensor) 115 | else: 116 | torch.set_default_tensor_type(torch.cuda.HalfTensor) 117 | model = Transformer(model_args, use_triton=use_triton) 118 | model.load_state_dict(checkpoint, strict=False) 119 | print(f"Loaded in {time.time() - start_time:.2f} seconds") 120 | 121 | return Llama(model, tokenizer, use_triton) 122 | 123 | def __init__( 124 | self, model: Transformer, tokenizer: Tokenizer, use_triton: bool = False 125 | ): 126 | self.model = model 127 | self.tokenizer = tokenizer 128 | self.formatter = ChatFormat(tokenizer) 129 | self.use_triton = use_triton 130 | self.Math = MathOps(use_triton) 131 | 132 | @torch.inference_mode() 133 | @Profiler.profiling_decorator(record_name="generate", skip_profiling=True) 134 | def generate( 135 | self, 136 | prompt_tokens: List[List[int]], 137 | max_gen_len: int, 138 | temperature: float = 0.6, 139 | top_p: float = 0.9, 140 | logprobs: bool = False, 141 | echo: bool = False, 142 | ) -> Tuple[List[List[int]], Optional[List[List[float]]]]: 143 | """ 144 | Generate text sequences based on provided prompts using the language generation model. 145 | 146 | Args: 147 | prompt_tokens (List[List[int]]): List of tokenized prompts, where each prompt is represented as a list of integers. 148 | max_gen_len (int): Maximum length of the generated text sequence. 149 | temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6. 150 | top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9. 151 | logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False. 152 | echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False. 153 | 154 | Returns: 155 | Tuple[List[List[int]], Optional[List[List[float]]]]: A tuple containing generated token sequences and, if logprobs is True, corresponding token log probabilities. 156 | 157 | Note: 158 | This method uses the provided prompts as a basis for generating text. It employs nucleus sampling to produce text with controlled randomness. 159 | If logprobs is True, token log probabilities are computed for each generated token. 160 | 161 | """ 162 | params = self.model.params 163 | bsz = len(prompt_tokens) 164 | assert bsz <= params.max_batch_size, (bsz, params.max_batch_size) 165 | 166 | min_prompt_len = min(len(t) for t in prompt_tokens) 167 | max_prompt_len = max(len(t) for t in prompt_tokens) 168 | assert max_prompt_len <= params.max_seq_len 169 | total_len = min(params.max_seq_len, max_gen_len + max_prompt_len) 170 | 171 | pad_id = self.tokenizer.pad_id 172 | tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda") 173 | for k, t in enumerate(prompt_tokens): 174 | tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda") 175 | if logprobs: 176 | token_logprobs = torch.zeros_like(tokens, dtype=torch.float) 177 | 178 | prev_pos = 0 179 | eos_reached = torch.tensor([False] * bsz, device="cuda") 180 | input_text_mask = tokens != pad_id 181 | if min_prompt_len == total_len: 182 | logits = self.model.forward(tokens, prev_pos) 183 | 184 | token_logprobs = self.Math.cross_entropy( 185 | input=logits.transpose(1, 2), 186 | target=tokens, 187 | reduction="none", 188 | ignore_index=pad_id, 189 | ) 190 | 191 | stop_tokens = torch.tensor(list(self.tokenizer.stop_tokens)) 192 | 193 | for cur_pos in range(min_prompt_len, total_len): 194 | logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) 195 | if temperature > 0: 196 | probs = self.Math.softmax(logits[:, -1] / temperature, dim=-1) 197 | next_token = sample_top_p(probs, top_p) 198 | else: 199 | next_token = self.Math.argmax(logits[:, -1], dim=-1) 200 | 201 | next_token = next_token.reshape(-1) 202 | # only replace token if prompt has already been generated 203 | next_token = torch.where( 204 | input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token 205 | ) 206 | tokens[:, cur_pos] = next_token 207 | if logprobs: 208 | token_logprobs[:, prev_pos + 1 : cur_pos + 1] = self.Math.cross_entropy( 209 | input=logits.transpose(1, 2), 210 | target=tokens[:, prev_pos + 1 : cur_pos + 1], 211 | reduction="none", 212 | ignore_index=pad_id, 213 | ) 214 | eos_reached |= (~input_text_mask[:, cur_pos]) & ( 215 | torch.isin(next_token, stop_tokens) 216 | ) 217 | prev_pos = cur_pos 218 | if all(eos_reached): 219 | break 220 | 221 | if logprobs: 222 | token_logprobs = token_logprobs.tolist() 223 | out_tokens, out_logprobs = [], [] 224 | for i, toks in enumerate(tokens.tolist()): 225 | # cut to max gen len 226 | start = 0 if echo else len(prompt_tokens[i]) 227 | toks = toks[start : len(prompt_tokens[i]) + max_gen_len] 228 | probs = None 229 | if logprobs: 230 | probs = token_logprobs[i][start : len(prompt_tokens[i]) + max_gen_len] 231 | # cut to after eos tok if any 232 | for stop_token in self.tokenizer.stop_tokens: 233 | try: 234 | eos_idx = toks.index(stop_token) 235 | toks = toks[:eos_idx] 236 | probs = probs[:eos_idx] if logprobs else None 237 | except ValueError: 238 | pass 239 | out_tokens.append(toks) 240 | out_logprobs.append(probs) 241 | return (out_tokens, out_logprobs if logprobs else None) 242 | 243 | @Profiler.profiling_decorator(skip_profiling=True) 244 | def text_completion( 245 | self, 246 | prompts: List[str], 247 | temperature: float = 0.6, 248 | top_p: float = 0.9, 249 | max_gen_len: Optional[int] = None, 250 | logprobs: bool = False, 251 | echo: bool = False, 252 | ) -> List[CompletionPrediction]: 253 | """ 254 | Perform text completion for a list of prompts using the language generation model. 255 | 256 | Args: 257 | prompts (List[str]): List of text prompts for completion. 258 | temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6. 259 | top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9. 260 | max_gen_len (Optional[int], optional): Maximum length of the generated completion sequence. 261 | If not provided, it's set to the model's maximum sequence length minus 1. 262 | logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False. 263 | echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False. 264 | 265 | Returns: 266 | List[CompletionPrediction]: List of completion predictions, each containing the generated text completion. 267 | 268 | Note: 269 | This method generates text completions for the provided prompts, employing nucleus sampling to introduce controlled randomness. 270 | If logprobs is True, token log probabilities are computed for each generated token. 271 | 272 | """ 273 | if max_gen_len is None: 274 | max_gen_len = self.model.params.max_seq_len - 1 275 | prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts] 276 | generation_tokens, generation_logprobs = self.generate( 277 | prompt_tokens=prompt_tokens, 278 | max_gen_len=max_gen_len, 279 | temperature=temperature, 280 | top_p=top_p, 281 | logprobs=logprobs, 282 | echo=echo, 283 | ) 284 | if logprobs: 285 | return [ 286 | { 287 | "generation": self.tokenizer.decode(t), 288 | "tokens": [self.tokenizer.decode([x]) for x in t], 289 | "logprobs": logprobs_i, 290 | } 291 | for t, logprobs_i in zip(generation_tokens, generation_logprobs) 292 | ] 293 | return [{"generation": self.tokenizer.decode(t)} for t in generation_tokens] 294 | 295 | @Profiler.profiling_decorator(skip_profiling=True) 296 | def chat_completion( 297 | self, 298 | dialogs: List[Dialog], 299 | temperature: float = 0.6, 300 | top_p: float = 0.9, 301 | max_gen_len: Optional[int] = None, 302 | logprobs: bool = False, 303 | ) -> List[ChatPrediction]: 304 | """ 305 | Generate assistant responses for a list of conversational dialogs using the language generation model. 306 | 307 | Args: 308 | dialogs (List[Dialog]): List of conversational dialogs, where each dialog is a list of messages. 309 | temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6. 310 | top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9. 311 | max_gen_len (Optional[int], optional): Maximum length of the generated response sequence. 312 | If not provided, it's set to the model's maximum sequence length minus 1. 313 | logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False. 314 | 315 | Returns: 316 | List[ChatPrediction]: List of chat predictions, each containing the assistant's generated response. 317 | 318 | Note: 319 | This method generates assistant responses for the provided conversational dialogs. 320 | It employs nucleus sampling to introduce controlled randomness in text generation. 321 | If logprobs is True, token log probabilities are computed for each generated token. 322 | """ 323 | if max_gen_len is None: 324 | max_gen_len = self.model.params.max_seq_len - 1 325 | 326 | prompt_tokens = [ 327 | self.formatter.encode_dialog_prompt(dialog) for dialog in dialogs 328 | ] 329 | generation_tokens, generation_logprobs = self.generate( 330 | prompt_tokens=prompt_tokens, 331 | max_gen_len=max_gen_len, 332 | temperature=temperature, 333 | top_p=top_p, 334 | logprobs=logprobs, 335 | ) 336 | if logprobs: 337 | return [ 338 | { 339 | "generation": { 340 | "role": "assistant", 341 | "content": self.tokenizer.decode(t), 342 | }, 343 | "tokens": [self.tokenizer.decode([x]) for x in t], 344 | "logprobs": logprobs_i, 345 | } 346 | for t, logprobs_i in zip(generation_tokens, generation_logprobs) 347 | ] 348 | return [ 349 | { 350 | "generation": { 351 | "role": "assistant", 352 | "content": self.tokenizer.decode(t), 353 | }, 354 | } 355 | for t in generation_tokens 356 | ] 357 | 358 | 359 | def sample_top_p(probs, p): 360 | """ 361 | Perform top-p (nucleus) sampling on a probability distribution. 362 | 363 | Args: 364 | probs (torch.Tensor): Probability distribution tensor. 365 | p (float): Probability threshold for top-p sampling. 366 | 367 | Returns: 368 | torch.Tensor: Sampled token indices. 369 | 370 | Note: 371 | Top-p sampling selects the smallest set of tokens whose cumulative probability mass 372 | exceeds the threshold p. The distribution is renormalized based on the selected tokens. 373 | """ 374 | probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) 375 | probs_sum = torch.cumsum(probs_sort, dim=-1) 376 | mask = probs_sum - probs_sort > p 377 | probs_sort[mask] = 0.0 378 | probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) 379 | next_token = torch.multinomial(probs_sort, num_samples=1) 380 | next_token = torch.gather(probs_idx, -1, next_token) 381 | return next_token 382 | -------------------------------------------------------------------------------- /kernels/blocksparse/matmul.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from triton import cdiv, heuristics, jit 4 | from triton import language as tl 5 | 6 | # ******************************************************** 7 | # -------------------------------------------------------- 8 | # Sparse = Dense x Dense (SDD) 9 | # This operation uses super-blocking to make sure that 10 | # it's done efficiently when small blocks can be grouped 11 | # together 12 | # -------------------------------------------------------- 13 | # ******************************************************** 14 | 15 | 16 | @heuristics( 17 | { 18 | "EVEN_K": lambda nargs: nargs["K"] % nargs["TILE_K"] == 0, 19 | } 20 | ) 21 | @jit 22 | def _sdd_kernel( 23 | A, 24 | B, 25 | C, # 26 | stride_za, 27 | stride_ha, 28 | stride_ma, 29 | stride_ak, # 30 | stride_zb, 31 | stride_hb, 32 | stride_bk, 33 | stride_nb, # 34 | stride_zc, 35 | stride_hc, 36 | stride_mc, 37 | stride_nc, # 38 | K, 39 | grid_offset, 40 | lut, # 41 | TILE_M: tl.constexpr, 42 | TILE_N: tl.constexpr, 43 | TILE_K: tl.constexpr, # 44 | BLOCK: tl.constexpr, 45 | EVEN_K: tl.constexpr, # 46 | ): 47 | # ------------ # 48 | # - Prologue - # 49 | # ------------ # 50 | block_id = tl.program_id(0) + grid_offset 51 | lut += block_id * 3 52 | # offsets 53 | off_z = tl.program_id(2) # batch 54 | off_h = tl.load(lut + 0) # head 55 | 56 | # initialize pointers to A 57 | start_am = tl.load(lut + 1) 58 | offs_am = start_am * BLOCK + (tl.arange(0, TILE_M) % BLOCK) 59 | offs_ak = tl.arange(0, TILE_K) 60 | a_ptrs = ( 61 | A 62 | + off_z * stride_za 63 | + off_h * stride_ha 64 | + offs_am[:, None] * stride_ma 65 | + offs_ak[None, :] * stride_ak 66 | ) 67 | # initialize pointers to B 68 | start_bn = tl.load(lut + 2) 69 | offs_bn = start_bn * BLOCK + (tl.arange(0, TILE_N) % BLOCK) 70 | offs_bk = tl.arange(0, TILE_K) 71 | b_ptrs = ( 72 | B 73 | + off_z * stride_zb 74 | + off_h * stride_hb 75 | + offs_bn[None, :] * stride_nb 76 | + offs_bk[:, None] * stride_bk 77 | ) 78 | # ---------------- # 79 | # Inner Loop # 80 | # ---------------- # 81 | acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32) 82 | for k in range(K, 0, -TILE_K): 83 | if EVEN_K: 84 | a = tl.load(a_ptrs) 85 | b = tl.load(b_ptrs) 86 | else: 87 | a = tl.load(a_ptrs, mask=offs_ak[None, :] < k, other=0.0) 88 | b = tl.load(b_ptrs, mask=offs_bk[:, None] < k, other=0.0) 89 | acc += tl.dot(a, b, out_dtype=tl.float32) 90 | a_ptrs += TILE_K * stride_ak 91 | b_ptrs += TILE_K * stride_bk 92 | c = acc.to(C.dtype.element_ty) 93 | # ---------------- # 94 | # Epilogue # 95 | # ---------------- # 96 | offs_cm = tl.arange(0, TILE_M) % BLOCK 97 | offs_cn = tl.arange(0, TILE_N) % BLOCK 98 | pc = ( 99 | C 100 | + off_z * stride_zc 101 | + block_id * stride_hc 102 | + offs_cm[:, None] * stride_mc 103 | + offs_cn[None, :] * stride_nc 104 | ) 105 | tl.store(pc, c, mask=True) 106 | 107 | 108 | def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, widths, out=None): 109 | if a.stride(2) != 1 and a.stride(3) != 1: 110 | a = a.contiguous() 111 | if b.stride(2) != 1 and b.stride(3) != 1: 112 | b = b.contiguous() 113 | # (A * B)^T = B^T * A^T 114 | if trans_c: 115 | a, b = b, a 116 | trans_a, trans_b = not trans_b, not trans_a 117 | # shape constraints 118 | a_dim = -2 if trans_a else -1 119 | b_dim = -1 if trans_b else -2 120 | Ka, Kb = a.shape[a_dim], b.shape[b_dim] 121 | if Ka != Kb: 122 | raise ValueError(f"Inner dimension mismatch (A: {Ka} vs B: {Kb})") 123 | # allocate output 124 | if out is None: 125 | c = torch.empty( 126 | (a.shape[0], lut.shape[0], block, block), dtype=a.dtype, device=a.device 127 | ) 128 | else: 129 | assert out.shape == (a.shape[0], lut.shape[0], block, block) 130 | c = out 131 | grid = [c.shape[1], 1, c.shape[0]] 132 | _sdd_kernel[grid]( 133 | a, 134 | b, 135 | c, # 136 | a.stride(0), 137 | a.stride(1), 138 | a.stride(3 if trans_a else 2), 139 | a.stride(2 if trans_a else 3), # 140 | b.stride(0), 141 | b.stride(1), 142 | b.stride(3 if trans_b else 2), 143 | b.stride(2 if trans_b else 3), # 144 | c.stride(0), 145 | c.stride(1), 146 | c.stride(2), 147 | c.stride(3), # 148 | Ka, 149 | 0, 150 | lut, # 151 | TILE_M=block, 152 | TILE_N=block, 153 | TILE_K=32, 154 | BLOCK=block, 155 | num_stages=4, # 156 | num_warps=4, # 157 | ) 158 | return c 159 | 160 | 161 | def sdd_lut(layout, block, device): 162 | lut = layout.nonzero(as_tuple=False).to(device).int() 163 | lut = lut.contiguous() 164 | return lut, None 165 | 166 | 167 | # ----------------------------- 168 | # Dense = Sparse x Dense (DSD) 169 | # This operation uses a look-up table that contains pre-computed pointer increments 170 | # in order to minimize computations in the inner loop of the matmul kernel. 171 | # ----------------------------- 172 | 173 | 174 | @jit 175 | def _dsd_kernel( 176 | A, 177 | B, 178 | C, # 179 | stride_az, 180 | stride_ha, 181 | stride_am, 182 | stride_ak, # 183 | stride_zb, 184 | stride_hb, 185 | stride_bk, 186 | stride_bn, # 187 | stride_zc, 188 | stride_hc, 189 | stride_cm, 190 | stride_cn, # 191 | DS0, 192 | DS1, 193 | lut, # 194 | TILE_M: tl.constexpr, 195 | TILE_N: tl.constexpr, 196 | TILE_K: tl.constexpr, # 197 | GROUP_SIZE_M: tl.constexpr, 198 | BLOCK: tl.constexpr, # 199 | ): 200 | # ------------ # 201 | # - Prologue - # 202 | # ------------ # 203 | pid_m = tl.program_id(0) 204 | pid_n = tl.program_id(1) 205 | num_pid_m = tl.num_programs(0) 206 | num_pid_n = tl.num_programs(1) 207 | pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_SIZE_M) 208 | pidz = tl.program_id(2) 209 | header = lut + pid_n * 4 210 | offset = tl.load(header + 0) 211 | K = tl.load(header + 1) 212 | column = tl.load(header + 2) 213 | off_h = tl.load(header + 3) 214 | pinc = lut + offset 215 | # initialize pointers to A (sparse) 216 | block_id = tl.load(pinc + 1) 217 | block_id = tl.multiple_of(block_id, 8) # compiler hint 218 | offs_am = tl.arange(0, TILE_M) 219 | offs_ak = tl.arange(0, TILE_K) 220 | pa = ( 221 | A 222 | + pidz * stride_az 223 | + block_id * stride_ha 224 | + offs_am[:, None] * stride_am 225 | + offs_ak[None, :] * stride_ak 226 | ) 227 | # initialize pointers to B (dense) 228 | offs_bn = pid_m * TILE_N + tl.arange(0, TILE_N) 229 | offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn % DS0, TILE_N), TILE_N) 230 | start_bk = tl.load(pinc) 231 | start_bk = tl.multiple_of(start_bk, 8) # compiler hint 232 | offs_bk = start_bk + tl.arange(0, TILE_K) 233 | pb = ( 234 | B 235 | + pidz * stride_zb 236 | + off_h * stride_hb 237 | + offs_bn[None, :] * stride_bn 238 | + offs_bk[:, None] * stride_bk 239 | ) 240 | # ---------------- # 241 | # Inner Loop # 242 | # ---------------- # 243 | acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32) 244 | pinc += 2 245 | inc_a = tl.load(pinc + 1) 246 | inc_a = tl.multiple_of(inc_a, 8) 247 | inc_b = tl.load(pinc) 248 | inc_b = tl.multiple_of(inc_b, 8) 249 | for k in range(K, 0, -TILE_K): 250 | a = tl.load(pa) 251 | b = tl.load(pb) 252 | acc += tl.dot(a, b, out_dtype=tl.float32) 253 | pa += inc_a 254 | pb += inc_b * stride_bk 255 | pinc += 2 256 | inc_a = tl.load(pinc + 1) 257 | inc_a = tl.multiple_of(inc_a, 8) 258 | inc_b = tl.load(pinc) 259 | inc_b = tl.multiple_of(inc_b, 8) 260 | c = acc.to(C.dtype.element_ty) 261 | # initialize pointers to C 262 | offs_cm = column * TILE_M + tl.arange(0, TILE_M) 263 | offs_cn = pid_m * TILE_N + tl.arange(0, TILE_N) 264 | pc = ( 265 | C 266 | + off_h * stride_hc 267 | + pidz * stride_zc 268 | + offs_cm[:, None] * stride_cm 269 | + offs_cn[None, :] * stride_cn 270 | ) 271 | tl.store(pc, c, mask=offs_cn[None, :] < DS0) 272 | 273 | 274 | def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=None): 275 | if a.stride(2) != 1 and a.stride(3) != 1: 276 | a = a.contiguous() 277 | if b.stride(2) != 1 and b.stride(3) != 1: 278 | b = b.contiguous() 279 | # shapes / dtypes 280 | AS1 = block * spdims[2 if trans_a else 1] 281 | BS0 = b.size(0) 282 | BS1 = b.size(1) 283 | BS3 = b.size(2 if trans_b else 3) 284 | dtype = a.dtype 285 | # allocate output 286 | CS0 = BS0 287 | CS1 = BS1 288 | CS2 = BS3 if trans_c else AS1 289 | CS3 = AS1 if trans_c else BS3 290 | if out is None: 291 | c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device) 292 | else: 293 | assert out.shape == (CS0, CS1, CS2, CS3) 294 | c = out 295 | # meta-parameter heuristics 296 | TILE_N = 128 297 | # compute output 298 | grid = lambda meta: [cdiv(BS3, meta["TILE_N"]), width, BS0] 299 | _dsd_kernel[grid]( 300 | a, 301 | b, 302 | c, # 303 | a.stride(0), 304 | a.stride(1), 305 | a.stride(3 if trans_a else 2), 306 | a.stride(2 if trans_a else 3), # 307 | b.stride(0), 308 | b.stride(1), 309 | b.stride(3 if trans_b else 2), 310 | b.stride(2 if trans_b else 3), # 311 | c.stride(0), 312 | c.stride(1), 313 | c.stride(3 if trans_c else 2), 314 | c.stride(2 if trans_c else 3), # 315 | BS3, 316 | AS1, 317 | lut, # 318 | TILE_M=block, 319 | TILE_N=TILE_N, 320 | TILE_K=min(block, 32), 321 | BLOCK=block, 322 | num_stages=4, # 323 | num_warps=4, 324 | GROUP_SIZE_M=4, # 325 | ) 326 | # exit() 327 | return c 328 | 329 | 330 | def dsd_lut(layout, block, step, trans, device): 331 | """ 332 | Generates the look-up table for incrementing pointers in the DSD/DDS matmul. 333 | Example (BLOCK=32, STEP=16) 334 | [[1, 0, 0, 1, 0], 335 | [0, 1, 1, 0, 1], 336 | [1, 0, 1, 0, 0]] 337 | 338 | Then the offsets for A are 339 | [0 , 16, 32, 48] <- row 0 340 | \\----/ \\----/ 341 | col=0 col=3 342 | [64, 80, 96, 112, 128, 144] <- row 1 343 | \\----/ \\----/ \\------/ 344 | col=1 col=2 col=3 345 | [160, 176, 192, 208] 346 | which leads to increments table 347 | [0, 16, 16, 16, || 64, 16, 16, 16, 16, 16, || 160, 16, 16, 16] 348 | 349 | Because B is dense, the offsets are 350 | [0, 16, 96, 112] <- row 0 351 | [32, 48, 64, 80] <- row 1 352 | [0, 16, 64, 80] <- row 2 353 | """ 354 | sizes = torch.sum(layout, 2 if trans else 1) 355 | head_id, col_id = torch.ones_like(sizes).nonzero(as_tuple=True) 356 | sizes = sizes.flatten() 357 | segments = sizes * step 358 | # pointer increments 359 | if trans: 360 | nnz = layout.nonzero(as_tuple=False) 361 | else: 362 | nnz = layout.transpose(1, 2).nonzero(as_tuple=False) 363 | num_blocks = nnz.size(0) 364 | offsets = torch.zeros_like(sizes) 365 | offsets[1:] = torch.cumsum(sizes[:-1], dim=0) 366 | offsets = torch.min(offsets, (num_blocks - 1) * torch.ones_like(offsets)) 367 | # ------------------------------- 368 | # dense input pointer increments 369 | # ------------------------------- 370 | # Note that the inner loop matmul kernel may have a fixed step size (e.g., TILE_K) 371 | # that is smaller than the block size, so we need to do a bit of extra work 372 | # to handle this case 373 | B_idx = nnz[:, 2] * block 374 | B_incs = B_idx.clone() 375 | B_incs[1:] -= B_idx[:-1] 376 | div = block // step 377 | B_incs = B_incs.view(-1, 1).repeat(1, div) 378 | B_incs[:, 1:] = step 379 | B_incs[:, 0] -= (div - 1) * step 380 | # first increment for each reduction is actually the offset 381 | B_incs[offsets[segments > 0], 0] = B_idx[offsets[segments > 0]] 382 | B_incs = B_incs.view(-1) 383 | # ------------------------------- 384 | # sparse input pointer increments 385 | # ------------------------------- 386 | # same as above, except that the increments are in the sparse memory layout 387 | if trans: 388 | A_idx = torch.arange(num_blocks, device=layout.device) 389 | else: 390 | A_idx = torch.tensor([], dtype=torch.int64, device=layout.device) 391 | current_offset = 0 392 | for z in range(layout.size(0)): 393 | layoutw = layout[z, :, :].clone().long() 394 | msum = layoutw.sum() 395 | layoutw[layoutw > 0] = 1 + torch.arange(msum, device=layout.device) 396 | A_idx = torch.cat((A_idx, current_offset + layoutw.T[layoutw.T > 0] - 1)) 397 | current_offset += msum 398 | A_incs = A_idx * block * block 399 | A_incs[1:] -= A_idx[:-1] * block * block 400 | A_incs = A_incs.view(-1, 1).repeat(1, div) 401 | if trans: 402 | A_incs[:, 1:] = step 403 | A_incs[:, 0] -= (div - 1) * step 404 | else: 405 | A_incs[:, 1:] = step * block 406 | A_incs[:, 0] -= (div - 1) * step * block 407 | A_incs[offsets[segments > 0], 0] = A_idx[offsets[segments > 0]] 408 | A_incs = A_incs.view(-1) 409 | # create header 410 | width = col_id.size(0) 411 | offsets = offsets * 2 * div + 4 * width 412 | segments = segments * div 413 | header = ( 414 | torch.stack((offsets, segments, col_id, head_id), dim=1).view(-1).contiguous() 415 | ) 416 | # create increments 417 | incs = torch.stack((B_incs, A_incs), dim=1).view(-1).contiguous() 418 | # pad by a factor 2*MAX_NUM_STAGES 419 | # to accommodate pre-fetching inside the kernel 420 | pad = torch.zeros(20, device=incs.device, dtype=incs.dtype) 421 | incs = torch.cat((incs, pad)) 422 | # create lut 423 | lut = torch.cat((header, incs)) 424 | lut = lut.type(torch.int32).to(device) 425 | # create locks 426 | return lut, width 427 | 428 | 429 | # ----------------------------- 430 | # Dense = Dense x Sparse (DDS) 431 | # ----------------------------- 432 | # AB = (B^T A^T)^T 433 | 434 | 435 | def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=None): 436 | return dsd_matmul( 437 | b, a, not trans_b, not trans_a, not trans_c, spdims, block, lut, width, out=out 438 | ) 439 | 440 | 441 | ############## 442 | # MAIN API # 443 | ############## 444 | 445 | 446 | class _matmul(torch.autograd.Function): 447 | 448 | fn = {"sdd": sdd_matmul, "dsd": dsd_matmul, "dds": dds_matmul} 449 | 450 | @staticmethod 451 | def forward( 452 | ctx, 453 | a, 454 | b, 455 | trans_a, 456 | trans_b, 457 | trans_c, 458 | mode, 459 | spdims, 460 | block, 461 | c_lut, 462 | c_width, 463 | da_lut, 464 | da_width, 465 | db_lut, 466 | db_width, 467 | out, 468 | ): 469 | c = _matmul.fn[mode]( 470 | a, b, trans_a, trans_b, trans_c, spdims, block, c_lut, c_width, out=out 471 | ) 472 | # save for backward 473 | ctx.save_for_backward(a, b) 474 | ctx.da_lut = da_lut 475 | ctx.da_width = da_width 476 | ctx.db_lut = db_lut 477 | ctx.db_width = db_width 478 | ctx.mode = mode 479 | ctx.spdims = spdims 480 | ctx.block = block 481 | ctx.trans_a = trans_a 482 | ctx.trans_b = trans_b 483 | ctx.trans_c = trans_c 484 | ctx.has_out = out is not None 485 | return c 486 | 487 | @staticmethod 488 | def backward(ctx, dc): 489 | # saved for backward 490 | a, b = ctx.saved_tensors 491 | da, db = None, None 492 | mode = ctx.mode 493 | # gradients w.r.t. a 494 | if ctx.needs_input_grad[0]: 495 | mode_da = mode[1] + mode[0] + mode[2] 496 | da = _matmul.fn[mode_da]( 497 | dc, 498 | b, 499 | ctx.trans_c, 500 | not ctx.trans_b, 501 | ctx.trans_a, 502 | ctx.spdims, 503 | ctx.block, 504 | ctx.da_lut, 505 | ctx.da_width, 506 | ) 507 | # gradients w.r.t. b 508 | if ctx.needs_input_grad[1]: 509 | mode_db = mode[2] + mode[1] + mode[0] 510 | db = _matmul.fn[mode_db]( 511 | a, 512 | dc, 513 | not ctx.trans_a, 514 | ctx.trans_c, 515 | ctx.trans_b, 516 | ctx.spdims, 517 | ctx.block, 518 | ctx.db_lut, 519 | ctx.db_width, 520 | ) 521 | dout = dc if ctx.has_out else None 522 | return ( 523 | da, 524 | db, 525 | None, 526 | None, 527 | None, 528 | None, 529 | None, 530 | None, 531 | None, 532 | None, 533 | None, 534 | None, 535 | None, 536 | None, 537 | dout, 538 | ) 539 | 540 | 541 | class matmul: 542 | 543 | def __init__( 544 | self, layout, block, mode, device, trans_a=False, trans_b=False, trans_c=False 545 | ): 546 | if mode not in ["sdd", "dsd", "dds"]: 547 | raise NotImplementedError("Supported modes are: sdd, dsd, dds") 548 | self.block = block 549 | self.mode = mode 550 | self.trans_a = trans_a 551 | self.trans_b = trans_b 552 | self.trans_c = trans_c 553 | self.layout = layout 554 | self.spdims = layout.shape 555 | step = min(block, 32) 556 | if self.mode == "sdd": 557 | self.c_lut, self.c_width = sdd_lut(layout, block, device) 558 | self.da_lut, self.da_width = dsd_lut(layout, block, step, True, device) 559 | self.db_lut, self.db_width = dsd_lut(layout, block, step, False, device) 560 | if self.mode == "dsd": 561 | self.c_lut, self.c_width = dsd_lut( 562 | layout, block, step, not self.trans_a, device 563 | ) 564 | self.da_lut, self.da_width = sdd_lut(layout, block, device) 565 | self.db_lut, self.db_width = dsd_lut( 566 | layout, block, step, self.trans_a, device 567 | ) 568 | if self.mode == "dds": 569 | self.c_lut, self.c_width = dsd_lut( 570 | layout, block, step, self.trans_b, device 571 | ) 572 | self.da_lut, self.da_width = dsd_lut( 573 | layout, block, step, not self.trans_b, device 574 | ) 575 | self.db_lut, self.db_width = sdd_lut(layout, block, device) 576 | 577 | def __call__(self, a, b, out=None): 578 | c = _matmul.apply( 579 | a, 580 | b, 581 | self.trans_a, 582 | self.trans_b, 583 | self.trans_c, 584 | self.mode, 585 | self.spdims, 586 | self.block, # 587 | self.c_lut, 588 | self.c_width, # 589 | self.da_lut, 590 | self.da_width, # 591 | self.db_lut, 592 | self.db_width, # 593 | out, 594 | ) 595 | return c 596 | -------------------------------------------------------------------------------- /kernels/flash_attention.py: -------------------------------------------------------------------------------- 1 | """ 2 | Fused Attention 3 | =============== 4 | This is a Triton implementation of the Flash Attention algorithm 5 | (see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf) 6 | 7 | Sequence Parallel implementation inspired by HazyResearch 8 | (see https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attn_triton.py) 9 | """ 10 | 11 | import torch 12 | import triton 13 | 14 | from triton import cdiv, jit 15 | from triton import language as tl 16 | 17 | 18 | def is_hip(): 19 | return triton.runtime.driver.active.get_current_target().backend == "hip" 20 | 21 | 22 | @jit 23 | def _fwd_kernel( 24 | Q, 25 | K, 26 | V, 27 | sm_scale, # 28 | L, # 29 | Out, # 30 | stride_qz, 31 | stride_qh, 32 | stride_qm, 33 | stride_qk, # 34 | stride_kz, 35 | stride_kh, 36 | stride_kn, 37 | stride_kk, # 38 | stride_vz, 39 | stride_vh, 40 | stride_vn, 41 | stride_vk, # 42 | stride_oz, 43 | stride_oh, 44 | stride_om, 45 | stride_on, # 46 | Z, 47 | H, 48 | N_CTX, # 49 | Z_H_N_CTX, # 50 | BLOCK_M: tl.constexpr, 51 | BLOCK_DMODEL: tl.constexpr, # 52 | BLOCK_N: tl.constexpr, # 53 | IS_CAUSAL: tl.constexpr, # 54 | ): 55 | start_m = tl.program_id(0) 56 | off_hz = tl.program_id(1) 57 | qvk_offset = off_hz * stride_qh 58 | vk_offset = qvk_offset // stride_qm 59 | 60 | K_block_ptr = tl.make_block_ptr( 61 | base=K, 62 | shape=(BLOCK_DMODEL, Z_H_N_CTX), 63 | strides=(stride_kk, stride_kn), 64 | offsets=(0, vk_offset), 65 | block_shape=(BLOCK_DMODEL, BLOCK_N), 66 | order=(0, 1), 67 | ) 68 | V_block_ptr = tl.make_block_ptr( 69 | base=V, 70 | shape=(Z_H_N_CTX, BLOCK_DMODEL), 71 | strides=(stride_vn, stride_vk), 72 | offsets=(vk_offset, 0), 73 | block_shape=(BLOCK_N, BLOCK_DMODEL), 74 | order=(1, 0), 75 | ) 76 | # initialize offsets 77 | offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) 78 | offs_n = tl.arange(0, BLOCK_N) 79 | # initialize pointer to m and l 80 | m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") 81 | l_i = tl.zeros([BLOCK_M], dtype=tl.float32) 82 | acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) 83 | # credits to: Adam P. Goucher (https://github.com/apgoucher): 84 | # scale sm_scale by 1/log_2(e) and use 85 | # 2^x instead of exp in the loop because CSE and LICM 86 | # don't work as expected with `exp` in the loop 87 | qk_scale = sm_scale * 1.44269504 88 | # load q: it will stay in SRAM throughout 89 | 90 | offs_k = tl.arange(0, BLOCK_DMODEL) 91 | Q_ptrs = Q + qvk_offset + offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk 92 | q = tl.load(Q_ptrs) 93 | 94 | q = (q * qk_scale).to(K.dtype.element_ty) 95 | lo = 0 96 | hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX 97 | for start_n in range(lo, hi, BLOCK_N): 98 | # -- load k, v -- 99 | k = tl.load(K_block_ptr) 100 | v = tl.load(V_block_ptr) 101 | # -- compute qk --- 102 | qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 103 | if IS_CAUSAL: 104 | qk = tl.where( 105 | offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf") 106 | ) 107 | qk += tl.dot(q, k) 108 | # -- compute scaling constant --- 109 | m_i_new = tl.maximum(m_i, tl.max(qk, 1)) 110 | alpha = tl.math.exp2(m_i - m_i_new) 111 | p = tl.math.exp2(qk - m_i_new[:, None]) 112 | # -- scale and update acc -- 113 | acc *= alpha[:, None] 114 | acc += tl.dot(p.to(V.dtype.element_ty), v) 115 | # -- update m_i and l_i -- 116 | l_i = l_i * alpha + tl.sum(p, 1) 117 | m_i = m_i_new 118 | # update pointers 119 | K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) 120 | V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) 121 | # write back l and m 122 | acc = acc / l_i[:, None] 123 | l_ptrs = L + off_hz * N_CTX + offs_m 124 | tl.store(l_ptrs, m_i + tl.math.log2(l_i)) 125 | # write back O 126 | O_block_ptr = tl.make_block_ptr( 127 | base=Out, 128 | shape=(Z_H_N_CTX, BLOCK_DMODEL), 129 | strides=(stride_om, stride_on), 130 | offsets=(vk_offset + start_m * BLOCK_M, 0), 131 | block_shape=(BLOCK_M, BLOCK_DMODEL), 132 | order=(1, 0), 133 | ) 134 | # O_ptrs = Out + qvk_offset + offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk 135 | tl.store(O_block_ptr, acc.to(K.dtype.element_ty)) 136 | 137 | 138 | @jit 139 | def _bwd_preprocess( 140 | Out, 141 | DO, 142 | Delta, 143 | BLOCK_M: tl.constexpr, 144 | D_HEAD: tl.constexpr, 145 | ): 146 | off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) 147 | off_n = tl.arange(0, D_HEAD) 148 | # load 149 | o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) 150 | do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) 151 | # compute 152 | delta = tl.sum(o * do, axis=1) 153 | # write-back 154 | tl.store(Delta + off_m, delta) 155 | 156 | 157 | @jit 158 | def _bwd_kernel_one_col_block( 159 | Q, 160 | K, 161 | V, 162 | sm_scale, 163 | qk_scale, # 164 | Out, 165 | DO, # 166 | DQ, 167 | DK, 168 | DV, # 169 | L, # 170 | D, # 171 | Q_block_ptr, 172 | K_block_ptr, 173 | V_block_ptr, # 174 | DO_block_ptr, 175 | DQ_block_ptr, 176 | DK_block_ptr, 177 | DV_block_ptr, # 178 | stride_dqa, 179 | stride_qz, 180 | stride_qh, 181 | stride_qm, 182 | stride_qk, # 183 | stride_kz, 184 | stride_kh, 185 | stride_kn, 186 | stride_kk, # 187 | stride_vz, 188 | stride_vh, 189 | stride_vn, 190 | stride_vk, # 191 | Z, 192 | H, 193 | N_CTX, # 194 | off_h, 195 | off_z, 196 | off_hz, 197 | start_n, 198 | num_block, # 199 | BLOCK_M: tl.constexpr, 200 | BLOCK_DMODEL: tl.constexpr, # 201 | BLOCK_N: tl.constexpr, # 202 | SEQUENCE_PARALLEL: tl.constexpr, # 203 | CAUSAL: tl.constexpr, # 204 | MMA_V3: tl.constexpr, # 205 | ): 206 | if CAUSAL: 207 | lo = start_n * BLOCK_M 208 | else: 209 | lo = 0 210 | 211 | Q_offset = (off_z * stride_qz + off_h * stride_qh) // stride_qm 212 | DQ_offset = off_z * stride_qz + off_h * stride_qh 213 | K_offset = (off_z * stride_kz + off_h * stride_kh) // stride_kn 214 | V_offset = (off_z * stride_vz + off_h * stride_vh) // stride_vn 215 | if SEQUENCE_PARALLEL: 216 | DQ_offset += stride_dqa * start_n 217 | DQ_offset = DQ_offset // stride_qm 218 | 219 | Q_block_ptr = tl.advance(Q_block_ptr, (lo + Q_offset, 0)) 220 | K_block_ptr = tl.advance(K_block_ptr, (start_n * BLOCK_M + K_offset, 0)) 221 | V_block_ptr = tl.advance(V_block_ptr, (start_n * BLOCK_M + V_offset, 0)) 222 | DO_block_ptr = tl.advance(DO_block_ptr, (lo + Q_offset, 0)) 223 | DQ_block_ptr = tl.advance(DQ_block_ptr, (lo + DQ_offset, 0)) 224 | DK_block_ptr = tl.advance(DK_block_ptr, (start_n * BLOCK_M + K_offset, 0)) 225 | DV_block_ptr = tl.advance(DV_block_ptr, (start_n * BLOCK_M + V_offset, 0)) 226 | 227 | # initialize row/col offsets 228 | offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M) 229 | offs_m = tl.arange(0, BLOCK_N) 230 | # pointer to row-wise quantities in value-like data 231 | D_ptrs = D + off_hz * N_CTX 232 | l_ptrs = L + off_hz * N_CTX 233 | # initialize dv amd dk 234 | dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) 235 | dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) 236 | # k and v stay in SRAM throughout 237 | k = tl.load(K_block_ptr) 238 | v = tl.load(V_block_ptr) 239 | # loop over rows 240 | for start_m in range(lo, num_block * BLOCK_M, BLOCK_M): 241 | offs_m_curr = start_m + offs_m 242 | # load q, k, v, do on-chip 243 | q = tl.load(Q_block_ptr) 244 | # recompute p = softmax(qk, dim=-1).T 245 | # NOTE: `do` is pre-divided by `l`; no normalization here 246 | if CAUSAL: 247 | qk = tl.where( 248 | offs_m_curr[:, None] >= (offs_n[None, :]), float(0.0), float("-inf") 249 | ) 250 | else: 251 | qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 252 | qk += tl.dot(q, tl.trans(k)) 253 | qk *= qk_scale 254 | l_i = tl.load(l_ptrs + offs_m_curr) 255 | p = tl.math.exp2(qk - l_i[:, None]) 256 | # compute dv 257 | do = tl.load(DO_block_ptr) 258 | dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do) 259 | # compute dp = dot(v, do) 260 | Di = tl.load(D_ptrs + offs_m_curr) 261 | # dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None] 262 | dp = tl.dot(do, tl.trans(v)) 263 | # compute ds = p * (dp - delta[:, None]) 264 | ds = (p * (dp - Di[:, None]) * sm_scale).to(Q.dtype.element_ty) 265 | # compute dk = dot(ds.T, q) 266 | dk += tl.dot(tl.trans(ds), q) 267 | # compute dq 268 | if not SEQUENCE_PARALLEL: 269 | dq = tl.load(DQ_block_ptr) 270 | dq += tl.dot(ds, k) 271 | tl.store(DQ_block_ptr, dq.to(Q.dtype.element_ty)) 272 | elif SEQUENCE_PARALLEL: 273 | if MMA_V3: 274 | dq = tl.dot(ds, k) 275 | else: 276 | # not work with mma v3, because M % 64 != 0 277 | dq = tl.trans(tl.dot(tl.trans(k), tl.trans(ds))) 278 | tl.store(DQ_block_ptr, dq.to(Q.dtype.element_ty)) 279 | 280 | # increment pointers 281 | DQ_block_ptr = tl.advance(DQ_block_ptr, (BLOCK_M, 0)) 282 | Q_block_ptr = tl.advance(Q_block_ptr, (BLOCK_M, 0)) 283 | DO_block_ptr = tl.advance(DO_block_ptr, (BLOCK_M, 0)) 284 | # write-back 285 | tl.store(DV_block_ptr, dv.to(V.dtype.element_ty)) 286 | tl.store(DK_block_ptr, dk.to(K.dtype.element_ty)) 287 | 288 | 289 | @jit 290 | def _bwd_kernel( 291 | Q, 292 | K, 293 | V, 294 | sm_scale, # 295 | Out, 296 | DO, # 297 | DQ, 298 | DK, 299 | DV, # 300 | L, # 301 | D, # 302 | stride_dqa, 303 | stride_qz, 304 | stride_qh, 305 | stride_qm, 306 | stride_qk, # 307 | stride_kz, 308 | stride_kh, 309 | stride_kn, 310 | stride_kk, # 311 | stride_vz, 312 | stride_vh, 313 | stride_vn, 314 | stride_vk, # 315 | Z, 316 | H, 317 | N_CTX, # 318 | Z_H_N_CTX, # 319 | SQ_Z_H_N_CTX, # 320 | BLOCK_M: tl.constexpr, 321 | BLOCK_DMODEL: tl.constexpr, # 322 | BLOCK_N: tl.constexpr, # 323 | SEQUENCE_PARALLEL: tl.constexpr, # 324 | CAUSAL: tl.constexpr, # 325 | MMA_V3: tl.constexpr, # 326 | ): 327 | qk_scale = sm_scale * 1.44269504 328 | off_hz = tl.program_id(0) 329 | off_z = off_hz // H 330 | off_h = off_hz % H 331 | 332 | Q_block_ptr = tl.make_block_ptr( 333 | base=Q, 334 | shape=(Z_H_N_CTX, BLOCK_DMODEL), 335 | strides=(stride_qm, stride_qk), 336 | offsets=(0, 0), 337 | block_shape=(BLOCK_M, BLOCK_DMODEL), 338 | order=(1, 0), 339 | ) 340 | K_block_ptr = tl.make_block_ptr( 341 | base=K, 342 | shape=(Z_H_N_CTX, BLOCK_DMODEL), 343 | strides=(stride_kn, stride_kk), 344 | offsets=(0, 0), 345 | block_shape=(BLOCK_M, BLOCK_DMODEL), 346 | order=(1, 0), 347 | ) 348 | V_block_ptr = tl.make_block_ptr( 349 | base=V, 350 | shape=(Z_H_N_CTX, BLOCK_DMODEL), 351 | strides=(stride_vn, stride_vk), 352 | offsets=(0, 0), 353 | block_shape=(BLOCK_M, BLOCK_DMODEL), 354 | order=(1, 0), 355 | ) 356 | DO_block_ptr = tl.make_block_ptr( 357 | base=DO, 358 | shape=(Z_H_N_CTX, BLOCK_DMODEL), 359 | strides=(stride_qm, stride_qk), 360 | offsets=(0, 0), 361 | block_shape=(BLOCK_M, BLOCK_DMODEL), 362 | order=(1, 0), 363 | ) 364 | if SEQUENCE_PARALLEL: 365 | DQ_block_ptr = tl.make_block_ptr( 366 | base=DQ, 367 | shape=(SQ_Z_H_N_CTX, BLOCK_DMODEL), 368 | strides=(stride_qm, stride_qk), 369 | offsets=(0, 0), 370 | block_shape=(BLOCK_M, BLOCK_DMODEL), 371 | order=(1, 0), 372 | ) 373 | else: 374 | DQ_block_ptr = tl.make_block_ptr( 375 | base=DQ, 376 | shape=(Z_H_N_CTX, BLOCK_DMODEL), 377 | strides=(stride_qm, stride_qk), 378 | offsets=(0, 0), 379 | block_shape=(BLOCK_M, BLOCK_DMODEL), 380 | order=(1, 0), 381 | ) 382 | 383 | DK_block_ptr = tl.make_block_ptr( 384 | base=DK, 385 | shape=(Z_H_N_CTX, BLOCK_DMODEL), 386 | strides=(stride_kn, stride_kk), 387 | offsets=(0, 0), 388 | block_shape=(BLOCK_M, BLOCK_DMODEL), 389 | order=(1, 0), 390 | ) 391 | DV_block_ptr = tl.make_block_ptr( 392 | base=DV, 393 | shape=(Z_H_N_CTX, BLOCK_DMODEL), 394 | strides=(stride_vn, stride_vk), 395 | offsets=(0, 0), 396 | block_shape=(BLOCK_M, BLOCK_DMODEL), 397 | order=(1, 0), 398 | ) 399 | 400 | num_block_n = tl.cdiv(N_CTX, BLOCK_N) 401 | if not SEQUENCE_PARALLEL: 402 | for start_n in range(0, num_block_n): 403 | _bwd_kernel_one_col_block( 404 | Q, 405 | K, 406 | V, 407 | sm_scale, 408 | qk_scale, 409 | Out, 410 | DO, # 411 | DQ, 412 | DK, 413 | DV, # 414 | L, # 415 | D, # 416 | Q_block_ptr, 417 | K_block_ptr, 418 | V_block_ptr, # 419 | DO_block_ptr, 420 | DQ_block_ptr, 421 | DK_block_ptr, 422 | DV_block_ptr, # 423 | stride_dqa, 424 | stride_qz, 425 | stride_qh, 426 | stride_qm, 427 | stride_qk, # 428 | stride_kz, 429 | stride_kh, 430 | stride_kn, 431 | stride_kk, # 432 | stride_vz, 433 | stride_vh, 434 | stride_vn, 435 | stride_vk, # 436 | Z, 437 | H, 438 | N_CTX, # 439 | off_h, 440 | off_z, 441 | off_hz, 442 | start_n, 443 | num_block_n, # 444 | BLOCK_M=BLOCK_M, 445 | BLOCK_DMODEL=BLOCK_DMODEL, # 446 | BLOCK_N=BLOCK_N, # 447 | SEQUENCE_PARALLEL=SEQUENCE_PARALLEL, # 448 | CAUSAL=CAUSAL, # 449 | MMA_V3=MMA_V3, # 450 | ) 451 | else: 452 | start_n = tl.program_id(1) 453 | _bwd_kernel_one_col_block( 454 | Q, 455 | K, 456 | V, 457 | sm_scale, 458 | qk_scale, 459 | Out, 460 | DO, # 461 | DQ, 462 | DK, 463 | DV, # 464 | L, # 465 | D, # 466 | Q_block_ptr, 467 | K_block_ptr, 468 | V_block_ptr, # 469 | DO_block_ptr, 470 | DQ_block_ptr, 471 | DK_block_ptr, 472 | DV_block_ptr, # 473 | stride_dqa, 474 | stride_qz, 475 | stride_qh, 476 | stride_qm, 477 | stride_qk, # 478 | stride_kz, 479 | stride_kh, 480 | stride_kn, 481 | stride_kk, # 482 | stride_vz, 483 | stride_vh, 484 | stride_vn, 485 | stride_vk, # 486 | Z, 487 | H, 488 | N_CTX, # 489 | off_h, 490 | off_z, 491 | off_hz, 492 | start_n, 493 | num_block_n, # 494 | BLOCK_M=BLOCK_M, 495 | BLOCK_DMODEL=BLOCK_DMODEL, # 496 | BLOCK_N=BLOCK_N, # 497 | SEQUENCE_PARALLEL=SEQUENCE_PARALLEL, # 498 | CAUSAL=CAUSAL, # 499 | MMA_V3=MMA_V3, # 500 | ) 501 | 502 | 503 | class _attention(torch.autograd.Function): 504 | 505 | @staticmethod 506 | def forward(ctx, q, k, v, causal, sm_scale, sequence_parallel=False): 507 | # only support for Ampere now 508 | capability = torch.cuda.get_device_capability() 509 | if capability[0] < 8: 510 | raise RuntimeError( 511 | "Flash attention currently only supported for compute capability >= 80" 512 | ) 513 | BLOCK_M = 128 514 | BLOCK_N = 64 515 | # shape constraints 516 | Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] 517 | assert Lq == Lk and Lk == Lv 518 | assert Lk in {16, 32, 64, 128} 519 | o = torch.empty_like(q) 520 | grid = (cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1) 521 | L = torch.empty( 522 | (q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32 523 | ) 524 | num_warps = 4 if Lk <= 64 else 8 525 | _fwd_kernel[grid]( 526 | q, 527 | k, 528 | v, 529 | sm_scale, # 530 | L, # 531 | o, # 532 | q.stride(0), 533 | q.stride(1), 534 | q.stride(2), 535 | q.stride(3), # 536 | k.stride(0), 537 | k.stride(1), 538 | k.stride(2), 539 | k.stride(3), # 540 | v.stride(0), 541 | v.stride(1), 542 | v.stride(2), 543 | v.stride(3), # 544 | o.stride(0), 545 | o.stride(1), 546 | o.stride(2), 547 | o.stride(3), # 548 | q.shape[0], 549 | q.shape[1], 550 | q.shape[2], # 551 | q.shape[0] * q.shape[1] * q.shape[2], # 552 | BLOCK_M=BLOCK_M, 553 | BLOCK_N=BLOCK_N, 554 | BLOCK_DMODEL=Lk, # 555 | IS_CAUSAL=causal, # 556 | num_warps=num_warps, # 557 | num_stages=4, # 558 | ) 559 | 560 | ctx.save_for_backward(q, k, v, o, L) 561 | ctx.grid = grid 562 | ctx.sm_scale = sm_scale 563 | ctx.BLOCK_DMODEL = Lk 564 | ctx.causal = causal 565 | ctx.sequence_parallel = sequence_parallel 566 | return o 567 | 568 | @staticmethod 569 | def backward(ctx, do): 570 | capability = torch.cuda.get_device_capability() 571 | MMA_V3 = capability[0] >= 9 572 | BLOCK = 128 573 | 574 | if is_hip(): 575 | # Bwd pass runs out of shared memory on HIP with larger block size. 576 | BLOCK = 64 577 | 578 | q, k, v, o, L = ctx.saved_tensors 579 | sequence_parallel = ctx.sequence_parallel 580 | seq_len_kv = k.shape[2] 581 | do = do.contiguous() 582 | if sequence_parallel: 583 | replicas = cdiv(seq_len_kv, BLOCK) 584 | new_dq_shape = (replicas,) + q.shape 585 | dq = torch.zeros(new_dq_shape, device=q.device, dtype=q.dtype) 586 | else: 587 | dq = torch.zeros_like(q, dtype=q.dtype) 588 | dk = torch.empty_like(k) 589 | dv = torch.empty_like(v) 590 | delta = torch.empty_like(L) 591 | _bwd_preprocess[(cdiv(q.shape[2], BLOCK) * ctx.grid[1],)]( 592 | o, 593 | do, 594 | delta, 595 | BLOCK_M=BLOCK, 596 | D_HEAD=ctx.BLOCK_DMODEL, 597 | ) 598 | _bwd_kernel[(ctx.grid[1], cdiv(seq_len_kv, BLOCK) if sequence_parallel else 1)]( 599 | q, 600 | k, 601 | v, 602 | ctx.sm_scale, # 603 | o, 604 | do, # 605 | dq, 606 | dk, 607 | dv, # 608 | L, # 609 | delta, # 610 | o.numel(), 611 | q.stride(0), 612 | q.stride(1), 613 | q.stride(2), 614 | q.stride(3), # 615 | k.stride(0), 616 | k.stride(1), 617 | k.stride(2), 618 | k.stride(3), # 619 | v.stride(0), 620 | v.stride(1), 621 | v.stride(2), 622 | v.stride(3), # 623 | q.shape[0], 624 | q.shape[1], 625 | q.shape[2], # 626 | q.shape[0] * q.shape[1] * q.shape[2], # 627 | cdiv(seq_len_kv, BLOCK) * q.shape[0] * q.shape[1] * q.shape[2], # 628 | BLOCK_M=BLOCK, 629 | BLOCK_N=BLOCK, # 630 | BLOCK_DMODEL=ctx.BLOCK_DMODEL, # 631 | SEQUENCE_PARALLEL=sequence_parallel, # 632 | CAUSAL=ctx.causal, # 633 | MMA_V3=MMA_V3, # 634 | num_warps=8, # 635 | num_stages=1, # 636 | ) 637 | 638 | if len(dq.shape) == 5: 639 | dq = dq.sum(dim=0) 640 | return dq, dk, dv, None, None, None 641 | 642 | 643 | attention = _attention.apply 644 | -------------------------------------------------------------------------------- /test/test_matmul.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | 3 | import pytest 4 | import torch 5 | 6 | import triton 7 | import triton.language as tl 8 | import triton.ops 9 | 10 | 11 | def is_hip(): 12 | return triton.runtime.driver.active.get_current_target().backend == "hip" 13 | 14 | 15 | @pytest.mark.parametrize( 16 | "BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, INPUT_PRECISION, F8_FASTACCUM, ACC_DTYPE, OUTPUT_DTYPE", 17 | itertools.chain( 18 | *[ 19 | [ 20 | # 1 warp 21 | ( 22 | 16, 23 | 16, 24 | 16, 25 | 1, 26 | 1, 27 | 2, 28 | None, 29 | None, 30 | None, 31 | AT, 32 | BT, 33 | DTYPE, 34 | DTYPE, 35 | None, 36 | True, 37 | None, 38 | None, 39 | ), 40 | ( 41 | 32, 42 | 16, 43 | 16, 44 | 1, 45 | 1, 46 | 2, 47 | None, 48 | None, 49 | None, 50 | AT, 51 | BT, 52 | DTYPE, 53 | DTYPE, 54 | None, 55 | True, 56 | None, 57 | None, 58 | ), 59 | ( 60 | 16, 61 | 32, 62 | 16, 63 | 1, 64 | 1, 65 | 2, 66 | None, 67 | None, 68 | None, 69 | AT, 70 | BT, 71 | DTYPE, 72 | DTYPE, 73 | None, 74 | True, 75 | None, 76 | None, 77 | ), 78 | ( 79 | 16, 80 | 16, 81 | 32, 82 | 1, 83 | 1, 84 | 2, 85 | None, 86 | None, 87 | None, 88 | AT, 89 | BT, 90 | DTYPE, 91 | DTYPE, 92 | None, 93 | True, 94 | None, 95 | None, 96 | ), 97 | ( 98 | 32, 99 | 16, 100 | 32, 101 | 1, 102 | 1, 103 | 2, 104 | None, 105 | None, 106 | None, 107 | AT, 108 | BT, 109 | DTYPE, 110 | DTYPE, 111 | None, 112 | True, 113 | None, 114 | None, 115 | ), 116 | ( 117 | 16, 118 | 32, 119 | 32, 120 | 1, 121 | 1, 122 | 2, 123 | None, 124 | None, 125 | None, 126 | AT, 127 | BT, 128 | DTYPE, 129 | DTYPE, 130 | None, 131 | True, 132 | None, 133 | None, 134 | ), 135 | ( 136 | 16, 137 | 16, 138 | 64, 139 | 1, 140 | 1, 141 | 2, 142 | None, 143 | None, 144 | None, 145 | AT, 146 | BT, 147 | DTYPE, 148 | DTYPE, 149 | None, 150 | True, 151 | None, 152 | None, 153 | ), 154 | ( 155 | 64, 156 | 16, 157 | 64, 158 | 1, 159 | 1, 160 | 2, 161 | None, 162 | None, 163 | None, 164 | AT, 165 | BT, 166 | DTYPE, 167 | DTYPE, 168 | None, 169 | True, 170 | None, 171 | None, 172 | ), 173 | ( 174 | 16, 175 | 64, 176 | 64, 177 | 1, 178 | 1, 179 | 2, 180 | None, 181 | None, 182 | None, 183 | AT, 184 | BT, 185 | DTYPE, 186 | DTYPE, 187 | None, 188 | True, 189 | None, 190 | None, 191 | ), 192 | # 2 warp 193 | ( 194 | 64, 195 | 32, 196 | 64, 197 | 1, 198 | 2, 199 | 2, 200 | None, 201 | None, 202 | None, 203 | AT, 204 | BT, 205 | DTYPE, 206 | DTYPE, 207 | None, 208 | True, 209 | None, 210 | None, 211 | ), 212 | ( 213 | 32, 214 | 64, 215 | 64, 216 | 1, 217 | 2, 218 | 2, 219 | None, 220 | None, 221 | None, 222 | AT, 223 | BT, 224 | DTYPE, 225 | DTYPE, 226 | None, 227 | True, 228 | None, 229 | None, 230 | ), 231 | ( 232 | 64, 233 | 32, 234 | 16, 235 | 1, 236 | 2, 237 | 2, 238 | None, 239 | None, 240 | None, 241 | AT, 242 | BT, 243 | DTYPE, 244 | DTYPE, 245 | None, 246 | True, 247 | None, 248 | None, 249 | ), 250 | ( 251 | 32, 252 | 64, 253 | 16, 254 | 1, 255 | 2, 256 | 2, 257 | None, 258 | None, 259 | None, 260 | AT, 261 | BT, 262 | DTYPE, 263 | DTYPE, 264 | None, 265 | True, 266 | None, 267 | None, 268 | ), 269 | ( 270 | 128, 271 | 32, 272 | 32, 273 | 1, 274 | 2, 275 | 2, 276 | None, 277 | None, 278 | None, 279 | AT, 280 | BT, 281 | DTYPE, 282 | DTYPE, 283 | None, 284 | True, 285 | None, 286 | None, 287 | ), 288 | ( 289 | 32, 290 | 128, 291 | 32, 292 | 1, 293 | 2, 294 | 2, 295 | None, 296 | None, 297 | None, 298 | AT, 299 | BT, 300 | DTYPE, 301 | DTYPE, 302 | None, 303 | True, 304 | None, 305 | None, 306 | ), 307 | # 4 warp 308 | ( 309 | 128, 310 | 64, 311 | 16, 312 | 1, 313 | 4, 314 | 2, 315 | None, 316 | None, 317 | None, 318 | AT, 319 | BT, 320 | DTYPE, 321 | DTYPE, 322 | None, 323 | True, 324 | None, 325 | None, 326 | ), 327 | ( 328 | 64, 329 | 128, 330 | 16, 331 | 1, 332 | 4, 333 | 2, 334 | None, 335 | None, 336 | None, 337 | AT, 338 | BT, 339 | DTYPE, 340 | DTYPE, 341 | None, 342 | True, 343 | None, 344 | None, 345 | ), 346 | ( 347 | 128, 348 | 32, 349 | 32, 350 | 1, 351 | 4, 352 | 2, 353 | None, 354 | None, 355 | None, 356 | AT, 357 | BT, 358 | DTYPE, 359 | DTYPE, 360 | None, 361 | True, 362 | None, 363 | None, 364 | ), 365 | ( 366 | 32, 367 | 128, 368 | 32, 369 | 1, 370 | 4, 371 | 2, 372 | None, 373 | None, 374 | None, 375 | AT, 376 | BT, 377 | DTYPE, 378 | DTYPE, 379 | None, 380 | True, 381 | None, 382 | None, 383 | ), 384 | ( 385 | 128, 386 | 32, 387 | 64, 388 | 1, 389 | 4, 390 | 2, 391 | None, 392 | None, 393 | None, 394 | AT, 395 | BT, 396 | DTYPE, 397 | DTYPE, 398 | None, 399 | True, 400 | None, 401 | None, 402 | ), 403 | ( 404 | 32, 405 | 128, 406 | 64, 407 | 1, 408 | 4, 409 | 2, 410 | None, 411 | None, 412 | None, 413 | AT, 414 | BT, 415 | DTYPE, 416 | DTYPE, 417 | None, 418 | True, 419 | None, 420 | None, 421 | ), 422 | # 8 warp 423 | ( 424 | 128, 425 | 256, 426 | 16, 427 | 1, 428 | 8, 429 | 2, 430 | None, 431 | None, 432 | None, 433 | AT, 434 | BT, 435 | DTYPE, 436 | DTYPE, 437 | None, 438 | True, 439 | None, 440 | None, 441 | ), 442 | ( 443 | 256, 444 | 128, 445 | 16, 446 | 1, 447 | 8, 448 | 2, 449 | None, 450 | None, 451 | None, 452 | AT, 453 | BT, 454 | DTYPE, 455 | DTYPE, 456 | None, 457 | True, 458 | None, 459 | None, 460 | ), 461 | ( 462 | 256, 463 | 128, 464 | 32, 465 | 1, 466 | 8, 467 | 2, 468 | None, 469 | None, 470 | None, 471 | AT, 472 | BT, 473 | DTYPE, 474 | DTYPE, 475 | None, 476 | True, 477 | None, 478 | None, 479 | ), 480 | # variable input 481 | ( 482 | 128, 483 | 128, 484 | 32, 485 | 1, 486 | 4, 487 | 2, 488 | 256, 489 | 384, 490 | 160, 491 | AT, 492 | BT, 493 | DTYPE, 494 | DTYPE, 495 | None, 496 | True, 497 | None, 498 | None, 499 | ), 500 | ( 501 | 128, 502 | 128, 503 | 32, 504 | 1, 505 | 4, 506 | 2, 507 | 107, 508 | 233, 509 | 128, 510 | AT, 511 | BT, 512 | DTYPE, 513 | DTYPE, 514 | None, 515 | True, 516 | None, 517 | None, 518 | ), 519 | ( 520 | 128, 521 | 128, 522 | 32, 523 | 1, 524 | 4, 525 | 2, 526 | 107, 527 | 233, 528 | 83, 529 | AT, 530 | BT, 531 | DTYPE, 532 | DTYPE, 533 | None, 534 | True, 535 | None, 536 | None, 537 | ), 538 | ( 539 | 128, 540 | 256, 541 | 64, 542 | 1, 543 | 8, 544 | 3, 545 | 256, 546 | 512, 547 | 160, 548 | AT, 549 | BT, 550 | DTYPE, 551 | DTYPE, 552 | None, 553 | True, 554 | None, 555 | None, 556 | ), 557 | ] 558 | for DTYPE in ["float16", "bfloat16", "float32"] 559 | for AT in [False, True] 560 | for BT in [False, True] 561 | ], 562 | # n-stage 563 | *[ 564 | [ 565 | ( 566 | 16, 567 | 16, 568 | 16, 569 | 1, 570 | 1, 571 | STAGES, 572 | 32, 573 | 32, 574 | 80, 575 | AT, 576 | BT, 577 | DTYPE, 578 | DTYPE, 579 | None, 580 | True, 581 | None, 582 | None, 583 | ), 584 | ( 585 | 64, 586 | 32, 587 | 64, 588 | 1, 589 | 2, 590 | STAGES, 591 | 128, 592 | 64, 593 | 128, 594 | AT, 595 | BT, 596 | DTYPE, 597 | DTYPE, 598 | None, 599 | True, 600 | None, 601 | None, 602 | ), 603 | ( 604 | 128, 605 | 64, 606 | 16, 607 | 1, 608 | 4, 609 | STAGES, 610 | 256, 611 | 128, 612 | 80, 613 | AT, 614 | BT, 615 | DTYPE, 616 | DTYPE, 617 | None, 618 | True, 619 | None, 620 | None, 621 | ), 622 | ( 623 | 256, 624 | 128, 625 | 32, 626 | 1, 627 | 8, 628 | STAGES, 629 | 512, 630 | 256, 631 | 160, 632 | AT, 633 | BT, 634 | DTYPE, 635 | DTYPE, 636 | None, 637 | True, 638 | None, 639 | None, 640 | ), 641 | ( 642 | 128, 643 | 128, 644 | 32, 645 | 1, 646 | 4, 647 | STAGES, 648 | 256, 649 | 256, 650 | 160, 651 | AT, 652 | BT, 653 | DTYPE, 654 | DTYPE, 655 | None, 656 | True, 657 | None, 658 | None, 659 | ), 660 | ] 661 | for DTYPE in ["float16", "bfloat16", "float32"] 662 | for AT in [False, True] 663 | for BT in [False, True] 664 | for STAGES in [4] 665 | ], 666 | # tf32x3 667 | *[ 668 | [ 669 | ( 670 | 16, 671 | 16, 672 | 16, 673 | 1, 674 | 1, 675 | 2, 676 | 32, 677 | 32, 678 | 80, 679 | AT, 680 | BT, 681 | "float32", 682 | "float32", 683 | "tf32x3", 684 | True, 685 | None, 686 | None, 687 | ), 688 | ( 689 | 64, 690 | 32, 691 | 64, 692 | 1, 693 | 2, 694 | 2, 695 | 128, 696 | 64, 697 | 128, 698 | AT, 699 | BT, 700 | "float32", 701 | "float32", 702 | "tf32x3", 703 | True, 704 | None, 705 | None, 706 | ), 707 | ( 708 | 128, 709 | 64, 710 | 16, 711 | 1, 712 | 4, 713 | 2, 714 | 256, 715 | 128, 716 | 80, 717 | AT, 718 | BT, 719 | "float32", 720 | "float32", 721 | "tf32x3", 722 | True, 723 | None, 724 | None, 725 | ), 726 | ( 727 | 256, 728 | 128, 729 | 32, 730 | 1, 731 | 8, 732 | 2, 733 | 512, 734 | 256, 735 | 160, 736 | AT, 737 | BT, 738 | "float32", 739 | "float32", 740 | "tf32x3", 741 | True, 742 | None, 743 | None, 744 | ), 745 | ( 746 | 128, 747 | 128, 748 | 32, 749 | 1, 750 | 4, 751 | 2, 752 | 256, 753 | 256, 754 | 160, 755 | AT, 756 | BT, 757 | "float32", 758 | "float32", 759 | "tf32x3", 760 | True, 761 | None, 762 | None, 763 | ), 764 | ] 765 | for AT in [False, True] 766 | for BT in [False, True] 767 | ], 768 | # mixed-precision 769 | *[ 770 | [ 771 | ( 772 | 32, 773 | 32, 774 | 32, 775 | 1, 776 | 1, 777 | 2, 778 | None, 779 | None, 780 | None, 781 | AT, 782 | BT, 783 | ADTYPE, 784 | BDTYPE, 785 | None, 786 | FASTACCUM, 787 | None, 788 | None, 789 | ), 790 | ( 791 | 128, 792 | 256, 793 | 32, 794 | 1, 795 | 8, 796 | 2, 797 | None, 798 | None, 799 | None, 800 | AT, 801 | BT, 802 | ADTYPE, 803 | BDTYPE, 804 | None, 805 | FASTACCUM, 806 | None, 807 | None, 808 | ), 809 | ( 810 | 32, 811 | 64, 812 | 32, 813 | 1, 814 | 1, 815 | 2, 816 | 64, 817 | 128, 818 | 32, 819 | AT, 820 | BT, 821 | ADTYPE, 822 | BDTYPE, 823 | None, 824 | FASTACCUM, 825 | None, 826 | None, 827 | ), 828 | ] 829 | for ADTYPE, BDTYPE in [ 830 | ("float8e4nv", "float8e5"), 831 | ("float8e4nv", "float8e4nv"), 832 | ("float8e5", "float8e4nv"), 833 | ("float8e5", "float8e5"), 834 | ("float8e4b15", "float8e4b15"), 835 | ("float8e4nv", "float16"), 836 | ("float16", "float8e5"), 837 | ("int8", "bfloat16"), 838 | ("float16", "int8"), 839 | ("float16", "float32"), 840 | ("float32", "float16"), 841 | ("bfloat16", "float32"), 842 | ("float32", "bfloat16"), 843 | ] 844 | for AT in [False, True] 845 | for BT in [False, True] 846 | for FASTACCUM in [True, False] 847 | ], 848 | # mixed-precision block layout 849 | *[ 850 | [ 851 | ( 852 | 32, 853 | 32, 854 | 32, 855 | 1, 856 | 1, 857 | 2, 858 | None, 859 | None, 860 | None, 861 | AT, 862 | BT, 863 | ADTYPE, 864 | BDTYPE, 865 | None, 866 | True, 867 | None, 868 | None, 869 | ), 870 | ( 871 | 128, 872 | 256, 873 | 32, 874 | 1, 875 | 8, 876 | 2, 877 | None, 878 | None, 879 | None, 880 | AT, 881 | BT, 882 | ADTYPE, 883 | BDTYPE, 884 | None, 885 | True, 886 | None, 887 | None, 888 | ), 889 | ( 890 | 32, 891 | 64, 892 | 32, 893 | 1, 894 | 1, 895 | 2, 896 | 64, 897 | 128, 898 | 32, 899 | AT, 900 | BT, 901 | ADTYPE, 902 | BDTYPE, 903 | None, 904 | True, 905 | None, 906 | None, 907 | ), 908 | ] 909 | for ADTYPE, BDTYPE in [ 910 | ("float8e4nv", "float16"), 911 | ("float16", "float8e5"), 912 | ("float16", "float32"), 913 | ("float32", "float16"), 914 | ("bfloat16", "float32"), 915 | ("float32", "bfloat16"), 916 | ] 917 | for AT in [False, True] 918 | for BT in [False, True] 919 | ], 920 | # acc-out-dtype and output_dtype 921 | *[ 922 | [ 923 | ( 924 | 32, 925 | 32, 926 | 32, 927 | 1, 928 | 1, 929 | 2, 930 | None, 931 | None, 932 | None, 933 | False, 934 | False, 935 | "float16", 936 | "float16", 937 | None, 938 | True, 939 | ACC_DTYPE, 940 | OUTPUT_DTYPE, 941 | ), 942 | ( 943 | 128, 944 | 256, 945 | 32, 946 | 1, 947 | 8, 948 | 2, 949 | None, 950 | None, 951 | None, 952 | False, 953 | False, 954 | "float16", 955 | "float16", 956 | None, 957 | True, 958 | ACC_DTYPE, 959 | OUTPUT_DTYPE, 960 | ), 961 | ] 962 | for ACC_DTYPE in [None, "float16", "float32"] 963 | for OUTPUT_DTYPE in [None, "float16", "float32"] 964 | ], 965 | ), 966 | ) 967 | def test_op( 968 | BLOCK_M, 969 | BLOCK_N, 970 | BLOCK_K, 971 | SPLIT_K, 972 | NWARP, 973 | NSTAGE, 974 | M, 975 | N, 976 | K, 977 | AT, 978 | BT, 979 | ADTYPE, 980 | BDTYPE, 981 | INPUT_PRECISION, 982 | F8_FASTACCUM, 983 | ACC_DTYPE, 984 | OUTPUT_DTYPE, 985 | ): 986 | capability = torch.cuda.get_device_capability() 987 | if capability[0] < 7: 988 | pytest.skip("Only test tl.dot() on devices with sm >= 70") 989 | if capability[0] < 8 and (ADTYPE == "bfloat16" or BDTYPE == "bfloat16"): 990 | pytest.skip("Only test bfloat16 on devices with sm >= 80") 991 | if capability[0] < 9 and (ADTYPE == "float8e4nv" or BDTYPE == "float8e4nv"): 992 | pytest.skip("Only test float8e4nv on devices with sm >= 90") 993 | if (ADTYPE == "bfloat16" or BDTYPE == "bfloat16") and SPLIT_K != 1: 994 | pytest.skip("bfloat16 matmuls don't allow split_k for now") 995 | torch.manual_seed(0) 996 | # nuke kernel decorators -- will set meta-parameters manually 997 | kwargs = { 998 | "BLOCK_M": BLOCK_M, 999 | "BLOCK_N": BLOCK_N, 1000 | "BLOCK_K": BLOCK_K, 1001 | "SPLIT_K": SPLIT_K, 1002 | } 1003 | pre_hook = None if SPLIT_K == 1 else lambda nargs: nargs["C"].zero_() 1004 | configs = [ 1005 | triton.Config( 1006 | kwargs=kwargs, num_warps=NWARP, num_stages=NSTAGE, pre_hook=pre_hook 1007 | ) 1008 | ] 1009 | kernel = triton.ops._matmul.kernel 1010 | kernel.configs = configs 1011 | # kernel.run = kernel.run.run.run 1012 | 1013 | # get matrix shape 1014 | M = BLOCK_M if M is None else M 1015 | N = BLOCK_N if N is None else N 1016 | K = BLOCK_K * SPLIT_K if K is None else K 1017 | 1018 | def is_fp8(dtype): 1019 | return "float8" in dtype 1020 | 1021 | def f8_to_f16(x, dtype): 1022 | 1023 | @triton.jit 1024 | def kernel(Y, X, N, BLOCK_SIZE: tl.constexpr): 1025 | pid = tl.program_id(0) 1026 | offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 1027 | mask = offs < N 1028 | x = tl.load(X + offs, mask=mask) 1029 | tl.store(Y + offs, x, mask=mask) 1030 | 1031 | ret = torch.empty_strided( 1032 | x.shape, x.stride(), dtype=torch.float16, device=x.device 1033 | ) 1034 | grid = lambda META: (triton.cdiv(x.numel(), META["BLOCK_SIZE"]),) 1035 | dtype = getattr(tl, dtype) 1036 | kernel[grid](ret, triton.reinterpret(x, dtype), ret.numel(), BLOCK_SIZE=1024) 1037 | return ret 1038 | 1039 | def upcast_if_fp8(x, dtype): 1040 | if is_fp8(dtype): 1041 | return f8_to_f16(x, dtype) 1042 | return x 1043 | 1044 | def init_input(m, n, dtype, acc_dtype): 1045 | if "float8" in dtype: 1046 | ewidth = {"float8e4b15": 4, "float8e4nv": 4, "float8e5": 5}[dtype] 1047 | sign = torch.randint(2, size=(m, n), device="cuda", dtype=torch.int8) * 128 1048 | val = ( 1049 | torch.randint(2**3 - 1, size=(m, n), device="cuda", dtype=torch.int8) 1050 | << 7 - ewidth 1051 | ) 1052 | return sign | val 1053 | if dtype == "int8": 1054 | return torch.randint(-128, 127, (m, n), device="cuda", dtype=torch.int8) 1055 | # Use small range of values to prevent numerical issues. 1056 | min_exp = -4 if acc_dtype == "float16" else -10 1057 | exponents = torch.randint(min_exp, 0, size=(m, n)) 1058 | ret = (2.0**exponents).to(getattr(torch, dtype)).to("cuda") 1059 | return ret 1060 | 1061 | if is_hip(): 1062 | if INPUT_PRECISION == "tf32x3" or is_fp8(ADTYPE) or is_fp8(BDTYPE): 1063 | pytest.skip( 1064 | "fp8 inputs or tf32x3 precison does not have native support on hip" 1065 | ) 1066 | # allocate/transpose inputs 1067 | a = init_input(M, K, ADTYPE, ACC_DTYPE) 1068 | b = init_input(K, N, BDTYPE, ACC_DTYPE) 1069 | a = a if not AT else a.T.contiguous().T 1070 | b = b if not BT else b.T.contiguous().T 1071 | # run test 1072 | th_a = upcast_if_fp8(a, ADTYPE) 1073 | th_b = upcast_if_fp8(b, BDTYPE) 1074 | ab_dtype = triton.ops.get_higher_dtype(th_a.dtype, th_b.dtype) 1075 | acc_dtype = getattr(torch, ACC_DTYPE) if ACC_DTYPE else ab_dtype 1076 | output_dtype = getattr(torch, OUTPUT_DTYPE) if OUTPUT_DTYPE else ab_dtype 1077 | th_c = torch.matmul(th_a.to(output_dtype), th_b.to(output_dtype)) 1078 | try: 1079 | if is_fp8(ADTYPE): 1080 | a = triton.reinterpret(a, getattr(tl, ADTYPE)) 1081 | if is_fp8(BDTYPE): 1082 | b = triton.reinterpret(b, getattr(tl, BDTYPE)) 1083 | tt_c = triton.ops.matmul( 1084 | a, 1085 | b, 1086 | acc_dtype if ACC_DTYPE else None, 1087 | INPUT_PRECISION, 1088 | F8_FASTACCUM, 1089 | output_dtype, 1090 | ) 1091 | torch.testing.assert_close(th_c, tt_c) 1092 | except triton.OutOfResources as e: 1093 | pytest.skip(str(e)) 1094 | --------------------------------------------------------------------------------