├── assets ├── logo │ ├── vertical-blue.png │ ├── vertical-dark.png │ ├── baai-flagopen.jpeg │ ├── horizontal-blue.png │ ├── horizontal-dark.png │ ├── horizontal-light.png │ └── vertical-light.png ├── headdim64-causal-A100.png ├── piecewise_attention.png ├── v0.2 │ ├── flash_attention.png │ ├── flash_attention_d64.png │ └── piecewise_attention.png ├── headdim128-causal-A100.png └── piecewise_attention_interface.png ├── src └── flag_attn │ ├── testing │ ├── __init__.py │ ├── dropout.py │ ├── paged.py │ ├── flash.py │ └── piecewise.py │ ├── __init__.py │ ├── dropout.py │ ├── total.py │ ├── split_kv.py │ ├── paged.py │ ├── piecewise.py │ └── flash.py ├── LICENSE ├── examples ├── use_cutom_config_func.py ├── flash_attention_example.py ├── piecewise_example.py ├── flash_attention_with_aux_outputs.py └── paged_example.py ├── tests └── flag_attn │ ├── test_dropout.py │ ├── test_paged_attention.py │ ├── test_piecewise_attention.py │ └── test_flash_attention.py ├── .pre-commit-config.yaml ├── pyproject.toml ├── .github └── workflows │ └── code-check.yml ├── .gitignore ├── benchmark ├── flash_decoding_benchmark.py ├── flash_benchmark.py ├── piecewise_benchmark.py └── paged_benchmark.py ├── README_cn.md └── README.md /assets/logo/vertical-blue.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flagos-ai/FlagAttention/HEAD/assets/logo/vertical-blue.png -------------------------------------------------------------------------------- /assets/logo/vertical-dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flagos-ai/FlagAttention/HEAD/assets/logo/vertical-dark.png -------------------------------------------------------------------------------- /assets/headdim64-causal-A100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flagos-ai/FlagAttention/HEAD/assets/headdim64-causal-A100.png -------------------------------------------------------------------------------- /assets/logo/baai-flagopen.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flagos-ai/FlagAttention/HEAD/assets/logo/baai-flagopen.jpeg -------------------------------------------------------------------------------- /assets/logo/horizontal-blue.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flagos-ai/FlagAttention/HEAD/assets/logo/horizontal-blue.png -------------------------------------------------------------------------------- /assets/logo/horizontal-dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flagos-ai/FlagAttention/HEAD/assets/logo/horizontal-dark.png -------------------------------------------------------------------------------- /assets/logo/horizontal-light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flagos-ai/FlagAttention/HEAD/assets/logo/horizontal-light.png -------------------------------------------------------------------------------- /assets/logo/vertical-light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flagos-ai/FlagAttention/HEAD/assets/logo/vertical-light.png -------------------------------------------------------------------------------- /assets/piecewise_attention.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flagos-ai/FlagAttention/HEAD/assets/piecewise_attention.png -------------------------------------------------------------------------------- /assets/v0.2/flash_attention.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flagos-ai/FlagAttention/HEAD/assets/v0.2/flash_attention.png -------------------------------------------------------------------------------- /assets/headdim128-causal-A100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flagos-ai/FlagAttention/HEAD/assets/headdim128-causal-A100.png -------------------------------------------------------------------------------- /assets/v0.2/flash_attention_d64.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flagos-ai/FlagAttention/HEAD/assets/v0.2/flash_attention_d64.png -------------------------------------------------------------------------------- /assets/v0.2/piecewise_attention.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flagos-ai/FlagAttention/HEAD/assets/v0.2/piecewise_attention.png -------------------------------------------------------------------------------- /assets/piecewise_attention_interface.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flagos-ai/FlagAttention/HEAD/assets/piecewise_attention_interface.png -------------------------------------------------------------------------------- /src/flag_attn/testing/__init__.py: -------------------------------------------------------------------------------- 1 | from flag_attn.testing.flash import attention as flash_attention # noqa: F401 2 | from flag_attn.testing.piecewise import attention as piecewise_attention # noqa: F401 3 | from flag_attn.testing.paged import attention as paged_attention # noqa: F401 4 | from flag_attn.testing.dropout import recompute_mask -------------------------------------------------------------------------------- /src/flag_attn/__init__.py: -------------------------------------------------------------------------------- 1 | try: 2 | from ._version import version as __version__ 3 | from ._version import version_tuple 4 | except ImportError: 5 | __version__ = "0.0.0" 6 | version_tuple = (0, 0, 0) 7 | 8 | 9 | from flag_attn.piecewise import attention as piecewise_attention # noqa: F401 10 | from flag_attn.flash import attention as flash_attention # noqa: F401 11 | from flag_attn.split_kv import attention as flash_attention_split_kv # noqa: F401 12 | from flag_attn.paged import attention as paged_attention # noqa: F401 13 | 14 | from flag_attn import testing # noqa: F401 15 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2023 BAAI 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | -------------------------------------------------------------------------------- /src/flag_attn/dropout.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import triton 3 | import triton.language as tl 4 | 5 | def philox_cuda_seed_offset(increment, device=None): 6 | device = device or torch.cuda.current_device() 7 | gen = torch.cuda.default_generators[device] 8 | state_copy = gen.get_state() 9 | c0, c1 = state_copy.view(torch.int64) 10 | seed, offset = int(c0), int(c1) 11 | increment = (increment + 3) // 4 * 4 12 | c1 += increment 13 | # get_state returns a new tensor, so it needs set_state to update the actual generator state. 14 | gen.set_state(state_copy) 15 | return seed, offset 16 | -------------------------------------------------------------------------------- /examples/use_cutom_config_func.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import flag_attn 3 | 4 | 5 | # replace the default config function 6 | from flag_attn import flash 7 | def get_fwd_config(B, H, M, N, D, causal): 8 | return (64, 64, 1, 4) 9 | flash.get_fwd_config = get_fwd_config 10 | 11 | B, H, M, N, D = 2, 16, 4096, 4096, 128 12 | causal = True 13 | 14 | q = torch.randn(B, H, M, D, dtype=torch.bfloat16, device="cuda:0", requires_grad=True) 15 | k = torch.randn(B, H, N, D, dtype=torch.bfloat16, device="cuda:0", requires_grad=True) 16 | v = torch.randn(B, H, N, D, dtype=torch.bfloat16, device="cuda:0", requires_grad=True) 17 | 18 | o = flag_attn.flash_attention(q, k, v, causal=causal) 19 | go = torch.randn_like(o) 20 | gq, gk, gv = torch.autograd.grad(o, (q, k, v), go) 21 | -------------------------------------------------------------------------------- /examples/flash_attention_example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import flag_attn 3 | 4 | B, H, M, N, D = 2, 16, 4000, 4000, 128 5 | causal = True 6 | 7 | q = torch.randn(B, H, M, D, dtype=torch.bfloat16, device="cuda:0", requires_grad=True) 8 | k = torch.randn(B, H, N, D, dtype=torch.bfloat16, device="cuda:0", requires_grad=True) 9 | v = torch.randn(B, H, N, D, dtype=torch.bfloat16, device="cuda:0", requires_grad=True) 10 | 11 | 12 | o_ref = flag_attn.testing.flash_attention(q, k, v, causal=causal, upcast=True) 13 | o = flag_attn.flash_attention(q, k, v, causal=causal) 14 | o_torch = flag_attn.testing.flash_attention(q, k, v, causal=causal) 15 | 16 | go = torch.randn_like(o) 17 | gq_ref, gk_ref, gv_ref = torch.autograd.grad(o_ref, (q, k, v), go) 18 | gq, gk, gv = torch.autograd.grad(o, (q, k, v), go) 19 | gq_torch, gk_torch, gv_torch = torch.autograd.grad(o_torch, (q, k, v), go) 20 | -------------------------------------------------------------------------------- /examples/piecewise_example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from flag_attn import piecewise_attention 3 | 4 | B, H, T, D = 2, 16, 8192, 128 5 | dist_threshold = T // 2 6 | 7 | q1 = torch.randn((B, H, T, D), dtype=torch.float16, device="cuda:0").requires_grad_() 8 | q2 = torch.randn((B, H, T, D), dtype=torch.float16, device="cuda:0").requires_grad_() 9 | k1 = torch.randn((B, H, T, D), dtype=torch.float16, device="cuda:0").requires_grad_() 10 | k2 = torch.randn((B, H, T, D), dtype=torch.float16, device="cuda:0").requires_grad_() 11 | v = torch.randn((B, H, T, D), dtype=torch.float16, device="cuda:0").requires_grad_() 12 | o = piecewise_attention(q1, k1, q2, k2, v, dist_threshold, causal=True) 13 | print(o) 14 | 15 | go = torch.randn((B, H, T, D), dtype=torch.float16, device="cuda:0") 16 | gq1, gk1, gq2, gk2, gv = torch.autograd.grad( 17 | o, (q1, k1, q2, k2, v), go 18 | ) 19 | print(gq1) 20 | 21 | 22 | -------------------------------------------------------------------------------- /examples/flash_attention_with_aux_outputs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import flag_attn 3 | 4 | 5 | B, H, M, N, D = 2, 4, 10, 20, 128 6 | causal = False 7 | q = torch.randn((B, H, M, D), dtype=torch.bfloat16, device="cuda", requires_grad=True) 8 | k = torch.randn((B, H, N, D), dtype=torch.bfloat16, device="cuda", requires_grad=True) 9 | v = torch.randn((B, H, N, D), dtype=torch.bfloat16, device="cuda", requires_grad=True) 10 | 11 | o, logz, tot_attn = flag_attn.flash_attention( 12 | q, k, v, causal=causal, return_log_normalizer=True, return_total_attention=True) 13 | o_ref, logz_ref, tot_attn_ref = flag_attn.testing.flash_attention( 14 | q, k, v, causal=causal, 15 | return_log_normalizer=True, return_total_attention=True, pcast=True) 16 | 17 | print("log normalizer") 18 | print(logz[0, 0]) 19 | print(logz_ref[0, 0]) 20 | 21 | print("total attention") 22 | print(tot_attn[0, 0]) 23 | print(tot_attn_ref[0, 0]) 24 | -------------------------------------------------------------------------------- /tests/flag_attn/test_dropout.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytest 3 | from flag_attn.testing import recompute_mask 4 | 5 | 6 | @pytest.mark.parametrize('B, H, M, N', [ 7 | (2, 4, 512, 612), 8 | (2, 4, 1024, 1034), 9 | (2, 4, 2048, 2048), 10 | (2, 4, 4096, 4096), 11 | (2, 4, 4001, 4001), 12 | (2, 4, 4001, 4096), 13 | (2, 4, 4096, 4000), 14 | (1, 2, 8192, 8202), 15 | (1, 2, 8192, 8192), 16 | ]) 17 | @pytest.mark.parametrize('p', [0.5, 0.8]) 18 | def test_recompute_mask(B, H, M, N, p): 19 | import math 20 | seed = 123456789 21 | offset = 123456789123456789 22 | device = torch.cuda.current_device() 23 | mask = recompute_mask(B, H, M, N, p, seed, offset, device) 24 | # zeros indicate to drop 25 | # k follows Binomial distributio B(k; n, p) 26 | n = mask.numel() 27 | k = torch.sum(mask == 0) 28 | p_cap = k / n 29 | tol = 0.01 30 | assert math.fabs(p_cap - p) < tol * p 31 | -------------------------------------------------------------------------------- /src/flag_attn/testing/dropout.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import triton 3 | import triton.language as tl 4 | 5 | @triton.jit 6 | def recompute_mask_kernel(mask, B, H, M, N, dropout_p, seed, offset): 7 | row, b, h = tl.program_id(0), tl.program_id(1), tl.program_id(2) 8 | offs_base = b * H * M * N + h * M * N + row * N 9 | BLOCK: tl.constexpr = 1024 10 | offs_base += tl.arange(0, BLOCK) 11 | for start_n in range(0, N, BLOCK): 12 | offs = start_n + offs_base 13 | rng_offs = offset + offs 14 | pmask = tl.rand(seed, rng_offs, n_rounds=6) > dropout_p 15 | row_mask = start_n + tl.arange(0, BLOCK) < N 16 | tl.store(mask + offs, pmask, mask=row_mask) 17 | 18 | def recompute_mask(B, H, M, N, dropout_p, seed, offset, device): 19 | mask = torch.full((B, H, M, N), True, dtype=torch.bool, device=device) 20 | if dropout_p == 0: 21 | return mask 22 | grid = (M, B, H) 23 | with torch.cuda.device(device): 24 | recompute_mask_kernel[grid](mask, B, H, M, N, dropout_p, seed, offset) 25 | return mask 26 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | files: '^src/.*' 2 | repos: 3 | - repo: https://github.com/pre-commit/pre-commit-hooks 4 | rev: v4.5.0 5 | hooks: 6 | - id: trailing-whitespace 7 | - id: end-of-file-fixer 8 | - id: check-yaml 9 | - id: check-toml 10 | - id: check-ast 11 | - id: check-added-large-files 12 | - id: check-merge-conflict 13 | - id: check-executables-have-shebangs 14 | - id: check-shebang-scripts-are-executable 15 | - id: detect-private-key 16 | - id: debug-statements 17 | 18 | # - repo: https://github.com/google/yapf 19 | # rev: v0.40.2 20 | # hooks: 21 | # - id: yapf 22 | # args: ["-p", "-i"] 23 | # stages: [commit, push, manual] 24 | 25 | # - repo: https://github.com/pylint-dev/pylint 26 | # rev: v3.0.3 27 | # hooks: 28 | # - id: pylint 29 | 30 | 31 | - repo: https://github.com/astral-sh/ruff-pre-commit 32 | rev: v0.1.14 33 | hooks: 34 | - id: ruff 35 | args: ["--fix"] 36 | stages: [commit, push, manual] 37 | # - id: ruff-format 38 | # stages: [commit, push, manual] 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "flag_attn" 3 | dynamic = ["version"] 4 | authors = [ 5 | {name = "Chen Feiyu", email = "iclementine@outlook.com"}, 6 | ] 7 | description = "A collection of memory efficient attention operators implemented in triton language." 8 | readme = {file= "README.md", content-type="text/markdown"} 9 | requires-python = ">=3.7" 10 | license = {text = "LICENSE.txt"} 11 | classifiers = [ 12 | "Development Status :: 3 - Alpha", 13 | "Programming Language :: Python :: 3", 14 | "License :: OSI Approved :: Apache Software License", 15 | ] 16 | 17 | # Not specifing triton version here because torch has its own required triton version 18 | # FlagAttention needs a recent version of triton (triton nightly or 2.2.0) to run. 19 | dependencies = [ 20 | "triton>=2.2.0" 21 | ] 22 | 23 | [project.optional-dependencies] 24 | test = [ 25 | "pytest>=7.1.0", 26 | ] 27 | 28 | [project.urls] 29 | homepage = "https://github.com/FlagOpen/FlagAttention" 30 | 31 | 32 | [build-system] 33 | requires = ["setuptools>=60", "setuptools-scm>=8.0"] 34 | build-backend = "setuptools.build_meta" 35 | 36 | [tool.setuptools_scm] 37 | version_file = "src/flag_attn/_version.py" 38 | 39 | [tool.setuptools.packages.find] 40 | where = ["src"] 41 | include = ["flag_attn"] 42 | namespaces = false 43 | 44 | # helps for setting up pytest in pyprojects 45 | # https://docs.pytest.org/en/7.3.x/reference/customize.html#rootdir 46 | # https://docs.pytest.org/en/7.3.x/reference/reference.html#confval-pythonpath 47 | [tool.pytest.ini_options] 48 | testpaths = [ 49 | "tests", 50 | ] 51 | pythonpath = [ 52 | "src", 53 | "tests/flag_attn", 54 | ] 55 | 56 | [tool.ruff] 57 | ignore = ["E741"] 58 | line-length = 120 59 | -------------------------------------------------------------------------------- /.github/workflows/code-check.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python 3 | 4 | name: Python application 5 | 6 | on: 7 | push: 8 | branches: [ "main" ] 9 | pull_request: 10 | branches: [ "main" ] 11 | 12 | permissions: 13 | contents: read 14 | 15 | jobs: 16 | build: 17 | runs-on: self-hosted 18 | steps: 19 | - uses: actions/checkout@v4 20 | with: 21 | ssh-key: ${{ secrets.SSH_PRIVATE_KEY }} 22 | 23 | # - name: Set up Python 3.10 24 | # uses: actions/setup-python@v3 25 | # with: 26 | # python-version: "3.10" 27 | 28 | # - name: Install dependencies 29 | # run: | 30 | # python -m pip install --upgrade pip 31 | # pip install flake8 pytest 32 | # if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 33 | # pip install . 34 | 35 | - name: Activate Virtualenv 36 | run: | 37 | source /home/flagattn_ci/.virtualenvs/release/bin/activate 38 | echo PATH=$PATH >> $GITHUB_ENV 39 | 40 | - name: Editable Install 41 | run: | 42 | pip install --no-dependencies -e . 43 | 44 | # - name: Lint with flake8 45 | # run: | 46 | # # stop the build if there are Python syntax errors or undefined names 47 | # flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 48 | # # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 49 | # flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 50 | - name: Test with pytest 51 | run: | 52 | pytest tests -------------------------------------------------------------------------------- /src/flag_attn/testing/paged.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | 5 | def attention( 6 | query: torch.Tensor, 7 | key_cache: torch.Tensor, 8 | value_cache: torch.Tensor, 9 | block_tables: torch.Tensor, 10 | context_lens: torch.Tensor, 11 | scale: float, 12 | ) -> None: 13 | output = torch.empty_like(query) 14 | 15 | num_query_heads = query.shape[1] 16 | num_kv_heads = value_cache.shape[1] 17 | num_queries_per_kv = num_query_heads // num_kv_heads 18 | block_size = value_cache.shape[2] 19 | head_size = value_cache.shape[3] 20 | num_seqs = query.shape[0] 21 | 22 | block_tables = block_tables.cpu().tolist() 23 | context_lens = context_lens.cpu().tolist() 24 | for i in range(num_seqs): 25 | q = query[i].unsqueeze(0) 26 | block_table = block_tables[i] 27 | context_len = int(context_lens[i]) 28 | 29 | keys = [] 30 | values = [] 31 | for j in range(context_len): 32 | block_number = int(block_table[j // block_size]) 33 | block_offset = j % block_size 34 | k = key_cache[block_number, :, block_offset, :] 35 | keys.append(k) 36 | v = value_cache[block_number, :, block_offset, :] 37 | values.append(v) 38 | keys = torch.stack(keys, dim=0) 39 | values = torch.stack(values, dim=0) 40 | if num_queries_per_kv > 1: 41 | # Handle MQA and GQA 42 | keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1) 43 | values = torch.repeat_interleave(values, num_queries_per_kv, dim=1) 44 | 45 | S = torch.bmm(q.transpose(0, 1).float(), keys.permute(1, 2, 0).float()) * scale 46 | P = torch.softmax(S, dim=-1) 47 | out = torch.bmm(P, values.transpose(0, 1).float()).transpose(0, 1) 48 | out = out.to(values.dtype) 49 | out = out.view(num_query_heads, head_size) 50 | output[i].copy_(out, non_blocking=True) 51 | 52 | return output 53 | -------------------------------------------------------------------------------- /examples/paged_example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Tuple 3 | 4 | import flag_attn 5 | 6 | MAX_SEQ_LEN = 4096 7 | NUM_BLOCKS = 2000 8 | 9 | 10 | def test_paged_attention( 11 | num_seqs: int, 12 | num_heads: Tuple[int, int], 13 | head_size: int, 14 | block_size: int, 15 | dtype: torch.dtype, 16 | seed: int, 17 | device: int, 18 | ): 19 | torch.set_default_dtype(dtype) 20 | torch.set_default_device(device=device) 21 | 22 | torch.cuda.manual_seed(seed) 23 | 24 | num_query_heads, num_kv_heads = num_heads 25 | 26 | context_lens = torch.randint(1, MAX_SEQ_LEN, [num_seqs], dtype=torch.int32) 27 | max_context_len = context_lens.max().item() 28 | max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size 29 | 30 | attn_scale = head_size**-0.5 31 | q = torch.empty(num_seqs, num_query_heads, head_size) 32 | q.uniform_(-attn_scale, attn_scale) 33 | 34 | k_cache = torch.empty(NUM_BLOCKS, num_kv_heads, block_size, head_size) 35 | k_cache.uniform_(-attn_scale, attn_scale) 36 | v_cache = torch.empty_like(k_cache) 37 | v_cache.uniform_(-attn_scale, attn_scale) 38 | 39 | # (NUM_SEQS, MAX_NUM_BLOCKS_PER_SEQ) 40 | block_tables = torch.randint(0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq)) 41 | 42 | out = flag_attn.paged_attention( 43 | q, 44 | k_cache, 45 | v_cache, 46 | context_lens, 47 | block_tables, 48 | attn_scale, 49 | max_context_len, 50 | ) 51 | 52 | ref_out = flag_attn.testing.paged_attention( 53 | q, 54 | k_cache, 55 | v_cache, 56 | block_tables, 57 | context_lens, 58 | attn_scale, 59 | ) 60 | print(torch.abs(out - ref_out).max()) 61 | assert torch.allclose(out, ref_out, atol=1e-3, rtol=1e-5) 62 | 63 | 64 | def main(): 65 | test_paged_attention( 66 | num_seqs=32, 67 | num_heads=(64, 64), 68 | head_size=64, 69 | block_size=16, 70 | dtype=torch.float16, 71 | seed=1, 72 | device="cuda:0", 73 | ) 74 | 75 | 76 | if __name__ == "__main__": 77 | main() 78 | -------------------------------------------------------------------------------- /src/flag_attn/testing/flash.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | 5 | def attention(q, 6 | k, 7 | v, 8 | causal, 9 | dropout_p=0.0, 10 | dropout_mask=None, 11 | sm_scale=None, 12 | return_log_normalizer=False, 13 | return_total_attention=False, 14 | upcast=False): 15 | input_dtype = q.dtype 16 | if upcast: 17 | q, k, v = q.float(), k.float(), v.float() 18 | # (B, H, T, D) 19 | D = q.shape[-1] 20 | if sm_scale is None: 21 | sm_scale = 1. / math.sqrt(D) 22 | 23 | num_heads_q = q.shape[1] 24 | num_heads_k = k.shape[1] 25 | assert num_heads_q % num_heads_k == 0 26 | num_groups = num_heads_q // num_heads_k 27 | 28 | if num_groups > 1: 29 | k = torch.repeat_interleave(k, repeats=num_groups, dim=1) 30 | v = torch.repeat_interleave(v, repeats=num_groups, dim=1) 31 | kv_seq_len = k.shape[-2] 32 | q_seq_len = q.shape[-2] 33 | p_seq = kv_seq_len - q_seq_len 34 | device = q.device 35 | 36 | ms = torch.arange(q_seq_len, device=device).unsqueeze(-1) 37 | ns = torch.arange(kv_seq_len, device=device) 38 | 39 | S = torch.matmul(q, k.transpose(2, 3)) * sm_scale 40 | if causal: 41 | S = torch.where(ms + p_seq >= ns, S, float("-inf")) 42 | 43 | S = S.to(torch.float32) 44 | if return_log_normalizer: 45 | log_normalizer = torch.logsumexp(S, dim=-1) 46 | 47 | # upcast attention to fp32 48 | P = torch.softmax(S, dim=-1, dtype=torch.float32) 49 | if causal: 50 | P = torch.where(ms + p_seq >= ns, P, 0.0) 51 | 52 | if return_total_attention: 53 | tot_attn = torch.sum(P, dim=-2) 54 | 55 | # Applies dropout 56 | dropout_scaling = 1.0 / (1 - dropout_p) 57 | if dropout_mask is not None: 58 | P = P.masked_fill(~dropout_mask, 0.0) 59 | 60 | attn_output = torch.matmul(P.to(v.dtype), v) * dropout_scaling 61 | attn_output = attn_output.to(input_dtype) 62 | 63 | has_extra_return = return_log_normalizer or return_total_attention 64 | if has_extra_return: 65 | outs = (attn_output, 66 | log_normalizer if return_log_normalizer else None, 67 | tot_attn if return_total_attention else None) 68 | return outs 69 | else: 70 | return attn_output 71 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # results of benchmark 2 | benchmark/results*/ 3 | playground/ 4 | 5 | # version, since setuptools-scm is used, this file is automatic generated when building the package 6 | src/flag_attn/_version.py 7 | 8 | # Editors 9 | .vscode/ 10 | .idea/ 11 | 12 | # Vagrant 13 | .vagrant/ 14 | 15 | # Mac/OSX 16 | .DS_Store 17 | 18 | # Windows 19 | Thumbs.db 20 | 21 | # Source for the following rules: https://raw.githubusercontent.com/github/gitignore/master/Python.gitignore 22 | # Byte-compiled / optimized / DLL files 23 | __pycache__/ 24 | *.py[cod] 25 | *$py.class 26 | 27 | # C extensions 28 | *.so 29 | 30 | # Distribution / packaging 31 | .Python 32 | build/ 33 | develop-eggs/ 34 | dist/ 35 | downloads/ 36 | eggs/ 37 | .eggs/ 38 | lib/ 39 | lib64/ 40 | parts/ 41 | sdist/ 42 | var/ 43 | wheels/ 44 | *.egg-info/ 45 | .installed.cfg 46 | *.egg 47 | MANIFEST 48 | 49 | # PyInstaller 50 | # Usually these files are written by a python script from a template 51 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 52 | *.manifest 53 | *.spec 54 | 55 | # Installer logs 56 | pip-log.txt 57 | pip-delete-this-directory.txt 58 | 59 | # Unit test / coverage reports 60 | htmlcov/ 61 | .tox/ 62 | .nox/ 63 | .coverage 64 | .coverage.* 65 | .cache 66 | nosetests.xml 67 | coverage.xml 68 | *.cover 69 | .hypothesis/ 70 | .pytest_cache/ 71 | 72 | # Translations 73 | *.mo 74 | *.pot 75 | 76 | # Django stuff: 77 | *.log 78 | local_settings.py 79 | db.sqlite3 80 | 81 | # Flask stuff: 82 | instance/ 83 | .webassets-cache 84 | 85 | # Scrapy stuff: 86 | .scrapy 87 | 88 | # Sphinx documentation 89 | docs/_build/ 90 | 91 | # PyBuilder 92 | target/ 93 | 94 | # Jupyter Notebook 95 | .ipynb_checkpoints 96 | 97 | # IPython 98 | profile_default/ 99 | ipython_config.py 100 | 101 | # pyenv 102 | .python-version 103 | 104 | # celery beat schedule file 105 | celerybeat-schedule 106 | 107 | # SageMath parsed files 108 | *.sage.py 109 | 110 | # Environments 111 | .env 112 | .venv 113 | env/ 114 | venv/ 115 | ENV/ 116 | env.bak/ 117 | venv.bak/ 118 | 119 | # Spyder project settings 120 | .spyderproject 121 | .spyproject 122 | 123 | # Rope project settings 124 | .ropeproject 125 | 126 | # mkdocs documentation 127 | /site 128 | 129 | # mypy 130 | .mypy_cache/ 131 | .dmypy.json 132 | dmypy.json 133 | 134 | -------------------------------------------------------------------------------- /benchmark/flash_decoding_benchmark.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import pathlib 4 | import torch 5 | import triton 6 | 7 | import flag_attn 8 | 9 | 10 | try: 11 | from flash_attn import flash_attn_func 12 | FLASH_VER = 2 13 | except BaseException: 14 | try: 15 | from flash_attn.flash_attn_interface import flash_attn_func 16 | FLASH_VER = 1 17 | except BaseException: 18 | FLASH_VER = None 19 | HAS_FLASH = FLASH_VER is not None 20 | 21 | 22 | configs = [triton.testing.Benchmark( 23 | x_names=['N_CTX'], 24 | x_vals=[2**i for i in range(9, 20)], 25 | line_arg='provider', 26 | line_vals=['flag_attn', 'torch', ] + (['flash'] if HAS_FLASH else []), 27 | line_names=['flag_attn', 'torch', ] + ([f'flash-{FLASH_VER}'] if HAS_FLASH else []), 28 | styles=[('red', '-'), ('green', '-'), ('blue', '-'), ('cyan', '-')], 29 | ylabel='tflop/s', 30 | plot_name=f'attention_d-{D_HEAD}_dtype-{dtype} (ms)', 31 | args={'D_HEAD': D_HEAD, 'dtype': dtype} 32 | ) for D_HEAD in [64, 128] 33 | for dtype in [torch.float16]] 34 | 35 | @triton.testing.perf_report(configs) 36 | def bench_flash_attention(N_CTX, D_HEAD, provider, dtype=torch.float16): 37 | BATCH = 2 38 | H = 2048 // D_HEAD 39 | causal = False 40 | if provider == "flag_attn": 41 | q = torch.randn((BATCH, H, 1, D_HEAD), dtype=dtype, device="cuda") 42 | k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda") 43 | v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda") 44 | fn = lambda: flag_attn.flash_attention(q, k, v, causal=causal) 45 | ms = triton.testing.do_bench(fn) 46 | if provider == "torch": 47 | q = torch.randn((BATCH, H, 1, D_HEAD), dtype=dtype, device="cuda") 48 | k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda") 49 | v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda") 50 | try: 51 | fn = lambda: flag_attn.testing.flash_attention(q, k, v, causal=causal, upcast=False) 52 | ms = triton.testing.do_bench(fn) 53 | except torch.cuda.OutOfMemoryError as e: 54 | logging.info(f"torch OOM for batch_size: {BATCH}, num_heads: {H}, seqlen: {N_CTX}, headdim: {D_HEAD}") 55 | ms = float("inf") 56 | if provider == "flash": 57 | q = torch.randn((BATCH, 1, H, D_HEAD), dtype=dtype, device="cuda") 58 | k = torch.randn((BATCH, N_CTX, H, D_HEAD), dtype=dtype, device="cuda") 59 | v = torch.randn((BATCH, N_CTX, H, D_HEAD), dtype=dtype, device="cuda") 60 | fn = lambda: flash_attn_func(q, k, v, causal=causal) 61 | ms = triton.testing.do_bench(fn) 62 | 63 | return ms 64 | # # total TFLOPS: following Flash Attention v2, only gemms are counted. 65 | # macs = 2. * BATCH * H * N_CTX * D_HEAD # Q@K, P@V 66 | # total_flops = 2 * macs 67 | # return total_flops / ms * 1e-9 68 | 69 | # only works on post-Ampere GPUs right now 70 | today = datetime.date.today().strftime(format("%Y%m%d")) 71 | output_dir = pathlib.Path(f"results_flash_attention_with_split_kv_{today}") 72 | output_dir.mkdir(exist_ok=True) 73 | bench_flash_attention.run(save_path=output_dir, print_data=True) 74 | -------------------------------------------------------------------------------- /tests/flag_attn/test_paged_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytest 3 | 4 | import flag_attn 5 | 6 | NUM_BLOCKS = 1000 7 | 8 | 9 | def base_paged_attention( 10 | num_seqs, 11 | num_query_heads, 12 | query_group_size, 13 | head_size, 14 | block_size, 15 | max_seq_len, 16 | num_splits=0, 17 | dtype=torch.float16, 18 | device="cuda", 19 | ): 20 | torch.set_default_dtype(dtype) 21 | torch.set_default_device(device=device) 22 | 23 | num_kv_heads = num_query_heads // query_group_size 24 | 25 | context_lens = torch.randint(1, max_seq_len, [num_seqs], dtype=torch.int32) 26 | context_lens[0] = max_seq_len 27 | max_context_len = context_lens.max().item() 28 | max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size 29 | 30 | attn_scale = head_size**-0.5 31 | q = torch.empty(num_seqs, num_query_heads, head_size) 32 | q.uniform_(-attn_scale, attn_scale) 33 | 34 | k_cache = torch.empty(NUM_BLOCKS, num_kv_heads, block_size, head_size) 35 | k_cache.uniform_(-attn_scale, attn_scale) 36 | v_cache = torch.empty_like(k_cache) 37 | v_cache.uniform_(-attn_scale, attn_scale) 38 | 39 | # (NUM_SEQS, MAX_NUM_BLOCKS_PER_SEQ) 40 | block_tables = torch.randint(0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq)) 41 | 42 | out = flag_attn.paged_attention( 43 | q, 44 | k_cache, 45 | v_cache, 46 | context_lens, 47 | block_tables, 48 | attn_scale, 49 | max_context_len, 50 | num_splits, 51 | ) 52 | 53 | ref_out = flag_attn.testing.paged_attention( 54 | q, 55 | k_cache, 56 | v_cache, 57 | block_tables, 58 | context_lens, 59 | attn_scale, 60 | ) 61 | print(torch.abs(out - ref_out).max()) 62 | assert torch.allclose(out, ref_out, atol=2e-3, rtol=1e-5) 63 | 64 | 65 | @pytest.mark.parametrize("num_seqs", [1, 32]) 66 | @pytest.mark.parametrize("num_query_heads", [64]) 67 | @pytest.mark.parametrize("query_group_size", [1, 8]) 68 | @pytest.mark.parametrize("head_size", [64, 128]) 69 | @pytest.mark.parametrize("block_size", [16, 128, 256]) 70 | @pytest.mark.parametrize("max_seq_len", [512, 4096]) 71 | def test_paged_attention_default( 72 | num_seqs, 73 | num_query_heads, 74 | query_group_size, 75 | head_size, 76 | block_size, 77 | max_seq_len, 78 | dtype=torch.float16, 79 | device="cuda", 80 | ): 81 | base_paged_attention( 82 | num_seqs, 83 | num_query_heads, 84 | query_group_size, 85 | head_size, 86 | block_size, 87 | max_seq_len, 88 | ) 89 | 90 | 91 | @pytest.mark.parametrize("num_seqs", [1, 16]) 92 | @pytest.mark.parametrize("num_query_heads", [64]) 93 | @pytest.mark.parametrize("query_group_size", [1, 8]) 94 | @pytest.mark.parametrize("head_size", [32, 64]) 95 | @pytest.mark.parametrize("block_size", [16]) 96 | @pytest.mark.parametrize("max_seq_len", [2048]) 97 | @pytest.mark.parametrize("num_splits", [1, 2, 3, 4, 5, 6, 7, 8]) 98 | def test_paged_attention_by_num_splits( 99 | num_seqs, 100 | num_query_heads, 101 | query_group_size, 102 | head_size, 103 | block_size, 104 | max_seq_len, 105 | num_splits, 106 | dtype=torch.float16, 107 | device="cuda", 108 | ): 109 | base_paged_attention( 110 | num_seqs, 111 | num_query_heads, 112 | query_group_size, 113 | head_size, 114 | block_size, 115 | max_seq_len, 116 | num_splits=num_splits, 117 | ) 118 | 119 | @pytest.mark.parametrize("num_seqs, num_query_heads, query_group_size, head_size, block_size, max_seq_len, num_splits", [ 120 | (1, 12, 1, 64, 16, 2, 0), 121 | (16, 64, 8, 32, 16, 2048, 2), 122 | (16, 64, 1, 64, 16, 2048, 6), 123 | ]) 124 | def test_paged_attention_by_case( 125 | num_seqs, 126 | num_query_heads, 127 | query_group_size, 128 | head_size, 129 | block_size, 130 | max_seq_len, 131 | num_splits, 132 | dtype=torch.float16, 133 | device="cuda", 134 | ): 135 | base_paged_attention( 136 | num_seqs, 137 | num_query_heads, 138 | query_group_size, 139 | head_size, 140 | block_size, 141 | max_seq_len, 142 | num_splits, 143 | ) 144 | -------------------------------------------------------------------------------- /src/flag_attn/testing/piecewise.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import pytest 4 | 5 | def attention(q1, k1, q2, k2, v, dist_threshold, causal, sm_scale=None, upcast=False): 6 | input_dtype = q1.dtype 7 | if upcast: 8 | q1, k1, q2, k2, v = q1.float(), k1.float(), q2.float(), k2.float(), v.float() 9 | # (B, H, T, D) 10 | D = q1.shape[-1] 11 | if sm_scale is None: 12 | sm_scale = 1. / math.sqrt(D) 13 | kv_seq_len = k1.shape[-2] 14 | q_seq_len = q1.shape[-2] 15 | p_seq = kv_seq_len - q_seq_len 16 | device = q1.device 17 | 18 | ms = torch.arange(q_seq_len, device=device).unsqueeze(-1) 19 | ns = torch.arange(kv_seq_len, device=device) 20 | 21 | S1 = torch.matmul(q1, k1.transpose(2, 3)) 22 | S2 = torch.matmul(q2, k2.transpose(2, 3)) 23 | long_distance = ((ms + p_seq - ns) >= dist_threshold) 24 | S = torch.where(long_distance, S2, S1) * sm_scale 25 | 26 | if causal: 27 | S = torch.where(ms + p_seq >= ns, S, torch.finfo(S.dtype).min) 28 | 29 | # upcast attention to fp32 30 | P = torch.softmax(S, dim=-1, dtype=torch.float32).to(v.dtype) 31 | if causal: 32 | P = torch.where(ms + p_seq >= ns, P, 0.0) 33 | attn_output = torch.matmul(P, v) 34 | return attn_output.to(input_dtype) 35 | 36 | def attention_grad(q1, k1, q2, k2, v, w, causal, sm_scale, o, do, upcast=False): 37 | input_dtype = q1.dtype 38 | 39 | if upcast: 40 | q1, k1, q2, k2, v, o, do = [item.float() for item in [q1, k1, q2, k2, v, o, do]] 41 | kv_seq_len = k1.shape[-2] 42 | q_seq_len = q1.shape[-2] 43 | p_seq = kv_seq_len - q_seq_len 44 | device = q1.device 45 | 46 | ms = torch.arange(q_seq_len, device=device).unsqueeze(-1) 47 | ns = torch.arange(kv_seq_len, device=device) 48 | 49 | S1 = torch.matmul(q1, k1.transpose(2, 3)) 50 | S2 = torch.matmul(q2, k2.transpose(2, 3)) 51 | long_distance = ((ms + p_seq - ns) >= w) 52 | S = torch.where(long_distance, S2, S1) * sm_scale 53 | 54 | if causal: 55 | S = torch.where((ms + p_seq) >= ns, S, torch.finfo(S.dtype).min) 56 | 57 | # upcast attention to fp32 58 | P = torch.softmax(S, dim=-1, dtype=torch.float32).to(v.dtype) 59 | 60 | # dP & dv 61 | dv = torch.matmul(P.transpose(2, 3), do) 62 | dP = torch.matmul(do, v.transpose(2, 3)) 63 | 64 | # dS 65 | delta = (do * o).sum(-1, keepdim=True) # (B,H,T) 66 | dS = P * (dP - delta) * sm_scale 67 | dS2 = torch.where(long_distance, dS, 0.0) 68 | dS1 = torch.where(long_distance, 0.0, dS) 69 | 70 | # dq & dk 71 | dq1 = torch.matmul(dS1, k1) 72 | dk1 = torch.matmul(dS1.transpose(2, 3), q1) 73 | 74 | dq2 = torch.matmul(dS2, k2) 75 | dk2 = torch.matmul(dS2.transpose(2, 3), q2) 76 | 77 | dq1, dk1, dq2, dk2, dv = [item.to(input_dtype) for item in [dq1, dk1, dq2, dk2, dv]] 78 | return dq1, dk1, dq2, dk2, dv 79 | 80 | @pytest.mark.parametrize('B, H, T, D, P_SEQ', [(2, 3, 1024, 32, 100), (2, 3, 1024, 32, 0)]) 81 | @pytest.mark.parametrize('causal', [True, False]) 82 | @pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) 83 | def test_op(B, H, T, D, P_SEQ, causal, dtype): 84 | q1 = torch.empty((B, H, T, D), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() 85 | q2 = torch.empty((B, H, T, D), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() 86 | k1 = torch.empty((B, H, T + P_SEQ, D), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() 87 | k2 = torch.empty((B, H, T + P_SEQ, D), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() 88 | v = torch.empty((B, H, T + P_SEQ, D), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() 89 | sm_scale = 0.5 90 | w = 780 91 | 92 | o = attention(q1, k1, q2, k2, v, w, causal, sm_scale) 93 | do = torch.empty((B, H, T, D), dtype=dtype, device="cuda").normal_(mean=0., std=0.5) 94 | o.backward(do) 95 | dq1, dk1, dq2, dk2, dv = attention_grad(q1, k1, q2, k2, v, w, causal, sm_scale, o, do) 96 | 97 | torch.testing.assert_close(dv, v.grad, atol=1e-2, rtol=0.0) 98 | torch.testing.assert_close(dq1, q1.grad, atol=1e-2, rtol=0.0) 99 | torch.testing.assert_close(dk1, k1.grad, atol=1e-2, rtol=0.0) 100 | torch.testing.assert_close(dq2, q2.grad, atol=1e-2, rtol=0.0) 101 | torch.testing.assert_close(dk2, k2.grad, atol=1e-2, rtol=0.0) 102 | -------------------------------------------------------------------------------- /benchmark/flash_benchmark.py: -------------------------------------------------------------------------------- 1 | import math 2 | import datetime 3 | import logging 4 | import pathlib 5 | import torch 6 | import triton 7 | 8 | import flag_attn 9 | 10 | 11 | try: 12 | from flash_attn import flash_attn_func 13 | FLASH_VER = 2 14 | except BaseException: 15 | try: 16 | from flash_attn.flash_attn_interface import flash_attn_func 17 | FLASH_VER = 1 18 | except BaseException: 19 | FLASH_VER = None 20 | HAS_FLASH = FLASH_VER is not None 21 | 22 | 23 | configs = [triton.testing.Benchmark( 24 | x_names=['N_CTX'], 25 | x_vals=[2**i for i in range(9, 16)], 26 | line_arg='provider', 27 | line_vals=['flag_attn', 'torch', ] + (['flash'] if HAS_FLASH else []), 28 | line_names=['flag_attn', 'torch', ] + ([f'flash-{FLASH_VER}'] if HAS_FLASH else []), 29 | styles=[('red', '-'), ('green', '-'), ('blue', '-'), ('cyan', '-')], 30 | ylabel='tflop/s', 31 | plot_name=f'attention_d-{D_HEAD}_mode-{mode}_causal-{causal}_dtype-{dtype}', 32 | args={'D_HEAD': D_HEAD, 'dtype': dtype, 'mode': mode, 'causal': causal} 33 | ) for mode in ['fwd', 'bwd'] 34 | for causal in [False, True] 35 | for D_HEAD in [64, 128] 36 | for dtype in [torch.float16, torch.bfloat16]] 37 | 38 | @triton.testing.perf_report(configs) 39 | def bench_flash_attention(N_CTX, D_HEAD, causal, mode, provider, dtype=torch.float16, device="cuda"): 40 | assert mode in ['fwd', 'bwd'] 41 | w = N_CTX // 2 # dist thresold 42 | warmup = 25 43 | rep = 100 44 | 45 | is_bwd = mode == "bwd" 46 | 47 | BATCH = 32768 // N_CTX 48 | H = 2048 // D_HEAD 49 | if provider == "flag_attn": 50 | q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=is_bwd) 51 | k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=is_bwd) 52 | v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=is_bwd) 53 | fn = lambda: flag_attn.flash_attention(q, k, v, causal=causal) 54 | if mode == 'bwd': 55 | o = fn() 56 | do = torch.randn_like(o) 57 | fn = lambda: o.backward(do, retain_graph=True) 58 | ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) 59 | if provider == "torch": 60 | q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=is_bwd) 61 | k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=is_bwd) 62 | v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=is_bwd) 63 | 64 | try: 65 | fn = lambda: flag_attn.testing.flash_attention(q, k, v, causal=causal, upcast=False) 66 | if mode == 'bwd': 67 | o = fn() 68 | do = torch.randn_like(o) 69 | fn = lambda: o.backward(do, retain_graph=True) 70 | ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) 71 | except torch.cuda.OutOfMemoryError as e: 72 | logging.info(f"torch OOM for batch_size: {BATCH}, num_heads: {H}, seqlen: {N_CTX}, headdim: {D_HEAD}") 73 | ms = float("inf") 74 | if provider == "flash": 75 | if FLASH_VER == 1: 76 | qkv = torch.randn((BATCH, N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=is_bwd) 77 | lengths = torch.full((BATCH,), fill_value=N_CTX, device=device) 78 | cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32) 79 | cu_seqlens[1:] = lengths.cumsum(0) 80 | qkv = qkv.reshape(BATCH * N_CTX, 3, H, D_HEAD) 81 | fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=causal) 82 | elif FLASH_VER == 2: 83 | q = torch.randn((BATCH, N_CTX, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=is_bwd) 84 | k = torch.randn((BATCH, N_CTX, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=is_bwd) 85 | v = torch.randn((BATCH, N_CTX, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=is_bwd) 86 | fn = lambda: flash_attn_func(q, k, v, causal=causal) 87 | else: 88 | raise ValueError(f'unknown {FLASH_VER = }') 89 | if mode == 'bwd': 90 | o = fn() 91 | do = torch.randn_like(o) 92 | fn = lambda: o.backward(do, retain_graph=True) 93 | ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) 94 | 95 | # total TFLOPS: following Flash Attention v2, only gemms are counted. 96 | macs = 2. * BATCH * H * N_CTX * N_CTX * D_HEAD # Q@K, P@V 97 | if mode == 'bwd': 98 | macs *= 2.5 # Q@K, dO@V, dO@P, dS@Q dS@K 99 | total_flops = 2 * macs 100 | 101 | if causal: 102 | total_flops *= 0.5 103 | return total_flops / ms * 1e-9 104 | 105 | # only works on post-Ampere GPUs right now 106 | today = datetime.date.today().strftime(format("%Y%m%d")) 107 | output_dir = pathlib.Path(f"results_flash_attention_{today}") 108 | output_dir.mkdir(exist_ok=True) 109 | bench_flash_attention.run(save_path=output_dir, print_data=True) 110 | -------------------------------------------------------------------------------- /src/flag_attn/total.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import triton 4 | import triton.language as tl 5 | 6 | def get_fwd_config(B, H, M, N, D, causal): 7 | return (32, 32, 1, 4) 8 | 9 | def total_attention(q, k, l, causal=False, sm_scale=None): 10 | Dq, Dk = q.shape[-1], k.shape[-1] 11 | assert Dq == Dk 12 | assert Dk in {16, 32, 64, 128} 13 | # assert L is contiguous 14 | 15 | B, H, M, D = q.shape 16 | N = k.shape[2] 17 | Hk = k.shape[1] 18 | assert H % Hk == 0, "number of heads in q must be a multiple of that in k" 19 | num_groups = H // Hk 20 | 21 | P_SEQ = N - M 22 | 23 | if sm_scale is None: 24 | sm_scale = 1. / math.sqrt(D) 25 | 26 | # to work around https://github.com/openai/triton/issues/2441 27 | device = torch.cuda.device_of(q) 28 | with torch.cuda.device(device): 29 | config = get_fwd_config(B, H, M, N, D, causal) 30 | BLOCK_M, BLOCK_N, num_stages, num_warps = config 31 | 32 | divisible_m = M % BLOCK_M == 0 33 | divisible_n = N % BLOCK_N == 0 34 | # consider using 3d grid to avoid div & rem 35 | grid = (triton.cdiv(N, BLOCK_N), H, B) 36 | tot_attn = torch.empty((B, H, N), dtype=torch.float32, device=q.device) 37 | _total_attention_kernel[grid]( 38 | q, k, l, tot_attn, sm_scale, 39 | q.stride(0), q.stride(1), q.stride(2), q.stride(3), 40 | k.stride(0), k.stride(1), k.stride(2), k.stride(3), 41 | B, H, M, N, P_SEQ, num_groups, 42 | BLOCK_M=BLOCK_M, BLOCK_DMODEL=D, BLOCK_N=BLOCK_N, 43 | CAUSAL=causal, 44 | DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n, 45 | num_stages=num_stages, num_warps=num_warps, 46 | ) 47 | return tot_attn 48 | 49 | 50 | @triton.jit 51 | def _total_attention_kernel( 52 | Q, K, L, TA, sm_scale, 53 | stride_qz, stride_qh, stride_qm, stride_qk, 54 | stride_kz, stride_kh, stride_kn, stride_kk, 55 | Z, H, M, N, P_SEQ, num_groups, 56 | BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, 57 | CAUSAL: tl.constexpr, 58 | DIVISIBLE_M: tl.constexpr, DIVISIBLE_N: tl.constexpr, 59 | ): 60 | # -- grid id -- 61 | start_n = tl.program_id(0) 62 | off_h = tl.program_id(1) 63 | off_z = tl.program_id(2) 64 | log2e: tl.constexpr = 1.4426950408889634 65 | qk_scale = sm_scale * log2e 66 | 67 | # offset pointers for (batch, head) 68 | off_hk = off_h // num_groups 69 | Q += off_z * stride_qz + off_h * stride_qh 70 | K += off_z * stride_kz + off_hk * stride_kh 71 | L += (off_z * H + off_h) * M 72 | TA += (off_z * H + off_h) * N # (b, h, n) 73 | 74 | if CAUSAL: 75 | lo = tl.maximum(start_n * BLOCK_N - P_SEQ, 0) 76 | lo = (lo // BLOCK_M) * BLOCK_M 77 | else: 78 | lo = 0 79 | 80 | offs_m_init = lo + tl.arange(0, BLOCK_M) 81 | offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) 82 | offs_m_base = tl.arange(0, BLOCK_M) 83 | offs_k = tl.arange(0, BLOCK_DMODEL) 84 | 85 | # initialize pointers to value-like data 86 | q_ptrs = Q + (offs_m_init[:, None] * stride_qm + offs_k[None, :] * stride_qk) # (BLOCK_M, BLOCK_DMODEL) 87 | k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) # (BLOCK_N, BLOCK_DMODEL) 88 | ta_ptrs = TA + offs_n # (BLOCK_N, ) 89 | 90 | # k and v stay in SRAM throughout 91 | if DIVISIBLE_N: 92 | k = tl.load(k_ptrs) 93 | else: 94 | mask_n = offs_n < N 95 | k = tl.load(k_ptrs, mask=mask_n[:, None]) 96 | 97 | # initialize total attention 98 | tot_attn = tl.zeros([BLOCK_N], dtype=tl.float32) 99 | 100 | # loop over a col 101 | for start_m in range(lo, M, BLOCK_M): 102 | start_m = tl.multiple_of(start_m, BLOCK_M) 103 | offs_m = start_m + offs_m_base 104 | causal_mask = (P_SEQ + offs_m[:, None]) >= (offs_n[None, :]) # (BLOCK_M, BLOCK_N) 105 | 106 | if DIVISIBLE_M: 107 | q = tl.load(q_ptrs) 108 | else: 109 | mask_m = offs_m < M 110 | valid_mask = mask_m[:, None] # & mask_n 111 | q = tl.load(q_ptrs, mask=mask_m[:, None]) 112 | # recompute p = softmax(qk * sm_scale, dim=-1) 113 | s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 114 | s += tl.dot(q, tl.trans(k)) 115 | 116 | # NOTE: since softmax in backward is pointwise, the normalizer has been saved in fwd) 117 | # So masking on s is not needed. 118 | # s = tl.where(valid_mask, s , float("-inf")) 119 | # if CAUSAL: 120 | # s = tl.where(causal_mask, s, float("-inf")) 121 | 122 | # -- recompute p --- 123 | if DIVISIBLE_M: 124 | l = tl.load(L + offs_m) 125 | else: 126 | l = tl.load(L + offs_m, mask=mask_m) 127 | p = tl.math.exp2(s * qk_scale - l[:, None] * log2e) # (BLOCK_M, BLOCK_N) 128 | 129 | if not DIVISIBLE_M: 130 | p = tl.where(valid_mask, p, 0.0) 131 | if CAUSAL: 132 | p = tl.where(causal_mask, p, 0.0) 133 | 134 | tot_attn += tl.sum(p, 0) 135 | # increment pointers 136 | q_ptrs += BLOCK_M * stride_qm 137 | 138 | 139 | if DIVISIBLE_N: 140 | tl.store(ta_ptrs, tot_attn) # (BLOCK_N,) 141 | else: 142 | tl.store(ta_ptrs, tot_attn, mask=mask_n) # (BLOCK_N, ) 143 | -------------------------------------------------------------------------------- /benchmark/piecewise_benchmark.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import pathlib 3 | import datetime 4 | import torch 5 | import triton 6 | 7 | import flag_attn 8 | 9 | try: 10 | from flash_attn import flash_attn_func 11 | FLASH_VER = 2 12 | except BaseException: 13 | try: 14 | from flash_attn.flash_attn_interface import flash_attn_func 15 | FLASH_VER = 1 16 | except BaseException: 17 | FLASH_VER = None 18 | HAS_FLASH = FLASH_VER is not None 19 | 20 | 21 | configs = [triton.testing.Benchmark( 22 | x_names=['N_CTX'], 23 | x_vals=[2**i for i in range(9, 16)], 24 | line_arg='provider', 25 | line_vals=['piecewise', 'torch'] + (['flash'] if HAS_FLASH else []), 26 | line_names=['piecewise', 'torch'] + ([f'flash-{FLASH_VER}'] if HAS_FLASH else []), 27 | styles=[('red', '-'), ('green', '-'), ('blue', '-')], 28 | ylabel='tflop/s', 29 | plot_name=f'piecewise_attention_d-{D_HEAD}_mode-{mode}_causal-{causal}_dtype-{dtype}', 30 | args={'D_HEAD': D_HEAD, 'dtype': dtype, 'mode': mode, 'causal': causal} 31 | ) for mode in ['fwd', 'bwd'] 32 | for causal in [False, True] 33 | for D_HEAD in [64, 128] 34 | for dtype in [torch.float16, torch.bfloat16]] 35 | 36 | @triton.testing.perf_report(configs) 37 | def bench_flash_attention(N_CTX, D_HEAD, causal, mode, provider, dtype=torch.float16, device="cuda"): 38 | assert mode in ['fwd', 'bwd'] 39 | w = N_CTX // 2 # dist thresold 40 | warmup = 25 41 | rep = 100 42 | 43 | BATCH = 32768 // N_CTX 44 | H = 2048 // D_HEAD 45 | if provider == "piecewise": 46 | q1 = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) 47 | k1 = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) 48 | q2 = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) 49 | k2 = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) 50 | v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) 51 | fn = lambda: flag_attn.piecewise_attention(q1, k1, q2, k2, v, w, causal=causal) 52 | if mode == 'bwd': 53 | o = fn() 54 | do = torch.randn_like(o) 55 | fn = lambda: o.backward(do, retain_graph=True) 56 | ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) 57 | if provider == "torch": 58 | q1 = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) 59 | k1 = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) 60 | q2 = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) 61 | k2 = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) 62 | v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) 63 | 64 | try: 65 | fn = lambda: flag_attn.testing.piecewise_attention(q1, k1, q2, k2, v, w, causal=causal) 66 | if mode == 'bwd': 67 | o = fn() 68 | do = torch.randn_like(o) 69 | fn = lambda: o.backward(do, retain_graph=True) 70 | ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) 71 | except torch.cuda.OutOfMemoryError as e: 72 | logging.info(f"torch OOM for batch_size: {BATCH}, num_heads: {H}, seqlen: {N_CTX}, headdim: {D_HEAD}") 73 | ms = float("inf") 74 | if provider == "flash": 75 | if FLASH_VER == 1: 76 | qkv = torch.randn((BATCH, N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True) 77 | lengths = torch.full((BATCH,), fill_value=N_CTX, device=device) 78 | cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32) 79 | cu_seqlens[1:] = lengths.cumsum(0) 80 | qkv = qkv.reshape(BATCH * N_CTX, 3, H, D_HEAD) 81 | fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=causal) 82 | elif FLASH_VER == 2: 83 | q = torch.randn((BATCH, N_CTX, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) 84 | k = torch.randn((BATCH, N_CTX, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) 85 | v = torch.randn((BATCH, N_CTX, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) 86 | fn = lambda: flash_attn_func(q, k, v, causal=causal) 87 | else: 88 | raise ValueError(f'unknown {FLASH_VER = }') 89 | if mode == 'bwd': 90 | o = fn() 91 | do = torch.randn_like(o) 92 | fn = lambda: o.backward(do, retain_graph=True) 93 | ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) 94 | 95 | # total TFLOPS: following Flash Attention v2, only gemms are counted. 96 | # NOTE: It is not a fair play here, the total amount of flops and the elapsed time are different, 97 | # the tflop/s is a used as a metric of the performance of the operator, for refernce only. 98 | if provider == "flash": 99 | macs = 2. * BATCH * H * N_CTX * N_CTX * D_HEAD # Q@K, P@V 100 | if mode == 'bwd': 101 | macs *= 2.5 # Q@K, dO@V, dO@P, dS@Q dS@K 102 | else: 103 | macs = 3. * BATCH * H * N_CTX * N_CTX * D_HEAD # Q1@K1, Q2@K2, P@V 104 | if mode == 'bwd': 105 | macs *= 8.0/3.0 # 1 *(Q1@K1, Q2@K2, dO@V), dO@P, dS1@@Q1, dS1@K1, dS2@@Q2, dS2@K2 106 | total_flops = 2 * macs 107 | 108 | if causal: 109 | total_flops *= 0.5 110 | return total_flops / ms * 1e-9 111 | 112 | # only works on post-Ampere GPUs right now 113 | today = datetime.date.today().strftime(format("%Y%m%d")) 114 | output_dir = pathlib.Path(f"results_piecewise_attention_{today}") 115 | output_dir.mkdir(exist_ok=True) 116 | bench_flash_attention.run(save_path=output_dir, print_data=True) 117 | -------------------------------------------------------------------------------- /benchmark/paged_benchmark.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import triton 3 | import flag_attn 4 | 5 | NUM_BLOCKS = 1000 6 | warmup = 200 7 | rep = 200 8 | 9 | try: 10 | from vllm._C import ops as vllm_ops 11 | 12 | HAS_VLLM = True 13 | 14 | # required vllm 0.3.0 15 | import vllm 16 | 17 | print("vllm.__version__", vllm.__version__) 18 | except BaseException: 19 | HAS_VLLM = False 20 | 21 | 22 | def vllm_paged_attention( 23 | out: torch.Tensor, 24 | query: torch.Tensor, 25 | key_cache: torch.Tensor, 26 | value_cache: torch.Tensor, 27 | num_kv_heads: int, 28 | scale: float, 29 | block_tables: torch.Tensor, 30 | context_lens: torch.Tensor, 31 | block_size: int, 32 | max_context_len: int, 33 | PARTITION_SIZE: int = 512, 34 | version: int = 1, 35 | ): 36 | if version == 1: 37 | vllm_ops.paged_attention_v1( 38 | out, 39 | query, 40 | key_cache, 41 | value_cache, 42 | num_kv_heads, 43 | scale, 44 | block_tables, 45 | context_lens, 46 | block_size, 47 | max_context_len, 48 | None, # alibi_slopes 49 | "auto", # kv_cache_dtype for vllm 0.3.0 50 | ) 51 | elif version == 2: 52 | num_partitions = (max_context_len + PARTITION_SIZE - 1) // PARTITION_SIZE 53 | assert PARTITION_SIZE % block_size == 0 54 | num_seqs, num_heads, head_size = out.shape 55 | tmp_out = torch.empty( 56 | size=(num_seqs, num_heads, num_partitions, head_size), 57 | dtype=out.dtype, 58 | device=out.device, 59 | ) 60 | exp_sums = torch.empty( 61 | size=(num_seqs, num_heads, num_partitions), 62 | dtype=torch.float32, 63 | device=out.device, 64 | ) 65 | max_logits = torch.empty_like(exp_sums) 66 | vllm_ops.paged_attention_v2( 67 | out, 68 | exp_sums, 69 | max_logits, 70 | tmp_out, 71 | query, 72 | key_cache, 73 | value_cache, 74 | num_kv_heads, 75 | scale, 76 | block_tables, 77 | context_lens, 78 | block_size, 79 | max_context_len, 80 | None, 81 | "auto", # vllm 0.3.0 82 | ) 83 | else: 84 | raise AssertionError(f"Unknown version: {version}") 85 | 86 | 87 | @triton.testing.perf_report( 88 | [ 89 | triton.testing.Benchmark( 90 | x_names=["context_len"], 91 | x_vals=[2**i for i in range(9, 15)], 92 | line_arg="provider", 93 | line_vals=["triton"] + (["vllm"] if HAS_VLLM else []), 94 | line_names=["triton"] + ([f"vllm-{vllm.__version__}"] if HAS_VLLM else []), 95 | styles=[("red", "-"), ("blue", "-")], 96 | ylabel="tflop/s", 97 | plot_name=f"vllm_paged_attention-B{num_seqs}-G{query_group_size}-D{head_size}-bs{block_size}-v{version}", 98 | args={ 99 | "num_seqs": num_seqs, 100 | "num_query_heads": 64, 101 | "query_group_size": query_group_size, 102 | "head_size": head_size, 103 | "block_size": block_size, 104 | "vllm_version": version, 105 | "dtype": dtype, 106 | }, 107 | ) 108 | for num_seqs in [1, 32, 64] 109 | for query_group_size in [1, 8] 110 | for head_size in [64, 128] 111 | for block_size in [16, 32] 112 | for version in [1, 2] 113 | for dtype in [torch.float16] 114 | ] 115 | ) 116 | def paged_attention_benchmark_with_vllm( 117 | num_seqs, 118 | num_query_heads, 119 | query_group_size, 120 | head_size, 121 | block_size, 122 | context_len, 123 | vllm_version, 124 | provider, 125 | dtype=torch.float16, 126 | device="cuda", 127 | ): 128 | num_kv_heads = num_query_heads // query_group_size 129 | 130 | context_lens = torch.zeros(num_seqs, dtype=torch.int32, device=device) + context_len 131 | max_num_blocks_per_seq = (context_len + block_size - 1) // block_size 132 | 133 | attn_scale = head_size**-0.5 134 | q = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype, device=device) 135 | q.uniform_(-attn_scale, attn_scale) 136 | out = torch.empty_like(q) 137 | 138 | k_cache = torch.empty( 139 | NUM_BLOCKS, num_kv_heads, block_size, head_size, dtype=dtype, device=device 140 | ) 141 | k_cache.uniform_(-attn_scale, attn_scale) 142 | v_cache = torch.empty_like(k_cache) 143 | v_cache.uniform_(-attn_scale, attn_scale) 144 | 145 | # (NUM_SEQS, MAX_NUM_BLOCKS_PER_SEQ) 146 | block_tables = torch.randint( 147 | 0, 148 | NUM_BLOCKS, 149 | (num_seqs, max_num_blocks_per_seq), 150 | dtype=torch.int32, 151 | device=device, 152 | ) 153 | 154 | if provider == "triton": 155 | fn = lambda: flag_attn.paged_attention( 156 | q, 157 | k_cache, 158 | v_cache, 159 | context_lens, 160 | block_tables, 161 | attn_scale, 162 | context_len, 163 | ) 164 | ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) 165 | 166 | if provider == "vllm": 167 | # Correctness error, does not affect performance results 168 | fn = lambda: vllm_paged_attention( 169 | out, 170 | q, 171 | k_cache, 172 | v_cache, 173 | num_kv_heads, 174 | attn_scale, 175 | block_tables, 176 | context_lens, 177 | block_size, 178 | context_len, 179 | PARTITION_SIZE=512, 180 | version=vllm_version, 181 | ) 182 | ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) 183 | 184 | total_flops = 2.0 * num_seqs * num_query_heads * 2 * context_len * head_size 185 | return total_flops / ms * 1e-9 186 | 187 | 188 | if HAS_VLLM: 189 | paged_attention_benchmark_with_vllm.run(print_data=True) 190 | -------------------------------------------------------------------------------- /tests/flag_attn/test_piecewise_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytest 3 | import logging 4 | 5 | import flag_attn 6 | 7 | torch.random.manual_seed(10086) 8 | 9 | def max_diff(a, b): 10 | return (a - b).abs().max().item() 11 | 12 | @pytest.mark.parametrize('device_id', list(range(torch.cuda.device_count()))) 13 | @pytest.mark.parametrize('scale', [1.0, 2.0, 3.0, 4.0]) 14 | @pytest.mark.parametrize('B, H, M, N, D', [ 15 | (2, 4, 512, 612, 128), 16 | (2, 4, 1024, 1034, 64), 17 | (2, 4, 2048, 2048, 32), 18 | (2, 4, 4096, 4096, 16), 19 | (2, 4, 4096, 4001, 16), 20 | (1, 2, 8192, 8192, 16), 21 | (1, 2, 8192, 8192, 32), 22 | (1, 2, 8192, 3867, 32), 23 | ]) 24 | @pytest.mark.parametrize('causal', [True, False]) 25 | @pytest.mark.parametrize('stride_order', ['BHTD', 'BTHD']) 26 | @pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) 27 | def test_attention_fwd(B, H, M, N, D, causal, stride_order, dtype, scale, device_id): 28 | device = f"cuda:{device_id}" 29 | if stride_order == "BHTD": 30 | q1 = torch.empty((B, H, M, D), dtype=dtype, device=device).normal_(mean=0., std=scale) 31 | q2 = torch.empty((B, H, M, D), dtype=dtype, device=device).normal_(mean=0., std=scale) 32 | k1 = torch.empty((B, H, N, D), dtype=dtype, device=device).normal_(mean=0., std=scale) 33 | k2 = torch.empty((B, H, N, D), dtype=dtype, device=device).normal_(mean=0., std=scale) 34 | v = torch.empty((B, H, N, D), dtype=dtype, device=device).normal_(mean=0., std=scale) 35 | else: 36 | q1 = torch.empty((B, M, H, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2) 37 | q2 = torch.empty((B, M, H, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2) 38 | k1 = torch.empty((B, N, H, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2) 39 | k2 = torch.empty((B, N, H, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2) 40 | v = torch.empty((B, N, H, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2) 41 | w = (M // 2) if M < N else (M - N // 2) 42 | 43 | o_ref = flag_attn.testing.piecewise_attention(q1, k1, q2, k2, v, w, causal=causal, upcast=True) 44 | o_torch = flag_attn.testing.piecewise_attention(q1, k1, q2, k2, v, w, causal=causal) 45 | o_hyp = flag_attn.piecewise_attention(q1, k1, q2, k2, v, w, causal=causal) 46 | 47 | torch_max_diff = max_diff(o_torch, o_ref) 48 | triton_max_diff = max_diff(o_hyp, o_ref) 49 | logging.info("torch_max_diff: {:.8f}\ttriton_max_diff: {:.8f}".format(torch_max_diff, triton_max_diff)) 50 | assert triton_max_diff <= 2 * torch_max_diff + 1e-5 51 | # assert torch.testing.assert_close(o_hyp, o_ref) 52 | 53 | 54 | 55 | @pytest.mark.parametrize('device_id', list(range(torch.cuda.device_count()))) 56 | @pytest.mark.parametrize('scale', [1.0, 2.0, 3.0, 4.0]) 57 | @pytest.mark.parametrize('B, H, M, N, D', [ 58 | (2, 4, 512, 612, 128), 59 | (2, 4, 1024, 1034, 64), 60 | (2, 4, 2048, 2048, 32), 61 | (2, 4, 4096, 4096, 16), 62 | (2, 4, 4096, 4001, 16), 63 | (1, 2, 8192, 8192, 16), 64 | (1, 2, 8192, 8192, 32), 65 | (1, 2, 8192, 3867, 32), 66 | ]) 67 | @pytest.mark.parametrize('causal', [True, False]) 68 | @pytest.mark.parametrize('stride_order', ['BHTD', 'BTHD']) 69 | @pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) 70 | def test_attention_bwd(B, H, M, N, D, causal, stride_order, dtype, scale, device_id): 71 | device = f"cuda:{device_id}" 72 | if stride_order == "BHTD": 73 | q1 = torch.empty((B, H, M, D), dtype=dtype, device=device).normal_(mean=0., std=scale).requires_grad_() 74 | q2 = torch.empty((B, H, M, D), dtype=dtype, device=device).normal_(mean=0., std=scale).requires_grad_() 75 | k1 = torch.empty((B, H, N, D), dtype=dtype, device=device).normal_(mean=0., std=scale).requires_grad_() 76 | k2 = torch.empty((B, H, N, D), dtype=dtype, device=device).normal_(mean=0., std=scale).requires_grad_() 77 | v = torch.empty((B, H, N, D), dtype=dtype, device=device).normal_(mean=0., std=scale).requires_grad_() 78 | do = torch.empty((B, H, M, D), dtype=dtype, device=device).normal_(mean=0., std=scale) 79 | else: 80 | q1 = torch.empty((B, M, H, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2).requires_grad_() 81 | q2 = torch.empty((B, M, H, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2).requires_grad_() 82 | k1 = torch.empty((B, N, H, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2).requires_grad_() 83 | k2 = torch.empty((B, N, H, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2).requires_grad_() 84 | v = torch.empty((B, N, H, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2).requires_grad_() 85 | do = torch.empty((B, M, H, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2) 86 | 87 | w = (M // 2) if M < N else (M - N // 2) 88 | 89 | o_ref = flag_attn.testing.piecewise_attention(q1, k1, q2, k2, v, w, causal=causal, upcast=True) 90 | dq1_ref, dk1_ref, dq2_ref, dk2_ref, dv_ref = torch.autograd.grad(o_ref, (q1, k1, q2, k2, v), do) 91 | 92 | o_torch = flag_attn.testing.piecewise_attention(q1, k1, q2, k2, v, w, causal=causal, upcast=False) 93 | dq1_torch, dk1_torch, dq2_torch, dk2_torch, dv_torch = torch.autograd.grad(o_torch, (q1, k1, q2, k2, v), do) 94 | 95 | o_hyp = flag_attn.piecewise_attention(q1, k1, q2, k2, v, w, causal=causal) 96 | dq1_hyp, dk1_hyp, dq2_hyp, dk2_hyp, dv_hyp = torch.autograd.grad(o_hyp, (q1, k1, q2, k2, v), do) 97 | 98 | o_torch_max_diff = max_diff(o_torch, o_ref) 99 | dq1_torch_max_diff = max_diff(dq1_torch, dq1_ref) 100 | dq2_torch_max_diff = max_diff(dq2_torch, dq2_ref) 101 | dk1_torch_max_diff = max_diff(dk1_torch, dk1_ref) 102 | dk2_torch_max_diff = max_diff(dk2_torch, dk2_ref) 103 | dv_torch_max_diff = max_diff(dv_torch, dv_ref) 104 | 105 | o_triton_max_diff = max_diff(o_hyp, o_ref) 106 | dq1_triton_max_diff = max_diff(dq1_hyp, dq1_ref) 107 | dq2_triton_max_diff = max_diff(dq2_hyp, dq2_ref) 108 | dk1_triton_max_diff = max_diff(dk1_hyp, dk1_ref) 109 | dk2_triton_max_diff = max_diff(dk2_hyp, dk2_ref) 110 | dv_triton_max_diff = max_diff(dv_hyp, dv_ref) 111 | 112 | logging.info("o torch_max_diff: {:.8f}\ttriton_max_diff: {:.8f}".format(o_torch_max_diff, o_triton_max_diff)) 113 | logging.info("dq1 torch_max_diff: {:.8f}\ttriton_max_diff: {:.8f}".format(dq1_torch_max_diff, dq1_triton_max_diff)) 114 | logging.info("dq2 torch_max_diff: {:.8f}\ttriton_max_diff: {:.8f}".format(dq2_torch_max_diff, dq2_triton_max_diff)) 115 | logging.info("dk1 torch_max_diff: {:.8f}\ttriton_max_diff: {:.8f}".format(dk1_torch_max_diff, dk1_triton_max_diff)) 116 | logging.info("dk2 torch_max_diff: {:.8f}\ttriton_max_diff: {:.8f}".format(dk2_torch_max_diff, dk2_triton_max_diff)) 117 | logging.info("dv torch_max_diff: {:.8f}\ttriton_max_diff: {:.8f}".format(dv_torch_max_diff, dv_triton_max_diff)) 118 | 119 | assert o_triton_max_diff <= 2 * o_torch_max_diff + 1e-5 120 | assert dq1_triton_max_diff <= 2 * dq1_torch_max_diff + 1e-5 121 | assert dq2_triton_max_diff <= 2 * dq2_torch_max_diff + 1e-5 122 | assert dk1_triton_max_diff <= 2 * dk1_torch_max_diff + 1e-5 123 | assert dk2_triton_max_diff <= 2 * dk2_torch_max_diff + 1e-5 124 | assert dv_triton_max_diff <= 2 * dv_torch_max_diff + 1e-5 125 | -------------------------------------------------------------------------------- /README_cn.md: -------------------------------------------------------------------------------- 1 | # FlagAttention 2 | 3 |
4 |
5 |
](https://flagopen.baai.ac.cn/)
245 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # FlagAttention
2 |
3 |
4 |
5 |
](https://flagopen.baai.ac.cn/)
254 |
--------------------------------------------------------------------------------
/src/flag_attn/split_kv.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import triton
4 | import triton.language as tl
5 |
6 | """
7 | This file implements flash decoding, flash attention with split_kv, which exposes another
8 | dimension of parallelism when batch_size * num_heads * blocks_along_seqlen_q cannot saturate
9 | the gpu's SM's.
10 |
11 | For more details, refer to https://princeton-nlp.github.io/flash-decoding/.
12 | """
13 |
14 | @triton.jit
15 | def _fwd_split_kv_kernel(
16 | Q, K, V, sm_scale,
17 | L, O,
18 | stride_qz, stride_qh, stride_qm, stride_qk,
19 | stride_kz, stride_kh, stride_kn, stride_kk,
20 | stride_vz, stride_vh, stride_vn, stride_vk,
21 | stride_oz, stride_oh, stride_os, stride_om, stride_ok,
22 | Z, H, M, N, P_SEQ, N_SPLIT_SIZE, S, num_groups,
23 | BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,
24 | IS_CAUSAL: tl.constexpr, LARGER_M: tl.constexpr,
25 | DIVISIBLE_M: tl.constexpr, DIVISIBLE_N: tl.constexpr,
26 | ):
27 | input_dtype = Q.dtype.element_ty
28 | # -- grid id --
29 | start_m = tl.program_id(0)
30 | n_split_id = tl.program_id(1)
31 | off_zh = tl.program_id(2)
32 | off_h = off_zh % H
33 | off_z = off_zh // H
34 | off_hk = off_h // num_groups
35 |
36 | # scale sm_scale by log_2(e) and use
37 | # 2^x instead of exp in the loop because CSE and LICM
38 | # don't work as expected with `exp` in the loop
39 | log2e: tl.constexpr = 1.4426950408889634
40 | qk_scale = sm_scale * log2e
41 |
42 | # offset pointers for (batch & head)
43 | Q += off_z * stride_qz + off_h * stride_qh
44 | K += off_z * stride_kz + off_hk * stride_kh
45 | V += off_z * stride_vz + off_hk * stride_vh
46 |
47 | # offset pointers for (batch & head, split)
48 | O += off_z * stride_oz + off_h * stride_oh + n_split_id * stride_os # o's shape is (B, H, S, M, D)
49 | L += ((off_z * H + off_h) * S + n_split_id) * M # l's shape is (B, H, S, M)
50 |
51 | offs_m_base = tl.arange(0, BLOCK_M)
52 | offs_m = start_m * BLOCK_M + offs_m_base
53 | offs_n_base = tl.arange(0, BLOCK_N)
54 | offs_k = tl.arange(0, BLOCK_DMODEL)
55 |
56 | # initialize pointers to value-like data
57 | q_ptrs = Q + (offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk) # (BLOCK_M, BLOCK_DMODEL)
58 | o_ptrs = O + (offs_m[:, None] * stride_om + offs_k[None, :] * stride_ok) # (BLOCK_M, BLOCK_DMODEL)
59 | l_ptrs = L + offs_m
60 |
61 | # initialize pointer to m and l, fp32 for accumulators
62 | m_i = tl.full([BLOCK_M], value=-float("inf"), dtype=tl.float32)
63 | l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
64 | acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
65 |
66 | # load q
67 | if DIVISIBLE_M:
68 | q = tl.load(q_ptrs)
69 | else:
70 | mask_m = offs_m < M
71 | q = tl.load(q_ptrs, mask=mask_m[:, None])
72 |
73 | #Dot I trick: to place q in registers, it saves shared memory
74 | if BLOCK_DMODEL < 128:
75 | I = tl.where(offs_k[:, None] == offs_k,
76 | tl.full((BLOCK_DMODEL, BLOCK_DMODEL), 1.0, dtype=input_dtype),
77 | tl.full((BLOCK_DMODEL, BLOCK_DMODEL), 0.0, dtype=input_dtype))
78 | q = tl.dot(q, I).to(input_dtype)
79 | # else:
80 | # I = tl.where(offs_m_base[:, None] == offs_m_base,
81 | # tl.full((BLOCK_M, BLOCK_M), 1.0, dtype=input_dtype),
82 | # tl.full((BLOCK_M, BLOCK_M), 0.0, dtype=input_dtype))
83 | # q = tl.dot(I, q).to(input_dtype)
84 |
85 | # NOTE: Loop-Bound-For-N
86 | # The indices in m-dimension that this block may access is in `[start_m * BLOCK_M, (start_m + 1) * BLOCK_M)`.
87 | # According to the rule of causal masking, then max index in n-dimension that this block may access
88 | # is `P_SEQ + (start_m + 1) * BLOCK_M`.
89 | # However, the upper bound of index in n-dimension should never exceed the sequence length of k/v(`P_SEQ + N_CTX`).
90 | # `P_SEQ + (start_m + 1) * BLOCK_M` may be larger than `N`.
91 | # At this case, there would be illegal memory access when loading k & v tiles
92 | # if mask_n is not applied for loading(only when `DIVISIBLE_N`` is true).
93 | # See also https://github.com/FlagOpen/FlagAttention/pull/8
94 | N_LEFT = n_split_id * N_SPLIT_SIZE
95 | N_RIGHT = tl.minimum(N_LEFT + N_SPLIT_SIZE, N)
96 | if IS_CAUSAL:
97 | hi = tl.minimum(N_RIGHT, P_SEQ + (start_m + 1) * BLOCK_M)
98 | if LARGER_M:
99 | hi = tl.maximum(N_LEFT, hi)
100 | else:
101 | hi = N_RIGHT
102 |
103 | # loop over k, v and update accumulators
104 | offs_n_init = N_LEFT + offs_n_base
105 | k_ptrs = K + (offs_k[:, None] * stride_vk + offs_n_init[None, :] * stride_vn) # (BLOCK_DMODEL, BLOCK_N)
106 | v_ptrs = V + (offs_n_init[:, None] * stride_kn + offs_k[None, :] * stride_kk) # (BLOCK_N, BLOCK_DMODEL)
107 | for start_n in range(N_LEFT, hi, BLOCK_N):
108 | start_n = tl.multiple_of(start_n, BLOCK_N)
109 | offs_n = start_n + offs_n_base
110 |
111 | # -- load k, v --
112 | if DIVISIBLE_N:
113 | k = tl.load(k_ptrs, cache_modifier=".cg")
114 | v = tl.load(v_ptrs, cache_modifier=".cg")
115 | else:
116 | mask_n = offs_n < N
117 | k = tl.load(k_ptrs, mask=mask_n[None, :], cache_modifier=".cg")
118 | v = tl.load(v_ptrs, mask=mask_n[:, None], cache_modifier=".cg")
119 |
120 | # -- compute qk ---
121 | s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
122 | s += tl.dot(q, k)
123 |
124 | if not DIVISIBLE_N:
125 | s = tl.where(mask_n[None, :], s, float("-inf"))
126 | if IS_CAUSAL:
127 | causal_mask = (P_SEQ + offs_m[:, None]) >= offs_n[None, :]
128 | s = tl.where(causal_mask, s, float("-inf"))
129 |
130 | # -- compute scaling constant ---
131 | m_i_new = tl.maximum(m_i, tl.max(s, 1))
132 | alpha = tl.math.exp2((m_i - m_i_new) * qk_scale)
133 | p = tl.math.exp2(s * qk_scale - m_i_new[:, None] * qk_scale)
134 |
135 | # -- scale and update acc: acc *= alpha[:, None]--
136 | acc *= alpha[:, None]
137 | acc += tl.dot(p.to(input_dtype), v)
138 |
139 | # -- update m_i and l_i --
140 | l_i = l_i * alpha + tl.sum(p, 1)
141 | m_i = m_i_new
142 | # update pointers
143 | k_ptrs += BLOCK_N * stride_kn
144 | v_ptrs += BLOCK_N * stride_vn
145 |
146 | # write back l & o
147 | if IS_CAUSAL and LARGER_M:
148 | is_empty_line = (offs_m + P_SEQ) < 0
149 | acc = tl.where(is_empty_line[:, None], 0.0, acc * (1.0 / l_i[:, None]))
150 | l = tl.where(is_empty_line, float("-inf"), m_i * sm_scale + tl.log(l_i))
151 | else:
152 | acc = acc * (1.0 / l_i[:, None])
153 | l = m_i * sm_scale + tl.log(l_i) # log(normalizer)
154 |
155 | if DIVISIBLE_M:
156 | tl.store(l_ptrs, l, cache_modifier=".cg")
157 | tl.store(o_ptrs, acc.to(input_dtype), cache_modifier=".cg")
158 | else:
159 | tl.store(l_ptrs, l, mask=mask_m, cache_modifier=".cg")
160 | tl.store(o_ptrs, acc.to(input_dtype), mask=mask_m[:, None], cache_modifier=".cg")
161 |
162 | @triton.jit
163 | def _fwd_combine_kv_splits(
164 | multiple_o, multiple_l,
165 | final_o, final_l,
166 | stride_mul_oz, stride_mul_oh, stride_mul_os, stride_mul_om, stride_mul_ok,
167 | stride_fin_oz, stride_fin_oh, stride_fin_om, stride_fin_ok,
168 | Z, H, M, S,
169 | BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
170 | DIVISIBLE_M: tl.constexpr,
171 | ):
172 | start_m = tl.program_id(0)
173 | offs_h = tl.program_id(1)
174 | offs_z = tl.program_id(2)
175 |
176 | # offset
177 | multiple_o += offs_z * stride_mul_oz + offs_h * stride_mul_oh # (B, H, S, M, D)
178 | multiple_l += (offs_z * H + offs_h) * S * M # (B, H, S, M)
179 |
180 | offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
181 | if not DIVISIBLE_M:
182 | mask_m = offs_m < M
183 |
184 | # 1st loop: online logsumexp to save a swipe
185 | m = tl.full([BLOCK_M], value=float("-inf"), dtype=tl.float32)
186 | acc = tl.full([BLOCK_M], value=float(0.0), dtype=tl.float32)
187 | l_ptrs = multiple_l + offs_m
188 | for _ in range(0, S):
189 | if DIVISIBLE_M:
190 | l = tl.load(l_ptrs)
191 | else:
192 | l = tl.load(l_ptrs, mask=mask_m)
193 | m_new = tl.maximum(m, l)
194 | acc = acc * tl.exp(m - m_new) + tl.exp(l - m_new)
195 | m = m_new
196 | l_ptrs += M
197 | l_acc = m + tl.log(acc)
198 |
199 | # 2rd loop to rescale and accumulate o
200 | o_acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
201 | l_ptrs = multiple_l + offs_m
202 | offs_k = tl.arange(0, BLOCK_DMODEL)
203 | o_ptrs = multiple_o + offs_m[:, None] * stride_mul_om + offs_k[None, :] * stride_mul_ok
204 | for _ in range(0, S):
205 | l = tl.load(l_ptrs, mask=offs_m < M)
206 | rescale = tl.exp(l - l_acc)
207 | if DIVISIBLE_M:
208 | o = tl.load(o_ptrs, )
209 | else:
210 | o = tl.load(o_ptrs, mask=mask_m[:, None])
211 | o_acc += o * rescale[:, None]
212 |
213 | l_ptrs += M
214 | o_ptrs += stride_mul_os
215 |
216 | # write back
217 | final_o += offs_z * stride_fin_oz + offs_h * stride_fin_oh
218 | final_l += (offs_z * H + offs_h) * M
219 | a_ptrs = final_o + offs_m[:, None] * stride_fin_om + offs_k * stride_fin_ok
220 | b_ptrs = final_l + offs_m
221 |
222 | if DIVISIBLE_M:
223 | tl.store(a_ptrs, o_acc)
224 | tl.store(b_ptrs, l_acc)
225 | else:
226 | tl.store(a_ptrs, o_acc, mask=mask_m[:, None])
227 | tl.store(b_ptrs, l_acc, mask=mask_m)
228 |
229 | def get_fwd_config(B, H, M, N, D, causal):
230 | # BLOCK_M, BLOCK_N, num_stages, num_warps
231 | return (16, 128, 1, 4)
232 |
233 | # this function is adapted from https://github.com/Dao-AILab/flash-attention/blob/61a777247900f6c2a37376f3ffd7134385fdc95c/csrc/flash_attn/flash_api.cpp#L235
234 | def num_splits_herustic(B, H, M, N, BLOCK_M, BLOCK_N, num_sms, max_splits):
235 | num_blocks_without_split_kv = B * H * triton.cdiv(M, BLOCK_M)
236 | if num_blocks_without_split_kv >= 0.8 * num_sms:
237 | return 1
238 |
239 | num_n_blocks = triton.cdiv(N, BLOCK_N)
240 | def num_split_avaiable(s):
241 | blocks_per_split = triton.cdiv(num_n_blocks, s)
242 | return s == 1 or (blocks_per_split * s - num_n_blocks < blocks_per_split)
243 |
244 | def efficiency(s):
245 | n_waves = (num_blocks_without_split_kv * s) / num_sms
246 | eff = n_waves / math.ceil(n_waves)
247 | return eff
248 |
249 | max_efficiency = 0.0
250 | plans = [] # (num_split, efficiency)
251 | max_splits = min(num_sms, num_n_blocks, max_splits)
252 |
253 | for num_split in range(1, max_splits + 1):
254 | if num_split_avaiable(num_split):
255 | eff = efficiency(num_split)
256 | plans.append((num_split, eff))
257 | max_efficiency = max(eff, max_efficiency)
258 |
259 | for num_split, eff in plans:
260 | if eff >= 0.85 * max_efficiency:
261 | return num_split
262 | return 1
263 |
264 |
265 | # flash decoding
266 | def attention(q, k, v, causal=False, sm_scale=None):
267 | Dq, Dk, Dv = q.shape[-1], k.shape[-1], v.shape[-1]
268 | assert Dq == Dk == Dv
269 | assert Dk in {16, 32, 64, 128}
270 |
271 | B, H, M, D = q.shape
272 | N = k.shape[2]
273 | Hk, Hv = k.shape[1], v.shape[1]
274 | assert Hk == Hv, "num of heads in k and v should be equal"
275 | assert H % Hk == 0, "number of heads in q must be a multiple of that in k & v"
276 | num_groups = H // Hk
277 | P_SEQ = N - M
278 | larger_m = M > N
279 |
280 | if sm_scale is None:
281 | sm_scale = 1. / math.sqrt(D)
282 |
283 | # to work around https://github.com/openai/triton/issues/2441
284 | device = torch.cuda.device_of(q)
285 | num_sms = torch.cuda.get_device_properties(device).multi_processor_count
286 |
287 | with torch.cuda.device(device):
288 | config = get_fwd_config(B, H, M, N, D, causal)
289 | BLOCK_M, BLOCK_N, num_stages, num_warps = config
290 | S = num_splits_herustic(B, H, M, N, BLOCK_M, BLOCK_N, num_sms, 128)
291 |
292 | divisible_m = M % BLOCK_M == 0
293 | divisible_n = N % BLOCK_N == 0
294 |
295 | # consider using 3d grid to avoid div & rem
296 | multiple_l = torch.empty((B, H, S, M), dtype=torch.float32, device="cuda")
297 | multiple_o = torch.empty((B, H, S, M, D), dtype=torch.float16, device="cuda")
298 | grid = (triton.cdiv(M, BLOCK_M), S, H * B)
299 | N_SPLIT_SIZE = triton.cdiv(triton.cdiv(N, BLOCK_N), S) * BLOCK_N
300 | _fwd_split_kv_kernel[grid](
301 | q, k, v, sm_scale,
302 | multiple_l, multiple_o,
303 | q.stride(0), q.stride(1), q.stride(2), q.stride(3),
304 | k.stride(0), k.stride(1), k.stride(2), k.stride(3),
305 | v.stride(0), v.stride(1), v.stride(2), v.stride(3),
306 | multiple_o.stride(0), multiple_o.stride(1), multiple_o.stride(2), multiple_o.stride(3), multiple_o.stride(4),
307 | B, H, M, N, P_SEQ, N_SPLIT_SIZE, S, num_groups,
308 | BLOCK_M=BLOCK_M, BLOCK_DMODEL=D, BLOCK_N=BLOCK_N,
309 | IS_CAUSAL=causal, LARGER_M=larger_m,
310 | DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n,
311 | num_stages=num_stages, num_warps=num_warps,
312 | )
313 |
314 | if S == 1:
315 | return multiple_o.squeeze(2)
316 |
317 | final_l = torch.empty((B, H, M), dtype=torch.float32, device="cuda")
318 | final_o = torch.empty_like(q)
319 | grid = (triton.cdiv(M, BLOCK_M), H, B)
320 | _fwd_combine_kv_splits[grid](
321 | multiple_o, multiple_l,
322 | final_o, final_l,
323 | multiple_o.stride(0), multiple_o.stride(1), multiple_o.stride(2), multiple_o.stride(3), multiple_o.stride(4),
324 | final_o.stride(0), final_o.stride(1), final_o.stride(2), final_o.stride(3),
325 | B, H, M, S,
326 | BLOCK_M=BLOCK_M, BLOCK_DMODEL=D,
327 | DIVISIBLE_M=divisible_m,
328 | num_stages=num_stages, num_warps=num_warps,
329 | )
330 | return final_o
331 |
--------------------------------------------------------------------------------
/tests/flag_attn/test_flash_attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pytest
3 |
4 | import flag_attn
5 |
6 | torch.random.manual_seed(10086)
7 |
8 | def max_diff(a, b):
9 | return (a - b).abs().max().item()
10 |
11 | def zero_percent(a, b):
12 | diff = (a - b).abs()
13 | num_non_zeros = diff.nonzero().shape[0]
14 | return (1.0 - num_non_zeros/ diff.numel()) * 100.0
15 |
16 | def report(name, actual, expected):
17 | print(f"{name}: \tmax_difference: {max_diff(actual, expected):0.6f}\tzero_diff elements: {zero_percent(actual, expected):0.3f}%")
18 |
19 |
20 | @pytest.mark.parametrize('device_id', list(range(torch.cuda.device_count())))
21 | @pytest.mark.parametrize('scale', [1.0, 2.0, 3.0, 4.0])
22 | @pytest.mark.parametrize('B, Hq, Hk, M, N, D', [
23 | (2, 4, 4, 512, 612, 128),
24 | (2, 4, 4, 1024, 1034, 64),
25 | (2, 4, 4, 2048, 2048, 32),
26 | (2, 4, 4, 4096, 4096, 16),
27 | (2, 4, 4, 4001, 4001, 32),
28 | (2, 4, 4, 4001, 4096, 64),
29 | (2, 4, 4, 4096, 4000, 128),
30 | (1, 2, 2, 8192, 8202, 16),
31 | (1, 2, 2, 8192, 8192, 32),
32 | # test for mqa/gqa
33 | (2, 4, 2, 512, 612, 128),
34 | (2, 4, 1, 1024, 1034, 64),
35 | (2, 4, 2, 2048, 2048, 32),
36 | (2, 4, 1, 4096, 4096, 16),
37 | (2, 4, 2, 4001, 4001, 32),
38 | (2, 4, 1, 4001, 4096, 64),
39 | (2, 4, 2, 4096, 4000, 128),
40 | (1, 2, 1, 8192, 8202, 16),
41 | (1, 2, 1, 8192, 8192, 32),
42 | ])
43 | @pytest.mark.parametrize('causal', [True, False])
44 | @pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
45 | @pytest.mark.parametrize('stride_order', ['BHTD', 'BTHD'])
46 | def test_attention_fwd(B, Hq, Hk, M, N, D, causal, stride_order, dtype, scale, device_id):
47 | device = f"cuda:{device_id}"
48 | if stride_order == "BHTD":
49 | q = torch.empty((B, Hq, M, D), dtype=dtype, device=device).normal_(mean=0., std=scale)
50 | k = torch.empty((B, Hk, N, D), dtype=dtype, device=device).normal_(mean=0., std=scale)
51 | v = torch.empty((B, Hk, N, D), dtype=dtype, device=device).normal_(mean=0., std=scale)
52 | else:
53 | q = torch.empty((B, M, Hq, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2)
54 | k = torch.empty((B, N, Hk, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2)
55 | v = torch.empty((B, N, Hk, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2)
56 |
57 | o_ref = flag_attn.testing.flash_attention(q, k, v, causal, upcast=True)
58 | o_torch = flag_attn.testing.flash_attention(q, k, v, causal, upcast=False)
59 | o_hyp = flag_attn.flash_attention(q, k, v, causal)
60 |
61 | torch_max_diff = max_diff(o_torch, o_ref)
62 | triton_max_diff = max_diff(o_hyp, o_ref)
63 | report("o hyp", o_hyp, o_ref)
64 | report("o torch", o_hyp, o_ref)
65 | assert triton_max_diff <= 2 * torch_max_diff + 1e-5
66 |
67 |
68 | @pytest.mark.parametrize('device_id', list(range(torch.cuda.device_count())))
69 | @pytest.mark.parametrize('scale', [10.0])
70 | @pytest.mark.parametrize('B, Hq, Hk, M, N, D', [
71 | (2, 4, 4, 1, 612, 128),
72 | (2, 4, 4, 1, 1034, 64),
73 | (2, 4, 4, 1, 2048, 32),
74 | (2, 4, 4, 1, 4096, 16),
75 | (2, 4, 4, 1, 4001, 32),
76 | (2, 4, 4, 1, 4096, 64),
77 | (2, 4, 4, 2, 4000, 128),
78 | (1, 2, 2, 4, 8202, 16),
79 | (1, 2, 2, 1, 8192, 32),
80 | # test for mqa/gqa
81 | (2, 4, 2, 1, 612, 128),
82 | (2, 4, 1, 1, 1034, 64),
83 | (2, 4, 2, 1, 2048, 32),
84 | (2, 4, 1, 1, 4096, 16),
85 | (2, 4, 2, 1, 4001, 32),
86 | (2, 4, 1, 1, 4096, 64),
87 | (2, 4, 2, 2, 4000, 128),
88 | (1, 2, 1, 4, 8202, 16),
89 | (1, 2, 1, 1, 8192, 32),
90 | ])
91 | @pytest.mark.parametrize('causal', [True, False])
92 | @pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
93 | @pytest.mark.parametrize('stride_order', ['BHTD', 'BTHD'])
94 | def test_attention_splitkv(B, Hq, Hk, M, N, D, causal, stride_order, dtype, scale, device_id):
95 | device = f"cuda:{device_id}"
96 | if stride_order == "BHTD":
97 | q = torch.empty((B, Hq, M, D), dtype=dtype, device=device).normal_(mean=0., std=scale)
98 | k = torch.empty((B, Hk, N, D), dtype=dtype, device=device).normal_(mean=0., std=scale)
99 | v = torch.empty((B, Hk, N, D), dtype=dtype, device=device).normal_(mean=0., std=scale)
100 | else:
101 | q = torch.empty((B, M, Hq, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2)
102 | k = torch.empty((B, N, Hk, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2)
103 | v = torch.empty((B, N, Hk, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2)
104 |
105 | o_ref = flag_attn.testing.flash_attention(q, k, v, causal, upcast=True)
106 | o_torch = flag_attn.testing.flash_attention(q, k, v, causal, upcast=False)
107 | o_hyp = flag_attn.flash_attention(q, k, v, causal)
108 |
109 | torch_max_diff = max_diff(o_torch, o_ref)
110 | triton_max_diff = max_diff(o_hyp, o_ref)
111 | report("o hyp", o_hyp, o_ref)
112 | report("o torch", o_hyp, o_ref)
113 | assert triton_max_diff <= 2 * torch_max_diff + 1e-5
114 |
115 | @pytest.mark.parametrize('device_id', list(range(torch.cuda.device_count())))
116 | @pytest.mark.parametrize('scale', [1.0, 2.0, 3.0, 4.0])
117 | @pytest.mark.parametrize('B, Hq, Hk, M, N, D', [
118 | (2, 4, 4, 512, 612, 128),
119 | (2, 4, 4, 1024, 1034, 64),
120 | (2, 4, 4, 2048, 2048, 32),
121 | (2, 4, 4, 4096, 4096, 16),
122 | (2, 4, 4, 4001, 4001, 32),
123 | (2, 4, 4, 4001, 4096, 64),
124 | (2, 4, 4, 4096, 4001, 128),
125 | (1, 2, 2, 8192, 8202, 16),
126 | (1, 2, 2, 8192, 8192, 32),
127 | (2, 4, 4, 10006, 10, 128),
128 | # test for mqa/gqa
129 | (2, 4, 2, 512, 612, 128),
130 | (2, 4, 1, 1024, 1034, 64),
131 | (2, 4, 2, 2048, 2048, 32),
132 | (2, 4, 1, 4096, 4096, 16),
133 | (2, 4, 2, 4001, 4001, 32),
134 | (2, 4, 1, 4001, 4096, 64),
135 | (2, 4, 2, 4096, 4001, 128),
136 | (1, 2, 1, 8192, 8202, 16),
137 | (1, 2, 1, 8192, 8192, 32),
138 | (2, 4, 2, 10006, 10, 128),
139 | ])
140 | @pytest.mark.parametrize('causal', [True, False])
141 | @pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
142 | @pytest.mark.parametrize('stride_order', ['BHTD', 'BTHD'])
143 | def test_attention_bwd(B, Hq, Hk, M, N, D, causal, stride_order, dtype, scale, device_id):
144 | device = f"cuda:{device_id}"
145 | if stride_order == "BHTD":
146 | q = torch.empty((B, Hq, M, D), dtype=dtype, device=device).normal_(mean=0., std=scale).requires_grad_()
147 | k = torch.empty((B, Hk, N, D), dtype=dtype, device=device).normal_(mean=0., std=scale).requires_grad_()
148 | v = torch.empty((B, Hk, N, D), dtype=dtype, device=device).normal_(mean=0., std=scale).requires_grad_()
149 | do = torch.randn((B, Hq, M, D), dtype=dtype, device=device)
150 | else:
151 | q = torch.empty((B, M, Hq, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2).requires_grad_()
152 | k = torch.empty((B, N, Hk, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2).requires_grad_()
153 | v = torch.empty((B, N, Hk, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2).requires_grad_()
154 | do = torch.randn((B, M, Hq, D), dtype=dtype, device=device).transpose(1, 2)
155 |
156 | o_ref = flag_attn.testing.flash_attention(q, k, v, causal=causal, upcast=True)
157 | o_torch = flag_attn.testing.flash_attention(q, k, v, causal=causal, upcast=False)
158 | o_hyp = flag_attn.flash_attention(q, k, v, causal=causal)
159 |
160 | gq_ref, gk_ref, gv_ref = torch.autograd.grad(o_ref, (q, k, v), do)
161 | gq_torch, gk_torch, gv_torch = torch.autograd.grad(o_torch, (q, k, v), do)
162 | gq_hyp, gk_hyp, gv_hyp = torch.autograd.grad(o_hyp, (q, k, v), do)
163 |
164 | o_torch_max_diff = max_diff(o_torch, o_ref)
165 | gq_torch_max_diff = max_diff(gq_torch, gq_ref)
166 | gk_torch_max_diff = max_diff(gk_torch, gk_ref)
167 | gv_torch_max_diff = max_diff(gv_torch, gv_ref)
168 |
169 | o_triton_max_diff = max_diff(o_hyp, o_ref)
170 | gq_triton_max_diff = max_diff(gq_hyp, gq_ref)
171 | gk_triton_max_diff = max_diff(gk_hyp, gk_ref)
172 | gv_triton_max_diff = max_diff(gv_hyp, gv_ref)
173 |
174 | assert o_triton_max_diff < 2 * o_torch_max_diff + 1e-5
175 | assert gq_triton_max_diff < 2 * gq_torch_max_diff + 1e-5
176 | assert gk_triton_max_diff < 2 * gk_torch_max_diff + 1e-5
177 | assert gv_triton_max_diff < 2 * gv_torch_max_diff + 1e-5
178 |
179 |
180 | @pytest.mark.parametrize('device_id', list(range(torch.cuda.device_count())))
181 | @pytest.mark.parametrize('scale', [1.0, 2.0, 3.0, 4.0])
182 | @pytest.mark.parametrize('B, H, M, N, D', [
183 | (2, 4, 512, 612, 128),
184 | (2, 4, 1024, 1034, 64),
185 | (2, 4, 2048, 2048, 32),
186 | (2, 4, 4096, 4096, 16),
187 | (2, 4, 4001, 4001, 32),
188 | (2, 4, 4001, 4096, 64),
189 | (2, 4, 4096, 4001, 128),
190 | (1, 2, 8192, 8202, 16),
191 | (1, 2, 8192, 8192, 32),
192 | ])
193 | @pytest.mark.parametrize('causal', [False])
194 | @pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
195 | @pytest.mark.parametrize('stride_order', ['BHTD', 'BTHD'])
196 | def test_attention_with_aux_outs(B, H, M, N, D, causal, stride_order, dtype, scale, device_id):
197 | device = f"cuda:{device_id}"
198 | if stride_order == "BHTD":
199 | q = torch.empty((B, H, M, D), dtype=dtype, device=device).normal_(mean=0., std=scale)
200 | k = torch.empty((B, H, N, D), dtype=dtype, device=device).normal_(mean=0., std=scale)
201 | v = torch.empty((B, H, N, D), dtype=dtype, device=device).normal_(mean=0., std=scale)
202 | else:
203 | q = torch.empty((B, M, H, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2)
204 | k = torch.empty((B, N, H, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2)
205 | v = torch.empty((B, N, H, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2)
206 |
207 | o_ref, log_norm_ref, tot_attn_ref = flag_attn.testing.flash_attention(q, k, v, causal, return_log_normalizer=True, return_total_attention=True, upcast=True)
208 | o_torch, log_norm_torch, tot_attn_torch = flag_attn.testing.flash_attention(q, k, v, causal, return_log_normalizer=True, return_total_attention=True, upcast=False)
209 | o_hyp, log_norm_hyp, tot_attn_hyp, *_ = flag_attn.flash_attention(q, k, v, causal, return_log_normalizer=True, return_total_attention=True)
210 |
211 |
212 | torch_max_diff = max_diff(o_torch, o_ref)
213 | triton_max_diff = max_diff(o_hyp, o_ref)
214 | assert triton_max_diff <= 2 * torch_max_diff + 1e-5
215 |
216 | torch_max_diff = max_diff(log_norm_torch, log_norm_ref)
217 | triton_max_diff = max_diff(log_norm_hyp, log_norm_ref)
218 | assert triton_max_diff <= 2 * torch_max_diff + 1e-5
219 |
220 | torch_max_diff = max_diff(tot_attn_torch, tot_attn_ref)
221 | triton_max_diff = max_diff(tot_attn_hyp, tot_attn_ref)
222 | assert triton_max_diff <= 2 * torch_max_diff + 1e-5
223 |
224 |
225 | @pytest.mark.parametrize('device_id', list(range(torch.cuda.device_count())))
226 | @pytest.mark.parametrize('scale', [1.0, 2.0])
227 | @pytest.mark.parametrize('B, H, M, N, D', [
228 | (2, 4, 512, 612, 128),
229 | (2, 4, 1024, 1034, 64),
230 | (2, 4, 2048, 2048, 32),
231 | (2, 4, 4096, 4096, 16),
232 | (2, 4, 4001, 4001, 32),
233 | (2, 4, 4001, 4096, 64),
234 | (2, 4, 4096, 4000, 128),
235 | (1, 2, 8192, 8202, 16),
236 | (1, 2, 8192, 8192, 32),
237 | ])
238 | @pytest.mark.parametrize('causal', [False, True])
239 | @pytest.mark.parametrize('dropout_p', [0.5, 0.8])
240 | @pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
241 | @pytest.mark.parametrize('stride_order', ['BHTD', 'BTHD'])
242 | def test_attention_fwd_dropout(B, H, M, N, D, causal, dropout_p, stride_order, dtype, scale, device_id):
243 | device = f"cuda:{device_id}"
244 | if stride_order == "BHTD":
245 | q = torch.empty((B, H, M, D), dtype=dtype, device=device).normal_(mean=0., std=scale)
246 | k = torch.empty((B, H, N, D), dtype=dtype, device=device).normal_(mean=0., std=scale)
247 | v = torch.empty((B, H, N, D), dtype=dtype, device=device).normal_(mean=0., std=scale)
248 | else:
249 | q = torch.empty((B, M, H, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2)
250 | k = torch.empty((B, N, H, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2)
251 | v = torch.empty((B, N, H, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2)
252 |
253 | o_hyp, _, _, seed, offset = flag_attn.flash_attention(q, k, v, causal, dropout_p=dropout_p, return_seed_offset=True)
254 | mask = flag_attn.testing.recompute_mask(B, H, M, N, dropout_p, seed, offset, device)
255 | o_ref = flag_attn.testing.flash_attention(q, k, v, causal, dropout_p=dropout_p, dropout_mask=mask, upcast=True)
256 | o_torch = flag_attn.testing.flash_attention(q, k, v, causal, dropout_p=dropout_p, dropout_mask=mask, upcast=False)
257 |
258 | torch_max_diff = max_diff(o_torch, o_ref)
259 | triton_max_diff = max_diff(o_hyp, o_ref)
260 | report("o hyp", o_hyp, o_ref)
261 | report("o torch", o_torch, o_ref)
262 | assert triton_max_diff <= 2 * torch_max_diff + 1e-5
263 |
264 |
265 | import random
266 | # @pytest.mark.parametrize('increment', [random.randint(0, 1000000000) for i in range(100)])
267 | @pytest.mark.parametrize('device_id', list(range(torch.cuda.device_count())))
268 | @pytest.mark.parametrize('scale', [1.0, 2.0])
269 | @pytest.mark.parametrize('B, H, M, N, D', [
270 | (2, 4, 512, 612, 128),
271 | (2, 4, 1024, 1034, 64),
272 | (2, 4, 2048, 2048, 32),
273 | (2, 4, 4096, 4096, 16),
274 | (2, 4, 4001, 4001, 32),
275 | (2, 4, 4001, 4096, 64),
276 | (2, 4, 4096, 4000, 128),
277 | (1, 2, 8192, 8202, 16),
278 | (1, 2, 8192, 8192, 32),
279 | ])
280 | @pytest.mark.parametrize('causal', [True, False])
281 | @pytest.mark.parametrize('dropout_p', [0.5, 0.8])
282 | @pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
283 | @pytest.mark.parametrize('stride_order', ['BHTD', 'BTHD'])
284 | def test_attention_bwd_dropout(B, H, M, N, D, causal, dropout_p, stride_order, dtype, scale, device_id):
285 | device = f"cuda:{device_id}"
286 | if stride_order == "BHTD":
287 | q = torch.empty((B, H, M, D), dtype=dtype, device=device).normal_(mean=0., std=scale).requires_grad_()
288 | k = torch.empty((B, H, N, D), dtype=dtype, device=device).normal_(mean=0., std=scale).requires_grad_()
289 | v = torch.empty((B, H, N, D), dtype=dtype, device=device).normal_(mean=0., std=scale).requires_grad_()
290 | do = torch.randn((B, H, M, D), dtype=dtype, device=device)
291 | else:
292 | q = torch.empty((B, M, H, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2).requires_grad_()
293 | k = torch.empty((B, N, H, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2).requires_grad_()
294 | v = torch.empty((B, N, H, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2).requires_grad_()
295 | do = torch.randn((B, M, H, D), dtype=dtype, device=device).transpose(1, 2)
296 |
297 | # from flag_attn.dropout import philox_cuda_seed_offset
298 | o_hyp, _, _, seed, offset = flag_attn.flash_attention(q, k, v, causal=causal, dropout_p=dropout_p, return_seed_offset=True)
299 | mask = flag_attn.testing.recompute_mask(B, H, M, N, dropout_p, seed, offset, device)
300 | o_ref = flag_attn.testing.flash_attention(q, k, v, causal=causal, dropout_p=dropout_p, dropout_mask=mask, upcast=True)
301 | o_torch = flag_attn.testing.flash_attention(q, k, v, causal=causal, dropout_p=dropout_p, dropout_mask=mask, upcast=False)
302 |
303 | gq_ref, gk_ref, gv_ref = torch.autograd.grad(o_ref, (q, k, v), do)
304 | gq_torch, gk_torch, gv_torch = torch.autograd.grad(o_torch, (q, k, v), do)
305 | gq_hyp, gk_hyp, gv_hyp = torch.autograd.grad(o_hyp, (q, k, v), do)
306 |
307 | o_torch_max_diff = max_diff(o_torch, o_ref)
308 | gq_torch_max_diff = max_diff(gq_torch, gq_ref)
309 | gk_torch_max_diff = max_diff(gk_torch, gk_ref)
310 | gv_torch_max_diff = max_diff(gv_torch, gv_ref)
311 |
312 | o_triton_max_diff = max_diff(o_hyp, o_ref)
313 | gq_triton_max_diff = max_diff(gq_hyp, gq_ref)
314 | gk_triton_max_diff = max_diff(gk_hyp, gk_ref)
315 | gv_triton_max_diff = max_diff(gv_hyp, gv_ref)
316 |
317 | assert o_triton_max_diff < 2 * o_torch_max_diff + 1e-5
318 | assert gq_triton_max_diff < 2 * gq_torch_max_diff + 1e-5
319 | assert gk_triton_max_diff < 2 * gk_torch_max_diff + 1e-5
320 | assert gv_triton_max_diff < 2 * gv_torch_max_diff + 1e-5
--------------------------------------------------------------------------------
/src/flag_attn/paged.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import triton
3 | import triton.language as tl
4 |
5 | # Requires triton 2.2.0
6 | def attention(
7 | query: torch.Tensor, # [num_seqs, NUM_KV_HEADS * QUERY_GROUP_SIZE, HEAD_SIZE]
8 | key_cache: torch.Tensor, # [num_blocks, NUM_KV_HEADS, KV_BLOCK_SIZE, HEAD_SIZE]
9 | value_cache: torch.Tensor, # [num_blocks, NUM_KV_HEADS, KV_BLOCK_SIZE, HEAD_SIZE], required same stride with key_cache
10 | context_lens: torch.Tensor, # [num_seqs]
11 | block_tables: torch.Tensor, # [num_seqs, max_num_blocks_per_seq]
12 | attn_scale: float,
13 | max_context_len: int,
14 | num_splits: int = 0,
15 | ) -> None:
16 | out = torch.empty_like(query)
17 |
18 | num_seqs = query.shape[0]
19 | num_kv_heads = key_cache.shape[1]
20 | kv_block_size = key_cache.shape[2]
21 | head_size = key_cache.shape[3]
22 | query_group_size = query.shape[1] // num_kv_heads
23 |
24 | if query_group_size == 1:
25 | padded_group_size = 1
26 | elif query_group_size < 16:
27 | padded_group_size = 16
28 | else:
29 | padded_group_size = triton.next_power_of_2(query_group_size)
30 |
31 | assert head_size in (16, 32, 64, 128, 256, 512), f"head_size={head_size}"
32 | assert padded_group_size == 1 or kv_block_size >= 16, f"kv_block_size={kv_block_size}"
33 | # query_group_size in (1, 2, 4, 8, 16, 32, 64, 128, 256)
34 | # assert query_group_size > 0 and query_group_size & (query_group_size-1) == 0, f"query_group_size={query_group_size}"
35 |
36 | # config for A100
37 | # TODO: support more devices and optimize
38 | device = torch.cuda.device_of(query)
39 | num_sms = torch.cuda.get_device_properties(device).multi_processor_count
40 | if num_splits == 0:
41 | if num_seqs * num_kv_heads > 2 * num_sms:
42 | num_splits = 1
43 | if max_context_len >= 4096:
44 | partition_size = max(256, kv_block_size)
45 | num_splits = triton.cdiv(max_context_len, partition_size)
46 | else:
47 | partition_size = max(256, kv_block_size)
48 | num_splits = triton.cdiv(max_context_len, partition_size)
49 | if max_context_len <= 1024 or kv_block_size >= 256:
50 | num_splits = 1
51 | elif num_splits > 1:
52 | partition_size = triton.cdiv(max_context_len, num_splits)
53 | partition_size = triton.next_power_of_2(partition_size)
54 |
55 | with torch.cuda.device(device):
56 | if num_splits == 1:
57 | grid = (num_seqs, num_kv_heads, 1)
58 | _paged_attn_kernel[grid](
59 | out, # dummy input
60 | out, # dummy input
61 | out,
62 | query,
63 | key_cache,
64 | value_cache,
65 | context_lens,
66 | block_tables,
67 | attn_scale,
68 | block_tables.stride(0),
69 | block_tables.stride(1),
70 | query.stride(0),
71 | query.stride(1),
72 | query.stride(2),
73 | key_cache.stride(0),
74 | key_cache.stride(1),
75 | key_cache.stride(2),
76 | key_cache.stride(3),
77 | out.stride(0),
78 | out.stride(1),
79 | out.stride(1),
80 | out.stride(1),
81 | out.stride(2),
82 | head_size,
83 | query_group_size,
84 | padded_group_size,
85 | num_kv_heads,
86 | kv_block_size,
87 | PARTITION_SIZE=0,
88 | )
89 |
90 | else:
91 | grid = (num_seqs, num_kv_heads, num_splits)
92 | m_i = torch.empty(
93 | size=(num_seqs, num_kv_heads, num_splits, query_group_size),
94 | dtype=torch.float32,
95 | device=query.device,
96 | )
97 | l_i = torch.empty_like(m_i)
98 | tmp_out = torch.empty(
99 | size=(
100 | num_seqs,
101 | num_kv_heads,
102 | num_splits,
103 | query_group_size,
104 | head_size,
105 | ),
106 | dtype=out.dtype,
107 | device=out.device,
108 | )
109 |
110 | assert (partition_size >= kv_block_size) and (partition_size % kv_block_size == 0), \
111 | f"partition_size={partition_size}, kv_block_size={kv_block_size}"
112 | _paged_attn_kernel[grid](
113 | m_i,
114 | l_i,
115 | tmp_out,
116 | query,
117 | key_cache,
118 | value_cache,
119 | context_lens,
120 | block_tables,
121 | attn_scale,
122 | block_tables.stride(0),
123 | block_tables.stride(1),
124 | query.stride(0),
125 | query.stride(1),
126 | query.stride(2),
127 | key_cache.stride(0),
128 | key_cache.stride(1),
129 | key_cache.stride(2),
130 | key_cache.stride(3),
131 | tmp_out.stride(0),
132 | tmp_out.stride(1),
133 | tmp_out.stride(2),
134 | tmp_out.stride(3),
135 | tmp_out.stride(4),
136 | head_size,
137 | query_group_size,
138 | padded_group_size,
139 | num_kv_heads,
140 | kv_block_size,
141 | partition_size,
142 | )
143 |
144 | reduce_grid = (num_seqs, num_kv_heads)
145 | next_num_splits = triton.next_power_of_2(num_splits)
146 |
147 | _paged_attn_v2_reduce_kernel[reduce_grid](
148 | out,
149 | m_i,
150 | l_i,
151 | tmp_out,
152 | context_lens,
153 | num_splits,
154 | out.stride(0),
155 | out.stride(1),
156 | out.stride(2),
157 | head_size,
158 | query_group_size,
159 | num_kv_heads,
160 | partition_size,
161 | next_num_splits,
162 | )
163 | return out
164 |
165 |
166 | def get_num_warps(QUERY_GROUP_SIZE, HEAD_SIZE, KV_BLOCK_SIZE):
167 | if QUERY_GROUP_SIZE == 1:
168 | if HEAD_SIZE >= 128 and KV_BLOCK_SIZE >= 32:
169 | return 16
170 | else:
171 | return 8
172 | else:
173 | return 4
174 |
175 |
176 | def get_num_stages(PARTITION_SIZE, KV_BLOCK_SIZE):
177 | if PARTITION_SIZE == 0:
178 | return 1
179 | else:
180 | if torch.cuda.get_device_capability() == (8, 0):
181 | if KV_BLOCK_SIZE < 256:
182 | return 3
183 | else:
184 | return 2
185 | elif torch.cuda.get_device_capability() == (8, 6):
186 | if KV_BLOCK_SIZE < 256:
187 | return 2
188 | else:
189 | return 1
190 | else:
191 | return 1
192 |
193 |
194 | @triton.heuristics(
195 | {
196 | "num_warps": lambda args: get_num_warps(
197 | args["QUERY_GROUP_SIZE"], args["HEAD_SIZE"], args["KV_BLOCK_SIZE"]
198 | ),
199 | "num_stages": lambda args: get_num_stages(
200 | args["QUERY_GROUP_SIZE"], args["KV_BLOCK_SIZE"]
201 | ),
202 | }
203 | )
204 | @triton.jit
205 | def _paged_attn_kernel(
206 | m_i_ptr, # [num_seqs, NUM_KV_HEADS, max_num_partitions, QUERY_GROUP_SIZE]
207 | l_i_ptr, # [num_seqs, NUM_KV_HEADS, max_num_partitions, QUERY_GROUP_SIZE]
208 | out_ptr, # [num_seqs, NUM_KV_HEADS, max_num_partitions, QUERY_GROUP_SIZE, HEAD_SIZE]
209 | q_ptr, # [num_seqs, NUM_KV_HEADS * QUERY_GROUP_SIZE, HEAD_SIZE]
210 | k_cache_ptr, # [num_blocks, NUM_KV_HEADS, KV_BLOCK_SIZE, HEAD_SIZE]
211 | v_cache_ptr, # [num_blocks, NUM_KV_HEADS, KV_BLOCK_SIZE, HEAD_SIZE]
212 | context_lens_ptr, # [num_seqs]
213 | block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
214 | attn_scale,
215 | stride_bt0,
216 | stride_bt1,
217 | stride_q0,
218 | stride_q1,
219 | stride_q2,
220 | stride_kv0,
221 | stride_kv1,
222 | stride_kv2,
223 | stride_kv3,
224 | stride_o0,
225 | stride_o1,
226 | stride_o2,
227 | stride_o3,
228 | stride_o4,
229 | HEAD_SIZE: tl.constexpr,
230 | QUERY_GROUP_SIZE: tl.constexpr,
231 | PADDED_QUERY_GROUP_SIZE: tl.constexpr,
232 | NUM_KV_HEADS: tl.constexpr,
233 | KV_BLOCK_SIZE: tl.constexpr,
234 | PARTITION_SIZE: tl.constexpr,
235 | ):
236 | seq_idx = tl.program_id(0)
237 | kv_head_idx = tl.program_id(1)
238 | part_idx = tl.program_id(2)
239 | max_num_partitions = tl.num_programs(2)
240 |
241 | # scale sm_scale by log_2(e) and use
242 | # 2^x instead of exp in the loop because CSE and LICM
243 | # don't work as expected with `exp` in the loop
244 | log2e: tl.constexpr = 1.4426950408889634
245 |
246 | USE_PARTITIONING = PARTITION_SIZE > 0
247 | context_len = tl.load(context_lens_ptr + seq_idx)
248 | if USE_PARTITIONING:
249 | context_start_idx = part_idx * PARTITION_SIZE
250 | if context_start_idx >= context_len:
251 | return
252 | context_end_idx = tl.minimum(context_start_idx + PARTITION_SIZE, context_len)
253 | num_blocks = tl.cdiv(context_end_idx - context_start_idx, KV_BLOCK_SIZE)
254 | else:
255 | num_blocks = tl.cdiv(context_len, KV_BLOCK_SIZE)
256 |
257 | block_offset = tl.arange(0, KV_BLOCK_SIZE)
258 | head_offset = tl.arange(0, HEAD_SIZE)
259 | padding_group_offset = tl.arange(0, PADDED_QUERY_GROUP_SIZE)
260 |
261 | kv_offset = (
262 | kv_head_idx * stride_kv1
263 | + block_offset[:, None] * stride_kv2
264 | + head_offset[None, :] * stride_kv3
265 | )
266 |
267 | # Load queries.
268 | q_offset = (
269 | seq_idx * stride_q0
270 | + (kv_head_idx * QUERY_GROUP_SIZE + padding_group_offset[:, None]) * stride_q1
271 | + head_offset[None, :] * stride_q2
272 | )
273 | group_mask = padding_group_offset[:, None] < QUERY_GROUP_SIZE
274 | # q: [PADDED_QUERY_GROUP_SIZE, HEAD_SIZE]
275 | q = tl.load(q_ptr + q_offset, mask=group_mask, other=0.0)
276 |
277 | m_i = tl.zeros([PADDED_QUERY_GROUP_SIZE], dtype=tl.float32) - float("inf")
278 | l_i = tl.zeros([PADDED_QUERY_GROUP_SIZE], dtype=tl.float32)
279 | acc = tl.zeros([PADDED_QUERY_GROUP_SIZE, HEAD_SIZE], dtype=tl.float32)
280 |
281 | num_prev_blocks = part_idx * (PARTITION_SIZE // KV_BLOCK_SIZE)
282 | for i in range(num_blocks):
283 | block_idx = num_prev_blocks + i
284 | block_number = tl.load(
285 | block_tables_ptr + seq_idx * stride_bt0 + block_idx * stride_bt1
286 | )
287 |
288 | # Load a key block.
289 | kv_block_offset = block_number * stride_kv0 + kv_offset
290 | mask_offset = block_idx * KV_BLOCK_SIZE + block_offset
291 | kv_mask = mask_offset[:, None] < context_len
292 |
293 | # k: [KV_BLOCK_SIZE, HEAD_SIZE]
294 | k = tl.load(k_cache_ptr + kv_block_offset, mask=kv_mask, other=0.0)
295 |
296 | # qk: [PADDED_QUERY_GROUP_SIZE, KV_BLOCK_SIZE]
297 | if PADDED_QUERY_GROUP_SIZE == 1:
298 | qk = tl.sum(q[:, None, :] * k[None, :, :], axis=2)
299 | else:
300 | qk = tl.dot(q, k.T, out_dtype=tl.float32)
301 |
302 | qk *= attn_scale
303 | qk = tl.where(mask_offset < context_len, qk, float("-inf"))
304 |
305 | m_i_new = tl.maximum(m_i, tl.max(qk, axis=1))
306 |
307 | # p: [PADDED_QUERY_GROUP_SIZE, KV_BLOCK_SIZE]
308 | p = tl.math.exp2((qk - m_i_new[:, None]) * log2e)
309 | alpha = tl.math.exp2((m_i - m_i_new) * log2e)
310 | acc *= alpha[:, None]
311 |
312 | # v: [KV_BLOCK_SIZE, HEAD_SIZE]
313 | v = tl.load(v_cache_ptr + kv_block_offset, mask=kv_mask, other=0.0)
314 |
315 | if PADDED_QUERY_GROUP_SIZE == 1:
316 | acc += tl.sum(p.T[:, :, None] * v[:, None, :], axis=0)
317 | else:
318 | p = p.to(v.dtype)
319 | acc += tl.dot(p, v, out_dtype=tl.float32)
320 |
321 | l_i = l_i * alpha + tl.sum(p, axis=1)
322 | m_i = m_i_new
323 | acc = acc / l_i[:, None]
324 |
325 | if USE_PARTITIONING:
326 | part_offset = (
327 | (seq_idx * NUM_KV_HEADS + kv_head_idx)
328 | * max_num_partitions
329 | * QUERY_GROUP_SIZE
330 | + part_idx * QUERY_GROUP_SIZE
331 | + padding_group_offset
332 | )
333 | mask = padding_group_offset < QUERY_GROUP_SIZE
334 | tl.store(m_i_ptr + part_offset, m_i, mask=mask)
335 | tl.store(l_i_ptr + part_offset, l_i, mask=mask)
336 |
337 | out_offset = seq_idx * stride_o0
338 | if USE_PARTITIONING:
339 | out_offset += kv_head_idx * stride_o1
340 | else:
341 | out_offset += kv_head_idx * QUERY_GROUP_SIZE * stride_o1
342 | out_offset += (
343 | part_idx * stride_o2
344 | + padding_group_offset[:, None] * stride_o3
345 | + head_offset[None, :] * stride_o4
346 | )
347 |
348 | group_mask = padding_group_offset[:, None] < QUERY_GROUP_SIZE
349 | tl.store(out_ptr + out_offset, acc, mask=group_mask)
350 |
351 |
352 | @triton.jit
353 | def _paged_attn_v2_reduce_kernel(
354 | out_ptr, # [num_seqs, NUM_KV_HEADS, QUERY_GROUP_SIZE, HEAD_SIZE]
355 | m_i_ptr, # [num_seqs, NUM_KV_HEADS, max_num_partitions, QUERY_GROUP_SIZE]
356 | l_i_ptr, # [num_seqs, NUM_KV_HEADS, max_num_partitions, QUERY_GROUP_SIZE]
357 | tmp_out_ptr, # [num_seqs, NUM_KV_HEADS, max_num_partitions, QUERY_GROUP_SIZE, HEAD_SIZE]
358 | context_lens_ptr, # [num_seqs]
359 | max_num_partitions, # partition stride
360 | stride_o0,
361 | stride_o1,
362 | stride_o2,
363 | HEAD_SIZE: tl.constexpr,
364 | QUERY_GROUP_SIZE: tl.constexpr,
365 | NUM_KV_HEADS: tl.constexpr,
366 | PARTITION_SIZE: tl.constexpr,
367 | NUM_PARTITIONS: tl.constexpr,
368 | ):
369 | seq_idx = tl.program_id(0)
370 | kv_head_idx = tl.program_id(1)
371 |
372 | context_len = tl.load(context_lens_ptr + seq_idx)
373 |
374 | num_partitions = tl.cdiv(context_len, PARTITION_SIZE)
375 | group_head_offset = (
376 | tl.arange(0, QUERY_GROUP_SIZE)[:, None] * HEAD_SIZE
377 | + tl.arange(0, HEAD_SIZE)[None, :]
378 | )
379 | if num_partitions == 1:
380 | tmp_out_offset = (
381 | seq_idx * NUM_KV_HEADS + kv_head_idx
382 | ) * max_num_partitions * QUERY_GROUP_SIZE * HEAD_SIZE + group_head_offset
383 | tmp_out = tl.load(tmp_out_ptr + tmp_out_offset)
384 |
385 | out_offset = (
386 | seq_idx * stride_o0
387 | + kv_head_idx * QUERY_GROUP_SIZE * stride_o1
388 | + group_head_offset * stride_o2
389 | )
390 | tl.store(out_ptr + out_offset, tmp_out)
391 | return
392 |
393 | # Get the global max logit.
394 | ml_offset = (
395 | (seq_idx * NUM_KV_HEADS + kv_head_idx) * max_num_partitions * QUERY_GROUP_SIZE
396 | + tl.arange(0, NUM_PARTITIONS)[:, None] * QUERY_GROUP_SIZE
397 | + tl.arange(0, QUERY_GROUP_SIZE)[None, :]
398 | )
399 |
400 | mask = tl.arange(0, NUM_PARTITIONS)[:, None] < num_partitions
401 | # m_i: [NUM_PARTITIONS, QUERY_GROUP_SIZE]
402 | m_i = tl.load(m_i_ptr + ml_offset, mask=mask, other=float("-inf"))
403 | # m: [QUERY_GROUP_SIZE]
404 | m = tl.max(m_i, axis=0)
405 |
406 | # Rescale the exp sums and compute the global sum.
407 | # l_i: [NUM_PARTITIONS, QUERY_GROUP_SIZE]
408 | l_i = tl.load(l_i_ptr + ml_offset, mask=mask, other=0.0)
409 | l_i *= tl.exp(m_i - m[None, :])
410 | # l: [QUERY_GROUP_SIZE]
411 | l = tl.sum(l_i, axis=0)
412 | # r: [NUM_PARTITIONS, QUERY_GROUP_SIZE]
413 | r = l_i / l[None, :]
414 | r = tl.reshape(r, (NUM_PARTITIONS, QUERY_GROUP_SIZE, 1))
415 |
416 | tmp_out_offset = (
417 | (seq_idx * NUM_KV_HEADS + kv_head_idx)
418 | * max_num_partitions
419 | * QUERY_GROUP_SIZE
420 | * HEAD_SIZE
421 | + tl.arange(0, NUM_PARTITIONS)[:, None, None] * QUERY_GROUP_SIZE * HEAD_SIZE
422 | + tl.arange(0, QUERY_GROUP_SIZE)[None, :, None] * HEAD_SIZE
423 | + tl.arange(0, HEAD_SIZE)[None, None, :]
424 | )
425 | # tmp_out: [NUM_PARTITIONS, QUERY_GROUP_SIZE, HEAD_SIZE]
426 | tmp_out = tl.load(tmp_out_ptr + tmp_out_offset, mask=mask[:, :, None], other=0.0)
427 | # out: [QUERY_GROUP_SIZE, HEAD_SIZE]
428 | out = tl.sum((tmp_out * r).to(tl.float32), axis=0)
429 |
430 | out_offset = (
431 | seq_idx * stride_o0
432 | + kv_head_idx * QUERY_GROUP_SIZE * stride_o1
433 | + group_head_offset * stride_o2
434 | )
435 | tl.store(out_ptr + out_offset, out)
436 |
--------------------------------------------------------------------------------
/src/flag_attn/piecewise.py:
--------------------------------------------------------------------------------
1 | """
2 | Piecewise Attention
3 | ====================
4 |
5 | This is a extension to Flash Attention v2 algorithm from Tri Dao
6 | (https://tridao.me/publications/flash2/flash2.pdf) that performs piecewise computation
7 | of attention scores(The scores to which softmax is applied). This design originates from
8 | the need to make better predictions when the predicted sequence is longer than sequences
9 | in the training set.
10 |
11 | It takes as input two q's and two k's as inputs. The attention score is the dot product
12 | of (q1, k1) or (q2, k2) depending on whether the distance between q & k exceeds a threshold.
13 |
14 | The code is adapted from triton's [tutorial](https://github.com/openai/triton/blob/5162871c6cae01a8508a309cf21a8e6b68a4c091/python/tutorials/06-fused-attention.py).
15 | """
16 |
17 | import math
18 | import torch
19 | import triton
20 | import triton.language as tl
21 |
22 | __all__ = ["attention"]
23 |
24 | # --------------------------- public API ---------------------------
25 | class PiecewiseAttention(torch.autograd.Function):
26 | @staticmethod
27 | def forward(ctx, q1, k1, q2, k2, v, w, causal, sm_scale):
28 | # shape constraints
29 | Dq1, Dk1, Dq2, Dk2, Dv = q1.shape[-1], k1.shape[-1], q2.shape[-1], k2.shape[-1], v.shape[-1]
30 | assert Dq1 == Dk1 == Dq2 == Dk2 == Dv
31 | assert Dk1 in {16, 32, 64, 128}
32 |
33 | B, H, M, D = q1.shape
34 | N = k1.shape[2]
35 | P_SEQ = N - M
36 | larger_m = M > N
37 |
38 | if sm_scale is None:
39 | sm_scale = 1. / math.sqrt(D)
40 |
41 | # to work around https://github.com/openai/triton/issues/2441
42 | device = torch.cuda.device_of(q1)
43 | with torch.cuda.device(device):
44 | config = get_fwd_config(B, H, M, N, D, causal)
45 | BLOCK_M, BLOCK_N, num_stages, num_warps = config
46 |
47 | divisible_m = M % BLOCK_M == 0
48 | divisible_n = N % BLOCK_N == 0
49 |
50 | grid = (triton.cdiv(M, BLOCK_M), H, B)
51 | o = torch.empty_like(q1)
52 | L = torch.empty((B, H, M), device=q1.device, dtype=torch.float32)
53 |
54 | _fwd_kernel[grid](
55 | q1, k1, q2, k2, v, sm_scale,
56 | L,
57 | o,
58 | q1.stride(0), q1.stride(1), q1.stride(2), q1.stride(3),
59 | k1.stride(0), k1.stride(1), k1.stride(2), k1.stride(3),
60 | q2.stride(0), q2.stride(1), q2.stride(2), q2.stride(3),
61 | k2.stride(0), k2.stride(1), k2.stride(2), k2.stride(3),
62 | v.stride(0), v.stride(1), v.stride(2), v.stride(3),
63 | o.stride(0), o.stride(1), o.stride(2), o.stride(3),
64 | B, H, M, N, P_SEQ,
65 | w = w, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=D,
66 | IS_CAUSAL=causal, LARGER_M=larger_m,
67 | DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n,
68 | num_warps=num_warps, num_stages=num_stages,
69 | )
70 |
71 | ctx.save_for_backward(q1, k1, q2, k2, v, o, L)
72 | ctx.sm_scale = sm_scale
73 | ctx.causal = causal
74 | ctx.w = w
75 | return o
76 |
77 | @staticmethod
78 | def backward(ctx, do):
79 | q1, k1, q2, k2, v, o, L = ctx.saved_tensors
80 | w = ctx.w
81 | causal = ctx.causal
82 | sm_scale = ctx.sm_scale
83 |
84 | B, H, M, D = q1.shape
85 | N = k1.shape[2]
86 | P_SEQ = N - M
87 | larger_m = M > N
88 |
89 | if sm_scale is None:
90 | sm_scale = 1. / math.sqrt(D)
91 |
92 | # to work around https://github.com/openai/triton/issues/2441
93 | device = torch.cuda.device_of(q1)
94 | with torch.cuda.device(device):
95 | config = get_bwd_config(B, H, M, N, D, causal)
96 | BLOCK_M, BLOCK_N, num_stages, num_warps = config
97 |
98 | divisible_m = M % BLOCK_M == 0
99 | divisible_n = N % BLOCK_N == 0
100 |
101 | delta = torch.empty((B, H, M), device=q1.device, dtype=torch.float32)
102 | grid = (triton.cdiv(M, BLOCK_M), H, B)
103 | _bwd_preprocess[grid](
104 | o, do,
105 | delta,
106 | o.stride(0), o.stride(1), o.stride(2), o.stride(3),
107 | do.stride(0), do.stride(1), do.stride(2), do.stride(3),
108 | delta.stride(0), delta.stride(1), delta.stride(2),
109 | M,
110 | BLOCK_M=BLOCK_M, D_HEAD=D,
111 | DIVISIBLE_M=divisible_m,
112 | )
113 |
114 | dk1 = torch.empty_like(k1)
115 | dk2 = torch.empty_like(k2)
116 | dv = torch.empty_like(v)
117 | grid = (triton.cdiv(N, BLOCK_N), H, B)
118 | _bwd_kv_kernel[grid](
119 | q1, k1, q2, k2, v, sm_scale, do,
120 | dk1,dk2, dv,
121 | L, delta,
122 | q1.stride(0), q1.stride(1), q1.stride(2), q1.stride(3),
123 | k1.stride(0), k1.stride(1), k1.stride(2), k1.stride(3),
124 | q2.stride(0), q2.stride(1), q2.stride(2), q2.stride(3),
125 | k2.stride(0), k2.stride(1), k2.stride(2), k2.stride(3),
126 | v.stride(0), v.stride(1), v.stride(2), v.stride(3),
127 | do.stride(0), do.stride(1), do.stride(2), do.stride(3),
128 | dk1.stride(0), dk1.stride(1), dk1.stride(2), dk1.stride(3),
129 | dk2.stride(0), dk2.stride(1), dk2.stride(2), dk2.stride(3),
130 | dv.stride(0), dv.stride(1), dv.stride(2), dv.stride(3),
131 | B, H, M, N, P_SEQ,
132 | w=w,
133 | BLOCK_M=BLOCK_M, BLOCK_DMODEL=D,
134 | BLOCK_N=BLOCK_N,
135 | CAUSAL=causal,
136 | DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n,
137 | num_stages=num_stages,
138 | num_warps=num_warps,
139 | )
140 |
141 | dq1 = torch.zeros_like(q1)
142 | dq2 = torch.zeros_like(q2)
143 | grid = (triton.cdiv(M, BLOCK_M), H, B)
144 | _bwd_q_kernel[grid](
145 | q1, k1, q2, k2, v, sm_scale, do,
146 | dq1, dq2,
147 | L, delta,
148 | q1.stride(0), q1.stride(1), q1.stride(2), q1.stride(3),
149 | k1.stride(0), k1.stride(1), k1.stride(2), k1.stride(3),
150 | q2.stride(0), q2.stride(1), q2.stride(2), q2.stride(3),
151 | k2.stride(0), k2.stride(1), k2.stride(2), k2.stride(3),
152 | v.stride(0), v.stride(1), v.stride(2), v.stride(3),
153 | do.stride(0), do.stride(1), do.stride(2), do.stride(3),
154 | dq1.stride(0), dq1.stride(1), dq1.stride(2), dq1.stride(3),
155 | dq2.stride(0), dq2.stride(1), dq2.stride(2), dq2.stride(3),
156 | B, H, M, N, P_SEQ,
157 | w=w,
158 | BLOCK_M=BLOCK_M, BLOCK_DMODEL=D,
159 | BLOCK_N=BLOCK_N,
160 | CAUSAL=causal, LARGER_M=larger_m,
161 | DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n,
162 | num_stages=num_stages,
163 | num_warps=num_warps,
164 | )
165 |
166 | return dq1, dk1, dq2, dk2, dv, None, None, None
167 |
168 |
169 | def attention(q1, k1, q2, k2, v, dist_threshold, causal=False, sm_scale=None):
170 | """
171 | PiecewiseAttention
172 |
173 | Piecewise deviates from standard scaled dot product attention in that takes
174 | as inputs two q's and two k's as inputs. The attention score is dot product
175 | of (q1, k1) or (q2, k2) depending on whether the distance between q & k
176 | exceeds a threshold.
177 |
178 | Arguments:
179 | q1(torch.Tensor): The first queries. The shape is (batch_size, nheads, seqlen_q, headdim).
180 | k1(torch.Tensor): The first keys. The shape is (batch_size, nheads, seqlen_k, headdim).
181 | q2(torch.Tensor): The second queries. The shape is (batch_size, nheads, seqlen_q, headdim).
182 | k2(torch.Tensor): The second keys. The shape is (batch_size, nheads, seqlen_k, headdim).
183 | v(torch.Tensor): The values. The shape is (batch_size, nheads, seqlen_k, headdim).
184 | dist_threshold(int): The threshold of distance between q and k. When the distance is not greater than w, the attention score is dot(q1, k1); otherwise dot(q2, k2).
185 | causal(bool): Whether causal masking is applied to attention scores before applying softmax.
186 | sm_scale(float): The scaling of attention scores before applying softmax.
187 |
188 | Returns:
189 | out: (torch.Tensor): The output. The shape is (batch_size, nheads, seqlen_q, headdim).
190 | """
191 | return PiecewiseAttention.apply(q1, k1, q2, k2, v, dist_threshold, causal, sm_scale)
192 |
193 | # --------------------------- Forward ---------------------------
194 | def get_fwd_config(B, H, M, N, D, causal):
195 | # A100
196 | if torch.cuda.get_device_capability() == (8, 0):
197 | if not causal:
198 | if D <= 64:
199 | BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 32, 3, 4
200 | else:
201 | BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 64, 3, 8
202 | else:
203 | if D <= 64:
204 | BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 32, 3, 4
205 | else:
206 | BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 64, 4, 8
207 | # RTX-3090, ...
208 | elif torch.cuda.get_device_capability() == (8, 6):
209 | if not causal:
210 | if D <= 64:
211 | BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 32, 3, 4
212 | else:
213 | BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 64, 3, 8
214 | else:
215 | if D <= 64:
216 | BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 32, 3, 4
217 | else:
218 | BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 32, 2, 8
219 | else:
220 | BLOCK_M, BLOCK_N, num_stages, num_warps = 32, 32, 1, 4
221 | return BLOCK_M, BLOCK_N, num_stages, num_warps
222 |
223 | @triton.jit
224 | def _fwd_kernel(
225 | Q1, K1, Q2, K2, V, sm_scale,
226 | L,
227 | O,
228 | stride_q1z, stride_q1h, stride_q1m, stride_q1k,
229 | stride_k1z, stride_k1h, stride_k1n, stride_k1k,
230 | stride_q2z, stride_q2h, stride_q2m, stride_q2k,
231 | stride_k2z, stride_k2h, stride_k2n, stride_k2k,
232 | stride_vz, stride_vh, stride_vn, stride_vk,
233 | stride_oz, stride_oh, stride_om, stride_ok,
234 | Z, H, M, N, P_SEQ,
235 | w: tl.constexpr,
236 | BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
237 | BLOCK_N: tl.constexpr,
238 | IS_CAUSAL: tl.constexpr, LARGER_M: tl.constexpr,
239 | DIVISIBLE_M: tl.constexpr, DIVISIBLE_N: tl.constexpr,
240 | ):
241 | input_dtype = Q1.dtype.element_ty
242 | # -- grid id --
243 | start_m = tl.program_id(0)
244 | off_h = tl.program_id(1)
245 | off_z = tl.program_id(2)
246 |
247 | # scale sm_scale by log_2(e) and use
248 | # 2^x instead of exp in the loop because CSE and LICM
249 | # don't work as expected with `exp` in the loop
250 | log2e: tl.constexpr = 1.4426950408889634
251 | qk_scale = sm_scale * log2e
252 |
253 | # offset pointers for (batch, head)
254 | Q1 += off_z * stride_q1z + off_h * stride_q1h
255 | Q2 += off_z * stride_q2z + off_h * stride_q2h
256 | K1 += off_z * stride_k1z + off_h * stride_k1h
257 | K2 += off_z * stride_k2z + off_h * stride_k2h
258 | V += off_z * stride_vz + off_h * stride_vh
259 | O += off_z * stride_oz + off_h * stride_oh
260 | L += (off_z * H + off_h) * M # L: shape(B, H, N_CTX), C-contiguous
261 |
262 | offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
263 | offs_n_base = tl.arange(0, BLOCK_N)
264 | offs_n_init = offs_n_base
265 | offs_k = tl.arange(0, BLOCK_DMODEL)
266 |
267 | # initialize pointers to v alue-like data
268 | q1_ptrs = Q1 + (offs_m[:, None] * stride_q1m + offs_k[None, :] * stride_q1k) # (BLOCK_M, BLOCK_DMODEL)
269 | q2_ptrs = Q2 + (offs_m[:, None] * stride_q2m + offs_k[None, :] * stride_q2k) # (BLOCK_M, BLOCK_DMODEL)
270 | k1_ptrs = K1 + (offs_n_init[:, None] * stride_k1n + offs_k[None, :] * stride_k1k) # (BLOCK_N, BLOCK_DMODEL)
271 | k2_ptrs = K2 + (offs_n_init[:, None] * stride_k2n + offs_k[None, :] * stride_k2k) # (BLOCK_N, BLOCK_DMODEL)
272 | v_ptrs = V + (offs_n_init[:, None] * stride_vn + offs_k[None, :] * stride_vk) # (BLOCK_N, BLOCK_DMODEL)
273 | o_ptrs = O + (offs_m[:, None] * stride_om + offs_k[None, :] * stride_ok) # (BLOCK_M, BLOCK_DMODEL)
274 | l_ptrs = L + offs_m
275 |
276 | # initialize pointer to m and l, fp32 for accumulators
277 | m_i = tl.full([BLOCK_M], value=-float("inf"), dtype=tl.float32)
278 | l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
279 | acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
280 |
281 | # load q: it will stay in SRAM throughout
282 | if DIVISIBLE_M:
283 | q1 = tl.load(q1_ptrs)
284 | q2 = tl.load(q2_ptrs)
285 | else:
286 | mask_m = offs_m < M
287 | q1 = tl.load(q1_ptrs, mask=mask_m[:, None])
288 | q2 = tl.load(q2_ptrs, mask=mask_m[:, None])
289 |
290 | # Dot I trick: it converts q1, q2 into mma layout and saves shared memory
291 | # better way to generate a eye matrix. avoid casting from bool
292 | I = tl.where(offs_k[:, None] == offs_k,
293 | tl.full((BLOCK_DMODEL, BLOCK_DMODEL), 1.0, dtype=input_dtype),
294 | tl.full((BLOCK_DMODEL, BLOCK_DMODEL), 0.0, dtype=input_dtype))
295 | q1 = tl.dot(q1, I).to(input_dtype)
296 | q2 = tl.dot(q2, I).to(input_dtype)
297 |
298 | # loop over k, v and update accumulator
299 | # see note "Loop-Bound-For-N"
300 | if IS_CAUSAL:
301 | hi = tl.minimum(N, P_SEQ + (start_m + 1) * BLOCK_M)
302 | if LARGER_M:
303 | hi = tl.maximum(0, hi)
304 | else:
305 | hi = N
306 |
307 | for start_n in range(0, hi, BLOCK_N):
308 | # -- offsets & masking --
309 | start_n = tl.multiple_of(start_n, BLOCK_N)
310 | offs_n = start_n + offs_n_base
311 | piecewise_mask = (P_SEQ + offs_m[:, None]) >= (offs_n[None, :] + w)
312 |
313 | # -- load k, v --
314 | if DIVISIBLE_N:
315 | k1 = tl.load(k1_ptrs)
316 | k2 = tl.load(k2_ptrs)
317 | v = tl.load(v_ptrs)
318 | else:
319 | mask_n = offs_n < N
320 | k1 = tl.load(k1_ptrs, mask=mask_n[:, None])
321 | k2 = tl.load(k2_ptrs, mask=mask_n[:, None])
322 | v = tl.load(v_ptrs, mask=mask_n[:, None])
323 |
324 | # -- compute s = qk ---
325 | s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
326 |
327 | # TODO: more careful masking
328 | s += tl.where(piecewise_mask,
329 | tl.dot(q2, tl.trans(k2)),
330 | tl.dot(q1, tl.trans(k1)))
331 | if not DIVISIBLE_N:
332 | s = tl.where(mask_n, s, float("-inf"))
333 | if IS_CAUSAL:
334 | causal_mask = (P_SEQ + offs_m[:, None]) >= offs_n[None, :]
335 | s = tl.where(causal_mask, s, float("-inf"))
336 |
337 | # -- compute scaling constant ---
338 | # loop l2r, so no extra handling of inf is needed
339 | m_i_new = tl.maximum(m_i, tl.max(s, 1))
340 | alpha = tl.math.exp2((m_i - m_i_new) * qk_scale)
341 | p = tl.math.exp2(s * qk_scale - m_i_new[:, None] * qk_scale)
342 |
343 | # -- scale and update acc --
344 | acc *= alpha[:, None]
345 | acc += tl.dot(p.to(input_dtype), v)
346 |
347 | # -- update m_i and l_i --
348 | l_i = l_i * alpha + tl.sum(p, 1)
349 | m_i = m_i_new
350 |
351 | # update pointers
352 | k1_ptrs += BLOCK_N * stride_k1n
353 | k2_ptrs += BLOCK_N * stride_k2n
354 | v_ptrs += BLOCK_N * stride_vn
355 |
356 | # write back l & o
357 | if IS_CAUSAL and LARGER_M:
358 | is_empty_line = (offs_m + P_SEQ) < 0
359 | acc = tl.where(is_empty_line[:, None], 0.0, acc * (1.0 / l_i[:, None]))
360 | l_i = tl.where(is_empty_line, float("-inf"), m_i * sm_scale + tl.log(l_i))
361 | else:
362 | acc = acc * (1.0 / l_i[:, None])
363 | l_i = m_i * sm_scale + tl.log(l_i)
364 |
365 | if DIVISIBLE_M:
366 | tl.store(l_ptrs, l_i)
367 | tl.store(o_ptrs, acc.to(input_dtype))
368 | else:
369 | tl.store(l_ptrs, l_i, mask=mask_m)
370 | tl.store(o_ptrs, acc.to(input_dtype), mask=mask_m[:, None])
371 |
372 | # --------------------------- Backward ---------------------------
373 | def get_bwd_config(B, H, M, N, D, causal):
374 | # A100
375 | if torch.cuda.get_device_capability() == (8, 0):
376 | if not causal:
377 | if D <= 64:
378 | BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 2, 4
379 | else:
380 | BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 64, 2, 8
381 | else:
382 | if D <= 64:
383 | BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 3, 4
384 | else:
385 | BLOCK_M, BLOCK_N, num_stages, num_warps = 32, 64, 2, 4
386 |
387 | # BLOCK_M = 64 if D<=64 else 128
388 | # BLOCK_N = 64
389 | # num_stages = 1 if D<=64 else (2 if not causal else 1)
390 | # num_warps = 4 if D <=64 else 8
391 | # RTX-3090, ...
392 | elif torch.cuda.get_device_capability() == (8, 6):
393 | if not causal:
394 | if D <= 64:
395 | BLOCK_M, BLOCK_N, num_stages, num_warps = 32, 32, 2, 4
396 | else:
397 | BLOCK_M, BLOCK_N, num_stages, num_warps = 32, 64, 2, 8
398 | else:
399 | if D <= 64:
400 | BLOCK_M, BLOCK_N, num_stages, num_warps = 32, 32, 2, 4
401 | else:
402 | BLOCK_M, BLOCK_N, num_stages, num_warps = 32, 64, 2, 8
403 | else:
404 | BLOCK_M, BLOCK_N, num_stages, num_warps = 32, 32, 1, 4
405 | return BLOCK_M, BLOCK_N, num_stages, num_warps
406 |
407 | @triton.jit
408 | def _bwd_preprocess(
409 | Out, DO,
410 | Delta,
411 | stride_oz, stride_oh, stride_om, stride_ok,
412 | stride_doz, stride_doh, stride_dom, stride_dok,
413 | stride_dz, stride_dh, stride_dm,
414 | M,
415 | BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
416 | DIVISIBLE_M: tl.constexpr,
417 | ):
418 | off_h = tl.program_id(1)
419 | off_z = tl.program_id(2)
420 | Out += off_z * stride_oz + off_h * stride_oh
421 | DO += off_z * stride_doz + off_h * stride_doh
422 | Delta += off_z * stride_dz + off_h * stride_dh
423 |
424 | # compute (Out * Dout).sum() for vector interpretation
425 | off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
426 | off_n = tl.arange(0, D_HEAD)
427 |
428 | # load
429 | o_ptrs = Out + off_m[:, None] * stride_om + off_n[None, :] * stride_ok
430 | do_ptrs = DO + off_m[:, None] * stride_dom + off_n[None, :] * stride_dok
431 |
432 | if DIVISIBLE_M:
433 | o = tl.load(o_ptrs).to(tl.float32)
434 | do = tl.load(do_ptrs).to(tl.float32)
435 | else:
436 | mask_m = off_m < M
437 | o = tl.load(o_ptrs, mask=mask_m[:, None]).to(tl.float32)
438 | do = tl.load(do_ptrs, mask=mask_m[:, None]).to(tl.float32)
439 |
440 | # compute
441 | delta = tl.sum(o * do, axis=1)
442 | # write-back
443 | d_ptrs = Delta + off_m * stride_dm
444 | if DIVISIBLE_M:
445 | tl.store(d_ptrs, delta)
446 | else:
447 | tl.store(d_ptrs, delta, mask=mask_m)
448 |
449 |
450 | @triton.jit
451 | def _bwd_kv_kernel(
452 | Q1, K1, Q2, K2, V, sm_scale, DO,
453 | DK1, DK2, DV,
454 | L,
455 | D,
456 | stride_q1z, stride_q1h, stride_q1m, stride_q1k,
457 | stride_k1z, stride_k1h, stride_k1n, stride_k1k,
458 | stride_q2z, stride_q2h, stride_q2m, stride_q2k,
459 | stride_k2z, stride_k2h, stride_k2n, stride_k2k,
460 | stride_vz, stride_vh, stride_vn, stride_vk,
461 | stride_doz, stride_doh, stride_dom, stride_dok,
462 | stride_dk1z, stride_dk1h, stride_dk1n, stride_dk1k,
463 | stride_dk2z, stride_dk2h, stride_dk2n, stride_dk2k,
464 | stride_dvz, stride_dvh, stride_dvn, stride_dvk,
465 | Z, H, M, N, P_SEQ,
466 | w: tl.constexpr,
467 | BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
468 | BLOCK_N: tl.constexpr,
469 | CAUSAL: tl.constexpr,
470 | DIVISIBLE_M: tl.constexpr, DIVISIBLE_N: tl.constexpr,
471 | ):
472 | input_dtype = Q1.dtype.element_ty
473 | # -- grid id --
474 | start_n = tl.program_id(0)
475 | off_h = tl.program_id(1)
476 | off_z = tl.program_id(2)
477 |
478 | log2e: tl.constexpr = 1.4426950408889634
479 | qk_scale = sm_scale * log2e
480 |
481 | # offset pointers for (batch, head)
482 | Q1 += off_z * stride_q1z + off_h * stride_q1h
483 | Q2 += off_z * stride_q2z + off_h * stride_q2h
484 | K1 += off_z * stride_k1z + off_h * stride_k1h
485 | K2 += off_z * stride_k2z + off_h * stride_k2h
486 | V += off_z * stride_vz + off_h * stride_vh
487 | DO += off_z * stride_doz + off_h * stride_doh
488 | D += (off_z * H + off_h) * M
489 | L += (off_z * H + off_h) * M
490 |
491 | # offset pointers for batch/head
492 | DK1 += off_z * stride_dk1z + off_h * stride_dk1h
493 | DK2 += off_z * stride_dk2z + off_h * stride_dk2h
494 | DV += off_z * stride_dvz + off_h * stride_dvh
495 |
496 | if CAUSAL:
497 | lo = tl.maximum(start_n * BLOCK_N - P_SEQ, 0)
498 | lo = (lo // BLOCK_M) * BLOCK_M
499 | else:
500 | lo = 0
501 |
502 | offs_m_init = lo + tl.arange(0, BLOCK_M)
503 | offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
504 | offs_m_base = tl.arange(0, BLOCK_M)
505 | offs_k = tl.arange(0, BLOCK_DMODEL)
506 |
507 |
508 | # initialize pointers to value-like data
509 | q1_ptrs = Q1 + (offs_m_init[:, None] * stride_q1m + offs_k[None, :] * stride_q1k) # (BLOCK_M, BLOCK_DMODEL)
510 | q2_ptrs = Q2 + (offs_m_init[:, None] * stride_q2m + offs_k[None, :] * stride_q2k) # (BLOCK_M, BLOCK_DMODEL)
511 | k1_ptrs = K1 + (offs_k[:, None] * stride_k1k + offs_n[None, :] * stride_k1n) # (BLOCK_DMODEL, BLOCK_N)
512 | k2_ptrs = K2 + (offs_k[:, None] * stride_k2k + offs_n[None, :] * stride_k2n) # (BLOCK_DMODEL, BLOCK_N)
513 | v_ptrs = V + (offs_n[:, None] * stride_vn + offs_k[None, :] * stride_vk) # (BLOCK_N, BLOCK_DMODEL)
514 | do_ptrs = DO + (offs_m_init[:, None] * stride_dom + offs_k[None, :] * stride_dok) # (BLOCK_M, BLOCK_DMODEL)
515 |
516 | dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_k[None, :] * stride_dvk) # (BLOCK_N, BLOCK_DMODEL)
517 | dk1_ptrs = DK1 + (offs_n[:, None] * stride_dk1n + offs_k[None, :] * stride_dk1k) # (BLOCK_N, BLOCK_DMODEL)
518 | dk2_ptrs = DK2 + (offs_n[:, None] * stride_dk2n + offs_k[None, :] * stride_dk2k) # (BLOCK_N, BLOCK_DMODEL)
519 |
520 | # k and v stay in SRAM throughout
521 | if DIVISIBLE_N:
522 | k1 = tl.load(k1_ptrs)
523 | k2 = tl.load(k2_ptrs)
524 | v = tl.load(v_ptrs)
525 | else:
526 | mask_n = offs_n < N
527 | k1 = tl.load(k1_ptrs, mask=mask_n[None, :])
528 | k2 = tl.load(k2_ptrs, mask=mask_n[None, :])
529 | v = tl.load(v_ptrs, mask=mask_n[:, None])
530 |
531 | # initialize dk amd dv
532 | dk1 = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)
533 | dk2 = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)
534 | dv = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)
535 |
536 | # loop over a column
537 | for start_m in range(lo, M, BLOCK_M):
538 | start_m = tl.multiple_of(start_m, BLOCK_M)
539 | offs_m = start_m + offs_m_base
540 |
541 | # load q1, k1, q2, k2, v, do on-chip
542 | if DIVISIBLE_M:
543 | q1 = tl.load(q1_ptrs)
544 | q2 = tl.load(q2_ptrs)
545 | do = tl.load(do_ptrs) # (BLOCK_M, BLOCK_DMODEL)
546 | delta = tl.load(D + offs_m)
547 | l = tl.load(L + offs_m)
548 | else:
549 | mask_m = offs_m < M
550 | q1 = tl.load(q1_ptrs, mask=mask_m[:, None])
551 | q2 = tl.load(q2_ptrs, mask=mask_m[:, None])
552 | do = tl.load(do_ptrs, mask=mask_m[:, None]) # (BLOCK_M, BLOCK_DMODEL)
553 | delta = tl.load(D + offs_m, mask=mask_m)
554 | l = tl.load(L + offs_m, mask=mask_m)
555 |
556 | # recompute p = softmax(qk, dim=-1).T
557 | piecewise_mask = (P_SEQ + offs_m[:, None]) >= (offs_n[None, :] + w) # (BLOCK_M, BLOCK_N)
558 | s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
559 | s += tl.where(piecewise_mask,
560 | tl.dot(q2, k2),
561 | tl.dot(q1, k1))
562 |
563 | # NOTE: since softmax in backward is pointwise, the normalizer has been saved in fwd)
564 | # So masking on s is not needed.
565 | # if CAUSAL:
566 | # s = tl.where(causal_mask & valid_mask, s, float("-inf"))
567 | # else:
568 | # s = tl.where(valid_mask, s, float("-inf"))
569 |
570 | # -- recompute p ---
571 | # l = tl.load(L + offs_m, mask=mask_m)
572 | p = tl.math.exp2(s * qk_scale - l[:, None] * log2e) # (BLOCK_M, BLOCK_N)
573 | if not DIVISIBLE_M:
574 | valid_mask = mask_m[:, None] # & mask_n
575 | p = tl.where(valid_mask, p, 0.0)
576 | if CAUSAL:
577 | causal_mask = (P_SEQ + offs_m[:, None]) >= (offs_n[None, :]) # (BLOCK_M, BLOCK_N)
578 | p = tl.where(causal_mask, p, 0.0)
579 |
580 |
581 | # compute dv = dot(p, do)
582 | # do = tl.load(do_ptrs, mask=mask_m[:, None]) # (BLOCK_M, BLOCK_DMODEL)
583 | dv += tl.dot(tl.trans(p.to(do.dtype)), do) # (BLOCK_N, BLOCK_DMODEL)
584 |
585 | # compute dp = dot(v, do)
586 | # delta = tl.load(D + offs_m, mask=mask_m)
587 | dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
588 | dp += tl.dot(do.to(input_dtype), tl.trans(v))
589 | # no need to mask dp
590 | # if CAUSAL:
591 | # dp = tl.where(causal_mask & valid_mask, dp, 0.0)
592 | # else:
593 | # dp = tl.where(valid_mask, dp, 0.0)
594 |
595 | # compute ds = p * (dp - delta[:, None])
596 | # move scale out to dk at last
597 | ds = p * (dp - delta[:, None]) # (BLOCK_M, BLOCK_N)
598 |
599 | # mask ds To ensure no small values
600 | if not DIVISIBLE_M:
601 | ds = tl.where(valid_mask, ds, 0.0)
602 | if CAUSAL:
603 | ds = tl.where(causal_mask, ds, 0.0)
604 |
605 | ds2 = tl.where(piecewise_mask, ds, 0.0).to(input_dtype)
606 | ds1 = tl.where(piecewise_mask, 0.0, ds).to(input_dtype)
607 |
608 | # compute dk = dot(ds.T, q) masking
609 | dk1 += tl.dot(tl.trans(ds1), q1)
610 | dk2 += tl.dot(tl.trans(ds2), q2)
611 |
612 | # increment pointers
613 | q1_ptrs += BLOCK_M * stride_q1m
614 | q2_ptrs += BLOCK_M * stride_q2m
615 | do_ptrs += BLOCK_M * stride_dom
616 |
617 | dk1 *= sm_scale
618 | dk2 *= sm_scale
619 |
620 | if DIVISIBLE_N:
621 | tl.store(dk1_ptrs, dk1.to(input_dtype)) # (BLOCK_N, BLOCK_DMODEL)
622 | tl.store(dk2_ptrs, dk2.to(input_dtype)) # (BLOCK_N, BLOCK_DMODEL)
623 | tl.store(dv_ptrs, dv.to(input_dtype)) # (BLOCK_N, BLOCK_DMODEL)
624 | else:
625 | tl.store(dk1_ptrs, dk1.to(input_dtype), mask=mask_n[:, None]) # (BLOCK_N, BLOCK_DMODEL)
626 | tl.store(dk2_ptrs, dk2.to(input_dtype), mask=mask_n[:, None]) # (BLOCK_N, BLOCK_DMODEL)
627 | tl.store(dv_ptrs, dv.to(input_dtype), mask=mask_n[:, None]) # (BLOCK_N, BLOCK_DMODEL)
628 |
629 |
630 | @triton.jit
631 | def _bwd_q_kernel(
632 | Q1, K1, Q2, K2, V, sm_scale, DO,
633 | DQ1, DQ2,
634 | L,
635 | D,
636 | stride_q1z, stride_q1h, stride_q1m, stride_q1k,
637 | stride_k1z, stride_k1h, stride_k1n, stride_k1k,
638 | stride_q2z, stride_q2h, stride_q2m, stride_q2k,
639 | stride_k2z, stride_k2h, stride_k2n, stride_k2k,
640 | stride_vz, stride_vh, stride_vn, stride_vk,
641 | stride_doz, stride_doh, stride_dom, stride_dok,
642 | stride_dq1z, stride_dq1h, stride_dq1m, stride_dq1k,
643 | stride_dq2z, stride_dq2h, stride_dq2m, stride_dq2k,
644 | Z, H, M, N, P_SEQ,
645 | w: tl.constexpr,
646 | BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
647 | BLOCK_N: tl.constexpr,
648 | CAUSAL: tl.constexpr, LARGER_M: tl.constexpr,
649 | DIVISIBLE_M: tl.constexpr, DIVISIBLE_N: tl.constexpr,
650 | ):
651 | input_dtype = Q1.dtype.element_ty
652 | # -- grid id --
653 | start_m = tl.program_id(0)
654 | off_h = tl.program_id(1)
655 | off_z = tl.program_id(2)
656 |
657 | # scale sm_scale by log_2(e) and use
658 | # 2^x instead of exp in the loop because CSE and LICM
659 | # don't work as expected with `exp` in the loop
660 | log2e: tl.constexpr = 1.4426950408889634
661 | qk_scale = sm_scale * log2e
662 |
663 | # offset pointers for (batch, head)
664 | Q1 += off_z * stride_q1z + off_h * stride_q1h
665 | Q2 += off_z * stride_q2z + off_h * stride_q2h
666 | K1 += off_z * stride_k1z + off_h * stride_k1h
667 | K2 += off_z * stride_k2z + off_h * stride_k2h
668 | V += off_z * stride_vz + off_h * stride_vh
669 | DO += off_z * stride_doz + off_h * stride_doh
670 | D += (off_z * H + off_h) * M
671 | L += (off_z * H + off_h) * M
672 |
673 | # offset pointers for batch/head
674 | DQ1 += off_z * stride_dq1z + off_h * stride_dq1h
675 | DQ2 += off_z * stride_dq2z + off_h * stride_dq2h
676 |
677 | offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
678 | offs_n_base = tl.arange(0, BLOCK_N)
679 | offs_n_init = offs_n_base
680 | offs_k = tl.arange(0, BLOCK_DMODEL)
681 |
682 | # initialize pointers to value-like data
683 | q1_ptrs = Q1 + (offs_m[:, None] * stride_q1m + offs_k[None, :] * stride_q1k) # (BLOCK_M, BLOCK_DMODEL)
684 | q2_ptrs = Q2 + (offs_m[:, None] * stride_q2m + offs_k[None, :] * stride_q2k) # (BLOCK_M, BLOCK_DMODEL)
685 | k1_ptrs = K1 + (offs_n_init[:, None] * stride_k1n + offs_k[None, :] * stride_k1k) # (BLOCK_N, BLOCK_DMODEL)
686 | k2_ptrs = K2 + (offs_n_init[:, None] * stride_k2n + offs_k[None, :] * stride_k2k) # (BLOCK_N, BLOCK_DMODEL)
687 | v_ptrs = V + (offs_n_init[:, None] * stride_vn + offs_k[None, :] * stride_vk) # (BLOCK_N, BLOCK_DMODEL)
688 |
689 | dq1_ptrs = DQ1 + (offs_m[:, None] * stride_dq1m + offs_k[None, :] * stride_dq1k) # (BLOCK_M, BLOCK_DMODEL)
690 | dq2_ptrs = DQ2 + (offs_m[:, None] * stride_dq2m + offs_k[None, :] * stride_dq2k) # (BLOCK_M, BLOCK_DMODEL)
691 | do_ptrs = DO + (offs_m[:, None] * stride_dom + offs_k[None, :] * stride_dok) # (BLOCK_M, BLOCK_DMODEL)
692 |
693 | # pointer to row-wise quantities in value-like data
694 | d_ptrs = D + offs_m
695 | l_ptrs = L + offs_m
696 |
697 | # load q: it will stay in SRAM throughout
698 | if DIVISIBLE_M:
699 | q1 = tl.load(q1_ptrs)
700 | q2 = tl.load(q2_ptrs)
701 | do = tl.load(do_ptrs)
702 | delta = tl.load(d_ptrs)
703 | l = tl.load(l_ptrs)
704 | else:
705 | mask_m = offs_m < M
706 | q1 = tl.load(q1_ptrs, mask=mask_m[:, None])
707 | q2 = tl.load(q2_ptrs, mask=mask_m[:, None])
708 | do = tl.load(do_ptrs, mask=mask_m[:, None])
709 | delta = tl.load(d_ptrs, mask=mask_m)
710 | l = tl.load(l_ptrs, mask=mask_m)
711 |
712 | # initialize dq
713 | dq1 = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
714 | dq2 = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
715 |
716 | # loop over k, v and update accumulator
717 | # see note "Loop-Bound-For-N"
718 | if CAUSAL:
719 | hi = tl.minimum(N, P_SEQ + (start_m + 1) * BLOCK_M)
720 | if LARGER_M:
721 | hi = tl.maximum(0, hi)
722 | else:
723 | hi = N
724 |
725 | # loop over a row
726 | for start_n in range(0, hi, BLOCK_N):
727 | offs_n = start_n + offs_n_base
728 |
729 | # load k1, k2, v on chip
730 | if DIVISIBLE_N:
731 | v = tl.load(v_ptrs)
732 | k1 = tl.load(k1_ptrs)
733 | k2 = tl.load(k2_ptrs)
734 | else:
735 | mask_n = offs_n < N
736 | v = tl.load(v_ptrs, mask=mask_n[:, None])
737 | k1 = tl.load(k1_ptrs, mask=mask_n[:, None])
738 | k2 = tl.load(k2_ptrs, mask=mask_n[:, None])
739 |
740 | # recompute p = softmax(qk * sm_scale, dim=-1)
741 | piecewise_mask = (P_SEQ + offs_m[:, None]) >= (offs_n[None, :] + w) # (BLOCK_M, BLOCK_N)
742 | s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
743 | s += tl.where(piecewise_mask,
744 | tl.dot(q2, tl.trans(k2)),
745 | tl.dot(q1, tl.trans(k1)))
746 | # NOTE: since softmax in backward is pointwise, the normalizer has been saved in fwd)
747 | # So masking on s is not needed.
748 | # if CAUSAL:
749 | # s = tl.where(causal_mask & valid_mask, s, float("-inf"))
750 | # else:
751 | # s = tl.where(valid_mask, s, float("-inf"))
752 | p = tl.math.exp2(s * qk_scale - l[:, None] * log2e) # (BLOCK_M, BLOCK_N)
753 |
754 | # compute dp = dot(v, do)
755 | dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
756 | dp += tl.dot(do.to(input_dtype), tl.trans(v))
757 | # no need to mask dp
758 | # if CAUSAL:
759 | # dp = tl.where(causal_mask & valid_mask, dp, 0.0)
760 | # else:
761 | # dp = tl.where(valid_mask, dp, 0.0)
762 |
763 | # compute ds = p * (dp - delta[:, None])
764 | # move scale out to dq at last
765 | ds = p * (dp - delta[:, None]) # (BLOCK_M, BLOCK_N)
766 |
767 | # mask ds to ensure no small values
768 | if not DIVISIBLE_N:
769 | ds = tl.where(mask_n, ds, 0.0)
770 | if CAUSAL:
771 | causal_mask = (P_SEQ + offs_m[:, None]) >= (offs_n[None, :]) # (BLOCK_M, BLOCK_N)
772 | ds = tl.where(causal_mask, ds, 0.0)
773 |
774 | ds2 = tl.where(piecewise_mask, ds, 0.0).to(input_dtype)
775 | ds1 = tl.where(piecewise_mask, 0.0, ds).to(input_dtype)
776 |
777 | dq1 += tl.dot(ds1, k1)
778 | dq2 += tl.dot(ds2, k2)
779 |
780 | # increment pointers
781 | k1_ptrs += BLOCK_N * stride_k1n
782 | k2_ptrs += BLOCK_N * stride_k2n
783 | v_ptrs += BLOCK_N * stride_vn
784 |
785 | dq1 *= sm_scale
786 | dq2 *= sm_scale
787 | if DIVISIBLE_M:
788 | tl.store(dq1_ptrs, dq1.to(input_dtype))
789 | tl.store(dq2_ptrs, dq2.to(input_dtype))
790 | else:
791 | tl.store(dq1_ptrs, dq1.to(input_dtype), mask=mask_m[:, None])
792 | tl.store(dq2_ptrs, dq2.to(input_dtype), mask=mask_m[:, None])
793 |
--------------------------------------------------------------------------------
/src/flag_attn/flash.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import triton
4 | import triton.language as tl
5 | from flag_attn.total import _total_attention_kernel
6 | from flag_attn.split_kv import _fwd_split_kv_kernel, _fwd_combine_kv_splits, num_splits_herustic
7 | from flag_attn.split_kv import get_fwd_config as get_fwd_config_kv_split
8 |
9 | from .dropout import philox_cuda_seed_offset
10 |
11 | __all__ = ["attention"]
12 |
13 |
14 | def maybe_contiguous(x):
15 | # only when the inner most dimension is contiguous can LDGSTS be used
16 | # so inner-dimension contiguity is enforced.
17 | return x.contiguous() if x.stride(-1) != 1 else x
18 |
19 | def rounded_multiple(a, b):
20 | return (a + b - 1) // b * b
21 |
22 | # --------------------------- public API ---------------------------
23 | class FlashAttention(torch.autograd.Function):
24 | @staticmethod
25 | def forward(ctx, q, k, v, causal, sm_scale, dropout_p, return_log_normalizer, return_total_attention, return_seed_offset):
26 | Dq, Dk, Dv = q.shape[-1], k.shape[-1], v.shape[-1]
27 | assert Dq == Dk == Dv, "feature size of q, k, v should be equal"
28 | assert Dk in {16, 32, 64, 128}
29 |
30 | B, H, M, D = q.shape
31 | N = k.shape[2]
32 | Hk, Hv = k.shape[1], v.shape[1]
33 | assert Hk == Hv, "num of heads in k and v should be equal"
34 | assert H % Hk == 0, "number of heads in q must be a multiple of that in k & v"
35 | num_groups = H // Hk
36 |
37 | P_SEQ = N - M
38 | larger_m = M > N
39 |
40 | if sm_scale is None:
41 | sm_scale = 1. / math.sqrt(D)
42 |
43 | # contiguity
44 | q, k, v = maybe_contiguous(q), maybe_contiguous(k), maybe_contiguous(v)
45 |
46 | # to work around https://github.com/openai/triton/issues/2441
47 | device = torch.cuda.device_of(q)
48 | num_sms = torch.cuda.get_device_properties(device).multi_processor_count
49 |
50 | with torch.cuda.device(device):
51 | # Dropout preparation.
52 | is_dropout = dropout_p > 0
53 | if is_dropout:
54 | offset_increment = B * H * M * N
55 | seed, offset = philox_cuda_seed_offset(offset_increment)
56 | else:
57 | seed, offset = 0, 0
58 |
59 | config_for_split_kv = get_fwd_config_kv_split(B, H, M, N, D, causal)
60 | S = num_splits_herustic(B, H, M, N, config_for_split_kv[0], config_for_split_kv[1], num_sms, 128)
61 | split_kv: bool = S > 1
62 | # print(f"flag_attn choose {S} splits")
63 |
64 | if not split_kv:
65 | config = get_fwd_config(B, H, M, N, D, causal)
66 | BLOCK_M, BLOCK_N, num_stages, num_warps = config
67 |
68 | divisible_m = M % BLOCK_M == 0
69 | divisible_n = N % BLOCK_N == 0
70 | # consider using 3d grid to avoid div & rem
71 | grid = (triton.cdiv(M, BLOCK_M), H, B)
72 | o = torch.empty_like(q)
73 | L = torch.empty((B, H, M), device=q.device, dtype=torch.float32)
74 | _fwd_kernel[grid](
75 | q, k, v, sm_scale,
76 | dropout_p, seed, offset,
77 | L, o,
78 | q.stride(0), q.stride(1), q.stride(2), q.stride(3),
79 | k.stride(0), k.stride(1), k.stride(2), k.stride(3),
80 | v.stride(0), v.stride(1), v.stride(2), v.stride(3),
81 | o.stride(0), o.stride(1), o.stride(2), o.stride(3),
82 | B, H, M, N, P_SEQ, num_groups,
83 | BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=D,
84 | IS_CAUSAL=causal, IS_DROPOUT=is_dropout, LARGER_M=larger_m,
85 | DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n,
86 | num_warps=num_warps, num_stages=num_stages,
87 | )
88 | else: # split kv
89 | assert not is_dropout, "Cannot apply dropout with splitkv."
90 | BLOCK_M, BLOCK_N, num_stages, num_warps = config_for_split_kv
91 |
92 | divisible_m = M % BLOCK_M == 0
93 | divisible_n = N % BLOCK_N == 0
94 | # consider using 3d grid to avoid div & rem
95 | multiple_l = torch.empty((B, H, S, M), dtype=torch.float32, device="cuda")
96 | multiple_o = torch.empty((B, H, S, M, D), dtype=torch.float16, device="cuda")
97 | grid = (triton.cdiv(M, BLOCK_M), S, H * B)
98 | N_SPLIT_SIZE = triton.cdiv(triton.cdiv(N, BLOCK_N), S) * BLOCK_N
99 | _fwd_split_kv_kernel[grid](
100 | q, k, v, sm_scale,
101 | multiple_l, multiple_o,
102 | q.stride(0), q.stride(1), q.stride(2), q.stride(3),
103 | k.stride(0), k.stride(1), k.stride(2), k.stride(3),
104 | v.stride(0), v.stride(1), v.stride(2), v.stride(3),
105 | multiple_o.stride(0), multiple_o.stride(1), multiple_o.stride(2), multiple_o.stride(3), multiple_o.stride(4),
106 | B, H, M, N, P_SEQ, N_SPLIT_SIZE, S, num_groups,
107 | BLOCK_M=BLOCK_M, BLOCK_DMODEL=D, BLOCK_N=BLOCK_N,
108 | IS_CAUSAL=causal, LARGER_M=larger_m,
109 | DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n,
110 | num_stages=num_stages, num_warps=num_warps,
111 | )
112 |
113 | L = torch.empty((B, H, M), dtype=torch.float32, device="cuda")
114 | o = torch.empty_like(q)
115 | grid = (triton.cdiv(M, BLOCK_M), H, B)
116 | _fwd_combine_kv_splits[grid](
117 | multiple_o, multiple_l,
118 | o, L,
119 | multiple_o.stride(0), multiple_o.stride(1), multiple_o.stride(2), multiple_o.stride(3), multiple_o.stride(4),
120 | o.stride(0), o.stride(1), o.stride(2), o.stride(3),
121 | B, H, M, S,
122 | BLOCK_M=BLOCK_M, BLOCK_DMODEL=D,
123 | DIVISIBLE_M=divisible_m,
124 | num_stages=num_stages, num_warps=num_warps,
125 | )
126 |
127 | # total attention
128 | if return_total_attention:
129 | tot_attn = torch.empty((B, H, N), device=q.device, dtype=torch.float32)
130 | grid = (triton.cdiv(N, BLOCK_N), H, B)
131 | _total_attention_kernel[grid](
132 | q, k, L, tot_attn, sm_scale,
133 | q.stride(0), q.stride(1), q.stride(2), q.stride(3),
134 | k.stride(0), k.stride(1), k.stride(2), k.stride(3),
135 | B, H, M, N, P_SEQ, num_groups,
136 | BLOCK_M=BLOCK_M, BLOCK_DMODEL=D, BLOCK_N=BLOCK_N,
137 | CAUSAL=causal,
138 | DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n,
139 | num_stages=num_stages, num_warps=num_warps,
140 | )
141 |
142 | # autograd context maintenance
143 | ctx.save_for_backward(q, k, v, o, L)
144 | ctx.sm_scale = sm_scale
145 | ctx.causal = causal
146 | ctx.dropout_p = dropout_p
147 | ctx.seed = seed
148 | ctx.offset = offset
149 |
150 | has_extra_return = True in (return_log_normalizer, return_total_attention, return_seed_offset)
151 | if has_extra_return:
152 | outs = (
153 | o,
154 | L if return_log_normalizer else None,
155 | tot_attn if return_total_attention else None,
156 | seed if is_dropout and return_seed_offset else None,
157 | offset if is_dropout and return_seed_offset else None
158 | )
159 | return outs
160 | return o
161 |
162 | @staticmethod
163 | def backward(ctx, do, *ignored):
164 | q, k, v, o, L = ctx.saved_tensors
165 | sm_scale = ctx.sm_scale
166 | causal = ctx.causal
167 | dropout_p = ctx.dropout_p
168 | is_dropout = ctx.dropout_p > 0
169 | seed = ctx.seed
170 | offset = ctx.offset
171 |
172 | B, H, M, D = q.shape
173 | N = k.shape[2]
174 | Hk = k.shape[1]
175 | num_groups = H // Hk
176 | P_SEQ = N - M
177 | larger_m = M > N
178 |
179 | if sm_scale is None:
180 | sm_scale = 1. / math.sqrt(D)
181 |
182 | # to work around https://github.com/openai/triton/issues/2441
183 | device = torch.cuda.device_of(q)
184 | with torch.cuda.device(device):
185 | config = get_bwd_config(B, H, M, N, D, causal)
186 | BLOCK_M, BLOCK_N, num_stages, num_warps = config
187 |
188 | divisible_m = M % BLOCK_M == 0
189 | divisible_n = N % BLOCK_N == 0
190 |
191 | delta = torch.empty_like(L)
192 | grid = (triton.cdiv(M, BLOCK_M), H, B)
193 | _bwd_preprocess[grid](
194 | o, do,
195 | delta,
196 | o.stride(0), o.stride(1), o.stride(2), o.stride(3),
197 | do.stride(0), do.stride(1), do.stride(2), do.stride(3),
198 | delta.stride(0), delta.stride(1), delta.stride(2),
199 | M,
200 | BLOCK_M=BLOCK_M, D_HEAD=D,
201 | DIVISIBLE_M=divisible_m,
202 | )
203 |
204 | # NOTE that dk & dv always have the same number of heads as q, instead of q.
205 | dk = torch.empty((B, H, N, D), dtype=k.dtype, device=q.device)
206 | dv = torch.empty((B, H, N, D), dtype=v.dtype, device=q.device)
207 | grid = (triton.cdiv(N, BLOCK_N), H, B)
208 | _bwd_kv_kernel[grid](
209 | q, k, v, sm_scale, do,
210 | dk, dv,
211 | L, delta,
212 | dropout_p,
213 | seed,
214 | offset,
215 | q.stride(0), q.stride(1), q.stride(2), q.stride(3),
216 | k.stride(0), k.stride(1), k.stride(2), k.stride(3),
217 | v.stride(0), v.stride(1), v.stride(2), v.stride(3),
218 | do.stride(0), do.stride(1), do.stride(2), do.stride(3),
219 | dk.stride(0), dk.stride(1), dk.stride(2), dk.stride(3),
220 | dv.stride(0), dv.stride(1), dv.stride(2), dv.stride(3),
221 | B, H, M, N, P_SEQ,
222 | num_groups,
223 | BLOCK_M=BLOCK_M, BLOCK_DMODEL=D, BLOCK_N=BLOCK_N, CAUSAL=causal,
224 | IS_DROPOUT=is_dropout,
225 | DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n,
226 | num_stages=num_stages, num_warps=num_warps,
227 | )
228 |
229 | dq = torch.zeros_like(q)
230 | grid = (triton.cdiv(M, BLOCK_M), H, B)
231 | _bwd_q_kernel[grid](
232 | q, k, v, sm_scale, do,
233 | dq,
234 | L, delta,
235 | dropout_p,
236 | seed,
237 | offset,
238 | q.stride(0), q.stride(1), q.stride(2), q.stride(3),
239 | k.stride(0), k.stride(1), k.stride(2), k.stride(3),
240 | v.stride(0), v.stride(1), v.stride(2), v.stride(3),
241 | do.stride(0), do.stride(1), do.stride(2), do.stride(3),
242 | dq.stride(0), dq.stride(1), dq.stride(2), dq.stride(3),
243 | B, H, M, N, P_SEQ,
244 | num_groups,
245 | BLOCK_M=BLOCK_M, BLOCK_DMODEL=D, BLOCK_N=BLOCK_N,
246 | CAUSAL=causal, IS_DROPOUT=is_dropout, LARGER_M=larger_m,
247 | DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n,
248 | num_stages=num_stages, num_warps = num_warps,
249 | )
250 | dk = dk.reshape((B, Hk, num_groups, N, D)).sum(2)
251 | dv = dv.reshape((B, Hk, num_groups, N, D)).sum(2)
252 | return dq, dk, dv, None, None, None, None, None, None
253 |
254 |
255 | def attention(q, k, v, causal=False, sm_scale=None, dropout_p=0.0,
256 | return_log_normalizer=False, return_total_attention=False, return_seed_offset=False
257 | ):
258 | """
259 | An implementation of FlashAttention v2(https://arxiv.org/abs/2307.08691).
260 |
261 | Arguments:
262 | q(torch.Tensor): The first queries. The shape is (batch_size, num_heads_q, seqlen_q, headdim).
263 | k(torch.Tensor): The first keys. The shape is (batch_size, num_heads_k, seqlen_k, headdim).
264 | v(torch.Tensor): The values. The shape is (batch_size, num_heads_k, seqlen_k, headdim).
265 | causal(bool): Whether causal masking is applied to attention scores before applying softmax.
266 | sm_scale(float): The scaling of attention scores before applying softmax.
267 | dropout_p(float): Dropout probability.
268 | return_log_normalizer(bool): Whether to return the log normalizer of softmax inside attention.
269 | return_total_attention(bool): Whether to return the sum of attention along q's sequence dimendion.
270 | return_seed_offset(bool): Whether to return dropout seed and offset
271 |
272 | Returns:
273 | out(torch.Tensor): The output. The shape is (batch_size, num_heads_q, seqlen_q, headdim).
274 |
275 | If `return_log_normalizer` or `return_total_attention` or `return_seed_offset` is True,
276 | return the following results in addition.
277 |
278 | log_normalizer(torch.Tensor): The log normalizer. The shape is (batch_size, num_heads_q, seqlen_q).
279 | total_attention(torch.Tensor): The total attention. The shape is (batch_size, num_heads_q, seqlen_k).
280 | seed(int): The Philox seed used in dropout.
281 | offset(int): The starting Philox offset used in dropout.
282 |
283 | Notes:
284 | `num_heads_q` must be a multiple of `num_heads_k`.
285 | """
286 | return FlashAttention.apply(q, k, v, causal, sm_scale, dropout_p, return_log_normalizer, return_total_attention, return_seed_offset)
287 |
288 |
289 | # --------------------------- Forward ---------------------------
290 | # NOTE: this function can be overwritten at runtime to use your custom config
291 | def get_fwd_config(B, H, M, N, D, causal):
292 | if torch.cuda.get_device_capability() == (8, 0):
293 | if not causal:
294 | if D <= 64:
295 | BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 64, 3, 4
296 | else:
297 | if M <= 1024:
298 | BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 32, 3, 4
299 | else:
300 | BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 128, 3, 8
301 | else:
302 | if D <= 64:
303 | BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 64, 4, 4
304 | else:
305 | if M <= 1024:
306 | BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 32, 2, 4
307 | else:
308 | BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 128, 3, 8
309 | elif torch.cuda.get_device_capability() == (8, 6):
310 | if not causal:
311 | if D <= 64:
312 | BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 64, 3, 4
313 | else:
314 | BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 32, 2, 4
315 | else: # causal
316 | if D <= 64:
317 | BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 3, 4
318 | else:
319 | BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 32, 2, 4
320 | else:
321 | BLOCK_M, BLOCK_N, num_stages, num_warps = 32, 32, 1, 4
322 | return (BLOCK_M, BLOCK_N, num_stages, num_warps)
323 |
324 |
325 | @triton.jit
326 | def _fwd_kernel(
327 | Q, K, V, sm_scale,
328 | dropout_p,
329 | seed,
330 | offset,
331 | L, O,
332 | stride_qz, stride_qh, stride_qm, stride_qk,
333 | stride_kz, stride_kh, stride_kn, stride_kk,
334 | stride_vz, stride_vh, stride_vn, stride_vk,
335 | stride_oz, stride_oh, stride_om, stride_ok,
336 | Z, H, M, N, P_SEQ,
337 | num_groups,
338 | BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,
339 | IS_CAUSAL: tl.constexpr, IS_DROPOUT: tl.constexpr, LARGER_M: tl.constexpr,
340 | DIVISIBLE_M: tl.constexpr, DIVISIBLE_N: tl.constexpr,
341 | ):
342 | input_dtype = Q.dtype.element_ty
343 | # -- grid id --
344 | start_m = tl.program_id(0)
345 | off_h = tl.program_id(1)
346 | off_z = tl.program_id(2)
347 |
348 | # scale sm_scale by log_2(e) and use
349 | # 2^x instead of exp in the loop because CSE and LICM
350 | # don't work as expected with `exp` in the loop
351 | log2e: tl.constexpr = 1.4426950408889634
352 | qk_scale = sm_scale * log2e
353 |
354 | # offset pointers for (batch, head)
355 | off_hk = off_h // num_groups
356 | Q += off_z * stride_qz + off_h * stride_qh
357 | K += off_z * stride_kz + off_hk * stride_kh
358 | V += off_z * stride_vz + off_hk * stride_vh
359 | O += off_z * stride_oz + off_h * stride_oh
360 | L += (off_z * H + off_h) * M # l's shape is (B, H, M)
361 |
362 | offs_m_base = tl.arange(0, BLOCK_M)
363 | offs_m = start_m * BLOCK_M + offs_m_base
364 | offs_n_base = tl.arange(0, BLOCK_N)
365 | offs_k = tl.arange(0, BLOCK_DMODEL)
366 |
367 | if IS_DROPOUT:
368 | rowblock_base = off_z * H * M * N + off_h * M * N + start_m * BLOCK_M * N
369 | offs_rng_base = offset + rowblock_base
370 | offs_rng_base += tl.arange(0, BLOCK_M)[:, None] * N
371 | offs_rng_base += tl.arange(0, BLOCK_N)[None, :]
372 |
373 | # initialize pointers to value-like data
374 | q_ptrs = Q + (offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk) # (BLOCK_M, BLOCK_DMODEL)
375 | o_ptrs = O + (offs_m[:, None] * stride_om + offs_k[None, :] * stride_ok) # (BLOCK_M, BLOCK_DMODEL)
376 | l_ptrs = L + offs_m
377 |
378 | # initialize pointer to m and l, fp32 for accumulators
379 | m_i = tl.full([BLOCK_M], value=-float("inf"), dtype=tl.float32)
380 | l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
381 | acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
382 |
383 | # load q
384 | if DIVISIBLE_M:
385 | q = tl.load(q_ptrs, cache_modifier=".cg")
386 | else:
387 | mask_m = offs_m < M
388 | q = tl.load(q_ptrs, mask=mask_m[:, None], cache_modifier=".cg")
389 |
390 | #Dot I trick: to place q in registers, it saves shared memory
391 | if BLOCK_DMODEL < 128:
392 | I = tl.where(offs_k[:, None] == offs_k,
393 | tl.full((BLOCK_DMODEL, BLOCK_DMODEL), 1.0, dtype=input_dtype),
394 | tl.full((BLOCK_DMODEL, BLOCK_DMODEL), 0.0, dtype=input_dtype))
395 | q = tl.dot(q, I).to(input_dtype)
396 | # else:
397 | # I = tl.where(offs_m_base[:, None] == offs_m_base,
398 | # tl.full((BLOCK_M, BLOCK_M), 1.0, dtype=input_dtype),
399 | # tl.full((BLOCK_M, BLOCK_M), 0.0, dtype=input_dtype))
400 | # q = tl.dot(I, q).to(input_dtype)
401 |
402 | # NOTE: Loop-Bound-For-N
403 | # The indices in m-dimension that this block may access is in `[start_m * BLOCK_M, (start_m + 1) * BLOCK_M)`.
404 | # According to the rule of causal masking, then max index in n-dimension that this block may access
405 | # is `P_SEQ + (start_m + 1) * BLOCK_M`.
406 | # However, the upper bound of index in n-dimension should never exceed the sequence length of k/v(`P_SEQ + N_CTX`).
407 | # `P_SEQ + (start_m + 1) * BLOCK_M` may be larger than `N`.
408 | # At this case, there would be illegal memory access when loading k & v tiles
409 | # if mask_n is not applied for loading(only when `DIVISIBLE_N`` is true).
410 | # See also https://github.com/FlagOpen/FlagAttention/pull/8
411 | if IS_CAUSAL:
412 | hi = tl.minimum(N, P_SEQ + (start_m + 1) * BLOCK_M)
413 | if LARGER_M:
414 | hi = tl.maximum(0, hi)
415 | else:
416 | hi = N
417 |
418 | # loop over k, v and update accumulators
419 | offs_n_init = offs_n_base
420 | k_ptrs = K + (offs_k[:, None] * stride_vk + offs_n_init[None, :] * stride_vn) # (BLOCK_DMODEL, BLOCK_N)
421 | v_ptrs = V + (offs_n_init[:, None] * stride_kn + offs_k[None, :] * stride_kk) # (BLOCK_N, BLOCK_DMODEL)
422 | for start_n in range(0, hi, BLOCK_N):
423 | start_n = tl.multiple_of(start_n, BLOCK_N)
424 | offs_n = start_n + offs_n_base
425 |
426 | # -- load k, v --
427 | if DIVISIBLE_N:
428 | k = tl.load(k_ptrs, cache_modifier=".cg")
429 | v = tl.load(v_ptrs, cache_modifier=".cg")
430 | else:
431 | mask_n = offs_n < N
432 | k = tl.load(k_ptrs, mask=mask_n[None, :], cache_modifier=".cg")
433 | v = tl.load(v_ptrs, mask=mask_n[:, None], cache_modifier=".cg")
434 |
435 | # -- compute qk ---
436 | s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
437 | s += tl.dot(q, k)
438 |
439 | if not DIVISIBLE_N:
440 | s = tl.where(mask_n[None, :], s, float("-inf"))
441 | if IS_CAUSAL:
442 | causal_mask = (P_SEQ + offs_m[:, None]) >= offs_n[None, :]
443 | s = tl.where(causal_mask, s, float("-inf"))
444 |
445 | # -- compute scaling constant ---
446 | m_i_new = tl.maximum(m_i, tl.max(s, 1))
447 | alpha = tl.math.exp2((m_i - m_i_new) * qk_scale)
448 | p = tl.math.exp2(s * qk_scale - m_i_new[:, None] * qk_scale)
449 |
450 | # -- compute partial sumexpn before applying dropout
451 | p_sum = tl.sum(p, 1)
452 |
453 | # -- apply dropout --
454 | if IS_DROPOUT:
455 | offs_rng = start_n + offs_rng_base
456 | pmask = tl.rand(seed, offs_rng, n_rounds=6) > dropout_p
457 | p *= pmask.to(tl.float32)
458 |
459 | # -- scale and update acc: acc *= alpha[:, None]--
460 | acc *= alpha[:, None]
461 | acc += tl.dot(p.to(input_dtype), v)
462 |
463 | # -- update m_i and l_i --
464 | l_i = l_i * alpha + p_sum
465 | m_i = m_i_new
466 | # update pointers
467 | k_ptrs += BLOCK_N * stride_kn
468 | v_ptrs += BLOCK_N * stride_vn
469 |
470 | # write back l & o
471 | if IS_CAUSAL and LARGER_M:
472 | is_empty_line = (offs_m + P_SEQ) < 0
473 | acc = tl.where(is_empty_line[:, None], 0.0, acc * (1.0 / l_i[:, None]))
474 | l = tl.where(is_empty_line, float("-inf"), m_i * sm_scale + tl.log(l_i))
475 | else:
476 | acc = acc * (1.0 / l_i[:, None])
477 | l = m_i * sm_scale + tl.log(l_i) # log(normalizer)
478 |
479 | # -- scale o due to dropout
480 | if IS_DROPOUT:
481 | scale = 1.0 / (1.0 - dropout_p)
482 | acc *= scale
483 |
484 | if DIVISIBLE_M:
485 | tl.store(l_ptrs, l, cache_modifier=".cg")
486 | tl.store(o_ptrs, acc.to(input_dtype), cache_modifier=".cg")
487 | else:
488 | tl.store(l_ptrs, l, mask=mask_m, cache_modifier=".cg")
489 | tl.store(o_ptrs, acc.to(input_dtype), mask=mask_m[:, None], cache_modifier=".cg")
490 |
491 |
492 | # --------------------------- Backward ---------------------------
493 | # NOTE: this function can be overwritten at runtime to use your custom config
494 | def get_bwd_config(B, H, M, N, D, causal):
495 | if torch.cuda.get_device_capability() == (8, 0):
496 | if not causal:
497 | BLOCK_M = 128 if D <= 64 else 64
498 | BLOCK_N = 64
499 | num_stages = 2
500 | num_warps = 4
501 | else:
502 | BLOCK_M = 64
503 | BLOCK_N = 64
504 | num_stages = 3 if D <= 64 else 2
505 | num_warps = 4
506 | elif torch.cuda.get_device_capability() == (8, 6): # tune for RTX-3090, device_capability(8, 6)
507 | if not causal:
508 | if D <= 64:
509 | BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 2, 4
510 | else:
511 | BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 2, 8
512 | else:
513 | if D <= 64:
514 | BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 2, 4
515 | else:
516 | BLOCK_M, BLOCK_N, num_stages, num_warps = 32, 32, 2, 4
517 | else:
518 | BLOCK_M, BLOCK_N, num_stages, num_warps = 32, 32, 1, 4
519 | return (BLOCK_M, BLOCK_N, num_stages, num_warps)
520 |
521 |
522 | @triton.jit
523 | def _bwd_preprocess(
524 | Out, DO,
525 | Delta,
526 | stride_oz, stride_oh, stride_om, stride_ok,
527 | stride_doz, stride_doh, stride_dom, stride_dok,
528 | stride_dz, stride_dh, stride_dm,
529 | M,
530 | BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
531 | DIVISIBLE_M: tl.constexpr,
532 | ):
533 | off_h = tl.program_id(1)
534 | off_z = tl.program_id(2)
535 | Out += off_z * stride_oz + off_h * stride_oh
536 | DO += off_z * stride_doz + off_h * stride_doh
537 | Delta += off_z * stride_dz + off_h * stride_dh
538 |
539 | # compute (Out * Dout).sum() for vector interpretation
540 | off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
541 | off_n = tl.arange(0, D_HEAD)
542 |
543 | # load
544 | o_ptrs = Out + off_m[:, None] * stride_om + off_n[None, :] * stride_ok
545 | do_ptrs = DO + off_m[:, None] * stride_dom + off_n[None, :] * stride_dok
546 |
547 | if DIVISIBLE_M:
548 | o = tl.load(o_ptrs).to(tl.float32)
549 | do = tl.load(do_ptrs).to(tl.float32)
550 | else:
551 | mask_m = off_m < M
552 | o = tl.load(o_ptrs, mask=mask_m[:, None]).to(tl.float32)
553 | do = tl.load(do_ptrs, mask=mask_m[:, None]).to(tl.float32)
554 |
555 | # compute
556 | delta = tl.sum(o * do, axis=1)
557 |
558 | # (NOTE) dropout scaling doesn't affect delta's value
559 | # when dropout is applied, o and do are actually scaled.
560 | # original_o equals o times reverse scale while original_do is do times scale,
561 | # and thus delta remains unchanged.
562 |
563 | # write-back
564 | d_ptrs = Delta + off_m * stride_dm
565 | if DIVISIBLE_M:
566 | tl.store(d_ptrs, delta)
567 | else:
568 | tl.store(d_ptrs, delta, mask=mask_m)
569 |
570 |
571 | @triton.jit
572 | def _bwd_kv_kernel(
573 | Q, K, V, sm_scale, DO,
574 | DK, DV,
575 | L,
576 | D,
577 | dropout_p,
578 | seed,
579 | offset,
580 | stride_qz, stride_qh, stride_qm, stride_qk,
581 | stride_kz, stride_kh, stride_kn, stride_kk,
582 | stride_vz, stride_vh, stride_vn, stride_vk,
583 | stride_doz, stride_doh, stride_dom, stride_dok,
584 | stride_dkz, stride_dkh, stride_dkn, stride_dkk,
585 | stride_dvz, stride_dvh, stride_dvn, stride_dvk,
586 | Z, H, M, N, P_SEQ,
587 | num_groups,
588 | BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,
589 | CAUSAL: tl.constexpr, IS_DROPOUT: tl.constexpr,
590 | DIVISIBLE_M: tl.constexpr, DIVISIBLE_N: tl.constexpr,
591 | ):
592 | input_dtype = Q.dtype.element_ty
593 | # -- grid id --
594 | start_n = tl.program_id(0)
595 | off_h = tl.program_id(1)
596 | off_z = tl.program_id(2)
597 | log2e: tl.constexpr = 1.4426950408889634
598 | qk_scale = sm_scale * log2e
599 |
600 | # offset pointers for (batch, head)
601 | off_hk = off_h // num_groups
602 | Q += off_z * stride_qz + off_h * stride_qh
603 | K += off_z * stride_kz + off_hk * stride_kh
604 | V += off_z * stride_vz + off_hk * stride_vh
605 | DO += off_z * stride_doz + off_h * stride_doh
606 |
607 | # offset pointers for batch/head
608 | DK += off_z * stride_dkz + off_h * stride_dkh
609 | DV += off_z * stride_dvz + off_h * stride_dvh
610 |
611 | # offset pointers for batch/head
612 | D += (off_z * H + off_h) * M
613 | L += (off_z * H + off_h) * M
614 |
615 | if CAUSAL:
616 | lo = tl.maximum(start_n * BLOCK_N - P_SEQ, 0)
617 | lo = (lo // BLOCK_M) * BLOCK_M
618 | else:
619 | lo = 0
620 |
621 | offs_m_init = lo + tl.arange(0, BLOCK_M)
622 | offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
623 | offs_m_base = tl.arange(0, BLOCK_M)
624 | offs_k = tl.arange(0, BLOCK_DMODEL)
625 |
626 | # initialize pointers to value-like data
627 | q_ptrs = Q + (offs_m_init[:, None] * stride_qm + offs_k[None, :] * stride_qk) # (BLOCK_M, BLOCK_DMODEL)
628 | k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) # (BLOCK_N, BLOCK_DMODEL)
629 | v_ptrs = V + (offs_n[:, None] * stride_vn + offs_k[None, :] * stride_vk) # (BLOCK_N, BLOCK_DMODEL)
630 | do_ptrs = DO + (offs_m_init[:, None] * stride_dom + offs_k[None, :] * stride_dok) # (BLOCK_M, BLOCK_DMODEL)
631 |
632 | dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_k[None, :] * stride_dvk) # (BLOCK_N, BLOCK_DMODEL)
633 | dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_k[None, :] * stride_dkk) # (BLOCK_N, BLOCK_DMODEL)
634 |
635 | # k and v stay in SRAM throughout
636 | if DIVISIBLE_N:
637 | v = tl.load(v_ptrs)
638 | k = tl.load(k_ptrs)
639 | else:
640 | mask_n = offs_n < N
641 | v = tl.load(v_ptrs, mask=mask_n[:, None])
642 | k = tl.load(k_ptrs, mask=mask_n[:, None])
643 |
644 | # dropout
645 | if IS_DROPOUT:
646 | colblock_base = off_z * H * M * N + off_h * M * N + start_n * BLOCK_N
647 | offs_rng_base = offset + colblock_base
648 | offs_rng_base += tl.arange(0, BLOCK_M)[:, None] * N
649 | offs_rng_base += tl.arange(0, BLOCK_N)[None, :]
650 | rp = 1. / (1. - dropout_p)
651 |
652 | # initialize dk amd dv
653 | dk = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)
654 | dv = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)
655 |
656 | # loop over a col
657 | for start_m in range(lo, M, BLOCK_M):
658 | start_m = tl.multiple_of(start_m, BLOCK_M)
659 | offs_m = start_m + offs_m_base
660 | causal_mask = (P_SEQ + offs_m[:, None]) >= (offs_n[None, :]) # (BLOCK_M, BLOCK_N)
661 |
662 | # load q1, k1, q2, k2, v, do on-chip
663 | if DIVISIBLE_M:
664 | q = tl.load(q_ptrs)
665 | else:
666 | mask_m = offs_m < M
667 | valid_mask = mask_m[:, None] # & mask_n
668 | q = tl.load(q_ptrs, mask=mask_m[:, None])
669 | # recompute p = softmax(qk * sm_scale, dim=-1)
670 | s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
671 | s += tl.dot(q, tl.trans(k))
672 |
673 | # NOTE: since softmax in backward is pointwise, the normalizer has been saved in fwd)
674 | # So masking on s is not needed.
675 | # s = tl.where(valid_mask, s , float("-inf"))
676 | # if CAUSAL:
677 | # s = tl.where(causal_mask, s, float("-inf"))
678 |
679 | # -- recompute p ---
680 | if DIVISIBLE_M:
681 | l = tl.load(L + offs_m)
682 | else:
683 | l = tl.load(L + offs_m, mask=mask_m)
684 | p = tl.math.exp2(s * qk_scale - l[:, None] * log2e) # (BLOCK_M, BLOCK_N)
685 |
686 | if not DIVISIBLE_M:
687 | p = tl.where(valid_mask, p, 0.0)
688 | if CAUSAL:
689 | p = tl.where(causal_mask, p, 0.0)
690 |
691 | # compute dv = dot(p, do)
692 | if DIVISIBLE_M:
693 | do = tl.load(do_ptrs)
694 | else:
695 | do = tl.load(do_ptrs, mask=mask_m[:, None]) # (BLOCK_M, BLOCK_DMODEL)
696 |
697 | if IS_DROPOUT:
698 | # do *= rp
699 | offs_rng = offs_rng_base + start_m * N
700 | pmask = tl.rand(seed, offs_rng, n_rounds=6) > dropout_p
701 | p_masked = p * pmask
702 | p_masked = p_masked.to(input_dtype)
703 |
704 | # -- apply dropout --
705 | if IS_DROPOUT:
706 | dv += tl.dot(tl.trans(p_masked), do) * rp # (BLOCK_N, BLOCK_DMODEL) # still correct
707 | else:
708 | dv += tl.dot(tl.trans(p).to(input_dtype), do) # (BLOCK_N, BLOCK_DMODEL) # still correct
709 |
710 | # compute dp = dot(v, do)
711 | if DIVISIBLE_M:
712 | delta = tl.load(D + offs_m)
713 | else:
714 | delta = tl.load(D + offs_m, mask=mask_m)
715 | dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
716 | dp += tl.dot(do, tl.trans(v))
717 |
718 | # -- apply dropout --
719 | if IS_DROPOUT:
720 | dp *= rp
721 | dp *= pmask
722 |
723 | # compute ds = p * (dp - delta[:, None])
724 | ds = p * (dp - delta[:, None]) # (BLOCK_M, BLOCK_N)
725 |
726 | if not DIVISIBLE_M:
727 | ds = tl.where(valid_mask, ds, 0.0)
728 | if CAUSAL:
729 | ds = tl.where(causal_mask, ds, 0.0)
730 | ds = ds.to(input_dtype)
731 |
732 | # compute dk = dot(ds.T, q) masking
733 | dk += tl.dot(tl.trans(ds), q)
734 |
735 | # increment pointers
736 | q_ptrs += BLOCK_M * stride_qm
737 | do_ptrs += BLOCK_M * stride_dom
738 |
739 | dk *= sm_scale
740 | if DIVISIBLE_N:
741 | tl.store(dk_ptrs, dk.to(input_dtype)) # (BLOCK_N, BLOCK_DMODEL)
742 | tl.store(dv_ptrs, dv.to(input_dtype)) # (BLOCK_N, BLOCK_DMODEL,)
743 | else:
744 | tl.store(dk_ptrs, dk.to(input_dtype), mask=mask_n[:, None]) # (BLOCK_N, BLOCK_DMODEL)
745 | tl.store(dv_ptrs, dv.to(input_dtype), mask=mask_n[:, None]) # (BLOCK_N, BLOCK_DMODEL,)
746 |
747 |
748 | @triton.jit
749 | def _bwd_q_kernel(
750 | Q, K, V, sm_scale, DO,
751 | DQ,
752 | L,
753 | D,
754 | dropout_p,
755 | seed,
756 | offset,
757 | stride_qz, stride_qh, stride_qm, stride_qk,
758 | stride_kz, stride_kh, stride_kn, stride_kk,
759 | stride_vz, stride_vh, stride_vn, stride_vk,
760 | stride_doz, stride_doh, stride_dom, stride_dok,
761 | stride_dqz, stride_dqh, stride_dqm, stride_dqk,
762 | Z, H, M, N, P_SEQ,
763 | num_groups,
764 | BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,
765 | CAUSAL: tl.constexpr, IS_DROPOUT: tl.constexpr, LARGER_M: tl.constexpr,
766 | DIVISIBLE_M: tl.constexpr, DIVISIBLE_N: tl.constexpr,
767 | ):
768 | input_dtype = Q.dtype.element_ty
769 | # -- grid id --
770 | start_m = tl.program_id(0)
771 | off_h = tl.program_id(1)
772 | off_z = tl.program_id(2)
773 |
774 | # scale sm_scale by log_2(e) and use
775 | # 2^x instead of exp in the loop because CSE and LICM
776 | # don't work as expected with `exp` in the loop
777 | log2e: tl.constexpr = 1.4426950408889634
778 | qk_scale = sm_scale * log2e
779 |
780 | # offset pointers for (batch, head)
781 | off_hk = off_h // num_groups
782 | Q += off_z * stride_qz + off_h * stride_qh
783 | K += off_z * stride_kz + off_hk * stride_kh
784 | V += off_z * stride_vz + off_hk * stride_vh
785 | DO += off_z * stride_doz + off_h * stride_doh
786 | D += (off_z * H + off_h) * M
787 | L += (off_z * H + off_h) * M
788 |
789 | # offset pointers for batch/head
790 | DQ += off_z * stride_dqz + off_h * stride_dqh
791 |
792 | offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
793 | offs_n_base = tl.arange(0, BLOCK_N)
794 | offs_n_init = offs_n_base
795 | offs_k = tl.arange(0, BLOCK_DMODEL)
796 |
797 | # initialize pointers to value-like data
798 | q_ptrs = Q + (offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk) # (BLOCK_M, BLOCK_DMODEL)
799 | k_ptrs = K + (offs_n_init[:, None] * stride_kn + offs_k[None, :] * stride_kk) # (BLOCK_N, BLOCK_DMODEL)
800 | v_ptrs = V + (offs_n_init[:, None] * stride_vn + offs_k[None, :] * stride_vk) # (BLOCK_N, BLOCK_DMODEL)
801 |
802 | dq_ptrs = DQ + (offs_m[:, None] * stride_dqm + offs_k[None, :] * stride_dqk) # (BLOCK_M, BLOCK_DMODEL)
803 | do_ptrs = DO + (offs_m[:, None] * stride_dom + offs_k[None, :] * stride_dok) # (BLOCK_M, BLOCK_DMODEL)
804 |
805 | # pointer to row-wise quantities in value-like data
806 | d_ptrs = D + offs_m
807 | l_ptrs = L + offs_m
808 |
809 | # load q: it will stay in SRAM throughout
810 | if DIVISIBLE_M:
811 | q = tl.load(q_ptrs)
812 | do = tl.load(do_ptrs)
813 | delta = tl.load(d_ptrs)
814 | l = tl.load(l_ptrs)
815 | else:
816 | mask_m = offs_m < M
817 | q = tl.load(q_ptrs, mask=mask_m[:, None])
818 | do = tl.load(do_ptrs, mask=mask_m[:, None])
819 | delta = tl.load(d_ptrs, mask=mask_m)
820 | l = tl.load(l_ptrs, mask=mask_m)
821 |
822 | # initialize dq
823 | dq = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
824 |
825 | # loop over k, v and update accumulator
826 | # see note "Loop-Bound-For-N"
827 | if CAUSAL:
828 | hi = tl.minimum(N, P_SEQ + (start_m + 1) * BLOCK_M)
829 | if LARGER_M:
830 | hi = tl.maximum(0, hi)
831 | else:
832 | hi = N
833 |
834 | # dropout
835 | if IS_DROPOUT:
836 | rowblock_base = off_z * H * M * N + off_h * M * N + start_m * BLOCK_M * N
837 | offs_rng_base = offset + rowblock_base
838 | offs_rng_base += tl.arange(0, BLOCK_M)[:, None] * N
839 | offs_rng_base += tl.arange(0, BLOCK_N)[None, :]
840 | rp = 1. / (1. - dropout_p)
841 | do *= rp.to(do.dtype)
842 |
843 | # loop over a row
844 | for start_n in range(0, hi, BLOCK_N):
845 | offs_n = start_n + offs_n_base
846 |
847 | # load k1, k2, v on chip
848 | if DIVISIBLE_N:
849 | v = tl.load(v_ptrs)
850 | k = tl.load(k_ptrs)
851 | else:
852 | mask_n = offs_n < N
853 | v = tl.load(v_ptrs, mask=mask_n[:, None])
854 | k = tl.load(k_ptrs, mask=mask_n[:, None])
855 |
856 |
857 | # recompute p = softmax(qk * sm_scale, dim=-1)
858 | if not DIVISIBLE_N:
859 | valid_mask = mask_n # & mask_m[:, None]
860 | if CAUSAL:
861 | causal_mask = (P_SEQ + offs_m[:, None]) >= (offs_n[None, :]) # (BLOCK_M, BLOCK_N)
862 | s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
863 | s += tl.dot(q, tl.trans(k))
864 |
865 | # NOTE: since softmax in backward is pointwise, the normalizer has been saved in fwd)
866 | # So masking on s is not needed.
867 | # if CAUSAL:
868 | # s = tl.where(causal_mask & valid_mask, s, float("-inf"))
869 | # else:
870 | # s = tl.where(valid_mask, s, float("-inf"))
871 | p = tl.math.exp2(s * qk_scale - l[:, None] * log2e) # (BLOCK_M, BLOCK_N)
872 |
873 | # compute dp = dot(v, do)
874 | dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
875 | dp += tl.dot(do.to(input_dtype), tl.trans(v))
876 |
877 | if IS_DROPOUT:
878 | offs_rng = start_n + offs_rng_base
879 | pmask = tl.rand(seed, offs_rng, n_rounds=6) > dropout_p
880 | dp *= pmask
881 | # p_dropout = p * pmask.to(tl.float32)
882 |
883 | # no need to mask dp
884 | # if CAUSAL:
885 | # dp = tl.where(causal_mask & valid_mask, dp, 0.0)
886 | # else:
887 | # dp = tl.where(valid_mask, dp, 0.0)
888 |
889 | # compute ds = p * (dp - delta[:, None])
890 | # move scale out to dq at last
891 | ds = p * (dp - delta[:, None]) # (BLOCK_M, BLOCK_N)
892 |
893 | # mask ds to ensure no small values
894 | if not DIVISIBLE_N:
895 | ds = tl.where(valid_mask, ds, 0.0)
896 | if CAUSAL:
897 | ds = tl.where(causal_mask, ds, 0.0)
898 |
899 | dq += tl.dot(ds.to(input_dtype), k)
900 |
901 | # increment pointers
902 | k_ptrs += BLOCK_N * stride_kn
903 | v_ptrs += BLOCK_N * stride_vn
904 |
905 | dq *= sm_scale
906 | if DIVISIBLE_M:
907 | tl.store(dq_ptrs, dq.to(input_dtype))
908 | else:
909 | tl.store(dq_ptrs, dq.to(input_dtype), mask=mask_m[:, None])
910 |
--------------------------------------------------------------------------------