├── assets ├── logo │ ├── vertical-blue.png │ ├── vertical-dark.png │ ├── baai-flagopen.jpeg │ ├── horizontal-blue.png │ ├── horizontal-dark.png │ ├── horizontal-light.png │ └── vertical-light.png ├── headdim64-causal-A100.png ├── piecewise_attention.png ├── v0.2 │ ├── flash_attention.png │ ├── flash_attention_d64.png │ └── piecewise_attention.png ├── headdim128-causal-A100.png └── piecewise_attention_interface.png ├── src └── flag_attn │ ├── testing │ ├── __init__.py │ ├── dropout.py │ ├── paged.py │ ├── flash.py │ └── piecewise.py │ ├── __init__.py │ ├── dropout.py │ ├── total.py │ ├── split_kv.py │ ├── paged.py │ ├── piecewise.py │ └── flash.py ├── LICENSE ├── examples ├── use_cutom_config_func.py ├── flash_attention_example.py ├── piecewise_example.py ├── flash_attention_with_aux_outputs.py └── paged_example.py ├── tests └── flag_attn │ ├── test_dropout.py │ ├── test_paged_attention.py │ ├── test_piecewise_attention.py │ └── test_flash_attention.py ├── .pre-commit-config.yaml ├── pyproject.toml ├── .github └── workflows │ └── code-check.yml ├── .gitignore ├── benchmark ├── flash_decoding_benchmark.py ├── flash_benchmark.py ├── piecewise_benchmark.py └── paged_benchmark.py ├── README_cn.md └── README.md /assets/logo/vertical-blue.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flagos-ai/FlagAttention/HEAD/assets/logo/vertical-blue.png -------------------------------------------------------------------------------- /assets/logo/vertical-dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flagos-ai/FlagAttention/HEAD/assets/logo/vertical-dark.png -------------------------------------------------------------------------------- /assets/headdim64-causal-A100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flagos-ai/FlagAttention/HEAD/assets/headdim64-causal-A100.png -------------------------------------------------------------------------------- /assets/logo/baai-flagopen.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flagos-ai/FlagAttention/HEAD/assets/logo/baai-flagopen.jpeg -------------------------------------------------------------------------------- /assets/logo/horizontal-blue.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flagos-ai/FlagAttention/HEAD/assets/logo/horizontal-blue.png -------------------------------------------------------------------------------- /assets/logo/horizontal-dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flagos-ai/FlagAttention/HEAD/assets/logo/horizontal-dark.png -------------------------------------------------------------------------------- /assets/logo/horizontal-light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flagos-ai/FlagAttention/HEAD/assets/logo/horizontal-light.png -------------------------------------------------------------------------------- /assets/logo/vertical-light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flagos-ai/FlagAttention/HEAD/assets/logo/vertical-light.png -------------------------------------------------------------------------------- /assets/piecewise_attention.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flagos-ai/FlagAttention/HEAD/assets/piecewise_attention.png -------------------------------------------------------------------------------- /assets/v0.2/flash_attention.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flagos-ai/FlagAttention/HEAD/assets/v0.2/flash_attention.png -------------------------------------------------------------------------------- /assets/headdim128-causal-A100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flagos-ai/FlagAttention/HEAD/assets/headdim128-causal-A100.png -------------------------------------------------------------------------------- /assets/v0.2/flash_attention_d64.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flagos-ai/FlagAttention/HEAD/assets/v0.2/flash_attention_d64.png -------------------------------------------------------------------------------- /assets/v0.2/piecewise_attention.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flagos-ai/FlagAttention/HEAD/assets/v0.2/piecewise_attention.png -------------------------------------------------------------------------------- /assets/piecewise_attention_interface.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flagos-ai/FlagAttention/HEAD/assets/piecewise_attention_interface.png -------------------------------------------------------------------------------- /src/flag_attn/testing/__init__.py: -------------------------------------------------------------------------------- 1 | from flag_attn.testing.flash import attention as flash_attention # noqa: F401 2 | from flag_attn.testing.piecewise import attention as piecewise_attention # noqa: F401 3 | from flag_attn.testing.paged import attention as paged_attention # noqa: F401 4 | from flag_attn.testing.dropout import recompute_mask -------------------------------------------------------------------------------- /src/flag_attn/__init__.py: -------------------------------------------------------------------------------- 1 | try: 2 | from ._version import version as __version__ 3 | from ._version import version_tuple 4 | except ImportError: 5 | __version__ = "0.0.0" 6 | version_tuple = (0, 0, 0) 7 | 8 | 9 | from flag_attn.piecewise import attention as piecewise_attention # noqa: F401 10 | from flag_attn.flash import attention as flash_attention # noqa: F401 11 | from flag_attn.split_kv import attention as flash_attention_split_kv # noqa: F401 12 | from flag_attn.paged import attention as paged_attention # noqa: F401 13 | 14 | from flag_attn import testing # noqa: F401 15 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2023 BAAI 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | -------------------------------------------------------------------------------- /src/flag_attn/dropout.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import triton 3 | import triton.language as tl 4 | 5 | def philox_cuda_seed_offset(increment, device=None): 6 | device = device or torch.cuda.current_device() 7 | gen = torch.cuda.default_generators[device] 8 | state_copy = gen.get_state() 9 | c0, c1 = state_copy.view(torch.int64) 10 | seed, offset = int(c0), int(c1) 11 | increment = (increment + 3) // 4 * 4 12 | c1 += increment 13 | # get_state returns a new tensor, so it needs set_state to update the actual generator state. 14 | gen.set_state(state_copy) 15 | return seed, offset 16 | -------------------------------------------------------------------------------- /examples/use_cutom_config_func.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import flag_attn 3 | 4 | 5 | # replace the default config function 6 | from flag_attn import flash 7 | def get_fwd_config(B, H, M, N, D, causal): 8 | return (64, 64, 1, 4) 9 | flash.get_fwd_config = get_fwd_config 10 | 11 | B, H, M, N, D = 2, 16, 4096, 4096, 128 12 | causal = True 13 | 14 | q = torch.randn(B, H, M, D, dtype=torch.bfloat16, device="cuda:0", requires_grad=True) 15 | k = torch.randn(B, H, N, D, dtype=torch.bfloat16, device="cuda:0", requires_grad=True) 16 | v = torch.randn(B, H, N, D, dtype=torch.bfloat16, device="cuda:0", requires_grad=True) 17 | 18 | o = flag_attn.flash_attention(q, k, v, causal=causal) 19 | go = torch.randn_like(o) 20 | gq, gk, gv = torch.autograd.grad(o, (q, k, v), go) 21 | -------------------------------------------------------------------------------- /examples/flash_attention_example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import flag_attn 3 | 4 | B, H, M, N, D = 2, 16, 4000, 4000, 128 5 | causal = True 6 | 7 | q = torch.randn(B, H, M, D, dtype=torch.bfloat16, device="cuda:0", requires_grad=True) 8 | k = torch.randn(B, H, N, D, dtype=torch.bfloat16, device="cuda:0", requires_grad=True) 9 | v = torch.randn(B, H, N, D, dtype=torch.bfloat16, device="cuda:0", requires_grad=True) 10 | 11 | 12 | o_ref = flag_attn.testing.flash_attention(q, k, v, causal=causal, upcast=True) 13 | o = flag_attn.flash_attention(q, k, v, causal=causal) 14 | o_torch = flag_attn.testing.flash_attention(q, k, v, causal=causal) 15 | 16 | go = torch.randn_like(o) 17 | gq_ref, gk_ref, gv_ref = torch.autograd.grad(o_ref, (q, k, v), go) 18 | gq, gk, gv = torch.autograd.grad(o, (q, k, v), go) 19 | gq_torch, gk_torch, gv_torch = torch.autograd.grad(o_torch, (q, k, v), go) 20 | -------------------------------------------------------------------------------- /examples/piecewise_example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from flag_attn import piecewise_attention 3 | 4 | B, H, T, D = 2, 16, 8192, 128 5 | dist_threshold = T // 2 6 | 7 | q1 = torch.randn((B, H, T, D), dtype=torch.float16, device="cuda:0").requires_grad_() 8 | q2 = torch.randn((B, H, T, D), dtype=torch.float16, device="cuda:0").requires_grad_() 9 | k1 = torch.randn((B, H, T, D), dtype=torch.float16, device="cuda:0").requires_grad_() 10 | k2 = torch.randn((B, H, T, D), dtype=torch.float16, device="cuda:0").requires_grad_() 11 | v = torch.randn((B, H, T, D), dtype=torch.float16, device="cuda:0").requires_grad_() 12 | o = piecewise_attention(q1, k1, q2, k2, v, dist_threshold, causal=True) 13 | print(o) 14 | 15 | go = torch.randn((B, H, T, D), dtype=torch.float16, device="cuda:0") 16 | gq1, gk1, gq2, gk2, gv = torch.autograd.grad( 17 | o, (q1, k1, q2, k2, v), go 18 | ) 19 | print(gq1) 20 | 21 | 22 | -------------------------------------------------------------------------------- /examples/flash_attention_with_aux_outputs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import flag_attn 3 | 4 | 5 | B, H, M, N, D = 2, 4, 10, 20, 128 6 | causal = False 7 | q = torch.randn((B, H, M, D), dtype=torch.bfloat16, device="cuda", requires_grad=True) 8 | k = torch.randn((B, H, N, D), dtype=torch.bfloat16, device="cuda", requires_grad=True) 9 | v = torch.randn((B, H, N, D), dtype=torch.bfloat16, device="cuda", requires_grad=True) 10 | 11 | o, logz, tot_attn = flag_attn.flash_attention( 12 | q, k, v, causal=causal, return_log_normalizer=True, return_total_attention=True) 13 | o_ref, logz_ref, tot_attn_ref = flag_attn.testing.flash_attention( 14 | q, k, v, causal=causal, 15 | return_log_normalizer=True, return_total_attention=True, pcast=True) 16 | 17 | print("log normalizer") 18 | print(logz[0, 0]) 19 | print(logz_ref[0, 0]) 20 | 21 | print("total attention") 22 | print(tot_attn[0, 0]) 23 | print(tot_attn_ref[0, 0]) 24 | -------------------------------------------------------------------------------- /tests/flag_attn/test_dropout.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytest 3 | from flag_attn.testing import recompute_mask 4 | 5 | 6 | @pytest.mark.parametrize('B, H, M, N', [ 7 | (2, 4, 512, 612), 8 | (2, 4, 1024, 1034), 9 | (2, 4, 2048, 2048), 10 | (2, 4, 4096, 4096), 11 | (2, 4, 4001, 4001), 12 | (2, 4, 4001, 4096), 13 | (2, 4, 4096, 4000), 14 | (1, 2, 8192, 8202), 15 | (1, 2, 8192, 8192), 16 | ]) 17 | @pytest.mark.parametrize('p', [0.5, 0.8]) 18 | def test_recompute_mask(B, H, M, N, p): 19 | import math 20 | seed = 123456789 21 | offset = 123456789123456789 22 | device = torch.cuda.current_device() 23 | mask = recompute_mask(B, H, M, N, p, seed, offset, device) 24 | # zeros indicate to drop 25 | # k follows Binomial distributio B(k; n, p) 26 | n = mask.numel() 27 | k = torch.sum(mask == 0) 28 | p_cap = k / n 29 | tol = 0.01 30 | assert math.fabs(p_cap - p) < tol * p 31 | -------------------------------------------------------------------------------- /src/flag_attn/testing/dropout.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import triton 3 | import triton.language as tl 4 | 5 | @triton.jit 6 | def recompute_mask_kernel(mask, B, H, M, N, dropout_p, seed, offset): 7 | row, b, h = tl.program_id(0), tl.program_id(1), tl.program_id(2) 8 | offs_base = b * H * M * N + h * M * N + row * N 9 | BLOCK: tl.constexpr = 1024 10 | offs_base += tl.arange(0, BLOCK) 11 | for start_n in range(0, N, BLOCK): 12 | offs = start_n + offs_base 13 | rng_offs = offset + offs 14 | pmask = tl.rand(seed, rng_offs, n_rounds=6) > dropout_p 15 | row_mask = start_n + tl.arange(0, BLOCK) < N 16 | tl.store(mask + offs, pmask, mask=row_mask) 17 | 18 | def recompute_mask(B, H, M, N, dropout_p, seed, offset, device): 19 | mask = torch.full((B, H, M, N), True, dtype=torch.bool, device=device) 20 | if dropout_p == 0: 21 | return mask 22 | grid = (M, B, H) 23 | with torch.cuda.device(device): 24 | recompute_mask_kernel[grid](mask, B, H, M, N, dropout_p, seed, offset) 25 | return mask 26 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | files: '^src/.*' 2 | repos: 3 | - repo: https://github.com/pre-commit/pre-commit-hooks 4 | rev: v4.5.0 5 | hooks: 6 | - id: trailing-whitespace 7 | - id: end-of-file-fixer 8 | - id: check-yaml 9 | - id: check-toml 10 | - id: check-ast 11 | - id: check-added-large-files 12 | - id: check-merge-conflict 13 | - id: check-executables-have-shebangs 14 | - id: check-shebang-scripts-are-executable 15 | - id: detect-private-key 16 | - id: debug-statements 17 | 18 | # - repo: https://github.com/google/yapf 19 | # rev: v0.40.2 20 | # hooks: 21 | # - id: yapf 22 | # args: ["-p", "-i"] 23 | # stages: [commit, push, manual] 24 | 25 | # - repo: https://github.com/pylint-dev/pylint 26 | # rev: v3.0.3 27 | # hooks: 28 | # - id: pylint 29 | 30 | 31 | - repo: https://github.com/astral-sh/ruff-pre-commit 32 | rev: v0.1.14 33 | hooks: 34 | - id: ruff 35 | args: ["--fix"] 36 | stages: [commit, push, manual] 37 | # - id: ruff-format 38 | # stages: [commit, push, manual] 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "flag_attn" 3 | dynamic = ["version"] 4 | authors = [ 5 | {name = "Chen Feiyu", email = "iclementine@outlook.com"}, 6 | ] 7 | description = "A collection of memory efficient attention operators implemented in triton language." 8 | readme = {file= "README.md", content-type="text/markdown"} 9 | requires-python = ">=3.7" 10 | license = {text = "LICENSE.txt"} 11 | classifiers = [ 12 | "Development Status :: 3 - Alpha", 13 | "Programming Language :: Python :: 3", 14 | "License :: OSI Approved :: Apache Software License", 15 | ] 16 | 17 | # Not specifing triton version here because torch has its own required triton version 18 | # FlagAttention needs a recent version of triton (triton nightly or 2.2.0) to run. 19 | dependencies = [ 20 | "triton>=2.2.0" 21 | ] 22 | 23 | [project.optional-dependencies] 24 | test = [ 25 | "pytest>=7.1.0", 26 | ] 27 | 28 | [project.urls] 29 | homepage = "https://github.com/FlagOpen/FlagAttention" 30 | 31 | 32 | [build-system] 33 | requires = ["setuptools>=60", "setuptools-scm>=8.0"] 34 | build-backend = "setuptools.build_meta" 35 | 36 | [tool.setuptools_scm] 37 | version_file = "src/flag_attn/_version.py" 38 | 39 | [tool.setuptools.packages.find] 40 | where = ["src"] 41 | include = ["flag_attn"] 42 | namespaces = false 43 | 44 | # helps for setting up pytest in pyprojects 45 | # https://docs.pytest.org/en/7.3.x/reference/customize.html#rootdir 46 | # https://docs.pytest.org/en/7.3.x/reference/reference.html#confval-pythonpath 47 | [tool.pytest.ini_options] 48 | testpaths = [ 49 | "tests", 50 | ] 51 | pythonpath = [ 52 | "src", 53 | "tests/flag_attn", 54 | ] 55 | 56 | [tool.ruff] 57 | ignore = ["E741"] 58 | line-length = 120 59 | -------------------------------------------------------------------------------- /.github/workflows/code-check.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python 3 | 4 | name: Python application 5 | 6 | on: 7 | push: 8 | branches: [ "main" ] 9 | pull_request: 10 | branches: [ "main" ] 11 | 12 | permissions: 13 | contents: read 14 | 15 | jobs: 16 | build: 17 | runs-on: self-hosted 18 | steps: 19 | - uses: actions/checkout@v4 20 | with: 21 | ssh-key: ${{ secrets.SSH_PRIVATE_KEY }} 22 | 23 | # - name: Set up Python 3.10 24 | # uses: actions/setup-python@v3 25 | # with: 26 | # python-version: "3.10" 27 | 28 | # - name: Install dependencies 29 | # run: | 30 | # python -m pip install --upgrade pip 31 | # pip install flake8 pytest 32 | # if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 33 | # pip install . 34 | 35 | - name: Activate Virtualenv 36 | run: | 37 | source /home/flagattn_ci/.virtualenvs/release/bin/activate 38 | echo PATH=$PATH >> $GITHUB_ENV 39 | 40 | - name: Editable Install 41 | run: | 42 | pip install --no-dependencies -e . 43 | 44 | # - name: Lint with flake8 45 | # run: | 46 | # # stop the build if there are Python syntax errors or undefined names 47 | # flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 48 | # # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 49 | # flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 50 | - name: Test with pytest 51 | run: | 52 | pytest tests -------------------------------------------------------------------------------- /src/flag_attn/testing/paged.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | 5 | def attention( 6 | query: torch.Tensor, 7 | key_cache: torch.Tensor, 8 | value_cache: torch.Tensor, 9 | block_tables: torch.Tensor, 10 | context_lens: torch.Tensor, 11 | scale: float, 12 | ) -> None: 13 | output = torch.empty_like(query) 14 | 15 | num_query_heads = query.shape[1] 16 | num_kv_heads = value_cache.shape[1] 17 | num_queries_per_kv = num_query_heads // num_kv_heads 18 | block_size = value_cache.shape[2] 19 | head_size = value_cache.shape[3] 20 | num_seqs = query.shape[0] 21 | 22 | block_tables = block_tables.cpu().tolist() 23 | context_lens = context_lens.cpu().tolist() 24 | for i in range(num_seqs): 25 | q = query[i].unsqueeze(0) 26 | block_table = block_tables[i] 27 | context_len = int(context_lens[i]) 28 | 29 | keys = [] 30 | values = [] 31 | for j in range(context_len): 32 | block_number = int(block_table[j // block_size]) 33 | block_offset = j % block_size 34 | k = key_cache[block_number, :, block_offset, :] 35 | keys.append(k) 36 | v = value_cache[block_number, :, block_offset, :] 37 | values.append(v) 38 | keys = torch.stack(keys, dim=0) 39 | values = torch.stack(values, dim=0) 40 | if num_queries_per_kv > 1: 41 | # Handle MQA and GQA 42 | keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1) 43 | values = torch.repeat_interleave(values, num_queries_per_kv, dim=1) 44 | 45 | S = torch.bmm(q.transpose(0, 1).float(), keys.permute(1, 2, 0).float()) * scale 46 | P = torch.softmax(S, dim=-1) 47 | out = torch.bmm(P, values.transpose(0, 1).float()).transpose(0, 1) 48 | out = out.to(values.dtype) 49 | out = out.view(num_query_heads, head_size) 50 | output[i].copy_(out, non_blocking=True) 51 | 52 | return output 53 | -------------------------------------------------------------------------------- /examples/paged_example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Tuple 3 | 4 | import flag_attn 5 | 6 | MAX_SEQ_LEN = 4096 7 | NUM_BLOCKS = 2000 8 | 9 | 10 | def test_paged_attention( 11 | num_seqs: int, 12 | num_heads: Tuple[int, int], 13 | head_size: int, 14 | block_size: int, 15 | dtype: torch.dtype, 16 | seed: int, 17 | device: int, 18 | ): 19 | torch.set_default_dtype(dtype) 20 | torch.set_default_device(device=device) 21 | 22 | torch.cuda.manual_seed(seed) 23 | 24 | num_query_heads, num_kv_heads = num_heads 25 | 26 | context_lens = torch.randint(1, MAX_SEQ_LEN, [num_seqs], dtype=torch.int32) 27 | max_context_len = context_lens.max().item() 28 | max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size 29 | 30 | attn_scale = head_size**-0.5 31 | q = torch.empty(num_seqs, num_query_heads, head_size) 32 | q.uniform_(-attn_scale, attn_scale) 33 | 34 | k_cache = torch.empty(NUM_BLOCKS, num_kv_heads, block_size, head_size) 35 | k_cache.uniform_(-attn_scale, attn_scale) 36 | v_cache = torch.empty_like(k_cache) 37 | v_cache.uniform_(-attn_scale, attn_scale) 38 | 39 | # (NUM_SEQS, MAX_NUM_BLOCKS_PER_SEQ) 40 | block_tables = torch.randint(0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq)) 41 | 42 | out = flag_attn.paged_attention( 43 | q, 44 | k_cache, 45 | v_cache, 46 | context_lens, 47 | block_tables, 48 | attn_scale, 49 | max_context_len, 50 | ) 51 | 52 | ref_out = flag_attn.testing.paged_attention( 53 | q, 54 | k_cache, 55 | v_cache, 56 | block_tables, 57 | context_lens, 58 | attn_scale, 59 | ) 60 | print(torch.abs(out - ref_out).max()) 61 | assert torch.allclose(out, ref_out, atol=1e-3, rtol=1e-5) 62 | 63 | 64 | def main(): 65 | test_paged_attention( 66 | num_seqs=32, 67 | num_heads=(64, 64), 68 | head_size=64, 69 | block_size=16, 70 | dtype=torch.float16, 71 | seed=1, 72 | device="cuda:0", 73 | ) 74 | 75 | 76 | if __name__ == "__main__": 77 | main() 78 | -------------------------------------------------------------------------------- /src/flag_attn/testing/flash.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | 5 | def attention(q, 6 | k, 7 | v, 8 | causal, 9 | dropout_p=0.0, 10 | dropout_mask=None, 11 | sm_scale=None, 12 | return_log_normalizer=False, 13 | return_total_attention=False, 14 | upcast=False): 15 | input_dtype = q.dtype 16 | if upcast: 17 | q, k, v = q.float(), k.float(), v.float() 18 | # (B, H, T, D) 19 | D = q.shape[-1] 20 | if sm_scale is None: 21 | sm_scale = 1. / math.sqrt(D) 22 | 23 | num_heads_q = q.shape[1] 24 | num_heads_k = k.shape[1] 25 | assert num_heads_q % num_heads_k == 0 26 | num_groups = num_heads_q // num_heads_k 27 | 28 | if num_groups > 1: 29 | k = torch.repeat_interleave(k, repeats=num_groups, dim=1) 30 | v = torch.repeat_interleave(v, repeats=num_groups, dim=1) 31 | kv_seq_len = k.shape[-2] 32 | q_seq_len = q.shape[-2] 33 | p_seq = kv_seq_len - q_seq_len 34 | device = q.device 35 | 36 | ms = torch.arange(q_seq_len, device=device).unsqueeze(-1) 37 | ns = torch.arange(kv_seq_len, device=device) 38 | 39 | S = torch.matmul(q, k.transpose(2, 3)) * sm_scale 40 | if causal: 41 | S = torch.where(ms + p_seq >= ns, S, float("-inf")) 42 | 43 | S = S.to(torch.float32) 44 | if return_log_normalizer: 45 | log_normalizer = torch.logsumexp(S, dim=-1) 46 | 47 | # upcast attention to fp32 48 | P = torch.softmax(S, dim=-1, dtype=torch.float32) 49 | if causal: 50 | P = torch.where(ms + p_seq >= ns, P, 0.0) 51 | 52 | if return_total_attention: 53 | tot_attn = torch.sum(P, dim=-2) 54 | 55 | # Applies dropout 56 | dropout_scaling = 1.0 / (1 - dropout_p) 57 | if dropout_mask is not None: 58 | P = P.masked_fill(~dropout_mask, 0.0) 59 | 60 | attn_output = torch.matmul(P.to(v.dtype), v) * dropout_scaling 61 | attn_output = attn_output.to(input_dtype) 62 | 63 | has_extra_return = return_log_normalizer or return_total_attention 64 | if has_extra_return: 65 | outs = (attn_output, 66 | log_normalizer if return_log_normalizer else None, 67 | tot_attn if return_total_attention else None) 68 | return outs 69 | else: 70 | return attn_output 71 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # results of benchmark 2 | benchmark/results*/ 3 | playground/ 4 | 5 | # version, since setuptools-scm is used, this file is automatic generated when building the package 6 | src/flag_attn/_version.py 7 | 8 | # Editors 9 | .vscode/ 10 | .idea/ 11 | 12 | # Vagrant 13 | .vagrant/ 14 | 15 | # Mac/OSX 16 | .DS_Store 17 | 18 | # Windows 19 | Thumbs.db 20 | 21 | # Source for the following rules: https://raw.githubusercontent.com/github/gitignore/master/Python.gitignore 22 | # Byte-compiled / optimized / DLL files 23 | __pycache__/ 24 | *.py[cod] 25 | *$py.class 26 | 27 | # C extensions 28 | *.so 29 | 30 | # Distribution / packaging 31 | .Python 32 | build/ 33 | develop-eggs/ 34 | dist/ 35 | downloads/ 36 | eggs/ 37 | .eggs/ 38 | lib/ 39 | lib64/ 40 | parts/ 41 | sdist/ 42 | var/ 43 | wheels/ 44 | *.egg-info/ 45 | .installed.cfg 46 | *.egg 47 | MANIFEST 48 | 49 | # PyInstaller 50 | # Usually these files are written by a python script from a template 51 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 52 | *.manifest 53 | *.spec 54 | 55 | # Installer logs 56 | pip-log.txt 57 | pip-delete-this-directory.txt 58 | 59 | # Unit test / coverage reports 60 | htmlcov/ 61 | .tox/ 62 | .nox/ 63 | .coverage 64 | .coverage.* 65 | .cache 66 | nosetests.xml 67 | coverage.xml 68 | *.cover 69 | .hypothesis/ 70 | .pytest_cache/ 71 | 72 | # Translations 73 | *.mo 74 | *.pot 75 | 76 | # Django stuff: 77 | *.log 78 | local_settings.py 79 | db.sqlite3 80 | 81 | # Flask stuff: 82 | instance/ 83 | .webassets-cache 84 | 85 | # Scrapy stuff: 86 | .scrapy 87 | 88 | # Sphinx documentation 89 | docs/_build/ 90 | 91 | # PyBuilder 92 | target/ 93 | 94 | # Jupyter Notebook 95 | .ipynb_checkpoints 96 | 97 | # IPython 98 | profile_default/ 99 | ipython_config.py 100 | 101 | # pyenv 102 | .python-version 103 | 104 | # celery beat schedule file 105 | celerybeat-schedule 106 | 107 | # SageMath parsed files 108 | *.sage.py 109 | 110 | # Environments 111 | .env 112 | .venv 113 | env/ 114 | venv/ 115 | ENV/ 116 | env.bak/ 117 | venv.bak/ 118 | 119 | # Spyder project settings 120 | .spyderproject 121 | .spyproject 122 | 123 | # Rope project settings 124 | .ropeproject 125 | 126 | # mkdocs documentation 127 | /site 128 | 129 | # mypy 130 | .mypy_cache/ 131 | .dmypy.json 132 | dmypy.json 133 | 134 | -------------------------------------------------------------------------------- /benchmark/flash_decoding_benchmark.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import pathlib 4 | import torch 5 | import triton 6 | 7 | import flag_attn 8 | 9 | 10 | try: 11 | from flash_attn import flash_attn_func 12 | FLASH_VER = 2 13 | except BaseException: 14 | try: 15 | from flash_attn.flash_attn_interface import flash_attn_func 16 | FLASH_VER = 1 17 | except BaseException: 18 | FLASH_VER = None 19 | HAS_FLASH = FLASH_VER is not None 20 | 21 | 22 | configs = [triton.testing.Benchmark( 23 | x_names=['N_CTX'], 24 | x_vals=[2**i for i in range(9, 20)], 25 | line_arg='provider', 26 | line_vals=['flag_attn', 'torch', ] + (['flash'] if HAS_FLASH else []), 27 | line_names=['flag_attn', 'torch', ] + ([f'flash-{FLASH_VER}'] if HAS_FLASH else []), 28 | styles=[('red', '-'), ('green', '-'), ('blue', '-'), ('cyan', '-')], 29 | ylabel='tflop/s', 30 | plot_name=f'attention_d-{D_HEAD}_dtype-{dtype} (ms)', 31 | args={'D_HEAD': D_HEAD, 'dtype': dtype} 32 | ) for D_HEAD in [64, 128] 33 | for dtype in [torch.float16]] 34 | 35 | @triton.testing.perf_report(configs) 36 | def bench_flash_attention(N_CTX, D_HEAD, provider, dtype=torch.float16): 37 | BATCH = 2 38 | H = 2048 // D_HEAD 39 | causal = False 40 | if provider == "flag_attn": 41 | q = torch.randn((BATCH, H, 1, D_HEAD), dtype=dtype, device="cuda") 42 | k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda") 43 | v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda") 44 | fn = lambda: flag_attn.flash_attention(q, k, v, causal=causal) 45 | ms = triton.testing.do_bench(fn) 46 | if provider == "torch": 47 | q = torch.randn((BATCH, H, 1, D_HEAD), dtype=dtype, device="cuda") 48 | k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda") 49 | v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda") 50 | try: 51 | fn = lambda: flag_attn.testing.flash_attention(q, k, v, causal=causal, upcast=False) 52 | ms = triton.testing.do_bench(fn) 53 | except torch.cuda.OutOfMemoryError as e: 54 | logging.info(f"torch OOM for batch_size: {BATCH}, num_heads: {H}, seqlen: {N_CTX}, headdim: {D_HEAD}") 55 | ms = float("inf") 56 | if provider == "flash": 57 | q = torch.randn((BATCH, 1, H, D_HEAD), dtype=dtype, device="cuda") 58 | k = torch.randn((BATCH, N_CTX, H, D_HEAD), dtype=dtype, device="cuda") 59 | v = torch.randn((BATCH, N_CTX, H, D_HEAD), dtype=dtype, device="cuda") 60 | fn = lambda: flash_attn_func(q, k, v, causal=causal) 61 | ms = triton.testing.do_bench(fn) 62 | 63 | return ms 64 | # # total TFLOPS: following Flash Attention v2, only gemms are counted. 65 | # macs = 2. * BATCH * H * N_CTX * D_HEAD # Q@K, P@V 66 | # total_flops = 2 * macs 67 | # return total_flops / ms * 1e-9 68 | 69 | # only works on post-Ampere GPUs right now 70 | today = datetime.date.today().strftime(format("%Y%m%d")) 71 | output_dir = pathlib.Path(f"results_flash_attention_with_split_kv_{today}") 72 | output_dir.mkdir(exist_ok=True) 73 | bench_flash_attention.run(save_path=output_dir, print_data=True) 74 | -------------------------------------------------------------------------------- /tests/flag_attn/test_paged_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytest 3 | 4 | import flag_attn 5 | 6 | NUM_BLOCKS = 1000 7 | 8 | 9 | def base_paged_attention( 10 | num_seqs, 11 | num_query_heads, 12 | query_group_size, 13 | head_size, 14 | block_size, 15 | max_seq_len, 16 | num_splits=0, 17 | dtype=torch.float16, 18 | device="cuda", 19 | ): 20 | torch.set_default_dtype(dtype) 21 | torch.set_default_device(device=device) 22 | 23 | num_kv_heads = num_query_heads // query_group_size 24 | 25 | context_lens = torch.randint(1, max_seq_len, [num_seqs], dtype=torch.int32) 26 | context_lens[0] = max_seq_len 27 | max_context_len = context_lens.max().item() 28 | max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size 29 | 30 | attn_scale = head_size**-0.5 31 | q = torch.empty(num_seqs, num_query_heads, head_size) 32 | q.uniform_(-attn_scale, attn_scale) 33 | 34 | k_cache = torch.empty(NUM_BLOCKS, num_kv_heads, block_size, head_size) 35 | k_cache.uniform_(-attn_scale, attn_scale) 36 | v_cache = torch.empty_like(k_cache) 37 | v_cache.uniform_(-attn_scale, attn_scale) 38 | 39 | # (NUM_SEQS, MAX_NUM_BLOCKS_PER_SEQ) 40 | block_tables = torch.randint(0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq)) 41 | 42 | out = flag_attn.paged_attention( 43 | q, 44 | k_cache, 45 | v_cache, 46 | context_lens, 47 | block_tables, 48 | attn_scale, 49 | max_context_len, 50 | num_splits, 51 | ) 52 | 53 | ref_out = flag_attn.testing.paged_attention( 54 | q, 55 | k_cache, 56 | v_cache, 57 | block_tables, 58 | context_lens, 59 | attn_scale, 60 | ) 61 | print(torch.abs(out - ref_out).max()) 62 | assert torch.allclose(out, ref_out, atol=2e-3, rtol=1e-5) 63 | 64 | 65 | @pytest.mark.parametrize("num_seqs", [1, 32]) 66 | @pytest.mark.parametrize("num_query_heads", [64]) 67 | @pytest.mark.parametrize("query_group_size", [1, 8]) 68 | @pytest.mark.parametrize("head_size", [64, 128]) 69 | @pytest.mark.parametrize("block_size", [16, 128, 256]) 70 | @pytest.mark.parametrize("max_seq_len", [512, 4096]) 71 | def test_paged_attention_default( 72 | num_seqs, 73 | num_query_heads, 74 | query_group_size, 75 | head_size, 76 | block_size, 77 | max_seq_len, 78 | dtype=torch.float16, 79 | device="cuda", 80 | ): 81 | base_paged_attention( 82 | num_seqs, 83 | num_query_heads, 84 | query_group_size, 85 | head_size, 86 | block_size, 87 | max_seq_len, 88 | ) 89 | 90 | 91 | @pytest.mark.parametrize("num_seqs", [1, 16]) 92 | @pytest.mark.parametrize("num_query_heads", [64]) 93 | @pytest.mark.parametrize("query_group_size", [1, 8]) 94 | @pytest.mark.parametrize("head_size", [32, 64]) 95 | @pytest.mark.parametrize("block_size", [16]) 96 | @pytest.mark.parametrize("max_seq_len", [2048]) 97 | @pytest.mark.parametrize("num_splits", [1, 2, 3, 4, 5, 6, 7, 8]) 98 | def test_paged_attention_by_num_splits( 99 | num_seqs, 100 | num_query_heads, 101 | query_group_size, 102 | head_size, 103 | block_size, 104 | max_seq_len, 105 | num_splits, 106 | dtype=torch.float16, 107 | device="cuda", 108 | ): 109 | base_paged_attention( 110 | num_seqs, 111 | num_query_heads, 112 | query_group_size, 113 | head_size, 114 | block_size, 115 | max_seq_len, 116 | num_splits=num_splits, 117 | ) 118 | 119 | @pytest.mark.parametrize("num_seqs, num_query_heads, query_group_size, head_size, block_size, max_seq_len, num_splits", [ 120 | (1, 12, 1, 64, 16, 2, 0), 121 | (16, 64, 8, 32, 16, 2048, 2), 122 | (16, 64, 1, 64, 16, 2048, 6), 123 | ]) 124 | def test_paged_attention_by_case( 125 | num_seqs, 126 | num_query_heads, 127 | query_group_size, 128 | head_size, 129 | block_size, 130 | max_seq_len, 131 | num_splits, 132 | dtype=torch.float16, 133 | device="cuda", 134 | ): 135 | base_paged_attention( 136 | num_seqs, 137 | num_query_heads, 138 | query_group_size, 139 | head_size, 140 | block_size, 141 | max_seq_len, 142 | num_splits, 143 | ) 144 | -------------------------------------------------------------------------------- /src/flag_attn/testing/piecewise.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import pytest 4 | 5 | def attention(q1, k1, q2, k2, v, dist_threshold, causal, sm_scale=None, upcast=False): 6 | input_dtype = q1.dtype 7 | if upcast: 8 | q1, k1, q2, k2, v = q1.float(), k1.float(), q2.float(), k2.float(), v.float() 9 | # (B, H, T, D) 10 | D = q1.shape[-1] 11 | if sm_scale is None: 12 | sm_scale = 1. / math.sqrt(D) 13 | kv_seq_len = k1.shape[-2] 14 | q_seq_len = q1.shape[-2] 15 | p_seq = kv_seq_len - q_seq_len 16 | device = q1.device 17 | 18 | ms = torch.arange(q_seq_len, device=device).unsqueeze(-1) 19 | ns = torch.arange(kv_seq_len, device=device) 20 | 21 | S1 = torch.matmul(q1, k1.transpose(2, 3)) 22 | S2 = torch.matmul(q2, k2.transpose(2, 3)) 23 | long_distance = ((ms + p_seq - ns) >= dist_threshold) 24 | S = torch.where(long_distance, S2, S1) * sm_scale 25 | 26 | if causal: 27 | S = torch.where(ms + p_seq >= ns, S, torch.finfo(S.dtype).min) 28 | 29 | # upcast attention to fp32 30 | P = torch.softmax(S, dim=-1, dtype=torch.float32).to(v.dtype) 31 | if causal: 32 | P = torch.where(ms + p_seq >= ns, P, 0.0) 33 | attn_output = torch.matmul(P, v) 34 | return attn_output.to(input_dtype) 35 | 36 | def attention_grad(q1, k1, q2, k2, v, w, causal, sm_scale, o, do, upcast=False): 37 | input_dtype = q1.dtype 38 | 39 | if upcast: 40 | q1, k1, q2, k2, v, o, do = [item.float() for item in [q1, k1, q2, k2, v, o, do]] 41 | kv_seq_len = k1.shape[-2] 42 | q_seq_len = q1.shape[-2] 43 | p_seq = kv_seq_len - q_seq_len 44 | device = q1.device 45 | 46 | ms = torch.arange(q_seq_len, device=device).unsqueeze(-1) 47 | ns = torch.arange(kv_seq_len, device=device) 48 | 49 | S1 = torch.matmul(q1, k1.transpose(2, 3)) 50 | S2 = torch.matmul(q2, k2.transpose(2, 3)) 51 | long_distance = ((ms + p_seq - ns) >= w) 52 | S = torch.where(long_distance, S2, S1) * sm_scale 53 | 54 | if causal: 55 | S = torch.where((ms + p_seq) >= ns, S, torch.finfo(S.dtype).min) 56 | 57 | # upcast attention to fp32 58 | P = torch.softmax(S, dim=-1, dtype=torch.float32).to(v.dtype) 59 | 60 | # dP & dv 61 | dv = torch.matmul(P.transpose(2, 3), do) 62 | dP = torch.matmul(do, v.transpose(2, 3)) 63 | 64 | # dS 65 | delta = (do * o).sum(-1, keepdim=True) # (B,H,T) 66 | dS = P * (dP - delta) * sm_scale 67 | dS2 = torch.where(long_distance, dS, 0.0) 68 | dS1 = torch.where(long_distance, 0.0, dS) 69 | 70 | # dq & dk 71 | dq1 = torch.matmul(dS1, k1) 72 | dk1 = torch.matmul(dS1.transpose(2, 3), q1) 73 | 74 | dq2 = torch.matmul(dS2, k2) 75 | dk2 = torch.matmul(dS2.transpose(2, 3), q2) 76 | 77 | dq1, dk1, dq2, dk2, dv = [item.to(input_dtype) for item in [dq1, dk1, dq2, dk2, dv]] 78 | return dq1, dk1, dq2, dk2, dv 79 | 80 | @pytest.mark.parametrize('B, H, T, D, P_SEQ', [(2, 3, 1024, 32, 100), (2, 3, 1024, 32, 0)]) 81 | @pytest.mark.parametrize('causal', [True, False]) 82 | @pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) 83 | def test_op(B, H, T, D, P_SEQ, causal, dtype): 84 | q1 = torch.empty((B, H, T, D), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() 85 | q2 = torch.empty((B, H, T, D), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() 86 | k1 = torch.empty((B, H, T + P_SEQ, D), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() 87 | k2 = torch.empty((B, H, T + P_SEQ, D), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() 88 | v = torch.empty((B, H, T + P_SEQ, D), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() 89 | sm_scale = 0.5 90 | w = 780 91 | 92 | o = attention(q1, k1, q2, k2, v, w, causal, sm_scale) 93 | do = torch.empty((B, H, T, D), dtype=dtype, device="cuda").normal_(mean=0., std=0.5) 94 | o.backward(do) 95 | dq1, dk1, dq2, dk2, dv = attention_grad(q1, k1, q2, k2, v, w, causal, sm_scale, o, do) 96 | 97 | torch.testing.assert_close(dv, v.grad, atol=1e-2, rtol=0.0) 98 | torch.testing.assert_close(dq1, q1.grad, atol=1e-2, rtol=0.0) 99 | torch.testing.assert_close(dk1, k1.grad, atol=1e-2, rtol=0.0) 100 | torch.testing.assert_close(dq2, q2.grad, atol=1e-2, rtol=0.0) 101 | torch.testing.assert_close(dk2, k2.grad, atol=1e-2, rtol=0.0) 102 | -------------------------------------------------------------------------------- /benchmark/flash_benchmark.py: -------------------------------------------------------------------------------- 1 | import math 2 | import datetime 3 | import logging 4 | import pathlib 5 | import torch 6 | import triton 7 | 8 | import flag_attn 9 | 10 | 11 | try: 12 | from flash_attn import flash_attn_func 13 | FLASH_VER = 2 14 | except BaseException: 15 | try: 16 | from flash_attn.flash_attn_interface import flash_attn_func 17 | FLASH_VER = 1 18 | except BaseException: 19 | FLASH_VER = None 20 | HAS_FLASH = FLASH_VER is not None 21 | 22 | 23 | configs = [triton.testing.Benchmark( 24 | x_names=['N_CTX'], 25 | x_vals=[2**i for i in range(9, 16)], 26 | line_arg='provider', 27 | line_vals=['flag_attn', 'torch', ] + (['flash'] if HAS_FLASH else []), 28 | line_names=['flag_attn', 'torch', ] + ([f'flash-{FLASH_VER}'] if HAS_FLASH else []), 29 | styles=[('red', '-'), ('green', '-'), ('blue', '-'), ('cyan', '-')], 30 | ylabel='tflop/s', 31 | plot_name=f'attention_d-{D_HEAD}_mode-{mode}_causal-{causal}_dtype-{dtype}', 32 | args={'D_HEAD': D_HEAD, 'dtype': dtype, 'mode': mode, 'causal': causal} 33 | ) for mode in ['fwd', 'bwd'] 34 | for causal in [False, True] 35 | for D_HEAD in [64, 128] 36 | for dtype in [torch.float16, torch.bfloat16]] 37 | 38 | @triton.testing.perf_report(configs) 39 | def bench_flash_attention(N_CTX, D_HEAD, causal, mode, provider, dtype=torch.float16, device="cuda"): 40 | assert mode in ['fwd', 'bwd'] 41 | w = N_CTX // 2 # dist thresold 42 | warmup = 25 43 | rep = 100 44 | 45 | is_bwd = mode == "bwd" 46 | 47 | BATCH = 32768 // N_CTX 48 | H = 2048 // D_HEAD 49 | if provider == "flag_attn": 50 | q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=is_bwd) 51 | k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=is_bwd) 52 | v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=is_bwd) 53 | fn = lambda: flag_attn.flash_attention(q, k, v, causal=causal) 54 | if mode == 'bwd': 55 | o = fn() 56 | do = torch.randn_like(o) 57 | fn = lambda: o.backward(do, retain_graph=True) 58 | ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) 59 | if provider == "torch": 60 | q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=is_bwd) 61 | k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=is_bwd) 62 | v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=is_bwd) 63 | 64 | try: 65 | fn = lambda: flag_attn.testing.flash_attention(q, k, v, causal=causal, upcast=False) 66 | if mode == 'bwd': 67 | o = fn() 68 | do = torch.randn_like(o) 69 | fn = lambda: o.backward(do, retain_graph=True) 70 | ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) 71 | except torch.cuda.OutOfMemoryError as e: 72 | logging.info(f"torch OOM for batch_size: {BATCH}, num_heads: {H}, seqlen: {N_CTX}, headdim: {D_HEAD}") 73 | ms = float("inf") 74 | if provider == "flash": 75 | if FLASH_VER == 1: 76 | qkv = torch.randn((BATCH, N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=is_bwd) 77 | lengths = torch.full((BATCH,), fill_value=N_CTX, device=device) 78 | cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32) 79 | cu_seqlens[1:] = lengths.cumsum(0) 80 | qkv = qkv.reshape(BATCH * N_CTX, 3, H, D_HEAD) 81 | fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=causal) 82 | elif FLASH_VER == 2: 83 | q = torch.randn((BATCH, N_CTX, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=is_bwd) 84 | k = torch.randn((BATCH, N_CTX, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=is_bwd) 85 | v = torch.randn((BATCH, N_CTX, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=is_bwd) 86 | fn = lambda: flash_attn_func(q, k, v, causal=causal) 87 | else: 88 | raise ValueError(f'unknown {FLASH_VER = }') 89 | if mode == 'bwd': 90 | o = fn() 91 | do = torch.randn_like(o) 92 | fn = lambda: o.backward(do, retain_graph=True) 93 | ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) 94 | 95 | # total TFLOPS: following Flash Attention v2, only gemms are counted. 96 | macs = 2. * BATCH * H * N_CTX * N_CTX * D_HEAD # Q@K, P@V 97 | if mode == 'bwd': 98 | macs *= 2.5 # Q@K, dO@V, dO@P, dS@Q dS@K 99 | total_flops = 2 * macs 100 | 101 | if causal: 102 | total_flops *= 0.5 103 | return total_flops / ms * 1e-9 104 | 105 | # only works on post-Ampere GPUs right now 106 | today = datetime.date.today().strftime(format("%Y%m%d")) 107 | output_dir = pathlib.Path(f"results_flash_attention_{today}") 108 | output_dir.mkdir(exist_ok=True) 109 | bench_flash_attention.run(save_path=output_dir, print_data=True) 110 | -------------------------------------------------------------------------------- /src/flag_attn/total.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import triton 4 | import triton.language as tl 5 | 6 | def get_fwd_config(B, H, M, N, D, causal): 7 | return (32, 32, 1, 4) 8 | 9 | def total_attention(q, k, l, causal=False, sm_scale=None): 10 | Dq, Dk = q.shape[-1], k.shape[-1] 11 | assert Dq == Dk 12 | assert Dk in {16, 32, 64, 128} 13 | # assert L is contiguous 14 | 15 | B, H, M, D = q.shape 16 | N = k.shape[2] 17 | Hk = k.shape[1] 18 | assert H % Hk == 0, "number of heads in q must be a multiple of that in k" 19 | num_groups = H // Hk 20 | 21 | P_SEQ = N - M 22 | 23 | if sm_scale is None: 24 | sm_scale = 1. / math.sqrt(D) 25 | 26 | # to work around https://github.com/openai/triton/issues/2441 27 | device = torch.cuda.device_of(q) 28 | with torch.cuda.device(device): 29 | config = get_fwd_config(B, H, M, N, D, causal) 30 | BLOCK_M, BLOCK_N, num_stages, num_warps = config 31 | 32 | divisible_m = M % BLOCK_M == 0 33 | divisible_n = N % BLOCK_N == 0 34 | # consider using 3d grid to avoid div & rem 35 | grid = (triton.cdiv(N, BLOCK_N), H, B) 36 | tot_attn = torch.empty((B, H, N), dtype=torch.float32, device=q.device) 37 | _total_attention_kernel[grid]( 38 | q, k, l, tot_attn, sm_scale, 39 | q.stride(0), q.stride(1), q.stride(2), q.stride(3), 40 | k.stride(0), k.stride(1), k.stride(2), k.stride(3), 41 | B, H, M, N, P_SEQ, num_groups, 42 | BLOCK_M=BLOCK_M, BLOCK_DMODEL=D, BLOCK_N=BLOCK_N, 43 | CAUSAL=causal, 44 | DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n, 45 | num_stages=num_stages, num_warps=num_warps, 46 | ) 47 | return tot_attn 48 | 49 | 50 | @triton.jit 51 | def _total_attention_kernel( 52 | Q, K, L, TA, sm_scale, 53 | stride_qz, stride_qh, stride_qm, stride_qk, 54 | stride_kz, stride_kh, stride_kn, stride_kk, 55 | Z, H, M, N, P_SEQ, num_groups, 56 | BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, 57 | CAUSAL: tl.constexpr, 58 | DIVISIBLE_M: tl.constexpr, DIVISIBLE_N: tl.constexpr, 59 | ): 60 | # -- grid id -- 61 | start_n = tl.program_id(0) 62 | off_h = tl.program_id(1) 63 | off_z = tl.program_id(2) 64 | log2e: tl.constexpr = 1.4426950408889634 65 | qk_scale = sm_scale * log2e 66 | 67 | # offset pointers for (batch, head) 68 | off_hk = off_h // num_groups 69 | Q += off_z * stride_qz + off_h * stride_qh 70 | K += off_z * stride_kz + off_hk * stride_kh 71 | L += (off_z * H + off_h) * M 72 | TA += (off_z * H + off_h) * N # (b, h, n) 73 | 74 | if CAUSAL: 75 | lo = tl.maximum(start_n * BLOCK_N - P_SEQ, 0) 76 | lo = (lo // BLOCK_M) * BLOCK_M 77 | else: 78 | lo = 0 79 | 80 | offs_m_init = lo + tl.arange(0, BLOCK_M) 81 | offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) 82 | offs_m_base = tl.arange(0, BLOCK_M) 83 | offs_k = tl.arange(0, BLOCK_DMODEL) 84 | 85 | # initialize pointers to value-like data 86 | q_ptrs = Q + (offs_m_init[:, None] * stride_qm + offs_k[None, :] * stride_qk) # (BLOCK_M, BLOCK_DMODEL) 87 | k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) # (BLOCK_N, BLOCK_DMODEL) 88 | ta_ptrs = TA + offs_n # (BLOCK_N, ) 89 | 90 | # k and v stay in SRAM throughout 91 | if DIVISIBLE_N: 92 | k = tl.load(k_ptrs) 93 | else: 94 | mask_n = offs_n < N 95 | k = tl.load(k_ptrs, mask=mask_n[:, None]) 96 | 97 | # initialize total attention 98 | tot_attn = tl.zeros([BLOCK_N], dtype=tl.float32) 99 | 100 | # loop over a col 101 | for start_m in range(lo, M, BLOCK_M): 102 | start_m = tl.multiple_of(start_m, BLOCK_M) 103 | offs_m = start_m + offs_m_base 104 | causal_mask = (P_SEQ + offs_m[:, None]) >= (offs_n[None, :]) # (BLOCK_M, BLOCK_N) 105 | 106 | if DIVISIBLE_M: 107 | q = tl.load(q_ptrs) 108 | else: 109 | mask_m = offs_m < M 110 | valid_mask = mask_m[:, None] # & mask_n 111 | q = tl.load(q_ptrs, mask=mask_m[:, None]) 112 | # recompute p = softmax(qk * sm_scale, dim=-1) 113 | s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 114 | s += tl.dot(q, tl.trans(k)) 115 | 116 | # NOTE: since softmax in backward is pointwise, the normalizer has been saved in fwd) 117 | # So masking on s is not needed. 118 | # s = tl.where(valid_mask, s , float("-inf")) 119 | # if CAUSAL: 120 | # s = tl.where(causal_mask, s, float("-inf")) 121 | 122 | # -- recompute p --- 123 | if DIVISIBLE_M: 124 | l = tl.load(L + offs_m) 125 | else: 126 | l = tl.load(L + offs_m, mask=mask_m) 127 | p = tl.math.exp2(s * qk_scale - l[:, None] * log2e) # (BLOCK_M, BLOCK_N) 128 | 129 | if not DIVISIBLE_M: 130 | p = tl.where(valid_mask, p, 0.0) 131 | if CAUSAL: 132 | p = tl.where(causal_mask, p, 0.0) 133 | 134 | tot_attn += tl.sum(p, 0) 135 | # increment pointers 136 | q_ptrs += BLOCK_M * stride_qm 137 | 138 | 139 | if DIVISIBLE_N: 140 | tl.store(ta_ptrs, tot_attn) # (BLOCK_N,) 141 | else: 142 | tl.store(ta_ptrs, tot_attn, mask=mask_n) # (BLOCK_N, ) 143 | -------------------------------------------------------------------------------- /benchmark/piecewise_benchmark.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import pathlib 3 | import datetime 4 | import torch 5 | import triton 6 | 7 | import flag_attn 8 | 9 | try: 10 | from flash_attn import flash_attn_func 11 | FLASH_VER = 2 12 | except BaseException: 13 | try: 14 | from flash_attn.flash_attn_interface import flash_attn_func 15 | FLASH_VER = 1 16 | except BaseException: 17 | FLASH_VER = None 18 | HAS_FLASH = FLASH_VER is not None 19 | 20 | 21 | configs = [triton.testing.Benchmark( 22 | x_names=['N_CTX'], 23 | x_vals=[2**i for i in range(9, 16)], 24 | line_arg='provider', 25 | line_vals=['piecewise', 'torch'] + (['flash'] if HAS_FLASH else []), 26 | line_names=['piecewise', 'torch'] + ([f'flash-{FLASH_VER}'] if HAS_FLASH else []), 27 | styles=[('red', '-'), ('green', '-'), ('blue', '-')], 28 | ylabel='tflop/s', 29 | plot_name=f'piecewise_attention_d-{D_HEAD}_mode-{mode}_causal-{causal}_dtype-{dtype}', 30 | args={'D_HEAD': D_HEAD, 'dtype': dtype, 'mode': mode, 'causal': causal} 31 | ) for mode in ['fwd', 'bwd'] 32 | for causal in [False, True] 33 | for D_HEAD in [64, 128] 34 | for dtype in [torch.float16, torch.bfloat16]] 35 | 36 | @triton.testing.perf_report(configs) 37 | def bench_flash_attention(N_CTX, D_HEAD, causal, mode, provider, dtype=torch.float16, device="cuda"): 38 | assert mode in ['fwd', 'bwd'] 39 | w = N_CTX // 2 # dist thresold 40 | warmup = 25 41 | rep = 100 42 | 43 | BATCH = 32768 // N_CTX 44 | H = 2048 // D_HEAD 45 | if provider == "piecewise": 46 | q1 = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) 47 | k1 = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) 48 | q2 = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) 49 | k2 = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) 50 | v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) 51 | fn = lambda: flag_attn.piecewise_attention(q1, k1, q2, k2, v, w, causal=causal) 52 | if mode == 'bwd': 53 | o = fn() 54 | do = torch.randn_like(o) 55 | fn = lambda: o.backward(do, retain_graph=True) 56 | ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) 57 | if provider == "torch": 58 | q1 = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) 59 | k1 = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) 60 | q2 = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) 61 | k2 = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) 62 | v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) 63 | 64 | try: 65 | fn = lambda: flag_attn.testing.piecewise_attention(q1, k1, q2, k2, v, w, causal=causal) 66 | if mode == 'bwd': 67 | o = fn() 68 | do = torch.randn_like(o) 69 | fn = lambda: o.backward(do, retain_graph=True) 70 | ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) 71 | except torch.cuda.OutOfMemoryError as e: 72 | logging.info(f"torch OOM for batch_size: {BATCH}, num_heads: {H}, seqlen: {N_CTX}, headdim: {D_HEAD}") 73 | ms = float("inf") 74 | if provider == "flash": 75 | if FLASH_VER == 1: 76 | qkv = torch.randn((BATCH, N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True) 77 | lengths = torch.full((BATCH,), fill_value=N_CTX, device=device) 78 | cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32) 79 | cu_seqlens[1:] = lengths.cumsum(0) 80 | qkv = qkv.reshape(BATCH * N_CTX, 3, H, D_HEAD) 81 | fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=causal) 82 | elif FLASH_VER == 2: 83 | q = torch.randn((BATCH, N_CTX, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) 84 | k = torch.randn((BATCH, N_CTX, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) 85 | v = torch.randn((BATCH, N_CTX, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) 86 | fn = lambda: flash_attn_func(q, k, v, causal=causal) 87 | else: 88 | raise ValueError(f'unknown {FLASH_VER = }') 89 | if mode == 'bwd': 90 | o = fn() 91 | do = torch.randn_like(o) 92 | fn = lambda: o.backward(do, retain_graph=True) 93 | ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) 94 | 95 | # total TFLOPS: following Flash Attention v2, only gemms are counted. 96 | # NOTE: It is not a fair play here, the total amount of flops and the elapsed time are different, 97 | # the tflop/s is a used as a metric of the performance of the operator, for refernce only. 98 | if provider == "flash": 99 | macs = 2. * BATCH * H * N_CTX * N_CTX * D_HEAD # Q@K, P@V 100 | if mode == 'bwd': 101 | macs *= 2.5 # Q@K, dO@V, dO@P, dS@Q dS@K 102 | else: 103 | macs = 3. * BATCH * H * N_CTX * N_CTX * D_HEAD # Q1@K1, Q2@K2, P@V 104 | if mode == 'bwd': 105 | macs *= 8.0/3.0 # 1 *(Q1@K1, Q2@K2, dO@V), dO@P, dS1@@Q1, dS1@K1, dS2@@Q2, dS2@K2 106 | total_flops = 2 * macs 107 | 108 | if causal: 109 | total_flops *= 0.5 110 | return total_flops / ms * 1e-9 111 | 112 | # only works on post-Ampere GPUs right now 113 | today = datetime.date.today().strftime(format("%Y%m%d")) 114 | output_dir = pathlib.Path(f"results_piecewise_attention_{today}") 115 | output_dir.mkdir(exist_ok=True) 116 | bench_flash_attention.run(save_path=output_dir, print_data=True) 117 | -------------------------------------------------------------------------------- /benchmark/paged_benchmark.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import triton 3 | import flag_attn 4 | 5 | NUM_BLOCKS = 1000 6 | warmup = 200 7 | rep = 200 8 | 9 | try: 10 | from vllm._C import ops as vllm_ops 11 | 12 | HAS_VLLM = True 13 | 14 | # required vllm 0.3.0 15 | import vllm 16 | 17 | print("vllm.__version__", vllm.__version__) 18 | except BaseException: 19 | HAS_VLLM = False 20 | 21 | 22 | def vllm_paged_attention( 23 | out: torch.Tensor, 24 | query: torch.Tensor, 25 | key_cache: torch.Tensor, 26 | value_cache: torch.Tensor, 27 | num_kv_heads: int, 28 | scale: float, 29 | block_tables: torch.Tensor, 30 | context_lens: torch.Tensor, 31 | block_size: int, 32 | max_context_len: int, 33 | PARTITION_SIZE: int = 512, 34 | version: int = 1, 35 | ): 36 | if version == 1: 37 | vllm_ops.paged_attention_v1( 38 | out, 39 | query, 40 | key_cache, 41 | value_cache, 42 | num_kv_heads, 43 | scale, 44 | block_tables, 45 | context_lens, 46 | block_size, 47 | max_context_len, 48 | None, # alibi_slopes 49 | "auto", # kv_cache_dtype for vllm 0.3.0 50 | ) 51 | elif version == 2: 52 | num_partitions = (max_context_len + PARTITION_SIZE - 1) // PARTITION_SIZE 53 | assert PARTITION_SIZE % block_size == 0 54 | num_seqs, num_heads, head_size = out.shape 55 | tmp_out = torch.empty( 56 | size=(num_seqs, num_heads, num_partitions, head_size), 57 | dtype=out.dtype, 58 | device=out.device, 59 | ) 60 | exp_sums = torch.empty( 61 | size=(num_seqs, num_heads, num_partitions), 62 | dtype=torch.float32, 63 | device=out.device, 64 | ) 65 | max_logits = torch.empty_like(exp_sums) 66 | vllm_ops.paged_attention_v2( 67 | out, 68 | exp_sums, 69 | max_logits, 70 | tmp_out, 71 | query, 72 | key_cache, 73 | value_cache, 74 | num_kv_heads, 75 | scale, 76 | block_tables, 77 | context_lens, 78 | block_size, 79 | max_context_len, 80 | None, 81 | "auto", # vllm 0.3.0 82 | ) 83 | else: 84 | raise AssertionError(f"Unknown version: {version}") 85 | 86 | 87 | @triton.testing.perf_report( 88 | [ 89 | triton.testing.Benchmark( 90 | x_names=["context_len"], 91 | x_vals=[2**i for i in range(9, 15)], 92 | line_arg="provider", 93 | line_vals=["triton"] + (["vllm"] if HAS_VLLM else []), 94 | line_names=["triton"] + ([f"vllm-{vllm.__version__}"] if HAS_VLLM else []), 95 | styles=[("red", "-"), ("blue", "-")], 96 | ylabel="tflop/s", 97 | plot_name=f"vllm_paged_attention-B{num_seqs}-G{query_group_size}-D{head_size}-bs{block_size}-v{version}", 98 | args={ 99 | "num_seqs": num_seqs, 100 | "num_query_heads": 64, 101 | "query_group_size": query_group_size, 102 | "head_size": head_size, 103 | "block_size": block_size, 104 | "vllm_version": version, 105 | "dtype": dtype, 106 | }, 107 | ) 108 | for num_seqs in [1, 32, 64] 109 | for query_group_size in [1, 8] 110 | for head_size in [64, 128] 111 | for block_size in [16, 32] 112 | for version in [1, 2] 113 | for dtype in [torch.float16] 114 | ] 115 | ) 116 | def paged_attention_benchmark_with_vllm( 117 | num_seqs, 118 | num_query_heads, 119 | query_group_size, 120 | head_size, 121 | block_size, 122 | context_len, 123 | vllm_version, 124 | provider, 125 | dtype=torch.float16, 126 | device="cuda", 127 | ): 128 | num_kv_heads = num_query_heads // query_group_size 129 | 130 | context_lens = torch.zeros(num_seqs, dtype=torch.int32, device=device) + context_len 131 | max_num_blocks_per_seq = (context_len + block_size - 1) // block_size 132 | 133 | attn_scale = head_size**-0.5 134 | q = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype, device=device) 135 | q.uniform_(-attn_scale, attn_scale) 136 | out = torch.empty_like(q) 137 | 138 | k_cache = torch.empty( 139 | NUM_BLOCKS, num_kv_heads, block_size, head_size, dtype=dtype, device=device 140 | ) 141 | k_cache.uniform_(-attn_scale, attn_scale) 142 | v_cache = torch.empty_like(k_cache) 143 | v_cache.uniform_(-attn_scale, attn_scale) 144 | 145 | # (NUM_SEQS, MAX_NUM_BLOCKS_PER_SEQ) 146 | block_tables = torch.randint( 147 | 0, 148 | NUM_BLOCKS, 149 | (num_seqs, max_num_blocks_per_seq), 150 | dtype=torch.int32, 151 | device=device, 152 | ) 153 | 154 | if provider == "triton": 155 | fn = lambda: flag_attn.paged_attention( 156 | q, 157 | k_cache, 158 | v_cache, 159 | context_lens, 160 | block_tables, 161 | attn_scale, 162 | context_len, 163 | ) 164 | ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) 165 | 166 | if provider == "vllm": 167 | # Correctness error, does not affect performance results 168 | fn = lambda: vllm_paged_attention( 169 | out, 170 | q, 171 | k_cache, 172 | v_cache, 173 | num_kv_heads, 174 | attn_scale, 175 | block_tables, 176 | context_lens, 177 | block_size, 178 | context_len, 179 | PARTITION_SIZE=512, 180 | version=vllm_version, 181 | ) 182 | ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) 183 | 184 | total_flops = 2.0 * num_seqs * num_query_heads * 2 * context_len * head_size 185 | return total_flops / ms * 1e-9 186 | 187 | 188 | if HAS_VLLM: 189 | paged_attention_benchmark_with_vllm.run(print_data=True) 190 | -------------------------------------------------------------------------------- /tests/flag_attn/test_piecewise_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytest 3 | import logging 4 | 5 | import flag_attn 6 | 7 | torch.random.manual_seed(10086) 8 | 9 | def max_diff(a, b): 10 | return (a - b).abs().max().item() 11 | 12 | @pytest.mark.parametrize('device_id', list(range(torch.cuda.device_count()))) 13 | @pytest.mark.parametrize('scale', [1.0, 2.0, 3.0, 4.0]) 14 | @pytest.mark.parametrize('B, H, M, N, D', [ 15 | (2, 4, 512, 612, 128), 16 | (2, 4, 1024, 1034, 64), 17 | (2, 4, 2048, 2048, 32), 18 | (2, 4, 4096, 4096, 16), 19 | (2, 4, 4096, 4001, 16), 20 | (1, 2, 8192, 8192, 16), 21 | (1, 2, 8192, 8192, 32), 22 | (1, 2, 8192, 3867, 32), 23 | ]) 24 | @pytest.mark.parametrize('causal', [True, False]) 25 | @pytest.mark.parametrize('stride_order', ['BHTD', 'BTHD']) 26 | @pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) 27 | def test_attention_fwd(B, H, M, N, D, causal, stride_order, dtype, scale, device_id): 28 | device = f"cuda:{device_id}" 29 | if stride_order == "BHTD": 30 | q1 = torch.empty((B, H, M, D), dtype=dtype, device=device).normal_(mean=0., std=scale) 31 | q2 = torch.empty((B, H, M, D), dtype=dtype, device=device).normal_(mean=0., std=scale) 32 | k1 = torch.empty((B, H, N, D), dtype=dtype, device=device).normal_(mean=0., std=scale) 33 | k2 = torch.empty((B, H, N, D), dtype=dtype, device=device).normal_(mean=0., std=scale) 34 | v = torch.empty((B, H, N, D), dtype=dtype, device=device).normal_(mean=0., std=scale) 35 | else: 36 | q1 = torch.empty((B, M, H, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2) 37 | q2 = torch.empty((B, M, H, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2) 38 | k1 = torch.empty((B, N, H, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2) 39 | k2 = torch.empty((B, N, H, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2) 40 | v = torch.empty((B, N, H, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2) 41 | w = (M // 2) if M < N else (M - N // 2) 42 | 43 | o_ref = flag_attn.testing.piecewise_attention(q1, k1, q2, k2, v, w, causal=causal, upcast=True) 44 | o_torch = flag_attn.testing.piecewise_attention(q1, k1, q2, k2, v, w, causal=causal) 45 | o_hyp = flag_attn.piecewise_attention(q1, k1, q2, k2, v, w, causal=causal) 46 | 47 | torch_max_diff = max_diff(o_torch, o_ref) 48 | triton_max_diff = max_diff(o_hyp, o_ref) 49 | logging.info("torch_max_diff: {:.8f}\ttriton_max_diff: {:.8f}".format(torch_max_diff, triton_max_diff)) 50 | assert triton_max_diff <= 2 * torch_max_diff + 1e-5 51 | # assert torch.testing.assert_close(o_hyp, o_ref) 52 | 53 | 54 | 55 | @pytest.mark.parametrize('device_id', list(range(torch.cuda.device_count()))) 56 | @pytest.mark.parametrize('scale', [1.0, 2.0, 3.0, 4.0]) 57 | @pytest.mark.parametrize('B, H, M, N, D', [ 58 | (2, 4, 512, 612, 128), 59 | (2, 4, 1024, 1034, 64), 60 | (2, 4, 2048, 2048, 32), 61 | (2, 4, 4096, 4096, 16), 62 | (2, 4, 4096, 4001, 16), 63 | (1, 2, 8192, 8192, 16), 64 | (1, 2, 8192, 8192, 32), 65 | (1, 2, 8192, 3867, 32), 66 | ]) 67 | @pytest.mark.parametrize('causal', [True, False]) 68 | @pytest.mark.parametrize('stride_order', ['BHTD', 'BTHD']) 69 | @pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) 70 | def test_attention_bwd(B, H, M, N, D, causal, stride_order, dtype, scale, device_id): 71 | device = f"cuda:{device_id}" 72 | if stride_order == "BHTD": 73 | q1 = torch.empty((B, H, M, D), dtype=dtype, device=device).normal_(mean=0., std=scale).requires_grad_() 74 | q2 = torch.empty((B, H, M, D), dtype=dtype, device=device).normal_(mean=0., std=scale).requires_grad_() 75 | k1 = torch.empty((B, H, N, D), dtype=dtype, device=device).normal_(mean=0., std=scale).requires_grad_() 76 | k2 = torch.empty((B, H, N, D), dtype=dtype, device=device).normal_(mean=0., std=scale).requires_grad_() 77 | v = torch.empty((B, H, N, D), dtype=dtype, device=device).normal_(mean=0., std=scale).requires_grad_() 78 | do = torch.empty((B, H, M, D), dtype=dtype, device=device).normal_(mean=0., std=scale) 79 | else: 80 | q1 = torch.empty((B, M, H, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2).requires_grad_() 81 | q2 = torch.empty((B, M, H, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2).requires_grad_() 82 | k1 = torch.empty((B, N, H, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2).requires_grad_() 83 | k2 = torch.empty((B, N, H, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2).requires_grad_() 84 | v = torch.empty((B, N, H, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2).requires_grad_() 85 | do = torch.empty((B, M, H, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2) 86 | 87 | w = (M // 2) if M < N else (M - N // 2) 88 | 89 | o_ref = flag_attn.testing.piecewise_attention(q1, k1, q2, k2, v, w, causal=causal, upcast=True) 90 | dq1_ref, dk1_ref, dq2_ref, dk2_ref, dv_ref = torch.autograd.grad(o_ref, (q1, k1, q2, k2, v), do) 91 | 92 | o_torch = flag_attn.testing.piecewise_attention(q1, k1, q2, k2, v, w, causal=causal, upcast=False) 93 | dq1_torch, dk1_torch, dq2_torch, dk2_torch, dv_torch = torch.autograd.grad(o_torch, (q1, k1, q2, k2, v), do) 94 | 95 | o_hyp = flag_attn.piecewise_attention(q1, k1, q2, k2, v, w, causal=causal) 96 | dq1_hyp, dk1_hyp, dq2_hyp, dk2_hyp, dv_hyp = torch.autograd.grad(o_hyp, (q1, k1, q2, k2, v), do) 97 | 98 | o_torch_max_diff = max_diff(o_torch, o_ref) 99 | dq1_torch_max_diff = max_diff(dq1_torch, dq1_ref) 100 | dq2_torch_max_diff = max_diff(dq2_torch, dq2_ref) 101 | dk1_torch_max_diff = max_diff(dk1_torch, dk1_ref) 102 | dk2_torch_max_diff = max_diff(dk2_torch, dk2_ref) 103 | dv_torch_max_diff = max_diff(dv_torch, dv_ref) 104 | 105 | o_triton_max_diff = max_diff(o_hyp, o_ref) 106 | dq1_triton_max_diff = max_diff(dq1_hyp, dq1_ref) 107 | dq2_triton_max_diff = max_diff(dq2_hyp, dq2_ref) 108 | dk1_triton_max_diff = max_diff(dk1_hyp, dk1_ref) 109 | dk2_triton_max_diff = max_diff(dk2_hyp, dk2_ref) 110 | dv_triton_max_diff = max_diff(dv_hyp, dv_ref) 111 | 112 | logging.info("o torch_max_diff: {:.8f}\ttriton_max_diff: {:.8f}".format(o_torch_max_diff, o_triton_max_diff)) 113 | logging.info("dq1 torch_max_diff: {:.8f}\ttriton_max_diff: {:.8f}".format(dq1_torch_max_diff, dq1_triton_max_diff)) 114 | logging.info("dq2 torch_max_diff: {:.8f}\ttriton_max_diff: {:.8f}".format(dq2_torch_max_diff, dq2_triton_max_diff)) 115 | logging.info("dk1 torch_max_diff: {:.8f}\ttriton_max_diff: {:.8f}".format(dk1_torch_max_diff, dk1_triton_max_diff)) 116 | logging.info("dk2 torch_max_diff: {:.8f}\ttriton_max_diff: {:.8f}".format(dk2_torch_max_diff, dk2_triton_max_diff)) 117 | logging.info("dv torch_max_diff: {:.8f}\ttriton_max_diff: {:.8f}".format(dv_torch_max_diff, dv_triton_max_diff)) 118 | 119 | assert o_triton_max_diff <= 2 * o_torch_max_diff + 1e-5 120 | assert dq1_triton_max_diff <= 2 * dq1_torch_max_diff + 1e-5 121 | assert dq2_triton_max_diff <= 2 * dq2_torch_max_diff + 1e-5 122 | assert dk1_triton_max_diff <= 2 * dk1_torch_max_diff + 1e-5 123 | assert dk2_triton_max_diff <= 2 * dk2_torch_max_diff + 1e-5 124 | assert dv_triton_max_diff <= 2 * dv_torch_max_diff + 1e-5 125 | -------------------------------------------------------------------------------- /README_cn.md: -------------------------------------------------------------------------------- 1 | # FlagAttention 2 | 3 |

