├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── assets └── images │ ├── dev-discuss-asynctp │ ├── Figure_1.png │ ├── Figure_10.png │ ├── Figure_11.png │ ├── Figure_2.png │ ├── Figure_3.png │ ├── Figure_4.png │ ├── Figure_5.png │ ├── Figure_6.png │ ├── Figure_7.png │ ├── Figure_8.png │ ├── Figure_9.png │ └── readme.md │ ├── moe_gemm_h100.png │ ├── readme.md │ └── softmax_fused.png ├── dev ├── sr │ ├── .gitignore │ ├── build │ │ ├── lib.linux-x86_64-cpython-311 │ │ │ └── stochastic_rounding_cuda.cpython-311-x86_64-linux-gnu.so │ │ └── temp.linux-x86_64-cpython-311 │ │ │ ├── .ninja_deps │ │ │ └── .ninja_log │ ├── dist │ │ └── stochastic_rounding_cuda-0.0.0-py3.11-linux-x86_64.egg │ ├── readme.md │ ├── setup.py │ ├── src │ │ ├── stochastic_rounding.cu │ │ ├── stochastic_rounding.hpp │ │ └── stochastic_rounding_cuda.cu │ ├── test.md │ ├── tests │ │ ├── benchmark.py │ │ └── core_unit_tests.py │ ├── usage.py │ └── usage2.py └── triton_groupGEMM │ ├── groupgemm.py │ ├── testing │ ├── base_testing.py │ └── unit_tests.py │ ├── tma_utils.py │ └── triton_tutorial_groupgemm.py ├── kernels ├── MoE │ └── group_GEMM │ │ └── triton │ │ ├── readme.md │ │ ├── testing │ │ ├── fast_verification.py │ │ └── pytorch_reference_backwards.py │ │ ├── tgroup_gemm_backwards.py │ │ ├── tgroup_gemm_forward.py │ │ └── utils │ │ └── tma_utils.py ├── blackwell │ ├── cute_gemm_01 │ │ ├── Makefile │ │ ├── build │ │ │ ├── lib.linux-x86_64-cpython-312 │ │ │ │ └── sm100_gemm.cpython-312-x86_64-linux-gnu.so │ │ │ └── temp.linux-x86_64-cpython-312 │ │ │ │ ├── .ninja_deps │ │ │ │ ├── .ninja_log │ │ │ │ ├── build.ninja │ │ │ │ ├── sm100_gemm.o │ │ │ │ └── sm100_gemm_pytorch.o │ │ ├── dist │ │ │ └── sm100_gemm-0.0.0-py3.12-linux-x86_64.egg │ │ ├── driver.py │ │ ├── setup.py │ │ ├── sm100_gemm.cu │ │ ├── sm100_gemm.egg-info │ │ │ ├── PKG-INFO │ │ │ ├── SOURCES.txt │ │ │ ├── dependency_links.txt │ │ │ ├── not-zip-safe │ │ │ ├── requires.txt │ │ │ └── top_level.txt │ │ ├── sm100_gemm.h │ │ └── sm100_gemm_pytorch.cpp │ └── cute_gemm_02_tma │ │ ├── build │ │ ├── lib.linux-x86_64-cpython-312 │ │ │ └── sm100_gemm.cpython-312-x86_64-linux-gnu.so │ │ └── temp.linux-x86_64-cpython-312 │ │ │ ├── .ninja_deps │ │ │ ├── .ninja_log │ │ │ ├── build.ninja │ │ │ ├── sm100_gemm.o │ │ │ └── sm100_gemm_pytorch.o │ │ ├── dist │ │ └── sm100_gemm-0.0.0-py3.12-linux-x86_64.egg │ │ ├── driver.py │ │ ├── setup.py │ │ ├── sm100_gemm.cu │ │ ├── sm100_gemm.egg-info │ │ ├── PKG-INFO │ │ ├── SOURCES.txt │ │ ├── dependency_links.txt │ │ ├── not-zip-safe │ │ ├── requires.txt │ │ └── top_level.txt │ │ ├── sm100_gemm.h │ │ └── sm100_gemm_pytorch.cpp ├── cuda │ ├── cutlass_gemm │ │ ├── broadcast_load_epilogue_c3x.hpp │ │ ├── common.hpp │ │ ├── cutlass.cpp │ │ ├── cutlass_kernel.cu │ │ ├── readme.md │ │ ├── setup.py │ │ └── test_cutlass_gemm.py │ ├── inference │ │ ├── README.md │ │ └── hadamard_transform │ │ │ ├── hadamard_transform.cpp │ │ │ ├── hadamard_transform_cuda.cu │ │ │ ├── setup.py │ │ │ └── test.py │ ├── training │ │ └── README.md │ └── tutorials │ │ ├── README.md │ │ └── flash2.cu ├── needs_perf_help │ ├── fp8_gemm_bench.py │ └── fp8_rowwise_tma_persistent.py └── triton │ ├── inference │ ├── .DS_Store │ ├── README.md │ ├── col_major_moe_gemm │ │ ├── README.md │ │ ├── perf_test_moe.py │ │ ├── profile_moe.py │ │ ├── results.html │ │ ├── test.csv │ │ ├── test.png │ │ ├── test_moe_gemm.py │ │ ├── v0_moe_fused.py │ │ ├── v1_moe_fused.py │ │ └── v2_moe_fused.py │ ├── flash_attention │ │ └── stay_attention.py │ ├── fp8 │ │ ├── float8_groupwise_quant.py │ │ ├── scaled_fp8_gemm.py │ │ ├── splitk_gemm_fp8.py │ │ └── tma_gemm.py │ ├── gptq │ │ ├── a100_qlinear.py │ │ ├── benchmark.py │ │ ├── h100_qlinear.py │ │ ├── mixtral │ │ │ ├── test_dequant_moe_gemm.py │ │ │ └── w4a16_fused_dequant_gemm.py │ │ ├── small_benchmark_cuda_graphs.py │ │ └── splitk_dequant_gemm.py │ ├── mamba │ │ └── causal_1d_conv │ │ │ ├── causal_1d_conv │ │ │ └── causal_1d_conv.py │ │ │ └── tests │ │ │ └── test_causal_1d_conv.py │ ├── paged_attention │ │ └── attention_triton.py │ └── torch_compile │ │ └── flash_backward.py │ ├── training │ ├── README.md │ ├── fused_softmax │ │ ├── README.md │ │ └── softmax.py │ └── rms_norm │ │ └── fused_rms_norm.py │ └── tutorials │ └── README.md ├── readme.md └── tutorials └── triton ├── kernels ├── __init__.py ├── flash_attention_fwd.py ├── fused_softmax.py ├── readme.md └── vector_add.py └── tests ├── test_softmax.py └── test_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | **/.ipynb_checkpoints 3 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to Applied AI 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Our Development Process 6 | ... (in particular how this is synced with internal changes to the project) 7 | 8 | ## Pull Requests 9 | We actively welcome your pull requests. 10 | 11 | 1. Fork the repo and create your branch from `main`. 12 | 2. If you've added code that should be tested, add tests. 13 | 3. If you've changed APIs, update the documentation. 14 | 4. Ensure the test suite passes. 15 | 5. Make sure your code lints. 16 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 17 | 18 | ## Contributor License Agreement ("CLA") 19 | In order to accept your pull request, we need you to submit a CLA. You only need 20 | to do this once to work on any of Meta's open source projects. 21 | 22 | Complete your CLA here: 23 | 24 | ## Issues 25 | We use GitHub issues to track public bugs. Please ensure your description is 26 | clear and has sufficient instructions to be able to reproduce the issue. 27 | 28 | Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe 29 | disclosure of security bugs. In those cases, please go through the process 30 | outlined on that page and do not file a public issue. 31 | 32 | ## Coding Style 33 | * 2 spaces for indentation rather than tabs 34 | * 80 character line length 35 | * ... 36 | 37 | ## License 38 | By contributing to applied-ai, you agree that your contributions will be licensed 39 | under the LICENSE file in the root directory of this source tree. 40 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2024 Meta 2 | 3 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 4 | 5 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 6 | 7 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 8 | 9 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 10 | 11 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 12 | -------------------------------------------------------------------------------- /assets/images/dev-discuss-asynctp/Figure_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch-labs/applied-ai/4f8c39eb70d92e3848c65347ca077c8e723c74fd/assets/images/dev-discuss-asynctp/Figure_1.png -------------------------------------------------------------------------------- /assets/images/dev-discuss-asynctp/Figure_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch-labs/applied-ai/4f8c39eb70d92e3848c65347ca077c8e723c74fd/assets/images/dev-discuss-asynctp/Figure_10.png -------------------------------------------------------------------------------- /assets/images/dev-discuss-asynctp/Figure_11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch-labs/applied-ai/4f8c39eb70d92e3848c65347ca077c8e723c74fd/assets/images/dev-discuss-asynctp/Figure_11.png -------------------------------------------------------------------------------- /assets/images/dev-discuss-asynctp/Figure_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch-labs/applied-ai/4f8c39eb70d92e3848c65347ca077c8e723c74fd/assets/images/dev-discuss-asynctp/Figure_2.png -------------------------------------------------------------------------------- /assets/images/dev-discuss-asynctp/Figure_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch-labs/applied-ai/4f8c39eb70d92e3848c65347ca077c8e723c74fd/assets/images/dev-discuss-asynctp/Figure_3.png -------------------------------------------------------------------------------- /assets/images/dev-discuss-asynctp/Figure_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch-labs/applied-ai/4f8c39eb70d92e3848c65347ca077c8e723c74fd/assets/images/dev-discuss-asynctp/Figure_4.png -------------------------------------------------------------------------------- /assets/images/dev-discuss-asynctp/Figure_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch-labs/applied-ai/4f8c39eb70d92e3848c65347ca077c8e723c74fd/assets/images/dev-discuss-asynctp/Figure_5.png -------------------------------------------------------------------------------- /assets/images/dev-discuss-asynctp/Figure_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch-labs/applied-ai/4f8c39eb70d92e3848c65347ca077c8e723c74fd/assets/images/dev-discuss-asynctp/Figure_6.png -------------------------------------------------------------------------------- /assets/images/dev-discuss-asynctp/Figure_7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch-labs/applied-ai/4f8c39eb70d92e3848c65347ca077c8e723c74fd/assets/images/dev-discuss-asynctp/Figure_7.png -------------------------------------------------------------------------------- /assets/images/dev-discuss-asynctp/Figure_8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch-labs/applied-ai/4f8c39eb70d92e3848c65347ca077c8e723c74fd/assets/images/dev-discuss-asynctp/Figure_8.png -------------------------------------------------------------------------------- /assets/images/dev-discuss-asynctp/Figure_9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch-labs/applied-ai/4f8c39eb70d92e3848c65347ca077c8e723c74fd/assets/images/dev-discuss-asynctp/Figure_9.png -------------------------------------------------------------------------------- /assets/images/dev-discuss-asynctp/readme.md: -------------------------------------------------------------------------------- 1 | This folder is for hosting the images for the AsyncTP public post at: 2 | [https://discuss.pytorch.org/t/distributed-w-torchtitan-introducing-async-tensor-parallelism-in-pytorch/209487](https://discuss.pytorch.org/t/distributed-w-torchtitan-introducing-async-tensor-parallelism-in-pytorch/209487) 3 | -------------------------------------------------------------------------------- /assets/images/moe_gemm_h100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch-labs/applied-ai/4f8c39eb70d92e3848c65347ca077c8e723c74fd/assets/images/moe_gemm_h100.png -------------------------------------------------------------------------------- /assets/images/readme.md: -------------------------------------------------------------------------------- 1 | Folder for housing images for the readmes. 2 | -------------------------------------------------------------------------------- /assets/images/softmax_fused.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch-labs/applied-ai/4f8c39eb70d92e3848c65347ca077c8e723c74fd/assets/images/softmax_fused.png -------------------------------------------------------------------------------- /dev/sr/.gitignore: -------------------------------------------------------------------------------- 1 | *.o 2 | *.ninja 3 | *.txt 4 | *.egg-info 5 | *.ninja-deps 6 | *.ninja-log/ 7 | *.so 8 | dist/ 9 | build/ 10 | -------------------------------------------------------------------------------- /dev/sr/build/lib.linux-x86_64-cpython-311/stochastic_rounding_cuda.cpython-311-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch-labs/applied-ai/4f8c39eb70d92e3848c65347ca077c8e723c74fd/dev/sr/build/lib.linux-x86_64-cpython-311/stochastic_rounding_cuda.cpython-311-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /dev/sr/build/temp.linux-x86_64-cpython-311/.ninja_deps: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch-labs/applied-ai/4f8c39eb70d92e3848c65347ca077c8e723c74fd/dev/sr/build/temp.linux-x86_64-cpython-311/.ninja_deps -------------------------------------------------------------------------------- /dev/sr/build/temp.linux-x86_64-cpython-311/.ninja_log: -------------------------------------------------------------------------------- 1 | # ninja log v5 2 | 0 35090 1739831912112986452 /data/users/less/applied-ai/dev/sr/build/temp.linux-x86_64-cpython-311/src/stochastic_rounding_cuda.o a0935ce4b779ec1a 3 | 0 38411 1739831915434998711 /data/users/less/applied-ai/dev/sr/build/temp.linux-x86_64-cpython-311/src/stochastic_rounding.o 2e266ac627e9fc88 4 | 17 35160 1739833265433057983 /data/users/less/applied-ai/dev/sr/build/temp.linux-x86_64-cpython-311/src/stochastic_rounding_cuda.o a0935ce4b779ec1a 5 | 17 34989 1739834841156897402 /data/users/less/applied-ai/dev/sr/build/temp.linux-x86_64-cpython-311/src/stochastic_rounding_cuda.o a0935ce4b779ec1a 6 | 17 35109 1739836443932926241 /data/users/less/applied-ai/dev/sr/build/temp.linux-x86_64-cpython-311/src/stochastic_rounding_cuda.o a0935ce4b779ec1a 7 | 0 38226 1739931693048424587 /data/users/less/applied-ai/dev/sr/build/temp.linux-x86_64-cpython-311/src/stochastic_rounding.o 2e266ac627e9fc88 8 | 0 35010 1740337115885126351 /data/users/less/applied-ai/dev/sr/build/temp.linux-x86_64-cpython-311/src/stochastic_rounding_cuda.o a0935ce4b779ec1a 9 | 0 38581 1740337119456139855 /data/users/less/applied-ai/dev/sr/build/temp.linux-x86_64-cpython-311/src/stochastic_rounding.o 2e266ac627e9fc88 10 | 14 38912 1740337736128509775 /data/users/less/applied-ai/dev/sr/build/temp.linux-x86_64-cpython-311/src/stochastic_rounding.o 2e266ac627e9fc88 11 | 8 38980 1740338025101622234 /data/users/less/applied-ai/dev/sr/build/temp.linux-x86_64-cpython-311/src/stochastic_rounding.o 2e266ac627e9fc88 12 | 9 38555 1740338553641634401 /data/users/less/applied-ai/dev/sr/build/temp.linux-x86_64-cpython-311/src/stochastic_rounding.o 2e266ac627e9fc88 13 | 16 38682 1740343653861211571 /data/users/less/applied-ai/dev/sr/build/temp.linux-x86_64-cpython-311/src/stochastic_rounding.o 2e266ac627e9fc88 14 | 8 34981 1740343690705355179 /data/users/less/applied-ai/dev/sr/build/temp.linux-x86_64-cpython-311/src/stochastic_rounding_cuda.o a0935ce4b779ec1a 15 | 9 38456 1740343918575244375 /data/users/less/applied-ai/dev/sr/build/temp.linux-x86_64-cpython-311/src/stochastic_rounding.o 2e266ac627e9fc88 16 | 8 38453 1740345709608098871 /data/users/less/applied-ai/dev/sr/build/temp.linux-x86_64-cpython-311/src/stochastic_rounding.o 2e266ac627e9fc88 17 | 17 38035 1740347710429769430 /data/users/less/applied-ai/dev/sr/build/temp.linux-x86_64-cpython-311/src/stochastic_rounding.o 2e266ac627e9fc88 18 | 9 38013 1740347825697202212 /data/users/less/applied-ai/dev/sr/build/temp.linux-x86_64-cpython-311/src/stochastic_rounding.o 2e266ac627e9fc88 19 | 9 34149 1740347987565817403 /data/users/less/applied-ai/dev/sr/build/temp.linux-x86_64-cpython-311/src/stochastic_rounding_cuda.o a0935ce4b779ec1a 20 | 9 37820 1740347991234831597 /data/users/less/applied-ai/dev/sr/build/temp.linux-x86_64-cpython-311/src/stochastic_rounding.o 2e266ac627e9fc88 21 | 9 38024 1740348243785799422 /data/users/less/applied-ai/dev/sr/build/temp.linux-x86_64-cpython-311/src/stochastic_rounding.o 2e266ac627e9fc88 22 | 17 37664 1740348867502214064 /data/users/less/applied-ai/dev/sr/build/temp.linux-x86_64-cpython-311/src/stochastic_rounding.o 2e266ac627e9fc88 23 | 9 38031 1740349024425819484 /data/users/less/applied-ai/dev/sr/build/temp.linux-x86_64-cpython-311/src/stochastic_rounding.o 2e266ac627e9fc88 24 | 9 37981 1740349398778267158 /data/users/less/applied-ai/dev/sr/build/temp.linux-x86_64-cpython-311/src/stochastic_rounding.o 2e266ac627e9fc88 25 | 9 37728 1740349508335689703 /data/users/less/applied-ai/dev/sr/build/temp.linux-x86_64-cpython-311/src/stochastic_rounding.o 2e266ac627e9fc88 26 | -------------------------------------------------------------------------------- /dev/sr/dist/stochastic_rounding_cuda-0.0.0-py3.11-linux-x86_64.egg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch-labs/applied-ai/4f8c39eb70d92e3848c65347ca077c8e723c74fd/dev/sr/dist/stochastic_rounding_cuda-0.0.0-py3.11-linux-x86_64.egg -------------------------------------------------------------------------------- /dev/sr/readme.md: -------------------------------------------------------------------------------- 1 | Branch for stochastic rounding kernel 2 | Currently processes 4 elements per thread to leverage rand4 3 | -------------------------------------------------------------------------------- /dev/sr/setup.py: -------------------------------------------------------------------------------- 1 | 2 | from setuptools import setup 3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 4 | 5 | setup( 6 | name='stochastic_rounding_cuda', 7 | version='0.1.021825', 8 | ext_modules=[ 9 | CUDAExtension('stochastic_rounding_cuda', [ 10 | 'src/stochastic_rounding.cu', 11 | 'src/stochastic_rounding_cuda.cu' 12 | ], 13 | extra_compile_args={ 14 | 'cxx': ['-O3'], 15 | 'nvcc': [ 16 | '-O3', 17 | '--expt-relaxed-constexpr', # better template support 18 | #'-gencode=arch=compute_70,code=sm_70', # Volta 19 | #'-gencode=arch=compute_75,code=sm_75', # Turing 20 | #'-gencode=arch=compute_80,code=sm_80' # Amper 21 | #'-gencode=arch=compute_86,code=sm_86' # Ampere 22 | '-gencode=arch=compute_90,code=sm_90', # Hopper 23 | ] 24 | }) 25 | ], 26 | cmdclass={ 27 | 'build_ext': BuildExtension 28 | } 29 | ) 30 | -------------------------------------------------------------------------------- /dev/sr/src/stochastic_rounding.cu: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include "stochastic_rounding.hpp" 4 | #include 5 | 6 | namespace py = pybind11; 7 | 8 | __host__ int getOptimalBlockSize() { 9 | cudaDeviceProp prop; 10 | cudaGetDeviceProperties(&prop, 0); 11 | return std::min(prop.maxThreadsPerBlock, 256); 12 | } 13 | 14 | torch::Tensor stochastic_round_bf16_cuda(torch::Tensor input, bool requires_grad) { 15 | TORCH_CHECK(input.is_cuda(), "Input tensor must be on CUDA device"); 16 | TORCH_CHECK(input.is_contiguous(), "Input tensor must be contiguous"); 17 | TORCH_CHECK(input.scalar_type() == torch::kFloat32, "Input tensor must be float32"); 18 | 19 | const int threads_per_block = 256; 20 | const int num_elements = input.numel(); 21 | const int elements_per_thread = 4; 22 | 23 | const int min_blocks = (num_elements + elements_per_thread * threads_per_block - 1) / 24 | (elements_per_thread * threads_per_block); 25 | 26 | cudaDeviceProp prop; 27 | cudaGetDeviceProperties(&prop, 0); 28 | const int blocks_per_sm = 4; 29 | const int min_blocks_for_sms = prop.multiProcessorCount * blocks_per_sm; 30 | const int num_blocks = std::max(min_blocks, min_blocks_for_sms); 31 | 32 | auto options = torch::TensorOptions() 33 | .dtype(torch::kBFloat16) 34 | .device(input.device()) 35 | .requires_grad(requires_grad); 36 | auto output = torch::empty_like(input, options); 37 | 38 | std::random_device rd; 39 | std::mt19937_64 gen(rd()); 40 | std::uniform_int_distribution dis; 41 | const unsigned long long seed = dis(gen); 42 | 43 | stochastic_round_bf16<<>>( 44 | input.data_ptr(), 45 | reinterpret_cast<__nv_bfloat16*>(output.data_ptr()), 46 | num_elements, 47 | seed); 48 | 49 | cudaError_t err = cudaGetLastError(); 50 | TORCH_CHECK(err == cudaSuccess, 51 | "CUDA kernel execution failed: ", cudaGetErrorString(err)); 52 | 53 | return output; 54 | } 55 | 56 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 57 | m.def("stochastic_round_bf16", 58 | static_cast(&stochastic_round_bf16_cuda), 59 | "Stochastic rounding to BFloat16", 60 | py::arg("input"), 61 | py::arg("requires_grad") = false); 62 | } 63 | -------------------------------------------------------------------------------- /dev/sr/src/stochastic_rounding.hpp: -------------------------------------------------------------------------------- 1 | 2 | #pragma once 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | namespace philox { 10 | constexpr unsigned int W32_0 = 0x9E3779B9; 11 | constexpr unsigned int W32_1 = 0xBB67AE85; 12 | constexpr unsigned int M0 = 0xD2511F53; 13 | constexpr unsigned int M1 = 0xCD9E8D57; 14 | constexpr int ROUNDS = 7; 15 | } 16 | 17 | // Forward declarations 18 | class PhiloxGenerator { 19 | public: 20 | __device__ __forceinline__ PhiloxGenerator(); 21 | __device__ __forceinline__ void init(const unsigned long long seed, const unsigned int thread_id); 22 | __device__ __forceinline__ uint4 next(); 23 | private: 24 | uint2 key; 25 | uint4 counter; 26 | static __device__ __forceinline__ uint2 mulhilo(const unsigned int a, const unsigned int b); 27 | static __device__ __forceinline__ uint4 round(uint4 ctr, uint2 key); 28 | }; 29 | 30 | // CUDA kernel declaration 31 | __global__ void stochastic_round_bf16( 32 | float *__restrict__ input, 33 | __nv_bfloat16 *__restrict__ output, 34 | const int size, 35 | const unsigned long long seed); 36 | 37 | // Host functions 38 | __host__ int getOptimalBlockSize(); 39 | torch::Tensor stochastic_round_bf16_cuda(torch::Tensor input, bool requires_grad = false); 40 | -------------------------------------------------------------------------------- /dev/sr/src/stochastic_rounding_cuda.cu: -------------------------------------------------------------------------------- 1 | #include "stochastic_rounding.hpp" 2 | #include 3 | 4 | // Philox RNG implementation 5 | 6 | __device__ __forceinline__ PhiloxGenerator::PhiloxGenerator() : 7 | key(make_uint2(0, 0)), 8 | counter(make_uint4(0, 0, 0, 0)) {} 9 | 10 | __device__ __forceinline__ void PhiloxGenerator::init(const unsigned long long seed, const unsigned int thread_id) { 11 | key.x = static_cast(seed); 12 | key.y = static_cast(seed >> 32); 13 | counter = make_uint4(thread_id, 0, 0, 0); 14 | __threadfence_block(); 15 | } 16 | 17 | __device__ __forceinline__ uint2 PhiloxGenerator::mulhilo(const unsigned int a, const unsigned int b) { 18 | uint2 result; 19 | unsigned long long prod; 20 | asm("mul.wide.u32 %0, %1, %2;" : "=l"(prod) : "r"(a), "r"(b)); 21 | result.x = static_cast(prod); 22 | result.y = static_cast(prod >> 32); 23 | return result; 24 | } 25 | 26 | __device__ __forceinline__ uint4 PhiloxGenerator::round(uint4 ctr, uint2 key) { 27 | const uint2 mul0 = mulhilo(philox::M0, ctr.x); 28 | const uint2 mul1 = mulhilo(philox::M1, ctr.z); 29 | 30 | return make_uint4( 31 | mul1.y ^ ctr.y ^ key.x, 32 | mul1.x, 33 | mul0.y ^ ctr.w ^ key.y, 34 | mul0.x 35 | ); 36 | } 37 | 38 | __device__ __forceinline__ uint4 PhiloxGenerator::next() { 39 | uint4 ctr = counter; 40 | uint2 k = key; 41 | 42 | #pragma unroll 43 | for (int i = 0; i < philox::ROUNDS; ++i) { 44 | ctr = round(ctr, k); 45 | k.x += philox::W32_0; 46 | k.y += philox::W32_1; 47 | } 48 | 49 | counter.x += 4; 50 | return ctr; 51 | } 52 | 53 | __device__ __forceinline__ __nv_bfloat16 float_to_bf16_stochastic(const float value, const uint32_t rand) { 54 | const uint32_t val_bits = __float_as_uint(value); 55 | const uint32_t rounding_bits = val_bits & 0xFFFF; 56 | uint32_t result = val_bits & 0xFFFF0000u; 57 | result += (rand & 0xFFFF) < rounding_bits ? 0x10000u : 0; 58 | return __float2bfloat16(__uint_as_float(result)); 59 | } 60 | 61 | __device__ __forceinline__ void float4_to_bf16_stochastic( 62 | const float4& values, 63 | uint4& rand_vals, 64 | __nv_bfloat16* output) { 65 | 66 | float vals[4] = {values.x, values.y, values.z, values.w}; 67 | uint32_t rands[4] = {rand_vals.x, rand_vals.y, rand_vals.z, rand_vals.w}; 68 | 69 | #pragma unroll 70 | for (int i = 0; i < 4; i++) { 71 | output[i] = float_to_bf16_stochastic(vals[i], rands[i]); 72 | } 73 | } 74 | 75 | __global__ void stochastic_round_bf16( 76 | float *__restrict__ input, 77 | __nv_bfloat16 *__restrict__ output, 78 | const int size, 79 | const unsigned long long seed) { 80 | 81 | PhiloxGenerator rng; 82 | rng.init(seed, blockIdx.x * blockDim.x + threadIdx.x); 83 | 84 | int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 4; 85 | int stride = blockDim.x * gridDim.x * 4; 86 | 87 | float4 values; 88 | __nv_bfloat16 local_output[4]; 89 | 90 | // Process full vectors of 4 elements 91 | for (; idx <= size - 4; idx += stride) { 92 | values = *reinterpret_cast(&input[idx]); 93 | uint4 rand = rng.next(); 94 | float4_to_bf16_stochastic(values, rand, local_output); 95 | 96 | for (int j = 0; j < 4; j++) { 97 | output[idx + j] = local_output[j]; 98 | } 99 | } 100 | 101 | // Handle remaining elements 102 | if (idx < size) { 103 | float remaining_values[4] = {0.0f, 0.0f, 0.0f, 0.0f}; 104 | int remainder = size - idx; 105 | 106 | for (int j = 0; j < remainder; j++) { 107 | remaining_values[j] = input[idx + j]; 108 | } 109 | 110 | values.x = remaining_values[0]; 111 | values.y = remaining_values[1]; 112 | values.z = remaining_values[2]; 113 | values.w = remaining_values[3]; 114 | 115 | uint4 rand = rng.next(); 116 | float4_to_bf16_stochastic(values, rand, local_output); 117 | 118 | for (int j = 0; j < remainder; j++) { 119 | output[idx + j] = local_output[j]; 120 | } 121 | } 122 | } 123 | -------------------------------------------------------------------------------- /dev/sr/test.md: -------------------------------------------------------------------------------- 1 | (tkdev11) [less@devgpu115.cco2 ~/local/applied-ai/dev/sr (sr_kernel)]$ python usage.py 2 | Launching kernel with blocks=1, threads_per_block=256, num_elements=12 3 | Input tensor: tensor([ 0.3282, -0.4513, -1.0612, 0.1446, -0.8440, -1.4669, -0.7135, -0.6183, 4 | -2.2411, 2.1464, 1.4772, -1.3564], device='cuda:0') 5 | Output tensor: tensor([ 0.3281, -0.4512, -1.0625, 0.1445, -0.8438, -1.4688, -0.7109, -0.6172, 6 | -2.2344, 2.1406, 1.4766, -1.3516], device='cuda:0', 7 | dtype=torch.bfloat16) 8 | Output tensor dtype: torch.bfloat16 9 | Success! 10 | -------------------------------------------------------------------------------- /dev/sr/tests/benchmark.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import stochastic_rounding_cuda 3 | import numpy as np 4 | import time 5 | from tabulate import tabulate 6 | import argparse 7 | 8 | def measure_performance(func, input_tensor, warmup=0, repeats=1): 9 | """Measure performance of a function with proper CUDA synchronization""" 10 | # Warmup 11 | for _ in range(warmup): 12 | output = func(input_tensor) 13 | 14 | torch.cuda.synchronize() 15 | start = time.perf_counter() 16 | 17 | for _ in range(repeats): 18 | output = func(input_tensor) 19 | 20 | torch.cuda.synchronize() 21 | end = time.perf_counter() 22 | 23 | avg_time = (end - start) / repeats 24 | elements_per_second = input_tensor.numel() / avg_time 25 | return avg_time, elements_per_second 26 | 27 | def benchmark_sizes(sizes= [1000, 10000, 100000, 1000000, 10000000, (10000000*10), (10000000*100)]): 28 | #[ 50,000,000]): # 29 | """Benchmark different input sizes""" 30 | results = [] 31 | 32 | for size in sizes: 33 | # Create input tensor 34 | x = torch.randn(size, device='cuda') 35 | 36 | # Measure stochastic rounding 37 | time_stoch, throughput_stoch = measure_performance( 38 | stochastic_rounding_cuda.stochastic_round_bf16, x) 39 | 40 | # Measure regular BF16 casting 41 | time_regular, throughput_regular = measure_performance( 42 | lambda t: t.to(torch.bfloat16), x) 43 | 44 | results.append([ 45 | size, 46 | time_stoch * 1000, # convert to ms 47 | throughput_stoch / 1e6, # convert to GElements/s 48 | time_regular * 1000, 49 | throughput_regular / 1e6, 50 | throughput_regular / throughput_stoch # speedup 51 | ]) 52 | 53 | print("\nSize Comparison:") 54 | print(tabulate(results, 55 | headers=['Size', 'Stoch Time (ms)', 'Stoch ME/s', 56 | 'Regular Time (ms)', 'Regular ME/s', 'Casting faster by'], 57 | floatfmt='.3f')) 58 | 59 | def benchmark_shapes(total_size=1000000): 60 | """Benchmark different tensor shapes with same total size""" 61 | shapes = [ 62 | (total_size,), # 1D 63 | (1000, total_size//1000), # 2D 64 | (100, 100, total_size//10000), # 3D 65 | ] 66 | 67 | results = [] 68 | for shape in shapes: 69 | x = torch.randn(*shape, device='cuda') 70 | time_stoch, throughput_stoch = measure_performance( 71 | stochastic_rounding_cuda.stochastic_round_bf16, x) 72 | 73 | results.append([ 74 | 'x'.join(str(d) for d in shape), 75 | time_stoch * 1000, 76 | throughput_stoch / 1e9 77 | ]) 78 | 79 | print("\nShape Comparison (same total size):") 80 | print(tabulate(results, 81 | headers=['Shape', 'Time (ms)', 'GElements/s'], 82 | floatfmt='.3f')) 83 | 84 | def stress_test(duration=10): 85 | """Run a stress test for specified duration""" 86 | print(f"\nRunning stress test for {duration} seconds...") 87 | 88 | size = 1000000 89 | x = torch.randn(size, device='cuda') 90 | start_time = time.time() 91 | iterations = 0 92 | 93 | while time.time() - start_time < duration: 94 | stochastic_rounding_cuda.stochastic_round_bf16(x) 95 | iterations += 1 96 | 97 | print(f"Completed {iterations} iterations without errors") 98 | print(f"Average throughput: {(iterations * size) / (duration * 1e9):.2f} GElements/s") 99 | 100 | def memory_test(max_size=1e9): 101 | """Test memory scaling""" 102 | sizes = np.logspace(3, min(9, np.log10(max_size)), num=7, dtype=int) 103 | results = [] 104 | 105 | for size in sizes: 106 | try: 107 | torch.cuda.empty_cache() 108 | x = torch.randn(size, device='cuda') 109 | torch.cuda.synchronize() 110 | 111 | # Measure peak memory during operation 112 | torch.cuda.reset_peak_memory_stats() 113 | _ = stochastic_rounding_cuda.stochastic_round_bf16(x) 114 | torch.cuda.synchronize() 115 | 116 | peak_memory = torch.cuda.max_memory_allocated() / 1e6 # MB 117 | results.append([size, peak_memory]) 118 | 119 | except RuntimeError as e: 120 | print(f"Out of memory at size {size}") 121 | break 122 | 123 | print("\nMemory Usage:") 124 | print(tabulate(results, 125 | headers=['Size', 'Peak Memory (MB)'], 126 | floatfmt='.2f')) 127 | 128 | def main(): 129 | parser = argparse.ArgumentParser(description='Benchmark stochastic rounding') 130 | parser.add_argument('--sizes', action='store_true', help='Run size benchmarks') 131 | parser.add_argument('--shapes', action='store_true', help='Run shape benchmarks') 132 | parser.add_argument('--stress', action='store_true', help='Run stress test') 133 | parser.add_argument('--memory', action='store_true', help='Run memory test') 134 | parser.add_argument('--all', action='store_true', help='Run all benchmarks') 135 | 136 | args = parser.parse_args() 137 | 138 | # Print device information 139 | device = torch.cuda.get_device_properties(0) 140 | print(f"\nRunning on: {device.name}") 141 | print(f"Compute Capability: {device.major}.{device.minor}") 142 | 143 | 144 | if args.all or args.sizes: 145 | benchmark_sizes() 146 | 147 | if args.all or args.shapes: 148 | benchmark_shapes() 149 | 150 | if args.all or args.stress: 151 | stress_test() 152 | 153 | if args.all or args.memory: 154 | memory_test() 155 | 156 | if __name__ == '__main__': 157 | main() 158 | -------------------------------------------------------------------------------- /dev/sr/tests/core_unit_tests.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from collections import Counter 4 | import unittest 5 | import stochastic_rounding_cuda 6 | import time 7 | 8 | class TestStochasticRounding(unittest.TestCase): 9 | def setup(self): 10 | # Ensure deterministic behavior for some tests 11 | torch.manual_seed(42) 12 | np.random.seed(42) 13 | 14 | def _test_rounding_statistics_helper(self, value, lower_value, upper_value, tensor_size=10000, rounds=100): 15 | """Helper method for testing stochastic rounding statistics""" 16 | print(f"\nInput value: {value}") 17 | MAX_VARIANCE = 0.03 18 | x = torch.full((tensor_size,), value, device='cuda') 19 | torch.cuda.manual_seed(42) 20 | 21 | # Single round test - isolate and show the round up and round down values 22 | single_result = stochastic_rounding_cuda.stochastic_round_bf16(x) 23 | print(f"Possible rounded values: {torch.unique(single_result)}") 24 | 25 | # Multiple rounds 26 | results = torch.empty((rounds, tensor_size), device='cuda', dtype=torch.bfloat16) 27 | for i in range(rounds): 28 | results[i] = stochastic_rounding_cuda.stochastic_round_bf16(x) 29 | 30 | prob_up = (results == upper_value).float().mean().item() 31 | print(f"Kernel's probability of rounding up: {prob_up:.4f}") 32 | 33 | distance_to_lower = abs(value - lower_value) 34 | total_distance = upper_value - lower_value 35 | expected_prob = distance_to_lower / total_distance 36 | print(f"Expected probability: {expected_prob:.4f}") 37 | 38 | self.assertTrue(abs(prob_up - expected_prob) < MAX_VARIANCE) 39 | 40 | def test_special_values(self): 41 | """Test handling of special values like inf, -inf, nan""" 42 | special_values = torch.tensor([float('inf'), float('-inf'), float('nan'), 0.0, -0.0], 43 | device='cuda') 44 | rounded = stochastic_rounding_cuda.stochastic_round_bf16(special_values) 45 | 46 | # Check inf and -inf are preserved 47 | self.assertTrue(torch.isinf(rounded[0])) 48 | self.assertTrue(torch.isinf(rounded[1])) 49 | self.assertTrue(rounded[0] > 0) 50 | self.assertTrue(rounded[1] < 0) 51 | 52 | # Check nan is preserved 53 | self.assertTrue(torch.isnan(rounded[2])) 54 | 55 | # Check zeros are preserved 56 | self.assertEqual(rounded[3].item(), 0.0) 57 | self.assertEqual(rounded[4].item(), 0.0) 58 | 59 | def test_small_values(self): 60 | """Test handling of small values near zero""" 61 | small_values = torch.tensor([1e-38, -1e-38, 1e-20, -1e-20], device='cuda') 62 | rounded = stochastic_rounding_cuda.stochastic_round_bf16(small_values) 63 | 64 | # Check that very small values are handled properly 65 | self.assertTrue(torch.all(torch.isfinite(rounded))) 66 | 67 | def test_vectorized_loading(self): 68 | """Test if vectorized loading works correctly for different tensor sizes""" 69 | sizes = [4, 8, 9, 16, 32, 100] # Test various sizes including non-aligned 70 | 71 | for size in sizes: 72 | x = torch.linspace(1, size, size, device='cuda') 73 | rounded = stochastic_rounding_cuda.stochastic_round_bf16(x) 74 | 75 | # Check output size matches input 76 | self.assertEqual(rounded.size(0), size) 77 | 78 | # Check dtype 79 | self.assertEqual(rounded.dtype, torch.bfloat16) 80 | 81 | def test_large_values(self): 82 | """Test handling of large values""" 83 | large_values = torch.tensor([1e38, -1e38, 1e20, -1e20], device='cuda') 84 | rounded = stochastic_rounding_cuda.stochastic_round_bf16(large_values) 85 | 86 | # Values should be preserved approximately in BF16 range 87 | self.assertTrue(torch.all(torch.isfinite(rounded))) 88 | 89 | def test_rounding_statistics(self): 90 | """Test if rounding probabilities match expected distribution""" 91 | self._test_rounding_statistics_helper(2.1999969482421875, 2.1875, 2.2031) 92 | 93 | def test_rounding_statistics_2(self): 94 | """Test stochastic rounding with different BF16 boundary values""" 95 | self._test_rounding_statistics_helper(1.7999992370605469, 1.7969, 1.8047) 96 | 97 | def test_rounding_statistics_small(self): 98 | """Test stochastic rounding for number between 0 and 1""" 99 | self._test_rounding_statistics_helper(0.7499847412109375, 0.7480, 0.7500) 100 | 101 | def test_rounding_statistics_large(self): 102 | """Test stochastic rounding for large number, over 100""" 103 | self._test_rounding_statistics_helper(128.99998474121094, 128.875, 129.000) 104 | 105 | 106 | 107 | if __name__ == '__main__': 108 | unittest.main(verbosity=2) 109 | -------------------------------------------------------------------------------- /dev/sr/usage.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import stochastic_rounding_cuda 3 | 4 | # Create input tensor 5 | input_tensor = torch.randn(12, device='cuda', dtype=torch.float32) 6 | 7 | # Apply stochastic rounding 8 | output_tensor = stochastic_rounding_cuda.stochastic_round_bf16(input_tensor) 9 | print(f"Input tensor: {input_tensor}") 10 | print(f"Output tensor: {output_tensor}") 11 | print(f"Output tensor dtype: {output_tensor.dtype}") 12 | print(f"Success!") 13 | 14 | ''' 15 | # Test tensor 16 | x = torch.tensor([9.8751e-01, -8.5288e-01, 1.6775e+00, -1.3683e+00, 17 | 4.0467e-01, 1.0759e-03, 2.8418e-01, -4.9392e-01, 18 | 8.7239e-01, -9.0545e-01, 1.1134e+00, 0], # -2.6872e+00 19 | device='cuda') 20 | 21 | # Convert to BF16 22 | y = stochastic_rounding_cuda.stochastic_round_bf16(x) 23 | print(f"Input: {x}") 24 | print(f"Output: {y}") 25 | ''' 26 | -------------------------------------------------------------------------------- /dev/sr/usage2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import stochastic_rounding_cuda 3 | 4 | # Test tensor 5 | x = torch.tensor([9.8751e-01, -8.5288e-01, 1.6775e+00], device='cuda') 6 | 7 | # Compare with regular rounding 8 | y_normal = x.to(torch.bfloat16) 9 | y_stochastic = stochastic_rounding_cuda.stochastic_round_bf16(x) 10 | 11 | print(f"Input: {x}") 12 | print(f"Normal BF16: {y_normal}") 13 | print(f"Stochastic BF16: {y_stochastic}") 14 | -------------------------------------------------------------------------------- /dev/triton_groupGEMM/testing/base_testing.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | import logging 10 | 11 | # Configure logging to print to stdout 12 | logging.basicConfig( 13 | level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" 14 | ) 15 | 16 | import os 17 | import sys 18 | import unittest 19 | from typing import Tuple 20 | 21 | import torch 22 | 23 | # Add parent directory to path 24 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 25 | 26 | 27 | if torch.cuda.is_available(): 28 | # from fp8_gemm import quantize_fp8_row 29 | from groupgemm import grouped_gemm # , grouped_gemm_fp8_rowwise 30 | from tma_utils import HAS_TMA_DESC 31 | 32 | 33 | @unittest.skipIf( 34 | not torch.cuda.is_available() 35 | or torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9 36 | or not HAS_TMA_DESC, 37 | "Skip when H100 or TMA is not available", 38 | ) 39 | class TestGroupedGEMM(unittest.TestCase): 40 | def setUp(self) -> None: 41 | torch.manual_seed(0) 42 | 43 | """def test_grouped_gemm_fp8_rowwise(self) -> None: 44 | def _test_grouped_gemm_fp8_rowwise( 45 | shape: Tuple[int, int, int, int], 46 | device: torch.device, 47 | ) -> None: 48 | G, M, N, K = shape 49 | a = torch.randn(M, K, dtype=torch.bfloat16, device=device) 50 | b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device) 51 | m_ends, _ = torch.sort( 52 | torch.randint( 53 | low=0, high=M, size=[G - 1], device=device, dtype=torch.int32 54 | ) 55 | ) 56 | m_ends = m_ends.tolist() 57 | m_starts = [0] + m_ends 58 | m_ends = m_ends + [M] 59 | m_sizes = torch.tensor( 60 | [m_ends[i] - m_starts[i] for i in range(G)], device=device 61 | ).to(torch.int32) 62 | 63 | a_fp8, a_scale = quantize_fp8_row(a) 64 | b_fp8, b_scale = quantize_fp8_row(b) 65 | 66 | result = grouped_gemm_fp8_rowwise( 67 | a_fp8, 68 | b_fp8, 69 | m_sizes, 70 | a_scale, 71 | b_scale, 72 | ) 73 | self.assertTrue(result.shape == (M, N)) 74 | 75 | expected_result = torch.zeros(M, N, dtype=torch.bfloat16, device=device) 76 | # Running baseline with quantization to exclude quantization error from the test as it has nothing to do with the correctness of the kernel implementation. 77 | for g in range(G): 78 | m_start = m_starts[g] 79 | m_end = m_ends[g] 80 | n_start = g * N 81 | n_end = (g + 1) * N 82 | 83 | expected_result[m_start:m_end, :] = ( 84 | a_fp8[m_start:m_end, :].to(torch.float32) 85 | @ b_fp8[n_start:n_end, :].to(torch.float32).T 86 | * a_scale[m_start:m_end][:, None] 87 | * b_scale[n_start:n_end][None, :] 88 | ).to(torch.bfloat16) 89 | 90 | torch.testing.assert_close(result, expected_result, atol=2e-2, rtol=1.6e-2) 91 | 92 | for G in (1, 4, 16): 93 | for M in (64, 512): 94 | logging.info(f"Testing FP8 GMM with G={G}, M={M}") 95 | _test_grouped_gemm_fp8_rowwise((G, M, 256, 256), torch.device("cuda")) 96 | """ 97 | 98 | def test_grouped_gemm_bf16(self) -> None: 99 | def _test_grouped_gemm_bf16( 100 | shape: Tuple[int, int, int, int], 101 | device: torch.device, 102 | ) -> None: 103 | G, M, N, K = shape 104 | a = torch.randn(M, K, dtype=torch.bfloat16, device=device) 105 | b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device) 106 | m_ends, _ = torch.sort( 107 | torch.randint( 108 | low=0, high=M, size=[G - 1], device=device, dtype=torch.int32 109 | ) 110 | ) 111 | m_ends = m_ends.tolist() 112 | m_starts = [0] + m_ends 113 | m_ends = m_ends + [M] 114 | m_sizes = torch.tensor( 115 | [m_ends[i] - m_starts[i] for i in range(G)], device=device 116 | ).to(torch.int32) 117 | 118 | result = grouped_gemm( 119 | a, 120 | b, 121 | m_sizes, 122 | ) 123 | self.assertTrue(result.shape == (M, N)) 124 | 125 | expected_result = torch.zeros(M, N, dtype=torch.bfloat16, device=device) 126 | for g in range(G): 127 | m_start = m_starts[g] 128 | m_end = m_ends[g] 129 | expected_result[m_start:m_end, :] = ( 130 | a[m_start:m_end, :] @ b[g * N : (g + 1) * N, :].T 131 | ) 132 | 133 | torch.testing.assert_close(result, expected_result, atol=1e-5, rtol=1.6e-2) 134 | 135 | for G in (1, 4, 16): 136 | for M in (64, 512): 137 | logging.info(f"Testing BF16 GMM with G={G}, M={M}") 138 | _test_grouped_gemm_bf16((G, M, 256, 256), torch.device("cuda")) 139 | 140 | 141 | if __name__ == "__main__": 142 | unittest.main(exit=False) 143 | -------------------------------------------------------------------------------- /dev/triton_groupGEMM/tma_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-unsafe 8 | # This code is derived from: https://github.com/pytorch/FBGEMM/tree/main/fbgemm_gpu/experimental/gemm/triton_gemm 9 | 10 | import sys 11 | 12 | import torch 13 | import triton # @manual 14 | 15 | import triton.language as tl # @manual 16 | 17 | 18 | def map_dtype_to_triton(dtype: torch.dtype) -> tl.dtype: 19 | """ 20 | Maps torch dtype to triton dtype. 21 | 22 | Args: 23 | dtype (torch.dtype): input dtype. 24 | 25 | Returns: 26 | tl.dtype: triton dtype. 27 | """ 28 | if dtype == torch.float16: 29 | return tl.float16 30 | elif dtype == torch.bfloat16: 31 | return tl.bfloat16 32 | elif dtype == torch.float32: 33 | return tl.float32 34 | elif dtype == torch.int32: 35 | return tl.int32 36 | elif dtype == torch.float8_e4m3fn and torch.version.hip is None: 37 | return tl.float8e4nv 38 | else: 39 | raise ValueError(f"Unsupported dtype {dtype}") 40 | 41 | 42 | # check if we have the TMA version in Triton PR #4498 (https://github.com/triton-lang/triton/pull/4498). 43 | HAS_TMA_DESC = "nv_tma_desc_type" in dir(tl) 44 | 45 | if HAS_TMA_DESC: 46 | print( 47 | "TMA benchmarks will be running with experimental grid constant TMA descriptor.", 48 | file=sys.stderr, 49 | ) 50 | else: 51 | print( 52 | "Missing: This group gemm code will not run without TMA descriptor support....", 53 | file=sys.stderr, 54 | ) 55 | raise NotImplementedError("grouped Gemm without TMA is not supported") 56 | 57 | 58 | class TmaAutoTuneHelper: 59 | 60 | # duck typing wrapper to implement the same interface as TmaDescKernelParam in Triton PR #4498 61 | class KernelParamWrapper: 62 | def __init__(self, desc): 63 | self.desc = desc 64 | 65 | def tma_desc_cpu_ptr(self): 66 | return self.desc.data_ptr() 67 | 68 | TMA_SIZE = 128 69 | 70 | def __init__(self): 71 | self.fill_1d_tma_descriptor_inner = ( 72 | triton.runtime.driver.active.utils.fill_1d_tma_descriptor 73 | ) 74 | self.fill_2d_tma_descriptor_inner = ( 75 | triton.runtime.driver.active.utils.fill_2d_tma_descriptor 76 | ) 77 | if HAS_TMA_DESC: 78 | self.descriptors = {} 79 | else: 80 | self.cuda_descriptors = {} 81 | 82 | # Call this method outside of the lambda function for grid size 83 | def init_tma_descriptor(self, name): 84 | if HAS_TMA_DESC: 85 | self.descriptors[name] = torch.empty( 86 | TmaAutoTuneHelper.TMA_SIZE, device="cpu", dtype=torch.int8 87 | ) 88 | else: 89 | self.cuda_descriptors[name] = torch.empty( 90 | TmaAutoTuneHelper.TMA_SIZE, device="cuda", dtype=torch.int8 91 | ) 92 | 93 | # Call this method inside the lambda function for grid size 94 | def fill_1d_tma_descriptor(self, name, ptr, dim, block_dim, element_size): 95 | if HAS_TMA_DESC: 96 | desc_x = self.descriptors[name] 97 | assert desc_x.data_ptr() % 64 == 0 98 | self.fill_1d_tma_descriptor_inner( 99 | ptr, dim, block_dim, element_size, desc_x.data_ptr() 100 | ) 101 | else: 102 | desc_x = self.cuda_descriptors[name] 103 | buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True) 104 | self.fill_1d_tma_descriptor_inner( 105 | ptr, dim, block_dim, element_size, buf_x.data_ptr() 106 | ) 107 | desc_x.copy_(buf_x, non_blocking=True) 108 | 109 | # Call this method inside the lambda function for grid size 110 | def fill_2d_tma_descriptor( 111 | self, name, ptr, dim1, dim0, block_dim1, block_dim0, element_size 112 | ): 113 | if HAS_TMA_DESC: 114 | desc_x = self.descriptors[name] 115 | assert desc_x.data_ptr() % 64 == 0 116 | self.fill_2d_tma_descriptor_inner( 117 | ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr() 118 | ) 119 | else: 120 | desc_x = self.cuda_descriptors[name] 121 | buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True) 122 | self.fill_2d_tma_descriptor_inner( 123 | ptr, dim1, dim0, block_dim1, block_dim0, element_size, buf_x.data_ptr() 124 | ) 125 | desc_x.copy_(buf_x, non_blocking=True) 126 | 127 | def get_tma_descriptor_kernel_param(self, name): 128 | if HAS_TMA_DESC: 129 | assert self.descriptors[name] is not None 130 | return self.KernelParamWrapper(self.descriptors[name]) 131 | else: 132 | assert self.cuda_descriptors[name] is not None 133 | return self.cuda_descriptors[name] 134 | -------------------------------------------------------------------------------- /kernels/MoE/group_GEMM/triton/readme.md: -------------------------------------------------------------------------------- 1 | ## Experimental 2 | 3 | Triton Group GEMM for supporting MoE training. 4 | -------------------------------------------------------------------------------- /kernels/MoE/group_GEMM/triton/testing/fast_verification.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | 5 | # Configure logging 6 | logging.basicConfig( 7 | level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" 8 | ) 9 | 10 | # import the reference implementations 11 | from pytorch_reference_backwards import ( 12 | _compute_grad_w_pytorch, 13 | _compute_grad_x_pytorch, 14 | _pytorch_fallback_backward, 15 | _pytorch_reference_backward, 16 | ) 17 | 18 | # Import the grouped GEMM modules 19 | from tgrouped_gemm_backwards import grouped_gemm_backward 20 | from tgrouped_gemm_forward import grouped_gemm_forward as grouped_gemm 21 | 22 | 23 | def test_backward_pass(): 24 | """ 25 | A simple test for the grouped GEMM backward pass with detailed error handling. 26 | """ 27 | try: 28 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 29 | 30 | # Test parameters 31 | G = 20 # Number of groups 32 | M = 1024 # Input dimension 33 | N = 512 # Output dimension per group 34 | K = 256 # Hidden dimension 35 | 36 | # Create input and weight tensors 37 | x = torch.randn(M, K, dtype=torch.bfloat16, device=device, requires_grad=True) 38 | w = torch.randn( 39 | N * G, K, dtype=torch.bfloat16, device=device, requires_grad=True 40 | ) 41 | 42 | # Create group sizes 43 | m_sizes = torch.zeros(G, device=device, dtype=torch.int32) 44 | base_size = M // G 45 | remainder = M % G 46 | 47 | for i in range(G): 48 | m_sizes[i] = base_size + (1 if i < remainder else 0) 49 | 50 | # Log the setup 51 | print(f"Test setup - G: {G}, M: {M}, N: {N}, K: {K}") 52 | print(f"Input x shape: {x.shape}") 53 | logging.info(f"Weight w shape: {w.shape}") 54 | logging.info(f"Group sizes: {m_sizes}") 55 | 56 | # Step 1: Run forward pass 57 | logging.info("Running forward pass") 58 | result = grouped_gemm(x, w, m_sizes) 59 | logging.info(f"Forward result shape: {result.shape}") 60 | 61 | # Create a gradient for backpropagation 62 | grad_output = torch.randn_like(result) 63 | logging.info(f"Created gradient with shape: {grad_output.shape}") 64 | 65 | # Step 2: Run backward pass directly 66 | logging.info("Running backward pass directly") 67 | grad_x, grad_w = grouped_gemm_backward(grad_output, x, w, m_sizes) 68 | 69 | # Verify gradient shapes 70 | logging.info( 71 | f"Gradient shapes - grad_x: {grad_x.shape}, grad_w: {grad_w.shape}" 72 | ) 73 | 74 | # Step 3: Verify gradient computation using PyTorch's autograd 75 | # First create autograd-enabled tensors 76 | x_autograd = x.detach().clone().requires_grad_(True) 77 | w_autograd = w.detach().clone().requires_grad_(True) 78 | 79 | # Create a PyTorch reference implementation to compare against 80 | logging.info("Running PyTorch reference implementation") 81 | 82 | # Compute reference result 83 | reference_result = torch.zeros_like(result) 84 | m_start = 0 85 | for g in range(G): 86 | m_size = m_sizes[g].item() 87 | m_end = m_start + m_size 88 | n_start = g * N 89 | n_end = (g + 1) * N 90 | 91 | if m_size > 0: 92 | reference_result[m_start:m_end, n_start:n_end] = ( 93 | x_autograd[m_start:m_end, :] @ w_autograd[n_start:n_end, :].T 94 | ) 95 | 96 | m_start = m_end 97 | 98 | # Backpropagate using PyTorch 99 | reference_result.backward(grad_output) 100 | 101 | # Compare gradients 102 | logging.info("Comparing gradients with PyTorch reference") 103 | grad_x_error = (grad_x - x_autograd.grad).abs().max().item() 104 | grad_w_error = (grad_w - w_autograd.grad).abs().max().item() 105 | 106 | logging.info( 107 | f"Maximum gradient error - grad_x: {grad_x_error}, grad_w: {grad_w_error}" 108 | ) 109 | 110 | # Check if gradients are close using allclose 111 | rtol = 1e-2 # Relative tolerance for bfloat16 112 | atol = 1e-2 # Absolute tolerance for bfloat16 113 | 114 | grad_x_close = torch.allclose(grad_x, x_autograd.grad, rtol=rtol, atol=atol) 115 | if not grad_x_close: 116 | logging.warning("FAILED: Gradient mismatch detected in grad_x") 117 | else: 118 | logging.info( 119 | "✓ SUCCESS! grad_X matches the PyTorch reference (allclose check passed)" 120 | ) 121 | 122 | grad_w_close = torch.allclose(grad_w, w_autograd.grad, rtol=rtol, atol=atol) 123 | if not grad_w_close: 124 | logging.warning("FAILED: Gradient mismatch detected in grad_w") 125 | else: 126 | logging.info( 127 | "✓ SUCCESS! grad_W matches the PyTorch reference (allclose check passed)" 128 | ) 129 | 130 | logging.info( 131 | f"Gradients allclose check - grad_x: {grad_x_close}, grad_w: {grad_w_close}" 132 | ) 133 | 134 | if grad_x_close and grad_w_close: 135 | logging.info( 136 | "✓ SUCCESS: Gradients match the PyTorch reference (allclose check passed)" 137 | ) 138 | else: 139 | logging.error("✗ FAILURE: Gradient mismatch detected in allclose check") 140 | 141 | # Additional diagnostics (for failed cases or in general) 142 | if True: # not grad_x_close: 143 | # Find where the largest differences are 144 | diff_x = (grad_x - x_autograd.grad).abs() 145 | max_idx_x = diff_x.argmax().item() 146 | flat_idx_x = max_idx_x 147 | idx_x = np.unravel_index(flat_idx_x, grad_x.shape) 148 | logging.error( 149 | f"Largest grad_x difference at {idx_x}: " 150 | f"{grad_x[idx_x].item()} vs {x_autograd.grad[idx_x].item()}" 151 | ) 152 | # Count zeros 153 | zeros_grad_x = (grad_x == 0).sum().item() 154 | zeros_autograd_x = (x_autograd.grad == 0).sum().item() 155 | logging.error( 156 | f"Zeros in grad_x: {zeros_grad_x}/{grad_x.numel()} ({zeros_grad_x/grad_x.numel()*100:.2f}%)" 157 | ) 158 | logging.error( 159 | f"Zeros in x_autograd.grad: {zeros_autograd_x}/{x_autograd.grad.numel()} ({zeros_autograd_x/x_autograd.grad.numel()*100:.2f}%)" 160 | ) 161 | 162 | if True: # not grad_w_close: 163 | # Find where the largest differences are 164 | diff_w = (grad_w - w_autograd.grad).abs() 165 | max_idx_w = diff_w.argmax().item() 166 | flat_idx_w = max_idx_w 167 | idx_w = np.unravel_index(flat_idx_w, grad_w.shape) 168 | logging.error( 169 | f"Largest grad_w difference at {idx_w}: " 170 | f"{grad_w[idx_w].item()} vs {w_autograd.grad[idx_w].item()}" 171 | ) 172 | # Count zeros 173 | zeros_grad_w = (grad_w == 0).sum().item() 174 | zeros_autograd_w = (w_autograd.grad == 0).sum().item() 175 | logging.error( 176 | f"Zeros in grad_w: {zeros_grad_w}/{grad_w.numel()} ({zeros_grad_w/grad_w.numel()*100:.2f}%)" 177 | ) 178 | logging.error( 179 | f"Zeros in w_autograd.grad: {zeros_autograd_w}/{w_autograd.grad.numel()} ({zeros_autograd_w/w_autograd.grad.numel()*100:.2f}%)" 180 | ) 181 | 182 | return grad_x_close and grad_w_close 183 | 184 | except Exception as e: 185 | logging.error(f"Test failed with error: {e}") 186 | import traceback 187 | 188 | logging.error(traceback.format_exc()) 189 | return False 190 | 191 | 192 | if __name__ == "__main__": 193 | print("Running test_backward_pass") 194 | logging.debug("Running test_backward_pass") 195 | # Add numpy import for unravel_index 196 | import numpy as np 197 | 198 | success = test_backward_pass() 199 | logging.info(f"Test {'succeeded' if success else 'failed'}") 200 | -------------------------------------------------------------------------------- /kernels/MoE/group_GEMM/triton/testing/pytorch_reference_backwards.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import torch 9 | 10 | # This is a series of helper functions for grouped GEMM backward that compute the gradients 11 | # using eager PyTorch operations. They are used as a verification reference for the Triton kernels. 12 | # They can also used as a fallback when the Triton kernels cannot be used, though lets hope that is not needed. 13 | 14 | 15 | def _compute_grad_x_pytorch(grad_output, w, m_sizes, grad_x): 16 | """ 17 | Compute grad_x using pure PyTorch operations with FP32 precision 18 | """ 19 | G = m_sizes.shape[0] 20 | M, K = grad_x.shape 21 | N = w.shape[0] // G 22 | 23 | # Zero out the output tensor first 24 | grad_x.zero_() 25 | 26 | # Store original dtype and convert to float32 for computation 27 | orig_dtype = grad_x.dtype 28 | grad_output_fp32 = grad_output.float() 29 | w_fp32 = w.float() 30 | grad_x_fp32 = torch.zeros_like(grad_x, dtype=torch.float32) 31 | 32 | # Process each group separately 33 | m_start = 0 34 | for g in range(G): 35 | m_size = m_sizes[g].item() 36 | if m_size > 0: 37 | m_end = m_start + m_size 38 | n_start = g * N 39 | n_end = (g + 1) * N 40 | 41 | # Get slices for this group 42 | grad_output_slice = grad_output_fp32[m_start:m_end, n_start:n_end] 43 | w_slice = w_fp32[n_start:n_end] 44 | 45 | # Process in chunks for better precision on large matrices 46 | CHUNK_SIZE = 256 47 | for chunk_start in range(0, m_size, CHUNK_SIZE): 48 | chunk_end = min(chunk_start + CHUNK_SIZE, m_size) 49 | chunk_size = chunk_end - chunk_start 50 | 51 | # Compute matrix multiplication with higher precision 52 | grad_output_chunk = grad_output_slice[chunk_start:chunk_end] 53 | result_chunk = torch.matmul( 54 | grad_output_chunk.double(), w_slice.double() 55 | ) 56 | 57 | # Store the result 58 | grad_x_fp32[m_start + chunk_start : m_start + chunk_end].copy_( 59 | result_chunk.float() 60 | ) 61 | 62 | m_start = m_end 63 | 64 | # Convert back to original dtype 65 | grad_x.copy_(grad_x_fp32.to(orig_dtype)) 66 | 67 | 68 | def _compute_grad_w_pytorch(grad_output, x, m_sizes, grad_w): 69 | """ 70 | Compute grad_w using pure PyTorch operations with FP64 precision for better accuracy. 71 | """ 72 | G = m_sizes.shape[0] 73 | N_times_G, K = grad_w.shape 74 | N = N_times_G // G 75 | 76 | # Zero out the output tensor first 77 | grad_w.zero_() 78 | 79 | # Store original dtype and convert to float32 for computation 80 | orig_dtype = grad_w.dtype 81 | grad_output_fp32 = grad_output.float() 82 | x_fp32 = x.float() 83 | grad_w_fp32 = torch.zeros_like(grad_w, dtype=torch.float32) 84 | 85 | # Handle potential K dimension mismatches 86 | K_x = x.shape[1] 87 | min_K = min(K, K_x) 88 | 89 | # Process each group separately 90 | m_start = 0 91 | for g in range(G): 92 | m_size = m_sizes[g].item() 93 | if m_size > 0: 94 | m_end = m_start + m_size 95 | n_start = g * N 96 | n_end = (g + 1) * N 97 | 98 | # Get slices for this group 99 | grad_output_slice = grad_output_fp32[m_start:m_end, n_start:n_end] 100 | x_slice = x_fp32[m_start:m_end, :min_K] 101 | 102 | # Process in chunks for better precision 103 | CHUNK_SIZE = 32 104 | result = torch.zeros( 105 | (grad_output_slice.shape[1], min_K), 106 | dtype=torch.float64, 107 | device=grad_output_slice.device, 108 | ) 109 | 110 | for chunk_start in range(0, m_size, CHUNK_SIZE): 111 | chunk_end = min(chunk_start + CHUNK_SIZE, m_size) 112 | 113 | # Get chunks 114 | grad_output_chunk = grad_output_slice[chunk_start:chunk_end].double() 115 | x_chunk = x_slice[chunk_start:chunk_end].double() 116 | 117 | # Matrix multiplication in FP64 118 | chunk_result = torch.matmul(grad_output_chunk.t(), x_chunk) 119 | result += chunk_result 120 | 121 | # Handle K dimension padding if needed 122 | if K > min_K: 123 | temp_result = torch.zeros( 124 | (grad_output_slice.shape[1], K), 125 | dtype=torch.float32, 126 | device=grad_output_slice.device, 127 | ) 128 | temp_result[:, :min_K] = result.float() 129 | grad_w_fp32[n_start:n_end].copy_(temp_result) 130 | else: 131 | grad_w_fp32[n_start:n_end].copy_(result.float()) 132 | 133 | m_start = m_end 134 | 135 | # Convert back to original dtype 136 | grad_w.copy_(grad_w_fp32.to(orig_dtype)) 137 | 138 | 139 | def _pytorch_fallback_backward(grad_output, x, w, m_sizes): 140 | """ 141 | Pure PyTorch implementation of grouped GEMM backward with high precision. 142 | Used as a fallback when the Triton kernels cannot be used. 143 | """ 144 | logging.info( 145 | "WARNING: Using PyTorch fallback for grouped GEMM backward with high precision" 146 | ) 147 | 148 | # Ensure inputs are contiguous 149 | x = x.contiguous() 150 | w = w.contiguous() 151 | grad_output = grad_output.contiguous() 152 | m_sizes = m_sizes.contiguous() 153 | 154 | # Allocate output tensors 155 | grad_x = torch.zeros_like(x) 156 | grad_w = torch.zeros_like(w) 157 | 158 | # Compute gradients using the helper functions 159 | _compute_grad_x_pytorch(grad_output, w, m_sizes, grad_x) 160 | _compute_grad_w_pytorch(grad_output, x, m_sizes, grad_w) 161 | 162 | return grad_x, grad_w 163 | 164 | 165 | def _pytorch_reference_backward(grad_output, x, w, m_sizes): 166 | """ 167 | Pure PyTorch implementation of grouped GEMM backward for validation. 168 | Simple version that's easy to verify but may be less numerically accurate 169 | for large matrices. 170 | """ 171 | # Create output gradients 172 | grad_x = torch.zeros_like(x) 173 | grad_w = torch.zeros_like(w) 174 | 175 | # Compute group-by-group 176 | G = m_sizes.shape[0] 177 | N = w.shape[0] // G 178 | 179 | m_start = 0 180 | for g in range(G): 181 | m_size = m_sizes[g].item() 182 | if m_size > 0: 183 | m_end = m_start + m_size 184 | n_start = g * N 185 | n_end = (g + 1) * N 186 | 187 | # Compute gradients 188 | grad_x[m_start:m_end] = torch.matmul( 189 | grad_output[m_start:m_end, n_start:n_end], w[n_start:n_end] 190 | ) 191 | grad_w[n_start:n_end] = torch.matmul( 192 | grad_output[m_start:m_end, n_start:n_end].t(), x[m_start:m_end] 193 | ) 194 | 195 | m_start += m_size 196 | 197 | return grad_x, grad_w 198 | 199 | 200 | # ========== End helper functions ========== 201 | -------------------------------------------------------------------------------- /kernels/MoE/group_GEMM/triton/utils/tma_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-unsafe 8 | # This code is derived from: https://github.com/pytorch/FBGEMM/tree/main/fbgemm_gpu/experimental/gemm/triton_gemm 9 | 10 | import sys 11 | 12 | import torch 13 | import triton # @manual 14 | 15 | import triton.language as tl # @manual 16 | 17 | 18 | def map_dtype_to_triton(dtype: torch.dtype) -> tl.dtype: 19 | """ 20 | Maps torch dtype to triton dtype. 21 | 22 | Args: 23 | dtype (torch.dtype): input dtype. 24 | 25 | Returns: 26 | tl.dtype: triton dtype. 27 | """ 28 | if dtype == torch.float16: 29 | return tl.float16 30 | elif dtype == torch.bfloat16: 31 | return tl.bfloat16 32 | elif dtype == torch.float32: 33 | return tl.float32 34 | elif dtype == torch.int32: 35 | return tl.int32 36 | elif dtype == torch.float8_e4m3fn and torch.version.hip is None: 37 | return tl.float8e4nv 38 | else: 39 | raise ValueError(f"Unsupported dtype {dtype}") 40 | 41 | 42 | # check if we have the TMA version in Triton PR #4498 (https://github.com/triton-lang/triton/pull/4498). 43 | HAS_TMA_DESC = "nv_tma_desc_type" in dir(tl) 44 | 45 | if HAS_TMA_DESC: 46 | print( 47 | "TMA benchmarks will be running with experimental grid constant TMA descriptor.", 48 | file=sys.stderr, 49 | ) 50 | else: 51 | print( 52 | "Missing: This group gemm code will not run without TMA descriptor support....", 53 | file=sys.stderr, 54 | ) 55 | raise NotImplementedError("grouped Gemm without TMA is not supported") 56 | 57 | 58 | class TmaAutoTuneHelper: 59 | 60 | # duck typing wrapper to implement the same interface as TmaDescKernelParam in Triton PR #4498 61 | class KernelParamWrapper: 62 | def __init__(self, desc): 63 | self.desc = desc 64 | 65 | def tma_desc_cpu_ptr(self): 66 | return self.desc.data_ptr() 67 | 68 | TMA_SIZE = 128 69 | 70 | def __init__(self): 71 | self.fill_1d_tma_descriptor_inner = ( 72 | triton.runtime.driver.active.utils.fill_1d_tma_descriptor 73 | ) 74 | self.fill_2d_tma_descriptor_inner = ( 75 | triton.runtime.driver.active.utils.fill_2d_tma_descriptor 76 | ) 77 | if HAS_TMA_DESC: 78 | self.descriptors = {} 79 | else: 80 | self.cuda_descriptors = {} 81 | 82 | # Call this method outside of the lambda function for grid size 83 | def init_tma_descriptor(self, name): 84 | if HAS_TMA_DESC: 85 | self.descriptors[name] = torch.empty( 86 | TmaAutoTuneHelper.TMA_SIZE, device="cpu", dtype=torch.int8 87 | ) 88 | else: 89 | self.cuda_descriptors[name] = torch.empty( 90 | TmaAutoTuneHelper.TMA_SIZE, device="cuda", dtype=torch.int8 91 | ) 92 | 93 | # Call this method inside the lambda function for grid size 94 | def fill_1d_tma_descriptor(self, name, ptr, dim, block_dim, element_size): 95 | if HAS_TMA_DESC: 96 | desc_x = self.descriptors[name] 97 | assert desc_x.data_ptr() % 64 == 0 98 | self.fill_1d_tma_descriptor_inner( 99 | ptr, dim, block_dim, element_size, desc_x.data_ptr() 100 | ) 101 | else: 102 | desc_x = self.cuda_descriptors[name] 103 | buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True) 104 | self.fill_1d_tma_descriptor_inner( 105 | ptr, dim, block_dim, element_size, buf_x.data_ptr() 106 | ) 107 | desc_x.copy_(buf_x, non_blocking=True) 108 | 109 | # Call this method inside the lambda function for grid size 110 | def fill_2d_tma_descriptor( 111 | self, name, ptr, dim1, dim0, block_dim1, block_dim0, element_size 112 | ): 113 | if HAS_TMA_DESC: 114 | desc_x = self.descriptors[name] 115 | assert desc_x.data_ptr() % 64 == 0 116 | self.fill_2d_tma_descriptor_inner( 117 | ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr() 118 | ) 119 | else: 120 | desc_x = self.cuda_descriptors[name] 121 | buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True) 122 | self.fill_2d_tma_descriptor_inner( 123 | ptr, dim1, dim0, block_dim1, block_dim0, element_size, buf_x.data_ptr() 124 | ) 125 | desc_x.copy_(buf_x, non_blocking=True) 126 | 127 | def get_tma_descriptor_kernel_param(self, name): 128 | if HAS_TMA_DESC: 129 | assert self.descriptors[name] is not None 130 | return self.KernelParamWrapper(self.descriptors[name]) 131 | else: 132 | assert self.cuda_descriptors[name] is not None 133 | return self.cuda_descriptors[name] 134 | -------------------------------------------------------------------------------- /kernels/blackwell/cute_gemm_01/Makefile: -------------------------------------------------------------------------------- 1 | 2 | # Makefile for SM100 GEMM PyTorch Extension 3 | 4 | # Set these paths according to your installation 5 | CUTLASS_PATH ?= /path/to/cutlass 6 | CUDA_HOME ?= $(shell python -c "import torch; print(torch.utils.cpp_extension.CUDA_HOME)") 7 | 8 | # Build the extension 9 | build: 10 | CUTLASS_PATH=$(CUTLASS_PATH) python setup.py build_ext --inplace 11 | 12 | # Install the extension 13 | install: 14 | CUTLASS_PATH=$(CUTLASS_PATH) pip install . 15 | 16 | # Clean build artifacts 17 | clean: 18 | rm -rf build/ dist/ *.egg-info/ sm100_gemm*.so 19 | 20 | # Test the installation 21 | test: 22 | python python_interface.py 23 | 24 | # Check CUDA device capability 25 | check_device: 26 | python -c "import torch; print(f'CUDA device: {torch.cuda.get_device_name()}, Compute capability: {torch.cuda.get_device_capability()}')" 27 | 28 | .PHONY: build install clean test check_device 29 | -------------------------------------------------------------------------------- /kernels/blackwell/cute_gemm_01/build/lib.linux-x86_64-cpython-312/sm100_gemm.cpython-312-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch-labs/applied-ai/4f8c39eb70d92e3848c65347ca077c8e723c74fd/kernels/blackwell/cute_gemm_01/build/lib.linux-x86_64-cpython-312/sm100_gemm.cpython-312-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /kernels/blackwell/cute_gemm_01/build/temp.linux-x86_64-cpython-312/.ninja_deps: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch-labs/applied-ai/4f8c39eb70d92e3848c65347ca077c8e723c74fd/kernels/blackwell/cute_gemm_01/build/temp.linux-x86_64-cpython-312/.ninja_deps -------------------------------------------------------------------------------- /kernels/blackwell/cute_gemm_01/build/temp.linux-x86_64-cpython-312/.ninja_log: -------------------------------------------------------------------------------- 1 | # ninja log v5 2 | 0 15279 1748131038212164071 /data/users/less/applied-ai/kernels/blackwell/cute_gemm/build/temp.linux-x86_64-cpython-312/sm100_gemm_pytorch.o 1163be77f63db063 3 | 6 13596 1748131241209889865 /data/users/less/applied-ai/kernels/blackwell/cute_gemm/build/temp.linux-x86_64-cpython-312/sm100_gemm.o 79aa61597088743a 4 | 8 13684 1748132015451659084 /data/users/less/applied-ai/kernels/blackwell/cute_gemm/build/temp.linux-x86_64-cpython-312/sm100_gemm.o 89ead7aaccf82852 5 | -------------------------------------------------------------------------------- /kernels/blackwell/cute_gemm_01/build/temp.linux-x86_64-cpython-312/build.ninja: -------------------------------------------------------------------------------- 1 | ninja_required_version = 1.3 2 | cxx = c++ 3 | nvcc = /usr/local/cuda-12.8/bin/nvcc 4 | 5 | cflags = -pthread -B /home/less/.conda/envs/pycutlass/compiler_compat -fno-strict-overflow -Wsign-compare -DNDEBUG -O2 -Wall -fPIC -O2 -isystem /home/less/.conda/envs/pycutlass/include -fPIC -O2 -isystem /home/less/.conda/envs/pycutlass/include -fPIC -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -I/home/less/local/cutlass40/include -I/home/less/local/cutlass40/tools/util/include -I/usr/local/cuda-12.8/include -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -I/usr/local/cuda-12.8/include -I/home/less/.conda/envs/pycutlass/include/python3.12 -c 6 | post_cflags = -O3 -std=c++17 -DCUTLASS_ARCH_MMA_SM100_SUPPORTED -DCUTE_SM100_ENABLED -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1018"' -DTORCH_EXTENSION_NAME=sm100_gemm 7 | cuda_cflags = -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -I/home/less/local/cutlass40/include -I/home/less/local/cutlass40/tools/util/include -I/usr/local/cuda-12.8/include -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -I/usr/local/cuda-12.8/include -I/home/less/.conda/envs/pycutlass/include/python3.12 -c 8 | cuda_post_cflags = -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -O3 -std=c++17 --expt-relaxed-constexpr --expt-extended-lambda -gencode=arch=compute_100a,code=sm_100a -DCUTLASS_ARCH_MMA_SM100_SUPPORTED -DCUTE_SM100_ENABLED --use_fast_math -Xcompiler=-fPIC -DCUTE_ARCH_TCGEN05_TMEM_ENABLED=1 -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1018"' -DTORCH_EXTENSION_NAME=sm100_gemm 9 | cuda_dlink_post_cflags = 10 | sycl_dlink_post_cflags = 11 | ldflags = 12 | 13 | rule compile 14 | command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags 15 | depfile = $out.d 16 | deps = gcc 17 | 18 | rule cuda_compile 19 | depfile = $out.d 20 | deps = gcc 21 | command = $nvcc --generate-dependencies-with-compile --dependency-output $out.d $cuda_cflags -c $in -o $out $cuda_post_cflags 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | build /data/users/less/applied-ai/kernels/blackwell/cute_gemm/build/temp.linux-x86_64-cpython-312/sm100_gemm.o: cuda_compile /data/users/less/applied-ai/kernels/blackwell/cute_gemm/sm100_gemm.cu 30 | build /data/users/less/applied-ai/kernels/blackwell/cute_gemm/build/temp.linux-x86_64-cpython-312/sm100_gemm_pytorch.o: compile /data/users/less/applied-ai/kernels/blackwell/cute_gemm/sm100_gemm_pytorch.cpp 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | -------------------------------------------------------------------------------- /kernels/blackwell/cute_gemm_01/build/temp.linux-x86_64-cpython-312/sm100_gemm.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch-labs/applied-ai/4f8c39eb70d92e3848c65347ca077c8e723c74fd/kernels/blackwell/cute_gemm_01/build/temp.linux-x86_64-cpython-312/sm100_gemm.o -------------------------------------------------------------------------------- /kernels/blackwell/cute_gemm_01/build/temp.linux-x86_64-cpython-312/sm100_gemm_pytorch.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch-labs/applied-ai/4f8c39eb70d92e3848c65347ca077c8e723c74fd/kernels/blackwell/cute_gemm_01/build/temp.linux-x86_64-cpython-312/sm100_gemm_pytorch.o -------------------------------------------------------------------------------- /kernels/blackwell/cute_gemm_01/dist/sm100_gemm-0.0.0-py3.12-linux-x86_64.egg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch-labs/applied-ai/4f8c39eb70d92e3848c65347ca077c8e723c74fd/kernels/blackwell/cute_gemm_01/dist/sm100_gemm-0.0.0-py3.12-linux-x86_64.egg -------------------------------------------------------------------------------- /kernels/blackwell/cute_gemm_01/driver.py: -------------------------------------------------------------------------------- 1 | # ============================================================================== 2 | # python_interface.py - High-level Python interface 3 | # ============================================================================== 4 | 5 | 6 | import torch 7 | 8 | try: 9 | import sm100_gemm # The compiled extension - this has to go after import torch...but auto-formatting is blocking 10 | except ImportError: 11 | print("❌ SM100 not ready!") 12 | raise ImportError( 13 | "SM100 not ready! Please build the extension using `python setup.py install`" 14 | ) 15 | 16 | 17 | def sm100_gemm_f16(A, B, C=None, alpha=1.0, beta=0.0): 18 | """ 19 | Perform GEMM using SM100 optimized kernel: D = alpha * A @ B^T + beta * C 20 | 21 | Args: 22 | A (torch.Tensor): Input tensor A of shape (M, K), dtype=torch.float16 23 | B (torch.Tensor): Input tensor B of shape (N, K), dtype=torch.float16 24 | C (torch.Tensor, optional): Input tensor C of shape (M, N), dtype=torch.float32 25 | If None, creates zero tensor 26 | alpha (float): Scaling factor for A @ B^T 27 | beta (float): Scaling factor for C 28 | 29 | Returns: 30 | torch.Tensor: Output tensor D of shape (M, N), dtype=torch.float32 31 | 32 | Note: 33 | - A and B are K-major (transposed in BLAS terms) 34 | - C and D are N-major (row-major) 35 | - All tensors must be on CUDA 36 | - M must be multiple of 128, N multiple of 256, K multiple of 64 37 | """ 38 | 39 | # Input validation 40 | assert A.dtype == torch.float16, f"A must be float16, got {A.dtype}" 41 | assert B.dtype == torch.float16, f"B must be float16, got {B.dtype}" 42 | assert A.is_cuda and B.is_cuda, "A and B must be on CUDA" 43 | assert A.is_contiguous() and B.is_contiguous(), "A and B must be contiguous" 44 | 45 | M, K = A.shape 46 | N, K_B = B.shape 47 | assert K == K_B, f"Inner dimensions must match: A.shape[1]={K}, B.shape[1]={K_B}" 48 | 49 | # Check alignment requirements 50 | assert M % 128 == 0, f"M={M} must be multiple of 128" 51 | assert N % 256 == 0, f"N={N} must be multiple of 256" 52 | assert K % 64 == 0, f"K={K} must be multiple of 64" 53 | 54 | # Create C if not provided 55 | if C is None: 56 | C = torch.zeros(M, N, dtype=torch.float32, device=A.device) 57 | else: 58 | assert C.dtype == torch.float32, f"C must be float32, got {C.dtype}" 59 | assert C.is_cuda, "C must be on CUDA" 60 | assert C.is_contiguous(), "C must be contiguous" 61 | assert C.shape == ( 62 | M, 63 | N, 64 | ), f"C shape {C.shape} must match output shape ({M}, {N})" 65 | 66 | # Call the extension 67 | return sm100_gemm.sm100_gemm_f16(A, B, C, alpha, beta) 68 | 69 | 70 | def benchmark_sm100_vs_torch( 71 | M=1024, N=2048, K=256, num_warmup=1, num_trials=10 72 | ): # M=512, N=1024, K=256, num_warmup=10, num_trials=100): 73 | """ 74 | Benchmark SM100 GEMM against PyTorch's native GEMM 75 | """ 76 | # Ensure dimensions are aligned 77 | M = ((M + 127) // 128) * 128 78 | N = ((N + 255) // 256) * 256 79 | K = ((K + 63) // 64) * 64 80 | 81 | print(f"Benchmarking GEMM with shape: ({M}, {N}, {K})") 82 | 83 | # Create test tensors 84 | A = torch.randn(M, K, dtype=torch.float16, device="cuda") 85 | B = torch.randn(N, K, dtype=torch.float16, device="cuda") 86 | C = torch.randn(M, N, dtype=torch.float16, device="cuda") 87 | C32 = C.to(torch.float32).clone() 88 | 89 | # Keep A and B as FP16 for PyTorch 90 | A_fp16 = A 91 | B_fp16 = B 92 | 93 | # Warmup 94 | for _ in range(num_warmup): 95 | # PyTorch GEMM (using FP16) 96 | torch_result = torch.addmm(C, A_fp16, B_fp16.T) 97 | 98 | # SM100 GEMM 99 | sm100_result = sm100_gemm_f16(A, B, C32) 100 | 101 | torch.cuda.synchronize() 102 | 103 | # Benchmark PyTorch 104 | torch.cuda.synchronize() 105 | start = torch.cuda.Event(enable_timing=True) 106 | end = torch.cuda.Event(enable_timing=True) 107 | 108 | start.record() 109 | for _ in range(num_trials): 110 | torch_result = torch.addmm(C, A_fp16, B_fp16.T) 111 | end.record() 112 | torch.cuda.synchronize() 113 | torch_time = start.elapsed_time(end) / num_trials 114 | 115 | # Benchmark SM100 116 | start.record() 117 | for _ in range(num_trials): 118 | sm100_result = sm100_gemm_f16(A, B, C32) 119 | end.record() 120 | torch.cuda.synchronize() 121 | sm100_time = start.elapsed_time(end) / num_trials 122 | 123 | # Check correctness 124 | max_diff = torch.max(torch.abs(torch_result - sm100_result.to(torch.float16))) 125 | rel_error = max_diff / torch.max(torch.abs(torch_result)) 126 | 127 | # Calculate FLOPS 128 | flops = 2 * M * N * K # Multiply-add operations 129 | torch_tflops = flops / (torch_time * 1e-3) / 1e12 130 | sm100_tflops = flops / (sm100_time * 1e-3) / 1e12 131 | 132 | print(f"PyTorch time: {torch_time:.3f} ms ({torch_tflops:.2f} TFLOPS)") 133 | print(f"SM100 time: {sm100_time:.3f} ms ({sm100_tflops:.2f} TFLOPS)") 134 | print(f"Speedup: {torch_time/sm100_time:.2f}x") 135 | print(f"Max difference: {max_diff:.6f}") 136 | print(f"Relative error: {rel_error:.6f}") 137 | 138 | return { 139 | "torch_time": torch_time, 140 | "sm100_time": sm100_time, 141 | "speedup": torch_time / sm100_time, 142 | "torch_tflops": torch_tflops, 143 | "sm100_tflops": sm100_tflops, 144 | "max_diff": max_diff.item(), 145 | "rel_error": rel_error.item(), 146 | } 147 | 148 | 149 | # Example usage and test 150 | if __name__ == "__main__": 151 | # Test basic functionality 152 | print("Testing SM100 GEMM...") 153 | 154 | M, N, K = 512, 1024, 256 155 | A = torch.randn(M, K, dtype=torch.float16, device="cuda") 156 | B = torch.randn(N, K, dtype=torch.float16, device="cuda") 157 | C = torch.randn(M, N, dtype=torch.float32, device="cuda") 158 | 159 | # Test the GEMM 160 | result = sm100_gemm_f16(A, B, C, alpha=1.0, beta=0.5) 161 | print(f"Result shape: {result.shape}, dtype: {result.dtype}") 162 | 163 | # Run benchmark 164 | print("\nRunning benchmark...") 165 | benchmark_results = benchmark_sm100_vs_torch(M, N, K) 166 | 167 | # ============================================================================== 168 | # Makefile for easy building 169 | # ============================================================================== 170 | ''' 171 | MAKEFILE_CONTENT = """ 172 | # Makefile for SM100 GEMM PyTorch Extension 173 | 174 | # Set these paths according to your installation 175 | CUTLASS_PATH ?= /path/to/cutlass 176 | CUDA_HOME ?= $(shell python -c "import torch; print(torch.utils.cpp_extension.CUDA_HOME)") 177 | 178 | # Build the extension 179 | build: 180 | CUTLASS_PATH=$(CUTLASS_PATH) python setup.py build_ext --inplace 181 | 182 | # Install the extension 183 | install: 184 | CUTLASS_PATH=$(CUTLASS_PATH) pip install . 185 | 186 | # Clean build artifacts 187 | clean: 188 | rm -rf build/ dist/ *.egg-info/ sm100_gemm*.so 189 | 190 | # Test the installation 191 | test: 192 | python python_interface.py 193 | 194 | # Check CUDA device capability 195 | check_device: 196 | python -c "import torch; print(f'CUDA device: {torch.cuda.get_device_name()}, Compute capability: {torch.cuda.get_device_capability()}')" 197 | 198 | .PHONY: build install clean test check_device 199 | """ 200 | 201 | # Write Makefile 202 | with open("Makefile", "w") as f: 203 | f.write(MAKEFILE_CONTENT) 204 | 205 | print("Setup files created!") 206 | print("To build:") 207 | print("1. Set CUTLASS_PATH environment variable to your CUTLASS installation") 208 | print("2. Run: make build") 209 | print("3. Test: make test") 210 | ''' 211 | -------------------------------------------------------------------------------- /kernels/blackwell/cute_gemm_01/setup.py: -------------------------------------------------------------------------------- 1 | # setup.py 2 | import os 3 | 4 | import pybind11 5 | import torch 6 | from pybind11 import get_cmake_dir 7 | from pybind11.setup_helpers import build_ext, Pybind11Extension 8 | from setuptools import Extension, setup 9 | from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension 10 | 11 | # IMPORTANT: The following two lines are the only ones you need to change 12 | # Get CUTLASS path (you'll need to set this to your CUTLASS installation) 13 | CUTLASS_PATH = os.environ.get("CUTLASS_PATH", "/home/less/local/cutlas40") 14 | 15 | # CUDA and PyTorch paths 16 | cuda_home = torch.utils.cpp_extension.CUDA_HOME 17 | pytorch_includes = torch.utils.cpp_extension.include_paths() 18 | 19 | ext_modules = [ 20 | CUDAExtension( 21 | name="sm100_gemm", 22 | sources=[ 23 | "sm100_gemm_pytorch.cpp", # PyTorch bindings (C++) 24 | "sm100_gemm.cu", # CUDA kernel implementation 25 | ], 26 | include_dirs=[ 27 | # PyTorch includes 28 | *pytorch_includes, 29 | # CUTLASS includes 30 | f"{CUTLASS_PATH}/include", 31 | f"{CUTLASS_PATH}/tools/util/include", 32 | # CUDA includes 33 | f"{cuda_home}/include", 34 | ], 35 | library_dirs=[ 36 | f"{cuda_home}/lib64", 37 | ], 38 | libraries=["cuda", "cudart"], 39 | extra_compile_args={ 40 | "cxx": [ 41 | "-O3", 42 | "-std=c++17", 43 | "-DCUTLASS_ARCH_MMA_SM100_SUPPORTED", 44 | "-DCUTE_SM100_ENABLED", 45 | ], 46 | "nvcc": [ 47 | "-O3", 48 | "-std=c++17", 49 | "--expt-relaxed-constexpr", 50 | "--expt-extended-lambda", 51 | "-gencode=arch=compute_100a,code=sm_100a", # SM100 architecture 52 | "-DCUTLASS_ARCH_MMA_SM100_SUPPORTED", 53 | "-DCUTE_SM100_ENABLED", 54 | "--use_fast_math", 55 | "-Xcompiler=-fPIC", 56 | "-DCUTE_ARCH_TCGEN05_TMEM_ENABLED=1", # Enable TCGEN05_TMEM 57 | ], 58 | }, 59 | extra_link_args=["-lcuda", "-lcudart"], 60 | language="c++", 61 | ) 62 | ] 63 | 64 | setup( 65 | name="sm100_gemm", 66 | ext_modules=ext_modules, 67 | cmdclass={"build_ext": BuildExtension}, 68 | zip_safe=False, 69 | python_requires=">=3.8", 70 | install_requires=["torch>=1.12.0"], 71 | ) 72 | -------------------------------------------------------------------------------- /kernels/blackwell/cute_gemm_01/sm100_gemm.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 2.4 2 | Name: sm100_gemm 3 | Version: 0.0.0 4 | Requires-Python: >=3.8 5 | Requires-Dist: torch>=1.12.0 6 | Dynamic: requires-dist 7 | Dynamic: requires-python 8 | -------------------------------------------------------------------------------- /kernels/blackwell/cute_gemm_01/sm100_gemm.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | setup.py 2 | sm100_gemm.cu 3 | sm100_gemm_pytorch.cpp 4 | sm100_gemm.egg-info/PKG-INFO 5 | sm100_gemm.egg-info/SOURCES.txt 6 | sm100_gemm.egg-info/dependency_links.txt 7 | sm100_gemm.egg-info/not-zip-safe 8 | sm100_gemm.egg-info/requires.txt 9 | sm100_gemm.egg-info/top_level.txt -------------------------------------------------------------------------------- /kernels/blackwell/cute_gemm_01/sm100_gemm.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /kernels/blackwell/cute_gemm_01/sm100_gemm.egg-info/not-zip-safe: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /kernels/blackwell/cute_gemm_01/sm100_gemm.egg-info/requires.txt: -------------------------------------------------------------------------------- 1 | torch>=1.12.0 2 | -------------------------------------------------------------------------------- /kernels/blackwell/cute_gemm_01/sm100_gemm.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | sm100_gemm 2 | -------------------------------------------------------------------------------- /kernels/blackwell/cute_gemm_01/sm100_gemm.h: -------------------------------------------------------------------------------- 1 | // sm100_gemm_kernel.h - Header file for CUDA kernel 2 | #pragma once 3 | 4 | #include 5 | 6 | #ifdef __cplusplus 7 | extern "C" { 8 | #endif 9 | 10 | /** 11 | * Launch SM100 GEMM kernel: D = alpha * A @ B^T + beta * C 12 | * 13 | * @param d_A Pointer to matrix A in device memory (M x K, FP16, K-major) 14 | * @param d_B Pointer to matrix B in device memory (N x K, FP16, K-major) 15 | * @param d_C Pointer to matrix C in device memory (M x N, FP32, N-major) 16 | * @param d_D Pointer to matrix D in device memory (M x N, FP32, N-major) 17 | * @param M Number of rows in A and C/D 18 | * @param N Number of rows in B and columns in C/D 19 | * @param K Number of columns in A and B 20 | * @param alpha Scaling factor for A @ B^T 21 | * @param beta Scaling factor for C 22 | * @param stream CUDA stream (currently unused, for future async support) 23 | * 24 | * @return cudaSuccess on success, error code otherwise 25 | * 26 | * Requirements: 27 | * - M must be multiple of 128 28 | * - N must be multiple of 256 29 | * - K must be multiple of 64 30 | * - All pointers must be valid device memory 31 | * - Tensors must be contiguous with specified layouts 32 | */ 33 | cudaError_t launch_sm100_gemm_f16(void *d_A, void *d_B, void *d_C, void *d_D, 34 | int M, int N, int K, float alpha, float beta, 35 | cudaStream_t stream = 0); 36 | 37 | #ifdef __cplusplus 38 | } 39 | #endif 40 | -------------------------------------------------------------------------------- /kernels/blackwell/cute_gemm_01/sm100_gemm_pytorch.cpp: -------------------------------------------------------------------------------- 1 | // sm100_gemm_pytorch.cpp - PyTorch C++ extension (no CUDA code) 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include "sm100_gemm.h" 9 | 10 | // Check if SM100 support is available at compile time 11 | bool is_sm100_supported() { 12 | #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) 13 | return true; 14 | #else 15 | return false; 16 | #endif 17 | } 18 | 19 | // Check if current GPU supports SM100 at runtime 20 | bool check_sm100_device() { 21 | int device; 22 | cudaGetDevice(&device); 23 | 24 | cudaDeviceProp props; 25 | cudaError_t error = cudaGetDeviceProperties(&props, device); 26 | if (error != cudaSuccess) { 27 | return false; 28 | } 29 | 30 | // Check for SM100 architecture (compute capability 10.0a) 31 | return (props.major == 10 && props.minor == 0); 32 | } 33 | 34 | torch::Tensor sm100_gemm_f16(const torch::Tensor &A, const torch::Tensor &B, 35 | const torch::Tensor &C, float alpha = 1.0f, 36 | float beta = 0.0f) { 37 | 38 | // Check compile-time support 39 | TORCH_CHECK( 40 | is_sm100_supported(), 41 | "SM100 support not compiled. Requires CUTLASS_ARCH_MMA_SM100_SUPPORTED"); 42 | 43 | // Check runtime device support 44 | TORCH_CHECK(check_sm100_device(), 45 | "Current GPU does not support SM100 architecture (requires " 46 | "compute capability 10.0a)"); 47 | 48 | // Input validation 49 | TORCH_CHECK(A.device().is_cuda(), "A must be a CUDA tensor"); 50 | TORCH_CHECK(B.device().is_cuda(), "B must be a CUDA tensor"); 51 | TORCH_CHECK(C.device().is_cuda(), "C must be a CUDA tensor"); 52 | TORCH_CHECK(A.dtype() == torch::kFloat16, "A must be float16"); 53 | TORCH_CHECK(B.dtype() == torch::kFloat16, "B must be float16"); 54 | TORCH_CHECK(C.dtype() == torch::kFloat32, "C must be float32"); 55 | TORCH_CHECK(A.is_contiguous(), "A must be contiguous"); 56 | TORCH_CHECK(B.is_contiguous(), "B must be contiguous"); 57 | TORCH_CHECK(C.is_contiguous(), "C must be contiguous"); 58 | TORCH_CHECK(A.dim() == 2, "A must be 2D"); 59 | TORCH_CHECK(B.dim() == 2, "B must be 2D"); 60 | TORCH_CHECK(C.dim() == 2, "C must be 2D"); 61 | 62 | // Get dimensions 63 | int64_t M = A.size(0); 64 | int64_t K = A.size(1); 65 | int64_t N = B.size(0); 66 | int64_t K_B = B.size(1); 67 | 68 | TORCH_CHECK(K == K_B, "Inner dimensions must match: A.shape[1]=", K, 69 | ", B.shape[1]=", K_B); 70 | TORCH_CHECK(C.size(0) == M && C.size(1) == N, "C dimensions (", C.size(0), 71 | ", ", C.size(1), ") must match output shape (", M, ", ", N, ")"); 72 | 73 | // Check alignment requirements for SM100 74 | TORCH_CHECK(M % 128 == 0, "M=", M, " must be multiple of 128"); 75 | TORCH_CHECK(N % 256 == 0, "N=", N, " must be multiple of 256"); 76 | TORCH_CHECK(K % 64 == 0, "K=", K, " must be multiple of 64"); 77 | 78 | // Check size limits (avoid overflow in int conversion) 79 | TORCH_CHECK(M <= INT_MAX && N <= INT_MAX && K <= INT_MAX, 80 | "Dimensions too large for int conversion"); 81 | 82 | // Create output tensor 83 | auto D = torch::empty_like(C); 84 | 85 | // Set CUDA device guard 86 | const auto device = A.device(); 87 | c10::cuda::CUDAGuard device_guard(device); 88 | 89 | // Get current CUDA stream 90 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(device.index()).stream(); 91 | 92 | // Launch the kernel 93 | cudaError_t error = launch_sm100_gemm_f16( 94 | A.data_ptr(), B.data_ptr(), C.data_ptr(), D.data_ptr(), 95 | static_cast(M), static_cast(N), static_cast(K), alpha, 96 | beta, stream); 97 | 98 | // Check for launch errors 99 | TORCH_CHECK(error == cudaSuccess, 100 | "SM100 GEMM kernel launch failed: ", cudaGetErrorString(error)); 101 | 102 | // Check for kernel execution errors 103 | C10_CUDA_CHECK(cudaGetLastError()); 104 | 105 | return D; 106 | } 107 | 108 | // Utility functions for debugging and information 109 | torch::Tensor get_device_info() { 110 | int device; 111 | cudaGetDevice(&device); 112 | 113 | cudaDeviceProp props; 114 | cudaGetDeviceProperties(&props, device); 115 | 116 | // Return device info as a tensor (for easy Python access) 117 | auto info = torch::zeros({4}, torch::kInt32); 118 | auto accessor = info.accessor(); 119 | 120 | accessor[0] = props.major; // Compute capability major 121 | accessor[1] = props.minor; // Compute capability minor 122 | accessor[2] = is_sm100_supported(); // Compile-time support 123 | accessor[3] = check_sm100_device(); // Runtime device support 124 | 125 | return info; 126 | } 127 | 128 | std::vector get_aligned_shape(int64_t M, int64_t N, int64_t K) { 129 | // Return properly aligned dimensions for SM100 130 | int64_t aligned_M = ((M + 127) / 128) * 128; 131 | int64_t aligned_N = ((N + 255) / 256) * 256; 132 | int64_t aligned_K = ((K + 63) / 64) * 64; 133 | 134 | return {aligned_M, aligned_N, aligned_K}; 135 | } 136 | 137 | torch::Tensor create_aligned_tensor(const std::vector &shape, 138 | torch::ScalarType dtype, 139 | torch::Device device) { 140 | // Create a tensor with SM100-aligned dimensions 141 | TORCH_CHECK(shape.size() == 2, "Shape must be 2D"); 142 | 143 | auto aligned_shape = 144 | get_aligned_shape(shape[0], shape[1], shape.size() > 2 ? shape[2] : 64); 145 | 146 | if (shape.size() == 2) { 147 | return torch::zeros({aligned_shape[0], aligned_shape[1]}, 148 | torch::TensorOptions().dtype(dtype).device(device)); 149 | } else { 150 | return torch::zeros({aligned_shape[0], aligned_shape[2]}, 151 | torch::TensorOptions().dtype(dtype).device(device)); 152 | } 153 | } 154 | 155 | // Python bindings 156 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 157 | m.doc() = "SM100 GEMM PyTorch Extension"; 158 | 159 | // Main GEMM function 160 | m.def("sm100_gemm_f16", &sm100_gemm_f16, 161 | "SM100 GEMM with FP16 inputs and FP32 output: D = alpha * A @ B^T + " 162 | "beta * C", 163 | py::arg("A"), py::arg("B"), py::arg("C"), py::arg("alpha") = 1.0f, 164 | py::arg("beta") = 0.0f); 165 | 166 | // Utility functions 167 | m.def("is_sm100_supported", &is_sm100_supported, 168 | "Check if SM100 support was compiled in"); 169 | 170 | m.def("check_sm100_device", &check_sm100_device, 171 | "Check if current GPU supports SM100 architecture"); 172 | 173 | m.def("get_device_info", &get_device_info, 174 | "Get device compute capability and SM100 support info"); 175 | 176 | m.def("get_aligned_shape", &get_aligned_shape, 177 | "Get SM100-aligned dimensions for given shape", py::arg("M"), 178 | py::arg("N"), py::arg("K")); 179 | 180 | m.def("create_aligned_tensor", &create_aligned_tensor, 181 | "Create tensor with SM100-aligned dimensions", py::arg("shape"), 182 | py::arg("dtype"), py::arg("device")); 183 | 184 | // Constants for alignment requirements 185 | m.attr("MMA_TILE_M") = 128; 186 | m.attr("MMA_TILE_N") = 256; 187 | m.attr("MMA_TILE_K") = 64; 188 | } 189 | -------------------------------------------------------------------------------- /kernels/blackwell/cute_gemm_02_tma/build/lib.linux-x86_64-cpython-312/sm100_gemm.cpython-312-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch-labs/applied-ai/4f8c39eb70d92e3848c65347ca077c8e723c74fd/kernels/blackwell/cute_gemm_02_tma/build/lib.linux-x86_64-cpython-312/sm100_gemm.cpython-312-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /kernels/blackwell/cute_gemm_02_tma/build/temp.linux-x86_64-cpython-312/.ninja_deps: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch-labs/applied-ai/4f8c39eb70d92e3848c65347ca077c8e723c74fd/kernels/blackwell/cute_gemm_02_tma/build/temp.linux-x86_64-cpython-312/.ninja_deps -------------------------------------------------------------------------------- /kernels/blackwell/cute_gemm_02_tma/build/temp.linux-x86_64-cpython-312/.ninja_log: -------------------------------------------------------------------------------- 1 | # ninja log v5 2 | 1 15202 1748185895110710199 /data/users/less/applied-ai/kernels/blackwell/cute_gemm_02_tma/build/temp.linux-x86_64-cpython-312/sm100_gemm_pytorch.o 342153d32d365f0b 3 | 7 78 1748186494782816813 /data/users/less/applied-ai/kernels/blackwell/cute_gemm_02_tma/build/temp.linux-x86_64-cpython-312/sm100_gemm_pytorch.o 342153d32d365f0b 4 | 6 15086 1748186805607894090 /data/users/less/applied-ai/kernels/blackwell/cute_gemm_02_tma/build/temp.linux-x86_64-cpython-312/sm100_gemm_pytorch.o 342153d32d365f0b 5 | 6 14058 1748187024415643408 /data/users/less/applied-ai/kernels/blackwell/cute_gemm_02_tma/build/temp.linux-x86_64-cpython-312/sm100_gemm.o 6c5f77cfca7cfb81 6 | -------------------------------------------------------------------------------- /kernels/blackwell/cute_gemm_02_tma/build/temp.linux-x86_64-cpython-312/build.ninja: -------------------------------------------------------------------------------- 1 | ninja_required_version = 1.3 2 | cxx = c++ 3 | nvcc = /usr/local/cuda-12.8/bin/nvcc 4 | 5 | cflags = -pthread -B /home/less/.conda/envs/pycutlass/compiler_compat -fno-strict-overflow -Wsign-compare -DNDEBUG -O2 -Wall -fPIC -O2 -isystem /home/less/.conda/envs/pycutlass/include -fPIC -O2 -isystem /home/less/.conda/envs/pycutlass/include -fPIC -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -I/home/less/local/cutlass40/include -I/home/less/local/cutlass40/tools/util/include -I/usr/local/cuda-12.8/include -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -I/usr/local/cuda-12.8/include -I/home/less/.conda/envs/pycutlass/include/python3.12 -c 6 | post_cflags = -O3 -std=c++17 -DCUTLASS_ARCH_MMA_SM100_SUPPORTED -DCUTE_SM100_ENABLED -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1018"' -DTORCH_EXTENSION_NAME=sm100_gemm 7 | cuda_cflags = -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -I/home/less/local/cutlass40/include -I/home/less/local/cutlass40/tools/util/include -I/usr/local/cuda-12.8/include -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include -I/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -I/usr/local/cuda-12.8/include -I/home/less/.conda/envs/pycutlass/include/python3.12 -c 8 | cuda_post_cflags = -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -O3 -std=c++17 --expt-relaxed-constexpr --expt-extended-lambda -gencode=arch=compute_100a,code=sm_100a -DCUTLASS_ARCH_MMA_SM100_SUPPORTED -DCUTE_SM100_ENABLED --use_fast_math -Xcompiler=-fPIC -DCUTE_ARCH_TCGEN05_TMEM_ENABLED=1 -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1018"' -DTORCH_EXTENSION_NAME=sm100_gemm 9 | cuda_dlink_post_cflags = 10 | sycl_dlink_post_cflags = 11 | ldflags = 12 | 13 | rule compile 14 | command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags 15 | depfile = $out.d 16 | deps = gcc 17 | 18 | rule cuda_compile 19 | depfile = $out.d 20 | deps = gcc 21 | command = $nvcc --generate-dependencies-with-compile --dependency-output $out.d $cuda_cflags -c $in -o $out $cuda_post_cflags 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | build /data/users/less/applied-ai/kernels/blackwell/cute_gemm_02_tma/build/temp.linux-x86_64-cpython-312/sm100_gemm.o: cuda_compile /data/users/less/applied-ai/kernels/blackwell/cute_gemm_02_tma/sm100_gemm.cu 30 | build /data/users/less/applied-ai/kernels/blackwell/cute_gemm_02_tma/build/temp.linux-x86_64-cpython-312/sm100_gemm_pytorch.o: compile /data/users/less/applied-ai/kernels/blackwell/cute_gemm_02_tma/sm100_gemm_pytorch.cpp 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | -------------------------------------------------------------------------------- /kernels/blackwell/cute_gemm_02_tma/build/temp.linux-x86_64-cpython-312/sm100_gemm.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch-labs/applied-ai/4f8c39eb70d92e3848c65347ca077c8e723c74fd/kernels/blackwell/cute_gemm_02_tma/build/temp.linux-x86_64-cpython-312/sm100_gemm.o -------------------------------------------------------------------------------- /kernels/blackwell/cute_gemm_02_tma/build/temp.linux-x86_64-cpython-312/sm100_gemm_pytorch.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch-labs/applied-ai/4f8c39eb70d92e3848c65347ca077c8e723c74fd/kernels/blackwell/cute_gemm_02_tma/build/temp.linux-x86_64-cpython-312/sm100_gemm_pytorch.o -------------------------------------------------------------------------------- /kernels/blackwell/cute_gemm_02_tma/dist/sm100_gemm-0.0.0-py3.12-linux-x86_64.egg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch-labs/applied-ai/4f8c39eb70d92e3848c65347ca077c8e723c74fd/kernels/blackwell/cute_gemm_02_tma/dist/sm100_gemm-0.0.0-py3.12-linux-x86_64.egg -------------------------------------------------------------------------------- /kernels/blackwell/cute_gemm_02_tma/setup.py: -------------------------------------------------------------------------------- 1 | # setup.py 2 | import os 3 | 4 | import pybind11 5 | import torch 6 | from pybind11 import get_cmake_dir 7 | from pybind11.setup_helpers import build_ext, Pybind11Extension 8 | from setuptools import Extension, setup 9 | from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension 10 | 11 | # IMPORTANT: The following two lines are the only ones you need to change 12 | # Get CUTLASS path (you'll need to set this to your CUTLASS installation) 13 | CUTLASS_PATH = os.environ.get("CUTLASS_PATH", "/home/less/local/cutlas40") 14 | 15 | # CUDA and PyTorch paths 16 | cuda_home = torch.utils.cpp_extension.CUDA_HOME 17 | pytorch_includes = torch.utils.cpp_extension.include_paths() 18 | 19 | ext_modules = [ 20 | CUDAExtension( 21 | name="sm100_gemm", 22 | sources=[ 23 | "sm100_gemm_pytorch.cpp", # PyTorch bindings (C++) 24 | "sm100_gemm.cu", # CUDA kernel implementation 25 | ], 26 | include_dirs=[ 27 | # PyTorch includes 28 | *pytorch_includes, 29 | # CUTLASS includes 30 | f"{CUTLASS_PATH}/include", 31 | f"{CUTLASS_PATH}/tools/util/include", 32 | # CUDA includes 33 | f"{cuda_home}/include", 34 | ], 35 | library_dirs=[ 36 | f"{cuda_home}/lib64", 37 | ], 38 | libraries=["cuda", "cudart"], 39 | extra_compile_args={ 40 | "cxx": [ 41 | "-O3", 42 | "-std=c++17", 43 | "-DCUTLASS_ARCH_MMA_SM100_SUPPORTED", 44 | "-DCUTE_SM100_ENABLED", 45 | ], 46 | "nvcc": [ 47 | "-O3", 48 | "-std=c++17", 49 | "--expt-relaxed-constexpr", 50 | "--expt-extended-lambda", 51 | "-gencode=arch=compute_100a,code=sm_100a", # SM100 architecture 52 | "-DCUTLASS_ARCH_MMA_SM100_SUPPORTED", 53 | "-DCUTE_SM100_ENABLED", 54 | "--use_fast_math", 55 | "-Xcompiler=-fPIC", 56 | "-DCUTE_ARCH_TCGEN05_TMEM_ENABLED=1", # Enable TCGEN05_TMEM 57 | ], 58 | }, 59 | extra_link_args=["-lcuda", "-lcudart"], 60 | language="c++", 61 | ) 62 | ] 63 | 64 | setup( 65 | name="sm100_gemm", 66 | ext_modules=ext_modules, 67 | cmdclass={"build_ext": BuildExtension}, 68 | zip_safe=False, 69 | python_requires=">=3.8", 70 | install_requires=["torch>=1.12.0"], 71 | ) 72 | -------------------------------------------------------------------------------- /kernels/blackwell/cute_gemm_02_tma/sm100_gemm.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 2.4 2 | Name: sm100_gemm 3 | Version: 0.0.0 4 | Requires-Python: >=3.8 5 | Requires-Dist: torch>=1.12.0 6 | Dynamic: requires-dist 7 | Dynamic: requires-python 8 | -------------------------------------------------------------------------------- /kernels/blackwell/cute_gemm_02_tma/sm100_gemm.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | setup.py 2 | sm100_gemm.cu 3 | sm100_gemm_pytorch.cpp 4 | sm100_gemm.egg-info/PKG-INFO 5 | sm100_gemm.egg-info/SOURCES.txt 6 | sm100_gemm.egg-info/dependency_links.txt 7 | sm100_gemm.egg-info/not-zip-safe 8 | sm100_gemm.egg-info/requires.txt 9 | sm100_gemm.egg-info/top_level.txt -------------------------------------------------------------------------------- /kernels/blackwell/cute_gemm_02_tma/sm100_gemm.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /kernels/blackwell/cute_gemm_02_tma/sm100_gemm.egg-info/not-zip-safe: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /kernels/blackwell/cute_gemm_02_tma/sm100_gemm.egg-info/requires.txt: -------------------------------------------------------------------------------- 1 | torch>=1.12.0 2 | -------------------------------------------------------------------------------- /kernels/blackwell/cute_gemm_02_tma/sm100_gemm.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | sm100_gemm 2 | -------------------------------------------------------------------------------- /kernels/blackwell/cute_gemm_02_tma/sm100_gemm.h: -------------------------------------------------------------------------------- 1 | // sm100_gemm_kernel.h - Header file for CUDA kernel 2 | #pragma once 3 | 4 | #include 5 | 6 | #ifdef __cplusplus 7 | extern "C" { 8 | #endif 9 | 10 | /** 11 | * Launch SM100 GEMM kernel: D = alpha * A @ B^T + beta * C 12 | * 13 | * @param d_A Pointer to matrix A in device memory (M x K, FP16, K-major) 14 | * @param d_B Pointer to matrix B in device memory (N x K, FP16, K-major) 15 | * @param d_C Pointer to matrix C in device memory (M x N, FP32, N-major) 16 | * @param d_D Pointer to matrix D in device memory (M x N, FP32, N-major) 17 | * @param M Number of rows in A and C/D 18 | * @param N Number of rows in B and columns in C/D 19 | * @param K Number of columns in A and B 20 | * @param alpha Scaling factor for A @ B^T 21 | * @param beta Scaling factor for C 22 | * @param stream CUDA stream (currently unused, for future async support) 23 | * 24 | * @return cudaSuccess on success, error code otherwise 25 | * 26 | * Requirements: 27 | * - M must be multiple of 128 28 | * - N must be multiple of 256 29 | * - K must be multiple of 64 30 | * - All pointers must be valid device memory 31 | * - Tensors must be contiguous with specified layouts 32 | */ 33 | cudaError_t launch_sm100_gemm_f16_tma(void *d_A, void *d_B, void *d_C, 34 | void *d_D, int M, int N, int K, 35 | float alpha, float beta, 36 | cudaStream_t stream = 0); 37 | 38 | #ifdef __cplusplus 39 | } 40 | #endif 41 | -------------------------------------------------------------------------------- /kernels/blackwell/cute_gemm_02_tma/sm100_gemm_pytorch.cpp: -------------------------------------------------------------------------------- 1 | // sm100_gemm_pytorch.cpp - PyTorch C++ extension (no CUDA code) 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include "sm100_gemm.h" 9 | 10 | // Check if SM100 support is available at compile time 11 | bool is_sm100_supported() { 12 | #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) 13 | return true; 14 | #else 15 | return false; 16 | #endif 17 | } 18 | 19 | // Check if current GPU supports SM100 at runtime 20 | bool check_sm100_device() { 21 | int device; 22 | cudaGetDevice(&device); 23 | 24 | cudaDeviceProp props; 25 | cudaError_t error = cudaGetDeviceProperties(&props, device); 26 | if (error != cudaSuccess) { 27 | return false; 28 | } 29 | 30 | // Check for SM100 architecture (compute capability 10.0a) 31 | return (props.major == 10 && props.minor == 0); 32 | } 33 | 34 | torch::Tensor sm100_gemm_f16(const torch::Tensor &A, const torch::Tensor &B, 35 | const torch::Tensor &C, float alpha = 1.0f, 36 | float beta = 0.0f) { 37 | 38 | // Check compile-time support 39 | TORCH_CHECK( 40 | is_sm100_supported(), 41 | "SM100 support not compiled. Requires CUTLASS_ARCH_MMA_SM100_SUPPORTED"); 42 | 43 | // Check runtime device support 44 | TORCH_CHECK(check_sm100_device(), 45 | "Current GPU does not support SM100 architecture (requires " 46 | "compute capability 10.0a)"); 47 | 48 | // Input validation 49 | TORCH_CHECK(A.device().is_cuda(), "A must be a CUDA tensor"); 50 | TORCH_CHECK(B.device().is_cuda(), "B must be a CUDA tensor"); 51 | TORCH_CHECK(C.device().is_cuda(), "C must be a CUDA tensor"); 52 | TORCH_CHECK(A.dtype() == torch::kFloat16, "A must be float16"); 53 | TORCH_CHECK(B.dtype() == torch::kFloat16, "B must be float16"); 54 | TORCH_CHECK(C.dtype() == torch::kFloat32, "C must be float32"); 55 | TORCH_CHECK(A.is_contiguous(), "A must be contiguous"); 56 | TORCH_CHECK(B.is_contiguous(), "B must be contiguous"); 57 | TORCH_CHECK(C.is_contiguous(), "C must be contiguous"); 58 | TORCH_CHECK(A.dim() == 2, "A must be 2D"); 59 | TORCH_CHECK(B.dim() == 2, "B must be 2D"); 60 | TORCH_CHECK(C.dim() == 2, "C must be 2D"); 61 | 62 | // Get dimensions 63 | int64_t M = A.size(0); 64 | int64_t K = A.size(1); 65 | int64_t N = B.size(0); 66 | int64_t K_B = B.size(1); 67 | 68 | TORCH_CHECK(K == K_B, "Inner dimensions must match: A.shape[1]=", K, 69 | ", B.shape[1]=", K_B); 70 | TORCH_CHECK(C.size(0) == M && C.size(1) == N, "C dimensions (", C.size(0), 71 | ", ", C.size(1), ") must match output shape (", M, ", ", N, ")"); 72 | 73 | // Check alignment requirements for SM100 74 | TORCH_CHECK(M % 128 == 0, "M=", M, " must be multiple of 128"); 75 | TORCH_CHECK(N % 256 == 0, "N=", N, " must be multiple of 256"); 76 | TORCH_CHECK(K % 64 == 0, "K=", K, " must be multiple of 64"); 77 | 78 | // Check size limits (avoid overflow in int conversion) 79 | TORCH_CHECK(M <= INT_MAX && N <= INT_MAX && K <= INT_MAX, 80 | "Dimensions too large for int conversion"); 81 | 82 | // Create output tensor 83 | auto D = torch::empty_like(C); 84 | 85 | // Set CUDA device guard 86 | const auto device = A.device(); 87 | c10::cuda::CUDAGuard device_guard(device); 88 | 89 | // Get current CUDA stream 90 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(device.index()).stream(); 91 | 92 | // Launch the kernel 93 | cudaError_t error = launch_sm100_gemm_f16_tma( 94 | A.data_ptr(), B.data_ptr(), C.data_ptr(), D.data_ptr(), 95 | static_cast(M), static_cast(N), static_cast(K), alpha, 96 | beta, stream); 97 | 98 | // Check for launch errors 99 | TORCH_CHECK(error == cudaSuccess, 100 | "SM100 GEMM kernel launch failed: ", cudaGetErrorString(error)); 101 | 102 | // Check for kernel execution errors 103 | C10_CUDA_CHECK(cudaGetLastError()); 104 | 105 | return D; 106 | } 107 | 108 | // Utility functions for debugging and information 109 | torch::Tensor get_device_info() { 110 | int device; 111 | cudaGetDevice(&device); 112 | 113 | cudaDeviceProp props; 114 | cudaGetDeviceProperties(&props, device); 115 | 116 | // Return device info as a tensor (for easy Python access) 117 | auto info = torch::zeros({4}, torch::kInt32); 118 | auto accessor = info.accessor(); 119 | 120 | accessor[0] = props.major; // Compute capability major 121 | accessor[1] = props.minor; // Compute capability minor 122 | accessor[2] = is_sm100_supported(); // Compile-time support 123 | accessor[3] = check_sm100_device(); // Runtime device support 124 | 125 | return info; 126 | } 127 | 128 | std::vector get_aligned_shape(int64_t M, int64_t N, int64_t K) { 129 | // Return properly aligned dimensions for SM100 130 | int64_t aligned_M = ((M + 127) / 128) * 128; 131 | int64_t aligned_N = ((N + 255) / 256) * 256; 132 | int64_t aligned_K = ((K + 63) / 64) * 64; 133 | 134 | return {aligned_M, aligned_N, aligned_K}; 135 | } 136 | 137 | torch::Tensor create_aligned_tensor(const std::vector &shape, 138 | torch::ScalarType dtype, 139 | torch::Device device) { 140 | // Create a tensor with SM100-aligned dimensions 141 | TORCH_CHECK(shape.size() == 2, "Shape must be 2D"); 142 | 143 | auto aligned_shape = 144 | get_aligned_shape(shape[0], shape[1], shape.size() > 2 ? shape[2] : 64); 145 | 146 | if (shape.size() == 2) { 147 | return torch::zeros({aligned_shape[0], aligned_shape[1]}, 148 | torch::TensorOptions().dtype(dtype).device(device)); 149 | } else { 150 | return torch::zeros({aligned_shape[0], aligned_shape[2]}, 151 | torch::TensorOptions().dtype(dtype).device(device)); 152 | } 153 | } 154 | 155 | // Python bindings 156 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 157 | m.doc() = "SM100 GEMM PyTorch Extension"; 158 | 159 | // Main GEMM function 160 | m.def("sm100_gemm_f16", &sm100_gemm_f16, 161 | "SM100 GEMM with FP16 inputs and FP32 output: D = alpha * A @ B^T + " 162 | "beta * C", 163 | py::arg("A"), py::arg("B"), py::arg("C"), py::arg("alpha") = 1.0f, 164 | py::arg("beta") = 0.0f); 165 | 166 | // Utility functions 167 | m.def("is_sm100_supported", &is_sm100_supported, 168 | "Check if SM100 support was compiled in"); 169 | 170 | m.def("check_sm100_device", &check_sm100_device, 171 | "Check if current GPU supports SM100 architecture"); 172 | 173 | m.def("get_device_info", &get_device_info, 174 | "Get device compute capability and SM100 support info"); 175 | 176 | m.def("get_aligned_shape", &get_aligned_shape, 177 | "Get SM100-aligned dimensions for given shape", py::arg("M"), 178 | py::arg("N"), py::arg("K")); 179 | 180 | m.def("create_aligned_tensor", &create_aligned_tensor, 181 | "Create tensor with SM100-aligned dimensions", py::arg("shape"), 182 | py::arg("dtype"), py::arg("device")); 183 | 184 | // Constants for alignment requirements 185 | m.attr("MMA_TILE_M") = 128; 186 | m.attr("MMA_TILE_N") = 256; 187 | m.attr("MMA_TILE_K") = 64; 188 | } 189 | -------------------------------------------------------------------------------- /kernels/cuda/cutlass_gemm/common.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "cutlass/cutlass.h" 4 | #include 5 | 6 | /** 7 | * Helper function for checking CUTLASS errors 8 | */ 9 | #define CUTLASS_CHECK(status) \ 10 | { \ 11 | TORCH_CHECK(status == cutlass::Status::kSuccess, \ 12 | cutlassGetStatusString(status)) \ 13 | } 14 | 15 | inline uint32_t next_pow_2(uint32_t const num) { 16 | if (num <= 1) return num; 17 | return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); 18 | } 19 | 20 | inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) { 21 | int max_shared_mem_per_block_opt_in = 0; 22 | cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in, 23 | cudaDevAttrMaxSharedMemoryPerBlockOptin, 24 | device); 25 | return max_shared_mem_per_block_opt_in; 26 | } 27 | -------------------------------------------------------------------------------- /kernels/cuda/cutlass_gemm/cutlass.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | void cutlass_scaled_mm_sm90(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, 5 | torch::Tensor const& a_scales, torch::Tensor const& b_scales); 6 | 7 | torch::Tensor cutlass_scaled_mm(torch::Tensor a, torch::Tensor b, torch::Tensor a_scales, torch::Tensor b_scales) { 8 | 9 | auto acc_dtype = torch::kFloat16; 10 | auto options = torch::TensorOptions().dtype(acc_dtype).device(a.device()); 11 | torch::Tensor out = torch::empty({a.size(0), b.size(1)}, options); 12 | 13 | cutlass_scaled_mm_sm90(out, a, b, a_scales, b_scales); 14 | return out; 15 | } 16 | 17 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 18 | m.def("cutlass_scaled_mm", &cutlass_scaled_mm, "CUTLASS Scaled Matrix Multiplication"); 19 | } -------------------------------------------------------------------------------- /kernels/cuda/cutlass_gemm/readme.md: -------------------------------------------------------------------------------- 1 | Currently the CPP extension builds with Cutlass 3.5.1 (credit to @SamirMoustafa for the update). 2 | 3.6 will fail atm due to a refactor in the TMA descriptor. 3 | -------------------------------------------------------------------------------- /kernels/cuda/cutlass_gemm/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | setup( 5 | name='cutlass_gemm', 6 | ext_modules=[ 7 | CUDAExtension( 8 | name='pingpong_gemm', 9 | sources=['cutlass.cpp', 'cutlass_kernel.cu'], 10 | extra_compile_args={ 11 | 'nvcc': [ 12 | '-DNDEBUG', 13 | '-O3', 14 | '-g', 15 | '-lineinfo', 16 | '--keep', 17 | '--ptxas-options=--warn-on-local-memory-usage', 18 | '--ptxas-options=--warn-on-spills', 19 | '--resource-usage', 20 | '--source-in-ptx', 21 | '-DCUTLASS_DEBUG_TRACE_LEVEL=1', 22 | '-gencode=arch=compute_90a, code=sm_90a', 23 | ] 24 | }, 25 | include_dirs=[ 26 | '/home/adhoq26/cutlass/include', 27 | '/home/adhoq26/cutlass/tools/util/include', 28 | ], 29 | libraries=['cuda'], 30 | library_dirs=['/usr/local/cuda-12.4/lib64'], 31 | ) 32 | ], 33 | cmdclass={ 34 | 'build_ext': BuildExtension 35 | } 36 | ) -------------------------------------------------------------------------------- /kernels/cuda/cutlass_gemm/test_cutlass_gemm.py: -------------------------------------------------------------------------------- 1 | from pingpong_gemm import cutlass_scaled_mm 2 | import torch 3 | 4 | m, k, n = 16, 4096, 4096 5 | dtype = torch.float8_e4m3fn 6 | out_dtype = torch.float16 7 | 8 | a = torch.empty(m, k).normal_(mean=0.0, std=0.5).to(dtype=dtype, device='cuda') 9 | bt = torch.empty(n, k).normal_(mean=0.0, std=0.5).to(dtype=dtype, device='cuda').t() 10 | scale_a = torch.ones((1,)).to(dtype=torch.float32, device='cuda') 11 | scale_b = torch.ones((1,)).to(dtype=torch.float32, device='cuda') 12 | y = cutlass_scaled_mm(a, bt, scale_a, scale_b) 13 | print(y) -------------------------------------------------------------------------------- /kernels/cuda/inference/README.md: -------------------------------------------------------------------------------- 1 | cuda kernels 2 | -------------------------------------------------------------------------------- /kernels/cuda/inference/hadamard_transform/hadamard_transform.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | using namespace torch::indexing; 7 | 8 | template 9 | void run_fht(void* a, void* out, uint32_t numel, uint32_t had_size, cudaStream_t stream); 10 | 11 | constexpr bool is_power_of_two(uint32_t x) { 12 | return x && !(x & (x - 1)); 13 | } 14 | 15 | torch::Tensor hadamard_transform(at::Tensor& in, bool inplace) { 16 | auto dtype = in.scalar_type(); 17 | TORCH_CHECK(dtype == torch::ScalarType::Half || dtype == torch::ScalarType::BFloat16, "Only fp16 and bf16 supported currently"); 18 | TORCH_CHECK(in.is_cuda()); 19 | 20 | const int had_size = in.size(-1); 21 | TORCH_CHECK(is_power_of_two(had_size) && (had_size <= (1U << 15)), 22 | "Only power of two Hadamard sizes up to 2^15 are supported, got ", had_size); 23 | 24 | const auto res_shape = in.sizes(); 25 | torch::Tensor x = in.reshape({-1, had_size}); 26 | 27 | auto numel = in.numel(); 28 | if (numel % 256 != 0) { 29 | x = torch::nn::functional::pad(x, torch::nn::functional::PadFuncOptions({0, 0, 0, (256 - numel % 256) / had_size})); 30 | } 31 | 32 | if (x.stride(-1) != 1) { 33 | x = x.contiguous(); 34 | } 35 | torch::Tensor out = inplace ? x : torch::empty_like(x); 36 | 37 | at::cuda::CUDAGuard device_guard{(char)x.get_device()}; 38 | auto stream = at::cuda::getCurrentCUDAStream().stream(); 39 | 40 | if (dtype == torch::ScalarType::Half) { 41 | run_fht(x.data_ptr(), out.data_ptr(), x.numel(), had_size, stream); 42 | } else { 43 | run_fht(x.data_ptr(), out.data_ptr(), x.numel(), had_size, stream); 44 | } 45 | 46 | if (numel % 256 != 0) { 47 | out = out.index({Slice(0, numel / had_size)}); 48 | } 49 | 50 | if (inplace && out.data_ptr() != in.data_ptr()) { 51 | in.copy_(out.view(res_shape)); 52 | return in; 53 | } 54 | return out.reshape(res_shape); 55 | } 56 | 57 | namespace py = pybind11; 58 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 59 | m.def("hadamard_transform", &hadamard_transform, "A function to perform a fast Hadamard transform", py::arg("x"), py::arg("inplace")=false); 60 | } -------------------------------------------------------------------------------- /kernels/cuda/inference/hadamard_transform/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | versions = [ 5 | "-gencode", 6 | "arch=compute_80,code=sm_80", 7 | "-gencode", 8 | "arch=compute_89,code=sm_89", 9 | "-gencode", 10 | "arch=compute_90,code=sm_90", 11 | ] # TODO: assumes installed CUDA toolkit supports sm_80 to sm_90 12 | 13 | setup( 14 | name='faster_hadamard_transform', 15 | ext_modules=[ 16 | CUDAExtension( 17 | name="faster_hadamard_transform", 18 | sources=[ 19 | "hadamard_transform.cpp", 20 | "hadamard_transform_cuda.cu", 21 | ], 22 | extra_compile_args={ 23 | "cxx": ["-O3"], 24 | "nvcc": [ 25 | "-O3", 26 | "-lineinfo", 27 | '--ptxas-options=--warn-on-local-memory-usage', 28 | '--ptxas-options=--warn-on-spills', 29 | ] + versions 30 | } 31 | ), 32 | ], 33 | cmdclass={ 34 | 'build_ext': BuildExtension 35 | } 36 | ) -------------------------------------------------------------------------------- /kernels/cuda/inference/hadamard_transform/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import faster_hadamard_transform 3 | import scipy.linalg 4 | import math 5 | 6 | # set to false to check performance 7 | correctness_check = True 8 | # set to warmup count + 1 to check performance 9 | # for quick testing, 2 is good. 10 | runs_per_size = 2 11 | 12 | # hadamard sizes 13 | test_sizes_m = [2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768] 14 | 15 | test_elem_counts = [1 << i for i in range(9, 26, 1)] # 32MB # 64MB # 2**28 = 256M 16 | 17 | print("test_sizes_m: ", test_sizes_m) 18 | print("test_elem_counts: ", test_elem_counts) 19 | 20 | test_count = len(test_sizes_m) * len(test_elem_counts) 21 | tests_done = 0 22 | failed_tests = 0 23 | 24 | def get_scale(size): 25 | return math.sqrt(1 / size) 26 | 27 | truth_hadamards = [torch.tensor(scipy.linalg.hadamard(size), device='cuda', dtype=torch.float32) * get_scale(size) for size in test_sizes_m] 28 | truth_hadamards = [(x.to(torch.float16), x.to(torch.bfloat16)) for x in truth_hadamards] 29 | truth_hadamards_fp16, truth_hadamards_bf16 = zip(*truth_hadamards) 30 | truth_hadamards_fp16 = list(truth_hadamards_fp16) 31 | truth_hadamards_bf16 = list(truth_hadamards_bf16) 32 | del truth_hadamards 33 | 34 | def truth_hadamard_transform_inplace(a: torch.Tensor, truth_hadamards): 35 | target_index = -1 36 | for i in range(len(test_sizes_m)): 37 | if test_sizes_m[i] == a.shape[1]: 38 | target_index = i 39 | break 40 | return a @ truth_hadamards[int(target_index)] 41 | 42 | def test_hadamard_transform_inplace_rowmajor(a: torch.Tensor): 43 | faster_hadamard_transform.hadamard_transform(a, inplace=True) 44 | return a 45 | 46 | torch.manual_seed(0) 47 | 48 | def check_correctness(m, elem_c, a, result, truth, atol=1e-2, rtol=0): 49 | success = torch.allclose(truth, result, atol=atol, rtol=rtol) 50 | 51 | if not success: 52 | torch.set_printoptions(threshold=100) 53 | print(f'Failed test: {m}x{elem_c // m}') 54 | print(f'Input:') 55 | print(a) 56 | print(f'Expected:') 57 | print(truth) 58 | print(f'Got:') 59 | print(result) 60 | # worst element 61 | diff = torch.abs(truth - result) 62 | max_diff = torch.max(diff) 63 | print(f'Max diff: {max_diff}') 64 | print(f'Max diff index: {torch.argmax(diff)}') 65 | diff_input = torch.abs(a - result) 66 | max_diff_input = torch.max(diff_input) 67 | print(f'Max diff input: {max_diff_input}') 68 | print('') 69 | exit(1) 70 | 71 | for m in test_sizes_m: 72 | for elem_c in test_elem_counts: 73 | if elem_c < m: 74 | tests_done += runs_per_size 75 | if tests_done % 100 == 0 or tests_done == test_count: 76 | print(f'{tests_done}/{test_count} tests done') 77 | continue 78 | print(f'Testing size {m}x{elem_c // m}') 79 | 80 | a = torch.randn((elem_c // m, m), device='cuda', dtype=torch.float32) 81 | # a = torch.zeros((m, elem_c // m), device='cuda', dtype=torch.float16) 82 | # for i in range(min(a.shape[0], a.shape[1])): 83 | # a[i, i] = 1.0 84 | if correctness_check: 85 | for i in range(runs_per_size): 86 | # run test here 87 | a_result_fp16 = a.clone().to(torch.float16) 88 | a_truth_fp16 = a.clone().to(torch.float16) 89 | result_fp16 = test_hadamard_transform_inplace_rowmajor(a_result_fp16) 90 | truth_fp16 = truth_hadamard_transform_inplace(a_truth_fp16, truth_hadamards_fp16) 91 | check_correctness(m, elem_c, a, result_fp16, truth_fp16, atol=1e-2) # TODO: NOTE: we are not accurate down to 3 decimal places (atol) 92 | 93 | a_result_bf16 = a.clone().to(torch.bfloat16) 94 | a_truth_bf16 = a.clone().to(torch.bfloat16) 95 | result_bf16 = test_hadamard_transform_inplace_rowmajor(a_result_bf16) 96 | truth_bf16 = truth_hadamard_transform_inplace(a_truth_bf16, truth_hadamards_bf16) 97 | check_correctness(m, elem_c, a, result_bf16, truth_bf16, atol=5e-2) # TODO: NOTE: need 5x atol to pass for bf16 98 | else: 99 | # run in a row so that warmup is valid 100 | a_result = a # we can clobber the result cause we are only interested in timing 101 | for i in range(runs_per_size): 102 | a_result = test_hadamard_transform_inplace_rowmajor(a_result) 103 | a_truth = a 104 | for i in range(runs_per_size): 105 | a_truth = truth_hadamard_transform_inplace(a_truth) 106 | a_memcpy = a 107 | # also can compare timing to memcpy 108 | temp = torch.empty_like(a) 109 | for i in range(runs_per_size): 110 | temp.copy_(a_memcpy) 111 | # do nothing with results since we are only interested in timing 112 | # NOTE: make sure to disable clearing cache in Nsight Compute 113 | 114 | tests_done += 1 115 | if tests_done % 100 == 0 or tests_done == test_count: 116 | print(f'{tests_done}/{test_count} size tests done') -------------------------------------------------------------------------------- /kernels/cuda/training/README.md: -------------------------------------------------------------------------------- 1 | kernels with backward pass support 2 | -------------------------------------------------------------------------------- /kernels/cuda/tutorials/README.md: -------------------------------------------------------------------------------- 1 | CUDA tutorials 2 | -------------------------------------------------------------------------------- /kernels/cuda/tutorials/flash2.cu: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | // flash2 9 | 10 | __global__ 11 | void forward_kernel(const float* Q, const float* K, const float* V, const int N, const int d, 12 | const int Tc, const int Tr, const int Bc, const int Br, const float sm_scale, 13 | float* l, float* m, float* O) 14 | { 15 | int tidx = threadIdx.x; 16 | int bidx = blockIdx.x; // batch index 17 | int bidy = blockIdx.y; // head index 18 | 19 | int qkv_offset = (bidx * gridDim.y * N * d) + (bidy*N*d); 20 | int lm_offset = (bidx * gridDim.y *N) + (bidy *N); //l and m offset 21 | 22 | extern __shared__ float sram[]; 23 | int tile_size = Bc * d; size of Qi, Kj, Vj 24 | 25 | float* Qi = sram; 26 | float * Kj = &sram[tile_size]; 27 | float* Vj = &sram[tile_size *2]; 28 | float* S = &sram[tile_size *3]; 29 | 30 | for (int j=0; j < Tc; j++) { 31 | 32 | // load Kj, Vj to sram 33 | for (int x=0; x < d; x++) { 34 | Kj[(tx*d)+x] = K[qkv_offset + (tile_size *j) + (tx*d) +x]; 35 | Vj[(tx*d) + x] = V[qkv_offset +(tile_size *j) + (tx*d) +x]; 36 | } 37 | __synchthreads(); 38 | 39 | } 40 | } 41 | 42 | 43 | for (int j = 0; j < Tc; j++) { 44 | 45 | // Load Kj, Vj to SRAM 46 | for (int x = 0; x < d; x++) { 47 | Kj[(tx * d) + x] = K[qkv_offset + (tile_size * j) + (tx * d) + x]; 48 | Vj[(tx * d) + x] = V[qkv_offset + (tile_size * j) + (tx * d) + x]; 49 | } 50 | __syncthreads(); // such that the inner loop can use the correct Kj, Vj 51 | -------------------------------------------------------------------------------- /kernels/needs_perf_help/fp8_gemm_bench.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from typing import Callable, Tuple 10 | 11 | #import click 12 | 13 | import torch 14 | import triton # @manual 15 | 16 | from fp8_gemm_rowwise import ( 17 | matmul_fp8_block, 18 | matmul_fp8_row, 19 | quantize_fp8_block, 20 | quantize_fp8_row, 21 | ) 22 | from torch._tensor import Tensor 23 | 24 | 25 | #@click.command() 26 | #@click.option("--cuda-graph", type=bool, default=True) 27 | #@click.option("--rowwise-tma", is_flag=True, default=False) 28 | def bench(cuda_graph: bool, rowwise_tma: bool=True) -> None: 29 | """Benchmark bf16 vs scale/cast + fp8.""" 30 | 31 | def _run_benchmark( 32 | bench_factory: Callable[ 33 | [torch.Tensor, torch.Tensor], Callable[[], torch.Tensor] 34 | ], 35 | shape: Tuple[int, int, int] = (1024, 1024, 1024), 36 | tag: str = "", 37 | ) -> None: 38 | # Benchmarks the function returned by bench_factory. 39 | # Any pre-processing that should not be benchmarked can occur inside bench_factory. 40 | m, n, k = shape 41 | 42 | input_shape = (m, k) 43 | weight_shape = (n, k) 44 | 45 | base_dtype = torch.bfloat16 46 | input_ = torch.randn(input_shape, device="cuda", dtype=base_dtype) 47 | weight_ = torch.randn(weight_shape, device="cuda", dtype=base_dtype) 48 | 49 | gemm_fn = bench_factory(input_, weight_) 50 | 51 | if cuda_graph: 52 | bench_stream = torch.cuda.Stream() 53 | with torch.cuda.stream(bench_stream): 54 | ms = triton.testing.do_bench_cudagraph( 55 | lambda: gemm_fn(), 56 | rep=100, 57 | ) 58 | else: 59 | ms = triton.testing.do_bench( 60 | lambda: gemm_fn(), 61 | warmup=25, 62 | rep=100, 63 | ) 64 | 65 | tflops = (2 * m * n * k) / 1e12 66 | sec = ms / 1e3 67 | perf_str = f"{tflops / sec:.2f}" 68 | print( 69 | f"{(tag + ':').ljust(40)}\tshape {str(shape):<25} tflops {perf_str:<8} ms {ms:.3f}" 70 | ) 71 | 72 | shapes = [ 73 | (8192, 8192, 512), 74 | (8192, 8192, 8192), 75 | (65536, 8192, 7168), 76 | (65536, 3584, 8192), 77 | (8192, 14336, 4096), 78 | ] 79 | for shape in shapes: 80 | _run_benchmark(bf16_bench, shape=shape, tag="bf16") 81 | _run_benchmark(scale_row_bench, shape=shape, tag="fp8 scale + row gemm") 82 | _run_benchmark(scale_block_bench, shape=shape, tag="fp8 scale + block gemm") 83 | _run_benchmark( 84 | row_gemm_bench, 85 | shape=shape, 86 | tag="fp8 row gemm only | fp8_fast_accum=True", 87 | ) 88 | _run_benchmark( 89 | row_gemm_bench_no_fast_acc, 90 | shape=shape, 91 | tag="fp8 row gemm only | fp8_fast_accum=False", 92 | ) 93 | _run_benchmark( 94 | row_gemm_bench_imprecise_acc, 95 | shape=shape, 96 | tag="fp8 row gemm only | max_num_imprecise_acc=32", 97 | ) 98 | _run_benchmark(block_gemm_bench, shape=shape, tag="fp8 block gemm only") 99 | if rowwise_tma: 100 | _run_benchmark( 101 | row_gemm_bench_tma, 102 | shape=shape, 103 | tag="fp8 row gemm only | fp8_fast_accum=True | tma_persistent=True", 104 | ) 105 | 106 | 107 | def bf16_bench(x: Tensor, w: Tensor) -> Callable[[], Tensor]: 108 | def gemm_fn() -> Tensor: 109 | return torch.matmul(x, w.T) 110 | 111 | return gemm_fn 112 | 113 | 114 | def scale_row_bench(x: Tensor, w: Tensor) -> Callable[[], Tensor]: 115 | # Benchmark quantize(x) + gemm for inference. 116 | def run_gemm() -> Tensor: 117 | x_fp8: Tensor 118 | w_fp8: Tensor 119 | x_scale: Tensor 120 | w_scale: Tensor 121 | x_fp8, x_scale = quantize_fp8_row(x) 122 | w_fp8, w_scale = quantize_fp8_row(w) 123 | return matmul_fp8_row( 124 | x_fp8, 125 | w_fp8, 126 | x_scale, 127 | w_scale, 128 | dot_out_dtype=torch.float32, 129 | allow_tf32=True, 130 | fp8_fast_accum=True, 131 | ) 132 | 133 | return run_gemm 134 | 135 | 136 | def row_gemm_bench(x: Tensor, w: Tensor) -> Callable[[], Tensor]: 137 | # Benchmark only row-wise gemm, caching scaling. 138 | x_fp8: Tensor 139 | w_fp8: Tensor 140 | x_scale: Tensor 141 | w_scale: Tensor 142 | x_fp8, x_scale = quantize_fp8_row(x) 143 | w_fp8, w_scale = quantize_fp8_row(w) 144 | 145 | def run_gemm() -> Tensor: 146 | return matmul_fp8_row( 147 | x_fp8, 148 | w_fp8, 149 | x_scale, 150 | w_scale, 151 | dot_out_dtype=torch.float32, 152 | allow_tf32=True, 153 | fp8_fast_accum=True, 154 | ) 155 | 156 | return run_gemm 157 | 158 | 159 | def row_gemm_bench_tma(x: Tensor, w: Tensor) -> Callable[[], Tensor]: 160 | # Benchmark only row-wise gemm with TMA persistent 161 | x_fp8: Tensor 162 | w_fp8: Tensor 163 | x_scale: Tensor 164 | w_scale: Tensor 165 | x_fp8, x_scale = quantize_fp8_row(x) 166 | w_fp8, w_scale = quantize_fp8_row(w) 167 | 168 | def run_gemm() -> Tensor: 169 | return matmul_fp8_row( 170 | x_fp8, 171 | w_fp8, 172 | x_scale, 173 | w_scale, 174 | dot_out_dtype=torch.float32, 175 | allow_tf32=True, 176 | fp8_fast_accum=True, 177 | tma_persistent=True, 178 | ) 179 | 180 | return run_gemm 181 | -------------------------------------------------------------------------------- /kernels/triton/inference/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch-labs/applied-ai/4f8c39eb70d92e3848c65347ca077c8e723c74fd/kernels/triton/inference/.DS_Store -------------------------------------------------------------------------------- /kernels/triton/inference/README.md: -------------------------------------------------------------------------------- 1 | Triton Inference kernels 2 | -------------------------------------------------------------------------------- /kernels/triton/inference/col_major_moe_gemm/README.md: -------------------------------------------------------------------------------- 1 | 2 | **MoE (Mixture of Experts) GEMM Kernels** 3 | 4 | 5 | Triton kernel supporting and accelerating MoE inference (Mixtral). 6 | This kernel was contributed by IBM Research. 7 | 8 | This kernel showcases the following optimizations: 9 | 10 | * Column-Major Launch Schedule (L2 Cache Optimization) 11 | * SplitK Work Decomposition (Parallel Work Strategy Optimization) 12 | 13 | See blog post: https://pytorch.org/blog/accelerating-moe-model/ 14 | 15 | 16 | * v0 = grouped MM 17 | * v1 = SplitK MM 18 | * v2 = Col Major MM 19 | 20 | This requires vLLM to be installed to run. 21 | -------------------------------------------------------------------------------- /kernels/triton/inference/col_major_moe_gemm/perf_test_moe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import pytest 8 | import torch 9 | import triton 10 | from vllm.model_executor.layers.fused_moe import fused_moe 11 | from vllm.model_executor.layers.activation import SiluAndMul 12 | from v0_moe_fused import fused_moe as fused_moe_grouped 13 | from v2_moe_fused import fused_moe as fused_moe_col 14 | import time 15 | 16 | def torch_moe(a, w1, w2, topk_weight, topk_ids): 17 | B, D = a.shape 18 | a = a.view(B, -1, D).repeat(1, topk_ids.shape[1], 1).reshape(-1, D) 19 | out = torch.zeros(B * topk_ids.shape[1], 20 | w2.shape[1], 21 | dtype=a.dtype, 22 | device=a.device) 23 | 24 | topk_ids = topk_ids.view(-1) 25 | topk_weight = topk_weight.view(-1) 26 | for i in range(w1.shape[0]): 27 | mask = topk_ids == i 28 | if mask.sum(): 29 | out[mask] = SiluAndMul()(a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1) 30 | return (out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1)).sum(dim=1) 31 | 32 | 33 | def test_fused_moe( 34 | m: int, 35 | n: int, 36 | k: int, 37 | e: int, 38 | topk: int, 39 | dtype: torch.dtype, 40 | ): 41 | torch.cuda.manual_seed(3227) 42 | 43 | a = torch.randn((m, k), device='cuda', dtype=dtype) / 10 44 | w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10 45 | w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10 46 | 47 | score = torch.randn((m, e), device='cuda', dtype=dtype) 48 | score = torch.softmax(score, dim=-1) 49 | 50 | topk_weight, topk_ids = torch.topk(score, topk) 51 | 52 | start = time.time() 53 | triton_output_gl = fused_moe_grouped(a, w1, w2, topk_weight, topk_ids, False) 54 | end = time.time() 55 | gl_time = end - start 56 | gl_time = gl_time * 1000 57 | print("Grouped Launch Time (us): ", gl_time) 58 | 59 | start = time.time() 60 | triton_output_cm = fused_moe_col(a, w1, w2, topk_weight, topk_ids, False) 61 | end = time.time() 62 | cm_major_time = end - start 63 | cm_major_time = cm_major_time * 1000 64 | print("Columm Major Time (us): ", cm_major_time) 65 | 66 | torch_base = torch_moe(a, w1, w2, topk_weight, topk_ids) 67 | torch.testing.assert_close(triton_output_cm, torch_base, atol=1e-2, rtol=0) 68 | 69 | # print(f"{triton_output_cm=}\n") 70 | # print(f"{triton_output_gl=}\n") 71 | 72 | print(f"Col Major Speedup {((gl_time - cm_major_time)/(gl_time))*100}") 73 | 74 | 75 | if __name__ == '__main__': 76 | 77 | 78 | # test_fused_moe(512, 14336//2, 4096, 8, 2, torch.float16) 79 | 80 | @triton.testing.perf_report( 81 | triton.testing.Benchmark( 82 | x_names=['m'], # Argument names to use as an x-axis for the plot 83 | x_vals=[ 84 | 2**i for i in range(0, 10) 85 | ], # Different possible values for `x_name` 86 | line_arg='provider', # Argument name whose value corresponds to a different line in the plot 87 | # Possible values for `line_arg` 88 | line_vals=['cm', 'gl'], 89 | # Label name for the lines 90 | line_names=["Fused MoE GEMM Kernel - Column Major", "vLLM MoE GEMM Kernel"], 91 | 92 | # Line styles 93 | styles=[('blue', '-'), ('green', '-')], 94 | ylabel="TFLOPS", # Label name for the y-axis 95 | plot_name="test", # Name for the plot, used also as a file name for saving the plot. 96 | args={}, 97 | ) 98 | ) 99 | def benchmark(m, provider): 100 | 101 | m = m 102 | n = 14336//2 103 | k = 4096 104 | e = 8 105 | topk = 2 106 | 107 | torch.cuda.manual_seed(3227) 108 | dtype = torch.float16 109 | 110 | a = torch.randn((m, k), device='cuda', dtype=dtype) / 10 111 | w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10 112 | w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10 113 | 114 | score = torch.randn((m, e), device='cuda', dtype=dtype) 115 | score = torch.softmax(score, dim=-1) 116 | topk_weight, topk_ids = torch.topk(score, topk) 117 | 118 | quantiles = [0.5, 0.2, 0.8] 119 | if provider == 'cm': 120 | ms, min_ms, max_ms = triton.testing.do_bench(lambda: fused_moe_col(a, w1, w2, topk_weight, topk_ids, False), quantiles=quantiles) 121 | if provider == 'gl': 122 | ms, min_ms, max_ms = triton.testing.do_bench(lambda: fused_moe_grouped(a, w1, w2, topk_weight, topk_ids, False), quantiles=quantiles) 123 | perf = lambda ms: 2 * m * n * k * 1e-12 / (ms * 1e-3) 124 | return perf(ms), perf(max_ms), perf(min_ms) 125 | 126 | benchmark.run(show_plots=True, print_data=True, save_path='./') 127 | -------------------------------------------------------------------------------- /kernels/triton/inference/col_major_moe_gemm/profile_moe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import pytest 8 | import torch 9 | from vllm.model_executor.layers.fused_moe import fused_moe 10 | from vllm.model_executor.layers.activation import SiluAndMul 11 | from v0_moe_fused import fused_moe as fused_moe_base 12 | from triton.kernels.mixtral.v1_moe_fused import fused_moe 13 | import time 14 | 15 | def torch_moe(a, w1, w2, topk_weight, topk_ids): 16 | B, D = a.shape 17 | a = a.view(B, -1, D).repeat(1, topk_ids.shape[1], 1).reshape(-1, D) 18 | out = torch.zeros(B * topk_ids.shape[1], 19 | w2.shape[1], 20 | dtype=a.dtype, 21 | device=a.device) 22 | 23 | topk_ids = topk_ids.view(-1) 24 | topk_weight = topk_weight.view(-1) 25 | for i in range(w1.shape[0]): 26 | mask = topk_ids == i 27 | if mask.sum(): 28 | out[mask] = SiluAndMul()(a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1) 29 | return (out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1)).sum(dim=1) 30 | 31 | 32 | def test_fused_moe( 33 | m: int, 34 | n: int, 35 | k: int, 36 | e: int, 37 | topk: int, 38 | dtype: torch.dtype, 39 | ): 40 | 41 | a = torch.randn((m, k), device='cuda', dtype=dtype) / 10 42 | w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10 43 | w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10 44 | 45 | score = torch.randn((m, e), device='cuda', dtype=dtype) 46 | score = torch.softmax(score, dim=-1) 47 | topk_weight, topk_ids = torch.topk(score, topk) 48 | 49 | triton_output_splitk = fused_moe(a, w1, w2, topk_weight, topk_ids, False) 50 | triton_output_base = fused_moe_base(a, w1, w2, topk_weight, topk_ids, False) 51 | 52 | 53 | if __name__ == '__main__': 54 | 55 | test_fused_moe(2, 14336//2, 4096, 8, 2, torch.float16) 56 | -------------------------------------------------------------------------------- /kernels/triton/inference/col_major_moe_gemm/results.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /kernels/triton/inference/col_major_moe_gemm/test.csv: -------------------------------------------------------------------------------- 1 | m,Fused MoE GEMM Kernel - Column Major,vLLM MoE GEMM Kernel 2 | 1.000000,0.412454,0.259585 3 | 2.000000,0.883064,0.269004 4 | 4.000000,1.751380,0.447645 5 | 8.000000,2.106783,0.571765 6 | 16.000000,4.121877,1.002326 7 | 32.000000,8.259988,1.991226 8 | 64.000000,16.105391,3.879061 9 | 128.000000,29.356460,7.191373 10 | 256.000000,50.550095,12.524316 11 | 512.000000,72.862390,19.934314 12 | -------------------------------------------------------------------------------- /kernels/triton/inference/col_major_moe_gemm/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch-labs/applied-ai/4f8c39eb70d92e3848c65347ca077c8e723c74fd/kernels/triton/inference/col_major_moe_gemm/test.png -------------------------------------------------------------------------------- /kernels/triton/inference/col_major_moe_gemm/test_moe_gemm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import pytest 8 | import torch 9 | from vllm.model_executor.layers.fused_moe import fused_moe 10 | from vllm.model_executor.layers.activation import SiluAndMul 11 | from v0_moe_fused import fused_moe as fused_moe_v0 12 | from v1_moe_fused import fused_moe as fused_moe_v1 13 | from splitk_moe_fused import fused_moe 14 | import time 15 | 16 | def torch_moe(a, w1, w2, topk_weight, topk_ids): 17 | B, D = a.shape 18 | a = a.view(B, -1, D).repeat(1, topk_ids.shape[1], 1).reshape(-1, D) 19 | out = torch.zeros(B * topk_ids.shape[1], 20 | w2.shape[1], 21 | dtype=a.dtype, 22 | device=a.device) 23 | 24 | topk_ids = topk_ids.view(-1) 25 | topk_weight = topk_weight.view(-1) 26 | for i in range(w1.shape[0]): 27 | mask = topk_ids == i 28 | if mask.sum(): 29 | out[mask] = SiluAndMul()(a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1) 30 | return (out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1)).sum(dim=1) 31 | 32 | 33 | @pytest.mark.parametrize("m", [2, 4, 8, 16, 32, 64, 128, 512, 1024, 2048]) 34 | @pytest.mark.parametrize("n", [14336//2]) 35 | @pytest.mark.parametrize("k", [4096]) 36 | @pytest.mark.parametrize("e", [8]) 37 | @pytest.mark.parametrize("topk", [2]) 38 | @pytest.mark.parametrize("dtype", [torch.float16]) 39 | def test_fused_moe( 40 | m: int, 41 | n: int, 42 | k: int, 43 | e: int, 44 | topk: int, 45 | dtype: torch.dtype, 46 | ): 47 | 48 | torch.cuda.manual_seed(3227) 49 | a = torch.randn((m, k), device='cuda', dtype=dtype) / 10 50 | w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10 51 | 52 | w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10 53 | 54 | score = torch.randn((m, e), device='cuda', dtype=dtype) 55 | score = torch.softmax(score, dim=-1) 56 | 57 | topk_weight, topk_ids = torch.topk(score, topk) 58 | 59 | start = time.time() 60 | triton_output_gl = fused_moe_v0(a, w1, w2, topk_weight, topk_ids, False) 61 | end = time.time() 62 | 63 | gl_time = end - start 64 | gl_time = gl_time * 1000 65 | print("Grouped Launch Time (us): \n", gl_time) 66 | 67 | 68 | start = time.time() 69 | triton_output_cm = fused_moe_v1(a, w1, w2, topk_weight, topk_ids, False) 70 | end = time.time() 71 | cm_major_time = end - start 72 | cm_major_time = cm_major_time * 1000 73 | print("Columm Major Time (us): \n", cm_major_time) 74 | 75 | 76 | torch_base = torch_moe(a, w1, w2, topk_weight, topk_ids) 77 | 78 | assert torch.allclose(triton_output_cm, torch_base, atol=1e-2, rtol=0) 79 | assert torch.allclose(triton_output_cm, triton_output_gl, atol=1e-2, rtol=0) 80 | 81 | # print(f"{triton_output_cm=}\n") 82 | # print(f"{triton_output_gl=}\n") 83 | # print(f"{torch_base=}\n") 84 | 85 | print(f"Col Major Speedup: {((gl_time/cm_major_time))} x\n") 86 | -------------------------------------------------------------------------------- /kernels/triton/inference/flash_attention/stay_attention.py: -------------------------------------------------------------------------------- 1 | import triton.language as tl 2 | import triton 3 | import torch 4 | 5 | 6 | @triton.jit() 7 | def stay_attention( 8 | q_ptr, k_ptr, v_ptr, o_ptr, 9 | stride_b, stride_nh, 10 | stride_qs, stride_qh, 11 | stride_ks, stride_kh, 12 | stride_vs, stride_vh, 13 | stride_os, stride_oh, 14 | seq_len, head_dim, 15 | sm_scale, 16 | BLOCK_SEQ: tl.constexpr, 17 | BLOCK_HD: tl.constexpr, 18 | NUM_SM: tl.constexpr, 19 | ): 20 | 21 | pid_b = tl.program_id(0) 22 | pid_h = tl.program_id(1) 23 | pid = tl.program_id(2) 24 | 25 | qkv_offset = pid_b*stride_b + pid_h*stride_nh 26 | num_tiles_seq_len = tl.cdiv(seq_len, BLOCK_SEQ) 27 | 28 | tiles_per_SM = num_tiles_seq_len // NUM_SM 29 | if pid < num_tiles_seq_len % NUM_SM: 30 | tiles_per_SM += 1 31 | 32 | tile_id = pid - NUM_SM 33 | si = -1 34 | 35 | pid_seq_m = 0 36 | pid_seq_n = 0 37 | 38 | offs_seq_m = tl.arange(0, BLOCK_SEQ) 39 | offs_seq_n = tl.arange(0, BLOCK_SEQ) 40 | offs_head = tl.arange(0, BLOCK_HD) 41 | 42 | q_ptrs = q_ptr + qkv_offset + offs_seq_n[:, None]*stride_qs + offs_head[None, :]*stride_qh 43 | 44 | # initialize pointer to m and l 45 | m_i = tl.zeros([BLOCK_SEQ], dtype=tl.float32) - float("inf") 46 | l_i = tl.zeros([BLOCK_SEQ], dtype=tl.float32) 47 | qk_scale = sm_scale * 1.44269504 48 | 49 | q = tl.load(q_ptrs) 50 | q = (q * qk_scale) 51 | 52 | pv = tl.zeros([BLOCK_SEQ, BLOCK_HD], dtype=tl.float32) 53 | for _ in range(0, num_tiles_seq_len * tiles_per_SM): 54 | 55 | si = tl.where(si == num_tiles_seq_len - 1, 0, si + 1) 56 | 57 | if si == 0: 58 | 59 | tile_id += NUM_SM 60 | 61 | pid_seq_m = pid // num_tiles_seq_len 62 | pid_seq_n = pid % num_tiles_seq_len 63 | 64 | offs_seq_m = pid_seq_m*BLOCK_SEQ + tl.arange(0, BLOCK_SEQ) 65 | offs_seq_n = pid_seq_n*BLOCK_SEQ + tl.arange(0, BLOCK_SEQ) 66 | offs_head = tl.arange(0, BLOCK_HD) 67 | 68 | q_ptrs = q_ptr + qkv_offset + offs_seq_n[:, None]*stride_qs + offs_head[None, :]*stride_qh 69 | 70 | qk_scale = sm_scale * 1.44269504 71 | q = tl.load(q_ptrs) 72 | q = (q * qk_scale) 73 | 74 | offs_seq_m = si*BLOCK_SEQ + tl.arange(0, BLOCK_SEQ) 75 | offs_head = tl.arange(0, BLOCK_HD) 76 | 77 | k_ptrs = k_ptr + qkv_offset + offs_seq_m[:, None]*stride_ks + offs_head[None, :]*stride_kh 78 | v_ptrs = v_ptr + qkv_offset + offs_seq_m[:, None]*stride_vs + offs_head[None, :]*stride_vh 79 | 80 | k = tl.load(k_ptrs) 81 | v = tl.load(v_ptrs) 82 | 83 | qk = tl.dot(q.to(tl.float16), k.T, out_dtype=tl.float32) 84 | 85 | # -- compute scaling constant --- 86 | m_i_new = tl.maximum(m_i, tl.max(qk, 1)) 87 | alpha = tl.math.exp2(m_i - m_i_new) 88 | p = tl.math.exp2(qk - m_i_new[:, None]) 89 | 90 | # -- scale and update acc -- 91 | pv *= alpha[:, None] 92 | pv += tl.dot(p.to(tl.float16), v, out_dtype=tl.float32) 93 | 94 | # -- update m_i and l_i -- 95 | l_i = l_i * alpha + tl.sum(p, 1) 96 | m_i = m_i_new 97 | 98 | if si == num_tiles_seq_len - 1: 99 | 100 | offs_seq_n = pid_seq_n*BLOCK_SEQ + tl.arange(0, BLOCK_SEQ) 101 | pv = pv / l_i[:, None] 102 | o_ptrs = o_ptr + qkv_offset + offs_seq_n[:, None]*stride_os + offs_head[None, :]*stride_oh 103 | tl.store(o_ptrs, pv) 104 | pv = tl.zeros([BLOCK_SEQ, BLOCK_HD], dtype=tl.float32) 105 | 106 | 107 | def flash_fn(q, k, v): 108 | 109 | batch, num_heads, seq_len, head_dim = q.shape 110 | 111 | sm_scale = 0.5 112 | BLOCK_SEQ = 64 113 | BLOCK_HD = 128 114 | 115 | NUM_SM = torch.cuda.get_device_properties("cuda").multi_processor_count 116 | grid = (batch, num_heads, NUM_SM) 117 | o = torch.zeros(batch, num_heads, seq_len, head_dim, dtype=torch.float16, device='cuda') 118 | stay_attention[grid](q, k, v, o, 119 | q.stride(0), q.stride(1), 120 | q.stride(2), q.stride(3), 121 | k.stride(2), k.stride(3), 122 | v.stride(2), v.stride(3), 123 | o.stride(2), o.stride(3), 124 | seq_len, head_dim, 125 | sm_scale, 126 | BLOCK_SEQ, BLOCK_HD, NUM_SM) 127 | return o 128 | 129 | 130 | if __name__ == '__main__': 131 | 132 | torch.manual_seed(0) 133 | 134 | batch, num_heads, seq_len, head_dim = 1, 32, 4096, 128 135 | 136 | q = torch.randn(batch, num_heads, seq_len, head_dim, dtype=torch.float16, device='cuda') // 10 137 | k = torch.randn(batch, num_heads, seq_len, head_dim, dtype=torch.float16, device='cuda') // 10 138 | v = torch.randn(batch, num_heads, seq_len, head_dim, dtype=torch.float16, device='cuda') // 10 139 | 140 | sm_scale = 0.5 141 | p = (q @ k.transpose(2, 3)) * sm_scale 142 | p = torch.softmax(p.float(), dim=-1) 143 | o_torch = torch.matmul(p.to(torch.float16), v) 144 | 145 | o_triton = flash_fn(q, k, v) 146 | 147 | print(f"{o_triton=}") 148 | print(f"{o_torch=}") 149 | 150 | torch.testing.assert_close(o_triton, o_torch, atol=1e-2, rtol=0) 151 | 152 | -------------------------------------------------------------------------------- /kernels/triton/inference/fp8/float8_groupwise_quant.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | from typing import Tuple 9 | 10 | import torch 11 | import triton 12 | import triton.language as tl 13 | from triton import Config 14 | 15 | # global constants 16 | FP8_MAX: tl.constexpr = 448.0 17 | EPSILON: tl.constexpr = 1e-12 18 | 19 | 20 | @triton.jit 21 | def _float8_groupwise_quant_kernel( 22 | in_ptr, out_ptr, scale_ptr, BLOCK_SIZE: tl.constexpr 23 | ): 24 | """ 25 | Quantizes the input tensor via BLOCK_SIZE groupwise scaling (i.e. 1x 128). 26 | 27 | Results: 28 | Stores 29 | 1 - float8_e4m3fn result in `out_ptr` 30 | 2 - scaling factor in `scale_ptr` 31 | 32 | """ 33 | pid = tl.program_id(axis=0) 34 | 35 | # load inputs 36 | offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 37 | x_vec = tl.load(in_ptr + offsets).to(tl.float32) 38 | 39 | # calc max and scale 40 | max_val = tl.max(tl.abs(x_vec)) 41 | safe_scale = tl.maximum(max_val, EPSILON) / FP8_MAX 42 | y_vec = x_vec / safe_scale 43 | 44 | # quantize 45 | y_clamped = tl.minimum(tl.maximum(y_vec, -FP8_MAX), FP8_MAX) 46 | y_fp8 = y_clamped.to(out_ptr.dtype.element_ty) 47 | 48 | # store quantized values and scale 49 | tl.store(out_ptr + offsets, y_fp8) 50 | tl.store(scale_ptr + pid, safe_scale) 51 | 52 | 53 | def float8_groupwise_quantize(x: torch.Tensor, block_size=128): 54 | """ 55 | Quantizes the input tensor via block_size groupwise scaling (i.e. 1x 128) 56 | to torch.float8_e4m3fn format. 57 | 58 | Results: 59 | Stores 60 | 1 - float8_e4m3fn result in `out_ptr` 61 | 2 - scaling factor in `scale_ptr` 62 | 63 | """ 64 | # verify input tensor 65 | x_last_dim_size = x.size(-1) 66 | 67 | # evenly divisible? 68 | if x_last_dim_size % block_size != 0: 69 | raise ValueError( 70 | f"Input tensor must have a last dimension that is a multiple of {block_size}" 71 | ) 72 | # contiguous? 73 | if x.stride(-1) != 1: 74 | x = x.contiguous() 75 | 76 | # allocate output tensors 77 | output = torch.empty_like(x, dtype=torch.float8_e4m3fn) 78 | scales = x.new_empty( 79 | *x.size()[:-1], x_last_dim_size // block_size, dtype=torch.float32 80 | ) 81 | print(f"{scales.size()=}") 82 | 83 | grid = lambda meta: (x.numel() // block_size,) 84 | _float8_groupwise_quant_kernel[grid]( 85 | in_ptr=x, 86 | out_ptr=output, 87 | scale_ptr=scales, 88 | BLOCK_SIZE=block_size, 89 | ) 90 | 91 | return output, scales 92 | -------------------------------------------------------------------------------- /kernels/triton/inference/fp8/scaled_fp8_gemm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import triton 3 | import triton.language as tl 4 | import time 5 | import os 6 | os.environ['ENABLE_TMA'] = '1' 7 | 8 | 9 | @triton.jit 10 | def grouped_launch(pid, 11 | m, n, 12 | block_m: tl.constexpr, block_n: tl.constexpr, group_m: tl.constexpr): 13 | 14 | grid_m = tl.cdiv(m, block_m) 15 | grid_n = tl.cdiv(n, block_n) 16 | 17 | width = group_m * grid_n 18 | group_id = pid // width 19 | group_size = tl.minimum(grid_m - group_id * group_m, group_m) 20 | 21 | pid_m = group_id * group_m + (pid % group_size) 22 | pid_n = (pid % width) // group_size 23 | 24 | return pid_m, pid_n 25 | 26 | @triton.jit() 27 | def column_major(pid, 28 | m, n, 29 | block_m: tl.constexpr, block_n: tl.constexpr): 30 | 31 | grid_m = tl.cdiv(m, block_m) 32 | 33 | pid_m = pid % grid_m 34 | pid_n = pid // grid_m 35 | 36 | return pid_m, pid_n 37 | 38 | @triton.jit 39 | def scaled_gemm_splitk(a_ptr, b_ptr, c_ptr, 40 | stride_am, stride_ak, 41 | stride_bk, stride_bn, 42 | stride_cm, stride_cn, 43 | scale_a, scale_b, 44 | m, n, k, 45 | block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr, 46 | split_k: tl.constexpr, group_m: tl.constexpr): 47 | 48 | pid = tl.program_id(0) 49 | pid_k = tl.program_id(1) 50 | grid_k = tl.cdiv(k, block_k*split_k) 51 | 52 | # Column Major produces speedup over Grouped Launch for small-to-medium M 53 | pid_m, pid_n = column_major(pid, 54 | m, n, 55 | block_m, block_n) 56 | 57 | 58 | offs_m = pid_m*block_m + tl.arange(0, block_m) 59 | offs_n = pid_n*block_n + tl.arange(0, block_n) 60 | offs_k = pid_k*block_k + tl.arange(0, block_k) 61 | 62 | offs_am = tl.max_contiguous(tl.multiple_of(offs_m, block_m), block_m) 63 | offs_bn = tl.max_contiguous(tl.multiple_of(offs_n, block_n), block_n) 64 | 65 | a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) 66 | b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) 67 | 68 | 69 | acc = tl.zeros((block_m, block_n), dtype=tl.float32) 70 | for k_ in range(0, grid_k): 71 | 72 | k_remaining = k - k_ * (block_k * split_k) 73 | 74 | a = tl.load(a_ptrs, mask=offs_k[None, :] < k_remaining, other=0.0) 75 | b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0) 76 | 77 | acc = tl.dot(a, b, acc, out_dtype=tl.float32) 78 | 79 | a_ptrs += block_k * split_k * stride_ak 80 | b_ptrs += block_k * split_k * stride_bk 81 | 82 | # Scaled in SRAM before write back to DRAM 83 | acc = scale_a * scale_b * acc 84 | acc.to(tl.float16) 85 | 86 | offs_m = pid_m*block_m + tl.arange(0, block_m) 87 | offs_n = pid_n*block_n + tl.arange(0, block_n) 88 | 89 | c_ptrs = c_ptr + (offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn) 90 | mask = (offs_m < m)[:, None] & (offs_n < n)[None, :] 91 | 92 | tl.atomic_add(c_ptrs, acc, mask=mask) 93 | 94 | def scaled_mm_splitk(a, b, scale_a: float=1.0, scale_b: float=1.0): 95 | assert a.shape[1] == b.shape[0] 96 | m, k = a.shape 97 | _, n = b.shape 98 | 99 | block_m = 64 100 | block_n = 64 101 | block_k = 256 102 | num_stages = 3 103 | num_warps = 8 104 | split_k = 4 105 | group_m = 8 106 | 107 | total_blocks_m = triton.cdiv(m, block_m) 108 | total_blocks_n = triton.cdiv(n, block_n) 109 | total_programs_mn = total_blocks_m * total_blocks_n 110 | total_programs_k = split_k 111 | 112 | grid = (total_programs_mn, total_programs_k) 113 | 114 | c = torch.zeros((m, n), device=a.device, dtype=torch.float16) 115 | k = scaled_gemm_splitk[grid](a, b, c, 116 | a.stride(0), a.stride(1), 117 | b.stride(0), b.stride(1), 118 | c.stride(0), c.stride(1), 119 | scale_a, scale_b, 120 | m, n, k, 121 | block_m, block_n, block_k, 122 | split_k, group_m, num_stages=num_stages, num_warps=num_warps) 123 | 124 | return c -------------------------------------------------------------------------------- /kernels/triton/inference/fp8/splitk_gemm_fp8.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import triton 3 | import triton.language as tl 4 | import time 5 | import os 6 | os.environ['ENABLE_TMA'] = '1' 7 | 8 | @triton.jit 9 | def grouped_launch(pid, 10 | m, n, 11 | block_m: tl.constexpr, block_n: tl.constexpr, group_m: tl.constexpr): 12 | 13 | grid_m = tl.cdiv(m, block_m) 14 | grid_n = tl.cdiv(n, block_n) 15 | 16 | width = group_m * grid_n 17 | group_id = pid // width 18 | group_size = tl.minimum(grid_m - group_id * group_m, group_m) 19 | 20 | pid_m = group_id * group_m + (pid % group_size) 21 | pid_n = (pid % width) // group_size 22 | 23 | return pid_m, pid_n 24 | 25 | 26 | @triton.jit() 27 | def col_major(pid, 28 | m, n, 29 | block_m: tl.constexpr, block_n: tl.constexpr): 30 | 31 | grid_m = tl.cdiv(m, block_m) 32 | 33 | pid_m = pid % grid_m 34 | pid_n = pid // grid_m 35 | 36 | return pid_m, pid_n 37 | 38 | 39 | @triton.jit 40 | def gemm_split_k_kernel(a_ptr, b_ptr, c_ptr, 41 | stride_am, stride_ak, 42 | stride_bk, stride_bn, 43 | stride_cm, stride_cn, 44 | m, n, k, 45 | block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr, 46 | split_k: tl.constexpr, group_m: tl.constexpr): 47 | 48 | pid = tl.program_id(0) 49 | pid_k = tl.program_id(1) 50 | grid_k = tl.cdiv(k, block_k*split_k) 51 | 52 | pid_m, pid_n = grouped_launch(pid, 53 | m, n, 54 | block_m, block_n, group_m) 55 | 56 | offs_m = pid_m*block_m + tl.arange(0, block_m) 57 | offs_n = pid_n*block_n + tl.arange(0, block_n) 58 | offs_k = pid_k*block_k + tl.arange(0, block_k) 59 | 60 | offs_am = tl.max_contiguous(tl.multiple_of(offs_m, block_m), block_m) 61 | offs_bn = tl.max_contiguous(tl.multiple_of(offs_n, block_n), block_n) 62 | 63 | a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) 64 | b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) 65 | 66 | 67 | acc = tl.zeros((block_m, block_n), dtype=tl.float32) 68 | for k_ in range(0, grid_k): 69 | 70 | k_remaining = k - k_ * (block_k * split_k) 71 | 72 | a = tl.load(a_ptrs, mask=offs_k[None, :] < k_remaining, other=0.0) 73 | b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0) 74 | 75 | acc = tl.dot(a, b, acc, out_dtype=tl.float32) 76 | 77 | a_ptrs += block_k * split_k * stride_ak 78 | b_ptrs += block_k * split_k * stride_bk 79 | 80 | acc.to(tl.float16) 81 | 82 | offs_m = pid_m*block_m + tl.arange(0, block_m) 83 | offs_n = pid_n*block_n + tl.arange(0, block_n) 84 | 85 | c_ptrs = c_ptr + (offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn) 86 | mask = (offs_m < m)[:, None] & (offs_n < n)[None, :] 87 | 88 | tl.atomic_add(c_ptrs, acc, mask=mask) 89 | 90 | def gemm_split_k(a, b): 91 | 92 | m, k = a.shape 93 | _, n = b.shape 94 | 95 | block_m = 64 96 | block_n = 64 97 | block_k = 512 98 | num_stages = 3 99 | num_warps = 8 100 | split_k = 4 101 | group_m = 8 102 | 103 | total_blocks_m = triton.cdiv(m, block_m) 104 | total_blocks_n = triton.cdiv(n, block_n) 105 | total_programs_mn = total_blocks_m * total_blocks_n 106 | total_programs_k = split_k 107 | 108 | grid = (total_programs_mn, total_programs_k) 109 | 110 | # print(f"problem m size: {m}, tile size m: {block_m}, total blocks m: {total_blocks_m}") 111 | # print(f"problem n size: {n}, tile size n: {block_n}, total blocks n: {total_blocks_n}") 112 | # print(f"problem k size: {k}, tile size k: {block_k}, total thread blocks k: {split_k}") 113 | 114 | # print(f"total thread blocks k: {k}, total thread blocks m and total thread blocks n = {total_blocks_m=} x {total_blocks_n} = {total_programs_mn}") 115 | # print(f"{total_programs_mn=}, {total_programs_k=}") 116 | 117 | c = torch.zeros((m, n), device=a.device, dtype=torch.float16) 118 | k = gemm_split_k_kernel[grid](a, b, c, 119 | a.stride(0), a.stride(1), 120 | b.stride(0), b.stride(1), 121 | c.stride(0), c.stride(1), 122 | m, n, k, 123 | block_m, block_n, block_k, 124 | split_k, group_m, num_stages=num_stages, num_warps=num_warps) 125 | 126 | # print(f"{k.n_regs} registers used, {k.n_spills} spills, {k.shared/1000} kB shared memory\n") 127 | 128 | # with open('matmul_split_k.txt', 'w') as f: 129 | 130 | # print(f"{k.n_regs} registers used, {k.n_spills} spills, {k.shared/1000} kB shared memory\n", file=f) 131 | # print("IR", k.asm['ttir'], file=f) 132 | # print("TTGIR", k.asm['ttgir'], file=f) 133 | # print("PTX", k.asm['ptx'], file=f) 134 | # print(f"{k.n_regs} registers used, {k.n_spills} spills, {k.shared/1000} kB shared memory\n", file=f) 135 | 136 | return c 137 | 138 | 139 | 140 | 141 | 142 | -------------------------------------------------------------------------------- /kernels/triton/inference/fp8/tma_gemm.py: -------------------------------------------------------------------------------- 1 | import triton 2 | import triton.language as tl 3 | import numpy as np 4 | import torch 5 | 6 | @triton.jit 7 | def gemm_kernel_tma(a_desc_ptr, b_desc_ptr, c_desc_ptr, # 8 | prob_m, prob_n, prob_k, block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr): 9 | 10 | pid = tl.program_id(axis=0) 11 | num_pid_m = tl.cdiv(prob_m, block_m) 12 | num_pid_k = tl.cdiv(prob_k, block_k) 13 | pid_m = pid % num_pid_m 14 | pid_n = pid // num_pid_m 15 | offs_am = pid_m * block_m 16 | offs_bn = pid_n * block_n 17 | offs_k = 0 18 | 19 | accumulator = tl.zeros((block_m, block_n), dtype=tl.float32) 20 | for kk in range(0, num_pid_k): 21 | 22 | a = tl._experimental_descriptor_load(a_desc_ptr, [offs_am, offs_k], [block_m, block_k], tl.float8e4nv) 23 | b = tl._experimental_descriptor_load(b_desc_ptr, [offs_bn, offs_k], [block_n, block_k], tl.float8e4nv) 24 | 25 | accumulator = tl.dot(a, b.T, acc=accumulator, out_dtype=tl.float32) 26 | offs_k += block_k 27 | 28 | accumulator = accumulator.to(tl.float16) 29 | tl._experimental_descriptor_store(c_desc_ptr, accumulator, [offs_am, offs_bn]) 30 | 31 | 32 | def matmul(a, b, config=None): 33 | 34 | m, _ = a.shape 35 | n, k = b.shape 36 | 37 | if config: 38 | block_m = config["block_m"] 39 | block_n = config["block_n"] 40 | block_k = config["block_k"] 41 | num_warps = config["num_warps"] 42 | num_stages = config["num_stages"] 43 | 44 | block_m = 64 45 | block_n = 64 46 | block_k = 256 47 | num_warps = 4 48 | num_stages = 4 49 | TMA_SIZE = 512 50 | 51 | desc_a = np.empty(TMA_SIZE, dtype=np.int8) 52 | desc_b = np.empty(TMA_SIZE, dtype=np.int8) 53 | desc_c = np.empty(TMA_SIZE, dtype=np.int8) 54 | 55 | c = torch.empty((m, n), dtype=torch.float16, device='cuda') 56 | triton.runtime.driver.active.utils.fill_2d_tma_descriptor(a.data_ptr(), m, k, block_m, block_k, a.element_size(), 57 | desc_a) 58 | triton.runtime.driver.active.utils.fill_2d_tma_descriptor(b.data_ptr(), n, k, block_n, block_k, b.element_size(), 59 | desc_b) 60 | triton.runtime.driver.active.utils.fill_2d_tma_descriptor(c.data_ptr(), m, n, block_m, block_n, c.element_size(), 61 | desc_c) 62 | desc_a = torch.tensor(desc_a, device='cuda') 63 | desc_b = torch.tensor(desc_b, device='cuda') 64 | desc_c = torch.tensor(desc_c, device='cuda') 65 | 66 | total_blocks_m = triton.cdiv(m, block_m) 67 | total_blocks_n = triton.cdiv(n, block_n) 68 | 69 | grid = (total_blocks_m * total_blocks_n, 1, 1) 70 | k = gemm_kernel_tma[grid]( 71 | desc_a, desc_b, desc_c, 72 | m, n, k, 73 | block_m, 74 | block_n, 75 | block_k, 76 | num_warps=num_warps, 77 | num_stages=num_stages, 78 | ) 79 | 80 | # with open('tma_fp8.ttgir', 'w') as f: 81 | # print(k.asm['ttgir'], file=f) 82 | 83 | # with open('tma_fp8.ptx', 'w') as f: 84 | # print(k.asm['ptx'], file=f) 85 | 86 | return c 87 | 88 | 89 | if __name__ == '__main__': 90 | 91 | M = 128 92 | N = 4096 93 | K = 4096 94 | 95 | a = torch.randn((M, K), device="cuda", dtype=torch.float16).to(torch.float8_e4m3fn) 96 | b = torch.randn((K, N), device="cuda", dtype=torch.float16).to(torch.float8_e4m3fn) 97 | b = b.T.contiguous() 98 | 99 | c = matmul(a, b) 100 | -------------------------------------------------------------------------------- /kernels/triton/inference/gptq/a100_qlinear.py: -------------------------------------------------------------------------------- 1 | import triton 2 | import triton.language as tl 3 | import torch 4 | 5 | @triton.jit() 6 | def _a100_quantized_matmul(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, 7 | stride_am, stride_ak, 8 | stride_bk, stride_bn, 9 | stride_cm, stride_cn, 10 | stride_scales_g, stride_scales_n, 11 | stride_zeros_g, stride_zeros_n, 12 | groupsize, 13 | m, n, k, 14 | block_size_m: tl.constexpr, block_size_n: tl.constexpr, block_size_k: tl.constexpr, 15 | group_size_m: tl.constexpr, 16 | ): 17 | 18 | pid = tl.program_id(0) 19 | 20 | total_blocks_m = tl.cdiv(m, block_size_m) 21 | total_blocks_n = tl.cdiv(n, block_size_n) 22 | total_blocks_k = tl.cdiv(k, block_size_k) 23 | 24 | num_blocks_in_group = group_size_m * total_blocks_n 25 | group_id = pid // num_blocks_in_group 26 | group_size = min(total_blocks_m - group_id * group_size_m, group_size_m) 27 | 28 | pid_m = group_id * group_size_m + (pid % group_size) 29 | pid_n = (pid % num_blocks_in_group) // (group_size) 30 | 31 | offs_m = (pid_m * block_size_m + tl.arange(0, block_size_m)) % m 32 | offs_n = (pid_n * block_size_n + tl.arange(0, block_size_n)) % n 33 | 34 | offs_am = tl.max_contiguous(tl.multiple_of(offs_m, block_size_m), block_size_m) 35 | offs_bn = tl.max_contiguous(tl.multiple_of(offs_n, block_size_n), block_size_n) 36 | offs_k = tl.arange(0, block_size_k) 37 | 38 | a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) 39 | b_ptrs = b_ptr + ((offs_k[:, None] // 8) * stride_bk + offs_bn[None, :] * stride_bn) 40 | 41 | scales_ptrs = scales_ptr + offs_bn * stride_scales_n 42 | zeros_ptrs = zeros_ptr + ((offs_bn // 8) * stride_zeros_n) 43 | 44 | shifter = (offs_k % 8) * 4 45 | zeros_shifter = (offs_bn % 8) * 4 46 | 47 | 48 | output = tl.zeros((block_size_m, block_size_n), dtype=tl.float32) 49 | for k in range(0, total_blocks_k): 50 | 51 | a = tl.load(a_ptrs) 52 | b = tl.load(b_ptrs) 53 | g_id = k // (groupsize // block_size_k) 54 | 55 | ptr = scales_ptrs + g_id * stride_scales_g 56 | scales = tl.load(ptr) 57 | 58 | ptr = zeros_ptrs + g_id * stride_zeros_g 59 | zeros = tl.load(ptr) 60 | 61 | zeros = (zeros >> zeros_shifter) & 0xF 62 | zeros = (zeros + 1) * scales 63 | 64 | b = (b >> shifter[:, None]) & 0xF # b -> int32 65 | b = b * scales[None, :] - zeros[None, :] # b -> fp16 66 | 67 | output += tl.dot(a, b) 68 | a_ptrs += stride_ak * block_size_k 69 | b_ptrs += (block_size_k//8) * stride_bk 70 | 71 | output.to(tl.float16) 72 | offs_cm = pid_m * block_size_m + tl.arange(0, block_size_m) 73 | offs_cn = pid_n * block_size_n + tl.arange(0, block_size_n) 74 | c_ptrs = c_ptr + (offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn) 75 | tl.store(c_ptrs, output) 76 | 77 | class a100_qlinear(torch.autograd.Function): 78 | def forward(ctx, a, b, scales, zeros): 79 | 80 | m, k = a.shape 81 | _, n = b.shape 82 | 83 | quant_groupsize = 128 84 | block_size_m = 16 85 | block_size_n = 32 # [N = 4096 // 32] = 128 blocks 86 | block_size_k = 256 87 | group_size_m = 8 88 | num_warps = 4 89 | num_stages = 8 90 | total_blocks_m = triton.cdiv(m, block_size_m) 91 | total_blocks_n = triton.cdiv(n, block_size_n) 92 | total_programs = total_blocks_m * total_blocks_n 93 | grid = (total_programs, 1) 94 | 95 | c = torch.zeros((m, n), device=b.device, dtype=torch.float16) 96 | k = _a100_quantized_matmul[grid]( 97 | a, b, c, scales, zeros, 98 | a.stride(0), a.stride(1), 99 | b.stride(0), b.stride(1), 100 | c.stride(0), c.stride(1), 101 | scales.stride(0), scales.stride(1), 102 | zeros.stride(0), zeros.stride(1), 103 | quant_groupsize, 104 | m, n, k, 105 | block_size_m, block_size_n, block_size_k, group_size_m, 106 | num_warps = num_warps, num_stages = num_stages, 107 | ) 108 | 109 | print(f"{k.n_regs} registers used, {k.n_spills} spills, {k.shared/1000} kB shared memory\n") 110 | 111 | with open('dequant_simple.txt', 'w') as f: 112 | 113 | print(f"{k.n_regs} registers used, {k.n_spills} spills, {k.shared/1000} kB shared memory\n", file=f) 114 | print("IR", k.asm['ttir'], file=f) 115 | print("TTGIR", k.asm['ttgir'], file=f) 116 | print("PTX", k.asm['ptx'], file=f) 117 | print(f"{k.n_regs} registers used, {k.n_spills} spills, {k.shared/1000} kB shared memory\n", file=f) 118 | 119 | print(f"{total_blocks_m=} x {total_blocks_n=} = {total_programs=}") 120 | return c 121 | 122 | 123 | a100_qlinear = a100_qlinear.apply -------------------------------------------------------------------------------- /kernels/triton/inference/gptq/benchmark.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import logging 4 | from tqdm import tqdm 5 | import torch 6 | from transformers import AutoTokenizer 7 | from auto_gptq import AutoGPTQForCausalLM 8 | 9 | # Configure logging 10 | logger = logging.getLogger(__name__) 11 | logging.basicConfig(level=logging.INFO) 12 | 13 | 14 | 15 | def benchmark_generation_speed(model, tokenizer, prompt, batch_size, device, num_passes=5): 16 | 17 | token_dict = tokenizer([prompt] * batch_size, return_tensors="pt", padding="longest").to(device) 18 | 19 | total_generation_time = 0 20 | total_num_generated_tokens = 0 21 | 22 | # Warmup 23 | logger.info("Starting warmup...") 24 | for _ in tqdm(range(4), desc="Warmup", leave=False): 25 | with torch.inference_mode(): 26 | _ = model.generate(**token_dict, min_length=30, max_length=30) 27 | 28 | logger.info("Starting benchmark...") 29 | with tqdm(range(num_passes), desc="Benchmark Passes") as pbar: 30 | for pass_num in pbar: 31 | token_dict = tokenizer([prompt] * batch_size, return_tensors="pt", padding="longest").to(device) 32 | 33 | start = time.time() 34 | with torch.inference_mode(): 35 | outputs_ids = model.generate(**token_dict, min_length=30, max_length=30) 36 | end = time.time() 37 | 38 | generation_time = end - start 39 | num_generated_tokens = sum(len(output_ids) for output_ids in outputs_ids) - batch_size * len(token_dict['input_ids'][0]) 40 | tokens_per_second = num_generated_tokens / generation_time 41 | 42 | total_generation_time += generation_time 43 | total_num_generated_tokens += num_generated_tokens 44 | 45 | # Update tqdm post-fix with current iteration results 46 | pbar.set_postfix({"Time (s)": f"{generation_time:.2f}", "Tokens/s": f"{tokens_per_second:.2f}"}) 47 | 48 | # Calculate average statistics 49 | avg_generation_time = total_generation_time / num_passes 50 | avg_tokens_per_second = total_num_generated_tokens / total_generation_time 51 | avg_num_generated_tokens = total_num_generated_tokens / num_passes 52 | 53 | # Log average statistics 54 | logger.info(f"Batch size: {batch_size}, Avg Time: {avg_generation_time:.2f}s, Avg Tokens/s: {avg_tokens_per_second:.2f}, Avg Total tokens: {avg_num_generated_tokens}") 55 | return avg_generation_time, avg_tokens_per_second, avg_num_generated_tokens 56 | 57 | 58 | 59 | def main(): 60 | parser = argparse.ArgumentParser(description='Benchmark Llama-70B') 61 | parser.add_argument('--use_triton', type=lambda x: (str(x).lower() == 'true'), help='use Triton Kernel') 62 | parser.add_argument('--batch_size', type=int, required=True, help='Batch size for the benchmark') 63 | args = parser.parse_args() 64 | 65 | device = "cuda:5" 66 | quantized_model_dir = '/net/storage149/autofs/css22/ccyang/fm-models/llama-gptq/gptq_output_act0_grp128_bluewiki' 67 | 68 | tokenizer = AutoTokenizer.from_pretrained(quantized_model_dir, use_fast=True) 69 | tokenizer.pad_token = tokenizer.eos_token 70 | 71 | tokenizer.padding_side = "left" 72 | 73 | if args.use_triton: 74 | torch.cuda.empty_cache() 75 | model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir, device=device, inject_fused_attention=False, inject_fused_mlp=False, 76 | use_triton=args.use_triton, disable_exllamaV2=True, low_cpu_mem_usage=True, warmup_triton=False) 77 | else: 78 | model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir, device=device, inject_fused_attention=False, inject_fused_mlp=False, 79 | use_triton=False, disable_exllamaV2=False, low_cpu_mem_usage=True, warmup_triton=False) 80 | 81 | model = torch.compile(model, mode="reduce-overhead") 82 | benchmark_generation_speed(model, tokenizer, "auto-gptq is a", args.batch_size, device) 83 | 84 | if __name__ == "__main__": 85 | main() 86 | -------------------------------------------------------------------------------- /kernels/triton/inference/gptq/h100_qlinear.py: -------------------------------------------------------------------------------- 1 | import triton 2 | import triton.language as tl 3 | import torch 4 | 5 | 6 | @triton.jit() 7 | def _h100_quantized_matmul(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, 8 | stride_am, stride_ak, 9 | stride_bk, stride_bn, 10 | stride_cm, stride_cn, 11 | stride_scales_g, stride_scales_n, 12 | stride_zeros_g, stride_zeros_n, 13 | groupsize, 14 | m, n, k, 15 | block_size_m: tl.constexpr, block_size_n: tl.constexpr, block_size_k: tl.constexpr, 16 | group_size_m: tl.constexpr, 17 | fp8_fast_accum: tl.constexpr,): 18 | 19 | pid = tl.program_id(0) 20 | 21 | total_blocks_m = tl.cdiv(m, block_size_m) 22 | total_blocks_n = tl.cdiv(n, block_size_n) 23 | total_blocks_k = tl.cdiv(k, block_size_k) 24 | 25 | num_blocks_in_group = group_size_m * total_blocks_n 26 | group_id = pid // num_blocks_in_group 27 | group_size = min(total_blocks_m - group_id * group_size_m, group_size_m) 28 | 29 | pid_m = group_id * group_size_m + (pid % group_size) 30 | pid_n = (pid % num_blocks_in_group) // (group_size) 31 | 32 | offs_n = pid_n * block_size_n + tl.arange(0, block_size_n) 33 | offs_bn = tl.max_contiguous(tl.multiple_of(offs_n, block_size_n), block_size_n) 34 | offs_k = tl.arange(0, block_size_k) 35 | 36 | a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(m,k), strides=(stride_am, stride_ak), 37 | offsets=(pid_m*block_size_m, 0), block_shape=(block_size_m, block_size_k), 38 | order =(1,0)) 39 | 40 | 41 | b_ptrs = b_ptr + ((offs_k[:, None] // 8) * stride_bk + offs_bn[None, :] * stride_bn) 42 | scales_ptrs = scales_ptr + offs_bn * stride_scales_n 43 | zeros_ptrs = zeros_ptr + ((offs_bn // 8) * stride_zeros_n) 44 | 45 | shifter = (offs_k % 8) * 4 46 | zeros_shifter = (offs_bn % 8) * 4 47 | 48 | acc = tl.zeros((block_size_m, block_size_n), dtype=tl.float32) 49 | for k in range(0, total_blocks_k): 50 | 51 | a = tl.load(a_block_ptr, boundary_check=(0,1)) 52 | b = tl.load(b_ptrs) 53 | g_id = k // (groupsize // block_size_k) 54 | 55 | ptr = scales_ptrs + g_id * stride_scales_g 56 | 57 | scales = tl.load(ptr) 58 | ptr = zeros_ptrs + g_id * stride_zeros_g 59 | zeros = tl.load(ptr) 60 | 61 | zeros = (zeros >> zeros_shifter) & 0xF 62 | zeros = (zeros + 1) * scales 63 | 64 | b = (b >> shifter[:, None]) & 0xF 65 | b = b * scales[None, :] - zeros[None, :] 66 | 67 | if fp8_fast_accum: 68 | acc = tl.dot(a.to(tl.float), b.to(tl.float8e4nv), acc) 69 | else: 70 | acc += tl.dot(a,b) 71 | 72 | a_block_ptr = tl.advance(a_block_ptr, (0, block_size_k)) 73 | b_ptrs += (block_size_k//8) * stride_bk 74 | 75 | acc.to(tl.float16) 76 | offs_cm = pid_m * block_size_m + tl.arange(0, block_size_m) 77 | offs_cn = pid_n * block_size_n + tl.arange(0, block_size_n) 78 | 79 | c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] 80 | c_mask = (offs_cm[:, None] < n) & (offs_cn[None, :] < n) 81 | tl.store(c_ptrs, acc, mask=c_mask) 82 | 83 | 84 | 85 | 86 | 87 | class h100_qlinear(torch.autograd.Function): 88 | def forward(ctx, a, b, scales, zeros): 89 | 90 | m, k = a.shape 91 | _, n = b.shape 92 | 93 | quant_groupsize = 128 94 | block_size_m = 16 95 | block_size_n = 32 96 | block_size_k = 256 97 | group_size_m = 8 98 | num_warps = 4 99 | num_stages = 4 100 | total_blocks_m = triton.cdiv(m, block_size_m) 101 | total_blocks_n = triton.cdiv(n, block_size_n) 102 | total_programs = total_blocks_m * total_blocks_n 103 | grid = (total_programs, 1) 104 | fp8_fast_accum = False 105 | 106 | c = torch.zeros((m, n), device=a.device, dtype=a.dtype) 107 | k = _h100_quantized_matmul[grid]( 108 | a, b, c, scales, zeros, 109 | a.stride(0), a.stride(1), 110 | b.stride(0), b.stride(1), 111 | c.stride(0), c.stride(1), 112 | scales.stride(0), scales.stride(1), 113 | zeros.stride(0), zeros.stride(1), 114 | quant_groupsize, 115 | m, n, k, 116 | block_size_m, block_size_n, block_size_k, group_size_m, fp8_fast_accum = fp8_fast_accum, 117 | num_warps = num_warps, num_stages = num_stages, 118 | ) 119 | 120 | print(f"{total_blocks_m=} x {total_blocks_n=} = {total_programs=}") 121 | return c 122 | 123 | 124 | h100_qlinear = h100_qlinear.apply -------------------------------------------------------------------------------- /kernels/triton/inference/gptq/mixtral/test_dequant_moe_gemm.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from vllm.model_executor.layers.fused_moe import fused_moe 4 | from vllm.model_executor.layers.activation import SiluAndMul 5 | from triton.kernels.gptq.mixtral.w4a16_fused_dequant_gemm import dequant_gemm_moe 6 | from v0_moe_fused import fused_moe as fused_moe_base 7 | import time 8 | 9 | def torch_moe(a, w1, w2, topk_weight, topk_ids): 10 | B, D = a.shape 11 | a = a.view(B, -1, D).repeat(1, topk_ids.shape[1], 1).reshape(-1, D) 12 | out = torch.zeros(B * topk_ids.shape[1], 13 | w2.shape[1], 14 | dtype=a.dtype, 15 | device=a.device) 16 | 17 | topk_ids = topk_ids.view(-1) 18 | topk_weight = topk_weight.view(-1) 19 | for i in range(w1.shape[0]): 20 | mask = topk_ids == i 21 | if mask.sum(): 22 | out[mask] = SiluAndMul()(a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1) 23 | return (out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1)).sum(dim=1) 24 | 25 | def test_dequant_moe( 26 | m: int, 27 | n: int, 28 | k: int, 29 | e: int, 30 | topk: int, 31 | ): 32 | m = m 33 | n = n 34 | k = k 35 | e = e 36 | topk = topk 37 | groupsize = 128 38 | packed_k_dim = k // 8 39 | packed_n_dim = n // 8 40 | g = k // groupsize 41 | topk = 2 42 | 43 | a = torch.randn((m, k), dtype=torch.float16, device='cuda') 44 | qw1 = torch.randint(0, 5, (e, packed_k_dim, n), device='cuda', dtype=torch.int32) 45 | qw2 = torch.randint(0, 5, (e, 2*n, packed_k_dim), device='cuda', dtype=torch.int32) 46 | qw1_zeros = torch.randint(0, 5, (e, g, packed_n_dim), device='cuda', dtype=torch.int32) 47 | qw2_zeros = torch.randint(0, 5, (e, g, packed_n_dim), device='cuda', dtype=torch.int32) 48 | qw1_scales = torch.randn((e, g, n), dtype=torch.float16, device='cuda') 49 | qw2_scales = torch.randn((e, g, n), dtype=torch.float16, device='cuda') 50 | score = torch.randn((m, e), device='cuda', dtype=torch.float16) 51 | score = torch.softmax(score, dim=-1) 52 | _, topk_ids = torch.topk(score, topk) 53 | 54 | 55 | # dtype = torch.float16 56 | # a = torch.randn((m, k), device='cuda', dtype=dtype) / 10 57 | # w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10 58 | # w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10 59 | 60 | 61 | # score = torch.randn((m, e), device='cuda', dtype=dtype) 62 | # score = torch.softmax(score, dim=-1) 63 | # topk_weight, topk_ids = torch.topk(score, topk) 64 | 65 | # triton_output_base = fused_moe_base(a, w1, w2, topk_weight, topk_ids, False) 66 | 67 | # print(triton_output_base) 68 | 69 | # breakpoint() 70 | c = dequant_gemm_moe(a, 71 | qw1, 72 | qw2, 73 | qw1_scales, 74 | qw2_scales, 75 | qw1_zeros, 76 | qw2_zeros, 77 | topk_ids, 78 | ) 79 | # print(c) 80 | # assert torch.allclose(triton_output_splitk, torch_output, atol=1e-1, rtol=0) 81 | 82 | if __name__ == '__main__': 83 | 84 | test_dequant_moe(2, 14336//2, 4096, 8, 2) -------------------------------------------------------------------------------- /kernels/triton/inference/gptq/splitk_dequant_gemm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import triton 3 | from triton import language as tl 4 | # from actual_base_gptq_4 import triton_matmul4 5 | 6 | @triton.jit() 7 | def swizzle_tile(pid, 8 | m, n, 9 | block_m: tl.constexpr, block_n: tl.constexpr, group_m: tl.constexpr): 10 | 11 | grid_m = tl.cdiv(m, block_m) 12 | grid_n = tl.cdiv(n, block_n) 13 | 14 | width = group_m * grid_n 15 | group_id = pid // width 16 | group_size = tl.minimum(grid_m - group_id * group_m, group_m) 17 | 18 | pid_m = group_id * group_m + (pid % group_size) 19 | pid_n = (pid % width) // group_size 20 | 21 | return pid_m, pid_n 22 | 23 | @triton.jit() 24 | def matmul_split_k_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, 25 | stride_am, stride_ak, 26 | stride_bk, stride_bn, 27 | stride_cm, stride_cn, 28 | stride_scales_g, stride_scales_n, 29 | stride_zeros_g, stride_zeros_n, 30 | groupsize, 31 | m, n, k, 32 | block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr, 33 | group_m: tl.constexpr, split_k: tl.constexpr): 34 | 35 | pid = tl.program_id(0) 36 | pid_k = tl.program_id(1) 37 | total_blocks_k = tl.cdiv(k, block_k*split_k) 38 | 39 | pid_m, pid_n = swizzle_tile(pid, 40 | m, n, 41 | block_m, block_n, group_m) 42 | 43 | offs_m = pid_m*block_m + tl.arange(0, block_m) 44 | offs_n = pid_n*block_n + tl.arange(0, block_n) 45 | offs_k = pid_k*block_k + tl.arange(0, block_k) 46 | 47 | offs_am = tl.max_contiguous(tl.multiple_of(offs_m, block_m), block_m) 48 | offs_bn = tl.max_contiguous(tl.multiple_of(offs_n, block_n), block_n) 49 | 50 | a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) 51 | b_ptrs = b_ptr + ((offs_k[:, None] // 8) * stride_bk + offs_bn[None, :] * stride_bn) 52 | 53 | scales_ptrs = scales_ptr + offs_bn * stride_scales_n 54 | zeros_ptrs = zeros_ptr + ((offs_bn // 8) * stride_zeros_n) 55 | 56 | shifter = (offs_k % 8) * 4 57 | zeros_shifter = (offs_bn % 8) * 4 58 | 59 | acc = tl.zeros((block_m, block_n), dtype=tl.float32) 60 | for k in range(0, total_blocks_k): 61 | 62 | a = tl.load(a_ptrs) 63 | b = tl.load(b_ptrs) 64 | 65 | g_id = (k * split_k + pid_k) // (groupsize // block_k) 66 | 67 | ptr = scales_ptrs + g_id * stride_scales_g 68 | scales = tl.load(ptr) 69 | 70 | ptr = zeros_ptrs + g_id * stride_zeros_g 71 | zeros = tl.load(ptr) 72 | 73 | zeros = (zeros >> zeros_shifter) & 0xF 74 | zeros = (zeros + 1) * scales 75 | 76 | b = (b >> shifter[:, None]) & 0xF 77 | b = b * scales[None, :] - zeros[None, :] 78 | 79 | acc += tl.dot(a, b) 80 | a_ptrs += block_k * split_k * stride_ak 81 | b_ptrs += (block_k // 8) * split_k * stride_bk 82 | 83 | acc.to(tl.float16) 84 | 85 | offs_m = pid_m*block_m + tl.arange(0, block_m) 86 | offs_n = pid_n*block_n + tl.arange(0, block_n) 87 | 88 | c_ptrs = c_ptr + (offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn) 89 | tl.atomic_add(c_ptrs, acc, sem='release') 90 | 91 | def matmul_split_k(a, b, scales, zeros): 92 | 93 | m, k = a.shape 94 | _, n = b.shape 95 | 96 | quant_groupsize = 128 97 | block_m = 16 98 | block_n = 32 99 | block_k = 128 100 | group_m = 8 101 | num_stages = 3 102 | num_warps = 4 103 | split_k = 4 104 | 105 | total_blocks_m = triton.cdiv(m, block_m) 106 | total_blocks_n = triton.cdiv(n, block_n) 107 | total_programs_mn = total_blocks_m * total_blocks_n 108 | total_programs_k = split_k 109 | 110 | grid = (total_programs_mn, total_programs_k) 111 | 112 | print(f"problem m size: {m}, tile size m: {block_m}, total blocks m: {total_blocks_m}") 113 | print(f"problem n size: {n}, tile size n: {block_n}, total blocks n: {total_blocks_n}") 114 | print(f"problem k size: {k}, tile size k: {block_k}, total thread blocks k: {split_k}") 115 | 116 | print(f"total thread blocks k: {k}, total thread blocks m and total thread blocks n = {total_blocks_m=} x {total_blocks_n} = {total_programs_mn}") 117 | print(f"{total_programs_mn=}, {total_programs_k=}") 118 | 119 | c = torch.zeros((m, n), device=a.device, dtype=torch.float16) 120 | k = matmul_split_k_kernel[grid](a, b, c, scales, zeros, 121 | a.stride(0), a.stride(1), 122 | b.stride(0), b.stride(1), 123 | c.stride(0), c.stride(1), 124 | scales.stride(0), scales.stride(1), 125 | zeros.stride(0), zeros.stride(1), 126 | quant_groupsize, 127 | m, n, k, 128 | block_m, block_n, block_k, 129 | group_m, split_k, num_stages=num_stages, num_warps=num_warps) 130 | 131 | print(f"{k.n_regs} registers used, {k.n_spills} spills, {k.shared/1000} kB shared memory\n") 132 | 133 | with open('matmul_split_k.txt', 'w') as f: 134 | 135 | print(f"{k.n_regs} registers used, {k.n_spills} spills, {k.shared/1000} kB shared memory\n", file=f) 136 | print("IR", k.asm['ttir'], file=f) 137 | print("TTGIR", k.asm['ttgir'], file=f) 138 | print("PTX", k.asm['ptx'], file=f) 139 | print(f"{k.n_regs} registers used, {k.n_spills} spills, {k.shared/1000} kB shared memory\n", file=f) 140 | 141 | return c 142 | 143 | def make_tensor(M, N, dtype): 144 | if dtype == torch.int32: 145 | # Fill with random integers for int32 type 146 | res = torch.randint(low=-2147483648, high=2147483647, size=(M, N), dtype=dtype, device="cuda") 147 | else: 148 | # Fill with normally distributed random values for other types 149 | res = torch.empty((M, N), dtype=dtype, device="cuda") 150 | res.normal_(mean=0.0, std=0.5) 151 | return res 152 | 153 | 154 | if __name__ == '__main__': 155 | 156 | m = 16 157 | k = 4096 158 | n = 4096 159 | groupsize = 128 160 | g = k // groupsize 161 | 162 | a = make_tensor(m, k, dtype=torch.float16) 163 | b = make_tensor(k//8, n, dtype=torch.int32) 164 | c = make_tensor(m, n, dtype=torch.float16) 165 | zeros = make_tensor(g, n//8, torch.int32) 166 | scales = make_tensor(g, n, torch.float16) 167 | 168 | # base = no_autotune(groupsize, a, b, scales, zeros) 169 | # print(f"{base.shape=}, {base[0][0:4]}") 170 | 171 | # c = custom_qlinear(a, b, scales, zeros) 172 | # print(f"{c.shape=}, {c[0][0:4]}") 173 | 174 | 175 | split_k_output = matmul_split_k(a, b, scales, zeros) 176 | print(f"{split_k_output.shape=}, {split_k_output[0][0:4]}") 177 | 178 | 179 | -------------------------------------------------------------------------------- /kernels/triton/inference/mamba/causal_1d_conv/tests/test_causal_1d_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2025, IBM Research. 2 | # python -m pytest tests/test_causal_conv1d.py 3 | 4 | import sys 5 | from einops import rearrange 6 | import pytest 7 | import torch.nn.functional as F 8 | import torch 9 | import math 10 | 11 | import os 12 | from pathlib import Path 13 | 14 | base_path = Path(os.path.abspath(os.path.dirname(os.path.realpath(__file__)))) 15 | 16 | sys.path.insert(0, str(base_path / "../causal_1d_conv")) 17 | 18 | try: 19 | from causal_1d_conv import causal_conv1d_fn 20 | except ImportError: 21 | raise 22 | 23 | 24 | def _undecorated_test_causal_conv1d( 25 | batch, 26 | dim, 27 | seqlen, 28 | width, 29 | has_bias, 30 | silu_activation, 31 | itype, 32 | channel_last, 33 | has_initial_states, 34 | return_final_states, 35 | check_backward, 36 | ): 37 | if not channel_last and (has_initial_states or return_final_states): 38 | pytest.skip("Only channel_last support initial_states or return_final_states") 39 | device = "cuda" 40 | rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) 41 | if itype == torch.bfloat16: 42 | rtol, atol = 1e-2, 5e-2 43 | rtolw, atolw = (1e-3, 1e-3) 44 | # set seed 45 | torch.random.manual_seed(0) 46 | if not channel_last: 47 | x = torch.randn(batch, 4096 + dim + 64, seqlen, device=device, dtype=itype)[ 48 | :, 4096: 4096 + dim, : 49 | ].requires_grad_() 50 | else: 51 | x = rearrange( 52 | torch.randn(batch, seqlen, 4096 + dim + 64, device=device, dtype=itype)[:, :, 4096: 4096 + dim], 53 | "b s d -> b d s", 54 | ).requires_grad_() 55 | weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True) 56 | if has_bias: 57 | bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) 58 | else: 59 | bias = None 60 | if has_initial_states: 61 | initial_states = torch.randn(batch, width - 1, dim, device=device, dtype=itype).transpose(1, 2).requires_grad_() 62 | else: 63 | initial_states = None 64 | x_ref = x.detach().clone().requires_grad_() 65 | weight_ref = weight.detach().clone().requires_grad_() 66 | bias_ref = bias.detach().clone().requires_grad_() if bias is not None else None 67 | initial_states_ref = initial_states.detach().clone().requires_grad_() if initial_states is not None else None 68 | activation = None if not silu_activation else "silu" 69 | out = causal_conv1d_fn( 70 | x, weight, bias, initial_states=initial_states, return_final_states=return_final_states, activation=activation 71 | ) 72 | out_ref = causal_conv1d_ref( 73 | x_ref, 74 | weight_ref, 75 | bias_ref, 76 | initial_states=initial_states_ref, 77 | return_final_states=return_final_states, 78 | activation=activation, 79 | ) 80 | if return_final_states: 81 | out, final_states = out 82 | out_ref, final_states_ref = out_ref 83 | print(f"Final states max diff: {(final_states - final_states_ref).abs().max().item()}") 84 | print(f"Final states mean diff: {(final_states - final_states_ref).abs().mean().item()}") 85 | assert torch.allclose(final_states, final_states_ref, rtol=rtol, atol=atol) 86 | 87 | print(f"Output max diff: {(out - out_ref).abs().max().item()}") 88 | print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") 89 | assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) 90 | 91 | if return_final_states: 92 | out += F.sigmoid(final_states).sum(dim=-1, keepdim=True) 93 | out_ref += F.sigmoid(final_states_ref).sum(dim=-1, keepdim=True) 94 | 95 | if check_backward: 96 | g = torch.randn_like(out) 97 | out.backward(g) 98 | out_ref.backward(g) 99 | 100 | print(f"dx max diff: {(x.grad - x_ref.grad).abs().max().item()}") 101 | print(f"dweight max diff: {(weight.grad - weight_ref.grad).abs().max().item()}") 102 | if has_bias: 103 | print(f"dbias max diff: {(bias.grad - bias_ref.grad).abs().max().item()}") 104 | if has_initial_states: 105 | print(f"dinitial_states max diff: {(initial_states.grad - initial_states_ref.grad).abs().max().item()}") 106 | 107 | assert torch.allclose(x.grad, x_ref.grad.to(dtype=itype), rtol=rtol, atol=atol) 108 | assert torch.allclose(weight.grad, weight_ref.grad, rtol=rtolw, atol=atolw) 109 | if has_bias: 110 | assert torch.allclose(bias.grad, bias_ref.grad, rtol=rtolw, atol=atolw) 111 | if has_initial_states: 112 | assert torch.allclose(initial_states.grad, initial_states_ref.grad.to(dtype=itype), rtol=rtol, atol=atol) 113 | torch.cuda.empty_cache() 114 | del x_ref, x, weight, weight_ref, bias, bias_ref, out, out_ref 115 | 116 | 117 | def causal_conv1d_ref( 118 | x, 119 | weight, 120 | bias=None, 121 | initial_states=None, 122 | return_final_states=False, 123 | final_states_out=None, 124 | activation=None, 125 | ): 126 | """[copied from causal_conv1d/causal_conv1d_interface.py] 127 | x: (batch, dim, seqlen) 128 | weight: (dim, width) 129 | bias: (dim,) 130 | initial_states: (batch, dim, width - 1) 131 | final_states_out: (batch, dim, width - 1) 132 | 133 | out: (batch, dim, seqlen) 134 | """ 135 | if activation not in [None, "silu", "swish"]: 136 | raise NotImplementedError("activation must be None, silu, or swish") 137 | dtype_in = x.dtype 138 | x = x.to(weight.dtype) 139 | seqlen = x.shape[-1] 140 | dim, width = weight.shape 141 | if initial_states is None: 142 | out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) 143 | else: 144 | x = torch.cat([initial_states, x], dim=-1) 145 | out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) 146 | out = out[..., :seqlen] 147 | if return_final_states: 148 | final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(dtype_in) # (batch, dim, width - 1) 149 | if final_states_out is not None: 150 | final_states_out.copy_(final_states) 151 | else: 152 | final_states_out = final_states 153 | out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) 154 | return out if not return_final_states else (out, final_states_out) 155 | 156 | 157 | @pytest.mark.parametrize("batch", [1, 2, 3, 8, 16, 32, 64]) # END-GOAL 158 | # @pytest.mark.parametrize("batch", [2]) 159 | @pytest.mark.parametrize("dim", [64, 4096 + 32]) # END-GOAL 160 | # @pytest.mark.parametrize('dim', [64]) 161 | # @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096]) 162 | # @pytest.mark.parametrize('seqlen', [128]) 163 | @pytest.mark.parametrize( 164 | "seqlen", [1, 2, 8, 16, 32, 64, 128, 129, 130, 151, 256, 372, 512, 784, 1024, 1134, 2048, 4096] 165 | ) # END-GOAL 166 | @pytest.mark.parametrize("width", [2, 3, 4, 5]) # END-GOAL 167 | # @pytest.mark.parametrize('width', [3]) 168 | @pytest.mark.parametrize("has_bias", [False, True]) # END-GOAL 169 | # @pytest.mark.parametrize('has_bias', [True]) 170 | # @pytest.mark.parametrize('has_bias', [False]) 171 | @pytest.mark.parametrize("silu_activation", [False, True]) # END-GOAL 172 | # @pytest.mark.parametrize("silu_activation", [True]) 173 | @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) 174 | # @pytest.mark.parametrize('itype', [torch.float16]) 175 | # @pytest.mark.parametrize("channel_last", [False, True]) 176 | @pytest.mark.parametrize("channel_last", [True]) # END-GOAL 177 | @pytest.mark.parametrize("has_initial_states", [False, True]) # END-GOAL 178 | # @pytest.mark.parametrize("has_initial_states", [False]) 179 | # @pytest.mark.parametrize("return_final_states", [False, True]) # END-GOAL 180 | @pytest.mark.parametrize("return_final_states", [False]) 181 | # @pytest.mark.parametrize('check_backward', [True]) # END-GOAL 182 | @pytest.mark.parametrize("check_backward", [False]) 183 | def test_causal_conv1d( 184 | batch, 185 | dim, 186 | seqlen, 187 | width, 188 | has_bias, 189 | silu_activation, 190 | itype, 191 | channel_last, 192 | has_initial_states, 193 | return_final_states, 194 | check_backward, 195 | ): 196 | return _undecorated_test_causal_conv1d( 197 | batch, 198 | dim, 199 | seqlen, 200 | width, 201 | has_bias, 202 | silu_activation, 203 | itype, 204 | channel_last, 205 | has_initial_states, 206 | return_final_states, 207 | check_backward, 208 | ) 209 | -------------------------------------------------------------------------------- /kernels/triton/training/README.md: -------------------------------------------------------------------------------- 1 | Triton training kernels 2 | -------------------------------------------------------------------------------- /kernels/triton/training/fused_softmax/README.md: -------------------------------------------------------------------------------- 1 | Fused Softmax in Triton, supporting both inference (fwd) and training (fwd/backward). 2 | 3 | Perf testing on A100: 4 | 5 | fused_softmax_a100 6 | -------------------------------------------------------------------------------- /kernels/triton/training/fused_softmax/softmax.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # ---- Fused Softmax written in Triton ------ 8 | # Extra Credits: 9 | # Triton Softmax Tutorial 10 | # LucidRains Triton_Transformers 11 | 12 | import torch 13 | import triton 14 | import triton.language as tl 15 | 16 | from torch import autograd 17 | 18 | def _get_num_warps(block_size: int)-> int: 19 | num_warps = 4 20 | if block_size > 2047: 21 | num_warps = 8 22 | if block_size > 4095: 23 | num_warps=16 24 | return num_warps 25 | 26 | @triton.jit 27 | def _softmax_kernel_fwd( 28 | output_ptr, 29 | output_row_stride, 30 | input_ptr, 31 | input_row_stride, 32 | n_cols, 33 | block_size: tl.constexpr, 34 | ): 35 | # setup input location 36 | row_index = tl.program_id(0) 37 | input_row_ptr = input_ptr + (row_index * input_row_stride) 38 | col_offsets = tl.arange(0, block_size) 39 | input_ptrs = input_row_ptr + col_offsets 40 | rw_mask = col_offsets < n_cols 41 | row = tl.load(input_ptrs, mask = rw_mask, other=float("-inf")) 42 | 43 | # safe softmax proper 44 | safe_row = row - tl.max(row, axis=0) 45 | numerator = tl.exp(safe_row) 46 | denom = tl.sum(numerator, axis=0) 47 | sm_out = numerator / denom 48 | 49 | # write results to HBM 50 | out_row_ptr = output_ptr + (row_index * output_row_stride) 51 | out_row_ptrs = out_row_ptr + col_offsets 52 | tl.store(out_row_ptrs, sm_out, mask = rw_mask) 53 | 54 | 55 | @triton.jit 56 | def _softmax_kernel_bwd( 57 | output_ptr, 58 | stride_output_row, 59 | grad_ptr, 60 | stride_grad_row, 61 | input_ptr, 62 | stride_input_row, 63 | n_cols, 64 | block_size: tl.constexpr, 65 | 66 | ): 67 | # setup input locations - need both grad and input access 68 | row_index = tl.program_id(0) 69 | 70 | input_row_ptr = input_ptr + (row_index * stride_input_row) 71 | grad_row_ptr = grad_ptr + (row_index * stride_grad_row) 72 | 73 | col_offsets = tl.arange(0,block_size) 74 | rw_mask = col_offsets < n_cols 75 | 76 | input_row_ptrs = input_row_ptr + col_offsets 77 | grad_row_ptrs = grad_row_ptr + col_offsets 78 | 79 | 80 | probs_row =tl.load(input_row_ptrs, mask=rw_mask, other = 0) 81 | grads_row = tl.load(grad_row_ptrs, mask = rw_mask, other=0) 82 | 83 | # compute derivatives 84 | dx = probs_row * grads_row 85 | dsm_out = dx - probs_row * (tl.sum(dx, axis=0)) 86 | 87 | # write to HBM 88 | output_row_ptr = output_ptr + (row_index * stride_output_row) 89 | output_ptrs = output_row_ptr + col_offsets 90 | tl.store(output_ptrs, dsm_out, mask=rw_mask) 91 | 92 | 93 | class triton_softmax(autograd.Function): 94 | @staticmethod 95 | def forward(ctx, x): 96 | orig_shape = x.shape 97 | x = x.view(-1, orig_shape[-1]) 98 | nrows, ncols = x.shape 99 | 100 | block_size = triton.next_power_of_2(ncols) 101 | num_warps = _get_num_warps(block_size) 102 | 103 | res = torch.empty_like(x) 104 | grid = (nrows,) 105 | 106 | _softmax_kernel_fwd[grid]( 107 | res, 108 | res.stride(0), 109 | x, 110 | x.stride(0), 111 | ncols, 112 | block_size=block_size, 113 | num_warps=num_warps, 114 | 115 | ) 116 | 117 | if x.requires_grad: 118 | ctx.save_for_backward(res) 119 | return res.view(*orig_shape) 120 | 121 | @staticmethod 122 | def backward(ctx, grad_probs): 123 | orig_shape = grad_probs.shape 124 | probs, = ctx.saved_tensors 125 | 126 | grad_probs = grad_probs.view(-1, orig_shape[-1]) 127 | nrows, ncols = grad_probs.shape 128 | 129 | block_size = triton.next_power_of_2(ncols) 130 | num_warps = _get_num_warps(block_size) 131 | 132 | dx = torch.empty_like(probs) 133 | grid = (nrows,) 134 | 135 | _softmax_kernel_bwd[grid]( 136 | dx, 137 | dx.stride(0), 138 | probs, 139 | probs.stride(0), 140 | grad_probs, 141 | grad_probs.stride(0), 142 | ncols, 143 | block_size=block_size, 144 | num_warps=num_warps, 145 | 146 | ) 147 | return dx.view(*orig_shape), None 148 | 149 | fused_softmax = triton_softmax.apply 150 | 151 | if __name__ == '__main__': 152 | sample = torch.tensor([[1,2,3,4,5], [5,4,3,2,1]], dtype = torch.float32, device="cuda", requires_grad=True) 153 | from torch.nn.functional import softmax as torch_softmax 154 | res_torch = torch_softmax(sample, dim=1) 155 | res_triton = fused_softmax(sample) 156 | 157 | torch.testing.assert_close(res_torch, res_triton, rtol=0, atol=1e-4) 158 | 159 | # backward 160 | dout = torch.randn_like(sample) 161 | bwd_torch = res_torch.backward(dout) 162 | bwd_triton = res_triton.backward(dout) 163 | 164 | torch.testing.assert_close(bwd_triton, bwd_torch, rtol=0, atol=1e-4) 165 | -------------------------------------------------------------------------------- /kernels/triton/tutorials/README.md: -------------------------------------------------------------------------------- 1 | Triton tutorials 2 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | 2 | ### Applied AI repo 3 | For experiments and research on Applied AI. 4 | 5 | ### Projects 6 | 7 | #### Kernels 8 | 9 | Housing a variety of Triton and CUDA kernels for training and inference. 10 | 11 | Inference kernels = no backward pass support. 12 | 13 | ##### Triton Kernels 14 | 15 | #### 1 - Triton - MoE (Mixtral) GEMM for accelerating inference. Uses col major access pattern to increase locality. 16 | 17 | moe_gemm_a100 18 | 19 | 20 | #### 2 - Triton - Fused Softmax for both training and inference. 21 | 22 | softmax_fused 23 | 24 | #### 3 - Triton - Fused RMSNorm for both training and inference. 25 | [Fused RMSNorm Kernel](https://github.com/pytorch-labs/applied-ai/blob/main/kernels/triton/training/rms_norm/fused_rms_norm.py) 26 | 27 | #### Other projects from Applied AI 28 | 29 | 1. [CUDA Mode](https://github.com/cuda-mode) - Reading group for learning CUDA programming - ([Discord](https://discord.gg/cudamode), [Lecture Materials](https://github.com/cuda-mode/lectures), [Lecture recordings](https://www.youtube.com/@CUDAMODE)) 30 | 2. [llama-recipes](https://github.com/meta-llama/llama-recipes) - Recipes for fine-tuning and inference for Llama model series 31 | 3. NeurIPS'23 [LLM Efficiency Challenge](https://llm-efficiency-challenge.github.io/) - 1LLM + 1GPU + 1Day competition - ([website](https://llm-efficiency-challenge.github.io/), [code](https://github.com/llm-efficiency-challenge), [NeurIPS Workshop recordings](https://neurips.cc/virtual/2023/competition/66594)) 32 | 33 | ### Papers and Publications 34 | 35 | 1. PyTorch 2: Faster Machine Learning Through Dynamic Python Bytecode Transformation and Graph Compilation [paper](https://pytorch.org/assets/pytorch2-2.pdf) 36 | 2. Accelerating a Triton Fused Kernel for W4A16 Quantized Inference with SplitK Work Decomposition [paper](https://ai.meta.com/research/publications/accelerating-a-triton-fused-kernel-for-w4a16-quantized-inference-with-splitk-work-decomposition/) 37 | 3. PyTorch FSDP: Experiences on Scaling Fully Sharded Data Parallel [paper](https://arxiv.org/abs/2304.11277) 38 | 4. Sustainable AI: Environmental Implications, Challenges and Opportunities [paper](https://arxiv.org/abs/2111.00364) 39 | 40 | 41 | 42 | ### License 43 | The applied-ai repo is released under the [BSD 3](LICENSE) license. 44 | -------------------------------------------------------------------------------- /tutorials/triton/kernels/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /tutorials/triton/kernels/flash_attention_fwd.py: -------------------------------------------------------------------------------- 1 | # flash forward v2 2 | -------------------------------------------------------------------------------- /tutorials/triton/kernels/fused_softmax.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # ---- Fused Softmax written in Triton ------ 5 | # Extra Credits: 6 | # Triton Softmax Tutorial 7 | # LucidRains Triton_Transformers 8 | 9 | import torch 10 | import triton 11 | import triton.language as tl 12 | 13 | from torch import autograd 14 | 15 | def _get_num_warps(block_size: int)-> int: 16 | num_warps = 4 17 | if block_size > 2047: 18 | num_warps = 8 19 | if block_size > 4095: 20 | num_warps=16 21 | return num_warps 22 | 23 | @triton.jit 24 | def _softmax_kernel_fwd( 25 | output_ptr, 26 | output_row_stride, 27 | input_ptr, 28 | input_row_stride, 29 | n_cols, 30 | block_size: tl.constexpr, 31 | ): 32 | # setup input location 33 | row_index = tl.program_id(0) 34 | input_row_ptr = input_ptr + (row_index * input_row_stride) 35 | col_offsets = tl.arange(0, block_size) 36 | input_ptrs = input_row_ptr + col_offsets 37 | rw_mask = col_offsets < n_cols 38 | row = tl.load(input_ptrs, mask = rw_mask, other=float("-inf")) 39 | 40 | # safe softmax proper 41 | safe_row = row - tl.max(row, axis=0) 42 | numerator = tl.exp(safe_row) 43 | denom = tl.sum(numerator, axis=0) 44 | sm_out = numerator / denom 45 | 46 | # write results to HBM 47 | out_row_ptr = output_ptr + (row_index * output_row_stride) 48 | out_row_ptrs = out_row_ptr + col_offsets 49 | tl.store(out_row_ptrs, sm_out, mask = rw_mask) 50 | 51 | 52 | @triton.jit 53 | def _softmax_kernel_bwd( 54 | output_ptr, 55 | stride_output_row, 56 | grad_ptr, 57 | stride_grad_row, 58 | input_ptr, 59 | stride_input_row, 60 | n_cols, 61 | block_size: tl.constexpr, 62 | 63 | ): 64 | # setup input locations - need both grad and input access 65 | row_index = tl.program_id(0) 66 | 67 | input_row_ptr = input_ptr + (row_index * stride_input_row) 68 | grad_row_ptr = grad_ptr + (row_index * stride_grad_row) 69 | 70 | col_offsets = tl.arange(0,block_size) 71 | rw_mask = col_offsets < n_cols 72 | 73 | input_row_ptrs = input_row_ptr + col_offsets 74 | grad_row_ptrs = grad_row_ptr + col_offsets 75 | 76 | 77 | probs_row =tl.load(input_row_ptrs, mask=rw_mask, other = 0) 78 | grads_row = tl.load(grad_row_ptrs, mask = rw_mask, other=0) 79 | 80 | # compute derivatives 81 | dx = probs_row * grads_row 82 | dsm_out = dx - probs_row * (tl.sum(dx, axis=0)) 83 | 84 | # write to HBM 85 | output_row_ptr = output_ptr + (row_index * stride_output_row) 86 | output_ptrs = output_row_ptr + col_offsets 87 | tl.store(output_ptrs, dsm_out, mask=rw_mask) 88 | 89 | 90 | class triton_softmax(autograd.Function): 91 | @staticmethod 92 | def forward(ctx, x): 93 | orig_shape = x.shape 94 | x = x.view(-1, orig_shape[-1]) 95 | nrows, ncols = x.shape 96 | 97 | block_size = triton.next_power_of_2(ncols) 98 | num_warps = _get_num_warps(block_size) 99 | 100 | res = torch.empty_like(x) 101 | grid = (nrows,) 102 | 103 | _softmax_kernel_fwd[grid]( 104 | res, 105 | res.stride(0), 106 | x, 107 | x.stride(0), 108 | ncols, 109 | block_size=block_size, 110 | num_warps=num_warps, 111 | 112 | ) 113 | 114 | if x.requires_grad: 115 | ctx.save_for_backward(res) 116 | return res.view(*orig_shape) 117 | 118 | @staticmethod 119 | def backward(ctx, grad_probs): 120 | orig_shape = grad_probs.shape 121 | probs, = ctx.saved_tensors 122 | 123 | grad_probs = grad_probs.view(-1, orig_shape[-1]) 124 | nrows, ncols = grad_probs.shape 125 | 126 | block_size = triton.next_power_of_2(ncols) 127 | num_warps = _get_num_warps(block_size) 128 | 129 | dx = torch.empty_like(probs) 130 | grid = (nrows,) 131 | 132 | _softmax_kernel_bwd[grid]( 133 | dx, 134 | dx.stride(0), 135 | probs, 136 | probs.stride(0), 137 | grad_probs, 138 | grad_probs.stride(0), 139 | ncols, 140 | block_size=block_size, 141 | num_warps=num_warps, 142 | 143 | ) 144 | return dx.view(*orig_shape), None 145 | 146 | fused_softmax = triton_softmax.apply 147 | 148 | if __name__ == '__main__': 149 | sample = torch.tensor([[1,2,3,4,5], [5,4,3,2,1]], dtype = torch.float32, device="cuda", requires_grad=True) 150 | from torch.nn.functional import softmax as torch_softmax 151 | res_torch = torch_softmax(sample, dim=1) 152 | res_triton = fused_softmax(sample) 153 | 154 | torch.testing.assert_close(res_torch, res_triton, rtol=0, atol=1e-4) 155 | 156 | # backward 157 | dout = torch.randn_like(sample) 158 | bwd_torch = res_torch.backward(dout) 159 | bwd_triton = res_triton.backward(dout) 160 | 161 | torch.testing.assert_close(bwd_triton, bwd_torch, rtol=0, atol=1e-4) 162 | -------------------------------------------------------------------------------- /tutorials/triton/kernels/readme.md: -------------------------------------------------------------------------------- 1 | Triton tutorials 2 | 3 | 1 - Vector Add - Starting tutorial on simple first kernel 4 | 2 - Fused Softmax - Full fused softmax with both forward and backward (training ready) 5 | -------------------------------------------------------------------------------- /tutorials/triton/kernels/vector_add.py: -------------------------------------------------------------------------------- 1 | # coding up a Triton vector addition kernel 2 | # links to 3 | 4 | import triton 5 | import triton.language as tl 6 | import torch 7 | 8 | @triton.jit 9 | def kernel_vector_addition(a_ptr, b_ptr, out_ptr, 10 | num_elems: tl.constexpr, 11 | block_size: tl.constexpr): 12 | 13 | pid = tl.program_id(axis = 0) 14 | 15 | block_start = pid * block_size # 0 * 2 = 0, 1 * 2 = 2, 16 | thread_offsets = block_start + tl.arange(0, block_size) 17 | mask = thread_offsets < num_elems 18 | a_pointers = tl.load(a_ptr + thread_offsets, mask = mask) 19 | b_pointers = tl.load(b_ptr + thread_offsets, mask = mask) 20 | res = a_pointers + b_pointers 21 | tl.store(out_ptr + thread_offsets, res, mask=mask) 22 | 23 | 24 | def ceil_div(x: int, y: int)-> int: 25 | return ((x+y-1)// y) 26 | 27 | def vector_addition(a: torch.tensor, b: torch.tensor)-> torch.tensor: 28 | output_buffer = torch.empty_like(a) 29 | assert a.is_cuda() and b.is_cuda() 30 | num_elems = a.numel() 31 | assert num_elems == b.numel() # todo - handle mismatched sizes 32 | 33 | block_size = 128 34 | grid_size = ceil_div(num_elems, block_size) 35 | grid = (grid_size,) 36 | 37 | k2 = kernel_vector_addition[grid](a, b, output_buffer, 38 | num_elems, 39 | block_size) 40 | 41 | return output_buffer 42 | -------------------------------------------------------------------------------- /tutorials/triton/tests/test_softmax.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | import pytest 5 | import torch 6 | import sys 7 | sys.path.append('..') 8 | from triton_kernels.softmax import fused_softmax 9 | 10 | from test_utils import assert_expected, set_rng_seed, gpu_test 11 | 12 | @pytest.fixture(autouse=True) 13 | def set_seed(): 14 | set_rng_seed(2020) 15 | 16 | 17 | @gpu_test() 18 | class TestForwardSoftMax: 19 | 20 | def test_forward_2D_float32(self,): 21 | # float32 22 | seq_len = 768 23 | 24 | sample_constant_float32 = torch.ones((seq_len, seq_len), dtype=torch.float32, device='cuda') 25 | sample_random_float32 = torch.randn_like(sample_constant_float32) 26 | 27 | expected_out_constant32 = torch.softmax(sample_constant_float32, dim=1) 28 | expected_out_random32 = torch.softmax(sample_random_float32, dim=1) 29 | 30 | triton_out_c32 = fused_softmax(sample_constant_float32) 31 | triton_out_random32 = fused_softmax(sample_random_float32) 32 | 33 | assert_expected(triton_out_c32, expected_out_constant32 ) 34 | assert_expected(triton_out_random32, expected_out_random32) 35 | 36 | def test_forward_2D_bfloat16(self,): 37 | # bfloat16 38 | seq_len = 2048 39 | sample_constant_bf16 = torch.ones((seq_len, seq_len), dtype=torch.bfloat16, device='cuda') 40 | sample_random_bf16 = torch.randn_like(sample_constant_bf16) 41 | 42 | expected_out_c_bf16 = torch.softmax(sample_constant_bf16, dim=1) 43 | expected_out_rand_bf16 = torch.softmax(sample_random_bf16, dim=1) 44 | 45 | triton_out_c_bf16 = fused_softmax(sample_constant_bf16) 46 | triton_out_rand_bf16 = fused_softmax(sample_random_bf16) 47 | 48 | assert_expected(triton_out_c_bf16, expected_out_c_bf16 ) 49 | assert_expected(triton_out_rand_bf16, expected_out_rand_bf16) 50 | 51 | def test_forward_3D_bfloat16(self,): 52 | # bfloat16 53 | seq_len = 2048 54 | batch = 12 55 | 56 | sample_constant_bf16 = torch.ones((batch, seq_len, seq_len), dtype=torch.bfloat16, device='cuda') 57 | sample_random_bf16 = torch.randn_like(sample_constant_bf16) 58 | 59 | expected_out_c_bf16 = torch.softmax(sample_constant_bf16, dim=1) 60 | expected_out_rand_bf16 = torch.softmax(sample_random_bf16, dim=1) 61 | 62 | triton_out_c_bf16 = fused_softmax(sample_constant_bf16) 63 | triton_out_rand_bf16 = fused_softmax(sample_random_bf16) 64 | 65 | assert_expected(triton_out_c_bf16, expected_out_c_bf16, atol=1e-2 ) 66 | assert_expected(triton_out_rand_bf16, expected_out_rand_bf16, atol=1e-2) 67 | 68 | 69 | @gpu_test() 70 | class TestBackwardSoftMax: 71 | 72 | def test_backward_2D(self,): 73 | seq_len = 1024 74 | 75 | sample_constant_float32 = torch.ones((seq_len, seq_len), dtype=torch.float32, device='cuda', requires_grad=True) 76 | sample_random_float32 = torch.randn_like(sample_constant_float32, requires_grad=True) 77 | 78 | expected_fwd_constant32 = torch.softmax(sample_constant_float32, dim=1) 79 | expected_fwd_random32 = torch.softmax(sample_random_float32, dim=1) 80 | 81 | triton_fwd_c32 = fused_softmax(sample_constant_float32) 82 | triton_fwd_random32 = fused_softmax(sample_random_float32) 83 | 84 | dout = torch.randn_like(sample_constant_float32) 85 | 86 | expected_bwd_c32 = expected_fwd_constant32.backward(dout) 87 | expected_bwd_r32 = expected_fwd_random32.backward(dout) 88 | 89 | triton_bwd_c32 = triton_fwd_c32.backward(dout) 90 | triton_bwd_r32 = triton_fwd_random32.backward(dout) 91 | 92 | 93 | assert_expected(triton_bwd_c32, expected_bwd_c32 ) 94 | assert_expected(triton_bwd_r32, expected_bwd_r32) 95 | 96 | def test_bwd_3D(self,): 97 | seq_len = 2048 98 | batch = 4 99 | 100 | sample_constant_float32 = torch.ones((batch, seq_len, seq_len), dtype=torch.float32, device='cuda', requires_grad=True) 101 | sample_random_float32 = torch.randn_like(sample_constant_float32, requires_grad=True) 102 | 103 | expected_fwd_constant32 = torch.softmax(sample_constant_float32, dim=1) 104 | expected_fwd_random32 = torch.softmax(sample_random_float32, dim=1) 105 | 106 | triton_fwd_c32 = fused_softmax(sample_constant_float32) 107 | triton_fwd_random32 = fused_softmax(sample_random_float32) 108 | 109 | dout = torch.randn_like(sample_constant_float32) 110 | 111 | expected_bwd_c32 = expected_fwd_constant32.backward(dout) 112 | expected_bwd_r32 = expected_fwd_random32.backward(dout) 113 | 114 | triton_bwd_c32 = triton_fwd_c32.backward(dout) 115 | triton_bwd_r32 = triton_fwd_random32.backward(dout) 116 | 117 | 118 | assert_expected(triton_bwd_c32, expected_bwd_c32 ) 119 | assert_expected(triton_bwd_r32, expected_bwd_r32) 120 | -------------------------------------------------------------------------------- /tutorials/triton/tests/test_utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Any, Dict, NamedTuple, Optional, Tuple, Union 3 | 4 | import pytest 5 | import torch 6 | import torch.distributed as dist 7 | from torch import Tensor, nn 8 | 9 | 10 | def assert_expected( 11 | actual: Any, 12 | expected: Any, 13 | rtol: Optional[float] = 0, 14 | atol: Optional[float] = 1e-4, 15 | check_device=True, 16 | ): 17 | torch.testing.assert_close( 18 | actual, 19 | expected, 20 | rtol=rtol, 21 | atol=atol, 22 | check_device=check_device, 23 | msg=f"actual: {actual}, expected: {expected}", 24 | ) 25 | 26 | def set_rng_seed(seed): 27 | """Sets the seed for pytorch random number generators""" 28 | torch.manual_seed(seed) 29 | 30 | 31 | def gpu_test(gpu_count: int = 1): 32 | """ 33 | Annotation for GPU tests, skipping the test if the 34 | required amount of GPU is not available 35 | """ 36 | message = f"Not enough GPUs to run the test: required {gpu_count}" 37 | return pytest.mark.skipif(torch.cuda.device_count() < gpu_count, reason=message) 38 | --------------------------------------------------------------------------------