├── .gitignore ├── LICENSE ├── README.md ├── install_dependency.sh ├── native_sparse_attention ├── __init__.py ├── infer │ ├── __init__.py │ ├── inference_func.py │ └── nsa_inference.py ├── model │ ├── README.md │ ├── __init__.py │ ├── toy_llama.py │ └── toy_nsa_llama.py ├── module │ ├── __init__.py │ ├── kv_cache.py │ ├── native_sparse_attention.py │ ├── rope.py │ └── self_attention.py └── ops │ ├── README.md │ ├── __init__.py │ ├── torch │ ├── __init__.py │ ├── compress_key_value.py │ ├── compressed_attention.py │ ├── compressed_attention_decode.py │ └── topk_sparse_attention.py │ └── triton │ ├── __init__.py │ ├── compressed_attention.py │ ├── flash_attention.py │ ├── flash_attention_decode.py │ ├── linear_compress.py │ ├── topk_sparse_attention.py │ ├── topk_sparse_attention_decode.py │ ├── utils.py │ └── weighted_pool.py ├── setup.py └── test ├── test_compress_key_value.py ├── test_compressed_attention.py ├── test_flash_attention.py ├── test_kv_cache.py ├── test_linear_compress.py ├── test_nsa_infer.py ├── test_nsa_model.py ├── test_nsa_module.py ├── test_rope.py └── test_topk_sparse_attention.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # PyPI configuration file 171 | .pypirc 172 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # Native Sparse Attention Triton 4 | 5 |
6 | 7 | This repository implements the sparse attention mechanism introduced in the paper [Native Sparse Attention: Hardware-Aligned and Natively Trainable Sparse Attention](https://arxiv.org/abs/2502.11089) and provides an efficient training implementation based on [Triton](https://github.com/triton-lang/triton). 8 | 9 | 🎉 We now support both training and inference for Native Sparse Attention (variable-length version, including prefilling, decoding, and KV cache management). We have provided a toy model at `model.ToyNSALlama`, which supports `forward` function for training and `generate` function for inference. Welcome to try it out! 10 | 11 | ## Requirements 12 | Ensure the following dependencies are installed: 13 | - PyTorch >= 2.1.0 14 | - triton >= 3.0.0 15 | - einops >= 0.7.0 16 | - flash_attn >= 2.6.3 17 | 18 | ## Usage 19 | 20 | ### Notes 21 | 1. PyTorch implementations (`ops.torch`) are intended for debugging only. 22 | 2. For production use, prefer Triton operators (`ops.triton`). 23 | 3. All implementations are based on the varlen approach similiar to flash_attn_func_varlen. Please concatenate the inputs of a batch before use. 24 | 4. Only support attention head dimension less than 128 for now. 25 | 26 | ### Install 27 | 28 | You can install `native_sparse_attention` using pip: 29 | 30 | ```shell 31 | pip install git+https://github.com/XunhaoLai/native-sparse-attention-triton.git 32 | ``` 33 | 34 | ### Functions 35 | 36 | The `ops` module has implemented several functions required for native sparse attention. For detailed usage instructions, please see [this link](https://github.com/XunhaoLai/native-sparse-attention-triton/tree/main/native_sparse_attention/ops#readme). 37 | 38 | You can import those functions from the `ops` module: 39 | 40 | ```python 41 | import torch 42 | from native_sparse_attention.ops import linear_compress, compressed_attention, topk_sparse_attention 43 | 44 | # input example 45 | num_q_heads = 64 46 | num_kv_heads = 4 47 | head_dim = 128 48 | kernel_size = 32 49 | kernel_stride = 16 50 | block_size = 64 51 | topk = 16 52 | cu_seqlens = torch.Tensor([0, 1024, 8192, 16384]).to(torch.int32).cuda() 53 | query = torch.randn(16384, num_q_heads, head_dim).to(torch.bfloat16).cuda() 54 | key = torch.randn(16384, num_kv_heads, head_dim).to(torch.bfloat16).cuda() 55 | value = torch.randn(16384, num_kv_heads, head_dim).to(torch.bfloat16).cuda() 56 | 57 | # weight example 58 | w = ( 59 | torch.randn(num_kv_heads, kernel_size * head_dim, head_dim) 60 | .to(torch.bfloat16) 61 | .cuda() 62 | ) 63 | pe = torch.randn(num_kv_heads, kernel_size, head_dim).to(torch.bfloat16).cuda() 64 | 65 | # 1. key value compression 66 | compressed_key, compressed_cu_seqlens = linear_compress( 67 | key, w, cu_seqlens, kernel_size, kernel_stride, pe 68 | ) 69 | compressed_value, _ = linear_compress( 70 | value, w, cu_seqlens, kernel_size, kernel_stride, None 71 | ) 72 | 73 | # 2. attention between query and compressed key value 74 | compressed_attn_output, topk_idx = compressed_attention( 75 | query, 76 | compressed_key, 77 | compressed_value, 78 | kernel_size, 79 | kernel_stride, 80 | block_size, 81 | topk, 82 | cu_seqlens, 83 | compressed_cu_seqlens, 84 | init_blocks=1, 85 | local_blocks=2, 86 | ) 87 | 88 | # 3. topk sparse attention 89 | sparse_attn_output = topk_sparse_attention( 90 | query, 91 | key, 92 | value, 93 | topk_idx, 94 | block_size, 95 | cu_seqlens, 96 | ) 97 | ``` 98 | 99 | ### Module 100 | 101 | The `modules` directory also provides implementations based on `torch.nn.module` for easy integration into models. 102 | 103 | ```python 104 | from native_sparse_attention.modules import NativeSparseAttention, RopeConfig 105 | 106 | NSA_Layer = NativeSparseAttention( 107 | compress_type="linear", 108 | hidden_size=4096, 109 | num_q_heads=64, 110 | num_kv_heads=4, 111 | head_dim=128, 112 | kernel_size=32, 113 | kernel_stride=16, 114 | block_size=64, 115 | topk=8, 116 | init_blocks=1, 117 | local_blocks=2, 118 | window_size=512, 119 | rope_config=RopeConfig( 120 | max_position_embeddings=32768, 121 | head_dim=128, 122 | rope_theta=500000, 123 | rope_scaling={ 124 | "factor": 4.0, 125 | "high_freq_factor": 4.0, 126 | "low_freq_factor": 1.0, 127 | "original_max_position_embeddings": 8192, 128 | "rope_type": "llama3", 129 | }, 130 | ), 131 | ) 132 | ``` 133 | 134 | ### Model 135 | 136 | We offer two simplified LLaMA models in the `model` directory, featuring self-attention and native sparse attention. For more details on how to use these models, please refer to [this link](https://github.com/XunhaoLai/native-sparse-attention-triton/tree/main/native_sparse_attention/model#readme). 137 | 138 | 139 | ```python 140 | from native_sparse_attention.model import ToyNSALlamaConfig, InferenceConfig, ToyNSALlama 141 | 142 | config = ToyNSALlamaConfig( 143 | hidden_size=4096, 144 | intermediate_size=14336, 145 | num_hidden_layers=8, 146 | num_attention_heads=32, 147 | num_key_value_heads=2, 148 | head_dim=128, 149 | rope_theta=500000.0, 150 | rope_scaling={ 151 | "factor": 8.0, 152 | "high_freq_factor": 4.0, 153 | "low_freq_factor": 1.0, 154 | "original_max_position_embeddings": 8192, 155 | "rope_type": "llama3", 156 | }, 157 | compress_type="weightedpool", 158 | kernel_size=32, 159 | kernel_stride=16, 160 | block_size=64, 161 | topk=8, 162 | init_blocks=1, 163 | local_blocks=2, 164 | window_size=512, 165 | ) 166 | inference_config = InferenceConfig( 167 | max_batch_size=4, 168 | max_length=8192, 169 | max_new_tokens=128, 170 | ) 171 | model = ToyNSALlama(config, inference_config).cuda().bfloat16() 172 | ``` 173 | 174 | ## Testing 175 | 176 | Some test scripts are available in the `test` folder and can be run directly for unit testing. For example: 177 | 178 | ```bash 179 | python test/test_topk_sparse_attention.py 180 | python test/test_nsa_module.py 181 | python test/test_nsa_model.py 182 | ``` 183 | 184 | ### Benchmarks 185 | 186 | Here are the speed benchmarks conducted on a single NVIDIA A100 GPU or H100 GPU for the `topk_sparse_attention` function: 187 | 188 | A100 GPU speed benchmarks: 189 | ```sh 190 | ** forward with block size 64 **: 191 | N Flash Triton-Flash Triton-Top8 Triton-Top16 192 | 0 2048.0 0.414144 0.635648 0.633440 1.009184 193 | 1 4096.0 1.400304 2.267552 1.179808 1.916736 194 | 2 8192.0 5.223776 8.528160 2.266816 3.723168 195 | 3 16384.0 20.225697 32.745537 4.468128 7.359168 196 | 4 32768.0 79.587715 128.951065 8.517440 14.142848 197 | 5 65536.0 321.240479 511.652100 17.249599 30.991360 198 | 6 131072.0 1349.810425 2063.245605 36.400482 67.884544 199 | 200 | ** backward with block size 64 **: 201 | N Flash Triton-Flash Triton-Top8 Triton-Top16 202 | 0 2048.0 1.315440 2.348560 1.941568 2.691040 203 | 1 4096.0 4.271584 8.553184 3.647744 5.032160 204 | 2 8192.0 15.323984 32.665440 5.650144 9.066112 205 | 3 16384.0 58.753281 127.675964 11.160832 17.113279 206 | 4 32768.0 227.770462 504.572693 21.723392 34.715614 207 | 5 65536.0 899.181274 2059.718506 44.517181 76.309441 208 | 6 131072.0 3587.918701 8530.726562 105.344734 182.970169 209 | ``` 210 | 211 | H100 GPU benchmarks: 212 | ```sh 213 | ** forward with block size 64 **: 214 | N Flash Triton-Flash Triton-Top8 Triton-Top16 215 | 0 2048.0 0.259552 0.293888 0.584544 0.917664 216 | 1 4096.0 0.846848 1.029904 1.094976 1.745136 217 | 2 8192.0 3.043744 3.843392 2.128256 3.396880 218 | 3 16384.0 11.743568 14.791360 4.190528 6.704192 219 | 4 32768.0 45.968513 57.532478 7.614496 12.417440 220 | 5 65536.0 187.234375 228.093948 14.840048 24.511856 221 | 6 131072.0 810.890381 914.693970 29.470400 48.990192 222 | 223 | ** backward with block size 64 **: 224 | N Flash Triton-Flash Triton-Top8 Triton-Top16 225 | 0 2048.0 0.798976 1.096096 1.117312 1.380016 226 | 1 4096.0 2.545680 3.826336 1.669760 2.214880 227 | 2 8192.0 9.029760 14.411633 2.772096 3.947456 228 | 3 16384.0 34.144016 58.945698 5.201344 7.538912 229 | 4 32768.0 135.718369 233.369247 9.968864 15.154192 230 | 5 65536.0 541.053894 929.337646 21.089870 33.818878 231 | 6 131072.0 2139.974854 3785.540527 54.918144 93.750717 232 | ``` 233 | 234 | Here comes another speed benchmark result for testing `compressed_attention` function on a single NVIDIA A100 GPU or H100 GPU: 235 | 236 | A100 GPU speed benchmarks: 237 | ```sh 238 | ** forward with kernel 32 and stride 16 **: 239 | N Flash Triton-Flash Compressed Compressed-wo-Score 240 | 0 2048.0 0.413664 0.635488 0.655024 0.170816 241 | 1 4096.0 1.396416 2.247648 1.132304 0.377152 242 | 2 8192.0 5.234656 8.526400 2.879200 0.977952 243 | 3 16384.0 19.988865 32.755199 9.426448 2.943024 244 | 4 32768.0 79.419907 128.955170 30.284096 9.901120 245 | 5 65536.0 321.590210 511.615509 112.260544 36.001602 246 | 6 131072.0 1346.996338 2069.837891 423.099518 136.820038 247 | 248 | ** backward with kernel 32 and stride 16 **: 249 | N Flash Triton-Flash Compressed 250 | 0 2048.0 1.322560 2.352000 0.486784 251 | 1 4096.0 4.270832 8.552608 0.971392 252 | 2 8192.0 15.515680 32.671329 2.603744 253 | 3 16384.0 59.345055 128.377472 8.499456 254 | 4 32768.0 230.626144 506.581238 30.064833 255 | 5 65536.0 919.260498 2068.642578 113.466560 256 | 6 131072.0 3646.603760 8498.374023 439.623444 257 | ``` 258 | 259 | H100 GPU speed benchmarks: 260 | ```sh 261 | ** forward with kernel 32 and stride 16 **: 262 | N Flash Triton-Flash Compressed Compressed-wo-Score 263 | 0 2048.0 0.259488 0.297152 0.485920 0.103232 264 | 1 4096.0 0.847376 1.030400 0.710208 0.217760 265 | 2 8192.0 3.044016 3.875840 1.607360 0.516016 266 | 3 16384.0 11.823104 14.829360 4.970272 1.440288 267 | 4 32768.0 46.204750 57.527809 15.004992 4.584736 268 | 5 65536.0 187.324249 227.909958 53.009087 16.134224 269 | 6 131072.0 810.707214 910.106873 191.245728 60.154270 270 | 271 | ** backward with kernel 32 and stride 16 **: 272 | N Flash Triton-Flash Compressed 273 | 0 2048.0 0.797728 1.090640 0.283104 274 | 1 4096.0 2.547088 3.834592 0.550464 275 | 2 8192.0 9.021520 14.421088 1.249184 276 | 3 16384.0 34.159508 58.793377 3.743440 277 | 4 32768.0 136.844070 233.447708 12.640032 278 | 5 65536.0 537.559814 929.360229 46.054817 279 | 6 131072.0 2135.629883 3782.351562 175.587296 280 | ``` 281 | 282 | All the speed benchmarks above were tested with 64 query heads, 4 key/value heads, and a head dimension of 128. 283 | 284 | ## Contributing 285 | Contributions are welcome! Please open an issue to discuss major changes. 286 | 287 | ## Contact 288 | 289 | For any questions or feedback, please feel free to contact laixunhao@pku.edu.cn. 290 | 291 | ## Citations 292 | 293 | ```bibtex 294 | @inproceedings{Yuan2025NativeSA, 295 | title = {Native Sparse Attention: Hardware-Aligned and Natively Trainable Sparse Attention}, 296 | author = {Jingyang Yuan and Huazuo Gao and Damai Dai and Junyu Luo and Liang Zhao and Zhengyan Zhang and Zhenda Xie and Y. X. Wei and Lean Wang and Zhiping Xiao and Yuqing Wang and Chong Ruan and Ming Zhang and Wenfeng Liang and Wangding Zeng}, 297 | year = {2025}, 298 | url = {https://api.semanticscholar.org/CorpusID:276408911} 299 | } 300 | ``` 301 | -------------------------------------------------------------------------------- /install_dependency.sh: -------------------------------------------------------------------------------- 1 | pip3 install packaging -i https://pypi.org/simple 2 | pip3 install numpy==1.26.4 -i https://pypi.org/simple 3 | pip3 install torch==2.4.0 -i https://pypi.org/simple 4 | pip3 install triton==3.0.0 -i https://pypi.org/simple 5 | pip3 install transformers==4.44.0 -i https://pypi.org/simple 6 | pip3 install flash_attn==2.6.3 -i https://pypi.org/simple 7 | pip3 install matplotlib==3.9.4 -i https://pypi.org/simple 8 | pip3 install pandas==2.2.3 -i https://pypi.org/simple -------------------------------------------------------------------------------- /native_sparse_attention/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XunhaoLai/native-sparse-attention-triton/9bea856c911ebf263be88d797fb28458f82f1d94/native_sparse_attention/__init__.py -------------------------------------------------------------------------------- /native_sparse_attention/infer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Xunhao Lai & Jianqiao Lu. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from native_sparse_attention.infer.nsa_inference import nsa_infer 16 | 17 | __all__ = [ 18 | "nsa_infer", 19 | ] 20 | -------------------------------------------------------------------------------- /native_sparse_attention/infer/inference_func.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Xunhao Lai & Jianqiao Lu. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import torch 15 | from typing import Tuple, Callable, Optional 16 | from flash_attn import flash_attn_varlen_func 17 | from native_sparse_attention.ops import ( 18 | flash_attention_decode, 19 | compressed_attention, 20 | compressed_attention_decode, 21 | topk_sparse_attention, 22 | topk_sparse_attention_decode, 23 | ) 24 | from native_sparse_attention.ops.triton.utils import get_compressed_seqlens 25 | 26 | 27 | def compress_infer( 28 | cu_seqlens: torch.Tensor, 29 | step: int, 30 | key: torch.Tensor, 31 | value: torch.Tensor, 32 | cache, 33 | weight: Tuple[torch.Tensor, torch.Tensor], 34 | compress_func: Tuple[Callable, Callable], 35 | intra_block_pe: Optional[torch.Tensor], 36 | kernel_size: int, 37 | kernel_stride: int, 38 | ): 39 | if step == 0: 40 | key, compress_cu_seqlens = compress_func[0]( 41 | key, 42 | weight[0], 43 | cu_seqlens, 44 | kernel_size, 45 | kernel_stride, 46 | intra_block_pe, 47 | ) 48 | value, _ = compress_func[1]( 49 | value, 50 | weight[1], 51 | cu_seqlens, 52 | kernel_size, 53 | kernel_stride, 54 | ) 55 | else: 56 | batch_size = cu_seqlens.shape[0] - 1 57 | aux_cu_seqlens = ( 58 | torch.arange(batch_size + 1, dtype=torch.int32).to(cu_seqlens.device) 59 | * kernel_size 60 | ) 61 | key, _ = compress_func[0]( 62 | cache.before_compress_kv_cache[0, :batch_size].view( 63 | batch_size * kernel_size, cache.num_kv_heads, cache.head_dim 64 | ), 65 | weight[0], 66 | aux_cu_seqlens, 67 | kernel_size, 68 | kernel_stride, 69 | intra_block_pe, 70 | ) 71 | value, _ = compress_func[1]( 72 | cache.before_compress_kv_cache[1, :batch_size].view( 73 | batch_size * kernel_size, cache.num_kv_heads, cache.head_dim 74 | ), 75 | weight[1], 76 | aux_cu_seqlens, 77 | kernel_size, 78 | kernel_stride, 79 | ) 80 | # return actual compress_cu_seqlens before this token 81 | compress_cu_seqlens = torch.zeros( 82 | batch_size + 1, dtype=torch.int32, device=key.device 83 | ) 84 | compress_cu_seqlens[1:] = torch.cumsum( 85 | cache.compress_kv_len[:batch_size], dim=0 86 | ) 87 | return key, value, compress_cu_seqlens 88 | 89 | 90 | def compressed_attention_infer( 91 | cu_seqlens, 92 | step, 93 | query, 94 | key, 95 | value, 96 | cache, 97 | kernel_size, 98 | kernel_stride, 99 | topk, 100 | block_size, 101 | init_blocks, 102 | local_blocks, 103 | ): 104 | if step == 0: 105 | seqlens = cu_seqlens[1:] - cu_seqlens[:-1] 106 | compress_seqlens, compress_cu_seqlens = get_compressed_seqlens( 107 | cu_seqlens, kernel_size, kernel_stride 108 | ) 109 | attn_output, topk_idx = compressed_attention( 110 | query, 111 | key, 112 | value, 113 | kernel_size, 114 | kernel_stride, 115 | block_size, 116 | topk, 117 | cu_seqlens, 118 | compress_cu_seqlens, 119 | seqlens.max().item(), 120 | compress_seqlens.max().item(), 121 | None, 122 | init_blocks, 123 | local_blocks, 124 | ) 125 | else: 126 | batch_size = cu_seqlens.shape[0] - 1 127 | seqlens = cu_seqlens[1:] - cu_seqlens[:-1] + step 128 | attn_output, topk_idx = compressed_attention_decode( 129 | query, 130 | cache.compress_kv_cache[ 131 | 0, :batch_size, : cache.compress_kv_len[:batch_size].max() 132 | ], 133 | cache.compress_kv_cache[ 134 | 1, :batch_size, : cache.compress_kv_len[:batch_size].max() 135 | ], 136 | seqlens, 137 | cache.compress_kv_len[:batch_size], 138 | kernel_size, 139 | kernel_stride, 140 | block_size, 141 | topk, 142 | init_blocks, 143 | local_blocks, 144 | ) 145 | return attn_output, topk_idx 146 | 147 | 148 | def topk_sparse_attention_infer( 149 | cu_seqlens, 150 | step, 151 | query, 152 | key, 153 | value, 154 | cache, 155 | topk_idx, 156 | block_size, 157 | ): 158 | if step == 0: 159 | attn_output = topk_sparse_attention( 160 | query, key, value, topk_idx, block_size, cu_seqlens 161 | ) 162 | else: 163 | batch_size = cu_seqlens.shape[0] - 1 164 | attn_output = topk_sparse_attention_decode( 165 | query, 166 | cache.sparse_kv_cache[0, :batch_size], 167 | cache.sparse_kv_cache[1, :batch_size], 168 | topk_idx, 169 | block_size, 170 | cache.sparse_kv_len[:batch_size], 171 | ) 172 | return attn_output 173 | 174 | 175 | def sliding_window_attention_infer( 176 | cu_seqlens, step, query, key, value, cache, window_size 177 | ): 178 | if step == 0: 179 | seqlens = cu_seqlens[1:] - cu_seqlens[:-1] 180 | attn_output = flash_attn_varlen_func( 181 | query, 182 | key, 183 | value, 184 | cu_seqlens, 185 | cu_seqlens, 186 | seqlens.max().item(), 187 | seqlens.max().item(), 188 | causal=True, 189 | window_size=(window_size, -1), 190 | ) 191 | else: 192 | batch_size = cu_seqlens.shape[0] - 1 193 | attn_output = flash_attention_decode( 194 | query, 195 | cache.sliding_kv_cache[0, :batch_size], 196 | cache.sliding_kv_cache[1, :batch_size], 197 | torch.minimum( 198 | cache.sliding_kv_len, 199 | torch.zeros_like(cache.sliding_kv_len) + window_size, 200 | )[:batch_size], 201 | ) 202 | return attn_output 203 | -------------------------------------------------------------------------------- /native_sparse_attention/infer/nsa_inference.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Xunhao Lai & Jianqiao Lu. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import torch 15 | from typing import Tuple, Callable, Optional 16 | from native_sparse_attention.infer.inference_func import ( 17 | compress_infer, 18 | compressed_attention_infer, 19 | topk_sparse_attention_infer, 20 | sliding_window_attention_infer, 21 | ) 22 | 23 | 24 | def nsa_infer( 25 | cu_seqlens: torch.Tensor, 26 | step: int, 27 | # qkv for three parts 28 | query: torch.Tensor, 29 | key: torch.Tensor, # prefill: [total_len, num_heads, head_dim], decode: [batch_size, num_heads, head_dim] 30 | value: torch.Tensor, 31 | gate_value: torch.Tensor, # prefill: [total_len, num_heads, 3], decode: [batch_size, num_heads, 3] 32 | # rope and kv cache 33 | rope, 34 | cache, 35 | # weight for nsa compress 36 | compress_weight: Tuple[ 37 | torch.Tensor, torch.Tensor 38 | ], # compress weight for key and value 39 | compress_func: Tuple[Callable, Callable], # compress function for key and value 40 | intra_block_pe: Optional[torch.Tensor], 41 | # nsa parameters 42 | kernel_size: int, 43 | kernel_stride: int, 44 | block_size: int, 45 | topk: int, 46 | init_blocks: int, 47 | local_blocks: int, 48 | window_size: int, 49 | ) -> torch.Tensor: 50 | """Inference function for native sparse attention. Support prefill and decode with kv cache. 51 | 52 | Args: 53 | cu_seqlens (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen. 54 | step (int): current inference step, step == 0 means prefill, step > 0 means decode step. 55 | query (torch.Tensor): for prefill, shape [total_len, num_q_heads, head_dim]; for decode, shape [batch_size, num_q_heads, head_dim] 56 | key (torch.Tensor): for prefill, shape [total_len, num_kv_heads, head_dim]; for decode, shape [batch_size, num_kv_heads, head_dim] 57 | value (torch.Tensor): for prefill, shape [total_len, num_kv_heads, head_dim]; for decode, shape [batch_size, num_kv_heads, head_dim] 58 | gate_value (torch.Tensor): for prefill, shape [total_len, num_heads, 3]; for decode, shape [batch_size, num_heads, 3] 59 | rope (RotaryEmbedding): rope module, see native_sparse_attention.module.rope.RotaryEmbedding for details 60 | cache (NSACache): kv cache, seed native_sparse_attention.module.kv_cache.NSACache for details 61 | compress_weight (Tuple[torch.Tensor, torch.Tensor]): compress weight of key and value respectively 62 | compress_func (Tuple[Callable, Callable]): compress functions for key and value respectively 63 | intra_block_pe (Optional[torch.Tensor]): intra-block positonal embedding for compression, set to None if don't use it 64 | kernel_size (int): kernel size of compression 65 | kernel_stride (int): kernel stride ofr compression 66 | block_size (int): block size of sparse attention 67 | topk (int): topk of sparse attention 68 | init_blocks (int): number of blocks at the begining of the sequence, these blocks are force to be computed in sparse attention 69 | local_blocks (int): number of blocks at the local window of each query, these blocks are force to be computed in sparse attention 70 | window_size (int): window size for sliding window attention 71 | 72 | Returns: 73 | torch.Tensor: native sparse attention output, same shape as input query 74 | """ 75 | # reset kv cache at the begining of prefilling 76 | if step == 0: 77 | cache.reset() 78 | # prepare for compress 79 | cache.prepare_compress(cu_seqlens, step, key, value) 80 | # compressed key and value before rope 81 | compress_key, compress_value, compress_cu_seqlens = compress_infer( 82 | cu_seqlens, 83 | step, 84 | key, 85 | value, 86 | cache, 87 | compress_weight, 88 | compress_func, 89 | intra_block_pe, 90 | kernel_size, 91 | kernel_stride, 92 | ) 93 | # do rope 94 | query = rope(query, cu_seqlens, step) 95 | if step == 0: 96 | compress_key = rope( 97 | compress_key, compress_cu_seqlens, step, stride=cache.kernel_stride 98 | ) 99 | else: 100 | compress_key = rope( 101 | compress_key, compress_cu_seqlens, 1, stride=cache.kernel_stride 102 | ) 103 | key = rope(key, cu_seqlens, step) 104 | # update kv cache 105 | cache.update_kv( 106 | cu_seqlens, 107 | step, 108 | compress_key, 109 | compress_value, 110 | key, 111 | value, 112 | key, 113 | value, 114 | ) 115 | # compressed attention 116 | compress_attn_output, topk_idx = compressed_attention_infer( 117 | cu_seqlens, 118 | step, 119 | query, 120 | compress_key, 121 | compress_value, 122 | cache, 123 | kernel_size, 124 | kernel_stride, 125 | topk, 126 | block_size, 127 | init_blocks, 128 | local_blocks, 129 | ) 130 | # topk sparse attention 131 | sparse_attn_output = topk_sparse_attention_infer( 132 | cu_seqlens, 133 | step, 134 | query, 135 | key, 136 | value, 137 | cache, 138 | topk_idx, 139 | block_size, 140 | ) 141 | # sliding window attention 142 | sliding_attn_output = sliding_window_attention_infer( 143 | cu_seqlens, step, query, key, value, cache, window_size 144 | ) 145 | # combine 3 attn output 146 | attn_output = ( 147 | gate_value[..., 0, None] * compress_attn_output 148 | + gate_value[..., 1, None] * sparse_attn_output 149 | + gate_value[..., 2, None] * sliding_attn_output 150 | ) 151 | return attn_output 152 | -------------------------------------------------------------------------------- /native_sparse_attention/model/README.md: -------------------------------------------------------------------------------- 1 | # Guide for the ToyNSALlama Model 2 | 3 | The `ToyNSALlama` model is a custom implementation of a Llama-like transformer architecture featuring a Native Sparse Attention (NSA) module. This guide explains how to integrate the NSA module into your own model. 4 | 5 | ## Overview 6 | 7 | The `ToyNSALlama` model consists of: 8 | - **Configuration**: Defined by `ToyNSALlamaConfig` (model structure parameters) and `InferenceConfig` (inference-specific parameters). 9 | - **Components**: An embedding layer, multiple NativeSparseAttention modules, Feed-Forward Network (FFN) modules, normalization layers, and a language model head. 10 | 11 | ## Step-by-Step Instructions 12 | 13 | ### 1. Import Necessary Modules 14 | ```python 15 | import torch 16 | import torch.nn as nn 17 | from native_sparse_attention.model import ToyNSALlama, ToyNSALlamaConfig, InferenceConfig 18 | ``` 19 | 20 | ### 2. Define Configurations 21 | Create instances of `ToyNSALlamaConfig` and `InferenceConfig` to set model and inference parameters. 22 | 23 | #### Model Configuration 24 | The model configuration aligns with the Transformers Llama model configuration. Adjust the following parameters to control the NSA module’s sparsity: 25 | - `compress_type`: Compression method for keys/values. Supported options: `avgpool`, `weightedpool`, `linear`. 26 | - `kernel_size` & `kernel_stride`: `kernel_size` determines how many tokens are compressed into one; `kernel_stride` sets the sliding window stride (must be divisible by `kernel_size`). 27 | - `block_size`: Block size for sparse attention (recommended: 64 or 128). 28 | - `topk`, `init_blocks`, `local_blocks`: `topk` specifies the number of blocks selected in sparse attention; `init_blocks` and `local_blocks` define the number of initial and local blocks that must be selected. 29 | - `window_size`: Size of the sliding window for attention. 30 | 31 | Example: 32 | ```python 33 | config = ToyNSALlamaConfig( 34 | hidden_size=4096, 35 | intermediate_size=14336, 36 | num_hidden_layers=8, 37 | num_attention_heads=32, 38 | num_key_value_heads=2, 39 | head_dim=128, 40 | vocab_size=128288, 41 | max_position_embeddings=131072, 42 | rope_theta=500000.0, 43 | rope_scaling={ 44 | "factor": 8.0, 45 | "high_freq_factor": 4.0, 46 | "low_freq_factor": 1.0, 47 | "original_max_position_embeddings": 8192, 48 | "rope_type": "llama3", 49 | }, 50 | compress_type="weightedpool", 51 | kernel_size=32, 52 | kernel_stride=16, 53 | block_size=64, 54 | topk=8, 55 | init_blocks=1, 56 | local_blocks=2, 57 | window_size=512, 58 | ) 59 | ``` 60 | 61 | #### Inference Configuration 62 | This configuration applies during inference, initializing the Key-Value (KV) Cache based on these settings. The full KV cache size is calculated as `max_batch_size × max_length × num_kv_heads × num_layers × 2 × 2` bytes. Currently, only greedy decoding is supported as an example. 63 | 64 | Example: 65 | ```python 66 | inference_config = InferenceConfig( 67 | max_batch_size=4, 68 | max_length=8192, 69 | max_new_tokens=128, 70 | ) 71 | ``` 72 | 73 | ### 3. Initialize the Model 74 | Instantiate the model and move it to the GPU with the appropriate data type (currently, only `bfloat16` is supported). 75 | 76 | ```python 77 | model = ToyNSALlama(config, inference_config).cuda().to(torch.bfloat16) 78 | ``` 79 | 80 | ### 4. Forward & Generate 81 | The model supports two methods: 82 | - **`forward`**: Accepts `input_ids` and `cu_seqlens`, returning final logits after the language model head. Use this for training or evaluation. 83 | - **`generate`**: Accepts `input_ids` and `cu_seqlens`, generating output tokens via greedy sampling. This demonstrates KV cache usage for token generation (pre-filling and decoding). 84 | 85 | Example: 86 | ```python 87 | # Example input 88 | batch_size = 4 89 | seqlens = torch.randint(0, 4096, (batch_size,), dtype=torch.int32, device="cuda") 90 | cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device="cuda") 91 | cu_seqlens[1:] = torch.cumsum(seqlens, dim=0) 92 | input_ids = torch.randint(0, 128288, (cu_seqlens[-1],), dtype=torch.int64, device="cuda") 93 | print(f"\nEXAMPLE INPUT:\ncu_seqlens: {cu_seqlens}\ninput_ids: {input_ids.shape}\n") 94 | 95 | # Example forward 96 | logits = model(input_ids, cu_seqlens) 97 | print(f"\nEXAMPLE OUTPUT:\nlogits: {logits.shape}\n") 98 | 99 | # Example generate 100 | output_tokens = model.generate(input_ids, cu_seqlens) 101 | print(f"\nEXAMPLE GENERATE:\noutput_tokens: {output_tokens}\n") 102 | ``` 103 | 104 | ## Toy Llama Model with Self-Attention 105 | A simpler toy model with the Llama structure is available in `native_sparse_attention/model/toy_llama.py`. Compare `ToyLlama` and `ToyNSALlama` to see how to adapt a self-attention model into an NSA-based model. 106 | 107 | The primary difference lies in replacing the `SelfAttention` module with the `NativeSparseAttention` module, along with updates to the KV cache and inference function. These changes are straightforward and easy to implement. 108 | -------------------------------------------------------------------------------- /native_sparse_attention/model/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Xunhao Lai & Jianqiao Lu. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from native_sparse_attention.model.toy_llama import ( 15 | ToyLlamaConfig, 16 | InferenceConfig, 17 | ToyLlama, 18 | ) 19 | from native_sparse_attention.model.toy_nsa_llama import ( 20 | ToyNSALlamaConfig, 21 | InferenceConfig, 22 | ToyNSALlama, 23 | ) 24 | 25 | __all__ = [ 26 | "ToyLlamaConfig", 27 | "ToyNSALlamaConfig", 28 | "InferenceConfig", 29 | "ToyLlama", 30 | "ToyNSALlama", 31 | ] 32 | -------------------------------------------------------------------------------- /native_sparse_attention/model/toy_llama.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Xunhao Lai & Jianqiao Lu. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import Optional 15 | import torch 16 | import torch.nn as nn 17 | from dataclasses import dataclass, field 18 | from native_sparse_attention.module import SelfAttention, RopeConfig, KVCache 19 | 20 | 21 | @dataclass 22 | class ToyLlamaConfig: 23 | # embedding config 24 | vocab_size: int = 128288 25 | max_position_embeddings: int = 131072 26 | # model config 27 | hidden_size: int = 4096 28 | intermediate_size: int = 14336 29 | num_hidden_layers: int = 32 30 | num_attention_heads: int = 32 31 | num_key_value_heads: int = 2 32 | head_dim: int = 128 33 | # rope config 34 | rope_theta: float = 500000.0 35 | rope_scaling: dict = field( 36 | default_factory=lambda: { 37 | "factor": 8.0, 38 | "high_freq_factor": 4.0, 39 | "low_freq_factor": 1.0, 40 | "original_max_position_embeddings": 8192, 41 | "rope_type": "llama3", 42 | } 43 | ) 44 | 45 | 46 | @dataclass 47 | class InferenceConfig: 48 | max_batch_size: int = 32 49 | max_length: int = 8192 50 | max_new_tokens: int = 128 51 | 52 | 53 | class RMSNorm(nn.Module): 54 | def __init__(self, hidden_size: int, eps: float = 1e-6): 55 | super().__init__() 56 | self.weight = nn.Parameter(torch.ones(hidden_size)) 57 | self.variance_epsilon = eps 58 | 59 | def forward(self, hidden_states: torch.Tensor): 60 | input_dtype = hidden_states.dtype 61 | hidden_states = hidden_states.to(torch.float32) 62 | variance = hidden_states.pow(2).mean(-1, keepdim=True) 63 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) 64 | return self.weight * hidden_states.to(input_dtype) 65 | 66 | 67 | class FFN(nn.Module): 68 | def __init__(self, hidden_size: int, intermediate_size: int): 69 | super().__init__() 70 | self.hidden_size = hidden_size 71 | self.intermediate_size = intermediate_size 72 | self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) 73 | self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) 74 | self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) 75 | self.act_fn = nn.SiLU() 76 | 77 | def forward(self, x): 78 | down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) 79 | return down_proj 80 | 81 | 82 | class ToyLlamaLayer(nn.Module): 83 | def __init__( 84 | self, 85 | hidden_size: int, 86 | intermediate_size: int, 87 | num_q_heads: int, 88 | num_kv_heads: int, 89 | head_dim: int, 90 | rope_config: RopeConfig, 91 | ): 92 | super().__init__() 93 | self.hidden_size = hidden_size 94 | self.intermediate_size = intermediate_size 95 | self.num_q_heads = num_q_heads 96 | self.num_kv_heads = num_kv_heads 97 | self.head_dim = head_dim 98 | self.rope_config = rope_config 99 | self.attn_norm = RMSNorm(self.hidden_size) 100 | self.self_attn = SelfAttention( 101 | hidden_size=self.hidden_size, 102 | num_q_heads=self.num_q_heads, 103 | num_kv_heads=self.num_kv_heads, 104 | head_dim=self.head_dim, 105 | rope_config=rope_config, 106 | ) 107 | self.ffn_norm = RMSNorm(self.hidden_size) 108 | self.ffn = FFN( 109 | hidden_size=self.hidden_size, intermediate_size=self.intermediate_size 110 | ) 111 | 112 | def forward(self, x, cu_seqlens): 113 | x = x + self.self_attn(self.attn_norm(x), cu_seqlens) 114 | x = x + self.ffn(self.ffn_norm(x)) 115 | return x 116 | 117 | @torch.no_grad() 118 | def inference(self, x, cu_seqlens, step, kv_cache): 119 | x = x + self.self_attn.inference(self.attn_norm(x), cu_seqlens, step, kv_cache) 120 | x = x + self.ffn(self.ffn_norm(x)) 121 | return x 122 | 123 | 124 | class ToyLlama(nn.Module): 125 | def __init__( 126 | self, config: ToyLlamaConfig, inference_config: Optional[InferenceConfig] = None 127 | ): 128 | super().__init__() 129 | self.config = config 130 | self.embedding = nn.Embedding(self.config.vocab_size, self.config.hidden_size) 131 | self.rope_config = RopeConfig( 132 | head_dim=self.config.head_dim, 133 | rope_theta=self.config.rope_theta, 134 | rope_scaling=self.config.rope_scaling, 135 | ) 136 | self.layers = nn.ModuleList( 137 | [ 138 | ToyLlamaLayer( 139 | hidden_size=self.config.hidden_size, 140 | intermediate_size=self.config.intermediate_size, 141 | num_q_heads=self.config.num_attention_heads, 142 | num_kv_heads=self.config.num_key_value_heads, 143 | head_dim=self.config.head_dim, 144 | rope_config=RopeConfig( 145 | self.config.max_position_embeddings, 146 | self.config.head_dim, 147 | self.config.rope_theta, 148 | self.config.rope_scaling, 149 | ), 150 | ) 151 | for _ in range(self.config.num_hidden_layers) 152 | ] 153 | ) 154 | self.norm = RMSNorm(self.config.hidden_size) 155 | self.lm_head = nn.Linear( 156 | self.config.hidden_size, self.config.vocab_size, bias=False 157 | ) 158 | 159 | # inference config and kv cache 160 | self.inference_config = inference_config 161 | self.kv_cache = None 162 | 163 | def forward( 164 | self, 165 | input_ids: torch.LongTensor, # shape: [total_length, ] 166 | cu_seqlens: torch.LongTensor, # shape: [batch_size + 1, ] 167 | ): 168 | # embedding 169 | x = self.embedding(input_ids).to(torch.bfloat16) 170 | # layers 171 | for layer in self.layers: 172 | x = layer(x, cu_seqlens) 173 | # final norm 174 | x = self.norm(x) 175 | # lanugauge head 176 | x = self.lm_head(x).to(torch.float32) # [total_len, vocab_size] 177 | return x 178 | 179 | @torch.no_grad() 180 | def inference( 181 | self, 182 | input_ids: torch.LongTensor, # prefill shape: [total_length, ]; decode shape: [batch_size, ] 183 | cu_seqlens: torch.LongTensor, # shape: [batch_size + 1, ] 184 | step: int, 185 | ): 186 | # set kv cache if self.kv_cache is None 187 | if self.kv_cache is None: 188 | self.kv_cache = [ 189 | KVCache( 190 | max_batch_size=self.inference_config.max_batch_size, 191 | max_length=self.inference_config.max_length, 192 | num_kv_heads=self.config.num_key_value_heads, 193 | head_dim=self.config.head_dim, 194 | dtype=torch.bfloat16, 195 | device="cuda", 196 | ) 197 | for _ in range(self.config.num_hidden_layers) 198 | ] 199 | # embedding 200 | x = self.embedding(input_ids).to(torch.bfloat16) 201 | # layers 202 | for i, layer in enumerate(self.layers): 203 | x = layer.inference(x, cu_seqlens, step, self.kv_cache[i]) 204 | # final norm 205 | x = self.norm(x) 206 | # lanugauge head 207 | if step == 0: 208 | x = x[cu_seqlens[1:] - 1, :] 209 | x = self.lm_head(x).to(torch.float32) # [total_len, vocab_size] 210 | return x 211 | 212 | def generate( 213 | self, 214 | input_ids: torch.LongTensor, 215 | cu_seqlens: torch.LongTensor, 216 | max_new_tokens: int = -1, 217 | ): 218 | output_tokens = [] 219 | if max_new_tokens <= 0: 220 | max_new_tokens = self.inference_config.max_new_tokens 221 | for step in range(max_new_tokens): 222 | logits = self.inference( 223 | input_ids, cu_seqlens, step 224 | ) # shape: [batch_size, vocab_size] 225 | next_token = torch.argmax(logits, dim=-1) # shape: [batch_size, ] 226 | input_ids = next_token 227 | output_tokens.append(next_token) 228 | output_tokens = torch.stack( 229 | output_tokens, dim=1 230 | ) # shape: [batch_size, max_new_tokens] 231 | return output_tokens 232 | 233 | 234 | if __name__ == "__main__": 235 | torch.manual_seed(42) 236 | # initialize model 237 | config = ToyLlamaConfig( 238 | hidden_size=4096, 239 | intermediate_size=14336, 240 | num_hidden_layers=8, 241 | num_attention_heads=32, 242 | num_key_value_heads=2, 243 | head_dim=128, 244 | rope_theta=500000.0, 245 | rope_scaling={ 246 | "factor": 8.0, 247 | "high_freq_factor": 4.0, 248 | "low_freq_factor": 1.0, 249 | "original_max_position_embeddings": 8192, 250 | "rope_type": "llama3", 251 | }, 252 | ) 253 | inference_config = InferenceConfig( 254 | max_batch_size=4, 255 | max_length=8192, 256 | max_new_tokens=128, 257 | ) 258 | model = ToyLlama(config, inference_config).cuda().bfloat16() 259 | print(f"\nMODEL CONFIG:\n{config}\n") 260 | print(f"\nINFERENCE CONFIG:\n{inference_config}\n") 261 | print(f"\nMODEL:\n{model}\n") 262 | 263 | # example input 264 | batch_size = 4 265 | seqlens = torch.randint(0, 4096, (batch_size,), dtype=torch.int32, device="cuda") 266 | cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device="cuda") 267 | cu_seqlens[1:] = torch.cumsum(seqlens, dim=0) 268 | input_ids = torch.randint( 269 | 0, 128288, (cu_seqlens[-1],), dtype=torch.int64, device="cuda" 270 | ) 271 | print(f"\nEXAMPLE INPUT:\ncu_seqlens: {cu_seqlens}\ninput_ids: {input_ids.shape}\n") 272 | 273 | # example output 274 | logits = model(input_ids, cu_seqlens) 275 | print(f"\nEXAMPLE OUTPUT:\nlogits: {logits.shape}\n") 276 | 277 | # example generate 278 | output_tokens = model.generate(input_ids, cu_seqlens, 64) 279 | print(f"\nEXAMPLE GENERATE:\noutput_tokens: {output_tokens}\n") 280 | -------------------------------------------------------------------------------- /native_sparse_attention/model/toy_nsa_llama.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Xunhao Lai & Jianqiao Lu. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import Optional 15 | import torch 16 | import torch.nn as nn 17 | from dataclasses import dataclass, field 18 | from native_sparse_attention.module import NativeSparseAttention, RopeConfig, NSACache 19 | 20 | 21 | @dataclass 22 | class ToyNSALlamaConfig: 23 | # embedding config 24 | vocab_size: int = 128288 25 | max_position_embeddings: int = 131072 26 | # model config 27 | hidden_size: int = 4096 28 | intermediate_size: int = 14336 29 | num_hidden_layers: int = 32 30 | num_attention_heads: int = 32 31 | num_key_value_heads: int = 2 32 | head_dim: int = 128 33 | # rope config 34 | rope_theta: float = 500000.0 35 | rope_scaling: dict = field( 36 | default_factory=lambda: { 37 | "factor": 8.0, 38 | "high_freq_factor": 4.0, 39 | "low_freq_factor": 1.0, 40 | "original_max_position_embeddings": 8192, 41 | "rope_type": "llama3", 42 | } 43 | ) 44 | # nsa config 45 | compress_type: str = "weightedpool" 46 | kernel_size: int = 32 47 | kernel_stride: int = 16 48 | block_size: int = 64 49 | topk: int = 16 50 | init_blocks: int = 1 51 | local_blocks: int = 2 52 | window_size: int = 512 53 | 54 | 55 | @dataclass 56 | class InferenceConfig: 57 | max_batch_size: int = 32 58 | max_length: int = 8192 59 | max_new_tokens: int = 128 60 | 61 | 62 | class RMSNorm(nn.Module): 63 | def __init__(self, hidden_size: int, eps: float = 1e-6): 64 | super().__init__() 65 | self.weight = nn.Parameter(torch.ones(hidden_size)) 66 | self.variance_epsilon = eps 67 | 68 | def forward(self, hidden_states: torch.Tensor): 69 | input_dtype = hidden_states.dtype 70 | hidden_states = hidden_states.to(torch.float32) 71 | variance = hidden_states.pow(2).mean(-1, keepdim=True) 72 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) 73 | return self.weight * hidden_states.to(input_dtype) 74 | 75 | 76 | class FFN(nn.Module): 77 | def __init__(self, hidden_size: int, intermediate_size: int): 78 | super().__init__() 79 | self.hidden_size = hidden_size 80 | self.intermediate_size = intermediate_size 81 | self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) 82 | self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) 83 | self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) 84 | self.act_fn = nn.SiLU() 85 | 86 | def forward(self, x): 87 | down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) 88 | return down_proj 89 | 90 | 91 | class ToyNSALlamaLayer(nn.Module): 92 | def __init__( 93 | self, 94 | hidden_size: int, 95 | intermediate_size: int, 96 | num_q_heads: int, 97 | num_kv_heads: int, 98 | head_dim: int, 99 | compress_type: str, 100 | kernel_size: int, 101 | kernel_stride: int, 102 | block_size: int, 103 | topk: int, 104 | init_blocks: int, 105 | local_blocks: int, 106 | window_size: int, 107 | rope_config: RopeConfig, 108 | ): 109 | super().__init__() 110 | self.hidden_size = hidden_size 111 | self.intermediate_size = intermediate_size 112 | self.num_q_heads = num_q_heads 113 | self.num_kv_heads = num_kv_heads 114 | self.head_dim = head_dim 115 | self.compress_type = compress_type 116 | self.kernel_size = kernel_size 117 | self.kernel_stride = kernel_stride 118 | self.block_size = block_size 119 | self.topk = topk 120 | self.init_blocks = init_blocks 121 | self.local_blocks = local_blocks 122 | self.window_size = window_size 123 | self.rope_config = rope_config 124 | self.attn_norm = RMSNorm(self.hidden_size) 125 | self.nsa = NativeSparseAttention( 126 | hidden_size=self.hidden_size, 127 | num_q_heads=self.num_q_heads, 128 | num_kv_heads=self.num_kv_heads, 129 | head_dim=self.head_dim, 130 | compress_type=self.compress_type, 131 | kernel_size=self.kernel_size, 132 | kernel_stride=self.kernel_stride, 133 | block_size=self.block_size, 134 | topk=self.topk, 135 | init_blocks=self.init_blocks, 136 | local_blocks=self.local_blocks, 137 | window_size=self.window_size, 138 | rope_config=rope_config, 139 | ) 140 | self.ffn_norm = RMSNorm(self.hidden_size) 141 | self.ffn = FFN( 142 | hidden_size=self.hidden_size, intermediate_size=self.intermediate_size 143 | ) 144 | 145 | def forward(self, x, cu_seqlens): 146 | x = x + self.nsa(self.attn_norm(x), cu_seqlens) 147 | x = x + self.ffn(self.ffn_norm(x)) 148 | return x 149 | 150 | @torch.no_grad() 151 | def inference(self, x, cu_seqlens, step, kv_cache): 152 | x = x + self.nsa.inference(self.attn_norm(x), cu_seqlens, step, kv_cache) 153 | x = x + self.ffn(self.ffn_norm(x)) 154 | return x 155 | 156 | 157 | class ToyNSALlama(nn.Module): 158 | def __init__( 159 | self, 160 | config: ToyNSALlamaConfig, 161 | inference_config: Optional[InferenceConfig] = None, 162 | ): 163 | super().__init__() 164 | self.config = config 165 | self.embedding = nn.Embedding(self.config.vocab_size, self.config.hidden_size) 166 | self.rope_config = RopeConfig( 167 | head_dim=self.config.head_dim, 168 | rope_theta=self.config.rope_theta, 169 | rope_scaling=self.config.rope_scaling, 170 | ) 171 | self.layers = nn.ModuleList( 172 | [ 173 | ToyNSALlamaLayer( 174 | hidden_size=self.config.hidden_size, 175 | intermediate_size=self.config.intermediate_size, 176 | num_q_heads=self.config.num_attention_heads, 177 | num_kv_heads=self.config.num_key_value_heads, 178 | head_dim=self.config.head_dim, 179 | compress_type=self.config.compress_type, 180 | kernel_size=self.config.kernel_size, 181 | kernel_stride=self.config.kernel_stride, 182 | block_size=self.config.block_size, 183 | topk=self.config.topk, 184 | init_blocks=self.config.init_blocks, 185 | local_blocks=self.config.local_blocks, 186 | window_size=self.config.window_size, 187 | rope_config=RopeConfig( 188 | self.config.max_position_embeddings, 189 | self.config.head_dim, 190 | self.config.rope_theta, 191 | self.config.rope_scaling, 192 | ), 193 | ) 194 | for _ in range(self.config.num_hidden_layers) 195 | ] 196 | ) 197 | self.norm = RMSNorm(self.config.hidden_size) 198 | self.lm_head = nn.Linear( 199 | self.config.hidden_size, self.config.vocab_size, bias=False 200 | ) 201 | 202 | # inference config and kv cache 203 | self.inference_config = inference_config 204 | self.kv_cache = None 205 | 206 | def forward( 207 | self, 208 | input_ids: torch.LongTensor, # shape: [batch_size, max_length] 209 | cu_seqlens: torch.LongTensor, # shape: [batch_size + 1, ] 210 | ): 211 | # embedding 212 | x = self.embedding(input_ids).to(torch.bfloat16) 213 | # layers 214 | for layer in self.layers: 215 | x = layer(x, cu_seqlens) 216 | # final norm 217 | x = self.norm(x) 218 | # lanugauge head 219 | x = self.lm_head(x).to(torch.float32) # [total_len, vocab_size] 220 | return x 221 | 222 | @torch.no_grad() 223 | def inference( 224 | self, 225 | input_ids: torch.LongTensor, # prefill shape: [total_length, ]; decode shape: [batch_size, ] 226 | cu_seqlens: torch.LongTensor, # shape: [batch_size + 1, ] 227 | step: int, 228 | ): 229 | # set kv cache if self.kv_cache is None 230 | if self.kv_cache is None: 231 | self.kv_cache = [ 232 | NSACache( 233 | max_batch_size=self.inference_config.max_batch_size, 234 | max_length=self.inference_config.max_length, 235 | num_kv_heads=self.config.num_key_value_heads, 236 | head_dim=self.config.head_dim, 237 | kernel_size=self.config.kernel_size, 238 | kernel_stride=self.config.kernel_stride, 239 | window_size=self.config.window_size, 240 | dtype=torch.bfloat16, 241 | device="cuda", 242 | ) 243 | for _ in range(self.config.num_hidden_layers) 244 | ] 245 | # embedding 246 | x = self.embedding(input_ids).to(torch.bfloat16) 247 | # layers 248 | for i, layer in enumerate(self.layers): 249 | x = layer.inference(x, cu_seqlens, step, self.kv_cache[i]) 250 | # final norm 251 | x = self.norm(x) 252 | # lanugauge head 253 | if step == 0: 254 | x = x[cu_seqlens[1:] - 1, :] 255 | x = self.lm_head(x).to(torch.float32) # [total_len, vocab_size] 256 | return x 257 | 258 | def generate( 259 | self, 260 | input_ids: torch.LongTensor, 261 | cu_seqlens: torch.LongTensor, 262 | max_new_tokens: int = -1, 263 | ): 264 | output_tokens = [] 265 | if max_new_tokens <= 0: 266 | max_new_tokens = self.inference_config.max_new_tokens 267 | for step in range(max_new_tokens): 268 | logits = self.inference( 269 | input_ids, cu_seqlens, step 270 | ) # shape: [batch_size, vocab_size] 271 | next_token = torch.argmax(logits, dim=-1) # shape: [batch_size, ] 272 | input_ids = next_token 273 | output_tokens.append(next_token) 274 | output_tokens = torch.stack( 275 | output_tokens, dim=1 276 | ) # shape: [batch_size, max_new_tokens] 277 | return output_tokens 278 | 279 | 280 | if __name__ == "__main__": 281 | torch.manual_seed(42) 282 | # initialize model 283 | config = ToyNSALlamaConfig( 284 | hidden_size=4096, 285 | intermediate_size=14336, 286 | num_hidden_layers=8, 287 | num_attention_heads=32, 288 | num_key_value_heads=2, 289 | head_dim=128, 290 | rope_theta=500000.0, 291 | rope_scaling={ 292 | "factor": 8.0, 293 | "high_freq_factor": 4.0, 294 | "low_freq_factor": 1.0, 295 | "original_max_position_embeddings": 8192, 296 | "rope_type": "llama3", 297 | }, 298 | compress_type="weightedpool", 299 | kernel_size=32, 300 | kernel_stride=16, 301 | block_size=64, 302 | topk=8, 303 | init_blocks=1, 304 | local_blocks=2, 305 | window_size=512, 306 | ) 307 | inference_config = InferenceConfig( 308 | max_batch_size=4, 309 | max_length=8192, 310 | max_new_tokens=128, 311 | ) 312 | model = ToyNSALlama(config, inference_config).cuda().bfloat16() 313 | print(f"\nMODEL CONFIG:\n{config}\n") 314 | print(f"\nINFERENCE CONFIG:\n{inference_config}\n") 315 | print(f"\nMODEL:\n{model}\n") 316 | 317 | # example input 318 | batch_size = 4 319 | seqlens = torch.randint(0, 4096, (batch_size,), dtype=torch.int32, device="cuda") 320 | cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device="cuda") 321 | cu_seqlens[1:] = torch.cumsum(seqlens, dim=0) 322 | input_ids = torch.randint( 323 | 0, 128288, (cu_seqlens[-1],), dtype=torch.int64, device="cuda" 324 | ) 325 | print(f"\nEXAMPLE INPUT:\ncu_seqlens: {cu_seqlens}\ninput_ids: {input_ids.shape}\n") 326 | 327 | # example output 328 | logits = model(input_ids, cu_seqlens) 329 | print(f"\nEXAMPLE OUTPUT:\nlogits: {logits.shape}\n") 330 | 331 | # example generate 332 | output_tokens = model.generate(input_ids, cu_seqlens, 64) 333 | print(f"\nEXAMPLE GENERATE:\noutput_tokens: {output_tokens}\n") 334 | -------------------------------------------------------------------------------- /native_sparse_attention/module/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Xunhao Lai & Jianqiao Lu. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from native_sparse_attention.module.native_sparse_attention import NativeSparseAttention 15 | from native_sparse_attention.module.self_attention import SelfAttention 16 | from native_sparse_attention.module.rope import RotaryEmbedding, RopeConfig 17 | from native_sparse_attention.module.kv_cache import NSACache, KVCache 18 | 19 | __all__ = [ 20 | "SelfAttention", 21 | "NativeSparseAttention", 22 | "RotaryEmbedding", 23 | "RopeConfig", 24 | "NSACache", 25 | ] 26 | -------------------------------------------------------------------------------- /native_sparse_attention/module/native_sparse_attention.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Xunhao Lai & Jianqiao Lu. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import torch 15 | from flash_attn import flash_attn_varlen_func 16 | from native_sparse_attention.ops import ( 17 | compressed_attention, 18 | topk_sparse_attention, 19 | avgpool_compress, 20 | weightedpool_compress, 21 | linear_compress, 22 | ) 23 | from einops import rearrange 24 | from native_sparse_attention.module.rope import RopeConfig, RotaryEmbedding 25 | from native_sparse_attention.infer import nsa_infer 26 | from native_sparse_attention.module.kv_cache import NSACache 27 | 28 | COMPRESS_TYPE_TO_FUNC = { 29 | "avgpool": avgpool_compress, 30 | "weightedpool": weightedpool_compress, 31 | "linear": linear_compress, 32 | } 33 | 34 | COMPRESS_TYPE_TO_WEIGHT = { 35 | "avgpool": lambda num_heads, head_dim, kernel_size: None, 36 | "weightedpool": lambda num_heads, head_dim, kernel_size: torch.nn.Parameter( 37 | torch.zeros(num_heads, kernel_size) 38 | ), 39 | "linear": lambda num_heads, head_dim, kernel_size: torch.nn.Parameter( 40 | torch.zeros(num_heads, head_dim * kernel_size, head_dim) 41 | ), 42 | } 43 | 44 | 45 | class NativeSparseAttention(torch.nn.Module): 46 | """Native sparse attention module, support training and inference 47 | 48 | Args: 49 | compress_type (str): key value compression type, currently support ['linear', 'avgpool', 'weightedpool'] 50 | hidden_size (int): hidden dimension 51 | num_q_heads (int): number of query heads 52 | num_kv_heads (int): number of key/value heads, must be divisible by num_q_heads 53 | head_dim (int): head dim 54 | kernel_size (int): kernel size of compression 55 | kernel_stride (int): kernel stride ofr compression 56 | block_size (int): block size of sparse attention 57 | topk (int): topk of sparse attention 58 | init_blocks (int): number of blocks at the begining of the sequence, these blocks are force to be computed in sparse attention 59 | local_blocks (int): number of blocks at the local window of each query, these blocks are force to be computed in sparse attention 60 | window_size (int): window size for sliding window attention 61 | rope_config (RopeConfig): config for rotary embedding, see native_sparse_attention.module.rope.RopeConfig for details 62 | rope_device (str): device used to store rope freqs 63 | """ 64 | 65 | def __init__( 66 | self, 67 | compress_type: str, 68 | hidden_size: int, 69 | num_q_heads: int, 70 | num_kv_heads: int, 71 | head_dim: int, 72 | kernel_size: int, 73 | kernel_stride: int, 74 | block_size: int, 75 | topk: int, 76 | init_blocks: int, 77 | local_blocks: int, 78 | window_size: int, 79 | rope_config: RopeConfig, 80 | rope_device: str = "cuda", 81 | ): 82 | super().__init__() 83 | # configs 84 | self.compress_type = compress_type 85 | self.hidden_size = hidden_size 86 | self.num_q_heads = num_q_heads 87 | self.num_kv_heads = num_kv_heads 88 | self.head_dim = head_dim 89 | self.kernel_size = kernel_size 90 | self.kernel_stride = kernel_stride 91 | self.block_size = block_size 92 | self.topk = topk 93 | self.init_blocks = init_blocks 94 | self.local_blocks = local_blocks 95 | self.window_size = window_size 96 | self.rope_config = rope_config 97 | assert self.head_dim == self.rope_config.head_dim 98 | 99 | # qkv proj and o proj 100 | self.proj_q = torch.nn.Linear( 101 | self.hidden_size, self.num_q_heads * self.head_dim, bias=False 102 | ) 103 | self.proj_k = torch.nn.Linear( 104 | self.hidden_size, self.num_kv_heads * self.head_dim, bias=False 105 | ) 106 | self.proj_v = torch.nn.Linear( 107 | self.hidden_size, self.num_kv_heads * self.head_dim, bias=False 108 | ) 109 | self.proj_o = torch.nn.Linear( 110 | self.num_q_heads * self.head_dim, self.hidden_size, bias=False 111 | ) 112 | 113 | # nsa compress func 114 | self.compress_func = COMPRESS_TYPE_TO_FUNC[self.compress_type] 115 | 116 | # nsa parameteres 117 | self.compress_key = COMPRESS_TYPE_TO_WEIGHT[self.compress_type]( 118 | num_kv_heads, head_dim, kernel_size 119 | ) 120 | 121 | self.compress_value = COMPRESS_TYPE_TO_WEIGHT[self.compress_type]( 122 | num_kv_heads, head_dim, kernel_size 123 | ) 124 | self.intra_block_pe = torch.nn.Parameter( 125 | torch.zeros(self.num_kv_heads, self.kernel_size, self.head_dim) 126 | ) 127 | 128 | # gate function 129 | self.gate = torch.nn.Sequential( 130 | torch.nn.Linear(self.hidden_size, self.num_q_heads * 3, bias=False), 131 | torch.nn.Sigmoid(), 132 | ) 133 | 134 | # rope 135 | self.rope = RotaryEmbedding(self.rope_config, device=rope_device) 136 | 137 | # init parameters 138 | self.init_params() 139 | 140 | def init_params(self): 141 | for p in self.parameters(): 142 | if len(p.shape) > 1: 143 | torch.nn.init.xavier_uniform_(p) 144 | 145 | def forward( 146 | self, 147 | x: torch.Tensor, # shape: [total_len, hidden_size] 148 | cu_seqlens: torch.Tensor, # shape: [batch_size + 1] 149 | ): 150 | # dtype and shape check 151 | assert x.dtype == torch.bfloat16 or x.dtype == torch.float16 152 | assert x.shape[-1] == self.hidden_size 153 | cu_seqlens = cu_seqlens.to(torch.int32) 154 | seqlens = cu_seqlens[1:] - cu_seqlens[:-1] 155 | 156 | # qkv proj 157 | q = self.proj_q(x).view(-1, self.num_q_heads, self.head_dim) 158 | k = self.proj_k(x).view(-1, self.num_kv_heads, self.head_dim) 159 | v = self.proj_v(x).view(-1, self.num_kv_heads, self.head_dim) 160 | 161 | # compressed key and value before rope 162 | compressed_k, compressed_cu_seqlens = self.compress_func( 163 | k, 164 | self.compress_key, 165 | cu_seqlens, 166 | self.kernel_size, 167 | self.kernel_stride, 168 | self.intra_block_pe, 169 | ) 170 | compressed_v, _ = self.compress_func( 171 | v, 172 | self.compress_value, 173 | cu_seqlens, 174 | self.kernel_size, 175 | self.kernel_stride, 176 | None, 177 | ) 178 | 179 | # do rope for query and compressed key 180 | q = self.rope(q, cu_seqlens) 181 | compressed_k = self.rope( 182 | compressed_k, compressed_cu_seqlens, stride=self.kernel_stride 183 | ) 184 | 185 | # attention between query and compressed key value 186 | compressed_seqlens = compressed_cu_seqlens[1:] - compressed_cu_seqlens[:-1] 187 | compressed_attn_output, topk_idx = compressed_attention( 188 | q, 189 | compressed_k, 190 | compressed_v, 191 | self.kernel_size, 192 | self.kernel_stride, 193 | self.block_size, 194 | self.topk, 195 | cu_seqlens, 196 | compressed_cu_seqlens, 197 | seqlens.max().item(), 198 | compressed_seqlens.max().item(), 199 | None, 200 | self.init_blocks, 201 | self.local_blocks, 202 | ) 203 | 204 | # do rope for original key 205 | k = self.rope(k, cu_seqlens) 206 | 207 | # topk sparse attention 208 | sparse_attn_output = topk_sparse_attention( 209 | q, k, v, topk_idx, self.block_size, cu_seqlens, None 210 | ) 211 | 212 | # sliding window attention 213 | sliding_attn_output = flash_attn_varlen_func( 214 | q, 215 | k, 216 | v, 217 | cu_seqlens, 218 | cu_seqlens, 219 | seqlens.max().item(), 220 | seqlens.max().item(), 221 | causal=True, 222 | window_size=(self.window_size, -1), 223 | ) 224 | 225 | # gate average 226 | gate = self.gate(x) 227 | gate = rearrange(gate, "n (h g) -> n h g", g=3) 228 | attn_output = ( 229 | gate[..., 0:1] * compressed_attn_output 230 | + gate[..., 1:2] * sparse_attn_output 231 | + gate[..., 2:3] * sliding_attn_output 232 | ) 233 | 234 | # rearrange and output proj 235 | attn_output = rearrange(attn_output, "n h d -> n (h d)") 236 | attn_output = self.proj_o(attn_output) 237 | 238 | return attn_output 239 | 240 | @torch.no_grad() 241 | def inference( 242 | self, 243 | x: torch.Tensor, # shape: [total_len, hidden_size] 244 | cu_seqlens: torch.Tensor, # shape: [batch_size + 1] 245 | step: int, 246 | cache: NSACache, 247 | ): 248 | # dtype and shape check 249 | assert x.dtype == torch.bfloat16 or x.dtype == torch.float16 250 | assert x.shape[-1] == self.hidden_size 251 | cu_seqlens = cu_seqlens.to(torch.int32) 252 | assert step >= 0 253 | if step == 0: 254 | assert x.shape[0] == cu_seqlens[-1] 255 | else: 256 | assert x.shape[0] == cu_seqlens.shape[0] - 1 257 | # qkv proj 258 | q = self.proj_q(x).view(-1, self.num_q_heads, self.head_dim) 259 | k = self.proj_k(x).view(-1, self.num_kv_heads, self.head_dim) 260 | v = self.proj_v(x).view(-1, self.num_kv_heads, self.head_dim) 261 | # gate proj 262 | gate = self.gate(x) 263 | gate = rearrange(gate, "n (h g) -> n h g", g=3) 264 | # nsa infer 265 | output = nsa_infer( 266 | cu_seqlens, 267 | step, 268 | q, 269 | k, 270 | v, 271 | gate, 272 | self.rope, 273 | cache, 274 | [self.compress_key, self.compress_value], 275 | [self.compress_func, self.compress_func], 276 | self.intra_block_pe, 277 | self.kernel_size, 278 | self.kernel_stride, 279 | self.block_size, 280 | self.topk, 281 | self.init_blocks, 282 | self.local_blocks, 283 | self.window_size, 284 | ) 285 | # output proj 286 | output = rearrange(output, "n h d -> n (h d)") 287 | output = self.proj_o(output) 288 | return output 289 | -------------------------------------------------------------------------------- /native_sparse_attention/module/rope.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Xunhao Lai & Jianqiao Lu. 2 | # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. 3 | # 4 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX 5 | # and OPT implementations in this library. It has been modified from its 6 | # original forms to accommodate minor architectural differences compared 7 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model. 8 | # 9 | # Licensed under the Apache License, Version 2.0 (the "License"); 10 | # you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # 13 | # http://www.apache.org/licenses/LICENSE-2.0 14 | # 15 | # Unless required by applicable law or agreed to in writing, software 16 | # distributed under the License is distributed on an "AS IS" BASIS, 17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | # See the License for the specific language governing permissions and 19 | # limitations under the License. 20 | import torch 21 | from dataclasses import dataclass, field 22 | from torch import nn 23 | from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS 24 | 25 | 26 | # default to llama3.1 rope config 27 | @dataclass 28 | class RopeConfig: 29 | """Config for RotaryEmbedding, similar to transformers llama.""" 30 | 31 | max_position_embeddings: int = 131072 32 | head_dim: int = 128 33 | rope_theta: float = 500000 34 | rope_scaling: dict = field( 35 | default_factory=lambda: { 36 | "factor": 8.0, 37 | "high_freq_factor": 4.0, 38 | "low_freq_factor": 1.0, 39 | "original_max_position_embeddings": 8192, 40 | "rope_type": "llama3", 41 | } 42 | ) 43 | # useless, just for compatibility, please use head_dim instead 44 | hidden_size: int = 1 45 | num_attention_heads: int = 1 46 | 47 | def __post_init__(self): 48 | self.num_attention_heads = 1 49 | self.hidden_size = self.head_dim 50 | 51 | 52 | # Copied from transformers.models.llama.modeling_llama.rotate_half 53 | def rotate_half(x): 54 | """Rotates half the hidden dims of the input.""" 55 | x1 = x[..., : x.shape[-1] // 2] 56 | x2 = x[..., x.shape[-1] // 2 :] 57 | return torch.cat((-x2, x1), dim=-1) 58 | 59 | 60 | # copy and modify from modify from hugigngface transformers 61 | # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py 62 | class RotaryEmbedding(nn.Module): 63 | """Rotary embedding 64 | 65 | Args: 66 | config (RopeConfig): config for rotary embedding, see native_sparse_attention.module.rope.RopeConfig for details 67 | device (str): default to 'cuda' 68 | """ 69 | 70 | cos = None 71 | sin = None 72 | 73 | def __init__( 74 | self, config: RopeConfig, device=torch.device(torch.cuda.current_device()) 75 | ): 76 | super().__init__() 77 | # BC: "rope_type" was originally "type" 78 | if hasattr(config, "rope_scaling") and config.rope_scaling is not None: 79 | self.rope_type = config.rope_scaling.get( 80 | "rope_type", config.rope_scaling.get("type") 81 | ) 82 | else: 83 | self.rope_type = "default" 84 | self.max_seq_len_cached = config.max_position_embeddings 85 | self.original_max_seq_len = config.max_position_embeddings 86 | 87 | self.config = config 88 | self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] 89 | 90 | inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) 91 | self.register_buffer("inv_freq", inv_freq, persistent=False) 92 | self.original_inv_freq = self.inv_freq 93 | 94 | def _dynamic_frequency_update(self, position_ids, device): 95 | """ 96 | dynamic RoPE layers should recompute `inv_freq` in the following situations: 97 | 1 - growing beyond the cached sequence length (allow scaling) 98 | 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) 99 | """ 100 | seq_len = torch.max(position_ids) + 1 101 | if seq_len > self.max_seq_len_cached: # growth 102 | inv_freq, self.attention_scaling = self.rope_init_fn( 103 | self.config, device, seq_len=seq_len 104 | ) 105 | self.register_buffer( 106 | "inv_freq", inv_freq, persistent=False 107 | ) # TODO joao: may break with compilation 108 | self.max_seq_len_cached = seq_len 109 | 110 | if ( 111 | seq_len < self.original_max_seq_len 112 | and self.max_seq_len_cached > self.original_max_seq_len 113 | ): # reset 114 | # This .to() is needed if the model has been moved to a device after being initialized (because 115 | # the buffer is automatically moved, but not the original copy) 116 | self.original_inv_freq = self.original_inv_freq.to(device) 117 | self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) 118 | self.max_seq_len_cached = self.original_max_seq_len 119 | 120 | @torch.no_grad() 121 | def generate_cos_sin(self, x: torch.Tensor, position_ids): 122 | if "dynamic" in self.rope_type: 123 | self._dynamic_frequency_update(position_ids, device=x.device) 124 | 125 | # Core RoPE block 126 | inv_freq_expanded = ( 127 | self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) 128 | ) 129 | position_ids_expanded = position_ids[:, None, :].float() 130 | # Force float32 (see https://github.com/huggingface/transformers/pull/29285) 131 | device_type = x.device.type 132 | device_type = ( 133 | device_type 134 | if isinstance(device_type, str) and device_type != "mps" 135 | else "cpu" 136 | ) 137 | with torch.autocast(device_type=device_type, enabled=False): 138 | freqs = ( 139 | inv_freq_expanded.float() @ position_ids_expanded.float() 140 | ).transpose(1, 2) 141 | # # donot use this if use flash_attn 142 | # emb = torch.cat((freqs, freqs), dim=-1) 143 | emb = freqs 144 | cos = emb.cos() 145 | sin = emb.sin() 146 | 147 | # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention 148 | cos = (cos * self.attention_scaling).to(dtype=x.dtype).squeeze(0) 149 | sin = (sin * self.attention_scaling).to(dtype=x.dtype).squeeze(0) 150 | 151 | # save cos sin 152 | RotaryEmbedding.cos = torch.cat([cos, cos], dim=-1) 153 | RotaryEmbedding.sin = torch.cat([sin, sin], dim=-1) 154 | 155 | return RotaryEmbedding.cos, RotaryEmbedding.sin 156 | 157 | @torch.no_grad() 158 | def generate_pos_embs( 159 | self, 160 | x: torch.Tensor, 161 | cu_seqlens: torch.Tensor, 162 | seqlens: torch.Tensor, 163 | step: int = 0, 164 | stride: int = 1, 165 | ): 166 | if ( 167 | RotaryEmbedding.cos is None 168 | or seqlens.max() + step > RotaryEmbedding.cos.shape[0] 169 | ): 170 | self.generate_cos_sin( 171 | x, torch.arange(seqlens.max() + step).to(x.device)[None, :] 172 | ) 173 | 174 | cos_embs = [] 175 | sin_embs = [] 176 | bsz = len(cu_seqlens) - 1 177 | 178 | for i in range(bsz): 179 | if step == 0: # prefilling 180 | r = cu_seqlens[i + 1] - cu_seqlens[i] 181 | cos_emb, sin_emb = ( 182 | RotaryEmbedding.cos[: r * stride : stride], 183 | RotaryEmbedding.sin[: r * stride : stride], 184 | ) 185 | elif step > 0: # decoding 186 | r = cu_seqlens[i + 1] - cu_seqlens[i] + step - 1 187 | cos_emb, sin_emb = ( 188 | RotaryEmbedding.cos[r * stride : r * stride + 1], 189 | RotaryEmbedding.sin[r * stride : r * stride + 1], 190 | ) 191 | cos_embs.append(cos_emb) 192 | sin_embs.append(sin_emb) 193 | 194 | cos_embs = torch.cat(cos_embs, dim=0) 195 | sin_embs = torch.cat(sin_embs, dim=0) 196 | return cos_embs, sin_embs 197 | 198 | def forward(self, x, cu_seqlens, step=0, stride=1): 199 | seqlens = cu_seqlens[1:] - cu_seqlens[:-1] 200 | cos_embs, sin_embs = self.generate_pos_embs( 201 | x, 202 | cu_seqlens, 203 | seqlens, 204 | step=step, 205 | stride=stride, 206 | ) 207 | N, H, D = x.shape[0], x.shape[-2], x.shape[-1] # H: number of heads 208 | x = x * cos_embs.view(N, 1, D) + rotate_half(x) * sin_embs.view(N, 1, D) 209 | return x 210 | -------------------------------------------------------------------------------- /native_sparse_attention/module/self_attention.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Xunhao Lai & Jianqiao Lu. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import torch 15 | from flash_attn import flash_attn_varlen_func 16 | from einops import rearrange 17 | from native_sparse_attention.module.rope import RopeConfig, RotaryEmbedding 18 | from native_sparse_attention.module.kv_cache import KVCache 19 | from native_sparse_attention.ops import flash_attention_decode 20 | 21 | 22 | class SelfAttention(torch.nn.Module): 23 | """self attention module 24 | 25 | Args: 26 | hidden_size (int): hidden dimension 27 | num_q_heads (int): number of query heads 28 | num_kv_heads (int): number of key/value heads, must be divisible by num_q_heads 29 | head_dim (int): head dim 30 | rope_config (RopeConfig): config for rotary embedding, see native_sparse_attention.module.rope.RopeConfig for details 31 | """ 32 | 33 | def __init__( 34 | self, 35 | hidden_size: int, 36 | num_q_heads: int, 37 | num_kv_heads: int, 38 | head_dim: int, 39 | rope_config: RopeConfig, 40 | rope_device: str = "cuda", 41 | ): 42 | super().__init__() 43 | # configs 44 | self.hidden_size = hidden_size 45 | self.num_q_heads = num_q_heads 46 | self.num_kv_heads = num_kv_heads 47 | self.head_dim = head_dim 48 | self.rope_config = rope_config 49 | assert self.head_dim == self.rope_config.head_dim 50 | 51 | # qkv proj and o proj 52 | self.proj_q = torch.nn.Linear( 53 | self.hidden_size, self.num_q_heads * self.head_dim, bias=False 54 | ) 55 | self.proj_k = torch.nn.Linear( 56 | self.hidden_size, self.num_kv_heads * self.head_dim, bias=False 57 | ) 58 | self.proj_v = torch.nn.Linear( 59 | self.hidden_size, self.num_kv_heads * self.head_dim, bias=False 60 | ) 61 | self.proj_o = torch.nn.Linear( 62 | self.num_q_heads * self.head_dim, self.hidden_size, bias=False 63 | ) 64 | # rope 65 | self.rope = RotaryEmbedding(self.rope_config, device=rope_device) 66 | 67 | # init parameters 68 | self.init_params() 69 | 70 | def init_params(self): 71 | for p in self.parameters(): 72 | torch.nn.init.xavier_uniform_(p) 73 | 74 | def forward( 75 | self, 76 | x: torch.Tensor, # shape: [total_len, hidden_size] 77 | cu_seqlens: torch.Tensor, # shape: [batch_size + 1] 78 | ): 79 | # dtype and shape check 80 | assert x.dtype == torch.bfloat16 or x.dtype == torch.float16 81 | assert x.shape[-1] == self.hidden_size 82 | cu_seqlens = cu_seqlens.to(torch.int32) 83 | seqlens = cu_seqlens[1:] - cu_seqlens[:-1] 84 | 85 | # qkv proj 86 | q = self.proj_q(x).view(-1, self.num_q_heads, self.head_dim) 87 | k = self.proj_k(x).view(-1, self.num_kv_heads, self.head_dim) 88 | v = self.proj_v(x).view(-1, self.num_kv_heads, self.head_dim) 89 | 90 | # do rope for query and compressed key 91 | q = self.rope(q, cu_seqlens) 92 | k = self.rope(k, cu_seqlens) 93 | 94 | # self attention 95 | attn_output = flash_attn_varlen_func( 96 | q, 97 | k, 98 | v, 99 | cu_seqlens, 100 | cu_seqlens, 101 | seqlens.max().item(), 102 | seqlens.max().item(), 103 | causal=True, 104 | ) 105 | 106 | # rearrange and output proj 107 | attn_output = rearrange(attn_output, "n h d -> n (h d)") 108 | attn_output = self.proj_o(attn_output) 109 | 110 | return attn_output 111 | 112 | @torch.no_grad() 113 | def inference( 114 | self, 115 | x: torch.Tensor, # shape: [total_len, hidden_size] 116 | cu_seqlens: torch.Tensor, # shape: [batch_size + 1] 117 | step: int, 118 | cache: KVCache, 119 | ): 120 | # dtype and shape check 121 | assert x.dtype == torch.bfloat16 or x.dtype == torch.float16 122 | assert x.shape[-1] == self.hidden_size 123 | cu_seqlens = cu_seqlens.to(torch.int32) 124 | seqlens = cu_seqlens[1:] - cu_seqlens[:-1] 125 | assert step >= 0 126 | if step == 0: 127 | assert x.shape[0] == cu_seqlens[-1] 128 | else: 129 | assert x.shape[0] == cu_seqlens.shape[0] - 1 130 | batch_size = cu_seqlens.shape[0] - 1 131 | # qkv proj 132 | q = self.proj_q(x).view(-1, self.num_q_heads, self.head_dim) 133 | k = self.proj_k(x).view(-1, self.num_kv_heads, self.head_dim) 134 | v = self.proj_v(x).view(-1, self.num_kv_heads, self.head_dim) 135 | # do rope for query and compressed key 136 | q = self.rope(q, cu_seqlens, step) 137 | k = self.rope(k, cu_seqlens, step) 138 | # reset and update kv cache 139 | if step == 0: 140 | cache.reset() 141 | cache.update_kv(cu_seqlens, step, k, v) 142 | # self attention 143 | if step == 0: 144 | cu_seqlens_q = cu_seqlens_k = cu_seqlens 145 | max_seqlen_in_batch_q = max_seqlen_in_batch_k = seqlens.max().item() 146 | output = flash_attn_varlen_func( 147 | q, 148 | k, 149 | v, 150 | cu_seqlens_q=cu_seqlens_q, 151 | cu_seqlens_k=cu_seqlens_k, 152 | max_seqlen_q=max_seqlen_in_batch_q, 153 | max_seqlen_k=max_seqlen_in_batch_k, 154 | causal=True, 155 | ) 156 | else: 157 | output = flash_attention_decode( 158 | q, 159 | cache.kv_cache[0, :batch_size], 160 | cache.kv_cache[1, :batch_size], 161 | cache.kv_len[:batch_size], 162 | ) 163 | # rearrange and output proj 164 | output = rearrange(output, "n h d -> n (h d)") 165 | output = self.proj_o(output) 166 | return output 167 | -------------------------------------------------------------------------------- /native_sparse_attention/ops/README.md: -------------------------------------------------------------------------------- 1 | # Triton Functions for Native Sparse Attention 2 | 3 | This folder provides efficient Triton-based implementations of components for Native Sparse Attention. This README introduces the available functions, explains how to set them up, and offers guidance on their usage. 4 | 5 | --- 6 | 7 | ## Overview of Functions 8 | 9 | The functions are organized into two main categories: 10 | 11 | 1. **Compression Methods**: Techniques for compressing key and value tensors. 12 | 2. **Attention Mechanisms**: Methods for computing attention between queries and compressed key/value tensors, including top-k sparse attention. 13 | 14 | --- 15 | 16 | ## Function Descriptions 17 | 18 | ### Compression Methods 19 | 20 | These functions compress key and value tensors using a sliding window approach. Within each window, `kernel_size` tokens are compressed into a single token, with a stride of `kernel_stride`. All compression functions share similar input parameters and return formats. 21 | 22 | **Parameters:** 23 | - `x`: Input tensor (`total_len, num_heads, head_dim`) 24 | - `w`: Weight tensor (shape varies by compression method) 25 | - `cu_seqlens`: Cumulative sequence lengths (`batch_size + 1`) 26 | - `kernel_size`: Size of the compression window 27 | - `kernel_stride`: Stride of the compression window 28 | - `pe`: Optional positional embedding (`num_heads, kernel_size, head_dim`) 29 | 30 | **Returns:** 31 | - Compressed tensor (`total_compress_len, num_heads, head_dim`) 32 | - Cumulative sequence lengths (`com_cu_seqlens`) for the compressed tensor 33 | 34 | #### `weightedpool_compress` 35 | Compresses the input tensor using weighted pooling, applying a weighted sum over each block: 36 | $\hat{k} = w_1 k_1 + \dots + w_m k_m$ 37 | - **Weight shape**: `(num_heads, kernel_size)` 38 | 39 | #### `avgpool_compress` 40 | Compresses the input tensor using average pooling: 41 | $\hat{k} = (k_1 + \dots + k_m) / m$ 42 | - **Weight**: Must be `None` 43 | 44 | #### `linear_compress` 45 | Compresses the input tensor via linear projection, mapping each block to a single vector using learned weights: 46 | $\hat{k} = \text{cat}(k_1, \dots, k_m) W$ 47 | - **Weight shape**: `(num_heads, kernel_size * head_dim, head_dim)` 48 | 49 | --- 50 | 51 | ### Attention Mechanisms 52 | 53 | These functions compute attention using either full or sparse mechanisms. 54 | 55 | #### `flash_attention_varlen` 56 | A variable-length implementation of flash attention, similar to `flash_attn_varlen_func` from the `flash_attn` package. 57 | 58 | **Parameters:** 59 | - `q`, `k`, `v`: Query, key, and value tensors (`total_len, num_heads, head_dim`) 60 | - `cu_seqlens_q`, `cu_seqlens_k`: Cumulative sequence lengths for queries and keys 61 | - `max_seqlen_q`, `max_seqlen_k`: Maximum sequence lengths in the batch 62 | - `causal`: Apply causal masking (default: `False`) 63 | - `sm_scale`: Softmax scale (default: `1 / sqrt(head_dim)`) 64 | 65 | **Returns:** 66 | - Attention output tensor (`total_q_len, num_q_heads, head_dim`) 67 | 68 | #### `compressed_attention` 69 | Computes attention between a query and compressed key/value tensors, identifying the top-k blocks for sparse attention. 70 | 71 | **Parameters:** 72 | - `q`: Query tensor (`total_len, num_heads, head_dim`) 73 | - `k`, `v`: Compressed key and value tensors (`total_compress_len, num_heads, head_dim`) 74 | - `kernel_size`, `kernel_stride`: Compression parameters 75 | - `block_size`: Size of blocks for sparse attention 76 | - `topk`: Number of top blocks to select 77 | - `cu_seqlens_q`, `cu_seqlens_k`: Cumulative sequence lengths for query and compressed key/value 78 | - `max_seqlen_q`, `max_seqlen_k`: Maximum sequence lengths for query and compressed key/value 79 | - `sm_scale`: Softmax scale (default: `1 / sqrt(head_dim)`) 80 | - `init_blocks`: Number of initial blocks forced to be selected (default: `1`) 81 | - `local_blocks`: Number of local blocks forced to be selected (default: `2`) 82 | 83 | **Returns:** 84 | - Tuple containing: 85 | - Attention output tensor 86 | - Top-k block indices 87 | 88 | #### `topk_sparse_attention` 89 | Performs sparse attention using precomputed top-k block indices. If a query attends to fewer than `topk` key/value blocks, the `topk_idx` should be padded with `-1` on the right. 90 | 91 | **Parameters:** 92 | - `q`, `k`, `v`: Query, key, and value tensors (`total_len, num_heads, head_dim`) 93 | - `topk_idx`: Precomputed top-k indices (`num_kv_heads, total_len, topk`) 94 | - `block_size`: Block size for sparse attention (recommended: `64` or `128`) 95 | - `cu_seqlens`: Cumulative sequence lengths 96 | - `softmax_scale`: Softmax scale (default: `1 / sqrt(head_dim)`) 97 | 98 | **Returns:** 99 | - Attention output tensor (`total_len, num_q_heads, head_dim`) 100 | 101 | --- 102 | 103 | ## Usage Example 104 | 105 | Below is a typical workflow demonstrating how to combine these sparse attention functions: 106 | 107 | ```python 108 | import torch 109 | from native_sparse_attention.ops import linear_compress, compressed_attention, topk_sparse_attention 110 | 111 | # Example input setup 112 | num_q_heads = 64 113 | num_kv_heads = 4 114 | head_dim = 128 115 | cu_seqlens = torch.tensor([0, 1024, 8192, 16384], dtype=torch.int32).cuda() 116 | 117 | # Query, key, and value tensors 118 | query = torch.randn(16384, num_q_heads, head_dim, dtype=torch.bfloat16).cuda() 119 | key = torch.randn(16384, num_kv_heads, head_dim, dtype=torch.bfloat16).cuda() 120 | value = torch.randn(16384, num_kv_heads, head_dim, dtype=torch.bfloat16).cuda() 121 | 122 | # Compression weights and positional embeddings 123 | kernel_size = 32 124 | kernel_stride = 16 125 | wk = torch.randn(num_kv_heads, kernel_size * head_dim, head_dim, dtype=torch.bfloat16).cuda() 126 | wv = torch.randn_like(wk) 127 | pe = torch.randn(num_kv_heads, kernel_size, head_dim, dtype=torch.bfloat16).cuda() 128 | 129 | # Parameters for top-k sparse attention 130 | block_size = 64 131 | topk = 16 132 | 133 | # 1. Compress key and value tensors 134 | compressed_key, compressed_cu_seqlens = linear_compress( 135 | key, wk, cu_seqlens, kernel_size, kernel_stride, pe 136 | ) 137 | compressed_value, _ = linear_compress( 138 | value, wv, cu_seqlens, kernel_size, kernel_stride, None 139 | ) 140 | 141 | # 2. Compute attention with compressed key/value and get top-k indices 142 | compressed_attn_output, topk_idx = compressed_attention( 143 | query, 144 | compressed_key, 145 | compressed_value, 146 | kernel_size, 147 | kernel_stride, 148 | block_size, 149 | topk, 150 | cu_seqlens, 151 | compressed_cu_seqlens, 152 | init_blocks=1, 153 | local_blocks=2, 154 | ) 155 | 156 | # 3. Perform top-k sparse attention 157 | sparse_attn_output = topk_sparse_attention( 158 | query, 159 | key, 160 | value, 161 | topk_idx, 162 | block_size, 163 | cu_seqlens, 164 | ) 165 | 166 | # 4. Combine attention outputs (e.g., average) 167 | attn_output = (compressed_attn_output + sparse_attn_output) / 2 168 | ``` 169 | 170 | For a complete implementation of the Native Sparse Attention module, see `native_sparse_attention/module/native_sparse_attention.py`. 171 | -------------------------------------------------------------------------------- /native_sparse_attention/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Xunhao Lai & Jianqiao Lu. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # compress method 16 | from native_sparse_attention.ops.triton.weighted_pool import ( 17 | weightedpool_compress, 18 | avgpool_compress, 19 | ) 20 | from native_sparse_attention.ops.triton.linear_compress import linear_compress 21 | 22 | # prefill attention 23 | from native_sparse_attention.ops.triton.flash_attention import flash_attention_varlen 24 | from native_sparse_attention.ops.triton.compressed_attention import compressed_attention 25 | from native_sparse_attention.ops.triton.topk_sparse_attention import ( 26 | topk_sparse_attention, 27 | ) 28 | 29 | # decode attention 30 | from native_sparse_attention.ops.triton.flash_attention_decode import ( 31 | flash_attention_decode, 32 | ) 33 | from native_sparse_attention.ops.torch.compressed_attention_decode import ( 34 | compressed_attention_decode, 35 | ) 36 | from native_sparse_attention.ops.triton.topk_sparse_attention_decode import ( 37 | topk_sparse_attention_decode, 38 | ) 39 | 40 | __all__ = [ 41 | # compress method 42 | "avgpool_compress", 43 | "weightedpool_compress", 44 | "linear_compress", 45 | # prefill attention, trainable 46 | "flash_attention_varlen", 47 | "compressed_attention", 48 | "topk_sparse_attention", 49 | # decode attention, no grad 50 | "flash_attention_decode", 51 | "compressed_attention_decode", 52 | "topk_sparse_attention_decode", 53 | ] 54 | -------------------------------------------------------------------------------- /native_sparse_attention/ops/torch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XunhaoLai/native-sparse-attention-triton/9bea856c911ebf263be88d797fb28458f82f1d94/native_sparse_attention/ops/torch/__init__.py -------------------------------------------------------------------------------- /native_sparse_attention/ops/torch/compress_key_value.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Xunhao Lai & Jianqiao Lu. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import torch 15 | from typing import Optional 16 | from einops import rearrange, einsum 17 | 18 | 19 | def avgpool_compress_torch( 20 | x: torch.Tensor, 21 | w: torch.Tensor, 22 | cu_seqlens, 23 | kernel_size: int, 24 | kernel_stride: int, 25 | pe: Optional[torch.Tensor] = None, 26 | ): 27 | """Compress key and value tensor with kernel_size and kernel_stride. 28 | 29 | Args: 30 | x (torch.Tensor): key_states or value_states, shape (total_len, num_heads, head_dim) 31 | w (torch.Tensor): no weight for avgpool, must be None. 32 | cu_seqlens (_type_): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen. 33 | kernel_size (int): kernel_size, each (kernel_size, head_dim) blocks will be compressed to (1, head_dim) 34 | kernel_stride (int): stride for each compress kernel 35 | pe (Optional[torch.Tensor], optional): intra-block positional embedding with shape (num_heads, kernel_size, head_dim). Defaults to None. 36 | 37 | Returns: 38 | Tuple[torch.Tensor, torch.Tensor]: compressed states and corresponding cu_seqlens. 39 | """ 40 | # dtype check 41 | assert x.dtype == torch.float16 or x.dtype == torch.bfloat16 42 | assert cu_seqlens.dtype == torch.int32 43 | assert x.dtype == pe.dtype if pe is not None else True 44 | 45 | # shape check 46 | total_len, num_heads, head_dim = x.shape 47 | batch_size = cu_seqlens.shape[0] - 1 48 | assert w is None, "don't need additional weight for avgpool" 49 | assert kernel_size % kernel_stride == 0 50 | assert kernel_size in {16, 32, 64, 128} 51 | 52 | # compute seqlens after compression 53 | seqlens = cu_seqlens[1:] - cu_seqlens[:-1] 54 | y_seqlens = torch.floor((seqlens - kernel_size) / kernel_stride).to(torch.int32) + 1 55 | # corner case, if sequence_length < kernel_size, no compression for this sequence 56 | y_seqlens[seqlens < kernel_size] = 0 57 | y_cu_seqlens = torch.cat( 58 | [ 59 | torch.zeros(1, dtype=torch.int32, device="cuda"), 60 | torch.cumsum(y_seqlens, dim=0), 61 | ], 62 | dim=0, 63 | ).to(torch.int32) 64 | 65 | # pad and rearrange x 66 | x = rearrange(x, "n h d -> n (h d)") 67 | splited_x = torch.split(x, seqlens.tolist(), 0) 68 | x = torch.nn.utils.rnn.pad_sequence(splited_x, batch_first=True) 69 | x = rearrange(x, "b n d -> b d n") 70 | # avgpool 71 | y = torch.nn.functional.avg_pool1d(x, kernel_size=kernel_size, stride=kernel_stride) 72 | y = rearrange(y, "b (h d) n -> b n h d", h=num_heads) 73 | # only keep useful part 74 | y = torch.cat([y[i, : y_seqlens[i]] for i in range(batch_size)], dim=0) 75 | 76 | # position embedding as a bias 77 | if pe is not None: 78 | bias = torch.mean(pe, dim=1) 79 | y = y + bias.unsqueeze(0) 80 | 81 | return y, y_cu_seqlens 82 | 83 | 84 | def weightedpool_compress_torch( 85 | x: torch.Tensor, 86 | w: torch.Tensor, # [num_heads, kernel_size] 87 | cu_seqlens, 88 | kernel_size: int, 89 | kernel_stride: int, 90 | pe: Optional[torch.Tensor] = None, 91 | ): 92 | """Compress key and value tensor with kernel_size and kernel_stride. 93 | 94 | Args: 95 | x (torch.Tensor): key_states or value_states, shape (total_len, num_heads, head_dim) 96 | w (torch.Tensor): weight for each head, shape (num_heads, kernel_size) 97 | cu_seqlens (_type_): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen. 98 | kernel_size (int): kernel_size, each (kernel_size, head_dim) blocks will be compressed to (1, head_dim) 99 | kernel_stride (int): stride for each compress kernel 100 | pe (Optional[torch.Tensor], optional): intra-block positional embedding with shape (num_heads, kernel_size, head_dim). Defaults to None. 101 | 102 | Returns: 103 | Tuple[torch.Tensor, torch.Tensor]: compressed states and corresponding cu_seqlens. 104 | """ 105 | # dtype check 106 | assert x.dtype == torch.float16 or x.dtype == torch.bfloat16 107 | assert x.dtype == w.dtype 108 | assert x.dtype == pe.dtype if pe is not None else True 109 | assert cu_seqlens.dtype == torch.int32 110 | # shape check 111 | total_len, num_heads, head_dim = x.shape 112 | batch_size = cu_seqlens.shape[0] - 1 113 | assert w.shape[0] == num_heads 114 | assert w.shape[1] == kernel_size 115 | assert kernel_size % kernel_stride == 0 116 | assert kernel_size in {16, 32, 64, 128} 117 | # compute seqlens after compression 118 | seqlens = cu_seqlens[1:] - cu_seqlens[:-1] 119 | y_seqlens = torch.floor((seqlens - kernel_size) / kernel_stride).to(torch.int32) + 1 120 | # corner case, if sequence_length < kernel_size, no compression for this sequence 121 | y_seqlens[seqlens < kernel_size] = 0 122 | y_cu_seqlens = torch.cat( 123 | [ 124 | torch.zeros(1, dtype=torch.int32, device="cuda"), 125 | torch.cumsum(y_seqlens, dim=0), 126 | ], 127 | dim=0, 128 | ).to(torch.int32) 129 | # pad and rearrange x 130 | x = rearrange(x, "n h d -> n (h d)") 131 | splited_x = torch.split(x, seqlens.tolist(), 0) 132 | x = torch.nn.utils.rnn.pad_sequence(splited_x, batch_first=True) 133 | x = rearrange(x, "b n (h d) -> b h n d", h=num_heads) 134 | x = x.as_strided( 135 | size=(batch_size, num_heads, y_seqlens.max().item(), kernel_size, head_dim), 136 | stride=( 137 | x.stride(0), 138 | x.stride(1), 139 | kernel_stride * x.stride(2), 140 | x.stride(2), 141 | x.stride(3), 142 | ), 143 | ) 144 | y = einsum(x, w, "b h n k d, h k -> b n h d") 145 | # only keep useful part 146 | y = torch.cat([y[i, : y_seqlens[i]] for i in range(batch_size)], dim=0) 147 | 148 | # position embedding as a bias 149 | if pe is not None: 150 | bias = einsum(pe, w, "h k d, h k -> h d") 151 | y = y + bias.unsqueeze(0) 152 | 153 | return y, y_cu_seqlens 154 | 155 | 156 | def linear_compress_torch( 157 | x: torch.Tensor, 158 | w: torch.Tensor, # [num_heads, kernel_size * head_dim, head_dim] 159 | cu_seqlens, 160 | kernel_size: int, 161 | kernel_stride: int, 162 | pe: Optional[torch.Tensor] = None, 163 | ): 164 | """Compress key and value tensor with kernel_size and kernel_stride. Similar to conv_compress. 165 | 166 | Args: 167 | x (torch.Tensor): key_states or value_states, shape (total_len, num_heads, head_dim) 168 | w (torch.Tensor): weight for each head, shape (num_heads, kernel_size * head_dim, head_dim) 169 | cu_seqlens (_type_): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen. 170 | kernel_size (int): kernel_size, each (kernel_size, head_dim) blocks will be compressed to (1, head_dim) 171 | kernel_stride (int): stride for each compress kernel 172 | pe (Optional[torch.Tensor], optional): intra-block positional embedding with shape (num_heads, kernel_size, head_dim). Defaults to None. 173 | 174 | Returns: 175 | Tuple[torch.Tensor, torch.Tensor]: compressed states and corresponding cu_seqlens. 176 | """ 177 | # dtype check 178 | assert x.dtype == torch.float16 or x.dtype == torch.bfloat16 179 | assert x.dtype == w.dtype 180 | assert x.dtype == pe.dtype if pe is not None else True 181 | assert cu_seqlens.dtype == torch.int32 182 | # shape check 183 | total_len, num_heads, head_dim = x.shape 184 | batch_size = cu_seqlens.shape[0] - 1 185 | assert w.shape[0] == num_heads 186 | assert w.shape[1] == kernel_size * head_dim 187 | assert w.shape[2] == head_dim 188 | assert kernel_size % kernel_stride == 0 189 | assert kernel_size in {16, 32, 64, 128} 190 | # compute seqlens after compression 191 | seqlens = cu_seqlens[1:] - cu_seqlens[:-1] 192 | y_seqlens = torch.floor((seqlens - kernel_size) / kernel_stride).to(torch.int32) + 1 193 | # corner case, if sequence_length < kernel_size, no compression for this sequence 194 | y_seqlens[seqlens < kernel_size] = 0 195 | y_cu_seqlens = torch.cat( 196 | [ 197 | torch.zeros(1, dtype=torch.int32, device="cuda"), 198 | torch.cumsum(y_seqlens, dim=0), 199 | ], 200 | dim=0, 201 | ).to(torch.int32) 202 | # pad and rearrange x 203 | x = rearrange(x, "n h d -> n (h d)") 204 | splited_x = torch.split(x, seqlens.tolist(), 0) 205 | x = torch.nn.utils.rnn.pad_sequence(splited_x, batch_first=True) 206 | x = rearrange(x, "b n (h d) -> b h n d", h=num_heads) 207 | x = x.as_strided( 208 | size=(batch_size, num_heads, y_seqlens.max().item(), kernel_size, head_dim), 209 | stride=( 210 | x.stride(0), 211 | x.stride(1), 212 | kernel_stride * x.stride(2), 213 | x.stride(2), 214 | x.stride(3), 215 | ), 216 | ) 217 | y = einsum( 218 | x, 219 | rearrange(w, "h (k d) D -> h k d D", k=kernel_size), 220 | "b h n k d, h k d D -> b n h D", 221 | ) 222 | # only keep useful part 223 | y = torch.cat([y[i, : y_seqlens[i]] for i in range(batch_size)], dim=0) 224 | 225 | # position embedding as a bias 226 | if pe is not None: 227 | pe = rearrange(pe, "h k d -> h (k d)") 228 | bias = einsum(pe, w, "h D, h D d -> h d") 229 | y = y + bias.unsqueeze(0) 230 | 231 | return y, y_cu_seqlens 232 | -------------------------------------------------------------------------------- /native_sparse_attention/ops/torch/compressed_attention.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Xunhao Lai & Jianqiao Lu. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import torch 15 | import math 16 | from typing import Tuple 17 | from collections import Counter 18 | from einops import rearrange 19 | 20 | 21 | def transform_score( 22 | score: torch.Tensor, 23 | kernel_size: int, 24 | kernel_stride: int, 25 | block_size: int, 26 | cu_seqlens_q: torch.Tensor, 27 | cu_seqlens_k: torch.Tensor, 28 | max_seqlen_q: int, 29 | max_seqlen_k: int, 30 | init_blocks: int = 1, 31 | local_blocks: int = 2, 32 | ) -> torch.Tensor: 33 | num_k_heads, total_query_len, _ = score.shape 34 | pad_len = kernel_size // kernel_stride - 1 35 | score = torch.nn.functional.pad(score, (pad_len, pad_len), value=0) 36 | max_blocks = math.ceil(max_seqlen_q / block_size) 37 | full_blocks = max_seqlen_q // block_size 38 | block_score = torch.zeros( 39 | num_k_heads, 40 | total_query_len, 41 | max_blocks, 42 | dtype=torch.float32, 43 | device=score.device, 44 | ) 45 | offs = ( 46 | torch.arange(kernel_size // kernel_stride)[:, None] 47 | + torch.arange(block_size // kernel_stride)[None, :] 48 | ).view(-1) 49 | offs = dict(Counter(offs.tolist())) 50 | for k, v in offs.items(): 51 | block_score[..., :full_blocks] += ( 52 | v * score[..., k :: block_size // kernel_stride][..., :full_blocks] 53 | ) 54 | # set init block and local block score 55 | batch_size = cu_seqlens_q.shape[0] - 1 56 | q_idx = torch.cat( 57 | [ 58 | torch.arange(cu_seqlens_q[i + 1] - cu_seqlens_q[i], device=score.device) 59 | for i in range(batch_size) 60 | ], 61 | dim=0, 62 | ) 63 | q_idx = q_idx // block_size 64 | b_idx = torch.arange(max_blocks, device=score.device) 65 | block_score[..., :init_blocks] = torch.inf 66 | local_mask = (q_idx[:, None] >= b_idx[None, :]) & ( 67 | q_idx[:, None] < b_idx[None, :] + local_blocks 68 | ) 69 | local_mask = local_mask.unsqueeze(0).expand(num_k_heads, -1, -1) 70 | block_score[local_mask] = torch.inf 71 | return block_score 72 | 73 | 74 | def compressed_attention_torch( 75 | q: torch.Tensor, # [total_query_len, num_q_heads, head_dim] 76 | k: torch.Tensor, # [total_key_len, num_k_heads, head_dim] 77 | v: torch.Tensor, # [total_key_len, num_k_heads, head_dim] 78 | kernel_size: int, 79 | kernel_stride: int, 80 | block_size: int, 81 | topk: int, 82 | cu_seqlens_q: torch.Tensor, 83 | cu_seqlens_k: torch.Tensor, 84 | max_seqlen_q: int, 85 | max_seqlen_k: int, 86 | sm_scale: float = None, 87 | init_blocks: int = 1, 88 | local_blocks: int = 2, 89 | ) -> Tuple[torch.Tensor, torch.Tensor]: 90 | """Attention between query and compressed key and value. Implemented with torch, only for debug. 91 | 92 | Args: 93 | q (torch.Tensor): shape [total_q_len, num_q_heads, head_dim] 94 | k (torch.Tensor): shape [total_kv_len, num_kv_heads, head_dim] 95 | v (torch.Tensor): shape [total_kv_len, num_kv_heads, head_dim] 96 | kernel_size (int): kernel size in compress_key_value 97 | kernel_stride (int): stride of compress_key_value 98 | block_size (int): key value block size for topk sparse attention. 99 | topk (int): number of blocks for each query. 100 | cu_seqlens_q (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen. 101 | cu_seqlens_k (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_k in flash_attn_func_varlen. 102 | max_seqlen_q (int): max q len of the batch. 103 | max_seqlen_k (int): max k len of the batch. 104 | sm_scale (float, optional): softmax scale. Defaults to None, means 1/sqrt(head_dim). 105 | init_blocks (int, optional): Number of init blocks for each query. Defaults to 1. 106 | local_blocks (int, optional): Number of local blocks for each query. Defaults to 2. 107 | 108 | Returns: 109 | Tuple[torch.Tensor, torch.Tensor]: attention output and topk_idx used in topk_sparse_attention 110 | """ 111 | assert block_size % kernel_size == 0 and kernel_size % kernel_stride == 0 112 | total_query_len, num_q_heads, head_dim = q.shape 113 | total_key_len, num_k_heads, _ = k.shape 114 | num_share_q_heads = num_q_heads // num_k_heads 115 | batch_size = cu_seqlens_q.shape[0] - 1 116 | if sm_scale is None: 117 | sm_scale = 1.0 / math.sqrt(head_dim) 118 | # get mask 119 | mask = torch.zeros( 120 | (total_query_len, total_key_len), dtype=torch.bool, device=q.device 121 | ) 122 | for b in range(batch_size): 123 | q_len, k_len = ( 124 | cu_seqlens_q[b + 1] - cu_seqlens_q[b], 125 | cu_seqlens_k[b + 1] - cu_seqlens_k[b], 126 | ) 127 | k_max_ids = ( 128 | torch.arange(k_len, device=q.device) * kernel_stride + kernel_size - 1 129 | ) 130 | q_ids = torch.arange(q_len, device=q.device) 131 | mask[ 132 | cu_seqlens_q[b] : cu_seqlens_q[b + 1], cu_seqlens_k[b] : cu_seqlens_k[b + 1] 133 | ] = (q_ids[:, None] >= k_max_ids[None, :]) 134 | # attention 135 | qk = ( 136 | torch.einsum("qhd,khd->hqk", q, k.repeat_interleave(num_share_q_heads, 1)) 137 | * sm_scale 138 | ) 139 | qk = qk.masked_fill_(~mask[None, ...], -torch.inf) 140 | # query from beginning of the sequence can't attend to any compressed key 141 | qk = qk.softmax(dim=-1, dtype=torch.float32) 142 | qk = qk.nan_to_num(0) 143 | attn_output = torch.einsum( 144 | "hqk,khd->qhd", qk.to(v.dtype), v.repeat_interleave(num_share_q_heads, 1) 145 | ) 146 | with torch.no_grad(): 147 | # get avg score over gqa heads 148 | # qk shape [num_k_heads, total_q_len, total_k_len] 149 | score = torch.zeros( 150 | num_k_heads, 151 | cu_seqlens_q[-1], 152 | max_seqlen_k, 153 | dtype=torch.float32, 154 | device=q.device, 155 | ) 156 | qk = rearrange(qk, "(h g) q k -> h g q k", h=num_k_heads).sum(1) 157 | for b in range(batch_size): 158 | score[ 159 | :, 160 | cu_seqlens_q[b] : cu_seqlens_q[b + 1], 161 | : cu_seqlens_k[b + 1] - cu_seqlens_k[b], 162 | ] = qk[ 163 | :, 164 | cu_seqlens_q[b] : cu_seqlens_q[b + 1], 165 | cu_seqlens_k[b] : cu_seqlens_k[b + 1], 166 | ] 167 | # transform score to block-wise score 168 | score = transform_score( 169 | score, 170 | kernel_size, 171 | kernel_stride, 172 | block_size, 173 | cu_seqlens_q, 174 | cu_seqlens_k, 175 | max_seqlen_q, 176 | max_seqlen_k, 177 | init_blocks, 178 | local_blocks, 179 | ) 180 | # get topk 181 | batch_size = cu_seqlens_q.shape[0] - 1 182 | q_idx = torch.cat( 183 | [ 184 | torch.arange(cu_seqlens_q[i + 1] - cu_seqlens_q[i], device=q.device) 185 | for i in range(batch_size) 186 | ], 187 | dim=0, 188 | ) 189 | q_idx = q_idx // block_size 190 | topk = min(topk, score.shape[-1]) 191 | topk_idx = score.topk(topk, dim=-1).indices.sort(-1).values 192 | topk_idx[topk_idx > q_idx[None, :, None]] = -1 193 | topk_idx = topk_idx.to(torch.int32) 194 | return attn_output, topk_idx 195 | -------------------------------------------------------------------------------- /native_sparse_attention/ops/torch/compressed_attention_decode.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Xunhao Lai & Jianqiao Lu. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import torch 15 | import math 16 | from typing import Tuple, Optional 17 | from collections import Counter 18 | from einops import rearrange 19 | 20 | 21 | def transform_score( 22 | score: torch.Tensor, 23 | seqlens: torch.Tensor, 24 | kernel_size: int, 25 | kernel_stride: int, 26 | block_size: int, 27 | init_blocks: int = 1, 28 | local_blocks: int = 2, 29 | ) -> torch.Tensor: 30 | num_k_heads, batch_size, kv_len = score.shape 31 | pad_len = kernel_size // kernel_stride - 1 32 | score = torch.nn.functional.pad(score, (pad_len, pad_len), value=0) 33 | max_seqlen = seqlens.max().item() 34 | max_blocks = math.ceil(max_seqlen / block_size) 35 | full_blocks = max_seqlen // block_size 36 | block_score = torch.zeros( 37 | num_k_heads, 38 | batch_size, 39 | max_blocks, 40 | dtype=torch.float32, 41 | device=score.device, 42 | ) 43 | offs = ( 44 | torch.arange(kernel_size // kernel_stride)[:, None] 45 | + torch.arange(block_size // kernel_stride)[None, :] 46 | ).view(-1) 47 | offs = dict(Counter(offs.tolist())) 48 | for k, v in offs.items(): 49 | block_score[..., :full_blocks] += ( 50 | v * score[..., k :: block_size // kernel_stride][..., :full_blocks] 51 | ) 52 | # set init block and local block score 53 | q_idx = (seqlens - 1) // block_size 54 | b_idx = torch.arange(max_blocks, device=score.device) 55 | block_score[..., :init_blocks] = torch.inf 56 | local_mask = (q_idx[:, None] >= b_idx[None, :]) & ( 57 | q_idx[:, None] < b_idx[None, :] + local_blocks 58 | ) 59 | local_mask = local_mask.unsqueeze(0).expand(num_k_heads, -1, -1) 60 | block_score[local_mask] = torch.inf 61 | block_score = block_score.nan_to_num(0, torch.inf, -torch.inf) 62 | return block_score 63 | 64 | 65 | def compressed_attention_decode( 66 | q: torch.Tensor, 67 | k: torch.Tensor, 68 | v: torch.Tensor, 69 | seqlens: torch.Tensor, 70 | compress_seqlens: torch.Tensor, 71 | kernel_size: int, 72 | kernel_stride: int, 73 | block_size: int, 74 | topk: int, 75 | init_blocks: int = 1, 76 | local_blocks: int = 2, 77 | sm_scale: Optional[float] = None, 78 | ) -> Tuple[torch.Tensor, torch.Tensor]: 79 | """_summary_ 80 | 81 | Args: 82 | q (torch.Tensor): shape [batch_size, num_q_heads, head_dim] 83 | k (torch.Tensor): shape [batch_size, kv_len, num_kv_heads, head_dim] 84 | v (torch.Tensor): shape [batch_size, kv_len, num_kv_heads, head_dim] 85 | seqlens (torch.Tensor): original kv length for each sequence 86 | compress_seqlens (torch.Tensor): kv length for each sequence after compression 87 | kernel_size (int): kernel size in compress_key_value 88 | kernel_stride (int): stride of compress_key_value 89 | block_size (int): key value block size for topk sparse attention. 90 | topk (int): number of blocks for each query. 91 | init_blocks (int, optional): Number of init blocks for each query. Defaults to 1. 92 | local_blocks (int, optional): Number of local blocks for each query. Defaults to 2. 93 | sm_scale (float, optional): softmax scale. Defaults to None, means 1/sqrt(head_dim). 94 | 95 | Returns: 96 | Tuple[torch.Tensor, torch.Tensor]: attention output and topk_idx used in topk_sparse_attention_decode 97 | """ 98 | assert block_size % kernel_size == 0 and kernel_size % kernel_stride == 0 99 | batch_size, num_q_heads, head_dim = q.shape 100 | batch_size, kv_len, num_k_heads, _ = k.shape 101 | num_share_q_heads = num_q_heads // num_k_heads 102 | if sm_scale is None: 103 | sm_scale = 1.0 / math.sqrt(head_dim) 104 | # input is too short to have a valid block 105 | if kv_len == 0: 106 | return torch.zeros_like(q), torch.zeros( 107 | num_k_heads, batch_size, 1, device=q.device, dtype=torch.int32 108 | ) 109 | # get mask 110 | mask = ( 111 | compress_seqlens[:, None] 112 | > torch.arange( 113 | kv_len, device=compress_seqlens.device, dtype=compress_seqlens.dtype 114 | )[None, :] 115 | ) 116 | # attention 117 | qk = ( 118 | torch.einsum( 119 | "bihgd, bjhgd -> bhgij", 120 | rearrange(q, "b (h g) d -> b 1 h g d", g=num_share_q_heads), 121 | rearrange(k, "b j h d -> b j h 1 d"), 122 | ) 123 | * sm_scale 124 | ) 125 | qk = qk.masked_fill_(~mask[:, None, None, None, :], -torch.inf) 126 | qk = qk.softmax(dim=-1, dtype=torch.float32) 127 | qk = qk.nan_to_num_(0) # qk is nan when seqlen == 0 128 | attn_output = torch.einsum( 129 | "bhgij, bjhgd -> bihgd", 130 | qk.to(v.dtype), 131 | rearrange(v, "b k h d -> b k h 1 d"), 132 | ) 133 | attn_output = rearrange(attn_output, "b 1 h g d -> b (h g) d") 134 | 135 | # get score 136 | score = rearrange(qk.sum(2).squeeze(2), "b h j -> h b j") 137 | # transform score to block-wise score 138 | score = transform_score( 139 | score, 140 | seqlens, 141 | kernel_size, 142 | kernel_stride, 143 | block_size, 144 | init_blocks, 145 | local_blocks, 146 | ) 147 | # get topk 148 | q_idx = (seqlens - 1) // block_size 149 | topk = min(topk, score.shape[-1]) 150 | topk_idx = score.topk(topk, dim=-1).indices.sort(-1).values 151 | topk_idx[topk_idx > q_idx[None, :, None]] = -1 152 | topk_idx = topk_idx.to(torch.int32) 153 | return attn_output, topk_idx 154 | -------------------------------------------------------------------------------- /native_sparse_attention/ops/torch/topk_sparse_attention.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Xunhao Lai & Jianqiao Lu. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import torch 15 | import math 16 | from typing import Optional 17 | 18 | 19 | def topk_sparse_attention_torch( 20 | q: torch.Tensor, 21 | k: torch.Tensor, 22 | v: torch.Tensor, 23 | topk_idx: torch.Tensor, 24 | block_size_k: int, 25 | cu_seqlens: torch.Tensor, 26 | softmax_scale: Optional[float] = None, 27 | block_size_q: int = 1, 28 | ) -> torch.Tensor: 29 | """Simple topk sparse attention varlen version implemented in torch. Extremly slow, only for debugging. 30 | 31 | Args: 32 | q (torch.Tensor): shape [total_len, num_q_heads, head_dim] 33 | k (torch.Tensor): shape [total_len, num_kv_heads, head_dim] 34 | v (torch.Tensor): shape [total_len, num_kv_heads, head_dim] 35 | topk_idx (torch.Tensor): topk block idx for each query, shape [num_kv_heads, total_len, topk]. -1 means padding. 36 | block_size_q (int): query block size. 37 | block_size_k (int): key value block size. 38 | cu_seqlens (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens in flash_attn_func_varlen. 39 | softmax_scale (Optional[float], optional): Defaults to None, means 1/sqrt(head_dim). 40 | 41 | Returns: 42 | torch.Tensor: attention output, shape [total_len, num_q_heads, head_dim] 43 | """ 44 | total_seqlen, num_q_heads, head_dim = q.shape 45 | total_seqlen, num_kv_heads, head_dim = k.shape 46 | num_share_q_heads = num_q_heads // num_kv_heads 47 | batch_size = cu_seqlens.shape[0] - 1 48 | topk = topk_idx.shape[-1] 49 | seqlens = cu_seqlens[1:] - cu_seqlens[:-1] 50 | seqblocks_q = torch.ceil(seqlens / block_size_q).to(torch.int32) 51 | cu_seqblocks_q = torch.nn.functional.pad(seqblocks_q.cumsum(0), (1, 0), value=0) 52 | if softmax_scale is None: 53 | softmax_scale = 1.0 / math.sqrt(head_dim) 54 | # get mask 55 | mask = torch.zeros( 56 | (num_kv_heads, total_seqlen, total_seqlen), dtype=torch.bool, device=q.device 57 | ) 58 | for i in range(batch_size): 59 | num_q_blocks = math.ceil(seqlens[i] / block_size_q) 60 | num_kv_blocks = math.ceil(seqlens[i] / block_size_k) 61 | for h in range(num_kv_heads): 62 | temp_mask = torch.zeros( 63 | num_q_blocks, num_kv_blocks, dtype=torch.bool, device=q.device 64 | ) 65 | temp_idx = topk_idx[h, cu_seqblocks_q[i] : cu_seqblocks_q[i + 1]].clone() 66 | temp_idx[temp_idx < 0] = 0 67 | temp_mask[torch.arange(num_q_blocks).to(q.device)[:, None], temp_idx] = True 68 | temp_mask = torch.repeat_interleave(temp_mask, block_size_q, dim=0) 69 | temp_mask = torch.repeat_interleave(temp_mask, block_size_k, dim=1) 70 | temp_mask = temp_mask[: seqlens[i], : seqlens[i]] 71 | mask[ 72 | h, cu_seqlens[i] : cu_seqlens[i + 1], cu_seqlens[i] : cu_seqlens[i + 1] 73 | ] = temp_mask 74 | mask = torch.tril(mask).repeat_interleave(num_share_q_heads, 0) 75 | # qk attn 76 | qk = ( 77 | torch.einsum("qhd,khd->hqk", q, k.repeat_interleave(num_share_q_heads, 1)) 78 | * softmax_scale 79 | ) 80 | qk = torch.masked_fill(qk, ~mask, -torch.inf) 81 | qk = torch.softmax(qk, dim=-1, dtype=torch.float32).to(q.dtype) 82 | o = torch.einsum("hqk,khd->qhd", qk, v.repeat_interleave(num_share_q_heads, 1)) 83 | return o 84 | -------------------------------------------------------------------------------- /native_sparse_attention/ops/triton/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XunhaoLai/native-sparse-attention-triton/9bea856c911ebf263be88d797fb28458f82f1d94/native_sparse_attention/ops/triton/__init__.py -------------------------------------------------------------------------------- /native_sparse_attention/ops/triton/flash_attention_decode.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Xunhao Lai & Jianqiao Lu. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import math 16 | import torch 17 | import triton 18 | import triton.language as tl 19 | from typing import Optional 20 | 21 | 22 | @triton.jit 23 | def decode_kernel( 24 | q_ptr, # Q: b x h x d 25 | k_ptr, # K: b x n x h x d 26 | v_ptr, # V: b x n x h x d 27 | o_ptr, # O: b x h x d 28 | seqlens, 29 | # shape 30 | BATCH_SIZE, 31 | NUM_SHARE_Q_HEADS, 32 | HEAD_DIM, 33 | # sm_scale 34 | sm_scale, 35 | # stride 36 | stride_qb, 37 | stride_qh, 38 | stride_qd, 39 | stride_kb, 40 | stride_kn, 41 | stride_kh, 42 | stride_kd, 43 | stride_vb, 44 | stride_vn, 45 | stride_vh, 46 | stride_vd, 47 | stride_ob, 48 | stride_oh, 49 | stride_od, 50 | # META parameters 51 | BLOCK_SIZE_B: tl.constexpr, 52 | BLOCK_SIZE_K: tl.constexpr, 53 | BLOCK_SIZE_D: tl.constexpr, 54 | ): 55 | qk_scale = sm_scale * 1.44269504 56 | # get batch id and head id 57 | pid_h = tl.program_id(0) 58 | pid_b = tl.program_id(1) 59 | pid_kh = pid_h // NUM_SHARE_Q_HEADS 60 | # get q k start and len after rmpad 61 | off_b = tl.arange(0, BLOCK_SIZE_B) 62 | kv_len = tl.load( 63 | seqlens + pid_b * BLOCK_SIZE_B + off_b, 64 | mask=pid_b * BLOCK_SIZE_B + off_b < BATCH_SIZE, 65 | other=0, 66 | ) 67 | max_kv_len = tl.max(kv_len) 68 | # init qkv pointer 69 | q_ptrs = tl.make_block_ptr( 70 | base=q_ptr + pid_h * stride_qh, 71 | shape=(BATCH_SIZE, HEAD_DIM), 72 | strides=(stride_qb, stride_qd), 73 | offsets=(pid_b * BLOCK_SIZE_B, 0), 74 | block_shape=(BLOCK_SIZE_B, BLOCK_SIZE_D), 75 | order=(1, 0), 76 | ) 77 | k_ptrs = tl.make_block_ptr( 78 | base=k_ptr + pid_kh * stride_kh, 79 | shape=(BATCH_SIZE, max_kv_len, HEAD_DIM), 80 | strides=(stride_kb, stride_kn, stride_kd), 81 | offsets=(pid_b * BLOCK_SIZE_B, 0, 0), 82 | block_shape=(BLOCK_SIZE_B, BLOCK_SIZE_K, BLOCK_SIZE_D), 83 | order=(2, 1, 0), 84 | ) 85 | v_ptrs = tl.make_block_ptr( 86 | base=v_ptr + pid_kh * stride_vh, 87 | shape=(BATCH_SIZE, max_kv_len, HEAD_DIM), 88 | strides=(stride_vb, stride_vn, stride_vd), 89 | offsets=(pid_b * BLOCK_SIZE_B, 0, 0), 90 | block_shape=(BLOCK_SIZE_B, BLOCK_SIZE_K, BLOCK_SIZE_D), 91 | order=(2, 1, 0), 92 | ) 93 | # load q 94 | q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero") 95 | # init statistics 96 | off_k = tl.arange(0, BLOCK_SIZE_K) 97 | m_i = tl.full((BLOCK_SIZE_B,), float("-inf"), dtype=tl.float32) 98 | lse_i = tl.full((BLOCK_SIZE_B,), float("-inf"), dtype=tl.float32) 99 | acc_o = tl.full((BLOCK_SIZE_B, BLOCK_SIZE_D), 0, dtype=tl.float32) 100 | # full attention or causal attention 101 | for i in range(0, max_kv_len, BLOCK_SIZE_K): 102 | i = tl.multiple_of(i, BLOCK_SIZE_K) 103 | # load k 104 | k = tl.load(k_ptrs, boundary_check=(0, 1, 2), padding_option="zero") 105 | # compute qk 106 | qk = tl.zeros((BLOCK_SIZE_B, BLOCK_SIZE_K), dtype=tl.float32) 107 | qk += tl.where(off_k[None, :] + i < kv_len[:, None], 0, float("-inf")) 108 | # [B, D], [B, K, D] -> [B, K] 109 | qk += tl.sum(q[:, None, :] * k, axis=2) * qk_scale 110 | # compute m_ij and l_ij 111 | m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) 112 | p = tl.math.exp2(qk - m_ij[:, None]) 113 | l_ij = tl.sum(p, axis=1) 114 | # scale acc_o 115 | acc_o_scale = tl.math.exp2(m_i - m_ij) 116 | acc_o = acc_o * acc_o_scale[:, None] 117 | # load v and update acc_o 118 | v = tl.load(v_ptrs, boundary_check=(0, 1, 2), padding_option="zero") 119 | p = p.to(v.dtype) 120 | # [B, K], [B, K, D] -> [B, D] 121 | acc_o += tl.sum(p[:, :, None] * v, axis=1) 122 | # update statistics 123 | m_i = m_ij 124 | lse_i = m_ij + tl.math.log2(tl.math.exp2(lse_i - m_ij) + l_ij) 125 | # update ptrs 126 | k_ptrs = tl.advance(k_ptrs, (0, BLOCK_SIZE_K, 0)) 127 | v_ptrs = tl.advance(v_ptrs, (0, BLOCK_SIZE_K, 0)) 128 | # final scale 129 | acc_o = acc_o * tl.math.exp2(m_i - lse_i)[:, None] 130 | # save output 131 | o_ptrs = tl.make_block_ptr( 132 | base=o_ptr + pid_h * stride_oh, 133 | shape=(BATCH_SIZE, HEAD_DIM), 134 | strides=(stride_ob, stride_od), 135 | offsets=(pid_b * BLOCK_SIZE_B, 0), 136 | block_shape=(BLOCK_SIZE_B, BLOCK_SIZE_D), 137 | order=(1, 0), 138 | ) 139 | tl.store(o_ptrs, acc_o.to(o_ptr.dtype.element_ty), boundary_check=(0, 1)) 140 | 141 | 142 | def flash_attention_decode( 143 | q: torch.Tensor, # [batch_size, num_heads, head_dim] 144 | k: torch.Tensor, # [batch_size, max_len, num_heads, head_dim] 145 | v: torch.Tensor, 146 | seqlens: torch.Tensor, # [batch_size, ] 147 | sm_scale: Optional[float] = None, 148 | ) -> torch.Tensor: 149 | """flash attention for decode. 150 | 151 | Args: 152 | q (torch.Tensor): query, shape [batch_size, num_q_heads, head_dim] 153 | k (torch.Tensor): key, shape [batch_size, kv_len, num_kv_heads, head_dim] 154 | v (torch.Tensor): value, shape [batch_size, kv_len, num_kv_heads, head_dim] 155 | seqlens (torch.Tensor): kv length for each sequence 156 | sm_scale (Optional[float]): softmax scale, default to 1/sqrt(head_dim) 157 | 158 | Returns: 159 | torch.Tensor: attention output 160 | """ 161 | # dtype check 162 | assert q.dtype == torch.bfloat16 or q.dtype == torch.float16 163 | assert k.dtype == q.dtype and v.dtype == q.dtype 164 | assert seqlens.dtype == torch.int32 165 | # shape 166 | batch_size, num_q_heads, head_dim = q.shape 167 | _, k_len, num_k_heads, head_dim = k.shape 168 | _, v_len, num_v_heads, head_dim = v.shape 169 | assert k_len == v_len and batch_size == seqlens.shape[0] 170 | # gqa 171 | assert num_k_heads == num_v_heads 172 | assert num_q_heads % num_k_heads == 0 173 | num_share_q_heads = num_q_heads // num_k_heads 174 | # sm scale 175 | if sm_scale is None: 176 | sm_scale = 1 / math.sqrt(head_dim) 177 | # output tensor 178 | o = torch.zeros_like(q) 179 | # launch kernel 180 | num_warps = 4 if head_dim <= 64 else 8 181 | num_stages = 3 182 | # there is a bug for triton 3.0.0 if BLOCK_SIZE_B > 16 183 | BLOCK_SIZE_B = min(16, triton.next_power_of_2(batch_size)) 184 | BLOCK_SIZE_K = 128 185 | BLOCK_SIZE_D = triton.next_power_of_2(head_dim) 186 | grid = (num_q_heads, triton.cdiv(batch_size, BLOCK_SIZE_B)) 187 | decode_kernel[grid]( 188 | q, 189 | k, 190 | v, 191 | o, 192 | seqlens, 193 | batch_size, 194 | num_share_q_heads, 195 | head_dim, 196 | sm_scale, 197 | q.stride(0), 198 | q.stride(1), 199 | q.stride(2), 200 | k.stride(0), 201 | k.stride(1), 202 | k.stride(2), 203 | k.stride(3), 204 | v.stride(0), 205 | v.stride(1), 206 | v.stride(2), 207 | v.stride(3), 208 | o.stride(0), 209 | o.stride(1), 210 | o.stride(2), 211 | BLOCK_SIZE_B=BLOCK_SIZE_B, 212 | BLOCK_SIZE_K=BLOCK_SIZE_K, 213 | BLOCK_SIZE_D=BLOCK_SIZE_D, 214 | num_warps=num_warps, 215 | num_stages=num_stages, 216 | ) 217 | return o 218 | 219 | 220 | def torch_attention_decode( 221 | q: torch.Tensor, # [batch_size, num_heads, head_dim] 222 | k: torch.Tensor, # [batch_size, max_len, num_heads, head_dim] 223 | v: torch.Tensor, 224 | seqlens: torch.Tensor, # [batch_size, ] 225 | sm_scale: Optional[float] = None, 226 | ): 227 | # dtype check 228 | assert q.dtype == torch.bfloat16 or q.dtype == torch.float16 229 | assert k.dtype == q.dtype and v.dtype == q.dtype 230 | assert seqlens.dtype == torch.int32 231 | # shape 232 | batch_size, num_q_heads, head_dim = q.shape 233 | _, k_len, num_k_heads, head_dim = k.shape 234 | _, v_len, num_v_heads, head_dim = v.shape 235 | assert k_len == v_len and batch_size == seqlens.shape[0] 236 | # gqa 237 | assert num_k_heads == num_v_heads 238 | assert num_q_heads % num_k_heads == 0 239 | num_share_q_heads = num_q_heads // num_k_heads 240 | # sm scale 241 | if sm_scale is None: 242 | sm_scale = 1 / math.sqrt(head_dim) 243 | # attention 244 | attn = ( 245 | torch.einsum( 246 | "bqhd,bkhd->bhqk", 247 | q.unsqueeze(1), 248 | k.repeat_interleave(num_share_q_heads, dim=2), 249 | ) 250 | * sm_scale 251 | ) 252 | mask = torch.arange(k_len, device=q.device)[None, :] < seqlens[:, None] 253 | attn = attn.masked_fill(~mask[:, None, None, :], -torch.inf) 254 | attn = torch.softmax(attn, dim=-1, dtype=torch.float32).to(q.dtype) 255 | out = torch.einsum( 256 | "bhqk,bkhd->bqhd", attn, v.repeat_interleave(num_share_q_heads, dim=2) 257 | ).squeeze(1) 258 | return out 259 | 260 | 261 | if __name__ == "__main__": 262 | torch.manual_seed(42) 263 | batch_size = 76 264 | max_length = 8192 265 | seqlens = torch.arange(batch_size, dtype=torch.int32).cuda() * 128 + 1 266 | seqlens[seqlens > max_length] = max_length 267 | seqlens = seqlens[torch.randn_like(seqlens, dtype=torch.float32).argsort(-1)] 268 | q = ( 269 | torch.empty(batch_size, 32, 128, device="cuda") 270 | .uniform_(-1, 1) 271 | .to(torch.bfloat16) 272 | ) 273 | k = ( 274 | torch.empty(batch_size, max_length, 4, 128, device="cuda") 275 | .uniform_(-1, 1) 276 | .to(torch.bfloat16) 277 | ) 278 | v = ( 279 | torch.empty(batch_size, max_length, 4, 128, device="cuda") 280 | .uniform_(-1, 1) 281 | .to(torch.bfloat16) 282 | ) 283 | 284 | o1 = torch_attention_decode(q, k, v, seqlens) 285 | o2 = flash_attention_decode(q, k, v, seqlens) 286 | 287 | print(torch.allclose(o1, o2, atol=1e-2, rtol=1e-2)) 288 | -------------------------------------------------------------------------------- /native_sparse_attention/ops/triton/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Xunhao Lai & Jianqiao Lu. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import torch 15 | 16 | 17 | def is_hopper_gpu(): 18 | if torch.cuda.is_available(): 19 | device_capability = torch.cuda.get_device_capability() 20 | major, minor = device_capability 21 | return major == 9 22 | return False 23 | 24 | 25 | def get_compressed_seqlens( 26 | cu_seqlens: torch.Tensor, kernel_size: int, kernel_stride: int 27 | ): 28 | # compute seqlens after compression 29 | seqlens = cu_seqlens[1:] - cu_seqlens[:-1] 30 | y_seqlens = torch.floor((seqlens - kernel_size) / kernel_stride).to(torch.int32) + 1 31 | # corner case, if sequence_length < kernel_size, no compression for this sequence 32 | y_seqlens[seqlens < kernel_size] = 0 33 | y_cu_seqlens = torch.zeros( 34 | y_seqlens.shape[0] + 1, dtype=torch.int32, device=cu_seqlens.device 35 | ) 36 | y_cu_seqlens[1:] = torch.cumsum(y_seqlens, dim=0) 37 | return y_seqlens, y_cu_seqlens 38 | 39 | 40 | def get_num_warps_stages(head_dim, block_size, is_hopper_gpu): 41 | """ 42 | Returns recommended num_warps and num_stages for a Sparse Attention kernel in Triton. 43 | 44 | Args: 45 | head_dim (int): Size of the head dimension. 46 | block_size (int): Size of the block in the attention matrix. 47 | is_hopper_gpu (bool): True if Hopper GPU, False if Ampere GPU. 48 | 49 | Returns: 50 | tuple: (num_warps, num_stages) recommended values. 51 | """ 52 | # Determine if head_dim and block_size exceed 64 53 | head_large = head_dim > 64 54 | block_large = block_size > 64 55 | 56 | if is_hopper_gpu: 57 | # Hopper GPU recommendations 58 | if head_large and block_large: 59 | num_warps = 8 60 | num_stages = 3 61 | elif head_large or block_large: 62 | num_warps = 4 63 | num_stages = 3 64 | else: 65 | num_warps = 2 66 | num_stages = 2 67 | else: 68 | # Ampere GPU recommendations 69 | if head_large and block_large: 70 | num_warps = 8 71 | num_stages = 3 72 | elif head_large or block_large: 73 | num_warps = 8 74 | num_stages = 3 75 | else: 76 | num_warps = 2 77 | num_stages = 2 78 | if head_dim > 128: 79 | num_stages = 2 80 | return num_warps, num_stages 81 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Xunhao Lai & Jianqiao Lu. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from setuptools import setup, find_packages 16 | 17 | # Read the README.md file for the long description 18 | with open("README.md", "r", encoding="utf-8") as fh: 19 | long_description = fh.read() 20 | 21 | # Define the setup configuration 22 | setup( 23 | name="native-sparse-attention-triton", 24 | version="0.1.0", 25 | description="An efficient implementation of Native Sparse Attention using Triton", 26 | long_description=long_description, 27 | long_description_content_type="text/markdown", 28 | author="XunhaoLai", 29 | author_email="laixunhao@pku.edu.cn", # Replace with your actual email 30 | url="https://github.com/XunhaoLai/native-sparse-attention-triton", 31 | packages=find_packages(), 32 | install_requires=[ 33 | "torch>=2.1.0", 34 | "triton>=3.0.0", 35 | "einops>=0.7.0", 36 | "flash-attn>=2.6.3", 37 | "transformers>=4.44.0", 38 | ], 39 | classifiers=[ 40 | "Programming Language :: Python :: 3", 41 | "License :: OSI Approved :: Apache Software License", 42 | "Operating System :: OS Independent", 43 | ], 44 | python_requires=">=3.9", 45 | ) 46 | -------------------------------------------------------------------------------- /test/test_compress_key_value.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Xunhao Lai & Jianqiao Lu. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific 13 | import torch 14 | import triton 15 | from native_sparse_attention.ops import linear_compress 16 | 17 | 18 | if __name__ == "__main__": 19 | torch.manual_seed(42) 20 | num_heads = 4 21 | head_dim = 192 22 | kernel_size = 32 23 | kernel_stride = 16 24 | seqlens = torch.LongTensor([1000, 2000, 4096]).int().cuda() 25 | cu_seqlens = torch.cat( 26 | [ 27 | torch.zeros(1, dtype=torch.int32, device="cuda"), 28 | torch.cumsum(seqlens, dim=0), 29 | ], 30 | dim=0, 31 | ).to(torch.int32) 32 | 33 | x = ( 34 | torch.zeros(cu_seqlens[-1], num_heads, head_dim) 35 | .uniform_(-1, 1) 36 | .cuda() 37 | .bfloat16() 38 | .requires_grad_() 39 | ) 40 | w = ( 41 | torch.zeros(num_heads, kernel_size * head_dim, head_dim) 42 | .uniform_(-1, 1) 43 | .cuda() 44 | .bfloat16() 45 | .requires_grad_() 46 | ) 47 | pe = ( 48 | torch.zeros(num_heads, kernel_size, head_dim) 49 | .uniform_(-1, 1) 50 | .cuda() 51 | .bfloat16() 52 | .requires_grad_() 53 | ) 54 | 55 | y, y_cu_seqlens = linear_compress(x, w, cu_seqlens, kernel_size, kernel_stride, pe) 56 | 57 | loss = (y * torch.randn_like(y)).mean() 58 | loss.backward() 59 | 60 | print(y.shape, y_cu_seqlens) 61 | print(y.norm(), x.grad.norm()) 62 | print( 63 | w.grad.norm() if w.grad is not None else None, 64 | pe.grad.norm() if pe.grad is not None else None, 65 | ) 66 | 67 | # benchmark 68 | @triton.testing.perf_report( 69 | triton.testing.Benchmark( 70 | x_names=["N"], 71 | x_vals=[1024 * 2**i for i in range(1, 6)], 72 | line_arg="provider", 73 | line_vals=["batch1", "batch8", "batch32"], 74 | line_names=["batch1", "batch8", "batch32"], 75 | styles=[("green", "-"), ("blue", "-"), ("blue", "--")], 76 | ylabel="ms", 77 | plot_name="** forward **", 78 | args={"H": 4, "D": 128}, 79 | ) 80 | ) 81 | def benchmark(N, H, D, provider): 82 | K, S = 32, 16 83 | x = torch.zeros(N, H, D, device="cuda", dtype=torch.bfloat16).uniform_(-1, 1) 84 | w = torch.zeros(H, K * D, D, device="cuda", dtype=torch.bfloat16).uniform_( 85 | -1, 1 86 | ) 87 | pe = torch.zeros(H, K, D, device="cuda", dtype=torch.bfloat16).uniform_(-1, 1) 88 | cu_seqlens_b1 = torch.LongTensor([0, N]).int().cuda() 89 | cu_seqlens_b8 = ( 90 | torch.LongTensor([N // 8 if i > 0 else 0 for i in range(9)]).int().cuda() 91 | ) 92 | cu_seqlens_b32 = ( 93 | torch.LongTensor([N // 32 if i > 0 else 0 for i in range(33)]).int().cuda() 94 | ) 95 | cu_seqlens_b1 = cu_seqlens_b1.cumsum(0).to(torch.int32) 96 | cu_seqlens_b8 = cu_seqlens_b8.cumsum(0).to(torch.int32) 97 | cu_seqlens_b32 = cu_seqlens_b32.cumsum(0).to(torch.int32) 98 | 99 | quantiles = [0.5, 0.2, 0.8] 100 | if provider == "batch1": 101 | ms, min_ms, max_ms = triton.testing.do_bench( 102 | lambda: linear_compress(x, w, cu_seqlens_b1, K, S, pe), 103 | quantiles=quantiles, 104 | ) 105 | if provider == "batch8": 106 | ms, min_ms, max_ms = triton.testing.do_bench( 107 | lambda: linear_compress(x, w, cu_seqlens_b8, K, S, pe), 108 | quantiles=quantiles, 109 | ) 110 | if provider == "batch32": 111 | ms, min_ms, max_ms = triton.testing.do_bench( 112 | lambda: linear_compress(x, w, cu_seqlens_b32, K, S, pe), 113 | quantiles=quantiles, 114 | ) 115 | return ms, min_ms, max_ms 116 | 117 | benchmark.run(show_plots=True, print_data=True) 118 | -------------------------------------------------------------------------------- /test/test_compressed_attention.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Xunhao Lai & Jianqiao Lu. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific 13 | import torch 14 | import triton 15 | import math 16 | from native_sparse_attention.ops.torch.compressed_attention import ( 17 | compressed_attention_torch, 18 | ) 19 | from native_sparse_attention.ops.triton.compressed_attention import ( 20 | compressed_attention, 21 | _compressed_attention_bwd, 22 | ) 23 | from native_sparse_attention.ops import avgpool_compress 24 | from native_sparse_attention.ops.triton.flash_attention import ( 25 | flash_attention_varlen, 26 | _flash_attention_bwd, 27 | ) 28 | from flash_attn import flash_attn_varlen_func 29 | from flash_attn.flash_attn_interface import _flash_attn_varlen_backward 30 | 31 | 32 | if __name__ == "__main__": 33 | torch.manual_seed(42) 34 | num_heads = 32 35 | head_dim = 96 36 | kernel_size = 32 37 | kernel_stride = 16 38 | block_size = 64 39 | topk = 16 40 | seqlens = torch.LongTensor([1000, 4000, 8192]).int().cuda() 41 | cu_seqlens = torch.cat( 42 | [ 43 | torch.zeros(1, dtype=torch.int32, device="cuda"), 44 | torch.cumsum(seqlens, dim=0), 45 | ], 46 | dim=0, 47 | ).to(torch.int32) 48 | max_seqlen = seqlens.max().item() 49 | q = ( 50 | torch.empty(cu_seqlens[-1], num_heads, head_dim, device="cuda") 51 | .uniform_(-1, 1) 52 | .to(torch.float16) 53 | ) 54 | k = ( 55 | torch.empty(cu_seqlens[-1], num_heads // 4, head_dim, device="cuda") 56 | .uniform_(-1, 1) 57 | .to(torch.float16) 58 | ) 59 | v = ( 60 | torch.empty(cu_seqlens[-1], num_heads // 4, head_dim, device="cuda") 61 | .uniform_(-1, 1) 62 | .to(torch.float16) 63 | ) 64 | q.requires_grad = True 65 | k.requires_grad = True 66 | v.requires_grad = True 67 | 68 | ck, ck_cu_seqlens = avgpool_compress( 69 | k, None, cu_seqlens, kernel_size, kernel_stride 70 | ) 71 | 72 | ck = torch.empty_like(ck).uniform_(-1, 1) 73 | cv = torch.empty_like(ck).uniform_(-1, 1) 74 | ck.requires_grad = True 75 | cv.requires_grad = True 76 | 77 | ck_seqlens = ck_cu_seqlens[1:] - ck_cu_seqlens[:-1] 78 | ck_max_seqlen = ck_seqlens.max().item() 79 | 80 | o, topk_idx = compressed_attention_torch( 81 | q, 82 | ck, 83 | cv, 84 | kernel_size, 85 | kernel_stride, 86 | block_size, 87 | topk, 88 | cu_seqlens, 89 | ck_cu_seqlens, 90 | max_seqlen, 91 | ck_max_seqlen, 92 | ) 93 | 94 | randn = torch.randn_like(o) 95 | loss = (o * randn).sum() 96 | loss.backward() 97 | 98 | torch.manual_seed(42) 99 | 100 | q1 = q.detach().clone().requires_grad_() 101 | ck1 = ck.detach().clone().requires_grad_() 102 | cv1 = cv.detach().clone().requires_grad_() 103 | 104 | o1, topk_idx1 = compressed_attention( 105 | q1, 106 | ck1, 107 | cv1, 108 | kernel_size, 109 | kernel_stride, 110 | block_size, 111 | topk, 112 | cu_seqlens, 113 | ck_cu_seqlens, 114 | max_seqlen, 115 | ck_max_seqlen, 116 | ) 117 | randn1 = randn.clone().detach() 118 | loss1 = (o1 * randn1).sum() 119 | loss1.backward() 120 | 121 | print("Same Output:", torch.allclose(o, o1, atol=0.01, rtol=0.01)) 122 | print("Max Error:", (o - o1).abs().max().item()) 123 | print() 124 | print("Same Query Gradient:", torch.allclose(q.grad, q1.grad, atol=0.01, rtol=0.01)) 125 | print("Max Query Gradient Error:", (q.grad - q1.grad).abs().max().item()) 126 | print() 127 | print("Same Key Gradient:", torch.allclose(ck.grad, ck1.grad, atol=0.01, rtol=0.01)) 128 | print("Max Key Gradient Error:", (ck.grad - ck1.grad).abs().max().item()) 129 | print() 130 | print( 131 | "Same Value Gradient:", torch.allclose(cv.grad, cv1.grad, atol=0.01, rtol=0.01) 132 | ) 133 | print("Max Value Gradient Error:", (cv.grad - cv1.grad).abs().max().item()) 134 | print() 135 | 136 | # There are some discrepancies in the topk indices (about 3%). These might be due to bugs and will be addressed later. 137 | all_num = 0 138 | err_num = 0 139 | for h in range(topk_idx.shape[0]): 140 | for i in range(topk_idx.shape[1]): 141 | s = set(topk_idx[h, i][topk_idx[h, i] >= 0].tolist()) 142 | s1 = set(topk_idx1[h, i][topk_idx1[h, i] >= 0].tolist()) 143 | all_num += len(s) 144 | err_num += len(s) - len(s1 & s) 145 | print("Topk Idx Error Rate:", err_num / all_num) 146 | 147 | # benchmark 148 | @triton.testing.perf_report( 149 | triton.testing.Benchmark( 150 | x_names=["N"], 151 | x_vals=[1024 * 2**i for i in range(1, 8)], 152 | line_arg="provider", 153 | line_vals=[ 154 | "flash", 155 | "triton-flash", 156 | "triton-compressed", 157 | "triton-compressed-wo-score", 158 | ], 159 | line_names=[ 160 | "Flash", 161 | "Triton-Flash", 162 | "Compressed", 163 | "Compressed-wo-Score", 164 | ], 165 | styles=[("green", "-"), ("green", "--"), ("blue", "-"), ("blue", "--")], 166 | ylabel="ms", 167 | plot_name="** forward speed for compressed attention (kernel 32 stride 16) **", 168 | args={"H": 64, "D": 128}, 169 | ) 170 | ) 171 | def benchmark(N, H, D, provider): 172 | q = torch.randn((N, H, D), device="cuda", dtype=torch.bfloat16) 173 | k = torch.randn((N, H // 16, D), device="cuda", dtype=torch.bfloat16) 174 | v = torch.randn((N, H // 16, D), device="cuda", dtype=torch.bfloat16) 175 | cu_seqlens = torch.tensor([0, N], device="cuda", dtype=torch.int32) 176 | sm_scale = 1 / math.sqrt(D) 177 | com_k, com_cu_seqlens = avgpool_compress(k, None, cu_seqlens, 32, 16, None) 178 | com_v, com_cu_seqlens = avgpool_compress(v, None, cu_seqlens, 32, 16, None) 179 | M = (com_cu_seqlens[1:] - com_cu_seqlens[:-1]).max().item() 180 | 181 | quantiles = [0.5, 0.2, 0.8] 182 | if provider == "flash": 183 | ms, min_ms, max_ms = triton.testing.do_bench( 184 | lambda: flash_attn_varlen_func( 185 | q, 186 | k, 187 | v, 188 | cu_seqlens, 189 | cu_seqlens, 190 | N, 191 | N, 192 | dropout_p=0.0, 193 | causal=True, 194 | softmax_scale=sm_scale, 195 | ), 196 | quantiles=quantiles, 197 | ) 198 | if provider == "triton-flash": 199 | ms, min_ms, max_ms = triton.testing.do_bench( 200 | lambda: flash_attention_varlen( 201 | q, k, v, cu_seqlens, cu_seqlens, N, N, True, sm_scale 202 | ), 203 | quantiles=quantiles, 204 | ) 205 | if provider == "triton-compressed": 206 | ms, min_ms, max_ms = triton.testing.do_bench( 207 | lambda: compressed_attention( 208 | q, 209 | com_k, 210 | com_v, 211 | 32, 212 | 16, 213 | 64, 214 | 16, 215 | cu_seqlens, 216 | com_cu_seqlens, 217 | N, 218 | M, 219 | sm_scale, 220 | ), 221 | quantiles=quantiles, 222 | ) 223 | if provider == "triton-compressed-wo-score": 224 | ms, min_ms, max_ms = triton.testing.do_bench( 225 | lambda: compressed_attention( 226 | q, 227 | com_k, 228 | com_v, 229 | 32, 230 | 16, 231 | 64, 232 | -1, 233 | cu_seqlens, 234 | com_cu_seqlens, 235 | N, 236 | M, 237 | sm_scale, 238 | ), 239 | quantiles=quantiles, 240 | ) 241 | return ms, min_ms, max_ms 242 | 243 | benchmark.run(show_plots=True, print_data=True) 244 | 245 | # benchmark 246 | @triton.testing.perf_report( 247 | triton.testing.Benchmark( 248 | x_names=["N"], 249 | x_vals=[1024 * 2**i for i in range(1, 8)], 250 | line_arg="provider", 251 | line_vals=[ 252 | "flash", 253 | "triton-flash", 254 | "triton-compressed", 255 | ], 256 | line_names=[ 257 | "Flash", 258 | "Triton-Flash", 259 | "Compressed", 260 | ], 261 | styles=[("green", "-"), ("green", "--"), ("blue", "-"), ("blue", "--")], 262 | ylabel="ms", 263 | plot_name="** backward speed for compressed attention (kernel 32 stride 16) **", 264 | args={"H": 64, "D": 128}, 265 | ) 266 | ) 267 | def benchmark(N, H, D, provider): 268 | q = torch.randn((N, H, D), device="cuda", dtype=torch.bfloat16) 269 | k = torch.randn((N, H // 16, D), device="cuda", dtype=torch.bfloat16) 270 | v = torch.randn((N, H // 16, D), device="cuda", dtype=torch.bfloat16) 271 | o = torch.randn((N, H, D), device="cuda", dtype=torch.bfloat16) 272 | do = torch.randn((N, H, D), device="cuda", dtype=torch.bfloat16) 273 | lse = torch.randn((H, N), device="cuda", dtype=torch.float32) 274 | sm_scale = 1 / math.sqrt(D) 275 | cu_seqlens = torch.tensor([0, N], device="cuda", dtype=torch.int32) 276 | dq = torch.zeros_like(q) 277 | dk = torch.zeros_like(k) 278 | dv = torch.zeros_like(v) 279 | 280 | com_k, com_cu_seqlens = avgpool_compress(k, None, cu_seqlens, 32, 16, None) 281 | com_v, com_cu_seqlens = avgpool_compress(v, None, cu_seqlens, 32, 16, None) 282 | M = (com_cu_seqlens[1:] - com_cu_seqlens[:-1]).max().item() 283 | 284 | quantiles = [0.5, 0.2, 0.8] 285 | if provider == "flash": 286 | ms, min_ms, max_ms = triton.testing.do_bench( 287 | lambda: _flash_attn_varlen_backward( 288 | do, 289 | q, 290 | k, 291 | v, 292 | o, 293 | lse.transpose(0, 1), 294 | dq, 295 | dk, 296 | dv, 297 | cu_seqlens, 298 | cu_seqlens, 299 | N, 300 | N, 301 | dropout_p=0.0, 302 | causal=True, 303 | softmax_scale=sm_scale, 304 | window_size=(-1, -1), 305 | softcap=0.0, 306 | alibi_slopes=None, 307 | deterministic=False, 308 | ), 309 | quantiles=quantiles, 310 | ) 311 | if provider == "triton-flash": 312 | ms, min_ms, max_ms = triton.testing.do_bench( 313 | lambda: _flash_attention_bwd( 314 | o, do, lse, q, k, v, cu_seqlens, cu_seqlens, N, N, True, sm_scale 315 | ), 316 | quantiles=quantiles, 317 | ) 318 | if provider == "triton-compressed": 319 | ms, min_ms, max_ms = triton.testing.do_bench( 320 | lambda: _compressed_attention_bwd( 321 | o, 322 | do, 323 | lse, 324 | q, 325 | com_k, 326 | com_v, 327 | 32, 328 | 16, 329 | cu_seqlens, 330 | com_cu_seqlens, 331 | N, 332 | M, 333 | sm_scale, 334 | ), 335 | quantiles=quantiles, 336 | ) 337 | return ms, min_ms, max_ms 338 | 339 | benchmark.run(show_plots=True, print_data=True) 340 | -------------------------------------------------------------------------------- /test/test_flash_attention.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Xunhao Lai & Jianqiao Lu. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific 13 | import torch 14 | import triton 15 | import math 16 | from native_sparse_attention.ops.triton.flash_attention import ( 17 | flash_attention_varlen, 18 | _flash_attention_fwd, 19 | _flash_attention_bwd, 20 | ) 21 | from flash_attn import flash_attn_varlen_func 22 | from flash_attn.flash_attn_interface import ( 23 | _flash_attn_varlen_forward, 24 | _flash_attn_varlen_backward, 25 | ) 26 | 27 | 28 | if __name__ == "__main__": 29 | for causal in [False, True]: 30 | # triton flash attention 31 | torch.manual_seed(42) 32 | q = torch.randn( 33 | 1000, 32, 128, dtype=torch.float16, device="cuda", requires_grad=True 34 | ) 35 | k = torch.randn( 36 | 1000, 16, 128, dtype=torch.float16, device="cuda", requires_grad=True 37 | ) 38 | v = torch.randn( 39 | 1000, 16, 128, dtype=torch.float16, device="cuda", requires_grad=True 40 | ) 41 | cu_seqlens_q = torch.Tensor([0, 100, 384, 1000]).cuda().to(torch.int32) 42 | cu_seqlens_k = torch.Tensor([0, 100, 384, 1000]).cuda().to(torch.int32) 43 | max_seqlen_q = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).max() 44 | max_seqlen_k = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).max() 45 | o = flash_attn_varlen_func( 46 | q, 47 | k, 48 | v, 49 | cu_seqlens_q, 50 | cu_seqlens_k, 51 | max_seqlen_q, 52 | max_seqlen_k, 53 | causal=causal, 54 | ) 55 | randn = torch.randn_like(o) 56 | loss = (o * randn).sum() 57 | loss.backward() 58 | 59 | # flash attention 60 | torch.manual_seed(42) 61 | q1 = q.clone().detach().requires_grad_() 62 | k1 = k.clone().detach().requires_grad_() 63 | v1 = v.clone().detach().requires_grad_() 64 | cu_seqlens_q1 = cu_seqlens_q.clone().detach() 65 | cu_seqlens_k1 = cu_seqlens_k.clone().detach() 66 | max_seqlen_q1 = (cu_seqlens_q1[1:] - cu_seqlens_q1[:-1]).max() 67 | max_seqlen_k1 = (cu_seqlens_k1[1:] - cu_seqlens_k1[:-1]).max() 68 | o1 = flash_attention_varlen( 69 | q1, 70 | k1, 71 | v1, 72 | cu_seqlens_q1, 73 | cu_seqlens_k1, 74 | max_seqlen_q1, 75 | max_seqlen_k1, 76 | causal=causal, 77 | ) 78 | randn2 = randn.clone().detach() 79 | loss2 = (o1 * randn2).sum() 80 | loss2.backward() 81 | 82 | # diff 83 | print( 84 | f"=== Flash Attention Backward Test ({'causal' if causal else 'full'}) ===" 85 | ) 86 | print("Same Output:", torch.allclose(o, o1, atol=0.01, rtol=0.01)) 87 | print("Max Error:", (o - o1).abs().max().item()) 88 | print() 89 | print( 90 | "Same Query Gradient:", 91 | torch.allclose(q.grad, q1.grad, atol=0.01, rtol=0.01), 92 | ) 93 | print("Max Query Gradient Error:", (q.grad - q1.grad).abs().max().item()) 94 | print() 95 | print( 96 | "Same Key Gradient:", torch.allclose(k.grad, k1.grad, atol=0.01, rtol=0.01) 97 | ) 98 | print("Max Key Gradient Error:", (k.grad - k1.grad).abs().max().item()) 99 | print() 100 | print( 101 | "Same Value Gradient:", 102 | torch.allclose(v.grad, v1.grad, atol=0.01, rtol=0.01), 103 | ) 104 | print("Max Value Gradient Error:", (v.grad - v1.grad).abs().max().item()) 105 | print() 106 | 107 | # benchmark 108 | @triton.testing.perf_report( 109 | triton.testing.Benchmark( 110 | x_names=["N"], 111 | x_vals=[1024 * 2**i for i in range(1, 6)], 112 | line_arg="provider", 113 | line_vals=["flash", "triton-flash"], 114 | line_names=[ 115 | "Flash", 116 | "Triton-Flash", 117 | ], 118 | styles=[("green", "-"), ("green", "--")], 119 | ylabel="ms", 120 | plot_name="** forward **", 121 | args={"H": 64, "D": 128}, 122 | ) 123 | ) 124 | def benchmark(N, H, D, provider): 125 | q = torch.randn((N, H, D), device="cuda", dtype=torch.bfloat16) 126 | k = torch.randn((N, H // 16, D), device="cuda", dtype=torch.bfloat16) 127 | v = torch.randn((N, H // 16, D), device="cuda", dtype=torch.bfloat16) 128 | cu_seqlens = torch.tensor([0, N], device="cuda", dtype=torch.int32) 129 | sm_scale = 1 / math.sqrt(D) 130 | 131 | quantiles = [0.5, 0.2, 0.8] 132 | if provider == "flash": 133 | ms, min_ms, max_ms = triton.testing.do_bench( 134 | lambda: _flash_attn_varlen_forward( 135 | q, 136 | k, 137 | v, 138 | cu_seqlens, 139 | cu_seqlens, 140 | N, 141 | N, 142 | dropout_p=0.0, 143 | causal=True, 144 | softmax_scale=sm_scale, 145 | ), 146 | quantiles=quantiles, 147 | ) 148 | if provider == "triton-flash": 149 | ms, min_ms, max_ms = triton.testing.do_bench( 150 | lambda: _flash_attention_fwd( 151 | q, k, v, cu_seqlens, cu_seqlens, N, N, True, sm_scale 152 | ), 153 | quantiles=quantiles, 154 | ) 155 | return ms, min_ms, max_ms 156 | 157 | benchmark.run(show_plots=True, print_data=True) 158 | 159 | # benchmark 160 | @triton.testing.perf_report( 161 | triton.testing.Benchmark( 162 | x_names=["N"], 163 | x_vals=[1024 * 2**i for i in range(1, 6)], 164 | line_arg="provider", 165 | line_vals=["flash", "triton-flash"], 166 | line_names=[ 167 | "Flash", 168 | "Triton-Flash", 169 | ], 170 | styles=[("green", "-"), ("green", "--")], 171 | ylabel="ms", 172 | plot_name="** backward **", 173 | args={"H": 64, "D": 128}, 174 | ) 175 | ) 176 | def benchmark(N, H, D, provider): 177 | q = torch.randn((N, H, D), device="cuda", dtype=torch.bfloat16) 178 | k = torch.randn((N, H // 16, D), device="cuda", dtype=torch.bfloat16) 179 | v = torch.randn((N, H // 16, D), device="cuda", dtype=torch.bfloat16) 180 | o = torch.randn((N, H, D), device="cuda", dtype=torch.bfloat16) 181 | do = torch.randn((N, H, D), device="cuda", dtype=torch.bfloat16) 182 | lse = torch.randn((H, N), device="cuda", dtype=torch.float32) 183 | sm_scale = 1 / math.sqrt(D) 184 | cu_seqlens = torch.tensor([0, N], device="cuda", dtype=torch.int32) 185 | dq = torch.zeros_like(q) 186 | dk = torch.zeros_like(k) 187 | dv = torch.zeros_like(v) 188 | 189 | quantiles = [0.5, 0.2, 0.8] 190 | if provider == "flash": 191 | ms, min_ms, max_ms = triton.testing.do_bench( 192 | lambda: _flash_attn_varlen_backward( 193 | do, 194 | q, 195 | k, 196 | v, 197 | o, 198 | lse.transpose(0, 1), 199 | dq, 200 | dk, 201 | dv, 202 | cu_seqlens, 203 | cu_seqlens, 204 | N, 205 | N, 206 | dropout_p=0.0, 207 | causal=True, 208 | softmax_scale=sm_scale, 209 | window_size=(-1, -1), 210 | softcap=0.0, 211 | alibi_slopes=None, 212 | deterministic=False, 213 | ), 214 | quantiles=quantiles, 215 | ) 216 | if provider == "triton-flash": 217 | ms, min_ms, max_ms = triton.testing.do_bench( 218 | lambda: _flash_attention_bwd( 219 | o, do, lse, q, k, v, cu_seqlens, cu_seqlens, N, N, True, sm_scale 220 | ), 221 | quantiles=quantiles, 222 | ) 223 | return ms, min_ms, max_ms 224 | 225 | benchmark.run(show_plots=True, print_data=True) 226 | -------------------------------------------------------------------------------- /test/test_kv_cache.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Xunhao Lai & Jianqiao Lu. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import torch 15 | from native_sparse_attention.module.kv_cache import NSACache 16 | 17 | 18 | if __name__ == "__main__": 19 | from native_sparse_attention.ops import avgpool_compress 20 | 21 | torch.manual_seed(42) 22 | 23 | num_heads = 4 24 | head_dim = 128 25 | seqlens = torch.tensor([12, 576, 12000]).to(torch.int32).cuda() 26 | batch_size = seqlens.shape[0] 27 | cu_seqlens = torch.zeros(seqlens.shape[0] + 1, dtype=torch.int32, device="cuda") 28 | cu_seqlens[1:] = seqlens.cumsum(0) 29 | 30 | # init cache 31 | cache = NSACache(4, 16384, num_heads, head_dim, 32, 16, 512, torch.bfloat16, "cuda") 32 | 33 | # test prefill 34 | step = 0 35 | k = torch.randn(cu_seqlens[-1], num_heads, head_dim).cuda().bfloat16() 36 | v = torch.randn_like(k) 37 | ck, _ = avgpool_compress(k, None, cu_seqlens, 32, 16, None) 38 | cv, _ = avgpool_compress(v, None, cu_seqlens, 32, 16, None) 39 | cache.prepare_compress(cu_seqlens, step, k, v) 40 | cache.update_kv(cu_seqlens, step, ck, cv, k, v, k, v) 41 | 42 | # test decode 43 | step = 1 44 | k = torch.randn(batch_size, num_heads, head_dim).cuda().bfloat16() 45 | v = torch.randn_like(k) 46 | ck = torch.randn_like(k) 47 | cv = torch.randn_like(v) 48 | cache.prepare_compress(cu_seqlens, step, k, v) 49 | cache.update_kv(cu_seqlens, step, ck, cv, k, v, k, v) 50 | -------------------------------------------------------------------------------- /test/test_linear_compress.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Xunhao Lai & Jianqiao Lu. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import triton 17 | from native_sparse_attention.ops.torch.compress_key_value import linear_compress_torch 18 | from native_sparse_attention.ops.triton.linear_compress import linear_compress 19 | 20 | 21 | def test_linear_compress( 22 | batch_size: int = 1, 23 | num_heads: int = 1, 24 | head_dim: int = 32, 25 | max_seqlen: int = 32, 26 | kernel_sizes: list = [16, 32], 27 | kernel_strides: list = [8, 16], 28 | use_pe: bool = True, 29 | dtype: torch.dtype = torch.float32, 30 | device: str = "cuda", 31 | ): 32 | """ 33 | Test both PyTorch and Triton implementations of linear_compress for equivalence, 34 | including forward and backward passes. 35 | 36 | Args: 37 | batch_size: Number of sequences in the batch 38 | num_heads: Number of attention heads 39 | head_dim: Dimension of each attention head 40 | max_seqlen: Maximum sequence length 41 | kernel_sizes: List of kernel sizes to test 42 | kernel_strides: List of kernel strides to test 43 | use_pe: Whether to test with positional encoding 44 | dtype: Data type for tensors 45 | device: Device to run the test on 46 | """ 47 | torch.manual_seed(42) 48 | 49 | # Generate random sequence lengths for each batch 50 | 51 | seqlens = torch.randint( 52 | low=kernel_sizes[0], # minimum length should be at least kernel_size 53 | high=max_seqlen + 1, 54 | size=(batch_size,), 55 | device=device, 56 | ) 57 | # seqlens[:] = max_seqlen 58 | cu_seqlens = torch.cat( 59 | [ 60 | torch.tensor([0], device=device, dtype=torch.int32), 61 | torch.cumsum(seqlens, dim=0).to(torch.int32), 62 | ] 63 | ) 64 | 65 | total_len = cu_seqlens[-1].item() 66 | 67 | for kernel_size, kernel_stride in zip(kernel_sizes, kernel_strides): 68 | print(f"\nTesting kernel_size={kernel_size}, kernel_stride={kernel_stride}") 69 | 70 | # Create input tensors with requires_grad=True 71 | x_torch = torch.zeros( 72 | (total_len, num_heads, head_dim), 73 | dtype=dtype, 74 | device=device, 75 | ).uniform_(-1, 1) 76 | x_torch.requires_grad_(True) 77 | 78 | x_triton = x_torch.clone().detach().requires_grad_(True) 79 | 80 | w_torch = ( 81 | torch.ones( 82 | (num_heads, kernel_size * head_dim, head_dim), 83 | dtype=dtype, 84 | device=device, 85 | ) 86 | / kernel_size 87 | ) 88 | w_torch.requires_grad_(True) 89 | 90 | w_triton = w_torch.clone().detach().requires_grad_(True) 91 | 92 | pe_torch = None 93 | pe_triton = None 94 | if use_pe: 95 | pe_torch = torch.randn( 96 | (num_heads, kernel_size, head_dim), 97 | dtype=dtype, 98 | device=device, 99 | requires_grad=True, 100 | ) 101 | pe_triton = pe_torch.clone().detach().requires_grad_(True) 102 | 103 | # Run forward passes 104 | y_torch, y_cu_seqlens_torch = linear_compress_torch( 105 | x=x_torch, 106 | w=w_torch, 107 | cu_seqlens=cu_seqlens, 108 | kernel_size=kernel_size, 109 | kernel_stride=kernel_stride, 110 | pe=pe_torch, 111 | ) 112 | 113 | y_triton, y_cu_seqlens_triton = linear_compress( 114 | x=x_triton, 115 | w=w_triton, 116 | cu_seqlens=cu_seqlens, 117 | kernel_size=kernel_size, 118 | kernel_stride=kernel_stride, 119 | pe=pe_triton, 120 | ) 121 | 122 | # Check forward pass numerical equivalence 123 | atol, rtol = 1e-2, 1e-2 124 | values_match = torch.allclose(y_torch, y_triton, atol=atol, rtol=rtol) 125 | print( 126 | f"Forward pass - Output values match (atol={atol}, rtol={rtol}): {values_match}" 127 | ) 128 | if not values_match: 129 | max_diff = (y_torch - y_triton).abs().max().item() 130 | print(f"Forward pass - Maximum difference: {max_diff}") 131 | print("\nSample values (first batch, first head):") 132 | print("Torch:", y_torch[0, 0, :5]) 133 | print("Triton:", y_triton[0, 0, :5]) 134 | 135 | # Create random output gradients for backward pass 136 | grad_output = torch.randn_like(y_torch) 137 | 138 | # Run backward passes 139 | y_torch.backward(grad_output) 140 | y_triton.backward(grad_output) 141 | 142 | # Check gradient equivalence 143 | print("\nTesting backward pass:") 144 | 145 | # Check x gradients 146 | x_grads_match = torch.allclose( 147 | x_torch.grad, x_triton.grad, atol=atol, rtol=rtol 148 | ) 149 | print(f"x gradients match (atol={atol}, rtol={rtol}): {x_grads_match}") 150 | if not x_grads_match: 151 | max_diff = (x_torch.grad - x_triton.grad).abs().max().item() 152 | print(f"x gradients - Maximum difference: {max_diff}") 153 | print("\nSample x gradients (first batch, first head):") 154 | print("Torch:", x_torch.grad[0, 0, :5]) 155 | print("Triton:", x_triton.grad[0, 0, :5]) 156 | 157 | # Check w gradients 158 | w_grads_match = torch.allclose( 159 | w_torch.grad, w_triton.grad, atol=atol, rtol=rtol 160 | ) 161 | print(f"w gradients match (atol={atol}, rtol={rtol}): {w_grads_match}") 162 | if not w_grads_match: 163 | max_diff = (w_torch.grad - w_triton.grad).abs().max().item() 164 | print(f"w gradients - Maximum difference: {max_diff}") 165 | print("\nSample w gradients (first head):") 166 | print("Torch:", w_torch.grad[0, :5, 0]) 167 | print("Triton:", w_triton.grad[0, :5, 0]) 168 | 169 | # Check pe gradients if used 170 | if use_pe: 171 | pe_grads_match = torch.allclose( 172 | pe_torch.grad, pe_triton.grad, atol=atol, rtol=rtol 173 | ) 174 | print(f"pe gradients match (atol={atol}, rtol={rtol}): {pe_grads_match}") 175 | if not pe_grads_match: 176 | max_diff = (pe_torch.grad - pe_triton.grad).abs().max().item() 177 | print(f"pe gradients - Maximum difference: {max_diff}") 178 | print("\nSample pe gradients (first head):") 179 | print("Torch:", pe_torch.grad[0, :5, 0]) 180 | print("Triton:", pe_triton.grad[0, :5, 0]) 181 | 182 | # Clean up gradients for next iteration 183 | x_torch.grad = None 184 | x_triton.grad = None 185 | w_torch.grad = None 186 | w_triton.grad = None 187 | if use_pe: 188 | pe_torch.grad = None 189 | pe_triton.grad = None 190 | 191 | 192 | if __name__ == "__main__": 193 | # Run tests 194 | test_linear_compress( 195 | batch_size=16, 196 | num_heads=8, 197 | head_dim=128, 198 | max_seqlen=2048, 199 | kernel_sizes=[32], 200 | kernel_strides=[16], 201 | use_pe=False, 202 | dtype=torch.float16, 203 | device="cuda", 204 | ) 205 | 206 | # benchmark 207 | @triton.testing.perf_report( 208 | triton.testing.Benchmark( 209 | x_names=["N"], 210 | x_vals=[1024 * 2**i for i in range(1, 8)], 211 | line_arg="provider", 212 | line_vals=["torch", "triton"], 213 | line_names=["torch", "triton"], 214 | styles=[("green", "-"), ("blue", "-")], 215 | ylabel="ms", 216 | plot_name="** forward + backward **", 217 | args={"H": 4, "D": 64}, 218 | ) 219 | ) 220 | def benchmark_fwdbwd(N, H, D, provider): 221 | K, S = 32, 16 222 | # Input tensors 223 | x = torch.zeros(N, H, D, device="cuda", dtype=torch.bfloat16).uniform_(-1, 1) 224 | x.requires_grad = True 225 | w = torch.zeros(H, K * D, D, device="cuda", dtype=torch.bfloat16).uniform_( 226 | -1, 1 227 | ) 228 | w.requires_grad = True 229 | pe = torch.zeros(H, K, D, device="cuda", dtype=torch.bfloat16).uniform_(-1, 1) 230 | cu_seqlens_b32 = ( 231 | torch.LongTensor( 232 | [0 if i == 0 else 32 if i > 1 else N - 32 * 31 for i in range(33)] 233 | ) 234 | .int() 235 | .cuda() 236 | ) 237 | cu_seqlens_b32 = cu_seqlens_b32.cumsum(0).to(torch.int32) 238 | 239 | quantiles = [0.5, 0.2, 0.8] 240 | 241 | def fwd_bwd(): 242 | if provider == "torch": 243 | out, _ = linear_compress_torch(x, w, cu_seqlens_b32, K, S, pe) 244 | else: 245 | out, _ = linear_compress(x, w, cu_seqlens_b32, K, S, pe) 246 | out.backward(out) # Using output as gradient for simplicity 247 | return out 248 | 249 | ms, min_ms, max_ms = triton.testing.do_bench(fwd_bwd, quantiles=quantiles) 250 | return ms, min_ms, max_ms 251 | 252 | benchmark_fwdbwd.run(show_plots=True, print_data=True) 253 | -------------------------------------------------------------------------------- /test/test_nsa_infer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Xunhao Lai & Jianqiao Lu. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import torch 15 | from native_sparse_attention.ops import linear_compress, weightedpool_compress 16 | from native_sparse_attention.module import NSACache, RotaryEmbedding, RopeConfig 17 | from native_sparse_attention.infer import nsa_infer 18 | 19 | 20 | if __name__ == "__main__": 21 | torch.manual_seed(42) 22 | 23 | num_heads = 4 24 | head_dim = 128 25 | kernel_size = 32 26 | kernel_stride = 16 27 | block_size = 64 28 | window_size = 512 29 | topk = 16 30 | init_blocks = 1 31 | local_blocks = 2 32 | 33 | # init seqlens 34 | seqlens = torch.tensor([12, 576, 12000]).to(torch.int32).cuda() 35 | batch_size = seqlens.shape[0] 36 | cu_seqlens = torch.zeros(seqlens.shape[0] + 1, dtype=torch.int32, device="cuda") 37 | cu_seqlens[1:] = seqlens.cumsum(0) 38 | step = 0 39 | 40 | # init cache and weight and rope 41 | cache = NSACache(4, 16384, num_heads, head_dim, 32, 16, 512, torch.bfloat16, "cuda") 42 | compress_weight = [ 43 | torch.ones(num_heads, kernel_size * head_dim, head_dim).cuda().bfloat16() 44 | / (kernel_size * head_dim), 45 | torch.ones(num_heads, kernel_size).cuda().bfloat16() / kernel_size, 46 | ] 47 | compress_func = [linear_compress, weightedpool_compress] 48 | rope = RotaryEmbedding( 49 | RopeConfig( 50 | max_position_embeddings=131072, 51 | head_dim=128, 52 | rope_theta=500000, 53 | rope_scaling={ 54 | "factor": 8.0, 55 | "high_freq_factor": 4.0, 56 | "low_freq_factor": 1.0, 57 | "original_max_position_embeddings": 8192, 58 | "rope_type": "llama3", 59 | }, 60 | ) 61 | ) 62 | 63 | # test prefill 64 | q = torch.randn(cu_seqlens[-1], num_heads * 16, head_dim).cuda().bfloat16() 65 | k = torch.randn(cu_seqlens[-1], num_heads, head_dim).cuda().bfloat16() 66 | v = torch.randn_like(k) 67 | g = torch.rand(cu_seqlens[-1], num_heads * 16, 3).cuda().bfloat16() 68 | o = nsa_infer( 69 | cu_seqlens, 70 | step, 71 | q, 72 | k, 73 | v, 74 | g, 75 | rope, 76 | cache, 77 | compress_weight, 78 | compress_func, 79 | None, 80 | kernel_size, 81 | kernel_stride, 82 | block_size, 83 | topk, 84 | init_blocks, 85 | local_blocks, 86 | window_size, 87 | ) 88 | print(o.shape, o.norm()) 89 | 90 | # test decode 91 | q = torch.randn(cu_seqlens.shape[0] - 1, num_heads * 16, head_dim).cuda().bfloat16() 92 | k = torch.randn(cu_seqlens.shape[0] - 1, num_heads, head_dim).cuda().bfloat16() 93 | v = torch.randn_like(k) 94 | g = torch.rand(cu_seqlens.shape[0] - 1, num_heads * 16, 3).cuda().bfloat16() 95 | step = 1 96 | o = nsa_infer( 97 | cu_seqlens, 98 | step, 99 | q, 100 | k, 101 | v, 102 | g, 103 | rope, 104 | cache, 105 | compress_weight, 106 | compress_func, 107 | None, 108 | kernel_size, 109 | kernel_stride, 110 | block_size, 111 | topk, 112 | init_blocks, 113 | local_blocks, 114 | window_size, 115 | ) 116 | print(o.shape, o.norm()) 117 | -------------------------------------------------------------------------------- /test/test_nsa_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Xunhao Lai & Jianqiao Lu. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import torch 15 | from native_sparse_attention.model import ( 16 | ToyNSALlamaConfig, 17 | InferenceConfig, 18 | ToyNSALlama, 19 | ) 20 | 21 | 22 | if __name__ == "__main__": 23 | torch.manual_seed(42) 24 | # initialize model 25 | config = ToyNSALlamaConfig( 26 | hidden_size=4096, 27 | intermediate_size=14336, 28 | num_hidden_layers=8, 29 | num_attention_heads=32, 30 | num_key_value_heads=2, 31 | head_dim=128, 32 | rope_theta=500000.0, 33 | rope_scaling={ 34 | "factor": 8.0, 35 | "high_freq_factor": 4.0, 36 | "low_freq_factor": 1.0, 37 | "original_max_position_embeddings": 8192, 38 | "rope_type": "llama3", 39 | }, 40 | compress_type="weightedpool", 41 | kernel_size=32, 42 | kernel_stride=16, 43 | block_size=64, 44 | topk=8, 45 | init_blocks=1, 46 | local_blocks=2, 47 | window_size=512, 48 | ) 49 | inference_config = InferenceConfig( 50 | max_batch_size=4, 51 | max_length=8192, 52 | max_new_tokens=128, 53 | ) 54 | model = ToyNSALlama(config, inference_config).cuda().bfloat16() 55 | print(f"\nMODEL CONFIG:\n{config}\n") 56 | print(f"\nINFERENCE CONFIG:\n{inference_config}\n") 57 | print(f"\nMODEL:\n{model}\n") 58 | 59 | # example input 60 | batch_size = 4 61 | seqlens = torch.randint(0, 4096, (batch_size,), dtype=torch.int32, device="cuda") 62 | cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device="cuda") 63 | cu_seqlens[1:] = torch.cumsum(seqlens, dim=0) 64 | input_ids = torch.randint( 65 | 0, 128288, (cu_seqlens[-1],), dtype=torch.int64, device="cuda" 66 | ) 67 | print(f"\nEXAMPLE INPUT:\ncu_seqlens: {cu_seqlens}\ninput_ids: {input_ids.shape}\n") 68 | 69 | # example output 70 | logits = model(input_ids, cu_seqlens) 71 | print(f"\nEXAMPLE OUTPUT:\nlogits: {logits.shape}\n") 72 | 73 | # example generate 74 | output_tokens = model.generate(input_ids, cu_seqlens, 64) 75 | print(f"\nEXAMPLE GENERATE:\noutput_tokens: {output_tokens}\n") 76 | -------------------------------------------------------------------------------- /test/test_nsa_module.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Xunhao Lai & Jianqiao Lu. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific 13 | import torch 14 | import triton 15 | from native_sparse_attention.module import ( 16 | SelfAttention, 17 | NativeSparseAttention, 18 | RopeConfig, 19 | ) 20 | 21 | 22 | if __name__ == "__main__": 23 | torch.manual_seed(42) 24 | NSA = ( 25 | NativeSparseAttention( 26 | compress_type="avgpool", 27 | hidden_size=8192, 28 | num_q_heads=64, 29 | num_kv_heads=4, 30 | head_dim=128, 31 | kernel_size=32, 32 | kernel_stride=16, 33 | block_size=64, 34 | topk=16, 35 | init_blocks=1, 36 | local_blocks=2, 37 | window_size=512, 38 | rope_config=RopeConfig( 39 | max_position_embeddings=131072, 40 | head_dim=128, 41 | rope_theta=500000, 42 | rope_scaling={ 43 | "factor": 8.0, 44 | "high_freq_factor": 4.0, 45 | "low_freq_factor": 1.0, 46 | "original_max_position_embeddings": 8192, 47 | "rope_type": "llama3", 48 | }, 49 | ), 50 | ) 51 | .cuda() 52 | .to(torch.bfloat16) 53 | ) 54 | print("======= Init Moduel: Native Sparse Attention =======\n") 55 | for name, param in NSA.named_parameters(): 56 | print(f"NSA Parameters, {name}, shape: {param.shape}\n") 57 | 58 | # random input 59 | seqlens = torch.LongTensor([4000, 8192, 16384]).int().cuda() 60 | cu_seqlens = torch.cat( 61 | [ 62 | torch.zeros(1, dtype=torch.int32, device="cuda"), 63 | torch.cumsum(seqlens, dim=0), 64 | ], 65 | dim=0, 66 | ).to(torch.int32) 67 | x = torch.zeros(cu_seqlens[-1], 8192, device="cuda", dtype=torch.bfloat16).uniform_( 68 | -1, 1 69 | ) 70 | 71 | # forward test 72 | print("======= NSA Forward & Backward Test =======\n") 73 | y = NSA(x, cu_seqlens) 74 | print(f"Forward, output shape: {y.shape}, output norm: {y.norm()}\n") 75 | 76 | # backward test 77 | loss = (y * torch.randn_like(y)).sum(-1).mean() 78 | loss.backward() 79 | for name, param in NSA.named_parameters(): 80 | print( 81 | f"Backward, {name}, grad shape: {param.grad.shape}, grad norm: {param.grad.norm()}\n" 82 | ) 83 | 84 | # speed benchmark 85 | SelfAttn = ( 86 | SelfAttention( 87 | hidden_size=8192, 88 | num_q_heads=64, 89 | num_kv_heads=4, 90 | head_dim=128, 91 | rope_config=RopeConfig( 92 | max_position_embeddings=131072, 93 | head_dim=128, 94 | rope_theta=500000, 95 | rope_scaling={ 96 | "factor": 8.0, 97 | "high_freq_factor": 4.0, 98 | "low_freq_factor": 1.0, 99 | "original_max_position_embeddings": 8192, 100 | "rope_type": "llama3", 101 | }, 102 | ), 103 | ) 104 | .cuda() 105 | .to(torch.bfloat16) 106 | ) 107 | 108 | @triton.testing.perf_report( 109 | triton.testing.Benchmark( 110 | x_names=["N"], 111 | x_vals=[1024 * 2**i for i in range(1, 8)], 112 | line_arg="provider", 113 | line_vals=["Self-Attention", "Native-Sparse-Attention"], 114 | line_names=["Self-Attention", "Native-Sparse-Attention"], 115 | styles=[("green", "-"), ("blue", "-")], 116 | ylabel="ms", 117 | plot_name="** NSA forward speed benchmark **", 118 | args={}, 119 | ) 120 | ) 121 | def benchmark(N, provider): 122 | x = torch.randn(N, 8192, device="cuda", dtype=torch.bfloat16) 123 | cu_seqlens = torch.tensor([0, N], device="cuda", dtype=torch.int32) 124 | quantiles = [0.5, 0.2, 0.8] 125 | with torch.no_grad(): 126 | if provider == "Self-Attention": 127 | ms, min_ms, max_ms = triton.testing.do_bench( 128 | lambda: SelfAttn(x, cu_seqlens), 129 | quantiles=quantiles, 130 | ) 131 | if provider == "Native-Sparse-Attention": 132 | ms, min_ms, max_ms = triton.testing.do_bench( 133 | lambda: NSA(x, cu_seqlens), 134 | quantiles=quantiles, 135 | ) 136 | return ms, min_ms, max_ms 137 | 138 | benchmark.run(show_plots=True, print_data=True) 139 | 140 | @triton.testing.perf_report( 141 | triton.testing.Benchmark( 142 | x_names=["N"], 143 | x_vals=[1024 * 2**i for i in range(1, 8)], 144 | line_arg="provider", 145 | line_vals=["Self-Attention", "Native-Sparse-Attention"], 146 | line_names=["Self-Attention", "Native-Sparse-Attention"], 147 | styles=[("green", "-"), ("blue", "-")], 148 | ylabel="ms", 149 | plot_name="** NSA backward speed benchmark **", 150 | args={}, 151 | ) 152 | ) 153 | def benchmark(N, provider): 154 | x = torch.randn(N, 8192, device="cuda", dtype=torch.bfloat16) 155 | cu_seqlens = torch.tensor([0, N], device="cuda", dtype=torch.int32) 156 | quantiles = [0.5, 0.2, 0.8] 157 | if provider == "Self-Attention": 158 | loss = SelfAttn(x.clone().detach().requires_grad_(), cu_seqlens).mean() 159 | ms, min_ms, max_ms = triton.testing.do_bench( 160 | lambda: loss.backward(retain_graph=True), 161 | quantiles=quantiles, 162 | ) 163 | elif provider == "Native-Sparse-Attention": 164 | loss = NSA(x.clone().detach().requires_grad_(), cu_seqlens).mean() 165 | ms, min_ms, max_ms = triton.testing.do_bench( 166 | lambda: loss.backward(retain_graph=True), 167 | quantiles=quantiles, 168 | ) 169 | return ms, min_ms, max_ms 170 | 171 | benchmark.run(show_plots=True, print_data=True) 172 | -------------------------------------------------------------------------------- /test/test_rope.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Xunhao Lai & Jianqiao Lu. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import torch 15 | from native_sparse_attention.module import RopeConfig, RotaryEmbedding 16 | 17 | if __name__ == "__main__": 18 | rope_config = RopeConfig( 19 | max_position_embeddings=131072, 20 | head_dim=128, 21 | rope_theta=500000, 22 | rope_scaling={ 23 | "factor": 8.0, 24 | "high_freq_factor": 4.0, 25 | "low_freq_factor": 1.0, 26 | "original_max_position_embeddings": 8192, 27 | "rope_type": "llama3", 28 | }, 29 | ) 30 | rope = RotaryEmbedding(rope_config, "cuda") 31 | 32 | # random input 33 | torch.manual_seed(42) 34 | seqlens = torch.LongTensor([1000, 2000, 4096]).int().cuda() 35 | cu_seqlens = torch.cat( 36 | [ 37 | torch.zeros(1, dtype=torch.int32, device="cuda"), 38 | torch.cumsum(seqlens, dim=0), 39 | ], 40 | dim=0, 41 | ).to(torch.int32) 42 | x = torch.zeros( 43 | cu_seqlens[-1], 32, 128, device="cuda", dtype=torch.bfloat16 44 | ).uniform_(-1, 1) 45 | y = rope(x, cu_seqlens) 46 | -------------------------------------------------------------------------------- /test/test_topk_sparse_attention.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Xunhao Lai & Jianqiao Lu. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific 13 | import torch 14 | import triton 15 | import math 16 | from native_sparse_attention.ops.torch.topk_sparse_attention import ( 17 | topk_sparse_attention_torch, 18 | ) 19 | from native_sparse_attention.ops.triton.topk_sparse_attention import ( 20 | topk_sparse_attention, 21 | _topk_sparse_attention_fwd, 22 | _topk_sparse_attention_bwd, 23 | ) 24 | from native_sparse_attention.ops.triton.flash_attention import ( 25 | _flash_attention_fwd, 26 | _flash_attention_bwd, 27 | ) 28 | from flash_attn.flash_attn_interface import ( 29 | _flash_attn_varlen_forward, 30 | _flash_attn_varlen_backward, 31 | ) 32 | 33 | 34 | def generate_topk_idx_example( 35 | seqlens: torch.Tensor, 36 | block_size_k: int, 37 | topk: int, 38 | num_heads: int, 39 | block_size_q: int = 1, 40 | ) -> torch.Tensor: 41 | """Generate topk idx example for test. 42 | 43 | Args: 44 | seqlens (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens in flash_attn_func_varlen. 45 | block_size_q (int): query block size 46 | block_size_k (int): key value block size 47 | topk (int): selected topk 48 | num_heads (int): number of key value heads 49 | 50 | Returns: 51 | torch.Tensor: shape [num_heads, total_seqlen, topk], topk key value block idx for each query. -1 means padding. 52 | """ 53 | batch_size = seqlens.shape[0] 54 | num_blocks = torch.ceil(seqlens / block_size_k).to(torch.int32) 55 | topk_idx_all_heads = [] 56 | cu_seqlens = torch.nn.functional.pad(seqlens.cumsum(0), pad=(1, 0), value=0) 57 | for _ in range(num_heads): 58 | topk_idx = [ 59 | torch.randn(seqlens[i], num_blocks[i], device="cuda") 60 | .topk(min(topk, num_blocks[i]), dim=-1) 61 | .indices.to(torch.int32) 62 | for i in range(batch_size) 63 | ] 64 | topk_idx = [ 65 | torch.nn.functional.pad( 66 | topk_idx[i], (0, topk - topk_idx[i].shape[-1]), value=topk 67 | ) 68 | for i in range(batch_size) 69 | ] 70 | topk_idx = torch.cat(topk_idx, dim=0) 71 | topk_idx = torch.sort(topk_idx, dim=1).values 72 | topk_idx[:, 0] = 0 73 | q_idx = torch.cat( 74 | [torch.arange(seqlens[i], device="cuda") for i in range(batch_size)], dim=0 75 | ) 76 | topk_idx[topk_idx > (q_idx // block_size_k)[:, None]] = -1 # -1 means padding 77 | topk_idx = torch.cat( 78 | [ 79 | topk_idx[cu_seqlens[i] : cu_seqlens[i + 1]][0::block_size_q] 80 | for i in range(batch_size) 81 | ], 82 | dim=0, 83 | ) 84 | topk_idx_all_heads.append(topk_idx) 85 | topk_idx = torch.stack(topk_idx_all_heads, dim=0) 86 | return topk_idx 87 | 88 | 89 | if __name__ == "__main__": 90 | torch.manual_seed(42) 91 | batch_size = 3 92 | seqlens = torch.LongTensor([1000, 2000, 4096]).int().cuda() 93 | cu_seqlens = torch.cat( 94 | [ 95 | torch.zeros(1, dtype=torch.int32, device="cuda"), 96 | torch.cumsum(seqlens, dim=0), 97 | ], 98 | dim=0, 99 | ).to(torch.int32) 100 | max_seqlen = seqlens.max().item() 101 | q = ( 102 | torch.empty(cu_seqlens[-1], 64, 96, device="cuda") 103 | .uniform_(-1, 1) 104 | .to(torch.float16) 105 | ) 106 | k = ( 107 | torch.empty(cu_seqlens[-1], 8, 96, device="cuda") 108 | .uniform_(-1, 1) 109 | .to(torch.float16) 110 | ) 111 | v = ( 112 | torch.empty(cu_seqlens[-1], 8, 96, device="cuda") 113 | .uniform_(-1, 1) 114 | .to(torch.float16) 115 | ) 116 | q.requires_grad = True 117 | k.requires_grad = True 118 | v.requires_grad = True 119 | block_size = 64 120 | topk = 5 121 | topk_idx = generate_topk_idx_example(seqlens, block_size, topk, 8) 122 | 123 | o = topk_sparse_attention_torch(q, k, v, topk_idx, block_size, cu_seqlens) 124 | 125 | randn = torch.randn_like(o) 126 | loss = (o * randn).sum() 127 | loss.backward() 128 | 129 | torch.manual_seed(42) 130 | q1 = q.clone().detach().requires_grad_() 131 | k1 = k.clone().detach().requires_grad_() 132 | v1 = v.clone().detach().requires_grad_() 133 | topk_idx1 = topk_idx.clone().detach() 134 | cu_seqlens1 = cu_seqlens.clone().detach() 135 | 136 | o1 = topk_sparse_attention(q1, k1, v1, topk_idx, block_size, cu_seqlens) 137 | 138 | randn2 = randn.clone().detach() 139 | loss2 = (o1 * randn2).sum() 140 | loss2.backward() 141 | 142 | print("Same Output:", torch.allclose(o, o1, atol=0.01, rtol=0.01)) 143 | print("Max Error:", (o - o1).abs().max().item()) 144 | print() 145 | print("Same Query Gradient:", torch.allclose(q.grad, q1.grad, atol=0.01, rtol=0.01)) 146 | print("Max Query Gradient Error:", (q.grad - q1.grad).abs().max().item()) 147 | print() 148 | print("Same Key Gradient:", torch.allclose(k.grad, k1.grad, atol=0.01, rtol=0.01)) 149 | print("Max Key Gradient Error:", (k.grad - k1.grad).abs().max().item()) 150 | print() 151 | print("Same Value Gradient:", torch.allclose(v.grad, v1.grad, atol=0.01, rtol=0.01)) 152 | print("Max Value Gradient Error:", (v.grad - v1.grad).abs().max().item()) 153 | print() 154 | 155 | # benchmark 156 | @triton.testing.perf_report( 157 | triton.testing.Benchmark( 158 | x_names=["N"], 159 | x_vals=[1024 * 2**i for i in range(1, 8)], 160 | line_arg="provider", 161 | line_vals=["flash", "triton-flash", "triton-top8", "triton-top16"], 162 | line_names=[ 163 | "Flash", 164 | "Triton-Flash", 165 | "Triton-Top8", 166 | "Triton-Top16", 167 | ], 168 | styles=[("green", "-"), ("green", "--"), ("blue", "-"), ("blue", "--")], 169 | ylabel="ms", 170 | plot_name="** forward with block size 64 **", 171 | args={"H": 64, "D": 128, "K": 64}, 172 | ) 173 | ) 174 | def benchmark(N, H, D, K, provider): 175 | q = torch.randn((N, H, D), device="cuda", dtype=torch.bfloat16) 176 | k = torch.randn((N, H // 16, D), device="cuda", dtype=torch.bfloat16) 177 | v = torch.randn((N, H // 16, D), device="cuda", dtype=torch.bfloat16) 178 | cu_seqlens = torch.tensor([0, N], device="cuda", dtype=torch.int32) 179 | sm_scale = 1 / math.sqrt(D) 180 | 181 | top8_idx = generate_topk_idx_example(cu_seqlens[1:], K, 8, H // 16) 182 | top16_idx = generate_topk_idx_example(cu_seqlens[1:], K, 16, H // 16) 183 | 184 | quantiles = [0.5, 0.2, 0.8] 185 | if provider == "flash": 186 | ms, min_ms, max_ms = triton.testing.do_bench( 187 | lambda: _flash_attn_varlen_forward( 188 | q, 189 | k, 190 | v, 191 | cu_seqlens, 192 | cu_seqlens, 193 | N, 194 | N, 195 | dropout_p=0.0, 196 | causal=True, 197 | softmax_scale=sm_scale, 198 | ), 199 | quantiles=quantiles, 200 | ) 201 | if provider == "triton-flash": 202 | ms, min_ms, max_ms = triton.testing.do_bench( 203 | lambda: _flash_attention_fwd( 204 | q, k, v, cu_seqlens, cu_seqlens, N, N, True, sm_scale 205 | ), 206 | quantiles=quantiles, 207 | ) 208 | if provider == "triton-top8": 209 | ms, min_ms, max_ms = triton.testing.do_bench( 210 | lambda: _topk_sparse_attention_fwd( 211 | q, k, v, top8_idx, K, cu_seqlens, cu_seqlens, N, N, sm_scale 212 | ), 213 | quantiles=quantiles, 214 | ) 215 | if provider == "triton-top16": 216 | ms, min_ms, max_ms = triton.testing.do_bench( 217 | lambda: _topk_sparse_attention_fwd( 218 | q, k, v, top16_idx, K, cu_seqlens, cu_seqlens, N, N, sm_scale 219 | ), 220 | quantiles=quantiles, 221 | ) 222 | return ms, min_ms, max_ms 223 | 224 | benchmark.run(show_plots=True, print_data=True) 225 | 226 | # benchmark 227 | @triton.testing.perf_report( 228 | triton.testing.Benchmark( 229 | x_names=["N"], 230 | x_vals=[1024 * 2**i for i in range(1, 8)], 231 | line_arg="provider", 232 | line_vals=["flash", "triton-flash", "triton-top8", "triton-top16"], 233 | line_names=[ 234 | "Flash", 235 | "Triton-Flash", 236 | "Triton-Top8", 237 | "Triton-Top16", 238 | ], 239 | styles=[("green", "-"), ("green", "--"), ("blue", "-"), ("blue", "--")], 240 | ylabel="ms", 241 | plot_name="** backward with block size 64 **", 242 | args={"H": 64, "D": 128, "K": 64}, 243 | ) 244 | ) 245 | def benchmark(N, H, D, K, provider): 246 | q = torch.randn((N, H, D), device="cuda", dtype=torch.bfloat16) 247 | k = torch.randn((N, H // 16, D), device="cuda", dtype=torch.bfloat16) 248 | v = torch.randn((N, H // 16, D), device="cuda", dtype=torch.bfloat16) 249 | o = torch.randn((N, H, D), device="cuda", dtype=torch.bfloat16) 250 | do = torch.randn((N, H, D), device="cuda", dtype=torch.bfloat16) 251 | lse = torch.randn((H, N), device="cuda", dtype=torch.float32) 252 | sm_scale = 1 / math.sqrt(D) 253 | cu_seqlens = torch.tensor([0, N], device="cuda", dtype=torch.int32) 254 | top8_idx = generate_topk_idx_example(cu_seqlens[1:], K, 8, H // 16) 255 | top16_idx = generate_topk_idx_example(cu_seqlens[1:], K, 16, H // 16) 256 | dq = torch.zeros_like(q) 257 | dk = torch.zeros_like(k) 258 | dv = torch.zeros_like(v) 259 | 260 | quantiles = [0.5, 0.2, 0.8] 261 | if provider == "flash": 262 | ms, min_ms, max_ms = triton.testing.do_bench( 263 | lambda: _flash_attn_varlen_backward( 264 | do, 265 | q, 266 | k, 267 | v, 268 | o, 269 | lse.transpose(0, 1), 270 | dq, 271 | dk, 272 | dv, 273 | cu_seqlens, 274 | cu_seqlens, 275 | N, 276 | N, 277 | dropout_p=0.0, 278 | causal=True, 279 | softmax_scale=sm_scale, 280 | window_size=(-1, -1), 281 | softcap=0.0, 282 | alibi_slopes=None, 283 | deterministic=False, 284 | ), 285 | quantiles=quantiles, 286 | ) 287 | if provider == "triton-flash": 288 | ms, min_ms, max_ms = triton.testing.do_bench( 289 | lambda: _flash_attention_bwd( 290 | o, do, lse, q, k, v, cu_seqlens, cu_seqlens, N, N, True, sm_scale 291 | ), 292 | quantiles=quantiles, 293 | ) 294 | if provider == "triton-top8": 295 | ms, min_ms, max_ms = triton.testing.do_bench( 296 | lambda: _topk_sparse_attention_bwd( 297 | o, 298 | do, 299 | lse, 300 | q, 301 | k, 302 | v, 303 | top8_idx, 304 | K, 305 | cu_seqlens, 306 | cu_seqlens, 307 | N, 308 | N, 309 | sm_scale, 310 | ), 311 | quantiles=quantiles, 312 | ) 313 | if provider == "triton-top16": 314 | ms, min_ms, max_ms = triton.testing.do_bench( 315 | lambda: _topk_sparse_attention_bwd( 316 | o, 317 | do, 318 | lse, 319 | q, 320 | k, 321 | v, 322 | top16_idx, 323 | K, 324 | cu_seqlens, 325 | cu_seqlens, 326 | N, 327 | N, 328 | sm_scale, 329 | ), 330 | quantiles=quantiles, 331 | ) 332 | return ms, min_ms, max_ms 333 | 334 | benchmark.run(show_plots=True, print_data=True) 335 | --------------------------------------------------------------------------------