├── speedup_plots ├── fig_hyper_attn_causal_masking.pdf ├── fig_hyper_attn_causal_masking.png ├── fig_hyper_attn_no_causal_masking.pdf └── fig_hyper_attn_no_causal_masking.png ├── replace_llm_attention.py ├── src ├── attn_utils.py ├── angular_lsh_triton.py ├── hyper_attn_triton.py └── flash_attn_triton.py ├── unit_tests ├── test_lsh.py └── test_hyper_attention.py ├── chatglm_fast_attention.py ├── README.md ├── benchmark_single_attention.py ├── benchmark_patch_llm.py ├── hyper_attention.py └── LICENSE /speedup_plots/fig_hyper_attn_causal_masking.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amirzandieh/HyperAttention/HEAD/speedup_plots/fig_hyper_attn_causal_masking.pdf -------------------------------------------------------------------------------- /speedup_plots/fig_hyper_attn_causal_masking.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amirzandieh/HyperAttention/HEAD/speedup_plots/fig_hyper_attn_causal_masking.png -------------------------------------------------------------------------------- /speedup_plots/fig_hyper_attn_no_causal_masking.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amirzandieh/HyperAttention/HEAD/speedup_plots/fig_hyper_attn_no_causal_masking.pdf -------------------------------------------------------------------------------- /speedup_plots/fig_hyper_attn_no_causal_masking.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amirzandieh/HyperAttention/HEAD/speedup_plots/fig_hyper_attn_no_causal_masking.png -------------------------------------------------------------------------------- /replace_llm_attention.py: -------------------------------------------------------------------------------- 1 | NUM_TOTAL_LAYERS = { 2 | 'chatglm2-6b-32k': 28, 3 | } 4 | 5 | 6 | def patch_attention_layers(model, model_name, patch_config, num_patch_layers, **kwargs): 7 | num_total_layers = NUM_TOTAL_LAYERS[model_name] 8 | num_patch_layers = num_total_layers if num_patch_layers < 0 else num_patch_layers 9 | 10 | if patch_config == 'last': 11 | patch_layer_indices = range(num_total_layers - 1, num_total_layers - num_patch_layers - 1, -1) 12 | 13 | elif patch_config == 'first': 14 | patch_layer_indices = range(num_patch_layers) 15 | 16 | elif patch_config == 'odd': 17 | patch_layer_indices = range(1, num_total_layers, 2) 18 | 19 | elif patch_config == 'even': 20 | patch_layer_indices = range(0, num_total_layers, 2) 21 | 22 | elif patch_config == 'odd_first': 23 | patch_layer_indices = range(1, 2 * num_patch_layers, 2) 24 | 25 | elif patch_config == 'odd_last': 26 | patch_layer_indices = range(num_total_layers - 1, num_total_layers - num_patch_layers, -1) 27 | 28 | elif patch_config == 'even_first': 29 | patch_layer_indices = range(0, num_total_layers, 2)[:num_patch_layers] 30 | 31 | elif patch_config == 'even_last': 32 | patch_layer_indices = range(1, num_total_layers, 2)[-num_patch_layers:] 33 | 34 | else: 35 | raise NotImplementedError(f"Invalid patch_config option: {patch_config}") 36 | 37 | if model_name == 'chatglm2-6b-32k': 38 | from chatglm_fast_attention import FastCoreAttention 39 | 40 | print( 41 | f"patch_config: {patch_config}, attn_method: {kwargs['attn_method']}, num_patch_layers: {num_patch_layers}, patch_indices: {list(patch_layer_indices)}") 42 | for i in patch_layer_indices: 43 | model.transformer.encoder.layers[i].self_attention.core_attention = FastCoreAttention(model.config, i, 44 | **kwargs) 45 | -------------------------------------------------------------------------------- /src/attn_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | 5 | 6 | def add_self_attentions(attn1, lse1, attn2, lse2): 7 | """ 8 | inputs: 9 | - attn1, attn2: 4d-tensors with shape [b, h, n, d] 10 | - lse1, lse2: 4d-tensors of log-sum-exp with shape [b, h, n, 1] 11 | output: 12 | - attn 13 | = (attn1 * exp(lse1) + attn2 * exp(lse2)) / (exp(lse1) + exp(lse2)) 14 | = (attn1 + attn2 * exp(lse2 - lse1)) / (1 + exp(lse2-lse1)) 15 | = attn1 * c + attn2 * (1-c), where c=1/(1 + exp(lse2-lse1)), 16 | - lse 17 | = log(exp(lse1) + exp(lse2)) 18 | = log(exp(lse1) * (1 + exp(lse2 - lse1))) 19 | = lse1 + log(1 + exp(lse2 - lse1)) = lse1 - log(c) 20 | """ 21 | c = (1 / (1 + (lse2 - lse1).exp())).to(dtype=attn1.dtype) 22 | attn = c * attn1 + (1-c) * attn2 23 | lse = lse1 - (c + torch.finfo(lse1.dtype).eps).log() 24 | return attn, lse 25 | 26 | 27 | def indexing(x, indices, chunk_size=-1): 28 | """ 29 | inputs: 30 | - x: 4d-tensor with shape [b, h, n, d] 31 | - indices: 3d-tensor with shape [b, h, s] where each entry should be in [0, n-1] 32 | output: 33 | - out: 4d-tensor with shape [b, h, s, d] where out[i,j] = x[i,j][indices[i,j],:] 34 | 35 | A naive implementation: 36 | out = torch.zeros(b, h, s, d) 37 | for i in range(b): 38 | for j in range(h): 39 | out[i,j] = x[i,j][idx[i,j],:] 40 | return out 41 | """ 42 | if chunk_size < 0 or (chunk_size > 0 and x.shape[-2] % chunk_size == 0): 43 | return x.gather(2, indices.unsqueeze(-1).expand(-1, -1, -1, x.shape[-1])) 44 | else: 45 | x = x.gather(2, indices.unsqueeze(-1).expand(-1, -1, -1, x.shape[-1])) 46 | new_n = math.ceil(x.shape[2] / chunk_size) * chunk_size 47 | if new_n <= 0 or new_n - x.shape[2] <= 0: 48 | import pdb; 49 | pdb.set_trace(); 50 | return torch.nn.functional.pad(x, (0, 0, 0, new_n - x.shape[2]), mode='constant', value=0.) 51 | -------------------------------------------------------------------------------- /unit_tests/test_lsh.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import time 3 | import torch 4 | import triton 5 | 6 | import sys; 7 | 8 | sys.path.append("/home/ec2-user/workspace/hyper_attention") 9 | from src.angular_lsh_triton import AngularLSHTriton 10 | 11 | cnt = 0 12 | 13 | def check_memory(): 14 | global cnt 15 | mem_alloc = torch.cuda.memory_allocated() / 1024 / 1024 / 1024 16 | mem_reserve = torch.cuda.memory_reserved() / 1024 / 1024 / 1024 17 | mem_peak = torch.cuda.memory_stats()['active_bytes.all.peak'] / 1024 / 1024 / 1024 18 | print(f"[{cnt}] mem_alloc: {mem_alloc:.4f}, mem_reserve: {mem_reserve:.4f}, mem_peak: {mem_peak:.4f}") 19 | cnt += 1 20 | 21 | 22 | class MyTestCase(unittest.TestCase): 23 | def test_1_validation(self): 24 | print("1. this is validation test") 25 | dtype = torch.float16 26 | block_size, dim, batch_size, head_size, seq_len = 256, 128, 4, 32, 2048 27 | num_projs = 8 28 | 29 | query = torch.randn((batch_size, head_size, seq_len, dim), device='cuda', dtype=dtype) 30 | 31 | self.lsh = AngularLSHTriton(num_projs=num_projs, dim=(1, 1, dim)).to(device='cuda', dtype=dtype) 32 | 33 | # apply lsh in pytorch 34 | t0 = time.time() 35 | query_hash_buckets = self.lsh.hash_torch(query) 36 | t1 = time.time() 37 | print('the runtime of torch lsh:', t1 - t0) 38 | 39 | # apply lsh in triton 40 | check_memory() 41 | t2 = time.time() 42 | query_hash_buckets_triton = self.lsh.hash_triton(query) 43 | t3 = time.time() 44 | check_memory() 45 | print('the runtime of triton lsh:', t3 - t2) 46 | 47 | print('difference between torch and triton hashes: ', (query_hash_buckets.float() - query_hash_buckets_triton.float()).norm()) 48 | 49 | def test_2_runtime(self): 50 | print() 51 | print("2. this is runtime test") 52 | 53 | block_size, dim, batch_size, head_size = 256, 128, 4, 32 54 | num_projs = 8 55 | seq_len = 2048 56 | dtype = torch.float16 57 | 58 | query = torch.randn((batch_size, head_size, seq_len, dim), device='cuda', dtype=dtype) 59 | 60 | self.lsh = AngularLSHTriton(num_projs=num_projs, dim=(1, 1, dim)).to(device='cuda', dtype=dtype) 61 | 62 | def test_fn1(): 63 | query_hash_buckets = self.lsh.hash_torch(query) 64 | 65 | warmup = 20 66 | rep = 1000 67 | 68 | tim_py_q20, tim_py_q50, tim_py_q80 = triton.testing.do_bench(test_fn1, warmup=warmup, rep=rep, 69 | quantiles=[0.2, 0.5, 0.8]) 70 | print(f"pytorch runtime: {tim_py_q50:.5f} ms ({tim_py_q20:.5f}, {tim_py_q80:.5f})") 71 | 72 | def test_fn2(): 73 | query_hash_buckets_triton = self.lsh.hash_triton(query) 74 | 75 | tim_tr_q20, tim_tr_q50, tim_tr_q80 = triton.testing.do_bench(test_fn2, warmup=warmup, rep=rep, 76 | quantiles=[0.2, 0.5, 0.8]) 77 | print(f"triton runtime: {tim_tr_q50:.5f} ms ({tim_tr_q20:.5f}, {tim_tr_q80:.5f})") 78 | 79 | 80 | if __name__ == '__main__': 81 | unittest.main() 82 | -------------------------------------------------------------------------------- /chatglm_fast_attention.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | from hyper_attention import HyperAttention 5 | 6 | 7 | # Edited from https://huggingface.co/THUDM/chatglm2-6b-32k/blob/main/modeling_chatglm.py#L194 8 | class FastCoreAttention(torch.nn.Module): 9 | 10 | def __init__(self, config, layer_number, **kwargs): 11 | super(FastCoreAttention, self).__init__() 12 | 13 | self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling 14 | self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32 15 | if self.apply_query_key_layer_scaling: 16 | self.attention_softmax_in_fp32 = True 17 | self.layer_number = max(1, layer_number) 18 | 19 | projection_size = config.kv_channels * config.num_attention_heads 20 | 21 | # Per attention head and per partition values. 22 | self.hidden_size_per_partition = projection_size 23 | self.hidden_size_per_attention_head = projection_size // config.num_attention_heads 24 | self.num_attention_heads_per_partition = config.num_attention_heads 25 | 26 | coeff = None 27 | self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) 28 | if self.apply_query_key_layer_scaling: 29 | coeff = self.layer_number 30 | self.norm_factor *= coeff 31 | self.coeff = coeff 32 | 33 | self.attention_dropout = torch.nn.Dropout(config.attention_dropout) 34 | 35 | self.attn_method = kwargs.get('attn_method') 36 | if self.attn_method == 'hyper': 37 | lsh_num_projs = kwargs.get('lsh_num_projs') 38 | block_size = kwargs.get('block_size') 39 | sample_size = kwargs.get('sample_size') 40 | min_seq_len = kwargs.get('min_seq_len') 41 | smooth_block = kwargs.get('smooth_block') 42 | self.attn = HyperAttention( 43 | input_dim=128, 44 | lsh_num_projs=lsh_num_projs, 45 | block_size=block_size, 46 | sample_size=sample_size, 47 | min_seq_len=min_seq_len, 48 | smooth_block=smooth_block, 49 | ) 50 | else: 51 | raise NotImplementedError("Invalid attn_method option") 52 | 53 | def forward(self, query_layer, key_layer, value_layer, attention_mask): 54 | 55 | query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]] 56 | if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: 57 | softmax_scale = query_layer.shape[-1] ** (-0.5) 58 | context_layer = self.attn(query_layer, key_layer, value_layer, causal=True) 59 | 60 | else: 61 | assert False, 'this part the query length and key length may be different and not be a computational bottleneck.' 62 | if attention_mask is not None: 63 | attention_mask = ~attention_mask 64 | context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, 65 | attention_mask) 66 | 67 | context_layer = context_layer.permute(2, 0, 1, 3) 68 | new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) 69 | context_layer = context_layer.reshape(*new_context_layer_shape) 70 | return context_layer -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # HyperAttention: Long-context Attention in Near-Linear Time 2 | 3 | Triton Implementation of HyperAttention Algorithm 4 | 5 | # Requirements 6 | 7 | The code requires ``pytorch`` and [``triton``](https://github.com/openai/triton). 8 | pytorch version 2.0.1 tested, but any version >= 2.0.0 might work. 9 | Also makes use of [triton](https://github.com/openai/triton) implementation of [FlashAttention](https://github.com/Dao-AILab/flash-attention/tree/main). Flash attention kernel adapted to be compilable with triton version **2.1.0.** 10 | 11 | # How to use 12 | 13 | The impelmentation of HyperAttention can be found in ``hyper_attention.py``. An example of usage: 14 | 15 | ```python 16 | from hyper_attention import HyperAttention 17 | 18 | attn = HyperAttention( 19 | input_dim=64, 20 | lsh_num_projs=8, 21 | block_size=256, 22 | sample_size=256, 23 | min_seq_len=2048, 24 | smooth_block=False,) 25 | 26 | attn_output = attn(query, key, value, causal=True) 27 | ``` 28 | 29 | The module has the following parameters: 30 | - ```input_dim```: the dimension of input query and key. (Required) 31 | - ```lsh_num_projs```: the number of random projection vectors used in the locality-sensitive hashing scheme. The default is 8. 32 | - ```block_size```: the size of blocks for the block-diagonal approximation. It must be divisible by 128. The default is 256. 33 | - ```sample_size```: the number of sampled columns in the attention matrix $A$. It must be divisible by 128. The default is 256. 34 | - ```min_seq_len```: minimum sequence length that HyperAttention applies. When the sequence length is smaller than this value we compute exactly using the FlashAttention because overheads of HyperAttention may dominate the runtime for short sequences. The default value is ```2048```. 35 | - ```smooth_block```: smoothen the block-diagonal approximation by letting the blocks overlap and resemble smooth banded diagonal approximation. The default is False. 36 | - The sequence lengths of both ```query``` and ```key``` must be divisible by ```block_size```. 37 | 38 | # Speedup on single attention layer 39 | 40 | In this section, we showcase the speedup achieved by HyperAttention in comparison to the Triton implementation of FlashAttention (v2) across a range of sequence lengths. The configuration involves 32 heads and a head_dim 64, and the results are obtained by running the methods on NVIDIA A10 Tensor Core GPUs. 41 | 42 | ## Causal masking (decoder-style attention) 43 | 44 | The speedup factors for both the forward pass and forward+backward passes for the attention decoder with causal masking are plotted below. HyperAttention exhibits over a ```22x``` speedup for the forward pass and an over ```16x``` speedup for the combined forward+backward passes when the sequence length is ```131k```. 45 | 46 |

47 | 48 |

49 | 50 | ## No causal masking (encoder-style attention) 51 | 52 | The speedup factors for both the forward pass and forward+backward passes in the attention encoder, without causal masking, are shown below. HyperAttention reduces to a notably simpler and more efficient algorithm in the absence of causal masking, avoiding the need for recursive partitioning of the attention matrix. Therefore, HyperAttention showcases remarkable speedups, surpassing ```270x``` acceleration for both the forward pass and the combined forward+backward passes when the sequence length is ```131k```. 53 | 54 |

55 | 56 |

