├── .gitignore
├── LICENSE
├── README.md
├── install_dependency.sh
├── native_sparse_attention
├── __init__.py
├── infer
│ ├── __init__.py
│ ├── inference_func.py
│ └── nsa_inference.py
├── model
│ ├── README.md
│ ├── __init__.py
│ ├── toy_llama.py
│ └── toy_nsa_llama.py
├── module
│ ├── __init__.py
│ ├── kv_cache.py
│ ├── native_sparse_attention.py
│ ├── rope.py
│ └── self_attention.py
└── ops
│ ├── README.md
│ ├── __init__.py
│ ├── torch
│ ├── __init__.py
│ ├── compress_key_value.py
│ ├── compressed_attention.py
│ ├── compressed_attention_decode.py
│ └── topk_sparse_attention.py
│ └── triton
│ ├── __init__.py
│ ├── compressed_attention.py
│ ├── flash_attention.py
│ ├── flash_attention_decode.py
│ ├── linear_compress.py
│ ├── topk_sparse_attention.py
│ ├── topk_sparse_attention_decode.py
│ ├── utils.py
│ └── weighted_pool.py
├── setup.py
└── test
├── test_compress_key_value.py
├── test_compressed_attention.py
├── test_flash_attention.py
├── test_kv_cache.py
├── test_linear_compress.py
├── test_nsa_infer.py
├── test_nsa_model.py
├── test_nsa_module.py
├── test_rope.py
└── test_topk_sparse_attention.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # UV
98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | #uv.lock
102 |
103 | # poetry
104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105 | # This is especially recommended for binary packages to ensure reproducibility, and is more
106 | # commonly ignored for libraries.
107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108 | #poetry.lock
109 |
110 | # pdm
111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112 | #pdm.lock
113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114 | # in version control.
115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116 | .pdm.toml
117 | .pdm-python
118 | .pdm-build/
119 |
120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121 | __pypackages__/
122 |
123 | # Celery stuff
124 | celerybeat-schedule
125 | celerybeat.pid
126 |
127 | # SageMath parsed files
128 | *.sage.py
129 |
130 | # Environments
131 | .env
132 | .venv
133 | env/
134 | venv/
135 | ENV/
136 | env.bak/
137 | venv.bak/
138 |
139 | # Spyder project settings
140 | .spyderproject
141 | .spyproject
142 |
143 | # Rope project settings
144 | .ropeproject
145 |
146 | # mkdocs documentation
147 | /site
148 |
149 | # mypy
150 | .mypy_cache/
151 | .dmypy.json
152 | dmypy.json
153 |
154 | # Pyre type checker
155 | .pyre/
156 |
157 | # pytype static type analyzer
158 | .pytype/
159 |
160 | # Cython debug symbols
161 | cython_debug/
162 |
163 | # PyCharm
164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166 | # and can be added to the global gitignore or merged into this file. For a more nuclear
167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168 | #.idea/
169 |
170 | # PyPI configuration file
171 | .pypirc
172 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # Native Sparse Attention Triton
4 |
5 |
6 |
7 | This repository implements the sparse attention mechanism introduced in the paper [Native Sparse Attention: Hardware-Aligned and Natively Trainable Sparse Attention](https://arxiv.org/abs/2502.11089) and provides an efficient training implementation based on [Triton](https://github.com/triton-lang/triton).
8 |
9 | 🎉 We now support both training and inference for Native Sparse Attention (variable-length version, including prefilling, decoding, and KV cache management). We have provided a toy model at `model.ToyNSALlama`, which supports `forward` function for training and `generate` function for inference. Welcome to try it out!
10 |
11 | ## Requirements
12 | Ensure the following dependencies are installed:
13 | - PyTorch >= 2.1.0
14 | - triton >= 3.0.0
15 | - einops >= 0.7.0
16 | - flash_attn >= 2.6.3
17 |
18 | ## Usage
19 |
20 | ### Notes
21 | 1. PyTorch implementations (`ops.torch`) are intended for debugging only.
22 | 2. For production use, prefer Triton operators (`ops.triton`).
23 | 3. All implementations are based on the varlen approach similiar to flash_attn_func_varlen. Please concatenate the inputs of a batch before use.
24 | 4. Only support attention head dimension less than 128 for now.
25 |
26 | ### Install
27 |
28 | You can install `native_sparse_attention` using pip:
29 |
30 | ```shell
31 | pip install git+https://github.com/XunhaoLai/native-sparse-attention-triton.git
32 | ```
33 |
34 | ### Functions
35 |
36 | The `ops` module has implemented several functions required for native sparse attention. For detailed usage instructions, please see [this link](https://github.com/XunhaoLai/native-sparse-attention-triton/tree/main/native_sparse_attention/ops#readme).
37 |
38 | You can import those functions from the `ops` module:
39 |
40 | ```python
41 | import torch
42 | from native_sparse_attention.ops import linear_compress, compressed_attention, topk_sparse_attention
43 |
44 | # input example
45 | num_q_heads = 64
46 | num_kv_heads = 4
47 | head_dim = 128
48 | kernel_size = 32
49 | kernel_stride = 16
50 | block_size = 64
51 | topk = 16
52 | cu_seqlens = torch.Tensor([0, 1024, 8192, 16384]).to(torch.int32).cuda()
53 | query = torch.randn(16384, num_q_heads, head_dim).to(torch.bfloat16).cuda()
54 | key = torch.randn(16384, num_kv_heads, head_dim).to(torch.bfloat16).cuda()
55 | value = torch.randn(16384, num_kv_heads, head_dim).to(torch.bfloat16).cuda()
56 |
57 | # weight example
58 | w = (
59 | torch.randn(num_kv_heads, kernel_size * head_dim, head_dim)
60 | .to(torch.bfloat16)
61 | .cuda()
62 | )
63 | pe = torch.randn(num_kv_heads, kernel_size, head_dim).to(torch.bfloat16).cuda()
64 |
65 | # 1. key value compression
66 | compressed_key, compressed_cu_seqlens = linear_compress(
67 | key, w, cu_seqlens, kernel_size, kernel_stride, pe
68 | )
69 | compressed_value, _ = linear_compress(
70 | value, w, cu_seqlens, kernel_size, kernel_stride, None
71 | )
72 |
73 | # 2. attention between query and compressed key value
74 | compressed_attn_output, topk_idx = compressed_attention(
75 | query,
76 | compressed_key,
77 | compressed_value,
78 | kernel_size,
79 | kernel_stride,
80 | block_size,
81 | topk,
82 | cu_seqlens,
83 | compressed_cu_seqlens,
84 | init_blocks=1,
85 | local_blocks=2,
86 | )
87 |
88 | # 3. topk sparse attention
89 | sparse_attn_output = topk_sparse_attention(
90 | query,
91 | key,
92 | value,
93 | topk_idx,
94 | block_size,
95 | cu_seqlens,
96 | )
97 | ```
98 |
99 | ### Module
100 |
101 | The `modules` directory also provides implementations based on `torch.nn.module` for easy integration into models.
102 |
103 | ```python
104 | from native_sparse_attention.modules import NativeSparseAttention, RopeConfig
105 |
106 | NSA_Layer = NativeSparseAttention(
107 | compress_type="linear",
108 | hidden_size=4096,
109 | num_q_heads=64,
110 | num_kv_heads=4,
111 | head_dim=128,
112 | kernel_size=32,
113 | kernel_stride=16,
114 | block_size=64,
115 | topk=8,
116 | init_blocks=1,
117 | local_blocks=2,
118 | window_size=512,
119 | rope_config=RopeConfig(
120 | max_position_embeddings=32768,
121 | head_dim=128,
122 | rope_theta=500000,
123 | rope_scaling={
124 | "factor": 4.0,
125 | "high_freq_factor": 4.0,
126 | "low_freq_factor": 1.0,
127 | "original_max_position_embeddings": 8192,
128 | "rope_type": "llama3",
129 | },
130 | ),
131 | )
132 | ```
133 |
134 | ### Model
135 |
136 | We offer two simplified LLaMA models in the `model` directory, featuring self-attention and native sparse attention. For more details on how to use these models, please refer to [this link](https://github.com/XunhaoLai/native-sparse-attention-triton/tree/main/native_sparse_attention/model#readme).
137 |
138 |
139 | ```python
140 | from native_sparse_attention.model import ToyNSALlamaConfig, InferenceConfig, ToyNSALlama
141 |
142 | config = ToyNSALlamaConfig(
143 | hidden_size=4096,
144 | intermediate_size=14336,
145 | num_hidden_layers=8,
146 | num_attention_heads=32,
147 | num_key_value_heads=2,
148 | head_dim=128,
149 | rope_theta=500000.0,
150 | rope_scaling={
151 | "factor": 8.0,
152 | "high_freq_factor": 4.0,
153 | "low_freq_factor": 1.0,
154 | "original_max_position_embeddings": 8192,
155 | "rope_type": "llama3",
156 | },
157 | compress_type="weightedpool",
158 | kernel_size=32,
159 | kernel_stride=16,
160 | block_size=64,
161 | topk=8,
162 | init_blocks=1,
163 | local_blocks=2,
164 | window_size=512,
165 | )
166 | inference_config = InferenceConfig(
167 | max_batch_size=4,
168 | max_length=8192,
169 | max_new_tokens=128,
170 | )
171 | model = ToyNSALlama(config, inference_config).cuda().bfloat16()
172 | ```
173 |
174 | ## Testing
175 |
176 | Some test scripts are available in the `test` folder and can be run directly for unit testing. For example:
177 |
178 | ```bash
179 | python test/test_topk_sparse_attention.py
180 | python test/test_nsa_module.py
181 | python test/test_nsa_model.py
182 | ```
183 |
184 | ### Benchmarks
185 |
186 | Here are the speed benchmarks conducted on a single NVIDIA A100 GPU or H100 GPU for the `topk_sparse_attention` function:
187 |
188 | A100 GPU speed benchmarks:
189 | ```sh
190 | ** forward with block size 64 **:
191 | N Flash Triton-Flash Triton-Top8 Triton-Top16
192 | 0 2048.0 0.414144 0.635648 0.633440 1.009184
193 | 1 4096.0 1.400304 2.267552 1.179808 1.916736
194 | 2 8192.0 5.223776 8.528160 2.266816 3.723168
195 | 3 16384.0 20.225697 32.745537 4.468128 7.359168
196 | 4 32768.0 79.587715 128.951065 8.517440 14.142848
197 | 5 65536.0 321.240479 511.652100 17.249599 30.991360
198 | 6 131072.0 1349.810425 2063.245605 36.400482 67.884544
199 |
200 | ** backward with block size 64 **:
201 | N Flash Triton-Flash Triton-Top8 Triton-Top16
202 | 0 2048.0 1.315440 2.348560 1.941568 2.691040
203 | 1 4096.0 4.271584 8.553184 3.647744 5.032160
204 | 2 8192.0 15.323984 32.665440 5.650144 9.066112
205 | 3 16384.0 58.753281 127.675964 11.160832 17.113279
206 | 4 32768.0 227.770462 504.572693 21.723392 34.715614
207 | 5 65536.0 899.181274 2059.718506 44.517181 76.309441
208 | 6 131072.0 3587.918701 8530.726562 105.344734 182.970169
209 | ```
210 |
211 | H100 GPU benchmarks:
212 | ```sh
213 | ** forward with block size 64 **:
214 | N Flash Triton-Flash Triton-Top8 Triton-Top16
215 | 0 2048.0 0.259552 0.293888 0.584544 0.917664
216 | 1 4096.0 0.846848 1.029904 1.094976 1.745136
217 | 2 8192.0 3.043744 3.843392 2.128256 3.396880
218 | 3 16384.0 11.743568 14.791360 4.190528 6.704192
219 | 4 32768.0 45.968513 57.532478 7.614496 12.417440
220 | 5 65536.0 187.234375 228.093948 14.840048 24.511856
221 | 6 131072.0 810.890381 914.693970 29.470400 48.990192
222 |
223 | ** backward with block size 64 **:
224 | N Flash Triton-Flash Triton-Top8 Triton-Top16
225 | 0 2048.0 0.798976 1.096096 1.117312 1.380016
226 | 1 4096.0 2.545680 3.826336 1.669760 2.214880
227 | 2 8192.0 9.029760 14.411633 2.772096 3.947456
228 | 3 16384.0 34.144016 58.945698 5.201344 7.538912
229 | 4 32768.0 135.718369 233.369247 9.968864 15.154192
230 | 5 65536.0 541.053894 929.337646 21.089870 33.818878
231 | 6 131072.0 2139.974854 3785.540527 54.918144 93.750717
232 | ```
233 |
234 | Here comes another speed benchmark result for testing `compressed_attention` function on a single NVIDIA A100 GPU or H100 GPU:
235 |
236 | A100 GPU speed benchmarks:
237 | ```sh
238 | ** forward with kernel 32 and stride 16 **:
239 | N Flash Triton-Flash Compressed Compressed-wo-Score
240 | 0 2048.0 0.413664 0.635488 0.655024 0.170816
241 | 1 4096.0 1.396416 2.247648 1.132304 0.377152
242 | 2 8192.0 5.234656 8.526400 2.879200 0.977952
243 | 3 16384.0 19.988865 32.755199 9.426448 2.943024
244 | 4 32768.0 79.419907 128.955170 30.284096 9.901120
245 | 5 65536.0 321.590210 511.615509 112.260544 36.001602
246 | 6 131072.0 1346.996338 2069.837891 423.099518 136.820038
247 |
248 | ** backward with kernel 32 and stride 16 **:
249 | N Flash Triton-Flash Compressed
250 | 0 2048.0 1.322560 2.352000 0.486784
251 | 1 4096.0 4.270832 8.552608 0.971392
252 | 2 8192.0 15.515680 32.671329 2.603744
253 | 3 16384.0 59.345055 128.377472 8.499456
254 | 4 32768.0 230.626144 506.581238 30.064833
255 | 5 65536.0 919.260498 2068.642578 113.466560
256 | 6 131072.0 3646.603760 8498.374023 439.623444
257 | ```
258 |
259 | H100 GPU speed benchmarks:
260 | ```sh
261 | ** forward with kernel 32 and stride 16 **:
262 | N Flash Triton-Flash Compressed Compressed-wo-Score
263 | 0 2048.0 0.259488 0.297152 0.485920 0.103232
264 | 1 4096.0 0.847376 1.030400 0.710208 0.217760
265 | 2 8192.0 3.044016 3.875840 1.607360 0.516016
266 | 3 16384.0 11.823104 14.829360 4.970272 1.440288
267 | 4 32768.0 46.204750 57.527809 15.004992 4.584736
268 | 5 65536.0 187.324249 227.909958 53.009087 16.134224
269 | 6 131072.0 810.707214 910.106873 191.245728 60.154270
270 |
271 | ** backward with kernel 32 and stride 16 **:
272 | N Flash Triton-Flash Compressed
273 | 0 2048.0 0.797728 1.090640 0.283104
274 | 1 4096.0 2.547088 3.834592 0.550464
275 | 2 8192.0 9.021520 14.421088 1.249184
276 | 3 16384.0 34.159508 58.793377 3.743440
277 | 4 32768.0 136.844070 233.447708 12.640032
278 | 5 65536.0 537.559814 929.360229 46.054817
279 | 6 131072.0 2135.629883 3782.351562 175.587296
280 | ```
281 |
282 | All the speed benchmarks above were tested with 64 query heads, 4 key/value heads, and a head dimension of 128.
283 |
284 | ## Contributing
285 | Contributions are welcome! Please open an issue to discuss major changes.
286 |
287 | ## Contact
288 |
289 | For any questions or feedback, please feel free to contact laixunhao@pku.edu.cn.
290 |
291 | ## Citations
292 |
293 | ```bibtex
294 | @inproceedings{Yuan2025NativeSA,
295 | title = {Native Sparse Attention: Hardware-Aligned and Natively Trainable Sparse Attention},
296 | author = {Jingyang Yuan and Huazuo Gao and Damai Dai and Junyu Luo and Liang Zhao and Zhengyan Zhang and Zhenda Xie and Y. X. Wei and Lean Wang and Zhiping Xiao and Yuqing Wang and Chong Ruan and Ming Zhang and Wenfeng Liang and Wangding Zeng},
297 | year = {2025},
298 | url = {https://api.semanticscholar.org/CorpusID:276408911}
299 | }
300 | ```
301 |
--------------------------------------------------------------------------------
/install_dependency.sh:
--------------------------------------------------------------------------------
1 | pip3 install packaging -i https://pypi.org/simple
2 | pip3 install numpy==1.26.4 -i https://pypi.org/simple
3 | pip3 install torch==2.4.0 -i https://pypi.org/simple
4 | pip3 install triton==3.0.0 -i https://pypi.org/simple
5 | pip3 install transformers==4.44.0 -i https://pypi.org/simple
6 | pip3 install flash_attn==2.6.3 -i https://pypi.org/simple
7 | pip3 install matplotlib==3.9.4 -i https://pypi.org/simple
8 | pip3 install pandas==2.2.3 -i https://pypi.org/simple
--------------------------------------------------------------------------------
/native_sparse_attention/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XunhaoLai/native-sparse-attention-triton/9bea856c911ebf263be88d797fb28458f82f1d94/native_sparse_attention/__init__.py
--------------------------------------------------------------------------------
/native_sparse_attention/infer/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Xunhao Lai & Jianqiao Lu.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from native_sparse_attention.infer.nsa_inference import nsa_infer
16 |
17 | __all__ = [
18 | "nsa_infer",
19 | ]
20 |
--------------------------------------------------------------------------------
/native_sparse_attention/infer/inference_func.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Xunhao Lai & Jianqiao Lu.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import torch
15 | from typing import Tuple, Callable, Optional
16 | from flash_attn import flash_attn_varlen_func
17 | from native_sparse_attention.ops import (
18 | flash_attention_decode,
19 | compressed_attention,
20 | compressed_attention_decode,
21 | topk_sparse_attention,
22 | topk_sparse_attention_decode,
23 | )
24 | from native_sparse_attention.ops.triton.utils import get_compressed_seqlens
25 |
26 |
27 | def compress_infer(
28 | cu_seqlens: torch.Tensor,
29 | step: int,
30 | key: torch.Tensor,
31 | value: torch.Tensor,
32 | cache,
33 | weight: Tuple[torch.Tensor, torch.Tensor],
34 | compress_func: Tuple[Callable, Callable],
35 | intra_block_pe: Optional[torch.Tensor],
36 | kernel_size: int,
37 | kernel_stride: int,
38 | ):
39 | if step == 0:
40 | key, compress_cu_seqlens = compress_func[0](
41 | key,
42 | weight[0],
43 | cu_seqlens,
44 | kernel_size,
45 | kernel_stride,
46 | intra_block_pe,
47 | )
48 | value, _ = compress_func[1](
49 | value,
50 | weight[1],
51 | cu_seqlens,
52 | kernel_size,
53 | kernel_stride,
54 | )
55 | else:
56 | batch_size = cu_seqlens.shape[0] - 1
57 | aux_cu_seqlens = (
58 | torch.arange(batch_size + 1, dtype=torch.int32).to(cu_seqlens.device)
59 | * kernel_size
60 | )
61 | key, _ = compress_func[0](
62 | cache.before_compress_kv_cache[0, :batch_size].view(
63 | batch_size * kernel_size, cache.num_kv_heads, cache.head_dim
64 | ),
65 | weight[0],
66 | aux_cu_seqlens,
67 | kernel_size,
68 | kernel_stride,
69 | intra_block_pe,
70 | )
71 | value, _ = compress_func[1](
72 | cache.before_compress_kv_cache[1, :batch_size].view(
73 | batch_size * kernel_size, cache.num_kv_heads, cache.head_dim
74 | ),
75 | weight[1],
76 | aux_cu_seqlens,
77 | kernel_size,
78 | kernel_stride,
79 | )
80 | # return actual compress_cu_seqlens before this token
81 | compress_cu_seqlens = torch.zeros(
82 | batch_size + 1, dtype=torch.int32, device=key.device
83 | )
84 | compress_cu_seqlens[1:] = torch.cumsum(
85 | cache.compress_kv_len[:batch_size], dim=0
86 | )
87 | return key, value, compress_cu_seqlens
88 |
89 |
90 | def compressed_attention_infer(
91 | cu_seqlens,
92 | step,
93 | query,
94 | key,
95 | value,
96 | cache,
97 | kernel_size,
98 | kernel_stride,
99 | topk,
100 | block_size,
101 | init_blocks,
102 | local_blocks,
103 | ):
104 | if step == 0:
105 | seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
106 | compress_seqlens, compress_cu_seqlens = get_compressed_seqlens(
107 | cu_seqlens, kernel_size, kernel_stride
108 | )
109 | attn_output, topk_idx = compressed_attention(
110 | query,
111 | key,
112 | value,
113 | kernel_size,
114 | kernel_stride,
115 | block_size,
116 | topk,
117 | cu_seqlens,
118 | compress_cu_seqlens,
119 | seqlens.max().item(),
120 | compress_seqlens.max().item(),
121 | None,
122 | init_blocks,
123 | local_blocks,
124 | )
125 | else:
126 | batch_size = cu_seqlens.shape[0] - 1
127 | seqlens = cu_seqlens[1:] - cu_seqlens[:-1] + step
128 | attn_output, topk_idx = compressed_attention_decode(
129 | query,
130 | cache.compress_kv_cache[
131 | 0, :batch_size, : cache.compress_kv_len[:batch_size].max()
132 | ],
133 | cache.compress_kv_cache[
134 | 1, :batch_size, : cache.compress_kv_len[:batch_size].max()
135 | ],
136 | seqlens,
137 | cache.compress_kv_len[:batch_size],
138 | kernel_size,
139 | kernel_stride,
140 | block_size,
141 | topk,
142 | init_blocks,
143 | local_blocks,
144 | )
145 | return attn_output, topk_idx
146 |
147 |
148 | def topk_sparse_attention_infer(
149 | cu_seqlens,
150 | step,
151 | query,
152 | key,
153 | value,
154 | cache,
155 | topk_idx,
156 | block_size,
157 | ):
158 | if step == 0:
159 | attn_output = topk_sparse_attention(
160 | query, key, value, topk_idx, block_size, cu_seqlens
161 | )
162 | else:
163 | batch_size = cu_seqlens.shape[0] - 1
164 | attn_output = topk_sparse_attention_decode(
165 | query,
166 | cache.sparse_kv_cache[0, :batch_size],
167 | cache.sparse_kv_cache[1, :batch_size],
168 | topk_idx,
169 | block_size,
170 | cache.sparse_kv_len[:batch_size],
171 | )
172 | return attn_output
173 |
174 |
175 | def sliding_window_attention_infer(
176 | cu_seqlens, step, query, key, value, cache, window_size
177 | ):
178 | if step == 0:
179 | seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
180 | attn_output = flash_attn_varlen_func(
181 | query,
182 | key,
183 | value,
184 | cu_seqlens,
185 | cu_seqlens,
186 | seqlens.max().item(),
187 | seqlens.max().item(),
188 | causal=True,
189 | window_size=(window_size, -1),
190 | )
191 | else:
192 | batch_size = cu_seqlens.shape[0] - 1
193 | attn_output = flash_attention_decode(
194 | query,
195 | cache.sliding_kv_cache[0, :batch_size],
196 | cache.sliding_kv_cache[1, :batch_size],
197 | torch.minimum(
198 | cache.sliding_kv_len,
199 | torch.zeros_like(cache.sliding_kv_len) + window_size,
200 | )[:batch_size],
201 | )
202 | return attn_output
203 |
--------------------------------------------------------------------------------
/native_sparse_attention/infer/nsa_inference.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Xunhao Lai & Jianqiao Lu.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import torch
15 | from typing import Tuple, Callable, Optional
16 | from native_sparse_attention.infer.inference_func import (
17 | compress_infer,
18 | compressed_attention_infer,
19 | topk_sparse_attention_infer,
20 | sliding_window_attention_infer,
21 | )
22 |
23 |
24 | def nsa_infer(
25 | cu_seqlens: torch.Tensor,
26 | step: int,
27 | # qkv for three parts
28 | query: torch.Tensor,
29 | key: torch.Tensor, # prefill: [total_len, num_heads, head_dim], decode: [batch_size, num_heads, head_dim]
30 | value: torch.Tensor,
31 | gate_value: torch.Tensor, # prefill: [total_len, num_heads, 3], decode: [batch_size, num_heads, 3]
32 | # rope and kv cache
33 | rope,
34 | cache,
35 | # weight for nsa compress
36 | compress_weight: Tuple[
37 | torch.Tensor, torch.Tensor
38 | ], # compress weight for key and value
39 | compress_func: Tuple[Callable, Callable], # compress function for key and value
40 | intra_block_pe: Optional[torch.Tensor],
41 | # nsa parameters
42 | kernel_size: int,
43 | kernel_stride: int,
44 | block_size: int,
45 | topk: int,
46 | init_blocks: int,
47 | local_blocks: int,
48 | window_size: int,
49 | ) -> torch.Tensor:
50 | """Inference function for native sparse attention. Support prefill and decode with kv cache.
51 |
52 | Args:
53 | cu_seqlens (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen.
54 | step (int): current inference step, step == 0 means prefill, step > 0 means decode step.
55 | query (torch.Tensor): for prefill, shape [total_len, num_q_heads, head_dim]; for decode, shape [batch_size, num_q_heads, head_dim]
56 | key (torch.Tensor): for prefill, shape [total_len, num_kv_heads, head_dim]; for decode, shape [batch_size, num_kv_heads, head_dim]
57 | value (torch.Tensor): for prefill, shape [total_len, num_kv_heads, head_dim]; for decode, shape [batch_size, num_kv_heads, head_dim]
58 | gate_value (torch.Tensor): for prefill, shape [total_len, num_heads, 3]; for decode, shape [batch_size, num_heads, 3]
59 | rope (RotaryEmbedding): rope module, see native_sparse_attention.module.rope.RotaryEmbedding for details
60 | cache (NSACache): kv cache, seed native_sparse_attention.module.kv_cache.NSACache for details
61 | compress_weight (Tuple[torch.Tensor, torch.Tensor]): compress weight of key and value respectively
62 | compress_func (Tuple[Callable, Callable]): compress functions for key and value respectively
63 | intra_block_pe (Optional[torch.Tensor]): intra-block positonal embedding for compression, set to None if don't use it
64 | kernel_size (int): kernel size of compression
65 | kernel_stride (int): kernel stride ofr compression
66 | block_size (int): block size of sparse attention
67 | topk (int): topk of sparse attention
68 | init_blocks (int): number of blocks at the begining of the sequence, these blocks are force to be computed in sparse attention
69 | local_blocks (int): number of blocks at the local window of each query, these blocks are force to be computed in sparse attention
70 | window_size (int): window size for sliding window attention
71 |
72 | Returns:
73 | torch.Tensor: native sparse attention output, same shape as input query
74 | """
75 | # reset kv cache at the begining of prefilling
76 | if step == 0:
77 | cache.reset()
78 | # prepare for compress
79 | cache.prepare_compress(cu_seqlens, step, key, value)
80 | # compressed key and value before rope
81 | compress_key, compress_value, compress_cu_seqlens = compress_infer(
82 | cu_seqlens,
83 | step,
84 | key,
85 | value,
86 | cache,
87 | compress_weight,
88 | compress_func,
89 | intra_block_pe,
90 | kernel_size,
91 | kernel_stride,
92 | )
93 | # do rope
94 | query = rope(query, cu_seqlens, step)
95 | if step == 0:
96 | compress_key = rope(
97 | compress_key, compress_cu_seqlens, step, stride=cache.kernel_stride
98 | )
99 | else:
100 | compress_key = rope(
101 | compress_key, compress_cu_seqlens, 1, stride=cache.kernel_stride
102 | )
103 | key = rope(key, cu_seqlens, step)
104 | # update kv cache
105 | cache.update_kv(
106 | cu_seqlens,
107 | step,
108 | compress_key,
109 | compress_value,
110 | key,
111 | value,
112 | key,
113 | value,
114 | )
115 | # compressed attention
116 | compress_attn_output, topk_idx = compressed_attention_infer(
117 | cu_seqlens,
118 | step,
119 | query,
120 | compress_key,
121 | compress_value,
122 | cache,
123 | kernel_size,
124 | kernel_stride,
125 | topk,
126 | block_size,
127 | init_blocks,
128 | local_blocks,
129 | )
130 | # topk sparse attention
131 | sparse_attn_output = topk_sparse_attention_infer(
132 | cu_seqlens,
133 | step,
134 | query,
135 | key,
136 | value,
137 | cache,
138 | topk_idx,
139 | block_size,
140 | )
141 | # sliding window attention
142 | sliding_attn_output = sliding_window_attention_infer(
143 | cu_seqlens, step, query, key, value, cache, window_size
144 | )
145 | # combine 3 attn output
146 | attn_output = (
147 | gate_value[..., 0, None] * compress_attn_output
148 | + gate_value[..., 1, None] * sparse_attn_output
149 | + gate_value[..., 2, None] * sliding_attn_output
150 | )
151 | return attn_output
152 |
--------------------------------------------------------------------------------
/native_sparse_attention/model/README.md:
--------------------------------------------------------------------------------
1 | # Guide for the ToyNSALlama Model
2 |
3 | The `ToyNSALlama` model is a custom implementation of a Llama-like transformer architecture featuring a Native Sparse Attention (NSA) module. This guide explains how to integrate the NSA module into your own model.
4 |
5 | ## Overview
6 |
7 | The `ToyNSALlama` model consists of:
8 | - **Configuration**: Defined by `ToyNSALlamaConfig` (model structure parameters) and `InferenceConfig` (inference-specific parameters).
9 | - **Components**: An embedding layer, multiple NativeSparseAttention modules, Feed-Forward Network (FFN) modules, normalization layers, and a language model head.
10 |
11 | ## Step-by-Step Instructions
12 |
13 | ### 1. Import Necessary Modules
14 | ```python
15 | import torch
16 | import torch.nn as nn
17 | from native_sparse_attention.model import ToyNSALlama, ToyNSALlamaConfig, InferenceConfig
18 | ```
19 |
20 | ### 2. Define Configurations
21 | Create instances of `ToyNSALlamaConfig` and `InferenceConfig` to set model and inference parameters.
22 |
23 | #### Model Configuration
24 | The model configuration aligns with the Transformers Llama model configuration. Adjust the following parameters to control the NSA module’s sparsity:
25 | - `compress_type`: Compression method for keys/values. Supported options: `avgpool`, `weightedpool`, `linear`.
26 | - `kernel_size` & `kernel_stride`: `kernel_size` determines how many tokens are compressed into one; `kernel_stride` sets the sliding window stride (must be divisible by `kernel_size`).
27 | - `block_size`: Block size for sparse attention (recommended: 64 or 128).
28 | - `topk`, `init_blocks`, `local_blocks`: `topk` specifies the number of blocks selected in sparse attention; `init_blocks` and `local_blocks` define the number of initial and local blocks that must be selected.
29 | - `window_size`: Size of the sliding window for attention.
30 |
31 | Example:
32 | ```python
33 | config = ToyNSALlamaConfig(
34 | hidden_size=4096,
35 | intermediate_size=14336,
36 | num_hidden_layers=8,
37 | num_attention_heads=32,
38 | num_key_value_heads=2,
39 | head_dim=128,
40 | vocab_size=128288,
41 | max_position_embeddings=131072,
42 | rope_theta=500000.0,
43 | rope_scaling={
44 | "factor": 8.0,
45 | "high_freq_factor": 4.0,
46 | "low_freq_factor": 1.0,
47 | "original_max_position_embeddings": 8192,
48 | "rope_type": "llama3",
49 | },
50 | compress_type="weightedpool",
51 | kernel_size=32,
52 | kernel_stride=16,
53 | block_size=64,
54 | topk=8,
55 | init_blocks=1,
56 | local_blocks=2,
57 | window_size=512,
58 | )
59 | ```
60 |
61 | #### Inference Configuration
62 | This configuration applies during inference, initializing the Key-Value (KV) Cache based on these settings. The full KV cache size is calculated as `max_batch_size × max_length × num_kv_heads × num_layers × 2 × 2` bytes. Currently, only greedy decoding is supported as an example.
63 |
64 | Example:
65 | ```python
66 | inference_config = InferenceConfig(
67 | max_batch_size=4,
68 | max_length=8192,
69 | max_new_tokens=128,
70 | )
71 | ```
72 |
73 | ### 3. Initialize the Model
74 | Instantiate the model and move it to the GPU with the appropriate data type (currently, only `bfloat16` is supported).
75 |
76 | ```python
77 | model = ToyNSALlama(config, inference_config).cuda().to(torch.bfloat16)
78 | ```
79 |
80 | ### 4. Forward & Generate
81 | The model supports two methods:
82 | - **`forward`**: Accepts `input_ids` and `cu_seqlens`, returning final logits after the language model head. Use this for training or evaluation.
83 | - **`generate`**: Accepts `input_ids` and `cu_seqlens`, generating output tokens via greedy sampling. This demonstrates KV cache usage for token generation (pre-filling and decoding).
84 |
85 | Example:
86 | ```python
87 | # Example input
88 | batch_size = 4
89 | seqlens = torch.randint(0, 4096, (batch_size,), dtype=torch.int32, device="cuda")
90 | cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device="cuda")
91 | cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
92 | input_ids = torch.randint(0, 128288, (cu_seqlens[-1],), dtype=torch.int64, device="cuda")
93 | print(f"\nEXAMPLE INPUT:\ncu_seqlens: {cu_seqlens}\ninput_ids: {input_ids.shape}\n")
94 |
95 | # Example forward
96 | logits = model(input_ids, cu_seqlens)
97 | print(f"\nEXAMPLE OUTPUT:\nlogits: {logits.shape}\n")
98 |
99 | # Example generate
100 | output_tokens = model.generate(input_ids, cu_seqlens)
101 | print(f"\nEXAMPLE GENERATE:\noutput_tokens: {output_tokens}\n")
102 | ```
103 |
104 | ## Toy Llama Model with Self-Attention
105 | A simpler toy model with the Llama structure is available in `native_sparse_attention/model/toy_llama.py`. Compare `ToyLlama` and `ToyNSALlama` to see how to adapt a self-attention model into an NSA-based model.
106 |
107 | The primary difference lies in replacing the `SelfAttention` module with the `NativeSparseAttention` module, along with updates to the KV cache and inference function. These changes are straightforward and easy to implement.
108 |
--------------------------------------------------------------------------------
/native_sparse_attention/model/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Xunhao Lai & Jianqiao Lu.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from native_sparse_attention.model.toy_llama import (
15 | ToyLlamaConfig,
16 | InferenceConfig,
17 | ToyLlama,
18 | )
19 | from native_sparse_attention.model.toy_nsa_llama import (
20 | ToyNSALlamaConfig,
21 | InferenceConfig,
22 | ToyNSALlama,
23 | )
24 |
25 | __all__ = [
26 | "ToyLlamaConfig",
27 | "ToyNSALlamaConfig",
28 | "InferenceConfig",
29 | "ToyLlama",
30 | "ToyNSALlama",
31 | ]
32 |
--------------------------------------------------------------------------------
/native_sparse_attention/model/toy_llama.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Xunhao Lai & Jianqiao Lu.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from typing import Optional
15 | import torch
16 | import torch.nn as nn
17 | from dataclasses import dataclass, field
18 | from native_sparse_attention.module import SelfAttention, RopeConfig, KVCache
19 |
20 |
21 | @dataclass
22 | class ToyLlamaConfig:
23 | # embedding config
24 | vocab_size: int = 128288
25 | max_position_embeddings: int = 131072
26 | # model config
27 | hidden_size: int = 4096
28 | intermediate_size: int = 14336
29 | num_hidden_layers: int = 32
30 | num_attention_heads: int = 32
31 | num_key_value_heads: int = 2
32 | head_dim: int = 128
33 | # rope config
34 | rope_theta: float = 500000.0
35 | rope_scaling: dict = field(
36 | default_factory=lambda: {
37 | "factor": 8.0,
38 | "high_freq_factor": 4.0,
39 | "low_freq_factor": 1.0,
40 | "original_max_position_embeddings": 8192,
41 | "rope_type": "llama3",
42 | }
43 | )
44 |
45 |
46 | @dataclass
47 | class InferenceConfig:
48 | max_batch_size: int = 32
49 | max_length: int = 8192
50 | max_new_tokens: int = 128
51 |
52 |
53 | class RMSNorm(nn.Module):
54 | def __init__(self, hidden_size: int, eps: float = 1e-6):
55 | super().__init__()
56 | self.weight = nn.Parameter(torch.ones(hidden_size))
57 | self.variance_epsilon = eps
58 |
59 | def forward(self, hidden_states: torch.Tensor):
60 | input_dtype = hidden_states.dtype
61 | hidden_states = hidden_states.to(torch.float32)
62 | variance = hidden_states.pow(2).mean(-1, keepdim=True)
63 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
64 | return self.weight * hidden_states.to(input_dtype)
65 |
66 |
67 | class FFN(nn.Module):
68 | def __init__(self, hidden_size: int, intermediate_size: int):
69 | super().__init__()
70 | self.hidden_size = hidden_size
71 | self.intermediate_size = intermediate_size
72 | self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
73 | self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
74 | self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
75 | self.act_fn = nn.SiLU()
76 |
77 | def forward(self, x):
78 | down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
79 | return down_proj
80 |
81 |
82 | class ToyLlamaLayer(nn.Module):
83 | def __init__(
84 | self,
85 | hidden_size: int,
86 | intermediate_size: int,
87 | num_q_heads: int,
88 | num_kv_heads: int,
89 | head_dim: int,
90 | rope_config: RopeConfig,
91 | ):
92 | super().__init__()
93 | self.hidden_size = hidden_size
94 | self.intermediate_size = intermediate_size
95 | self.num_q_heads = num_q_heads
96 | self.num_kv_heads = num_kv_heads
97 | self.head_dim = head_dim
98 | self.rope_config = rope_config
99 | self.attn_norm = RMSNorm(self.hidden_size)
100 | self.self_attn = SelfAttention(
101 | hidden_size=self.hidden_size,
102 | num_q_heads=self.num_q_heads,
103 | num_kv_heads=self.num_kv_heads,
104 | head_dim=self.head_dim,
105 | rope_config=rope_config,
106 | )
107 | self.ffn_norm = RMSNorm(self.hidden_size)
108 | self.ffn = FFN(
109 | hidden_size=self.hidden_size, intermediate_size=self.intermediate_size
110 | )
111 |
112 | def forward(self, x, cu_seqlens):
113 | x = x + self.self_attn(self.attn_norm(x), cu_seqlens)
114 | x = x + self.ffn(self.ffn_norm(x))
115 | return x
116 |
117 | @torch.no_grad()
118 | def inference(self, x, cu_seqlens, step, kv_cache):
119 | x = x + self.self_attn.inference(self.attn_norm(x), cu_seqlens, step, kv_cache)
120 | x = x + self.ffn(self.ffn_norm(x))
121 | return x
122 |
123 |
124 | class ToyLlama(nn.Module):
125 | def __init__(
126 | self, config: ToyLlamaConfig, inference_config: Optional[InferenceConfig] = None
127 | ):
128 | super().__init__()
129 | self.config = config
130 | self.embedding = nn.Embedding(self.config.vocab_size, self.config.hidden_size)
131 | self.rope_config = RopeConfig(
132 | head_dim=self.config.head_dim,
133 | rope_theta=self.config.rope_theta,
134 | rope_scaling=self.config.rope_scaling,
135 | )
136 | self.layers = nn.ModuleList(
137 | [
138 | ToyLlamaLayer(
139 | hidden_size=self.config.hidden_size,
140 | intermediate_size=self.config.intermediate_size,
141 | num_q_heads=self.config.num_attention_heads,
142 | num_kv_heads=self.config.num_key_value_heads,
143 | head_dim=self.config.head_dim,
144 | rope_config=RopeConfig(
145 | self.config.max_position_embeddings,
146 | self.config.head_dim,
147 | self.config.rope_theta,
148 | self.config.rope_scaling,
149 | ),
150 | )
151 | for _ in range(self.config.num_hidden_layers)
152 | ]
153 | )
154 | self.norm = RMSNorm(self.config.hidden_size)
155 | self.lm_head = nn.Linear(
156 | self.config.hidden_size, self.config.vocab_size, bias=False
157 | )
158 |
159 | # inference config and kv cache
160 | self.inference_config = inference_config
161 | self.kv_cache = None
162 |
163 | def forward(
164 | self,
165 | input_ids: torch.LongTensor, # shape: [total_length, ]
166 | cu_seqlens: torch.LongTensor, # shape: [batch_size + 1, ]
167 | ):
168 | # embedding
169 | x = self.embedding(input_ids).to(torch.bfloat16)
170 | # layers
171 | for layer in self.layers:
172 | x = layer(x, cu_seqlens)
173 | # final norm
174 | x = self.norm(x)
175 | # lanugauge head
176 | x = self.lm_head(x).to(torch.float32) # [total_len, vocab_size]
177 | return x
178 |
179 | @torch.no_grad()
180 | def inference(
181 | self,
182 | input_ids: torch.LongTensor, # prefill shape: [total_length, ]; decode shape: [batch_size, ]
183 | cu_seqlens: torch.LongTensor, # shape: [batch_size + 1, ]
184 | step: int,
185 | ):
186 | # set kv cache if self.kv_cache is None
187 | if self.kv_cache is None:
188 | self.kv_cache = [
189 | KVCache(
190 | max_batch_size=self.inference_config.max_batch_size,
191 | max_length=self.inference_config.max_length,
192 | num_kv_heads=self.config.num_key_value_heads,
193 | head_dim=self.config.head_dim,
194 | dtype=torch.bfloat16,
195 | device="cuda",
196 | )
197 | for _ in range(self.config.num_hidden_layers)
198 | ]
199 | # embedding
200 | x = self.embedding(input_ids).to(torch.bfloat16)
201 | # layers
202 | for i, layer in enumerate(self.layers):
203 | x = layer.inference(x, cu_seqlens, step, self.kv_cache[i])
204 | # final norm
205 | x = self.norm(x)
206 | # lanugauge head
207 | if step == 0:
208 | x = x[cu_seqlens[1:] - 1, :]
209 | x = self.lm_head(x).to(torch.float32) # [total_len, vocab_size]
210 | return x
211 |
212 | def generate(
213 | self,
214 | input_ids: torch.LongTensor,
215 | cu_seqlens: torch.LongTensor,
216 | max_new_tokens: int = -1,
217 | ):
218 | output_tokens = []
219 | if max_new_tokens <= 0:
220 | max_new_tokens = self.inference_config.max_new_tokens
221 | for step in range(max_new_tokens):
222 | logits = self.inference(
223 | input_ids, cu_seqlens, step
224 | ) # shape: [batch_size, vocab_size]
225 | next_token = torch.argmax(logits, dim=-1) # shape: [batch_size, ]
226 | input_ids = next_token
227 | output_tokens.append(next_token)
228 | output_tokens = torch.stack(
229 | output_tokens, dim=1
230 | ) # shape: [batch_size, max_new_tokens]
231 | return output_tokens
232 |
233 |
234 | if __name__ == "__main__":
235 | torch.manual_seed(42)
236 | # initialize model
237 | config = ToyLlamaConfig(
238 | hidden_size=4096,
239 | intermediate_size=14336,
240 | num_hidden_layers=8,
241 | num_attention_heads=32,
242 | num_key_value_heads=2,
243 | head_dim=128,
244 | rope_theta=500000.0,
245 | rope_scaling={
246 | "factor": 8.0,
247 | "high_freq_factor": 4.0,
248 | "low_freq_factor": 1.0,
249 | "original_max_position_embeddings": 8192,
250 | "rope_type": "llama3",
251 | },
252 | )
253 | inference_config = InferenceConfig(
254 | max_batch_size=4,
255 | max_length=8192,
256 | max_new_tokens=128,
257 | )
258 | model = ToyLlama(config, inference_config).cuda().bfloat16()
259 | print(f"\nMODEL CONFIG:\n{config}\n")
260 | print(f"\nINFERENCE CONFIG:\n{inference_config}\n")
261 | print(f"\nMODEL:\n{model}\n")
262 |
263 | # example input
264 | batch_size = 4
265 | seqlens = torch.randint(0, 4096, (batch_size,), dtype=torch.int32, device="cuda")
266 | cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device="cuda")
267 | cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
268 | input_ids = torch.randint(
269 | 0, 128288, (cu_seqlens[-1],), dtype=torch.int64, device="cuda"
270 | )
271 | print(f"\nEXAMPLE INPUT:\ncu_seqlens: {cu_seqlens}\ninput_ids: {input_ids.shape}\n")
272 |
273 | # example output
274 | logits = model(input_ids, cu_seqlens)
275 | print(f"\nEXAMPLE OUTPUT:\nlogits: {logits.shape}\n")
276 |
277 | # example generate
278 | output_tokens = model.generate(input_ids, cu_seqlens, 64)
279 | print(f"\nEXAMPLE GENERATE:\noutput_tokens: {output_tokens}\n")
280 |
--------------------------------------------------------------------------------
/native_sparse_attention/model/toy_nsa_llama.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Xunhao Lai & Jianqiao Lu.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from typing import Optional
15 | import torch
16 | import torch.nn as nn
17 | from dataclasses import dataclass, field
18 | from native_sparse_attention.module import NativeSparseAttention, RopeConfig, NSACache
19 |
20 |
21 | @dataclass
22 | class ToyNSALlamaConfig:
23 | # embedding config
24 | vocab_size: int = 128288
25 | max_position_embeddings: int = 131072
26 | # model config
27 | hidden_size: int = 4096
28 | intermediate_size: int = 14336
29 | num_hidden_layers: int = 32
30 | num_attention_heads: int = 32
31 | num_key_value_heads: int = 2
32 | head_dim: int = 128
33 | # rope config
34 | rope_theta: float = 500000.0
35 | rope_scaling: dict = field(
36 | default_factory=lambda: {
37 | "factor": 8.0,
38 | "high_freq_factor": 4.0,
39 | "low_freq_factor": 1.0,
40 | "original_max_position_embeddings": 8192,
41 | "rope_type": "llama3",
42 | }
43 | )
44 | # nsa config
45 | compress_type: str = "weightedpool"
46 | kernel_size: int = 32
47 | kernel_stride: int = 16
48 | block_size: int = 64
49 | topk: int = 16
50 | init_blocks: int = 1
51 | local_blocks: int = 2
52 | window_size: int = 512
53 |
54 |
55 | @dataclass
56 | class InferenceConfig:
57 | max_batch_size: int = 32
58 | max_length: int = 8192
59 | max_new_tokens: int = 128
60 |
61 |
62 | class RMSNorm(nn.Module):
63 | def __init__(self, hidden_size: int, eps: float = 1e-6):
64 | super().__init__()
65 | self.weight = nn.Parameter(torch.ones(hidden_size))
66 | self.variance_epsilon = eps
67 |
68 | def forward(self, hidden_states: torch.Tensor):
69 | input_dtype = hidden_states.dtype
70 | hidden_states = hidden_states.to(torch.float32)
71 | variance = hidden_states.pow(2).mean(-1, keepdim=True)
72 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
73 | return self.weight * hidden_states.to(input_dtype)
74 |
75 |
76 | class FFN(nn.Module):
77 | def __init__(self, hidden_size: int, intermediate_size: int):
78 | super().__init__()
79 | self.hidden_size = hidden_size
80 | self.intermediate_size = intermediate_size
81 | self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
82 | self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
83 | self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
84 | self.act_fn = nn.SiLU()
85 |
86 | def forward(self, x):
87 | down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
88 | return down_proj
89 |
90 |
91 | class ToyNSALlamaLayer(nn.Module):
92 | def __init__(
93 | self,
94 | hidden_size: int,
95 | intermediate_size: int,
96 | num_q_heads: int,
97 | num_kv_heads: int,
98 | head_dim: int,
99 | compress_type: str,
100 | kernel_size: int,
101 | kernel_stride: int,
102 | block_size: int,
103 | topk: int,
104 | init_blocks: int,
105 | local_blocks: int,
106 | window_size: int,
107 | rope_config: RopeConfig,
108 | ):
109 | super().__init__()
110 | self.hidden_size = hidden_size
111 | self.intermediate_size = intermediate_size
112 | self.num_q_heads = num_q_heads
113 | self.num_kv_heads = num_kv_heads
114 | self.head_dim = head_dim
115 | self.compress_type = compress_type
116 | self.kernel_size = kernel_size
117 | self.kernel_stride = kernel_stride
118 | self.block_size = block_size
119 | self.topk = topk
120 | self.init_blocks = init_blocks
121 | self.local_blocks = local_blocks
122 | self.window_size = window_size
123 | self.rope_config = rope_config
124 | self.attn_norm = RMSNorm(self.hidden_size)
125 | self.nsa = NativeSparseAttention(
126 | hidden_size=self.hidden_size,
127 | num_q_heads=self.num_q_heads,
128 | num_kv_heads=self.num_kv_heads,
129 | head_dim=self.head_dim,
130 | compress_type=self.compress_type,
131 | kernel_size=self.kernel_size,
132 | kernel_stride=self.kernel_stride,
133 | block_size=self.block_size,
134 | topk=self.topk,
135 | init_blocks=self.init_blocks,
136 | local_blocks=self.local_blocks,
137 | window_size=self.window_size,
138 | rope_config=rope_config,
139 | )
140 | self.ffn_norm = RMSNorm(self.hidden_size)
141 | self.ffn = FFN(
142 | hidden_size=self.hidden_size, intermediate_size=self.intermediate_size
143 | )
144 |
145 | def forward(self, x, cu_seqlens):
146 | x = x + self.nsa(self.attn_norm(x), cu_seqlens)
147 | x = x + self.ffn(self.ffn_norm(x))
148 | return x
149 |
150 | @torch.no_grad()
151 | def inference(self, x, cu_seqlens, step, kv_cache):
152 | x = x + self.nsa.inference(self.attn_norm(x), cu_seqlens, step, kv_cache)
153 | x = x + self.ffn(self.ffn_norm(x))
154 | return x
155 |
156 |
157 | class ToyNSALlama(nn.Module):
158 | def __init__(
159 | self,
160 | config: ToyNSALlamaConfig,
161 | inference_config: Optional[InferenceConfig] = None,
162 | ):
163 | super().__init__()
164 | self.config = config
165 | self.embedding = nn.Embedding(self.config.vocab_size, self.config.hidden_size)
166 | self.rope_config = RopeConfig(
167 | head_dim=self.config.head_dim,
168 | rope_theta=self.config.rope_theta,
169 | rope_scaling=self.config.rope_scaling,
170 | )
171 | self.layers = nn.ModuleList(
172 | [
173 | ToyNSALlamaLayer(
174 | hidden_size=self.config.hidden_size,
175 | intermediate_size=self.config.intermediate_size,
176 | num_q_heads=self.config.num_attention_heads,
177 | num_kv_heads=self.config.num_key_value_heads,
178 | head_dim=self.config.head_dim,
179 | compress_type=self.config.compress_type,
180 | kernel_size=self.config.kernel_size,
181 | kernel_stride=self.config.kernel_stride,
182 | block_size=self.config.block_size,
183 | topk=self.config.topk,
184 | init_blocks=self.config.init_blocks,
185 | local_blocks=self.config.local_blocks,
186 | window_size=self.config.window_size,
187 | rope_config=RopeConfig(
188 | self.config.max_position_embeddings,
189 | self.config.head_dim,
190 | self.config.rope_theta,
191 | self.config.rope_scaling,
192 | ),
193 | )
194 | for _ in range(self.config.num_hidden_layers)
195 | ]
196 | )
197 | self.norm = RMSNorm(self.config.hidden_size)
198 | self.lm_head = nn.Linear(
199 | self.config.hidden_size, self.config.vocab_size, bias=False
200 | )
201 |
202 | # inference config and kv cache
203 | self.inference_config = inference_config
204 | self.kv_cache = None
205 |
206 | def forward(
207 | self,
208 | input_ids: torch.LongTensor, # shape: [batch_size, max_length]
209 | cu_seqlens: torch.LongTensor, # shape: [batch_size + 1, ]
210 | ):
211 | # embedding
212 | x = self.embedding(input_ids).to(torch.bfloat16)
213 | # layers
214 | for layer in self.layers:
215 | x = layer(x, cu_seqlens)
216 | # final norm
217 | x = self.norm(x)
218 | # lanugauge head
219 | x = self.lm_head(x).to(torch.float32) # [total_len, vocab_size]
220 | return x
221 |
222 | @torch.no_grad()
223 | def inference(
224 | self,
225 | input_ids: torch.LongTensor, # prefill shape: [total_length, ]; decode shape: [batch_size, ]
226 | cu_seqlens: torch.LongTensor, # shape: [batch_size + 1, ]
227 | step: int,
228 | ):
229 | # set kv cache if self.kv_cache is None
230 | if self.kv_cache is None:
231 | self.kv_cache = [
232 | NSACache(
233 | max_batch_size=self.inference_config.max_batch_size,
234 | max_length=self.inference_config.max_length,
235 | num_kv_heads=self.config.num_key_value_heads,
236 | head_dim=self.config.head_dim,
237 | kernel_size=self.config.kernel_size,
238 | kernel_stride=self.config.kernel_stride,
239 | window_size=self.config.window_size,
240 | dtype=torch.bfloat16,
241 | device="cuda",
242 | )
243 | for _ in range(self.config.num_hidden_layers)
244 | ]
245 | # embedding
246 | x = self.embedding(input_ids).to(torch.bfloat16)
247 | # layers
248 | for i, layer in enumerate(self.layers):
249 | x = layer.inference(x, cu_seqlens, step, self.kv_cache[i])
250 | # final norm
251 | x = self.norm(x)
252 | # lanugauge head
253 | if step == 0:
254 | x = x[cu_seqlens[1:] - 1, :]
255 | x = self.lm_head(x).to(torch.float32) # [total_len, vocab_size]
256 | return x
257 |
258 | def generate(
259 | self,
260 | input_ids: torch.LongTensor,
261 | cu_seqlens: torch.LongTensor,
262 | max_new_tokens: int = -1,
263 | ):
264 | output_tokens = []
265 | if max_new_tokens <= 0:
266 | max_new_tokens = self.inference_config.max_new_tokens
267 | for step in range(max_new_tokens):
268 | logits = self.inference(
269 | input_ids, cu_seqlens, step
270 | ) # shape: [batch_size, vocab_size]
271 | next_token = torch.argmax(logits, dim=-1) # shape: [batch_size, ]
272 | input_ids = next_token
273 | output_tokens.append(next_token)
274 | output_tokens = torch.stack(
275 | output_tokens, dim=1
276 | ) # shape: [batch_size, max_new_tokens]
277 | return output_tokens
278 |
279 |
280 | if __name__ == "__main__":
281 | torch.manual_seed(42)
282 | # initialize model
283 | config = ToyNSALlamaConfig(
284 | hidden_size=4096,
285 | intermediate_size=14336,
286 | num_hidden_layers=8,
287 | num_attention_heads=32,
288 | num_key_value_heads=2,
289 | head_dim=128,
290 | rope_theta=500000.0,
291 | rope_scaling={
292 | "factor": 8.0,
293 | "high_freq_factor": 4.0,
294 | "low_freq_factor": 1.0,
295 | "original_max_position_embeddings": 8192,
296 | "rope_type": "llama3",
297 | },
298 | compress_type="weightedpool",
299 | kernel_size=32,
300 | kernel_stride=16,
301 | block_size=64,
302 | topk=8,
303 | init_blocks=1,
304 | local_blocks=2,
305 | window_size=512,
306 | )
307 | inference_config = InferenceConfig(
308 | max_batch_size=4,
309 | max_length=8192,
310 | max_new_tokens=128,
311 | )
312 | model = ToyNSALlama(config, inference_config).cuda().bfloat16()
313 | print(f"\nMODEL CONFIG:\n{config}\n")
314 | print(f"\nINFERENCE CONFIG:\n{inference_config}\n")
315 | print(f"\nMODEL:\n{model}\n")
316 |
317 | # example input
318 | batch_size = 4
319 | seqlens = torch.randint(0, 4096, (batch_size,), dtype=torch.int32, device="cuda")
320 | cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device="cuda")
321 | cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
322 | input_ids = torch.randint(
323 | 0, 128288, (cu_seqlens[-1],), dtype=torch.int64, device="cuda"
324 | )
325 | print(f"\nEXAMPLE INPUT:\ncu_seqlens: {cu_seqlens}\ninput_ids: {input_ids.shape}\n")
326 |
327 | # example output
328 | logits = model(input_ids, cu_seqlens)
329 | print(f"\nEXAMPLE OUTPUT:\nlogits: {logits.shape}\n")
330 |
331 | # example generate
332 | output_tokens = model.generate(input_ids, cu_seqlens, 64)
333 | print(f"\nEXAMPLE GENERATE:\noutput_tokens: {output_tokens}\n")
334 |
--------------------------------------------------------------------------------
/native_sparse_attention/module/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Xunhao Lai & Jianqiao Lu.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from native_sparse_attention.module.native_sparse_attention import NativeSparseAttention
15 | from native_sparse_attention.module.self_attention import SelfAttention
16 | from native_sparse_attention.module.rope import RotaryEmbedding, RopeConfig
17 | from native_sparse_attention.module.kv_cache import NSACache, KVCache
18 |
19 | __all__ = [
20 | "SelfAttention",
21 | "NativeSparseAttention",
22 | "RotaryEmbedding",
23 | "RopeConfig",
24 | "NSACache",
25 | ]
26 |
--------------------------------------------------------------------------------
/native_sparse_attention/module/native_sparse_attention.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Xunhao Lai & Jianqiao Lu.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import torch
15 | from flash_attn import flash_attn_varlen_func
16 | from native_sparse_attention.ops import (
17 | compressed_attention,
18 | topk_sparse_attention,
19 | avgpool_compress,
20 | weightedpool_compress,
21 | linear_compress,
22 | )
23 | from einops import rearrange
24 | from native_sparse_attention.module.rope import RopeConfig, RotaryEmbedding
25 | from native_sparse_attention.infer import nsa_infer
26 | from native_sparse_attention.module.kv_cache import NSACache
27 |
28 | COMPRESS_TYPE_TO_FUNC = {
29 | "avgpool": avgpool_compress,
30 | "weightedpool": weightedpool_compress,
31 | "linear": linear_compress,
32 | }
33 |
34 | COMPRESS_TYPE_TO_WEIGHT = {
35 | "avgpool": lambda num_heads, head_dim, kernel_size: None,
36 | "weightedpool": lambda num_heads, head_dim, kernel_size: torch.nn.Parameter(
37 | torch.zeros(num_heads, kernel_size)
38 | ),
39 | "linear": lambda num_heads, head_dim, kernel_size: torch.nn.Parameter(
40 | torch.zeros(num_heads, head_dim * kernel_size, head_dim)
41 | ),
42 | }
43 |
44 |
45 | class NativeSparseAttention(torch.nn.Module):
46 | """Native sparse attention module, support training and inference
47 |
48 | Args:
49 | compress_type (str): key value compression type, currently support ['linear', 'avgpool', 'weightedpool']
50 | hidden_size (int): hidden dimension
51 | num_q_heads (int): number of query heads
52 | num_kv_heads (int): number of key/value heads, must be divisible by num_q_heads
53 | head_dim (int): head dim
54 | kernel_size (int): kernel size of compression
55 | kernel_stride (int): kernel stride ofr compression
56 | block_size (int): block size of sparse attention
57 | topk (int): topk of sparse attention
58 | init_blocks (int): number of blocks at the begining of the sequence, these blocks are force to be computed in sparse attention
59 | local_blocks (int): number of blocks at the local window of each query, these blocks are force to be computed in sparse attention
60 | window_size (int): window size for sliding window attention
61 | rope_config (RopeConfig): config for rotary embedding, see native_sparse_attention.module.rope.RopeConfig for details
62 | rope_device (str): device used to store rope freqs
63 | """
64 |
65 | def __init__(
66 | self,
67 | compress_type: str,
68 | hidden_size: int,
69 | num_q_heads: int,
70 | num_kv_heads: int,
71 | head_dim: int,
72 | kernel_size: int,
73 | kernel_stride: int,
74 | block_size: int,
75 | topk: int,
76 | init_blocks: int,
77 | local_blocks: int,
78 | window_size: int,
79 | rope_config: RopeConfig,
80 | rope_device: str = "cuda",
81 | ):
82 | super().__init__()
83 | # configs
84 | self.compress_type = compress_type
85 | self.hidden_size = hidden_size
86 | self.num_q_heads = num_q_heads
87 | self.num_kv_heads = num_kv_heads
88 | self.head_dim = head_dim
89 | self.kernel_size = kernel_size
90 | self.kernel_stride = kernel_stride
91 | self.block_size = block_size
92 | self.topk = topk
93 | self.init_blocks = init_blocks
94 | self.local_blocks = local_blocks
95 | self.window_size = window_size
96 | self.rope_config = rope_config
97 | assert self.head_dim == self.rope_config.head_dim
98 |
99 | # qkv proj and o proj
100 | self.proj_q = torch.nn.Linear(
101 | self.hidden_size, self.num_q_heads * self.head_dim, bias=False
102 | )
103 | self.proj_k = torch.nn.Linear(
104 | self.hidden_size, self.num_kv_heads * self.head_dim, bias=False
105 | )
106 | self.proj_v = torch.nn.Linear(
107 | self.hidden_size, self.num_kv_heads * self.head_dim, bias=False
108 | )
109 | self.proj_o = torch.nn.Linear(
110 | self.num_q_heads * self.head_dim, self.hidden_size, bias=False
111 | )
112 |
113 | # nsa compress func
114 | self.compress_func = COMPRESS_TYPE_TO_FUNC[self.compress_type]
115 |
116 | # nsa parameteres
117 | self.compress_key = COMPRESS_TYPE_TO_WEIGHT[self.compress_type](
118 | num_kv_heads, head_dim, kernel_size
119 | )
120 |
121 | self.compress_value = COMPRESS_TYPE_TO_WEIGHT[self.compress_type](
122 | num_kv_heads, head_dim, kernel_size
123 | )
124 | self.intra_block_pe = torch.nn.Parameter(
125 | torch.zeros(self.num_kv_heads, self.kernel_size, self.head_dim)
126 | )
127 |
128 | # gate function
129 | self.gate = torch.nn.Sequential(
130 | torch.nn.Linear(self.hidden_size, self.num_q_heads * 3, bias=False),
131 | torch.nn.Sigmoid(),
132 | )
133 |
134 | # rope
135 | self.rope = RotaryEmbedding(self.rope_config, device=rope_device)
136 |
137 | # init parameters
138 | self.init_params()
139 |
140 | def init_params(self):
141 | for p in self.parameters():
142 | if len(p.shape) > 1:
143 | torch.nn.init.xavier_uniform_(p)
144 |
145 | def forward(
146 | self,
147 | x: torch.Tensor, # shape: [total_len, hidden_size]
148 | cu_seqlens: torch.Tensor, # shape: [batch_size + 1]
149 | ):
150 | # dtype and shape check
151 | assert x.dtype == torch.bfloat16 or x.dtype == torch.float16
152 | assert x.shape[-1] == self.hidden_size
153 | cu_seqlens = cu_seqlens.to(torch.int32)
154 | seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
155 |
156 | # qkv proj
157 | q = self.proj_q(x).view(-1, self.num_q_heads, self.head_dim)
158 | k = self.proj_k(x).view(-1, self.num_kv_heads, self.head_dim)
159 | v = self.proj_v(x).view(-1, self.num_kv_heads, self.head_dim)
160 |
161 | # compressed key and value before rope
162 | compressed_k, compressed_cu_seqlens = self.compress_func(
163 | k,
164 | self.compress_key,
165 | cu_seqlens,
166 | self.kernel_size,
167 | self.kernel_stride,
168 | self.intra_block_pe,
169 | )
170 | compressed_v, _ = self.compress_func(
171 | v,
172 | self.compress_value,
173 | cu_seqlens,
174 | self.kernel_size,
175 | self.kernel_stride,
176 | None,
177 | )
178 |
179 | # do rope for query and compressed key
180 | q = self.rope(q, cu_seqlens)
181 | compressed_k = self.rope(
182 | compressed_k, compressed_cu_seqlens, stride=self.kernel_stride
183 | )
184 |
185 | # attention between query and compressed key value
186 | compressed_seqlens = compressed_cu_seqlens[1:] - compressed_cu_seqlens[:-1]
187 | compressed_attn_output, topk_idx = compressed_attention(
188 | q,
189 | compressed_k,
190 | compressed_v,
191 | self.kernel_size,
192 | self.kernel_stride,
193 | self.block_size,
194 | self.topk,
195 | cu_seqlens,
196 | compressed_cu_seqlens,
197 | seqlens.max().item(),
198 | compressed_seqlens.max().item(),
199 | None,
200 | self.init_blocks,
201 | self.local_blocks,
202 | )
203 |
204 | # do rope for original key
205 | k = self.rope(k, cu_seqlens)
206 |
207 | # topk sparse attention
208 | sparse_attn_output = topk_sparse_attention(
209 | q, k, v, topk_idx, self.block_size, cu_seqlens, None
210 | )
211 |
212 | # sliding window attention
213 | sliding_attn_output = flash_attn_varlen_func(
214 | q,
215 | k,
216 | v,
217 | cu_seqlens,
218 | cu_seqlens,
219 | seqlens.max().item(),
220 | seqlens.max().item(),
221 | causal=True,
222 | window_size=(self.window_size, -1),
223 | )
224 |
225 | # gate average
226 | gate = self.gate(x)
227 | gate = rearrange(gate, "n (h g) -> n h g", g=3)
228 | attn_output = (
229 | gate[..., 0:1] * compressed_attn_output
230 | + gate[..., 1:2] * sparse_attn_output
231 | + gate[..., 2:3] * sliding_attn_output
232 | )
233 |
234 | # rearrange and output proj
235 | attn_output = rearrange(attn_output, "n h d -> n (h d)")
236 | attn_output = self.proj_o(attn_output)
237 |
238 | return attn_output
239 |
240 | @torch.no_grad()
241 | def inference(
242 | self,
243 | x: torch.Tensor, # shape: [total_len, hidden_size]
244 | cu_seqlens: torch.Tensor, # shape: [batch_size + 1]
245 | step: int,
246 | cache: NSACache,
247 | ):
248 | # dtype and shape check
249 | assert x.dtype == torch.bfloat16 or x.dtype == torch.float16
250 | assert x.shape[-1] == self.hidden_size
251 | cu_seqlens = cu_seqlens.to(torch.int32)
252 | assert step >= 0
253 | if step == 0:
254 | assert x.shape[0] == cu_seqlens[-1]
255 | else:
256 | assert x.shape[0] == cu_seqlens.shape[0] - 1
257 | # qkv proj
258 | q = self.proj_q(x).view(-1, self.num_q_heads, self.head_dim)
259 | k = self.proj_k(x).view(-1, self.num_kv_heads, self.head_dim)
260 | v = self.proj_v(x).view(-1, self.num_kv_heads, self.head_dim)
261 | # gate proj
262 | gate = self.gate(x)
263 | gate = rearrange(gate, "n (h g) -> n h g", g=3)
264 | # nsa infer
265 | output = nsa_infer(
266 | cu_seqlens,
267 | step,
268 | q,
269 | k,
270 | v,
271 | gate,
272 | self.rope,
273 | cache,
274 | [self.compress_key, self.compress_value],
275 | [self.compress_func, self.compress_func],
276 | self.intra_block_pe,
277 | self.kernel_size,
278 | self.kernel_stride,
279 | self.block_size,
280 | self.topk,
281 | self.init_blocks,
282 | self.local_blocks,
283 | self.window_size,
284 | )
285 | # output proj
286 | output = rearrange(output, "n h d -> n (h d)")
287 | output = self.proj_o(output)
288 | return output
289 |
--------------------------------------------------------------------------------
/native_sparse_attention/module/rope.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Xunhao Lai & Jianqiao Lu.
2 | # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3 | #
4 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5 | # and OPT implementations in this library. It has been modified from its
6 | # original forms to accommodate minor architectural differences compared
7 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8 | #
9 | # Licensed under the Apache License, Version 2.0 (the "License");
10 | # you may not use this file except in compliance with the License.
11 | # You may obtain a copy of the License at
12 | #
13 | # http://www.apache.org/licenses/LICENSE-2.0
14 | #
15 | # Unless required by applicable law or agreed to in writing, software
16 | # distributed under the License is distributed on an "AS IS" BASIS,
17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18 | # See the License for the specific language governing permissions and
19 | # limitations under the License.
20 | import torch
21 | from dataclasses import dataclass, field
22 | from torch import nn
23 | from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
24 |
25 |
26 | # default to llama3.1 rope config
27 | @dataclass
28 | class RopeConfig:
29 | """Config for RotaryEmbedding, similar to transformers llama."""
30 |
31 | max_position_embeddings: int = 131072
32 | head_dim: int = 128
33 | rope_theta: float = 500000
34 | rope_scaling: dict = field(
35 | default_factory=lambda: {
36 | "factor": 8.0,
37 | "high_freq_factor": 4.0,
38 | "low_freq_factor": 1.0,
39 | "original_max_position_embeddings": 8192,
40 | "rope_type": "llama3",
41 | }
42 | )
43 | # useless, just for compatibility, please use head_dim instead
44 | hidden_size: int = 1
45 | num_attention_heads: int = 1
46 |
47 | def __post_init__(self):
48 | self.num_attention_heads = 1
49 | self.hidden_size = self.head_dim
50 |
51 |
52 | # Copied from transformers.models.llama.modeling_llama.rotate_half
53 | def rotate_half(x):
54 | """Rotates half the hidden dims of the input."""
55 | x1 = x[..., : x.shape[-1] // 2]
56 | x2 = x[..., x.shape[-1] // 2 :]
57 | return torch.cat((-x2, x1), dim=-1)
58 |
59 |
60 | # copy and modify from modify from hugigngface transformers
61 | # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
62 | class RotaryEmbedding(nn.Module):
63 | """Rotary embedding
64 |
65 | Args:
66 | config (RopeConfig): config for rotary embedding, see native_sparse_attention.module.rope.RopeConfig for details
67 | device (str): default to 'cuda'
68 | """
69 |
70 | cos = None
71 | sin = None
72 |
73 | def __init__(
74 | self, config: RopeConfig, device=torch.device(torch.cuda.current_device())
75 | ):
76 | super().__init__()
77 | # BC: "rope_type" was originally "type"
78 | if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
79 | self.rope_type = config.rope_scaling.get(
80 | "rope_type", config.rope_scaling.get("type")
81 | )
82 | else:
83 | self.rope_type = "default"
84 | self.max_seq_len_cached = config.max_position_embeddings
85 | self.original_max_seq_len = config.max_position_embeddings
86 |
87 | self.config = config
88 | self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
89 |
90 | inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
91 | self.register_buffer("inv_freq", inv_freq, persistent=False)
92 | self.original_inv_freq = self.inv_freq
93 |
94 | def _dynamic_frequency_update(self, position_ids, device):
95 | """
96 | dynamic RoPE layers should recompute `inv_freq` in the following situations:
97 | 1 - growing beyond the cached sequence length (allow scaling)
98 | 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
99 | """
100 | seq_len = torch.max(position_ids) + 1
101 | if seq_len > self.max_seq_len_cached: # growth
102 | inv_freq, self.attention_scaling = self.rope_init_fn(
103 | self.config, device, seq_len=seq_len
104 | )
105 | self.register_buffer(
106 | "inv_freq", inv_freq, persistent=False
107 | ) # TODO joao: may break with compilation
108 | self.max_seq_len_cached = seq_len
109 |
110 | if (
111 | seq_len < self.original_max_seq_len
112 | and self.max_seq_len_cached > self.original_max_seq_len
113 | ): # reset
114 | # This .to() is needed if the model has been moved to a device after being initialized (because
115 | # the buffer is automatically moved, but not the original copy)
116 | self.original_inv_freq = self.original_inv_freq.to(device)
117 | self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
118 | self.max_seq_len_cached = self.original_max_seq_len
119 |
120 | @torch.no_grad()
121 | def generate_cos_sin(self, x: torch.Tensor, position_ids):
122 | if "dynamic" in self.rope_type:
123 | self._dynamic_frequency_update(position_ids, device=x.device)
124 |
125 | # Core RoPE block
126 | inv_freq_expanded = (
127 | self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
128 | )
129 | position_ids_expanded = position_ids[:, None, :].float()
130 | # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
131 | device_type = x.device.type
132 | device_type = (
133 | device_type
134 | if isinstance(device_type, str) and device_type != "mps"
135 | else "cpu"
136 | )
137 | with torch.autocast(device_type=device_type, enabled=False):
138 | freqs = (
139 | inv_freq_expanded.float() @ position_ids_expanded.float()
140 | ).transpose(1, 2)
141 | # # donot use this if use flash_attn
142 | # emb = torch.cat((freqs, freqs), dim=-1)
143 | emb = freqs
144 | cos = emb.cos()
145 | sin = emb.sin()
146 |
147 | # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
148 | cos = (cos * self.attention_scaling).to(dtype=x.dtype).squeeze(0)
149 | sin = (sin * self.attention_scaling).to(dtype=x.dtype).squeeze(0)
150 |
151 | # save cos sin
152 | RotaryEmbedding.cos = torch.cat([cos, cos], dim=-1)
153 | RotaryEmbedding.sin = torch.cat([sin, sin], dim=-1)
154 |
155 | return RotaryEmbedding.cos, RotaryEmbedding.sin
156 |
157 | @torch.no_grad()
158 | def generate_pos_embs(
159 | self,
160 | x: torch.Tensor,
161 | cu_seqlens: torch.Tensor,
162 | seqlens: torch.Tensor,
163 | step: int = 0,
164 | stride: int = 1,
165 | ):
166 | if (
167 | RotaryEmbedding.cos is None
168 | or seqlens.max() + step > RotaryEmbedding.cos.shape[0]
169 | ):
170 | self.generate_cos_sin(
171 | x, torch.arange(seqlens.max() + step).to(x.device)[None, :]
172 | )
173 |
174 | cos_embs = []
175 | sin_embs = []
176 | bsz = len(cu_seqlens) - 1
177 |
178 | for i in range(bsz):
179 | if step == 0: # prefilling
180 | r = cu_seqlens[i + 1] - cu_seqlens[i]
181 | cos_emb, sin_emb = (
182 | RotaryEmbedding.cos[: r * stride : stride],
183 | RotaryEmbedding.sin[: r * stride : stride],
184 | )
185 | elif step > 0: # decoding
186 | r = cu_seqlens[i + 1] - cu_seqlens[i] + step - 1
187 | cos_emb, sin_emb = (
188 | RotaryEmbedding.cos[r * stride : r * stride + 1],
189 | RotaryEmbedding.sin[r * stride : r * stride + 1],
190 | )
191 | cos_embs.append(cos_emb)
192 | sin_embs.append(sin_emb)
193 |
194 | cos_embs = torch.cat(cos_embs, dim=0)
195 | sin_embs = torch.cat(sin_embs, dim=0)
196 | return cos_embs, sin_embs
197 |
198 | def forward(self, x, cu_seqlens, step=0, stride=1):
199 | seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
200 | cos_embs, sin_embs = self.generate_pos_embs(
201 | x,
202 | cu_seqlens,
203 | seqlens,
204 | step=step,
205 | stride=stride,
206 | )
207 | N, H, D = x.shape[0], x.shape[-2], x.shape[-1] # H: number of heads
208 | x = x * cos_embs.view(N, 1, D) + rotate_half(x) * sin_embs.view(N, 1, D)
209 | return x
210 |
--------------------------------------------------------------------------------
/native_sparse_attention/module/self_attention.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Xunhao Lai & Jianqiao Lu.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import torch
15 | from flash_attn import flash_attn_varlen_func
16 | from einops import rearrange
17 | from native_sparse_attention.module.rope import RopeConfig, RotaryEmbedding
18 | from native_sparse_attention.module.kv_cache import KVCache
19 | from native_sparse_attention.ops import flash_attention_decode
20 |
21 |
22 | class SelfAttention(torch.nn.Module):
23 | """self attention module
24 |
25 | Args:
26 | hidden_size (int): hidden dimension
27 | num_q_heads (int): number of query heads
28 | num_kv_heads (int): number of key/value heads, must be divisible by num_q_heads
29 | head_dim (int): head dim
30 | rope_config (RopeConfig): config for rotary embedding, see native_sparse_attention.module.rope.RopeConfig for details
31 | """
32 |
33 | def __init__(
34 | self,
35 | hidden_size: int,
36 | num_q_heads: int,
37 | num_kv_heads: int,
38 | head_dim: int,
39 | rope_config: RopeConfig,
40 | rope_device: str = "cuda",
41 | ):
42 | super().__init__()
43 | # configs
44 | self.hidden_size = hidden_size
45 | self.num_q_heads = num_q_heads
46 | self.num_kv_heads = num_kv_heads
47 | self.head_dim = head_dim
48 | self.rope_config = rope_config
49 | assert self.head_dim == self.rope_config.head_dim
50 |
51 | # qkv proj and o proj
52 | self.proj_q = torch.nn.Linear(
53 | self.hidden_size, self.num_q_heads * self.head_dim, bias=False
54 | )
55 | self.proj_k = torch.nn.Linear(
56 | self.hidden_size, self.num_kv_heads * self.head_dim, bias=False
57 | )
58 | self.proj_v = torch.nn.Linear(
59 | self.hidden_size, self.num_kv_heads * self.head_dim, bias=False
60 | )
61 | self.proj_o = torch.nn.Linear(
62 | self.num_q_heads * self.head_dim, self.hidden_size, bias=False
63 | )
64 | # rope
65 | self.rope = RotaryEmbedding(self.rope_config, device=rope_device)
66 |
67 | # init parameters
68 | self.init_params()
69 |
70 | def init_params(self):
71 | for p in self.parameters():
72 | torch.nn.init.xavier_uniform_(p)
73 |
74 | def forward(
75 | self,
76 | x: torch.Tensor, # shape: [total_len, hidden_size]
77 | cu_seqlens: torch.Tensor, # shape: [batch_size + 1]
78 | ):
79 | # dtype and shape check
80 | assert x.dtype == torch.bfloat16 or x.dtype == torch.float16
81 | assert x.shape[-1] == self.hidden_size
82 | cu_seqlens = cu_seqlens.to(torch.int32)
83 | seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
84 |
85 | # qkv proj
86 | q = self.proj_q(x).view(-1, self.num_q_heads, self.head_dim)
87 | k = self.proj_k(x).view(-1, self.num_kv_heads, self.head_dim)
88 | v = self.proj_v(x).view(-1, self.num_kv_heads, self.head_dim)
89 |
90 | # do rope for query and compressed key
91 | q = self.rope(q, cu_seqlens)
92 | k = self.rope(k, cu_seqlens)
93 |
94 | # self attention
95 | attn_output = flash_attn_varlen_func(
96 | q,
97 | k,
98 | v,
99 | cu_seqlens,
100 | cu_seqlens,
101 | seqlens.max().item(),
102 | seqlens.max().item(),
103 | causal=True,
104 | )
105 |
106 | # rearrange and output proj
107 | attn_output = rearrange(attn_output, "n h d -> n (h d)")
108 | attn_output = self.proj_o(attn_output)
109 |
110 | return attn_output
111 |
112 | @torch.no_grad()
113 | def inference(
114 | self,
115 | x: torch.Tensor, # shape: [total_len, hidden_size]
116 | cu_seqlens: torch.Tensor, # shape: [batch_size + 1]
117 | step: int,
118 | cache: KVCache,
119 | ):
120 | # dtype and shape check
121 | assert x.dtype == torch.bfloat16 or x.dtype == torch.float16
122 | assert x.shape[-1] == self.hidden_size
123 | cu_seqlens = cu_seqlens.to(torch.int32)
124 | seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
125 | assert step >= 0
126 | if step == 0:
127 | assert x.shape[0] == cu_seqlens[-1]
128 | else:
129 | assert x.shape[0] == cu_seqlens.shape[0] - 1
130 | batch_size = cu_seqlens.shape[0] - 1
131 | # qkv proj
132 | q = self.proj_q(x).view(-1, self.num_q_heads, self.head_dim)
133 | k = self.proj_k(x).view(-1, self.num_kv_heads, self.head_dim)
134 | v = self.proj_v(x).view(-1, self.num_kv_heads, self.head_dim)
135 | # do rope for query and compressed key
136 | q = self.rope(q, cu_seqlens, step)
137 | k = self.rope(k, cu_seqlens, step)
138 | # reset and update kv cache
139 | if step == 0:
140 | cache.reset()
141 | cache.update_kv(cu_seqlens, step, k, v)
142 | # self attention
143 | if step == 0:
144 | cu_seqlens_q = cu_seqlens_k = cu_seqlens
145 | max_seqlen_in_batch_q = max_seqlen_in_batch_k = seqlens.max().item()
146 | output = flash_attn_varlen_func(
147 | q,
148 | k,
149 | v,
150 | cu_seqlens_q=cu_seqlens_q,
151 | cu_seqlens_k=cu_seqlens_k,
152 | max_seqlen_q=max_seqlen_in_batch_q,
153 | max_seqlen_k=max_seqlen_in_batch_k,
154 | causal=True,
155 | )
156 | else:
157 | output = flash_attention_decode(
158 | q,
159 | cache.kv_cache[0, :batch_size],
160 | cache.kv_cache[1, :batch_size],
161 | cache.kv_len[:batch_size],
162 | )
163 | # rearrange and output proj
164 | output = rearrange(output, "n h d -> n (h d)")
165 | output = self.proj_o(output)
166 | return output
167 |
--------------------------------------------------------------------------------
/native_sparse_attention/ops/README.md:
--------------------------------------------------------------------------------
1 | # Triton Functions for Native Sparse Attention
2 |
3 | This folder provides efficient Triton-based implementations of components for Native Sparse Attention. This README introduces the available functions, explains how to set them up, and offers guidance on their usage.
4 |
5 | ---
6 |
7 | ## Overview of Functions
8 |
9 | The functions are organized into two main categories:
10 |
11 | 1. **Compression Methods**: Techniques for compressing key and value tensors.
12 | 2. **Attention Mechanisms**: Methods for computing attention between queries and compressed key/value tensors, including top-k sparse attention.
13 |
14 | ---
15 |
16 | ## Function Descriptions
17 |
18 | ### Compression Methods
19 |
20 | These functions compress key and value tensors using a sliding window approach. Within each window, `kernel_size` tokens are compressed into a single token, with a stride of `kernel_stride`. All compression functions share similar input parameters and return formats.
21 |
22 | **Parameters:**
23 | - `x`: Input tensor (`total_len, num_heads, head_dim`)
24 | - `w`: Weight tensor (shape varies by compression method)
25 | - `cu_seqlens`: Cumulative sequence lengths (`batch_size + 1`)
26 | - `kernel_size`: Size of the compression window
27 | - `kernel_stride`: Stride of the compression window
28 | - `pe`: Optional positional embedding (`num_heads, kernel_size, head_dim`)
29 |
30 | **Returns:**
31 | - Compressed tensor (`total_compress_len, num_heads, head_dim`)
32 | - Cumulative sequence lengths (`com_cu_seqlens`) for the compressed tensor
33 |
34 | #### `weightedpool_compress`
35 | Compresses the input tensor using weighted pooling, applying a weighted sum over each block:
36 | $\hat{k} = w_1 k_1 + \dots + w_m k_m$
37 | - **Weight shape**: `(num_heads, kernel_size)`
38 |
39 | #### `avgpool_compress`
40 | Compresses the input tensor using average pooling:
41 | $\hat{k} = (k_1 + \dots + k_m) / m$
42 | - **Weight**: Must be `None`
43 |
44 | #### `linear_compress`
45 | Compresses the input tensor via linear projection, mapping each block to a single vector using learned weights:
46 | $\hat{k} = \text{cat}(k_1, \dots, k_m) W$
47 | - **Weight shape**: `(num_heads, kernel_size * head_dim, head_dim)`
48 |
49 | ---
50 |
51 | ### Attention Mechanisms
52 |
53 | These functions compute attention using either full or sparse mechanisms.
54 |
55 | #### `flash_attention_varlen`
56 | A variable-length implementation of flash attention, similar to `flash_attn_varlen_func` from the `flash_attn` package.
57 |
58 | **Parameters:**
59 | - `q`, `k`, `v`: Query, key, and value tensors (`total_len, num_heads, head_dim`)
60 | - `cu_seqlens_q`, `cu_seqlens_k`: Cumulative sequence lengths for queries and keys
61 | - `max_seqlen_q`, `max_seqlen_k`: Maximum sequence lengths in the batch
62 | - `causal`: Apply causal masking (default: `False`)
63 | - `sm_scale`: Softmax scale (default: `1 / sqrt(head_dim)`)
64 |
65 | **Returns:**
66 | - Attention output tensor (`total_q_len, num_q_heads, head_dim`)
67 |
68 | #### `compressed_attention`
69 | Computes attention between a query and compressed key/value tensors, identifying the top-k blocks for sparse attention.
70 |
71 | **Parameters:**
72 | - `q`: Query tensor (`total_len, num_heads, head_dim`)
73 | - `k`, `v`: Compressed key and value tensors (`total_compress_len, num_heads, head_dim`)
74 | - `kernel_size`, `kernel_stride`: Compression parameters
75 | - `block_size`: Size of blocks for sparse attention
76 | - `topk`: Number of top blocks to select
77 | - `cu_seqlens_q`, `cu_seqlens_k`: Cumulative sequence lengths for query and compressed key/value
78 | - `max_seqlen_q`, `max_seqlen_k`: Maximum sequence lengths for query and compressed key/value
79 | - `sm_scale`: Softmax scale (default: `1 / sqrt(head_dim)`)
80 | - `init_blocks`: Number of initial blocks forced to be selected (default: `1`)
81 | - `local_blocks`: Number of local blocks forced to be selected (default: `2`)
82 |
83 | **Returns:**
84 | - Tuple containing:
85 | - Attention output tensor
86 | - Top-k block indices
87 |
88 | #### `topk_sparse_attention`
89 | Performs sparse attention using precomputed top-k block indices. If a query attends to fewer than `topk` key/value blocks, the `topk_idx` should be padded with `-1` on the right.
90 |
91 | **Parameters:**
92 | - `q`, `k`, `v`: Query, key, and value tensors (`total_len, num_heads, head_dim`)
93 | - `topk_idx`: Precomputed top-k indices (`num_kv_heads, total_len, topk`)
94 | - `block_size`: Block size for sparse attention (recommended: `64` or `128`)
95 | - `cu_seqlens`: Cumulative sequence lengths
96 | - `softmax_scale`: Softmax scale (default: `1 / sqrt(head_dim)`)
97 |
98 | **Returns:**
99 | - Attention output tensor (`total_len, num_q_heads, head_dim`)
100 |
101 | ---
102 |
103 | ## Usage Example
104 |
105 | Below is a typical workflow demonstrating how to combine these sparse attention functions:
106 |
107 | ```python
108 | import torch
109 | from native_sparse_attention.ops import linear_compress, compressed_attention, topk_sparse_attention
110 |
111 | # Example input setup
112 | num_q_heads = 64
113 | num_kv_heads = 4
114 | head_dim = 128
115 | cu_seqlens = torch.tensor([0, 1024, 8192, 16384], dtype=torch.int32).cuda()
116 |
117 | # Query, key, and value tensors
118 | query = torch.randn(16384, num_q_heads, head_dim, dtype=torch.bfloat16).cuda()
119 | key = torch.randn(16384, num_kv_heads, head_dim, dtype=torch.bfloat16).cuda()
120 | value = torch.randn(16384, num_kv_heads, head_dim, dtype=torch.bfloat16).cuda()
121 |
122 | # Compression weights and positional embeddings
123 | kernel_size = 32
124 | kernel_stride = 16
125 | wk = torch.randn(num_kv_heads, kernel_size * head_dim, head_dim, dtype=torch.bfloat16).cuda()
126 | wv = torch.randn_like(wk)
127 | pe = torch.randn(num_kv_heads, kernel_size, head_dim, dtype=torch.bfloat16).cuda()
128 |
129 | # Parameters for top-k sparse attention
130 | block_size = 64
131 | topk = 16
132 |
133 | # 1. Compress key and value tensors
134 | compressed_key, compressed_cu_seqlens = linear_compress(
135 | key, wk, cu_seqlens, kernel_size, kernel_stride, pe
136 | )
137 | compressed_value, _ = linear_compress(
138 | value, wv, cu_seqlens, kernel_size, kernel_stride, None
139 | )
140 |
141 | # 2. Compute attention with compressed key/value and get top-k indices
142 | compressed_attn_output, topk_idx = compressed_attention(
143 | query,
144 | compressed_key,
145 | compressed_value,
146 | kernel_size,
147 | kernel_stride,
148 | block_size,
149 | topk,
150 | cu_seqlens,
151 | compressed_cu_seqlens,
152 | init_blocks=1,
153 | local_blocks=2,
154 | )
155 |
156 | # 3. Perform top-k sparse attention
157 | sparse_attn_output = topk_sparse_attention(
158 | query,
159 | key,
160 | value,
161 | topk_idx,
162 | block_size,
163 | cu_seqlens,
164 | )
165 |
166 | # 4. Combine attention outputs (e.g., average)
167 | attn_output = (compressed_attn_output + sparse_attn_output) / 2
168 | ```
169 |
170 | For a complete implementation of the Native Sparse Attention module, see `native_sparse_attention/module/native_sparse_attention.py`.
171 |
--------------------------------------------------------------------------------
/native_sparse_attention/ops/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Xunhao Lai & Jianqiao Lu.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # compress method
16 | from native_sparse_attention.ops.triton.weighted_pool import (
17 | weightedpool_compress,
18 | avgpool_compress,
19 | )
20 | from native_sparse_attention.ops.triton.linear_compress import linear_compress
21 |
22 | # prefill attention
23 | from native_sparse_attention.ops.triton.flash_attention import flash_attention_varlen
24 | from native_sparse_attention.ops.triton.compressed_attention import compressed_attention
25 | from native_sparse_attention.ops.triton.topk_sparse_attention import (
26 | topk_sparse_attention,
27 | )
28 |
29 | # decode attention
30 | from native_sparse_attention.ops.triton.flash_attention_decode import (
31 | flash_attention_decode,
32 | )
33 | from native_sparse_attention.ops.torch.compressed_attention_decode import (
34 | compressed_attention_decode,
35 | )
36 | from native_sparse_attention.ops.triton.topk_sparse_attention_decode import (
37 | topk_sparse_attention_decode,
38 | )
39 |
40 | __all__ = [
41 | # compress method
42 | "avgpool_compress",
43 | "weightedpool_compress",
44 | "linear_compress",
45 | # prefill attention, trainable
46 | "flash_attention_varlen",
47 | "compressed_attention",
48 | "topk_sparse_attention",
49 | # decode attention, no grad
50 | "flash_attention_decode",
51 | "compressed_attention_decode",
52 | "topk_sparse_attention_decode",
53 | ]
54 |
--------------------------------------------------------------------------------
/native_sparse_attention/ops/torch/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XunhaoLai/native-sparse-attention-triton/9bea856c911ebf263be88d797fb28458f82f1d94/native_sparse_attention/ops/torch/__init__.py
--------------------------------------------------------------------------------
/native_sparse_attention/ops/torch/compress_key_value.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Xunhao Lai & Jianqiao Lu.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import torch
15 | from typing import Optional
16 | from einops import rearrange, einsum
17 |
18 |
19 | def avgpool_compress_torch(
20 | x: torch.Tensor,
21 | w: torch.Tensor,
22 | cu_seqlens,
23 | kernel_size: int,
24 | kernel_stride: int,
25 | pe: Optional[torch.Tensor] = None,
26 | ):
27 | """Compress key and value tensor with kernel_size and kernel_stride.
28 |
29 | Args:
30 | x (torch.Tensor): key_states or value_states, shape (total_len, num_heads, head_dim)
31 | w (torch.Tensor): no weight for avgpool, must be None.
32 | cu_seqlens (_type_): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen.
33 | kernel_size (int): kernel_size, each (kernel_size, head_dim) blocks will be compressed to (1, head_dim)
34 | kernel_stride (int): stride for each compress kernel
35 | pe (Optional[torch.Tensor], optional): intra-block positional embedding with shape (num_heads, kernel_size, head_dim). Defaults to None.
36 |
37 | Returns:
38 | Tuple[torch.Tensor, torch.Tensor]: compressed states and corresponding cu_seqlens.
39 | """
40 | # dtype check
41 | assert x.dtype == torch.float16 or x.dtype == torch.bfloat16
42 | assert cu_seqlens.dtype == torch.int32
43 | assert x.dtype == pe.dtype if pe is not None else True
44 |
45 | # shape check
46 | total_len, num_heads, head_dim = x.shape
47 | batch_size = cu_seqlens.shape[0] - 1
48 | assert w is None, "don't need additional weight for avgpool"
49 | assert kernel_size % kernel_stride == 0
50 | assert kernel_size in {16, 32, 64, 128}
51 |
52 | # compute seqlens after compression
53 | seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
54 | y_seqlens = torch.floor((seqlens - kernel_size) / kernel_stride).to(torch.int32) + 1
55 | # corner case, if sequence_length < kernel_size, no compression for this sequence
56 | y_seqlens[seqlens < kernel_size] = 0
57 | y_cu_seqlens = torch.cat(
58 | [
59 | torch.zeros(1, dtype=torch.int32, device="cuda"),
60 | torch.cumsum(y_seqlens, dim=0),
61 | ],
62 | dim=0,
63 | ).to(torch.int32)
64 |
65 | # pad and rearrange x
66 | x = rearrange(x, "n h d -> n (h d)")
67 | splited_x = torch.split(x, seqlens.tolist(), 0)
68 | x = torch.nn.utils.rnn.pad_sequence(splited_x, batch_first=True)
69 | x = rearrange(x, "b n d -> b d n")
70 | # avgpool
71 | y = torch.nn.functional.avg_pool1d(x, kernel_size=kernel_size, stride=kernel_stride)
72 | y = rearrange(y, "b (h d) n -> b n h d", h=num_heads)
73 | # only keep useful part
74 | y = torch.cat([y[i, : y_seqlens[i]] for i in range(batch_size)], dim=0)
75 |
76 | # position embedding as a bias
77 | if pe is not None:
78 | bias = torch.mean(pe, dim=1)
79 | y = y + bias.unsqueeze(0)
80 |
81 | return y, y_cu_seqlens
82 |
83 |
84 | def weightedpool_compress_torch(
85 | x: torch.Tensor,
86 | w: torch.Tensor, # [num_heads, kernel_size]
87 | cu_seqlens,
88 | kernel_size: int,
89 | kernel_stride: int,
90 | pe: Optional[torch.Tensor] = None,
91 | ):
92 | """Compress key and value tensor with kernel_size and kernel_stride.
93 |
94 | Args:
95 | x (torch.Tensor): key_states or value_states, shape (total_len, num_heads, head_dim)
96 | w (torch.Tensor): weight for each head, shape (num_heads, kernel_size)
97 | cu_seqlens (_type_): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen.
98 | kernel_size (int): kernel_size, each (kernel_size, head_dim) blocks will be compressed to (1, head_dim)
99 | kernel_stride (int): stride for each compress kernel
100 | pe (Optional[torch.Tensor], optional): intra-block positional embedding with shape (num_heads, kernel_size, head_dim). Defaults to None.
101 |
102 | Returns:
103 | Tuple[torch.Tensor, torch.Tensor]: compressed states and corresponding cu_seqlens.
104 | """
105 | # dtype check
106 | assert x.dtype == torch.float16 or x.dtype == torch.bfloat16
107 | assert x.dtype == w.dtype
108 | assert x.dtype == pe.dtype if pe is not None else True
109 | assert cu_seqlens.dtype == torch.int32
110 | # shape check
111 | total_len, num_heads, head_dim = x.shape
112 | batch_size = cu_seqlens.shape[0] - 1
113 | assert w.shape[0] == num_heads
114 | assert w.shape[1] == kernel_size
115 | assert kernel_size % kernel_stride == 0
116 | assert kernel_size in {16, 32, 64, 128}
117 | # compute seqlens after compression
118 | seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
119 | y_seqlens = torch.floor((seqlens - kernel_size) / kernel_stride).to(torch.int32) + 1
120 | # corner case, if sequence_length < kernel_size, no compression for this sequence
121 | y_seqlens[seqlens < kernel_size] = 0
122 | y_cu_seqlens = torch.cat(
123 | [
124 | torch.zeros(1, dtype=torch.int32, device="cuda"),
125 | torch.cumsum(y_seqlens, dim=0),
126 | ],
127 | dim=0,
128 | ).to(torch.int32)
129 | # pad and rearrange x
130 | x = rearrange(x, "n h d -> n (h d)")
131 | splited_x = torch.split(x, seqlens.tolist(), 0)
132 | x = torch.nn.utils.rnn.pad_sequence(splited_x, batch_first=True)
133 | x = rearrange(x, "b n (h d) -> b h n d", h=num_heads)
134 | x = x.as_strided(
135 | size=(batch_size, num_heads, y_seqlens.max().item(), kernel_size, head_dim),
136 | stride=(
137 | x.stride(0),
138 | x.stride(1),
139 | kernel_stride * x.stride(2),
140 | x.stride(2),
141 | x.stride(3),
142 | ),
143 | )
144 | y = einsum(x, w, "b h n k d, h k -> b n h d")
145 | # only keep useful part
146 | y = torch.cat([y[i, : y_seqlens[i]] for i in range(batch_size)], dim=0)
147 |
148 | # position embedding as a bias
149 | if pe is not None:
150 | bias = einsum(pe, w, "h k d, h k -> h d")
151 | y = y + bias.unsqueeze(0)
152 |
153 | return y, y_cu_seqlens
154 |
155 |
156 | def linear_compress_torch(
157 | x: torch.Tensor,
158 | w: torch.Tensor, # [num_heads, kernel_size * head_dim, head_dim]
159 | cu_seqlens,
160 | kernel_size: int,
161 | kernel_stride: int,
162 | pe: Optional[torch.Tensor] = None,
163 | ):
164 | """Compress key and value tensor with kernel_size and kernel_stride. Similar to conv_compress.
165 |
166 | Args:
167 | x (torch.Tensor): key_states or value_states, shape (total_len, num_heads, head_dim)
168 | w (torch.Tensor): weight for each head, shape (num_heads, kernel_size * head_dim, head_dim)
169 | cu_seqlens (_type_): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen.
170 | kernel_size (int): kernel_size, each (kernel_size, head_dim) blocks will be compressed to (1, head_dim)
171 | kernel_stride (int): stride for each compress kernel
172 | pe (Optional[torch.Tensor], optional): intra-block positional embedding with shape (num_heads, kernel_size, head_dim). Defaults to None.
173 |
174 | Returns:
175 | Tuple[torch.Tensor, torch.Tensor]: compressed states and corresponding cu_seqlens.
176 | """
177 | # dtype check
178 | assert x.dtype == torch.float16 or x.dtype == torch.bfloat16
179 | assert x.dtype == w.dtype
180 | assert x.dtype == pe.dtype if pe is not None else True
181 | assert cu_seqlens.dtype == torch.int32
182 | # shape check
183 | total_len, num_heads, head_dim = x.shape
184 | batch_size = cu_seqlens.shape[0] - 1
185 | assert w.shape[0] == num_heads
186 | assert w.shape[1] == kernel_size * head_dim
187 | assert w.shape[2] == head_dim
188 | assert kernel_size % kernel_stride == 0
189 | assert kernel_size in {16, 32, 64, 128}
190 | # compute seqlens after compression
191 | seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
192 | y_seqlens = torch.floor((seqlens - kernel_size) / kernel_stride).to(torch.int32) + 1
193 | # corner case, if sequence_length < kernel_size, no compression for this sequence
194 | y_seqlens[seqlens < kernel_size] = 0
195 | y_cu_seqlens = torch.cat(
196 | [
197 | torch.zeros(1, dtype=torch.int32, device="cuda"),
198 | torch.cumsum(y_seqlens, dim=0),
199 | ],
200 | dim=0,
201 | ).to(torch.int32)
202 | # pad and rearrange x
203 | x = rearrange(x, "n h d -> n (h d)")
204 | splited_x = torch.split(x, seqlens.tolist(), 0)
205 | x = torch.nn.utils.rnn.pad_sequence(splited_x, batch_first=True)
206 | x = rearrange(x, "b n (h d) -> b h n d", h=num_heads)
207 | x = x.as_strided(
208 | size=(batch_size, num_heads, y_seqlens.max().item(), kernel_size, head_dim),
209 | stride=(
210 | x.stride(0),
211 | x.stride(1),
212 | kernel_stride * x.stride(2),
213 | x.stride(2),
214 | x.stride(3),
215 | ),
216 | )
217 | y = einsum(
218 | x,
219 | rearrange(w, "h (k d) D -> h k d D", k=kernel_size),
220 | "b h n k d, h k d D -> b n h D",
221 | )
222 | # only keep useful part
223 | y = torch.cat([y[i, : y_seqlens[i]] for i in range(batch_size)], dim=0)
224 |
225 | # position embedding as a bias
226 | if pe is not None:
227 | pe = rearrange(pe, "h k d -> h (k d)")
228 | bias = einsum(pe, w, "h D, h D d -> h d")
229 | y = y + bias.unsqueeze(0)
230 |
231 | return y, y_cu_seqlens
232 |
--------------------------------------------------------------------------------
/native_sparse_attention/ops/torch/compressed_attention.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Xunhao Lai & Jianqiao Lu.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import torch
15 | import math
16 | from typing import Tuple
17 | from collections import Counter
18 | from einops import rearrange
19 |
20 |
21 | def transform_score(
22 | score: torch.Tensor,
23 | kernel_size: int,
24 | kernel_stride: int,
25 | block_size: int,
26 | cu_seqlens_q: torch.Tensor,
27 | cu_seqlens_k: torch.Tensor,
28 | max_seqlen_q: int,
29 | max_seqlen_k: int,
30 | init_blocks: int = 1,
31 | local_blocks: int = 2,
32 | ) -> torch.Tensor:
33 | num_k_heads, total_query_len, _ = score.shape
34 | pad_len = kernel_size // kernel_stride - 1
35 | score = torch.nn.functional.pad(score, (pad_len, pad_len), value=0)
36 | max_blocks = math.ceil(max_seqlen_q / block_size)
37 | full_blocks = max_seqlen_q // block_size
38 | block_score = torch.zeros(
39 | num_k_heads,
40 | total_query_len,
41 | max_blocks,
42 | dtype=torch.float32,
43 | device=score.device,
44 | )
45 | offs = (
46 | torch.arange(kernel_size // kernel_stride)[:, None]
47 | + torch.arange(block_size // kernel_stride)[None, :]
48 | ).view(-1)
49 | offs = dict(Counter(offs.tolist()))
50 | for k, v in offs.items():
51 | block_score[..., :full_blocks] += (
52 | v * score[..., k :: block_size // kernel_stride][..., :full_blocks]
53 | )
54 | # set init block and local block score
55 | batch_size = cu_seqlens_q.shape[0] - 1
56 | q_idx = torch.cat(
57 | [
58 | torch.arange(cu_seqlens_q[i + 1] - cu_seqlens_q[i], device=score.device)
59 | for i in range(batch_size)
60 | ],
61 | dim=0,
62 | )
63 | q_idx = q_idx // block_size
64 | b_idx = torch.arange(max_blocks, device=score.device)
65 | block_score[..., :init_blocks] = torch.inf
66 | local_mask = (q_idx[:, None] >= b_idx[None, :]) & (
67 | q_idx[:, None] < b_idx[None, :] + local_blocks
68 | )
69 | local_mask = local_mask.unsqueeze(0).expand(num_k_heads, -1, -1)
70 | block_score[local_mask] = torch.inf
71 | return block_score
72 |
73 |
74 | def compressed_attention_torch(
75 | q: torch.Tensor, # [total_query_len, num_q_heads, head_dim]
76 | k: torch.Tensor, # [total_key_len, num_k_heads, head_dim]
77 | v: torch.Tensor, # [total_key_len, num_k_heads, head_dim]
78 | kernel_size: int,
79 | kernel_stride: int,
80 | block_size: int,
81 | topk: int,
82 | cu_seqlens_q: torch.Tensor,
83 | cu_seqlens_k: torch.Tensor,
84 | max_seqlen_q: int,
85 | max_seqlen_k: int,
86 | sm_scale: float = None,
87 | init_blocks: int = 1,
88 | local_blocks: int = 2,
89 | ) -> Tuple[torch.Tensor, torch.Tensor]:
90 | """Attention between query and compressed key and value. Implemented with torch, only for debug.
91 |
92 | Args:
93 | q (torch.Tensor): shape [total_q_len, num_q_heads, head_dim]
94 | k (torch.Tensor): shape [total_kv_len, num_kv_heads, head_dim]
95 | v (torch.Tensor): shape [total_kv_len, num_kv_heads, head_dim]
96 | kernel_size (int): kernel size in compress_key_value
97 | kernel_stride (int): stride of compress_key_value
98 | block_size (int): key value block size for topk sparse attention.
99 | topk (int): number of blocks for each query.
100 | cu_seqlens_q (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen.
101 | cu_seqlens_k (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_k in flash_attn_func_varlen.
102 | max_seqlen_q (int): max q len of the batch.
103 | max_seqlen_k (int): max k len of the batch.
104 | sm_scale (float, optional): softmax scale. Defaults to None, means 1/sqrt(head_dim).
105 | init_blocks (int, optional): Number of init blocks for each query. Defaults to 1.
106 | local_blocks (int, optional): Number of local blocks for each query. Defaults to 2.
107 |
108 | Returns:
109 | Tuple[torch.Tensor, torch.Tensor]: attention output and topk_idx used in topk_sparse_attention
110 | """
111 | assert block_size % kernel_size == 0 and kernel_size % kernel_stride == 0
112 | total_query_len, num_q_heads, head_dim = q.shape
113 | total_key_len, num_k_heads, _ = k.shape
114 | num_share_q_heads = num_q_heads // num_k_heads
115 | batch_size = cu_seqlens_q.shape[0] - 1
116 | if sm_scale is None:
117 | sm_scale = 1.0 / math.sqrt(head_dim)
118 | # get mask
119 | mask = torch.zeros(
120 | (total_query_len, total_key_len), dtype=torch.bool, device=q.device
121 | )
122 | for b in range(batch_size):
123 | q_len, k_len = (
124 | cu_seqlens_q[b + 1] - cu_seqlens_q[b],
125 | cu_seqlens_k[b + 1] - cu_seqlens_k[b],
126 | )
127 | k_max_ids = (
128 | torch.arange(k_len, device=q.device) * kernel_stride + kernel_size - 1
129 | )
130 | q_ids = torch.arange(q_len, device=q.device)
131 | mask[
132 | cu_seqlens_q[b] : cu_seqlens_q[b + 1], cu_seqlens_k[b] : cu_seqlens_k[b + 1]
133 | ] = (q_ids[:, None] >= k_max_ids[None, :])
134 | # attention
135 | qk = (
136 | torch.einsum("qhd,khd->hqk", q, k.repeat_interleave(num_share_q_heads, 1))
137 | * sm_scale
138 | )
139 | qk = qk.masked_fill_(~mask[None, ...], -torch.inf)
140 | # query from beginning of the sequence can't attend to any compressed key
141 | qk = qk.softmax(dim=-1, dtype=torch.float32)
142 | qk = qk.nan_to_num(0)
143 | attn_output = torch.einsum(
144 | "hqk,khd->qhd", qk.to(v.dtype), v.repeat_interleave(num_share_q_heads, 1)
145 | )
146 | with torch.no_grad():
147 | # get avg score over gqa heads
148 | # qk shape [num_k_heads, total_q_len, total_k_len]
149 | score = torch.zeros(
150 | num_k_heads,
151 | cu_seqlens_q[-1],
152 | max_seqlen_k,
153 | dtype=torch.float32,
154 | device=q.device,
155 | )
156 | qk = rearrange(qk, "(h g) q k -> h g q k", h=num_k_heads).sum(1)
157 | for b in range(batch_size):
158 | score[
159 | :,
160 | cu_seqlens_q[b] : cu_seqlens_q[b + 1],
161 | : cu_seqlens_k[b + 1] - cu_seqlens_k[b],
162 | ] = qk[
163 | :,
164 | cu_seqlens_q[b] : cu_seqlens_q[b + 1],
165 | cu_seqlens_k[b] : cu_seqlens_k[b + 1],
166 | ]
167 | # transform score to block-wise score
168 | score = transform_score(
169 | score,
170 | kernel_size,
171 | kernel_stride,
172 | block_size,
173 | cu_seqlens_q,
174 | cu_seqlens_k,
175 | max_seqlen_q,
176 | max_seqlen_k,
177 | init_blocks,
178 | local_blocks,
179 | )
180 | # get topk
181 | batch_size = cu_seqlens_q.shape[0] - 1
182 | q_idx = torch.cat(
183 | [
184 | torch.arange(cu_seqlens_q[i + 1] - cu_seqlens_q[i], device=q.device)
185 | for i in range(batch_size)
186 | ],
187 | dim=0,
188 | )
189 | q_idx = q_idx // block_size
190 | topk = min(topk, score.shape[-1])
191 | topk_idx = score.topk(topk, dim=-1).indices.sort(-1).values
192 | topk_idx[topk_idx > q_idx[None, :, None]] = -1
193 | topk_idx = topk_idx.to(torch.int32)
194 | return attn_output, topk_idx
195 |
--------------------------------------------------------------------------------
/native_sparse_attention/ops/torch/compressed_attention_decode.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Xunhao Lai & Jianqiao Lu.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import torch
15 | import math
16 | from typing import Tuple, Optional
17 | from collections import Counter
18 | from einops import rearrange
19 |
20 |
21 | def transform_score(
22 | score: torch.Tensor,
23 | seqlens: torch.Tensor,
24 | kernel_size: int,
25 | kernel_stride: int,
26 | block_size: int,
27 | init_blocks: int = 1,
28 | local_blocks: int = 2,
29 | ) -> torch.Tensor:
30 | num_k_heads, batch_size, kv_len = score.shape
31 | pad_len = kernel_size // kernel_stride - 1
32 | score = torch.nn.functional.pad(score, (pad_len, pad_len), value=0)
33 | max_seqlen = seqlens.max().item()
34 | max_blocks = math.ceil(max_seqlen / block_size)
35 | full_blocks = max_seqlen // block_size
36 | block_score = torch.zeros(
37 | num_k_heads,
38 | batch_size,
39 | max_blocks,
40 | dtype=torch.float32,
41 | device=score.device,
42 | )
43 | offs = (
44 | torch.arange(kernel_size // kernel_stride)[:, None]
45 | + torch.arange(block_size // kernel_stride)[None, :]
46 | ).view(-1)
47 | offs = dict(Counter(offs.tolist()))
48 | for k, v in offs.items():
49 | block_score[..., :full_blocks] += (
50 | v * score[..., k :: block_size // kernel_stride][..., :full_blocks]
51 | )
52 | # set init block and local block score
53 | q_idx = (seqlens - 1) // block_size
54 | b_idx = torch.arange(max_blocks, device=score.device)
55 | block_score[..., :init_blocks] = torch.inf
56 | local_mask = (q_idx[:, None] >= b_idx[None, :]) & (
57 | q_idx[:, None] < b_idx[None, :] + local_blocks
58 | )
59 | local_mask = local_mask.unsqueeze(0).expand(num_k_heads, -1, -1)
60 | block_score[local_mask] = torch.inf
61 | block_score = block_score.nan_to_num(0, torch.inf, -torch.inf)
62 | return block_score
63 |
64 |
65 | def compressed_attention_decode(
66 | q: torch.Tensor,
67 | k: torch.Tensor,
68 | v: torch.Tensor,
69 | seqlens: torch.Tensor,
70 | compress_seqlens: torch.Tensor,
71 | kernel_size: int,
72 | kernel_stride: int,
73 | block_size: int,
74 | topk: int,
75 | init_blocks: int = 1,
76 | local_blocks: int = 2,
77 | sm_scale: Optional[float] = None,
78 | ) -> Tuple[torch.Tensor, torch.Tensor]:
79 | """_summary_
80 |
81 | Args:
82 | q (torch.Tensor): shape [batch_size, num_q_heads, head_dim]
83 | k (torch.Tensor): shape [batch_size, kv_len, num_kv_heads, head_dim]
84 | v (torch.Tensor): shape [batch_size, kv_len, num_kv_heads, head_dim]
85 | seqlens (torch.Tensor): original kv length for each sequence
86 | compress_seqlens (torch.Tensor): kv length for each sequence after compression
87 | kernel_size (int): kernel size in compress_key_value
88 | kernel_stride (int): stride of compress_key_value
89 | block_size (int): key value block size for topk sparse attention.
90 | topk (int): number of blocks for each query.
91 | init_blocks (int, optional): Number of init blocks for each query. Defaults to 1.
92 | local_blocks (int, optional): Number of local blocks for each query. Defaults to 2.
93 | sm_scale (float, optional): softmax scale. Defaults to None, means 1/sqrt(head_dim).
94 |
95 | Returns:
96 | Tuple[torch.Tensor, torch.Tensor]: attention output and topk_idx used in topk_sparse_attention_decode
97 | """
98 | assert block_size % kernel_size == 0 and kernel_size % kernel_stride == 0
99 | batch_size, num_q_heads, head_dim = q.shape
100 | batch_size, kv_len, num_k_heads, _ = k.shape
101 | num_share_q_heads = num_q_heads // num_k_heads
102 | if sm_scale is None:
103 | sm_scale = 1.0 / math.sqrt(head_dim)
104 | # input is too short to have a valid block
105 | if kv_len == 0:
106 | return torch.zeros_like(q), torch.zeros(
107 | num_k_heads, batch_size, 1, device=q.device, dtype=torch.int32
108 | )
109 | # get mask
110 | mask = (
111 | compress_seqlens[:, None]
112 | > torch.arange(
113 | kv_len, device=compress_seqlens.device, dtype=compress_seqlens.dtype
114 | )[None, :]
115 | )
116 | # attention
117 | qk = (
118 | torch.einsum(
119 | "bihgd, bjhgd -> bhgij",
120 | rearrange(q, "b (h g) d -> b 1 h g d", g=num_share_q_heads),
121 | rearrange(k, "b j h d -> b j h 1 d"),
122 | )
123 | * sm_scale
124 | )
125 | qk = qk.masked_fill_(~mask[:, None, None, None, :], -torch.inf)
126 | qk = qk.softmax(dim=-1, dtype=torch.float32)
127 | qk = qk.nan_to_num_(0) # qk is nan when seqlen == 0
128 | attn_output = torch.einsum(
129 | "bhgij, bjhgd -> bihgd",
130 | qk.to(v.dtype),
131 | rearrange(v, "b k h d -> b k h 1 d"),
132 | )
133 | attn_output = rearrange(attn_output, "b 1 h g d -> b (h g) d")
134 |
135 | # get score
136 | score = rearrange(qk.sum(2).squeeze(2), "b h j -> h b j")
137 | # transform score to block-wise score
138 | score = transform_score(
139 | score,
140 | seqlens,
141 | kernel_size,
142 | kernel_stride,
143 | block_size,
144 | init_blocks,
145 | local_blocks,
146 | )
147 | # get topk
148 | q_idx = (seqlens - 1) // block_size
149 | topk = min(topk, score.shape[-1])
150 | topk_idx = score.topk(topk, dim=-1).indices.sort(-1).values
151 | topk_idx[topk_idx > q_idx[None, :, None]] = -1
152 | topk_idx = topk_idx.to(torch.int32)
153 | return attn_output, topk_idx
154 |
--------------------------------------------------------------------------------
/native_sparse_attention/ops/torch/topk_sparse_attention.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Xunhao Lai & Jianqiao Lu.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import torch
15 | import math
16 | from typing import Optional
17 |
18 |
19 | def topk_sparse_attention_torch(
20 | q: torch.Tensor,
21 | k: torch.Tensor,
22 | v: torch.Tensor,
23 | topk_idx: torch.Tensor,
24 | block_size_k: int,
25 | cu_seqlens: torch.Tensor,
26 | softmax_scale: Optional[float] = None,
27 | block_size_q: int = 1,
28 | ) -> torch.Tensor:
29 | """Simple topk sparse attention varlen version implemented in torch. Extremly slow, only for debugging.
30 |
31 | Args:
32 | q (torch.Tensor): shape [total_len, num_q_heads, head_dim]
33 | k (torch.Tensor): shape [total_len, num_kv_heads, head_dim]
34 | v (torch.Tensor): shape [total_len, num_kv_heads, head_dim]
35 | topk_idx (torch.Tensor): topk block idx for each query, shape [num_kv_heads, total_len, topk]. -1 means padding.
36 | block_size_q (int): query block size.
37 | block_size_k (int): key value block size.
38 | cu_seqlens (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens in flash_attn_func_varlen.
39 | softmax_scale (Optional[float], optional): Defaults to None, means 1/sqrt(head_dim).
40 |
41 | Returns:
42 | torch.Tensor: attention output, shape [total_len, num_q_heads, head_dim]
43 | """
44 | total_seqlen, num_q_heads, head_dim = q.shape
45 | total_seqlen, num_kv_heads, head_dim = k.shape
46 | num_share_q_heads = num_q_heads // num_kv_heads
47 | batch_size = cu_seqlens.shape[0] - 1
48 | topk = topk_idx.shape[-1]
49 | seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
50 | seqblocks_q = torch.ceil(seqlens / block_size_q).to(torch.int32)
51 | cu_seqblocks_q = torch.nn.functional.pad(seqblocks_q.cumsum(0), (1, 0), value=0)
52 | if softmax_scale is None:
53 | softmax_scale = 1.0 / math.sqrt(head_dim)
54 | # get mask
55 | mask = torch.zeros(
56 | (num_kv_heads, total_seqlen, total_seqlen), dtype=torch.bool, device=q.device
57 | )
58 | for i in range(batch_size):
59 | num_q_blocks = math.ceil(seqlens[i] / block_size_q)
60 | num_kv_blocks = math.ceil(seqlens[i] / block_size_k)
61 | for h in range(num_kv_heads):
62 | temp_mask = torch.zeros(
63 | num_q_blocks, num_kv_blocks, dtype=torch.bool, device=q.device
64 | )
65 | temp_idx = topk_idx[h, cu_seqblocks_q[i] : cu_seqblocks_q[i + 1]].clone()
66 | temp_idx[temp_idx < 0] = 0
67 | temp_mask[torch.arange(num_q_blocks).to(q.device)[:, None], temp_idx] = True
68 | temp_mask = torch.repeat_interleave(temp_mask, block_size_q, dim=0)
69 | temp_mask = torch.repeat_interleave(temp_mask, block_size_k, dim=1)
70 | temp_mask = temp_mask[: seqlens[i], : seqlens[i]]
71 | mask[
72 | h, cu_seqlens[i] : cu_seqlens[i + 1], cu_seqlens[i] : cu_seqlens[i + 1]
73 | ] = temp_mask
74 | mask = torch.tril(mask).repeat_interleave(num_share_q_heads, 0)
75 | # qk attn
76 | qk = (
77 | torch.einsum("qhd,khd->hqk", q, k.repeat_interleave(num_share_q_heads, 1))
78 | * softmax_scale
79 | )
80 | qk = torch.masked_fill(qk, ~mask, -torch.inf)
81 | qk = torch.softmax(qk, dim=-1, dtype=torch.float32).to(q.dtype)
82 | o = torch.einsum("hqk,khd->qhd", qk, v.repeat_interleave(num_share_q_heads, 1))
83 | return o
84 |
--------------------------------------------------------------------------------
/native_sparse_attention/ops/triton/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XunhaoLai/native-sparse-attention-triton/9bea856c911ebf263be88d797fb28458f82f1d94/native_sparse_attention/ops/triton/__init__.py
--------------------------------------------------------------------------------
/native_sparse_attention/ops/triton/flash_attention_decode.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Xunhao Lai & Jianqiao Lu.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import math
16 | import torch
17 | import triton
18 | import triton.language as tl
19 | from typing import Optional
20 |
21 |
22 | @triton.jit
23 | def decode_kernel(
24 | q_ptr, # Q: b x h x d
25 | k_ptr, # K: b x n x h x d
26 | v_ptr, # V: b x n x h x d
27 | o_ptr, # O: b x h x d
28 | seqlens,
29 | # shape
30 | BATCH_SIZE,
31 | NUM_SHARE_Q_HEADS,
32 | HEAD_DIM,
33 | # sm_scale
34 | sm_scale,
35 | # stride
36 | stride_qb,
37 | stride_qh,
38 | stride_qd,
39 | stride_kb,
40 | stride_kn,
41 | stride_kh,
42 | stride_kd,
43 | stride_vb,
44 | stride_vn,
45 | stride_vh,
46 | stride_vd,
47 | stride_ob,
48 | stride_oh,
49 | stride_od,
50 | # META parameters
51 | BLOCK_SIZE_B: tl.constexpr,
52 | BLOCK_SIZE_K: tl.constexpr,
53 | BLOCK_SIZE_D: tl.constexpr,
54 | ):
55 | qk_scale = sm_scale * 1.44269504
56 | # get batch id and head id
57 | pid_h = tl.program_id(0)
58 | pid_b = tl.program_id(1)
59 | pid_kh = pid_h // NUM_SHARE_Q_HEADS
60 | # get q k start and len after rmpad
61 | off_b = tl.arange(0, BLOCK_SIZE_B)
62 | kv_len = tl.load(
63 | seqlens + pid_b * BLOCK_SIZE_B + off_b,
64 | mask=pid_b * BLOCK_SIZE_B + off_b < BATCH_SIZE,
65 | other=0,
66 | )
67 | max_kv_len = tl.max(kv_len)
68 | # init qkv pointer
69 | q_ptrs = tl.make_block_ptr(
70 | base=q_ptr + pid_h * stride_qh,
71 | shape=(BATCH_SIZE, HEAD_DIM),
72 | strides=(stride_qb, stride_qd),
73 | offsets=(pid_b * BLOCK_SIZE_B, 0),
74 | block_shape=(BLOCK_SIZE_B, BLOCK_SIZE_D),
75 | order=(1, 0),
76 | )
77 | k_ptrs = tl.make_block_ptr(
78 | base=k_ptr + pid_kh * stride_kh,
79 | shape=(BATCH_SIZE, max_kv_len, HEAD_DIM),
80 | strides=(stride_kb, stride_kn, stride_kd),
81 | offsets=(pid_b * BLOCK_SIZE_B, 0, 0),
82 | block_shape=(BLOCK_SIZE_B, BLOCK_SIZE_K, BLOCK_SIZE_D),
83 | order=(2, 1, 0),
84 | )
85 | v_ptrs = tl.make_block_ptr(
86 | base=v_ptr + pid_kh * stride_vh,
87 | shape=(BATCH_SIZE, max_kv_len, HEAD_DIM),
88 | strides=(stride_vb, stride_vn, stride_vd),
89 | offsets=(pid_b * BLOCK_SIZE_B, 0, 0),
90 | block_shape=(BLOCK_SIZE_B, BLOCK_SIZE_K, BLOCK_SIZE_D),
91 | order=(2, 1, 0),
92 | )
93 | # load q
94 | q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero")
95 | # init statistics
96 | off_k = tl.arange(0, BLOCK_SIZE_K)
97 | m_i = tl.full((BLOCK_SIZE_B,), float("-inf"), dtype=tl.float32)
98 | lse_i = tl.full((BLOCK_SIZE_B,), float("-inf"), dtype=tl.float32)
99 | acc_o = tl.full((BLOCK_SIZE_B, BLOCK_SIZE_D), 0, dtype=tl.float32)
100 | # full attention or causal attention
101 | for i in range(0, max_kv_len, BLOCK_SIZE_K):
102 | i = tl.multiple_of(i, BLOCK_SIZE_K)
103 | # load k
104 | k = tl.load(k_ptrs, boundary_check=(0, 1, 2), padding_option="zero")
105 | # compute qk
106 | qk = tl.zeros((BLOCK_SIZE_B, BLOCK_SIZE_K), dtype=tl.float32)
107 | qk += tl.where(off_k[None, :] + i < kv_len[:, None], 0, float("-inf"))
108 | # [B, D], [B, K, D] -> [B, K]
109 | qk += tl.sum(q[:, None, :] * k, axis=2) * qk_scale
110 | # compute m_ij and l_ij
111 | m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
112 | p = tl.math.exp2(qk - m_ij[:, None])
113 | l_ij = tl.sum(p, axis=1)
114 | # scale acc_o
115 | acc_o_scale = tl.math.exp2(m_i - m_ij)
116 | acc_o = acc_o * acc_o_scale[:, None]
117 | # load v and update acc_o
118 | v = tl.load(v_ptrs, boundary_check=(0, 1, 2), padding_option="zero")
119 | p = p.to(v.dtype)
120 | # [B, K], [B, K, D] -> [B, D]
121 | acc_o += tl.sum(p[:, :, None] * v, axis=1)
122 | # update statistics
123 | m_i = m_ij
124 | lse_i = m_ij + tl.math.log2(tl.math.exp2(lse_i - m_ij) + l_ij)
125 | # update ptrs
126 | k_ptrs = tl.advance(k_ptrs, (0, BLOCK_SIZE_K, 0))
127 | v_ptrs = tl.advance(v_ptrs, (0, BLOCK_SIZE_K, 0))
128 | # final scale
129 | acc_o = acc_o * tl.math.exp2(m_i - lse_i)[:, None]
130 | # save output
131 | o_ptrs = tl.make_block_ptr(
132 | base=o_ptr + pid_h * stride_oh,
133 | shape=(BATCH_SIZE, HEAD_DIM),
134 | strides=(stride_ob, stride_od),
135 | offsets=(pid_b * BLOCK_SIZE_B, 0),
136 | block_shape=(BLOCK_SIZE_B, BLOCK_SIZE_D),
137 | order=(1, 0),
138 | )
139 | tl.store(o_ptrs, acc_o.to(o_ptr.dtype.element_ty), boundary_check=(0, 1))
140 |
141 |
142 | def flash_attention_decode(
143 | q: torch.Tensor, # [batch_size, num_heads, head_dim]
144 | k: torch.Tensor, # [batch_size, max_len, num_heads, head_dim]
145 | v: torch.Tensor,
146 | seqlens: torch.Tensor, # [batch_size, ]
147 | sm_scale: Optional[float] = None,
148 | ) -> torch.Tensor:
149 | """flash attention for decode.
150 |
151 | Args:
152 | q (torch.Tensor): query, shape [batch_size, num_q_heads, head_dim]
153 | k (torch.Tensor): key, shape [batch_size, kv_len, num_kv_heads, head_dim]
154 | v (torch.Tensor): value, shape [batch_size, kv_len, num_kv_heads, head_dim]
155 | seqlens (torch.Tensor): kv length for each sequence
156 | sm_scale (Optional[float]): softmax scale, default to 1/sqrt(head_dim)
157 |
158 | Returns:
159 | torch.Tensor: attention output
160 | """
161 | # dtype check
162 | assert q.dtype == torch.bfloat16 or q.dtype == torch.float16
163 | assert k.dtype == q.dtype and v.dtype == q.dtype
164 | assert seqlens.dtype == torch.int32
165 | # shape
166 | batch_size, num_q_heads, head_dim = q.shape
167 | _, k_len, num_k_heads, head_dim = k.shape
168 | _, v_len, num_v_heads, head_dim = v.shape
169 | assert k_len == v_len and batch_size == seqlens.shape[0]
170 | # gqa
171 | assert num_k_heads == num_v_heads
172 | assert num_q_heads % num_k_heads == 0
173 | num_share_q_heads = num_q_heads // num_k_heads
174 | # sm scale
175 | if sm_scale is None:
176 | sm_scale = 1 / math.sqrt(head_dim)
177 | # output tensor
178 | o = torch.zeros_like(q)
179 | # launch kernel
180 | num_warps = 4 if head_dim <= 64 else 8
181 | num_stages = 3
182 | # there is a bug for triton 3.0.0 if BLOCK_SIZE_B > 16
183 | BLOCK_SIZE_B = min(16, triton.next_power_of_2(batch_size))
184 | BLOCK_SIZE_K = 128
185 | BLOCK_SIZE_D = triton.next_power_of_2(head_dim)
186 | grid = (num_q_heads, triton.cdiv(batch_size, BLOCK_SIZE_B))
187 | decode_kernel[grid](
188 | q,
189 | k,
190 | v,
191 | o,
192 | seqlens,
193 | batch_size,
194 | num_share_q_heads,
195 | head_dim,
196 | sm_scale,
197 | q.stride(0),
198 | q.stride(1),
199 | q.stride(2),
200 | k.stride(0),
201 | k.stride(1),
202 | k.stride(2),
203 | k.stride(3),
204 | v.stride(0),
205 | v.stride(1),
206 | v.stride(2),
207 | v.stride(3),
208 | o.stride(0),
209 | o.stride(1),
210 | o.stride(2),
211 | BLOCK_SIZE_B=BLOCK_SIZE_B,
212 | BLOCK_SIZE_K=BLOCK_SIZE_K,
213 | BLOCK_SIZE_D=BLOCK_SIZE_D,
214 | num_warps=num_warps,
215 | num_stages=num_stages,
216 | )
217 | return o
218 |
219 |
220 | def torch_attention_decode(
221 | q: torch.Tensor, # [batch_size, num_heads, head_dim]
222 | k: torch.Tensor, # [batch_size, max_len, num_heads, head_dim]
223 | v: torch.Tensor,
224 | seqlens: torch.Tensor, # [batch_size, ]
225 | sm_scale: Optional[float] = None,
226 | ):
227 | # dtype check
228 | assert q.dtype == torch.bfloat16 or q.dtype == torch.float16
229 | assert k.dtype == q.dtype and v.dtype == q.dtype
230 | assert seqlens.dtype == torch.int32
231 | # shape
232 | batch_size, num_q_heads, head_dim = q.shape
233 | _, k_len, num_k_heads, head_dim = k.shape
234 | _, v_len, num_v_heads, head_dim = v.shape
235 | assert k_len == v_len and batch_size == seqlens.shape[0]
236 | # gqa
237 | assert num_k_heads == num_v_heads
238 | assert num_q_heads % num_k_heads == 0
239 | num_share_q_heads = num_q_heads // num_k_heads
240 | # sm scale
241 | if sm_scale is None:
242 | sm_scale = 1 / math.sqrt(head_dim)
243 | # attention
244 | attn = (
245 | torch.einsum(
246 | "bqhd,bkhd->bhqk",
247 | q.unsqueeze(1),
248 | k.repeat_interleave(num_share_q_heads, dim=2),
249 | )
250 | * sm_scale
251 | )
252 | mask = torch.arange(k_len, device=q.device)[None, :] < seqlens[:, None]
253 | attn = attn.masked_fill(~mask[:, None, None, :], -torch.inf)
254 | attn = torch.softmax(attn, dim=-1, dtype=torch.float32).to(q.dtype)
255 | out = torch.einsum(
256 | "bhqk,bkhd->bqhd", attn, v.repeat_interleave(num_share_q_heads, dim=2)
257 | ).squeeze(1)
258 | return out
259 |
260 |
261 | if __name__ == "__main__":
262 | torch.manual_seed(42)
263 | batch_size = 76
264 | max_length = 8192
265 | seqlens = torch.arange(batch_size, dtype=torch.int32).cuda() * 128 + 1
266 | seqlens[seqlens > max_length] = max_length
267 | seqlens = seqlens[torch.randn_like(seqlens, dtype=torch.float32).argsort(-1)]
268 | q = (
269 | torch.empty(batch_size, 32, 128, device="cuda")
270 | .uniform_(-1, 1)
271 | .to(torch.bfloat16)
272 | )
273 | k = (
274 | torch.empty(batch_size, max_length, 4, 128, device="cuda")
275 | .uniform_(-1, 1)
276 | .to(torch.bfloat16)
277 | )
278 | v = (
279 | torch.empty(batch_size, max_length, 4, 128, device="cuda")
280 | .uniform_(-1, 1)
281 | .to(torch.bfloat16)
282 | )
283 |
284 | o1 = torch_attention_decode(q, k, v, seqlens)
285 | o2 = flash_attention_decode(q, k, v, seqlens)
286 |
287 | print(torch.allclose(o1, o2, atol=1e-2, rtol=1e-2))
288 |
--------------------------------------------------------------------------------
/native_sparse_attention/ops/triton/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Xunhao Lai & Jianqiao Lu.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import torch
15 |
16 |
17 | def is_hopper_gpu():
18 | if torch.cuda.is_available():
19 | device_capability = torch.cuda.get_device_capability()
20 | major, minor = device_capability
21 | return major == 9
22 | return False
23 |
24 |
25 | def get_compressed_seqlens(
26 | cu_seqlens: torch.Tensor, kernel_size: int, kernel_stride: int
27 | ):
28 | # compute seqlens after compression
29 | seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
30 | y_seqlens = torch.floor((seqlens - kernel_size) / kernel_stride).to(torch.int32) + 1
31 | # corner case, if sequence_length < kernel_size, no compression for this sequence
32 | y_seqlens[seqlens < kernel_size] = 0
33 | y_cu_seqlens = torch.zeros(
34 | y_seqlens.shape[0] + 1, dtype=torch.int32, device=cu_seqlens.device
35 | )
36 | y_cu_seqlens[1:] = torch.cumsum(y_seqlens, dim=0)
37 | return y_seqlens, y_cu_seqlens
38 |
39 |
40 | def get_num_warps_stages(head_dim, block_size, is_hopper_gpu):
41 | """
42 | Returns recommended num_warps and num_stages for a Sparse Attention kernel in Triton.
43 |
44 | Args:
45 | head_dim (int): Size of the head dimension.
46 | block_size (int): Size of the block in the attention matrix.
47 | is_hopper_gpu (bool): True if Hopper GPU, False if Ampere GPU.
48 |
49 | Returns:
50 | tuple: (num_warps, num_stages) recommended values.
51 | """
52 | # Determine if head_dim and block_size exceed 64
53 | head_large = head_dim > 64
54 | block_large = block_size > 64
55 |
56 | if is_hopper_gpu:
57 | # Hopper GPU recommendations
58 | if head_large and block_large:
59 | num_warps = 8
60 | num_stages = 3
61 | elif head_large or block_large:
62 | num_warps = 4
63 | num_stages = 3
64 | else:
65 | num_warps = 2
66 | num_stages = 2
67 | else:
68 | # Ampere GPU recommendations
69 | if head_large and block_large:
70 | num_warps = 8
71 | num_stages = 3
72 | elif head_large or block_large:
73 | num_warps = 8
74 | num_stages = 3
75 | else:
76 | num_warps = 2
77 | num_stages = 2
78 | if head_dim > 128:
79 | num_stages = 2
80 | return num_warps, num_stages
81 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Xunhao Lai & Jianqiao Lu.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from setuptools import setup, find_packages
16 |
17 | # Read the README.md file for the long description
18 | with open("README.md", "r", encoding="utf-8") as fh:
19 | long_description = fh.read()
20 |
21 | # Define the setup configuration
22 | setup(
23 | name="native-sparse-attention-triton",
24 | version="0.1.0",
25 | description="An efficient implementation of Native Sparse Attention using Triton",
26 | long_description=long_description,
27 | long_description_content_type="text/markdown",
28 | author="XunhaoLai",
29 | author_email="laixunhao@pku.edu.cn", # Replace with your actual email
30 | url="https://github.com/XunhaoLai/native-sparse-attention-triton",
31 | packages=find_packages(),
32 | install_requires=[
33 | "torch>=2.1.0",
34 | "triton>=3.0.0",
35 | "einops>=0.7.0",
36 | "flash-attn>=2.6.3",
37 | "transformers>=4.44.0",
38 | ],
39 | classifiers=[
40 | "Programming Language :: Python :: 3",
41 | "License :: OSI Approved :: Apache Software License",
42 | "Operating System :: OS Independent",
43 | ],
44 | python_requires=">=3.9",
45 | )
46 |
--------------------------------------------------------------------------------
/test/test_compress_key_value.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Xunhao Lai & Jianqiao Lu.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific
13 | import torch
14 | import triton
15 | from native_sparse_attention.ops import linear_compress
16 |
17 |
18 | if __name__ == "__main__":
19 | torch.manual_seed(42)
20 | num_heads = 4
21 | head_dim = 192
22 | kernel_size = 32
23 | kernel_stride = 16
24 | seqlens = torch.LongTensor([1000, 2000, 4096]).int().cuda()
25 | cu_seqlens = torch.cat(
26 | [
27 | torch.zeros(1, dtype=torch.int32, device="cuda"),
28 | torch.cumsum(seqlens, dim=0),
29 | ],
30 | dim=0,
31 | ).to(torch.int32)
32 |
33 | x = (
34 | torch.zeros(cu_seqlens[-1], num_heads, head_dim)
35 | .uniform_(-1, 1)
36 | .cuda()
37 | .bfloat16()
38 | .requires_grad_()
39 | )
40 | w = (
41 | torch.zeros(num_heads, kernel_size * head_dim, head_dim)
42 | .uniform_(-1, 1)
43 | .cuda()
44 | .bfloat16()
45 | .requires_grad_()
46 | )
47 | pe = (
48 | torch.zeros(num_heads, kernel_size, head_dim)
49 | .uniform_(-1, 1)
50 | .cuda()
51 | .bfloat16()
52 | .requires_grad_()
53 | )
54 |
55 | y, y_cu_seqlens = linear_compress(x, w, cu_seqlens, kernel_size, kernel_stride, pe)
56 |
57 | loss = (y * torch.randn_like(y)).mean()
58 | loss.backward()
59 |
60 | print(y.shape, y_cu_seqlens)
61 | print(y.norm(), x.grad.norm())
62 | print(
63 | w.grad.norm() if w.grad is not None else None,
64 | pe.grad.norm() if pe.grad is not None else None,
65 | )
66 |
67 | # benchmark
68 | @triton.testing.perf_report(
69 | triton.testing.Benchmark(
70 | x_names=["N"],
71 | x_vals=[1024 * 2**i for i in range(1, 6)],
72 | line_arg="provider",
73 | line_vals=["batch1", "batch8", "batch32"],
74 | line_names=["batch1", "batch8", "batch32"],
75 | styles=[("green", "-"), ("blue", "-"), ("blue", "--")],
76 | ylabel="ms",
77 | plot_name="** forward **",
78 | args={"H": 4, "D": 128},
79 | )
80 | )
81 | def benchmark(N, H, D, provider):
82 | K, S = 32, 16
83 | x = torch.zeros(N, H, D, device="cuda", dtype=torch.bfloat16).uniform_(-1, 1)
84 | w = torch.zeros(H, K * D, D, device="cuda", dtype=torch.bfloat16).uniform_(
85 | -1, 1
86 | )
87 | pe = torch.zeros(H, K, D, device="cuda", dtype=torch.bfloat16).uniform_(-1, 1)
88 | cu_seqlens_b1 = torch.LongTensor([0, N]).int().cuda()
89 | cu_seqlens_b8 = (
90 | torch.LongTensor([N // 8 if i > 0 else 0 for i in range(9)]).int().cuda()
91 | )
92 | cu_seqlens_b32 = (
93 | torch.LongTensor([N // 32 if i > 0 else 0 for i in range(33)]).int().cuda()
94 | )
95 | cu_seqlens_b1 = cu_seqlens_b1.cumsum(0).to(torch.int32)
96 | cu_seqlens_b8 = cu_seqlens_b8.cumsum(0).to(torch.int32)
97 | cu_seqlens_b32 = cu_seqlens_b32.cumsum(0).to(torch.int32)
98 |
99 | quantiles = [0.5, 0.2, 0.8]
100 | if provider == "batch1":
101 | ms, min_ms, max_ms = triton.testing.do_bench(
102 | lambda: linear_compress(x, w, cu_seqlens_b1, K, S, pe),
103 | quantiles=quantiles,
104 | )
105 | if provider == "batch8":
106 | ms, min_ms, max_ms = triton.testing.do_bench(
107 | lambda: linear_compress(x, w, cu_seqlens_b8, K, S, pe),
108 | quantiles=quantiles,
109 | )
110 | if provider == "batch32":
111 | ms, min_ms, max_ms = triton.testing.do_bench(
112 | lambda: linear_compress(x, w, cu_seqlens_b32, K, S, pe),
113 | quantiles=quantiles,
114 | )
115 | return ms, min_ms, max_ms
116 |
117 | benchmark.run(show_plots=True, print_data=True)
118 |
--------------------------------------------------------------------------------
/test/test_compressed_attention.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Xunhao Lai & Jianqiao Lu.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific
13 | import torch
14 | import triton
15 | import math
16 | from native_sparse_attention.ops.torch.compressed_attention import (
17 | compressed_attention_torch,
18 | )
19 | from native_sparse_attention.ops.triton.compressed_attention import (
20 | compressed_attention,
21 | _compressed_attention_bwd,
22 | )
23 | from native_sparse_attention.ops import avgpool_compress
24 | from native_sparse_attention.ops.triton.flash_attention import (
25 | flash_attention_varlen,
26 | _flash_attention_bwd,
27 | )
28 | from flash_attn import flash_attn_varlen_func
29 | from flash_attn.flash_attn_interface import _flash_attn_varlen_backward
30 |
31 |
32 | if __name__ == "__main__":
33 | torch.manual_seed(42)
34 | num_heads = 32
35 | head_dim = 96
36 | kernel_size = 32
37 | kernel_stride = 16
38 | block_size = 64
39 | topk = 16
40 | seqlens = torch.LongTensor([1000, 4000, 8192]).int().cuda()
41 | cu_seqlens = torch.cat(
42 | [
43 | torch.zeros(1, dtype=torch.int32, device="cuda"),
44 | torch.cumsum(seqlens, dim=0),
45 | ],
46 | dim=0,
47 | ).to(torch.int32)
48 | max_seqlen = seqlens.max().item()
49 | q = (
50 | torch.empty(cu_seqlens[-1], num_heads, head_dim, device="cuda")
51 | .uniform_(-1, 1)
52 | .to(torch.float16)
53 | )
54 | k = (
55 | torch.empty(cu_seqlens[-1], num_heads // 4, head_dim, device="cuda")
56 | .uniform_(-1, 1)
57 | .to(torch.float16)
58 | )
59 | v = (
60 | torch.empty(cu_seqlens[-1], num_heads // 4, head_dim, device="cuda")
61 | .uniform_(-1, 1)
62 | .to(torch.float16)
63 | )
64 | q.requires_grad = True
65 | k.requires_grad = True
66 | v.requires_grad = True
67 |
68 | ck, ck_cu_seqlens = avgpool_compress(
69 | k, None, cu_seqlens, kernel_size, kernel_stride
70 | )
71 |
72 | ck = torch.empty_like(ck).uniform_(-1, 1)
73 | cv = torch.empty_like(ck).uniform_(-1, 1)
74 | ck.requires_grad = True
75 | cv.requires_grad = True
76 |
77 | ck_seqlens = ck_cu_seqlens[1:] - ck_cu_seqlens[:-1]
78 | ck_max_seqlen = ck_seqlens.max().item()
79 |
80 | o, topk_idx = compressed_attention_torch(
81 | q,
82 | ck,
83 | cv,
84 | kernel_size,
85 | kernel_stride,
86 | block_size,
87 | topk,
88 | cu_seqlens,
89 | ck_cu_seqlens,
90 | max_seqlen,
91 | ck_max_seqlen,
92 | )
93 |
94 | randn = torch.randn_like(o)
95 | loss = (o * randn).sum()
96 | loss.backward()
97 |
98 | torch.manual_seed(42)
99 |
100 | q1 = q.detach().clone().requires_grad_()
101 | ck1 = ck.detach().clone().requires_grad_()
102 | cv1 = cv.detach().clone().requires_grad_()
103 |
104 | o1, topk_idx1 = compressed_attention(
105 | q1,
106 | ck1,
107 | cv1,
108 | kernel_size,
109 | kernel_stride,
110 | block_size,
111 | topk,
112 | cu_seqlens,
113 | ck_cu_seqlens,
114 | max_seqlen,
115 | ck_max_seqlen,
116 | )
117 | randn1 = randn.clone().detach()
118 | loss1 = (o1 * randn1).sum()
119 | loss1.backward()
120 |
121 | print("Same Output:", torch.allclose(o, o1, atol=0.01, rtol=0.01))
122 | print("Max Error:", (o - o1).abs().max().item())
123 | print()
124 | print("Same Query Gradient:", torch.allclose(q.grad, q1.grad, atol=0.01, rtol=0.01))
125 | print("Max Query Gradient Error:", (q.grad - q1.grad).abs().max().item())
126 | print()
127 | print("Same Key Gradient:", torch.allclose(ck.grad, ck1.grad, atol=0.01, rtol=0.01))
128 | print("Max Key Gradient Error:", (ck.grad - ck1.grad).abs().max().item())
129 | print()
130 | print(
131 | "Same Value Gradient:", torch.allclose(cv.grad, cv1.grad, atol=0.01, rtol=0.01)
132 | )
133 | print("Max Value Gradient Error:", (cv.grad - cv1.grad).abs().max().item())
134 | print()
135 |
136 | # There are some discrepancies in the topk indices (about 3%). These might be due to bugs and will be addressed later.
137 | all_num = 0
138 | err_num = 0
139 | for h in range(topk_idx.shape[0]):
140 | for i in range(topk_idx.shape[1]):
141 | s = set(topk_idx[h, i][topk_idx[h, i] >= 0].tolist())
142 | s1 = set(topk_idx1[h, i][topk_idx1[h, i] >= 0].tolist())
143 | all_num += len(s)
144 | err_num += len(s) - len(s1 & s)
145 | print("Topk Idx Error Rate:", err_num / all_num)
146 |
147 | # benchmark
148 | @triton.testing.perf_report(
149 | triton.testing.Benchmark(
150 | x_names=["N"],
151 | x_vals=[1024 * 2**i for i in range(1, 8)],
152 | line_arg="provider",
153 | line_vals=[
154 | "flash",
155 | "triton-flash",
156 | "triton-compressed",
157 | "triton-compressed-wo-score",
158 | ],
159 | line_names=[
160 | "Flash",
161 | "Triton-Flash",
162 | "Compressed",
163 | "Compressed-wo-Score",
164 | ],
165 | styles=[("green", "-"), ("green", "--"), ("blue", "-"), ("blue", "--")],
166 | ylabel="ms",
167 | plot_name="** forward speed for compressed attention (kernel 32 stride 16) **",
168 | args={"H": 64, "D": 128},
169 | )
170 | )
171 | def benchmark(N, H, D, provider):
172 | q = torch.randn((N, H, D), device="cuda", dtype=torch.bfloat16)
173 | k = torch.randn((N, H // 16, D), device="cuda", dtype=torch.bfloat16)
174 | v = torch.randn((N, H // 16, D), device="cuda", dtype=torch.bfloat16)
175 | cu_seqlens = torch.tensor([0, N], device="cuda", dtype=torch.int32)
176 | sm_scale = 1 / math.sqrt(D)
177 | com_k, com_cu_seqlens = avgpool_compress(k, None, cu_seqlens, 32, 16, None)
178 | com_v, com_cu_seqlens = avgpool_compress(v, None, cu_seqlens, 32, 16, None)
179 | M = (com_cu_seqlens[1:] - com_cu_seqlens[:-1]).max().item()
180 |
181 | quantiles = [0.5, 0.2, 0.8]
182 | if provider == "flash":
183 | ms, min_ms, max_ms = triton.testing.do_bench(
184 | lambda: flash_attn_varlen_func(
185 | q,
186 | k,
187 | v,
188 | cu_seqlens,
189 | cu_seqlens,
190 | N,
191 | N,
192 | dropout_p=0.0,
193 | causal=True,
194 | softmax_scale=sm_scale,
195 | ),
196 | quantiles=quantiles,
197 | )
198 | if provider == "triton-flash":
199 | ms, min_ms, max_ms = triton.testing.do_bench(
200 | lambda: flash_attention_varlen(
201 | q, k, v, cu_seqlens, cu_seqlens, N, N, True, sm_scale
202 | ),
203 | quantiles=quantiles,
204 | )
205 | if provider == "triton-compressed":
206 | ms, min_ms, max_ms = triton.testing.do_bench(
207 | lambda: compressed_attention(
208 | q,
209 | com_k,
210 | com_v,
211 | 32,
212 | 16,
213 | 64,
214 | 16,
215 | cu_seqlens,
216 | com_cu_seqlens,
217 | N,
218 | M,
219 | sm_scale,
220 | ),
221 | quantiles=quantiles,
222 | )
223 | if provider == "triton-compressed-wo-score":
224 | ms, min_ms, max_ms = triton.testing.do_bench(
225 | lambda: compressed_attention(
226 | q,
227 | com_k,
228 | com_v,
229 | 32,
230 | 16,
231 | 64,
232 | -1,
233 | cu_seqlens,
234 | com_cu_seqlens,
235 | N,
236 | M,
237 | sm_scale,
238 | ),
239 | quantiles=quantiles,
240 | )
241 | return ms, min_ms, max_ms
242 |
243 | benchmark.run(show_plots=True, print_data=True)
244 |
245 | # benchmark
246 | @triton.testing.perf_report(
247 | triton.testing.Benchmark(
248 | x_names=["N"],
249 | x_vals=[1024 * 2**i for i in range(1, 8)],
250 | line_arg="provider",
251 | line_vals=[
252 | "flash",
253 | "triton-flash",
254 | "triton-compressed",
255 | ],
256 | line_names=[
257 | "Flash",
258 | "Triton-Flash",
259 | "Compressed",
260 | ],
261 | styles=[("green", "-"), ("green", "--"), ("blue", "-"), ("blue", "--")],
262 | ylabel="ms",
263 | plot_name="** backward speed for compressed attention (kernel 32 stride 16) **",
264 | args={"H": 64, "D": 128},
265 | )
266 | )
267 | def benchmark(N, H, D, provider):
268 | q = torch.randn((N, H, D), device="cuda", dtype=torch.bfloat16)
269 | k = torch.randn((N, H // 16, D), device="cuda", dtype=torch.bfloat16)
270 | v = torch.randn((N, H // 16, D), device="cuda", dtype=torch.bfloat16)
271 | o = torch.randn((N, H, D), device="cuda", dtype=torch.bfloat16)
272 | do = torch.randn((N, H, D), device="cuda", dtype=torch.bfloat16)
273 | lse = torch.randn((H, N), device="cuda", dtype=torch.float32)
274 | sm_scale = 1 / math.sqrt(D)
275 | cu_seqlens = torch.tensor([0, N], device="cuda", dtype=torch.int32)
276 | dq = torch.zeros_like(q)
277 | dk = torch.zeros_like(k)
278 | dv = torch.zeros_like(v)
279 |
280 | com_k, com_cu_seqlens = avgpool_compress(k, None, cu_seqlens, 32, 16, None)
281 | com_v, com_cu_seqlens = avgpool_compress(v, None, cu_seqlens, 32, 16, None)
282 | M = (com_cu_seqlens[1:] - com_cu_seqlens[:-1]).max().item()
283 |
284 | quantiles = [0.5, 0.2, 0.8]
285 | if provider == "flash":
286 | ms, min_ms, max_ms = triton.testing.do_bench(
287 | lambda: _flash_attn_varlen_backward(
288 | do,
289 | q,
290 | k,
291 | v,
292 | o,
293 | lse.transpose(0, 1),
294 | dq,
295 | dk,
296 | dv,
297 | cu_seqlens,
298 | cu_seqlens,
299 | N,
300 | N,
301 | dropout_p=0.0,
302 | causal=True,
303 | softmax_scale=sm_scale,
304 | window_size=(-1, -1),
305 | softcap=0.0,
306 | alibi_slopes=None,
307 | deterministic=False,
308 | ),
309 | quantiles=quantiles,
310 | )
311 | if provider == "triton-flash":
312 | ms, min_ms, max_ms = triton.testing.do_bench(
313 | lambda: _flash_attention_bwd(
314 | o, do, lse, q, k, v, cu_seqlens, cu_seqlens, N, N, True, sm_scale
315 | ),
316 | quantiles=quantiles,
317 | )
318 | if provider == "triton-compressed":
319 | ms, min_ms, max_ms = triton.testing.do_bench(
320 | lambda: _compressed_attention_bwd(
321 | o,
322 | do,
323 | lse,
324 | q,
325 | com_k,
326 | com_v,
327 | 32,
328 | 16,
329 | cu_seqlens,
330 | com_cu_seqlens,
331 | N,
332 | M,
333 | sm_scale,
334 | ),
335 | quantiles=quantiles,
336 | )
337 | return ms, min_ms, max_ms
338 |
339 | benchmark.run(show_plots=True, print_data=True)
340 |
--------------------------------------------------------------------------------
/test/test_flash_attention.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Xunhao Lai & Jianqiao Lu.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific
13 | import torch
14 | import triton
15 | import math
16 | from native_sparse_attention.ops.triton.flash_attention import (
17 | flash_attention_varlen,
18 | _flash_attention_fwd,
19 | _flash_attention_bwd,
20 | )
21 | from flash_attn import flash_attn_varlen_func
22 | from flash_attn.flash_attn_interface import (
23 | _flash_attn_varlen_forward,
24 | _flash_attn_varlen_backward,
25 | )
26 |
27 |
28 | if __name__ == "__main__":
29 | for causal in [False, True]:
30 | # triton flash attention
31 | torch.manual_seed(42)
32 | q = torch.randn(
33 | 1000, 32, 128, dtype=torch.float16, device="cuda", requires_grad=True
34 | )
35 | k = torch.randn(
36 | 1000, 16, 128, dtype=torch.float16, device="cuda", requires_grad=True
37 | )
38 | v = torch.randn(
39 | 1000, 16, 128, dtype=torch.float16, device="cuda", requires_grad=True
40 | )
41 | cu_seqlens_q = torch.Tensor([0, 100, 384, 1000]).cuda().to(torch.int32)
42 | cu_seqlens_k = torch.Tensor([0, 100, 384, 1000]).cuda().to(torch.int32)
43 | max_seqlen_q = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).max()
44 | max_seqlen_k = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).max()
45 | o = flash_attn_varlen_func(
46 | q,
47 | k,
48 | v,
49 | cu_seqlens_q,
50 | cu_seqlens_k,
51 | max_seqlen_q,
52 | max_seqlen_k,
53 | causal=causal,
54 | )
55 | randn = torch.randn_like(o)
56 | loss = (o * randn).sum()
57 | loss.backward()
58 |
59 | # flash attention
60 | torch.manual_seed(42)
61 | q1 = q.clone().detach().requires_grad_()
62 | k1 = k.clone().detach().requires_grad_()
63 | v1 = v.clone().detach().requires_grad_()
64 | cu_seqlens_q1 = cu_seqlens_q.clone().detach()
65 | cu_seqlens_k1 = cu_seqlens_k.clone().detach()
66 | max_seqlen_q1 = (cu_seqlens_q1[1:] - cu_seqlens_q1[:-1]).max()
67 | max_seqlen_k1 = (cu_seqlens_k1[1:] - cu_seqlens_k1[:-1]).max()
68 | o1 = flash_attention_varlen(
69 | q1,
70 | k1,
71 | v1,
72 | cu_seqlens_q1,
73 | cu_seqlens_k1,
74 | max_seqlen_q1,
75 | max_seqlen_k1,
76 | causal=causal,
77 | )
78 | randn2 = randn.clone().detach()
79 | loss2 = (o1 * randn2).sum()
80 | loss2.backward()
81 |
82 | # diff
83 | print(
84 | f"=== Flash Attention Backward Test ({'causal' if causal else 'full'}) ==="
85 | )
86 | print("Same Output:", torch.allclose(o, o1, atol=0.01, rtol=0.01))
87 | print("Max Error:", (o - o1).abs().max().item())
88 | print()
89 | print(
90 | "Same Query Gradient:",
91 | torch.allclose(q.grad, q1.grad, atol=0.01, rtol=0.01),
92 | )
93 | print("Max Query Gradient Error:", (q.grad - q1.grad).abs().max().item())
94 | print()
95 | print(
96 | "Same Key Gradient:", torch.allclose(k.grad, k1.grad, atol=0.01, rtol=0.01)
97 | )
98 | print("Max Key Gradient Error:", (k.grad - k1.grad).abs().max().item())
99 | print()
100 | print(
101 | "Same Value Gradient:",
102 | torch.allclose(v.grad, v1.grad, atol=0.01, rtol=0.01),
103 | )
104 | print("Max Value Gradient Error:", (v.grad - v1.grad).abs().max().item())
105 | print()
106 |
107 | # benchmark
108 | @triton.testing.perf_report(
109 | triton.testing.Benchmark(
110 | x_names=["N"],
111 | x_vals=[1024 * 2**i for i in range(1, 6)],
112 | line_arg="provider",
113 | line_vals=["flash", "triton-flash"],
114 | line_names=[
115 | "Flash",
116 | "Triton-Flash",
117 | ],
118 | styles=[("green", "-"), ("green", "--")],
119 | ylabel="ms",
120 | plot_name="** forward **",
121 | args={"H": 64, "D": 128},
122 | )
123 | )
124 | def benchmark(N, H, D, provider):
125 | q = torch.randn((N, H, D), device="cuda", dtype=torch.bfloat16)
126 | k = torch.randn((N, H // 16, D), device="cuda", dtype=torch.bfloat16)
127 | v = torch.randn((N, H // 16, D), device="cuda", dtype=torch.bfloat16)
128 | cu_seqlens = torch.tensor([0, N], device="cuda", dtype=torch.int32)
129 | sm_scale = 1 / math.sqrt(D)
130 |
131 | quantiles = [0.5, 0.2, 0.8]
132 | if provider == "flash":
133 | ms, min_ms, max_ms = triton.testing.do_bench(
134 | lambda: _flash_attn_varlen_forward(
135 | q,
136 | k,
137 | v,
138 | cu_seqlens,
139 | cu_seqlens,
140 | N,
141 | N,
142 | dropout_p=0.0,
143 | causal=True,
144 | softmax_scale=sm_scale,
145 | ),
146 | quantiles=quantiles,
147 | )
148 | if provider == "triton-flash":
149 | ms, min_ms, max_ms = triton.testing.do_bench(
150 | lambda: _flash_attention_fwd(
151 | q, k, v, cu_seqlens, cu_seqlens, N, N, True, sm_scale
152 | ),
153 | quantiles=quantiles,
154 | )
155 | return ms, min_ms, max_ms
156 |
157 | benchmark.run(show_plots=True, print_data=True)
158 |
159 | # benchmark
160 | @triton.testing.perf_report(
161 | triton.testing.Benchmark(
162 | x_names=["N"],
163 | x_vals=[1024 * 2**i for i in range(1, 6)],
164 | line_arg="provider",
165 | line_vals=["flash", "triton-flash"],
166 | line_names=[
167 | "Flash",
168 | "Triton-Flash",
169 | ],
170 | styles=[("green", "-"), ("green", "--")],
171 | ylabel="ms",
172 | plot_name="** backward **",
173 | args={"H": 64, "D": 128},
174 | )
175 | )
176 | def benchmark(N, H, D, provider):
177 | q = torch.randn((N, H, D), device="cuda", dtype=torch.bfloat16)
178 | k = torch.randn((N, H // 16, D), device="cuda", dtype=torch.bfloat16)
179 | v = torch.randn((N, H // 16, D), device="cuda", dtype=torch.bfloat16)
180 | o = torch.randn((N, H, D), device="cuda", dtype=torch.bfloat16)
181 | do = torch.randn((N, H, D), device="cuda", dtype=torch.bfloat16)
182 | lse = torch.randn((H, N), device="cuda", dtype=torch.float32)
183 | sm_scale = 1 / math.sqrt(D)
184 | cu_seqlens = torch.tensor([0, N], device="cuda", dtype=torch.int32)
185 | dq = torch.zeros_like(q)
186 | dk = torch.zeros_like(k)
187 | dv = torch.zeros_like(v)
188 |
189 | quantiles = [0.5, 0.2, 0.8]
190 | if provider == "flash":
191 | ms, min_ms, max_ms = triton.testing.do_bench(
192 | lambda: _flash_attn_varlen_backward(
193 | do,
194 | q,
195 | k,
196 | v,
197 | o,
198 | lse.transpose(0, 1),
199 | dq,
200 | dk,
201 | dv,
202 | cu_seqlens,
203 | cu_seqlens,
204 | N,
205 | N,
206 | dropout_p=0.0,
207 | causal=True,
208 | softmax_scale=sm_scale,
209 | window_size=(-1, -1),
210 | softcap=0.0,
211 | alibi_slopes=None,
212 | deterministic=False,
213 | ),
214 | quantiles=quantiles,
215 | )
216 | if provider == "triton-flash":
217 | ms, min_ms, max_ms = triton.testing.do_bench(
218 | lambda: _flash_attention_bwd(
219 | o, do, lse, q, k, v, cu_seqlens, cu_seqlens, N, N, True, sm_scale
220 | ),
221 | quantiles=quantiles,
222 | )
223 | return ms, min_ms, max_ms
224 |
225 | benchmark.run(show_plots=True, print_data=True)
226 |
--------------------------------------------------------------------------------
/test/test_kv_cache.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Xunhao Lai & Jianqiao Lu.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import torch
15 | from native_sparse_attention.module.kv_cache import NSACache
16 |
17 |
18 | if __name__ == "__main__":
19 | from native_sparse_attention.ops import avgpool_compress
20 |
21 | torch.manual_seed(42)
22 |
23 | num_heads = 4
24 | head_dim = 128
25 | seqlens = torch.tensor([12, 576, 12000]).to(torch.int32).cuda()
26 | batch_size = seqlens.shape[0]
27 | cu_seqlens = torch.zeros(seqlens.shape[0] + 1, dtype=torch.int32, device="cuda")
28 | cu_seqlens[1:] = seqlens.cumsum(0)
29 |
30 | # init cache
31 | cache = NSACache(4, 16384, num_heads, head_dim, 32, 16, 512, torch.bfloat16, "cuda")
32 |
33 | # test prefill
34 | step = 0
35 | k = torch.randn(cu_seqlens[-1], num_heads, head_dim).cuda().bfloat16()
36 | v = torch.randn_like(k)
37 | ck, _ = avgpool_compress(k, None, cu_seqlens, 32, 16, None)
38 | cv, _ = avgpool_compress(v, None, cu_seqlens, 32, 16, None)
39 | cache.prepare_compress(cu_seqlens, step, k, v)
40 | cache.update_kv(cu_seqlens, step, ck, cv, k, v, k, v)
41 |
42 | # test decode
43 | step = 1
44 | k = torch.randn(batch_size, num_heads, head_dim).cuda().bfloat16()
45 | v = torch.randn_like(k)
46 | ck = torch.randn_like(k)
47 | cv = torch.randn_like(v)
48 | cache.prepare_compress(cu_seqlens, step, k, v)
49 | cache.update_kv(cu_seqlens, step, ck, cv, k, v, k, v)
50 |
--------------------------------------------------------------------------------
/test/test_linear_compress.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Xunhao Lai & Jianqiao Lu.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch
16 | import triton
17 | from native_sparse_attention.ops.torch.compress_key_value import linear_compress_torch
18 | from native_sparse_attention.ops.triton.linear_compress import linear_compress
19 |
20 |
21 | def test_linear_compress(
22 | batch_size: int = 1,
23 | num_heads: int = 1,
24 | head_dim: int = 32,
25 | max_seqlen: int = 32,
26 | kernel_sizes: list = [16, 32],
27 | kernel_strides: list = [8, 16],
28 | use_pe: bool = True,
29 | dtype: torch.dtype = torch.float32,
30 | device: str = "cuda",
31 | ):
32 | """
33 | Test both PyTorch and Triton implementations of linear_compress for equivalence,
34 | including forward and backward passes.
35 |
36 | Args:
37 | batch_size: Number of sequences in the batch
38 | num_heads: Number of attention heads
39 | head_dim: Dimension of each attention head
40 | max_seqlen: Maximum sequence length
41 | kernel_sizes: List of kernel sizes to test
42 | kernel_strides: List of kernel strides to test
43 | use_pe: Whether to test with positional encoding
44 | dtype: Data type for tensors
45 | device: Device to run the test on
46 | """
47 | torch.manual_seed(42)
48 |
49 | # Generate random sequence lengths for each batch
50 |
51 | seqlens = torch.randint(
52 | low=kernel_sizes[0], # minimum length should be at least kernel_size
53 | high=max_seqlen + 1,
54 | size=(batch_size,),
55 | device=device,
56 | )
57 | # seqlens[:] = max_seqlen
58 | cu_seqlens = torch.cat(
59 | [
60 | torch.tensor([0], device=device, dtype=torch.int32),
61 | torch.cumsum(seqlens, dim=0).to(torch.int32),
62 | ]
63 | )
64 |
65 | total_len = cu_seqlens[-1].item()
66 |
67 | for kernel_size, kernel_stride in zip(kernel_sizes, kernel_strides):
68 | print(f"\nTesting kernel_size={kernel_size}, kernel_stride={kernel_stride}")
69 |
70 | # Create input tensors with requires_grad=True
71 | x_torch = torch.zeros(
72 | (total_len, num_heads, head_dim),
73 | dtype=dtype,
74 | device=device,
75 | ).uniform_(-1, 1)
76 | x_torch.requires_grad_(True)
77 |
78 | x_triton = x_torch.clone().detach().requires_grad_(True)
79 |
80 | w_torch = (
81 | torch.ones(
82 | (num_heads, kernel_size * head_dim, head_dim),
83 | dtype=dtype,
84 | device=device,
85 | )
86 | / kernel_size
87 | )
88 | w_torch.requires_grad_(True)
89 |
90 | w_triton = w_torch.clone().detach().requires_grad_(True)
91 |
92 | pe_torch = None
93 | pe_triton = None
94 | if use_pe:
95 | pe_torch = torch.randn(
96 | (num_heads, kernel_size, head_dim),
97 | dtype=dtype,
98 | device=device,
99 | requires_grad=True,
100 | )
101 | pe_triton = pe_torch.clone().detach().requires_grad_(True)
102 |
103 | # Run forward passes
104 | y_torch, y_cu_seqlens_torch = linear_compress_torch(
105 | x=x_torch,
106 | w=w_torch,
107 | cu_seqlens=cu_seqlens,
108 | kernel_size=kernel_size,
109 | kernel_stride=kernel_stride,
110 | pe=pe_torch,
111 | )
112 |
113 | y_triton, y_cu_seqlens_triton = linear_compress(
114 | x=x_triton,
115 | w=w_triton,
116 | cu_seqlens=cu_seqlens,
117 | kernel_size=kernel_size,
118 | kernel_stride=kernel_stride,
119 | pe=pe_triton,
120 | )
121 |
122 | # Check forward pass numerical equivalence
123 | atol, rtol = 1e-2, 1e-2
124 | values_match = torch.allclose(y_torch, y_triton, atol=atol, rtol=rtol)
125 | print(
126 | f"Forward pass - Output values match (atol={atol}, rtol={rtol}): {values_match}"
127 | )
128 | if not values_match:
129 | max_diff = (y_torch - y_triton).abs().max().item()
130 | print(f"Forward pass - Maximum difference: {max_diff}")
131 | print("\nSample values (first batch, first head):")
132 | print("Torch:", y_torch[0, 0, :5])
133 | print("Triton:", y_triton[0, 0, :5])
134 |
135 | # Create random output gradients for backward pass
136 | grad_output = torch.randn_like(y_torch)
137 |
138 | # Run backward passes
139 | y_torch.backward(grad_output)
140 | y_triton.backward(grad_output)
141 |
142 | # Check gradient equivalence
143 | print("\nTesting backward pass:")
144 |
145 | # Check x gradients
146 | x_grads_match = torch.allclose(
147 | x_torch.grad, x_triton.grad, atol=atol, rtol=rtol
148 | )
149 | print(f"x gradients match (atol={atol}, rtol={rtol}): {x_grads_match}")
150 | if not x_grads_match:
151 | max_diff = (x_torch.grad - x_triton.grad).abs().max().item()
152 | print(f"x gradients - Maximum difference: {max_diff}")
153 | print("\nSample x gradients (first batch, first head):")
154 | print("Torch:", x_torch.grad[0, 0, :5])
155 | print("Triton:", x_triton.grad[0, 0, :5])
156 |
157 | # Check w gradients
158 | w_grads_match = torch.allclose(
159 | w_torch.grad, w_triton.grad, atol=atol, rtol=rtol
160 | )
161 | print(f"w gradients match (atol={atol}, rtol={rtol}): {w_grads_match}")
162 | if not w_grads_match:
163 | max_diff = (w_torch.grad - w_triton.grad).abs().max().item()
164 | print(f"w gradients - Maximum difference: {max_diff}")
165 | print("\nSample w gradients (first head):")
166 | print("Torch:", w_torch.grad[0, :5, 0])
167 | print("Triton:", w_triton.grad[0, :5, 0])
168 |
169 | # Check pe gradients if used
170 | if use_pe:
171 | pe_grads_match = torch.allclose(
172 | pe_torch.grad, pe_triton.grad, atol=atol, rtol=rtol
173 | )
174 | print(f"pe gradients match (atol={atol}, rtol={rtol}): {pe_grads_match}")
175 | if not pe_grads_match:
176 | max_diff = (pe_torch.grad - pe_triton.grad).abs().max().item()
177 | print(f"pe gradients - Maximum difference: {max_diff}")
178 | print("\nSample pe gradients (first head):")
179 | print("Torch:", pe_torch.grad[0, :5, 0])
180 | print("Triton:", pe_triton.grad[0, :5, 0])
181 |
182 | # Clean up gradients for next iteration
183 | x_torch.grad = None
184 | x_triton.grad = None
185 | w_torch.grad = None
186 | w_triton.grad = None
187 | if use_pe:
188 | pe_torch.grad = None
189 | pe_triton.grad = None
190 |
191 |
192 | if __name__ == "__main__":
193 | # Run tests
194 | test_linear_compress(
195 | batch_size=16,
196 | num_heads=8,
197 | head_dim=128,
198 | max_seqlen=2048,
199 | kernel_sizes=[32],
200 | kernel_strides=[16],
201 | use_pe=False,
202 | dtype=torch.float16,
203 | device="cuda",
204 | )
205 |
206 | # benchmark
207 | @triton.testing.perf_report(
208 | triton.testing.Benchmark(
209 | x_names=["N"],
210 | x_vals=[1024 * 2**i for i in range(1, 8)],
211 | line_arg="provider",
212 | line_vals=["torch", "triton"],
213 | line_names=["torch", "triton"],
214 | styles=[("green", "-"), ("blue", "-")],
215 | ylabel="ms",
216 | plot_name="** forward + backward **",
217 | args={"H": 4, "D": 64},
218 | )
219 | )
220 | def benchmark_fwdbwd(N, H, D, provider):
221 | K, S = 32, 16
222 | # Input tensors
223 | x = torch.zeros(N, H, D, device="cuda", dtype=torch.bfloat16).uniform_(-1, 1)
224 | x.requires_grad = True
225 | w = torch.zeros(H, K * D, D, device="cuda", dtype=torch.bfloat16).uniform_(
226 | -1, 1
227 | )
228 | w.requires_grad = True
229 | pe = torch.zeros(H, K, D, device="cuda", dtype=torch.bfloat16).uniform_(-1, 1)
230 | cu_seqlens_b32 = (
231 | torch.LongTensor(
232 | [0 if i == 0 else 32 if i > 1 else N - 32 * 31 for i in range(33)]
233 | )
234 | .int()
235 | .cuda()
236 | )
237 | cu_seqlens_b32 = cu_seqlens_b32.cumsum(0).to(torch.int32)
238 |
239 | quantiles = [0.5, 0.2, 0.8]
240 |
241 | def fwd_bwd():
242 | if provider == "torch":
243 | out, _ = linear_compress_torch(x, w, cu_seqlens_b32, K, S, pe)
244 | else:
245 | out, _ = linear_compress(x, w, cu_seqlens_b32, K, S, pe)
246 | out.backward(out) # Using output as gradient for simplicity
247 | return out
248 |
249 | ms, min_ms, max_ms = triton.testing.do_bench(fwd_bwd, quantiles=quantiles)
250 | return ms, min_ms, max_ms
251 |
252 | benchmark_fwdbwd.run(show_plots=True, print_data=True)
253 |
--------------------------------------------------------------------------------
/test/test_nsa_infer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Xunhao Lai & Jianqiao Lu.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import torch
15 | from native_sparse_attention.ops import linear_compress, weightedpool_compress
16 | from native_sparse_attention.module import NSACache, RotaryEmbedding, RopeConfig
17 | from native_sparse_attention.infer import nsa_infer
18 |
19 |
20 | if __name__ == "__main__":
21 | torch.manual_seed(42)
22 |
23 | num_heads = 4
24 | head_dim = 128
25 | kernel_size = 32
26 | kernel_stride = 16
27 | block_size = 64
28 | window_size = 512
29 | topk = 16
30 | init_blocks = 1
31 | local_blocks = 2
32 |
33 | # init seqlens
34 | seqlens = torch.tensor([12, 576, 12000]).to(torch.int32).cuda()
35 | batch_size = seqlens.shape[0]
36 | cu_seqlens = torch.zeros(seqlens.shape[0] + 1, dtype=torch.int32, device="cuda")
37 | cu_seqlens[1:] = seqlens.cumsum(0)
38 | step = 0
39 |
40 | # init cache and weight and rope
41 | cache = NSACache(4, 16384, num_heads, head_dim, 32, 16, 512, torch.bfloat16, "cuda")
42 | compress_weight = [
43 | torch.ones(num_heads, kernel_size * head_dim, head_dim).cuda().bfloat16()
44 | / (kernel_size * head_dim),
45 | torch.ones(num_heads, kernel_size).cuda().bfloat16() / kernel_size,
46 | ]
47 | compress_func = [linear_compress, weightedpool_compress]
48 | rope = RotaryEmbedding(
49 | RopeConfig(
50 | max_position_embeddings=131072,
51 | head_dim=128,
52 | rope_theta=500000,
53 | rope_scaling={
54 | "factor": 8.0,
55 | "high_freq_factor": 4.0,
56 | "low_freq_factor": 1.0,
57 | "original_max_position_embeddings": 8192,
58 | "rope_type": "llama3",
59 | },
60 | )
61 | )
62 |
63 | # test prefill
64 | q = torch.randn(cu_seqlens[-1], num_heads * 16, head_dim).cuda().bfloat16()
65 | k = torch.randn(cu_seqlens[-1], num_heads, head_dim).cuda().bfloat16()
66 | v = torch.randn_like(k)
67 | g = torch.rand(cu_seqlens[-1], num_heads * 16, 3).cuda().bfloat16()
68 | o = nsa_infer(
69 | cu_seqlens,
70 | step,
71 | q,
72 | k,
73 | v,
74 | g,
75 | rope,
76 | cache,
77 | compress_weight,
78 | compress_func,
79 | None,
80 | kernel_size,
81 | kernel_stride,
82 | block_size,
83 | topk,
84 | init_blocks,
85 | local_blocks,
86 | window_size,
87 | )
88 | print(o.shape, o.norm())
89 |
90 | # test decode
91 | q = torch.randn(cu_seqlens.shape[0] - 1, num_heads * 16, head_dim).cuda().bfloat16()
92 | k = torch.randn(cu_seqlens.shape[0] - 1, num_heads, head_dim).cuda().bfloat16()
93 | v = torch.randn_like(k)
94 | g = torch.rand(cu_seqlens.shape[0] - 1, num_heads * 16, 3).cuda().bfloat16()
95 | step = 1
96 | o = nsa_infer(
97 | cu_seqlens,
98 | step,
99 | q,
100 | k,
101 | v,
102 | g,
103 | rope,
104 | cache,
105 | compress_weight,
106 | compress_func,
107 | None,
108 | kernel_size,
109 | kernel_stride,
110 | block_size,
111 | topk,
112 | init_blocks,
113 | local_blocks,
114 | window_size,
115 | )
116 | print(o.shape, o.norm())
117 |
--------------------------------------------------------------------------------
/test/test_nsa_model.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Xunhao Lai & Jianqiao Lu.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import torch
15 | from native_sparse_attention.model import (
16 | ToyNSALlamaConfig,
17 | InferenceConfig,
18 | ToyNSALlama,
19 | )
20 |
21 |
22 | if __name__ == "__main__":
23 | torch.manual_seed(42)
24 | # initialize model
25 | config = ToyNSALlamaConfig(
26 | hidden_size=4096,
27 | intermediate_size=14336,
28 | num_hidden_layers=8,
29 | num_attention_heads=32,
30 | num_key_value_heads=2,
31 | head_dim=128,
32 | rope_theta=500000.0,
33 | rope_scaling={
34 | "factor": 8.0,
35 | "high_freq_factor": 4.0,
36 | "low_freq_factor": 1.0,
37 | "original_max_position_embeddings": 8192,
38 | "rope_type": "llama3",
39 | },
40 | compress_type="weightedpool",
41 | kernel_size=32,
42 | kernel_stride=16,
43 | block_size=64,
44 | topk=8,
45 | init_blocks=1,
46 | local_blocks=2,
47 | window_size=512,
48 | )
49 | inference_config = InferenceConfig(
50 | max_batch_size=4,
51 | max_length=8192,
52 | max_new_tokens=128,
53 | )
54 | model = ToyNSALlama(config, inference_config).cuda().bfloat16()
55 | print(f"\nMODEL CONFIG:\n{config}\n")
56 | print(f"\nINFERENCE CONFIG:\n{inference_config}\n")
57 | print(f"\nMODEL:\n{model}\n")
58 |
59 | # example input
60 | batch_size = 4
61 | seqlens = torch.randint(0, 4096, (batch_size,), dtype=torch.int32, device="cuda")
62 | cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device="cuda")
63 | cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
64 | input_ids = torch.randint(
65 | 0, 128288, (cu_seqlens[-1],), dtype=torch.int64, device="cuda"
66 | )
67 | print(f"\nEXAMPLE INPUT:\ncu_seqlens: {cu_seqlens}\ninput_ids: {input_ids.shape}\n")
68 |
69 | # example output
70 | logits = model(input_ids, cu_seqlens)
71 | print(f"\nEXAMPLE OUTPUT:\nlogits: {logits.shape}\n")
72 |
73 | # example generate
74 | output_tokens = model.generate(input_ids, cu_seqlens, 64)
75 | print(f"\nEXAMPLE GENERATE:\noutput_tokens: {output_tokens}\n")
76 |
--------------------------------------------------------------------------------
/test/test_nsa_module.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Xunhao Lai & Jianqiao Lu.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific
13 | import torch
14 | import triton
15 | from native_sparse_attention.module import (
16 | SelfAttention,
17 | NativeSparseAttention,
18 | RopeConfig,
19 | )
20 |
21 |
22 | if __name__ == "__main__":
23 | torch.manual_seed(42)
24 | NSA = (
25 | NativeSparseAttention(
26 | compress_type="avgpool",
27 | hidden_size=8192,
28 | num_q_heads=64,
29 | num_kv_heads=4,
30 | head_dim=128,
31 | kernel_size=32,
32 | kernel_stride=16,
33 | block_size=64,
34 | topk=16,
35 | init_blocks=1,
36 | local_blocks=2,
37 | window_size=512,
38 | rope_config=RopeConfig(
39 | max_position_embeddings=131072,
40 | head_dim=128,
41 | rope_theta=500000,
42 | rope_scaling={
43 | "factor": 8.0,
44 | "high_freq_factor": 4.0,
45 | "low_freq_factor": 1.0,
46 | "original_max_position_embeddings": 8192,
47 | "rope_type": "llama3",
48 | },
49 | ),
50 | )
51 | .cuda()
52 | .to(torch.bfloat16)
53 | )
54 | print("======= Init Moduel: Native Sparse Attention =======\n")
55 | for name, param in NSA.named_parameters():
56 | print(f"NSA Parameters, {name}, shape: {param.shape}\n")
57 |
58 | # random input
59 | seqlens = torch.LongTensor([4000, 8192, 16384]).int().cuda()
60 | cu_seqlens = torch.cat(
61 | [
62 | torch.zeros(1, dtype=torch.int32, device="cuda"),
63 | torch.cumsum(seqlens, dim=0),
64 | ],
65 | dim=0,
66 | ).to(torch.int32)
67 | x = torch.zeros(cu_seqlens[-1], 8192, device="cuda", dtype=torch.bfloat16).uniform_(
68 | -1, 1
69 | )
70 |
71 | # forward test
72 | print("======= NSA Forward & Backward Test =======\n")
73 | y = NSA(x, cu_seqlens)
74 | print(f"Forward, output shape: {y.shape}, output norm: {y.norm()}\n")
75 |
76 | # backward test
77 | loss = (y * torch.randn_like(y)).sum(-1).mean()
78 | loss.backward()
79 | for name, param in NSA.named_parameters():
80 | print(
81 | f"Backward, {name}, grad shape: {param.grad.shape}, grad norm: {param.grad.norm()}\n"
82 | )
83 |
84 | # speed benchmark
85 | SelfAttn = (
86 | SelfAttention(
87 | hidden_size=8192,
88 | num_q_heads=64,
89 | num_kv_heads=4,
90 | head_dim=128,
91 | rope_config=RopeConfig(
92 | max_position_embeddings=131072,
93 | head_dim=128,
94 | rope_theta=500000,
95 | rope_scaling={
96 | "factor": 8.0,
97 | "high_freq_factor": 4.0,
98 | "low_freq_factor": 1.0,
99 | "original_max_position_embeddings": 8192,
100 | "rope_type": "llama3",
101 | },
102 | ),
103 | )
104 | .cuda()
105 | .to(torch.bfloat16)
106 | )
107 |
108 | @triton.testing.perf_report(
109 | triton.testing.Benchmark(
110 | x_names=["N"],
111 | x_vals=[1024 * 2**i for i in range(1, 8)],
112 | line_arg="provider",
113 | line_vals=["Self-Attention", "Native-Sparse-Attention"],
114 | line_names=["Self-Attention", "Native-Sparse-Attention"],
115 | styles=[("green", "-"), ("blue", "-")],
116 | ylabel="ms",
117 | plot_name="** NSA forward speed benchmark **",
118 | args={},
119 | )
120 | )
121 | def benchmark(N, provider):
122 | x = torch.randn(N, 8192, device="cuda", dtype=torch.bfloat16)
123 | cu_seqlens = torch.tensor([0, N], device="cuda", dtype=torch.int32)
124 | quantiles = [0.5, 0.2, 0.8]
125 | with torch.no_grad():
126 | if provider == "Self-Attention":
127 | ms, min_ms, max_ms = triton.testing.do_bench(
128 | lambda: SelfAttn(x, cu_seqlens),
129 | quantiles=quantiles,
130 | )
131 | if provider == "Native-Sparse-Attention":
132 | ms, min_ms, max_ms = triton.testing.do_bench(
133 | lambda: NSA(x, cu_seqlens),
134 | quantiles=quantiles,
135 | )
136 | return ms, min_ms, max_ms
137 |
138 | benchmark.run(show_plots=True, print_data=True)
139 |
140 | @triton.testing.perf_report(
141 | triton.testing.Benchmark(
142 | x_names=["N"],
143 | x_vals=[1024 * 2**i for i in range(1, 8)],
144 | line_arg="provider",
145 | line_vals=["Self-Attention", "Native-Sparse-Attention"],
146 | line_names=["Self-Attention", "Native-Sparse-Attention"],
147 | styles=[("green", "-"), ("blue", "-")],
148 | ylabel="ms",
149 | plot_name="** NSA backward speed benchmark **",
150 | args={},
151 | )
152 | )
153 | def benchmark(N, provider):
154 | x = torch.randn(N, 8192, device="cuda", dtype=torch.bfloat16)
155 | cu_seqlens = torch.tensor([0, N], device="cuda", dtype=torch.int32)
156 | quantiles = [0.5, 0.2, 0.8]
157 | if provider == "Self-Attention":
158 | loss = SelfAttn(x.clone().detach().requires_grad_(), cu_seqlens).mean()
159 | ms, min_ms, max_ms = triton.testing.do_bench(
160 | lambda: loss.backward(retain_graph=True),
161 | quantiles=quantiles,
162 | )
163 | elif provider == "Native-Sparse-Attention":
164 | loss = NSA(x.clone().detach().requires_grad_(), cu_seqlens).mean()
165 | ms, min_ms, max_ms = triton.testing.do_bench(
166 | lambda: loss.backward(retain_graph=True),
167 | quantiles=quantiles,
168 | )
169 | return ms, min_ms, max_ms
170 |
171 | benchmark.run(show_plots=True, print_data=True)
172 |
--------------------------------------------------------------------------------
/test/test_rope.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Xunhao Lai & Jianqiao Lu.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import torch
15 | from native_sparse_attention.module import RopeConfig, RotaryEmbedding
16 |
17 | if __name__ == "__main__":
18 | rope_config = RopeConfig(
19 | max_position_embeddings=131072,
20 | head_dim=128,
21 | rope_theta=500000,
22 | rope_scaling={
23 | "factor": 8.0,
24 | "high_freq_factor": 4.0,
25 | "low_freq_factor": 1.0,
26 | "original_max_position_embeddings": 8192,
27 | "rope_type": "llama3",
28 | },
29 | )
30 | rope = RotaryEmbedding(rope_config, "cuda")
31 |
32 | # random input
33 | torch.manual_seed(42)
34 | seqlens = torch.LongTensor([1000, 2000, 4096]).int().cuda()
35 | cu_seqlens = torch.cat(
36 | [
37 | torch.zeros(1, dtype=torch.int32, device="cuda"),
38 | torch.cumsum(seqlens, dim=0),
39 | ],
40 | dim=0,
41 | ).to(torch.int32)
42 | x = torch.zeros(
43 | cu_seqlens[-1], 32, 128, device="cuda", dtype=torch.bfloat16
44 | ).uniform_(-1, 1)
45 | y = rope(x, cu_seqlens)
46 |
--------------------------------------------------------------------------------
/test/test_topk_sparse_attention.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Xunhao Lai & Jianqiao Lu.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific
13 | import torch
14 | import triton
15 | import math
16 | from native_sparse_attention.ops.torch.topk_sparse_attention import (
17 | topk_sparse_attention_torch,
18 | )
19 | from native_sparse_attention.ops.triton.topk_sparse_attention import (
20 | topk_sparse_attention,
21 | _topk_sparse_attention_fwd,
22 | _topk_sparse_attention_bwd,
23 | )
24 | from native_sparse_attention.ops.triton.flash_attention import (
25 | _flash_attention_fwd,
26 | _flash_attention_bwd,
27 | )
28 | from flash_attn.flash_attn_interface import (
29 | _flash_attn_varlen_forward,
30 | _flash_attn_varlen_backward,
31 | )
32 |
33 |
34 | def generate_topk_idx_example(
35 | seqlens: torch.Tensor,
36 | block_size_k: int,
37 | topk: int,
38 | num_heads: int,
39 | block_size_q: int = 1,
40 | ) -> torch.Tensor:
41 | """Generate topk idx example for test.
42 |
43 | Args:
44 | seqlens (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens in flash_attn_func_varlen.
45 | block_size_q (int): query block size
46 | block_size_k (int): key value block size
47 | topk (int): selected topk
48 | num_heads (int): number of key value heads
49 |
50 | Returns:
51 | torch.Tensor: shape [num_heads, total_seqlen, topk], topk key value block idx for each query. -1 means padding.
52 | """
53 | batch_size = seqlens.shape[0]
54 | num_blocks = torch.ceil(seqlens / block_size_k).to(torch.int32)
55 | topk_idx_all_heads = []
56 | cu_seqlens = torch.nn.functional.pad(seqlens.cumsum(0), pad=(1, 0), value=0)
57 | for _ in range(num_heads):
58 | topk_idx = [
59 | torch.randn(seqlens[i], num_blocks[i], device="cuda")
60 | .topk(min(topk, num_blocks[i]), dim=-1)
61 | .indices.to(torch.int32)
62 | for i in range(batch_size)
63 | ]
64 | topk_idx = [
65 | torch.nn.functional.pad(
66 | topk_idx[i], (0, topk - topk_idx[i].shape[-1]), value=topk
67 | )
68 | for i in range(batch_size)
69 | ]
70 | topk_idx = torch.cat(topk_idx, dim=0)
71 | topk_idx = torch.sort(topk_idx, dim=1).values
72 | topk_idx[:, 0] = 0
73 | q_idx = torch.cat(
74 | [torch.arange(seqlens[i], device="cuda") for i in range(batch_size)], dim=0
75 | )
76 | topk_idx[topk_idx > (q_idx // block_size_k)[:, None]] = -1 # -1 means padding
77 | topk_idx = torch.cat(
78 | [
79 | topk_idx[cu_seqlens[i] : cu_seqlens[i + 1]][0::block_size_q]
80 | for i in range(batch_size)
81 | ],
82 | dim=0,
83 | )
84 | topk_idx_all_heads.append(topk_idx)
85 | topk_idx = torch.stack(topk_idx_all_heads, dim=0)
86 | return topk_idx
87 |
88 |
89 | if __name__ == "__main__":
90 | torch.manual_seed(42)
91 | batch_size = 3
92 | seqlens = torch.LongTensor([1000, 2000, 4096]).int().cuda()
93 | cu_seqlens = torch.cat(
94 | [
95 | torch.zeros(1, dtype=torch.int32, device="cuda"),
96 | torch.cumsum(seqlens, dim=0),
97 | ],
98 | dim=0,
99 | ).to(torch.int32)
100 | max_seqlen = seqlens.max().item()
101 | q = (
102 | torch.empty(cu_seqlens[-1], 64, 96, device="cuda")
103 | .uniform_(-1, 1)
104 | .to(torch.float16)
105 | )
106 | k = (
107 | torch.empty(cu_seqlens[-1], 8, 96, device="cuda")
108 | .uniform_(-1, 1)
109 | .to(torch.float16)
110 | )
111 | v = (
112 | torch.empty(cu_seqlens[-1], 8, 96, device="cuda")
113 | .uniform_(-1, 1)
114 | .to(torch.float16)
115 | )
116 | q.requires_grad = True
117 | k.requires_grad = True
118 | v.requires_grad = True
119 | block_size = 64
120 | topk = 5
121 | topk_idx = generate_topk_idx_example(seqlens, block_size, topk, 8)
122 |
123 | o = topk_sparse_attention_torch(q, k, v, topk_idx, block_size, cu_seqlens)
124 |
125 | randn = torch.randn_like(o)
126 | loss = (o * randn).sum()
127 | loss.backward()
128 |
129 | torch.manual_seed(42)
130 | q1 = q.clone().detach().requires_grad_()
131 | k1 = k.clone().detach().requires_grad_()
132 | v1 = v.clone().detach().requires_grad_()
133 | topk_idx1 = topk_idx.clone().detach()
134 | cu_seqlens1 = cu_seqlens.clone().detach()
135 |
136 | o1 = topk_sparse_attention(q1, k1, v1, topk_idx, block_size, cu_seqlens)
137 |
138 | randn2 = randn.clone().detach()
139 | loss2 = (o1 * randn2).sum()
140 | loss2.backward()
141 |
142 | print("Same Output:", torch.allclose(o, o1, atol=0.01, rtol=0.01))
143 | print("Max Error:", (o - o1).abs().max().item())
144 | print()
145 | print("Same Query Gradient:", torch.allclose(q.grad, q1.grad, atol=0.01, rtol=0.01))
146 | print("Max Query Gradient Error:", (q.grad - q1.grad).abs().max().item())
147 | print()
148 | print("Same Key Gradient:", torch.allclose(k.grad, k1.grad, atol=0.01, rtol=0.01))
149 | print("Max Key Gradient Error:", (k.grad - k1.grad).abs().max().item())
150 | print()
151 | print("Same Value Gradient:", torch.allclose(v.grad, v1.grad, atol=0.01, rtol=0.01))
152 | print("Max Value Gradient Error:", (v.grad - v1.grad).abs().max().item())
153 | print()
154 |
155 | # benchmark
156 | @triton.testing.perf_report(
157 | triton.testing.Benchmark(
158 | x_names=["N"],
159 | x_vals=[1024 * 2**i for i in range(1, 8)],
160 | line_arg="provider",
161 | line_vals=["flash", "triton-flash", "triton-top8", "triton-top16"],
162 | line_names=[
163 | "Flash",
164 | "Triton-Flash",
165 | "Triton-Top8",
166 | "Triton-Top16",
167 | ],
168 | styles=[("green", "-"), ("green", "--"), ("blue", "-"), ("blue", "--")],
169 | ylabel="ms",
170 | plot_name="** forward with block size 64 **",
171 | args={"H": 64, "D": 128, "K": 64},
172 | )
173 | )
174 | def benchmark(N, H, D, K, provider):
175 | q = torch.randn((N, H, D), device="cuda", dtype=torch.bfloat16)
176 | k = torch.randn((N, H // 16, D), device="cuda", dtype=torch.bfloat16)
177 | v = torch.randn((N, H // 16, D), device="cuda", dtype=torch.bfloat16)
178 | cu_seqlens = torch.tensor([0, N], device="cuda", dtype=torch.int32)
179 | sm_scale = 1 / math.sqrt(D)
180 |
181 | top8_idx = generate_topk_idx_example(cu_seqlens[1:], K, 8, H // 16)
182 | top16_idx = generate_topk_idx_example(cu_seqlens[1:], K, 16, H // 16)
183 |
184 | quantiles = [0.5, 0.2, 0.8]
185 | if provider == "flash":
186 | ms, min_ms, max_ms = triton.testing.do_bench(
187 | lambda: _flash_attn_varlen_forward(
188 | q,
189 | k,
190 | v,
191 | cu_seqlens,
192 | cu_seqlens,
193 | N,
194 | N,
195 | dropout_p=0.0,
196 | causal=True,
197 | softmax_scale=sm_scale,
198 | ),
199 | quantiles=quantiles,
200 | )
201 | if provider == "triton-flash":
202 | ms, min_ms, max_ms = triton.testing.do_bench(
203 | lambda: _flash_attention_fwd(
204 | q, k, v, cu_seqlens, cu_seqlens, N, N, True, sm_scale
205 | ),
206 | quantiles=quantiles,
207 | )
208 | if provider == "triton-top8":
209 | ms, min_ms, max_ms = triton.testing.do_bench(
210 | lambda: _topk_sparse_attention_fwd(
211 | q, k, v, top8_idx, K, cu_seqlens, cu_seqlens, N, N, sm_scale
212 | ),
213 | quantiles=quantiles,
214 | )
215 | if provider == "triton-top16":
216 | ms, min_ms, max_ms = triton.testing.do_bench(
217 | lambda: _topk_sparse_attention_fwd(
218 | q, k, v, top16_idx, K, cu_seqlens, cu_seqlens, N, N, sm_scale
219 | ),
220 | quantiles=quantiles,
221 | )
222 | return ms, min_ms, max_ms
223 |
224 | benchmark.run(show_plots=True, print_data=True)
225 |
226 | # benchmark
227 | @triton.testing.perf_report(
228 | triton.testing.Benchmark(
229 | x_names=["N"],
230 | x_vals=[1024 * 2**i for i in range(1, 8)],
231 | line_arg="provider",
232 | line_vals=["flash", "triton-flash", "triton-top8", "triton-top16"],
233 | line_names=[
234 | "Flash",
235 | "Triton-Flash",
236 | "Triton-Top8",
237 | "Triton-Top16",
238 | ],
239 | styles=[("green", "-"), ("green", "--"), ("blue", "-"), ("blue", "--")],
240 | ylabel="ms",
241 | plot_name="** backward with block size 64 **",
242 | args={"H": 64, "D": 128, "K": 64},
243 | )
244 | )
245 | def benchmark(N, H, D, K, provider):
246 | q = torch.randn((N, H, D), device="cuda", dtype=torch.bfloat16)
247 | k = torch.randn((N, H // 16, D), device="cuda", dtype=torch.bfloat16)
248 | v = torch.randn((N, H // 16, D), device="cuda", dtype=torch.bfloat16)
249 | o = torch.randn((N, H, D), device="cuda", dtype=torch.bfloat16)
250 | do = torch.randn((N, H, D), device="cuda", dtype=torch.bfloat16)
251 | lse = torch.randn((H, N), device="cuda", dtype=torch.float32)
252 | sm_scale = 1 / math.sqrt(D)
253 | cu_seqlens = torch.tensor([0, N], device="cuda", dtype=torch.int32)
254 | top8_idx = generate_topk_idx_example(cu_seqlens[1:], K, 8, H // 16)
255 | top16_idx = generate_topk_idx_example(cu_seqlens[1:], K, 16, H // 16)
256 | dq = torch.zeros_like(q)
257 | dk = torch.zeros_like(k)
258 | dv = torch.zeros_like(v)
259 |
260 | quantiles = [0.5, 0.2, 0.8]
261 | if provider == "flash":
262 | ms, min_ms, max_ms = triton.testing.do_bench(
263 | lambda: _flash_attn_varlen_backward(
264 | do,
265 | q,
266 | k,
267 | v,
268 | o,
269 | lse.transpose(0, 1),
270 | dq,
271 | dk,
272 | dv,
273 | cu_seqlens,
274 | cu_seqlens,
275 | N,
276 | N,
277 | dropout_p=0.0,
278 | causal=True,
279 | softmax_scale=sm_scale,
280 | window_size=(-1, -1),
281 | softcap=0.0,
282 | alibi_slopes=None,
283 | deterministic=False,
284 | ),
285 | quantiles=quantiles,
286 | )
287 | if provider == "triton-flash":
288 | ms, min_ms, max_ms = triton.testing.do_bench(
289 | lambda: _flash_attention_bwd(
290 | o, do, lse, q, k, v, cu_seqlens, cu_seqlens, N, N, True, sm_scale
291 | ),
292 | quantiles=quantiles,
293 | )
294 | if provider == "triton-top8":
295 | ms, min_ms, max_ms = triton.testing.do_bench(
296 | lambda: _topk_sparse_attention_bwd(
297 | o,
298 | do,
299 | lse,
300 | q,
301 | k,
302 | v,
303 | top8_idx,
304 | K,
305 | cu_seqlens,
306 | cu_seqlens,
307 | N,
308 | N,
309 | sm_scale,
310 | ),
311 | quantiles=quantiles,
312 | )
313 | if provider == "triton-top16":
314 | ms, min_ms, max_ms = triton.testing.do_bench(
315 | lambda: _topk_sparse_attention_bwd(
316 | o,
317 | do,
318 | lse,
319 | q,
320 | k,
321 | v,
322 | top16_idx,
323 | K,
324 | cu_seqlens,
325 | cu_seqlens,
326 | N,
327 | N,
328 | sm_scale,
329 | ),
330 | quantiles=quantiles,
331 | )
332 | return ms, min_ms, max_ms
333 |
334 | benchmark.run(show_plots=True, print_data=True)
335 |
--------------------------------------------------------------------------------