├── tests
├── __init__.py
└── test_jvp_attention.py
├── jvp_flash_attention
├── __init__.py
└── jvp_attention.py
├── main.png
├── float32_mem_scaling.png
├── float32_time_scaling.png
├── CITATION.cff
├── .github
└── workflows
│ └── publish.yaml
├── LICENSE
├── pyproject.toml
├── .pre-commit-config.yaml
├── .gitignore
└── README.md
/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/jvp_flash_attention/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/main.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amorehead/jvp_flash_attention/HEAD/main.png
--------------------------------------------------------------------------------
/float32_mem_scaling.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amorehead/jvp_flash_attention/HEAD/float32_mem_scaling.png
--------------------------------------------------------------------------------
/float32_time_scaling.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amorehead/jvp_flash_attention/HEAD/float32_time_scaling.png
--------------------------------------------------------------------------------
/CITATION.cff:
--------------------------------------------------------------------------------
1 | cff-version: 1.2.0
2 | message: "If you want to cite the kernel, feel free to use this (but only if you loved it 😊)"
3 | title: "JVP Flash Attention"
4 | abstract: "A Flash Attention Triton kernel with support for second-order derivatives, such as Jacobian-Vector Products (JVPs) and Hessian-Vector Products (HVPs)."
5 | date-released: 2025-09-05
6 | authors:
7 | - family-names: "Morehead"
8 | given-names: "Alex"
9 | version: 0.10.0
10 | doi: 10.5281/zenodo.17050188
11 | license: "MIT"
12 | url: "https://zenodo.org/records/17050188"
13 | repository-code: "https://github.com/amorehead/jvp_flash_attention"
14 | keywords:
15 | - artificial intelligence
16 | - deep learning
17 |
--------------------------------------------------------------------------------
/.github/workflows/publish.yaml:
--------------------------------------------------------------------------------
1 | # This workflow will upload a Python Package using Twine when a release is created
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | # This workflow uses actions that are not certified by GitHub.
5 | # They are provided by a third-party and are governed by
6 | # separate terms of service, privacy policy, and support
7 | # documentation.
8 |
9 | name: Upload Python Package
10 |
11 | on:
12 | release:
13 | types: [published]
14 |
15 | jobs:
16 | deploy:
17 | runs-on: ubuntu-latest
18 |
19 | steps:
20 | - uses: actions/checkout@v4
21 | - name: Set up Python 3.10
22 | uses: actions/setup-python@v5
23 | with:
24 | python-version: "3.10"
25 | - name: Install dependencies
26 | run: |
27 | python -m pip install --upgrade pip
28 | pip install build
29 | - name: Build package
30 | run: python -m build
31 | - name: Publish package
32 | uses: pypa/gh-action-pypi-publish@v1.12.4
33 | with:
34 | user: __token__
35 | password: ${{ secrets.PYPI_API_TOKEN }}
36 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | JVP Flash Attention (jvp_flash_attention) Copyright (c) 2025, The Regents of the University of California, through Lawrence Berkeley National Laboratory (subject to receipt of any required approvals from the U.S. Dept. of Energy). All rights reserved.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights to use,
8 | copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
9 | Software, and to permit persons to whom the Software is furnished to do so,
10 | subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 |
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 |
19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 |
21 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22 |
23 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
24 |
25 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
26 |
27 | SOFTWARE.
28 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "jvp_flash_attention"
3 | version = "0.10.0"
4 | description = "Flash Attention Triton kernel with support for second-order derivatives"
5 | authors = [
6 | { name = "Alex Morehead", email = "alex.morehead@gmail.edu" }
7 | ]
8 | readme = "README.md"
9 | requires-python = ">=3.10"
10 | license = "MIT"
11 | keywords = [
12 | 'artificial intelligence',
13 | 'deep learning',
14 | ]
15 |
16 | classifiers=[
17 | 'Development Status :: 4 - Beta',
18 | 'Intended Audience :: Developers',
19 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
20 | 'License :: OSI Approved :: MIT License',
21 | 'Programming Language :: Python :: 3.10',
22 | ]
23 |
24 | dependencies = [
25 | # --------- pytorch --------- #
26 | "torch>=2.8.0",
27 | "torchvision>=0.23",
28 |
29 | # --------- others --------- #
30 | # "numpy>=2.2.6",
31 | ]
32 |
33 | [project.urls]
34 | Homepage = "https://pypi.org/project/jvp_flash_attention/"
35 | Repository = "https://github.com/amorehead/jvp-flash-attention"
36 |
37 | [project.optional-dependencies]
38 | lint = ["pre-commit>=4.3.0"]
39 |
40 | [build-system]
41 | requires = ["hatchling"]
42 | build-backend = "hatchling.build"
43 |
44 | [tool.pytest.ini_options]
45 | pythonpath = [
46 | "."
47 | ]
48 | addopts = [
49 | "--color=yes",
50 | "--durations=0",
51 | "--strict-markers",
52 | "--doctest-modules",
53 | ]
54 | filterwarnings = [
55 | "ignore::DeprecationWarning",
56 | "ignore::UserWarning",
57 | ]
58 | log_cli = "True"
59 | markers = [
60 | "slow: slow tests",
61 | ]
62 | minversion = "6.0"
63 | testpaths = "tests/"
64 |
65 | # Assuming you're developing for Python 3.10
66 | target-version = "py310"
67 |
68 | [tool.hatch.metadata]
69 | allow-direct-references = true
70 |
71 | [tool.hatch.build.targets.wheel]
72 | packages = ["jvp_flash_attention"]
73 |
74 | [tool.coverage.report]
75 | exclude_lines = [
76 | "pragma: nocover",
77 | "raise NotImplementedError",
78 | "raise NotImplementedError()",
79 | "if __name__ == .__main__.:",
80 | ]
81 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | default_language_version:
2 | python: python3
3 |
4 | repos:
5 | - repo: https://github.com/pre-commit/pre-commit-hooks
6 | rev: v6.0.0
7 | hooks:
8 | # list of supported hooks: https://pre-commit.com/hooks.html
9 | - id: trailing-whitespace
10 | - id: end-of-file-fixer
11 | - id: check-docstring-first
12 | - id: check-yaml
13 | - id: debug-statements
14 | - id: detect-private-key
15 | - id: check-executables-have-shebangs
16 | - id: check-toml
17 | - id: check-case-conflict
18 | - id: check-added-large-files
19 | args: ["--maxkb=5000"]
20 |
21 | # python code formatting
22 | - repo: https://github.com/psf/black
23 | rev: 25.1.0
24 | hooks:
25 | - id: black
26 | args: [--line-length, "99"]
27 |
28 | # python import sorting
29 | - repo: https://github.com/PyCQA/isort
30 | rev: 6.0.1
31 | hooks:
32 | - id: isort
33 | args: ["--profile", "black", "--filter-files"]
34 |
35 | # python upgrading syntax to newer version
36 | - repo: https://github.com/asottile/pyupgrade
37 | rev: v3.20.0
38 | hooks:
39 | - id: pyupgrade
40 | args: [--py38-plus]
41 |
42 | # python docstring formatting
43 | # address runtime issue: https://github.com/PyCQA/docformatter/pull/287
44 | - repo: local
45 | hooks:
46 | - id: docformatter
47 | name: docformatter
48 | description: Formats docstrings to follow Google style.
49 | entry: python -Im docformatter
50 | additional_dependencies:
51 | - docformatter == 1.7.5
52 | args:
53 | [
54 | --in-place,
55 | --wrap-summaries=99,
56 | --wrap-descriptions=99,
57 | --style=google,
58 | --black,
59 | ]
60 | language: python
61 | types:
62 | - python
63 |
64 | # python docstring coverage checking
65 | - repo: https://github.com/econchick/interrogate
66 | rev: 1.7.0 # or master if you're bold
67 | hooks:
68 | - id: interrogate
69 | args:
70 | [
71 | --verbose,
72 | --fail-under=80,
73 | --ignore-init-module,
74 | --ignore-init-method,
75 | --ignore-module,
76 | --ignore-nested-functions,
77 | -vv,
78 | ]
79 |
80 | # python check (PEP8), programming errors and code complexity
81 | - repo: https://github.com/PyCQA/flake8
82 | rev: 7.3.0
83 | hooks:
84 | - id: flake8
85 | args:
86 | [
87 | "--extend-ignore",
88 | "E203,E402,E501,F401,F841,RST2,RST301",
89 | "--exclude",
90 | "logs/*,data/*",
91 | ]
92 | additional_dependencies: [flake8-rst-docstrings==0.3.0]
93 |
94 | # python security linter
95 | - repo: https://github.com/PyCQA/bandit
96 | rev: "1.8.6"
97 | hooks:
98 | - id: bandit
99 | args: ["-s", "B101"]
100 |
101 | # yaml formatting
102 | - repo: https://github.com/pre-commit/mirrors-prettier
103 | rev: v4.0.0-alpha.8
104 | hooks:
105 | - id: prettier
106 | types: [yaml]
107 | exclude: "environment.yaml"
108 |
109 | # shell scripts linter
110 | - repo: https://github.com/shellcheck-py/shellcheck-py
111 | rev: v0.11.0.1
112 | hooks:
113 | - id: shellcheck
114 |
115 | # md formatting
116 | - repo: https://github.com/executablebooks/mdformat
117 | rev: 0.7.22
118 | hooks:
119 | - id: mdformat
120 | args: ["--number"]
121 | additional_dependencies:
122 | - mdformat-gfm
123 | - mdformat-tables
124 | - mdformat_frontmatter
125 | # - mdformat-toc
126 | # - mdformat-black
127 |
128 | # word spelling linter
129 | - repo: https://github.com/codespell-project/codespell
130 | rev: v2.4.1
131 | hooks:
132 | - id: codespell
133 | args:
134 | - --skip=logs/**,data/**,*.ipynb
135 | # - --ignore-words-list=abc,def
136 |
137 | # jupyter notebook cell output clearing
138 | - repo: https://github.com/kynan/nbstripout
139 | rev: 0.8.1
140 | hooks:
141 | - id: nbstripout
142 |
143 | # jupyter notebook linting
144 | - repo: https://github.com/nbQA-dev/nbQA
145 | rev: 1.9.1
146 | hooks:
147 | - id: nbqa-black
148 | args: ["--line-length=99"]
149 | - id: nbqa-isort
150 | args: ["--profile=black"]
151 | - id: nbqa-flake8
152 | args:
153 | [
154 | "--extend-ignore=E203,E402,E501,F401,F841",
155 | "--exclude=logs/*,data/*",
156 | ]
157 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[codz]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py.cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # UV
98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | #uv.lock
102 |
103 | # poetry
104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105 | # This is especially recommended for binary packages to ensure reproducibility, and is more
106 | # commonly ignored for libraries.
107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108 | #poetry.lock
109 | #poetry.toml
110 |
111 | # pdm
112 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
113 | # pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
114 | # https://pdm-project.org/en/latest/usage/project/#working-with-version-control
115 | #pdm.lock
116 | #pdm.toml
117 | .pdm-python
118 | .pdm-build/
119 |
120 | # pixi
121 | # Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
122 | #pixi.lock
123 | # Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
124 | # in the .venv directory. It is recommended not to include this directory in version control.
125 | .pixi
126 |
127 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
128 | __pypackages__/
129 |
130 | # Celery stuff
131 | celerybeat-schedule
132 | celerybeat.pid
133 |
134 | # SageMath parsed files
135 | *.sage.py
136 |
137 | # Environments
138 | .env
139 | .envrc
140 | .venv
141 | env/
142 | venv/
143 | ENV/
144 | env.bak/
145 | venv.bak/
146 |
147 | # Spyder project settings
148 | .spyderproject
149 | .spyproject
150 |
151 | # Rope project settings
152 | .ropeproject
153 |
154 | # mkdocs documentation
155 | /site
156 |
157 | # mypy
158 | .mypy_cache/
159 | .dmypy.json
160 | dmypy.json
161 |
162 | # Pyre type checker
163 | .pyre/
164 |
165 | # pytype static type analyzer
166 | .pytype/
167 |
168 | # Cython debug symbols
169 | cython_debug/
170 |
171 | # PyCharm
172 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
173 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
174 | # and can be added to the global gitignore or merged into this file. For a more nuclear
175 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
176 | #.idea/
177 |
178 | # Abstra
179 | # Abstra is an AI-powered process automation framework.
180 | # Ignore directories containing user credentials, local state, and settings.
181 | # Learn more at https://abstra.io/docs
182 | .abstra/
183 |
184 | # Visual Studio Code
185 | # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
186 | # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
187 | # and can be added to the global gitignore or merged into this file. However, if you prefer,
188 | # you could uncomment the following to ignore the entire vscode folder
189 | # .vscode/
190 |
191 | # Ruff stuff:
192 | .ruff_cache/
193 |
194 | # PyPI configuration file
195 | .pypirc
196 |
197 | # Cursor
198 | # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
199 | # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
200 | # refer to https://docs.cursor.com/context/ignore-files
201 | .cursorignore
202 | .cursorindexingignore
203 |
204 | # Marimo
205 | marimo/_static/
206 | marimo/_lsp/
207 | __marimo__/
208 |
209 | # Visual Studio Code
210 | .vscode/
211 |
212 | # Unit tests
213 | /tests/*.json
214 | /tests/*.png
215 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # JVP Flash Attention
4 |
5 |

6 | [](https://doi.org/10.5281/zenodo.17050188)
7 | [](https://badge.fury.io/py/jvp_flash_attention)
8 | [](https://www.repostatus.org/#active)
9 |

10 | [](https://opensource.org/licenses/MIT)
11 |
12 |

13 |
14 |
15 |
16 | ## Description
17 |
18 | Flash Attention Triton kernel with support for second-order derivatives, such as Jacobian-Vector Products (JVPs) and Hessian-Vector Products (HVPs)
19 |
20 | ## Installation
21 |
22 | Using `pip`, one can install `jvp_flash_attention` as follows.
23 |
24 | ```bash
25 | # Install package
26 | pip install jvp_flash_attention
27 |
28 | # [OPTIONAL, for development] Install package and pre-commit hooks
29 | pip install -e .
30 | pre-commit install
31 | ```
32 |
33 | ## Usage
34 |
35 | Once installed, one can use `jvp_flash_attention` in place of PyTorch's `scaled_dot_product_attention` as follows.
36 |
37 | ```python
38 | import torch.nn.functional as F
39 |
40 | from torch.nn.attention import SDPBackend, sdpa_kernel
41 | from jvp_flash_attention.jvp_attention import JVPAttn, attention as jvp_attention
42 |
43 | with sdpa_kernel(SDPBackend.MATH):
44 | # Regular (quadratic) attention
45 | x = F.scaled_dot_product_attention(
46 | q,
47 | k,
48 | v,
49 | attn_mask=attn_mask,
50 | dropout_p=attn_dropout_p if self.training else 0.0,
51 | )
52 |
53 | # JVP flash attention
54 | x = jvp_attention(
55 | q,
56 | k,
57 | v,
58 | attn_mask=attn_mask,
59 | # dropout_p=attn_dropout_p if self.training else 0.0, # NOTE: Attention dropout is currently unsupported
60 | )
61 | ```
62 |
63 | Anecdotally, one can also swap out `F.scaled_dot_product_attention` with `jvp_attention` **even for pretrained models** with minimal impact on numerical accuracy.
64 |
65 | > Note: If calling `torch.func.jvp` manually in your model's forward pass like
66 | > `pred, df = torch.func.jvp(*(lambda x_jvp: model(x_jvp), (x,), (gt,)))`,
67 | > make sure to use JVP Flash Attention in your model as `model = lambda q, k, v: JVPAttn.fwd_dual(q, k, v)` instead of as `model = lambda q, k, v: jvp_attention(q, k, v)` to ensure each input's tangent vectors are computed [prior](https://github.com/amorehead/jvp_flash_attention/issues/10) to running PyTorch's `autograd` engine. Models that rely on `torch.autograd.grad` to compute higher-order derivatives in their forward pass (e.g., energy-based models) should not require this change.
68 |
69 | Contributions or enhancements are welcome!
70 |
71 | ## Results
72 |
73 | ### Loss matching
74 |
75 | Model training with either `F.scaled_dot_product_attention` or `JVPAttn.fwd_dual` produces the same loss trajectory.
76 |
77 |
78 |
79 | ### Speed matching
80 |
81 | Model training with either `F.scaled_dot_product_attention` or `JVPAttn.fwd_dual` achieves the same iteration speed.
82 |
83 |
84 |
85 | > Note: The following results can be reproduced (for `float32` precision) by running `python tests/test_jvp_attention.py --dtype float32`.
86 |
87 | ### Time scaling
88 |
89 | `jvp_attention` outscales the speed of (`SDPBackend.MATH`-based) `F.scaled_dot_product_attention` when calculating second-order derivatives.
90 |
91 |
92 |
93 |

94 |
95 |
96 |
97 | ### Memory scaling
98 |
99 | `jvp_attention` improves the memory usage of (`SDPBackend.MATH`-based) `F.scaled_dot_product_attention` when calculating second-order derivatives.
100 |
101 |
102 |
103 |

104 |
105 |
106 |
107 | ## Tests
108 |
109 | If you want to run all the unit tests verifying the correctness of the JVP Flash Attention Triton kernel, run the following command(s).
110 |
111 | ```bash
112 | python tests/test_jvp_attention.py --dtype {float16,bfloat16,float32}
113 | ```
114 |
115 | In principle, the kernel should support ROCm systems as well, though it has not yet been tested on them. macOS is currently unsupported except using a CPU-only backend.
116 |
117 | Full results for `float16`:
118 |
119 | ```
120 | ==============================================================================================================
121 | BENCHMARK SUMMARY
122 | ==============================================================================================================
123 | Seq Len Causal Mask Method Time (ms) Mem (MB) TFLOP/s Max Error Grad Check
124 | --------------------------------------------------------------------------------------------------------------
125 | 32 False additive sdpa 0.821 3.09 0.0 TFLOP/s baseline N/A
126 | 32 False additive jvp_attn 0.723 1.08 0.0 TFLOP/s 1.83e+01 ✗
127 |
128 | 32 False boolean sdpa 0.961 3.14 0.0 TFLOP/s baseline N/A
129 | 32 False boolean jvp_attn 0.504 1.03 0.0 TFLOP/s 3.91e-03 ✓
130 |
131 | 32 False none sdpa 0.576 3.09 0.0 TFLOP/s baseline N/A
132 | 32 False none jvp_attn 0.447 1.03 0.0 TFLOP/s 1.95e-03 ✓
133 |
134 | 32 True none sdpa 0.934 3.10 0.0 TFLOP/s baseline N/A
135 | 32 True none jvp_attn 0.458 1.03 0.0 TFLOP/s 3.91e-03 ✓
136 |
137 | 64 False additive sdpa 0.860 6.75 0.0 TFLOP/s baseline N/A
138 | 64 False additive jvp_attn 0.847 2.26 0.1 TFLOP/s 2.23e+00 ✗
139 |
140 | 64 False boolean sdpa 0.908 6.94 0.0 TFLOP/s baseline N/A
141 | 64 False boolean jvp_attn 0.521 2.07 0.1 TFLOP/s 3.91e-03 ✓
142 |
143 | 64 False none sdpa 0.542 6.75 0.0 TFLOP/s baseline N/A
144 | 64 False none jvp_attn 0.414 2.07 0.1 TFLOP/s 1.95e-03 ✓
145 |
146 | 64 True none sdpa 0.888 6.77 0.0 TFLOP/s baseline N/A
147 | 64 True none jvp_attn 0.437 2.07 0.1 TFLOP/s 2.20e-03 ✓
148 |
149 | 128 False additive sdpa 0.834 16.51 0.1 TFLOP/s baseline N/A
150 | 128 False additive jvp_attn 0.750 4.89 0.3 TFLOP/s 3.91e-03 ✓
151 |
152 | 128 False boolean sdpa 0.840 17.26 0.1 TFLOP/s baseline N/A
153 | 128 False boolean jvp_attn 0.520 4.14 0.4 TFLOP/s 3.91e-03 ✓
154 |
155 | 128 False none sdpa 0.610 16.51 0.2 TFLOP/s baseline N/A
156 | 128 False none jvp_attn 0.459 4.14 0.4 TFLOP/s 9.77e-04 ✓
157 |
158 | 128 True none sdpa 1.053 16.57 0.0 TFLOP/s baseline N/A
159 | 128 True none jvp_attn 0.438 4.14 0.2 TFLOP/s 2.44e-03 ✓
160 |
161 | 256 False additive sdpa 0.829 47.77 0.5 TFLOP/s baseline N/A
162 | 256 False additive jvp_attn 0.738 12.02 1.1 TFLOP/s 3.91e-03 ✓
163 |
164 | 256 False boolean sdpa 0.872 50.77 0.5 TFLOP/s baseline N/A
165 | 256 False boolean jvp_attn 0.482 8.27 1.7 TFLOP/s 3.91e-03 ✓
166 |
167 | 256 False none sdpa 0.812 47.27 0.5 TFLOP/s baseline N/A
168 | 256 False none jvp_attn 0.460 8.27 1.8 TFLOP/s 9.77e-04 ✓
169 |
170 | 256 True none sdpa 0.964 47.52 0.2 TFLOP/s baseline N/A
171 | 256 True none jvp_attn 0.436 8.27 0.9 TFLOP/s 3.91e-03 ✓
172 |
173 | 512 False additive sdpa 1.416 153.55 1.2 TFLOP/s baseline N/A
174 | 512 False additive jvp_attn 0.715 30.55 4.6 TFLOP/s 1.95e-03 ✓
175 |
176 | 512 False boolean sdpa 1.441 165.05 1.1 TFLOP/s baseline N/A
177 | 512 False boolean jvp_attn 0.500 16.55 6.6 TFLOP/s 1.95e-03 ✓
178 |
179 | 512 False none sdpa 1.374 153.05 1.2 TFLOP/s baseline N/A
180 | 512 False none jvp_attn 0.407 16.55 8.1 TFLOP/s 4.88e-04 ✓
181 |
182 | 512 True none sdpa 1.402 154.05 0.6 TFLOP/s baseline N/A
183 | 512 True none jvp_attn 0.460 16.55 3.6 TFLOP/s 2.93e-03 ✓
184 |
185 | 1024 False additive sdpa 4.963 546.84 1.3 TFLOP/s baseline N/A
186 | 1024 False additive jvp_attn 1.183 96.84 11.1 TFLOP/s 1.95e-03 ✓
187 |
188 | 1024 False boolean sdpa 4.991 594.84 1.3 TFLOP/s baseline N/A
189 | 1024 False boolean jvp_attn 0.622 33.84 21.1 TFLOP/s 1.95e-03 ✓
190 |
191 | 1024 False none sdpa 4.227 546.84 1.6 TFLOP/s baseline N/A
192 | 1024 False none jvp_attn 0.420 33.84 31.3 TFLOP/s 4.88e-04 ✓
193 |
194 | 1024 True none sdpa 4.861 550.84 0.7 TFLOP/s baseline N/A
195 | 1024 True none jvp_attn 0.469 33.84 14.0 TFLOP/s 3.91e-03 ✓
196 |
197 | 2048 False additive sdpa 18.773 2052.19 1.4 TFLOP/s baseline N/A
198 | 2048 False additive jvp_attn 3.379 336.19 15.6 TFLOP/s 1.95e-03 ✓
199 |
200 | 2048 False boolean sdpa 18.815 2244.19 1.4 TFLOP/s baseline N/A
201 | 2048 False boolean jvp_attn 1.674 66.19 31.4 TFLOP/s 1.95e-03 ✓
202 |
203 | 2048 False none sdpa 16.156 2052.19 1.6 TFLOP/s baseline N/A
204 | 2048 False none jvp_attn 1.186 66.19 44.3 TFLOP/s 4.88e-04 ✓
205 |
206 | 2048 True none sdpa 18.587 2068.19 0.7 TFLOP/s baseline N/A
207 | 2048 True none jvp_attn 0.720 66.19 36.5 TFLOP/s 1.95e-03 ✓
208 |
209 |
210 | ================================================================================
211 | MASK TYPE PERFORMANCE COMPARISON
212 | ================================================================================
213 | Seq Len Causal Method No Mask Boolean Mask Additive Mask
214 | --------------------------------------------------------------------------------
215 | 32 False jvp_attn 0.45 ms 0.50 ms (1.13x) 0.72 ms (1.62x)
216 | 32 True jvp_attn 0.46 ms N/A N/A
217 | 64 False jvp_attn 0.41 ms 0.52 ms (1.26x) 0.85 ms (2.05x)
218 | 64 True jvp_attn 0.44 ms N/A N/A
219 | 128 False jvp_attn 0.46 ms 0.52 ms (1.13x) 0.75 ms (1.63x)
220 | 128 True jvp_attn 0.44 ms N/A N/A
221 | 256 False jvp_attn 0.46 ms 0.48 ms (1.05x) 0.74 ms (1.60x)
222 | 256 True jvp_attn 0.44 ms N/A N/A
223 | 512 False jvp_attn 0.41 ms 0.50 ms (1.23x) 0.72 ms (1.76x)
224 | 512 True jvp_attn 0.46 ms N/A N/A
225 | 1024 False jvp_attn 0.42 ms 0.62 ms (1.48x) 1.18 ms (2.82x)
226 | 1024 True jvp_attn 0.47 ms N/A N/A
227 | 2048 False jvp_attn 1.19 ms 1.67 ms (1.41x) 3.38 ms (2.85x)
228 | 2048 True jvp_attn 0.72 ms N/A N/A
229 |
230 | ============================================================
231 | STATISTICS
232 | ============================================================
233 | Average speedup: 4.50x
234 | Min speedup: 1.02x
235 | Max speedup: 25.82x
236 |
237 | Accuracy: 26/28 tests passed
238 | ⚠️ Some accuracy checks failed
239 |
240 | Failed configurations:
241 | - Seq=32, Causal=False, Mask=additive
242 | - Seq=64, Causal=False, Mask=additive
243 | ```
244 |
245 | Full results for `bfloat16`:
246 |
247 | ```
248 | ==============================================================================================================
249 | BENCHMARK SUMMARY
250 | ==============================================================================================================
251 | Seq Len Causal Mask Method Time (ms) Mem (MB) TFLOP/s Max Error Grad Check
252 | --------------------------------------------------------------------------------------------------------------
253 | 32 False additive sdpa 0.864 3.09 0.0 TFLOP/s baseline N/A
254 | 32 False additive jvp_attn 0.773 1.08 0.0 TFLOP/s 1.84e+01 ✗
255 |
256 | 32 False boolean sdpa 0.949 3.14 0.0 TFLOP/s baseline N/A
257 | 32 False boolean jvp_attn 0.569 1.03 0.0 TFLOP/s 3.12e-02 ✓
258 |
259 | 32 False none sdpa 0.662 3.09 0.0 TFLOP/s baseline N/A
260 | 32 False none jvp_attn 0.447 1.03 0.0 TFLOP/s 1.56e-02 ✓
261 |
262 | 32 True none sdpa 0.945 3.10 0.0 TFLOP/s baseline N/A
263 | 32 True none jvp_attn 0.469 1.03 0.0 TFLOP/s 3.12e-02 ✓
264 |
265 | 64 False additive sdpa 0.923 6.75 0.0 TFLOP/s baseline N/A
266 | 64 False additive jvp_attn 1.149 2.26 0.0 TFLOP/s 2.23e+00 ✗
267 |
268 | 64 False boolean sdpa 0.910 6.94 0.0 TFLOP/s baseline N/A
269 | 64 False boolean jvp_attn 0.518 2.07 0.1 TFLOP/s 3.12e-02 ✓
270 |
271 | 64 False none sdpa 0.554 6.75 0.0 TFLOP/s baseline N/A
272 | 64 False none jvp_attn 0.427 2.07 0.1 TFLOP/s 1.56e-02 ✓
273 |
274 | 64 True none sdpa 0.886 6.77 0.0 TFLOP/s baseline N/A
275 | 64 True none jvp_attn 0.458 2.07 0.1 TFLOP/s 3.12e-02 ✓
276 |
277 | 128 False additive sdpa 0.860 16.51 0.1 TFLOP/s baseline N/A
278 | 128 False additive jvp_attn 0.896 4.89 0.2 TFLOP/s 1.56e-02 ✓
279 |
280 | 128 False boolean sdpa 0.891 17.26 0.1 TFLOP/s baseline N/A
281 | 128 False boolean jvp_attn 0.771 4.14 0.3 TFLOP/s 1.56e-02 ✓
282 |
283 | 128 False none sdpa 0.578 16.51 0.2 TFLOP/s baseline N/A
284 | 128 False none jvp_attn 0.467 4.14 0.4 TFLOP/s 7.81e-03 ✓
285 |
286 | 128 True none sdpa 0.917 16.57 0.1 TFLOP/s baseline N/A
287 | 128 True none jvp_attn 0.447 4.14 0.2 TFLOP/s 3.12e-02 ✓
288 |
289 | 256 False additive sdpa 0.822 47.77 0.5 TFLOP/s baseline N/A
290 | 256 False additive jvp_attn 0.734 12.02 1.1 TFLOP/s 1.56e-02 ✓
291 |
292 | 256 False boolean sdpa 0.880 50.77 0.5 TFLOP/s baseline N/A
293 | 256 False boolean jvp_attn 0.532 8.27 1.5 TFLOP/s 1.56e-02 ✓
294 |
295 | 256 False none sdpa 0.597 47.27 0.7 TFLOP/s baseline N/A
296 | 256 False none jvp_attn 0.441 8.27 1.9 TFLOP/s 7.81e-03 ✓
297 |
298 | 256 True none sdpa 0.869 47.52 0.2 TFLOP/s baseline N/A
299 | 256 True none jvp_attn 0.469 8.27 0.9 TFLOP/s 1.56e-02 ✓
300 |
301 | 512 False additive sdpa 1.429 153.55 1.1 TFLOP/s baseline N/A
302 | 512 False additive jvp_attn 0.710 30.55 4.6 TFLOP/s 1.56e-02 ✓
303 |
304 | 512 False boolean sdpa 1.714 165.05 1.0 TFLOP/s baseline N/A
305 | 512 False boolean jvp_attn 0.552 16.55 5.9 TFLOP/s 1.56e-02 ✓
306 |
307 | 512 False none sdpa 1.314 153.05 1.2 TFLOP/s baseline N/A
308 | 512 False none jvp_attn 0.403 16.55 8.2 TFLOP/s 7.81e-03 ✓
309 |
310 | 512 True none sdpa 1.788 154.05 0.5 TFLOP/s baseline N/A
311 | 512 True none jvp_attn 0.432 16.55 3.8 TFLOP/s 3.12e-02 ✓
312 |
313 | 1024 False additive sdpa 5.720 546.84 1.1 TFLOP/s baseline N/A
314 | 1024 False additive jvp_attn 1.133 96.84 11.6 TFLOP/s 1.56e-02 ✓
315 |
316 | 1024 False boolean sdpa 5.376 594.84 1.2 TFLOP/s baseline N/A
317 | 1024 False boolean jvp_attn 0.634 33.84 20.7 TFLOP/s 1.56e-02 ✓
318 |
319 | 1024 False none sdpa 4.646 546.84 1.4 TFLOP/s baseline N/A
320 | 1024 False none jvp_attn 0.423 33.84 31.1 TFLOP/s 3.91e-03 ✓
321 |
322 | 1024 True none sdpa 5.566 550.84 0.6 TFLOP/s baseline N/A
323 | 1024 True none jvp_attn 0.466 33.84 14.1 TFLOP/s 1.56e-02 ✓
324 |
325 | 2048 False additive sdpa 21.231 2052.19 1.2 TFLOP/s baseline N/A
326 | 2048 False additive jvp_attn 3.735 336.19 14.1 TFLOP/s 1.56e-02 ✓
327 |
328 | 2048 False boolean sdpa 21.626 2244.19 1.2 TFLOP/s baseline N/A
329 | 2048 False boolean jvp_attn 1.926 66.19 27.3 TFLOP/s 1.56e-02 ✓
330 |
331 | 2048 False none sdpa 18.311 2052.19 1.4 TFLOP/s baseline N/A
332 | 2048 False none jvp_attn 1.139 66.19 46.1 TFLOP/s 3.91e-03 ✓
333 |
334 | 2048 True none sdpa 20.748 2068.19 0.6 TFLOP/s baseline N/A
335 | 2048 True none jvp_attn 0.750 66.19 35.0 TFLOP/s 3.12e-02 ✓
336 |
337 |
338 | ================================================================================
339 | MASK TYPE PERFORMANCE COMPARISON
340 | ================================================================================
341 | Seq Len Causal Method No Mask Boolean Mask Additive Mask
342 | --------------------------------------------------------------------------------
343 | 32 False jvp_attn 0.45 ms 0.57 ms (1.27x) 0.77 ms (1.73x)
344 | 32 True jvp_attn 0.47 ms N/A N/A
345 | 64 False jvp_attn 0.43 ms 0.52 ms (1.21x) 1.15 ms (2.69x)
346 | 64 True jvp_attn 0.46 ms N/A N/A
347 | 128 False jvp_attn 0.47 ms 0.77 ms (1.65x) 0.90 ms (1.92x)
348 | 128 True jvp_attn 0.45 ms N/A N/A
349 | 256 False jvp_attn 0.44 ms 0.53 ms (1.21x) 0.73 ms (1.66x)
350 | 256 True jvp_attn 0.47 ms N/A N/A
351 | 512 False jvp_attn 0.40 ms 0.55 ms (1.37x) 0.71 ms (1.76x)
352 | 512 True jvp_attn 0.43 ms N/A N/A
353 | 1024 False jvp_attn 0.42 ms 0.63 ms (1.50x) 1.13 ms (2.68x)
354 | 1024 True jvp_attn 0.47 ms N/A N/A
355 | 2048 False jvp_attn 1.14 ms 1.93 ms (1.69x) 3.74 ms (3.28x)
356 | 2048 True jvp_attn 0.75 ms N/A N/A
357 |
358 | ============================================================
359 | STATISTICS
360 | ============================================================
361 | Average speedup: 4.75x
362 | Min speedup: 0.80x
363 | Max speedup: 27.65x
364 |
365 | Accuracy: 26/28 tests passed
366 | ⚠️ Some accuracy checks failed
367 |
368 | Failed configurations:
369 | - Seq=32, Causal=False, Mask=additive
370 | - Seq=64, Causal=False, Mask=additive
371 | ```
372 |
373 | Full results for `float32`:
374 |
375 | ```
376 | ==============================================================================================================
377 | BENCHMARK SUMMARY
378 | ==============================================================================================================
379 | Seq Len Causal Mask Method Time (ms) Mem (MB) TFLOP/s Max Error Grad Check
380 | --------------------------------------------------------------------------------------------------------------
381 | 32 False additive sdpa 0.770 2.44 0.0 TFLOP/s baseline N/A
382 | 32 False additive jvp_attn 0.812 2.16 0.0 TFLOP/s 2.31e-02 ✓
383 |
384 | 32 False boolean sdpa 0.830 2.53 0.0 TFLOP/s baseline N/A
385 | 32 False boolean jvp_attn 0.575 2.07 0.0 TFLOP/s 9.16e-03 ✓
386 |
387 | 32 False none sdpa 0.491 2.44 0.0 TFLOP/s baseline N/A
388 | 32 False none jvp_attn 0.528 2.07 0.0 TFLOP/s 7.83e-03 ✓
389 |
390 | 32 True none sdpa 0.831 2.44 0.0 TFLOP/s baseline N/A
391 | 32 True none jvp_attn 0.457 2.07 0.0 TFLOP/s 8.60e-03 ✓
392 |
393 | 64 False additive sdpa 0.859 5.25 0.0 TFLOP/s baseline N/A
394 | 64 False additive jvp_attn 0.793 4.51 0.1 TFLOP/s 1.24e-02 ✓
395 |
396 | 64 False boolean sdpa 0.778 5.62 0.0 TFLOP/s baseline N/A
397 | 64 False boolean jvp_attn 0.522 4.13 0.1 TFLOP/s 1.23e-02 ✓
398 |
399 | 64 False none sdpa 0.501 5.25 0.1 TFLOP/s baseline N/A
400 | 64 False none jvp_attn 0.437 4.13 0.1 TFLOP/s 7.03e-03 ✓
401 |
402 | 64 True none sdpa 0.810 5.27 0.0 TFLOP/s baseline N/A
403 | 64 True none jvp_attn 0.450 4.13 0.1 TFLOP/s 1.05e-02 ✓
404 |
405 | 128 False additive sdpa 0.869 13.51 0.1 TFLOP/s baseline N/A
406 | 128 False additive jvp_attn 0.697 9.76 0.3 TFLOP/s 9.14e-03 ✓
407 |
408 | 128 False boolean sdpa 0.832 15.76 0.1 TFLOP/s baseline N/A
409 | 128 False boolean jvp_attn 0.527 8.26 0.4 TFLOP/s 8.91e-03 ✓
410 |
411 | 128 False none sdpa 0.458 14.26 0.2 TFLOP/s baseline N/A
412 | 128 False none jvp_attn 0.610 8.26 0.3 TFLOP/s 5.07e-03 ✓
413 |
414 | 128 True none sdpa 0.817 14.32 0.1 TFLOP/s baseline N/A
415 | 128 True none jvp_attn 0.478 8.26 0.2 TFLOP/s 1.05e-02 ✓
416 |
417 | 256 False additive sdpa 0.786 43.27 0.5 TFLOP/s baseline N/A
418 | 256 False additive jvp_attn 0.689 23.77 1.2 TFLOP/s 9.98e-03 ✓
419 |
420 | 256 False boolean sdpa 0.754 48.52 0.5 TFLOP/s baseline N/A
421 | 256 False boolean jvp_attn 0.514 17.02 1.6 TFLOP/s 9.93e-03 ✓
422 |
423 | 256 False none sdpa 0.596 43.27 0.7 TFLOP/s baseline N/A
424 | 256 False none jvp_attn 0.461 17.77 1.8 TFLOP/s 4.03e-03 ✓
425 |
426 | 256 True none sdpa 0.837 43.52 0.2 TFLOP/s baseline N/A
427 | 256 True none jvp_attn 0.424 17.77 1.0 TFLOP/s 9.61e-03 ✓
428 |
429 | 512 False additive sdpa 1.383 144.80 1.2 TFLOP/s baseline N/A
430 | 512 False additive jvp_attn 0.793 57.80 4.1 TFLOP/s 7.29e-03 ✓
431 |
432 | 512 False boolean sdpa 1.342 168.80 1.2 TFLOP/s baseline N/A
433 | 512 False boolean jvp_attn 0.792 33.80 4.1 TFLOP/s 7.22e-03 ✓
434 |
435 | 512 False none sdpa 1.479 144.80 1.1 TFLOP/s baseline N/A
436 | 512 False none jvp_attn 0.453 33.80 7.2 TFLOP/s 3.94e-03 ✓
437 |
438 | 512 True none sdpa 1.582 145.80 0.5 TFLOP/s baseline N/A
439 | 512 True none jvp_attn 0.501 33.80 3.3 TFLOP/s 6.56e-03 ✓
440 |
441 | 1024 False additive sdpa 5.448 528.09 1.2 TFLOP/s baseline N/A
442 | 1024 False additive jvp_attn 2.148 168.09 6.1 TFLOP/s 6.90e-03 ✓
443 |
444 | 1024 False boolean sdpa 5.131 624.09 1.3 TFLOP/s baseline N/A
445 | 1024 False boolean jvp_attn 1.139 66.09 11.5 TFLOP/s 6.85e-03 ✓
446 |
447 | 1024 False none sdpa 4.339 528.09 1.5 TFLOP/s baseline N/A
448 | 1024 False none jvp_attn 0.538 66.09 24.4 TFLOP/s 2.92e-03 ✓
449 |
450 | 1024 True none sdpa 5.262 532.09 0.6 TFLOP/s baseline N/A
451 | 1024 True none jvp_attn 0.437 66.09 15.0 TFLOP/s 8.65e-03 ✓
452 |
453 | 2048 False additive sdpa 19.849 2016.19 1.3 TFLOP/s baseline N/A
454 | 2048 False additive jvp_attn 6.468 576.19 8.1 TFLOP/s 7.22e-03 ✓
455 |
456 | 2048 False boolean sdpa 20.017 2400.19 1.3 TFLOP/s baseline N/A
457 | 2048 False boolean jvp_attn 3.247 132.19 16.2 TFLOP/s 7.16e-03 ✓
458 |
459 | 2048 False none sdpa 16.573 2016.19 1.6 TFLOP/s baseline N/A
460 | 2048 False none jvp_attn 1.883 132.19 27.9 TFLOP/s 2.61e-03 ✓
461 |
462 | 2048 True none sdpa 19.577 2032.19 0.7 TFLOP/s baseline N/A
463 | 2048 True none jvp_attn 1.114 132.19 23.6 TFLOP/s 7.71e-03 ✓
464 |
465 |
466 | ================================================================================
467 | MASK TYPE PERFORMANCE COMPARISON
468 | ================================================================================
469 | Seq Len Causal Method No Mask Boolean Mask Additive Mask
470 | --------------------------------------------------------------------------------
471 | 32 False jvp_attn 0.53 ms 0.57 ms (1.09x) 0.81 ms (1.54x)
472 | 32 True jvp_attn 0.46 ms N/A N/A
473 | 64 False jvp_attn 0.44 ms 0.52 ms (1.20x) 0.79 ms (1.81x)
474 | 64 True jvp_attn 0.45 ms N/A N/A
475 | 128 False jvp_attn 0.61 ms 0.53 ms (0.87x) 0.70 ms (1.14x)
476 | 128 True jvp_attn 0.48 ms N/A N/A
477 | 256 False jvp_attn 0.46 ms 0.51 ms (1.11x) 0.69 ms (1.49x)
478 | 256 True jvp_attn 0.42 ms N/A N/A
479 | 512 False jvp_attn 0.45 ms 0.79 ms (1.75x) 0.79 ms (1.75x)
480 | 512 True jvp_attn 0.50 ms N/A N/A
481 | 1024 False jvp_attn 0.54 ms 1.14 ms (2.12x) 2.15 ms (3.99x)
482 | 1024 True jvp_attn 0.44 ms N/A N/A
483 | 2048 False jvp_attn 1.88 ms 3.25 ms (1.72x) 6.47 ms (3.43x)
484 | 2048 True jvp_attn 1.11 ms N/A N/A
485 |
486 | ============================================================
487 | STATISTICS
488 | ============================================================
489 | Average speedup: 3.37x
490 | Min speedup: 0.75x
491 | Max speedup: 17.57x
492 |
493 | Accuracy: 28/28 tests passed
494 | ✓ All accuracy checks passed!
495 | ```
496 |
497 | Note: Based on these results, for all precision types, it is recommended to provide a boolean `attn_mask` to `jvp_attention()` where possible.
498 |
499 | ## License
500 |
501 | This project is covered under the **MIT License**.
502 |
503 | ## Copyright
504 |
505 | JVP Flash Attention (jvp_flash_attention) Copyright (c) 2025, The Regents of the University of California, through Lawrence Berkeley National Laboratory (subject to receipt of any required approvals from the U.S. Dept. of Energy). All rights reserved.
506 |
507 | If you have questions about your rights to use or distribute this software,
508 | please contact Berkeley Lab's Intellectual Property Office at
509 | IPO@lbl.gov.
510 |
511 | **NOTICE.** This Software was developed under funding from the U.S. Department
512 | of Energy and the U.S. Government consequently retains certain rights. As
513 | such, the U.S. Government has been granted for itself and others acting on
514 | its behalf a paid-up, nonexclusive, irrevocable, worldwide license in the
515 | Software to reproduce, distribute copies to the public, prepare derivative
516 | works, and perform publicly and display publicly, and to permit others to do so.
517 |
518 | ## Citing this work
519 |
520 | If you use the code associated with this package or otherwise find this work useful, please use GitHub's `Cite this repository` feature or the BibTeX below.
521 |
522 | ```bibtex
523 | @software{Morehead_JVP_Flash_Attention_2025,
524 | author = {Morehead, Alex},
525 | doi = {10.5281/zenodo.17050188},
526 | license = {MIT},
527 | month = sep,
528 | title = {{JVP Flash Attention}},
529 | url = {https://github.com/amorehead/jvp_flash_attention},
530 | version = {0.10.0},
531 | year = {2025}
532 | }
533 | ```
534 |
535 | ## Acknowledgements
536 |
537 | `jvp_flash_attention` builds upon the contributions and insights from the following sources:
538 |
539 | - [flash-attention](https://github.com/Dao-AILab/flash-attention)
540 | - [JVP Triton kernel thread](https://github.com/Dao-AILab/flash-attention/issues/1672)
541 | - [benjamin-dinkelmann](https://gist.github.com/benjamin-dinkelmann)
542 | - *[Birch-san](https://github.com/Birch-san)*
543 | - [dabeschte](https://github.com/dabeschte)
544 | - [IsaacYQH](https://gist.github.com/IsaacYQH)
545 | - [KohakuBlueleaf](https://github.com/KohakuBlueleaf)
546 | - [leon](https://github.com/leon532)
547 | - [limsanky](https://github.com/limsanky)
548 | - [lucidrains](https://github.com/lucidrains)
549 | - [Peterande](https://gist.github.com/Peterande)
550 | - *[Ryu1845](https://github.com/Ryu1845)*
551 | - [tridao](https://github.com/tridao)
552 |
553 | Thank you to each and every contributor!
554 |
--------------------------------------------------------------------------------
/tests/test_jvp_attention.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import gc
4 | import os
5 | import random
6 | import time
7 | from argparse import ArgumentParser, Namespace
8 | from dataclasses import dataclass, field
9 | from functools import partial
10 | from typing import Any, Callable, NamedTuple
11 |
12 | import torch
13 | import torch.autograd.forward_ad as fwAD
14 | from torch import Tensor, enable_grad
15 | from torch.nn import MSELoss
16 | from torch.nn.attention import SDPBackend, sdpa_kernel
17 | from torch.nn.functional import scaled_dot_product_attention
18 |
19 | try:
20 | import matplotlib.pyplot as plt
21 | import numpy as np
22 |
23 | PLOTTING_AVAILABLE = True
24 | except ImportError:
25 | PLOTTING_AVAILABLE = False
26 |
27 | from jvp_flash_attention.jvp_attention import MASK_CONST, JVPAttn
28 |
29 |
30 | def mpi_to_flops(ms_per_iter: float, flop_count: int) -> float:
31 | """Convert milliseconds per iteration to FLOPS.
32 |
33 | Args:
34 | ms_per_iter: Milliseconds per iteration.
35 | flop_count: Number of floating point operations.
36 |
37 | Returns:
38 | The number of FLOPS.
39 | """
40 | iters_per_second = 1e3 / ms_per_iter
41 | return iters_per_second * flop_count
42 |
43 |
44 | def fmt_flops(flops: int) -> str:
45 | """Return a string representation of FLOPS in TFLOP/s."""
46 | return f"{flops / 1e12:5.1f} TFLOP/s"
47 |
48 |
49 | def get_attention_flop_count(
50 | batch_size: int,
51 | num_heads: int,
52 | seq_len: int,
53 | head_dim: int,
54 | is_causal: bool,
55 | is_jvp: bool = False,
56 | ) -> int:
57 | """Calculate FLOPs for attention operations.
58 |
59 | Args:
60 | batch_size: Batch size.
61 | num_heads: Number of attention heads.
62 | seq_len: Sequence length.
63 | head_dim: Dimension of each attention head.
64 | is_causal: Whether the attention is causal.
65 | is_jvp: Whether to include JVP (Jacobian-vector product) FLOPs.
66 |
67 | Returns:
68 | The total FLOPs for the attention operation.
69 | """
70 | # Base attention FLOPs
71 | qk_flops = 2 * batch_size * num_heads * seq_len * seq_len * head_dim
72 | softmax_flops = 5 * batch_size * num_heads * seq_len * seq_len
73 | av_flops = 2 * batch_size * num_heads * seq_len * seq_len * head_dim
74 |
75 | total_flops = qk_flops + softmax_flops + av_flops
76 |
77 | if is_causal:
78 | total_flops = total_flops // 2
79 |
80 | if is_jvp:
81 | total_flops = total_flops * 2
82 |
83 | return total_flops
84 |
85 |
86 | def measure_memory_usage(f: Callable[[], Any]) -> tuple[float, float]:
87 | """Measure GPU memory usage of a function.
88 |
89 | Args:
90 | f: The function to measure.
91 |
92 | Returns:
93 | Tuple of (allocated_mb, reserved_mb) memory in megabytes.
94 | """
95 | torch.cuda.synchronize()
96 | torch.cuda.reset_peak_memory_stats()
97 | gc.collect()
98 | torch.cuda.empty_cache()
99 |
100 | initial_allocated = torch.cuda.memory_allocated()
101 | initial_reserved = torch.cuda.memory_reserved()
102 |
103 | f()
104 |
105 | torch.cuda.synchronize()
106 |
107 | peak_allocated = torch.cuda.max_memory_allocated()
108 | peak_reserved = torch.cuda.max_memory_reserved()
109 |
110 | allocated_mb = (peak_allocated - initial_allocated) / (1024 * 1024)
111 | reserved_mb = (peak_reserved - initial_reserved) / (1024 * 1024)
112 |
113 | return allocated_mb, reserved_mb
114 |
115 |
116 | def benchmark_function(
117 | f: Callable[[], Any], warmup_iters: int = 10, benchmark_iters: int = 100
118 | ) -> float:
119 | """Benchmark a function's execution time.
120 |
121 | Args:
122 | f: The function to benchmark.
123 | warmup_iters: Number of warmup iterations.
124 | benchmark_iters: Number of benchmark iterations.
125 |
126 | Returns:
127 | Average time per iteration in milliseconds.
128 | """
129 | # Warmup
130 | for _ in range(warmup_iters):
131 | f()
132 |
133 | torch.cuda.synchronize()
134 | start_time = time.perf_counter()
135 |
136 | for _ in range(benchmark_iters):
137 | f()
138 |
139 | torch.cuda.synchronize()
140 | end_time = time.perf_counter()
141 |
142 | avg_time_ms = (end_time - start_time) * 1000 / benchmark_iters
143 | return avg_time_ms
144 |
145 |
146 | class QKV(NamedTuple):
147 | """Query, Key, Value tensors."""
148 |
149 | q: Tensor
150 | k: Tensor
151 | v: Tensor
152 |
153 |
154 | class UnpackedDualQKV(NamedTuple):
155 | """Unpacked dual Query, Key, Value tensors."""
156 |
157 | primal: QKV
158 | tangent: QKV
159 |
160 |
161 | @dataclass
162 | class AccuracyMetrics:
163 | """Accuracy metrics for numerical validation."""
164 |
165 | primal_error: float
166 | tangent_error: float
167 | loss_error: float
168 | q_grad_error: float
169 | k_grad_error: float
170 | v_grad_error: float
171 | tolerance: float = 5e-3
172 |
173 | @property
174 | def max_error(self) -> float:
175 | """Return the maximum error across all metrics."""
176 | return max(
177 | self.primal_error,
178 | self.tangent_error,
179 | self.loss_error,
180 | self.q_grad_error,
181 | self.k_grad_error,
182 | self.v_grad_error,
183 | )
184 |
185 | def is_accurate(self) -> bool:
186 | """Check if all errors are within tolerance."""
187 | return self.max_error < self.tolerance
188 |
189 |
190 | class BenchmarkResult(NamedTuple):
191 | """Results from a single benchmark run."""
192 |
193 | seq_len: int
194 | is_causal: bool
195 | method: str # 'sdpa' or 'jvp_attn'
196 | time_ms: float
197 | memory_allocated_mb: float
198 | memory_reserved_mb: float
199 | mask_type: str # 'none', 'boolean', or 'additive'
200 | flops: int | None = None
201 | accuracy: AccuracyMetrics | None = None
202 |
203 |
204 | @dataclass
205 | class Args:
206 | """Training arguments."""
207 |
208 | bsz: int
209 | model_dim: int
210 | head_dim: int
211 | seq_lengths: list[int] = field(default_factory=lambda: [32, 64, 128, 256, 512, 1024, 2048])
212 | warmup_iters: int = 10
213 | benchmark_iters: int = 100
214 | dtype: str = "float16"
215 | seed: int = 42
216 | test_masks: bool = True
217 | validate_gradients: bool = True
218 | benchmark_performance: bool = True
219 | mask_prob: float = 0.9 # Probability of masking out an attention weight
220 |
221 | @staticmethod
222 | def get_parser() -> ArgumentParser:
223 | """Get the argument parser for training."""
224 | parser = ArgumentParser()
225 | parser.add_argument("--bsz", default=2, type=int)
226 | parser.add_argument("--model-dim", default=768, type=int)
227 | parser.add_argument("--head-dim", default=64, type=int)
228 | parser.add_argument(
229 | "--seq-lengths", nargs="+", type=int, default=[32, 64, 128, 256, 512, 1024, 2048]
230 | )
231 | parser.add_argument("--warmup-iters", default=10, type=int)
232 | parser.add_argument("--benchmark-iters", default=100, type=int)
233 | parser.add_argument(
234 | "--dtype", default="float16", choices=["float16", "float32", "bfloat16"]
235 | )
236 | parser.add_argument("--seed", default=42, type=int)
237 | parser.add_argument(
238 | "--no-test-masks",
239 | action="store_true",
240 | help="Skip testing with attention masks",
241 | )
242 | parser.add_argument(
243 | "--no-validate-gradients",
244 | action="store_true",
245 | help="Skip gradient validation",
246 | )
247 | parser.add_argument(
248 | "--no-benchmark-performance",
249 | action="store_true",
250 | help="Skip performance benchmarking",
251 | )
252 | parser.add_argument(
253 | "--mask-prob",
254 | default=0.9,
255 | type=float,
256 | help="Probability of masking out attention weights",
257 | )
258 | return parser
259 |
260 | @staticmethod
261 | def from_namespace(namespace: Namespace) -> Args:
262 | """Create Args from a namespace."""
263 | kwargs = vars(namespace)
264 |
265 | test_masks = not kwargs.pop("no_test_masks", False)
266 | validate_gradients = not kwargs.pop("no_validate_gradients", False)
267 | benchmark_performance = not kwargs.pop("no_benchmark_performance", False)
268 |
269 | kwargs["test_masks"] = test_masks
270 | kwargs["validate_gradients"] = validate_gradients
271 | kwargs["benchmark_performance"] = benchmark_performance
272 |
273 | return Args(**kwargs)
274 |
275 |
276 | def create_test_tensors(
277 | args: Args, seq_len: int, device: torch.device, dtype: torch.dtype
278 | ) -> tuple[Tensor, ...]:
279 | """Create test tensors for benchmarking.
280 |
281 | Args:
282 | args: The training arguments.
283 | seq_len: The sequence length.
284 | device: The device to create the tensors on.
285 | dtype: The data type of the tensors.
286 |
287 | Returns:
288 | Tuple of (q_p, q_t, k_p, k_t, v_p, v_t, target) tensors.
289 | """
290 | gen = torch.Generator(device=device).manual_seed(args.seed)
291 | heads = args.model_dim // args.head_dim
292 |
293 | tensors = tuple(
294 | torch.randn(
295 | args.bsz,
296 | heads,
297 | seq_len,
298 | args.head_dim,
299 | device=device,
300 | dtype=dtype,
301 | generator=gen,
302 | )
303 | for _ in range(7)
304 | )
305 |
306 | return tensors
307 |
308 |
309 | def create_attention_mask(
310 | args: Args,
311 | seq_len: int,
312 | device: torch.device,
313 | dtype: torch.dtype,
314 | mask_type: str,
315 | ) -> Tensor | None:
316 | """Create an attention mask for testing.
317 |
318 | Args:
319 | args: The training arguments.
320 | seq_len: The sequence length.
321 | device: The device to create the mask on.
322 | dtype: The data type of the mask.
323 | mask_type: Type of mask ('none', 'boolean', or 'additive').
324 |
325 | Returns:
326 | The attention mask tensor or None if mask_type is 'none'.
327 | """
328 | if mask_type == "none":
329 | return None
330 |
331 | gen = torch.Generator(device=device).manual_seed(args.seed + 1000) # Different seed for masks
332 | heads = args.model_dim // args.head_dim
333 |
334 | if mask_type == "boolean":
335 | # Create a boolean mask where True means "attend" and False means "ignore"
336 | # We'll create a random mask with some positions masked out
337 | mask = (
338 | torch.rand(args.bsz, heads, seq_len, seq_len, device=device, generator=gen)
339 | > args.mask_prob
340 | )
341 | # mask[0, :-1, :, :2] = (
342 | # True # Ensure first two columns of the first batch element (except for its last head) are True
343 | # )
344 | # mask[1, :-1, :, -2:] = (
345 | # True # Ensure last two columns of the second batch element (except for its last head) are True
346 | # )
347 |
348 | # Find completely masked heads
349 | fully_masked = ~mask.view(args.bsz, heads, -1).any(dim=2)
350 |
351 | # For each fully masked head, unmask some random positions
352 | if fully_masked.any():
353 | print(" ⚠️ Some heads were fully masked; unmasking some positions to avoid this.")
354 | for b in range(args.bsz):
355 | for h in range(heads):
356 | if fully_masked[b, h]:
357 | num_to_unmask = max(1, seq_len * seq_len // 10)
358 | indices = torch.randperm(seq_len * seq_len, device=device, generator=gen)[
359 | :num_to_unmask
360 | ]
361 | mask[b, h].view(-1)[indices] = True
362 |
363 | return mask
364 |
365 | elif mask_type == "additive":
366 | # Create an additive mask with values to be added to attention scores
367 | # Use -inf (MASK_CONST) for positions to ignore, 0 for positions to attend
368 | rand_mask = torch.rand(args.bsz, heads, seq_len, seq_len, device=device, generator=gen)
369 | mask = torch.where(rand_mask > args.mask_prob, 0.0, MASK_CONST)
370 | # Convert to the target dtype
371 | mask = mask.to(dtype)
372 | # mask[0, :-1, :, :2] = (
373 | # 0.0 # Ensure first two columns of the first batch element (except for its last head) are zeros
374 | # )
375 | # mask[1, :-1, :, -2:] = (
376 | # 0.0 # Ensure last two columns of the second batch element (except for its last head) are zeros
377 | # )
378 |
379 | # Find completely masked heads
380 | fully_masked = (mask.view(args.bsz, heads, -1) == MASK_CONST).all(dim=2)
381 |
382 | # For each fully masked head, unmask some random positions
383 | if fully_masked.any():
384 | print(" ⚠️ Some heads were fully masked; unmasking some positions to avoid this.")
385 | for b in range(args.bsz):
386 | for h in range(heads):
387 | if fully_masked[b, h]:
388 | num_to_unmask = max(1, seq_len * seq_len // 10)
389 | indices = torch.randperm(seq_len * seq_len, device=device, generator=gen)[
390 | :num_to_unmask
391 | ]
392 | mask[b, h].view(-1)[indices] = 0.0
393 |
394 | return mask
395 |
396 | else:
397 | raise ValueError(f"Unknown mask type: {mask_type}")
398 |
399 |
400 | def loss_fn(out: Tensor, target: Tensor) -> Tensor:
401 | """Compute the mean squared error loss.
402 |
403 | Args:
404 | out: The output tensor.
405 | target: The target tensor.
406 |
407 | Returns:
408 | The mean squared error loss.
409 | """
410 | return (out - target).square().mean()
411 |
412 |
413 | def make_qkv_with_grad(
414 | q_p: Tensor, k_p: Tensor, v_p: Tensor, q_t: Tensor, k_t: Tensor, v_t: Tensor
415 | ) -> QKV:
416 | """Make a QKV tuple with gradients enabled.
417 |
418 | Args:
419 | q_p: The query projection tensor.
420 | k_p: The key projection tensor.
421 | v_p: The value projection tensor.
422 | q_t: The query tangent tensor.
423 | k_t: The key tangent tensor.
424 | v_t: The value tangent tensor.
425 |
426 | Returns:
427 | A QKV tuple containing the primal and tangent QKV tensors.
428 | """
429 | # Create dual tensors
430 | q = fwAD.make_dual(q_p, q_t)
431 | k = fwAD.make_dual(k_p, k_t)
432 | v = fwAD.make_dual(v_p, v_t)
433 |
434 | for t in (q, k, v):
435 | t.requires_grad = True
436 | t.retain_grad()
437 |
438 | return QKV(q, k, v)
439 |
440 |
441 | def make_qkv(q_p: Tensor, k_p: Tensor, v_p: Tensor, q_t: Tensor, k_t: Tensor, v_t: Tensor) -> QKV:
442 | """Make a QKV tuple from the given tensors with dual numbers.
443 |
444 | Args:
445 | q_p: The query projection tensor.
446 | k_p: The key projection tensor.
447 | v_p: The value projection tensor.
448 | q_t: The query tangent tensor.
449 | k_t: The key tangent tensor.
450 | v_t: The value tangent tensor.
451 |
452 | Returns:
453 | A QKV tuple containing the primal and tangent QKV tensors.
454 | """
455 | q = fwAD.make_dual(q_p, q_t)
456 | k = fwAD.make_dual(k_p, k_t)
457 | v = fwAD.make_dual(v_p, v_t)
458 | return QKV(q, k, v)
459 |
460 |
461 | def make_qkv_unpacked(
462 | q_p: Tensor, k_p: Tensor, v_p: Tensor, q_t: Tensor, k_t: Tensor, v_t: Tensor
463 | ) -> UnpackedDualQKV:
464 | """Make an unpacked dual QKV from the given tensors.
465 |
466 | Args:
467 | q_p: The query projection tensor.
468 | k_p: The key projection tensor.
469 | v_p: The value projection tensor.
470 | q_t: The query tangent tensor.
471 | k_t: The key tangent tensor.
472 | v_t: The value tangent tensor.
473 |
474 | Returns:
475 | An unpacked dual QKV containing the primal and tangent QKV tensors.
476 | """
477 | for t in (q_p, k_p, v_p):
478 | t.requires_grad = True
479 | t.retain_grad()
480 |
481 | return UnpackedDualQKV(
482 | primal=QKV(
483 | q=q_p,
484 | k=k_p,
485 | v=v_p,
486 | ),
487 | tangent=QKV(
488 | q=q_t,
489 | k=k_t,
490 | v=v_t,
491 | ),
492 | )
493 |
494 |
495 | def compute_absolute_error(*tensors: Tensor) -> float:
496 | """Compute the maximum absolute pairwise error between all tensors.
497 |
498 | Args:
499 | tensors: The input tensors to compare.
500 |
501 | Returns:
502 | The maximum absolute pairwise error.
503 | """
504 | if len(tensors) < 2:
505 | raise ValueError("At least two tensors are required to compute absolute error.")
506 | max_error = 0.0
507 | for i in range(len(tensors)):
508 | for j in range(i + 1, len(tensors)):
509 | diff = (tensors[i] - tensors[j]).abs().max().item()
510 | if diff > max_error:
511 | max_error = diff
512 | return max_error
513 |
514 |
515 | def validate_accuracy_and_gradients(
516 | q_p: Tensor,
517 | k_p: Tensor,
518 | v_p: Tensor,
519 | q_t: Tensor,
520 | k_t: Tensor,
521 | v_t: Tensor,
522 | target: Tensor,
523 | is_causal: bool,
524 | attn_mask: Tensor | None = None,
525 | tolerance: float = 4e-3,
526 | loss_tolerance: float = 5e-4,
527 | grad_tolerance: float = 5e-4,
528 | ) -> AccuracyMetrics:
529 | """Validate numerical accuracy and gradient matching between SDPA and JVP attention.
530 |
531 | Args:
532 | q_p: The query projection tensor.
533 | k_p: The key projection tensor.
534 | v_p: The value projection tensor.
535 | q_t: The query tangent tensor.
536 | k_t: The key tangent tensor.
537 | v_t: The value tangent tensor.
538 | target: The target tensor.
539 | is_causal: Whether the attention is causal.
540 | attn_mask: Optional attention mask tensor.
541 | tolerance: The tolerance for primal errors.
542 | loss_tolerance: The tolerance for loss errors.
543 | grad_tolerance: The tolerance for gradient errors.
544 |
545 | Returns:
546 | AccuracyMetrics containing all error measurements.
547 | """
548 | with sdpa_kernel(SDPBackend.MATH), fwAD.dual_level(), enable_grad():
549 | # Run SDPA
550 | q0, k0, v0 = make_qkv_with_grad(
551 | q_p.clone(), k_p.clone(), v_p.clone(), q_t.clone(), k_t.clone(), v_t.clone()
552 | )
553 |
554 | sdpa_out = scaled_dot_product_attention(
555 | q0, k0, v0, attn_mask=attn_mask, is_causal=is_causal
556 | )
557 | sdpa_out.retain_grad()
558 | sdpa_op, sdpa_ot = fwAD.unpack_dual(sdpa_out)
559 |
560 | loss0 = loss_fn(sdpa_out, target)
561 | loss0.backward()
562 |
563 | assert not any(
564 | t.grad.isnan().any() or t.grad.isinf().any() for t in (q0, k0, v0)
565 | ), "NaN/Inf in SDPA input gradients."
566 |
567 | # Run JVP Attention
568 | q1, k1, v1 = make_qkv_with_grad(
569 | q_p.clone(), k_p.clone(), v_p.clone(), q_t.clone(), k_t.clone(), v_t.clone()
570 | )
571 |
572 | jvp_out = JVPAttn.fwd_dual(q1, k1, v1, attn_mask=attn_mask, causal=is_causal)
573 | jvp_out.retain_grad()
574 | jvp_op, jvp_ot = fwAD.unpack_dual(jvp_out)
575 |
576 | loss1 = loss_fn(jvp_out, target)
577 | loss1.backward()
578 |
579 | assert not any(
580 | t.grad.isnan().any() or t.grad.isinf().any() for t in (q1, k1, v1)
581 | ), "NaN/Inf in JVP input gradients."
582 |
583 | mse_fn = MSELoss()
584 | with enable_grad():
585 | # Run JVP Attention with torch.func.jvp
586 | qkv_p, qkv_t = make_qkv_unpacked(
587 | q_p.clone(),
588 | k_p.clone(),
589 | v_p.clone(),
590 | q_t.clone(),
591 | k_t.clone(),
592 | v_t.clone(),
593 | )
594 |
595 | jvp_func_op, jvp_func_ot = torch.func.jvp(
596 | partial(JVPAttn.fwd_dual, attn_mask=attn_mask, causal=is_causal), qkv_p, qkv_t
597 | )
598 | jvp_func_op.retain_grad()
599 |
600 | loss2: Tensor = mse_fn(jvp_func_op, target)
601 | loss2.backward()
602 |
603 | q2, k2, v2 = qkv_p
604 |
605 | assert not any(
606 | t.grad.isnan().any() or t.grad.isinf().any() for t in (q2, k2, v2)
607 | ), "NaN/Inf in JVP (func) input gradients."
608 |
609 | # Compute errors
610 | primal_error = compute_absolute_error(jvp_func_op, jvp_op, sdpa_op)
611 | tangent_error = compute_absolute_error(jvp_func_ot, jvp_ot, sdpa_ot)
612 | loss_error = compute_absolute_error(loss2, loss1, loss0)
613 |
614 | # Compute gradient errors
615 | q_grad_error = compute_absolute_error(q2.grad, q1.grad, q0.grad)
616 | k_grad_error = compute_absolute_error(k2.grad, k1.grad, k0.grad)
617 | v_grad_error = compute_absolute_error(v2.grad, v1.grad, v0.grad)
618 |
619 | metrics = AccuracyMetrics(
620 | primal_error=primal_error,
621 | tangent_error=tangent_error,
622 | loss_error=loss_error,
623 | q_grad_error=q_grad_error,
624 | k_grad_error=k_grad_error,
625 | v_grad_error=v_grad_error,
626 | )
627 |
628 | # Validate using torch.testing.assert_close
629 | try:
630 | torch.testing.assert_close(jvp_op, sdpa_op, atol=tolerance, rtol=1e-5)
631 | torch.testing.assert_close(
632 | # TODO: Improve this (causal) accuracy for longer sequence lengths
633 | jvp_func_op,
634 | sdpa_op,
635 | atol=tolerance,
636 | rtol=1e-5,
637 | )
638 |
639 | # TODO: Improve these tangent accuracies
640 | torch.testing.assert_close(
641 | jvp_ot,
642 | sdpa_ot,
643 | atol=tolerance,
644 | rtol=1e-5,
645 | )
646 | torch.testing.assert_close(
647 | jvp_func_ot,
648 | sdpa_ot,
649 | atol=tolerance,
650 | rtol=1e-5,
651 | )
652 |
653 | torch.testing.assert_close(loss1, loss0, atol=loss_tolerance, rtol=1e-5)
654 | torch.testing.assert_close(loss2, loss0, atol=loss_tolerance, rtol=1e-5)
655 |
656 | torch.testing.assert_close(q1.grad, q0.grad, atol=grad_tolerance, rtol=1e-5)
657 | torch.testing.assert_close(k1.grad, k0.grad, atol=grad_tolerance, rtol=1e-5)
658 | torch.testing.assert_close(v1.grad, v0.grad, atol=grad_tolerance, rtol=1e-5)
659 |
660 | torch.testing.assert_close(q2.grad, q0.grad, atol=grad_tolerance, rtol=1e-5)
661 | torch.testing.assert_close(k2.grad, k0.grad, atol=grad_tolerance, rtol=1e-5)
662 | torch.testing.assert_close(v2.grad, v0.grad, atol=grad_tolerance, rtol=1e-5)
663 |
664 | except AssertionError as e:
665 | print(f" ⚠️ Accuracy validation failed (causal={is_causal}): {e}")
666 |
667 | return metrics
668 |
669 |
670 | def run_benchmark_suite(args: Args) -> list[BenchmarkResult]:
671 | """Run comprehensive benchmarks across different configurations.
672 |
673 | Args:
674 | args: The command-line arguments for the benchmark.
675 |
676 | Returns:
677 | A list of benchmark results.
678 | """
679 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
680 |
681 | dtype_map = {
682 | "float16": torch.float16,
683 | "float32": torch.float32,
684 | "bfloat16": torch.bfloat16,
685 | }
686 | dtype = dtype_map[args.dtype]
687 |
688 | tolerance_map = {
689 | "float16": 4e-3,
690 | "float32": 2.35e-2,
691 | "bfloat16": 3.2e-2,
692 | }
693 | tolerance = tolerance_map[args.dtype]
694 |
695 | results = []
696 |
697 | # Define mask types to test
698 | mask_types = ["none"]
699 | if args.test_masks:
700 | mask_types.extend(["boolean", "additive"])
701 |
702 | for seq_len in args.seq_lengths:
703 | print(f"\n{'='*60}")
704 | print(f"Benchmarking sequence length: {seq_len}")
705 | print(f"{'='*60}")
706 |
707 | # Create test tensors
708 | q_p, q_t, k_p, k_t, v_p, v_t, target = create_test_tensors(args, seq_len, device, dtype)
709 |
710 | for mask_type in mask_types:
711 | # Create attention mask if needed
712 | attn_mask = create_attention_mask(args, seq_len, device, dtype, mask_type)
713 |
714 | for is_causal in [False, True]:
715 | print(f"\nCausal: {is_causal}, Mask: {mask_type}")
716 | print("-" * 40)
717 |
718 | if is_causal and mask_type != "none":
719 | print(" Skipping invalid combination of causal + mask")
720 | continue
721 |
722 | # Validate accuracy and gradients first
723 | if args.validate_gradients:
724 | print("Validating accuracy and gradients...")
725 | accuracy_metrics = validate_accuracy_and_gradients(
726 | q_p,
727 | k_p,
728 | v_p,
729 | q_t,
730 | k_t,
731 | v_t,
732 | target,
733 | is_causal,
734 | attn_mask=attn_mask,
735 | tolerance=tolerance,
736 | )
737 | accuracy_metrics.tolerance = tolerance
738 |
739 | print(f" Primal error: {accuracy_metrics.primal_error:.2e}")
740 | print(f" Tangent error: {accuracy_metrics.tangent_error:.2e}")
741 | print(f" Loss error: {accuracy_metrics.loss_error:.2e}")
742 | print(f" Q gradient error: {accuracy_metrics.q_grad_error:.2e}")
743 | print(f" K gradient error: {accuracy_metrics.k_grad_error:.2e}")
744 | print(f" V gradient error: {accuracy_metrics.v_grad_error:.2e}")
745 |
746 | if accuracy_metrics.is_accurate():
747 | print(" ✓ All accuracy checks passed!")
748 | else:
749 | print(f" ⚠️ Max error {accuracy_metrics.max_error:.2e} exceeds tolerance")
750 | else:
751 | accuracy_metrics = None
752 |
753 | # Benchmark performance
754 | with sdpa_kernel(SDPBackend.MATH), fwAD.dual_level(), enable_grad():
755 | # Create functions for benchmarking
756 | def run_sdpa():
757 | """Run SDPA attention."""
758 | q, k, v = make_qkv(
759 | q_p.clone(),
760 | k_p.clone(),
761 | v_p.clone(),
762 | q_t.clone(),
763 | k_t.clone(),
764 | v_t.clone(),
765 | )
766 | out = scaled_dot_product_attention(
767 | q, k, v, attn_mask=attn_mask, is_causal=is_causal
768 | )
769 |
770 | def run_jvp_attn():
771 | """Run JVP attention."""
772 | q, k, v = make_qkv(
773 | q_p.clone(),
774 | k_p.clone(),
775 | v_p.clone(),
776 | q_t.clone(),
777 | k_t.clone(),
778 | v_t.clone(),
779 | )
780 | out = JVPAttn.fwd_dual(q, k, v, attn_mask=attn_mask, causal=is_causal)
781 |
782 | if not args.benchmark_performance:
783 | print(" Skipping performance benchmarking.")
784 | results.append(
785 | BenchmarkResult(
786 | seq_len=seq_len,
787 | is_causal=is_causal,
788 | method="sdpa",
789 | time_ms=np.nan,
790 | memory_allocated_mb=np.nan,
791 | memory_reserved_mb=np.nan,
792 | mask_type=mask_type,
793 | flops=np.nan,
794 | accuracy=None,
795 | )
796 | )
797 | results.append(
798 | BenchmarkResult(
799 | seq_len=seq_len,
800 | is_causal=is_causal,
801 | method="jvp_attn",
802 | time_ms=np.nan,
803 | memory_allocated_mb=np.nan,
804 | memory_reserved_mb=np.nan,
805 | mask_type=mask_type,
806 | flops=np.nan,
807 | accuracy=accuracy_metrics,
808 | )
809 | )
810 | continue
811 |
812 | print("\nBenchmarking performance...")
813 | heads = args.model_dim // args.head_dim
814 |
815 | # Measure SDPA performance
816 | sdpa_time = benchmark_function(
817 | run_sdpa, args.warmup_iters, args.benchmark_iters
818 | )
819 | sdpa_mem_alloc, sdpa_mem_reserved = measure_memory_usage(run_sdpa)
820 | sdpa_flops = get_attention_flop_count(
821 | args.bsz, heads, seq_len, args.head_dim, is_causal, is_jvp=False
822 | )
823 |
824 | # Measure JVP Attention performance
825 | jvp_time = benchmark_function(
826 | run_jvp_attn, args.warmup_iters, args.benchmark_iters
827 | )
828 | jvp_mem_alloc, jvp_mem_reserved = measure_memory_usage(run_jvp_attn)
829 | jvp_flops = get_attention_flop_count(
830 | args.bsz, heads, seq_len, args.head_dim, is_causal, is_jvp=True
831 | )
832 |
833 | # Store results
834 | results.append(
835 | BenchmarkResult(
836 | seq_len=seq_len,
837 | is_causal=is_causal,
838 | method="sdpa",
839 | time_ms=sdpa_time,
840 | memory_allocated_mb=sdpa_mem_alloc,
841 | memory_reserved_mb=sdpa_mem_reserved,
842 | mask_type=mask_type,
843 | flops=sdpa_flops,
844 | accuracy=None,
845 | )
846 | )
847 |
848 | results.append(
849 | BenchmarkResult(
850 | seq_len=seq_len,
851 | is_causal=is_causal,
852 | method="jvp_attn",
853 | time_ms=jvp_time,
854 | memory_allocated_mb=jvp_mem_alloc,
855 | memory_reserved_mb=jvp_mem_reserved,
856 | mask_type=mask_type,
857 | flops=jvp_flops,
858 | accuracy=accuracy_metrics,
859 | )
860 | )
861 |
862 | # Print results
863 | print("PyTorch SDPA:")
864 | print(f" Time: {sdpa_time:.3f} ms")
865 | print(
866 | f" Memory (alloc/reserved): {sdpa_mem_alloc:.2f}/{sdpa_mem_reserved:.2f} MB"
867 | )
868 | print(f" FLOPS: {fmt_flops(mpi_to_flops(sdpa_time, sdpa_flops))}")
869 |
870 | print("\nJVP Attention:")
871 | print(f" Time: {jvp_time:.3f} ms")
872 | print(
873 | f" Memory (alloc/reserved): {jvp_mem_alloc:.2f}/{jvp_mem_reserved:.2f} MB"
874 | )
875 | print(f" FLOPS: {fmt_flops(mpi_to_flops(jvp_time, jvp_flops))}")
876 |
877 | print(f"\nSpeedup: {sdpa_time/jvp_time:.2f}x")
878 | print(f"Memory ratio: {jvp_mem_alloc/sdpa_mem_alloc:.2f}x")
879 |
880 | return results
881 |
882 |
883 | def print_summary_table(results: list[BenchmarkResult]) -> None:
884 | """Print a summary table of benchmark results.
885 |
886 | Args:
887 | results: The list of benchmark results to summarize.
888 | """
889 | print("\n" + "=" * 110)
890 | print("BENCHMARK SUMMARY")
891 | print("=" * 110)
892 |
893 | # Group results by seq_len, causal, and mask_type
894 | from collections import defaultdict
895 |
896 | grouped = defaultdict(dict)
897 |
898 | for r in results:
899 | key = (r.seq_len, r.is_causal, r.mask_type)
900 | grouped[key][r.method] = r
901 |
902 | # Print header
903 | print(
904 | f"{'Seq Len':<10} {'Causal':<8} {'Mask':<10} {'Method':<10} "
905 | f"{'Time (ms)':<12} {'Mem (MB)':<12} {'TFLOP/s':<12} "
906 | f"{'Max Error':<12} {'Grad Check':<10}"
907 | )
908 | print("-" * 110)
909 |
910 | for (seq_len, is_causal, mask_type), methods in sorted(grouped.items()):
911 | for method in ["sdpa", "jvp_attn"]:
912 | if method in methods:
913 | r = methods[method]
914 | flops_str = fmt_flops(mpi_to_flops(r.time_ms, r.flops)) if r.flops else "N/A"
915 |
916 | if r.accuracy:
917 | error_str = f"{r.accuracy.max_error:.2e}"
918 | grad_check = "✓" if r.accuracy.is_accurate() else "✗"
919 | else:
920 | error_str = "baseline"
921 | grad_check = "N/A"
922 |
923 | print(
924 | f"{seq_len:<10} {str(is_causal):<8} {mask_type:<10} {method:<10} "
925 | f"{r.time_ms:<12.3f} {r.memory_allocated_mb:<12.2f} "
926 | f"{flops_str:<12} {error_str:<12} {grad_check:<10}"
927 | )
928 | print()
929 |
930 |
931 | def print_mask_comparison_table(results: list[BenchmarkResult]) -> None:
932 | """Print a comparison table showing the impact of different mask types.
933 |
934 | Args:
935 | results: The list of benchmark results to analyze.
936 | """
937 | print("\n" + "=" * 80)
938 | print("MASK TYPE PERFORMANCE COMPARISON")
939 | print("=" * 80)
940 |
941 | # Group by seq_len, is_causal, and method
942 | from collections import defaultdict
943 |
944 | grouped = defaultdict(lambda: defaultdict(dict))
945 |
946 | for r in results:
947 | grouped[(r.seq_len, r.is_causal, r.method)][r.mask_type] = r
948 |
949 | print(
950 | f"{'Seq Len':<10} {'Causal':<8} {'Method':<10} "
951 | f"{'No Mask':<15} {'Boolean Mask':<15} {'Additive Mask':<15}"
952 | )
953 | print("-" * 80)
954 |
955 | for (seq_len, is_causal, method), mask_results in sorted(grouped.items()):
956 | if method == "jvp_attn": # Only show JVP results for clarity
957 | none_time = mask_results.get("none", None)
958 | bool_time = mask_results.get("boolean", None)
959 | add_time = mask_results.get("additive", None)
960 |
961 | none_str = f"{none_time.time_ms:.2f} ms" if none_time else "N/A"
962 | bool_str = f"{bool_time.time_ms:.2f} ms" if bool_time else "N/A"
963 | add_str = f"{add_time.time_ms:.2f} ms" if add_time else "N/A"
964 |
965 | # Add relative performance
966 | if none_time and bool_time:
967 | bool_str += f" ({bool_time.time_ms/none_time.time_ms:.2f}x)"
968 | if none_time and add_time:
969 | add_str += f" ({add_time.time_ms/none_time.time_ms:.2f}x)"
970 |
971 | print(
972 | f"{seq_len:<10} {str(is_causal):<8} {method:<10} "
973 | f"{none_str:<15} {bool_str:<15} {add_str:<15}"
974 | )
975 |
976 |
977 | def plot_benchmark_results(
978 | results: list[BenchmarkResult], args: Args, verbose: bool = False
979 | ) -> None:
980 | """Generate, save, and display plots summarizing benchmark results.
981 |
982 | Args:
983 | results: The list of benchmark results to plot.
984 | args: The command-line arguments, used for filename generation.
985 | verbose: Whether to print verbose output.
986 | """
987 | if not PLOTTING_AVAILABLE:
988 | print("\nmatplotlib and/or numpy not found. Skipping plotting.")
989 | return
990 |
991 | from collections import defaultdict
992 |
993 | # Group results by (is_causal, mask_type) for creating subplots
994 | grouped = defaultdict(list)
995 | for r in results:
996 | key = (r.is_causal, r.mask_type)
997 | grouped[key].append(r)
998 |
999 | num_configs = len(grouped)
1000 | if num_configs == 0:
1001 | return
1002 |
1003 | # --- 1. Create and Populate the Performance Figure ---
1004 | fig_perf, axes_perf = plt.subplots(1, num_configs, figsize=(6 * num_configs, 5), squeeze=False)
1005 | fig_perf.suptitle("Performance Speedup (JVP Attention vs. SDPA)", fontsize=16)
1006 |
1007 | # --- 2. Create and Populate the Memory Figure ---
1008 | fig_mem, axes_mem = plt.subplots(1, num_configs, figsize=(6 * num_configs, 5), squeeze=False)
1009 | fig_mem.suptitle("Peak Allocated Memory Comparison", fontsize=16)
1010 |
1011 | # Iterate through data to draw on the axes of BOTH figures
1012 | for i, ((is_causal, mask_type), config_results) in enumerate(grouped.items()):
1013 | config_results.sort(key=lambda r: r.seq_len)
1014 |
1015 | # Fixed: Use the correct method names "jvp_attn" and "sdpa"
1016 | jvp_results = [r for r in config_results if r.method == "jvp_attn"]
1017 | sdpa_results = [r for r in config_results if r.method == "sdpa"]
1018 |
1019 | if not jvp_results or not sdpa_results:
1020 | # Debug print to help identify issues
1021 | if verbose:
1022 | print(
1023 | f"Warning: Missing results for config (causal={is_causal}, mask={mask_type})"
1024 | )
1025 | print(f" Available methods: {set(r.method for r in config_results)}")
1026 | continue
1027 |
1028 | seq_lens = [r.seq_len for r in jvp_results]
1029 | jvp_times = np.array([r.time_ms for r in jvp_results])
1030 | sdpa_times = np.array([r.time_ms for r in sdpa_results])
1031 | jvp_mems = np.array([r.memory_allocated_mb for r in jvp_results])
1032 | sdpa_mems = np.array([r.memory_allocated_mb for r in sdpa_results])
1033 | speedup = sdpa_times / jvp_times
1034 | x = np.arange(len(seq_lens))
1035 |
1036 | # Draw on the performance subplot
1037 | ax_perf = axes_perf[0, i]
1038 | bar_perf = ax_perf.bar(x, speedup, width=0.5, color="g")
1039 | ax_perf.bar_label(bar_perf, fmt=lambda val: f"{val:.2f}x")
1040 | ax_perf.axhline(1.0, color="grey", linestyle="--")
1041 | ax_perf.set(
1042 | ylabel="Speedup (SDPA Time / JVP Time)",
1043 | xlabel="Sequence Length",
1044 | title=f"Causal={is_causal}, Mask={mask_type}",
1045 | xticks=x,
1046 | xticklabels=seq_lens,
1047 | ylim=(0, max(1.1, np.max(speedup) * 1.15)),
1048 | )
1049 |
1050 | # Draw on the memory subplot
1051 | ax_mem = axes_mem[0, i]
1052 | width = 0.35
1053 | rects1 = ax_mem.bar(x - width / 2, sdpa_mems, width, label="PyTorch SDPA")
1054 | rects2 = ax_mem.bar(x + width / 2, jvp_mems, width, label="JVP Attention")
1055 | ax_mem.bar_label(rects1, padding=3, fmt=lambda val: f"{val:.1f}")
1056 | ax_mem.bar_label(rects2, padding=3, fmt=lambda val: f"{val:.1f}")
1057 | ax_mem.set(
1058 | ylabel="Peak Allocated Memory (MB)",
1059 | xlabel="Sequence Length",
1060 | title=f"Causal={is_causal}, Mask={mask_type}",
1061 | xticks=x,
1062 | xticklabels=seq_lens,
1063 | ylim=(0, max(np.max(sdpa_mems), np.max(jvp_mems)) * 1.25),
1064 | )
1065 | ax_mem.legend()
1066 |
1067 | # --- 3. Finalize and Save Each Figure Individually ---
1068 | plot_dir = "tests"
1069 | os.makedirs(plot_dir, exist_ok=True)
1070 | mask_suffix = "_with_masks" if args.test_masks else ""
1071 |
1072 | # Finalize and save the performance plot
1073 | perf_plot_path = os.path.join(plot_dir, f"{args.dtype}_jvp_attention_perf{mask_suffix}.png")
1074 | fig_perf.tight_layout(rect=[0, 0.03, 1, 0.95])
1075 | fig_perf.savefig(perf_plot_path, dpi=150)
1076 | if verbose:
1077 | print(f"Saved performance plot to {perf_plot_path}")
1078 |
1079 | # Finalize and save the memory plot
1080 | mem_plot_path = os.path.join(plot_dir, f"{args.dtype}_jvp_attention_mem{mask_suffix}.png")
1081 | fig_mem.tight_layout(rect=[0, 0.03, 1, 0.95])
1082 | fig_mem.savefig(mem_plot_path, dpi=150)
1083 | if verbose:
1084 | print(f"Saved memory plot to {mem_plot_path}")
1085 |
1086 | # --- 4. Show and Close ---
1087 | plt.show()
1088 |
1089 | # Explicitly close figures to free memory
1090 | plt.close(fig_perf)
1091 | plt.close(fig_mem)
1092 |
1093 |
1094 | def main(args: Args) -> None:
1095 | """Main benchmarking loop."""
1096 | print("Flash Attention JVP Kernel Benchmark with Mask Testing")
1097 | print(
1098 | f"Configuration: bsz={args.bsz}, model_dim={args.model_dim}, "
1099 | f"head_dim={args.head_dim}, dtype={args.dtype}"
1100 | )
1101 | print(f"Mask testing: {'Enabled' if args.test_masks else 'Disabled'}")
1102 | print(f"Gradient validation: {'Enabled' if args.validate_gradients else 'Disabled'}")
1103 | print(f"Performance benchmarking: {'Enabled' if args.benchmark_performance else 'Disabled'}")
1104 | if args.test_masks:
1105 | print(f"Mask probability: {args.mask_prob}")
1106 |
1107 | # Seed everything
1108 | random.seed(args.seed)
1109 | if PLOTTING_AVAILABLE:
1110 | np.random.seed(args.seed)
1111 | torch.manual_seed(args.seed)
1112 | torch.cuda.manual_seed_all(args.seed)
1113 |
1114 | results = run_benchmark_suite(args)
1115 |
1116 | # Print summary tables
1117 | print_summary_table(results)
1118 |
1119 | # If performance was benchmarked, plot benchmarking results
1120 | if args.benchmark_performance:
1121 | plot_benchmark_results(results, args)
1122 |
1123 | # If masks were tested, print comparison table
1124 | if args.test_masks:
1125 | print_mask_comparison_table(results)
1126 |
1127 | # Print statistics
1128 | print("\n" + "=" * 60)
1129 | print("STATISTICS")
1130 | print("=" * 60)
1131 |
1132 | # Calculate average speedup
1133 | speedups = []
1134 | for i in range(0, len(results), 2): # Assuming pairs of sdpa/jvp_attn
1135 | if i + 1 < len(results):
1136 | sdpa_result = results[i]
1137 | jvp_result = results[i + 1]
1138 | if sdpa_result.method == "sdpa" and jvp_result.method == "jvp_attn":
1139 | speedup = sdpa_result.time_ms / jvp_result.time_ms
1140 | speedups.append(speedup)
1141 |
1142 | if speedups:
1143 | avg_speedup = sum(speedups) / len(speedups)
1144 | min_speedup = min(speedups)
1145 | max_speedup = max(speedups)
1146 | print(f"Average speedup: {avg_speedup:.2f}x")
1147 | print(f"Min speedup: {min_speedup:.2f}x")
1148 | print(f"Max speedup: {max_speedup:.2f}x")
1149 |
1150 | # Calculate accuracy statistics
1151 | accuracy_results = [r for r in results if r.accuracy is not None]
1152 | if accuracy_results:
1153 | all_accurate = all(r.accuracy.is_accurate() for r in accuracy_results)
1154 | num_accurate = sum(1 for r in accuracy_results if r.accuracy.is_accurate())
1155 | print(f"\nAccuracy: {num_accurate}/{len(accuracy_results)} tests passed")
1156 | if all_accurate:
1157 | print("✓ All accuracy checks passed!")
1158 | else:
1159 | print("⚠️ Some accuracy checks failed")
1160 |
1161 | # Show which configurations failed
1162 | failed_configs = [
1163 | (r.seq_len, r.is_causal, r.mask_type)
1164 | for r in accuracy_results
1165 | if not r.accuracy.is_accurate()
1166 | ]
1167 | if failed_configs:
1168 | print("\nFailed configurations:")
1169 | for seq_len, is_causal, mask_type in failed_configs:
1170 | print(f" - Seq={seq_len}, Causal={is_causal}, Mask={mask_type}")
1171 |
1172 | # Save results to file
1173 | import json
1174 |
1175 | # Convert results to JSON-serializable format
1176 | results_data = []
1177 | for r in results:
1178 | result_dict = r._asdict()
1179 | if r.accuracy:
1180 | result_dict["accuracy"] = {
1181 | "primal_error": r.accuracy.primal_error,
1182 | "tangent_error": r.accuracy.tangent_error,
1183 | "loss_error": r.accuracy.loss_error,
1184 | "q_grad_error": r.accuracy.q_grad_error,
1185 | "k_grad_error": r.accuracy.k_grad_error,
1186 | "v_grad_error": r.accuracy.v_grad_error,
1187 | "max_error": r.accuracy.max_error,
1188 | "is_accurate": r.accuracy.is_accurate(),
1189 | }
1190 | results_data.append(result_dict)
1191 |
1192 | # Include configuration in filename
1193 | mask_suffix = "_with_masks" if args.test_masks else ""
1194 | output_filepath = os.path.join(
1195 | "tests", f"{args.dtype}_test_jvp_attention_results{mask_suffix}.json"
1196 | )
1197 | os.makedirs(os.path.dirname(output_filepath), exist_ok=True)
1198 |
1199 | # Save both results and configuration
1200 | output_data = {
1201 | "configuration": {
1202 | "bsz": args.bsz,
1203 | "model_dim": args.model_dim,
1204 | "head_dim": args.head_dim,
1205 | "seq_lengths": args.seq_lengths,
1206 | "dtype": args.dtype,
1207 | "seed": args.seed,
1208 | "test_masks": args.test_masks,
1209 | "validate_gradients": args.validate_gradients,
1210 | "benchmark_performance": args.benchmark_performance,
1211 | "mask_prob": args.mask_prob if args.test_masks else None,
1212 | },
1213 | "results": results_data,
1214 | "summary": {
1215 | "avg_speedup": avg_speedup if speedups else None,
1216 | "min_speedup": min_speedup if speedups else None,
1217 | "max_speedup": max_speedup if speedups else None,
1218 | "accuracy_rate": (
1219 | f"{num_accurate}/{len(accuracy_results)}" if accuracy_results else "N/A"
1220 | ),
1221 | },
1222 | }
1223 |
1224 | with open(output_filepath, "w") as f:
1225 | json.dump(output_data, f, indent=2, default=str)
1226 |
1227 | print(f"\nResults saved to {output_filepath}")
1228 |
1229 |
1230 | if __name__ == "__main__":
1231 | parser = Args.get_parser()
1232 | namespace = parser.parse_args()
1233 | args = Args.from_namespace(namespace)
1234 | main(args)
1235 |
--------------------------------------------------------------------------------
/jvp_flash_attention/jvp_attention.py:
--------------------------------------------------------------------------------
1 | """
2 | Fused Attention
3 | ===============
4 |
5 | This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf)
6 |
7 | Credits: OpenAI kernel team
8 | Extra Credits:
9 |
10 | * Original flash attention paper (https://arxiv.org/abs/2205.14135)
11 | * Rabe and Staats (https://arxiv.org/pdf/2112.05682v2.pdf)
12 |
13 | Plus modifications to support Jacobian-vector products (JVPs) and Hessian-vector products (HVPs):
14 | - Formulation of flash JVP, by Cheng Lu and Yang Song in https://arxiv.org/abs/2410.11081.
15 | - Reference Triton implementation, by Sofian Mejjoute.
16 | - Reimplementing reference implementation as an autograd function with latest Triton tutorial optimizations, by Alex Birch.
17 | - Support for forward to receive tangents, so as to compute fwd and jvp together; autograd workaround, by Emily (nshepperd).
18 | - Support for function transforms (e.g., torch.func.jvp) via the use of setup_context, by Shih-Ying Yeh.
19 | - Support for sequence lengths 32 & 64; float32 & bfloat16 precision; comprehensive, length and dtype-stratified unit tests;
20 | working backward hook w.r.t. tensor contiguity; HVP stress testing; standardized docstrings/packaging; and masking/dropout, by Alex Morehead.
21 | """
22 |
23 | from __future__ import annotations
24 |
25 | import os
26 | from typing import Any, Literal, NamedTuple
27 |
28 | import torch
29 | import torch.autograd.forward_ad as fwAD
30 | import triton
31 | import triton.language as tl
32 | from torch import Tensor
33 | from torch.autograd import Function
34 | from torch.autograd.function import FunctionCtx
35 |
36 | # NOTE: Uncomment to turn warnings into errors for debugging
37 | # import warnings
38 | # warnings.filterwarnings("error", category=UserWarning)
39 | # warnings.filterwarnings("error", category=RuntimeWarning)
40 |
41 | try:
42 | from triton.tools.tensor_descriptor import TensorDescriptor
43 |
44 | HAS_TENSOR_DESC = True
45 | except ModuleNotFoundError:
46 | HAS_TENSOR_DESC = False
47 |
48 | MASK_CONST = (
49 | -1.0e2
50 | ) # Use a large negative value for masking (compatible with float16, bfloat16, and float32)
51 | MIN_SEQUENCE_LENGTH = 32 # NOTE: All sequence lengths must be multiples of 2 >= 32
52 |
53 |
54 | def is_hip():
55 | """Check if the current device is HIP."""
56 | try:
57 | return triton.runtime.driver.active.get_current_target().backend == "hip"
58 | except Exception:
59 | return False
60 |
61 |
62 | def is_cuda():
63 | """Check if the current device is CUDA."""
64 | try:
65 | return triton.runtime.driver.active.get_current_target().backend == "cuda"
66 | except Exception:
67 | return False
68 |
69 |
70 | def supports_host_descriptor():
71 | """Check if the current device supports host tensor descriptors."""
72 | try:
73 | return is_cuda() and torch.cuda.get_device_capability()[0] >= 9
74 | except Exception:
75 | return False
76 |
77 |
78 | def supports_tma():
79 | """Check if the current device supports Tensor Memory Access (TMA)."""
80 | try:
81 | return HAS_TENSOR_DESC and is_cuda() and torch.cuda.get_device_capability()[0] >= 9
82 | except Exception:
83 | return False
84 |
85 |
86 | def is_blackwell():
87 | """Check if the current device is Blackwell architecture."""
88 | try:
89 | return is_cuda() and torch.cuda.get_device_capability()[0] == 10
90 | except Exception:
91 | return False
92 |
93 |
94 | @triton.jit
95 | def create_dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride):
96 | """Generate dropout mask using Philox RNG.
97 |
98 | Args:
99 | philox_seed: Seed for Philox RNG.
100 | philox_offset: Offset for Philox RNG.
101 | dropout_p: Dropout probability.
102 | m: Number of rows.
103 | n: Number of columns.
104 | stride: Stride for the output mask.
105 |
106 | Returns:
107 | dropout_mask: A boolean mask indicating which elements to keep (1.0) or drop (0.0).
108 | dropout_scale: Scale factor to apply after dropout.
109 | """
110 | ms = tl.arange(0, m)
111 | ns = tl.arange(0, n)
112 | offs = ms[:, None] * stride + ns[None, :]
113 | rng_offs = philox_offset + offs
114 |
115 | # Generate random values using Philox
116 | rand_vals = tl.rand(philox_seed, rng_offs)
117 |
118 | # Create dropout mask (1.0 = keep, 0.0 = drop)
119 | dropout_mask = rand_vals > dropout_p
120 | dropout_scale = 1.0 / (1.0 - dropout_p) if dropout_p < 1.0 else 0.0
121 |
122 | return dropout_mask, dropout_scale
123 |
124 |
125 | @triton.jit
126 | def _attn_fwd_inner(
127 | acc,
128 | g_acc, #
129 | l_i,
130 | m_i, #
131 | mu_i,
132 | p_tv_acc, #
133 | q,
134 | t_q, #
135 | K_block_ptr,
136 | V_block_ptr, #
137 | T_K_block_ptr,
138 | T_V_block_ptr, #
139 | # Mask and dropout parameters
140 | mask_block_ptr,
141 | dropout_p,
142 | philox_seed,
143 | philox_offset_base,
144 | # Other parameters
145 | dtype: tl.constexpr,
146 | start_m,
147 | qk_scale,
148 | sm_scale, #
149 | BLOCK_M: tl.constexpr,
150 | HEAD_DIM: tl.constexpr,
151 | BLOCK_N: tl.constexpr, #
152 | STAGE: tl.constexpr,
153 | offs_m: tl.constexpr,
154 | offs_n: tl.constexpr, #
155 | N_CTX: tl.constexpr,
156 | warp_specialize: tl.constexpr, #
157 | ENABLE_JVP: tl.constexpr,
158 | ENABLE_DROPOUT: tl.constexpr,
159 | MASK_TYPE: tl.constexpr, # 0: no mask, 1: boolean, 2: additive
160 | MASK_CONST: tl.constexpr = MASK_CONST,
161 | ):
162 | """Inner forward pass for attention mechanism.
163 |
164 | Args:
165 | acc: Accumulator tensor.
166 | g_acc: Gradient accumulator tensor.
167 | l_i: Tensor for storing intermediate results.
168 | m_i: Tensor for storing intermediate results.
169 | mu_i: Tensor for storing intermediate results.
170 | p_tv_acc: Tensor for storing intermediate results.
171 | q: Query tensor.
172 | t_q: Tangent of the query tensor.
173 | K_block_ptr: Pointer to the key block.
174 | V_block_ptr: Pointer to the value block.
175 | T_K_block_ptr: Pointer to the tangent key block.
176 | T_V_block_ptr: Pointer to the tangent value block.
177 | mask_block_ptr: Pointer to the attention mask block.
178 | dropout_p: Dropout probability.
179 | philox_seed: Seed for Philox RNG.
180 | philox_offset_base: Base offset for Philox RNG.
181 | dtype: Data type of the tensors.
182 | start_m: Starting index for the current block.
183 | qk_scale: Scale factor for the query-key dot product.
184 | sm_scale: Scale factor for the softmax.
185 | BLOCK_M: Block size for the M dimension.
186 | HEAD_DIM: Dimension of the attention heads.
187 | BLOCK_N: Block size for the N dimension.
188 | STAGE: Current stage of the computation.
189 | offs_m: Offsets for the M dimension.
190 | offs_n: Offsets for the N dimension.
191 | N_CTX: Number of context tokens.
192 | warp_specialize: Whether to apply warp specialization.
193 | ENABLE_JVP: Whether to enable JVP (Jacobian-vector product).
194 | ENABLE_DROPOUT: Whether to enable dropout.
195 | MASK_TYPE: Type of attention mask (0: no mask, 1: boolean, 2: additive).
196 | MASK_CONST: Constant value used for masking.
197 |
198 | Returns:
199 | The output tensors as a tuple.
200 | """
201 | # Range of values handled by this stage
202 | if STAGE == 1:
203 | # NOTE: From 0 to the left of the diagonal
204 | lo, hi = 0, start_m * BLOCK_M
205 | elif STAGE == 2:
206 | # NOTE: Used only for the block in which there is transition between non-masked and masked keys
207 | lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
208 | lo = tl.multiple_of(lo, BLOCK_M)
209 | else:
210 | # NOTE: Only used for non-causal attention
211 | lo, hi = 0, N_CTX
212 |
213 | K_block_ptr = tl.advance(K_block_ptr, (0, lo))
214 | # NOTE: In fp8 mode, we may want to advance the V_block_ptr differently.
215 | # I did try advancing by (0, lo) instead for fp8, but I got an illegal memory access.
216 | # https://github.com/triton-lang/triton/commit/75d27b0b425329bad8c13b9cd47177d93590ec31
217 | V_block_ptr = tl.advance(V_block_ptr, (lo, 0))
218 | if ENABLE_JVP:
219 | T_K_block_ptr = tl.advance(T_K_block_ptr, (0, lo))
220 | T_V_block_ptr = tl.advance(T_V_block_ptr, (lo, 0))
221 |
222 | if MASK_TYPE > 0:
223 | mask_block_ptr = tl.advance(mask_block_ptr, (0, lo))
224 |
225 | # Loop over k, v and update accumulator
226 | for start_n in range(lo, hi, BLOCK_N):
227 | # Let the compiler know that start_n is a multiple
228 | # of BLOCK_N, so the compiler can do optimizations
229 | start_n = tl.multiple_of(start_n, BLOCK_N)
230 |
231 | # -- Compute qk --
232 | k = tl.load(K_block_ptr)
233 | qk = tl.dot(q, k)
234 | if ENABLE_JVP:
235 | t_k = tl.load(T_K_block_ptr)
236 | t_qk = tl.dot(t_q, k) + tl.dot(q, t_k)
237 |
238 | # Load and apply attention mask if provided (before scaling for STAGE != 2)
239 | if MASK_TYPE > 0:
240 | mask = tl.load(mask_block_ptr)
241 | if MASK_TYPE == 1: # Boolean mask
242 | # Convert boolean to additive mask: True (attend) -> 0, False (ignore) -> -inf
243 | qk = qk + tl.where(mask == 1, 0.0, MASK_CONST)
244 | if ENABLE_JVP:
245 | t_qk = tl.where(mask == 1, t_qk, 0.0)
246 |
247 | elif MASK_TYPE == 2: # Additive mask
248 | qk = qk + mask
249 |
250 | m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale)
251 | qk = qk * qk_scale - m_ij[:, None]
252 |
253 | # For causal attention (STAGE == 2)
254 | elif STAGE == 2:
255 | mask = offs_m[:, None] >= (start_n + offs_n[None, :])
256 | qk = qk * qk_scale + tl.where(mask, 0.0, MASK_CONST)
257 | m_ij = tl.maximum(m_i, tl.max(qk, 1))
258 | qk -= m_ij[:, None]
259 |
260 | # No masking case (MASK_TYPE == 0 and STAGE != 2)
261 | else:
262 | m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale)
263 | qk = qk * qk_scale - m_ij[:, None]
264 |
265 | p = tl.math.exp2(qk)
266 |
267 | if MASK_TYPE == 1 or STAGE == 2:
268 | # Account for fully masked sequence blocks
269 | p = tl.where(mask == 1, p, 0.0)
270 |
271 | # Apply dropout if enabled
272 | if ENABLE_DROPOUT:
273 | philox_offset = philox_offset_base + start_m * N_CTX + start_n
274 | dropout_mask, dropout_scale = create_dropout_mask(
275 | philox_seed, philox_offset, dropout_p, BLOCK_M, BLOCK_N, N_CTX
276 | )
277 | p = p * dropout_mask.to(dtype) * dropout_scale
278 |
279 | l_ij = tl.sum(p, 1)
280 |
281 | # -- Update m_i and l_i --
282 | alpha = tl.math.exp2(m_i - m_ij)
283 | l_i = l_i * alpha + l_ij
284 |
285 | # -- Update output accumulator --
286 | if warp_specialize and (BLOCK_M == 128 and HEAD_DIM == 128):
287 | BM: tl.constexpr = acc.shape[0]
288 | BN: tl.constexpr = acc.shape[1]
289 | acc0, acc1 = acc.reshape([BM, 2, BN // 2]).permute(0, 2, 1).split()
290 | acc0 = acc0 * alpha[:, None]
291 | acc1 = acc1 * alpha[:, None]
292 | acc = tl.join(acc0, acc1).permute(0, 2, 1).reshape([BM, BN])
293 | else:
294 | acc = acc * alpha[:, None]
295 |
296 | v = tl.load(V_block_ptr)
297 | # NOTE: We may need to transpose v if dtype == tl.float8e5
298 | # https://github.com/triton-lang/triton/commit/75d27b0b425329bad8c13b9cd47177d93590ec31
299 | p = p.to(dtype)
300 |
301 | if ENABLE_JVP:
302 | p_tqk = p * (t_qk * sm_scale)
303 |
304 | if warp_specialize and (BLOCK_M == 128 and HEAD_DIM == 128):
305 | BM: tl.constexpr = g_acc.shape[0]
306 | BN: tl.constexpr = g_acc.shape[1]
307 | g_acc0, g_acc1 = g_acc.reshape([BM, 2, BN // 2]).permute(0, 2, 1).split()
308 | g_acc0 = g_acc0 * alpha[:, None]
309 | g_acc1 = g_acc1 * alpha[:, None]
310 | g_acc = tl.join(g_acc0, g_acc1).permute(0, 2, 1).reshape([BM, BN])
311 | else:
312 | g_acc = g_acc * alpha[:, None]
313 |
314 | g_acc = tl.dot(p_tqk.to(v.dtype), v, g_acc)
315 | mu_ij = tl.sum(p_tqk, 1)
316 | mu_i = mu_i * alpha + mu_ij
317 | t_v = tl.load(T_V_block_ptr)
318 | p_tv_acc = p_tv_acc * alpha[:, None] + tl.dot(p, t_v.to(dtype)).to(t_v.dtype)
319 | T_V_block_ptr = tl.advance(T_V_block_ptr, (BLOCK_N, 0))
320 | T_K_block_ptr = tl.advance(T_K_block_ptr, (0, BLOCK_N))
321 |
322 | acc = tl.dot(p, v.to(dtype), acc).to(acc.dtype)
323 |
324 | # -- Update m_i --
325 | m_i = m_ij
326 |
327 | # -- Move to the next block of K, V, and maybe the mask --
328 | # NOTE: The fp8 PR made a change to how K and V are advanced here but I believe we already have that.
329 | # https://github.com/triton-lang/triton/commit/75d27b0b425329bad8c13b9cd47177d93590ec31
330 | V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
331 | K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
332 |
333 | if MASK_TYPE > 0:
334 | mask_block_ptr = tl.advance(mask_block_ptr, (0, BLOCK_N))
335 |
336 | return acc, g_acc, l_i, m_i, mu_i, p_tv_acc
337 |
338 |
339 | @triton.jit
340 | def _attn_fwd_inner_tma(
341 | acc,
342 | g_acc, #
343 | l_i,
344 | m_i, #
345 | mu_i,
346 | p_tv_acc, #
347 | q,
348 | t_q, #
349 | desc_k,
350 | desc_v, #
351 | desc_k_t,
352 | desc_v_t, #
353 | offset_y,
354 | # Mask and dropout parameters
355 | mask_block_ptr,
356 | dropout_p,
357 | philox_seed,
358 | philox_offset_base,
359 | # Other parameters
360 | dtype: tl.constexpr,
361 | start_m,
362 | qk_scale,
363 | sm_scale, #
364 | BLOCK_M: tl.constexpr,
365 | HEAD_DIM: tl.constexpr,
366 | BLOCK_N: tl.constexpr, #
367 | STAGE: tl.constexpr,
368 | offs_m: tl.constexpr,
369 | offs_n: tl.constexpr, #
370 | N_CTX: tl.constexpr,
371 | warp_specialize: tl.constexpr,
372 | ENABLE_JVP: tl.constexpr,
373 | ENABLE_DROPOUT: tl.constexpr,
374 | MASK_TYPE: tl.constexpr, # 0: no mask, 1: boolean, 2: additive
375 | MASK_CONST: tl.constexpr = MASK_CONST,
376 | ):
377 | """Inner forward pass for attention mechanism with TMA (Tensor Memory Access) support.
378 |
379 | Args:
380 | acc: Accumulator tensor.
381 | g_acc: Gradient accumulator tensor.
382 | l_i: Tensor for layer normalization.
383 | m_i: Tensor for masking.
384 | mu_i: Tensor for mean.
385 | p_tv_acc: Tensor for TV attention.
386 | q: Query tensor.
387 | t_q: Transposed query tensor.
388 | desc_k: Descriptor for key tensor.
389 | desc_v: Descriptor for value tensor.
390 | desc_k_t: Descriptor for transposed key tensor.
391 | desc_v_t: Descriptor for transposed value tensor.
392 | offset_y: Offset for y dimension.
393 | mask_block_ptr: Pointer to the attention mask block.
394 | dropout_p: Dropout probability.
395 | philox_seed: Seed for Philox RNG.
396 | philox_offset_base: Base offset for Philox RNG.
397 | dtype: Data type.
398 | start_m: Start index for m dimension.
399 | qk_scale: Scale factor for qk.
400 | sm_scale: Scale factor for sm.
401 | BLOCK_M: Block size for m dimension.
402 | HEAD_DIM: Head dimension size.
403 | BLOCK_N: Block size for n dimension.
404 | STAGE: Stage of computation.
405 | offs_m: Offset for m dimension.
406 | offs_n: Offset for n dimension.
407 | N_CTX: Context size.
408 | warp_specialize: Flag for warp specialization.
409 | ENABLE_JVP: Flag for enabling JVP.
410 | ENABLE_DROPOUT: Flag for enabling dropout.
411 | MASK_TYPE: Type of attention mask (0: no mask, 1: boolean, 2: additive).
412 | MASK_CONST: Constant value used for masking.
413 |
414 | Returns:
415 | The output tensors as a tuple.
416 | """
417 | # Range of values handled by this stage
418 | if STAGE == 1:
419 | # NOTE: From 0 to the left of the diagonal
420 | lo, hi = 0, start_m * BLOCK_M
421 | elif STAGE == 2:
422 | # NOTE: Used only for the block in which there is transition between non-masked and masked keys
423 | lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
424 | lo = tl.multiple_of(lo, BLOCK_M)
425 | else:
426 | # NOTE: Only used for non-causal attention
427 | lo, hi = 0, N_CTX
428 |
429 | offsetk_y = offset_y + lo
430 | if dtype == tl.float8e5:
431 | offsetv_y = offset_y * HEAD_DIM + lo
432 | else:
433 | offsetv_y = offset_y + lo
434 |
435 | if MASK_TYPE > 0:
436 | mask_block_ptr = tl.advance(mask_block_ptr, (0, lo))
437 |
438 | # Loop over k, v and update accumulator
439 | for start_n in tl.range(lo, hi, BLOCK_N, warp_specialize=warp_specialize):
440 | # Let the compiler know that start_n is a multiple
441 | # of BLOCK_N, so the compiler can do optimizations
442 | start_n = tl.multiple_of(start_n, BLOCK_N)
443 |
444 | # -- Compute qk ----
445 | k = desc_k.load([offsetk_y, 0]).T
446 | qk = tl.dot(q, k)
447 | if ENABLE_JVP:
448 | t_k = desc_k_t.load([offsetk_y, 0]).T
449 | t_qk = tl.dot(t_q, k) + tl.dot(q, t_k)
450 |
451 | # Load and apply attention mask if provided (before scaling for STAGE != 2)
452 | if MASK_TYPE > 0:
453 | mask = tl.load(mask_block_ptr)
454 | if MASK_TYPE == 1: # Boolean mask
455 | # Convert boolean to additive mask: True (attend) -> 0, False (ignore) -> -inf
456 | qk = qk + tl.where(mask == 1, 0.0, MASK_CONST)
457 | if ENABLE_JVP:
458 | t_qk = tl.where(mask == 1, t_qk, 0.0)
459 |
460 | elif MASK_TYPE == 2: # Additive mask
461 | qk = qk + mask
462 |
463 | m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale)
464 | qk = qk * qk_scale - m_ij[:, None]
465 |
466 | # For causal attention (STAGE == 2)
467 | elif STAGE == 2:
468 | mask = offs_m[:, None] >= (start_n + offs_n[None, :])
469 | qk = qk * qk_scale + tl.where(mask, 0.0, MASK_CONST)
470 | m_ij = tl.maximum(m_i, tl.max(qk, 1))
471 | qk -= m_ij[:, None]
472 |
473 | # No masking case (MASK_TYPE == 0 and STAGE != 2)
474 | else:
475 | m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale)
476 | qk = qk * qk_scale - m_ij[:, None]
477 |
478 | p = tl.math.exp2(qk)
479 |
480 | if MASK_TYPE == 1 or STAGE == 2:
481 | # Account for fully masked sequence blocks
482 | p = tl.where(mask == 1, p, 0.0)
483 |
484 | # Apply dropout if enabled
485 | if ENABLE_DROPOUT:
486 | philox_offset = philox_offset_base + start_m * N_CTX + start_n
487 | dropout_mask, dropout_scale = create_dropout_mask(
488 | philox_seed, philox_offset, dropout_p, BLOCK_M, BLOCK_N, N_CTX
489 | )
490 | p = p * dropout_mask.to(dtype) * dropout_scale
491 |
492 | # -- Compute correction factor
493 | alpha = tl.math.exp2(m_i - m_ij)
494 | l_ij = tl.sum(p, 1)
495 |
496 | # -- Update output accumulator --
497 | if warp_specialize and (BLOCK_M == 128 and HEAD_DIM == 128):
498 | BM: tl.constexpr = acc.shape[0]
499 | BN: tl.constexpr = acc.shape[1]
500 | acc0, acc1 = acc.reshape([BM, 2, BN // 2]).permute(0, 2, 1).split()
501 | acc0 = acc0 * alpha[:, None]
502 | acc1 = acc1 * alpha[:, None]
503 | acc = tl.join(acc0, acc1).permute(0, 2, 1).reshape([BM, BN])
504 | else:
505 | acc = acc * alpha[:, None]
506 |
507 | # Prepare p and v for the dot
508 | if dtype == tl.float8e5:
509 | v = desc_v.load([0, offsetv_y]).T
510 | else:
511 | v = desc_v.load([offsetv_y, 0])
512 |
513 | p = p.to(dtype)
514 |
515 | if ENABLE_JVP:
516 | p_tqk = p * (t_qk * sm_scale)
517 |
518 | # NOT: This non-transposed v for FP8 is presumably only supported on Blackwell
519 | if warp_specialize and (BLOCK_M == 128 and HEAD_DIM == 128):
520 | BM: tl.constexpr = g_acc.shape[0]
521 | BN: tl.constexpr = g_acc.shape[1]
522 | g_acc0, g_acc1 = g_acc.reshape([BM, 2, BN // 2]).permute(0, 2, 1).split()
523 | g_acc0 = g_acc0 * alpha[:, None]
524 | g_acc1 = g_acc1 * alpha[:, None]
525 | g_acc = tl.join(g_acc0, g_acc1).permute(0, 2, 1).reshape([BM, BN])
526 | else:
527 | g_acc = g_acc * alpha[:, None]
528 |
529 | g_acc = tl.dot(p_tqk.to(v.dtype), v, g_acc)
530 | mu_ij = tl.sum(p_tqk, 1)
531 | mu_i = mu_i * alpha + mu_ij
532 | t_v = desc_v_t.load([offsetv_y, 0])
533 | p_tv_acc = p_tv_acc * alpha[:, None] + tl.dot(p, t_v.to(dtype)).to(t_v.dtype)
534 |
535 | # NOTE: This non transposed v for FP8 is only supported on Blackwell
536 | acc = tl.dot(p, v.to(dtype), acc).to(acc.dtype)
537 |
538 | # Update m_i and l_i
539 | # Place this at the end of the loop to reduce register pressure
540 | l_i = l_i * alpha + l_ij
541 | m_i = m_ij
542 | offsetk_y += BLOCK_N
543 | offsetv_y += BLOCK_N
544 |
545 | if MASK_TYPE > 0:
546 | mask_block_ptr = tl.advance(mask_block_ptr, (0, BLOCK_N))
547 |
548 | return acc, g_acc, l_i, m_i, mu_i, p_tv_acc
549 |
550 |
551 | def _host_descriptor_pre_hook(nargs):
552 | """Pre-hook to set up tensor descriptors for the attention kernel.
553 |
554 | Args:
555 | nargs: A dictionary of kernel arguments.
556 | """
557 | BLOCK_M = nargs["BLOCK_M"]
558 | BLOCK_N = nargs["BLOCK_N"]
559 | HEAD_DIM = nargs["HEAD_DIM"]
560 | if not supports_tma() or not isinstance(nargs["desc_q"], TensorDescriptor):
561 | return
562 | nargs["desc_q"].block_shape = [BLOCK_M, HEAD_DIM]
563 | if nargs["FP8_OUTPUT"]:
564 | nargs["desc_v"].block_shape = [HEAD_DIM, BLOCK_N]
565 | else:
566 | nargs["desc_v"].block_shape = [BLOCK_N, HEAD_DIM]
567 | nargs["desc_k"].block_shape = [BLOCK_N, HEAD_DIM]
568 | nargs["desc_o"].block_shape = [BLOCK_M, HEAD_DIM]
569 |
570 |
571 | if is_hip():
572 | NUM_STAGES_OPTIONS = [1]
573 | elif supports_host_descriptor():
574 | NUM_STAGES_OPTIONS = [2, 3, 4]
575 | else:
576 | NUM_STAGES_OPTIONS = [2, 3, 4]
577 |
578 | configs = [
579 | triton.Config(
580 | {"BLOCK_M": BM, "BLOCK_N": BN},
581 | num_stages=s,
582 | num_warps=w,
583 | pre_hook=_host_descriptor_pre_hook,
584 | )
585 | for BM in [MIN_SEQUENCE_LENGTH, 64, 128]
586 | for BN in [MIN_SEQUENCE_LENGTH, 64, 128]
587 | for s in NUM_STAGES_OPTIONS
588 | for w in [4, 8]
589 | ]
590 | if "PYTEST_VERSION" in os.environ:
591 | # Use a single config in testing for reproducibility
592 | configs = [
593 | triton.Config(
594 | dict(BLOCK_M=128, BLOCK_N=64),
595 | num_stages=2,
596 | num_warps=4,
597 | pre_hook=_host_descriptor_pre_hook,
598 | ),
599 | ]
600 |
601 |
602 | def keep(conf):
603 | """Keep configurations that meet certain criteria.
604 |
605 | Args:
606 | conf: A configuration object.
607 | """
608 | BLOCK_M = conf.kwargs["BLOCK_M"]
609 | BLOCK_N = conf.kwargs["BLOCK_N"]
610 | return not (BLOCK_M * BLOCK_N < 128 * 128 and conf.num_warps == 8)
611 |
612 |
613 | def prune_invalid_configs(configs, named_args, **kwargs):
614 | """Prune configurations that are invalid based on certain criteria.
615 |
616 | Args:
617 | configs: A list of configuration objects.
618 | named_args: A dictionary of named arguments.
619 | **kwargs: Additional keyword arguments.
620 |
621 | Returns:
622 | A list of valid configuration objects.
623 | """
624 | N_CTX = kwargs["N_CTX"]
625 |
626 | if N_CTX == MIN_SEQUENCE_LENGTH:
627 | # Filter out configs where BLOCK_M > MIN_SEQUENCE_LENGTH
628 | return [conf for conf in configs if conf.kwargs.get("BLOCK_M", 0) <= MIN_SEQUENCE_LENGTH]
629 |
630 | # Filter out configs where BLOCK_M > N_CTX or BLOCK_M <= MIN_SEQUENCE_LENGTH, as
631 | # BLOCK_M = MIN_SEQUENCE_LENGTH often leads to reduced numerical accuracy for longer sequences
632 | # TODO: Find out why this occurs
633 | return [
634 | conf for conf in configs if MIN_SEQUENCE_LENGTH < conf.kwargs.get("BLOCK_M", 0) <= N_CTX
635 | ]
636 |
637 |
638 | @triton.jit
639 | def _maybe_make_tensor_desc(desc_or_ptr, shape, strides, block_shape):
640 | """Maybe make a tensor descriptor from a pointer.
641 |
642 | Args:
643 | desc_or_ptr: The input tensor or pointer.
644 | shape: The shape of the tensor.
645 | strides: The strides of the tensor.
646 | block_shape: The block shape of the tensor.
647 |
648 | Returns:
649 | A tensor descriptor.
650 | """
651 | if isinstance(desc_or_ptr, tl.tensor_descriptor):
652 | return desc_or_ptr
653 | else:
654 | return tl.make_tensor_descriptor(desc_or_ptr, shape, strides, block_shape)
655 |
656 |
657 | # @triton.autotune(
658 | # configs=list(filter(keep, configs)),
659 | # key=["N_CTX", "HEAD_DIM", "FP8_OUTPUT", "warp_specialize"],
660 | # prune_configs_by={"early_config_prune": prune_invalid_configs},
661 | # )
662 | @triton.jit
663 | def _attn_fwd(
664 | Q,
665 | K,
666 | V,
667 | T_Q,
668 | T_K,
669 | T_V, #
670 | sm_scale,
671 | M,
672 | Out,
673 | T_Out, #
674 | Mask, # Mask tensor
675 | dropout_p, # Dropout probability
676 | philox_seed, # RNG seed for dropout
677 | stride_qz,
678 | stride_qh,
679 | stride_qm,
680 | stride_qk, #
681 | stride_kz,
682 | stride_kh,
683 | stride_kn,
684 | stride_kk, #
685 | stride_vz,
686 | stride_vh,
687 | stride_vk,
688 | stride_vn, #
689 | stride_tqz,
690 | stride_tqh,
691 | stride_tqm,
692 | stride_tqk, #
693 | stride_tkz,
694 | stride_tkh,
695 | stride_tkn,
696 | stride_tkk, #
697 | stride_tvz,
698 | stride_tvh,
699 | stride_tvk,
700 | stride_tvn, #
701 | stride_oz,
702 | stride_oh,
703 | stride_om,
704 | stride_on, #
705 | stride_toz,
706 | stride_toh,
707 | stride_tom,
708 | stride_ton, #
709 | stride_mz, # Mask stride
710 | stride_mh, # Mask stride
711 | stride_mm, # Mask stride
712 | stride_mn, # Mask stride
713 | Z,
714 | H,
715 | N_CTX, #
716 | HEAD_DIM: tl.constexpr, #
717 | BLOCK_M: tl.constexpr, #
718 | BLOCK_N: tl.constexpr, #
719 | FP8_OUTPUT: tl.constexpr, #
720 | STAGE: tl.constexpr, #
721 | warp_specialize: tl.constexpr, #
722 | ENABLE_JVP: tl.constexpr, #
723 | ENABLE_DROPOUT: tl.constexpr, #
724 | MASK_TYPE: tl.constexpr, #
725 | ):
726 | """Forward attention computation.
727 |
728 | Args:
729 | Q: Query tensor.
730 | K: Key tensor.
731 | V: Value tensor.
732 | T_Q: Tensor for query.
733 | T_K: Tensor for key.
734 | T_V: Tensor for value.
735 | sm_scale: Scale factor.
736 | M: Number of rows.
737 | Out: Output tensor.
738 | T_Out: Tensor for output.
739 | Mask: Attention mask tensor.
740 | dropout_p: Dropout probability.
741 | philox_seed: Seed for Philox RNG.
742 | stride_qz: Stride for query z dimension.
743 | stride_qh: Stride for query h dimension.
744 | stride_qm: Stride for query m dimension.
745 | stride_qk: Stride for query k dimension.
746 | stride_kz: Stride for key z dimension.
747 | stride_kh: Stride for key h dimension.
748 | stride_kn: Stride for key n dimension.
749 | stride_kk: Stride for key k dimension.
750 | stride_vz: Stride for value z dimension.
751 | stride_vh: Stride for value h dimension.
752 | stride_vk: Stride for value k dimension.
753 | stride_vn: Stride for value n dimension.
754 | stride_tqz: Stride for tensor query z dimension.
755 | stride_tqh: Stride for tensor query h dimension.
756 | stride_tqm: Stride for tensor query m dimension.
757 | stride_tqk: Stride for tensor query k dimension.
758 | stride_tkz: Stride for tensor key z dimension.
759 | stride_tkh: Stride for tensor key h dimension.
760 | stride_tkn: Stride for tensor key n dimension.
761 | stride_tkk: Stride for tensor key k dimension.
762 | stride_tvz: Stride for tensor value z dimension.
763 | stride_tvh: Stride for tensor value h dimension.
764 | stride_tvk: Stride for tensor value k dimension.
765 | stride_tvn: Stride for tensor value n dimension.
766 | stride_oz: Stride for output z dimension.
767 | stride_oh: Stride for output h dimension.
768 | stride_om: Stride for output m dimension.
769 | stride_on: Stride for output n dimension.
770 | stride_toz: Stride for tensor output z dimension.
771 | stride_toh: Stride for tensor output h dimension.
772 | stride_tom: Stride for tensor output m dimension.
773 | stride_ton: Stride for tensor output n dimension.
774 | stride_mz: Stride for mask z dimension.
775 | stride_mh: Stride for mask h dimension.
776 | stride_mm: Stride for mask m dimension.
777 | stride_mn: Stride for mask n dimension.
778 | Z: Number of z dimensions.
779 | H: Number of h dimensions.
780 | N_CTX: Number of context dimensions.
781 | HEAD_DIM: Head dimension.
782 | BLOCK_M: Block size for the queries.
783 | BLOCK_N: Block size for the keys/values.
784 | FP8_OUTPUT: FP8 output flag.
785 | STAGE: Stage.
786 | warp_specialize: Warp specialization flag.
787 | ENABLE_JVP: Enable JVP flag.
788 | ENABLE_DROPOUT: Enable dropout flag.
789 | MASK_TYPE: Mask type (0: no mask, 1: boolean, 2: additive).
790 | """
791 | tl.static_assert(BLOCK_N <= HEAD_DIM) # N = KV
792 |
793 | # Prepare metadata and indices
794 | dtype = tl.float8e5 if FP8_OUTPUT else tl.float32 # For dot products
795 | start_m = tl.program_id(0) # Which block (in the input query sequence) to process
796 | off_hz = tl.program_id(
797 | 1
798 | ) # Which head and batch element to process, with a program being a single head of a single batch element
799 | off_z = (
800 | off_hz // H
801 | ) # Which batch element this program is assigned to (n.b., each batch element has H heads)
802 | off_h = off_hz % H # The position of the head to process in the batch
803 |
804 | # NOTE: This allows one to get the (N_CTX, HEAD_DIM) block in Q, K, V by indexing it by batch and head
805 | qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh
806 |
807 | # Initialize block pointers
808 | Q_block_ptr = tl.make_block_ptr(
809 | base=Q + qvk_offset,
810 | shape=(N_CTX, HEAD_DIM),
811 | strides=(stride_qm, stride_qk),
812 | offsets=(start_m * BLOCK_M, 0), # M = Q
813 | block_shape=(BLOCK_M, HEAD_DIM),
814 | order=(1, 0),
815 | )
816 | v_order: tl.constexpr = (0, 1) if V.dtype.element_ty == tl.float8e5 else (1, 0)
817 | V_block_ptr = tl.make_block_ptr(
818 | base=V + qvk_offset,
819 | shape=(N_CTX, HEAD_DIM),
820 | strides=(stride_vk, stride_vn),
821 | offsets=(0, 0),
822 | block_shape=(BLOCK_N, HEAD_DIM),
823 | order=v_order,
824 | )
825 | K_block_ptr = tl.make_block_ptr(
826 | base=K + qvk_offset,
827 | shape=(HEAD_DIM, N_CTX),
828 | strides=(
829 | stride_kk,
830 | stride_kn,
831 | ), # NOTE: We invert the strides of K to get its matrix transpose K^T
832 | offsets=(0, 0),
833 | block_shape=(HEAD_DIM, BLOCK_N),
834 | order=(0, 1),
835 | )
836 | O_block_ptr = tl.make_block_ptr(
837 | base=Out + qvk_offset,
838 | shape=(N_CTX, HEAD_DIM),
839 | strides=(stride_om, stride_on),
840 | offsets=(start_m * BLOCK_M, 0),
841 | block_shape=(BLOCK_M, HEAD_DIM),
842 | order=(1, 0),
843 | )
844 |
845 | # Initialize block pointer for the mask, if provided
846 | if MASK_TYPE > 0:
847 | mask_offset = off_z.to(tl.int64) * stride_mz + off_h.to(tl.int64) * stride_mh
848 | mask_block_ptr = tl.make_block_ptr(
849 | base=Mask + mask_offset,
850 | shape=(N_CTX, N_CTX),
851 | strides=(stride_mm, stride_mn),
852 | offsets=(start_m * BLOCK_M, 0),
853 | block_shape=(BLOCK_M, BLOCK_N),
854 | order=(1, 0),
855 | )
856 | else:
857 | mask_block_ptr = None
858 |
859 | # Initialize dropout offset for this block
860 | philox_offset_base = off_hz * N_CTX * N_CTX
861 |
862 | # Initialize offsets for the query tokens to process
863 | offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
864 | offs_n = tl.arange(0, BLOCK_N)
865 |
866 | # Initialize accumulator pointers:
867 | # m, the running maximum (one for each query)
868 | m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
869 | # l, the running sum (one for each query as we sum the attention scores by rows)
870 | l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
871 | # acc, the output accumulator (one vector for each query)
872 | acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
873 |
874 | if ENABLE_JVP:
875 | # NOTE: It's extremely likely we could just reuse qvk_offset, but this seems cheap so whatever
876 | t_qvk_offset = off_z.to(tl.int64) * stride_tqz + off_h.to(tl.int64) * stride_tqh
877 | T_Q_block_ptr = tl.make_block_ptr(
878 | base=T_Q + t_qvk_offset,
879 | shape=(N_CTX, HEAD_DIM),
880 | strides=(stride_tqm, stride_tqk),
881 | offsets=(start_m * BLOCK_M, 0),
882 | block_shape=(BLOCK_M, HEAD_DIM),
883 | order=(1, 0),
884 | )
885 | # NOTE: Could probably just reuse v_order here
886 | t_v_order: tl.constexpr = (0, 1) if T_V.dtype.element_ty == tl.float8e5 else (1, 0)
887 | T_V_block_ptr = tl.make_block_ptr(
888 | base=T_V + t_qvk_offset,
889 | shape=(N_CTX, HEAD_DIM),
890 | strides=(stride_tvk, stride_tvn),
891 | offsets=(0, 0),
892 | block_shape=(BLOCK_N, HEAD_DIM),
893 | order=t_v_order,
894 | )
895 | T_K_block_ptr = tl.make_block_ptr(
896 | base=T_K + t_qvk_offset,
897 | shape=(HEAD_DIM, N_CTX),
898 | strides=(
899 | stride_tkk,
900 | stride_tkn,
901 | ), # NOTE: We invert the strides of tangent K (k_t) to get its matrix transpose K^T
902 | offsets=(0, 0),
903 | block_shape=(HEAD_DIM, BLOCK_N),
904 | order=(0, 1),
905 | )
906 | T_O_block_ptr = tl.make_block_ptr(
907 | base=T_Out + t_qvk_offset,
908 | shape=(N_CTX, HEAD_DIM),
909 | strides=(stride_tom, stride_ton),
910 | offsets=(start_m * BLOCK_M, 0),
911 | block_shape=(BLOCK_M, HEAD_DIM),
912 | order=(1, 0),
913 | )
914 | # Load q_t: It will stay in SRAM throughout.
915 | t_q = tl.load(T_Q_block_ptr)
916 | g_acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
917 | mu_i = tl.zeros([BLOCK_M], dtype=tl.float32)
918 | p_tv_acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
919 | else:
920 | t_q = None
921 | T_V_block_ptr = None
922 | T_K_block_ptr = None
923 | # Allocate minimal dummy tensors to keep consistent the return signature of _attn_fwd_inner
924 | g_acc = tl.zeros([1, 1], dtype=tl.float32)
925 | mu_i = tl.zeros([1], dtype=tl.float32)
926 | p_tv_acc = tl.zeros([1, 1], dtype=tl.float32)
927 |
928 | # Prepare scales
929 | qk_scale = sm_scale
930 | qk_scale = qk_scale * 1.44269504 # 1/log(2)
931 |
932 | # Load q: It will stay in SRAM throughout.
933 | q = tl.load(Q_block_ptr)
934 |
935 | # Stage: 3 if causal, else 1
936 | if STAGE == 1 or STAGE == 3:
937 | # NOTE: This step runs for non-causal attention or for the
938 | # blocks to the left of the diagonal for causal attention
939 | acc, g_acc, l_i, m_i, mu_i, p_tv_acc = _attn_fwd_inner(
940 | acc,
941 | g_acc,
942 | l_i,
943 | m_i, #
944 | mu_i,
945 | p_tv_acc, #
946 | q,
947 | t_q, #
948 | K_block_ptr,
949 | V_block_ptr, #
950 | T_K_block_ptr,
951 | T_V_block_ptr, #
952 | mask_block_ptr,
953 | dropout_p,
954 | philox_seed,
955 | philox_offset_base,
956 | dtype,
957 | start_m,
958 | qk_scale,
959 | sm_scale, #
960 | BLOCK_M,
961 | HEAD_DIM,
962 | BLOCK_N, #
963 | 4 - STAGE,
964 | offs_m,
965 | offs_n,
966 | N_CTX, #
967 | warp_specialize,
968 | ENABLE_JVP,
969 | ENABLE_DROPOUT,
970 | MASK_TYPE,
971 | )
972 |
973 | if STAGE == 3:
974 | # NOTE: This step runs for the blocks to the
975 | # right of the diagonal for causal attention
976 | acc, g_acc, l_i, m_i, mu_i, p_tv_acc = _attn_fwd_inner(
977 | acc,
978 | g_acc, #
979 | l_i,
980 | m_i, #
981 | mu_i,
982 | p_tv_acc, #
983 | q,
984 | t_q, #
985 | K_block_ptr,
986 | V_block_ptr, #
987 | T_K_block_ptr,
988 | T_V_block_ptr, #
989 | mask_block_ptr,
990 | dropout_p,
991 | philox_seed,
992 | philox_offset_base,
993 | dtype,
994 | start_m,
995 | qk_scale,
996 | sm_scale, #
997 | BLOCK_M,
998 | HEAD_DIM,
999 | BLOCK_N, #
1000 | 2,
1001 | offs_m,
1002 | offs_n,
1003 | N_CTX, #
1004 | warp_specialize,
1005 | ENABLE_JVP,
1006 | ENABLE_DROPOUT,
1007 | MASK_TYPE,
1008 | )
1009 |
1010 | # Epilogue
1011 | empty_mask = l_i == 0.0
1012 | if empty_mask.sum() > 0:
1013 | l_i = tl.where(
1014 | empty_mask, 1.0, l_i
1015 | ) # NOTE: This happens if the entire block is masked out.
1016 |
1017 | m_i = m_i + tl.where(
1018 | # NOTE: This is needed to compute the logsumexp for the backward pass.
1019 | empty_mask,
1020 | 0.0,
1021 | tl.math.log2(l_i),
1022 | )
1023 |
1024 | acc = acc / l_i[:, None]
1025 | m_ptrs = M + off_hz * N_CTX + offs_m
1026 | tl.store(m_ptrs, m_i)
1027 | tl.store(O_block_ptr, acc.to(Out.type.element_ty))
1028 |
1029 | # If JVP is enabled, compute and store the output tangent
1030 | if ENABLE_JVP:
1031 | t_p_v = g_acc / l_i[:, None] - (mu_i / l_i)[:, None] * acc
1032 | t_y_out = t_p_v + p_tv_acc / l_i[:, None]
1033 | tl.store(T_O_block_ptr, t_y_out.to(T_Out.type.element_ty))
1034 |
1035 |
1036 | def _tma_pre_hook(nargs):
1037 | """Pre-hook for TMA (Tensor Memory Access) optimization.
1038 |
1039 | Args:
1040 | nargs: A dictionary containing the kernel arguments.
1041 | """
1042 | BLOCK_M = nargs["BLOCK_M"]
1043 | BLOCK_N = nargs["BLOCK_N"]
1044 | HEAD_DIM = nargs["HEAD_DIM"]
1045 | nargs["desc_q"].block_shape = [BLOCK_M, HEAD_DIM]
1046 | nargs["desc_v"].block_shape = [BLOCK_N, HEAD_DIM]
1047 | nargs["desc_k"].block_shape = [BLOCK_N, HEAD_DIM]
1048 | nargs["desc_o"].block_shape = [BLOCK_M, HEAD_DIM]
1049 |
1050 |
1051 | # We don't run auto-tuning every time to keep the tutorial fast. Keeping
1052 | # the code below and commenting out the equivalent parameters is convenient for
1053 | # re-tuning.
1054 | configs_tma = [
1055 | triton.Config(
1056 | {"BLOCK_M": BM, "BLOCK_N": BN},
1057 | num_stages=s,
1058 | num_warps=w,
1059 | pre_hook=_tma_pre_hook,
1060 | )
1061 | for BM in [MIN_SEQUENCE_LENGTH, 64, 128, 256]
1062 | for BN in [MIN_SEQUENCE_LENGTH, 64, 128]
1063 | for s in [3, 4, 5]
1064 | for w in [4, 8]
1065 | ]
1066 |
1067 |
1068 | def keep_tma(conf):
1069 | """Check if TMA (Tensor Memory Access) optimization should be kept for the given configuration.
1070 |
1071 | Args:
1072 | conf: The configuration to check.
1073 | """
1074 | BLOCK_M = conf.kwargs["BLOCK_M"]
1075 | BLOCK_N = conf.kwargs["BLOCK_N"]
1076 | return not (
1077 | is_cuda()
1078 | and torch.cuda.get_device_capability()[0] == 9
1079 | and BLOCK_M * BLOCK_N < 128 * 128
1080 | and conf.num_warps == 8
1081 | )
1082 |
1083 |
1084 | # @triton.autotune(
1085 | # configs=list(filter(keep_tma, configs_tma)),
1086 | # key=["N_CTX", "HEAD_DIM", "FP8_OUTPUT", "warp_specialize"],
1087 | # prune_configs_by={"early_config_prune": prune_invalid_configs},
1088 | # )
1089 | @triton.jit
1090 | def _attn_fwd_tma(
1091 | sm_scale,
1092 | M, #
1093 | Z,
1094 | H, #
1095 | desc_q,
1096 | desc_k,
1097 | desc_v, #
1098 | desc_q_t,
1099 | desc_k_t,
1100 | desc_v_t, #
1101 | desc_o,
1102 | desc_o_t, #
1103 | Mask, # Mask tensor
1104 | dropout_p, # Dropout probability
1105 | philox_seed, # RNG seed for dropout
1106 | stride_mz, # Mask stride
1107 | stride_mh, # Mask stride
1108 | stride_mm, # Mask stride
1109 | stride_mn, # Mask stride
1110 | N_CTX, #
1111 | HEAD_DIM: tl.constexpr, #
1112 | BLOCK_M: tl.constexpr, #
1113 | BLOCK_N: tl.constexpr, #
1114 | FP8_OUTPUT: tl.constexpr, #
1115 | STAGE: tl.constexpr, #
1116 | warp_specialize: tl.constexpr, #
1117 | ENABLE_JVP: tl.constexpr, #
1118 | ENABLE_DROPOUT: tl.constexpr, #
1119 | MASK_TYPE: tl.constexpr, #
1120 | ):
1121 | """Forward attention computation with TMA (Tensor Memory Access) support.
1122 |
1123 | Args:
1124 | sm_scale: Scale factor for the softmax.
1125 | M: Number of rows in the input.
1126 | Z: Number of channels in the input.
1127 | H: Number of heads in the multi-head attention.
1128 | desc_q: Descriptor for the query tensor.
1129 | desc_k: Descriptor for the key tensor.
1130 | desc_v: Descriptor for the value tensor.
1131 | desc_q_t: Descriptor for the transposed query tensor.
1132 | desc_k_t: Descriptor for the transposed key tensor.
1133 | desc_v_t: Descriptor for the transposed value tensor.
1134 | desc_o: Descriptor for the output tensor.
1135 | desc_o_t: Descriptor for the transposed output tensor.
1136 | Mask: Attention mask tensor.
1137 | dropout_p: Dropout probability.
1138 | philox_seed: Seed for the Philox random number generator.
1139 | stride_mz: Stride for the mask in the z dimension.
1140 | stride_mh: Stride for the mask in the h dimension.
1141 | stride_mm: Stride for the mask in the m dimension.
1142 | stride_mn: Stride for the mask in the n dimension.
1143 | N_CTX: Context length.
1144 | HEAD_DIM: Dimension of each head.
1145 | BLOCK_M: Block size for the queries.
1146 | BLOCK_N: Block size for the keys/values.
1147 | FP8_OUTPUT: Flag indicating if FP8 output is used.
1148 | STAGE: Stage of the computation.
1149 | warp_specialize: Flag indicating if warp specialization is used.
1150 | ENABLE_JVP: Flag indicating if JVP (Jacobian-vector product) is enabled.
1151 | ENABLE_DROPOUT: Flag indicating if dropout is enabled.
1152 | MASK_TYPE: Type of mask used (0: no mask, 1: boolean, 2: additive).
1153 | """
1154 | tl.static_assert(BLOCK_N <= HEAD_DIM) # N = KV
1155 |
1156 | # Prepare metadata and indices
1157 | dtype = tl.float8e5 if FP8_OUTPUT else tl.float32 # For dot products
1158 | start_m = tl.program_id(0) # Which block (in the input query sequence) to process
1159 | off_hz = tl.program_id(
1160 | 1
1161 | ) # Which head and batch element to process, with a program being a single head of a single batch element
1162 | off_z = (
1163 | off_hz // H
1164 | ) # Which batch element this program is assigned to (n.b., each batch element has H heads)
1165 | off_h = off_hz % H # The position of the head to process in the batch
1166 |
1167 | # Initialize tensor descriptors
1168 | y_dim = Z * H * N_CTX
1169 | desc_q = _maybe_make_tensor_desc(
1170 | desc_q,
1171 | shape=[y_dim, HEAD_DIM],
1172 | strides=[HEAD_DIM, 1],
1173 | block_shape=[BLOCK_M, HEAD_DIM], # M = Q
1174 | )
1175 | if FP8_OUTPUT:
1176 | v_shape = [HEAD_DIM, y_dim]
1177 | v_strides = [N_CTX, 1]
1178 | v_block_shape = [HEAD_DIM, BLOCK_N]
1179 | else:
1180 | v_shape = [y_dim, HEAD_DIM]
1181 | v_strides = [HEAD_DIM, 1]
1182 | v_block_shape = [BLOCK_N, HEAD_DIM]
1183 | desc_v = _maybe_make_tensor_desc(
1184 | desc_v, shape=v_shape, strides=v_strides, block_shape=v_block_shape
1185 | )
1186 | desc_k = _maybe_make_tensor_desc(
1187 | desc_k,
1188 | shape=[y_dim, HEAD_DIM],
1189 | strides=[HEAD_DIM, 1],
1190 | block_shape=[BLOCK_N, HEAD_DIM],
1191 | )
1192 | desc_o = _maybe_make_tensor_desc(
1193 | desc_o,
1194 | shape=[y_dim, HEAD_DIM],
1195 | strides=[HEAD_DIM, 1],
1196 | block_shape=[BLOCK_M, HEAD_DIM],
1197 | )
1198 |
1199 | offset_y = off_z * (N_CTX * H) + off_h * N_CTX
1200 | qo_offset_y = offset_y + start_m * BLOCK_M
1201 |
1202 | # Initialize block pointer for the mask, if provided
1203 | if MASK_TYPE > 0:
1204 | mask_offset = off_z * stride_mz + off_h * stride_mh
1205 | mask_block_ptr = tl.make_block_ptr(
1206 | base=Mask + mask_offset,
1207 | shape=(N_CTX, N_CTX),
1208 | strides=(stride_mm, stride_mn),
1209 | offsets=(start_m * BLOCK_M, 0),
1210 | block_shape=(BLOCK_M, BLOCK_N),
1211 | order=(1, 0),
1212 | )
1213 | else:
1214 | mask_block_ptr = None
1215 |
1216 | # Initialize dropout offset for this block
1217 | philox_offset_base = off_hz * N_CTX * N_CTX
1218 |
1219 | # Initialize offsets for the query tokens to process
1220 | offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
1221 | offs_n = tl.arange(0, BLOCK_N)
1222 |
1223 | # Initialize accumulator pointers:
1224 | # m, the running maximum (one for each query)
1225 | m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
1226 | # l, the running sum (one for each query as we sum the attention scores by rows)
1227 | l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
1228 | # acc, the output accumulator (one vector for each query)
1229 | acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
1230 |
1231 | if ENABLE_JVP:
1232 | desc_q_t = _maybe_make_tensor_desc(
1233 | desc_q_t,
1234 | shape=[y_dim, HEAD_DIM],
1235 | strides=[HEAD_DIM, 1],
1236 | block_shape=[BLOCK_M, HEAD_DIM],
1237 | )
1238 | if FP8_OUTPUT:
1239 | t_v_shape = [HEAD_DIM, y_dim]
1240 | t_v_strides = [N_CTX, 1]
1241 | t_v_block_shape = [HEAD_DIM, BLOCK_N]
1242 | else:
1243 | t_v_shape = [y_dim, HEAD_DIM]
1244 | t_v_strides = [HEAD_DIM, 1]
1245 | t_v_block_shape = [BLOCK_N, HEAD_DIM]
1246 | desc_v_t = _maybe_make_tensor_desc(
1247 | desc_v_t, shape=t_v_shape, strides=t_v_strides, block_shape=t_v_block_shape
1248 | )
1249 | desc_k_t = _maybe_make_tensor_desc(
1250 | desc_k_t,
1251 | shape=[y_dim, HEAD_DIM],
1252 | strides=[HEAD_DIM, 1],
1253 | block_shape=[BLOCK_N, HEAD_DIM],
1254 | )
1255 | desc_o_t = _maybe_make_tensor_desc(
1256 | desc_o_t,
1257 | shape=[y_dim, HEAD_DIM],
1258 | strides=[HEAD_DIM, 1],
1259 | block_shape=[BLOCK_M, HEAD_DIM],
1260 | )
1261 | # Load t_q: It will stay in SRAM throughout.
1262 | t_q = desc_q_t.load([qo_offset_y, 0])
1263 | g_acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
1264 | mu_i = tl.zeros([BLOCK_M], dtype=tl.float32)
1265 | p_tv_acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
1266 | else:
1267 | t_q = None
1268 | desc_k_t = None
1269 | desc_v_t = None
1270 | # Allocate minimal dummy tensors to keep consistent the return signature of _attn_fwd_inner_tma
1271 | g_acc = tl.zeros([1, 1], dtype=tl.float32)
1272 | mu_i = tl.zeros([1], dtype=tl.float32)
1273 | p_tv_acc = tl.zeros([1, 1], dtype=tl.float32)
1274 |
1275 | # Prepare scales
1276 | qk_scale = sm_scale
1277 | qk_scale *= 1.44269504 # 1/log(2)
1278 |
1279 | # Load q: It will stay in SRAM throughout.
1280 | q = desc_q.load([qo_offset_y, 0])
1281 |
1282 | # Stage: 3 if causal, else 1
1283 | if STAGE == 1 or STAGE == 3:
1284 | # NOTE: This step runs for non-causal attention or for the
1285 | # blocks to the left of the diagonal for causal attention
1286 | acc, g_acc, l_i, m_i, mu_i, p_tv_acc = _attn_fwd_inner_tma(
1287 | acc,
1288 | g_acc, #
1289 | l_i,
1290 | m_i, #
1291 | mu_i,
1292 | p_tv_acc, #
1293 | q,
1294 | t_q, #
1295 | desc_k,
1296 | desc_v, #
1297 | desc_k_t,
1298 | desc_v_t, #
1299 | offset_y,
1300 | mask_block_ptr,
1301 | dropout_p,
1302 | philox_seed,
1303 | philox_offset_base,
1304 | dtype,
1305 | start_m,
1306 | qk_scale,
1307 | sm_scale, #
1308 | BLOCK_M,
1309 | HEAD_DIM,
1310 | BLOCK_N, #
1311 | 4 - STAGE,
1312 | offs_m,
1313 | offs_n,
1314 | N_CTX, #
1315 | warp_specialize,
1316 | ENABLE_JVP,
1317 | ENABLE_DROPOUT,
1318 | MASK_TYPE,
1319 | )
1320 |
1321 | if STAGE == 3:
1322 | # NOTE: This step runs for the blocks to the
1323 | # right of the diagonal for causal attention
1324 | acc, g_acc, l_i, m_i, mu_i, p_tv_acc = _attn_fwd_inner_tma(
1325 | acc,
1326 | g_acc, #
1327 | l_i,
1328 | m_i, #
1329 | mu_i,
1330 | p_tv_acc, #
1331 | q,
1332 | t_q, #
1333 | desc_k,
1334 | desc_v, #
1335 | desc_k_t,
1336 | desc_v_t, #
1337 | offset_y,
1338 | mask_block_ptr,
1339 | dropout_p,
1340 | philox_seed,
1341 | philox_offset_base,
1342 | dtype,
1343 | start_m,
1344 | qk_scale,
1345 | sm_scale, #
1346 | BLOCK_M,
1347 | HEAD_DIM,
1348 | BLOCK_N, #
1349 | 2,
1350 | offs_m,
1351 | offs_n,
1352 | N_CTX, #
1353 | warp_specialize,
1354 | ENABLE_JVP,
1355 | ENABLE_DROPOUT,
1356 | MASK_TYPE,
1357 | )
1358 |
1359 | # Epilogue
1360 | empty_mask = l_i == 0.0
1361 | if empty_mask.sum() > 0:
1362 | l_i = tl.where(
1363 | empty_mask, 1.0, l_i
1364 | ) # NOTE: This happens if the entire block is masked out.
1365 |
1366 | m_i = m_i + tl.where(
1367 | # NOTE: This is needed to compute the logsumexp for the backward pass.
1368 | empty_mask,
1369 | 0.0,
1370 | tl.math.log2(l_i),
1371 | )
1372 |
1373 | acc = acc / l_i[:, None]
1374 | m_ptrs = M + off_hz * N_CTX + offs_m
1375 | tl.store(m_ptrs, m_i)
1376 | desc_o.store([qo_offset_y, 0], acc.to(desc_o.dtype))
1377 |
1378 | if ENABLE_JVP:
1379 | t_p_v = g_acc / l_i[:, None] - (mu_i / l_i)[:, None] * acc
1380 | t_y_out = t_p_v + p_tv_acc / l_i[:, None]
1381 | desc_o_t.store([qo_offset_y, 0], t_y_out.to(desc_o_t.dtype))
1382 |
1383 |
1384 | @triton.jit
1385 | def _attn_bwd_preprocess(
1386 | O, DO, Delta, N_CTX, BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr # noqa: E741
1387 | ):
1388 | """Preprocess output deltas for the backward attention pass.
1389 |
1390 | Args:
1391 | O: Output tensor.
1392 | DO: Gradient of the output tensor.
1393 | Delta: Accumulated gradients.
1394 | N_CTX: Context length.
1395 | BLOCK_M: Block size for M dimension.
1396 | HEAD_DIM: Head dimension size.
1397 | """
1398 | # Collect sequence, batch, and head indices
1399 | off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
1400 | off_hz = tl.program_id(1)
1401 | off_n = tl.arange(0, HEAD_DIM)
1402 |
1403 | # Load outputs and gradients
1404 | o = tl.load(O + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :])
1405 | do = tl.load(DO + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]).to(
1406 | tl.float32
1407 | )
1408 | delta = tl.sum(o * do, axis=1)
1409 |
1410 | # Write-back the intermediate delta results
1411 | tl.store(Delta + off_hz * N_CTX + off_m, delta)
1412 |
1413 |
1414 | # The main inner-loop logic for computing dK and dV.
1415 | @triton.jit
1416 | def _attn_bwd_dkdv(
1417 | dk,
1418 | dv, #
1419 | Q,
1420 | k,
1421 | v,
1422 | DO, #
1423 | M,
1424 | D, #
1425 | # shared by Q/K/V/DO.
1426 | stride_tok,
1427 | stride_d, #
1428 | N_CTX,
1429 | BLOCK_M1: tl.constexpr, #
1430 | BLOCK_N1: tl.constexpr, #
1431 | HEAD_DIM: tl.constexpr, #
1432 | # Filled in by the wrapper.
1433 | start_n,
1434 | start_m,
1435 | num_steps, #
1436 | CAUSAL_MASKING: tl.constexpr,
1437 | # Args for masking/dropout
1438 | mask_ptr,
1439 | mask_stride_tok1,
1440 | mask_stride_tok2,
1441 | MASK_TYPE: tl.constexpr,
1442 | dropout_p,
1443 | philox_seed,
1444 | philox_offset_base,
1445 | ENABLE_DROPOUT: tl.constexpr,
1446 | MASK_CONST: tl.constexpr = MASK_CONST,
1447 | ):
1448 | """The main inner-loop logic for computing dK and dV.
1449 |
1450 | Args:
1451 | dk: Gradient of the key tensor.
1452 | dv: Gradient of the value tensor.
1453 | Q: Query tensor.
1454 | k: Key tensor.
1455 | v: Value tensor.
1456 | DO: Gradient of the output tensor.
1457 | M: Memory tensor.
1458 | D: Delta tensor.
1459 | stride_tok: Stride for the token dimension.
1460 | stride_d: Stride for the head dimension.
1461 | N_CTX: Context length.
1462 | BLOCK_M1: Block size for M dimension.
1463 | BLOCK_N1: Block size for N dimension.
1464 | HEAD_DIM: Head dimension size.
1465 | start_n: Starting index for N dimension.
1466 | start_m: Starting index for M dimension.
1467 | num_steps: Number of steps to unroll.
1468 | CAUSAL_MASKING: Flag for causal masking.
1469 | mask_ptr: Pointer to the mask tensor.
1470 | mask_stride_tok1: Stride for the third (row) dimension of the mask tensor.
1471 | mask_stride_tok2: Stride for the fourth (column) dimension of the mask tensor.
1472 | MASK_TYPE: Type of masking (0: no mask, 1: boolean mask,
1473 | 2: additive mask).
1474 | dropout_p: Dropout probability.
1475 | philox_seed: Seed for Philox RNG.
1476 | philox_offset_base: Base offset for Philox RNG.
1477 | ENABLE_DROPOUT: Flag to enable dropout.
1478 | MASK_CONST: Constant used for masking.
1479 |
1480 | Returns:
1481 | dk: Gradient of the key tensor.
1482 | dv: Gradient of the value tensor.
1483 | """
1484 | # Initialize pointers for Q and DO
1485 | offs_m = start_m + tl.arange(0, BLOCK_M1)
1486 | offs_n = start_n + tl.arange(0, BLOCK_N1)
1487 | offs_h = tl.arange(0, HEAD_DIM)
1488 | qT_ptrs = Q + offs_m[None, :] * stride_tok + offs_h[:, None] * stride_d
1489 | do_ptrs = DO + offs_m[:, None] * stride_tok + offs_h[None, :] * stride_d
1490 |
1491 | if MASK_TYPE > 0:
1492 | mask_ptr = (
1493 | mask_ptr + offs_m[None, :] * mask_stride_tok1 + offs_n[:, None] * mask_stride_tok2
1494 | )
1495 |
1496 | # NOTE: BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
1497 | tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
1498 | curr_m = start_m
1499 | step_m = BLOCK_M1
1500 | dtype = tl.float32 # For dot products
1501 |
1502 | # Iteratively compute dK and dV over the M dimension
1503 | for _ in range(num_steps):
1504 | qT = tl.load(qT_ptrs)
1505 |
1506 | # Load m before computing qk to reduce pipeline stall
1507 | offs_m = curr_m + tl.arange(0, BLOCK_M1)
1508 | m = tl.load(M + offs_m)
1509 | qkT = tl.dot(k, qT)
1510 |
1511 | # Exponentiation
1512 | pT = tl.math.exp2(qkT - m[None, :])
1513 |
1514 | # External masking after exponentiation
1515 | if MASK_TYPE > 0:
1516 | mask = tl.load(mask_ptr)
1517 | if MASK_TYPE == 1: # Boolean mask
1518 | pT = tl.where(mask == 1, pT, 0.0)
1519 | elif MASK_TYPE == 2: # Additive mask
1520 | # 'mask' is the additive mask loaded above (MASK_CONST not allowed, all other values allowed)
1521 | attend = mask != MASK_CONST
1522 | pT = tl.where(attend, pT, 0.0)
1523 |
1524 | # (or) Causal masking after exponentiation
1525 | elif CAUSAL_MASKING:
1526 | causal_mask = offs_m[None, :] >= offs_n[:, None]
1527 | pT = tl.where(causal_mask, pT, 0.0)
1528 |
1529 | # Dropout after exponentiation
1530 | if ENABLE_DROPOUT:
1531 | philox_offset = philox_offset_base + curr_m * N_CTX + start_n
1532 | dropout_mask, dropout_scale = create_dropout_mask(
1533 | philox_seed, philox_offset, dropout_p, BLOCK_M1, BLOCK_N1, N_CTX
1534 | )
1535 | pT = pT * dropout_mask.to(pT.dtype) * dropout_scale
1536 |
1537 | # Compute dV
1538 | ppT = pT
1539 | ppT = ppT.to(dtype)
1540 | do = tl.load(do_ptrs)
1541 | dv += tl.dot(ppT, do.to(dtype)).to(do.dtype)
1542 | # NOTE: D (= delta) is pre-divided by ds_scale.
1543 | Di = tl.load(D + offs_m)
1544 |
1545 | # Compute dP and dS to derive dK
1546 | dpT = tl.dot(v, tl.trans(do)).to(tl.float32)
1547 |
1548 | if ENABLE_DROPOUT: # This derivative should be masked with the same dropout mask
1549 | dpT = dpT * dropout_mask.to(dpT.dtype) * dropout_scale
1550 |
1551 | dsT = pT * (dpT - Di[None, :])
1552 | dsT = dsT.to(dtype)
1553 | dk += tl.dot(dsT, tl.trans(qT).to(dtype)).to(qT.dtype)
1554 |
1555 | # Increment pointers
1556 | curr_m += step_m
1557 | qT_ptrs += step_m * stride_tok
1558 | do_ptrs += step_m * stride_tok
1559 |
1560 | if MASK_TYPE > 0:
1561 | mask_ptr += step_m * mask_stride_tok1
1562 |
1563 | return dk, dv
1564 |
1565 |
1566 | # The main inner-loop logic for computing dQ
1567 | @triton.jit
1568 | def _attn_bwd_dq(
1569 | dq,
1570 | q,
1571 | K,
1572 | V, #
1573 | do,
1574 | m,
1575 | D,
1576 | # shared by Q/K/V/DO.
1577 | stride_tok,
1578 | stride_d, #
1579 | N_CTX, #
1580 | BLOCK_M2: tl.constexpr, #
1581 | BLOCK_N2: tl.constexpr, #
1582 | HEAD_DIM: tl.constexpr,
1583 | # Filled in by the wrapper.
1584 | start_m,
1585 | start_n,
1586 | num_steps, #
1587 | CAUSAL_MASKING: tl.constexpr,
1588 | # Args for masking/dropout
1589 | mask_ptr,
1590 | mask_stride_tok1,
1591 | mask_stride_tok2,
1592 | MASK_TYPE: tl.constexpr,
1593 | dropout_p,
1594 | philox_seed,
1595 | philox_offset_base,
1596 | ENABLE_DROPOUT: tl.constexpr,
1597 | MASK_CONST: tl.constexpr = MASK_CONST,
1598 | ):
1599 | """The main inner-loop logic for computing dQ.
1600 |
1601 | Args:
1602 | dq: Gradient of the query tensor.
1603 | q: Query tensor.
1604 | K: Key tensor.
1605 | V: Value tensor.
1606 | do: Gradient of the output tensor.
1607 | m: Memory tensor.
1608 | D: Delta tensor.
1609 | stride_tok: Stride for the token dimension.
1610 | stride_d: Stride for the head dimension.
1611 | N_CTX: Context length.
1612 | BLOCK_M2: Block size for M dimension.
1613 | BLOCK_N2: Block size for N dimension.
1614 | HEAD_DIM: Head dimension size.
1615 | start_m: Starting index for M dimension.
1616 | start_n: Starting index for N dimension.
1617 | num_steps: Number of steps to unroll.
1618 | CAUSAL_MASKING: Flag for causal masking.
1619 | mask_ptr: Pointer to the mask tensor.
1620 | mask_stride_tok1: Stride for the third (row) dimension of the mask tensor.
1621 | mask_stride_tok2: Stride for the fourth (column) dimension of the mask tensor.
1622 | MASK_TYPE: Type of masking (0: no mask, 1: boolean mask,
1623 | 2: additive mask).
1624 | dropout_p: Dropout probability.
1625 | philox_seed: Seed for Philox RNG.
1626 | philox_offset_base: Base offset for Philox RNG.
1627 | ENABLE_DROPOUT: Flag to enable dropout.
1628 | MASK_CONST: Constant used for masking.
1629 |
1630 | Returns:
1631 | dq: Gradient of the query tensor.
1632 | """
1633 | # Initialize pointers for K, V, and DO
1634 | offs_m = start_m + tl.arange(0, BLOCK_M2)
1635 | offs_n = start_n + tl.arange(0, BLOCK_N2)
1636 | offs_h = tl.arange(0, HEAD_DIM)
1637 | kT_ptrs = K + offs_n[None, :] * stride_tok + offs_h[:, None] * stride_d
1638 | vT_ptrs = V + offs_n[None, :] * stride_tok + offs_h[:, None] * stride_d
1639 |
1640 | if MASK_TYPE > 0:
1641 | mask_ptr = (
1642 | mask_ptr + offs_m[:, None] * mask_stride_tok1 + offs_n[None, :] * mask_stride_tok2
1643 | )
1644 |
1645 | # NOTE: D (= delta) is pre-divided by ds_scale.
1646 | Di = tl.load(D + offs_m)
1647 |
1648 | # NOTE: BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
1649 | tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
1650 | curr_n = start_n
1651 | step_n = BLOCK_N2
1652 | dtype = tl.float32 # For dot products
1653 |
1654 | # Iteratively compute dQ over the N dimension
1655 | for _ in range(num_steps):
1656 | offs_n = curr_n + tl.arange(0, BLOCK_N2)
1657 | kT = tl.load(kT_ptrs)
1658 | vT = tl.load(vT_ptrs)
1659 | qk = tl.dot(q, kT)
1660 |
1661 | # Exponentiation
1662 | p = tl.math.exp2(qk - m)
1663 |
1664 | # External masking after exponentiation
1665 | if MASK_TYPE > 0:
1666 | mask = tl.load(mask_ptr)
1667 | if MASK_TYPE == 1: # Boolean mask
1668 | p = tl.where(mask == 1, p, 0.0)
1669 | elif MASK_TYPE == 2: # Additive mask
1670 | attend = mask != MASK_CONST
1671 | p = tl.where(attend, p, 0.0)
1672 |
1673 | # (or) Causal masking after exponentiation
1674 | elif CAUSAL_MASKING:
1675 | causal_mask = offs_m[:, None] >= offs_n[None, :]
1676 | p = tl.where(causal_mask, p, 0.0)
1677 |
1678 | # Dropout after exponentiation
1679 | if ENABLE_DROPOUT:
1680 | philox_offset = philox_offset_base + start_m * N_CTX + curr_n
1681 | dropout_mask, dropout_scale = create_dropout_mask(
1682 | philox_seed, philox_offset, dropout_p, BLOCK_M2, BLOCK_N2, N_CTX
1683 | )
1684 | p = p * dropout_mask.to(p.dtype) * dropout_scale
1685 |
1686 | # Compute dP and dS
1687 | dp = tl.dot(do, vT).to(tl.float32)
1688 |
1689 | if ENABLE_DROPOUT: # NOTE: This derivative should be masked with the same dropout mask.
1690 | dp = dp * dropout_mask.to(dp.dtype) * dropout_scale
1691 |
1692 | ds = p * (dp - Di[:, None])
1693 |
1694 | # Compute dQ
1695 | # NOTE: We need to de-scale dq in the end, because kT was pre-scaled.
1696 | ds = ds.to(dtype)
1697 | dq += tl.dot(ds, tl.trans(kT).to(dtype)).to(kT.dtype)
1698 |
1699 | # Increment pointers
1700 | curr_n += step_n
1701 | kT_ptrs += step_n * stride_tok
1702 | vT_ptrs += step_n * stride_tok
1703 |
1704 | if MASK_TYPE > 0:
1705 | mask_ptr += step_n * mask_stride_tok2
1706 |
1707 | return dq
1708 |
1709 |
1710 | @triton.jit
1711 | def _attn_bwd_causal(
1712 | Q,
1713 | K,
1714 | V,
1715 | sm_scale, #
1716 | DO, #
1717 | DQ,
1718 | DK,
1719 | DV, #
1720 | M,
1721 | D,
1722 | # Shared by Q/K/V/DO.
1723 | stride_z,
1724 | stride_h,
1725 | stride_tok,
1726 | stride_d, #
1727 | # Used for the mask.
1728 | mask_stride_z,
1729 | mask_stride_h,
1730 | mask_stride_tok1,
1731 | mask_stride_tok2,
1732 | # Dimensions and sizes.
1733 | H,
1734 | N_CTX, #
1735 | BLOCK_M1: tl.constexpr, #
1736 | BLOCK_N1: tl.constexpr, #
1737 | BLOCK_M2: tl.constexpr, #
1738 | BLOCK_N2: tl.constexpr, #
1739 | BLK_SLICE_FACTOR: tl.constexpr, #
1740 | HEAD_DIM: tl.constexpr,
1741 | # Args for masking/dropout.
1742 | mask_ptr,
1743 | MASK_TYPE: tl.constexpr,
1744 | dropout_p,
1745 | philox_seed,
1746 | ENABLE_DROPOUT: tl.constexpr,
1747 | ):
1748 | """The main backward pass for the (causal) attention mechanism.
1749 |
1750 | This computes gradients for only ~N²/2 pairwise token interactions,
1751 | since causal attention already masks out half of the interactions.
1752 |
1753 | Args:
1754 | Q: Query tensor.
1755 | K: Key tensor.
1756 | V: Value tensor.
1757 | sm_scale: Scale factor for the softmax.
1758 | DO: Gradient of the output tensor.
1759 | DQ: Gradient of the query tensor.
1760 | DK: Gradient of the key tensor.
1761 | DV: Gradient of the value tensor.
1762 | M: Memory tensor.
1763 | D: Delta tensor.
1764 | stride_z: Stride for the z dimension.
1765 | stride_h: Stride for the head dimension.
1766 | stride_tok: Stride for the token dimension.
1767 | stride_d: Stride for the head dimension.
1768 | mask_stride_z: Stride for the z dimension in the mask tensor.
1769 | mask_stride_h: Stride for the head dimension in the mask tensor.
1770 | mask_stride_tok1: Stride for the first token (row) dimension in the mask tensor.
1771 | mask_stride_tok2: Stride for the second token (column) dimension in the mask tensor.
1772 | H: Head dimension.
1773 | N_CTX: Context length.
1774 | BLOCK_M1: Block size for M dimension.
1775 | BLOCK_N1: Block size for N dimension.
1776 | BLOCK_M2: Block size for M dimension.
1777 | BLOCK_N2: Block size for N dimension.
1778 | BLK_SLICE_FACTOR: Block slice factor.
1779 | HEAD_DIM: Head dimension size.
1780 | mask_ptr: Pointer to the mask tensor.
1781 | MASK_TYPE: Type of masking (0: no mask, 1: boolean mask,
1782 | 2: additive mask).
1783 | dropout_p: Dropout probability.
1784 | philox_seed: Seed for Philox RNG.
1785 | ENABLE_DROPOUT: Flag to enable dropout.
1786 | """
1787 | # Constants
1788 | LN2: tl.constexpr = 0.6931471824645996 # = ln(2)
1789 |
1790 | # Collect sequence, batch, and head indices
1791 | start_block_id = tl.program_id(0) # Which block (in the input query sequence) to process
1792 | off_hz = tl.program_id(
1793 | 1
1794 | ) # Which head and batch element to process, with a program being a single head of a single batch element
1795 | off_z = (
1796 | off_hz // H
1797 | ) # Which batch element this program is assigned to (n.b., each batch element has H heads)
1798 | off_h = off_hz % H # The position of the head to process in the batch
1799 |
1800 | # NOTE: This allows one to get the (N_CTX, HEAD_DIM) block in Q, K, V, etc. by indexing it by batch and head
1801 | delta_shared_offset = (off_hz * N_CTX).to(tl.int64)
1802 | qkv_shared_offset = off_z.to(tl.int64) * stride_z + off_h.to(tl.int64) * stride_h
1803 |
1804 | # Offset pointers for batch elements and heads
1805 | Q += qkv_shared_offset
1806 | K += qkv_shared_offset
1807 | V += qkv_shared_offset
1808 | DO += qkv_shared_offset
1809 | DQ += qkv_shared_offset
1810 | DK += qkv_shared_offset
1811 | DV += qkv_shared_offset
1812 |
1813 | M += delta_shared_offset # NOTE: These tensors have fewer dimensions.
1814 | D += delta_shared_offset
1815 |
1816 | # Initialize pointer for the mask, if provided
1817 | if MASK_TYPE > 0:
1818 | mask_offset = off_z.to(tl.int64) * mask_stride_z + off_h.to(tl.int64) * mask_stride_h
1819 | mask_ptr += mask_offset
1820 |
1821 | # Generate philox offset for this block
1822 | philox_offset_base = off_hz * N_CTX * N_CTX
1823 |
1824 | # ====== COMPUTE dK and dV ======
1825 | # Determine step size for dK and dV computation
1826 | MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR
1827 |
1828 | # Prepare offsets for loading Q/K/V/DO
1829 | start_n = start_block_id * BLOCK_N1
1830 | start_m = start_n
1831 |
1832 | # Load K and V: They will stay in SRAM throughout.
1833 | offs_n = start_n + tl.arange(0, BLOCK_N1)
1834 | offs_h = tl.arange(0, HEAD_DIM)
1835 |
1836 | k = tl.load(K + offs_n[:, None] * stride_tok + offs_h[None, :] * stride_d)
1837 | v = tl.load(V + offs_n[:, None] * stride_tok + offs_h[None, :] * stride_d)
1838 |
1839 | # Initialize dK and dV accumulators
1840 | dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32)
1841 | dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32)
1842 |
1843 | # Compute dK and dV for (causally) masked blocks
1844 | num_steps = BLOCK_N1 // MASK_BLOCK_M1
1845 | dk, dv = _attn_bwd_dkdv(
1846 | dk,
1847 | dv, #
1848 | Q,
1849 | k,
1850 | v,
1851 | DO, #
1852 | M,
1853 | D, #
1854 | stride_tok,
1855 | stride_d, #
1856 | N_CTX, #
1857 | MASK_BLOCK_M1,
1858 | BLOCK_N1,
1859 | HEAD_DIM, #
1860 | start_n,
1861 | start_m,
1862 | num_steps, #
1863 | CAUSAL_MASKING=True, #
1864 | mask_ptr=mask_ptr,
1865 | mask_stride_tok1=mask_stride_tok1,
1866 | mask_stride_tok2=mask_stride_tok2,
1867 | MASK_TYPE=MASK_TYPE,
1868 | dropout_p=dropout_p,
1869 | philox_seed=philox_seed,
1870 | philox_offset_base=philox_offset_base,
1871 | ENABLE_DROPOUT=ENABLE_DROPOUT,
1872 | )
1873 |
1874 | start_m += num_steps * MASK_BLOCK_M1
1875 | num_steps = (N_CTX - start_m) // BLOCK_M1
1876 |
1877 | # Compute dK and dV for (causally) non-masked blocks
1878 | dk, dv = _attn_bwd_dkdv( #
1879 | dk,
1880 | dv, #
1881 | Q,
1882 | k,
1883 | v,
1884 | DO, #
1885 | M,
1886 | D, #
1887 | stride_tok,
1888 | stride_d, #
1889 | N_CTX, #
1890 | BLOCK_M1,
1891 | BLOCK_N1,
1892 | HEAD_DIM, #
1893 | start_n,
1894 | start_m,
1895 | num_steps, #
1896 | CAUSAL_MASKING=False, #
1897 | mask_ptr=mask_ptr,
1898 | mask_stride_tok1=mask_stride_tok1,
1899 | mask_stride_tok2=mask_stride_tok2,
1900 | MASK_TYPE=MASK_TYPE,
1901 | dropout_p=dropout_p,
1902 | philox_seed=philox_seed,
1903 | philox_offset_base=philox_offset_base,
1904 | ENABLE_DROPOUT=ENABLE_DROPOUT,
1905 | )
1906 |
1907 | # Write-back dV
1908 | dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_h[None, :] * stride_d
1909 | tl.store(dv_ptrs, dv)
1910 |
1911 | # Write-back dK (scaled)
1912 | dk *= sm_scale
1913 | dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_h[None, :] * stride_d
1914 | tl.store(dk_ptrs, dk)
1915 |
1916 | # ====== COMPUTE dQ ======
1917 | # Determine step size for dQ computation
1918 | MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR
1919 |
1920 | # Prepare offsets for dQ computation
1921 | start_m = start_block_id * BLOCK_M2
1922 | end_n = start_m + BLOCK_M2
1923 |
1924 | offs_m = start_m + tl.arange(0, BLOCK_M2)
1925 |
1926 | # Load Q, DO, and M: They will stay in SRAM throughout.
1927 | q = tl.load(Q + offs_m[:, None] * stride_tok + offs_h[None, :] * stride_d)
1928 | do = tl.load(DO + offs_m[:, None] * stride_tok + offs_h[None, :] * stride_d)
1929 |
1930 | m = tl.load(M + offs_m)
1931 | m = m[:, None]
1932 |
1933 | # Initialize dQ accumulator
1934 | dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32)
1935 |
1936 | # Compute dQ for (causally) masked blocks
1937 | num_steps = BLOCK_M2 // MASK_BLOCK_N2
1938 | dq = _attn_bwd_dq(
1939 | # NOTE: This code scans each row of QK^T backward (from right to left,
1940 | # but inside each call to _attn_bwd_dq, from left to right), but that's
1941 | # not due to anything important. It's just to reuse the loop structure
1942 | # for dK and dV above as much as possible.
1943 | dq,
1944 | q,
1945 | K,
1946 | V, #
1947 | do,
1948 | m,
1949 | D, #
1950 | stride_tok,
1951 | stride_d, #
1952 | N_CTX, #
1953 | BLOCK_M2,
1954 | MASK_BLOCK_N2,
1955 | HEAD_DIM, #
1956 | start_m,
1957 | end_n - num_steps * MASK_BLOCK_N2,
1958 | num_steps, #
1959 | CAUSAL_MASKING=True, #
1960 | mask_ptr=mask_ptr,
1961 | mask_stride_tok1=mask_stride_tok1,
1962 | mask_stride_tok2=mask_stride_tok2,
1963 | MASK_TYPE=MASK_TYPE,
1964 | dropout_p=dropout_p,
1965 | philox_seed=philox_seed,
1966 | philox_offset_base=philox_offset_base,
1967 | ENABLE_DROPOUT=ENABLE_DROPOUT,
1968 | )
1969 |
1970 | end_n -= num_steps * MASK_BLOCK_N2
1971 |
1972 | # Compute dQ for (causally) non-masked blocks
1973 | num_steps = end_n // BLOCK_N2
1974 | dq = _attn_bwd_dq(
1975 | dq,
1976 | q,
1977 | K,
1978 | V, #
1979 | do,
1980 | m,
1981 | D, #
1982 | stride_tok,
1983 | stride_d, #
1984 | N_CTX, #
1985 | BLOCK_M2,
1986 | BLOCK_N2,
1987 | HEAD_DIM, #
1988 | start_m,
1989 | end_n - num_steps * BLOCK_N2,
1990 | num_steps, #
1991 | CAUSAL_MASKING=False, #
1992 | mask_ptr=mask_ptr,
1993 | mask_stride_tok1=mask_stride_tok1,
1994 | mask_stride_tok2=mask_stride_tok2,
1995 | MASK_TYPE=MASK_TYPE,
1996 | dropout_p=dropout_p,
1997 | philox_seed=philox_seed,
1998 | philox_offset_base=philox_offset_base,
1999 | ENABLE_DROPOUT=ENABLE_DROPOUT,
2000 | )
2001 |
2002 | # Write-back dQ (scaled)
2003 | dq *= LN2
2004 | dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_h[None, :] * stride_d
2005 | tl.store(dq_ptrs, dq)
2006 |
2007 |
2008 | @triton.jit
2009 | def _attn_bwd(
2010 | Q,
2011 | K,
2012 | V,
2013 | sm_scale, #
2014 | DO, #
2015 | DQ,
2016 | DK,
2017 | DV, #
2018 | M,
2019 | D,
2020 | # Shared by Q/K/V/DO.
2021 | stride_z,
2022 | stride_h,
2023 | stride_tok,
2024 | stride_d, #
2025 | # Used for the mask.
2026 | mask_stride_z,
2027 | mask_stride_h,
2028 | mask_stride_tok1,
2029 | mask_stride_tok2,
2030 | # Dimensions and sizes.
2031 | H,
2032 | N_CTX, #
2033 | BLOCK_M1: tl.constexpr, #
2034 | BLOCK_N1: tl.constexpr, #
2035 | BLOCK_M2: tl.constexpr, #
2036 | BLOCK_N2: tl.constexpr, #
2037 | BLK_SLICE_FACTOR: tl.constexpr, #
2038 | HEAD_DIM: tl.constexpr,
2039 | # Args for masking/dropout.
2040 | mask_ptr,
2041 | MASK_TYPE: tl.constexpr,
2042 | dropout_p,
2043 | philox_seed,
2044 | ENABLE_DROPOUT: tl.constexpr,
2045 | ):
2046 | """The main backward pass for the (non-causal) attention mechanism.
2047 |
2048 | This computes gradients for all N² pairwise token interactions,
2049 | unlike the causal version which only computes ~N²/2.
2050 |
2051 | Args:
2052 | Q: Query tensor.
2053 | K: Key tensor.
2054 | V: Value tensor.
2055 | sm_scale: Scale factor for the softmax.
2056 | DO: Gradient of the output tensor.
2057 | DQ: Gradient of the query tensor.
2058 | DK: Gradient of the key tensor.
2059 | DV: Gradient of the value tensor.
2060 | M: Memory tensor.
2061 | D: Delta tensor.
2062 | stride_z: Stride for the z dimension.
2063 | stride_h: Stride for the head dimension.
2064 | stride_tok: Stride for the token dimension.
2065 | stride_d: Stride for the head dimension.
2066 | mask_stride_z: Stride for the z dimension in the mask tensor.
2067 | mask_stride_h: Stride for the head dimension in the mask tensor.
2068 | mask_stride_tok1: Stride for the first token (row) dimension in the mask tensor.
2069 | mask_stride_tok2: Stride for the second token (column) dimension in the mask tensor.
2070 | H: Head dimension.
2071 | N_CTX: Context length.
2072 | BLOCK_M1: Block size for M dimension.
2073 | BLOCK_N1: Block size for N dimension.
2074 | BLOCK_M2: Block size for M dimension.
2075 | BLOCK_N2: Block size for N dimension.
2076 | BLK_SLICE_FACTOR: Block slice factor.
2077 | HEAD_DIM: Head dimension size.
2078 | mask_ptr: Pointer to the mask tensor.
2079 | MASK_TYPE: Type of masking (0: no mask, 1: boolean mask,
2080 | 2: additive mask).
2081 | dropout_p: Dropout probability.
2082 | philox_seed: Seed for Philox RNG.
2083 | ENABLE_DROPOUT: Flag to enable dropout.
2084 | """
2085 | # Constants
2086 | LN2: tl.constexpr = 0.6931471824645996 # = ln(2)
2087 |
2088 | # Collect sequence, batch, and head indices
2089 | start_block_id = tl.program_id(0) # Which block (in the input query sequence) to process
2090 | off_hz = tl.program_id(
2091 | 1
2092 | ) # Which head and batch element to process, with a program being a single head of a single batch element
2093 | off_z = (
2094 | off_hz // H
2095 | ) # Which batch element this program is assigned to (n.b., each batch element has H heads)
2096 | off_h = off_hz % H # The position of the head to process in the batch
2097 |
2098 | # NOTE: This allows one to get the (N_CTX, HEAD_DIM) block in Q, K, V, etc. by indexing it by batch and head
2099 | delta_shared_offset = (off_hz * N_CTX).to(tl.int64)
2100 | qkv_shared_offset = off_z.to(tl.int64) * stride_z + off_h.to(tl.int64) * stride_h
2101 |
2102 | # Offset pointers for batch elements and heads
2103 | Q += qkv_shared_offset
2104 | K += qkv_shared_offset
2105 | V += qkv_shared_offset
2106 | DO += qkv_shared_offset
2107 | DQ += qkv_shared_offset
2108 | DK += qkv_shared_offset
2109 | DV += qkv_shared_offset
2110 |
2111 | M += delta_shared_offset # NOTE: These tensors have fewer dimensions.
2112 | D += delta_shared_offset
2113 |
2114 | # Initialize pointer for the mask, if provided
2115 | if MASK_TYPE > 0:
2116 | mask_offset = off_z.to(tl.int64) * mask_stride_z + off_h.to(tl.int64) * mask_stride_h
2117 | mask_ptr += mask_offset
2118 |
2119 | # Generate philox offset for this block
2120 | philox_offset_base = off_hz * N_CTX * N_CTX
2121 |
2122 | # ====== COMPUTE dK and dV ======
2123 | # For non-causal attention, we process ALL query blocks (the entire sequence)
2124 | # This is the key difference from causal: we iterate through all Q positions
2125 |
2126 | # Prepare offsets for loading Q/K/V/DO
2127 | start_n = start_block_id * BLOCK_N1
2128 |
2129 | # Load K and V: They will stay in SRAM throughout.
2130 | offs_n = start_n + tl.arange(0, BLOCK_N1)
2131 | offs_h = tl.arange(0, HEAD_DIM)
2132 |
2133 | k = tl.load(K + offs_n[:, None] * stride_tok + offs_h[None, :] * stride_d)
2134 | v = tl.load(V + offs_n[:, None] * stride_tok + offs_h[None, :] * stride_d)
2135 |
2136 | # Initialize dK and dV accumulators
2137 | dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32)
2138 | dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32)
2139 |
2140 | start_m = 0 # Start from the beginning of the sequence
2141 | num_steps = N_CTX // BLOCK_M1 # Process the entire sequence
2142 |
2143 | dk, dv = _attn_bwd_dkdv(
2144 | dk,
2145 | dv, #
2146 | Q,
2147 | k,
2148 | v,
2149 | DO, #
2150 | M,
2151 | D, #
2152 | stride_tok,
2153 | stride_d, #
2154 | N_CTX, #
2155 | BLOCK_M1,
2156 | BLOCK_N1,
2157 | HEAD_DIM, #
2158 | start_n,
2159 | start_m,
2160 | num_steps, #
2161 | CAUSAL_MASKING=False,
2162 | mask_ptr=mask_ptr,
2163 | mask_stride_tok1=mask_stride_tok1,
2164 | mask_stride_tok2=mask_stride_tok2,
2165 | MASK_TYPE=MASK_TYPE,
2166 | dropout_p=dropout_p,
2167 | philox_seed=philox_seed,
2168 | philox_offset_base=philox_offset_base,
2169 | ENABLE_DROPOUT=ENABLE_DROPOUT,
2170 | )
2171 |
2172 | # Write-back dV
2173 | dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_h[None, :] * stride_d
2174 | tl.store(dv_ptrs, dv)
2175 |
2176 | # Write-back dK (scaled)
2177 | dk *= sm_scale
2178 | dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_h[None, :] * stride_d
2179 | tl.store(dk_ptrs, dk)
2180 |
2181 | # ====== COMPUTE dQ ======
2182 | # Prepare offsets for dQ computation
2183 | start_m = start_block_id * BLOCK_M2
2184 | offs_m = start_m + tl.arange(0, BLOCK_M2)
2185 |
2186 | # Load Q, DO, and M: They will stay in SRAM throughout.
2187 | q = tl.load(Q + offs_m[:, None] * stride_tok + offs_h[None, :] * stride_d)
2188 | do = tl.load(DO + offs_m[:, None] * stride_tok + offs_h[None, :] * stride_d)
2189 |
2190 | m = tl.load(M + offs_m)
2191 | m = m[:, None]
2192 |
2193 | # Initialize dQ accumulator
2194 | dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32)
2195 |
2196 | # For non-causal attention, we process ALL key/value blocks (the entire sequence)
2197 | # This means each query position can attend to ALL key/value positions
2198 |
2199 | start_n = 0 # Start from the beginning of the sequence
2200 | num_steps = N_CTX // BLOCK_N2 # Process the entire sequence
2201 |
2202 | dq = _attn_bwd_dq(
2203 | dq,
2204 | q,
2205 | K,
2206 | V, #
2207 | do,
2208 | m,
2209 | D, #
2210 | stride_tok,
2211 | stride_d, #
2212 | N_CTX, #
2213 | BLOCK_M2,
2214 | BLOCK_N2,
2215 | HEAD_DIM, #
2216 | start_m,
2217 | start_n,
2218 | num_steps, #
2219 | CAUSAL_MASKING=False, #
2220 | mask_ptr=mask_ptr,
2221 | mask_stride_tok1=mask_stride_tok1,
2222 | mask_stride_tok2=mask_stride_tok2,
2223 | MASK_TYPE=MASK_TYPE,
2224 | dropout_p=dropout_p,
2225 | philox_seed=philox_seed,
2226 | philox_offset_base=philox_offset_base,
2227 | ENABLE_DROPOUT=ENABLE_DROPOUT,
2228 | )
2229 |
2230 | # Write-back dQ (scaled)
2231 | dq *= LN2
2232 | dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_h[None, :] * stride_d
2233 | tl.store(dq_ptrs, dq)
2234 |
2235 |
2236 | class JVPAttn(Function):
2237 | """JVP (Jacobian-Vector Product) for Attention Mechanism."""
2238 |
2239 | class Grid(NamedTuple):
2240 | """Grid configuration for JVP Attention."""
2241 |
2242 | M_BLOCKS: int
2243 | Z_H: int
2244 | ONE: Literal[1]
2245 |
2246 | class FnCtx(FunctionCtx):
2247 | """Function context for JVP Attention."""
2248 |
2249 | sm_scale: float
2250 | HEAD_DIM_K: int
2251 | causal: bool
2252 | grid: JVPAttn.Grid
2253 | mask_tensor: Tensor
2254 | MASK_TYPE: int
2255 | dropout_p: float
2256 | philox_seed: int
2257 | ENABLE_DROPOUT: bool
2258 |
2259 | class FwdOutCtxContrib(NamedTuple):
2260 | """Forward output context contributions for JVP Attention."""
2261 |
2262 | o_t: Tensor | None
2263 | M: Tensor
2264 | grid: JVPAttn.Grid
2265 | HEAD_DIM_K: int
2266 | sm_scale: float
2267 | mask_tensor: Tensor
2268 | MASK_TYPE: int
2269 | dropout_p: float
2270 | philox_seed: int
2271 | ENABLE_DROPOUT: bool
2272 |
2273 | class FwdOut(NamedTuple):
2274 | """Forward output for JVP Attention."""
2275 |
2276 | o: Tensor
2277 | ctx: JVPAttn.FwdOutCtxContrib
2278 |
2279 | class JVPOut(NamedTuple):
2280 | """JVP output for JVP Attention."""
2281 |
2282 | o: Tensor
2283 | ctx: None
2284 |
2285 | class BwdOut(NamedTuple):
2286 | """Backward output for JVP Attention."""
2287 |
2288 | q: Tensor
2289 | k: Tensor
2290 | v: Tensor
2291 | q_t: None
2292 | k_t: None
2293 | v_t: None
2294 | attn_mask: None
2295 | dropout_p: None
2296 | causal: None
2297 | sm_scale: None
2298 | warp_specialize: None
2299 | USE_TMA: None
2300 | verify_attn_mask: None
2301 |
2302 | class Strides(NamedTuple):
2303 | """Strides for JVP Attention."""
2304 |
2305 | z: int
2306 | h: int
2307 | n_ctx: int
2308 | head_dim: int
2309 |
2310 | @staticmethod
2311 | def forward(
2312 | q: Tensor,
2313 | k: Tensor,
2314 | v: Tensor,
2315 | q_t: Tensor | None,
2316 | k_t: Tensor | None,
2317 | v_t: Tensor | None,
2318 | attn_mask: Tensor | None = None,
2319 | dropout_p: float = 0.0,
2320 | causal: bool = False,
2321 | sm_scale: float | None = None,
2322 | warp_specialize: bool = True,
2323 | USE_TMA: bool = True,
2324 | verify_attn_mask: bool = True,
2325 | ) -> JVPAttn.FwdOut:
2326 | """Forward pass for JVP Attention.
2327 |
2328 | NOTE: The following warning(s) will be raised if `verify_attn_mask=True`
2329 | and an attention mask with any all-null head is provided:
2330 | `RuntimeWarning: overflow encountered in exp2.`
2331 |
2332 | Args:
2333 | q: Query tensor of shape (Z, H, N_CTX, HEAD_DIM_Q).
2334 | k: Key tensor of shape (Z, H, N_CTX, HEAD_DIM_K).
2335 | v: Value tensor of shape (Z, H, N_CTX, HEAD_DIM_V).
2336 | q_t: Optional tensor for query transpose.
2337 | k_t: Optional tensor for key transpose.
2338 | v_t: Optional tensor for value transpose.
2339 | attn_mask: Optional attention mask of shape (Z, H, N_CTX, N_CTX).
2340 | Two types of masks are supported. A boolean mask where a value
2341 | of True indicates that the element should take part in attention,
2342 | or a float mask of the same type as query, key, value that is added
2343 | to the attention score. The constant `MASK_CONST` is used to
2344 | indicate masked positions in the float mask. All other values
2345 | denote unmasked positions.
2346 | dropout_p: Dropout probability.
2347 | causal: Whether the attention is causal.
2348 | sm_scale: Optional scaling factor for softmax.
2349 | warp_specialize: Whether to use warp specialization.
2350 | USE_TMA: Whether to use TMA.
2351 | verify_attn_mask: Whether to verify the correctness of the provided attention mask.
2352 |
2353 | Returns:
2354 | Outputs of JVP Attention.
2355 | """
2356 | if dropout_p != 0.0:
2357 | raise NotImplementedError("Dropout is not currently supported in JVP attention.")
2358 |
2359 | # Collect metadata
2360 | Z, H, N_CTX, HEAD_DIM_Q = q.shape
2361 | HEAD_DIM_K = k.shape[-1]
2362 | HEAD_DIM_V = v.shape[-1] # NOTE: When v is in float8_e5m2 it is transposed.
2363 |
2364 | STAGE = 3 if causal else 1
2365 | ENABLE_JVP = q_t is not None
2366 |
2367 | assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V, (
2368 | "JVP attention requires HEAD_DIM_Q == HEAD_DIM_K == HEAD_DIM_V"
2369 | f" but got HEAD_DIM_Q={HEAD_DIM_Q}, HEAD_DIM_K={HEAD_DIM_K}, HEAD_DIM_V={HEAD_DIM_V}"
2370 | )
2371 | assert HEAD_DIM_K in {16, 32, 64, 128, 256}, (
2372 | "JVP attention only supports HEAD_DIM_K in {16, 32, 64, 128, 256},"
2373 | f" but got HEAD_DIM_K={HEAD_DIM_K}",
2374 | )
2375 |
2376 | if causal and attn_mask is not None:
2377 | raise ValueError("Causal attention does not support an attention mask.")
2378 | if attn_mask is not None:
2379 | assert attn_mask.shape == (
2380 | Z,
2381 | H,
2382 | N_CTX,
2383 | N_CTX,
2384 | ), "The provided attention mask must have 4 dimensions (Z, H, N_CTX, N_CTX)."
2385 | assert attn_mask.dtype in {
2386 | torch.bool,
2387 | q.dtype,
2388 | }, "The attention mask must be of the dtype bool or that of the query tensor."
2389 |
2390 | # Initialize arguments and tensors
2391 | if sm_scale is None:
2392 | sm_scale = HEAD_DIM_K**-0.5
2393 |
2394 | o = torch.empty_like(q)
2395 | o_t: Tensor | None = torch.empty_like(q_t) if ENABLE_JVP else None
2396 | M = torch.empty((Z, H, N_CTX), device=q.device, dtype=torch.float32)
2397 |
2398 | # Tune kernel for custom (e.g., AMD) targets
2399 | extra_kern_args = {}
2400 |
2401 | if is_hip():
2402 | waves_per_eu = 3 if HEAD_DIM_K <= 64 else 2
2403 | extra_kern_args = {"waves_per_eu": waves_per_eu, "allow_flush_denorm": True}
2404 |
2405 | if is_cuda() and warp_specialize:
2406 | # NOTE: We need more registers if we're doing JVP
2407 | if (HEAD_DIM_K == 128 and q.dtype == torch.float16) or ENABLE_JVP:
2408 | extra_kern_args["maxnreg"] = 168
2409 | else:
2410 | # NOTE: For backward pass with HEAD_DIM_K=128, this is probably too low for H100; register allocation fails.
2411 | extra_kern_args["maxnreg"] = 80
2412 |
2413 | if hasattr(triton, "set_allocator") and is_cuda():
2414 |
2415 | def alloc_fn(size: int, align: int, _):
2416 | """Custom allocator function for Triton."""
2417 | return torch.empty(size, dtype=torch.int8, device="cuda")
2418 |
2419 | triton.set_allocator(alloc_fn)
2420 |
2421 | def strides_zhnd(t: Tensor) -> JVPAttn.Strides:
2422 | """Get strides for a tensor with shape (Z, H, N_CTX, HEAD_DIM)."""
2423 | return JVPAttn.Strides(t.stride(0), t.stride(1), t.stride(2), t.stride(3))
2424 |
2425 | # Determine mask type
2426 | if attn_mask is None:
2427 | MASK_TYPE = 0
2428 | mask_tensor = torch.empty(0, device=q.device, dtype=q.dtype)
2429 | mask_strides = (0, 0, 0, 0)
2430 | elif attn_mask.dtype == torch.bool:
2431 | MASK_TYPE = 1
2432 | mask_tensor = attn_mask.contiguous()
2433 | mask_strides = strides_zhnd(mask_tensor)
2434 | if verify_attn_mask:
2435 | # Check if any head is all False
2436 | assert mask_tensor.any(
2437 | dim=(-1, -2)
2438 | ).all(), "The attention mask cannot be all False for any head."
2439 | else:
2440 | MASK_TYPE = 2
2441 | mask_tensor = attn_mask.to(q.dtype).contiguous()
2442 | mask_strides = strides_zhnd(mask_tensor)
2443 | if verify_attn_mask:
2444 | # Check if the mask contains -inf/inf/NaN or is all (or no) MASK_CONST for any head
2445 | assert not torch.isinf(
2446 | mask_tensor
2447 | ).any(), "The attention mask cannot contain -inf or inf."
2448 | assert not torch.isnan(
2449 | mask_tensor
2450 | ).any(), "The attention mask cannot contain NaNs."
2451 | assert (
2452 | (mask_tensor != MASK_CONST).any(dim=(-1, -2)).all()
2453 | ), f"The attention mask cannot be all {MASK_CONST} (the masking constant) for any head."
2454 |
2455 | if not (mask_tensor == MASK_CONST).any():
2456 | raise UserWarning(
2457 | f"The provided floating-point attention mask does not mask out any elements with {MASK_CONST} (the masking constant). Consider using this constant for correct masking behavior."
2458 | )
2459 |
2460 | # Prepare dropout arguments
2461 | ENABLE_DROPOUT = dropout_p > 0.0
2462 | if ENABLE_DROPOUT:
2463 | philox_seed = torch.randint(0, 2**32, (1,), device=q.device, dtype=torch.int64).item()
2464 | else:
2465 | philox_seed = 0
2466 |
2467 | # Set up grid for kernel launch
2468 | Z_H = Z * H
2469 |
2470 | def grid(META: dict[str, Any]) -> JVPAttn.Grid:
2471 | """Determine grid configuration."""
2472 | return JVPAttn.Grid(triton.cdiv(N_CTX, META["BLOCK_M"]), Z_H, 1)
2473 |
2474 | if USE_TMA and supports_tma():
2475 | # NOTE: On Hopper, we cannot perform a FP8 dot with a non-transposed second tensor.
2476 | y_dim = Z_H * N_CTX
2477 | tma_block_shape = [MIN_SEQUENCE_LENGTH, HEAD_DIM_K]
2478 |
2479 | desc_q = TensorDescriptor(
2480 | q,
2481 | shape=[y_dim, HEAD_DIM_K],
2482 | strides=[HEAD_DIM_K, 1],
2483 | block_shape=tma_block_shape,
2484 | )
2485 | desc_q_t = (
2486 | desc_q
2487 | if q_t is None
2488 | else TensorDescriptor(
2489 | q_t,
2490 | shape=[y_dim, HEAD_DIM_K],
2491 | strides=[HEAD_DIM_K, 1],
2492 | block_shape=tma_block_shape,
2493 | )
2494 | )
2495 |
2496 | if q.dtype == torch.float8_e5m2:
2497 | v_shape = [HEAD_DIM_K, y_dim]
2498 | v_strides = [N_CTX, 1]
2499 | else:
2500 | v_shape = [y_dim, HEAD_DIM_K]
2501 | v_strides = [HEAD_DIM_K, 1]
2502 | desc_v = TensorDescriptor(
2503 | v, shape=v_shape, strides=v_strides, block_shape=tma_block_shape
2504 | )
2505 | # NOTE: Probably we could share the shape and strides from above, but whatever
2506 | if q_t is not None and q_t.dtype == torch.float8_e5m2:
2507 | t_v_shape = [HEAD_DIM_K, y_dim]
2508 | t_v_strides = [q_t.shape[2], 1]
2509 | else:
2510 | t_v_shape = [y_dim, HEAD_DIM_K]
2511 | t_v_strides = [HEAD_DIM_K, 1]
2512 | desc_v_t = (
2513 | desc_v
2514 | if v_t is None
2515 | else TensorDescriptor(
2516 | v_t, shape=t_v_shape, strides=t_v_strides, block_shape=tma_block_shape
2517 | )
2518 | )
2519 |
2520 | desc_k = TensorDescriptor(
2521 | k,
2522 | shape=[y_dim, HEAD_DIM_K],
2523 | strides=[HEAD_DIM_K, 1],
2524 | block_shape=tma_block_shape,
2525 | )
2526 | desc_k_t = (
2527 | desc_k
2528 | if k_t is None
2529 | else TensorDescriptor(
2530 | k_t,
2531 | shape=[y_dim, HEAD_DIM_K],
2532 | strides=[HEAD_DIM_K, 1],
2533 | block_shape=tma_block_shape,
2534 | )
2535 | )
2536 |
2537 | desc_o = TensorDescriptor(
2538 | o,
2539 | shape=[y_dim, HEAD_DIM_K],
2540 | strides=[HEAD_DIM_K, 1],
2541 | block_shape=tma_block_shape,
2542 | )
2543 | desc_o_t = (
2544 | desc_o
2545 | if o_t is None
2546 | else TensorDescriptor(
2547 | o_t,
2548 | shape=[y_dim, HEAD_DIM_K],
2549 | strides=[HEAD_DIM_K, 1],
2550 | block_shape=tma_block_shape,
2551 | )
2552 | )
2553 |
2554 | _attn_fwd_tma[grid](
2555 | sm_scale,
2556 | M, #
2557 | Z,
2558 | H, #
2559 | desc_q,
2560 | desc_k,
2561 | desc_v, #
2562 | desc_q_t,
2563 | desc_k_t,
2564 | desc_v_t, #
2565 | desc_o,
2566 | desc_o_t, #
2567 | mask_tensor, #
2568 | dropout_p, #
2569 | philox_seed, #
2570 | *mask_strides, #
2571 | N_CTX=N_CTX, #
2572 | HEAD_DIM=HEAD_DIM_K, #
2573 | FP8_OUTPUT=q.dtype == torch.float8_e5m2, #
2574 | STAGE=STAGE, #
2575 | warp_specialize=warp_specialize, #
2576 | ENABLE_JVP=ENABLE_JVP, #
2577 | ENABLE_DROPOUT=ENABLE_DROPOUT,
2578 | MASK_TYPE=MASK_TYPE,
2579 | # NOTE: The following are safe (unit-tested) default values
2580 | BLOCK_M=MIN_SEQUENCE_LENGTH, #
2581 | BLOCK_N=MIN_SEQUENCE_LENGTH, #
2582 | num_stages=NUM_STAGES_OPTIONS[0], #
2583 | num_warps=4, #
2584 | **extra_kern_args,
2585 | )
2586 |
2587 | else:
2588 | _attn_fwd[grid](
2589 | q,
2590 | k,
2591 | v,
2592 | q_t,
2593 | k_t,
2594 | v_t, #
2595 | sm_scale,
2596 | M,
2597 | o,
2598 | o_t, #
2599 | mask_tensor, #
2600 | dropout_p, #
2601 | philox_seed, #
2602 | *strides_zhnd(q), #
2603 | *strides_zhnd(k), #
2604 | *strides_zhnd(v), #
2605 | *strides_zhnd(q if q_t is None else q_t), #
2606 | *strides_zhnd(k if k_t is None else k_t), #
2607 | *strides_zhnd(v if v_t is None else v_t), #
2608 | *strides_zhnd(o), #
2609 | *strides_zhnd(o if o_t is None else o_t), #
2610 | *mask_strides, #
2611 | Z,
2612 | H, #
2613 | N_CTX=N_CTX, #
2614 | HEAD_DIM=HEAD_DIM_K, #
2615 | FP8_OUTPUT=q.dtype == torch.float8_e5m2, #
2616 | STAGE=STAGE, #
2617 | warp_specialize=warp_specialize, #
2618 | ENABLE_JVP=ENABLE_JVP, #
2619 | ENABLE_DROPOUT=ENABLE_DROPOUT,
2620 | MASK_TYPE=MASK_TYPE,
2621 | # NOTE: The following are safe (unit-tested) default values
2622 | BLOCK_M=MIN_SEQUENCE_LENGTH, #
2623 | BLOCK_N=MIN_SEQUENCE_LENGTH, #
2624 | num_stages=NUM_STAGES_OPTIONS[0], #
2625 | num_warps=4, #
2626 | **extra_kern_args,
2627 | )
2628 |
2629 | return JVPAttn.FwdOut(
2630 | o,
2631 | JVPAttn.FwdOutCtxContrib(
2632 | o_t,
2633 | M,
2634 | grid,
2635 | HEAD_DIM_K,
2636 | sm_scale,
2637 | mask_tensor,
2638 | MASK_TYPE,
2639 | dropout_p,
2640 | philox_seed,
2641 | ENABLE_DROPOUT,
2642 | ),
2643 | )
2644 |
2645 | @staticmethod
2646 | def setup_context(ctx: JVPAttn.FnCtx, inputs, outputs: JVPAttn.FwdOut) -> Tensor:
2647 | """Set up the context for JVP Attention.
2648 |
2649 | Args:
2650 | ctx: The context to set up
2651 | inputs: The input tensors
2652 | outputs: The output tensors
2653 | """
2654 | (
2655 | q,
2656 | k,
2657 | v,
2658 | q_t,
2659 | k_t,
2660 | v_t,
2661 | attn_mask,
2662 | dropout_p,
2663 | causal,
2664 | sm_scale,
2665 | warp_specialize,
2666 | USE_TMA,
2667 | verify_attn_mask,
2668 | ) = inputs
2669 |
2670 | o, (
2671 | o_t,
2672 | M,
2673 | grid,
2674 | HEAD_DIM_K,
2675 | sm_scale,
2676 | mask_tensor,
2677 | MASK_TYPE,
2678 | dropout_p,
2679 | philox_seed,
2680 | ENABLE_DROPOUT,
2681 | ) = outputs
2682 |
2683 | ctx.grid = grid
2684 | ctx.save_for_forward(o_t)
2685 | ctx.save_for_backward(q, k, v, o, M)
2686 |
2687 | ctx.sm_scale = sm_scale
2688 | ctx.HEAD_DIM_K = HEAD_DIM_K
2689 | ctx.causal = causal
2690 | ctx.mask_tensor = mask_tensor
2691 | ctx.MASK_TYPE = MASK_TYPE
2692 | ctx.dropout_p = dropout_p
2693 | ctx.philox_seed = philox_seed
2694 | ctx.ENABLE_DROPOUT = ENABLE_DROPOUT
2695 |
2696 | @staticmethod
2697 | def fwd(
2698 | q: Tensor,
2699 | k: Tensor,
2700 | v: Tensor,
2701 | attn_mask: Tensor | None = None,
2702 | dropout_p: float = 0.0,
2703 | causal: bool = False,
2704 | sm_scale: float | None = None,
2705 | warp_specialize: bool = True,
2706 | USE_TMA: bool = True,
2707 | ) -> Tensor:
2708 | """Forward pass for JVP Attention.
2709 |
2710 | NOTE: This is not an autograd convention. It's a workaround to get type-hinting and kwarg support.
2711 |
2712 | NOTE: Calls to `contiguous()` are necessary to ensure the inputs are contiguous in memory
2713 | (e.g., due to an `unbind` call to create `q`, `k`, `v`) but nonetheless may incur a performance cost.
2714 |
2715 | Args:
2716 | q: Query tensor of shape (Z, H, N_CTX, HEAD_DIM_Q).
2717 | k: Key tensor of shape (Z, H, N_CTX, HEAD_DIM_K).
2718 | v: Value tensor of shape (Z, H, N_CTX, HEAD_DIM_V).
2719 | attn_mask: Optional attention mask of shape (Z, H, N_CTX, N_CTX). Two types of masks are supported. A boolean mask where a value of True indicates that the element should take part in attention, or a float mask of the same type as query, key, value that is added to the attention score.
2720 | dropout_p: Dropout probability.
2721 | causal: Whether to use causal attention.
2722 | sm_scale: The softmax scale factor.
2723 | warp_specialize: Whether to use warp specialization.
2724 | USE_TMA: Whether to use TMA.
2725 |
2726 | Returns:
2727 | The output tensor.
2728 | """
2729 | if not (q.is_contiguous() and k.is_contiguous() and v.is_contiguous()):
2730 | q, k, v = q.contiguous(), k.contiguous(), v.contiguous()
2731 |
2732 | out: JVPAttn.FwdOut = JVPAttn.apply(
2733 | q,
2734 | k,
2735 | v,
2736 | None,
2737 | None,
2738 | None,
2739 | attn_mask,
2740 | dropout_p,
2741 | causal,
2742 | sm_scale,
2743 | warp_specialize,
2744 | USE_TMA,
2745 | )
2746 |
2747 | a, _ = out
2748 | return a
2749 |
2750 | @staticmethod
2751 | def fwd_dual(
2752 | q: Tensor,
2753 | k: Tensor,
2754 | v: Tensor,
2755 | attn_mask: Tensor | None = None,
2756 | dropout_p: float = 0.0,
2757 | causal: bool = False,
2758 | sm_scale: float | None = None,
2759 | warp_specialize: bool = True,
2760 | USE_TMA: bool = True,
2761 | ) -> Tensor:
2762 | """Forward pass for JVP Attention with dual tensor inputs.
2763 |
2764 | NOTE: This is not an autograd convention. It's a workaround to get type-hinting and kwarg support.
2765 |
2766 | NOTE: Calls to `contiguous()` are necessary to ensure the inputs are contiguous in memory
2767 | (e.g., due to an `unbind` call to create `q`, `k`, `v`) but nonetheless may incur a performance cost.
2768 |
2769 | Args:
2770 | q: Query tensor of shape (Z, H, N_CTX, HEAD_DIM_Q).
2771 | k: Key tensor of shape (Z, H, N_CTX, HEAD_DIM_K).
2772 | v: Value tensor of shape (Z, H, N_CTX, HEAD_DIM_V).
2773 | attn_mask: Optional attention mask of shape (Z, H, N_CTX, N_CTX). Two types of masks are supported. A boolean mask where a value of True indicates that the element should take part in attention, or a float mask of the same type as query, key, value that is added to the attention score.
2774 | dropout_p: Dropout probability.
2775 | causal: Whether to use causal attention.
2776 | sm_scale: The softmax scale factor.
2777 | warp_specialize: Whether to use warp specialization.
2778 | USE_TMA: Whether to use TMA.
2779 |
2780 | Returns:
2781 | The output tensor.
2782 | """
2783 | if not (q.is_contiguous() and k.is_contiguous() and v.is_contiguous()):
2784 | q, k, v = q.contiguous(), k.contiguous(), v.contiguous()
2785 |
2786 | q_p, q_t = fwAD.unpack_dual(q)
2787 | k_p, k_t = fwAD.unpack_dual(k)
2788 | v_p, v_t = fwAD.unpack_dual(v)
2789 |
2790 | # NOTE: We pass some dualtensor args to ensure jvp() will be called,
2791 | # but we also pass tangents separately, as forward() demotes dual
2792 | # tensor args to primals for some reason.
2793 | out: JVPAttn.FwdOut = JVPAttn.apply(
2794 | q,
2795 | k,
2796 | v,
2797 | q_t,
2798 | k_t,
2799 | v_t,
2800 | attn_mask,
2801 | dropout_p,
2802 | causal,
2803 | sm_scale,
2804 | warp_specialize,
2805 | USE_TMA,
2806 | )
2807 |
2808 | a, _ = out
2809 | return a
2810 |
2811 | @staticmethod
2812 | def jvp(ctx: JVPAttn.FnCtx, gq: Tensor, gk: Tensor, gv: Tensor, *_) -> JVPAttn.JVPOut:
2813 | """Compute the Jacobian-vector product (JVP) for JVP Attention.
2814 |
2815 | Args:
2816 | ctx: The context
2817 | gq: The gradient of the query tensor
2818 | gk: The gradient of the key tensor
2819 | gv: The gradient of the value tensor
2820 |
2821 | Returns:
2822 | The JVP output.
2823 | """
2824 | return JVPAttn.JVPOut(ctx.saved_for_forward[0], None)
2825 |
2826 | @staticmethod
2827 | def backward(ctx, do, _) -> JVPAttn.BwdOut:
2828 | """Backward pass for JVP Attention.
2829 |
2830 | NOTE: A call to `contiguous()` may be necessary to ensure the output derivatives are contiguous
2831 | in memory (e.g., due to autograd weirdness) but nonetheless may incur a performance cost.
2832 |
2833 | Args:
2834 | ctx: The context
2835 | do: The gradient of the output tensor
2836 |
2837 | Returns:
2838 | The backward output.
2839 | """
2840 | q, k, v, o, M = ctx.saved_tensors
2841 |
2842 | # Ensure inputs/outputs the kernel reads share the same (contiguous) layout
2843 | if not (
2844 | q.is_contiguous() and k.is_contiguous() and v.is_contiguous() and o.is_contiguous()
2845 | ):
2846 | raise ValueError(
2847 | "JVPAttn expected q, k, v, o to be contiguous; got "
2848 | f"q.is_contiguous()={q.is_contiguous()}, k.is_contiguous()={k.is_contiguous()}, "
2849 | f"v.is_contiguous()={v.is_contiguous()}, o.is_contiguous()={o.is_contiguous()}, "
2850 | f"do.is_contiguous()={do.is_contiguous()}"
2851 | )
2852 |
2853 | # NOTE: Autograd may deliver a non-contiguous output gradient; if so, normalize it.
2854 | if not do.is_contiguous():
2855 | do = do.contiguous()
2856 |
2857 | # Ensure all inputs/outputs the kernel reads share the same layout
2858 | assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride(), (
2859 | "JVPAttn expected q, k, v, o, do to have the same layout; got "
2860 | f"q.stride()={q.stride()}, k.stride()={k.stride()}, v.stride()={v.stride()}, "
2861 | f"o.stride()={o.stride()}, do.stride()={do.stride()}"
2862 | )
2863 |
2864 | # Initialize tensors for gradients
2865 | dq = torch.empty_like(q)
2866 | dk = torch.empty_like(k)
2867 | dv = torch.empty_like(v)
2868 | delta = torch.empty_like(M)
2869 |
2870 | # Collect metadata
2871 | Z, H, N_CTX = q.shape[:3]
2872 |
2873 | BLK_SLICE_FACTOR = 2 # NOTE: This is a safe default value to reduce backward memory usage
2874 | BLOCK_MIN = MIN_SEQUENCE_LENGTH # NOTE: Adjust according to minimum input sequence length
2875 | BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = BLOCK_MIN, BLOCK_MIN, BLOCK_MIN, BLOCK_MIN
2876 |
2877 | assert N_CTX % BLOCK_MIN == 0, f"N_CTX must be divisible by BLOCK_MIN={BLOCK_MIN}"
2878 |
2879 | if not ctx.causal:
2880 | assert (
2881 | BLOCK_M1 == BLOCK_M2 == BLOCK_N1 == BLOCK_N2
2882 | ), "For non-causal attention, all block sizes must be equal."
2883 |
2884 | # Scale k by sm_scale / ln(2) to account for softmax scaling and
2885 | # change-of-base of exponentiation (exp2).
2886 | RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2)
2887 | arg_k = k
2888 | arg_k = arg_k * (ctx.sm_scale * RCP_LN2)
2889 |
2890 | # Determine mask type
2891 | if ctx.MASK_TYPE == 0:
2892 | mask_strides = (0, 0, 0, 0)
2893 | else:
2894 | mask_strides = (
2895 | ctx.mask_tensor.stride(0),
2896 | ctx.mask_tensor.stride(1),
2897 | ctx.mask_tensor.stride(2),
2898 | ctx.mask_tensor.stride(3),
2899 | )
2900 |
2901 | # Set up grid for kernel launch
2902 | Z_H = Z * H
2903 |
2904 | # Preprocess output's deltas
2905 | pre_grid = (N_CTX // BLOCK_MIN, Z_H)
2906 | _attn_bwd_preprocess[pre_grid](
2907 | o,
2908 | do, #
2909 | delta, #
2910 | N_CTX, #
2911 | BLOCK_M=BLOCK_MIN,
2912 | HEAD_DIM=ctx.HEAD_DIM_K, #
2913 | )
2914 |
2915 | # Launch the backward kernel, enabling pipelining for backward pass on A100s
2916 | grid = (N_CTX // BLOCK_MIN, Z_H)
2917 | bwd_kernel = _attn_bwd_causal if ctx.causal else _attn_bwd
2918 | num_stages = (
2919 | 5
2920 | if is_cuda() and torch.cuda.get_device_capability()[0] == 9
2921 | else NUM_STAGES_OPTIONS[0]
2922 | )
2923 |
2924 | bwd_kernel[grid](
2925 | q,
2926 | arg_k,
2927 | v,
2928 | ctx.sm_scale,
2929 | do,
2930 | dq,
2931 | dk,
2932 | dv, #
2933 | M,
2934 | delta, #
2935 | q.stride(0),
2936 | q.stride(1),
2937 | q.stride(2),
2938 | q.stride(3), #
2939 | mask_strides[0],
2940 | mask_strides[1],
2941 | mask_strides[2],
2942 | mask_strides[3], #
2943 | H,
2944 | N_CTX, #
2945 | BLOCK_M1=BLOCK_M1,
2946 | BLOCK_N1=BLOCK_N1, #
2947 | BLOCK_M2=BLOCK_M2,
2948 | BLOCK_N2=BLOCK_N2, #
2949 | BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, #
2950 | HEAD_DIM=ctx.HEAD_DIM_K, #
2951 | mask_ptr=ctx.mask_tensor,
2952 | MASK_TYPE=ctx.MASK_TYPE,
2953 | dropout_p=ctx.dropout_p,
2954 | philox_seed=ctx.philox_seed,
2955 | ENABLE_DROPOUT=ctx.ENABLE_DROPOUT,
2956 | # NOTE: The following are safe (unit-tested) default values
2957 | num_stages=num_stages, #
2958 | num_warps=4, #
2959 | )
2960 |
2961 | return JVPAttn.BwdOut(
2962 | dq, dk, dv, None, None, None, None, None, None, None, None, None, None
2963 | )
2964 |
2965 |
2966 | attention = JVPAttn.fwd
2967 |
--------------------------------------------------------------------------------