├── tests ├── __init__.py └── test_jvp_attention.py ├── jvp_flash_attention ├── __init__.py └── jvp_attention.py ├── main.png ├── float32_mem_scaling.png ├── float32_time_scaling.png ├── CITATION.cff ├── .github └── workflows │ └── publish.yaml ├── LICENSE ├── pyproject.toml ├── .pre-commit-config.yaml ├── .gitignore └── README.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /jvp_flash_attention/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /main.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amorehead/jvp_flash_attention/HEAD/main.png -------------------------------------------------------------------------------- /float32_mem_scaling.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amorehead/jvp_flash_attention/HEAD/float32_mem_scaling.png -------------------------------------------------------------------------------- /float32_time_scaling.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amorehead/jvp_flash_attention/HEAD/float32_time_scaling.png -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you want to cite the kernel, feel free to use this (but only if you loved it 😊)" 3 | title: "JVP Flash Attention" 4 | abstract: "A Flash Attention Triton kernel with support for second-order derivatives, such as Jacobian-Vector Products (JVPs) and Hessian-Vector Products (HVPs)." 5 | date-released: 2025-09-05 6 | authors: 7 | - family-names: "Morehead" 8 | given-names: "Alex" 9 | version: 0.10.0 10 | doi: 10.5281/zenodo.17050188 11 | license: "MIT" 12 | url: "https://zenodo.org/records/17050188" 13 | repository-code: "https://github.com/amorehead/jvp_flash_attention" 14 | keywords: 15 | - artificial intelligence 16 | - deep learning 17 | -------------------------------------------------------------------------------- /.github/workflows/publish.yaml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | jobs: 16 | deploy: 17 | runs-on: ubuntu-latest 18 | 19 | steps: 20 | - uses: actions/checkout@v4 21 | - name: Set up Python 3.10 22 | uses: actions/setup-python@v5 23 | with: 24 | python-version: "3.10" 25 | - name: Install dependencies 26 | run: | 27 | python -m pip install --upgrade pip 28 | pip install build 29 | - name: Build package 30 | run: python -m build 31 | - name: Publish package 32 | uses: pypa/gh-action-pypi-publish@v1.12.4 33 | with: 34 | user: __token__ 35 | password: ${{ secrets.PYPI_API_TOKEN }} 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | JVP Flash Attention (jvp_flash_attention) Copyright (c) 2025, The Regents of the University of California, through Lawrence Berkeley National Laboratory (subject to receipt of any required approvals from the U.S. Dept. of Energy). All rights reserved. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights to use, 8 | copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 9 | Software, and to permit persons to whom the Software is furnished to do so, 10 | subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | 21 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | 23 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 24 | 25 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | 27 | SOFTWARE. 28 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "jvp_flash_attention" 3 | version = "0.10.0" 4 | description = "Flash Attention Triton kernel with support for second-order derivatives" 5 | authors = [ 6 | { name = "Alex Morehead", email = "alex.morehead@gmail.edu" } 7 | ] 8 | readme = "README.md" 9 | requires-python = ">=3.10" 10 | license = "MIT" 11 | keywords = [ 12 | 'artificial intelligence', 13 | 'deep learning', 14 | ] 15 | 16 | classifiers=[ 17 | 'Development Status :: 4 - Beta', 18 | 'Intended Audience :: Developers', 19 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 20 | 'License :: OSI Approved :: MIT License', 21 | 'Programming Language :: Python :: 3.10', 22 | ] 23 | 24 | dependencies = [ 25 | # --------- pytorch --------- # 26 | "torch>=2.8.0", 27 | "torchvision>=0.23", 28 | 29 | # --------- others --------- # 30 | # "numpy>=2.2.6", 31 | ] 32 | 33 | [project.urls] 34 | Homepage = "https://pypi.org/project/jvp_flash_attention/" 35 | Repository = "https://github.com/amorehead/jvp-flash-attention" 36 | 37 | [project.optional-dependencies] 38 | lint = ["pre-commit>=4.3.0"] 39 | 40 | [build-system] 41 | requires = ["hatchling"] 42 | build-backend = "hatchling.build" 43 | 44 | [tool.pytest.ini_options] 45 | pythonpath = [ 46 | "." 47 | ] 48 | addopts = [ 49 | "--color=yes", 50 | "--durations=0", 51 | "--strict-markers", 52 | "--doctest-modules", 53 | ] 54 | filterwarnings = [ 55 | "ignore::DeprecationWarning", 56 | "ignore::UserWarning", 57 | ] 58 | log_cli = "True" 59 | markers = [ 60 | "slow: slow tests", 61 | ] 62 | minversion = "6.0" 63 | testpaths = "tests/" 64 | 65 | # Assuming you're developing for Python 3.10 66 | target-version = "py310" 67 | 68 | [tool.hatch.metadata] 69 | allow-direct-references = true 70 | 71 | [tool.hatch.build.targets.wheel] 72 | packages = ["jvp_flash_attention"] 73 | 74 | [tool.coverage.report] 75 | exclude_lines = [ 76 | "pragma: nocover", 77 | "raise NotImplementedError", 78 | "raise NotImplementedError()", 79 | "if __name__ == .__main__.:", 80 | ] 81 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3 3 | 4 | repos: 5 | - repo: https://github.com/pre-commit/pre-commit-hooks 6 | rev: v6.0.0 7 | hooks: 8 | # list of supported hooks: https://pre-commit.com/hooks.html 9 | - id: trailing-whitespace 10 | - id: end-of-file-fixer 11 | - id: check-docstring-first 12 | - id: check-yaml 13 | - id: debug-statements 14 | - id: detect-private-key 15 | - id: check-executables-have-shebangs 16 | - id: check-toml 17 | - id: check-case-conflict 18 | - id: check-added-large-files 19 | args: ["--maxkb=5000"] 20 | 21 | # python code formatting 22 | - repo: https://github.com/psf/black 23 | rev: 25.1.0 24 | hooks: 25 | - id: black 26 | args: [--line-length, "99"] 27 | 28 | # python import sorting 29 | - repo: https://github.com/PyCQA/isort 30 | rev: 6.0.1 31 | hooks: 32 | - id: isort 33 | args: ["--profile", "black", "--filter-files"] 34 | 35 | # python upgrading syntax to newer version 36 | - repo: https://github.com/asottile/pyupgrade 37 | rev: v3.20.0 38 | hooks: 39 | - id: pyupgrade 40 | args: [--py38-plus] 41 | 42 | # python docstring formatting 43 | # address runtime issue: https://github.com/PyCQA/docformatter/pull/287 44 | - repo: local 45 | hooks: 46 | - id: docformatter 47 | name: docformatter 48 | description: Formats docstrings to follow Google style. 49 | entry: python -Im docformatter 50 | additional_dependencies: 51 | - docformatter == 1.7.5 52 | args: 53 | [ 54 | --in-place, 55 | --wrap-summaries=99, 56 | --wrap-descriptions=99, 57 | --style=google, 58 | --black, 59 | ] 60 | language: python 61 | types: 62 | - python 63 | 64 | # python docstring coverage checking 65 | - repo: https://github.com/econchick/interrogate 66 | rev: 1.7.0 # or master if you're bold 67 | hooks: 68 | - id: interrogate 69 | args: 70 | [ 71 | --verbose, 72 | --fail-under=80, 73 | --ignore-init-module, 74 | --ignore-init-method, 75 | --ignore-module, 76 | --ignore-nested-functions, 77 | -vv, 78 | ] 79 | 80 | # python check (PEP8), programming errors and code complexity 81 | - repo: https://github.com/PyCQA/flake8 82 | rev: 7.3.0 83 | hooks: 84 | - id: flake8 85 | args: 86 | [ 87 | "--extend-ignore", 88 | "E203,E402,E501,F401,F841,RST2,RST301", 89 | "--exclude", 90 | "logs/*,data/*", 91 | ] 92 | additional_dependencies: [flake8-rst-docstrings==0.3.0] 93 | 94 | # python security linter 95 | - repo: https://github.com/PyCQA/bandit 96 | rev: "1.8.6" 97 | hooks: 98 | - id: bandit 99 | args: ["-s", "B101"] 100 | 101 | # yaml formatting 102 | - repo: https://github.com/pre-commit/mirrors-prettier 103 | rev: v4.0.0-alpha.8 104 | hooks: 105 | - id: prettier 106 | types: [yaml] 107 | exclude: "environment.yaml" 108 | 109 | # shell scripts linter 110 | - repo: https://github.com/shellcheck-py/shellcheck-py 111 | rev: v0.11.0.1 112 | hooks: 113 | - id: shellcheck 114 | 115 | # md formatting 116 | - repo: https://github.com/executablebooks/mdformat 117 | rev: 0.7.22 118 | hooks: 119 | - id: mdformat 120 | args: ["--number"] 121 | additional_dependencies: 122 | - mdformat-gfm 123 | - mdformat-tables 124 | - mdformat_frontmatter 125 | # - mdformat-toc 126 | # - mdformat-black 127 | 128 | # word spelling linter 129 | - repo: https://github.com/codespell-project/codespell 130 | rev: v2.4.1 131 | hooks: 132 | - id: codespell 133 | args: 134 | - --skip=logs/**,data/**,*.ipynb 135 | # - --ignore-words-list=abc,def 136 | 137 | # jupyter notebook cell output clearing 138 | - repo: https://github.com/kynan/nbstripout 139 | rev: 0.8.1 140 | hooks: 141 | - id: nbstripout 142 | 143 | # jupyter notebook linting 144 | - repo: https://github.com/nbQA-dev/nbQA 145 | rev: 1.9.1 146 | hooks: 147 | - id: nbqa-black 148 | args: ["--line-length=99"] 149 | - id: nbqa-isort 150 | args: ["--profile=black"] 151 | - id: nbqa-flake8 152 | args: 153 | [ 154 | "--extend-ignore=E203,E402,E501,F401,F841", 155 | "--exclude=logs/*,data/*", 156 | ] 157 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[codz] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | #poetry.toml 110 | 111 | # pdm 112 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 113 | # pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python. 114 | # https://pdm-project.org/en/latest/usage/project/#working-with-version-control 115 | #pdm.lock 116 | #pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # pixi 121 | # Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control. 122 | #pixi.lock 123 | # Pixi creates a virtual environment in the .pixi directory, just like venv module creates one 124 | # in the .venv directory. It is recommended not to include this directory in version control. 125 | .pixi 126 | 127 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 128 | __pypackages__/ 129 | 130 | # Celery stuff 131 | celerybeat-schedule 132 | celerybeat.pid 133 | 134 | # SageMath parsed files 135 | *.sage.py 136 | 137 | # Environments 138 | .env 139 | .envrc 140 | .venv 141 | env/ 142 | venv/ 143 | ENV/ 144 | env.bak/ 145 | venv.bak/ 146 | 147 | # Spyder project settings 148 | .spyderproject 149 | .spyproject 150 | 151 | # Rope project settings 152 | .ropeproject 153 | 154 | # mkdocs documentation 155 | /site 156 | 157 | # mypy 158 | .mypy_cache/ 159 | .dmypy.json 160 | dmypy.json 161 | 162 | # Pyre type checker 163 | .pyre/ 164 | 165 | # pytype static type analyzer 166 | .pytype/ 167 | 168 | # Cython debug symbols 169 | cython_debug/ 170 | 171 | # PyCharm 172 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 173 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 174 | # and can be added to the global gitignore or merged into this file. For a more nuclear 175 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 176 | #.idea/ 177 | 178 | # Abstra 179 | # Abstra is an AI-powered process automation framework. 180 | # Ignore directories containing user credentials, local state, and settings. 181 | # Learn more at https://abstra.io/docs 182 | .abstra/ 183 | 184 | # Visual Studio Code 185 | # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore 186 | # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore 187 | # and can be added to the global gitignore or merged into this file. However, if you prefer, 188 | # you could uncomment the following to ignore the entire vscode folder 189 | # .vscode/ 190 | 191 | # Ruff stuff: 192 | .ruff_cache/ 193 | 194 | # PyPI configuration file 195 | .pypirc 196 | 197 | # Cursor 198 | # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to 199 | # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data 200 | # refer to https://docs.cursor.com/context/ignore-files 201 | .cursorignore 202 | .cursorindexingignore 203 | 204 | # Marimo 205 | marimo/_static/ 206 | marimo/_lsp/ 207 | __marimo__/ 208 | 209 | # Visual Studio Code 210 | .vscode/ 211 | 212 | # Unit tests 213 | /tests/*.json 214 | /tests/*.png 215 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # JVP Flash Attention 4 | 5 | PyTorch 6 | [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.17050188.svg)](https://doi.org/10.5281/zenodo.17050188) 7 | [![PyPI version](https://badge.fury.io/py/jvp_flash_attention.svg)](https://badge.fury.io/py/jvp_flash_attention) 8 | [![Project Status: Active – The project has reached a stable, usable state and is being actively developed.](https://www.repostatus.org/badges/latest/active.svg)](https://www.repostatus.org/#active) 9 | Code style: black 10 | [![License: MIT](https://img.shields.io/badge/license-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 11 | 12 | 13 | 14 |
15 | 16 | ## Description 17 | 18 | Flash Attention Triton kernel with support for second-order derivatives, such as Jacobian-Vector Products (JVPs) and Hessian-Vector Products (HVPs) 19 | 20 | ## Installation 21 | 22 | Using `pip`, one can install `jvp_flash_attention` as follows. 23 | 24 | ```bash 25 | # Install package 26 | pip install jvp_flash_attention 27 | 28 | # [OPTIONAL, for development] Install package and pre-commit hooks 29 | pip install -e . 30 | pre-commit install 31 | ``` 32 | 33 | ## Usage 34 | 35 | Once installed, one can use `jvp_flash_attention` in place of PyTorch's `scaled_dot_product_attention` as follows. 36 | 37 | ```python 38 | import torch.nn.functional as F 39 | 40 | from torch.nn.attention import SDPBackend, sdpa_kernel 41 | from jvp_flash_attention.jvp_attention import JVPAttn, attention as jvp_attention 42 | 43 | with sdpa_kernel(SDPBackend.MATH): 44 | # Regular (quadratic) attention 45 | x = F.scaled_dot_product_attention( 46 | q, 47 | k, 48 | v, 49 | attn_mask=attn_mask, 50 | dropout_p=attn_dropout_p if self.training else 0.0, 51 | ) 52 | 53 | # JVP flash attention 54 | x = jvp_attention( 55 | q, 56 | k, 57 | v, 58 | attn_mask=attn_mask, 59 | # dropout_p=attn_dropout_p if self.training else 0.0, # NOTE: Attention dropout is currently unsupported 60 | ) 61 | ``` 62 | 63 | Anecdotally, one can also swap out `F.scaled_dot_product_attention` with `jvp_attention` **even for pretrained models** with minimal impact on numerical accuracy. 64 | 65 | > Note: If calling `torch.func.jvp` manually in your model's forward pass like 66 | > `pred, df = torch.func.jvp(*(lambda x_jvp: model(x_jvp), (x,), (gt,)))`, 67 | > make sure to use JVP Flash Attention in your model as `model = lambda q, k, v: JVPAttn.fwd_dual(q, k, v)` instead of as `model = lambda q, k, v: jvp_attention(q, k, v)` to ensure each input's tangent vectors are computed [prior](https://github.com/amorehead/jvp_flash_attention/issues/10) to running PyTorch's `autograd` engine. Models that rely on `torch.autograd.grad` to compute higher-order derivatives in their forward pass (e.g., energy-based models) should not require this change. 68 | 69 | Contributions or enhancements are welcome! 70 | 71 | ## Results 72 | 73 | ### Loss matching 74 | 75 | Model training with either `F.scaled_dot_product_attention` or `JVPAttn.fwd_dual` produces the same loss trajectory. 76 | 77 | image 78 | 79 | ### Speed matching 80 | 81 | Model training with either `F.scaled_dot_product_attention` or `JVPAttn.fwd_dual` achieves the same iteration speed. 82 | 83 | image 84 | 85 | > Note: The following results can be reproduced (for `float32` precision) by running `python tests/test_jvp_attention.py --dtype float32`. 86 | 87 | ### Time scaling 88 | 89 | `jvp_attention` outscales the speed of (`SDPBackend.MATH`-based) `F.scaled_dot_product_attention` when calculating second-order derivatives. 90 | 91 |
92 | 93 | 94 | 95 |
96 | 97 | ### Memory scaling 98 | 99 | `jvp_attention` improves the memory usage of (`SDPBackend.MATH`-based) `F.scaled_dot_product_attention` when calculating second-order derivatives. 100 | 101 |
102 | 103 | 104 | 105 |
106 | 107 | ## Tests 108 | 109 | If you want to run all the unit tests verifying the correctness of the JVP Flash Attention Triton kernel, run the following command(s). 110 | 111 | ```bash 112 | python tests/test_jvp_attention.py --dtype {float16,bfloat16,float32} 113 | ``` 114 | 115 | In principle, the kernel should support ROCm systems as well, though it has not yet been tested on them. macOS is currently unsupported except using a CPU-only backend. 116 | 117 | Full results for `float16`: 118 | 119 | ``` 120 | ============================================================================================================== 121 | BENCHMARK SUMMARY 122 | ============================================================================================================== 123 | Seq Len Causal Mask Method Time (ms) Mem (MB) TFLOP/s Max Error Grad Check 124 | -------------------------------------------------------------------------------------------------------------- 125 | 32 False additive sdpa 0.821 3.09 0.0 TFLOP/s baseline N/A 126 | 32 False additive jvp_attn 0.723 1.08 0.0 TFLOP/s 1.83e+01 ✗ 127 | 128 | 32 False boolean sdpa 0.961 3.14 0.0 TFLOP/s baseline N/A 129 | 32 False boolean jvp_attn 0.504 1.03 0.0 TFLOP/s 3.91e-03 ✓ 130 | 131 | 32 False none sdpa 0.576 3.09 0.0 TFLOP/s baseline N/A 132 | 32 False none jvp_attn 0.447 1.03 0.0 TFLOP/s 1.95e-03 ✓ 133 | 134 | 32 True none sdpa 0.934 3.10 0.0 TFLOP/s baseline N/A 135 | 32 True none jvp_attn 0.458 1.03 0.0 TFLOP/s 3.91e-03 ✓ 136 | 137 | 64 False additive sdpa 0.860 6.75 0.0 TFLOP/s baseline N/A 138 | 64 False additive jvp_attn 0.847 2.26 0.1 TFLOP/s 2.23e+00 ✗ 139 | 140 | 64 False boolean sdpa 0.908 6.94 0.0 TFLOP/s baseline N/A 141 | 64 False boolean jvp_attn 0.521 2.07 0.1 TFLOP/s 3.91e-03 ✓ 142 | 143 | 64 False none sdpa 0.542 6.75 0.0 TFLOP/s baseline N/A 144 | 64 False none jvp_attn 0.414 2.07 0.1 TFLOP/s 1.95e-03 ✓ 145 | 146 | 64 True none sdpa 0.888 6.77 0.0 TFLOP/s baseline N/A 147 | 64 True none jvp_attn 0.437 2.07 0.1 TFLOP/s 2.20e-03 ✓ 148 | 149 | 128 False additive sdpa 0.834 16.51 0.1 TFLOP/s baseline N/A 150 | 128 False additive jvp_attn 0.750 4.89 0.3 TFLOP/s 3.91e-03 ✓ 151 | 152 | 128 False boolean sdpa 0.840 17.26 0.1 TFLOP/s baseline N/A 153 | 128 False boolean jvp_attn 0.520 4.14 0.4 TFLOP/s 3.91e-03 ✓ 154 | 155 | 128 False none sdpa 0.610 16.51 0.2 TFLOP/s baseline N/A 156 | 128 False none jvp_attn 0.459 4.14 0.4 TFLOP/s 9.77e-04 ✓ 157 | 158 | 128 True none sdpa 1.053 16.57 0.0 TFLOP/s baseline N/A 159 | 128 True none jvp_attn 0.438 4.14 0.2 TFLOP/s 2.44e-03 ✓ 160 | 161 | 256 False additive sdpa 0.829 47.77 0.5 TFLOP/s baseline N/A 162 | 256 False additive jvp_attn 0.738 12.02 1.1 TFLOP/s 3.91e-03 ✓ 163 | 164 | 256 False boolean sdpa 0.872 50.77 0.5 TFLOP/s baseline N/A 165 | 256 False boolean jvp_attn 0.482 8.27 1.7 TFLOP/s 3.91e-03 ✓ 166 | 167 | 256 False none sdpa 0.812 47.27 0.5 TFLOP/s baseline N/A 168 | 256 False none jvp_attn 0.460 8.27 1.8 TFLOP/s 9.77e-04 ✓ 169 | 170 | 256 True none sdpa 0.964 47.52 0.2 TFLOP/s baseline N/A 171 | 256 True none jvp_attn 0.436 8.27 0.9 TFLOP/s 3.91e-03 ✓ 172 | 173 | 512 False additive sdpa 1.416 153.55 1.2 TFLOP/s baseline N/A 174 | 512 False additive jvp_attn 0.715 30.55 4.6 TFLOP/s 1.95e-03 ✓ 175 | 176 | 512 False boolean sdpa 1.441 165.05 1.1 TFLOP/s baseline N/A 177 | 512 False boolean jvp_attn 0.500 16.55 6.6 TFLOP/s 1.95e-03 ✓ 178 | 179 | 512 False none sdpa 1.374 153.05 1.2 TFLOP/s baseline N/A 180 | 512 False none jvp_attn 0.407 16.55 8.1 TFLOP/s 4.88e-04 ✓ 181 | 182 | 512 True none sdpa 1.402 154.05 0.6 TFLOP/s baseline N/A 183 | 512 True none jvp_attn 0.460 16.55 3.6 TFLOP/s 2.93e-03 ✓ 184 | 185 | 1024 False additive sdpa 4.963 546.84 1.3 TFLOP/s baseline N/A 186 | 1024 False additive jvp_attn 1.183 96.84 11.1 TFLOP/s 1.95e-03 ✓ 187 | 188 | 1024 False boolean sdpa 4.991 594.84 1.3 TFLOP/s baseline N/A 189 | 1024 False boolean jvp_attn 0.622 33.84 21.1 TFLOP/s 1.95e-03 ✓ 190 | 191 | 1024 False none sdpa 4.227 546.84 1.6 TFLOP/s baseline N/A 192 | 1024 False none jvp_attn 0.420 33.84 31.3 TFLOP/s 4.88e-04 ✓ 193 | 194 | 1024 True none sdpa 4.861 550.84 0.7 TFLOP/s baseline N/A 195 | 1024 True none jvp_attn 0.469 33.84 14.0 TFLOP/s 3.91e-03 ✓ 196 | 197 | 2048 False additive sdpa 18.773 2052.19 1.4 TFLOP/s baseline N/A 198 | 2048 False additive jvp_attn 3.379 336.19 15.6 TFLOP/s 1.95e-03 ✓ 199 | 200 | 2048 False boolean sdpa 18.815 2244.19 1.4 TFLOP/s baseline N/A 201 | 2048 False boolean jvp_attn 1.674 66.19 31.4 TFLOP/s 1.95e-03 ✓ 202 | 203 | 2048 False none sdpa 16.156 2052.19 1.6 TFLOP/s baseline N/A 204 | 2048 False none jvp_attn 1.186 66.19 44.3 TFLOP/s 4.88e-04 ✓ 205 | 206 | 2048 True none sdpa 18.587 2068.19 0.7 TFLOP/s baseline N/A 207 | 2048 True none jvp_attn 0.720 66.19 36.5 TFLOP/s 1.95e-03 ✓ 208 | 209 | 210 | ================================================================================ 211 | MASK TYPE PERFORMANCE COMPARISON 212 | ================================================================================ 213 | Seq Len Causal Method No Mask Boolean Mask Additive Mask 214 | -------------------------------------------------------------------------------- 215 | 32 False jvp_attn 0.45 ms 0.50 ms (1.13x) 0.72 ms (1.62x) 216 | 32 True jvp_attn 0.46 ms N/A N/A 217 | 64 False jvp_attn 0.41 ms 0.52 ms (1.26x) 0.85 ms (2.05x) 218 | 64 True jvp_attn 0.44 ms N/A N/A 219 | 128 False jvp_attn 0.46 ms 0.52 ms (1.13x) 0.75 ms (1.63x) 220 | 128 True jvp_attn 0.44 ms N/A N/A 221 | 256 False jvp_attn 0.46 ms 0.48 ms (1.05x) 0.74 ms (1.60x) 222 | 256 True jvp_attn 0.44 ms N/A N/A 223 | 512 False jvp_attn 0.41 ms 0.50 ms (1.23x) 0.72 ms (1.76x) 224 | 512 True jvp_attn 0.46 ms N/A N/A 225 | 1024 False jvp_attn 0.42 ms 0.62 ms (1.48x) 1.18 ms (2.82x) 226 | 1024 True jvp_attn 0.47 ms N/A N/A 227 | 2048 False jvp_attn 1.19 ms 1.67 ms (1.41x) 3.38 ms (2.85x) 228 | 2048 True jvp_attn 0.72 ms N/A N/A 229 | 230 | ============================================================ 231 | STATISTICS 232 | ============================================================ 233 | Average speedup: 4.50x 234 | Min speedup: 1.02x 235 | Max speedup: 25.82x 236 | 237 | Accuracy: 26/28 tests passed 238 | ⚠️ Some accuracy checks failed 239 | 240 | Failed configurations: 241 | - Seq=32, Causal=False, Mask=additive 242 | - Seq=64, Causal=False, Mask=additive 243 | ``` 244 | 245 | Full results for `bfloat16`: 246 | 247 | ``` 248 | ============================================================================================================== 249 | BENCHMARK SUMMARY 250 | ============================================================================================================== 251 | Seq Len Causal Mask Method Time (ms) Mem (MB) TFLOP/s Max Error Grad Check 252 | -------------------------------------------------------------------------------------------------------------- 253 | 32 False additive sdpa 0.864 3.09 0.0 TFLOP/s baseline N/A 254 | 32 False additive jvp_attn 0.773 1.08 0.0 TFLOP/s 1.84e+01 ✗ 255 | 256 | 32 False boolean sdpa 0.949 3.14 0.0 TFLOP/s baseline N/A 257 | 32 False boolean jvp_attn 0.569 1.03 0.0 TFLOP/s 3.12e-02 ✓ 258 | 259 | 32 False none sdpa 0.662 3.09 0.0 TFLOP/s baseline N/A 260 | 32 False none jvp_attn 0.447 1.03 0.0 TFLOP/s 1.56e-02 ✓ 261 | 262 | 32 True none sdpa 0.945 3.10 0.0 TFLOP/s baseline N/A 263 | 32 True none jvp_attn 0.469 1.03 0.0 TFLOP/s 3.12e-02 ✓ 264 | 265 | 64 False additive sdpa 0.923 6.75 0.0 TFLOP/s baseline N/A 266 | 64 False additive jvp_attn 1.149 2.26 0.0 TFLOP/s 2.23e+00 ✗ 267 | 268 | 64 False boolean sdpa 0.910 6.94 0.0 TFLOP/s baseline N/A 269 | 64 False boolean jvp_attn 0.518 2.07 0.1 TFLOP/s 3.12e-02 ✓ 270 | 271 | 64 False none sdpa 0.554 6.75 0.0 TFLOP/s baseline N/A 272 | 64 False none jvp_attn 0.427 2.07 0.1 TFLOP/s 1.56e-02 ✓ 273 | 274 | 64 True none sdpa 0.886 6.77 0.0 TFLOP/s baseline N/A 275 | 64 True none jvp_attn 0.458 2.07 0.1 TFLOP/s 3.12e-02 ✓ 276 | 277 | 128 False additive sdpa 0.860 16.51 0.1 TFLOP/s baseline N/A 278 | 128 False additive jvp_attn 0.896 4.89 0.2 TFLOP/s 1.56e-02 ✓ 279 | 280 | 128 False boolean sdpa 0.891 17.26 0.1 TFLOP/s baseline N/A 281 | 128 False boolean jvp_attn 0.771 4.14 0.3 TFLOP/s 1.56e-02 ✓ 282 | 283 | 128 False none sdpa 0.578 16.51 0.2 TFLOP/s baseline N/A 284 | 128 False none jvp_attn 0.467 4.14 0.4 TFLOP/s 7.81e-03 ✓ 285 | 286 | 128 True none sdpa 0.917 16.57 0.1 TFLOP/s baseline N/A 287 | 128 True none jvp_attn 0.447 4.14 0.2 TFLOP/s 3.12e-02 ✓ 288 | 289 | 256 False additive sdpa 0.822 47.77 0.5 TFLOP/s baseline N/A 290 | 256 False additive jvp_attn 0.734 12.02 1.1 TFLOP/s 1.56e-02 ✓ 291 | 292 | 256 False boolean sdpa 0.880 50.77 0.5 TFLOP/s baseline N/A 293 | 256 False boolean jvp_attn 0.532 8.27 1.5 TFLOP/s 1.56e-02 ✓ 294 | 295 | 256 False none sdpa 0.597 47.27 0.7 TFLOP/s baseline N/A 296 | 256 False none jvp_attn 0.441 8.27 1.9 TFLOP/s 7.81e-03 ✓ 297 | 298 | 256 True none sdpa 0.869 47.52 0.2 TFLOP/s baseline N/A 299 | 256 True none jvp_attn 0.469 8.27 0.9 TFLOP/s 1.56e-02 ✓ 300 | 301 | 512 False additive sdpa 1.429 153.55 1.1 TFLOP/s baseline N/A 302 | 512 False additive jvp_attn 0.710 30.55 4.6 TFLOP/s 1.56e-02 ✓ 303 | 304 | 512 False boolean sdpa 1.714 165.05 1.0 TFLOP/s baseline N/A 305 | 512 False boolean jvp_attn 0.552 16.55 5.9 TFLOP/s 1.56e-02 ✓ 306 | 307 | 512 False none sdpa 1.314 153.05 1.2 TFLOP/s baseline N/A 308 | 512 False none jvp_attn 0.403 16.55 8.2 TFLOP/s 7.81e-03 ✓ 309 | 310 | 512 True none sdpa 1.788 154.05 0.5 TFLOP/s baseline N/A 311 | 512 True none jvp_attn 0.432 16.55 3.8 TFLOP/s 3.12e-02 ✓ 312 | 313 | 1024 False additive sdpa 5.720 546.84 1.1 TFLOP/s baseline N/A 314 | 1024 False additive jvp_attn 1.133 96.84 11.6 TFLOP/s 1.56e-02 ✓ 315 | 316 | 1024 False boolean sdpa 5.376 594.84 1.2 TFLOP/s baseline N/A 317 | 1024 False boolean jvp_attn 0.634 33.84 20.7 TFLOP/s 1.56e-02 ✓ 318 | 319 | 1024 False none sdpa 4.646 546.84 1.4 TFLOP/s baseline N/A 320 | 1024 False none jvp_attn 0.423 33.84 31.1 TFLOP/s 3.91e-03 ✓ 321 | 322 | 1024 True none sdpa 5.566 550.84 0.6 TFLOP/s baseline N/A 323 | 1024 True none jvp_attn 0.466 33.84 14.1 TFLOP/s 1.56e-02 ✓ 324 | 325 | 2048 False additive sdpa 21.231 2052.19 1.2 TFLOP/s baseline N/A 326 | 2048 False additive jvp_attn 3.735 336.19 14.1 TFLOP/s 1.56e-02 ✓ 327 | 328 | 2048 False boolean sdpa 21.626 2244.19 1.2 TFLOP/s baseline N/A 329 | 2048 False boolean jvp_attn 1.926 66.19 27.3 TFLOP/s 1.56e-02 ✓ 330 | 331 | 2048 False none sdpa 18.311 2052.19 1.4 TFLOP/s baseline N/A 332 | 2048 False none jvp_attn 1.139 66.19 46.1 TFLOP/s 3.91e-03 ✓ 333 | 334 | 2048 True none sdpa 20.748 2068.19 0.6 TFLOP/s baseline N/A 335 | 2048 True none jvp_attn 0.750 66.19 35.0 TFLOP/s 3.12e-02 ✓ 336 | 337 | 338 | ================================================================================ 339 | MASK TYPE PERFORMANCE COMPARISON 340 | ================================================================================ 341 | Seq Len Causal Method No Mask Boolean Mask Additive Mask 342 | -------------------------------------------------------------------------------- 343 | 32 False jvp_attn 0.45 ms 0.57 ms (1.27x) 0.77 ms (1.73x) 344 | 32 True jvp_attn 0.47 ms N/A N/A 345 | 64 False jvp_attn 0.43 ms 0.52 ms (1.21x) 1.15 ms (2.69x) 346 | 64 True jvp_attn 0.46 ms N/A N/A 347 | 128 False jvp_attn 0.47 ms 0.77 ms (1.65x) 0.90 ms (1.92x) 348 | 128 True jvp_attn 0.45 ms N/A N/A 349 | 256 False jvp_attn 0.44 ms 0.53 ms (1.21x) 0.73 ms (1.66x) 350 | 256 True jvp_attn 0.47 ms N/A N/A 351 | 512 False jvp_attn 0.40 ms 0.55 ms (1.37x) 0.71 ms (1.76x) 352 | 512 True jvp_attn 0.43 ms N/A N/A 353 | 1024 False jvp_attn 0.42 ms 0.63 ms (1.50x) 1.13 ms (2.68x) 354 | 1024 True jvp_attn 0.47 ms N/A N/A 355 | 2048 False jvp_attn 1.14 ms 1.93 ms (1.69x) 3.74 ms (3.28x) 356 | 2048 True jvp_attn 0.75 ms N/A N/A 357 | 358 | ============================================================ 359 | STATISTICS 360 | ============================================================ 361 | Average speedup: 4.75x 362 | Min speedup: 0.80x 363 | Max speedup: 27.65x 364 | 365 | Accuracy: 26/28 tests passed 366 | ⚠️ Some accuracy checks failed 367 | 368 | Failed configurations: 369 | - Seq=32, Causal=False, Mask=additive 370 | - Seq=64, Causal=False, Mask=additive 371 | ``` 372 | 373 | Full results for `float32`: 374 | 375 | ``` 376 | ============================================================================================================== 377 | BENCHMARK SUMMARY 378 | ============================================================================================================== 379 | Seq Len Causal Mask Method Time (ms) Mem (MB) TFLOP/s Max Error Grad Check 380 | -------------------------------------------------------------------------------------------------------------- 381 | 32 False additive sdpa 0.770 2.44 0.0 TFLOP/s baseline N/A 382 | 32 False additive jvp_attn 0.812 2.16 0.0 TFLOP/s 2.31e-02 ✓ 383 | 384 | 32 False boolean sdpa 0.830 2.53 0.0 TFLOP/s baseline N/A 385 | 32 False boolean jvp_attn 0.575 2.07 0.0 TFLOP/s 9.16e-03 ✓ 386 | 387 | 32 False none sdpa 0.491 2.44 0.0 TFLOP/s baseline N/A 388 | 32 False none jvp_attn 0.528 2.07 0.0 TFLOP/s 7.83e-03 ✓ 389 | 390 | 32 True none sdpa 0.831 2.44 0.0 TFLOP/s baseline N/A 391 | 32 True none jvp_attn 0.457 2.07 0.0 TFLOP/s 8.60e-03 ✓ 392 | 393 | 64 False additive sdpa 0.859 5.25 0.0 TFLOP/s baseline N/A 394 | 64 False additive jvp_attn 0.793 4.51 0.1 TFLOP/s 1.24e-02 ✓ 395 | 396 | 64 False boolean sdpa 0.778 5.62 0.0 TFLOP/s baseline N/A 397 | 64 False boolean jvp_attn 0.522 4.13 0.1 TFLOP/s 1.23e-02 ✓ 398 | 399 | 64 False none sdpa 0.501 5.25 0.1 TFLOP/s baseline N/A 400 | 64 False none jvp_attn 0.437 4.13 0.1 TFLOP/s 7.03e-03 ✓ 401 | 402 | 64 True none sdpa 0.810 5.27 0.0 TFLOP/s baseline N/A 403 | 64 True none jvp_attn 0.450 4.13 0.1 TFLOP/s 1.05e-02 ✓ 404 | 405 | 128 False additive sdpa 0.869 13.51 0.1 TFLOP/s baseline N/A 406 | 128 False additive jvp_attn 0.697 9.76 0.3 TFLOP/s 9.14e-03 ✓ 407 | 408 | 128 False boolean sdpa 0.832 15.76 0.1 TFLOP/s baseline N/A 409 | 128 False boolean jvp_attn 0.527 8.26 0.4 TFLOP/s 8.91e-03 ✓ 410 | 411 | 128 False none sdpa 0.458 14.26 0.2 TFLOP/s baseline N/A 412 | 128 False none jvp_attn 0.610 8.26 0.3 TFLOP/s 5.07e-03 ✓ 413 | 414 | 128 True none sdpa 0.817 14.32 0.1 TFLOP/s baseline N/A 415 | 128 True none jvp_attn 0.478 8.26 0.2 TFLOP/s 1.05e-02 ✓ 416 | 417 | 256 False additive sdpa 0.786 43.27 0.5 TFLOP/s baseline N/A 418 | 256 False additive jvp_attn 0.689 23.77 1.2 TFLOP/s 9.98e-03 ✓ 419 | 420 | 256 False boolean sdpa 0.754 48.52 0.5 TFLOP/s baseline N/A 421 | 256 False boolean jvp_attn 0.514 17.02 1.6 TFLOP/s 9.93e-03 ✓ 422 | 423 | 256 False none sdpa 0.596 43.27 0.7 TFLOP/s baseline N/A 424 | 256 False none jvp_attn 0.461 17.77 1.8 TFLOP/s 4.03e-03 ✓ 425 | 426 | 256 True none sdpa 0.837 43.52 0.2 TFLOP/s baseline N/A 427 | 256 True none jvp_attn 0.424 17.77 1.0 TFLOP/s 9.61e-03 ✓ 428 | 429 | 512 False additive sdpa 1.383 144.80 1.2 TFLOP/s baseline N/A 430 | 512 False additive jvp_attn 0.793 57.80 4.1 TFLOP/s 7.29e-03 ✓ 431 | 432 | 512 False boolean sdpa 1.342 168.80 1.2 TFLOP/s baseline N/A 433 | 512 False boolean jvp_attn 0.792 33.80 4.1 TFLOP/s 7.22e-03 ✓ 434 | 435 | 512 False none sdpa 1.479 144.80 1.1 TFLOP/s baseline N/A 436 | 512 False none jvp_attn 0.453 33.80 7.2 TFLOP/s 3.94e-03 ✓ 437 | 438 | 512 True none sdpa 1.582 145.80 0.5 TFLOP/s baseline N/A 439 | 512 True none jvp_attn 0.501 33.80 3.3 TFLOP/s 6.56e-03 ✓ 440 | 441 | 1024 False additive sdpa 5.448 528.09 1.2 TFLOP/s baseline N/A 442 | 1024 False additive jvp_attn 2.148 168.09 6.1 TFLOP/s 6.90e-03 ✓ 443 | 444 | 1024 False boolean sdpa 5.131 624.09 1.3 TFLOP/s baseline N/A 445 | 1024 False boolean jvp_attn 1.139 66.09 11.5 TFLOP/s 6.85e-03 ✓ 446 | 447 | 1024 False none sdpa 4.339 528.09 1.5 TFLOP/s baseline N/A 448 | 1024 False none jvp_attn 0.538 66.09 24.4 TFLOP/s 2.92e-03 ✓ 449 | 450 | 1024 True none sdpa 5.262 532.09 0.6 TFLOP/s baseline N/A 451 | 1024 True none jvp_attn 0.437 66.09 15.0 TFLOP/s 8.65e-03 ✓ 452 | 453 | 2048 False additive sdpa 19.849 2016.19 1.3 TFLOP/s baseline N/A 454 | 2048 False additive jvp_attn 6.468 576.19 8.1 TFLOP/s 7.22e-03 ✓ 455 | 456 | 2048 False boolean sdpa 20.017 2400.19 1.3 TFLOP/s baseline N/A 457 | 2048 False boolean jvp_attn 3.247 132.19 16.2 TFLOP/s 7.16e-03 ✓ 458 | 459 | 2048 False none sdpa 16.573 2016.19 1.6 TFLOP/s baseline N/A 460 | 2048 False none jvp_attn 1.883 132.19 27.9 TFLOP/s 2.61e-03 ✓ 461 | 462 | 2048 True none sdpa 19.577 2032.19 0.7 TFLOP/s baseline N/A 463 | 2048 True none jvp_attn 1.114 132.19 23.6 TFLOP/s 7.71e-03 ✓ 464 | 465 | 466 | ================================================================================ 467 | MASK TYPE PERFORMANCE COMPARISON 468 | ================================================================================ 469 | Seq Len Causal Method No Mask Boolean Mask Additive Mask 470 | -------------------------------------------------------------------------------- 471 | 32 False jvp_attn 0.53 ms 0.57 ms (1.09x) 0.81 ms (1.54x) 472 | 32 True jvp_attn 0.46 ms N/A N/A 473 | 64 False jvp_attn 0.44 ms 0.52 ms (1.20x) 0.79 ms (1.81x) 474 | 64 True jvp_attn 0.45 ms N/A N/A 475 | 128 False jvp_attn 0.61 ms 0.53 ms (0.87x) 0.70 ms (1.14x) 476 | 128 True jvp_attn 0.48 ms N/A N/A 477 | 256 False jvp_attn 0.46 ms 0.51 ms (1.11x) 0.69 ms (1.49x) 478 | 256 True jvp_attn 0.42 ms N/A N/A 479 | 512 False jvp_attn 0.45 ms 0.79 ms (1.75x) 0.79 ms (1.75x) 480 | 512 True jvp_attn 0.50 ms N/A N/A 481 | 1024 False jvp_attn 0.54 ms 1.14 ms (2.12x) 2.15 ms (3.99x) 482 | 1024 True jvp_attn 0.44 ms N/A N/A 483 | 2048 False jvp_attn 1.88 ms 3.25 ms (1.72x) 6.47 ms (3.43x) 484 | 2048 True jvp_attn 1.11 ms N/A N/A 485 | 486 | ============================================================ 487 | STATISTICS 488 | ============================================================ 489 | Average speedup: 3.37x 490 | Min speedup: 0.75x 491 | Max speedup: 17.57x 492 | 493 | Accuracy: 28/28 tests passed 494 | ✓ All accuracy checks passed! 495 | ``` 496 | 497 | Note: Based on these results, for all precision types, it is recommended to provide a boolean `attn_mask` to `jvp_attention()` where possible. 498 | 499 | ## License 500 | 501 | This project is covered under the **MIT License**. 502 | 503 | ## Copyright 504 | 505 | JVP Flash Attention (jvp_flash_attention) Copyright (c) 2025, The Regents of the University of California, through Lawrence Berkeley National Laboratory (subject to receipt of any required approvals from the U.S. Dept. of Energy). All rights reserved. 506 | 507 | If you have questions about your rights to use or distribute this software, 508 | please contact Berkeley Lab's Intellectual Property Office at 509 | IPO@lbl.gov. 510 | 511 | **NOTICE.** This Software was developed under funding from the U.S. Department 512 | of Energy and the U.S. Government consequently retains certain rights. As 513 | such, the U.S. Government has been granted for itself and others acting on 514 | its behalf a paid-up, nonexclusive, irrevocable, worldwide license in the 515 | Software to reproduce, distribute copies to the public, prepare derivative 516 | works, and perform publicly and display publicly, and to permit others to do so. 517 | 518 | ## Citing this work 519 | 520 | If you use the code associated with this package or otherwise find this work useful, please use GitHub's `Cite this repository` feature or the BibTeX below. 521 | 522 | ```bibtex 523 | @software{Morehead_JVP_Flash_Attention_2025, 524 | author = {Morehead, Alex}, 525 | doi = {10.5281/zenodo.17050188}, 526 | license = {MIT}, 527 | month = sep, 528 | title = {{JVP Flash Attention}}, 529 | url = {https://github.com/amorehead/jvp_flash_attention}, 530 | version = {0.10.0}, 531 | year = {2025} 532 | } 533 | ``` 534 | 535 | ## Acknowledgements 536 | 537 | `jvp_flash_attention` builds upon the contributions and insights from the following sources: 538 | 539 | - [flash-attention](https://github.com/Dao-AILab/flash-attention) 540 | - [JVP Triton kernel thread](https://github.com/Dao-AILab/flash-attention/issues/1672) 541 | - [benjamin-dinkelmann](https://gist.github.com/benjamin-dinkelmann) 542 | - *[Birch-san](https://github.com/Birch-san)* 543 | - [dabeschte](https://github.com/dabeschte) 544 | - [IsaacYQH](https://gist.github.com/IsaacYQH) 545 | - [KohakuBlueleaf](https://github.com/KohakuBlueleaf) 546 | - [leon](https://github.com/leon532) 547 | - [limsanky](https://github.com/limsanky) 548 | - [lucidrains](https://github.com/lucidrains) 549 | - [Peterande](https://gist.github.com/Peterande) 550 | - *[Ryu1845](https://github.com/Ryu1845)* 551 | - [tridao](https://github.com/tridao) 552 | 553 | Thank you to each and every contributor! 554 | -------------------------------------------------------------------------------- /tests/test_jvp_attention.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import gc 4 | import os 5 | import random 6 | import time 7 | from argparse import ArgumentParser, Namespace 8 | from dataclasses import dataclass, field 9 | from functools import partial 10 | from typing import Any, Callable, NamedTuple 11 | 12 | import torch 13 | import torch.autograd.forward_ad as fwAD 14 | from torch import Tensor, enable_grad 15 | from torch.nn import MSELoss 16 | from torch.nn.attention import SDPBackend, sdpa_kernel 17 | from torch.nn.functional import scaled_dot_product_attention 18 | 19 | try: 20 | import matplotlib.pyplot as plt 21 | import numpy as np 22 | 23 | PLOTTING_AVAILABLE = True 24 | except ImportError: 25 | PLOTTING_AVAILABLE = False 26 | 27 | from jvp_flash_attention.jvp_attention import MASK_CONST, JVPAttn 28 | 29 | 30 | def mpi_to_flops(ms_per_iter: float, flop_count: int) -> float: 31 | """Convert milliseconds per iteration to FLOPS. 32 | 33 | Args: 34 | ms_per_iter: Milliseconds per iteration. 35 | flop_count: Number of floating point operations. 36 | 37 | Returns: 38 | The number of FLOPS. 39 | """ 40 | iters_per_second = 1e3 / ms_per_iter 41 | return iters_per_second * flop_count 42 | 43 | 44 | def fmt_flops(flops: int) -> str: 45 | """Return a string representation of FLOPS in TFLOP/s.""" 46 | return f"{flops / 1e12:5.1f} TFLOP/s" 47 | 48 | 49 | def get_attention_flop_count( 50 | batch_size: int, 51 | num_heads: int, 52 | seq_len: int, 53 | head_dim: int, 54 | is_causal: bool, 55 | is_jvp: bool = False, 56 | ) -> int: 57 | """Calculate FLOPs for attention operations. 58 | 59 | Args: 60 | batch_size: Batch size. 61 | num_heads: Number of attention heads. 62 | seq_len: Sequence length. 63 | head_dim: Dimension of each attention head. 64 | is_causal: Whether the attention is causal. 65 | is_jvp: Whether to include JVP (Jacobian-vector product) FLOPs. 66 | 67 | Returns: 68 | The total FLOPs for the attention operation. 69 | """ 70 | # Base attention FLOPs 71 | qk_flops = 2 * batch_size * num_heads * seq_len * seq_len * head_dim 72 | softmax_flops = 5 * batch_size * num_heads * seq_len * seq_len 73 | av_flops = 2 * batch_size * num_heads * seq_len * seq_len * head_dim 74 | 75 | total_flops = qk_flops + softmax_flops + av_flops 76 | 77 | if is_causal: 78 | total_flops = total_flops // 2 79 | 80 | if is_jvp: 81 | total_flops = total_flops * 2 82 | 83 | return total_flops 84 | 85 | 86 | def measure_memory_usage(f: Callable[[], Any]) -> tuple[float, float]: 87 | """Measure GPU memory usage of a function. 88 | 89 | Args: 90 | f: The function to measure. 91 | 92 | Returns: 93 | Tuple of (allocated_mb, reserved_mb) memory in megabytes. 94 | """ 95 | torch.cuda.synchronize() 96 | torch.cuda.reset_peak_memory_stats() 97 | gc.collect() 98 | torch.cuda.empty_cache() 99 | 100 | initial_allocated = torch.cuda.memory_allocated() 101 | initial_reserved = torch.cuda.memory_reserved() 102 | 103 | f() 104 | 105 | torch.cuda.synchronize() 106 | 107 | peak_allocated = torch.cuda.max_memory_allocated() 108 | peak_reserved = torch.cuda.max_memory_reserved() 109 | 110 | allocated_mb = (peak_allocated - initial_allocated) / (1024 * 1024) 111 | reserved_mb = (peak_reserved - initial_reserved) / (1024 * 1024) 112 | 113 | return allocated_mb, reserved_mb 114 | 115 | 116 | def benchmark_function( 117 | f: Callable[[], Any], warmup_iters: int = 10, benchmark_iters: int = 100 118 | ) -> float: 119 | """Benchmark a function's execution time. 120 | 121 | Args: 122 | f: The function to benchmark. 123 | warmup_iters: Number of warmup iterations. 124 | benchmark_iters: Number of benchmark iterations. 125 | 126 | Returns: 127 | Average time per iteration in milliseconds. 128 | """ 129 | # Warmup 130 | for _ in range(warmup_iters): 131 | f() 132 | 133 | torch.cuda.synchronize() 134 | start_time = time.perf_counter() 135 | 136 | for _ in range(benchmark_iters): 137 | f() 138 | 139 | torch.cuda.synchronize() 140 | end_time = time.perf_counter() 141 | 142 | avg_time_ms = (end_time - start_time) * 1000 / benchmark_iters 143 | return avg_time_ms 144 | 145 | 146 | class QKV(NamedTuple): 147 | """Query, Key, Value tensors.""" 148 | 149 | q: Tensor 150 | k: Tensor 151 | v: Tensor 152 | 153 | 154 | class UnpackedDualQKV(NamedTuple): 155 | """Unpacked dual Query, Key, Value tensors.""" 156 | 157 | primal: QKV 158 | tangent: QKV 159 | 160 | 161 | @dataclass 162 | class AccuracyMetrics: 163 | """Accuracy metrics for numerical validation.""" 164 | 165 | primal_error: float 166 | tangent_error: float 167 | loss_error: float 168 | q_grad_error: float 169 | k_grad_error: float 170 | v_grad_error: float 171 | tolerance: float = 5e-3 172 | 173 | @property 174 | def max_error(self) -> float: 175 | """Return the maximum error across all metrics.""" 176 | return max( 177 | self.primal_error, 178 | self.tangent_error, 179 | self.loss_error, 180 | self.q_grad_error, 181 | self.k_grad_error, 182 | self.v_grad_error, 183 | ) 184 | 185 | def is_accurate(self) -> bool: 186 | """Check if all errors are within tolerance.""" 187 | return self.max_error < self.tolerance 188 | 189 | 190 | class BenchmarkResult(NamedTuple): 191 | """Results from a single benchmark run.""" 192 | 193 | seq_len: int 194 | is_causal: bool 195 | method: str # 'sdpa' or 'jvp_attn' 196 | time_ms: float 197 | memory_allocated_mb: float 198 | memory_reserved_mb: float 199 | mask_type: str # 'none', 'boolean', or 'additive' 200 | flops: int | None = None 201 | accuracy: AccuracyMetrics | None = None 202 | 203 | 204 | @dataclass 205 | class Args: 206 | """Training arguments.""" 207 | 208 | bsz: int 209 | model_dim: int 210 | head_dim: int 211 | seq_lengths: list[int] = field(default_factory=lambda: [32, 64, 128, 256, 512, 1024, 2048]) 212 | warmup_iters: int = 10 213 | benchmark_iters: int = 100 214 | dtype: str = "float16" 215 | seed: int = 42 216 | test_masks: bool = True 217 | validate_gradients: bool = True 218 | benchmark_performance: bool = True 219 | mask_prob: float = 0.9 # Probability of masking out an attention weight 220 | 221 | @staticmethod 222 | def get_parser() -> ArgumentParser: 223 | """Get the argument parser for training.""" 224 | parser = ArgumentParser() 225 | parser.add_argument("--bsz", default=2, type=int) 226 | parser.add_argument("--model-dim", default=768, type=int) 227 | parser.add_argument("--head-dim", default=64, type=int) 228 | parser.add_argument( 229 | "--seq-lengths", nargs="+", type=int, default=[32, 64, 128, 256, 512, 1024, 2048] 230 | ) 231 | parser.add_argument("--warmup-iters", default=10, type=int) 232 | parser.add_argument("--benchmark-iters", default=100, type=int) 233 | parser.add_argument( 234 | "--dtype", default="float16", choices=["float16", "float32", "bfloat16"] 235 | ) 236 | parser.add_argument("--seed", default=42, type=int) 237 | parser.add_argument( 238 | "--no-test-masks", 239 | action="store_true", 240 | help="Skip testing with attention masks", 241 | ) 242 | parser.add_argument( 243 | "--no-validate-gradients", 244 | action="store_true", 245 | help="Skip gradient validation", 246 | ) 247 | parser.add_argument( 248 | "--no-benchmark-performance", 249 | action="store_true", 250 | help="Skip performance benchmarking", 251 | ) 252 | parser.add_argument( 253 | "--mask-prob", 254 | default=0.9, 255 | type=float, 256 | help="Probability of masking out attention weights", 257 | ) 258 | return parser 259 | 260 | @staticmethod 261 | def from_namespace(namespace: Namespace) -> Args: 262 | """Create Args from a namespace.""" 263 | kwargs = vars(namespace) 264 | 265 | test_masks = not kwargs.pop("no_test_masks", False) 266 | validate_gradients = not kwargs.pop("no_validate_gradients", False) 267 | benchmark_performance = not kwargs.pop("no_benchmark_performance", False) 268 | 269 | kwargs["test_masks"] = test_masks 270 | kwargs["validate_gradients"] = validate_gradients 271 | kwargs["benchmark_performance"] = benchmark_performance 272 | 273 | return Args(**kwargs) 274 | 275 | 276 | def create_test_tensors( 277 | args: Args, seq_len: int, device: torch.device, dtype: torch.dtype 278 | ) -> tuple[Tensor, ...]: 279 | """Create test tensors for benchmarking. 280 | 281 | Args: 282 | args: The training arguments. 283 | seq_len: The sequence length. 284 | device: The device to create the tensors on. 285 | dtype: The data type of the tensors. 286 | 287 | Returns: 288 | Tuple of (q_p, q_t, k_p, k_t, v_p, v_t, target) tensors. 289 | """ 290 | gen = torch.Generator(device=device).manual_seed(args.seed) 291 | heads = args.model_dim // args.head_dim 292 | 293 | tensors = tuple( 294 | torch.randn( 295 | args.bsz, 296 | heads, 297 | seq_len, 298 | args.head_dim, 299 | device=device, 300 | dtype=dtype, 301 | generator=gen, 302 | ) 303 | for _ in range(7) 304 | ) 305 | 306 | return tensors 307 | 308 | 309 | def create_attention_mask( 310 | args: Args, 311 | seq_len: int, 312 | device: torch.device, 313 | dtype: torch.dtype, 314 | mask_type: str, 315 | ) -> Tensor | None: 316 | """Create an attention mask for testing. 317 | 318 | Args: 319 | args: The training arguments. 320 | seq_len: The sequence length. 321 | device: The device to create the mask on. 322 | dtype: The data type of the mask. 323 | mask_type: Type of mask ('none', 'boolean', or 'additive'). 324 | 325 | Returns: 326 | The attention mask tensor or None if mask_type is 'none'. 327 | """ 328 | if mask_type == "none": 329 | return None 330 | 331 | gen = torch.Generator(device=device).manual_seed(args.seed + 1000) # Different seed for masks 332 | heads = args.model_dim // args.head_dim 333 | 334 | if mask_type == "boolean": 335 | # Create a boolean mask where True means "attend" and False means "ignore" 336 | # We'll create a random mask with some positions masked out 337 | mask = ( 338 | torch.rand(args.bsz, heads, seq_len, seq_len, device=device, generator=gen) 339 | > args.mask_prob 340 | ) 341 | # mask[0, :-1, :, :2] = ( 342 | # True # Ensure first two columns of the first batch element (except for its last head) are True 343 | # ) 344 | # mask[1, :-1, :, -2:] = ( 345 | # True # Ensure last two columns of the second batch element (except for its last head) are True 346 | # ) 347 | 348 | # Find completely masked heads 349 | fully_masked = ~mask.view(args.bsz, heads, -1).any(dim=2) 350 | 351 | # For each fully masked head, unmask some random positions 352 | if fully_masked.any(): 353 | print(" ⚠️ Some heads were fully masked; unmasking some positions to avoid this.") 354 | for b in range(args.bsz): 355 | for h in range(heads): 356 | if fully_masked[b, h]: 357 | num_to_unmask = max(1, seq_len * seq_len // 10) 358 | indices = torch.randperm(seq_len * seq_len, device=device, generator=gen)[ 359 | :num_to_unmask 360 | ] 361 | mask[b, h].view(-1)[indices] = True 362 | 363 | return mask 364 | 365 | elif mask_type == "additive": 366 | # Create an additive mask with values to be added to attention scores 367 | # Use -inf (MASK_CONST) for positions to ignore, 0 for positions to attend 368 | rand_mask = torch.rand(args.bsz, heads, seq_len, seq_len, device=device, generator=gen) 369 | mask = torch.where(rand_mask > args.mask_prob, 0.0, MASK_CONST) 370 | # Convert to the target dtype 371 | mask = mask.to(dtype) 372 | # mask[0, :-1, :, :2] = ( 373 | # 0.0 # Ensure first two columns of the first batch element (except for its last head) are zeros 374 | # ) 375 | # mask[1, :-1, :, -2:] = ( 376 | # 0.0 # Ensure last two columns of the second batch element (except for its last head) are zeros 377 | # ) 378 | 379 | # Find completely masked heads 380 | fully_masked = (mask.view(args.bsz, heads, -1) == MASK_CONST).all(dim=2) 381 | 382 | # For each fully masked head, unmask some random positions 383 | if fully_masked.any(): 384 | print(" ⚠️ Some heads were fully masked; unmasking some positions to avoid this.") 385 | for b in range(args.bsz): 386 | for h in range(heads): 387 | if fully_masked[b, h]: 388 | num_to_unmask = max(1, seq_len * seq_len // 10) 389 | indices = torch.randperm(seq_len * seq_len, device=device, generator=gen)[ 390 | :num_to_unmask 391 | ] 392 | mask[b, h].view(-1)[indices] = 0.0 393 | 394 | return mask 395 | 396 | else: 397 | raise ValueError(f"Unknown mask type: {mask_type}") 398 | 399 | 400 | def loss_fn(out: Tensor, target: Tensor) -> Tensor: 401 | """Compute the mean squared error loss. 402 | 403 | Args: 404 | out: The output tensor. 405 | target: The target tensor. 406 | 407 | Returns: 408 | The mean squared error loss. 409 | """ 410 | return (out - target).square().mean() 411 | 412 | 413 | def make_qkv_with_grad( 414 | q_p: Tensor, k_p: Tensor, v_p: Tensor, q_t: Tensor, k_t: Tensor, v_t: Tensor 415 | ) -> QKV: 416 | """Make a QKV tuple with gradients enabled. 417 | 418 | Args: 419 | q_p: The query projection tensor. 420 | k_p: The key projection tensor. 421 | v_p: The value projection tensor. 422 | q_t: The query tangent tensor. 423 | k_t: The key tangent tensor. 424 | v_t: The value tangent tensor. 425 | 426 | Returns: 427 | A QKV tuple containing the primal and tangent QKV tensors. 428 | """ 429 | # Create dual tensors 430 | q = fwAD.make_dual(q_p, q_t) 431 | k = fwAD.make_dual(k_p, k_t) 432 | v = fwAD.make_dual(v_p, v_t) 433 | 434 | for t in (q, k, v): 435 | t.requires_grad = True 436 | t.retain_grad() 437 | 438 | return QKV(q, k, v) 439 | 440 | 441 | def make_qkv(q_p: Tensor, k_p: Tensor, v_p: Tensor, q_t: Tensor, k_t: Tensor, v_t: Tensor) -> QKV: 442 | """Make a QKV tuple from the given tensors with dual numbers. 443 | 444 | Args: 445 | q_p: The query projection tensor. 446 | k_p: The key projection tensor. 447 | v_p: The value projection tensor. 448 | q_t: The query tangent tensor. 449 | k_t: The key tangent tensor. 450 | v_t: The value tangent tensor. 451 | 452 | Returns: 453 | A QKV tuple containing the primal and tangent QKV tensors. 454 | """ 455 | q = fwAD.make_dual(q_p, q_t) 456 | k = fwAD.make_dual(k_p, k_t) 457 | v = fwAD.make_dual(v_p, v_t) 458 | return QKV(q, k, v) 459 | 460 | 461 | def make_qkv_unpacked( 462 | q_p: Tensor, k_p: Tensor, v_p: Tensor, q_t: Tensor, k_t: Tensor, v_t: Tensor 463 | ) -> UnpackedDualQKV: 464 | """Make an unpacked dual QKV from the given tensors. 465 | 466 | Args: 467 | q_p: The query projection tensor. 468 | k_p: The key projection tensor. 469 | v_p: The value projection tensor. 470 | q_t: The query tangent tensor. 471 | k_t: The key tangent tensor. 472 | v_t: The value tangent tensor. 473 | 474 | Returns: 475 | An unpacked dual QKV containing the primal and tangent QKV tensors. 476 | """ 477 | for t in (q_p, k_p, v_p): 478 | t.requires_grad = True 479 | t.retain_grad() 480 | 481 | return UnpackedDualQKV( 482 | primal=QKV( 483 | q=q_p, 484 | k=k_p, 485 | v=v_p, 486 | ), 487 | tangent=QKV( 488 | q=q_t, 489 | k=k_t, 490 | v=v_t, 491 | ), 492 | ) 493 | 494 | 495 | def compute_absolute_error(*tensors: Tensor) -> float: 496 | """Compute the maximum absolute pairwise error between all tensors. 497 | 498 | Args: 499 | tensors: The input tensors to compare. 500 | 501 | Returns: 502 | The maximum absolute pairwise error. 503 | """ 504 | if len(tensors) < 2: 505 | raise ValueError("At least two tensors are required to compute absolute error.") 506 | max_error = 0.0 507 | for i in range(len(tensors)): 508 | for j in range(i + 1, len(tensors)): 509 | diff = (tensors[i] - tensors[j]).abs().max().item() 510 | if diff > max_error: 511 | max_error = diff 512 | return max_error 513 | 514 | 515 | def validate_accuracy_and_gradients( 516 | q_p: Tensor, 517 | k_p: Tensor, 518 | v_p: Tensor, 519 | q_t: Tensor, 520 | k_t: Tensor, 521 | v_t: Tensor, 522 | target: Tensor, 523 | is_causal: bool, 524 | attn_mask: Tensor | None = None, 525 | tolerance: float = 4e-3, 526 | loss_tolerance: float = 5e-4, 527 | grad_tolerance: float = 5e-4, 528 | ) -> AccuracyMetrics: 529 | """Validate numerical accuracy and gradient matching between SDPA and JVP attention. 530 | 531 | Args: 532 | q_p: The query projection tensor. 533 | k_p: The key projection tensor. 534 | v_p: The value projection tensor. 535 | q_t: The query tangent tensor. 536 | k_t: The key tangent tensor. 537 | v_t: The value tangent tensor. 538 | target: The target tensor. 539 | is_causal: Whether the attention is causal. 540 | attn_mask: Optional attention mask tensor. 541 | tolerance: The tolerance for primal errors. 542 | loss_tolerance: The tolerance for loss errors. 543 | grad_tolerance: The tolerance for gradient errors. 544 | 545 | Returns: 546 | AccuracyMetrics containing all error measurements. 547 | """ 548 | with sdpa_kernel(SDPBackend.MATH), fwAD.dual_level(), enable_grad(): 549 | # Run SDPA 550 | q0, k0, v0 = make_qkv_with_grad( 551 | q_p.clone(), k_p.clone(), v_p.clone(), q_t.clone(), k_t.clone(), v_t.clone() 552 | ) 553 | 554 | sdpa_out = scaled_dot_product_attention( 555 | q0, k0, v0, attn_mask=attn_mask, is_causal=is_causal 556 | ) 557 | sdpa_out.retain_grad() 558 | sdpa_op, sdpa_ot = fwAD.unpack_dual(sdpa_out) 559 | 560 | loss0 = loss_fn(sdpa_out, target) 561 | loss0.backward() 562 | 563 | assert not any( 564 | t.grad.isnan().any() or t.grad.isinf().any() for t in (q0, k0, v0) 565 | ), "NaN/Inf in SDPA input gradients." 566 | 567 | # Run JVP Attention 568 | q1, k1, v1 = make_qkv_with_grad( 569 | q_p.clone(), k_p.clone(), v_p.clone(), q_t.clone(), k_t.clone(), v_t.clone() 570 | ) 571 | 572 | jvp_out = JVPAttn.fwd_dual(q1, k1, v1, attn_mask=attn_mask, causal=is_causal) 573 | jvp_out.retain_grad() 574 | jvp_op, jvp_ot = fwAD.unpack_dual(jvp_out) 575 | 576 | loss1 = loss_fn(jvp_out, target) 577 | loss1.backward() 578 | 579 | assert not any( 580 | t.grad.isnan().any() or t.grad.isinf().any() for t in (q1, k1, v1) 581 | ), "NaN/Inf in JVP input gradients." 582 | 583 | mse_fn = MSELoss() 584 | with enable_grad(): 585 | # Run JVP Attention with torch.func.jvp 586 | qkv_p, qkv_t = make_qkv_unpacked( 587 | q_p.clone(), 588 | k_p.clone(), 589 | v_p.clone(), 590 | q_t.clone(), 591 | k_t.clone(), 592 | v_t.clone(), 593 | ) 594 | 595 | jvp_func_op, jvp_func_ot = torch.func.jvp( 596 | partial(JVPAttn.fwd_dual, attn_mask=attn_mask, causal=is_causal), qkv_p, qkv_t 597 | ) 598 | jvp_func_op.retain_grad() 599 | 600 | loss2: Tensor = mse_fn(jvp_func_op, target) 601 | loss2.backward() 602 | 603 | q2, k2, v2 = qkv_p 604 | 605 | assert not any( 606 | t.grad.isnan().any() or t.grad.isinf().any() for t in (q2, k2, v2) 607 | ), "NaN/Inf in JVP (func) input gradients." 608 | 609 | # Compute errors 610 | primal_error = compute_absolute_error(jvp_func_op, jvp_op, sdpa_op) 611 | tangent_error = compute_absolute_error(jvp_func_ot, jvp_ot, sdpa_ot) 612 | loss_error = compute_absolute_error(loss2, loss1, loss0) 613 | 614 | # Compute gradient errors 615 | q_grad_error = compute_absolute_error(q2.grad, q1.grad, q0.grad) 616 | k_grad_error = compute_absolute_error(k2.grad, k1.grad, k0.grad) 617 | v_grad_error = compute_absolute_error(v2.grad, v1.grad, v0.grad) 618 | 619 | metrics = AccuracyMetrics( 620 | primal_error=primal_error, 621 | tangent_error=tangent_error, 622 | loss_error=loss_error, 623 | q_grad_error=q_grad_error, 624 | k_grad_error=k_grad_error, 625 | v_grad_error=v_grad_error, 626 | ) 627 | 628 | # Validate using torch.testing.assert_close 629 | try: 630 | torch.testing.assert_close(jvp_op, sdpa_op, atol=tolerance, rtol=1e-5) 631 | torch.testing.assert_close( 632 | # TODO: Improve this (causal) accuracy for longer sequence lengths 633 | jvp_func_op, 634 | sdpa_op, 635 | atol=tolerance, 636 | rtol=1e-5, 637 | ) 638 | 639 | # TODO: Improve these tangent accuracies 640 | torch.testing.assert_close( 641 | jvp_ot, 642 | sdpa_ot, 643 | atol=tolerance, 644 | rtol=1e-5, 645 | ) 646 | torch.testing.assert_close( 647 | jvp_func_ot, 648 | sdpa_ot, 649 | atol=tolerance, 650 | rtol=1e-5, 651 | ) 652 | 653 | torch.testing.assert_close(loss1, loss0, atol=loss_tolerance, rtol=1e-5) 654 | torch.testing.assert_close(loss2, loss0, atol=loss_tolerance, rtol=1e-5) 655 | 656 | torch.testing.assert_close(q1.grad, q0.grad, atol=grad_tolerance, rtol=1e-5) 657 | torch.testing.assert_close(k1.grad, k0.grad, atol=grad_tolerance, rtol=1e-5) 658 | torch.testing.assert_close(v1.grad, v0.grad, atol=grad_tolerance, rtol=1e-5) 659 | 660 | torch.testing.assert_close(q2.grad, q0.grad, atol=grad_tolerance, rtol=1e-5) 661 | torch.testing.assert_close(k2.grad, k0.grad, atol=grad_tolerance, rtol=1e-5) 662 | torch.testing.assert_close(v2.grad, v0.grad, atol=grad_tolerance, rtol=1e-5) 663 | 664 | except AssertionError as e: 665 | print(f" ⚠️ Accuracy validation failed (causal={is_causal}): {e}") 666 | 667 | return metrics 668 | 669 | 670 | def run_benchmark_suite(args: Args) -> list[BenchmarkResult]: 671 | """Run comprehensive benchmarks across different configurations. 672 | 673 | Args: 674 | args: The command-line arguments for the benchmark. 675 | 676 | Returns: 677 | A list of benchmark results. 678 | """ 679 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 680 | 681 | dtype_map = { 682 | "float16": torch.float16, 683 | "float32": torch.float32, 684 | "bfloat16": torch.bfloat16, 685 | } 686 | dtype = dtype_map[args.dtype] 687 | 688 | tolerance_map = { 689 | "float16": 4e-3, 690 | "float32": 2.35e-2, 691 | "bfloat16": 3.2e-2, 692 | } 693 | tolerance = tolerance_map[args.dtype] 694 | 695 | results = [] 696 | 697 | # Define mask types to test 698 | mask_types = ["none"] 699 | if args.test_masks: 700 | mask_types.extend(["boolean", "additive"]) 701 | 702 | for seq_len in args.seq_lengths: 703 | print(f"\n{'='*60}") 704 | print(f"Benchmarking sequence length: {seq_len}") 705 | print(f"{'='*60}") 706 | 707 | # Create test tensors 708 | q_p, q_t, k_p, k_t, v_p, v_t, target = create_test_tensors(args, seq_len, device, dtype) 709 | 710 | for mask_type in mask_types: 711 | # Create attention mask if needed 712 | attn_mask = create_attention_mask(args, seq_len, device, dtype, mask_type) 713 | 714 | for is_causal in [False, True]: 715 | print(f"\nCausal: {is_causal}, Mask: {mask_type}") 716 | print("-" * 40) 717 | 718 | if is_causal and mask_type != "none": 719 | print(" Skipping invalid combination of causal + mask") 720 | continue 721 | 722 | # Validate accuracy and gradients first 723 | if args.validate_gradients: 724 | print("Validating accuracy and gradients...") 725 | accuracy_metrics = validate_accuracy_and_gradients( 726 | q_p, 727 | k_p, 728 | v_p, 729 | q_t, 730 | k_t, 731 | v_t, 732 | target, 733 | is_causal, 734 | attn_mask=attn_mask, 735 | tolerance=tolerance, 736 | ) 737 | accuracy_metrics.tolerance = tolerance 738 | 739 | print(f" Primal error: {accuracy_metrics.primal_error:.2e}") 740 | print(f" Tangent error: {accuracy_metrics.tangent_error:.2e}") 741 | print(f" Loss error: {accuracy_metrics.loss_error:.2e}") 742 | print(f" Q gradient error: {accuracy_metrics.q_grad_error:.2e}") 743 | print(f" K gradient error: {accuracy_metrics.k_grad_error:.2e}") 744 | print(f" V gradient error: {accuracy_metrics.v_grad_error:.2e}") 745 | 746 | if accuracy_metrics.is_accurate(): 747 | print(" ✓ All accuracy checks passed!") 748 | else: 749 | print(f" ⚠️ Max error {accuracy_metrics.max_error:.2e} exceeds tolerance") 750 | else: 751 | accuracy_metrics = None 752 | 753 | # Benchmark performance 754 | with sdpa_kernel(SDPBackend.MATH), fwAD.dual_level(), enable_grad(): 755 | # Create functions for benchmarking 756 | def run_sdpa(): 757 | """Run SDPA attention.""" 758 | q, k, v = make_qkv( 759 | q_p.clone(), 760 | k_p.clone(), 761 | v_p.clone(), 762 | q_t.clone(), 763 | k_t.clone(), 764 | v_t.clone(), 765 | ) 766 | out = scaled_dot_product_attention( 767 | q, k, v, attn_mask=attn_mask, is_causal=is_causal 768 | ) 769 | 770 | def run_jvp_attn(): 771 | """Run JVP attention.""" 772 | q, k, v = make_qkv( 773 | q_p.clone(), 774 | k_p.clone(), 775 | v_p.clone(), 776 | q_t.clone(), 777 | k_t.clone(), 778 | v_t.clone(), 779 | ) 780 | out = JVPAttn.fwd_dual(q, k, v, attn_mask=attn_mask, causal=is_causal) 781 | 782 | if not args.benchmark_performance: 783 | print(" Skipping performance benchmarking.") 784 | results.append( 785 | BenchmarkResult( 786 | seq_len=seq_len, 787 | is_causal=is_causal, 788 | method="sdpa", 789 | time_ms=np.nan, 790 | memory_allocated_mb=np.nan, 791 | memory_reserved_mb=np.nan, 792 | mask_type=mask_type, 793 | flops=np.nan, 794 | accuracy=None, 795 | ) 796 | ) 797 | results.append( 798 | BenchmarkResult( 799 | seq_len=seq_len, 800 | is_causal=is_causal, 801 | method="jvp_attn", 802 | time_ms=np.nan, 803 | memory_allocated_mb=np.nan, 804 | memory_reserved_mb=np.nan, 805 | mask_type=mask_type, 806 | flops=np.nan, 807 | accuracy=accuracy_metrics, 808 | ) 809 | ) 810 | continue 811 | 812 | print("\nBenchmarking performance...") 813 | heads = args.model_dim // args.head_dim 814 | 815 | # Measure SDPA performance 816 | sdpa_time = benchmark_function( 817 | run_sdpa, args.warmup_iters, args.benchmark_iters 818 | ) 819 | sdpa_mem_alloc, sdpa_mem_reserved = measure_memory_usage(run_sdpa) 820 | sdpa_flops = get_attention_flop_count( 821 | args.bsz, heads, seq_len, args.head_dim, is_causal, is_jvp=False 822 | ) 823 | 824 | # Measure JVP Attention performance 825 | jvp_time = benchmark_function( 826 | run_jvp_attn, args.warmup_iters, args.benchmark_iters 827 | ) 828 | jvp_mem_alloc, jvp_mem_reserved = measure_memory_usage(run_jvp_attn) 829 | jvp_flops = get_attention_flop_count( 830 | args.bsz, heads, seq_len, args.head_dim, is_causal, is_jvp=True 831 | ) 832 | 833 | # Store results 834 | results.append( 835 | BenchmarkResult( 836 | seq_len=seq_len, 837 | is_causal=is_causal, 838 | method="sdpa", 839 | time_ms=sdpa_time, 840 | memory_allocated_mb=sdpa_mem_alloc, 841 | memory_reserved_mb=sdpa_mem_reserved, 842 | mask_type=mask_type, 843 | flops=sdpa_flops, 844 | accuracy=None, 845 | ) 846 | ) 847 | 848 | results.append( 849 | BenchmarkResult( 850 | seq_len=seq_len, 851 | is_causal=is_causal, 852 | method="jvp_attn", 853 | time_ms=jvp_time, 854 | memory_allocated_mb=jvp_mem_alloc, 855 | memory_reserved_mb=jvp_mem_reserved, 856 | mask_type=mask_type, 857 | flops=jvp_flops, 858 | accuracy=accuracy_metrics, 859 | ) 860 | ) 861 | 862 | # Print results 863 | print("PyTorch SDPA:") 864 | print(f" Time: {sdpa_time:.3f} ms") 865 | print( 866 | f" Memory (alloc/reserved): {sdpa_mem_alloc:.2f}/{sdpa_mem_reserved:.2f} MB" 867 | ) 868 | print(f" FLOPS: {fmt_flops(mpi_to_flops(sdpa_time, sdpa_flops))}") 869 | 870 | print("\nJVP Attention:") 871 | print(f" Time: {jvp_time:.3f} ms") 872 | print( 873 | f" Memory (alloc/reserved): {jvp_mem_alloc:.2f}/{jvp_mem_reserved:.2f} MB" 874 | ) 875 | print(f" FLOPS: {fmt_flops(mpi_to_flops(jvp_time, jvp_flops))}") 876 | 877 | print(f"\nSpeedup: {sdpa_time/jvp_time:.2f}x") 878 | print(f"Memory ratio: {jvp_mem_alloc/sdpa_mem_alloc:.2f}x") 879 | 880 | return results 881 | 882 | 883 | def print_summary_table(results: list[BenchmarkResult]) -> None: 884 | """Print a summary table of benchmark results. 885 | 886 | Args: 887 | results: The list of benchmark results to summarize. 888 | """ 889 | print("\n" + "=" * 110) 890 | print("BENCHMARK SUMMARY") 891 | print("=" * 110) 892 | 893 | # Group results by seq_len, causal, and mask_type 894 | from collections import defaultdict 895 | 896 | grouped = defaultdict(dict) 897 | 898 | for r in results: 899 | key = (r.seq_len, r.is_causal, r.mask_type) 900 | grouped[key][r.method] = r 901 | 902 | # Print header 903 | print( 904 | f"{'Seq Len':<10} {'Causal':<8} {'Mask':<10} {'Method':<10} " 905 | f"{'Time (ms)':<12} {'Mem (MB)':<12} {'TFLOP/s':<12} " 906 | f"{'Max Error':<12} {'Grad Check':<10}" 907 | ) 908 | print("-" * 110) 909 | 910 | for (seq_len, is_causal, mask_type), methods in sorted(grouped.items()): 911 | for method in ["sdpa", "jvp_attn"]: 912 | if method in methods: 913 | r = methods[method] 914 | flops_str = fmt_flops(mpi_to_flops(r.time_ms, r.flops)) if r.flops else "N/A" 915 | 916 | if r.accuracy: 917 | error_str = f"{r.accuracy.max_error:.2e}" 918 | grad_check = "✓" if r.accuracy.is_accurate() else "✗" 919 | else: 920 | error_str = "baseline" 921 | grad_check = "N/A" 922 | 923 | print( 924 | f"{seq_len:<10} {str(is_causal):<8} {mask_type:<10} {method:<10} " 925 | f"{r.time_ms:<12.3f} {r.memory_allocated_mb:<12.2f} " 926 | f"{flops_str:<12} {error_str:<12} {grad_check:<10}" 927 | ) 928 | print() 929 | 930 | 931 | def print_mask_comparison_table(results: list[BenchmarkResult]) -> None: 932 | """Print a comparison table showing the impact of different mask types. 933 | 934 | Args: 935 | results: The list of benchmark results to analyze. 936 | """ 937 | print("\n" + "=" * 80) 938 | print("MASK TYPE PERFORMANCE COMPARISON") 939 | print("=" * 80) 940 | 941 | # Group by seq_len, is_causal, and method 942 | from collections import defaultdict 943 | 944 | grouped = defaultdict(lambda: defaultdict(dict)) 945 | 946 | for r in results: 947 | grouped[(r.seq_len, r.is_causal, r.method)][r.mask_type] = r 948 | 949 | print( 950 | f"{'Seq Len':<10} {'Causal':<8} {'Method':<10} " 951 | f"{'No Mask':<15} {'Boolean Mask':<15} {'Additive Mask':<15}" 952 | ) 953 | print("-" * 80) 954 | 955 | for (seq_len, is_causal, method), mask_results in sorted(grouped.items()): 956 | if method == "jvp_attn": # Only show JVP results for clarity 957 | none_time = mask_results.get("none", None) 958 | bool_time = mask_results.get("boolean", None) 959 | add_time = mask_results.get("additive", None) 960 | 961 | none_str = f"{none_time.time_ms:.2f} ms" if none_time else "N/A" 962 | bool_str = f"{bool_time.time_ms:.2f} ms" if bool_time else "N/A" 963 | add_str = f"{add_time.time_ms:.2f} ms" if add_time else "N/A" 964 | 965 | # Add relative performance 966 | if none_time and bool_time: 967 | bool_str += f" ({bool_time.time_ms/none_time.time_ms:.2f}x)" 968 | if none_time and add_time: 969 | add_str += f" ({add_time.time_ms/none_time.time_ms:.2f}x)" 970 | 971 | print( 972 | f"{seq_len:<10} {str(is_causal):<8} {method:<10} " 973 | f"{none_str:<15} {bool_str:<15} {add_str:<15}" 974 | ) 975 | 976 | 977 | def plot_benchmark_results( 978 | results: list[BenchmarkResult], args: Args, verbose: bool = False 979 | ) -> None: 980 | """Generate, save, and display plots summarizing benchmark results. 981 | 982 | Args: 983 | results: The list of benchmark results to plot. 984 | args: The command-line arguments, used for filename generation. 985 | verbose: Whether to print verbose output. 986 | """ 987 | if not PLOTTING_AVAILABLE: 988 | print("\nmatplotlib and/or numpy not found. Skipping plotting.") 989 | return 990 | 991 | from collections import defaultdict 992 | 993 | # Group results by (is_causal, mask_type) for creating subplots 994 | grouped = defaultdict(list) 995 | for r in results: 996 | key = (r.is_causal, r.mask_type) 997 | grouped[key].append(r) 998 | 999 | num_configs = len(grouped) 1000 | if num_configs == 0: 1001 | return 1002 | 1003 | # --- 1. Create and Populate the Performance Figure --- 1004 | fig_perf, axes_perf = plt.subplots(1, num_configs, figsize=(6 * num_configs, 5), squeeze=False) 1005 | fig_perf.suptitle("Performance Speedup (JVP Attention vs. SDPA)", fontsize=16) 1006 | 1007 | # --- 2. Create and Populate the Memory Figure --- 1008 | fig_mem, axes_mem = plt.subplots(1, num_configs, figsize=(6 * num_configs, 5), squeeze=False) 1009 | fig_mem.suptitle("Peak Allocated Memory Comparison", fontsize=16) 1010 | 1011 | # Iterate through data to draw on the axes of BOTH figures 1012 | for i, ((is_causal, mask_type), config_results) in enumerate(grouped.items()): 1013 | config_results.sort(key=lambda r: r.seq_len) 1014 | 1015 | # Fixed: Use the correct method names "jvp_attn" and "sdpa" 1016 | jvp_results = [r for r in config_results if r.method == "jvp_attn"] 1017 | sdpa_results = [r for r in config_results if r.method == "sdpa"] 1018 | 1019 | if not jvp_results or not sdpa_results: 1020 | # Debug print to help identify issues 1021 | if verbose: 1022 | print( 1023 | f"Warning: Missing results for config (causal={is_causal}, mask={mask_type})" 1024 | ) 1025 | print(f" Available methods: {set(r.method for r in config_results)}") 1026 | continue 1027 | 1028 | seq_lens = [r.seq_len for r in jvp_results] 1029 | jvp_times = np.array([r.time_ms for r in jvp_results]) 1030 | sdpa_times = np.array([r.time_ms for r in sdpa_results]) 1031 | jvp_mems = np.array([r.memory_allocated_mb for r in jvp_results]) 1032 | sdpa_mems = np.array([r.memory_allocated_mb for r in sdpa_results]) 1033 | speedup = sdpa_times / jvp_times 1034 | x = np.arange(len(seq_lens)) 1035 | 1036 | # Draw on the performance subplot 1037 | ax_perf = axes_perf[0, i] 1038 | bar_perf = ax_perf.bar(x, speedup, width=0.5, color="g") 1039 | ax_perf.bar_label(bar_perf, fmt=lambda val: f"{val:.2f}x") 1040 | ax_perf.axhline(1.0, color="grey", linestyle="--") 1041 | ax_perf.set( 1042 | ylabel="Speedup (SDPA Time / JVP Time)", 1043 | xlabel="Sequence Length", 1044 | title=f"Causal={is_causal}, Mask={mask_type}", 1045 | xticks=x, 1046 | xticklabels=seq_lens, 1047 | ylim=(0, max(1.1, np.max(speedup) * 1.15)), 1048 | ) 1049 | 1050 | # Draw on the memory subplot 1051 | ax_mem = axes_mem[0, i] 1052 | width = 0.35 1053 | rects1 = ax_mem.bar(x - width / 2, sdpa_mems, width, label="PyTorch SDPA") 1054 | rects2 = ax_mem.bar(x + width / 2, jvp_mems, width, label="JVP Attention") 1055 | ax_mem.bar_label(rects1, padding=3, fmt=lambda val: f"{val:.1f}") 1056 | ax_mem.bar_label(rects2, padding=3, fmt=lambda val: f"{val:.1f}") 1057 | ax_mem.set( 1058 | ylabel="Peak Allocated Memory (MB)", 1059 | xlabel="Sequence Length", 1060 | title=f"Causal={is_causal}, Mask={mask_type}", 1061 | xticks=x, 1062 | xticklabels=seq_lens, 1063 | ylim=(0, max(np.max(sdpa_mems), np.max(jvp_mems)) * 1.25), 1064 | ) 1065 | ax_mem.legend() 1066 | 1067 | # --- 3. Finalize and Save Each Figure Individually --- 1068 | plot_dir = "tests" 1069 | os.makedirs(plot_dir, exist_ok=True) 1070 | mask_suffix = "_with_masks" if args.test_masks else "" 1071 | 1072 | # Finalize and save the performance plot 1073 | perf_plot_path = os.path.join(plot_dir, f"{args.dtype}_jvp_attention_perf{mask_suffix}.png") 1074 | fig_perf.tight_layout(rect=[0, 0.03, 1, 0.95]) 1075 | fig_perf.savefig(perf_plot_path, dpi=150) 1076 | if verbose: 1077 | print(f"Saved performance plot to {perf_plot_path}") 1078 | 1079 | # Finalize and save the memory plot 1080 | mem_plot_path = os.path.join(plot_dir, f"{args.dtype}_jvp_attention_mem{mask_suffix}.png") 1081 | fig_mem.tight_layout(rect=[0, 0.03, 1, 0.95]) 1082 | fig_mem.savefig(mem_plot_path, dpi=150) 1083 | if verbose: 1084 | print(f"Saved memory plot to {mem_plot_path}") 1085 | 1086 | # --- 4. Show and Close --- 1087 | plt.show() 1088 | 1089 | # Explicitly close figures to free memory 1090 | plt.close(fig_perf) 1091 | plt.close(fig_mem) 1092 | 1093 | 1094 | def main(args: Args) -> None: 1095 | """Main benchmarking loop.""" 1096 | print("Flash Attention JVP Kernel Benchmark with Mask Testing") 1097 | print( 1098 | f"Configuration: bsz={args.bsz}, model_dim={args.model_dim}, " 1099 | f"head_dim={args.head_dim}, dtype={args.dtype}" 1100 | ) 1101 | print(f"Mask testing: {'Enabled' if args.test_masks else 'Disabled'}") 1102 | print(f"Gradient validation: {'Enabled' if args.validate_gradients else 'Disabled'}") 1103 | print(f"Performance benchmarking: {'Enabled' if args.benchmark_performance else 'Disabled'}") 1104 | if args.test_masks: 1105 | print(f"Mask probability: {args.mask_prob}") 1106 | 1107 | # Seed everything 1108 | random.seed(args.seed) 1109 | if PLOTTING_AVAILABLE: 1110 | np.random.seed(args.seed) 1111 | torch.manual_seed(args.seed) 1112 | torch.cuda.manual_seed_all(args.seed) 1113 | 1114 | results = run_benchmark_suite(args) 1115 | 1116 | # Print summary tables 1117 | print_summary_table(results) 1118 | 1119 | # If performance was benchmarked, plot benchmarking results 1120 | if args.benchmark_performance: 1121 | plot_benchmark_results(results, args) 1122 | 1123 | # If masks were tested, print comparison table 1124 | if args.test_masks: 1125 | print_mask_comparison_table(results) 1126 | 1127 | # Print statistics 1128 | print("\n" + "=" * 60) 1129 | print("STATISTICS") 1130 | print("=" * 60) 1131 | 1132 | # Calculate average speedup 1133 | speedups = [] 1134 | for i in range(0, len(results), 2): # Assuming pairs of sdpa/jvp_attn 1135 | if i + 1 < len(results): 1136 | sdpa_result = results[i] 1137 | jvp_result = results[i + 1] 1138 | if sdpa_result.method == "sdpa" and jvp_result.method == "jvp_attn": 1139 | speedup = sdpa_result.time_ms / jvp_result.time_ms 1140 | speedups.append(speedup) 1141 | 1142 | if speedups: 1143 | avg_speedup = sum(speedups) / len(speedups) 1144 | min_speedup = min(speedups) 1145 | max_speedup = max(speedups) 1146 | print(f"Average speedup: {avg_speedup:.2f}x") 1147 | print(f"Min speedup: {min_speedup:.2f}x") 1148 | print(f"Max speedup: {max_speedup:.2f}x") 1149 | 1150 | # Calculate accuracy statistics 1151 | accuracy_results = [r for r in results if r.accuracy is not None] 1152 | if accuracy_results: 1153 | all_accurate = all(r.accuracy.is_accurate() for r in accuracy_results) 1154 | num_accurate = sum(1 for r in accuracy_results if r.accuracy.is_accurate()) 1155 | print(f"\nAccuracy: {num_accurate}/{len(accuracy_results)} tests passed") 1156 | if all_accurate: 1157 | print("✓ All accuracy checks passed!") 1158 | else: 1159 | print("⚠️ Some accuracy checks failed") 1160 | 1161 | # Show which configurations failed 1162 | failed_configs = [ 1163 | (r.seq_len, r.is_causal, r.mask_type) 1164 | for r in accuracy_results 1165 | if not r.accuracy.is_accurate() 1166 | ] 1167 | if failed_configs: 1168 | print("\nFailed configurations:") 1169 | for seq_len, is_causal, mask_type in failed_configs: 1170 | print(f" - Seq={seq_len}, Causal={is_causal}, Mask={mask_type}") 1171 | 1172 | # Save results to file 1173 | import json 1174 | 1175 | # Convert results to JSON-serializable format 1176 | results_data = [] 1177 | for r in results: 1178 | result_dict = r._asdict() 1179 | if r.accuracy: 1180 | result_dict["accuracy"] = { 1181 | "primal_error": r.accuracy.primal_error, 1182 | "tangent_error": r.accuracy.tangent_error, 1183 | "loss_error": r.accuracy.loss_error, 1184 | "q_grad_error": r.accuracy.q_grad_error, 1185 | "k_grad_error": r.accuracy.k_grad_error, 1186 | "v_grad_error": r.accuracy.v_grad_error, 1187 | "max_error": r.accuracy.max_error, 1188 | "is_accurate": r.accuracy.is_accurate(), 1189 | } 1190 | results_data.append(result_dict) 1191 | 1192 | # Include configuration in filename 1193 | mask_suffix = "_with_masks" if args.test_masks else "" 1194 | output_filepath = os.path.join( 1195 | "tests", f"{args.dtype}_test_jvp_attention_results{mask_suffix}.json" 1196 | ) 1197 | os.makedirs(os.path.dirname(output_filepath), exist_ok=True) 1198 | 1199 | # Save both results and configuration 1200 | output_data = { 1201 | "configuration": { 1202 | "bsz": args.bsz, 1203 | "model_dim": args.model_dim, 1204 | "head_dim": args.head_dim, 1205 | "seq_lengths": args.seq_lengths, 1206 | "dtype": args.dtype, 1207 | "seed": args.seed, 1208 | "test_masks": args.test_masks, 1209 | "validate_gradients": args.validate_gradients, 1210 | "benchmark_performance": args.benchmark_performance, 1211 | "mask_prob": args.mask_prob if args.test_masks else None, 1212 | }, 1213 | "results": results_data, 1214 | "summary": { 1215 | "avg_speedup": avg_speedup if speedups else None, 1216 | "min_speedup": min_speedup if speedups else None, 1217 | "max_speedup": max_speedup if speedups else None, 1218 | "accuracy_rate": ( 1219 | f"{num_accurate}/{len(accuracy_results)}" if accuracy_results else "N/A" 1220 | ), 1221 | }, 1222 | } 1223 | 1224 | with open(output_filepath, "w") as f: 1225 | json.dump(output_data, f, indent=2, default=str) 1226 | 1227 | print(f"\nResults saved to {output_filepath}") 1228 | 1229 | 1230 | if __name__ == "__main__": 1231 | parser = Args.get_parser() 1232 | namespace = parser.parse_args() 1233 | args = Args.from_namespace(namespace) 1234 | main(args) 1235 | -------------------------------------------------------------------------------- /jvp_flash_attention/jvp_attention.py: -------------------------------------------------------------------------------- 1 | """ 2 | Fused Attention 3 | =============== 4 | 5 | This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf) 6 | 7 | Credits: OpenAI kernel team 8 | Extra Credits: 9 | 10 | * Original flash attention paper (https://arxiv.org/abs/2205.14135) 11 | * Rabe and Staats (https://arxiv.org/pdf/2112.05682v2.pdf) 12 | 13 | Plus modifications to support Jacobian-vector products (JVPs) and Hessian-vector products (HVPs): 14 | - Formulation of flash JVP, by Cheng Lu and Yang Song in https://arxiv.org/abs/2410.11081. 15 | - Reference Triton implementation, by Sofian Mejjoute. 16 | - Reimplementing reference implementation as an autograd function with latest Triton tutorial optimizations, by Alex Birch. 17 | - Support for forward to receive tangents, so as to compute fwd and jvp together; autograd workaround, by Emily (nshepperd). 18 | - Support for function transforms (e.g., torch.func.jvp) via the use of setup_context, by Shih-Ying Yeh. 19 | - Support for sequence lengths 32 & 64; float32 & bfloat16 precision; comprehensive, length and dtype-stratified unit tests; 20 | working backward hook w.r.t. tensor contiguity; HVP stress testing; standardized docstrings/packaging; and masking/dropout, by Alex Morehead. 21 | """ 22 | 23 | from __future__ import annotations 24 | 25 | import os 26 | from typing import Any, Literal, NamedTuple 27 | 28 | import torch 29 | import torch.autograd.forward_ad as fwAD 30 | import triton 31 | import triton.language as tl 32 | from torch import Tensor 33 | from torch.autograd import Function 34 | from torch.autograd.function import FunctionCtx 35 | 36 | # NOTE: Uncomment to turn warnings into errors for debugging 37 | # import warnings 38 | # warnings.filterwarnings("error", category=UserWarning) 39 | # warnings.filterwarnings("error", category=RuntimeWarning) 40 | 41 | try: 42 | from triton.tools.tensor_descriptor import TensorDescriptor 43 | 44 | HAS_TENSOR_DESC = True 45 | except ModuleNotFoundError: 46 | HAS_TENSOR_DESC = False 47 | 48 | MASK_CONST = ( 49 | -1.0e2 50 | ) # Use a large negative value for masking (compatible with float16, bfloat16, and float32) 51 | MIN_SEQUENCE_LENGTH = 32 # NOTE: All sequence lengths must be multiples of 2 >= 32 52 | 53 | 54 | def is_hip(): 55 | """Check if the current device is HIP.""" 56 | try: 57 | return triton.runtime.driver.active.get_current_target().backend == "hip" 58 | except Exception: 59 | return False 60 | 61 | 62 | def is_cuda(): 63 | """Check if the current device is CUDA.""" 64 | try: 65 | return triton.runtime.driver.active.get_current_target().backend == "cuda" 66 | except Exception: 67 | return False 68 | 69 | 70 | def supports_host_descriptor(): 71 | """Check if the current device supports host tensor descriptors.""" 72 | try: 73 | return is_cuda() and torch.cuda.get_device_capability()[0] >= 9 74 | except Exception: 75 | return False 76 | 77 | 78 | def supports_tma(): 79 | """Check if the current device supports Tensor Memory Access (TMA).""" 80 | try: 81 | return HAS_TENSOR_DESC and is_cuda() and torch.cuda.get_device_capability()[0] >= 9 82 | except Exception: 83 | return False 84 | 85 | 86 | def is_blackwell(): 87 | """Check if the current device is Blackwell architecture.""" 88 | try: 89 | return is_cuda() and torch.cuda.get_device_capability()[0] == 10 90 | except Exception: 91 | return False 92 | 93 | 94 | @triton.jit 95 | def create_dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride): 96 | """Generate dropout mask using Philox RNG. 97 | 98 | Args: 99 | philox_seed: Seed for Philox RNG. 100 | philox_offset: Offset for Philox RNG. 101 | dropout_p: Dropout probability. 102 | m: Number of rows. 103 | n: Number of columns. 104 | stride: Stride for the output mask. 105 | 106 | Returns: 107 | dropout_mask: A boolean mask indicating which elements to keep (1.0) or drop (0.0). 108 | dropout_scale: Scale factor to apply after dropout. 109 | """ 110 | ms = tl.arange(0, m) 111 | ns = tl.arange(0, n) 112 | offs = ms[:, None] * stride + ns[None, :] 113 | rng_offs = philox_offset + offs 114 | 115 | # Generate random values using Philox 116 | rand_vals = tl.rand(philox_seed, rng_offs) 117 | 118 | # Create dropout mask (1.0 = keep, 0.0 = drop) 119 | dropout_mask = rand_vals > dropout_p 120 | dropout_scale = 1.0 / (1.0 - dropout_p) if dropout_p < 1.0 else 0.0 121 | 122 | return dropout_mask, dropout_scale 123 | 124 | 125 | @triton.jit 126 | def _attn_fwd_inner( 127 | acc, 128 | g_acc, # 129 | l_i, 130 | m_i, # 131 | mu_i, 132 | p_tv_acc, # 133 | q, 134 | t_q, # 135 | K_block_ptr, 136 | V_block_ptr, # 137 | T_K_block_ptr, 138 | T_V_block_ptr, # 139 | # Mask and dropout parameters 140 | mask_block_ptr, 141 | dropout_p, 142 | philox_seed, 143 | philox_offset_base, 144 | # Other parameters 145 | dtype: tl.constexpr, 146 | start_m, 147 | qk_scale, 148 | sm_scale, # 149 | BLOCK_M: tl.constexpr, 150 | HEAD_DIM: tl.constexpr, 151 | BLOCK_N: tl.constexpr, # 152 | STAGE: tl.constexpr, 153 | offs_m: tl.constexpr, 154 | offs_n: tl.constexpr, # 155 | N_CTX: tl.constexpr, 156 | warp_specialize: tl.constexpr, # 157 | ENABLE_JVP: tl.constexpr, 158 | ENABLE_DROPOUT: tl.constexpr, 159 | MASK_TYPE: tl.constexpr, # 0: no mask, 1: boolean, 2: additive 160 | MASK_CONST: tl.constexpr = MASK_CONST, 161 | ): 162 | """Inner forward pass for attention mechanism. 163 | 164 | Args: 165 | acc: Accumulator tensor. 166 | g_acc: Gradient accumulator tensor. 167 | l_i: Tensor for storing intermediate results. 168 | m_i: Tensor for storing intermediate results. 169 | mu_i: Tensor for storing intermediate results. 170 | p_tv_acc: Tensor for storing intermediate results. 171 | q: Query tensor. 172 | t_q: Tangent of the query tensor. 173 | K_block_ptr: Pointer to the key block. 174 | V_block_ptr: Pointer to the value block. 175 | T_K_block_ptr: Pointer to the tangent key block. 176 | T_V_block_ptr: Pointer to the tangent value block. 177 | mask_block_ptr: Pointer to the attention mask block. 178 | dropout_p: Dropout probability. 179 | philox_seed: Seed for Philox RNG. 180 | philox_offset_base: Base offset for Philox RNG. 181 | dtype: Data type of the tensors. 182 | start_m: Starting index for the current block. 183 | qk_scale: Scale factor for the query-key dot product. 184 | sm_scale: Scale factor for the softmax. 185 | BLOCK_M: Block size for the M dimension. 186 | HEAD_DIM: Dimension of the attention heads. 187 | BLOCK_N: Block size for the N dimension. 188 | STAGE: Current stage of the computation. 189 | offs_m: Offsets for the M dimension. 190 | offs_n: Offsets for the N dimension. 191 | N_CTX: Number of context tokens. 192 | warp_specialize: Whether to apply warp specialization. 193 | ENABLE_JVP: Whether to enable JVP (Jacobian-vector product). 194 | ENABLE_DROPOUT: Whether to enable dropout. 195 | MASK_TYPE: Type of attention mask (0: no mask, 1: boolean, 2: additive). 196 | MASK_CONST: Constant value used for masking. 197 | 198 | Returns: 199 | The output tensors as a tuple. 200 | """ 201 | # Range of values handled by this stage 202 | if STAGE == 1: 203 | # NOTE: From 0 to the left of the diagonal 204 | lo, hi = 0, start_m * BLOCK_M 205 | elif STAGE == 2: 206 | # NOTE: Used only for the block in which there is transition between non-masked and masked keys 207 | lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M 208 | lo = tl.multiple_of(lo, BLOCK_M) 209 | else: 210 | # NOTE: Only used for non-causal attention 211 | lo, hi = 0, N_CTX 212 | 213 | K_block_ptr = tl.advance(K_block_ptr, (0, lo)) 214 | # NOTE: In fp8 mode, we may want to advance the V_block_ptr differently. 215 | # I did try advancing by (0, lo) instead for fp8, but I got an illegal memory access. 216 | # https://github.com/triton-lang/triton/commit/75d27b0b425329bad8c13b9cd47177d93590ec31 217 | V_block_ptr = tl.advance(V_block_ptr, (lo, 0)) 218 | if ENABLE_JVP: 219 | T_K_block_ptr = tl.advance(T_K_block_ptr, (0, lo)) 220 | T_V_block_ptr = tl.advance(T_V_block_ptr, (lo, 0)) 221 | 222 | if MASK_TYPE > 0: 223 | mask_block_ptr = tl.advance(mask_block_ptr, (0, lo)) 224 | 225 | # Loop over k, v and update accumulator 226 | for start_n in range(lo, hi, BLOCK_N): 227 | # Let the compiler know that start_n is a multiple 228 | # of BLOCK_N, so the compiler can do optimizations 229 | start_n = tl.multiple_of(start_n, BLOCK_N) 230 | 231 | # -- Compute qk -- 232 | k = tl.load(K_block_ptr) 233 | qk = tl.dot(q, k) 234 | if ENABLE_JVP: 235 | t_k = tl.load(T_K_block_ptr) 236 | t_qk = tl.dot(t_q, k) + tl.dot(q, t_k) 237 | 238 | # Load and apply attention mask if provided (before scaling for STAGE != 2) 239 | if MASK_TYPE > 0: 240 | mask = tl.load(mask_block_ptr) 241 | if MASK_TYPE == 1: # Boolean mask 242 | # Convert boolean to additive mask: True (attend) -> 0, False (ignore) -> -inf 243 | qk = qk + tl.where(mask == 1, 0.0, MASK_CONST) 244 | if ENABLE_JVP: 245 | t_qk = tl.where(mask == 1, t_qk, 0.0) 246 | 247 | elif MASK_TYPE == 2: # Additive mask 248 | qk = qk + mask 249 | 250 | m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) 251 | qk = qk * qk_scale - m_ij[:, None] 252 | 253 | # For causal attention (STAGE == 2) 254 | elif STAGE == 2: 255 | mask = offs_m[:, None] >= (start_n + offs_n[None, :]) 256 | qk = qk * qk_scale + tl.where(mask, 0.0, MASK_CONST) 257 | m_ij = tl.maximum(m_i, tl.max(qk, 1)) 258 | qk -= m_ij[:, None] 259 | 260 | # No masking case (MASK_TYPE == 0 and STAGE != 2) 261 | else: 262 | m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) 263 | qk = qk * qk_scale - m_ij[:, None] 264 | 265 | p = tl.math.exp2(qk) 266 | 267 | if MASK_TYPE == 1 or STAGE == 2: 268 | # Account for fully masked sequence blocks 269 | p = tl.where(mask == 1, p, 0.0) 270 | 271 | # Apply dropout if enabled 272 | if ENABLE_DROPOUT: 273 | philox_offset = philox_offset_base + start_m * N_CTX + start_n 274 | dropout_mask, dropout_scale = create_dropout_mask( 275 | philox_seed, philox_offset, dropout_p, BLOCK_M, BLOCK_N, N_CTX 276 | ) 277 | p = p * dropout_mask.to(dtype) * dropout_scale 278 | 279 | l_ij = tl.sum(p, 1) 280 | 281 | # -- Update m_i and l_i -- 282 | alpha = tl.math.exp2(m_i - m_ij) 283 | l_i = l_i * alpha + l_ij 284 | 285 | # -- Update output accumulator -- 286 | if warp_specialize and (BLOCK_M == 128 and HEAD_DIM == 128): 287 | BM: tl.constexpr = acc.shape[0] 288 | BN: tl.constexpr = acc.shape[1] 289 | acc0, acc1 = acc.reshape([BM, 2, BN // 2]).permute(0, 2, 1).split() 290 | acc0 = acc0 * alpha[:, None] 291 | acc1 = acc1 * alpha[:, None] 292 | acc = tl.join(acc0, acc1).permute(0, 2, 1).reshape([BM, BN]) 293 | else: 294 | acc = acc * alpha[:, None] 295 | 296 | v = tl.load(V_block_ptr) 297 | # NOTE: We may need to transpose v if dtype == tl.float8e5 298 | # https://github.com/triton-lang/triton/commit/75d27b0b425329bad8c13b9cd47177d93590ec31 299 | p = p.to(dtype) 300 | 301 | if ENABLE_JVP: 302 | p_tqk = p * (t_qk * sm_scale) 303 | 304 | if warp_specialize and (BLOCK_M == 128 and HEAD_DIM == 128): 305 | BM: tl.constexpr = g_acc.shape[0] 306 | BN: tl.constexpr = g_acc.shape[1] 307 | g_acc0, g_acc1 = g_acc.reshape([BM, 2, BN // 2]).permute(0, 2, 1).split() 308 | g_acc0 = g_acc0 * alpha[:, None] 309 | g_acc1 = g_acc1 * alpha[:, None] 310 | g_acc = tl.join(g_acc0, g_acc1).permute(0, 2, 1).reshape([BM, BN]) 311 | else: 312 | g_acc = g_acc * alpha[:, None] 313 | 314 | g_acc = tl.dot(p_tqk.to(v.dtype), v, g_acc) 315 | mu_ij = tl.sum(p_tqk, 1) 316 | mu_i = mu_i * alpha + mu_ij 317 | t_v = tl.load(T_V_block_ptr) 318 | p_tv_acc = p_tv_acc * alpha[:, None] + tl.dot(p, t_v.to(dtype)).to(t_v.dtype) 319 | T_V_block_ptr = tl.advance(T_V_block_ptr, (BLOCK_N, 0)) 320 | T_K_block_ptr = tl.advance(T_K_block_ptr, (0, BLOCK_N)) 321 | 322 | acc = tl.dot(p, v.to(dtype), acc).to(acc.dtype) 323 | 324 | # -- Update m_i -- 325 | m_i = m_ij 326 | 327 | # -- Move to the next block of K, V, and maybe the mask -- 328 | # NOTE: The fp8 PR made a change to how K and V are advanced here but I believe we already have that. 329 | # https://github.com/triton-lang/triton/commit/75d27b0b425329bad8c13b9cd47177d93590ec31 330 | V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) 331 | K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) 332 | 333 | if MASK_TYPE > 0: 334 | mask_block_ptr = tl.advance(mask_block_ptr, (0, BLOCK_N)) 335 | 336 | return acc, g_acc, l_i, m_i, mu_i, p_tv_acc 337 | 338 | 339 | @triton.jit 340 | def _attn_fwd_inner_tma( 341 | acc, 342 | g_acc, # 343 | l_i, 344 | m_i, # 345 | mu_i, 346 | p_tv_acc, # 347 | q, 348 | t_q, # 349 | desc_k, 350 | desc_v, # 351 | desc_k_t, 352 | desc_v_t, # 353 | offset_y, 354 | # Mask and dropout parameters 355 | mask_block_ptr, 356 | dropout_p, 357 | philox_seed, 358 | philox_offset_base, 359 | # Other parameters 360 | dtype: tl.constexpr, 361 | start_m, 362 | qk_scale, 363 | sm_scale, # 364 | BLOCK_M: tl.constexpr, 365 | HEAD_DIM: tl.constexpr, 366 | BLOCK_N: tl.constexpr, # 367 | STAGE: tl.constexpr, 368 | offs_m: tl.constexpr, 369 | offs_n: tl.constexpr, # 370 | N_CTX: tl.constexpr, 371 | warp_specialize: tl.constexpr, 372 | ENABLE_JVP: tl.constexpr, 373 | ENABLE_DROPOUT: tl.constexpr, 374 | MASK_TYPE: tl.constexpr, # 0: no mask, 1: boolean, 2: additive 375 | MASK_CONST: tl.constexpr = MASK_CONST, 376 | ): 377 | """Inner forward pass for attention mechanism with TMA (Tensor Memory Access) support. 378 | 379 | Args: 380 | acc: Accumulator tensor. 381 | g_acc: Gradient accumulator tensor. 382 | l_i: Tensor for layer normalization. 383 | m_i: Tensor for masking. 384 | mu_i: Tensor for mean. 385 | p_tv_acc: Tensor for TV attention. 386 | q: Query tensor. 387 | t_q: Transposed query tensor. 388 | desc_k: Descriptor for key tensor. 389 | desc_v: Descriptor for value tensor. 390 | desc_k_t: Descriptor for transposed key tensor. 391 | desc_v_t: Descriptor for transposed value tensor. 392 | offset_y: Offset for y dimension. 393 | mask_block_ptr: Pointer to the attention mask block. 394 | dropout_p: Dropout probability. 395 | philox_seed: Seed for Philox RNG. 396 | philox_offset_base: Base offset for Philox RNG. 397 | dtype: Data type. 398 | start_m: Start index for m dimension. 399 | qk_scale: Scale factor for qk. 400 | sm_scale: Scale factor for sm. 401 | BLOCK_M: Block size for m dimension. 402 | HEAD_DIM: Head dimension size. 403 | BLOCK_N: Block size for n dimension. 404 | STAGE: Stage of computation. 405 | offs_m: Offset for m dimension. 406 | offs_n: Offset for n dimension. 407 | N_CTX: Context size. 408 | warp_specialize: Flag for warp specialization. 409 | ENABLE_JVP: Flag for enabling JVP. 410 | ENABLE_DROPOUT: Flag for enabling dropout. 411 | MASK_TYPE: Type of attention mask (0: no mask, 1: boolean, 2: additive). 412 | MASK_CONST: Constant value used for masking. 413 | 414 | Returns: 415 | The output tensors as a tuple. 416 | """ 417 | # Range of values handled by this stage 418 | if STAGE == 1: 419 | # NOTE: From 0 to the left of the diagonal 420 | lo, hi = 0, start_m * BLOCK_M 421 | elif STAGE == 2: 422 | # NOTE: Used only for the block in which there is transition between non-masked and masked keys 423 | lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M 424 | lo = tl.multiple_of(lo, BLOCK_M) 425 | else: 426 | # NOTE: Only used for non-causal attention 427 | lo, hi = 0, N_CTX 428 | 429 | offsetk_y = offset_y + lo 430 | if dtype == tl.float8e5: 431 | offsetv_y = offset_y * HEAD_DIM + lo 432 | else: 433 | offsetv_y = offset_y + lo 434 | 435 | if MASK_TYPE > 0: 436 | mask_block_ptr = tl.advance(mask_block_ptr, (0, lo)) 437 | 438 | # Loop over k, v and update accumulator 439 | for start_n in tl.range(lo, hi, BLOCK_N, warp_specialize=warp_specialize): 440 | # Let the compiler know that start_n is a multiple 441 | # of BLOCK_N, so the compiler can do optimizations 442 | start_n = tl.multiple_of(start_n, BLOCK_N) 443 | 444 | # -- Compute qk ---- 445 | k = desc_k.load([offsetk_y, 0]).T 446 | qk = tl.dot(q, k) 447 | if ENABLE_JVP: 448 | t_k = desc_k_t.load([offsetk_y, 0]).T 449 | t_qk = tl.dot(t_q, k) + tl.dot(q, t_k) 450 | 451 | # Load and apply attention mask if provided (before scaling for STAGE != 2) 452 | if MASK_TYPE > 0: 453 | mask = tl.load(mask_block_ptr) 454 | if MASK_TYPE == 1: # Boolean mask 455 | # Convert boolean to additive mask: True (attend) -> 0, False (ignore) -> -inf 456 | qk = qk + tl.where(mask == 1, 0.0, MASK_CONST) 457 | if ENABLE_JVP: 458 | t_qk = tl.where(mask == 1, t_qk, 0.0) 459 | 460 | elif MASK_TYPE == 2: # Additive mask 461 | qk = qk + mask 462 | 463 | m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) 464 | qk = qk * qk_scale - m_ij[:, None] 465 | 466 | # For causal attention (STAGE == 2) 467 | elif STAGE == 2: 468 | mask = offs_m[:, None] >= (start_n + offs_n[None, :]) 469 | qk = qk * qk_scale + tl.where(mask, 0.0, MASK_CONST) 470 | m_ij = tl.maximum(m_i, tl.max(qk, 1)) 471 | qk -= m_ij[:, None] 472 | 473 | # No masking case (MASK_TYPE == 0 and STAGE != 2) 474 | else: 475 | m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) 476 | qk = qk * qk_scale - m_ij[:, None] 477 | 478 | p = tl.math.exp2(qk) 479 | 480 | if MASK_TYPE == 1 or STAGE == 2: 481 | # Account for fully masked sequence blocks 482 | p = tl.where(mask == 1, p, 0.0) 483 | 484 | # Apply dropout if enabled 485 | if ENABLE_DROPOUT: 486 | philox_offset = philox_offset_base + start_m * N_CTX + start_n 487 | dropout_mask, dropout_scale = create_dropout_mask( 488 | philox_seed, philox_offset, dropout_p, BLOCK_M, BLOCK_N, N_CTX 489 | ) 490 | p = p * dropout_mask.to(dtype) * dropout_scale 491 | 492 | # -- Compute correction factor 493 | alpha = tl.math.exp2(m_i - m_ij) 494 | l_ij = tl.sum(p, 1) 495 | 496 | # -- Update output accumulator -- 497 | if warp_specialize and (BLOCK_M == 128 and HEAD_DIM == 128): 498 | BM: tl.constexpr = acc.shape[0] 499 | BN: tl.constexpr = acc.shape[1] 500 | acc0, acc1 = acc.reshape([BM, 2, BN // 2]).permute(0, 2, 1).split() 501 | acc0 = acc0 * alpha[:, None] 502 | acc1 = acc1 * alpha[:, None] 503 | acc = tl.join(acc0, acc1).permute(0, 2, 1).reshape([BM, BN]) 504 | else: 505 | acc = acc * alpha[:, None] 506 | 507 | # Prepare p and v for the dot 508 | if dtype == tl.float8e5: 509 | v = desc_v.load([0, offsetv_y]).T 510 | else: 511 | v = desc_v.load([offsetv_y, 0]) 512 | 513 | p = p.to(dtype) 514 | 515 | if ENABLE_JVP: 516 | p_tqk = p * (t_qk * sm_scale) 517 | 518 | # NOT: This non-transposed v for FP8 is presumably only supported on Blackwell 519 | if warp_specialize and (BLOCK_M == 128 and HEAD_DIM == 128): 520 | BM: tl.constexpr = g_acc.shape[0] 521 | BN: tl.constexpr = g_acc.shape[1] 522 | g_acc0, g_acc1 = g_acc.reshape([BM, 2, BN // 2]).permute(0, 2, 1).split() 523 | g_acc0 = g_acc0 * alpha[:, None] 524 | g_acc1 = g_acc1 * alpha[:, None] 525 | g_acc = tl.join(g_acc0, g_acc1).permute(0, 2, 1).reshape([BM, BN]) 526 | else: 527 | g_acc = g_acc * alpha[:, None] 528 | 529 | g_acc = tl.dot(p_tqk.to(v.dtype), v, g_acc) 530 | mu_ij = tl.sum(p_tqk, 1) 531 | mu_i = mu_i * alpha + mu_ij 532 | t_v = desc_v_t.load([offsetv_y, 0]) 533 | p_tv_acc = p_tv_acc * alpha[:, None] + tl.dot(p, t_v.to(dtype)).to(t_v.dtype) 534 | 535 | # NOTE: This non transposed v for FP8 is only supported on Blackwell 536 | acc = tl.dot(p, v.to(dtype), acc).to(acc.dtype) 537 | 538 | # Update m_i and l_i 539 | # Place this at the end of the loop to reduce register pressure 540 | l_i = l_i * alpha + l_ij 541 | m_i = m_ij 542 | offsetk_y += BLOCK_N 543 | offsetv_y += BLOCK_N 544 | 545 | if MASK_TYPE > 0: 546 | mask_block_ptr = tl.advance(mask_block_ptr, (0, BLOCK_N)) 547 | 548 | return acc, g_acc, l_i, m_i, mu_i, p_tv_acc 549 | 550 | 551 | def _host_descriptor_pre_hook(nargs): 552 | """Pre-hook to set up tensor descriptors for the attention kernel. 553 | 554 | Args: 555 | nargs: A dictionary of kernel arguments. 556 | """ 557 | BLOCK_M = nargs["BLOCK_M"] 558 | BLOCK_N = nargs["BLOCK_N"] 559 | HEAD_DIM = nargs["HEAD_DIM"] 560 | if not supports_tma() or not isinstance(nargs["desc_q"], TensorDescriptor): 561 | return 562 | nargs["desc_q"].block_shape = [BLOCK_M, HEAD_DIM] 563 | if nargs["FP8_OUTPUT"]: 564 | nargs["desc_v"].block_shape = [HEAD_DIM, BLOCK_N] 565 | else: 566 | nargs["desc_v"].block_shape = [BLOCK_N, HEAD_DIM] 567 | nargs["desc_k"].block_shape = [BLOCK_N, HEAD_DIM] 568 | nargs["desc_o"].block_shape = [BLOCK_M, HEAD_DIM] 569 | 570 | 571 | if is_hip(): 572 | NUM_STAGES_OPTIONS = [1] 573 | elif supports_host_descriptor(): 574 | NUM_STAGES_OPTIONS = [2, 3, 4] 575 | else: 576 | NUM_STAGES_OPTIONS = [2, 3, 4] 577 | 578 | configs = [ 579 | triton.Config( 580 | {"BLOCK_M": BM, "BLOCK_N": BN}, 581 | num_stages=s, 582 | num_warps=w, 583 | pre_hook=_host_descriptor_pre_hook, 584 | ) 585 | for BM in [MIN_SEQUENCE_LENGTH, 64, 128] 586 | for BN in [MIN_SEQUENCE_LENGTH, 64, 128] 587 | for s in NUM_STAGES_OPTIONS 588 | for w in [4, 8] 589 | ] 590 | if "PYTEST_VERSION" in os.environ: 591 | # Use a single config in testing for reproducibility 592 | configs = [ 593 | triton.Config( 594 | dict(BLOCK_M=128, BLOCK_N=64), 595 | num_stages=2, 596 | num_warps=4, 597 | pre_hook=_host_descriptor_pre_hook, 598 | ), 599 | ] 600 | 601 | 602 | def keep(conf): 603 | """Keep configurations that meet certain criteria. 604 | 605 | Args: 606 | conf: A configuration object. 607 | """ 608 | BLOCK_M = conf.kwargs["BLOCK_M"] 609 | BLOCK_N = conf.kwargs["BLOCK_N"] 610 | return not (BLOCK_M * BLOCK_N < 128 * 128 and conf.num_warps == 8) 611 | 612 | 613 | def prune_invalid_configs(configs, named_args, **kwargs): 614 | """Prune configurations that are invalid based on certain criteria. 615 | 616 | Args: 617 | configs: A list of configuration objects. 618 | named_args: A dictionary of named arguments. 619 | **kwargs: Additional keyword arguments. 620 | 621 | Returns: 622 | A list of valid configuration objects. 623 | """ 624 | N_CTX = kwargs["N_CTX"] 625 | 626 | if N_CTX == MIN_SEQUENCE_LENGTH: 627 | # Filter out configs where BLOCK_M > MIN_SEQUENCE_LENGTH 628 | return [conf for conf in configs if conf.kwargs.get("BLOCK_M", 0) <= MIN_SEQUENCE_LENGTH] 629 | 630 | # Filter out configs where BLOCK_M > N_CTX or BLOCK_M <= MIN_SEQUENCE_LENGTH, as 631 | # BLOCK_M = MIN_SEQUENCE_LENGTH often leads to reduced numerical accuracy for longer sequences 632 | # TODO: Find out why this occurs 633 | return [ 634 | conf for conf in configs if MIN_SEQUENCE_LENGTH < conf.kwargs.get("BLOCK_M", 0) <= N_CTX 635 | ] 636 | 637 | 638 | @triton.jit 639 | def _maybe_make_tensor_desc(desc_or_ptr, shape, strides, block_shape): 640 | """Maybe make a tensor descriptor from a pointer. 641 | 642 | Args: 643 | desc_or_ptr: The input tensor or pointer. 644 | shape: The shape of the tensor. 645 | strides: The strides of the tensor. 646 | block_shape: The block shape of the tensor. 647 | 648 | Returns: 649 | A tensor descriptor. 650 | """ 651 | if isinstance(desc_or_ptr, tl.tensor_descriptor): 652 | return desc_or_ptr 653 | else: 654 | return tl.make_tensor_descriptor(desc_or_ptr, shape, strides, block_shape) 655 | 656 | 657 | # @triton.autotune( 658 | # configs=list(filter(keep, configs)), 659 | # key=["N_CTX", "HEAD_DIM", "FP8_OUTPUT", "warp_specialize"], 660 | # prune_configs_by={"early_config_prune": prune_invalid_configs}, 661 | # ) 662 | @triton.jit 663 | def _attn_fwd( 664 | Q, 665 | K, 666 | V, 667 | T_Q, 668 | T_K, 669 | T_V, # 670 | sm_scale, 671 | M, 672 | Out, 673 | T_Out, # 674 | Mask, # Mask tensor 675 | dropout_p, # Dropout probability 676 | philox_seed, # RNG seed for dropout 677 | stride_qz, 678 | stride_qh, 679 | stride_qm, 680 | stride_qk, # 681 | stride_kz, 682 | stride_kh, 683 | stride_kn, 684 | stride_kk, # 685 | stride_vz, 686 | stride_vh, 687 | stride_vk, 688 | stride_vn, # 689 | stride_tqz, 690 | stride_tqh, 691 | stride_tqm, 692 | stride_tqk, # 693 | stride_tkz, 694 | stride_tkh, 695 | stride_tkn, 696 | stride_tkk, # 697 | stride_tvz, 698 | stride_tvh, 699 | stride_tvk, 700 | stride_tvn, # 701 | stride_oz, 702 | stride_oh, 703 | stride_om, 704 | stride_on, # 705 | stride_toz, 706 | stride_toh, 707 | stride_tom, 708 | stride_ton, # 709 | stride_mz, # Mask stride 710 | stride_mh, # Mask stride 711 | stride_mm, # Mask stride 712 | stride_mn, # Mask stride 713 | Z, 714 | H, 715 | N_CTX, # 716 | HEAD_DIM: tl.constexpr, # 717 | BLOCK_M: tl.constexpr, # 718 | BLOCK_N: tl.constexpr, # 719 | FP8_OUTPUT: tl.constexpr, # 720 | STAGE: tl.constexpr, # 721 | warp_specialize: tl.constexpr, # 722 | ENABLE_JVP: tl.constexpr, # 723 | ENABLE_DROPOUT: tl.constexpr, # 724 | MASK_TYPE: tl.constexpr, # 725 | ): 726 | """Forward attention computation. 727 | 728 | Args: 729 | Q: Query tensor. 730 | K: Key tensor. 731 | V: Value tensor. 732 | T_Q: Tensor for query. 733 | T_K: Tensor for key. 734 | T_V: Tensor for value. 735 | sm_scale: Scale factor. 736 | M: Number of rows. 737 | Out: Output tensor. 738 | T_Out: Tensor for output. 739 | Mask: Attention mask tensor. 740 | dropout_p: Dropout probability. 741 | philox_seed: Seed for Philox RNG. 742 | stride_qz: Stride for query z dimension. 743 | stride_qh: Stride for query h dimension. 744 | stride_qm: Stride for query m dimension. 745 | stride_qk: Stride for query k dimension. 746 | stride_kz: Stride for key z dimension. 747 | stride_kh: Stride for key h dimension. 748 | stride_kn: Stride for key n dimension. 749 | stride_kk: Stride for key k dimension. 750 | stride_vz: Stride for value z dimension. 751 | stride_vh: Stride for value h dimension. 752 | stride_vk: Stride for value k dimension. 753 | stride_vn: Stride for value n dimension. 754 | stride_tqz: Stride for tensor query z dimension. 755 | stride_tqh: Stride for tensor query h dimension. 756 | stride_tqm: Stride for tensor query m dimension. 757 | stride_tqk: Stride for tensor query k dimension. 758 | stride_tkz: Stride for tensor key z dimension. 759 | stride_tkh: Stride for tensor key h dimension. 760 | stride_tkn: Stride for tensor key n dimension. 761 | stride_tkk: Stride for tensor key k dimension. 762 | stride_tvz: Stride for tensor value z dimension. 763 | stride_tvh: Stride for tensor value h dimension. 764 | stride_tvk: Stride for tensor value k dimension. 765 | stride_tvn: Stride for tensor value n dimension. 766 | stride_oz: Stride for output z dimension. 767 | stride_oh: Stride for output h dimension. 768 | stride_om: Stride for output m dimension. 769 | stride_on: Stride for output n dimension. 770 | stride_toz: Stride for tensor output z dimension. 771 | stride_toh: Stride for tensor output h dimension. 772 | stride_tom: Stride for tensor output m dimension. 773 | stride_ton: Stride for tensor output n dimension. 774 | stride_mz: Stride for mask z dimension. 775 | stride_mh: Stride for mask h dimension. 776 | stride_mm: Stride for mask m dimension. 777 | stride_mn: Stride for mask n dimension. 778 | Z: Number of z dimensions. 779 | H: Number of h dimensions. 780 | N_CTX: Number of context dimensions. 781 | HEAD_DIM: Head dimension. 782 | BLOCK_M: Block size for the queries. 783 | BLOCK_N: Block size for the keys/values. 784 | FP8_OUTPUT: FP8 output flag. 785 | STAGE: Stage. 786 | warp_specialize: Warp specialization flag. 787 | ENABLE_JVP: Enable JVP flag. 788 | ENABLE_DROPOUT: Enable dropout flag. 789 | MASK_TYPE: Mask type (0: no mask, 1: boolean, 2: additive). 790 | """ 791 | tl.static_assert(BLOCK_N <= HEAD_DIM) # N = KV 792 | 793 | # Prepare metadata and indices 794 | dtype = tl.float8e5 if FP8_OUTPUT else tl.float32 # For dot products 795 | start_m = tl.program_id(0) # Which block (in the input query sequence) to process 796 | off_hz = tl.program_id( 797 | 1 798 | ) # Which head and batch element to process, with a program being a single head of a single batch element 799 | off_z = ( 800 | off_hz // H 801 | ) # Which batch element this program is assigned to (n.b., each batch element has H heads) 802 | off_h = off_hz % H # The position of the head to process in the batch 803 | 804 | # NOTE: This allows one to get the (N_CTX, HEAD_DIM) block in Q, K, V by indexing it by batch and head 805 | qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh 806 | 807 | # Initialize block pointers 808 | Q_block_ptr = tl.make_block_ptr( 809 | base=Q + qvk_offset, 810 | shape=(N_CTX, HEAD_DIM), 811 | strides=(stride_qm, stride_qk), 812 | offsets=(start_m * BLOCK_M, 0), # M = Q 813 | block_shape=(BLOCK_M, HEAD_DIM), 814 | order=(1, 0), 815 | ) 816 | v_order: tl.constexpr = (0, 1) if V.dtype.element_ty == tl.float8e5 else (1, 0) 817 | V_block_ptr = tl.make_block_ptr( 818 | base=V + qvk_offset, 819 | shape=(N_CTX, HEAD_DIM), 820 | strides=(stride_vk, stride_vn), 821 | offsets=(0, 0), 822 | block_shape=(BLOCK_N, HEAD_DIM), 823 | order=v_order, 824 | ) 825 | K_block_ptr = tl.make_block_ptr( 826 | base=K + qvk_offset, 827 | shape=(HEAD_DIM, N_CTX), 828 | strides=( 829 | stride_kk, 830 | stride_kn, 831 | ), # NOTE: We invert the strides of K to get its matrix transpose K^T 832 | offsets=(0, 0), 833 | block_shape=(HEAD_DIM, BLOCK_N), 834 | order=(0, 1), 835 | ) 836 | O_block_ptr = tl.make_block_ptr( 837 | base=Out + qvk_offset, 838 | shape=(N_CTX, HEAD_DIM), 839 | strides=(stride_om, stride_on), 840 | offsets=(start_m * BLOCK_M, 0), 841 | block_shape=(BLOCK_M, HEAD_DIM), 842 | order=(1, 0), 843 | ) 844 | 845 | # Initialize block pointer for the mask, if provided 846 | if MASK_TYPE > 0: 847 | mask_offset = off_z.to(tl.int64) * stride_mz + off_h.to(tl.int64) * stride_mh 848 | mask_block_ptr = tl.make_block_ptr( 849 | base=Mask + mask_offset, 850 | shape=(N_CTX, N_CTX), 851 | strides=(stride_mm, stride_mn), 852 | offsets=(start_m * BLOCK_M, 0), 853 | block_shape=(BLOCK_M, BLOCK_N), 854 | order=(1, 0), 855 | ) 856 | else: 857 | mask_block_ptr = None 858 | 859 | # Initialize dropout offset for this block 860 | philox_offset_base = off_hz * N_CTX * N_CTX 861 | 862 | # Initialize offsets for the query tokens to process 863 | offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) 864 | offs_n = tl.arange(0, BLOCK_N) 865 | 866 | # Initialize accumulator pointers: 867 | # m, the running maximum (one for each query) 868 | m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") 869 | # l, the running sum (one for each query as we sum the attention scores by rows) 870 | l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 871 | # acc, the output accumulator (one vector for each query) 872 | acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) 873 | 874 | if ENABLE_JVP: 875 | # NOTE: It's extremely likely we could just reuse qvk_offset, but this seems cheap so whatever 876 | t_qvk_offset = off_z.to(tl.int64) * stride_tqz + off_h.to(tl.int64) * stride_tqh 877 | T_Q_block_ptr = tl.make_block_ptr( 878 | base=T_Q + t_qvk_offset, 879 | shape=(N_CTX, HEAD_DIM), 880 | strides=(stride_tqm, stride_tqk), 881 | offsets=(start_m * BLOCK_M, 0), 882 | block_shape=(BLOCK_M, HEAD_DIM), 883 | order=(1, 0), 884 | ) 885 | # NOTE: Could probably just reuse v_order here 886 | t_v_order: tl.constexpr = (0, 1) if T_V.dtype.element_ty == tl.float8e5 else (1, 0) 887 | T_V_block_ptr = tl.make_block_ptr( 888 | base=T_V + t_qvk_offset, 889 | shape=(N_CTX, HEAD_DIM), 890 | strides=(stride_tvk, stride_tvn), 891 | offsets=(0, 0), 892 | block_shape=(BLOCK_N, HEAD_DIM), 893 | order=t_v_order, 894 | ) 895 | T_K_block_ptr = tl.make_block_ptr( 896 | base=T_K + t_qvk_offset, 897 | shape=(HEAD_DIM, N_CTX), 898 | strides=( 899 | stride_tkk, 900 | stride_tkn, 901 | ), # NOTE: We invert the strides of tangent K (k_t) to get its matrix transpose K^T 902 | offsets=(0, 0), 903 | block_shape=(HEAD_DIM, BLOCK_N), 904 | order=(0, 1), 905 | ) 906 | T_O_block_ptr = tl.make_block_ptr( 907 | base=T_Out + t_qvk_offset, 908 | shape=(N_CTX, HEAD_DIM), 909 | strides=(stride_tom, stride_ton), 910 | offsets=(start_m * BLOCK_M, 0), 911 | block_shape=(BLOCK_M, HEAD_DIM), 912 | order=(1, 0), 913 | ) 914 | # Load q_t: It will stay in SRAM throughout. 915 | t_q = tl.load(T_Q_block_ptr) 916 | g_acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) 917 | mu_i = tl.zeros([BLOCK_M], dtype=tl.float32) 918 | p_tv_acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) 919 | else: 920 | t_q = None 921 | T_V_block_ptr = None 922 | T_K_block_ptr = None 923 | # Allocate minimal dummy tensors to keep consistent the return signature of _attn_fwd_inner 924 | g_acc = tl.zeros([1, 1], dtype=tl.float32) 925 | mu_i = tl.zeros([1], dtype=tl.float32) 926 | p_tv_acc = tl.zeros([1, 1], dtype=tl.float32) 927 | 928 | # Prepare scales 929 | qk_scale = sm_scale 930 | qk_scale = qk_scale * 1.44269504 # 1/log(2) 931 | 932 | # Load q: It will stay in SRAM throughout. 933 | q = tl.load(Q_block_ptr) 934 | 935 | # Stage: 3 if causal, else 1 936 | if STAGE == 1 or STAGE == 3: 937 | # NOTE: This step runs for non-causal attention or for the 938 | # blocks to the left of the diagonal for causal attention 939 | acc, g_acc, l_i, m_i, mu_i, p_tv_acc = _attn_fwd_inner( 940 | acc, 941 | g_acc, 942 | l_i, 943 | m_i, # 944 | mu_i, 945 | p_tv_acc, # 946 | q, 947 | t_q, # 948 | K_block_ptr, 949 | V_block_ptr, # 950 | T_K_block_ptr, 951 | T_V_block_ptr, # 952 | mask_block_ptr, 953 | dropout_p, 954 | philox_seed, 955 | philox_offset_base, 956 | dtype, 957 | start_m, 958 | qk_scale, 959 | sm_scale, # 960 | BLOCK_M, 961 | HEAD_DIM, 962 | BLOCK_N, # 963 | 4 - STAGE, 964 | offs_m, 965 | offs_n, 966 | N_CTX, # 967 | warp_specialize, 968 | ENABLE_JVP, 969 | ENABLE_DROPOUT, 970 | MASK_TYPE, 971 | ) 972 | 973 | if STAGE == 3: 974 | # NOTE: This step runs for the blocks to the 975 | # right of the diagonal for causal attention 976 | acc, g_acc, l_i, m_i, mu_i, p_tv_acc = _attn_fwd_inner( 977 | acc, 978 | g_acc, # 979 | l_i, 980 | m_i, # 981 | mu_i, 982 | p_tv_acc, # 983 | q, 984 | t_q, # 985 | K_block_ptr, 986 | V_block_ptr, # 987 | T_K_block_ptr, 988 | T_V_block_ptr, # 989 | mask_block_ptr, 990 | dropout_p, 991 | philox_seed, 992 | philox_offset_base, 993 | dtype, 994 | start_m, 995 | qk_scale, 996 | sm_scale, # 997 | BLOCK_M, 998 | HEAD_DIM, 999 | BLOCK_N, # 1000 | 2, 1001 | offs_m, 1002 | offs_n, 1003 | N_CTX, # 1004 | warp_specialize, 1005 | ENABLE_JVP, 1006 | ENABLE_DROPOUT, 1007 | MASK_TYPE, 1008 | ) 1009 | 1010 | # Epilogue 1011 | empty_mask = l_i == 0.0 1012 | if empty_mask.sum() > 0: 1013 | l_i = tl.where( 1014 | empty_mask, 1.0, l_i 1015 | ) # NOTE: This happens if the entire block is masked out. 1016 | 1017 | m_i = m_i + tl.where( 1018 | # NOTE: This is needed to compute the logsumexp for the backward pass. 1019 | empty_mask, 1020 | 0.0, 1021 | tl.math.log2(l_i), 1022 | ) 1023 | 1024 | acc = acc / l_i[:, None] 1025 | m_ptrs = M + off_hz * N_CTX + offs_m 1026 | tl.store(m_ptrs, m_i) 1027 | tl.store(O_block_ptr, acc.to(Out.type.element_ty)) 1028 | 1029 | # If JVP is enabled, compute and store the output tangent 1030 | if ENABLE_JVP: 1031 | t_p_v = g_acc / l_i[:, None] - (mu_i / l_i)[:, None] * acc 1032 | t_y_out = t_p_v + p_tv_acc / l_i[:, None] 1033 | tl.store(T_O_block_ptr, t_y_out.to(T_Out.type.element_ty)) 1034 | 1035 | 1036 | def _tma_pre_hook(nargs): 1037 | """Pre-hook for TMA (Tensor Memory Access) optimization. 1038 | 1039 | Args: 1040 | nargs: A dictionary containing the kernel arguments. 1041 | """ 1042 | BLOCK_M = nargs["BLOCK_M"] 1043 | BLOCK_N = nargs["BLOCK_N"] 1044 | HEAD_DIM = nargs["HEAD_DIM"] 1045 | nargs["desc_q"].block_shape = [BLOCK_M, HEAD_DIM] 1046 | nargs["desc_v"].block_shape = [BLOCK_N, HEAD_DIM] 1047 | nargs["desc_k"].block_shape = [BLOCK_N, HEAD_DIM] 1048 | nargs["desc_o"].block_shape = [BLOCK_M, HEAD_DIM] 1049 | 1050 | 1051 | # We don't run auto-tuning every time to keep the tutorial fast. Keeping 1052 | # the code below and commenting out the equivalent parameters is convenient for 1053 | # re-tuning. 1054 | configs_tma = [ 1055 | triton.Config( 1056 | {"BLOCK_M": BM, "BLOCK_N": BN}, 1057 | num_stages=s, 1058 | num_warps=w, 1059 | pre_hook=_tma_pre_hook, 1060 | ) 1061 | for BM in [MIN_SEQUENCE_LENGTH, 64, 128, 256] 1062 | for BN in [MIN_SEQUENCE_LENGTH, 64, 128] 1063 | for s in [3, 4, 5] 1064 | for w in [4, 8] 1065 | ] 1066 | 1067 | 1068 | def keep_tma(conf): 1069 | """Check if TMA (Tensor Memory Access) optimization should be kept for the given configuration. 1070 | 1071 | Args: 1072 | conf: The configuration to check. 1073 | """ 1074 | BLOCK_M = conf.kwargs["BLOCK_M"] 1075 | BLOCK_N = conf.kwargs["BLOCK_N"] 1076 | return not ( 1077 | is_cuda() 1078 | and torch.cuda.get_device_capability()[0] == 9 1079 | and BLOCK_M * BLOCK_N < 128 * 128 1080 | and conf.num_warps == 8 1081 | ) 1082 | 1083 | 1084 | # @triton.autotune( 1085 | # configs=list(filter(keep_tma, configs_tma)), 1086 | # key=["N_CTX", "HEAD_DIM", "FP8_OUTPUT", "warp_specialize"], 1087 | # prune_configs_by={"early_config_prune": prune_invalid_configs}, 1088 | # ) 1089 | @triton.jit 1090 | def _attn_fwd_tma( 1091 | sm_scale, 1092 | M, # 1093 | Z, 1094 | H, # 1095 | desc_q, 1096 | desc_k, 1097 | desc_v, # 1098 | desc_q_t, 1099 | desc_k_t, 1100 | desc_v_t, # 1101 | desc_o, 1102 | desc_o_t, # 1103 | Mask, # Mask tensor 1104 | dropout_p, # Dropout probability 1105 | philox_seed, # RNG seed for dropout 1106 | stride_mz, # Mask stride 1107 | stride_mh, # Mask stride 1108 | stride_mm, # Mask stride 1109 | stride_mn, # Mask stride 1110 | N_CTX, # 1111 | HEAD_DIM: tl.constexpr, # 1112 | BLOCK_M: tl.constexpr, # 1113 | BLOCK_N: tl.constexpr, # 1114 | FP8_OUTPUT: tl.constexpr, # 1115 | STAGE: tl.constexpr, # 1116 | warp_specialize: tl.constexpr, # 1117 | ENABLE_JVP: tl.constexpr, # 1118 | ENABLE_DROPOUT: tl.constexpr, # 1119 | MASK_TYPE: tl.constexpr, # 1120 | ): 1121 | """Forward attention computation with TMA (Tensor Memory Access) support. 1122 | 1123 | Args: 1124 | sm_scale: Scale factor for the softmax. 1125 | M: Number of rows in the input. 1126 | Z: Number of channels in the input. 1127 | H: Number of heads in the multi-head attention. 1128 | desc_q: Descriptor for the query tensor. 1129 | desc_k: Descriptor for the key tensor. 1130 | desc_v: Descriptor for the value tensor. 1131 | desc_q_t: Descriptor for the transposed query tensor. 1132 | desc_k_t: Descriptor for the transposed key tensor. 1133 | desc_v_t: Descriptor for the transposed value tensor. 1134 | desc_o: Descriptor for the output tensor. 1135 | desc_o_t: Descriptor for the transposed output tensor. 1136 | Mask: Attention mask tensor. 1137 | dropout_p: Dropout probability. 1138 | philox_seed: Seed for the Philox random number generator. 1139 | stride_mz: Stride for the mask in the z dimension. 1140 | stride_mh: Stride for the mask in the h dimension. 1141 | stride_mm: Stride for the mask in the m dimension. 1142 | stride_mn: Stride for the mask in the n dimension. 1143 | N_CTX: Context length. 1144 | HEAD_DIM: Dimension of each head. 1145 | BLOCK_M: Block size for the queries. 1146 | BLOCK_N: Block size for the keys/values. 1147 | FP8_OUTPUT: Flag indicating if FP8 output is used. 1148 | STAGE: Stage of the computation. 1149 | warp_specialize: Flag indicating if warp specialization is used. 1150 | ENABLE_JVP: Flag indicating if JVP (Jacobian-vector product) is enabled. 1151 | ENABLE_DROPOUT: Flag indicating if dropout is enabled. 1152 | MASK_TYPE: Type of mask used (0: no mask, 1: boolean, 2: additive). 1153 | """ 1154 | tl.static_assert(BLOCK_N <= HEAD_DIM) # N = KV 1155 | 1156 | # Prepare metadata and indices 1157 | dtype = tl.float8e5 if FP8_OUTPUT else tl.float32 # For dot products 1158 | start_m = tl.program_id(0) # Which block (in the input query sequence) to process 1159 | off_hz = tl.program_id( 1160 | 1 1161 | ) # Which head and batch element to process, with a program being a single head of a single batch element 1162 | off_z = ( 1163 | off_hz // H 1164 | ) # Which batch element this program is assigned to (n.b., each batch element has H heads) 1165 | off_h = off_hz % H # The position of the head to process in the batch 1166 | 1167 | # Initialize tensor descriptors 1168 | y_dim = Z * H * N_CTX 1169 | desc_q = _maybe_make_tensor_desc( 1170 | desc_q, 1171 | shape=[y_dim, HEAD_DIM], 1172 | strides=[HEAD_DIM, 1], 1173 | block_shape=[BLOCK_M, HEAD_DIM], # M = Q 1174 | ) 1175 | if FP8_OUTPUT: 1176 | v_shape = [HEAD_DIM, y_dim] 1177 | v_strides = [N_CTX, 1] 1178 | v_block_shape = [HEAD_DIM, BLOCK_N] 1179 | else: 1180 | v_shape = [y_dim, HEAD_DIM] 1181 | v_strides = [HEAD_DIM, 1] 1182 | v_block_shape = [BLOCK_N, HEAD_DIM] 1183 | desc_v = _maybe_make_tensor_desc( 1184 | desc_v, shape=v_shape, strides=v_strides, block_shape=v_block_shape 1185 | ) 1186 | desc_k = _maybe_make_tensor_desc( 1187 | desc_k, 1188 | shape=[y_dim, HEAD_DIM], 1189 | strides=[HEAD_DIM, 1], 1190 | block_shape=[BLOCK_N, HEAD_DIM], 1191 | ) 1192 | desc_o = _maybe_make_tensor_desc( 1193 | desc_o, 1194 | shape=[y_dim, HEAD_DIM], 1195 | strides=[HEAD_DIM, 1], 1196 | block_shape=[BLOCK_M, HEAD_DIM], 1197 | ) 1198 | 1199 | offset_y = off_z * (N_CTX * H) + off_h * N_CTX 1200 | qo_offset_y = offset_y + start_m * BLOCK_M 1201 | 1202 | # Initialize block pointer for the mask, if provided 1203 | if MASK_TYPE > 0: 1204 | mask_offset = off_z * stride_mz + off_h * stride_mh 1205 | mask_block_ptr = tl.make_block_ptr( 1206 | base=Mask + mask_offset, 1207 | shape=(N_CTX, N_CTX), 1208 | strides=(stride_mm, stride_mn), 1209 | offsets=(start_m * BLOCK_M, 0), 1210 | block_shape=(BLOCK_M, BLOCK_N), 1211 | order=(1, 0), 1212 | ) 1213 | else: 1214 | mask_block_ptr = None 1215 | 1216 | # Initialize dropout offset for this block 1217 | philox_offset_base = off_hz * N_CTX * N_CTX 1218 | 1219 | # Initialize offsets for the query tokens to process 1220 | offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) 1221 | offs_n = tl.arange(0, BLOCK_N) 1222 | 1223 | # Initialize accumulator pointers: 1224 | # m, the running maximum (one for each query) 1225 | m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") 1226 | # l, the running sum (one for each query as we sum the attention scores by rows) 1227 | l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 1228 | # acc, the output accumulator (one vector for each query) 1229 | acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) 1230 | 1231 | if ENABLE_JVP: 1232 | desc_q_t = _maybe_make_tensor_desc( 1233 | desc_q_t, 1234 | shape=[y_dim, HEAD_DIM], 1235 | strides=[HEAD_DIM, 1], 1236 | block_shape=[BLOCK_M, HEAD_DIM], 1237 | ) 1238 | if FP8_OUTPUT: 1239 | t_v_shape = [HEAD_DIM, y_dim] 1240 | t_v_strides = [N_CTX, 1] 1241 | t_v_block_shape = [HEAD_DIM, BLOCK_N] 1242 | else: 1243 | t_v_shape = [y_dim, HEAD_DIM] 1244 | t_v_strides = [HEAD_DIM, 1] 1245 | t_v_block_shape = [BLOCK_N, HEAD_DIM] 1246 | desc_v_t = _maybe_make_tensor_desc( 1247 | desc_v_t, shape=t_v_shape, strides=t_v_strides, block_shape=t_v_block_shape 1248 | ) 1249 | desc_k_t = _maybe_make_tensor_desc( 1250 | desc_k_t, 1251 | shape=[y_dim, HEAD_DIM], 1252 | strides=[HEAD_DIM, 1], 1253 | block_shape=[BLOCK_N, HEAD_DIM], 1254 | ) 1255 | desc_o_t = _maybe_make_tensor_desc( 1256 | desc_o_t, 1257 | shape=[y_dim, HEAD_DIM], 1258 | strides=[HEAD_DIM, 1], 1259 | block_shape=[BLOCK_M, HEAD_DIM], 1260 | ) 1261 | # Load t_q: It will stay in SRAM throughout. 1262 | t_q = desc_q_t.load([qo_offset_y, 0]) 1263 | g_acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) 1264 | mu_i = tl.zeros([BLOCK_M], dtype=tl.float32) 1265 | p_tv_acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) 1266 | else: 1267 | t_q = None 1268 | desc_k_t = None 1269 | desc_v_t = None 1270 | # Allocate minimal dummy tensors to keep consistent the return signature of _attn_fwd_inner_tma 1271 | g_acc = tl.zeros([1, 1], dtype=tl.float32) 1272 | mu_i = tl.zeros([1], dtype=tl.float32) 1273 | p_tv_acc = tl.zeros([1, 1], dtype=tl.float32) 1274 | 1275 | # Prepare scales 1276 | qk_scale = sm_scale 1277 | qk_scale *= 1.44269504 # 1/log(2) 1278 | 1279 | # Load q: It will stay in SRAM throughout. 1280 | q = desc_q.load([qo_offset_y, 0]) 1281 | 1282 | # Stage: 3 if causal, else 1 1283 | if STAGE == 1 or STAGE == 3: 1284 | # NOTE: This step runs for non-causal attention or for the 1285 | # blocks to the left of the diagonal for causal attention 1286 | acc, g_acc, l_i, m_i, mu_i, p_tv_acc = _attn_fwd_inner_tma( 1287 | acc, 1288 | g_acc, # 1289 | l_i, 1290 | m_i, # 1291 | mu_i, 1292 | p_tv_acc, # 1293 | q, 1294 | t_q, # 1295 | desc_k, 1296 | desc_v, # 1297 | desc_k_t, 1298 | desc_v_t, # 1299 | offset_y, 1300 | mask_block_ptr, 1301 | dropout_p, 1302 | philox_seed, 1303 | philox_offset_base, 1304 | dtype, 1305 | start_m, 1306 | qk_scale, 1307 | sm_scale, # 1308 | BLOCK_M, 1309 | HEAD_DIM, 1310 | BLOCK_N, # 1311 | 4 - STAGE, 1312 | offs_m, 1313 | offs_n, 1314 | N_CTX, # 1315 | warp_specialize, 1316 | ENABLE_JVP, 1317 | ENABLE_DROPOUT, 1318 | MASK_TYPE, 1319 | ) 1320 | 1321 | if STAGE == 3: 1322 | # NOTE: This step runs for the blocks to the 1323 | # right of the diagonal for causal attention 1324 | acc, g_acc, l_i, m_i, mu_i, p_tv_acc = _attn_fwd_inner_tma( 1325 | acc, 1326 | g_acc, # 1327 | l_i, 1328 | m_i, # 1329 | mu_i, 1330 | p_tv_acc, # 1331 | q, 1332 | t_q, # 1333 | desc_k, 1334 | desc_v, # 1335 | desc_k_t, 1336 | desc_v_t, # 1337 | offset_y, 1338 | mask_block_ptr, 1339 | dropout_p, 1340 | philox_seed, 1341 | philox_offset_base, 1342 | dtype, 1343 | start_m, 1344 | qk_scale, 1345 | sm_scale, # 1346 | BLOCK_M, 1347 | HEAD_DIM, 1348 | BLOCK_N, # 1349 | 2, 1350 | offs_m, 1351 | offs_n, 1352 | N_CTX, # 1353 | warp_specialize, 1354 | ENABLE_JVP, 1355 | ENABLE_DROPOUT, 1356 | MASK_TYPE, 1357 | ) 1358 | 1359 | # Epilogue 1360 | empty_mask = l_i == 0.0 1361 | if empty_mask.sum() > 0: 1362 | l_i = tl.where( 1363 | empty_mask, 1.0, l_i 1364 | ) # NOTE: This happens if the entire block is masked out. 1365 | 1366 | m_i = m_i + tl.where( 1367 | # NOTE: This is needed to compute the logsumexp for the backward pass. 1368 | empty_mask, 1369 | 0.0, 1370 | tl.math.log2(l_i), 1371 | ) 1372 | 1373 | acc = acc / l_i[:, None] 1374 | m_ptrs = M + off_hz * N_CTX + offs_m 1375 | tl.store(m_ptrs, m_i) 1376 | desc_o.store([qo_offset_y, 0], acc.to(desc_o.dtype)) 1377 | 1378 | if ENABLE_JVP: 1379 | t_p_v = g_acc / l_i[:, None] - (mu_i / l_i)[:, None] * acc 1380 | t_y_out = t_p_v + p_tv_acc / l_i[:, None] 1381 | desc_o_t.store([qo_offset_y, 0], t_y_out.to(desc_o_t.dtype)) 1382 | 1383 | 1384 | @triton.jit 1385 | def _attn_bwd_preprocess( 1386 | O, DO, Delta, N_CTX, BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr # noqa: E741 1387 | ): 1388 | """Preprocess output deltas for the backward attention pass. 1389 | 1390 | Args: 1391 | O: Output tensor. 1392 | DO: Gradient of the output tensor. 1393 | Delta: Accumulated gradients. 1394 | N_CTX: Context length. 1395 | BLOCK_M: Block size for M dimension. 1396 | HEAD_DIM: Head dimension size. 1397 | """ 1398 | # Collect sequence, batch, and head indices 1399 | off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) 1400 | off_hz = tl.program_id(1) 1401 | off_n = tl.arange(0, HEAD_DIM) 1402 | 1403 | # Load outputs and gradients 1404 | o = tl.load(O + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]) 1405 | do = tl.load(DO + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]).to( 1406 | tl.float32 1407 | ) 1408 | delta = tl.sum(o * do, axis=1) 1409 | 1410 | # Write-back the intermediate delta results 1411 | tl.store(Delta + off_hz * N_CTX + off_m, delta) 1412 | 1413 | 1414 | # The main inner-loop logic for computing dK and dV. 1415 | @triton.jit 1416 | def _attn_bwd_dkdv( 1417 | dk, 1418 | dv, # 1419 | Q, 1420 | k, 1421 | v, 1422 | DO, # 1423 | M, 1424 | D, # 1425 | # shared by Q/K/V/DO. 1426 | stride_tok, 1427 | stride_d, # 1428 | N_CTX, 1429 | BLOCK_M1: tl.constexpr, # 1430 | BLOCK_N1: tl.constexpr, # 1431 | HEAD_DIM: tl.constexpr, # 1432 | # Filled in by the wrapper. 1433 | start_n, 1434 | start_m, 1435 | num_steps, # 1436 | CAUSAL_MASKING: tl.constexpr, 1437 | # Args for masking/dropout 1438 | mask_ptr, 1439 | mask_stride_tok1, 1440 | mask_stride_tok2, 1441 | MASK_TYPE: tl.constexpr, 1442 | dropout_p, 1443 | philox_seed, 1444 | philox_offset_base, 1445 | ENABLE_DROPOUT: tl.constexpr, 1446 | MASK_CONST: tl.constexpr = MASK_CONST, 1447 | ): 1448 | """The main inner-loop logic for computing dK and dV. 1449 | 1450 | Args: 1451 | dk: Gradient of the key tensor. 1452 | dv: Gradient of the value tensor. 1453 | Q: Query tensor. 1454 | k: Key tensor. 1455 | v: Value tensor. 1456 | DO: Gradient of the output tensor. 1457 | M: Memory tensor. 1458 | D: Delta tensor. 1459 | stride_tok: Stride for the token dimension. 1460 | stride_d: Stride for the head dimension. 1461 | N_CTX: Context length. 1462 | BLOCK_M1: Block size for M dimension. 1463 | BLOCK_N1: Block size for N dimension. 1464 | HEAD_DIM: Head dimension size. 1465 | start_n: Starting index for N dimension. 1466 | start_m: Starting index for M dimension. 1467 | num_steps: Number of steps to unroll. 1468 | CAUSAL_MASKING: Flag for causal masking. 1469 | mask_ptr: Pointer to the mask tensor. 1470 | mask_stride_tok1: Stride for the third (row) dimension of the mask tensor. 1471 | mask_stride_tok2: Stride for the fourth (column) dimension of the mask tensor. 1472 | MASK_TYPE: Type of masking (0: no mask, 1: boolean mask, 1473 | 2: additive mask). 1474 | dropout_p: Dropout probability. 1475 | philox_seed: Seed for Philox RNG. 1476 | philox_offset_base: Base offset for Philox RNG. 1477 | ENABLE_DROPOUT: Flag to enable dropout. 1478 | MASK_CONST: Constant used for masking. 1479 | 1480 | Returns: 1481 | dk: Gradient of the key tensor. 1482 | dv: Gradient of the value tensor. 1483 | """ 1484 | # Initialize pointers for Q and DO 1485 | offs_m = start_m + tl.arange(0, BLOCK_M1) 1486 | offs_n = start_n + tl.arange(0, BLOCK_N1) 1487 | offs_h = tl.arange(0, HEAD_DIM) 1488 | qT_ptrs = Q + offs_m[None, :] * stride_tok + offs_h[:, None] * stride_d 1489 | do_ptrs = DO + offs_m[:, None] * stride_tok + offs_h[None, :] * stride_d 1490 | 1491 | if MASK_TYPE > 0: 1492 | mask_ptr = ( 1493 | mask_ptr + offs_m[None, :] * mask_stride_tok1 + offs_n[:, None] * mask_stride_tok2 1494 | ) 1495 | 1496 | # NOTE: BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. 1497 | tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) 1498 | curr_m = start_m 1499 | step_m = BLOCK_M1 1500 | dtype = tl.float32 # For dot products 1501 | 1502 | # Iteratively compute dK and dV over the M dimension 1503 | for _ in range(num_steps): 1504 | qT = tl.load(qT_ptrs) 1505 | 1506 | # Load m before computing qk to reduce pipeline stall 1507 | offs_m = curr_m + tl.arange(0, BLOCK_M1) 1508 | m = tl.load(M + offs_m) 1509 | qkT = tl.dot(k, qT) 1510 | 1511 | # Exponentiation 1512 | pT = tl.math.exp2(qkT - m[None, :]) 1513 | 1514 | # External masking after exponentiation 1515 | if MASK_TYPE > 0: 1516 | mask = tl.load(mask_ptr) 1517 | if MASK_TYPE == 1: # Boolean mask 1518 | pT = tl.where(mask == 1, pT, 0.0) 1519 | elif MASK_TYPE == 2: # Additive mask 1520 | # 'mask' is the additive mask loaded above (MASK_CONST not allowed, all other values allowed) 1521 | attend = mask != MASK_CONST 1522 | pT = tl.where(attend, pT, 0.0) 1523 | 1524 | # (or) Causal masking after exponentiation 1525 | elif CAUSAL_MASKING: 1526 | causal_mask = offs_m[None, :] >= offs_n[:, None] 1527 | pT = tl.where(causal_mask, pT, 0.0) 1528 | 1529 | # Dropout after exponentiation 1530 | if ENABLE_DROPOUT: 1531 | philox_offset = philox_offset_base + curr_m * N_CTX + start_n 1532 | dropout_mask, dropout_scale = create_dropout_mask( 1533 | philox_seed, philox_offset, dropout_p, BLOCK_M1, BLOCK_N1, N_CTX 1534 | ) 1535 | pT = pT * dropout_mask.to(pT.dtype) * dropout_scale 1536 | 1537 | # Compute dV 1538 | ppT = pT 1539 | ppT = ppT.to(dtype) 1540 | do = tl.load(do_ptrs) 1541 | dv += tl.dot(ppT, do.to(dtype)).to(do.dtype) 1542 | # NOTE: D (= delta) is pre-divided by ds_scale. 1543 | Di = tl.load(D + offs_m) 1544 | 1545 | # Compute dP and dS to derive dK 1546 | dpT = tl.dot(v, tl.trans(do)).to(tl.float32) 1547 | 1548 | if ENABLE_DROPOUT: # This derivative should be masked with the same dropout mask 1549 | dpT = dpT * dropout_mask.to(dpT.dtype) * dropout_scale 1550 | 1551 | dsT = pT * (dpT - Di[None, :]) 1552 | dsT = dsT.to(dtype) 1553 | dk += tl.dot(dsT, tl.trans(qT).to(dtype)).to(qT.dtype) 1554 | 1555 | # Increment pointers 1556 | curr_m += step_m 1557 | qT_ptrs += step_m * stride_tok 1558 | do_ptrs += step_m * stride_tok 1559 | 1560 | if MASK_TYPE > 0: 1561 | mask_ptr += step_m * mask_stride_tok1 1562 | 1563 | return dk, dv 1564 | 1565 | 1566 | # The main inner-loop logic for computing dQ 1567 | @triton.jit 1568 | def _attn_bwd_dq( 1569 | dq, 1570 | q, 1571 | K, 1572 | V, # 1573 | do, 1574 | m, 1575 | D, 1576 | # shared by Q/K/V/DO. 1577 | stride_tok, 1578 | stride_d, # 1579 | N_CTX, # 1580 | BLOCK_M2: tl.constexpr, # 1581 | BLOCK_N2: tl.constexpr, # 1582 | HEAD_DIM: tl.constexpr, 1583 | # Filled in by the wrapper. 1584 | start_m, 1585 | start_n, 1586 | num_steps, # 1587 | CAUSAL_MASKING: tl.constexpr, 1588 | # Args for masking/dropout 1589 | mask_ptr, 1590 | mask_stride_tok1, 1591 | mask_stride_tok2, 1592 | MASK_TYPE: tl.constexpr, 1593 | dropout_p, 1594 | philox_seed, 1595 | philox_offset_base, 1596 | ENABLE_DROPOUT: tl.constexpr, 1597 | MASK_CONST: tl.constexpr = MASK_CONST, 1598 | ): 1599 | """The main inner-loop logic for computing dQ. 1600 | 1601 | Args: 1602 | dq: Gradient of the query tensor. 1603 | q: Query tensor. 1604 | K: Key tensor. 1605 | V: Value tensor. 1606 | do: Gradient of the output tensor. 1607 | m: Memory tensor. 1608 | D: Delta tensor. 1609 | stride_tok: Stride for the token dimension. 1610 | stride_d: Stride for the head dimension. 1611 | N_CTX: Context length. 1612 | BLOCK_M2: Block size for M dimension. 1613 | BLOCK_N2: Block size for N dimension. 1614 | HEAD_DIM: Head dimension size. 1615 | start_m: Starting index for M dimension. 1616 | start_n: Starting index for N dimension. 1617 | num_steps: Number of steps to unroll. 1618 | CAUSAL_MASKING: Flag for causal masking. 1619 | mask_ptr: Pointer to the mask tensor. 1620 | mask_stride_tok1: Stride for the third (row) dimension of the mask tensor. 1621 | mask_stride_tok2: Stride for the fourth (column) dimension of the mask tensor. 1622 | MASK_TYPE: Type of masking (0: no mask, 1: boolean mask, 1623 | 2: additive mask). 1624 | dropout_p: Dropout probability. 1625 | philox_seed: Seed for Philox RNG. 1626 | philox_offset_base: Base offset for Philox RNG. 1627 | ENABLE_DROPOUT: Flag to enable dropout. 1628 | MASK_CONST: Constant used for masking. 1629 | 1630 | Returns: 1631 | dq: Gradient of the query tensor. 1632 | """ 1633 | # Initialize pointers for K, V, and DO 1634 | offs_m = start_m + tl.arange(0, BLOCK_M2) 1635 | offs_n = start_n + tl.arange(0, BLOCK_N2) 1636 | offs_h = tl.arange(0, HEAD_DIM) 1637 | kT_ptrs = K + offs_n[None, :] * stride_tok + offs_h[:, None] * stride_d 1638 | vT_ptrs = V + offs_n[None, :] * stride_tok + offs_h[:, None] * stride_d 1639 | 1640 | if MASK_TYPE > 0: 1641 | mask_ptr = ( 1642 | mask_ptr + offs_m[:, None] * mask_stride_tok1 + offs_n[None, :] * mask_stride_tok2 1643 | ) 1644 | 1645 | # NOTE: D (= delta) is pre-divided by ds_scale. 1646 | Di = tl.load(D + offs_m) 1647 | 1648 | # NOTE: BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. 1649 | tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) 1650 | curr_n = start_n 1651 | step_n = BLOCK_N2 1652 | dtype = tl.float32 # For dot products 1653 | 1654 | # Iteratively compute dQ over the N dimension 1655 | for _ in range(num_steps): 1656 | offs_n = curr_n + tl.arange(0, BLOCK_N2) 1657 | kT = tl.load(kT_ptrs) 1658 | vT = tl.load(vT_ptrs) 1659 | qk = tl.dot(q, kT) 1660 | 1661 | # Exponentiation 1662 | p = tl.math.exp2(qk - m) 1663 | 1664 | # External masking after exponentiation 1665 | if MASK_TYPE > 0: 1666 | mask = tl.load(mask_ptr) 1667 | if MASK_TYPE == 1: # Boolean mask 1668 | p = tl.where(mask == 1, p, 0.0) 1669 | elif MASK_TYPE == 2: # Additive mask 1670 | attend = mask != MASK_CONST 1671 | p = tl.where(attend, p, 0.0) 1672 | 1673 | # (or) Causal masking after exponentiation 1674 | elif CAUSAL_MASKING: 1675 | causal_mask = offs_m[:, None] >= offs_n[None, :] 1676 | p = tl.where(causal_mask, p, 0.0) 1677 | 1678 | # Dropout after exponentiation 1679 | if ENABLE_DROPOUT: 1680 | philox_offset = philox_offset_base + start_m * N_CTX + curr_n 1681 | dropout_mask, dropout_scale = create_dropout_mask( 1682 | philox_seed, philox_offset, dropout_p, BLOCK_M2, BLOCK_N2, N_CTX 1683 | ) 1684 | p = p * dropout_mask.to(p.dtype) * dropout_scale 1685 | 1686 | # Compute dP and dS 1687 | dp = tl.dot(do, vT).to(tl.float32) 1688 | 1689 | if ENABLE_DROPOUT: # NOTE: This derivative should be masked with the same dropout mask. 1690 | dp = dp * dropout_mask.to(dp.dtype) * dropout_scale 1691 | 1692 | ds = p * (dp - Di[:, None]) 1693 | 1694 | # Compute dQ 1695 | # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. 1696 | ds = ds.to(dtype) 1697 | dq += tl.dot(ds, tl.trans(kT).to(dtype)).to(kT.dtype) 1698 | 1699 | # Increment pointers 1700 | curr_n += step_n 1701 | kT_ptrs += step_n * stride_tok 1702 | vT_ptrs += step_n * stride_tok 1703 | 1704 | if MASK_TYPE > 0: 1705 | mask_ptr += step_n * mask_stride_tok2 1706 | 1707 | return dq 1708 | 1709 | 1710 | @triton.jit 1711 | def _attn_bwd_causal( 1712 | Q, 1713 | K, 1714 | V, 1715 | sm_scale, # 1716 | DO, # 1717 | DQ, 1718 | DK, 1719 | DV, # 1720 | M, 1721 | D, 1722 | # Shared by Q/K/V/DO. 1723 | stride_z, 1724 | stride_h, 1725 | stride_tok, 1726 | stride_d, # 1727 | # Used for the mask. 1728 | mask_stride_z, 1729 | mask_stride_h, 1730 | mask_stride_tok1, 1731 | mask_stride_tok2, 1732 | # Dimensions and sizes. 1733 | H, 1734 | N_CTX, # 1735 | BLOCK_M1: tl.constexpr, # 1736 | BLOCK_N1: tl.constexpr, # 1737 | BLOCK_M2: tl.constexpr, # 1738 | BLOCK_N2: tl.constexpr, # 1739 | BLK_SLICE_FACTOR: tl.constexpr, # 1740 | HEAD_DIM: tl.constexpr, 1741 | # Args for masking/dropout. 1742 | mask_ptr, 1743 | MASK_TYPE: tl.constexpr, 1744 | dropout_p, 1745 | philox_seed, 1746 | ENABLE_DROPOUT: tl.constexpr, 1747 | ): 1748 | """The main backward pass for the (causal) attention mechanism. 1749 | 1750 | This computes gradients for only ~N²/2 pairwise token interactions, 1751 | since causal attention already masks out half of the interactions. 1752 | 1753 | Args: 1754 | Q: Query tensor. 1755 | K: Key tensor. 1756 | V: Value tensor. 1757 | sm_scale: Scale factor for the softmax. 1758 | DO: Gradient of the output tensor. 1759 | DQ: Gradient of the query tensor. 1760 | DK: Gradient of the key tensor. 1761 | DV: Gradient of the value tensor. 1762 | M: Memory tensor. 1763 | D: Delta tensor. 1764 | stride_z: Stride for the z dimension. 1765 | stride_h: Stride for the head dimension. 1766 | stride_tok: Stride for the token dimension. 1767 | stride_d: Stride for the head dimension. 1768 | mask_stride_z: Stride for the z dimension in the mask tensor. 1769 | mask_stride_h: Stride for the head dimension in the mask tensor. 1770 | mask_stride_tok1: Stride for the first token (row) dimension in the mask tensor. 1771 | mask_stride_tok2: Stride for the second token (column) dimension in the mask tensor. 1772 | H: Head dimension. 1773 | N_CTX: Context length. 1774 | BLOCK_M1: Block size for M dimension. 1775 | BLOCK_N1: Block size for N dimension. 1776 | BLOCK_M2: Block size for M dimension. 1777 | BLOCK_N2: Block size for N dimension. 1778 | BLK_SLICE_FACTOR: Block slice factor. 1779 | HEAD_DIM: Head dimension size. 1780 | mask_ptr: Pointer to the mask tensor. 1781 | MASK_TYPE: Type of masking (0: no mask, 1: boolean mask, 1782 | 2: additive mask). 1783 | dropout_p: Dropout probability. 1784 | philox_seed: Seed for Philox RNG. 1785 | ENABLE_DROPOUT: Flag to enable dropout. 1786 | """ 1787 | # Constants 1788 | LN2: tl.constexpr = 0.6931471824645996 # = ln(2) 1789 | 1790 | # Collect sequence, batch, and head indices 1791 | start_block_id = tl.program_id(0) # Which block (in the input query sequence) to process 1792 | off_hz = tl.program_id( 1793 | 1 1794 | ) # Which head and batch element to process, with a program being a single head of a single batch element 1795 | off_z = ( 1796 | off_hz // H 1797 | ) # Which batch element this program is assigned to (n.b., each batch element has H heads) 1798 | off_h = off_hz % H # The position of the head to process in the batch 1799 | 1800 | # NOTE: This allows one to get the (N_CTX, HEAD_DIM) block in Q, K, V, etc. by indexing it by batch and head 1801 | delta_shared_offset = (off_hz * N_CTX).to(tl.int64) 1802 | qkv_shared_offset = off_z.to(tl.int64) * stride_z + off_h.to(tl.int64) * stride_h 1803 | 1804 | # Offset pointers for batch elements and heads 1805 | Q += qkv_shared_offset 1806 | K += qkv_shared_offset 1807 | V += qkv_shared_offset 1808 | DO += qkv_shared_offset 1809 | DQ += qkv_shared_offset 1810 | DK += qkv_shared_offset 1811 | DV += qkv_shared_offset 1812 | 1813 | M += delta_shared_offset # NOTE: These tensors have fewer dimensions. 1814 | D += delta_shared_offset 1815 | 1816 | # Initialize pointer for the mask, if provided 1817 | if MASK_TYPE > 0: 1818 | mask_offset = off_z.to(tl.int64) * mask_stride_z + off_h.to(tl.int64) * mask_stride_h 1819 | mask_ptr += mask_offset 1820 | 1821 | # Generate philox offset for this block 1822 | philox_offset_base = off_hz * N_CTX * N_CTX 1823 | 1824 | # ====== COMPUTE dK and dV ====== 1825 | # Determine step size for dK and dV computation 1826 | MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR 1827 | 1828 | # Prepare offsets for loading Q/K/V/DO 1829 | start_n = start_block_id * BLOCK_N1 1830 | start_m = start_n 1831 | 1832 | # Load K and V: They will stay in SRAM throughout. 1833 | offs_n = start_n + tl.arange(0, BLOCK_N1) 1834 | offs_h = tl.arange(0, HEAD_DIM) 1835 | 1836 | k = tl.load(K + offs_n[:, None] * stride_tok + offs_h[None, :] * stride_d) 1837 | v = tl.load(V + offs_n[:, None] * stride_tok + offs_h[None, :] * stride_d) 1838 | 1839 | # Initialize dK and dV accumulators 1840 | dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) 1841 | dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) 1842 | 1843 | # Compute dK and dV for (causally) masked blocks 1844 | num_steps = BLOCK_N1 // MASK_BLOCK_M1 1845 | dk, dv = _attn_bwd_dkdv( 1846 | dk, 1847 | dv, # 1848 | Q, 1849 | k, 1850 | v, 1851 | DO, # 1852 | M, 1853 | D, # 1854 | stride_tok, 1855 | stride_d, # 1856 | N_CTX, # 1857 | MASK_BLOCK_M1, 1858 | BLOCK_N1, 1859 | HEAD_DIM, # 1860 | start_n, 1861 | start_m, 1862 | num_steps, # 1863 | CAUSAL_MASKING=True, # 1864 | mask_ptr=mask_ptr, 1865 | mask_stride_tok1=mask_stride_tok1, 1866 | mask_stride_tok2=mask_stride_tok2, 1867 | MASK_TYPE=MASK_TYPE, 1868 | dropout_p=dropout_p, 1869 | philox_seed=philox_seed, 1870 | philox_offset_base=philox_offset_base, 1871 | ENABLE_DROPOUT=ENABLE_DROPOUT, 1872 | ) 1873 | 1874 | start_m += num_steps * MASK_BLOCK_M1 1875 | num_steps = (N_CTX - start_m) // BLOCK_M1 1876 | 1877 | # Compute dK and dV for (causally) non-masked blocks 1878 | dk, dv = _attn_bwd_dkdv( # 1879 | dk, 1880 | dv, # 1881 | Q, 1882 | k, 1883 | v, 1884 | DO, # 1885 | M, 1886 | D, # 1887 | stride_tok, 1888 | stride_d, # 1889 | N_CTX, # 1890 | BLOCK_M1, 1891 | BLOCK_N1, 1892 | HEAD_DIM, # 1893 | start_n, 1894 | start_m, 1895 | num_steps, # 1896 | CAUSAL_MASKING=False, # 1897 | mask_ptr=mask_ptr, 1898 | mask_stride_tok1=mask_stride_tok1, 1899 | mask_stride_tok2=mask_stride_tok2, 1900 | MASK_TYPE=MASK_TYPE, 1901 | dropout_p=dropout_p, 1902 | philox_seed=philox_seed, 1903 | philox_offset_base=philox_offset_base, 1904 | ENABLE_DROPOUT=ENABLE_DROPOUT, 1905 | ) 1906 | 1907 | # Write-back dV 1908 | dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_h[None, :] * stride_d 1909 | tl.store(dv_ptrs, dv) 1910 | 1911 | # Write-back dK (scaled) 1912 | dk *= sm_scale 1913 | dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_h[None, :] * stride_d 1914 | tl.store(dk_ptrs, dk) 1915 | 1916 | # ====== COMPUTE dQ ====== 1917 | # Determine step size for dQ computation 1918 | MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR 1919 | 1920 | # Prepare offsets for dQ computation 1921 | start_m = start_block_id * BLOCK_M2 1922 | end_n = start_m + BLOCK_M2 1923 | 1924 | offs_m = start_m + tl.arange(0, BLOCK_M2) 1925 | 1926 | # Load Q, DO, and M: They will stay in SRAM throughout. 1927 | q = tl.load(Q + offs_m[:, None] * stride_tok + offs_h[None, :] * stride_d) 1928 | do = tl.load(DO + offs_m[:, None] * stride_tok + offs_h[None, :] * stride_d) 1929 | 1930 | m = tl.load(M + offs_m) 1931 | m = m[:, None] 1932 | 1933 | # Initialize dQ accumulator 1934 | dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32) 1935 | 1936 | # Compute dQ for (causally) masked blocks 1937 | num_steps = BLOCK_M2 // MASK_BLOCK_N2 1938 | dq = _attn_bwd_dq( 1939 | # NOTE: This code scans each row of QK^T backward (from right to left, 1940 | # but inside each call to _attn_bwd_dq, from left to right), but that's 1941 | # not due to anything important. It's just to reuse the loop structure 1942 | # for dK and dV above as much as possible. 1943 | dq, 1944 | q, 1945 | K, 1946 | V, # 1947 | do, 1948 | m, 1949 | D, # 1950 | stride_tok, 1951 | stride_d, # 1952 | N_CTX, # 1953 | BLOCK_M2, 1954 | MASK_BLOCK_N2, 1955 | HEAD_DIM, # 1956 | start_m, 1957 | end_n - num_steps * MASK_BLOCK_N2, 1958 | num_steps, # 1959 | CAUSAL_MASKING=True, # 1960 | mask_ptr=mask_ptr, 1961 | mask_stride_tok1=mask_stride_tok1, 1962 | mask_stride_tok2=mask_stride_tok2, 1963 | MASK_TYPE=MASK_TYPE, 1964 | dropout_p=dropout_p, 1965 | philox_seed=philox_seed, 1966 | philox_offset_base=philox_offset_base, 1967 | ENABLE_DROPOUT=ENABLE_DROPOUT, 1968 | ) 1969 | 1970 | end_n -= num_steps * MASK_BLOCK_N2 1971 | 1972 | # Compute dQ for (causally) non-masked blocks 1973 | num_steps = end_n // BLOCK_N2 1974 | dq = _attn_bwd_dq( 1975 | dq, 1976 | q, 1977 | K, 1978 | V, # 1979 | do, 1980 | m, 1981 | D, # 1982 | stride_tok, 1983 | stride_d, # 1984 | N_CTX, # 1985 | BLOCK_M2, 1986 | BLOCK_N2, 1987 | HEAD_DIM, # 1988 | start_m, 1989 | end_n - num_steps * BLOCK_N2, 1990 | num_steps, # 1991 | CAUSAL_MASKING=False, # 1992 | mask_ptr=mask_ptr, 1993 | mask_stride_tok1=mask_stride_tok1, 1994 | mask_stride_tok2=mask_stride_tok2, 1995 | MASK_TYPE=MASK_TYPE, 1996 | dropout_p=dropout_p, 1997 | philox_seed=philox_seed, 1998 | philox_offset_base=philox_offset_base, 1999 | ENABLE_DROPOUT=ENABLE_DROPOUT, 2000 | ) 2001 | 2002 | # Write-back dQ (scaled) 2003 | dq *= LN2 2004 | dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_h[None, :] * stride_d 2005 | tl.store(dq_ptrs, dq) 2006 | 2007 | 2008 | @triton.jit 2009 | def _attn_bwd( 2010 | Q, 2011 | K, 2012 | V, 2013 | sm_scale, # 2014 | DO, # 2015 | DQ, 2016 | DK, 2017 | DV, # 2018 | M, 2019 | D, 2020 | # Shared by Q/K/V/DO. 2021 | stride_z, 2022 | stride_h, 2023 | stride_tok, 2024 | stride_d, # 2025 | # Used for the mask. 2026 | mask_stride_z, 2027 | mask_stride_h, 2028 | mask_stride_tok1, 2029 | mask_stride_tok2, 2030 | # Dimensions and sizes. 2031 | H, 2032 | N_CTX, # 2033 | BLOCK_M1: tl.constexpr, # 2034 | BLOCK_N1: tl.constexpr, # 2035 | BLOCK_M2: tl.constexpr, # 2036 | BLOCK_N2: tl.constexpr, # 2037 | BLK_SLICE_FACTOR: tl.constexpr, # 2038 | HEAD_DIM: tl.constexpr, 2039 | # Args for masking/dropout. 2040 | mask_ptr, 2041 | MASK_TYPE: tl.constexpr, 2042 | dropout_p, 2043 | philox_seed, 2044 | ENABLE_DROPOUT: tl.constexpr, 2045 | ): 2046 | """The main backward pass for the (non-causal) attention mechanism. 2047 | 2048 | This computes gradients for all N² pairwise token interactions, 2049 | unlike the causal version which only computes ~N²/2. 2050 | 2051 | Args: 2052 | Q: Query tensor. 2053 | K: Key tensor. 2054 | V: Value tensor. 2055 | sm_scale: Scale factor for the softmax. 2056 | DO: Gradient of the output tensor. 2057 | DQ: Gradient of the query tensor. 2058 | DK: Gradient of the key tensor. 2059 | DV: Gradient of the value tensor. 2060 | M: Memory tensor. 2061 | D: Delta tensor. 2062 | stride_z: Stride for the z dimension. 2063 | stride_h: Stride for the head dimension. 2064 | stride_tok: Stride for the token dimension. 2065 | stride_d: Stride for the head dimension. 2066 | mask_stride_z: Stride for the z dimension in the mask tensor. 2067 | mask_stride_h: Stride for the head dimension in the mask tensor. 2068 | mask_stride_tok1: Stride for the first token (row) dimension in the mask tensor. 2069 | mask_stride_tok2: Stride for the second token (column) dimension in the mask tensor. 2070 | H: Head dimension. 2071 | N_CTX: Context length. 2072 | BLOCK_M1: Block size for M dimension. 2073 | BLOCK_N1: Block size for N dimension. 2074 | BLOCK_M2: Block size for M dimension. 2075 | BLOCK_N2: Block size for N dimension. 2076 | BLK_SLICE_FACTOR: Block slice factor. 2077 | HEAD_DIM: Head dimension size. 2078 | mask_ptr: Pointer to the mask tensor. 2079 | MASK_TYPE: Type of masking (0: no mask, 1: boolean mask, 2080 | 2: additive mask). 2081 | dropout_p: Dropout probability. 2082 | philox_seed: Seed for Philox RNG. 2083 | ENABLE_DROPOUT: Flag to enable dropout. 2084 | """ 2085 | # Constants 2086 | LN2: tl.constexpr = 0.6931471824645996 # = ln(2) 2087 | 2088 | # Collect sequence, batch, and head indices 2089 | start_block_id = tl.program_id(0) # Which block (in the input query sequence) to process 2090 | off_hz = tl.program_id( 2091 | 1 2092 | ) # Which head and batch element to process, with a program being a single head of a single batch element 2093 | off_z = ( 2094 | off_hz // H 2095 | ) # Which batch element this program is assigned to (n.b., each batch element has H heads) 2096 | off_h = off_hz % H # The position of the head to process in the batch 2097 | 2098 | # NOTE: This allows one to get the (N_CTX, HEAD_DIM) block in Q, K, V, etc. by indexing it by batch and head 2099 | delta_shared_offset = (off_hz * N_CTX).to(tl.int64) 2100 | qkv_shared_offset = off_z.to(tl.int64) * stride_z + off_h.to(tl.int64) * stride_h 2101 | 2102 | # Offset pointers for batch elements and heads 2103 | Q += qkv_shared_offset 2104 | K += qkv_shared_offset 2105 | V += qkv_shared_offset 2106 | DO += qkv_shared_offset 2107 | DQ += qkv_shared_offset 2108 | DK += qkv_shared_offset 2109 | DV += qkv_shared_offset 2110 | 2111 | M += delta_shared_offset # NOTE: These tensors have fewer dimensions. 2112 | D += delta_shared_offset 2113 | 2114 | # Initialize pointer for the mask, if provided 2115 | if MASK_TYPE > 0: 2116 | mask_offset = off_z.to(tl.int64) * mask_stride_z + off_h.to(tl.int64) * mask_stride_h 2117 | mask_ptr += mask_offset 2118 | 2119 | # Generate philox offset for this block 2120 | philox_offset_base = off_hz * N_CTX * N_CTX 2121 | 2122 | # ====== COMPUTE dK and dV ====== 2123 | # For non-causal attention, we process ALL query blocks (the entire sequence) 2124 | # This is the key difference from causal: we iterate through all Q positions 2125 | 2126 | # Prepare offsets for loading Q/K/V/DO 2127 | start_n = start_block_id * BLOCK_N1 2128 | 2129 | # Load K and V: They will stay in SRAM throughout. 2130 | offs_n = start_n + tl.arange(0, BLOCK_N1) 2131 | offs_h = tl.arange(0, HEAD_DIM) 2132 | 2133 | k = tl.load(K + offs_n[:, None] * stride_tok + offs_h[None, :] * stride_d) 2134 | v = tl.load(V + offs_n[:, None] * stride_tok + offs_h[None, :] * stride_d) 2135 | 2136 | # Initialize dK and dV accumulators 2137 | dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) 2138 | dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) 2139 | 2140 | start_m = 0 # Start from the beginning of the sequence 2141 | num_steps = N_CTX // BLOCK_M1 # Process the entire sequence 2142 | 2143 | dk, dv = _attn_bwd_dkdv( 2144 | dk, 2145 | dv, # 2146 | Q, 2147 | k, 2148 | v, 2149 | DO, # 2150 | M, 2151 | D, # 2152 | stride_tok, 2153 | stride_d, # 2154 | N_CTX, # 2155 | BLOCK_M1, 2156 | BLOCK_N1, 2157 | HEAD_DIM, # 2158 | start_n, 2159 | start_m, 2160 | num_steps, # 2161 | CAUSAL_MASKING=False, 2162 | mask_ptr=mask_ptr, 2163 | mask_stride_tok1=mask_stride_tok1, 2164 | mask_stride_tok2=mask_stride_tok2, 2165 | MASK_TYPE=MASK_TYPE, 2166 | dropout_p=dropout_p, 2167 | philox_seed=philox_seed, 2168 | philox_offset_base=philox_offset_base, 2169 | ENABLE_DROPOUT=ENABLE_DROPOUT, 2170 | ) 2171 | 2172 | # Write-back dV 2173 | dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_h[None, :] * stride_d 2174 | tl.store(dv_ptrs, dv) 2175 | 2176 | # Write-back dK (scaled) 2177 | dk *= sm_scale 2178 | dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_h[None, :] * stride_d 2179 | tl.store(dk_ptrs, dk) 2180 | 2181 | # ====== COMPUTE dQ ====== 2182 | # Prepare offsets for dQ computation 2183 | start_m = start_block_id * BLOCK_M2 2184 | offs_m = start_m + tl.arange(0, BLOCK_M2) 2185 | 2186 | # Load Q, DO, and M: They will stay in SRAM throughout. 2187 | q = tl.load(Q + offs_m[:, None] * stride_tok + offs_h[None, :] * stride_d) 2188 | do = tl.load(DO + offs_m[:, None] * stride_tok + offs_h[None, :] * stride_d) 2189 | 2190 | m = tl.load(M + offs_m) 2191 | m = m[:, None] 2192 | 2193 | # Initialize dQ accumulator 2194 | dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32) 2195 | 2196 | # For non-causal attention, we process ALL key/value blocks (the entire sequence) 2197 | # This means each query position can attend to ALL key/value positions 2198 | 2199 | start_n = 0 # Start from the beginning of the sequence 2200 | num_steps = N_CTX // BLOCK_N2 # Process the entire sequence 2201 | 2202 | dq = _attn_bwd_dq( 2203 | dq, 2204 | q, 2205 | K, 2206 | V, # 2207 | do, 2208 | m, 2209 | D, # 2210 | stride_tok, 2211 | stride_d, # 2212 | N_CTX, # 2213 | BLOCK_M2, 2214 | BLOCK_N2, 2215 | HEAD_DIM, # 2216 | start_m, 2217 | start_n, 2218 | num_steps, # 2219 | CAUSAL_MASKING=False, # 2220 | mask_ptr=mask_ptr, 2221 | mask_stride_tok1=mask_stride_tok1, 2222 | mask_stride_tok2=mask_stride_tok2, 2223 | MASK_TYPE=MASK_TYPE, 2224 | dropout_p=dropout_p, 2225 | philox_seed=philox_seed, 2226 | philox_offset_base=philox_offset_base, 2227 | ENABLE_DROPOUT=ENABLE_DROPOUT, 2228 | ) 2229 | 2230 | # Write-back dQ (scaled) 2231 | dq *= LN2 2232 | dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_h[None, :] * stride_d 2233 | tl.store(dq_ptrs, dq) 2234 | 2235 | 2236 | class JVPAttn(Function): 2237 | """JVP (Jacobian-Vector Product) for Attention Mechanism.""" 2238 | 2239 | class Grid(NamedTuple): 2240 | """Grid configuration for JVP Attention.""" 2241 | 2242 | M_BLOCKS: int 2243 | Z_H: int 2244 | ONE: Literal[1] 2245 | 2246 | class FnCtx(FunctionCtx): 2247 | """Function context for JVP Attention.""" 2248 | 2249 | sm_scale: float 2250 | HEAD_DIM_K: int 2251 | causal: bool 2252 | grid: JVPAttn.Grid 2253 | mask_tensor: Tensor 2254 | MASK_TYPE: int 2255 | dropout_p: float 2256 | philox_seed: int 2257 | ENABLE_DROPOUT: bool 2258 | 2259 | class FwdOutCtxContrib(NamedTuple): 2260 | """Forward output context contributions for JVP Attention.""" 2261 | 2262 | o_t: Tensor | None 2263 | M: Tensor 2264 | grid: JVPAttn.Grid 2265 | HEAD_DIM_K: int 2266 | sm_scale: float 2267 | mask_tensor: Tensor 2268 | MASK_TYPE: int 2269 | dropout_p: float 2270 | philox_seed: int 2271 | ENABLE_DROPOUT: bool 2272 | 2273 | class FwdOut(NamedTuple): 2274 | """Forward output for JVP Attention.""" 2275 | 2276 | o: Tensor 2277 | ctx: JVPAttn.FwdOutCtxContrib 2278 | 2279 | class JVPOut(NamedTuple): 2280 | """JVP output for JVP Attention.""" 2281 | 2282 | o: Tensor 2283 | ctx: None 2284 | 2285 | class BwdOut(NamedTuple): 2286 | """Backward output for JVP Attention.""" 2287 | 2288 | q: Tensor 2289 | k: Tensor 2290 | v: Tensor 2291 | q_t: None 2292 | k_t: None 2293 | v_t: None 2294 | attn_mask: None 2295 | dropout_p: None 2296 | causal: None 2297 | sm_scale: None 2298 | warp_specialize: None 2299 | USE_TMA: None 2300 | verify_attn_mask: None 2301 | 2302 | class Strides(NamedTuple): 2303 | """Strides for JVP Attention.""" 2304 | 2305 | z: int 2306 | h: int 2307 | n_ctx: int 2308 | head_dim: int 2309 | 2310 | @staticmethod 2311 | def forward( 2312 | q: Tensor, 2313 | k: Tensor, 2314 | v: Tensor, 2315 | q_t: Tensor | None, 2316 | k_t: Tensor | None, 2317 | v_t: Tensor | None, 2318 | attn_mask: Tensor | None = None, 2319 | dropout_p: float = 0.0, 2320 | causal: bool = False, 2321 | sm_scale: float | None = None, 2322 | warp_specialize: bool = True, 2323 | USE_TMA: bool = True, 2324 | verify_attn_mask: bool = True, 2325 | ) -> JVPAttn.FwdOut: 2326 | """Forward pass for JVP Attention. 2327 | 2328 | NOTE: The following warning(s) will be raised if `verify_attn_mask=True` 2329 | and an attention mask with any all-null head is provided: 2330 | `RuntimeWarning: overflow encountered in exp2.` 2331 | 2332 | Args: 2333 | q: Query tensor of shape (Z, H, N_CTX, HEAD_DIM_Q). 2334 | k: Key tensor of shape (Z, H, N_CTX, HEAD_DIM_K). 2335 | v: Value tensor of shape (Z, H, N_CTX, HEAD_DIM_V). 2336 | q_t: Optional tensor for query transpose. 2337 | k_t: Optional tensor for key transpose. 2338 | v_t: Optional tensor for value transpose. 2339 | attn_mask: Optional attention mask of shape (Z, H, N_CTX, N_CTX). 2340 | Two types of masks are supported. A boolean mask where a value 2341 | of True indicates that the element should take part in attention, 2342 | or a float mask of the same type as query, key, value that is added 2343 | to the attention score. The constant `MASK_CONST` is used to 2344 | indicate masked positions in the float mask. All other values 2345 | denote unmasked positions. 2346 | dropout_p: Dropout probability. 2347 | causal: Whether the attention is causal. 2348 | sm_scale: Optional scaling factor for softmax. 2349 | warp_specialize: Whether to use warp specialization. 2350 | USE_TMA: Whether to use TMA. 2351 | verify_attn_mask: Whether to verify the correctness of the provided attention mask. 2352 | 2353 | Returns: 2354 | Outputs of JVP Attention. 2355 | """ 2356 | if dropout_p != 0.0: 2357 | raise NotImplementedError("Dropout is not currently supported in JVP attention.") 2358 | 2359 | # Collect metadata 2360 | Z, H, N_CTX, HEAD_DIM_Q = q.shape 2361 | HEAD_DIM_K = k.shape[-1] 2362 | HEAD_DIM_V = v.shape[-1] # NOTE: When v is in float8_e5m2 it is transposed. 2363 | 2364 | STAGE = 3 if causal else 1 2365 | ENABLE_JVP = q_t is not None 2366 | 2367 | assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V, ( 2368 | "JVP attention requires HEAD_DIM_Q == HEAD_DIM_K == HEAD_DIM_V" 2369 | f" but got HEAD_DIM_Q={HEAD_DIM_Q}, HEAD_DIM_K={HEAD_DIM_K}, HEAD_DIM_V={HEAD_DIM_V}" 2370 | ) 2371 | assert HEAD_DIM_K in {16, 32, 64, 128, 256}, ( 2372 | "JVP attention only supports HEAD_DIM_K in {16, 32, 64, 128, 256}," 2373 | f" but got HEAD_DIM_K={HEAD_DIM_K}", 2374 | ) 2375 | 2376 | if causal and attn_mask is not None: 2377 | raise ValueError("Causal attention does not support an attention mask.") 2378 | if attn_mask is not None: 2379 | assert attn_mask.shape == ( 2380 | Z, 2381 | H, 2382 | N_CTX, 2383 | N_CTX, 2384 | ), "The provided attention mask must have 4 dimensions (Z, H, N_CTX, N_CTX)." 2385 | assert attn_mask.dtype in { 2386 | torch.bool, 2387 | q.dtype, 2388 | }, "The attention mask must be of the dtype bool or that of the query tensor." 2389 | 2390 | # Initialize arguments and tensors 2391 | if sm_scale is None: 2392 | sm_scale = HEAD_DIM_K**-0.5 2393 | 2394 | o = torch.empty_like(q) 2395 | o_t: Tensor | None = torch.empty_like(q_t) if ENABLE_JVP else None 2396 | M = torch.empty((Z, H, N_CTX), device=q.device, dtype=torch.float32) 2397 | 2398 | # Tune kernel for custom (e.g., AMD) targets 2399 | extra_kern_args = {} 2400 | 2401 | if is_hip(): 2402 | waves_per_eu = 3 if HEAD_DIM_K <= 64 else 2 2403 | extra_kern_args = {"waves_per_eu": waves_per_eu, "allow_flush_denorm": True} 2404 | 2405 | if is_cuda() and warp_specialize: 2406 | # NOTE: We need more registers if we're doing JVP 2407 | if (HEAD_DIM_K == 128 and q.dtype == torch.float16) or ENABLE_JVP: 2408 | extra_kern_args["maxnreg"] = 168 2409 | else: 2410 | # NOTE: For backward pass with HEAD_DIM_K=128, this is probably too low for H100; register allocation fails. 2411 | extra_kern_args["maxnreg"] = 80 2412 | 2413 | if hasattr(triton, "set_allocator") and is_cuda(): 2414 | 2415 | def alloc_fn(size: int, align: int, _): 2416 | """Custom allocator function for Triton.""" 2417 | return torch.empty(size, dtype=torch.int8, device="cuda") 2418 | 2419 | triton.set_allocator(alloc_fn) 2420 | 2421 | def strides_zhnd(t: Tensor) -> JVPAttn.Strides: 2422 | """Get strides for a tensor with shape (Z, H, N_CTX, HEAD_DIM).""" 2423 | return JVPAttn.Strides(t.stride(0), t.stride(1), t.stride(2), t.stride(3)) 2424 | 2425 | # Determine mask type 2426 | if attn_mask is None: 2427 | MASK_TYPE = 0 2428 | mask_tensor = torch.empty(0, device=q.device, dtype=q.dtype) 2429 | mask_strides = (0, 0, 0, 0) 2430 | elif attn_mask.dtype == torch.bool: 2431 | MASK_TYPE = 1 2432 | mask_tensor = attn_mask.contiguous() 2433 | mask_strides = strides_zhnd(mask_tensor) 2434 | if verify_attn_mask: 2435 | # Check if any head is all False 2436 | assert mask_tensor.any( 2437 | dim=(-1, -2) 2438 | ).all(), "The attention mask cannot be all False for any head." 2439 | else: 2440 | MASK_TYPE = 2 2441 | mask_tensor = attn_mask.to(q.dtype).contiguous() 2442 | mask_strides = strides_zhnd(mask_tensor) 2443 | if verify_attn_mask: 2444 | # Check if the mask contains -inf/inf/NaN or is all (or no) MASK_CONST for any head 2445 | assert not torch.isinf( 2446 | mask_tensor 2447 | ).any(), "The attention mask cannot contain -inf or inf." 2448 | assert not torch.isnan( 2449 | mask_tensor 2450 | ).any(), "The attention mask cannot contain NaNs." 2451 | assert ( 2452 | (mask_tensor != MASK_CONST).any(dim=(-1, -2)).all() 2453 | ), f"The attention mask cannot be all {MASK_CONST} (the masking constant) for any head." 2454 | 2455 | if not (mask_tensor == MASK_CONST).any(): 2456 | raise UserWarning( 2457 | f"The provided floating-point attention mask does not mask out any elements with {MASK_CONST} (the masking constant). Consider using this constant for correct masking behavior." 2458 | ) 2459 | 2460 | # Prepare dropout arguments 2461 | ENABLE_DROPOUT = dropout_p > 0.0 2462 | if ENABLE_DROPOUT: 2463 | philox_seed = torch.randint(0, 2**32, (1,), device=q.device, dtype=torch.int64).item() 2464 | else: 2465 | philox_seed = 0 2466 | 2467 | # Set up grid for kernel launch 2468 | Z_H = Z * H 2469 | 2470 | def grid(META: dict[str, Any]) -> JVPAttn.Grid: 2471 | """Determine grid configuration.""" 2472 | return JVPAttn.Grid(triton.cdiv(N_CTX, META["BLOCK_M"]), Z_H, 1) 2473 | 2474 | if USE_TMA and supports_tma(): 2475 | # NOTE: On Hopper, we cannot perform a FP8 dot with a non-transposed second tensor. 2476 | y_dim = Z_H * N_CTX 2477 | tma_block_shape = [MIN_SEQUENCE_LENGTH, HEAD_DIM_K] 2478 | 2479 | desc_q = TensorDescriptor( 2480 | q, 2481 | shape=[y_dim, HEAD_DIM_K], 2482 | strides=[HEAD_DIM_K, 1], 2483 | block_shape=tma_block_shape, 2484 | ) 2485 | desc_q_t = ( 2486 | desc_q 2487 | if q_t is None 2488 | else TensorDescriptor( 2489 | q_t, 2490 | shape=[y_dim, HEAD_DIM_K], 2491 | strides=[HEAD_DIM_K, 1], 2492 | block_shape=tma_block_shape, 2493 | ) 2494 | ) 2495 | 2496 | if q.dtype == torch.float8_e5m2: 2497 | v_shape = [HEAD_DIM_K, y_dim] 2498 | v_strides = [N_CTX, 1] 2499 | else: 2500 | v_shape = [y_dim, HEAD_DIM_K] 2501 | v_strides = [HEAD_DIM_K, 1] 2502 | desc_v = TensorDescriptor( 2503 | v, shape=v_shape, strides=v_strides, block_shape=tma_block_shape 2504 | ) 2505 | # NOTE: Probably we could share the shape and strides from above, but whatever 2506 | if q_t is not None and q_t.dtype == torch.float8_e5m2: 2507 | t_v_shape = [HEAD_DIM_K, y_dim] 2508 | t_v_strides = [q_t.shape[2], 1] 2509 | else: 2510 | t_v_shape = [y_dim, HEAD_DIM_K] 2511 | t_v_strides = [HEAD_DIM_K, 1] 2512 | desc_v_t = ( 2513 | desc_v 2514 | if v_t is None 2515 | else TensorDescriptor( 2516 | v_t, shape=t_v_shape, strides=t_v_strides, block_shape=tma_block_shape 2517 | ) 2518 | ) 2519 | 2520 | desc_k = TensorDescriptor( 2521 | k, 2522 | shape=[y_dim, HEAD_DIM_K], 2523 | strides=[HEAD_DIM_K, 1], 2524 | block_shape=tma_block_shape, 2525 | ) 2526 | desc_k_t = ( 2527 | desc_k 2528 | if k_t is None 2529 | else TensorDescriptor( 2530 | k_t, 2531 | shape=[y_dim, HEAD_DIM_K], 2532 | strides=[HEAD_DIM_K, 1], 2533 | block_shape=tma_block_shape, 2534 | ) 2535 | ) 2536 | 2537 | desc_o = TensorDescriptor( 2538 | o, 2539 | shape=[y_dim, HEAD_DIM_K], 2540 | strides=[HEAD_DIM_K, 1], 2541 | block_shape=tma_block_shape, 2542 | ) 2543 | desc_o_t = ( 2544 | desc_o 2545 | if o_t is None 2546 | else TensorDescriptor( 2547 | o_t, 2548 | shape=[y_dim, HEAD_DIM_K], 2549 | strides=[HEAD_DIM_K, 1], 2550 | block_shape=tma_block_shape, 2551 | ) 2552 | ) 2553 | 2554 | _attn_fwd_tma[grid]( 2555 | sm_scale, 2556 | M, # 2557 | Z, 2558 | H, # 2559 | desc_q, 2560 | desc_k, 2561 | desc_v, # 2562 | desc_q_t, 2563 | desc_k_t, 2564 | desc_v_t, # 2565 | desc_o, 2566 | desc_o_t, # 2567 | mask_tensor, # 2568 | dropout_p, # 2569 | philox_seed, # 2570 | *mask_strides, # 2571 | N_CTX=N_CTX, # 2572 | HEAD_DIM=HEAD_DIM_K, # 2573 | FP8_OUTPUT=q.dtype == torch.float8_e5m2, # 2574 | STAGE=STAGE, # 2575 | warp_specialize=warp_specialize, # 2576 | ENABLE_JVP=ENABLE_JVP, # 2577 | ENABLE_DROPOUT=ENABLE_DROPOUT, 2578 | MASK_TYPE=MASK_TYPE, 2579 | # NOTE: The following are safe (unit-tested) default values 2580 | BLOCK_M=MIN_SEQUENCE_LENGTH, # 2581 | BLOCK_N=MIN_SEQUENCE_LENGTH, # 2582 | num_stages=NUM_STAGES_OPTIONS[0], # 2583 | num_warps=4, # 2584 | **extra_kern_args, 2585 | ) 2586 | 2587 | else: 2588 | _attn_fwd[grid]( 2589 | q, 2590 | k, 2591 | v, 2592 | q_t, 2593 | k_t, 2594 | v_t, # 2595 | sm_scale, 2596 | M, 2597 | o, 2598 | o_t, # 2599 | mask_tensor, # 2600 | dropout_p, # 2601 | philox_seed, # 2602 | *strides_zhnd(q), # 2603 | *strides_zhnd(k), # 2604 | *strides_zhnd(v), # 2605 | *strides_zhnd(q if q_t is None else q_t), # 2606 | *strides_zhnd(k if k_t is None else k_t), # 2607 | *strides_zhnd(v if v_t is None else v_t), # 2608 | *strides_zhnd(o), # 2609 | *strides_zhnd(o if o_t is None else o_t), # 2610 | *mask_strides, # 2611 | Z, 2612 | H, # 2613 | N_CTX=N_CTX, # 2614 | HEAD_DIM=HEAD_DIM_K, # 2615 | FP8_OUTPUT=q.dtype == torch.float8_e5m2, # 2616 | STAGE=STAGE, # 2617 | warp_specialize=warp_specialize, # 2618 | ENABLE_JVP=ENABLE_JVP, # 2619 | ENABLE_DROPOUT=ENABLE_DROPOUT, 2620 | MASK_TYPE=MASK_TYPE, 2621 | # NOTE: The following are safe (unit-tested) default values 2622 | BLOCK_M=MIN_SEQUENCE_LENGTH, # 2623 | BLOCK_N=MIN_SEQUENCE_LENGTH, # 2624 | num_stages=NUM_STAGES_OPTIONS[0], # 2625 | num_warps=4, # 2626 | **extra_kern_args, 2627 | ) 2628 | 2629 | return JVPAttn.FwdOut( 2630 | o, 2631 | JVPAttn.FwdOutCtxContrib( 2632 | o_t, 2633 | M, 2634 | grid, 2635 | HEAD_DIM_K, 2636 | sm_scale, 2637 | mask_tensor, 2638 | MASK_TYPE, 2639 | dropout_p, 2640 | philox_seed, 2641 | ENABLE_DROPOUT, 2642 | ), 2643 | ) 2644 | 2645 | @staticmethod 2646 | def setup_context(ctx: JVPAttn.FnCtx, inputs, outputs: JVPAttn.FwdOut) -> Tensor: 2647 | """Set up the context for JVP Attention. 2648 | 2649 | Args: 2650 | ctx: The context to set up 2651 | inputs: The input tensors 2652 | outputs: The output tensors 2653 | """ 2654 | ( 2655 | q, 2656 | k, 2657 | v, 2658 | q_t, 2659 | k_t, 2660 | v_t, 2661 | attn_mask, 2662 | dropout_p, 2663 | causal, 2664 | sm_scale, 2665 | warp_specialize, 2666 | USE_TMA, 2667 | verify_attn_mask, 2668 | ) = inputs 2669 | 2670 | o, ( 2671 | o_t, 2672 | M, 2673 | grid, 2674 | HEAD_DIM_K, 2675 | sm_scale, 2676 | mask_tensor, 2677 | MASK_TYPE, 2678 | dropout_p, 2679 | philox_seed, 2680 | ENABLE_DROPOUT, 2681 | ) = outputs 2682 | 2683 | ctx.grid = grid 2684 | ctx.save_for_forward(o_t) 2685 | ctx.save_for_backward(q, k, v, o, M) 2686 | 2687 | ctx.sm_scale = sm_scale 2688 | ctx.HEAD_DIM_K = HEAD_DIM_K 2689 | ctx.causal = causal 2690 | ctx.mask_tensor = mask_tensor 2691 | ctx.MASK_TYPE = MASK_TYPE 2692 | ctx.dropout_p = dropout_p 2693 | ctx.philox_seed = philox_seed 2694 | ctx.ENABLE_DROPOUT = ENABLE_DROPOUT 2695 | 2696 | @staticmethod 2697 | def fwd( 2698 | q: Tensor, 2699 | k: Tensor, 2700 | v: Tensor, 2701 | attn_mask: Tensor | None = None, 2702 | dropout_p: float = 0.0, 2703 | causal: bool = False, 2704 | sm_scale: float | None = None, 2705 | warp_specialize: bool = True, 2706 | USE_TMA: bool = True, 2707 | ) -> Tensor: 2708 | """Forward pass for JVP Attention. 2709 | 2710 | NOTE: This is not an autograd convention. It's a workaround to get type-hinting and kwarg support. 2711 | 2712 | NOTE: Calls to `contiguous()` are necessary to ensure the inputs are contiguous in memory 2713 | (e.g., due to an `unbind` call to create `q`, `k`, `v`) but nonetheless may incur a performance cost. 2714 | 2715 | Args: 2716 | q: Query tensor of shape (Z, H, N_CTX, HEAD_DIM_Q). 2717 | k: Key tensor of shape (Z, H, N_CTX, HEAD_DIM_K). 2718 | v: Value tensor of shape (Z, H, N_CTX, HEAD_DIM_V). 2719 | attn_mask: Optional attention mask of shape (Z, H, N_CTX, N_CTX). Two types of masks are supported. A boolean mask where a value of True indicates that the element should take part in attention, or a float mask of the same type as query, key, value that is added to the attention score. 2720 | dropout_p: Dropout probability. 2721 | causal: Whether to use causal attention. 2722 | sm_scale: The softmax scale factor. 2723 | warp_specialize: Whether to use warp specialization. 2724 | USE_TMA: Whether to use TMA. 2725 | 2726 | Returns: 2727 | The output tensor. 2728 | """ 2729 | if not (q.is_contiguous() and k.is_contiguous() and v.is_contiguous()): 2730 | q, k, v = q.contiguous(), k.contiguous(), v.contiguous() 2731 | 2732 | out: JVPAttn.FwdOut = JVPAttn.apply( 2733 | q, 2734 | k, 2735 | v, 2736 | None, 2737 | None, 2738 | None, 2739 | attn_mask, 2740 | dropout_p, 2741 | causal, 2742 | sm_scale, 2743 | warp_specialize, 2744 | USE_TMA, 2745 | ) 2746 | 2747 | a, _ = out 2748 | return a 2749 | 2750 | @staticmethod 2751 | def fwd_dual( 2752 | q: Tensor, 2753 | k: Tensor, 2754 | v: Tensor, 2755 | attn_mask: Tensor | None = None, 2756 | dropout_p: float = 0.0, 2757 | causal: bool = False, 2758 | sm_scale: float | None = None, 2759 | warp_specialize: bool = True, 2760 | USE_TMA: bool = True, 2761 | ) -> Tensor: 2762 | """Forward pass for JVP Attention with dual tensor inputs. 2763 | 2764 | NOTE: This is not an autograd convention. It's a workaround to get type-hinting and kwarg support. 2765 | 2766 | NOTE: Calls to `contiguous()` are necessary to ensure the inputs are contiguous in memory 2767 | (e.g., due to an `unbind` call to create `q`, `k`, `v`) but nonetheless may incur a performance cost. 2768 | 2769 | Args: 2770 | q: Query tensor of shape (Z, H, N_CTX, HEAD_DIM_Q). 2771 | k: Key tensor of shape (Z, H, N_CTX, HEAD_DIM_K). 2772 | v: Value tensor of shape (Z, H, N_CTX, HEAD_DIM_V). 2773 | attn_mask: Optional attention mask of shape (Z, H, N_CTX, N_CTX). Two types of masks are supported. A boolean mask where a value of True indicates that the element should take part in attention, or a float mask of the same type as query, key, value that is added to the attention score. 2774 | dropout_p: Dropout probability. 2775 | causal: Whether to use causal attention. 2776 | sm_scale: The softmax scale factor. 2777 | warp_specialize: Whether to use warp specialization. 2778 | USE_TMA: Whether to use TMA. 2779 | 2780 | Returns: 2781 | The output tensor. 2782 | """ 2783 | if not (q.is_contiguous() and k.is_contiguous() and v.is_contiguous()): 2784 | q, k, v = q.contiguous(), k.contiguous(), v.contiguous() 2785 | 2786 | q_p, q_t = fwAD.unpack_dual(q) 2787 | k_p, k_t = fwAD.unpack_dual(k) 2788 | v_p, v_t = fwAD.unpack_dual(v) 2789 | 2790 | # NOTE: We pass some dualtensor args to ensure jvp() will be called, 2791 | # but we also pass tangents separately, as forward() demotes dual 2792 | # tensor args to primals for some reason. 2793 | out: JVPAttn.FwdOut = JVPAttn.apply( 2794 | q, 2795 | k, 2796 | v, 2797 | q_t, 2798 | k_t, 2799 | v_t, 2800 | attn_mask, 2801 | dropout_p, 2802 | causal, 2803 | sm_scale, 2804 | warp_specialize, 2805 | USE_TMA, 2806 | ) 2807 | 2808 | a, _ = out 2809 | return a 2810 | 2811 | @staticmethod 2812 | def jvp(ctx: JVPAttn.FnCtx, gq: Tensor, gk: Tensor, gv: Tensor, *_) -> JVPAttn.JVPOut: 2813 | """Compute the Jacobian-vector product (JVP) for JVP Attention. 2814 | 2815 | Args: 2816 | ctx: The context 2817 | gq: The gradient of the query tensor 2818 | gk: The gradient of the key tensor 2819 | gv: The gradient of the value tensor 2820 | 2821 | Returns: 2822 | The JVP output. 2823 | """ 2824 | return JVPAttn.JVPOut(ctx.saved_for_forward[0], None) 2825 | 2826 | @staticmethod 2827 | def backward(ctx, do, _) -> JVPAttn.BwdOut: 2828 | """Backward pass for JVP Attention. 2829 | 2830 | NOTE: A call to `contiguous()` may be necessary to ensure the output derivatives are contiguous 2831 | in memory (e.g., due to autograd weirdness) but nonetheless may incur a performance cost. 2832 | 2833 | Args: 2834 | ctx: The context 2835 | do: The gradient of the output tensor 2836 | 2837 | Returns: 2838 | The backward output. 2839 | """ 2840 | q, k, v, o, M = ctx.saved_tensors 2841 | 2842 | # Ensure inputs/outputs the kernel reads share the same (contiguous) layout 2843 | if not ( 2844 | q.is_contiguous() and k.is_contiguous() and v.is_contiguous() and o.is_contiguous() 2845 | ): 2846 | raise ValueError( 2847 | "JVPAttn expected q, k, v, o to be contiguous; got " 2848 | f"q.is_contiguous()={q.is_contiguous()}, k.is_contiguous()={k.is_contiguous()}, " 2849 | f"v.is_contiguous()={v.is_contiguous()}, o.is_contiguous()={o.is_contiguous()}, " 2850 | f"do.is_contiguous()={do.is_contiguous()}" 2851 | ) 2852 | 2853 | # NOTE: Autograd may deliver a non-contiguous output gradient; if so, normalize it. 2854 | if not do.is_contiguous(): 2855 | do = do.contiguous() 2856 | 2857 | # Ensure all inputs/outputs the kernel reads share the same layout 2858 | assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride(), ( 2859 | "JVPAttn expected q, k, v, o, do to have the same layout; got " 2860 | f"q.stride()={q.stride()}, k.stride()={k.stride()}, v.stride()={v.stride()}, " 2861 | f"o.stride()={o.stride()}, do.stride()={do.stride()}" 2862 | ) 2863 | 2864 | # Initialize tensors for gradients 2865 | dq = torch.empty_like(q) 2866 | dk = torch.empty_like(k) 2867 | dv = torch.empty_like(v) 2868 | delta = torch.empty_like(M) 2869 | 2870 | # Collect metadata 2871 | Z, H, N_CTX = q.shape[:3] 2872 | 2873 | BLK_SLICE_FACTOR = 2 # NOTE: This is a safe default value to reduce backward memory usage 2874 | BLOCK_MIN = MIN_SEQUENCE_LENGTH # NOTE: Adjust according to minimum input sequence length 2875 | BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = BLOCK_MIN, BLOCK_MIN, BLOCK_MIN, BLOCK_MIN 2876 | 2877 | assert N_CTX % BLOCK_MIN == 0, f"N_CTX must be divisible by BLOCK_MIN={BLOCK_MIN}" 2878 | 2879 | if not ctx.causal: 2880 | assert ( 2881 | BLOCK_M1 == BLOCK_M2 == BLOCK_N1 == BLOCK_N2 2882 | ), "For non-causal attention, all block sizes must be equal." 2883 | 2884 | # Scale k by sm_scale / ln(2) to account for softmax scaling and 2885 | # change-of-base of exponentiation (exp2). 2886 | RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2) 2887 | arg_k = k 2888 | arg_k = arg_k * (ctx.sm_scale * RCP_LN2) 2889 | 2890 | # Determine mask type 2891 | if ctx.MASK_TYPE == 0: 2892 | mask_strides = (0, 0, 0, 0) 2893 | else: 2894 | mask_strides = ( 2895 | ctx.mask_tensor.stride(0), 2896 | ctx.mask_tensor.stride(1), 2897 | ctx.mask_tensor.stride(2), 2898 | ctx.mask_tensor.stride(3), 2899 | ) 2900 | 2901 | # Set up grid for kernel launch 2902 | Z_H = Z * H 2903 | 2904 | # Preprocess output's deltas 2905 | pre_grid = (N_CTX // BLOCK_MIN, Z_H) 2906 | _attn_bwd_preprocess[pre_grid]( 2907 | o, 2908 | do, # 2909 | delta, # 2910 | N_CTX, # 2911 | BLOCK_M=BLOCK_MIN, 2912 | HEAD_DIM=ctx.HEAD_DIM_K, # 2913 | ) 2914 | 2915 | # Launch the backward kernel, enabling pipelining for backward pass on A100s 2916 | grid = (N_CTX // BLOCK_MIN, Z_H) 2917 | bwd_kernel = _attn_bwd_causal if ctx.causal else _attn_bwd 2918 | num_stages = ( 2919 | 5 2920 | if is_cuda() and torch.cuda.get_device_capability()[0] == 9 2921 | else NUM_STAGES_OPTIONS[0] 2922 | ) 2923 | 2924 | bwd_kernel[grid]( 2925 | q, 2926 | arg_k, 2927 | v, 2928 | ctx.sm_scale, 2929 | do, 2930 | dq, 2931 | dk, 2932 | dv, # 2933 | M, 2934 | delta, # 2935 | q.stride(0), 2936 | q.stride(1), 2937 | q.stride(2), 2938 | q.stride(3), # 2939 | mask_strides[0], 2940 | mask_strides[1], 2941 | mask_strides[2], 2942 | mask_strides[3], # 2943 | H, 2944 | N_CTX, # 2945 | BLOCK_M1=BLOCK_M1, 2946 | BLOCK_N1=BLOCK_N1, # 2947 | BLOCK_M2=BLOCK_M2, 2948 | BLOCK_N2=BLOCK_N2, # 2949 | BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, # 2950 | HEAD_DIM=ctx.HEAD_DIM_K, # 2951 | mask_ptr=ctx.mask_tensor, 2952 | MASK_TYPE=ctx.MASK_TYPE, 2953 | dropout_p=ctx.dropout_p, 2954 | philox_seed=ctx.philox_seed, 2955 | ENABLE_DROPOUT=ctx.ENABLE_DROPOUT, 2956 | # NOTE: The following are safe (unit-tested) default values 2957 | num_stages=num_stages, # 2958 | num_warps=4, # 2959 | ) 2960 | 2961 | return JVPAttn.BwdOut( 2962 | dq, dk, dv, None, None, None, None, None, None, None, None, None, None 2963 | ) 2964 | 2965 | 2966 | attention = JVPAttn.fwd 2967 | --------------------------------------------------------------------------------