├── .gitignore
├── .gitmodules
├── LICENSE
├── README.md
├── csrc
├── grouped_gemm.cu
├── grouped_gemm.h
├── ops.cu
├── permute.cu
├── permute.h
├── sinkhorn.cu
└── sinkhorn.h
├── figures
├── figure1.png
├── figure_groupedgemm.png
├── figure_permute.png
└── figure_unpermute.png
├── grouped_gemm
├── __init__.py
├── backend.py
├── ops.py
├── ops_test.py
├── permute_test.py
└── sinkhorn_test.py
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *~
2 | build
3 | *.egg-info
4 | dist
5 | __pycache__
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "third_party/cutlass"]
2 | path = third_party/cutlass
3 | url = https://github.com/NVIDIA/cutlass
4 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License. Apache License
202 | Version 2.0, January 2004
203 | http://www.apache.org/licenses/
204 |
205 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
206 |
207 | 1. Definitions.
208 |
209 | "License" shall mean the terms and conditions for use, reproduction,
210 | and distribution as defined by Sections 1 through 9 of this document.
211 |
212 | "Licensor" shall mean the copyright owner or entity authorized by
213 | the copyright owner that is granting the License.
214 |
215 | "Legal Entity" shall mean the union of the acting entity and all
216 | other entities that control, are controlled by, or are under common
217 | control with that entity. For the purposes of this definition,
218 | "control" means (i) the power, direct or indirect, to cause the
219 | direction or management of such entity, whether by contract or
220 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
221 | outstanding shares, or (iii) beneficial ownership of such entity.
222 |
223 | "You" (or "Your") shall mean an individual or Legal Entity
224 | exercising permissions granted by this License.
225 |
226 | "Source" form shall mean the preferred form for making modifications,
227 | including but not limited to software source code, documentation
228 | source, and configuration files.
229 |
230 | "Object" form shall mean any form resulting from mechanical
231 | transformation or translation of a Source form, including but
232 | not limited to compiled object code, generated documentation,
233 | and conversions to other media types.
234 |
235 | "Work" shall mean the work of authorship, whether in Source or
236 | Object form, made available under the License, as indicated by a
237 | copyright notice that is included in or attached to the work
238 | (an example is provided in the Appendix below).
239 |
240 | "Derivative Works" shall mean any work, whether in Source or Object
241 | form, that is based on (or derived from) the Work and for which the
242 | editorial revisions, annotations, elaborations, or other modifications
243 | represent, as a whole, an original work of authorship. For the purposes
244 | of this License, Derivative Works shall not include works that remain
245 | separable from, or merely link (or bind by name) to the interfaces of,
246 | the Work and Derivative Works thereof.
247 |
248 | "Contribution" shall mean any work of authorship, including
249 | the original version of the Work and any modifications or additions
250 | to that Work or Derivative Works thereof, that is intentionally
251 | submitted to Licensor for inclusion in the Work by the copyright owner
252 | or by an individual or Legal Entity authorized to submit on behalf of
253 | the copyright owner. For the purposes of this definition, "submitted"
254 | means any form of electronic, verbal, or written communication sent
255 | to the Licensor or its representatives, including but not limited to
256 | communication on electronic mailing lists, source code control systems,
257 | and issue tracking systems that are managed by, or on behalf of, the
258 | Licensor for the purpose of discussing and improving the Work, but
259 | excluding communication that is conspicuously marked or otherwise
260 | designated in writing by the copyright owner as "Not a Contribution."
261 |
262 | "Contributor" shall mean Licensor and any individual or Legal Entity
263 | on behalf of whom a Contribution has been received by Licensor and
264 | subsequently incorporated within the Work.
265 |
266 | 2. Grant of Copyright License. Subject to the terms and conditions of
267 | this License, each Contributor hereby grants to You a perpetual,
268 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
269 | copyright license to reproduce, prepare Derivative Works of,
270 | publicly display, publicly perform, sublicense, and distribute the
271 | Work and such Derivative Works in Source or Object form.
272 |
273 | 3. Grant of Patent License. Subject to the terms and conditions of
274 | this License, each Contributor hereby grants to You a perpetual,
275 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
276 | (except as stated in this section) patent license to make, have made,
277 | use, offer to sell, sell, import, and otherwise transfer the Work,
278 | where such license applies only to those patent claims licensable
279 | by such Contributor that are necessarily infringed by their
280 | Contribution(s) alone or by combination of their Contribution(s)
281 | with the Work to which such Contribution(s) was submitted. If You
282 | institute patent litigation against any entity (including a
283 | cross-claim or counterclaim in a lawsuit) alleging that the Work
284 | or a Contribution incorporated within the Work constitutes direct
285 | or contributory patent infringement, then any patent licenses
286 | granted to You under this License for that Work shall terminate
287 | as of the date such litigation is filed.
288 |
289 | 4. Redistribution. You may reproduce and distribute copies of the
290 | Work or Derivative Works thereof in any medium, with or without
291 | modifications, and in Source or Object form, provided that You
292 | meet the following conditions:
293 |
294 | (a) You must give any other recipients of the Work or
295 | Derivative Works a copy of this License; and
296 |
297 | (b) You must cause any modified files to carry prominent notices
298 | stating that You changed the files; and
299 |
300 | (c) You must retain, in the Source form of any Derivative Works
301 | that You distribute, all copyright, patent, trademark, and
302 | attribution notices from the Source form of the Work,
303 | excluding those notices that do not pertain to any part of
304 | the Derivative Works; and
305 |
306 | (d) If the Work includes a "NOTICE" text file as part of its
307 | distribution, then any Derivative Works that You distribute must
308 | include a readable copy of the attribution notices contained
309 | within such NOTICE file, excluding those notices that do not
310 | pertain to any part of the Derivative Works, in at least one
311 | of the following places: within a NOTICE text file distributed
312 | as part of the Derivative Works; within the Source form or
313 | documentation, if provided along with the Derivative Works; or,
314 | within a display generated by the Derivative Works, if and
315 | wherever such third-party notices normally appear. The contents
316 | of the NOTICE file are for informational purposes only and
317 | do not modify the License. You may add Your own attribution
318 | notices within Derivative Works that You distribute, alongside
319 | or as an addendum to the NOTICE text from the Work, provided
320 | that such additional attribution notices cannot be construed
321 | as modifying the License.
322 |
323 | You may add Your own copyright statement to Your modifications and
324 | may provide additional or different license terms and conditions
325 | for use, reproduction, or distribution of Your modifications, or
326 | for any such Derivative Works as a whole, provided Your use,
327 | reproduction, and distribution of the Work otherwise complies with
328 | the conditions stated in this License.
329 |
330 | 5. Submission of Contributions. Unless You explicitly state otherwise,
331 | any Contribution intentionally submitted for inclusion in the Work
332 | by You to the Licensor shall be under the terms and conditions of
333 | this License, without any additional terms or conditions.
334 | Notwithstanding the above, nothing herein shall supersede or modify
335 | the terms of any separate license agreement you may have executed
336 | with Licensor regarding such Contributions.
337 |
338 | 6. Trademarks. This License does not grant permission to use the trade
339 | names, trademarks, service marks, or product names of the Licensor,
340 | except as required for reasonable and customary use in describing the
341 | origin of the Work and reproducing the content of the NOTICE file.
342 |
343 | 7. Disclaimer of Warranty. Unless required by applicable law or
344 | agreed to in writing, Licensor provides the Work (and each
345 | Contributor provides its Contributions) on an "AS IS" BASIS,
346 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
347 | implied, including, without limitation, any warranties or conditions
348 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
349 | PARTICULAR PURPOSE. You are solely responsible for determining the
350 | appropriateness of using or redistributing the Work and assume any
351 | risks associated with Your exercise of permissions under this License.
352 |
353 | 8. Limitation of Liability. In no event and under no legal theory,
354 | whether in tort (including negligence), contract, or otherwise,
355 | unless required by applicable law (such as deliberate and grossly
356 | negligent acts) or agreed to in writing, shall any Contributor be
357 | liable to You for damages, including any direct, indirect, special,
358 | incidental, or consequential damages of any character arising as a
359 | result of this License or out of the use or inability to use the
360 | Work (including but not limited to damages for loss of goodwill,
361 | work stoppage, computer failure or malfunction, or any and all
362 | other commercial damages or losses), even if such Contributor
363 | has been advised of the possibility of such damages.
364 |
365 | 9. Accepting Warranty or Additional Liability. While redistributing
366 | the Work or Derivative Works thereof, You may choose to offer,
367 | and charge a fee for, acceptance of support, warranty, indemnity,
368 | or other liability obligations and/or rights consistent with this
369 | License. However, in accepting such obligations, You may act only
370 | on Your own behalf and on Your sole responsibility, not on behalf
371 | of any other Contributor, and only if You agree to indemnify,
372 | defend, and hold each Contributor harmless for any liability
373 | incurred by, or claims asserted against, such Contributor by reason
374 | of your accepting any such warranty or additional liability.
375 |
376 | END OF TERMS AND CONDITIONS
377 |
378 | Copyright 2023 MegaBlocks authors
379 |
380 | Licensed under the Apache License, Version 2.0 (the "License");
381 | you may not use this file except in compliance with the License.
382 | You may obtain a copy of the License at
383 |
384 | http://www.apache.org/licenses/LICENSE-2.0
385 |
386 | Unless required by applicable law or agreed to in writing, software
387 | distributed under the License is distributed on an "AS IS" BASIS,
388 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
389 | See the License for the specific language governing permissions and
390 | limitations under the License.
391 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | Grouped GEMM for MoE
4 | ===========================
5 |
A PyTorch Toolbox for Grouped GEMM in MoE Model Training
6 |
7 | [](./LICENSE)
8 |
9 |
10 |
11 | - [Grouped GEMM for MoE](#grouped-gemm-for-moe)
12 | - [Steps for Using](#steps-for-using)
13 | - [pip install](#pip-install)
14 | - [Build from Source](#build-from-source)
15 | - [Support Matrix](#support-matrix)
16 | - [permute \& unpermute](#permute--unpermute)
17 | - [Ops Usage](#ops-usage)
18 | - [permute](#permute)
19 | - [Parameters](#parameters)
20 | - [unpermute](#unpermute)
21 | - [Parameters](#parameters-1)
22 | - [Example](#example)
23 |
24 | ---
25 |
26 | # Steps for Using
27 |
28 | ## pip install
29 | ```bash
30 | pip install --verbose git+https://github.com/fanshiqing/grouped_gemm@main
31 | ```
32 |
33 | ## Build from Source
34 | ```bash
35 | git submodule update --init --recursive
36 | mkdir build
37 | cd build
38 | cmake ..
39 | make -j
40 | cd ..
41 |
42 | # GroupedGEMM ops test
43 | python grouped_gemm/ops_test.py
44 |
45 | # topK permute & unpermute ops test
46 | python grouped_gemm/permute_test.py
47 |
48 | # sinkhorn kernel test
49 | python grouped_gemm/sinkhorn_test.py
50 | ```
51 |
52 | # Support Matrix
53 |
54 | ## permute & unpermute
55 |
56 | | GPU Arch | FP32 | FP16 | BF16 | FP8 |
57 | | :--------- | :---: | :---: | :---: | :---: |
58 | | SM 70 | Y | Y | . | Y |
59 | | SM 75 | Y | Y | . | Y |
60 | | SM 80 | Y | Y | Y | Y |
61 | | SM 86 | Y | Y | Y | Y |
62 | | SM 89 | Y | Y | Y | Y |
63 | | SM 90 | Y | Y | Y | Y |
64 |
65 | # Ops Usage
66 |
67 | ## permute
68 |
69 | > ```py
70 | > grouped_gemm.ops.permute(
71 | > input_act: torch.Tensor,
72 | > indices: torch.Tensor,
73 | > num_out_tokens: int = 0,
74 | > max_token_num=0: int) -> tuple
75 | > ```
76 |
77 | The output tuple of `(torch.Tensor, torch.Tensor)` that contains two tensors `permuted_act` and `row_id_map`.
78 |
79 | * `permuted_act` is the permutation of the original tensor `input_act` with its first dimension permuted according to `indices`.
80 | * `row_id_map` is the mapping table for the row indices of the input activations before and after `grouped_gemm.ops.permute`, which is used for the following `unpermute` op.
81 |
82 | ### Parameters
83 |
84 | * **input_act** (torch.Tensor)
85 | shape = [tokens_num, hidden_size]
86 | The input activations with each row (token) corresponds to topK experts.
87 |
88 | * **indices** (torch.Tensor)
89 | shape = [tokens_num, topK_num]
90 | The topK expert indices for each row (token) of activations. The `int32` type is recommended.
91 |
92 | * **num_out_tokens** (int)
93 | The number of output tokens (rows) used for token drop feature.
94 |
95 | * **max_token_num** (int)
96 | The maximum number of tokens (rows) used for workspace pre-allocation.
97 |
98 |

99 |
100 | ## unpermute
101 |
102 | > ```py
103 | > grouped_gemm.ops.unpermute(
104 | > input_act: torch.Tensor,
105 | > row_id_map: torch.Tensor,
106 | > probs) -> torch.Tensor
107 | > ```
108 |
109 | The mirror operator of `grouped_gemm.ops.permute`.
110 |
111 | ### Parameters
112 |
113 | * **input_act** (torch.Tensor)
114 | shape = [tokens_num * topK_num, hidden_size]
115 | The permuted activations produced by `grouped_gemm.ops.permute`.
116 |
117 | * **row_id_map** (torch.Tensor)
118 | shape = [tokens_num * topK_num]
119 | The mapping table for the row indices of the activations before and after `grouped_gemm.ops.permute`. The second output tensor of `grouped_gemm.ops.permute`.
120 |
121 | * **probs** (torch.Tensor)
122 | shape = [tokens_num, topK_num]
123 | Sum weights for same-origin tokens from different experts.
124 |
125 |

126 |
127 | ### Example
128 |
129 | ```py
130 | import torch
131 | from grouped_gemm import permute, unpermute
132 |
133 | indices = torch.tensor([[1, 2], [0, 1], [0, 2], [1, 2]], dtype=torch.int32, device='cuda')
134 | input_act = torch.tensor([[0,0,0,0], [1,1,1,1], [2,2,2,2], [3,3,3,3]], dtype=torch.float32, device='cuda')
135 | probs = torch.ones_like(indices, dtype=torch.float32)
136 | permuted_inputs, row_id_map = permute(input_act, indices)
137 | unpermute_outputs = unpermute(permuted_inputs, row_id_map, probs)
138 |
139 | print(row_id_map)
140 | print(input_act)
141 | print(permuted_inputs)
142 | print(unpermute_outputs)
143 |
144 | # Output
145 | # tensor([2, 0, 1, 4, 5, 3, 6, 7], device='cuda:0', dtype=torch.int32)
146 | # tensor([[0., 0., 0., 0.],
147 | # [1., 1., 1., 1.],
148 | # [2., 2., 2., 2.],
149 | # [3., 3., 3., 3.]], device='cuda:0')
150 | # tensor([[1., 1., 1., 1.],
151 | # [2., 2., 2., 2.],
152 | # [0., 0., 0., 0.],
153 | # [1., 1., 1., 1.],
154 | # [3., 3., 3., 3.],
155 | # [0., 0., 0., 0.],
156 | # [2., 2., 2., 2.],
157 | # [3., 3., 3., 3.]], device='cuda:0')
158 | # tensor([[0., 0., 0., 0.],
159 | # [2., 2., 2., 2.],
160 | # [4., 4., 4., 4.],
161 | # [6., 6., 6., 6.]], device='cuda:0')
162 | ```
163 |
164 |
--------------------------------------------------------------------------------
/csrc/grouped_gemm.cu:
--------------------------------------------------------------------------------
1 | #include "grouped_gemm.h"
2 |
3 | #include
4 | #include
5 | #include
6 | #include
7 |
8 | #include "cutlass/bfloat16.h"
9 | #include "cutlass/complex.h"
10 | #include "cutlass/gemm/kernel/gemm_grouped.h"
11 | #include "cutlass/gemm/kernel/default_gemm_grouped.h"
12 | #include "cutlass/gemm/device/gemm_grouped.h"
13 |
14 | namespace grouped_gemm {
15 |
16 | #define NUM_STREAM 4
17 |
18 | #define CUDA_CALL(code) \
19 | do { \
20 | cudaError_t status = code; \
21 | std::string err = cudaGetErrorString(status); \
22 | TORCH_CHECK(status == cudaSuccess, err); \
23 | } while (0)
24 |
25 | #define CUBLAS_CALL(code) \
26 | do { \
27 | cublasStatus_t status = code; \
28 | TORCH_CHECK(status == CUBLAS_STATUS_SUCCESS, "CuBLAS Error"); \
29 | } while (0)
30 |
31 | #define GROUPED_GEMM_STRINGIFY_HELPER(x) #x
32 | #define GROUPED_GEMM_STRINGIFY(x) \
33 | GROUPED_GEMM_STRINGIFY_HELPER(x)
34 |
35 | // TODO(tgale): Update this for SM90 when it's supported by CUTLASS.
36 | using GroupedGemmKernelNN = typename cutlass::gemm::kernel::DefaultGemmGrouped<
37 | // Non-transposed A operand.
38 | ::cutlass::bfloat16_t,
39 | ::cutlass::layout::RowMajor,
40 | ::cutlass::ComplexTransform::kNone,
41 | 8,
42 | // Non-transposed B operand.
43 | ::cutlass::bfloat16_t,
44 | ::cutlass::layout::RowMajor,
45 | ::cutlass::ComplexTransform::kNone,
46 | 8,
47 | // C operand.
48 | ::cutlass::bfloat16_t,
49 | ::cutlass::layout::RowMajor,
50 | float,
51 | ::cutlass::arch::OpClassTensorOp,
52 | ::cutlass::arch::Sm80,
53 | ::cutlass::gemm::GemmShape<128, 128, 32>,
54 | ::cutlass::gemm::GemmShape<64, 64, 32>,
55 | ::cutlass::gemm::GemmShape<16, 8, 16>,
56 | ::cutlass::epilogue::thread::LinearCombination<::cutlass::bfloat16_t, 8, float, float>,
57 | // NOTE: Threadblock swizzling is currently not supported by CUTLASS's grouped kernels.
58 | // This parameter is passed in at present to match the APIs of other kernels. The parameter
59 | // is unused within the kernel.
60 | ::cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,
61 | // TODO(tgale): Experiment with GroupScheduleMode.
62 | // TODO(tgale): Tune this for SM90.
63 | 4>::GemmKernel;
64 | using GemmGroupedNN = ::cutlass::gemm::device::GemmGrouped;
65 |
66 | std::vector MakeProblemSizes(torch::Tensor b, torch::Tensor batch_sizes) {
67 | const size_t num_experts = batch_sizes.size(0);
68 | const size_t k = b.size(1), n = b.size(2);
69 | std::vector problem_sizes(num_experts);
70 | for (int i = 0; i < num_experts; ++i) {
71 | problem_sizes[i] = cutlass::gemm::GemmCoord(batch_sizes.data_ptr()[i], n, k);
72 | }
73 | return problem_sizes;
74 | }
75 |
76 | template
77 | torch::Tensor CopyToDevice(const std::vector &x, const torch::Device &device) {
78 | size_t bytes = x.size() * sizeof(T);
79 | auto options = torch::TensorOptions().dtype(torch::kInt8).device(device);
80 | torch::Tensor out = torch::empty(bytes, options);
81 |
82 | CUDA_CALL(cudaMemcpyAsync(out.data_ptr(),
83 | x.data(), bytes,
84 | cudaMemcpyHostToDevice,
85 | c10::cuda::getCurrentCUDAStream()));
86 | return out;
87 | }
88 |
89 | template
90 | typename Gemm::Arguments MakeArguments(torch::Tensor a,
91 | torch::Tensor b,
92 | torch::Tensor c,
93 | torch::Tensor batch_sizes) {
94 | auto problem_sizes_host = MakeProblemSizes(b, batch_sizes);
95 |
96 | // Calculate the number of threadblocks to use and validate the result.
97 | int64_t num_experts = problem_sizes_host.size();
98 |
99 | // NOTE: This is borrowed from FasterTransformer.
100 | int threadblock_count = Gemm::sufficient(problem_sizes_host.data(), num_experts);
101 | if (!threadblock_count) {
102 | TORCH_CHECK(false, "Grouped GEMM execution not possible with HW");
103 | }
104 |
105 | // Create the host arrays of leading dimension data and pointer data.
106 | using LayoutA = typename Gemm::LayoutA;
107 | using LayoutB = typename Gemm::LayoutB;
108 | using LayoutC = typename Gemm::LayoutC;
109 |
110 | std::vector lda_host(num_experts), offsets_a(num_experts);
111 | std::vector ldb_host(num_experts), offsets_b(num_experts);
112 | std::vector ldc_host(num_experts), offsets_c(num_experts);
113 | int64_t elements_a = 0, elements_b = 0, elements_c = 0;
114 |
115 | using ElementA = typename Gemm::ElementA;
116 | using ElementB = typename Gemm::ElementB;
117 | using ElementC = typename Gemm::ElementC;
118 | std::vector ptr_a_host(num_experts);
119 | std::vector ptr_b_host(num_experts);
120 | std::vector ptr_c_host(num_experts);
121 |
122 | for (int i = 0; i < num_experts; ++i) {
123 | auto problem = problem_sizes_host[i];
124 | lda_host[i] = LayoutA::packed({problem.m(), problem.k()}).stride(0);
125 | ldb_host[i] = LayoutB::packed({problem.k(), problem.n()}).stride(0);
126 | ldc_host[i] = LayoutC::packed({problem.m(), problem.n()}).stride(0);
127 |
128 | offsets_a[i] = elements_a;
129 | offsets_b[i] = elements_b;
130 | offsets_c[i] = elements_c;
131 |
132 | ptr_a_host[i] = (ElementA*)a.data_ptr() + offsets_a[i];
133 | ptr_b_host[i] = (ElementB*)b.data_ptr() + offsets_b[i];
134 | ptr_c_host[i] = (ElementC*)c.data_ptr() + offsets_c[i];
135 |
136 | elements_a += problem.m() * problem.k();
137 | elements_b += problem.k() * problem.n();
138 | elements_c += problem.m() * problem.n();
139 | }
140 |
141 | // Copy the problem sizes, pointers and leading dimension data to the device.
142 | torch::Tensor lda = CopyToDevice(lda_host, a.device());
143 | torch::Tensor ldb = CopyToDevice(ldb_host, a.device());
144 | torch::Tensor ldc = CopyToDevice(ldc_host, a.device());
145 | torch::Tensor ptr_a = CopyToDevice(ptr_a_host, a.device());
146 | torch::Tensor ptr_b = CopyToDevice(ptr_b_host, a.device());
147 | torch::Tensor ptr_c = CopyToDevice(ptr_c_host, a.device());
148 | torch::Tensor problem_sizes = CopyToDevice(problem_sizes_host, a.device());
149 |
150 | typename Gemm::EpilogueOutputOp::Params epilogue_op(/*alpha=*/1.0f, /*beta=*/0.0f);
151 | typename Gemm::Arguments arguments((cutlass::gemm::GemmCoord*)problem_sizes.data_ptr(),
152 | (int)num_experts,
153 | (int)threadblock_count,
154 | epilogue_op,
155 | (ElementA**)ptr_a.data_ptr(),
156 | (ElementB**)ptr_b.data_ptr(),
157 | (ElementC**)ptr_c.data_ptr(),
158 | (ElementC**)ptr_c.data_ptr(),
159 | /*lda=*/(int64_t*)lda.data_ptr(),
160 | /*ldb=*/(int64_t*)ldb.data_ptr(),
161 | /*ldc=*/(int64_t*)ldc.data_ptr(),
162 | /*ldd=*/(int64_t*)ldc.data_ptr(),
163 | (cutlass::gemm::GemmCoord*)problem_sizes_host.data());
164 | return arguments;
165 | }
166 |
167 | torch::Tensor CutlassGroupedGemm(torch::Tensor a,
168 | torch::Tensor b,
169 | torch::Tensor c,
170 | torch::Tensor batch_sizes) {
171 | using Gemm = GemmGroupedNN;
172 | Gemm gemm;
173 |
174 | auto arguments = MakeArguments(a, b, c, batch_sizes);
175 | int64_t workspace_size = gemm.get_workspace_size(arguments);
176 | auto options = torch::TensorOptions().dtype(torch::kInt8).device(a.device());
177 | torch::Tensor workspace = torch::empty(workspace_size, options);
178 |
179 | // Initialize the kernel.
180 | if(gemm.initialize(arguments, workspace.data_ptr()) != cutlass::Status::kSuccess) {
181 | TORCH_CHECK(false, "Failed to initialize CUTLASS Grouped GEMM");
182 | }
183 |
184 | // Execute the kernel in the current stream.
185 | if(gemm.run(c10::cuda::getCurrentCUDAStream()) != cutlass::Status::kSuccess) {
186 | TORCH_CHECK(false, "Failed to run CUTLASS Grouped GEMM");
187 | }
188 | return c;
189 | }
190 |
191 | cublasHandle_t cublas_handle[NUM_STREAM];
192 | cudaStream_t cublas_stream[NUM_STREAM];
193 | cudaEvent_t cublas_event[NUM_STREAM];
194 | bool cublas_init = false;
195 |
196 | void cublas_handle_init()
197 | {
198 | cublas_init = true;
199 |
200 | for (int i = 0; i < NUM_STREAM; i++)
201 | {
202 | cudaStreamCreateWithFlags(&cublas_stream[i], cudaStreamNonBlocking);
203 | cublasCreate(&cublas_handle[i]);
204 | cublasSetStream(cublas_handle[i], cublas_stream[i]);
205 | cudaEventCreate(&cublas_event[i]);
206 | }
207 | }
208 |
209 | inline void cublas_current_wait_streams(cudaStream_t stream)
210 | {
211 | for (int s = 0; s < NUM_STREAM; s++)
212 | {
213 | cudaEventRecord(cublas_event[s], cublas_stream[s]);
214 | }
215 |
216 | for (int s = 0; s < NUM_STREAM; s++)
217 | {
218 | cudaStreamWaitEvent(stream, cublas_event[s]);
219 | }
220 | }
221 |
222 | inline void cublas_streams_wait_current(cudaStream_t stream)
223 | {
224 | cudaEventRecord(cublas_event[0], stream);
225 |
226 | for (int s = 0; s < NUM_STREAM; s++)
227 | {
228 | cudaStreamWaitEvent(cublas_stream[s], cublas_event[0]);
229 | }
230 | }
231 |
232 | void CublasGemm(cublasHandle_t cublas_handle,
233 | c10::BFloat16 *a, int64_t a_rows, int64_t a_cols, bool trans_a,
234 | c10::BFloat16 *b, int64_t b_rows, int64_t b_cols, bool trans_b,
235 | c10::BFloat16 *c, int64_t c_rows, int64_t c_cols) {
236 | int m = trans_b ? b_rows : b_cols;
237 | int k = trans_b ? b_cols : b_rows;
238 | int n = trans_a ? a_cols : a_rows;
239 |
240 | int lda = trans_a ? n : k;
241 | int ldb = trans_b ? k : m;
242 | cublasOperation_t transpose_a = trans_a ? CUBLAS_OP_T : CUBLAS_OP_N;
243 | cublasOperation_t transpose_b = trans_b ? CUBLAS_OP_T : CUBLAS_OP_N;
244 |
245 |
246 | float alpha = 1.0, beta = 0.0;
247 | CUBLAS_CALL(cublasGemmEx(cublas_handle,
248 | transpose_b, transpose_a,
249 | m, n, k, &alpha,
250 | b, CUDA_R_16BF, ldb,
251 | a, CUDA_R_16BF, lda,
252 | &beta,
253 | c, CUDA_R_16BF, c_cols, CUDA_R_32F,
254 | CUBLAS_GEMM_DEFAULT));
255 | }
256 |
257 | void CublasGroupedGemm(torch::Tensor a,
258 | torch::Tensor b,
259 | torch::Tensor c,
260 | torch::Tensor batch_sizes,
261 | bool trans_b) {
262 | if (!cublas_init)
263 | cublas_handle_init();
264 |
265 | int64_t bs = batch_sizes.size(0), k = a.size(1);
266 | int64_t n = trans_b ? b.size(1) : b.size(2);
267 | int64_t b_rows = b.size(1), b_cols = b.size(2);
268 | c10::BFloat16* a_ptr = a.data_ptr();
269 | c10::BFloat16* b_ptr = b.data_ptr();
270 | c10::BFloat16* c_ptr = c.data_ptr();
271 |
272 | cublas_streams_wait_current(c10::cuda::getCurrentCUDAStream());
273 |
274 | for (int i = 0; i < bs; ++i) {
275 |
276 | int64_t m = batch_sizes.data_ptr()[i];
277 | CublasGemm(cublas_handle[i % NUM_STREAM], a_ptr, m, k, /*trans_a=*/false,
278 | b_ptr, b_rows, b_cols, trans_b,
279 | c_ptr, m, n);
280 | a_ptr += m * k;
281 | b_ptr += b_rows * b_cols;
282 | c_ptr += m * n;
283 | }
284 |
285 | cublas_current_wait_streams(c10::cuda::getCurrentCUDAStream());
286 | }
287 |
288 | void CublasGroupedGemmVariableK(torch::Tensor a,
289 | torch::Tensor b,
290 | torch::Tensor c,
291 | torch::Tensor batch_sizes) {
292 | if (!cublas_init)
293 | cublas_handle_init();
294 |
295 | int64_t bs = batch_sizes.size(0), m = a.size(1), n = b.size(1);
296 | c10::BFloat16* a_ptr = a.data_ptr();
297 | c10::BFloat16* b_ptr = b.data_ptr();
298 | c10::BFloat16* c_ptr = c.data_ptr();
299 |
300 | cublas_streams_wait_current(c10::cuda::getCurrentCUDAStream());
301 |
302 | for (int i = 0; i < bs; ++i) {
303 | int64_t k = batch_sizes.data_ptr()[i];
304 | CublasGemm(cublas_handle[i % NUM_STREAM], a_ptr, k, m, /*trans_a=*/true,
305 | b_ptr, k, n, /*trans_b=*/false,
306 | c_ptr, m, n);
307 | a_ptr += k * m;
308 | b_ptr += k * n;
309 | c_ptr += m * n;
310 | }
311 |
312 | cublas_current_wait_streams(c10::cuda::getCurrentCUDAStream());
313 | }
314 |
315 | void GroupedGemmVariableK(torch::Tensor a,
316 | torch::Tensor b,
317 | torch::Tensor c,
318 | torch::Tensor batch_sizes) {
319 | // We expected a CUDA tensor with two dimensions and shape
320 | // (tokens, hidden_out) for 'b'.
321 | TORCH_CHECK(b.is_cuda());
322 | TORCH_CHECK(b.ndimension() == 2);
323 | TORCH_CHECK(b.scalar_type() == torch::kBFloat16);
324 |
325 | // Validate the dimensions.
326 | int64_t tokens = a.size(0), num_experts = batch_sizes.size(0);
327 | int64_t m = a.size(1), n = b.size(1);
328 |
329 | // Validate that we have the same contraction dimension.
330 | TORCH_CHECK(tokens == b.size(0));
331 |
332 | // Validate the output shape.
333 | TORCH_CHECK(c.is_cuda());
334 | TORCH_CHECK(c.ndimension() == 3);
335 | TORCH_CHECK(c.scalar_type() == torch::kBFloat16);
336 | TORCH_CHECK(c.size(0) == num_experts);
337 | TORCH_CHECK(c.size(1) == m);
338 | TORCH_CHECK(c.size(2) == n);
339 |
340 | // Run the computation.
341 | CublasGroupedGemmVariableK(a, b, c, batch_sizes);
342 | }
343 |
344 | // NOTE: We only support dynamic group sizes for the 'a' tensor. Tensor 'b' is
345 | // assumed to be batched with fixed sized batches.
346 | //
347 | // TODO(tgale): Validate alignment is true for every batch element.
348 | void GroupedGemm(torch::Tensor a,
349 | torch::Tensor b,
350 | torch::Tensor c,
351 | torch::Tensor batch_sizes,
352 | bool trans_a, bool trans_b) {
353 | // NOTE: We only support 'trans_a' or 'trans_b', not both.
354 | TORCH_CHECK(!(trans_a && trans_b));
355 |
356 | // We expect the batch_sizes on CPU.
357 | TORCH_CHECK(batch_sizes.is_cpu());
358 | TORCH_CHECK(batch_sizes.ndimension() == 1);
359 | TORCH_CHECK(batch_sizes.scalar_type() == torch::kInt64);
360 |
361 | // We expected a CUDA tensor with two dimensions and shape
362 | // (tokens, hidden_in) for 'a'.
363 | TORCH_CHECK(a.is_cuda());
364 | TORCH_CHECK(a.ndimension() == 2);
365 | TORCH_CHECK(a.scalar_type() == torch::kBFloat16);
366 |
367 | // Defer to the variable 'k' helper for the rest of the op.
368 | if (trans_a) {
369 | GroupedGemmVariableK(a, b, c, batch_sizes);
370 | return;
371 | }
372 |
373 | // We expected a CUDA tensor with three dimensions and shape
374 | // (num_experts, hidden_in, hidden_out) for 'b'.
375 | TORCH_CHECK(b.is_cuda());
376 | TORCH_CHECK(b.ndimension() == 3);
377 | TORCH_CHECK(b.scalar_type() == torch::kBFloat16);
378 |
379 | // Validate the contraction dimensions match.
380 | int64_t tokens = a.size(0), num_experts = b.size(0);
381 | int64_t hidden_in = trans_b ? b.size(2) : b.size(1);
382 | int64_t hidden_out = trans_b ? b.size(1) : b.size(2);
383 | TORCH_CHECK(hidden_in == a.size(1));
384 |
385 | // Validate that we have one size per expert.
386 | TORCH_CHECK(batch_sizes.size(0) == num_experts);
387 |
388 | // Validate the output shape.
389 | TORCH_CHECK(c.is_cuda());
390 | TORCH_CHECK(c.ndimension() == 2);
391 | TORCH_CHECK(c.scalar_type() == torch::kBFloat16);
392 | TORCH_CHECK(c.size(0) == tokens);
393 | TORCH_CHECK(c.size(1) == hidden_out);
394 |
395 | // NOTE: We support transposition through the 'trans_b' flag.
396 | TORCH_CHECK(a.is_contiguous());
397 | TORCH_CHECK(b.is_contiguous());
398 |
399 | // NOTE: Use cuBLAS for SM90 until CUTLASS supports SM90-optimized grouped-gemm.
400 | #if !defined(GROUPED_GEMM_DEVICE_CAPABILITY) || GROUPED_GEMM_DEVICE_CAPABILITY != 80
401 | CublasGroupedGemm(a, b, c, batch_sizes, trans_b);
402 | return;
403 | #else
404 | // TODO(tgale): Support transposition with CUTLASS grouped GEMM.
405 | if (trans_b) {
406 | CublasGroupedGemm(a, b, c, batch_sizes, trans_b);
407 | return;
408 | }
409 | CutlassGroupedGemm(a, b, c, batch_sizes);
410 | return;
411 | #endif
412 | }
413 |
414 | } // namespace grouped_gemm
415 |
--------------------------------------------------------------------------------
/csrc/grouped_gemm.h:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | namespace grouped_gemm {
4 |
5 | void GroupedGemm(torch::Tensor a,
6 | torch::Tensor b,
7 | torch::Tensor c,
8 | torch::Tensor batch_sizes,
9 | bool trans_a, bool trans_b);
10 |
11 | } // namespace grouped_gemm
12 |
--------------------------------------------------------------------------------
/csrc/ops.cu:
--------------------------------------------------------------------------------
1 | #include "grouped_gemm.h"
2 | #include "permute.h"
3 | #include "sinkhorn.h"
4 |
5 | #include
6 |
7 | namespace grouped_gemm {
8 |
9 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
10 | m.def("gmm", &GroupedGemm, "Grouped GEMM.");
11 | m.def("sinkhorn", &sinkhorn, "Sinkhorn kernel");
12 | m.def("permute", &moe_permute_topK_op, "Token permutation kernel");
13 | m.def("unpermute", &moe_recover_topK_op, "Token un-permutation kernel");
14 | m.def("unpermute_bwd", &moe_recover_topK_bwd_op, "Token un-permutation backward kernel");
15 | }
16 |
17 | } // namespace grouped_gemm
18 |
--------------------------------------------------------------------------------
/csrc/permute.cu:
--------------------------------------------------------------------------------
1 | /*************************************************************************
2 | * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | *
4 | * See LICENSE for license information.
5 | ************************************************************************/
6 |
7 | #include "permute.h"
8 |
9 | #include
10 | #include
11 | #include
12 |
13 | #include "cuda_runtime.h"
14 | #include "device_launch_parameters.h"
15 |
16 | #include "ATen/cuda/CUDAContext.h"
17 |
18 | #include "cutlass/arch/memory.h"
19 | #include "cutlass/arch/cache_operation.h"
20 | #include "cutlass/array.h"
21 | #include "cutlass/numeric_conversion.h"
22 |
23 |
24 | using torch::Tensor;
25 |
26 | namespace grouped_gemm {
27 |
28 | template
29 | inline T *get_ptr(torch::Tensor &t)
30 | {
31 | return reinterpret_cast(t.data_ptr());
32 | }
33 |
34 | /////////////////////////////////////////////////////////////////////////////////////////////////
35 | //
36 | // Top K
37 | //
38 | /////////////////////////////////////////////////////////////////////////////////////////////////
39 |
40 | static __global__ void moe_permute_topK_row_map(
41 | const int *sorted_row_id,
42 | int *row_id_map,
43 | const int num_rows,
44 | const int num_topK,
45 | const int num_out_tokens)
46 | {
47 | // Each block corresponds to one source token
48 | // row_id_map[num_topK][num_rows]
49 | const int bid = blockIdx.x;
50 | const int tid = threadIdx.x;
51 | const int idx = bid * blockDim.x + tid;
52 |
53 | if (idx >= num_rows * num_topK)
54 | return;
55 |
56 | int source_row = sorted_row_id[idx];
57 | int source_token_id = source_row / num_topK;
58 | int source_topK_id = source_row % num_topK;
59 |
60 | if (idx >= num_out_tokens)
61 | {
62 | row_id_map[source_topK_id * num_rows + source_token_id] = -1;
63 | }
64 | else
65 | {
66 | row_id_map[source_topK_id * num_rows + source_token_id] = idx;
67 | }
68 | }
69 |
70 | template
71 | __global__ void moe_recover_topK_kernel(const T *input,
72 | T *unpermuted_output,
73 | const int *row_id_map,
74 | const float *prob,
75 | const int num_rows,
76 | const int num_topK,
77 | const int num_cols)
78 | {
79 | extern __shared__ int8_t s_mem[];
80 | TCompute *s_prob = reinterpret_cast(s_mem);
81 |
82 | using FragmentLoadStore = cutlass::Array;
83 | using FragmentCompute = cutlass::Array;
84 |
85 | cutlass::NumericArrayConverter src_converter;
86 | cutlass::NumericArrayConverter dst_converter;
87 |
88 | // each block corresponds to one source token
89 | const int source_token = blockIdx.x;
90 | const int tid = threadIdx.x;
91 |
92 | if (hasProb)
93 | {
94 | for (int i = tid; i < num_topK; i += blockDim.x * blockDim.y)
95 | {
96 | s_prob[i] = TCompute(prob[source_token * num_topK + i]);
97 | }
98 | __syncthreads();
99 | }
100 |
101 | for (int i = tid * kElementsPerAccess; i < num_cols; i += blockDim.x * kElementsPerAccess)
102 | {
103 | FragmentLoadStore frag_load_store;
104 | FragmentCompute frag_elem;
105 | FragmentCompute frag_sum;
106 |
107 | int source_row = row_id_map[source_token];
108 |
109 | if (source_row != -1)
110 | {
111 | const T *source_row_ptr = input + source_row * num_cols;
112 |
113 | cutlass::arch::global_load(
114 | frag_load_store, (source_row_ptr + i), true);
115 | frag_sum = src_converter(frag_load_store);
116 |
117 | if (hasProb)
118 | {
119 | frag_sum = frag_sum * s_prob[0];
120 | }
121 | }
122 | else
123 | {
124 | frag_sum.clear();
125 | }
126 |
127 | for (int k = 1; k < num_topK; k++)
128 | {
129 | source_row = row_id_map[k * num_rows + source_token];
130 |
131 | if (source_row == -1)
132 | continue;
133 |
134 | const T *source_row_ptr = input + source_row * num_cols;
135 |
136 | cutlass::arch::global_load(
137 | frag_load_store, (source_row_ptr + i), true);
138 | frag_elem = src_converter(frag_load_store);
139 |
140 | if (hasProb)
141 | {
142 | frag_elem = frag_elem * s_prob[k];
143 | }
144 |
145 | for (int e = 0; e < kElementsPerAccess; e++)
146 | {
147 | frag_sum.at(e) = frag_sum.at(e) + frag_elem.at(e);
148 | }
149 | }
150 |
151 | T *dest_row_ptr = unpermuted_output + source_token * num_cols;
152 | frag_load_store = dst_converter(frag_sum);
153 | *(float4 *)(dest_row_ptr + i) = *(float4 *)(frag_load_store.data());
154 | }
155 | }
156 |
157 | template
162 | __global__ void moe_permute_topK_kernel(const T *input_bwd,
163 | const T *input_fwd,
164 | T *act_grad,
165 | const float *prob,
166 | float *prob_grad,
167 | const int *row_id_map,
168 | const int num_rows,
169 | const int num_topK,
170 | const int num_cols)
171 | {
172 | extern __shared__ int8_t s_mem[];
173 | TCompute *s_prob = reinterpret_cast(s_mem);
174 |
175 | using FragmentLoadStore = cutlass::Array;
176 | using FragmentCompute = cutlass::Array;
177 |
178 | cutlass::NumericArrayConverter src_converter;
179 | cutlass::NumericArrayConverter dst_converter;
180 |
181 | const int source_token = blockIdx.x;
182 | const int tid = threadIdx.x;
183 |
184 | if (hasProb)
185 | {
186 | for (int i = tid; i < num_topK; i += blockDim.x)
187 | {
188 | s_prob[i] = TCompute(prob[source_token * num_topK + i]);
189 | }
190 | __syncthreads();
191 | }
192 |
193 | float accum[topKTile] = {0.0f};
194 | FragmentLoadStore frag_load_store;
195 |
196 | const T *source_row_ptr = input_bwd + source_token * num_cols;
197 | for (int i = tid * kElementsPerAccess; i < num_cols; i += blockDim.x * kElementsPerAccess)
198 | {
199 | cutlass::arch::global_load(
200 | frag_load_store, (source_row_ptr + i), true);
201 | FragmentCompute frag_src = src_converter(frag_load_store);
202 |
203 | int index = source_token;
204 |
205 | for (int k = 0; k < topKTile; k++)
206 | {
207 | if (k == num_topK) break;
208 |
209 | int dest_row = row_id_map[index];
210 | index += num_rows;
211 |
212 | if (dest_row == -1)
213 | continue;
214 |
215 | if (hasProb)
216 | {
217 | frag_load_store = dst_converter(frag_src * s_prob[k]);
218 | }
219 | else
220 | {
221 | frag_load_store = dst_converter(frag_src);
222 | }
223 |
224 | T *dest_row_ptr = act_grad + dest_row * num_cols;
225 | *(float4 *)(dest_row_ptr + i) = *(float4 *)(frag_load_store.data());
226 |
227 | if (hasProb)
228 | {
229 | const T *input_fwd_ptr = input_fwd + dest_row * num_cols;
230 | cutlass::arch::global_load(
231 | frag_load_store, (input_fwd_ptr + i), true);
232 | FragmentCompute frag_input_fwd = src_converter(frag_load_store);
233 |
234 | for (int e = 0; e < kElementsPerAccess; e++)
235 | {
236 | accum[k] += float(frag_src.at(e) * frag_input_fwd.at(e));
237 | }
238 | }
239 | }
240 | }
241 |
242 | if (hasProb)
243 | {
244 | for (int k = 0; k < topKTile; k++)
245 | {
246 | if (k == num_topK) break;
247 |
248 | for (int mask = 16; mask > 0; mask /= 2)
249 | {
250 | accum[k] = accum[k] + __shfl_xor_sync(0xffffffff, accum[k], mask, 32);
251 | }
252 | }
253 |
254 | if (tid == 0)
255 | {
256 | for (int k = 0; k < topKTile; k++)
257 | {
258 | if (k == num_topK) break;
259 | prob_grad[source_token * num_topK + k] = accum[k];
260 | }
261 | }
262 | }
263 | }
264 |
265 |
266 | template
267 | void moe_permute_topK_kernel_launcher(
268 | const T *input,
269 | T *output,
270 | const int *sorted_row_id,
271 | int *row_id_map,
272 | const float *prob,
273 | const int num_rows,
274 | const int num_topK,
275 | const int num_cols,
276 | const int num_out_tokens,
277 | cudaStream_t stream,
278 | float *prob_grad = nullptr,
279 | const T *input_fwd = nullptr)
280 | {
281 | if (FWD)
282 | {
283 | if (prob_grad == nullptr)
284 | {
285 | // permute_topK fwd
286 | int threads = 64;
287 | int blocks = (num_rows * num_topK + threads - 1) / threads;
288 | moe_permute_topK_row_map<<>>(
289 | sorted_row_id,
290 | row_id_map,
291 | num_rows,
292 | num_topK,
293 | num_out_tokens);
294 |
295 | blocks = num_rows;
296 | threads = std::min(num_cols / kElementsPerAccess, 1024);
297 | moe_permute_topK_kernel<<>>(
298 | input,
299 | nullptr,
300 | output,
301 | nullptr,
302 | nullptr,
303 | row_id_map,
304 | num_rows,
305 | num_topK,
306 | num_cols);
307 | }
308 | else
309 | {
310 | // unpermute_topK bwd
311 | int blocks = num_rows;
312 | int threads = 32;
313 | size_t smem_bytes = num_topK * sizeof(TCompute);
314 |
315 | if (num_topK == 1)
316 | {
317 | moe_permute_topK_kernel<<>>(
318 | input,
319 | input_fwd,
320 | output,
321 | prob,
322 | prob_grad,
323 | row_id_map,
324 | num_rows,
325 | num_topK,
326 | num_cols);
327 | }
328 | else if (num_topK <= 8)
329 | {
330 | moe_permute_topK_kernel<<>>(
331 | input,
332 | input_fwd,
333 | output,
334 | prob,
335 | prob_grad,
336 | row_id_map,
337 | num_rows,
338 | num_topK,
339 | num_cols);
340 | }
341 | else if (num_topK <= 16)
342 | {
343 | moe_permute_topK_kernel<<>>(
344 | input,
345 | input_fwd,
346 | output,
347 | prob,
348 | prob_grad,
349 | row_id_map,
350 | num_rows,
351 | num_topK,
352 | num_cols);
353 | }
354 | else if (num_topK <= 32)
355 | {
356 | moe_permute_topK_kernel<<>>(
357 | input,
358 | input_fwd,
359 | output,
360 | prob,
361 | prob_grad,
362 | row_id_map,
363 | num_rows,
364 | num_topK,
365 | num_cols);
366 | }
367 | else if (num_topK <= 64)
368 | {
369 | moe_permute_topK_kernel<<>>(
370 | input,
371 | input_fwd,
372 | output,
373 | prob,
374 | prob_grad,
375 | row_id_map,
376 | num_rows,
377 | num_topK,
378 | num_cols);
379 | }
380 | else if (num_topK <= 128)
381 | {
382 | moe_permute_topK_kernel<<>>(
383 | input,
384 | input_fwd,
385 | output,
386 | prob,
387 | prob_grad,
388 | row_id_map,
389 | num_rows,
390 | num_topK,
391 | num_cols);
392 | }
393 | else
394 | {
395 | throw std::runtime_error("num_topK cannot exceed 128.");
396 | }
397 | }
398 | }
399 | else
400 | {
401 | int blocks = num_rows;
402 | int threads = std::min(num_cols / kElementsPerAccess, 1024);
403 | size_t smem_bytes = num_topK * sizeof(TCompute);
404 |
405 |
406 | if (num_topK == 1)
407 | {
408 | // permute_topK bwd with topK==1
409 | moe_recover_topK_kernel<<>>(
410 | input,
411 | output,
412 | row_id_map,
413 | prob,
414 | num_rows,
415 | num_topK,
416 | num_cols);
417 | }
418 | else if (prob == nullptr)
419 | {
420 | // permute_topK bwd
421 | moe_recover_topK_kernel<<>>(
422 | input,
423 | output,
424 | row_id_map,
425 | prob,
426 | num_rows,
427 | num_topK,
428 | num_cols);
429 | }
430 | else
431 | {
432 | // unpermute_topK fwd
433 | moe_recover_topK_kernel<<>>(
434 | input,
435 | output,
436 | row_id_map,
437 | prob,
438 | num_rows,
439 | num_topK,
440 | num_cols);
441 | }
442 | }
443 | }
444 |
445 | /////////////////////////////////////////////////////////////////////////////////////////////////
446 | //
447 | // Permute_topK OP
448 | //
449 | /////////////////////////////////////////////////////////////////////////////////////////////////
450 |
451 | std::tuple> moe_permute_topK_op(
452 | Tensor input,
453 | Tensor indices,
454 | int64_t num_out_tokens,
455 | std::vector workspace,
456 | int64_t max_expanded_token_num)
457 | {
458 | const int num_tokens = input.size(0);
459 | const int num_cols = input.size(1);
460 | const int num_topK = indices.size(1);
461 |
462 | // initialize the workspace on the first run
463 | if (workspace.empty()) {
464 | auto options = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false);
465 |
466 | Tensor sorted_indices = torch::empty(max_expanded_token_num, options);
467 | Tensor row_id = torch::range(0, max_expanded_token_num - 1, 1, options);
468 | Tensor sorted_row_id =
469 | torch::empty(max_expanded_token_num, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false));
470 |
471 | size_t temp_storage_bytes = 0;
472 | int *temp_ptr = nullptr;
473 | cub::DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes,
474 | temp_ptr, temp_ptr,
475 | temp_ptr, temp_ptr, max_expanded_token_num);
476 | Tensor temp_storage =
477 | torch::empty(temp_storage_bytes, torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false));
478 |
479 | workspace.push_back(sorted_indices);
480 | workspace.push_back(row_id);
481 | workspace.push_back(sorted_row_id);
482 | workspace.push_back(temp_storage);
483 | }
484 |
485 | int *indices_ptr = get_ptr(indices);
486 | int *sorted_indices_ptr = get_ptr(workspace[0]);
487 | int *row_id_ptr = get_ptr(workspace[1]);
488 | int *sorted_row_id_ptr = get_ptr(workspace[2]);
489 |
490 | void *d_temp_storage = get_ptr(workspace[3]);
491 | size_t temp_storage_bytes = std::numeric_limits::max();
492 |
493 | cub::DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes,
494 | indices_ptr, sorted_indices_ptr,
495 | row_id_ptr, sorted_row_id_ptr, num_tokens * num_topK);
496 |
497 | // activations type
498 | const at::ScalarType _st = input.scalar_type();
499 |
500 | // Output buffer alloc
501 | num_out_tokens = (num_out_tokens > 0) ? num_out_tokens : num_tokens * num_topK;
502 | Tensor permuted_output =
503 | torch::empty({num_out_tokens, num_cols}, torch::dtype(_st).device(torch::kCUDA).requires_grad(false));
504 | Tensor row_id_map =
505 | torch::empty({num_tokens * num_topK}, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false));
506 |
507 | int *row_id_map_ptr = get_ptr(row_id_map);
508 | auto stream = at::cuda::getCurrentCUDAStream().stream();
509 |
510 | switch (_st)
511 | {
512 | case at::ScalarType::Float:
513 | {
514 | using dType = float;
515 | using dTypeCompute = float;
516 |
517 | dType *input_ptr = get_ptr(input);
518 | dType *permuted_output_ptr = get_ptr(permuted_output);
519 |
520 | moe_permute_topK_kernel_launcher(
521 | input_ptr,
522 | permuted_output_ptr,
523 | sorted_row_id_ptr,
524 | row_id_map_ptr,
525 | nullptr,
526 | num_tokens,
527 | num_topK,
528 | num_cols,
529 | num_out_tokens,
530 | stream);
531 |
532 | break;
533 | }
534 | case at::ScalarType::Half:
535 | {
536 | using dType = cutlass::half_t;
537 | using dTypeCompute = cutlass::half_t;
538 |
539 | dType *input_ptr = get_ptr(input);
540 | dType *permuted_output_ptr = get_ptr(permuted_output);
541 |
542 | moe_permute_topK_kernel_launcher(
543 | input_ptr,
544 | permuted_output_ptr,
545 | sorted_row_id_ptr,
546 | row_id_map_ptr,
547 | nullptr,
548 | num_tokens,
549 | num_topK,
550 | num_cols,
551 | num_out_tokens,
552 | stream);
553 |
554 | break;
555 | }
556 | #ifdef ENABLE_BF16
557 | case at::ScalarType::BFloat16:
558 | {
559 | using dType = cutlass::bfloat16_t;
560 | using dTypeCompute = cutlass::bfloat16_t;
561 |
562 | dType *input_ptr = get_ptr(input);
563 | dType *permuted_output_ptr = get_ptr(permuted_output);
564 |
565 | moe_permute_topK_kernel_launcher(
566 | input_ptr,
567 | permuted_output_ptr,
568 | sorted_row_id_ptr,
569 | row_id_map_ptr,
570 | nullptr,
571 | num_tokens,
572 | num_topK,
573 | num_cols,
574 | num_out_tokens,
575 | stream);
576 |
577 | break;
578 | }
579 | #endif
580 | #ifdef ENABLE_FP8
581 | case at::ScalarType::Float8_e5m2:
582 | {
583 | using dType = cutlass::float_e5m2_t;
584 | using dTypeCompute = cutlass::half_t;
585 |
586 | dType *input_ptr = get_ptr(input);
587 | dType *permuted_output_ptr = get_ptr(permuted_output);
588 |
589 | moe_permute_topK_kernel_launcher(
590 | input_ptr,
591 | permuted_output_ptr,
592 | sorted_row_id_ptr,
593 | row_id_map_ptr,
594 | nullptr,
595 | num_tokens,
596 | num_topK,
597 | num_cols,
598 | num_out_tokens,
599 | stream);
600 |
601 | break;
602 | }
603 | case at::ScalarType::Float8_e4m3fn:
604 | {
605 | using dType = cutlass::float_e4m3_t;
606 | using dTypeCompute = cutlass::half_t;
607 |
608 | dType *input_ptr = get_ptr(input);
609 | dType *permuted_output_ptr = get_ptr(permuted_output);
610 |
611 | moe_permute_topK_kernel_launcher(
612 | input_ptr,
613 | permuted_output_ptr,
614 | sorted_row_id_ptr,
615 | row_id_map_ptr,
616 | nullptr,
617 | num_tokens,
618 | num_topK,
619 | num_cols,
620 | num_out_tokens,
621 | stream);
622 |
623 | break;
624 | }
625 | #endif
626 | default:
627 | throw std::runtime_error("Wrong activation tensor type.");
628 | }
629 |
630 | return std::make_tuple(permuted_output, row_id_map, workspace);
631 | }
632 |
633 | /////////////////////////////////////////////////////////////////////////////////////////////////
634 | //
635 | // Unpermute_topK OP
636 | //
637 | /////////////////////////////////////////////////////////////////////////////////////////////////
638 |
639 | Tensor moe_recover_topK_op(
640 | Tensor input,
641 | Tensor row_id_map,
642 | Tensor prob,
643 | int64_t num_tokens,
644 | int64_t num_topK)
645 | {
646 | const int num_cols = input.size(1);
647 |
648 | // activations type
649 | const at::ScalarType _st = input.scalar_type();
650 |
651 | // Output buffer alloc
652 | Tensor unpermuted_output =
653 | torch::empty({num_tokens, num_cols}, torch::dtype(_st).device(torch::kCUDA).requires_grad(false));
654 |
655 | int *row_id_map_ptr = get_ptr(row_id_map);
656 | float *prob_ptr = (prob.defined()) ? get_ptr(prob) : nullptr;
657 | auto stream = at::cuda::getCurrentCUDAStream().stream();
658 |
659 | switch (_st)
660 | {
661 | case at::ScalarType::Float:
662 | {
663 | using dType = float;
664 | using dTypeCompute = float;
665 |
666 | dType *input_ptr = get_ptr(input);
667 | dType *unpermuted_output_ptr = get_ptr(unpermuted_output);
668 |
669 | moe_permute_topK_kernel_launcher(
670 | input_ptr,
671 | unpermuted_output_ptr,
672 | nullptr,
673 | row_id_map_ptr,
674 | prob_ptr,
675 | num_tokens,
676 | num_topK,
677 | num_cols,
678 | 0,
679 | stream);
680 |
681 | break;
682 | }
683 | case at::ScalarType::Half:
684 | {
685 | using dType = cutlass::half_t;
686 | using dTypeCompute = cutlass::half_t;
687 |
688 | dType *input_ptr = get_ptr(input);
689 | dType *unpermuted_output_ptr = get_ptr(unpermuted_output);
690 |
691 | moe_permute_topK_kernel_launcher(
692 | input_ptr,
693 | unpermuted_output_ptr,
694 | nullptr,
695 | row_id_map_ptr,
696 | prob_ptr,
697 | num_tokens,
698 | num_topK,
699 | num_cols,
700 | 0,
701 | stream);
702 |
703 | break;
704 | }
705 | #ifdef ENABLE_BF16
706 | case at::ScalarType::BFloat16:
707 | {
708 | using dType = cutlass::bfloat16_t;
709 | using dTypeCompute = cutlass::bfloat16_t;
710 |
711 | dType *input_ptr = get_ptr(input);
712 | dType *unpermuted_output_ptr = get_ptr(unpermuted_output);
713 |
714 | moe_permute_topK_kernel_launcher(
715 | input_ptr,
716 | unpermuted_output_ptr,
717 | nullptr,
718 | row_id_map_ptr,
719 | prob_ptr,
720 | num_tokens,
721 | num_topK,
722 | num_cols,
723 | 0,
724 | stream);
725 |
726 | break;
727 | }
728 | #endif
729 | #ifdef ENABLE_FP8
730 | case at::ScalarType::Float8_e5m2:
731 | {
732 | using dType = cutlass::float_e5m2_t;
733 | using dTypeCompute = cutlass::half_t;
734 |
735 | dType *input_ptr = get_ptr(input);
736 | dType *unpermuted_output_ptr = get_ptr(unpermuted_output);
737 |
738 | moe_permute_topK_kernel_launcher(
739 | input_ptr,
740 | unpermuted_output_ptr,
741 | nullptr,
742 | row_id_map_ptr,
743 | prob_ptr,
744 | num_tokens,
745 | num_topK,
746 | num_cols,
747 | 0,
748 | stream);
749 |
750 | break;
751 | }
752 | case at::ScalarType::Float8_e4m3fn:
753 | {
754 | using dType = cutlass::float_e4m3_t;
755 | using dTypeCompute = cutlass::half_t;
756 |
757 | dType *input_ptr = get_ptr(input);
758 | dType *unpermuted_output_ptr = get_ptr(unpermuted_output);
759 |
760 | moe_permute_topK_kernel_launcher(
761 | input_ptr,
762 | unpermuted_output_ptr,
763 | nullptr,
764 | row_id_map_ptr,
765 | prob_ptr,
766 | num_tokens,
767 | num_topK,
768 | num_cols,
769 | 0,
770 | stream);
771 |
772 | break;
773 | }
774 | #endif
775 | default:
776 | throw std::runtime_error("Wrong activation tensor type.");
777 | }
778 |
779 | return unpermuted_output;
780 | }
781 |
782 | std::tuple moe_recover_topK_bwd_op(
783 | Tensor input_bwd,
784 | Tensor input_fwd,
785 | Tensor row_id_map,
786 | Tensor prob)
787 | {
788 | const int num_tokens = prob.size(0);
789 | const int num_topK = prob.size(1);
790 | const int num_cols = input_bwd.size(1);
791 |
792 | int *row_id_map_ptr = get_ptr(row_id_map);
793 | float *prob_ptr = get_ptr(prob);
794 |
795 | // activations type
796 | const at::ScalarType _st = input_bwd.scalar_type();
797 |
798 | // Output buffer alloc
799 | Tensor act_grad =
800 | torch::empty({input_fwd.size(0), num_cols}, torch::dtype(_st).device(torch::kCUDA).requires_grad(false));
801 | Tensor prob_grad =
802 | torch::empty({num_tokens, num_topK}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(false));
803 | float *prob_grad_ptr = get_ptr(prob_grad);
804 |
805 | auto stream = at::cuda::getCurrentCUDAStream().stream();
806 |
807 | switch (_st)
808 | {
809 | case at::ScalarType::Float:
810 | {
811 | using dType = float;
812 | using dTypeCompute = float;
813 |
814 | dType *input_bwd_ptr = get_ptr(input_bwd);
815 | dType *input_fwd_ptr = get_ptr(input_fwd);
816 | dType *act_grad_ptr = get_ptr(act_grad);
817 |
818 | moe_permute_topK_kernel_launcher(
819 | input_bwd_ptr,
820 | act_grad_ptr,
821 | nullptr,
822 | row_id_map_ptr,
823 | prob_ptr,
824 | num_tokens,
825 | num_topK,
826 | num_cols,
827 | 0,
828 | stream,
829 | prob_grad_ptr,
830 | input_fwd_ptr);
831 |
832 | break;
833 | }
834 | case at::ScalarType::Half:
835 | {
836 | using dType = cutlass::half_t;
837 | using dTypeCompute = cutlass::half_t;
838 |
839 | dType *input_bwd_ptr = get_ptr(input_bwd);
840 | dType *input_fwd_ptr = get_ptr(input_fwd);
841 | dType *act_grad_ptr = get_ptr(act_grad);
842 |
843 | moe_permute_topK_kernel_launcher(
844 | input_bwd_ptr,
845 | act_grad_ptr,
846 | nullptr,
847 | row_id_map_ptr,
848 | prob_ptr,
849 | num_tokens,
850 | num_topK,
851 | num_cols,
852 | 0,
853 | stream,
854 | prob_grad_ptr,
855 | input_fwd_ptr);
856 |
857 | break;
858 | }
859 | #ifdef ENABLE_BF16
860 | case at::ScalarType::BFloat16:
861 | {
862 | using dType = cutlass::bfloat16_t;
863 | using dTypeCompute = cutlass::bfloat16_t;
864 |
865 | dType *input_bwd_ptr = get_ptr(input_bwd);
866 | dType *input_fwd_ptr = get_ptr(input_fwd);
867 | dType *act_grad_ptr = get_ptr(act_grad);
868 |
869 | moe_permute_topK_kernel_launcher(
870 | input_bwd_ptr,
871 | act_grad_ptr,
872 | nullptr,
873 | row_id_map_ptr,
874 | prob_ptr,
875 | num_tokens,
876 | num_topK,
877 | num_cols,
878 | 0,
879 | stream,
880 | prob_grad_ptr,
881 | input_fwd_ptr);
882 |
883 | break;
884 | }
885 | #endif
886 | #ifdef ENABLE_FP8
887 | case at::ScalarType::Float8_e5m2:
888 | {
889 | using dType = cutlass::float_e5m2_t;
890 | using dTypeCompute = cutlass::half_t;
891 |
892 | dType *input_bwd_ptr = get_ptr(input_bwd);
893 | dType *input_fwd_ptr = get_ptr(input_fwd);
894 | dType *act_grad_ptr = get_ptr(act_grad);
895 |
896 | moe_permute_topK_kernel_launcher(
897 | input_bwd_ptr,
898 | act_grad_ptr,
899 | nullptr,
900 | row_id_map_ptr,
901 | prob_ptr,
902 | num_tokens,
903 | num_topK,
904 | num_cols,
905 | 0,
906 | stream,
907 | prob_grad_ptr,
908 | input_fwd_ptr);
909 |
910 | break;
911 | }
912 | case at::ScalarType::Float8_e4m3fn:
913 | {
914 | using dType = cutlass::float_e4m3_t;
915 | using dTypeCompute = cutlass::half_t;
916 |
917 | dType *input_bwd_ptr = get_ptr(input_bwd);
918 | dType *input_fwd_ptr = get_ptr(input_fwd);
919 | dType *act_grad_ptr = get_ptr(act_grad);
920 |
921 | moe_permute_topK_kernel_launcher(
922 | input_bwd_ptr,
923 | act_grad_ptr,
924 | nullptr,
925 | row_id_map_ptr,
926 | prob_ptr,
927 | num_tokens,
928 | num_topK,
929 | num_cols,
930 | 0,
931 | stream,
932 | prob_grad_ptr,
933 | input_fwd_ptr);
934 |
935 | break;
936 | }
937 | #endif
938 | default:
939 | throw std::runtime_error("Wrong activation tensor type.");
940 | }
941 |
942 | return std::make_tuple(act_grad, prob_grad);
943 | }
944 |
945 | } // namespace grouped_gemm
946 |
--------------------------------------------------------------------------------
/csrc/permute.h:
--------------------------------------------------------------------------------
1 | /*************************************************************************
2 | * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | *
4 | * See LICENSE for license information.
5 | ************************************************************************/
6 |
7 | #pragma once
8 |
9 | #include
10 |
11 | using torch::Tensor;
12 |
13 | namespace grouped_gemm {
14 |
15 | std::tuple> moe_permute_topK_op(
16 | Tensor input,
17 | Tensor indices,
18 | int64_t num_out_tokens,
19 | std::vector workspace,
20 | int64_t max_expanded_token_num);
21 |
22 | torch::Tensor moe_recover_topK_op(
23 | torch::Tensor input,
24 | torch::Tensor row_id_map,
25 | torch::Tensor prob_opt,
26 | int64_t num_tokens,
27 | int64_t num_topK);
28 |
29 | std::tuple moe_recover_topK_bwd_op(
30 | Tensor input_bwd,
31 | Tensor input_fwd,
32 | Tensor row_id_map,
33 | Tensor prob);
34 |
35 | } // namespace grouped_gemm
--------------------------------------------------------------------------------
/csrc/sinkhorn.cu:
--------------------------------------------------------------------------------
1 | #include "sinkhorn.h"
2 |
3 | #include
4 | #include
5 | #include
6 |
7 | #include
8 |
9 | namespace grouped_gemm {
10 |
11 | __global__ void sinkhorn_kernel(float *cost, const int rows, const int cols, float tol) {
12 | assert(rows >= cols && cols < blockDim.x);
13 |
14 | extern __shared__ float shared_memory[];
15 | float *shared_d0 = shared_memory; // For d0
16 | float *shared_d1 = (float*)&shared_d0[rows]; // For d1
17 | float *shared_d1_old = (float*)&shared_d1[cols]; // For d1_old
18 | float *abs_diff_sum = (float*)&shared_d1_old[cols]; // For sum of absolute differences
19 |
20 | int idx = blockIdx.x * blockDim.x + threadIdx.x;
21 |
22 | // Exponentiate cost matrix
23 | if (idx < rows * cols) {
24 | for (int flat_idx = idx; flat_idx < rows * cols; flat_idx += blockDim.x)
25 | cost[flat_idx] = expf(cost[flat_idx]);
26 | }
27 |
28 | if (idx >= rows) return;
29 |
30 | // Initilization for d0, d1, d1_old and abs_diff_sum vector.
31 | for (int row_idx = idx; row_idx < rows; row_idx += blockDim.x) {
32 | shared_d0[row_idx] = 1.0f;
33 | }
34 |
35 | if (idx < cols) {
36 | shared_d1[idx] = 1.0f;
37 | shared_d1_old[idx] = 1.0f;
38 | abs_diff_sum[idx] = 0.0f;
39 | }
40 | __syncthreads();
41 |
42 | tol = tol * cols; // error mean check --> error sum check
43 | const float eps = 1e-8;
44 | float local_error = 0.0f;
45 | do {
46 | local_error = 0.0f;
47 |
48 | // Update d0.
49 | for (int row_idx = idx; row_idx < rows; row_idx += blockDim.x) {
50 | float sum = 0.0f;
51 | for (int j = 0; j < cols; ++j) {
52 | sum += shared_d1[j] * cost[row_idx * cols + j];
53 | }
54 | // Using __fdividef for fast division
55 | shared_d0[row_idx] = __fdividef(1.0f, (sum + eps) * rows);
56 | }
57 | __syncthreads();
58 |
59 | // Update d1 and calculate absolute differences.
60 | if (idx < cols) {
61 | float sum = 0.0f;
62 | for (int i = 0; i < rows; ++i) {
63 | sum += shared_d0[i] * cost[i * cols + idx];
64 | }
65 | float new_d1 = __fdividef(1.0, (sum + eps) * cols);
66 | abs_diff_sum[idx] = fabsf(new_d1 - shared_d1_old[idx]);
67 | shared_d1[idx] = new_d1;
68 | // Update shared_d1_old for the next iteration
69 | shared_d1_old[idx] = new_d1;
70 | }
71 | __syncthreads();
72 |
73 | // Compute the sum absolute difference error.
74 | for (int i = 0; i < cols; ++i) {
75 | local_error += abs_diff_sum[i];
76 | }
77 |
78 | } while (local_error > tol);
79 |
80 | // Final multiplication.
81 | for (int row_idx = idx; row_idx < rows; row_idx += blockDim.x) {
82 | for (int j = 0; j < cols; ++j) {
83 | cost[row_idx * cols + j] *= shared_d1[j] * shared_d0[row_idx];
84 | }
85 | }
86 | }
87 |
88 | void sinkhorn_launch(float *cost, int rows, int cols, float tol) {
89 | int threadsPerBlock = 1024;
90 | int blocksPerGrid = 1;
91 | // Allocate enough shared memory for d0, d1, d1_old and abs_diff_sum
92 | size_t sharedMemSize = (rows + cols * 2 + cols) * sizeof(float);
93 | sinkhorn_kernel<<>>(cost, rows, cols, tol);
94 | // cudaDeviceSynchronize();
95 | }
96 |
97 | // Wrapper function
98 | torch::Tensor sinkhorn(torch::Tensor cost, const float tol) {
99 | sinkhorn_launch(cost.data_ptr(), cost.size(0), cost.size(1), tol);
100 |
101 | cudaError_t err = cudaGetLastError();
102 | if (err != cudaSuccess) {
103 | printf("CUDA Error: %s\n", cudaGetErrorString(err));
104 | }
105 | return cost;
106 | }
107 |
108 | } // namespace grouped_gemm
109 |
--------------------------------------------------------------------------------
/csrc/sinkhorn.h:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | namespace grouped_gemm {
4 |
5 | torch::Tensor sinkhorn(torch::Tensor cost, const float tol=0.0001);
6 |
7 | } // namespace grouped_gemm
8 |
--------------------------------------------------------------------------------
/figures/figure1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanshiqing/grouped_gemm/5c1d831ecf91b225abc91689683e7de67fbee7ef/figures/figure1.png
--------------------------------------------------------------------------------
/figures/figure_groupedgemm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanshiqing/grouped_gemm/5c1d831ecf91b225abc91689683e7de67fbee7ef/figures/figure_groupedgemm.png
--------------------------------------------------------------------------------
/figures/figure_permute.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanshiqing/grouped_gemm/5c1d831ecf91b225abc91689683e7de67fbee7ef/figures/figure_permute.png
--------------------------------------------------------------------------------
/figures/figure_unpermute.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanshiqing/grouped_gemm/5c1d831ecf91b225abc91689683e7de67fbee7ef/figures/figure_unpermute.png
--------------------------------------------------------------------------------
/grouped_gemm/__init__.py:
--------------------------------------------------------------------------------
1 | import grouped_gemm.ops
2 | import grouped_gemm.backend
3 |
--------------------------------------------------------------------------------
/grouped_gemm/backend.py:
--------------------------------------------------------------------------------
1 | # NOTE: Torch needs to be imported before the custom
2 | # extensions. Otherwise libc10.so cannot be found.
3 | import torch
4 |
5 | # TODO(tgale): Wrap this in a try-block with better
6 | # error message and instructions for building the
7 | # c++ operations.
8 | import grouped_gemm_backend as backend
9 |
10 |
11 | def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
12 | assert not (trans_a and trans_b)
13 | assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes"
14 | assert a.ndim == 2, "Expected 2d tensor for 'a'"
15 | assert b.ndim == (2 if trans_a else 3)
16 |
17 | shape = (
18 | (batch_sizes.shape[0], a.shape[1], b.shape[1])
19 | if trans_a else
20 | (a.shape[0], (b.shape[1] if trans_b else b.shape[2]))
21 | )
22 | return torch.empty(*shape, device=a.device, dtype=a.dtype)
23 |
24 | def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None):
25 | if c is None:
26 | c = _allocate_output(a, b, batch_sizes, trans_a, trans_b)
27 | backend.gmm(a, b, c, batch_sizes, trans_a, trans_b)
28 | return c
29 |
30 | def sinkhorn(cost, tol=0.0001):
31 | return backend.sinkhorn(cost, tol)
32 |
33 | def permute(input, indices, num_out_tokens, workspace, max_expanded_token_num):
34 | return backend.permute(input, indices, num_out_tokens, workspace, max_expanded_token_num)
35 |
36 | def unpermute(input, row_id_map, prob, max_tokens, num_topK):
37 | return backend.unpermute(input, row_id_map, prob, max_tokens, num_topK)
38 |
39 | def unpermute_bwd(input_bwd, input_fwd, row_id_map, prob):
40 | # TODO: @Jiang fix the case in kernel to allow None probs
41 | if prob is None:
42 | prob = torch.ones([input_bwd.size(0), 1], dtype=torch.float32, device=input_bwd.device)
43 | return backend.unpermute_bwd(input_bwd, input_fwd, row_id_map, prob)
44 |
--------------------------------------------------------------------------------
/grouped_gemm/ops.py:
--------------------------------------------------------------------------------
1 | from grouped_gemm import backend
2 | import torch
3 | import warnings
4 |
5 | from sys import stderr
6 | import torch.cuda.nvtx as nvtx
7 |
8 |
9 | # For debug's convenience
10 | ENABLE_NVTX=False
11 |
12 | class GroupedGemm(torch.autograd.Function):
13 |
14 | @staticmethod
15 | def forward(ctx, a, b, batch_sizes, trans_b):
16 | assert torch.count_nonzero(batch_sizes) != 0, "Input batch_sizes should not be all zeros!"
17 | ctx.save_for_backward(a, b, batch_sizes)
18 | ctx.trans_b = trans_b
19 | return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b)
20 |
21 | @staticmethod
22 | def backward(ctx, grad):
23 | grad = grad.contiguous()
24 | a, b, batch_sizes = ctx.saved_tensors
25 | trans_b = ctx.trans_b
26 |
27 | agrad = None
28 | if ctx.needs_input_grad[0]:
29 | agrad = backend.gmm(
30 | grad, b, batch_sizes, trans_a=False, trans_b=not trans_b)
31 |
32 | bgrad = None
33 | if ctx.needs_input_grad[1]:
34 | lhs, rhs = (grad, a) if trans_b else (a, grad)
35 | bgrad = backend.gmm(
36 | lhs, rhs, batch_sizes, trans_a=True, trans_b=False)
37 | return agrad, bgrad, None, None
38 |
39 |
40 | def gmm(a, b, batch_sizes, trans_b=False):
41 | return GroupedGemm.apply(a, b, batch_sizes, trans_b)
42 |
43 | def sinkhorn_kernel(cost, tol=0.0001):
44 | return backend.sinkhorn(cost, tol)
45 |
46 | ################################################################################################
47 | ##
48 | ## PermuteMoE topK
49 | ##
50 | ################################################################################################
51 |
52 | class PermuteMoE_topK(torch.autograd.Function):
53 |
54 | workspace_fw=None
55 | dtype=None
56 | max_expanded_token_num=0
57 |
58 | @staticmethod
59 | def forward(ctx,
60 | input_act: torch.Tensor,
61 | indices: torch.Tensor,
62 | num_out_tokens: int,
63 | max_token_num: int):
64 | '''
65 | indices: for topK=1, indices in a 1-d tensor of shape [num_tokens],
66 | otherwise, it's a 2-d tensor of shape [num_tokens, topK]
67 | '''
68 | if ENABLE_NVTX:
69 | nvtx.range_push("permute_topK forward")
70 | # Empty input check
71 | if not input_act.numel():
72 | if ENABLE_NVTX:
73 | nvtx.range_pop()
74 | return input_act, None
75 |
76 | # For top1 case, view the indices as 2D tensor to unify the shape for topk>=2 cases.
77 | if indices.dim() == 1:
78 | indices = indices.view(-1, 1)
79 |
80 | # Device check
81 | if input_act.is_cpu:
82 | raise RuntimeError("[Error] The input `input_act` of permute_topK op is on the device: CPU!")
83 | if indices.is_cpu:
84 | warnings.warn("The input `indices` of permute_topK op is on the device: CPU!")
85 | expert_for_rows = expert_for_rows.cuda()
86 |
87 | # Shape check
88 | if input_act.size(0) != indices.size(0):
89 | raise RuntimeError(f"[Error] permute_topK op input `indices` shape mismatch! "
90 | f"Expect {input_act.size(0)}, but got {indices.size(0)}.")
91 |
92 | # Data type check
93 | if indices.dtype != torch.int32:
94 | warnings.warn(f"The data type of the input `indices` of permute_topK op is {indices.dtype}! "
95 | "The recommended type is torch.int32.")
96 | indices = indices.to(torch.int32)
97 |
98 | # Contiguous check
99 | if not input_act.is_contiguous():
100 | warnings.warn("The input `input_act` of permute_topK op is discontiguous!")
101 | input_act = input_act.contiguous()
102 | if not indices.is_contiguous():
103 | warnings.warn("The input `indices` of permute_topK op is discontiguous!")
104 | indices = indices.contiguous()
105 |
106 | num_topK = indices.size(1)
107 |
108 | input_max_expanded_token_num = max(max_token_num, input_act.size(0)) * num_topK
109 | if PermuteMoE_topK.max_expanded_token_num < input_max_expanded_token_num:
110 | PermuteMoE_topK.max_expanded_token_num = input_max_expanded_token_num
111 | PermuteMoE_topK.workspace_fw = []
112 |
113 | if PermuteMoE_topK.dtype != input_act.dtype:
114 | PermuteMoE_topK.dtype = input_act.dtype
115 | PermuteMoE_topK.workspace_fw = []
116 |
117 | permuted_act, row_id_map, PermuteMoE_topK.workspace_fw = backend.permute(
118 | input_act,
119 | indices,
120 | num_out_tokens,
121 | PermuteMoE_topK.workspace_fw,
122 | PermuteMoE_topK.max_expanded_token_num)
123 |
124 | ctx.row_id_map = row_id_map
125 | ctx.num_tokens = indices.size(0)
126 | ctx.num_topK = num_topK
127 | if ENABLE_NVTX:
128 | nvtx.range_pop()
129 | return permuted_act, row_id_map
130 |
131 |
132 | @staticmethod
133 | def backward(ctx, permuted_act_grad, _):
134 | if ENABLE_NVTX:
135 | nvtx.range_push("permute_topK backward")
136 | # Empty input check
137 | if not permuted_act_grad.numel():
138 | if ENABLE_NVTX:
139 | nvtx.range_pop()
140 | return permuted_act_grad, None, None, None
141 |
142 | if not permuted_act_grad.is_contiguous():
143 | permuted_act_grad = permuted_act_grad.contiguous()
144 |
145 | row_id_map = ctx.row_id_map
146 | num_tokens = ctx.num_tokens
147 | num_topK = ctx.num_topK
148 |
149 | unpermuted_act_grad = backend.unpermute(
150 | permuted_act_grad,
151 | row_id_map,
152 | torch.tensor([]),
153 | num_tokens,
154 | num_topK)
155 | if ENABLE_NVTX:
156 | nvtx.range_pop()
157 | return unpermuted_act_grad, None, None, None
158 |
159 | ################################################################################################
160 | ##
161 | ## UnpermuteMoE topK
162 | ##
163 | ################################################################################################
164 |
165 | class UnpermuteMoE_topK(torch.autograd.Function):
166 |
167 | @staticmethod
168 | def forward(ctx,
169 | input_act: torch.Tensor,
170 | row_id_map: torch.Tensor,
171 | probs: torch.Tensor = None):
172 | if ENABLE_NVTX:
173 | nvtx.range_push("unpermute_topK forward")
174 | # Empty input check
175 | if not input_act.numel():
176 | ctx.probs = probs
177 | if ENABLE_NVTX:
178 | nvtx.range_pop()
179 | return input_act
180 |
181 | # Device check
182 | if input_act.is_cpu:
183 | raise RuntimeError("[Error] The input `input_act` of unpermute_topK op is on the device: CPU!")
184 | if row_id_map.is_cpu:
185 | warnings.warn("The input `row_id_map` of unpermute_topK op is on the device: CPU!")
186 | row_id_map = row_id_map.cuda()
187 | if probs is not None and probs.is_cpu:
188 | warnings.warn("The input `probs` of unpermute_topK op is on the device: CPU!")
189 | probs = probs.cuda()
190 |
191 | # Shape check
192 | if probs is not None and row_id_map.size(0) != probs.size(0) * probs.size(1):
193 | raise RuntimeError(f"[Error] unpermute_topK op input `probs` shape mismatch! "
194 | f"Expect {row_id_map.size(0)}, but got {probs.size(0) * probs.size(1)}.")
195 |
196 | # Data type check
197 | if row_id_map.dtype != torch.int32:
198 | warnings.warn(f"The data type of the input `row_id_map` of unpermute_topK op is {row_id_map.dtype}! "
199 | "The recommended type is torch.int32.")
200 | row_id_map = row_id_map.to(torch.int32)
201 | if probs is not None and probs.dtype != torch.float32:
202 | warnings.warn(f"The data type of the input `probs` of unpermute_topK op is {probs.dtype}! "
203 | "The recommended type is torch.float32.")
204 | probs = probs.to(torch.float32)
205 |
206 | # Contiguous check
207 | if not input_act.is_contiguous():
208 | warnings.warn("The input `input_act` of unpermute_topK op is discontiguous!")
209 | input_act = input_act.contiguous()
210 | if not row_id_map.is_contiguous():
211 | warnings.warn("The input `row_id_map` of unpermute_topK op is discontiguous!")
212 | row_id_map = row_id_map.contiguous()
213 | if probs is not None and not probs.is_contiguous():
214 | warnings.warn("The input `probs` of unpermute_topK op is discontiguous!")
215 | probs = probs.contiguous()
216 |
217 | num_tokens = probs.size(0) if probs is not None else input_act.size(0)
218 | num_topK = probs.size(1) if probs is not None else 1
219 |
220 | unpermuted_output = backend.unpermute(
221 | input_act,
222 | row_id_map,
223 | probs if probs is not None else torch.tensor([]),
224 | num_tokens,
225 | num_topK)
226 |
227 | ctx.save_for_backward(input_act, row_id_map, probs)
228 | if ENABLE_NVTX:
229 | nvtx.range_pop()
230 | return unpermuted_output
231 |
232 | @staticmethod
233 | def backward(ctx, unpermuted_act_grad):
234 | if ENABLE_NVTX:
235 | nvtx.range_push("unpermute_topK backward")
236 | # Empty input check
237 | if not unpermuted_act_grad.numel():
238 | if ENABLE_NVTX:
239 | nvtx.range_pop()
240 | return unpermuted_act_grad, None, ctx.probs
241 |
242 | if not unpermuted_act_grad.is_contiguous():
243 | unpermuted_act_grad = unpermuted_act_grad.contiguous()
244 |
245 | input_act, row_id_map, probs = ctx.saved_tensors
246 |
247 | act_grad = None
248 | if ctx.needs_input_grad[0]:
249 | act_grad, prob_grad = backend.unpermute_bwd(
250 | unpermuted_act_grad,
251 | input_act,
252 | row_id_map,
253 | probs)
254 |
255 | if not ctx.needs_input_grad[2]:
256 | prob_grad = None
257 | if ENABLE_NVTX:
258 | nvtx.range_pop()
259 | return act_grad, None, prob_grad
260 |
261 | def permute(input_act, indices, num_out_tokens=None, max_token_num=0):
262 | num_out_tokens = 0 if num_out_tokens is None else num_out_tokens
263 | return PermuteMoE_topK.apply(input_act, indices, num_out_tokens, max_token_num)
264 |
265 | def unpermute(input_act, row_id_map, probs=None):
266 | return UnpermuteMoE_topK.apply(input_act, row_id_map, probs)
267 |
--------------------------------------------------------------------------------
/grouped_gemm/ops_test.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import itertools
3 |
4 | from absl.testing import parameterized
5 | from grouped_gemm import ops
6 | import numpy as np
7 | import torch
8 |
9 |
10 | def allclose(x, y, pct=2.0):
11 | mask = torch.isclose(x, y, rtol=1e-5)
12 | pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100
13 | if pct_diff > pct:
14 | print(x[torch.logical_not(mask)], y[torch.logical_not(mask)])
15 | print("{:.2f}% of values not close.".format(pct_diff))
16 | return False
17 | return True
18 |
19 |
20 | def add_transpose_flags(x):
21 | out = []
22 | for y in x:
23 | for f in [(False,), (True,)]:
24 | out.append(y + f)
25 | return out
26 |
27 |
28 | _TEST_PROBLEMS = add_transpose_flags((
29 | (1, 128, 128, 128),
30 | (8, 128, 128, 128),
31 | (16, 128, 128, 128),
32 | (1, 128, 256, 512),
33 | (8, 128, 256, 512),
34 | (16, 128, 256, 512),
35 | ))
36 |
37 |
38 | def randn(bs, x, y):
39 | out = (torch.rand(bs, x, y) - 0.5 * 2) / (y * x)
40 | return out.cuda().to(torch.bfloat16)
41 |
42 |
43 | def gmm(a, b, batch_sizes, trans_b=False):
44 | batch_sizes = batch_sizes.numpy()
45 |
46 | out = []
47 | start = 0
48 | for i, size in enumerate(batch_sizes):
49 | rhs = b[i, :, :].t() if trans_b else b[i, :, :]
50 | out.append(a[start:start + size, :] @ rhs)
51 | start += size
52 | return torch.cat(out)
53 |
54 |
55 | @parameterized.parameters(*_TEST_PROBLEMS)
56 | class OpsTest(parameterized.TestCase):
57 |
58 | def testGroupedGemm_FixedSizes(self, z, m, k, n, trans_b):
59 | torch.manual_seed(0)
60 | a = randn(z, m, k).view(-1, k)
61 | b = randn(z, n, k) if trans_b else randn(z, k, n)
62 | batch_sizes = torch.tensor([m] * z)
63 |
64 | a.requires_grad_(True)
65 | b.requires_grad_(True)
66 | a_ref = a.detach().clone().requires_grad_(True)
67 | b_ref = b.detach().clone().requires_grad_(True)
68 |
69 | out = ops.gmm(a, b, batch_sizes, trans_b)
70 | expected_out = gmm(a_ref, b_ref, batch_sizes, trans_b)
71 | self.assertTrue(allclose(out, expected_out))
72 |
73 | # Check gradients.
74 | out.sum().backward()
75 | expected_out.sum().backward()
76 | self.assertTrue(allclose(a.grad, a_ref.grad))
77 | self.assertTrue(allclose(b.grad, b_ref.grad))
78 |
79 | def testGroupedGemm_VariableSizes(self, z, m, k, n, trans_b):
80 | torch.manual_seed(0)
81 | a = randn(z, m, k).view(-1, k)
82 | b = randn(z, n, k) if trans_b else randn(z, k, n)
83 |
84 | dist = torch.rand(z, )
85 | dist /= dist.sum()
86 | batch_sizes = (dist * m).to(torch.long)
87 | error = m * z - batch_sizes.sum()
88 | batch_sizes[-1] += error
89 | assert batch_sizes.sum() == (m * z)
90 |
91 | a.requires_grad_(True)
92 | b.requires_grad_(True)
93 | a_ref = a.detach().clone().requires_grad_(True)
94 | b_ref = b.detach().clone().requires_grad_(True)
95 |
96 | out = ops.gmm(a, b, batch_sizes, trans_b)
97 | expected_out = gmm(a_ref, b_ref, batch_sizes, trans_b)
98 | self.assertTrue(allclose(out, expected_out))
99 |
100 | # Check gradients.
101 | out.sum().backward()
102 | expected_out.sum().backward()
103 | self.assertTrue(allclose(a.grad, a_ref.grad))
104 | self.assertTrue(allclose(b.grad, b_ref.grad))
105 |
106 |
107 |
108 | if __name__ == '__main__':
109 | unittest.main()
110 |
--------------------------------------------------------------------------------
/grouped_gemm/permute_test.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # See LICENSE for license information.
4 |
5 | import torch
6 | import triton
7 | import torch.cuda.nvtx as nvtx
8 |
9 | try:
10 | from grouped_gemm.ops import permute as permute_topK, unpermute as unpermute_topK
11 | except ImportError:
12 | print("grouped-gemm toolkit is not installed. Fall back to local import.")
13 | # For local debug
14 | import sys
15 | sys.path.append("..")
16 | from ops import permute as permute_topK, unpermute as unpermute_topK
17 |
18 | def permute(tokens, indices, expand_factor: int = 1, is_fp8=False):
19 | """Permute the tokens based on the indices.
20 |
21 | Args:
22 | tokens (torch.Tensor): The input token tensor.
23 | indices (torch.Tensor): The token2expert indices tensor.
24 |
25 | Returns:
26 | torch.Tensor: The permuted tensor.
27 | """
28 | expand_factor = indices.size(1)
29 |
30 | flatten_indices = indices.view(-1)
31 | sorted_indices = torch.argsort(flatten_indices, stable=True)
32 | permuted_tokens = tokens.index_select(0, sorted_indices // expand_factor)
33 | return permuted_tokens, sorted_indices
34 |
35 |
36 | def unpermute(permuted_tokens, sorted_indices, probs: torch.Tensor = None, merge_factor: int = 1):
37 | """Unpermute the sorted tokens based on the indices.
38 |
39 | Args:
40 | permuted_tokens (torch.Tensor): The permuted token tensor.
41 | sorted_indices (torch.Tensor): The sorted indices tensor.
42 | probs (torch.Tensor, optional): The probabilities tensor. Defaults to None.
43 | merge_factor (int, optional): The merge factor. Defaults to 1.
44 |
45 | Returns:
46 | torch.Tensor: The unpermuted tensor.
47 | """
48 | merge_factor = probs.size(1)
49 |
50 | if merge_factor > 1:
51 | assert probs is not None
52 | assert (
53 | probs.size(0) == permuted_tokens.size(0) // merge_factor
54 | ), f"{probs.size()} {permuted_tokens.size()}"
55 | if probs is not None:
56 | assert probs.size(0) == permuted_tokens.size(0) // merge_factor
57 | assert (
58 | probs.size(1) == merge_factor
59 | ), f"probs size {probs.size()} merge_factor {merge_factor}"
60 |
61 | # unpermuted_tokens = torch.zeros_like(permuted_tokens)
62 | unpermuted_tokens = permuted_tokens.index_copy(0, sorted_indices, permuted_tokens)
63 |
64 | unpermuted_tokens = unpermuted_tokens.reshape(-1, merge_factor, permuted_tokens.size(-1))
65 |
66 | if probs is not None:
67 | dtype = unpermuted_tokens.dtype
68 | unpermuted_tokens = unpermuted_tokens * probs.unsqueeze(-1)
69 | unpermuted_tokens = unpermuted_tokens.to(dtype)
70 | unpermuted_tokens = unpermuted_tokens.sum(dim=1)
71 |
72 | return unpermuted_tokens
73 |
74 | def permute_topK_test(
75 | dtype,
76 | num_token,
77 | num_expert,
78 | hidden_size,
79 | num_topK,
80 | PRINT,
81 | BENCHMARK):
82 |
83 | print(f"{dtype} token:{num_token} hidden_size:{hidden_size} expert:{num_expert} topK:{num_topK}")
84 |
85 | is_fp8 = dtype in [torch.float8_e5m2, torch.float8_e4m3fn]
86 |
87 | permute_input = torch.rand((num_token, hidden_size), dtype=torch.float32).cuda()
88 | # for i in range(num_token):
89 | # for j in range(hidden_size):
90 | # permute_input[i][j] = i * 100 + j
91 | permute_input = permute_input.to(dtype)
92 | if is_fp8:
93 | permute_input = permute_input.half()
94 |
95 | permute_input.requires_grad_(True)
96 |
97 | if num_token > 0:
98 | indices = torch.stack([torch.randperm(num_expert)[:num_topK] for _ in range(num_token)])
99 | else:
100 | indices = torch.empty((num_token, num_topK))
101 | indices = indices.to(torch.int32).cuda()
102 |
103 | # probs = torch.tensor([[0.1, 0.9],
104 | # [0.2, 0.8],
105 | # [0.3, 0.7]])
106 | # 0.5
107 | # probs = torch.ones_like(indices) / 2
108 | # rand
109 | probs = torch.rand(num_token, num_topK).cuda()
110 | row_sums = probs.sum(dim=1, keepdim=True)
111 | probs = probs / row_sums
112 | probs.requires_grad_(True)
113 |
114 | if PRINT:
115 | print(permute_input)
116 | print(indices)
117 | print(probs)
118 |
119 | ###################################################################################################################################
120 | #
121 | # PyTorch
122 | #
123 | ###################################################################################################################################
124 | nvtx.range_push("PyTorch permute forward")
125 | permute_output, sorted_indices = permute(permute_input, indices, num_topK, is_fp8)
126 | nvtx.range_pop()
127 |
128 | permute_bwd_input = torch.rand_like(permute_output)
129 | # for i in range(num_token * num_topK):
130 | # for j in range(hidden_size):
131 | # permute_bwd_input[i][j] = i * 100 + j
132 |
133 | nvtx.range_push("PyTorch permute backward")
134 | permute_output.backward(permute_bwd_input, retain_graph=True)
135 | nvtx.range_pop()
136 |
137 | unpermute_input = permute_output.detach()
138 | unpermute_input.requires_grad_(True)
139 |
140 | unpermute_output = unpermute(
141 | unpermute_input, sorted_indices, probs=probs, merge_factor=num_topK)
142 |
143 | if PRINT:
144 | print("--------------unpermute fwd permute_input--------------")
145 | print(unpermute_input)
146 | print("--------------unpermute fwd output--------------")
147 | print(unpermute_output)
148 |
149 | unpermute_bwd_input = torch.rand_like(unpermute_output)
150 | # for i in range(num_token):
151 | # for j in range(hidden_size):
152 | # unpermute_bwd_input[i][j] = i * 2000 + j * 20
153 |
154 | if PRINT:
155 | print("--------------unpermute bwd permute_input--------------")
156 | print(unpermute_bwd_input)
157 |
158 | unpermute_output.backward(unpermute_bwd_input, retain_graph=True)
159 | if PRINT:
160 | print("--------------unpermute bwd output act grad--------------")
161 | print(permute_output.grad)
162 | print("--------------unpermute bwd output probs grad--------------")
163 | print(probs.grad)
164 |
165 | ###################################################################################################################################
166 | #
167 | # Mine
168 | #
169 | ###################################################################################################################################
170 | new_permute_input = permute_input.detach().to(dtype)
171 | new_permute_bwd_input = permute_bwd_input.detach().to(dtype)
172 | new_unpermute_bwd_input = unpermute_bwd_input.detach().to(dtype)
173 | new_permute_input.requires_grad_(True)
174 |
175 | new_permute_output, row_id_map = permute_topK(new_permute_input, indices)
176 |
177 | assert torch.allclose(permute_output.float(), new_permute_output.float())
178 |
179 | if PRINT:
180 | print("--------------row_id_map--------------")
181 | print(row_id_map)
182 | print("--------------new_permute_input--------------")
183 | print(new_permute_input)
184 | print("--------------new_permute_output--------------")
185 | print(new_permute_output)
186 |
187 | new_permute_output.backward(new_permute_bwd_input, retain_graph=True)
188 |
189 | if torch.allclose(permute_input.grad.float(), new_permute_input.grad.float()) == False:
190 | original_inputs = new_permute_input.grad.float().cpu().numpy().flatten()
191 | original_output = permute_input.grad.float().cpu().numpy().flatten()
192 | max_abs_error = abs(original_inputs - original_output).max()
193 | print(f"permute_topK bwd max error (mine vs pytorch): \t\t\t{max_abs_error:.3e} ({dtype})")
194 |
195 | if PRINT:
196 | print(permute_input.grad)
197 | print(new_permute_input.grad)
198 |
199 | new_probs = probs.detach()
200 | new_probs.requires_grad_(True)
201 | new_unpermute_input = new_permute_output.detach()
202 | new_unpermute_input.requires_grad_(True)
203 |
204 | print("new_probs=", new_probs)
205 | new_unpermute_output = unpermute_topK(new_unpermute_input, row_id_map, new_probs)
206 |
207 | if torch.allclose(unpermute_output.float(), new_unpermute_output.float()) == False:
208 | original_inputs = unpermute_output.float().cpu().detach().numpy().flatten()
209 | original_output = new_unpermute_output.float().cpu().detach().numpy().flatten()
210 | max_abs_error = abs(original_inputs - original_output).max()
211 | print(f"unpermute_topK fwd max error (mine vs pytorch): \t\t{max_abs_error:.3e} ({dtype})")
212 |
213 | if PRINT:
214 | print(unpermute_output)
215 | print(new_unpermute_output)
216 |
217 | new_unpermute_output.backward(new_unpermute_bwd_input, retain_graph=True)
218 |
219 | if torch.allclose(unpermute_input.grad.float(), new_unpermute_input.grad.float()) == False:
220 | original_inputs = unpermute_input.grad.float().cpu().detach().numpy().flatten()
221 | original_output = new_unpermute_input.grad.float().cpu().detach().numpy().flatten()
222 | max_abs_error = abs(original_inputs - original_output).max()
223 | print(f"unpermute_topK bwd act_grad max error (mine vs pytorch): \t{max_abs_error:.3e} ({dtype})")
224 | if PRINT:
225 | print(new_unpermute_input.grad)
226 | print(unpermute_input.grad)
227 |
228 | if num_topK > 1 and torch.allclose(new_probs.grad, probs.grad) == False:
229 | original_inputs = new_probs.grad.float().cpu().detach().numpy().flatten()
230 | original_output = probs.grad.float().cpu().detach().numpy().flatten()
231 | max_abs_error = abs(original_inputs - original_output).max()
232 | print(f"unpermute_topK bwd prob_grad max error (mine vs pytorch): \t{max_abs_error:.3e} ({dtype})")
233 | if PRINT:
234 | print(new_probs.grad)
235 | print(probs.grad)
236 |
237 | if not permute_input.numel():
238 | print("Empty permute_input activation test passed.")
239 | return
240 |
241 | ###################################################################################################################################
242 | #
243 | # Benchmark
244 | #
245 | ###################################################################################################################################
246 | def backward_wrapper(act, backward_input, forward_input=[], retain_graph=True, accumulate_grad=False):
247 | # Set forward_input.grad to None to avoid grad accumulation.
248 | if accumulate_grad == False:
249 | for i in forward_input:
250 | i.grad = None
251 | return act.backward(backward_input, retain_graph=retain_graph)
252 |
253 | if BENCHMARK:
254 | print(f"----permute topK----")
255 | t = perf_test_cuda_kernel(lambda: permute(permute_input, indices, 2))
256 | print(f"pytorch fwd: {t:.3f} ms")
257 | t = perf_test_cuda_kernel(lambda: permute_topK(new_permute_input, indices))
258 | print(f"new fwd: {t:.3f} ms")
259 |
260 | t = perf_test_cuda_kernel(
261 | lambda: backward_wrapper(permute_output, permute_bwd_input, forward_input=[permute_input], retain_graph=True, accumulate_grad=False))
262 | print(f"pytorch bwd: {t:.3f} ms")
263 | t = perf_test_cuda_kernel(
264 | lambda: backward_wrapper(new_permute_output, new_permute_bwd_input, forward_input=[new_permute_input], retain_graph=True, accumulate_grad=False))
265 | print(f"new bwd: {t:.3f} ms")
266 |
267 | print(f"----unpermute topK----")
268 | t = perf_test_cuda_kernel(
269 | lambda: unpermute(unpermute_input, sorted_indices, probs=probs, merge_factor=num_topK))
270 | print(f"pytorch fwd: {t:.3f} ms")
271 | t = perf_test_cuda_kernel(
272 | lambda: unpermute_topK(new_unpermute_input, row_id_map, new_probs))
273 | print(f"new fwd: {t:.3f} ms")
274 |
275 | t = perf_test_cuda_kernel(
276 | lambda: backward_wrapper(unpermute_output, unpermute_bwd_input, forward_input=[unpermute_input, probs], retain_graph=True, accumulate_grad=False))
277 | print(f"pytorch bwd: {t:.3f} ms")
278 | t = perf_test_cuda_kernel(
279 | lambda: backward_wrapper(new_unpermute_output, new_unpermute_bwd_input, forward_input=[new_unpermute_input, new_probs], retain_graph=True, accumulate_grad=False))
280 | print(f"new bwd: {t:.3f} ms")
281 |
282 |
283 | def perf_test_cuda_kernel(cuda_kernel_fn):
284 | if torch.cuda.is_available():
285 | # create CUDA event
286 | start_event = torch.cuda.Event(enable_timing=True)
287 | end_event = torch.cuda.Event(enable_timing=True)
288 |
289 | # warmup
290 | for _ in range(50):
291 | cuda_kernel_fn()
292 |
293 | start_event.record()
294 | for _ in range(100):
295 | cuda_kernel_fn()
296 | end_event.record()
297 | torch.cuda.synchronize()
298 |
299 | elapsed_time_ms = start_event.elapsed_time(end_event)
300 | # print(f"Elapsed Time: {elapsed_time_ms / 100} ms")
301 | return elapsed_time_ms / 100
302 | else:
303 | print("CUDA is not available.")
304 |
305 | def test_permute_topK():
306 |
307 | torch.manual_seed(1)
308 |
309 | num_token = 4096 * 2
310 | num_expert = 8
311 | hidden_size = 4096
312 | num_topK = 2
313 |
314 | PRINT=False
315 | Benchmark = False
316 | print("GPU:", torch.cuda.get_device_name(0))
317 |
318 | dtype = torch.float32
319 | permute_topK_test(dtype, num_token, num_expert,
320 | hidden_size, num_topK, PRINT, Benchmark)
321 | dtype = torch.float16
322 | permute_topK_test(dtype, num_token, num_expert,
323 | hidden_size, num_topK, False, Benchmark)
324 | dtype = torch.bfloat16
325 | permute_topK_test(dtype, num_token, num_expert,
326 | hidden_size, num_topK, False, Benchmark)
327 | dtype = torch.float8_e5m2
328 | permute_topK_test(dtype, num_token, num_expert,
329 | hidden_size, num_topK, False, Benchmark)
330 | dtype = torch.float8_e4m3fn
331 | permute_topK_test(dtype, num_token, num_expert,
332 | hidden_size, num_topK, False, Benchmark)
333 | dtype = torch.bfloat16
334 | permute_topK_test(dtype, num_token, 4, hidden_size, 1, False, Benchmark)
335 | permute_topK_test(dtype, num_token, 5, hidden_size, 2, False, Benchmark)
336 | permute_topK_test(dtype, num_token, 6, hidden_size, 3, False, Benchmark)
337 | permute_topK_test(dtype, num_token, 7, hidden_size, 4, False, Benchmark)
338 | permute_topK_test(dtype, num_token, 8, hidden_size, 5, False, Benchmark)
339 | num_token = 0
340 | permute_topK_test(dtype, num_token, 8, hidden_size, 5, False, Benchmark)
341 |
342 | if __name__ == "__main__":
343 | test_permute_topK()
--------------------------------------------------------------------------------
/grouped_gemm/sinkhorn_test.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import itertools
3 |
4 | from absl.testing import parameterized
5 | from grouped_gemm import ops, ops_test
6 | import numpy as np
7 | import torch
8 |
9 | # Notes: the source of this func implementation:
10 | # https://github.com/NVIDIA/Megatron-LM/blob/2bc6cd307a11423928c675f741e79e03df23e721/megatron/core/transformer/switch_mlp.py#L17-L31
11 | def baseline_sinkhorn(cost, tol=0.0001):
12 | "Sinkhorn based MoE routing function"
13 | cost = torch.exp(cost)
14 | d0 = torch.ones(cost.size(0), device=cost.device, dtype=cost.dtype)
15 | d1 = torch.ones(cost.size(1), device=cost.device, dtype=cost.dtype)
16 |
17 | eps = 0.00000001
18 | error = 1e9
19 | d1_old = d1
20 | iter_count = 0
21 | while error > tol:
22 | d0 = (1 / d0.size(0)) * 1 / (torch.sum(d1 * cost, 1) + eps)
23 | d1 = (1 / d1.size(0)) * 1 / (torch.sum(d0.unsqueeze(1) * cost, 0) + eps)
24 | error = torch.mean(torch.abs(d1_old - d1))
25 | d1_old = d1
26 | iter_count = iter_count + 1
27 | return d1 * cost * d0.unsqueeze(1)
28 |
29 | _TEST_PROBLEMS = (
30 | # (128, 2, 0.1),
31 | # (256, 2, 0.1),
32 | # (1024, 4, 0.01),
33 | # (2048, 8, 0.0001),
34 | (4096, 8, 0.0001),
35 | (8192, 8, 0.0001),
36 | (8192, 16, 0.0001),
37 | )
38 |
39 | @parameterized.parameters(*_TEST_PROBLEMS)
40 | class OpsTest(parameterized.TestCase):
41 |
42 | def test_sinkhorn_kernel(self, m, k, tol=0.0001):
43 | start = torch.cuda.Event(enable_timing=True)
44 | end = torch.cuda.Event(enable_timing=True)
45 |
46 | torch.manual_seed(0)
47 | cost = torch.rand(m, k, device='cuda', dtype=torch.float32)
48 |
49 | start.record()
50 | expected_out = baseline_sinkhorn(cost, tol)
51 | end.record()
52 | torch.cuda.synchronize()
53 | baseline_time = start.elapsed_time(end)
54 |
55 | start.record()
56 | out = ops.sinkhorn_kernel(cost, tol)
57 | end.record()
58 | torch.cuda.synchronize()
59 | kernel_time = start.elapsed_time(end)
60 | print("===================================")
61 | print("Problem size: [%d]x[%d], kernel speedup: %fX" % (m, k, baseline_time/kernel_time))
62 | print("===================================")
63 | self.assertTrue(ops_test.allclose(out, expected_out))
64 |
65 | if __name__ == '__main__':
66 | unittest.main()
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 | from setuptools import setup, find_packages
4 | import torch
5 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
6 |
7 |
8 | # Supported NVIDIA GPU architectures.
9 | NVIDIA_SUPPORTED_ARCHS = {"7.0", "7.5", "8.0", "8.6", "8.9", "9.0"}
10 |
11 | # TORCH_CUDA_ARCH_LIST can have one or more architectures,
12 | # e.g. "9.0" or "7.0 7.2 7.5 8.0 8.6 8.7 9.0+PTX". Here,
13 | # the "9.0+PTX" option asks the
14 | # compiler to additionally include PTX code that can be runtime-compiled
15 | # and executed on the 8.6 or newer architectures. While the PTX code will
16 | # not give the best performance on the newer architectures, it provides
17 | # forward compatibility.
18 | env_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST", None)
19 | if env_arch_list:
20 | # Let PyTorch builder to choose device to target for.
21 | device_capability = ""
22 | else:
23 | device_capability = torch.cuda.get_device_capability()
24 | device_capability = f"{device_capability[0]}{device_capability[1]}"
25 |
26 | cwd = Path(os.path.dirname(os.path.abspath(__file__)))
27 |
28 | nvcc_flags = [
29 | "-std=c++17", # NOTE: CUTLASS requires c++17
30 | "-DENABLE_BF16", # Enable BF16 for cuda_version >= 11
31 | # "-DENABLE_FP8", # Enable FP8 for cuda_version >= 11.8
32 | ]
33 |
34 | if device_capability:
35 | nvcc_flags.extend([
36 | f"--generate-code=arch=compute_{device_capability},code=sm_{device_capability}",
37 | f"-DGROUPED_GEMM_DEVICE_CAPABILITY={device_capability}",
38 | ])
39 |
40 | ext_modules = [
41 | CUDAExtension(
42 | "grouped_gemm_backend",
43 | ["csrc/ops.cu", "csrc/grouped_gemm.cu", "csrc/sinkhorn.cu", "csrc/permute.cu"],
44 | include_dirs = [
45 | f"{cwd}/third_party/cutlass/include/"
46 | ],
47 | extra_compile_args={
48 | "cxx": [
49 | "-fopenmp", "-fPIC", "-Wno-strict-aliasing"
50 | ],
51 | "nvcc": nvcc_flags,
52 | }
53 | )
54 | ]
55 |
56 | setup(
57 | name="grouped_gemm",
58 | version="1.1.4",
59 | author="Trevor Gale, Jiang Shao, Shiqing Fan",
60 | author_email="tgale@stanford.edu, jiangs@nvidia.com, shiqingf@nvidia.com",
61 | description="GEMM Grouped",
62 | url="https://github.com/fanshiqing/grouped_gemm",
63 | classifiers=[
64 | "Programming Language :: Python :: 3",
65 | "License :: OSI Approved :: BSD License",
66 | "Operating System :: Unix",
67 | ],
68 | packages=find_packages(),
69 | ext_modules=ext_modules,
70 | cmdclass={"build_ext": BuildExtension},
71 | install_requires=["absl-py", "numpy", "torch"],
72 | )
73 |
--------------------------------------------------------------------------------