├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── csrc ├── grouped_gemm.cu ├── grouped_gemm.h ├── ops.cu ├── permute.cu ├── permute.h ├── sinkhorn.cu └── sinkhorn.h ├── figures ├── figure1.png ├── figure_groupedgemm.png ├── figure_permute.png └── figure_unpermute.png ├── grouped_gemm ├── __init__.py ├── backend.py ├── ops.py ├── ops_test.py ├── permute_test.py └── sinkhorn_test.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | *~ 2 | build 3 | *.egg-info 4 | dist 5 | __pycache__ -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "third_party/cutlass"] 2 | path = third_party/cutlass 3 | url = https://github.com/NVIDIA/cutlass 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. Apache License 202 | Version 2.0, January 2004 203 | http://www.apache.org/licenses/ 204 | 205 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 206 | 207 | 1. Definitions. 208 | 209 | "License" shall mean the terms and conditions for use, reproduction, 210 | and distribution as defined by Sections 1 through 9 of this document. 211 | 212 | "Licensor" shall mean the copyright owner or entity authorized by 213 | the copyright owner that is granting the License. 214 | 215 | "Legal Entity" shall mean the union of the acting entity and all 216 | other entities that control, are controlled by, or are under common 217 | control with that entity. For the purposes of this definition, 218 | "control" means (i) the power, direct or indirect, to cause the 219 | direction or management of such entity, whether by contract or 220 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 221 | outstanding shares, or (iii) beneficial ownership of such entity. 222 | 223 | "You" (or "Your") shall mean an individual or Legal Entity 224 | exercising permissions granted by this License. 225 | 226 | "Source" form shall mean the preferred form for making modifications, 227 | including but not limited to software source code, documentation 228 | source, and configuration files. 229 | 230 | "Object" form shall mean any form resulting from mechanical 231 | transformation or translation of a Source form, including but 232 | not limited to compiled object code, generated documentation, 233 | and conversions to other media types. 234 | 235 | "Work" shall mean the work of authorship, whether in Source or 236 | Object form, made available under the License, as indicated by a 237 | copyright notice that is included in or attached to the work 238 | (an example is provided in the Appendix below). 239 | 240 | "Derivative Works" shall mean any work, whether in Source or Object 241 | form, that is based on (or derived from) the Work and for which the 242 | editorial revisions, annotations, elaborations, or other modifications 243 | represent, as a whole, an original work of authorship. For the purposes 244 | of this License, Derivative Works shall not include works that remain 245 | separable from, or merely link (or bind by name) to the interfaces of, 246 | the Work and Derivative Works thereof. 247 | 248 | "Contribution" shall mean any work of authorship, including 249 | the original version of the Work and any modifications or additions 250 | to that Work or Derivative Works thereof, that is intentionally 251 | submitted to Licensor for inclusion in the Work by the copyright owner 252 | or by an individual or Legal Entity authorized to submit on behalf of 253 | the copyright owner. For the purposes of this definition, "submitted" 254 | means any form of electronic, verbal, or written communication sent 255 | to the Licensor or its representatives, including but not limited to 256 | communication on electronic mailing lists, source code control systems, 257 | and issue tracking systems that are managed by, or on behalf of, the 258 | Licensor for the purpose of discussing and improving the Work, but 259 | excluding communication that is conspicuously marked or otherwise 260 | designated in writing by the copyright owner as "Not a Contribution." 261 | 262 | "Contributor" shall mean Licensor and any individual or Legal Entity 263 | on behalf of whom a Contribution has been received by Licensor and 264 | subsequently incorporated within the Work. 265 | 266 | 2. Grant of Copyright License. Subject to the terms and conditions of 267 | this License, each Contributor hereby grants to You a perpetual, 268 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 269 | copyright license to reproduce, prepare Derivative Works of, 270 | publicly display, publicly perform, sublicense, and distribute the 271 | Work and such Derivative Works in Source or Object form. 272 | 273 | 3. Grant of Patent License. Subject to the terms and conditions of 274 | this License, each Contributor hereby grants to You a perpetual, 275 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 276 | (except as stated in this section) patent license to make, have made, 277 | use, offer to sell, sell, import, and otherwise transfer the Work, 278 | where such license applies only to those patent claims licensable 279 | by such Contributor that are necessarily infringed by their 280 | Contribution(s) alone or by combination of their Contribution(s) 281 | with the Work to which such Contribution(s) was submitted. If You 282 | institute patent litigation against any entity (including a 283 | cross-claim or counterclaim in a lawsuit) alleging that the Work 284 | or a Contribution incorporated within the Work constitutes direct 285 | or contributory patent infringement, then any patent licenses 286 | granted to You under this License for that Work shall terminate 287 | as of the date such litigation is filed. 288 | 289 | 4. Redistribution. You may reproduce and distribute copies of the 290 | Work or Derivative Works thereof in any medium, with or without 291 | modifications, and in Source or Object form, provided that You 292 | meet the following conditions: 293 | 294 | (a) You must give any other recipients of the Work or 295 | Derivative Works a copy of this License; and 296 | 297 | (b) You must cause any modified files to carry prominent notices 298 | stating that You changed the files; and 299 | 300 | (c) You must retain, in the Source form of any Derivative Works 301 | that You distribute, all copyright, patent, trademark, and 302 | attribution notices from the Source form of the Work, 303 | excluding those notices that do not pertain to any part of 304 | the Derivative Works; and 305 | 306 | (d) If the Work includes a "NOTICE" text file as part of its 307 | distribution, then any Derivative Works that You distribute must 308 | include a readable copy of the attribution notices contained 309 | within such NOTICE file, excluding those notices that do not 310 | pertain to any part of the Derivative Works, in at least one 311 | of the following places: within a NOTICE text file distributed 312 | as part of the Derivative Works; within the Source form or 313 | documentation, if provided along with the Derivative Works; or, 314 | within a display generated by the Derivative Works, if and 315 | wherever such third-party notices normally appear. The contents 316 | of the NOTICE file are for informational purposes only and 317 | do not modify the License. You may add Your own attribution 318 | notices within Derivative Works that You distribute, alongside 319 | or as an addendum to the NOTICE text from the Work, provided 320 | that such additional attribution notices cannot be construed 321 | as modifying the License. 322 | 323 | You may add Your own copyright statement to Your modifications and 324 | may provide additional or different license terms and conditions 325 | for use, reproduction, or distribution of Your modifications, or 326 | for any such Derivative Works as a whole, provided Your use, 327 | reproduction, and distribution of the Work otherwise complies with 328 | the conditions stated in this License. 329 | 330 | 5. Submission of Contributions. Unless You explicitly state otherwise, 331 | any Contribution intentionally submitted for inclusion in the Work 332 | by You to the Licensor shall be under the terms and conditions of 333 | this License, without any additional terms or conditions. 334 | Notwithstanding the above, nothing herein shall supersede or modify 335 | the terms of any separate license agreement you may have executed 336 | with Licensor regarding such Contributions. 337 | 338 | 6. Trademarks. This License does not grant permission to use the trade 339 | names, trademarks, service marks, or product names of the Licensor, 340 | except as required for reasonable and customary use in describing the 341 | origin of the Work and reproducing the content of the NOTICE file. 342 | 343 | 7. Disclaimer of Warranty. Unless required by applicable law or 344 | agreed to in writing, Licensor provides the Work (and each 345 | Contributor provides its Contributions) on an "AS IS" BASIS, 346 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 347 | implied, including, without limitation, any warranties or conditions 348 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 349 | PARTICULAR PURPOSE. You are solely responsible for determining the 350 | appropriateness of using or redistributing the Work and assume any 351 | risks associated with Your exercise of permissions under this License. 352 | 353 | 8. Limitation of Liability. In no event and under no legal theory, 354 | whether in tort (including negligence), contract, or otherwise, 355 | unless required by applicable law (such as deliberate and grossly 356 | negligent acts) or agreed to in writing, shall any Contributor be 357 | liable to You for damages, including any direct, indirect, special, 358 | incidental, or consequential damages of any character arising as a 359 | result of this License or out of the use or inability to use the 360 | Work (including but not limited to damages for loss of goodwill, 361 | work stoppage, computer failure or malfunction, or any and all 362 | other commercial damages or losses), even if such Contributor 363 | has been advised of the possibility of such damages. 364 | 365 | 9. Accepting Warranty or Additional Liability. While redistributing 366 | the Work or Derivative Works thereof, You may choose to offer, 367 | and charge a fee for, acceptance of support, warranty, indemnity, 368 | or other liability obligations and/or rights consistent with this 369 | License. However, in accepting such obligations, You may act only 370 | on Your own behalf and on Your sole responsibility, not on behalf 371 | of any other Contributor, and only if You agree to indemnify, 372 | defend, and hold each Contributor harmless for any liability 373 | incurred by, or claims asserted against, such Contributor by reason 374 | of your accepting any such warranty or additional liability. 375 | 376 | END OF TERMS AND CONDITIONS 377 | 378 | Copyright 2023 MegaBlocks authors 379 | 380 | Licensed under the Apache License, Version 2.0 (the "License"); 381 | you may not use this file except in compliance with the License. 382 | You may obtain a copy of the License at 383 | 384 | http://www.apache.org/licenses/LICENSE-2.0 385 | 386 | Unless required by applicable law or agreed to in writing, software 387 | distributed under the License is distributed on an "AS IS" BASIS, 388 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 389 | See the License for the specific language governing permissions and 390 | limitations under the License. 391 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | Grouped GEMM for MoE 4 | =========================== 5 |

A PyTorch Toolbox for Grouped GEMM in MoE Model Training

