├── .gitmodules ├── .gitignore ├── csrc └── kernels │ ├── include │ ├── gemm.h │ └── common.h │ ├── bindings.cpp │ └── gemm.cu ├── test.py ├── LICENSE ├── pyproject.toml ├── gemm_int8 └── __init__.py ├── setup.py ├── .github └── workflows │ └── build-and-release.yml ├── benchmark.py ├── README.md └── CMakeLists.txt /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "cutlass"] 2 | path = cutlass 3 | url = https://github.com/NVIDIA/cutlass 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.cmake 2 | Makefile 3 | *.so 4 | *.log 5 | *.a 6 | *.o 7 | *.obj 8 | *.dll 9 | *.dylib 10 | *.exe 11 | *.out 12 | *.bin 13 | linux-x86_64-cpython-310/* 14 | *.ninja_deps 15 | *.ninja_log 16 | *.ninja* 17 | CMakeFiles/ 18 | int8_ada.egg-info/ 19 | *.PKG-INFO 20 | __pycache__/ 21 | CMakeCache* 22 | gemm_int8.egg-info/ 23 | build/* 24 | dist/* 25 | *.cmake -------------------------------------------------------------------------------- /csrc/kernels/include/gemm.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | 6 | torch::Tensor int8_matmul_host(torch::Tensor input, // INT8 7 | torch::Tensor weight, // INT8 8 | torch::Tensor out, // BF16 9 | float alpha // FP32 10 | ); 11 | 12 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import gemm_int8 3 | 4 | 5 | x = torch.rand(32, 32).cuda() * 100 6 | x = x.to(torch.int8) 7 | 8 | y_torch = x.bfloat16() @ x.bfloat16().t() 9 | y_int8 = gemm_int8.matmul(x, x, 1.0) 10 | 11 | print("Testing opcheck...") 12 | torch.library.opcheck(torch.ops.gemm_int8_CUDA.int8_matmul.default, (x, x, 1.0)) 13 | 14 | print("Testing assert_close of torch matmul vs gemm_int8...") 15 | torch.testing.assert_close(y_torch, y_int8) 16 | 17 | 18 | @torch.compile(dynamic=True) 19 | def test_gemm_int8(x, y, alpha): 20 | return gemm_int8.matmul(x, y, alpha) 21 | 22 | 23 | y_int8 = test_gemm_int8(x, x, 1.0) 24 | print("Testing compile of gemm_int8...") 25 | torch.testing.assert_close(y_torch, y_int8) 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 IST Austria Distributed Algorithms and Systems Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /csrc/kernels/include/common.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #pragma once 4 | #include 5 | #include 6 | #include 7 | 8 | #include 9 | #include 10 | 11 | 12 | #include 13 | #include 14 | #include 15 | #include 16 | 17 | #include "cutlass/cutlass.h" 18 | #include 19 | 20 | /** 21 | * Helper function for checking CUTLASS errors 22 | */ 23 | #define CUTLASS_CHECK(status) \ 24 | { \ 25 | TORCH_CHECK(status == cutlass::Status::kSuccess, \ 26 | cutlassGetStatusString(status)) \ 27 | } 28 | 29 | inline uint32_t next_pow_2(uint32_t const num) { 30 | if (num <= 1) return num; 31 | return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); 32 | } 33 | 34 | inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) { 35 | int max_shared_mem_per_block_opt_in = 0; 36 | cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in, 37 | cudaDevAttrMaxSharedMemoryPerBlockOptin, 38 | device); 39 | return max_shared_mem_per_block_opt_in; 40 | } -------------------------------------------------------------------------------- /csrc/kernels/bindings.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include // For std::pair 9 | 10 | torch::Tensor int8_matmul(const torch::Tensor &A, 11 | const torch::Tensor &B, 12 | double alpha) 13 | { 14 | float alpha_f = static_cast(alpha); 15 | torch::checkAllContiguous("int8_matmul", {{A, "A", 0}, 16 | {B, "B", 1}}); 17 | torch::checkDeviceType("int8_matmul", {A, B}, at::DeviceType::CUDA); 18 | 19 | torch::checkAllSameGPU("int8_matmul", {{A, "A", 0}, 20 | {B, "B", 1}}); 21 | uint32_t M = A.size(0); 22 | uint32_t N = B.size(0); 23 | auto C = torch::empty({M, N}, torch::dtype(torch::kBFloat16).device(A.device())); 24 | 25 | return int8_matmul_host(A, B, C, alpha_f); 26 | } 27 | 28 | //====== pybind ====== 29 | 30 | TORCH_LIBRARY(gemm_int8_CUDA, m) 31 | { 32 | m.def("int8_matmul(Tensor A, Tensor B, float alpha) -> Tensor"); 33 | } 34 | 35 | TORCH_LIBRARY_IMPL(gemm_int8_CUDA, CUDA, m) 36 | { 37 | m.impl("int8_matmul", &int8_matmul); 38 | } 39 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools>=45", 4 | "wheel", 5 | "torch>=2.0.0", 6 | "cmake>=3.18.0", 7 | "ninja", 8 | "numpy" 9 | ] 10 | build-backend = "setuptools.build_meta" 11 | 12 | [project] 13 | name = "gemm_int8" 14 | version = "1.0.0" 15 | description = "High-performance INT8 matrix multiplication CUDA extension for PyTorch" 16 | readme = "README.md" 17 | authors = [ 18 | {name = "Rush Tabesh", email = "soroushtabesh@gmail.com"} 19 | ] 20 | requires-python = ">=3.9" 21 | classifiers = [ 22 | "Programming Language :: Python :: 3", 23 | "Programming Language :: Python :: 3.9", 24 | "Programming Language :: Python :: 3.10", 25 | "License :: OSI Approved :: MIT License", 26 | "Operating System :: POSIX :: Linux", 27 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 28 | "Intended Audience :: Science/Research", 29 | "Development Status :: 4 - Beta", 30 | ] 31 | dependencies = [ 32 | "torch>=2.0.0", 33 | ] 34 | 35 | [project.urls] 36 | Homepage = "https://github.com/IST-DASLab/gemm-int8" 37 | "Bug Tracker" = "https://github.com/IST-DASLab/gemm-int8/issues" 38 | Documentation = "https://github.com/IST-DASLab/gemm-int8#readme" 39 | 40 | [project.optional-dependencies] 41 | build = [ 42 | "cmake>=3.18.0", 43 | "ninja", 44 | ] 45 | 46 | [tool.setuptools] 47 | packages = ["gemm_int8"] 48 | -------------------------------------------------------------------------------- /gemm_int8/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import glob 4 | 5 | package_dir = os.path.dirname(os.path.abspath(__file__)) 6 | 7 | lib_pattern = os.path.join(package_dir, "gemm_int8_CUDA*.so") 8 | lib_files = glob.glob(lib_pattern) 9 | if not lib_files: 10 | raise ImportError(f"Could not find compiled CUDA extension in {package_dir}") 11 | 12 | for lib_file in lib_files: 13 | torch.ops.load_library(lib_file) 14 | 15 | 16 | @torch.library.register_fake("gemm_int8_CUDA::int8_matmul") 17 | def _(x: torch.Tensor, y: torch.Tensor, alpha: float = 1.0): 18 | torch._check(x.device.type == "cuda", "x must be a CUDA tensor") 19 | torch._check(y.device.type == "cuda", "y must be a CUDA tensor") 20 | torch._check(x.dtype == torch.int8, "x must be an int8 tensor") 21 | torch._check(y.dtype == torch.int8, "y must be an int8 tensor") 22 | torch._check(len(x.shape) == 2, "x must be a 2D tensor") 23 | torch._check(len(y.shape) == 2, "y must be a 2D tensor") 24 | torch._check(x.shape[1] == y.shape[1], "x.shape[1] must be equal to y.shape[1]") 25 | return torch.empty(x.shape[0], y.shape[0], device=x.device, dtype=torch.bfloat16) 26 | 27 | 28 | def matmul(x: torch.Tensor, y: torch.Tensor, alpha: float = 1.0): 29 | """ 30 | Matrix-Matrix Multiplication for INT8 data type in the form of (x @ y.t())*alpha. 31 | The output is BF16 data type. todo: support arbitrary output dtype! 32 | Argumengs: 33 | x: torch.Tensor, shape (M, K) 34 | y: torch.Tensor, shape (K, N) 35 | alpha: float, which is multiplied by the output (default=1.0) 36 | """ 37 | return torch.ops.gemm_int8_CUDA.int8_matmul(x, y, alpha) 38 | 39 | 40 | __all__ = ["matmul"] 41 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | import os 3 | import platform 4 | import pathlib 5 | import torch 6 | import sys 7 | from setuptools.command.bdist_wheel import bdist_wheel 8 | 9 | setup_dir = os.path.dirname(os.path.realpath(__file__)) 10 | HERE = pathlib.Path(__file__).absolute().parent 11 | 12 | min_cuda_version = (11, 8) 13 | 14 | def check_cuda_version(): 15 | """Verify CUDA compatibility before building.""" 16 | print(f"CUDA version: {torch.version.cuda}") 17 | cuda_version = tuple(map(int, torch.version.cuda.split("."))) 18 | assert cuda_version >= min_cuda_version, ( 19 | f"CUDA version must be >= {min_cuda_version}, yours is {torch.version.cuda}" 20 | ) 21 | 22 | def get_platform_tag(architecture=None): 23 | """Determine the platform tag for the wheel.""" 24 | if architecture is None: 25 | architecture = platform.machine() 26 | 27 | system = platform.system() 28 | 29 | if system == "Linux": 30 | tag = "manylinux_2_24_x86_64" if architecture == "x86_64" else "manylinux_2_24_aarch64" 31 | elif system == "Darwin": 32 | tag = "macosx_13_1_x86_64" if architecture == "x86_64" else "macosx_13_1_arm64" 33 | elif system == "Windows": 34 | tag = "win_amd64" if architecture == "x86_64" else "win_arm64" 35 | else: 36 | raise ValueError(f"Unsupported system: {system}") 37 | 38 | return tag 39 | 40 | class BdistWheelCommand(bdist_wheel): 41 | """Custom wheel building command to set platform tags correctly.""" 42 | def finalize_options(self): 43 | bdist_wheel.finalize_options(self) 44 | # Mark the wheel as platform-specific (not "any") 45 | self.root_is_pure = False 46 | 47 | def get_tag(self): 48 | python_tag = "py3" 49 | 50 | platform_tag = get_platform_tag() 51 | 52 | # Force the ABI tag to be 'none' since we're not using Python C API directly 53 | # (PyTorch's C++ extensions handle this for us) 54 | abi_tag = 'none' 55 | 56 | return python_tag, abi_tag, platform_tag 57 | 58 | if __name__ == "__main__": 59 | # Read README for the long description 60 | with open("README.md", "r", encoding="utf-8") as fh: 61 | long_description = fh.read() 62 | 63 | check_cuda_version() 64 | 65 | print(f"Building wheel with platform tag: {get_platform_tag()}") 66 | 67 | # The actual setup call without ext_modules 68 | setup( 69 | # All package configuration is now in pyproject.toml 70 | package_data={"gemm_int8": ["*.so"]}, # Include compiled libraries 71 | cmdclass={ 72 | 'bdist_wheel': BdistWheelCommand, 73 | }, 74 | ) 75 | -------------------------------------------------------------------------------- /.github/workflows/build-and-release.yml: -------------------------------------------------------------------------------- 1 | name: Build and Release 2 | 3 | on: 4 | push: 5 | branches: [main, ci] 6 | pull_request: 7 | branches: [main] 8 | workflow_dispatch: {} # Allow manual trigger 9 | 10 | concurrency: 11 | group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} 12 | cancel-in-progress: true 13 | 14 | jobs: 15 | build-with-cuda: 16 | runs-on: ubuntu-latest 17 | container: 18 | image: pytorch/pytorch:2.6.0-cuda12.6-cudnn9-devel 19 | steps: 20 | - name: Install build dependencies 21 | run: | 22 | apt-get update && apt-get install -y --no-install-recommends \ 23 | cmake ninja-build git 24 | 25 | - name: Checkout code 26 | uses: actions/checkout@v4 27 | with: 28 | submodules: recursive 29 | 30 | - name: Set CUDA_HOME 31 | run: | 32 | # Find nvcc location 33 | NVCC_PATH=$(which nvcc) 34 | # Extract CUDA installation directory (remove /bin/nvcc from path) 35 | export CUDA_HOME=$(dirname $(dirname $NVCC_PATH)) 36 | echo "CUDA_HOME=${CUDA_HOME}" >> $GITHUB_ENV 37 | echo "Found CUDA installation at: ${CUDA_HOME}" 38 | echo "${CUDA_HOME}/bin" >> $GITHUB_PATH 39 | 40 | - name: Verify CUDA installation 41 | run: | 42 | nvcc -V 43 | echo "CUDA_HOME: ${CUDA_HOME}" 44 | ls -la ${CUDA_HOME}/bin 45 | echo "PATH: $PATH" 46 | pwd 47 | ls . -alh 48 | ls cutlass -alh 49 | ls gemm_int8 -alh 50 | 51 | - name: Build C++/CUDA (CMake) 52 | run: | 53 | chmod +x build.sh 54 | ./build.sh 55 | env: 56 | CUDA_PATH: ${CUDA_HOME} 57 | 58 | - name: Build wheel 59 | run: ./build.sh --wheel 60 | env: 61 | CUDA_PATH: ${CUDA_HOME} 62 | 63 | - name: Upload build artifact 64 | uses: actions/upload-artifact@v4 65 | with: 66 | name: wheel 67 | path: dist/*.whl 68 | retention-days: 7 69 | 70 | publish-release: 71 | needs: build-with-cuda 72 | runs-on: ubuntu-latest 73 | permissions: 74 | contents: write 75 | steps: 76 | - name: Checkout code to get version 77 | uses: actions/checkout@v4 78 | 79 | - name: Extract version 80 | id: extract_version 81 | run: | 82 | VERSION=$(grep version pyproject.toml | head -n1 | awk -F'"' '{print $2}') 83 | echo "Package version: $VERSION" 84 | echo "version=$VERSION" >> $GITHUB_OUTPUT 85 | 86 | - name: Download wheel artifacts 87 | uses: actions/download-artifact@v4 88 | with: 89 | name: wheel 90 | path: wheels/ 91 | 92 | - name: List wheels 93 | run: ls -la wheels/ 94 | 95 | - name: Create/Update Release 96 | uses: softprops/action-gh-release@v2.0.8 97 | with: 98 | files: wheels/*.whl 99 | prerelease: false 100 | name: "v${{ steps.extract_version.outputs.version }}" 101 | tag_name: "v${{ steps.extract_version.outputs.version }}" 102 | make_latest: true 103 | draft: false 104 | target_commitish: ${{ github.sha }} -------------------------------------------------------------------------------- /benchmark.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import gemm_int8 3 | import time 4 | import numpy as np 5 | import matplotlib 6 | import matplotlib.pyplot as plt 7 | import seaborn as sns 8 | from typing import Callable, Iterable, List, Tuple 9 | 10 | matplotlib.rcParams["lines.linewidth"] = 2 * matplotlib.rcParams["lines.linewidth"] 11 | matplotlib.rcParams["lines.markersize"] = 2 * matplotlib.rcParams["lines.markersize"] 12 | matplotlib.rcParams.update({"font.size": 2 * matplotlib.rcParams["font.size"]}) 13 | 14 | iters = 100 15 | warmup = 10 16 | 17 | def to_int8(tensor: torch.Tensor) -> torch.Tensor: 18 | return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8) 19 | 20 | def make_rand_tensors(dtype: torch.dtype, m: int, n: int, 21 | k: int) -> Tuple[torch.Tensor, torch.Tensor]: 22 | a = torch.randn((m, k), device='cuda').contiguous() * 5 23 | b = torch.randn((n, k), device='cuda').contiguous() * 5 24 | 25 | if dtype == torch.int8: 26 | return to_int8(a), to_int8(b) 27 | if dtype == torch.bfloat16: 28 | return a.to(torch.bfloat16), b.to(torch.bfloat16) 29 | if dtype == torch.float16: 30 | return a.half(), b.half() 31 | if dtype == torch.float32: 32 | return a.float(), b.float() 33 | 34 | raise ValueError("unsupported dtype") 35 | 36 | 37 | 38 | # bench 39 | def bench_fn(fn: Callable, *args, **kwargs) -> Tuple: 40 | 41 | times_ = [] 42 | for i in range(warmup): 43 | fn(*args, **kwargs) 44 | torch.cuda.synchronize() 45 | 46 | for _ in range(10): 47 | start = time.time() 48 | for i in range(iters): 49 | fn(*args, **kwargs) 50 | torch.cuda.synchronize() 51 | times_.append((time.time() - start) * 1000 / iters) 52 | 53 | return np.mean(np.array(times_)), np.std(np.array(times_)) 54 | 55 | 56 | matrix_sizes = [ 57 | (4096, 4096), 58 | (14336, 4096), 59 | (4096, 14336), 60 | ] 61 | 62 | tokens = [512, 1024, 2048] 63 | 64 | x_labels = [] 65 | bf16_runtimes = [] 66 | int8_runtimes = [] 67 | 68 | for token in tokens: 69 | print('------------------') 70 | print(f"Token: {token}") 71 | for (n, k) in matrix_sizes: 72 | print(f"Matrix size: {k}x{n}") 73 | x_labels.append(f"{k}x{n}") 74 | a, b = make_rand_tensors(torch.bfloat16, token, n, k) 75 | a_int8, b_int8 = make_rand_tensors(torch.int8, token, n, k) 76 | 77 | bf16_times, bf16_times_std = bench_fn(torch.matmul, a, b.t()) 78 | v_1_times, v_1_times_std = bench_fn(gemm_int8.matmul, a_int8, b_int8, 1.0) 79 | 80 | print(f'Speedup: {bf16_times/v_1_times:.2f}x') 81 | 82 | int8_runtimes.append(v_1_times.item()) 83 | bf16_runtimes.append(bf16_times.item()) 84 | 85 | print(bf16_runtimes) 86 | print(int8_runtimes) 87 | 88 | """ for layer in range(len(matrix_sizes)): 89 | plt.plot( 90 | x_labels[(layer*len(tokens)):(layer*len(tokens))+len(tokens)], 91 | np.array(bf16_runtimes[(layer*len(tokens)):(layer*len(tokens))+len(tokens)])/np.array(int8_runtimes[(layer*len(tokens)):(layer*len(tokens))+len(tokens)]), 92 | 'o-', label=f"Layer shape: {matrix_sizes[layer]}") 93 | 94 | plt.axhline(1, color='black', linestyle='--') 95 | plt.ylabel("Speedup (over BF16)") 96 | plt.xlabel("M-dim") 97 | plt.title(f'{torch.cuda.get_device_name()}') 98 | plt.legend() 99 | plt.savefig("int8_bf16_benchmark.png") """ 100 | 101 | sns.set() 102 | plt.figure(figsize=(15, 10)) 103 | for token_id in range(len(tokens)): 104 | plt.plot( 105 | x_labels[(token_id*len(matrix_sizes)):(token_id*len(matrix_sizes))+len(matrix_sizes)], 106 | np.array(bf16_runtimes[(token_id*len(matrix_sizes)):(token_id*len(matrix_sizes))+len(matrix_sizes)])/np.array(int8_runtimes[(token_id*len(matrix_sizes)):(token_id*len(matrix_sizes))+len(matrix_sizes)]), 107 | 'o-', label=f"Token Dim: {tokens[token_id]}") 108 | plt.plot(x_labels, np.ones(len(x_labels))*4, "k") 109 | plt.axhline(1, color='black', linestyle='--') 110 | plt.ylabel("Speedup (over BF16)") 111 | plt.xlabel("Matrix Dimensions (k x n)") 112 | plt.title(f'{torch.cuda.get_device_name()}') 113 | plt.legend() 114 | plt.yticks(np.arange(1, 4.1, 0.25)) 115 | 116 | plt.tight_layout() 117 | plt.savefig("benchmark_int8.png") -------------------------------------------------------------------------------- /csrc/kernels/gemm.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "cutlass/float8.h" 4 | 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | 13 | #include 14 | #include 15 | #include 16 | 17 | #include 18 | #include 19 | #include 20 | 21 | 22 | template 23 | torch::Tensor int8_matmul( 24 | torch::Tensor input, // INT8 25 | torch::Tensor weight, // INT8 26 | torch::Tensor out, // BF16 27 | float alpha // FP32 28 | ){ 29 | auto M = input.size(0); 30 | auto N = weight.size(0); 31 | auto K = input.size(1); 32 | 33 | using ElementOutput = cutlass::bfloat16_t; 34 | using ElementAccumulator = int32_t; 35 | using ElementComputeEpilogue = float; 36 | using ElementInputA = int8_t; // <- data type of elements in input matrix A 37 | using ElementInputB = int8_t; // <- data type of elements in input matrix B 38 | 39 | // The code section below describes matrix layout of input and output 40 | // matrices. Column Major for Matrix A, Row Major for Matrix B and Row Major 41 | // for Matrix C 42 | using LayoutInputA = cutlass::layout::RowMajor; 43 | using LayoutInputB = cutlass::layout::ColumnMajor; 44 | using LayoutOutput = cutlass::layout::RowMajor; 45 | 46 | using Gemm = cutlass::gemm::device::Gemm< 47 | int8_t, 48 | cutlass::layout::RowMajor, 49 | int8_t, 50 | cutlass::layout::ColumnMajor, 51 | ElementOutput, 52 | cutlass::layout::RowMajor, 53 | ElementAccumulator, 54 | cutlass::arch::OpClassTensorOp, 55 | cutlass::arch::Sm80, 56 | TileShape, 57 | WarpShape, 58 | cutlass::gemm::GemmShape<16, 8, 32>, 59 | cutlass::epilogue::thread::LinearCombination< 60 | ElementOutput, 61 | 128 / cutlass::sizeof_bits::value, 62 | ElementAccumulator, 63 | ElementComputeEpilogue>, 64 | cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 65 | kStages>; 66 | 67 | auto input_size = cutlass::MatrixCoord(M, K); 68 | auto weight_size = cutlass::MatrixCoord(K, N); 69 | auto output_size = cutlass::MatrixCoord(M, N); 70 | 71 | auto device = input.device(); 72 | 73 | cutlass::gemm::GemmCoord problem_size(M, N, K); 74 | 75 | 76 | cutlass::TensorRef input_ref( 77 | static_cast(input.data_ptr()), 78 | LayoutInputA::packed(input_size)); 79 | 80 | cutlass::TensorRef weight_ref( 81 | //weight.data_ptr(), 82 | static_cast(weight.data_ptr()), 83 | LayoutInputB::packed(weight_size)); 84 | 85 | cutlass::TensorRef out_ref( 86 | //out.data_ptr(), 87 | static_cast(out.data_ptr()), 88 | LayoutOutput::packed(output_size)); 89 | 90 | typename Gemm::Arguments arguments{ 91 | problem_size, // <- problem size of matrix multiplication 92 | input_ref, // <- reference to matrix A on device 93 | weight_ref, // <- reference to matrix B on device 94 | out_ref, // <- reference to matrix C on device 95 | out_ref, // <- reference to matrix D on device 96 | {alpha, 0.0}, 1}; 97 | Gemm gemm_op; 98 | 99 | // Using the arguments, query for extra workspace required for matrix 100 | // multiplication computation 101 | size_t workspace_size = Gemm::get_workspace_size(arguments); 102 | 103 | // Allocate workspace memory 104 | cutlass::device_memory::allocation workspace(workspace_size); 105 | 106 | // Check the problem size is supported or not 107 | cutlass::Status status = gemm_op.can_implement(arguments); 108 | if (status != cutlass::Status::kSuccess) { 109 | throw std::runtime_error("cutlass cannot implement"); 110 | } 111 | 112 | // Initialize CUTLASS kernel with arguments and workspace pointer 113 | status = gemm_op.initialize(arguments, workspace.get()); 114 | if (status != cutlass::Status::kSuccess) { 115 | throw std::runtime_error("cutlass cannot initialize"); 116 | } 117 | 118 | status = gemm_op(); 119 | if (status != cutlass::Status::kSuccess) { 120 | throw std::runtime_error("cutlass cannot run"); 121 | } 122 | 123 | return out; 124 | } 125 | 126 | torch::Tensor int8_matmul_host( 127 | torch::Tensor input, // INT8 128 | torch::Tensor weight, // INT8 129 | torch::Tensor out, // BF16 130 | float alpha // FP32 131 | ){ 132 | auto M = input.size(0); 133 | auto N = weight.size(0); 134 | auto K = input.size(1); 135 | 136 | if (M==512 && N==4096 && K==4096){ 137 | using TileShape = typename cutlass::gemm::GemmShape<128, 128, 128>; 138 | using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 128>; 139 | static const int kStages = 3; 140 | return int8_matmul(input, weight, out, alpha); 141 | } else if (M==512 && N==4096 && K==14336){ 142 | using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>; 143 | using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; 144 | static const int kStages = 4; 145 | return int8_matmul(input, weight, out, alpha); 146 | } else if (K==4096 && N==4096){ 147 | using TileShape = typename cutlass::gemm::GemmShape<256, 128, 64>; 148 | using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; 149 | static const int kStages = 3; 150 | return int8_matmul(input, weight, out, alpha); 151 | } else if (M==1024 && N==14336 && K==4096){ 152 | using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>; 153 | using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; 154 | static const int kStages = 3; 155 | return int8_matmul(input, weight, out, alpha); 156 | } else { 157 | using TileShape = typename cutlass::gemm::GemmShape<256, 128, 64>; 158 | using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; 159 | static const int kStages = 3; 160 | return int8_matmul(input, weight, out, alpha); 161 | } 162 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # INT8 GEMM with PyTorch Interface 2 | 3 | 4 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 5 | [![Python 3.9+](https://img.shields.io/badge/python-3.9+-blue.svg)](https://www.python.org/downloads/) 6 | [![CUDA 11.8+](https://img.shields.io/badge/CUDA-11.8%2B-green.svg)](https://developer.nvidia.com/cuda-toolkit) 7 | 8 | 9 | 10 | A PyTorch CUDA extension providing high-performance INT8 matrix multiplication operations utilizing CUTLASS iterators. Specifically optimized for modern NVIDIA GPUs including Ada Lovelace and Hopper architectures, this library offers measurable performance improvements over standard BF16 matrix multiplication in deep learning applications. (It was originally used in [HALO: Hadamard-Assisted Low-Precision Optimization and Training method for finetuning LLMs](https://github.com/IST-DASLab/HALO)) 11 | 12 | ## Features 13 | 14 | - INT8 matrix multiplication with PyTorch integration, providing up to 4x speedup on RTX 4090 GPUs 15 | - Compatible with PyTorch's torch.compile (autograd not supported) 16 | - Optimized CUDA kernels for compute capabilities 89-100 (Ada Lovelace, Hopper) 17 | - Tuned kernel configurations for common matrix dimensions in transformer models 18 | - Direct integration with existing PyTorch workflows 19 | 20 | ## Quick Start 21 | 22 | ```bash 23 | # Install from GitHub releases 24 | pip install https://github.com/IST-DASLab/gemm-int8/releases/download/latest/gemm_int8-1.0.0-py3-none-manylinux_2_24_x86_64.whl 25 | ``` 26 | 27 | ```python 28 | import torch 29 | import gemm_int8 30 | 31 | # Create input tensors 32 | a = torch.randint(-128, 127, (1024, 4096), device='cuda', dtype=torch.int8) 33 | b = torch.randint(-128, 127, (4096, 4096), device='cuda', dtype=torch.int8) 34 | 35 | # Perform INT8 matrix multiplication (compute a @ b.t()) 36 | result = gemm_int8.matmul(a, b, alpha=1.0) # Returns bfloat16 tensor of (a @ b.t()) * alpha 37 | ``` 38 | 39 | Performs matrix multiplication in the form of `(x @ y.t()) * alpha`. 40 | 41 | **Parameters:** 42 | - `x` (torch.Tensor): Input matrix of shape (M, K) with dtype torch.int8 43 | - `y` (torch.Tensor): Input matrix of shape (N, K) with dtype torch.int8 44 | - `alpha` (float, optional): Scaling factor applied to the output. Default: 1.0 45 | 46 | **Returns:** 47 | - torch.Tensor: Result matrix of shape (M, N) with dtype torch.bfloat16 48 | 49 | ## Requirements 50 | 51 | - Python 3.9+ 52 | - PyTorch 2.0.0+ 53 | - CUDA 11.8+ 54 | - NVIDIA GPU with Compute Capability 70 or higher 55 | - Linux with x86_64 architecture (primary platform) 56 | 57 | ## Installation 58 | 59 | ### Option 1: From PyPI (Coming Soon) 60 | 61 | ```bash 62 | pip install gemm-int8 63 | ``` 64 | 65 | ### Option 2: From GitHub Release 66 | 67 | Download pre-built wheels directly from the GitHub releases page: 68 | 69 | ```bash 70 | pip install https://github.com/IST-DASLab/gemm-int8/releases/download/v$(VERSION)/gemm_int8-$(VERSION)-py3-none-$(PLATFORM_TAG).whl 71 | ``` 72 | 73 | Where: 74 | - `$(VERSION)` is the package version (e.g., "1.0.0") 75 | - `$(PLATFORM_TAG)` is your platform tag (e.g., "manylinux_2_24_x86_64") 76 | 77 | Or to install the latest build from the main branch: 78 | 79 | ```bash 80 | pip install https://github.com/IST-DASLab/gemm-int8/releases/download/latest/gemm_int8-$(VERSION)-py3-none-$(PLATFORM_TAG).whl 81 | ``` 82 | 83 | ### Option 3: Build From Source 84 | 85 | Building from source requires additional development tools: 86 | 87 | ```bash 88 | # Clone the repository with submodules 89 | git clone --recursive https://github.com/IST-DASLab/gemm-int8.git 90 | cd gemm-int8 91 | 92 | # Make sure CUDA toolkit is properly installed and CUDA_HOME is set 93 | echo $CUDA_HOME # Should point to your CUDA installation directory 94 | # If not set, you may need to run: export CUDA_HOME=/usr/local/cuda 95 | 96 | # Also make sure you hace cmake and ninja installed in your environment. 97 | pip install cmake ninja 98 | 99 | # Build and install 100 | ./build.sh 101 | pip install . 102 | 103 | # Alternatively, for development installation 104 | pip install -e . 105 | ``` 106 | 107 | 108 | ### Integration with torch.compile 109 | 110 | The library is compatible with PyTorch's `torch.compile` i.e. if this code is used within a compiled scope: 111 | 112 | ```python 113 | import torch 114 | import gemm_int8 115 | 116 | @torch.compile(dynamic=True) 117 | def compiled_matmul_routine(x, y, alpha): 118 | # ... some pytorch operations 119 | res = gemm_int8.matmul(x, y, alpha) 120 | # ... some pytorch operations 121 | return res 122 | 123 | # Use the compiled function 124 | result = compiled_matmul_routine(a, b, 1.0) 125 | ``` 126 | 127 | Note that compile won't optimize this kernel and it's only compatible in the sense that torch compile backend will recognize it as an operator and can be compiled along other operations in a routine. 128 | 129 | ## Benchmarks 130 | 131 | You can run the benchmark script to compare performance: 132 | 133 | ```bash 134 | python benchmark.py 135 | ``` 136 | 137 | This will generate a benchmark report and a visualization showing the speedup compared to BF16 matrix multiplication across different matrix sizes and token dimensions. 138 | 139 | Typical speedups range from 2x to 4x depending on the matrix dimensions and hardware. 140 | 141 | ## Performance Tips 142 | 143 | - For best performance, ensure your tensors are contiguous in memory 144 | - The library is optimized for large matrix sizes commonly found in transformer models 145 | - Performance benefits are most significant for matrix dimensions commonly used in LLM inference 146 | 147 | ## License 148 | 149 | This project is licensed under the MIT License - see the LICENSE file for details. 150 | 151 | ## Citation 152 | 153 | If you use this library in your research, please cite: 154 | 155 | ```bibtex 156 | @software{gemm_int8, 157 | author = {Roberto L. Castro and Saleh Ashkboos and Soroush Tabesh}, 158 | title = {INT8 GEMM with PyTorch Interface}, 159 | url = {https://github.com/IST-DASLab/gemm-int8}, 160 | year = {2024}, 161 | } 162 | ``` 163 | 164 | ```bibtex 165 | @article{halo2025, 166 | title={HALO: Hadamard-Assisted Lower-Precision Optimization for LLMs}, 167 | author={Saleh Ashkboos and Mahdi Nikdan and Soroush Tabesh and Roberto L. Castro and Torsten Hoefler and Dan Alistarh}, 168 | year={2025}, 169 | eprint={2501.02625}, 170 | archivePrefix={arXiv}, 171 | primaryClass={cs.LG}, 172 | url={https://arxiv.org/abs/2501.02625}, 173 | } 174 | ``` 175 | 176 | ## Acknowledgements 177 | 178 | This project uses [CUTLASS](https://github.com/NVIDIA/cutlass) for optimized CUDA kernels. 179 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.18) 2 | project(gemm_int8 LANGUAGES CXX) 3 | 4 | # Set default build type to Release 5 | if(NOT CMAKE_BUILD_TYPE) 6 | set(CMAKE_BUILD_TYPE Release) 7 | endif() 8 | message(STATUS "Build type: ${CMAKE_BUILD_TYPE}") 9 | 10 | # Set output directories for all build artifacts 11 | set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) 12 | set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) 13 | set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) 14 | 15 | # Find Python executable 16 | if(NOT DEFINED Python3_EXECUTABLE) 17 | find_program(Python3_EXECUTABLE NAMES python3 python) 18 | if(NOT Python3_EXECUTABLE) 19 | message(FATAL_ERROR "Python3 executable not found. Please specify with -DPython3_EXECUTABLE=path/to/python") 20 | endif() 21 | endif() 22 | message(STATUS "Using Python executable: ${Python3_EXECUTABLE}") 23 | 24 | # Find Python package 25 | find_package(Python3 COMPONENTS Development REQUIRED) 26 | message(STATUS "Python3_INCLUDE_DIRS: ${Python3_INCLUDE_DIRS}") 27 | 28 | # Get Python include directories 29 | execute_process( 30 | COMMAND ${Python3_EXECUTABLE} -c "import sysconfig; print(sysconfig.get_path('include'))" 31 | OUTPUT_VARIABLE PYTHON_INCLUDE_DIR 32 | OUTPUT_STRIP_TRAILING_WHITESPACE 33 | ) 34 | message(STATUS "Python include directory: ${PYTHON_INCLUDE_DIR}") 35 | 36 | # Find PyTorch 37 | execute_process( 38 | COMMAND ${Python3_EXECUTABLE} -c "import torch; print(torch.utils.cmake_prefix_path)" 39 | RESULT_VARIABLE PYTORCH_RESULT 40 | OUTPUT_VARIABLE TORCH_PREFIX_PATH 41 | OUTPUT_STRIP_TRAILING_WHITESPACE 42 | ) 43 | if(NOT PYTORCH_RESULT EQUAL 0) 44 | message(FATAL_ERROR "PyTorch not found. Please install PyTorch first.") 45 | endif() 46 | list(APPEND CMAKE_PREFIX_PATH ${TORCH_PREFIX_PATH}) 47 | 48 | # Enable CUDA 49 | if(NOT DEFINED BUILD_CUDA) 50 | set(BUILD_CUDA ON) 51 | endif() 52 | 53 | if(BUILD_CUDA) 54 | # NVCC compatibility check for newer MSVC compilers 55 | if(MSVC AND MSVC_VERSION VERSION_GREATER_EQUAL 1940) 56 | string(APPEND CMAKE_CUDA_FLAGS " --allow-unsupported-compiler") 57 | endif() 58 | 59 | enable_language(CUDA) 60 | find_package(CUDAToolkit REQUIRED) 61 | 62 | # Convert the CUDA version from X.Y.z to XY 63 | string(REGEX MATCH "^[0-9]+.[0-9]+" _CUDA_VERSION_FIRST_TWO "${CMAKE_CUDA_COMPILER_VERSION}") 64 | string(REPLACE "." "" CUDA_VERSION_SHORT "${_CUDA_VERSION_FIRST_TWO}") 65 | 66 | message(STATUS "CUDA Version: ${CUDA_VERSION_SHORT} (${CMAKE_CUDA_COMPILER_VERSION})") 67 | message(STATUS "CUDA Compiler: ${CMAKE_CUDA_COMPILER}") 68 | 69 | # IMPORTANT: This is the key change - disable PyTorch's architecture detection 70 | set(TORCH_CUDA_ARCH_LIST "") 71 | 72 | # Default architectures if not provided 73 | if(NOT DEFINED COMPUTE_CAPABILITY) 74 | set(COMPUTE_CAPABILITY "70;75;80;86;89;90;90a" CACHE STRING "CUDA Compute Capabilities") 75 | endif() 76 | 77 | message(STATUS "CUDA Capabilities Selected: ${COMPUTE_CAPABILITY}") 78 | 79 | # Configure architectures for compilation - explicitly set with our choices 80 | set(CMAKE_CUDA_ARCHITECTURES ${COMPUTE_CAPABILITY}) 81 | 82 | # Set explicit NVCC flags to override any auto-detection 83 | set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr --use_fast_math") 84 | 85 | # Add explicit architecture flags to NVCC 86 | foreach(ARCH ${COMPUTE_CAPABILITY}) 87 | string(APPEND CMAKE_CUDA_FLAGS " -gencode=arch=compute_${ARCH},code=sm_${ARCH}") 88 | endforeach() 89 | 90 | # For the latest architecture, also add PTX 91 | list(GET COMPUTE_CAPABILITY -1 LATEST_ARCH) 92 | string(APPEND CMAKE_CUDA_FLAGS " -gencode=arch=compute_${LATEST_ARCH},code=compute_${LATEST_ARCH}") 93 | 94 | message(STATUS "CUDA Flags: ${CMAKE_CUDA_FLAGS}") 95 | 96 | # Set C++ standard for CUDA 97 | set(CMAKE_CUDA_STANDARD 17) 98 | set(CMAKE_CUDA_STANDARD_REQUIRED ON) 99 | 100 | # Define that we're building with CUDA 101 | add_compile_definitions(BUILD_CUDA) 102 | endif() 103 | 104 | # Set C++ standard 105 | set(CMAKE_CXX_STANDARD 17) 106 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 107 | 108 | # Include CUTLASS headers (without building the entire library) 109 | include_directories(${CMAKE_SOURCE_DIR}/cutlass/include) 110 | include_directories(${CMAKE_SOURCE_DIR}/cutlass/tools/util/include) 111 | 112 | # Setup include directories 113 | include_directories(${CMAKE_SOURCE_DIR}) 114 | include_directories(${CMAKE_SOURCE_DIR}/csrc/kernels/include) 115 | include_directories(${Python3_INCLUDE_DIRS}) 116 | include_directories(${PYTHON_INCLUDE_DIR}) 117 | 118 | # Find PyTorch - IMPORTANT: Do this after setting TORCH_CUDA_ARCH_LIST 119 | find_package(Torch REQUIRED) 120 | message(STATUS "Found PyTorch: ${TORCH_INCLUDE_DIRS}") 121 | 122 | # Create source files list 123 | set(CPP_FILES csrc/kernels/bindings.cpp) 124 | set(CUDA_FILES csrc/kernels/gemm.cu) 125 | 126 | # Add source files based on backend 127 | if(BUILD_CUDA) 128 | set(SRC_FILES ${CPP_FILES} ${CUDA_FILES}) 129 | set(OUTPUT_NAME "gemm_int8_CUDA") 130 | else() 131 | set(SRC_FILES ${CPP_FILES}) 132 | set(OUTPUT_NAME "gemm_int8_CPU") 133 | endif() 134 | 135 | # Create the extension library 136 | add_library(gemm_int8 SHARED ${SRC_FILES}) 137 | 138 | # Link dependencies 139 | if(BUILD_CUDA) 140 | target_link_libraries(gemm_int8 PRIVATE 141 | "${TORCH_LIBRARIES}" 142 | Python3::Python 143 | CUDA::cudart 144 | CUDA::cublas 145 | ) 146 | else() 147 | target_link_libraries(gemm_int8 PRIVATE 148 | "${TORCH_LIBRARIES}" 149 | Python3::Python 150 | ) 151 | endif() 152 | 153 | target_include_directories(gemm_int8 PRIVATE 154 | ${TORCH_INCLUDE_DIRS} 155 | ${Python3_INCLUDE_DIRS} 156 | ${PYTHON_INCLUDE_DIR} 157 | ) 158 | 159 | # Set output properties 160 | set_target_properties(gemm_int8 PROPERTIES 161 | OUTPUT_NAME "${OUTPUT_NAME}" 162 | PREFIX "" 163 | ) 164 | 165 | # Configure output directories based on platform 166 | if(WIN32) 167 | # Windows-specific settings 168 | set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON) 169 | 170 | if(MSVC) 171 | set_target_properties(gemm_int8 PROPERTIES 172 | RUNTIME_OUTPUT_DIRECTORY_RELEASE "${CMAKE_SOURCE_DIR}/gemm_int8" 173 | RUNTIME_OUTPUT_DIRECTORY_DEBUG "${CMAKE_SOURCE_DIR}/gemm_int8" 174 | ) 175 | endif() 176 | else() 177 | # Linux/macOS settings 178 | set_target_properties(gemm_int8 PROPERTIES 179 | LIBRARY_OUTPUT_DIRECTORY "${CMAKE_SOURCE_DIR}/gemm_int8" 180 | ) 181 | endif() 182 | 183 | # Make a custom command to copy the built library to the Python package 184 | add_custom_command( 185 | TARGET gemm_int8 186 | POST_BUILD 187 | COMMAND ${CMAKE_COMMAND} -E copy_if_different 188 | $ 189 | "${CMAKE_SOURCE_DIR}/gemm_int8/$" 190 | COMMENT "Copying library to Python package directory" 191 | ) 192 | 193 | # Debug info 194 | message(STATUS "Source files: ${SRC_FILES}") 195 | message(STATUS "Library will be copied to: ${CMAKE_SOURCE_DIR}/gemm_int8/$") 196 | 197 | # Print architecture settings again at the end to confirm 198 | if(BUILD_CUDA) 199 | message(STATUS "Final CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}") 200 | message(STATUS "Final CUDA flags: ${CMAKE_CUDA_FLAGS}") 201 | endif() --------------------------------------------------------------------------------