├── speedup_plots
├── fig_hyper_attn_causal_masking.pdf
├── fig_hyper_attn_causal_masking.png
├── fig_hyper_attn_no_causal_masking.pdf
└── fig_hyper_attn_no_causal_masking.png
├── replace_llm_attention.py
├── src
├── attn_utils.py
├── angular_lsh_triton.py
├── hyper_attn_triton.py
└── flash_attn_triton.py
├── unit_tests
├── test_lsh.py
└── test_hyper_attention.py
├── chatglm_fast_attention.py
├── README.md
├── benchmark_single_attention.py
├── benchmark_patch_llm.py
├── hyper_attention.py
└── LICENSE
/speedup_plots/fig_hyper_attn_causal_masking.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amirzandieh/HyperAttention/HEAD/speedup_plots/fig_hyper_attn_causal_masking.pdf
--------------------------------------------------------------------------------
/speedup_plots/fig_hyper_attn_causal_masking.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amirzandieh/HyperAttention/HEAD/speedup_plots/fig_hyper_attn_causal_masking.png
--------------------------------------------------------------------------------
/speedup_plots/fig_hyper_attn_no_causal_masking.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amirzandieh/HyperAttention/HEAD/speedup_plots/fig_hyper_attn_no_causal_masking.pdf
--------------------------------------------------------------------------------
/speedup_plots/fig_hyper_attn_no_causal_masking.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amirzandieh/HyperAttention/HEAD/speedup_plots/fig_hyper_attn_no_causal_masking.png
--------------------------------------------------------------------------------
/replace_llm_attention.py:
--------------------------------------------------------------------------------
1 | NUM_TOTAL_LAYERS = {
2 | 'chatglm2-6b-32k': 28,
3 | }
4 |
5 |
6 | def patch_attention_layers(model, model_name, patch_config, num_patch_layers, **kwargs):
7 | num_total_layers = NUM_TOTAL_LAYERS[model_name]
8 | num_patch_layers = num_total_layers if num_patch_layers < 0 else num_patch_layers
9 |
10 | if patch_config == 'last':
11 | patch_layer_indices = range(num_total_layers - 1, num_total_layers - num_patch_layers - 1, -1)
12 |
13 | elif patch_config == 'first':
14 | patch_layer_indices = range(num_patch_layers)
15 |
16 | elif patch_config == 'odd':
17 | patch_layer_indices = range(1, num_total_layers, 2)
18 |
19 | elif patch_config == 'even':
20 | patch_layer_indices = range(0, num_total_layers, 2)
21 |
22 | elif patch_config == 'odd_first':
23 | patch_layer_indices = range(1, 2 * num_patch_layers, 2)
24 |
25 | elif patch_config == 'odd_last':
26 | patch_layer_indices = range(num_total_layers - 1, num_total_layers - num_patch_layers, -1)
27 |
28 | elif patch_config == 'even_first':
29 | patch_layer_indices = range(0, num_total_layers, 2)[:num_patch_layers]
30 |
31 | elif patch_config == 'even_last':
32 | patch_layer_indices = range(1, num_total_layers, 2)[-num_patch_layers:]
33 |
34 | else:
35 | raise NotImplementedError(f"Invalid patch_config option: {patch_config}")
36 |
37 | if model_name == 'chatglm2-6b-32k':
38 | from chatglm_fast_attention import FastCoreAttention
39 |
40 | print(
41 | f"patch_config: {patch_config}, attn_method: {kwargs['attn_method']}, num_patch_layers: {num_patch_layers}, patch_indices: {list(patch_layer_indices)}")
42 | for i in patch_layer_indices:
43 | model.transformer.encoder.layers[i].self_attention.core_attention = FastCoreAttention(model.config, i,
44 | **kwargs)
45 |
--------------------------------------------------------------------------------
/src/attn_utils.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 |
5 |
6 | def add_self_attentions(attn1, lse1, attn2, lse2):
7 | """
8 | inputs:
9 | - attn1, attn2: 4d-tensors with shape [b, h, n, d]
10 | - lse1, lse2: 4d-tensors of log-sum-exp with shape [b, h, n, 1]
11 | output:
12 | - attn
13 | = (attn1 * exp(lse1) + attn2 * exp(lse2)) / (exp(lse1) + exp(lse2))
14 | = (attn1 + attn2 * exp(lse2 - lse1)) / (1 + exp(lse2-lse1))
15 | = attn1 * c + attn2 * (1-c), where c=1/(1 + exp(lse2-lse1)),
16 | - lse
17 | = log(exp(lse1) + exp(lse2))
18 | = log(exp(lse1) * (1 + exp(lse2 - lse1)))
19 | = lse1 + log(1 + exp(lse2 - lse1)) = lse1 - log(c)
20 | """
21 | c = (1 / (1 + (lse2 - lse1).exp())).to(dtype=attn1.dtype)
22 | attn = c * attn1 + (1-c) * attn2
23 | lse = lse1 - (c + torch.finfo(lse1.dtype).eps).log()
24 | return attn, lse
25 |
26 |
27 | def indexing(x, indices, chunk_size=-1):
28 | """
29 | inputs:
30 | - x: 4d-tensor with shape [b, h, n, d]
31 | - indices: 3d-tensor with shape [b, h, s] where each entry should be in [0, n-1]
32 | output:
33 | - out: 4d-tensor with shape [b, h, s, d] where out[i,j] = x[i,j][indices[i,j],:]
34 |
35 | A naive implementation:
36 | out = torch.zeros(b, h, s, d)
37 | for i in range(b):
38 | for j in range(h):
39 | out[i,j] = x[i,j][idx[i,j],:]
40 | return out
41 | """
42 | if chunk_size < 0 or (chunk_size > 0 and x.shape[-2] % chunk_size == 0):
43 | return x.gather(2, indices.unsqueeze(-1).expand(-1, -1, -1, x.shape[-1]))
44 | else:
45 | x = x.gather(2, indices.unsqueeze(-1).expand(-1, -1, -1, x.shape[-1]))
46 | new_n = math.ceil(x.shape[2] / chunk_size) * chunk_size
47 | if new_n <= 0 or new_n - x.shape[2] <= 0:
48 | import pdb;
49 | pdb.set_trace();
50 | return torch.nn.functional.pad(x, (0, 0, 0, new_n - x.shape[2]), mode='constant', value=0.)
51 |
--------------------------------------------------------------------------------
/unit_tests/test_lsh.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import time
3 | import torch
4 | import triton
5 |
6 | import sys;
7 |
8 | sys.path.append("/home/ec2-user/workspace/hyper_attention")
9 | from src.angular_lsh_triton import AngularLSHTriton
10 |
11 | cnt = 0
12 |
13 | def check_memory():
14 | global cnt
15 | mem_alloc = torch.cuda.memory_allocated() / 1024 / 1024 / 1024
16 | mem_reserve = torch.cuda.memory_reserved() / 1024 / 1024 / 1024
17 | mem_peak = torch.cuda.memory_stats()['active_bytes.all.peak'] / 1024 / 1024 / 1024
18 | print(f"[{cnt}] mem_alloc: {mem_alloc:.4f}, mem_reserve: {mem_reserve:.4f}, mem_peak: {mem_peak:.4f}")
19 | cnt += 1
20 |
21 |
22 | class MyTestCase(unittest.TestCase):
23 | def test_1_validation(self):
24 | print("1. this is validation test")
25 | dtype = torch.float16
26 | block_size, dim, batch_size, head_size, seq_len = 256, 128, 4, 32, 2048
27 | num_projs = 8
28 |
29 | query = torch.randn((batch_size, head_size, seq_len, dim), device='cuda', dtype=dtype)
30 |
31 | self.lsh = AngularLSHTriton(num_projs=num_projs, dim=(1, 1, dim)).to(device='cuda', dtype=dtype)
32 |
33 | # apply lsh in pytorch
34 | t0 = time.time()
35 | query_hash_buckets = self.lsh.hash_torch(query)
36 | t1 = time.time()
37 | print('the runtime of torch lsh:', t1 - t0)
38 |
39 | # apply lsh in triton
40 | check_memory()
41 | t2 = time.time()
42 | query_hash_buckets_triton = self.lsh.hash_triton(query)
43 | t3 = time.time()
44 | check_memory()
45 | print('the runtime of triton lsh:', t3 - t2)
46 |
47 | print('difference between torch and triton hashes: ', (query_hash_buckets.float() - query_hash_buckets_triton.float()).norm())
48 |
49 | def test_2_runtime(self):
50 | print()
51 | print("2. this is runtime test")
52 |
53 | block_size, dim, batch_size, head_size = 256, 128, 4, 32
54 | num_projs = 8
55 | seq_len = 2048
56 | dtype = torch.float16
57 |
58 | query = torch.randn((batch_size, head_size, seq_len, dim), device='cuda', dtype=dtype)
59 |
60 | self.lsh = AngularLSHTriton(num_projs=num_projs, dim=(1, 1, dim)).to(device='cuda', dtype=dtype)
61 |
62 | def test_fn1():
63 | query_hash_buckets = self.lsh.hash_torch(query)
64 |
65 | warmup = 20
66 | rep = 1000
67 |
68 | tim_py_q20, tim_py_q50, tim_py_q80 = triton.testing.do_bench(test_fn1, warmup=warmup, rep=rep,
69 | quantiles=[0.2, 0.5, 0.8])
70 | print(f"pytorch runtime: {tim_py_q50:.5f} ms ({tim_py_q20:.5f}, {tim_py_q80:.5f})")
71 |
72 | def test_fn2():
73 | query_hash_buckets_triton = self.lsh.hash_triton(query)
74 |
75 | tim_tr_q20, tim_tr_q50, tim_tr_q80 = triton.testing.do_bench(test_fn2, warmup=warmup, rep=rep,
76 | quantiles=[0.2, 0.5, 0.8])
77 | print(f"triton runtime: {tim_tr_q50:.5f} ms ({tim_tr_q20:.5f}, {tim_tr_q80:.5f})")
78 |
79 |
80 | if __name__ == '__main__':
81 | unittest.main()
82 |
--------------------------------------------------------------------------------
/chatglm_fast_attention.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 |
4 | from hyper_attention import HyperAttention
5 |
6 |
7 | # Edited from https://huggingface.co/THUDM/chatglm2-6b-32k/blob/main/modeling_chatglm.py#L194
8 | class FastCoreAttention(torch.nn.Module):
9 |
10 | def __init__(self, config, layer_number, **kwargs):
11 | super(FastCoreAttention, self).__init__()
12 |
13 | self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
14 | self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
15 | if self.apply_query_key_layer_scaling:
16 | self.attention_softmax_in_fp32 = True
17 | self.layer_number = max(1, layer_number)
18 |
19 | projection_size = config.kv_channels * config.num_attention_heads
20 |
21 | # Per attention head and per partition values.
22 | self.hidden_size_per_partition = projection_size
23 | self.hidden_size_per_attention_head = projection_size // config.num_attention_heads
24 | self.num_attention_heads_per_partition = config.num_attention_heads
25 |
26 | coeff = None
27 | self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
28 | if self.apply_query_key_layer_scaling:
29 | coeff = self.layer_number
30 | self.norm_factor *= coeff
31 | self.coeff = coeff
32 |
33 | self.attention_dropout = torch.nn.Dropout(config.attention_dropout)
34 |
35 | self.attn_method = kwargs.get('attn_method')
36 | if self.attn_method == 'hyper':
37 | lsh_num_projs = kwargs.get('lsh_num_projs')
38 | block_size = kwargs.get('block_size')
39 | sample_size = kwargs.get('sample_size')
40 | min_seq_len = kwargs.get('min_seq_len')
41 | smooth_block = kwargs.get('smooth_block')
42 | self.attn = HyperAttention(
43 | input_dim=128,
44 | lsh_num_projs=lsh_num_projs,
45 | block_size=block_size,
46 | sample_size=sample_size,
47 | min_seq_len=min_seq_len,
48 | smooth_block=smooth_block,
49 | )
50 | else:
51 | raise NotImplementedError("Invalid attn_method option")
52 |
53 | def forward(self, query_layer, key_layer, value_layer, attention_mask):
54 |
55 | query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
56 | if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
57 | softmax_scale = query_layer.shape[-1] ** (-0.5)
58 | context_layer = self.attn(query_layer, key_layer, value_layer, causal=True)
59 |
60 | else:
61 | assert False, 'this part the query length and key length may be different and not be a computational bottleneck.'
62 | if attention_mask is not None:
63 | attention_mask = ~attention_mask
64 | context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
65 | attention_mask)
66 |
67 | context_layer = context_layer.permute(2, 0, 1, 3)
68 | new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
69 | context_layer = context_layer.reshape(*new_context_layer_shape)
70 | return context_layer
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # HyperAttention: Long-context Attention in Near-Linear Time
2 |
3 | Triton Implementation of HyperAttention Algorithm
4 |
5 | # Requirements
6 |
7 | The code requires ``pytorch`` and [``triton``](https://github.com/openai/triton).
8 | pytorch version 2.0.1 tested, but any version >= 2.0.0 might work.
9 | Also makes use of [triton](https://github.com/openai/triton) implementation of [FlashAttention](https://github.com/Dao-AILab/flash-attention/tree/main). Flash attention kernel adapted to be compilable with triton version **2.1.0.**
10 |
11 | # How to use
12 |
13 | The impelmentation of HyperAttention can be found in ``hyper_attention.py``. An example of usage:
14 |
15 | ```python
16 | from hyper_attention import HyperAttention
17 |
18 | attn = HyperAttention(
19 | input_dim=64,
20 | lsh_num_projs=8,
21 | block_size=256,
22 | sample_size=256,
23 | min_seq_len=2048,
24 | smooth_block=False,)
25 |
26 | attn_output = attn(query, key, value, causal=True)
27 | ```
28 |
29 | The module has the following parameters:
30 | - ```input_dim```: the dimension of input query and key. (Required)
31 | - ```lsh_num_projs```: the number of random projection vectors used in the locality-sensitive hashing scheme. The default is 8.
32 | - ```block_size```: the size of blocks for the block-diagonal approximation. It must be divisible by 128. The default is 256.
33 | - ```sample_size```: the number of sampled columns in the attention matrix $A$. It must be divisible by 128. The default is 256.
34 | - ```min_seq_len```: minimum sequence length that HyperAttention applies. When the sequence length is smaller than this value we compute exactly using the FlashAttention because overheads of HyperAttention may dominate the runtime for short sequences. The default value is ```2048```.
35 | - ```smooth_block```: smoothen the block-diagonal approximation by letting the blocks overlap and resemble smooth banded diagonal approximation. The default is False.
36 | - The sequence lengths of both ```query``` and ```key``` must be divisible by ```block_size```.
37 |
38 | # Speedup on single attention layer
39 |
40 | In this section, we showcase the speedup achieved by HyperAttention in comparison to the Triton implementation of FlashAttention (v2) across a range of sequence lengths. The configuration involves 32 heads and a head_dim 64, and the results are obtained by running the methods on NVIDIA A10 Tensor Core GPUs.
41 |
42 | ## Causal masking (decoder-style attention)
43 |
44 | The speedup factors for both the forward pass and forward+backward passes for the attention decoder with causal masking are plotted below. HyperAttention exhibits over a ```22x``` speedup for the forward pass and an over ```16x``` speedup for the combined forward+backward passes when the sequence length is ```131k```.
45 |
46 |
47 |
48 |
49 |
50 | ## No causal masking (encoder-style attention)
51 |
52 | The speedup factors for both the forward pass and forward+backward passes in the attention encoder, without causal masking, are shown below. HyperAttention reduces to a notably simpler and more efficient algorithm in the absence of causal masking, avoiding the need for recursive partitioning of the attention matrix. Therefore, HyperAttention showcases remarkable speedups, surpassing ```270x``` acceleration for both the forward pass and the combined forward+backward passes when the sequence length is ```131k```.
53 |
54 |
55 |
56 |
57 |
--------------------------------------------------------------------------------
/benchmark_single_attention.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import triton
4 |
5 | from src.flash_attn_triton import flash_attn_func
6 | from hyper_attention import HyperAttention
7 |
8 | try:
9 | from flash_attn import flash_attn_func as flash_attn_func_cuda
10 | except ImportError:
11 | flash_attn_func_cuda = None
12 |
13 |
14 | def get_arguments():
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument("--no_causal", action="store_true")
17 | parser.add_argument("--smooth_block", action="store_true")
18 | parser.add_argument("--mode", type=str, default="fwd+bwd", choices=['fwd', 'bwd', 'fwd+bwd'])
19 | parser.add_argument("--attn_method", type=str, default="flash",
20 | choices=['flash', 'flash-cuda', 'hyper'])
21 | return parser.parse_args()
22 |
23 |
24 | def get_tensors(batch_size, seq_len, head_size, dim):
25 | q = torch.randn((batch_size, seq_len, head_size, dim), dtype=torch.bfloat16, device="cuda", requires_grad=True)
26 | k = torch.randn((batch_size, seq_len, head_size, dim), dtype=torch.bfloat16, device="cuda", requires_grad=True)
27 | v = torch.randn((batch_size, seq_len, head_size, dim), dtype=torch.bfloat16, device="cuda", requires_grad=True)
28 | return q, k, v
29 |
30 |
31 | def run_flash_attn(batch_size, head_size, seq_len, dim, causal, mode, impl="triton", warmup=20, rep=100):
32 | q, k, v = get_tensors(batch_size, seq_len, head_size, dim)
33 | if impl == "cuda":
34 | if flash_attn_func_cuda is None:
35 | raise ImportError("Please install flash_attn (pip install flash-attn --no-build-isolation)")
36 | fn = lambda: flash_attn_func_cuda(q, k, v, causal=causal)
37 | else:
38 | fn = lambda: flash_attn_func(q, k, v, None, causal, None)[0]
39 | if mode == 'fwd':
40 | return triton.testing.do_bench(fn, warmup=warmup, rep=rep, quantiles=[0.2, 0.5, 0.8])
41 | elif mode == 'bwd':
42 | o = fn()
43 | do = torch.randn_like(o)
44 | fn = lambda: o.backward(do, retain_graph=True)
45 | return triton.testing.do_bench(fn, warmup=warmup, rep=rep, quantiles=[0.2, 0.5, 0.8])
46 | else: # mode == 'fwd+bwd'
47 | q20_fwd, median_fwd, q80_fwd = triton.testing.do_bench(fn, warmup=warmup, rep=rep, quantiles=[0.2, 0.5, 0.8])
48 | o = fn()
49 | do = torch.randn_like(o)
50 | fn = lambda: o.backward(do, retain_graph=True)
51 | q20_bwd, median_bwd, q80_bwd = triton.testing.do_bench(fn, warmup=warmup, rep=rep, quantiles=[0.2, 0.5, 0.8])
52 | return q20_fwd + q20_bwd, median_fwd + median_bwd, q80_fwd + q80_bwd
53 |
54 |
55 | def run_hyper_attn(batch_size, head_size, seq_len, dim, causal, mode, smooth_block, warmup=20, rep=100):
56 | q, k, v = get_tensors(batch_size, head_size, seq_len, dim)
57 | block_size = 256
58 | sample_size = 256
59 |
60 | attn = HyperAttention(
61 | input_dim=dim,
62 | block_size=block_size,
63 | sample_size=sample_size,
64 | smooth_block=smooth_block,).to(device='cuda', dtype=q.dtype)
65 |
66 | fn = lambda: attn(q, k, v, causal=causal)
67 |
68 | if mode == 'fwd':
69 | return triton.testing.do_bench(fn, warmup=warmup, rep=rep, quantiles=[0.2, 0.5, 0.8])
70 | elif mode == 'bwd':
71 | o = fn()
72 | do = torch.randn_like(o)
73 | fn = lambda: o.backward(do, retain_graph=True)
74 | return triton.testing.do_bench(fn, warmup=warmup, rep=rep, quantiles=[0.2, 0.5, 0.8])
75 | else: # mode == 'fwd+bwd'
76 | q20_fwd, median_fwd, q80_fwd = triton.testing.do_bench(fn, warmup=warmup, rep=rep, quantiles=[0.2, 0.5, 0.8])
77 | o = fn()
78 | do = torch.randn_like(o)
79 | fn = lambda: o.backward(do, retain_graph=True)
80 | q20_bwd, median_bwd, q80_bwd = triton.testing.do_bench(fn, warmup=warmup, rep=rep, quantiles=[0.2, 0.5, 0.8])
81 | return q20_fwd + q20_bwd, median_fwd + median_bwd, q80_fwd + q80_bwd
82 |
83 |
84 | def main():
85 | args = get_arguments()
86 | for arg_name, arg_var in args.__dict__.items():
87 | print(f"{arg_name:<16} : {arg_var}")
88 |
89 | seq_lens = [2 ** i for i in range(10, 18)]
90 |
91 | attn_method = args.attn_method # ['flash', 'hyper']
92 | mode = args.mode # ['fwd', 'bwd', 'fwd+bwd']
93 | batch_size, head_size, dim = 1, 24, 64
94 | print(f"mode: {mode}, attn_method: {attn_method}, batch_size: {batch_size}, head_size: {head_size}, dim: {dim}")
95 |
96 | causal = not args.no_causal
97 |
98 | for seq_len in seq_lens:
99 | if attn_method == 'flash':
100 | ms = run_flash_attn(batch_size, head_size, seq_len, dim, causal, mode=args.mode)
101 | elif attn_method == 'flash-cuda':
102 | ms = run_flash_attn(batch_size, head_size, seq_len, dim, causal, mode=args.mode, impl="cuda")
103 | elif attn_method == 'hyper':
104 | ms = run_hyper_attn(batch_size, head_size, seq_len, dim, causal, mode=args.mode,
105 | smooth_block=args.smooth_block)
106 | else:
107 | raise NotImplementedError
108 |
109 | print(
110 | f"[{mode:<8}], {attn_method}, seq_len: {seq_len:<8}, causal: {causal}, ms: {ms[0]:.5f} ({ms[1]:.5f}, {ms[2]:.5f}) | ")
111 |
112 |
113 | if __name__ == "__main__":
114 | main()
115 |
116 |
--------------------------------------------------------------------------------
/benchmark_patch_llm.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from tqdm import tqdm
3 | import numpy as np
4 | import torch
5 | from torch.nn import CrossEntropyLoss
6 | from datasets import load_dataset, concatenate_datasets
7 | from transformers import AutoModelForCausalLM, AutoTokenizer
8 |
9 | from replace_llm_attention import patch_attention_layers
10 |
11 |
12 | def get_model_and_tokenizer(model_name):
13 | if model_name == "chatglm2-6b-32k":
14 | tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b-32k", trust_remote_code=True)
15 | model = AutoModelForCausalLM.from_pretrained("THUDM/chatglm2-6b-32k", trust_remote_code=True)
16 | else:
17 | raise NotImplementedError("Currently we only support chatglm2")
18 |
19 | return model, tokenizer
20 |
21 |
22 | def get_arguments():
23 | parser = argparse.ArgumentParser()
24 | parser.add_argument("--seq_len", type=int, default=32768)
25 | # patch config
26 | parser.add_argument("--patch_config", type=str, default="last", choices=['last', 'first', 'even', 'odd'])
27 | parser.add_argument("--attn_method", type=str, default="hyper", choices=['flash', 'hyper', 'hyper-cuda'])
28 | parser.add_argument("--num_patch_layers", type=int, default=-1)
29 | # params of HyperAttention
30 | parser.add_argument("--block_size", type=int, default=256)
31 | parser.add_argument("--sample_size", type=int, default=256)
32 | parser.add_argument("--lsh_num_projs", type=int, default=8)
33 | parser.add_argument("--min_seq_len", type=int, default=2048)
34 | parser.add_argument("--smooth_block", action="store_true")
35 | # currently only supports **chatglm2-6b-32k**
36 | parser.add_argument("--model_name", type=str, default="chatglm2-6b-32k")
37 | return parser.parse_args()
38 |
39 |
40 | @torch.no_grad()
41 | def main():
42 | args = get_arguments()
43 | for arg_name, arg_var in args.__dict__.items():
44 | print(f"{arg_name:<16} : {arg_var}")
45 |
46 | model, tokenizer = get_model_and_tokenizer(args.model_name)
47 | tokenizer.model_max_length = args.seq_len
48 | device = "cuda"
49 | dtype = torch.bfloat16
50 |
51 | # Load LongBench datasets
52 | dataset = 'longbench'
53 | dataset_names = ["narrativeqa", "qasper", "multifieldqa_en", "multifieldqa_zh", "hotpotqa", "2wikimqa", "musique", \
54 | "dureader", "gov_report", "qmsum", "multi_news", "vcsum", "trec", "triviaqa", "samsum", "lsht", \
55 | "passage_count", "passage_retrieval_en", "passage_retrieval_zh", "lcc", "repobench-p"]
56 |
57 | data_subset_all = []
58 | for dataset in dataset_names:
59 | data_ = load_dataset('THUDM/LongBench', f"{dataset}", split='test')
60 | data_subset = data_.filter(lambda x: len(tokenizer.encode(x['context'])) >= args.seq_len)
61 | if len(data_subset) > 0:
62 | data_subset_all.append(data_subset)
63 | data = concatenate_datasets(data_subset_all)
64 |
65 | encoded_texts = []
66 | pbar = tqdm(data)
67 | for i, data_i in enumerate(pbar):
68 | encoded_text = tokenizer.encode(data_i['context'], return_tensors='pt', truncation=True)
69 | pbar.set_description(f"seq_len: {len(encoded_text[0])}, n_data: {len(encoded_texts)}")
70 | if len(encoded_text[0]) < args.seq_len:
71 | continue
72 | encoded_texts.append(encoded_text)
73 | print(f"# of data longer than {args.seq_len}: {len(encoded_texts)}")
74 |
75 | if args.attn_method != 'flash':
76 | patch_attention_layers(model=model, **args.__dict__)
77 |
78 | model.to(device=device, dtype=dtype)
79 | model.eval()
80 | loss_fct = CrossEntropyLoss(reduction="none")
81 |
82 | ppls = []
83 |
84 | pbar = tqdm(range(len(encoded_texts)))
85 | for bid in pbar:
86 | encoded_batch = encoded_texts[bid:bid + 1]
87 | if type(encoded_batch) == dict:
88 | attn_mask = encoded_batch['attention_mask'] if 'attention_mask' in encoded_batch.keys() else None
89 | encoded_batch = encoded_batch['input_ids']
90 | elif type(encoded_batch) == list:
91 | encoded_batch = encoded_batch[0]
92 |
93 | encoded_batch = encoded_batch.to(device)
94 | attn_mask = torch.ones_like(encoded_batch)
95 |
96 | out_logits = model(encoded_batch).logits
97 |
98 | labels = encoded_batch
99 |
100 | shift_logits = out_logits[..., :-1, :].contiguous()
101 | shift_labels = labels[..., 1:].contiguous()
102 | shift_attention_mask_batch = attn_mask[..., 1:].contiguous()
103 |
104 | loss_ = loss_fct(shift_logits.transpose(1, 2), shift_labels).float()
105 | perplexity_batch = torch.exp2(
106 | (loss_ * shift_attention_mask_batch).sum(1)
107 | / shift_attention_mask_batch.sum(1)
108 | )
109 | ppls += perplexity_batch.tolist()
110 |
111 | pbar.set_description(
112 | f"[{bid:<4}/{len(encoded_texts)}] avg_ppls: {np.mean(np.array(ppls)[~np.isnan(np.array(ppls))]):.4f}")
113 |
114 | del out_logits, encoded_batch, attn_mask, shift_logits, shift_labels, shift_attention_mask_batch, perplexity_batch
115 |
116 | nan_cnt = sum(np.isnan(np.array(ppls)))
117 | ppl_mean = np.mean(np.array(ppls)[~np.isnan(np.array(ppls))])
118 |
119 | print(f"ppl: {ppl_mean}, nan_cnt: {nan_cnt}")
120 | res_str = f"model: {args.model_name}, dtype: {dtype}, seq_len: {args.seq_len}, num_patch_layers: {args.num_patch_layers}, n_data: {len(encoded_texts)}, ppl: {ppl_mean}, nan_cnt: {nan_cnt}\n"
121 | print(res_str)
122 |
123 |
124 | if __name__ == "__main__":
125 | main()
126 |
--------------------------------------------------------------------------------
/hyper_attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from src.attn_utils import add_self_attentions
4 | from src.flash_attn_triton import flash_attn_func
5 | from src.hyper_attn_triton import hyper_attn_func
6 | from src.angular_lsh_triton import AngularLSHTriton
7 |
8 |
9 | class HyperAttention(torch.nn.Module):
10 |
11 | def __init__(self, input_dim=64, lsh_num_projs=8, block_size=256, sample_size=256, min_seq_len=2048,
12 | smooth_block=False, **kwargs):
13 | """
14 | - block_size and sample_size must be divisible by 128
15 | """
16 | super().__init__()
17 | self.input_dim = input_dim
18 | self.lsh_num_projs = lsh_num_projs
19 | self.block_size = block_size
20 | self.sample_size = sample_size
21 | self.min_seq_len = min_seq_len
22 | self.smooth_block = smooth_block
23 | self.lsh = AngularLSHTriton(num_projs=self.lsh_num_projs, dim=(1, 1, input_dim))
24 |
25 | def forward(self, query: torch.tensor, key: torch.tensor, value: torch.tensor, scale=None, causal=False,
26 | return_lse=False):
27 | """
28 | Forward function for HyperAttention. If no causal masking, simply invokes forward_no_causal_mask method.
29 | If there is causal masking, it partitions the attention matrix and recurses on the partitions.
30 | inputs:
31 | - query, key, and valu: must have same sequence lengths but dimension of values vectors can be different
32 | from that of query or key
33 | - sequence lengths must be divisible by block_size
34 | output:
35 | - attn: (approximation of) the final attention output tensor
36 | - lse: (approximation of) log sum exp of the qk matrix
37 | """
38 | query = query.contiguous()
39 | key = key.contiguous()
40 | value = value.contiguous()
41 |
42 | n_query = query.shape[2]
43 | batch_size, n_heads, n_key, dim = key.shape
44 | scale = scale or dim ** (-0.5)
45 | assert n_query == n_key
46 |
47 | # without causal masking
48 | if causal is False:
49 | attn, lse = self.forward_no_causal_mask(query, key, value, scale)
50 |
51 | else: # with causal masking
52 | if n_key <= self.min_seq_len:
53 | attn, lse = flash_attn_func(query.transpose(1, 2),
54 | key.transpose(1, 2),
55 | value.transpose(1, 2),
56 | None, True, scale)
57 | attn = attn.transpose(1, 2)
58 |
59 | else:
60 | # If n_query is odd we pad inputs by zero rows
61 | if n_query % 2:
62 | query = torch.nn.functional.pad(query, (0, 0, 0, 1), mode='constant', value=0.)
63 | key = torch.nn.functional.pad(key, (0, 0, 0, 1), mode='constant', value=0.)
64 | value = torch.nn.functional.pad(value, (0, 0, 0, 1), mode='constant', value=0.)
65 |
66 | # extract block diagonal parts
67 | q_bd = query.view(batch_size, 2 * n_heads, query.shape[2] // 2, query.shape[-1])
68 | k_bd = key.view(batch_size, 2 * n_heads, key.shape[2] // 2, key.shape[-1])
69 | v_bd = value.view(batch_size, 2 * n_heads, key.shape[2] // 2, value.shape[-1])
70 |
71 | attn_bd, lse_bd = self.forward(q_bd, k_bd, v_bd, scale, True, True)
72 |
73 | if attn_bd.shape[2] not in attn_bd.stride():
74 | attn_bd = attn_bd.contiguous()
75 | attn_bd = attn_bd.view(batch_size, n_heads, -1, dim)
76 |
77 | if lse_bd.shape[2] not in lse_bd.stride():
78 | lse_bd = lse_bd.contiguous()
79 | lse_bd = lse_bd.view(batch_size, n_heads, -1, 1)
80 |
81 | # lowe diagonal block is an unmasked attention
82 | attn_unmasked, lse_unmasked = self.forward_no_causal_mask(
83 | query[:, :, key.shape[2] // 2:, :], key[:, :, :key.shape[2] // 2, :],
84 | value[:, :, :key.shape[2] // 2, :], scale)
85 |
86 | attn_up, lse_up = attn_bd[:, :, :query.shape[2] // 2, :], lse_bd[:, :, :query.shape[2] // 2, :]
87 | attn_down, lse_down = add_self_attentions(attn_bd[:, :, query.shape[2] // 2:, :],
88 | lse_bd[:, :, query.shape[2] // 2:, :],
89 | attn_unmasked, lse_unmasked)
90 |
91 | attn = torch.cat((attn_up, attn_down), dim=-2)
92 | lse = torch.cat((lse_up, lse_down), dim=-2)
93 |
94 | if n_query % 2:
95 | attn = attn[:, :, :-1, :]
96 | lse = lse[:, :, :-1, :]
97 |
98 | if not return_lse:
99 | return attn
100 | else:
101 | return attn, lse
102 |
103 | def forward_no_causal_mask(self, query, key, value, scale):
104 | """
105 | - sequence lengths must be divisible by block_size
106 | """
107 | batch_size, head_size, n_query, dim = query.shape
108 |
109 | if self.min_seq_len > n_query:
110 | attn, lse = flash_attn_func(query.transpose(1, 2),
111 | key.transpose(1, 2),
112 | value.transpose(1, 2),
113 | None, False, scale)
114 | else:
115 | # Hash keys and queries via SortLSH and obtain buckets
116 | _, query_sort_idx = torch.sort(self.lsh.hash_triton(query), dim=2, stable=True) # batch_size x head_size x n
117 | _, key_sort_idx = torch.sort(self.lsh.hash_triton(key), dim=2, stable=True)
118 |
119 | # Now run hyper attention function on q,k,v and the permutations
120 | attn, lse = hyper_attn_func(query.transpose(1, 2),
121 | key.transpose(1, 2),
122 | value.transpose(1, 2),
123 | query_sort_idx.transpose(1, 2),
124 | key_sort_idx.transpose(1, 2),
125 | self.block_size,
126 | self.sample_size,
127 | scale,
128 | self.smooth_block,
129 | )
130 | attn = attn.transpose(1, 2)
131 |
132 | return attn, lse.unsqueeze(-1)
133 |
--------------------------------------------------------------------------------
/src/angular_lsh_triton.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import triton
4 | import triton.language as tl
5 |
6 |
7 | @triton.heuristics(
8 | {
9 | "EVEN_M": lambda args: args["seqlen"] % args["BLOCK_M"] == 0,
10 | "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
11 | }
12 | )
13 | @triton.jit
14 | def _angular_lsh_kernel(
15 | in_mat,
16 | proj_dir,
17 | perm,
18 | enc_vec,
19 | buckets,
20 | stride_in_matb,
21 | stride_in_math,
22 | stride_in_matm,
23 | stride_proj_dirb,
24 | stride_proj_dirh,
25 | stride_proj_dird,
26 | stride_bucketsb,
27 | stride_bucketsh,
28 | nheads,
29 | seqlen,
30 | seqlen_rounded,
31 | headdim,
32 | NUM_PROJ_ROUNDED: tl.constexpr,
33 | num_projs: tl.constexpr,
34 | BLOCK_HEADDIM: tl.constexpr,
35 | EVEN_M: tl.constexpr,
36 | EVEN_HEADDIM: tl.constexpr,
37 | BLOCK_M: tl.constexpr,
38 | ):
39 | start_m = tl.program_id(0)
40 | off_hb = tl.program_id(1)
41 | off_b = off_hb // nheads
42 | off_h = off_hb % nheads
43 | offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
44 | offs_n = tl.arange(0, NUM_PROJ_ROUNDED)
45 | offs_d = tl.arange(0, BLOCK_HEADDIM)
46 |
47 | in_mat_ptrs = (
48 | in_mat + off_b * stride_in_matb + off_h * stride_in_math + (offs_m[:, None] * stride_in_matm +
49 | offs_d[None, :])
50 | )
51 | proj_dir_ptrs = (
52 | proj_dir + off_b * stride_proj_dirb + off_h * stride_proj_dirh + (offs_d[:, None] * stride_proj_dird +
53 | offs_n[None, :])
54 | )
55 |
56 | # load in_mat block
57 | if EVEN_M:
58 | if EVEN_HEADDIM:
59 | mat = tl.load(in_mat_ptrs)
60 | else:
61 | mat = tl.load(in_mat_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
62 | else:
63 | if EVEN_HEADDIM:
64 | mat = tl.load(in_mat_ptrs, mask=offs_m[:, None] < seqlen, other=0.0)
65 | else:
66 | mat = tl.load(in_mat_ptrs, mask=(offs_m[:, None] < seqlen) & (offs_d[None, :] < headdim), other=0.0)
67 |
68 | # load proj_dir block, need to mask out out of bound offsets
69 | if EVEN_HEADDIM:
70 | proj_dir_block = tl.load(proj_dir_ptrs, mask=offs_n[None, :] < num_projs, other=0.0)
71 | else:
72 | proj_dir_block = tl.load(proj_dir_ptrs,
73 | mask=(offs_n[None, :] < num_projs) & (offs_d[:, None] * stride_proj_dird < headdim),
74 | other=0.0)
75 |
76 | # multiply the in_mat block with proj_dir block to get the mask
77 | mask = tl.dot(mat, proj_dir_block)
78 | mask = tl.where(mask > 0.0, 1.0, 0.0)
79 |
80 | # form enc_vec
81 | encoding_vectors = tl.load(enc_vec+offs_n, mask=offs_n < num_projs, other=0.0)
82 |
83 | # multiply mask by enc_vec
84 | bin_ids = tl.sum(mask * encoding_vectors[None, :], 1).to(tl.int32)
85 | # bin_ids = tl.ravel(bin_ids) # flatten the bin_ids into a 1d tensor
86 |
87 | # read hash buckets from look up table
88 | hash_buckets = tl.load(perm+bin_ids)
89 |
90 | # write back bin_ids
91 | # initialize pointers to output
92 | buckets_ptrs = buckets + off_b * stride_bucketsb + off_h * stride_bucketsh + offs_m
93 | if EVEN_M:
94 | tl.store(buckets_ptrs, hash_buckets)
95 | else:
96 | tl.store(buckets_ptrs, hash_buckets, mask=offs_m < seqlen)
97 |
98 |
99 | def _angular_lsh(in_mat, proj_dir, perm, enc_vec):
100 | # shape constraints
101 | num_projs = proj_dir.shape[-1]
102 | batch, nheads, seqlen, d = in_mat.shape
103 | assert (proj_dir.shape == (batch, nheads, d, num_projs)) or (proj_dir.shape == (1, 1, d, num_projs))
104 | assert in_mat.dtype == proj_dir.dtype, "All three tensors must have the same type"
105 | assert in_mat.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16"
106 | assert in_mat.is_cuda and proj_dir.is_cuda and perm.is_cuda and enc_vec.is_cuda
107 | if proj_dir.shape[:2] == (1, 1):
108 | stride_proj_dirb, stride_proj_dirh = 0, 0
109 | else:
110 | stride_proj_dirb, stride_proj_dirh = proj_dir.stride()[:2]
111 |
112 | seqlen_rounded = math.ceil(seqlen / 128) * 128
113 | num_projs_rounded = 16
114 | buckets = torch.empty((batch, nheads, seqlen), device=in_mat.device, dtype=torch.int32)
115 |
116 | BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
117 | BLOCK = 128
118 | num_warps = 4 if d <= 64 else 8
119 | grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_M"]), batch * nheads)
120 | _angular_lsh_kernel[grid](
121 | in_mat=in_mat,
122 | proj_dir=proj_dir,
123 | perm=perm,
124 | enc_vec=enc_vec,
125 | buckets=buckets,
126 | stride_in_matb=in_mat.stride(0),
127 | stride_in_math=in_mat.stride(1),
128 | stride_in_matm=in_mat.stride(2),
129 | stride_proj_dirb=stride_proj_dirb,
130 | stride_proj_dirh=stride_proj_dirh,
131 | stride_proj_dird=proj_dir.stride(2),
132 | stride_bucketsb=buckets.stride(0),
133 | stride_bucketsh=buckets.stride(1),
134 | nheads=nheads,
135 | seqlen=seqlen,
136 | seqlen_rounded=seqlen_rounded,
137 | headdim=d,
138 | NUM_PROJ_ROUNDED=num_projs_rounded,
139 | num_projs=num_projs,
140 | BLOCK_HEADDIM=BLOCK_HEADDIM,
141 | BLOCK_M=BLOCK,
142 | num_warps=num_warps,
143 | num_stages=1,
144 | )
145 | return buckets
146 |
147 |
148 | class AngularLSHTriton(torch.nn.Module):
149 | """
150 | inputs:
151 | - num_projs: a positive integer that determines the number of random projections used by hash function
152 | - dim: positive integer that determines the dimension of input vectors
153 | - mat: a tensor whose last shape is equal to dim and gets hashed by the lsh function
154 | output:
155 | - buckets: a tensor with shape mat.shape[:-1] and each entry is an integer in [0, 2^num_proj - 1]
156 | """
157 | def __init__(self, num_projs, dim, rng=None):
158 | super().__init__()
159 | self.num_projs = num_projs
160 |
161 | if num_projs > 0:
162 | self.register_buffer('perm', self._unit_hamming_distance_array(self.num_projs), persistent=False)
163 | self.register_buffer('proj_dir', torch.randn(dim + (num_projs,), generator=rng), persistent=False)
164 | self.register_buffer('enc_vec', 2 ** torch.arange(self.num_projs).view(1, 1, 1, -1), persistent=False)
165 | else:
166 | raise ValueError("Invalid value for num_projs")
167 |
168 | def _unit_hamming_distance_array(self, size_n):
169 | if size_n == 1:
170 | return torch.tensor([0, 1], dtype=torch.int32)
171 | a = self._unit_hamming_distance_array(size_n - 1)
172 | b = torch.concat([a, torch.flip(a, dims=[0]) + 2 ** (size_n - 1)], 0)
173 | return b if b.stride(-1) == 1 else b.contiguous()
174 |
175 | def hash_torch(self, mat):
176 | mask = torch.einsum('...nd,...dr -> ...nr', mat, self.proj_dir)
177 | mask = mask > 0
178 | bin_ids = (mask * self.enc_vec).sum(-1)
179 | return self.perm[bin_ids]
180 |
181 | def hash_triton(self, mat):
182 | return _angular_lsh(mat, self.proj_dir, self.perm, self.enc_vec)
183 |
184 | def __repr__(self):
185 | return f"AngularLSH(num_proj={self.num_projs}, proj_dir.shape={self.proj_dir.shape})"
186 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/unit_tests/test_hyper_attention.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import time
3 | import torch
4 | import triton
5 | import math
6 |
7 | import sys; sys.path.append("/home/ec2-user/workspace/hyper_attention")
8 | from src.flash_attn_triton import flash_attn_func
9 | from src.hyper_attn_triton import hyper_attn_func
10 | from src.attn_utils import add_self_attentions, indexing
11 |
12 | cnt = 0
13 |
14 |
15 | def check_memory(new_cnt=-1):
16 | global cnt
17 | if new_cnt != -1:
18 | cnt = new_cnt
19 | mem_alloc = torch.cuda.memory_allocated()/1024/1024/1024
20 | mem_reserve = torch.cuda.memory_reserved()/1024/1024/1024
21 | mem_peak = torch.cuda.memory_stats()['active_bytes.all.peak']/1024/1024/1024
22 | print(f"[{cnt}] mem_alloc: {mem_alloc:.4f}, mem_reserve: {mem_reserve:.4f}, mem_peak: {mem_peak:.4f}")
23 | cnt += 1
24 | return
25 |
26 |
27 | class MyTestCase(unittest.TestCase):
28 | def test_forward_attention(self):
29 | print("1. this is forward test")
30 |
31 | dtype = torch.bfloat16
32 |
33 | batch_size = 4
34 | block_size = 512
35 | dim = 128
36 | head_size = 32
37 | seq_len = 2048
38 | sample_size = 128
39 | smooth_block = False
40 |
41 | query = torch.randn((batch_size, head_size, seq_len, dim), device='cuda', dtype=dtype)
42 | key = torch.randn((batch_size, head_size, seq_len, dim), device='cuda', dtype=dtype)
43 | key[:, :, :sample_size, :] *= 4.
44 | value = torch.randn((batch_size, head_size, seq_len, dim), device='cuda', dtype=dtype)
45 |
46 | q_buckets_idx = torch.randint(0, 64, (batch_size, head_size, seq_len), device='cuda')
47 | k_buckets_idx = torch.randint(0, 64, (batch_size, head_size, seq_len), device='cuda')
48 | _, query_sort_idx = torch.sort(q_buckets_idx, dim=2, stable=True)
49 | _, key_sort_idx = torch.sort(k_buckets_idx, dim=2, stable=True)
50 | check_memory()
51 |
52 | # compute attention by presorting queries and sorting back the output attention
53 | t0 = time.time()
54 | query_sort_idx_inv = torch.argsort(query_sort_idx, dim=2, stable=True)
55 | query_sorted = indexing(query, query_sort_idx)
56 | key_sorted = indexing(key, key_sort_idx)
57 | value_sorted = indexing(value, key_sort_idx)
58 | query_split_per_block = query_sorted.view(-1, 1, block_size, dim)
59 | key_split_per_block = key_sorted.view(-1, 1, block_size, dim)
60 | value_split_per_block = value_sorted.view(-1, 1, block_size, dim)
61 | attn_block, lse_block = flash_attn_func(query_split_per_block.transpose(1, 2),
62 | key_split_per_block.transpose(1, 2),
63 | value_split_per_block.transpose(1, 2))
64 |
65 | attn_sample, lse_sample = flash_attn_func(query.transpose(1, 2),
66 | key[:, :, :sample_size, :].transpose(1, 2),
67 | value[:, :, :sample_size, :].transpose(1, 2))
68 | attn_block = attn_block.transpose(1, 2)
69 | attn_block = attn_block.view(batch_size, head_size, query_sorted.shape[2], -1)
70 | attn_sample = attn_sample.transpose(1, 2)
71 | lse_block = lse_block[:, :, :query_sorted.shape[2]]
72 | lse_block = lse_block.view(batch_size, head_size, query_sorted.shape[2], -1)
73 | flash_attn_block = indexing(attn_block, query_sort_idx_inv)+attn_sample
74 | lse_block = indexing(lse_block, query_sort_idx_inv)
75 | attn, lse = add_self_attentions(flash_attn_block, lse_block, attn_sample, lse_sample.unsqueeze(-1))
76 | lse = lse.squeeze(-1)
77 | t1 = time.time()
78 | check_memory()
79 | print('the runtime of flash attention with permutation and indexing of queries:', t1-t0)
80 |
81 | # torch lse computation
82 | qk = query_split_per_block @ key_split_per_block.transpose(-1, -2) / math.sqrt(dim)
83 | lse_block_torch = torch.logsumexp(qk, dim=-1, keepdim=True)
84 | lse_block_torch = lse_block_torch.view(batch_size, head_size, query_sorted.shape[2], -1)
85 | lse_block_torch = indexing(lse_block_torch, query_sort_idx_inv).squeeze(-1)
86 | lse_sample_torch = torch.logsumexp(
87 | query @ key[:, :, :sample_size, :].transpose(-1, -2) / math.sqrt(dim),
88 | dim=-1,
89 | keepdim=True
90 | ).squeeze(-1)
91 | lse_torch = (lse_sample_torch.exp() + lse_block_torch.exp()).log().to(dtype=lse_block_torch.dtype)
92 | print('diff between lse with sample and without: ', (lse_block_torch - lse_torch).norm(), lse_torch.norm())
93 | print('error flash attention:', (lse - lse_torch).norm(), lse_torch.norm())
94 |
95 | # compute attention kernel which permutes queries in triton
96 | check_memory(0)
97 | t2 = time.time()
98 | attn_hyper, lse_hyper = hyper_attn_func(
99 | query.transpose(1, 2),
100 | key.transpose(1, 2),
101 | value.transpose(1, 2),
102 | query_sort_idx.transpose(1, 2),
103 | key_sort_idx.transpose(1, 2),
104 | block_size,
105 | sample_size,
106 | 1./math.sqrt(dim),
107 | smooth_block,
108 | )
109 | attn_hyper = attn_hyper.transpose(1, 2)
110 | t3 = time.time()
111 | check_memory()
112 |
113 | print('the runtime of hyper attention:', t3 - t2)
114 |
115 | print('diff lse hyper_attention and flash with indexing and permutation: ', (lse - lse_hyper).norm(), lse.norm())
116 |
117 | print('error hyper attention lse: ', (lse_hyper - lse_torch).norm(), lse_torch.norm())
118 |
119 | # check if dimension of V can be different from that of Q and K
120 | value_small = value[:, :, :, :dim//2].clone()
121 | attn_triton_unequal_dim, lse_triton_unequal_dim = hyper_attn_func(
122 | query.transpose(1, 2),
123 | key.transpose(1, 2),
124 | value_small.transpose(1, 2),
125 | query_sort_idx.transpose(1, 2),
126 | key_sort_idx.transpose(1, 2),
127 | block_size,
128 | sample_size,
129 | )
130 | attn_triton_unequal_dim = attn_triton_unequal_dim.transpose(1, 2)
131 |
132 | print('testing unequal dimension for V compared to Q, K')
133 | print((attn_hyper[:, :, :, :dim//2] - attn_triton_unequal_dim).norm())
134 |
135 | def test_gradient(self):
136 | print()
137 | print("2. this is gradients test")
138 |
139 | dtype = torch.bfloat16
140 |
141 | batch_size = 4
142 | block_size = 256
143 | dim = 64
144 | head_size = 32
145 | seq_len = 2048
146 | sample_size = 128
147 |
148 | query = torch.randn((batch_size, head_size, seq_len, dim), device='cuda', dtype=dtype, requires_grad=True)
149 | key = torch.randn((batch_size, head_size, seq_len, dim), device='cuda', dtype=dtype, requires_grad=True)
150 | value = torch.randn((batch_size, head_size, seq_len, dim), device='cuda', dtype=dtype, requires_grad=True)
151 | do = torch.randn_like(value)
152 |
153 | q_buckets_idx = torch.randint(0, 64, (batch_size, head_size, seq_len), device='cuda')
154 | k_buckets_idx = torch.randint(0, 64, (batch_size, head_size, seq_len), device='cuda')
155 | _, query_sort_idx = torch.sort(q_buckets_idx, dim=2, stable=True)
156 | _, key_sort_idx = torch.sort(k_buckets_idx, dim=2, stable=True)
157 |
158 | t0 = time.time()
159 | query_sort_idx_inv = torch.argsort(query_sort_idx, dim=2, stable=True)
160 | query_sorted = indexing(query, query_sort_idx)
161 | key_sorted = indexing(key, key_sort_idx)
162 | value_sorted = indexing(value, key_sort_idx)
163 | query_split_per_block = query_sorted.view(-1, 1, block_size, dim)
164 | key_split_per_block = key_sorted.view(-1, 1, block_size, dim)
165 | value_split_per_block = value_sorted.view(-1, 1, block_size, dim)
166 |
167 | attn_block, lse_block = flash_attn_func(query_split_per_block.transpose(1, 2),
168 | key_split_per_block.transpose(1, 2),
169 | value_split_per_block.transpose(1, 2))
170 |
171 | attn_block = attn_block.transpose(1, 2)
172 | attn_block = attn_block.view(batch_size, head_size, query_sorted.shape[2], -1)
173 | attn = indexing(attn_block, query_sort_idx_inv)
174 |
175 | attn.backward(do, retain_graph=True)
176 | t1 = time.time()
177 | print('flash attention and indexing forward+backward time: ', t1-t0)
178 | q_grad = query.grad.detach().clone()
179 | k_grad = key.grad.detach().clone()
180 | v_grad = value.grad.detach().clone()
181 |
182 | query.grad = None
183 | key.grad = None
184 | value.grad = None
185 |
186 | # torch computation
187 |
188 | qk = query_split_per_block @ key_split_per_block.transpose(-1, -2) / math.sqrt(dim)
189 | attn_block_torch = qk.softmax(dim=-1) @ value_split_per_block
190 | attn_torch_block = indexing(
191 | attn_block_torch.view(batch_size, head_size, query_sorted.shape[2], -1),
192 | query_sort_idx_inv
193 | )
194 | lse_block_torch = torch.logsumexp(qk, dim=-1, keepdim=True)
195 | lse_torch_block = indexing(
196 | lse_block_torch.view(batch_size, head_size, query_sorted.shape[2], -1),
197 | query_sort_idx_inv
198 | )
199 |
200 | qk_sample = query @ key[:, :, :sample_size, :].transpose(-1, -2) / math.sqrt(dim)
201 | attn_torch_sample = qk_sample.softmax(dim=-1) @ value[:, :, :sample_size, :]
202 | lse_torch_sample = torch.logsumexp(qk_sample, dim=-1, keepdim=True)
203 |
204 | attn_torch, lse_torch = add_self_attentions(attn_torch_block, lse_torch_block, attn_torch_sample, lse_torch_sample)
205 | lse_torch = lse_torch.squeeze(-1)
206 | attn_torch.backward(do, retain_graph=True)
207 |
208 | q_grad_torch = query.grad.detach().clone()
209 | k_grad_torch = key.grad.detach().clone()
210 | v_grad_torch = value.grad.detach().clone()
211 |
212 | query.grad = None
213 | key.grad = None
214 | value.grad = None
215 |
216 | # hyper attention computation
217 |
218 | t2 = time.time()
219 | hyper_attn, hyper_lse = hyper_attn_func(
220 | query.transpose(1, 2),
221 | key.transpose(1, 2),
222 | value.transpose(1, 2),
223 | query_sort_idx.transpose(1, 2),
224 | key_sort_idx.transpose(1, 2),
225 | block_size,
226 | sample_size,
227 | )
228 | hyper_attn = hyper_attn.transpose(1, 2)
229 | hyper_attn.backward(do, retain_graph=True)
230 | t3 = time.time()
231 | print('hyper attention triton forward+backward time: ', t3 - t2)
232 |
233 | q_grad_hyper = query.grad.detach().clone()
234 | k_grad_hyper = key.grad.detach().clone()
235 | v_grad_hyper = value.grad.detach().clone()
236 |
237 | print('difference of torch attention and flash attn: ', (attn_torch-attn).norm(), attn_torch.norm())
238 | print('difference of torch lse and hyper_attention lse: ', (lse_torch - hyper_lse).norm(), lse_torch.norm())
239 |
240 | print('difference between gradients of queries, flash vs hyper:')
241 | print((q_grad - q_grad_hyper).norm())
242 |
243 | print('difference between gradients of keys, flash vs hyper:')
244 | print((k_grad - k_grad_hyper).norm())
245 |
246 | print('difference between gradients of values, flash vs hyper:')
247 | print((v_grad - v_grad_hyper).norm())
248 |
249 | print('difference gradients of queries, torch vs hyper:')
250 | print((q_grad_torch - q_grad_hyper).norm(), q_grad_torch.norm(), q_grad_hyper.norm())
251 |
252 | print('difference gradients of keys, torch vs hyper:')
253 | print((k_grad_torch - k_grad_hyper).norm(), k_grad_torch.norm(), k_grad_hyper.norm())
254 |
255 | print('difference gradients of values, torch vs hyper:')
256 | print((v_grad_torch - v_grad_hyper).norm(), v_grad_torch.norm(), v_grad_hyper.norm())
257 |
258 |
259 | if __name__ == '__main__':
260 | unittest.main()
261 | # test_runtime()
--------------------------------------------------------------------------------
/src/hyper_attn_triton.py:
--------------------------------------------------------------------------------
1 | """
2 | Implementation of HyperAttention in Triton.
3 | Tested with triton==2.1.0.
4 |
5 | Requirements:
6 | - This implementation does not support attention bias (additive mask to qk).
7 | - This implementation only supports sequence lengths that are integer powers of two.
8 | - the permutation indices for q and k must have the same sequence length as q and k themselves
9 | - sequence length for q and k must be equal
10 | """
11 |
12 | import math
13 |
14 | import torch
15 | import triton
16 | import triton.language as tl
17 |
18 | @triton.heuristics(
19 | {
20 | "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
21 | "EVEN_V_HEADDIM": lambda args: args["v_headdim"] == args["V_BLOCK_HEADDIM"],
22 | }
23 | )
24 | # bug when seqlen_q is not divisible by BLOCK_M=128
25 | @triton.jit
26 | def _fwd_hyper_kernel(
27 | Q,
28 | K,
29 | V,
30 | q_sort_idx,
31 | k_sort_idx,
32 | Out,
33 | Lse,
34 | softmax_scale,
35 | stride_qb,
36 | stride_qh,
37 | stride_qm,
38 | stride_kb,
39 | stride_kh,
40 | stride_kn,
41 | stride_vb,
42 | stride_vh,
43 | stride_vn,
44 | stride_q_sort_idxb,
45 | stride_q_sort_idxh,
46 | stride_q_sort_idxm,
47 | stride_k_sort_idxb,
48 | stride_k_sort_idxh,
49 | stride_k_sort_idxn,
50 | stride_ob,
51 | stride_oh,
52 | stride_om,
53 | nheads,
54 | block_size,
55 | sample_size,
56 | seqlen_k,
57 | seqlen_q,
58 | headdim,
59 | v_headdim,
60 | smooth_block,
61 | CACHE_KEY_SEQLEN_Q,
62 | CACHE_KEY_SEQLEN_K,
63 | BLOCK_HEADDIM: tl.constexpr,
64 | V_BLOCK_HEADDIM: tl.constexpr,
65 | EVEN_HEADDIM: tl.constexpr,
66 | EVEN_V_HEADDIM: tl.constexpr,
67 | BLOCK_M: tl.constexpr,
68 | BLOCK_N: tl.constexpr,
69 | ):
70 | start_m = tl.program_id(0)
71 | off_hb = tl.program_id(1)
72 | off_b = off_hb // nheads
73 | off_h = off_hb % nheads
74 | # initialize offsets
75 | offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
76 | offs_n = tl.arange(0, BLOCK_N)
77 | offs_d = tl.arange(0, BLOCK_HEADDIM)
78 | offs_vd = tl.arange(0, V_BLOCK_HEADDIM)
79 | # Initialize pointers to Q, K, V
80 | q_idx_ptrs = (
81 | q_sort_idx + off_b * stride_q_sort_idxb + off_h * stride_q_sort_idxh + offs_m * stride_q_sort_idxm
82 | )
83 | q_idx = tl.load(q_idx_ptrs).to(tl.int32)
84 |
85 | k_sort_idx += off_b * stride_k_sort_idxb + off_h * stride_k_sort_idxh
86 |
87 | # initialize pointer to m and l
88 | lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
89 | m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
90 | acc_o = tl.zeros([BLOCK_M, V_BLOCK_HEADDIM], dtype=tl.float32)
91 | q_ptrs = (
92 | Q + off_b * stride_qb + off_h * stride_qh + (q_idx[:, None] * stride_qm + offs_d[None, :])
93 | )
94 | if EVEN_HEADDIM:
95 | q = tl.load(q_ptrs)
96 | else:
97 | q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
98 |
99 | # block diagonal part
100 | # loop over k, v and update accumulator
101 | block_id = start_m // block_size
102 | block_offs = seqlen_k + (start_m % block_size) * BLOCK_N - (block_size-1) * BLOCK_N//2
103 | end_n = tl.minimum((block_id + 1) * BLOCK_N * block_size, seqlen_k)
104 | for start_n in range(block_id * BLOCK_N * block_size, end_n, BLOCK_N):
105 | start_n = tl.multiple_of(start_n, BLOCK_N)
106 | if smooth_block:
107 | k_idx_ptrs = ((start_n + block_offs + offs_n) * stride_k_sort_idxn) % seqlen_k
108 | else:
109 | k_idx_ptrs = (start_n + offs_n) * stride_k_sort_idxn
110 |
111 | k_idx = tl.load(k_sort_idx + k_idx_ptrs).to(tl.int32)
112 | k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (k_idx[:, None] * stride_kn + offs_d[None, :])
113 | # -- compute qk ----
114 | if EVEN_HEADDIM:
115 | k = tl.load(k_ptrs)
116 | else:
117 | k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
118 | qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
119 | qk += tl.dot(q, tl.trans(k))
120 | m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i)
121 | p = tl.exp(qk * softmax_scale - m_ij[:, None])
122 | l_ij = tl.sum(p, 1)
123 |
124 | # scale acc_o
125 | acc_o_scale = tl.exp(m_i - m_ij)
126 |
127 | # # -- update output accumulator acc_o --
128 | acc_o = acc_o * acc_o_scale[:, None]
129 |
130 | v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (k_idx[:, None] * stride_vn + offs_vd[None, :])
131 | if EVEN_V_HEADDIM:
132 | v = tl.load(v_ptrs)
133 | else:
134 | v = tl.load(v_ptrs, mask=offs_vd[None, :] < v_headdim, other=0.0)
135 | p = p.to(v.dtype)
136 | acc_o += tl.dot(p, v)
137 |
138 | # -- update statistics
139 | m_i = m_ij
140 | l_i_new = tl.exp(lse_i - m_ij) + l_ij
141 | lse_i = m_ij + tl.log(l_i_new)
142 | # compute sampled columns
143 | for col_block in range(0, sample_size):
144 | curr_offs_n = col_block * BLOCK_N * stride_kn + offs_n
145 | k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (curr_offs_n[:, None] * stride_kn + offs_d[None, :])
146 | # -- compute qk ----
147 | if EVEN_HEADDIM:
148 | k = tl.load(k_ptrs)
149 | else:
150 | k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
151 | qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
152 | qk += tl.dot(q, tl.trans(k))
153 | m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i)
154 | p = tl.exp(qk * softmax_scale - m_ij[:, None])
155 | l_ij = tl.sum(p, 1)
156 |
157 | # scale acc_o
158 | acc_o_scale = tl.exp(m_i - m_ij)
159 | # # -- update output accumulator acc_o --
160 | acc_o = acc_o * acc_o_scale[:, None]
161 |
162 | v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (curr_offs_n[:, None] * stride_vn + offs_vd[None, :])
163 | if EVEN_V_HEADDIM:
164 | v = tl.load(v_ptrs)
165 | else:
166 | v = tl.load(v_ptrs, mask=offs_vd[None, :] < v_headdim, other=0.0)
167 | p = p.to(v.dtype)
168 | acc_o += tl.dot(p, v)
169 |
170 | # -- update statistics
171 | m_i = m_ij
172 | l_i_new = tl.exp(lse_i - m_ij) + l_ij
173 | lse_i = m_ij + tl.log(l_i_new)
174 |
175 |
176 | o_scale = tl.exp(m_i - lse_i)
177 | acc_o = acc_o * o_scale[:, None]
178 |
179 | # initialize pointers to outputs
180 | lse_ptrs = Lse + off_hb * seqlen_q + q_idx
181 | out_ptrs = (
182 | Out
183 | + off_b * stride_ob
184 | + off_h * stride_oh
185 | + (q_idx[:, None] * stride_om + offs_vd[None, :])
186 | )
187 | # write back l and m
188 | tl.store(lse_ptrs, lse_i)
189 | if EVEN_V_HEADDIM:
190 | tl.store(out_ptrs, acc_o)
191 | else:
192 | tl.store(out_ptrs, acc_o, mask=offs_vd[None, :] < v_headdim)
193 |
194 |
195 | @triton.jit
196 | def _bwd_preprocess_do_o_dot(
197 | Out,
198 | DO,
199 | Delta,
200 | stride_ob,
201 | stride_oh,
202 | stride_om,
203 | stride_dob,
204 | stride_doh,
205 | stride_dom,
206 | nheads,
207 | seqlen_q,
208 | v_headdim,
209 | BLOCK_M: tl.constexpr,
210 | V_BLOCK_HEADDIM: tl.constexpr,
211 | ):
212 | start_m = tl.program_id(0)
213 | off_hb = tl.program_id(1)
214 | off_b = off_hb // nheads
215 | off_h = off_hb % nheads
216 | # initialize offsets
217 | offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
218 | offs_d = tl.arange(0, V_BLOCK_HEADDIM)
219 | # load
220 | o = tl.load(
221 | Out + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :],
222 | mask=offs_d[None, :] < v_headdim,
223 | other=0.0,
224 | ).to(tl.float32)
225 | do = tl.load(
226 | DO
227 | + off_b * stride_dob
228 | + off_h * stride_doh
229 | + offs_m[:, None] * stride_dom
230 | + offs_d[None, :],
231 | mask=offs_d[None, :] < v_headdim,
232 | other=0.0,
233 | ).to(tl.float32)
234 | delta = tl.sum(o * do, axis=1)
235 | # write-back
236 | tl.store(Delta + off_hb * seqlen_q + offs_m, delta)
237 |
238 |
239 | @triton.jit
240 | def _bwd_store_dx(
241 | dx_ptrs,
242 | dx,
243 | offs_d,
244 | headdim,
245 | even_headdim,
246 | ):
247 | if even_headdim:
248 | tl.store(dx_ptrs, dx)
249 | else:
250 | tl.store(dx_ptrs, dx, mask=offs_d[None, :] < headdim)
251 |
252 |
253 | @triton.jit
254 | def _bwd_blocked_kernel_one_col(
255 | start_n,
256 | Q,
257 | K,
258 | V,
259 | Q_idx,
260 | K_idx,
261 | DO,
262 | DQ,
263 | DK,
264 | DV,
265 | LSE,
266 | D,
267 | softmax_scale,
268 | stride_qm,
269 | stride_kn,
270 | stride_vn,
271 | stride_dom,
272 | stride_dqm,
273 | stride_dkn,
274 | stride_dvn,
275 | stride_q_idxm,
276 | stride_k_idxn,
277 | seqlen_q,
278 | block_size,
279 | headdim,
280 | v_headdim,
281 | smooth_block,
282 | BLOCK_HEADDIM: tl.constexpr,
283 | V_BLOCK_HEADDIM: tl.constexpr,
284 | EVEN_HEADDIM: tl.constexpr,
285 | EVEN_V_HEADDIM: tl.constexpr,
286 | BLOCK_M: tl.constexpr,
287 | BLOCK_N: tl.constexpr,
288 | ):
289 | # We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N)
290 | block_id = start_n // block_size
291 | block_offs = seqlen_q + (start_n % block_size) * BLOCK_M - (block_size - 1) * BLOCK_M // 2
292 | begin_m = block_id * BLOCK_M * block_size
293 | # initialize row / col offsets
294 | offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
295 | offs_m = tl.arange(0, BLOCK_M)
296 | offs_d = tl.arange(0, BLOCK_HEADDIM)
297 | offs_vd = tl.arange(0, V_BLOCK_HEADDIM)
298 | # initialize pointers to value-like data
299 | k_idx_ptrs = K_idx + offs_n * stride_k_idxn
300 | k_idx = tl.load(k_idx_ptrs).to(tl.int32)
301 | k_ptrs = K + (k_idx[:, None] * stride_kn + offs_d[None, :])
302 | v_ptrs = V + (k_idx[:, None] * stride_vn + offs_vd[None, :])
303 | # initialize dv and dk
304 | dv = tl.zeros([BLOCK_N, V_BLOCK_HEADDIM], dtype=tl.float32)
305 | dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
306 |
307 | # k and v stay in SRAM throughout
308 | if EVEN_HEADDIM:
309 | k = tl.load(k_ptrs)
310 | else:
311 | k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
312 | if EVEN_V_HEADDIM:
313 | v = tl.load(v_ptrs)
314 | else:
315 | v = tl.load(v_ptrs, mask=offs_vd[None, :] < v_headdim, other=0.0)
316 |
317 | # loop over rows
318 | end_m = tl.minimum((block_id + 1) * BLOCK_M * block_size, seqlen_q)
319 | for start_m in range(begin_m, end_m, BLOCK_M):
320 | start_m = tl.multiple_of(start_m, BLOCK_M)
321 | if smooth_block:
322 | q_idx_ptrs = ((start_m + block_offs + offs_m) * stride_q_idxm) % seqlen_q
323 | else:
324 | q_idx_ptrs = (start_m + offs_m) * stride_q_idxm
325 | q_idx = tl.load(Q_idx + q_idx_ptrs).to(tl.int32)
326 | q_ptrs = Q + (q_idx[:, None] * stride_qm + offs_d[None, :])
327 | # load q, k, v, do on-chip
328 | if EVEN_HEADDIM:
329 | q = tl.load(q_ptrs)
330 | else:
331 | q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
332 | # recompute p = softmax(qk, dim=-1).T
333 | qk = tl.dot(q, tl.trans(k))
334 | if not EVEN_HEADDIM:
335 | tl.debug_barrier()
336 | lse_i = tl.load(LSE + q_idx)
337 | p = tl.exp(qk * softmax_scale - lse_i[:, None])
338 | # compute dv
339 | do_ptrs = DO + (q_idx[:, None] * stride_dom + offs_vd[None, :])
340 | if EVEN_V_HEADDIM:
341 | do = tl.load(do_ptrs)
342 | else:
343 | do = tl.load(do_ptrs, mask=offs_vd[None, :] < v_headdim, other=0.0)
344 | dv += tl.dot(tl.trans(p.to(do.dtype)), do)
345 | # compute dp = dot(v, do)
346 | if not EVEN_HEADDIM:
347 | tl.debug_barrier()
348 | dp = tl.dot(do, tl.trans(v))
349 | # There's a race condition for headdim=48
350 | if not EVEN_HEADDIM:
351 | tl.debug_barrier()
352 | # compute ds = p * (dp - delta[:, None])
353 | # Putting the subtraction after the dp matmul (instead of before) is slightly faster
354 | Di = tl.load(D + q_idx)
355 | # Converting ds to q.dtype here reduces register pressure and makes it much faster
356 | # for BLOCK_HEADDIM=128
357 | ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype)
358 | # compute dk = dot(ds.T, q)
359 | dk += tl.dot(tl.trans(ds), q)
360 | # compute dq
361 | if not EVEN_HEADDIM: # Otherewise there's a race condition when BIAS_TYPE='matrix'
362 | tl.debug_barrier()
363 |
364 | dq_ptrs = DQ + (q_idx[:, None] * stride_dqm + offs_d[None, :])
365 | dq = tl.dot(ds, k)
366 | if EVEN_HEADDIM:
367 | tl.atomic_add(dq_ptrs, dq)
368 | else:
369 | tl.atomic_add(dq_ptrs, dq, mask=offs_d[None, :] < headdim)
370 |
371 |
372 | # write-back
373 | dv_ptrs = DV + (k_idx[:, None] * stride_dvn + offs_vd[None, :])
374 | dk_ptrs = DK + (k_idx[:, None] * stride_dkn + offs_d[None, :])
375 | _bwd_store_dx(
376 | dk_ptrs,
377 | dk,
378 | offs_d,
379 | headdim,
380 | even_headdim=EVEN_HEADDIM,
381 | )
382 | _bwd_store_dx(
383 | dv_ptrs,
384 | dv,
385 | offs_vd,
386 | v_headdim,
387 | even_headdim=EVEN_V_HEADDIM,
388 | )
389 |
390 |
391 |
392 | @triton.heuristics(
393 | {
394 | "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
395 | "EVEN_V_HEADDIM": lambda args: args["v_headdim"] == args["V_BLOCK_HEADDIM"],
396 | }
397 | )
398 | @triton.jit
399 | def _bwd_permuted_block_diagonal_kernel(
400 | Q,
401 | K,
402 | V,
403 | q_sort_idx,
404 | k_sort_idx,
405 | DO,
406 | DQ,
407 | DK,
408 | DV,
409 | LSE,
410 | D,
411 | softmax_scale,
412 | stride_qb,
413 | stride_qh,
414 | stride_qm,
415 | stride_kb,
416 | stride_kh,
417 | stride_kn,
418 | stride_vb,
419 | stride_vh,
420 | stride_vn,
421 | stride_q_sort_idxb,
422 | stride_q_sort_idxh,
423 | stride_q_sort_idxm,
424 | stride_k_sort_idxb,
425 | stride_k_sort_idxh,
426 | stride_k_sort_idxn,
427 | stride_dob,
428 | stride_doh,
429 | stride_dom,
430 | stride_dqb,
431 | stride_dqh,
432 | stride_dqm,
433 | stride_dkb,
434 | stride_dkh,
435 | stride_dkn,
436 | stride_dvb,
437 | stride_dvh,
438 | stride_dvn,
439 | nheads,
440 | seqlen_q,
441 | block_size,
442 | headdim,
443 | v_headdim,
444 | smooth_block,
445 | CACHE_KEY_SEQLEN_Q,
446 | CACHE_KEY_SEQLEN_K,
447 | BLOCK_HEADDIM: tl.constexpr,
448 | V_BLOCK_HEADDIM: tl.constexpr,
449 | EVEN_HEADDIM: tl.constexpr,
450 | EVEN_V_HEADDIM: tl.constexpr,
451 | BLOCK_M: tl.constexpr,
452 | BLOCK_N: tl.constexpr,
453 | ):
454 | off_hb = tl.program_id(1)
455 | off_b = off_hb // nheads
456 | off_h = off_hb % nheads
457 | # offset pointers for batch/head
458 | Q += off_b * stride_qb + off_h * stride_qh
459 | K += off_b * stride_kb + off_h * stride_kh
460 | V += off_b * stride_vb + off_h * stride_vh
461 | Q_idx = q_sort_idx + off_b * stride_q_sort_idxb + off_h * stride_q_sort_idxh
462 | K_idx = k_sort_idx + off_b * stride_k_sort_idxb + off_h * stride_k_sort_idxh
463 | DO += off_b * stride_dob + off_h * stride_doh
464 | DQ += off_b * stride_dqb + off_h * stride_dqh
465 | DK += off_b * stride_dkb + off_h * stride_dkh
466 | DV += off_b * stride_dvb + off_h * stride_dvh
467 | # pointer to row-wise quantities in value-like data
468 | D += off_hb * seqlen_q
469 | LSE += off_hb * seqlen_q
470 |
471 | start_n = tl.program_id(0)
472 | _bwd_blocked_kernel_one_col(
473 | start_n=start_n,
474 | Q=Q,
475 | K=K,
476 | V=V,
477 | Q_idx=Q_idx,
478 | K_idx=K_idx,
479 | DO=DO,
480 | DQ=DQ,
481 | DK=DK,
482 | DV=DV,
483 | LSE=LSE,
484 | D=D,
485 | softmax_scale=softmax_scale,
486 | stride_qm=stride_qm,
487 | stride_kn=stride_kn,
488 | stride_vn=stride_vn,
489 | stride_dom=stride_dom,
490 | stride_dqm=stride_dqm,
491 | stride_dkn=stride_dkn,
492 | stride_dvn=stride_dvn,
493 | stride_q_idxm=stride_q_sort_idxm,
494 | stride_k_idxn=stride_k_sort_idxn,
495 | seqlen_q=seqlen_q,
496 | block_size=block_size // BLOCK_N,
497 | headdim=headdim,
498 | v_headdim=v_headdim,
499 | smooth_block=smooth_block,
500 | BLOCK_HEADDIM=BLOCK_HEADDIM,
501 | V_BLOCK_HEADDIM=V_BLOCK_HEADDIM,
502 | EVEN_HEADDIM=EVEN_HEADDIM,
503 | EVEN_V_HEADDIM=EVEN_V_HEADDIM,
504 | BLOCK_M=BLOCK_M,
505 | BLOCK_N=BLOCK_N,
506 | )
507 |
508 |
509 |
510 | @triton.heuristics(
511 | {
512 | "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
513 | "EVEN_V_HEADDIM": lambda args: args["v_headdim"] == args["V_BLOCK_HEADDIM"],
514 | }
515 | )
516 | @triton.jit
517 | def _bwd_sampled_col_kernel(
518 | Q,
519 | K,
520 | V,
521 | DO,
522 | DQ,
523 | DK,
524 | DV,
525 | LSE,
526 | D,
527 | softmax_scale,
528 | stride_qb,
529 | stride_qh,
530 | stride_qm,
531 | stride_kb,
532 | stride_kh,
533 | stride_kn,
534 | stride_vb,
535 | stride_vh,
536 | stride_vn,
537 | stride_dob,
538 | stride_doh,
539 | stride_dom,
540 | stride_dqb,
541 | stride_dqh,
542 | stride_dqm,
543 | stride_dkb,
544 | stride_dkh,
545 | stride_dkn,
546 | stride_dvb,
547 | stride_dvh,
548 | stride_dvn,
549 | nheads,
550 | seqlen_q,
551 | headdim,
552 | v_headdim,
553 | CACHE_KEY_SEQLEN_Q,
554 | CACHE_KEY_SEQLEN_K,
555 | BLOCK_HEADDIM: tl.constexpr,
556 | V_BLOCK_HEADDIM: tl.constexpr,
557 | EVEN_HEADDIM: tl.constexpr,
558 | EVEN_V_HEADDIM: tl.constexpr,
559 | BLOCK_M: tl.constexpr,
560 | BLOCK_N: tl.constexpr,
561 | ):
562 | off_hb = tl.program_id(1)
563 | off_b = off_hb // nheads
564 | off_h = off_hb % nheads
565 | # offset pointers for batch/head
566 | Q += off_b * stride_qb + off_h * stride_qh
567 | DO += off_b * stride_dob + off_h * stride_doh
568 | DQ += off_b * stride_dqb + off_h * stride_dqh
569 | # pointer to row-wise quantities in value-like data
570 | D += off_hb * seqlen_q
571 | LSE += off_hb * seqlen_q
572 |
573 | start_n = tl.program_id(0)
574 |
575 | # initialize row / col offsets
576 | offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
577 | offs_m = tl.arange(0, BLOCK_M)
578 | offs_d = tl.arange(0, BLOCK_HEADDIM)
579 | offs_vd = tl.arange(0, V_BLOCK_HEADDIM)
580 | # initialize pointers to value-like data
581 | k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :])
582 | v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_vd[None, :])
583 | # initialize dv and dk
584 | dv = tl.zeros([BLOCK_N, V_BLOCK_HEADDIM], dtype=tl.float32)
585 | dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
586 |
587 | # k and v stay in SRAM throughout
588 | if EVEN_HEADDIM:
589 | k = tl.load(k_ptrs)
590 | else:
591 | k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
592 | if EVEN_V_HEADDIM:
593 | v = tl.load(v_ptrs)
594 | else:
595 | v = tl.load(v_ptrs, mask=offs_vd[None, :] < v_headdim, other=0.0)
596 |
597 | # loop over rows
598 | for start_m in range(0, seqlen_q, BLOCK_M):
599 | start_m = tl.multiple_of(start_m, BLOCK_M)
600 | offs_m_curr = start_m + offs_m
601 | q_ptrs = Q + (offs_m_curr[:, None] * stride_qm + offs_d[None, :])
602 | # load q, k, v, do on-chip
603 | if EVEN_HEADDIM:
604 | q = tl.load(q_ptrs)
605 | else:
606 | q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
607 | # recompute p = softmax(qk, dim=-1).T
608 | qk = tl.dot(q, tl.trans(k))
609 | if not EVEN_HEADDIM:
610 | tl.debug_barrier()
611 | lse_i = tl.load(LSE + offs_m_curr)
612 | p = tl.exp(qk * softmax_scale - lse_i[:, None])
613 | # compute dv
614 | do_ptrs = DO + (offs_m_curr[:, None] * stride_dom + offs_vd[None, :])
615 | if EVEN_V_HEADDIM:
616 | do = tl.load(do_ptrs)
617 | else:
618 | do = tl.load(do_ptrs, mask=offs_vd[None, :] < v_headdim, other=0.0)
619 | dv += tl.dot(tl.trans(p.to(do.dtype)), do)
620 | # compute dp = dot(v, do)
621 | if not EVEN_HEADDIM:
622 | tl.debug_barrier()
623 | dp = tl.dot(do, tl.trans(v))
624 | # There's a race condition for headdim=48
625 | if not EVEN_HEADDIM:
626 | tl.debug_barrier()
627 | # compute ds = p * (dp - delta[:, None])
628 | # Putting the subtraction after the dp matmul (instead of before) is slightly faster
629 | Di = tl.load(D + offs_m_curr)
630 | # Converting ds to q.dtype here reduces register pressure and makes it much faster
631 | # for BLOCK_HEADDIM=128
632 | ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype)
633 | # compute dk = dot(ds.T, q)
634 | dk += tl.dot(tl.trans(ds), q)
635 | # compute dq
636 | if not EVEN_HEADDIM: # Otherewise there's a race condition when BIAS_TYPE='matrix'
637 | tl.debug_barrier()
638 |
639 | dq_ptrs = DQ + (offs_m_curr[:, None] * stride_dqm + offs_d[None, :])
640 | dq = tl.dot(ds, k)
641 | if EVEN_HEADDIM:
642 | tl.atomic_add(dq_ptrs, dq)
643 | else:
644 | tl.atomic_add(dq_ptrs, dq, mask=offs_d[None, :] < headdim)
645 |
646 | dv_ptrs = DV + off_b * stride_dvb + off_h * stride_dvh + (offs_n[:, None] * stride_dvn + offs_vd[None, :])
647 | dk_ptrs = DK + off_b * stride_dkb + off_h * stride_dkh + (offs_n[:, None] * stride_dkn + offs_d[None, :])
648 | dk += tl.load(dk_ptrs)
649 | dv += tl.load(dv_ptrs)
650 | _bwd_store_dx(
651 | dk_ptrs,
652 | dk,
653 | offs_d,
654 | headdim,
655 | even_headdim=EVEN_HEADDIM,
656 | )
657 | _bwd_store_dx(
658 | dv_ptrs,
659 | dv,
660 | offs_vd,
661 | v_headdim,
662 | even_headdim=EVEN_V_HEADDIM,
663 | )
664 |
665 | return
666 |
667 |
668 | def _hyper_attn_forward(q, k, v, q_sort_idx, k_sort_idx, block_size, sample_size, softmax_scale=None,
669 | smooth_block=False):
670 | """
671 | Initializes the forward kernel and schedules thread blocks and runs them in parallel
672 | """
673 | # shape constraints
674 | batch, seqlen_q, nheads, d = q.shape
675 | _, seqlen_k, _, _ = k.shape
676 | _, seqlen_q_idx,_ = q_sort_idx.shape
677 | _, seqlen_k_idx, _ = k_sort_idx.shape
678 | assert k.shape == (batch, seqlen_k, nheads, d)
679 | assert v.shape[:3] == (batch, seqlen_k, nheads)
680 | assert q_sort_idx.shape == q.shape[:3]
681 | assert k_sort_idx.shape == k.shape[:3]
682 | assert d <= 128, "FlashAttention only support head dimensions up to 128"
683 | assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type"
684 | assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16"
685 | assert q.is_cuda and k.is_cuda and v.is_cuda and q_sort_idx.is_cuda and k_sort_idx.is_cuda
686 | softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
687 | lse = torch.empty((batch, nheads, seqlen_q), device=q.device, dtype=torch.float32)
688 | # o = torch.empty_like(q)
689 | o = torch.empty((batch, seqlen_q, nheads, v.shape[-1]), device=q.device, dtype=q.dtype)
690 |
691 | BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
692 | v_headdim = v.shape[3]
693 | V_BLOCK_HEADDIM = max(triton.next_power_of_2(v_headdim), 16)
694 | BLOCK = 128
695 | assert seqlen_k % BLOCK == 0, f'keys sequence length must be divisible by {BLOCK}'
696 | num_warps = 4 if d <= 64 else 8
697 | grid = lambda META: (triton.cdiv(seqlen_q_idx, META["BLOCK_M"]), batch * nheads)
698 | _fwd_hyper_kernel[grid](
699 | Q=q,
700 | K=k,
701 | V=v,
702 | q_sort_idx=q_sort_idx,
703 | k_sort_idx=k_sort_idx,
704 | Out=o,
705 | Lse=lse,
706 | softmax_scale=softmax_scale,
707 | stride_qb=q.stride(0),
708 | stride_qh=q.stride(2),
709 | stride_qm=q.stride(1),
710 | stride_kb=k.stride(0),
711 | stride_kh=k.stride(2),
712 | stride_kn=k.stride(1),
713 | stride_vb=v.stride(0),
714 | stride_vh=v.stride(2),
715 | stride_vn=v.stride(1),
716 | stride_q_sort_idxb=q_sort_idx.stride(0),
717 | stride_q_sort_idxh=q_sort_idx.stride(2),
718 | stride_q_sort_idxm=q_sort_idx.stride(1),
719 | stride_k_sort_idxb=k_sort_idx.stride(0),
720 | stride_k_sort_idxh=k_sort_idx.stride(2),
721 | stride_k_sort_idxn=k_sort_idx.stride(1),
722 | stride_ob=o.stride(0),
723 | stride_oh=o.stride(2),
724 | stride_om=o.stride(1),
725 | nheads=nheads,
726 | block_size=triton.cdiv(block_size, BLOCK),
727 | sample_size=triton.cdiv(sample_size, BLOCK),
728 | seqlen_k=seqlen_k,
729 | seqlen_q=seqlen_q,
730 | headdim=d,
731 | v_headdim=v_headdim,
732 | smooth_block=smooth_block,
733 | CACHE_KEY_SEQLEN_Q=seqlen_q // 32,
734 | CACHE_KEY_SEQLEN_K=seqlen_k // 32,
735 | BLOCK_HEADDIM=BLOCK_HEADDIM,
736 | V_BLOCK_HEADDIM=V_BLOCK_HEADDIM,
737 | BLOCK_M=BLOCK,
738 | BLOCK_N=BLOCK,
739 | num_warps=num_warps,
740 | num_stages=1,
741 | )
742 | return o, lse, softmax_scale # softmax_scale could have been updated
743 |
744 |
745 | def _hyper_attn_backward(
746 | do, q, k, v, q_sort_idx, k_sort_idx, o, lse, dq, dk, dv, block_size, sample_size, softmax_scale=None,
747 | smooth_block=False):
748 | """
749 | Initializes the backward kernel and schedules thread blocks and runs them in parallel
750 | """
751 | # Make sure that the last dimension is contiguous
752 | if do.stride(-1) != 1:
753 | do = do.contiguous()
754 | batch, seqlen_q, nheads, d = q.shape
755 | _, seqlen_k, _, _ = k.shape
756 | # assert d in {16, 32, 64, 128}
757 | assert d <= 128
758 | assert lse.shape == (batch, nheads, seqlen_q)
759 | assert q.stride(-1) == k.stride(-1) == v.stride(-1) == o.stride(-1) == 1
760 | assert dq.stride(-1) == dk.stride(-1) == dv.stride(-1) == do.stride(-1) == 1
761 | softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
762 |
763 | dq_accum = torch.zeros_like(q, dtype=torch.float32)
764 | delta = torch.empty_like(lse)
765 |
766 | v_headdim = v.shape[3]
767 | V_BLOCK_HEADDIM = max(triton.next_power_of_2(v_headdim), 16)
768 | BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
769 | grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
770 | _bwd_preprocess_do_o_dot[grid](
771 | Out=o,
772 | DO=do,
773 | Delta=delta,
774 | stride_ob=o.stride(0),
775 | stride_oh=o.stride(2),
776 | stride_om=o.stride(1),
777 | stride_dob=do.stride(0),
778 | stride_doh=do.stride(2),
779 | stride_dom=do.stride(1),
780 | nheads=nheads,
781 | seqlen_q=seqlen_q,
782 | v_headdim=v_headdim,
783 | BLOCK_M=128,
784 | V_BLOCK_HEADDIM=V_BLOCK_HEADDIM,
785 | )
786 |
787 | BLOCK = 128
788 | num_warps = 8
789 | grid = lambda META: (triton.cdiv(seqlen_k, BLOCK), batch * nheads)
790 | _bwd_permuted_block_diagonal_kernel[grid](
791 | Q=q,
792 | K=k,
793 | V=v,
794 | q_sort_idx=q_sort_idx,
795 | k_sort_idx=k_sort_idx,
796 | DO=do,
797 | DQ=dq_accum,
798 | DK=dk,
799 | DV=dv,
800 | LSE=lse,
801 | D=delta,
802 | softmax_scale=softmax_scale,
803 | stride_qb=q.stride(0),
804 | stride_qh=q.stride(2),
805 | stride_qm=q.stride(1),
806 | stride_kb=k.stride(0),
807 | stride_kh=k.stride(2),
808 | stride_kn=k.stride(1),
809 | stride_vb=v.stride(0),
810 | stride_vh=v.stride(2),
811 | stride_vn=v.stride(1),
812 | stride_q_sort_idxb=q_sort_idx.stride(0),
813 | stride_q_sort_idxh=q_sort_idx.stride(2),
814 | stride_q_sort_idxm=q_sort_idx.stride(1),
815 | stride_k_sort_idxb=k_sort_idx.stride(0),
816 | stride_k_sort_idxh=k_sort_idx.stride(2),
817 | stride_k_sort_idxn=k_sort_idx.stride(1),
818 | stride_dob=do.stride(0),
819 | stride_doh=do.stride(2),
820 | stride_dom=do.stride(1),
821 | stride_dqb=dq_accum.stride(0),
822 | stride_dqh=dq_accum.stride(2),
823 | stride_dqm=dq_accum.stride(1),
824 | stride_dkb=dk.stride(0),
825 | stride_dkh=dk.stride(2),
826 | stride_dkn=dk.stride(1),
827 | stride_dvb=dv.stride(0),
828 | stride_dvh=dv.stride(2),
829 | stride_dvn=dv.stride(1),
830 | nheads=nheads,
831 | seqlen_q=seqlen_q,
832 | block_size=block_size,
833 | headdim=d,
834 | v_headdim=v_headdim,
835 | smooth_block=smooth_block,
836 | CACHE_KEY_SEQLEN_Q=seqlen_q // 32,
837 | CACHE_KEY_SEQLEN_K=seqlen_k // 32, # key for triton cache (limit number of compilations)
838 | BLOCK_HEADDIM=BLOCK_HEADDIM,
839 | V_BLOCK_HEADDIM=V_BLOCK_HEADDIM,
840 | BLOCK_M=BLOCK,
841 | BLOCK_N=BLOCK,
842 | num_warps=num_warps,
843 | num_stages=1,
844 | )
845 |
846 | grid = lambda META: (triton.cdiv(sample_size, BLOCK), batch * nheads)
847 | _bwd_sampled_col_kernel[grid](
848 | Q=q,
849 | K=k,
850 | V=v,
851 | DO=do,
852 | DQ=dq_accum,
853 | DK=dk,
854 | DV=dv,
855 | LSE=lse,
856 | D=delta,
857 | softmax_scale=softmax_scale,
858 | stride_qb=q.stride(0),
859 | stride_qh=q.stride(2),
860 | stride_qm=q.stride(1),
861 | stride_kb=k.stride(0),
862 | stride_kh=k.stride(2),
863 | stride_kn=k.stride(1),
864 | stride_vb=v.stride(0),
865 | stride_vh=v.stride(2),
866 | stride_vn=v.stride(1),
867 | stride_dob=do.stride(0),
868 | stride_doh=do.stride(2),
869 | stride_dom=do.stride(1),
870 | stride_dqb=dq_accum.stride(0),
871 | stride_dqh=dq_accum.stride(2),
872 | stride_dqm=dq_accum.stride(1),
873 | stride_dkb=dk.stride(0),
874 | stride_dkh=dk.stride(2),
875 | stride_dkn=dk.stride(1),
876 | stride_dvb=dv.stride(0),
877 | stride_dvh=dv.stride(2),
878 | stride_dvn=dv.stride(1),
879 | nheads=nheads,
880 | seqlen_q=seqlen_q,
881 | headdim=d,
882 | v_headdim=v_headdim,
883 | CACHE_KEY_SEQLEN_Q=seqlen_q // 32,
884 | CACHE_KEY_SEQLEN_K=seqlen_k // 32, # key for triton cache (limit number of compilations)
885 | BLOCK_HEADDIM=BLOCK_HEADDIM,
886 | V_BLOCK_HEADDIM=V_BLOCK_HEADDIM,
887 | BLOCK_M=BLOCK,
888 | BLOCK_N=BLOCK,
889 | num_warps=num_warps,
890 | num_stages=1,
891 | )
892 | dq.copy_(dq_accum)
893 |
894 |
895 | class HyperAttnFunc(torch.autograd.Function):
896 | @staticmethod
897 | def forward(ctx, q, k, v, q_sort_idx, k_sort_idx, block_size, sample_size=0, softmax_scale=None,
898 | smooth_block=False):
899 | """
900 | q, k: queries and keys (batch_size, seqlen, nheads, headdim), seqlen must be integer power of two
901 | v: values (batch_size, seqlen, nheads, v_headdim)
902 | q_sort_idx: the permutation for queries (batch_size, seqlen, nheads)
903 | k_sort_idx: the permutation for keys and values (batch_size, seqlen, nheads)
904 | block_size: side length of block diagonal blocks
905 | sample_size: number of sampled columns, must be multiple of 128
906 | softmax_scale: if none then scale will be 1/sqrt(headdim)
907 | smooth_block: if true the block diagonals will be smoothened to resemble banded digonal patterns
908 | """
909 | # Make sure that the last dimension is contiguous
910 | q, k, v = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v]]
911 | assert sample_size % 128 == 0
912 | o, lse, ctx.softmax_scale = _hyper_attn_forward(
913 | q, k, v, q_sort_idx, k_sort_idx, block_size, sample_size,
914 | softmax_scale=softmax_scale, smooth_block=smooth_block,
915 | )
916 | ctx.save_for_backward(q, k, v, q_sort_idx, k_sort_idx, o, lse)
917 | ctx.block_size = block_size
918 | ctx.sample_size = sample_size
919 | ctx.smooth_block = smooth_block
920 | return o, lse
921 |
922 | @staticmethod
923 | def backward(ctx, do, dlse_use_needed=None):
924 | q, k, v, q_sort_idx, k_sort_idx, o, lse = ctx.saved_tensors
925 | dq = torch.zeros_like(q)
926 | dk = torch.zeros_like(k)
927 | dv = torch.zeros_like(v)
928 | _hyper_attn_backward(
929 | do,
930 | q,
931 | k,
932 | v,
933 | q_sort_idx,
934 | k_sort_idx,
935 | o,
936 | lse,
937 | dq,
938 | dk,
939 | dv,
940 | ctx.block_size,
941 | ctx.sample_size,
942 | softmax_scale=ctx.softmax_scale,
943 | smooth_block=ctx.smooth_block,
944 | )
945 | return dq, dk, dv, None, None, None, None, None, None
946 |
947 |
948 | hyper_attn_func = HyperAttnFunc.apply
949 |
950 |
--------------------------------------------------------------------------------
/src/flash_attn_triton.py:
--------------------------------------------------------------------------------
1 | """
2 | *Experimental* implementation of FlashAttention in Triton.
3 | Tested with triton==2.0.0.dev20221202.
4 | Triton 2.0 has a new backend (MLIR) but seems like it doesn't yet work for head dimensions
5 | other than 64:
6 | https://github.com/openai/triton/blob/d376020f90002757eea3ea9475d4f7cfc2ec5ead/python/triton/ops/flash_attention.py#L207
7 | We'll update this implementation with the new Triton backend once this is fixed.
8 |
9 | We use the FlashAttention implementation from Phil Tillet a starting point.
10 | https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
11 |
12 | Changes:
13 | - Implement both causal and non-causal attention.
14 | - Implement both self-attention and cross-attention.
15 | - Support arbitrary seqlens (not just multiples of 128), for both forward and backward.
16 | - Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both forward and backward.
17 | - Support attention bias.
18 | - Speed up the forward pass a bit, and only store the LSE instead of m and l.
19 | - Make the backward for d=128 much faster by reducing register spilling.
20 | - Optionally parallelize the backward pass across seqlen_k, to deal with the case of
21 | small batch size * nheads.
22 |
23 | Caution:
24 | - This is an *experimental* implementation. The forward pass should be quite robust but
25 | I'm not 100% sure that the backward pass doesn't have race conditions (due to the Triton compiler).
26 | - This implementation has only been tested on A100.
27 | - If you plan to use headdim other than 64 and 128, you should test for race conditions
28 | (due to the Triton compiler), as done in tests/test_flash_attn.py
29 | "test_flash_attn_triton_race_condition". I've tested and fixed many race conditions
30 | for different head dimensions (40, 48, 64, 128, 80, 88, 96), but I'm still not 100% confident
31 | that there are none left for other head dimensions.
32 |
33 | Differences between this Triton version and the CUDA version:
34 | - Triton version doesn't support dropout.
35 | - Triton forward is generally faster than CUDA forward, while Triton backward is
36 | generally slower than CUDA backward. Overall Triton forward + backward is slightly slower
37 | than CUDA forward + backward.
38 | - Triton version doesn't support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor).
39 | - Triton version supports attention bias, while CUDA version doesn't.
40 | """
41 |
42 | import math
43 |
44 | import torch
45 | import triton
46 | import triton.language as tl
47 |
48 |
49 | @triton.heuristics(
50 | {
51 | "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
52 | "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
53 | "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
54 | }
55 | )
56 | @triton.jit
57 | def _fwd_kernel(
58 | Q,
59 | K,
60 | V,
61 | Bias,
62 | Out,
63 | Lse,
64 | softmax_scale,
65 | stride_qb,
66 | stride_qh,
67 | stride_qm,
68 | stride_kb,
69 | stride_kh,
70 | stride_kn,
71 | stride_vb,
72 | stride_vh,
73 | stride_vn,
74 | stride_bb,
75 | stride_bh,
76 | stride_bm,
77 | stride_ob,
78 | stride_oh,
79 | stride_om,
80 | nheads,
81 | seqlen_q,
82 | seqlen_k,
83 | seqlen_q_rounded,
84 | headdim,
85 | CACHE_KEY_SEQLEN_Q,
86 | CACHE_KEY_SEQLEN_K,
87 | BIAS_TYPE: tl.constexpr,
88 | IS_CAUSAL: tl.constexpr,
89 | BLOCK_HEADDIM: tl.constexpr,
90 | EVEN_M: tl.constexpr,
91 | EVEN_N: tl.constexpr,
92 | EVEN_HEADDIM: tl.constexpr,
93 | BLOCK_M: tl.constexpr,
94 | BLOCK_N: tl.constexpr,
95 | ):
96 | start_m = tl.program_id(0)
97 | off_hb = tl.program_id(1)
98 | off_b = off_hb // nheads
99 | off_h = off_hb % nheads
100 | # off_b = tl.program_id(1)
101 | # off_h = tl.program_id(2)
102 | # off_hb = off_b * nheads + off_h
103 | # initialize offsets
104 | offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
105 | offs_n = tl.arange(0, BLOCK_N)
106 | offs_d = tl.arange(0, BLOCK_HEADDIM)
107 | # Initialize pointers to Q, K, V
108 | # Adding parenthesis around indexing might use int32 math instead of int64 math?
109 | # https://github.com/openai/triton/issues/741
110 | # I'm seeing a tiny bit of difference (5-7us)
111 | q_ptrs = (
112 | Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :])
113 | )
114 | k_ptrs = (
115 | K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :])
116 | )
117 | v_ptrs = (
118 | V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])
119 | )
120 | if BIAS_TYPE == "vector":
121 | b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n
122 | elif BIAS_TYPE == "matrix":
123 | b_ptrs = (
124 | Bias
125 | + off_b * stride_bb
126 | + off_h * stride_bh
127 | + (offs_m[:, None] * stride_bm + offs_n[None, :])
128 | )
129 | # initialize pointer to m and l
130 | lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
131 | m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
132 | acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
133 | # load q: it will stay in SRAM throughout
134 | # [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call
135 | # tl.load(q_ptrs), we get the wrong output!
136 | if EVEN_M & EVEN_N:
137 | if EVEN_HEADDIM:
138 | q = tl.load(q_ptrs)
139 | else:
140 | q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
141 | else:
142 | if EVEN_HEADDIM:
143 | q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
144 | else:
145 | q = tl.load(
146 | q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0
147 | )
148 | # loop over k, v and update accumulator
149 | end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
150 | for start_n in range(0, end_n, BLOCK_N):
151 | start_n = tl.multiple_of(start_n, BLOCK_N)
152 | # -- compute qk ----
153 | if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition
154 | if EVEN_HEADDIM:
155 | k = tl.load(k_ptrs + start_n * stride_kn)
156 | else:
157 | k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0)
158 | else:
159 | if EVEN_HEADDIM:
160 | k = tl.load(
161 | k_ptrs + start_n * stride_kn,
162 | mask=(start_n + offs_n)[:, None] < seqlen_k,
163 | other=0.0,
164 | )
165 | else:
166 | k = tl.load(
167 | k_ptrs + start_n * stride_kn,
168 | mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
169 | other=0.0,
170 | )
171 | qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
172 | qk += tl.dot(q, tl.trans(k))
173 | # Trying to combine the two masks seem to make the result wrong
174 | if not EVEN_N: # Need to mask out otherwise the softmax is wrong
175 | qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf"))
176 | if IS_CAUSAL:
177 | qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf"))
178 | if BIAS_TYPE != "none":
179 | if BIAS_TYPE == "vector":
180 | if EVEN_N:
181 | bias = tl.load(b_ptrs + start_n).to(tl.float32)
182 | else:
183 | bias = tl.load(
184 | b_ptrs + start_n, mask=(start_n + offs_n) < seqlen_k, other=0.0
185 | ).to(tl.float32)
186 | bias = bias[None, :]
187 | elif BIAS_TYPE == "matrix":
188 | if EVEN_M & EVEN_N:
189 | bias = tl.load(b_ptrs + start_n).to(tl.float32)
190 | else:
191 | bias = tl.load(
192 | b_ptrs + start_n,
193 | mask=(offs_m[:, None] < seqlen_q)
194 | & ((start_n + offs_n)[None, :] < seqlen_k),
195 | other=0.0,
196 | ).to(tl.float32)
197 | # Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
198 | # can then fuse the mult and add into an fma instruction. But if we have bias we need to
199 | # to multiply with softmax_scale here.
200 | qk = qk * softmax_scale + bias
201 | m_ij = tl.maximum(tl.max(qk, 1), lse_i)
202 | p = tl.exp(qk - m_ij[:, None])
203 | else:
204 | m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i)
205 | p = tl.exp(qk * softmax_scale - m_ij[:, None])
206 | l_ij = tl.sum(p, 1)
207 |
208 | # scale acc_o
209 | acc_o_scale = tl.exp(m_i - m_ij)
210 |
211 | # # -- update output accumulator acc_o --
212 | acc_o = acc_o * acc_o_scale[:, None]
213 |
214 | if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition
215 | if EVEN_HEADDIM:
216 | v = tl.load(v_ptrs + start_n * stride_vn)
217 | else:
218 | v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0)
219 | else:
220 | if EVEN_HEADDIM:
221 | v = tl.load(
222 | v_ptrs + start_n * stride_vn,
223 | mask=(start_n + offs_n)[:, None] < seqlen_k,
224 | other=0.0,
225 | )
226 | else:
227 | v = tl.load(
228 | v_ptrs + start_n * stride_vn,
229 | mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
230 | other=0.0,
231 | )
232 | p = p.to(v.dtype)
233 | acc_o += tl.dot(p, v)
234 |
235 | # -- update statistics
236 | m_i = m_ij
237 | l_i_new = tl.exp(lse_i - m_ij) + l_ij
238 | lse_i = m_ij + tl.log(l_i_new)
239 |
240 | o_scale = tl.exp(m_i - lse_i)
241 | acc_o = acc_o * o_scale[:, None]
242 | # rematerialize offsets to save registers
243 | start_m = tl.program_id(0)
244 | offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
245 | # write back l and m
246 | lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m
247 | tl.store(lse_ptrs, lse_i)
248 | # initialize pointers to output
249 | offs_d = tl.arange(0, BLOCK_HEADDIM)
250 | out_ptrs = (
251 | Out
252 | + off_b * stride_ob
253 | + off_h * stride_oh
254 | + (offs_m[:, None] * stride_om + offs_d[None, :])
255 | )
256 | if EVEN_M:
257 | if EVEN_HEADDIM:
258 | tl.store(out_ptrs, acc_o)
259 | else:
260 | tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim)
261 | else:
262 | if EVEN_HEADDIM:
263 | tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q)
264 | else:
265 | tl.store(
266 | out_ptrs, acc_o, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)
267 | )
268 |
269 |
270 | @triton.jit
271 | def _bwd_preprocess_do_o_dot(
272 | Out,
273 | DO,
274 | Delta,
275 | stride_ob,
276 | stride_oh,
277 | stride_om,
278 | stride_dob,
279 | stride_doh,
280 | stride_dom,
281 | nheads,
282 | seqlen_q,
283 | seqlen_q_rounded,
284 | headdim,
285 | BLOCK_M: tl.constexpr,
286 | BLOCK_HEADDIM: tl.constexpr,
287 | ):
288 | start_m = tl.program_id(0)
289 | off_hb = tl.program_id(1)
290 | off_b = off_hb // nheads
291 | off_h = off_hb % nheads
292 | # initialize offsets
293 | offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
294 | offs_d = tl.arange(0, BLOCK_HEADDIM)
295 | # load
296 | o = tl.load(
297 | Out + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :],
298 | mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
299 | other=0.0,
300 | ).to(tl.float32)
301 | do = tl.load(
302 | DO
303 | + off_b * stride_dob
304 | + off_h * stride_doh
305 | + offs_m[:, None] * stride_dom
306 | + offs_d[None, :],
307 | mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
308 | other=0.0,
309 | ).to(tl.float32)
310 | delta = tl.sum(o * do, axis=1)
311 | # write-back
312 | tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta)
313 |
314 |
315 | @triton.jit
316 | def _bwd_store_dx(
317 | dx_ptrs,
318 | dx,
319 | offs_n,
320 | offs_d,
321 | seqlen,
322 | headdim,
323 | EVEN_M: tl.constexpr,
324 | EVEN_N: tl.constexpr,
325 | even_headdim,
326 | ):
327 | # [2022-11-01] TD: Same bug. In the case of EVEN_N=True and EVEN_M=False,
328 | # if we just call tl.store(dv_ptrs), there's a race condition
329 | if EVEN_N & EVEN_M:
330 | if even_headdim:
331 | tl.store(dx_ptrs, dx)
332 | else:
333 | tl.store(dx_ptrs, dx, mask=offs_d[None, :] < headdim)
334 | else:
335 | if even_headdim:
336 | tl.store(dx_ptrs, dx, mask=offs_n[:, None] < seqlen)
337 | else:
338 | tl.store(dx_ptrs, dx, mask=(offs_n[:, None] < seqlen) & (offs_d[None, :] < headdim))
339 |
340 |
341 | @triton.jit
342 | def _bwd_kernel_one_col_block(
343 | start_n,
344 | Q,
345 | K,
346 | V,
347 | Bias,
348 | DO,
349 | DQ,
350 | DK,
351 | DV,
352 | LSE,
353 | D,
354 | softmax_scale,
355 | stride_qm,
356 | stride_kn,
357 | stride_vn,
358 | stride_bm,
359 | stride_dom,
360 | stride_dqm,
361 | stride_dkn,
362 | stride_dvn,
363 | seqlen_q,
364 | seqlen_k,
365 | headdim,
366 | ATOMIC_ADD: tl.constexpr,
367 | BIAS_TYPE: tl.constexpr,
368 | IS_CAUSAL: tl.constexpr,
369 | BLOCK_HEADDIM: tl.constexpr,
370 | EVEN_M: tl.constexpr,
371 | EVEN_N: tl.constexpr,
372 | EVEN_HEADDIM: tl.constexpr,
373 | BLOCK_M: tl.constexpr,
374 | BLOCK_N: tl.constexpr,
375 | ):
376 | # We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N)
377 | begin_m = 0 if not IS_CAUSAL else ((start_n * BLOCK_N) // BLOCK_M) * BLOCK_M
378 | # initialize row/col offsets
379 | offs_qm = begin_m + tl.arange(0, BLOCK_M)
380 | offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
381 | offs_m = tl.arange(0, BLOCK_M)
382 | offs_d = tl.arange(0, BLOCK_HEADDIM)
383 | # initialize pointers to value-like data
384 | q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_d[None, :])
385 | k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :])
386 | v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :])
387 | do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :])
388 | dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :])
389 | if BIAS_TYPE == "vector":
390 | b_ptrs = Bias + offs_n
391 | elif BIAS_TYPE == "matrix":
392 | b_ptrs = Bias + (offs_qm[:, None] * stride_bm + offs_n[None, :])
393 | # initialize dv and dk
394 | dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
395 | dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
396 | # There seems to be some problem with Triton pipelining that makes results wrong for
397 | # headdim=64, seqlen=(113, 255), bias_type='matrix'. In this case the for loop
398 | # may have zero step, and pipelining with the bias matrix could screw it up.
399 | # So we just exit early.
400 | if begin_m >= seqlen_q:
401 | dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
402 | dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
403 | _bwd_store_dx(
404 | dk_ptrs,
405 | dk,
406 | offs_n,
407 | offs_d,
408 | seqlen_k,
409 | headdim,
410 | EVEN_M=EVEN_M,
411 | EVEN_N=EVEN_N,
412 | even_headdim=EVEN_HEADDIM,
413 | )
414 | _bwd_store_dx(
415 | dv_ptrs,
416 | dv,
417 | offs_n,
418 | offs_d,
419 | seqlen_k,
420 | headdim,
421 | EVEN_M=EVEN_M,
422 | EVEN_N=EVEN_N,
423 | even_headdim=EVEN_HEADDIM,
424 | )
425 | return
426 | # k and v stay in SRAM throughout
427 | # [2022-10-30] TD: Same bug as the fwd. In the case of EVEN_N=True and EVEN_M=False,
428 | # if we just call tl.load(k_ptrs), we get the wrong output!
429 | if EVEN_N & EVEN_M:
430 | if EVEN_HEADDIM:
431 | k = tl.load(k_ptrs)
432 | v = tl.load(v_ptrs)
433 | else:
434 | k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
435 | v = tl.load(v_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
436 | else:
437 | if EVEN_HEADDIM:
438 | k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
439 | v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
440 | else:
441 | k = tl.load(
442 | k_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0
443 | )
444 | v = tl.load(
445 | v_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0
446 | )
447 | # loop over rows
448 | num_block_m = tl.cdiv(seqlen_q, BLOCK_M)
449 | for start_m in range(begin_m, num_block_m * BLOCK_M, BLOCK_M):
450 | start_m = tl.multiple_of(start_m, BLOCK_M)
451 | offs_m_curr = start_m + offs_m
452 | # load q, k, v, do on-chip
453 | # Same bug as below. Otherwise gives wrong result for headdim=40, seqlen=(128, 117)
454 | if EVEN_M & EVEN_HEADDIM:
455 | q = tl.load(q_ptrs)
456 | else:
457 | if EVEN_HEADDIM:
458 | q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0)
459 | else:
460 | q = tl.load(
461 | q_ptrs,
462 | mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
463 | other=0.0,
464 | )
465 | # recompute p = softmax(qk, dim=-1).T
466 | qk = tl.dot(q, tl.trans(k))
467 | # Trying to combine the two masks seem to make the result wrong
468 | if not EVEN_N: # Need to mask out otherwise the softmax is wrong
469 | qk = tl.where(offs_n[None, :] < seqlen_k, qk, float("-inf"))
470 | if IS_CAUSAL:
471 | qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
472 | if BIAS_TYPE != "none":
473 | tl.debug_barrier() # Race condition otherwise
474 | if BIAS_TYPE == "vector":
475 | if EVEN_N:
476 | bias = tl.load(b_ptrs).to(tl.float32)
477 | else:
478 | bias = tl.load(b_ptrs, mask=offs_n < seqlen_k, other=0.0).to(tl.float32)
479 | bias = bias[None, :]
480 | elif BIAS_TYPE == "matrix":
481 | if EVEN_M & EVEN_N:
482 | bias = tl.load(b_ptrs).to(tl.float32)
483 | else:
484 | bias = tl.load(
485 | b_ptrs,
486 | mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k),
487 | other=0.0,
488 | ).to(tl.float32)
489 | qk = qk * softmax_scale + bias
490 | # There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong.
491 | # Also wrong for headdim=64.
492 | if not (EVEN_M & EVEN_HEADDIM):
493 | tl.debug_barrier()
494 | lse_i = tl.load(LSE + offs_m_curr)
495 | if BIAS_TYPE == "none":
496 | p = tl.exp(qk * softmax_scale - lse_i[:, None])
497 | else:
498 | p = tl.exp(qk - lse_i[:, None])
499 | # compute dv
500 | # [2022-10-30] TD: A Triton bug: if EVEN_M=True and EVEN_HEADDIM=False, if we call
501 | # do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0), we get wrong outputs
502 | # in the case of headdim=48/96, seqlen_q & seqlen_k >= 512. If headdim=40 or seqlen < 512,
503 | # the output is correct.
504 | if EVEN_M & EVEN_HEADDIM:
505 | do = tl.load(do_ptrs)
506 | else:
507 | # [2022-11-01] TD: Triton bug, there's a race condition if we just use m_mask and not d_mask.
508 | do = tl.load(
509 | do_ptrs,
510 | mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
511 | other=0.0,
512 | )
513 | # if EVEN_M:
514 | # if EVEN_HEADDIM:
515 | # do = tl.load(do_ptrs)
516 | # else:
517 | # do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
518 | # else:
519 | # if EVEN_HEADDIM:
520 | # do = tl.load(do_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0)
521 | # else:
522 | # do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q)
523 | # & (offs_d[None, :] < headdim), other=0.0)
524 | dv += tl.dot(tl.trans(p.to(do.dtype)), do)
525 | # compute dp = dot(v, do)
526 | # There seems to be a race condition when headdim=48/96, and dq, dk are wrong.
527 | # Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True
528 | # Also wrong for headdim=64, seqlen=(1023, 1024), and ATOMIC_ADD=False
529 | if not (EVEN_M & EVEN_HEADDIM):
530 | tl.debug_barrier()
531 | dp = tl.dot(do, tl.trans(v))
532 | # There's a race condition for headdim=48
533 | if not EVEN_HEADDIM:
534 | tl.debug_barrier()
535 | # compute ds = p * (dp - delta[:, None])
536 | # Putting the subtraction after the dp matmul (instead of before) is slightly faster
537 | Di = tl.load(D + offs_m_curr)
538 | # Converting ds to q.dtype here reduces register pressure and makes it much faster
539 | # for BLOCK_HEADDIM=128
540 | ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype)
541 | # compute dk = dot(ds.T, q)
542 | dk += tl.dot(tl.trans(ds), q)
543 | # compute dq
544 | if not (
545 | EVEN_M & EVEN_HEADDIM
546 | ): # Otherewise there's a race condition when BIAS_TYPE='matrix'
547 | tl.debug_barrier()
548 | if not ATOMIC_ADD:
549 | if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M
550 | dq = tl.load(dq_ptrs, eviction_policy="evict_last")
551 | dq += tl.dot(ds, k)
552 | tl.store(dq_ptrs, dq, eviction_policy="evict_last")
553 | else:
554 | if EVEN_HEADDIM:
555 | dq = tl.load(
556 | dq_ptrs,
557 | mask=offs_m_curr[:, None] < seqlen_q,
558 | other=0.0,
559 | eviction_policy="evict_last",
560 | )
561 | dq += tl.dot(ds, k)
562 | tl.store(
563 | dq_ptrs,
564 | dq,
565 | mask=offs_m_curr[:, None] < seqlen_q,
566 | eviction_policy="evict_last",
567 | )
568 | else:
569 | dq = tl.load(
570 | dq_ptrs,
571 | mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
572 | other=0.0,
573 | eviction_policy="evict_last",
574 | )
575 | dq += tl.dot(ds, k)
576 | tl.store(
577 | dq_ptrs,
578 | dq,
579 | mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
580 | eviction_policy="evict_last",
581 | )
582 | else: # If we're parallelizing across the seqlen_k dimension
583 | dq = tl.dot(ds, k)
584 | if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M
585 | tl.atomic_add(dq_ptrs, dq)
586 | else:
587 | if EVEN_HEADDIM:
588 | tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q)
589 | else:
590 | tl.atomic_add(
591 | dq_ptrs,
592 | dq,
593 | mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
594 | )
595 | # increment pointers
596 | dq_ptrs += BLOCK_M * stride_dqm
597 | q_ptrs += BLOCK_M * stride_qm
598 | do_ptrs += BLOCK_M * stride_dom
599 | if BIAS_TYPE == "matrix":
600 | b_ptrs += BLOCK_M * stride_bm
601 | # write-back
602 | dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
603 | dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
604 | _bwd_store_dx(
605 | dk_ptrs,
606 | dk,
607 | offs_n,
608 | offs_d,
609 | seqlen_k,
610 | headdim,
611 | EVEN_M=EVEN_M,
612 | EVEN_N=EVEN_N,
613 | even_headdim=EVEN_HEADDIM,
614 | )
615 | _bwd_store_dx(
616 | dv_ptrs,
617 | dv,
618 | offs_n,
619 | offs_d,
620 | seqlen_k,
621 | headdim,
622 | EVEN_M=EVEN_M,
623 | EVEN_N=EVEN_N,
624 | even_headdim=EVEN_HEADDIM,
625 | )
626 |
627 |
628 | def init_to_zero(name):
629 | return lambda nargs: nargs[name].zero_()
630 |
631 | # compiler bug with using autotune in triton v2.
632 | @triton.autotune(
633 | configs=[
634 | triton.Config(
635 | {"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True},
636 | num_warps=8,
637 | num_stages=1,
638 | pre_hook=init_to_zero("DQ"),
639 | ),
640 | ],
641 | key=["CACHE_KEY_SEQLEN_Q", "CACHE_KEY_SEQLEN_K", "BIAS_TYPE", "IS_CAUSAL", "BLOCK_HEADDIM"],
642 | )
643 | @triton.heuristics(
644 | {
645 | "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
646 | "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
647 | "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
648 | }
649 | )
650 | @triton.jit
651 | def _bwd_kernel(
652 | Q,
653 | K,
654 | V,
655 | Bias,
656 | DO,
657 | DQ,
658 | DK,
659 | DV,
660 | LSE,
661 | D,
662 | softmax_scale,
663 | stride_qb,
664 | stride_qh,
665 | stride_qm,
666 | stride_kb,
667 | stride_kh,
668 | stride_kn,
669 | stride_vb,
670 | stride_vh,
671 | stride_vn,
672 | stride_bb,
673 | stride_bh,
674 | stride_bm,
675 | stride_dob,
676 | stride_doh,
677 | stride_dom,
678 | stride_dqb,
679 | stride_dqh,
680 | stride_dqm,
681 | stride_dkb,
682 | stride_dkh,
683 | stride_dkn,
684 | stride_dvb,
685 | stride_dvh,
686 | stride_dvn,
687 | nheads,
688 | seqlen_q,
689 | seqlen_k,
690 | seqlen_q_rounded,
691 | headdim,
692 | CACHE_KEY_SEQLEN_Q,
693 | CACHE_KEY_SEQLEN_K,
694 | BIAS_TYPE: tl.constexpr,
695 | IS_CAUSAL: tl.constexpr,
696 | BLOCK_HEADDIM: tl.constexpr,
697 | SEQUENCE_PARALLEL: tl.constexpr,
698 | EVEN_M: tl.constexpr,
699 | EVEN_N: tl.constexpr,
700 | EVEN_HEADDIM: tl.constexpr,
701 | BLOCK_M: tl.constexpr,
702 | BLOCK_N: tl.constexpr,
703 | ):
704 | off_hb = tl.program_id(1)
705 | off_b = off_hb // nheads
706 | off_h = off_hb % nheads
707 | # offset pointers for batch/head
708 | Q += off_b * stride_qb + off_h * stride_qh
709 | K += off_b * stride_kb + off_h * stride_kh
710 | V += off_b * stride_vb + off_h * stride_vh
711 | DO += off_b * stride_dob + off_h * stride_doh
712 | DQ += off_b * stride_dqb + off_h * stride_dqh
713 | DK += off_b * stride_dkb + off_h * stride_dkh
714 | DV += off_b * stride_dvb + off_h * stride_dvh
715 | if BIAS_TYPE != "none":
716 | Bias += off_b * stride_bb + off_h * stride_bh
717 | # pointer to row-wise quantities in value-like data
718 | D += off_hb * seqlen_q_rounded
719 | LSE += off_hb * seqlen_q_rounded
720 | if not SEQUENCE_PARALLEL:
721 | num_block_n = tl.cdiv(seqlen_k, BLOCK_N)
722 | for start_n in range(0, num_block_n):
723 | _bwd_kernel_one_col_block(
724 | start_n,
725 | Q,
726 | K,
727 | V,
728 | Bias,
729 | DO,
730 | DQ,
731 | DK,
732 | DV,
733 | LSE,
734 | D,
735 | softmax_scale,
736 | stride_qm,
737 | stride_kn,
738 | stride_vn,
739 | stride_bm,
740 | stride_dom,
741 | stride_dqm,
742 | stride_dkn,
743 | stride_dvn,
744 | seqlen_q,
745 | seqlen_k,
746 | headdim,
747 | ATOMIC_ADD=False,
748 | BIAS_TYPE=BIAS_TYPE,
749 | IS_CAUSAL=IS_CAUSAL,
750 | BLOCK_HEADDIM=BLOCK_HEADDIM,
751 | EVEN_M=EVEN_M,
752 | EVEN_N=EVEN_N,
753 | EVEN_HEADDIM=EVEN_HEADDIM,
754 | BLOCK_M=BLOCK_M,
755 | BLOCK_N=BLOCK_N,
756 | )
757 | else:
758 | start_n = tl.program_id(0)
759 | _bwd_kernel_one_col_block(
760 | start_n,
761 | Q,
762 | K,
763 | V,
764 | Bias,
765 | DO,
766 | DQ,
767 | DK,
768 | DV,
769 | LSE,
770 | D,
771 | softmax_scale,
772 | stride_qm,
773 | stride_kn,
774 | stride_vn,
775 | stride_bm,
776 | stride_dom,
777 | stride_dqm,
778 | stride_dkn,
779 | stride_dvn,
780 | seqlen_q,
781 | seqlen_k,
782 | headdim,
783 | ATOMIC_ADD=True,
784 | BIAS_TYPE=BIAS_TYPE,
785 | IS_CAUSAL=IS_CAUSAL,
786 | BLOCK_HEADDIM=BLOCK_HEADDIM,
787 | EVEN_M=EVEN_M,
788 | EVEN_N=EVEN_N,
789 | EVEN_HEADDIM=EVEN_HEADDIM,
790 | BLOCK_M=BLOCK_M,
791 | BLOCK_N=BLOCK_N,
792 | )
793 |
794 |
795 | def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):
796 | # shape constraints
797 | batch, seqlen_q, nheads, d = q.shape
798 | _, seqlen_k, _, _ = k.shape
799 | assert k.shape == (batch, seqlen_k, nheads, d)
800 | assert v.shape == (batch, seqlen_k, nheads, d)
801 | assert d <= 128, "FlashAttention only support head dimensions up to 128"
802 | assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type"
803 | assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16"
804 | assert q.is_cuda and k.is_cuda and v.is_cuda
805 | softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
806 |
807 | has_bias = bias is not None
808 | bias_type = "none"
809 | if has_bias:
810 | assert bias.dtype in [q.dtype, torch.float]
811 | assert bias.is_cuda
812 | assert bias.dim() == 4
813 | if bias.stride(-1) != 1:
814 | bias = bias.contiguous()
815 | if bias.shape[2:] == (1, seqlen_k):
816 | bias_type = "vector"
817 | elif bias.shape[2:] == (seqlen_q, seqlen_k):
818 | bias_type = "matrix"
819 | else:
820 | raise RuntimeError(
821 | "Last 2 dimensions of bias must be (1, seqlen_k)" " or (seqlen_q, seqlen_k)"
822 | )
823 | bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
824 | bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)
825 |
826 | seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
827 | lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
828 | o = torch.empty_like(q)
829 |
830 | BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
831 | BLOCK = 128
832 | num_warps = 4 if d <= 64 else 8
833 | grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
834 | _fwd_kernel[grid](
835 | q,
836 | k,
837 | v,
838 | bias,
839 | o,
840 | lse,
841 | softmax_scale,
842 | q.stride(0),
843 | q.stride(2),
844 | q.stride(1),
845 | k.stride(0),
846 | k.stride(2),
847 | k.stride(1),
848 | v.stride(0),
849 | v.stride(2),
850 | v.stride(1),
851 | *bias_strides,
852 | o.stride(0),
853 | o.stride(2),
854 | o.stride(1),
855 | nheads,
856 | seqlen_q,
857 | seqlen_k,
858 | seqlen_q_rounded,
859 | d,
860 | seqlen_q // 32,
861 | seqlen_k // 32, # key for triton cache (limit number of compilations)
862 | # Can't use kwargs here because triton autotune expects key to be args, not kwargs
863 | # IS_CAUSAL=causal, BLOCK_HEADDIM=d,
864 | bias_type,
865 | causal,
866 | BLOCK_HEADDIM,
867 | BLOCK_M=BLOCK,
868 | BLOCK_N=BLOCK,
869 | num_warps=num_warps,
870 | num_stages=1,
871 | )
872 | return o, lse, softmax_scale # softmax_scale could have been updated
873 |
874 |
875 | def _flash_attn_backward(
876 | do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=False, softmax_scale=None
877 | ):
878 | # Make sure that the last dimension is contiguous
879 | if do.stride(-1) != 1:
880 | do = do.contiguous()
881 | batch, seqlen_q, nheads, d = q.shape
882 | _, seqlen_k, _, _ = k.shape
883 | # assert d in {16, 32, 64, 128}
884 | assert d <= 128
885 | seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
886 | assert lse.shape == (batch, nheads, seqlen_q_rounded)
887 | assert q.stride(-1) == k.stride(-1) == v.stride(-1) == o.stride(-1) == 1
888 | assert dq.stride(-1) == dk.stride(-1) == dv.stride(-1) == 1
889 | softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
890 | # dq_accum = torch.zeros_like(q, dtype=torch.float32)
891 | dq_accum = torch.empty_like(q, dtype=torch.float32)
892 | delta = torch.empty_like(lse)
893 | # delta = torch.zeros_like(lse)
894 |
895 | BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
896 | grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
897 | _bwd_preprocess_do_o_dot[grid](
898 | o,
899 | do,
900 | delta,
901 | o.stride(0),
902 | o.stride(2),
903 | o.stride(1),
904 | do.stride(0),
905 | do.stride(2),
906 | do.stride(1),
907 | nheads,
908 | seqlen_q,
909 | seqlen_q_rounded,
910 | d,
911 | BLOCK_M=128,
912 | BLOCK_HEADDIM=BLOCK_HEADDIM,
913 | )
914 |
915 | has_bias = bias is not None
916 | bias_type = "none"
917 | if has_bias:
918 | assert bias.dtype in [q.dtype, torch.float]
919 | assert bias.is_cuda
920 | assert bias.dim() == 4
921 | assert bias.stride(-1) == 1
922 | if bias.shape[2:] == (1, seqlen_k):
923 | bias_type = "vector"
924 | elif bias.shape[2:] == (seqlen_q, seqlen_k):
925 | bias_type = "matrix"
926 | else:
927 | raise RuntimeError(
928 | "Last 2 dimensions of bias must be (1, seqlen_k)" " or (seqlen_q, seqlen_k)"
929 | )
930 | bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
931 | bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)
932 |
933 | # BLOCK_M = 128
934 | # BLOCK_N = 128
935 | # num_warps = 8
936 | grid = lambda META: (
937 | triton.cdiv(seqlen_k, META["BLOCK_N"]) if META["SEQUENCE_PARALLEL"] else 1,
938 | batch * nheads,
939 | )
940 | _bwd_kernel[grid](
941 | q,
942 | k,
943 | v,
944 | bias,
945 | do,
946 | dq_accum,
947 | dk,
948 | dv,
949 | lse,
950 | delta,
951 | softmax_scale,
952 | q.stride(0),
953 | q.stride(2),
954 | q.stride(1),
955 | k.stride(0),
956 | k.stride(2),
957 | k.stride(1),
958 | v.stride(0),
959 | v.stride(2),
960 | v.stride(1),
961 | *bias_strides,
962 | do.stride(0),
963 | do.stride(2),
964 | do.stride(1),
965 | dq_accum.stride(0),
966 | dq_accum.stride(2),
967 | dq_accum.stride(1),
968 | dk.stride(0),
969 | dk.stride(2),
970 | dk.stride(1),
971 | dv.stride(0),
972 | dv.stride(2),
973 | dv.stride(1),
974 | nheads,
975 | seqlen_q,
976 | seqlen_k,
977 | seqlen_q_rounded,
978 | d,
979 | seqlen_q // 32,
980 | seqlen_k // 32, # key for triton cache (limit number of compilations)
981 | # Can't use kwargs here because triton autotune expects key to be args, not kwargs
982 | # IS_CAUSAL=causal, BLOCK_HEADDIM=d,
983 | bias_type,
984 | causal,
985 | BLOCK_HEADDIM,
986 | # SEQUENCE_PARALLEL=False,
987 | # BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
988 | # num_warps=num_warps,
989 | # num_stages=1,
990 | )
991 | dq.copy_(dq_accum)
992 |
993 |
994 |
995 | class FlashAttnFunc(torch.autograd.Function):
996 | @staticmethod
997 | def forward(ctx, q, k, v, bias=None, causal=False, softmax_scale=None):
998 | """
999 | q: (batch_size, seqlen_q, nheads, headdim)
1000 | k, v: (batch_size, seqlen_k, nheads, headdim)
1001 | bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).
1002 | For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k).
1003 | ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k)
1004 | """
1005 | # Make sure that the last dimension is contiguous
1006 | q, k, v = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v]]
1007 | o, lse, ctx.softmax_scale = _flash_attn_forward(
1008 | q, k, v, bias=bias, causal=causal, softmax_scale=softmax_scale
1009 | )
1010 | ctx.save_for_backward(q, k, v, o, lse, bias)
1011 | ctx.causal = causal
1012 | return o, lse
1013 |
1014 | @staticmethod
1015 | def backward(ctx, do, dlse_use_needed=None):
1016 | q, k, v, o, lse, bias = ctx.saved_tensors
1017 | if len(ctx.needs_input_grad) > 3:
1018 | assert not ctx.needs_input_grad[3], "FlashAttention does not support bias gradient yet"
1019 | # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
1020 | # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
1021 | with torch.inference_mode():
1022 | dq = torch.empty_like(q)
1023 | dk = torch.empty_like(k)
1024 | dv = torch.empty_like(v)
1025 | _flash_attn_backward(
1026 | do,
1027 | q,
1028 | k,
1029 | v,
1030 | o,
1031 | lse,
1032 | dq,
1033 | dk,
1034 | dv,
1035 | bias=bias,
1036 | causal=ctx.causal,
1037 | softmax_scale=ctx.softmax_scale,
1038 | )
1039 | return dq, dk, dv, None, None, None
1040 |
1041 |
1042 | flash_attn_func = FlashAttnFunc.apply
1043 |
1044 |
--------------------------------------------------------------------------------