4 | flag-attention 5 |

6 | 7 | [English](./README.md) 8 | 9 | 10 | FlagAttention 是一个用 Triton 语言(https://github.com/openai/triton)实现的内存高效 Attention 算子项目。FlagAttention 由语言模型中对非标准 attention 算子的需求驱动,对 multihead attention 算子进行扩展。 11 | 12 | FlagAttention 和 [FlashAttention](https://arxiv.org/abs/2205.14135) 和 [FlashAttention v2](https://tridao.me/publications/flash2/flash2.pdf) 一样内存高效,可以节省内存占用和访存。因为使用 Triton 语言实现,它更容易理解和修改。原版的 CUDA 实现的 [FlashAttention](https://github.com/Dao-AILab/flash-attention) 提供了如何设计算法以考虑不同内存层级的良好范例。通过分块和重计算的技巧, FlashAttention 避免了实体化 attention score 这个容量和文本长度的平方成正比的中间变量。但是使用 FlashAttention 的时候,无法对 attention score 进行自定义的变换,除非这个变换本身就被 FlashAttention 支持。对 FlashAttention 算子进行扩展需要熟练的 CUDA 编程技巧, 但用 Triton 语言实现的 FlagAttention 则更好修改。 13 | 14 | FlagAttention 目前提供了两个算子。 15 | 16 | 1. flash_attention. 用 Triton 语言实现的 FlashAttention. 17 | 2. piecewise_attention. 这个算子用于实现 NLPE(non linear position embedding),目前用于 [Aquila-2-34B](https://github.com/FlagAI-Open/Aquila2) 模型的训练和推理。 18 | 19 | 如果需要更多的定制,FlagAttention 中的算子实现也可以作为参考。 20 | 21 | ## 更新日志 22 | 23 | ### v0.1 24 | 25 | 添加 piecewise_attention 和 flash_attention 算子。 26 | 27 | ### v0.2 28 | 29 | 优化算子性能。 30 | 1. 仅在必要时使用 masking. 31 | 2. 使用一个单独的 kernel 来计算 q 的梯度,以避免对全局内存的 RMW 操作。 32 | 33 | ## 依赖 34 | 35 | FlagAttention 依赖 Torch 和 Triton。 为了使用 Triton 的新功能,建议使用 nightly 版。 36 | 37 | ```sh 38 | pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly 39 | ``` 40 | FlagAttention 需要 Ampere 架构的 Nvidia GPU (e.g. A100, RTX-3090, ...) 以及 CUDA Toolkit 11.6 及以上的版本运行。其他的 GPU 可能也能运行,但暂未测试。 41 | 42 | ## 安装 43 | 44 | FlagAttention 可以通过以下两种方式安装。 45 | 46 | 1. 可编辑安装。对本地代码的修改会立即生效,无需重新安装。 47 | 2. 构建并安装。这种方式只有 `flag_attn` 包的内容会被安装。 48 | 49 | ### 可编辑安装 50 | 51 | 通过 `pip` 进行可编辑安装 52 | 53 | ```sh 54 | git clone https://github.com/FlagOpen/FlagAttention && cd FlagAttention 55 | pip install -e . 56 | ``` 57 | 58 | ### 构建并安装 59 | 60 | 遵循现代 python 打包惯例,FlagAttention 通过 [`pyproject.toml`](https://pip.pypa.io/en/stable/reference/build-system/pyproject-toml/) 文件来配置,因此没有 `setup.py`. 推荐使用 python 的 `build` 包来构建发行版,包括源码发行版(sdist) 和二进制发行版(whl). 61 | 62 | 首先通过 pip 安装 `build` 包。 63 | 64 | ```sh 65 | pip install build 66 | ``` 67 | 68 | 然后构建包。 69 | 70 | ```sh 71 | git clone https://github.com/FlagOpen/FlagAttention && cd FlagAttention 72 | # 以非隔离模式安装需要自行安装依赖 73 | pip install -U setuptools setuptools-scm 74 | python -m build --no-isolation 75 | ``` 76 | 77 | 构建好的包在 `dist/` 目录,可用于安装。 78 | 79 | ```sh 80 | pip install dist/flag_attn-xxx.whl 81 | ``` 82 | 83 | ## 使用方式 84 | 85 | FlagAttention 提供了自定义的 attention 算子。当一个算子的功能和 torch 函数等价的时候,就可以用它替换对应的 torch 函数。 86 | 87 | ## 运行测试 88 | 89 | 需要较新版本的 `pytest`(>=7.1.0) 以运行 `tests/` 中的测试。FlagAttention 中的运算符以 `flag_attn.testing` 中的 PyTorch [参考实现](src/flag_attn/testing) 为参考进行测试,包括前向和反向。对于支持 `float16` 和 `bfloat16` 数据类型的算子,测试中包含了三种实现用于对比。 90 | 91 | 1. **Pytorch 参考实现**:在这个实现中,输入先被转换为 `float32` 类型,此后全程使用 `float32` 进行运算,再将结果转换为 `float16` 或 `bfloat16` 类型。 92 | 2. **Triton 实现**: 算子的 Triton 实现,使用 `float16` 或 `bfloat16` 作为矩阵乘(MMA)的输入类型,而使用 `float32` 作为矩阵乘的输出类型,以及其他运算的计算类型。 93 | 3. **Pytorch 实现**: 这个实现使用和 Pytorch 参考实现相同的运算,但计算精度和 Triton 实现一致。 94 | 95 | 我们的测试要求在相同情况下,Triton 实现与 Pytorch 参考实现之间的最大误差不大于 Pytorch 实现与 Pytorch 参考实现之间最大误差的两倍。 96 | 97 | ```sh 98 | pytest . 99 | ``` 100 | 101 | ## 运行性能测试 102 | 103 | 项目中提供了性能基准测试来衡量算子所能达到的的 TFLOPs/s。FLOPs/s 用来作为衡量算子运行速度的指标。算子的浮点数运算总量 (FLOPs) 仅考虑矩阵乘。总计算量除以运行时间的中位数,得到算子运行的 FLOPs/s。 104 | 105 | 我们对比了算子的 Triton 实现和 PyTorch 实现的性能。当输入规模较大时,PyTorch 参考实现会遇到内存不足的问题,这种情况下,FLOPs/s 记为 0. 106 | 107 | ```sh 108 | cd benchmarks/ 109 | python flash_benchmark.py 110 | python piecewise_benchmark.py 111 | ``` 112 | 113 | ## 算子 114 | 115 | ### flash_attention 116 | 117 | Triton 语言实现的 FlashAttention, 接口如下。 118 | 119 | ```python 120 | flash_attention(q, k, v, causal=False, sm_scale=None, return_log_normalizer=False, return_total_attention=False) 121 | ``` 122 | 123 | 除了 attention 的输出之外,它还可以根据 `return_log_normalizer` 和 `return_total_attention=False` 返回一些额外的输出。 124 | 125 | 1. log_normalizer: 形状 (batch_size, num_heads, seqlen_q), attention 运算内部的 softmax 运算的 log normalizer. 126 | 2. total_attention: 形状 (batch_size, num_heads, seqlen_k). attention weights 沿着 q 的序列轴上求和的结果。 127 | 128 | ### piecewise_attention 129 | 130 | 对 FlashAttention 的第一个扩展是 [piecewise attention](src/flag_attn/piecewise.py). 该算子增强了 FlashAttention 的功能:使用两个 `q` 和两个 `k` 来计算 attention score(S) ,然后使用 softmax 来计算 attention weight(P). 131 | 132 | 这个设计源于具有旋转位置编码的 Transformer 模型在预测的序列长度超过其最大训练序列长度时存在困难。当距离超过训练集中最大序列长度是,这样的 (q,k) 对会得到较高的 attention score,这是不符合预期的现象。 133 | 134 | 为了解决这个,BAAI提出了 NLPE(Non-Linearized Position Embedding, 非线性位置编码)。该方法根据q和k之间的距离是否超过预定义的阈值,为q和k应用两个不同的位置嵌入,产生q1, q2和k1, k2。然后,根据q和k之间的距离,注意力得分计算为q1, k1或q2, k2的点积。 135 | 136 | 接口如下: 137 | 138 | ![piecewise_attention_interface](./assets/piecewise_attention_interface.png) 139 | 140 | ```python 141 | piecewise_attention(q1, k1, q2, k2, v, dist_threshold, softmax_scale=None, causal=False) 142 | ``` 143 | 144 | ![piecewise attention](assets/piecewise_attention.png) 145 | 146 | #### 使用示例 147 | 148 | ```python 149 | # piecewise_attention 150 | import torch 151 | from flag_attn import piecewise_attention 152 | 153 | B, H, T, D = 2, 16, 8192, 128 154 | dist_threshold = T // 2 155 | 156 | q1 = torch.randn((B, H, T, D), dtype=torch.float16, device="cuda:0").requires_grad_() 157 | q2 = torch.randn((B, H, T, D), dtype=torch.float16, device="cuda:0").requires_grad_() 158 | k1 = torch.randn((B, H, T, D), dtype=torch.float16, device="cuda:0").requires_grad_() 159 | k2 = torch.randn((B, H, T, D), dtype=torch.float16, device="cuda:0").requires_grad_() 160 | v = torch.randn((B, H, T, D), dtype=torch.float16, device="cuda:0").requires_grad_() 161 | o = piecewise_attention(q1, k1, q2, k2, v, dist_threshold, causal=True) 162 | print(o) 163 | 164 | go = torch.randn((B, H, T, D), dtype=torch.float16, device="cuda:0") 165 | gq1, gk1, gq2, gk2, gv = torch.autograd.grad( 166 | o, (q1, k1, q2, k2, v), go 167 | ) 168 | print(gq1) 169 | ``` 170 | 171 | ```python 172 | # flash_attention 173 | import torch 174 | from flag_attn import flash_attention 175 | 176 | B, H, T, D = 2, 16, 8192, 128 177 | 178 | q = torch.randn((B, H, T, D), dtype=torch.float16, device="cuda:0").requires_grad_() 179 | k = torch.randn((B, H, T, D), dtype=torch.float16, device="cuda:0").requires_grad_() 180 | v = torch.randn((B, H, T, D), dtype=torch.float16, device="cuda:0").requires_grad_() 181 | o = flash_attention(q, k, v, causal=True) 182 | print(o) 183 | 184 | go = torch.randn((B, H, T, D), dtype=torch.float16, device="cuda:0") 185 | gq, gk, gv = torch.autograd.grad( 186 | o, (q, k, v), go 187 | ) 188 | print(gq) 189 | ``` 190 | 191 | #### 性能 192 | 193 | 性能测试条件如下: 194 | 195 | 1. seqlen 为 `[512, 1k, 2k, 4k, 16k, 32k]`; 196 | 2. batch size 为 `32k / seqlen`; 197 | 3. headdim 为 `[64, 128]`; 198 | 4. num_heads 为 `2048 / headdim`. 199 | 200 | ##### flash_attention 201 | 202 | 在使用 causal masking 条件下, flash_attention 算子性能如下: 203 | 204 | ![headdim64](./assets/v0.2/flash_attention_d64.png) 205 | 206 | ![headdim128](./assets/v0.2/flash_attention.png) 207 | 208 | 前向算子和 FlashAttention(CUDA) 一样快,甚至在某些情况下比 FlashAttention(CUDA)更快。但反向算子比 FlashAttention 慢。一开始的实现中,我们依照论文中的使用原子加的方式更新 q 的梯度,但这样运行非常慢。所以我们将反向的 kernel 分成两个,一个用来计算 k&v 的梯度,一个用来计算 q 的梯度。这避免了原子加运算,但是增加了更多的重计算。这样的修改将反向算子速度提升到了 4~5 倍,但仍然比 FlashAttention 慢。 209 | 210 | 相同的技巧也用在了 piecewise_attention 上。 211 | 212 | ##### piecewise_attention 213 | 214 | 相比 v0.1, piecewise_attention 算子的性能得到了提升。在 head dim 为 128 且使用 causal masking 的情况下,正向和反向算子的速度分别提升了 36% 和 9%. 215 | 216 | ![piecewise_attention](./assets/v0.2/piecewise_attention.png) 217 | 218 | #### 特征 219 | 220 | - 支持[英伟达](https://www.nvidia.com/) 安培架构的 GPU(在 RTX-3090, A100 上验证); 221 | - 支持[天数智芯](https://www.iluvatar.com/)的 GPU(在 MR-V100 上验证); 222 | - 数据类型支持,float16, 在英伟达安培架构 GPU 上也支持 bfloat16; 223 | - 支持 causal 和非 causal 模式; 224 | - 支持前向和反向计算; 225 | - K/V 的序列长度可以不等于 Q 的序列长度; 226 | - 支持计算每个 k 从所有 q 得到的 attention 总和。 227 | - 支持 [MQA](https://arxiv.org/abs/1911.02150) and [GQA](https://arxiv.org/pdf/2305.13245). 228 | - 支持对 attention weights 进行 dropout. 229 | 230 | #### 限制 231 | 232 | - `headdim` 必须为 `[16, 32, 64, 128]` 之一; 233 | 234 | ## TODOs 235 | 236 | 1. 在其他 GPU 上测试; 237 | 2. 在更多 Triton 版本上进行测试; 238 | 3. 提高算子的性能; 239 | 4. 支持对 FlashAttention 的其他功能扩展。 240 | 241 | ## 更多 242 | 243 | 关于智源研究院的更多大模型开源技术,请访问 [BAAI/FlagOpen](https://flagopen.baai.ac.cn/) 查看。 244 | [](https://flagopen.baai.ac.cn/) 245 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FlagAttention 2 | 3 |

4 | flag-attention 5 |

6 | 7 | [中文版](./README_cn.md) 8 | 9 | FlagAttention is a project for memory-efficient attention operators implemented in the [Triton language](https://github.com/openai/triton). Motivated by the need for non-standard attention operators in language modeling, it starts as an extension of multi-head attention. 10 | 11 | It saves memory footprint and traffic like [FlashAttention](https://arxiv.org/abs/2205.14135) and [FlashAttention v2](https://tridao.me/publications/flash2/flash2.pdf). Implemented in the Triton language, it is easier to understand and modify. The original implementation of FlashAttention in CUDA([flash-attention](https://github.com/Dao-AILab/flash-attention)) provides a good example of how to design an algorithm that takes different levels of memory into account. By tiling and re-computation, FlashAttention avoids materializing the attention scores, whose capacity is proportional to the square of the sequence length. However, custom transformation to the attention scores is not possible when using FlashAttention, unless it is supported by FlashAttention out-of-the-box. 12 | While extending FlashAttention requires proficiency in CUDA programming, FlagAttention implemented in the Triton language is easier to modify. 13 | 14 | FlagAttention now offers two operators. 15 | 16 | 1. **flash_attention**: FlashAttention implemented in the Triton language. 17 | 2. **piecewise_attention**. Currently employed for NLPE(Non-Linearized position embedding) in both training and inference of the [Aquila-2-34B](https://github.com/FlagAI-Open/Aquila2) model. 18 | 19 | When further customization is required, FlagAttention servers as an example. 20 | 21 | ## Changelog 22 | 23 | ### v0.1 24 | 25 | Add piecewise_attention & flash_attention. 26 | 27 | ### v0.2 28 | 29 | Optimization of operators. 30 | 1. applying mask only when needed. 31 | 2. use a separate kernel to compute the gradien of q to avoid atomic RMW to global memory. 32 | 33 | 34 | ## Requirements 35 | 36 | FlagAttention requires Pytorch and Triton. To use the new features of Triton, a nightly release is recommended. 37 | 38 | 39 | ```sh 40 | # install a nightly release of Triton 41 | pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly 42 | ``` 43 | 44 | FlagAttention requires Ampere Nvidia GPUs(e.g. A100, RTX-3090, ...) and CUDA Toolkit 11.6 or above. Other GPUs may work but have not been tested yet. 45 | 46 | ## Installation 47 | 48 | FlagAttention can be installed in either way below. 49 | 50 | 1. Editable Installation. Changes to the code in the local source tree are effective without re-installation. 51 | 2. Build a distribution and then install. Only the package is installed. 52 | 53 | ### Editable Installation 54 | 55 | Editable installation with pip. 56 | 57 | ```sh 58 | git clone https://github.com/FlagOpen/FlagAttention && cd FlagAttention 59 | pip install -e . 60 | ``` 61 | 62 | ### Build a Distribution & Install 63 | 64 | Following modern Python packaging convention(PEP-517), FlagAttention is configured by [`pyproject.toml`](https://pip.pypa.io/en/stable/reference/build-system/pyproject-toml/), and no `setup.py` is provided. To build a distribution, either a source distribution or a binary distribution, python package `build` is recommended. 65 | 66 | First, install `build` package via pip. 67 | 68 | ```sh 69 | pip install build 70 | ``` 71 | 72 | Then build the package. 73 | 74 | ```sh 75 | git clone https://github.com/FlagOpen/FlagAttention && cd FlagAttention 76 | # to build in `no-isolation` mode requires installing build requirements manually 77 | pip install -U setuptools setuptools-scm 78 | python -m build --no-isolation 79 | ``` 80 | 81 | The built package is in `dist/` for installation. 82 | 83 | ```sh 84 | pip install dist/flag_attn-xxx.whl 85 | ``` 86 | 87 | ## Usage 88 | 89 | FlagAttention provides customized operators for attention. When an operator is equivalent to a torch function, it can be used as a drop-in replacement. 90 | 91 | ## Run the Tests 92 | 93 | A recent version of `pytest`(>=7.1.0) is required to run the tests in `tests/`. Operators in `FlagAttention` are tested against [reference implementations](src/flag_attn/testing) in Pytorch provided by `flag_attn.testing`, both for the forward and backward operators. For operators with support for inputs of `float16` or `bfloat16`, three different implementations are included for numerical accuracy testing. 94 | 95 | 1. **Reference Implementation in Pytorch**: This implementation upcasts the inputs to `float32` and performs the computations in `float32` all the way through before casting the outputs to `float16` or `bfloat16`. 96 | 2. **Triton Implementation**: The Triton implementation uses `float16` or `bfloat16` for MMA(matrix multiplication accumulation) inputs and `float32` for MMA outputs and other computations. 97 | 3. **Pytorch Implementation**: This implementation mirrors the computations in the reference implementation, except that the precision is the same as the Triton implementation. 98 | 99 | The tests for numerical accuracy enforce that the maximum difference between the Triton implementation and reference implementation is not greater than twice the maximanum difference between the Pytorch implementation and reference implementation. 100 | 101 | ```sh 102 | pytest . 103 | ``` 104 | 105 | ## Run the Benchmark 106 | 107 | Benchmarks are included to quantify the achieved `TFLOP/s`, which serves as a metric of speed operators. The calculation of FLOPs for an operator considers only the matmul operation. The resulting FLOPs are then divided by the median runtime to determine the achieved FLOPs/s. 108 | 109 | The benchmarking process involves comparing the Triton implementations with counterparts in Pytorch. When the input size is large, resulting in memory exhaustion in the Pytorch implementation, the FLOP/s is considered zero. 110 | 111 | ```sh 112 | cd benchmarks/ 113 | python flash_benchmark.py 114 | python piecewise_benchmark.py 115 | ``` 116 | 117 | ## Operators 118 | 119 | ### flash_attention 120 | 121 | The implementation of FlashAttention in the Triton language. The interface is. 122 | 123 | ```python 124 | flash_attention(q, k, v, causal=False, sm_scale=None, return_log_normalizer=False, return_total_attention=False) 125 | ``` 126 | 127 | In addition to the attention outputs, it can return some extra outputs dependes on `return_log_normalizer` and `return_total_attention`. 128 | 129 | 1. log_normalizer: shape (batch_size, num_heads, seqlen_q). The log normalizer of the softmax inside attention operation. 130 | 2. total_attention: shape (batch_size, num_heads, seqlen_k). The sum of attention weights along q's sequence axis. 131 | 132 | ### piecewise_attention 133 | 134 | The first extension to FlashAttention is [piecewise_attention](src/flag_attn/piecewise.py). This operator enhances FlashAttention by using two `q`'s and two `k`'s to calculate the attention scores(S) before applying softmax to obtain the attention weights(P). 135 | 136 | The rationale behind this design is rooted in the observations that a transformer with rotary position embedding struggles with predicting sequences longer than the maximum sequence length it is trained on. Pairs of `(q, k)` yield unexpectedly high attention scores when the distance exceeds the maximum sequence length in the training set. 137 | 138 | To address this issue, BAAI proposes NLPE(Non-Linearized Position Embedding), which applies two different position embeddings to `q` and `k` based on whether the distance between `q` and `k` exceeds a pre-defined threshold, producing `q1, q2` and `k1, k2`. Then the attention score is computed as the dot product of `q1, k1` or `q2, k2` depending on the distance between `q` and `k`. 139 | 140 | 141 | 142 | The interface is shown below. 143 | 144 | ![piecewise_attention_interface](./assets/piecewise_attention_interface.png) 145 | 146 | ```python 147 | piecewise_attention(q1, k1, q2, k2, v, dist_threshold, causal=False, sm_scale=None) 148 | ``` 149 | 150 | It splices two attention scores(S) in the forward computation and splits the gradient of S in the backward computation. 151 | 152 | ![piecewise attention](assets/piecewise_attention.png) 153 | 154 | #### Usage 155 | 156 | ```python 157 | # piecewise_attention 158 | import torch 159 | from flag_attn import piecewise_attention 160 | 161 | B, H, T, D = 2, 16, 8192, 128 162 | dist_threshold = T // 2 163 | 164 | q1 = torch.randn((B, H, T, D), dtype=torch.float16, device="cuda:0").requires_grad_() 165 | q2 = torch.randn((B, H, T, D), dtype=torch.float16, device="cuda:0").requires_grad_() 166 | k1 = torch.randn((B, H, T, D), dtype=torch.float16, device="cuda:0").requires_grad_() 167 | k2 = torch.randn((B, H, T, D), dtype=torch.float16, device="cuda:0").requires_grad_() 168 | v = torch.randn((B, H, T, D), dtype=torch.float16, device="cuda:0").requires_grad_() 169 | o = piecewise_attention(q1, k1, q2, k2, v, dist_threshold, causal=True) 170 | print(o) 171 | 172 | go = torch.randn((B, H, T, D), dtype=torch.float16, device="cuda:0") 173 | gq1, gk1, gq2, gk2, gv = torch.autograd.grad( 174 | o, (q1, k1, q2, k2, v), go 175 | ) 176 | print(gq1) 177 | ``` 178 | 179 | ```python 180 | # flash_attention 181 | import torch 182 | from flag_attn import flash_attention 183 | 184 | B, H, T, D = 2, 16, 8192, 128 185 | 186 | q = torch.randn((B, H, T, D), dtype=torch.float16, device="cuda:0").requires_grad_() 187 | k = torch.randn((B, H, T, D), dtype=torch.float16, device="cuda:0").requires_grad_() 188 | v = torch.randn((B, H, T, D), dtype=torch.float16, device="cuda:0").requires_grad_() 189 | o = flash_attention(q, k, v, causal=True) 190 | print(o) 191 | 192 | go = torch.randn((B, H, T, D), dtype=torch.float16, device="cuda:0") 193 | gq, gk, gv = torch.autograd.grad( 194 | o, (q, k, v), go 195 | ) 196 | print(gq) 197 | ``` 198 | 199 | #### Performance 200 | 201 | Benchmark is performed under such conditions. 202 | 203 | 1. seqlen in `[512, 1k, 2k, 4k, 16k, 32k]`; 204 | 2. batch size: `32k / seqlen`; 205 | 3. headdim in`[64, 128]`; 206 | 4. num_heads: `2048 / headdim`. 207 | 208 | ##### flash_attention 209 | 210 | The performance of flash_attention with causal masking is shown below. 211 | 212 | ![headdim64](./assets/v0.2/flash_attention_d64.png) 213 | 214 | ![headdim128](./assets/v0.2/flash_attention.png) 215 | 216 | The forward operator runs as fast as, and in some cases, faster than FlashAttention(CUDA), but the backward operator is generally slower than FlashAttention. We first follow the paper and update the gradient of Q with atomic addition in the backward operator, which runs extremely slowly. Then we split the backward operator into two kernels, one to compute the gradient of k and v, the other to compute the gradient of q. This alternation avoids atomic additions but introduces more re-computation. Although this strategy yields a 4x to 5x speedup in the backward operator, it is still slower than FlashAttention(CUDA). 217 | 218 | The same split-kernel trick is also applied to `piecewise_attention` for efficiency. 219 | 220 | ##### piecewise_attention 221 | 222 | The performance of piecewise_attention has improved compared to that in v0.1. In the case where the head dim is 128 and causal masking is applied, the forward and backward operator is faster than that in v0.1 by 36% and 9%, respectively. 223 | 224 | ![piecewise_attention](./assets/v0.2/piecewise_attention.png) 225 | 226 | #### Features 227 | 228 | - support for [Nvidia](https://www.nvidia.com/) Ampere GPU(Tested on RTX-3090 and A100); 229 | - support for [Iluvatar CoreX](https://www.iluvatar.com/) GPU(Tested on Iluvatar CoreX MR-V100); 230 | - datatype support, `float16` and `bfloat16` for Ampere Nvidia GPUs; 231 | - support causal and non-causal modes; 232 | - support forward & backward modes; 233 | - the sequence length of k/v can be different from that of q; 234 | - support computation of total attention of each `k` gets from all `q`'s; 235 | - supports returning accumulative attention of each keys. 236 | - supports [MQA](https://arxiv.org/abs/1911.02150) and [GQA](https://arxiv.org/pdf/2305.13245). 237 | - supports dropout of attention weights. 238 | 239 | #### Limitations 240 | 241 | - `headdim` should be in `[16, 32, 64, 128]`. 242 | 243 | ## TODOs 244 | 245 | 1. Test on other GPUs; 246 | 2. Test on more versions of triton; 247 | 3. Improve performance of attention operators(especially for the backward op); 248 | 4. Support other extensions to flash attention. 249 | 250 | ## More 251 | 252 | For more about the open source system for large models from BAAI, please with [BAAI/FlagOpen](https://flagopen.baai.ac.cn/). 253 | [](https://flagopen.baai.ac.cn/) 254 | -------------------------------------------------------------------------------- /src/flag_attn/split_kv.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import triton 4 | import triton.language as tl 5 | 6 | """ 7 | This file implements flash decoding, flash attention with split_kv, which exposes another 8 | dimension of parallelism when batch_size * num_heads * blocks_along_seqlen_q cannot saturate 9 | the gpu's SM's. 10 | 11 | For more details, refer to https://princeton-nlp.github.io/flash-decoding/. 12 | """ 13 | 14 | @triton.jit 15 | def _fwd_split_kv_kernel( 16 | Q, K, V, sm_scale, 17 | L, O, 18 | stride_qz, stride_qh, stride_qm, stride_qk, 19 | stride_kz, stride_kh, stride_kn, stride_kk, 20 | stride_vz, stride_vh, stride_vn, stride_vk, 21 | stride_oz, stride_oh, stride_os, stride_om, stride_ok, 22 | Z, H, M, N, P_SEQ, N_SPLIT_SIZE, S, num_groups, 23 | BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, 24 | IS_CAUSAL: tl.constexpr, LARGER_M: tl.constexpr, 25 | DIVISIBLE_M: tl.constexpr, DIVISIBLE_N: tl.constexpr, 26 | ): 27 | input_dtype = Q.dtype.element_ty 28 | # -- grid id -- 29 | start_m = tl.program_id(0) 30 | n_split_id = tl.program_id(1) 31 | off_zh = tl.program_id(2) 32 | off_h = off_zh % H 33 | off_z = off_zh // H 34 | off_hk = off_h // num_groups 35 | 36 | # scale sm_scale by log_2(e) and use 37 | # 2^x instead of exp in the loop because CSE and LICM 38 | # don't work as expected with `exp` in the loop 39 | log2e: tl.constexpr = 1.4426950408889634 40 | qk_scale = sm_scale * log2e 41 | 42 | # offset pointers for (batch & head) 43 | Q += off_z * stride_qz + off_h * stride_qh 44 | K += off_z * stride_kz + off_hk * stride_kh 45 | V += off_z * stride_vz + off_hk * stride_vh 46 | 47 | # offset pointers for (batch & head, split) 48 | O += off_z * stride_oz + off_h * stride_oh + n_split_id * stride_os # o's shape is (B, H, S, M, D) 49 | L += ((off_z * H + off_h) * S + n_split_id) * M # l's shape is (B, H, S, M) 50 | 51 | offs_m_base = tl.arange(0, BLOCK_M) 52 | offs_m = start_m * BLOCK_M + offs_m_base 53 | offs_n_base = tl.arange(0, BLOCK_N) 54 | offs_k = tl.arange(0, BLOCK_DMODEL) 55 | 56 | # initialize pointers to value-like data 57 | q_ptrs = Q + (offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk) # (BLOCK_M, BLOCK_DMODEL) 58 | o_ptrs = O + (offs_m[:, None] * stride_om + offs_k[None, :] * stride_ok) # (BLOCK_M, BLOCK_DMODEL) 59 | l_ptrs = L + offs_m 60 | 61 | # initialize pointer to m and l, fp32 for accumulators 62 | m_i = tl.full([BLOCK_M], value=-float("inf"), dtype=tl.float32) 63 | l_i = tl.zeros([BLOCK_M], dtype=tl.float32) 64 | acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) 65 | 66 | # load q 67 | if DIVISIBLE_M: 68 | q = tl.load(q_ptrs) 69 | else: 70 | mask_m = offs_m < M 71 | q = tl.load(q_ptrs, mask=mask_m[:, None]) 72 | 73 | #Dot I trick: to place q in registers, it saves shared memory 74 | if BLOCK_DMODEL < 128: 75 | I = tl.where(offs_k[:, None] == offs_k, 76 | tl.full((BLOCK_DMODEL, BLOCK_DMODEL), 1.0, dtype=input_dtype), 77 | tl.full((BLOCK_DMODEL, BLOCK_DMODEL), 0.0, dtype=input_dtype)) 78 | q = tl.dot(q, I).to(input_dtype) 79 | # else: 80 | # I = tl.where(offs_m_base[:, None] == offs_m_base, 81 | # tl.full((BLOCK_M, BLOCK_M), 1.0, dtype=input_dtype), 82 | # tl.full((BLOCK_M, BLOCK_M), 0.0, dtype=input_dtype)) 83 | # q = tl.dot(I, q).to(input_dtype) 84 | 85 | # NOTE: Loop-Bound-For-N 86 | # The indices in m-dimension that this block may access is in `[start_m * BLOCK_M, (start_m + 1) * BLOCK_M)`. 87 | # According to the rule of causal masking, then max index in n-dimension that this block may access 88 | # is `P_SEQ + (start_m + 1) * BLOCK_M`. 89 | # However, the upper bound of index in n-dimension should never exceed the sequence length of k/v(`P_SEQ + N_CTX`). 90 | # `P_SEQ + (start_m + 1) * BLOCK_M` may be larger than `N`. 91 | # At this case, there would be illegal memory access when loading k & v tiles 92 | # if mask_n is not applied for loading(only when `DIVISIBLE_N`` is true). 93 | # See also https://github.com/FlagOpen/FlagAttention/pull/8 94 | N_LEFT = n_split_id * N_SPLIT_SIZE 95 | N_RIGHT = tl.minimum(N_LEFT + N_SPLIT_SIZE, N) 96 | if IS_CAUSAL: 97 | hi = tl.minimum(N_RIGHT, P_SEQ + (start_m + 1) * BLOCK_M) 98 | if LARGER_M: 99 | hi = tl.maximum(N_LEFT, hi) 100 | else: 101 | hi = N_RIGHT 102 | 103 | # loop over k, v and update accumulators 104 | offs_n_init = N_LEFT + offs_n_base 105 | k_ptrs = K + (offs_k[:, None] * stride_vk + offs_n_init[None, :] * stride_vn) # (BLOCK_DMODEL, BLOCK_N) 106 | v_ptrs = V + (offs_n_init[:, None] * stride_kn + offs_k[None, :] * stride_kk) # (BLOCK_N, BLOCK_DMODEL) 107 | for start_n in range(N_LEFT, hi, BLOCK_N): 108 | start_n = tl.multiple_of(start_n, BLOCK_N) 109 | offs_n = start_n + offs_n_base 110 | 111 | # -- load k, v -- 112 | if DIVISIBLE_N: 113 | k = tl.load(k_ptrs, cache_modifier=".cg") 114 | v = tl.load(v_ptrs, cache_modifier=".cg") 115 | else: 116 | mask_n = offs_n < N 117 | k = tl.load(k_ptrs, mask=mask_n[None, :], cache_modifier=".cg") 118 | v = tl.load(v_ptrs, mask=mask_n[:, None], cache_modifier=".cg") 119 | 120 | # -- compute qk --- 121 | s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 122 | s += tl.dot(q, k) 123 | 124 | if not DIVISIBLE_N: 125 | s = tl.where(mask_n[None, :], s, float("-inf")) 126 | if IS_CAUSAL: 127 | causal_mask = (P_SEQ + offs_m[:, None]) >= offs_n[None, :] 128 | s = tl.where(causal_mask, s, float("-inf")) 129 | 130 | # -- compute scaling constant --- 131 | m_i_new = tl.maximum(m_i, tl.max(s, 1)) 132 | alpha = tl.math.exp2((m_i - m_i_new) * qk_scale) 133 | p = tl.math.exp2(s * qk_scale - m_i_new[:, None] * qk_scale) 134 | 135 | # -- scale and update acc: acc *= alpha[:, None]-- 136 | acc *= alpha[:, None] 137 | acc += tl.dot(p.to(input_dtype), v) 138 | 139 | # -- update m_i and l_i -- 140 | l_i = l_i * alpha + tl.sum(p, 1) 141 | m_i = m_i_new 142 | # update pointers 143 | k_ptrs += BLOCK_N * stride_kn 144 | v_ptrs += BLOCK_N * stride_vn 145 | 146 | # write back l & o 147 | if IS_CAUSAL and LARGER_M: 148 | is_empty_line = (offs_m + P_SEQ) < 0 149 | acc = tl.where(is_empty_line[:, None], 0.0, acc * (1.0 / l_i[:, None])) 150 | l = tl.where(is_empty_line, float("-inf"), m_i * sm_scale + tl.log(l_i)) 151 | else: 152 | acc = acc * (1.0 / l_i[:, None]) 153 | l = m_i * sm_scale + tl.log(l_i) # log(normalizer) 154 | 155 | if DIVISIBLE_M: 156 | tl.store(l_ptrs, l, cache_modifier=".cg") 157 | tl.store(o_ptrs, acc.to(input_dtype), cache_modifier=".cg") 158 | else: 159 | tl.store(l_ptrs, l, mask=mask_m, cache_modifier=".cg") 160 | tl.store(o_ptrs, acc.to(input_dtype), mask=mask_m[:, None], cache_modifier=".cg") 161 | 162 | @triton.jit 163 | def _fwd_combine_kv_splits( 164 | multiple_o, multiple_l, 165 | final_o, final_l, 166 | stride_mul_oz, stride_mul_oh, stride_mul_os, stride_mul_om, stride_mul_ok, 167 | stride_fin_oz, stride_fin_oh, stride_fin_om, stride_fin_ok, 168 | Z, H, M, S, 169 | BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, 170 | DIVISIBLE_M: tl.constexpr, 171 | ): 172 | start_m = tl.program_id(0) 173 | offs_h = tl.program_id(1) 174 | offs_z = tl.program_id(2) 175 | 176 | # offset 177 | multiple_o += offs_z * stride_mul_oz + offs_h * stride_mul_oh # (B, H, S, M, D) 178 | multiple_l += (offs_z * H + offs_h) * S * M # (B, H, S, M) 179 | 180 | offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) 181 | if not DIVISIBLE_M: 182 | mask_m = offs_m < M 183 | 184 | # 1st loop: online logsumexp to save a swipe 185 | m = tl.full([BLOCK_M], value=float("-inf"), dtype=tl.float32) 186 | acc = tl.full([BLOCK_M], value=float(0.0), dtype=tl.float32) 187 | l_ptrs = multiple_l + offs_m 188 | for _ in range(0, S): 189 | if DIVISIBLE_M: 190 | l = tl.load(l_ptrs) 191 | else: 192 | l = tl.load(l_ptrs, mask=mask_m) 193 | m_new = tl.maximum(m, l) 194 | acc = acc * tl.exp(m - m_new) + tl.exp(l - m_new) 195 | m = m_new 196 | l_ptrs += M 197 | l_acc = m + tl.log(acc) 198 | 199 | # 2rd loop to rescale and accumulate o 200 | o_acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) 201 | l_ptrs = multiple_l + offs_m 202 | offs_k = tl.arange(0, BLOCK_DMODEL) 203 | o_ptrs = multiple_o + offs_m[:, None] * stride_mul_om + offs_k[None, :] * stride_mul_ok 204 | for _ in range(0, S): 205 | l = tl.load(l_ptrs, mask=offs_m < M) 206 | rescale = tl.exp(l - l_acc) 207 | if DIVISIBLE_M: 208 | o = tl.load(o_ptrs, ) 209 | else: 210 | o = tl.load(o_ptrs, mask=mask_m[:, None]) 211 | o_acc += o * rescale[:, None] 212 | 213 | l_ptrs += M 214 | o_ptrs += stride_mul_os 215 | 216 | # write back 217 | final_o += offs_z * stride_fin_oz + offs_h * stride_fin_oh 218 | final_l += (offs_z * H + offs_h) * M 219 | a_ptrs = final_o + offs_m[:, None] * stride_fin_om + offs_k * stride_fin_ok 220 | b_ptrs = final_l + offs_m 221 | 222 | if DIVISIBLE_M: 223 | tl.store(a_ptrs, o_acc) 224 | tl.store(b_ptrs, l_acc) 225 | else: 226 | tl.store(a_ptrs, o_acc, mask=mask_m[:, None]) 227 | tl.store(b_ptrs, l_acc, mask=mask_m) 228 | 229 | def get_fwd_config(B, H, M, N, D, causal): 230 | # BLOCK_M, BLOCK_N, num_stages, num_warps 231 | return (16, 128, 1, 4) 232 | 233 | # this function is adapted from https://github.com/Dao-AILab/flash-attention/blob/61a777247900f6c2a37376f3ffd7134385fdc95c/csrc/flash_attn/flash_api.cpp#L235 234 | def num_splits_herustic(B, H, M, N, BLOCK_M, BLOCK_N, num_sms, max_splits): 235 | num_blocks_without_split_kv = B * H * triton.cdiv(M, BLOCK_M) 236 | if num_blocks_without_split_kv >= 0.8 * num_sms: 237 | return 1 238 | 239 | num_n_blocks = triton.cdiv(N, BLOCK_N) 240 | def num_split_avaiable(s): 241 | blocks_per_split = triton.cdiv(num_n_blocks, s) 242 | return s == 1 or (blocks_per_split * s - num_n_blocks < blocks_per_split) 243 | 244 | def efficiency(s): 245 | n_waves = (num_blocks_without_split_kv * s) / num_sms 246 | eff = n_waves / math.ceil(n_waves) 247 | return eff 248 | 249 | max_efficiency = 0.0 250 | plans = [] # (num_split, efficiency) 251 | max_splits = min(num_sms, num_n_blocks, max_splits) 252 | 253 | for num_split in range(1, max_splits + 1): 254 | if num_split_avaiable(num_split): 255 | eff = efficiency(num_split) 256 | plans.append((num_split, eff)) 257 | max_efficiency = max(eff, max_efficiency) 258 | 259 | for num_split, eff in plans: 260 | if eff >= 0.85 * max_efficiency: 261 | return num_split 262 | return 1 263 | 264 | 265 | # flash decoding 266 | def attention(q, k, v, causal=False, sm_scale=None): 267 | Dq, Dk, Dv = q.shape[-1], k.shape[-1], v.shape[-1] 268 | assert Dq == Dk == Dv 269 | assert Dk in {16, 32, 64, 128} 270 | 271 | B, H, M, D = q.shape 272 | N = k.shape[2] 273 | Hk, Hv = k.shape[1], v.shape[1] 274 | assert Hk == Hv, "num of heads in k and v should be equal" 275 | assert H % Hk == 0, "number of heads in q must be a multiple of that in k & v" 276 | num_groups = H // Hk 277 | P_SEQ = N - M 278 | larger_m = M > N 279 | 280 | if sm_scale is None: 281 | sm_scale = 1. / math.sqrt(D) 282 | 283 | # to work around https://github.com/openai/triton/issues/2441 284 | device = torch.cuda.device_of(q) 285 | num_sms = torch.cuda.get_device_properties(device).multi_processor_count 286 | 287 | with torch.cuda.device(device): 288 | config = get_fwd_config(B, H, M, N, D, causal) 289 | BLOCK_M, BLOCK_N, num_stages, num_warps = config 290 | S = num_splits_herustic(B, H, M, N, BLOCK_M, BLOCK_N, num_sms, 128) 291 | 292 | divisible_m = M % BLOCK_M == 0 293 | divisible_n = N % BLOCK_N == 0 294 | 295 | # consider using 3d grid to avoid div & rem 296 | multiple_l = torch.empty((B, H, S, M), dtype=torch.float32, device="cuda") 297 | multiple_o = torch.empty((B, H, S, M, D), dtype=torch.float16, device="cuda") 298 | grid = (triton.cdiv(M, BLOCK_M), S, H * B) 299 | N_SPLIT_SIZE = triton.cdiv(triton.cdiv(N, BLOCK_N), S) * BLOCK_N 300 | _fwd_split_kv_kernel[grid]( 301 | q, k, v, sm_scale, 302 | multiple_l, multiple_o, 303 | q.stride(0), q.stride(1), q.stride(2), q.stride(3), 304 | k.stride(0), k.stride(1), k.stride(2), k.stride(3), 305 | v.stride(0), v.stride(1), v.stride(2), v.stride(3), 306 | multiple_o.stride(0), multiple_o.stride(1), multiple_o.stride(2), multiple_o.stride(3), multiple_o.stride(4), 307 | B, H, M, N, P_SEQ, N_SPLIT_SIZE, S, num_groups, 308 | BLOCK_M=BLOCK_M, BLOCK_DMODEL=D, BLOCK_N=BLOCK_N, 309 | IS_CAUSAL=causal, LARGER_M=larger_m, 310 | DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n, 311 | num_stages=num_stages, num_warps=num_warps, 312 | ) 313 | 314 | if S == 1: 315 | return multiple_o.squeeze(2) 316 | 317 | final_l = torch.empty((B, H, M), dtype=torch.float32, device="cuda") 318 | final_o = torch.empty_like(q) 319 | grid = (triton.cdiv(M, BLOCK_M), H, B) 320 | _fwd_combine_kv_splits[grid]( 321 | multiple_o, multiple_l, 322 | final_o, final_l, 323 | multiple_o.stride(0), multiple_o.stride(1), multiple_o.stride(2), multiple_o.stride(3), multiple_o.stride(4), 324 | final_o.stride(0), final_o.stride(1), final_o.stride(2), final_o.stride(3), 325 | B, H, M, S, 326 | BLOCK_M=BLOCK_M, BLOCK_DMODEL=D, 327 | DIVISIBLE_M=divisible_m, 328 | num_stages=num_stages, num_warps=num_warps, 329 | ) 330 | return final_o 331 | -------------------------------------------------------------------------------- /tests/flag_attn/test_flash_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytest 3 | 4 | import flag_attn 5 | 6 | torch.random.manual_seed(10086) 7 | 8 | def max_diff(a, b): 9 | return (a - b).abs().max().item() 10 | 11 | def zero_percent(a, b): 12 | diff = (a - b).abs() 13 | num_non_zeros = diff.nonzero().shape[0] 14 | return (1.0 - num_non_zeros/ diff.numel()) * 100.0 15 | 16 | def report(name, actual, expected): 17 | print(f"{name}: \tmax_difference: {max_diff(actual, expected):0.6f}\tzero_diff elements: {zero_percent(actual, expected):0.3f}%") 18 | 19 | 20 | @pytest.mark.parametrize('device_id', list(range(torch.cuda.device_count()))) 21 | @pytest.mark.parametrize('scale', [1.0, 2.0, 3.0, 4.0]) 22 | @pytest.mark.parametrize('B, Hq, Hk, M, N, D', [ 23 | (2, 4, 4, 512, 612, 128), 24 | (2, 4, 4, 1024, 1034, 64), 25 | (2, 4, 4, 2048, 2048, 32), 26 | (2, 4, 4, 4096, 4096, 16), 27 | (2, 4, 4, 4001, 4001, 32), 28 | (2, 4, 4, 4001, 4096, 64), 29 | (2, 4, 4, 4096, 4000, 128), 30 | (1, 2, 2, 8192, 8202, 16), 31 | (1, 2, 2, 8192, 8192, 32), 32 | # test for mqa/gqa 33 | (2, 4, 2, 512, 612, 128), 34 | (2, 4, 1, 1024, 1034, 64), 35 | (2, 4, 2, 2048, 2048, 32), 36 | (2, 4, 1, 4096, 4096, 16), 37 | (2, 4, 2, 4001, 4001, 32), 38 | (2, 4, 1, 4001, 4096, 64), 39 | (2, 4, 2, 4096, 4000, 128), 40 | (1, 2, 1, 8192, 8202, 16), 41 | (1, 2, 1, 8192, 8192, 32), 42 | ]) 43 | @pytest.mark.parametrize('causal', [True, False]) 44 | @pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) 45 | @pytest.mark.parametrize('stride_order', ['BHTD', 'BTHD']) 46 | def test_attention_fwd(B, Hq, Hk, M, N, D, causal, stride_order, dtype, scale, device_id): 47 | device = f"cuda:{device_id}" 48 | if stride_order == "BHTD": 49 | q = torch.empty((B, Hq, M, D), dtype=dtype, device=device).normal_(mean=0., std=scale) 50 | k = torch.empty((B, Hk, N, D), dtype=dtype, device=device).normal_(mean=0., std=scale) 51 | v = torch.empty((B, Hk, N, D), dtype=dtype, device=device).normal_(mean=0., std=scale) 52 | else: 53 | q = torch.empty((B, M, Hq, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2) 54 | k = torch.empty((B, N, Hk, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2) 55 | v = torch.empty((B, N, Hk, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2) 56 | 57 | o_ref = flag_attn.testing.flash_attention(q, k, v, causal, upcast=True) 58 | o_torch = flag_attn.testing.flash_attention(q, k, v, causal, upcast=False) 59 | o_hyp = flag_attn.flash_attention(q, k, v, causal) 60 | 61 | torch_max_diff = max_diff(o_torch, o_ref) 62 | triton_max_diff = max_diff(o_hyp, o_ref) 63 | report("o hyp", o_hyp, o_ref) 64 | report("o torch", o_hyp, o_ref) 65 | assert triton_max_diff <= 2 * torch_max_diff + 1e-5 66 | 67 | 68 | @pytest.mark.parametrize('device_id', list(range(torch.cuda.device_count()))) 69 | @pytest.mark.parametrize('scale', [10.0]) 70 | @pytest.mark.parametrize('B, Hq, Hk, M, N, D', [ 71 | (2, 4, 4, 1, 612, 128), 72 | (2, 4, 4, 1, 1034, 64), 73 | (2, 4, 4, 1, 2048, 32), 74 | (2, 4, 4, 1, 4096, 16), 75 | (2, 4, 4, 1, 4001, 32), 76 | (2, 4, 4, 1, 4096, 64), 77 | (2, 4, 4, 2, 4000, 128), 78 | (1, 2, 2, 4, 8202, 16), 79 | (1, 2, 2, 1, 8192, 32), 80 | # test for mqa/gqa 81 | (2, 4, 2, 1, 612, 128), 82 | (2, 4, 1, 1, 1034, 64), 83 | (2, 4, 2, 1, 2048, 32), 84 | (2, 4, 1, 1, 4096, 16), 85 | (2, 4, 2, 1, 4001, 32), 86 | (2, 4, 1, 1, 4096, 64), 87 | (2, 4, 2, 2, 4000, 128), 88 | (1, 2, 1, 4, 8202, 16), 89 | (1, 2, 1, 1, 8192, 32), 90 | ]) 91 | @pytest.mark.parametrize('causal', [True, False]) 92 | @pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) 93 | @pytest.mark.parametrize('stride_order', ['BHTD', 'BTHD']) 94 | def test_attention_splitkv(B, Hq, Hk, M, N, D, causal, stride_order, dtype, scale, device_id): 95 | device = f"cuda:{device_id}" 96 | if stride_order == "BHTD": 97 | q = torch.empty((B, Hq, M, D), dtype=dtype, device=device).normal_(mean=0., std=scale) 98 | k = torch.empty((B, Hk, N, D), dtype=dtype, device=device).normal_(mean=0., std=scale) 99 | v = torch.empty((B, Hk, N, D), dtype=dtype, device=device).normal_(mean=0., std=scale) 100 | else: 101 | q = torch.empty((B, M, Hq, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2) 102 | k = torch.empty((B, N, Hk, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2) 103 | v = torch.empty((B, N, Hk, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2) 104 | 105 | o_ref = flag_attn.testing.flash_attention(q, k, v, causal, upcast=True) 106 | o_torch = flag_attn.testing.flash_attention(q, k, v, causal, upcast=False) 107 | o_hyp = flag_attn.flash_attention(q, k, v, causal) 108 | 109 | torch_max_diff = max_diff(o_torch, o_ref) 110 | triton_max_diff = max_diff(o_hyp, o_ref) 111 | report("o hyp", o_hyp, o_ref) 112 | report("o torch", o_hyp, o_ref) 113 | assert triton_max_diff <= 2 * torch_max_diff + 1e-5 114 | 115 | @pytest.mark.parametrize('device_id', list(range(torch.cuda.device_count()))) 116 | @pytest.mark.parametrize('scale', [1.0, 2.0, 3.0, 4.0]) 117 | @pytest.mark.parametrize('B, Hq, Hk, M, N, D', [ 118 | (2, 4, 4, 512, 612, 128), 119 | (2, 4, 4, 1024, 1034, 64), 120 | (2, 4, 4, 2048, 2048, 32), 121 | (2, 4, 4, 4096, 4096, 16), 122 | (2, 4, 4, 4001, 4001, 32), 123 | (2, 4, 4, 4001, 4096, 64), 124 | (2, 4, 4, 4096, 4001, 128), 125 | (1, 2, 2, 8192, 8202, 16), 126 | (1, 2, 2, 8192, 8192, 32), 127 | (2, 4, 4, 10006, 10, 128), 128 | # test for mqa/gqa 129 | (2, 4, 2, 512, 612, 128), 130 | (2, 4, 1, 1024, 1034, 64), 131 | (2, 4, 2, 2048, 2048, 32), 132 | (2, 4, 1, 4096, 4096, 16), 133 | (2, 4, 2, 4001, 4001, 32), 134 | (2, 4, 1, 4001, 4096, 64), 135 | (2, 4, 2, 4096, 4001, 128), 136 | (1, 2, 1, 8192, 8202, 16), 137 | (1, 2, 1, 8192, 8192, 32), 138 | (2, 4, 2, 10006, 10, 128), 139 | ]) 140 | @pytest.mark.parametrize('causal', [True, False]) 141 | @pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) 142 | @pytest.mark.parametrize('stride_order', ['BHTD', 'BTHD']) 143 | def test_attention_bwd(B, Hq, Hk, M, N, D, causal, stride_order, dtype, scale, device_id): 144 | device = f"cuda:{device_id}" 145 | if stride_order == "BHTD": 146 | q = torch.empty((B, Hq, M, D), dtype=dtype, device=device).normal_(mean=0., std=scale).requires_grad_() 147 | k = torch.empty((B, Hk, N, D), dtype=dtype, device=device).normal_(mean=0., std=scale).requires_grad_() 148 | v = torch.empty((B, Hk, N, D), dtype=dtype, device=device).normal_(mean=0., std=scale).requires_grad_() 149 | do = torch.randn((B, Hq, M, D), dtype=dtype, device=device) 150 | else: 151 | q = torch.empty((B, M, Hq, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2).requires_grad_() 152 | k = torch.empty((B, N, Hk, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2).requires_grad_() 153 | v = torch.empty((B, N, Hk, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2).requires_grad_() 154 | do = torch.randn((B, M, Hq, D), dtype=dtype, device=device).transpose(1, 2) 155 | 156 | o_ref = flag_attn.testing.flash_attention(q, k, v, causal=causal, upcast=True) 157 | o_torch = flag_attn.testing.flash_attention(q, k, v, causal=causal, upcast=False) 158 | o_hyp = flag_attn.flash_attention(q, k, v, causal=causal) 159 | 160 | gq_ref, gk_ref, gv_ref = torch.autograd.grad(o_ref, (q, k, v), do) 161 | gq_torch, gk_torch, gv_torch = torch.autograd.grad(o_torch, (q, k, v), do) 162 | gq_hyp, gk_hyp, gv_hyp = torch.autograd.grad(o_hyp, (q, k, v), do) 163 | 164 | o_torch_max_diff = max_diff(o_torch, o_ref) 165 | gq_torch_max_diff = max_diff(gq_torch, gq_ref) 166 | gk_torch_max_diff = max_diff(gk_torch, gk_ref) 167 | gv_torch_max_diff = max_diff(gv_torch, gv_ref) 168 | 169 | o_triton_max_diff = max_diff(o_hyp, o_ref) 170 | gq_triton_max_diff = max_diff(gq_hyp, gq_ref) 171 | gk_triton_max_diff = max_diff(gk_hyp, gk_ref) 172 | gv_triton_max_diff = max_diff(gv_hyp, gv_ref) 173 | 174 | assert o_triton_max_diff < 2 * o_torch_max_diff + 1e-5 175 | assert gq_triton_max_diff < 2 * gq_torch_max_diff + 1e-5 176 | assert gk_triton_max_diff < 2 * gk_torch_max_diff + 1e-5 177 | assert gv_triton_max_diff < 2 * gv_torch_max_diff + 1e-5 178 | 179 | 180 | @pytest.mark.parametrize('device_id', list(range(torch.cuda.device_count()))) 181 | @pytest.mark.parametrize('scale', [1.0, 2.0, 3.0, 4.0]) 182 | @pytest.mark.parametrize('B, H, M, N, D', [ 183 | (2, 4, 512, 612, 128), 184 | (2, 4, 1024, 1034, 64), 185 | (2, 4, 2048, 2048, 32), 186 | (2, 4, 4096, 4096, 16), 187 | (2, 4, 4001, 4001, 32), 188 | (2, 4, 4001, 4096, 64), 189 | (2, 4, 4096, 4001, 128), 190 | (1, 2, 8192, 8202, 16), 191 | (1, 2, 8192, 8192, 32), 192 | ]) 193 | @pytest.mark.parametrize('causal', [False]) 194 | @pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) 195 | @pytest.mark.parametrize('stride_order', ['BHTD', 'BTHD']) 196 | def test_attention_with_aux_outs(B, H, M, N, D, causal, stride_order, dtype, scale, device_id): 197 | device = f"cuda:{device_id}" 198 | if stride_order == "BHTD": 199 | q = torch.empty((B, H, M, D), dtype=dtype, device=device).normal_(mean=0., std=scale) 200 | k = torch.empty((B, H, N, D), dtype=dtype, device=device).normal_(mean=0., std=scale) 201 | v = torch.empty((B, H, N, D), dtype=dtype, device=device).normal_(mean=0., std=scale) 202 | else: 203 | q = torch.empty((B, M, H, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2) 204 | k = torch.empty((B, N, H, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2) 205 | v = torch.empty((B, N, H, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2) 206 | 207 | o_ref, log_norm_ref, tot_attn_ref = flag_attn.testing.flash_attention(q, k, v, causal, return_log_normalizer=True, return_total_attention=True, upcast=True) 208 | o_torch, log_norm_torch, tot_attn_torch = flag_attn.testing.flash_attention(q, k, v, causal, return_log_normalizer=True, return_total_attention=True, upcast=False) 209 | o_hyp, log_norm_hyp, tot_attn_hyp, *_ = flag_attn.flash_attention(q, k, v, causal, return_log_normalizer=True, return_total_attention=True) 210 | 211 | 212 | torch_max_diff = max_diff(o_torch, o_ref) 213 | triton_max_diff = max_diff(o_hyp, o_ref) 214 | assert triton_max_diff <= 2 * torch_max_diff + 1e-5 215 | 216 | torch_max_diff = max_diff(log_norm_torch, log_norm_ref) 217 | triton_max_diff = max_diff(log_norm_hyp, log_norm_ref) 218 | assert triton_max_diff <= 2 * torch_max_diff + 1e-5 219 | 220 | torch_max_diff = max_diff(tot_attn_torch, tot_attn_ref) 221 | triton_max_diff = max_diff(tot_attn_hyp, tot_attn_ref) 222 | assert triton_max_diff <= 2 * torch_max_diff + 1e-5 223 | 224 | 225 | @pytest.mark.parametrize('device_id', list(range(torch.cuda.device_count()))) 226 | @pytest.mark.parametrize('scale', [1.0, 2.0]) 227 | @pytest.mark.parametrize('B, H, M, N, D', [ 228 | (2, 4, 512, 612, 128), 229 | (2, 4, 1024, 1034, 64), 230 | (2, 4, 2048, 2048, 32), 231 | (2, 4, 4096, 4096, 16), 232 | (2, 4, 4001, 4001, 32), 233 | (2, 4, 4001, 4096, 64), 234 | (2, 4, 4096, 4000, 128), 235 | (1, 2, 8192, 8202, 16), 236 | (1, 2, 8192, 8192, 32), 237 | ]) 238 | @pytest.mark.parametrize('causal', [False, True]) 239 | @pytest.mark.parametrize('dropout_p', [0.5, 0.8]) 240 | @pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) 241 | @pytest.mark.parametrize('stride_order', ['BHTD', 'BTHD']) 242 | def test_attention_fwd_dropout(B, H, M, N, D, causal, dropout_p, stride_order, dtype, scale, device_id): 243 | device = f"cuda:{device_id}" 244 | if stride_order == "BHTD": 245 | q = torch.empty((B, H, M, D), dtype=dtype, device=device).normal_(mean=0., std=scale) 246 | k = torch.empty((B, H, N, D), dtype=dtype, device=device).normal_(mean=0., std=scale) 247 | v = torch.empty((B, H, N, D), dtype=dtype, device=device).normal_(mean=0., std=scale) 248 | else: 249 | q = torch.empty((B, M, H, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2) 250 | k = torch.empty((B, N, H, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2) 251 | v = torch.empty((B, N, H, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2) 252 | 253 | o_hyp, _, _, seed, offset = flag_attn.flash_attention(q, k, v, causal, dropout_p=dropout_p, return_seed_offset=True) 254 | mask = flag_attn.testing.recompute_mask(B, H, M, N, dropout_p, seed, offset, device) 255 | o_ref = flag_attn.testing.flash_attention(q, k, v, causal, dropout_p=dropout_p, dropout_mask=mask, upcast=True) 256 | o_torch = flag_attn.testing.flash_attention(q, k, v, causal, dropout_p=dropout_p, dropout_mask=mask, upcast=False) 257 | 258 | torch_max_diff = max_diff(o_torch, o_ref) 259 | triton_max_diff = max_diff(o_hyp, o_ref) 260 | report("o hyp", o_hyp, o_ref) 261 | report("o torch", o_torch, o_ref) 262 | assert triton_max_diff <= 2 * torch_max_diff + 1e-5 263 | 264 | 265 | import random 266 | # @pytest.mark.parametrize('increment', [random.randint(0, 1000000000) for i in range(100)]) 267 | @pytest.mark.parametrize('device_id', list(range(torch.cuda.device_count()))) 268 | @pytest.mark.parametrize('scale', [1.0, 2.0]) 269 | @pytest.mark.parametrize('B, H, M, N, D', [ 270 | (2, 4, 512, 612, 128), 271 | (2, 4, 1024, 1034, 64), 272 | (2, 4, 2048, 2048, 32), 273 | (2, 4, 4096, 4096, 16), 274 | (2, 4, 4001, 4001, 32), 275 | (2, 4, 4001, 4096, 64), 276 | (2, 4, 4096, 4000, 128), 277 | (1, 2, 8192, 8202, 16), 278 | (1, 2, 8192, 8192, 32), 279 | ]) 280 | @pytest.mark.parametrize('causal', [True, False]) 281 | @pytest.mark.parametrize('dropout_p', [0.5, 0.8]) 282 | @pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) 283 | @pytest.mark.parametrize('stride_order', ['BHTD', 'BTHD']) 284 | def test_attention_bwd_dropout(B, H, M, N, D, causal, dropout_p, stride_order, dtype, scale, device_id): 285 | device = f"cuda:{device_id}" 286 | if stride_order == "BHTD": 287 | q = torch.empty((B, H, M, D), dtype=dtype, device=device).normal_(mean=0., std=scale).requires_grad_() 288 | k = torch.empty((B, H, N, D), dtype=dtype, device=device).normal_(mean=0., std=scale).requires_grad_() 289 | v = torch.empty((B, H, N, D), dtype=dtype, device=device).normal_(mean=0., std=scale).requires_grad_() 290 | do = torch.randn((B, H, M, D), dtype=dtype, device=device) 291 | else: 292 | q = torch.empty((B, M, H, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2).requires_grad_() 293 | k = torch.empty((B, N, H, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2).requires_grad_() 294 | v = torch.empty((B, N, H, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2).requires_grad_() 295 | do = torch.randn((B, M, H, D), dtype=dtype, device=device).transpose(1, 2) 296 | 297 | # from flag_attn.dropout import philox_cuda_seed_offset 298 | o_hyp, _, _, seed, offset = flag_attn.flash_attention(q, k, v, causal=causal, dropout_p=dropout_p, return_seed_offset=True) 299 | mask = flag_attn.testing.recompute_mask(B, H, M, N, dropout_p, seed, offset, device) 300 | o_ref = flag_attn.testing.flash_attention(q, k, v, causal=causal, dropout_p=dropout_p, dropout_mask=mask, upcast=True) 301 | o_torch = flag_attn.testing.flash_attention(q, k, v, causal=causal, dropout_p=dropout_p, dropout_mask=mask, upcast=False) 302 | 303 | gq_ref, gk_ref, gv_ref = torch.autograd.grad(o_ref, (q, k, v), do) 304 | gq_torch, gk_torch, gv_torch = torch.autograd.grad(o_torch, (q, k, v), do) 305 | gq_hyp, gk_hyp, gv_hyp = torch.autograd.grad(o_hyp, (q, k, v), do) 306 | 307 | o_torch_max_diff = max_diff(o_torch, o_ref) 308 | gq_torch_max_diff = max_diff(gq_torch, gq_ref) 309 | gk_torch_max_diff = max_diff(gk_torch, gk_ref) 310 | gv_torch_max_diff = max_diff(gv_torch, gv_ref) 311 | 312 | o_triton_max_diff = max_diff(o_hyp, o_ref) 313 | gq_triton_max_diff = max_diff(gq_hyp, gq_ref) 314 | gk_triton_max_diff = max_diff(gk_hyp, gk_ref) 315 | gv_triton_max_diff = max_diff(gv_hyp, gv_ref) 316 | 317 | assert o_triton_max_diff < 2 * o_torch_max_diff + 1e-5 318 | assert gq_triton_max_diff < 2 * gq_torch_max_diff + 1e-5 319 | assert gk_triton_max_diff < 2 * gk_torch_max_diff + 1e-5 320 | assert gv_triton_max_diff < 2 * gv_torch_max_diff + 1e-5 -------------------------------------------------------------------------------- /src/flag_attn/paged.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import triton 3 | import triton.language as tl 4 | 5 | # Requires triton 2.2.0 6 | def attention( 7 | query: torch.Tensor, # [num_seqs, NUM_KV_HEADS * QUERY_GROUP_SIZE, HEAD_SIZE] 8 | key_cache: torch.Tensor, # [num_blocks, NUM_KV_HEADS, KV_BLOCK_SIZE, HEAD_SIZE] 9 | value_cache: torch.Tensor, # [num_blocks, NUM_KV_HEADS, KV_BLOCK_SIZE, HEAD_SIZE], required same stride with key_cache 10 | context_lens: torch.Tensor, # [num_seqs] 11 | block_tables: torch.Tensor, # [num_seqs, max_num_blocks_per_seq] 12 | attn_scale: float, 13 | max_context_len: int, 14 | num_splits: int = 0, 15 | ) -> None: 16 | out = torch.empty_like(query) 17 | 18 | num_seqs = query.shape[0] 19 | num_kv_heads = key_cache.shape[1] 20 | kv_block_size = key_cache.shape[2] 21 | head_size = key_cache.shape[3] 22 | query_group_size = query.shape[1] // num_kv_heads 23 | 24 | if query_group_size == 1: 25 | padded_group_size = 1 26 | elif query_group_size < 16: 27 | padded_group_size = 16 28 | else: 29 | padded_group_size = triton.next_power_of_2(query_group_size) 30 | 31 | assert head_size in (16, 32, 64, 128, 256, 512), f"head_size={head_size}" 32 | assert padded_group_size == 1 or kv_block_size >= 16, f"kv_block_size={kv_block_size}" 33 | # query_group_size in (1, 2, 4, 8, 16, 32, 64, 128, 256) 34 | # assert query_group_size > 0 and query_group_size & (query_group_size-1) == 0, f"query_group_size={query_group_size}" 35 | 36 | # config for A100 37 | # TODO: support more devices and optimize 38 | device = torch.cuda.device_of(query) 39 | num_sms = torch.cuda.get_device_properties(device).multi_processor_count 40 | if num_splits == 0: 41 | if num_seqs * num_kv_heads > 2 * num_sms: 42 | num_splits = 1 43 | if max_context_len >= 4096: 44 | partition_size = max(256, kv_block_size) 45 | num_splits = triton.cdiv(max_context_len, partition_size) 46 | else: 47 | partition_size = max(256, kv_block_size) 48 | num_splits = triton.cdiv(max_context_len, partition_size) 49 | if max_context_len <= 1024 or kv_block_size >= 256: 50 | num_splits = 1 51 | elif num_splits > 1: 52 | partition_size = triton.cdiv(max_context_len, num_splits) 53 | partition_size = triton.next_power_of_2(partition_size) 54 | 55 | with torch.cuda.device(device): 56 | if num_splits == 1: 57 | grid = (num_seqs, num_kv_heads, 1) 58 | _paged_attn_kernel[grid]( 59 | out, # dummy input 60 | out, # dummy input 61 | out, 62 | query, 63 | key_cache, 64 | value_cache, 65 | context_lens, 66 | block_tables, 67 | attn_scale, 68 | block_tables.stride(0), 69 | block_tables.stride(1), 70 | query.stride(0), 71 | query.stride(1), 72 | query.stride(2), 73 | key_cache.stride(0), 74 | key_cache.stride(1), 75 | key_cache.stride(2), 76 | key_cache.stride(3), 77 | out.stride(0), 78 | out.stride(1), 79 | out.stride(1), 80 | out.stride(1), 81 | out.stride(2), 82 | head_size, 83 | query_group_size, 84 | padded_group_size, 85 | num_kv_heads, 86 | kv_block_size, 87 | PARTITION_SIZE=0, 88 | ) 89 | 90 | else: 91 | grid = (num_seqs, num_kv_heads, num_splits) 92 | m_i = torch.empty( 93 | size=(num_seqs, num_kv_heads, num_splits, query_group_size), 94 | dtype=torch.float32, 95 | device=query.device, 96 | ) 97 | l_i = torch.empty_like(m_i) 98 | tmp_out = torch.empty( 99 | size=( 100 | num_seqs, 101 | num_kv_heads, 102 | num_splits, 103 | query_group_size, 104 | head_size, 105 | ), 106 | dtype=out.dtype, 107 | device=out.device, 108 | ) 109 | 110 | assert (partition_size >= kv_block_size) and (partition_size % kv_block_size == 0), \ 111 | f"partition_size={partition_size}, kv_block_size={kv_block_size}" 112 | _paged_attn_kernel[grid]( 113 | m_i, 114 | l_i, 115 | tmp_out, 116 | query, 117 | key_cache, 118 | value_cache, 119 | context_lens, 120 | block_tables, 121 | attn_scale, 122 | block_tables.stride(0), 123 | block_tables.stride(1), 124 | query.stride(0), 125 | query.stride(1), 126 | query.stride(2), 127 | key_cache.stride(0), 128 | key_cache.stride(1), 129 | key_cache.stride(2), 130 | key_cache.stride(3), 131 | tmp_out.stride(0), 132 | tmp_out.stride(1), 133 | tmp_out.stride(2), 134 | tmp_out.stride(3), 135 | tmp_out.stride(4), 136 | head_size, 137 | query_group_size, 138 | padded_group_size, 139 | num_kv_heads, 140 | kv_block_size, 141 | partition_size, 142 | ) 143 | 144 | reduce_grid = (num_seqs, num_kv_heads) 145 | next_num_splits = triton.next_power_of_2(num_splits) 146 | 147 | _paged_attn_v2_reduce_kernel[reduce_grid]( 148 | out, 149 | m_i, 150 | l_i, 151 | tmp_out, 152 | context_lens, 153 | num_splits, 154 | out.stride(0), 155 | out.stride(1), 156 | out.stride(2), 157 | head_size, 158 | query_group_size, 159 | num_kv_heads, 160 | partition_size, 161 | next_num_splits, 162 | ) 163 | return out 164 | 165 | 166 | def get_num_warps(QUERY_GROUP_SIZE, HEAD_SIZE, KV_BLOCK_SIZE): 167 | if QUERY_GROUP_SIZE == 1: 168 | if HEAD_SIZE >= 128 and KV_BLOCK_SIZE >= 32: 169 | return 16 170 | else: 171 | return 8 172 | else: 173 | return 4 174 | 175 | 176 | def get_num_stages(PARTITION_SIZE, KV_BLOCK_SIZE): 177 | if PARTITION_SIZE == 0: 178 | return 1 179 | else: 180 | if torch.cuda.get_device_capability() == (8, 0): 181 | if KV_BLOCK_SIZE < 256: 182 | return 3 183 | else: 184 | return 2 185 | elif torch.cuda.get_device_capability() == (8, 6): 186 | if KV_BLOCK_SIZE < 256: 187 | return 2 188 | else: 189 | return 1 190 | else: 191 | return 1 192 | 193 | 194 | @triton.heuristics( 195 | { 196 | "num_warps": lambda args: get_num_warps( 197 | args["QUERY_GROUP_SIZE"], args["HEAD_SIZE"], args["KV_BLOCK_SIZE"] 198 | ), 199 | "num_stages": lambda args: get_num_stages( 200 | args["QUERY_GROUP_SIZE"], args["KV_BLOCK_SIZE"] 201 | ), 202 | } 203 | ) 204 | @triton.jit 205 | def _paged_attn_kernel( 206 | m_i_ptr, # [num_seqs, NUM_KV_HEADS, max_num_partitions, QUERY_GROUP_SIZE] 207 | l_i_ptr, # [num_seqs, NUM_KV_HEADS, max_num_partitions, QUERY_GROUP_SIZE] 208 | out_ptr, # [num_seqs, NUM_KV_HEADS, max_num_partitions, QUERY_GROUP_SIZE, HEAD_SIZE] 209 | q_ptr, # [num_seqs, NUM_KV_HEADS * QUERY_GROUP_SIZE, HEAD_SIZE] 210 | k_cache_ptr, # [num_blocks, NUM_KV_HEADS, KV_BLOCK_SIZE, HEAD_SIZE] 211 | v_cache_ptr, # [num_blocks, NUM_KV_HEADS, KV_BLOCK_SIZE, HEAD_SIZE] 212 | context_lens_ptr, # [num_seqs] 213 | block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] 214 | attn_scale, 215 | stride_bt0, 216 | stride_bt1, 217 | stride_q0, 218 | stride_q1, 219 | stride_q2, 220 | stride_kv0, 221 | stride_kv1, 222 | stride_kv2, 223 | stride_kv3, 224 | stride_o0, 225 | stride_o1, 226 | stride_o2, 227 | stride_o3, 228 | stride_o4, 229 | HEAD_SIZE: tl.constexpr, 230 | QUERY_GROUP_SIZE: tl.constexpr, 231 | PADDED_QUERY_GROUP_SIZE: tl.constexpr, 232 | NUM_KV_HEADS: tl.constexpr, 233 | KV_BLOCK_SIZE: tl.constexpr, 234 | PARTITION_SIZE: tl.constexpr, 235 | ): 236 | seq_idx = tl.program_id(0) 237 | kv_head_idx = tl.program_id(1) 238 | part_idx = tl.program_id(2) 239 | max_num_partitions = tl.num_programs(2) 240 | 241 | # scale sm_scale by log_2(e) and use 242 | # 2^x instead of exp in the loop because CSE and LICM 243 | # don't work as expected with `exp` in the loop 244 | log2e: tl.constexpr = 1.4426950408889634 245 | 246 | USE_PARTITIONING = PARTITION_SIZE > 0 247 | context_len = tl.load(context_lens_ptr + seq_idx) 248 | if USE_PARTITIONING: 249 | context_start_idx = part_idx * PARTITION_SIZE 250 | if context_start_idx >= context_len: 251 | return 252 | context_end_idx = tl.minimum(context_start_idx + PARTITION_SIZE, context_len) 253 | num_blocks = tl.cdiv(context_end_idx - context_start_idx, KV_BLOCK_SIZE) 254 | else: 255 | num_blocks = tl.cdiv(context_len, KV_BLOCK_SIZE) 256 | 257 | block_offset = tl.arange(0, KV_BLOCK_SIZE) 258 | head_offset = tl.arange(0, HEAD_SIZE) 259 | padding_group_offset = tl.arange(0, PADDED_QUERY_GROUP_SIZE) 260 | 261 | kv_offset = ( 262 | kv_head_idx * stride_kv1 263 | + block_offset[:, None] * stride_kv2 264 | + head_offset[None, :] * stride_kv3 265 | ) 266 | 267 | # Load queries. 268 | q_offset = ( 269 | seq_idx * stride_q0 270 | + (kv_head_idx * QUERY_GROUP_SIZE + padding_group_offset[:, None]) * stride_q1 271 | + head_offset[None, :] * stride_q2 272 | ) 273 | group_mask = padding_group_offset[:, None] < QUERY_GROUP_SIZE 274 | # q: [PADDED_QUERY_GROUP_SIZE, HEAD_SIZE] 275 | q = tl.load(q_ptr + q_offset, mask=group_mask, other=0.0) 276 | 277 | m_i = tl.zeros([PADDED_QUERY_GROUP_SIZE], dtype=tl.float32) - float("inf") 278 | l_i = tl.zeros([PADDED_QUERY_GROUP_SIZE], dtype=tl.float32) 279 | acc = tl.zeros([PADDED_QUERY_GROUP_SIZE, HEAD_SIZE], dtype=tl.float32) 280 | 281 | num_prev_blocks = part_idx * (PARTITION_SIZE // KV_BLOCK_SIZE) 282 | for i in range(num_blocks): 283 | block_idx = num_prev_blocks + i 284 | block_number = tl.load( 285 | block_tables_ptr + seq_idx * stride_bt0 + block_idx * stride_bt1 286 | ) 287 | 288 | # Load a key block. 289 | kv_block_offset = block_number * stride_kv0 + kv_offset 290 | mask_offset = block_idx * KV_BLOCK_SIZE + block_offset 291 | kv_mask = mask_offset[:, None] < context_len 292 | 293 | # k: [KV_BLOCK_SIZE, HEAD_SIZE] 294 | k = tl.load(k_cache_ptr + kv_block_offset, mask=kv_mask, other=0.0) 295 | 296 | # qk: [PADDED_QUERY_GROUP_SIZE, KV_BLOCK_SIZE] 297 | if PADDED_QUERY_GROUP_SIZE == 1: 298 | qk = tl.sum(q[:, None, :] * k[None, :, :], axis=2) 299 | else: 300 | qk = tl.dot(q, k.T, out_dtype=tl.float32) 301 | 302 | qk *= attn_scale 303 | qk = tl.where(mask_offset < context_len, qk, float("-inf")) 304 | 305 | m_i_new = tl.maximum(m_i, tl.max(qk, axis=1)) 306 | 307 | # p: [PADDED_QUERY_GROUP_SIZE, KV_BLOCK_SIZE] 308 | p = tl.math.exp2((qk - m_i_new[:, None]) * log2e) 309 | alpha = tl.math.exp2((m_i - m_i_new) * log2e) 310 | acc *= alpha[:, None] 311 | 312 | # v: [KV_BLOCK_SIZE, HEAD_SIZE] 313 | v = tl.load(v_cache_ptr + kv_block_offset, mask=kv_mask, other=0.0) 314 | 315 | if PADDED_QUERY_GROUP_SIZE == 1: 316 | acc += tl.sum(p.T[:, :, None] * v[:, None, :], axis=0) 317 | else: 318 | p = p.to(v.dtype) 319 | acc += tl.dot(p, v, out_dtype=tl.float32) 320 | 321 | l_i = l_i * alpha + tl.sum(p, axis=1) 322 | m_i = m_i_new 323 | acc = acc / l_i[:, None] 324 | 325 | if USE_PARTITIONING: 326 | part_offset = ( 327 | (seq_idx * NUM_KV_HEADS + kv_head_idx) 328 | * max_num_partitions 329 | * QUERY_GROUP_SIZE 330 | + part_idx * QUERY_GROUP_SIZE 331 | + padding_group_offset 332 | ) 333 | mask = padding_group_offset < QUERY_GROUP_SIZE 334 | tl.store(m_i_ptr + part_offset, m_i, mask=mask) 335 | tl.store(l_i_ptr + part_offset, l_i, mask=mask) 336 | 337 | out_offset = seq_idx * stride_o0 338 | if USE_PARTITIONING: 339 | out_offset += kv_head_idx * stride_o1 340 | else: 341 | out_offset += kv_head_idx * QUERY_GROUP_SIZE * stride_o1 342 | out_offset += ( 343 | part_idx * stride_o2 344 | + padding_group_offset[:, None] * stride_o3 345 | + head_offset[None, :] * stride_o4 346 | ) 347 | 348 | group_mask = padding_group_offset[:, None] < QUERY_GROUP_SIZE 349 | tl.store(out_ptr + out_offset, acc, mask=group_mask) 350 | 351 | 352 | @triton.jit 353 | def _paged_attn_v2_reduce_kernel( 354 | out_ptr, # [num_seqs, NUM_KV_HEADS, QUERY_GROUP_SIZE, HEAD_SIZE] 355 | m_i_ptr, # [num_seqs, NUM_KV_HEADS, max_num_partitions, QUERY_GROUP_SIZE] 356 | l_i_ptr, # [num_seqs, NUM_KV_HEADS, max_num_partitions, QUERY_GROUP_SIZE] 357 | tmp_out_ptr, # [num_seqs, NUM_KV_HEADS, max_num_partitions, QUERY_GROUP_SIZE, HEAD_SIZE] 358 | context_lens_ptr, # [num_seqs] 359 | max_num_partitions, # partition stride 360 | stride_o0, 361 | stride_o1, 362 | stride_o2, 363 | HEAD_SIZE: tl.constexpr, 364 | QUERY_GROUP_SIZE: tl.constexpr, 365 | NUM_KV_HEADS: tl.constexpr, 366 | PARTITION_SIZE: tl.constexpr, 367 | NUM_PARTITIONS: tl.constexpr, 368 | ): 369 | seq_idx = tl.program_id(0) 370 | kv_head_idx = tl.program_id(1) 371 | 372 | context_len = tl.load(context_lens_ptr + seq_idx) 373 | 374 | num_partitions = tl.cdiv(context_len, PARTITION_SIZE) 375 | group_head_offset = ( 376 | tl.arange(0, QUERY_GROUP_SIZE)[:, None] * HEAD_SIZE 377 | + tl.arange(0, HEAD_SIZE)[None, :] 378 | ) 379 | if num_partitions == 1: 380 | tmp_out_offset = ( 381 | seq_idx * NUM_KV_HEADS + kv_head_idx 382 | ) * max_num_partitions * QUERY_GROUP_SIZE * HEAD_SIZE + group_head_offset 383 | tmp_out = tl.load(tmp_out_ptr + tmp_out_offset) 384 | 385 | out_offset = ( 386 | seq_idx * stride_o0 387 | + kv_head_idx * QUERY_GROUP_SIZE * stride_o1 388 | + group_head_offset * stride_o2 389 | ) 390 | tl.store(out_ptr + out_offset, tmp_out) 391 | return 392 | 393 | # Get the global max logit. 394 | ml_offset = ( 395 | (seq_idx * NUM_KV_HEADS + kv_head_idx) * max_num_partitions * QUERY_GROUP_SIZE 396 | + tl.arange(0, NUM_PARTITIONS)[:, None] * QUERY_GROUP_SIZE 397 | + tl.arange(0, QUERY_GROUP_SIZE)[None, :] 398 | ) 399 | 400 | mask = tl.arange(0, NUM_PARTITIONS)[:, None] < num_partitions 401 | # m_i: [NUM_PARTITIONS, QUERY_GROUP_SIZE] 402 | m_i = tl.load(m_i_ptr + ml_offset, mask=mask, other=float("-inf")) 403 | # m: [QUERY_GROUP_SIZE] 404 | m = tl.max(m_i, axis=0) 405 | 406 | # Rescale the exp sums and compute the global sum. 407 | # l_i: [NUM_PARTITIONS, QUERY_GROUP_SIZE] 408 | l_i = tl.load(l_i_ptr + ml_offset, mask=mask, other=0.0) 409 | l_i *= tl.exp(m_i - m[None, :]) 410 | # l: [QUERY_GROUP_SIZE] 411 | l = tl.sum(l_i, axis=0) 412 | # r: [NUM_PARTITIONS, QUERY_GROUP_SIZE] 413 | r = l_i / l[None, :] 414 | r = tl.reshape(r, (NUM_PARTITIONS, QUERY_GROUP_SIZE, 1)) 415 | 416 | tmp_out_offset = ( 417 | (seq_idx * NUM_KV_HEADS + kv_head_idx) 418 | * max_num_partitions 419 | * QUERY_GROUP_SIZE 420 | * HEAD_SIZE 421 | + tl.arange(0, NUM_PARTITIONS)[:, None, None] * QUERY_GROUP_SIZE * HEAD_SIZE 422 | + tl.arange(0, QUERY_GROUP_SIZE)[None, :, None] * HEAD_SIZE 423 | + tl.arange(0, HEAD_SIZE)[None, None, :] 424 | ) 425 | # tmp_out: [NUM_PARTITIONS, QUERY_GROUP_SIZE, HEAD_SIZE] 426 | tmp_out = tl.load(tmp_out_ptr + tmp_out_offset, mask=mask[:, :, None], other=0.0) 427 | # out: [QUERY_GROUP_SIZE, HEAD_SIZE] 428 | out = tl.sum((tmp_out * r).to(tl.float32), axis=0) 429 | 430 | out_offset = ( 431 | seq_idx * stride_o0 432 | + kv_head_idx * QUERY_GROUP_SIZE * stride_o1 433 | + group_head_offset * stride_o2 434 | ) 435 | tl.store(out_ptr + out_offset, out) 436 | -------------------------------------------------------------------------------- /src/flag_attn/piecewise.py: -------------------------------------------------------------------------------- 1 | """ 2 | Piecewise Attention 3 | ==================== 4 | 5 | This is a extension to Flash Attention v2 algorithm from Tri Dao 6 | (https://tridao.me/publications/flash2/flash2.pdf) that performs piecewise computation 7 | of attention scores(The scores to which softmax is applied). This design originates from 8 | the need to make better predictions when the predicted sequence is longer than sequences 9 | in the training set. 10 | 11 | It takes as input two q's and two k's as inputs. The attention score is the dot product 12 | of (q1, k1) or (q2, k2) depending on whether the distance between q & k exceeds a threshold. 13 | 14 | The code is adapted from triton's [tutorial](https://github.com/openai/triton/blob/5162871c6cae01a8508a309cf21a8e6b68a4c091/python/tutorials/06-fused-attention.py). 15 | """ 16 | 17 | import math 18 | import torch 19 | import triton 20 | import triton.language as tl 21 | 22 | __all__ = ["attention"] 23 | 24 | # --------------------------- public API --------------------------- 25 | class PiecewiseAttention(torch.autograd.Function): 26 | @staticmethod 27 | def forward(ctx, q1, k1, q2, k2, v, w, causal, sm_scale): 28 | # shape constraints 29 | Dq1, Dk1, Dq2, Dk2, Dv = q1.shape[-1], k1.shape[-1], q2.shape[-1], k2.shape[-1], v.shape[-1] 30 | assert Dq1 == Dk1 == Dq2 == Dk2 == Dv 31 | assert Dk1 in {16, 32, 64, 128} 32 | 33 | B, H, M, D = q1.shape 34 | N = k1.shape[2] 35 | P_SEQ = N - M 36 | larger_m = M > N 37 | 38 | if sm_scale is None: 39 | sm_scale = 1. / math.sqrt(D) 40 | 41 | # to work around https://github.com/openai/triton/issues/2441 42 | device = torch.cuda.device_of(q1) 43 | with torch.cuda.device(device): 44 | config = get_fwd_config(B, H, M, N, D, causal) 45 | BLOCK_M, BLOCK_N, num_stages, num_warps = config 46 | 47 | divisible_m = M % BLOCK_M == 0 48 | divisible_n = N % BLOCK_N == 0 49 | 50 | grid = (triton.cdiv(M, BLOCK_M), H, B) 51 | o = torch.empty_like(q1) 52 | L = torch.empty((B, H, M), device=q1.device, dtype=torch.float32) 53 | 54 | _fwd_kernel[grid]( 55 | q1, k1, q2, k2, v, sm_scale, 56 | L, 57 | o, 58 | q1.stride(0), q1.stride(1), q1.stride(2), q1.stride(3), 59 | k1.stride(0), k1.stride(1), k1.stride(2), k1.stride(3), 60 | q2.stride(0), q2.stride(1), q2.stride(2), q2.stride(3), 61 | k2.stride(0), k2.stride(1), k2.stride(2), k2.stride(3), 62 | v.stride(0), v.stride(1), v.stride(2), v.stride(3), 63 | o.stride(0), o.stride(1), o.stride(2), o.stride(3), 64 | B, H, M, N, P_SEQ, 65 | w = w, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=D, 66 | IS_CAUSAL=causal, LARGER_M=larger_m, 67 | DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n, 68 | num_warps=num_warps, num_stages=num_stages, 69 | ) 70 | 71 | ctx.save_for_backward(q1, k1, q2, k2, v, o, L) 72 | ctx.sm_scale = sm_scale 73 | ctx.causal = causal 74 | ctx.w = w 75 | return o 76 | 77 | @staticmethod 78 | def backward(ctx, do): 79 | q1, k1, q2, k2, v, o, L = ctx.saved_tensors 80 | w = ctx.w 81 | causal = ctx.causal 82 | sm_scale = ctx.sm_scale 83 | 84 | B, H, M, D = q1.shape 85 | N = k1.shape[2] 86 | P_SEQ = N - M 87 | larger_m = M > N 88 | 89 | if sm_scale is None: 90 | sm_scale = 1. / math.sqrt(D) 91 | 92 | # to work around https://github.com/openai/triton/issues/2441 93 | device = torch.cuda.device_of(q1) 94 | with torch.cuda.device(device): 95 | config = get_bwd_config(B, H, M, N, D, causal) 96 | BLOCK_M, BLOCK_N, num_stages, num_warps = config 97 | 98 | divisible_m = M % BLOCK_M == 0 99 | divisible_n = N % BLOCK_N == 0 100 | 101 | delta = torch.empty((B, H, M), device=q1.device, dtype=torch.float32) 102 | grid = (triton.cdiv(M, BLOCK_M), H, B) 103 | _bwd_preprocess[grid]( 104 | o, do, 105 | delta, 106 | o.stride(0), o.stride(1), o.stride(2), o.stride(3), 107 | do.stride(0), do.stride(1), do.stride(2), do.stride(3), 108 | delta.stride(0), delta.stride(1), delta.stride(2), 109 | M, 110 | BLOCK_M=BLOCK_M, D_HEAD=D, 111 | DIVISIBLE_M=divisible_m, 112 | ) 113 | 114 | dk1 = torch.empty_like(k1) 115 | dk2 = torch.empty_like(k2) 116 | dv = torch.empty_like(v) 117 | grid = (triton.cdiv(N, BLOCK_N), H, B) 118 | _bwd_kv_kernel[grid]( 119 | q1, k1, q2, k2, v, sm_scale, do, 120 | dk1,dk2, dv, 121 | L, delta, 122 | q1.stride(0), q1.stride(1), q1.stride(2), q1.stride(3), 123 | k1.stride(0), k1.stride(1), k1.stride(2), k1.stride(3), 124 | q2.stride(0), q2.stride(1), q2.stride(2), q2.stride(3), 125 | k2.stride(0), k2.stride(1), k2.stride(2), k2.stride(3), 126 | v.stride(0), v.stride(1), v.stride(2), v.stride(3), 127 | do.stride(0), do.stride(1), do.stride(2), do.stride(3), 128 | dk1.stride(0), dk1.stride(1), dk1.stride(2), dk1.stride(3), 129 | dk2.stride(0), dk2.stride(1), dk2.stride(2), dk2.stride(3), 130 | dv.stride(0), dv.stride(1), dv.stride(2), dv.stride(3), 131 | B, H, M, N, P_SEQ, 132 | w=w, 133 | BLOCK_M=BLOCK_M, BLOCK_DMODEL=D, 134 | BLOCK_N=BLOCK_N, 135 | CAUSAL=causal, 136 | DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n, 137 | num_stages=num_stages, 138 | num_warps=num_warps, 139 | ) 140 | 141 | dq1 = torch.zeros_like(q1) 142 | dq2 = torch.zeros_like(q2) 143 | grid = (triton.cdiv(M, BLOCK_M), H, B) 144 | _bwd_q_kernel[grid]( 145 | q1, k1, q2, k2, v, sm_scale, do, 146 | dq1, dq2, 147 | L, delta, 148 | q1.stride(0), q1.stride(1), q1.stride(2), q1.stride(3), 149 | k1.stride(0), k1.stride(1), k1.stride(2), k1.stride(3), 150 | q2.stride(0), q2.stride(1), q2.stride(2), q2.stride(3), 151 | k2.stride(0), k2.stride(1), k2.stride(2), k2.stride(3), 152 | v.stride(0), v.stride(1), v.stride(2), v.stride(3), 153 | do.stride(0), do.stride(1), do.stride(2), do.stride(3), 154 | dq1.stride(0), dq1.stride(1), dq1.stride(2), dq1.stride(3), 155 | dq2.stride(0), dq2.stride(1), dq2.stride(2), dq2.stride(3), 156 | B, H, M, N, P_SEQ, 157 | w=w, 158 | BLOCK_M=BLOCK_M, BLOCK_DMODEL=D, 159 | BLOCK_N=BLOCK_N, 160 | CAUSAL=causal, LARGER_M=larger_m, 161 | DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n, 162 | num_stages=num_stages, 163 | num_warps=num_warps, 164 | ) 165 | 166 | return dq1, dk1, dq2, dk2, dv, None, None, None 167 | 168 | 169 | def attention(q1, k1, q2, k2, v, dist_threshold, causal=False, sm_scale=None): 170 | """ 171 | PiecewiseAttention 172 | 173 | Piecewise deviates from standard scaled dot product attention in that takes 174 | as inputs two q's and two k's as inputs. The attention score is dot product 175 | of (q1, k1) or (q2, k2) depending on whether the distance between q & k 176 | exceeds a threshold. 177 | 178 | Arguments: 179 | q1(torch.Tensor): The first queries. The shape is (batch_size, nheads, seqlen_q, headdim). 180 | k1(torch.Tensor): The first keys. The shape is (batch_size, nheads, seqlen_k, headdim). 181 | q2(torch.Tensor): The second queries. The shape is (batch_size, nheads, seqlen_q, headdim). 182 | k2(torch.Tensor): The second keys. The shape is (batch_size, nheads, seqlen_k, headdim). 183 | v(torch.Tensor): The values. The shape is (batch_size, nheads, seqlen_k, headdim). 184 | dist_threshold(int): The threshold of distance between q and k. When the distance is not greater than w, the attention score is dot(q1, k1); otherwise dot(q2, k2). 185 | causal(bool): Whether causal masking is applied to attention scores before applying softmax. 186 | sm_scale(float): The scaling of attention scores before applying softmax. 187 | 188 | Returns: 189 | out: (torch.Tensor): The output. The shape is (batch_size, nheads, seqlen_q, headdim). 190 | """ 191 | return PiecewiseAttention.apply(q1, k1, q2, k2, v, dist_threshold, causal, sm_scale) 192 | 193 | # --------------------------- Forward --------------------------- 194 | def get_fwd_config(B, H, M, N, D, causal): 195 | # A100 196 | if torch.cuda.get_device_capability() == (8, 0): 197 | if not causal: 198 | if D <= 64: 199 | BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 32, 3, 4 200 | else: 201 | BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 64, 3, 8 202 | else: 203 | if D <= 64: 204 | BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 32, 3, 4 205 | else: 206 | BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 64, 4, 8 207 | # RTX-3090, ... 208 | elif torch.cuda.get_device_capability() == (8, 6): 209 | if not causal: 210 | if D <= 64: 211 | BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 32, 3, 4 212 | else: 213 | BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 64, 3, 8 214 | else: 215 | if D <= 64: 216 | BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 32, 3, 4 217 | else: 218 | BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 32, 2, 8 219 | else: 220 | BLOCK_M, BLOCK_N, num_stages, num_warps = 32, 32, 1, 4 221 | return BLOCK_M, BLOCK_N, num_stages, num_warps 222 | 223 | @triton.jit 224 | def _fwd_kernel( 225 | Q1, K1, Q2, K2, V, sm_scale, 226 | L, 227 | O, 228 | stride_q1z, stride_q1h, stride_q1m, stride_q1k, 229 | stride_k1z, stride_k1h, stride_k1n, stride_k1k, 230 | stride_q2z, stride_q2h, stride_q2m, stride_q2k, 231 | stride_k2z, stride_k2h, stride_k2n, stride_k2k, 232 | stride_vz, stride_vh, stride_vn, stride_vk, 233 | stride_oz, stride_oh, stride_om, stride_ok, 234 | Z, H, M, N, P_SEQ, 235 | w: tl.constexpr, 236 | BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, 237 | BLOCK_N: tl.constexpr, 238 | IS_CAUSAL: tl.constexpr, LARGER_M: tl.constexpr, 239 | DIVISIBLE_M: tl.constexpr, DIVISIBLE_N: tl.constexpr, 240 | ): 241 | input_dtype = Q1.dtype.element_ty 242 | # -- grid id -- 243 | start_m = tl.program_id(0) 244 | off_h = tl.program_id(1) 245 | off_z = tl.program_id(2) 246 | 247 | # scale sm_scale by log_2(e) and use 248 | # 2^x instead of exp in the loop because CSE and LICM 249 | # don't work as expected with `exp` in the loop 250 | log2e: tl.constexpr = 1.4426950408889634 251 | qk_scale = sm_scale * log2e 252 | 253 | # offset pointers for (batch, head) 254 | Q1 += off_z * stride_q1z + off_h * stride_q1h 255 | Q2 += off_z * stride_q2z + off_h * stride_q2h 256 | K1 += off_z * stride_k1z + off_h * stride_k1h 257 | K2 += off_z * stride_k2z + off_h * stride_k2h 258 | V += off_z * stride_vz + off_h * stride_vh 259 | O += off_z * stride_oz + off_h * stride_oh 260 | L += (off_z * H + off_h) * M # L: shape(B, H, N_CTX), C-contiguous 261 | 262 | offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) 263 | offs_n_base = tl.arange(0, BLOCK_N) 264 | offs_n_init = offs_n_base 265 | offs_k = tl.arange(0, BLOCK_DMODEL) 266 | 267 | # initialize pointers to v alue-like data 268 | q1_ptrs = Q1 + (offs_m[:, None] * stride_q1m + offs_k[None, :] * stride_q1k) # (BLOCK_M, BLOCK_DMODEL) 269 | q2_ptrs = Q2 + (offs_m[:, None] * stride_q2m + offs_k[None, :] * stride_q2k) # (BLOCK_M, BLOCK_DMODEL) 270 | k1_ptrs = K1 + (offs_n_init[:, None] * stride_k1n + offs_k[None, :] * stride_k1k) # (BLOCK_N, BLOCK_DMODEL) 271 | k2_ptrs = K2 + (offs_n_init[:, None] * stride_k2n + offs_k[None, :] * stride_k2k) # (BLOCK_N, BLOCK_DMODEL) 272 | v_ptrs = V + (offs_n_init[:, None] * stride_vn + offs_k[None, :] * stride_vk) # (BLOCK_N, BLOCK_DMODEL) 273 | o_ptrs = O + (offs_m[:, None] * stride_om + offs_k[None, :] * stride_ok) # (BLOCK_M, BLOCK_DMODEL) 274 | l_ptrs = L + offs_m 275 | 276 | # initialize pointer to m and l, fp32 for accumulators 277 | m_i = tl.full([BLOCK_M], value=-float("inf"), dtype=tl.float32) 278 | l_i = tl.zeros([BLOCK_M], dtype=tl.float32) 279 | acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) 280 | 281 | # load q: it will stay in SRAM throughout 282 | if DIVISIBLE_M: 283 | q1 = tl.load(q1_ptrs) 284 | q2 = tl.load(q2_ptrs) 285 | else: 286 | mask_m = offs_m < M 287 | q1 = tl.load(q1_ptrs, mask=mask_m[:, None]) 288 | q2 = tl.load(q2_ptrs, mask=mask_m[:, None]) 289 | 290 | # Dot I trick: it converts q1, q2 into mma layout and saves shared memory 291 | # better way to generate a eye matrix. avoid casting from bool 292 | I = tl.where(offs_k[:, None] == offs_k, 293 | tl.full((BLOCK_DMODEL, BLOCK_DMODEL), 1.0, dtype=input_dtype), 294 | tl.full((BLOCK_DMODEL, BLOCK_DMODEL), 0.0, dtype=input_dtype)) 295 | q1 = tl.dot(q1, I).to(input_dtype) 296 | q2 = tl.dot(q2, I).to(input_dtype) 297 | 298 | # loop over k, v and update accumulator 299 | # see note "Loop-Bound-For-N" 300 | if IS_CAUSAL: 301 | hi = tl.minimum(N, P_SEQ + (start_m + 1) * BLOCK_M) 302 | if LARGER_M: 303 | hi = tl.maximum(0, hi) 304 | else: 305 | hi = N 306 | 307 | for start_n in range(0, hi, BLOCK_N): 308 | # -- offsets & masking -- 309 | start_n = tl.multiple_of(start_n, BLOCK_N) 310 | offs_n = start_n + offs_n_base 311 | piecewise_mask = (P_SEQ + offs_m[:, None]) >= (offs_n[None, :] + w) 312 | 313 | # -- load k, v -- 314 | if DIVISIBLE_N: 315 | k1 = tl.load(k1_ptrs) 316 | k2 = tl.load(k2_ptrs) 317 | v = tl.load(v_ptrs) 318 | else: 319 | mask_n = offs_n < N 320 | k1 = tl.load(k1_ptrs, mask=mask_n[:, None]) 321 | k2 = tl.load(k2_ptrs, mask=mask_n[:, None]) 322 | v = tl.load(v_ptrs, mask=mask_n[:, None]) 323 | 324 | # -- compute s = qk --- 325 | s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 326 | 327 | # TODO: more careful masking 328 | s += tl.where(piecewise_mask, 329 | tl.dot(q2, tl.trans(k2)), 330 | tl.dot(q1, tl.trans(k1))) 331 | if not DIVISIBLE_N: 332 | s = tl.where(mask_n, s, float("-inf")) 333 | if IS_CAUSAL: 334 | causal_mask = (P_SEQ + offs_m[:, None]) >= offs_n[None, :] 335 | s = tl.where(causal_mask, s, float("-inf")) 336 | 337 | # -- compute scaling constant --- 338 | # loop l2r, so no extra handling of inf is needed 339 | m_i_new = tl.maximum(m_i, tl.max(s, 1)) 340 | alpha = tl.math.exp2((m_i - m_i_new) * qk_scale) 341 | p = tl.math.exp2(s * qk_scale - m_i_new[:, None] * qk_scale) 342 | 343 | # -- scale and update acc -- 344 | acc *= alpha[:, None] 345 | acc += tl.dot(p.to(input_dtype), v) 346 | 347 | # -- update m_i and l_i -- 348 | l_i = l_i * alpha + tl.sum(p, 1) 349 | m_i = m_i_new 350 | 351 | # update pointers 352 | k1_ptrs += BLOCK_N * stride_k1n 353 | k2_ptrs += BLOCK_N * stride_k2n 354 | v_ptrs += BLOCK_N * stride_vn 355 | 356 | # write back l & o 357 | if IS_CAUSAL and LARGER_M: 358 | is_empty_line = (offs_m + P_SEQ) < 0 359 | acc = tl.where(is_empty_line[:, None], 0.0, acc * (1.0 / l_i[:, None])) 360 | l_i = tl.where(is_empty_line, float("-inf"), m_i * sm_scale + tl.log(l_i)) 361 | else: 362 | acc = acc * (1.0 / l_i[:, None]) 363 | l_i = m_i * sm_scale + tl.log(l_i) 364 | 365 | if DIVISIBLE_M: 366 | tl.store(l_ptrs, l_i) 367 | tl.store(o_ptrs, acc.to(input_dtype)) 368 | else: 369 | tl.store(l_ptrs, l_i, mask=mask_m) 370 | tl.store(o_ptrs, acc.to(input_dtype), mask=mask_m[:, None]) 371 | 372 | # --------------------------- Backward --------------------------- 373 | def get_bwd_config(B, H, M, N, D, causal): 374 | # A100 375 | if torch.cuda.get_device_capability() == (8, 0): 376 | if not causal: 377 | if D <= 64: 378 | BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 2, 4 379 | else: 380 | BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 64, 2, 8 381 | else: 382 | if D <= 64: 383 | BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 3, 4 384 | else: 385 | BLOCK_M, BLOCK_N, num_stages, num_warps = 32, 64, 2, 4 386 | 387 | # BLOCK_M = 64 if D<=64 else 128 388 | # BLOCK_N = 64 389 | # num_stages = 1 if D<=64 else (2 if not causal else 1) 390 | # num_warps = 4 if D <=64 else 8 391 | # RTX-3090, ... 392 | elif torch.cuda.get_device_capability() == (8, 6): 393 | if not causal: 394 | if D <= 64: 395 | BLOCK_M, BLOCK_N, num_stages, num_warps = 32, 32, 2, 4 396 | else: 397 | BLOCK_M, BLOCK_N, num_stages, num_warps = 32, 64, 2, 8 398 | else: 399 | if D <= 64: 400 | BLOCK_M, BLOCK_N, num_stages, num_warps = 32, 32, 2, 4 401 | else: 402 | BLOCK_M, BLOCK_N, num_stages, num_warps = 32, 64, 2, 8 403 | else: 404 | BLOCK_M, BLOCK_N, num_stages, num_warps = 32, 32, 1, 4 405 | return BLOCK_M, BLOCK_N, num_stages, num_warps 406 | 407 | @triton.jit 408 | def _bwd_preprocess( 409 | Out, DO, 410 | Delta, 411 | stride_oz, stride_oh, stride_om, stride_ok, 412 | stride_doz, stride_doh, stride_dom, stride_dok, 413 | stride_dz, stride_dh, stride_dm, 414 | M, 415 | BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr, 416 | DIVISIBLE_M: tl.constexpr, 417 | ): 418 | off_h = tl.program_id(1) 419 | off_z = tl.program_id(2) 420 | Out += off_z * stride_oz + off_h * stride_oh 421 | DO += off_z * stride_doz + off_h * stride_doh 422 | Delta += off_z * stride_dz + off_h * stride_dh 423 | 424 | # compute (Out * Dout).sum() for vector interpretation 425 | off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) 426 | off_n = tl.arange(0, D_HEAD) 427 | 428 | # load 429 | o_ptrs = Out + off_m[:, None] * stride_om + off_n[None, :] * stride_ok 430 | do_ptrs = DO + off_m[:, None] * stride_dom + off_n[None, :] * stride_dok 431 | 432 | if DIVISIBLE_M: 433 | o = tl.load(o_ptrs).to(tl.float32) 434 | do = tl.load(do_ptrs).to(tl.float32) 435 | else: 436 | mask_m = off_m < M 437 | o = tl.load(o_ptrs, mask=mask_m[:, None]).to(tl.float32) 438 | do = tl.load(do_ptrs, mask=mask_m[:, None]).to(tl.float32) 439 | 440 | # compute 441 | delta = tl.sum(o * do, axis=1) 442 | # write-back 443 | d_ptrs = Delta + off_m * stride_dm 444 | if DIVISIBLE_M: 445 | tl.store(d_ptrs, delta) 446 | else: 447 | tl.store(d_ptrs, delta, mask=mask_m) 448 | 449 | 450 | @triton.jit 451 | def _bwd_kv_kernel( 452 | Q1, K1, Q2, K2, V, sm_scale, DO, 453 | DK1, DK2, DV, 454 | L, 455 | D, 456 | stride_q1z, stride_q1h, stride_q1m, stride_q1k, 457 | stride_k1z, stride_k1h, stride_k1n, stride_k1k, 458 | stride_q2z, stride_q2h, stride_q2m, stride_q2k, 459 | stride_k2z, stride_k2h, stride_k2n, stride_k2k, 460 | stride_vz, stride_vh, stride_vn, stride_vk, 461 | stride_doz, stride_doh, stride_dom, stride_dok, 462 | stride_dk1z, stride_dk1h, stride_dk1n, stride_dk1k, 463 | stride_dk2z, stride_dk2h, stride_dk2n, stride_dk2k, 464 | stride_dvz, stride_dvh, stride_dvn, stride_dvk, 465 | Z, H, M, N, P_SEQ, 466 | w: tl.constexpr, 467 | BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, 468 | BLOCK_N: tl.constexpr, 469 | CAUSAL: tl.constexpr, 470 | DIVISIBLE_M: tl.constexpr, DIVISIBLE_N: tl.constexpr, 471 | ): 472 | input_dtype = Q1.dtype.element_ty 473 | # -- grid id -- 474 | start_n = tl.program_id(0) 475 | off_h = tl.program_id(1) 476 | off_z = tl.program_id(2) 477 | 478 | log2e: tl.constexpr = 1.4426950408889634 479 | qk_scale = sm_scale * log2e 480 | 481 | # offset pointers for (batch, head) 482 | Q1 += off_z * stride_q1z + off_h * stride_q1h 483 | Q2 += off_z * stride_q2z + off_h * stride_q2h 484 | K1 += off_z * stride_k1z + off_h * stride_k1h 485 | K2 += off_z * stride_k2z + off_h * stride_k2h 486 | V += off_z * stride_vz + off_h * stride_vh 487 | DO += off_z * stride_doz + off_h * stride_doh 488 | D += (off_z * H + off_h) * M 489 | L += (off_z * H + off_h) * M 490 | 491 | # offset pointers for batch/head 492 | DK1 += off_z * stride_dk1z + off_h * stride_dk1h 493 | DK2 += off_z * stride_dk2z + off_h * stride_dk2h 494 | DV += off_z * stride_dvz + off_h * stride_dvh 495 | 496 | if CAUSAL: 497 | lo = tl.maximum(start_n * BLOCK_N - P_SEQ, 0) 498 | lo = (lo // BLOCK_M) * BLOCK_M 499 | else: 500 | lo = 0 501 | 502 | offs_m_init = lo + tl.arange(0, BLOCK_M) 503 | offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) 504 | offs_m_base = tl.arange(0, BLOCK_M) 505 | offs_k = tl.arange(0, BLOCK_DMODEL) 506 | 507 | 508 | # initialize pointers to value-like data 509 | q1_ptrs = Q1 + (offs_m_init[:, None] * stride_q1m + offs_k[None, :] * stride_q1k) # (BLOCK_M, BLOCK_DMODEL) 510 | q2_ptrs = Q2 + (offs_m_init[:, None] * stride_q2m + offs_k[None, :] * stride_q2k) # (BLOCK_M, BLOCK_DMODEL) 511 | k1_ptrs = K1 + (offs_k[:, None] * stride_k1k + offs_n[None, :] * stride_k1n) # (BLOCK_DMODEL, BLOCK_N) 512 | k2_ptrs = K2 + (offs_k[:, None] * stride_k2k + offs_n[None, :] * stride_k2n) # (BLOCK_DMODEL, BLOCK_N) 513 | v_ptrs = V + (offs_n[:, None] * stride_vn + offs_k[None, :] * stride_vk) # (BLOCK_N, BLOCK_DMODEL) 514 | do_ptrs = DO + (offs_m_init[:, None] * stride_dom + offs_k[None, :] * stride_dok) # (BLOCK_M, BLOCK_DMODEL) 515 | 516 | dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_k[None, :] * stride_dvk) # (BLOCK_N, BLOCK_DMODEL) 517 | dk1_ptrs = DK1 + (offs_n[:, None] * stride_dk1n + offs_k[None, :] * stride_dk1k) # (BLOCK_N, BLOCK_DMODEL) 518 | dk2_ptrs = DK2 + (offs_n[:, None] * stride_dk2n + offs_k[None, :] * stride_dk2k) # (BLOCK_N, BLOCK_DMODEL) 519 | 520 | # k and v stay in SRAM throughout 521 | if DIVISIBLE_N: 522 | k1 = tl.load(k1_ptrs) 523 | k2 = tl.load(k2_ptrs) 524 | v = tl.load(v_ptrs) 525 | else: 526 | mask_n = offs_n < N 527 | k1 = tl.load(k1_ptrs, mask=mask_n[None, :]) 528 | k2 = tl.load(k2_ptrs, mask=mask_n[None, :]) 529 | v = tl.load(v_ptrs, mask=mask_n[:, None]) 530 | 531 | # initialize dk amd dv 532 | dk1 = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) 533 | dk2 = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) 534 | dv = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) 535 | 536 | # loop over a column 537 | for start_m in range(lo, M, BLOCK_M): 538 | start_m = tl.multiple_of(start_m, BLOCK_M) 539 | offs_m = start_m + offs_m_base 540 | 541 | # load q1, k1, q2, k2, v, do on-chip 542 | if DIVISIBLE_M: 543 | q1 = tl.load(q1_ptrs) 544 | q2 = tl.load(q2_ptrs) 545 | do = tl.load(do_ptrs) # (BLOCK_M, BLOCK_DMODEL) 546 | delta = tl.load(D + offs_m) 547 | l = tl.load(L + offs_m) 548 | else: 549 | mask_m = offs_m < M 550 | q1 = tl.load(q1_ptrs, mask=mask_m[:, None]) 551 | q2 = tl.load(q2_ptrs, mask=mask_m[:, None]) 552 | do = tl.load(do_ptrs, mask=mask_m[:, None]) # (BLOCK_M, BLOCK_DMODEL) 553 | delta = tl.load(D + offs_m, mask=mask_m) 554 | l = tl.load(L + offs_m, mask=mask_m) 555 | 556 | # recompute p = softmax(qk, dim=-1).T 557 | piecewise_mask = (P_SEQ + offs_m[:, None]) >= (offs_n[None, :] + w) # (BLOCK_M, BLOCK_N) 558 | s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 559 | s += tl.where(piecewise_mask, 560 | tl.dot(q2, k2), 561 | tl.dot(q1, k1)) 562 | 563 | # NOTE: since softmax in backward is pointwise, the normalizer has been saved in fwd) 564 | # So masking on s is not needed. 565 | # if CAUSAL: 566 | # s = tl.where(causal_mask & valid_mask, s, float("-inf")) 567 | # else: 568 | # s = tl.where(valid_mask, s, float("-inf")) 569 | 570 | # -- recompute p --- 571 | # l = tl.load(L + offs_m, mask=mask_m) 572 | p = tl.math.exp2(s * qk_scale - l[:, None] * log2e) # (BLOCK_M, BLOCK_N) 573 | if not DIVISIBLE_M: 574 | valid_mask = mask_m[:, None] # & mask_n 575 | p = tl.where(valid_mask, p, 0.0) 576 | if CAUSAL: 577 | causal_mask = (P_SEQ + offs_m[:, None]) >= (offs_n[None, :]) # (BLOCK_M, BLOCK_N) 578 | p = tl.where(causal_mask, p, 0.0) 579 | 580 | 581 | # compute dv = dot(p, do) 582 | # do = tl.load(do_ptrs, mask=mask_m[:, None]) # (BLOCK_M, BLOCK_DMODEL) 583 | dv += tl.dot(tl.trans(p.to(do.dtype)), do) # (BLOCK_N, BLOCK_DMODEL) 584 | 585 | # compute dp = dot(v, do) 586 | # delta = tl.load(D + offs_m, mask=mask_m) 587 | dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 588 | dp += tl.dot(do.to(input_dtype), tl.trans(v)) 589 | # no need to mask dp 590 | # if CAUSAL: 591 | # dp = tl.where(causal_mask & valid_mask, dp, 0.0) 592 | # else: 593 | # dp = tl.where(valid_mask, dp, 0.0) 594 | 595 | # compute ds = p * (dp - delta[:, None]) 596 | # move scale out to dk at last 597 | ds = p * (dp - delta[:, None]) # (BLOCK_M, BLOCK_N) 598 | 599 | # mask ds To ensure no small values 600 | if not DIVISIBLE_M: 601 | ds = tl.where(valid_mask, ds, 0.0) 602 | if CAUSAL: 603 | ds = tl.where(causal_mask, ds, 0.0) 604 | 605 | ds2 = tl.where(piecewise_mask, ds, 0.0).to(input_dtype) 606 | ds1 = tl.where(piecewise_mask, 0.0, ds).to(input_dtype) 607 | 608 | # compute dk = dot(ds.T, q) masking 609 | dk1 += tl.dot(tl.trans(ds1), q1) 610 | dk2 += tl.dot(tl.trans(ds2), q2) 611 | 612 | # increment pointers 613 | q1_ptrs += BLOCK_M * stride_q1m 614 | q2_ptrs += BLOCK_M * stride_q2m 615 | do_ptrs += BLOCK_M * stride_dom 616 | 617 | dk1 *= sm_scale 618 | dk2 *= sm_scale 619 | 620 | if DIVISIBLE_N: 621 | tl.store(dk1_ptrs, dk1.to(input_dtype)) # (BLOCK_N, BLOCK_DMODEL) 622 | tl.store(dk2_ptrs, dk2.to(input_dtype)) # (BLOCK_N, BLOCK_DMODEL) 623 | tl.store(dv_ptrs, dv.to(input_dtype)) # (BLOCK_N, BLOCK_DMODEL) 624 | else: 625 | tl.store(dk1_ptrs, dk1.to(input_dtype), mask=mask_n[:, None]) # (BLOCK_N, BLOCK_DMODEL) 626 | tl.store(dk2_ptrs, dk2.to(input_dtype), mask=mask_n[:, None]) # (BLOCK_N, BLOCK_DMODEL) 627 | tl.store(dv_ptrs, dv.to(input_dtype), mask=mask_n[:, None]) # (BLOCK_N, BLOCK_DMODEL) 628 | 629 | 630 | @triton.jit 631 | def _bwd_q_kernel( 632 | Q1, K1, Q2, K2, V, sm_scale, DO, 633 | DQ1, DQ2, 634 | L, 635 | D, 636 | stride_q1z, stride_q1h, stride_q1m, stride_q1k, 637 | stride_k1z, stride_k1h, stride_k1n, stride_k1k, 638 | stride_q2z, stride_q2h, stride_q2m, stride_q2k, 639 | stride_k2z, stride_k2h, stride_k2n, stride_k2k, 640 | stride_vz, stride_vh, stride_vn, stride_vk, 641 | stride_doz, stride_doh, stride_dom, stride_dok, 642 | stride_dq1z, stride_dq1h, stride_dq1m, stride_dq1k, 643 | stride_dq2z, stride_dq2h, stride_dq2m, stride_dq2k, 644 | Z, H, M, N, P_SEQ, 645 | w: tl.constexpr, 646 | BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, 647 | BLOCK_N: tl.constexpr, 648 | CAUSAL: tl.constexpr, LARGER_M: tl.constexpr, 649 | DIVISIBLE_M: tl.constexpr, DIVISIBLE_N: tl.constexpr, 650 | ): 651 | input_dtype = Q1.dtype.element_ty 652 | # -- grid id -- 653 | start_m = tl.program_id(0) 654 | off_h = tl.program_id(1) 655 | off_z = tl.program_id(2) 656 | 657 | # scale sm_scale by log_2(e) and use 658 | # 2^x instead of exp in the loop because CSE and LICM 659 | # don't work as expected with `exp` in the loop 660 | log2e: tl.constexpr = 1.4426950408889634 661 | qk_scale = sm_scale * log2e 662 | 663 | # offset pointers for (batch, head) 664 | Q1 += off_z * stride_q1z + off_h * stride_q1h 665 | Q2 += off_z * stride_q2z + off_h * stride_q2h 666 | K1 += off_z * stride_k1z + off_h * stride_k1h 667 | K2 += off_z * stride_k2z + off_h * stride_k2h 668 | V += off_z * stride_vz + off_h * stride_vh 669 | DO += off_z * stride_doz + off_h * stride_doh 670 | D += (off_z * H + off_h) * M 671 | L += (off_z * H + off_h) * M 672 | 673 | # offset pointers for batch/head 674 | DQ1 += off_z * stride_dq1z + off_h * stride_dq1h 675 | DQ2 += off_z * stride_dq2z + off_h * stride_dq2h 676 | 677 | offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) 678 | offs_n_base = tl.arange(0, BLOCK_N) 679 | offs_n_init = offs_n_base 680 | offs_k = tl.arange(0, BLOCK_DMODEL) 681 | 682 | # initialize pointers to value-like data 683 | q1_ptrs = Q1 + (offs_m[:, None] * stride_q1m + offs_k[None, :] * stride_q1k) # (BLOCK_M, BLOCK_DMODEL) 684 | q2_ptrs = Q2 + (offs_m[:, None] * stride_q2m + offs_k[None, :] * stride_q2k) # (BLOCK_M, BLOCK_DMODEL) 685 | k1_ptrs = K1 + (offs_n_init[:, None] * stride_k1n + offs_k[None, :] * stride_k1k) # (BLOCK_N, BLOCK_DMODEL) 686 | k2_ptrs = K2 + (offs_n_init[:, None] * stride_k2n + offs_k[None, :] * stride_k2k) # (BLOCK_N, BLOCK_DMODEL) 687 | v_ptrs = V + (offs_n_init[:, None] * stride_vn + offs_k[None, :] * stride_vk) # (BLOCK_N, BLOCK_DMODEL) 688 | 689 | dq1_ptrs = DQ1 + (offs_m[:, None] * stride_dq1m + offs_k[None, :] * stride_dq1k) # (BLOCK_M, BLOCK_DMODEL) 690 | dq2_ptrs = DQ2 + (offs_m[:, None] * stride_dq2m + offs_k[None, :] * stride_dq2k) # (BLOCK_M, BLOCK_DMODEL) 691 | do_ptrs = DO + (offs_m[:, None] * stride_dom + offs_k[None, :] * stride_dok) # (BLOCK_M, BLOCK_DMODEL) 692 | 693 | # pointer to row-wise quantities in value-like data 694 | d_ptrs = D + offs_m 695 | l_ptrs = L + offs_m 696 | 697 | # load q: it will stay in SRAM throughout 698 | if DIVISIBLE_M: 699 | q1 = tl.load(q1_ptrs) 700 | q2 = tl.load(q2_ptrs) 701 | do = tl.load(do_ptrs) 702 | delta = tl.load(d_ptrs) 703 | l = tl.load(l_ptrs) 704 | else: 705 | mask_m = offs_m < M 706 | q1 = tl.load(q1_ptrs, mask=mask_m[:, None]) 707 | q2 = tl.load(q2_ptrs, mask=mask_m[:, None]) 708 | do = tl.load(do_ptrs, mask=mask_m[:, None]) 709 | delta = tl.load(d_ptrs, mask=mask_m) 710 | l = tl.load(l_ptrs, mask=mask_m) 711 | 712 | # initialize dq 713 | dq1 = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) 714 | dq2 = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) 715 | 716 | # loop over k, v and update accumulator 717 | # see note "Loop-Bound-For-N" 718 | if CAUSAL: 719 | hi = tl.minimum(N, P_SEQ + (start_m + 1) * BLOCK_M) 720 | if LARGER_M: 721 | hi = tl.maximum(0, hi) 722 | else: 723 | hi = N 724 | 725 | # loop over a row 726 | for start_n in range(0, hi, BLOCK_N): 727 | offs_n = start_n + offs_n_base 728 | 729 | # load k1, k2, v on chip 730 | if DIVISIBLE_N: 731 | v = tl.load(v_ptrs) 732 | k1 = tl.load(k1_ptrs) 733 | k2 = tl.load(k2_ptrs) 734 | else: 735 | mask_n = offs_n < N 736 | v = tl.load(v_ptrs, mask=mask_n[:, None]) 737 | k1 = tl.load(k1_ptrs, mask=mask_n[:, None]) 738 | k2 = tl.load(k2_ptrs, mask=mask_n[:, None]) 739 | 740 | # recompute p = softmax(qk * sm_scale, dim=-1) 741 | piecewise_mask = (P_SEQ + offs_m[:, None]) >= (offs_n[None, :] + w) # (BLOCK_M, BLOCK_N) 742 | s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 743 | s += tl.where(piecewise_mask, 744 | tl.dot(q2, tl.trans(k2)), 745 | tl.dot(q1, tl.trans(k1))) 746 | # NOTE: since softmax in backward is pointwise, the normalizer has been saved in fwd) 747 | # So masking on s is not needed. 748 | # if CAUSAL: 749 | # s = tl.where(causal_mask & valid_mask, s, float("-inf")) 750 | # else: 751 | # s = tl.where(valid_mask, s, float("-inf")) 752 | p = tl.math.exp2(s * qk_scale - l[:, None] * log2e) # (BLOCK_M, BLOCK_N) 753 | 754 | # compute dp = dot(v, do) 755 | dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 756 | dp += tl.dot(do.to(input_dtype), tl.trans(v)) 757 | # no need to mask dp 758 | # if CAUSAL: 759 | # dp = tl.where(causal_mask & valid_mask, dp, 0.0) 760 | # else: 761 | # dp = tl.where(valid_mask, dp, 0.0) 762 | 763 | # compute ds = p * (dp - delta[:, None]) 764 | # move scale out to dq at last 765 | ds = p * (dp - delta[:, None]) # (BLOCK_M, BLOCK_N) 766 | 767 | # mask ds to ensure no small values 768 | if not DIVISIBLE_N: 769 | ds = tl.where(mask_n, ds, 0.0) 770 | if CAUSAL: 771 | causal_mask = (P_SEQ + offs_m[:, None]) >= (offs_n[None, :]) # (BLOCK_M, BLOCK_N) 772 | ds = tl.where(causal_mask, ds, 0.0) 773 | 774 | ds2 = tl.where(piecewise_mask, ds, 0.0).to(input_dtype) 775 | ds1 = tl.where(piecewise_mask, 0.0, ds).to(input_dtype) 776 | 777 | dq1 += tl.dot(ds1, k1) 778 | dq2 += tl.dot(ds2, k2) 779 | 780 | # increment pointers 781 | k1_ptrs += BLOCK_N * stride_k1n 782 | k2_ptrs += BLOCK_N * stride_k2n 783 | v_ptrs += BLOCK_N * stride_vn 784 | 785 | dq1 *= sm_scale 786 | dq2 *= sm_scale 787 | if DIVISIBLE_M: 788 | tl.store(dq1_ptrs, dq1.to(input_dtype)) 789 | tl.store(dq2_ptrs, dq2.to(input_dtype)) 790 | else: 791 | tl.store(dq1_ptrs, dq1.to(input_dtype), mask=mask_m[:, None]) 792 | tl.store(dq2_ptrs, dq2.to(input_dtype), mask=mask_m[:, None]) 793 | -------------------------------------------------------------------------------- /src/flag_attn/flash.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import triton 4 | import triton.language as tl 5 | from flag_attn.total import _total_attention_kernel 6 | from flag_attn.split_kv import _fwd_split_kv_kernel, _fwd_combine_kv_splits, num_splits_herustic 7 | from flag_attn.split_kv import get_fwd_config as get_fwd_config_kv_split 8 | 9 | from .dropout import philox_cuda_seed_offset 10 | 11 | __all__ = ["attention"] 12 | 13 | 14 | def maybe_contiguous(x): 15 | # only when the inner most dimension is contiguous can LDGSTS be used 16 | # so inner-dimension contiguity is enforced. 17 | return x.contiguous() if x.stride(-1) != 1 else x 18 | 19 | def rounded_multiple(a, b): 20 | return (a + b - 1) // b * b 21 | 22 | # --------------------------- public API --------------------------- 23 | class FlashAttention(torch.autograd.Function): 24 | @staticmethod 25 | def forward(ctx, q, k, v, causal, sm_scale, dropout_p, return_log_normalizer, return_total_attention, return_seed_offset): 26 | Dq, Dk, Dv = q.shape[-1], k.shape[-1], v.shape[-1] 27 | assert Dq == Dk == Dv, "feature size of q, k, v should be equal" 28 | assert Dk in {16, 32, 64, 128} 29 | 30 | B, H, M, D = q.shape 31 | N = k.shape[2] 32 | Hk, Hv = k.shape[1], v.shape[1] 33 | assert Hk == Hv, "num of heads in k and v should be equal" 34 | assert H % Hk == 0, "number of heads in q must be a multiple of that in k & v" 35 | num_groups = H // Hk 36 | 37 | P_SEQ = N - M 38 | larger_m = M > N 39 | 40 | if sm_scale is None: 41 | sm_scale = 1. / math.sqrt(D) 42 | 43 | # contiguity 44 | q, k, v = maybe_contiguous(q), maybe_contiguous(k), maybe_contiguous(v) 45 | 46 | # to work around https://github.com/openai/triton/issues/2441 47 | device = torch.cuda.device_of(q) 48 | num_sms = torch.cuda.get_device_properties(device).multi_processor_count 49 | 50 | with torch.cuda.device(device): 51 | # Dropout preparation. 52 | is_dropout = dropout_p > 0 53 | if is_dropout: 54 | offset_increment = B * H * M * N 55 | seed, offset = philox_cuda_seed_offset(offset_increment) 56 | else: 57 | seed, offset = 0, 0 58 | 59 | config_for_split_kv = get_fwd_config_kv_split(B, H, M, N, D, causal) 60 | S = num_splits_herustic(B, H, M, N, config_for_split_kv[0], config_for_split_kv[1], num_sms, 128) 61 | split_kv: bool = S > 1 62 | # print(f"flag_attn choose {S} splits") 63 | 64 | if not split_kv: 65 | config = get_fwd_config(B, H, M, N, D, causal) 66 | BLOCK_M, BLOCK_N, num_stages, num_warps = config 67 | 68 | divisible_m = M % BLOCK_M == 0 69 | divisible_n = N % BLOCK_N == 0 70 | # consider using 3d grid to avoid div & rem 71 | grid = (triton.cdiv(M, BLOCK_M), H, B) 72 | o = torch.empty_like(q) 73 | L = torch.empty((B, H, M), device=q.device, dtype=torch.float32) 74 | _fwd_kernel[grid]( 75 | q, k, v, sm_scale, 76 | dropout_p, seed, offset, 77 | L, o, 78 | q.stride(0), q.stride(1), q.stride(2), q.stride(3), 79 | k.stride(0), k.stride(1), k.stride(2), k.stride(3), 80 | v.stride(0), v.stride(1), v.stride(2), v.stride(3), 81 | o.stride(0), o.stride(1), o.stride(2), o.stride(3), 82 | B, H, M, N, P_SEQ, num_groups, 83 | BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=D, 84 | IS_CAUSAL=causal, IS_DROPOUT=is_dropout, LARGER_M=larger_m, 85 | DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n, 86 | num_warps=num_warps, num_stages=num_stages, 87 | ) 88 | else: # split kv 89 | assert not is_dropout, "Cannot apply dropout with splitkv." 90 | BLOCK_M, BLOCK_N, num_stages, num_warps = config_for_split_kv 91 | 92 | divisible_m = M % BLOCK_M == 0 93 | divisible_n = N % BLOCK_N == 0 94 | # consider using 3d grid to avoid div & rem 95 | multiple_l = torch.empty((B, H, S, M), dtype=torch.float32, device="cuda") 96 | multiple_o = torch.empty((B, H, S, M, D), dtype=torch.float16, device="cuda") 97 | grid = (triton.cdiv(M, BLOCK_M), S, H * B) 98 | N_SPLIT_SIZE = triton.cdiv(triton.cdiv(N, BLOCK_N), S) * BLOCK_N 99 | _fwd_split_kv_kernel[grid]( 100 | q, k, v, sm_scale, 101 | multiple_l, multiple_o, 102 | q.stride(0), q.stride(1), q.stride(2), q.stride(3), 103 | k.stride(0), k.stride(1), k.stride(2), k.stride(3), 104 | v.stride(0), v.stride(1), v.stride(2), v.stride(3), 105 | multiple_o.stride(0), multiple_o.stride(1), multiple_o.stride(2), multiple_o.stride(3), multiple_o.stride(4), 106 | B, H, M, N, P_SEQ, N_SPLIT_SIZE, S, num_groups, 107 | BLOCK_M=BLOCK_M, BLOCK_DMODEL=D, BLOCK_N=BLOCK_N, 108 | IS_CAUSAL=causal, LARGER_M=larger_m, 109 | DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n, 110 | num_stages=num_stages, num_warps=num_warps, 111 | ) 112 | 113 | L = torch.empty((B, H, M), dtype=torch.float32, device="cuda") 114 | o = torch.empty_like(q) 115 | grid = (triton.cdiv(M, BLOCK_M), H, B) 116 | _fwd_combine_kv_splits[grid]( 117 | multiple_o, multiple_l, 118 | o, L, 119 | multiple_o.stride(0), multiple_o.stride(1), multiple_o.stride(2), multiple_o.stride(3), multiple_o.stride(4), 120 | o.stride(0), o.stride(1), o.stride(2), o.stride(3), 121 | B, H, M, S, 122 | BLOCK_M=BLOCK_M, BLOCK_DMODEL=D, 123 | DIVISIBLE_M=divisible_m, 124 | num_stages=num_stages, num_warps=num_warps, 125 | ) 126 | 127 | # total attention 128 | if return_total_attention: 129 | tot_attn = torch.empty((B, H, N), device=q.device, dtype=torch.float32) 130 | grid = (triton.cdiv(N, BLOCK_N), H, B) 131 | _total_attention_kernel[grid]( 132 | q, k, L, tot_attn, sm_scale, 133 | q.stride(0), q.stride(1), q.stride(2), q.stride(3), 134 | k.stride(0), k.stride(1), k.stride(2), k.stride(3), 135 | B, H, M, N, P_SEQ, num_groups, 136 | BLOCK_M=BLOCK_M, BLOCK_DMODEL=D, BLOCK_N=BLOCK_N, 137 | CAUSAL=causal, 138 | DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n, 139 | num_stages=num_stages, num_warps=num_warps, 140 | ) 141 | 142 | # autograd context maintenance 143 | ctx.save_for_backward(q, k, v, o, L) 144 | ctx.sm_scale = sm_scale 145 | ctx.causal = causal 146 | ctx.dropout_p = dropout_p 147 | ctx.seed = seed 148 | ctx.offset = offset 149 | 150 | has_extra_return = True in (return_log_normalizer, return_total_attention, return_seed_offset) 151 | if has_extra_return: 152 | outs = ( 153 | o, 154 | L if return_log_normalizer else None, 155 | tot_attn if return_total_attention else None, 156 | seed if is_dropout and return_seed_offset else None, 157 | offset if is_dropout and return_seed_offset else None 158 | ) 159 | return outs 160 | return o 161 | 162 | @staticmethod 163 | def backward(ctx, do, *ignored): 164 | q, k, v, o, L = ctx.saved_tensors 165 | sm_scale = ctx.sm_scale 166 | causal = ctx.causal 167 | dropout_p = ctx.dropout_p 168 | is_dropout = ctx.dropout_p > 0 169 | seed = ctx.seed 170 | offset = ctx.offset 171 | 172 | B, H, M, D = q.shape 173 | N = k.shape[2] 174 | Hk = k.shape[1] 175 | num_groups = H // Hk 176 | P_SEQ = N - M 177 | larger_m = M > N 178 | 179 | if sm_scale is None: 180 | sm_scale = 1. / math.sqrt(D) 181 | 182 | # to work around https://github.com/openai/triton/issues/2441 183 | device = torch.cuda.device_of(q) 184 | with torch.cuda.device(device): 185 | config = get_bwd_config(B, H, M, N, D, causal) 186 | BLOCK_M, BLOCK_N, num_stages, num_warps = config 187 | 188 | divisible_m = M % BLOCK_M == 0 189 | divisible_n = N % BLOCK_N == 0 190 | 191 | delta = torch.empty_like(L) 192 | grid = (triton.cdiv(M, BLOCK_M), H, B) 193 | _bwd_preprocess[grid]( 194 | o, do, 195 | delta, 196 | o.stride(0), o.stride(1), o.stride(2), o.stride(3), 197 | do.stride(0), do.stride(1), do.stride(2), do.stride(3), 198 | delta.stride(0), delta.stride(1), delta.stride(2), 199 | M, 200 | BLOCK_M=BLOCK_M, D_HEAD=D, 201 | DIVISIBLE_M=divisible_m, 202 | ) 203 | 204 | # NOTE that dk & dv always have the same number of heads as q, instead of q. 205 | dk = torch.empty((B, H, N, D), dtype=k.dtype, device=q.device) 206 | dv = torch.empty((B, H, N, D), dtype=v.dtype, device=q.device) 207 | grid = (triton.cdiv(N, BLOCK_N), H, B) 208 | _bwd_kv_kernel[grid]( 209 | q, k, v, sm_scale, do, 210 | dk, dv, 211 | L, delta, 212 | dropout_p, 213 | seed, 214 | offset, 215 | q.stride(0), q.stride(1), q.stride(2), q.stride(3), 216 | k.stride(0), k.stride(1), k.stride(2), k.stride(3), 217 | v.stride(0), v.stride(1), v.stride(2), v.stride(3), 218 | do.stride(0), do.stride(1), do.stride(2), do.stride(3), 219 | dk.stride(0), dk.stride(1), dk.stride(2), dk.stride(3), 220 | dv.stride(0), dv.stride(1), dv.stride(2), dv.stride(3), 221 | B, H, M, N, P_SEQ, 222 | num_groups, 223 | BLOCK_M=BLOCK_M, BLOCK_DMODEL=D, BLOCK_N=BLOCK_N, CAUSAL=causal, 224 | IS_DROPOUT=is_dropout, 225 | DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n, 226 | num_stages=num_stages, num_warps=num_warps, 227 | ) 228 | 229 | dq = torch.zeros_like(q) 230 | grid = (triton.cdiv(M, BLOCK_M), H, B) 231 | _bwd_q_kernel[grid]( 232 | q, k, v, sm_scale, do, 233 | dq, 234 | L, delta, 235 | dropout_p, 236 | seed, 237 | offset, 238 | q.stride(0), q.stride(1), q.stride(2), q.stride(3), 239 | k.stride(0), k.stride(1), k.stride(2), k.stride(3), 240 | v.stride(0), v.stride(1), v.stride(2), v.stride(3), 241 | do.stride(0), do.stride(1), do.stride(2), do.stride(3), 242 | dq.stride(0), dq.stride(1), dq.stride(2), dq.stride(3), 243 | B, H, M, N, P_SEQ, 244 | num_groups, 245 | BLOCK_M=BLOCK_M, BLOCK_DMODEL=D, BLOCK_N=BLOCK_N, 246 | CAUSAL=causal, IS_DROPOUT=is_dropout, LARGER_M=larger_m, 247 | DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n, 248 | num_stages=num_stages, num_warps = num_warps, 249 | ) 250 | dk = dk.reshape((B, Hk, num_groups, N, D)).sum(2) 251 | dv = dv.reshape((B, Hk, num_groups, N, D)).sum(2) 252 | return dq, dk, dv, None, None, None, None, None, None 253 | 254 | 255 | def attention(q, k, v, causal=False, sm_scale=None, dropout_p=0.0, 256 | return_log_normalizer=False, return_total_attention=False, return_seed_offset=False 257 | ): 258 | """ 259 | An implementation of FlashAttention v2(https://arxiv.org/abs/2307.08691). 260 | 261 | Arguments: 262 | q(torch.Tensor): The first queries. The shape is (batch_size, num_heads_q, seqlen_q, headdim). 263 | k(torch.Tensor): The first keys. The shape is (batch_size, num_heads_k, seqlen_k, headdim). 264 | v(torch.Tensor): The values. The shape is (batch_size, num_heads_k, seqlen_k, headdim). 265 | causal(bool): Whether causal masking is applied to attention scores before applying softmax. 266 | sm_scale(float): The scaling of attention scores before applying softmax. 267 | dropout_p(float): Dropout probability. 268 | return_log_normalizer(bool): Whether to return the log normalizer of softmax inside attention. 269 | return_total_attention(bool): Whether to return the sum of attention along q's sequence dimendion. 270 | return_seed_offset(bool): Whether to return dropout seed and offset 271 | 272 | Returns: 273 | out(torch.Tensor): The output. The shape is (batch_size, num_heads_q, seqlen_q, headdim). 274 | 275 | If `return_log_normalizer` or `return_total_attention` or `return_seed_offset` is True, 276 | return the following results in addition. 277 | 278 | log_normalizer(torch.Tensor): The log normalizer. The shape is (batch_size, num_heads_q, seqlen_q). 279 | total_attention(torch.Tensor): The total attention. The shape is (batch_size, num_heads_q, seqlen_k). 280 | seed(int): The Philox seed used in dropout. 281 | offset(int): The starting Philox offset used in dropout. 282 | 283 | Notes: 284 | `num_heads_q` must be a multiple of `num_heads_k`. 285 | """ 286 | return FlashAttention.apply(q, k, v, causal, sm_scale, dropout_p, return_log_normalizer, return_total_attention, return_seed_offset) 287 | 288 | 289 | # --------------------------- Forward --------------------------- 290 | # NOTE: this function can be overwritten at runtime to use your custom config 291 | def get_fwd_config(B, H, M, N, D, causal): 292 | if torch.cuda.get_device_capability() == (8, 0): 293 | if not causal: 294 | if D <= 64: 295 | BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 64, 3, 4 296 | else: 297 | if M <= 1024: 298 | BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 32, 3, 4 299 | else: 300 | BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 128, 3, 8 301 | else: 302 | if D <= 64: 303 | BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 64, 4, 4 304 | else: 305 | if M <= 1024: 306 | BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 32, 2, 4 307 | else: 308 | BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 128, 3, 8 309 | elif torch.cuda.get_device_capability() == (8, 6): 310 | if not causal: 311 | if D <= 64: 312 | BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 64, 3, 4 313 | else: 314 | BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 32, 2, 4 315 | else: # causal 316 | if D <= 64: 317 | BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 3, 4 318 | else: 319 | BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 32, 2, 4 320 | else: 321 | BLOCK_M, BLOCK_N, num_stages, num_warps = 32, 32, 1, 4 322 | return (BLOCK_M, BLOCK_N, num_stages, num_warps) 323 | 324 | 325 | @triton.jit 326 | def _fwd_kernel( 327 | Q, K, V, sm_scale, 328 | dropout_p, 329 | seed, 330 | offset, 331 | L, O, 332 | stride_qz, stride_qh, stride_qm, stride_qk, 333 | stride_kz, stride_kh, stride_kn, stride_kk, 334 | stride_vz, stride_vh, stride_vn, stride_vk, 335 | stride_oz, stride_oh, stride_om, stride_ok, 336 | Z, H, M, N, P_SEQ, 337 | num_groups, 338 | BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, 339 | IS_CAUSAL: tl.constexpr, IS_DROPOUT: tl.constexpr, LARGER_M: tl.constexpr, 340 | DIVISIBLE_M: tl.constexpr, DIVISIBLE_N: tl.constexpr, 341 | ): 342 | input_dtype = Q.dtype.element_ty 343 | # -- grid id -- 344 | start_m = tl.program_id(0) 345 | off_h = tl.program_id(1) 346 | off_z = tl.program_id(2) 347 | 348 | # scale sm_scale by log_2(e) and use 349 | # 2^x instead of exp in the loop because CSE and LICM 350 | # don't work as expected with `exp` in the loop 351 | log2e: tl.constexpr = 1.4426950408889634 352 | qk_scale = sm_scale * log2e 353 | 354 | # offset pointers for (batch, head) 355 | off_hk = off_h // num_groups 356 | Q += off_z * stride_qz + off_h * stride_qh 357 | K += off_z * stride_kz + off_hk * stride_kh 358 | V += off_z * stride_vz + off_hk * stride_vh 359 | O += off_z * stride_oz + off_h * stride_oh 360 | L += (off_z * H + off_h) * M # l's shape is (B, H, M) 361 | 362 | offs_m_base = tl.arange(0, BLOCK_M) 363 | offs_m = start_m * BLOCK_M + offs_m_base 364 | offs_n_base = tl.arange(0, BLOCK_N) 365 | offs_k = tl.arange(0, BLOCK_DMODEL) 366 | 367 | if IS_DROPOUT: 368 | rowblock_base = off_z * H * M * N + off_h * M * N + start_m * BLOCK_M * N 369 | offs_rng_base = offset + rowblock_base 370 | offs_rng_base += tl.arange(0, BLOCK_M)[:, None] * N 371 | offs_rng_base += tl.arange(0, BLOCK_N)[None, :] 372 | 373 | # initialize pointers to value-like data 374 | q_ptrs = Q + (offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk) # (BLOCK_M, BLOCK_DMODEL) 375 | o_ptrs = O + (offs_m[:, None] * stride_om + offs_k[None, :] * stride_ok) # (BLOCK_M, BLOCK_DMODEL) 376 | l_ptrs = L + offs_m 377 | 378 | # initialize pointer to m and l, fp32 for accumulators 379 | m_i = tl.full([BLOCK_M], value=-float("inf"), dtype=tl.float32) 380 | l_i = tl.zeros([BLOCK_M], dtype=tl.float32) 381 | acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) 382 | 383 | # load q 384 | if DIVISIBLE_M: 385 | q = tl.load(q_ptrs, cache_modifier=".cg") 386 | else: 387 | mask_m = offs_m < M 388 | q = tl.load(q_ptrs, mask=mask_m[:, None], cache_modifier=".cg") 389 | 390 | #Dot I trick: to place q in registers, it saves shared memory 391 | if BLOCK_DMODEL < 128: 392 | I = tl.where(offs_k[:, None] == offs_k, 393 | tl.full((BLOCK_DMODEL, BLOCK_DMODEL), 1.0, dtype=input_dtype), 394 | tl.full((BLOCK_DMODEL, BLOCK_DMODEL), 0.0, dtype=input_dtype)) 395 | q = tl.dot(q, I).to(input_dtype) 396 | # else: 397 | # I = tl.where(offs_m_base[:, None] == offs_m_base, 398 | # tl.full((BLOCK_M, BLOCK_M), 1.0, dtype=input_dtype), 399 | # tl.full((BLOCK_M, BLOCK_M), 0.0, dtype=input_dtype)) 400 | # q = tl.dot(I, q).to(input_dtype) 401 | 402 | # NOTE: Loop-Bound-For-N 403 | # The indices in m-dimension that this block may access is in `[start_m * BLOCK_M, (start_m + 1) * BLOCK_M)`. 404 | # According to the rule of causal masking, then max index in n-dimension that this block may access 405 | # is `P_SEQ + (start_m + 1) * BLOCK_M`. 406 | # However, the upper bound of index in n-dimension should never exceed the sequence length of k/v(`P_SEQ + N_CTX`). 407 | # `P_SEQ + (start_m + 1) * BLOCK_M` may be larger than `N`. 408 | # At this case, there would be illegal memory access when loading k & v tiles 409 | # if mask_n is not applied for loading(only when `DIVISIBLE_N`` is true). 410 | # See also https://github.com/FlagOpen/FlagAttention/pull/8 411 | if IS_CAUSAL: 412 | hi = tl.minimum(N, P_SEQ + (start_m + 1) * BLOCK_M) 413 | if LARGER_M: 414 | hi = tl.maximum(0, hi) 415 | else: 416 | hi = N 417 | 418 | # loop over k, v and update accumulators 419 | offs_n_init = offs_n_base 420 | k_ptrs = K + (offs_k[:, None] * stride_vk + offs_n_init[None, :] * stride_vn) # (BLOCK_DMODEL, BLOCK_N) 421 | v_ptrs = V + (offs_n_init[:, None] * stride_kn + offs_k[None, :] * stride_kk) # (BLOCK_N, BLOCK_DMODEL) 422 | for start_n in range(0, hi, BLOCK_N): 423 | start_n = tl.multiple_of(start_n, BLOCK_N) 424 | offs_n = start_n + offs_n_base 425 | 426 | # -- load k, v -- 427 | if DIVISIBLE_N: 428 | k = tl.load(k_ptrs, cache_modifier=".cg") 429 | v = tl.load(v_ptrs, cache_modifier=".cg") 430 | else: 431 | mask_n = offs_n < N 432 | k = tl.load(k_ptrs, mask=mask_n[None, :], cache_modifier=".cg") 433 | v = tl.load(v_ptrs, mask=mask_n[:, None], cache_modifier=".cg") 434 | 435 | # -- compute qk --- 436 | s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 437 | s += tl.dot(q, k) 438 | 439 | if not DIVISIBLE_N: 440 | s = tl.where(mask_n[None, :], s, float("-inf")) 441 | if IS_CAUSAL: 442 | causal_mask = (P_SEQ + offs_m[:, None]) >= offs_n[None, :] 443 | s = tl.where(causal_mask, s, float("-inf")) 444 | 445 | # -- compute scaling constant --- 446 | m_i_new = tl.maximum(m_i, tl.max(s, 1)) 447 | alpha = tl.math.exp2((m_i - m_i_new) * qk_scale) 448 | p = tl.math.exp2(s * qk_scale - m_i_new[:, None] * qk_scale) 449 | 450 | # -- compute partial sumexpn before applying dropout 451 | p_sum = tl.sum(p, 1) 452 | 453 | # -- apply dropout -- 454 | if IS_DROPOUT: 455 | offs_rng = start_n + offs_rng_base 456 | pmask = tl.rand(seed, offs_rng, n_rounds=6) > dropout_p 457 | p *= pmask.to(tl.float32) 458 | 459 | # -- scale and update acc: acc *= alpha[:, None]-- 460 | acc *= alpha[:, None] 461 | acc += tl.dot(p.to(input_dtype), v) 462 | 463 | # -- update m_i and l_i -- 464 | l_i = l_i * alpha + p_sum 465 | m_i = m_i_new 466 | # update pointers 467 | k_ptrs += BLOCK_N * stride_kn 468 | v_ptrs += BLOCK_N * stride_vn 469 | 470 | # write back l & o 471 | if IS_CAUSAL and LARGER_M: 472 | is_empty_line = (offs_m + P_SEQ) < 0 473 | acc = tl.where(is_empty_line[:, None], 0.0, acc * (1.0 / l_i[:, None])) 474 | l = tl.where(is_empty_line, float("-inf"), m_i * sm_scale + tl.log(l_i)) 475 | else: 476 | acc = acc * (1.0 / l_i[:, None]) 477 | l = m_i * sm_scale + tl.log(l_i) # log(normalizer) 478 | 479 | # -- scale o due to dropout 480 | if IS_DROPOUT: 481 | scale = 1.0 / (1.0 - dropout_p) 482 | acc *= scale 483 | 484 | if DIVISIBLE_M: 485 | tl.store(l_ptrs, l, cache_modifier=".cg") 486 | tl.store(o_ptrs, acc.to(input_dtype), cache_modifier=".cg") 487 | else: 488 | tl.store(l_ptrs, l, mask=mask_m, cache_modifier=".cg") 489 | tl.store(o_ptrs, acc.to(input_dtype), mask=mask_m[:, None], cache_modifier=".cg") 490 | 491 | 492 | # --------------------------- Backward --------------------------- 493 | # NOTE: this function can be overwritten at runtime to use your custom config 494 | def get_bwd_config(B, H, M, N, D, causal): 495 | if torch.cuda.get_device_capability() == (8, 0): 496 | if not causal: 497 | BLOCK_M = 128 if D <= 64 else 64 498 | BLOCK_N = 64 499 | num_stages = 2 500 | num_warps = 4 501 | else: 502 | BLOCK_M = 64 503 | BLOCK_N = 64 504 | num_stages = 3 if D <= 64 else 2 505 | num_warps = 4 506 | elif torch.cuda.get_device_capability() == (8, 6): # tune for RTX-3090, device_capability(8, 6) 507 | if not causal: 508 | if D <= 64: 509 | BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 2, 4 510 | else: 511 | BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 2, 8 512 | else: 513 | if D <= 64: 514 | BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 2, 4 515 | else: 516 | BLOCK_M, BLOCK_N, num_stages, num_warps = 32, 32, 2, 4 517 | else: 518 | BLOCK_M, BLOCK_N, num_stages, num_warps = 32, 32, 1, 4 519 | return (BLOCK_M, BLOCK_N, num_stages, num_warps) 520 | 521 | 522 | @triton.jit 523 | def _bwd_preprocess( 524 | Out, DO, 525 | Delta, 526 | stride_oz, stride_oh, stride_om, stride_ok, 527 | stride_doz, stride_doh, stride_dom, stride_dok, 528 | stride_dz, stride_dh, stride_dm, 529 | M, 530 | BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr, 531 | DIVISIBLE_M: tl.constexpr, 532 | ): 533 | off_h = tl.program_id(1) 534 | off_z = tl.program_id(2) 535 | Out += off_z * stride_oz + off_h * stride_oh 536 | DO += off_z * stride_doz + off_h * stride_doh 537 | Delta += off_z * stride_dz + off_h * stride_dh 538 | 539 | # compute (Out * Dout).sum() for vector interpretation 540 | off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) 541 | off_n = tl.arange(0, D_HEAD) 542 | 543 | # load 544 | o_ptrs = Out + off_m[:, None] * stride_om + off_n[None, :] * stride_ok 545 | do_ptrs = DO + off_m[:, None] * stride_dom + off_n[None, :] * stride_dok 546 | 547 | if DIVISIBLE_M: 548 | o = tl.load(o_ptrs).to(tl.float32) 549 | do = tl.load(do_ptrs).to(tl.float32) 550 | else: 551 | mask_m = off_m < M 552 | o = tl.load(o_ptrs, mask=mask_m[:, None]).to(tl.float32) 553 | do = tl.load(do_ptrs, mask=mask_m[:, None]).to(tl.float32) 554 | 555 | # compute 556 | delta = tl.sum(o * do, axis=1) 557 | 558 | # (NOTE) dropout scaling doesn't affect delta's value 559 | # when dropout is applied, o and do are actually scaled. 560 | # original_o equals o times reverse scale while original_do is do times scale, 561 | # and thus delta remains unchanged. 562 | 563 | # write-back 564 | d_ptrs = Delta + off_m * stride_dm 565 | if DIVISIBLE_M: 566 | tl.store(d_ptrs, delta) 567 | else: 568 | tl.store(d_ptrs, delta, mask=mask_m) 569 | 570 | 571 | @triton.jit 572 | def _bwd_kv_kernel( 573 | Q, K, V, sm_scale, DO, 574 | DK, DV, 575 | L, 576 | D, 577 | dropout_p, 578 | seed, 579 | offset, 580 | stride_qz, stride_qh, stride_qm, stride_qk, 581 | stride_kz, stride_kh, stride_kn, stride_kk, 582 | stride_vz, stride_vh, stride_vn, stride_vk, 583 | stride_doz, stride_doh, stride_dom, stride_dok, 584 | stride_dkz, stride_dkh, stride_dkn, stride_dkk, 585 | stride_dvz, stride_dvh, stride_dvn, stride_dvk, 586 | Z, H, M, N, P_SEQ, 587 | num_groups, 588 | BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, 589 | CAUSAL: tl.constexpr, IS_DROPOUT: tl.constexpr, 590 | DIVISIBLE_M: tl.constexpr, DIVISIBLE_N: tl.constexpr, 591 | ): 592 | input_dtype = Q.dtype.element_ty 593 | # -- grid id -- 594 | start_n = tl.program_id(0) 595 | off_h = tl.program_id(1) 596 | off_z = tl.program_id(2) 597 | log2e: tl.constexpr = 1.4426950408889634 598 | qk_scale = sm_scale * log2e 599 | 600 | # offset pointers for (batch, head) 601 | off_hk = off_h // num_groups 602 | Q += off_z * stride_qz + off_h * stride_qh 603 | K += off_z * stride_kz + off_hk * stride_kh 604 | V += off_z * stride_vz + off_hk * stride_vh 605 | DO += off_z * stride_doz + off_h * stride_doh 606 | 607 | # offset pointers for batch/head 608 | DK += off_z * stride_dkz + off_h * stride_dkh 609 | DV += off_z * stride_dvz + off_h * stride_dvh 610 | 611 | # offset pointers for batch/head 612 | D += (off_z * H + off_h) * M 613 | L += (off_z * H + off_h) * M 614 | 615 | if CAUSAL: 616 | lo = tl.maximum(start_n * BLOCK_N - P_SEQ, 0) 617 | lo = (lo // BLOCK_M) * BLOCK_M 618 | else: 619 | lo = 0 620 | 621 | offs_m_init = lo + tl.arange(0, BLOCK_M) 622 | offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) 623 | offs_m_base = tl.arange(0, BLOCK_M) 624 | offs_k = tl.arange(0, BLOCK_DMODEL) 625 | 626 | # initialize pointers to value-like data 627 | q_ptrs = Q + (offs_m_init[:, None] * stride_qm + offs_k[None, :] * stride_qk) # (BLOCK_M, BLOCK_DMODEL) 628 | k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) # (BLOCK_N, BLOCK_DMODEL) 629 | v_ptrs = V + (offs_n[:, None] * stride_vn + offs_k[None, :] * stride_vk) # (BLOCK_N, BLOCK_DMODEL) 630 | do_ptrs = DO + (offs_m_init[:, None] * stride_dom + offs_k[None, :] * stride_dok) # (BLOCK_M, BLOCK_DMODEL) 631 | 632 | dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_k[None, :] * stride_dvk) # (BLOCK_N, BLOCK_DMODEL) 633 | dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_k[None, :] * stride_dkk) # (BLOCK_N, BLOCK_DMODEL) 634 | 635 | # k and v stay in SRAM throughout 636 | if DIVISIBLE_N: 637 | v = tl.load(v_ptrs) 638 | k = tl.load(k_ptrs) 639 | else: 640 | mask_n = offs_n < N 641 | v = tl.load(v_ptrs, mask=mask_n[:, None]) 642 | k = tl.load(k_ptrs, mask=mask_n[:, None]) 643 | 644 | # dropout 645 | if IS_DROPOUT: 646 | colblock_base = off_z * H * M * N + off_h * M * N + start_n * BLOCK_N 647 | offs_rng_base = offset + colblock_base 648 | offs_rng_base += tl.arange(0, BLOCK_M)[:, None] * N 649 | offs_rng_base += tl.arange(0, BLOCK_N)[None, :] 650 | rp = 1. / (1. - dropout_p) 651 | 652 | # initialize dk amd dv 653 | dk = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) 654 | dv = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) 655 | 656 | # loop over a col 657 | for start_m in range(lo, M, BLOCK_M): 658 | start_m = tl.multiple_of(start_m, BLOCK_M) 659 | offs_m = start_m + offs_m_base 660 | causal_mask = (P_SEQ + offs_m[:, None]) >= (offs_n[None, :]) # (BLOCK_M, BLOCK_N) 661 | 662 | # load q1, k1, q2, k2, v, do on-chip 663 | if DIVISIBLE_M: 664 | q = tl.load(q_ptrs) 665 | else: 666 | mask_m = offs_m < M 667 | valid_mask = mask_m[:, None] # & mask_n 668 | q = tl.load(q_ptrs, mask=mask_m[:, None]) 669 | # recompute p = softmax(qk * sm_scale, dim=-1) 670 | s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 671 | s += tl.dot(q, tl.trans(k)) 672 | 673 | # NOTE: since softmax in backward is pointwise, the normalizer has been saved in fwd) 674 | # So masking on s is not needed. 675 | # s = tl.where(valid_mask, s , float("-inf")) 676 | # if CAUSAL: 677 | # s = tl.where(causal_mask, s, float("-inf")) 678 | 679 | # -- recompute p --- 680 | if DIVISIBLE_M: 681 | l = tl.load(L + offs_m) 682 | else: 683 | l = tl.load(L + offs_m, mask=mask_m) 684 | p = tl.math.exp2(s * qk_scale - l[:, None] * log2e) # (BLOCK_M, BLOCK_N) 685 | 686 | if not DIVISIBLE_M: 687 | p = tl.where(valid_mask, p, 0.0) 688 | if CAUSAL: 689 | p = tl.where(causal_mask, p, 0.0) 690 | 691 | # compute dv = dot(p, do) 692 | if DIVISIBLE_M: 693 | do = tl.load(do_ptrs) 694 | else: 695 | do = tl.load(do_ptrs, mask=mask_m[:, None]) # (BLOCK_M, BLOCK_DMODEL) 696 | 697 | if IS_DROPOUT: 698 | # do *= rp 699 | offs_rng = offs_rng_base + start_m * N 700 | pmask = tl.rand(seed, offs_rng, n_rounds=6) > dropout_p 701 | p_masked = p * pmask 702 | p_masked = p_masked.to(input_dtype) 703 | 704 | # -- apply dropout -- 705 | if IS_DROPOUT: 706 | dv += tl.dot(tl.trans(p_masked), do) * rp # (BLOCK_N, BLOCK_DMODEL) # still correct 707 | else: 708 | dv += tl.dot(tl.trans(p).to(input_dtype), do) # (BLOCK_N, BLOCK_DMODEL) # still correct 709 | 710 | # compute dp = dot(v, do) 711 | if DIVISIBLE_M: 712 | delta = tl.load(D + offs_m) 713 | else: 714 | delta = tl.load(D + offs_m, mask=mask_m) 715 | dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 716 | dp += tl.dot(do, tl.trans(v)) 717 | 718 | # -- apply dropout -- 719 | if IS_DROPOUT: 720 | dp *= rp 721 | dp *= pmask 722 | 723 | # compute ds = p * (dp - delta[:, None]) 724 | ds = p * (dp - delta[:, None]) # (BLOCK_M, BLOCK_N) 725 | 726 | if not DIVISIBLE_M: 727 | ds = tl.where(valid_mask, ds, 0.0) 728 | if CAUSAL: 729 | ds = tl.where(causal_mask, ds, 0.0) 730 | ds = ds.to(input_dtype) 731 | 732 | # compute dk = dot(ds.T, q) masking 733 | dk += tl.dot(tl.trans(ds), q) 734 | 735 | # increment pointers 736 | q_ptrs += BLOCK_M * stride_qm 737 | do_ptrs += BLOCK_M * stride_dom 738 | 739 | dk *= sm_scale 740 | if DIVISIBLE_N: 741 | tl.store(dk_ptrs, dk.to(input_dtype)) # (BLOCK_N, BLOCK_DMODEL) 742 | tl.store(dv_ptrs, dv.to(input_dtype)) # (BLOCK_N, BLOCK_DMODEL,) 743 | else: 744 | tl.store(dk_ptrs, dk.to(input_dtype), mask=mask_n[:, None]) # (BLOCK_N, BLOCK_DMODEL) 745 | tl.store(dv_ptrs, dv.to(input_dtype), mask=mask_n[:, None]) # (BLOCK_N, BLOCK_DMODEL,) 746 | 747 | 748 | @triton.jit 749 | def _bwd_q_kernel( 750 | Q, K, V, sm_scale, DO, 751 | DQ, 752 | L, 753 | D, 754 | dropout_p, 755 | seed, 756 | offset, 757 | stride_qz, stride_qh, stride_qm, stride_qk, 758 | stride_kz, stride_kh, stride_kn, stride_kk, 759 | stride_vz, stride_vh, stride_vn, stride_vk, 760 | stride_doz, stride_doh, stride_dom, stride_dok, 761 | stride_dqz, stride_dqh, stride_dqm, stride_dqk, 762 | Z, H, M, N, P_SEQ, 763 | num_groups, 764 | BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, 765 | CAUSAL: tl.constexpr, IS_DROPOUT: tl.constexpr, LARGER_M: tl.constexpr, 766 | DIVISIBLE_M: tl.constexpr, DIVISIBLE_N: tl.constexpr, 767 | ): 768 | input_dtype = Q.dtype.element_ty 769 | # -- grid id -- 770 | start_m = tl.program_id(0) 771 | off_h = tl.program_id(1) 772 | off_z = tl.program_id(2) 773 | 774 | # scale sm_scale by log_2(e) and use 775 | # 2^x instead of exp in the loop because CSE and LICM 776 | # don't work as expected with `exp` in the loop 777 | log2e: tl.constexpr = 1.4426950408889634 778 | qk_scale = sm_scale * log2e 779 | 780 | # offset pointers for (batch, head) 781 | off_hk = off_h // num_groups 782 | Q += off_z * stride_qz + off_h * stride_qh 783 | K += off_z * stride_kz + off_hk * stride_kh 784 | V += off_z * stride_vz + off_hk * stride_vh 785 | DO += off_z * stride_doz + off_h * stride_doh 786 | D += (off_z * H + off_h) * M 787 | L += (off_z * H + off_h) * M 788 | 789 | # offset pointers for batch/head 790 | DQ += off_z * stride_dqz + off_h * stride_dqh 791 | 792 | offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) 793 | offs_n_base = tl.arange(0, BLOCK_N) 794 | offs_n_init = offs_n_base 795 | offs_k = tl.arange(0, BLOCK_DMODEL) 796 | 797 | # initialize pointers to value-like data 798 | q_ptrs = Q + (offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk) # (BLOCK_M, BLOCK_DMODEL) 799 | k_ptrs = K + (offs_n_init[:, None] * stride_kn + offs_k[None, :] * stride_kk) # (BLOCK_N, BLOCK_DMODEL) 800 | v_ptrs = V + (offs_n_init[:, None] * stride_vn + offs_k[None, :] * stride_vk) # (BLOCK_N, BLOCK_DMODEL) 801 | 802 | dq_ptrs = DQ + (offs_m[:, None] * stride_dqm + offs_k[None, :] * stride_dqk) # (BLOCK_M, BLOCK_DMODEL) 803 | do_ptrs = DO + (offs_m[:, None] * stride_dom + offs_k[None, :] * stride_dok) # (BLOCK_M, BLOCK_DMODEL) 804 | 805 | # pointer to row-wise quantities in value-like data 806 | d_ptrs = D + offs_m 807 | l_ptrs = L + offs_m 808 | 809 | # load q: it will stay in SRAM throughout 810 | if DIVISIBLE_M: 811 | q = tl.load(q_ptrs) 812 | do = tl.load(do_ptrs) 813 | delta = tl.load(d_ptrs) 814 | l = tl.load(l_ptrs) 815 | else: 816 | mask_m = offs_m < M 817 | q = tl.load(q_ptrs, mask=mask_m[:, None]) 818 | do = tl.load(do_ptrs, mask=mask_m[:, None]) 819 | delta = tl.load(d_ptrs, mask=mask_m) 820 | l = tl.load(l_ptrs, mask=mask_m) 821 | 822 | # initialize dq 823 | dq = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) 824 | 825 | # loop over k, v and update accumulator 826 | # see note "Loop-Bound-For-N" 827 | if CAUSAL: 828 | hi = tl.minimum(N, P_SEQ + (start_m + 1) * BLOCK_M) 829 | if LARGER_M: 830 | hi = tl.maximum(0, hi) 831 | else: 832 | hi = N 833 | 834 | # dropout 835 | if IS_DROPOUT: 836 | rowblock_base = off_z * H * M * N + off_h * M * N + start_m * BLOCK_M * N 837 | offs_rng_base = offset + rowblock_base 838 | offs_rng_base += tl.arange(0, BLOCK_M)[:, None] * N 839 | offs_rng_base += tl.arange(0, BLOCK_N)[None, :] 840 | rp = 1. / (1. - dropout_p) 841 | do *= rp.to(do.dtype) 842 | 843 | # loop over a row 844 | for start_n in range(0, hi, BLOCK_N): 845 | offs_n = start_n + offs_n_base 846 | 847 | # load k1, k2, v on chip 848 | if DIVISIBLE_N: 849 | v = tl.load(v_ptrs) 850 | k = tl.load(k_ptrs) 851 | else: 852 | mask_n = offs_n < N 853 | v = tl.load(v_ptrs, mask=mask_n[:, None]) 854 | k = tl.load(k_ptrs, mask=mask_n[:, None]) 855 | 856 | 857 | # recompute p = softmax(qk * sm_scale, dim=-1) 858 | if not DIVISIBLE_N: 859 | valid_mask = mask_n # & mask_m[:, None] 860 | if CAUSAL: 861 | causal_mask = (P_SEQ + offs_m[:, None]) >= (offs_n[None, :]) # (BLOCK_M, BLOCK_N) 862 | s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 863 | s += tl.dot(q, tl.trans(k)) 864 | 865 | # NOTE: since softmax in backward is pointwise, the normalizer has been saved in fwd) 866 | # So masking on s is not needed. 867 | # if CAUSAL: 868 | # s = tl.where(causal_mask & valid_mask, s, float("-inf")) 869 | # else: 870 | # s = tl.where(valid_mask, s, float("-inf")) 871 | p = tl.math.exp2(s * qk_scale - l[:, None] * log2e) # (BLOCK_M, BLOCK_N) 872 | 873 | # compute dp = dot(v, do) 874 | dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 875 | dp += tl.dot(do.to(input_dtype), tl.trans(v)) 876 | 877 | if IS_DROPOUT: 878 | offs_rng = start_n + offs_rng_base 879 | pmask = tl.rand(seed, offs_rng, n_rounds=6) > dropout_p 880 | dp *= pmask 881 | # p_dropout = p * pmask.to(tl.float32) 882 | 883 | # no need to mask dp 884 | # if CAUSAL: 885 | # dp = tl.where(causal_mask & valid_mask, dp, 0.0) 886 | # else: 887 | # dp = tl.where(valid_mask, dp, 0.0) 888 | 889 | # compute ds = p * (dp - delta[:, None]) 890 | # move scale out to dq at last 891 | ds = p * (dp - delta[:, None]) # (BLOCK_M, BLOCK_N) 892 | 893 | # mask ds to ensure no small values 894 | if not DIVISIBLE_N: 895 | ds = tl.where(valid_mask, ds, 0.0) 896 | if CAUSAL: 897 | ds = tl.where(causal_mask, ds, 0.0) 898 | 899 | dq += tl.dot(ds.to(input_dtype), k) 900 | 901 | # increment pointers 902 | k_ptrs += BLOCK_N * stride_kn 903 | v_ptrs += BLOCK_N * stride_vn 904 | 905 | dq *= sm_scale 906 | if DIVISIBLE_M: 907 | tl.store(dq_ptrs, dq.to(input_dtype)) 908 | else: 909 | tl.store(dq_ptrs, dq.to(input_dtype), mask=mask_m[:, None]) 910 | --------------------------------------------------------------------------------