├── .gitignore ├── .gitmodules ├── CMakeLists.txt ├── LICENSE ├── README.md ├── benchmark.py ├── gemm_fp8 ├── __init__.py └── kernels │ ├── bindings.cpp │ ├── gemm.cu │ └── include │ ├── common.h │ └── gemm.h ├── pyproject.toml └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.cmake 2 | Makefile 3 | CMakeCache.txt 4 | CMakeFiles/* 5 | *.so 6 | gemm_fp8.egg-info/ 7 | __pycache__/ 8 | *.pyc 9 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "cutlass"] 2 | path = cutlass 3 | url = https://github.com/NVIDIA/cutlass.git 4 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.11) 2 | project(gemm_fp8 LANGUAGES CXX) 3 | 4 | find_package(Git REQUIRED) 5 | if(GIT_FOUND AND EXISTS "${PROJECT_SOURCE_DIR}/.git") 6 | message(STATUS "Populating Git submodule.") 7 | execute_process(COMMAND ${GIT_EXECUTABLE} submodule update --init --recursive 8 | WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} 9 | RESULT_VARIABLE GIT_SUBMOD_RESULT) 10 | if(NOT GIT_SUBMOD_RESULT EQUAL "0") 11 | message(FATAL_ERROR 12 | "git submodule updata --init --recursive failed with ${GIT_SUBMOD_RESULT}.") 13 | endif() 14 | endif() 15 | 16 | 17 | 18 | 19 | set(_saved_CMAKE_MESSAGE_LOG_LEVEL ${CMAKE_MESSAGE_LOG_LEVEL}) 20 | set(CMAKE_MESSAGE_LOG_LEVEL ERROR) 21 | add_subdirectory(cutlass) 22 | set(CMAKE_MESSAGE_LOG_LEVEL ${_saved_CMAKE_MESSAGE_LOG_LEVEL}) 23 | 24 | include_directories("${CMAKE_SOURCE_DIR}") 25 | include_directories(cutlass/tools/util/include) 26 | include_directories(cutlass/include) 27 | include_directories(gemm_fp8/kernels/include) 28 | 29 | get_property(dirs DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} PROPERTY INCLUDE_DIRECTORIES) 30 | foreach(dir ${dirs}) 31 | message(STATUS "dir='${dir}'") 32 | endforeach() 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 IST Austria Distributed Algorithms and Systems Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FP8 GEMM with PyTorch Interface 2 | 3 | ## Usage 4 | 5 | Insall the kernels using the following commands: 6 | 7 | ```bash 8 | git clone https://github.com/IST-DASLab/gemm_fp8.git 9 | cd gemm_fp8 10 | pip install -e . # or pip install . 11 | ``` 12 | 13 | Then, the kernel can be used as follows: 14 | 15 | ```python 16 | import torch 17 | import gemm_fp8 18 | y = gemm_fp8.matmul(a, b, alpha=1.0) 19 | ``` 20 | 21 | where `a` and `b` are the input matrices (in `torch.float8_e4m3fn` format) and `alpha` is the scaling factor (in `float`). 22 | 23 | ## Benchmark 24 | 25 | Run the following command to benchmark the kernel: 26 | 27 | ```bash 28 | python benchmark.py 29 | ``` -------------------------------------------------------------------------------- /benchmark.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import gemm_fp8 3 | import time 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | import seaborn as sns 7 | from typing import Callable, Iterable, List, Tuple 8 | 9 | sns.set() 10 | 11 | iters = 10 12 | warmup = 3 13 | 14 | 15 | def to_fp8(tensor: torch.Tensor) -> torch.Tensor: 16 | finfo = torch.finfo(torch.float8_e4m3fn) 17 | return torch.round(tensor.clamp( 18 | min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) 19 | 20 | 21 | 22 | def make_rand_tensors(dtype: torch.dtype, m: int, n: int, 23 | k: int) -> Tuple[torch.Tensor, torch.Tensor]: 24 | a = torch.randn((m, k), device='cuda').contiguous() * 5 25 | b = torch.randn((n, k), device='cuda').contiguous() * 5 26 | 27 | if dtype == torch.float8_e4m3fn: 28 | return to_fp8(a), to_fp8(b) 29 | if dtype == torch.bfloat16: 30 | return a.to(torch.bfloat16), b.to(torch.bfloat16) 31 | if dtype == torch.float16: 32 | return a.half(), b.half() 33 | if dtype == torch.float32: 34 | return a.float(), b.float() 35 | 36 | raise ValueError("unsupported dtype") 37 | 38 | 39 | # bench 40 | def bench_fn(fn: Callable, *args, **kwargs) -> Tuple: 41 | 42 | times_ = [] 43 | for i in range(warmup): 44 | fn(*args, **kwargs) 45 | torch.cuda.synchronize() 46 | 47 | for _ in range(10): 48 | start = time.time() 49 | for i in range(iters): 50 | fn(*args, **kwargs) 51 | torch.cuda.synchronize() 52 | times_.append((time.time() - start) * 1000 / iters) 53 | 54 | return np.mean(np.array(times_)), np.std(np.array(times_)) 55 | 56 | 57 | 58 | 59 | K_lists = [64, 128, 256, 512, 1024, 2048, 4096, 8192, 11008] 60 | last_dim = 4096 61 | 62 | token_dim = [512, 1024] 63 | dim_lists = [ #from LLaMa3-8B 64 | [4096, 4096], 65 | [4096, 14336], 66 | [14336, 4096] 67 | ] 68 | 69 | 70 | plt.figure(figsize=(15, 10)) 71 | 72 | for token in token_dim: 73 | x_labels = [] 74 | fp8_fast_acc_speedups = [] 75 | fp8_speedups = [] 76 | 77 | for k_, n_ in dim_lists: 78 | m_ = token 79 | 80 | x_labels.append(f"{k_}x{n_}") 81 | 82 | a, b = make_rand_tensors(torch.bfloat16, m_, n_, k_) 83 | a_fp8, b_fp8 = to_fp8(a), to_fp8(b) 84 | 85 | print("---- m: ", m_, "k: ", k_, "n: ", n_, "----") 86 | 87 | bf16_times, bf16_times_std = bench_fn(torch.matmul, a, b.t()) 88 | cutlass_times_fastAcc, cutlass_times_fastAcc_std = bench_fn(gemm_fp8.matmul, a_fp8, b_fp8, 1.0, True) 89 | cutlass_times, cutlass_times_std = bench_fn(gemm_fp8.matmul, a_fp8, b_fp8, 1.0, False) 90 | 91 | fp8_fast_acc_speedups.append(bf16_times/cutlass_times_fastAcc) 92 | fp8_speedups.append(bf16_times/cutlass_times) 93 | print(f"Speedup (FP8): {(bf16_times/cutlass_times):.2f}x") 94 | print(f"Speedup (FP8 FastAcc): {(bf16_times/cutlass_times_fastAcc):.2f}x") 95 | 96 | plt.plot(x_labels, fp8_fast_acc_speedups, 'o-', label=f"M={token} FP8 FastAcc") 97 | plt.plot(x_labels, fp8_speedups, 'o--', label=f"M={token} FP8") 98 | 99 | 100 | plt.axhline(1, color='black', linestyle='--') 101 | plt.xlabel("KxN") 102 | plt.ylabel("Speedup") 103 | plt.title(f"Speedup of FP8 over BF16 ({torch.cuda.get_device_name(0)})") 104 | plt.legend() 105 | plt.savefig("benchmark_fp8.png") 106 | 107 | 108 | 109 | 110 | -------------------------------------------------------------------------------- /gemm_fp8/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import gemm_fp8._CUDA 3 | 4 | 5 | __all__ = [ 6 | "matmul" 7 | ] 8 | 9 | def matmul(x: torch.Tensor, 10 | y: torch.Tensor, 11 | alpha: float = 1.0, 12 | fastAcc: bool = True) -> torch.Tensor: 13 | ''' 14 | Matrix-Matrix Multiplication for FP8 data type in the form of (x @ y.t())*alpha. 15 | The output is BF16 data type. todo: support arbitrary output dtype! 16 | Argumengs: 17 | x: torch.Tensor, shape (M, K) 18 | y: torch.Tensor, shape (K, N) 19 | alpha: float, which is multiplied by the output (default=1.0) 20 | fastAcc: bool, (default=True) 21 | ''' 22 | if fastAcc: 23 | return gemm_fp8._CUDA.fp8_matmul_fastAcc(x, y, alpha) 24 | else: 25 | return gemm_fp8._CUDA.fp8_matmul(x, y, alpha) 26 | -------------------------------------------------------------------------------- /gemm_fp8/kernels/bindings.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include // For std::pair 9 | 10 | torch::Tensor fp8_matmul( 11 | const torch::Tensor &X, const torch::Tensor &Y, const float alpha 12 | ) 13 | { 14 | torch::checkAllContiguous("fp8_matmul", {{X, "X", 0}, 15 | {Y, "Y", 1}}); 16 | torch::checkDeviceType("fp8_matmul", {X, Y}, at::DeviceType::CUDA); 17 | 18 | torch::checkAllSameGPU("fp8_matmul", {{X, "X", 0}, 19 | { Y, "Y", 1}}); 20 | uint32_t M = X.size(0); 21 | uint32_t N = Y.size(0); 22 | auto OUT = torch::empty({M, N}, torch::dtype(torch::kBFloat16).device(X.device())); 23 | 24 | fp8_matmul_host(OUT, X, Y, alpha); 25 | 26 | return OUT; 27 | } 28 | 29 | torch::Tensor fp8_matmul_fastAcc( 30 | const torch::Tensor &X, const torch::Tensor &Y, const float alpha 31 | ) 32 | { 33 | torch::checkAllContiguous("fp8_matmul_fastAcc", {{X, "X", 0}, 34 | {Y, "Y", 1}}); 35 | torch::checkDeviceType("fp8_matmul_fastAcc", {X, Y}, at::DeviceType::CUDA); 36 | 37 | torch::checkAllSameGPU("fp8_matmul_fastAcc", {{X, "X", 0}, 38 | { Y, "Y", 1}}); 39 | uint32_t M = X.size(0); 40 | uint32_t N = Y.size(0); 41 | auto OUT = torch::empty({M, N}, torch::dtype(torch::kBFloat16).device(X.device())); 42 | 43 | fp8_matmul_fastAcc_host(OUT, X, Y, alpha); 44 | 45 | return OUT; 46 | } 47 | 48 | //====== pybind ====== 49 | 50 | #define DEFINE_pybind(name) m.def(#name, &name, #name); 51 | 52 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m 53 | ) 54 | { 55 | 56 | m.def("fp8_matmul", &fp8_matmul, 57 | "fp8_matmul"); 58 | 59 | m.def("fp8_matmul_fastAcc", &fp8_matmul_fastAcc, 60 | "fp8_matmul_fastAcc"); 61 | } -------------------------------------------------------------------------------- /gemm_fp8/kernels/gemm.cu: -------------------------------------------------------------------------------- 1 | /*! \file 2 | \brief Example of running an Ada FP8 GEMM. 3 | 4 | In addition to using FP8 Tensor Core instructions, the Ada FP8 GEMM uses a distinct epilogue 5 | that enables additional scaling of operands/outputs, storing a pre-activation-function output 6 | tensor (called the "auxiliary" output), and computing the absolute maximum value of the 7 | outputs. 8 | 9 | Pseudocode for this epilogue is as follows: 10 | 11 | Aux = ((alpha * scale_a * scale_b) * accumulator) + ((beta * scale_c) * source) + bias 12 | D = activation(Aux) 13 | 14 | if Aux is fp8 type: 15 | abs_max_output = max( abs(aux) | (for every aux in Aux)) 16 | Aux = scale_aux * Aux 17 | endif 18 | 19 | if D is fp8 type: 20 | abs_max_output = max( abs(d) | (for every d in D)) 21 | D = scale_d * D 22 | endif 23 | 24 | Parameter Aux is optionally stored to global memory 25 | */ 26 | 27 | #include 28 | #include 29 | #include 30 | 31 | #include "cutlass/cutlass.h" 32 | #include "cutlass/numeric_conversion.h" 33 | #include "cutlass/util/command_line.h" 34 | #include "cutlass/util/host_tensor.h" 35 | #include "cutlass/util/reference/host/gemm_complex.h" 36 | #include "cutlass/util/tensor_view_io.h" 37 | #include "cutlass/util/distribution.h" 38 | #include "cutlass/util/reference/host/tensor_fill.h" 39 | #include "cutlass/util/reference/host/tensor_copy.h" 40 | #include "cutlass/util/reference/host/tensor_compare.h" 41 | #include "cutlass/util/reference/host/tensor_norm.h" 42 | #include "cutlass/util/reference/host/gemm.h" 43 | 44 | #include "cutlass/epilogue/thread/activation.h" 45 | #include "cutlass/epilogue/thread/linear_combination_generic_with_scaling.h" 46 | #include "cutlass/gemm/device/gemm_universal_with_absmax.h" 47 | 48 | #include "cutlass/layout/matrix.h" 49 | #include "cutlass/matrix_coord.h" 50 | #include "cutlass/gemm/device/gemm_universal_adapter.h" 51 | 52 | #include 53 | 54 | using ElementA = cutlass::float_e4m3_t; 55 | using ElementB = cutlass::float_e4m3_t; 56 | using ElementOutput = cutlass::bfloat16_t; 57 | using ElementAuxOutput = ElementOutput; 58 | using ElementAccumulator = float; 59 | using LayoutA = cutlass::layout::RowMajor; 60 | using LayoutB = cutlass::layout::ColumnMajor; 61 | using LayoutC = cutlass::layout::RowMajor; 62 | static int const kStages = 3; 63 | static int const kAlignmentA = 16; 64 | static int const kAlignmentB = 16; 65 | 66 | using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationGenericWithScalingAndAbsMax< 67 | cutlass::epilogue::thread::Identity, 68 | ElementOutput, 69 | ElementAuxOutput, 70 | 128 / cutlass::sizeof_bits::value, 71 | //8, 72 | ElementAccumulator, 73 | ElementAccumulator 74 | >; 75 | 76 | template 77 | using Gemm_ = cutlass::gemm::device::GemmUniversalWithAbsMax< 78 | ElementA, 79 | LayoutA, // Row-major 80 | ElementB, 81 | LayoutB, // Column-major 82 | ElementOutput, 83 | LayoutC, // Row-major 84 | ElementAccumulator, // float 85 | cutlass::arch::OpClassTensorOp, 86 | cutlass::arch::Sm89, 87 | TileShape, 88 | WarpShape, 89 | cutlass::gemm::GemmShape<16, 8, 32>, 90 | EpilogueOutputOp, 91 | cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 92 | kStages, 93 | kAlignmentA, 94 | kAlignmentB, 95 | MathOperator 96 | >; 97 | 98 | using ElementAbsmax = typename EpilogueOutputOp::ElementAbsmax; 99 | 100 | 101 | // Command line options parsing 102 | struct Options { 103 | 104 | cutlass::gemm::GemmCoord problem_size; 105 | 106 | float alpha; 107 | float beta; 108 | 109 | Options(int M, int N, int K, float scale=1.f): 110 | beta(0.f) 111 | { 112 | problem_size = cutlass::gemm::GemmCoord{M, N, K}; 113 | alpha = scale; 114 | } 115 | 116 | /// Compute performance in GFLOP/s 117 | float gflops(float runtime_s) const { 118 | // Two flops per multiply-add 119 | return 2.0f * float(problem_size.product()) / float(1.0e9) / runtime_s; 120 | } 121 | }; 122 | 123 | /// Helper class to run the kernel 124 | template 125 | struct TestbedRunner { 126 | 127 | using ElementCompute = typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; 128 | 129 | uint64_t seed; 130 | 131 | 132 | // 133 | // Methods 134 | // 135 | 136 | TestbedRunner() { } 137 | 138 | 139 | bool run( 140 | Options& options, 141 | torch::Tensor out, // FP32/FP16/BF16 (TODO) 142 | torch::Tensor x, // float_e4m3_t 143 | torch::Tensor y // float_e4m3_t 144 | ) 145 | { 146 | 147 | 148 | // 149 | // Initialize the GEMM operator 150 | // 151 | 152 | typename Gemm::EpilogueOutputOp::Params::ActivationParams activation_params{ 153 | ElementCompute(options.alpha), 154 | ElementCompute(options.beta) 155 | }; 156 | 157 | typename Gemm::EpilogueOutputOp::Params epilogue_params{ 158 | activation_params, 159 | nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, 160 | }; 161 | 162 | typename Gemm::Arguments arguments{ 163 | cutlass::gemm::GemmUniversalMode::kGemm, 164 | options.problem_size, 165 | /* batch_count = */ 1, 166 | epilogue_params, 167 | 168 | reinterpret_cast(x.data_ptr()), 169 | reinterpret_cast(y.data_ptr()), 170 | reinterpret_cast(out.data_ptr()), 171 | reinterpret_cast(out.data_ptr()), 172 | 173 | nullptr, 174 | nullptr, 175 | 176 | options.problem_size.m() * options.problem_size.k(), 177 | options.problem_size.n() * options.problem_size.k(), 178 | options.problem_size.m() * options.problem_size.n(), 179 | options.problem_size.m() * options.problem_size.n(), 180 | (int)options.problem_size.m(), // Batch stride vector 181 | 182 | x.stride(0), 183 | y.stride(0), 184 | out.stride(0), 185 | out.stride(0), 186 | (int64_t)0 // Leading dimension of vector. This must be 0 187 | }; 188 | 189 | Gemm gemm_op; 190 | 191 | cutlass::Status status = gemm_op.can_implement(arguments); 192 | if (status != cutlass::Status::kSuccess) { 193 | std::cerr << "Gemm::can_implement() failed" << std::endl; 194 | return false; 195 | } 196 | 197 | size_t workspace_size = Gemm::get_workspace_size(arguments); 198 | cutlass::device_memory::allocation workspace(workspace_size); 199 | 200 | status = gemm_op.initialize(arguments, workspace.get()); 201 | if (status != cutlass::Status::kSuccess) { 202 | std::cerr << "Gemm::initialize() failed" << std::endl; 203 | return false; 204 | } 205 | 206 | // 207 | // Run the GEMM 208 | // 209 | 210 | status = gemm_op(); 211 | 212 | if (status != cutlass::Status::kSuccess) { 213 | std::cerr << "Gemm::run() failed" << std::endl; 214 | return false; 215 | } 216 | 217 | return true; 218 | } 219 | 220 | }; 221 | 222 | ///////////////////////////////////////////////////////////////////////////////////////////////// 223 | 224 | bool fp8_matmul_host( 225 | torch::Tensor out, // FP32 226 | torch::Tensor x, // float_e4m3_t 227 | torch::Tensor y, // float_e4m3_t 228 | float alpha 229 | ){ 230 | auto M = x.size(0); 231 | auto N = y.size(0); 232 | auto K = x.size(1); 233 | 234 | Options options(M, N, K, alpha); 235 | 236 | if (K==4096 && N==4096){ 237 | using TileShape = typename cutlass::gemm::GemmShape<128, 64, 128>; 238 | using WarpShape = typename cutlass::gemm::GemmShape<64, 32, 128>; 239 | TestbedRunner> testbed_fast_accum; 240 | return testbed_fast_accum.run(options, out, x, y); 241 | } else { 242 | using TileShape = typename cutlass::gemm::GemmShape<64, 128, 64>; 243 | using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>; 244 | TestbedRunner> testbed_fast_accum; 245 | return testbed_fast_accum.run(options, out, x, y); 246 | } 247 | } 248 | 249 | 250 | bool fp8_matmul_fastAcc_host( 251 | torch::Tensor out, // FP32 252 | torch::Tensor x, // float_e4m3_t 253 | torch::Tensor y, // float_e4m3_t 254 | float alpha 255 | ){ 256 | auto M = x.size(0); 257 | auto N = y.size(0); 258 | auto K = x.size(1); 259 | 260 | Options options(M, N, K, alpha); 261 | 262 | 263 | if (K==4096 && N==4096){ 264 | using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>; 265 | using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; 266 | TestbedRunner> testbed_fast_accum; 267 | return testbed_fast_accum.run(options, out, x, y); 268 | } else { 269 | using TileShape = typename cutlass::gemm::GemmShape<64, 128, 64>; 270 | using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>; 271 | TestbedRunner> testbed_fast_accum; 272 | return testbed_fast_accum.run(options, out, x, y); 273 | } 274 | } -------------------------------------------------------------------------------- /gemm_fp8/kernels/include/common.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #pragma once 4 | #include 5 | #include 6 | #include 7 | 8 | #include 9 | #include 10 | 11 | 12 | #include 13 | #include 14 | #include 15 | #include 16 | 17 | #include "cutlass/cutlass.h" 18 | #include 19 | 20 | /** 21 | * Helper function for checking CUTLASS errors 22 | */ 23 | #define CUTLASS_CHECK(status) \ 24 | { \ 25 | TORCH_CHECK(status == cutlass::Status::kSuccess, \ 26 | cutlassGetStatusString(status)) \ 27 | } 28 | 29 | inline uint32_t next_pow_2(uint32_t const num) { 30 | if (num <= 1) return num; 31 | return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); 32 | } 33 | 34 | inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) { 35 | int max_shared_mem_per_block_opt_in = 0; 36 | cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in, 37 | cudaDevAttrMaxSharedMemoryPerBlockOptin, 38 | device); 39 | return max_shared_mem_per_block_opt_in; 40 | } -------------------------------------------------------------------------------- /gemm_fp8/kernels/include/gemm.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | bool fp8_matmul_fastAcc_host( 6 | torch::Tensor out, // BF16 7 | torch::Tensor x, // float_e4m3_t 8 | torch::Tensor y, // float_e4m3_t 9 | float alpha 10 | ); 11 | 12 | bool fp8_matmul_host( 13 | torch::Tensor out, // BF16 14 | torch::Tensor x, // float_e4m3_t 15 | torch::Tensor y, // float_e4m3_t 16 | float alpha 17 | ); -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel", "torch", "numpy", "cmake"] # Specify build dependencies here 3 | build-backend = "setuptools.build_meta" 4 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | import torch.utils.cpp_extension as torch_cpp_ext 3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 4 | import os 5 | import pathlib, torch 6 | setup_dir = os.path.dirname(os.path.realpath(__file__)) 7 | HERE = pathlib.Path(__file__).absolute().parent 8 | torch_version = torch.__version__ 9 | 10 | def remove_unwanted_pytorch_nvcc_flags(): 11 | REMOVE_NVCC_FLAGS = [ 12 | '-D__CUDA_NO_HALF_OPERATORS__', 13 | '-D__CUDA_NO_HALF_CONVERSIONS__', 14 | '-D__CUDA_NO_BFLOAT16_CONVERSIONS__', 15 | '-D__CUDA_NO_HALF2_OPERATORS__', 16 | ] 17 | for flag in REMOVE_NVCC_FLAGS: 18 | try: 19 | torch_cpp_ext.COMMON_NVCC_FLAGS.remove(flag) 20 | except ValueError: 21 | pass 22 | 23 | def get_cuda_arch_flags(): 24 | return [ 25 | '-gencode', 'arch=compute_80,code=sm_80', # Ampere 26 | '-gencode', 'arch=compute_89,code=sm_89', # Ada 27 | '--expt-relaxed-constexpr' 28 | ] 29 | 30 | def third_party_cmake(): 31 | import subprocess, sys, shutil 32 | 33 | cmake = shutil.which('cmake') 34 | if cmake is None: 35 | raise RuntimeError('Cannot find CMake executable.') 36 | 37 | retcode = subprocess.call([cmake, HERE]) 38 | if retcode != 0: 39 | sys.stderr.write("Error: CMake configuration failed.\n") 40 | sys.exit(1) 41 | 42 | if __name__ == '__main__': 43 | 44 | assert torch.cuda.is_available(), "CUDA is not available!" 45 | device = torch.cuda.current_device() 46 | print(f"Current device: {torch.cuda.get_device_name(device)}") 47 | print(f"Current CUDA capability: {torch.cuda.get_device_capability(device)}") 48 | assert torch.cuda.get_device_capability(device)[0] >= 8, f"CUDA capability must be >= 8.0, yours is {torch.cuda.get_device_capability(device)}" 49 | 50 | 51 | # Check if version is higher than 2.0 52 | print(f"PyTorch version: {torch_version}") 53 | assert int(torch_version.split('.')[0]) >= 2, "Torch version should be higher than 2!" 54 | 55 | 56 | third_party_cmake() 57 | remove_unwanted_pytorch_nvcc_flags() 58 | setup( 59 | name='gemm_fp8', 60 | ext_modules=[ 61 | CUDAExtension( 62 | name='gemm_fp8._CUDA', 63 | sources=[ 64 | 'gemm_fp8/kernels/bindings.cpp', 65 | 'gemm_fp8/kernels/gemm.cu', 66 | ], 67 | include_dirs=[ 68 | os.path.join(setup_dir, 'gemm_fp8/kernels/include'), 69 | os.path.join(setup_dir, 'cutlass/include'), 70 | os.path.join(setup_dir, 'cutlass/tools/util/include') 71 | ], 72 | extra_compile_args={ 73 | 'cxx': [], 74 | 'nvcc': get_cuda_arch_flags(), 75 | } 76 | ) 77 | ], 78 | cmdclass={ 79 | 'build_ext': BuildExtension 80 | } 81 | ) 82 | --------------------------------------------------------------------------------