6 | 7 | [![license](https://img.shields.io/badge/license-Apache%202-blue)](./LICENSE) 8 | 9 |
10 | 11 | - [Grouped GEMM for MoE](#grouped-gemm-for-moe) 12 | - [Steps for Using](#steps-for-using) 13 | - [pip install](#pip-install) 14 | - [Build from Source](#build-from-source) 15 | - [Support Matrix](#support-matrix) 16 | - [permute \& unpermute](#permute--unpermute) 17 | - [Ops Usage](#ops-usage) 18 | - [permute](#permute) 19 | - [Parameters](#parameters) 20 | - [unpermute](#unpermute) 21 | - [Parameters](#parameters-1) 22 | - [Example](#example) 23 | 24 | --- 25 | 26 | # Steps for Using 27 | 28 | ## pip install 29 | ```bash 30 | pip install --verbose git+https://github.com/fanshiqing/grouped_gemm@main 31 | ``` 32 | 33 | ## Build from Source 34 | ```bash 35 | git submodule update --init --recursive 36 | mkdir build 37 | cd build 38 | cmake .. 39 | make -j 40 | cd .. 41 | 42 | # GroupedGEMM ops test 43 | python grouped_gemm/ops_test.py 44 | 45 | # topK permute & unpermute ops test 46 | python grouped_gemm/permute_test.py 47 | 48 | # sinkhorn kernel test 49 | python grouped_gemm/sinkhorn_test.py 50 | ``` 51 | 52 | # Support Matrix 53 | 54 | ## permute & unpermute 55 | 56 | | GPU Arch | FP32 | FP16 | BF16 | FP8 | 57 | | :--------- | :---: | :---: | :---: | :---: | 58 | | SM 70 | Y | Y | . | Y | 59 | | SM 75 | Y | Y | . | Y | 60 | | SM 80 | Y | Y | Y | Y | 61 | | SM 86 | Y | Y | Y | Y | 62 | | SM 89 | Y | Y | Y | Y | 63 | | SM 90 | Y | Y | Y | Y | 64 | 65 | # Ops Usage 66 | 67 | ## permute 68 | 69 | > ```py 70 | > grouped_gemm.ops.permute( 71 | > input_act: torch.Tensor, 72 | > indices: torch.Tensor, 73 | > num_out_tokens: int = 0, 74 | > max_token_num=0: int) -> tuple 75 | > ``` 76 | 77 | The output tuple of `(torch.Tensor, torch.Tensor)` that contains two tensors `permuted_act` and `row_id_map`. 78 | 79 | * `permuted_act` is the permutation of the original tensor `input_act` with its first dimension permuted according to `indices`. 80 | * `row_id_map` is the mapping table for the row indices of the input activations before and after `grouped_gemm.ops.permute`, which is used for the following `unpermute` op. 81 | 82 | ### Parameters 83 | 84 | * **input_act** (torch.Tensor) 85 |  shape = [tokens_num, hidden_size] 86 |  The input activations with each row (token) corresponds to topK experts. 87 | 88 | * **indices** (torch.Tensor) 89 |  shape = [tokens_num, topK_num] 90 |  The topK expert indices for each row (token) of activations. The `int32` type is recommended. 91 | 92 | * **num_out_tokens** (int) 93 |  The number of output tokens (rows) used for token drop feature. 94 | 95 | * **max_token_num** (int) 96 |  The maximum number of tokens (rows) used for workspace pre-allocation. 97 | 98 |

99 | 100 | ## unpermute 101 | 102 | > ```py 103 | > grouped_gemm.ops.unpermute( 104 | > input_act: torch.Tensor, 105 | > row_id_map: torch.Tensor, 106 | > probs) -> torch.Tensor 107 | > ``` 108 | 109 | The mirror operator of `grouped_gemm.ops.permute`. 110 | 111 | ### Parameters 112 | 113 | * **input_act** (torch.Tensor) 114 |  shape = [tokens_num * topK_num, hidden_size] 115 |  The permuted activations produced by `grouped_gemm.ops.permute`. 116 | 117 | * **row_id_map** (torch.Tensor) 118 |  shape = [tokens_num * topK_num] 119 |  The mapping table for the row indices of the activations before and after `grouped_gemm.ops.permute`. The second output tensor of `grouped_gemm.ops.permute`. 120 | 121 | * **probs** (torch.Tensor) 122 |  shape = [tokens_num, topK_num] 123 |  Sum weights for same-origin tokens from different experts. 124 | 125 |

126 | 127 | ### Example 128 | 129 | ```py 130 | import torch 131 | from grouped_gemm import permute, unpermute 132 | 133 | indices = torch.tensor([[1, 2], [0, 1], [0, 2], [1, 2]], dtype=torch.int32, device='cuda') 134 | input_act = torch.tensor([[0,0,0,0], [1,1,1,1], [2,2,2,2], [3,3,3,3]], dtype=torch.float32, device='cuda') 135 | probs = torch.ones_like(indices, dtype=torch.float32) 136 | permuted_inputs, row_id_map = permute(input_act, indices) 137 | unpermute_outputs = unpermute(permuted_inputs, row_id_map, probs) 138 | 139 | print(row_id_map) 140 | print(input_act) 141 | print(permuted_inputs) 142 | print(unpermute_outputs) 143 | 144 | # Output 145 | # tensor([2, 0, 1, 4, 5, 3, 6, 7], device='cuda:0', dtype=torch.int32) 146 | # tensor([[0., 0., 0., 0.], 147 | # [1., 1., 1., 1.], 148 | # [2., 2., 2., 2.], 149 | # [3., 3., 3., 3.]], device='cuda:0') 150 | # tensor([[1., 1., 1., 1.], 151 | # [2., 2., 2., 2.], 152 | # [0., 0., 0., 0.], 153 | # [1., 1., 1., 1.], 154 | # [3., 3., 3., 3.], 155 | # [0., 0., 0., 0.], 156 | # [2., 2., 2., 2.], 157 | # [3., 3., 3., 3.]], device='cuda:0') 158 | # tensor([[0., 0., 0., 0.], 159 | # [2., 2., 2., 2.], 160 | # [4., 4., 4., 4.], 161 | # [6., 6., 6., 6.]], device='cuda:0') 162 | ``` 163 | 164 | -------------------------------------------------------------------------------- /csrc/grouped_gemm.cu: -------------------------------------------------------------------------------- 1 | #include "grouped_gemm.h" 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include "cutlass/bfloat16.h" 9 | #include "cutlass/complex.h" 10 | #include "cutlass/gemm/kernel/gemm_grouped.h" 11 | #include "cutlass/gemm/kernel/default_gemm_grouped.h" 12 | #include "cutlass/gemm/device/gemm_grouped.h" 13 | 14 | namespace grouped_gemm { 15 | 16 | #define NUM_STREAM 4 17 | 18 | #define CUDA_CALL(code) \ 19 | do { \ 20 | cudaError_t status = code; \ 21 | std::string err = cudaGetErrorString(status); \ 22 | TORCH_CHECK(status == cudaSuccess, err); \ 23 | } while (0) 24 | 25 | #define CUBLAS_CALL(code) \ 26 | do { \ 27 | cublasStatus_t status = code; \ 28 | TORCH_CHECK(status == CUBLAS_STATUS_SUCCESS, "CuBLAS Error"); \ 29 | } while (0) 30 | 31 | #define GROUPED_GEMM_STRINGIFY_HELPER(x) #x 32 | #define GROUPED_GEMM_STRINGIFY(x) \ 33 | GROUPED_GEMM_STRINGIFY_HELPER(x) 34 | 35 | // TODO(tgale): Update this for SM90 when it's supported by CUTLASS. 36 | using GroupedGemmKernelNN = typename cutlass::gemm::kernel::DefaultGemmGrouped< 37 | // Non-transposed A operand. 38 | ::cutlass::bfloat16_t, 39 | ::cutlass::layout::RowMajor, 40 | ::cutlass::ComplexTransform::kNone, 41 | 8, 42 | // Non-transposed B operand. 43 | ::cutlass::bfloat16_t, 44 | ::cutlass::layout::RowMajor, 45 | ::cutlass::ComplexTransform::kNone, 46 | 8, 47 | // C operand. 48 | ::cutlass::bfloat16_t, 49 | ::cutlass::layout::RowMajor, 50 | float, 51 | ::cutlass::arch::OpClassTensorOp, 52 | ::cutlass::arch::Sm80, 53 | ::cutlass::gemm::GemmShape<128, 128, 32>, 54 | ::cutlass::gemm::GemmShape<64, 64, 32>, 55 | ::cutlass::gemm::GemmShape<16, 8, 16>, 56 | ::cutlass::epilogue::thread::LinearCombination<::cutlass::bfloat16_t, 8, float, float>, 57 | // NOTE: Threadblock swizzling is currently not supported by CUTLASS's grouped kernels. 58 | // This parameter is passed in at present to match the APIs of other kernels. The parameter 59 | // is unused within the kernel. 60 | ::cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, 61 | // TODO(tgale): Experiment with GroupScheduleMode. 62 | // TODO(tgale): Tune this for SM90. 63 | 4>::GemmKernel; 64 | using GemmGroupedNN = ::cutlass::gemm::device::GemmGrouped; 65 | 66 | std::vector MakeProblemSizes(torch::Tensor b, torch::Tensor batch_sizes) { 67 | const size_t num_experts = batch_sizes.size(0); 68 | const size_t k = b.size(1), n = b.size(2); 69 | std::vector problem_sizes(num_experts); 70 | for (int i = 0; i < num_experts; ++i) { 71 | problem_sizes[i] = cutlass::gemm::GemmCoord(batch_sizes.data_ptr()[i], n, k); 72 | } 73 | return problem_sizes; 74 | } 75 | 76 | template 77 | torch::Tensor CopyToDevice(const std::vector &x, const torch::Device &device) { 78 | size_t bytes = x.size() * sizeof(T); 79 | auto options = torch::TensorOptions().dtype(torch::kInt8).device(device); 80 | torch::Tensor out = torch::empty(bytes, options); 81 | 82 | CUDA_CALL(cudaMemcpyAsync(out.data_ptr(), 83 | x.data(), bytes, 84 | cudaMemcpyHostToDevice, 85 | c10::cuda::getCurrentCUDAStream())); 86 | return out; 87 | } 88 | 89 | template 90 | typename Gemm::Arguments MakeArguments(torch::Tensor a, 91 | torch::Tensor b, 92 | torch::Tensor c, 93 | torch::Tensor batch_sizes) { 94 | auto problem_sizes_host = MakeProblemSizes(b, batch_sizes); 95 | 96 | // Calculate the number of threadblocks to use and validate the result. 97 | int64_t num_experts = problem_sizes_host.size(); 98 | 99 | // NOTE: This is borrowed from FasterTransformer. 100 | int threadblock_count = Gemm::sufficient(problem_sizes_host.data(), num_experts); 101 | if (!threadblock_count) { 102 | TORCH_CHECK(false, "Grouped GEMM execution not possible with HW"); 103 | } 104 | 105 | // Create the host arrays of leading dimension data and pointer data. 106 | using LayoutA = typename Gemm::LayoutA; 107 | using LayoutB = typename Gemm::LayoutB; 108 | using LayoutC = typename Gemm::LayoutC; 109 | 110 | std::vector lda_host(num_experts), offsets_a(num_experts); 111 | std::vector ldb_host(num_experts), offsets_b(num_experts); 112 | std::vector ldc_host(num_experts), offsets_c(num_experts); 113 | int64_t elements_a = 0, elements_b = 0, elements_c = 0; 114 | 115 | using ElementA = typename Gemm::ElementA; 116 | using ElementB = typename Gemm::ElementB; 117 | using ElementC = typename Gemm::ElementC; 118 | std::vector ptr_a_host(num_experts); 119 | std::vector ptr_b_host(num_experts); 120 | std::vector ptr_c_host(num_experts); 121 | 122 | for (int i = 0; i < num_experts; ++i) { 123 | auto problem = problem_sizes_host[i]; 124 | lda_host[i] = LayoutA::packed({problem.m(), problem.k()}).stride(0); 125 | ldb_host[i] = LayoutB::packed({problem.k(), problem.n()}).stride(0); 126 | ldc_host[i] = LayoutC::packed({problem.m(), problem.n()}).stride(0); 127 | 128 | offsets_a[i] = elements_a; 129 | offsets_b[i] = elements_b; 130 | offsets_c[i] = elements_c; 131 | 132 | ptr_a_host[i] = (ElementA*)a.data_ptr() + offsets_a[i]; 133 | ptr_b_host[i] = (ElementB*)b.data_ptr() + offsets_b[i]; 134 | ptr_c_host[i] = (ElementC*)c.data_ptr() + offsets_c[i]; 135 | 136 | elements_a += problem.m() * problem.k(); 137 | elements_b += problem.k() * problem.n(); 138 | elements_c += problem.m() * problem.n(); 139 | } 140 | 141 | // Copy the problem sizes, pointers and leading dimension data to the device. 142 | torch::Tensor lda = CopyToDevice(lda_host, a.device()); 143 | torch::Tensor ldb = CopyToDevice(ldb_host, a.device()); 144 | torch::Tensor ldc = CopyToDevice(ldc_host, a.device()); 145 | torch::Tensor ptr_a = CopyToDevice(ptr_a_host, a.device()); 146 | torch::Tensor ptr_b = CopyToDevice(ptr_b_host, a.device()); 147 | torch::Tensor ptr_c = CopyToDevice(ptr_c_host, a.device()); 148 | torch::Tensor problem_sizes = CopyToDevice(problem_sizes_host, a.device()); 149 | 150 | typename Gemm::EpilogueOutputOp::Params epilogue_op(/*alpha=*/1.0f, /*beta=*/0.0f); 151 | typename Gemm::Arguments arguments((cutlass::gemm::GemmCoord*)problem_sizes.data_ptr(), 152 | (int)num_experts, 153 | (int)threadblock_count, 154 | epilogue_op, 155 | (ElementA**)ptr_a.data_ptr(), 156 | (ElementB**)ptr_b.data_ptr(), 157 | (ElementC**)ptr_c.data_ptr(), 158 | (ElementC**)ptr_c.data_ptr(), 159 | /*lda=*/(int64_t*)lda.data_ptr(), 160 | /*ldb=*/(int64_t*)ldb.data_ptr(), 161 | /*ldc=*/(int64_t*)ldc.data_ptr(), 162 | /*ldd=*/(int64_t*)ldc.data_ptr(), 163 | (cutlass::gemm::GemmCoord*)problem_sizes_host.data()); 164 | return arguments; 165 | } 166 | 167 | torch::Tensor CutlassGroupedGemm(torch::Tensor a, 168 | torch::Tensor b, 169 | torch::Tensor c, 170 | torch::Tensor batch_sizes) { 171 | using Gemm = GemmGroupedNN; 172 | Gemm gemm; 173 | 174 | auto arguments = MakeArguments(a, b, c, batch_sizes); 175 | int64_t workspace_size = gemm.get_workspace_size(arguments); 176 | auto options = torch::TensorOptions().dtype(torch::kInt8).device(a.device()); 177 | torch::Tensor workspace = torch::empty(workspace_size, options); 178 | 179 | // Initialize the kernel. 180 | if(gemm.initialize(arguments, workspace.data_ptr()) != cutlass::Status::kSuccess) { 181 | TORCH_CHECK(false, "Failed to initialize CUTLASS Grouped GEMM"); 182 | } 183 | 184 | // Execute the kernel in the current stream. 185 | if(gemm.run(c10::cuda::getCurrentCUDAStream()) != cutlass::Status::kSuccess) { 186 | TORCH_CHECK(false, "Failed to run CUTLASS Grouped GEMM"); 187 | } 188 | return c; 189 | } 190 | 191 | cublasHandle_t cublas_handle[NUM_STREAM]; 192 | cudaStream_t cublas_stream[NUM_STREAM]; 193 | cudaEvent_t cublas_event[NUM_STREAM]; 194 | bool cublas_init = false; 195 | 196 | void cublas_handle_init() 197 | { 198 | cublas_init = true; 199 | 200 | for (int i = 0; i < NUM_STREAM; i++) 201 | { 202 | cudaStreamCreateWithFlags(&cublas_stream[i], cudaStreamNonBlocking); 203 | cublasCreate(&cublas_handle[i]); 204 | cublasSetStream(cublas_handle[i], cublas_stream[i]); 205 | cudaEventCreate(&cublas_event[i]); 206 | } 207 | } 208 | 209 | inline void cublas_current_wait_streams(cudaStream_t stream) 210 | { 211 | for (int s = 0; s < NUM_STREAM; s++) 212 | { 213 | cudaEventRecord(cublas_event[s], cublas_stream[s]); 214 | } 215 | 216 | for (int s = 0; s < NUM_STREAM; s++) 217 | { 218 | cudaStreamWaitEvent(stream, cublas_event[s]); 219 | } 220 | } 221 | 222 | inline void cublas_streams_wait_current(cudaStream_t stream) 223 | { 224 | cudaEventRecord(cublas_event[0], stream); 225 | 226 | for (int s = 0; s < NUM_STREAM; s++) 227 | { 228 | cudaStreamWaitEvent(cublas_stream[s], cublas_event[0]); 229 | } 230 | } 231 | 232 | void CublasGemm(cublasHandle_t cublas_handle, 233 | c10::BFloat16 *a, int64_t a_rows, int64_t a_cols, bool trans_a, 234 | c10::BFloat16 *b, int64_t b_rows, int64_t b_cols, bool trans_b, 235 | c10::BFloat16 *c, int64_t c_rows, int64_t c_cols) { 236 | int m = trans_b ? b_rows : b_cols; 237 | int k = trans_b ? b_cols : b_rows; 238 | int n = trans_a ? a_cols : a_rows; 239 | 240 | int lda = trans_a ? n : k; 241 | int ldb = trans_b ? k : m; 242 | cublasOperation_t transpose_a = trans_a ? CUBLAS_OP_T : CUBLAS_OP_N; 243 | cublasOperation_t transpose_b = trans_b ? CUBLAS_OP_T : CUBLAS_OP_N; 244 | 245 | 246 | float alpha = 1.0, beta = 0.0; 247 | CUBLAS_CALL(cublasGemmEx(cublas_handle, 248 | transpose_b, transpose_a, 249 | m, n, k, &alpha, 250 | b, CUDA_R_16BF, ldb, 251 | a, CUDA_R_16BF, lda, 252 | &beta, 253 | c, CUDA_R_16BF, c_cols, CUDA_R_32F, 254 | CUBLAS_GEMM_DEFAULT)); 255 | } 256 | 257 | void CublasGroupedGemm(torch::Tensor a, 258 | torch::Tensor b, 259 | torch::Tensor c, 260 | torch::Tensor batch_sizes, 261 | bool trans_b) { 262 | if (!cublas_init) 263 | cublas_handle_init(); 264 | 265 | int64_t bs = batch_sizes.size(0), k = a.size(1); 266 | int64_t n = trans_b ? b.size(1) : b.size(2); 267 | int64_t b_rows = b.size(1), b_cols = b.size(2); 268 | c10::BFloat16* a_ptr = a.data_ptr(); 269 | c10::BFloat16* b_ptr = b.data_ptr(); 270 | c10::BFloat16* c_ptr = c.data_ptr(); 271 | 272 | cublas_streams_wait_current(c10::cuda::getCurrentCUDAStream()); 273 | 274 | for (int i = 0; i < bs; ++i) { 275 | 276 | int64_t m = batch_sizes.data_ptr()[i]; 277 | CublasGemm(cublas_handle[i % NUM_STREAM], a_ptr, m, k, /*trans_a=*/false, 278 | b_ptr, b_rows, b_cols, trans_b, 279 | c_ptr, m, n); 280 | a_ptr += m * k; 281 | b_ptr += b_rows * b_cols; 282 | c_ptr += m * n; 283 | } 284 | 285 | cublas_current_wait_streams(c10::cuda::getCurrentCUDAStream()); 286 | } 287 | 288 | void CublasGroupedGemmVariableK(torch::Tensor a, 289 | torch::Tensor b, 290 | torch::Tensor c, 291 | torch::Tensor batch_sizes) { 292 | if (!cublas_init) 293 | cublas_handle_init(); 294 | 295 | int64_t bs = batch_sizes.size(0), m = a.size(1), n = b.size(1); 296 | c10::BFloat16* a_ptr = a.data_ptr(); 297 | c10::BFloat16* b_ptr = b.data_ptr(); 298 | c10::BFloat16* c_ptr = c.data_ptr(); 299 | 300 | cublas_streams_wait_current(c10::cuda::getCurrentCUDAStream()); 301 | 302 | for (int i = 0; i < bs; ++i) { 303 | int64_t k = batch_sizes.data_ptr()[i]; 304 | CublasGemm(cublas_handle[i % NUM_STREAM], a_ptr, k, m, /*trans_a=*/true, 305 | b_ptr, k, n, /*trans_b=*/false, 306 | c_ptr, m, n); 307 | a_ptr += k * m; 308 | b_ptr += k * n; 309 | c_ptr += m * n; 310 | } 311 | 312 | cublas_current_wait_streams(c10::cuda::getCurrentCUDAStream()); 313 | } 314 | 315 | void GroupedGemmVariableK(torch::Tensor a, 316 | torch::Tensor b, 317 | torch::Tensor c, 318 | torch::Tensor batch_sizes) { 319 | // We expected a CUDA tensor with two dimensions and shape 320 | // (tokens, hidden_out) for 'b'. 321 | TORCH_CHECK(b.is_cuda()); 322 | TORCH_CHECK(b.ndimension() == 2); 323 | TORCH_CHECK(b.scalar_type() == torch::kBFloat16); 324 | 325 | // Validate the dimensions. 326 | int64_t tokens = a.size(0), num_experts = batch_sizes.size(0); 327 | int64_t m = a.size(1), n = b.size(1); 328 | 329 | // Validate that we have the same contraction dimension. 330 | TORCH_CHECK(tokens == b.size(0)); 331 | 332 | // Validate the output shape. 333 | TORCH_CHECK(c.is_cuda()); 334 | TORCH_CHECK(c.ndimension() == 3); 335 | TORCH_CHECK(c.scalar_type() == torch::kBFloat16); 336 | TORCH_CHECK(c.size(0) == num_experts); 337 | TORCH_CHECK(c.size(1) == m); 338 | TORCH_CHECK(c.size(2) == n); 339 | 340 | // Run the computation. 341 | CublasGroupedGemmVariableK(a, b, c, batch_sizes); 342 | } 343 | 344 | // NOTE: We only support dynamic group sizes for the 'a' tensor. Tensor 'b' is 345 | // assumed to be batched with fixed sized batches. 346 | // 347 | // TODO(tgale): Validate alignment is true for every batch element. 348 | void GroupedGemm(torch::Tensor a, 349 | torch::Tensor b, 350 | torch::Tensor c, 351 | torch::Tensor batch_sizes, 352 | bool trans_a, bool trans_b) { 353 | // NOTE: We only support 'trans_a' or 'trans_b', not both. 354 | TORCH_CHECK(!(trans_a && trans_b)); 355 | 356 | // We expect the batch_sizes on CPU. 357 | TORCH_CHECK(batch_sizes.is_cpu()); 358 | TORCH_CHECK(batch_sizes.ndimension() == 1); 359 | TORCH_CHECK(batch_sizes.scalar_type() == torch::kInt64); 360 | 361 | // We expected a CUDA tensor with two dimensions and shape 362 | // (tokens, hidden_in) for 'a'. 363 | TORCH_CHECK(a.is_cuda()); 364 | TORCH_CHECK(a.ndimension() == 2); 365 | TORCH_CHECK(a.scalar_type() == torch::kBFloat16); 366 | 367 | // Defer to the variable 'k' helper for the rest of the op. 368 | if (trans_a) { 369 | GroupedGemmVariableK(a, b, c, batch_sizes); 370 | return; 371 | } 372 | 373 | // We expected a CUDA tensor with three dimensions and shape 374 | // (num_experts, hidden_in, hidden_out) for 'b'. 375 | TORCH_CHECK(b.is_cuda()); 376 | TORCH_CHECK(b.ndimension() == 3); 377 | TORCH_CHECK(b.scalar_type() == torch::kBFloat16); 378 | 379 | // Validate the contraction dimensions match. 380 | int64_t tokens = a.size(0), num_experts = b.size(0); 381 | int64_t hidden_in = trans_b ? b.size(2) : b.size(1); 382 | int64_t hidden_out = trans_b ? b.size(1) : b.size(2); 383 | TORCH_CHECK(hidden_in == a.size(1)); 384 | 385 | // Validate that we have one size per expert. 386 | TORCH_CHECK(batch_sizes.size(0) == num_experts); 387 | 388 | // Validate the output shape. 389 | TORCH_CHECK(c.is_cuda()); 390 | TORCH_CHECK(c.ndimension() == 2); 391 | TORCH_CHECK(c.scalar_type() == torch::kBFloat16); 392 | TORCH_CHECK(c.size(0) == tokens); 393 | TORCH_CHECK(c.size(1) == hidden_out); 394 | 395 | // NOTE: We support transposition through the 'trans_b' flag. 396 | TORCH_CHECK(a.is_contiguous()); 397 | TORCH_CHECK(b.is_contiguous()); 398 | 399 | // NOTE: Use cuBLAS for SM90 until CUTLASS supports SM90-optimized grouped-gemm. 400 | #if !defined(GROUPED_GEMM_DEVICE_CAPABILITY) || GROUPED_GEMM_DEVICE_CAPABILITY != 80 401 | CublasGroupedGemm(a, b, c, batch_sizes, trans_b); 402 | return; 403 | #else 404 | // TODO(tgale): Support transposition with CUTLASS grouped GEMM. 405 | if (trans_b) { 406 | CublasGroupedGemm(a, b, c, batch_sizes, trans_b); 407 | return; 408 | } 409 | CutlassGroupedGemm(a, b, c, batch_sizes); 410 | return; 411 | #endif 412 | } 413 | 414 | } // namespace grouped_gemm 415 | -------------------------------------------------------------------------------- /csrc/grouped_gemm.h: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | namespace grouped_gemm { 4 | 5 | void GroupedGemm(torch::Tensor a, 6 | torch::Tensor b, 7 | torch::Tensor c, 8 | torch::Tensor batch_sizes, 9 | bool trans_a, bool trans_b); 10 | 11 | } // namespace grouped_gemm 12 | -------------------------------------------------------------------------------- /csrc/ops.cu: -------------------------------------------------------------------------------- 1 | #include "grouped_gemm.h" 2 | #include "permute.h" 3 | #include "sinkhorn.h" 4 | 5 | #include 6 | 7 | namespace grouped_gemm { 8 | 9 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 10 | m.def("gmm", &GroupedGemm, "Grouped GEMM."); 11 | m.def("sinkhorn", &sinkhorn, "Sinkhorn kernel"); 12 | m.def("permute", &moe_permute_topK_op, "Token permutation kernel"); 13 | m.def("unpermute", &moe_recover_topK_op, "Token un-permutation kernel"); 14 | m.def("unpermute_bwd", &moe_recover_topK_bwd_op, "Token un-permutation backward kernel"); 15 | } 16 | 17 | } // namespace grouped_gemm 18 | -------------------------------------------------------------------------------- /csrc/permute.cu: -------------------------------------------------------------------------------- 1 | /************************************************************************* 2 | * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * 4 | * See LICENSE for license information. 5 | ************************************************************************/ 6 | 7 | #include "permute.h" 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | #include "cuda_runtime.h" 14 | #include "device_launch_parameters.h" 15 | 16 | #include "ATen/cuda/CUDAContext.h" 17 | 18 | #include "cutlass/arch/memory.h" 19 | #include "cutlass/arch/cache_operation.h" 20 | #include "cutlass/array.h" 21 | #include "cutlass/numeric_conversion.h" 22 | 23 | 24 | using torch::Tensor; 25 | 26 | namespace grouped_gemm { 27 | 28 | template 29 | inline T *get_ptr(torch::Tensor &t) 30 | { 31 | return reinterpret_cast(t.data_ptr()); 32 | } 33 | 34 | ///////////////////////////////////////////////////////////////////////////////////////////////// 35 | // 36 | // Top K 37 | // 38 | ///////////////////////////////////////////////////////////////////////////////////////////////// 39 | 40 | static __global__ void moe_permute_topK_row_map( 41 | const int *sorted_row_id, 42 | int *row_id_map, 43 | const int num_rows, 44 | const int num_topK, 45 | const int num_out_tokens) 46 | { 47 | // Each block corresponds to one source token 48 | // row_id_map[num_topK][num_rows] 49 | const int bid = blockIdx.x; 50 | const int tid = threadIdx.x; 51 | const int idx = bid * blockDim.x + tid; 52 | 53 | if (idx >= num_rows * num_topK) 54 | return; 55 | 56 | int source_row = sorted_row_id[idx]; 57 | int source_token_id = source_row / num_topK; 58 | int source_topK_id = source_row % num_topK; 59 | 60 | if (idx >= num_out_tokens) 61 | { 62 | row_id_map[source_topK_id * num_rows + source_token_id] = -1; 63 | } 64 | else 65 | { 66 | row_id_map[source_topK_id * num_rows + source_token_id] = idx; 67 | } 68 | } 69 | 70 | template 71 | __global__ void moe_recover_topK_kernel(const T *input, 72 | T *unpermuted_output, 73 | const int *row_id_map, 74 | const float *prob, 75 | const int num_rows, 76 | const int num_topK, 77 | const int num_cols) 78 | { 79 | extern __shared__ int8_t s_mem[]; 80 | TCompute *s_prob = reinterpret_cast(s_mem); 81 | 82 | using FragmentLoadStore = cutlass::Array; 83 | using FragmentCompute = cutlass::Array; 84 | 85 | cutlass::NumericArrayConverter src_converter; 86 | cutlass::NumericArrayConverter dst_converter; 87 | 88 | // each block corresponds to one source token 89 | const int source_token = blockIdx.x; 90 | const int tid = threadIdx.x; 91 | 92 | if (hasProb) 93 | { 94 | for (int i = tid; i < num_topK; i += blockDim.x * blockDim.y) 95 | { 96 | s_prob[i] = TCompute(prob[source_token * num_topK + i]); 97 | } 98 | __syncthreads(); 99 | } 100 | 101 | for (int i = tid * kElementsPerAccess; i < num_cols; i += blockDim.x * kElementsPerAccess) 102 | { 103 | FragmentLoadStore frag_load_store; 104 | FragmentCompute frag_elem; 105 | FragmentCompute frag_sum; 106 | 107 | int source_row = row_id_map[source_token]; 108 | 109 | if (source_row != -1) 110 | { 111 | const T *source_row_ptr = input + source_row * num_cols; 112 | 113 | cutlass::arch::global_load( 114 | frag_load_store, (source_row_ptr + i), true); 115 | frag_sum = src_converter(frag_load_store); 116 | 117 | if (hasProb) 118 | { 119 | frag_sum = frag_sum * s_prob[0]; 120 | } 121 | } 122 | else 123 | { 124 | frag_sum.clear(); 125 | } 126 | 127 | for (int k = 1; k < num_topK; k++) 128 | { 129 | source_row = row_id_map[k * num_rows + source_token]; 130 | 131 | if (source_row == -1) 132 | continue; 133 | 134 | const T *source_row_ptr = input + source_row * num_cols; 135 | 136 | cutlass::arch::global_load( 137 | frag_load_store, (source_row_ptr + i), true); 138 | frag_elem = src_converter(frag_load_store); 139 | 140 | if (hasProb) 141 | { 142 | frag_elem = frag_elem * s_prob[k]; 143 | } 144 | 145 | for (int e = 0; e < kElementsPerAccess; e++) 146 | { 147 | frag_sum.at(e) = frag_sum.at(e) + frag_elem.at(e); 148 | } 149 | } 150 | 151 | T *dest_row_ptr = unpermuted_output + source_token * num_cols; 152 | frag_load_store = dst_converter(frag_sum); 153 | *(float4 *)(dest_row_ptr + i) = *(float4 *)(frag_load_store.data()); 154 | } 155 | } 156 | 157 | template 162 | __global__ void moe_permute_topK_kernel(const T *input_bwd, 163 | const T *input_fwd, 164 | T *act_grad, 165 | const float *prob, 166 | float *prob_grad, 167 | const int *row_id_map, 168 | const int num_rows, 169 | const int num_topK, 170 | const int num_cols) 171 | { 172 | extern __shared__ int8_t s_mem[]; 173 | TCompute *s_prob = reinterpret_cast(s_mem); 174 | 175 | using FragmentLoadStore = cutlass::Array; 176 | using FragmentCompute = cutlass::Array; 177 | 178 | cutlass::NumericArrayConverter src_converter; 179 | cutlass::NumericArrayConverter dst_converter; 180 | 181 | const int source_token = blockIdx.x; 182 | const int tid = threadIdx.x; 183 | 184 | if (hasProb) 185 | { 186 | for (int i = tid; i < num_topK; i += blockDim.x) 187 | { 188 | s_prob[i] = TCompute(prob[source_token * num_topK + i]); 189 | } 190 | __syncthreads(); 191 | } 192 | 193 | float accum[topKTile] = {0.0f}; 194 | FragmentLoadStore frag_load_store; 195 | 196 | const T *source_row_ptr = input_bwd + source_token * num_cols; 197 | for (int i = tid * kElementsPerAccess; i < num_cols; i += blockDim.x * kElementsPerAccess) 198 | { 199 | cutlass::arch::global_load( 200 | frag_load_store, (source_row_ptr + i), true); 201 | FragmentCompute frag_src = src_converter(frag_load_store); 202 | 203 | int index = source_token; 204 | 205 | for (int k = 0; k < topKTile; k++) 206 | { 207 | if (k == num_topK) break; 208 | 209 | int dest_row = row_id_map[index]; 210 | index += num_rows; 211 | 212 | if (dest_row == -1) 213 | continue; 214 | 215 | if (hasProb) 216 | { 217 | frag_load_store = dst_converter(frag_src * s_prob[k]); 218 | } 219 | else 220 | { 221 | frag_load_store = dst_converter(frag_src); 222 | } 223 | 224 | T *dest_row_ptr = act_grad + dest_row * num_cols; 225 | *(float4 *)(dest_row_ptr + i) = *(float4 *)(frag_load_store.data()); 226 | 227 | if (hasProb) 228 | { 229 | const T *input_fwd_ptr = input_fwd + dest_row * num_cols; 230 | cutlass::arch::global_load( 231 | frag_load_store, (input_fwd_ptr + i), true); 232 | FragmentCompute frag_input_fwd = src_converter(frag_load_store); 233 | 234 | for (int e = 0; e < kElementsPerAccess; e++) 235 | { 236 | accum[k] += float(frag_src.at(e) * frag_input_fwd.at(e)); 237 | } 238 | } 239 | } 240 | } 241 | 242 | if (hasProb) 243 | { 244 | for (int k = 0; k < topKTile; k++) 245 | { 246 | if (k == num_topK) break; 247 | 248 | for (int mask = 16; mask > 0; mask /= 2) 249 | { 250 | accum[k] = accum[k] + __shfl_xor_sync(0xffffffff, accum[k], mask, 32); 251 | } 252 | } 253 | 254 | if (tid == 0) 255 | { 256 | for (int k = 0; k < topKTile; k++) 257 | { 258 | if (k == num_topK) break; 259 | prob_grad[source_token * num_topK + k] = accum[k]; 260 | } 261 | } 262 | } 263 | } 264 | 265 | 266 | template 267 | void moe_permute_topK_kernel_launcher( 268 | const T *input, 269 | T *output, 270 | const int *sorted_row_id, 271 | int *row_id_map, 272 | const float *prob, 273 | const int num_rows, 274 | const int num_topK, 275 | const int num_cols, 276 | const int num_out_tokens, 277 | cudaStream_t stream, 278 | float *prob_grad = nullptr, 279 | const T *input_fwd = nullptr) 280 | { 281 | if (FWD) 282 | { 283 | if (prob_grad == nullptr) 284 | { 285 | // permute_topK fwd 286 | int threads = 64; 287 | int blocks = (num_rows * num_topK + threads - 1) / threads; 288 | moe_permute_topK_row_map<<>>( 289 | sorted_row_id, 290 | row_id_map, 291 | num_rows, 292 | num_topK, 293 | num_out_tokens); 294 | 295 | blocks = num_rows; 296 | threads = std::min(num_cols / kElementsPerAccess, 1024); 297 | moe_permute_topK_kernel<<>>( 298 | input, 299 | nullptr, 300 | output, 301 | nullptr, 302 | nullptr, 303 | row_id_map, 304 | num_rows, 305 | num_topK, 306 | num_cols); 307 | } 308 | else 309 | { 310 | // unpermute_topK bwd 311 | int blocks = num_rows; 312 | int threads = 32; 313 | size_t smem_bytes = num_topK * sizeof(TCompute); 314 | 315 | if (num_topK == 1) 316 | { 317 | moe_permute_topK_kernel<<>>( 318 | input, 319 | input_fwd, 320 | output, 321 | prob, 322 | prob_grad, 323 | row_id_map, 324 | num_rows, 325 | num_topK, 326 | num_cols); 327 | } 328 | else if (num_topK <= 8) 329 | { 330 | moe_permute_topK_kernel<<>>( 331 | input, 332 | input_fwd, 333 | output, 334 | prob, 335 | prob_grad, 336 | row_id_map, 337 | num_rows, 338 | num_topK, 339 | num_cols); 340 | } 341 | else if (num_topK <= 16) 342 | { 343 | moe_permute_topK_kernel<<>>( 344 | input, 345 | input_fwd, 346 | output, 347 | prob, 348 | prob_grad, 349 | row_id_map, 350 | num_rows, 351 | num_topK, 352 | num_cols); 353 | } 354 | else if (num_topK <= 32) 355 | { 356 | moe_permute_topK_kernel<<>>( 357 | input, 358 | input_fwd, 359 | output, 360 | prob, 361 | prob_grad, 362 | row_id_map, 363 | num_rows, 364 | num_topK, 365 | num_cols); 366 | } 367 | else if (num_topK <= 64) 368 | { 369 | moe_permute_topK_kernel<<>>( 370 | input, 371 | input_fwd, 372 | output, 373 | prob, 374 | prob_grad, 375 | row_id_map, 376 | num_rows, 377 | num_topK, 378 | num_cols); 379 | } 380 | else if (num_topK <= 128) 381 | { 382 | moe_permute_topK_kernel<<>>( 383 | input, 384 | input_fwd, 385 | output, 386 | prob, 387 | prob_grad, 388 | row_id_map, 389 | num_rows, 390 | num_topK, 391 | num_cols); 392 | } 393 | else 394 | { 395 | throw std::runtime_error("num_topK cannot exceed 128."); 396 | } 397 | } 398 | } 399 | else 400 | { 401 | int blocks = num_rows; 402 | int threads = std::min(num_cols / kElementsPerAccess, 1024); 403 | size_t smem_bytes = num_topK * sizeof(TCompute); 404 | 405 | 406 | if (num_topK == 1) 407 | { 408 | // permute_topK bwd with topK==1 409 | moe_recover_topK_kernel<<>>( 410 | input, 411 | output, 412 | row_id_map, 413 | prob, 414 | num_rows, 415 | num_topK, 416 | num_cols); 417 | } 418 | else if (prob == nullptr) 419 | { 420 | // permute_topK bwd 421 | moe_recover_topK_kernel<<>>( 422 | input, 423 | output, 424 | row_id_map, 425 | prob, 426 | num_rows, 427 | num_topK, 428 | num_cols); 429 | } 430 | else 431 | { 432 | // unpermute_topK fwd 433 | moe_recover_topK_kernel<<>>( 434 | input, 435 | output, 436 | row_id_map, 437 | prob, 438 | num_rows, 439 | num_topK, 440 | num_cols); 441 | } 442 | } 443 | } 444 | 445 | ///////////////////////////////////////////////////////////////////////////////////////////////// 446 | // 447 | // Permute_topK OP 448 | // 449 | ///////////////////////////////////////////////////////////////////////////////////////////////// 450 | 451 | std::tuple> moe_permute_topK_op( 452 | Tensor input, 453 | Tensor indices, 454 | int64_t num_out_tokens, 455 | std::vector workspace, 456 | int64_t max_expanded_token_num) 457 | { 458 | const int num_tokens = input.size(0); 459 | const int num_cols = input.size(1); 460 | const int num_topK = indices.size(1); 461 | 462 | // initialize the workspace on the first run 463 | if (workspace.empty()) { 464 | auto options = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false); 465 | 466 | Tensor sorted_indices = torch::empty(max_expanded_token_num, options); 467 | Tensor row_id = torch::range(0, max_expanded_token_num - 1, 1, options); 468 | Tensor sorted_row_id = 469 | torch::empty(max_expanded_token_num, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false)); 470 | 471 | size_t temp_storage_bytes = 0; 472 | int *temp_ptr = nullptr; 473 | cub::DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, 474 | temp_ptr, temp_ptr, 475 | temp_ptr, temp_ptr, max_expanded_token_num); 476 | Tensor temp_storage = 477 | torch::empty(temp_storage_bytes, torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false)); 478 | 479 | workspace.push_back(sorted_indices); 480 | workspace.push_back(row_id); 481 | workspace.push_back(sorted_row_id); 482 | workspace.push_back(temp_storage); 483 | } 484 | 485 | int *indices_ptr = get_ptr(indices); 486 | int *sorted_indices_ptr = get_ptr(workspace[0]); 487 | int *row_id_ptr = get_ptr(workspace[1]); 488 | int *sorted_row_id_ptr = get_ptr(workspace[2]); 489 | 490 | void *d_temp_storage = get_ptr(workspace[3]); 491 | size_t temp_storage_bytes = std::numeric_limits::max(); 492 | 493 | cub::DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, 494 | indices_ptr, sorted_indices_ptr, 495 | row_id_ptr, sorted_row_id_ptr, num_tokens * num_topK); 496 | 497 | // activations type 498 | const at::ScalarType _st = input.scalar_type(); 499 | 500 | // Output buffer alloc 501 | num_out_tokens = (num_out_tokens > 0) ? num_out_tokens : num_tokens * num_topK; 502 | Tensor permuted_output = 503 | torch::empty({num_out_tokens, num_cols}, torch::dtype(_st).device(torch::kCUDA).requires_grad(false)); 504 | Tensor row_id_map = 505 | torch::empty({num_tokens * num_topK}, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false)); 506 | 507 | int *row_id_map_ptr = get_ptr(row_id_map); 508 | auto stream = at::cuda::getCurrentCUDAStream().stream(); 509 | 510 | switch (_st) 511 | { 512 | case at::ScalarType::Float: 513 | { 514 | using dType = float; 515 | using dTypeCompute = float; 516 | 517 | dType *input_ptr = get_ptr(input); 518 | dType *permuted_output_ptr = get_ptr(permuted_output); 519 | 520 | moe_permute_topK_kernel_launcher( 521 | input_ptr, 522 | permuted_output_ptr, 523 | sorted_row_id_ptr, 524 | row_id_map_ptr, 525 | nullptr, 526 | num_tokens, 527 | num_topK, 528 | num_cols, 529 | num_out_tokens, 530 | stream); 531 | 532 | break; 533 | } 534 | case at::ScalarType::Half: 535 | { 536 | using dType = cutlass::half_t; 537 | using dTypeCompute = cutlass::half_t; 538 | 539 | dType *input_ptr = get_ptr(input); 540 | dType *permuted_output_ptr = get_ptr(permuted_output); 541 | 542 | moe_permute_topK_kernel_launcher( 543 | input_ptr, 544 | permuted_output_ptr, 545 | sorted_row_id_ptr, 546 | row_id_map_ptr, 547 | nullptr, 548 | num_tokens, 549 | num_topK, 550 | num_cols, 551 | num_out_tokens, 552 | stream); 553 | 554 | break; 555 | } 556 | #ifdef ENABLE_BF16 557 | case at::ScalarType::BFloat16: 558 | { 559 | using dType = cutlass::bfloat16_t; 560 | using dTypeCompute = cutlass::bfloat16_t; 561 | 562 | dType *input_ptr = get_ptr(input); 563 | dType *permuted_output_ptr = get_ptr(permuted_output); 564 | 565 | moe_permute_topK_kernel_launcher( 566 | input_ptr, 567 | permuted_output_ptr, 568 | sorted_row_id_ptr, 569 | row_id_map_ptr, 570 | nullptr, 571 | num_tokens, 572 | num_topK, 573 | num_cols, 574 | num_out_tokens, 575 | stream); 576 | 577 | break; 578 | } 579 | #endif 580 | #ifdef ENABLE_FP8 581 | case at::ScalarType::Float8_e5m2: 582 | { 583 | using dType = cutlass::float_e5m2_t; 584 | using dTypeCompute = cutlass::half_t; 585 | 586 | dType *input_ptr = get_ptr(input); 587 | dType *permuted_output_ptr = get_ptr(permuted_output); 588 | 589 | moe_permute_topK_kernel_launcher( 590 | input_ptr, 591 | permuted_output_ptr, 592 | sorted_row_id_ptr, 593 | row_id_map_ptr, 594 | nullptr, 595 | num_tokens, 596 | num_topK, 597 | num_cols, 598 | num_out_tokens, 599 | stream); 600 | 601 | break; 602 | } 603 | case at::ScalarType::Float8_e4m3fn: 604 | { 605 | using dType = cutlass::float_e4m3_t; 606 | using dTypeCompute = cutlass::half_t; 607 | 608 | dType *input_ptr = get_ptr(input); 609 | dType *permuted_output_ptr = get_ptr(permuted_output); 610 | 611 | moe_permute_topK_kernel_launcher( 612 | input_ptr, 613 | permuted_output_ptr, 614 | sorted_row_id_ptr, 615 | row_id_map_ptr, 616 | nullptr, 617 | num_tokens, 618 | num_topK, 619 | num_cols, 620 | num_out_tokens, 621 | stream); 622 | 623 | break; 624 | } 625 | #endif 626 | default: 627 | throw std::runtime_error("Wrong activation tensor type."); 628 | } 629 | 630 | return std::make_tuple(permuted_output, row_id_map, workspace); 631 | } 632 | 633 | ///////////////////////////////////////////////////////////////////////////////////////////////// 634 | // 635 | // Unpermute_topK OP 636 | // 637 | ///////////////////////////////////////////////////////////////////////////////////////////////// 638 | 639 | Tensor moe_recover_topK_op( 640 | Tensor input, 641 | Tensor row_id_map, 642 | Tensor prob, 643 | int64_t num_tokens, 644 | int64_t num_topK) 645 | { 646 | const int num_cols = input.size(1); 647 | 648 | // activations type 649 | const at::ScalarType _st = input.scalar_type(); 650 | 651 | // Output buffer alloc 652 | Tensor unpermuted_output = 653 | torch::empty({num_tokens, num_cols}, torch::dtype(_st).device(torch::kCUDA).requires_grad(false)); 654 | 655 | int *row_id_map_ptr = get_ptr(row_id_map); 656 | float *prob_ptr = (prob.defined()) ? get_ptr(prob) : nullptr; 657 | auto stream = at::cuda::getCurrentCUDAStream().stream(); 658 | 659 | switch (_st) 660 | { 661 | case at::ScalarType::Float: 662 | { 663 | using dType = float; 664 | using dTypeCompute = float; 665 | 666 | dType *input_ptr = get_ptr(input); 667 | dType *unpermuted_output_ptr = get_ptr(unpermuted_output); 668 | 669 | moe_permute_topK_kernel_launcher( 670 | input_ptr, 671 | unpermuted_output_ptr, 672 | nullptr, 673 | row_id_map_ptr, 674 | prob_ptr, 675 | num_tokens, 676 | num_topK, 677 | num_cols, 678 | 0, 679 | stream); 680 | 681 | break; 682 | } 683 | case at::ScalarType::Half: 684 | { 685 | using dType = cutlass::half_t; 686 | using dTypeCompute = cutlass::half_t; 687 | 688 | dType *input_ptr = get_ptr(input); 689 | dType *unpermuted_output_ptr = get_ptr(unpermuted_output); 690 | 691 | moe_permute_topK_kernel_launcher( 692 | input_ptr, 693 | unpermuted_output_ptr, 694 | nullptr, 695 | row_id_map_ptr, 696 | prob_ptr, 697 | num_tokens, 698 | num_topK, 699 | num_cols, 700 | 0, 701 | stream); 702 | 703 | break; 704 | } 705 | #ifdef ENABLE_BF16 706 | case at::ScalarType::BFloat16: 707 | { 708 | using dType = cutlass::bfloat16_t; 709 | using dTypeCompute = cutlass::bfloat16_t; 710 | 711 | dType *input_ptr = get_ptr(input); 712 | dType *unpermuted_output_ptr = get_ptr(unpermuted_output); 713 | 714 | moe_permute_topK_kernel_launcher( 715 | input_ptr, 716 | unpermuted_output_ptr, 717 | nullptr, 718 | row_id_map_ptr, 719 | prob_ptr, 720 | num_tokens, 721 | num_topK, 722 | num_cols, 723 | 0, 724 | stream); 725 | 726 | break; 727 | } 728 | #endif 729 | #ifdef ENABLE_FP8 730 | case at::ScalarType::Float8_e5m2: 731 | { 732 | using dType = cutlass::float_e5m2_t; 733 | using dTypeCompute = cutlass::half_t; 734 | 735 | dType *input_ptr = get_ptr(input); 736 | dType *unpermuted_output_ptr = get_ptr(unpermuted_output); 737 | 738 | moe_permute_topK_kernel_launcher( 739 | input_ptr, 740 | unpermuted_output_ptr, 741 | nullptr, 742 | row_id_map_ptr, 743 | prob_ptr, 744 | num_tokens, 745 | num_topK, 746 | num_cols, 747 | 0, 748 | stream); 749 | 750 | break; 751 | } 752 | case at::ScalarType::Float8_e4m3fn: 753 | { 754 | using dType = cutlass::float_e4m3_t; 755 | using dTypeCompute = cutlass::half_t; 756 | 757 | dType *input_ptr = get_ptr(input); 758 | dType *unpermuted_output_ptr = get_ptr(unpermuted_output); 759 | 760 | moe_permute_topK_kernel_launcher( 761 | input_ptr, 762 | unpermuted_output_ptr, 763 | nullptr, 764 | row_id_map_ptr, 765 | prob_ptr, 766 | num_tokens, 767 | num_topK, 768 | num_cols, 769 | 0, 770 | stream); 771 | 772 | break; 773 | } 774 | #endif 775 | default: 776 | throw std::runtime_error("Wrong activation tensor type."); 777 | } 778 | 779 | return unpermuted_output; 780 | } 781 | 782 | std::tuple moe_recover_topK_bwd_op( 783 | Tensor input_bwd, 784 | Tensor input_fwd, 785 | Tensor row_id_map, 786 | Tensor prob) 787 | { 788 | const int num_tokens = prob.size(0); 789 | const int num_topK = prob.size(1); 790 | const int num_cols = input_bwd.size(1); 791 | 792 | int *row_id_map_ptr = get_ptr(row_id_map); 793 | float *prob_ptr = get_ptr(prob); 794 | 795 | // activations type 796 | const at::ScalarType _st = input_bwd.scalar_type(); 797 | 798 | // Output buffer alloc 799 | Tensor act_grad = 800 | torch::empty({input_fwd.size(0), num_cols}, torch::dtype(_st).device(torch::kCUDA).requires_grad(false)); 801 | Tensor prob_grad = 802 | torch::empty({num_tokens, num_topK}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(false)); 803 | float *prob_grad_ptr = get_ptr(prob_grad); 804 | 805 | auto stream = at::cuda::getCurrentCUDAStream().stream(); 806 | 807 | switch (_st) 808 | { 809 | case at::ScalarType::Float: 810 | { 811 | using dType = float; 812 | using dTypeCompute = float; 813 | 814 | dType *input_bwd_ptr = get_ptr(input_bwd); 815 | dType *input_fwd_ptr = get_ptr(input_fwd); 816 | dType *act_grad_ptr = get_ptr(act_grad); 817 | 818 | moe_permute_topK_kernel_launcher( 819 | input_bwd_ptr, 820 | act_grad_ptr, 821 | nullptr, 822 | row_id_map_ptr, 823 | prob_ptr, 824 | num_tokens, 825 | num_topK, 826 | num_cols, 827 | 0, 828 | stream, 829 | prob_grad_ptr, 830 | input_fwd_ptr); 831 | 832 | break; 833 | } 834 | case at::ScalarType::Half: 835 | { 836 | using dType = cutlass::half_t; 837 | using dTypeCompute = cutlass::half_t; 838 | 839 | dType *input_bwd_ptr = get_ptr(input_bwd); 840 | dType *input_fwd_ptr = get_ptr(input_fwd); 841 | dType *act_grad_ptr = get_ptr(act_grad); 842 | 843 | moe_permute_topK_kernel_launcher( 844 | input_bwd_ptr, 845 | act_grad_ptr, 846 | nullptr, 847 | row_id_map_ptr, 848 | prob_ptr, 849 | num_tokens, 850 | num_topK, 851 | num_cols, 852 | 0, 853 | stream, 854 | prob_grad_ptr, 855 | input_fwd_ptr); 856 | 857 | break; 858 | } 859 | #ifdef ENABLE_BF16 860 | case at::ScalarType::BFloat16: 861 | { 862 | using dType = cutlass::bfloat16_t; 863 | using dTypeCompute = cutlass::bfloat16_t; 864 | 865 | dType *input_bwd_ptr = get_ptr(input_bwd); 866 | dType *input_fwd_ptr = get_ptr(input_fwd); 867 | dType *act_grad_ptr = get_ptr(act_grad); 868 | 869 | moe_permute_topK_kernel_launcher( 870 | input_bwd_ptr, 871 | act_grad_ptr, 872 | nullptr, 873 | row_id_map_ptr, 874 | prob_ptr, 875 | num_tokens, 876 | num_topK, 877 | num_cols, 878 | 0, 879 | stream, 880 | prob_grad_ptr, 881 | input_fwd_ptr); 882 | 883 | break; 884 | } 885 | #endif 886 | #ifdef ENABLE_FP8 887 | case at::ScalarType::Float8_e5m2: 888 | { 889 | using dType = cutlass::float_e5m2_t; 890 | using dTypeCompute = cutlass::half_t; 891 | 892 | dType *input_bwd_ptr = get_ptr(input_bwd); 893 | dType *input_fwd_ptr = get_ptr(input_fwd); 894 | dType *act_grad_ptr = get_ptr(act_grad); 895 | 896 | moe_permute_topK_kernel_launcher( 897 | input_bwd_ptr, 898 | act_grad_ptr, 899 | nullptr, 900 | row_id_map_ptr, 901 | prob_ptr, 902 | num_tokens, 903 | num_topK, 904 | num_cols, 905 | 0, 906 | stream, 907 | prob_grad_ptr, 908 | input_fwd_ptr); 909 | 910 | break; 911 | } 912 | case at::ScalarType::Float8_e4m3fn: 913 | { 914 | using dType = cutlass::float_e4m3_t; 915 | using dTypeCompute = cutlass::half_t; 916 | 917 | dType *input_bwd_ptr = get_ptr(input_bwd); 918 | dType *input_fwd_ptr = get_ptr(input_fwd); 919 | dType *act_grad_ptr = get_ptr(act_grad); 920 | 921 | moe_permute_topK_kernel_launcher( 922 | input_bwd_ptr, 923 | act_grad_ptr, 924 | nullptr, 925 | row_id_map_ptr, 926 | prob_ptr, 927 | num_tokens, 928 | num_topK, 929 | num_cols, 930 | 0, 931 | stream, 932 | prob_grad_ptr, 933 | input_fwd_ptr); 934 | 935 | break; 936 | } 937 | #endif 938 | default: 939 | throw std::runtime_error("Wrong activation tensor type."); 940 | } 941 | 942 | return std::make_tuple(act_grad, prob_grad); 943 | } 944 | 945 | } // namespace grouped_gemm 946 | -------------------------------------------------------------------------------- /csrc/permute.h: -------------------------------------------------------------------------------- 1 | /************************************************************************* 2 | * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * 4 | * See LICENSE for license information. 5 | ************************************************************************/ 6 | 7 | #pragma once 8 | 9 | #include 10 | 11 | using torch::Tensor; 12 | 13 | namespace grouped_gemm { 14 | 15 | std::tuple> moe_permute_topK_op( 16 | Tensor input, 17 | Tensor indices, 18 | int64_t num_out_tokens, 19 | std::vector workspace, 20 | int64_t max_expanded_token_num); 21 | 22 | torch::Tensor moe_recover_topK_op( 23 | torch::Tensor input, 24 | torch::Tensor row_id_map, 25 | torch::Tensor prob_opt, 26 | int64_t num_tokens, 27 | int64_t num_topK); 28 | 29 | std::tuple moe_recover_topK_bwd_op( 30 | Tensor input_bwd, 31 | Tensor input_fwd, 32 | Tensor row_id_map, 33 | Tensor prob); 34 | 35 | } // namespace grouped_gemm -------------------------------------------------------------------------------- /csrc/sinkhorn.cu: -------------------------------------------------------------------------------- 1 | #include "sinkhorn.h" 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include 8 | 9 | namespace grouped_gemm { 10 | 11 | __global__ void sinkhorn_kernel(float *cost, const int rows, const int cols, float tol) { 12 | assert(rows >= cols && cols < blockDim.x); 13 | 14 | extern __shared__ float shared_memory[]; 15 | float *shared_d0 = shared_memory; // For d0 16 | float *shared_d1 = (float*)&shared_d0[rows]; // For d1 17 | float *shared_d1_old = (float*)&shared_d1[cols]; // For d1_old 18 | float *abs_diff_sum = (float*)&shared_d1_old[cols]; // For sum of absolute differences 19 | 20 | int idx = blockIdx.x * blockDim.x + threadIdx.x; 21 | 22 | // Exponentiate cost matrix 23 | if (idx < rows * cols) { 24 | for (int flat_idx = idx; flat_idx < rows * cols; flat_idx += blockDim.x) 25 | cost[flat_idx] = expf(cost[flat_idx]); 26 | } 27 | 28 | if (idx >= rows) return; 29 | 30 | // Initilization for d0, d1, d1_old and abs_diff_sum vector. 31 | for (int row_idx = idx; row_idx < rows; row_idx += blockDim.x) { 32 | shared_d0[row_idx] = 1.0f; 33 | } 34 | 35 | if (idx < cols) { 36 | shared_d1[idx] = 1.0f; 37 | shared_d1_old[idx] = 1.0f; 38 | abs_diff_sum[idx] = 0.0f; 39 | } 40 | __syncthreads(); 41 | 42 | tol = tol * cols; // error mean check --> error sum check 43 | const float eps = 1e-8; 44 | float local_error = 0.0f; 45 | do { 46 | local_error = 0.0f; 47 | 48 | // Update d0. 49 | for (int row_idx = idx; row_idx < rows; row_idx += blockDim.x) { 50 | float sum = 0.0f; 51 | for (int j = 0; j < cols; ++j) { 52 | sum += shared_d1[j] * cost[row_idx * cols + j]; 53 | } 54 | // Using __fdividef for fast division 55 | shared_d0[row_idx] = __fdividef(1.0f, (sum + eps) * rows); 56 | } 57 | __syncthreads(); 58 | 59 | // Update d1 and calculate absolute differences. 60 | if (idx < cols) { 61 | float sum = 0.0f; 62 | for (int i = 0; i < rows; ++i) { 63 | sum += shared_d0[i] * cost[i * cols + idx]; 64 | } 65 | float new_d1 = __fdividef(1.0, (sum + eps) * cols); 66 | abs_diff_sum[idx] = fabsf(new_d1 - shared_d1_old[idx]); 67 | shared_d1[idx] = new_d1; 68 | // Update shared_d1_old for the next iteration 69 | shared_d1_old[idx] = new_d1; 70 | } 71 | __syncthreads(); 72 | 73 | // Compute the sum absolute difference error. 74 | for (int i = 0; i < cols; ++i) { 75 | local_error += abs_diff_sum[i]; 76 | } 77 | 78 | } while (local_error > tol); 79 | 80 | // Final multiplication. 81 | for (int row_idx = idx; row_idx < rows; row_idx += blockDim.x) { 82 | for (int j = 0; j < cols; ++j) { 83 | cost[row_idx * cols + j] *= shared_d1[j] * shared_d0[row_idx]; 84 | } 85 | } 86 | } 87 | 88 | void sinkhorn_launch(float *cost, int rows, int cols, float tol) { 89 | int threadsPerBlock = 1024; 90 | int blocksPerGrid = 1; 91 | // Allocate enough shared memory for d0, d1, d1_old and abs_diff_sum 92 | size_t sharedMemSize = (rows + cols * 2 + cols) * sizeof(float); 93 | sinkhorn_kernel<<>>(cost, rows, cols, tol); 94 | // cudaDeviceSynchronize(); 95 | } 96 | 97 | // Wrapper function 98 | torch::Tensor sinkhorn(torch::Tensor cost, const float tol) { 99 | sinkhorn_launch(cost.data_ptr(), cost.size(0), cost.size(1), tol); 100 | 101 | cudaError_t err = cudaGetLastError(); 102 | if (err != cudaSuccess) { 103 | printf("CUDA Error: %s\n", cudaGetErrorString(err)); 104 | } 105 | return cost; 106 | } 107 | 108 | } // namespace grouped_gemm 109 | -------------------------------------------------------------------------------- /csrc/sinkhorn.h: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | namespace grouped_gemm { 4 | 5 | torch::Tensor sinkhorn(torch::Tensor cost, const float tol=0.0001); 6 | 7 | } // namespace grouped_gemm 8 | -------------------------------------------------------------------------------- /figures/figure1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fanshiqing/grouped_gemm/5c1d831ecf91b225abc91689683e7de67fbee7ef/figures/figure1.png -------------------------------------------------------------------------------- /figures/figure_groupedgemm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fanshiqing/grouped_gemm/5c1d831ecf91b225abc91689683e7de67fbee7ef/figures/figure_groupedgemm.png -------------------------------------------------------------------------------- /figures/figure_permute.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fanshiqing/grouped_gemm/5c1d831ecf91b225abc91689683e7de67fbee7ef/figures/figure_permute.png -------------------------------------------------------------------------------- /figures/figure_unpermute.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fanshiqing/grouped_gemm/5c1d831ecf91b225abc91689683e7de67fbee7ef/figures/figure_unpermute.png -------------------------------------------------------------------------------- /grouped_gemm/__init__.py: -------------------------------------------------------------------------------- 1 | import grouped_gemm.ops 2 | import grouped_gemm.backend 3 | -------------------------------------------------------------------------------- /grouped_gemm/backend.py: -------------------------------------------------------------------------------- 1 | # NOTE: Torch needs to be imported before the custom 2 | # extensions. Otherwise libc10.so cannot be found. 3 | import torch 4 | 5 | # TODO(tgale): Wrap this in a try-block with better 6 | # error message and instructions for building the 7 | # c++ operations. 8 | import grouped_gemm_backend as backend 9 | 10 | 11 | def _allocate_output(a, b, batch_sizes, trans_a, trans_b): 12 | assert not (trans_a and trans_b) 13 | assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes" 14 | assert a.ndim == 2, "Expected 2d tensor for 'a'" 15 | assert b.ndim == (2 if trans_a else 3) 16 | 17 | shape = ( 18 | (batch_sizes.shape[0], a.shape[1], b.shape[1]) 19 | if trans_a else 20 | (a.shape[0], (b.shape[1] if trans_b else b.shape[2])) 21 | ) 22 | return torch.empty(*shape, device=a.device, dtype=a.dtype) 23 | 24 | def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None): 25 | if c is None: 26 | c = _allocate_output(a, b, batch_sizes, trans_a, trans_b) 27 | backend.gmm(a, b, c, batch_sizes, trans_a, trans_b) 28 | return c 29 | 30 | def sinkhorn(cost, tol=0.0001): 31 | return backend.sinkhorn(cost, tol) 32 | 33 | def permute(input, indices, num_out_tokens, workspace, max_expanded_token_num): 34 | return backend.permute(input, indices, num_out_tokens, workspace, max_expanded_token_num) 35 | 36 | def unpermute(input, row_id_map, prob, max_tokens, num_topK): 37 | return backend.unpermute(input, row_id_map, prob, max_tokens, num_topK) 38 | 39 | def unpermute_bwd(input_bwd, input_fwd, row_id_map, prob): 40 | # TODO: @Jiang fix the case in kernel to allow None probs 41 | if prob is None: 42 | prob = torch.ones([input_bwd.size(0), 1], dtype=torch.float32, device=input_bwd.device) 43 | return backend.unpermute_bwd(input_bwd, input_fwd, row_id_map, prob) 44 | -------------------------------------------------------------------------------- /grouped_gemm/ops.py: -------------------------------------------------------------------------------- 1 | from grouped_gemm import backend 2 | import torch 3 | import warnings 4 | 5 | from sys import stderr 6 | import torch.cuda.nvtx as nvtx 7 | 8 | 9 | # For debug's convenience 10 | ENABLE_NVTX=False 11 | 12 | class GroupedGemm(torch.autograd.Function): 13 | 14 | @staticmethod 15 | def forward(ctx, a, b, batch_sizes, trans_b): 16 | assert torch.count_nonzero(batch_sizes) != 0, "Input batch_sizes should not be all zeros!" 17 | ctx.save_for_backward(a, b, batch_sizes) 18 | ctx.trans_b = trans_b 19 | return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b) 20 | 21 | @staticmethod 22 | def backward(ctx, grad): 23 | grad = grad.contiguous() 24 | a, b, batch_sizes = ctx.saved_tensors 25 | trans_b = ctx.trans_b 26 | 27 | agrad = None 28 | if ctx.needs_input_grad[0]: 29 | agrad = backend.gmm( 30 | grad, b, batch_sizes, trans_a=False, trans_b=not trans_b) 31 | 32 | bgrad = None 33 | if ctx.needs_input_grad[1]: 34 | lhs, rhs = (grad, a) if trans_b else (a, grad) 35 | bgrad = backend.gmm( 36 | lhs, rhs, batch_sizes, trans_a=True, trans_b=False) 37 | return agrad, bgrad, None, None 38 | 39 | 40 | def gmm(a, b, batch_sizes, trans_b=False): 41 | return GroupedGemm.apply(a, b, batch_sizes, trans_b) 42 | 43 | def sinkhorn_kernel(cost, tol=0.0001): 44 | return backend.sinkhorn(cost, tol) 45 | 46 | ################################################################################################ 47 | ## 48 | ## PermuteMoE topK 49 | ## 50 | ################################################################################################ 51 | 52 | class PermuteMoE_topK(torch.autograd.Function): 53 | 54 | workspace_fw=None 55 | dtype=None 56 | max_expanded_token_num=0 57 | 58 | @staticmethod 59 | def forward(ctx, 60 | input_act: torch.Tensor, 61 | indices: torch.Tensor, 62 | num_out_tokens: int, 63 | max_token_num: int): 64 | ''' 65 | indices: for topK=1, indices in a 1-d tensor of shape [num_tokens], 66 | otherwise, it's a 2-d tensor of shape [num_tokens, topK] 67 | ''' 68 | if ENABLE_NVTX: 69 | nvtx.range_push("permute_topK forward") 70 | # Empty input check 71 | if not input_act.numel(): 72 | if ENABLE_NVTX: 73 | nvtx.range_pop() 74 | return input_act, None 75 | 76 | # For top1 case, view the indices as 2D tensor to unify the shape for topk>=2 cases. 77 | if indices.dim() == 1: 78 | indices = indices.view(-1, 1) 79 | 80 | # Device check 81 | if input_act.is_cpu: 82 | raise RuntimeError("[Error] The input `input_act` of permute_topK op is on the device: CPU!") 83 | if indices.is_cpu: 84 | warnings.warn("The input `indices` of permute_topK op is on the device: CPU!") 85 | expert_for_rows = expert_for_rows.cuda() 86 | 87 | # Shape check 88 | if input_act.size(0) != indices.size(0): 89 | raise RuntimeError(f"[Error] permute_topK op input `indices` shape mismatch! " 90 | f"Expect {input_act.size(0)}, but got {indices.size(0)}.") 91 | 92 | # Data type check 93 | if indices.dtype != torch.int32: 94 | warnings.warn(f"The data type of the input `indices` of permute_topK op is {indices.dtype}! " 95 | "The recommended type is torch.int32.") 96 | indices = indices.to(torch.int32) 97 | 98 | # Contiguous check 99 | if not input_act.is_contiguous(): 100 | warnings.warn("The input `input_act` of permute_topK op is discontiguous!") 101 | input_act = input_act.contiguous() 102 | if not indices.is_contiguous(): 103 | warnings.warn("The input `indices` of permute_topK op is discontiguous!") 104 | indices = indices.contiguous() 105 | 106 | num_topK = indices.size(1) 107 | 108 | input_max_expanded_token_num = max(max_token_num, input_act.size(0)) * num_topK 109 | if PermuteMoE_topK.max_expanded_token_num < input_max_expanded_token_num: 110 | PermuteMoE_topK.max_expanded_token_num = input_max_expanded_token_num 111 | PermuteMoE_topK.workspace_fw = [] 112 | 113 | if PermuteMoE_topK.dtype != input_act.dtype: 114 | PermuteMoE_topK.dtype = input_act.dtype 115 | PermuteMoE_topK.workspace_fw = [] 116 | 117 | permuted_act, row_id_map, PermuteMoE_topK.workspace_fw = backend.permute( 118 | input_act, 119 | indices, 120 | num_out_tokens, 121 | PermuteMoE_topK.workspace_fw, 122 | PermuteMoE_topK.max_expanded_token_num) 123 | 124 | ctx.row_id_map = row_id_map 125 | ctx.num_tokens = indices.size(0) 126 | ctx.num_topK = num_topK 127 | if ENABLE_NVTX: 128 | nvtx.range_pop() 129 | return permuted_act, row_id_map 130 | 131 | 132 | @staticmethod 133 | def backward(ctx, permuted_act_grad, _): 134 | if ENABLE_NVTX: 135 | nvtx.range_push("permute_topK backward") 136 | # Empty input check 137 | if not permuted_act_grad.numel(): 138 | if ENABLE_NVTX: 139 | nvtx.range_pop() 140 | return permuted_act_grad, None, None, None 141 | 142 | if not permuted_act_grad.is_contiguous(): 143 | permuted_act_grad = permuted_act_grad.contiguous() 144 | 145 | row_id_map = ctx.row_id_map 146 | num_tokens = ctx.num_tokens 147 | num_topK = ctx.num_topK 148 | 149 | unpermuted_act_grad = backend.unpermute( 150 | permuted_act_grad, 151 | row_id_map, 152 | torch.tensor([]), 153 | num_tokens, 154 | num_topK) 155 | if ENABLE_NVTX: 156 | nvtx.range_pop() 157 | return unpermuted_act_grad, None, None, None 158 | 159 | ################################################################################################ 160 | ## 161 | ## UnpermuteMoE topK 162 | ## 163 | ################################################################################################ 164 | 165 | class UnpermuteMoE_topK(torch.autograd.Function): 166 | 167 | @staticmethod 168 | def forward(ctx, 169 | input_act: torch.Tensor, 170 | row_id_map: torch.Tensor, 171 | probs: torch.Tensor = None): 172 | if ENABLE_NVTX: 173 | nvtx.range_push("unpermute_topK forward") 174 | # Empty input check 175 | if not input_act.numel(): 176 | ctx.probs = probs 177 | if ENABLE_NVTX: 178 | nvtx.range_pop() 179 | return input_act 180 | 181 | # Device check 182 | if input_act.is_cpu: 183 | raise RuntimeError("[Error] The input `input_act` of unpermute_topK op is on the device: CPU!") 184 | if row_id_map.is_cpu: 185 | warnings.warn("The input `row_id_map` of unpermute_topK op is on the device: CPU!") 186 | row_id_map = row_id_map.cuda() 187 | if probs is not None and probs.is_cpu: 188 | warnings.warn("The input `probs` of unpermute_topK op is on the device: CPU!") 189 | probs = probs.cuda() 190 | 191 | # Shape check 192 | if probs is not None and row_id_map.size(0) != probs.size(0) * probs.size(1): 193 | raise RuntimeError(f"[Error] unpermute_topK op input `probs` shape mismatch! " 194 | f"Expect {row_id_map.size(0)}, but got {probs.size(0) * probs.size(1)}.") 195 | 196 | # Data type check 197 | if row_id_map.dtype != torch.int32: 198 | warnings.warn(f"The data type of the input `row_id_map` of unpermute_topK op is {row_id_map.dtype}! " 199 | "The recommended type is torch.int32.") 200 | row_id_map = row_id_map.to(torch.int32) 201 | if probs is not None and probs.dtype != torch.float32: 202 | warnings.warn(f"The data type of the input `probs` of unpermute_topK op is {probs.dtype}! " 203 | "The recommended type is torch.float32.") 204 | probs = probs.to(torch.float32) 205 | 206 | # Contiguous check 207 | if not input_act.is_contiguous(): 208 | warnings.warn("The input `input_act` of unpermute_topK op is discontiguous!") 209 | input_act = input_act.contiguous() 210 | if not row_id_map.is_contiguous(): 211 | warnings.warn("The input `row_id_map` of unpermute_topK op is discontiguous!") 212 | row_id_map = row_id_map.contiguous() 213 | if probs is not None and not probs.is_contiguous(): 214 | warnings.warn("The input `probs` of unpermute_topK op is discontiguous!") 215 | probs = probs.contiguous() 216 | 217 | num_tokens = probs.size(0) if probs is not None else input_act.size(0) 218 | num_topK = probs.size(1) if probs is not None else 1 219 | 220 | unpermuted_output = backend.unpermute( 221 | input_act, 222 | row_id_map, 223 | probs if probs is not None else torch.tensor([]), 224 | num_tokens, 225 | num_topK) 226 | 227 | ctx.save_for_backward(input_act, row_id_map, probs) 228 | if ENABLE_NVTX: 229 | nvtx.range_pop() 230 | return unpermuted_output 231 | 232 | @staticmethod 233 | def backward(ctx, unpermuted_act_grad): 234 | if ENABLE_NVTX: 235 | nvtx.range_push("unpermute_topK backward") 236 | # Empty input check 237 | if not unpermuted_act_grad.numel(): 238 | if ENABLE_NVTX: 239 | nvtx.range_pop() 240 | return unpermuted_act_grad, None, ctx.probs 241 | 242 | if not unpermuted_act_grad.is_contiguous(): 243 | unpermuted_act_grad = unpermuted_act_grad.contiguous() 244 | 245 | input_act, row_id_map, probs = ctx.saved_tensors 246 | 247 | act_grad = None 248 | if ctx.needs_input_grad[0]: 249 | act_grad, prob_grad = backend.unpermute_bwd( 250 | unpermuted_act_grad, 251 | input_act, 252 | row_id_map, 253 | probs) 254 | 255 | if not ctx.needs_input_grad[2]: 256 | prob_grad = None 257 | if ENABLE_NVTX: 258 | nvtx.range_pop() 259 | return act_grad, None, prob_grad 260 | 261 | def permute(input_act, indices, num_out_tokens=None, max_token_num=0): 262 | num_out_tokens = 0 if num_out_tokens is None else num_out_tokens 263 | return PermuteMoE_topK.apply(input_act, indices, num_out_tokens, max_token_num) 264 | 265 | def unpermute(input_act, row_id_map, probs=None): 266 | return UnpermuteMoE_topK.apply(input_act, row_id_map, probs) 267 | -------------------------------------------------------------------------------- /grouped_gemm/ops_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import itertools 3 | 4 | from absl.testing import parameterized 5 | from grouped_gemm import ops 6 | import numpy as np 7 | import torch 8 | 9 | 10 | def allclose(x, y, pct=2.0): 11 | mask = torch.isclose(x, y, rtol=1e-5) 12 | pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100 13 | if pct_diff > pct: 14 | print(x[torch.logical_not(mask)], y[torch.logical_not(mask)]) 15 | print("{:.2f}% of values not close.".format(pct_diff)) 16 | return False 17 | return True 18 | 19 | 20 | def add_transpose_flags(x): 21 | out = [] 22 | for y in x: 23 | for f in [(False,), (True,)]: 24 | out.append(y + f) 25 | return out 26 | 27 | 28 | _TEST_PROBLEMS = add_transpose_flags(( 29 | (1, 128, 128, 128), 30 | (8, 128, 128, 128), 31 | (16, 128, 128, 128), 32 | (1, 128, 256, 512), 33 | (8, 128, 256, 512), 34 | (16, 128, 256, 512), 35 | )) 36 | 37 | 38 | def randn(bs, x, y): 39 | out = (torch.rand(bs, x, y) - 0.5 * 2) / (y * x) 40 | return out.cuda().to(torch.bfloat16) 41 | 42 | 43 | def gmm(a, b, batch_sizes, trans_b=False): 44 | batch_sizes = batch_sizes.numpy() 45 | 46 | out = [] 47 | start = 0 48 | for i, size in enumerate(batch_sizes): 49 | rhs = b[i, :, :].t() if trans_b else b[i, :, :] 50 | out.append(a[start:start + size, :] @ rhs) 51 | start += size 52 | return torch.cat(out) 53 | 54 | 55 | @parameterized.parameters(*_TEST_PROBLEMS) 56 | class OpsTest(parameterized.TestCase): 57 | 58 | def testGroupedGemm_FixedSizes(self, z, m, k, n, trans_b): 59 | torch.manual_seed(0) 60 | a = randn(z, m, k).view(-1, k) 61 | b = randn(z, n, k) if trans_b else randn(z, k, n) 62 | batch_sizes = torch.tensor([m] * z) 63 | 64 | a.requires_grad_(True) 65 | b.requires_grad_(True) 66 | a_ref = a.detach().clone().requires_grad_(True) 67 | b_ref = b.detach().clone().requires_grad_(True) 68 | 69 | out = ops.gmm(a, b, batch_sizes, trans_b) 70 | expected_out = gmm(a_ref, b_ref, batch_sizes, trans_b) 71 | self.assertTrue(allclose(out, expected_out)) 72 | 73 | # Check gradients. 74 | out.sum().backward() 75 | expected_out.sum().backward() 76 | self.assertTrue(allclose(a.grad, a_ref.grad)) 77 | self.assertTrue(allclose(b.grad, b_ref.grad)) 78 | 79 | def testGroupedGemm_VariableSizes(self, z, m, k, n, trans_b): 80 | torch.manual_seed(0) 81 | a = randn(z, m, k).view(-1, k) 82 | b = randn(z, n, k) if trans_b else randn(z, k, n) 83 | 84 | dist = torch.rand(z, ) 85 | dist /= dist.sum() 86 | batch_sizes = (dist * m).to(torch.long) 87 | error = m * z - batch_sizes.sum() 88 | batch_sizes[-1] += error 89 | assert batch_sizes.sum() == (m * z) 90 | 91 | a.requires_grad_(True) 92 | b.requires_grad_(True) 93 | a_ref = a.detach().clone().requires_grad_(True) 94 | b_ref = b.detach().clone().requires_grad_(True) 95 | 96 | out = ops.gmm(a, b, batch_sizes, trans_b) 97 | expected_out = gmm(a_ref, b_ref, batch_sizes, trans_b) 98 | self.assertTrue(allclose(out, expected_out)) 99 | 100 | # Check gradients. 101 | out.sum().backward() 102 | expected_out.sum().backward() 103 | self.assertTrue(allclose(a.grad, a_ref.grad)) 104 | self.assertTrue(allclose(b.grad, b_ref.grad)) 105 | 106 | 107 | 108 | if __name__ == '__main__': 109 | unittest.main() 110 | -------------------------------------------------------------------------------- /grouped_gemm/permute_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | import torch 6 | import triton 7 | import torch.cuda.nvtx as nvtx 8 | 9 | try: 10 | from grouped_gemm.ops import permute as permute_topK, unpermute as unpermute_topK 11 | except ImportError: 12 | print("grouped-gemm toolkit is not installed. Fall back to local import.") 13 | # For local debug 14 | import sys 15 | sys.path.append("..") 16 | from ops import permute as permute_topK, unpermute as unpermute_topK 17 | 18 | def permute(tokens, indices, expand_factor: int = 1, is_fp8=False): 19 | """Permute the tokens based on the indices. 20 | 21 | Args: 22 | tokens (torch.Tensor): The input token tensor. 23 | indices (torch.Tensor): The token2expert indices tensor. 24 | 25 | Returns: 26 | torch.Tensor: The permuted tensor. 27 | """ 28 | expand_factor = indices.size(1) 29 | 30 | flatten_indices = indices.view(-1) 31 | sorted_indices = torch.argsort(flatten_indices, stable=True) 32 | permuted_tokens = tokens.index_select(0, sorted_indices // expand_factor) 33 | return permuted_tokens, sorted_indices 34 | 35 | 36 | def unpermute(permuted_tokens, sorted_indices, probs: torch.Tensor = None, merge_factor: int = 1): 37 | """Unpermute the sorted tokens based on the indices. 38 | 39 | Args: 40 | permuted_tokens (torch.Tensor): The permuted token tensor. 41 | sorted_indices (torch.Tensor): The sorted indices tensor. 42 | probs (torch.Tensor, optional): The probabilities tensor. Defaults to None. 43 | merge_factor (int, optional): The merge factor. Defaults to 1. 44 | 45 | Returns: 46 | torch.Tensor: The unpermuted tensor. 47 | """ 48 | merge_factor = probs.size(1) 49 | 50 | if merge_factor > 1: 51 | assert probs is not None 52 | assert ( 53 | probs.size(0) == permuted_tokens.size(0) // merge_factor 54 | ), f"{probs.size()} {permuted_tokens.size()}" 55 | if probs is not None: 56 | assert probs.size(0) == permuted_tokens.size(0) // merge_factor 57 | assert ( 58 | probs.size(1) == merge_factor 59 | ), f"probs size {probs.size()} merge_factor {merge_factor}" 60 | 61 | # unpermuted_tokens = torch.zeros_like(permuted_tokens) 62 | unpermuted_tokens = permuted_tokens.index_copy(0, sorted_indices, permuted_tokens) 63 | 64 | unpermuted_tokens = unpermuted_tokens.reshape(-1, merge_factor, permuted_tokens.size(-1)) 65 | 66 | if probs is not None: 67 | dtype = unpermuted_tokens.dtype 68 | unpermuted_tokens = unpermuted_tokens * probs.unsqueeze(-1) 69 | unpermuted_tokens = unpermuted_tokens.to(dtype) 70 | unpermuted_tokens = unpermuted_tokens.sum(dim=1) 71 | 72 | return unpermuted_tokens 73 | 74 | def permute_topK_test( 75 | dtype, 76 | num_token, 77 | num_expert, 78 | hidden_size, 79 | num_topK, 80 | PRINT, 81 | BENCHMARK): 82 | 83 | print(f"{dtype} token:{num_token} hidden_size:{hidden_size} expert:{num_expert} topK:{num_topK}") 84 | 85 | is_fp8 = dtype in [torch.float8_e5m2, torch.float8_e4m3fn] 86 | 87 | permute_input = torch.rand((num_token, hidden_size), dtype=torch.float32).cuda() 88 | # for i in range(num_token): 89 | # for j in range(hidden_size): 90 | # permute_input[i][j] = i * 100 + j 91 | permute_input = permute_input.to(dtype) 92 | if is_fp8: 93 | permute_input = permute_input.half() 94 | 95 | permute_input.requires_grad_(True) 96 | 97 | if num_token > 0: 98 | indices = torch.stack([torch.randperm(num_expert)[:num_topK] for _ in range(num_token)]) 99 | else: 100 | indices = torch.empty((num_token, num_topK)) 101 | indices = indices.to(torch.int32).cuda() 102 | 103 | # probs = torch.tensor([[0.1, 0.9], 104 | # [0.2, 0.8], 105 | # [0.3, 0.7]]) 106 | # 0.5 107 | # probs = torch.ones_like(indices) / 2 108 | # rand 109 | probs = torch.rand(num_token, num_topK).cuda() 110 | row_sums = probs.sum(dim=1, keepdim=True) 111 | probs = probs / row_sums 112 | probs.requires_grad_(True) 113 | 114 | if PRINT: 115 | print(permute_input) 116 | print(indices) 117 | print(probs) 118 | 119 | ################################################################################################################################### 120 | # 121 | # PyTorch 122 | # 123 | ################################################################################################################################### 124 | nvtx.range_push("PyTorch permute forward") 125 | permute_output, sorted_indices = permute(permute_input, indices, num_topK, is_fp8) 126 | nvtx.range_pop() 127 | 128 | permute_bwd_input = torch.rand_like(permute_output) 129 | # for i in range(num_token * num_topK): 130 | # for j in range(hidden_size): 131 | # permute_bwd_input[i][j] = i * 100 + j 132 | 133 | nvtx.range_push("PyTorch permute backward") 134 | permute_output.backward(permute_bwd_input, retain_graph=True) 135 | nvtx.range_pop() 136 | 137 | unpermute_input = permute_output.detach() 138 | unpermute_input.requires_grad_(True) 139 | 140 | unpermute_output = unpermute( 141 | unpermute_input, sorted_indices, probs=probs, merge_factor=num_topK) 142 | 143 | if PRINT: 144 | print("--------------unpermute fwd permute_input--------------") 145 | print(unpermute_input) 146 | print("--------------unpermute fwd output--------------") 147 | print(unpermute_output) 148 | 149 | unpermute_bwd_input = torch.rand_like(unpermute_output) 150 | # for i in range(num_token): 151 | # for j in range(hidden_size): 152 | # unpermute_bwd_input[i][j] = i * 2000 + j * 20 153 | 154 | if PRINT: 155 | print("--------------unpermute bwd permute_input--------------") 156 | print(unpermute_bwd_input) 157 | 158 | unpermute_output.backward(unpermute_bwd_input, retain_graph=True) 159 | if PRINT: 160 | print("--------------unpermute bwd output act grad--------------") 161 | print(permute_output.grad) 162 | print("--------------unpermute bwd output probs grad--------------") 163 | print(probs.grad) 164 | 165 | ################################################################################################################################### 166 | # 167 | # Mine 168 | # 169 | ################################################################################################################################### 170 | new_permute_input = permute_input.detach().to(dtype) 171 | new_permute_bwd_input = permute_bwd_input.detach().to(dtype) 172 | new_unpermute_bwd_input = unpermute_bwd_input.detach().to(dtype) 173 | new_permute_input.requires_grad_(True) 174 | 175 | new_permute_output, row_id_map = permute_topK(new_permute_input, indices) 176 | 177 | assert torch.allclose(permute_output.float(), new_permute_output.float()) 178 | 179 | if PRINT: 180 | print("--------------row_id_map--------------") 181 | print(row_id_map) 182 | print("--------------new_permute_input--------------") 183 | print(new_permute_input) 184 | print("--------------new_permute_output--------------") 185 | print(new_permute_output) 186 | 187 | new_permute_output.backward(new_permute_bwd_input, retain_graph=True) 188 | 189 | if torch.allclose(permute_input.grad.float(), new_permute_input.grad.float()) == False: 190 | original_inputs = new_permute_input.grad.float().cpu().numpy().flatten() 191 | original_output = permute_input.grad.float().cpu().numpy().flatten() 192 | max_abs_error = abs(original_inputs - original_output).max() 193 | print(f"permute_topK bwd max error (mine vs pytorch): \t\t\t{max_abs_error:.3e} ({dtype})") 194 | 195 | if PRINT: 196 | print(permute_input.grad) 197 | print(new_permute_input.grad) 198 | 199 | new_probs = probs.detach() 200 | new_probs.requires_grad_(True) 201 | new_unpermute_input = new_permute_output.detach() 202 | new_unpermute_input.requires_grad_(True) 203 | 204 | print("new_probs=", new_probs) 205 | new_unpermute_output = unpermute_topK(new_unpermute_input, row_id_map, new_probs) 206 | 207 | if torch.allclose(unpermute_output.float(), new_unpermute_output.float()) == False: 208 | original_inputs = unpermute_output.float().cpu().detach().numpy().flatten() 209 | original_output = new_unpermute_output.float().cpu().detach().numpy().flatten() 210 | max_abs_error = abs(original_inputs - original_output).max() 211 | print(f"unpermute_topK fwd max error (mine vs pytorch): \t\t{max_abs_error:.3e} ({dtype})") 212 | 213 | if PRINT: 214 | print(unpermute_output) 215 | print(new_unpermute_output) 216 | 217 | new_unpermute_output.backward(new_unpermute_bwd_input, retain_graph=True) 218 | 219 | if torch.allclose(unpermute_input.grad.float(), new_unpermute_input.grad.float()) == False: 220 | original_inputs = unpermute_input.grad.float().cpu().detach().numpy().flatten() 221 | original_output = new_unpermute_input.grad.float().cpu().detach().numpy().flatten() 222 | max_abs_error = abs(original_inputs - original_output).max() 223 | print(f"unpermute_topK bwd act_grad max error (mine vs pytorch): \t{max_abs_error:.3e} ({dtype})") 224 | if PRINT: 225 | print(new_unpermute_input.grad) 226 | print(unpermute_input.grad) 227 | 228 | if num_topK > 1 and torch.allclose(new_probs.grad, probs.grad) == False: 229 | original_inputs = new_probs.grad.float().cpu().detach().numpy().flatten() 230 | original_output = probs.grad.float().cpu().detach().numpy().flatten() 231 | max_abs_error = abs(original_inputs - original_output).max() 232 | print(f"unpermute_topK bwd prob_grad max error (mine vs pytorch): \t{max_abs_error:.3e} ({dtype})") 233 | if PRINT: 234 | print(new_probs.grad) 235 | print(probs.grad) 236 | 237 | if not permute_input.numel(): 238 | print("Empty permute_input activation test passed.") 239 | return 240 | 241 | ################################################################################################################################### 242 | # 243 | # Benchmark 244 | # 245 | ################################################################################################################################### 246 | def backward_wrapper(act, backward_input, forward_input=[], retain_graph=True, accumulate_grad=False): 247 | # Set forward_input.grad to None to avoid grad accumulation. 248 | if accumulate_grad == False: 249 | for i in forward_input: 250 | i.grad = None 251 | return act.backward(backward_input, retain_graph=retain_graph) 252 | 253 | if BENCHMARK: 254 | print(f"----permute topK----") 255 | t = perf_test_cuda_kernel(lambda: permute(permute_input, indices, 2)) 256 | print(f"pytorch fwd: {t:.3f} ms") 257 | t = perf_test_cuda_kernel(lambda: permute_topK(new_permute_input, indices)) 258 | print(f"new fwd: {t:.3f} ms") 259 | 260 | t = perf_test_cuda_kernel( 261 | lambda: backward_wrapper(permute_output, permute_bwd_input, forward_input=[permute_input], retain_graph=True, accumulate_grad=False)) 262 | print(f"pytorch bwd: {t:.3f} ms") 263 | t = perf_test_cuda_kernel( 264 | lambda: backward_wrapper(new_permute_output, new_permute_bwd_input, forward_input=[new_permute_input], retain_graph=True, accumulate_grad=False)) 265 | print(f"new bwd: {t:.3f} ms") 266 | 267 | print(f"----unpermute topK----") 268 | t = perf_test_cuda_kernel( 269 | lambda: unpermute(unpermute_input, sorted_indices, probs=probs, merge_factor=num_topK)) 270 | print(f"pytorch fwd: {t:.3f} ms") 271 | t = perf_test_cuda_kernel( 272 | lambda: unpermute_topK(new_unpermute_input, row_id_map, new_probs)) 273 | print(f"new fwd: {t:.3f} ms") 274 | 275 | t = perf_test_cuda_kernel( 276 | lambda: backward_wrapper(unpermute_output, unpermute_bwd_input, forward_input=[unpermute_input, probs], retain_graph=True, accumulate_grad=False)) 277 | print(f"pytorch bwd: {t:.3f} ms") 278 | t = perf_test_cuda_kernel( 279 | lambda: backward_wrapper(new_unpermute_output, new_unpermute_bwd_input, forward_input=[new_unpermute_input, new_probs], retain_graph=True, accumulate_grad=False)) 280 | print(f"new bwd: {t:.3f} ms") 281 | 282 | 283 | def perf_test_cuda_kernel(cuda_kernel_fn): 284 | if torch.cuda.is_available(): 285 | # create CUDA event 286 | start_event = torch.cuda.Event(enable_timing=True) 287 | end_event = torch.cuda.Event(enable_timing=True) 288 | 289 | # warmup 290 | for _ in range(50): 291 | cuda_kernel_fn() 292 | 293 | start_event.record() 294 | for _ in range(100): 295 | cuda_kernel_fn() 296 | end_event.record() 297 | torch.cuda.synchronize() 298 | 299 | elapsed_time_ms = start_event.elapsed_time(end_event) 300 | # print(f"Elapsed Time: {elapsed_time_ms / 100} ms") 301 | return elapsed_time_ms / 100 302 | else: 303 | print("CUDA is not available.") 304 | 305 | def test_permute_topK(): 306 | 307 | torch.manual_seed(1) 308 | 309 | num_token = 4096 * 2 310 | num_expert = 8 311 | hidden_size = 4096 312 | num_topK = 2 313 | 314 | PRINT=False 315 | Benchmark = False 316 | print("GPU:", torch.cuda.get_device_name(0)) 317 | 318 | dtype = torch.float32 319 | permute_topK_test(dtype, num_token, num_expert, 320 | hidden_size, num_topK, PRINT, Benchmark) 321 | dtype = torch.float16 322 | permute_topK_test(dtype, num_token, num_expert, 323 | hidden_size, num_topK, False, Benchmark) 324 | dtype = torch.bfloat16 325 | permute_topK_test(dtype, num_token, num_expert, 326 | hidden_size, num_topK, False, Benchmark) 327 | dtype = torch.float8_e5m2 328 | permute_topK_test(dtype, num_token, num_expert, 329 | hidden_size, num_topK, False, Benchmark) 330 | dtype = torch.float8_e4m3fn 331 | permute_topK_test(dtype, num_token, num_expert, 332 | hidden_size, num_topK, False, Benchmark) 333 | dtype = torch.bfloat16 334 | permute_topK_test(dtype, num_token, 4, hidden_size, 1, False, Benchmark) 335 | permute_topK_test(dtype, num_token, 5, hidden_size, 2, False, Benchmark) 336 | permute_topK_test(dtype, num_token, 6, hidden_size, 3, False, Benchmark) 337 | permute_topK_test(dtype, num_token, 7, hidden_size, 4, False, Benchmark) 338 | permute_topK_test(dtype, num_token, 8, hidden_size, 5, False, Benchmark) 339 | num_token = 0 340 | permute_topK_test(dtype, num_token, 8, hidden_size, 5, False, Benchmark) 341 | 342 | if __name__ == "__main__": 343 | test_permute_topK() -------------------------------------------------------------------------------- /grouped_gemm/sinkhorn_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import itertools 3 | 4 | from absl.testing import parameterized 5 | from grouped_gemm import ops, ops_test 6 | import numpy as np 7 | import torch 8 | 9 | # Notes: the source of this func implementation: 10 | # https://github.com/NVIDIA/Megatron-LM/blob/2bc6cd307a11423928c675f741e79e03df23e721/megatron/core/transformer/switch_mlp.py#L17-L31 11 | def baseline_sinkhorn(cost, tol=0.0001): 12 | "Sinkhorn based MoE routing function" 13 | cost = torch.exp(cost) 14 | d0 = torch.ones(cost.size(0), device=cost.device, dtype=cost.dtype) 15 | d1 = torch.ones(cost.size(1), device=cost.device, dtype=cost.dtype) 16 | 17 | eps = 0.00000001 18 | error = 1e9 19 | d1_old = d1 20 | iter_count = 0 21 | while error > tol: 22 | d0 = (1 / d0.size(0)) * 1 / (torch.sum(d1 * cost, 1) + eps) 23 | d1 = (1 / d1.size(0)) * 1 / (torch.sum(d0.unsqueeze(1) * cost, 0) + eps) 24 | error = torch.mean(torch.abs(d1_old - d1)) 25 | d1_old = d1 26 | iter_count = iter_count + 1 27 | return d1 * cost * d0.unsqueeze(1) 28 | 29 | _TEST_PROBLEMS = ( 30 | # (128, 2, 0.1), 31 | # (256, 2, 0.1), 32 | # (1024, 4, 0.01), 33 | # (2048, 8, 0.0001), 34 | (4096, 8, 0.0001), 35 | (8192, 8, 0.0001), 36 | (8192, 16, 0.0001), 37 | ) 38 | 39 | @parameterized.parameters(*_TEST_PROBLEMS) 40 | class OpsTest(parameterized.TestCase): 41 | 42 | def test_sinkhorn_kernel(self, m, k, tol=0.0001): 43 | start = torch.cuda.Event(enable_timing=True) 44 | end = torch.cuda.Event(enable_timing=True) 45 | 46 | torch.manual_seed(0) 47 | cost = torch.rand(m, k, device='cuda', dtype=torch.float32) 48 | 49 | start.record() 50 | expected_out = baseline_sinkhorn(cost, tol) 51 | end.record() 52 | torch.cuda.synchronize() 53 | baseline_time = start.elapsed_time(end) 54 | 55 | start.record() 56 | out = ops.sinkhorn_kernel(cost, tol) 57 | end.record() 58 | torch.cuda.synchronize() 59 | kernel_time = start.elapsed_time(end) 60 | print("===================================") 61 | print("Problem size: [%d]x[%d], kernel speedup: %fX" % (m, k, baseline_time/kernel_time)) 62 | print("===================================") 63 | self.assertTrue(ops_test.allclose(out, expected_out)) 64 | 65 | if __name__ == '__main__': 66 | unittest.main() -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from setuptools import setup, find_packages 4 | import torch 5 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 6 | 7 | 8 | # Supported NVIDIA GPU architectures. 9 | NVIDIA_SUPPORTED_ARCHS = {"7.0", "7.5", "8.0", "8.6", "8.9", "9.0"} 10 | 11 | # TORCH_CUDA_ARCH_LIST can have one or more architectures, 12 | # e.g. "9.0" or "7.0 7.2 7.5 8.0 8.6 8.7 9.0+PTX". Here, 13 | # the "9.0+PTX" option asks the 14 | # compiler to additionally include PTX code that can be runtime-compiled 15 | # and executed on the 8.6 or newer architectures. While the PTX code will 16 | # not give the best performance on the newer architectures, it provides 17 | # forward compatibility. 18 | env_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST", None) 19 | if env_arch_list: 20 | # Let PyTorch builder to choose device to target for. 21 | device_capability = "" 22 | else: 23 | device_capability = torch.cuda.get_device_capability() 24 | device_capability = f"{device_capability[0]}{device_capability[1]}" 25 | 26 | cwd = Path(os.path.dirname(os.path.abspath(__file__))) 27 | 28 | nvcc_flags = [ 29 | "-std=c++17", # NOTE: CUTLASS requires c++17 30 | "-DENABLE_BF16", # Enable BF16 for cuda_version >= 11 31 | # "-DENABLE_FP8", # Enable FP8 for cuda_version >= 11.8 32 | ] 33 | 34 | if device_capability: 35 | nvcc_flags.extend([ 36 | f"--generate-code=arch=compute_{device_capability},code=sm_{device_capability}", 37 | f"-DGROUPED_GEMM_DEVICE_CAPABILITY={device_capability}", 38 | ]) 39 | 40 | ext_modules = [ 41 | CUDAExtension( 42 | "grouped_gemm_backend", 43 | ["csrc/ops.cu", "csrc/grouped_gemm.cu", "csrc/sinkhorn.cu", "csrc/permute.cu"], 44 | include_dirs = [ 45 | f"{cwd}/third_party/cutlass/include/" 46 | ], 47 | extra_compile_args={ 48 | "cxx": [ 49 | "-fopenmp", "-fPIC", "-Wno-strict-aliasing" 50 | ], 51 | "nvcc": nvcc_flags, 52 | } 53 | ) 54 | ] 55 | 56 | setup( 57 | name="grouped_gemm", 58 | version="1.1.4", 59 | author="Trevor Gale, Jiang Shao, Shiqing Fan", 60 | author_email="tgale@stanford.edu, jiangs@nvidia.com, shiqingf@nvidia.com", 61 | description="GEMM Grouped", 62 | url="https://github.com/fanshiqing/grouped_gemm", 63 | classifiers=[ 64 | "Programming Language :: Python :: 3", 65 | "License :: OSI Approved :: BSD License", 66 | "Operating System :: Unix", 67 | ], 68 | packages=find_packages(), 69 | ext_modules=ext_modules, 70 | cmdclass={"build_ext": BuildExtension}, 71 | install_requires=["absl-py", "numpy", "torch"], 72 | ) 73 | --------------------------------------------------------------------------------