├── .github
└── workflows
│ └── pylint.yml
├── .gitignore
├── LICENSE
├── README.md
├── benchmarks
├── triton-vs-jax-sdpa-cudnn
│ ├── README.md
│ ├── b=1-ub=False-hdim=128-qh=16-kvh=8-mode=fwd.csv
│ ├── b=1-ub=False-hdim=128-qh=16-kvh=8-mode=fwd.png
│ ├── b=1-ub=False-hdim=128-qh=32-kvh=16-mode=fwd.csv
│ ├── b=1-ub=False-hdim=128-qh=32-kvh=16-mode=fwd.png
│ ├── b=1-ub=False-hdim=256-qh=16-kvh=8-mode=fwd.csv
│ ├── b=1-ub=False-hdim=256-qh=16-kvh=8-mode=fwd.png
│ ├── b=1-ub=False-hdim=256-qh=32-kvh=16-mode=fwd.csv
│ ├── b=1-ub=False-hdim=256-qh=32-kvh=16-mode=fwd.png
│ ├── b=1-ub=False-hdim=64-qh=16-kvh=8-mode=fwd.csv
│ ├── b=1-ub=False-hdim=64-qh=16-kvh=8-mode=fwd.png
│ ├── b=1-ub=False-hdim=64-qh=32-kvh=16-mode=fwd.csv
│ ├── b=1-ub=False-hdim=64-qh=32-kvh=16-mode=fwd.png
│ ├── b=1-ub=True-hdim=128-qh=16-kvh=8-mode=fwd.csv
│ ├── b=1-ub=True-hdim=128-qh=16-kvh=8-mode=fwd.png
│ ├── b=1-ub=True-hdim=128-qh=32-kvh=16-mode=fwd.csv
│ ├── b=1-ub=True-hdim=128-qh=32-kvh=16-mode=fwd.png
│ ├── b=1-ub=True-hdim=256-qh=16-kvh=8-mode=fwd.csv
│ ├── b=1-ub=True-hdim=256-qh=16-kvh=8-mode=fwd.png
│ ├── b=1-ub=True-hdim=256-qh=32-kvh=16-mode=fwd.csv
│ ├── b=1-ub=True-hdim=256-qh=32-kvh=16-mode=fwd.png
│ ├── b=1-ub=True-hdim=64-qh=16-kvh=8-mode=fwd.csv
│ ├── b=1-ub=True-hdim=64-qh=16-kvh=8-mode=fwd.png
│ ├── b=1-ub=True-hdim=64-qh=32-kvh=16-mode=fwd.csv
│ ├── b=1-ub=True-hdim=64-qh=32-kvh=16-mode=fwd.png
│ ├── b=2-ub=False-hdim=128-qh=16-kvh=8-mode=fwd.csv
│ ├── b=2-ub=False-hdim=128-qh=16-kvh=8-mode=fwd.png
│ ├── b=2-ub=False-hdim=128-qh=32-kvh=16-mode=fwd.csv
│ ├── b=2-ub=False-hdim=128-qh=32-kvh=16-mode=fwd.png
│ ├── b=2-ub=False-hdim=256-qh=16-kvh=8-mode=fwd.csv
│ ├── b=2-ub=False-hdim=256-qh=16-kvh=8-mode=fwd.png
│ ├── b=2-ub=False-hdim=256-qh=32-kvh=16-mode=fwd.csv
│ ├── b=2-ub=False-hdim=256-qh=32-kvh=16-mode=fwd.png
│ ├── b=2-ub=False-hdim=64-qh=16-kvh=8-mode=fwd.csv
│ ├── b=2-ub=False-hdim=64-qh=16-kvh=8-mode=fwd.png
│ ├── b=2-ub=False-hdim=64-qh=32-kvh=16-mode=fwd.csv
│ ├── b=2-ub=False-hdim=64-qh=32-kvh=16-mode=fwd.png
│ ├── b=2-ub=True-hdim=128-qh=16-kvh=8-mode=fwd.csv
│ ├── b=2-ub=True-hdim=128-qh=16-kvh=8-mode=fwd.png
│ ├── b=2-ub=True-hdim=128-qh=32-kvh=16-mode=fwd.csv
│ ├── b=2-ub=True-hdim=128-qh=32-kvh=16-mode=fwd.png
│ ├── b=2-ub=True-hdim=256-qh=16-kvh=8-mode=fwd.csv
│ ├── b=2-ub=True-hdim=256-qh=16-kvh=8-mode=fwd.png
│ ├── b=2-ub=True-hdim=256-qh=32-kvh=16-mode=fwd.csv
│ ├── b=2-ub=True-hdim=256-qh=32-kvh=16-mode=fwd.png
│ ├── b=2-ub=True-hdim=64-qh=16-kvh=8-mode=fwd.csv
│ ├── b=2-ub=True-hdim=64-qh=16-kvh=8-mode=fwd.png
│ ├── b=2-ub=True-hdim=64-qh=32-kvh=16-mode=fwd.csv
│ ├── b=2-ub=True-hdim=64-qh=32-kvh=16-mode=fwd.png
│ ├── b=4-ub=False-hdim=128-qh=16-kvh=8-mode=fwd.csv
│ ├── b=4-ub=False-hdim=128-qh=16-kvh=8-mode=fwd.png
│ ├── b=4-ub=False-hdim=128-qh=32-kvh=16-mode=fwd.csv
│ ├── b=4-ub=False-hdim=128-qh=32-kvh=16-mode=fwd.png
│ ├── b=4-ub=False-hdim=256-qh=16-kvh=8-mode=fwd.csv
│ ├── b=4-ub=False-hdim=256-qh=16-kvh=8-mode=fwd.png
│ ├── b=4-ub=False-hdim=256-qh=32-kvh=16-mode=fwd.csv
│ ├── b=4-ub=False-hdim=256-qh=32-kvh=16-mode=fwd.png
│ ├── b=4-ub=False-hdim=64-qh=16-kvh=8-mode=fwd.csv
│ ├── b=4-ub=False-hdim=64-qh=16-kvh=8-mode=fwd.png
│ ├── b=4-ub=False-hdim=64-qh=32-kvh=16-mode=fwd.csv
│ ├── b=4-ub=False-hdim=64-qh=32-kvh=16-mode=fwd.png
│ ├── b=4-ub=True-hdim=128-qh=16-kvh=8-mode=fwd.csv
│ ├── b=4-ub=True-hdim=128-qh=16-kvh=8-mode=fwd.png
│ ├── b=4-ub=True-hdim=128-qh=32-kvh=16-mode=fwd.csv
│ ├── b=4-ub=True-hdim=128-qh=32-kvh=16-mode=fwd.png
│ ├── b=4-ub=True-hdim=256-qh=16-kvh=8-mode=fwd.csv
│ ├── b=4-ub=True-hdim=256-qh=16-kvh=8-mode=fwd.png
│ ├── b=4-ub=True-hdim=256-qh=32-kvh=16-mode=fwd.csv
│ ├── b=4-ub=True-hdim=256-qh=32-kvh=16-mode=fwd.png
│ ├── b=4-ub=True-hdim=64-qh=16-kvh=8-mode=fwd.csv
│ ├── b=4-ub=True-hdim=64-qh=16-kvh=8-mode=fwd.png
│ ├── b=4-ub=True-hdim=64-qh=32-kvh=16-mode=fwd.csv
│ ├── b=4-ub=True-hdim=64-qh=32-kvh=16-mode=fwd.png
│ └── results.html
└── triton-vs-jax-sdpa
│ ├── README.md
│ ├── b=1-ub=False-hdim=128-qh=16-kvh=8-mode=fwd.csv
│ ├── b=1-ub=False-hdim=128-qh=16-kvh=8-mode=fwd.png
│ ├── b=1-ub=False-hdim=128-qh=32-kvh=16-mode=fwd.csv
│ ├── b=1-ub=False-hdim=128-qh=32-kvh=16-mode=fwd.png
│ ├── b=1-ub=False-hdim=256-qh=16-kvh=8-mode=fwd.csv
│ ├── b=1-ub=False-hdim=256-qh=16-kvh=8-mode=fwd.png
│ ├── b=1-ub=False-hdim=256-qh=32-kvh=16-mode=fwd.csv
│ ├── b=1-ub=False-hdim=256-qh=32-kvh=16-mode=fwd.png
│ ├── b=1-ub=False-hdim=64-qh=16-kvh=8-mode=fwd.csv
│ ├── b=1-ub=False-hdim=64-qh=16-kvh=8-mode=fwd.png
│ ├── b=1-ub=False-hdim=64-qh=32-kvh=16-mode=fwd.csv
│ ├── b=1-ub=False-hdim=64-qh=32-kvh=16-mode=fwd.png
│ ├── b=1-ub=True-hdim=128-qh=16-kvh=8-mode=fwd.csv
│ ├── b=1-ub=True-hdim=128-qh=16-kvh=8-mode=fwd.png
│ ├── b=1-ub=True-hdim=128-qh=32-kvh=16-mode=fwd.csv
│ ├── b=1-ub=True-hdim=128-qh=32-kvh=16-mode=fwd.png
│ ├── b=1-ub=True-hdim=256-qh=16-kvh=8-mode=fwd.csv
│ ├── b=1-ub=True-hdim=256-qh=16-kvh=8-mode=fwd.png
│ ├── b=1-ub=True-hdim=256-qh=32-kvh=16-mode=fwd.csv
│ ├── b=1-ub=True-hdim=256-qh=32-kvh=16-mode=fwd.png
│ ├── b=1-ub=True-hdim=64-qh=16-kvh=8-mode=fwd.csv
│ ├── b=1-ub=True-hdim=64-qh=16-kvh=8-mode=fwd.png
│ ├── b=1-ub=True-hdim=64-qh=32-kvh=16-mode=fwd.csv
│ ├── b=1-ub=True-hdim=64-qh=32-kvh=16-mode=fwd.png
│ ├── b=2-ub=True-hdim=128-qh=16-kvh=8-mode=fwd.csv
│ ├── b=2-ub=True-hdim=128-qh=16-kvh=8-mode=fwd.png
│ ├── b=2-ub=True-hdim=128-qh=32-kvh=16-mode=fwd.csv
│ ├── b=2-ub=True-hdim=128-qh=32-kvh=16-mode=fwd.png
│ ├── b=2-ub=True-hdim=256-qh=16-kvh=8-mode=fwd.csv
│ ├── b=2-ub=True-hdim=256-qh=16-kvh=8-mode=fwd.png
│ ├── b=2-ub=True-hdim=64-qh=16-kvh=8-mode=fwd.csv
│ ├── b=2-ub=True-hdim=64-qh=16-kvh=8-mode=fwd.png
│ ├── b=2-ub=True-hdim=64-qh=32-kvh=16-mode=fwd.csv
│ ├── b=2-ub=True-hdim=64-qh=32-kvh=16-mode=fwd.png
│ └── results.html
├── jax_flash_attn2
├── __init__.py
├── flash_attention.py
├── flash_attention_jax
│ ├── __init__.py
│ ├── _backward_jax.py
│ ├── _flash_attention.py
│ └── _forward_jax.py
├── flash_attention_triton
│ ├── __init__.py
│ ├── _backward_triton.py
│ ├── _flash_attention.py
│ ├── _forward_triton.py
│ └── _utils.py
├── refrence_call.py
└── utils.py
├── poetry.lock
└── pyproject.toml
/.github/workflows/pylint.yml:
--------------------------------------------------------------------------------
1 | name: Pylint
2 |
3 | on: [push]
4 |
5 | jobs:
6 | build:
7 | runs-on: ubuntu-latest
8 | strategy:
9 | matrix:
10 | python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
11 | steps:
12 | - uses: actions/checkout@v4
13 | - name: Set up Python ${{ matrix.python-version }}
14 | uses: actions/setup-python@v3
15 | with:
16 | python-version: ${{ matrix.python-version }}
17 | - name: Install dependencies
18 | run: |
19 | python -m pip install --upgrade pip
20 | pip install pylint
21 | - name: Analysing the code with pylint
22 | run: |
23 | pylint $(git ls-files '*.py')
24 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 | env.py
9 | exp.py
10 | # Distribution / packaging
11 | .Python
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 | cover/
54 |
55 | # Translations
56 | *.mo
57 | *.pot
58 |
59 | # Django stuff:
60 | *.log
61 | local_settings.py
62 | db.sqlite3
63 | db.sqlite3-journal
64 |
65 | # Flask stuff:
66 | instance/
67 | .webassets-cache
68 |
69 | # Scrapy stuff:
70 | .scrapy
71 |
72 | # Sphinx documentation
73 | docs/_build/
74 |
75 | # PyBuilder
76 | .pybuilder/
77 | target/
78 |
79 | # Jupyter Notebook
80 | .ipynb_checkpoints
81 |
82 | # IPython
83 | profile_default/
84 | ipython_config.py
85 |
86 | # pyenv
87 | # For a library or package, you might want to ignore these files since the code is
88 | # intended to run in multiple environments; otherwise, check them in:
89 | # .python-version
90 |
91 | # pipenv
92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
95 | # install all needed dependencies.
96 | #Pipfile.lock
97 |
98 | # poetry
99 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
100 | # This is especially recommended for binary packages to ensure reproducibility, and is more
101 | # commonly ignored for libraries.
102 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
103 | #poetry.lock
104 |
105 | # pdm
106 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
107 | #pdm.lock
108 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
109 | # in version control.
110 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
111 | .pdm.toml
112 | .pdm-python
113 | .pdm-build/
114 |
115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
116 | __pypackages__/
117 |
118 | # Celery stuff
119 | celerybeat-schedule
120 | celerybeat.pid
121 |
122 | # SageMath parsed files
123 | *.sage.py
124 |
125 | # Environments
126 | .env
127 | .venv
128 | env/
129 | venv/
130 | ENV/
131 | env.bak/
132 | venv.bak/
133 |
134 | # Spyder project settings
135 | .spyderproject
136 | .spyproject
137 |
138 | # Rope project settings
139 | .ropeproject
140 |
141 | # mkdocs documentation
142 | /site
143 |
144 | # mypy
145 | .mypy_cache/
146 | .dmypy.json
147 | dmypy.json
148 |
149 | # Pyre type checker
150 | .pyre/
151 |
152 | # pytype static type analyzer
153 | .pytype/
154 |
155 | # Cython debug symbols
156 | cython_debug/
157 |
158 | # PyCharm
159 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
160 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
161 | # and can be added to the global gitignore or merged into this file. For a more nuclear
162 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
163 | #.idea/
164 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # JAX-Flash-Attention2
2 |
3 | A flexible and efficient implementation of Flash Attention 2.0 for JAX, supporting multiple backends (GPU/TPU/CPU) and platforms (Triton/Pallas/JAX).
4 |
5 | ## Installation
6 |
7 | ```bash
8 | pip install jax-flash-attn2
9 | ```
10 |
11 | ## Basic Usage
12 |
13 | ```python
14 | import jax
15 | import jax.numpy as jnp
16 | import jax_flash_attn2 as jfa
17 |
18 | # Initialize the FlashAttention module with desired configuration
19 | flash_attention = jfa.FlashAttention(
20 | jfa.AttentionConfig(
21 | platform=jfa.Platform.TRITON, # Options: TRITON, PALLAS, JAX
22 | backend=jfa.Backend.GPU, # Options: GPU, TPU, CPU
23 | )
24 | )
25 |
26 | # Create sample inputs
27 | batch_size, num_heads, seq_len, head_dim = 2, 4, 512, 64
28 | query = jax.random.normal(jax.random.PRNGKey(0), (batch_size, num_heads * 4, seq_len, head_dim), "f2")
29 | key = jax.random.normal(jax.random.PRNGKey(1), (batch_size, num_heads, seq_len, head_dim), "f2")
30 | value = jax.random.normal(jax.random.PRNGKey(2), (batch_size, num_heads, seq_len, head_dim), "f2")
31 |
32 | # Compute attention
33 | output = flash_attention(
34 | query=query,
35 | key=key,
36 | value=value,
37 | causal=True # Enable causal masking for decoder-only models
38 | )
39 |
40 | # output shape: (batch_size, num_heads, seq_len, head_dim)
41 | ```
42 |
43 | ## Advanced Usage
44 |
45 | ### With Attention Mask
46 |
47 | ```python
48 | # Create an attention mask (1 = attend, 0 = mask)
49 | attention_mask = jnp.ones((batch_size, 1, seq_len, seq_len)) # Allow full attention
50 | # For example, mask the first 100 tokens from attending to the last 100 tokens
51 | attention_mask = attention_mask.at[:, :, :100, -100:].set(0)
52 |
53 | output = flash_attention(
54 | query=query,
55 | key=key,
56 | value=value,
57 | attention_mask=attention_mask,
58 | causal=False # Using explicit mask instead of causal
59 | )
60 | ```
61 |
62 | ### With Attention Bias
63 |
64 | ```python
65 | # Create an attention bias
66 | bias = jnp.zeros((batch_size, 1, seq_len, seq_len))
67 | # Add position-dependent bias
68 | for i in range(seq_len):
69 | for j in range(seq_len):
70 | bias = bias.at[:, :, i, j].set(1.0 / (1.0 + abs(i - j)))
71 |
72 | output = flash_attention(
73 | query=query,
74 | key=key,
75 | value=value,
76 | bias=bias
77 | )
78 | ```
79 |
80 | ### With Dropout
81 |
82 | ```python
83 | output = flash_attention(
84 | query=query,
85 | key=key,
86 | value=value,
87 | dropout_prob=0.1,
88 | dropout_seed=42,
89 | causal=True
90 | )
91 | ```
92 |
93 | ## Flax Modules with JFA2
94 |
95 | Here's an example of integrating jax-flash-attn2 within a Transformer model implemented in Flax:
96 |
97 | ```python
98 | import typing as tp
99 | from functools import partial
100 |
101 | import chex
102 | import flax.nnx as nn
103 | import jax
104 | import jax.numpy as jnp
105 |
106 | import jax_flash_attn2 as jfa
107 |
108 |
109 | class JFAttention2(nn.Module):
110 | def __init__(
111 | self,
112 | hidden_size: int,
113 | head_dim: int,
114 | num_attention_heads: int,
115 | num_key_value_heads: int,
116 | dtype: jnp.dtype = jnp.float32,
117 | param_dtype: jnp.dtype = jnp.float32,
118 | precision: jax.lax.PrecisionLike = None,
119 | *,
120 | rngs: nn.Rngs = None,
121 | ):
122 | if rngs is None:
123 | rngs = nn.Rngs(0)
124 | self.dtype = dtype
125 | self.param_dtype = param_dtype
126 | self.precision = precision
127 | self.rngs = rngs
128 |
129 | self.hidden_size = hidden_size
130 | self.head_dim = head_dim
131 | self.num_attention_heads = num_attention_heads
132 | self.num_key_value_heads = num_key_value_heads
133 |
134 | self.num_key_value_groups = num_attention_heads // num_key_value_heads
135 |
136 | if self.num_key_value_groups == 1:
137 | assert num_attention_heads == num_key_value_heads
138 |
139 | linear_class = partial(
140 | nn.Linear,
141 | dtype=dtype,
142 | param_dtype=param_dtype,
143 | use_bias=False,
144 | kernel_init=jax.nn.initializers.normal(0.02),
145 | precision=precision,
146 | rngs=rngs,
147 | )
148 | self.q_proj = linear_class(hidden_size, num_attention_heads * self.head_dim)
149 | self.k_proj = linear_class(hidden_size, num_key_value_heads * self.head_dim)
150 | self.v_proj = linear_class(hidden_size, num_key_value_heads * self.head_dim)
151 | self.o_proj = linear_class(num_attention_heads * self.head_dim, hidden_size)
152 |
153 | config = jfa.AttentionConfig(platform=jfa.Platform.TRITON, backend=jfa.Backend.GPU)
154 |
155 | self.jfa2 = jfa.FlashAttention(config)
156 |
157 | def __call__(
158 | self,
159 | hidden_states: chex.Array,
160 | attention_mask: chex.Array,
161 | causal: bool = True,
162 | ) -> tp.Tuple[chex.Array, chex.Array]:
163 | batch_size, sequence_length = hidden_states.shape[:2]
164 | query_states, key_states, value_states = (
165 | self.q_proj(hidden_states),
166 | self.k_proj(hidden_states),
167 | self.v_proj(hidden_states),
168 | )
169 | qshape = (
170 | batch_size,
171 | sequence_length,
172 | self.num_attention_heads,
173 | self.head_dim,
174 | )
175 | kv_shape = (
176 | batch_size,
177 | sequence_length,
178 | self.num_key_value_heads,
179 | self.head_dim,
180 | )
181 | query_states = query_states.reshape(qshape)
182 | key_states = key_states.reshape(kv_shape)
183 | value_states = value_states.reshape(kv_shape)
184 | attn_output = self.jfa2.forward(
185 | query_states.astype(jnp.bfloat16),
186 | key_states.astype(jnp.bfloat16),
187 | value_states.astype(jnp.bfloat16),
188 | jnp.where(attention_mask, 0, jnp.finfo(query_states).min).astype(jnp.bfloat16),
189 | causal=causal,
190 | )
191 | attn_output = jnp.reshape(attn_output, (batch_size, sequence_length, -1))
192 | attn_output = self.o_proj(attn_output)
193 | return attn_output
194 | ```
195 |
196 | ## Platform-Specific Examples
197 |
198 | ### Using JAX Backend
199 |
200 | ```python
201 | jax_flash_attn = jfa.FlashAttention(
202 | jfa.AttentionConfig(
203 | platform=jfa.Platform.JAX,
204 | backend=jfa.Backend.CPU, # Works on any hardware
205 | )
206 | )
207 |
208 | output = jax_flash_attn(query, key, value)
209 | ```
210 |
211 | ### Using Pallas for TPU
212 |
213 | ```python
214 | tpu_flash_attn = jfa.FlashAttention(
215 | jfa.AttentionConfig(
216 | platform=jfa.Platform.PALLAS,
217 | backend=jfa.Backend.TPU,
218 | )
219 | )
220 |
221 | output = tpu_flash_attn(query, key, value)
222 | ```
223 |
224 | ## Integration with JAX Transformations
225 |
226 | ```python
227 | @jax.jit
228 | def attention_forward(q, k, v, mask=None):
229 | return flash_attention(
230 | query=q,
231 | key=k,
232 | value=v,
233 | attention_mask=mask,
234 | causal=True
235 | )
236 |
237 | # JIT-compiled function
238 | fast_attention = attention_forward(query, key, value)
239 |
240 | # With gradient computation
241 | def loss_fn(q, k, v):
242 | attn_output = flash_attention(q, k, v, causal=True)
243 | return jnp.mean(attn_output)
244 |
245 | grads = jax.grad(loss_fn)(query, key, value)
246 | ```
247 |
248 | ## Limitations
249 |
250 | - Triton platform is only available on NVIDIA GPUs.
251 | - Some platform-backend combinations are not supported (see table above).
252 | - Custom attention masks are not yet supported (use bias instead).
253 |
254 | ## Contributing
255 |
256 | Contributions are welcome! Please feel free to submit a Pull Request.
257 |
258 | ## Citation
259 |
260 | If you use this implementation in your research, please cite:
261 |
262 | ```bibtex
263 | @software{jax_flash_attn2,
264 | title = {JAX Flash Attention 2.0},
265 | year = {2024},
266 | url = {https://github.com/erfanzar/jax-flash-attn2}
267 | }
268 | ```
269 |
270 | ### refrence citations
271 |
272 | ```bibtex
273 | @inproceedings{dao2022flashattention,
274 | title={Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
275 | author={Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
276 | booktitle={Advances in Neural Information Processing Systems (NeurIPS)},
277 | year={2022}
278 | }
279 | @inproceedings{dao2023flashattention2,
280 | title={Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning},
281 | author={Dao, Tri},
282 | booktitle={International Conference on Learning Representations (ICLR)},
283 | year={2024}
284 | }
285 | ```
286 |
287 | ## Acknowledgments And Refrences
288 |
289 | 1. All of kernels are copied from [`EasyDeL`](https://github.com/erfanzar/Easydel)
290 |
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=False-hdim=128-qh=16-kvh=8-mode=fwd.csv:
--------------------------------------------------------------------------------
1 | seqlen,Triton,Jax-cudnn
2 | 1024.000000,0.109986,0.260454
3 | 2048.000000,0.277971,0.602729
4 | 4096.000000,0.539475,2.037018
5 | 6144.000000,0.657339,4.449059
6 | 8192.000000,0.536271,7.733691
7 |
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=False-hdim=128-qh=16-kvh=8-mode=fwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=False-hdim=128-qh=16-kvh=8-mode=fwd.png
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=False-hdim=128-qh=32-kvh=16-mode=fwd.csv:
--------------------------------------------------------------------------------
1 | seqlen,Triton,Jax-cudnn
2 | 1024.000000,0.201889,0.316081
3 | 2048.000000,0.277653,1.052060
4 | 4096.000000,0.558586,3.973465
5 | 6144.000000,0.312637,8.737595
6 | 8192.000000,0.001152,15.191124
7 |
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=False-hdim=128-qh=32-kvh=16-mode=fwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=False-hdim=128-qh=32-kvh=16-mode=fwd.png
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=False-hdim=256-qh=16-kvh=8-mode=fwd.csv:
--------------------------------------------------------------------------------
1 | seqlen,Triton,Jax-cudnn
2 | 1024.000000,0.293918,300.000000
3 | 2048.000000,0.564013,300.000000
4 | 4096.000000,0.581458,300.000000
5 | 6144.000000,0.001311,300.000000
6 | 8192.000000,0.001536,300.000000
7 |
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=False-hdim=256-qh=16-kvh=8-mode=fwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=False-hdim=256-qh=16-kvh=8-mode=fwd.png
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=False-hdim=256-qh=32-kvh=16-mode=fwd.csv:
--------------------------------------------------------------------------------
1 | seqlen,Triton,Jax-cudnn
2 | 1024.000000,0.299787,300.000000
3 | 2048.000000,0.549313,300.000000
4 | 4096.000000,0.061952,300.000000
5 | 6144.000000,0.001654,300.000000
6 | 8192.000000,0.001317,300.000000
7 |
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=False-hdim=256-qh=32-kvh=16-mode=fwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=False-hdim=256-qh=32-kvh=16-mode=fwd.png
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=False-hdim=64-qh=16-kvh=8-mode=fwd.csv:
--------------------------------------------------------------------------------
1 | seqlen,Triton,Jax-cudnn
2 | 1024.000000,0.001285,0.052782
3 | 2048.000000,0.145014,0.378316
4 | 4096.000000,0.233479,1.062174
5 | 6144.000000,0.366142,2.297836
6 | 8192.000000,0.392203,3.882547
7 |
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=False-hdim=64-qh=16-kvh=8-mode=fwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=False-hdim=64-qh=16-kvh=8-mode=fwd.png
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=False-hdim=64-qh=32-kvh=16-mode=fwd.csv:
--------------------------------------------------------------------------------
1 | seqlen,Triton,Jax-cudnn
2 | 1024.000000,0.001382,0.184686
3 | 2048.000000,0.179171,0.601632
4 | 4096.000000,0.333080,1.976510
5 | 6144.000000,0.461126,4.308108
6 | 8192.000000,0.357641,7.543623
7 |
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=False-hdim=64-qh=32-kvh=16-mode=fwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=False-hdim=64-qh=32-kvh=16-mode=fwd.png
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=True-hdim=128-qh=16-kvh=8-mode=fwd.csv:
--------------------------------------------------------------------------------
1 | seqlen,Triton,Jax-cudnn
2 | 1024.000000,0.154736,0.244154
3 | 2048.000000,0.337517,0.651588
4 | 4096.000000,0.561307,2.177832
5 | 6144.000000,0.744631,5.053024
6 | 8192.000000,0.475687,8.623260
7 |
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=True-hdim=128-qh=16-kvh=8-mode=fwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=True-hdim=128-qh=16-kvh=8-mode=fwd.png
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=True-hdim=128-qh=32-kvh=16-mode=fwd.csv:
--------------------------------------------------------------------------------
1 | seqlen,Triton,Jax-cudnn
2 | 1024.000000,0.236990,0.358192
3 | 2048.000000,0.419905,1.142686
4 | 4096.000000,0.569661,4.441965
5 | 6144.000000,0.263899,9.707267
6 | 8192.000000,0.001587,17.073843
7 |
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=True-hdim=128-qh=32-kvh=16-mode=fwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=True-hdim=128-qh=32-kvh=16-mode=fwd.png
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=True-hdim=256-qh=16-kvh=8-mode=fwd.csv:
--------------------------------------------------------------------------------
1 | seqlen,Triton,Jax-cudnn
2 | 1024.000000,0.367849,300.000000
3 | 2048.000000,0.695647,300.000000
4 | 4096.000000,0.379563,300.000000
5 | 6144.000000,0.001536,300.000000
6 | 8192.000000,0.001479,300.000000
7 |
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=True-hdim=256-qh=16-kvh=8-mode=fwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=True-hdim=256-qh=16-kvh=8-mode=fwd.png
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=True-hdim=256-qh=32-kvh=16-mode=fwd.csv:
--------------------------------------------------------------------------------
1 | seqlen,Triton,Jax-cudnn
2 | 1024.000000,0.354809,300.000000
3 | 2048.000000,0.535335,300.000000
4 | 4096.000000,0.001536,300.000000
5 | 6144.000000,0.001152,300.000000
6 | 8192.000000,0.002048,300.000000
7 |
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=True-hdim=256-qh=32-kvh=16-mode=fwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=True-hdim=256-qh=32-kvh=16-mode=fwd.png
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=True-hdim=64-qh=16-kvh=8-mode=fwd.csv:
--------------------------------------------------------------------------------
1 | seqlen,Triton,Jax-cudnn
2 | 1024.000000,0.009526,0.232182
3 | 2048.000000,0.256336,0.574287
4 | 4096.000000,0.443419,2.048195
5 | 6144.000000,0.787608,4.677517
6 | 8192.000000,0.911081,11.655725
7 |
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=True-hdim=64-qh=16-kvh=8-mode=fwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=True-hdim=64-qh=16-kvh=8-mode=fwd.png
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=True-hdim=64-qh=32-kvh=16-mode=fwd.csv:
--------------------------------------------------------------------------------
1 | seqlen,Triton,Jax-cudnn
2 | 1024.000000,0.146767,0.336205
3 | 2048.000000,0.258379,1.036054
4 | 4096.000000,0.530867,3.775501
5 | 6144.000000,0.605855,11.867448
6 | 8192.000000,0.293947,24.475136
7 |
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=True-hdim=64-qh=32-kvh=16-mode=fwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=True-hdim=64-qh=32-kvh=16-mode=fwd.png
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=False-hdim=128-qh=16-kvh=8-mode=fwd.csv:
--------------------------------------------------------------------------------
1 | seqlen,Triton,Jax-cudnn
2 | 1024.000000,0.199049,0.342512
3 | 2048.000000,0.302052,1.057982
4 | 4096.000000,0.462467,4.031027
5 | 6144.000000,0.339456,8.867299
6 | 8192.000000,0.001707,15.418736
7 |
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=False-hdim=128-qh=16-kvh=8-mode=fwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=False-hdim=128-qh=16-kvh=8-mode=fwd.png
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=False-hdim=128-qh=32-kvh=16-mode=fwd.csv:
--------------------------------------------------------------------------------
1 | seqlen,Triton,Jax-cudnn
2 | 1024.000000,0.203239,0.592487
3 | 2048.000000,0.401872,2.120210
4 | 4096.000000,0.001365,7.893704
5 | 6144.000000,0.001463,17.337357
6 | 8192.000000,0.001451,30.725855
7 |
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=False-hdim=128-qh=32-kvh=16-mode=fwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=False-hdim=128-qh=32-kvh=16-mode=fwd.png
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=False-hdim=256-qh=16-kvh=8-mode=fwd.csv:
--------------------------------------------------------------------------------
1 | seqlen,Triton,Jax-cudnn
2 | 1024.000000,0.305674,300.000000
3 | 2048.000000,0.442985,300.000000
4 | 4096.000000,0.061477,300.000000
5 | 6144.000000,0.001339,300.000000
6 | 8192.000000,0.001609,300.000000
7 |
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=False-hdim=256-qh=16-kvh=8-mode=fwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=False-hdim=256-qh=16-kvh=8-mode=fwd.png
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=False-hdim=256-qh=32-kvh=16-mode=fwd.csv:
--------------------------------------------------------------------------------
1 | seqlen,Triton,Jax-cudnn
2 | 1024.000000,0.372295,300.000000
3 | 2048.000000,0.363640,300.000000
4 | 4096.000000,0.001463,300.000000
5 | 6144.000000,0.002219,300.000000
6 | 8192.000000,0.001707,300.000000
7 |
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=False-hdim=256-qh=32-kvh=16-mode=fwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=False-hdim=256-qh=32-kvh=16-mode=fwd.png
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=False-hdim=64-qh=16-kvh=8-mode=fwd.csv:
--------------------------------------------------------------------------------
1 | seqlen,Triton,Jax-cudnn
2 | 1024.000000,0.001341,0.119918
3 | 2048.000000,0.190682,0.616036
4 | 4096.000000,0.307017,2.008788
5 | 6144.000000,0.472064,4.379703
6 | 8192.000000,0.386290,7.646230
7 |
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=False-hdim=64-qh=16-kvh=8-mode=fwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=False-hdim=64-qh=16-kvh=8-mode=fwd.png
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=False-hdim=64-qh=32-kvh=16-mode=fwd.csv:
--------------------------------------------------------------------------------
1 | seqlen,Triton,Jax-cudnn
2 | 1024.000000,0.001479,0.268115
3 | 2048.000000,0.189668,1.023319
4 | 4096.000000,0.170710,3.953273
5 | 6144.000000,0.230988,8.635345
6 | 8192.000000,0.001463,15.048271
7 |
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=False-hdim=64-qh=32-kvh=16-mode=fwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=False-hdim=64-qh=32-kvh=16-mode=fwd.png
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=True-hdim=128-qh=16-kvh=8-mode=fwd.csv:
--------------------------------------------------------------------------------
1 | seqlen,Triton,Jax-cudnn
2 | 1024.000000,0.238014,0.355650
3 | 2048.000000,0.426852,1.160171
4 | 4096.000000,0.623411,4.461117
5 | 6144.000000,0.262114,9.963975
6 | 8192.000000,0.001331,17.273287
7 |
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=True-hdim=128-qh=16-kvh=8-mode=fwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=True-hdim=128-qh=16-kvh=8-mode=fwd.png
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=True-hdim=128-qh=32-kvh=16-mode=fwd.csv:
--------------------------------------------------------------------------------
1 | seqlen,Triton,Jax-cudnn
2 | 1024.000000,0.253855,0.626204
3 | 2048.000000,0.489057,2.258021
4 | 4096.000000,0.237148,8.760285
5 | 6144.000000,0.001265,19.611481
6 | 8192.000000,0.001741,34.416176
7 |
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=True-hdim=128-qh=32-kvh=16-mode=fwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=True-hdim=128-qh=32-kvh=16-mode=fwd.png
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=True-hdim=256-qh=16-kvh=8-mode=fwd.csv:
--------------------------------------------------------------------------------
1 | seqlen,Triton,Jax-cudnn
2 | 1024.000000,0.364506,300.000000
3 | 2048.000000,0.547446,300.000000
4 | 4096.000000,0.001536,300.000000
5 | 6144.000000,0.002176,300.000000
6 | 8192.000000,0.001536,300.000000
7 |
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=True-hdim=256-qh=16-kvh=8-mode=fwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=True-hdim=256-qh=16-kvh=8-mode=fwd.png
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=True-hdim=256-qh=32-kvh=16-mode=fwd.csv:
--------------------------------------------------------------------------------
1 | seqlen,Triton,Jax-cudnn
2 | 1024.000000,0.426923,300.000000
3 | 2048.000000,0.206000,300.000000
4 | 4096.000000,0.001479,300.000000
5 | 6144.000000,0.002048,300.000000
6 | 8192.000000,0.001536,300.000000
7 |
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=True-hdim=256-qh=32-kvh=16-mode=fwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=True-hdim=256-qh=32-kvh=16-mode=fwd.png
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=True-hdim=64-qh=16-kvh=8-mode=fwd.csv:
--------------------------------------------------------------------------------
1 | seqlen,Triton,Jax-cudnn
2 | 1024.000000,0.181881,0.348315
3 | 2048.000000,0.248074,1.040480
4 | 4096.000000,0.596546,3.865138
5 | 6144.000000,0.630855,12.445517
6 | 8192.000000,0.288358,24.834625
7 |
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=True-hdim=64-qh=16-kvh=8-mode=fwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=True-hdim=64-qh=16-kvh=8-mode=fwd.png
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=True-hdim=64-qh=32-kvh=16-mode=fwd.csv:
--------------------------------------------------------------------------------
1 | seqlen,Triton,Jax-cudnn
2 | 1024.000000,0.208093,0.579161
3 | 2048.000000,0.392836,1.864016
4 | 4096.000000,0.479936,9.628813
5 | 6144.000000,0.114556,27.601694
6 | 8192.000000,0.001506,51.775490
7 |
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=True-hdim=64-qh=32-kvh=16-mode=fwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=True-hdim=64-qh=32-kvh=16-mode=fwd.png
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=False-hdim=128-qh=16-kvh=8-mode=fwd.csv:
--------------------------------------------------------------------------------
1 | seqlen,Triton,Jax-cudnn
2 | 1024.000000,0.224760,0.575211
3 | 2048.000000,0.371112,2.018634
4 | 4096.000000,0.001536,5.240088
5 | 6144.000000,0.001268,18.134644
6 | 8192.000000,0.001536,32.440353
7 |
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=False-hdim=128-qh=16-kvh=8-mode=fwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=False-hdim=128-qh=16-kvh=8-mode=fwd.png
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=False-hdim=128-qh=32-kvh=16-mode=fwd.csv:
--------------------------------------------------------------------------------
1 | seqlen,Triton,Jax-cudnn
2 | 1024.000000,0.182035,0.957698
3 | 2048.000000,0.024439,2.797037
4 | 4096.000000,0.001769,15.568085
5 | 6144.000000,0.001434,34.589699
6 | 8192.000000,0.001365,61.821152
7 |
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=False-hdim=128-qh=32-kvh=16-mode=fwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=False-hdim=128-qh=32-kvh=16-mode=fwd.png
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=False-hdim=256-qh=16-kvh=8-mode=fwd.csv:
--------------------------------------------------------------------------------
1 | seqlen,Triton,Jax-cudnn
2 | 1024.000000,0.344924,300.000000
3 | 2048.000000,0.330334,300.000000
4 | 4096.000000,0.001497,300.000000
5 | 6144.000000,0.001877,300.000000
6 | 8192.000000,0.001365,300.000000
7 |
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=False-hdim=256-qh=16-kvh=8-mode=fwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=False-hdim=256-qh=16-kvh=8-mode=fwd.png
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=False-hdim=256-qh=32-kvh=16-mode=fwd.csv:
--------------------------------------------------------------------------------
1 | seqlen,Triton,Jax-cudnn
2 | 1024.000000,0.316557,300.000000
3 | 2048.000000,0.001516,300.000000
4 | 4096.000000,0.001536,300.000000
5 | 6144.000000,0.001707,300.000000
6 | 8192.000000,0.002048,300.000000
7 |
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=False-hdim=256-qh=32-kvh=16-mode=fwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=False-hdim=256-qh=32-kvh=16-mode=fwd.png
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=False-hdim=64-qh=16-kvh=8-mode=fwd.csv:
--------------------------------------------------------------------------------
1 | seqlen,Triton,Jax-cudnn
2 | 1024.000000,0.001392,0.309596
3 | 2048.000000,0.193620,1.052123
4 | 4096.000000,0.170950,3.937541
5 | 6144.000000,0.208427,8.745746
6 | 8192.000000,0.001682,15.822118
7 |
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=False-hdim=64-qh=16-kvh=8-mode=fwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=False-hdim=64-qh=16-kvh=8-mode=fwd.png
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=False-hdim=64-qh=32-kvh=16-mode=fwd.csv:
--------------------------------------------------------------------------------
1 | seqlen,Triton,Jax-cudnn
2 | 1024.000000,0.001500,0.512446
3 | 2048.000000,0.157459,2.044994
4 | 4096.000000,0.001593,7.709981
5 | 6144.000000,0.001516,16.856882
6 | 8192.000000,0.001575,30.138208
7 |
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=False-hdim=64-qh=32-kvh=16-mode=fwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=False-hdim=64-qh=32-kvh=16-mode=fwd.png
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=True-hdim=128-qh=16-kvh=8-mode=fwd.csv:
--------------------------------------------------------------------------------
1 | seqlen,Triton,Jax-cudnn
2 | 1024.000000,0.255597,0.589090
3 | 2048.000000,0.440020,2.224123
4 | 4096.000000,0.122880,7.861397
5 | 6144.000000,0.001408,19.485580
6 | 8192.000000,0.001331,34.835968
7 |
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=True-hdim=128-qh=16-kvh=8-mode=fwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=True-hdim=128-qh=16-kvh=8-mode=fwd.png
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=True-hdim=128-qh=32-kvh=16-mode=fwd.csv:
--------------------------------------------------------------------------------
1 | seqlen,Triton,Jax-cudnn
2 | 1024.000000,0.267349,0.912527
3 | 2048.000000,0.251509,3.839468
4 | 4096.000000,0.001325,14.249164
5 | 6144.000000,0.001536,40.990719
6 | 8192.000000,0.001536,72.704002
7 |
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=True-hdim=128-qh=32-kvh=16-mode=fwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=True-hdim=128-qh=32-kvh=16-mode=fwd.png
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=True-hdim=256-qh=16-kvh=8-mode=fwd.csv:
--------------------------------------------------------------------------------
1 | seqlen,Triton,Jax-cudnn
2 | 1024.000000,0.374650,300.000000
3 | 2048.000000,0.082349,300.000000
4 | 4096.000000,0.001792,300.000000
5 | 6144.000000,0.001024,300.000000
6 | 8192.000000,0.002560,300.000000
7 |
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=True-hdim=256-qh=16-kvh=8-mode=fwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=True-hdim=256-qh=16-kvh=8-mode=fwd.png
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=True-hdim=256-qh=32-kvh=16-mode=fwd.csv:
--------------------------------------------------------------------------------
1 | seqlen,Triton,Jax-cudnn
2 | 1024.000000,0.358098,300.000000
3 | 2048.000000,0.001506,300.000000
4 | 4096.000000,0.001280,300.000000
5 | 6144.000000,0.002048,300.000000
6 | 8192.000000,0.001024,300.000000
7 |
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=True-hdim=256-qh=32-kvh=16-mode=fwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=True-hdim=256-qh=32-kvh=16-mode=fwd.png
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=True-hdim=64-qh=16-kvh=8-mode=fwd.csv:
--------------------------------------------------------------------------------
1 | seqlen,Triton,Jax-cudnn
2 | 1024.000000,0.213456,0.571713
3 | 2048.000000,0.361903,1.789920
4 | 4096.000000,0.499891,9.623846
5 | 6144.000000,0.117496,27.389696
6 | 8192.000000,0.001792,51.911678
7 |
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=True-hdim=64-qh=16-kvh=8-mode=fwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=True-hdim=64-qh=16-kvh=8-mode=fwd.png
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=True-hdim=64-qh=32-kvh=16-mode=fwd.csv:
--------------------------------------------------------------------------------
1 | seqlen,Triton,Jax-cudnn
2 | 1024.000000,0.255275,0.832885
3 | 2048.000000,0.391979,3.346892
4 | 4096.000000,0.153389,21.636608
5 | 6144.000000,0.001707,57.332737
6 | 8192.000000,0.001792,105.936035
7 |
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=True-hdim=64-qh=32-kvh=16-mode=fwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=True-hdim=64-qh=32-kvh=16-mode=fwd.png
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa-cudnn/results.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa/b=1-ub=False-hdim=128-qh=16-kvh=8-mode=fwd.csv:
--------------------------------------------------------------------------------
1 | seqlen,Triton,Jax-cudnn
2 | 1024.000000,0.106995,0.689995
3 | 2048.000000,0.275410,2.995743
4 | 4096.000000,0.520256,11.519552
5 | 6144.000000,0.616047,26.424032
6 | 8192.000000,0.524092,67.017731
7 |
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa/b=1-ub=False-hdim=128-qh=16-kvh=8-mode=fwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa/b=1-ub=False-hdim=128-qh=16-kvh=8-mode=fwd.png
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa/b=1-ub=False-hdim=128-qh=32-kvh=16-mode=fwd.csv:
--------------------------------------------------------------------------------
1 | seqlen,Triton,Jax-cudnn
2 | 1024.000000,0.195825,1.440105
3 | 2048.000000,0.323976,5.214518
4 | 4096.000000,0.500160,23.082130
5 | 6144.000000,0.331045,54.701153
6 | 8192.000000,0.001792,136.058044
7 |
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa/b=1-ub=False-hdim=128-qh=32-kvh=16-mode=fwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa/b=1-ub=False-hdim=128-qh=32-kvh=16-mode=fwd.png
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa/b=1-ub=False-hdim=256-qh=16-kvh=8-mode=fwd.csv:
--------------------------------------------------------------------------------
1 | seqlen,Triton,Jax-cudnn
2 | 1024.000000,0.302997,0.836519
3 | 2048.000000,0.544905,3.151774
4 | 4096.000000,0.574020,13.211268
5 | 6144.000000,0.001352,30.355797
6 | 8192.000000,0.001434,74.565636
7 |
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa/b=1-ub=False-hdim=256-qh=16-kvh=8-mode=fwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa/b=1-ub=False-hdim=256-qh=16-kvh=8-mode=fwd.png
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa/b=1-ub=False-hdim=256-qh=32-kvh=16-mode=fwd.csv:
--------------------------------------------------------------------------------
1 | seqlen,Triton,Jax-cudnn
2 | 1024.000000,0.366214,1.688182
3 | 2048.000000,0.452540,6.120752
4 | 4096.000000,0.051639,25.717834
5 | 6144.000000,0.001339,67.074051
6 | 8192.000000,0.001317,146.715714
7 |
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa/b=1-ub=False-hdim=256-qh=32-kvh=16-mode=fwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa/b=1-ub=False-hdim=256-qh=32-kvh=16-mode=fwd.png
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa/b=1-ub=False-hdim=64-qh=16-kvh=8-mode=fwd.csv:
--------------------------------------------------------------------------------
1 | seqlen,Triton,Jax-cudnn
2 | 1024.000000,0.001316,0.670596
3 | 2048.000000,0.142873,2.497885
4 | 4096.000000,0.234355,10.730381
5 | 6144.000000,0.324767,25.309675
6 | 8192.000000,0.424600,66.269180
7 |
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa/b=1-ub=False-hdim=64-qh=16-kvh=8-mode=fwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa/b=1-ub=False-hdim=64-qh=16-kvh=8-mode=fwd.png
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa/b=1-ub=False-hdim=64-qh=32-kvh=16-mode=fwd.csv:
--------------------------------------------------------------------------------
1 | seqlen,Triton,Jax-cudnn
2 | 1024.000000,0.001388,1.238941
3 | 2048.000000,0.182871,4.753040
4 | 4096.000000,0.278805,21.524223
5 | 6144.000000,0.396571,52.595871
6 | 8192.000000,0.326004,132.046021
7 |
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa/b=1-ub=False-hdim=64-qh=32-kvh=16-mode=fwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa/b=1-ub=False-hdim=64-qh=32-kvh=16-mode=fwd.png
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa/b=1-ub=True-hdim=128-qh=16-kvh=8-mode=fwd.csv:
--------------------------------------------------------------------------------
1 | seqlen,Triton,Jax-cudnn
2 | 1024.000000,0.170421,1.009647
3 | 2048.000000,0.340471,3.215832
4 | 4096.000000,0.616194,13.106743
5 | 6144.000000,0.728552,30.723488
6 | 8192.000000,0.448755,72.886269
7 |
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa/b=1-ub=True-hdim=128-qh=16-kvh=8-mode=fwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa/b=1-ub=True-hdim=128-qh=16-kvh=8-mode=fwd.png
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa/b=1-ub=True-hdim=128-qh=32-kvh=16-mode=fwd.csv:
--------------------------------------------------------------------------------
1 | seqlen,Triton,Jax-cudnn
2 | 1024.000000,0.246250,2.027551
3 | 2048.000000,0.390695,5.809258
4 | 4096.000000,0.588857,25.905151
5 | 6144.000000,0.246381,61.066338
6 | 8192.000000,0.001229,144.867325
7 |
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa/b=1-ub=True-hdim=128-qh=32-kvh=16-mode=fwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa/b=1-ub=True-hdim=128-qh=32-kvh=16-mode=fwd.png
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa/b=1-ub=True-hdim=256-qh=16-kvh=8-mode=fwd.csv:
--------------------------------------------------------------------------------
1 | seqlen,Triton,Jax-cudnn
2 | 1024.000000,0.331402,1.110649
3 | 2048.000000,0.558512,3.770115
4 | 4096.000000,0.388238,14.670970
5 | 6144.000000,0.001472,34.339840
6 | 8192.000000,0.001820,79.478050
7 |
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa/b=1-ub=True-hdim=256-qh=16-kvh=8-mode=fwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa/b=1-ub=True-hdim=256-qh=16-kvh=8-mode=fwd.png
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa/b=1-ub=True-hdim=256-qh=32-kvh=16-mode=fwd.csv:
--------------------------------------------------------------------------------
1 | seqlen,Triton,Jax-cudnn
2 | 1024.000000,0.438935,2.225673
3 | 2048.000000,0.560500,6.704792
4 | 4096.000000,0.001479,29.120886
5 | 6144.000000,0.001152,69.249084
6 | 8192.000000,0.001536,158.288132
7 |
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa/b=1-ub=True-hdim=256-qh=32-kvh=16-mode=fwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa/b=1-ub=True-hdim=256-qh=32-kvh=16-mode=fwd.png
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa/b=1-ub=True-hdim=64-qh=16-kvh=8-mode=fwd.csv:
--------------------------------------------------------------------------------
1 | seqlen,Triton,Jax-cudnn
2 | 1024.000000,0.001379,0.948490
3 | 2048.000000,0.253063,2.807466
4 | 4096.000000,0.450294,12.528571
5 | 6144.000000,0.832924,29.624063
6 | 8192.000000,0.837758,71.584770
7 |
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa/b=1-ub=True-hdim=64-qh=16-kvh=8-mode=fwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa/b=1-ub=True-hdim=64-qh=16-kvh=8-mode=fwd.png
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa/b=1-ub=True-hdim=64-qh=32-kvh=16-mode=fwd.csv:
--------------------------------------------------------------------------------
1 | seqlen,Triton,Jax-cudnn
2 | 1024.000000,0.152944,1.764381
3 | 2048.000000,0.248449,5.496425
4 | 4096.000000,0.613376,24.820086
5 | 6144.000000,0.583715,60.615711
6 | 8192.000000,0.265728,141.684006
7 |
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa/b=1-ub=True-hdim=64-qh=32-kvh=16-mode=fwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa/b=1-ub=True-hdim=64-qh=32-kvh=16-mode=fwd.png
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa/b=2-ub=True-hdim=128-qh=16-kvh=8-mode=fwd.csv:
--------------------------------------------------------------------------------
1 | seqlen,Triton,Jax-cudnn
2 | 1024.000000,0.236832,1.944196
3 | 2048.000000,0.454954,5.739056
4 | 4096.000000,0.604304,25.787445
5 | 6144.000000,0.281980,61.076481
6 | 8192.000000,0.001434,145.263779
7 |
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa/b=2-ub=True-hdim=128-qh=16-kvh=8-mode=fwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa/b=2-ub=True-hdim=128-qh=16-kvh=8-mode=fwd.png
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa/b=2-ub=True-hdim=128-qh=32-kvh=16-mode=fwd.csv:
--------------------------------------------------------------------------------
1 | seqlen,Triton,Jax-cudnn
2 | 1024.000000,0.259367,3.784511
3 | 2048.000000,0.483416,11.617672
4 | 4096.000000,0.215821,51.328354
5 | 6144.000000,0.001877,100.000000
6 | 8192.000000,0.001536,100.000000
7 |
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa/b=2-ub=True-hdim=128-qh=32-kvh=16-mode=fwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa/b=2-ub=True-hdim=128-qh=32-kvh=16-mode=fwd.png
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa/b=2-ub=True-hdim=256-qh=16-kvh=8-mode=fwd.csv:
--------------------------------------------------------------------------------
1 | seqlen,Triton,Jax-cudnn
2 | 1024.000000,0.395234,2.273197
3 | 2048.000000,0.564041,7.182823
4 | 4096.000000,0.001252,29.376169
5 | 6144.000000,0.001536,68.175011
6 | 8192.000000,0.001280,157.184326
7 |
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa/b=2-ub=True-hdim=256-qh=16-kvh=8-mode=fwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa/b=2-ub=True-hdim=256-qh=16-kvh=8-mode=fwd.png
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa/b=2-ub=True-hdim=64-qh=16-kvh=8-mode=fwd.csv:
--------------------------------------------------------------------------------
1 | seqlen,Triton,Jax-cudnn
2 | 1024.000000,0.147088,1.814945
3 | 2048.000000,0.253838,5.710850
4 | 4096.000000,0.616780,24.963852
5 | 6144.000000,0.614312,60.137470
6 | 8192.000000,0.295790,143.506439
7 |
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa/b=2-ub=True-hdim=64-qh=16-kvh=8-mode=fwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa/b=2-ub=True-hdim=64-qh=16-kvh=8-mode=fwd.png
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa/b=2-ub=True-hdim=64-qh=32-kvh=16-mode=fwd.csv:
--------------------------------------------------------------------------------
1 | seqlen,Triton,Jax-cudnn
2 | 1024.000000,0.213497,3.775728
3 | 2048.000000,0.390579,12.141111
4 | 4096.000000,0.481676,49.568768
5 | 6144.000000,0.090283,100.000000
6 | 8192.000000,0.001506,100.000000
7 |
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa/b=2-ub=True-hdim=64-qh=32-kvh=16-mode=fwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa/b=2-ub=True-hdim=64-qh=32-kvh=16-mode=fwd.png
--------------------------------------------------------------------------------
/benchmarks/triton-vs-jax-sdpa/results.html:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa/results.html
--------------------------------------------------------------------------------
/jax_flash_attn2/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .flash_attention import (
16 | AttentionConfig,
17 | Backend,
18 | FlashAttention,
19 | Platform,
20 | create_flash_attention,
21 | )
22 | from .flash_attention_jax import jax_flash_attention
23 | from .flash_attention_triton import triton_flash_attention
24 | from .refrence_call import basic_attention_refrence
25 |
26 | __all__ = (
27 | "AttentionConfig",
28 | "Backend",
29 | "FlashAttention",
30 | "Platform",
31 | "create_flash_attention",
32 | "triton_flash_attention",
33 | "jax_flash_attention",
34 | "basic_attention_refrence",
35 | )
36 |
37 | __version__ = "0.0.3"
38 |
--------------------------------------------------------------------------------
/jax_flash_attn2/flash_attention.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import os
17 | import typing as tp
18 | import warnings
19 | from dataclasses import dataclass
20 | from enum import Enum
21 | from functools import partial
22 |
23 | import chex
24 | import einops
25 | import jax
26 | import jax.numpy as jnp
27 |
28 | # fmt:off
29 | from jax.experimental.pallas.ops.tpu.flash_attention import BlockSizes as TPUBlockSizes
30 | from jax.experimental.pallas.ops.tpu.flash_attention import flash_attention as pallas_flash_attention_tpu
31 | from jax.extend.backend import get_backend
32 | # fmt:on
33 |
34 | from .flash_attention_jax import jax_flash_attention
35 | from .flash_attention_triton import triton_flash_attention
36 |
37 | AVAILABLE_FLASH_ATTENTION2_PLATFORMS = tp.Literal["triton", "pallas", "jax"]
38 | AVAILABLE_BACKENDS = tp.Literal["gpu", "tpu", "cpu"]
39 |
40 |
41 | def get_device_memory_usage(device: jax.Device) -> float:
42 | """
43 | Get the memory usage for a specific JAX device using local_devices stats.
44 |
45 | Args:
46 | device: JAX device to check
47 | Returns:
48 | float: Memory usage in bytes
49 | """
50 | try:
51 | memory_stats = device.memory_stats()
52 | return memory_stats["bytes_in_use"] if memory_stats else float("inf")
53 | except: # noqa
54 | return float("inf")
55 |
56 |
57 | def free_gpu_in_process() -> int:
58 | """
59 | Returns the index of the GPU with the most available memory using JAX local_devices.
60 |
61 | Returns:
62 | int: Index of the GPU with most free memory
63 | """
64 | devices = jax.local_devices()
65 | gpu_devices = [d for d in devices if d.platform == "gpu"]
66 |
67 | if not gpu_devices:
68 | return 0
69 |
70 | memory_usage = [get_device_memory_usage(device) for device in gpu_devices]
71 | return memory_usage.index(min(memory_usage))
72 |
73 |
74 | class Backend(str, Enum):
75 | """Supported compute backends."""
76 |
77 | GPU = "gpu"
78 | TPU = "tpu"
79 | CPU = "cpu"
80 |
81 |
82 | class Platform(str, Enum):
83 | """Supported Flash Attention platforms."""
84 |
85 | TRITON = "triton"
86 | PALLAS = "pallas"
87 | JAX = "jax"
88 |
89 |
90 | @dataclass
91 | class AttentionConfig:
92 | """Configuration for Flash Attention computation."""
93 |
94 | blocksize_q: int = 128
95 | blocksize_k: int = 128
96 | softmax_scale: tp.Optional[float] = None
97 | backend: tp.Optional[Backend] = None
98 | platform: tp.Optional[Platform] = None
99 |
100 | def __post_init__(self):
101 | if self.backend is None:
102 | self.backend = Backend(get_backend().platform)
103 |
104 | if self.platform is None:
105 | self.platform = self._default_platform()
106 |
107 | def _default_platform(self) -> Platform:
108 | """Determines the default platform based on the backend."""
109 | platform_map = {
110 | Backend.GPU: Platform.TRITON,
111 | Backend.CPU: Platform.JAX,
112 | Backend.TPU: Platform.PALLAS,
113 | }
114 | return platform_map.get(self.backend)
115 |
116 |
117 | class FlashAttention:
118 | """Flash Attention implementation with multiple backend support."""
119 |
120 | def __init__(self, config: tp.Optional[AttentionConfig] = None):
121 | self.config = config or AttentionConfig()
122 | self._validate_config()
123 |
124 | def _validate_config(self):
125 | """Validates the configuration settings."""
126 | valid_combinations = {
127 | (Backend.GPU, Platform.TRITON),
128 | (Backend.GPU, Platform.PALLAS),
129 | (Backend.GPU, Platform.JAX),
130 | (Backend.CPU, Platform.JAX),
131 | (Backend.TPU, Platform.JAX),
132 | (Backend.TPU, Platform.PALLAS),
133 | }
134 |
135 | if (self.config.backend, self.config.platform) not in valid_combinations:
136 | raise ValueError(
137 | f"Invalid backend-platform combination: "
138 | f"{self.config.backend}-{self.config.platform}"
139 | )
140 |
141 | @staticmethod
142 | def repeat_kv_heads(
143 | key: chex.Array, value: chex.Array, num_reps: int
144 | ) -> tp.Tuple[chex.Array, chex.Array]:
145 | """Repeats key and value heads to match query heads."""
146 | return (
147 | einops.repeat(key, "b s h d -> b s (h r) d", r=num_reps),
148 | einops.repeat(value, "b s h d -> b s (h r) d", r=num_reps),
149 | )
150 |
151 | def _handle_bias(
152 | self, bias: chex.Array, num_q_heads: int, num_kv_heads: int
153 | ) -> tp.Optional[chex.Array]:
154 | """Processes attention bias based on head configuration."""
155 | if bias is None:
156 | return None
157 |
158 | if bias.shape[1] == num_q_heads or bias.shape[1] == 1:
159 | return bias
160 |
161 | elif bias.shape[1] == num_kv_heads:
162 | return einops.repeat(
163 | bias, "b h q k -> b (h r) q k", r=num_q_heads // bias.shape[1]
164 | )
165 | else:
166 | raise ValueError(
167 | f"Incompatible bias shape. Got {bias.shape[1]} heads, "
168 | f"expected {num_q_heads}, {num_kv_heads}, or 1"
169 | )
170 |
171 | def __call__(
172 | self,
173 | query: chex.Array,
174 | key: chex.Array,
175 | value: chex.Array,
176 | bias: tp.Optional[chex.Array] = None,
177 | attention_mask: tp.Optional[chex.Array] = None,
178 | dropout_prob: float = 0.0,
179 | causal: bool = False,
180 | dropout_seed: tp.Optional[int] = None,
181 | adjust_sharindgs: bool = False,
182 | ) -> chex.Array:
183 | """
184 | Computes flash attention using the configured backend and platform.
185 | """
186 | num_q_heads = query.shape[2]
187 | num_kv_heads = key.shape[2]
188 |
189 | if num_q_heads % num_kv_heads != 0:
190 | raise ValueError(
191 | f"Query heads ({num_q_heads}) must be divisible by "
192 | f"key/value heads ({num_kv_heads})"
193 | )
194 | if bias is not None:
195 | bias = self._handle_bias(bias, num_q_heads, num_kv_heads)
196 | kw = dict(
197 | query=query,
198 | key=key,
199 | value=value,
200 | bias=bias,
201 | adjust_sharindgs=adjust_sharindgs,
202 | attention_mask=attention_mask,
203 | causal=causal,
204 | dropout_prob=dropout_prob,
205 | dropout_seed=dropout_seed,
206 | )
207 | if self.config.platform == Platform.TRITON:
208 | return self._compute_triton(**kw)
209 | elif self.config.platform == Platform.PALLAS:
210 | return self._compute_pallas(**kw)
211 | else: # Platform.JAX
212 | return self._compute_jax(**kw)
213 |
214 | forward = __call__
215 |
216 | def _compute_triton(
217 | self,
218 | query: chex.Array,
219 | key: chex.Array,
220 | value: chex.Array,
221 | bias: tp.Optional[chex.Array] = None,
222 | attention_mask: tp.Optional[chex.Array] = None,
223 | dropout_prob: float = 0.0,
224 | causal: bool = False,
225 | dropout_seed: tp.Optional[int] = None,
226 | adjust_sharindgs: bool = False,
227 | ) -> chex.Array:
228 | """Computes attention using Triton backend."""
229 | if adjust_sharindgs:
230 | query_sharding = query.sharding if hasattr(query, "sharding") else None
231 | target_gpu_idx = int(os.environ.get("GPU_IDX_FLASH_ATTN", free_gpu_in_process()))
232 | devices = jax.local_devices(process_index=jax.process_index(), backend="gpu")
233 | target_device = devices[target_gpu_idx]
234 | query = jax.device_put(query, target_device)
235 | key = jax.device_put(key, target_device)
236 | value = jax.device_put(value, target_device)
237 | if bias is not None:
238 | bias = jax.device_put(bias, target_device)
239 | if query.shape[1] != key.shape[1]:
240 | if query.shape[1] == 1:
241 | causal = False
242 | attention_mask = None
243 |
244 | attn = triton_flash_attention(
245 | q=query,
246 | k=key,
247 | v=value,
248 | bias=bias,
249 | attention_mask=attention_mask,
250 | dropout_prob=dropout_prob,
251 | causal=causal,
252 | dropout_seed=dropout_seed,
253 | softmax_scale=self.config.softmax_scale,
254 | )
255 |
256 | if adjust_sharindgs and query_sharding is not None:
257 | attn = jax.device_put(attn, query_sharding)
258 | return attn
259 |
260 | def _compute_pallas(
261 | self,
262 | query: chex.Array,
263 | key: chex.Array,
264 | value: chex.Array,
265 | bias: tp.Optional[chex.Array] = None,
266 | attention_mask: tp.Optional[chex.Array] = None,
267 | dropout_prob: float = 0.0,
268 | causal: bool = False,
269 | dropout_seed: tp.Optional[int] = None,
270 | adjust_sharindgs: bool = False,
271 | ) -> chex.Array:
272 | """Computes attention using Pallas backend."""
273 |
274 | if self.config.backend == Backend.GPU:
275 | warnings.warn(
276 | "Pallas-FlashAttention has been deprecated on GPUs (triton backend will be used)",
277 | stacklevel=1,
278 | )
279 | return self._compute_triton(
280 | query=query,
281 | key=key,
282 | value=value,
283 | bias=bias,
284 | adjust_sharindgs=adjust_sharindgs,
285 | attention_mask=attention_mask,
286 | dropout_prob=dropout_prob,
287 | causal=causal,
288 | dropout_seed=dropout_seed,
289 | )
290 |
291 | key, value = self.repeat_kv_heads(key, value, query.shape[2] // key.shape[2])
292 | query_lenght = query.shape[1]
293 | value_lenght = value.shape[1]
294 | if bias is not None:
295 | if bias.shape[1] != value.shape[2]:
296 | bias = jnp.repeat(bias, value.shape[2] // bias.shape[1], 1)
297 | # TPU implementation
298 | block_sizes = TPUBlockSizes(
299 | block_q=min(self.config.blocksize_q, query_lenght),
300 | block_k_major=min(self.config.blocksize_k, value_lenght),
301 | block_k=min(self.config.blocksize_k, value_lenght),
302 | block_b=1,
303 | block_q_major_dkv=min(self.config.blocksize_q, query_lenght),
304 | block_k_major_dkv=min(self.config.blocksize_k, value_lenght),
305 | block_k_dkv=min(self.config.blocksize_k, value_lenght),
306 | block_q_dkv=min(self.config.blocksize_q, query_lenght),
307 | block_k_major_dq=min(self.config.blocksize_k, value_lenght),
308 | block_k_dq=min(self.config.blocksize_k, value_lenght),
309 | block_q_dq=min(self.config.blocksize_q, query_lenght),
310 | )
311 |
312 | return partial(
313 | pallas_flash_attention_tpu,
314 | sm_scale=self.config.softmax_scale,
315 | block_sizes=block_sizes,
316 | causal=causal,
317 | )(
318 | query.transpose(0, 2, 1, 3),
319 | key.transpose(0, 2, 1, 3),
320 | value.transpose(0, 2, 1, 3),
321 | bias,
322 | ).transpose(0, 2, 1, 3)
323 |
324 | def _compute_jax(
325 | self,
326 | query: chex.Array,
327 | key: chex.Array,
328 | value: chex.Array,
329 | bias: tp.Optional[chex.Array] = None,
330 | attention_mask: tp.Optional[chex.Array] = None,
331 | dropout_prob: float = 0.0,
332 | causal: bool = False,
333 | dropout_seed: tp.Optional[int] = None,
334 | adjust_sharindgs: bool = False,
335 | ) -> chex.Array:
336 | """Computes attention using JAX backend."""
337 | key, value = self.repeat_kv_heads(key, value, query.shape[2] // key.shape[2])
338 | return jax_flash_attention(
339 | query_state=query,
340 | key_state=key,
341 | value_state=value,
342 | mask=None,
343 | bias=bias,
344 | blocksize_q=self.config.blocksize_q,
345 | blocksize_k=self.config.blocksize_k,
346 | dtype=query.dtype,
347 | softmax_scale=self.config.softmax_scale,
348 | dropout=dropout_prob,
349 | )
350 |
351 | def __repr__(self):
352 | return (
353 | f"FlashAttention(platform={self.config.platform}, backend={self.config.backend})"
354 | )
355 |
356 | __str__ = __repr__
357 |
358 |
359 | def create_flash_attention(
360 | backend: tp.Optional[tp.Union[Backend, str]] = None,
361 | platform: tp.Optional[tp.Union[Platform, str]] = None,
362 | **kwargs,
363 | ) -> FlashAttention:
364 | """
365 | Factory function to create a FlashAttention instance with the specified configuration.
366 |
367 | Args:
368 | backend: Compute backend to use (GPU, TPU, or CPU)
369 | platform: Platform to use (Triton, Pallas, or JAX)
370 | **kwargs: Additional configuration parameters for AttentionConfig
371 |
372 | Returns:
373 | Configured FlashAttention instance
374 | """
375 | if isinstance(backend, str):
376 | backend = Backend(backend)
377 | if isinstance(platform, str):
378 | platform = Platform(platform)
379 |
380 | config = AttentionConfig(backend=backend, platform=platform, **kwargs)
381 | return FlashAttention(config)
382 |
--------------------------------------------------------------------------------
/jax_flash_attn2/flash_attention_jax/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from ._flash_attention import flash_attention as jax_flash_attention
16 |
17 | __all__ = ("jax_flash_attention",)
18 |
--------------------------------------------------------------------------------
/jax_flash_attn2/flash_attention_jax/_backward_jax.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | # Implementation based on FlashAttention 2 (https://arxiv.org/pdf/2307.08691) by @erfanzar,
17 | # with a few bug fixes and adjustments.
18 |
19 | import functools
20 | import typing as tp
21 |
22 | import jax
23 | import jax.numpy as jnp
24 | import jax.sharding
25 | from jax import lax
26 |
27 |
28 | @functools.partial(jax.named_call, name="_bwd_flash_attn")
29 | def _bwd_flash_attn(
30 | dropout: float,
31 | inference: bool,
32 | key: tp.Optional[jax.random.PRNGKey],
33 | blocksize_q: int,
34 | blocksize_k: int,
35 | dtype: tp.Optional[jnp.dtype],
36 | precision: lax.PrecisionLike,
37 | residuals,
38 | grad_in: jax.Array,
39 | ):
40 | """Backward pass of FlashAttention."""
41 |
42 | del dtype
43 | (
44 | O, # noqa: E741
45 | L,
46 | query_state,
47 | key_state,
48 | value_state,
49 | mask,
50 | bias,
51 | ) = residuals
52 | dO = grad_in
53 |
54 | b, h, _, d = query_state.shape
55 | q_seq = query_state.shape[2]
56 | k_seq = key_state.shape[2]
57 | assert q_seq % blocksize_q == 0
58 | assert k_seq % blocksize_k == 0
59 | Tr = q_seq // blocksize_q
60 | Tc = k_seq // blocksize_k
61 |
62 | D = jnp.sum(dO * O, axis=-1)
63 |
64 | dQ = (query_state * 0.0).astype(query_state.dtype)
65 | dK = (key_state * 0.0).astype(key_state.dtype)
66 | dV = (value_state * 0.0).astype(value_state.dtype)
67 | global_mask = mask
68 | is_causal = mask is not None
69 |
70 | @jax.jit
71 | @functools.partial(jax.named_call, name="_bwd_flash_attn_call_o")
72 | def call_o(state):
73 | j, dQ, dK, dV = state
74 | k_j = jax.lax.dynamic_slice_in_dim(key_state, j * blocksize_k, blocksize_k, 2)
75 | v_j = jax.lax.dynamic_slice_in_dim(value_state, j * blocksize_k, blocksize_k, 2)
76 |
77 | dK_j = jax.lax.dynamic_slice_in_dim(dK, j * blocksize_k, blocksize_k, 2)
78 | dV_j = jax.lax.dynamic_slice_in_dim(dV, j * blocksize_k, blocksize_k, 2)
79 |
80 | @jax.jit
81 | @functools.partial(jax.named_call, name="_bwd_flash_attn_call_o_call_qk")
82 | def do_inner_block(state):
83 | i, j, dQ, dK_j, dV_j = state
84 | q_i = jax.lax.dynamic_slice_in_dim(query_state, i * blocksize_q, blocksize_q, 2)
85 | dQ_i = jax.lax.dynamic_slice_in_dim(dQ, i * blocksize_q, blocksize_q, 2)
86 | dO_i = jax.lax.dynamic_slice_in_dim(dO, i * blocksize_q, blocksize_q, 2)
87 |
88 | L_i = jax.lax.dynamic_slice_in_dim(L, i * blocksize_q, blocksize_q, 2)
89 | D_i = jax.lax.dynamic_slice_in_dim(D, i * blocksize_q, blocksize_q, 2)
90 | s_ij = q_i @ k_j.transpose(0, 1, 3, 2)
91 | if dropout > 0 and not inference:
92 | rng = jax.random.fold_in(key, i * Tc + j)
93 | keep_prob = 1.0 - dropout
94 | broadcast_shape = list(s_ij.shape)
95 | mask = jax.random.bernoulli(rng, p=keep_prob, shape=broadcast_shape)
96 | mask = jnp.broadcast_to(mask, s_ij.shape)
97 | s_ij = lax.select(mask, s_ij / keep_prob, jnp.zeros_like(s_ij))
98 |
99 | if bias is not None:
100 | b_i = jax.lax.dynamic_slice_in_dim(bias, i * blocksize_q, blocksize_q, 2)
101 | b_ij = jax.lax.dynamic_slice_in_dim(b_i, j * blocksize_k, blocksize_k, 3)
102 | s_ij = s_ij + b_ij
103 |
104 | if global_mask is not None:
105 | ma_i = jax.lax.dynamic_slice_in_dim(
106 | global_mask, i * blocksize_q, blocksize_q, 2
107 | )
108 | ma_ij = jax.lax.dynamic_slice_in_dim(ma_i, j * blocksize_k, blocksize_k, 3)
109 | s_ij = jnp.where(ma_ij, s_ij, -1e10)
110 |
111 | p_ij = jnp.exp(s_ij - jnp.expand_dims(L_i, -1))
112 |
113 | if dropout > 0 and not inference:
114 | rng = jax.random.fold_in(key, i * Tc + j)
115 | keep_prob = 1.0 - dropout
116 | broadcast_shape = list(p_ij.shape)
117 | mask = jax.random.bernoulli(rng, p=keep_prob, shape=broadcast_shape)
118 | mask = jnp.broadcast_to(mask, p_ij.shape)
119 | p_ij = lax.select(mask, p_ij / keep_prob, jnp.zeros_like(p_ij))
120 |
121 | dV_j = dV_j + jnp.matmul(p_ij.transpose(0, 1, 3, 2), dO_i, precision=precision)
122 |
123 | dP_ij = jnp.matmul(dO_i, v_j.transpose(0, 1, 3, 2), precision=precision)
124 |
125 | dS_ij = p_ij * (dP_ij - D_i[..., None])
126 | dQ_i = dQ_i + jnp.matmul(dS_ij, k_j, precision=precision)
127 | dK_j = dK_j + jnp.matmul(dS_ij.transpose(0, 1, 3, 2), q_i, precision=precision)
128 | dQ = jax.lax.dynamic_update_slice_in_dim(
129 | dQ,
130 | dQ_i.astype(dQ.dtype),
131 | i * blocksize_q,
132 | 2,
133 | )
134 | return (
135 | i + 1,
136 | j,
137 | dQ.astype(query_state.dtype),
138 | dK_j.astype(key_state.dtype),
139 | dV_j.astype(value_state.dtype),
140 | )
141 |
142 | i_start = j if is_causal else 0
143 | _, j, dQ, dK_j, dV_j = jax.lax.while_loop(
144 | lambda state: state[0] < Tr,
145 | do_inner_block,
146 | (i_start, j, dQ, dK_j, dV_j),
147 | )
148 |
149 | dK = jax.lax.dynamic_update_slice_in_dim(
150 | dK, dK_j.astype(dK.dtype), j * blocksize_q, 2
151 | )
152 | dV = jax.lax.dynamic_update_slice_in_dim(
153 | dV, dV_j.astype(dV.dtype), j * blocksize_q, 2
154 | )
155 |
156 | return j + 1, dQ, dK, dV
157 |
158 | _, dQ, dK, dV = jax.lax.while_loop(
159 | lambda state: state[0] < Tc,
160 | call_o,
161 | (0, dQ, dK, dV),
162 | )
163 |
164 | return dQ, dK, dV, None, None
165 |
--------------------------------------------------------------------------------
/jax_flash_attn2/flash_attention_jax/_flash_attention.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | # Implementation based on FlashAttention 2 (https://arxiv.org/pdf/2307.08691) by @erfanzar,
17 | # with a few bug fixes and adjustments.
18 |
19 | import functools
20 | import math
21 | import typing as tp
22 |
23 | import jax
24 | import jax.numpy as jnp
25 | import jax.sharding
26 | from jax import lax
27 |
28 | from ._backward_jax import _bwd_flash_attn
29 | from ._forward_jax import _fwd_flash_attn
30 |
31 |
32 | @functools.partial(
33 | jax.jit,
34 | static_argnames=[
35 | "dtype",
36 | "precision",
37 | "blocksize_q",
38 | "blocksize_k",
39 | ],
40 | )
41 | def flash_attention(
42 | query_state: jax.Array,
43 | key_state: jax.Array,
44 | value_state: jax.Array,
45 | mask: tp.Optional[jax.Array] = None,
46 | bias: tp.Optional[jax.Array] = None,
47 | *,
48 | dropout: float = 0.0,
49 | inference: bool = True,
50 | key: tp.Optional[jax.random.PRNGKey] = None,
51 | blocksize_q: tp.Optional[int] = None,
52 | blocksize_k: tp.Optional[int] = None,
53 | dtype: tp.Optional[jnp.dtype] = None,
54 | precision: lax.PrecisionLike = None,
55 | head_dim: tp.Optional[int] = None,
56 | softmax_scale: tp.Optional[float] = None,
57 | ) -> jax.Array:
58 | """
59 | Computes multi-head attention using FlashAttention implementation.
60 |
61 | This implementation makes use of the FlashAttention algorithm for faster
62 | and more memory-efficient computation of attention. It is particularly
63 | beneficial for long sequences.
64 |
65 | Args:
66 | query_state: Query, shape (`batch_size`, `q_len`, `num_heads`, `head_dim`).
67 | key_state: Key, shape (`batch_size`, `kv_len`, `num_heads`, `head_dim`).
68 | value_state: Value, shape (`batch_size`, `kv_len`, `num_heads`, `head_dim`).
69 | mask: tp.Optional attention mask. This can be any of the following:
70 |
71 | - No mask (default): All attention weights are computed.
72 | - Boolean mask (2D): shape (`batch_size`, `q_len`), with `True` for
73 | valid and `False` for masked positions.
74 | - Integer mask (2D): shape (`batch_size`, `q_len`), where the value at
75 | each position indicates the length of the sequence to attend to.
76 | - 4D mask: shape (`batch_size`, `q_len`, `kv_len`), with `True` for
77 | valid and `False` for masked positions.
78 |
79 | bias: tp.Optional attention bias.
80 | dropout: Dropout rate.
81 | inference: Whether to run in inference mode.
82 | key: PRNG key for dropout.
83 | blocksize_q: Block size for query processing.
84 | blocksize_k: Block size for key/value processing.
85 | dtype: tp.Optional dtype for the output.
86 | precision: tp.Optional precision for matrix multiplication.
87 | head_dim: tp.Optional head dim to be used at `query_state = query_state / math.sqrt(float(head_dim or query_state.shape[-1]))`.
88 | softmax_scale tp.Optional softmax_scale to be used for `query_state = query_state * softmax_scale`
89 |
90 | Returns:
91 | Output of multi-head attention, with shape
92 | (`batch_size`, `q_len`, `num_heads`, `head_dim`).
93 |
94 | Raises:
95 | ValueError: If `dropout` is not in the range [0, 1], or if `key` is not
96 | provided during training when `dropout` > 0.
97 | """
98 | query_state, key_state, value_state = map(
99 | lambda x: x.transpose(0, 2, 1, 3),
100 | [query_state, key_state, value_state],
101 | )
102 | if not inference and dropout > 0 and key is None:
103 | raise ValueError("key must be provided for training")
104 | if dropout < 0 or dropout > 1:
105 | raise ValueError(f"invalid dropout {dropout}")
106 | if dtype is not None:
107 | query_state = query_state.astype(dtype)
108 | key_state = key_state.astype(dtype)
109 |
110 | blocksize_k = min(key_state.shape[2], blocksize_k or 128)
111 | blocksize_q = min(query_state.shape[2], blocksize_q or 128)
112 | if head_dim is not None and softmax_scale is not None:
113 | raise ValueError("you can't pass both `head_dim` and `softmax_scale`.")
114 | if head_dim is not None:
115 | query_state = query_state / math.sqrt(float(head_dim))
116 | elif softmax_scale is not None:
117 | query_state = query_state * softmax_scale
118 | else:
119 | query_state = query_state / math.sqrt(float(query_state.shape[-1]))
120 | return _flash_attn2(
121 | query_state,
122 | key_state,
123 | value_state,
124 | mask,
125 | bias,
126 | dropout,
127 | inference,
128 | key,
129 | blocksize_q,
130 | blocksize_k,
131 | dtype,
132 | precision,
133 | ).transpose(0, 2, 1, 3)
134 |
135 |
136 | @functools.partial(
137 | jax.custom_vjp,
138 | nondiff_argnums=(5, 6, 7, 8, 9, 10, 11),
139 | )
140 | def _flash_attn2(
141 | query_state: jax.Array,
142 | key_state: jax.Array,
143 | value_state: jax.Array,
144 | mask: tp.Optional[jax.Array] = None,
145 | bias: tp.Optional[jax.Array] = None,
146 | dropout: float = 0.0,
147 | inference: bool = False,
148 | key: tp.Optional[jax.random.PRNGKey] = None,
149 | blocksize_q: int = 128,
150 | blocksize_k: int = 128,
151 | dtype: tp.Optional[jnp.dtype] = jnp.float32,
152 | precision: lax.PrecisionLike = None,
153 | ) -> jax.Array:
154 | """Custom VJP-enabled wrapper for FlashAttention forward pass."""
155 | return _fwd_flash_attn(
156 | query_state,
157 | key_state,
158 | value_state,
159 | mask,
160 | bias,
161 | dropout,
162 | inference,
163 | key,
164 | blocksize_q,
165 | blocksize_k,
166 | dtype,
167 | precision,
168 | )[0]
169 |
170 |
171 | _flash_attn2.defvjp(_fwd_flash_attn, _bwd_flash_attn)
172 |
173 |
174 | def fwd_test():
175 | import flax
176 | from jax import random as jrand
177 |
178 | from easydel.utils import GenerateRNG
179 |
180 | rng = GenerateRNG()
181 |
182 | b, h, qs, s, d = 1, 32, 2048, 2048, 128
183 | dtype = jnp.float16
184 |
185 | q = jrand.normal(rng.rng, shape=(b, qs, h, d), dtype=dtype)
186 | k = jrand.normal(rng.rng, shape=(b, s, h, d), dtype=dtype)
187 | v = jrand.normal(rng.rng, shape=(b, s, h, d), dtype=dtype)
188 | b = jnp.where(
189 | jrand.randint(rng.rng, shape=(b, h, qs, s), minval=0, maxval=3) > 1,
190 | 0,
191 | jnp.finfo(dtype).min,
192 | )
193 | excepted_result = flax.nnx.dot_product_attention(
194 | query=q,
195 | key=k,
196 | value=v,
197 | bias=b,
198 | )
199 | result = flash_attention(
200 | query_state=q,
201 | key_state=k,
202 | value_state=v,
203 | bias=b,
204 | dtype=dtype,
205 | blocksize_q=64,
206 | blocksize_k=64,
207 | )
208 |
209 | print(f"PRED : {result[0, 0, 0, :5]}")
210 | print(f"ORGN : {excepted_result[0, 0, 0, :5]}")
211 |
212 | print(jnp.allclose(excepted_result, result, atol=0.125, rtol=0))
213 |
214 |
215 | def bwd_test():
216 | import flax
217 | from jax import random as jrand
218 |
219 | from easydel.utils import GenerateRNG
220 |
221 | rng = GenerateRNG()
222 | b, h, qs, s, d = 2, 32, 64, 64, 64
223 | dtype = jnp.float16
224 |
225 | q = jrand.normal(rng.rng, shape=(b, qs, h, d), dtype=dtype)
226 | k = jrand.normal(rng.rng, shape=(b, s, h, d), dtype=dtype)
227 | v = jrand.normal(rng.rng, shape=(b, s, h, d), dtype=dtype)
228 | b = jnp.where(
229 | jrand.randint(rng.rng, shape=(b, h, qs, s), minval=0, maxval=3) > 1,
230 | 0,
231 | jnp.finfo(dtype).min,
232 | )
233 |
234 | excepted_result = jax.grad(lambda *x: flax.nnx.dot_product_attention(*x).sum())(
235 | q, k, v
236 | )
237 | result = jax.grad(
238 | lambda *x: flash_attention(
239 | *x,
240 | dtype=dtype,
241 | blocksize_q=qs,
242 | blocksize_k=s,
243 | precision=jax.lax.Precision("HIGHEST".lower()),
244 | ).sum()
245 | )(q, k, v)
246 |
247 | print(f"PRED BWD : {result[0, 0, 0, :5]}")
248 | print(f"ORGN BWD : {excepted_result[0, 0, 0, :5]}")
249 |
250 | print(jnp.allclose(excepted_result, result, atol=0.125, rtol=0))
251 |
252 |
253 | jax_flash_attn_2_mu = flash_attention
254 |
255 | __all__ = ["jax_flash_attn_2_mu"]
256 |
257 | if __name__ == "__main__":
258 | # fwd_test()
259 | bwd_test()
260 |
--------------------------------------------------------------------------------
/jax_flash_attn2/flash_attention_jax/_forward_jax.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import functools
16 | import typing as tp
17 |
18 | import jax
19 | import jax.numpy as jnp
20 | import jax.sharding
21 | from eformer.escale import with_sharding_constraint
22 | from jax import lax
23 |
24 |
25 | @functools.partial(jax.named_call, name="_fwd_flash_attn")
26 | def _fwd_flash_attn(
27 | query_state: jax.Array,
28 | key_state: jax.Array,
29 | value_state: jax.Array,
30 | mask: tp.Optional[jax.Array],
31 | bias: tp.Optional[jax.Array],
32 | dropout: float,
33 | inference: bool,
34 | key: tp.Optional[jax.random.PRNGKey],
35 | blocksize_q: int,
36 | blocksize_k: int,
37 | dtype: tp.Optional[jnp.dtype],
38 | precision: lax.PrecisionLike,
39 | ) -> tuple[jax.Array, tuple[jax.Array, ...]]:
40 | """Forward pass of FlashAttention."""
41 | b, h, _, d = query_state.shape
42 | q_seq = query_state.shape[2]
43 | k_seq = key_state.shape[2]
44 | assert q_seq % blocksize_q == 0, (
45 | "Query sequence length is not visible by queryblock size"
46 | )
47 | assert k_seq % blocksize_k == 0, "Key sequence length is not visible by keyblock size"
48 | Tr = q_seq // blocksize_q
49 | Tc = k_seq // blocksize_k
50 | o_shape = jax.eval_shape(
51 | lambda: (query_state @ key_state.transpose(0, 1, 3, 2)) @ value_state
52 | ).shape
53 | o = jnp.zeros(o_shape, dtype=dtype)
54 |
55 | lse = jnp.full((b, h, q_seq), fill_value=-jnp.inf, dtype=jnp.float32)
56 | if hasattr(query_state, "sharding"):
57 | if isinstance(query_state.sharding, jax.sharding.NamedSharding):
58 | with query_state.sharding.mesh:
59 | o = with_sharding_constraint(
60 | arr=o,
61 | sharding=query_state.sharding,
62 | )
63 | lse = with_sharding_constraint(
64 | arr=lse,
65 | sharding=jax.sharding.PartitionSpec(*query_state.sharding.spec[:3]),
66 | )
67 | elif isinstance(
68 | query_state.sharding, jax.sharding.SingleDeviceSharding
69 | ) and hasattr(query_state.sharding, "_device"):
70 | o = jax.device_put(o, query_state.sharding._device)
71 | lse = jax.device_put(lse, query_state.sharding._device)
72 |
73 | global_mask = mask
74 |
75 | @jax.jit
76 | @functools.partial(jax.named_call, name="_fwd_flash_attn_call_o")
77 | def call_o(state):
78 | i, o, lse = state
79 | q_i = jax.lax.dynamic_slice_in_dim(query_state, i * blocksize_q, blocksize_q, 2)
80 | o_i = jax.lax.dynamic_slice_in_dim(o, i * blocksize_q, blocksize_q, 2)
81 | lse_i = jax.lax.dynamic_slice_in_dim(lse, i * blocksize_q, blocksize_q, 2)
82 | m_i = jnp.full((b, h, blocksize_q), fill_value=-jnp.inf, dtype=dtype)
83 |
84 | @jax.jit
85 | @functools.partial(jax.named_call, name="_fwd_flash_attn_call_o_call_qk")
86 | def call_qk(state):
87 | i, j, o_i, q_i, lse_i, m_i = state
88 | k_j = jax.lax.dynamic_slice_in_dim(key_state, j * blocksize_k, blocksize_k, 2)
89 | v_j = jax.lax.dynamic_slice_in_dim(value_state, j * blocksize_k, blocksize_k, 2)
90 |
91 | s_ij = jnp.einsum(
92 | "bhqd,bhdk->bhqk",
93 | q_i,
94 | k_j.transpose(0, 1, 3, 2),
95 | precision=precision,
96 | )
97 |
98 | if bias is not None:
99 | b_i = jax.lax.dynamic_slice_in_dim(bias, i * blocksize_q, blocksize_q, 2)
100 | b_ij = jax.lax.dynamic_slice_in_dim(b_i, j * blocksize_k, blocksize_k, 3)
101 | s_ij = s_ij + b_ij
102 | if global_mask is not None:
103 | ma_i = jax.lax.dynamic_slice_in_dim(
104 | global_mask, i * blocksize_q, blocksize_q, 2
105 | )
106 | ma_ij = jax.lax.dynamic_slice_in_dim(ma_i, j * blocksize_k, blocksize_k, 3)
107 | s_ij = jnp.where(ma_ij, s_ij, -1e10)
108 |
109 | if dropout > 0 and not inference:
110 | rng = jax.random.fold_in(key, i * Tc + j)
111 | keep_prob = 1.0 - dropout
112 | broadcast_shape = list(s_ij.shape)
113 | mask = jax.random.bernoulli(rng, p=keep_prob, shape=broadcast_shape)
114 | mask = jnp.broadcast_to(mask, s_ij.shape)
115 | s_ij = lax.select(mask, s_ij / keep_prob, jnp.zeros_like(s_ij))
116 |
117 | m_ij = jnp.maximum(m_i, jnp.max(s_ij, axis=-1))
118 | p = jnp.exp(s_ij - jnp.expand_dims(m_ij, -1))
119 |
120 | l_ij = jnp.sum(p, -1)
121 |
122 | o_scale = jnp.exp(m_i - m_ij)
123 | o_i = o_i * jnp.expand_dims(o_scale, -1)
124 |
125 | o_i = o_i + jnp.einsum(
126 | "bhqk,bhkd->bhqd",
127 | p,
128 | v_j,
129 | precision=precision,
130 | )
131 |
132 | return (
133 | i,
134 | j + 1,
135 | o_i.astype(dtype),
136 | q_i.astype(dtype),
137 | jnp.log(jnp.exp(lse_i - m_ij) + l_ij) + m_ij,
138 | m_ij.astype(dtype),
139 | )
140 |
141 | j_end = jnp.minimum(i + 1, Tc) if mask is not None else Tc
142 |
143 | _, _, o_i, _, lse_i, m_i = jax.lax.while_loop(
144 | lambda state: state[1] < j_end,
145 | call_qk,
146 | (i, 0, o_i, q_i, lse_i, m_i),
147 | )
148 | o_scale = jnp.exp(m_i - lse_i)
149 | o_i = o_i * jnp.expand_dims(o_scale, -1)
150 |
151 | o = jax.lax.dynamic_update_slice_in_dim(o, o_i.astype(o.dtype), i * blocksize_q, 2)
152 | lse = jax.lax.dynamic_update_slice_in_dim(
153 | lse,
154 | lse_i.astype(lse.dtype),
155 | i * blocksize_q,
156 | 2,
157 | )
158 | return i + 1, o, lse
159 |
160 | _, o, lse = jax.lax.while_loop(lambda state: state[0] < Tr, call_o, (0, o, lse))
161 |
162 | return o, (
163 | o,
164 | lse,
165 | query_state, #: jax.Array
166 | key_state, #: jax.Array
167 | value_state, #: jax.Array
168 | mask,
169 | bias,
170 | )
171 |
--------------------------------------------------------------------------------
/jax_flash_attn2/flash_attention_triton/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from ._flash_attention import flash_attention as triton_flash_attention
16 |
17 | __all__ = ("triton_flash_attention",)
18 |
--------------------------------------------------------------------------------
/jax_flash_attn2/flash_attention_triton/_backward_triton.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import math
16 | import typing as tp
17 |
18 | import chex
19 | import jax
20 | import triton
21 | import triton.language as tl
22 | from eformer.callib import triton_call
23 | from jax import numpy as jnp
24 | from triton import Config
25 |
26 | from ._utils import (
27 | dtype_index,
28 | get_sharding,
29 | get_strides,
30 | safe_autotune,
31 | attention_pack_with_static_shape,
32 | attention_unpack_with_static_shape,
33 | calc_bias_strides,
34 | padded_load,
35 | )
36 |
37 |
38 | def config_prune_kernel(
39 | configs: tp.List[Config],
40 | named_args: tp.Dict[str, tp.Any],
41 | **kwargs,
42 | ) -> tp.List[Config]:
43 | kept_configs = []
44 | for config in configs:
45 | largest_m = (
46 | max(
47 | config.kwargs["BLOCK_M1"],
48 | config.kwargs["BLOCK_M2"],
49 | )
50 | > named_args["QSeq"]
51 | )
52 | largest_n = (
53 | max(
54 | config.kwargs["BLOCK_N1"],
55 | config.kwargs["BLOCK_N2"],
56 | )
57 | > named_args["KSeq"]
58 | )
59 | if largest_m or largest_n:
60 | pass
61 | else:
62 | kept_configs.append(config)
63 | if kept_configs:
64 | return kept_configs
65 | return [
66 | Config(
67 | {
68 | "BLOCK_M1": 32,
69 | "BLOCK_N1": 32,
70 | "BLOCK_M2": 32,
71 | "BLOCK_N2": 32,
72 | },
73 | num_warps=4,
74 | num_stages=0,
75 | )
76 | ]
77 |
78 |
79 | @safe_autotune(
80 | configs=[
81 | Config({"BLOCK_M": 16}, num_warps=4, num_stages=0),
82 | Config({"BLOCK_M": 32}, num_warps=4, num_stages=0),
83 | Config({"BLOCK_M": 64}, num_warps=4, num_stages=0),
84 | Config({"BLOCK_M": 128}, num_warps=4, num_stages=0),
85 | ],
86 | key=["CQSeq", "DRuntime"],
87 | )
88 | @triton.jit
89 | def _attn_bwd_preprocess(
90 | Po,
91 | Do,
92 | stride_oz,
93 | stride_om,
94 | stride_oh,
95 | stride_dez,
96 | stride_dem,
97 | stride_deh,
98 | nheads,
99 | QSeq,
100 | max_seqlen_q_rounded,
101 | cum_seqlens_q,
102 | headdim,
103 | CQSeq, # Re-compile argument
104 | DRuntime, # Re-compile argument
105 | Delta,
106 | VARLEN: tl.constexpr,
107 | BLOCK_M: tl.constexpr,
108 | BLOCK_HEADDIM: tl.constexpr,
109 | ):
110 | start_m = tl.program_id(0)
111 | off_zh = tl.program_id(1)
112 | off_z = off_zh // nheads
113 | off_h = off_zh % nheads
114 | offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
115 | offs_d = tl.arange(0, BLOCK_HEADDIM)
116 |
117 | if VARLEN:
118 | start_seqlen_q = tl.load(cum_seqlens_q + off_z)
119 | actual_seqlen_q = tl.load(cum_seqlens_q + off_z + 1) - start_seqlen_q
120 | cu_seq_start_q = tl.load(cum_seqlens_q + off_z)
121 | off_z = 0
122 | else:
123 | actual_seqlen_q = QSeq
124 | cu_seq_start_q = 0
125 |
126 | o_ptrs = (
127 | Po
128 | + off_z * stride_oz
129 | + off_h * stride_oh
130 | + cu_seq_start_q * stride_om
131 | + offs_m[:, None] * stride_om
132 | + offs_d[None, :]
133 | )
134 | do_ptrs = (
135 | Do
136 | + off_z * stride_dez
137 | + off_h * stride_deh
138 | + cu_seq_start_q * stride_dem
139 | + offs_m[:, None] * stride_dem
140 | + offs_d[None, :]
141 | )
142 |
143 | mask = (offs_m[:, None] < actual_seqlen_q) & (offs_d[None, :] < headdim)
144 | o = tl.load(o_ptrs, mask=mask, other=0.0).to(tl.float32)
145 | do = tl.load(do_ptrs, mask=mask, other=0.0).to(tl.float32)
146 | delta = tl.sum(o * do, axis=1)
147 | tl.store(Delta + off_zh * max_seqlen_q_rounded + offs_m, delta)
148 |
149 |
150 | @triton.jit
151 | def _attn_bwd_dkdv(
152 | index_start_m,
153 | k,
154 | v,
155 | dk,
156 | dv,
157 | M,
158 | D,
159 | offs_m,
160 | offs_n,
161 | offs_d,
162 | q_ptrs,
163 | bias_ptrs,
164 | dropout_offs,
165 | do_ptrs,
166 | softmax_scale,
167 | stride_qm,
168 | stride_bm,
169 | stride_dom,
170 | actual_seqlen_q,
171 | actual_seqlen_k,
172 | fully_masked_lines,
173 | headdim,
174 | MASKED: tl.constexpr,
175 | IS_CAUSAL: tl.constexpr,
176 | BIAS_ON: tl.constexpr,
177 | USE_DROPOUT: tl.constexpr,
178 | PAD_ROWS: tl.constexpr,
179 | PAD_COLS: tl.constexpr,
180 | HEADS_PADDED: tl.constexpr,
181 | ):
182 | LN2: tl.constexpr = 1.44269504089
183 | q_ptrs = q_ptrs + index_start_m * stride_qm
184 | do_ptrs = do_ptrs + index_start_m * stride_dom
185 | if BIAS_ON:
186 | bias_ptrs = bias_ptrs + index_start_m * stride_bm
187 | if USE_DROPOUT:
188 | dropout_offs += index_start_m * actual_seqlen_k
189 |
190 | offs_m_curr = index_start_m + offs_m
191 |
192 | q = padded_load(
193 | q_ptrs,
194 | offs_m_curr,
195 | offs_d,
196 | PAD_ROWS or HEADS_PADDED,
197 | PAD_ROWS or HEADS_PADDED,
198 | actual_seqlen_q,
199 | headdim,
200 | )
201 | me_i = tl.load(M + offs_m_curr)
202 | if BIAS_ON:
203 | bias = padded_load(
204 | bias_ptrs,
205 | offs_m_curr,
206 | offs_n,
207 | PAD_ROWS or HEADS_PADDED,
208 | PAD_ROWS or HEADS_PADDED,
209 | actual_seqlen_q,
210 | actual_seqlen_k,
211 | )
212 |
213 | qk = tl.dot(q, tl.trans(k))
214 | if BIAS_ON:
215 | qk += bias / softmax_scale
216 |
217 | offs_n_causal = offs_n - actual_seqlen_k + actual_seqlen_q
218 | if MASKED:
219 | if PAD_COLS:
220 | if IS_CAUSAL:
221 | qk = tl.where(
222 | tl.minimum(actual_seqlen_q - 1, offs_m_curr)[:, None]
223 | >= offs_n_causal[None, :],
224 | qk,
225 | float("-inf"),
226 | )
227 | else:
228 | qk = tl.where(actual_seqlen_q - 1 >= offs_n_causal[None, :], qk, float("-inf"))
229 | elif IS_CAUSAL:
230 | qk = tl.where(offs_m_curr[:, None] >= offs_n_causal[None, :], qk, float("-inf"))
231 | tl.debug_barrier()
232 | p = tl.exp2(qk * (softmax_scale * LN2) - me_i[:, None])
233 |
234 | if MASKED:
235 | if fully_masked_lines > 0:
236 | p = tl.where(offs_m_curr[:, None] < fully_masked_lines, 0, p)
237 |
238 | do = padded_load(
239 | do_ptrs,
240 | offs_m_curr,
241 | offs_d,
242 | PAD_ROWS,
243 | HEADS_PADDED,
244 | actual_seqlen_q,
245 | headdim,
246 | )
247 |
248 | dv += tl.dot(tl.trans(p).to(do.dtype), do)
249 | dp = tl.dot(do, tl.trans(v))
250 | Di = tl.load(D + offs_m_curr)
251 | ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype)
252 | dk += tl.dot(tl.trans(ds), q)
253 |
254 | return dk, dv
255 |
256 |
257 | @triton.jit
258 | def _attn_bwd_block_dkdv(
259 | index_start_n,
260 | Q,
261 | K,
262 | V,
263 | B,
264 | Dropout,
265 | Do,
266 | Dk,
267 | Dv,
268 | M,
269 | D,
270 | softmax_scale,
271 | stride_qm,
272 | stride_kn,
273 | stride_vn,
274 | stride_bm,
275 | stride_dom,
276 | stride_dkn,
277 | stride_dvn,
278 | actual_seqlen_q,
279 | actual_seqlen_k,
280 | headdim,
281 | IS_CAUSAL: tl.constexpr,
282 | BIAS_ON: tl.constexpr,
283 | USE_DROPOUT: tl.constexpr,
284 | PAD_COLS: tl.constexpr,
285 | HEADS_PADDED: tl.constexpr,
286 | BLOCK_M: tl.constexpr,
287 | BLOCK_N: tl.constexpr,
288 | BLOCK_HEADDIM: tl.constexpr,
289 | ):
290 | index_begin_m = (
291 | max(index_start_n + actual_seqlen_q - actual_seqlen_k, 0) if IS_CAUSAL else 0
292 | )
293 | index_begin_m = (index_begin_m // BLOCK_M) * BLOCK_M
294 | index_end_m = actual_seqlen_q
295 |
296 | fully_masked_lines = (actual_seqlen_q - actual_seqlen_k) if IS_CAUSAL else 0
297 | if (index_begin_m >= actual_seqlen_q) or (index_start_n >= actual_seqlen_k):
298 | return
299 |
300 | offs_n = index_start_n + tl.arange(0, BLOCK_N)
301 | offs_m = tl.arange(0, BLOCK_M)
302 | offs_d = tl.arange(0, BLOCK_HEADDIM)
303 |
304 | q_ptrs = Q + (offs_m[:, None] * stride_qm + offs_d[None, :])
305 | k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :])
306 | v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :])
307 | dk_ptrs = Dk + (offs_n[:, None] * stride_dkn + offs_d[None, :])
308 | dv_ptrs = Dv + (offs_n[:, None] * stride_dvn + offs_d[None, :])
309 | do_ptrs = Do + (offs_m[:, None] * stride_dom + offs_d[None, :])
310 | if BIAS_ON:
311 | bias_ptrs = B + (offs_m[:, None] * stride_bm + offs_n[None, :])
312 | else:
313 | bias_ptrs = None
314 | if USE_DROPOUT:
315 | dropout_offs = Dropout + offs_m[:, None] * actual_seqlen_k + offs_n[None, :]
316 | else:
317 | dropout_offs = None
318 | dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
319 | dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
320 | k = padded_load(
321 | k_ptrs,
322 | offs_n,
323 | offs_d,
324 | PA0=PAD_COLS,
325 | PA1=HEADS_PADDED,
326 | LA0=actual_seqlen_k,
327 | LA1=headdim,
328 | )
329 | v = padded_load(
330 | v_ptrs,
331 | offs_n,
332 | offs_d,
333 | PA0=PAD_COLS,
334 | PA1=HEADS_PADDED,
335 | LA0=actual_seqlen_k,
336 | LA1=headdim,
337 | )
338 | # fmt:off
339 | fr = max(0, index_start_n + BLOCK_N - 1 + actual_seqlen_q - actual_seqlen_k)
340 | fb = BLOCK_M * ((min(fr, actual_seqlen_q) + BLOCK_M - 1) // BLOCK_M)
341 | num_masked_blocks = (fb - index_begin_m) // BLOCK_M if IS_CAUSAL else 0
342 | index_next_start_m = index_begin_m
343 | # fmt:on
344 |
345 | if num_masked_blocks > 0:
346 | for _ in range(0, num_masked_blocks):
347 | dk, dv = _attn_bwd_dkdv(
348 | index_next_start_m,
349 | k,
350 | v,
351 | dk,
352 | dv,
353 | M,
354 | D,
355 | offs_m,
356 | offs_n,
357 | offs_d,
358 | q_ptrs,
359 | bias_ptrs,
360 | dropout_offs,
361 | do_ptrs,
362 | softmax_scale,
363 | stride_qm,
364 | stride_bm,
365 | stride_dom,
366 | actual_seqlen_q,
367 | actual_seqlen_k,
368 | fully_masked_lines,
369 | headdim,
370 | MASKED=True,
371 | IS_CAUSAL=IS_CAUSAL,
372 | BIAS_ON=BIAS_ON,
373 | USE_DROPOUT=USE_DROPOUT,
374 | PAD_ROWS=True,
375 | PAD_COLS=PAD_COLS,
376 | HEADS_PADDED=HEADS_PADDED,
377 | )
378 | index_next_start_m += BLOCK_M
379 |
380 | if index_next_start_m < index_end_m:
381 | for index_start_m in range(index_next_start_m, index_end_m, BLOCK_M):
382 | dk, dv = _attn_bwd_dkdv(
383 | index_start_m,
384 | k,
385 | v,
386 | dk,
387 | dv,
388 | M,
389 | D,
390 | offs_m,
391 | offs_n,
392 | offs_d,
393 | q_ptrs,
394 | bias_ptrs,
395 | dropout_offs,
396 | do_ptrs,
397 | softmax_scale,
398 | stride_qm,
399 | stride_bm,
400 | stride_dom,
401 | actual_seqlen_q,
402 | actual_seqlen_k,
403 | fully_masked_lines,
404 | headdim,
405 | MASKED=False,
406 | IS_CAUSAL=IS_CAUSAL,
407 | BIAS_ON=BIAS_ON,
408 | USE_DROPOUT=USE_DROPOUT,
409 | PAD_ROWS=True,
410 | PAD_COLS=PAD_COLS,
411 | HEADS_PADDED=HEADS_PADDED,
412 | )
413 |
414 | if HEADS_PADDED:
415 | if PAD_COLS:
416 | tl.store(
417 | dk_ptrs,
418 | dk,
419 | mask=(offs_n[:, None] < actual_seqlen_k) & (offs_d[None, :] < headdim),
420 | )
421 | tl.store(
422 | dv_ptrs,
423 | dv,
424 | mask=(offs_n[:, None] < actual_seqlen_k) & (offs_d[None, :] < headdim),
425 | )
426 | else:
427 | tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim)
428 | tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim)
429 | else:
430 | if PAD_COLS:
431 | tl.store(dk_ptrs, dk, mask=offs_n[:, None] < actual_seqlen_k)
432 | tl.store(dv_ptrs, dv, mask=offs_n[:, None] < actual_seqlen_k)
433 | else:
434 | tl.store(dk_ptrs, dk)
435 | tl.store(dv_ptrs, dv)
436 |
437 |
438 | @triton.jit
439 | def _attn_bwd_dq(
440 | index_start_n,
441 | q,
442 | dq,
443 | do,
444 | me_i,
445 | de_i,
446 | offs_m,
447 | offs_n,
448 | offs_d,
449 | k_ptrs,
450 | v_ptrs,
451 | bias_ptrs,
452 | dropout_offs,
453 | softmax_scale,
454 | dropout_prob,
455 | dropout_seed,
456 | stride_kn,
457 | stride_vn,
458 | actual_seqlen_q,
459 | actual_seqlen_k,
460 | headdim,
461 | MASKED: tl.constexpr,
462 | IS_CAUSAL: tl.constexpr,
463 | BIAS_ON: tl.constexpr,
464 | USE_DROPOUT: tl.constexpr,
465 | PAD_COLS: tl.constexpr,
466 | HEADS_PADDED: tl.constexpr,
467 | ):
468 | k_ptrs = k_ptrs + index_start_n * stride_kn
469 | v_ptrs = v_ptrs + index_start_n * stride_vn
470 | offs_n_curr = index_start_n + offs_n
471 | if BIAS_ON:
472 | bias_ptrs += index_start_n
473 | if USE_DROPOUT:
474 | dropout_offs += index_start_n
475 | k = padded_load(
476 | k_ptrs, offs_n_curr, offs_d, PAD_COLS, HEADS_PADDED, actual_seqlen_k, headdim
477 | )
478 | v = padded_load(
479 | v_ptrs, offs_n_curr, offs_d, PAD_COLS, HEADS_PADDED, actual_seqlen_k, headdim
480 | )
481 | if BIAS_ON:
482 | bias = padded_load(
483 | bias_ptrs,
484 | offs_m,
485 | offs_n_curr,
486 | True,
487 | PAD_COLS,
488 | actual_seqlen_q,
489 | actual_seqlen_k,
490 | )
491 | qk = tl.dot(q, tl.trans(k))
492 | if BIAS_ON:
493 | qk += bias / softmax_scale
494 | offs_n_causal = offs_n_curr - actual_seqlen_k + actual_seqlen_q
495 | if MASKED:
496 | if PAD_COLS:
497 | if IS_CAUSAL:
498 | qk = tl.where(
499 | tl.minimum(actual_seqlen_q - 1, offs_m)[:, None] >= offs_n_causal[None, :],
500 | qk,
501 | float("-inf"),
502 | )
503 | else:
504 | qk = tl.where(actual_seqlen_q - 1 >= offs_n_causal[None, :], qk, float("-inf"))
505 | elif IS_CAUSAL:
506 | qk = tl.where(offs_m[:, None] >= offs_n_causal[None, :], qk, float("-inf"))
507 | tl.debug_barrier()
508 |
509 | p = tl.exp2(qk * (softmax_scale * 1.44269504089) - me_i[:, None])
510 | dp = tl.dot(do, tl.trans(v))
511 |
512 | ds = (p * (dp - de_i[:, None]) * softmax_scale).to(q.dtype)
513 |
514 | dq += tl.dot(ds, k)
515 |
516 | return dq
517 |
518 |
519 | @triton.jit
520 | def _attn_bwd_block_dq(
521 | index_start_m,
522 | Q,
523 | K,
524 | V,
525 | B,
526 | Dropout,
527 | Do,
528 | Dq,
529 | M,
530 | D,
531 | softmax_scale,
532 | dropout_prob,
533 | dropout_seed,
534 | stride_qm,
535 | stride_kn,
536 | stride_vn,
537 | stride_bm,
538 | stride_dom,
539 | stride_dqm,
540 | actual_seqlen_q,
541 | actual_seqlen_k,
542 | headdim,
543 | VARLEN: tl.constexpr,
544 | IS_CAUSAL: tl.constexpr,
545 | BIAS_ON: tl.constexpr,
546 | USE_DROPOUT: tl.constexpr,
547 | PAD_ROWS: tl.constexpr,
548 | HEADS_PADDED: tl.constexpr,
549 | BLOCK_M: tl.constexpr,
550 | BLOCK_N: tl.constexpr,
551 | BLOCK_HEADDIM: tl.constexpr,
552 | EVEN_N: tl.constexpr,
553 | ):
554 | if IS_CAUSAL:
555 | index_end_n = min(
556 | actual_seqlen_k - actual_seqlen_q + index_start_m + BLOCK_M,
557 | actual_seqlen_k,
558 | )
559 |
560 | if index_end_n < 0:
561 | return
562 | else:
563 | index_end_n = actual_seqlen_k
564 |
565 | fully_masked_lines = actual_seqlen_q - actual_seqlen_k if IS_CAUSAL else 0
566 | mask_reached = fully_masked_lines >= index_start_m + BLOCK_M
567 | if (index_start_m >= actual_seqlen_q) or mask_reached:
568 | return
569 |
570 | offs_m = tl.arange(0, BLOCK_M) + index_start_m
571 | offs_n = tl.arange(0, BLOCK_N)
572 | offs_d = tl.arange(0, BLOCK_HEADDIM)
573 |
574 | q_ptrs = Q + (offs_m[:, None] * stride_qm + offs_d[None, :])
575 | k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :])
576 | v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :])
577 |
578 | dq_ptrs = Dq + (offs_m[:, None] * stride_dqm + offs_d[None, :])
579 | do_ptrs = Do + (offs_m[:, None] * stride_dom + offs_d[None, :])
580 |
581 | if BIAS_ON:
582 | bias_ptrs = B + (offs_m[:, None] * stride_bm + offs_n[None, :])
583 | else:
584 | bias_ptrs = None
585 |
586 | if USE_DROPOUT:
587 | dropout_offs = Dropout + (offs_m[:, None] * stride_bm + offs_n[None, :])
588 | else:
589 | dropout_offs = None
590 |
591 | dq = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
592 | q = padded_load(
593 | q_ptrs,
594 | offs_m,
595 | offs_d,
596 | PA0=PAD_ROWS,
597 | PA1=HEADS_PADDED,
598 | LA0=actual_seqlen_q,
599 | LA1=headdim,
600 | )
601 | do = padded_load(
602 | do_ptrs,
603 | offs_m,
604 | offs_d,
605 | PA0=PAD_ROWS,
606 | PA1=HEADS_PADDED,
607 | LA0=actual_seqlen_q,
608 | LA1=headdim,
609 | )
610 | me_i = tl.load(M + offs_m)
611 | de_i = tl.load(D + offs_m)
612 |
613 | uneven_n = actual_seqlen_k % BLOCK_N != 0
614 | attention_padding = VARLEN & uneven_n
615 | if IS_CAUSAL:
616 | first_masked_col = index_start_m + 1 + actual_seqlen_k - actual_seqlen_q
617 | elif attention_padding:
618 | first_masked_col = actual_seqlen_k
619 | else:
620 | first_masked_col = index_end_n
621 | nb_full_blocks = first_masked_col // BLOCK_N
622 |
623 | index_next_start_n = 0
624 | if nb_full_blocks > 0:
625 | for _ in range(0, nb_full_blocks):
626 | index_next_start_n = tl.multiple_of(index_next_start_n, BLOCK_N)
627 | dq = _attn_bwd_dq(
628 | index_next_start_n,
629 | q,
630 | dq,
631 | do,
632 | me_i,
633 | de_i,
634 | offs_m,
635 | offs_n,
636 | offs_d,
637 | k_ptrs,
638 | v_ptrs,
639 | bias_ptrs,
640 | dropout_offs,
641 | softmax_scale,
642 | dropout_prob,
643 | dropout_seed,
644 | stride_kn,
645 | stride_vn,
646 | actual_seqlen_q,
647 | actual_seqlen_k,
648 | headdim,
649 | IS_CAUSAL=IS_CAUSAL,
650 | BIAS_ON=BIAS_ON,
651 | USE_DROPOUT=USE_DROPOUT,
652 | MASKED=False,
653 | PAD_COLS=False,
654 | HEADS_PADDED=HEADS_PADDED,
655 | )
656 | index_next_start_n += BLOCK_N
657 |
658 | if index_next_start_n < index_end_n:
659 | for index_start_n in range(index_next_start_n, index_end_n, BLOCK_N):
660 | pad_cols = (not EVEN_N) or (
661 | VARLEN and (index_start_n + BLOCK_N > actual_seqlen_k)
662 | )
663 | dq = _attn_bwd_dq(
664 | index_start_n,
665 | q,
666 | dq,
667 | do,
668 | me_i,
669 | de_i,
670 | offs_m,
671 | offs_n,
672 | offs_d,
673 | k_ptrs,
674 | v_ptrs,
675 | bias_ptrs,
676 | dropout_offs,
677 | softmax_scale,
678 | dropout_prob,
679 | dropout_seed,
680 | stride_kn,
681 | stride_vn,
682 | actual_seqlen_q,
683 | actual_seqlen_k,
684 | headdim,
685 | IS_CAUSAL=IS_CAUSAL,
686 | BIAS_ON=BIAS_ON,
687 | USE_DROPOUT=USE_DROPOUT,
688 | MASKED=True,
689 | PAD_COLS=pad_cols,
690 | HEADS_PADDED=HEADS_PADDED,
691 | )
692 |
693 | if fully_masked_lines > 0:
694 | dq = tl.where(offs_m[:, None] < fully_masked_lines, 0, dq)
695 |
696 | if HEADS_PADDED:
697 | if PAD_ROWS:
698 | tl.store(
699 | dq_ptrs,
700 | dq,
701 | mask=(offs_m[:, None] < actual_seqlen_q) & (offs_d[None, :] < headdim),
702 | )
703 | else:
704 | tl.store(dq_ptrs, dq, mask=offs_d[None, :] < headdim)
705 | else:
706 | if PAD_ROWS:
707 | tl.store(dq_ptrs, dq, mask=offs_m[:, None] < actual_seqlen_q)
708 | else:
709 | tl.store(dq_ptrs, dq)
710 |
711 |
712 | @safe_autotune(
713 | configs=[
714 | Config(
715 | {"BLOCK_M1": 16, "BLOCK_N1": 16, "BLOCK_M2": 16, "BLOCK_N2": 16},
716 | num_warps=4,
717 | num_stages=0,
718 | ),
719 | Config(
720 | {"BLOCK_M1": 32, "BLOCK_N1": 16, "BLOCK_M2": 16, "BLOCK_N2": 32},
721 | num_warps=4,
722 | num_stages=0,
723 | ),
724 | Config(
725 | {"BLOCK_M1": 32, "BLOCK_N1": 64, "BLOCK_M2": 64, "BLOCK_N2": 32},
726 | num_warps=4,
727 | num_stages=0,
728 | ),
729 | Config(
730 | {"BLOCK_M1": 64, "BLOCK_N1": 64, "BLOCK_M2": 64, "BLOCK_N2": 64},
731 | num_warps=4,
732 | num_stages=0,
733 | ),
734 | Config(
735 | {"BLOCK_M1": 64, "BLOCK_N1": 128, "BLOCK_M2": 128, "BLOCK_N2": 64},
736 | num_warps=4,
737 | num_stages=0,
738 | ),
739 | ],
740 | key=[
741 | "CQSeq",
742 | "CKSeq",
743 | "DRuntime",
744 | "VARLEN",
745 | "USE_DROPOUT",
746 | "IS_CAUSAL",
747 | "BIAS_ON",
748 | "BLOCK_HEADDIM",
749 | ],
750 | prune_configs_by={"early_config_prune": config_prune_kernel},
751 | )
752 | @triton.heuristics(
753 | {
754 | "EVEN_M1": lambda args: args["QSeq"] % args["BLOCK_M1"] == 0,
755 | "EVEN_N1": lambda args: args["KSeq"] % args["BLOCK_N1"] == 0,
756 | "EVEN_M2": lambda args: args["QSeq"] % args["BLOCK_M2"] == 0,
757 | "EVEN_N2": lambda args: args["KSeq"] % args["BLOCK_N2"] == 0,
758 | "HEADS_PADDED": lambda args: args["headdim"] != args["BLOCK_HEADDIM"],
759 | "NUM_BLOCKS_KV": lambda args: math.ceil(args["KSeq"] / args["BLOCK_N1"]),
760 | }
761 | )
762 | @triton.jit
763 | def _attn_bwd(
764 | Q,
765 | K,
766 | V,
767 | B,
768 | Do,
769 | M,
770 | D,
771 | softmax_scale,
772 | dropout_prob,
773 | dropout_seed,
774 | stride_qz,
775 | stride_qm,
776 | stride_qh,
777 | stride_kz,
778 | stride_kn,
779 | stride_kh,
780 | stride_vz,
781 | stride_vn,
782 | stride_vh,
783 | stride_bz,
784 | stride_bm,
785 | stride_bh,
786 | stride_doz,
787 | stride_dom,
788 | stride_doh,
789 | stride_dqz,
790 | stride_dqm,
791 | stride_dqh,
792 | stride_dkz,
793 | stride_dkn,
794 | stride_dkh,
795 | stride_dvz,
796 | stride_dvn,
797 | stride_dvh,
798 | nheads_q,
799 | num_repeats,
800 | QSeq,
801 | cum_seqlens_q,
802 | KSeq,
803 | cum_seqlens_k,
804 | seqlen_q_rounded,
805 | headdim,
806 | CQSeq,
807 | CKSeq,
808 | DRuntime,
809 | Dq,
810 | Dk,
811 | Dv,
812 | VARLEN: tl.constexpr,
813 | IS_CAUSAL: tl.constexpr,
814 | BIAS_ON: tl.constexpr,
815 | USE_DROPOUT: tl.constexpr,
816 | BLOCK_HEADDIM: tl.constexpr,
817 | # Heuristics
818 | EVEN_M1: tl.constexpr,
819 | EVEN_N1: tl.constexpr,
820 | EVEN_M2: tl.constexpr,
821 | EVEN_N2: tl.constexpr,
822 | NUM_BLOCKS_KV: tl.constexpr,
823 | HEADS_PADDED: tl.constexpr,
824 | # AutoTune
825 | BLOCK_M1: tl.constexpr,
826 | BLOCK_N1: tl.constexpr,
827 | BLOCK_M2: tl.constexpr,
828 | BLOCK_N2: tl.constexpr,
829 | ):
830 | pid = tl.program_id(0)
831 | off_zh = tl.program_id(1)
832 | off_z = off_zh // nheads_q
833 | off_head_q = off_zh % nheads_q
834 | off_head_kv = off_head_q // num_repeats
835 |
836 | if VARLEN:
837 | cu_seq_start_q = tl.load(cum_seqlens_q + off_z)
838 | cu_seq_start_k = tl.load(cum_seqlens_k + off_z)
839 | actual_seqlen_q = tl.load(cum_seqlens_q + off_z + 1) - cu_seq_start_q
840 | actual_seqlen_k = tl.load(cum_seqlens_k + off_z + 1) - cu_seq_start_k
841 | off_z = 0
842 | else:
843 | cu_seq_start_q = 0
844 | cu_seq_start_k = 0
845 | actual_seqlen_q = QSeq
846 | actual_seqlen_k = KSeq
847 |
848 | Q += off_z * stride_qz + off_head_q * stride_qh + cu_seq_start_q * stride_qm
849 | K += off_z * stride_kz + off_head_kv * stride_kh + cu_seq_start_k * stride_kn
850 | V += off_z * stride_vz + off_head_kv * stride_vh + cu_seq_start_k * stride_vn
851 |
852 | Do += off_z * stride_doz + off_head_q * stride_doh + cu_seq_start_q * stride_dom
853 | Dq += off_z * stride_dqz + off_head_q * stride_dqh + cu_seq_start_q * stride_dqm
854 | Dk += off_z * stride_dkz + off_head_q * stride_dkh + cu_seq_start_k * stride_dkn
855 | Dv += off_z * stride_dvz + off_head_q * stride_dvh + cu_seq_start_k * stride_dvn
856 |
857 | if BIAS_ON:
858 | B += off_z * stride_bz + off_head_q * stride_bh + cu_seq_start_q * stride_bm
859 | if USE_DROPOUT:
860 | Dropout = actual_seqlen_k * (
861 | cu_seq_start_q + actual_seqlen_q * (off_head_q + nheads_q * off_z)
862 | )
863 | else:
864 | Dropout = None
865 |
866 | D += off_zh * seqlen_q_rounded
867 | M += off_zh * seqlen_q_rounded
868 |
869 | if pid < NUM_BLOCKS_KV:
870 | i_start_n = pid
871 | pad_cols = (not EVEN_N1) or (
872 | VARLEN and ((i_start_n + 1) * BLOCK_N1 > actual_seqlen_k)
873 | )
874 | _attn_bwd_block_dkdv(
875 | i_start_n * BLOCK_N1,
876 | Q,
877 | K,
878 | V,
879 | B,
880 | Dropout,
881 | Do,
882 | Dk,
883 | Dv,
884 | M,
885 | D,
886 | softmax_scale,
887 | stride_qm,
888 | stride_kn,
889 | stride_vn,
890 | stride_bm,
891 | stride_dom,
892 | stride_dkn,
893 | stride_dvn,
894 | actual_seqlen_q,
895 | actual_seqlen_k,
896 | headdim,
897 | IS_CAUSAL=IS_CAUSAL,
898 | BIAS_ON=BIAS_ON,
899 | USE_DROPOUT=USE_DROPOUT,
900 | PAD_COLS=pad_cols,
901 | HEADS_PADDED=HEADS_PADDED,
902 | BLOCK_M=BLOCK_M1,
903 | BLOCK_N=BLOCK_N1,
904 | BLOCK_HEADDIM=BLOCK_HEADDIM,
905 | )
906 |
907 | else:
908 | i_start_m = pid - NUM_BLOCKS_KV
909 | pad_rows = (not EVEN_M2) or (
910 | VARLEN and ((i_start_m + 1) * BLOCK_M2 > actual_seqlen_q)
911 | )
912 | _attn_bwd_block_dq(
913 | i_start_m * BLOCK_M2,
914 | Q,
915 | K,
916 | V,
917 | B,
918 | Dropout,
919 | Do,
920 | Dq,
921 | M,
922 | D,
923 | softmax_scale,
924 | dropout_prob,
925 | dropout_seed,
926 | stride_qm,
927 | stride_kn,
928 | stride_vn,
929 | stride_bm,
930 | stride_dom,
931 | stride_dqm,
932 | actual_seqlen_q,
933 | actual_seqlen_k,
934 | headdim,
935 | VARLEN=VARLEN,
936 | IS_CAUSAL=IS_CAUSAL,
937 | BIAS_ON=BIAS_ON,
938 | USE_DROPOUT=USE_DROPOUT,
939 | PAD_ROWS=pad_rows,
940 | HEADS_PADDED=HEADS_PADDED,
941 | BLOCK_M=BLOCK_M2,
942 | BLOCK_N=BLOCK_N2,
943 | BLOCK_HEADDIM=BLOCK_HEADDIM,
944 | EVEN_N=EVEN_N2,
945 | )
946 |
947 |
948 | def _bwd_attention_kernel_call(
949 | dO: chex.Array,
950 | q: chex.Array,
951 | k: chex.Array,
952 | v: chex.Array,
953 | bias: tp.Optional[chex.Array],
954 | attention_mask: tp.Optional[chex.Array],
955 | o: chex.Array,
956 | M: chex.Array,
957 | dropout_prob: float,
958 | causal: bool,
959 | softmax_scale: tp.Optional[float],
960 | dropout_seed: tp.Optional[int],
961 | varlen_mode: bool,
962 | ):
963 | """Calls the Triton kernel for the backward pass of the attention mechanism.
964 |
965 | Args:
966 | softmax_scale: Scaling factor for the softmax function.
967 | residual: Residual from the forward pass.
968 | Do: Output gradient array.
969 |
970 | Returns:
971 | Tuple of the gradients of the query, key, value, and bias arrays.
972 | """
973 | if attention_mask is not None and varlen_mode:
974 | assert bias is None, (
975 | "Attention mask is not supported along with attention bias. Just use bias instead."
976 | )
977 | assert q.shape[1] == k.shape[1], "Attention mask is not supported with QSeq != KSeq"
978 | varlen_mode = attention_mask.shape[0] > 1
979 | useless_padding = attention_mask.shape[1] - attention_mask.sum(-1).max().item()
980 | if useless_padding > 0:
981 | dO = dO[:, :-useless_padding]
982 | q = q[:, :-useless_padding]
983 | k = k[:, :-useless_padding]
984 | v = v[:, :-useless_padding]
985 | attention_mask = attention_mask[:, :-useless_padding]
986 | o = o[:, :-useless_padding]
987 | else:
988 | varlen_mode = False
989 | useless_padding = 0
990 |
991 | batch_size, QSeq, nheads_q, head_dim = q.shape
992 | _, KSeq, nheads_kv, _ = k.shape
993 | max_seqlen_q_rounded = math.ceil(QSeq / 128) * 128
994 | softmax_scale = 1.0 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale
995 | assert nheads_q % nheads_kv == 0, f"{nheads_q = } is not divisible by {nheads_kv =}"
996 | assert M.shape == (batch_size, nheads_q, max_seqlen_q_rounded)
997 |
998 | if varlen_mode:
999 | cum_seqlens_q = jnp.zeros(shape=(attention_mask.shape[0] + 1,), dtype="i4")
1000 | cum_seqlens_k = jnp.zeros(shape=(attention_mask.shape[0] + 1,), dtype="i4")
1001 | cum_seqlens_k = cum_seqlens_k.at[1:].set(
1002 | jnp.cumsum(
1003 | attention_mask.sum(axis=1, dtype="i4"),
1004 | axis=0,
1005 | dtype="i4",
1006 | )
1007 | )
1008 | cum_seqlens_q = cum_seqlens_q.at[1:].set(
1009 | jnp.cumsum(
1010 | attention_mask.sum(axis=1, dtype="i4"),
1011 | axis=0,
1012 | dtype="i4",
1013 | )
1014 | )
1015 | max_seqlen_q: int = attention_mask.shape[1]
1016 | max_seqlen_k: int = attention_mask.shape[1]
1017 |
1018 | dO = attention_pack_with_static_shape(dO, attention_mask)
1019 |
1020 | q = attention_pack_with_static_shape(q, attention_mask)
1021 | k = attention_pack_with_static_shape(k, attention_mask)
1022 | v = attention_pack_with_static_shape(v, attention_mask)
1023 | o = attention_pack_with_static_shape(o, attention_mask)
1024 | QSeq = q.shape[1]
1025 | KSeq = k.shape[1]
1026 | else:
1027 | cum_seqlens_q = None
1028 | cum_seqlens_k = None
1029 | max_seqlen_q = QSeq
1030 | max_seqlen_k = KSeq
1031 |
1032 | bz, bh, bm = calc_bias_strides(
1033 | bias,
1034 | batch_size,
1035 | nheads_q,
1036 | QSeq,
1037 | KSeq,
1038 | )
1039 |
1040 | num_repeats = nheads_q // nheads_kv
1041 | BLOCK_HEADDIM = max(triton.next_power_of_2(head_dim), 16)
1042 |
1043 | oz, om, oh, _ = get_strides(o)
1044 | doz, dom, doh, _ = get_strides(dO)
1045 | qz, qm, qh, _ = get_strides(q)
1046 | kz, kn, kh, _ = get_strides(k)
1047 | vz, vn, vh, _ = get_strides(v)
1048 |
1049 | (delta,) = triton_call(
1050 | o,
1051 | dO,
1052 | oz,
1053 | om,
1054 | oh,
1055 | doz,
1056 | dom,
1057 | doh,
1058 | nheads_q,
1059 | QSeq,
1060 | max_seqlen_q_rounded,
1061 | cum_seqlens_q if cum_seqlens_q is not None else jnp.array((1,), dtype=jnp.float16),
1062 | head_dim,
1063 | max_seqlen_q // 32,
1064 | dtype_index(q),
1065 | VARLEN=varlen_mode,
1066 | BLOCK_HEADDIM=BLOCK_HEADDIM,
1067 | out_shape=[
1068 | jax.ShapeDtypeStruct(
1069 | shape=M.shape,
1070 | dtype="f4",
1071 | sharding=get_sharding(M),
1072 | )
1073 | ],
1074 | grid=lambda META: (
1075 | triton.cdiv(max_seqlen_q, META["BLOCK_M"]),
1076 | batch_size * nheads_q,
1077 | ),
1078 | kernel=_attn_bwd_preprocess,
1079 | name="triton::ops::_attn_bwd_preprocess",
1080 | )
1081 |
1082 | dq, dk, dv = triton_call(
1083 | q,
1084 | k,
1085 | v,
1086 | bias if bias is not None else jnp.zeros((1,), jnp.float16),
1087 | dO,
1088 | M,
1089 | delta,
1090 | softmax_scale,
1091 | dropout_prob,
1092 | dropout_seed if dropout_seed is not None else jnp.zeros((1,), jnp.float16),
1093 | qz,
1094 | qm,
1095 | qh,
1096 | kz,
1097 | kn,
1098 | kh,
1099 | vz,
1100 | vn,
1101 | vh,
1102 | bz,
1103 | bm,
1104 | bh,
1105 | doz,
1106 | dom,
1107 | doh,
1108 | qz,
1109 | qm,
1110 | qh,
1111 | kz,
1112 | kn,
1113 | kh,
1114 | vz,
1115 | vn,
1116 | vh,
1117 | nheads_q,
1118 | num_repeats,
1119 | QSeq,
1120 | cum_seqlens_q if cum_seqlens_q is not None else jnp.zeros((1,), jnp.float16),
1121 | KSeq,
1122 | cum_seqlens_k if cum_seqlens_k is not None else jnp.zeros((1,), jnp.float16),
1123 | max_seqlen_q_rounded,
1124 | head_dim,
1125 | max_seqlen_q // 32,
1126 | max_seqlen_k // 32,
1127 | dtype_index(q),
1128 | VARLEN=varlen_mode,
1129 | IS_CAUSAL=causal,
1130 | BIAS_ON=(bias is not None),
1131 | USE_DROPOUT=(dropout_prob > 0),
1132 | BLOCK_HEADDIM=BLOCK_HEADDIM,
1133 | kernel=_attn_bwd,
1134 | grid=lambda META: (
1135 | triton.cdiv(KSeq, META["BLOCK_N1"]) + triton.cdiv(QSeq, META["BLOCK_M2"]),
1136 | batch_size * nheads_q,
1137 | ),
1138 | out_shape=[
1139 | jax.ShapeDtypeStruct(
1140 | shape=q.shape,
1141 | dtype="f4",
1142 | sharding=get_sharding(q),
1143 | ),
1144 | jax.ShapeDtypeStruct(
1145 | shape=(k.shape[0], k.shape[1], q.shape[2], k.shape[3]),
1146 | dtype=k.dtype,
1147 | ),
1148 | jax.ShapeDtypeStruct(
1149 | shape=(v.shape[0], v.shape[1], q.shape[2], v.shape[3]),
1150 | dtype=v.dtype,
1151 | ),
1152 | ],
1153 | name="triton::ops::_attn_bwd",
1154 | )
1155 |
1156 | if num_repeats > 1:
1157 | dk = dk.reshape(dk.shape[0], dk.shape[1], nheads_kv, num_repeats, -1)
1158 | dk = jnp.sum(dk, axis=3)
1159 |
1160 | dv = dv.reshape(dv.shape[0], dv.shape[1], nheads_kv, num_repeats, -1)
1161 | dv = jnp.sum(dv, axis=3)
1162 |
1163 | if varlen_mode:
1164 | dq = attention_unpack_with_static_shape(dq, cum_seqlens_q, batch_size, max_seqlen_q)
1165 | dk = attention_unpack_with_static_shape(dk, cum_seqlens_k, batch_size, max_seqlen_k)
1166 | dv = attention_unpack_with_static_shape(dv, cum_seqlens_k, batch_size, max_seqlen_k)
1167 |
1168 | if useless_padding > 0:
1169 | dq = jnp.pad(dq, ((0, useless_padding), (0, 0), (0, 0)))
1170 | dk = jnp.pad(dk, ((0, useless_padding), (0, 0), (0, 0)))
1171 | dv = jnp.pad(dv, ((0, useless_padding), (0, 0), (0, 0)))
1172 |
1173 | return dq, dk, dv
1174 |
--------------------------------------------------------------------------------
/jax_flash_attn2/flash_attention_triton/_flash_attention.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import functools
15 | import typing as tp
16 | import warnings
17 |
18 | import chex
19 | import jax
20 |
21 | from ._backward_triton import _bwd_attention_kernel_call
22 | from ._forward_triton import _fwd_attention_kernel_call
23 |
24 | DEV_MODE = True
25 |
26 |
27 | def _jax_fwd_attention_call(
28 | q: tp.Optional[chex.Array],
29 | k: tp.Optional[chex.Array],
30 | v: tp.Optional[chex.Array],
31 | attention_mask: tp.Optional[chex.Array] = None,
32 | bias: tp.Optional[chex.Array] = None,
33 | softmax_scale: tp.Optional[float] = None,
34 | dropout_prob: float = 0.0,
35 | causal: bool = False,
36 | dropout_seed: tp.Optional[int] = None,
37 | varlen_mode: bool = True,
38 | ):
39 | out, lse = _fwd_attention_kernel_call(
40 | q=q,
41 | k=k,
42 | v=v,
43 | attention_mask=attention_mask,
44 | bias=bias,
45 | softmax_scale=softmax_scale,
46 | dropout_prob=dropout_prob,
47 | causal=causal,
48 | dropout_seed=dropout_seed,
49 | varlen_mode=varlen_mode,
50 | )
51 | residual = (
52 | q,
53 | k,
54 | v,
55 | bias,
56 | attention_mask,
57 | out,
58 | lse,
59 | dropout_seed,
60 | )
61 | return out, residual
62 |
63 |
64 | def _jax_bwd_attention_call(
65 | softmax_scale: tp.Optional[float],
66 | dropout_prob: float,
67 | causal: bool,
68 | varlen_mode: bool,
69 | residual: tp.Tuple[chex.Array],
70 | dO: chex.Array,
71 | ):
72 | q, k, v, bias, attention_mask, out, lse, dropout_seed = residual
73 | dq, dk, dv = _bwd_attention_kernel_call(
74 | dO=dO,
75 | q=q,
76 | k=k,
77 | v=v,
78 | bias=bias,
79 | attention_mask=attention_mask,
80 | o=out,
81 | M=lse,
82 | dropout_prob=dropout_prob,
83 | causal=causal,
84 | dropout_seed=dropout_seed,
85 | softmax_scale=softmax_scale,
86 | varlen_mode=varlen_mode,
87 | )
88 | return dq, dk, dv, None, None, None
89 |
90 |
91 | @functools.partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 9))
92 | @functools.partial(jax.jit, static_argnums=(5, 6, 7, 9))
93 | def flash_attention_call(
94 | q: tp.Optional[chex.Array],
95 | k: tp.Optional[chex.Array],
96 | v: tp.Optional[chex.Array],
97 | attention_mask: tp.Optional[chex.Array] = None,
98 | bias: tp.Optional[chex.Array] = None,
99 | softmax_scale: tp.Optional[float] = None,
100 | dropout_prob: float = 0.0,
101 | causal: bool = False,
102 | dropout_seed: tp.Optional[int] = None,
103 | varlen_mode: bool = True,
104 | ) -> chex.Array:
105 | return _fwd_attention_kernel_call(
106 | q=q,
107 | k=k,
108 | v=v,
109 | attention_mask=attention_mask,
110 | bias=bias,
111 | softmax_scale=softmax_scale,
112 | dropout_prob=dropout_prob,
113 | causal=causal,
114 | dropout_seed=dropout_seed,
115 | varlen_mode=varlen_mode,
116 | )[0]
117 |
118 |
119 | flash_attention_call.defvjp(
120 | _jax_fwd_attention_call,
121 | _jax_bwd_attention_call,
122 | )
123 |
124 |
125 | def flash_attention(
126 | q: tp.Optional[chex.Array],
127 | k: tp.Optional[chex.Array],
128 | v: tp.Optional[chex.Array],
129 | attention_mask: tp.Optional[chex.Array] = None,
130 | bias: tp.Optional[chex.Array] = None,
131 | softmax_scale: tp.Optional[float] = None,
132 | dropout_prob: float = 0.0,
133 | causal: bool = False,
134 | dropout_seed: tp.Optional[int] = None,
135 | varlen_mode: bool = True,
136 | ) -> chex.Array:
137 | # TODO: Debug Varlen Mode
138 | if attention_mask is not None and not DEV_MODE and bias is None:
139 | warnings.warn(
140 | "Varlen Mode and attention mask passing is still under development",
141 | stacklevel=1,
142 | )
143 |
144 | attention_mask = None
145 | out = flash_attention_call(
146 | q=q,
147 | k=k,
148 | v=v,
149 | attention_mask=attention_mask,
150 | bias=bias,
151 | softmax_scale=softmax_scale,
152 | dropout_prob=dropout_prob,
153 | causal=causal,
154 | dropout_seed=dropout_seed,
155 | varlen_mode=varlen_mode,
156 | )
157 | return out
158 |
--------------------------------------------------------------------------------
/jax_flash_attn2/flash_attention_triton/_forward_triton.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import math
16 | import typing as tp
17 |
18 | import chex
19 | import jax
20 | import jax.numpy as jnp
21 | import triton
22 | import triton.language as tl
23 | from eformer.callib import triton_call
24 | from triton import Config
25 |
26 | from ._utils import (
27 | dtype_index,
28 | get_sharding,
29 | get_strides,
30 | safe_autotune,
31 | attention_pack_with_static_shape,
32 | attention_unpack_with_static_shape,
33 | calc_bias_strides,
34 | padded_load,
35 | )
36 |
37 |
38 | def config_prune_kernel(
39 | configs: tp.List[Config],
40 | named_args: tp.Dict[str, tp.Any],
41 | **kwargs,
42 | ) -> tp.List[Config]:
43 | kept_configs = []
44 | for config in configs:
45 | largerst_m = config.kwargs["BLOCK_M"] > named_args["QSeq"]
46 | largerst_n = config.kwargs["BLOCK_N"] > named_args["KSeq"]
47 | if largerst_m or largerst_n:
48 | pass
49 | else:
50 | kept_configs.append(config)
51 | if kept_configs:
52 | return kept_configs
53 | return [
54 | Config({"BLOCK_M": 16, "BLOCK_N": 64}, num_warps=4, num_stages=1),
55 | Config({"BLOCK_M": 16, "BLOCK_N": 64}, num_warps=4, num_stages=3),
56 | Config({"BLOCK_M": 16, "BLOCK_N": 64}, num_warps=4, num_stages=5),
57 | ]
58 |
59 |
60 | @triton.jit
61 | def _attn_fwd_inner(
62 | q,
63 | m_i,
64 | me_i,
65 | k_ptrs,
66 | v_ptrs,
67 | bias_ptrs,
68 | acc_o,
69 | offs_m,
70 | offs_n,
71 | offs_d,
72 | softmax_scale,
73 | dropout_prob,
74 | dropout_seed,
75 | dropout_offs,
76 | stride_kn,
77 | stride_vn,
78 | index_start_n,
79 | actual_seqlen_q,
80 | actual_seqlen_k,
81 | headdim,
82 | USE_DROPOUT: tl.constexpr,
83 | IS_CAUSAL: tl.constexpr,
84 | BIAS_ON: tl.constexpr,
85 | MASKED: tl.constexpr,
86 | PADDED_COLS: tl.constexpr,
87 | PADDED_HEADS: tl.constexpr,
88 | BLOCK_M: tl.constexpr,
89 | BLOCK_N: tl.constexpr,
90 | ):
91 | LN2: tl.constexpr = 1.44269504089
92 | index_start_n = tl.multiple_of(index_start_n, BLOCK_N)
93 | offset_k_ptrs = k_ptrs + index_start_n * stride_kn
94 | k = padded_load(
95 | offset_k_ptrs,
96 | index_start_n + offs_n,
97 | offs_d,
98 | PA0=PADDED_COLS,
99 | PA1=PADDED_HEADS,
100 | LA0=actual_seqlen_k,
101 | LA1=headdim,
102 | )
103 | if BIAS_ON:
104 | if PADDED_COLS:
105 | bias = tl.load(
106 | bias_ptrs + index_start_n,
107 | mask=(offs_m[:, None] < actual_seqlen_q)
108 | & ((index_start_n + offs_n) < actual_seqlen_k)[None, :],
109 | other=0.0,
110 | )
111 | else:
112 | bias = tl.load(
113 | bias_ptrs + index_start_n,
114 | mask=offs_m[:, None] < actual_seqlen_q,
115 | other=0.0,
116 | )
117 |
118 | qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
119 | qk += tl.dot(q, tl.trans(k))
120 | if PADDED_COLS:
121 | qk += tl.where(
122 | (index_start_n + offs_n)[None, :] < actual_seqlen_k,
123 | 0,
124 | float("-inf"),
125 | )
126 |
127 | if MASKED and IS_CAUSAL:
128 | causal_mask = (
129 | offs_m[:, None]
130 | >= (index_start_n + offs_n - actual_seqlen_k + actual_seqlen_q)[None, :]
131 | )
132 | qk += tl.where(causal_mask, 0, float("-inf"))
133 |
134 | if BIAS_ON:
135 | qk += bias * (LN2 / softmax_scale)
136 | m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, me_i)
137 | P_ij = tl.exp2(qk * softmax_scale - m_ij[:, None])
138 | l_ij = tl.sum(P_ij, 1)
139 |
140 | if USE_DROPOUT:
141 | dropout_offs = dropout_offs + index_start_n
142 | dropout_mask = tl.rand(dropout_seed, dropout_offs) > dropout_prob
143 | P_ij = tl.where(dropout_mask, P_ij, 0.0)
144 |
145 | acc_o_scale = tl.exp2(m_i - m_ij)
146 | acc_o = acc_o * acc_o_scale[:, None]
147 |
148 | offset_v_ptrs = v_ptrs + index_start_n * stride_vn
149 | v = padded_load(
150 | offset_v_ptrs,
151 | index_start_n + offs_n,
152 | offs_d,
153 | PA0=PADDED_COLS,
154 | PA1=PADDED_HEADS,
155 | LA0=actual_seqlen_k,
156 | LA1=headdim,
157 | )
158 |
159 | P_ij = P_ij.to(v.dtype)
160 | acc_o += tl.dot(P_ij, v)
161 | m_i = m_ij
162 | l_i_new = tl.exp2(me_i - m_ij) + l_ij
163 | me_i = m_ij + tl.log2(l_i_new)
164 | return m_i, me_i, acc_o
165 |
166 |
167 | @safe_autotune(
168 | configs=[
169 | triton.Config({"BLOCK_M": 32, "BLOCK_N": 32}, num_warps=4, num_stages=5),
170 | triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=5),
171 | triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4, num_stages=5),
172 | triton.Config({"BLOCK_M": 256, "BLOCK_N": 256}, num_warps=4, num_stages=5),
173 | triton.Config({"BLOCK_M": 32, "BLOCK_N": 32}, num_warps=4, num_stages=3),
174 | triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=3),
175 | triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4, num_stages=3),
176 | triton.Config({"BLOCK_M": 256, "BLOCK_N": 256}, num_warps=4, num_stages=3),
177 | triton.Config({"BLOCK_M": 32, "BLOCK_N": 32}, num_warps=4, num_stages=1),
178 | triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1),
179 | triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4, num_stages=1),
180 | triton.Config({"BLOCK_M": 256, "BLOCK_N": 256}, num_warps=4, num_stages=1),
181 | ###
182 | triton.Config({"BLOCK_M": 32, "BLOCK_N": 32}, num_warps=8, num_stages=5),
183 | triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=8, num_stages=5),
184 | triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=8, num_stages=5),
185 | triton.Config({"BLOCK_M": 256, "BLOCK_N": 256}, num_warps=8, num_stages=5),
186 | triton.Config({"BLOCK_M": 32, "BLOCK_N": 32}, num_warps=8, num_stages=3),
187 | triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=8, num_stages=3),
188 | triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=8, num_stages=3),
189 | triton.Config({"BLOCK_M": 256, "BLOCK_N": 256}, num_warps=8, num_stages=3),
190 | triton.Config({"BLOCK_M": 32, "BLOCK_N": 32}, num_warps=8, num_stages=1),
191 | triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=8, num_stages=1),
192 | triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=8, num_stages=1),
193 | triton.Config({"BLOCK_M": 256, "BLOCK_N": 256}, num_warps=8, num_stages=1),
194 | ],
195 | key=[
196 | "CKSeq",
197 | "CQSeq",
198 | "DRuntime",
199 | "VARLEN",
200 | "USE_DROPOUT",
201 | "IS_CAUSAL",
202 | "BIAS_ON",
203 | "BLOCK_HEADDIM",
204 | ],
205 | prune_configs_by={"early_config_prune": config_prune_kernel},
206 | )
207 | @triton.heuristics(
208 | {
209 | "EVEN_M": lambda args: args["QSeq"] % args["BLOCK_M"] == 0,
210 | "EVEN_N": lambda args: args["KSeq"] % args["BLOCK_N"] == 0,
211 | }
212 | )
213 | @triton.jit
214 | def _attn_fwd(
215 | q,
216 | k,
217 | v,
218 | B,
219 | softmax_scale,
220 | dropout_prob,
221 | dropout_seed,
222 | stride_qz,
223 | stride_qm,
224 | stride_qh,
225 | stride_kz,
226 | stride_kn,
227 | stride_kh,
228 | stride_vz,
229 | stride_vn,
230 | stride_vh,
231 | stride_oz,
232 | stride_om,
233 | stride_oh,
234 | stride_bz,
235 | stride_bm,
236 | stride_bh,
237 | nheads_q,
238 | num_repeats,
239 | QSeq,
240 | cum_seqlens_q,
241 | KSeq,
242 | max_seqlen_q_rounded,
243 | headdim,
244 | CQSeq,
245 | CKSeq,
246 | DRuntime,
247 | Po,
248 | M,
249 | VARLEN: tl.constexpr,
250 | USE_DROPOUT: tl.constexpr,
251 | IS_CAUSAL: tl.constexpr,
252 | BIAS_ON: tl.constexpr,
253 | BLOCK_HEADDIM: tl.constexpr,
254 | PADDED_HEADS: tl.constexpr,
255 | EVEN_M: tl.constexpr,
256 | EVEN_N: tl.constexpr,
257 | BLOCK_M: tl.constexpr,
258 | BLOCK_N: tl.constexpr,
259 | ):
260 | LN2: tl.constexpr = 1.44269504089
261 | i_start_m = tl.program_id(0)
262 | off_zh = tl.program_id(1)
263 | off_head_q = off_zh % nheads_q
264 | off_head_kv = off_head_q // num_repeats
265 | off_z = off_zh // nheads_q
266 |
267 | if VARLEN:
268 | cu_seq_start_q = tl.load(cum_seqlens_q + off_z)
269 | actual_seqlen_q = tl.load(cum_seqlens_q + off_z + 1) - cu_seq_start_q
270 | if i_start_m * BLOCK_M >= actual_seqlen_q:
271 | return
272 | actual_seqlen_k = actual_seqlen_q
273 | cu_seq_start_k = cu_seq_start_q
274 | off_z = 0
275 | else:
276 | actual_seqlen_q = QSeq
277 | actual_seqlen_k = KSeq
278 | cu_seq_start_q = 0
279 | cu_seq_start_k = 0
280 |
281 | softmax_scale = softmax_scale * LN2
282 |
283 | offs_m = i_start_m * BLOCK_M + tl.arange(0, BLOCK_M)
284 | offs_n = tl.arange(0, BLOCK_N)
285 | offs_d = tl.arange(0, BLOCK_HEADDIM)
286 |
287 | fully_masked_lines = actual_seqlen_q - actual_seqlen_k if IS_CAUSAL else 0
288 | if fully_masked_lines >= (i_start_m + 1) * BLOCK_M:
289 | return
290 |
291 | q_ptrs = (
292 | q
293 | + off_z * stride_qz
294 | + off_head_q * stride_qh
295 | + cu_seq_start_q * stride_qm
296 | + (offs_m[:, None] * stride_qm + offs_d[None, :])
297 | )
298 |
299 | k_ptrs = (
300 | k
301 | + off_z * stride_kz
302 | + off_head_kv * stride_kh
303 | + cu_seq_start_k * stride_kn
304 | + (offs_n[:, None] * stride_kn + offs_d[None, :])
305 | )
306 |
307 | v_ptrs = (
308 | v
309 | + off_z * stride_vz
310 | + off_head_kv * stride_vh
311 | + cu_seq_start_k * stride_vn
312 | + (offs_n[:, None] * stride_vn + offs_d[None, :])
313 | )
314 |
315 | if BIAS_ON:
316 | bias_ptrs = (
317 | B
318 | + off_z * stride_bz
319 | + off_head_kv * stride_bh
320 | + cu_seq_start_q * stride_bm
321 | + (offs_m[:, None] * stride_bm + offs_n[None, :])
322 | )
323 | else:
324 | bias_ptrs = None
325 | if USE_DROPOUT:
326 | dropout_off = actual_seqlen_k * (
327 | cu_seq_start_q + actual_seqlen_q * (off_head_q + nheads_q * off_z)
328 | )
329 | dropout_offs = dropout_off + offs_m[:, None] * actual_seqlen_k + offs_n[None, :]
330 | else:
331 | dropout_offs = None
332 |
333 | me_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
334 | m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
335 | acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
336 |
337 | pad_rows = (not EVEN_M) or (VARLEN and (i_start_m * BLOCK_M > actual_seqlen_q))
338 | q = padded_load(
339 | q_ptrs,
340 | offs_m,
341 | offs_d,
342 | PA0=pad_rows,
343 | PA1=PADDED_HEADS,
344 | LA0=actual_seqlen_q,
345 | LA1=headdim,
346 | )
347 | if IS_CAUSAL:
348 | end_n = min(
349 | actual_seqlen_k - actual_seqlen_q + (i_start_m + 1) * BLOCK_M,
350 | actual_seqlen_k,
351 | )
352 | if end_n < 0:
353 | return
354 | else:
355 | end_n = actual_seqlen_k
356 |
357 | uneven_n = actual_seqlen_k % BLOCK_N != 0
358 | attention_padding = VARLEN & uneven_n
359 | if IS_CAUSAL:
360 | first_masked_col = i_start_m * BLOCK_M + 1 + actual_seqlen_k - actual_seqlen_q
361 | elif attention_padding:
362 | first_masked_col = actual_seqlen_k
363 | else:
364 | first_masked_col = end_n
365 | nb_full_blocks = first_masked_col // BLOCK_N
366 |
367 | next_start_n = 0
368 | if nb_full_blocks > 0:
369 | for _ in range(0, nb_full_blocks):
370 | m_i, me_i, acc_o = _attn_fwd_inner(
371 | q,
372 | m_i,
373 | me_i,
374 | k_ptrs,
375 | v_ptrs,
376 | bias_ptrs,
377 | acc_o,
378 | offs_m,
379 | offs_n,
380 | offs_d,
381 | softmax_scale,
382 | dropout_prob,
383 | dropout_seed,
384 | dropout_offs,
385 | stride_kn,
386 | stride_vn,
387 | next_start_n,
388 | actual_seqlen_q,
389 | actual_seqlen_k,
390 | headdim,
391 | USE_DROPOUT=USE_DROPOUT,
392 | IS_CAUSAL=IS_CAUSAL,
393 | BIAS_ON=BIAS_ON,
394 | MASKED=False,
395 | PADDED_COLS=False,
396 | PADDED_HEADS=PADDED_HEADS,
397 | BLOCK_M=BLOCK_M,
398 | BLOCK_N=BLOCK_N,
399 | )
400 | next_start_n += BLOCK_N
401 | if next_start_n < end_n:
402 | for index_start_n in range(next_start_n, end_n, BLOCK_N):
403 | pad_cols = (not EVEN_N) or VARLEN
404 | m_i, me_i, acc_o = _attn_fwd_inner(
405 | q,
406 | m_i,
407 | me_i,
408 | k_ptrs,
409 | v_ptrs,
410 | bias_ptrs,
411 | acc_o,
412 | offs_m,
413 | offs_n,
414 | offs_d,
415 | softmax_scale,
416 | dropout_prob,
417 | dropout_seed,
418 | dropout_offs,
419 | stride_kn,
420 | stride_vn,
421 | index_start_n,
422 | actual_seqlen_q,
423 | actual_seqlen_k,
424 | headdim,
425 | USE_DROPOUT=USE_DROPOUT,
426 | IS_CAUSAL=IS_CAUSAL,
427 | BIAS_ON=BIAS_ON,
428 | MASKED=True,
429 | PADDED_COLS=pad_cols,
430 | PADDED_HEADS=PADDED_HEADS,
431 | BLOCK_M=BLOCK_M,
432 | BLOCK_N=BLOCK_N,
433 | )
434 |
435 | if USE_DROPOUT:
436 | o_scale = tl.exp2((m_i - me_i) - tl.log2(1 - dropout_prob))
437 | else:
438 | o_scale = tl.exp2(m_i - me_i)
439 | acc_o = acc_o * o_scale[:, None]
440 | if fully_masked_lines > i_start_m * BLOCK_M:
441 | acc_o = tl.where(offs_m[:, None] < fully_masked_lines, 0, acc_o)
442 | i_start_m = tl.program_id(0)
443 | offs_m = i_start_m * BLOCK_M + tl.arange(0, BLOCK_M)
444 | lse_ptrs = M + off_zh * max_seqlen_q_rounded + offs_m
445 | tl.store(lse_ptrs, me_i)
446 | offs_d = tl.arange(0, BLOCK_HEADDIM)
447 | out_ptrs = (
448 | Po
449 | + off_z * stride_oz
450 | + off_head_q * stride_oh
451 | + cu_seq_start_q * stride_om
452 | + (offs_m[:, None] * stride_om + offs_d[None, :])
453 | )
454 |
455 | tl.store(
456 | out_ptrs,
457 | acc_o,
458 | mask=(offs_m[:, None] < actual_seqlen_q) & (offs_d[None, :] < headdim),
459 | )
460 |
461 |
462 | def _fwd_attention_kernel_call(
463 | q: tp.Optional[chex.Array],
464 | k: tp.Optional[chex.Array],
465 | v: tp.Optional[chex.Array],
466 | attention_mask: tp.Optional[chex.Array] = None,
467 | bias: tp.Optional[chex.Array] = None,
468 | softmax_scale: tp.Optional[float] = None,
469 | dropout_prob: float = 0.0,
470 | causal: bool = False,
471 | dropout_seed: tp.Optional[int] = None,
472 | varlen_mode: bool = True,
473 | ):
474 | if attention_mask is not None and varlen_mode:
475 | varlen_mode = attention_mask.shape[0] > 1
476 | assert bias is None, (
477 | "Attention mask is not supported along with attention bias. Just use bias instead."
478 | )
479 | assert q.shape[1] == k.shape[1], "Attention mask is not supported with QSeq != KSeq"
480 | else:
481 | varlen_mode = False
482 | batch, QSeq, nheads_q, head_dim = q.shape
483 | _, KSeq, nheads_kv, _ = k.shape
484 | expected_kv_shape = (batch, KSeq, nheads_kv, head_dim)
485 |
486 | assert k.shape == expected_kv_shape, (
487 | f"key shape is {k.shape = } and we excepted it to be like {expected_kv_shape = }"
488 | )
489 | assert v.shape == expected_kv_shape, (
490 | f"value shape is {v.shape = } and we excepted it to be like {expected_kv_shape = }"
491 | )
492 |
493 | assert nheads_q % nheads_kv == 0, f"{nheads_q = } is not divisible by {nheads_kv =}"
494 | assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type"
495 | assert q.dtype in [jnp.float16, jnp.bfloat16], "Only support fp16 and bf16"
496 |
497 | softmax_scale = 1.0 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale
498 |
499 | varlen_mode = varlen_mode and (batch > 1)
500 | if varlen_mode:
501 | cum_seqlens_q = jnp.zeros(shape=(attention_mask.shape[0] + 1,), dtype=jnp.int32)
502 | cum_seqlens_q = cum_seqlens_q.at[1:].set(
503 | jnp.cumsum(attention_mask.sum(axis=1, dtype="i4"), axis=0, dtype="i4")
504 | )
505 | max_seqlen_q = attention_mask.shape[1]
506 | max_seqlen_k = attention_mask.shape[1]
507 | q = attention_pack_with_static_shape(q, attention_mask)
508 | k = attention_pack_with_static_shape(k, attention_mask)
509 | v = attention_pack_with_static_shape(v, attention_mask)
510 | QSeq = q.shape[1]
511 | else:
512 | cum_seqlens_q = None
513 | max_seqlen_q = QSeq
514 | max_seqlen_k = KSeq
515 |
516 | bz, bh, bm = calc_bias_strides(
517 | bias,
518 | batch,
519 | nheads_q,
520 | QSeq,
521 | KSeq,
522 | )
523 |
524 | max_seqlen_q_rounded = math.ceil(max_seqlen_q / 128) * 128
525 | BLOCK_HEADDIM = max(triton.next_power_of_2(head_dim), 16)
526 | PADDED_HEADS = BLOCK_HEADDIM > head_dim
527 | num_repeats = nheads_q // nheads_kv
528 |
529 | qz, qm, qh, _ = get_strides(q.shape)
530 | oz, om, oh, _ = get_strides(q.shape)
531 | kz, kn, kh, _ = get_strides(k.shape)
532 | vz, vn, vh, _ = get_strides(v.shape)
533 |
534 | metaparams = dict(
535 | VARLEN=varlen_mode,
536 | USE_DROPOUT=(dropout_prob > 0),
537 | IS_CAUSAL=causal,
538 | BIAS_ON=(bias is not None),
539 | BLOCK_HEADDIM=BLOCK_HEADDIM,
540 | PADDED_HEADS=PADDED_HEADS,
541 | )
542 | out_shape = [
543 | jax.ShapeDtypeStruct(q.shape, q.dtype, sharding=get_sharding(q)),
544 | jax.ShapeDtypeStruct((batch, nheads_q, max_seqlen_q_rounded), jnp.float32),
545 | ]
546 |
547 | out, lse = triton_call(
548 | q,
549 | k,
550 | v,
551 | bias if bias is not None else jnp.zeros((1,), jnp.float16),
552 | softmax_scale,
553 | dropout_prob,
554 | dropout_seed if dropout_seed is not None else jnp.zeros((1,), jnp.float16),
555 | qz,
556 | qm,
557 | qh,
558 | kz,
559 | kn,
560 | kh,
561 | vz,
562 | vn,
563 | vh,
564 | oz,
565 | om,
566 | oh,
567 | bz,
568 | bm,
569 | bh,
570 | nheads_q,
571 | num_repeats,
572 | QSeq,
573 | cum_seqlens_q if cum_seqlens_q is not None else jnp.zeros((1,), jnp.float16),
574 | KSeq,
575 | max_seqlen_q_rounded,
576 | head_dim,
577 | max_seqlen_q // 128,
578 | max_seqlen_k // 128,
579 | dtype_index(q),
580 | kernel=_attn_fwd,
581 | out_shape=out_shape,
582 | grid=lambda META: (
583 | triton.cdiv(max_seqlen_q, META["BLOCK_M"]),
584 | batch * nheads_q,
585 | ),
586 | name="triton::ops::_attn_fwd",
587 | **metaparams,
588 | )
589 |
590 | if varlen_mode:
591 | out = attention_unpack_with_static_shape(out, cum_seqlens_q, *attention_mask.shape)
592 | return out, lse
593 |
--------------------------------------------------------------------------------
/jax_flash_attn2/flash_attention_triton/_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import math
17 | import typing as tp
18 | from functools import partial
19 |
20 | import chex
21 | import jax
22 | import jax.numpy as jnp
23 | import numpy
24 | import triton
25 | import triton.language as tl
26 |
27 | F = tp.TypeVar("F", bound=tp.Callable[..., tp.Any])
28 |
29 |
30 | def safe_autotune(
31 | configs,
32 | key,
33 | prune_configs_by=None,
34 | reset_to_zero=None,
35 | restore_value=None,
36 | pre_hook=None,
37 | post_hook=None,
38 | warmup=None,
39 | rep=None,
40 | use_cuda_graph=False,
41 | do_bench=None,
42 | ) -> tp.Callable[[F], F]:
43 | """
44 | Applies `triton.autotune` safely. Falls back to the original function if autotuning fails.
45 | """
46 | try:
47 | from triton.runtime.autotuner import Autotuner
48 |
49 | def decorator(fn):
50 | return Autotuner(
51 | fn,
52 | fn.arg_names,
53 | configs,
54 | key,
55 | reset_to_zero,
56 | restore_value,
57 | pre_hook=pre_hook,
58 | post_hook=post_hook,
59 | prune_configs_by=prune_configs_by,
60 | warmup=warmup,
61 | rep=rep,
62 | use_cuda_graph=use_cuda_graph,
63 | )
64 |
65 | return decorator
66 | except Exception as err:
67 | print(f"Couldn't autotune given function due to {err}")
68 |
69 | def decorator(fn):
70 | return fn
71 |
72 | return decorator
73 |
74 |
75 | def dtype_index(x: jnp.array) -> int:
76 | if x.dtype == jnp.float16:
77 | return 1
78 | if x.dtype == jnp.bfloat16:
79 | return 2
80 | if x.dtype == jnp.float32:
81 | return 3
82 | raise ValueError(x.dtype)
83 |
84 |
85 | def get_sharding(arr: chex.Array):
86 | """Gets the sharding of an array.
87 |
88 | Args:
89 | arr: Array to get sharding from.
90 |
91 | Returns:
92 | Sharding of the array.
93 | """
94 | return getattr(arr, "sharding", None)
95 |
96 |
97 | def get_strides(shape: tuple[int, ...]) -> tuple[int, ...]:
98 | """Calculates strides for a given shape.
99 |
100 | Args:
101 | shape: Shape of the array.
102 |
103 | Returns:
104 | Tuple of strides.
105 | """
106 | if hasattr(shape, "shape"):
107 | shape = shape.shape
108 | size = numpy.prod(shape)
109 | strides = []
110 | for s in shape:
111 | size = int(size // s)
112 | strides.append(size)
113 | return tuple(strides)
114 |
115 |
116 | @triton.jit
117 | def padded_load(
118 | ptrs,
119 | offs_a,
120 | offs_b,
121 | PA0: tl.constexpr,
122 | PA1: tl.constexpr,
123 | LA0: tl.constexpr,
124 | LA1: tl.constexpr,
125 | ):
126 | if PA0:
127 | if PA1:
128 | x = tl.load(
129 | ptrs,
130 | mask=(offs_a[:, None] < LA0) & (offs_b[None, :] < LA1),
131 | other=0.0,
132 | )
133 | else:
134 | x = tl.load(
135 | ptrs,
136 | mask=offs_a[:, None] < LA0,
137 | other=0.0,
138 | )
139 | else:
140 | if PA1:
141 | x = tl.load(
142 | ptrs,
143 | mask=offs_b[None, :] < LA1,
144 | other=0.0,
145 | )
146 | else:
147 | x = tl.load(ptrs)
148 | return x
149 |
150 |
151 | def calc_bias_strides(
152 | bias: tp.Optional[jnp.ndarray],
153 | batch: int,
154 | nheads_q: int,
155 | QSeq: int,
156 | KSeq: int,
157 | ) -> tp.Tuple[int, ...]:
158 | if bias is not None:
159 | if not hasattr(bias, "strides"):
160 | strides = tuple(map(lambda x: x * bias.itemsize, get_strides(bias)))
161 | else:
162 | strides = bias.strides
163 | if bias.shape[2] != QSeq or bias.shape[3] != KSeq:
164 | raise ValueError(
165 | f"Bias tensor has incompatible sequence dimensions. "
166 | f"Expected shape [..., {QSeq}, {KSeq}], but got [..., {bias.shape[2]}, {bias.shape[3]}]. "
167 | f"Full bias shape: {bias.shape}"
168 | )
169 | if bias.shape[0] == 1:
170 | stride_bz = 0
171 | elif bias.shape[0] == batch:
172 | stride_bz = strides[0] // bias.itemsize
173 | else:
174 | raise ValueError(
175 | f"Batch dimension mismatch in bias tensor. "
176 | f"Expected either 1 (for broadcasting) or {batch} (batch size), "
177 | f"but got {bias.shape[0]}. Consider reshaping your bias tensor."
178 | )
179 | if bias.shape[1] == 1:
180 | stride_bh = 0
181 | elif bias.shape[1] == nheads_q:
182 | stride_bh = strides[1] // bias.itemsize
183 | else:
184 | raise ValueError(
185 | f"Head dimension mismatch in bias tensor. "
186 | f"Expected either 1 (for broadcasting) or {nheads_q} (number of heads), "
187 | f"but got {bias.shape[1]}. Check that your bias tensor matches the model configuration."
188 | )
189 |
190 | stride_bm = strides[2] // bias.itemsize
191 | else:
192 | stride_bz, stride_bh, stride_bm = 0, 0, 0
193 | return stride_bz, stride_bh, stride_bm
194 |
195 |
196 | @partial(jax.jit, static_argnames=["max_tokens"])
197 | def attention_pack_with_static_shape(
198 | x: jnp.ndarray,
199 | attention_mask: jnp.ndarray,
200 | max_tokens: int = None,
201 | ) -> jnp.ndarray:
202 | """
203 | Pack attention tensor by removing padding based on attention mask.
204 | Uses a static maximum shape to be compatible with JIT.
205 | """
206 | batch_size, seqlen = attention_mask.shape
207 | num_heads, head_dim = x.shape[2], x.shape[3]
208 |
209 | if max_tokens is None:
210 | max_tokens = batch_size * seqlen
211 |
212 | seqlens = jnp.sum(attention_mask, axis=1).astype(jnp.int32)
213 | offsets = jnp.zeros((batch_size,), dtype=jnp.int32)
214 | offsets = offsets.at[1:].set(jnp.cumsum(seqlens[:-1]))
215 | packed = jnp.zeros((1, max_tokens, num_heads, head_dim), dtype=x.dtype)
216 | batch_idx, pos_idx = jnp.meshgrid(
217 | jnp.arange(batch_size), jnp.arange(seqlen), indexing="ij"
218 | )
219 |
220 | batch_idx_flat = batch_idx.reshape(-1)
221 | pos_idx_flat = pos_idx.reshape(-1)
222 |
223 | valid_mask = pos_idx < seqlens[:, None]
224 | target_idx = jnp.where(
225 | valid_mask,
226 | offsets[:, None] + pos_idx,
227 | jnp.zeros_like(pos_idx),
228 | )
229 | target_idx_flat = target_idx.reshape(-1)
230 | valid_mask_flat = valid_mask.reshape(-1)
231 |
232 | def process_token(i, packed_acc):
233 | b = batch_idx_flat[i]
234 | p = pos_idx_flat[i]
235 | t = target_idx_flat[i]
236 | valid = valid_mask_flat[i]
237 | packed_acc = jnp.where(valid, packed_acc.at[0, t].set(x[b, p]), packed_acc)
238 |
239 | return packed_acc
240 |
241 | packed = jax.lax.fori_loop(0, batch_size * seqlen, process_token, packed)
242 | return packed
243 |
244 |
245 | @partial(jax.jit, static_argnames=["seqlen", "batch_size"])
246 | def attention_unpack_with_static_shape(
247 | x: jnp.ndarray,
248 | cum_seqlens: jnp.ndarray,
249 | batch_size: int,
250 | seqlen: int,
251 | ) -> jnp.ndarray:
252 | """
253 | Unpack attention tensor by redistributing the packed values to their original positions.
254 |
255 | Args:
256 | x: Packed tensor of shape [1, packed_tokens, num_heads, head_dim]
257 | cum_seqlens: Cumulative sequence lengths, shape [batch_size+1]
258 | batch_size: Number of batches
259 | seqlen: Maximum sequence length
260 |
261 | Returns:
262 | Unpacked tensor of shape [batch_size, seqlen, num_heads, head_dim]
263 | """
264 | num_heads, head_dim = x.shape[2], x.shape[3]
265 |
266 | # Create output with static shape
267 | unpacked = jnp.zeros((batch_size, seqlen, num_heads, head_dim), dtype=x.dtype)
268 |
269 | # Process each batch
270 | def process_batch(b, unpacked_acc):
271 | start_idx = cum_seqlens[b]
272 | end_idx = cum_seqlens[b + 1]
273 | seq_len = end_idx - start_idx
274 |
275 | # Process each position in the sequence
276 | def process_position(p, acc):
277 | # Only copy if within valid sequence length
278 | valid = p < seq_len
279 | src_idx = start_idx + p
280 |
281 | # Update conditionally
282 | acc = jnp.where(valid, acc.at[b, p].set(x[0, src_idx]), acc)
283 |
284 | return acc
285 |
286 | # Process all positions in this batch
287 | unpacked_acc = jax.lax.fori_loop(0, seqlen, process_position, unpacked_acc)
288 |
289 | return unpacked_acc
290 |
291 | # Process all batches
292 | unpacked = jax.lax.fori_loop(0, batch_size, process_batch, unpacked)
293 |
294 | return unpacked
295 |
296 |
297 | def basic_attention_refrence(
298 | q: jnp.ndarray,
299 | k: jnp.ndarray,
300 | v: jnp.ndarray,
301 | attn_bias: tp.Optional[jnp.ndarray] = None,
302 | query_padding_mask: tp.Optional[jnp.ndarray] = None,
303 | key_padding_mask: tp.Optional[jnp.ndarray] = None,
304 | dropout_prob: float = 0.0,
305 | dropout_key: tp.Optional[jax.random.PRNGKey] = None,
306 | window_size: tp.Tuple[int, int] = (-1, -1),
307 | causal: bool = False,
308 | softcap: float = 0.0,
309 | ):
310 | if causal:
311 | window_size = (window_size[0], 0)
312 | dtype_og = q.dtype
313 | q, k, v = q.astype(jnp.float32), k.astype(jnp.float32), v.astype(jnp.float32)
314 | QSeq, KSeq = q.shape[1], k.shape[1]
315 | repeats = q.shape[2] // k.shape[2]
316 | if repeats > 1:
317 | k = jnp.repeat(k, repeats=repeats, axis=2)
318 | v = jnp.repeat(v, repeats=repeats, axis=2)
319 | d = q.shape[-1]
320 | q_scaled = q / math.sqrt(d)
321 | scores = jnp.einsum("bthd,bshd->bhts", q_scaled, k)
322 | if softcap > 0:
323 | scores = scores / softcap
324 | scores = jnp.tanh(scores)
325 | scores = scores * softcap
326 | if key_padding_mask is not None:
327 | key_mask = (~key_padding_mask).reshape(key_padding_mask.shape[0], 1, 1, KSeq)
328 | scores = jnp.where(key_mask, jnp.finfo(scores.dtype).min, scores)
329 | if window_size[0] >= 0 or window_size[1] >= 0:
330 | row_idx = jnp.arange(QSeq).reshape(-1, 1)
331 | col_idx = jnp.arange(KSeq)
332 | if key_padding_mask is None:
333 | sk = KSeq
334 | else:
335 | sk = jnp.sum(key_padding_mask, axis=-1).reshape(-1, 1, 1, 1)
336 | if query_padding_mask is None:
337 | sq = QSeq
338 | else:
339 | sq = jnp.sum(query_padding_mask, axis=-1).reshape(-1, 1, 1, 1)
340 | if window_size[0] < 0:
341 | local_mask = col_idx > row_idx + sk - sq + window_size[1]
342 | else:
343 | if key_padding_mask is None:
344 | sk_full = jnp.full_like(col_idx, KSeq)
345 | else:
346 | sk_full = sk
347 | local_mask = jnp.logical_or(
348 | col_idx > jnp.minimum(row_idx + sk - sq + window_size[1], sk_full),
349 | col_idx < row_idx + sk - sq - window_size[0],
350 | )
351 | scores = jnp.where(local_mask, jnp.finfo(scores.dtype).min, scores)
352 | if attn_bias is not None:
353 | scores = scores + attn_bias
354 | attention = jax.nn.softmax(scores, axis=-1).astype(v.dtype)
355 | if window_size[0] >= 0 or window_size[1] >= 0:
356 | all_masked = jnp.all(local_mask, axis=-1, keepdims=True)
357 | attention = jnp.where(all_masked, 0.0, attention)
358 | if query_padding_mask is not None:
359 | query_mask = (~query_padding_mask).reshape(query_padding_mask.shape[0], 1, QSeq, 1)
360 | attention = jnp.where(query_mask, 0.0, attention)
361 | dropout_scaling = 1.0 / (1 - dropout_prob)
362 | if dropout_prob > 0 and dropout_key is not None:
363 | dropout_mask = jax.random.bernoulli(
364 | dropout_key, p=1 - dropout_prob, shape=attention.shape
365 | )
366 | attention_drop = attention * dropout_mask * dropout_scaling
367 | else:
368 | attention_drop = attention
369 | output = jnp.einsum("bhts,bshd->bthd", attention_drop, v)
370 | if query_padding_mask is not None:
371 | query_mask_expanded = (~query_padding_mask).reshape(
372 | query_padding_mask.shape[0],
373 | QSeq,
374 | 1,
375 | 1,
376 | )
377 | output = jnp.where(query_mask_expanded, 0.0, output)
378 | return output.astype(dtype_og)
379 |
--------------------------------------------------------------------------------
/jax_flash_attn2/refrence_call.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import math
17 | import typing as tp
18 |
19 | import jax
20 | import jax.numpy as jnp
21 |
22 |
23 | def basic_attention_refrence(
24 | q: jnp.ndarray,
25 | k: jnp.ndarray,
26 | v: jnp.ndarray,
27 | attn_bias: tp.Optional[jnp.ndarray] = None,
28 | query_padding_mask: tp.Optional[jnp.ndarray] = None,
29 | key_padding_mask: tp.Optional[jnp.ndarray] = None,
30 | dropout_prob: float = 0.0,
31 | dropout_key: tp.Optional[jax.random.PRNGKey] = None,
32 | window_size: tp.Tuple[int, int] = (-1, -1),
33 | causal: bool = False,
34 | softcap: float = 0.0,
35 | ):
36 | if causal:
37 | window_size = (window_size[0], 0)
38 | dtype_og = q.dtype
39 | q, k, v = q.astype(jnp.float32), k.astype(jnp.float32), v.astype(jnp.float32)
40 | QSeq, KSeq = q.shape[1], k.shape[1]
41 | repeats = q.shape[2] // k.shape[2]
42 | if repeats > 1:
43 | k = jnp.repeat(k, repeats=repeats, axis=2)
44 | v = jnp.repeat(v, repeats=repeats, axis=2)
45 | d = q.shape[-1]
46 | q_scaled = q / math.sqrt(d)
47 | scores = jnp.einsum("bthd,bshd->bhts", q_scaled, k)
48 | if softcap > 0:
49 | scores = scores / softcap
50 | scores = jnp.tanh(scores)
51 | scores = scores * softcap
52 | if key_padding_mask is not None:
53 | key_mask = (~key_padding_mask).reshape(key_padding_mask.shape[0], 1, 1, KSeq)
54 | scores = jnp.where(key_mask, jnp.finfo(scores.dtype).min, scores)
55 | if window_size[0] >= 0 or window_size[1] >= 0:
56 | row_idx = jnp.arange(QSeq).reshape(-1, 1)
57 | col_idx = jnp.arange(KSeq)
58 | if key_padding_mask is None:
59 | sk = KSeq
60 | else:
61 | sk = jnp.sum(key_padding_mask, axis=-1).reshape(-1, 1, 1, 1)
62 | if query_padding_mask is None:
63 | sq = QSeq
64 | else:
65 | sq = jnp.sum(query_padding_mask, axis=-1).reshape(-1, 1, 1, 1)
66 | if window_size[0] < 0:
67 | local_mask = col_idx > row_idx + sk - sq + window_size[1]
68 | else:
69 | if key_padding_mask is None:
70 | sk_full = jnp.full_like(col_idx, KSeq)
71 | else:
72 | sk_full = sk
73 | local_mask = jnp.logical_or(
74 | col_idx > jnp.minimum(row_idx + sk - sq + window_size[1], sk_full),
75 | col_idx < row_idx + sk - sq - window_size[0],
76 | )
77 | scores = jnp.where(local_mask, jnp.finfo(scores.dtype).min, scores)
78 | if attn_bias is not None:
79 | scores = scores + attn_bias
80 | attention = jax.nn.softmax(scores, axis=-1).astype(v.dtype)
81 | if window_size[0] >= 0 or window_size[1] >= 0:
82 | all_masked = jnp.all(local_mask, axis=-1, keepdims=True)
83 | attention = jnp.where(all_masked, 0.0, attention)
84 | if query_padding_mask is not None:
85 | query_mask = (~query_padding_mask).reshape(query_padding_mask.shape[0], 1, QSeq, 1)
86 | attention = jnp.where(query_mask, 0.0, attention)
87 | dropout_scaling = 1.0 / (1 - dropout_prob)
88 | if dropout_prob > 0 and dropout_key is not None:
89 | dropout_mask = jax.random.bernoulli(
90 | dropout_key, p=1 - dropout_prob, shape=attention.shape
91 | )
92 | attention_drop = attention * dropout_mask * dropout_scaling
93 | else:
94 | attention_drop = attention
95 | output = jnp.einsum("bhts,bshd->bthd", attention_drop, v)
96 | if query_padding_mask is not None:
97 | query_mask_expanded = (~query_padding_mask).reshape(
98 | query_padding_mask.shape[0],
99 | QSeq,
100 | 1,
101 | 1,
102 | )
103 | output = jnp.where(query_mask_expanded, 0.0, output)
104 | return output.astype(dtype_og)
105 |
--------------------------------------------------------------------------------
/jax_flash_attn2/utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from jax import random as jrandom
4 | from jax.interpreters import pxla
5 | from jax.sharding import PartitionSpec
6 |
7 |
8 | class GenerateRNG:
9 | """An infinite generator of JAX PRNGKeys, useful for iterating over seeds."""
10 |
11 | def __init__(self, seed: int = 0):
12 | """Initializes the generator with a starting seed.
13 |
14 | Args:
15 | seed: The seed to use for the initial PRNGKey.
16 | """
17 | self.seed = seed
18 | self._rng = jrandom.PRNGKey(seed)
19 |
20 | def __next__(self) -> jrandom.PRNGKey:
21 | """Generates and returns the next PRNGKey in the sequence.
22 |
23 | Returns:
24 | The next PRNGKey derived from the internal state.
25 | """
26 | self._rng, key = jrandom.split(self._rng)
27 | return key
28 |
29 | @property
30 | def rng(self) -> jrandom.PRNGKey:
31 | """Provides access to the next PRNGKey without advancing the generator.
32 |
33 | Returns:
34 | The next PRNGKey in the sequence.
35 | """
36 | return next(self)
37 |
38 |
39 | def get_logger(name, level: int = logging.INFO) -> logging.Logger:
40 | """
41 | Function to create and configure a logger.
42 | Args:
43 | name: str: The name of the logger.
44 | level: int: The logging level. Defaults to logging.INFO.
45 | Returns:
46 | logging.Logger: The configured logger instance.
47 | """
48 | logger = logging.getLogger(name)
49 | logger.propagate = False
50 | logger.setLevel(level)
51 | console_handler = logging.StreamHandler()
52 | console_handler.setLevel(level)
53 | formatter = logging.Formatter("%(asctime)s %(levelname)-8s [%(name)s] %(message)s")
54 | console_handler.setFormatter(formatter)
55 | logger.addHandler(console_handler)
56 | return logger
57 |
58 |
59 | def names_in_current_mesh(*names: str) -> bool:
60 | """
61 | Check if the given names are present in the current JAX mesh.
62 |
63 | Args:
64 | *names: Variable number of axis names to check.
65 |
66 | Returns:
67 | True if all given names are present in the current mesh, False otherwise.
68 | """
69 | mesh_axis_names = pxla.thread_resources.env.physical_mesh.axis_names
70 | return set(names) <= set(mesh_axis_names)
71 |
72 |
73 | def get_names_from_partition_spec(
74 | partition_specs: dict[str, PartitionSpec],
75 | ) -> list[str]:
76 | """
77 | Extract axis names from a partition specification.
78 |
79 | This function recursively iterates through the provided `partition_specs`
80 | dictionary and extracts all unique axis names used in the sharding specifications.
81 |
82 | Args:
83 | partition_specs: A dictionary mapping parameter names to their respective `PartitionSpec`.
84 |
85 | Returns:
86 | A list of unique axis names used in the partition specs.
87 | """
88 | names = set()
89 | if isinstance(partition_specs, dict):
90 | partition_specs = partition_specs.values()
91 | for item in partition_specs:
92 | if item is None:
93 | continue
94 | elif isinstance(item, str):
95 | names.add(item)
96 | else:
97 | names.update(get_names_from_partition_spec(item))
98 | return list(names)
99 |
--------------------------------------------------------------------------------
/poetry.lock:
--------------------------------------------------------------------------------
1 | # This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand.
2 |
3 | [[package]]
4 | name = "absl-py"
5 | version = "2.1.0"
6 | description = "Abseil Python Common Libraries, see https://github.com/abseil/abseil-py."
7 | optional = false
8 | python-versions = ">=3.7"
9 | files = [
10 | {file = "absl-py-2.1.0.tar.gz", hash = "sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff"},
11 | {file = "absl_py-2.1.0-py3-none-any.whl", hash = "sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308"},
12 | ]
13 |
14 | [[package]]
15 | name = "chex"
16 | version = "0.1.87"
17 | description = "Chex: Testing made fun, in JAX!"
18 | optional = false
19 | python-versions = ">=3.9"
20 | files = [
21 | {file = "chex-0.1.87-py3-none-any.whl", hash = "sha256:ce536475661fd96d21be0c1728ecdbedd03f8ff950c662dfc338c92ea782cb16"},
22 | {file = "chex-0.1.87.tar.gz", hash = "sha256:0096d89cc8d898bb521ef4bfbf5c24549022b0e5b301f529ab57238896fe6c5d"},
23 | ]
24 |
25 | [package.dependencies]
26 | absl-py = ">=0.9.0"
27 | jax = ">=0.4.27"
28 | jaxlib = ">=0.4.27"
29 | numpy = ">=1.24.1"
30 | setuptools = {version = "*", markers = "python_version >= \"3.12\""}
31 | toolz = ">=0.9.0"
32 | typing-extensions = ">=4.2.0"
33 |
34 | [[package]]
35 | name = "einops"
36 | version = "0.8.0"
37 | description = "A new flavour of deep learning operations"
38 | optional = false
39 | python-versions = ">=3.8"
40 | files = [
41 | {file = "einops-0.8.0-py3-none-any.whl", hash = "sha256:9572fb63046264a862693b0a87088af3bdc8c068fde03de63453cbbde245465f"},
42 | {file = "einops-0.8.0.tar.gz", hash = "sha256:63486517fed345712a8385c100cb279108d9d47e6ae59099b07657e983deae85"},
43 | ]
44 |
45 | [[package]]
46 | name = "filelock"
47 | version = "3.16.1"
48 | description = "A platform independent file lock."
49 | optional = false
50 | python-versions = ">=3.8"
51 | files = [
52 | {file = "filelock-3.16.1-py3-none-any.whl", hash = "sha256:2082e5703d51fbf98ea75855d9d5527e33d8ff23099bec374a134febee6946b0"},
53 | {file = "filelock-3.16.1.tar.gz", hash = "sha256:c249fbfcd5db47e5e2d6d62198e565475ee65e4831e2561c8e313fa7eb961435"},
54 | ]
55 |
56 | [package.extras]
57 | docs = ["furo (>=2024.8.6)", "sphinx (>=8.0.2)", "sphinx-autodoc-typehints (>=2.4.1)"]
58 | testing = ["covdefaults (>=2.3)", "coverage (>=7.6.1)", "diff-cover (>=9.2)", "pytest (>=8.3.3)", "pytest-asyncio (>=0.24)", "pytest-cov (>=5)", "pytest-mock (>=3.14)", "pytest-timeout (>=2.3.1)", "virtualenv (>=20.26.4)"]
59 | typing = ["typing-extensions (>=4.12.2)"]
60 |
61 | [[package]]
62 | name = "jax"
63 | version = "0.4.35"
64 | description = "Differentiate, compile, and transform Numpy code."
65 | optional = false
66 | python-versions = ">=3.10"
67 | files = [
68 | {file = "jax-0.4.35-py3-none-any.whl", hash = "sha256:fa99e909a31424abfec750019a6dd36f6acc18a6e7d40e2c0086b932cc351325"},
69 | {file = "jax-0.4.35.tar.gz", hash = "sha256:c0c986993026b10bf6f607fecb7417377460254640766ce40f1fef3fd139c12e"},
70 | ]
71 |
72 | [package.dependencies]
73 | jaxlib = ">=0.4.34,<=0.4.35"
74 | ml-dtypes = ">=0.4.0"
75 | numpy = [
76 | {version = ">=1.26.0", markers = "python_version >= \"3.12\""},
77 | {version = ">=1.24", markers = "python_version < \"3.12\""},
78 | ]
79 | opt-einsum = "*"
80 | scipy = [
81 | {version = ">=1.11.1", markers = "python_version >= \"3.12\""},
82 | {version = ">=1.10", markers = "python_version < \"3.12\""},
83 | ]
84 |
85 | [package.extras]
86 | ci = ["jaxlib (==0.4.34)"]
87 | cuda = ["jax-cuda12-plugin[with-cuda] (>=0.4.34,<=0.4.35)", "jaxlib (==0.4.34)"]
88 | cuda12 = ["jax-cuda12-plugin[with-cuda] (>=0.4.34,<=0.4.35)", "jaxlib (==0.4.34)"]
89 | cuda12-local = ["jax-cuda12-plugin (==0.4.34)", "jaxlib (==0.4.34)"]
90 | cuda12-pip = ["jax-cuda12-plugin[with-cuda] (>=0.4.34,<=0.4.35)", "jaxlib (==0.4.34)"]
91 | k8s = ["kubernetes"]
92 | minimum-jaxlib = ["jaxlib (==0.4.34)"]
93 | tpu = ["jaxlib (>=0.4.34,<=0.4.35)", "libtpu (==0.0.2)", "libtpu-nightly (==0.1.dev20241010+nightly.cleanup)", "requests"]
94 |
95 | [[package]]
96 | name = "jaxlib"
97 | version = "0.4.35"
98 | description = "XLA library for JAX"
99 | optional = false
100 | python-versions = ">=3.10"
101 | files = [
102 | {file = "jaxlib-0.4.35-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:907e548ad6ce53b242a55c5f36c2a2a4c37d38f6cd8c356fc550a2f18ab0e82f"},
103 | {file = "jaxlib-0.4.35-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8f8c499644660aefd0ae2ee31039da6d4df0f26d0ee67ba9fb316183a5304288"},
104 | {file = "jaxlib-0.4.35-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:5d2d8a5b89d334b875ede98d7fcee946bebef1a1b5abd118ff543bcef4ab09f5"},
105 | {file = "jaxlib-0.4.35-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:91a283a72263feebe0d110d1136df96950744e47530f12df42c03f36888c971e"},
106 | {file = "jaxlib-0.4.35-cp310-cp310-win_amd64.whl", hash = "sha256:d210bab7e1ce0b2f2e568548b3903ea6aec349019fc1398cd2a0c069e8342e62"},
107 | {file = "jaxlib-0.4.35-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:7f8bfc90f68857b223b7e38a9bdf466a4f1cb405c9a4aa11698dc9ab7b35c29b"},
108 | {file = "jaxlib-0.4.35-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:261570c94b169dc90f3af903282eeec856b52736c0944d243504ced93d19b217"},
109 | {file = "jaxlib-0.4.35-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:e1cee6dc291251f3fb6b0127fdd96c0439ac1ea97e01571d06910df72d6ac6e1"},
110 | {file = "jaxlib-0.4.35-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:bc9eafba001ff8569cfa252fe7f04ba553622702b4b473b656dd0866edf6b8d4"},
111 | {file = "jaxlib-0.4.35-cp311-cp311-win_amd64.whl", hash = "sha256:0fd990354d5623d3a34493fcd7213493390dbf5039bea19b62e2aaee1049eda9"},
112 | {file = "jaxlib-0.4.35-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:b44f3e6e9fb748bb43df914356cf9d0d0c9a6e446a12c21fe843db25ed0df65f"},
113 | {file = "jaxlib-0.4.35-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:504d0a2e2117724359d99d7e3663022686dcdddd85aa14bdad02008d444481ad"},
114 | {file = "jaxlib-0.4.35-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:187cb6929dc139b75d952d67c33118473c1b4105525a3e5607f064e7b8efdc74"},
115 | {file = "jaxlib-0.4.35-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:04d1db3bf0050d120238bfb9b686b58fefcc4d9dd9e2d96aecd3f68a1f1f5e0a"},
116 | {file = "jaxlib-0.4.35-cp312-cp312-win_amd64.whl", hash = "sha256:dddffce48d7e6057008999aed2d8a9daecc57a48c45a4f8c475e00880eb2e41d"},
117 | {file = "jaxlib-0.4.35-cp313-cp313-macosx_10_14_x86_64.whl", hash = "sha256:14aeac3fea2ca1d5afb1878f72470b159cc89adb2633c5f0686f5d7c39f2ac18"},
118 | {file = "jaxlib-0.4.35-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:e8c9579e20d5ecdc4f61336cdd032710cb8c38d5ae9c4fce0cf9ea031cef21cb"},
119 | {file = "jaxlib-0.4.35-cp313-cp313-manylinux2014_aarch64.whl", hash = "sha256:7b11ad7c13f7f96f36efd303711ecac425f19ca2ddf65cf1be1541167a959ee5"},
120 | {file = "jaxlib-0.4.35-cp313-cp313-manylinux2014_x86_64.whl", hash = "sha256:0be3cf9df879d9ae1b5b92fc281f77d21f522fcbae1a48a02661026bbd9b9309"},
121 | {file = "jaxlib-0.4.35-cp313-cp313-win_amd64.whl", hash = "sha256:330c090bb9af413f552d8a92d097e50baec6b75823430fb2966a49f5298d4c43"},
122 | ]
123 |
124 | [package.dependencies]
125 | ml-dtypes = ">=0.2.0"
126 | numpy = ">=1.24"
127 | scipy = [
128 | {version = ">=1.11.1", markers = "python_version >= \"3.12\""},
129 | {version = ">=1.10", markers = "python_version < \"3.12\""},
130 | ]
131 |
132 | [[package]]
133 | name = "ml-dtypes"
134 | version = "0.5.0"
135 | description = ""
136 | optional = false
137 | python-versions = ">=3.9"
138 | files = [
139 | {file = "ml_dtypes-0.5.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:8c32138975797e681eb175996d64356bcfa124bdbb6a70460b9768c2b35a6fa4"},
140 | {file = "ml_dtypes-0.5.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ab046f2ff789b1f11b2491909682c5d089934835f9a760fafc180e47dcb676b8"},
141 | {file = "ml_dtypes-0.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c7a9152f5876fef565516aa5dd1dccd6fc298a5891b2467973905103eb5c7856"},
142 | {file = "ml_dtypes-0.5.0-cp310-cp310-win_amd64.whl", hash = "sha256:968fede07d1f9b926a63df97d25ac656cac1a57ebd33701734eaf704bc55d8d8"},
143 | {file = "ml_dtypes-0.5.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:60275f2b51b56834e840c4809fca840565f9bf8e9a73f6d8c94f5b5935701215"},
144 | {file = "ml_dtypes-0.5.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:76942f6aeb5c40766d5ea62386daa4148e6a54322aaf5b53eae9e7553240222f"},
145 | {file = "ml_dtypes-0.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2e7534392682c3098bc7341648c650864207169c654aed83143d7a19c67ae06f"},
146 | {file = "ml_dtypes-0.5.0-cp311-cp311-win_amd64.whl", hash = "sha256:dc74fd9995513d33eac63d64e436240f5494ec74d522a9f0920194942fc3d2d7"},
147 | {file = "ml_dtypes-0.5.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:d4b1a70a3e5219790d6b55b9507606fc4e02911d1497d16c18dd721eb7efe7d0"},
148 | {file = "ml_dtypes-0.5.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a988bac6572630e1e9c2edd9b1277b4eefd1c86209e52b0d061b775ac33902ff"},
149 | {file = "ml_dtypes-0.5.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a38df8df61194aeaae1ab7579075779b4ad32cd1cffd012c28be227fa7f2a70a"},
150 | {file = "ml_dtypes-0.5.0-cp312-cp312-win_amd64.whl", hash = "sha256:afa08343069874a30812871d639f9c02b4158ace065601406a493a8511180c02"},
151 | {file = "ml_dtypes-0.5.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:d3b3db9990c3840986a0e70524e122cfa32b91139c3653df76121ba7776e015f"},
152 | {file = "ml_dtypes-0.5.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e04fde367b2fe901b1d47234426fe8819909bd1dd862a5adb630f27789c20599"},
153 | {file = "ml_dtypes-0.5.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:54415257f00eb44fbcc807454efac3356f75644f1cbfc2d4e5522a72ae1dacab"},
154 | {file = "ml_dtypes-0.5.0-cp313-cp313-win_amd64.whl", hash = "sha256:cb5cc7b25acabd384f75bbd78892d0c724943f3e2e1986254665a1aa10982e07"},
155 | {file = "ml_dtypes-0.5.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:5f2b59233a0dbb6a560b3137ed6125433289ccba2f8d9c3695a52423a369ed15"},
156 | {file = "ml_dtypes-0.5.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:099e09edd54e676903b4538f3815b5ab96f5b119690514602d96bfdb67172cbe"},
157 | {file = "ml_dtypes-0.5.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a03fc861b86cc586728e3d093ba37f0cc05e65330c3ebd7688e7bae8290f8859"},
158 | {file = "ml_dtypes-0.5.0-cp39-cp39-win_amd64.whl", hash = "sha256:7ee9c320bb0f9ffdf9f6fa6a696ef2e005d1f66438d6f1c1457338e00a02e8cf"},
159 | {file = "ml_dtypes-0.5.0.tar.gz", hash = "sha256:3e7d3a380fe73a63c884f06136f8baa7a5249cc8e9fdec677997dd78549f8128"},
160 | ]
161 |
162 | [package.dependencies]
163 | numpy = [
164 | {version = ">=2.1.0", markers = "python_version >= \"3.13\""},
165 | {version = ">=1.26.0", markers = "python_version >= \"3.12\" and python_version < \"3.13\""},
166 | {version = ">=1.23.3", markers = "python_version >= \"3.11\" and python_version < \"3.12\""},
167 | {version = ">=1.21.2", markers = "python_version >= \"3.10\" and python_version < \"3.11\""},
168 | ]
169 |
170 | [package.extras]
171 | dev = ["absl-py", "pyink", "pylint (>=2.6.0)", "pytest", "pytest-xdist"]
172 |
173 | [[package]]
174 | name = "numpy"
175 | version = "2.1.2"
176 | description = "Fundamental package for array computing in Python"
177 | optional = false
178 | python-versions = ">=3.10"
179 | files = [
180 | {file = "numpy-2.1.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:30d53720b726ec36a7f88dc873f0eec8447fbc93d93a8f079dfac2629598d6ee"},
181 | {file = "numpy-2.1.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e8d3ca0a72dd8846eb6f7dfe8f19088060fcb76931ed592d29128e0219652884"},
182 | {file = "numpy-2.1.2-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:fc44e3c68ff00fd991b59092a54350e6e4911152682b4782f68070985aa9e648"},
183 | {file = "numpy-2.1.2-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:7c1c60328bd964b53f8b835df69ae8198659e2b9302ff9ebb7de4e5a5994db3d"},
184 | {file = "numpy-2.1.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6cdb606a7478f9ad91c6283e238544451e3a95f30fb5467fbf715964341a8a86"},
185 | {file = "numpy-2.1.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d666cb72687559689e9906197e3bec7b736764df6a2e58ee265e360663e9baf7"},
186 | {file = "numpy-2.1.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:c6eef7a2dbd0abfb0d9eaf78b73017dbfd0b54051102ff4e6a7b2980d5ac1a03"},
187 | {file = "numpy-2.1.2-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:12edb90831ff481f7ef5f6bc6431a9d74dc0e5ff401559a71e5e4611d4f2d466"},
188 | {file = "numpy-2.1.2-cp310-cp310-win32.whl", hash = "sha256:a65acfdb9c6ebb8368490dbafe83c03c7e277b37e6857f0caeadbbc56e12f4fb"},
189 | {file = "numpy-2.1.2-cp310-cp310-win_amd64.whl", hash = "sha256:860ec6e63e2c5c2ee5e9121808145c7bf86c96cca9ad396c0bd3e0f2798ccbe2"},
190 | {file = "numpy-2.1.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b42a1a511c81cc78cbc4539675713bbcf9d9c3913386243ceff0e9429ca892fe"},
191 | {file = "numpy-2.1.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:faa88bc527d0f097abdc2c663cddf37c05a1c2f113716601555249805cf573f1"},
192 | {file = "numpy-2.1.2-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:c82af4b2ddd2ee72d1fc0c6695048d457e00b3582ccde72d8a1c991b808bb20f"},
193 | {file = "numpy-2.1.2-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:13602b3174432a35b16c4cfb5de9a12d229727c3dd47a6ce35111f2ebdf66ff4"},
194 | {file = "numpy-2.1.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1ebec5fd716c5a5b3d8dfcc439be82a8407b7b24b230d0ad28a81b61c2f4659a"},
195 | {file = "numpy-2.1.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e2b49c3c0804e8ecb05d59af8386ec2f74877f7ca8fd9c1e00be2672e4d399b1"},
196 | {file = "numpy-2.1.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:2cbba4b30bf31ddbe97f1c7205ef976909a93a66bb1583e983adbd155ba72ac2"},
197 | {file = "numpy-2.1.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:8e00ea6fc82e8a804433d3e9cedaa1051a1422cb6e443011590c14d2dea59146"},
198 | {file = "numpy-2.1.2-cp311-cp311-win32.whl", hash = "sha256:5006b13a06e0b38d561fab5ccc37581f23c9511879be7693bd33c7cd15ca227c"},
199 | {file = "numpy-2.1.2-cp311-cp311-win_amd64.whl", hash = "sha256:f1eb068ead09f4994dec71c24b2844f1e4e4e013b9629f812f292f04bd1510d9"},
200 | {file = "numpy-2.1.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:d7bf0a4f9f15b32b5ba53147369e94296f5fffb783db5aacc1be15b4bf72f43b"},
201 | {file = "numpy-2.1.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b1d0fcae4f0949f215d4632be684a539859b295e2d0cb14f78ec231915d644db"},
202 | {file = "numpy-2.1.2-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:f751ed0a2f250541e19dfca9f1eafa31a392c71c832b6bb9e113b10d050cb0f1"},
203 | {file = "numpy-2.1.2-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:bd33f82e95ba7ad632bc57837ee99dba3d7e006536200c4e9124089e1bf42426"},
204 | {file = "numpy-2.1.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1b8cde4f11f0a975d1fd59373b32e2f5a562ade7cde4f85b7137f3de8fbb29a0"},
205 | {file = "numpy-2.1.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d95f286b8244b3649b477ac066c6906fbb2905f8ac19b170e2175d3d799f4df"},
206 | {file = "numpy-2.1.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:ab4754d432e3ac42d33a269c8567413bdb541689b02d93788af4131018cbf366"},
207 | {file = "numpy-2.1.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:e585c8ae871fd38ac50598f4763d73ec5497b0de9a0ab4ef5b69f01c6a046142"},
208 | {file = "numpy-2.1.2-cp312-cp312-win32.whl", hash = "sha256:9c6c754df29ce6a89ed23afb25550d1c2d5fdb9901d9c67a16e0b16eaf7e2550"},
209 | {file = "numpy-2.1.2-cp312-cp312-win_amd64.whl", hash = "sha256:456e3b11cb79ac9946c822a56346ec80275eaf2950314b249b512896c0d2505e"},
210 | {file = "numpy-2.1.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:a84498e0d0a1174f2b3ed769b67b656aa5460c92c9554039e11f20a05650f00d"},
211 | {file = "numpy-2.1.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:4d6ec0d4222e8ffdab1744da2560f07856421b367928026fb540e1945f2eeeaf"},
212 | {file = "numpy-2.1.2-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:259ec80d54999cc34cd1eb8ded513cb053c3bf4829152a2e00de2371bd406f5e"},
213 | {file = "numpy-2.1.2-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:675c741d4739af2dc20cd6c6a5c4b7355c728167845e3c6b0e824e4e5d36a6c3"},
214 | {file = "numpy-2.1.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:05b2d4e667895cc55e3ff2b56077e4c8a5604361fc21a042845ea3ad67465aa8"},
215 | {file = "numpy-2.1.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:43cca367bf94a14aca50b89e9bc2061683116cfe864e56740e083392f533ce7a"},
216 | {file = "numpy-2.1.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:76322dcdb16fccf2ac56f99048af32259dcc488d9b7e25b51e5eca5147a3fb98"},
217 | {file = "numpy-2.1.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:32e16a03138cabe0cb28e1007ee82264296ac0983714094380b408097a418cfe"},
218 | {file = "numpy-2.1.2-cp313-cp313-win32.whl", hash = "sha256:242b39d00e4944431a3cd2db2f5377e15b5785920421993770cddb89992c3f3a"},
219 | {file = "numpy-2.1.2-cp313-cp313-win_amd64.whl", hash = "sha256:f2ded8d9b6f68cc26f8425eda5d3877b47343e68ca23d0d0846f4d312ecaa445"},
220 | {file = "numpy-2.1.2-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:2ffef621c14ebb0188a8633348504a35c13680d6da93ab5cb86f4e54b7e922b5"},
221 | {file = "numpy-2.1.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:ad369ed238b1959dfbade9018a740fb9392c5ac4f9b5173f420bd4f37ba1f7a0"},
222 | {file = "numpy-2.1.2-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:d82075752f40c0ddf57e6e02673a17f6cb0f8eb3f587f63ca1eaab5594da5b17"},
223 | {file = "numpy-2.1.2-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:1600068c262af1ca9580a527d43dc9d959b0b1d8e56f8a05d830eea39b7c8af6"},
224 | {file = "numpy-2.1.2-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a26ae94658d3ba3781d5e103ac07a876b3e9b29db53f68ed7df432fd033358a8"},
225 | {file = "numpy-2.1.2-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:13311c2db4c5f7609b462bc0f43d3c465424d25c626d95040f073e30f7570e35"},
226 | {file = "numpy-2.1.2-cp313-cp313t-musllinux_1_1_x86_64.whl", hash = "sha256:2abbf905a0b568706391ec6fa15161fad0fb5d8b68d73c461b3c1bab6064dd62"},
227 | {file = "numpy-2.1.2-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:ef444c57d664d35cac4e18c298c47d7b504c66b17c2ea91312e979fcfbdfb08a"},
228 | {file = "numpy-2.1.2-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:bdd407c40483463898b84490770199d5714dcc9dd9b792f6c6caccc523c00952"},
229 | {file = "numpy-2.1.2-pp310-pypy310_pp73-macosx_14_0_x86_64.whl", hash = "sha256:da65fb46d4cbb75cb417cddf6ba5e7582eb7bb0b47db4b99c9fe5787ce5d91f5"},
230 | {file = "numpy-2.1.2-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1c193d0b0238638e6fc5f10f1b074a6993cb13b0b431f64079a509d63d3aa8b7"},
231 | {file = "numpy-2.1.2-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:a7d80b2e904faa63068ead63107189164ca443b42dd1930299e0d1cb041cec2e"},
232 | {file = "numpy-2.1.2.tar.gz", hash = "sha256:13532a088217fa624c99b843eeb54640de23b3414b14aa66d023805eb731066c"},
233 | ]
234 |
235 | [[package]]
236 | name = "opt-einsum"
237 | version = "3.4.0"
238 | description = "Path optimization of einsum functions."
239 | optional = false
240 | python-versions = ">=3.8"
241 | files = [
242 | {file = "opt_einsum-3.4.0-py3-none-any.whl", hash = "sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd"},
243 | {file = "opt_einsum-3.4.0.tar.gz", hash = "sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac"},
244 | ]
245 |
246 | [[package]]
247 | name = "scipy"
248 | version = "1.13.1"
249 | description = "Fundamental algorithms for scientific computing in Python"
250 | optional = false
251 | python-versions = ">=3.9"
252 | files = [
253 | {file = "scipy-1.13.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:20335853b85e9a49ff7572ab453794298bcf0354d8068c5f6775a0eabf350aca"},
254 | {file = "scipy-1.13.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:d605e9c23906d1994f55ace80e0125c587f96c020037ea6aa98d01b4bd2e222f"},
255 | {file = "scipy-1.13.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cfa31f1def5c819b19ecc3a8b52d28ffdcc7ed52bb20c9a7589669dd3c250989"},
256 | {file = "scipy-1.13.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f26264b282b9da0952a024ae34710c2aff7d27480ee91a2e82b7b7073c24722f"},
257 | {file = "scipy-1.13.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:eccfa1906eacc02de42d70ef4aecea45415f5be17e72b61bafcfd329bdc52e94"},
258 | {file = "scipy-1.13.1-cp310-cp310-win_amd64.whl", hash = "sha256:2831f0dc9c5ea9edd6e51e6e769b655f08ec6db6e2e10f86ef39bd32eb11da54"},
259 | {file = "scipy-1.13.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:27e52b09c0d3a1d5b63e1105f24177e544a222b43611aaf5bc44d4a0979e32f9"},
260 | {file = "scipy-1.13.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:54f430b00f0133e2224c3ba42b805bfd0086fe488835effa33fa291561932326"},
261 | {file = "scipy-1.13.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e89369d27f9e7b0884ae559a3a956e77c02114cc60a6058b4e5011572eea9299"},
262 | {file = "scipy-1.13.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a78b4b3345f1b6f68a763c6e25c0c9a23a9fd0f39f5f3d200efe8feda560a5fa"},
263 | {file = "scipy-1.13.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:45484bee6d65633752c490404513b9ef02475b4284c4cfab0ef946def50b3f59"},
264 | {file = "scipy-1.13.1-cp311-cp311-win_amd64.whl", hash = "sha256:5713f62f781eebd8d597eb3f88b8bf9274e79eeabf63afb4a737abc6c84ad37b"},
265 | {file = "scipy-1.13.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:5d72782f39716b2b3509cd7c33cdc08c96f2f4d2b06d51e52fb45a19ca0c86a1"},
266 | {file = "scipy-1.13.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:017367484ce5498445aade74b1d5ab377acdc65e27095155e448c88497755a5d"},
267 | {file = "scipy-1.13.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:949ae67db5fa78a86e8fa644b9a6b07252f449dcf74247108c50e1d20d2b4627"},
268 | {file = "scipy-1.13.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:de3ade0e53bc1f21358aa74ff4830235d716211d7d077e340c7349bc3542e884"},
269 | {file = "scipy-1.13.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:2ac65fb503dad64218c228e2dc2d0a0193f7904747db43014645ae139c8fad16"},
270 | {file = "scipy-1.13.1-cp312-cp312-win_amd64.whl", hash = "sha256:cdd7dacfb95fea358916410ec61bbc20440f7860333aee6d882bb8046264e949"},
271 | {file = "scipy-1.13.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:436bbb42a94a8aeef855d755ce5a465479c721e9d684de76bf61a62e7c2b81d5"},
272 | {file = "scipy-1.13.1-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:8335549ebbca860c52bf3d02f80784e91a004b71b059e3eea9678ba994796a24"},
273 | {file = "scipy-1.13.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d533654b7d221a6a97304ab63c41c96473ff04459e404b83275b60aa8f4b7004"},
274 | {file = "scipy-1.13.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:637e98dcf185ba7f8e663e122ebf908c4702420477ae52a04f9908707456ba4d"},
275 | {file = "scipy-1.13.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a014c2b3697bde71724244f63de2476925596c24285c7a637364761f8710891c"},
276 | {file = "scipy-1.13.1-cp39-cp39-win_amd64.whl", hash = "sha256:392e4ec766654852c25ebad4f64e4e584cf19820b980bc04960bca0b0cd6eaa2"},
277 | {file = "scipy-1.13.1.tar.gz", hash = "sha256:095a87a0312b08dfd6a6155cbbd310a8c51800fc931b8c0b84003014b874ed3c"},
278 | ]
279 |
280 | [package.dependencies]
281 | numpy = ">=1.22.4,<2.3"
282 |
283 | [package.extras]
284 | dev = ["cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy", "pycodestyle", "pydevtool", "rich-click", "ruff", "types-psutil", "typing_extensions"]
285 | doc = ["jupyterlite-pyodide-kernel", "jupyterlite-sphinx (>=0.12.0)", "jupytext", "matplotlib (>=3.5)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (>=0.15.2)", "sphinx (>=5.0.0)", "sphinx-design (>=0.4.0)"]
286 | test = ["array-api-strict", "asv", "gmpy2", "hypothesis (>=6.30)", "mpmath", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"]
287 |
288 | [[package]]
289 | name = "setuptools"
290 | version = "75.2.0"
291 | description = "Easily download, build, install, upgrade, and uninstall Python packages"
292 | optional = false
293 | python-versions = ">=3.8"
294 | files = [
295 | {file = "setuptools-75.2.0-py3-none-any.whl", hash = "sha256:a7fcb66f68b4d9e8e66b42f9876150a3371558f98fa32222ffaa5bced76406f8"},
296 | {file = "setuptools-75.2.0.tar.gz", hash = "sha256:753bb6ebf1f465a1912e19ed1d41f403a79173a9acf66a42e7e6aec45c3c16ec"},
297 | ]
298 |
299 | [package.extras]
300 | check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)", "ruff (>=0.5.2)"]
301 | core = ["importlib-metadata (>=6)", "importlib-resources (>=5.10.2)", "jaraco.collections", "jaraco.functools", "jaraco.text (>=3.7)", "more-itertools", "more-itertools (>=8.8)", "packaging", "packaging (>=24)", "platformdirs (>=2.6.2)", "tomli (>=2.0.1)", "wheel (>=0.43.0)"]
302 | cover = ["pytest-cov"]
303 | doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier", "towncrier (<24.7)"]
304 | enabler = ["pytest-enabler (>=2.2)"]
305 | test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "jaraco.test", "packaging (>=23.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"]
306 | type = ["importlib-metadata (>=7.0.2)", "jaraco.develop (>=7.21)", "mypy (==1.11.*)", "pytest-mypy"]
307 |
308 | [[package]]
309 | name = "toolz"
310 | version = "1.0.0"
311 | description = "List processing tools and functional utilities"
312 | optional = false
313 | python-versions = ">=3.8"
314 | files = [
315 | {file = "toolz-1.0.0-py3-none-any.whl", hash = "sha256:292c8f1c4e7516bf9086f8850935c799a874039c8bcf959d47b600e4c44a6236"},
316 | {file = "toolz-1.0.0.tar.gz", hash = "sha256:2c86e3d9a04798ac556793bced838816296a2f085017664e4995cb40a1047a02"},
317 | ]
318 |
319 | [[package]]
320 | name = "triton"
321 | version = "3.0.0"
322 | description = "A language and compiler for custom Deep Learning operations"
323 | optional = false
324 | python-versions = "*"
325 | files = [
326 | {file = "triton-3.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e1efef76935b2febc365bfadf74bcb65a6f959a9872e5bddf44cc9e0adce1e1a"},
327 | {file = "triton-3.0.0-1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5ce8520437c602fb633f1324cc3871c47bee3b67acf9756c1a66309b60e3216c"},
328 | {file = "triton-3.0.0-1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:34e509deb77f1c067d8640725ef00c5cbfcb2052a1a3cb6a6d343841f92624eb"},
329 | {file = "triton-3.0.0-1-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:bcbf3b1c48af6a28011a5c40a5b3b9b5330530c3827716b5fbf6d7adcc1e53e9"},
330 | {file = "triton-3.0.0-1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6e5727202f7078c56f91ff13ad0c1abab14a0e7f2c87e91b12b6f64f3e8ae609"},
331 | ]
332 |
333 | [package.dependencies]
334 | filelock = "*"
335 |
336 | [package.extras]
337 | build = ["cmake (>=3.20)", "lit"]
338 | tests = ["autopep8", "flake8", "isort", "llnl-hatchet", "numpy", "pytest", "scipy (>=1.7.1)"]
339 | tutorials = ["matplotlib", "pandas", "tabulate"]
340 |
341 | [[package]]
342 | name = "typing-extensions"
343 | version = "4.12.2"
344 | description = "Backported and Experimental Type Hints for Python 3.8+"
345 | optional = false
346 | python-versions = ">=3.8"
347 | files = [
348 | {file = "typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d"},
349 | {file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"},
350 | ]
351 |
352 | [metadata]
353 | lock-version = "2.0"
354 | python-versions = ">=3.10"
355 | content-hash = "3e16c751bf056781ad6406640f06b72c8d32e81104cb9bd20a300dd34aeb77a5"
356 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | name = "jax-flash-attn2"
3 | version = "0.0.3"
4 | description = "Flash Attention Implementation with Multiple Backend Support and Sharding This module provides a flexible implementation of Flash Attention with support for different backends (GPU, TPU, CPU) and platforms (Triton, Pallas, JAX)."
5 | authors = ["Erfan Zare Chavoshi "]
6 | license = "Apache-2.0"
7 | readme = "README.md"
8 | homepage = "https://github.com/erfanzar/jax-flash-attn2"
9 | repository = "https://github.com/erfanzar/jax-flash-attn2"
10 | documentation = "https://erfanzar.github.io/jax-flash-attn2"
11 | keywords = ["JAX", "Deep Learning", "Machine Learning", "XLA"]
12 | classifiers = [
13 | "Development Status :: 3 - Alpha",
14 | "Intended Audience :: Developers",
15 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
16 | "License :: OSI Approved :: Apache Software License",
17 | "Programming Language :: Python :: 3",
18 | "Programming Language :: Python :: 3.9",
19 | "Programming Language :: Python :: 3.10",
20 | "Programming Language :: Python :: 3.11",
21 | "Programming Language :: Python :: 3.12",
22 | ]
23 |
24 | [tool.poetry.dependencies]
25 | python = ">=3.10,<3.14"
26 | jax = ">=0.4.36"
27 | jaxlib = ">=0.4.36"
28 | eformer = "0.0.15"
29 | einops = "~0.8.0"
30 | triton = "~3.2.0"
31 |
32 | [tool.ruff.lint]
33 | select = ["E4", "E7", "E9", "F", "B"]
34 | ignore = ["E501", "B905", "B007", "E741"]
35 | unfixable = ["B"]
36 |
37 | [tool.ruff.lint.per-file-ignores]
38 | "__init__.py" = ["E402", "F401"]
39 | "**/{tests,docs,tools}/*" = ["E402"]
40 | "tests/*" = ["E402", "E731"]
41 | "triton_*" = ["E741", "ISC001", "E501", "E731"]
42 | "pallas_*" = ["E741", "ISC001", "E501", "E731"]
43 |
44 | [tool.ruff.format]
45 | quote-style = "double"
46 | indent-style = "tab"
47 | docstring-code-format = true
48 |
49 | [tool.ruff]
50 | target-version = "py311"
51 | line-length = 88
52 | indent-width = 2
53 |
--------------------------------------------------------------------------------