57 | -------------------------------------------------------------------------------- /benchmark_single_attention.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import triton 4 | 5 | from src.flash_attn_triton import flash_attn_func 6 | from hyper_attention import HyperAttention 7 | 8 | try: 9 | from flash_attn import flash_attn_func as flash_attn_func_cuda 10 | except ImportError: 11 | flash_attn_func_cuda = None 12 | 13 | 14 | def get_arguments(): 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--no_causal", action="store_true") 17 | parser.add_argument("--smooth_block", action="store_true") 18 | parser.add_argument("--mode", type=str, default="fwd+bwd", choices=['fwd', 'bwd', 'fwd+bwd']) 19 | parser.add_argument("--attn_method", type=str, default="flash", 20 | choices=['flash', 'flash-cuda', 'hyper']) 21 | return parser.parse_args() 22 | 23 | 24 | def get_tensors(batch_size, seq_len, head_size, dim): 25 | q = torch.randn((batch_size, seq_len, head_size, dim), dtype=torch.bfloat16, device="cuda", requires_grad=True) 26 | k = torch.randn((batch_size, seq_len, head_size, dim), dtype=torch.bfloat16, device="cuda", requires_grad=True) 27 | v = torch.randn((batch_size, seq_len, head_size, dim), dtype=torch.bfloat16, device="cuda", requires_grad=True) 28 | return q, k, v 29 | 30 | 31 | def run_flash_attn(batch_size, head_size, seq_len, dim, causal, mode, impl="triton", warmup=20, rep=100): 32 | q, k, v = get_tensors(batch_size, seq_len, head_size, dim) 33 | if impl == "cuda": 34 | if flash_attn_func_cuda is None: 35 | raise ImportError("Please install flash_attn (pip install flash-attn --no-build-isolation)") 36 | fn = lambda: flash_attn_func_cuda(q, k, v, causal=causal) 37 | else: 38 | fn = lambda: flash_attn_func(q, k, v, None, causal, None)[0] 39 | if mode == 'fwd': 40 | return triton.testing.do_bench(fn, warmup=warmup, rep=rep, quantiles=[0.2, 0.5, 0.8]) 41 | elif mode == 'bwd': 42 | o = fn() 43 | do = torch.randn_like(o) 44 | fn = lambda: o.backward(do, retain_graph=True) 45 | return triton.testing.do_bench(fn, warmup=warmup, rep=rep, quantiles=[0.2, 0.5, 0.8]) 46 | else: # mode == 'fwd+bwd' 47 | q20_fwd, median_fwd, q80_fwd = triton.testing.do_bench(fn, warmup=warmup, rep=rep, quantiles=[0.2, 0.5, 0.8]) 48 | o = fn() 49 | do = torch.randn_like(o) 50 | fn = lambda: o.backward(do, retain_graph=True) 51 | q20_bwd, median_bwd, q80_bwd = triton.testing.do_bench(fn, warmup=warmup, rep=rep, quantiles=[0.2, 0.5, 0.8]) 52 | return q20_fwd + q20_bwd, median_fwd + median_bwd, q80_fwd + q80_bwd 53 | 54 | 55 | def run_hyper_attn(batch_size, head_size, seq_len, dim, causal, mode, smooth_block, warmup=20, rep=100): 56 | q, k, v = get_tensors(batch_size, head_size, seq_len, dim) 57 | block_size = 256 58 | sample_size = 256 59 | 60 | attn = HyperAttention( 61 | input_dim=dim, 62 | block_size=block_size, 63 | sample_size=sample_size, 64 | smooth_block=smooth_block,).to(device='cuda', dtype=q.dtype) 65 | 66 | fn = lambda: attn(q, k, v, causal=causal) 67 | 68 | if mode == 'fwd': 69 | return triton.testing.do_bench(fn, warmup=warmup, rep=rep, quantiles=[0.2, 0.5, 0.8]) 70 | elif mode == 'bwd': 71 | o = fn() 72 | do = torch.randn_like(o) 73 | fn = lambda: o.backward(do, retain_graph=True) 74 | return triton.testing.do_bench(fn, warmup=warmup, rep=rep, quantiles=[0.2, 0.5, 0.8]) 75 | else: # mode == 'fwd+bwd' 76 | q20_fwd, median_fwd, q80_fwd = triton.testing.do_bench(fn, warmup=warmup, rep=rep, quantiles=[0.2, 0.5, 0.8]) 77 | o = fn() 78 | do = torch.randn_like(o) 79 | fn = lambda: o.backward(do, retain_graph=True) 80 | q20_bwd, median_bwd, q80_bwd = triton.testing.do_bench(fn, warmup=warmup, rep=rep, quantiles=[0.2, 0.5, 0.8]) 81 | return q20_fwd + q20_bwd, median_fwd + median_bwd, q80_fwd + q80_bwd 82 | 83 | 84 | def main(): 85 | args = get_arguments() 86 | for arg_name, arg_var in args.__dict__.items(): 87 | print(f"{arg_name:<16} : {arg_var}") 88 | 89 | seq_lens = [2 ** i for i in range(10, 18)] 90 | 91 | attn_method = args.attn_method # ['flash', 'hyper'] 92 | mode = args.mode # ['fwd', 'bwd', 'fwd+bwd'] 93 | batch_size, head_size, dim = 1, 24, 64 94 | print(f"mode: {mode}, attn_method: {attn_method}, batch_size: {batch_size}, head_size: {head_size}, dim: {dim}") 95 | 96 | causal = not args.no_causal 97 | 98 | for seq_len in seq_lens: 99 | if attn_method == 'flash': 100 | ms = run_flash_attn(batch_size, head_size, seq_len, dim, causal, mode=args.mode) 101 | elif attn_method == 'flash-cuda': 102 | ms = run_flash_attn(batch_size, head_size, seq_len, dim, causal, mode=args.mode, impl="cuda") 103 | elif attn_method == 'hyper': 104 | ms = run_hyper_attn(batch_size, head_size, seq_len, dim, causal, mode=args.mode, 105 | smooth_block=args.smooth_block) 106 | else: 107 | raise NotImplementedError 108 | 109 | print( 110 | f"[{mode:<8}], {attn_method}, seq_len: {seq_len:<8}, causal: {causal}, ms: {ms[0]:.5f} ({ms[1]:.5f}, {ms[2]:.5f}) | ") 111 | 112 | 113 | if __name__ == "__main__": 114 | main() 115 | 116 | -------------------------------------------------------------------------------- /benchmark_patch_llm.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from tqdm import tqdm 3 | import numpy as np 4 | import torch 5 | from torch.nn import CrossEntropyLoss 6 | from datasets import load_dataset, concatenate_datasets 7 | from transformers import AutoModelForCausalLM, AutoTokenizer 8 | 9 | from replace_llm_attention import patch_attention_layers 10 | 11 | 12 | def get_model_and_tokenizer(model_name): 13 | if model_name == "chatglm2-6b-32k": 14 | tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b-32k", trust_remote_code=True) 15 | model = AutoModelForCausalLM.from_pretrained("THUDM/chatglm2-6b-32k", trust_remote_code=True) 16 | else: 17 | raise NotImplementedError("Currently we only support chatglm2") 18 | 19 | return model, tokenizer 20 | 21 | 22 | def get_arguments(): 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("--seq_len", type=int, default=32768) 25 | # patch config 26 | parser.add_argument("--patch_config", type=str, default="last", choices=['last', 'first', 'even', 'odd']) 27 | parser.add_argument("--attn_method", type=str, default="hyper", choices=['flash', 'hyper', 'hyper-cuda']) 28 | parser.add_argument("--num_patch_layers", type=int, default=-1) 29 | # params of HyperAttention 30 | parser.add_argument("--block_size", type=int, default=256) 31 | parser.add_argument("--sample_size", type=int, default=256) 32 | parser.add_argument("--lsh_num_projs", type=int, default=8) 33 | parser.add_argument("--min_seq_len", type=int, default=2048) 34 | parser.add_argument("--smooth_block", action="store_true") 35 | # currently only supports **chatglm2-6b-32k** 36 | parser.add_argument("--model_name", type=str, default="chatglm2-6b-32k") 37 | return parser.parse_args() 38 | 39 | 40 | @torch.no_grad() 41 | def main(): 42 | args = get_arguments() 43 | for arg_name, arg_var in args.__dict__.items(): 44 | print(f"{arg_name:<16} : {arg_var}") 45 | 46 | model, tokenizer = get_model_and_tokenizer(args.model_name) 47 | tokenizer.model_max_length = args.seq_len 48 | device = "cuda" 49 | dtype = torch.bfloat16 50 | 51 | # Load LongBench datasets 52 | dataset = 'longbench' 53 | dataset_names = ["narrativeqa", "qasper", "multifieldqa_en", "multifieldqa_zh", "hotpotqa", "2wikimqa", "musique", \ 54 | "dureader", "gov_report", "qmsum", "multi_news", "vcsum", "trec", "triviaqa", "samsum", "lsht", \ 55 | "passage_count", "passage_retrieval_en", "passage_retrieval_zh", "lcc", "repobench-p"] 56 | 57 | data_subset_all = [] 58 | for dataset in dataset_names: 59 | data_ = load_dataset('THUDM/LongBench', f"{dataset}", split='test') 60 | data_subset = data_.filter(lambda x: len(tokenizer.encode(x['context'])) >= args.seq_len) 61 | if len(data_subset) > 0: 62 | data_subset_all.append(data_subset) 63 | data = concatenate_datasets(data_subset_all) 64 | 65 | encoded_texts = [] 66 | pbar = tqdm(data) 67 | for i, data_i in enumerate(pbar): 68 | encoded_text = tokenizer.encode(data_i['context'], return_tensors='pt', truncation=True) 69 | pbar.set_description(f"seq_len: {len(encoded_text[0])}, n_data: {len(encoded_texts)}") 70 | if len(encoded_text[0]) < args.seq_len: 71 | continue 72 | encoded_texts.append(encoded_text) 73 | print(f"# of data longer than {args.seq_len}: {len(encoded_texts)}") 74 | 75 | if args.attn_method != 'flash': 76 | patch_attention_layers(model=model, **args.__dict__) 77 | 78 | model.to(device=device, dtype=dtype) 79 | model.eval() 80 | loss_fct = CrossEntropyLoss(reduction="none") 81 | 82 | ppls = [] 83 | 84 | pbar = tqdm(range(len(encoded_texts))) 85 | for bid in pbar: 86 | encoded_batch = encoded_texts[bid:bid + 1] 87 | if type(encoded_batch) == dict: 88 | attn_mask = encoded_batch['attention_mask'] if 'attention_mask' in encoded_batch.keys() else None 89 | encoded_batch = encoded_batch['input_ids'] 90 | elif type(encoded_batch) == list: 91 | encoded_batch = encoded_batch[0] 92 | 93 | encoded_batch = encoded_batch.to(device) 94 | attn_mask = torch.ones_like(encoded_batch) 95 | 96 | out_logits = model(encoded_batch).logits 97 | 98 | labels = encoded_batch 99 | 100 | shift_logits = out_logits[..., :-1, :].contiguous() 101 | shift_labels = labels[..., 1:].contiguous() 102 | shift_attention_mask_batch = attn_mask[..., 1:].contiguous() 103 | 104 | loss_ = loss_fct(shift_logits.transpose(1, 2), shift_labels).float() 105 | perplexity_batch = torch.exp2( 106 | (loss_ * shift_attention_mask_batch).sum(1) 107 | / shift_attention_mask_batch.sum(1) 108 | ) 109 | ppls += perplexity_batch.tolist() 110 | 111 | pbar.set_description( 112 | f"[{bid:<4}/{len(encoded_texts)}] avg_ppls: {np.mean(np.array(ppls)[~np.isnan(np.array(ppls))]):.4f}") 113 | 114 | del out_logits, encoded_batch, attn_mask, shift_logits, shift_labels, shift_attention_mask_batch, perplexity_batch 115 | 116 | nan_cnt = sum(np.isnan(np.array(ppls))) 117 | ppl_mean = np.mean(np.array(ppls)[~np.isnan(np.array(ppls))]) 118 | 119 | print(f"ppl: {ppl_mean}, nan_cnt: {nan_cnt}") 120 | res_str = f"model: {args.model_name}, dtype: {dtype}, seq_len: {args.seq_len}, num_patch_layers: {args.num_patch_layers}, n_data: {len(encoded_texts)}, ppl: {ppl_mean}, nan_cnt: {nan_cnt}\n" 121 | print(res_str) 122 | 123 | 124 | if __name__ == "__main__": 125 | main() 126 | -------------------------------------------------------------------------------- /hyper_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from src.attn_utils import add_self_attentions 4 | from src.flash_attn_triton import flash_attn_func 5 | from src.hyper_attn_triton import hyper_attn_func 6 | from src.angular_lsh_triton import AngularLSHTriton 7 | 8 | 9 | class HyperAttention(torch.nn.Module): 10 | 11 | def __init__(self, input_dim=64, lsh_num_projs=8, block_size=256, sample_size=256, min_seq_len=2048, 12 | smooth_block=False, **kwargs): 13 | """ 14 | - block_size and sample_size must be divisible by 128 15 | """ 16 | super().__init__() 17 | self.input_dim = input_dim 18 | self.lsh_num_projs = lsh_num_projs 19 | self.block_size = block_size 20 | self.sample_size = sample_size 21 | self.min_seq_len = min_seq_len 22 | self.smooth_block = smooth_block 23 | self.lsh = AngularLSHTriton(num_projs=self.lsh_num_projs, dim=(1, 1, input_dim)) 24 | 25 | def forward(self, query: torch.tensor, key: torch.tensor, value: torch.tensor, scale=None, causal=False, 26 | return_lse=False): 27 | """ 28 | Forward function for HyperAttention. If no causal masking, simply invokes forward_no_causal_mask method. 29 | If there is causal masking, it partitions the attention matrix and recurses on the partitions. 30 | inputs: 31 | - query, key, and valu: must have same sequence lengths but dimension of values vectors can be different 32 | from that of query or key 33 | - sequence lengths must be divisible by block_size 34 | output: 35 | - attn: (approximation of) the final attention output tensor 36 | - lse: (approximation of) log sum exp of the qk matrix 37 | """ 38 | query = query.contiguous() 39 | key = key.contiguous() 40 | value = value.contiguous() 41 | 42 | n_query = query.shape[2] 43 | batch_size, n_heads, n_key, dim = key.shape 44 | scale = scale or dim ** (-0.5) 45 | assert n_query == n_key 46 | 47 | # without causal masking 48 | if causal is False: 49 | attn, lse = self.forward_no_causal_mask(query, key, value, scale) 50 | 51 | else: # with causal masking 52 | if n_key <= self.min_seq_len: 53 | attn, lse = flash_attn_func(query.transpose(1, 2), 54 | key.transpose(1, 2), 55 | value.transpose(1, 2), 56 | None, True, scale) 57 | attn = attn.transpose(1, 2) 58 | 59 | else: 60 | # If n_query is odd we pad inputs by zero rows 61 | if n_query % 2: 62 | query = torch.nn.functional.pad(query, (0, 0, 0, 1), mode='constant', value=0.) 63 | key = torch.nn.functional.pad(key, (0, 0, 0, 1), mode='constant', value=0.) 64 | value = torch.nn.functional.pad(value, (0, 0, 0, 1), mode='constant', value=0.) 65 | 66 | # extract block diagonal parts 67 | q_bd = query.view(batch_size, 2 * n_heads, query.shape[2] // 2, query.shape[-1]) 68 | k_bd = key.view(batch_size, 2 * n_heads, key.shape[2] // 2, key.shape[-1]) 69 | v_bd = value.view(batch_size, 2 * n_heads, key.shape[2] // 2, value.shape[-1]) 70 | 71 | attn_bd, lse_bd = self.forward(q_bd, k_bd, v_bd, scale, True, True) 72 | 73 | if attn_bd.shape[2] not in attn_bd.stride(): 74 | attn_bd = attn_bd.contiguous() 75 | attn_bd = attn_bd.view(batch_size, n_heads, -1, dim) 76 | 77 | if lse_bd.shape[2] not in lse_bd.stride(): 78 | lse_bd = lse_bd.contiguous() 79 | lse_bd = lse_bd.view(batch_size, n_heads, -1, 1) 80 | 81 | # lowe diagonal block is an unmasked attention 82 | attn_unmasked, lse_unmasked = self.forward_no_causal_mask( 83 | query[:, :, key.shape[2] // 2:, :], key[:, :, :key.shape[2] // 2, :], 84 | value[:, :, :key.shape[2] // 2, :], scale) 85 | 86 | attn_up, lse_up = attn_bd[:, :, :query.shape[2] // 2, :], lse_bd[:, :, :query.shape[2] // 2, :] 87 | attn_down, lse_down = add_self_attentions(attn_bd[:, :, query.shape[2] // 2:, :], 88 | lse_bd[:, :, query.shape[2] // 2:, :], 89 | attn_unmasked, lse_unmasked) 90 | 91 | attn = torch.cat((attn_up, attn_down), dim=-2) 92 | lse = torch.cat((lse_up, lse_down), dim=-2) 93 | 94 | if n_query % 2: 95 | attn = attn[:, :, :-1, :] 96 | lse = lse[:, :, :-1, :] 97 | 98 | if not return_lse: 99 | return attn 100 | else: 101 | return attn, lse 102 | 103 | def forward_no_causal_mask(self, query, key, value, scale): 104 | """ 105 | - sequence lengths must be divisible by block_size 106 | """ 107 | batch_size, head_size, n_query, dim = query.shape 108 | 109 | if self.min_seq_len > n_query: 110 | attn, lse = flash_attn_func(query.transpose(1, 2), 111 | key.transpose(1, 2), 112 | value.transpose(1, 2), 113 | None, False, scale) 114 | else: 115 | # Hash keys and queries via SortLSH and obtain buckets 116 | _, query_sort_idx = torch.sort(self.lsh.hash_triton(query), dim=2, stable=True) # batch_size x head_size x n 117 | _, key_sort_idx = torch.sort(self.lsh.hash_triton(key), dim=2, stable=True) 118 | 119 | # Now run hyper attention function on q,k,v and the permutations 120 | attn, lse = hyper_attn_func(query.transpose(1, 2), 121 | key.transpose(1, 2), 122 | value.transpose(1, 2), 123 | query_sort_idx.transpose(1, 2), 124 | key_sort_idx.transpose(1, 2), 125 | self.block_size, 126 | self.sample_size, 127 | scale, 128 | self.smooth_block, 129 | ) 130 | attn = attn.transpose(1, 2) 131 | 132 | return attn, lse.unsqueeze(-1) 133 | -------------------------------------------------------------------------------- /src/angular_lsh_triton.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import triton 4 | import triton.language as tl 5 | 6 | 7 | @triton.heuristics( 8 | { 9 | "EVEN_M": lambda args: args["seqlen"] % args["BLOCK_M"] == 0, 10 | "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], 11 | } 12 | ) 13 | @triton.jit 14 | def _angular_lsh_kernel( 15 | in_mat, 16 | proj_dir, 17 | perm, 18 | enc_vec, 19 | buckets, 20 | stride_in_matb, 21 | stride_in_math, 22 | stride_in_matm, 23 | stride_proj_dirb, 24 | stride_proj_dirh, 25 | stride_proj_dird, 26 | stride_bucketsb, 27 | stride_bucketsh, 28 | nheads, 29 | seqlen, 30 | seqlen_rounded, 31 | headdim, 32 | NUM_PROJ_ROUNDED: tl.constexpr, 33 | num_projs: tl.constexpr, 34 | BLOCK_HEADDIM: tl.constexpr, 35 | EVEN_M: tl.constexpr, 36 | EVEN_HEADDIM: tl.constexpr, 37 | BLOCK_M: tl.constexpr, 38 | ): 39 | start_m = tl.program_id(0) 40 | off_hb = tl.program_id(1) 41 | off_b = off_hb // nheads 42 | off_h = off_hb % nheads 43 | offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) 44 | offs_n = tl.arange(0, NUM_PROJ_ROUNDED) 45 | offs_d = tl.arange(0, BLOCK_HEADDIM) 46 | 47 | in_mat_ptrs = ( 48 | in_mat + off_b * stride_in_matb + off_h * stride_in_math + (offs_m[:, None] * stride_in_matm + 49 | offs_d[None, :]) 50 | ) 51 | proj_dir_ptrs = ( 52 | proj_dir + off_b * stride_proj_dirb + off_h * stride_proj_dirh + (offs_d[:, None] * stride_proj_dird + 53 | offs_n[None, :]) 54 | ) 55 | 56 | # load in_mat block 57 | if EVEN_M: 58 | if EVEN_HEADDIM: 59 | mat = tl.load(in_mat_ptrs) 60 | else: 61 | mat = tl.load(in_mat_ptrs, mask=offs_d[None, :] < headdim, other=0.0) 62 | else: 63 | if EVEN_HEADDIM: 64 | mat = tl.load(in_mat_ptrs, mask=offs_m[:, None] < seqlen, other=0.0) 65 | else: 66 | mat = tl.load(in_mat_ptrs, mask=(offs_m[:, None] < seqlen) & (offs_d[None, :] < headdim), other=0.0) 67 | 68 | # load proj_dir block, need to mask out out of bound offsets 69 | if EVEN_HEADDIM: 70 | proj_dir_block = tl.load(proj_dir_ptrs, mask=offs_n[None, :] < num_projs, other=0.0) 71 | else: 72 | proj_dir_block = tl.load(proj_dir_ptrs, 73 | mask=(offs_n[None, :] < num_projs) & (offs_d[:, None] * stride_proj_dird < headdim), 74 | other=0.0) 75 | 76 | # multiply the in_mat block with proj_dir block to get the mask 77 | mask = tl.dot(mat, proj_dir_block) 78 | mask = tl.where(mask > 0.0, 1.0, 0.0) 79 | 80 | # form enc_vec 81 | encoding_vectors = tl.load(enc_vec+offs_n, mask=offs_n < num_projs, other=0.0) 82 | 83 | # multiply mask by enc_vec 84 | bin_ids = tl.sum(mask * encoding_vectors[None, :], 1).to(tl.int32) 85 | # bin_ids = tl.ravel(bin_ids) # flatten the bin_ids into a 1d tensor 86 | 87 | # read hash buckets from look up table 88 | hash_buckets = tl.load(perm+bin_ids) 89 | 90 | # write back bin_ids 91 | # initialize pointers to output 92 | buckets_ptrs = buckets + off_b * stride_bucketsb + off_h * stride_bucketsh + offs_m 93 | if EVEN_M: 94 | tl.store(buckets_ptrs, hash_buckets) 95 | else: 96 | tl.store(buckets_ptrs, hash_buckets, mask=offs_m < seqlen) 97 | 98 | 99 | def _angular_lsh(in_mat, proj_dir, perm, enc_vec): 100 | # shape constraints 101 | num_projs = proj_dir.shape[-1] 102 | batch, nheads, seqlen, d = in_mat.shape 103 | assert (proj_dir.shape == (batch, nheads, d, num_projs)) or (proj_dir.shape == (1, 1, d, num_projs)) 104 | assert in_mat.dtype == proj_dir.dtype, "All three tensors must have the same type" 105 | assert in_mat.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16" 106 | assert in_mat.is_cuda and proj_dir.is_cuda and perm.is_cuda and enc_vec.is_cuda 107 | if proj_dir.shape[:2] == (1, 1): 108 | stride_proj_dirb, stride_proj_dirh = 0, 0 109 | else: 110 | stride_proj_dirb, stride_proj_dirh = proj_dir.stride()[:2] 111 | 112 | seqlen_rounded = math.ceil(seqlen / 128) * 128 113 | num_projs_rounded = 16 114 | buckets = torch.empty((batch, nheads, seqlen), device=in_mat.device, dtype=torch.int32) 115 | 116 | BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) 117 | BLOCK = 128 118 | num_warps = 4 if d <= 64 else 8 119 | grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_M"]), batch * nheads) 120 | _angular_lsh_kernel[grid]( 121 | in_mat=in_mat, 122 | proj_dir=proj_dir, 123 | perm=perm, 124 | enc_vec=enc_vec, 125 | buckets=buckets, 126 | stride_in_matb=in_mat.stride(0), 127 | stride_in_math=in_mat.stride(1), 128 | stride_in_matm=in_mat.stride(2), 129 | stride_proj_dirb=stride_proj_dirb, 130 | stride_proj_dirh=stride_proj_dirh, 131 | stride_proj_dird=proj_dir.stride(2), 132 | stride_bucketsb=buckets.stride(0), 133 | stride_bucketsh=buckets.stride(1), 134 | nheads=nheads, 135 | seqlen=seqlen, 136 | seqlen_rounded=seqlen_rounded, 137 | headdim=d, 138 | NUM_PROJ_ROUNDED=num_projs_rounded, 139 | num_projs=num_projs, 140 | BLOCK_HEADDIM=BLOCK_HEADDIM, 141 | BLOCK_M=BLOCK, 142 | num_warps=num_warps, 143 | num_stages=1, 144 | ) 145 | return buckets 146 | 147 | 148 | class AngularLSHTriton(torch.nn.Module): 149 | """ 150 | inputs: 151 | - num_projs: a positive integer that determines the number of random projections used by hash function 152 | - dim: positive integer that determines the dimension of input vectors 153 | - mat: a tensor whose last shape is equal to dim and gets hashed by the lsh function 154 | output: 155 | - buckets: a tensor with shape mat.shape[:-1] and each entry is an integer in [0, 2^num_proj - 1] 156 | """ 157 | def __init__(self, num_projs, dim, rng=None): 158 | super().__init__() 159 | self.num_projs = num_projs 160 | 161 | if num_projs > 0: 162 | self.register_buffer('perm', self._unit_hamming_distance_array(self.num_projs), persistent=False) 163 | self.register_buffer('proj_dir', torch.randn(dim + (num_projs,), generator=rng), persistent=False) 164 | self.register_buffer('enc_vec', 2 ** torch.arange(self.num_projs).view(1, 1, 1, -1), persistent=False) 165 | else: 166 | raise ValueError("Invalid value for num_projs") 167 | 168 | def _unit_hamming_distance_array(self, size_n): 169 | if size_n == 1: 170 | return torch.tensor([0, 1], dtype=torch.int32) 171 | a = self._unit_hamming_distance_array(size_n - 1) 172 | b = torch.concat([a, torch.flip(a, dims=[0]) + 2 ** (size_n - 1)], 0) 173 | return b if b.stride(-1) == 1 else b.contiguous() 174 | 175 | def hash_torch(self, mat): 176 | mask = torch.einsum('...nd,...dr -> ...nr', mat, self.proj_dir) 177 | mask = mask > 0 178 | bin_ids = (mask * self.enc_vec).sum(-1) 179 | return self.perm[bin_ids] 180 | 181 | def hash_triton(self, mat): 182 | return _angular_lsh(mat, self.proj_dir, self.perm, self.enc_vec) 183 | 184 | def __repr__(self): 185 | return f"AngularLSH(num_proj={self.num_projs}, proj_dir.shape={self.proj_dir.shape})" 186 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /unit_tests/test_hyper_attention.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import time 3 | import torch 4 | import triton 5 | import math 6 | 7 | import sys; sys.path.append("/home/ec2-user/workspace/hyper_attention") 8 | from src.flash_attn_triton import flash_attn_func 9 | from src.hyper_attn_triton import hyper_attn_func 10 | from src.attn_utils import add_self_attentions, indexing 11 | 12 | cnt = 0 13 | 14 | 15 | def check_memory(new_cnt=-1): 16 | global cnt 17 | if new_cnt != -1: 18 | cnt = new_cnt 19 | mem_alloc = torch.cuda.memory_allocated()/1024/1024/1024 20 | mem_reserve = torch.cuda.memory_reserved()/1024/1024/1024 21 | mem_peak = torch.cuda.memory_stats()['active_bytes.all.peak']/1024/1024/1024 22 | print(f"[{cnt}] mem_alloc: {mem_alloc:.4f}, mem_reserve: {mem_reserve:.4f}, mem_peak: {mem_peak:.4f}") 23 | cnt += 1 24 | return 25 | 26 | 27 | class MyTestCase(unittest.TestCase): 28 | def test_forward_attention(self): 29 | print("1. this is forward test") 30 | 31 | dtype = torch.bfloat16 32 | 33 | batch_size = 4 34 | block_size = 512 35 | dim = 128 36 | head_size = 32 37 | seq_len = 2048 38 | sample_size = 128 39 | smooth_block = False 40 | 41 | query = torch.randn((batch_size, head_size, seq_len, dim), device='cuda', dtype=dtype) 42 | key = torch.randn((batch_size, head_size, seq_len, dim), device='cuda', dtype=dtype) 43 | key[:, :, :sample_size, :] *= 4. 44 | value = torch.randn((batch_size, head_size, seq_len, dim), device='cuda', dtype=dtype) 45 | 46 | q_buckets_idx = torch.randint(0, 64, (batch_size, head_size, seq_len), device='cuda') 47 | k_buckets_idx = torch.randint(0, 64, (batch_size, head_size, seq_len), device='cuda') 48 | _, query_sort_idx = torch.sort(q_buckets_idx, dim=2, stable=True) 49 | _, key_sort_idx = torch.sort(k_buckets_idx, dim=2, stable=True) 50 | check_memory() 51 | 52 | # compute attention by presorting queries and sorting back the output attention 53 | t0 = time.time() 54 | query_sort_idx_inv = torch.argsort(query_sort_idx, dim=2, stable=True) 55 | query_sorted = indexing(query, query_sort_idx) 56 | key_sorted = indexing(key, key_sort_idx) 57 | value_sorted = indexing(value, key_sort_idx) 58 | query_split_per_block = query_sorted.view(-1, 1, block_size, dim) 59 | key_split_per_block = key_sorted.view(-1, 1, block_size, dim) 60 | value_split_per_block = value_sorted.view(-1, 1, block_size, dim) 61 | attn_block, lse_block = flash_attn_func(query_split_per_block.transpose(1, 2), 62 | key_split_per_block.transpose(1, 2), 63 | value_split_per_block.transpose(1, 2)) 64 | 65 | attn_sample, lse_sample = flash_attn_func(query.transpose(1, 2), 66 | key[:, :, :sample_size, :].transpose(1, 2), 67 | value[:, :, :sample_size, :].transpose(1, 2)) 68 | attn_block = attn_block.transpose(1, 2) 69 | attn_block = attn_block.view(batch_size, head_size, query_sorted.shape[2], -1) 70 | attn_sample = attn_sample.transpose(1, 2) 71 | lse_block = lse_block[:, :, :query_sorted.shape[2]] 72 | lse_block = lse_block.view(batch_size, head_size, query_sorted.shape[2], -1) 73 | flash_attn_block = indexing(attn_block, query_sort_idx_inv)+attn_sample 74 | lse_block = indexing(lse_block, query_sort_idx_inv) 75 | attn, lse = add_self_attentions(flash_attn_block, lse_block, attn_sample, lse_sample.unsqueeze(-1)) 76 | lse = lse.squeeze(-1) 77 | t1 = time.time() 78 | check_memory() 79 | print('the runtime of flash attention with permutation and indexing of queries:', t1-t0) 80 | 81 | # torch lse computation 82 | qk = query_split_per_block @ key_split_per_block.transpose(-1, -2) / math.sqrt(dim) 83 | lse_block_torch = torch.logsumexp(qk, dim=-1, keepdim=True) 84 | lse_block_torch = lse_block_torch.view(batch_size, head_size, query_sorted.shape[2], -1) 85 | lse_block_torch = indexing(lse_block_torch, query_sort_idx_inv).squeeze(-1) 86 | lse_sample_torch = torch.logsumexp( 87 | query @ key[:, :, :sample_size, :].transpose(-1, -2) / math.sqrt(dim), 88 | dim=-1, 89 | keepdim=True 90 | ).squeeze(-1) 91 | lse_torch = (lse_sample_torch.exp() + lse_block_torch.exp()).log().to(dtype=lse_block_torch.dtype) 92 | print('diff between lse with sample and without: ', (lse_block_torch - lse_torch).norm(), lse_torch.norm()) 93 | print('error flash attention:', (lse - lse_torch).norm(), lse_torch.norm()) 94 | 95 | # compute attention kernel which permutes queries in triton 96 | check_memory(0) 97 | t2 = time.time() 98 | attn_hyper, lse_hyper = hyper_attn_func( 99 | query.transpose(1, 2), 100 | key.transpose(1, 2), 101 | value.transpose(1, 2), 102 | query_sort_idx.transpose(1, 2), 103 | key_sort_idx.transpose(1, 2), 104 | block_size, 105 | sample_size, 106 | 1./math.sqrt(dim), 107 | smooth_block, 108 | ) 109 | attn_hyper = attn_hyper.transpose(1, 2) 110 | t3 = time.time() 111 | check_memory() 112 | 113 | print('the runtime of hyper attention:', t3 - t2) 114 | 115 | print('diff lse hyper_attention and flash with indexing and permutation: ', (lse - lse_hyper).norm(), lse.norm()) 116 | 117 | print('error hyper attention lse: ', (lse_hyper - lse_torch).norm(), lse_torch.norm()) 118 | 119 | # check if dimension of V can be different from that of Q and K 120 | value_small = value[:, :, :, :dim//2].clone() 121 | attn_triton_unequal_dim, lse_triton_unequal_dim = hyper_attn_func( 122 | query.transpose(1, 2), 123 | key.transpose(1, 2), 124 | value_small.transpose(1, 2), 125 | query_sort_idx.transpose(1, 2), 126 | key_sort_idx.transpose(1, 2), 127 | block_size, 128 | sample_size, 129 | ) 130 | attn_triton_unequal_dim = attn_triton_unequal_dim.transpose(1, 2) 131 | 132 | print('testing unequal dimension for V compared to Q, K') 133 | print((attn_hyper[:, :, :, :dim//2] - attn_triton_unequal_dim).norm()) 134 | 135 | def test_gradient(self): 136 | print() 137 | print("2. this is gradients test") 138 | 139 | dtype = torch.bfloat16 140 | 141 | batch_size = 4 142 | block_size = 256 143 | dim = 64 144 | head_size = 32 145 | seq_len = 2048 146 | sample_size = 128 147 | 148 | query = torch.randn((batch_size, head_size, seq_len, dim), device='cuda', dtype=dtype, requires_grad=True) 149 | key = torch.randn((batch_size, head_size, seq_len, dim), device='cuda', dtype=dtype, requires_grad=True) 150 | value = torch.randn((batch_size, head_size, seq_len, dim), device='cuda', dtype=dtype, requires_grad=True) 151 | do = torch.randn_like(value) 152 | 153 | q_buckets_idx = torch.randint(0, 64, (batch_size, head_size, seq_len), device='cuda') 154 | k_buckets_idx = torch.randint(0, 64, (batch_size, head_size, seq_len), device='cuda') 155 | _, query_sort_idx = torch.sort(q_buckets_idx, dim=2, stable=True) 156 | _, key_sort_idx = torch.sort(k_buckets_idx, dim=2, stable=True) 157 | 158 | t0 = time.time() 159 | query_sort_idx_inv = torch.argsort(query_sort_idx, dim=2, stable=True) 160 | query_sorted = indexing(query, query_sort_idx) 161 | key_sorted = indexing(key, key_sort_idx) 162 | value_sorted = indexing(value, key_sort_idx) 163 | query_split_per_block = query_sorted.view(-1, 1, block_size, dim) 164 | key_split_per_block = key_sorted.view(-1, 1, block_size, dim) 165 | value_split_per_block = value_sorted.view(-1, 1, block_size, dim) 166 | 167 | attn_block, lse_block = flash_attn_func(query_split_per_block.transpose(1, 2), 168 | key_split_per_block.transpose(1, 2), 169 | value_split_per_block.transpose(1, 2)) 170 | 171 | attn_block = attn_block.transpose(1, 2) 172 | attn_block = attn_block.view(batch_size, head_size, query_sorted.shape[2], -1) 173 | attn = indexing(attn_block, query_sort_idx_inv) 174 | 175 | attn.backward(do, retain_graph=True) 176 | t1 = time.time() 177 | print('flash attention and indexing forward+backward time: ', t1-t0) 178 | q_grad = query.grad.detach().clone() 179 | k_grad = key.grad.detach().clone() 180 | v_grad = value.grad.detach().clone() 181 | 182 | query.grad = None 183 | key.grad = None 184 | value.grad = None 185 | 186 | # torch computation 187 | 188 | qk = query_split_per_block @ key_split_per_block.transpose(-1, -2) / math.sqrt(dim) 189 | attn_block_torch = qk.softmax(dim=-1) @ value_split_per_block 190 | attn_torch_block = indexing( 191 | attn_block_torch.view(batch_size, head_size, query_sorted.shape[2], -1), 192 | query_sort_idx_inv 193 | ) 194 | lse_block_torch = torch.logsumexp(qk, dim=-1, keepdim=True) 195 | lse_torch_block = indexing( 196 | lse_block_torch.view(batch_size, head_size, query_sorted.shape[2], -1), 197 | query_sort_idx_inv 198 | ) 199 | 200 | qk_sample = query @ key[:, :, :sample_size, :].transpose(-1, -2) / math.sqrt(dim) 201 | attn_torch_sample = qk_sample.softmax(dim=-1) @ value[:, :, :sample_size, :] 202 | lse_torch_sample = torch.logsumexp(qk_sample, dim=-1, keepdim=True) 203 | 204 | attn_torch, lse_torch = add_self_attentions(attn_torch_block, lse_torch_block, attn_torch_sample, lse_torch_sample) 205 | lse_torch = lse_torch.squeeze(-1) 206 | attn_torch.backward(do, retain_graph=True) 207 | 208 | q_grad_torch = query.grad.detach().clone() 209 | k_grad_torch = key.grad.detach().clone() 210 | v_grad_torch = value.grad.detach().clone() 211 | 212 | query.grad = None 213 | key.grad = None 214 | value.grad = None 215 | 216 | # hyper attention computation 217 | 218 | t2 = time.time() 219 | hyper_attn, hyper_lse = hyper_attn_func( 220 | query.transpose(1, 2), 221 | key.transpose(1, 2), 222 | value.transpose(1, 2), 223 | query_sort_idx.transpose(1, 2), 224 | key_sort_idx.transpose(1, 2), 225 | block_size, 226 | sample_size, 227 | ) 228 | hyper_attn = hyper_attn.transpose(1, 2) 229 | hyper_attn.backward(do, retain_graph=True) 230 | t3 = time.time() 231 | print('hyper attention triton forward+backward time: ', t3 - t2) 232 | 233 | q_grad_hyper = query.grad.detach().clone() 234 | k_grad_hyper = key.grad.detach().clone() 235 | v_grad_hyper = value.grad.detach().clone() 236 | 237 | print('difference of torch attention and flash attn: ', (attn_torch-attn).norm(), attn_torch.norm()) 238 | print('difference of torch lse and hyper_attention lse: ', (lse_torch - hyper_lse).norm(), lse_torch.norm()) 239 | 240 | print('difference between gradients of queries, flash vs hyper:') 241 | print((q_grad - q_grad_hyper).norm()) 242 | 243 | print('difference between gradients of keys, flash vs hyper:') 244 | print((k_grad - k_grad_hyper).norm()) 245 | 246 | print('difference between gradients of values, flash vs hyper:') 247 | print((v_grad - v_grad_hyper).norm()) 248 | 249 | print('difference gradients of queries, torch vs hyper:') 250 | print((q_grad_torch - q_grad_hyper).norm(), q_grad_torch.norm(), q_grad_hyper.norm()) 251 | 252 | print('difference gradients of keys, torch vs hyper:') 253 | print((k_grad_torch - k_grad_hyper).norm(), k_grad_torch.norm(), k_grad_hyper.norm()) 254 | 255 | print('difference gradients of values, torch vs hyper:') 256 | print((v_grad_torch - v_grad_hyper).norm(), v_grad_torch.norm(), v_grad_hyper.norm()) 257 | 258 | 259 | if __name__ == '__main__': 260 | unittest.main() 261 | # test_runtime() -------------------------------------------------------------------------------- /src/hyper_attn_triton.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of HyperAttention in Triton. 3 | Tested with triton==2.1.0. 4 | 5 | Requirements: 6 | - This implementation does not support attention bias (additive mask to qk). 7 | - This implementation only supports sequence lengths that are integer powers of two. 8 | - the permutation indices for q and k must have the same sequence length as q and k themselves 9 | - sequence length for q and k must be equal 10 | """ 11 | 12 | import math 13 | 14 | import torch 15 | import triton 16 | import triton.language as tl 17 | 18 | @triton.heuristics( 19 | { 20 | "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], 21 | "EVEN_V_HEADDIM": lambda args: args["v_headdim"] == args["V_BLOCK_HEADDIM"], 22 | } 23 | ) 24 | # bug when seqlen_q is not divisible by BLOCK_M=128 25 | @triton.jit 26 | def _fwd_hyper_kernel( 27 | Q, 28 | K, 29 | V, 30 | q_sort_idx, 31 | k_sort_idx, 32 | Out, 33 | Lse, 34 | softmax_scale, 35 | stride_qb, 36 | stride_qh, 37 | stride_qm, 38 | stride_kb, 39 | stride_kh, 40 | stride_kn, 41 | stride_vb, 42 | stride_vh, 43 | stride_vn, 44 | stride_q_sort_idxb, 45 | stride_q_sort_idxh, 46 | stride_q_sort_idxm, 47 | stride_k_sort_idxb, 48 | stride_k_sort_idxh, 49 | stride_k_sort_idxn, 50 | stride_ob, 51 | stride_oh, 52 | stride_om, 53 | nheads, 54 | block_size, 55 | sample_size, 56 | seqlen_k, 57 | seqlen_q, 58 | headdim, 59 | v_headdim, 60 | smooth_block, 61 | CACHE_KEY_SEQLEN_Q, 62 | CACHE_KEY_SEQLEN_K, 63 | BLOCK_HEADDIM: tl.constexpr, 64 | V_BLOCK_HEADDIM: tl.constexpr, 65 | EVEN_HEADDIM: tl.constexpr, 66 | EVEN_V_HEADDIM: tl.constexpr, 67 | BLOCK_M: tl.constexpr, 68 | BLOCK_N: tl.constexpr, 69 | ): 70 | start_m = tl.program_id(0) 71 | off_hb = tl.program_id(1) 72 | off_b = off_hb // nheads 73 | off_h = off_hb % nheads 74 | # initialize offsets 75 | offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) 76 | offs_n = tl.arange(0, BLOCK_N) 77 | offs_d = tl.arange(0, BLOCK_HEADDIM) 78 | offs_vd = tl.arange(0, V_BLOCK_HEADDIM) 79 | # Initialize pointers to Q, K, V 80 | q_idx_ptrs = ( 81 | q_sort_idx + off_b * stride_q_sort_idxb + off_h * stride_q_sort_idxh + offs_m * stride_q_sort_idxm 82 | ) 83 | q_idx = tl.load(q_idx_ptrs).to(tl.int32) 84 | 85 | k_sort_idx += off_b * stride_k_sort_idxb + off_h * stride_k_sort_idxh 86 | 87 | # initialize pointer to m and l 88 | lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") 89 | m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") 90 | acc_o = tl.zeros([BLOCK_M, V_BLOCK_HEADDIM], dtype=tl.float32) 91 | q_ptrs = ( 92 | Q + off_b * stride_qb + off_h * stride_qh + (q_idx[:, None] * stride_qm + offs_d[None, :]) 93 | ) 94 | if EVEN_HEADDIM: 95 | q = tl.load(q_ptrs) 96 | else: 97 | q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0) 98 | 99 | # block diagonal part 100 | # loop over k, v and update accumulator 101 | block_id = start_m // block_size 102 | block_offs = seqlen_k + (start_m % block_size) * BLOCK_N - (block_size-1) * BLOCK_N//2 103 | end_n = tl.minimum((block_id + 1) * BLOCK_N * block_size, seqlen_k) 104 | for start_n in range(block_id * BLOCK_N * block_size, end_n, BLOCK_N): 105 | start_n = tl.multiple_of(start_n, BLOCK_N) 106 | if smooth_block: 107 | k_idx_ptrs = ((start_n + block_offs + offs_n) * stride_k_sort_idxn) % seqlen_k 108 | else: 109 | k_idx_ptrs = (start_n + offs_n) * stride_k_sort_idxn 110 | 111 | k_idx = tl.load(k_sort_idx + k_idx_ptrs).to(tl.int32) 112 | k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (k_idx[:, None] * stride_kn + offs_d[None, :]) 113 | # -- compute qk ---- 114 | if EVEN_HEADDIM: 115 | k = tl.load(k_ptrs) 116 | else: 117 | k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0) 118 | qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 119 | qk += tl.dot(q, tl.trans(k)) 120 | m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i) 121 | p = tl.exp(qk * softmax_scale - m_ij[:, None]) 122 | l_ij = tl.sum(p, 1) 123 | 124 | # scale acc_o 125 | acc_o_scale = tl.exp(m_i - m_ij) 126 | 127 | # # -- update output accumulator acc_o -- 128 | acc_o = acc_o * acc_o_scale[:, None] 129 | 130 | v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (k_idx[:, None] * stride_vn + offs_vd[None, :]) 131 | if EVEN_V_HEADDIM: 132 | v = tl.load(v_ptrs) 133 | else: 134 | v = tl.load(v_ptrs, mask=offs_vd[None, :] < v_headdim, other=0.0) 135 | p = p.to(v.dtype) 136 | acc_o += tl.dot(p, v) 137 | 138 | # -- update statistics 139 | m_i = m_ij 140 | l_i_new = tl.exp(lse_i - m_ij) + l_ij 141 | lse_i = m_ij + tl.log(l_i_new) 142 | # compute sampled columns 143 | for col_block in range(0, sample_size): 144 | curr_offs_n = col_block * BLOCK_N * stride_kn + offs_n 145 | k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (curr_offs_n[:, None] * stride_kn + offs_d[None, :]) 146 | # -- compute qk ---- 147 | if EVEN_HEADDIM: 148 | k = tl.load(k_ptrs) 149 | else: 150 | k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0) 151 | qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 152 | qk += tl.dot(q, tl.trans(k)) 153 | m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i) 154 | p = tl.exp(qk * softmax_scale - m_ij[:, None]) 155 | l_ij = tl.sum(p, 1) 156 | 157 | # scale acc_o 158 | acc_o_scale = tl.exp(m_i - m_ij) 159 | # # -- update output accumulator acc_o -- 160 | acc_o = acc_o * acc_o_scale[:, None] 161 | 162 | v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (curr_offs_n[:, None] * stride_vn + offs_vd[None, :]) 163 | if EVEN_V_HEADDIM: 164 | v = tl.load(v_ptrs) 165 | else: 166 | v = tl.load(v_ptrs, mask=offs_vd[None, :] < v_headdim, other=0.0) 167 | p = p.to(v.dtype) 168 | acc_o += tl.dot(p, v) 169 | 170 | # -- update statistics 171 | m_i = m_ij 172 | l_i_new = tl.exp(lse_i - m_ij) + l_ij 173 | lse_i = m_ij + tl.log(l_i_new) 174 | 175 | 176 | o_scale = tl.exp(m_i - lse_i) 177 | acc_o = acc_o * o_scale[:, None] 178 | 179 | # initialize pointers to outputs 180 | lse_ptrs = Lse + off_hb * seqlen_q + q_idx 181 | out_ptrs = ( 182 | Out 183 | + off_b * stride_ob 184 | + off_h * stride_oh 185 | + (q_idx[:, None] * stride_om + offs_vd[None, :]) 186 | ) 187 | # write back l and m 188 | tl.store(lse_ptrs, lse_i) 189 | if EVEN_V_HEADDIM: 190 | tl.store(out_ptrs, acc_o) 191 | else: 192 | tl.store(out_ptrs, acc_o, mask=offs_vd[None, :] < v_headdim) 193 | 194 | 195 | @triton.jit 196 | def _bwd_preprocess_do_o_dot( 197 | Out, 198 | DO, 199 | Delta, 200 | stride_ob, 201 | stride_oh, 202 | stride_om, 203 | stride_dob, 204 | stride_doh, 205 | stride_dom, 206 | nheads, 207 | seqlen_q, 208 | v_headdim, 209 | BLOCK_M: tl.constexpr, 210 | V_BLOCK_HEADDIM: tl.constexpr, 211 | ): 212 | start_m = tl.program_id(0) 213 | off_hb = tl.program_id(1) 214 | off_b = off_hb // nheads 215 | off_h = off_hb % nheads 216 | # initialize offsets 217 | offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) 218 | offs_d = tl.arange(0, V_BLOCK_HEADDIM) 219 | # load 220 | o = tl.load( 221 | Out + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :], 222 | mask=offs_d[None, :] < v_headdim, 223 | other=0.0, 224 | ).to(tl.float32) 225 | do = tl.load( 226 | DO 227 | + off_b * stride_dob 228 | + off_h * stride_doh 229 | + offs_m[:, None] * stride_dom 230 | + offs_d[None, :], 231 | mask=offs_d[None, :] < v_headdim, 232 | other=0.0, 233 | ).to(tl.float32) 234 | delta = tl.sum(o * do, axis=1) 235 | # write-back 236 | tl.store(Delta + off_hb * seqlen_q + offs_m, delta) 237 | 238 | 239 | @triton.jit 240 | def _bwd_store_dx( 241 | dx_ptrs, 242 | dx, 243 | offs_d, 244 | headdim, 245 | even_headdim, 246 | ): 247 | if even_headdim: 248 | tl.store(dx_ptrs, dx) 249 | else: 250 | tl.store(dx_ptrs, dx, mask=offs_d[None, :] < headdim) 251 | 252 | 253 | @triton.jit 254 | def _bwd_blocked_kernel_one_col( 255 | start_n, 256 | Q, 257 | K, 258 | V, 259 | Q_idx, 260 | K_idx, 261 | DO, 262 | DQ, 263 | DK, 264 | DV, 265 | LSE, 266 | D, 267 | softmax_scale, 268 | stride_qm, 269 | stride_kn, 270 | stride_vn, 271 | stride_dom, 272 | stride_dqm, 273 | stride_dkn, 274 | stride_dvn, 275 | stride_q_idxm, 276 | stride_k_idxn, 277 | seqlen_q, 278 | block_size, 279 | headdim, 280 | v_headdim, 281 | smooth_block, 282 | BLOCK_HEADDIM: tl.constexpr, 283 | V_BLOCK_HEADDIM: tl.constexpr, 284 | EVEN_HEADDIM: tl.constexpr, 285 | EVEN_V_HEADDIM: tl.constexpr, 286 | BLOCK_M: tl.constexpr, 287 | BLOCK_N: tl.constexpr, 288 | ): 289 | # We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N) 290 | block_id = start_n // block_size 291 | block_offs = seqlen_q + (start_n % block_size) * BLOCK_M - (block_size - 1) * BLOCK_M // 2 292 | begin_m = block_id * BLOCK_M * block_size 293 | # initialize row / col offsets 294 | offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) 295 | offs_m = tl.arange(0, BLOCK_M) 296 | offs_d = tl.arange(0, BLOCK_HEADDIM) 297 | offs_vd = tl.arange(0, V_BLOCK_HEADDIM) 298 | # initialize pointers to value-like data 299 | k_idx_ptrs = K_idx + offs_n * stride_k_idxn 300 | k_idx = tl.load(k_idx_ptrs).to(tl.int32) 301 | k_ptrs = K + (k_idx[:, None] * stride_kn + offs_d[None, :]) 302 | v_ptrs = V + (k_idx[:, None] * stride_vn + offs_vd[None, :]) 303 | # initialize dv and dk 304 | dv = tl.zeros([BLOCK_N, V_BLOCK_HEADDIM], dtype=tl.float32) 305 | dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) 306 | 307 | # k and v stay in SRAM throughout 308 | if EVEN_HEADDIM: 309 | k = tl.load(k_ptrs) 310 | else: 311 | k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0) 312 | if EVEN_V_HEADDIM: 313 | v = tl.load(v_ptrs) 314 | else: 315 | v = tl.load(v_ptrs, mask=offs_vd[None, :] < v_headdim, other=0.0) 316 | 317 | # loop over rows 318 | end_m = tl.minimum((block_id + 1) * BLOCK_M * block_size, seqlen_q) 319 | for start_m in range(begin_m, end_m, BLOCK_M): 320 | start_m = tl.multiple_of(start_m, BLOCK_M) 321 | if smooth_block: 322 | q_idx_ptrs = ((start_m + block_offs + offs_m) * stride_q_idxm) % seqlen_q 323 | else: 324 | q_idx_ptrs = (start_m + offs_m) * stride_q_idxm 325 | q_idx = tl.load(Q_idx + q_idx_ptrs).to(tl.int32) 326 | q_ptrs = Q + (q_idx[:, None] * stride_qm + offs_d[None, :]) 327 | # load q, k, v, do on-chip 328 | if EVEN_HEADDIM: 329 | q = tl.load(q_ptrs) 330 | else: 331 | q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0) 332 | # recompute p = softmax(qk, dim=-1).T 333 | qk = tl.dot(q, tl.trans(k)) 334 | if not EVEN_HEADDIM: 335 | tl.debug_barrier() 336 | lse_i = tl.load(LSE + q_idx) 337 | p = tl.exp(qk * softmax_scale - lse_i[:, None]) 338 | # compute dv 339 | do_ptrs = DO + (q_idx[:, None] * stride_dom + offs_vd[None, :]) 340 | if EVEN_V_HEADDIM: 341 | do = tl.load(do_ptrs) 342 | else: 343 | do = tl.load(do_ptrs, mask=offs_vd[None, :] < v_headdim, other=0.0) 344 | dv += tl.dot(tl.trans(p.to(do.dtype)), do) 345 | # compute dp = dot(v, do) 346 | if not EVEN_HEADDIM: 347 | tl.debug_barrier() 348 | dp = tl.dot(do, tl.trans(v)) 349 | # There's a race condition for headdim=48 350 | if not EVEN_HEADDIM: 351 | tl.debug_barrier() 352 | # compute ds = p * (dp - delta[:, None]) 353 | # Putting the subtraction after the dp matmul (instead of before) is slightly faster 354 | Di = tl.load(D + q_idx) 355 | # Converting ds to q.dtype here reduces register pressure and makes it much faster 356 | # for BLOCK_HEADDIM=128 357 | ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype) 358 | # compute dk = dot(ds.T, q) 359 | dk += tl.dot(tl.trans(ds), q) 360 | # compute dq 361 | if not EVEN_HEADDIM: # Otherewise there's a race condition when BIAS_TYPE='matrix' 362 | tl.debug_barrier() 363 | 364 | dq_ptrs = DQ + (q_idx[:, None] * stride_dqm + offs_d[None, :]) 365 | dq = tl.dot(ds, k) 366 | if EVEN_HEADDIM: 367 | tl.atomic_add(dq_ptrs, dq) 368 | else: 369 | tl.atomic_add(dq_ptrs, dq, mask=offs_d[None, :] < headdim) 370 | 371 | 372 | # write-back 373 | dv_ptrs = DV + (k_idx[:, None] * stride_dvn + offs_vd[None, :]) 374 | dk_ptrs = DK + (k_idx[:, None] * stride_dkn + offs_d[None, :]) 375 | _bwd_store_dx( 376 | dk_ptrs, 377 | dk, 378 | offs_d, 379 | headdim, 380 | even_headdim=EVEN_HEADDIM, 381 | ) 382 | _bwd_store_dx( 383 | dv_ptrs, 384 | dv, 385 | offs_vd, 386 | v_headdim, 387 | even_headdim=EVEN_V_HEADDIM, 388 | ) 389 | 390 | 391 | 392 | @triton.heuristics( 393 | { 394 | "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], 395 | "EVEN_V_HEADDIM": lambda args: args["v_headdim"] == args["V_BLOCK_HEADDIM"], 396 | } 397 | ) 398 | @triton.jit 399 | def _bwd_permuted_block_diagonal_kernel( 400 | Q, 401 | K, 402 | V, 403 | q_sort_idx, 404 | k_sort_idx, 405 | DO, 406 | DQ, 407 | DK, 408 | DV, 409 | LSE, 410 | D, 411 | softmax_scale, 412 | stride_qb, 413 | stride_qh, 414 | stride_qm, 415 | stride_kb, 416 | stride_kh, 417 | stride_kn, 418 | stride_vb, 419 | stride_vh, 420 | stride_vn, 421 | stride_q_sort_idxb, 422 | stride_q_sort_idxh, 423 | stride_q_sort_idxm, 424 | stride_k_sort_idxb, 425 | stride_k_sort_idxh, 426 | stride_k_sort_idxn, 427 | stride_dob, 428 | stride_doh, 429 | stride_dom, 430 | stride_dqb, 431 | stride_dqh, 432 | stride_dqm, 433 | stride_dkb, 434 | stride_dkh, 435 | stride_dkn, 436 | stride_dvb, 437 | stride_dvh, 438 | stride_dvn, 439 | nheads, 440 | seqlen_q, 441 | block_size, 442 | headdim, 443 | v_headdim, 444 | smooth_block, 445 | CACHE_KEY_SEQLEN_Q, 446 | CACHE_KEY_SEQLEN_K, 447 | BLOCK_HEADDIM: tl.constexpr, 448 | V_BLOCK_HEADDIM: tl.constexpr, 449 | EVEN_HEADDIM: tl.constexpr, 450 | EVEN_V_HEADDIM: tl.constexpr, 451 | BLOCK_M: tl.constexpr, 452 | BLOCK_N: tl.constexpr, 453 | ): 454 | off_hb = tl.program_id(1) 455 | off_b = off_hb // nheads 456 | off_h = off_hb % nheads 457 | # offset pointers for batch/head 458 | Q += off_b * stride_qb + off_h * stride_qh 459 | K += off_b * stride_kb + off_h * stride_kh 460 | V += off_b * stride_vb + off_h * stride_vh 461 | Q_idx = q_sort_idx + off_b * stride_q_sort_idxb + off_h * stride_q_sort_idxh 462 | K_idx = k_sort_idx + off_b * stride_k_sort_idxb + off_h * stride_k_sort_idxh 463 | DO += off_b * stride_dob + off_h * stride_doh 464 | DQ += off_b * stride_dqb + off_h * stride_dqh 465 | DK += off_b * stride_dkb + off_h * stride_dkh 466 | DV += off_b * stride_dvb + off_h * stride_dvh 467 | # pointer to row-wise quantities in value-like data 468 | D += off_hb * seqlen_q 469 | LSE += off_hb * seqlen_q 470 | 471 | start_n = tl.program_id(0) 472 | _bwd_blocked_kernel_one_col( 473 | start_n=start_n, 474 | Q=Q, 475 | K=K, 476 | V=V, 477 | Q_idx=Q_idx, 478 | K_idx=K_idx, 479 | DO=DO, 480 | DQ=DQ, 481 | DK=DK, 482 | DV=DV, 483 | LSE=LSE, 484 | D=D, 485 | softmax_scale=softmax_scale, 486 | stride_qm=stride_qm, 487 | stride_kn=stride_kn, 488 | stride_vn=stride_vn, 489 | stride_dom=stride_dom, 490 | stride_dqm=stride_dqm, 491 | stride_dkn=stride_dkn, 492 | stride_dvn=stride_dvn, 493 | stride_q_idxm=stride_q_sort_idxm, 494 | stride_k_idxn=stride_k_sort_idxn, 495 | seqlen_q=seqlen_q, 496 | block_size=block_size // BLOCK_N, 497 | headdim=headdim, 498 | v_headdim=v_headdim, 499 | smooth_block=smooth_block, 500 | BLOCK_HEADDIM=BLOCK_HEADDIM, 501 | V_BLOCK_HEADDIM=V_BLOCK_HEADDIM, 502 | EVEN_HEADDIM=EVEN_HEADDIM, 503 | EVEN_V_HEADDIM=EVEN_V_HEADDIM, 504 | BLOCK_M=BLOCK_M, 505 | BLOCK_N=BLOCK_N, 506 | ) 507 | 508 | 509 | 510 | @triton.heuristics( 511 | { 512 | "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], 513 | "EVEN_V_HEADDIM": lambda args: args["v_headdim"] == args["V_BLOCK_HEADDIM"], 514 | } 515 | ) 516 | @triton.jit 517 | def _bwd_sampled_col_kernel( 518 | Q, 519 | K, 520 | V, 521 | DO, 522 | DQ, 523 | DK, 524 | DV, 525 | LSE, 526 | D, 527 | softmax_scale, 528 | stride_qb, 529 | stride_qh, 530 | stride_qm, 531 | stride_kb, 532 | stride_kh, 533 | stride_kn, 534 | stride_vb, 535 | stride_vh, 536 | stride_vn, 537 | stride_dob, 538 | stride_doh, 539 | stride_dom, 540 | stride_dqb, 541 | stride_dqh, 542 | stride_dqm, 543 | stride_dkb, 544 | stride_dkh, 545 | stride_dkn, 546 | stride_dvb, 547 | stride_dvh, 548 | stride_dvn, 549 | nheads, 550 | seqlen_q, 551 | headdim, 552 | v_headdim, 553 | CACHE_KEY_SEQLEN_Q, 554 | CACHE_KEY_SEQLEN_K, 555 | BLOCK_HEADDIM: tl.constexpr, 556 | V_BLOCK_HEADDIM: tl.constexpr, 557 | EVEN_HEADDIM: tl.constexpr, 558 | EVEN_V_HEADDIM: tl.constexpr, 559 | BLOCK_M: tl.constexpr, 560 | BLOCK_N: tl.constexpr, 561 | ): 562 | off_hb = tl.program_id(1) 563 | off_b = off_hb // nheads 564 | off_h = off_hb % nheads 565 | # offset pointers for batch/head 566 | Q += off_b * stride_qb + off_h * stride_qh 567 | DO += off_b * stride_dob + off_h * stride_doh 568 | DQ += off_b * stride_dqb + off_h * stride_dqh 569 | # pointer to row-wise quantities in value-like data 570 | D += off_hb * seqlen_q 571 | LSE += off_hb * seqlen_q 572 | 573 | start_n = tl.program_id(0) 574 | 575 | # initialize row / col offsets 576 | offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) 577 | offs_m = tl.arange(0, BLOCK_M) 578 | offs_d = tl.arange(0, BLOCK_HEADDIM) 579 | offs_vd = tl.arange(0, V_BLOCK_HEADDIM) 580 | # initialize pointers to value-like data 581 | k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :]) 582 | v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_vd[None, :]) 583 | # initialize dv and dk 584 | dv = tl.zeros([BLOCK_N, V_BLOCK_HEADDIM], dtype=tl.float32) 585 | dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) 586 | 587 | # k and v stay in SRAM throughout 588 | if EVEN_HEADDIM: 589 | k = tl.load(k_ptrs) 590 | else: 591 | k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0) 592 | if EVEN_V_HEADDIM: 593 | v = tl.load(v_ptrs) 594 | else: 595 | v = tl.load(v_ptrs, mask=offs_vd[None, :] < v_headdim, other=0.0) 596 | 597 | # loop over rows 598 | for start_m in range(0, seqlen_q, BLOCK_M): 599 | start_m = tl.multiple_of(start_m, BLOCK_M) 600 | offs_m_curr = start_m + offs_m 601 | q_ptrs = Q + (offs_m_curr[:, None] * stride_qm + offs_d[None, :]) 602 | # load q, k, v, do on-chip 603 | if EVEN_HEADDIM: 604 | q = tl.load(q_ptrs) 605 | else: 606 | q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0) 607 | # recompute p = softmax(qk, dim=-1).T 608 | qk = tl.dot(q, tl.trans(k)) 609 | if not EVEN_HEADDIM: 610 | tl.debug_barrier() 611 | lse_i = tl.load(LSE + offs_m_curr) 612 | p = tl.exp(qk * softmax_scale - lse_i[:, None]) 613 | # compute dv 614 | do_ptrs = DO + (offs_m_curr[:, None] * stride_dom + offs_vd[None, :]) 615 | if EVEN_V_HEADDIM: 616 | do = tl.load(do_ptrs) 617 | else: 618 | do = tl.load(do_ptrs, mask=offs_vd[None, :] < v_headdim, other=0.0) 619 | dv += tl.dot(tl.trans(p.to(do.dtype)), do) 620 | # compute dp = dot(v, do) 621 | if not EVEN_HEADDIM: 622 | tl.debug_barrier() 623 | dp = tl.dot(do, tl.trans(v)) 624 | # There's a race condition for headdim=48 625 | if not EVEN_HEADDIM: 626 | tl.debug_barrier() 627 | # compute ds = p * (dp - delta[:, None]) 628 | # Putting the subtraction after the dp matmul (instead of before) is slightly faster 629 | Di = tl.load(D + offs_m_curr) 630 | # Converting ds to q.dtype here reduces register pressure and makes it much faster 631 | # for BLOCK_HEADDIM=128 632 | ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype) 633 | # compute dk = dot(ds.T, q) 634 | dk += tl.dot(tl.trans(ds), q) 635 | # compute dq 636 | if not EVEN_HEADDIM: # Otherewise there's a race condition when BIAS_TYPE='matrix' 637 | tl.debug_barrier() 638 | 639 | dq_ptrs = DQ + (offs_m_curr[:, None] * stride_dqm + offs_d[None, :]) 640 | dq = tl.dot(ds, k) 641 | if EVEN_HEADDIM: 642 | tl.atomic_add(dq_ptrs, dq) 643 | else: 644 | tl.atomic_add(dq_ptrs, dq, mask=offs_d[None, :] < headdim) 645 | 646 | dv_ptrs = DV + off_b * stride_dvb + off_h * stride_dvh + (offs_n[:, None] * stride_dvn + offs_vd[None, :]) 647 | dk_ptrs = DK + off_b * stride_dkb + off_h * stride_dkh + (offs_n[:, None] * stride_dkn + offs_d[None, :]) 648 | dk += tl.load(dk_ptrs) 649 | dv += tl.load(dv_ptrs) 650 | _bwd_store_dx( 651 | dk_ptrs, 652 | dk, 653 | offs_d, 654 | headdim, 655 | even_headdim=EVEN_HEADDIM, 656 | ) 657 | _bwd_store_dx( 658 | dv_ptrs, 659 | dv, 660 | offs_vd, 661 | v_headdim, 662 | even_headdim=EVEN_V_HEADDIM, 663 | ) 664 | 665 | return 666 | 667 | 668 | def _hyper_attn_forward(q, k, v, q_sort_idx, k_sort_idx, block_size, sample_size, softmax_scale=None, 669 | smooth_block=False): 670 | """ 671 | Initializes the forward kernel and schedules thread blocks and runs them in parallel 672 | """ 673 | # shape constraints 674 | batch, seqlen_q, nheads, d = q.shape 675 | _, seqlen_k, _, _ = k.shape 676 | _, seqlen_q_idx,_ = q_sort_idx.shape 677 | _, seqlen_k_idx, _ = k_sort_idx.shape 678 | assert k.shape == (batch, seqlen_k, nheads, d) 679 | assert v.shape[:3] == (batch, seqlen_k, nheads) 680 | assert q_sort_idx.shape == q.shape[:3] 681 | assert k_sort_idx.shape == k.shape[:3] 682 | assert d <= 128, "FlashAttention only support head dimensions up to 128" 683 | assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type" 684 | assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16" 685 | assert q.is_cuda and k.is_cuda and v.is_cuda and q_sort_idx.is_cuda and k_sort_idx.is_cuda 686 | softmax_scale = softmax_scale or 1.0 / math.sqrt(d) 687 | lse = torch.empty((batch, nheads, seqlen_q), device=q.device, dtype=torch.float32) 688 | # o = torch.empty_like(q) 689 | o = torch.empty((batch, seqlen_q, nheads, v.shape[-1]), device=q.device, dtype=q.dtype) 690 | 691 | BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) 692 | v_headdim = v.shape[3] 693 | V_BLOCK_HEADDIM = max(triton.next_power_of_2(v_headdim), 16) 694 | BLOCK = 128 695 | assert seqlen_k % BLOCK == 0, f'keys sequence length must be divisible by {BLOCK}' 696 | num_warps = 4 if d <= 64 else 8 697 | grid = lambda META: (triton.cdiv(seqlen_q_idx, META["BLOCK_M"]), batch * nheads) 698 | _fwd_hyper_kernel[grid]( 699 | Q=q, 700 | K=k, 701 | V=v, 702 | q_sort_idx=q_sort_idx, 703 | k_sort_idx=k_sort_idx, 704 | Out=o, 705 | Lse=lse, 706 | softmax_scale=softmax_scale, 707 | stride_qb=q.stride(0), 708 | stride_qh=q.stride(2), 709 | stride_qm=q.stride(1), 710 | stride_kb=k.stride(0), 711 | stride_kh=k.stride(2), 712 | stride_kn=k.stride(1), 713 | stride_vb=v.stride(0), 714 | stride_vh=v.stride(2), 715 | stride_vn=v.stride(1), 716 | stride_q_sort_idxb=q_sort_idx.stride(0), 717 | stride_q_sort_idxh=q_sort_idx.stride(2), 718 | stride_q_sort_idxm=q_sort_idx.stride(1), 719 | stride_k_sort_idxb=k_sort_idx.stride(0), 720 | stride_k_sort_idxh=k_sort_idx.stride(2), 721 | stride_k_sort_idxn=k_sort_idx.stride(1), 722 | stride_ob=o.stride(0), 723 | stride_oh=o.stride(2), 724 | stride_om=o.stride(1), 725 | nheads=nheads, 726 | block_size=triton.cdiv(block_size, BLOCK), 727 | sample_size=triton.cdiv(sample_size, BLOCK), 728 | seqlen_k=seqlen_k, 729 | seqlen_q=seqlen_q, 730 | headdim=d, 731 | v_headdim=v_headdim, 732 | smooth_block=smooth_block, 733 | CACHE_KEY_SEQLEN_Q=seqlen_q // 32, 734 | CACHE_KEY_SEQLEN_K=seqlen_k // 32, 735 | BLOCK_HEADDIM=BLOCK_HEADDIM, 736 | V_BLOCK_HEADDIM=V_BLOCK_HEADDIM, 737 | BLOCK_M=BLOCK, 738 | BLOCK_N=BLOCK, 739 | num_warps=num_warps, 740 | num_stages=1, 741 | ) 742 | return o, lse, softmax_scale # softmax_scale could have been updated 743 | 744 | 745 | def _hyper_attn_backward( 746 | do, q, k, v, q_sort_idx, k_sort_idx, o, lse, dq, dk, dv, block_size, sample_size, softmax_scale=None, 747 | smooth_block=False): 748 | """ 749 | Initializes the backward kernel and schedules thread blocks and runs them in parallel 750 | """ 751 | # Make sure that the last dimension is contiguous 752 | if do.stride(-1) != 1: 753 | do = do.contiguous() 754 | batch, seqlen_q, nheads, d = q.shape 755 | _, seqlen_k, _, _ = k.shape 756 | # assert d in {16, 32, 64, 128} 757 | assert d <= 128 758 | assert lse.shape == (batch, nheads, seqlen_q) 759 | assert q.stride(-1) == k.stride(-1) == v.stride(-1) == o.stride(-1) == 1 760 | assert dq.stride(-1) == dk.stride(-1) == dv.stride(-1) == do.stride(-1) == 1 761 | softmax_scale = softmax_scale or 1.0 / math.sqrt(d) 762 | 763 | dq_accum = torch.zeros_like(q, dtype=torch.float32) 764 | delta = torch.empty_like(lse) 765 | 766 | v_headdim = v.shape[3] 767 | V_BLOCK_HEADDIM = max(triton.next_power_of_2(v_headdim), 16) 768 | BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) 769 | grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) 770 | _bwd_preprocess_do_o_dot[grid]( 771 | Out=o, 772 | DO=do, 773 | Delta=delta, 774 | stride_ob=o.stride(0), 775 | stride_oh=o.stride(2), 776 | stride_om=o.stride(1), 777 | stride_dob=do.stride(0), 778 | stride_doh=do.stride(2), 779 | stride_dom=do.stride(1), 780 | nheads=nheads, 781 | seqlen_q=seqlen_q, 782 | v_headdim=v_headdim, 783 | BLOCK_M=128, 784 | V_BLOCK_HEADDIM=V_BLOCK_HEADDIM, 785 | ) 786 | 787 | BLOCK = 128 788 | num_warps = 8 789 | grid = lambda META: (triton.cdiv(seqlen_k, BLOCK), batch * nheads) 790 | _bwd_permuted_block_diagonal_kernel[grid]( 791 | Q=q, 792 | K=k, 793 | V=v, 794 | q_sort_idx=q_sort_idx, 795 | k_sort_idx=k_sort_idx, 796 | DO=do, 797 | DQ=dq_accum, 798 | DK=dk, 799 | DV=dv, 800 | LSE=lse, 801 | D=delta, 802 | softmax_scale=softmax_scale, 803 | stride_qb=q.stride(0), 804 | stride_qh=q.stride(2), 805 | stride_qm=q.stride(1), 806 | stride_kb=k.stride(0), 807 | stride_kh=k.stride(2), 808 | stride_kn=k.stride(1), 809 | stride_vb=v.stride(0), 810 | stride_vh=v.stride(2), 811 | stride_vn=v.stride(1), 812 | stride_q_sort_idxb=q_sort_idx.stride(0), 813 | stride_q_sort_idxh=q_sort_idx.stride(2), 814 | stride_q_sort_idxm=q_sort_idx.stride(1), 815 | stride_k_sort_idxb=k_sort_idx.stride(0), 816 | stride_k_sort_idxh=k_sort_idx.stride(2), 817 | stride_k_sort_idxn=k_sort_idx.stride(1), 818 | stride_dob=do.stride(0), 819 | stride_doh=do.stride(2), 820 | stride_dom=do.stride(1), 821 | stride_dqb=dq_accum.stride(0), 822 | stride_dqh=dq_accum.stride(2), 823 | stride_dqm=dq_accum.stride(1), 824 | stride_dkb=dk.stride(0), 825 | stride_dkh=dk.stride(2), 826 | stride_dkn=dk.stride(1), 827 | stride_dvb=dv.stride(0), 828 | stride_dvh=dv.stride(2), 829 | stride_dvn=dv.stride(1), 830 | nheads=nheads, 831 | seqlen_q=seqlen_q, 832 | block_size=block_size, 833 | headdim=d, 834 | v_headdim=v_headdim, 835 | smooth_block=smooth_block, 836 | CACHE_KEY_SEQLEN_Q=seqlen_q // 32, 837 | CACHE_KEY_SEQLEN_K=seqlen_k // 32, # key for triton cache (limit number of compilations) 838 | BLOCK_HEADDIM=BLOCK_HEADDIM, 839 | V_BLOCK_HEADDIM=V_BLOCK_HEADDIM, 840 | BLOCK_M=BLOCK, 841 | BLOCK_N=BLOCK, 842 | num_warps=num_warps, 843 | num_stages=1, 844 | ) 845 | 846 | grid = lambda META: (triton.cdiv(sample_size, BLOCK), batch * nheads) 847 | _bwd_sampled_col_kernel[grid]( 848 | Q=q, 849 | K=k, 850 | V=v, 851 | DO=do, 852 | DQ=dq_accum, 853 | DK=dk, 854 | DV=dv, 855 | LSE=lse, 856 | D=delta, 857 | softmax_scale=softmax_scale, 858 | stride_qb=q.stride(0), 859 | stride_qh=q.stride(2), 860 | stride_qm=q.stride(1), 861 | stride_kb=k.stride(0), 862 | stride_kh=k.stride(2), 863 | stride_kn=k.stride(1), 864 | stride_vb=v.stride(0), 865 | stride_vh=v.stride(2), 866 | stride_vn=v.stride(1), 867 | stride_dob=do.stride(0), 868 | stride_doh=do.stride(2), 869 | stride_dom=do.stride(1), 870 | stride_dqb=dq_accum.stride(0), 871 | stride_dqh=dq_accum.stride(2), 872 | stride_dqm=dq_accum.stride(1), 873 | stride_dkb=dk.stride(0), 874 | stride_dkh=dk.stride(2), 875 | stride_dkn=dk.stride(1), 876 | stride_dvb=dv.stride(0), 877 | stride_dvh=dv.stride(2), 878 | stride_dvn=dv.stride(1), 879 | nheads=nheads, 880 | seqlen_q=seqlen_q, 881 | headdim=d, 882 | v_headdim=v_headdim, 883 | CACHE_KEY_SEQLEN_Q=seqlen_q // 32, 884 | CACHE_KEY_SEQLEN_K=seqlen_k // 32, # key for triton cache (limit number of compilations) 885 | BLOCK_HEADDIM=BLOCK_HEADDIM, 886 | V_BLOCK_HEADDIM=V_BLOCK_HEADDIM, 887 | BLOCK_M=BLOCK, 888 | BLOCK_N=BLOCK, 889 | num_warps=num_warps, 890 | num_stages=1, 891 | ) 892 | dq.copy_(dq_accum) 893 | 894 | 895 | class HyperAttnFunc(torch.autograd.Function): 896 | @staticmethod 897 | def forward(ctx, q, k, v, q_sort_idx, k_sort_idx, block_size, sample_size=0, softmax_scale=None, 898 | smooth_block=False): 899 | """ 900 | q, k: queries and keys (batch_size, seqlen, nheads, headdim), seqlen must be integer power of two 901 | v: values (batch_size, seqlen, nheads, v_headdim) 902 | q_sort_idx: the permutation for queries (batch_size, seqlen, nheads) 903 | k_sort_idx: the permutation for keys and values (batch_size, seqlen, nheads) 904 | block_size: side length of block diagonal blocks 905 | sample_size: number of sampled columns, must be multiple of 128 906 | softmax_scale: if none then scale will be 1/sqrt(headdim) 907 | smooth_block: if true the block diagonals will be smoothened to resemble banded digonal patterns 908 | """ 909 | # Make sure that the last dimension is contiguous 910 | q, k, v = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v]] 911 | assert sample_size % 128 == 0 912 | o, lse, ctx.softmax_scale = _hyper_attn_forward( 913 | q, k, v, q_sort_idx, k_sort_idx, block_size, sample_size, 914 | softmax_scale=softmax_scale, smooth_block=smooth_block, 915 | ) 916 | ctx.save_for_backward(q, k, v, q_sort_idx, k_sort_idx, o, lse) 917 | ctx.block_size = block_size 918 | ctx.sample_size = sample_size 919 | ctx.smooth_block = smooth_block 920 | return o, lse 921 | 922 | @staticmethod 923 | def backward(ctx, do, dlse_use_needed=None): 924 | q, k, v, q_sort_idx, k_sort_idx, o, lse = ctx.saved_tensors 925 | dq = torch.zeros_like(q) 926 | dk = torch.zeros_like(k) 927 | dv = torch.zeros_like(v) 928 | _hyper_attn_backward( 929 | do, 930 | q, 931 | k, 932 | v, 933 | q_sort_idx, 934 | k_sort_idx, 935 | o, 936 | lse, 937 | dq, 938 | dk, 939 | dv, 940 | ctx.block_size, 941 | ctx.sample_size, 942 | softmax_scale=ctx.softmax_scale, 943 | smooth_block=ctx.smooth_block, 944 | ) 945 | return dq, dk, dv, None, None, None, None, None, None 946 | 947 | 948 | hyper_attn_func = HyperAttnFunc.apply 949 | 950 | -------------------------------------------------------------------------------- /src/flash_attn_triton.py: -------------------------------------------------------------------------------- 1 | """ 2 | *Experimental* implementation of FlashAttention in Triton. 3 | Tested with triton==2.0.0.dev20221202. 4 | Triton 2.0 has a new backend (MLIR) but seems like it doesn't yet work for head dimensions 5 | other than 64: 6 | https://github.com/openai/triton/blob/d376020f90002757eea3ea9475d4f7cfc2ec5ead/python/triton/ops/flash_attention.py#L207 7 | We'll update this implementation with the new Triton backend once this is fixed. 8 | 9 | We use the FlashAttention implementation from Phil Tillet a starting point. 10 | https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py 11 | 12 | Changes: 13 | - Implement both causal and non-causal attention. 14 | - Implement both self-attention and cross-attention. 15 | - Support arbitrary seqlens (not just multiples of 128), for both forward and backward. 16 | - Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both forward and backward. 17 | - Support attention bias. 18 | - Speed up the forward pass a bit, and only store the LSE instead of m and l. 19 | - Make the backward for d=128 much faster by reducing register spilling. 20 | - Optionally parallelize the backward pass across seqlen_k, to deal with the case of 21 | small batch size * nheads. 22 | 23 | Caution: 24 | - This is an *experimental* implementation. The forward pass should be quite robust but 25 | I'm not 100% sure that the backward pass doesn't have race conditions (due to the Triton compiler). 26 | - This implementation has only been tested on A100. 27 | - If you plan to use headdim other than 64 and 128, you should test for race conditions 28 | (due to the Triton compiler), as done in tests/test_flash_attn.py 29 | "test_flash_attn_triton_race_condition". I've tested and fixed many race conditions 30 | for different head dimensions (40, 48, 64, 128, 80, 88, 96), but I'm still not 100% confident 31 | that there are none left for other head dimensions. 32 | 33 | Differences between this Triton version and the CUDA version: 34 | - Triton version doesn't support dropout. 35 | - Triton forward is generally faster than CUDA forward, while Triton backward is 36 | generally slower than CUDA backward. Overall Triton forward + backward is slightly slower 37 | than CUDA forward + backward. 38 | - Triton version doesn't support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor). 39 | - Triton version supports attention bias, while CUDA version doesn't. 40 | """ 41 | 42 | import math 43 | 44 | import torch 45 | import triton 46 | import triton.language as tl 47 | 48 | 49 | @triton.heuristics( 50 | { 51 | "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, 52 | "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0, 53 | "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], 54 | } 55 | ) 56 | @triton.jit 57 | def _fwd_kernel( 58 | Q, 59 | K, 60 | V, 61 | Bias, 62 | Out, 63 | Lse, 64 | softmax_scale, 65 | stride_qb, 66 | stride_qh, 67 | stride_qm, 68 | stride_kb, 69 | stride_kh, 70 | stride_kn, 71 | stride_vb, 72 | stride_vh, 73 | stride_vn, 74 | stride_bb, 75 | stride_bh, 76 | stride_bm, 77 | stride_ob, 78 | stride_oh, 79 | stride_om, 80 | nheads, 81 | seqlen_q, 82 | seqlen_k, 83 | seqlen_q_rounded, 84 | headdim, 85 | CACHE_KEY_SEQLEN_Q, 86 | CACHE_KEY_SEQLEN_K, 87 | BIAS_TYPE: tl.constexpr, 88 | IS_CAUSAL: tl.constexpr, 89 | BLOCK_HEADDIM: tl.constexpr, 90 | EVEN_M: tl.constexpr, 91 | EVEN_N: tl.constexpr, 92 | EVEN_HEADDIM: tl.constexpr, 93 | BLOCK_M: tl.constexpr, 94 | BLOCK_N: tl.constexpr, 95 | ): 96 | start_m = tl.program_id(0) 97 | off_hb = tl.program_id(1) 98 | off_b = off_hb // nheads 99 | off_h = off_hb % nheads 100 | # off_b = tl.program_id(1) 101 | # off_h = tl.program_id(2) 102 | # off_hb = off_b * nheads + off_h 103 | # initialize offsets 104 | offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) 105 | offs_n = tl.arange(0, BLOCK_N) 106 | offs_d = tl.arange(0, BLOCK_HEADDIM) 107 | # Initialize pointers to Q, K, V 108 | # Adding parenthesis around indexing might use int32 math instead of int64 math? 109 | # https://github.com/openai/triton/issues/741 110 | # I'm seeing a tiny bit of difference (5-7us) 111 | q_ptrs = ( 112 | Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :]) 113 | ) 114 | k_ptrs = ( 115 | K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :]) 116 | ) 117 | v_ptrs = ( 118 | V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :]) 119 | ) 120 | if BIAS_TYPE == "vector": 121 | b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n 122 | elif BIAS_TYPE == "matrix": 123 | b_ptrs = ( 124 | Bias 125 | + off_b * stride_bb 126 | + off_h * stride_bh 127 | + (offs_m[:, None] * stride_bm + offs_n[None, :]) 128 | ) 129 | # initialize pointer to m and l 130 | lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") 131 | m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") 132 | acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) 133 | # load q: it will stay in SRAM throughout 134 | # [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call 135 | # tl.load(q_ptrs), we get the wrong output! 136 | if EVEN_M & EVEN_N: 137 | if EVEN_HEADDIM: 138 | q = tl.load(q_ptrs) 139 | else: 140 | q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0) 141 | else: 142 | if EVEN_HEADDIM: 143 | q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0) 144 | else: 145 | q = tl.load( 146 | q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0 147 | ) 148 | # loop over k, v and update accumulator 149 | end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k) 150 | for start_n in range(0, end_n, BLOCK_N): 151 | start_n = tl.multiple_of(start_n, BLOCK_N) 152 | # -- compute qk ---- 153 | if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition 154 | if EVEN_HEADDIM: 155 | k = tl.load(k_ptrs + start_n * stride_kn) 156 | else: 157 | k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0) 158 | else: 159 | if EVEN_HEADDIM: 160 | k = tl.load( 161 | k_ptrs + start_n * stride_kn, 162 | mask=(start_n + offs_n)[:, None] < seqlen_k, 163 | other=0.0, 164 | ) 165 | else: 166 | k = tl.load( 167 | k_ptrs + start_n * stride_kn, 168 | mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), 169 | other=0.0, 170 | ) 171 | qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 172 | qk += tl.dot(q, tl.trans(k)) 173 | # Trying to combine the two masks seem to make the result wrong 174 | if not EVEN_N: # Need to mask out otherwise the softmax is wrong 175 | qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf")) 176 | if IS_CAUSAL: 177 | qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf")) 178 | if BIAS_TYPE != "none": 179 | if BIAS_TYPE == "vector": 180 | if EVEN_N: 181 | bias = tl.load(b_ptrs + start_n).to(tl.float32) 182 | else: 183 | bias = tl.load( 184 | b_ptrs + start_n, mask=(start_n + offs_n) < seqlen_k, other=0.0 185 | ).to(tl.float32) 186 | bias = bias[None, :] 187 | elif BIAS_TYPE == "matrix": 188 | if EVEN_M & EVEN_N: 189 | bias = tl.load(b_ptrs + start_n).to(tl.float32) 190 | else: 191 | bias = tl.load( 192 | b_ptrs + start_n, 193 | mask=(offs_m[:, None] < seqlen_q) 194 | & ((start_n + offs_n)[None, :] < seqlen_k), 195 | other=0.0, 196 | ).to(tl.float32) 197 | # Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler 198 | # can then fuse the mult and add into an fma instruction. But if we have bias we need to 199 | # to multiply with softmax_scale here. 200 | qk = qk * softmax_scale + bias 201 | m_ij = tl.maximum(tl.max(qk, 1), lse_i) 202 | p = tl.exp(qk - m_ij[:, None]) 203 | else: 204 | m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i) 205 | p = tl.exp(qk * softmax_scale - m_ij[:, None]) 206 | l_ij = tl.sum(p, 1) 207 | 208 | # scale acc_o 209 | acc_o_scale = tl.exp(m_i - m_ij) 210 | 211 | # # -- update output accumulator acc_o -- 212 | acc_o = acc_o * acc_o_scale[:, None] 213 | 214 | if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition 215 | if EVEN_HEADDIM: 216 | v = tl.load(v_ptrs + start_n * stride_vn) 217 | else: 218 | v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0) 219 | else: 220 | if EVEN_HEADDIM: 221 | v = tl.load( 222 | v_ptrs + start_n * stride_vn, 223 | mask=(start_n + offs_n)[:, None] < seqlen_k, 224 | other=0.0, 225 | ) 226 | else: 227 | v = tl.load( 228 | v_ptrs + start_n * stride_vn, 229 | mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), 230 | other=0.0, 231 | ) 232 | p = p.to(v.dtype) 233 | acc_o += tl.dot(p, v) 234 | 235 | # -- update statistics 236 | m_i = m_ij 237 | l_i_new = tl.exp(lse_i - m_ij) + l_ij 238 | lse_i = m_ij + tl.log(l_i_new) 239 | 240 | o_scale = tl.exp(m_i - lse_i) 241 | acc_o = acc_o * o_scale[:, None] 242 | # rematerialize offsets to save registers 243 | start_m = tl.program_id(0) 244 | offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) 245 | # write back l and m 246 | lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m 247 | tl.store(lse_ptrs, lse_i) 248 | # initialize pointers to output 249 | offs_d = tl.arange(0, BLOCK_HEADDIM) 250 | out_ptrs = ( 251 | Out 252 | + off_b * stride_ob 253 | + off_h * stride_oh 254 | + (offs_m[:, None] * stride_om + offs_d[None, :]) 255 | ) 256 | if EVEN_M: 257 | if EVEN_HEADDIM: 258 | tl.store(out_ptrs, acc_o) 259 | else: 260 | tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim) 261 | else: 262 | if EVEN_HEADDIM: 263 | tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q) 264 | else: 265 | tl.store( 266 | out_ptrs, acc_o, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim) 267 | ) 268 | 269 | 270 | @triton.jit 271 | def _bwd_preprocess_do_o_dot( 272 | Out, 273 | DO, 274 | Delta, 275 | stride_ob, 276 | stride_oh, 277 | stride_om, 278 | stride_dob, 279 | stride_doh, 280 | stride_dom, 281 | nheads, 282 | seqlen_q, 283 | seqlen_q_rounded, 284 | headdim, 285 | BLOCK_M: tl.constexpr, 286 | BLOCK_HEADDIM: tl.constexpr, 287 | ): 288 | start_m = tl.program_id(0) 289 | off_hb = tl.program_id(1) 290 | off_b = off_hb // nheads 291 | off_h = off_hb % nheads 292 | # initialize offsets 293 | offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) 294 | offs_d = tl.arange(0, BLOCK_HEADDIM) 295 | # load 296 | o = tl.load( 297 | Out + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :], 298 | mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), 299 | other=0.0, 300 | ).to(tl.float32) 301 | do = tl.load( 302 | DO 303 | + off_b * stride_dob 304 | + off_h * stride_doh 305 | + offs_m[:, None] * stride_dom 306 | + offs_d[None, :], 307 | mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), 308 | other=0.0, 309 | ).to(tl.float32) 310 | delta = tl.sum(o * do, axis=1) 311 | # write-back 312 | tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta) 313 | 314 | 315 | @triton.jit 316 | def _bwd_store_dx( 317 | dx_ptrs, 318 | dx, 319 | offs_n, 320 | offs_d, 321 | seqlen, 322 | headdim, 323 | EVEN_M: tl.constexpr, 324 | EVEN_N: tl.constexpr, 325 | even_headdim, 326 | ): 327 | # [2022-11-01] TD: Same bug. In the case of EVEN_N=True and EVEN_M=False, 328 | # if we just call tl.store(dv_ptrs), there's a race condition 329 | if EVEN_N & EVEN_M: 330 | if even_headdim: 331 | tl.store(dx_ptrs, dx) 332 | else: 333 | tl.store(dx_ptrs, dx, mask=offs_d[None, :] < headdim) 334 | else: 335 | if even_headdim: 336 | tl.store(dx_ptrs, dx, mask=offs_n[:, None] < seqlen) 337 | else: 338 | tl.store(dx_ptrs, dx, mask=(offs_n[:, None] < seqlen) & (offs_d[None, :] < headdim)) 339 | 340 | 341 | @triton.jit 342 | def _bwd_kernel_one_col_block( 343 | start_n, 344 | Q, 345 | K, 346 | V, 347 | Bias, 348 | DO, 349 | DQ, 350 | DK, 351 | DV, 352 | LSE, 353 | D, 354 | softmax_scale, 355 | stride_qm, 356 | stride_kn, 357 | stride_vn, 358 | stride_bm, 359 | stride_dom, 360 | stride_dqm, 361 | stride_dkn, 362 | stride_dvn, 363 | seqlen_q, 364 | seqlen_k, 365 | headdim, 366 | ATOMIC_ADD: tl.constexpr, 367 | BIAS_TYPE: tl.constexpr, 368 | IS_CAUSAL: tl.constexpr, 369 | BLOCK_HEADDIM: tl.constexpr, 370 | EVEN_M: tl.constexpr, 371 | EVEN_N: tl.constexpr, 372 | EVEN_HEADDIM: tl.constexpr, 373 | BLOCK_M: tl.constexpr, 374 | BLOCK_N: tl.constexpr, 375 | ): 376 | # We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N) 377 | begin_m = 0 if not IS_CAUSAL else ((start_n * BLOCK_N) // BLOCK_M) * BLOCK_M 378 | # initialize row/col offsets 379 | offs_qm = begin_m + tl.arange(0, BLOCK_M) 380 | offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) 381 | offs_m = tl.arange(0, BLOCK_M) 382 | offs_d = tl.arange(0, BLOCK_HEADDIM) 383 | # initialize pointers to value-like data 384 | q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_d[None, :]) 385 | k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :]) 386 | v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :]) 387 | do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :]) 388 | dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :]) 389 | if BIAS_TYPE == "vector": 390 | b_ptrs = Bias + offs_n 391 | elif BIAS_TYPE == "matrix": 392 | b_ptrs = Bias + (offs_qm[:, None] * stride_bm + offs_n[None, :]) 393 | # initialize dv and dk 394 | dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) 395 | dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) 396 | # There seems to be some problem with Triton pipelining that makes results wrong for 397 | # headdim=64, seqlen=(113, 255), bias_type='matrix'. In this case the for loop 398 | # may have zero step, and pipelining with the bias matrix could screw it up. 399 | # So we just exit early. 400 | if begin_m >= seqlen_q: 401 | dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :]) 402 | dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :]) 403 | _bwd_store_dx( 404 | dk_ptrs, 405 | dk, 406 | offs_n, 407 | offs_d, 408 | seqlen_k, 409 | headdim, 410 | EVEN_M=EVEN_M, 411 | EVEN_N=EVEN_N, 412 | even_headdim=EVEN_HEADDIM, 413 | ) 414 | _bwd_store_dx( 415 | dv_ptrs, 416 | dv, 417 | offs_n, 418 | offs_d, 419 | seqlen_k, 420 | headdim, 421 | EVEN_M=EVEN_M, 422 | EVEN_N=EVEN_N, 423 | even_headdim=EVEN_HEADDIM, 424 | ) 425 | return 426 | # k and v stay in SRAM throughout 427 | # [2022-10-30] TD: Same bug as the fwd. In the case of EVEN_N=True and EVEN_M=False, 428 | # if we just call tl.load(k_ptrs), we get the wrong output! 429 | if EVEN_N & EVEN_M: 430 | if EVEN_HEADDIM: 431 | k = tl.load(k_ptrs) 432 | v = tl.load(v_ptrs) 433 | else: 434 | k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0) 435 | v = tl.load(v_ptrs, mask=offs_d[None, :] < headdim, other=0.0) 436 | else: 437 | if EVEN_HEADDIM: 438 | k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0) 439 | v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0) 440 | else: 441 | k = tl.load( 442 | k_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0 443 | ) 444 | v = tl.load( 445 | v_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0 446 | ) 447 | # loop over rows 448 | num_block_m = tl.cdiv(seqlen_q, BLOCK_M) 449 | for start_m in range(begin_m, num_block_m * BLOCK_M, BLOCK_M): 450 | start_m = tl.multiple_of(start_m, BLOCK_M) 451 | offs_m_curr = start_m + offs_m 452 | # load q, k, v, do on-chip 453 | # Same bug as below. Otherwise gives wrong result for headdim=40, seqlen=(128, 117) 454 | if EVEN_M & EVEN_HEADDIM: 455 | q = tl.load(q_ptrs) 456 | else: 457 | if EVEN_HEADDIM: 458 | q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0) 459 | else: 460 | q = tl.load( 461 | q_ptrs, 462 | mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), 463 | other=0.0, 464 | ) 465 | # recompute p = softmax(qk, dim=-1).T 466 | qk = tl.dot(q, tl.trans(k)) 467 | # Trying to combine the two masks seem to make the result wrong 468 | if not EVEN_N: # Need to mask out otherwise the softmax is wrong 469 | qk = tl.where(offs_n[None, :] < seqlen_k, qk, float("-inf")) 470 | if IS_CAUSAL: 471 | qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf")) 472 | if BIAS_TYPE != "none": 473 | tl.debug_barrier() # Race condition otherwise 474 | if BIAS_TYPE == "vector": 475 | if EVEN_N: 476 | bias = tl.load(b_ptrs).to(tl.float32) 477 | else: 478 | bias = tl.load(b_ptrs, mask=offs_n < seqlen_k, other=0.0).to(tl.float32) 479 | bias = bias[None, :] 480 | elif BIAS_TYPE == "matrix": 481 | if EVEN_M & EVEN_N: 482 | bias = tl.load(b_ptrs).to(tl.float32) 483 | else: 484 | bias = tl.load( 485 | b_ptrs, 486 | mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k), 487 | other=0.0, 488 | ).to(tl.float32) 489 | qk = qk * softmax_scale + bias 490 | # There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong. 491 | # Also wrong for headdim=64. 492 | if not (EVEN_M & EVEN_HEADDIM): 493 | tl.debug_barrier() 494 | lse_i = tl.load(LSE + offs_m_curr) 495 | if BIAS_TYPE == "none": 496 | p = tl.exp(qk * softmax_scale - lse_i[:, None]) 497 | else: 498 | p = tl.exp(qk - lse_i[:, None]) 499 | # compute dv 500 | # [2022-10-30] TD: A Triton bug: if EVEN_M=True and EVEN_HEADDIM=False, if we call 501 | # do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0), we get wrong outputs 502 | # in the case of headdim=48/96, seqlen_q & seqlen_k >= 512. If headdim=40 or seqlen < 512, 503 | # the output is correct. 504 | if EVEN_M & EVEN_HEADDIM: 505 | do = tl.load(do_ptrs) 506 | else: 507 | # [2022-11-01] TD: Triton bug, there's a race condition if we just use m_mask and not d_mask. 508 | do = tl.load( 509 | do_ptrs, 510 | mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), 511 | other=0.0, 512 | ) 513 | # if EVEN_M: 514 | # if EVEN_HEADDIM: 515 | # do = tl.load(do_ptrs) 516 | # else: 517 | # do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0) 518 | # else: 519 | # if EVEN_HEADDIM: 520 | # do = tl.load(do_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0) 521 | # else: 522 | # do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) 523 | # & (offs_d[None, :] < headdim), other=0.0) 524 | dv += tl.dot(tl.trans(p.to(do.dtype)), do) 525 | # compute dp = dot(v, do) 526 | # There seems to be a race condition when headdim=48/96, and dq, dk are wrong. 527 | # Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True 528 | # Also wrong for headdim=64, seqlen=(1023, 1024), and ATOMIC_ADD=False 529 | if not (EVEN_M & EVEN_HEADDIM): 530 | tl.debug_barrier() 531 | dp = tl.dot(do, tl.trans(v)) 532 | # There's a race condition for headdim=48 533 | if not EVEN_HEADDIM: 534 | tl.debug_barrier() 535 | # compute ds = p * (dp - delta[:, None]) 536 | # Putting the subtraction after the dp matmul (instead of before) is slightly faster 537 | Di = tl.load(D + offs_m_curr) 538 | # Converting ds to q.dtype here reduces register pressure and makes it much faster 539 | # for BLOCK_HEADDIM=128 540 | ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype) 541 | # compute dk = dot(ds.T, q) 542 | dk += tl.dot(tl.trans(ds), q) 543 | # compute dq 544 | if not ( 545 | EVEN_M & EVEN_HEADDIM 546 | ): # Otherewise there's a race condition when BIAS_TYPE='matrix' 547 | tl.debug_barrier() 548 | if not ATOMIC_ADD: 549 | if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M 550 | dq = tl.load(dq_ptrs, eviction_policy="evict_last") 551 | dq += tl.dot(ds, k) 552 | tl.store(dq_ptrs, dq, eviction_policy="evict_last") 553 | else: 554 | if EVEN_HEADDIM: 555 | dq = tl.load( 556 | dq_ptrs, 557 | mask=offs_m_curr[:, None] < seqlen_q, 558 | other=0.0, 559 | eviction_policy="evict_last", 560 | ) 561 | dq += tl.dot(ds, k) 562 | tl.store( 563 | dq_ptrs, 564 | dq, 565 | mask=offs_m_curr[:, None] < seqlen_q, 566 | eviction_policy="evict_last", 567 | ) 568 | else: 569 | dq = tl.load( 570 | dq_ptrs, 571 | mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), 572 | other=0.0, 573 | eviction_policy="evict_last", 574 | ) 575 | dq += tl.dot(ds, k) 576 | tl.store( 577 | dq_ptrs, 578 | dq, 579 | mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), 580 | eviction_policy="evict_last", 581 | ) 582 | else: # If we're parallelizing across the seqlen_k dimension 583 | dq = tl.dot(ds, k) 584 | if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M 585 | tl.atomic_add(dq_ptrs, dq) 586 | else: 587 | if EVEN_HEADDIM: 588 | tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q) 589 | else: 590 | tl.atomic_add( 591 | dq_ptrs, 592 | dq, 593 | mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), 594 | ) 595 | # increment pointers 596 | dq_ptrs += BLOCK_M * stride_dqm 597 | q_ptrs += BLOCK_M * stride_qm 598 | do_ptrs += BLOCK_M * stride_dom 599 | if BIAS_TYPE == "matrix": 600 | b_ptrs += BLOCK_M * stride_bm 601 | # write-back 602 | dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :]) 603 | dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :]) 604 | _bwd_store_dx( 605 | dk_ptrs, 606 | dk, 607 | offs_n, 608 | offs_d, 609 | seqlen_k, 610 | headdim, 611 | EVEN_M=EVEN_M, 612 | EVEN_N=EVEN_N, 613 | even_headdim=EVEN_HEADDIM, 614 | ) 615 | _bwd_store_dx( 616 | dv_ptrs, 617 | dv, 618 | offs_n, 619 | offs_d, 620 | seqlen_k, 621 | headdim, 622 | EVEN_M=EVEN_M, 623 | EVEN_N=EVEN_N, 624 | even_headdim=EVEN_HEADDIM, 625 | ) 626 | 627 | 628 | def init_to_zero(name): 629 | return lambda nargs: nargs[name].zero_() 630 | 631 | # compiler bug with using autotune in triton v2. 632 | @triton.autotune( 633 | configs=[ 634 | triton.Config( 635 | {"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, 636 | num_warps=8, 637 | num_stages=1, 638 | pre_hook=init_to_zero("DQ"), 639 | ), 640 | ], 641 | key=["CACHE_KEY_SEQLEN_Q", "CACHE_KEY_SEQLEN_K", "BIAS_TYPE", "IS_CAUSAL", "BLOCK_HEADDIM"], 642 | ) 643 | @triton.heuristics( 644 | { 645 | "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, 646 | "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0, 647 | "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], 648 | } 649 | ) 650 | @triton.jit 651 | def _bwd_kernel( 652 | Q, 653 | K, 654 | V, 655 | Bias, 656 | DO, 657 | DQ, 658 | DK, 659 | DV, 660 | LSE, 661 | D, 662 | softmax_scale, 663 | stride_qb, 664 | stride_qh, 665 | stride_qm, 666 | stride_kb, 667 | stride_kh, 668 | stride_kn, 669 | stride_vb, 670 | stride_vh, 671 | stride_vn, 672 | stride_bb, 673 | stride_bh, 674 | stride_bm, 675 | stride_dob, 676 | stride_doh, 677 | stride_dom, 678 | stride_dqb, 679 | stride_dqh, 680 | stride_dqm, 681 | stride_dkb, 682 | stride_dkh, 683 | stride_dkn, 684 | stride_dvb, 685 | stride_dvh, 686 | stride_dvn, 687 | nheads, 688 | seqlen_q, 689 | seqlen_k, 690 | seqlen_q_rounded, 691 | headdim, 692 | CACHE_KEY_SEQLEN_Q, 693 | CACHE_KEY_SEQLEN_K, 694 | BIAS_TYPE: tl.constexpr, 695 | IS_CAUSAL: tl.constexpr, 696 | BLOCK_HEADDIM: tl.constexpr, 697 | SEQUENCE_PARALLEL: tl.constexpr, 698 | EVEN_M: tl.constexpr, 699 | EVEN_N: tl.constexpr, 700 | EVEN_HEADDIM: tl.constexpr, 701 | BLOCK_M: tl.constexpr, 702 | BLOCK_N: tl.constexpr, 703 | ): 704 | off_hb = tl.program_id(1) 705 | off_b = off_hb // nheads 706 | off_h = off_hb % nheads 707 | # offset pointers for batch/head 708 | Q += off_b * stride_qb + off_h * stride_qh 709 | K += off_b * stride_kb + off_h * stride_kh 710 | V += off_b * stride_vb + off_h * stride_vh 711 | DO += off_b * stride_dob + off_h * stride_doh 712 | DQ += off_b * stride_dqb + off_h * stride_dqh 713 | DK += off_b * stride_dkb + off_h * stride_dkh 714 | DV += off_b * stride_dvb + off_h * stride_dvh 715 | if BIAS_TYPE != "none": 716 | Bias += off_b * stride_bb + off_h * stride_bh 717 | # pointer to row-wise quantities in value-like data 718 | D += off_hb * seqlen_q_rounded 719 | LSE += off_hb * seqlen_q_rounded 720 | if not SEQUENCE_PARALLEL: 721 | num_block_n = tl.cdiv(seqlen_k, BLOCK_N) 722 | for start_n in range(0, num_block_n): 723 | _bwd_kernel_one_col_block( 724 | start_n, 725 | Q, 726 | K, 727 | V, 728 | Bias, 729 | DO, 730 | DQ, 731 | DK, 732 | DV, 733 | LSE, 734 | D, 735 | softmax_scale, 736 | stride_qm, 737 | stride_kn, 738 | stride_vn, 739 | stride_bm, 740 | stride_dom, 741 | stride_dqm, 742 | stride_dkn, 743 | stride_dvn, 744 | seqlen_q, 745 | seqlen_k, 746 | headdim, 747 | ATOMIC_ADD=False, 748 | BIAS_TYPE=BIAS_TYPE, 749 | IS_CAUSAL=IS_CAUSAL, 750 | BLOCK_HEADDIM=BLOCK_HEADDIM, 751 | EVEN_M=EVEN_M, 752 | EVEN_N=EVEN_N, 753 | EVEN_HEADDIM=EVEN_HEADDIM, 754 | BLOCK_M=BLOCK_M, 755 | BLOCK_N=BLOCK_N, 756 | ) 757 | else: 758 | start_n = tl.program_id(0) 759 | _bwd_kernel_one_col_block( 760 | start_n, 761 | Q, 762 | K, 763 | V, 764 | Bias, 765 | DO, 766 | DQ, 767 | DK, 768 | DV, 769 | LSE, 770 | D, 771 | softmax_scale, 772 | stride_qm, 773 | stride_kn, 774 | stride_vn, 775 | stride_bm, 776 | stride_dom, 777 | stride_dqm, 778 | stride_dkn, 779 | stride_dvn, 780 | seqlen_q, 781 | seqlen_k, 782 | headdim, 783 | ATOMIC_ADD=True, 784 | BIAS_TYPE=BIAS_TYPE, 785 | IS_CAUSAL=IS_CAUSAL, 786 | BLOCK_HEADDIM=BLOCK_HEADDIM, 787 | EVEN_M=EVEN_M, 788 | EVEN_N=EVEN_N, 789 | EVEN_HEADDIM=EVEN_HEADDIM, 790 | BLOCK_M=BLOCK_M, 791 | BLOCK_N=BLOCK_N, 792 | ) 793 | 794 | 795 | def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None): 796 | # shape constraints 797 | batch, seqlen_q, nheads, d = q.shape 798 | _, seqlen_k, _, _ = k.shape 799 | assert k.shape == (batch, seqlen_k, nheads, d) 800 | assert v.shape == (batch, seqlen_k, nheads, d) 801 | assert d <= 128, "FlashAttention only support head dimensions up to 128" 802 | assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type" 803 | assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16" 804 | assert q.is_cuda and k.is_cuda and v.is_cuda 805 | softmax_scale = softmax_scale or 1.0 / math.sqrt(d) 806 | 807 | has_bias = bias is not None 808 | bias_type = "none" 809 | if has_bias: 810 | assert bias.dtype in [q.dtype, torch.float] 811 | assert bias.is_cuda 812 | assert bias.dim() == 4 813 | if bias.stride(-1) != 1: 814 | bias = bias.contiguous() 815 | if bias.shape[2:] == (1, seqlen_k): 816 | bias_type = "vector" 817 | elif bias.shape[2:] == (seqlen_q, seqlen_k): 818 | bias_type = "matrix" 819 | else: 820 | raise RuntimeError( 821 | "Last 2 dimensions of bias must be (1, seqlen_k)" " or (seqlen_q, seqlen_k)" 822 | ) 823 | bias = bias.expand(batch, nheads, seqlen_q, seqlen_k) 824 | bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0) 825 | 826 | seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 827 | lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) 828 | o = torch.empty_like(q) 829 | 830 | BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) 831 | BLOCK = 128 832 | num_warps = 4 if d <= 64 else 8 833 | grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) 834 | _fwd_kernel[grid]( 835 | q, 836 | k, 837 | v, 838 | bias, 839 | o, 840 | lse, 841 | softmax_scale, 842 | q.stride(0), 843 | q.stride(2), 844 | q.stride(1), 845 | k.stride(0), 846 | k.stride(2), 847 | k.stride(1), 848 | v.stride(0), 849 | v.stride(2), 850 | v.stride(1), 851 | *bias_strides, 852 | o.stride(0), 853 | o.stride(2), 854 | o.stride(1), 855 | nheads, 856 | seqlen_q, 857 | seqlen_k, 858 | seqlen_q_rounded, 859 | d, 860 | seqlen_q // 32, 861 | seqlen_k // 32, # key for triton cache (limit number of compilations) 862 | # Can't use kwargs here because triton autotune expects key to be args, not kwargs 863 | # IS_CAUSAL=causal, BLOCK_HEADDIM=d, 864 | bias_type, 865 | causal, 866 | BLOCK_HEADDIM, 867 | BLOCK_M=BLOCK, 868 | BLOCK_N=BLOCK, 869 | num_warps=num_warps, 870 | num_stages=1, 871 | ) 872 | return o, lse, softmax_scale # softmax_scale could have been updated 873 | 874 | 875 | def _flash_attn_backward( 876 | do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=False, softmax_scale=None 877 | ): 878 | # Make sure that the last dimension is contiguous 879 | if do.stride(-1) != 1: 880 | do = do.contiguous() 881 | batch, seqlen_q, nheads, d = q.shape 882 | _, seqlen_k, _, _ = k.shape 883 | # assert d in {16, 32, 64, 128} 884 | assert d <= 128 885 | seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 886 | assert lse.shape == (batch, nheads, seqlen_q_rounded) 887 | assert q.stride(-1) == k.stride(-1) == v.stride(-1) == o.stride(-1) == 1 888 | assert dq.stride(-1) == dk.stride(-1) == dv.stride(-1) == 1 889 | softmax_scale = softmax_scale or 1.0 / math.sqrt(d) 890 | # dq_accum = torch.zeros_like(q, dtype=torch.float32) 891 | dq_accum = torch.empty_like(q, dtype=torch.float32) 892 | delta = torch.empty_like(lse) 893 | # delta = torch.zeros_like(lse) 894 | 895 | BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) 896 | grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) 897 | _bwd_preprocess_do_o_dot[grid]( 898 | o, 899 | do, 900 | delta, 901 | o.stride(0), 902 | o.stride(2), 903 | o.stride(1), 904 | do.stride(0), 905 | do.stride(2), 906 | do.stride(1), 907 | nheads, 908 | seqlen_q, 909 | seqlen_q_rounded, 910 | d, 911 | BLOCK_M=128, 912 | BLOCK_HEADDIM=BLOCK_HEADDIM, 913 | ) 914 | 915 | has_bias = bias is not None 916 | bias_type = "none" 917 | if has_bias: 918 | assert bias.dtype in [q.dtype, torch.float] 919 | assert bias.is_cuda 920 | assert bias.dim() == 4 921 | assert bias.stride(-1) == 1 922 | if bias.shape[2:] == (1, seqlen_k): 923 | bias_type = "vector" 924 | elif bias.shape[2:] == (seqlen_q, seqlen_k): 925 | bias_type = "matrix" 926 | else: 927 | raise RuntimeError( 928 | "Last 2 dimensions of bias must be (1, seqlen_k)" " or (seqlen_q, seqlen_k)" 929 | ) 930 | bias = bias.expand(batch, nheads, seqlen_q, seqlen_k) 931 | bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0) 932 | 933 | # BLOCK_M = 128 934 | # BLOCK_N = 128 935 | # num_warps = 8 936 | grid = lambda META: ( 937 | triton.cdiv(seqlen_k, META["BLOCK_N"]) if META["SEQUENCE_PARALLEL"] else 1, 938 | batch * nheads, 939 | ) 940 | _bwd_kernel[grid]( 941 | q, 942 | k, 943 | v, 944 | bias, 945 | do, 946 | dq_accum, 947 | dk, 948 | dv, 949 | lse, 950 | delta, 951 | softmax_scale, 952 | q.stride(0), 953 | q.stride(2), 954 | q.stride(1), 955 | k.stride(0), 956 | k.stride(2), 957 | k.stride(1), 958 | v.stride(0), 959 | v.stride(2), 960 | v.stride(1), 961 | *bias_strides, 962 | do.stride(0), 963 | do.stride(2), 964 | do.stride(1), 965 | dq_accum.stride(0), 966 | dq_accum.stride(2), 967 | dq_accum.stride(1), 968 | dk.stride(0), 969 | dk.stride(2), 970 | dk.stride(1), 971 | dv.stride(0), 972 | dv.stride(2), 973 | dv.stride(1), 974 | nheads, 975 | seqlen_q, 976 | seqlen_k, 977 | seqlen_q_rounded, 978 | d, 979 | seqlen_q // 32, 980 | seqlen_k // 32, # key for triton cache (limit number of compilations) 981 | # Can't use kwargs here because triton autotune expects key to be args, not kwargs 982 | # IS_CAUSAL=causal, BLOCK_HEADDIM=d, 983 | bias_type, 984 | causal, 985 | BLOCK_HEADDIM, 986 | # SEQUENCE_PARALLEL=False, 987 | # BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, 988 | # num_warps=num_warps, 989 | # num_stages=1, 990 | ) 991 | dq.copy_(dq_accum) 992 | 993 | 994 | 995 | class FlashAttnFunc(torch.autograd.Function): 996 | @staticmethod 997 | def forward(ctx, q, k, v, bias=None, causal=False, softmax_scale=None): 998 | """ 999 | q: (batch_size, seqlen_q, nheads, headdim) 1000 | k, v: (batch_size, seqlen_k, nheads, headdim) 1001 | bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k). 1002 | For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k). 1003 | ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k) 1004 | """ 1005 | # Make sure that the last dimension is contiguous 1006 | q, k, v = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v]] 1007 | o, lse, ctx.softmax_scale = _flash_attn_forward( 1008 | q, k, v, bias=bias, causal=causal, softmax_scale=softmax_scale 1009 | ) 1010 | ctx.save_for_backward(q, k, v, o, lse, bias) 1011 | ctx.causal = causal 1012 | return o, lse 1013 | 1014 | @staticmethod 1015 | def backward(ctx, do, dlse_use_needed=None): 1016 | q, k, v, o, lse, bias = ctx.saved_tensors 1017 | if len(ctx.needs_input_grad) > 3: 1018 | assert not ctx.needs_input_grad[3], "FlashAttention does not support bias gradient yet" 1019 | # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd 1020 | # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version. 1021 | with torch.inference_mode(): 1022 | dq = torch.empty_like(q) 1023 | dk = torch.empty_like(k) 1024 | dv = torch.empty_like(v) 1025 | _flash_attn_backward( 1026 | do, 1027 | q, 1028 | k, 1029 | v, 1030 | o, 1031 | lse, 1032 | dq, 1033 | dk, 1034 | dv, 1035 | bias=bias, 1036 | causal=ctx.causal, 1037 | softmax_scale=ctx.softmax_scale, 1038 | ) 1039 | return dq, dk, dv, None, None, None 1040 | 1041 | 1042 | flash_attn_func = FlashAttnFunc.apply 1043 | 1044 | --------------------------------------------------------------------------------