├── requirements.txt ├── data ├── results-causal.png ├── results-random.png └── results-causal-fa.png ├── fa2_custom_mask ├── __init__.py ├── utils.py ├── fa2_custom_mask.py ├── fa2_fwd.py └── fa2_bwd.py ├── setup.py ├── pyproject.toml ├── .gitignore ├── README.md ├── LICENSE ├── test_benchmark.py └── fa2_original.py /requirements.txt: -------------------------------------------------------------------------------- 1 | pytest~=8.3.1 2 | triton~=3.0.0 3 | torch~=2.2.1 4 | numpy==1.26.4 -------------------------------------------------------------------------------- /data/results-causal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexzhang13/flashattention2-custom-mask/HEAD/data/results-causal.png -------------------------------------------------------------------------------- /data/results-random.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexzhang13/flashattention2-custom-mask/HEAD/data/results-random.png -------------------------------------------------------------------------------- /data/results-causal-fa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexzhang13/flashattention2-custom-mask/HEAD/data/results-causal-fa.png -------------------------------------------------------------------------------- /fa2_custom_mask/__init__.py: -------------------------------------------------------------------------------- 1 | from fa2_custom_mask.fa2_custom_mask import flash_attention_custom_mask 2 | from fa2_custom_mask.fa2_fwd import _attn_fwd 3 | from fa2_custom_mask.fa2_bwd import _attn_bwd_preprocess, _attn_bwd 4 | from fa2_custom_mask.utils import is_hip, keep -------------------------------------------------------------------------------- /fa2_custom_mask/utils.py: -------------------------------------------------------------------------------- 1 | import triton 2 | 3 | def is_hip(): 4 | return False 5 | # bugged in older versions of Triton 6 | # return triton.runtime.driver.active.get_current_target().backend == "hip" 7 | 8 | 9 | def keep(conf): 10 | BLOCK_M = conf.kwargs["BLOCK_M"] 11 | BLOCK_N = conf.kwargs["BLOCK_N"] 12 | if BLOCK_M * BLOCK_N < 128 * 128 and conf.num_warps == 8: 13 | return False 14 | return True -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r", encoding="utf-8") as f: 4 | long_description = f.read() 5 | 6 | setuptools.setup( 7 | name="Unofficial FlashAttention2 with Custom Masks", 8 | version="0.1.0", 9 | packages=["fa2_custom_mask"], 10 | description='Unofficial implementation of FlashAttention2 with Custom Masks', 11 | long_description=long_description, 12 | long_description_content_type="text/markdown", 13 | install_requires=[ 14 | "triton", 15 | "torch", 16 | ], 17 | python_requires=">=3.8", 18 | ) 19 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "flashattention2-custom-mask" 3 | version = "0.1.0" 4 | description = 'Unofficial FlashAttention2 with Custom Masks' 5 | readme = "README.md" 6 | requires-python = ">=3.8" 7 | authors = [{ name = "Alex Zhang", email = "alzhang@alumni.princeton.edu" }] 8 | keywords = ["flash attention", "triton", "pytorch"] 9 | dynamic=["dependencies"] 10 | 11 | 12 | [tool.setuptools.dynamic] 13 | dependencies = {file = ["requirements.txt"]} 14 | 15 | [tool.setuptools.packages.find] 16 | where = ["."] # list of folders that contain the packages (["."] by default) 17 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Data 2 | results/ 3 | *.DS_Store 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | cover/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | .pybuilder/ 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | # For a library or package, you might want to ignore these files since the code is 91 | # intended to run in multiple environments; otherwise, check them in: 92 | # .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # poetry 102 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 103 | # This is especially recommended for binary packages to ensure reproducibility, and is more 104 | # commonly ignored for libraries. 105 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 106 | #poetry.lock 107 | 108 | # pdm 109 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 110 | #pdm.lock 111 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 112 | # in version control. 113 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 114 | .pdm.toml 115 | .pdm-python 116 | .pdm-build/ 117 | 118 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 119 | __pypackages__/ 120 | 121 | # Celery stuff 122 | celerybeat-schedule 123 | celerybeat.pid 124 | 125 | # SageMath parsed files 126 | *.sage.py 127 | 128 | # Environments 129 | .env 130 | .venv 131 | env/ 132 | venv/ 133 | ENV/ 134 | env.bak/ 135 | venv.bak/ 136 | 137 | # Spyder project settings 138 | .spyderproject 139 | .spyproject 140 | 141 | # Rope project settings 142 | .ropeproject 143 | 144 | # mkdocs documentation 145 | /site 146 | 147 | # mypy 148 | .mypy_cache/ 149 | .dmypy.json 150 | dmypy.json 151 | 152 | # Pyre type checker 153 | .pyre/ 154 | 155 | # pytype static type analyzer 156 | .pytype/ 157 | 158 | # Cython debug symbols 159 | cython_debug/ 160 | 161 | # PyCharm 162 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 163 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 164 | # and can be added to the global gitignore or merged into this file. For a more nuclear 165 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 166 | #.idea/ 167 | -------------------------------------------------------------------------------- /fa2_custom_mask/fa2_custom_mask.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import triton 3 | 4 | from fa2_custom_mask.fa2_fwd import _attn_fwd 5 | from fa2_custom_mask.fa2_bwd import _attn_bwd_preprocess, _attn_bwd 6 | from fa2_custom_mask.utils import is_hip 7 | 8 | class _attention(torch.autograd.Function): 9 | 10 | @staticmethod 11 | def forward(ctx, q, k, v, mask=None, sm_scale=1.3): 12 | # shape constraints 13 | HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1] 14 | USE_MASK = mask is not None 15 | # when v is in float8_e5m2 it is transposed. 16 | HEAD_DIM_V = v.shape[-1] 17 | assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V 18 | assert HEAD_DIM_K in {16, 32, 64, 128, 256} 19 | o = torch.empty_like(q) 20 | 21 | # TODO: verify this means mask is not None 22 | stage = 3 if mask is not None else 2 23 | extra_kern_args = {} 24 | # Tuning for AMD target 25 | if is_hip(): 26 | waves_per_eu = 3 if HEAD_DIM_K <= 64 else 2 27 | extra_kern_args = {"waves_per_eu": waves_per_eu, "allow_flush_denorm": True} 28 | 29 | grid = lambda args: (triton.cdiv(q.shape[2], args["BLOCK_M"]), q.shape[0] * q.shape[1], 1) 30 | M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) 31 | 32 | mask_stride_0 = (None if not USE_MASK else mask.stride(0)) 33 | mask_stride_1 = (None if not USE_MASK else mask.stride(1)) 34 | mask_stride_2 = (None if not USE_MASK else mask.stride(2)) 35 | mask_stride_3 = (None if not USE_MASK else mask.stride(3)) 36 | 37 | _attn_fwd[grid]( 38 | q, k, v, mask, sm_scale, M, o, # 39 | q.stride(0), q.stride(1), q.stride(2), q.stride(3), # 40 | k.stride(0), k.stride(1), k.stride(2), k.stride(3), # 41 | v.stride(0), v.stride(1), v.stride(2), v.stride(3), # 42 | o.stride(0), o.stride(1), o.stride(2), o.stride(3), # 43 | mask_stride_0, mask_stride_1, mask_stride_2, mask_stride_3, # 44 | q.shape[0], q.shape[1], # 45 | N_CTX=q.shape[2], # 46 | HEAD_DIM=HEAD_DIM_K, # 47 | STAGE=stage, 48 | USE_MASK=USE_MASK, # 49 | **extra_kern_args) 50 | 51 | ctx.save_for_backward(q, k, v, o, mask, M) 52 | ctx.grid = grid 53 | ctx.sm_scale = sm_scale 54 | ctx.HEAD_DIM = HEAD_DIM_K 55 | ctx.USE_MASK = USE_MASK 56 | return o 57 | 58 | @staticmethod 59 | def backward(ctx, do): 60 | q, k, v, o, mask, M = ctx.saved_tensors 61 | assert do.is_contiguous() 62 | assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride() 63 | dq = torch.empty_like(q) 64 | dk = torch.empty_like(k) 65 | dv = torch.empty_like(v) 66 | BATCH, N_HEAD, N_CTX = q.shape[:3] 67 | PRE_BLOCK = 128 68 | NUM_WARPS, NUM_STAGES = 4, 5 69 | BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32 70 | BLK_SLICE_FACTOR = 2 71 | RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2) 72 | arg_k = k 73 | arg_k = arg_k * (ctx.sm_scale * RCP_LN2) 74 | PRE_BLOCK = 128 75 | assert N_CTX % PRE_BLOCK == 0 76 | pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD) 77 | delta = torch.empty_like(M) 78 | _attn_bwd_preprocess[pre_grid]( 79 | o, do, # 80 | delta, # 81 | BATCH, N_HEAD, N_CTX, # 82 | BLOCK_M=PRE_BLOCK, HEAD_DIM=ctx.HEAD_DIM # 83 | ) 84 | grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD) 85 | 86 | if ctx.USE_MASK: 87 | _attn_bwd[grid]( 88 | q, arg_k, v, mask, ctx.sm_scale, do, dq, dk, dv, # 89 | M, delta, # 90 | q.stride(0), q.stride(1), q.stride(2), q.stride(3), # 91 | mask.stride(0), mask.stride(1), mask.stride(2), mask.stride(3), # 92 | N_HEAD, N_CTX, # 93 | BLOCK_M1=BLOCK_M1, BLOCK_N1=BLOCK_N1, # 94 | BLOCK_M2=BLOCK_M2, BLOCK_N2=BLOCK_N2, # 95 | BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, # 96 | HEAD_DIM=ctx.HEAD_DIM, 97 | USE_MASK=ctx.USE_MASK, # 98 | num_warps=NUM_WARPS, # 99 | num_stages=NUM_STAGES, # 100 | ) 101 | else: 102 | _attn_bwd[grid]( 103 | q, arg_k, v, None, ctx.sm_scale, do, dq, dk, dv, # 104 | M, delta, # 105 | q.stride(0), q.stride(1), q.stride(2), q.stride(3), # 106 | None, None, None, None, # 107 | N_HEAD, N_CTX, # 108 | BLOCK_M1=BLOCK_M1, BLOCK_N1=BLOCK_N1, # 109 | BLOCK_M2=BLOCK_M2, BLOCK_N2=BLOCK_N2, # 110 | BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, # 111 | HEAD_DIM=ctx.HEAD_DIM, 112 | USE_MASK=ctx.USE_MASK,# 113 | num_warps=NUM_WARPS, # 114 | num_stages=NUM_STAGES # 115 | ) 116 | 117 | return dq, dk, dv, None, None 118 | 119 | flash_attention_custom_mask = _attention.apply 120 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FlashAttention2 with Custom Masks 🎭 2 | **Note: This is an unofficial implementation of FlashAttention2.** 3 | 4 | For efficiency purposes, the standard implementations of FlashAttention currently do not support **arbitrary custom masks**. 5 | Their implementation of specific masks like causal masking for language modeling are implemented using branch logic to save memory. This repository is just a modified version of the tutorial Triton implementation of FlashAttention2 that allows the user 6 | to define a (batch of) custom mask. It modifies both the forward and backwards pass to handle custom masking (you can define a different mask per head and batch). 7 | 8 | Original Triton code: [https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html](https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html) 9 | 10 | See the original thread: [https://github.com/Dao-AILab/flash-attention/issues/352](https://github.com/Dao-AILab/flash-attention/issues/352) 11 | 12 | ## Quick Install 13 | Create a Python environment (>=3.8) and install through pip: 14 | ``` 15 | pip install flashattention2-custom-mask 16 | ``` 17 | 18 | ## Example Setup 19 | The relevant libraries needed to use the custom-mask FlashAttention2 kernel are below: 20 | ``` 21 | pip install triton>=3.0.0 22 | pip install torch 23 | ``` 24 | 25 | #### For Viewing Benchmarking Results 26 | Other libraries for evaluating the performance of the models is below. These are primarily for `test_benchmark.py`, which verifies the correctness of the implementation. 27 | ``` 28 | pip install pytest 29 | pip install matplotlib 30 | pip install pandas 31 | ``` 32 | To compare with the official FlashAttention and `xformers.ops.memory_efficient_attention` implementations, make sure to install both libraries separately (follow the instructions on these repositories). 33 | ``` 34 | pip install flash-attn --no-build-isolation 35 | pip3 install -U xformers --index-url https://download.pytorch.org/whl/cu121 36 | ``` 37 | 38 | ## Testing Correctness 39 | There are two `pytest` functions in `test_benchmark.py`, one that tests whether a reference implementation of multi-head attention with a causal mask matches the Triton version in both the forward pass and backwards pass gradients. The second tests whether the same implementation with **random** masks matches the Triton version. You can modify these tests to do more rigorous correctness tests and check with `pytest`. 40 | 41 | ## Simple Example 42 | You can insert this module into your standard attention pipeline. 43 | ```python 44 | from fa2_custom_mask import flash_attention_custom_mask 45 | 46 | B, H, L, D = 4, 16, 4096, 64 47 | sm_scale = 1 / (D ** 0.5) 48 | 49 | fp32_q = torch.randn(B, H, L, D).float().cuda() 50 | fp32_k = torch.randn(B, H, L, D).float().cuda() 51 | fp32_v = torch.randn(B, H, L, D).float().cuda() 52 | mask = torch.randint(0, 2, (B, 1, L, L)).int().cuda() 53 | mask = torch.broadcast_to(mask, (B, H, L, L)) 54 | 55 | out = flash_attention_custom_mask(fp32_q, fp32_k, fp32_v, mask=mask, sm_scale=sm_scale) 56 | ... 57 | out.backward(loss) 58 | ``` 59 | 60 | ## Benchmarking 61 | Simple benchmark against the base Triton implementation. In our custom mask version, we pass in the canonical causal mask as input (hence storing in global device memory). Running `test_benchmark.py`, 62 | with batch size=4, # heads=16, hidden dim=64, and sequence length `N_CTX` ranging from 256 to 16384 in powers of 2. You can replicate the experiments by running 63 | ``` 64 | pytest 65 | python test_benchmark.py 66 | ``` 67 | 68 | #### Causal Masks and No Masks Comparisons 69 | We compare against the original experiments and original implementation, as well as the official FlashAttention and xformers implementation (note: there seems to be a versioning issue, so it's using a different implementation. I corrected the version in the later benchmarking experiments). 70 | ![causal and no masking with flash attn](./data/results-causal-fa.png) 71 | 72 | #### Causal Masks and No Masks Comparisons (with Correct xfrormers version) 73 | We compare against the original experiments and original implementation, as well as the xformers implementation. Notably, the original implementation does well for causal masking because of some pipelining tricks and ability to not have to store masks. 74 | ![causal and no masking](./data/results-causal.png) 75 | #### Custom Masking Comparison 76 | We compare directly to the [xformers memory efficient attention](https://facebookresearch.github.io/xformers/components/ops.html) which allows for custom masking. We generate random masks (fixed across the head dimension). 77 | ![custom masking](./data/results-random.png) 78 | 79 | 80 | ## Notes and Bugs 81 | 1. This implementation only works on Ampere devices and up. I originally tried running it on a V100 (Volta) and it failed. 82 | 2. You need to be on `triton>=3.0.0`, or it'll complain about permutation indices on the value vector pointer. The `torch` and `flash-attn` libraries may force you to install `triton=2.x.x`, but you can just re-install `triton>=3.0.0` and it should work. I may fix this manually in the future. 83 | * This is oddly specific, but I'm not able to have `flash-attn` and `xformers` at the same time. I had to run them separately and generate the plots. 84 | 3. TODO: Add benchmarking for peak memory consumption and other efficiency metrics. 85 | 86 | If time permits, I'm interested in making this implementation generalizable / changing the CUDA implementation for FA3 (if it's necessary of course). I also probably will run some more realistic workloads and see what happens. 87 | 88 | 89 | 90 | -------------------------------------------------------------------------------- /fa2_custom_mask/fa2_fwd.py: -------------------------------------------------------------------------------- 1 | import triton 2 | import triton.language as tl 3 | 4 | from fa2_custom_mask.utils import is_hip, keep 5 | 6 | # We don't run auto-tuning every time to keep the tutorial fast. Keeping 7 | # the code below and commenting out the equivalent parameters is convenient for 8 | # re-tuning. 9 | configs = [ 10 | triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN}, num_stages=s, num_warps=w) \ 11 | for BM in [64, 128]\ 12 | for BN in [32, 64]\ 13 | for s in ([1] if is_hip() else [3, 4, 7])\ 14 | for w in [4, 8]\ 15 | ] 16 | 17 | @triton.jit 18 | def _attn_fwd_inner( 19 | acc, 20 | l_i, 21 | m_i, 22 | q, # 23 | K_block_ptr, 24 | V_block_ptr, # 25 | mask_block_ptr, # TODO: make sure it's added 26 | start_m, # 27 | qk_scale, # 28 | BLOCK_M: tl.constexpr, 29 | HEAD_DIM: tl.constexpr, 30 | BLOCK_N: tl.constexpr, # 31 | STAGE: tl.constexpr, 32 | offs_m: tl.constexpr, 33 | offs_n: tl.constexpr, # 34 | N_CTX: tl.constexpr, 35 | fp8_v: tl.constexpr, 36 | USE_MASK: tl.constexpr, 37 | ): 38 | """ 39 | 40 | """ 41 | # range of values handled by this stage 42 | # if STAGE == 1: 43 | # lo, hi = 0, start_m * BLOCK_M 44 | # elif STAGE == 2: 45 | # lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M 46 | # lo = tl.multiple_of(lo, BLOCK_M) 47 | # causal = False 48 | # else: 49 | # lo, hi = 0, N_CTX 50 | lo, hi = 0, N_CTX 51 | 52 | K_block_ptr = tl.advance(K_block_ptr, (0, lo)) 53 | V_block_ptr = tl.advance(V_block_ptr, (lo, 0)) 54 | 55 | if USE_MASK: 56 | # TODO: advance mask pointer along N dim 57 | mask_block_ptr = tl.advance(mask_block_ptr, (0, lo)) 58 | 59 | # loop over k, v and update accumulator 60 | for start_n in range(lo, hi, BLOCK_N): 61 | start_n = tl.multiple_of(start_n, BLOCK_N) 62 | # -- compute qk ---- 63 | k = tl.load(K_block_ptr) 64 | qk = tl.dot(q, k) 65 | if USE_MASK: 66 | # mask = offs_m[:, None] >= (start_n + offs_n[None, :]) 67 | # TODO: replace mask! 68 | mask_ = tl.load(mask_block_ptr) 69 | qk = qk * qk_scale + tl.where(mask_, 0, -1.0e6) 70 | m_ij = tl.maximum(m_i, tl.max(qk, 1)) 71 | qk -= m_ij[:, None] 72 | else: 73 | m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) 74 | qk = qk * qk_scale - m_ij[:, None] 75 | p = tl.math.exp2(qk) 76 | l_ij = tl.sum(p, 1) 77 | # -- update m_i and l_i 78 | alpha = tl.math.exp2(m_i - m_ij) 79 | l_i = l_i * alpha + l_ij 80 | # -- update output accumulator -- 81 | acc = acc * alpha[:, None] 82 | # update acc 83 | v = tl.load(V_block_ptr) 84 | if fp8_v: 85 | p = p.to(tl.float8e5) 86 | else: 87 | p = p.to(tl.float16) 88 | acc = tl.dot(p, v, acc) 89 | # update m_i and l_i 90 | m_i = m_ij 91 | 92 | # TODO: is this wrong? no, just stride 93 | V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) 94 | K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) 95 | 96 | # TODO: update mask pointer offset 97 | if USE_MASK: 98 | mask_block_ptr = tl.advance(mask_block_ptr, (0, BLOCK_N)) 99 | 100 | return acc, l_i, m_i 101 | 102 | 103 | @triton.autotune(list(filter(keep, configs)), key=["N_CTX", "HEAD_DIM"]) 104 | @triton.jit 105 | def _attn_fwd(Q, K, V, mask, sm_scale, M, Out, # 106 | stride_qz, stride_qh, stride_qm, stride_qk, # 107 | stride_kz, stride_kh, stride_kn, stride_kk, # 108 | stride_vz, stride_vh, stride_vk, stride_vn, # 109 | stride_oz, stride_oh, stride_om, stride_on, # 110 | stride_mask_z, stride_mask_h, stride_mask_m, stride_mask_n, # 111 | Z, H, N_CTX, # 112 | HEAD_DIM: tl.constexpr, # 113 | BLOCK_M: tl.constexpr, # 114 | BLOCK_N: tl.constexpr, # 115 | STAGE: tl.constexpr, # 116 | USE_MASK: tl.constexpr, 117 | ): 118 | tl.static_assert(BLOCK_N <= HEAD_DIM) 119 | start_m = tl.program_id(0) 120 | off_hz = tl.program_id(1) 121 | off_z = off_hz // H 122 | off_h = off_hz % H 123 | qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh 124 | 125 | if USE_MASK: 126 | mask_offset = off_z.to(tl.int64) * stride_mask_z + off_h.to(tl.int64) * stride_mask_h 127 | 128 | # block pointers 129 | Q_block_ptr = tl.make_block_ptr( 130 | base=Q + qvk_offset, 131 | shape=(N_CTX, HEAD_DIM), 132 | strides=(stride_qm, stride_qk), 133 | offsets=(start_m * BLOCK_M, 0), 134 | block_shape=(BLOCK_M, HEAD_DIM), 135 | order=(1, 0), 136 | ) 137 | v_order: tl.constexpr = (0, 1) if V.dtype.element_ty == tl.float8e5 else (1, 0) 138 | V_block_ptr = tl.make_block_ptr( 139 | base=V + qvk_offset, 140 | shape=(N_CTX, HEAD_DIM), 141 | strides=(stride_vk, stride_vn), 142 | offsets=(0, 0), 143 | block_shape=(BLOCK_N, HEAD_DIM), 144 | order=v_order, 145 | ) 146 | K_block_ptr = tl.make_block_ptr( 147 | base=K + qvk_offset, 148 | shape=(HEAD_DIM, N_CTX), 149 | strides=(stride_kk, stride_kn), 150 | offsets=(0, 0), 151 | block_shape=(HEAD_DIM, BLOCK_N), 152 | order=(0, 1), 153 | ) 154 | O_block_ptr = tl.make_block_ptr( 155 | base=Out + qvk_offset, 156 | shape=(N_CTX, HEAD_DIM), 157 | strides=(stride_om, stride_on), 158 | offsets=(start_m * BLOCK_M, 0), 159 | block_shape=(BLOCK_M, HEAD_DIM), 160 | order=(1, 0), 161 | ) 162 | 163 | # TODO: Make a mask block pointer and compute offsets 164 | mask_block_ptr = None if not USE_MASK else tl.make_block_ptr( 165 | base=mask + mask_offset, 166 | shape=(N_CTX, N_CTX), 167 | strides=(stride_mask_m, stride_mask_n), # TODO 168 | offsets=(start_m * BLOCK_M, 0), 169 | block_shape=(BLOCK_M, BLOCK_N), 170 | order=(0, 1), 171 | ) 172 | 173 | # initialize offsets 174 | offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) 175 | offs_n = tl.arange(0, BLOCK_N) 176 | # initialize pointer to m and l 177 | m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") 178 | l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 179 | acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) 180 | # load scales 181 | qk_scale = sm_scale 182 | qk_scale *= 1.44269504 # 1/log(2) 183 | # load q: it will stay in SRAM throughout 184 | q = tl.load(Q_block_ptr) 185 | # stage 1: off-band 186 | # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE 187 | # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE 188 | if USE_MASK: 189 | acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, # 190 | mask_block_ptr, 191 | start_m, qk_scale, # 192 | BLOCK_M, HEAD_DIM, BLOCK_N, # 193 | 4 - STAGE, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5, # 194 | USE_MASK 195 | ) 196 | else: 197 | acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, # 198 | None, 199 | start_m, qk_scale, # 200 | BLOCK_M, HEAD_DIM, BLOCK_N, # 201 | 2, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5, # 202 | USE_MASK 203 | ) 204 | # epilogue 205 | m_i += tl.math.log2(l_i) 206 | acc = acc / l_i[:, None] 207 | m_ptrs = M + off_hz * N_CTX + offs_m 208 | tl.store(m_ptrs, m_i) 209 | tl.store(O_block_ptr, acc.to(Out.type.element_ty)) 210 | -------------------------------------------------------------------------------- /fa2_custom_mask/fa2_bwd.py: -------------------------------------------------------------------------------- 1 | import triton 2 | import triton.language as tl 3 | 4 | 5 | @triton.jit 6 | def _attn_bwd_preprocess(O, DO, # 7 | Delta, # 8 | Z, H, N_CTX, # 9 | BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr # 10 | ): 11 | off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) 12 | off_hz = tl.program_id(1) 13 | off_n = tl.arange(0, HEAD_DIM) 14 | # load 15 | o = tl.load(O + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]) 16 | do = tl.load(DO + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]).to(tl.float32) 17 | delta = tl.sum(o * do, axis=1) 18 | # write-back 19 | tl.store(Delta + off_hz * N_CTX + off_m, delta) 20 | 21 | 22 | # The main inner-loop logic for computing dK and dV. 23 | @triton.jit 24 | def _attn_bwd_dkdv(dk, dv, # 25 | Q, k, v, mask, sm_scale, # 26 | DO, # 27 | M, D, # 28 | # shared by Q/K/V/DO. 29 | stride_tok, stride_d, # 30 | mask_stride_tok, mask_stride_tokk, 31 | H, N_CTX, BLOCK_M1: tl.constexpr, # 32 | BLOCK_N1: tl.constexpr, # 33 | HEAD_DIM: tl.constexpr, # 34 | # Filled in by the wrapper. 35 | start_n, start_m, num_steps, # 36 | MASK: tl.constexpr): 37 | offs_m = start_m + tl.arange(0, BLOCK_M1) 38 | offs_n = start_n + tl.arange(0, BLOCK_N1) 39 | offs_k = tl.arange(0, HEAD_DIM) 40 | qT_ptrs = Q + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d 41 | do_ptrs = DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d 42 | 43 | # TODO: verify correctness 44 | if MASK: 45 | maskT_ptrs = mask + offs_m[None, :] * mask_stride_tok + offs_n[:, None] * mask_stride_tokk 46 | 47 | # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. 48 | tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) 49 | curr_m = start_m 50 | step_m = BLOCK_M1 51 | for blk_idx in range(num_steps): 52 | qT = tl.load(qT_ptrs) 53 | # Load m before computing qk to reduce pipeline stall. 54 | offs_m = curr_m + tl.arange(0, BLOCK_M1) 55 | m = tl.load(M + offs_m) 56 | qkT = tl.dot(k, qT) 57 | pT = tl.math.exp2(qkT - m[None, :]) 58 | # Autoregressive masking. 59 | if MASK: 60 | maskT = tl.load(maskT_ptrs) 61 | # mask = (offs_m[None, :] >= offs_n[:, None]) 62 | pT = tl.where(maskT, pT, 0.0) 63 | do = tl.load(do_ptrs) 64 | # Compute dV. 65 | ppT = pT 66 | ppT = ppT.to(tl.float16) 67 | dv += tl.dot(ppT, do) 68 | # D (= delta) is pre-divided by ds_scale. 69 | Di = tl.load(D + offs_m) 70 | # Compute dP and dS. 71 | dpT = tl.dot(v, tl.trans(do)).to(tl.float32) 72 | dsT = pT * (dpT - Di[None, :]) 73 | dsT = dsT.to(tl.float16) 74 | dk += tl.dot(dsT, tl.trans(qT)) 75 | # Increment pointers. 76 | curr_m += step_m 77 | qT_ptrs += step_m * stride_tok 78 | do_ptrs += step_m * stride_tok 79 | if MASK: 80 | maskT_ptrs += step_m * mask_stride_tok 81 | return dk, dv 82 | 83 | 84 | # the main inner-loop logic for computing dQ 85 | @triton.jit 86 | def _attn_bwd_dq(dq, q, K, V, # 87 | do, m, D, mask, 88 | # shared by Q/K/V/DO. 89 | mask_stride_tok, mask_stride_tokk, # 90 | stride_tok, stride_d, # 91 | H, N_CTX, # 92 | BLOCK_M2: tl.constexpr, # 93 | BLOCK_N2: tl.constexpr, # 94 | HEAD_DIM: tl.constexpr, 95 | # Filled in by the wrapper. 96 | start_m, start_n, num_steps, # 97 | MASK: tl.constexpr): 98 | offs_m = start_m + tl.arange(0, BLOCK_M2) 99 | offs_n = start_n + tl.arange(0, BLOCK_N2) 100 | offs_k = tl.arange(0, HEAD_DIM) 101 | kT_ptrs = K + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d 102 | vT_ptrs = V + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d 103 | 104 | if MASK: 105 | mask_ptrs = mask + offs_m[:, None] * mask_stride_tok + offs_n[None, :] * mask_stride_tokk 106 | 107 | # D (= delta) is pre-divided by ds_scale. 108 | Di = tl.load(D + offs_m) 109 | # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. 110 | tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) 111 | curr_n = start_n 112 | step_n = BLOCK_N2 113 | for blk_idx in range(num_steps): 114 | kT = tl.load(kT_ptrs) 115 | vT = tl.load(vT_ptrs) 116 | qk = tl.dot(q, kT) 117 | p = tl.math.exp2(qk - m) 118 | # Autoregressive masking. 119 | if MASK: 120 | # offs_n = curr_n + tl.arange(0, BLOCK_N2) 121 | mask_ = tl.load(mask_ptrs) 122 | # mask_ = (offs_m[:, None] >= offs_n[None, :]) 123 | p = tl.where(mask_, p, 0.0) 124 | # Compute dP and dS. 125 | dp = tl.dot(do, vT).to(tl.float32) 126 | ds = p * (dp - Di[:, None]) 127 | ds = ds.to(tl.float16) 128 | # Compute dQ. 129 | # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. 130 | dq += tl.dot(ds, tl.trans(kT)) 131 | # Increment pointers. 132 | curr_n += step_n 133 | kT_ptrs += step_n * stride_tok 134 | vT_ptrs += step_n * stride_tok 135 | if MASK: 136 | mask_ptrs += step_n * mask_stride_tokk 137 | return dq 138 | 139 | 140 | @triton.jit 141 | def _attn_bwd(Q, K, V, mask, sm_scale, # 142 | DO, # 143 | DQ, DK, DV, # 144 | M, D, 145 | # shared by Q/K/V/DO. 146 | stride_z, stride_h, stride_tok, stride_d, # 147 | mask_stride_z, mask_stride_h, mask_stride_tok, mask_stride_tokk, # 148 | H, N_CTX, # 149 | BLOCK_M1: tl.constexpr, # 150 | BLOCK_N1: tl.constexpr, # 151 | BLOCK_M2: tl.constexpr, # 152 | BLOCK_N2: tl.constexpr, # 153 | BLK_SLICE_FACTOR: tl.constexpr, # 154 | HEAD_DIM: tl.constexpr, 155 | USE_MASK: tl.constexpr): 156 | LN2: tl.constexpr = 0.6931471824645996 # = ln(2) 157 | 158 | bhid = tl.program_id(2) 159 | off_chz = (bhid * N_CTX).to(tl.int64) 160 | adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64) 161 | 162 | if USE_MASK: 163 | m_adj = (mask_stride_h * (bhid % H) + mask_stride_z * (bhid // H)).to(tl.int64) 164 | mask += m_adj 165 | 166 | # TODO: verify this is the same as adj 167 | pid = tl.program_id(0) 168 | 169 | # offset pointers for batch/head 170 | Q += adj 171 | K += adj 172 | V += adj 173 | DO += adj 174 | DQ += adj 175 | DK += adj 176 | DV += adj 177 | M += off_chz 178 | D += off_chz 179 | 180 | # load scales 181 | offs_k = tl.arange(0, HEAD_DIM) 182 | 183 | start_n = pid * BLOCK_N1 184 | start_m = start_n 185 | 186 | MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR 187 | offs_n = start_n + tl.arange(0, BLOCK_N1) 188 | 189 | dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) 190 | dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) 191 | 192 | # load K and V: they stay in SRAM throughout the inner loop. 193 | k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d) 194 | v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d) 195 | 196 | num_steps = BLOCK_N1 // MASK_BLOCK_M1 197 | # num_steps = N_CTX // BLOCK_M1 198 | 199 | dk, dv = _attn_bwd_dkdv(dk, dv, # 200 | Q, k, v, mask, sm_scale, # 201 | DO, # 202 | M, D, # 203 | stride_tok, stride_d, # 204 | mask_stride_tok, mask_stride_tokk, 205 | H, N_CTX, # 206 | MASK_BLOCK_M1, BLOCK_N1, HEAD_DIM, # 207 | start_n, start_m, num_steps, # 208 | MASK=USE_MASK # 209 | ) 210 | 211 | start_m += num_steps * MASK_BLOCK_M1 212 | num_steps = (N_CTX - start_m) // BLOCK_M1 213 | 214 | # Compute dK and dV for non-masked blocks. 215 | dk, dv = _attn_bwd_dkdv( # 216 | dk, dv, # 217 | Q, k, v, mask, sm_scale, # 218 | DO, # 219 | M, D, # 220 | stride_tok, stride_d, # 221 | mask_stride_tok, mask_stride_tokk, 222 | H, N_CTX, # 223 | BLOCK_M1, BLOCK_N1, HEAD_DIM, # 224 | start_n, start_m, num_steps, # 225 | MASK=False # 226 | ) 227 | 228 | dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d 229 | tl.store(dv_ptrs, dv) 230 | 231 | # Write back dK. 232 | dk *= sm_scale 233 | dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d 234 | tl.store(dk_ptrs, dk) 235 | 236 | # THIS BLOCK DOES DQ: 237 | start_m = pid * BLOCK_M2 238 | start_n = 0 239 | end_n = start_m + BLOCK_M2 240 | 241 | MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR 242 | offs_m = start_m + tl.arange(0, BLOCK_M2) 243 | 244 | q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) 245 | dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32) 246 | do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) 247 | 248 | m = tl.load(M + offs_m) 249 | m = m[:, None] 250 | 251 | # Compute dQ for masked (diagonal) blocks. 252 | # NOTE: This code scans each row of QK^T backward (from right to left, 253 | # but inside each call to _attn_bwd_dq, from left to right), but that's 254 | # not due to anything important. I just wanted to reuse the loop 255 | # structure for dK & dV above as much as possible. 256 | 257 | num_steps = BLOCK_M2 // MASK_BLOCK_N2 258 | # num_steps = N_CTX // BLOCK_N2 259 | dq = _attn_bwd_dq(dq, q, K, V, # 260 | do, m, D, mask, # 261 | mask_stride_tok, mask_stride_tokk, # 262 | stride_tok, stride_d, # 263 | H, N_CTX, # 264 | BLOCK_M2, MASK_BLOCK_N2, HEAD_DIM, # 265 | start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps, # 266 | # BLOCK_M2, BLOCK_N2, HEAD_DIM, # 267 | # start_m, start_n, num_steps, # 268 | MASK=USE_MASK # 269 | ) 270 | end_n -= num_steps * MASK_BLOCK_N2 271 | # stage 2 272 | num_steps = end_n // BLOCK_N2 273 | dq = _attn_bwd_dq(dq, q, K, V, # 274 | do, m, D, mask, # 275 | mask_stride_tok, mask_stride_tokk, # 276 | stride_tok, stride_d, # 277 | H, N_CTX, # 278 | BLOCK_M2, BLOCK_N2, HEAD_DIM, # 279 | start_m, end_n - num_steps * BLOCK_N2, num_steps, # 280 | MASK=USE_MASK # 281 | ) 282 | # Write back dQ. 283 | dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d 284 | dq *= LN2 285 | tl.store(dq_ptrs, dq) 286 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /test_benchmark.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import torch.nn.functional as F 4 | import triton 5 | 6 | from fa2_custom_mask import flash_attention_custom_mask 7 | 8 | @pytest.mark.parametrize("Z, H, N_CTX, HEAD_DIM", [(1, 2, 128, 64)]) 9 | @pytest.mark.parametrize("causal", [True]) 10 | def test_op_causal(Z, H, N_CTX, causal, HEAD_DIM, dtype=torch.float16): 11 | torch.manual_seed(20) 12 | q = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) 13 | k = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) 14 | v = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) 15 | if causal: 16 | mask = torch.tril(torch.ones((Z, H, N_CTX, N_CTX), dtype=torch.uint8, device="cuda", requires_grad=False)) 17 | else: 18 | mask = None 19 | sm_scale = 0.5 20 | dout = torch.randn_like(q) 21 | # reference implementation 22 | M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) 23 | p = torch.matmul(q, k.transpose(2, 3)) * sm_scale 24 | if causal: 25 | p[:, :, M == 0] = float("-inf") 26 | p = torch.softmax(p.float(), dim=-1).half() 27 | # p = torch.exp(p) 28 | ref_out = torch.matmul(p, v) 29 | ref_out.backward(dout) 30 | ref_dv, v.grad = v.grad.clone(), None 31 | ref_dk, k.grad = k.grad.clone(), None 32 | ref_dq, q.grad = q.grad.clone(), None 33 | # triton implementation 34 | tri_out = flash_attention_custom_mask(q, k, v, mask, sm_scale).half() 35 | tri_out.backward(dout) 36 | tri_dv, v.grad = v.grad.clone(), None 37 | tri_dk, k.grad = k.grad.clone(), None 38 | tri_dq, q.grad = q.grad.clone(), None 39 | # compare 40 | assert torch.allclose(ref_out, tri_out, atol=1e-2, rtol=0) 41 | rtol = 0.0 42 | # Relative tolerance workaround for known hardware limitation of MI200 GPU. 43 | # For details see https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices 44 | if torch.version.hip is not None and triton.runtime.driver.active.get_current_target().arch == "gfx90a": 45 | rtol = 1e-2 46 | 47 | assert torch.allclose(ref_dv, tri_dv, atol=1e-2, rtol=rtol) 48 | assert torch.allclose(ref_dk, tri_dk, atol=1e-2, rtol=rtol) 49 | assert torch.allclose(ref_dq, tri_dq, atol=1e-2, rtol=rtol) 50 | 51 | 52 | @pytest.mark.parametrize("Z, H, N_CTX, HEAD_DIM", [(1, 2, 128, 64)]) 53 | def test_op_random(Z, H, N_CTX, HEAD_DIM, dtype=torch.float16): 54 | torch.manual_seed(20) 55 | q = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) 56 | k = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) 57 | v = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) 58 | sm_scale = 0.5 59 | dout = torch.randn_like(q) 60 | # reference implementation 61 | p = torch.matmul(q, k.transpose(2, 3)) * sm_scale 62 | 63 | # create random mask 64 | M = mask = torch.randint(0, 2, (N_CTX, N_CTX), dtype=torch.uint8, device="cuda", requires_grad=False) 65 | mask = torch.broadcast_to(mask, (Z, H, N_CTX, N_CTX)) 66 | p[:, :, M == 0] = float("-inf") 67 | p = torch.softmax(p.float(), dim=-1).half() 68 | # p = torch.exp(p) 69 | ref_out = torch.matmul(p, v) 70 | ref_out.backward(dout) 71 | ref_dv, v.grad = v.grad.clone(), None 72 | ref_dk, k.grad = k.grad.clone(), None 73 | ref_dq, q.grad = q.grad.clone(), None 74 | # triton implementation 75 | tri_out = flash_attention_custom_mask(q, k, v, mask, sm_scale).half() 76 | tri_out.backward(dout) 77 | tri_dv, v.grad = v.grad.clone(), None 78 | tri_dk, k.grad = k.grad.clone(), None 79 | tri_dq, q.grad = q.grad.clone(), None 80 | # compare 81 | assert torch.allclose(ref_out, tri_out, atol=1e-2, rtol=0) 82 | rtol = 0.0 83 | # Relative tolerance workaround for known hardware limitation of MI200 GPU. 84 | # For details see https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices 85 | if torch.version.hip is not None and triton.runtime.driver.active.get_current_target().arch == "gfx90a": 86 | rtol = 1e-2 87 | 88 | assert torch.allclose(ref_dv, tri_dv, atol=1e-2, rtol=rtol) 89 | assert torch.allclose(ref_dk, tri_dk, atol=1e-2, rtol=rtol) 90 | assert torch.allclose(ref_dq, tri_dq, atol=1e-2, rtol=rtol) 91 | 92 | try: 93 | from flash_attn.flash_attn_interface import \ 94 | flash_attn_qkvpacked_func as flash_attn_func 95 | HAS_FLASH = True 96 | except BaseException: 97 | HAS_FLASH = False 98 | 99 | try: 100 | from fa2_original import attention 101 | USE_FA2_TRITON_ORIGINAL = True 102 | except BaseException: 103 | USE_FA2_TRITON_ORIGINAL = False 104 | 105 | try: 106 | import xformers 107 | import xformers.ops 108 | from xformers.ops import memory_efficient_attention 109 | import xformers.ops.fmha as fmha 110 | HAS_XFORMERS = True 111 | except BaseException: 112 | HAS_XFORMERS = False 113 | 114 | 115 | TORCH_HAS_FP8 = hasattr(torch, 'float8_e5m2') 116 | BATCH, N_HEADS, HEAD_DIM = 4, 16, 64 117 | # vary seq length for fixed head and batch=4 118 | configs = [] 119 | for mode in ["fwd", "bwd"]: 120 | for causal in [True, False]: 121 | configs.append( 122 | triton.testing.Benchmark( 123 | x_names=["N_CTX"], 124 | x_vals=[2**i for i in range(8, 15)], 125 | line_arg="provider", 126 | line_vals=["triton_custom_mask-fp16"] + (["triton_custom_mask-fp8"] if TORCH_HAS_FP8 else []) + 127 | (["flash"] if HAS_FLASH else []) + (["triton-original-fp16"] if USE_FA2_TRITON_ORIGINAL else []) + (["xformers-memory_efficient_attention"] if HAS_XFORMERS else []), 128 | line_names=["Triton (Custom Mask) [FP16]"] + (["Triton (Custom Mask) [FP8]"] if TORCH_HAS_FP8 else []) + 129 | (["Flash-2"] if HAS_FLASH else []) + (["Original Triton [FP16]"] if USE_FA2_TRITON_ORIGINAL else []) + (["XFormers Memory-Efficient Attn"] if HAS_XFORMERS else []), 130 | styles=[("red", "-"), ("blue", "-"), ("green", "-"), ("yellow", "-"), ("pink", "-")], 131 | ylabel="GFLOPS", 132 | plot_name=f"fused-attention-batch{BATCH}-head{N_HEADS}-d{HEAD_DIM}-{mode}-causal={causal}", 133 | args={ 134 | "H": N_HEADS, 135 | "BATCH": BATCH, 136 | "HEAD_DIM": HEAD_DIM, 137 | "mode": mode, 138 | "causal": causal, 139 | }, 140 | )) 141 | 142 | 143 | @triton.testing.perf_report(configs) 144 | def bench_flash_attention(BATCH, H, N_CTX, HEAD_DIM, causal, mode, provider, device="cuda"): 145 | assert mode in ["fwd", "bwd"] 146 | warmup = 25 147 | rep = 100 148 | dtype = torch.float16 149 | if "triton_custom_mask" in provider: 150 | q = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) 151 | k = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) 152 | v = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) 153 | if causal: 154 | mask = torch.tril(torch.ones((BATCH, H, N_CTX, N_CTX), dtype=torch.uint8, device=device, requires_grad=False)) 155 | else: 156 | mask = None 157 | if mode == "fwd" and "fp8" in provider: 158 | q = q.to(torch.float8_e5m2) 159 | k = k.to(torch.float8_e5m2) 160 | v = v.permute(0, 1, 3, 2).contiguous() 161 | v = v.permute(0, 1, 3, 2) 162 | v = v.to(torch.float8_e5m2) 163 | sm_scale = 1.3 164 | fn = lambda: flash_attention_custom_mask(q, k, v, mask, sm_scale) 165 | if mode == "bwd": 166 | o = fn() 167 | do = torch.randn_like(o) 168 | fn = lambda: o.backward(do, retain_graph=True) 169 | ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) 170 | elif "triton-original" in provider: 171 | q = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) 172 | k = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) 173 | v = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) 174 | if mode == "fwd" and "fp8" in provider: 175 | q = q.to(torch.float8_e5m2) 176 | k = k.to(torch.float8_e5m2) 177 | v = v.permute(0, 1, 3, 2).contiguous() 178 | v = v.permute(0, 1, 3, 2) 179 | v = v.to(torch.float8_e5m2) 180 | sm_scale = 1.3 181 | fn = lambda: attention(q, k, v, causal, sm_scale) 182 | if mode == "bwd": 183 | o = fn() 184 | do = torch.randn_like(o) 185 | fn = lambda: o.backward(do, retain_graph=True) 186 | ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) 187 | if provider == "flash": 188 | qkv = torch.randn((BATCH, N_CTX, 3, H, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) 189 | fn = lambda: flash_attn_func(qkv, causal=causal) 190 | if mode == "bwd": 191 | o = fn() 192 | do = torch.randn_like(o) 193 | fn = lambda: o.backward(do, retain_graph=True) 194 | ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) 195 | if provider == "xformers-memory_efficient_attention": 196 | q = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) 197 | k = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) 198 | v = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) 199 | with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=False, enable_mem_efficient=True): 200 | fn = lambda: F.scaled_dot_product_attention(q, k, v) 201 | # fn = lambda: flash_attn_func(qkv, causal=causal) 202 | if mode == "bwd": 203 | o = fn() 204 | do = torch.randn_like(o) 205 | fn = lambda: o.backward(do, retain_graph=True) 206 | ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) 207 | flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * HEAD_DIM 208 | total_flops = 2 * flops_per_matmul 209 | if causal and provider in ["flash", "triton-original-fp16"]: 210 | total_flops *= 0.5 211 | if mode == "bwd": 212 | total_flops *= 2.5 # 2.0(bwd) + 0.5(recompute) 213 | return total_flops / ms * 1e-9 214 | 215 | configs_random = [] 216 | for mode in ["fwd", "bwd"]: 217 | configs_random.append( 218 | triton.testing.Benchmark( 219 | x_names=["N_CTX"], 220 | x_vals=[2**i for i in range(8, 15)], 221 | line_arg="provider", 222 | line_vals=["triton_custom_mask-fp16"] + (["triton_custom_mask-fp8"] if TORCH_HAS_FP8 else []) + (["xformers-memory_efficient_attention"] if HAS_XFORMERS else []), 223 | line_names=["Triton (Custom Mask) [FP16]"] + (["Triton (Custom Mask) [FP8]"] if TORCH_HAS_FP8 else []) + (["XFormers Memory-Efficient Attn"] if HAS_XFORMERS else []), 224 | styles=[("red", "-"), ("blue", "-"), ("pink", "-")], 225 | ylabel="GFLOPS", 226 | plot_name=f"fused-attention-batch{BATCH}-head{N_HEADS}-d{HEAD_DIM}-{mode}-random-mask", 227 | args={ 228 | "H": N_HEADS, 229 | "BATCH": BATCH, 230 | "HEAD_DIM": HEAD_DIM, 231 | "mode": mode, 232 | }, 233 | )) 234 | 235 | @triton.testing.perf_report(configs_random) 236 | def bench_flash_attention_random_mask(BATCH, H, N_CTX, HEAD_DIM, mode, provider, device="cuda"): 237 | assert mode in ["fwd", "bwd"] 238 | warmup = 25 239 | rep = 100 240 | dtype = torch.float16 241 | if "triton_custom_mask" in provider: 242 | q = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) 243 | k = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) 244 | v = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) 245 | mask = torch.randint(0, 2, (BATCH, 1, N_CTX, N_CTX), dtype=dtype, device="cuda", requires_grad=False) 246 | mask = torch.broadcast_to(mask, (BATCH, H, N_CTX, N_CTX)) 247 | if mode == "fwd" and "fp8" in provider: 248 | q = q.to(torch.float8_e5m2) 249 | k = k.to(torch.float8_e5m2) 250 | v = v.permute(0, 1, 3, 2).contiguous() 251 | v = v.permute(0, 1, 3, 2) 252 | v = v.to(torch.float8_e5m2) 253 | sm_scale = 1.3 254 | fn = lambda: flash_attention_custom_mask(q, k, v, mask, sm_scale) 255 | if mode == "bwd": 256 | o = fn() 257 | do = torch.randn_like(o) 258 | fn = lambda: o.backward(do, retain_graph=True) 259 | ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) 260 | if provider == "xformers-memory_efficient_attention": 261 | q = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) 262 | k = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) 263 | v = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) 264 | mask = torch.randint(0, 2, (BATCH, 1, N_CTX, N_CTX), dtype=dtype, device="cuda", requires_grad=False) # doesn't allow uint8 265 | mask = torch.broadcast_to(mask, (BATCH, H, N_CTX, N_CTX)) 266 | q = q.permute(0, 2, 1, 3).contiguous() 267 | k = k.permute(0, 2, 1, 3).contiguous() 268 | v = v.permute(0, 2, 1, 3).contiguous() 269 | # with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=False, enable_mem_efficient=True): 270 | fn = lambda: fmha.memory_efficient_attention(q, k, v, attn_bias=mask) 271 | if mode == "bwd": 272 | o = fn() 273 | do = torch.randn_like(o) 274 | fn = lambda: o.backward(do, retain_graph=True) 275 | ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) 276 | flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * HEAD_DIM 277 | total_flops = 2 * flops_per_matmul 278 | if causal and provider in ["flash", "triton-original-fp16"]: 279 | total_flops *= 0.5 280 | if mode == "bwd": 281 | total_flops *= 2.5 # 2.0(bwd) + 0.5(recompute) 282 | return total_flops / ms * 1e-9 283 | 284 | 285 | if __name__ == "__main__": 286 | # only works on post-Ampere GPUs right now 287 | bench_flash_attention_random_mask.run(save_path="data/", print_data=True) 288 | bench_flash_attention.run(save_path="data/", print_data=True) 289 | -------------------------------------------------------------------------------- /fa2_original.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | import triton 5 | import triton.language as tl 6 | 7 | 8 | def is_hip(): 9 | return triton.runtime.driver.active.get_current_target().backend == "hip" 10 | 11 | 12 | @triton.jit 13 | def _attn_fwd_inner(acc, l_i, m_i, q, # 14 | K_block_ptr, V_block_ptr, # 15 | start_m, qk_scale, # 16 | BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, # 17 | STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, # 18 | N_CTX: tl.constexpr, fp8_v: tl.constexpr): 19 | # range of values handled by this stage 20 | if STAGE == 1: 21 | lo, hi = 0, start_m * BLOCK_M 22 | elif STAGE == 2: 23 | lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M 24 | lo = tl.multiple_of(lo, BLOCK_M) 25 | # causal = False 26 | else: 27 | lo, hi = 0, N_CTX 28 | K_block_ptr = tl.advance(K_block_ptr, (0, lo)) 29 | V_block_ptr = tl.advance(V_block_ptr, (lo, 0)) 30 | # loop over k, v and update accumulator 31 | for start_n in range(lo, hi, BLOCK_N): 32 | start_n = tl.multiple_of(start_n, BLOCK_N) 33 | # -- compute qk ---- 34 | k = tl.load(K_block_ptr) 35 | qk = tl.dot(q, k) 36 | if STAGE == 2: 37 | mask = offs_m[:, None] >= (start_n + offs_n[None, :]) 38 | qk = qk * qk_scale + tl.where(mask, 0, -1.0e6) 39 | m_ij = tl.maximum(m_i, tl.max(qk, 1)) 40 | qk -= m_ij[:, None] 41 | else: 42 | m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) 43 | qk = qk * qk_scale - m_ij[:, None] 44 | p = tl.math.exp2(qk) 45 | l_ij = tl.sum(p, 1) 46 | # -- update m_i and l_i 47 | alpha = tl.math.exp2(m_i - m_ij) 48 | l_i = l_i * alpha + l_ij 49 | # -- update output accumulator -- 50 | acc = acc * alpha[:, None] 51 | # update acc 52 | v = tl.load(V_block_ptr) 53 | if fp8_v: 54 | p = p.to(tl.float8e5) 55 | else: 56 | p = p.to(tl.float16) 57 | acc = tl.dot(p, v, acc) 58 | # update m_i and l_i 59 | m_i = m_ij 60 | V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) 61 | K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) 62 | return acc, l_i, m_i 63 | 64 | 65 | # We don't run auto-tuning every time to keep the tutorial fast. Keeping 66 | # the code below and commenting out the equivalent parameters is convenient for 67 | # re-tuning. 68 | configs = [ 69 | triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN}, num_stages=s, num_warps=w) \ 70 | for BM in [64, 128]\ 71 | for BN in [32, 64]\ 72 | for s in ([1] if is_hip() else [3, 4, 7])\ 73 | for w in [4, 8]\ 74 | ] 75 | 76 | 77 | def keep(conf): 78 | BLOCK_M = conf.kwargs["BLOCK_M"] 79 | BLOCK_N = conf.kwargs["BLOCK_N"] 80 | if BLOCK_M * BLOCK_N < 128 * 128 and conf.num_warps == 8: 81 | return False 82 | return True 83 | 84 | 85 | @triton.autotune(list(filter(keep, configs)), key=["N_CTX", "HEAD_DIM"]) 86 | @triton.jit 87 | def _attn_fwd(Q, K, V, sm_scale, M, Out, # 88 | stride_qz, stride_qh, stride_qm, stride_qk, # 89 | stride_kz, stride_kh, stride_kn, stride_kk, # 90 | stride_vz, stride_vh, stride_vk, stride_vn, # 91 | stride_oz, stride_oh, stride_om, stride_on, # 92 | Z, H, N_CTX, # 93 | HEAD_DIM: tl.constexpr, # 94 | BLOCK_M: tl.constexpr, # 95 | BLOCK_N: tl.constexpr, # 96 | STAGE: tl.constexpr # 97 | ): 98 | tl.static_assert(BLOCK_N <= HEAD_DIM) 99 | start_m = tl.program_id(0) 100 | off_hz = tl.program_id(1) 101 | off_z = off_hz // H 102 | off_h = off_hz % H 103 | qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh 104 | 105 | # block pointers 106 | Q_block_ptr = tl.make_block_ptr( 107 | base=Q + qvk_offset, 108 | shape=(N_CTX, HEAD_DIM), 109 | strides=(stride_qm, stride_qk), 110 | offsets=(start_m * BLOCK_M, 0), 111 | block_shape=(BLOCK_M, HEAD_DIM), 112 | order=(1, 0), 113 | ) 114 | v_order: tl.constexpr = (0, 1) if V.dtype.element_ty == tl.float8e5 else (1, 0) 115 | V_block_ptr = tl.make_block_ptr( 116 | base=V + qvk_offset, 117 | shape=(N_CTX, HEAD_DIM), 118 | strides=(stride_vk, stride_vn), 119 | offsets=(0, 0), 120 | block_shape=(BLOCK_N, HEAD_DIM), 121 | order=v_order, 122 | ) 123 | K_block_ptr = tl.make_block_ptr( 124 | base=K + qvk_offset, 125 | shape=(HEAD_DIM, N_CTX), 126 | strides=(stride_kk, stride_kn), 127 | offsets=(0, 0), 128 | block_shape=(HEAD_DIM, BLOCK_N), 129 | order=(0, 1), 130 | ) 131 | O_block_ptr = tl.make_block_ptr( 132 | base=Out + qvk_offset, 133 | shape=(N_CTX, HEAD_DIM), 134 | strides=(stride_om, stride_on), 135 | offsets=(start_m * BLOCK_M, 0), 136 | block_shape=(BLOCK_M, HEAD_DIM), 137 | order=(1, 0), 138 | ) 139 | # initialize offsets 140 | offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) 141 | offs_n = tl.arange(0, BLOCK_N) 142 | # initialize pointer to m and l 143 | m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") 144 | l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 145 | acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) 146 | # load scales 147 | qk_scale = sm_scale 148 | qk_scale *= 1.44269504 # 1/log(2) 149 | # load q: it will stay in SRAM throughout 150 | q = tl.load(Q_block_ptr) 151 | # stage 1: off-band 152 | # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE 153 | # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE 154 | if STAGE & 1: 155 | acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, # 156 | start_m, qk_scale, # 157 | BLOCK_M, HEAD_DIM, BLOCK_N, # 158 | 4 - STAGE, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5 # 159 | ) 160 | # stage 2: on-band 161 | if STAGE & 2: 162 | # barrier makes it easier for compielr to schedule the 163 | # two loops independently 164 | acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, # 165 | start_m, qk_scale, # 166 | BLOCK_M, HEAD_DIM, BLOCK_N, # 167 | 2, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5 # 168 | ) 169 | # epilogue 170 | m_i += tl.math.log2(l_i) 171 | acc = acc / l_i[:, None] 172 | m_ptrs = M + off_hz * N_CTX + offs_m 173 | tl.store(m_ptrs, m_i) 174 | tl.store(O_block_ptr, acc.to(Out.type.element_ty)) 175 | 176 | 177 | @triton.jit 178 | def _attn_bwd_preprocess(O, DO, # 179 | Delta, # 180 | Z, H, N_CTX, # 181 | BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr # 182 | ): 183 | off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) 184 | off_hz = tl.program_id(1) 185 | off_n = tl.arange(0, HEAD_DIM) 186 | # load 187 | o = tl.load(O + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]) 188 | do = tl.load(DO + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]).to(tl.float32) 189 | delta = tl.sum(o * do, axis=1) 190 | # write-back 191 | tl.store(Delta + off_hz * N_CTX + off_m, delta) 192 | 193 | 194 | # The main inner-loop logic for computing dK and dV. 195 | @triton.jit 196 | def _attn_bwd_dkdv(dk, dv, # 197 | Q, k, v, sm_scale, # 198 | DO, # 199 | M, D, # 200 | # shared by Q/K/V/DO. 201 | stride_tok, stride_d, # 202 | H, N_CTX, BLOCK_M1: tl.constexpr, # 203 | BLOCK_N1: tl.constexpr, # 204 | HEAD_DIM: tl.constexpr, # 205 | # Filled in by the wrapper. 206 | start_n, start_m, num_steps, # 207 | MASK: tl.constexpr): 208 | offs_m = start_m + tl.arange(0, BLOCK_M1) 209 | offs_n = start_n + tl.arange(0, BLOCK_N1) 210 | offs_k = tl.arange(0, HEAD_DIM) 211 | qT_ptrs = Q + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d 212 | do_ptrs = DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d 213 | # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. 214 | tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) 215 | curr_m = start_m 216 | step_m = BLOCK_M1 217 | for blk_idx in range(num_steps): 218 | qT = tl.load(qT_ptrs) 219 | # Load m before computing qk to reduce pipeline stall. 220 | offs_m = curr_m + tl.arange(0, BLOCK_M1) 221 | m = tl.load(M + offs_m) 222 | qkT = tl.dot(k, qT) 223 | pT = tl.math.exp2(qkT - m[None, :]) 224 | # Autoregressive masking. 225 | if MASK: 226 | mask = (offs_m[None, :] >= offs_n[:, None]) 227 | pT = tl.where(mask, pT, 0.0) 228 | do = tl.load(do_ptrs) 229 | # Compute dV. 230 | ppT = pT 231 | ppT = ppT.to(tl.float16) 232 | dv += tl.dot(ppT, do) 233 | # D (= delta) is pre-divided by ds_scale. 234 | Di = tl.load(D + offs_m) 235 | # Compute dP and dS. 236 | dpT = tl.dot(v, tl.trans(do)).to(tl.float32) 237 | dsT = pT * (dpT - Di[None, :]) 238 | dsT = dsT.to(tl.float16) 239 | dk += tl.dot(dsT, tl.trans(qT)) 240 | # Increment pointers. 241 | curr_m += step_m 242 | qT_ptrs += step_m * stride_tok 243 | do_ptrs += step_m * stride_tok 244 | return dk, dv 245 | 246 | 247 | # the main inner-loop logic for computing dQ 248 | @triton.jit 249 | def _attn_bwd_dq(dq, q, K, V, # 250 | do, m, D, 251 | # shared by Q/K/V/DO. 252 | stride_tok, stride_d, # 253 | H, N_CTX, # 254 | BLOCK_M2: tl.constexpr, # 255 | BLOCK_N2: tl.constexpr, # 256 | HEAD_DIM: tl.constexpr, 257 | # Filled in by the wrapper. 258 | start_m, start_n, num_steps, # 259 | MASK: tl.constexpr): 260 | offs_m = start_m + tl.arange(0, BLOCK_M2) 261 | offs_n = start_n + tl.arange(0, BLOCK_N2) 262 | offs_k = tl.arange(0, HEAD_DIM) 263 | kT_ptrs = K + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d 264 | vT_ptrs = V + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d 265 | # D (= delta) is pre-divided by ds_scale. 266 | Di = tl.load(D + offs_m) 267 | # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. 268 | tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) 269 | curr_n = start_n 270 | step_n = BLOCK_N2 271 | for blk_idx in range(num_steps): 272 | kT = tl.load(kT_ptrs) 273 | vT = tl.load(vT_ptrs) 274 | qk = tl.dot(q, kT) 275 | p = tl.math.exp2(qk - m) 276 | # Autoregressive masking. 277 | if MASK: 278 | offs_n = curr_n + tl.arange(0, BLOCK_N2) 279 | mask = (offs_m[:, None] >= offs_n[None, :]) 280 | p = tl.where(mask, p, 0.0) 281 | # Compute dP and dS. 282 | dp = tl.dot(do, vT).to(tl.float32) 283 | ds = p * (dp - Di[:, None]) 284 | ds = ds.to(tl.float16) 285 | # Compute dQ. 286 | # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. 287 | dq += tl.dot(ds, tl.trans(kT)) 288 | # Increment pointers. 289 | curr_n += step_n 290 | kT_ptrs += step_n * stride_tok 291 | vT_ptrs += step_n * stride_tok 292 | return dq 293 | 294 | 295 | @triton.jit 296 | def _attn_bwd(Q, K, V, sm_scale, # 297 | DO, # 298 | DQ, DK, DV, # 299 | M, D, 300 | # shared by Q/K/V/DO. 301 | stride_z, stride_h, stride_tok, stride_d, # 302 | H, N_CTX, # 303 | BLOCK_M1: tl.constexpr, # 304 | BLOCK_N1: tl.constexpr, # 305 | BLOCK_M2: tl.constexpr, # 306 | BLOCK_N2: tl.constexpr, # 307 | BLK_SLICE_FACTOR: tl.constexpr, # 308 | HEAD_DIM: tl.constexpr): 309 | LN2: tl.constexpr = 0.6931471824645996 # = ln(2) 310 | 311 | bhid = tl.program_id(2) 312 | off_chz = (bhid * N_CTX).to(tl.int64) 313 | adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64) 314 | pid = tl.program_id(0) 315 | 316 | # offset pointers for batch/head 317 | Q += adj 318 | K += adj 319 | V += adj 320 | DO += adj 321 | DQ += adj 322 | DK += adj 323 | DV += adj 324 | M += off_chz 325 | D += off_chz 326 | 327 | # load scales 328 | offs_k = tl.arange(0, HEAD_DIM) 329 | 330 | start_n = pid * BLOCK_N1 331 | start_m = start_n 332 | 333 | MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR 334 | offs_n = start_n + tl.arange(0, BLOCK_N1) 335 | 336 | dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) 337 | dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) 338 | 339 | # load K and V: they stay in SRAM throughout the inner loop. 340 | k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d) 341 | v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d) 342 | 343 | num_steps = BLOCK_N1 // MASK_BLOCK_M1 344 | 345 | dk, dv = _attn_bwd_dkdv(dk, dv, # 346 | Q, k, v, sm_scale, # 347 | DO, # 348 | M, D, # 349 | stride_tok, stride_d, # 350 | H, N_CTX, # 351 | MASK_BLOCK_M1, BLOCK_N1, HEAD_DIM, # 352 | start_n, start_m, num_steps, # 353 | MASK=True # 354 | ) 355 | 356 | start_m += num_steps * MASK_BLOCK_M1 357 | num_steps = (N_CTX - start_m) // BLOCK_M1 358 | 359 | # Compute dK and dV for non-masked blocks. 360 | dk, dv = _attn_bwd_dkdv( # 361 | dk, dv, # 362 | Q, k, v, sm_scale, # 363 | DO, # 364 | M, D, # 365 | stride_tok, stride_d, # 366 | H, N_CTX, # 367 | BLOCK_M1, BLOCK_N1, HEAD_DIM, # 368 | start_n, start_m, num_steps, # 369 | MASK=False # 370 | ) 371 | 372 | dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d 373 | tl.store(dv_ptrs, dv) 374 | 375 | # Write back dK. 376 | dk *= sm_scale 377 | dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d 378 | tl.store(dk_ptrs, dk) 379 | 380 | # THIS BLOCK DOES DQ: 381 | start_m = pid * BLOCK_M2 382 | end_n = start_m + BLOCK_M2 383 | 384 | MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR 385 | offs_m = start_m + tl.arange(0, BLOCK_M2) 386 | 387 | q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) 388 | dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32) 389 | do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) 390 | 391 | m = tl.load(M + offs_m) 392 | m = m[:, None] 393 | 394 | # Compute dQ for masked (diagonal) blocks. 395 | # NOTE: This code scans each row of QK^T backward (from right to left, 396 | # but inside each call to _attn_bwd_dq, from left to right), but that's 397 | # not due to anything important. I just wanted to reuse the loop 398 | # structure for dK & dV above as much as possible. 399 | num_steps = BLOCK_M2 // MASK_BLOCK_N2 400 | dq = _attn_bwd_dq(dq, q, K, V, # 401 | do, m, D, # 402 | stride_tok, stride_d, # 403 | H, N_CTX, # 404 | BLOCK_M2, MASK_BLOCK_N2, HEAD_DIM, # 405 | start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps, # 406 | MASK=True # 407 | ) 408 | end_n -= num_steps * MASK_BLOCK_N2 409 | # stage 2 410 | num_steps = end_n // BLOCK_N2 411 | dq = _attn_bwd_dq(dq, q, K, V, # 412 | do, m, D, # 413 | stride_tok, stride_d, # 414 | H, N_CTX, # 415 | BLOCK_M2, BLOCK_N2, HEAD_DIM, # 416 | start_m, end_n - num_steps * BLOCK_N2, num_steps, # 417 | MASK=False # 418 | ) 419 | # Write back dQ. 420 | dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d 421 | dq *= LN2 422 | tl.store(dq_ptrs, dq) 423 | 424 | 425 | class _attention(torch.autograd.Function): 426 | 427 | @staticmethod 428 | def forward(ctx, q, k, v, causal, sm_scale): 429 | # shape constraints 430 | HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1] 431 | # when v is in float8_e5m2 it is transposed. 432 | HEAD_DIM_V = v.shape[-1] 433 | assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V 434 | assert HEAD_DIM_K in {16, 32, 64, 128, 256} 435 | o = torch.empty_like(q) 436 | stage = 3 if causal else 1 437 | extra_kern_args = {} 438 | # Tuning for AMD target 439 | if is_hip(): 440 | waves_per_eu = 3 if HEAD_DIM_K <= 64 else 2 441 | extra_kern_args = {"waves_per_eu": waves_per_eu, "allow_flush_denorm": True} 442 | 443 | grid = lambda args: (triton.cdiv(q.shape[2], args["BLOCK_M"]), q.shape[0] * q.shape[1], 1) 444 | M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) 445 | _attn_fwd[grid]( 446 | q, k, v, sm_scale, M, o, # 447 | q.stride(0), q.stride(1), q.stride(2), q.stride(3), # 448 | k.stride(0), k.stride(1), k.stride(2), k.stride(3), # 449 | v.stride(0), v.stride(1), v.stride(2), v.stride(3), # 450 | o.stride(0), o.stride(1), o.stride(2), o.stride(3), # 451 | q.shape[0], q.shape[1], # 452 | N_CTX=q.shape[2], # 453 | HEAD_DIM=HEAD_DIM_K, # 454 | STAGE=stage, # 455 | **extra_kern_args) 456 | 457 | ctx.save_for_backward(q, k, v, o, M) 458 | ctx.grid = grid 459 | ctx.sm_scale = sm_scale 460 | ctx.HEAD_DIM = HEAD_DIM_K 461 | ctx.causal = causal 462 | return o 463 | 464 | @staticmethod 465 | def backward(ctx, do): 466 | q, k, v, o, M = ctx.saved_tensors 467 | assert do.is_contiguous() 468 | assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride() 469 | dq = torch.empty_like(q) 470 | dk = torch.empty_like(k) 471 | dv = torch.empty_like(v) 472 | BATCH, N_HEAD, N_CTX = q.shape[:3] 473 | PRE_BLOCK = 128 474 | NUM_WARPS, NUM_STAGES = 4, 5 475 | BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32 476 | BLK_SLICE_FACTOR = 2 477 | RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2) 478 | arg_k = k 479 | arg_k = arg_k * (ctx.sm_scale * RCP_LN2) 480 | PRE_BLOCK = 128 481 | assert N_CTX % PRE_BLOCK == 0 482 | pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD) 483 | delta = torch.empty_like(M) 484 | _attn_bwd_preprocess[pre_grid]( 485 | o, do, # 486 | delta, # 487 | BATCH, N_HEAD, N_CTX, # 488 | BLOCK_M=PRE_BLOCK, HEAD_DIM=ctx.HEAD_DIM # 489 | ) 490 | grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD) 491 | _attn_bwd[grid]( 492 | q, arg_k, v, ctx.sm_scale, do, dq, dk, dv, # 493 | M, delta, # 494 | q.stride(0), q.stride(1), q.stride(2), q.stride(3), # 495 | N_HEAD, N_CTX, # 496 | BLOCK_M1=BLOCK_M1, BLOCK_N1=BLOCK_N1, # 497 | BLOCK_M2=BLOCK_M2, BLOCK_N2=BLOCK_N2, # 498 | BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, # 499 | HEAD_DIM=ctx.HEAD_DIM, # 500 | num_warps=NUM_WARPS, # 501 | num_stages=NUM_STAGES # 502 | ) 503 | 504 | return dq, dk, dv, None, None 505 | 506 | 507 | attention = _attention.apply 508 | 509 | 510 | @pytest.mark.parametrize("Z, H, N_CTX, HEAD_DIM", [(1, 2, 1024, 64)]) 511 | @pytest.mark.parametrize("causal", [True]) 512 | def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype=torch.float16): 513 | torch.manual_seed(20) 514 | q = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) 515 | k = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) 516 | v = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) 517 | sm_scale = 0.5 518 | dout = torch.randn_like(q) 519 | # reference implementation 520 | M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) 521 | p = torch.matmul(q, k.transpose(2, 3)) * sm_scale 522 | if causal: 523 | p[:, :, M == 0] = float("-inf") 524 | p = torch.softmax(p.float(), dim=-1).half() 525 | # p = torch.exp(p) 526 | ref_out = torch.matmul(p, v) 527 | ref_out.backward(dout) 528 | ref_dv, v.grad = v.grad.clone(), None 529 | ref_dk, k.grad = k.grad.clone(), None 530 | ref_dq, q.grad = q.grad.clone(), None 531 | # triton implementation 532 | tri_out = attention(q, k, v, causal, sm_scale).half() 533 | tri_out.backward(dout) 534 | tri_dv, v.grad = v.grad.clone(), None 535 | tri_dk, k.grad = k.grad.clone(), None 536 | tri_dq, q.grad = q.grad.clone(), None 537 | # compare 538 | assert torch.allclose(ref_out, tri_out, atol=1e-2, rtol=0) 539 | rtol = 0.0 540 | # Relative tolerance workaround for known hardware limitation of MI200 GPU. 541 | # For details see https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices 542 | if torch.version.hip is not None and triton.runtime.driver.active.get_current_target().arch == "gfx90a": 543 | rtol = 1e-2 544 | assert torch.allclose(ref_dv, tri_dv, atol=1e-2, rtol=rtol) 545 | assert torch.allclose(ref_dk, tri_dk, atol=1e-2, rtol=rtol) 546 | assert torch.allclose(ref_dq, tri_dq, atol=1e-2, rtol=rtol) 547 | 548 | 549 | try: 550 | from flash_attn.flash_attn_interface import \ 551 | flash_attn_qkvpacked_func as flash_attn_func 552 | HAS_FLASH = True 553 | except BaseException: 554 | HAS_FLASH = False 555 | 556 | TORCH_HAS_FP8 = hasattr(torch, 'float8_e5m2') 557 | BATCH, N_HEADS, HEAD_DIM = 4, 32, 64 558 | # vary seq length for fixed head and batch=4 559 | configs = [] 560 | for mode in ["fwd", "bwd"]: 561 | for causal in [True, False]: 562 | if mode == "bwd" and not causal: 563 | continue 564 | configs.append( 565 | triton.testing.Benchmark( 566 | x_names=["N_CTX"], 567 | x_vals=[2**i for i in range(10, 15)], 568 | line_arg="provider", 569 | line_vals=["triton-fp16"] + (["triton-fp8"] if TORCH_HAS_FP8 else []) + 570 | (["flash"] if HAS_FLASH else []), 571 | line_names=["Triton [FP16]"] + (["Triton [FP8]"] if TORCH_HAS_FP8 else []) + 572 | (["Flash-2"] if HAS_FLASH else []), 573 | styles=[("red", "-"), ("blue", "-"), ("green", "-")], 574 | ylabel="ms", 575 | plot_name=f"fused-attention-batch{BATCH}-head{N_HEADS}-d{HEAD_DIM}-{mode}-causal={causal}", 576 | args={ 577 | "H": N_HEADS, 578 | "BATCH": BATCH, 579 | "HEAD_DIM": HEAD_DIM, 580 | "mode": mode, 581 | "causal": causal, 582 | }, 583 | )) 584 | 585 | 586 | @triton.testing.perf_report(configs) 587 | def bench_flash_attention(BATCH, H, N_CTX, HEAD_DIM, causal, mode, provider, device="cuda"): 588 | assert mode in ["fwd", "bwd"] 589 | warmup = 25 590 | rep = 100 591 | dtype = torch.float16 592 | if "triton" in provider: 593 | q = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) 594 | k = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) 595 | v = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) 596 | if mode == "fwd" and "fp8" in provider: 597 | q = q.to(torch.float8_e5m2) 598 | k = k.to(torch.float8_e5m2) 599 | v = v.permute(0, 1, 3, 2).contiguous() 600 | v = v.permute(0, 1, 3, 2) 601 | v = v.to(torch.float8_e5m2) 602 | sm_scale = 1.3 603 | fn = lambda: attention(q, k, v, causal, sm_scale) 604 | if mode == "bwd": 605 | o = fn() 606 | do = torch.randn_like(o) 607 | fn = lambda: o.backward(do, retain_graph=True) 608 | ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) 609 | if provider == "flash": 610 | qkv = torch.randn((BATCH, N_CTX, 3, H, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) 611 | fn = lambda: flash_attn_func(qkv, causal=causal) 612 | if mode == "bwd": 613 | o = fn() 614 | do = torch.randn_like(o) 615 | fn = lambda: o.backward(do, retain_graph=True) 616 | ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) 617 | flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * HEAD_DIM 618 | total_flops = 2 * flops_per_matmul 619 | if causal: 620 | total_flops *= 0.5 621 | if mode == "bwd": 622 | total_flops *= 2.5 # 2.0(bwd) + 0.5(recompute) 623 | return total_flops / ms * 1e-9 624 | 625 | 626 | if __name__ == "__main__": 627 | # only works on post-Ampere GPUs right now 628 | bench_flash_attention.run(save_path=".", print_data=True) --------------------------------------------------------------------------------