├── .gitignore ├── LICENSE ├── README.md ├── bench.py ├── bench_causal.py ├── check_backward.py ├── check_backward_causal.py ├── flash_attention.py ├── flash_attention_causal.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | kineto 2 | profiler_logs 3 | __pycache__ 4 | flash_attention_v2_wip.py -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Shreyansh Singh 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FlashAttention in PyTorch 2 | 3 | A simplified implementation of [FlashAttention](https://arxiv.org/abs/2205.14135) in PyTorch. I have implemented the forward pass and backward pass algorithms from the paper, and also shown that it is equivalent to the normal attention formulation in Transformers. I also include some code for benchmarking. 4 | 5 | Note that this is for educational purposes only as I haven't implemented any of the CUDA and SRAM memory tricks as described in the paper. 6 | 7 | ## Requirements 8 | * einops==0.6.1 9 | * torch==2.0.1 10 | 11 | ## Files 12 | * [flash_attention.py](flash_attention.py) - Implementation of the general formulation of FlashAttention which takes in Q, K, V and a mask. The code includes both the forward and backward algorithms and a simple test of equivalence of the forward pass with normal attention as well. 13 | * [flash_attention_causal.py](flash_attention_causal.py) - The causal version of FlashAttention which takes in Q, K and V. The mask is caluclated in a causal fashion which is typcially used in autoregressive models. This code also includes the forward and backward algorithms and a simple test of equivalence of the forward pass with normal attention (causal) as well. 14 | * [bench.py](bench.py), [bench_causal.py](bench_causal.py) - Benchmarking code for both general and causal versions of FlashAttention. 15 | * [check_backward.py](check_backward.py), [check_backward_causal.py](check_backward_causal.py) - This script verifies two things - 1. whether the calculated value of gradients (using PyTorch's `jacrev`) of Q, K and V match for the normal version of attention and FlashAttention, and 2. whether these results match the implementation of backward pass given in the paper. The loss function is simply assumed to be a sum of the final output tensor. 16 | 17 | ## To run 18 | 19 | ### Forward pass 20 | 21 | **Causal mask** 22 | ```python flash_attention_causal.py``` 23 | 24 | **Random mask** 25 | ```python flash_attention.py``` 26 | 27 | ### Benchmarking - Causal mask 28 | 29 | **FlashAttention** 30 | ```python bench_causal.py --b 1 --h 2 --q_len 16384 --kv_len 16384 --d 512 --type flash``` 31 | 32 | **Normal attention** 33 | ```python bench_causal.py --b 1 --h 2 --q_len 16384 --kv_len 16384 --d 512 --type normal``` 34 | 35 | Add `--profile` to log additional details using PyTorch Profiler. 36 | 37 | ### Benchmarking - Random mask 38 | 39 | **FlashAttention** 40 | ```python bench.py --b 1 --h 2 --q_len 16384 --kv_len 16384 --d 512 --type flash``` 41 | 42 | **Normal attention** 43 | ```python bench.py --b 1 --h 2 --q_len 16384 --kv_len 16384 --d 512 --type normal``` 44 | 45 | Add `--profile` to log additional details using PyTorch Profiler. 46 | 47 | ### Backward Pass 48 | 49 | **Causal mask** 50 | ```python check_backward_causal.py``` 51 | 52 | **Random mask** 53 | ```python check_backward.py``` 54 | -------------------------------------------------------------------------------- /bench.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | from flash_attention import flash_attention, normal_attention 4 | import argparse 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('--type', type=str, required=True, help="flash/normal") 8 | parser.add_argument('--b', type=int, required=False, default=1, help="Batch size") 9 | parser.add_argument('--h', type=int, required=False, default=2, help="Number of heads") 10 | parser.add_argument('--q_len', type=int, required=False, default=4096, help="Length/first dimension of Q matrix") 11 | parser.add_argument('--kv_len', type=int, required=False, default=4096, help="Length/first dimension of K/V matrix") 12 | parser.add_argument('--d', type=int, required=False, default=512, help="Dimension of vector") 13 | parser.add_argument('--profile', action='store_true', help="For Pytorch profiling") 14 | 15 | args = parser.parse_args() 16 | 17 | Q = torch.randn(args.b, args.h, args.q_len, args.d, requires_grad=True).to(device='cuda') 18 | K = torch.randn(args.b, args.h, args.kv_len, args.d, requires_grad=True).to(device='cuda') 19 | V = torch.randn(args.b, args.h, args.kv_len, args.d, requires_grad=True).to(device='cuda') 20 | mask = torch.randint(0, 2, (args.b, args.kv_len)).to(device='cuda') 21 | 22 | if args.type == "flash": 23 | for _ in range(10): 24 | flash_attention(Q, K, V, mask) 25 | 26 | start = time.time_ns() 27 | flash_attention(Q, K, V, mask) 28 | end = time.time_ns() 29 | 30 | t = (end - start) / 1000000 31 | print(f'{t}ms') 32 | else: 33 | for _ in range(10): 34 | normal_attention(Q, K, V, mask) 35 | 36 | start = time.time_ns() 37 | normal_attention(Q, K, V, mask) 38 | end = time.time_ns() 39 | 40 | t = (end - start) / 1000000 41 | print(f'{t}ms') 42 | 43 | if args.profile: 44 | if args.type == "flash": 45 | with torch.profiler.profile( 46 | activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], 47 | on_trace_ready=torch.profiler.tensorboard_trace_handler('./profiler_logs/bench_log_flash'), 48 | record_shapes=True, 49 | profile_memory=True, 50 | with_stack=False, # incurs an additional overhead, disable if not needed 51 | with_flops=True, 52 | with_modules=False, # only for torchscript models atm 53 | ) as prof: 54 | flash_attention(Q, K, V, mask) 55 | print(prof.key_averages().table(sort_by="cuda_memory_usage", row_limit=10)) 56 | else: 57 | with torch.profiler.profile( 58 | activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], 59 | on_trace_ready=torch.profiler.tensorboard_trace_handler('./profiler_logs/bench_log_normal'), 60 | record_shapes=True, 61 | profile_memory=True, 62 | with_stack=False, # incurs an additional overhead, disable if not needed 63 | with_flops=True, 64 | with_modules=False, # only for torchscript models atm 65 | ) as prof: 66 | normal_attention(Q, K, V, mask) 67 | print(prof.key_averages().table(sort_by="cuda_memory_usage", row_limit=10)) 68 | -------------------------------------------------------------------------------- /bench_causal.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | from flash_attention_causal import flash_attention_causal, normal_attention_causal 4 | import argparse 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('--type', type=str, required=True, help="flash/normal") 8 | parser.add_argument('--b', type=int, required=False, default=1, help="Batch size") 9 | parser.add_argument('--h', type=int, required=False, default=2, help="Number of heads") 10 | parser.add_argument('--q_len', type=int, required=False, default=4096, help="Length/first dimension of Q matrix") 11 | parser.add_argument('--kv_len', type=int, required=False, default=4096, help="Length/first dimension of K/V matrix") 12 | parser.add_argument('--d', type=int, required=False, default=512, help="Dimension of vector") 13 | parser.add_argument('--profile', action='store_true', help="For Pytorch profiling") 14 | 15 | args = parser.parse_args() 16 | 17 | Q = torch.randn(args.b, args.h, args.q_len, args.d, requires_grad=True).to(device='cuda') 18 | K = torch.randn(args.b, args.h, args.kv_len, args.d, requires_grad=True).to(device='cuda') 19 | V = torch.randn(args.b, args.h, args.kv_len, args.d, requires_grad=True).to(device='cuda') 20 | 21 | if args.type == "flash": 22 | for _ in range(10): 23 | flash_attention_causal(Q, K, V) 24 | 25 | start = time.time_ns() 26 | flash_attention_causal(Q, K, V) 27 | end = time.time_ns() 28 | 29 | t = (end - start) / 1000000 30 | print(f'{t}ms') 31 | else: 32 | for _ in range(10): 33 | normal_attention_causal(Q, K, V) 34 | 35 | start = time.time_ns() 36 | normal_attention_causal(Q, K, V) 37 | end = time.time_ns() 38 | 39 | t = (end - start) / 1000000 40 | print(f'{t}ms') 41 | 42 | if args.profile: 43 | if args.type == "flash": 44 | with torch.profiler.profile( 45 | activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], 46 | on_trace_ready=torch.profiler.tensorboard_trace_handler('./profiler_logs/bench_log_flash'), 47 | record_shapes=True, 48 | profile_memory=True, 49 | with_stack=False, # incurs an additional overhead, disable if not needed 50 | with_flops=True, 51 | with_modules=False, # only for torchscript models atm 52 | ) as prof: 53 | flash_attention_causal(Q, K, V) 54 | print(prof.key_averages().table(sort_by="cuda_memory_usage", row_limit=10)) 55 | else: 56 | with torch.profiler.profile( 57 | activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], 58 | on_trace_ready=torch.profiler.tensorboard_trace_handler('./profiler_logs/bench_log_normal'), 59 | record_shapes=True, 60 | profile_memory=True, 61 | with_stack=False, # incurs an additional overhead, disable if not needed 62 | with_flops=True, 63 | with_modules=False, # only for torchscript models atm 64 | ) as prof: 65 | normal_attention_causal(Q, K, V) 66 | print(prof.key_averages().table(sort_by="cuda_memory_usage", row_limit=10)) 67 | -------------------------------------------------------------------------------- /check_backward.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | from flash_attention import flash_attention, normal_attention, flash_attention_backward, flash_attention_forward 4 | from torch.func import jacrev 5 | 6 | Q = torch.randn(1, 1, 2048, 512, requires_grad=True).to(device='cuda') 7 | K = torch.randn(1, 1, 2048, 512, requires_grad=True).to(device='cuda') 8 | V = torch.randn(1, 1, 2048, 512, requires_grad=True).to(device='cuda') 9 | mask = torch.randint(0, 2, (1, 2048)).to(device='cuda') 10 | 11 | def loss_fn(fn, *args): 12 | return torch.sum(fn(*args)) 13 | 14 | args = (Q, K, V, mask) 15 | 16 | dq_flash, dk_flash, dv_flash = jacrev(loss_fn, argnums=(1,2,3))(flash_attention, *args) 17 | dq_normal, dk_normal, dv_normal = jacrev(loss_fn, argnums=(1,2,3))(normal_attention, *args) 18 | 19 | print(torch.allclose(dq_flash, dq_normal, atol=1e-5)) 20 | print(torch.allclose(dk_flash, dk_normal, atol=1e-5)) 21 | print(torch.allclose(dv_flash, dv_normal, atol=1e-5)) 22 | 23 | O, l, m = flash_attention_forward(Q, K, V, mask) 24 | dO = torch.ones_like(O) # Since "loss" here is the sum of the elements of the output matrix 25 | dq_flash_manual, dk_flash_manual, dv_flash_manual = flash_attention_backward(Q, K, V, mask, O, l, m, dO) 26 | 27 | print(torch.allclose(dq_flash, dq_flash_manual, atol=1e-5)) 28 | print(torch.allclose(dk_flash, dk_flash_manual, atol=1e-5)) 29 | print(torch.allclose(dv_flash, dv_flash_manual, atol=1e-5)) 30 | -------------------------------------------------------------------------------- /check_backward_causal.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | from flash_attention_causal import flash_attention_causal, normal_attention_causal, flash_attention_causal_backward, flash_attention_causal_forward 4 | from torch.func import jacrev 5 | 6 | Q = torch.randn(1, 1, 2048, 512, requires_grad=True).to(device='cuda') 7 | K = torch.randn(1, 1, 2048, 512, requires_grad=True).to(device='cuda') 8 | V = torch.randn(1, 1, 2048, 512, requires_grad=True).to(device='cuda') 9 | 10 | def loss_fn(fn, *args): 11 | return torch.sum(fn(*args)) 12 | 13 | args = (Q, K, V) 14 | 15 | dq_flash, dk_flash, dv_flash = jacrev(loss_fn, argnums=(1,2,3))(flash_attention_causal, *args) 16 | dq_normal, dk_normal, dv_normal = jacrev(loss_fn, argnums=(1,2,3))(normal_attention_causal, *args) 17 | 18 | print(torch.allclose(dq_flash, dq_normal, atol=1e-5)) 19 | print(torch.allclose(dk_flash, dk_normal, atol=1e-5)) 20 | print(torch.allclose(dv_flash, dv_normal, atol=1e-5)) 21 | 22 | O, l, m = flash_attention_causal_forward(Q, K, V) 23 | dO = torch.ones_like(O) # Since "loss" here is the sum of the elements of the output matrix 24 | dq_flash_manual, dk_flash_manual, dv_flash_manual = flash_attention_causal_backward(Q, K, V, O, l, m, dO) 25 | 26 | print(torch.allclose(dq_flash, dq_flash_manual, atol=1e-5)) 27 | print(torch.allclose(dk_flash, dk_flash_manual, atol=1e-5)) 28 | print(torch.allclose(dv_flash, dv_flash_manual, atol=1e-5)) 29 | -------------------------------------------------------------------------------- /flash_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import sys 5 | import time 6 | from einops import rearrange 7 | 8 | BLOCK_SIZE = 1024 9 | NEG_INF = -1e10 # -infinity 10 | EPSILON = 1e-10 11 | 12 | def normal_attention(Q, K, V, mask=None): 13 | scale = 1 / np.sqrt(Q.shape[-1]) 14 | Q = Q * scale 15 | QKt = torch.einsum('... i d, ... j d -> ... i j', Q, K) 16 | 17 | key_mask = rearrange(mask, 'b j -> b 1 1 j') 18 | QKt = torch.where(key_mask > 0, QKt, NEG_INF) 19 | 20 | attn = nn.functional.softmax(QKt, dim=-1) 21 | return attn @ V 22 | 23 | def flash_attention_forward(Q, K, V, mask=None): 24 | O = torch.zeros_like(Q, requires_grad=True) 25 | l = torch.zeros(Q.shape[:-1])[...,None] 26 | m = torch.ones(Q.shape[:-1])[...,None] * NEG_INF 27 | 28 | O = O.to(device='cuda') 29 | l = l.to(device='cuda') 30 | m = m.to(device='cuda') 31 | 32 | Q_BLOCK_SIZE = min(BLOCK_SIZE, Q.shape[-1]) 33 | KV_BLOCK_SIZE = BLOCK_SIZE 34 | 35 | Q_BLOCKS = torch.split(Q, Q_BLOCK_SIZE, dim=2) 36 | K_BLOCKS = torch.split(K, KV_BLOCK_SIZE, dim=2) 37 | V_BLOCKS = torch.split(V, KV_BLOCK_SIZE, dim=2) 38 | mask_BLOCKS = list(torch.split(mask, KV_BLOCK_SIZE, dim=1)) 39 | 40 | Tr = len(Q_BLOCKS) 41 | Tc = len(K_BLOCKS) 42 | 43 | O_BLOCKS = list(torch.split(O, Q_BLOCK_SIZE, dim=2)) 44 | l_BLOCKS = list(torch.split(l, Q_BLOCK_SIZE, dim=2)) 45 | m_BLOCKS = list(torch.split(m, Q_BLOCK_SIZE, dim=2)) 46 | 47 | for j in range(Tc): 48 | Kj = K_BLOCKS[j] 49 | Vj = V_BLOCKS[j] 50 | maskj = mask_BLOCKS[j] 51 | 52 | for i in range(Tr): 53 | Qi = Q_BLOCKS[i] 54 | Oi = O_BLOCKS[i] 55 | li = l_BLOCKS[i] 56 | mi = m_BLOCKS[i] 57 | 58 | scale = 1 / np.sqrt(Q.shape[-1]) 59 | Qi_scaled = Qi * scale 60 | 61 | S_ij = torch.einsum('... i d, ... j d -> ... i j', Qi_scaled, Kj) 62 | 63 | # Masking 64 | maskj_temp = rearrange(maskj, 'b j -> b 1 1 j') 65 | S_ij = torch.where(maskj_temp > 0, S_ij, NEG_INF) 66 | 67 | m_block_ij, _ = torch.max(S_ij, dim=-1, keepdims=True) 68 | P_ij = torch.exp(S_ij - m_block_ij) 69 | # Masking 70 | P_ij = torch.where(maskj_temp > 0, P_ij, 0.) 71 | 72 | l_block_ij = torch.sum(P_ij, dim=-1, keepdims=True) + EPSILON 73 | 74 | P_ij_Vj = torch.einsum('... i j, ... j d -> ... i d', P_ij, Vj) 75 | 76 | mi_new = torch.maximum(m_block_ij, mi) 77 | li_new = torch.exp(mi - mi_new) * li + torch.exp(m_block_ij - mi_new) * l_block_ij 78 | 79 | O_BLOCKS[i] = (li/li_new) * torch.exp(mi - mi_new) * Oi + (torch.exp(m_block_ij - mi_new) / li_new) * P_ij_Vj 80 | l_BLOCKS[i] = li_new 81 | m_BLOCKS[i] = mi_new 82 | 83 | O = torch.cat(O_BLOCKS, dim=2) 84 | l = torch.cat(l_BLOCKS, dim=2) 85 | m = torch.cat(m_BLOCKS, dim=2) 86 | return O, l, m 87 | 88 | def flash_attention_backward(Q, K, V, mask, O, l, m, dO): 89 | Q_BLOCK_SIZE = min(BLOCK_SIZE, Q.shape[-1]) 90 | KV_BLOCK_SIZE = BLOCK_SIZE 91 | 92 | Q_BLOCKS = torch.split(Q, Q_BLOCK_SIZE, dim=2) 93 | K_BLOCKS = torch.split(K, KV_BLOCK_SIZE, dim=2) 94 | V_BLOCKS = torch.split(V, KV_BLOCK_SIZE, dim=2) 95 | mask_BLOCKS = list(torch.split(mask, KV_BLOCK_SIZE, dim=1)) 96 | 97 | Tr = len(Q_BLOCKS) 98 | Tc = len(K_BLOCKS) 99 | 100 | O_BLOCKS = list(torch.split(O, Q_BLOCK_SIZE, dim=2)) 101 | dO_BLOCKS = list(torch.split(dO, Q_BLOCK_SIZE, dim=2)) 102 | l_BLOCKS = list(torch.split(l, Q_BLOCK_SIZE, dim=2)) 103 | m_BLOCKS = list(torch.split(m, Q_BLOCK_SIZE, dim=2)) 104 | 105 | dQ = torch.zeros_like(Q, requires_grad=True).to(device='cuda') 106 | dK = torch.zeros_like(K, requires_grad=True).to(device='cuda') 107 | dV = torch.zeros_like(V, requires_grad=True).to(device='cuda') 108 | 109 | dQ_BLOCKS = list(torch.split(dQ, Q_BLOCK_SIZE, dim=2)) 110 | dK_BLOCKS = list(torch.split(dK, KV_BLOCK_SIZE, dim=2)) 111 | dV_BLOCKS = list(torch.split(dV, KV_BLOCK_SIZE, dim=2)) 112 | 113 | for j in range(Tc): 114 | Kj = K_BLOCKS[j] 115 | Vj = V_BLOCKS[j] 116 | maskj = mask_BLOCKS[j] 117 | 118 | dKj_block = torch.zeros_like(dK_BLOCKS[j], requires_grad=True).to(device='cuda') 119 | dVj_block = torch.zeros_like(dV_BLOCKS[j], requires_grad=True).to(device='cuda') 120 | 121 | for i in range(Tr): 122 | Qi = Q_BLOCKS[i] 123 | Oi = O_BLOCKS[i] 124 | dOi = dO_BLOCKS[i] 125 | li = l_BLOCKS[i] 126 | mi = m_BLOCKS[i] 127 | 128 | scale = 1 / np.sqrt(Q.shape[-1]) 129 | Qi_scaled = Qi * scale 130 | 131 | S_ij = torch.einsum('... i d, ... j d -> ... i j', Qi_scaled, Kj) 132 | 133 | # Masking 134 | maskj_temp = rearrange(maskj, 'b j -> b 1 1 j') 135 | S_ij = torch.where(maskj_temp > 0, S_ij, NEG_INF) 136 | 137 | P_ij = (1/li) * torch.exp(S_ij - mi) 138 | # Masking 139 | P_ij = torch.where(maskj_temp > 0, P_ij, 0.) 140 | 141 | dVj_block = dVj_block + torch.einsum('... r c, ... r d -> ... c d', P_ij, dOi) 142 | dP_ij = torch.einsum('... r d, ... c d -> ... r c', dOi, Vj) 143 | 144 | Di = torch.sum(dOi * Oi, dim=-1, keepdims=True) 145 | dS_ij = P_ij * (dP_ij - Di) 146 | 147 | dQ_BLOCKS[i] = dQ_BLOCKS[i] + scale * torch.einsum('... r c, ... c d -> ... r d', dS_ij, Kj) 148 | 149 | dKj_block = dKj_block + scale * torch.einsum('... r c, ... r d -> ... c d', dS_ij, Qi) 150 | 151 | dK_BLOCKS[j] = dKj_block 152 | dV_BLOCKS[j] = dVj_block 153 | 154 | dQ = torch.cat(dQ_BLOCKS, dim=2) 155 | dK = torch.cat(dK_BLOCKS, dim=2) 156 | dV = torch.cat(dV_BLOCKS, dim=2) 157 | return dQ, dK, dV 158 | 159 | def flash_attention(Q, K, V, mask): 160 | out = flash_attention_forward(Q, K, V, mask) 161 | return out[0] 162 | 163 | if __name__ == "__main__": 164 | Q = torch.randn(1, 2, 4096, 1024, requires_grad=True).to(device='cuda') 165 | K = torch.randn(1, 2, 4096, 1024, requires_grad=True).to(device='cuda') 166 | V = torch.randn(1, 2, 4096, 1024, requires_grad=True).to(device='cuda') 167 | mask = torch.randint(0, 2, (1, 4096)).to(device='cuda') 168 | 169 | for i in range(10): 170 | start1 = time.time_ns() 171 | out1 = flash_attention(Q, K, V, mask) 172 | end1 = time.time_ns() 173 | 174 | start2 = time.time_ns() 175 | out2 = normal_attention(Q, K, V, mask) 176 | end2 = time.time_ns() 177 | 178 | t1 = (end1 - start1) / 1000000 179 | t2 = (end2 - start2) / 1000000 180 | 181 | print(f'{t1}ms, {t2}ms') 182 | print(torch.allclose(out1, out2, atol=1e-5)) -------------------------------------------------------------------------------- /flash_attention_causal.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import sys 5 | import time 6 | from einops import rearrange 7 | 8 | BLOCK_SIZE = 1024 9 | NEG_INF = -1e10 # -infinity 10 | EPSILON = 1e-10 11 | 12 | def normal_attention_causal(Q, K, V): 13 | scale = 1 / np.sqrt(Q.shape[-1]) 14 | Q = Q * scale 15 | QKt = torch.einsum('... i d, ... j d -> ... i j', Q, K) 16 | 17 | Q_LEN = Q.shape[2] 18 | K_LEN = K.shape[2] 19 | 20 | causal_mask = torch.triu(torch.ones((Q_LEN, K_LEN)), K_LEN - Q_LEN + 1) 21 | causal_mask = causal_mask.to(device='cuda') 22 | QKt = torch.where(causal_mask > 0, NEG_INF, QKt) 23 | 24 | attn = nn.functional.softmax(QKt, dim=-1) 25 | return attn @ V 26 | 27 | def flash_attention_causal_forward(Q, K, V): 28 | O = torch.zeros_like(Q, requires_grad=True) 29 | l = torch.zeros(Q.shape[:-1])[...,None] 30 | m = torch.ones(Q.shape[:-1])[...,None] * NEG_INF 31 | 32 | O = O.to(device='cuda') 33 | l = l.to(device='cuda') 34 | m = m.to(device='cuda') 35 | 36 | Q_BLOCK_SIZE = min(BLOCK_SIZE, Q.shape[-1]) 37 | KV_BLOCK_SIZE = BLOCK_SIZE 38 | 39 | Q_BLOCKS = torch.split(Q, Q_BLOCK_SIZE, dim=2) 40 | K_BLOCKS = torch.split(K, KV_BLOCK_SIZE, dim=2) 41 | V_BLOCKS = torch.split(V, KV_BLOCK_SIZE, dim=2) 42 | 43 | Tr = len(Q_BLOCKS) 44 | Tc = len(K_BLOCKS) 45 | 46 | O_BLOCKS = list(torch.split(O, Q_BLOCK_SIZE, dim=2)) 47 | l_BLOCKS = list(torch.split(l, Q_BLOCK_SIZE, dim=2)) 48 | m_BLOCKS = list(torch.split(m, Q_BLOCK_SIZE, dim=2)) 49 | 50 | Q_LEN = Q.shape[2] 51 | K_LEN = K.shape[2] 52 | 53 | Q_RANGE = torch.arange(Q_LEN)[:,None] + (K_LEN - Q_LEN) 54 | K_RANGE = torch.arange(K_LEN)[None,:] 55 | 56 | Q_RANGE = Q_RANGE.to(device='cuda') 57 | K_RANGE = K_RANGE.to(device='cuda') 58 | 59 | Q_RANGE_BLOCKS = torch.split(Q_RANGE, Q_BLOCK_SIZE, dim=0) 60 | K_RANGE_BLOCKS = torch.split(K_RANGE, KV_BLOCK_SIZE, dim=1) 61 | 62 | for j in range(Tc): 63 | Kj = K_BLOCKS[j] 64 | Vj = V_BLOCKS[j] 65 | K_RANGE_BLOCKSj = K_RANGE_BLOCKS[j] 66 | 67 | for i in range(Tr): 68 | Qi = Q_BLOCKS[i] 69 | Oi = O_BLOCKS[i] 70 | li = l_BLOCKS[i] 71 | mi = m_BLOCKS[i] 72 | Q_RANGE_BLOCKSi = Q_RANGE_BLOCKS[i] 73 | 74 | scale = 1 / np.sqrt(Q.shape[-1]) 75 | Qi_scaled = Qi * scale 76 | 77 | S_ij = torch.einsum('... i d, ... j d -> ... i j', Qi_scaled, Kj) 78 | 79 | # Masking 80 | causal_mask = Q_RANGE_BLOCKSi >= K_RANGE_BLOCKSj 81 | S_ij = torch.where(causal_mask > 0, S_ij, NEG_INF) 82 | 83 | m_block_ij, _ = torch.max(S_ij, dim=-1, keepdims=True) 84 | P_ij = torch.exp(S_ij - m_block_ij) 85 | # Masking 86 | P_ij = torch.where(causal_mask > 0, P_ij, 0) 87 | 88 | l_block_ij = torch.sum(P_ij, dim=-1, keepdims=True) + EPSILON 89 | 90 | P_ij_Vj = torch.einsum('... i j, ... j d -> ... i d', P_ij, Vj) 91 | 92 | mi_new = torch.maximum(m_block_ij, mi) 93 | li_new = torch.exp(mi - mi_new) * li + torch.exp(m_block_ij - mi_new) * l_block_ij 94 | 95 | O_BLOCKS[i] = (li/li_new) * torch.exp(mi - mi_new) * Oi + (torch.exp(m_block_ij - mi_new) / li_new) * P_ij_Vj 96 | l_BLOCKS[i] = li_new 97 | m_BLOCKS[i] = mi_new 98 | 99 | O = torch.cat(O_BLOCKS, dim=2) 100 | l = torch.cat(l_BLOCKS, dim=2) 101 | m = torch.cat(m_BLOCKS, dim=2) 102 | return O, l, m 103 | 104 | def flash_attention_causal_backward(Q, K, V, O, l, m, dO): 105 | Q_BLOCK_SIZE = min(BLOCK_SIZE, Q.shape[-1]) 106 | KV_BLOCK_SIZE = BLOCK_SIZE 107 | 108 | Q_BLOCKS = torch.split(Q, Q_BLOCK_SIZE, dim=2) 109 | K_BLOCKS = torch.split(K, KV_BLOCK_SIZE, dim=2) 110 | V_BLOCKS = torch.split(V, KV_BLOCK_SIZE, dim=2) 111 | 112 | Tr = len(Q_BLOCKS) 113 | Tc = len(K_BLOCKS) 114 | 115 | O_BLOCKS = list(torch.split(O, Q_BLOCK_SIZE, dim=2)) 116 | dO_BLOCKS = list(torch.split(dO, Q_BLOCK_SIZE, dim=2)) 117 | l_BLOCKS = list(torch.split(l, Q_BLOCK_SIZE, dim=2)) 118 | m_BLOCKS = list(torch.split(m, Q_BLOCK_SIZE, dim=2)) 119 | 120 | dQ = torch.zeros_like(Q, requires_grad=True).to(device='cuda') 121 | dK = torch.zeros_like(K, requires_grad=True).to(device='cuda') 122 | dV = torch.zeros_like(V, requires_grad=True).to(device='cuda') 123 | 124 | dQ_BLOCKS = list(torch.split(dQ, Q_BLOCK_SIZE, dim=2)) 125 | dK_BLOCKS = list(torch.split(dK, KV_BLOCK_SIZE, dim=2)) 126 | dV_BLOCKS = list(torch.split(dV, KV_BLOCK_SIZE, dim=2)) 127 | 128 | Q_LEN = Q.shape[2] 129 | K_LEN = K.shape[2] 130 | 131 | Q_RANGE = torch.arange(Q_LEN)[:,None] + (K_LEN - Q_LEN) 132 | K_RANGE = torch.arange(K_LEN)[None,:] 133 | 134 | Q_RANGE = Q_RANGE.to(device='cuda') 135 | K_RANGE = K_RANGE.to(device='cuda') 136 | 137 | Q_RANGE_BLOCKS = torch.split(Q_RANGE, Q_BLOCK_SIZE, dim=0) 138 | K_RANGE_BLOCKS = torch.split(K_RANGE, KV_BLOCK_SIZE, dim=1) 139 | 140 | for j in range(Tc): 141 | Kj = K_BLOCKS[j] 142 | Vj = V_BLOCKS[j] 143 | K_RANGE_BLOCKSj = K_RANGE_BLOCKS[j] 144 | 145 | dKj_block = torch.zeros_like(dK_BLOCKS[j], requires_grad=True).to(device='cuda') 146 | dVj_block = torch.zeros_like(dV_BLOCKS[j], requires_grad=True).to(device='cuda') 147 | 148 | for i in range(Tr): 149 | Qi = Q_BLOCKS[i] 150 | Oi = O_BLOCKS[i] 151 | dOi = dO_BLOCKS[i] 152 | li = l_BLOCKS[i] 153 | mi = m_BLOCKS[i] 154 | Q_RANGE_BLOCKSi = Q_RANGE_BLOCKS[i] 155 | 156 | scale = 1 / np.sqrt(Q.shape[-1]) 157 | Qi_scaled = Qi * scale 158 | 159 | S_ij = torch.einsum('... i d, ... j d -> ... i j', Qi_scaled, Kj) 160 | 161 | # Masking 162 | causal_mask = Q_RANGE_BLOCKSi >= K_RANGE_BLOCKSj 163 | S_ij = torch.where(causal_mask > 0, S_ij, NEG_INF) 164 | 165 | P_ij = (1/li) * torch.exp(S_ij - mi) 166 | # Masking 167 | P_ij = torch.where(causal_mask > 0, P_ij, 0) 168 | 169 | dVj_block = dVj_block + torch.einsum('... r c, ... r d -> ... c d', P_ij, dOi) 170 | dP_ij = torch.einsum('... r d, ... c d -> ... r c', dOi, Vj) 171 | 172 | Di = torch.sum(dOi * Oi, dim=-1, keepdims=True) 173 | dS_ij = P_ij * (dP_ij - Di) 174 | 175 | dQ_BLOCKS[i] = dQ_BLOCKS[i] + scale * torch.einsum('... r c, ... c d -> ... r d', dS_ij, Kj) 176 | 177 | dKj_block = dKj_block + scale * torch.einsum('... r c, ... r d -> ... c d', dS_ij, Qi) 178 | 179 | dK_BLOCKS[j] = dKj_block 180 | dV_BLOCKS[j] = dVj_block 181 | 182 | dQ = torch.cat(dQ_BLOCKS, dim=2) 183 | dK = torch.cat(dK_BLOCKS, dim=2) 184 | dV = torch.cat(dV_BLOCKS, dim=2) 185 | return dQ, dK, dV 186 | 187 | def flash_attention_causal(Q, K, V): 188 | out = flash_attention_causal_forward(Q, K, V) 189 | return out[0] 190 | 191 | if __name__ == "__main__": 192 | Q = torch.randn(1, 2, 4096, 1024, requires_grad=True).to(device='cuda') 193 | K = torch.randn(1, 2, 4096, 1024, requires_grad=True).to(device='cuda') 194 | V = torch.randn(1, 2, 4096, 1024, requires_grad=True).to(device='cuda') 195 | 196 | for i in range(10): 197 | start1 = time.time_ns() 198 | out1 = flash_attention_causal(Q, K, V) 199 | end1 = time.time_ns() 200 | 201 | start2 = time.time_ns() 202 | out2 = normal_attention_causal(Q, K, V) 203 | end2 = time.time_ns() 204 | 205 | t1 = (end1 - start1) / 1000000 206 | t2 = (end2 - start2) / 1000000 207 | 208 | print(f'{t1}ms, {t2}ms') 209 | print(torch.allclose(out1, out2, atol=1e-5)) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops==0.6.1 2 | torch==2.0.1 3 | --------------------------------------------------------------------------------