├── .github └── workflows │ ├── build-kernels.yml │ └── release.yml ├── .gitignore ├── .gitmodules ├── LICENSE ├── MANIFEST.in ├── README.md ├── SECURITY.md ├── build.sh ├── builder ├── __init__.py ├── builder.py ├── ft_gemm.py └── inf_flash_attn.py ├── dskernels ├── __init__.py ├── ft_gemm │ └── gemm_variants │ │ ├── CMakeLists.txt │ │ ├── LICENSE │ │ ├── build_ft_kernels.sh │ │ ├── cutlass_extensions │ │ ├── arch │ │ │ └── mma.h │ │ ├── compute_occupancy.h │ │ ├── epilogue │ │ │ ├── epilogue_quant_helper.h │ │ │ ├── thread │ │ │ │ └── ft_fused_activations.h │ │ │ └── threadblock │ │ │ │ ├── epilogue_per_row_per_col_scale.h │ │ │ │ └── epilogue_tensor_op_int32.h │ │ ├── epilogue_helpers.h │ │ ├── ft_gemm_configs.h │ │ ├── gemm │ │ │ ├── kernel │ │ │ │ ├── default_fpA_intB_traits.h │ │ │ │ ├── fpA_intB_gemm.h │ │ │ │ ├── gemm_moe_problem_visitor.h │ │ │ │ ├── gemm_with_epilogue_visitor.h │ │ │ │ ├── mixed_gemm_B_layout.h │ │ │ │ ├── moe_cutlass_kernel.h │ │ │ │ └── moe_problem_visitor.h │ │ │ ├── threadblock │ │ │ │ ├── default_dq_mma.h │ │ │ │ ├── default_dq_mma_multistage.h │ │ │ │ ├── default_dq_mma_pipelined.h │ │ │ │ ├── default_mma.h │ │ │ │ ├── default_mma_bf16.h │ │ │ │ ├── dq_mma_base.h │ │ │ │ ├── dq_mma_multistage.h │ │ │ │ └── dq_mma_pipelined.h │ │ │ └── warp │ │ │ │ ├── default_mma_tensor_op.h │ │ │ │ ├── mma_tensorop_compute_B_with_f16.h │ │ │ │ └── mma_tensorop_dequantizer.h │ │ ├── interleaved_numeric_conversion.h │ │ └── tile_interleaved_layout.h │ │ ├── fpA_intB_gemm │ │ ├── fpA_intB_gemm.h │ │ ├── fpA_intB_gemm_bf16_uint4.cu │ │ ├── fpA_intB_gemm_bf16_uint8.cu │ │ ├── fpA_intB_gemm_fp16_int4.cu │ │ ├── fpA_intB_gemm_fp16_int8.cu │ │ └── fpA_intB_gemm_template.h │ │ ├── moe_gemm │ │ ├── moe_gemm_kernels.h │ │ ├── moe_gemm_kernels_bf16_bf16.cu │ │ ├── moe_gemm_kernels_bf16_uint4.cu │ │ ├── moe_gemm_kernels_bf16_uint8.cu │ │ ├── moe_gemm_kernels_fp16_fp16.cu │ │ ├── moe_gemm_kernels_fp16_uint4.cu │ │ ├── moe_gemm_kernels_fp16_uint8.cu │ │ └── moe_gemm_kernels_template.h │ │ └── utils │ │ ├── INIReader.h │ │ ├── activation_type.h │ │ ├── cuda_utils.h │ │ ├── cutlass_heuristic.cc │ │ ├── cutlass_heuristic.h │ │ ├── cutlass_preprocessors.cc │ │ ├── cutlass_preprocessors.h │ │ ├── string_utils.h │ │ └── weight_variant.h └── inf_flash_attn │ └── blocked_flash │ ├── CMakeLists.txt │ ├── LICENSE │ ├── Makefile │ ├── attention_atom.h │ ├── build_blocked_flash.sh │ ├── flash.h │ ├── flash_api.cu │ ├── flash_fwd_hdim128_bf16_sm80.cu │ ├── flash_fwd_hdim128_fp16_sm80.cu │ ├── flash_fwd_hdim160_bf16_sm80.cu │ ├── flash_fwd_hdim160_fp16_sm80.cu │ ├── flash_fwd_hdim192_bf16_sm80.cu │ ├── flash_fwd_hdim192_fp16_sm80.cu │ ├── flash_fwd_hdim224_bf16_sm80.cu │ ├── flash_fwd_hdim224_fp16_sm80.cu │ ├── flash_fwd_hdim256_bf16_sm80.cu │ ├── flash_fwd_hdim256_fp16_sm80.cu │ ├── flash_fwd_hdim32_bf16_sm80.cu │ ├── flash_fwd_hdim32_fp16_sm80.cu │ ├── flash_fwd_hdim64_bf16_sm80.cu │ ├── flash_fwd_hdim64_fp16_sm80.cu │ ├── flash_fwd_hdim96_bf16_sm80.cu │ ├── flash_fwd_hdim96_fp16_sm80.cu │ ├── flash_fwd_kernel.h │ ├── flash_fwd_launch_template.h │ ├── kernel_traits.h │ ├── softmax.h │ ├── static_switch.h │ └── utils.h ├── fetch.sh ├── pyproject.toml ├── release ├── bump_patch_version.py ├── check_release_version.py └── release.sh ├── requirements ├── requirements-dev.txt └── requirements.txt ├── setup.py └── version.txt /.github/workflows/build-kernels.yml: -------------------------------------------------------------------------------- 1 | name: Build Kernels 2 | 3 | on: 4 | pull_request: 5 | 6 | jobs: 7 | build: 8 | 9 | runs-on: [self-hosted, cpu] 10 | 11 | steps: 12 | - uses: actions/checkout@v4 13 | with: 14 | submodules: 'recursive' 15 | 16 | - name: environment 17 | run: | 18 | python --version 19 | nvcc --version 20 | python -c "import torch; print('torch:', torch.__version__, torch)" 21 | python -c "import torch; print('CUDA available:', torch.cuda.is_available())" 22 | 23 | - name: build kernels 24 | run: | 25 | pip install build 26 | ts=$(date +%s) 27 | DS_KERNELS_MAKE_JOBS=10 DS_KERNELS_BUILD_STRING=".dev${ts}" CUDA_ARCH_LIST="80;86;89;90" python -m build --wheel 28 | fname=$(ls dist) 29 | nname=$(echo $fname | sed 's/cp[0-9]\+-cp[0-9]\+/py3-none/' | sed 's/linux/manylinux1/') 30 | mv "dist/$fname" "dist/$nname" 31 | ls -al 32 | 33 | - uses: actions/upload-artifact@v3 34 | with: 35 | name: deepspeed-kernels-whl 36 | path: dist/*.whl 37 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Build and publish DeepSpeed-Kernels release 2 | 3 | on: 4 | push: 5 | tags: 6 | - 'v*.*.*' 7 | 8 | jobs: 9 | deploy: 10 | runs-on: [self-hosted, cpu] 11 | environment: release-env 12 | 13 | steps: 14 | - uses: actions/checkout@v4 15 | with: 16 | submodules: 'recursive' 17 | 18 | - name: Check environment 19 | run: | 20 | which python 21 | python --version 22 | which nvcc 23 | nvcc --version 24 | python -c "import torch; print('torch:', torch.__version__, torch)" 25 | python -c "import torch; print('CUDA available:', torch.cuda.is_available())" 26 | - name: Get release version from tag 27 | run: | 28 | echo "RELEASE_VERSION=${GITHUB_REF#refs/*/v}" >> $GITHUB_ENV 29 | - name: Check release version 30 | run: | 31 | pip install packaging 32 | python release/check_release_version.py --release_version ${{ env.RELEASE_VERSION }} 33 | - name: Build DeepSpeed-Kernels 34 | run: | 35 | pip install build 36 | ts=$(date +%s) 37 | DS_KERNELS_MAKE_JOBS=10 DS_KERNELS_BUILD_STRING=".dev${ts}" CUDA_ARCH_LIST="80;86;89;90" python -m build --wheel --no-build-isolation 38 | fname=$(ls dist) 39 | nname=$(echo $fname | sed 's/cp[0-9]\+-cp[0-9]\+/py3-none/' | sed 's/linux/manylinux1/') 40 | mv "dist/$fname" "dist/$nname" 41 | ls -al 42 | - name: Publish to PyPI 43 | uses: pypa/gh-action-pypi-publish@release/v1 44 | with: 45 | password: ${{ secrets.PYPI_API_TOKEN }} 46 | repository-url: https://upload.pypi.org/legacy/ 47 | - name: Bump version 48 | run: | 49 | python release/bump_patch_version.py --current_version ${{ env.RELEASE_VERSION }} 50 | - name: Create Pull Request 51 | uses: peter-evans/create-pull-request@v6 52 | with: 53 | token: ${{ secrets.GH_PAT }} 54 | add-paths: | 55 | version.txt 56 | body: | 57 | **Auto-generated PR to update version.txt after a DeepSpeed release** 58 | Released version - ${{ env.RELEASE_VERSION }} 59 | Author - @${{ github.actor }} 60 | branch: AutoPR/${{ env.RELEASE_VERSION }} 61 | assignees: ${{ github.actor }} 62 | title: "Update version.txt after ${{ env.RELEASE_VERSION }} release" 63 | author: ${{ github.actor }} <${{ github.actor }}@users.noreply.github.com> 64 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | 3 | **/build/** 4 | **/lib/** 5 | 6 | dskernels/version.py 7 | 8 | # Dev/IDE data 9 | .vscode 10 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "ft_gemm/third_party/cutlass"] 2 | path = dskernels/ft_gemm/third_party/cutlass 3 | url = https://github.com/NVIDIA/cutlass.git 4 | [submodule "inf_flash_attn/third_party/cutlass"] 5 | path = dskernels/inf_flash_attn/third_party/cutlass 6 | url = https://github.com/NVIDIA/cutlass.git 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include dskernels/*.so 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![License Apache 2.0](https://badgen.net/badge/license/apache2.0/blue)](https://github.com/deepspeedai/DeepSpeed/blob/master/LICENSE) 2 | [![PyPI version](https://badge.fury.io/py/deepspeed-kernels.svg)](https://pypi.org/project/deepspeed-kernels/) 3 | 4 | # DeepSpeed Kernels 5 | 6 | DeepSpeed-Kernels is a backend library that is used to power [DeepSpeed-FastGen](https://github.com/deepspeedai/DeepSpeed/tree/master/blogs/deepspeed-fastgen) to achieve accelerated text-generation inference through [DeepSpeed-MII](https://github.com/deepspeedai/DeepSpeed-mii). This library is not intended to be an independent user package, but is open-source to benefit the community and show how DeepSpeed is accelerating text-generation. 7 | 8 | The resulting binaries that are compiled from this repo and included in the PyPI release are torch and python agnostic, this allows the core backend to be as portable as possible and leaves the task of compiling torch and python bindings to DeepSpeed itself using it's [JIT op builder](https://github.com/deepspeedai/DeepSpeed/tree/master/op_builder). 9 | 10 | # Installation 11 | 12 | ## PyPI 13 | 14 | If your environment supports it you can quickly install DeepSpeed-Kernels from [PyPI](https://pypi.org/project/deepspeed-kernels/) (see below). We've tested the portability of the PyPI release on A100, A6000, and H100. 15 | 16 | The release on PyPI should work with the following assumptions about your environment: 17 | * NVIDIA GPU(s) with compute capability of: 8.0, 8.6, 8.9, 9.0 18 | * CUDA 11.6+ 19 | * Ubuntu 20+ 20 | 21 | ```bash 22 | pip install deepspeed-kernels 23 | ``` 24 | 25 | ## Source 26 | If the PyPI release does not work for you we recommend installing from source which can take several minutes: 27 | ```bash 28 | pip install -v . 29 | ``` 30 | 31 | ## Advanced 32 | 33 | You can create a pre-compiled portable wheel that supports different CUDA architectures via the `CUDA_ARCH_LIST` environment variable. By default the kernels will be compiled using the `native` compute capability. If you want to compile for more than one you can set the `CUDA_ARCH_LIST` environment variable. We currently only support Ampere and above architectures (i.e., 8.0+). See example below to build for GPUs like A100 and A6000: 34 | ```bash 35 | CUDA_ARCH_LIST="80;86;89;90" python -m build --wheel 36 | ``` 37 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /build.sh: -------------------------------------------------------------------------------- 1 | set -ex 2 | 3 | # set CUDA_ARCH_LIST if desired, otherwise falls back to native 4 | CUDA_ARCH_LIST="80;86" python -m build --wheel 5 | -------------------------------------------------------------------------------- /builder/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | # DeepSpeed Team 5 | -------------------------------------------------------------------------------- /builder/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | # DeepSpeed Team 5 | 6 | import os 7 | import subprocess 8 | from setuptools import Extension 9 | from setuptools.command.build_ext import build_ext 10 | from packaging import version as pkg_version 11 | 12 | 13 | def installed_cuda_version(): 14 | cuda_home = os.environ.get("CUDA_HOME", None) 15 | if cuda_home is None: 16 | import torch.utils.cpp_extension 17 | cuda_home = torch.utils.cpp_extension.CUDA_HOME 18 | assert cuda_home is not None, "CUDA_HOME does not exist, unable to compile CUDA op(s)" 19 | output = subprocess.check_output([cuda_home + "/bin/nvcc", "-V"], 20 | universal_newlines=True) 21 | output_split = output.split() 22 | release_idx = output_split.index("release") 23 | release = output_split[release_idx + 1].replace(',', '') 24 | return pkg_version.parse(release) 25 | 26 | 27 | def validate_arch_list(cuda_arch_list: str): 28 | cuda_version = installed_cuda_version() 29 | cuda_arch_list = cuda_arch_list.split(';') 30 | for arch in cuda_arch_list: 31 | if arch == 'native': 32 | continue 33 | try: 34 | arch = int(arch) 35 | except ValueError: 36 | raise ValueError( 37 | f"Invalid CUDA_ARCH_LIST: {cuda_arch_list}. " 38 | "CUDA_ARCH_LIST must be a list of integers or 'native'.") 39 | 40 | # [error] if cuda < 11.8 and arch is >= 89 41 | if cuda_version < pkg_version.parse('11.8'): 42 | assert arch < 89, f"Compute capability of {arch} is not supported for CUDA {cuda_version}" 43 | 44 | # [error] if cuda == 11.8 and arch > 90 45 | if cuda_version == pkg_version.parse('11.8'): 46 | assert arch <= 90, f"Compute capability of {arch} is not supported for CUDA {cuda_version}" 47 | 48 | # [error] min arch is 80 49 | assert arch >= 80, f"Compute capability less than 80 is not supported for DeepSpeed kernels" 50 | 51 | 52 | class CMakeExtension(Extension): 53 | def __init__(self, name, sources=[]): 54 | super().__init__(name, sources) 55 | 56 | 57 | class CMakeBuild(build_ext): 58 | def run(self): 59 | for ext in self.extensions: 60 | self.build_extension(ext) 61 | 62 | def build_extension(self, ext): 63 | # Build in kernel unique sub-directory inside temp build directory 64 | abs_build_temp = os.path.abspath(self.build_temp) 65 | abs_build_temp = os.path.join(abs_build_temp, ext.name) 66 | 67 | # Pass through CUDA_ARCH_LIST if defined, otherwise fall back to native 68 | cuda_arch_list = os.environ.get('CUDA_ARCH_LIST', 'native') 69 | validate_arch_list(cuda_arch_list) 70 | 71 | # Destination path for final binaries 72 | abs_build_lib = os.path.join(os.path.abspath(self.build_lib), 73 | "dskernels") 74 | 75 | subprocess.check_call(['cmake', '-B', abs_build_temp, 76 | f'-DLIB_OUTPUT_DIR={abs_build_lib}', 77 | f'-DCUDA_ARCH_LIST={cuda_arch_list}'], 78 | cwd=ext.source) 79 | 80 | # Allow user to specify degree of make parallelism 81 | make_jobs = os.environ.get('DS_KERNELS_MAKE_JOBS', None) 82 | make_cmd = f"make -j {make_jobs}" if make_jobs is not None else "make -j" 83 | subprocess.check_call(make_cmd.split(" "), cwd=abs_build_temp) 84 | -------------------------------------------------------------------------------- /builder/ft_gemm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | # DeepSpeed Team 5 | 6 | from .builder import CMakeExtension 7 | 8 | 9 | class FTGemmBuilder(CMakeExtension): 10 | def __init__(self, name, sources=[]): 11 | super().__init__(name, sources) 12 | 13 | @property 14 | def source(self): 15 | return "dskernels/ft_gemm/gemm_variants/" 16 | -------------------------------------------------------------------------------- /builder/inf_flash_attn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | # DeepSpeed Team 5 | 6 | from .builder import CMakeExtension 7 | 8 | 9 | class BlockedFlashBuilder(CMakeExtension): 10 | def __init__(self, name, sources=[]): 11 | super().__init__(name, sources) 12 | 13 | @property 14 | def source(self): 15 | return "dskernels/inf_flash_attn/blocked_flash/" 16 | -------------------------------------------------------------------------------- /dskernels/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | # DeepSpeed Team 5 | 6 | import os 7 | 8 | __version__ = "0.0.0" 9 | try: 10 | from .version import __version__ 11 | except ImportError: 12 | pass 13 | 14 | 15 | def library_path(): 16 | return os.path.dirname(__file__) 17 | -------------------------------------------------------------------------------- /dskernels/ft_gemm/gemm_variants/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | cmake_minimum_required(VERSION 3.11) 16 | project(DeepSpeedFTKernels CXX CUDA) 17 | 18 | Set(CMAKE_CXX_STANDARD 17) 19 | Set(CMAKE_CUDA_STANDARD 17) 20 | 21 | if (NOT CMAKE_BUILD_TYPE) 22 | set(CMAKE_BUILD_TYPE Release) 23 | endif() 24 | 25 | list(APPEND CMAKE_MODULE_PATH ${CMAKE_BINARY_DIR}) 26 | list(APPEND CMAKE_PREFIX_PATH ${CMAKE_BINARY_DIR}) 27 | 28 | if (NOT EXISTS "${CMAKE_BINARY_DIR}/conan.cmake") 29 | message(STATUS "Downloading conan.cmake from https://github.com/conan-io/cmake-conan") 30 | file(DOWNLOAD "https://raw.githubusercontent.com/conan-io/cmake-conan/v0.16.1/conan.cmake" 31 | "${CMAKE_BINARY_DIR}/conan.cmake" 32 | EXPECTED_HASH SHA256=396e16d0f5eabdc6a14afddbcfff62a54a7ee75c6da23f32f7a31bc85db23484 33 | TLS_VERIFY ON) 34 | endif() 35 | 36 | find_package(CUDA) 37 | 38 | if (NOT WIN32) 39 | list(APPEND CMAKE_CXX_FLAGS "-fmax-errors=1 -Wfatal-errors") 40 | Set(LIB_NAME "deepspeedft") 41 | else() 42 | Set(LIB_NAME "libdeepspeedft") 43 | endif() 44 | 45 | list(APPEND NVCC_FLAGS "-O3") 46 | list(APPEND NVCC_FLAGS "-U__CUDA_NO_HALF_OPERATORS__") 47 | list(APPEND NVCC_FLAGS "-U__CUDA_NO_HALF_CONVERSIONS__") 48 | list(APPEND NVCC_FLAGS "-U__CUDA_NO_HALF2_OPERATORS__") 49 | list(APPEND NVCC_FLAGS "-U__CUDA_NO_BFLOAT16_CONVERSIONS__") 50 | list(APPEND NVCC_FLAGS "-U__CUDA_NO_BFLOAT16_OPERATORS__") 51 | list(APPEND NVCC_FLAGS "-expt-relaxed-constexpr") 52 | list(APPEND NVCC_FLAGS "--use_fast_math") 53 | add_definitions(-DENABLE_BF16) 54 | add_definitions(-DBUILD_CUTLASS_MOE) 55 | add_definitions(-DBUILD_CUTLASS_MIXED_GEMM) 56 | 57 | add_library(cutlass_heuristic STATIC utils/cutlass_heuristic.cc) 58 | target_include_directories(cutlass_heuristic PRIVATE 59 | ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} 60 | ${CMAKE_CURRENT_SOURCE_DIR} 61 | ${CMAKE_CURRENT_SOURCE_DIR}/../third_party/cutlass/include) 62 | set_property(TARGET cutlass_heuristic PROPERTY POSITION_INDEPENDENT_CODE ON) 63 | target_compile_options(cutlass_heuristic PRIVATE $<$:${NVCC_FLAGS}>) 64 | 65 | add_library(cutlass_preprocessors STATIC utils/cutlass_preprocessors.cc) 66 | target_include_directories(cutlass_preprocessors PRIVATE 67 | ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} 68 | ${CMAKE_CURRENT_SOURCE_DIR} 69 | ${CMAKE_CURRENT_SOURCE_DIR}/../third_party/cutlass/include) 70 | set_property(TARGET cutlass_preprocessors PROPERTY POSITION_INDEPENDENT_CODE ON) 71 | target_compile_options(cutlass_preprocessors PRIVATE $<$:${NVCC_FLAGS}>) 72 | 73 | set(VERBOSE_BUILD 0) 74 | 75 | if (VERBOSE_BUILD) 76 | list(APPEND NVCC_FLAGS "--ptxas-options=-v") 77 | endif() 78 | 79 | add_library(${LIB_NAME} SHARED) 80 | 81 | set(moe_gemm_files "") 82 | file(GLOB moe_gemm_files ${moe_gemm_files} moe_gemm/*.cu) 83 | set(fpA_intB_files "") 84 | file(GLOB fpA_intB_files ${fpA_intB_files} fpA_intB_gemm/*.cu) 85 | 86 | set(ALL_SRCS ${moe_gemm_files} ${fpA_intB_files}) 87 | 88 | target_sources(${LIB_NAME} PRIVATE ${ALL_SRCS}) 89 | target_include_directories(${LIB_NAME} PRIVATE 90 | ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} 91 | ${CMAKE_CURRENT_SOURCE_DIR} 92 | ${CMAKE_CURRENT_SOURCE_DIR}/../third_party/cutlass/include) 93 | set_property(TARGET ${LIB_NAME} PROPERTY POSITION_INDEPENDENT_CODE ON) 94 | set_property(TARGET ${LIB_NAME} PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) 95 | target_link_libraries(${LIB_NAME} PRIVATE cutlass_heuristic) 96 | 97 | target_compile_options(${LIB_NAME} PRIVATE $<$:${NVCC_FLAGS}>) 98 | set_target_properties(${LIB_NAME} PROPERTIES CUDA_ARCHITECTURES "${CUDA_ARCH_LIST}") 99 | set_target_properties(${LIB_NAME} PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${LIB_OUTPUT_DIR}) 100 | 101 | # Show timings? 102 | set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE "${CMAKE_COMMAND} -E time") 103 | -------------------------------------------------------------------------------- /dskernels/ft_gemm/gemm_variants/LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /dskernels/ft_gemm/gemm_variants/build_ft_kernels.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ ! -d "build" ]; then 4 | mkdir build 5 | fi 6 | 7 | cd build 8 | cmake .. 9 | make -j 10 | -------------------------------------------------------------------------------- /dskernels/ft_gemm/gemm_variants/cutlass_extensions/arch/mma.h: -------------------------------------------------------------------------------- 1 | /*************************************************************************************************** 2 | * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | * 5 | * Redistribution and use in source and binary forms, with or without 6 | * modification, are permitted provided that the following conditions are met: 7 | * 8 | * 1. Redistributions of source code must retain the above copyright notice, this 9 | * list of conditions and the following disclaimer. 10 | * 11 | * 2. Redistributions in binary form must reproduce the above copyright notice, 12 | * this list of conditions and the following disclaimer in the documentation 13 | * and/or other materials provided with the distribution. 14 | * 15 | * 3. Neither the name of the copyright holder nor the names of its 16 | * contributors may be used to endorse or promote products derived from 17 | * this software without specific prior written permission. 18 | * 19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | * 30 | **************************************************************************************************/ 31 | /*! \file 32 | \brief Templates exposing architecture support for multiply-add operations 33 | */ 34 | 35 | #pragma once 36 | 37 | ///////////////////////////////////////////////////////////////////////////////////////////////// 38 | 39 | namespace cutlass { 40 | namespace arch { 41 | 42 | // Tag which triggers MMA which will trigger 43 | struct OpMultiplyAddDequantizeInterleavedBToA; 44 | 45 | } // namespace arch 46 | } // namespace cutlass -------------------------------------------------------------------------------- /dskernels/ft_gemm/gemm_variants/cutlass_extensions/compute_occupancy.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | #pragma once 17 | 18 | #include 19 | 20 | #include "cutlass/device_kernel.h" 21 | #include "utils/cuda_utils.h" 22 | 23 | namespace fastertransformer { 24 | 25 | template 26 | inline int compute_occupancy_for_kernel() 27 | { 28 | 29 | int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); 30 | 31 | if (smem_size > (48 << 10)) { 32 | cudaError_t status = 33 | cudaFuncSetAttribute(cutlass::Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); 34 | if (status == cudaError::cudaErrorInvalidValue) { 35 | // Clear the error bit since we can ignore this. 36 | // This should mean that smem_size > cudaDevAttrMaxSharedMemoryPerBlockOptin. In that case, we return an 37 | // occupancy of 0. This will cause the heuristic to ignore this configuration. 38 | status = cudaGetLastError(); 39 | return 0; 40 | } 41 | check_cuda_error(status); 42 | } 43 | 44 | int max_active_blocks = -1; 45 | check_cuda_error(cudaOccupancyMaxActiveBlocksPerMultiprocessor( 46 | &max_active_blocks, cutlass::Kernel, GemmKernel::kThreadCount, smem_size)); 47 | 48 | return max_active_blocks; 49 | } 50 | 51 | } // namespace fastertransformer -------------------------------------------------------------------------------- /dskernels/ft_gemm/gemm_variants/cutlass_extensions/epilogue/epilogue_quant_helper.h: -------------------------------------------------------------------------------- 1 | /*************************************************************************************************** 2 | * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | * 5 | * Redistribution and use in source and binary forms, with or without 6 | * modification, are permitted provided that the following conditions are met: 7 | * 8 | * 1. Redistributions of source code must retain the above copyright notice, this 9 | * list of conditions and the following disclaimer. 10 | * 11 | * 2. Redistributions in binary form must reproduce the above copyright notice, 12 | * this list of conditions and the following disclaimer in the documentation 13 | * and/or other materials provided with the distribution. 14 | * 15 | * 3. Neither the name of the copyright holder nor the names of its 16 | * contributors may be used to endorse or promote products derived from 17 | * this software without specific prior written permission. 18 | * 19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | * 30 | **************************************************************************************************/ 31 | 32 | #pragma once 33 | 34 | ///////////////////////////////////////////////////////////////////////////////////////////////// 35 | 36 | namespace cutlass { 37 | namespace epilogue { 38 | 39 | // define scaling mode 40 | enum class QuantMode { 41 | PerTensorQuant, 42 | PerTokenQuant, 43 | PerChannelQuant, 44 | PerTokenChannelQuant 45 | }; 46 | 47 | } // namespace epilogue 48 | } // namespace cutlass 49 | -------------------------------------------------------------------------------- /dskernels/ft_gemm/gemm_variants/cutlass_extensions/epilogue/thread/ft_fused_activations.h: -------------------------------------------------------------------------------- 1 | /*************************************************************************************************** 2 | * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | * 5 | * Redistribution and use in source and binary forms, with or without 6 | * modification, are permitted provided that the following conditions are met: 7 | * 8 | * 1. Redistributions of source code must retain the above copyright notice, this 9 | * list of conditions and the following disclaimer. 10 | * 11 | * 2. Redistributions in binary form must reproduce the above copyright notice, 12 | * this list of conditions and the following disclaimer in the documentation 13 | * and/or other materials provided with the distribution. 14 | * 15 | * 3. Neither the name of the copyright holder nor the names of its 16 | * contributors may be used to endorse or promote products derived from 17 | * this software without specific prior written permission. 18 | * 19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | * 30 | **************************************************************************************************/ 31 | /*! \file 32 | \brief Functor performing linear combination with a maximum operation used by epilogues. 33 | */ 34 | 35 | #pragma once 36 | 37 | #include "cutlass/array.h" 38 | #include "cutlass/cutlass.h" 39 | #include "cutlass/epilogue/thread/activation.h" 40 | #include "cutlass/epilogue/thread/scale_type.h" 41 | #include "cutlass/functional.h" 42 | #include "cutlass/half.h" 43 | #include "cutlass/numeric_conversion.h" 44 | #include "cutlass/numeric_types.h" 45 | 46 | ///////////////////////////////////////////////////////////////////////////////////////////////// 47 | 48 | namespace cutlass { 49 | namespace epilogue { 50 | namespace thread { 51 | 52 | ///////////////////////////////////////////////////////////////////////////////////////////////// 53 | 54 | __forceinline__ __device__ float copysignf_pos(float a, float b) 55 | { 56 | float r; 57 | r = __int_as_float(__float_as_int(a) | (__float_as_int(b) & 0x80000000)); 58 | return r; 59 | } 60 | 61 | __forceinline__ __device__ float tanh_opt(float x) 62 | { 63 | #if (__CUDACC_VER_MAJOR__ < 11) || (__CUDA_ARCH__ < 750) 64 | const float exp_val = -1.f * fabs(2 * x); 65 | return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x); 66 | #else 67 | return fast_tanh(x); 68 | #endif 69 | } 70 | 71 | ///////////////////////////////////////////////////////////////////////////////////////////////// 72 | template<> 73 | struct GELU_taylor { 74 | static const bool kIsHeavy = true; 75 | CUTLASS_DEVICE 76 | float operator()(float const& z) const 77 | { 78 | 79 | float k0 = float(0.7978845608028654); 80 | float k1 = float(0.044715); 81 | 82 | return float( 83 | cutlass::constants::half() * z 84 | * (cutlass::constants::one() + tanh_opt(k0 * z * (cutlass::constants::one() + k1 * z * z)))); 85 | } 86 | 87 | using Params = LinearCombinationGenericParams; 88 | 89 | CUTLASS_DEVICE 90 | float operator()(float const& scalar, Params const& params_) const 91 | { 92 | return this->operator()(scalar); 93 | } 94 | }; 95 | 96 | } // namespace thread 97 | } // namespace epilogue 98 | } // namespace cutlass 99 | 100 | ///////////////////////////////////////////////////////////////////////////////////////////////// 101 | -------------------------------------------------------------------------------- /dskernels/ft_gemm/gemm_variants/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h: -------------------------------------------------------------------------------- 1 | /*************************************************************************************************** 2 | * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | * 5 | * Redistribution and use in source and binary forms, with or without 6 | * modification, are permitted provided that the following conditions are met: 7 | * 8 | * 1. Redistributions of source code must retain the above copyright notice, this 9 | * list of conditions and the following disclaimer. 10 | * 11 | * 2. Redistributions in binary form must reproduce the above copyright notice, 12 | * this list of conditions and the following disclaimer in the documentation 13 | * and/or other materials provided with the distribution. 14 | * 15 | * 3. Neither the name of the copyright holder nor the names of its 16 | * contributors may be used to endorse or promote products derived from 17 | * this software without specific prior written permission. 18 | * 19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | * 30 | **************************************************************************************************/ 31 | /*! \file 32 | \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. 33 | 34 | The epilogue rearranges the result of a matrix product through shared memory to match canonical 35 | tensor layouts in global memory. Epilogues support conversion and reduction operations. 36 | 37 | original file: 3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h 38 | 39 | */ 40 | 41 | #pragma once 42 | 43 | #include "cutlass/array.h" 44 | #include "cutlass/cutlass.h" 45 | #include "cutlass/numeric_types.h" 46 | 47 | #include "cutlass/platform/platform.h" 48 | 49 | #include "cutlass/gemm/gemm.h" 50 | 51 | #include "cutlass/epilogue/thread/linear_combination.h" 52 | #include "cutlass/epilogue/thread/linear_combination_clamp.h" 53 | #include "cutlass/epilogue/thread/linear_combination_gelu.h" 54 | #include "cutlass/epilogue/thread/linear_combination_hardswish.h" 55 | #include "cutlass/epilogue/thread/linear_combination_planar_complex.h" 56 | #include "cutlass/epilogue/thread/linear_combination_relu.h" 57 | #include "cutlass/epilogue/thread/linear_combination_relu0.h" 58 | #include "cutlass/epilogue/thread/linear_combination_sigmoid.h" 59 | 60 | #include "cutlass/epilogue/thread/conversion_op.h" 61 | #include "cutlass/epilogue/thread/reduction_op.h" 62 | 63 | #include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h" 64 | 65 | #include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h" 66 | #include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" 67 | #include "cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h" 68 | #include "cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h" 69 | #include "cutlass/epilogue/threadblock/shared_load_iterator.h" 70 | #include "cutlass/epilogue/threadblock/shared_load_iterator_mixed.h" 71 | #include "cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h" 72 | #include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h" 73 | #include "cutlass/epilogue/warp/tile_iterator_tensor_op.h" 74 | #include "cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h" 75 | 76 | #include "cutlass/epilogue/threadblock/epilogue.h" 77 | #include "cutlass/epilogue/threadblock/interleaved_epilogue.h" 78 | 79 | #include "cutlass/layout/permute.h" 80 | 81 | //////////////////////////////////////////////////////////////////////////////// 82 | 83 | namespace cutlass { 84 | namespace epilogue { 85 | namespace threadblock { 86 | 87 | //////////////////////////////////////////////////////////////////////////////// 88 | 89 | namespace detail { 90 | 91 | /// Partial specialization for half <= int32_t x 8 epilogues avoids shared memory bank conflicts. 92 | template 93 | struct DefaultIteratorsTensorOp { 94 | 95 | using WarpTileIterator = 96 | cutlass::epilogue::warp::TileIteratorTensorOp; 97 | 98 | using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator; 99 | 100 | static int const kFragmentsPerIteration = 1; 101 | }; 102 | 103 | /// Partial specialization for bfloat16_t <= int32_t x 8 epilogues avoids shared memory bank conflicts. 104 | template 105 | struct DefaultIteratorsTensorOp { 112 | 113 | using WarpTileIterator = 114 | cutlass::epilogue::warp::TileIteratorTensorOp; 115 | 116 | using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator; 117 | 118 | static int const kFragmentsPerIteration = 1; 119 | }; 120 | 121 | ///////////////////////////////////////////////////////////////////////////////////////////////// 122 | 123 | } // namespace detail 124 | 125 | ///////////////////////////////////////////////////////////////////////////////////////////////// 126 | 127 | /// Tile iterator used to load output tile from shared memory in epilogue. 128 | /// 129 | /// Satisfies: ReadableTileIterator 130 | /// 131 | template 133 | class SharedLoadIteratorMixed { 134 | public: 135 | using ThreadMap = ThreadMap_; 136 | using Shape = typename ThreadMap::Shape; 137 | 138 | using Element = int32_t; 139 | 140 | using Layout = layout::RowMajor; 141 | using TensorRef = TensorRef; 142 | using ConstTensorRef = typename TensorRef::ConstTensorRef; 143 | 144 | using Index = typename Layout::Index; 145 | using LongIndex = typename Layout::LongIndex; 146 | using TensorCoord = MatrixCoord; 147 | 148 | static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; 149 | 150 | static int const kAlignment = ThreadMap::kElementsPerAccess * sizeof_bits::value / 8; 151 | 152 | static int const kThreads = ThreadMap::kThreads; 153 | 154 | /// Fragment object 155 | using Fragment = Array; 158 | 159 | /// Memory access size 160 | using AccessType = AlignedArray; 161 | 162 | /// Vector type used for SMEM loads 163 | using LoadType = AlignedArray::value, ThreadMap::kElementsPerAccess), 165 | const_min(16, kAlignment)>; 166 | 167 | static int const kLoadsPerAccess = AccessType::kElements / LoadType::kElements; 168 | 169 | private: 170 | // 171 | // Data members 172 | // 173 | 174 | /// Byte-level pointer 175 | LoadType const* pointers_[kLoadsPerAccess]; 176 | 177 | /// Stride along adjacent rows in units of LoadType 178 | int stride_; 179 | 180 | public: 181 | // 182 | // Methods 183 | // 184 | 185 | /// Constructor 186 | CUTLASS_DEVICE 187 | SharedLoadIteratorMixed(TensorRef ref, int thread_idx): stride_((ref.stride(0) / LoadType::kElements)) 188 | { 189 | 190 | TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx); 191 | 192 | // Initialize pointers 193 | CUTLASS_PRAGMA_UNROLL 194 | for (int i = 0; i < kLoadsPerAccess; ++i) { 195 | pointers_[i] = reinterpret_cast(ref.data()); 196 | 197 | int col_idx = (thread_offset.column() / kElementsPerAccess) * kLoadsPerAccess; 198 | int bank_offset = (col_idx * int(sizeof(LoadType)) / 128) % kLoadsPerAccess; 199 | 200 | col_idx += (bank_offset + i) % kLoadsPerAccess; 201 | 202 | pointers_[i] += thread_offset.row() * stride_ + col_idx; 203 | } 204 | } 205 | 206 | /// Adds a pointer offset in units of Element 207 | CUTLASS_HOST_DEVICE 208 | void add_pointer_offset(LongIndex pointer_offset) 209 | { 210 | CUTLASS_PRAGMA_UNROLL 211 | for (int i = 0; i < kLoadsPerAccess; ++i) { 212 | pointers_[i] += pointer_offset / LoadType::kElements; 213 | } 214 | } 215 | 216 | CUTLASS_DEVICE 217 | void add_tile_offset(TensorCoord const& offset) 218 | { 219 | CUTLASS_PRAGMA_UNROLL 220 | for (int i = 0; i < kLoadsPerAccess; ++i) { 221 | pointers_[i] += 222 | offset.row() * Shape::kRow * stride_ + offset.column() * Shape::kColumn / LoadType::kElements; 223 | } 224 | } 225 | 226 | /// Loads a fragment from memory 227 | CUTLASS_DEVICE 228 | void load_with_pointer_offset(Fragment& frag, Index pointer_offset) const 229 | { 230 | 231 | CUTLASS_PRAGMA_UNROLL 232 | for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { 233 | 234 | CUTLASS_PRAGMA_UNROLL 235 | for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { 236 | 237 | CUTLASS_PRAGMA_UNROLL 238 | for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { 239 | 240 | int row_ptr_offset = 241 | row * ThreadMap::Delta::kRow * stride_ + group * ThreadMap::Delta::kGroup * stride_ 242 | + cluster * ThreadMap::Delta::kCluster * stride_ + pointer_offset / LoadType::kElements; 243 | 244 | int frag_row_idx = 245 | (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); 246 | 247 | LoadType* frag_ptr = reinterpret_cast(&frag); 248 | 249 | CUTLASS_PRAGMA_UNROLL 250 | for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { 251 | 252 | int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn + column; 253 | 254 | CUTLASS_PRAGMA_UNROLL 255 | for (int v = 0; v < kLoadsPerAccess; ++v) { 256 | 257 | int vector_idx = 258 | (column * ThreadMap::Delta::kColumn / kElementsPerAccess * kLoadsPerAccess); 259 | 260 | LoadType const* memory_pointer = pointers_[v] + row_ptr_offset; 261 | 262 | frag_ptr[frag_idx * kLoadsPerAccess + v] = memory_pointer[vector_idx]; 263 | } 264 | } 265 | } 266 | } 267 | } 268 | } 269 | 270 | /// Loads a fragment 271 | CUTLASS_DEVICE 272 | void load(Fragment& frag) const 273 | { 274 | 275 | load_with_pointer_offset(frag, 0); 276 | } 277 | }; 278 | 279 | ///////////////////////////////////////////////////////////////////////////////////////////////// 280 | 281 | } // namespace threadblock 282 | } // namespace epilogue 283 | } // namespace cutlass 284 | 285 | //////////////////////////////////////////////////////////////////////////////// 286 | -------------------------------------------------------------------------------- /dskernels/ft_gemm/gemm_variants/cutlass_extensions/epilogue_helpers.h: -------------------------------------------------------------------------------- 1 | /** 2 | * @file epilogue_helpers.h 3 | * 4 | * This file includes types for the epilogues. The empty structs exist so we can signal to template 5 | * code the type of epilogue we want to run, and let the underlying code specify the details such as 6 | * element types, accumulator type and elements per vector access. 7 | * 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "cutlass/epilogue/thread/linear_combination.h" 13 | #include "cutlass/epilogue/thread/linear_combination_generic.h" 14 | #include "cutlass/epilogue/thread/linear_combination_relu.h" 15 | #include "cutlass/epilogue/thread/linear_combination_silu.h" 16 | #include "cutlass_extensions/epilogue/thread/ft_fused_activations.h" 17 | 18 | namespace fastertransformer { 19 | 20 | struct EpilogueOpBiasSilu {}; 21 | 22 | struct EpilogueOpBiasReLU {}; 23 | 24 | struct EpilogueOpBiasFtGelu {}; 25 | 26 | struct EpilogueOpBias {}; 27 | 28 | struct EpilogueOpNoBias {}; 29 | 30 | template 31 | struct Epilogue { 32 | }; 33 | 34 | template 35 | struct Epilogue { 36 | using Op = cutlass::epilogue::thread::LinearCombinationSilu; 41 | }; 42 | 43 | template 44 | struct Epilogue { 45 | using Op = cutlass::epilogue::thread::LinearCombinationRelu; 50 | }; 51 | 52 | template 53 | struct Epilogue { 54 | using Op = cutlass::epilogue::thread::LinearCombinationGeneric; 62 | }; 63 | 64 | template 65 | struct Epilogue { 66 | using Op = cutlass::epilogue::thread::LinearCombination; 71 | }; 72 | 73 | template 74 | struct Epilogue { 75 | using Op = cutlass::epilogue::thread::LinearCombination; 80 | }; 81 | 82 | } // namespace fastertransformer 83 | -------------------------------------------------------------------------------- /dskernels/ft_gemm/gemm_variants/cutlass_extensions/ft_gemm_configs.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #pragma once 18 | 19 | namespace fastertransformer { 20 | 21 | // Note: The shapes are in the format MxNxK. The K shape of the runtime config MUST match the K shape 22 | // in the kernel layout details when doing weight only quantization. 23 | enum class CutlassTileConfig { 24 | // Signals that we should run heuristics do choose a config 25 | Undefined, 26 | 27 | // Signals that we should run heuristics do choose a config 28 | ChooseWithHeuristic, 29 | 30 | // SiMT config 31 | CtaShape128x128x8_WarpShape64x64x8, 32 | 33 | // TensorCore configs CTA_N = 128, CTA_K = 64 34 | // Warp configs for M=32 35 | CtaShape32x128x64_WarpShape32x32x64, 36 | 37 | // Warp configs for M=64 38 | CtaShape64x128x64_WarpShape32x64x64, 39 | CtaShape64x128x64_WarpShape64x32x64, 40 | 41 | // Warp configs for M=128 42 | CtaShape128x128x64_WarpShape64x32x64, 43 | CtaShape128x128x64_WarpShape128x32x64 44 | }; 45 | 46 | enum class SplitKStyle { 47 | NO_SPLIT_K, 48 | SPLIT_K_SERIAL, 49 | // SPLIT_K_PARALLEL // Not supported yet 50 | }; 51 | 52 | struct CutlassGemmConfig { 53 | CutlassTileConfig tile_config = CutlassTileConfig::ChooseWithHeuristic; 54 | SplitKStyle split_k_style = SplitKStyle::NO_SPLIT_K; 55 | int split_k_factor = -1; 56 | int stages = -1; 57 | }; 58 | 59 | } // namespace fastertransformer 60 | -------------------------------------------------------------------------------- /dskernels/ft_gemm/gemm_variants/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "cutlass/arch/arch.h" 4 | #include "cutlass/arch/mma.h" 5 | #include "cutlass/bfloat16.h" 6 | #include "cutlass/cutlass.h" 7 | #include "cutlass/gemm/gemm.h" 8 | #include "cutlass/layout/matrix.h" 9 | 10 | #include "cutlass_extensions/arch/mma.h" 11 | #include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h" 12 | 13 | namespace cutlass { 14 | namespace gemm { 15 | namespace kernel { 16 | 17 | template 18 | struct MixedGemmArchTraits { 19 | }; 20 | 21 | template 22 | struct MixedGemmArchTraits { 23 | static constexpr int Stages = 2; 24 | using OperatorClass = cutlass::arch::OpClassSimt; 25 | using AccType = float; 26 | using LayoutB = cutlass::layout::RowMajor; 27 | 28 | static constexpr int ElementsPerAccessA = 1; 29 | static constexpr int ElementsPerAccessB = 1; 30 | static constexpr int ElementsPerAccessC = 1; 31 | static constexpr int ThreadblockK = 8; 32 | using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; 33 | 34 | using Operator = cutlass::arch::OpMultiplyAdd; 35 | }; 36 | 37 | // ========================= Volta Traits =========================== 38 | // Volta will always dequantize after the global memory load. 39 | // This will instantiate any HMMA tensorcore kernels for Volta. 40 | // Note that volta does not have native bfloat support so weights and activations will be casted to fp16 41 | // and compute will happen in fp16 then will be converted for bf16 output. 42 | template 43 | struct MixedGemmArchTraits< 44 | TypeA, 45 | TypeB, 46 | cutlass::arch::Sm70, 47 | typename cutlass::platform::enable_if::value 48 | || cutlass::platform::is_same::value>::type> { 49 | private: 50 | using LayoutDetails = LayoutDetailsB; 51 | 52 | public: 53 | static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; 54 | 55 | using OperatorClass = cutlass::arch::OpClassTensorOp; 56 | using AccType = float; 57 | using LayoutB = typename LayoutDetails::Layout; 58 | 59 | static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; 60 | static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; 61 | static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; 62 | using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; 63 | 64 | using Operator = typename LayoutDetails::Operator; 65 | }; 66 | 67 | // ======================= Turing Traits ============================== 68 | // Note that turing does not have native bfloat support so weights and activations will be casted to fp16 69 | // and compute will happen in fp16 then will be converted for bf16 output. 70 | template 71 | struct MixedGemmArchTraits< 72 | TypeA, 73 | TypeB, 74 | cutlass::arch::Sm75, 75 | typename cutlass::platform::enable_if::value 76 | || cutlass::platform::is_same::value>::type> { 77 | private: 78 | using LayoutDetails = LayoutDetailsB; 79 | 80 | public: 81 | static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; 82 | 83 | using OperatorClass = cutlass::arch::OpClassTensorOp; 84 | using AccType = float; 85 | using LayoutB = typename LayoutDetails::Layout; 86 | 87 | static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; 88 | static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; 89 | static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; 90 | using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; 91 | 92 | using Operator = typename LayoutDetails::Operator; 93 | }; 94 | 95 | // ======================= Ampere Traits ============================== 96 | template 97 | struct MixedGemmArchTraits< 98 | TypeA, 99 | TypeB, 100 | cutlass::arch::Sm80, 101 | typename cutlass::platform::enable_if::value 102 | || cutlass::platform::is_same::value>::type> { 103 | private: 104 | using LayoutDetails = LayoutDetailsB; 105 | 106 | public: 107 | static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; 108 | 109 | using OperatorClass = cutlass::arch::OpClassTensorOp; 110 | using AccType = float; 111 | using LayoutB = typename LayoutDetails::Layout; 112 | 113 | static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; 114 | static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; 115 | static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; 116 | using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; 117 | 118 | using Operator = typename LayoutDetails::Operator; 119 | }; 120 | 121 | } // namespace kernel 122 | } // namespace gemm 123 | } // namespace cutlass -------------------------------------------------------------------------------- /dskernels/ft_gemm/gemm_variants/cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h: -------------------------------------------------------------------------------- 1 | /*************************************************************************************************** 2 | * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | * 5 | * Redistribution and use in source and binary forms, with or without 6 | * modification, are permitted provided that the following conditions are met: 7 | * 8 | * 1. Redistributions of source code must retain the above copyright notice, this 9 | * list of conditions and the following disclaimer. 10 | * 11 | * 2. Redistributions in binary form must reproduce the above copyright notice, 12 | * this list of conditions and the following disclaimer in the documentation 13 | * and/or other materials provided with the distribution. 14 | * 15 | * 3. Neither the name of the copyright holder nor the names of its 16 | * contributors may be used to endorse or promote products derived from 17 | * this software without specific prior written permission. 18 | * 19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | * 30 | **************************************************************************************************/ 31 | 32 | /*! \file 33 | \brief Scheduler for grouped GEMM 34 | */ 35 | 36 | #pragma once 37 | 38 | #include "cutlass/cutlass.h" 39 | #include "cutlass/gemm/gemm.h" 40 | #include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h" 41 | #include "cutlass/matrix_coord.h" 42 | 43 | #include "cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h" 44 | #include "cutlass_extensions/gemm/kernel/moe_problem_visitor.h" 45 | 46 | ///////////////////////////////////////////////////////////////////////////////////////////////// 47 | 48 | namespace cutlass { 49 | namespace gemm { 50 | namespace kernel { 51 | 52 | /// Visitor class to abstract away the algorithm for iterating over tiles 53 | template 58 | struct GemmMoeProblemVisitor: 59 | public MoeProblemVisitor, 60 | ThreadblockShape, 61 | GroupScheduleMode_, 62 | PrefetchTileCount, 63 | ThreadCount> { 64 | 65 | static bool const kTransposed = Transposed; 66 | 67 | using ProblemSizeHelper = detail::GemmGroupedProblemSizeHelper; 68 | using Base = 69 | MoeProblemVisitor; 70 | using Params = typename Base::Params; 71 | using SharedStorage = typename Base::SharedStorage; 72 | 73 | // 74 | // Methods 75 | // 76 | CUTLASS_DEVICE 77 | GemmMoeProblemVisitor(Params const& params_, SharedStorage& shared_storage_, int32_t block_idx): 78 | Base(params_, shared_storage_, block_idx) 79 | { 80 | } 81 | }; 82 | 83 | ///////////////////////////////////////////////////////////////////////////////////////////////// 84 | 85 | } // namespace kernel 86 | } // namespace gemm 87 | } // namespace cutlass 88 | 89 | ///////////////////////////////////////////////////////////////////////////////////////////////// -------------------------------------------------------------------------------- /dskernels/ft_gemm/gemm_variants/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h: -------------------------------------------------------------------------------- 1 | /* 2 | This file exists so that we use the same weight layout for MoE grouped gemm and regular gemm when the weight is 3 | quantized. The preprocessing code reads this template to know how to organize the quantized weight matrices 4 | to be consumed by CUTLASS. 5 | 6 | Note that for int4, ThreadBlockK MUST be 64. 7 | 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "cutlass/layout/matrix.h" 13 | #include "cutlass/numeric_types.h" 14 | 15 | #include "cutlass/arch/arch.h" 16 | #include "cutlass/arch/mma.h" 17 | #include "cutlass/platform/platform.h" 18 | 19 | #include "cutlass_extensions/arch/mma.h" 20 | #include "cutlass_extensions/tile_interleaved_layout.h" 21 | 22 | namespace cutlass { 23 | namespace gemm { 24 | namespace kernel { 25 | 26 | template 27 | struct LayoutDetailsB { 28 | }; 29 | 30 | // Volta specialiations. Volta will dequantize before STS, so we need a different operator 31 | template 32 | struct LayoutDetailsB { 33 | static constexpr int ThreadblockK = 64; 34 | using Layout = layout::RowMajor; 35 | static constexpr int ElementsPerAccess = 8; 36 | using Operator = cutlass::arch::OpMultiplyAdd; 37 | }; 38 | 39 | // Specializations for Turing+ when B is FP16. These are currently only used for MoE networks. 40 | // TODO - Switch this to column major for weights since gemms should be more performant. 41 | template 42 | struct LayoutDetailsB= 75>::type> { 43 | static constexpr int ThreadblockK = 64; 44 | using Layout = layout::RowMajor; 45 | static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; 46 | using Operator = cutlass::arch::OpMultiplyAdd; 47 | }; 48 | 49 | template 50 | struct LayoutDetailsB= 75>::type> { 51 | static constexpr int ThreadblockK = 64; 52 | using Layout = layout::RowMajor; 53 | static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; 54 | using Operator = cutlass::arch::OpMultiplyAdd; 55 | }; 56 | 57 | // Specializations for Turing+ when B is quantized. These can use the operator OpMultiplyAddDequantizeInterleavedBToA, 58 | // which signals that we want to dequantize after loading from smem. 59 | template 60 | struct LayoutDetailsB= 75>::type> { 61 | static constexpr int ThreadblockK = 64; 62 | 63 | private: 64 | static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits::value; 65 | static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; 66 | 67 | public: 68 | using Layout = layout::ColumnMajorTileInterleave; 69 | static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; 70 | using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; 71 | }; 72 | 73 | template 74 | struct LayoutDetailsB= 75>::type> { 75 | static constexpr int ThreadblockK = 64; 76 | 77 | private: 78 | static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits::value; 79 | static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; 80 | 81 | public: 82 | using Layout = layout::ColumnMajorTileInterleave; 83 | static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; 84 | using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; 85 | }; 86 | 87 | } // namespace kernel 88 | } // namespace gemm 89 | } // namespace cutlass -------------------------------------------------------------------------------- /dskernels/ft_gemm/gemm_variants/cutlass_extensions/gemm/threadblock/default_dq_mma.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "cutlass_extensions/arch/mma.h" 4 | #include "cutlass_extensions/interleaved_numeric_conversion.h" 5 | 6 | namespace cutlass { 7 | namespace gemm { 8 | namespace threadblock { 9 | //////////////////////////////////////////////////////////////////////////////// 10 | 11 | // We need to distinguish here, since we want volta support. It is too much effort 12 | // to write shared memory iterators that are probably needed for volta to function 13 | // properly. As a result, we allow converters both after the LDG (for volta) and after 14 | // the LDS for Turing+. 15 | template< 16 | /// Iterator for B matrix in global memory 17 | typename IteratorB, 18 | /// Warp level Mma 19 | typename MmaOperator, 20 | /// Math operation perform by warp level operator 21 | typename MathOperator> 22 | struct SetConverters { 23 | }; 24 | 25 | // Dequantize after LDG, so set transforms accordingly 26 | template< 27 | /// Iterator for B matrix in global memory 28 | typename IteratorB, 29 | /// Mma Policy 30 | typename MmaOperator> 31 | struct SetConverters { 32 | using TransformAfterLDG = 33 | FastInterleavedAndBiasedNumericArrayConverter; 36 | 37 | using TransformAfterLDS = NumericArrayConverter; 40 | }; 41 | 42 | // Dequantize after LDS, so set transforms accordingly 43 | 44 | template< 45 | /// Iterator for B matrix in global memory 46 | typename IteratorB, 47 | /// Mma Policy 48 | typename MmaOperator> 49 | struct SetConverters { 50 | using TransformAfterLDG = 51 | NumericArrayConverter; 52 | 53 | using TransformAfterLDS = 54 | FastInterleavedAndBiasedNumericArrayConverter; 57 | }; 58 | 59 | //////////////////////////////////////////////////////////////////////////////// 60 | 61 | template< 62 | /// Element type for A matrix operand 63 | typename ElementA_, 64 | /// Layout type for A matrix operand 65 | typename LayoutA_, 66 | /// Access granularity of A matrix in units of elements 67 | int kAlignmentA, 68 | /// Element type for B matrix operand 69 | typename ElementB_, 70 | /// Layout type for B matrix operand 71 | typename LayoutB_, 72 | /// Access granularity of B matrix in units of elements 73 | int kAlignmentB, 74 | /// Element type for the input scale 75 | typename ElementScale_, 76 | /// Layout for the scale operand 77 | typename LayoutScale_, 78 | /// Access granularity of Scales in unit of elements 79 | int kAlignmentScale, 80 | /// Element type for internal accumulation 81 | typename ElementAccumulator_, 82 | /// Layout type for C and D matrix operands 83 | typename LayoutC_, 84 | /// Operator class tag 85 | typename OperatorClass_, 86 | /// Tag indicating architecture to tune for 87 | typename ArchTag_, 88 | /// Threadblock-level tile size (concept: GemmShape) 89 | typename ThreadblockShape_, 90 | /// Warp-level tile size (concept: GemmShape) 91 | typename WarpShape_, 92 | /// Instruction-level tile size (concept: GemmShape) 93 | typename InstructionShape_, 94 | /// Number of stages used in the pipelined mainloop 95 | int Stages, 96 | /// Operation performed by GEMM 97 | typename Operator_, 98 | /// Use zfill or predicate for out-of-bound cp.async 99 | SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, 100 | /// 101 | typename Enable = void> 102 | struct DqMma; 103 | 104 | } // namespace threadblock 105 | } // namespace gemm 106 | } // namespace cutlass -------------------------------------------------------------------------------- /dskernels/ft_gemm/gemm_variants/cutlass_extensions/gemm/threadblock/dq_mma_base.h: -------------------------------------------------------------------------------- 1 | /*************************************************************************************************** 2 | * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | * 5 | * Redistribution and use in source and binary forms, with or without 6 | * modification, are permitted provided that the following conditions are met: 7 | * 8 | * 1. Redistributions of source code must retain the above copyright notice, this 9 | * list of conditions and the following disclaimer. 10 | * 11 | * 2. Redistributions in binary form must reproduce the above copyright notice, 12 | * this list of conditions and the following disclaimer in the documentation 13 | * and/or other materials provided with the distribution. 14 | * 15 | * 3. Neither the name of the copyright holder nor the names of its 16 | * contributors may be used to endorse or promote products derived from 17 | * this software without specific prior written permission. 18 | * 19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | * 30 | **************************************************************************************************/ 31 | /*! \file 32 | \brief Template for a double-buffered threadblock-scoped GEMM kernel. 33 | */ 34 | 35 | #pragma once 36 | 37 | #include "cutlass/aligned_buffer.h" 38 | #include "cutlass/arch/memory.h" 39 | #include "cutlass/array.h" 40 | #include "cutlass/cutlass.h" 41 | #include "cutlass/gemm/gemm.h" 42 | #include "cutlass/gemm/threadblock/mma_base.h" 43 | #include "cutlass/matrix_shape.h" 44 | #include "cutlass/numeric_types.h" 45 | 46 | //////////////////////////////////////////////////////////////////////////////// 47 | 48 | namespace cutlass { 49 | namespace gemm { 50 | namespace threadblock { 51 | 52 | //////////////////////////////////////////////////////////////////////////////// 53 | // SFINAE trick so I can keep the same loop code for Volta and dispatch to the 54 | // correct warp level mma. On volta, all data is stored to shared memory as FP16. 55 | template 56 | CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, 57 | typename WarpMma::FragmentC& D, 58 | typename WarpMma::FragmentA const& A, 59 | typename WarpMma::FragmentB const& B, 60 | typename WarpMma::FragmentC const& C, 61 | const int warp_tileB_k_offset) 62 | { 63 | warp_mma(D, A, B, C); 64 | } 65 | 66 | template 67 | CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, 68 | typename WarpMma::FragmentC& D, 69 | typename WarpMma::TransformedFragmentA const& A, 70 | typename WarpMma::TransformedFragmentB const& B, 71 | typename WarpMma::FragmentC const& C, 72 | const int warp_tileB_k_offset) 73 | { 74 | warp_mma(D, A, B, C, warp_tileB_k_offset); 75 | } 76 | //////////////////////////////////////////////////////////////////////////////// 77 | 78 | /// Structure to compute the matrix product targeting CUDA cores and SIMT math 79 | /// instructions. 80 | template< 81 | /// Size of the Gemm problem - concept: gemm::GemmShape<> 82 | typename Shape_, 83 | /// Policy describing tuning details (concept: MmaPolicy) 84 | typename Policy_, 85 | /// The type of the scales 86 | typename ElementScale_, 87 | /// Number of stages, 88 | int Stages, 89 | /// Used for partial specialization 90 | typename Enable = bool> 91 | class DqMmaBase { 92 | public: 93 | ///< Size of the Gemm problem - concept: gemm::GemmShape<> 94 | using Shape = Shape_; 95 | 96 | ///< Policy describing tuning details 97 | using Policy = Policy_; 98 | 99 | ///< Type of the scale to be loaded 100 | using ElementScale = ElementScale_; 101 | 102 | // 103 | // Dependent types 104 | // 105 | 106 | /// Warp-level Mma 107 | using Operator = typename Policy::Operator; 108 | 109 | /// Shape describing the overall GEMM computed from shared memory 110 | /// by each warp. 111 | using WarpGemm = typename Policy::Operator::Shape; 112 | 113 | /// Shape describing the number of warps filling the CTA 114 | using WarpCount = GemmShape; 115 | 116 | /// Number of warp-level GEMM oeprations 117 | static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK); 118 | 119 | static constexpr int kNumKIterationsPerWarpBLoad = 120 | Operator::IteratorB::InstructionShape::kRow / Operator::InstructionShape::kK; 121 | 122 | static_assert(!(kWarpGemmIterations % kNumKIterationsPerWarpBLoad), ""); 123 | static constexpr int kWarpGemmIterationsForB = kWarpGemmIterations / kNumKIterationsPerWarpBLoad; 124 | 125 | /// Number of stages 126 | static int const kStages = Stages; 127 | 128 | /// Tensor reference to the A operand 129 | using TensorRefA = TensorRef; 130 | 131 | /// Tensor reference to the B operand 132 | using TensorRefB = TensorRef; 133 | 134 | // 135 | // Nested structs 136 | // 137 | 138 | /// Shared storage object needed by threadblock-scoped GEMM 139 | class SharedStorage { 140 | public: 141 | // 142 | // Type definitions 143 | // 144 | 145 | /// Shape of the A matrix operand in shared memory 146 | using ShapeA = 147 | MatrixShape; 148 | 149 | /// Shape of the B matrix operand in shared memory 150 | using ShapeB = 151 | MatrixShape; 152 | 153 | public: 154 | // 155 | // Data members 156 | // 157 | 158 | /// Buffer for A operand 159 | AlignedBuffer operand_A; 160 | 161 | /// Buffer for B operand 162 | AlignedBuffer operand_B; 163 | 164 | /// Buffer to hold scales for threadblock 165 | AlignedBuffer operand_scale; 166 | 167 | public: 168 | // 169 | // Methods 170 | // 171 | 172 | /// Returns a layout object for the A matrix 173 | CUTLASS_DEVICE 174 | static typename Operator::LayoutA LayoutA() 175 | { 176 | return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); 177 | } 178 | 179 | /// Returns a layout object for the B matrix 180 | CUTLASS_HOST_DEVICE 181 | static typename Operator::LayoutB LayoutB() 182 | { 183 | return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); 184 | } 185 | 186 | /// Returns a TensorRef to the A operand 187 | CUTLASS_HOST_DEVICE 188 | TensorRefA operand_A_ref() 189 | { 190 | return TensorRefA{operand_A.data(), LayoutA()}; 191 | } 192 | 193 | /// Returns a TensorRef to the B operand 194 | CUTLASS_HOST_DEVICE 195 | TensorRefB operand_B_ref() 196 | { 197 | return TensorRefB{operand_B.data(), LayoutB()}; 198 | } 199 | }; 200 | 201 | protected: 202 | // 203 | // Data members 204 | // 205 | 206 | /// Iterator to load a warp-scoped tile of A operand from shared memory 207 | typename Operator::IteratorA warp_tile_iterator_A_; 208 | 209 | /// Iterator to load a warp-scoped tile of B operand from shared memory 210 | typename Operator::IteratorB warp_tile_iterator_B_; 211 | 212 | public: 213 | /// Construct from tensor references 214 | CUTLASS_DEVICE 215 | DqMmaBase( 216 | ///< Shared storage needed for internal use by threadblock-scoped GEMM 217 | SharedStorage& shared_storage, 218 | ///< ID within the threadblock 219 | int thread_idx, 220 | ///< ID of warp 221 | int warp_idx, 222 | ///< ID of each thread within a warp 223 | int lane_idx): 224 | warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), 225 | warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) 226 | { 227 | } 228 | }; 229 | 230 | ///////////////////////////////////////////////////////////////////////////////////////////////// 231 | 232 | } // namespace threadblock 233 | } // namespace gemm 234 | } // namespace cutlass 235 | 236 | ///////////////////////////////////////////////////////////////////////////////////////////////// 237 | -------------------------------------------------------------------------------- /dskernels/ft_gemm/gemm_variants/cutlass_extensions/gemm/warp/default_mma_tensor_op.h: -------------------------------------------------------------------------------- 1 | /*************************************************************************************************** 2 | * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | * 5 | * Redistribution and use in source and binary forms, with or without 6 | * modification, are permitted provided that the following conditions are met: 7 | * 8 | * 1. Redistributions of source code must retain the above copyright notice, this 9 | * list of conditions and the following disclaimer. 10 | * 11 | * 2. Redistributions in binary form must reproduce the above copyright notice, 12 | * this list of conditions and the following disclaimer in the documentation 13 | * and/or other materials provided with the distribution. 14 | * 15 | * 3. Neither the name of the copyright holder nor the names of its 16 | * contributors may be used to endorse or promote products derived from 17 | * this software without specific prior written permission. 18 | * 19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | * 30 | **************************************************************************************************/ 31 | /*! \file 32 | \brief Default warp-level GEMM operators selected by data type, size, and layouts of operands. 33 | */ 34 | 35 | #pragma once 36 | 37 | #include "cutlass/cutlass.h" 38 | #include "cutlass/gemm/warp/default_mma_tensor_op.h" 39 | #include "cutlass/gemm/warp/mma_tensor_op.h" 40 | 41 | #include "cutlass_extensions/arch/mma.h" 42 | #include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h" 43 | 44 | namespace cutlass { 45 | namespace gemm { 46 | namespace warp { 47 | 48 | ///////////////////////////////////////////////////////////////////////////////////////////////// 49 | 50 | /// Partial specialization for m-by-n-by-kgroup 51 | template< 52 | /// Shape of one matrix production operation (concept: GemmShape) 53 | typename WarpShape_, 54 | /// Shape of one matrix production operation (concept: GemmShape) 55 | typename InstructionShape_, 56 | /// Data type of A elements, 57 | typename ElementA, 58 | /// Layout of A matrix (concept: MatrixLayout) 59 | typename LayoutA, 60 | /// Data type of B elements 61 | typename ElementB, 62 | /// Layout of B matrix (concept: MatrixLayout) 63 | typename LayoutB, 64 | /// Element type of C matrix 65 | typename ElementC, 66 | /// Layout of C matrix (concept: MatrixLayout) 67 | typename LayoutC, 68 | /// Number of partitions along K dimension 69 | int PartitionsK, 70 | /// Store the accumulators in row major or column major. Row major is used 71 | /// when output layout is interleaved. 72 | bool AccumulatorsInRowMajor> 73 | struct DefaultMmaTensorOp { 84 | 85 | private: 86 | // Shape for computing the FP16s 87 | using ComputeInstructionShape = InstructionShape_; 88 | 89 | // Chosen so we get K=16 for int8 and K=32 for int4. 90 | static constexpr int LoadInstructionK = 8 * sizeof_bits::value / sizeof_bits::value; 91 | 92 | // Shape for loading the narrow data type from shared memory 93 | using LoadInstructionShape = GemmShape; 94 | 95 | public: 96 | using Policy = cutlass::gemm::warp::MmaTensorOpPolicy, 105 | cutlass::MatrixShape<1, 1>>; 106 | 107 | // Define the warp-level tensor op 108 | using Type = cutlass::gemm::warp::MmaTensorOpComputeBWithF16; 119 | }; 120 | 121 | ///////////////////////////////////////////////////////////////////////////////////////////////// 122 | 123 | } // namespace warp 124 | } // namespace gemm 125 | } // namespace cutlass 126 | 127 | ///////////////////////////////////////////////////////////////////////////////////////////////// 128 | -------------------------------------------------------------------------------- /dskernels/ft_gemm/gemm_variants/cutlass_extensions/tile_interleaved_layout.h: -------------------------------------------------------------------------------- 1 | /*************************************************************************************************** 2 | * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | * 5 | * Redistribution and use in source and binary forms, with or without 6 | * modification, are permitted provided that the following conditions are met: 7 | * 8 | * 1. Redistributions of source code must retain the above copyright notice, this 9 | * list of conditions and the following disclaimer. 10 | * 11 | * 2. Redistributions in binary form must reproduce the above copyright notice, 12 | * this list of conditions and the following disclaimer in the documentation 13 | * and/or other materials provided with the distribution. 14 | * 15 | * 3. Neither the name of the copyright holder nor the names of its 16 | * contributors may be used to endorse or promote products derived from 17 | * this software without specific prior written permission. 18 | * 19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | * 30 | **************************************************************************************************/ 31 | /*! \file 32 | \brief Defines new layouts needed for MoE 33 | */ 34 | #pragma once 35 | 36 | #include "cutlass/cutlass.h" 37 | #include "cutlass/fast_math.h" 38 | #include "cutlass/matrix_coord.h" 39 | #include "cutlass/pitch_linear_coord.h" 40 | 41 | namespace cutlass { 42 | namespace layout { 43 | 44 | template 45 | class ColumnMajorTileInterleave { 46 | static constexpr int kRowsPerTile = RowsPerTile; 47 | static constexpr int kColumnsInterleaved = ColumnsInterleaved; 48 | }; 49 | 50 | template 51 | struct IsColumnMajorTileInterleave { 52 | static constexpr bool value = false; 53 | }; 54 | 55 | template 56 | struct IsColumnMajorTileInterleave> { 57 | static constexpr bool value = true; 58 | }; 59 | 60 | } // namespace layout 61 | } // namespace cutlass 62 | -------------------------------------------------------------------------------- /dskernels/ft_gemm/gemm_variants/fpA_intB_gemm/fpA_intB_gemm.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #pragma once 18 | 19 | #include "cutlass_extensions/ft_gemm_configs.h" 20 | #include "utils/activation_type.h" 21 | #include 22 | #include "utils/weight_variant.h" 23 | 24 | namespace fastertransformer { 25 | 26 | /* 27 | This runner only supports: 28 | T in {half, __nv_bfloat} WeightType in {int8_t, cutlass::uint4b_t} 29 | 30 | Activations, biases, scales and outputs are all assumed to be row-major. 31 | 32 | However, it is assumed that B is in a special format governed by cutlass_extensions/gemm/kernel/mixed_gemm_B_layout. 33 | In this case, B must be preprocessed using the cutlass weight only quant preprocessors. The weight preprocessor 34 | will instantiate the layout and preprocess based on the instantiation, so layout changes should only require 35 | modifications to mix_gemm_B_layout.h. 36 | */ 37 | 38 | template 39 | class CutlassFpAIntBGemmRunner { 40 | public: 41 | using WeightType = typename WeightStorageType::type; 42 | 43 | CutlassFpAIntBGemmRunner(); 44 | ~CutlassFpAIntBGemmRunner(); 45 | 46 | void gemm(const T* A, 47 | const char* B, 48 | const T* weight_scales, 49 | T* C, 50 | int m, 51 | int n, 52 | int k, 53 | char* workspace_ptr, 54 | const size_t workspace_bytes, 55 | cudaStream_t stream); 56 | 57 | void gemm_bias_act(const T* A, 58 | const char* B, 59 | const T* weight_scales, 60 | const T* biases, 61 | T* C, 62 | int m, 63 | int n, 64 | int k, 65 | ActivationType activation_type, 66 | char* workspace_ptr, 67 | const size_t workspace_bytes, 68 | cudaStream_t stream); 69 | 70 | // Returns desired workspace size in bytes. 71 | int getWorkspaceSize(const int m, const int n, const int k); 72 | 73 | private: 74 | template 75 | void dispatch_to_arch(const T* A, 76 | const WeightType* B, 77 | const T* weight_scales, 78 | const T* biases, 79 | T* C, 80 | int m, 81 | int n, 82 | int k, 83 | CutlassGemmConfig gemm_config, 84 | char* workspace_ptr, 85 | const size_t workspace_bytes, 86 | cudaStream_t stream, 87 | int* occupancy = nullptr); 88 | 89 | template 90 | void run_gemm(const T* A, 91 | const WeightType* B, 92 | const T* weight_scales, 93 | const T* biases, 94 | T* C, 95 | int m, 96 | int n, 97 | int k, 98 | char* workspace_ptr, 99 | const size_t workspace_bytes, 100 | cudaStream_t stream); 101 | 102 | private: 103 | static constexpr int split_k_limit = 7; 104 | 105 | int sm_; 106 | int multi_processor_count_; 107 | }; 108 | 109 | } // namespace fastertransformer 110 | -------------------------------------------------------------------------------- /dskernels/ft_gemm/gemm_variants/fpA_intB_gemm/fpA_intB_gemm_bf16_uint4.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #include "fpA_intB_gemm/fpA_intB_gemm_template.h" 18 | 19 | namespace fastertransformer { 20 | #ifdef ENABLE_BF16 21 | template class CutlassFpAIntBGemmRunner<__nv_bfloat16, WeightVariant::kFP4>; 22 | #endif 23 | } // namespace fastertransformer 24 | -------------------------------------------------------------------------------- /dskernels/ft_gemm/gemm_variants/fpA_intB_gemm/fpA_intB_gemm_bf16_uint8.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #include "fpA_intB_gemm/fpA_intB_gemm_template.h" 18 | 19 | namespace fastertransformer { 20 | #ifdef ENABLE_BF16 21 | template class CutlassFpAIntBGemmRunner<__nv_bfloat16, WeightVariant::kFP8>; 22 | #endif 23 | } // namespace fastertransformer -------------------------------------------------------------------------------- /dskernels/ft_gemm/gemm_variants/fpA_intB_gemm/fpA_intB_gemm_fp16_int4.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #include "fpA_intB_gemm/fpA_intB_gemm_template.h" 18 | 19 | namespace fastertransformer { 20 | template class CutlassFpAIntBGemmRunner; 21 | } // namespace fastertransformer -------------------------------------------------------------------------------- /dskernels/ft_gemm/gemm_variants/fpA_intB_gemm/fpA_intB_gemm_fp16_int8.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #include "fpA_intB_gemm/fpA_intB_gemm_template.h" 18 | 19 | namespace fastertransformer { 20 | template class CutlassFpAIntBGemmRunner; 21 | } // namespace fastertransformer -------------------------------------------------------------------------------- /dskernels/ft_gemm/gemm_variants/moe_gemm/moe_gemm_kernels.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #pragma once 18 | #include "cutlass_extensions/ft_gemm_configs.h" 19 | #include "utils/activation_type.h" 20 | #include 21 | #include "utils/weight_variant.h" 22 | 23 | namespace fastertransformer { 24 | 25 | template 27 | class MoeGemmRunner { 28 | public: 29 | using WeightType = typename WeightStorageType::type; 30 | 31 | MoeGemmRunner(); 32 | 33 | void moe_gemm_bias_act(const T* A, 34 | const char* B, 35 | const T* weight_scales, 36 | const T* biases, 37 | T* C, 38 | int64_t* total_rows_before_expert, 39 | int64_t total_rows, 40 | int64_t gemm_n, 41 | int64_t gemm_k, 42 | int num_experts, 43 | ActivationType activation_type, 44 | cudaStream_t stream); 45 | 46 | void moe_gemm(const T* A, 47 | const char* B, 48 | const T* weight_scales, 49 | T* C, 50 | int64_t* total_rows_before_expert, 51 | int64_t total_rows, 52 | int64_t gemm_n, 53 | int64_t gemm_k, 54 | int num_experts, 55 | cudaStream_t stream); 56 | 57 | private: 58 | template 59 | void dispatch_to_arch(const T* A, 60 | const WeightType* B, 61 | const T* weight_scales, 62 | const T* biases, 63 | T* C, 64 | int64_t* total_rows_before_expert, 65 | int64_t total_rows, 66 | int64_t gemm_n, 67 | int64_t gemm_k, 68 | int num_experts, 69 | CutlassGemmConfig gemm_config, 70 | cudaStream_t stream, 71 | int* occupancy = nullptr); 72 | 73 | template 74 | void run_gemm(const T* A, 75 | const WeightType* B, 76 | const T* weight_scales, 77 | const T* biases, 78 | T* C, 79 | int64_t* total_rows_before_expert, 80 | int64_t total_rows, 81 | int64_t gemm_n, 82 | int64_t gemm_k, 83 | int num_experts, 84 | cudaStream_t stream); 85 | 86 | private: 87 | int sm_; 88 | int multi_processor_count_; 89 | }; 90 | 91 | } // namespace fastertransformer -------------------------------------------------------------------------------- /dskernels/ft_gemm/gemm_variants/moe_gemm/moe_gemm_kernels_bf16_bf16.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #include "moe_gemm/moe_gemm_kernels_template.h" 18 | 19 | namespace fastertransformer { 20 | #ifdef ENABLE_BF16 21 | template class MoeGemmRunner<__nv_bfloat16, WeightVariant::kBF16>; 22 | #endif 23 | } // namespace fastertransformer -------------------------------------------------------------------------------- /dskernels/ft_gemm/gemm_variants/moe_gemm/moe_gemm_kernels_bf16_uint4.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #include "moe_gemm/moe_gemm_kernels_template.h" 18 | 19 | namespace fastertransformer { 20 | #ifdef ENABLE_BF16 21 | template class MoeGemmRunner<__nv_bfloat16, WeightVariant::kFP4>; 22 | #endif 23 | } // namespace fastertransformer -------------------------------------------------------------------------------- /dskernels/ft_gemm/gemm_variants/moe_gemm/moe_gemm_kernels_bf16_uint8.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #include "moe_gemm/moe_gemm_kernels_template.h" 18 | 19 | namespace fastertransformer { 20 | #ifdef ENABLE_BF16 21 | template class MoeGemmRunner<__nv_bfloat16, WeightVariant::kFP8>; 22 | #endif 23 | } // namespace fastertransformer -------------------------------------------------------------------------------- /dskernels/ft_gemm/gemm_variants/moe_gemm/moe_gemm_kernels_fp16_fp16.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #include "moe_gemm/moe_gemm_kernels_template.h" 18 | 19 | namespace fastertransformer { 20 | template class MoeGemmRunner; 21 | } -------------------------------------------------------------------------------- /dskernels/ft_gemm/gemm_variants/moe_gemm/moe_gemm_kernels_fp16_uint4.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #include "moe_gemm/moe_gemm_kernels_template.h" 18 | 19 | namespace fastertransformer { 20 | template class MoeGemmRunner; 21 | } -------------------------------------------------------------------------------- /dskernels/ft_gemm/gemm_variants/moe_gemm/moe_gemm_kernels_fp16_uint8.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #include "moe_gemm/moe_gemm_kernels_template.h" 18 | 19 | namespace fastertransformer { 20 | template class MoeGemmRunner; 21 | } -------------------------------------------------------------------------------- /dskernels/ft_gemm/gemm_variants/utils/activation_type.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // SPDX-License-Identifier: Apache-2.0 3 | 4 | // DeepSpeed Team 5 | 6 | #pragma once 7 | 8 | /* 9 | NOTE(cmikeh2): This needs to match the equivalent file in deepspeed/csrc/includes 10 | exactly to ensure coherence. 11 | */ 12 | 13 | enum ActivationType { 14 | GELU = 0, 15 | RELU = 1, 16 | SILU = 2, 17 | GEGLU = 3, 18 | ReGLU = 4, 19 | SiGLU = 5, 20 | IDENTITY = 6, 21 | InvalidType = -1 22 | }; 23 | 24 | inline bool isGatedActivation(ActivationType activation_type) 25 | { 26 | return activation_type == ActivationType::GEGLU || activation_type == ActivationType::ReGLU 27 | || activation_type == ActivationType::SiGLU; 28 | } 29 | -------------------------------------------------------------------------------- /dskernels/ft_gemm/gemm_variants/utils/cutlass_heuristic.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #include "utils/cutlass_heuristic.h" 18 | #include "cuda_bf16.h" 19 | 20 | #pragma GCC diagnostic push 21 | #pragma GCC diagnostic ignored "-Wstrict-aliasing" 22 | #include "cutlass/gemm/gemm.h" 23 | #include "cutlass/numeric_types.h" 24 | #pragma GCC diagnostic pop 25 | 26 | #include 27 | #include 28 | 29 | namespace fastertransformer { 30 | 31 | struct TileShape { 32 | int m; 33 | int n; 34 | }; 35 | 36 | TileShape get_cta_shape_for_config(CutlassTileConfig tile_config) 37 | { 38 | switch (tile_config) { 39 | case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: 40 | return TileShape{32, 128}; 41 | case CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64: 42 | case CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: 43 | return TileShape{64, 128}; 44 | case CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8: 45 | case CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64: 46 | case CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: 47 | return TileShape{128, 128}; 48 | default: 49 | throw std::runtime_error("[FT Error][get_grid_shape_for_config] Invalid config"); 50 | } 51 | } 52 | 53 | bool is_valid_split_k_factor(const int64_t m, 54 | const int64_t n, 55 | const int64_t k, 56 | const TileShape tile_shape, 57 | const int split_k_factor, 58 | const size_t workspace_bytes, 59 | const bool is_weight_only) 60 | { 61 | 62 | // All tile sizes have a k_tile of 64. 63 | static constexpr int k_tile = 64; 64 | 65 | // For weight-only quant, we need k and k_elements_per_split to be a multiple of cta_k 66 | if (is_weight_only) { 67 | if ((k % k_tile) != 0) { 68 | return false; 69 | } 70 | 71 | if ((k % split_k_factor) != 0) { 72 | return false; 73 | } 74 | 75 | const int k_elements_per_split = k / split_k_factor; 76 | if ((k_elements_per_split % k_tile) != 0) { 77 | return false; 78 | } 79 | } 80 | 81 | // Check that the workspace has sufficient space for this split-k factor 82 | const int ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m; 83 | const int ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n; 84 | const int required_ws_bytes = split_k_factor == 1 ? 0 : sizeof(int) * ctas_in_m_dim * ctas_in_n_dim; 85 | 86 | if (required_ws_bytes > workspace_bytes) { 87 | return false; 88 | } 89 | 90 | return true; 91 | } 92 | 93 | std::vector get_candidate_tiles(const bool is_weight_only, const bool simt_configs_only) 94 | { 95 | 96 | std::vector simt_configs{CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8}; 97 | 98 | std::vector square_configs{CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, 99 | CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64, 100 | CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64}; 101 | 102 | std::vector quant_B_configs{CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, 103 | CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, 104 | CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64}; 105 | 106 | const std::vector allowed_configs = is_weight_only ? quant_B_configs : square_configs; 107 | return simt_configs_only ? simt_configs : allowed_configs; 108 | } 109 | 110 | std::vector get_candidate_configs(int sm, const bool is_weight_only, const bool simt_configs_only) 111 | { 112 | std::vector tiles = get_candidate_tiles(is_weight_only, simt_configs_only); 113 | 114 | std::vector candidate_configs; 115 | const int min_stages = 2; 116 | const int max_stages = sm >= 80 ? 4 : 2; 117 | 118 | for (const auto& tile_config : tiles) { 119 | for (int stages = min_stages; stages <= max_stages; ++stages) { 120 | CutlassGemmConfig config{tile_config, SplitKStyle::NO_SPLIT_K, 1, stages}; 121 | candidate_configs.push_back(config); 122 | } 123 | } 124 | 125 | return candidate_configs; 126 | } 127 | 128 | CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector& candidate_configs, 129 | const std::vector& occupancies, 130 | const int64_t m, 131 | const int64_t n, 132 | const int64_t k, 133 | const int64_t num_experts, 134 | const int split_k_limit, 135 | const size_t workspace_bytes, 136 | const int multi_processor_count, 137 | const int is_weight_only) 138 | { 139 | 140 | if (occupancies.size() != candidate_configs.size()) { 141 | throw std::runtime_error("[FT Error][estimate_best_config_from_occupancies] occpancies and " 142 | "candidate configs vectors must have equal length."); 143 | } 144 | 145 | CutlassGemmConfig best_config; 146 | // Score will be [0, 1]. The objective is to minimize this score. 147 | // It represents the fraction of SM resources unused in the last wave. 148 | float config_score = 1.0f; 149 | int config_waves = INT_MAX; 150 | int current_m_tile = 0; 151 | 152 | const int max_split_k = n >= multi_processor_count * 256 ? 1 : split_k_limit; 153 | for (int ii = 0; ii < candidate_configs.size(); ++ii) { 154 | CutlassGemmConfig candidate_config = candidate_configs[ii]; 155 | TileShape tile_shape = get_cta_shape_for_config(candidate_config.tile_config); 156 | int occupancy = occupancies[ii]; 157 | 158 | if (occupancy == 0) { 159 | continue; 160 | } 161 | 162 | // Keep small tile sizes when possible. 163 | if (best_config.tile_config != CutlassTileConfig::ChooseWithHeuristic && m < current_m_tile 164 | && current_m_tile < tile_shape.m) { 165 | continue; 166 | } 167 | 168 | const int ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m; 169 | const int ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n; 170 | 171 | for (int split_k_factor = 1; split_k_factor <= max_split_k; ++split_k_factor) { 172 | if (is_valid_split_k_factor(m, n, k, tile_shape, split_k_factor, workspace_bytes, is_weight_only)) { 173 | const int ctas_per_wave = occupancy * multi_processor_count; 174 | const int ctas_for_problem = ctas_in_m_dim * ctas_in_n_dim * split_k_factor; 175 | 176 | const int num_waves_total = (ctas_for_problem + ctas_per_wave - 1) / ctas_per_wave; 177 | const float num_waves_fractional = ctas_for_problem / float(ctas_per_wave); 178 | const float current_score = float(num_waves_total) - num_waves_fractional; 179 | 180 | const float score_slack = 0.1f; 181 | if (current_score < config_score 182 | || ((config_waves > num_waves_total) && (current_score < config_score + score_slack))) { 183 | config_score = current_score; 184 | config_waves = num_waves_total; 185 | SplitKStyle split_style = 186 | split_k_factor > 1 ? SplitKStyle::SPLIT_K_SERIAL : SplitKStyle::NO_SPLIT_K; 187 | best_config = CutlassGemmConfig{ 188 | candidate_config.tile_config, split_style, split_k_factor, candidate_config.stages}; 189 | current_m_tile = tile_shape.m; 190 | } 191 | else if (current_score == config_score 192 | && (best_config.stages < candidate_config.stages || split_k_factor < best_config.split_k_factor 193 | || current_m_tile < tile_shape.m)) { 194 | // Prefer deeper pipeline or smaller split-k 195 | SplitKStyle split_style = 196 | split_k_factor > 1 ? SplitKStyle::SPLIT_K_SERIAL : SplitKStyle::NO_SPLIT_K; 197 | best_config = CutlassGemmConfig{ 198 | candidate_config.tile_config, split_style, split_k_factor, candidate_config.stages}; 199 | current_m_tile = tile_shape.m; 200 | config_waves = num_waves_total; 201 | } 202 | } 203 | } 204 | } 205 | 206 | if (best_config.tile_config == CutlassTileConfig::ChooseWithHeuristic) { 207 | throw std::runtime_error("[FT Error] Heurisitc failed to find a valid config."); 208 | } 209 | 210 | return best_config; 211 | } 212 | 213 | } // namespace fastertransformer -------------------------------------------------------------------------------- /dskernels/ft_gemm/gemm_variants/utils/cutlass_heuristic.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #pragma once 18 | 19 | #include "cutlass_extensions/ft_gemm_configs.h" 20 | #include "utils/cuda_utils.h" 21 | 22 | namespace fastertransformer { 23 | 24 | std::vector get_candidate_configs(int sm, const bool is_weight_only, const bool simt_configs_only); 25 | 26 | CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector& candidate_configs, 27 | const std::vector& occupancies, 28 | const int64_t m, 29 | const int64_t n, 30 | const int64_t k, 31 | const int64_t num_experts, 32 | const int split_k_limit, 33 | const size_t workspace_bytes, 34 | const int multi_processor_count, 35 | const int is_weight_only); 36 | 37 | } // namespace fastertransformer -------------------------------------------------------------------------------- /dskernels/ft_gemm/gemm_variants/utils/cutlass_preprocessors.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #pragma once 18 | 19 | #include 20 | #include 21 | #include 22 | 23 | #include "utils/cuda_utils.h" 24 | 25 | namespace fastertransformer { 26 | enum class QuantType { 27 | INT8_WEIGHT_ONLY, 28 | PACKED_INT4_WEIGHT_ONLY 29 | }; 30 | int get_bits_in_quant_type(QuantType quant_type); 31 | 32 | // Shapes here can be 2 or 3D. 2-D shapes are [num_rows, num_cols] 33 | // 3-D shapes are [num_experts, num_rows, num_cols] 34 | void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, 35 | const int8_t* quantized_tensor, 36 | const std::vector& shape, 37 | QuantType quant_type, 38 | const int64_t arch_version); 39 | 40 | void subbyte_transpose(int8_t* transposed_quantized_tensor, 41 | const int8_t* quantized_tensor, 42 | const std::vector& shape, 43 | QuantType quant_type); 44 | 45 | void add_bias_and_interleave_quantized_tensor_inplace(int8_t* tensor, const size_t num_elts, QuantType quant_type); 46 | 47 | void preprocess_weights_for_mixed_gemm(int8_t* preprocessed_quantized_weight, 48 | const int8_t* row_major_quantized_weight, 49 | const std::vector& shape, 50 | QuantType quant_type); 51 | 52 | template 53 | void symmetric_quantize(int8_t* processed_quantized_weight, 54 | ComputeType* scale_ptr, 55 | const WeightType* input_weight_ptr, 56 | const std::vector& shape, 57 | QuantType quant_type); 58 | 59 | // This is exposed so that we can write tests that use the processed weights for CUTLASS but the unprocessed weight 60 | // to implement a simple reference implementation. 61 | template 62 | void symmetric_quantize(int8_t* processed_quantized_weight, 63 | int8_t* unprocessed_quantized_weight, 64 | ComputeType* scale_ptr, 65 | const WeightType* input_weight_ptr, 66 | const std::vector& shape, 67 | QuantType quant_type); 68 | 69 | } // namespace fastertransformer -------------------------------------------------------------------------------- /dskernels/ft_gemm/gemm_variants/utils/string_utils.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #pragma once 18 | 19 | #include // std::make_unique 20 | #include // std::stringstream 21 | #include 22 | #include 23 | 24 | namespace fastertransformer { 25 | 26 | template 27 | inline std::string fmtstr(const std::string& format, Args... args) 28 | { 29 | // This function came from a code snippet in stackoverflow under cc-by-1.0 30 | // https://stackoverflow.com/questions/2342162/stdstring-formatting-like-sprintf 31 | 32 | // Disable format-security warning in this function. 33 | #if defined(_MSC_VER) // for visual studio 34 | #pragma warning(push) 35 | #pragma warning(warning(disable : 4996)) 36 | #elif defined(__GNUC__) || defined(__clang__) // for gcc or clang 37 | #pragma GCC diagnostic push 38 | #pragma GCC diagnostic ignored "-Wformat-security" 39 | #endif 40 | int size_s = std::snprintf(nullptr, 0, format.c_str(), args...) + 1; // Extra space for '\0' 41 | if (size_s <= 0) { 42 | throw std::runtime_error("Error during formatting."); 43 | } 44 | auto size = static_cast(size_s); 45 | auto buf = std::make_unique(size); 46 | std::snprintf(buf.get(), size, format.c_str(), args...); 47 | #if defined(_MSC_VER) 48 | #pragma warning(pop) 49 | #elif defined(__GNUC__) || defined(__clang__) 50 | #pragma GCC diagnostic pop 51 | #endif 52 | return std::string(buf.get(), buf.get() + size - 1); // We don't want the '\0' inside 53 | } 54 | 55 | template 56 | inline std::string vec2str(std::vector vec) 57 | { 58 | std::stringstream ss; 59 | ss << "("; 60 | if (!vec.empty()) { 61 | for (size_t i = 0; i < vec.size() - 1; ++i) { 62 | ss << vec[i] << ", "; 63 | } 64 | ss << vec.back(); 65 | } 66 | ss << ")"; 67 | return ss.str(); 68 | } 69 | 70 | template 71 | inline std::string arr2str(T* arr, size_t size) 72 | { 73 | std::stringstream ss; 74 | ss << "("; 75 | for (size_t i = 0; i < size - 1; ++i) { 76 | ss << arr[i] << ", "; 77 | } 78 | if (size > 0) { 79 | ss << arr[size - 1]; 80 | } 81 | ss << ")"; 82 | return ss.str(); 83 | } 84 | } // namespace fastertransformer 85 | -------------------------------------------------------------------------------- /dskernels/ft_gemm/gemm_variants/utils/weight_variant.h: -------------------------------------------------------------------------------- 1 | 2 | 3 | enum WeightVariant { kFP16, kBF16, kFP8, kFP4}; 4 | 5 | template 6 | class WeightStorageType; 7 | 8 | template <> 9 | class WeightStorageType { 10 | public: 11 | using type = half; 12 | }; 13 | 14 | template <> 15 | class WeightStorageType { 16 | public: 17 | using type = nv_bfloat16; 18 | }; 19 | 20 | template <> 21 | class WeightStorageType { 22 | public: 23 | using type = uint8_t; 24 | }; 25 | 26 | template <> 27 | class WeightStorageType { 28 | public: 29 | using type = cutlass::uint4b_t; 30 | }; 31 | -------------------------------------------------------------------------------- /dskernels/inf_flash_attn/blocked_flash/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.11) 2 | project(DeepSpeedBlockedFlash CXX CUDA) 3 | 4 | Set(CMAKE_CXX_STANDARD 17) 5 | Set(CMAKE_CUDA_STANDARD 17) 6 | 7 | if (NOT CMAKE_BUILD_TYPE) 8 | set(CMAKE_BUILD_TYPE Release) 9 | endif() 10 | 11 | list(APPEND CMAKE_MODULE_PATH ${CMAKE_BINARY_DIR}) 12 | list(APPEND CMAKE_PREFIX_PATH ${CMAKE_BINARY_DIR}) 13 | 14 | if (NOT EXISTS "${CMAKE_BINARY_DIR}/conan.cmake") 15 | message(STATUS "Downloading conan.cmake from https://github.com/conan-io/cmake-conan") 16 | file(DOWNLOAD "https://raw.githubusercontent.com/conan-io/cmake-conan/v0.16.1/conan.cmake" 17 | "${CMAKE_BINARY_DIR}/conan.cmake" 18 | EXPECTED_HASH SHA256=396e16d0f5eabdc6a14afddbcfff62a54a7ee75c6da23f32f7a31bc85db23484 19 | TLS_VERIFY ON) 20 | endif() 21 | 22 | find_package(CUDA) 23 | 24 | if (NOT WIN32) 25 | list(APPEND CMAKE_CXX_FLAGS "-fmax-errors=1 -Wfatal-errors") 26 | Set(LIB_NAME "blockedflash") 27 | else() 28 | Set(LIB_NAME "libblockedflash") 29 | endif() 30 | 31 | add_library(${LIB_NAME} SHARED) 32 | 33 | target_include_directories(${LIB_NAME} PRIVATE 34 | ${CMAKE_CURRENT_SOURCE_DIR} 35 | ${CMAKE_CURRENT_SOURCE_DIR}/../third_party/cutlass/include 36 | ) 37 | 38 | list(APPEND CMAKE_CXX_FLAGS ${TORCH_CXX_FLAGS}) 39 | 40 | list(APPEND NVCC_FLAGS "-O3") 41 | list(APPEND NVCC_FLAGS "-U__CUDA_NO_HALF_OPERATORS__") 42 | list(APPEND NVCC_FLAGS "-U__CUDA_NO_HALF_CONVERSIONS__") 43 | list(APPEND NVCC_FLAGS "-U__CUDA_NO_HALF2_OPERATORS__") 44 | list(APPEND NVCC_FLAGS "-U__CUDA_NO_BFLOAT16_CONVERSIONS__") 45 | list(APPEND NVCC_FLAGS "-expt-relaxed-constexpr") 46 | list(APPEND NVCC_FLAGS "--use_fast_math") 47 | 48 | set(VERBOSE_BUILD 0) 49 | 50 | if (VERBOSE_BUILD) 51 | list(APPEND NVCC_FLAGS "--ptxas-options=-v") 52 | endif() 53 | 54 | file(GLOB SRC_FILES *.cu) 55 | 56 | target_sources(${LIB_NAME} PRIVATE ${SRC_FILES}) 57 | target_link_libraries(${LIB_NAME} "${TORCH_LIBRARIES}") 58 | target_compile_options(${LIB_NAME} PRIVATE $<$:${NVCC_FLAGS}>) 59 | set_target_properties(${LIB_NAME} PROPERTIES CUDA_ARCHITECTURES "${CUDA_ARCH_LIST}") 60 | set_target_properties(${LIB_NAME} PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${LIB_OUTPUT_DIR}) 61 | 62 | 63 | # Show timings? 64 | set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE "${CMAKE_COMMAND} -E time") 65 | -------------------------------------------------------------------------------- /dskernels/inf_flash_attn/blocked_flash/LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file. 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /dskernels/inf_flash_attn/blocked_flash/Makefile: -------------------------------------------------------------------------------- 1 | # CMAKE generated file: DO NOT EDIT! 2 | # Generated by "Unix Makefiles" Generator, CMake Version 3.25 3 | 4 | # Default target executed when no arguments are given to make. 5 | default_target: all 6 | .PHONY : default_target 7 | 8 | # Allow only one "make -f Makefile2" at a time, but pass parallelism. 9 | .NOTPARALLEL: 10 | 11 | #============================================================================= 12 | # Special targets provided by cmake. 13 | 14 | # Disable implicit rules so canonical targets will work. 15 | .SUFFIXES: 16 | 17 | # Disable VCS-based implicit rules. 18 | % : %,v 19 | 20 | # Disable VCS-based implicit rules. 21 | % : RCS/% 22 | 23 | # Disable VCS-based implicit rules. 24 | % : RCS/%,v 25 | 26 | # Disable VCS-based implicit rules. 27 | % : SCCS/s.% 28 | 29 | # Disable VCS-based implicit rules. 30 | % : s.% 31 | 32 | .SUFFIXES: .hpux_make_needs_suffix_list 33 | 34 | # Command-line flag to silence nested $(MAKE). 35 | $(VERBOSE)MAKESILENT = -s 36 | 37 | #Suppress display of executed commands. 38 | $(VERBOSE).SILENT: 39 | 40 | # A target that is always out of date. 41 | cmake_force: 42 | .PHONY : cmake_force 43 | 44 | #============================================================================= 45 | # Set environment variables for the build. 46 | 47 | # The shell in which to execute make rules. 48 | SHELL = /bin/sh 49 | 50 | # The CMake executable. 51 | CMAKE_COMMAND = /home/deepspeed/.local/lib/python3.8/site-packages/cmake/data/bin/cmake 52 | 53 | # The command to remove a file. 54 | RM = /home/deepspeed/.local/lib/python3.8/site-packages/cmake/data/bin/cmake -E rm -f 55 | 56 | # Escaping for special characters. 57 | EQUALS = = 58 | 59 | # The top-level source directory on which CMake was run. 60 | CMAKE_SOURCE_DIR = /data/private_dev/DeepSpeed-Kernels/inf_flash_attn/blocked_flash 61 | 62 | # The top-level build directory on which CMake was run. 63 | CMAKE_BINARY_DIR = /data/private_dev/DeepSpeed-Kernels/inf_flash_attn/blocked_flash 64 | 65 | #============================================================================= 66 | # Targets provided globally by CMake. 67 | 68 | # Special rule for the target edit_cache 69 | edit_cache: 70 | @$(CMAKE_COMMAND) -E cmake_echo_color --switch=$(COLOR) --cyan "No interactive CMake dialog available..." 71 | /home/deepspeed/.local/lib/python3.8/site-packages/cmake/data/bin/cmake -E echo No\ interactive\ CMake\ dialog\ available. 72 | .PHONY : edit_cache 73 | 74 | # Special rule for the target edit_cache 75 | edit_cache/fast: edit_cache 76 | .PHONY : edit_cache/fast 77 | 78 | # Special rule for the target rebuild_cache 79 | rebuild_cache: 80 | @$(CMAKE_COMMAND) -E cmake_echo_color --switch=$(COLOR) --cyan "Running CMake to regenerate build system..." 81 | /home/deepspeed/.local/lib/python3.8/site-packages/cmake/data/bin/cmake --regenerate-during-build -S$(CMAKE_SOURCE_DIR) -B$(CMAKE_BINARY_DIR) 82 | .PHONY : rebuild_cache 83 | 84 | # Special rule for the target rebuild_cache 85 | rebuild_cache/fast: rebuild_cache 86 | .PHONY : rebuild_cache/fast 87 | 88 | # The main all target 89 | all: cmake_check_build_system 90 | $(CMAKE_COMMAND) -E cmake_progress_start /data/private_dev/DeepSpeed-Kernels/inf_flash_attn/blocked_flash/CMakeFiles /data/private_dev/DeepSpeed-Kernels/inf_flash_attn/blocked_flash//CMakeFiles/progress.marks 91 | $(MAKE) $(MAKESILENT) -f CMakeFiles/Makefile2 all 92 | $(CMAKE_COMMAND) -E cmake_progress_start /data/private_dev/DeepSpeed-Kernels/inf_flash_attn/blocked_flash/CMakeFiles 0 93 | .PHONY : all 94 | 95 | # The main clean target 96 | clean: 97 | $(MAKE) $(MAKESILENT) -f CMakeFiles/Makefile2 clean 98 | .PHONY : clean 99 | 100 | # The main clean target 101 | clean/fast: clean 102 | .PHONY : clean/fast 103 | 104 | # Prepare targets for installation. 105 | preinstall: all 106 | $(MAKE) $(MAKESILENT) -f CMakeFiles/Makefile2 preinstall 107 | .PHONY : preinstall 108 | 109 | # Prepare targets for installation. 110 | preinstall/fast: 111 | $(MAKE) $(MAKESILENT) -f CMakeFiles/Makefile2 preinstall 112 | .PHONY : preinstall/fast 113 | 114 | # clear depends 115 | depend: 116 | $(CMAKE_COMMAND) -S$(CMAKE_SOURCE_DIR) -B$(CMAKE_BINARY_DIR) --check-build-system CMakeFiles/Makefile.cmake 1 117 | .PHONY : depend 118 | 119 | #============================================================================= 120 | # Target rules for targets named gemm 121 | 122 | # Build rule for target. 123 | gemm: cmake_check_build_system 124 | $(MAKE) $(MAKESILENT) -f CMakeFiles/Makefile2 gemm 125 | .PHONY : gemm 126 | 127 | # fast build rule for target. 128 | gemm/fast: 129 | $(MAKE) $(MAKESILENT) -f CMakeFiles/gemm.dir/build.make CMakeFiles/gemm.dir/build 130 | .PHONY : gemm/fast 131 | 132 | flash_fwd_hdim32_bf16_sm80.o: flash_fwd_hdim32_bf16_sm80.cu.o 133 | .PHONY : flash_fwd_hdim32_bf16_sm80.o 134 | 135 | # target to build an object file 136 | flash_fwd_hdim32_bf16_sm80.cu.o: 137 | $(MAKE) $(MAKESILENT) -f CMakeFiles/gemm.dir/build.make CMakeFiles/gemm.dir/flash_fwd_hdim32_bf16_sm80.cu.o 138 | .PHONY : flash_fwd_hdim32_bf16_sm80.cu.o 139 | 140 | flash_fwd_hdim32_bf16_sm80.i: flash_fwd_hdim32_bf16_sm80.cu.i 141 | .PHONY : flash_fwd_hdim32_bf16_sm80.i 142 | 143 | # target to preprocess a source file 144 | flash_fwd_hdim32_bf16_sm80.cu.i: 145 | $(MAKE) $(MAKESILENT) -f CMakeFiles/gemm.dir/build.make CMakeFiles/gemm.dir/flash_fwd_hdim32_bf16_sm80.cu.i 146 | .PHONY : flash_fwd_hdim32_bf16_sm80.cu.i 147 | 148 | flash_fwd_hdim32_bf16_sm80.s: flash_fwd_hdim32_bf16_sm80.cu.s 149 | .PHONY : flash_fwd_hdim32_bf16_sm80.s 150 | 151 | # target to generate assembly for a file 152 | flash_fwd_hdim32_bf16_sm80.cu.s: 153 | $(MAKE) $(MAKESILENT) -f CMakeFiles/gemm.dir/build.make CMakeFiles/gemm.dir/flash_fwd_hdim32_bf16_sm80.cu.s 154 | .PHONY : flash_fwd_hdim32_bf16_sm80.cu.s 155 | 156 | # Help Target 157 | help: 158 | @echo "The following are some of the valid targets for this Makefile:" 159 | @echo "... all (the default if no target is provided)" 160 | @echo "... clean" 161 | @echo "... depend" 162 | @echo "... edit_cache" 163 | @echo "... rebuild_cache" 164 | @echo "... gemm" 165 | @echo "... flash_fwd_hdim32_bf16_sm80.o" 166 | @echo "... flash_fwd_hdim32_bf16_sm80.i" 167 | @echo "... flash_fwd_hdim32_bf16_sm80.s" 168 | .PHONY : help 169 | 170 | 171 | 172 | #============================================================================= 173 | # Special targets to cleanup operation of make. 174 | 175 | # Special rule to run CMake to check the build system integrity. 176 | # No rule that depends on this can have commands that come from listfiles 177 | # because they might be regenerated. 178 | cmake_check_build_system: 179 | $(CMAKE_COMMAND) -S$(CMAKE_SOURCE_DIR) -B$(CMAKE_BINARY_DIR) --check-build-system CMakeFiles/Makefile.cmake 0 180 | .PHONY : cmake_check_build_system 181 | 182 | -------------------------------------------------------------------------------- /dskernels/inf_flash_attn/blocked_flash/attention_atom.h: -------------------------------------------------------------------------------- 1 | 2 | #pragma once 3 | 4 | #include 5 | #include "cuda.h" 6 | #include "cute/pointer.hpp" 7 | 8 | struct __align__(32) AttentionAtom { 9 | int32_t* block_idx_list; 10 | 11 | int32_t q_start_idx; 12 | int32_t q_len; 13 | int32_t kv_blocks; 14 | int32_t total_extent; 15 | int32_t global_q_idx; 16 | int32_t unused; 17 | 18 | template 19 | __device__ void load_kv_block_idxs(cute::smem_ptr block_idx_list_shr, int tidx) const 20 | { 21 | for (int i = tidx; i < kv_blocks; i += threads) { block_idx_list_shr[i] = block_idx_list[i]; } 22 | // Aggressive (but safe) sync 23 | __syncthreads(); 24 | } 25 | }; 26 | -------------------------------------------------------------------------------- /dskernels/inf_flash_attn/blocked_flash/build_blocked_flash.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ ! -d "build" ]; then 4 | mkdir build 5 | fi 6 | 7 | cd build 8 | cmake .. 9 | make -j 10 | -------------------------------------------------------------------------------- /dskernels/inf_flash_attn/blocked_flash/flash.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #pragma once 6 | 7 | #include 8 | #include 9 | 10 | #include "attention_atom.h" 11 | 12 | constexpr int TOTAL_DIM = 0; 13 | constexpr int H_DIM = 1; 14 | constexpr int D_DIM = 2; 15 | 16 | //////////////////////////////////////////////////////////////////////////////////////////////////// 17 | 18 | struct Qkv_params { 19 | using index_t = uint32_t; 20 | // The QKV matrices. 21 | void* __restrict__ q_ptr; 22 | void* __restrict__ k_ptr; 23 | void* __restrict__ v_ptr; 24 | 25 | // The stride between rows of the Q, K and V matrices. 26 | index_t q_row_stride; 27 | index_t k_row_stride; 28 | index_t v_row_stride; 29 | index_t q_head_stride; 30 | index_t k_head_stride; 31 | index_t v_head_stride; 32 | 33 | // The number of heads. 34 | int h, h_k; 35 | // In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be 36 | // different from nheads (query). 37 | int h_h_k_ratio; // precompute h / h_k, 38 | }; 39 | 40 | //////////////////////////////////////////////////////////////////////////////////////////////////// 41 | 42 | struct Flash_fwd_params : public Qkv_params { 43 | // The O matrix (output). 44 | void* __restrict__ o_ptr; 45 | 46 | // The attention metadata 47 | AttentionAtom* __restrict__ atoms; 48 | 49 | // Total attention atoms 50 | int num_atoms; 51 | 52 | // The stride between rows of O. 53 | index_t o_row_stride; 54 | index_t o_head_stride; 55 | 56 | // The dimensions 57 | int d, d_rounded; 58 | 59 | // The scaling factors for the kernel. 60 | float scale_softmax; 61 | float scale_softmax_log2; 62 | 63 | bool is_bf16; 64 | bool is_causal; 65 | }; 66 | 67 | //////////////////////////////////////////////////////////////////////////////////////////////////// 68 | 69 | template 70 | void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream); 71 | -------------------------------------------------------------------------------- /dskernels/inf_flash_attn/blocked_flash/flash_api.cu: -------------------------------------------------------------------------------- 1 | 2 | #include "flash.h" 3 | #include "static_switch.h" 4 | 5 | void run_mha_fwd(Flash_fwd_params& params, cudaStream_t stream) 6 | { 7 | FP16_SWITCH(!params.is_bf16, [&] { 8 | FWD_HEADDIM_SWITCH(params.d, [&] { run_mha_fwd_(params, stream); }); 9 | }); 10 | } 11 | -------------------------------------------------------------------------------- /dskernels/inf_flash_attn/blocked_flash/flash_fwd_hdim128_bf16_sm80.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Tri Dao. 2 | 3 | // Splitting the different head dimensions to different files to speed up compilation. 4 | 5 | #include "flash_fwd_launch_template.h" 6 | 7 | // template<> 8 | // void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { 9 | // using elem_type = cutlass::bfloat16_t; 10 | // if (params.p_dropout == 1.f) { 11 | // run_flash_fwd, 12 | // false>(params, stream); 13 | // } else { 14 | // run_flash_fwd, 15 | // true>(params, stream); 16 | // } 17 | // } 18 | template <> 19 | void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) 20 | { 21 | run_mha_fwd_hdim128(params, stream); 22 | } 23 | -------------------------------------------------------------------------------- /dskernels/inf_flash_attn/blocked_flash/flash_fwd_hdim128_fp16_sm80.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Tri Dao. 2 | 3 | // Splitting the different head dimensions to different files to speed up compilation. 4 | 5 | #include "flash_fwd_launch_template.h" 6 | 7 | // template<> 8 | // void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { 9 | // using elem_type = cutlass::half_t; 10 | // if (params.p_dropout == 1.f) { 11 | // // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k 12 | // run_flash_fwd, 13 | // false>(params, stream); 14 | // // run_flash_fwd, 15 | // false>(params, stream); 16 | // // run_flash_fwd, 17 | // false>(params, stream); 18 | // // run_flash_fwd, 19 | // false>(params, stream); run_flash_fwd, false>(params, stream); run_flash_fwd, false>(params, stream); 22 | // run_flash_fwd, 23 | // false>(params, stream); 24 | // // 1st ones are good for H100, A100 25 | // // 2nd one is good for A6000 bc we get slightly better occupancy 26 | // } else { 27 | // run_flash_fwd, 28 | // true>(params, stream); run_flash_fwd, true>(params, stream); run_flash_fwd, true>(params, stream); 31 | // // 1st one is good for H100, A100, A6000 32 | // } 33 | // } 34 | 35 | template <> 36 | void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) 37 | { 38 | run_mha_fwd_hdim128(params, stream); 39 | } 40 | -------------------------------------------------------------------------------- /dskernels/inf_flash_attn/blocked_flash/flash_fwd_hdim160_bf16_sm80.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Tri Dao. 2 | 3 | // Splitting the different head dimensions to different files to speed up compilation. 4 | 5 | #include "flash_fwd_launch_template.h" 6 | 7 | // template<> 8 | // void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { 9 | // using elem_type = cutlass::bfloat16_t; 10 | // BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { 11 | // run_flash_fwd, 12 | // Is_dropout>(params, stream); 13 | // }); 14 | // } 15 | template <> 16 | void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) 17 | { 18 | run_mha_fwd_hdim160(params, stream); 19 | } 20 | -------------------------------------------------------------------------------- /dskernels/inf_flash_attn/blocked_flash/flash_fwd_hdim160_fp16_sm80.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Tri Dao. 2 | 3 | // Splitting the different head dimensions to different files to speed up compilation. 4 | 5 | #include "flash_fwd_launch_template.h" 6 | 7 | // template<> 8 | // void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { 9 | // using elem_type = cutlass::half_t; 10 | // BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { 11 | // run_flash_fwd, 12 | // Is_dropout>(params, stream); run_flash_fwd, Is_dropout>(params, stream); 14 | // run_flash_fwd, 15 | // Is_dropout>(params, stream); run_flash_fwd, Is_dropout>(params, stream); 17 | // // run_flash_fwd>(params, 18 | // stream); 19 | // // run_flash_fwd>(params, 20 | // stream); 21 | // // run_flash_fwd>(params, 22 | // stream); 23 | // // run_flash_fwd>(params, 24 | // stream); 25 | // // run_flash_fwd>(params, 26 | // stream); 27 | // // For A6000, no-causal, 1st is fastest. causal, 4th is fastest. 28 | // // For A100, H100, 1st is fastest. 29 | // }); 30 | // } 31 | template <> 32 | void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) 33 | { 34 | run_mha_fwd_hdim160(params, stream); 35 | } 36 | -------------------------------------------------------------------------------- /dskernels/inf_flash_attn/blocked_flash/flash_fwd_hdim192_bf16_sm80.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Tri Dao. 2 | 3 | // Splitting the different head dimensions to different files to speed up compilation. 4 | 5 | #include "flash_fwd_launch_template.h" 6 | 7 | // template<> 8 | // void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { 9 | // using elem_type = cutlass::bfloat16_t; 10 | // BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { 11 | // run_flash_fwd, 12 | // Is_dropout>(params, stream); 13 | // }); 14 | // } 15 | template <> 16 | void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) 17 | { 18 | run_mha_fwd_hdim192(params, stream); 19 | } 20 | -------------------------------------------------------------------------------- /dskernels/inf_flash_attn/blocked_flash/flash_fwd_hdim192_fp16_sm80.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Tri Dao. 2 | 3 | // Splitting the different head dimensions to different files to speed up compilation. 4 | 5 | #include "flash_fwd_launch_template.h" 6 | 7 | // template<> 8 | // void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { 9 | // using elem_type = cutlass::half_t; 10 | // BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { 11 | // run_flash_fwd, 12 | // Is_dropout>(params, stream); run_flash_fwd, Is_dropout>(params, stream); 14 | // run_flash_fwd, 15 | // Is_dropout>(params, stream); 16 | // // This one is slightly faster for causal? 17 | // // run_flash_fwd>(params, 18 | // stream); 19 | // // run_flash_fwd>(params, 20 | // stream); 21 | // // run_flash_fwd>(params, 22 | // stream); 23 | // // run_flash_fwd>(params, 24 | // stream); 25 | // // run_flash_fwd>(params, 26 | // stream); 27 | // }); 28 | // // For A100 H100, 1st is faster with dropout, 3rd is faster without dropout 29 | // // For A6000, 1st is faster when causal, 3rd is faster when not causal 30 | // } 31 | template <> 32 | void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) 33 | { 34 | run_mha_fwd_hdim192(params, stream); 35 | } 36 | -------------------------------------------------------------------------------- /dskernels/inf_flash_attn/blocked_flash/flash_fwd_hdim224_bf16_sm80.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Tri Dao. 2 | 3 | // Splitting the different head dimensions to different files to speed up compilation. 4 | 5 | #include "flash_fwd_launch_template.h" 6 | 7 | template <> 8 | void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) 9 | { 10 | run_mha_fwd_hdim224(params, stream); 11 | } 12 | -------------------------------------------------------------------------------- /dskernels/inf_flash_attn/blocked_flash/flash_fwd_hdim224_fp16_sm80.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Tri Dao. 2 | 3 | // Splitting the different head dimensions to different files to speed up compilation. 4 | 5 | #include "flash_fwd_launch_template.h" 6 | 7 | template <> 8 | void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) 9 | { 10 | run_mha_fwd_hdim224(params, stream); 11 | } 12 | -------------------------------------------------------------------------------- /dskernels/inf_flash_attn/blocked_flash/flash_fwd_hdim256_bf16_sm80.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Tri Dao. 2 | 3 | // Splitting the different head dimensions to different files to speed up compilation. 4 | 5 | #include "flash_fwd_launch_template.h" 6 | 7 | template <> 8 | void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) 9 | { 10 | run_mha_fwd_hdim256(params, stream); 11 | } 12 | -------------------------------------------------------------------------------- /dskernels/inf_flash_attn/blocked_flash/flash_fwd_hdim256_fp16_sm80.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Tri Dao. 2 | 3 | // Splitting the different head dimensions to different files to speed up compilation. 4 | 5 | #include "flash_fwd_launch_template.h" 6 | 7 | template <> 8 | void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) 9 | { 10 | run_mha_fwd_hdim256(params, stream); 11 | } 12 | -------------------------------------------------------------------------------- /dskernels/inf_flash_attn/blocked_flash/flash_fwd_hdim32_bf16_sm80.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Tri Dao. 2 | 3 | // Splitting the different head dimensions to different files to speed up compilation. 4 | 5 | #include "flash_fwd_launch_template.h" 6 | 7 | template <> 8 | void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) 9 | { 10 | run_mha_fwd_hdim32(params, stream); 11 | } 12 | -------------------------------------------------------------------------------- /dskernels/inf_flash_attn/blocked_flash/flash_fwd_hdim32_fp16_sm80.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Tri Dao. 2 | 3 | // Splitting the different head dimensions to different files to speed up compilation. 4 | 5 | #include "flash_fwd_launch_template.h" 6 | 7 | // template<> 8 | // void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { 9 | // using elem_type = cutlass::half_t; 10 | // BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { 11 | // run_flash_fwd, 12 | // Is_dropout>(params, stream); 13 | // // For dropout there might be a lot of register spilling? 14 | // // These two are very slow due to register spilling 15 | // // run_flash_fwd>(params, 16 | // stream); 17 | // // run_flash_fwd>(params, 18 | // stream); 19 | // // This one is slightly slower 20 | // // run_flash_fwd>(params, 21 | // stream); 22 | // }); 23 | // } 24 | template <> 25 | void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) 26 | { 27 | run_mha_fwd_hdim32(params, stream); 28 | } 29 | -------------------------------------------------------------------------------- /dskernels/inf_flash_attn/blocked_flash/flash_fwd_hdim64_bf16_sm80.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Tri Dao. 2 | 3 | // Splitting the different head dimensions to different files to speed up compilation. 4 | 5 | #include "flash_fwd_launch_template.h" 6 | 7 | // template<> 8 | // void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { 9 | // using elem_type = cutlass::bfloat16_t; 10 | // if (params.p_dropout == 1.f) { 11 | // run_flash_fwd, 12 | // false>(params, stream); 13 | // } else { 14 | // run_flash_fwd, 15 | // true>(params, stream); 16 | // } 17 | // } 18 | template <> 19 | void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) 20 | { 21 | run_mha_fwd_hdim64(params, stream); 22 | } 23 | -------------------------------------------------------------------------------- /dskernels/inf_flash_attn/blocked_flash/flash_fwd_hdim64_fp16_sm80.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Tri Dao. 2 | 3 | // Splitting the different head dimensions to different files to speed up compilation. 4 | 5 | #include "flash_fwd_launch_template.h" 6 | 7 | // template<> 8 | // void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { 9 | // using elem_type = cutlass::half_t; 10 | // if (params.p_dropout == 1.f) { 11 | // // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower 12 | // // Using block size (64 x 256) is 27% slower for seqlen=2k 13 | // // Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling 14 | // run_flash_fwd, 15 | // false>(params, stream); run_flash_fwd, false>(params, stream); run_flash_fwd, false>(params, stream); 18 | // } else { 19 | // run_flash_fwd, 20 | // true>(params, stream); run_flash_fwd, true>(params, stream); run_flash_fwd, true>(params, stream); 23 | // } 24 | // } 25 | template <> 26 | void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) 27 | { 28 | run_mha_fwd_hdim64(params, stream); 29 | } 30 | -------------------------------------------------------------------------------- /dskernels/inf_flash_attn/blocked_flash/flash_fwd_hdim96_bf16_sm80.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Tri Dao. 2 | 3 | // Splitting the different head dimensions to different files to speed up compilation. 4 | 5 | #include "flash_fwd_launch_template.h" 6 | 7 | // template<> 8 | // void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { 9 | // using elem_type = cutlass::bfloat16_t; 10 | // BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { 11 | // run_flash_fwd, 12 | // Is_dropout>(params, stream); 13 | // }); 14 | // } 15 | template <> 16 | void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) 17 | { 18 | run_mha_fwd_hdim96(params, stream); 19 | } 20 | -------------------------------------------------------------------------------- /dskernels/inf_flash_attn/blocked_flash/flash_fwd_hdim96_fp16_sm80.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Tri Dao. 2 | 3 | // Splitting the different head dimensions to different files to speed up compilation. 4 | 5 | #include "flash_fwd_launch_template.h" 6 | 7 | // template<> 8 | // void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { 9 | // using elem_type = cutlass::half_t; 10 | // BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { 11 | // run_flash_fwd, 12 | // Is_dropout>(params, stream); run_flash_fwd, Is_dropout>(params, stream); 14 | // // This 3rd one is good for H100, and A100, A6000 15 | // run_flash_fwd, 16 | // Is_dropout>(params, stream); run_flash_fwd, Is_dropout>(params, stream); 18 | // // These two are always slower 19 | // // run_flash_fwd>(params, 20 | // stream); 21 | // // run_flash_fwd>(params, 22 | // stream); 23 | // }); 24 | // } 25 | template <> 26 | void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) 27 | { 28 | run_mha_fwd_hdim96(params, stream); 29 | } 30 | -------------------------------------------------------------------------------- /dskernels/inf_flash_attn/blocked_flash/kernel_traits.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #pragma once 6 | 7 | #include "cute/algorithm/copy.hpp" 8 | 9 | #include 10 | #include "cutlass/cutlass.h" 11 | #include "cutlass/layout/layout.h" 12 | 13 | using namespace cute; 14 | 15 | template 20 | struct Flash_kernel_traits { 21 | #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 22 | using Element = elem_type; 23 | static constexpr bool Has_cp_async = true; 24 | #else 25 | using Element = cutlass::half_t; 26 | static constexpr bool Has_cp_async = false; 27 | #endif 28 | 29 | using ElementAccum = float; 30 | using index_t = uint32_t; 31 | 32 | #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 33 | using MMA_Atom_Arch = std::conditional_t, 34 | MMA_Atom, 35 | MMA_Atom>; 36 | using ValLayoutMNK = Layout>; 37 | #else 38 | using MMA_Atom_Arch = MMA_Atom; 39 | using ValLayoutMNK = Layout>; 40 | #endif 41 | 42 | #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 43 | using SmemCopyAtom = Copy_Atom; 44 | using SmemCopyAtomTransposed = Copy_Atom; 45 | #else 46 | using SmemCopyAtom = Copy_Atom; 47 | using SmemCopyAtomTransposed = Copy_Atom; 48 | #endif 49 | }; 50 | 51 | // If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true 52 | template > 60 | struct Flash_fwd_kernel_traits : public Base { 61 | using Element = typename Base::Element; 62 | using ElementAccum = typename Base::ElementAccum; 63 | using index_t = typename Base::index_t; 64 | static constexpr bool Has_cp_async = Base::Has_cp_async; 65 | using SmemCopyAtom = typename Base::SmemCopyAtom; 66 | using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; 67 | 68 | static constexpr bool Share_Q_K_smem = Share_Q_K_smem_; 69 | static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem; 70 | 71 | // The number of threads. 72 | static constexpr int kNWarps = kNWarps_; 73 | static constexpr int kNThreads = kNWarps * 32; 74 | 75 | static constexpr int kBlockM = kBlockM_; 76 | static constexpr int kBlockN = kBlockN_; 77 | static constexpr int kHeadDim = kHeadDim_; 78 | static_assert(kHeadDim % 32 == 0); 79 | static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; 80 | static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); 81 | static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; 82 | 83 | using TiledMma = TiledMMA, _1, _1>>, // 4x1x1 or 8x1x1 thread group 85 | typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 86 | // 16x16x16 MMA and LDSM 87 | 88 | using SmemLayoutAtomQ = decltype(composition( 89 | Swizzle{}, 90 | // This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128 91 | Layout>, Stride, _1>>{})); 92 | using SmemLayoutQ = 93 | decltype(tile_to_shape(SmemLayoutAtomQ{}, Shape, Int>{})); 94 | 95 | using SmemLayoutKV = 96 | decltype(tile_to_shape(SmemLayoutAtomQ{}, Shape, Int>{})); 97 | 98 | // This has to be kBlockN and not 8, otherwise we get wrong results for d=128 99 | using SmemLayoutAtomVtransposedNoSwizzle = Layout, Int>, 100 | Stride<_1, Int>>; 101 | using SmemLayoutAtomVtransposed = decltype( 102 | composition(Swizzle{}, SmemLayoutAtomVtransposedNoSwizzle{})); 103 | using SmemLayoutVtransposed = decltype(tile_to_shape( 104 | SmemLayoutAtomVtransposed{}, 105 | Shape, Int>{})); 106 | // Maybe the VtransposeNoSwizzle just needs to have the right shape 107 | // And the strides don't matter? 108 | using SmemLayoutVtransposedNoSwizzle = decltype(tile_to_shape( 109 | SmemLayoutAtomVtransposedNoSwizzle{}, 110 | Shape, Int>{})); 111 | // using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn()); 112 | 113 | using SmemLayoutAtomO = decltype(composition( 114 | Swizzle{}, 115 | Layout, Int>, Stride, _1>>{})); 116 | using SmemLayoutO = 117 | decltype(tile_to_shape(SmemLayoutAtomO{}, Shape, Int>{})); 118 | using SmemCopyAtomO = Copy_Atom; 119 | 120 | static constexpr int kMaxBlocks = 256; 121 | 122 | static constexpr int kSmemQCount = size(SmemLayoutQ{}); 123 | static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2; 124 | static constexpr int kSmemQSize = kSmemQCount * sizeof(Element); 125 | static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element); 126 | static constexpr int kSmemBlockSize = kMaxBlocks * sizeof(int32_t); 127 | static constexpr int kSmemFlashSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) 128 | : kSmemQSize + kSmemKVSize; 129 | static constexpr int kSmemSize = kSmemFlashSize + kSmemBlockSize; 130 | 131 | static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); 132 | static_assert(kHeadDim % kGmemElemsPerLoad == 0, 133 | "kHeadDim must be a multiple of kGmemElemsPerLoad"); 134 | // Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts. 135 | // For example, for d=128, smem is split into 2 "pages", each page takes care of columns 136 | // 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem, 137 | // thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page, 138 | // to the same banks. 139 | static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; 140 | static_assert(kNThreads % kGmemThreadsPerRow == 0, 141 | "kNThreads must be a multiple of kGmemThreadsPerRow"); 142 | using GmemLayoutAtom = 143 | Layout, Int>, 144 | Stride, _1>>; 145 | 146 | // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading 147 | // from the same address by the same threadblock. This is slightly faster. 148 | using Gmem_copy_struct = 149 | std::conditional_t, DefaultCopy>; 150 | using GmemTiledCopyQKV = 151 | decltype(make_tiled_copy(Copy_Atom{}, 152 | GmemLayoutAtom{}, 153 | Layout>{})); // Val layout, 8 vals per read 154 | using GmemTiledCopyO = 155 | decltype(make_tiled_copy(Copy_Atom{}, 156 | GmemLayoutAtom{}, 157 | Layout>{})); // Val layout, 8 vals per store 158 | static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad; 159 | static_assert(kNThreads % kGmemThreadsPerRowP == 0, 160 | "kNThreads must be a multiple of kGmemThreadsPerRowP"); 161 | using GmemLayoutAtomP = 162 | Layout, Int>, 163 | Stride, _1>>; 164 | 165 | using GmemTiledCopyP = 166 | decltype(make_tiled_copy(Copy_Atom{}, 167 | GmemLayoutAtomP{}, 168 | Layout>{})); // Val layout, 8 vals per store 169 | }; 170 | 171 | //////////////////////////////////////////////////////////////////////////////////////////////////// 172 | -------------------------------------------------------------------------------- /dskernels/inf_flash_attn/blocked_flash/softmax.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #pragma once 6 | 7 | #include 8 | 9 | #include 10 | 11 | #include 12 | #include 13 | 14 | #include "utils.h" 15 | 16 | namespace flash { 17 | 18 | using namespace cute; 19 | 20 | //////////////////////////////////////////////////////////////////////////////////////////////////// 21 | 22 | template 23 | __device__ inline void thread_reduce_(Tensor const &tensor, Tensor &summary, Operator &op) { 24 | static_assert(Layout0::rank == 2, "Only support 2D Tensor"); 25 | static_assert(Layout1::rank == 1, "Only support 1D Tensor"); 26 | CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor)); 27 | #pragma unroll 28 | for (int mi = 0; mi < size<0>(tensor); mi++) { 29 | summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0)); 30 | #pragma unroll 31 | for (int ni = 1; ni < size<1>(tensor); ni++) { 32 | summary(mi) = op(summary(mi), tensor(mi, ni)); 33 | } 34 | } 35 | } 36 | 37 | template 38 | __device__ inline void quad_allreduce_(Tensor &dst, Tensor &src, Operator &op) { 39 | CUTE_STATIC_ASSERT_V(size(dst) == size(src)); 40 | #pragma unroll 41 | for (int i = 0; i < size(dst); i++){ 42 | dst(i) = Allreduce<4>::run(src(i), op); 43 | } 44 | } 45 | 46 | template 47 | __device__ inline void reduce_(Tensor const& tensor, Tensor &summary, Operator &op) { 48 | thread_reduce_(tensor, summary, op); 49 | quad_allreduce_(summary, summary, op); 50 | } 51 | 52 | template 53 | __device__ inline void reduce_max(Tensor const& tensor, Tensor &max){ 54 | MaxOp max_op; 55 | reduce_(tensor, max, max_op); 56 | } 57 | 58 | template 59 | __device__ inline void reduce_sum(Tensor const& tensor, Tensor &sum){ 60 | SumOp sum_op; 61 | reduce_(tensor, sum, sum_op); 62 | } 63 | 64 | // Apply the exp to all the elements. 65 | template 66 | inline __device__ void scale_apply_exp2(Tensor &tensor, Tensor const &max, const float scale) { 67 | static_assert(Layout0::rank == 2, "Only support 2D Tensor"); 68 | static_assert(Layout1::rank == 1, "Only support 1D Tensor"); 69 | CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); 70 | #pragma unroll 71 | for (int mi = 0; mi < size<0>(tensor); ++mi) { 72 | // If max is -inf, then all elements must have been -inf (possibly due to masking). 73 | // We don't want (-inf - (-inf)) since that would give NaN. 74 | // If we don't have float around M_LOG2E the multiplication is done in fp64. 75 | const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E)); 76 | #pragma unroll 77 | for (int ni = 0; ni < size<1>(tensor); ++ni) { 78 | // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - 79 | // max * log_2(e)) This allows the compiler to use the ffma 80 | // instruction instead of fadd and fmul separately. 81 | tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); 82 | } 83 | } 84 | } 85 | 86 | // Apply the exp to all the elements. 87 | template 92 | inline __device__ void max_scale_exp2_sum(Tensor& tensor, 93 | Tensor& max, 94 | Tensor& sum, 95 | const float scale) 96 | { 97 | static_assert(Layout0::rank == 2, "Only support 2D Tensor"); 98 | static_assert(Layout1::rank == 1, "Only support 1D Tensor"); 99 | CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); 100 | #pragma unroll 101 | for (int mi = 0; mi < size<0>(tensor); ++mi) { 102 | MaxOp max_op; 103 | max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0)); 104 | #pragma unroll 105 | for (int ni = 1; ni < size<1>(tensor); ni++) { max(mi) = max_op(max(mi), tensor(mi, ni)); } 106 | max(mi) = Allreduce<4>::run(max(mi), max_op); 107 | // If max is -inf, then all elements must have been -inf (possibly due to masking). 108 | // We don't want (-inf - (-inf)) since that would give NaN. 109 | const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale; 110 | sum(mi) = 0; 111 | #pragma unroll 112 | for (int ni = 0; ni < size<1>(tensor); ++ni) { 113 | // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - 114 | // max * log_2(e)) This allows the compiler to use the ffma 115 | // instruction instead of fadd and fmul separately. 116 | tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); 117 | sum(mi) += tensor(mi, ni); 118 | } 119 | SumOp sum_op; 120 | sum(mi) = Allreduce<4>::run(sum(mi), sum_op); 121 | } 122 | } 123 | 124 | template 125 | inline __device__ void apply_mask(Tensor &tensor, const int max_seqlen_k, 126 | const int col_idx_offset_ = 0) { 127 | // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) 128 | static_assert(Layout::rank == 2, "Only support 2D Tensor"); 129 | const int lane_id = threadIdx.x % 32; 130 | const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; 131 | #pragma unroll 132 | for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { 133 | const int col_idx_base = col_idx_offset + nj * 8; 134 | #pragma unroll 135 | for (int j = 0; j < size<1, 0>(tensor); ++j) { 136 | const int col_idx = col_idx_base + j; 137 | if (col_idx >= max_seqlen_k) { 138 | // Without the "make_coord" we get wrong results 139 | #pragma unroll 140 | for (int mi = 0; mi < size<0>(tensor); ++mi) { 141 | tensor(mi, make_coord(j, nj)) = -INFINITY; 142 | } 143 | } 144 | } 145 | } 146 | } 147 | 148 | template 149 | inline __device__ void apply_mask_causal(Tensor &tensor, const int col_idx_offset_, 150 | const int max_seqlen_k, const int row_idx_offset_, 151 | const int max_seqlen_q, const int warp_row_stride) { 152 | // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) 153 | static_assert(Layout::rank == 2, "Only support 2D Tensor"); 154 | const int lane_id = threadIdx.x % 32; 155 | // const int row_idx_offset = row_idx_offset_ + lane_id / 4; 156 | const int row_idx_offset = row_idx_offset_; 157 | const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; 158 | #pragma unroll 159 | for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { 160 | const int row_idx_base = row_idx_offset + mi * warp_row_stride; 161 | #pragma unroll 162 | for (int i = 0; i < size<0, 0>(tensor); ++i) { 163 | const int row_idx = row_idx_base + i * 8; 164 | const int col_idx_limit = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q); 165 | #pragma unroll 166 | for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { 167 | const int col_idx_base = col_idx_offset + nj * 8; 168 | #pragma unroll 169 | for (int j = 0; j < size<1, 0>(tensor); ++j) { 170 | const int col_idx = col_idx_base + j; 171 | if (col_idx >= col_idx_limit) { 172 | tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; 173 | } 174 | } 175 | } 176 | // if (cute::thread0()) { 177 | // printf("mi = %d, i = %d, row_idx = %d, max_seqlen_k = %d\n", mi, i, row_idx, max_seqlen_k); 178 | // print(tensor(make_coord(i, mi), _)); 179 | // // print(tensor(_, j + nj * size<1, 0>(tensor))); 180 | // } 181 | } 182 | } 183 | } 184 | 185 | template 186 | inline __device__ void apply_mask_causal_w_idx(Tensor& tensor, 187 | Tensor const& idx_rowcol, 188 | const int32_t col_idx_offset_, 189 | const int32_t max_seqlen_k, 190 | const int32_t row_idx_offset_) 191 | { 192 | // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) 193 | static_assert(Layout0::rank == 2, "Only support 2D Tensor"); 194 | static_assert(Layout1::rank == 2, "Only support 2D Tensor"); 195 | CUTE_STATIC_ASSERT_V(size<0>(tensor) == size<0>(idx_rowcol)); 196 | CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol)); 197 | #pragma unroll 198 | for (int mi = 0; mi < size<0>(tensor); ++mi) { 199 | const int32_t col_idx_limit = 200 | std::min(max_seqlen_k, 1 + row_idx_offset_ + get<0>(idx_rowcol(mi, 0))); 201 | #pragma unroll 202 | for (int ni = 0; ni < size<1, 1>(tensor); ++ni) { 203 | if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) { 204 | tensor(mi, ni) = -INFINITY; 205 | } 206 | } 207 | // if (cute::thread0()) { 208 | // printf("ni = %d, j = %d, col_idx = %d, max_seqlen_k = %d\n", ni, j, col_idx, 209 | // max_seqlen_k); print(tensor(_, make_coord(j, ni))); 210 | // // print(tensor(_, j + ni * size<1, 0>(tensor))); 211 | // } 212 | } 213 | } 214 | 215 | } // namespace flash 216 | -------------------------------------------------------------------------------- /dskernels/inf_flash_attn/blocked_flash/static_switch.h: -------------------------------------------------------------------------------- 1 | // Inspired by 2 | // https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h 3 | // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h 4 | 5 | #pragma once 6 | 7 | /// @param COND - a boolean expression to switch by 8 | /// @param CONST_NAME - a name given for the constexpr bool variable. 9 | /// @param ... - code to execute for true and false 10 | /// 11 | /// Usage: 12 | /// ``` 13 | /// BOOL_SWITCH(flag, BoolConst, [&] { 14 | /// some_function(...); 15 | /// }); 16 | /// ``` 17 | #define BOOL_SWITCH(COND, CONST_NAME, ...) \ 18 | [&] { \ 19 | if (COND) { \ 20 | constexpr static bool CONST_NAME = true; \ 21 | return __VA_ARGS__(); \ 22 | } else { \ 23 | constexpr static bool CONST_NAME = false; \ 24 | return __VA_ARGS__(); \ 25 | } \ 26 | }() 27 | 28 | #define FP16_SWITCH(COND, ...) \ 29 | [&] { \ 30 | if (COND) { \ 31 | using elem_type = cutlass::half_t; \ 32 | return __VA_ARGS__(); \ 33 | } else { \ 34 | using elem_type = cutlass::bfloat16_t; \ 35 | return __VA_ARGS__(); \ 36 | } \ 37 | }() 38 | 39 | #define FWD_HEADDIM_SWITCH(HEADDIM, ...) \ 40 | [&] { \ 41 | if (HEADDIM <= 32) { \ 42 | constexpr static int kHeadDim = 32; \ 43 | return __VA_ARGS__(); \ 44 | } else if (HEADDIM <= 64) { \ 45 | constexpr static int kHeadDim = 64; \ 46 | return __VA_ARGS__(); \ 47 | } else if (HEADDIM <= 96) { \ 48 | constexpr static int kHeadDim = 96; \ 49 | return __VA_ARGS__(); \ 50 | } else if (HEADDIM <= 128) { \ 51 | constexpr static int kHeadDim = 128; \ 52 | return __VA_ARGS__(); \ 53 | } else if (HEADDIM <= 160) { \ 54 | constexpr static int kHeadDim = 160; \ 55 | return __VA_ARGS__(); \ 56 | } else if (HEADDIM <= 192) { \ 57 | constexpr static int kHeadDim = 192; \ 58 | return __VA_ARGS__(); \ 59 | } else if (HEADDIM <= 224) { \ 60 | constexpr static int kHeadDim = 224; \ 61 | return __VA_ARGS__(); \ 62 | } else if (HEADDIM <= 256) { \ 63 | constexpr static int kHeadDim = 256; \ 64 | return __VA_ARGS__(); \ 65 | } \ 66 | }() 67 | -------------------------------------------------------------------------------- /dskernels/inf_flash_attn/blocked_flash/utils.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #pragma once 6 | 7 | #include 8 | #include 9 | #include 10 | 11 | #include 12 | 13 | #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 14 | #include 15 | #endif 16 | 17 | #include 18 | #include 19 | 20 | #include 21 | #include 22 | #include 23 | #include 24 | 25 | //////////////////////////////////////////////////////////////////////////////////////////////////// 26 | 27 | namespace flash { 28 | 29 | //////////////////////////////////////////////////////////////////////////////////////////////////// 30 | 31 | template 32 | struct MaxOp { 33 | __device__ inline T operator()(T const& x, T const& y) { return x > y ? x : y; } 34 | }; 35 | 36 | template <> 37 | struct MaxOp { 38 | // This is slightly faster 39 | __device__ inline float operator()(float const& x, float const& y) { return max(x, y); } 40 | }; 41 | 42 | //////////////////////////////////////////////////////////////////////////////////////////////////// 43 | 44 | template 45 | struct SumOp { 46 | __device__ inline T operator()(T const& x, T const& y) { return x + y; } 47 | }; 48 | 49 | //////////////////////////////////////////////////////////////////////////////////////////////////// 50 | 51 | template 52 | struct Allreduce { 53 | static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); 54 | template 55 | static __device__ inline T run(T x, Operator& op) 56 | { 57 | constexpr int OFFSET = THREADS / 2; 58 | x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); 59 | return Allreduce::run(x, op); 60 | } 61 | }; 62 | 63 | //////////////////////////////////////////////////////////////////////////////////////////////////// 64 | 65 | template <> 66 | struct Allreduce<2> { 67 | template 68 | static __device__ inline T run(T x, Operator& op) 69 | { 70 | x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); 71 | return x; 72 | } 73 | }; 74 | 75 | //////////////////////////////////////////////////////////////////////////////////////////////////// 76 | 77 | template 81 | inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA, 82 | Tensor4 const& tCsB, TiledMma tiled_mma, 83 | TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B, 84 | ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B) { 85 | CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M 86 | CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N 87 | CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K 88 | Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); 89 | CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M 90 | Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); 91 | CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N 92 | if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); } 93 | if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); } 94 | #pragma unroll 95 | for (int i = 0; i < size<2>(tCrA); ++i) { 96 | if (i < size<2>(tCrA) - 1) { 97 | if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); } 98 | if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); } 99 | } 100 | cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); 101 | } 102 | } 103 | 104 | //////////////////////////////////////////////////////////////////////////////////////////////////// 105 | 106 | template 108 | inline __device__ void gemm_A_in_regs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB, 109 | TiledMma tiled_mma, TiledCopy smem_tiled_copy_B, 110 | ThrCopy smem_thr_copy_B) { 111 | CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M 112 | CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N 113 | CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K 114 | Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); 115 | CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N 116 | cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); 117 | #pragma unroll 118 | for (int i = 0; i < size<2>(tCrA); ++i) { 119 | if (i < size<2>(tCrA) - 1) { 120 | cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); 121 | } 122 | cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); 123 | } 124 | } 125 | 126 | //////////////////////////////////////////////////////////////////////////////////////////////////// 127 | 128 | // Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) 129 | template 130 | inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { 131 | static_assert(decltype(size<0>(acc_layout))::value == 4); 132 | static_assert(decltype(rank(acc_layout))::value == 3); 133 | auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) 134 | // TD [2023-08-13]: Idk why but get<0, 1>(l) doesn't work for Cutlass 3.2, I'm getting 135 | // "int_tuple.hpp(74): error: conversion to inaccessible base class" 136 | return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); 137 | // return make_layout(make_layout(get<1>(get<0>(l)), get<1>(l)), make_layout(get<0>(get<0>(l)), get<2>(l))); 138 | }; 139 | 140 | //////////////////////////////////////////////////////////////////////////////////////////////////// 141 | 142 | // Convert rowcol_layout from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) 143 | // if using m16n8k16, or to ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. 144 | template 145 | inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) { 146 | using X = Underscore; 147 | static_assert(decltype(size<0, 0>(rowcol_layout))::value == 2); 148 | static_assert(decltype(size<1, 0>(rowcol_layout))::value == 2); 149 | constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{}); 150 | static_assert(mma_shape_K == 8 || mma_shape_K == 16); 151 | constexpr int MMA_N_divisor = mma_shape_K == 8 ? 1 : 2; 152 | auto l = logical_divide(rowcol_layout, Shape>>{}); // ((2, MMA_M), (2, (2, MMA_N / 2))) 153 | // TD [2023-08-13]: Same error as above on Cutlass 3.2 154 | return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)), 155 | get<0, 1>(l), 156 | get<1, 1, 1>(l)); 157 | // return make_layout(make_layout(get<0>(get<1>(l)), get<0>(get<0>(l)), get<0>(get<1>(get<1>(l)))), 158 | // get<1>(get<0>(l)), 159 | // get<1>(get<1>(get<1>(l)))); 160 | }; 161 | 162 | //////////////////////////////////////////////////////////////////////////////////////////////////// 163 | 164 | template 165 | inline __device__ auto convert_type(Tensor const &tensor) { 166 | using From_type = typename Engine::value_type; 167 | constexpr int numel = decltype(size(tensor))::value; 168 | cutlass::NumericArrayConverter convert_op; 169 | // HACK: this requires tensor to be "contiguous" 170 | auto frag = convert_op(*reinterpret_cast *>(tensor.data())); 171 | return make_tensor(make_rmem_ptr(&frag), tensor.layout()); 172 | } 173 | 174 | //////////////////////////////////////////////////////////////////////////////////////////////////// 175 | 176 | // Blocks until all but N previous cp.async.commit_group operations have committed. 177 | // This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all 178 | // (which is equivalent to commit_group then wait_group 0). 179 | // Instead we just call cp.async.wait_group 0, which is slightly faster. 180 | // https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113 181 | template 182 | CUTE_HOST_DEVICE void cp_async_wait() 183 | { 184 | #if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) 185 | asm volatile("cp.async.wait_group %0;\n" ::"n"(N)); 186 | #endif 187 | } 188 | 189 | //////////////////////////////////////////////////////////////////////////////////////////////////// 190 | 191 | template 194 | inline __device__ void copy(TiledCopy tiled_copy, Tensor const &S, 195 | Tensor &D, Tensor const &identity_MN, 196 | Tensor const &predicate_K, const int max_MN=0) { 197 | CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); 198 | CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); 199 | CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA 200 | CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M 201 | CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K 202 | // There's no case where !Clear_OOB_K && Clear_OOB_MN 203 | static_assert(!(Clear_OOB_MN && !Clear_OOB_K)); 204 | #pragma unroll 205 | for (int m = 0; m < size<1>(S); ++m) { 206 | if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { 207 | #pragma unroll 208 | for (int k = 0; k < size<2>(S); ++k) { 209 | if (Is_even_K || predicate_K(k)) { 210 | cute::copy(tiled_copy, S(_, m, k), D(_, m, k)); 211 | } else if (Clear_OOB_K) { 212 | cute::clear(D(_, m, k)); 213 | } 214 | } 215 | } else if (Clear_OOB_MN) { 216 | cute::clear(D(_, m, _)); 217 | } 218 | } 219 | } 220 | 221 | //////////////////////////////////////////////////////////////////////////////////////////////////// 222 | 223 | } // namespace flash 224 | -------------------------------------------------------------------------------- /fetch.sh: -------------------------------------------------------------------------------- 1 | set -ex 2 | git submodule update --init --recursive 3 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools>=64", 4 | "wheel", 5 | "packaging" 6 | ] 7 | # Use legacy backend to import local packages in setup.py 8 | build-backend = "setuptools.build_meta:__legacy__" 9 | -------------------------------------------------------------------------------- /release/bump_patch_version.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | # DeepSpeed Team 5 | 6 | import argparse 7 | from packaging import version as pkg_version 8 | 9 | parser = argparse.ArgumentParser() 10 | 11 | parser.add_argument("--current_version", 12 | type=str, 13 | help="The current version being published to help set the next version.") 14 | 15 | args = parser.parse_args() 16 | 17 | current_version = pkg_version.parse(args.current_version) 18 | 19 | with open('./version.txt', 'w') as fd: 20 | fd.write(f'{current_version.major}.{current_version.minor}.{current_version.micro + 1}\n') 21 | 22 | print(f'{current_version} -> {current_version.major}.{current_version.minor}.{current_version.micro + 1}') 23 | -------------------------------------------------------------------------------- /release/check_release_version.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | # DeepSpeed Team 5 | 6 | import argparse 7 | from packaging import version as pkg_version 8 | 9 | parser = argparse.ArgumentParser() 10 | 11 | parser.add_argument("--release_version", type=str, help="The new version being published.") 12 | 13 | args = parser.parse_args() 14 | 15 | release_version = pkg_version.parse(args.release_version) 16 | 17 | with open('./version.txt') as fd: 18 | repo_version = pkg_version.parse(fd.read()) 19 | 20 | assert repo_version == release_version, f"{repo_version=} does not match {release_version=}, unable to proceed" 21 | -------------------------------------------------------------------------------- /release/release.sh: -------------------------------------------------------------------------------- 1 | set -ex 2 | 3 | rm -rf dist 4 | 5 | # enable to reduce overall memory consumption if running on a small VM 6 | #export DS_KERNELS_MAKE_JOBS=10 7 | 8 | ts=$(date +%s) 9 | DS_KERNELS_BUILD_STRING=".dev${ts}" CUDA_ARCH_LIST="80;86" python setup.py bdist_wheel 10 | 11 | # rename whl to ensure portability 12 | fname=$(ls dist) 13 | nname=$(echo $fname | sed 's/cp[0-9]\+-cp[0-9]\+/py3-none/' | sed 's/linux/manylinux1/') 14 | mv "dist/$fname" "dist/$nname" 15 | 16 | twine upload dist/*.whl 17 | -------------------------------------------------------------------------------- /requirements/requirements-dev.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepspeedai/DeepSpeed-Kernels/d5dde7b1deead286b5bd54f437edde0fccfaee2f/requirements/requirements-dev.txt -------------------------------------------------------------------------------- /requirements/requirements.txt: -------------------------------------------------------------------------------- 1 | cmake>=3.24 2 | packaging 3 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | # DeepSpeed Team 5 | 6 | import os 7 | import sys 8 | import subprocess 9 | from setuptools import setup, find_packages 10 | 11 | 12 | def fetch_requirements(path): 13 | with open(path, 'r') as fd: 14 | return [r.strip() for r in fd.readlines()] 15 | 16 | 17 | install_requires = fetch_requirements('requirements/requirements.txt') 18 | extras_require = { 19 | "dev": fetch_requirements('requirements/requirements-dev.txt') 20 | } 21 | 22 | 23 | def command_exists(cmd): 24 | if sys.platform == "win32": 25 | result = subprocess.Popen(f'{cmd}', stdout=subprocess.PIPE, shell=True) 26 | return result.wait() == 1 27 | else: 28 | result = subprocess.Popen(f'type {cmd}', 29 | stdout=subprocess.PIPE, 30 | shell=True) 31 | return result.wait() == 0 32 | 33 | 34 | # Write out version/git info 35 | git_hash_cmd = "git rev-parse --short HEAD" 36 | git_branch_cmd = "git rev-parse --abbrev-ref HEAD" 37 | if command_exists('git') and 'DS_KERNELS_BUILD_STRING' not in os.environ: 38 | try: 39 | result = subprocess.check_output(git_hash_cmd, shell=True) 40 | git_hash = result.decode('utf-8').strip() 41 | result = subprocess.check_output(git_branch_cmd, shell=True) 42 | git_branch = result.decode('utf-8').strip() 43 | except subprocess.CalledProcessError: 44 | git_hash = "unknown" 45 | git_branch = "unknown" 46 | else: 47 | git_hash = "unknown" 48 | git_branch = "unknown" 49 | 50 | # Ensure all submodules have been pulled in 51 | git_submodules = "git submodule update --init --recursive" 52 | if command_exists('git'): 53 | try: 54 | result = subprocess.check_output(git_submodules, shell=True) 55 | except subprocess.CalledProcessError: 56 | pass 57 | 58 | # Parse the ds-kernels version string from version.txt 59 | version_str = open('version.txt', 'r').read().strip() 60 | 61 | # Build specifiers like .devX can be added at install time. Otherwise, add the git hash. 62 | # example: BUILD_STR=".dev20201022" python -m build --sdist --wheel 63 | 64 | BUILD_STRING = 'DS_KERNELS_BUILD_STRING' 65 | BUILD_FILE = 'build.txt' 66 | build_string = os.environ.get(BUILD_STRING) 67 | 68 | # Building wheel for distribution, update version file 69 | if build_string: 70 | # Build string env specified, probably building for distribution 71 | with open(BUILD_FILE, 'w') as fd: 72 | fd.write(build_string) 73 | version_str += build_string 74 | elif os.path.isfile(BUILD_FILE): 75 | # build.txt exists, probably installing from distribution 76 | with open(BUILD_FILE, 'r') as fd: 77 | version_str += fd.read().strip() 78 | else: 79 | # None of the above, probably installing from source 80 | version_str += f'+{git_hash}' 81 | 82 | # write out installed version 83 | with open("dskernels/version.py", 'w') as fd: 84 | fd.write(f"__version__ = '{version_str}'\n") 85 | 86 | from builder.builder import CMakeBuild 87 | from builder.ft_gemm import FTGemmBuilder 88 | from builder.inf_flash_attn import BlockedFlashBuilder 89 | 90 | ext_modules = [] 91 | build_ext = {'build_ext': CMakeBuild} 92 | 93 | ext_modules.append(FTGemmBuilder(name="deepspeed_ft_gemm")) 94 | ext_modules.append(BlockedFlashBuilder(name="deepspeed_blocked_flash")) 95 | 96 | setup(name="deepspeed-kernels", 97 | version=version_str, 98 | description='deepspeed kernels', 99 | author='DeepSpeed Team', 100 | author_email='deepspeed@microsoft.com', 101 | url='http://deepspeed.ai', 102 | project_urls={ 103 | 'Documentation': 'https://github.com/deepspeedai/DeepSpeed-Kernels', 104 | 'Source': 'https://github.com/deepspeedai/DeepSpeed-Kernels', 105 | }, 106 | install_requires=install_requires, 107 | extras_require=extras_require, 108 | ext_modules=ext_modules, 109 | cmdclass=build_ext, 110 | include_package_data=True, 111 | packages=find_packages(include=['dskernels']), 112 | classifiers=[ 113 | 'Programming Language :: Python :: 3.9', 114 | 'Programming Language :: Python :: 3.10', 115 | 'Programming Language :: Python :: 3.11' 116 | ]) 117 | -------------------------------------------------------------------------------- /version.txt: -------------------------------------------------------------------------------- 1 | 0.1.0 2 | --------------------------------------------------------------------------------