├── .github └── workflows │ └── pylint.yml ├── .gitignore ├── LICENSE ├── README.md ├── benchmarks ├── triton-vs-jax-sdpa-cudnn │ ├── README.md │ ├── b=1-ub=False-hdim=128-qh=16-kvh=8-mode=fwd.csv │ ├── b=1-ub=False-hdim=128-qh=16-kvh=8-mode=fwd.png │ ├── b=1-ub=False-hdim=128-qh=32-kvh=16-mode=fwd.csv │ ├── b=1-ub=False-hdim=128-qh=32-kvh=16-mode=fwd.png │ ├── b=1-ub=False-hdim=256-qh=16-kvh=8-mode=fwd.csv │ ├── b=1-ub=False-hdim=256-qh=16-kvh=8-mode=fwd.png │ ├── b=1-ub=False-hdim=256-qh=32-kvh=16-mode=fwd.csv │ ├── b=1-ub=False-hdim=256-qh=32-kvh=16-mode=fwd.png │ ├── b=1-ub=False-hdim=64-qh=16-kvh=8-mode=fwd.csv │ ├── b=1-ub=False-hdim=64-qh=16-kvh=8-mode=fwd.png │ ├── b=1-ub=False-hdim=64-qh=32-kvh=16-mode=fwd.csv │ ├── b=1-ub=False-hdim=64-qh=32-kvh=16-mode=fwd.png │ ├── b=1-ub=True-hdim=128-qh=16-kvh=8-mode=fwd.csv │ ├── b=1-ub=True-hdim=128-qh=16-kvh=8-mode=fwd.png │ ├── b=1-ub=True-hdim=128-qh=32-kvh=16-mode=fwd.csv │ ├── b=1-ub=True-hdim=128-qh=32-kvh=16-mode=fwd.png │ ├── b=1-ub=True-hdim=256-qh=16-kvh=8-mode=fwd.csv │ ├── b=1-ub=True-hdim=256-qh=16-kvh=8-mode=fwd.png │ ├── b=1-ub=True-hdim=256-qh=32-kvh=16-mode=fwd.csv │ ├── b=1-ub=True-hdim=256-qh=32-kvh=16-mode=fwd.png │ ├── b=1-ub=True-hdim=64-qh=16-kvh=8-mode=fwd.csv │ ├── b=1-ub=True-hdim=64-qh=16-kvh=8-mode=fwd.png │ ├── b=1-ub=True-hdim=64-qh=32-kvh=16-mode=fwd.csv │ ├── b=1-ub=True-hdim=64-qh=32-kvh=16-mode=fwd.png │ ├── b=2-ub=False-hdim=128-qh=16-kvh=8-mode=fwd.csv │ ├── b=2-ub=False-hdim=128-qh=16-kvh=8-mode=fwd.png │ ├── b=2-ub=False-hdim=128-qh=32-kvh=16-mode=fwd.csv │ ├── b=2-ub=False-hdim=128-qh=32-kvh=16-mode=fwd.png │ ├── b=2-ub=False-hdim=256-qh=16-kvh=8-mode=fwd.csv │ ├── b=2-ub=False-hdim=256-qh=16-kvh=8-mode=fwd.png │ ├── b=2-ub=False-hdim=256-qh=32-kvh=16-mode=fwd.csv │ ├── b=2-ub=False-hdim=256-qh=32-kvh=16-mode=fwd.png │ ├── b=2-ub=False-hdim=64-qh=16-kvh=8-mode=fwd.csv │ ├── b=2-ub=False-hdim=64-qh=16-kvh=8-mode=fwd.png │ ├── b=2-ub=False-hdim=64-qh=32-kvh=16-mode=fwd.csv │ ├── b=2-ub=False-hdim=64-qh=32-kvh=16-mode=fwd.png │ ├── b=2-ub=True-hdim=128-qh=16-kvh=8-mode=fwd.csv │ ├── b=2-ub=True-hdim=128-qh=16-kvh=8-mode=fwd.png │ ├── b=2-ub=True-hdim=128-qh=32-kvh=16-mode=fwd.csv │ ├── b=2-ub=True-hdim=128-qh=32-kvh=16-mode=fwd.png │ ├── b=2-ub=True-hdim=256-qh=16-kvh=8-mode=fwd.csv │ ├── b=2-ub=True-hdim=256-qh=16-kvh=8-mode=fwd.png │ ├── b=2-ub=True-hdim=256-qh=32-kvh=16-mode=fwd.csv │ ├── b=2-ub=True-hdim=256-qh=32-kvh=16-mode=fwd.png │ ├── b=2-ub=True-hdim=64-qh=16-kvh=8-mode=fwd.csv │ ├── b=2-ub=True-hdim=64-qh=16-kvh=8-mode=fwd.png │ ├── b=2-ub=True-hdim=64-qh=32-kvh=16-mode=fwd.csv │ ├── b=2-ub=True-hdim=64-qh=32-kvh=16-mode=fwd.png │ ├── b=4-ub=False-hdim=128-qh=16-kvh=8-mode=fwd.csv │ ├── b=4-ub=False-hdim=128-qh=16-kvh=8-mode=fwd.png │ ├── b=4-ub=False-hdim=128-qh=32-kvh=16-mode=fwd.csv │ ├── b=4-ub=False-hdim=128-qh=32-kvh=16-mode=fwd.png │ ├── b=4-ub=False-hdim=256-qh=16-kvh=8-mode=fwd.csv │ ├── b=4-ub=False-hdim=256-qh=16-kvh=8-mode=fwd.png │ ├── b=4-ub=False-hdim=256-qh=32-kvh=16-mode=fwd.csv │ ├── b=4-ub=False-hdim=256-qh=32-kvh=16-mode=fwd.png │ ├── b=4-ub=False-hdim=64-qh=16-kvh=8-mode=fwd.csv │ ├── b=4-ub=False-hdim=64-qh=16-kvh=8-mode=fwd.png │ ├── b=4-ub=False-hdim=64-qh=32-kvh=16-mode=fwd.csv │ ├── b=4-ub=False-hdim=64-qh=32-kvh=16-mode=fwd.png │ ├── b=4-ub=True-hdim=128-qh=16-kvh=8-mode=fwd.csv │ ├── b=4-ub=True-hdim=128-qh=16-kvh=8-mode=fwd.png │ ├── b=4-ub=True-hdim=128-qh=32-kvh=16-mode=fwd.csv │ ├── b=4-ub=True-hdim=128-qh=32-kvh=16-mode=fwd.png │ ├── b=4-ub=True-hdim=256-qh=16-kvh=8-mode=fwd.csv │ ├── b=4-ub=True-hdim=256-qh=16-kvh=8-mode=fwd.png │ ├── b=4-ub=True-hdim=256-qh=32-kvh=16-mode=fwd.csv │ ├── b=4-ub=True-hdim=256-qh=32-kvh=16-mode=fwd.png │ ├── b=4-ub=True-hdim=64-qh=16-kvh=8-mode=fwd.csv │ ├── b=4-ub=True-hdim=64-qh=16-kvh=8-mode=fwd.png │ ├── b=4-ub=True-hdim=64-qh=32-kvh=16-mode=fwd.csv │ ├── b=4-ub=True-hdim=64-qh=32-kvh=16-mode=fwd.png │ └── results.html └── triton-vs-jax-sdpa │ ├── README.md │ ├── b=1-ub=False-hdim=128-qh=16-kvh=8-mode=fwd.csv │ ├── b=1-ub=False-hdim=128-qh=16-kvh=8-mode=fwd.png │ ├── b=1-ub=False-hdim=128-qh=32-kvh=16-mode=fwd.csv │ ├── b=1-ub=False-hdim=128-qh=32-kvh=16-mode=fwd.png │ ├── b=1-ub=False-hdim=256-qh=16-kvh=8-mode=fwd.csv │ ├── b=1-ub=False-hdim=256-qh=16-kvh=8-mode=fwd.png │ ├── b=1-ub=False-hdim=256-qh=32-kvh=16-mode=fwd.csv │ ├── b=1-ub=False-hdim=256-qh=32-kvh=16-mode=fwd.png │ ├── b=1-ub=False-hdim=64-qh=16-kvh=8-mode=fwd.csv │ ├── b=1-ub=False-hdim=64-qh=16-kvh=8-mode=fwd.png │ ├── b=1-ub=False-hdim=64-qh=32-kvh=16-mode=fwd.csv │ ├── b=1-ub=False-hdim=64-qh=32-kvh=16-mode=fwd.png │ ├── b=1-ub=True-hdim=128-qh=16-kvh=8-mode=fwd.csv │ ├── b=1-ub=True-hdim=128-qh=16-kvh=8-mode=fwd.png │ ├── b=1-ub=True-hdim=128-qh=32-kvh=16-mode=fwd.csv │ ├── b=1-ub=True-hdim=128-qh=32-kvh=16-mode=fwd.png │ ├── b=1-ub=True-hdim=256-qh=16-kvh=8-mode=fwd.csv │ ├── b=1-ub=True-hdim=256-qh=16-kvh=8-mode=fwd.png │ ├── b=1-ub=True-hdim=256-qh=32-kvh=16-mode=fwd.csv │ ├── b=1-ub=True-hdim=256-qh=32-kvh=16-mode=fwd.png │ ├── b=1-ub=True-hdim=64-qh=16-kvh=8-mode=fwd.csv │ ├── b=1-ub=True-hdim=64-qh=16-kvh=8-mode=fwd.png │ ├── b=1-ub=True-hdim=64-qh=32-kvh=16-mode=fwd.csv │ ├── b=1-ub=True-hdim=64-qh=32-kvh=16-mode=fwd.png │ ├── b=2-ub=True-hdim=128-qh=16-kvh=8-mode=fwd.csv │ ├── b=2-ub=True-hdim=128-qh=16-kvh=8-mode=fwd.png │ ├── b=2-ub=True-hdim=128-qh=32-kvh=16-mode=fwd.csv │ ├── b=2-ub=True-hdim=128-qh=32-kvh=16-mode=fwd.png │ ├── b=2-ub=True-hdim=256-qh=16-kvh=8-mode=fwd.csv │ ├── b=2-ub=True-hdim=256-qh=16-kvh=8-mode=fwd.png │ ├── b=2-ub=True-hdim=64-qh=16-kvh=8-mode=fwd.csv │ ├── b=2-ub=True-hdim=64-qh=16-kvh=8-mode=fwd.png │ ├── b=2-ub=True-hdim=64-qh=32-kvh=16-mode=fwd.csv │ ├── b=2-ub=True-hdim=64-qh=32-kvh=16-mode=fwd.png │ └── results.html ├── jax_flash_attn2 ├── __init__.py ├── flash_attention.py ├── flash_attention_jax │ ├── __init__.py │ ├── _backward_jax.py │ ├── _flash_attention.py │ └── _forward_jax.py ├── flash_attention_triton │ ├── __init__.py │ ├── _backward_triton.py │ ├── _flash_attention.py │ ├── _forward_triton.py │ └── _utils.py ├── refrence_call.py └── utils.py ├── poetry.lock └── pyproject.toml /.github/workflows/pylint.yml: -------------------------------------------------------------------------------- 1 | name: Pylint 2 | 3 | on: [push] 4 | 5 | jobs: 6 | build: 7 | runs-on: ubuntu-latest 8 | strategy: 9 | matrix: 10 | python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] 11 | steps: 12 | - uses: actions/checkout@v4 13 | - name: Set up Python ${{ matrix.python-version }} 14 | uses: actions/setup-python@v3 15 | with: 16 | python-version: ${{ matrix.python-version }} 17 | - name: Install dependencies 18 | run: | 19 | python -m pip install --upgrade pip 20 | pip install pylint 21 | - name: Analysing the code with pylint 22 | run: | 23 | pylint $(git ls-files '*.py') 24 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | env.py 9 | exp.py 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | cover/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | .pybuilder/ 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | # For a library or package, you might want to ignore these files since the code is 88 | # intended to run in multiple environments; otherwise, check them in: 89 | # .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # poetry 99 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 100 | # This is especially recommended for binary packages to ensure reproducibility, and is more 101 | # commonly ignored for libraries. 102 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 103 | #poetry.lock 104 | 105 | # pdm 106 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 107 | #pdm.lock 108 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 109 | # in version control. 110 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 111 | .pdm.toml 112 | .pdm-python 113 | .pdm-build/ 114 | 115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 116 | __pypackages__/ 117 | 118 | # Celery stuff 119 | celerybeat-schedule 120 | celerybeat.pid 121 | 122 | # SageMath parsed files 123 | *.sage.py 124 | 125 | # Environments 126 | .env 127 | .venv 128 | env/ 129 | venv/ 130 | ENV/ 131 | env.bak/ 132 | venv.bak/ 133 | 134 | # Spyder project settings 135 | .spyderproject 136 | .spyproject 137 | 138 | # Rope project settings 139 | .ropeproject 140 | 141 | # mkdocs documentation 142 | /site 143 | 144 | # mypy 145 | .mypy_cache/ 146 | .dmypy.json 147 | dmypy.json 148 | 149 | # Pyre type checker 150 | .pyre/ 151 | 152 | # pytype static type analyzer 153 | .pytype/ 154 | 155 | # Cython debug symbols 156 | cython_debug/ 157 | 158 | # PyCharm 159 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 160 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 161 | # and can be added to the global gitignore or merged into this file. For a more nuclear 162 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 163 | #.idea/ 164 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # JAX-Flash-Attention2 2 | 3 | A flexible and efficient implementation of Flash Attention 2.0 for JAX, supporting multiple backends (GPU/TPU/CPU) and platforms (Triton/Pallas/JAX). 4 | 5 | ## Installation 6 | 7 | ```bash 8 | pip install jax-flash-attn2 9 | ``` 10 | 11 | ## Basic Usage 12 | 13 | ```python 14 | import jax 15 | import jax.numpy as jnp 16 | import jax_flash_attn2 as jfa 17 | 18 | # Initialize the FlashAttention module with desired configuration 19 | flash_attention = jfa.FlashAttention( 20 | jfa.AttentionConfig( 21 | platform=jfa.Platform.TRITON, # Options: TRITON, PALLAS, JAX 22 | backend=jfa.Backend.GPU, # Options: GPU, TPU, CPU 23 | ) 24 | ) 25 | 26 | # Create sample inputs 27 | batch_size, num_heads, seq_len, head_dim = 2, 4, 512, 64 28 | query = jax.random.normal(jax.random.PRNGKey(0), (batch_size, num_heads * 4, seq_len, head_dim), "f2") 29 | key = jax.random.normal(jax.random.PRNGKey(1), (batch_size, num_heads, seq_len, head_dim), "f2") 30 | value = jax.random.normal(jax.random.PRNGKey(2), (batch_size, num_heads, seq_len, head_dim), "f2") 31 | 32 | # Compute attention 33 | output = flash_attention( 34 | query=query, 35 | key=key, 36 | value=value, 37 | causal=True # Enable causal masking for decoder-only models 38 | ) 39 | 40 | # output shape: (batch_size, num_heads, seq_len, head_dim) 41 | ``` 42 | 43 | ## Advanced Usage 44 | 45 | ### With Attention Mask 46 | 47 | ```python 48 | # Create an attention mask (1 = attend, 0 = mask) 49 | attention_mask = jnp.ones((batch_size, 1, seq_len, seq_len)) # Allow full attention 50 | # For example, mask the first 100 tokens from attending to the last 100 tokens 51 | attention_mask = attention_mask.at[:, :, :100, -100:].set(0) 52 | 53 | output = flash_attention( 54 | query=query, 55 | key=key, 56 | value=value, 57 | attention_mask=attention_mask, 58 | causal=False # Using explicit mask instead of causal 59 | ) 60 | ``` 61 | 62 | ### With Attention Bias 63 | 64 | ```python 65 | # Create an attention bias 66 | bias = jnp.zeros((batch_size, 1, seq_len, seq_len)) 67 | # Add position-dependent bias 68 | for i in range(seq_len): 69 | for j in range(seq_len): 70 | bias = bias.at[:, :, i, j].set(1.0 / (1.0 + abs(i - j))) 71 | 72 | output = flash_attention( 73 | query=query, 74 | key=key, 75 | value=value, 76 | bias=bias 77 | ) 78 | ``` 79 | 80 | ### With Dropout 81 | 82 | ```python 83 | output = flash_attention( 84 | query=query, 85 | key=key, 86 | value=value, 87 | dropout_prob=0.1, 88 | dropout_seed=42, 89 | causal=True 90 | ) 91 | ``` 92 | 93 | ## Flax Modules with JFA2 94 | 95 | Here's an example of integrating jax-flash-attn2 within a Transformer model implemented in Flax: 96 | 97 | ```python 98 | import typing as tp 99 | from functools import partial 100 | 101 | import chex 102 | import flax.nnx as nn 103 | import jax 104 | import jax.numpy as jnp 105 | 106 | import jax_flash_attn2 as jfa 107 | 108 | 109 | class JFAttention2(nn.Module): 110 | def __init__( 111 | self, 112 | hidden_size: int, 113 | head_dim: int, 114 | num_attention_heads: int, 115 | num_key_value_heads: int, 116 | dtype: jnp.dtype = jnp.float32, 117 | param_dtype: jnp.dtype = jnp.float32, 118 | precision: jax.lax.PrecisionLike = None, 119 | *, 120 | rngs: nn.Rngs = None, 121 | ): 122 | if rngs is None: 123 | rngs = nn.Rngs(0) 124 | self.dtype = dtype 125 | self.param_dtype = param_dtype 126 | self.precision = precision 127 | self.rngs = rngs 128 | 129 | self.hidden_size = hidden_size 130 | self.head_dim = head_dim 131 | self.num_attention_heads = num_attention_heads 132 | self.num_key_value_heads = num_key_value_heads 133 | 134 | self.num_key_value_groups = num_attention_heads // num_key_value_heads 135 | 136 | if self.num_key_value_groups == 1: 137 | assert num_attention_heads == num_key_value_heads 138 | 139 | linear_class = partial( 140 | nn.Linear, 141 | dtype=dtype, 142 | param_dtype=param_dtype, 143 | use_bias=False, 144 | kernel_init=jax.nn.initializers.normal(0.02), 145 | precision=precision, 146 | rngs=rngs, 147 | ) 148 | self.q_proj = linear_class(hidden_size, num_attention_heads * self.head_dim) 149 | self.k_proj = linear_class(hidden_size, num_key_value_heads * self.head_dim) 150 | self.v_proj = linear_class(hidden_size, num_key_value_heads * self.head_dim) 151 | self.o_proj = linear_class(num_attention_heads * self.head_dim, hidden_size) 152 | 153 | config = jfa.AttentionConfig(platform=jfa.Platform.TRITON, backend=jfa.Backend.GPU) 154 | 155 | self.jfa2 = jfa.FlashAttention(config) 156 | 157 | def __call__( 158 | self, 159 | hidden_states: chex.Array, 160 | attention_mask: chex.Array, 161 | causal: bool = True, 162 | ) -> tp.Tuple[chex.Array, chex.Array]: 163 | batch_size, sequence_length = hidden_states.shape[:2] 164 | query_states, key_states, value_states = ( 165 | self.q_proj(hidden_states), 166 | self.k_proj(hidden_states), 167 | self.v_proj(hidden_states), 168 | ) 169 | qshape = ( 170 | batch_size, 171 | sequence_length, 172 | self.num_attention_heads, 173 | self.head_dim, 174 | ) 175 | kv_shape = ( 176 | batch_size, 177 | sequence_length, 178 | self.num_key_value_heads, 179 | self.head_dim, 180 | ) 181 | query_states = query_states.reshape(qshape) 182 | key_states = key_states.reshape(kv_shape) 183 | value_states = value_states.reshape(kv_shape) 184 | attn_output = self.jfa2.forward( 185 | query_states.astype(jnp.bfloat16), 186 | key_states.astype(jnp.bfloat16), 187 | value_states.astype(jnp.bfloat16), 188 | jnp.where(attention_mask, 0, jnp.finfo(query_states).min).astype(jnp.bfloat16), 189 | causal=causal, 190 | ) 191 | attn_output = jnp.reshape(attn_output, (batch_size, sequence_length, -1)) 192 | attn_output = self.o_proj(attn_output) 193 | return attn_output 194 | ``` 195 | 196 | ## Platform-Specific Examples 197 | 198 | ### Using JAX Backend 199 | 200 | ```python 201 | jax_flash_attn = jfa.FlashAttention( 202 | jfa.AttentionConfig( 203 | platform=jfa.Platform.JAX, 204 | backend=jfa.Backend.CPU, # Works on any hardware 205 | ) 206 | ) 207 | 208 | output = jax_flash_attn(query, key, value) 209 | ``` 210 | 211 | ### Using Pallas for TPU 212 | 213 | ```python 214 | tpu_flash_attn = jfa.FlashAttention( 215 | jfa.AttentionConfig( 216 | platform=jfa.Platform.PALLAS, 217 | backend=jfa.Backend.TPU, 218 | ) 219 | ) 220 | 221 | output = tpu_flash_attn(query, key, value) 222 | ``` 223 | 224 | ## Integration with JAX Transformations 225 | 226 | ```python 227 | @jax.jit 228 | def attention_forward(q, k, v, mask=None): 229 | return flash_attention( 230 | query=q, 231 | key=k, 232 | value=v, 233 | attention_mask=mask, 234 | causal=True 235 | ) 236 | 237 | # JIT-compiled function 238 | fast_attention = attention_forward(query, key, value) 239 | 240 | # With gradient computation 241 | def loss_fn(q, k, v): 242 | attn_output = flash_attention(q, k, v, causal=True) 243 | return jnp.mean(attn_output) 244 | 245 | grads = jax.grad(loss_fn)(query, key, value) 246 | ``` 247 | 248 | ## Limitations 249 | 250 | - Triton platform is only available on NVIDIA GPUs. 251 | - Some platform-backend combinations are not supported (see table above). 252 | - Custom attention masks are not yet supported (use bias instead). 253 | 254 | ## Contributing 255 | 256 | Contributions are welcome! Please feel free to submit a Pull Request. 257 | 258 | ## Citation 259 | 260 | If you use this implementation in your research, please cite: 261 | 262 | ```bibtex 263 | @software{jax_flash_attn2, 264 | title = {JAX Flash Attention 2.0}, 265 | year = {2024}, 266 | url = {https://github.com/erfanzar/jax-flash-attn2} 267 | } 268 | ``` 269 | 270 | ### refrence citations 271 | 272 | ```bibtex 273 | @inproceedings{dao2022flashattention, 274 | title={Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness}, 275 | author={Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher}, 276 | booktitle={Advances in Neural Information Processing Systems (NeurIPS)}, 277 | year={2022} 278 | } 279 | @inproceedings{dao2023flashattention2, 280 | title={Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning}, 281 | author={Dao, Tri}, 282 | booktitle={International Conference on Learning Representations (ICLR)}, 283 | year={2024} 284 | } 285 | ``` 286 | 287 | ## Acknowledgments And Refrences 288 | 289 | 1. All of kernels are copied from [`EasyDeL`](https://github.com/erfanzar/Easydel) 290 | -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=False-hdim=128-qh=16-kvh=8-mode=fwd.csv: -------------------------------------------------------------------------------- 1 | seqlen,Triton,Jax-cudnn 2 | 1024.000000,0.109986,0.260454 3 | 2048.000000,0.277971,0.602729 4 | 4096.000000,0.539475,2.037018 5 | 6144.000000,0.657339,4.449059 6 | 8192.000000,0.536271,7.733691 7 | -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=False-hdim=128-qh=16-kvh=8-mode=fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=False-hdim=128-qh=16-kvh=8-mode=fwd.png -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=False-hdim=128-qh=32-kvh=16-mode=fwd.csv: -------------------------------------------------------------------------------- 1 | seqlen,Triton,Jax-cudnn 2 | 1024.000000,0.201889,0.316081 3 | 2048.000000,0.277653,1.052060 4 | 4096.000000,0.558586,3.973465 5 | 6144.000000,0.312637,8.737595 6 | 8192.000000,0.001152,15.191124 7 | -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=False-hdim=128-qh=32-kvh=16-mode=fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=False-hdim=128-qh=32-kvh=16-mode=fwd.png -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=False-hdim=256-qh=16-kvh=8-mode=fwd.csv: -------------------------------------------------------------------------------- 1 | seqlen,Triton,Jax-cudnn 2 | 1024.000000,0.293918,300.000000 3 | 2048.000000,0.564013,300.000000 4 | 4096.000000,0.581458,300.000000 5 | 6144.000000,0.001311,300.000000 6 | 8192.000000,0.001536,300.000000 7 | -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=False-hdim=256-qh=16-kvh=8-mode=fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=False-hdim=256-qh=16-kvh=8-mode=fwd.png -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=False-hdim=256-qh=32-kvh=16-mode=fwd.csv: -------------------------------------------------------------------------------- 1 | seqlen,Triton,Jax-cudnn 2 | 1024.000000,0.299787,300.000000 3 | 2048.000000,0.549313,300.000000 4 | 4096.000000,0.061952,300.000000 5 | 6144.000000,0.001654,300.000000 6 | 8192.000000,0.001317,300.000000 7 | -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=False-hdim=256-qh=32-kvh=16-mode=fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=False-hdim=256-qh=32-kvh=16-mode=fwd.png -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=False-hdim=64-qh=16-kvh=8-mode=fwd.csv: -------------------------------------------------------------------------------- 1 | seqlen,Triton,Jax-cudnn 2 | 1024.000000,0.001285,0.052782 3 | 2048.000000,0.145014,0.378316 4 | 4096.000000,0.233479,1.062174 5 | 6144.000000,0.366142,2.297836 6 | 8192.000000,0.392203,3.882547 7 | -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=False-hdim=64-qh=16-kvh=8-mode=fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=False-hdim=64-qh=16-kvh=8-mode=fwd.png -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=False-hdim=64-qh=32-kvh=16-mode=fwd.csv: -------------------------------------------------------------------------------- 1 | seqlen,Triton,Jax-cudnn 2 | 1024.000000,0.001382,0.184686 3 | 2048.000000,0.179171,0.601632 4 | 4096.000000,0.333080,1.976510 5 | 6144.000000,0.461126,4.308108 6 | 8192.000000,0.357641,7.543623 7 | -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=False-hdim=64-qh=32-kvh=16-mode=fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=False-hdim=64-qh=32-kvh=16-mode=fwd.png -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=True-hdim=128-qh=16-kvh=8-mode=fwd.csv: -------------------------------------------------------------------------------- 1 | seqlen,Triton,Jax-cudnn 2 | 1024.000000,0.154736,0.244154 3 | 2048.000000,0.337517,0.651588 4 | 4096.000000,0.561307,2.177832 5 | 6144.000000,0.744631,5.053024 6 | 8192.000000,0.475687,8.623260 7 | -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=True-hdim=128-qh=16-kvh=8-mode=fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=True-hdim=128-qh=16-kvh=8-mode=fwd.png -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=True-hdim=128-qh=32-kvh=16-mode=fwd.csv: -------------------------------------------------------------------------------- 1 | seqlen,Triton,Jax-cudnn 2 | 1024.000000,0.236990,0.358192 3 | 2048.000000,0.419905,1.142686 4 | 4096.000000,0.569661,4.441965 5 | 6144.000000,0.263899,9.707267 6 | 8192.000000,0.001587,17.073843 7 | -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=True-hdim=128-qh=32-kvh=16-mode=fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=True-hdim=128-qh=32-kvh=16-mode=fwd.png -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=True-hdim=256-qh=16-kvh=8-mode=fwd.csv: -------------------------------------------------------------------------------- 1 | seqlen,Triton,Jax-cudnn 2 | 1024.000000,0.367849,300.000000 3 | 2048.000000,0.695647,300.000000 4 | 4096.000000,0.379563,300.000000 5 | 6144.000000,0.001536,300.000000 6 | 8192.000000,0.001479,300.000000 7 | -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=True-hdim=256-qh=16-kvh=8-mode=fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=True-hdim=256-qh=16-kvh=8-mode=fwd.png -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=True-hdim=256-qh=32-kvh=16-mode=fwd.csv: -------------------------------------------------------------------------------- 1 | seqlen,Triton,Jax-cudnn 2 | 1024.000000,0.354809,300.000000 3 | 2048.000000,0.535335,300.000000 4 | 4096.000000,0.001536,300.000000 5 | 6144.000000,0.001152,300.000000 6 | 8192.000000,0.002048,300.000000 7 | -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=True-hdim=256-qh=32-kvh=16-mode=fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=True-hdim=256-qh=32-kvh=16-mode=fwd.png -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=True-hdim=64-qh=16-kvh=8-mode=fwd.csv: -------------------------------------------------------------------------------- 1 | seqlen,Triton,Jax-cudnn 2 | 1024.000000,0.009526,0.232182 3 | 2048.000000,0.256336,0.574287 4 | 4096.000000,0.443419,2.048195 5 | 6144.000000,0.787608,4.677517 6 | 8192.000000,0.911081,11.655725 7 | -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=True-hdim=64-qh=16-kvh=8-mode=fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=True-hdim=64-qh=16-kvh=8-mode=fwd.png -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=True-hdim=64-qh=32-kvh=16-mode=fwd.csv: -------------------------------------------------------------------------------- 1 | seqlen,Triton,Jax-cudnn 2 | 1024.000000,0.146767,0.336205 3 | 2048.000000,0.258379,1.036054 4 | 4096.000000,0.530867,3.775501 5 | 6144.000000,0.605855,11.867448 6 | 8192.000000,0.293947,24.475136 7 | -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=True-hdim=64-qh=32-kvh=16-mode=fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=1-ub=True-hdim=64-qh=32-kvh=16-mode=fwd.png -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=False-hdim=128-qh=16-kvh=8-mode=fwd.csv: -------------------------------------------------------------------------------- 1 | seqlen,Triton,Jax-cudnn 2 | 1024.000000,0.199049,0.342512 3 | 2048.000000,0.302052,1.057982 4 | 4096.000000,0.462467,4.031027 5 | 6144.000000,0.339456,8.867299 6 | 8192.000000,0.001707,15.418736 7 | -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=False-hdim=128-qh=16-kvh=8-mode=fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=False-hdim=128-qh=16-kvh=8-mode=fwd.png -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=False-hdim=128-qh=32-kvh=16-mode=fwd.csv: -------------------------------------------------------------------------------- 1 | seqlen,Triton,Jax-cudnn 2 | 1024.000000,0.203239,0.592487 3 | 2048.000000,0.401872,2.120210 4 | 4096.000000,0.001365,7.893704 5 | 6144.000000,0.001463,17.337357 6 | 8192.000000,0.001451,30.725855 7 | -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=False-hdim=128-qh=32-kvh=16-mode=fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=False-hdim=128-qh=32-kvh=16-mode=fwd.png -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=False-hdim=256-qh=16-kvh=8-mode=fwd.csv: -------------------------------------------------------------------------------- 1 | seqlen,Triton,Jax-cudnn 2 | 1024.000000,0.305674,300.000000 3 | 2048.000000,0.442985,300.000000 4 | 4096.000000,0.061477,300.000000 5 | 6144.000000,0.001339,300.000000 6 | 8192.000000,0.001609,300.000000 7 | -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=False-hdim=256-qh=16-kvh=8-mode=fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=False-hdim=256-qh=16-kvh=8-mode=fwd.png -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=False-hdim=256-qh=32-kvh=16-mode=fwd.csv: -------------------------------------------------------------------------------- 1 | seqlen,Triton,Jax-cudnn 2 | 1024.000000,0.372295,300.000000 3 | 2048.000000,0.363640,300.000000 4 | 4096.000000,0.001463,300.000000 5 | 6144.000000,0.002219,300.000000 6 | 8192.000000,0.001707,300.000000 7 | -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=False-hdim=256-qh=32-kvh=16-mode=fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=False-hdim=256-qh=32-kvh=16-mode=fwd.png -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=False-hdim=64-qh=16-kvh=8-mode=fwd.csv: -------------------------------------------------------------------------------- 1 | seqlen,Triton,Jax-cudnn 2 | 1024.000000,0.001341,0.119918 3 | 2048.000000,0.190682,0.616036 4 | 4096.000000,0.307017,2.008788 5 | 6144.000000,0.472064,4.379703 6 | 8192.000000,0.386290,7.646230 7 | -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=False-hdim=64-qh=16-kvh=8-mode=fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=False-hdim=64-qh=16-kvh=8-mode=fwd.png -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=False-hdim=64-qh=32-kvh=16-mode=fwd.csv: -------------------------------------------------------------------------------- 1 | seqlen,Triton,Jax-cudnn 2 | 1024.000000,0.001479,0.268115 3 | 2048.000000,0.189668,1.023319 4 | 4096.000000,0.170710,3.953273 5 | 6144.000000,0.230988,8.635345 6 | 8192.000000,0.001463,15.048271 7 | -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=False-hdim=64-qh=32-kvh=16-mode=fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=False-hdim=64-qh=32-kvh=16-mode=fwd.png -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=True-hdim=128-qh=16-kvh=8-mode=fwd.csv: -------------------------------------------------------------------------------- 1 | seqlen,Triton,Jax-cudnn 2 | 1024.000000,0.238014,0.355650 3 | 2048.000000,0.426852,1.160171 4 | 4096.000000,0.623411,4.461117 5 | 6144.000000,0.262114,9.963975 6 | 8192.000000,0.001331,17.273287 7 | -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=True-hdim=128-qh=16-kvh=8-mode=fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=True-hdim=128-qh=16-kvh=8-mode=fwd.png -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=True-hdim=128-qh=32-kvh=16-mode=fwd.csv: -------------------------------------------------------------------------------- 1 | seqlen,Triton,Jax-cudnn 2 | 1024.000000,0.253855,0.626204 3 | 2048.000000,0.489057,2.258021 4 | 4096.000000,0.237148,8.760285 5 | 6144.000000,0.001265,19.611481 6 | 8192.000000,0.001741,34.416176 7 | -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=True-hdim=128-qh=32-kvh=16-mode=fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=True-hdim=128-qh=32-kvh=16-mode=fwd.png -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=True-hdim=256-qh=16-kvh=8-mode=fwd.csv: -------------------------------------------------------------------------------- 1 | seqlen,Triton,Jax-cudnn 2 | 1024.000000,0.364506,300.000000 3 | 2048.000000,0.547446,300.000000 4 | 4096.000000,0.001536,300.000000 5 | 6144.000000,0.002176,300.000000 6 | 8192.000000,0.001536,300.000000 7 | -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=True-hdim=256-qh=16-kvh=8-mode=fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=True-hdim=256-qh=16-kvh=8-mode=fwd.png -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=True-hdim=256-qh=32-kvh=16-mode=fwd.csv: -------------------------------------------------------------------------------- 1 | seqlen,Triton,Jax-cudnn 2 | 1024.000000,0.426923,300.000000 3 | 2048.000000,0.206000,300.000000 4 | 4096.000000,0.001479,300.000000 5 | 6144.000000,0.002048,300.000000 6 | 8192.000000,0.001536,300.000000 7 | -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=True-hdim=256-qh=32-kvh=16-mode=fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=True-hdim=256-qh=32-kvh=16-mode=fwd.png -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=True-hdim=64-qh=16-kvh=8-mode=fwd.csv: -------------------------------------------------------------------------------- 1 | seqlen,Triton,Jax-cudnn 2 | 1024.000000,0.181881,0.348315 3 | 2048.000000,0.248074,1.040480 4 | 4096.000000,0.596546,3.865138 5 | 6144.000000,0.630855,12.445517 6 | 8192.000000,0.288358,24.834625 7 | -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=True-hdim=64-qh=16-kvh=8-mode=fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=True-hdim=64-qh=16-kvh=8-mode=fwd.png -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=True-hdim=64-qh=32-kvh=16-mode=fwd.csv: -------------------------------------------------------------------------------- 1 | seqlen,Triton,Jax-cudnn 2 | 1024.000000,0.208093,0.579161 3 | 2048.000000,0.392836,1.864016 4 | 4096.000000,0.479936,9.628813 5 | 6144.000000,0.114556,27.601694 6 | 8192.000000,0.001506,51.775490 7 | -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=True-hdim=64-qh=32-kvh=16-mode=fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=2-ub=True-hdim=64-qh=32-kvh=16-mode=fwd.png -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=False-hdim=128-qh=16-kvh=8-mode=fwd.csv: -------------------------------------------------------------------------------- 1 | seqlen,Triton,Jax-cudnn 2 | 1024.000000,0.224760,0.575211 3 | 2048.000000,0.371112,2.018634 4 | 4096.000000,0.001536,5.240088 5 | 6144.000000,0.001268,18.134644 6 | 8192.000000,0.001536,32.440353 7 | -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=False-hdim=128-qh=16-kvh=8-mode=fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=False-hdim=128-qh=16-kvh=8-mode=fwd.png -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=False-hdim=128-qh=32-kvh=16-mode=fwd.csv: -------------------------------------------------------------------------------- 1 | seqlen,Triton,Jax-cudnn 2 | 1024.000000,0.182035,0.957698 3 | 2048.000000,0.024439,2.797037 4 | 4096.000000,0.001769,15.568085 5 | 6144.000000,0.001434,34.589699 6 | 8192.000000,0.001365,61.821152 7 | -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=False-hdim=128-qh=32-kvh=16-mode=fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=False-hdim=128-qh=32-kvh=16-mode=fwd.png -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=False-hdim=256-qh=16-kvh=8-mode=fwd.csv: -------------------------------------------------------------------------------- 1 | seqlen,Triton,Jax-cudnn 2 | 1024.000000,0.344924,300.000000 3 | 2048.000000,0.330334,300.000000 4 | 4096.000000,0.001497,300.000000 5 | 6144.000000,0.001877,300.000000 6 | 8192.000000,0.001365,300.000000 7 | -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=False-hdim=256-qh=16-kvh=8-mode=fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=False-hdim=256-qh=16-kvh=8-mode=fwd.png -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=False-hdim=256-qh=32-kvh=16-mode=fwd.csv: -------------------------------------------------------------------------------- 1 | seqlen,Triton,Jax-cudnn 2 | 1024.000000,0.316557,300.000000 3 | 2048.000000,0.001516,300.000000 4 | 4096.000000,0.001536,300.000000 5 | 6144.000000,0.001707,300.000000 6 | 8192.000000,0.002048,300.000000 7 | -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=False-hdim=256-qh=32-kvh=16-mode=fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=False-hdim=256-qh=32-kvh=16-mode=fwd.png -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=False-hdim=64-qh=16-kvh=8-mode=fwd.csv: -------------------------------------------------------------------------------- 1 | seqlen,Triton,Jax-cudnn 2 | 1024.000000,0.001392,0.309596 3 | 2048.000000,0.193620,1.052123 4 | 4096.000000,0.170950,3.937541 5 | 6144.000000,0.208427,8.745746 6 | 8192.000000,0.001682,15.822118 7 | -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=False-hdim=64-qh=16-kvh=8-mode=fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=False-hdim=64-qh=16-kvh=8-mode=fwd.png -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=False-hdim=64-qh=32-kvh=16-mode=fwd.csv: -------------------------------------------------------------------------------- 1 | seqlen,Triton,Jax-cudnn 2 | 1024.000000,0.001500,0.512446 3 | 2048.000000,0.157459,2.044994 4 | 4096.000000,0.001593,7.709981 5 | 6144.000000,0.001516,16.856882 6 | 8192.000000,0.001575,30.138208 7 | -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=False-hdim=64-qh=32-kvh=16-mode=fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=False-hdim=64-qh=32-kvh=16-mode=fwd.png -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=True-hdim=128-qh=16-kvh=8-mode=fwd.csv: -------------------------------------------------------------------------------- 1 | seqlen,Triton,Jax-cudnn 2 | 1024.000000,0.255597,0.589090 3 | 2048.000000,0.440020,2.224123 4 | 4096.000000,0.122880,7.861397 5 | 6144.000000,0.001408,19.485580 6 | 8192.000000,0.001331,34.835968 7 | -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=True-hdim=128-qh=16-kvh=8-mode=fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=True-hdim=128-qh=16-kvh=8-mode=fwd.png -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=True-hdim=128-qh=32-kvh=16-mode=fwd.csv: -------------------------------------------------------------------------------- 1 | seqlen,Triton,Jax-cudnn 2 | 1024.000000,0.267349,0.912527 3 | 2048.000000,0.251509,3.839468 4 | 4096.000000,0.001325,14.249164 5 | 6144.000000,0.001536,40.990719 6 | 8192.000000,0.001536,72.704002 7 | -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=True-hdim=128-qh=32-kvh=16-mode=fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=True-hdim=128-qh=32-kvh=16-mode=fwd.png -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=True-hdim=256-qh=16-kvh=8-mode=fwd.csv: -------------------------------------------------------------------------------- 1 | seqlen,Triton,Jax-cudnn 2 | 1024.000000,0.374650,300.000000 3 | 2048.000000,0.082349,300.000000 4 | 4096.000000,0.001792,300.000000 5 | 6144.000000,0.001024,300.000000 6 | 8192.000000,0.002560,300.000000 7 | -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=True-hdim=256-qh=16-kvh=8-mode=fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=True-hdim=256-qh=16-kvh=8-mode=fwd.png -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=True-hdim=256-qh=32-kvh=16-mode=fwd.csv: -------------------------------------------------------------------------------- 1 | seqlen,Triton,Jax-cudnn 2 | 1024.000000,0.358098,300.000000 3 | 2048.000000,0.001506,300.000000 4 | 4096.000000,0.001280,300.000000 5 | 6144.000000,0.002048,300.000000 6 | 8192.000000,0.001024,300.000000 7 | -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=True-hdim=256-qh=32-kvh=16-mode=fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=True-hdim=256-qh=32-kvh=16-mode=fwd.png -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=True-hdim=64-qh=16-kvh=8-mode=fwd.csv: -------------------------------------------------------------------------------- 1 | seqlen,Triton,Jax-cudnn 2 | 1024.000000,0.213456,0.571713 3 | 2048.000000,0.361903,1.789920 4 | 4096.000000,0.499891,9.623846 5 | 6144.000000,0.117496,27.389696 6 | 8192.000000,0.001792,51.911678 7 | -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=True-hdim=64-qh=16-kvh=8-mode=fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=True-hdim=64-qh=16-kvh=8-mode=fwd.png -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=True-hdim=64-qh=32-kvh=16-mode=fwd.csv: -------------------------------------------------------------------------------- 1 | seqlen,Triton,Jax-cudnn 2 | 1024.000000,0.255275,0.832885 3 | 2048.000000,0.391979,3.346892 4 | 4096.000000,0.153389,21.636608 5 | 6144.000000,0.001707,57.332737 6 | 8192.000000,0.001792,105.936035 7 | -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=True-hdim=64-qh=32-kvh=16-mode=fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa-cudnn/b=4-ub=True-hdim=64-qh=32-kvh=16-mode=fwd.png -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa-cudnn/results.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa/README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa/b=1-ub=False-hdim=128-qh=16-kvh=8-mode=fwd.csv: -------------------------------------------------------------------------------- 1 | seqlen,Triton,Jax-cudnn 2 | 1024.000000,0.106995,0.689995 3 | 2048.000000,0.275410,2.995743 4 | 4096.000000,0.520256,11.519552 5 | 6144.000000,0.616047,26.424032 6 | 8192.000000,0.524092,67.017731 7 | -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa/b=1-ub=False-hdim=128-qh=16-kvh=8-mode=fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa/b=1-ub=False-hdim=128-qh=16-kvh=8-mode=fwd.png -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa/b=1-ub=False-hdim=128-qh=32-kvh=16-mode=fwd.csv: -------------------------------------------------------------------------------- 1 | seqlen,Triton,Jax-cudnn 2 | 1024.000000,0.195825,1.440105 3 | 2048.000000,0.323976,5.214518 4 | 4096.000000,0.500160,23.082130 5 | 6144.000000,0.331045,54.701153 6 | 8192.000000,0.001792,136.058044 7 | -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa/b=1-ub=False-hdim=128-qh=32-kvh=16-mode=fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa/b=1-ub=False-hdim=128-qh=32-kvh=16-mode=fwd.png -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa/b=1-ub=False-hdim=256-qh=16-kvh=8-mode=fwd.csv: -------------------------------------------------------------------------------- 1 | seqlen,Triton,Jax-cudnn 2 | 1024.000000,0.302997,0.836519 3 | 2048.000000,0.544905,3.151774 4 | 4096.000000,0.574020,13.211268 5 | 6144.000000,0.001352,30.355797 6 | 8192.000000,0.001434,74.565636 7 | -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa/b=1-ub=False-hdim=256-qh=16-kvh=8-mode=fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa/b=1-ub=False-hdim=256-qh=16-kvh=8-mode=fwd.png -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa/b=1-ub=False-hdim=256-qh=32-kvh=16-mode=fwd.csv: -------------------------------------------------------------------------------- 1 | seqlen,Triton,Jax-cudnn 2 | 1024.000000,0.366214,1.688182 3 | 2048.000000,0.452540,6.120752 4 | 4096.000000,0.051639,25.717834 5 | 6144.000000,0.001339,67.074051 6 | 8192.000000,0.001317,146.715714 7 | -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa/b=1-ub=False-hdim=256-qh=32-kvh=16-mode=fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa/b=1-ub=False-hdim=256-qh=32-kvh=16-mode=fwd.png -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa/b=1-ub=False-hdim=64-qh=16-kvh=8-mode=fwd.csv: -------------------------------------------------------------------------------- 1 | seqlen,Triton,Jax-cudnn 2 | 1024.000000,0.001316,0.670596 3 | 2048.000000,0.142873,2.497885 4 | 4096.000000,0.234355,10.730381 5 | 6144.000000,0.324767,25.309675 6 | 8192.000000,0.424600,66.269180 7 | -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa/b=1-ub=False-hdim=64-qh=16-kvh=8-mode=fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa/b=1-ub=False-hdim=64-qh=16-kvh=8-mode=fwd.png -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa/b=1-ub=False-hdim=64-qh=32-kvh=16-mode=fwd.csv: -------------------------------------------------------------------------------- 1 | seqlen,Triton,Jax-cudnn 2 | 1024.000000,0.001388,1.238941 3 | 2048.000000,0.182871,4.753040 4 | 4096.000000,0.278805,21.524223 5 | 6144.000000,0.396571,52.595871 6 | 8192.000000,0.326004,132.046021 7 | -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa/b=1-ub=False-hdim=64-qh=32-kvh=16-mode=fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa/b=1-ub=False-hdim=64-qh=32-kvh=16-mode=fwd.png -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa/b=1-ub=True-hdim=128-qh=16-kvh=8-mode=fwd.csv: -------------------------------------------------------------------------------- 1 | seqlen,Triton,Jax-cudnn 2 | 1024.000000,0.170421,1.009647 3 | 2048.000000,0.340471,3.215832 4 | 4096.000000,0.616194,13.106743 5 | 6144.000000,0.728552,30.723488 6 | 8192.000000,0.448755,72.886269 7 | -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa/b=1-ub=True-hdim=128-qh=16-kvh=8-mode=fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa/b=1-ub=True-hdim=128-qh=16-kvh=8-mode=fwd.png -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa/b=1-ub=True-hdim=128-qh=32-kvh=16-mode=fwd.csv: -------------------------------------------------------------------------------- 1 | seqlen,Triton,Jax-cudnn 2 | 1024.000000,0.246250,2.027551 3 | 2048.000000,0.390695,5.809258 4 | 4096.000000,0.588857,25.905151 5 | 6144.000000,0.246381,61.066338 6 | 8192.000000,0.001229,144.867325 7 | -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa/b=1-ub=True-hdim=128-qh=32-kvh=16-mode=fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa/b=1-ub=True-hdim=128-qh=32-kvh=16-mode=fwd.png -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa/b=1-ub=True-hdim=256-qh=16-kvh=8-mode=fwd.csv: -------------------------------------------------------------------------------- 1 | seqlen,Triton,Jax-cudnn 2 | 1024.000000,0.331402,1.110649 3 | 2048.000000,0.558512,3.770115 4 | 4096.000000,0.388238,14.670970 5 | 6144.000000,0.001472,34.339840 6 | 8192.000000,0.001820,79.478050 7 | -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa/b=1-ub=True-hdim=256-qh=16-kvh=8-mode=fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa/b=1-ub=True-hdim=256-qh=16-kvh=8-mode=fwd.png -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa/b=1-ub=True-hdim=256-qh=32-kvh=16-mode=fwd.csv: -------------------------------------------------------------------------------- 1 | seqlen,Triton,Jax-cudnn 2 | 1024.000000,0.438935,2.225673 3 | 2048.000000,0.560500,6.704792 4 | 4096.000000,0.001479,29.120886 5 | 6144.000000,0.001152,69.249084 6 | 8192.000000,0.001536,158.288132 7 | -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa/b=1-ub=True-hdim=256-qh=32-kvh=16-mode=fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa/b=1-ub=True-hdim=256-qh=32-kvh=16-mode=fwd.png -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa/b=1-ub=True-hdim=64-qh=16-kvh=8-mode=fwd.csv: -------------------------------------------------------------------------------- 1 | seqlen,Triton,Jax-cudnn 2 | 1024.000000,0.001379,0.948490 3 | 2048.000000,0.253063,2.807466 4 | 4096.000000,0.450294,12.528571 5 | 6144.000000,0.832924,29.624063 6 | 8192.000000,0.837758,71.584770 7 | -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa/b=1-ub=True-hdim=64-qh=16-kvh=8-mode=fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa/b=1-ub=True-hdim=64-qh=16-kvh=8-mode=fwd.png -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa/b=1-ub=True-hdim=64-qh=32-kvh=16-mode=fwd.csv: -------------------------------------------------------------------------------- 1 | seqlen,Triton,Jax-cudnn 2 | 1024.000000,0.152944,1.764381 3 | 2048.000000,0.248449,5.496425 4 | 4096.000000,0.613376,24.820086 5 | 6144.000000,0.583715,60.615711 6 | 8192.000000,0.265728,141.684006 7 | -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa/b=1-ub=True-hdim=64-qh=32-kvh=16-mode=fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa/b=1-ub=True-hdim=64-qh=32-kvh=16-mode=fwd.png -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa/b=2-ub=True-hdim=128-qh=16-kvh=8-mode=fwd.csv: -------------------------------------------------------------------------------- 1 | seqlen,Triton,Jax-cudnn 2 | 1024.000000,0.236832,1.944196 3 | 2048.000000,0.454954,5.739056 4 | 4096.000000,0.604304,25.787445 5 | 6144.000000,0.281980,61.076481 6 | 8192.000000,0.001434,145.263779 7 | -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa/b=2-ub=True-hdim=128-qh=16-kvh=8-mode=fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa/b=2-ub=True-hdim=128-qh=16-kvh=8-mode=fwd.png -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa/b=2-ub=True-hdim=128-qh=32-kvh=16-mode=fwd.csv: -------------------------------------------------------------------------------- 1 | seqlen,Triton,Jax-cudnn 2 | 1024.000000,0.259367,3.784511 3 | 2048.000000,0.483416,11.617672 4 | 4096.000000,0.215821,51.328354 5 | 6144.000000,0.001877,100.000000 6 | 8192.000000,0.001536,100.000000 7 | -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa/b=2-ub=True-hdim=128-qh=32-kvh=16-mode=fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa/b=2-ub=True-hdim=128-qh=32-kvh=16-mode=fwd.png -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa/b=2-ub=True-hdim=256-qh=16-kvh=8-mode=fwd.csv: -------------------------------------------------------------------------------- 1 | seqlen,Triton,Jax-cudnn 2 | 1024.000000,0.395234,2.273197 3 | 2048.000000,0.564041,7.182823 4 | 4096.000000,0.001252,29.376169 5 | 6144.000000,0.001536,68.175011 6 | 8192.000000,0.001280,157.184326 7 | -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa/b=2-ub=True-hdim=256-qh=16-kvh=8-mode=fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa/b=2-ub=True-hdim=256-qh=16-kvh=8-mode=fwd.png -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa/b=2-ub=True-hdim=64-qh=16-kvh=8-mode=fwd.csv: -------------------------------------------------------------------------------- 1 | seqlen,Triton,Jax-cudnn 2 | 1024.000000,0.147088,1.814945 3 | 2048.000000,0.253838,5.710850 4 | 4096.000000,0.616780,24.963852 5 | 6144.000000,0.614312,60.137470 6 | 8192.000000,0.295790,143.506439 7 | -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa/b=2-ub=True-hdim=64-qh=16-kvh=8-mode=fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa/b=2-ub=True-hdim=64-qh=16-kvh=8-mode=fwd.png -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa/b=2-ub=True-hdim=64-qh=32-kvh=16-mode=fwd.csv: -------------------------------------------------------------------------------- 1 | seqlen,Triton,Jax-cudnn 2 | 1024.000000,0.213497,3.775728 3 | 2048.000000,0.390579,12.141111 4 | 4096.000000,0.481676,49.568768 5 | 6144.000000,0.090283,100.000000 6 | 8192.000000,0.001506,100.000000 7 | -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa/b=2-ub=True-hdim=64-qh=32-kvh=16-mode=fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa/b=2-ub=True-hdim=64-qh=32-kvh=16-mode=fwd.png -------------------------------------------------------------------------------- /benchmarks/triton-vs-jax-sdpa/results.html: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erfanzar/jax-flash-attn2/af278984a44b2bd9b00b9962926e978fa4f63d00/benchmarks/triton-vs-jax-sdpa/results.html -------------------------------------------------------------------------------- /jax_flash_attn2/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi). 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .flash_attention import ( 16 | AttentionConfig, 17 | Backend, 18 | FlashAttention, 19 | Platform, 20 | create_flash_attention, 21 | ) 22 | from .flash_attention_jax import jax_flash_attention 23 | from .flash_attention_triton import triton_flash_attention 24 | from .refrence_call import basic_attention_refrence 25 | 26 | __all__ = ( 27 | "AttentionConfig", 28 | "Backend", 29 | "FlashAttention", 30 | "Platform", 31 | "create_flash_attention", 32 | "triton_flash_attention", 33 | "jax_flash_attention", 34 | "basic_attention_refrence", 35 | ) 36 | 37 | __version__ = "0.0.3" 38 | -------------------------------------------------------------------------------- /jax_flash_attn2/flash_attention.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi). 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import os 17 | import typing as tp 18 | import warnings 19 | from dataclasses import dataclass 20 | from enum import Enum 21 | from functools import partial 22 | 23 | import chex 24 | import einops 25 | import jax 26 | import jax.numpy as jnp 27 | 28 | # fmt:off 29 | from jax.experimental.pallas.ops.tpu.flash_attention import BlockSizes as TPUBlockSizes 30 | from jax.experimental.pallas.ops.tpu.flash_attention import flash_attention as pallas_flash_attention_tpu 31 | from jax.extend.backend import get_backend 32 | # fmt:on 33 | 34 | from .flash_attention_jax import jax_flash_attention 35 | from .flash_attention_triton import triton_flash_attention 36 | 37 | AVAILABLE_FLASH_ATTENTION2_PLATFORMS = tp.Literal["triton", "pallas", "jax"] 38 | AVAILABLE_BACKENDS = tp.Literal["gpu", "tpu", "cpu"] 39 | 40 | 41 | def get_device_memory_usage(device: jax.Device) -> float: 42 | """ 43 | Get the memory usage for a specific JAX device using local_devices stats. 44 | 45 | Args: 46 | device: JAX device to check 47 | Returns: 48 | float: Memory usage in bytes 49 | """ 50 | try: 51 | memory_stats = device.memory_stats() 52 | return memory_stats["bytes_in_use"] if memory_stats else float("inf") 53 | except: # noqa 54 | return float("inf") 55 | 56 | 57 | def free_gpu_in_process() -> int: 58 | """ 59 | Returns the index of the GPU with the most available memory using JAX local_devices. 60 | 61 | Returns: 62 | int: Index of the GPU with most free memory 63 | """ 64 | devices = jax.local_devices() 65 | gpu_devices = [d for d in devices if d.platform == "gpu"] 66 | 67 | if not gpu_devices: 68 | return 0 69 | 70 | memory_usage = [get_device_memory_usage(device) for device in gpu_devices] 71 | return memory_usage.index(min(memory_usage)) 72 | 73 | 74 | class Backend(str, Enum): 75 | """Supported compute backends.""" 76 | 77 | GPU = "gpu" 78 | TPU = "tpu" 79 | CPU = "cpu" 80 | 81 | 82 | class Platform(str, Enum): 83 | """Supported Flash Attention platforms.""" 84 | 85 | TRITON = "triton" 86 | PALLAS = "pallas" 87 | JAX = "jax" 88 | 89 | 90 | @dataclass 91 | class AttentionConfig: 92 | """Configuration for Flash Attention computation.""" 93 | 94 | blocksize_q: int = 128 95 | blocksize_k: int = 128 96 | softmax_scale: tp.Optional[float] = None 97 | backend: tp.Optional[Backend] = None 98 | platform: tp.Optional[Platform] = None 99 | 100 | def __post_init__(self): 101 | if self.backend is None: 102 | self.backend = Backend(get_backend().platform) 103 | 104 | if self.platform is None: 105 | self.platform = self._default_platform() 106 | 107 | def _default_platform(self) -> Platform: 108 | """Determines the default platform based on the backend.""" 109 | platform_map = { 110 | Backend.GPU: Platform.TRITON, 111 | Backend.CPU: Platform.JAX, 112 | Backend.TPU: Platform.PALLAS, 113 | } 114 | return platform_map.get(self.backend) 115 | 116 | 117 | class FlashAttention: 118 | """Flash Attention implementation with multiple backend support.""" 119 | 120 | def __init__(self, config: tp.Optional[AttentionConfig] = None): 121 | self.config = config or AttentionConfig() 122 | self._validate_config() 123 | 124 | def _validate_config(self): 125 | """Validates the configuration settings.""" 126 | valid_combinations = { 127 | (Backend.GPU, Platform.TRITON), 128 | (Backend.GPU, Platform.PALLAS), 129 | (Backend.GPU, Platform.JAX), 130 | (Backend.CPU, Platform.JAX), 131 | (Backend.TPU, Platform.JAX), 132 | (Backend.TPU, Platform.PALLAS), 133 | } 134 | 135 | if (self.config.backend, self.config.platform) not in valid_combinations: 136 | raise ValueError( 137 | f"Invalid backend-platform combination: " 138 | f"{self.config.backend}-{self.config.platform}" 139 | ) 140 | 141 | @staticmethod 142 | def repeat_kv_heads( 143 | key: chex.Array, value: chex.Array, num_reps: int 144 | ) -> tp.Tuple[chex.Array, chex.Array]: 145 | """Repeats key and value heads to match query heads.""" 146 | return ( 147 | einops.repeat(key, "b s h d -> b s (h r) d", r=num_reps), 148 | einops.repeat(value, "b s h d -> b s (h r) d", r=num_reps), 149 | ) 150 | 151 | def _handle_bias( 152 | self, bias: chex.Array, num_q_heads: int, num_kv_heads: int 153 | ) -> tp.Optional[chex.Array]: 154 | """Processes attention bias based on head configuration.""" 155 | if bias is None: 156 | return None 157 | 158 | if bias.shape[1] == num_q_heads or bias.shape[1] == 1: 159 | return bias 160 | 161 | elif bias.shape[1] == num_kv_heads: 162 | return einops.repeat( 163 | bias, "b h q k -> b (h r) q k", r=num_q_heads // bias.shape[1] 164 | ) 165 | else: 166 | raise ValueError( 167 | f"Incompatible bias shape. Got {bias.shape[1]} heads, " 168 | f"expected {num_q_heads}, {num_kv_heads}, or 1" 169 | ) 170 | 171 | def __call__( 172 | self, 173 | query: chex.Array, 174 | key: chex.Array, 175 | value: chex.Array, 176 | bias: tp.Optional[chex.Array] = None, 177 | attention_mask: tp.Optional[chex.Array] = None, 178 | dropout_prob: float = 0.0, 179 | causal: bool = False, 180 | dropout_seed: tp.Optional[int] = None, 181 | adjust_sharindgs: bool = False, 182 | ) -> chex.Array: 183 | """ 184 | Computes flash attention using the configured backend and platform. 185 | """ 186 | num_q_heads = query.shape[2] 187 | num_kv_heads = key.shape[2] 188 | 189 | if num_q_heads % num_kv_heads != 0: 190 | raise ValueError( 191 | f"Query heads ({num_q_heads}) must be divisible by " 192 | f"key/value heads ({num_kv_heads})" 193 | ) 194 | if bias is not None: 195 | bias = self._handle_bias(bias, num_q_heads, num_kv_heads) 196 | kw = dict( 197 | query=query, 198 | key=key, 199 | value=value, 200 | bias=bias, 201 | adjust_sharindgs=adjust_sharindgs, 202 | attention_mask=attention_mask, 203 | causal=causal, 204 | dropout_prob=dropout_prob, 205 | dropout_seed=dropout_seed, 206 | ) 207 | if self.config.platform == Platform.TRITON: 208 | return self._compute_triton(**kw) 209 | elif self.config.platform == Platform.PALLAS: 210 | return self._compute_pallas(**kw) 211 | else: # Platform.JAX 212 | return self._compute_jax(**kw) 213 | 214 | forward = __call__ 215 | 216 | def _compute_triton( 217 | self, 218 | query: chex.Array, 219 | key: chex.Array, 220 | value: chex.Array, 221 | bias: tp.Optional[chex.Array] = None, 222 | attention_mask: tp.Optional[chex.Array] = None, 223 | dropout_prob: float = 0.0, 224 | causal: bool = False, 225 | dropout_seed: tp.Optional[int] = None, 226 | adjust_sharindgs: bool = False, 227 | ) -> chex.Array: 228 | """Computes attention using Triton backend.""" 229 | if adjust_sharindgs: 230 | query_sharding = query.sharding if hasattr(query, "sharding") else None 231 | target_gpu_idx = int(os.environ.get("GPU_IDX_FLASH_ATTN", free_gpu_in_process())) 232 | devices = jax.local_devices(process_index=jax.process_index(), backend="gpu") 233 | target_device = devices[target_gpu_idx] 234 | query = jax.device_put(query, target_device) 235 | key = jax.device_put(key, target_device) 236 | value = jax.device_put(value, target_device) 237 | if bias is not None: 238 | bias = jax.device_put(bias, target_device) 239 | if query.shape[1] != key.shape[1]: 240 | if query.shape[1] == 1: 241 | causal = False 242 | attention_mask = None 243 | 244 | attn = triton_flash_attention( 245 | q=query, 246 | k=key, 247 | v=value, 248 | bias=bias, 249 | attention_mask=attention_mask, 250 | dropout_prob=dropout_prob, 251 | causal=causal, 252 | dropout_seed=dropout_seed, 253 | softmax_scale=self.config.softmax_scale, 254 | ) 255 | 256 | if adjust_sharindgs and query_sharding is not None: 257 | attn = jax.device_put(attn, query_sharding) 258 | return attn 259 | 260 | def _compute_pallas( 261 | self, 262 | query: chex.Array, 263 | key: chex.Array, 264 | value: chex.Array, 265 | bias: tp.Optional[chex.Array] = None, 266 | attention_mask: tp.Optional[chex.Array] = None, 267 | dropout_prob: float = 0.0, 268 | causal: bool = False, 269 | dropout_seed: tp.Optional[int] = None, 270 | adjust_sharindgs: bool = False, 271 | ) -> chex.Array: 272 | """Computes attention using Pallas backend.""" 273 | 274 | if self.config.backend == Backend.GPU: 275 | warnings.warn( 276 | "Pallas-FlashAttention has been deprecated on GPUs (triton backend will be used)", 277 | stacklevel=1, 278 | ) 279 | return self._compute_triton( 280 | query=query, 281 | key=key, 282 | value=value, 283 | bias=bias, 284 | adjust_sharindgs=adjust_sharindgs, 285 | attention_mask=attention_mask, 286 | dropout_prob=dropout_prob, 287 | causal=causal, 288 | dropout_seed=dropout_seed, 289 | ) 290 | 291 | key, value = self.repeat_kv_heads(key, value, query.shape[2] // key.shape[2]) 292 | query_lenght = query.shape[1] 293 | value_lenght = value.shape[1] 294 | if bias is not None: 295 | if bias.shape[1] != value.shape[2]: 296 | bias = jnp.repeat(bias, value.shape[2] // bias.shape[1], 1) 297 | # TPU implementation 298 | block_sizes = TPUBlockSizes( 299 | block_q=min(self.config.blocksize_q, query_lenght), 300 | block_k_major=min(self.config.blocksize_k, value_lenght), 301 | block_k=min(self.config.blocksize_k, value_lenght), 302 | block_b=1, 303 | block_q_major_dkv=min(self.config.blocksize_q, query_lenght), 304 | block_k_major_dkv=min(self.config.blocksize_k, value_lenght), 305 | block_k_dkv=min(self.config.blocksize_k, value_lenght), 306 | block_q_dkv=min(self.config.blocksize_q, query_lenght), 307 | block_k_major_dq=min(self.config.blocksize_k, value_lenght), 308 | block_k_dq=min(self.config.blocksize_k, value_lenght), 309 | block_q_dq=min(self.config.blocksize_q, query_lenght), 310 | ) 311 | 312 | return partial( 313 | pallas_flash_attention_tpu, 314 | sm_scale=self.config.softmax_scale, 315 | block_sizes=block_sizes, 316 | causal=causal, 317 | )( 318 | query.transpose(0, 2, 1, 3), 319 | key.transpose(0, 2, 1, 3), 320 | value.transpose(0, 2, 1, 3), 321 | bias, 322 | ).transpose(0, 2, 1, 3) 323 | 324 | def _compute_jax( 325 | self, 326 | query: chex.Array, 327 | key: chex.Array, 328 | value: chex.Array, 329 | bias: tp.Optional[chex.Array] = None, 330 | attention_mask: tp.Optional[chex.Array] = None, 331 | dropout_prob: float = 0.0, 332 | causal: bool = False, 333 | dropout_seed: tp.Optional[int] = None, 334 | adjust_sharindgs: bool = False, 335 | ) -> chex.Array: 336 | """Computes attention using JAX backend.""" 337 | key, value = self.repeat_kv_heads(key, value, query.shape[2] // key.shape[2]) 338 | return jax_flash_attention( 339 | query_state=query, 340 | key_state=key, 341 | value_state=value, 342 | mask=None, 343 | bias=bias, 344 | blocksize_q=self.config.blocksize_q, 345 | blocksize_k=self.config.blocksize_k, 346 | dtype=query.dtype, 347 | softmax_scale=self.config.softmax_scale, 348 | dropout=dropout_prob, 349 | ) 350 | 351 | def __repr__(self): 352 | return ( 353 | f"FlashAttention(platform={self.config.platform}, backend={self.config.backend})" 354 | ) 355 | 356 | __str__ = __repr__ 357 | 358 | 359 | def create_flash_attention( 360 | backend: tp.Optional[tp.Union[Backend, str]] = None, 361 | platform: tp.Optional[tp.Union[Platform, str]] = None, 362 | **kwargs, 363 | ) -> FlashAttention: 364 | """ 365 | Factory function to create a FlashAttention instance with the specified configuration. 366 | 367 | Args: 368 | backend: Compute backend to use (GPU, TPU, or CPU) 369 | platform: Platform to use (Triton, Pallas, or JAX) 370 | **kwargs: Additional configuration parameters for AttentionConfig 371 | 372 | Returns: 373 | Configured FlashAttention instance 374 | """ 375 | if isinstance(backend, str): 376 | backend = Backend(backend) 377 | if isinstance(platform, str): 378 | platform = Platform(platform) 379 | 380 | config = AttentionConfig(backend=backend, platform=platform, **kwargs) 381 | return FlashAttention(config) 382 | -------------------------------------------------------------------------------- /jax_flash_attn2/flash_attention_jax/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi). 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from ._flash_attention import flash_attention as jax_flash_attention 16 | 17 | __all__ = ("jax_flash_attention",) 18 | -------------------------------------------------------------------------------- /jax_flash_attn2/flash_attention_jax/_backward_jax.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi). 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | # Implementation based on FlashAttention 2 (https://arxiv.org/pdf/2307.08691) by @erfanzar, 17 | # with a few bug fixes and adjustments. 18 | 19 | import functools 20 | import typing as tp 21 | 22 | import jax 23 | import jax.numpy as jnp 24 | import jax.sharding 25 | from jax import lax 26 | 27 | 28 | @functools.partial(jax.named_call, name="_bwd_flash_attn") 29 | def _bwd_flash_attn( 30 | dropout: float, 31 | inference: bool, 32 | key: tp.Optional[jax.random.PRNGKey], 33 | blocksize_q: int, 34 | blocksize_k: int, 35 | dtype: tp.Optional[jnp.dtype], 36 | precision: lax.PrecisionLike, 37 | residuals, 38 | grad_in: jax.Array, 39 | ): 40 | """Backward pass of FlashAttention.""" 41 | 42 | del dtype 43 | ( 44 | O, # noqa: E741 45 | L, 46 | query_state, 47 | key_state, 48 | value_state, 49 | mask, 50 | bias, 51 | ) = residuals 52 | dO = grad_in 53 | 54 | b, h, _, d = query_state.shape 55 | q_seq = query_state.shape[2] 56 | k_seq = key_state.shape[2] 57 | assert q_seq % blocksize_q == 0 58 | assert k_seq % blocksize_k == 0 59 | Tr = q_seq // blocksize_q 60 | Tc = k_seq // blocksize_k 61 | 62 | D = jnp.sum(dO * O, axis=-1) 63 | 64 | dQ = (query_state * 0.0).astype(query_state.dtype) 65 | dK = (key_state * 0.0).astype(key_state.dtype) 66 | dV = (value_state * 0.0).astype(value_state.dtype) 67 | global_mask = mask 68 | is_causal = mask is not None 69 | 70 | @jax.jit 71 | @functools.partial(jax.named_call, name="_bwd_flash_attn_call_o") 72 | def call_o(state): 73 | j, dQ, dK, dV = state 74 | k_j = jax.lax.dynamic_slice_in_dim(key_state, j * blocksize_k, blocksize_k, 2) 75 | v_j = jax.lax.dynamic_slice_in_dim(value_state, j * blocksize_k, blocksize_k, 2) 76 | 77 | dK_j = jax.lax.dynamic_slice_in_dim(dK, j * blocksize_k, blocksize_k, 2) 78 | dV_j = jax.lax.dynamic_slice_in_dim(dV, j * blocksize_k, blocksize_k, 2) 79 | 80 | @jax.jit 81 | @functools.partial(jax.named_call, name="_bwd_flash_attn_call_o_call_qk") 82 | def do_inner_block(state): 83 | i, j, dQ, dK_j, dV_j = state 84 | q_i = jax.lax.dynamic_slice_in_dim(query_state, i * blocksize_q, blocksize_q, 2) 85 | dQ_i = jax.lax.dynamic_slice_in_dim(dQ, i * blocksize_q, blocksize_q, 2) 86 | dO_i = jax.lax.dynamic_slice_in_dim(dO, i * blocksize_q, blocksize_q, 2) 87 | 88 | L_i = jax.lax.dynamic_slice_in_dim(L, i * blocksize_q, blocksize_q, 2) 89 | D_i = jax.lax.dynamic_slice_in_dim(D, i * blocksize_q, blocksize_q, 2) 90 | s_ij = q_i @ k_j.transpose(0, 1, 3, 2) 91 | if dropout > 0 and not inference: 92 | rng = jax.random.fold_in(key, i * Tc + j) 93 | keep_prob = 1.0 - dropout 94 | broadcast_shape = list(s_ij.shape) 95 | mask = jax.random.bernoulli(rng, p=keep_prob, shape=broadcast_shape) 96 | mask = jnp.broadcast_to(mask, s_ij.shape) 97 | s_ij = lax.select(mask, s_ij / keep_prob, jnp.zeros_like(s_ij)) 98 | 99 | if bias is not None: 100 | b_i = jax.lax.dynamic_slice_in_dim(bias, i * blocksize_q, blocksize_q, 2) 101 | b_ij = jax.lax.dynamic_slice_in_dim(b_i, j * blocksize_k, blocksize_k, 3) 102 | s_ij = s_ij + b_ij 103 | 104 | if global_mask is not None: 105 | ma_i = jax.lax.dynamic_slice_in_dim( 106 | global_mask, i * blocksize_q, blocksize_q, 2 107 | ) 108 | ma_ij = jax.lax.dynamic_slice_in_dim(ma_i, j * blocksize_k, blocksize_k, 3) 109 | s_ij = jnp.where(ma_ij, s_ij, -1e10) 110 | 111 | p_ij = jnp.exp(s_ij - jnp.expand_dims(L_i, -1)) 112 | 113 | if dropout > 0 and not inference: 114 | rng = jax.random.fold_in(key, i * Tc + j) 115 | keep_prob = 1.0 - dropout 116 | broadcast_shape = list(p_ij.shape) 117 | mask = jax.random.bernoulli(rng, p=keep_prob, shape=broadcast_shape) 118 | mask = jnp.broadcast_to(mask, p_ij.shape) 119 | p_ij = lax.select(mask, p_ij / keep_prob, jnp.zeros_like(p_ij)) 120 | 121 | dV_j = dV_j + jnp.matmul(p_ij.transpose(0, 1, 3, 2), dO_i, precision=precision) 122 | 123 | dP_ij = jnp.matmul(dO_i, v_j.transpose(0, 1, 3, 2), precision=precision) 124 | 125 | dS_ij = p_ij * (dP_ij - D_i[..., None]) 126 | dQ_i = dQ_i + jnp.matmul(dS_ij, k_j, precision=precision) 127 | dK_j = dK_j + jnp.matmul(dS_ij.transpose(0, 1, 3, 2), q_i, precision=precision) 128 | dQ = jax.lax.dynamic_update_slice_in_dim( 129 | dQ, 130 | dQ_i.astype(dQ.dtype), 131 | i * blocksize_q, 132 | 2, 133 | ) 134 | return ( 135 | i + 1, 136 | j, 137 | dQ.astype(query_state.dtype), 138 | dK_j.astype(key_state.dtype), 139 | dV_j.astype(value_state.dtype), 140 | ) 141 | 142 | i_start = j if is_causal else 0 143 | _, j, dQ, dK_j, dV_j = jax.lax.while_loop( 144 | lambda state: state[0] < Tr, 145 | do_inner_block, 146 | (i_start, j, dQ, dK_j, dV_j), 147 | ) 148 | 149 | dK = jax.lax.dynamic_update_slice_in_dim( 150 | dK, dK_j.astype(dK.dtype), j * blocksize_q, 2 151 | ) 152 | dV = jax.lax.dynamic_update_slice_in_dim( 153 | dV, dV_j.astype(dV.dtype), j * blocksize_q, 2 154 | ) 155 | 156 | return j + 1, dQ, dK, dV 157 | 158 | _, dQ, dK, dV = jax.lax.while_loop( 159 | lambda state: state[0] < Tc, 160 | call_o, 161 | (0, dQ, dK, dV), 162 | ) 163 | 164 | return dQ, dK, dV, None, None 165 | -------------------------------------------------------------------------------- /jax_flash_attn2/flash_attention_jax/_flash_attention.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi). 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | # Implementation based on FlashAttention 2 (https://arxiv.org/pdf/2307.08691) by @erfanzar, 17 | # with a few bug fixes and adjustments. 18 | 19 | import functools 20 | import math 21 | import typing as tp 22 | 23 | import jax 24 | import jax.numpy as jnp 25 | import jax.sharding 26 | from jax import lax 27 | 28 | from ._backward_jax import _bwd_flash_attn 29 | from ._forward_jax import _fwd_flash_attn 30 | 31 | 32 | @functools.partial( 33 | jax.jit, 34 | static_argnames=[ 35 | "dtype", 36 | "precision", 37 | "blocksize_q", 38 | "blocksize_k", 39 | ], 40 | ) 41 | def flash_attention( 42 | query_state: jax.Array, 43 | key_state: jax.Array, 44 | value_state: jax.Array, 45 | mask: tp.Optional[jax.Array] = None, 46 | bias: tp.Optional[jax.Array] = None, 47 | *, 48 | dropout: float = 0.0, 49 | inference: bool = True, 50 | key: tp.Optional[jax.random.PRNGKey] = None, 51 | blocksize_q: tp.Optional[int] = None, 52 | blocksize_k: tp.Optional[int] = None, 53 | dtype: tp.Optional[jnp.dtype] = None, 54 | precision: lax.PrecisionLike = None, 55 | head_dim: tp.Optional[int] = None, 56 | softmax_scale: tp.Optional[float] = None, 57 | ) -> jax.Array: 58 | """ 59 | Computes multi-head attention using FlashAttention implementation. 60 | 61 | This implementation makes use of the FlashAttention algorithm for faster 62 | and more memory-efficient computation of attention. It is particularly 63 | beneficial for long sequences. 64 | 65 | Args: 66 | query_state: Query, shape (`batch_size`, `q_len`, `num_heads`, `head_dim`). 67 | key_state: Key, shape (`batch_size`, `kv_len`, `num_heads`, `head_dim`). 68 | value_state: Value, shape (`batch_size`, `kv_len`, `num_heads`, `head_dim`). 69 | mask: tp.Optional attention mask. This can be any of the following: 70 | 71 | - No mask (default): All attention weights are computed. 72 | - Boolean mask (2D): shape (`batch_size`, `q_len`), with `True` for 73 | valid and `False` for masked positions. 74 | - Integer mask (2D): shape (`batch_size`, `q_len`), where the value at 75 | each position indicates the length of the sequence to attend to. 76 | - 4D mask: shape (`batch_size`, `q_len`, `kv_len`), with `True` for 77 | valid and `False` for masked positions. 78 | 79 | bias: tp.Optional attention bias. 80 | dropout: Dropout rate. 81 | inference: Whether to run in inference mode. 82 | key: PRNG key for dropout. 83 | blocksize_q: Block size for query processing. 84 | blocksize_k: Block size for key/value processing. 85 | dtype: tp.Optional dtype for the output. 86 | precision: tp.Optional precision for matrix multiplication. 87 | head_dim: tp.Optional head dim to be used at `query_state = query_state / math.sqrt(float(head_dim or query_state.shape[-1]))`. 88 | softmax_scale tp.Optional softmax_scale to be used for `query_state = query_state * softmax_scale` 89 | 90 | Returns: 91 | Output of multi-head attention, with shape 92 | (`batch_size`, `q_len`, `num_heads`, `head_dim`). 93 | 94 | Raises: 95 | ValueError: If `dropout` is not in the range [0, 1], or if `key` is not 96 | provided during training when `dropout` > 0. 97 | """ 98 | query_state, key_state, value_state = map( 99 | lambda x: x.transpose(0, 2, 1, 3), 100 | [query_state, key_state, value_state], 101 | ) 102 | if not inference and dropout > 0 and key is None: 103 | raise ValueError("key must be provided for training") 104 | if dropout < 0 or dropout > 1: 105 | raise ValueError(f"invalid dropout {dropout}") 106 | if dtype is not None: 107 | query_state = query_state.astype(dtype) 108 | key_state = key_state.astype(dtype) 109 | 110 | blocksize_k = min(key_state.shape[2], blocksize_k or 128) 111 | blocksize_q = min(query_state.shape[2], blocksize_q or 128) 112 | if head_dim is not None and softmax_scale is not None: 113 | raise ValueError("you can't pass both `head_dim` and `softmax_scale`.") 114 | if head_dim is not None: 115 | query_state = query_state / math.sqrt(float(head_dim)) 116 | elif softmax_scale is not None: 117 | query_state = query_state * softmax_scale 118 | else: 119 | query_state = query_state / math.sqrt(float(query_state.shape[-1])) 120 | return _flash_attn2( 121 | query_state, 122 | key_state, 123 | value_state, 124 | mask, 125 | bias, 126 | dropout, 127 | inference, 128 | key, 129 | blocksize_q, 130 | blocksize_k, 131 | dtype, 132 | precision, 133 | ).transpose(0, 2, 1, 3) 134 | 135 | 136 | @functools.partial( 137 | jax.custom_vjp, 138 | nondiff_argnums=(5, 6, 7, 8, 9, 10, 11), 139 | ) 140 | def _flash_attn2( 141 | query_state: jax.Array, 142 | key_state: jax.Array, 143 | value_state: jax.Array, 144 | mask: tp.Optional[jax.Array] = None, 145 | bias: tp.Optional[jax.Array] = None, 146 | dropout: float = 0.0, 147 | inference: bool = False, 148 | key: tp.Optional[jax.random.PRNGKey] = None, 149 | blocksize_q: int = 128, 150 | blocksize_k: int = 128, 151 | dtype: tp.Optional[jnp.dtype] = jnp.float32, 152 | precision: lax.PrecisionLike = None, 153 | ) -> jax.Array: 154 | """Custom VJP-enabled wrapper for FlashAttention forward pass.""" 155 | return _fwd_flash_attn( 156 | query_state, 157 | key_state, 158 | value_state, 159 | mask, 160 | bias, 161 | dropout, 162 | inference, 163 | key, 164 | blocksize_q, 165 | blocksize_k, 166 | dtype, 167 | precision, 168 | )[0] 169 | 170 | 171 | _flash_attn2.defvjp(_fwd_flash_attn, _bwd_flash_attn) 172 | 173 | 174 | def fwd_test(): 175 | import flax 176 | from jax import random as jrand 177 | 178 | from easydel.utils import GenerateRNG 179 | 180 | rng = GenerateRNG() 181 | 182 | b, h, qs, s, d = 1, 32, 2048, 2048, 128 183 | dtype = jnp.float16 184 | 185 | q = jrand.normal(rng.rng, shape=(b, qs, h, d), dtype=dtype) 186 | k = jrand.normal(rng.rng, shape=(b, s, h, d), dtype=dtype) 187 | v = jrand.normal(rng.rng, shape=(b, s, h, d), dtype=dtype) 188 | b = jnp.where( 189 | jrand.randint(rng.rng, shape=(b, h, qs, s), minval=0, maxval=3) > 1, 190 | 0, 191 | jnp.finfo(dtype).min, 192 | ) 193 | excepted_result = flax.nnx.dot_product_attention( 194 | query=q, 195 | key=k, 196 | value=v, 197 | bias=b, 198 | ) 199 | result = flash_attention( 200 | query_state=q, 201 | key_state=k, 202 | value_state=v, 203 | bias=b, 204 | dtype=dtype, 205 | blocksize_q=64, 206 | blocksize_k=64, 207 | ) 208 | 209 | print(f"PRED : {result[0, 0, 0, :5]}") 210 | print(f"ORGN : {excepted_result[0, 0, 0, :5]}") 211 | 212 | print(jnp.allclose(excepted_result, result, atol=0.125, rtol=0)) 213 | 214 | 215 | def bwd_test(): 216 | import flax 217 | from jax import random as jrand 218 | 219 | from easydel.utils import GenerateRNG 220 | 221 | rng = GenerateRNG() 222 | b, h, qs, s, d = 2, 32, 64, 64, 64 223 | dtype = jnp.float16 224 | 225 | q = jrand.normal(rng.rng, shape=(b, qs, h, d), dtype=dtype) 226 | k = jrand.normal(rng.rng, shape=(b, s, h, d), dtype=dtype) 227 | v = jrand.normal(rng.rng, shape=(b, s, h, d), dtype=dtype) 228 | b = jnp.where( 229 | jrand.randint(rng.rng, shape=(b, h, qs, s), minval=0, maxval=3) > 1, 230 | 0, 231 | jnp.finfo(dtype).min, 232 | ) 233 | 234 | excepted_result = jax.grad(lambda *x: flax.nnx.dot_product_attention(*x).sum())( 235 | q, k, v 236 | ) 237 | result = jax.grad( 238 | lambda *x: flash_attention( 239 | *x, 240 | dtype=dtype, 241 | blocksize_q=qs, 242 | blocksize_k=s, 243 | precision=jax.lax.Precision("HIGHEST".lower()), 244 | ).sum() 245 | )(q, k, v) 246 | 247 | print(f"PRED BWD : {result[0, 0, 0, :5]}") 248 | print(f"ORGN BWD : {excepted_result[0, 0, 0, :5]}") 249 | 250 | print(jnp.allclose(excepted_result, result, atol=0.125, rtol=0)) 251 | 252 | 253 | jax_flash_attn_2_mu = flash_attention 254 | 255 | __all__ = ["jax_flash_attn_2_mu"] 256 | 257 | if __name__ == "__main__": 258 | # fwd_test() 259 | bwd_test() 260 | -------------------------------------------------------------------------------- /jax_flash_attn2/flash_attention_jax/_forward_jax.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi). 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import functools 16 | import typing as tp 17 | 18 | import jax 19 | import jax.numpy as jnp 20 | import jax.sharding 21 | from eformer.escale import with_sharding_constraint 22 | from jax import lax 23 | 24 | 25 | @functools.partial(jax.named_call, name="_fwd_flash_attn") 26 | def _fwd_flash_attn( 27 | query_state: jax.Array, 28 | key_state: jax.Array, 29 | value_state: jax.Array, 30 | mask: tp.Optional[jax.Array], 31 | bias: tp.Optional[jax.Array], 32 | dropout: float, 33 | inference: bool, 34 | key: tp.Optional[jax.random.PRNGKey], 35 | blocksize_q: int, 36 | blocksize_k: int, 37 | dtype: tp.Optional[jnp.dtype], 38 | precision: lax.PrecisionLike, 39 | ) -> tuple[jax.Array, tuple[jax.Array, ...]]: 40 | """Forward pass of FlashAttention.""" 41 | b, h, _, d = query_state.shape 42 | q_seq = query_state.shape[2] 43 | k_seq = key_state.shape[2] 44 | assert q_seq % blocksize_q == 0, ( 45 | "Query sequence length is not visible by queryblock size" 46 | ) 47 | assert k_seq % blocksize_k == 0, "Key sequence length is not visible by keyblock size" 48 | Tr = q_seq // blocksize_q 49 | Tc = k_seq // blocksize_k 50 | o_shape = jax.eval_shape( 51 | lambda: (query_state @ key_state.transpose(0, 1, 3, 2)) @ value_state 52 | ).shape 53 | o = jnp.zeros(o_shape, dtype=dtype) 54 | 55 | lse = jnp.full((b, h, q_seq), fill_value=-jnp.inf, dtype=jnp.float32) 56 | if hasattr(query_state, "sharding"): 57 | if isinstance(query_state.sharding, jax.sharding.NamedSharding): 58 | with query_state.sharding.mesh: 59 | o = with_sharding_constraint( 60 | arr=o, 61 | sharding=query_state.sharding, 62 | ) 63 | lse = with_sharding_constraint( 64 | arr=lse, 65 | sharding=jax.sharding.PartitionSpec(*query_state.sharding.spec[:3]), 66 | ) 67 | elif isinstance( 68 | query_state.sharding, jax.sharding.SingleDeviceSharding 69 | ) and hasattr(query_state.sharding, "_device"): 70 | o = jax.device_put(o, query_state.sharding._device) 71 | lse = jax.device_put(lse, query_state.sharding._device) 72 | 73 | global_mask = mask 74 | 75 | @jax.jit 76 | @functools.partial(jax.named_call, name="_fwd_flash_attn_call_o") 77 | def call_o(state): 78 | i, o, lse = state 79 | q_i = jax.lax.dynamic_slice_in_dim(query_state, i * blocksize_q, blocksize_q, 2) 80 | o_i = jax.lax.dynamic_slice_in_dim(o, i * blocksize_q, blocksize_q, 2) 81 | lse_i = jax.lax.dynamic_slice_in_dim(lse, i * blocksize_q, blocksize_q, 2) 82 | m_i = jnp.full((b, h, blocksize_q), fill_value=-jnp.inf, dtype=dtype) 83 | 84 | @jax.jit 85 | @functools.partial(jax.named_call, name="_fwd_flash_attn_call_o_call_qk") 86 | def call_qk(state): 87 | i, j, o_i, q_i, lse_i, m_i = state 88 | k_j = jax.lax.dynamic_slice_in_dim(key_state, j * blocksize_k, blocksize_k, 2) 89 | v_j = jax.lax.dynamic_slice_in_dim(value_state, j * blocksize_k, blocksize_k, 2) 90 | 91 | s_ij = jnp.einsum( 92 | "bhqd,bhdk->bhqk", 93 | q_i, 94 | k_j.transpose(0, 1, 3, 2), 95 | precision=precision, 96 | ) 97 | 98 | if bias is not None: 99 | b_i = jax.lax.dynamic_slice_in_dim(bias, i * blocksize_q, blocksize_q, 2) 100 | b_ij = jax.lax.dynamic_slice_in_dim(b_i, j * blocksize_k, blocksize_k, 3) 101 | s_ij = s_ij + b_ij 102 | if global_mask is not None: 103 | ma_i = jax.lax.dynamic_slice_in_dim( 104 | global_mask, i * blocksize_q, blocksize_q, 2 105 | ) 106 | ma_ij = jax.lax.dynamic_slice_in_dim(ma_i, j * blocksize_k, blocksize_k, 3) 107 | s_ij = jnp.where(ma_ij, s_ij, -1e10) 108 | 109 | if dropout > 0 and not inference: 110 | rng = jax.random.fold_in(key, i * Tc + j) 111 | keep_prob = 1.0 - dropout 112 | broadcast_shape = list(s_ij.shape) 113 | mask = jax.random.bernoulli(rng, p=keep_prob, shape=broadcast_shape) 114 | mask = jnp.broadcast_to(mask, s_ij.shape) 115 | s_ij = lax.select(mask, s_ij / keep_prob, jnp.zeros_like(s_ij)) 116 | 117 | m_ij = jnp.maximum(m_i, jnp.max(s_ij, axis=-1)) 118 | p = jnp.exp(s_ij - jnp.expand_dims(m_ij, -1)) 119 | 120 | l_ij = jnp.sum(p, -1) 121 | 122 | o_scale = jnp.exp(m_i - m_ij) 123 | o_i = o_i * jnp.expand_dims(o_scale, -1) 124 | 125 | o_i = o_i + jnp.einsum( 126 | "bhqk,bhkd->bhqd", 127 | p, 128 | v_j, 129 | precision=precision, 130 | ) 131 | 132 | return ( 133 | i, 134 | j + 1, 135 | o_i.astype(dtype), 136 | q_i.astype(dtype), 137 | jnp.log(jnp.exp(lse_i - m_ij) + l_ij) + m_ij, 138 | m_ij.astype(dtype), 139 | ) 140 | 141 | j_end = jnp.minimum(i + 1, Tc) if mask is not None else Tc 142 | 143 | _, _, o_i, _, lse_i, m_i = jax.lax.while_loop( 144 | lambda state: state[1] < j_end, 145 | call_qk, 146 | (i, 0, o_i, q_i, lse_i, m_i), 147 | ) 148 | o_scale = jnp.exp(m_i - lse_i) 149 | o_i = o_i * jnp.expand_dims(o_scale, -1) 150 | 151 | o = jax.lax.dynamic_update_slice_in_dim(o, o_i.astype(o.dtype), i * blocksize_q, 2) 152 | lse = jax.lax.dynamic_update_slice_in_dim( 153 | lse, 154 | lse_i.astype(lse.dtype), 155 | i * blocksize_q, 156 | 2, 157 | ) 158 | return i + 1, o, lse 159 | 160 | _, o, lse = jax.lax.while_loop(lambda state: state[0] < Tr, call_o, (0, o, lse)) 161 | 162 | return o, ( 163 | o, 164 | lse, 165 | query_state, #: jax.Array 166 | key_state, #: jax.Array 167 | value_state, #: jax.Array 168 | mask, 169 | bias, 170 | ) 171 | -------------------------------------------------------------------------------- /jax_flash_attn2/flash_attention_triton/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi). 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from ._flash_attention import flash_attention as triton_flash_attention 16 | 17 | __all__ = ("triton_flash_attention",) 18 | -------------------------------------------------------------------------------- /jax_flash_attn2/flash_attention_triton/_backward_triton.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi). 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import math 16 | import typing as tp 17 | 18 | import chex 19 | import jax 20 | import triton 21 | import triton.language as tl 22 | from eformer.callib import triton_call 23 | from jax import numpy as jnp 24 | from triton import Config 25 | 26 | from ._utils import ( 27 | dtype_index, 28 | get_sharding, 29 | get_strides, 30 | safe_autotune, 31 | attention_pack_with_static_shape, 32 | attention_unpack_with_static_shape, 33 | calc_bias_strides, 34 | padded_load, 35 | ) 36 | 37 | 38 | def config_prune_kernel( 39 | configs: tp.List[Config], 40 | named_args: tp.Dict[str, tp.Any], 41 | **kwargs, 42 | ) -> tp.List[Config]: 43 | kept_configs = [] 44 | for config in configs: 45 | largest_m = ( 46 | max( 47 | config.kwargs["BLOCK_M1"], 48 | config.kwargs["BLOCK_M2"], 49 | ) 50 | > named_args["QSeq"] 51 | ) 52 | largest_n = ( 53 | max( 54 | config.kwargs["BLOCK_N1"], 55 | config.kwargs["BLOCK_N2"], 56 | ) 57 | > named_args["KSeq"] 58 | ) 59 | if largest_m or largest_n: 60 | pass 61 | else: 62 | kept_configs.append(config) 63 | if kept_configs: 64 | return kept_configs 65 | return [ 66 | Config( 67 | { 68 | "BLOCK_M1": 32, 69 | "BLOCK_N1": 32, 70 | "BLOCK_M2": 32, 71 | "BLOCK_N2": 32, 72 | }, 73 | num_warps=4, 74 | num_stages=0, 75 | ) 76 | ] 77 | 78 | 79 | @safe_autotune( 80 | configs=[ 81 | Config({"BLOCK_M": 16}, num_warps=4, num_stages=0), 82 | Config({"BLOCK_M": 32}, num_warps=4, num_stages=0), 83 | Config({"BLOCK_M": 64}, num_warps=4, num_stages=0), 84 | Config({"BLOCK_M": 128}, num_warps=4, num_stages=0), 85 | ], 86 | key=["CQSeq", "DRuntime"], 87 | ) 88 | @triton.jit 89 | def _attn_bwd_preprocess( 90 | Po, 91 | Do, 92 | stride_oz, 93 | stride_om, 94 | stride_oh, 95 | stride_dez, 96 | stride_dem, 97 | stride_deh, 98 | nheads, 99 | QSeq, 100 | max_seqlen_q_rounded, 101 | cum_seqlens_q, 102 | headdim, 103 | CQSeq, # Re-compile argument 104 | DRuntime, # Re-compile argument 105 | Delta, 106 | VARLEN: tl.constexpr, 107 | BLOCK_M: tl.constexpr, 108 | BLOCK_HEADDIM: tl.constexpr, 109 | ): 110 | start_m = tl.program_id(0) 111 | off_zh = tl.program_id(1) 112 | off_z = off_zh // nheads 113 | off_h = off_zh % nheads 114 | offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) 115 | offs_d = tl.arange(0, BLOCK_HEADDIM) 116 | 117 | if VARLEN: 118 | start_seqlen_q = tl.load(cum_seqlens_q + off_z) 119 | actual_seqlen_q = tl.load(cum_seqlens_q + off_z + 1) - start_seqlen_q 120 | cu_seq_start_q = tl.load(cum_seqlens_q + off_z) 121 | off_z = 0 122 | else: 123 | actual_seqlen_q = QSeq 124 | cu_seq_start_q = 0 125 | 126 | o_ptrs = ( 127 | Po 128 | + off_z * stride_oz 129 | + off_h * stride_oh 130 | + cu_seq_start_q * stride_om 131 | + offs_m[:, None] * stride_om 132 | + offs_d[None, :] 133 | ) 134 | do_ptrs = ( 135 | Do 136 | + off_z * stride_dez 137 | + off_h * stride_deh 138 | + cu_seq_start_q * stride_dem 139 | + offs_m[:, None] * stride_dem 140 | + offs_d[None, :] 141 | ) 142 | 143 | mask = (offs_m[:, None] < actual_seqlen_q) & (offs_d[None, :] < headdim) 144 | o = tl.load(o_ptrs, mask=mask, other=0.0).to(tl.float32) 145 | do = tl.load(do_ptrs, mask=mask, other=0.0).to(tl.float32) 146 | delta = tl.sum(o * do, axis=1) 147 | tl.store(Delta + off_zh * max_seqlen_q_rounded + offs_m, delta) 148 | 149 | 150 | @triton.jit 151 | def _attn_bwd_dkdv( 152 | index_start_m, 153 | k, 154 | v, 155 | dk, 156 | dv, 157 | M, 158 | D, 159 | offs_m, 160 | offs_n, 161 | offs_d, 162 | q_ptrs, 163 | bias_ptrs, 164 | dropout_offs, 165 | do_ptrs, 166 | softmax_scale, 167 | stride_qm, 168 | stride_bm, 169 | stride_dom, 170 | actual_seqlen_q, 171 | actual_seqlen_k, 172 | fully_masked_lines, 173 | headdim, 174 | MASKED: tl.constexpr, 175 | IS_CAUSAL: tl.constexpr, 176 | BIAS_ON: tl.constexpr, 177 | USE_DROPOUT: tl.constexpr, 178 | PAD_ROWS: tl.constexpr, 179 | PAD_COLS: tl.constexpr, 180 | HEADS_PADDED: tl.constexpr, 181 | ): 182 | LN2: tl.constexpr = 1.44269504089 183 | q_ptrs = q_ptrs + index_start_m * stride_qm 184 | do_ptrs = do_ptrs + index_start_m * stride_dom 185 | if BIAS_ON: 186 | bias_ptrs = bias_ptrs + index_start_m * stride_bm 187 | if USE_DROPOUT: 188 | dropout_offs += index_start_m * actual_seqlen_k 189 | 190 | offs_m_curr = index_start_m + offs_m 191 | 192 | q = padded_load( 193 | q_ptrs, 194 | offs_m_curr, 195 | offs_d, 196 | PAD_ROWS or HEADS_PADDED, 197 | PAD_ROWS or HEADS_PADDED, 198 | actual_seqlen_q, 199 | headdim, 200 | ) 201 | me_i = tl.load(M + offs_m_curr) 202 | if BIAS_ON: 203 | bias = padded_load( 204 | bias_ptrs, 205 | offs_m_curr, 206 | offs_n, 207 | PAD_ROWS or HEADS_PADDED, 208 | PAD_ROWS or HEADS_PADDED, 209 | actual_seqlen_q, 210 | actual_seqlen_k, 211 | ) 212 | 213 | qk = tl.dot(q, tl.trans(k)) 214 | if BIAS_ON: 215 | qk += bias / softmax_scale 216 | 217 | offs_n_causal = offs_n - actual_seqlen_k + actual_seqlen_q 218 | if MASKED: 219 | if PAD_COLS: 220 | if IS_CAUSAL: 221 | qk = tl.where( 222 | tl.minimum(actual_seqlen_q - 1, offs_m_curr)[:, None] 223 | >= offs_n_causal[None, :], 224 | qk, 225 | float("-inf"), 226 | ) 227 | else: 228 | qk = tl.where(actual_seqlen_q - 1 >= offs_n_causal[None, :], qk, float("-inf")) 229 | elif IS_CAUSAL: 230 | qk = tl.where(offs_m_curr[:, None] >= offs_n_causal[None, :], qk, float("-inf")) 231 | tl.debug_barrier() 232 | p = tl.exp2(qk * (softmax_scale * LN2) - me_i[:, None]) 233 | 234 | if MASKED: 235 | if fully_masked_lines > 0: 236 | p = tl.where(offs_m_curr[:, None] < fully_masked_lines, 0, p) 237 | 238 | do = padded_load( 239 | do_ptrs, 240 | offs_m_curr, 241 | offs_d, 242 | PAD_ROWS, 243 | HEADS_PADDED, 244 | actual_seqlen_q, 245 | headdim, 246 | ) 247 | 248 | dv += tl.dot(tl.trans(p).to(do.dtype), do) 249 | dp = tl.dot(do, tl.trans(v)) 250 | Di = tl.load(D + offs_m_curr) 251 | ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype) 252 | dk += tl.dot(tl.trans(ds), q) 253 | 254 | return dk, dv 255 | 256 | 257 | @triton.jit 258 | def _attn_bwd_block_dkdv( 259 | index_start_n, 260 | Q, 261 | K, 262 | V, 263 | B, 264 | Dropout, 265 | Do, 266 | Dk, 267 | Dv, 268 | M, 269 | D, 270 | softmax_scale, 271 | stride_qm, 272 | stride_kn, 273 | stride_vn, 274 | stride_bm, 275 | stride_dom, 276 | stride_dkn, 277 | stride_dvn, 278 | actual_seqlen_q, 279 | actual_seqlen_k, 280 | headdim, 281 | IS_CAUSAL: tl.constexpr, 282 | BIAS_ON: tl.constexpr, 283 | USE_DROPOUT: tl.constexpr, 284 | PAD_COLS: tl.constexpr, 285 | HEADS_PADDED: tl.constexpr, 286 | BLOCK_M: tl.constexpr, 287 | BLOCK_N: tl.constexpr, 288 | BLOCK_HEADDIM: tl.constexpr, 289 | ): 290 | index_begin_m = ( 291 | max(index_start_n + actual_seqlen_q - actual_seqlen_k, 0) if IS_CAUSAL else 0 292 | ) 293 | index_begin_m = (index_begin_m // BLOCK_M) * BLOCK_M 294 | index_end_m = actual_seqlen_q 295 | 296 | fully_masked_lines = (actual_seqlen_q - actual_seqlen_k) if IS_CAUSAL else 0 297 | if (index_begin_m >= actual_seqlen_q) or (index_start_n >= actual_seqlen_k): 298 | return 299 | 300 | offs_n = index_start_n + tl.arange(0, BLOCK_N) 301 | offs_m = tl.arange(0, BLOCK_M) 302 | offs_d = tl.arange(0, BLOCK_HEADDIM) 303 | 304 | q_ptrs = Q + (offs_m[:, None] * stride_qm + offs_d[None, :]) 305 | k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :]) 306 | v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :]) 307 | dk_ptrs = Dk + (offs_n[:, None] * stride_dkn + offs_d[None, :]) 308 | dv_ptrs = Dv + (offs_n[:, None] * stride_dvn + offs_d[None, :]) 309 | do_ptrs = Do + (offs_m[:, None] * stride_dom + offs_d[None, :]) 310 | if BIAS_ON: 311 | bias_ptrs = B + (offs_m[:, None] * stride_bm + offs_n[None, :]) 312 | else: 313 | bias_ptrs = None 314 | if USE_DROPOUT: 315 | dropout_offs = Dropout + offs_m[:, None] * actual_seqlen_k + offs_n[None, :] 316 | else: 317 | dropout_offs = None 318 | dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) 319 | dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) 320 | k = padded_load( 321 | k_ptrs, 322 | offs_n, 323 | offs_d, 324 | PA0=PAD_COLS, 325 | PA1=HEADS_PADDED, 326 | LA0=actual_seqlen_k, 327 | LA1=headdim, 328 | ) 329 | v = padded_load( 330 | v_ptrs, 331 | offs_n, 332 | offs_d, 333 | PA0=PAD_COLS, 334 | PA1=HEADS_PADDED, 335 | LA0=actual_seqlen_k, 336 | LA1=headdim, 337 | ) 338 | # fmt:off 339 | fr = max(0, index_start_n + BLOCK_N - 1 + actual_seqlen_q - actual_seqlen_k) 340 | fb = BLOCK_M * ((min(fr, actual_seqlen_q) + BLOCK_M - 1) // BLOCK_M) 341 | num_masked_blocks = (fb - index_begin_m) // BLOCK_M if IS_CAUSAL else 0 342 | index_next_start_m = index_begin_m 343 | # fmt:on 344 | 345 | if num_masked_blocks > 0: 346 | for _ in range(0, num_masked_blocks): 347 | dk, dv = _attn_bwd_dkdv( 348 | index_next_start_m, 349 | k, 350 | v, 351 | dk, 352 | dv, 353 | M, 354 | D, 355 | offs_m, 356 | offs_n, 357 | offs_d, 358 | q_ptrs, 359 | bias_ptrs, 360 | dropout_offs, 361 | do_ptrs, 362 | softmax_scale, 363 | stride_qm, 364 | stride_bm, 365 | stride_dom, 366 | actual_seqlen_q, 367 | actual_seqlen_k, 368 | fully_masked_lines, 369 | headdim, 370 | MASKED=True, 371 | IS_CAUSAL=IS_CAUSAL, 372 | BIAS_ON=BIAS_ON, 373 | USE_DROPOUT=USE_DROPOUT, 374 | PAD_ROWS=True, 375 | PAD_COLS=PAD_COLS, 376 | HEADS_PADDED=HEADS_PADDED, 377 | ) 378 | index_next_start_m += BLOCK_M 379 | 380 | if index_next_start_m < index_end_m: 381 | for index_start_m in range(index_next_start_m, index_end_m, BLOCK_M): 382 | dk, dv = _attn_bwd_dkdv( 383 | index_start_m, 384 | k, 385 | v, 386 | dk, 387 | dv, 388 | M, 389 | D, 390 | offs_m, 391 | offs_n, 392 | offs_d, 393 | q_ptrs, 394 | bias_ptrs, 395 | dropout_offs, 396 | do_ptrs, 397 | softmax_scale, 398 | stride_qm, 399 | stride_bm, 400 | stride_dom, 401 | actual_seqlen_q, 402 | actual_seqlen_k, 403 | fully_masked_lines, 404 | headdim, 405 | MASKED=False, 406 | IS_CAUSAL=IS_CAUSAL, 407 | BIAS_ON=BIAS_ON, 408 | USE_DROPOUT=USE_DROPOUT, 409 | PAD_ROWS=True, 410 | PAD_COLS=PAD_COLS, 411 | HEADS_PADDED=HEADS_PADDED, 412 | ) 413 | 414 | if HEADS_PADDED: 415 | if PAD_COLS: 416 | tl.store( 417 | dk_ptrs, 418 | dk, 419 | mask=(offs_n[:, None] < actual_seqlen_k) & (offs_d[None, :] < headdim), 420 | ) 421 | tl.store( 422 | dv_ptrs, 423 | dv, 424 | mask=(offs_n[:, None] < actual_seqlen_k) & (offs_d[None, :] < headdim), 425 | ) 426 | else: 427 | tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim) 428 | tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim) 429 | else: 430 | if PAD_COLS: 431 | tl.store(dk_ptrs, dk, mask=offs_n[:, None] < actual_seqlen_k) 432 | tl.store(dv_ptrs, dv, mask=offs_n[:, None] < actual_seqlen_k) 433 | else: 434 | tl.store(dk_ptrs, dk) 435 | tl.store(dv_ptrs, dv) 436 | 437 | 438 | @triton.jit 439 | def _attn_bwd_dq( 440 | index_start_n, 441 | q, 442 | dq, 443 | do, 444 | me_i, 445 | de_i, 446 | offs_m, 447 | offs_n, 448 | offs_d, 449 | k_ptrs, 450 | v_ptrs, 451 | bias_ptrs, 452 | dropout_offs, 453 | softmax_scale, 454 | dropout_prob, 455 | dropout_seed, 456 | stride_kn, 457 | stride_vn, 458 | actual_seqlen_q, 459 | actual_seqlen_k, 460 | headdim, 461 | MASKED: tl.constexpr, 462 | IS_CAUSAL: tl.constexpr, 463 | BIAS_ON: tl.constexpr, 464 | USE_DROPOUT: tl.constexpr, 465 | PAD_COLS: tl.constexpr, 466 | HEADS_PADDED: tl.constexpr, 467 | ): 468 | k_ptrs = k_ptrs + index_start_n * stride_kn 469 | v_ptrs = v_ptrs + index_start_n * stride_vn 470 | offs_n_curr = index_start_n + offs_n 471 | if BIAS_ON: 472 | bias_ptrs += index_start_n 473 | if USE_DROPOUT: 474 | dropout_offs += index_start_n 475 | k = padded_load( 476 | k_ptrs, offs_n_curr, offs_d, PAD_COLS, HEADS_PADDED, actual_seqlen_k, headdim 477 | ) 478 | v = padded_load( 479 | v_ptrs, offs_n_curr, offs_d, PAD_COLS, HEADS_PADDED, actual_seqlen_k, headdim 480 | ) 481 | if BIAS_ON: 482 | bias = padded_load( 483 | bias_ptrs, 484 | offs_m, 485 | offs_n_curr, 486 | True, 487 | PAD_COLS, 488 | actual_seqlen_q, 489 | actual_seqlen_k, 490 | ) 491 | qk = tl.dot(q, tl.trans(k)) 492 | if BIAS_ON: 493 | qk += bias / softmax_scale 494 | offs_n_causal = offs_n_curr - actual_seqlen_k + actual_seqlen_q 495 | if MASKED: 496 | if PAD_COLS: 497 | if IS_CAUSAL: 498 | qk = tl.where( 499 | tl.minimum(actual_seqlen_q - 1, offs_m)[:, None] >= offs_n_causal[None, :], 500 | qk, 501 | float("-inf"), 502 | ) 503 | else: 504 | qk = tl.where(actual_seqlen_q - 1 >= offs_n_causal[None, :], qk, float("-inf")) 505 | elif IS_CAUSAL: 506 | qk = tl.where(offs_m[:, None] >= offs_n_causal[None, :], qk, float("-inf")) 507 | tl.debug_barrier() 508 | 509 | p = tl.exp2(qk * (softmax_scale * 1.44269504089) - me_i[:, None]) 510 | dp = tl.dot(do, tl.trans(v)) 511 | 512 | ds = (p * (dp - de_i[:, None]) * softmax_scale).to(q.dtype) 513 | 514 | dq += tl.dot(ds, k) 515 | 516 | return dq 517 | 518 | 519 | @triton.jit 520 | def _attn_bwd_block_dq( 521 | index_start_m, 522 | Q, 523 | K, 524 | V, 525 | B, 526 | Dropout, 527 | Do, 528 | Dq, 529 | M, 530 | D, 531 | softmax_scale, 532 | dropout_prob, 533 | dropout_seed, 534 | stride_qm, 535 | stride_kn, 536 | stride_vn, 537 | stride_bm, 538 | stride_dom, 539 | stride_dqm, 540 | actual_seqlen_q, 541 | actual_seqlen_k, 542 | headdim, 543 | VARLEN: tl.constexpr, 544 | IS_CAUSAL: tl.constexpr, 545 | BIAS_ON: tl.constexpr, 546 | USE_DROPOUT: tl.constexpr, 547 | PAD_ROWS: tl.constexpr, 548 | HEADS_PADDED: tl.constexpr, 549 | BLOCK_M: tl.constexpr, 550 | BLOCK_N: tl.constexpr, 551 | BLOCK_HEADDIM: tl.constexpr, 552 | EVEN_N: tl.constexpr, 553 | ): 554 | if IS_CAUSAL: 555 | index_end_n = min( 556 | actual_seqlen_k - actual_seqlen_q + index_start_m + BLOCK_M, 557 | actual_seqlen_k, 558 | ) 559 | 560 | if index_end_n < 0: 561 | return 562 | else: 563 | index_end_n = actual_seqlen_k 564 | 565 | fully_masked_lines = actual_seqlen_q - actual_seqlen_k if IS_CAUSAL else 0 566 | mask_reached = fully_masked_lines >= index_start_m + BLOCK_M 567 | if (index_start_m >= actual_seqlen_q) or mask_reached: 568 | return 569 | 570 | offs_m = tl.arange(0, BLOCK_M) + index_start_m 571 | offs_n = tl.arange(0, BLOCK_N) 572 | offs_d = tl.arange(0, BLOCK_HEADDIM) 573 | 574 | q_ptrs = Q + (offs_m[:, None] * stride_qm + offs_d[None, :]) 575 | k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :]) 576 | v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :]) 577 | 578 | dq_ptrs = Dq + (offs_m[:, None] * stride_dqm + offs_d[None, :]) 579 | do_ptrs = Do + (offs_m[:, None] * stride_dom + offs_d[None, :]) 580 | 581 | if BIAS_ON: 582 | bias_ptrs = B + (offs_m[:, None] * stride_bm + offs_n[None, :]) 583 | else: 584 | bias_ptrs = None 585 | 586 | if USE_DROPOUT: 587 | dropout_offs = Dropout + (offs_m[:, None] * stride_bm + offs_n[None, :]) 588 | else: 589 | dropout_offs = None 590 | 591 | dq = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) 592 | q = padded_load( 593 | q_ptrs, 594 | offs_m, 595 | offs_d, 596 | PA0=PAD_ROWS, 597 | PA1=HEADS_PADDED, 598 | LA0=actual_seqlen_q, 599 | LA1=headdim, 600 | ) 601 | do = padded_load( 602 | do_ptrs, 603 | offs_m, 604 | offs_d, 605 | PA0=PAD_ROWS, 606 | PA1=HEADS_PADDED, 607 | LA0=actual_seqlen_q, 608 | LA1=headdim, 609 | ) 610 | me_i = tl.load(M + offs_m) 611 | de_i = tl.load(D + offs_m) 612 | 613 | uneven_n = actual_seqlen_k % BLOCK_N != 0 614 | attention_padding = VARLEN & uneven_n 615 | if IS_CAUSAL: 616 | first_masked_col = index_start_m + 1 + actual_seqlen_k - actual_seqlen_q 617 | elif attention_padding: 618 | first_masked_col = actual_seqlen_k 619 | else: 620 | first_masked_col = index_end_n 621 | nb_full_blocks = first_masked_col // BLOCK_N 622 | 623 | index_next_start_n = 0 624 | if nb_full_blocks > 0: 625 | for _ in range(0, nb_full_blocks): 626 | index_next_start_n = tl.multiple_of(index_next_start_n, BLOCK_N) 627 | dq = _attn_bwd_dq( 628 | index_next_start_n, 629 | q, 630 | dq, 631 | do, 632 | me_i, 633 | de_i, 634 | offs_m, 635 | offs_n, 636 | offs_d, 637 | k_ptrs, 638 | v_ptrs, 639 | bias_ptrs, 640 | dropout_offs, 641 | softmax_scale, 642 | dropout_prob, 643 | dropout_seed, 644 | stride_kn, 645 | stride_vn, 646 | actual_seqlen_q, 647 | actual_seqlen_k, 648 | headdim, 649 | IS_CAUSAL=IS_CAUSAL, 650 | BIAS_ON=BIAS_ON, 651 | USE_DROPOUT=USE_DROPOUT, 652 | MASKED=False, 653 | PAD_COLS=False, 654 | HEADS_PADDED=HEADS_PADDED, 655 | ) 656 | index_next_start_n += BLOCK_N 657 | 658 | if index_next_start_n < index_end_n: 659 | for index_start_n in range(index_next_start_n, index_end_n, BLOCK_N): 660 | pad_cols = (not EVEN_N) or ( 661 | VARLEN and (index_start_n + BLOCK_N > actual_seqlen_k) 662 | ) 663 | dq = _attn_bwd_dq( 664 | index_start_n, 665 | q, 666 | dq, 667 | do, 668 | me_i, 669 | de_i, 670 | offs_m, 671 | offs_n, 672 | offs_d, 673 | k_ptrs, 674 | v_ptrs, 675 | bias_ptrs, 676 | dropout_offs, 677 | softmax_scale, 678 | dropout_prob, 679 | dropout_seed, 680 | stride_kn, 681 | stride_vn, 682 | actual_seqlen_q, 683 | actual_seqlen_k, 684 | headdim, 685 | IS_CAUSAL=IS_CAUSAL, 686 | BIAS_ON=BIAS_ON, 687 | USE_DROPOUT=USE_DROPOUT, 688 | MASKED=True, 689 | PAD_COLS=pad_cols, 690 | HEADS_PADDED=HEADS_PADDED, 691 | ) 692 | 693 | if fully_masked_lines > 0: 694 | dq = tl.where(offs_m[:, None] < fully_masked_lines, 0, dq) 695 | 696 | if HEADS_PADDED: 697 | if PAD_ROWS: 698 | tl.store( 699 | dq_ptrs, 700 | dq, 701 | mask=(offs_m[:, None] < actual_seqlen_q) & (offs_d[None, :] < headdim), 702 | ) 703 | else: 704 | tl.store(dq_ptrs, dq, mask=offs_d[None, :] < headdim) 705 | else: 706 | if PAD_ROWS: 707 | tl.store(dq_ptrs, dq, mask=offs_m[:, None] < actual_seqlen_q) 708 | else: 709 | tl.store(dq_ptrs, dq) 710 | 711 | 712 | @safe_autotune( 713 | configs=[ 714 | Config( 715 | {"BLOCK_M1": 16, "BLOCK_N1": 16, "BLOCK_M2": 16, "BLOCK_N2": 16}, 716 | num_warps=4, 717 | num_stages=0, 718 | ), 719 | Config( 720 | {"BLOCK_M1": 32, "BLOCK_N1": 16, "BLOCK_M2": 16, "BLOCK_N2": 32}, 721 | num_warps=4, 722 | num_stages=0, 723 | ), 724 | Config( 725 | {"BLOCK_M1": 32, "BLOCK_N1": 64, "BLOCK_M2": 64, "BLOCK_N2": 32}, 726 | num_warps=4, 727 | num_stages=0, 728 | ), 729 | Config( 730 | {"BLOCK_M1": 64, "BLOCK_N1": 64, "BLOCK_M2": 64, "BLOCK_N2": 64}, 731 | num_warps=4, 732 | num_stages=0, 733 | ), 734 | Config( 735 | {"BLOCK_M1": 64, "BLOCK_N1": 128, "BLOCK_M2": 128, "BLOCK_N2": 64}, 736 | num_warps=4, 737 | num_stages=0, 738 | ), 739 | ], 740 | key=[ 741 | "CQSeq", 742 | "CKSeq", 743 | "DRuntime", 744 | "VARLEN", 745 | "USE_DROPOUT", 746 | "IS_CAUSAL", 747 | "BIAS_ON", 748 | "BLOCK_HEADDIM", 749 | ], 750 | prune_configs_by={"early_config_prune": config_prune_kernel}, 751 | ) 752 | @triton.heuristics( 753 | { 754 | "EVEN_M1": lambda args: args["QSeq"] % args["BLOCK_M1"] == 0, 755 | "EVEN_N1": lambda args: args["KSeq"] % args["BLOCK_N1"] == 0, 756 | "EVEN_M2": lambda args: args["QSeq"] % args["BLOCK_M2"] == 0, 757 | "EVEN_N2": lambda args: args["KSeq"] % args["BLOCK_N2"] == 0, 758 | "HEADS_PADDED": lambda args: args["headdim"] != args["BLOCK_HEADDIM"], 759 | "NUM_BLOCKS_KV": lambda args: math.ceil(args["KSeq"] / args["BLOCK_N1"]), 760 | } 761 | ) 762 | @triton.jit 763 | def _attn_bwd( 764 | Q, 765 | K, 766 | V, 767 | B, 768 | Do, 769 | M, 770 | D, 771 | softmax_scale, 772 | dropout_prob, 773 | dropout_seed, 774 | stride_qz, 775 | stride_qm, 776 | stride_qh, 777 | stride_kz, 778 | stride_kn, 779 | stride_kh, 780 | stride_vz, 781 | stride_vn, 782 | stride_vh, 783 | stride_bz, 784 | stride_bm, 785 | stride_bh, 786 | stride_doz, 787 | stride_dom, 788 | stride_doh, 789 | stride_dqz, 790 | stride_dqm, 791 | stride_dqh, 792 | stride_dkz, 793 | stride_dkn, 794 | stride_dkh, 795 | stride_dvz, 796 | stride_dvn, 797 | stride_dvh, 798 | nheads_q, 799 | num_repeats, 800 | QSeq, 801 | cum_seqlens_q, 802 | KSeq, 803 | cum_seqlens_k, 804 | seqlen_q_rounded, 805 | headdim, 806 | CQSeq, 807 | CKSeq, 808 | DRuntime, 809 | Dq, 810 | Dk, 811 | Dv, 812 | VARLEN: tl.constexpr, 813 | IS_CAUSAL: tl.constexpr, 814 | BIAS_ON: tl.constexpr, 815 | USE_DROPOUT: tl.constexpr, 816 | BLOCK_HEADDIM: tl.constexpr, 817 | # Heuristics 818 | EVEN_M1: tl.constexpr, 819 | EVEN_N1: tl.constexpr, 820 | EVEN_M2: tl.constexpr, 821 | EVEN_N2: tl.constexpr, 822 | NUM_BLOCKS_KV: tl.constexpr, 823 | HEADS_PADDED: tl.constexpr, 824 | # AutoTune 825 | BLOCK_M1: tl.constexpr, 826 | BLOCK_N1: tl.constexpr, 827 | BLOCK_M2: tl.constexpr, 828 | BLOCK_N2: tl.constexpr, 829 | ): 830 | pid = tl.program_id(0) 831 | off_zh = tl.program_id(1) 832 | off_z = off_zh // nheads_q 833 | off_head_q = off_zh % nheads_q 834 | off_head_kv = off_head_q // num_repeats 835 | 836 | if VARLEN: 837 | cu_seq_start_q = tl.load(cum_seqlens_q + off_z) 838 | cu_seq_start_k = tl.load(cum_seqlens_k + off_z) 839 | actual_seqlen_q = tl.load(cum_seqlens_q + off_z + 1) - cu_seq_start_q 840 | actual_seqlen_k = tl.load(cum_seqlens_k + off_z + 1) - cu_seq_start_k 841 | off_z = 0 842 | else: 843 | cu_seq_start_q = 0 844 | cu_seq_start_k = 0 845 | actual_seqlen_q = QSeq 846 | actual_seqlen_k = KSeq 847 | 848 | Q += off_z * stride_qz + off_head_q * stride_qh + cu_seq_start_q * stride_qm 849 | K += off_z * stride_kz + off_head_kv * stride_kh + cu_seq_start_k * stride_kn 850 | V += off_z * stride_vz + off_head_kv * stride_vh + cu_seq_start_k * stride_vn 851 | 852 | Do += off_z * stride_doz + off_head_q * stride_doh + cu_seq_start_q * stride_dom 853 | Dq += off_z * stride_dqz + off_head_q * stride_dqh + cu_seq_start_q * stride_dqm 854 | Dk += off_z * stride_dkz + off_head_q * stride_dkh + cu_seq_start_k * stride_dkn 855 | Dv += off_z * stride_dvz + off_head_q * stride_dvh + cu_seq_start_k * stride_dvn 856 | 857 | if BIAS_ON: 858 | B += off_z * stride_bz + off_head_q * stride_bh + cu_seq_start_q * stride_bm 859 | if USE_DROPOUT: 860 | Dropout = actual_seqlen_k * ( 861 | cu_seq_start_q + actual_seqlen_q * (off_head_q + nheads_q * off_z) 862 | ) 863 | else: 864 | Dropout = None 865 | 866 | D += off_zh * seqlen_q_rounded 867 | M += off_zh * seqlen_q_rounded 868 | 869 | if pid < NUM_BLOCKS_KV: 870 | i_start_n = pid 871 | pad_cols = (not EVEN_N1) or ( 872 | VARLEN and ((i_start_n + 1) * BLOCK_N1 > actual_seqlen_k) 873 | ) 874 | _attn_bwd_block_dkdv( 875 | i_start_n * BLOCK_N1, 876 | Q, 877 | K, 878 | V, 879 | B, 880 | Dropout, 881 | Do, 882 | Dk, 883 | Dv, 884 | M, 885 | D, 886 | softmax_scale, 887 | stride_qm, 888 | stride_kn, 889 | stride_vn, 890 | stride_bm, 891 | stride_dom, 892 | stride_dkn, 893 | stride_dvn, 894 | actual_seqlen_q, 895 | actual_seqlen_k, 896 | headdim, 897 | IS_CAUSAL=IS_CAUSAL, 898 | BIAS_ON=BIAS_ON, 899 | USE_DROPOUT=USE_DROPOUT, 900 | PAD_COLS=pad_cols, 901 | HEADS_PADDED=HEADS_PADDED, 902 | BLOCK_M=BLOCK_M1, 903 | BLOCK_N=BLOCK_N1, 904 | BLOCK_HEADDIM=BLOCK_HEADDIM, 905 | ) 906 | 907 | else: 908 | i_start_m = pid - NUM_BLOCKS_KV 909 | pad_rows = (not EVEN_M2) or ( 910 | VARLEN and ((i_start_m + 1) * BLOCK_M2 > actual_seqlen_q) 911 | ) 912 | _attn_bwd_block_dq( 913 | i_start_m * BLOCK_M2, 914 | Q, 915 | K, 916 | V, 917 | B, 918 | Dropout, 919 | Do, 920 | Dq, 921 | M, 922 | D, 923 | softmax_scale, 924 | dropout_prob, 925 | dropout_seed, 926 | stride_qm, 927 | stride_kn, 928 | stride_vn, 929 | stride_bm, 930 | stride_dom, 931 | stride_dqm, 932 | actual_seqlen_q, 933 | actual_seqlen_k, 934 | headdim, 935 | VARLEN=VARLEN, 936 | IS_CAUSAL=IS_CAUSAL, 937 | BIAS_ON=BIAS_ON, 938 | USE_DROPOUT=USE_DROPOUT, 939 | PAD_ROWS=pad_rows, 940 | HEADS_PADDED=HEADS_PADDED, 941 | BLOCK_M=BLOCK_M2, 942 | BLOCK_N=BLOCK_N2, 943 | BLOCK_HEADDIM=BLOCK_HEADDIM, 944 | EVEN_N=EVEN_N2, 945 | ) 946 | 947 | 948 | def _bwd_attention_kernel_call( 949 | dO: chex.Array, 950 | q: chex.Array, 951 | k: chex.Array, 952 | v: chex.Array, 953 | bias: tp.Optional[chex.Array], 954 | attention_mask: tp.Optional[chex.Array], 955 | o: chex.Array, 956 | M: chex.Array, 957 | dropout_prob: float, 958 | causal: bool, 959 | softmax_scale: tp.Optional[float], 960 | dropout_seed: tp.Optional[int], 961 | varlen_mode: bool, 962 | ): 963 | """Calls the Triton kernel for the backward pass of the attention mechanism. 964 | 965 | Args: 966 | softmax_scale: Scaling factor for the softmax function. 967 | residual: Residual from the forward pass. 968 | Do: Output gradient array. 969 | 970 | Returns: 971 | Tuple of the gradients of the query, key, value, and bias arrays. 972 | """ 973 | if attention_mask is not None and varlen_mode: 974 | assert bias is None, ( 975 | "Attention mask is not supported along with attention bias. Just use bias instead." 976 | ) 977 | assert q.shape[1] == k.shape[1], "Attention mask is not supported with QSeq != KSeq" 978 | varlen_mode = attention_mask.shape[0] > 1 979 | useless_padding = attention_mask.shape[1] - attention_mask.sum(-1).max().item() 980 | if useless_padding > 0: 981 | dO = dO[:, :-useless_padding] 982 | q = q[:, :-useless_padding] 983 | k = k[:, :-useless_padding] 984 | v = v[:, :-useless_padding] 985 | attention_mask = attention_mask[:, :-useless_padding] 986 | o = o[:, :-useless_padding] 987 | else: 988 | varlen_mode = False 989 | useless_padding = 0 990 | 991 | batch_size, QSeq, nheads_q, head_dim = q.shape 992 | _, KSeq, nheads_kv, _ = k.shape 993 | max_seqlen_q_rounded = math.ceil(QSeq / 128) * 128 994 | softmax_scale = 1.0 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale 995 | assert nheads_q % nheads_kv == 0, f"{nheads_q = } is not divisible by {nheads_kv =}" 996 | assert M.shape == (batch_size, nheads_q, max_seqlen_q_rounded) 997 | 998 | if varlen_mode: 999 | cum_seqlens_q = jnp.zeros(shape=(attention_mask.shape[0] + 1,), dtype="i4") 1000 | cum_seqlens_k = jnp.zeros(shape=(attention_mask.shape[0] + 1,), dtype="i4") 1001 | cum_seqlens_k = cum_seqlens_k.at[1:].set( 1002 | jnp.cumsum( 1003 | attention_mask.sum(axis=1, dtype="i4"), 1004 | axis=0, 1005 | dtype="i4", 1006 | ) 1007 | ) 1008 | cum_seqlens_q = cum_seqlens_q.at[1:].set( 1009 | jnp.cumsum( 1010 | attention_mask.sum(axis=1, dtype="i4"), 1011 | axis=0, 1012 | dtype="i4", 1013 | ) 1014 | ) 1015 | max_seqlen_q: int = attention_mask.shape[1] 1016 | max_seqlen_k: int = attention_mask.shape[1] 1017 | 1018 | dO = attention_pack_with_static_shape(dO, attention_mask) 1019 | 1020 | q = attention_pack_with_static_shape(q, attention_mask) 1021 | k = attention_pack_with_static_shape(k, attention_mask) 1022 | v = attention_pack_with_static_shape(v, attention_mask) 1023 | o = attention_pack_with_static_shape(o, attention_mask) 1024 | QSeq = q.shape[1] 1025 | KSeq = k.shape[1] 1026 | else: 1027 | cum_seqlens_q = None 1028 | cum_seqlens_k = None 1029 | max_seqlen_q = QSeq 1030 | max_seqlen_k = KSeq 1031 | 1032 | bz, bh, bm = calc_bias_strides( 1033 | bias, 1034 | batch_size, 1035 | nheads_q, 1036 | QSeq, 1037 | KSeq, 1038 | ) 1039 | 1040 | num_repeats = nheads_q // nheads_kv 1041 | BLOCK_HEADDIM = max(triton.next_power_of_2(head_dim), 16) 1042 | 1043 | oz, om, oh, _ = get_strides(o) 1044 | doz, dom, doh, _ = get_strides(dO) 1045 | qz, qm, qh, _ = get_strides(q) 1046 | kz, kn, kh, _ = get_strides(k) 1047 | vz, vn, vh, _ = get_strides(v) 1048 | 1049 | (delta,) = triton_call( 1050 | o, 1051 | dO, 1052 | oz, 1053 | om, 1054 | oh, 1055 | doz, 1056 | dom, 1057 | doh, 1058 | nheads_q, 1059 | QSeq, 1060 | max_seqlen_q_rounded, 1061 | cum_seqlens_q if cum_seqlens_q is not None else jnp.array((1,), dtype=jnp.float16), 1062 | head_dim, 1063 | max_seqlen_q // 32, 1064 | dtype_index(q), 1065 | VARLEN=varlen_mode, 1066 | BLOCK_HEADDIM=BLOCK_HEADDIM, 1067 | out_shape=[ 1068 | jax.ShapeDtypeStruct( 1069 | shape=M.shape, 1070 | dtype="f4", 1071 | sharding=get_sharding(M), 1072 | ) 1073 | ], 1074 | grid=lambda META: ( 1075 | triton.cdiv(max_seqlen_q, META["BLOCK_M"]), 1076 | batch_size * nheads_q, 1077 | ), 1078 | kernel=_attn_bwd_preprocess, 1079 | name="triton::ops::_attn_bwd_preprocess", 1080 | ) 1081 | 1082 | dq, dk, dv = triton_call( 1083 | q, 1084 | k, 1085 | v, 1086 | bias if bias is not None else jnp.zeros((1,), jnp.float16), 1087 | dO, 1088 | M, 1089 | delta, 1090 | softmax_scale, 1091 | dropout_prob, 1092 | dropout_seed if dropout_seed is not None else jnp.zeros((1,), jnp.float16), 1093 | qz, 1094 | qm, 1095 | qh, 1096 | kz, 1097 | kn, 1098 | kh, 1099 | vz, 1100 | vn, 1101 | vh, 1102 | bz, 1103 | bm, 1104 | bh, 1105 | doz, 1106 | dom, 1107 | doh, 1108 | qz, 1109 | qm, 1110 | qh, 1111 | kz, 1112 | kn, 1113 | kh, 1114 | vz, 1115 | vn, 1116 | vh, 1117 | nheads_q, 1118 | num_repeats, 1119 | QSeq, 1120 | cum_seqlens_q if cum_seqlens_q is not None else jnp.zeros((1,), jnp.float16), 1121 | KSeq, 1122 | cum_seqlens_k if cum_seqlens_k is not None else jnp.zeros((1,), jnp.float16), 1123 | max_seqlen_q_rounded, 1124 | head_dim, 1125 | max_seqlen_q // 32, 1126 | max_seqlen_k // 32, 1127 | dtype_index(q), 1128 | VARLEN=varlen_mode, 1129 | IS_CAUSAL=causal, 1130 | BIAS_ON=(bias is not None), 1131 | USE_DROPOUT=(dropout_prob > 0), 1132 | BLOCK_HEADDIM=BLOCK_HEADDIM, 1133 | kernel=_attn_bwd, 1134 | grid=lambda META: ( 1135 | triton.cdiv(KSeq, META["BLOCK_N1"]) + triton.cdiv(QSeq, META["BLOCK_M2"]), 1136 | batch_size * nheads_q, 1137 | ), 1138 | out_shape=[ 1139 | jax.ShapeDtypeStruct( 1140 | shape=q.shape, 1141 | dtype="f4", 1142 | sharding=get_sharding(q), 1143 | ), 1144 | jax.ShapeDtypeStruct( 1145 | shape=(k.shape[0], k.shape[1], q.shape[2], k.shape[3]), 1146 | dtype=k.dtype, 1147 | ), 1148 | jax.ShapeDtypeStruct( 1149 | shape=(v.shape[0], v.shape[1], q.shape[2], v.shape[3]), 1150 | dtype=v.dtype, 1151 | ), 1152 | ], 1153 | name="triton::ops::_attn_bwd", 1154 | ) 1155 | 1156 | if num_repeats > 1: 1157 | dk = dk.reshape(dk.shape[0], dk.shape[1], nheads_kv, num_repeats, -1) 1158 | dk = jnp.sum(dk, axis=3) 1159 | 1160 | dv = dv.reshape(dv.shape[0], dv.shape[1], nheads_kv, num_repeats, -1) 1161 | dv = jnp.sum(dv, axis=3) 1162 | 1163 | if varlen_mode: 1164 | dq = attention_unpack_with_static_shape(dq, cum_seqlens_q, batch_size, max_seqlen_q) 1165 | dk = attention_unpack_with_static_shape(dk, cum_seqlens_k, batch_size, max_seqlen_k) 1166 | dv = attention_unpack_with_static_shape(dv, cum_seqlens_k, batch_size, max_seqlen_k) 1167 | 1168 | if useless_padding > 0: 1169 | dq = jnp.pad(dq, ((0, useless_padding), (0, 0), (0, 0))) 1170 | dk = jnp.pad(dk, ((0, useless_padding), (0, 0), (0, 0))) 1171 | dv = jnp.pad(dv, ((0, useless_padding), (0, 0), (0, 0))) 1172 | 1173 | return dq, dk, dv 1174 | -------------------------------------------------------------------------------- /jax_flash_attn2/flash_attention_triton/_flash_attention.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi). 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import functools 15 | import typing as tp 16 | import warnings 17 | 18 | import chex 19 | import jax 20 | 21 | from ._backward_triton import _bwd_attention_kernel_call 22 | from ._forward_triton import _fwd_attention_kernel_call 23 | 24 | DEV_MODE = True 25 | 26 | 27 | def _jax_fwd_attention_call( 28 | q: tp.Optional[chex.Array], 29 | k: tp.Optional[chex.Array], 30 | v: tp.Optional[chex.Array], 31 | attention_mask: tp.Optional[chex.Array] = None, 32 | bias: tp.Optional[chex.Array] = None, 33 | softmax_scale: tp.Optional[float] = None, 34 | dropout_prob: float = 0.0, 35 | causal: bool = False, 36 | dropout_seed: tp.Optional[int] = None, 37 | varlen_mode: bool = True, 38 | ): 39 | out, lse = _fwd_attention_kernel_call( 40 | q=q, 41 | k=k, 42 | v=v, 43 | attention_mask=attention_mask, 44 | bias=bias, 45 | softmax_scale=softmax_scale, 46 | dropout_prob=dropout_prob, 47 | causal=causal, 48 | dropout_seed=dropout_seed, 49 | varlen_mode=varlen_mode, 50 | ) 51 | residual = ( 52 | q, 53 | k, 54 | v, 55 | bias, 56 | attention_mask, 57 | out, 58 | lse, 59 | dropout_seed, 60 | ) 61 | return out, residual 62 | 63 | 64 | def _jax_bwd_attention_call( 65 | softmax_scale: tp.Optional[float], 66 | dropout_prob: float, 67 | causal: bool, 68 | varlen_mode: bool, 69 | residual: tp.Tuple[chex.Array], 70 | dO: chex.Array, 71 | ): 72 | q, k, v, bias, attention_mask, out, lse, dropout_seed = residual 73 | dq, dk, dv = _bwd_attention_kernel_call( 74 | dO=dO, 75 | q=q, 76 | k=k, 77 | v=v, 78 | bias=bias, 79 | attention_mask=attention_mask, 80 | o=out, 81 | M=lse, 82 | dropout_prob=dropout_prob, 83 | causal=causal, 84 | dropout_seed=dropout_seed, 85 | softmax_scale=softmax_scale, 86 | varlen_mode=varlen_mode, 87 | ) 88 | return dq, dk, dv, None, None, None 89 | 90 | 91 | @functools.partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 9)) 92 | @functools.partial(jax.jit, static_argnums=(5, 6, 7, 9)) 93 | def flash_attention_call( 94 | q: tp.Optional[chex.Array], 95 | k: tp.Optional[chex.Array], 96 | v: tp.Optional[chex.Array], 97 | attention_mask: tp.Optional[chex.Array] = None, 98 | bias: tp.Optional[chex.Array] = None, 99 | softmax_scale: tp.Optional[float] = None, 100 | dropout_prob: float = 0.0, 101 | causal: bool = False, 102 | dropout_seed: tp.Optional[int] = None, 103 | varlen_mode: bool = True, 104 | ) -> chex.Array: 105 | return _fwd_attention_kernel_call( 106 | q=q, 107 | k=k, 108 | v=v, 109 | attention_mask=attention_mask, 110 | bias=bias, 111 | softmax_scale=softmax_scale, 112 | dropout_prob=dropout_prob, 113 | causal=causal, 114 | dropout_seed=dropout_seed, 115 | varlen_mode=varlen_mode, 116 | )[0] 117 | 118 | 119 | flash_attention_call.defvjp( 120 | _jax_fwd_attention_call, 121 | _jax_bwd_attention_call, 122 | ) 123 | 124 | 125 | def flash_attention( 126 | q: tp.Optional[chex.Array], 127 | k: tp.Optional[chex.Array], 128 | v: tp.Optional[chex.Array], 129 | attention_mask: tp.Optional[chex.Array] = None, 130 | bias: tp.Optional[chex.Array] = None, 131 | softmax_scale: tp.Optional[float] = None, 132 | dropout_prob: float = 0.0, 133 | causal: bool = False, 134 | dropout_seed: tp.Optional[int] = None, 135 | varlen_mode: bool = True, 136 | ) -> chex.Array: 137 | # TODO: Debug Varlen Mode 138 | if attention_mask is not None and not DEV_MODE and bias is None: 139 | warnings.warn( 140 | "Varlen Mode and attention mask passing is still under development", 141 | stacklevel=1, 142 | ) 143 | 144 | attention_mask = None 145 | out = flash_attention_call( 146 | q=q, 147 | k=k, 148 | v=v, 149 | attention_mask=attention_mask, 150 | bias=bias, 151 | softmax_scale=softmax_scale, 152 | dropout_prob=dropout_prob, 153 | causal=causal, 154 | dropout_seed=dropout_seed, 155 | varlen_mode=varlen_mode, 156 | ) 157 | return out 158 | -------------------------------------------------------------------------------- /jax_flash_attn2/flash_attention_triton/_forward_triton.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi). 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import math 16 | import typing as tp 17 | 18 | import chex 19 | import jax 20 | import jax.numpy as jnp 21 | import triton 22 | import triton.language as tl 23 | from eformer.callib import triton_call 24 | from triton import Config 25 | 26 | from ._utils import ( 27 | dtype_index, 28 | get_sharding, 29 | get_strides, 30 | safe_autotune, 31 | attention_pack_with_static_shape, 32 | attention_unpack_with_static_shape, 33 | calc_bias_strides, 34 | padded_load, 35 | ) 36 | 37 | 38 | def config_prune_kernel( 39 | configs: tp.List[Config], 40 | named_args: tp.Dict[str, tp.Any], 41 | **kwargs, 42 | ) -> tp.List[Config]: 43 | kept_configs = [] 44 | for config in configs: 45 | largerst_m = config.kwargs["BLOCK_M"] > named_args["QSeq"] 46 | largerst_n = config.kwargs["BLOCK_N"] > named_args["KSeq"] 47 | if largerst_m or largerst_n: 48 | pass 49 | else: 50 | kept_configs.append(config) 51 | if kept_configs: 52 | return kept_configs 53 | return [ 54 | Config({"BLOCK_M": 16, "BLOCK_N": 64}, num_warps=4, num_stages=1), 55 | Config({"BLOCK_M": 16, "BLOCK_N": 64}, num_warps=4, num_stages=3), 56 | Config({"BLOCK_M": 16, "BLOCK_N": 64}, num_warps=4, num_stages=5), 57 | ] 58 | 59 | 60 | @triton.jit 61 | def _attn_fwd_inner( 62 | q, 63 | m_i, 64 | me_i, 65 | k_ptrs, 66 | v_ptrs, 67 | bias_ptrs, 68 | acc_o, 69 | offs_m, 70 | offs_n, 71 | offs_d, 72 | softmax_scale, 73 | dropout_prob, 74 | dropout_seed, 75 | dropout_offs, 76 | stride_kn, 77 | stride_vn, 78 | index_start_n, 79 | actual_seqlen_q, 80 | actual_seqlen_k, 81 | headdim, 82 | USE_DROPOUT: tl.constexpr, 83 | IS_CAUSAL: tl.constexpr, 84 | BIAS_ON: tl.constexpr, 85 | MASKED: tl.constexpr, 86 | PADDED_COLS: tl.constexpr, 87 | PADDED_HEADS: tl.constexpr, 88 | BLOCK_M: tl.constexpr, 89 | BLOCK_N: tl.constexpr, 90 | ): 91 | LN2: tl.constexpr = 1.44269504089 92 | index_start_n = tl.multiple_of(index_start_n, BLOCK_N) 93 | offset_k_ptrs = k_ptrs + index_start_n * stride_kn 94 | k = padded_load( 95 | offset_k_ptrs, 96 | index_start_n + offs_n, 97 | offs_d, 98 | PA0=PADDED_COLS, 99 | PA1=PADDED_HEADS, 100 | LA0=actual_seqlen_k, 101 | LA1=headdim, 102 | ) 103 | if BIAS_ON: 104 | if PADDED_COLS: 105 | bias = tl.load( 106 | bias_ptrs + index_start_n, 107 | mask=(offs_m[:, None] < actual_seqlen_q) 108 | & ((index_start_n + offs_n) < actual_seqlen_k)[None, :], 109 | other=0.0, 110 | ) 111 | else: 112 | bias = tl.load( 113 | bias_ptrs + index_start_n, 114 | mask=offs_m[:, None] < actual_seqlen_q, 115 | other=0.0, 116 | ) 117 | 118 | qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 119 | qk += tl.dot(q, tl.trans(k)) 120 | if PADDED_COLS: 121 | qk += tl.where( 122 | (index_start_n + offs_n)[None, :] < actual_seqlen_k, 123 | 0, 124 | float("-inf"), 125 | ) 126 | 127 | if MASKED and IS_CAUSAL: 128 | causal_mask = ( 129 | offs_m[:, None] 130 | >= (index_start_n + offs_n - actual_seqlen_k + actual_seqlen_q)[None, :] 131 | ) 132 | qk += tl.where(causal_mask, 0, float("-inf")) 133 | 134 | if BIAS_ON: 135 | qk += bias * (LN2 / softmax_scale) 136 | m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, me_i) 137 | P_ij = tl.exp2(qk * softmax_scale - m_ij[:, None]) 138 | l_ij = tl.sum(P_ij, 1) 139 | 140 | if USE_DROPOUT: 141 | dropout_offs = dropout_offs + index_start_n 142 | dropout_mask = tl.rand(dropout_seed, dropout_offs) > dropout_prob 143 | P_ij = tl.where(dropout_mask, P_ij, 0.0) 144 | 145 | acc_o_scale = tl.exp2(m_i - m_ij) 146 | acc_o = acc_o * acc_o_scale[:, None] 147 | 148 | offset_v_ptrs = v_ptrs + index_start_n * stride_vn 149 | v = padded_load( 150 | offset_v_ptrs, 151 | index_start_n + offs_n, 152 | offs_d, 153 | PA0=PADDED_COLS, 154 | PA1=PADDED_HEADS, 155 | LA0=actual_seqlen_k, 156 | LA1=headdim, 157 | ) 158 | 159 | P_ij = P_ij.to(v.dtype) 160 | acc_o += tl.dot(P_ij, v) 161 | m_i = m_ij 162 | l_i_new = tl.exp2(me_i - m_ij) + l_ij 163 | me_i = m_ij + tl.log2(l_i_new) 164 | return m_i, me_i, acc_o 165 | 166 | 167 | @safe_autotune( 168 | configs=[ 169 | triton.Config({"BLOCK_M": 32, "BLOCK_N": 32}, num_warps=4, num_stages=5), 170 | triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=5), 171 | triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4, num_stages=5), 172 | triton.Config({"BLOCK_M": 256, "BLOCK_N": 256}, num_warps=4, num_stages=5), 173 | triton.Config({"BLOCK_M": 32, "BLOCK_N": 32}, num_warps=4, num_stages=3), 174 | triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=3), 175 | triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4, num_stages=3), 176 | triton.Config({"BLOCK_M": 256, "BLOCK_N": 256}, num_warps=4, num_stages=3), 177 | triton.Config({"BLOCK_M": 32, "BLOCK_N": 32}, num_warps=4, num_stages=1), 178 | triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1), 179 | triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4, num_stages=1), 180 | triton.Config({"BLOCK_M": 256, "BLOCK_N": 256}, num_warps=4, num_stages=1), 181 | ### 182 | triton.Config({"BLOCK_M": 32, "BLOCK_N": 32}, num_warps=8, num_stages=5), 183 | triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=8, num_stages=5), 184 | triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=8, num_stages=5), 185 | triton.Config({"BLOCK_M": 256, "BLOCK_N": 256}, num_warps=8, num_stages=5), 186 | triton.Config({"BLOCK_M": 32, "BLOCK_N": 32}, num_warps=8, num_stages=3), 187 | triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=8, num_stages=3), 188 | triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=8, num_stages=3), 189 | triton.Config({"BLOCK_M": 256, "BLOCK_N": 256}, num_warps=8, num_stages=3), 190 | triton.Config({"BLOCK_M": 32, "BLOCK_N": 32}, num_warps=8, num_stages=1), 191 | triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=8, num_stages=1), 192 | triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=8, num_stages=1), 193 | triton.Config({"BLOCK_M": 256, "BLOCK_N": 256}, num_warps=8, num_stages=1), 194 | ], 195 | key=[ 196 | "CKSeq", 197 | "CQSeq", 198 | "DRuntime", 199 | "VARLEN", 200 | "USE_DROPOUT", 201 | "IS_CAUSAL", 202 | "BIAS_ON", 203 | "BLOCK_HEADDIM", 204 | ], 205 | prune_configs_by={"early_config_prune": config_prune_kernel}, 206 | ) 207 | @triton.heuristics( 208 | { 209 | "EVEN_M": lambda args: args["QSeq"] % args["BLOCK_M"] == 0, 210 | "EVEN_N": lambda args: args["KSeq"] % args["BLOCK_N"] == 0, 211 | } 212 | ) 213 | @triton.jit 214 | def _attn_fwd( 215 | q, 216 | k, 217 | v, 218 | B, 219 | softmax_scale, 220 | dropout_prob, 221 | dropout_seed, 222 | stride_qz, 223 | stride_qm, 224 | stride_qh, 225 | stride_kz, 226 | stride_kn, 227 | stride_kh, 228 | stride_vz, 229 | stride_vn, 230 | stride_vh, 231 | stride_oz, 232 | stride_om, 233 | stride_oh, 234 | stride_bz, 235 | stride_bm, 236 | stride_bh, 237 | nheads_q, 238 | num_repeats, 239 | QSeq, 240 | cum_seqlens_q, 241 | KSeq, 242 | max_seqlen_q_rounded, 243 | headdim, 244 | CQSeq, 245 | CKSeq, 246 | DRuntime, 247 | Po, 248 | M, 249 | VARLEN: tl.constexpr, 250 | USE_DROPOUT: tl.constexpr, 251 | IS_CAUSAL: tl.constexpr, 252 | BIAS_ON: tl.constexpr, 253 | BLOCK_HEADDIM: tl.constexpr, 254 | PADDED_HEADS: tl.constexpr, 255 | EVEN_M: tl.constexpr, 256 | EVEN_N: tl.constexpr, 257 | BLOCK_M: tl.constexpr, 258 | BLOCK_N: tl.constexpr, 259 | ): 260 | LN2: tl.constexpr = 1.44269504089 261 | i_start_m = tl.program_id(0) 262 | off_zh = tl.program_id(1) 263 | off_head_q = off_zh % nheads_q 264 | off_head_kv = off_head_q // num_repeats 265 | off_z = off_zh // nheads_q 266 | 267 | if VARLEN: 268 | cu_seq_start_q = tl.load(cum_seqlens_q + off_z) 269 | actual_seqlen_q = tl.load(cum_seqlens_q + off_z + 1) - cu_seq_start_q 270 | if i_start_m * BLOCK_M >= actual_seqlen_q: 271 | return 272 | actual_seqlen_k = actual_seqlen_q 273 | cu_seq_start_k = cu_seq_start_q 274 | off_z = 0 275 | else: 276 | actual_seqlen_q = QSeq 277 | actual_seqlen_k = KSeq 278 | cu_seq_start_q = 0 279 | cu_seq_start_k = 0 280 | 281 | softmax_scale = softmax_scale * LN2 282 | 283 | offs_m = i_start_m * BLOCK_M + tl.arange(0, BLOCK_M) 284 | offs_n = tl.arange(0, BLOCK_N) 285 | offs_d = tl.arange(0, BLOCK_HEADDIM) 286 | 287 | fully_masked_lines = actual_seqlen_q - actual_seqlen_k if IS_CAUSAL else 0 288 | if fully_masked_lines >= (i_start_m + 1) * BLOCK_M: 289 | return 290 | 291 | q_ptrs = ( 292 | q 293 | + off_z * stride_qz 294 | + off_head_q * stride_qh 295 | + cu_seq_start_q * stride_qm 296 | + (offs_m[:, None] * stride_qm + offs_d[None, :]) 297 | ) 298 | 299 | k_ptrs = ( 300 | k 301 | + off_z * stride_kz 302 | + off_head_kv * stride_kh 303 | + cu_seq_start_k * stride_kn 304 | + (offs_n[:, None] * stride_kn + offs_d[None, :]) 305 | ) 306 | 307 | v_ptrs = ( 308 | v 309 | + off_z * stride_vz 310 | + off_head_kv * stride_vh 311 | + cu_seq_start_k * stride_vn 312 | + (offs_n[:, None] * stride_vn + offs_d[None, :]) 313 | ) 314 | 315 | if BIAS_ON: 316 | bias_ptrs = ( 317 | B 318 | + off_z * stride_bz 319 | + off_head_kv * stride_bh 320 | + cu_seq_start_q * stride_bm 321 | + (offs_m[:, None] * stride_bm + offs_n[None, :]) 322 | ) 323 | else: 324 | bias_ptrs = None 325 | if USE_DROPOUT: 326 | dropout_off = actual_seqlen_k * ( 327 | cu_seq_start_q + actual_seqlen_q * (off_head_q + nheads_q * off_z) 328 | ) 329 | dropout_offs = dropout_off + offs_m[:, None] * actual_seqlen_k + offs_n[None, :] 330 | else: 331 | dropout_offs = None 332 | 333 | me_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") 334 | m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") 335 | acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) 336 | 337 | pad_rows = (not EVEN_M) or (VARLEN and (i_start_m * BLOCK_M > actual_seqlen_q)) 338 | q = padded_load( 339 | q_ptrs, 340 | offs_m, 341 | offs_d, 342 | PA0=pad_rows, 343 | PA1=PADDED_HEADS, 344 | LA0=actual_seqlen_q, 345 | LA1=headdim, 346 | ) 347 | if IS_CAUSAL: 348 | end_n = min( 349 | actual_seqlen_k - actual_seqlen_q + (i_start_m + 1) * BLOCK_M, 350 | actual_seqlen_k, 351 | ) 352 | if end_n < 0: 353 | return 354 | else: 355 | end_n = actual_seqlen_k 356 | 357 | uneven_n = actual_seqlen_k % BLOCK_N != 0 358 | attention_padding = VARLEN & uneven_n 359 | if IS_CAUSAL: 360 | first_masked_col = i_start_m * BLOCK_M + 1 + actual_seqlen_k - actual_seqlen_q 361 | elif attention_padding: 362 | first_masked_col = actual_seqlen_k 363 | else: 364 | first_masked_col = end_n 365 | nb_full_blocks = first_masked_col // BLOCK_N 366 | 367 | next_start_n = 0 368 | if nb_full_blocks > 0: 369 | for _ in range(0, nb_full_blocks): 370 | m_i, me_i, acc_o = _attn_fwd_inner( 371 | q, 372 | m_i, 373 | me_i, 374 | k_ptrs, 375 | v_ptrs, 376 | bias_ptrs, 377 | acc_o, 378 | offs_m, 379 | offs_n, 380 | offs_d, 381 | softmax_scale, 382 | dropout_prob, 383 | dropout_seed, 384 | dropout_offs, 385 | stride_kn, 386 | stride_vn, 387 | next_start_n, 388 | actual_seqlen_q, 389 | actual_seqlen_k, 390 | headdim, 391 | USE_DROPOUT=USE_DROPOUT, 392 | IS_CAUSAL=IS_CAUSAL, 393 | BIAS_ON=BIAS_ON, 394 | MASKED=False, 395 | PADDED_COLS=False, 396 | PADDED_HEADS=PADDED_HEADS, 397 | BLOCK_M=BLOCK_M, 398 | BLOCK_N=BLOCK_N, 399 | ) 400 | next_start_n += BLOCK_N 401 | if next_start_n < end_n: 402 | for index_start_n in range(next_start_n, end_n, BLOCK_N): 403 | pad_cols = (not EVEN_N) or VARLEN 404 | m_i, me_i, acc_o = _attn_fwd_inner( 405 | q, 406 | m_i, 407 | me_i, 408 | k_ptrs, 409 | v_ptrs, 410 | bias_ptrs, 411 | acc_o, 412 | offs_m, 413 | offs_n, 414 | offs_d, 415 | softmax_scale, 416 | dropout_prob, 417 | dropout_seed, 418 | dropout_offs, 419 | stride_kn, 420 | stride_vn, 421 | index_start_n, 422 | actual_seqlen_q, 423 | actual_seqlen_k, 424 | headdim, 425 | USE_DROPOUT=USE_DROPOUT, 426 | IS_CAUSAL=IS_CAUSAL, 427 | BIAS_ON=BIAS_ON, 428 | MASKED=True, 429 | PADDED_COLS=pad_cols, 430 | PADDED_HEADS=PADDED_HEADS, 431 | BLOCK_M=BLOCK_M, 432 | BLOCK_N=BLOCK_N, 433 | ) 434 | 435 | if USE_DROPOUT: 436 | o_scale = tl.exp2((m_i - me_i) - tl.log2(1 - dropout_prob)) 437 | else: 438 | o_scale = tl.exp2(m_i - me_i) 439 | acc_o = acc_o * o_scale[:, None] 440 | if fully_masked_lines > i_start_m * BLOCK_M: 441 | acc_o = tl.where(offs_m[:, None] < fully_masked_lines, 0, acc_o) 442 | i_start_m = tl.program_id(0) 443 | offs_m = i_start_m * BLOCK_M + tl.arange(0, BLOCK_M) 444 | lse_ptrs = M + off_zh * max_seqlen_q_rounded + offs_m 445 | tl.store(lse_ptrs, me_i) 446 | offs_d = tl.arange(0, BLOCK_HEADDIM) 447 | out_ptrs = ( 448 | Po 449 | + off_z * stride_oz 450 | + off_head_q * stride_oh 451 | + cu_seq_start_q * stride_om 452 | + (offs_m[:, None] * stride_om + offs_d[None, :]) 453 | ) 454 | 455 | tl.store( 456 | out_ptrs, 457 | acc_o, 458 | mask=(offs_m[:, None] < actual_seqlen_q) & (offs_d[None, :] < headdim), 459 | ) 460 | 461 | 462 | def _fwd_attention_kernel_call( 463 | q: tp.Optional[chex.Array], 464 | k: tp.Optional[chex.Array], 465 | v: tp.Optional[chex.Array], 466 | attention_mask: tp.Optional[chex.Array] = None, 467 | bias: tp.Optional[chex.Array] = None, 468 | softmax_scale: tp.Optional[float] = None, 469 | dropout_prob: float = 0.0, 470 | causal: bool = False, 471 | dropout_seed: tp.Optional[int] = None, 472 | varlen_mode: bool = True, 473 | ): 474 | if attention_mask is not None and varlen_mode: 475 | varlen_mode = attention_mask.shape[0] > 1 476 | assert bias is None, ( 477 | "Attention mask is not supported along with attention bias. Just use bias instead." 478 | ) 479 | assert q.shape[1] == k.shape[1], "Attention mask is not supported with QSeq != KSeq" 480 | else: 481 | varlen_mode = False 482 | batch, QSeq, nheads_q, head_dim = q.shape 483 | _, KSeq, nheads_kv, _ = k.shape 484 | expected_kv_shape = (batch, KSeq, nheads_kv, head_dim) 485 | 486 | assert k.shape == expected_kv_shape, ( 487 | f"key shape is {k.shape = } and we excepted it to be like {expected_kv_shape = }" 488 | ) 489 | assert v.shape == expected_kv_shape, ( 490 | f"value shape is {v.shape = } and we excepted it to be like {expected_kv_shape = }" 491 | ) 492 | 493 | assert nheads_q % nheads_kv == 0, f"{nheads_q = } is not divisible by {nheads_kv =}" 494 | assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type" 495 | assert q.dtype in [jnp.float16, jnp.bfloat16], "Only support fp16 and bf16" 496 | 497 | softmax_scale = 1.0 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale 498 | 499 | varlen_mode = varlen_mode and (batch > 1) 500 | if varlen_mode: 501 | cum_seqlens_q = jnp.zeros(shape=(attention_mask.shape[0] + 1,), dtype=jnp.int32) 502 | cum_seqlens_q = cum_seqlens_q.at[1:].set( 503 | jnp.cumsum(attention_mask.sum(axis=1, dtype="i4"), axis=0, dtype="i4") 504 | ) 505 | max_seqlen_q = attention_mask.shape[1] 506 | max_seqlen_k = attention_mask.shape[1] 507 | q = attention_pack_with_static_shape(q, attention_mask) 508 | k = attention_pack_with_static_shape(k, attention_mask) 509 | v = attention_pack_with_static_shape(v, attention_mask) 510 | QSeq = q.shape[1] 511 | else: 512 | cum_seqlens_q = None 513 | max_seqlen_q = QSeq 514 | max_seqlen_k = KSeq 515 | 516 | bz, bh, bm = calc_bias_strides( 517 | bias, 518 | batch, 519 | nheads_q, 520 | QSeq, 521 | KSeq, 522 | ) 523 | 524 | max_seqlen_q_rounded = math.ceil(max_seqlen_q / 128) * 128 525 | BLOCK_HEADDIM = max(triton.next_power_of_2(head_dim), 16) 526 | PADDED_HEADS = BLOCK_HEADDIM > head_dim 527 | num_repeats = nheads_q // nheads_kv 528 | 529 | qz, qm, qh, _ = get_strides(q.shape) 530 | oz, om, oh, _ = get_strides(q.shape) 531 | kz, kn, kh, _ = get_strides(k.shape) 532 | vz, vn, vh, _ = get_strides(v.shape) 533 | 534 | metaparams = dict( 535 | VARLEN=varlen_mode, 536 | USE_DROPOUT=(dropout_prob > 0), 537 | IS_CAUSAL=causal, 538 | BIAS_ON=(bias is not None), 539 | BLOCK_HEADDIM=BLOCK_HEADDIM, 540 | PADDED_HEADS=PADDED_HEADS, 541 | ) 542 | out_shape = [ 543 | jax.ShapeDtypeStruct(q.shape, q.dtype, sharding=get_sharding(q)), 544 | jax.ShapeDtypeStruct((batch, nheads_q, max_seqlen_q_rounded), jnp.float32), 545 | ] 546 | 547 | out, lse = triton_call( 548 | q, 549 | k, 550 | v, 551 | bias if bias is not None else jnp.zeros((1,), jnp.float16), 552 | softmax_scale, 553 | dropout_prob, 554 | dropout_seed if dropout_seed is not None else jnp.zeros((1,), jnp.float16), 555 | qz, 556 | qm, 557 | qh, 558 | kz, 559 | kn, 560 | kh, 561 | vz, 562 | vn, 563 | vh, 564 | oz, 565 | om, 566 | oh, 567 | bz, 568 | bm, 569 | bh, 570 | nheads_q, 571 | num_repeats, 572 | QSeq, 573 | cum_seqlens_q if cum_seqlens_q is not None else jnp.zeros((1,), jnp.float16), 574 | KSeq, 575 | max_seqlen_q_rounded, 576 | head_dim, 577 | max_seqlen_q // 128, 578 | max_seqlen_k // 128, 579 | dtype_index(q), 580 | kernel=_attn_fwd, 581 | out_shape=out_shape, 582 | grid=lambda META: ( 583 | triton.cdiv(max_seqlen_q, META["BLOCK_M"]), 584 | batch * nheads_q, 585 | ), 586 | name="triton::ops::_attn_fwd", 587 | **metaparams, 588 | ) 589 | 590 | if varlen_mode: 591 | out = attention_unpack_with_static_shape(out, cum_seqlens_q, *attention_mask.shape) 592 | return out, lse 593 | -------------------------------------------------------------------------------- /jax_flash_attn2/flash_attention_triton/_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi). 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import math 17 | import typing as tp 18 | from functools import partial 19 | 20 | import chex 21 | import jax 22 | import jax.numpy as jnp 23 | import numpy 24 | import triton 25 | import triton.language as tl 26 | 27 | F = tp.TypeVar("F", bound=tp.Callable[..., tp.Any]) 28 | 29 | 30 | def safe_autotune( 31 | configs, 32 | key, 33 | prune_configs_by=None, 34 | reset_to_zero=None, 35 | restore_value=None, 36 | pre_hook=None, 37 | post_hook=None, 38 | warmup=None, 39 | rep=None, 40 | use_cuda_graph=False, 41 | do_bench=None, 42 | ) -> tp.Callable[[F], F]: 43 | """ 44 | Applies `triton.autotune` safely. Falls back to the original function if autotuning fails. 45 | """ 46 | try: 47 | from triton.runtime.autotuner import Autotuner 48 | 49 | def decorator(fn): 50 | return Autotuner( 51 | fn, 52 | fn.arg_names, 53 | configs, 54 | key, 55 | reset_to_zero, 56 | restore_value, 57 | pre_hook=pre_hook, 58 | post_hook=post_hook, 59 | prune_configs_by=prune_configs_by, 60 | warmup=warmup, 61 | rep=rep, 62 | use_cuda_graph=use_cuda_graph, 63 | ) 64 | 65 | return decorator 66 | except Exception as err: 67 | print(f"Couldn't autotune given function due to {err}") 68 | 69 | def decorator(fn): 70 | return fn 71 | 72 | return decorator 73 | 74 | 75 | def dtype_index(x: jnp.array) -> int: 76 | if x.dtype == jnp.float16: 77 | return 1 78 | if x.dtype == jnp.bfloat16: 79 | return 2 80 | if x.dtype == jnp.float32: 81 | return 3 82 | raise ValueError(x.dtype) 83 | 84 | 85 | def get_sharding(arr: chex.Array): 86 | """Gets the sharding of an array. 87 | 88 | Args: 89 | arr: Array to get sharding from. 90 | 91 | Returns: 92 | Sharding of the array. 93 | """ 94 | return getattr(arr, "sharding", None) 95 | 96 | 97 | def get_strides(shape: tuple[int, ...]) -> tuple[int, ...]: 98 | """Calculates strides for a given shape. 99 | 100 | Args: 101 | shape: Shape of the array. 102 | 103 | Returns: 104 | Tuple of strides. 105 | """ 106 | if hasattr(shape, "shape"): 107 | shape = shape.shape 108 | size = numpy.prod(shape) 109 | strides = [] 110 | for s in shape: 111 | size = int(size // s) 112 | strides.append(size) 113 | return tuple(strides) 114 | 115 | 116 | @triton.jit 117 | def padded_load( 118 | ptrs, 119 | offs_a, 120 | offs_b, 121 | PA0: tl.constexpr, 122 | PA1: tl.constexpr, 123 | LA0: tl.constexpr, 124 | LA1: tl.constexpr, 125 | ): 126 | if PA0: 127 | if PA1: 128 | x = tl.load( 129 | ptrs, 130 | mask=(offs_a[:, None] < LA0) & (offs_b[None, :] < LA1), 131 | other=0.0, 132 | ) 133 | else: 134 | x = tl.load( 135 | ptrs, 136 | mask=offs_a[:, None] < LA0, 137 | other=0.0, 138 | ) 139 | else: 140 | if PA1: 141 | x = tl.load( 142 | ptrs, 143 | mask=offs_b[None, :] < LA1, 144 | other=0.0, 145 | ) 146 | else: 147 | x = tl.load(ptrs) 148 | return x 149 | 150 | 151 | def calc_bias_strides( 152 | bias: tp.Optional[jnp.ndarray], 153 | batch: int, 154 | nheads_q: int, 155 | QSeq: int, 156 | KSeq: int, 157 | ) -> tp.Tuple[int, ...]: 158 | if bias is not None: 159 | if not hasattr(bias, "strides"): 160 | strides = tuple(map(lambda x: x * bias.itemsize, get_strides(bias))) 161 | else: 162 | strides = bias.strides 163 | if bias.shape[2] != QSeq or bias.shape[3] != KSeq: 164 | raise ValueError( 165 | f"Bias tensor has incompatible sequence dimensions. " 166 | f"Expected shape [..., {QSeq}, {KSeq}], but got [..., {bias.shape[2]}, {bias.shape[3]}]. " 167 | f"Full bias shape: {bias.shape}" 168 | ) 169 | if bias.shape[0] == 1: 170 | stride_bz = 0 171 | elif bias.shape[0] == batch: 172 | stride_bz = strides[0] // bias.itemsize 173 | else: 174 | raise ValueError( 175 | f"Batch dimension mismatch in bias tensor. " 176 | f"Expected either 1 (for broadcasting) or {batch} (batch size), " 177 | f"but got {bias.shape[0]}. Consider reshaping your bias tensor." 178 | ) 179 | if bias.shape[1] == 1: 180 | stride_bh = 0 181 | elif bias.shape[1] == nheads_q: 182 | stride_bh = strides[1] // bias.itemsize 183 | else: 184 | raise ValueError( 185 | f"Head dimension mismatch in bias tensor. " 186 | f"Expected either 1 (for broadcasting) or {nheads_q} (number of heads), " 187 | f"but got {bias.shape[1]}. Check that your bias tensor matches the model configuration." 188 | ) 189 | 190 | stride_bm = strides[2] // bias.itemsize 191 | else: 192 | stride_bz, stride_bh, stride_bm = 0, 0, 0 193 | return stride_bz, stride_bh, stride_bm 194 | 195 | 196 | @partial(jax.jit, static_argnames=["max_tokens"]) 197 | def attention_pack_with_static_shape( 198 | x: jnp.ndarray, 199 | attention_mask: jnp.ndarray, 200 | max_tokens: int = None, 201 | ) -> jnp.ndarray: 202 | """ 203 | Pack attention tensor by removing padding based on attention mask. 204 | Uses a static maximum shape to be compatible with JIT. 205 | """ 206 | batch_size, seqlen = attention_mask.shape 207 | num_heads, head_dim = x.shape[2], x.shape[3] 208 | 209 | if max_tokens is None: 210 | max_tokens = batch_size * seqlen 211 | 212 | seqlens = jnp.sum(attention_mask, axis=1).astype(jnp.int32) 213 | offsets = jnp.zeros((batch_size,), dtype=jnp.int32) 214 | offsets = offsets.at[1:].set(jnp.cumsum(seqlens[:-1])) 215 | packed = jnp.zeros((1, max_tokens, num_heads, head_dim), dtype=x.dtype) 216 | batch_idx, pos_idx = jnp.meshgrid( 217 | jnp.arange(batch_size), jnp.arange(seqlen), indexing="ij" 218 | ) 219 | 220 | batch_idx_flat = batch_idx.reshape(-1) 221 | pos_idx_flat = pos_idx.reshape(-1) 222 | 223 | valid_mask = pos_idx < seqlens[:, None] 224 | target_idx = jnp.where( 225 | valid_mask, 226 | offsets[:, None] + pos_idx, 227 | jnp.zeros_like(pos_idx), 228 | ) 229 | target_idx_flat = target_idx.reshape(-1) 230 | valid_mask_flat = valid_mask.reshape(-1) 231 | 232 | def process_token(i, packed_acc): 233 | b = batch_idx_flat[i] 234 | p = pos_idx_flat[i] 235 | t = target_idx_flat[i] 236 | valid = valid_mask_flat[i] 237 | packed_acc = jnp.where(valid, packed_acc.at[0, t].set(x[b, p]), packed_acc) 238 | 239 | return packed_acc 240 | 241 | packed = jax.lax.fori_loop(0, batch_size * seqlen, process_token, packed) 242 | return packed 243 | 244 | 245 | @partial(jax.jit, static_argnames=["seqlen", "batch_size"]) 246 | def attention_unpack_with_static_shape( 247 | x: jnp.ndarray, 248 | cum_seqlens: jnp.ndarray, 249 | batch_size: int, 250 | seqlen: int, 251 | ) -> jnp.ndarray: 252 | """ 253 | Unpack attention tensor by redistributing the packed values to their original positions. 254 | 255 | Args: 256 | x: Packed tensor of shape [1, packed_tokens, num_heads, head_dim] 257 | cum_seqlens: Cumulative sequence lengths, shape [batch_size+1] 258 | batch_size: Number of batches 259 | seqlen: Maximum sequence length 260 | 261 | Returns: 262 | Unpacked tensor of shape [batch_size, seqlen, num_heads, head_dim] 263 | """ 264 | num_heads, head_dim = x.shape[2], x.shape[3] 265 | 266 | # Create output with static shape 267 | unpacked = jnp.zeros((batch_size, seqlen, num_heads, head_dim), dtype=x.dtype) 268 | 269 | # Process each batch 270 | def process_batch(b, unpacked_acc): 271 | start_idx = cum_seqlens[b] 272 | end_idx = cum_seqlens[b + 1] 273 | seq_len = end_idx - start_idx 274 | 275 | # Process each position in the sequence 276 | def process_position(p, acc): 277 | # Only copy if within valid sequence length 278 | valid = p < seq_len 279 | src_idx = start_idx + p 280 | 281 | # Update conditionally 282 | acc = jnp.where(valid, acc.at[b, p].set(x[0, src_idx]), acc) 283 | 284 | return acc 285 | 286 | # Process all positions in this batch 287 | unpacked_acc = jax.lax.fori_loop(0, seqlen, process_position, unpacked_acc) 288 | 289 | return unpacked_acc 290 | 291 | # Process all batches 292 | unpacked = jax.lax.fori_loop(0, batch_size, process_batch, unpacked) 293 | 294 | return unpacked 295 | 296 | 297 | def basic_attention_refrence( 298 | q: jnp.ndarray, 299 | k: jnp.ndarray, 300 | v: jnp.ndarray, 301 | attn_bias: tp.Optional[jnp.ndarray] = None, 302 | query_padding_mask: tp.Optional[jnp.ndarray] = None, 303 | key_padding_mask: tp.Optional[jnp.ndarray] = None, 304 | dropout_prob: float = 0.0, 305 | dropout_key: tp.Optional[jax.random.PRNGKey] = None, 306 | window_size: tp.Tuple[int, int] = (-1, -1), 307 | causal: bool = False, 308 | softcap: float = 0.0, 309 | ): 310 | if causal: 311 | window_size = (window_size[0], 0) 312 | dtype_og = q.dtype 313 | q, k, v = q.astype(jnp.float32), k.astype(jnp.float32), v.astype(jnp.float32) 314 | QSeq, KSeq = q.shape[1], k.shape[1] 315 | repeats = q.shape[2] // k.shape[2] 316 | if repeats > 1: 317 | k = jnp.repeat(k, repeats=repeats, axis=2) 318 | v = jnp.repeat(v, repeats=repeats, axis=2) 319 | d = q.shape[-1] 320 | q_scaled = q / math.sqrt(d) 321 | scores = jnp.einsum("bthd,bshd->bhts", q_scaled, k) 322 | if softcap > 0: 323 | scores = scores / softcap 324 | scores = jnp.tanh(scores) 325 | scores = scores * softcap 326 | if key_padding_mask is not None: 327 | key_mask = (~key_padding_mask).reshape(key_padding_mask.shape[0], 1, 1, KSeq) 328 | scores = jnp.where(key_mask, jnp.finfo(scores.dtype).min, scores) 329 | if window_size[0] >= 0 or window_size[1] >= 0: 330 | row_idx = jnp.arange(QSeq).reshape(-1, 1) 331 | col_idx = jnp.arange(KSeq) 332 | if key_padding_mask is None: 333 | sk = KSeq 334 | else: 335 | sk = jnp.sum(key_padding_mask, axis=-1).reshape(-1, 1, 1, 1) 336 | if query_padding_mask is None: 337 | sq = QSeq 338 | else: 339 | sq = jnp.sum(query_padding_mask, axis=-1).reshape(-1, 1, 1, 1) 340 | if window_size[0] < 0: 341 | local_mask = col_idx > row_idx + sk - sq + window_size[1] 342 | else: 343 | if key_padding_mask is None: 344 | sk_full = jnp.full_like(col_idx, KSeq) 345 | else: 346 | sk_full = sk 347 | local_mask = jnp.logical_or( 348 | col_idx > jnp.minimum(row_idx + sk - sq + window_size[1], sk_full), 349 | col_idx < row_idx + sk - sq - window_size[0], 350 | ) 351 | scores = jnp.where(local_mask, jnp.finfo(scores.dtype).min, scores) 352 | if attn_bias is not None: 353 | scores = scores + attn_bias 354 | attention = jax.nn.softmax(scores, axis=-1).astype(v.dtype) 355 | if window_size[0] >= 0 or window_size[1] >= 0: 356 | all_masked = jnp.all(local_mask, axis=-1, keepdims=True) 357 | attention = jnp.where(all_masked, 0.0, attention) 358 | if query_padding_mask is not None: 359 | query_mask = (~query_padding_mask).reshape(query_padding_mask.shape[0], 1, QSeq, 1) 360 | attention = jnp.where(query_mask, 0.0, attention) 361 | dropout_scaling = 1.0 / (1 - dropout_prob) 362 | if dropout_prob > 0 and dropout_key is not None: 363 | dropout_mask = jax.random.bernoulli( 364 | dropout_key, p=1 - dropout_prob, shape=attention.shape 365 | ) 366 | attention_drop = attention * dropout_mask * dropout_scaling 367 | else: 368 | attention_drop = attention 369 | output = jnp.einsum("bhts,bshd->bthd", attention_drop, v) 370 | if query_padding_mask is not None: 371 | query_mask_expanded = (~query_padding_mask).reshape( 372 | query_padding_mask.shape[0], 373 | QSeq, 374 | 1, 375 | 1, 376 | ) 377 | output = jnp.where(query_mask_expanded, 0.0, output) 378 | return output.astype(dtype_og) 379 | -------------------------------------------------------------------------------- /jax_flash_attn2/refrence_call.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi). 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import math 17 | import typing as tp 18 | 19 | import jax 20 | import jax.numpy as jnp 21 | 22 | 23 | def basic_attention_refrence( 24 | q: jnp.ndarray, 25 | k: jnp.ndarray, 26 | v: jnp.ndarray, 27 | attn_bias: tp.Optional[jnp.ndarray] = None, 28 | query_padding_mask: tp.Optional[jnp.ndarray] = None, 29 | key_padding_mask: tp.Optional[jnp.ndarray] = None, 30 | dropout_prob: float = 0.0, 31 | dropout_key: tp.Optional[jax.random.PRNGKey] = None, 32 | window_size: tp.Tuple[int, int] = (-1, -1), 33 | causal: bool = False, 34 | softcap: float = 0.0, 35 | ): 36 | if causal: 37 | window_size = (window_size[0], 0) 38 | dtype_og = q.dtype 39 | q, k, v = q.astype(jnp.float32), k.astype(jnp.float32), v.astype(jnp.float32) 40 | QSeq, KSeq = q.shape[1], k.shape[1] 41 | repeats = q.shape[2] // k.shape[2] 42 | if repeats > 1: 43 | k = jnp.repeat(k, repeats=repeats, axis=2) 44 | v = jnp.repeat(v, repeats=repeats, axis=2) 45 | d = q.shape[-1] 46 | q_scaled = q / math.sqrt(d) 47 | scores = jnp.einsum("bthd,bshd->bhts", q_scaled, k) 48 | if softcap > 0: 49 | scores = scores / softcap 50 | scores = jnp.tanh(scores) 51 | scores = scores * softcap 52 | if key_padding_mask is not None: 53 | key_mask = (~key_padding_mask).reshape(key_padding_mask.shape[0], 1, 1, KSeq) 54 | scores = jnp.where(key_mask, jnp.finfo(scores.dtype).min, scores) 55 | if window_size[0] >= 0 or window_size[1] >= 0: 56 | row_idx = jnp.arange(QSeq).reshape(-1, 1) 57 | col_idx = jnp.arange(KSeq) 58 | if key_padding_mask is None: 59 | sk = KSeq 60 | else: 61 | sk = jnp.sum(key_padding_mask, axis=-1).reshape(-1, 1, 1, 1) 62 | if query_padding_mask is None: 63 | sq = QSeq 64 | else: 65 | sq = jnp.sum(query_padding_mask, axis=-1).reshape(-1, 1, 1, 1) 66 | if window_size[0] < 0: 67 | local_mask = col_idx > row_idx + sk - sq + window_size[1] 68 | else: 69 | if key_padding_mask is None: 70 | sk_full = jnp.full_like(col_idx, KSeq) 71 | else: 72 | sk_full = sk 73 | local_mask = jnp.logical_or( 74 | col_idx > jnp.minimum(row_idx + sk - sq + window_size[1], sk_full), 75 | col_idx < row_idx + sk - sq - window_size[0], 76 | ) 77 | scores = jnp.where(local_mask, jnp.finfo(scores.dtype).min, scores) 78 | if attn_bias is not None: 79 | scores = scores + attn_bias 80 | attention = jax.nn.softmax(scores, axis=-1).astype(v.dtype) 81 | if window_size[0] >= 0 or window_size[1] >= 0: 82 | all_masked = jnp.all(local_mask, axis=-1, keepdims=True) 83 | attention = jnp.where(all_masked, 0.0, attention) 84 | if query_padding_mask is not None: 85 | query_mask = (~query_padding_mask).reshape(query_padding_mask.shape[0], 1, QSeq, 1) 86 | attention = jnp.where(query_mask, 0.0, attention) 87 | dropout_scaling = 1.0 / (1 - dropout_prob) 88 | if dropout_prob > 0 and dropout_key is not None: 89 | dropout_mask = jax.random.bernoulli( 90 | dropout_key, p=1 - dropout_prob, shape=attention.shape 91 | ) 92 | attention_drop = attention * dropout_mask * dropout_scaling 93 | else: 94 | attention_drop = attention 95 | output = jnp.einsum("bhts,bshd->bthd", attention_drop, v) 96 | if query_padding_mask is not None: 97 | query_mask_expanded = (~query_padding_mask).reshape( 98 | query_padding_mask.shape[0], 99 | QSeq, 100 | 1, 101 | 1, 102 | ) 103 | output = jnp.where(query_mask_expanded, 0.0, output) 104 | return output.astype(dtype_og) 105 | -------------------------------------------------------------------------------- /jax_flash_attn2/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from jax import random as jrandom 4 | from jax.interpreters import pxla 5 | from jax.sharding import PartitionSpec 6 | 7 | 8 | class GenerateRNG: 9 | """An infinite generator of JAX PRNGKeys, useful for iterating over seeds.""" 10 | 11 | def __init__(self, seed: int = 0): 12 | """Initializes the generator with a starting seed. 13 | 14 | Args: 15 | seed: The seed to use for the initial PRNGKey. 16 | """ 17 | self.seed = seed 18 | self._rng = jrandom.PRNGKey(seed) 19 | 20 | def __next__(self) -> jrandom.PRNGKey: 21 | """Generates and returns the next PRNGKey in the sequence. 22 | 23 | Returns: 24 | The next PRNGKey derived from the internal state. 25 | """ 26 | self._rng, key = jrandom.split(self._rng) 27 | return key 28 | 29 | @property 30 | def rng(self) -> jrandom.PRNGKey: 31 | """Provides access to the next PRNGKey without advancing the generator. 32 | 33 | Returns: 34 | The next PRNGKey in the sequence. 35 | """ 36 | return next(self) 37 | 38 | 39 | def get_logger(name, level: int = logging.INFO) -> logging.Logger: 40 | """ 41 | Function to create and configure a logger. 42 | Args: 43 | name: str: The name of the logger. 44 | level: int: The logging level. Defaults to logging.INFO. 45 | Returns: 46 | logging.Logger: The configured logger instance. 47 | """ 48 | logger = logging.getLogger(name) 49 | logger.propagate = False 50 | logger.setLevel(level) 51 | console_handler = logging.StreamHandler() 52 | console_handler.setLevel(level) 53 | formatter = logging.Formatter("%(asctime)s %(levelname)-8s [%(name)s] %(message)s") 54 | console_handler.setFormatter(formatter) 55 | logger.addHandler(console_handler) 56 | return logger 57 | 58 | 59 | def names_in_current_mesh(*names: str) -> bool: 60 | """ 61 | Check if the given names are present in the current JAX mesh. 62 | 63 | Args: 64 | *names: Variable number of axis names to check. 65 | 66 | Returns: 67 | True if all given names are present in the current mesh, False otherwise. 68 | """ 69 | mesh_axis_names = pxla.thread_resources.env.physical_mesh.axis_names 70 | return set(names) <= set(mesh_axis_names) 71 | 72 | 73 | def get_names_from_partition_spec( 74 | partition_specs: dict[str, PartitionSpec], 75 | ) -> list[str]: 76 | """ 77 | Extract axis names from a partition specification. 78 | 79 | This function recursively iterates through the provided `partition_specs` 80 | dictionary and extracts all unique axis names used in the sharding specifications. 81 | 82 | Args: 83 | partition_specs: A dictionary mapping parameter names to their respective `PartitionSpec`. 84 | 85 | Returns: 86 | A list of unique axis names used in the partition specs. 87 | """ 88 | names = set() 89 | if isinstance(partition_specs, dict): 90 | partition_specs = partition_specs.values() 91 | for item in partition_specs: 92 | if item is None: 93 | continue 94 | elif isinstance(item, str): 95 | names.add(item) 96 | else: 97 | names.update(get_names_from_partition_spec(item)) 98 | return list(names) 99 | -------------------------------------------------------------------------------- /poetry.lock: -------------------------------------------------------------------------------- 1 | # This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. 2 | 3 | [[package]] 4 | name = "absl-py" 5 | version = "2.1.0" 6 | description = "Abseil Python Common Libraries, see https://github.com/abseil/abseil-py." 7 | optional = false 8 | python-versions = ">=3.7" 9 | files = [ 10 | {file = "absl-py-2.1.0.tar.gz", hash = "sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff"}, 11 | {file = "absl_py-2.1.0-py3-none-any.whl", hash = "sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308"}, 12 | ] 13 | 14 | [[package]] 15 | name = "chex" 16 | version = "0.1.87" 17 | description = "Chex: Testing made fun, in JAX!" 18 | optional = false 19 | python-versions = ">=3.9" 20 | files = [ 21 | {file = "chex-0.1.87-py3-none-any.whl", hash = "sha256:ce536475661fd96d21be0c1728ecdbedd03f8ff950c662dfc338c92ea782cb16"}, 22 | {file = "chex-0.1.87.tar.gz", hash = "sha256:0096d89cc8d898bb521ef4bfbf5c24549022b0e5b301f529ab57238896fe6c5d"}, 23 | ] 24 | 25 | [package.dependencies] 26 | absl-py = ">=0.9.0" 27 | jax = ">=0.4.27" 28 | jaxlib = ">=0.4.27" 29 | numpy = ">=1.24.1" 30 | setuptools = {version = "*", markers = "python_version >= \"3.12\""} 31 | toolz = ">=0.9.0" 32 | typing-extensions = ">=4.2.0" 33 | 34 | [[package]] 35 | name = "einops" 36 | version = "0.8.0" 37 | description = "A new flavour of deep learning operations" 38 | optional = false 39 | python-versions = ">=3.8" 40 | files = [ 41 | {file = "einops-0.8.0-py3-none-any.whl", hash = "sha256:9572fb63046264a862693b0a87088af3bdc8c068fde03de63453cbbde245465f"}, 42 | {file = "einops-0.8.0.tar.gz", hash = "sha256:63486517fed345712a8385c100cb279108d9d47e6ae59099b07657e983deae85"}, 43 | ] 44 | 45 | [[package]] 46 | name = "filelock" 47 | version = "3.16.1" 48 | description = "A platform independent file lock." 49 | optional = false 50 | python-versions = ">=3.8" 51 | files = [ 52 | {file = "filelock-3.16.1-py3-none-any.whl", hash = "sha256:2082e5703d51fbf98ea75855d9d5527e33d8ff23099bec374a134febee6946b0"}, 53 | {file = "filelock-3.16.1.tar.gz", hash = "sha256:c249fbfcd5db47e5e2d6d62198e565475ee65e4831e2561c8e313fa7eb961435"}, 54 | ] 55 | 56 | [package.extras] 57 | docs = ["furo (>=2024.8.6)", "sphinx (>=8.0.2)", "sphinx-autodoc-typehints (>=2.4.1)"] 58 | testing = ["covdefaults (>=2.3)", "coverage (>=7.6.1)", "diff-cover (>=9.2)", "pytest (>=8.3.3)", "pytest-asyncio (>=0.24)", "pytest-cov (>=5)", "pytest-mock (>=3.14)", "pytest-timeout (>=2.3.1)", "virtualenv (>=20.26.4)"] 59 | typing = ["typing-extensions (>=4.12.2)"] 60 | 61 | [[package]] 62 | name = "jax" 63 | version = "0.4.35" 64 | description = "Differentiate, compile, and transform Numpy code." 65 | optional = false 66 | python-versions = ">=3.10" 67 | files = [ 68 | {file = "jax-0.4.35-py3-none-any.whl", hash = "sha256:fa99e909a31424abfec750019a6dd36f6acc18a6e7d40e2c0086b932cc351325"}, 69 | {file = "jax-0.4.35.tar.gz", hash = "sha256:c0c986993026b10bf6f607fecb7417377460254640766ce40f1fef3fd139c12e"}, 70 | ] 71 | 72 | [package.dependencies] 73 | jaxlib = ">=0.4.34,<=0.4.35" 74 | ml-dtypes = ">=0.4.0" 75 | numpy = [ 76 | {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, 77 | {version = ">=1.24", markers = "python_version < \"3.12\""}, 78 | ] 79 | opt-einsum = "*" 80 | scipy = [ 81 | {version = ">=1.11.1", markers = "python_version >= \"3.12\""}, 82 | {version = ">=1.10", markers = "python_version < \"3.12\""}, 83 | ] 84 | 85 | [package.extras] 86 | ci = ["jaxlib (==0.4.34)"] 87 | cuda = ["jax-cuda12-plugin[with-cuda] (>=0.4.34,<=0.4.35)", "jaxlib (==0.4.34)"] 88 | cuda12 = ["jax-cuda12-plugin[with-cuda] (>=0.4.34,<=0.4.35)", "jaxlib (==0.4.34)"] 89 | cuda12-local = ["jax-cuda12-plugin (==0.4.34)", "jaxlib (==0.4.34)"] 90 | cuda12-pip = ["jax-cuda12-plugin[with-cuda] (>=0.4.34,<=0.4.35)", "jaxlib (==0.4.34)"] 91 | k8s = ["kubernetes"] 92 | minimum-jaxlib = ["jaxlib (==0.4.34)"] 93 | tpu = ["jaxlib (>=0.4.34,<=0.4.35)", "libtpu (==0.0.2)", "libtpu-nightly (==0.1.dev20241010+nightly.cleanup)", "requests"] 94 | 95 | [[package]] 96 | name = "jaxlib" 97 | version = "0.4.35" 98 | description = "XLA library for JAX" 99 | optional = false 100 | python-versions = ">=3.10" 101 | files = [ 102 | {file = "jaxlib-0.4.35-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:907e548ad6ce53b242a55c5f36c2a2a4c37d38f6cd8c356fc550a2f18ab0e82f"}, 103 | {file = "jaxlib-0.4.35-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8f8c499644660aefd0ae2ee31039da6d4df0f26d0ee67ba9fb316183a5304288"}, 104 | {file = "jaxlib-0.4.35-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:5d2d8a5b89d334b875ede98d7fcee946bebef1a1b5abd118ff543bcef4ab09f5"}, 105 | {file = "jaxlib-0.4.35-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:91a283a72263feebe0d110d1136df96950744e47530f12df42c03f36888c971e"}, 106 | {file = "jaxlib-0.4.35-cp310-cp310-win_amd64.whl", hash = "sha256:d210bab7e1ce0b2f2e568548b3903ea6aec349019fc1398cd2a0c069e8342e62"}, 107 | {file = "jaxlib-0.4.35-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:7f8bfc90f68857b223b7e38a9bdf466a4f1cb405c9a4aa11698dc9ab7b35c29b"}, 108 | {file = "jaxlib-0.4.35-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:261570c94b169dc90f3af903282eeec856b52736c0944d243504ced93d19b217"}, 109 | {file = "jaxlib-0.4.35-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:e1cee6dc291251f3fb6b0127fdd96c0439ac1ea97e01571d06910df72d6ac6e1"}, 110 | {file = "jaxlib-0.4.35-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:bc9eafba001ff8569cfa252fe7f04ba553622702b4b473b656dd0866edf6b8d4"}, 111 | {file = "jaxlib-0.4.35-cp311-cp311-win_amd64.whl", hash = "sha256:0fd990354d5623d3a34493fcd7213493390dbf5039bea19b62e2aaee1049eda9"}, 112 | {file = "jaxlib-0.4.35-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:b44f3e6e9fb748bb43df914356cf9d0d0c9a6e446a12c21fe843db25ed0df65f"}, 113 | {file = "jaxlib-0.4.35-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:504d0a2e2117724359d99d7e3663022686dcdddd85aa14bdad02008d444481ad"}, 114 | {file = "jaxlib-0.4.35-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:187cb6929dc139b75d952d67c33118473c1b4105525a3e5607f064e7b8efdc74"}, 115 | {file = "jaxlib-0.4.35-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:04d1db3bf0050d120238bfb9b686b58fefcc4d9dd9e2d96aecd3f68a1f1f5e0a"}, 116 | {file = "jaxlib-0.4.35-cp312-cp312-win_amd64.whl", hash = "sha256:dddffce48d7e6057008999aed2d8a9daecc57a48c45a4f8c475e00880eb2e41d"}, 117 | {file = "jaxlib-0.4.35-cp313-cp313-macosx_10_14_x86_64.whl", hash = "sha256:14aeac3fea2ca1d5afb1878f72470b159cc89adb2633c5f0686f5d7c39f2ac18"}, 118 | {file = "jaxlib-0.4.35-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:e8c9579e20d5ecdc4f61336cdd032710cb8c38d5ae9c4fce0cf9ea031cef21cb"}, 119 | {file = "jaxlib-0.4.35-cp313-cp313-manylinux2014_aarch64.whl", hash = "sha256:7b11ad7c13f7f96f36efd303711ecac425f19ca2ddf65cf1be1541167a959ee5"}, 120 | {file = "jaxlib-0.4.35-cp313-cp313-manylinux2014_x86_64.whl", hash = "sha256:0be3cf9df879d9ae1b5b92fc281f77d21f522fcbae1a48a02661026bbd9b9309"}, 121 | {file = "jaxlib-0.4.35-cp313-cp313-win_amd64.whl", hash = "sha256:330c090bb9af413f552d8a92d097e50baec6b75823430fb2966a49f5298d4c43"}, 122 | ] 123 | 124 | [package.dependencies] 125 | ml-dtypes = ">=0.2.0" 126 | numpy = ">=1.24" 127 | scipy = [ 128 | {version = ">=1.11.1", markers = "python_version >= \"3.12\""}, 129 | {version = ">=1.10", markers = "python_version < \"3.12\""}, 130 | ] 131 | 132 | [[package]] 133 | name = "ml-dtypes" 134 | version = "0.5.0" 135 | description = "" 136 | optional = false 137 | python-versions = ">=3.9" 138 | files = [ 139 | {file = "ml_dtypes-0.5.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:8c32138975797e681eb175996d64356bcfa124bdbb6a70460b9768c2b35a6fa4"}, 140 | {file = "ml_dtypes-0.5.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ab046f2ff789b1f11b2491909682c5d089934835f9a760fafc180e47dcb676b8"}, 141 | {file = "ml_dtypes-0.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c7a9152f5876fef565516aa5dd1dccd6fc298a5891b2467973905103eb5c7856"}, 142 | {file = "ml_dtypes-0.5.0-cp310-cp310-win_amd64.whl", hash = "sha256:968fede07d1f9b926a63df97d25ac656cac1a57ebd33701734eaf704bc55d8d8"}, 143 | {file = "ml_dtypes-0.5.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:60275f2b51b56834e840c4809fca840565f9bf8e9a73f6d8c94f5b5935701215"}, 144 | {file = "ml_dtypes-0.5.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:76942f6aeb5c40766d5ea62386daa4148e6a54322aaf5b53eae9e7553240222f"}, 145 | {file = "ml_dtypes-0.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2e7534392682c3098bc7341648c650864207169c654aed83143d7a19c67ae06f"}, 146 | {file = "ml_dtypes-0.5.0-cp311-cp311-win_amd64.whl", hash = "sha256:dc74fd9995513d33eac63d64e436240f5494ec74d522a9f0920194942fc3d2d7"}, 147 | {file = "ml_dtypes-0.5.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:d4b1a70a3e5219790d6b55b9507606fc4e02911d1497d16c18dd721eb7efe7d0"}, 148 | {file = "ml_dtypes-0.5.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a988bac6572630e1e9c2edd9b1277b4eefd1c86209e52b0d061b775ac33902ff"}, 149 | {file = "ml_dtypes-0.5.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a38df8df61194aeaae1ab7579075779b4ad32cd1cffd012c28be227fa7f2a70a"}, 150 | {file = "ml_dtypes-0.5.0-cp312-cp312-win_amd64.whl", hash = "sha256:afa08343069874a30812871d639f9c02b4158ace065601406a493a8511180c02"}, 151 | {file = "ml_dtypes-0.5.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:d3b3db9990c3840986a0e70524e122cfa32b91139c3653df76121ba7776e015f"}, 152 | {file = "ml_dtypes-0.5.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e04fde367b2fe901b1d47234426fe8819909bd1dd862a5adb630f27789c20599"}, 153 | {file = "ml_dtypes-0.5.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:54415257f00eb44fbcc807454efac3356f75644f1cbfc2d4e5522a72ae1dacab"}, 154 | {file = "ml_dtypes-0.5.0-cp313-cp313-win_amd64.whl", hash = "sha256:cb5cc7b25acabd384f75bbd78892d0c724943f3e2e1986254665a1aa10982e07"}, 155 | {file = "ml_dtypes-0.5.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:5f2b59233a0dbb6a560b3137ed6125433289ccba2f8d9c3695a52423a369ed15"}, 156 | {file = "ml_dtypes-0.5.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:099e09edd54e676903b4538f3815b5ab96f5b119690514602d96bfdb67172cbe"}, 157 | {file = "ml_dtypes-0.5.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a03fc861b86cc586728e3d093ba37f0cc05e65330c3ebd7688e7bae8290f8859"}, 158 | {file = "ml_dtypes-0.5.0-cp39-cp39-win_amd64.whl", hash = "sha256:7ee9c320bb0f9ffdf9f6fa6a696ef2e005d1f66438d6f1c1457338e00a02e8cf"}, 159 | {file = "ml_dtypes-0.5.0.tar.gz", hash = "sha256:3e7d3a380fe73a63c884f06136f8baa7a5249cc8e9fdec677997dd78549f8128"}, 160 | ] 161 | 162 | [package.dependencies] 163 | numpy = [ 164 | {version = ">=2.1.0", markers = "python_version >= \"3.13\""}, 165 | {version = ">=1.26.0", markers = "python_version >= \"3.12\" and python_version < \"3.13\""}, 166 | {version = ">=1.23.3", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, 167 | {version = ">=1.21.2", markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, 168 | ] 169 | 170 | [package.extras] 171 | dev = ["absl-py", "pyink", "pylint (>=2.6.0)", "pytest", "pytest-xdist"] 172 | 173 | [[package]] 174 | name = "numpy" 175 | version = "2.1.2" 176 | description = "Fundamental package for array computing in Python" 177 | optional = false 178 | python-versions = ">=3.10" 179 | files = [ 180 | {file = "numpy-2.1.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:30d53720b726ec36a7f88dc873f0eec8447fbc93d93a8f079dfac2629598d6ee"}, 181 | {file = "numpy-2.1.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e8d3ca0a72dd8846eb6f7dfe8f19088060fcb76931ed592d29128e0219652884"}, 182 | {file = "numpy-2.1.2-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:fc44e3c68ff00fd991b59092a54350e6e4911152682b4782f68070985aa9e648"}, 183 | {file = "numpy-2.1.2-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:7c1c60328bd964b53f8b835df69ae8198659e2b9302ff9ebb7de4e5a5994db3d"}, 184 | {file = "numpy-2.1.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6cdb606a7478f9ad91c6283e238544451e3a95f30fb5467fbf715964341a8a86"}, 185 | {file = "numpy-2.1.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d666cb72687559689e9906197e3bec7b736764df6a2e58ee265e360663e9baf7"}, 186 | {file = "numpy-2.1.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:c6eef7a2dbd0abfb0d9eaf78b73017dbfd0b54051102ff4e6a7b2980d5ac1a03"}, 187 | {file = "numpy-2.1.2-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:12edb90831ff481f7ef5f6bc6431a9d74dc0e5ff401559a71e5e4611d4f2d466"}, 188 | {file = "numpy-2.1.2-cp310-cp310-win32.whl", hash = "sha256:a65acfdb9c6ebb8368490dbafe83c03c7e277b37e6857f0caeadbbc56e12f4fb"}, 189 | {file = "numpy-2.1.2-cp310-cp310-win_amd64.whl", hash = "sha256:860ec6e63e2c5c2ee5e9121808145c7bf86c96cca9ad396c0bd3e0f2798ccbe2"}, 190 | {file = "numpy-2.1.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b42a1a511c81cc78cbc4539675713bbcf9d9c3913386243ceff0e9429ca892fe"}, 191 | {file = "numpy-2.1.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:faa88bc527d0f097abdc2c663cddf37c05a1c2f113716601555249805cf573f1"}, 192 | {file = "numpy-2.1.2-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:c82af4b2ddd2ee72d1fc0c6695048d457e00b3582ccde72d8a1c991b808bb20f"}, 193 | {file = "numpy-2.1.2-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:13602b3174432a35b16c4cfb5de9a12d229727c3dd47a6ce35111f2ebdf66ff4"}, 194 | {file = "numpy-2.1.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1ebec5fd716c5a5b3d8dfcc439be82a8407b7b24b230d0ad28a81b61c2f4659a"}, 195 | {file = "numpy-2.1.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e2b49c3c0804e8ecb05d59af8386ec2f74877f7ca8fd9c1e00be2672e4d399b1"}, 196 | {file = "numpy-2.1.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:2cbba4b30bf31ddbe97f1c7205ef976909a93a66bb1583e983adbd155ba72ac2"}, 197 | {file = "numpy-2.1.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:8e00ea6fc82e8a804433d3e9cedaa1051a1422cb6e443011590c14d2dea59146"}, 198 | {file = "numpy-2.1.2-cp311-cp311-win32.whl", hash = "sha256:5006b13a06e0b38d561fab5ccc37581f23c9511879be7693bd33c7cd15ca227c"}, 199 | {file = "numpy-2.1.2-cp311-cp311-win_amd64.whl", hash = "sha256:f1eb068ead09f4994dec71c24b2844f1e4e4e013b9629f812f292f04bd1510d9"}, 200 | {file = "numpy-2.1.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:d7bf0a4f9f15b32b5ba53147369e94296f5fffb783db5aacc1be15b4bf72f43b"}, 201 | {file = "numpy-2.1.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b1d0fcae4f0949f215d4632be684a539859b295e2d0cb14f78ec231915d644db"}, 202 | {file = "numpy-2.1.2-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:f751ed0a2f250541e19dfca9f1eafa31a392c71c832b6bb9e113b10d050cb0f1"}, 203 | {file = "numpy-2.1.2-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:bd33f82e95ba7ad632bc57837ee99dba3d7e006536200c4e9124089e1bf42426"}, 204 | {file = "numpy-2.1.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1b8cde4f11f0a975d1fd59373b32e2f5a562ade7cde4f85b7137f3de8fbb29a0"}, 205 | {file = "numpy-2.1.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d95f286b8244b3649b477ac066c6906fbb2905f8ac19b170e2175d3d799f4df"}, 206 | {file = "numpy-2.1.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:ab4754d432e3ac42d33a269c8567413bdb541689b02d93788af4131018cbf366"}, 207 | {file = "numpy-2.1.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:e585c8ae871fd38ac50598f4763d73ec5497b0de9a0ab4ef5b69f01c6a046142"}, 208 | {file = "numpy-2.1.2-cp312-cp312-win32.whl", hash = "sha256:9c6c754df29ce6a89ed23afb25550d1c2d5fdb9901d9c67a16e0b16eaf7e2550"}, 209 | {file = "numpy-2.1.2-cp312-cp312-win_amd64.whl", hash = "sha256:456e3b11cb79ac9946c822a56346ec80275eaf2950314b249b512896c0d2505e"}, 210 | {file = "numpy-2.1.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:a84498e0d0a1174f2b3ed769b67b656aa5460c92c9554039e11f20a05650f00d"}, 211 | {file = "numpy-2.1.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:4d6ec0d4222e8ffdab1744da2560f07856421b367928026fb540e1945f2eeeaf"}, 212 | {file = "numpy-2.1.2-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:259ec80d54999cc34cd1eb8ded513cb053c3bf4829152a2e00de2371bd406f5e"}, 213 | {file = "numpy-2.1.2-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:675c741d4739af2dc20cd6c6a5c4b7355c728167845e3c6b0e824e4e5d36a6c3"}, 214 | {file = "numpy-2.1.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:05b2d4e667895cc55e3ff2b56077e4c8a5604361fc21a042845ea3ad67465aa8"}, 215 | {file = "numpy-2.1.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:43cca367bf94a14aca50b89e9bc2061683116cfe864e56740e083392f533ce7a"}, 216 | {file = "numpy-2.1.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:76322dcdb16fccf2ac56f99048af32259dcc488d9b7e25b51e5eca5147a3fb98"}, 217 | {file = "numpy-2.1.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:32e16a03138cabe0cb28e1007ee82264296ac0983714094380b408097a418cfe"}, 218 | {file = "numpy-2.1.2-cp313-cp313-win32.whl", hash = "sha256:242b39d00e4944431a3cd2db2f5377e15b5785920421993770cddb89992c3f3a"}, 219 | {file = "numpy-2.1.2-cp313-cp313-win_amd64.whl", hash = "sha256:f2ded8d9b6f68cc26f8425eda5d3877b47343e68ca23d0d0846f4d312ecaa445"}, 220 | {file = "numpy-2.1.2-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:2ffef621c14ebb0188a8633348504a35c13680d6da93ab5cb86f4e54b7e922b5"}, 221 | {file = "numpy-2.1.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:ad369ed238b1959dfbade9018a740fb9392c5ac4f9b5173f420bd4f37ba1f7a0"}, 222 | {file = "numpy-2.1.2-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:d82075752f40c0ddf57e6e02673a17f6cb0f8eb3f587f63ca1eaab5594da5b17"}, 223 | {file = "numpy-2.1.2-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:1600068c262af1ca9580a527d43dc9d959b0b1d8e56f8a05d830eea39b7c8af6"}, 224 | {file = "numpy-2.1.2-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a26ae94658d3ba3781d5e103ac07a876b3e9b29db53f68ed7df432fd033358a8"}, 225 | {file = "numpy-2.1.2-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:13311c2db4c5f7609b462bc0f43d3c465424d25c626d95040f073e30f7570e35"}, 226 | {file = "numpy-2.1.2-cp313-cp313t-musllinux_1_1_x86_64.whl", hash = "sha256:2abbf905a0b568706391ec6fa15161fad0fb5d8b68d73c461b3c1bab6064dd62"}, 227 | {file = "numpy-2.1.2-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:ef444c57d664d35cac4e18c298c47d7b504c66b17c2ea91312e979fcfbdfb08a"}, 228 | {file = "numpy-2.1.2-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:bdd407c40483463898b84490770199d5714dcc9dd9b792f6c6caccc523c00952"}, 229 | {file = "numpy-2.1.2-pp310-pypy310_pp73-macosx_14_0_x86_64.whl", hash = "sha256:da65fb46d4cbb75cb417cddf6ba5e7582eb7bb0b47db4b99c9fe5787ce5d91f5"}, 230 | {file = "numpy-2.1.2-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1c193d0b0238638e6fc5f10f1b074a6993cb13b0b431f64079a509d63d3aa8b7"}, 231 | {file = "numpy-2.1.2-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:a7d80b2e904faa63068ead63107189164ca443b42dd1930299e0d1cb041cec2e"}, 232 | {file = "numpy-2.1.2.tar.gz", hash = "sha256:13532a088217fa624c99b843eeb54640de23b3414b14aa66d023805eb731066c"}, 233 | ] 234 | 235 | [[package]] 236 | name = "opt-einsum" 237 | version = "3.4.0" 238 | description = "Path optimization of einsum functions." 239 | optional = false 240 | python-versions = ">=3.8" 241 | files = [ 242 | {file = "opt_einsum-3.4.0-py3-none-any.whl", hash = "sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd"}, 243 | {file = "opt_einsum-3.4.0.tar.gz", hash = "sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac"}, 244 | ] 245 | 246 | [[package]] 247 | name = "scipy" 248 | version = "1.13.1" 249 | description = "Fundamental algorithms for scientific computing in Python" 250 | optional = false 251 | python-versions = ">=3.9" 252 | files = [ 253 | {file = "scipy-1.13.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:20335853b85e9a49ff7572ab453794298bcf0354d8068c5f6775a0eabf350aca"}, 254 | {file = "scipy-1.13.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:d605e9c23906d1994f55ace80e0125c587f96c020037ea6aa98d01b4bd2e222f"}, 255 | {file = "scipy-1.13.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cfa31f1def5c819b19ecc3a8b52d28ffdcc7ed52bb20c9a7589669dd3c250989"}, 256 | {file = "scipy-1.13.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f26264b282b9da0952a024ae34710c2aff7d27480ee91a2e82b7b7073c24722f"}, 257 | {file = "scipy-1.13.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:eccfa1906eacc02de42d70ef4aecea45415f5be17e72b61bafcfd329bdc52e94"}, 258 | {file = "scipy-1.13.1-cp310-cp310-win_amd64.whl", hash = "sha256:2831f0dc9c5ea9edd6e51e6e769b655f08ec6db6e2e10f86ef39bd32eb11da54"}, 259 | {file = "scipy-1.13.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:27e52b09c0d3a1d5b63e1105f24177e544a222b43611aaf5bc44d4a0979e32f9"}, 260 | {file = "scipy-1.13.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:54f430b00f0133e2224c3ba42b805bfd0086fe488835effa33fa291561932326"}, 261 | {file = "scipy-1.13.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e89369d27f9e7b0884ae559a3a956e77c02114cc60a6058b4e5011572eea9299"}, 262 | {file = "scipy-1.13.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a78b4b3345f1b6f68a763c6e25c0c9a23a9fd0f39f5f3d200efe8feda560a5fa"}, 263 | {file = "scipy-1.13.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:45484bee6d65633752c490404513b9ef02475b4284c4cfab0ef946def50b3f59"}, 264 | {file = "scipy-1.13.1-cp311-cp311-win_amd64.whl", hash = "sha256:5713f62f781eebd8d597eb3f88b8bf9274e79eeabf63afb4a737abc6c84ad37b"}, 265 | {file = "scipy-1.13.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:5d72782f39716b2b3509cd7c33cdc08c96f2f4d2b06d51e52fb45a19ca0c86a1"}, 266 | {file = "scipy-1.13.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:017367484ce5498445aade74b1d5ab377acdc65e27095155e448c88497755a5d"}, 267 | {file = "scipy-1.13.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:949ae67db5fa78a86e8fa644b9a6b07252f449dcf74247108c50e1d20d2b4627"}, 268 | {file = "scipy-1.13.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:de3ade0e53bc1f21358aa74ff4830235d716211d7d077e340c7349bc3542e884"}, 269 | {file = "scipy-1.13.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:2ac65fb503dad64218c228e2dc2d0a0193f7904747db43014645ae139c8fad16"}, 270 | {file = "scipy-1.13.1-cp312-cp312-win_amd64.whl", hash = "sha256:cdd7dacfb95fea358916410ec61bbc20440f7860333aee6d882bb8046264e949"}, 271 | {file = "scipy-1.13.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:436bbb42a94a8aeef855d755ce5a465479c721e9d684de76bf61a62e7c2b81d5"}, 272 | {file = "scipy-1.13.1-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:8335549ebbca860c52bf3d02f80784e91a004b71b059e3eea9678ba994796a24"}, 273 | {file = "scipy-1.13.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d533654b7d221a6a97304ab63c41c96473ff04459e404b83275b60aa8f4b7004"}, 274 | {file = "scipy-1.13.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:637e98dcf185ba7f8e663e122ebf908c4702420477ae52a04f9908707456ba4d"}, 275 | {file = "scipy-1.13.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a014c2b3697bde71724244f63de2476925596c24285c7a637364761f8710891c"}, 276 | {file = "scipy-1.13.1-cp39-cp39-win_amd64.whl", hash = "sha256:392e4ec766654852c25ebad4f64e4e584cf19820b980bc04960bca0b0cd6eaa2"}, 277 | {file = "scipy-1.13.1.tar.gz", hash = "sha256:095a87a0312b08dfd6a6155cbbd310a8c51800fc931b8c0b84003014b874ed3c"}, 278 | ] 279 | 280 | [package.dependencies] 281 | numpy = ">=1.22.4,<2.3" 282 | 283 | [package.extras] 284 | dev = ["cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy", "pycodestyle", "pydevtool", "rich-click", "ruff", "types-psutil", "typing_extensions"] 285 | doc = ["jupyterlite-pyodide-kernel", "jupyterlite-sphinx (>=0.12.0)", "jupytext", "matplotlib (>=3.5)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (>=0.15.2)", "sphinx (>=5.0.0)", "sphinx-design (>=0.4.0)"] 286 | test = ["array-api-strict", "asv", "gmpy2", "hypothesis (>=6.30)", "mpmath", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"] 287 | 288 | [[package]] 289 | name = "setuptools" 290 | version = "75.2.0" 291 | description = "Easily download, build, install, upgrade, and uninstall Python packages" 292 | optional = false 293 | python-versions = ">=3.8" 294 | files = [ 295 | {file = "setuptools-75.2.0-py3-none-any.whl", hash = "sha256:a7fcb66f68b4d9e8e66b42f9876150a3371558f98fa32222ffaa5bced76406f8"}, 296 | {file = "setuptools-75.2.0.tar.gz", hash = "sha256:753bb6ebf1f465a1912e19ed1d41f403a79173a9acf66a42e7e6aec45c3c16ec"}, 297 | ] 298 | 299 | [package.extras] 300 | check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)", "ruff (>=0.5.2)"] 301 | core = ["importlib-metadata (>=6)", "importlib-resources (>=5.10.2)", "jaraco.collections", "jaraco.functools", "jaraco.text (>=3.7)", "more-itertools", "more-itertools (>=8.8)", "packaging", "packaging (>=24)", "platformdirs (>=2.6.2)", "tomli (>=2.0.1)", "wheel (>=0.43.0)"] 302 | cover = ["pytest-cov"] 303 | doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier", "towncrier (<24.7)"] 304 | enabler = ["pytest-enabler (>=2.2)"] 305 | test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "jaraco.test", "packaging (>=23.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"] 306 | type = ["importlib-metadata (>=7.0.2)", "jaraco.develop (>=7.21)", "mypy (==1.11.*)", "pytest-mypy"] 307 | 308 | [[package]] 309 | name = "toolz" 310 | version = "1.0.0" 311 | description = "List processing tools and functional utilities" 312 | optional = false 313 | python-versions = ">=3.8" 314 | files = [ 315 | {file = "toolz-1.0.0-py3-none-any.whl", hash = "sha256:292c8f1c4e7516bf9086f8850935c799a874039c8bcf959d47b600e4c44a6236"}, 316 | {file = "toolz-1.0.0.tar.gz", hash = "sha256:2c86e3d9a04798ac556793bced838816296a2f085017664e4995cb40a1047a02"}, 317 | ] 318 | 319 | [[package]] 320 | name = "triton" 321 | version = "3.0.0" 322 | description = "A language and compiler for custom Deep Learning operations" 323 | optional = false 324 | python-versions = "*" 325 | files = [ 326 | {file = "triton-3.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e1efef76935b2febc365bfadf74bcb65a6f959a9872e5bddf44cc9e0adce1e1a"}, 327 | {file = "triton-3.0.0-1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5ce8520437c602fb633f1324cc3871c47bee3b67acf9756c1a66309b60e3216c"}, 328 | {file = "triton-3.0.0-1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:34e509deb77f1c067d8640725ef00c5cbfcb2052a1a3cb6a6d343841f92624eb"}, 329 | {file = "triton-3.0.0-1-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:bcbf3b1c48af6a28011a5c40a5b3b9b5330530c3827716b5fbf6d7adcc1e53e9"}, 330 | {file = "triton-3.0.0-1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6e5727202f7078c56f91ff13ad0c1abab14a0e7f2c87e91b12b6f64f3e8ae609"}, 331 | ] 332 | 333 | [package.dependencies] 334 | filelock = "*" 335 | 336 | [package.extras] 337 | build = ["cmake (>=3.20)", "lit"] 338 | tests = ["autopep8", "flake8", "isort", "llnl-hatchet", "numpy", "pytest", "scipy (>=1.7.1)"] 339 | tutorials = ["matplotlib", "pandas", "tabulate"] 340 | 341 | [[package]] 342 | name = "typing-extensions" 343 | version = "4.12.2" 344 | description = "Backported and Experimental Type Hints for Python 3.8+" 345 | optional = false 346 | python-versions = ">=3.8" 347 | files = [ 348 | {file = "typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d"}, 349 | {file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"}, 350 | ] 351 | 352 | [metadata] 353 | lock-version = "2.0" 354 | python-versions = ">=3.10" 355 | content-hash = "3e16c751bf056781ad6406640f06b72c8d32e81104cb9bd20a300dd34aeb77a5" 356 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "jax-flash-attn2" 3 | version = "0.0.3" 4 | description = "Flash Attention Implementation with Multiple Backend Support and Sharding This module provides a flexible implementation of Flash Attention with support for different backends (GPU, TPU, CPU) and platforms (Triton, Pallas, JAX)." 5 | authors = ["Erfan Zare Chavoshi "] 6 | license = "Apache-2.0" 7 | readme = "README.md" 8 | homepage = "https://github.com/erfanzar/jax-flash-attn2" 9 | repository = "https://github.com/erfanzar/jax-flash-attn2" 10 | documentation = "https://erfanzar.github.io/jax-flash-attn2" 11 | keywords = ["JAX", "Deep Learning", "Machine Learning", "XLA"] 12 | classifiers = [ 13 | "Development Status :: 3 - Alpha", 14 | "Intended Audience :: Developers", 15 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 16 | "License :: OSI Approved :: Apache Software License", 17 | "Programming Language :: Python :: 3", 18 | "Programming Language :: Python :: 3.9", 19 | "Programming Language :: Python :: 3.10", 20 | "Programming Language :: Python :: 3.11", 21 | "Programming Language :: Python :: 3.12", 22 | ] 23 | 24 | [tool.poetry.dependencies] 25 | python = ">=3.10,<3.14" 26 | jax = ">=0.4.36" 27 | jaxlib = ">=0.4.36" 28 | eformer = "0.0.15" 29 | einops = "~0.8.0" 30 | triton = "~3.2.0" 31 | 32 | [tool.ruff.lint] 33 | select = ["E4", "E7", "E9", "F", "B"] 34 | ignore = ["E501", "B905", "B007", "E741"] 35 | unfixable = ["B"] 36 | 37 | [tool.ruff.lint.per-file-ignores] 38 | "__init__.py" = ["E402", "F401"] 39 | "**/{tests,docs,tools}/*" = ["E402"] 40 | "tests/*" = ["E402", "E731"] 41 | "triton_*" = ["E741", "ISC001", "E501", "E731"] 42 | "pallas_*" = ["E741", "ISC001", "E501", "E731"] 43 | 44 | [tool.ruff.format] 45 | quote-style = "double" 46 | indent-style = "tab" 47 | docstring-code-format = true 48 | 49 | [tool.ruff] 50 | target-version = "py311" 51 | line-length = 88 52 | indent-width = 2 53 | --------------------------------------------------------------------------------