├── .gitignore ├── common.py ├── test.py ├── README.md ├── mha.py ├── benchmark.py └── flash_mha.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/flash_mha.cpython-310.pyc 2 | __pycache__/common.cpython-312.pyc 3 | __pycache__/flash_mha.cpython-312.pyc 4 | __pycache__/mha.cpython-312.pyc 5 | -------------------------------------------------------------------------------- /common.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import torch 3 | 4 | if torch.cuda.is_available(): 5 | torch.cuda.manual_seed_all(1) 6 | DEVICE = torch.device("cuda:0") 7 | # Ensure that all operations are deterministic on GPU (if used) for reproducibility 8 | torch.backends.cudnn.deterministic = True 9 | torch.backends.cudnn.benchmark = False 10 | else: 11 | raise ValueError("GPU or CUDA is not available.") 12 | 13 | 14 | MASKOUT_VAL = -float("inf") 15 | EPS = 1e-10 16 | 17 | 18 | @dataclass 19 | class ModelParams: 20 | d_model: int 21 | num_heads: int 22 | block_size: int 23 | d_ffn: int 24 | 25 | 26 | @dataclass 27 | class QKV: 28 | Q: torch.tensor 29 | K: torch.tensor 30 | V: torch.tensor 31 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | 4 | from flash_mha import MultiheadFlashAttention 5 | from mha import MultiHeadAttention 6 | import common 7 | 8 | 9 | class TestFlashAttention(unittest.TestCase): 10 | def test_equivalence(self): 11 | # Parameters 12 | seq_len = 64 13 | batch = 10 14 | num_heads = 4 15 | d_model = 100 16 | block_size = 16 17 | device = common.DEVICE 18 | test_data = 0.01 * torch.randn(batch, seq_len, d_model).to(device) 19 | mha = MultiHeadAttention(d_model=d_model, num_heads=num_heads).to(device) 20 | flash_mha = MultiheadFlashAttention( 21 | d_model=d_model, 22 | num_heads=num_heads, 23 | block_size=block_size 24 | ).to(device) 25 | for param1, param2 in zip(mha.parameters(), flash_mha.parameters()): 26 | param2.data = param1.data 27 | # Test if they do actually have the same parameters: 28 | for param1, param2 in zip(mha.parameters(), flash_mha.parameters()): 29 | if (param2.data != param1.data).all(): 30 | raise ValueError("Two modules have different parameters!") 31 | 32 | flash_out = flash_mha(test_data) 33 | mha_out = mha(test_data) 34 | self.assertTrue(flash_out.shape == mha_out.shape) 35 | torch.testing.assert_close(flash_out, mha_out) 36 | 37 | 38 | if __name__ == "__main__": 39 | unittest.main() 40 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ⚡ Experimental Flash Attention implementation ⚡ 2 | 3 | A basic pure pytorch implementation of flash attention. 4 | 5 | The codebase is mainly written for educational purposes not meant to be used for production 6 | or anything serious. For more practical use cases consider using [Flexattention](https://pytorch.org/blog/flexattention/). 7 | 8 | I would refer to the original paper for the details of the algorithm. This implementation 9 | is based on the Algorithm 1 in the paper: 10 | 11 | ``` 12 | Dao, Tri, Dan Fu, Stefano Ermon, Atri Rudra, and Christopher Ré. "Flashattention: Fast and memory-efficient exact attention with io-awareness." Advances in Neural Information Processing Systems 35 (2022): 16344-16359. 13 | ``` 14 | 15 | The original implementation requires CUDA kernels for fusing operations and moving data 16 | between HBM to SRAM whereas this implementation does not consider any of these. 17 | 18 | ## Implementation 19 | There are four important files: 20 | - `flash_mha.py`: A very basic implementation of flash attention in pytorch. 21 | For educational purposes only. 22 | - `mha.py`: A vanilla implementation of the multi-head attention. 23 | - `test.py`: A unit test that checks the equivalence of the forward prop of flash attention 24 | and the vanilla multihead attention. 25 | - `benchmark.py`: For the speed and memory comparisons between the flash attention and 26 | the vanilla multihead attention. 27 | 28 | To run the test you can simply run: 29 | 30 | `python test.py` 31 | 32 | To run the benchmark you can simply run: 33 | 34 | `python benchmark.py` -------------------------------------------------------------------------------- /mha.py: -------------------------------------------------------------------------------- 1 | # A simple multihead attention implementation 2 | import math 3 | from einops import einsum 4 | import logging 5 | logging.basicConfig(level=logging.INFO) 6 | 7 | 8 | import torch 9 | from torch import nn 10 | import common 11 | 12 | 13 | class MultiHeadAttention(nn.Module): 14 | def __init__( 15 | self, 16 | d_model: int = None, 17 | num_heads: int = None, 18 | proj: nn.Module = None, 19 | out_proj: nn.ModuleDict = None, 20 | ): 21 | super().__init__() 22 | assert d_model % num_heads == 0, "d_model should be divisible by the num_heads." 23 | 24 | self.d_model = d_model 25 | self.num_heads = num_heads 26 | self.d_head = d_model // num_heads 27 | self.scaling = 1.0 / math.sqrt(self.d_head) 28 | self.out_proj = out_proj if out_proj else nn.Linear(d_model, d_model) 29 | self.proj = proj if proj else nn.Linear(d_model, 3 * d_model, bias=False) 30 | 31 | def scaled_dot_product_attention(self, qkv: common.QKV, mask: torch.tensor = None): 32 | score = ( 33 | einsum( 34 | qkv.Q, 35 | qkv.K, 36 | "b q_length n d, b kv_length n dim -> b n q_length kv_length", 37 | ) * self.scaling 38 | ) 39 | if mask is not None: 40 | score = score.masked_fill(mask == 0, -1e9) 41 | 42 | softmax_scores = torch.softmax(score, dim=-1) 43 | output = einsum( 44 | softmax_scores, 45 | qkv.V, 46 | "b n q_length kv_length, b kv_length n d -> b q_length n d", 47 | ) 48 | return output 49 | 50 | def combine_heads(self, x): 51 | batch_size, seq_len, *_ = x.size() 52 | return x.contiguous().view(batch_size, seq_len, -1) 53 | 54 | def forward(self, x: torch.tensor, mask: torch.tensor = None): 55 | # Q, K, V: [batch_size, seq_len, d_model] 56 | QKV_proj = self.proj(x) 57 | Q, K, V = torch.split(QKV_proj, self.d_model, dim=-1) 58 | # Q, K, V map them to [batch_size, seq_len, num_heads, head_dim] 59 | Q, K, V = map( 60 | lambda x: torch.Tensor.view(x, (Q.size(0), Q.size(1), self.num_heads, -1)), 61 | (Q, K, V), 62 | ) 63 | qkv = common.QKV(Q=Q, K=K, V=V) 64 | attention_out = self.scaled_dot_product_attention(qkv, mask=mask) 65 | output = self.out_proj(self.combine_heads(attention_out)) 66 | return output -------------------------------------------------------------------------------- /benchmark.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import os 4 | import torch 5 | import numpy as np 6 | import logging 7 | logging.basicConfig(level=logging.INFO) 8 | 9 | from flash_mha import MultiheadFlashAttention 10 | from mha import MultiHeadAttention 11 | import common 12 | 13 | if torch.cuda.is_available(): 14 | torch.backends.cudnn.benchmark = True 15 | os.environ['USE_KINETO'] = "1" 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--batch_size', default=2, required=False, type=int) 19 | parser.add_argument('--seq_len', default=8192, required=False, type=int) 20 | parser.add_argument('--d_model', default=1024, required=False, type=int) 21 | parser.add_argument('--num_heads', default=16, required=False, type=int) 22 | parser.add_argument('--block_size', default=512, required=False, type=int) 23 | parser.add_argument('--profile', default=True, required=False, type=bool) 24 | args = parser.parse_args() 25 | 26 | 27 | def time_attention(data: torch.tensor = None, 28 | num_trials: int = 10, 29 | attention_module = torch.nn.Module): 30 | 31 | ms_denom = 1_000_000 32 | device = common.DEVICE 33 | attention_module = attention_module.to(device) 34 | data = data.to(device) 35 | # Warmup: 36 | for _ in range(10): 37 | attention_module(data) 38 | logging.warning(f"Memory allocated: {torch.cuda.memory_allocated()}") 39 | # The real timing starts now: 40 | times = [] 41 | for _ in range(num_trials): 42 | start = time.time_ns() 43 | attention_module(data) 44 | end = time.time_ns() 45 | duration = (end - start) / ms_denom 46 | times.append(duration) 47 | logging.warning(f"Average time: {np.mean(times)} ms, Std time: {np.std(times)}.") 48 | attention_module.cpu() 49 | torch.cuda.empty_cache() 50 | 51 | def run_profile(data: torch.tensor = None, 52 | num_trials: int = 10, 53 | attention_module = torch.nn.Module): 54 | device = common.DEVICE 55 | attention_module = attention_module.to(device) 56 | data = data.to(device) 57 | with torch.profiler.profile( 58 | activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], 59 | on_trace_ready=torch.profiler.tensorboard_trace_handler('/tmp/profiler_logs/bench_log_flash'), 60 | record_shapes=True, 61 | profile_memory=True, 62 | with_stack=False, # incurs an additional overhead, disable if not needed 63 | with_flops=True, 64 | with_modules=False, # only for torchscript models atm 65 | ) as prof: 66 | for _ in range(num_trials): 67 | attention_module(data) 68 | logging.warning(prof.key_averages().table(sort_by="cuda_memory_usage", row_limit=10)) 69 | logging.warning(f"Memory allocated: {torch.cuda.memory_allocated()}") 70 | attention_module.cpu() 71 | torch.cuda.empty_cache() 72 | 73 | def run_benchmarks(): 74 | data = 0.01 * torch.randn((args.batch_size, args.seq_len, args.d_model)) 75 | mha_module = MultiHeadAttention(d_model=args.d_model, num_heads=args.num_heads) 76 | proj = mha_module.proj 77 | out_proj = mha_module.out_proj 78 | flash_mha_module = MultiheadFlashAttention(d_model=args.d_model, num_heads=args.num_heads, 79 | block_size=args.block_size, proj=proj, out_proj=out_proj) 80 | logging.warning("**Benchmarking flash multihead attention...**") 81 | time_attention(data, attention_module=flash_mha_module) 82 | logging.warning("========") 83 | logging.warning("**Benchmarking the regular multihead attention...") 84 | time_attention(data, attention_module=mha_module) 85 | logging.warning("Finished timing!") 86 | logging.warning("**Profiling flash multihead attention...**") 87 | run_profile(data, num_trials=20, attention_module=flash_mha_module) 88 | logging.warning("========") 89 | logging.warning("**Profiling flash multihead attention...**") 90 | run_profile(data, num_trials=20, attention_module=mha_module) 91 | logging.warning("**Finished profiling!**") 92 | 93 | 94 | if __name__ == "__main__": 95 | run_benchmarks() 96 | 97 | -------------------------------------------------------------------------------- /flash_mha.py: -------------------------------------------------------------------------------- 1 | # A basic multihead flash attention implementation 2 | import math 3 | from einops import einsum, rearrange 4 | import logging 5 | logging.basicConfig(level=logging.INFO) 6 | 7 | 8 | import torch 9 | from torch import nn 10 | import common 11 | 12 | 13 | class FlashAttention(nn.Module): 14 | def __init__(self, block_size=64): 15 | super().__init__() 16 | self.block_size = block_size 17 | self._initialized = False 18 | 19 | def _init_vars(self, size: torch.tensor, device: str): 20 | O = torch.zeros(size, requires_grad=True) 21 | l = torch.zeros(size[:-1]) 22 | m = torch.ones(size[:-1]) * common.MASKOUT_VAL 23 | self.O, self.l, self.m = map(lambda x: x.to(device), (O, l, m)) 24 | self._initialized = True 25 | 26 | def forward(self, qkv: common.QKV, mask=None): 27 | if not self._initialized: 28 | self._init_vars(qkv.Q.shape, qkv.Q.device) 29 | O, l, m = self.O, self.l, self.m 30 | q_seqlen = qkv.Q.shape[1] 31 | kv_seqlen = qkv.K.shape[1] 32 | Q_block_size = min(self.block_size, q_seqlen) 33 | KV_block_size = min(self.block_size, kv_seqlen) 34 | 35 | Q_blocks = torch.split(qkv.Q, Q_block_size, dim=1) 36 | O_blocks = list(torch.split(O, Q_block_size, dim=1)) 37 | l_blocks = list(torch.split(l, Q_block_size, dim=1)) 38 | m_blocks = list(torch.split(m, Q_block_size, dim=1)) 39 | 40 | K_blocks = torch.split(qkv.K, KV_block_size, dim=1) 41 | V_blocks = torch.split(qkv.V, KV_block_size, dim=1) 42 | if mask: 43 | mask_blocks = torch.split(mask, KV_block_size, dim=1) 44 | 45 | scale = 1.0 / math.sqrt(qkv.Q.shape[-1]) 46 | num_q_blocks = len(Q_blocks) 47 | num_kv_blocks = len(K_blocks) 48 | for j in range(num_kv_blocks): 49 | K_block_j, V_block_j = K_blocks[j], V_blocks[j] 50 | if mask: 51 | mask_block_j = mask_blocks[..., j] 52 | for i in range(num_q_blocks): 53 | Q_block_i = Q_blocks[i] 54 | if j == 0: 55 | O_block_i = rearrange(O_blocks[i], "b l n d -> b n l d") 56 | m_i = rearrange(m_blocks[i], "b l n -> b n l 1") 57 | l_i = rearrange(l_blocks[i], "b l n -> b n l 1") 58 | else: 59 | O_block_i = O_blocks[i] 60 | m_i = m_blocks[i] 61 | l_i = l_blocks[i] 62 | 63 | S_ij = ( 64 | einsum( 65 | Q_block_i, 66 | K_block_j, 67 | "b block_len_i n d, b block_len_j n d -> b n block_len_i block_len_j", 68 | ) 69 | * scale 70 | ) 71 | if mask: 72 | S_ij = S_ij.masked_fill(mask_block_j != 0, S_ij, common.MASKOUT_VAL) 73 | m_ij = torch.max(S_ij, dim=-1, keepdim=True).values 74 | P_ij = torch.exp(S_ij - m_ij) 75 | l_ij = P_ij.sum(-1, keepdim=True) + common.EPS 76 | m_i_new = torch.maximum(m_i, m_ij) 77 | l_i_new = ( 78 | torch.exp(m_i - m_i_new) * l_i + torch.exp(m_ij - m_i_new) * l_ij 79 | ) 80 | O_block_i_new = einsum( 81 | P_ij, 82 | V_block_j, 83 | "b n q_length kv_length, b kv_length n d -> b n q_length d", 84 | ) 85 | O_blocks[i] = (l_i * torch.exp(m_i - m_i_new) * O_block_i + 86 | torch.exp(m_ij - m_i_new) * O_block_i_new) / l_i_new 87 | l_blocks[i] = l_i_new 88 | m_blocks[i] = m_i_new 89 | attn_output = rearrange(torch.cat(O_blocks, dim=2), "b n l d -> b l n d") 90 | return attn_output 91 | 92 | 93 | class MultiheadFlashAttention(nn.Module): 94 | 95 | def __init__( 96 | self, 97 | d_model: int, 98 | num_heads: int, 99 | block_size: int, 100 | proj: nn.Module = None, 101 | out_proj: nn.Module = None, 102 | ): 103 | super().__init__() 104 | assert d_model % num_heads == 0, "d_model should be divisible by the num_heads." 105 | assert ( 106 | block_size <= d_model 107 | ), "the block size must be less than or equal to d_model." 108 | self.attention = FlashAttention(block_size=block_size) 109 | self.d_model = d_model 110 | self.num_heads = num_heads 111 | self.out_proj = out_proj if out_proj else nn.Linear(d_model, d_model) 112 | self.proj = proj if proj else nn.Linear(d_model, 3 * d_model, bias=False) 113 | 114 | def combine_heads(self, x): 115 | batch_size, seq_len, *_ = x.size() 116 | return x.contiguous().view(batch_size, seq_len, -1) 117 | 118 | def forward(self, x: torch.tensor, mask: torch.tensor = None): 119 | # Q, K, V: [batch_size, seq_len, d_model] 120 | QKV_proj = self.proj(x) 121 | Q, K, V = torch.split(QKV_proj, self.d_model, dim=-1) 122 | 123 | # Q, K, V map them to [batch_size, seq_len, num_heads, head_dim] 124 | Q, K, V = map( 125 | lambda x: torch.Tensor.view(x, (Q.size(0), Q.size(1), self.num_heads, -1)), 126 | (Q, K, V), 127 | ) 128 | qkv = common.QKV(Q=Q, K=K, V=V) 129 | attention_out = self.attention(qkv, mask=mask) 130 | output = self.out_proj(self.combine_heads(attention_out)) 131 | return output 132 | --------------------------------------------------------------------------------