├── kernels └── hgemm │ ├── tools │ ├── clear.sh │ ├── install.sh │ ├── utils.py │ └── print_swizzle_layout.py │ ├── bench │ ├── NVIDIA_L20.png │ ├── NVIDIA_GeForce_RTX_4090.png │ ├── NVIDIA_GeForce_RTX_3080_Laptop_GPU_WSL2.png │ └── prof.py │ ├── .gitignore │ ├── mma │ ├── basic │ │ ├── .gitignore │ │ ├── hgemm_mma.cu │ │ └── hgemm_mma_stage_tn.cu │ ├── others │ │ └── .gitignore │ ├── swizzle │ │ └── .gitignore │ ├── hgemm_mma.cu │ └── hgemm_mma_stage_tn.cu │ ├── setup.py │ ├── makefile │ ├── pybind │ └── hgemm.cc │ ├── cublas │ └── hgemm_cublas.cu │ ├── utils │ └── utils.h │ ├── README.md │ ├── cutlass │ └── hgemm_mma_stage_tn_cute.cu │ └── hgemm.py ├── .gitmodules ├── .dev └── update_submodules.sh └── README.md /kernels/hgemm/tools/clear.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | rm -rf __pycache__ build dist toy_hgemm.egg-info *.bin 4 | 5 | set +x -------------------------------------------------------------------------------- /kernels/hgemm/bench/NVIDIA_L20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xlite-dev/HGEMM/HEAD/kernels/hgemm/bench/NVIDIA_L20.png -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "third-party/cutlass"] 2 | path = third-party/cutlass 3 | url = https://github.com/NVIDIA/cutlass.git 4 | tag = v3.5.1 5 | -------------------------------------------------------------------------------- /kernels/hgemm/bench/NVIDIA_GeForce_RTX_4090.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xlite-dev/HGEMM/HEAD/kernels/hgemm/bench/NVIDIA_GeForce_RTX_4090.png -------------------------------------------------------------------------------- /kernels/hgemm/bench/NVIDIA_GeForce_RTX_3080_Laptop_GPU_WSL2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xlite-dev/HGEMM/HEAD/kernels/hgemm/bench/NVIDIA_GeForce_RTX_3080_Laptop_GPU_WSL2.png -------------------------------------------------------------------------------- /.dev/update_submodules.sh: -------------------------------------------------------------------------------- 1 | # update submodules 2 | set -x 3 | git submodule init 4 | git submodule update --remote # update all submodule 5 | git add . 6 | git commit -m "Automated cutlass submodule update" 7 | set +x 8 | -------------------------------------------------------------------------------- /kernels/hgemm/tools/install.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | git submodule update --init --recursive --force 4 | python3 -m pip uninstall toy-hgemm -y 5 | python3 setup.py bdist_wheel && cd dist && python3 -m pip install *.whl && cd - 6 | rm -rf toy_hgemm.egg-info __pycache__ 7 | 8 | set +x -------------------------------------------------------------------------------- /kernels/hgemm/.gitignore: -------------------------------------------------------------------------------- 1 | *.so 2 | *.a 3 | *.dylib 4 | *.dll 5 | *.lib 6 | .DS_Store 7 | build 8 | *.whl 9 | tmp 10 | __pycache__ 11 | *.onnx 12 | *.engine 13 | *.pt 14 | *.pth 15 | *.nsys* 16 | *.ncu* 17 | *.sqlite* 18 | *.engine 19 | *.bin 20 | *.out 21 | *bin 22 | bin 23 | output 24 | *.egg-info 25 | *.whl 26 | dist 27 | *.pdf 28 | *.tex 29 | *.log 30 | *.md5 31 | *.aux* 32 | *.dpth 33 | -------------------------------------------------------------------------------- /kernels/hgemm/mma/basic/.gitignore: -------------------------------------------------------------------------------- 1 | *.so 2 | *.a 3 | *.dylib 4 | *.dll 5 | *.lib 6 | .DS_Store 7 | build 8 | *.whl 9 | tmp 10 | __pycache__ 11 | *.onnx 12 | *.engine 13 | *.pt 14 | *.pth 15 | *.nsys* 16 | *.ncu* 17 | *.sqlite* 18 | *.engine 19 | *.bin 20 | *.out 21 | *bin 22 | bin 23 | output 24 | *.egg-info 25 | *.whl 26 | dist 27 | *.pdf 28 | *.tex 29 | *.log 30 | *.md5 31 | *.aux* 32 | *.dpth 33 | -------------------------------------------------------------------------------- /kernels/hgemm/mma/others/.gitignore: -------------------------------------------------------------------------------- 1 | *.so 2 | *.a 3 | *.dylib 4 | *.dll 5 | *.lib 6 | .DS_Store 7 | build 8 | *.whl 9 | tmp 10 | __pycache__ 11 | *.onnx 12 | *.engine 13 | *.pt 14 | *.pth 15 | *.nsys* 16 | *.ncu* 17 | *.sqlite* 18 | *.engine 19 | *.bin 20 | *.out 21 | *bin 22 | bin 23 | output 24 | *.egg-info 25 | *.whl 26 | dist 27 | *.pdf 28 | *.tex 29 | *.log 30 | *.md5 31 | *.aux* 32 | *.dpth 33 | -------------------------------------------------------------------------------- /kernels/hgemm/mma/swizzle/.gitignore: -------------------------------------------------------------------------------- 1 | *.so 2 | *.a 3 | *.dylib 4 | *.dll 5 | *.lib 6 | .DS_Store 7 | build 8 | *.whl 9 | tmp 10 | __pycache__ 11 | *.onnx 12 | *.engine 13 | *.pt 14 | *.pth 15 | *.nsys* 16 | *.ncu* 17 | *.sqlite* 18 | *.engine 19 | *.bin 20 | *.out 21 | *bin 22 | bin 23 | output 24 | *.egg-info 25 | *.whl 26 | dist 27 | *.pdf 28 | *.tex 29 | *.log 30 | *.md5 31 | *.aux* 32 | *.dpth 33 | -------------------------------------------------------------------------------- /kernels/hgemm/setup.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import os 3 | from packaging.version import parse, Version 4 | from pathlib import Path 5 | from setuptools import setup, find_packages 6 | from torch.utils.cpp_extension import ( 7 | BuildExtension, 8 | CUDAExtension, 9 | CUDA_HOME, 10 | ) 11 | from tools.utils import (get_build_sources, get_build_cuda_cflags) 12 | 13 | # package name managed by pip, which can be remove by `pip uninstall toy-hgemm` 14 | PACKAGE_NAME = "toy-hgemm" 15 | 16 | ext_modules = [] 17 | generator_flag = [] 18 | cc_flag = [] 19 | cc_flag.append("-gencode") 20 | cc_flag.append("arch=compute_80,code=sm_80") 21 | cc_flag.append("-gencode") 22 | cc_flag.append("arch=compute_89,code=sm_89") 23 | 24 | 25 | # helper function to get cuda version 26 | def get_cuda_bare_metal_version(cuda_dir): 27 | raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) 28 | output = raw_output.split() 29 | release_idx = output.index("release") + 1 30 | bare_metal_version = parse(output[release_idx].split(",")[0]) 31 | 32 | return raw_output, bare_metal_version 33 | 34 | 35 | if CUDA_HOME is not None: 36 | _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) 37 | if bare_metal_version >= Version("11.8"): 38 | cc_flag.append("-gencode") 39 | cc_flag.append("arch=compute_90,code=sm_90") 40 | 41 | # ninja build does not work unless include_dirs are abs path 42 | this_dir = os.path.dirname(os.path.abspath(__file__)) 43 | 44 | # cuda module 45 | # may need export LD_LIBRARY_PATH=PATH-TO/torch/lib:$LD_LIBRARY_PATH 46 | ext_modules.append( 47 | CUDAExtension( 48 | # package name for import 49 | name="toy_hgemm", 50 | sources=get_build_sources(), 51 | extra_compile_args={ 52 | # add c compile flags 53 | "cxx": ["-O3", "-std=c++17"] + generator_flag, 54 | # add nvcc compile flags 55 | "nvcc": get_build_cuda_cflags(build_pkg=True) + generator_flag + cc_flag, 56 | }, 57 | include_dirs=[ 58 | Path(this_dir) / "naive", 59 | Path(this_dir) / "utils", 60 | Path(this_dir) / "wmma", 61 | Path(this_dir) / "mma" , 62 | Path(this_dir) / "cutlass" , 63 | Path(this_dir) / "cublas" , 64 | Path(this_dir) / "pybind" , 65 | ], 66 | ) 67 | ) 68 | 69 | setup( 70 | name=PACKAGE_NAME, 71 | version="0.1.0", 72 | packages=find_packages( 73 | exclude=( 74 | "build", 75 | "naive", 76 | "wmma", 77 | "mma", 78 | "cutlass", 79 | "cublas", 80 | "utils", 81 | "bench", 82 | "pybind", 83 | "tmp", 84 | ) 85 | ), 86 | description="My Toy HGEMM implement by CUDA", 87 | ext_modules=ext_modules, 88 | cmdclass={ "build_ext": BuildExtension}, 89 | python_requires=">=3.10", 90 | install_requires=[ 91 | "torch", 92 | "packaging", 93 | "ninja", 94 | ], 95 | ) 96 | 97 | 98 | 99 | -------------------------------------------------------------------------------- /kernels/hgemm/bench/prof.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | from torch.utils.cpp_extension import load 4 | from functools import partial 5 | from typing import Optional 6 | 7 | torch.set_grad_enabled(False) 8 | 9 | # # Load the CUDA kernel as a python module 10 | # lib = load(name='hgemm_lib', 11 | # sources=['hgemm.cu'], 12 | # extra_cuda_cflags=[ 13 | # "-O3", 14 | # "-U__CUDA_NO_HALF_OPERATORS__", 15 | # "-U__CUDA_NO_HALF_CONVERSIONS__", 16 | # "-U__CUDA_NO_HALF2_OPERATORS__", 17 | # "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", 18 | # "--expt-relaxed-constexpr", 19 | # "--expt-extended-lambda", 20 | # "--use_fast_math" 21 | # ], 22 | # extra_cflags=['-std=c++17']) 23 | 24 | 25 | def run_benchmark(perf_func: callable, 26 | a: torch.Tensor, b: torch.Tensor, 27 | tag: str, out: Optional[torch.Tensor] = None, 28 | warmup: int = 1, iters: int = 10, 29 | show_all: bool = False): 30 | if out is not None: 31 | out.fill_(0) 32 | if out is not None: 33 | for i in range(warmup): 34 | perf_func(a, b, out) 35 | else: 36 | for i in range(warmup): 37 | _ = perf_func(a, b) 38 | 39 | torch.cuda.synchronize() 40 | start = time.time() 41 | # iters 42 | if out is not None: 43 | for i in range(iters): 44 | perf_func(a, b, out) 45 | else: 46 | for i in range(iters): 47 | out = perf_func(a, b) 48 | torch.cuda.synchronize() 49 | end = time.time() 50 | total_time = (end - start) * 1000 # ms 51 | mean_time = total_time / iters 52 | out_info = f"out_{tag}" 53 | out_val = out.flatten().detach().cpu().numpy().tolist()[:3] 54 | out_val = [round(v, 8) for v in out_val] 55 | out_val = [f"{v:<12}" for v in out_val] 56 | print(f"{out_info:>32}: {out_val}, time:{mean_time:.6f}ms") 57 | if show_all: print(out) 58 | return out.clone(), mean_time 59 | 60 | 61 | # Ms = [1024, 2048, 4096] 62 | # Ns = [1024, 2048, 4096] 63 | # Ks = [256, 512, 1024] 64 | Ms = [4096] 65 | Ns = [4096] 66 | Ks = [1024] 67 | MNKs = [(M, N, K) for M in Ms for N in Ns for K in Ks] 68 | for (M, N, K) in MNKs: 69 | print("-" * 110) 70 | print(" " * 45 + f"M={M}, N={N}, K={K}") 71 | a = torch.randn((M, K)).cuda().half().contiguous() 72 | b = torch.randn((K, N)).cuda().half().contiguous() 73 | c = torch.randn((M, N)).cuda().half().contiguous() 74 | # run_benchmark(lib.hgemm_naive_f16, a, b, "f16", c) 75 | # run_benchmark(lib.hgemm_sliced_k_f16, a, b, "f16(sk)", c) 76 | # run_benchmark(lib.hgemm_t_4x4_sliced_k_f16x4_pack_bcf, a, b, "f16x4pack(t4x4bcf)", c) 77 | # run_benchmark(lib.hgemm_t_4x4_sliced_k_f16x4_pack_bcf_offset, a, b, "f16x4pack(t4x4offset)", c) 78 | # run_benchmark(lib.hgemm_t_8x8_sliced_k_f16x4, a, b, "f16x4(t8x8sk)", c) 79 | # run_benchmark(lib.hgemm_t_8x8_sliced_k_f16x4_bcf, a, b, "f16x4(t8x8bcf)", c) 80 | # run_benchmark(lib.hgemm_t_8x8_sliced_k_f16x4_pack, a, b, "f16x4pack(t8x8sk)", c) 81 | # run_benchmark(lib.hgemm_t_8x8_sliced_k_f16x4_pack_bcf, a, b, "f16x4pack(bcf)", c) 82 | # run_benchmark(lib.hgemm_t_8x8_sliced_k_f16x4_pack_bcf_offset, a, b, "f16x4pack(bcf+offset)", c) 83 | # run_benchmark(lib.hgemm_t_8x8_sliced_k_f16x8_pack_bcf, a, b, "f16x8pack(bcf)", c) 84 | # run_benchmark(lib.hgemm_t_8x8_sliced_k_f16x8_pack_bcf_offset, a, b, "f16x8pack(bcf+offset)", c) 85 | # run_benchmark(lib.hgemm_t_8x8_sliced_k_f16x8_pack_bcf_dbuf, a, b, "f16x8pack(dbuf)", c) 86 | run_benchmark(partial(torch.matmul, out=c), a, b, "f16_th") 87 | print("-" * 110) 88 | 89 | -------------------------------------------------------------------------------- /kernels/hgemm/makefile: -------------------------------------------------------------------------------- 1 | INCLUDE_DIRS=-I ./utils -I ../../third-party/cutlass/include -I ../../third-party/cutlass/tools/util/include 2 | ARCHS=-gencode arch=compute_80,code=sm_80 -gencode arch=compute_89,code=sm_89 3 | ARCHS_80=-gencode arch=compute_80,code=sm_80 4 | ARCHS_89=-gencode arch=compute_89,code=sm_89 5 | DEFAULT_FLAGS=-O2 $(ARCHS) -std=c++17 $(INCLUDE_DIRS) --expt-relaxed-constexpr -lcublas 6 | DEFAULT_FLAGS_89=-O2 $(ARCHS_89) -std=c++17 $(INCLUDE_DIRS) --expt-relaxed-constexpr -lcublas 7 | DEFAULT_FLAGS_80=-O2 $(ARCHS_80) -std=c++17 $(INCLUDE_DIRS) --expt-relaxed-constexpr -lcublas 8 | 9 | # Default 10 | default: 11 | nvcc cutlass/hgemm_mma_stage_tn_cute.cu -o hgemm_cute.bin $(DEFAULT_FLAGS) 12 | nvcc cublas/hgemm_cublas.cu -o hgemm_cublas.bin $(DEFAULT_FLAGS) 13 | nvcc mma/basic/hgemm_mma_stage.cu -o hgemm_mma_stage.bin $(DEFAULT_FLAGS) 14 | nvcc mma/basic/hgemm_mma_stage_tn.cu -o hgemm_mma_stage_tn.bin $(DEFAULT_FLAGS) 15 | nvcc mma/swizzle/hgemm_mma_stage_swizzle.cu -o hgemm_mma_stage_swizzle.bin $(DEFAULT_FLAGS) 16 | nvcc mma/swizzle/hgemm_mma_stage_tn_swizzle.cu -o hgemm_mma_stage_tn_swizzle.bin $(DEFAULT_FLAGS) 17 | nvcc mma/swizzle/hgemm_mma_stage_tn_swizzle_x4.cu -o hgemm_mma_stage_tn_swizzle_x4.bin $(DEFAULT_FLAGS) 18 | 19 | # SM 89 20 | cute_89: 21 | nvcc cutlass/hgemm_mma_stage_tn_cute.cu -o hgemm_cute.89.bin $(DEFAULT_FLAGS_89) 22 | cute_89_debug: 23 | nvcc cutlass/hgemm_mma_stage_tn_cute.cu -o hgemm_cute.89.debug.bin $(DEFAULT_FLAGS_89) -DCUTE_HGEMM_DEBUG -Xcompiler "-Wno-format" 24 | # SM 89 NN debug 25 | mma_89: 26 | nvcc mma/basic/hgemm_mma_stage.cu -o hgemm_mma_stage.89.bin $(DEFAULT_FLAGS_89) 27 | mma_89_debug: 28 | nvcc mma/basic/hgemm_mma_stage.cu -o hgemm_mma_stage.89.debug.bin $(DEFAULT_FLAGS_89) -DHGEMM_MMA_DEBUG 29 | mma_89_swizzle: 30 | nvcc mma/swizzle/hgemm_mma_stage_swizzle.cu -o hgemm_mma_stage_swizzle.89.bin $(DEFAULT_FLAGS_89) 31 | mma_89_swizzle_debug: 32 | nvcc mma/swizzle/hgemm_mma_stage_swizzle.cu -o hgemm_mma_stage_swizzle.89.debug.bin $(DEFAULT_FLAGS_89) -DHGEMM_MMA_DEBUG 33 | # SM 89 TN debug 34 | mma_tn_89: 35 | nvcc mma/basic/hgemm_mma_stage_tn.cu -o hgemm_mma_stage_tn.89.bin $(DEFAULT_FLAGS_89) 36 | mma_tn_89_debug: 37 | nvcc mma/basic/hgemm_mma_stage_tn.cu -o hgemm_mma_stage_tn.89.debug.bin $(DEFAULT_FLAGS_89) -DHGEMM_MMA_DEBUG 38 | mma_tn_89_swizzle: 39 | nvcc mma/swizzle/hgemm_mma_stage_tn_swizzle.cu -o hgemm_mma_stage_tn_swizzle.89.bin $(DEFAULT_FLAGS_89) 40 | mma_tn_89_swizzle_debug: 41 | nvcc mma/swizzle/hgemm_mma_stage_tn_swizzle.cu -o hgemm_mma_stage_tn_swizzle.89.debug.bin $(DEFAULT_FLAGS_89) -DHGEMM_MMA_DEBUG 42 | mma_tn_89_swizzle_x2: 43 | nvcc mma/swizzle/hgemm_mma_stage_tn_swizzle_x2.cu -o hgemm_mma_stage_tn_swizzle_x2.89.bin $(DEFAULT_FLAGS_89) 44 | mma_tn_89_swizzle_x2_debug: 45 | nvcc mma/swizzle/hgemm_mma_stage_tn_swizzle_x2.cu -o hgemm_mma_stage_tn_swizzle_x2.89.debug.bin $(DEFAULT_FLAGS_89) -DHGEMM_MMA_DEBUG 46 | mma_tn_89_swizzle_x4: 47 | nvcc mma/swizzle/hgemm_mma_stage_tn_swizzle_x4.cu -o hgemm_mma_stage_tn_swizzle_x4.89.bin $(DEFAULT_FLAGS_89) 48 | mma_tn_89_swizzle_x4_debug: 49 | nvcc mma/swizzle/hgemm_mma_stage_tn_swizzle_x4.cu -o hgemm_mma_stage_tn_swizzle_x4.89.debug.bin $(DEFAULT_FLAGS_89) -DHGEMM_MMA_DEBUG 50 | 51 | # SM 80 52 | cute_80: 53 | nvcc cutlass/hgemm_mma_stage_tn_cute.cu -o hgemm_cute.80.bin $(DEFAULT_FLAGS_80) 54 | cute_80_debug: 55 | nvcc cutlass/hgemm_mma_stage_tn_cute.cu -o hgemm_cute.80.debug.bin $(DEFAULT_FLAGS_80) -DCUTE_HGEMM_DEBUG -Xcompiler "-Wno-format" 56 | # SM 80 TN debug 57 | mma_80: 58 | nvcc mma/basic/hgemm_mma_stage.cu -o hgemm_mma_stage.80.bin $(DEFAULT_FLAGS_80) 59 | mma_80_debug: 60 | nvcc mma/basic/hgemm_mma_stage.cu -o hgemm_mma_stage.80.debug.bin $(DEFAULT_FLAGS_80) -DHGEMM_MMA_DEBUG 61 | mma_80_swizzle: 62 | nvcc mma/swizzle/hgemm_mma_stage_swizzle.cu -o hgemm_mma_stage_swizzle.80.bin $(DEFAULT_FLAGS_80) 63 | mma_80_swizzle_debug: 64 | nvcc mma/swizzle/hgemm_mma_stage_swizzle.cu -o hgemm_mma_stage_swizzle.80.debug.bin $(DEFAULT_FLAGS_80) -DHGEMM_MMA_DEBUG 65 | # SM 80 TN debug 66 | mma_tn_80: 67 | nvcc mma/basic/hgemm_mma_stage_tn.cu -o hgemm_mma_stage_tn.80.bin $(DEFAULT_FLAGS_80) 68 | mma_tn_80_debug: 69 | nvcc mma/basic/hgemm_mma_stage_tn.cu -o hgemm_mma_stage_tn.80.debug.bin $(DEFAULT_FLAGS_80) -DHGEMM_MMA_DEBUG 70 | mma_tn_80_swizzle: 71 | nvcc mma/swizzle/hgemm_mma_stage_tn_swizzle.cu -o hgemm_mma_stage_tn_swizzle.80.bin $(DEFAULT_FLAGS_80) 72 | mma_tn_80_swizzle_debug: 73 | nvcc mma/swizzle/hgemm_mma_stage_tn_swizzle.cu -o hgemm_mma_stage_tn_swizzle.80.debug.bin $(DEFAULT_FLAGS_80) -DHGEMM_MMA_DEBUG 74 | mma_tn_80_swizzle_x4: 75 | nvcc mma/swizzle/hgemm_mma_stage_tn_swizzle_x4.cu -o hgemm_mma_stage_tn_swizzle_x4.80.bin $(DEFAULT_FLAGS_80) 76 | mma_tn_80_swizzle_x4_debug: 77 | nvcc mma/swizzle/hgemm_mma_stage_tn_swizzle_x4.cu -o hgemm_mma_stage_tn_swizzle_x4.80.debug.bin $(DEFAULT_FLAGS_80) -DHGEMM_MMA_DEBUG 78 | 79 | clean: 80 | rm -rf *.bin & rm -rf ./bin 81 | -------------------------------------------------------------------------------- /kernels/hgemm/tools/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.cpp_extension import load 4 | 5 | 6 | def get_device_name(): 7 | device_name = torch.cuda.get_device_name(torch.cuda.current_device()) 8 | # since we will run GPU on WSL2, so add WSL2 tag. 9 | if "Laptop" in device_name: 10 | device_name += " WSL2" 11 | return device_name 12 | 13 | 14 | def get_device_capability(): 15 | return torch.cuda.get_device_capability(torch.cuda.current_device()) 16 | 17 | 18 | def get_build_sources(): 19 | build_sources = [] 20 | build_sources.append('naive/hgemm.cu') 21 | build_sources.append('naive/hgemm_async.cu') 22 | build_sources.append('cublas/hgemm_cublas.cu') 23 | build_sources.append('wmma/hgemm_wmma.cu') 24 | build_sources.append('wmma/hgemm_wmma_stage.cu') 25 | build_sources.append('mma/basic/hgemm_mma.cu') 26 | build_sources.append('mma/basic/hgemm_mma_stage.cu') 27 | build_sources.append('mma/basic/hgemm_mma_stage_tn.cu') 28 | build_sources.append('mma/swizzle/hgemm_mma_stage_swizzle.cu') 29 | build_sources.append('mma/swizzle/hgemm_mma_stage_tn_swizzle_x4.cu') 30 | build_sources.append('cutlass/hgemm_mma_stage_tn_cute.cu') 31 | build_sources.append('pybind/hgemm.cc') 32 | return build_sources 33 | 34 | 35 | def get_project_dir(): 36 | return os.path.dirname(os.path.dirname( 37 | os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) 38 | 39 | 40 | def get_build_cuda_cflags(build_pkg: bool = False): 41 | # -Xptxas -v: 42 | # registers, smem, cmem, stack, gmem usage 43 | # registers: 寄存器,访问速度最快。Ada Lovelace架构每个SM的寄存器文件大小 44 | # 为256KB,这相当于65536个32位寄存器,65536/256=256。一个SM可以同时执行多 45 | # 个block,对一个Kernel,同时存在于一个SM中的Block和Warp数量取决于SM中可用 46 | # 且所需的寄存器和共享内存数量。每个Thread需要的寄存器越多,那么SM中的Warp就 47 | # 越少。即减少Thread所需寄存器数量,即可增加SM中的Warp数。每个Block需要的共 48 | # 享内存越多,那么SM中可以被同时处理的Block就会变少。即减少每个Block所需的共 49 | # 享内存,即可同时处理更多Block。SM内的资源没办法处理一个完整Block,Kernel 50 | # 将无法启动。 51 | # cmem: 常量内存,被缓存,访问速度快。 52 | # stack frame: 由于寄存器的数量有限,当需要使用的变量数量超过可用寄存器数量时, 53 | # 编译器会将某些变量从寄存器“溢出”到栈上,这个过程称为spill。访问栈上的数据比 54 | # 访问寄存器慢得多。 55 | # spill stores: 指的是在执行过程中,数据因为寄存器不足而被存储到了栈上。 56 | # spill loads: 则是指将之前溢出到栈上的数据重新加载回寄存器。 57 | # diag 177: variable was declared but never referenced 58 | extra_cuda_cflags = [] 59 | extra_cuda_cflags.append("-O3") 60 | extra_cuda_cflags.append("-std=c++17") 61 | extra_cuda_cflags.append("-U__CUDA_NO_HALF_OPERATORS__") 62 | extra_cuda_cflags.append("-U__CUDA_NO_HALF_CONVERSIONS__") 63 | extra_cuda_cflags.append("-U__CUDA_NO_HALF2_OPERATORS__") 64 | extra_cuda_cflags.append("-U__CUDA_NO_BFLOAT16_CONVERSIONS__") 65 | extra_cuda_cflags.append("--expt-relaxed-constexpr") 66 | extra_cuda_cflags.append("--expt-extended-lambda") 67 | extra_cuda_cflags.append("--use_fast_math") 68 | if not build_pkg: 69 | extra_cuda_cflags.append("-diag-suppress 177") 70 | extra_cuda_cflags.append("-Xptxas -v") 71 | else: 72 | extra_cuda_cflags.append("--ptxas-options=-v") 73 | extra_cuda_cflags.append("--ptxas-options=-O3") 74 | # extra cuda flags for cute hgemm 75 | project_dir = get_project_dir() 76 | extra_cuda_cflags.append('-DNO_MMA_HGEMM_BIN') 77 | extra_cuda_cflags.append('-DNO_WMMA_HGEMM_BIN') 78 | extra_cuda_cflags.append('-DNO_CUTE_HGEMM_BIN') 79 | extra_cuda_cflags.append('-DNO_CUBLAS_HGEMM_BIN') 80 | # add cutlass headers and link cublas. 81 | extra_cuda_cflags.append(f'-I {project_dir}/kernels/hgemm') 82 | extra_cuda_cflags.append(f'-I {project_dir}/kernels/hgemm/utils') 83 | extra_cuda_cflags.append(f'-I {project_dir}/kernels/hgemm/naive') 84 | extra_cuda_cflags.append(f'-I {project_dir}/kernels/hgemm/wmma') 85 | extra_cuda_cflags.append(f'-I {project_dir}/kernels/hgemm/mma/basic') 86 | extra_cuda_cflags.append(f'-I {project_dir}/kernels/hgemm/mma/swizzle') 87 | extra_cuda_cflags.append(f'-I {project_dir}/kernels/hgemm/cutlass') 88 | extra_cuda_cflags.append(f'-I {project_dir}/kernels/hgemm/cublas') 89 | extra_cuda_cflags.append(f'-I {project_dir}/kernels/hgemm/pybind') 90 | extra_cuda_cflags.append(f'-I {project_dir}/third-party/cutlass/include') 91 | extra_cuda_cflags.append(f'-I {project_dir}/third-party/cutlass/tools/util/include') 92 | extra_cuda_cflags.append('-lcublas') 93 | return extra_cuda_cflags 94 | 95 | 96 | def pretty_print_line(m: str = "", sep: str = "-", width: int = 150): 97 | res_len = width - len(m) 98 | left_len = int(res_len / 2) 99 | right_len = res_len - left_len 100 | pretty_line = sep * left_len + m + sep * right_len 101 | print(pretty_line) 102 | 103 | 104 | def build_from_sources(verbose: bool = False): 105 | torch_arch_list_env = os.environ.get("TORCH_CUDA_ARCH_LIST", None) 106 | # Load the CUDA kernel as a python module 107 | pretty_print_line(f"Loading hgemm lib on device: {get_device_name()}, " 108 | f"capability: {get_device_capability()}, " 109 | f"Arch ENV: {torch_arch_list_env}") 110 | return load(name='hgemm_lib', sources=get_build_sources(), 111 | extra_cuda_cflags=get_build_cuda_cflags(), 112 | extra_cflags=['-std=c++17'], 113 | verbose=verbose) 114 | 115 | 116 | def try_load_hgemm_library(force_build: bool = False, verbose: bool = False): 117 | if not force_build: 118 | # check if can import toy_hgemm 119 | try: 120 | import toy_hgemm as hgemm 121 | pretty_print_line(f"Import toy-hgemm library done, use it!") 122 | except Exception: 123 | pretty_print_line(f"Can't import toy-hgemm, force build " 124 | f"from source or run ") 125 | pretty_print_line(f"Also may need export LD_LIBRARY_PATH=" 126 | f"PATH-TO/torch/lib:$LD_LIBRARY_PATH") 127 | hgemm = build_from_sources(verbose=verbose) 128 | else: 129 | pretty_print_line("Force hgemm lib build from sources") 130 | hgemm = build_from_sources(verbose=verbose) 131 | 132 | return hgemm 133 | 134 | 135 | @torch.no_grad 136 | def as_col_major(x: torch.Tensor): 137 | # convert a row major tensor -> col major with contiguous storage 138 | x_trans = x.t() 139 | x_col_major = x_trans.reshape(x.shape) 140 | return x_col_major.contiguous() # must be a contiguous tensor 141 | -------------------------------------------------------------------------------- /kernels/hgemm/pybind/hgemm.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #define STRINGFY(str) #str 5 | #define TORCH_BINDING_COMMON_EXTENSION(func) m.def(STRINGFY(func), &func, STRINGFY(func)); 6 | 7 | // from hgemm.cu 8 | void hgemm_naive_f16(torch::Tensor a, torch::Tensor b, torch::Tensor c); 9 | void hgemm_sliced_k_f16(torch::Tensor a, torch::Tensor b, torch::Tensor c); 10 | void hgemm_t_8x8_sliced_k_f16x4(torch::Tensor a, torch::Tensor b, torch::Tensor c); 11 | void hgemm_t_8x8_sliced_k_f16x4_pack(torch::Tensor a, torch::Tensor b, torch::Tensor c); 12 | void hgemm_t_8x8_sliced_k_f16x4_bcf(torch::Tensor a, torch::Tensor b, torch::Tensor c); 13 | void hgemm_t_8x8_sliced_k_f16x4_pack_bcf(torch::Tensor a, torch::Tensor b, torch::Tensor c); 14 | void hgemm_t_8x8_sliced_k_f16x8_pack_bcf(torch::Tensor a, torch::Tensor b, torch::Tensor c); 15 | void hgemm_t_8x8_sliced_k_f16x8_pack_bcf_dbuf(torch::Tensor a, torch::Tensor b, torch::Tensor c); 16 | // from hgemm_async.cu 17 | void hgemm_t_8x8_sliced_k16_f16x8_pack_dbuf(torch::Tensor a, torch::Tensor b, torch::Tensor c); 18 | void hgemm_t_8x8_sliced_k16_f16x8_pack_dbuf_async(torch::Tensor a, torch::Tensor b, torch::Tensor c); 19 | void hgemm_t_8x8_sliced_k32_f16x8_pack_dbuf(torch::Tensor a, torch::Tensor b, torch::Tensor c); 20 | void hgemm_t_8x8_sliced_k32_f16x8_pack_dbuf_async(torch::Tensor a, torch::Tensor b, torch::Tensor c); 21 | void hgemm_t_16x8_sliced_k32_f16x8_pack_dbuf(torch::Tensor a, torch::Tensor b, torch::Tensor c); 22 | void hgemm_t_16x8_sliced_k32_f16x8_pack_dbuf_async(torch::Tensor a, torch::Tensor b, torch::Tensor c); 23 | // from hgemm_cublas.cu 24 | void init_cublas_handle(); 25 | void destroy_cublas_handle(); 26 | void hgemm_cublas_tensor_op_nn(torch::Tensor a, torch::Tensor b, torch::Tensor c); 27 | void hgemm_cublas_tensor_op_tn(torch::Tensor a, torch::Tensor b, torch::Tensor c); 28 | // from hgemm_wmma.cu 29 | void hgemm_wmma_m16n16k16_naive(torch::Tensor a, torch::Tensor b, torch::Tensor c); 30 | void hgemm_wmma_m16n16k16_mma4x2(torch::Tensor a, torch::Tensor b, torch::Tensor c); 31 | void hgemm_wmma_m16n16k16_mma4x2_warp2x4(torch::Tensor a, torch::Tensor b, torch::Tensor c); 32 | void hgemm_wmma_m16n16k16_mma4x2_warp2x4_dbuf_async(torch::Tensor a, torch::Tensor b, torch::Tensor c); 33 | void hgemm_wmma_m32n8k16_mma2x4_warp2x4_dbuf_async(torch::Tensor a, torch::Tensor b, torch::Tensor c); 34 | // from hgemm_wmma_stage.cu 35 | void hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride); 36 | void hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride); 37 | void hgemm_wmma_m16n16k16_mma4x2_warp4x4_stages_dsmem(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride); 38 | void hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride); 39 | // from hgemm_mma.cu 40 | void hgemm_mma_m16n8k16_naive(torch::Tensor a, torch::Tensor b, torch::Tensor c); 41 | void hgemm_mma_m16n8k16_mma2x4_warp4x4(torch::Tensor a, torch::Tensor b, torch::Tensor c); 42 | // from hgemm_mma_stage.cu 43 | void hgemm_mma_m16n8k16_mma2x4_warp4x4_stages(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride); 44 | void hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride); 45 | void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride); 46 | void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_x4(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride); 47 | void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_rr(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride); 48 | // from hgemm_mma_stage_tn.cu 49 | void hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride); 50 | // from hgemm_mma_stage_tn_cute.cu 51 | void hgemm_mma_stages_block_swizzle_tn_cute(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride); 52 | // from hgemm_mma_stage_swizzle.cu 53 | void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_swizzle(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride); 54 | // from hgemm_mma_stage_tn_swizzle_x4s.cu 55 | void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_tn_swizzle_x4(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride); 56 | 57 | 58 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 59 | // CUDA Cores FP16 60 | TORCH_BINDING_COMMON_EXTENSION(hgemm_naive_f16) 61 | TORCH_BINDING_COMMON_EXTENSION(hgemm_sliced_k_f16) 62 | TORCH_BINDING_COMMON_EXTENSION(hgemm_t_8x8_sliced_k_f16x4) 63 | TORCH_BINDING_COMMON_EXTENSION(hgemm_t_8x8_sliced_k_f16x4_pack) 64 | TORCH_BINDING_COMMON_EXTENSION(hgemm_t_8x8_sliced_k_f16x4_bcf) 65 | TORCH_BINDING_COMMON_EXTENSION(hgemm_t_8x8_sliced_k_f16x4_pack_bcf) 66 | TORCH_BINDING_COMMON_EXTENSION(hgemm_t_8x8_sliced_k_f16x8_pack_bcf) 67 | TORCH_BINDING_COMMON_EXTENSION(hgemm_t_8x8_sliced_k_f16x8_pack_bcf_dbuf) 68 | // Copy Async 69 | TORCH_BINDING_COMMON_EXTENSION(hgemm_t_8x8_sliced_k16_f16x8_pack_dbuf) 70 | TORCH_BINDING_COMMON_EXTENSION(hgemm_t_8x8_sliced_k16_f16x8_pack_dbuf_async) 71 | TORCH_BINDING_COMMON_EXTENSION(hgemm_t_8x8_sliced_k32_f16x8_pack_dbuf) 72 | TORCH_BINDING_COMMON_EXTENSION(hgemm_t_8x8_sliced_k32_f16x8_pack_dbuf_async) 73 | TORCH_BINDING_COMMON_EXTENSION(hgemm_t_16x8_sliced_k32_f16x8_pack_dbuf) 74 | TORCH_BINDING_COMMON_EXTENSION(hgemm_t_16x8_sliced_k32_f16x8_pack_dbuf_async) 75 | // cuBLAS Tensor Cores 76 | TORCH_BINDING_COMMON_EXTENSION(init_cublas_handle) 77 | TORCH_BINDING_COMMON_EXTENSION(destroy_cublas_handle) 78 | TORCH_BINDING_COMMON_EXTENSION(hgemm_cublas_tensor_op_nn) 79 | TORCH_BINDING_COMMON_EXTENSION(hgemm_cublas_tensor_op_tn) 80 | // WMMA API Tensor Cores 81 | TORCH_BINDING_COMMON_EXTENSION(hgemm_wmma_m16n16k16_naive) 82 | TORCH_BINDING_COMMON_EXTENSION(hgemm_wmma_m16n16k16_mma4x2) 83 | TORCH_BINDING_COMMON_EXTENSION(hgemm_wmma_m16n16k16_mma4x2_warp2x4) 84 | TORCH_BINDING_COMMON_EXTENSION(hgemm_wmma_m16n16k16_mma4x2_warp2x4_dbuf_async) 85 | TORCH_BINDING_COMMON_EXTENSION(hgemm_wmma_m32n8k16_mma2x4_warp2x4_dbuf_async) 86 | // stage, thread block swizzle, dsmem 87 | TORCH_BINDING_COMMON_EXTENSION(hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages) 88 | TORCH_BINDING_COMMON_EXTENSION(hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem) 89 | TORCH_BINDING_COMMON_EXTENSION(hgemm_wmma_m16n16k16_mma4x2_warp4x4_stages_dsmem) 90 | TORCH_BINDING_COMMON_EXTENSION(hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem) 91 | // MMA API Tensor Cores 92 | TORCH_BINDING_COMMON_EXTENSION(hgemm_mma_m16n8k16_naive) 93 | TORCH_BINDING_COMMON_EXTENSION(hgemm_mma_m16n8k16_mma2x4_warp4x4) 94 | // stage, thread block swizzle, dsmem, reg double buffers 95 | TORCH_BINDING_COMMON_EXTENSION(hgemm_mma_m16n8k16_mma2x4_warp4x4_stages) 96 | TORCH_BINDING_COMMON_EXTENSION(hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem) 97 | TORCH_BINDING_COMMON_EXTENSION(hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem) 98 | TORCH_BINDING_COMMON_EXTENSION(hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_x4) 99 | TORCH_BINDING_COMMON_EXTENSION(hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_rr) 100 | // smem swizzle 101 | TORCH_BINDING_COMMON_EXTENSION(hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_swizzle) 102 | // TN: A row major MxK, B col major NxK, C row major MxN 103 | TORCH_BINDING_COMMON_EXTENSION(hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn) 104 | TORCH_BINDING_COMMON_EXTENSION(hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_tn_swizzle_x4) 105 | // TN: cute hgemm with smem & block swizzle 106 | TORCH_BINDING_COMMON_EXTENSION(hgemm_mma_stages_block_swizzle_tn_cute) 107 | } 108 | 109 | -------------------------------------------------------------------------------- /kernels/hgemm/cublas/hgemm_cublas.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include "cublas_v2.h" 12 | 13 | static cublasHandle_t g_handle = nullptr; 14 | 15 | void init_cublas_handle() { 16 | if (g_handle == nullptr) { 17 | cublasStatus_t status = cublasCreate(&g_handle); 18 | if (status != CUBLAS_STATUS_SUCCESS) { 19 | printf("Failed to create cuBLAS handle: %d", status); 20 | exit(EXIT_FAILURE); 21 | } 22 | status = cublasSetMathMode(g_handle, CUBLAS_TENSOR_OP_MATH); 23 | if (status != CUBLAS_STATUS_SUCCESS) { 24 | printf("Failed to set cuBLAS Math Mode: %d", status); 25 | exit(EXIT_FAILURE); 26 | } 27 | } 28 | } 29 | 30 | void destroy_cublas_handle() { 31 | if (g_handle != nullptr) { 32 | cublasStatus_t status = cublasDestroy(g_handle); 33 | if (status != CUBLAS_STATUS_SUCCESS) { 34 | printf("Failed to destroy cuBLAS handle: %d", status); 35 | } 36 | g_handle = nullptr; 37 | } 38 | } 39 | 40 | // NN: A/B/C All row major 41 | void cublas_tensor_op_nn(half *A, half *B, half *C, size_t M, size_t N, size_t K) { 42 | 43 | static half alpha = 1.0; 44 | static half beta = 0.0; 45 | 46 | if (g_handle == nullptr) { 47 | init_cublas_handle(); 48 | } 49 | 50 | cublasGemmEx(g_handle, 51 | CUBLAS_OP_N, 52 | CUBLAS_OP_N, 53 | N, M, K, 54 | &alpha, 55 | B, CUDA_R_16F, N, 56 | A, CUDA_R_16F, K, 57 | &beta, 58 | C, CUDA_R_16F, N, 59 | CUBLAS_COMPUTE_16F, 60 | CUBLAS_GEMM_DEFAULT_TENSOR_OP); 61 | } 62 | 63 | // TN: A row major MxK, B col major NxK, C row major MxN 64 | void cublas_tensor_op_tn(half *A, half *B, half *C, size_t M, size_t N, size_t K) { 65 | 66 | static half alpha = 1.0; 67 | static half beta = 0.0; 68 | 69 | if (g_handle == nullptr) { 70 | init_cublas_handle(); 71 | } 72 | 73 | cublasGemmEx(g_handle, 74 | CUBLAS_OP_T, 75 | CUBLAS_OP_N, 76 | N, M, K, 77 | &alpha, 78 | B, CUDA_R_16F, K, 79 | A, CUDA_R_16F, K, 80 | &beta, 81 | C, CUDA_R_16F, N, 82 | CUBLAS_COMPUTE_16F, 83 | CUBLAS_GEMM_DEFAULT_TENSOR_OP); 84 | } 85 | 86 | // build cpp binary 87 | #ifndef NO_CUBLAS_HGEMM_BIN 88 | 89 | // pass the cuBLAS handle from outside to avoid error. 90 | void cublas_tensor_op_tn_v2(cublasHandle_t handle, 91 | half *A, half *B, half *C, 92 | size_t M, size_t N, size_t K) { 93 | half alpha = 1.0; 94 | half beta = 0.0; 95 | 96 | cublasGemmEx(handle, 97 | CUBLAS_OP_T, 98 | CUBLAS_OP_N, 99 | N, M, K, 100 | &alpha, 101 | B, CUDA_R_16F, K, 102 | A, CUDA_R_16F, K, 103 | &beta, 104 | C, CUDA_R_16F, N, 105 | CUBLAS_COMPUTE_16F, 106 | CUBLAS_GEMM_DEFAULT_TENSOR_OP); 107 | } 108 | 109 | float perf_cublas_tn(int M, int N, int K, int repeat) { 110 | size_t size_a = M * K * sizeof(half); 111 | size_t size_b = K * N * sizeof(half); 112 | size_t size_c = M * N * sizeof(half); 113 | 114 | half *d_a, *d_b; 115 | half *d_c; 116 | cudaMalloc(&d_a, size_a); 117 | cudaMalloc(&d_b, size_b); 118 | cudaMalloc(&d_c, size_c); 119 | 120 | cublasHandle_t handle = nullptr; 121 | cublasCreate(&handle); 122 | cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH); 123 | 124 | // warmup 125 | for (int i = 0; i < 10; ++i) { 126 | cublas_tensor_op_tn_v2(handle, d_a, d_b, d_c, M, N, K); 127 | } 128 | cudaDeviceSynchronize(); 129 | 130 | cudaEvent_t start, end; 131 | cudaEventCreate(&start); 132 | cudaEventCreate(&end); 133 | cudaEventRecord(start); 134 | 135 | for (int i = 0; i < repeat; i++) { 136 | cublas_tensor_op_tn_v2(handle, d_a, d_b, d_c, M, N, K); 137 | } 138 | 139 | cudaEventRecord(end); 140 | cudaDeviceSynchronize(); 141 | cudaEventSynchronize(end); 142 | 143 | float msec, sec; 144 | cudaEventElapsedTime(&msec, start, end); 145 | sec = msec / 1000.0 / repeat; 146 | 147 | cudaFree(d_a); 148 | cudaFree(d_b); 149 | cudaFree(d_c); 150 | cudaEventDestroy(start); 151 | cudaEventDestroy(end); 152 | cublasDestroy(handle); 153 | 154 | return sec; 155 | } 156 | 157 | int main(int argc, char *argv[]) { 158 | const int test_num = 64; 159 | int M_list[test_num]; 160 | int N_list[test_num]; 161 | int K_list[test_num]; 162 | 163 | for (int i = 0; i < test_num; i++) { 164 | M_list[i] = (i + 1) * 256; 165 | N_list[i] = (i + 1) * 256; 166 | K_list[i] = (i + 1) * 256; 167 | } 168 | 169 | const int outer_repeat = 10, inner_repeat = 1; 170 | 171 | printf("ALGO = cuBLAS CUBLAS_GEMM_DEFAULT_TENSOR_OP TN\n"); 172 | 173 | for (int j = 0; j < test_num; j++) { 174 | int M = M_list[j], N = N_list[j], K = K_list[j]; 175 | 176 | double max_sec = 0.0; 177 | double min_sec = DBL_MAX; 178 | double total_sec = 0.0; 179 | 180 | for (int k = 0; k < outer_repeat; k++) { 181 | double this_sec = perf_cublas_tn(M, N, K, inner_repeat); 182 | max_sec = max(max_sec, this_sec); 183 | min_sec = min(min_sec, this_sec); 184 | total_sec += this_sec; 185 | } 186 | 187 | // 1 TFLOPS = 10^12 FLOPS 188 | // ref: https://imgtec.eetrend.com/blog/2021/100062210.html. 189 | double avg_sec = total_sec / outer_repeat; 190 | double avg_Tflops = ((double)M) * N * K * 2 * 1e-12 / avg_sec; 191 | 192 | printf("M N K = %6d %6d %6d, ", M, N, K); 193 | printf("Time = %12.8lf %12.8lf %12.8lf s, ", min_sec, avg_sec, max_sec); 194 | printf("AVG Performance = %10.4lf Tflops\n", avg_Tflops); 195 | } 196 | 197 | return 0; 198 | } 199 | // build torch python binding 200 | #else 201 | // --------------------- PyTorch bindings for custom kernel ----------------------- 202 | #include 203 | #include 204 | 205 | #define STRINGFY(str) #str 206 | #define TORCH_BINDING_COMMON_EXTENSION(func) \ 207 | m.def(STRINGFY(func), &func, STRINGFY(func)); 208 | 209 | #define CHECK_TORCH_TENSOR_DTYPE(T, th_type) \ 210 | if(((T).options().dtype() != (th_type))) { \ 211 | std::cout << "Tensor Info:" << (T).options() << std::endl; \ 212 | throw std::runtime_error("values must be "#th_type); \ 213 | } 214 | 215 | #define CHECK_TORCH_TENSOR_SHAPE(T, S0, S1) \ 216 | if (((T).size(0) != (S0)) || ((T).size(1) != (S1))) { \ 217 | throw std::runtime_error("Tensor size mismatch!"); \ 218 | } 219 | 220 | // NN: A/B/C All row major 221 | void hgemm_cublas_tensor_op_nn( 222 | torch::Tensor a, torch::Tensor b, torch::Tensor c) { 223 | CHECK_TORCH_TENSOR_DTYPE(a, torch::kHalf) 224 | CHECK_TORCH_TENSOR_DTYPE(b, torch::kHalf) 225 | CHECK_TORCH_TENSOR_DTYPE(c, torch::kHalf) 226 | const int M = a.size(0); 227 | const int K = a.size(1); 228 | const int N = b.size(1); 229 | CHECK_TORCH_TENSOR_SHAPE(a, M, K) 230 | CHECK_TORCH_TENSOR_SHAPE(b, K, N) 231 | CHECK_TORCH_TENSOR_SHAPE(c, M, N) 232 | 233 | cublas_tensor_op_nn( 234 | reinterpret_cast(a.data_ptr()), 235 | reinterpret_cast(b.data_ptr()), 236 | reinterpret_cast(c.data_ptr()), 237 | M, N, K 238 | ); 239 | } 240 | 241 | // TN: A row major MxK, B col major KxN, C row major MxN 242 | void hgemm_cublas_tensor_op_tn( 243 | torch::Tensor a, torch::Tensor b, torch::Tensor c) { 244 | CHECK_TORCH_TENSOR_DTYPE(a, torch::kHalf) 245 | CHECK_TORCH_TENSOR_DTYPE(b, torch::kHalf) 246 | CHECK_TORCH_TENSOR_DTYPE(c, torch::kHalf) 247 | const int M = a.size(0); 248 | const int K = a.size(1); 249 | const int N = b.size(1); 250 | CHECK_TORCH_TENSOR_SHAPE(a, M, K) 251 | CHECK_TORCH_TENSOR_SHAPE(b, K, N) 252 | CHECK_TORCH_TENSOR_SHAPE(c, M, N) 253 | 254 | cublas_tensor_op_tn( 255 | reinterpret_cast(a.data_ptr()), 256 | reinterpret_cast(b.data_ptr()), 257 | reinterpret_cast(c.data_ptr()), 258 | M, N, K 259 | ); 260 | } 261 | #endif 262 | -------------------------------------------------------------------------------- /kernels/hgemm/utils/utils.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | template 7 | float perf_gemm( 8 | void (*gpu_hgemm) (T *, T *, T *, int, int, int), 9 | int M, int N, int K, int repeat, int warmup = 1) { 10 | 11 | size_t size_a = M * K * sizeof(T); 12 | size_t size_b = K * N * sizeof(T); 13 | size_t size_c = M * N * sizeof(T); 14 | 15 | T *d_a, *d_b; 16 | T *d_c; 17 | cudaMalloc(&d_a, size_a); 18 | cudaMalloc(&d_b, size_b); 19 | cudaMalloc(&d_c, size_c); 20 | 21 | // warmup 22 | for (int i = 0; i < warmup; ++i){ 23 | gpu_hgemm(d_a, d_b, d_c, M, N, K); 24 | } 25 | cudaDeviceSynchronize(); 26 | 27 | cudaEvent_t start, end; 28 | cudaEventCreate(&start); 29 | cudaEventCreate(&end); 30 | cudaEventRecord(start); 31 | for (int i = 0; i < repeat; i++) { 32 | gpu_hgemm(d_a, d_b, d_c, M, N, K); 33 | } 34 | cudaEventRecord(end); 35 | cudaDeviceSynchronize(); 36 | cudaEventSynchronize(end); 37 | 38 | float msec, sec; 39 | cudaEventElapsedTime(&msec, start, end); 40 | sec = msec / 1000.0 / repeat; 41 | 42 | cudaFree(d_a); 43 | cudaFree(d_b); 44 | cudaFree(d_c); 45 | cudaEventDestroy(start); 46 | cudaEventDestroy(end); 47 | 48 | return sec; 49 | } 50 | 51 | 52 | template 53 | float perf_gemm_swizzle( 54 | void (*gpu_hgemm) (T *, T *, T *, int, int, int, int), 55 | int M, int N, int K, int swizzle_stride, int repeat, int warmup = 1) { 56 | 57 | size_t size_a = M * K * sizeof(T); 58 | size_t size_b = K * N * sizeof(T); 59 | size_t size_c = M * N * sizeof(T); 60 | 61 | T *d_a, *d_b; 62 | T *d_c; 63 | cudaMalloc(&d_a, size_a); 64 | cudaMalloc(&d_b, size_b); 65 | cudaMalloc(&d_c, size_c); 66 | 67 | // warmup 68 | for (int i = 0; i < warmup; ++i){ 69 | gpu_hgemm(d_a, d_b, d_c, M, N, K, swizzle_stride); 70 | } 71 | cudaDeviceSynchronize(); 72 | 73 | cudaEvent_t start, end; 74 | cudaEventCreate(&start); 75 | cudaEventCreate(&end); 76 | cudaEventRecord(start); 77 | for (int i = 0; i < repeat; i++) { 78 | gpu_hgemm(d_a, d_b, d_c, M, N, K, swizzle_stride); 79 | } 80 | cudaEventRecord(end); 81 | cudaDeviceSynchronize(); 82 | cudaEventSynchronize(end); 83 | 84 | float msec, sec; 85 | cudaEventElapsedTime(&msec, start, end); 86 | sec = msec / 1000.0 / repeat; 87 | 88 | cudaFree(d_a); 89 | cudaFree(d_b); 90 | cudaFree(d_c); 91 | cudaEventDestroy(start); 92 | cudaEventDestroy(end); 93 | 94 | return sec; 95 | } 96 | 97 | 98 | template 99 | float gemm_error_check_tn( 100 | void (*gpu_hgemm) (T *, T *, T *, int, int, int), 101 | int M, int N, int K) { 102 | 103 | size_t size_a = M * K * sizeof(T); 104 | size_t size_b = K * N * sizeof(T); 105 | size_t size_c = M * N * sizeof(T); 106 | 107 | T *h_a, *h_b, *h_c, *h_c_ref; 108 | T *d_a, *d_b, *d_c, *d_c_ref; 109 | 110 | h_a = (T *)malloc(size_a); 111 | h_b = (T *)malloc(size_b); 112 | h_c = (T *)malloc(size_c); 113 | h_c_ref = (T *)malloc(size_c); 114 | 115 | cudaMalloc(&d_a, size_a); 116 | cudaMalloc(&d_b, size_b); 117 | cudaMalloc(&d_c, size_c); 118 | cudaMalloc(&d_c_ref, size_c); 119 | 120 | srand(time(0)); 121 | for (int i = 0; i < M * K; i++) 122 | h_a[i] = (T)((rand() % 200 - 100) * 0.01); // -1 ~ 1 123 | for (int i = 0; i < K * N; i++) 124 | h_b[i] = (T)((rand() % 200 - 100) * 0.01); 125 | 126 | cublasHandle_t handle; 127 | cublasCreate(&handle); 128 | half alpha = 1.f; 129 | half beta = 0.f; 130 | 131 | cudaMemcpy(d_a, h_a, size_a, cudaMemcpyHostToDevice); 132 | cudaMemcpy(d_b, h_b, size_b, cudaMemcpyHostToDevice); 133 | 134 | cublasHgemm(handle, 135 | CUBLAS_OP_T, 136 | CUBLAS_OP_N, 137 | N, M, K, 138 | &alpha, 139 | (half *)d_b, K, 140 | (half *)d_a, K, 141 | &beta, 142 | (half *)d_c_ref, N); 143 | 144 | gpu_hgemm(d_a, d_b, d_c, M, N, K); 145 | 146 | cudaMemcpy(h_c, d_c, size_c, cudaMemcpyDeviceToHost); 147 | cudaMemcpy(h_c_ref, d_c_ref, size_c, cudaMemcpyDeviceToHost); 148 | 149 | float max_error = 0.0; 150 | for (int i = 0; i < M * N; i++) { 151 | float this_error = abs((float)h_c_ref[i] - (float)h_c[i]); 152 | max_error = max(max_error, this_error); 153 | } 154 | 155 | free(h_a); 156 | free(h_b); 157 | free(h_c); 158 | free(h_c_ref); 159 | cudaFree(d_a); 160 | cudaFree(d_b); 161 | cudaFree(d_c); 162 | cudaFree(d_c_ref); 163 | cublasDestroy(handle); 164 | 165 | return max_error; 166 | } 167 | 168 | template 169 | float gemm_error_check_tn_swizzle( 170 | void (*gpu_hgemm) (T *, T *, T *, int, int, int, int), 171 | int M, int N, int K, int swizzle_stride) { 172 | 173 | size_t size_a = M * K * sizeof(T); 174 | size_t size_b = K * N * sizeof(T); 175 | size_t size_c = M * N * sizeof(T); 176 | 177 | T *h_a, *h_b, *h_c, *h_c_ref; 178 | T *d_a, *d_b, *d_c, *d_c_ref; 179 | 180 | h_a = (T *)malloc(size_a); 181 | h_b = (T *)malloc(size_b); 182 | h_c = (T *)malloc(size_c); 183 | h_c_ref = (T *)malloc(size_c); 184 | 185 | cudaMalloc(&d_a, size_a); 186 | cudaMalloc(&d_b, size_b); 187 | cudaMalloc(&d_c, size_c); 188 | cudaMalloc(&d_c_ref, size_c); 189 | 190 | srand(time(0)); 191 | for (int i = 0; i < M * K; i++) 192 | h_a[i] = (T)((rand() % 200 - 100) * 0.01); // -1 ~ 1 193 | for (int i = 0; i < K * N; i++) 194 | h_b[i] = (T)((rand() % 200 - 100) * 0.01); 195 | 196 | cublasHandle_t handle; 197 | cublasCreate(&handle); 198 | half alpha = 1.f; 199 | half beta = 0.f; 200 | 201 | cudaMemcpy(d_a, h_a, size_a, cudaMemcpyHostToDevice); 202 | cudaMemcpy(d_b, h_b, size_b, cudaMemcpyHostToDevice); 203 | 204 | cublasHgemm(handle, 205 | CUBLAS_OP_T, 206 | CUBLAS_OP_N, 207 | N, M, K, 208 | &alpha, 209 | (half *)d_b, K, 210 | (half *)d_a, K, 211 | &beta, 212 | (half *)d_c_ref, N); 213 | 214 | gpu_hgemm(d_a, d_b, d_c, M, N, K, swizzle_stride); 215 | 216 | cudaMemcpy(h_c, d_c, size_c, cudaMemcpyDeviceToHost); 217 | cudaMemcpy(h_c_ref, d_c_ref, size_c, cudaMemcpyDeviceToHost); 218 | 219 | float max_error = 0.0; 220 | for (int i = 0; i < M * N; i++) { 221 | float this_error = abs((float)h_c_ref[i] - (float)h_c[i]); 222 | max_error = max(max_error, this_error); 223 | } 224 | 225 | free(h_a); 226 | free(h_b); 227 | free(h_c); 228 | free(h_c_ref); 229 | cudaFree(d_a); 230 | cudaFree(d_b); 231 | cudaFree(d_c); 232 | cudaFree(d_c_ref); 233 | cublasDestroy(handle); 234 | 235 | return max_error; 236 | } 237 | 238 | template 239 | float gemm_error_check_nn( 240 | void (*gpu_hgemm) (T *, T *, T *, int, int, int), 241 | int M, int N, int K) { 242 | 243 | size_t size_a = M * K * sizeof(T); 244 | size_t size_b = K * N * sizeof(T); 245 | size_t size_c = M * N * sizeof(T); 246 | 247 | T *h_a, *h_b, *h_c, *h_c_ref; 248 | T *d_a, *d_b, *d_c, *d_c_ref; 249 | 250 | h_a = (T *)malloc(size_a); 251 | h_b = (T *)malloc(size_b); 252 | h_c = (T *)malloc(size_c); 253 | h_c_ref = (T *)malloc(size_c); 254 | 255 | cudaMalloc(&d_a, size_a); 256 | cudaMalloc(&d_b, size_b); 257 | cudaMalloc(&d_c, size_c); 258 | cudaMalloc(&d_c_ref, size_c); 259 | 260 | srand(time(0)); 261 | for (int i = 0; i < M * K; i++) 262 | h_a[i] = (T)((rand() % 200 - 100) * 0.01); // -1 ~ 1 263 | for (int i = 0; i < K * N; i++) 264 | h_b[i] = (T)((rand() % 200 - 100) * 0.01); 265 | 266 | cublasHandle_t handle; 267 | cublasCreate(&handle); 268 | cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH); 269 | half alpha = 1.f; 270 | half beta = 0.f; 271 | 272 | cudaMemcpy(d_a, h_a, size_a, cudaMemcpyHostToDevice); 273 | cudaMemcpy(d_b, h_b, size_b, cudaMemcpyHostToDevice); 274 | 275 | cublasGemmEx(handle, 276 | CUBLAS_OP_N, 277 | CUBLAS_OP_N, 278 | N, M, K, 279 | &alpha, 280 | (half *)d_b, CUDA_R_16F, N, 281 | (half *)d_a, CUDA_R_16F, K, 282 | &beta, 283 | (half *)d_c_ref, CUDA_R_16F, N, 284 | CUBLAS_COMPUTE_16F, 285 | CUBLAS_GEMM_DEFAULT_TENSOR_OP); 286 | 287 | gpu_hgemm(d_a, d_b, d_c, M, N, K); 288 | 289 | cudaMemcpy(h_c, d_c, size_c, cudaMemcpyDeviceToHost); 290 | cudaMemcpy(h_c_ref, d_c_ref, size_c, cudaMemcpyDeviceToHost); 291 | 292 | float max_error = 0.0; 293 | for (int i = 0; i < M * N; i++) { 294 | float this_error = abs((float)h_c_ref[i] - (float)h_c[i]); 295 | max_error = max(max_error, this_error); 296 | } 297 | 298 | free(h_a); 299 | free(h_b); 300 | free(h_c); 301 | free(h_c_ref); 302 | cudaFree(d_a); 303 | cudaFree(d_b); 304 | cudaFree(d_c); 305 | cudaFree(d_c_ref); 306 | cublasDestroy(handle); 307 | 308 | return max_error; 309 | } 310 | -------------------------------------------------------------------------------- /kernels/hgemm/tools/print_swizzle_layout.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def pretty_print_line(m: str = "", 5 | sep: str = "-", 6 | width: int = 130, 7 | return_str: bool = False): 8 | res_len = width - len(m) 9 | left_len = int(res_len / 2) 10 | right_len = res_len - left_len 11 | pretty_line = sep * left_len + m + sep * right_len 12 | if not return_str: 13 | print(pretty_line) 14 | else: 15 | return pretty_line 16 | 17 | 18 | PERMUTED_DOCS_STRING = \ 19 | """---------------------------------------------------------------- 20 | [INFO] Assert smem store layout col_stride <= 16, prefer 16. | 21 | [INFO] For logical_col_stride > 16, we have to permute the | 22 | [INFO] smem store layout using col major ZigZag method: | 23 | [INFO] e.g, --> Q smem logical layout [Br][64]. | 24 | [INFO] --> col major ZigZag permuted --> | 25 | [INFO] --> Q smem store layout [4][Br][16]. | 26 | ----------------------------------------------------------------""" 27 | 28 | def swizzle_permuted_j(i: int, 29 | j: int, 30 | col_stride: int = 16, 31 | num_elems_per_128b: int = 8): 32 | # i: row index; j: col index. col_stride <= 16. 33 | # assert col_stride <= 16, f"col_stride must <= 16, but got {col_stride}" 34 | # for col_stride > 16, we have to permute it using col major ZigZag order. 35 | # e.g, Q smem logical layout [Br,d]=[Br,64] -> store layout [4][Br][16]. 36 | return ( 37 | (int(j / num_elems_per_128b) ^ int(i / 4)) % 38 | (int(col_stride / num_elems_per_128b)) 39 | ) * num_elems_per_128b 40 | 41 | 42 | def print_smem_swizzle_layout(rows: int = 16, 43 | logical_col_stride: int = 16, 44 | num_elems_per_128b: int = 8, 45 | smem_pading: int = 0, 46 | show_logical_col_id: bool = False, 47 | use_logical_col_stride: bool = False): 48 | # ---------------------------------------------------------------- 49 | # [INFO] Assert smem store layout col_stride <= 16, prefer 16. | 50 | # [INFO] For logical_col_stride > 16, we have to permute the | 51 | # [INFO] smem store layout using col major ZigZag method: | 52 | # [INFO] e.g, --> Q smem logical layout [Br][64]. | 53 | # [INFO] --> col major ZigZag permuted --> | 54 | # [INFO] --> Q smem store layout [4][Br][16]. | 55 | # ---------------------------------------------------------------- 56 | # ---------------------------------------------------------------- 57 | # -------------------------swizzle layout------------------------- 58 | # --------------------logical col 0~64, step 8-------------------- 59 | # ---------------------smem col 0~16, step 8---------------------- 60 | # ---------------------------------------------------------------- 61 | # |bank |b 0~3 |b 4~7 |b 0~3 |b 4~7 |b 0~3 |b 4~7 |b 0~3 |b 4~7 | 62 | # |row 0 | 0:0 | 8:8 |16:0 |24:8 |32:0 |40:8 |48:0 |56:8 | 63 | # |bank |b 8~11|b12~15|b 8~11|b12~15|b 8~11|b12~15|b 8~11|b12~15| 64 | # |row 1 | 0:0 | 8:8 |16:0 |24:8 |32:0 |40:8 |48:0 |56:8 | 65 | # |bank |b16~19|b20~23|b16~19|b20~23|b16~19|b20~23|b16~19|b20~23| 66 | # |row 2 | 0:0 | 8:8 |16:0 |24:8 |32:0 |40:8 |48:0 |56:8 | 67 | # |bank |b24~27|b28~31|b24~27|b28~31|b24~27|b28~31|b24~27|b28~31| 68 | # |row 3 | 0:0 | 8:8 |16:0 |24:8 |32:0 |40:8 |48:0 |56:8 | 69 | # ---------------------------------------------------------------- 70 | # |bank |b 0~3 |b 4~7 |b 0~3 |b 4~7 |b 0~3 |b 4~7 |b 0~3 |b 4~7 | 71 | # |row 4 | 0:8 | 8:0 |16:8 |24:0 |32:8 |40:0 |48:8 |56:0 | 72 | # |bank |b 8~11|b12~15|b 8~11|b12~15|b 8~11|b12~15|b 8~11|b12~15| 73 | # |row 5 | 0:8 | 8:0 |16:8 |24:0 |32:8 |40:0 |48:8 |56:0 | 74 | # |bank |b16~19|b20~23|b16~19|b20~23|b16~19|b20~23|b16~19|b20~23| 75 | # |row 6 | 0:8 | 8:0 |16:8 |24:0 |32:8 |40:0 |48:8 |56:0 | 76 | # |bank |b24~27|b28~31|b24~27|b28~31|b24~27|b28~31|b24~27|b28~31| 77 | # |row 7 | 0:8 | 8:0 |16:8 |24:0 |32:8 |40:0 |48:8 |56:0 | 78 | # ---------------------------------------------------------------- 79 | # |bank |b 0~3 |b 4~7 |b 0~3 |b 4~7 |b 0~3 |b 4~7 |b 0~3 |b 4~7 | 80 | # |row 8 | 0:0 | 8:8 |16:0 |24:8 |32:0 |40:8 |48:0 |56:8 | 81 | # |bank |b 8~11|b12~15|b 8~11|b12~15|b 8~11|b12~15|b 8~11|b12~15| 82 | # |row 9 | 0:0 | 8:8 |16:0 |24:8 |32:0 |40:8 |48:0 |56:8 | 83 | # |bank |b16~19|b20~23|b16~19|b20~23|b16~19|b20~23|b16~19|b20~23| 84 | # |row 10| 0:0 | 8:8 |16:0 |24:8 |32:0 |40:8 |48:0 |56:8 | 85 | # |bank |b24~27|b28~31|b24~27|b28~31|b24~27|b28~31|b24~27|b28~31| 86 | # |row 11| 0:0 | 8:8 |16:0 |24:8 |32:0 |40:8 |48:0 |56:8 | 87 | # ---------------------------------------------------------------- 88 | # |bank |b 0~3 |b 4~7 |b 0~3 |b 4~7 |b 0~3 |b 4~7 |b 0~3 |b 4~7 | 89 | # |row 12| 0:8 | 8:0 |16:8 |24:0 |32:8 |40:0 |48:8 |56:0 | 90 | # |bank |b 8~11|b12~15|b 8~11|b12~15|b 8~11|b12~15|b 8~11|b12~15| 91 | # |row 13| 0:8 | 8:0 |16:8 |24:0 |32:8 |40:0 |48:8 |56:0 | 92 | # |bank |b16~19|b20~23|b16~19|b20~23|b16~19|b20~23|b16~19|b20~23| 93 | # |row 14| 0:8 | 8:0 |16:8 |24:0 |32:8 |40:0 |48:8 |56:0 | 94 | # |bank |b24~27|b28~31|b24~27|b28~31|b24~27|b28~31|b24~27|b28~31| 95 | # |row 15| 0:8 | 8:0 |16:8 |24:0 |32:8 |40:0 |48:8 |56:0 | 96 | # ---------------------------------------------------------------- 97 | str_len = 0 98 | total_banks = 0 99 | assert smem_pading == 0 or smem_pading == 8, "smem_pading must be 0 or 8" 100 | # 4 bytes per bank 101 | banks_per_col = int((16 * 2) / 4) if logical_col_stride >= 16 else 4 102 | if use_logical_col_stride: 103 | banks_per_col = int((logical_col_stride * 2) / 4) 104 | if logical_col_stride > 16: 105 | print(f"[WARN] col_stride must <= 16, but got {logical_col_stride}") 106 | if smem_pading == 8: 107 | banks_per_col += 4 108 | print(f"[INFO] smem padding 8 half values, 4 banks, banks_per_col: {banks_per_col}") 109 | 110 | banks_per_num_elems_per_128b = int((num_elems_per_128b * 2) / 4) 111 | for i in range(rows): 112 | layout_str_len = 0 113 | banks_str_len = 0 114 | 115 | # bank_layout_str 116 | banks_start = total_banks % 32 # 32 banks in total 117 | banks_end = (banks_start + banks_per_col) 118 | bank_layout_str = f"|bank |" 119 | max_bank_str_len = 0 120 | if logical_col_stride >= 16 and (not use_logical_col_stride): 121 | for k in range(int(logical_col_stride / 16)): 122 | for j in range(banks_start, banks_end, banks_per_num_elems_per_128b): 123 | curr_bank_str = f"b{j:>2}~{j + banks_per_num_elems_per_128b - 1:<2}|" 124 | max_bank_str_len = max(max_bank_str_len, len(curr_bank_str)) 125 | bank_layout_str += curr_bank_str 126 | else: 127 | for j in range(banks_start, banks_end, banks_per_num_elems_per_128b): 128 | curr_bank_str = f"b{j:>2}~{j + banks_per_num_elems_per_128b - 1:<2}|" 129 | max_bank_str_len = max(max_bank_str_len, len(curr_bank_str)) 130 | bank_layout_str += curr_bank_str 131 | 132 | # smem_layout_str 133 | logical_col_ids = [] 134 | smem_layout_col_ids = [] 135 | if logical_col_stride >= 16 and (not use_logical_col_stride): 136 | for k in range(int(logical_col_stride / 16)): 137 | for j in range(0, 16, num_elems_per_128b): 138 | layout_j = swizzle_permuted_j(i, j, 16, 139 | num_elems_per_128b) 140 | logical_col_ids.append(k * 16 + j) 141 | smem_layout_col_ids.append(layout_j) 142 | else: 143 | for j in range(0, logical_col_stride, num_elems_per_128b): 144 | layout_j = swizzle_permuted_j(i, j, logical_col_stride, 145 | num_elems_per_128b) 146 | logical_col_ids.append(j) 147 | smem_layout_col_ids.append(layout_j) 148 | 149 | smem_layout_str = f"|row {i:<2}|" 150 | 151 | r = 0 152 | for c, l in zip(logical_col_ids, smem_layout_col_ids): 153 | smem_layout_str += pretty_print_line( 154 | (f"{c:>2}:{l:<2}" if show_logical_col_id else f"{l:<2}"), 155 | sep=" ", 156 | width=(max_bank_str_len-1), 157 | return_str=True 158 | ) + "|" 159 | r += 1 160 | if logical_col_stride >= 16: 161 | if smem_pading == 8 and (r > 1 and r % 2 == 0): 162 | smem_layout_str += pretty_print_line( 163 | (f"pad"), 164 | sep=" ", width=max_bank_str_len-1, 165 | return_str=True 166 | ) + "|" 167 | else: 168 | if smem_pading == 8: 169 | smem_layout_str += pretty_print_line( 170 | (f"pad"), 171 | sep=" ", width=max_bank_str_len-1, 172 | return_str=True 173 | ) + "|" 174 | 175 | layout_str_len = len(smem_layout_str) 176 | str_len = max(layout_str_len, banks_str_len) 177 | 178 | # print banks and smem layout 179 | if (i == 0): 180 | print("-" * str_len) 181 | pretty_print_line(f"swizzle layout", width=str_len) 182 | pretty_print_line(f"logical col 0~{logical_col_stride}, " 183 | f"step {num_elems_per_128b}", 184 | width=str_len) 185 | pretty_print_line(f"smem col 0~16, step {num_elems_per_128b}" 186 | if logical_col_stride >= 16 187 | else f"smem col 0~8, step {num_elems_per_128b}", 188 | width=str_len) 189 | print("-" * str_len) 190 | print(bank_layout_str) 191 | print(smem_layout_str) 192 | if ((i + 1) % 4 == 0 and i != (rows - 1)): 193 | print("-" * str_len) 194 | total_banks += banks_per_col 195 | print("-" * str_len) 196 | 197 | 198 | def get_args(): 199 | parser = argparse.ArgumentParser() 200 | parser.add_argument("--rows", type=int, default=16) 201 | parser.add_argument("--smem-padding", "--pad", type=int, default=0) 202 | parser.add_argument("--num-elems-per-128b", "--num-elems", type=int, default=8) 203 | parser.add_argument("--logical-col-stride", "--logical-col", "--col", type=int, default=64) 204 | parser.add_argument("--use-logical-col-stride", "--use-logical-col", action="store_true") 205 | parser.add_argument("--show-logical-col-id", "--show-logical-col", action="store_true") 206 | return parser.parse_args() 207 | 208 | 209 | if __name__ == "__main__": 210 | args = get_args() 211 | print(args) 212 | print(PERMUTED_DOCS_STRING) 213 | print_smem_swizzle_layout(rows=args.rows, 214 | logical_col_stride=args.logical_col_stride, 215 | num_elems_per_128b=args.num_elems_per_128b, 216 | smem_pading=args.smem_padding, 217 | show_logical_col_id=args.show_logical_col_id, 218 | use_logical_col_stride=args.use_logical_col_stride) 219 | 220 | -------------------------------------------------------------------------------- /kernels/hgemm/mma/hgemm_mma.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | using namespace nvcuda; 14 | 15 | #define WARP_SIZE 32 16 | #define DEVICE_INLINE __device__ inline 17 | #define HOST_DEVICE_INLINE __device__ __host__ inline 18 | #define INT4(value) (reinterpret_cast(&(value))[0]) 19 | #define FLOAT4(value) (reinterpret_cast(&(value))[0]) 20 | #define HALF2(value) (reinterpret_cast(&(value))[0]) 21 | #define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162*>(&(value))[0]) 22 | #define LDST32BITS(value) (reinterpret_cast(&(value))[0]) 23 | #define LDST64BITS(value) (reinterpret_cast(&(value))[0]) 24 | #define LDST128BITS(value) (reinterpret_cast(&(value))[0]) 25 | #define CP_ASYNC_COMMIT_GROUP() asm volatile("cp.async.commit_group;\n" ::) 26 | #define CP_ASYNC_WAIT_ALL() asm volatile("cp.async.wait_all;\n" ::) 27 | #define CP_ASYNC_WAIT_GROUP(n) asm volatile("cp.async.wait_group %0;\n" ::"n"(n)) 28 | // ca(cache all, L1 + L2): support 4, 8, 16 bytes, cg(cache global, L2): only support 16 bytes. 29 | #define CP_ASYNC_CA(dst, src, bytes) asm volatile("cp.async.ca.shared.global.L2::128B [%0], [%1], %2;\n" ::"r"(dst), "l"(src), "n"(bytes)) 30 | #define CP_ASYNC_CG(dst, src, bytes) asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], %2;\n" ::"r"(dst), "l"(src), "n"(bytes)) 31 | #define LDMATRIX_X1(R, addr) asm volatile("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];\n" : "=r"(R) : "r"(addr)) 32 | #define LDMATRIX_X2(R0, R1, addr) asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" : "=r"(R0), "=r"(R1) : "r"(addr)) 33 | #define LDMATRIX_X4(R0, R1, R2, R3, addr) asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" : "=r"(R0), "=r"(R1), "=r"(R2), "=r"(R3) : "r"(addr)) 34 | #define LDMATRIX_X1_T(R, addr) asm volatile("ldmatrix.sync.aligned.x1.trans.m8n8.shared.b16 {%0}, [%1];\n" : "=r"(R) : "r"(addr)) 35 | #define LDMATRIX_X2_T(R0, R1, addr) asm volatile("ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0, %1}, [%2];\n" : "=r"(R0), "=r"(R1) : "r"(addr)) 36 | #define LDMATRIX_X4_T(R0, R1, R2, R3, addr) asm volatile("ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" : "=r"(R0), "=r"(R1), "=r"(R2), "=r"(R3) : "r"(addr)) 37 | #define HMMA16816(RD0, RD1, RA0, RA1, RA2, RA3, RB0, RB1, RC0, RC1) asm volatile("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%8, %9};\n" : "=r"(RD0), "=r"(RD1) : "r"(RA0), "r"(RA1), "r"(RA2), "r"(RA3), "r"(RB0), "r"(RB1), "r"(RC0), "r"(RC1)) 38 | 39 | HOST_DEVICE_INLINE 40 | int div_ceil(int a, int b) { return (a % b != 0) ? (a / b + 1) : (a / b); } 41 | 42 | // only 1 warp per block(32 threads), m16n8k16. A, B, C: all row_major. 43 | template 44 | __global__ void hgemm_mma_m16n8k16_naive_kernel(half* A, half* B, half* C, 45 | int M, int N, int K) { 46 | const int bx = blockIdx.x; 47 | const int by = blockIdx.y; 48 | const int NUM_K_TILES = div_ceil(K, MMA_K); 49 | constexpr int BM = MMA_M; // 16 50 | constexpr int BN = MMA_N; // 8 51 | constexpr int BK = MMA_K; // 16 52 | 53 | __shared__ half s_a[MMA_M][MMA_K]; // 16x16 54 | __shared__ half s_b[MMA_K][MMA_N]; // 16x8 55 | __shared__ half s_c[MMA_M][MMA_N]; // 16x8 56 | 57 | const int tid = threadIdx.y * blockDim.x + threadIdx.x; // within block 58 | const int lane_id = tid % WARP_SIZE; // 0~31 59 | 60 | // s_a[16][16], 每行16,每线程load 8,需要2线程,共16行,需2x16=32线程 61 | const int load_smem_a_m = tid / 2; // row 0~15 62 | const int load_smem_a_k = (tid % 2) * 8; // col 0,8 63 | // s_b[16][8], 每行8,每线程load 8,需要1线程,共16行,需16线程,只需一半线程加载 64 | const int load_smem_b_k = tid; // row 0~31, but only use 0~15 65 | const int load_smem_b_n = 0; // col 0 66 | const int load_gmem_a_m = by * BM + load_smem_a_m; // global m 67 | const int load_gmem_b_n = bx * BN + load_smem_b_n; // global n 68 | if (load_gmem_a_m >= M && load_gmem_b_n >= N) return; 69 | 70 | uint32_t RC[2] = {0, 0}; 71 | 72 | #pragma unroll 73 | for (int k = 0; k < NUM_K_TILES; ++k) { 74 | // gmem_a -> smem_a 75 | int load_gmem_a_k = k * BK + load_smem_a_k; // global col of a 76 | int load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k; 77 | LDST128BITS(s_a[load_smem_a_m][load_smem_a_k]) = ( 78 | LDST128BITS(A[load_gmem_a_addr])); 79 | 80 | // gmem_b -> smem_b 81 | if (lane_id < MMA_K) { 82 | int load_gmem_b_k = k * MMA_K + load_smem_b_k; // global row of b 83 | int load_gmem_b_addr = load_gmem_b_k * N + load_gmem_b_n; 84 | LDST128BITS(s_b[load_smem_b_k][load_smem_b_n]) = ( 85 | LDST128BITS(B[load_gmem_b_addr])); 86 | } 87 | __syncthreads(); 88 | 89 | uint32_t RA[4]; 90 | uint32_t RB[2]; 91 | 92 | // ldmatrix for s_a, ldmatrix.trans for s_b. 93 | // s_a: (0,1)*8 -> 0,8 -> [(0~15),(0,8)] 94 | uint32_t load_smem_a_ptr = __cvta_generic_to_shared( 95 | &s_a[lane_id % 16][(lane_id / 16) * 8]); 96 | LDMATRIX_X4(RA[0], RA[1], RA[2], RA[3], load_smem_a_ptr); 97 | uint32_t load_smem_b_ptr = __cvta_generic_to_shared( 98 | &s_b[lane_id % 16][0]); 99 | LDMATRIX_X2_T(RB[0], RB[1], load_smem_b_ptr); 100 | 101 | HMMA16816(RC[0], RC[1], RA[0], RA[1], RA[2], RA[3], RB[0], RB[1], RC[0], RC[1]); 102 | 103 | __syncthreads(); 104 | } 105 | 106 | // s_c[16][8], https://docs.nvidia.com/cuda/parallel-thread-execution/index.html 107 | // #matrix-fragments-for-mma-m16n8k16-with-floating-point-type 108 | // [0~7][0~3 u32 -> 0~7 f16], [8~15][0~3 u32 -> 0~7 f16] 109 | LDST32BITS(s_c[lane_id / 4 ][(lane_id % 4) * 2]) = LDST32BITS(RC[0]); 110 | LDST32BITS(s_c[lane_id / 4 + 8][(lane_id % 4) * 2]) = LDST32BITS(RC[1]); 111 | 112 | __syncthreads(); 113 | 114 | // store s_c[16][8] 115 | if (lane_id < MMA_M) { 116 | // store 128 bits per memory issue. 117 | int store_gmem_c_m = by * BM + lane_id; 118 | int store_gmem_c_n = bx * BN; 119 | int store_gmem_c_addr = store_gmem_c_m * N + store_gmem_c_n; 120 | LDST128BITS(C[store_gmem_c_addr]) = (LDST128BITS(s_c[lane_id][0])); 121 | } 122 | } 123 | 124 | // 128x128, mma2x4, warp4x4(64,32,16) 125 | template 134 | __global__ void __launch_bounds__(256) 135 | hgemm_mma_m16n8k16_mma2x4_warp4x4_kernel( 136 | half* A, half* B, half* C, int M, int N, int K) { 137 | const int bx = blockIdx.x; 138 | const int by = blockIdx.y; 139 | const int NUM_K_TILES = div_ceil(K, MMA_K); 140 | constexpr int BM = MMA_M * MMA_TILE_M * WARP_TILE_M; // 16*2*4=128 141 | constexpr int BN = MMA_N * MMA_TILE_N * WARP_TILE_N; // 8*4*4=128 142 | constexpr int BK = MMA_K; // 16 143 | 144 | __shared__ half s_a[BM][BK+A_PAD]; // 128*16*2=4KB 145 | __shared__ half s_b[BK][BN+B_PAD]; // 16*128*2=4KB, 16*(128+16)*2=4.5KB 146 | 147 | const int tid = threadIdx.y * blockDim.x + threadIdx.x; // within block 148 | const int warp_id = tid / WARP_SIZE; // 0~7 warp_id within block 149 | const int lane_id = tid % WARP_SIZE; // 0~31 150 | const int warp_m = warp_id % 2; // 0,1 151 | const int warp_n = warp_id / 2; // 0,1,2,3 152 | 153 | // 先计算shared memory中的索引 154 | // tid和需要加载的smem s_a[BM][BK] 之间的索引关系 BM=128 BK=16 按行读取 A行主序 155 | // 对于s_a每行16个数据,每个线程读取8个,需要2个线程;总共128行,需要128x2刚好256线程 156 | int load_smem_a_m = tid / 2; // row 0~127 157 | int load_smem_a_k = (tid % 2 == 0) ? 0 : 8; // col 0,8 158 | // tid和需要加载的smem s_b[BK][BN] 之间的索引关系 BK=16 BN=128 按行读取 B行主序 159 | // 对于s_b每行128个数据,每个线程读8个数据,需要16个线程;总共16行,需要16x16=256个线程 160 | int load_smem_b_k = tid / 16; // row 0~15 161 | int load_smem_b_n = (tid % 16) * 8; // col 0,8,...,120 162 | // 再计算全局内存中的索引 163 | // 要加载到s_a中的元素对应到A全局内存中的行数 每个block负责出C中大小为BM*BN的块 164 | int load_gmem_a_m = by * BM + load_smem_a_m; // global row of a and c 165 | int load_gmem_b_n = bx * BN + load_smem_b_n; // global col of b and c 166 | if (load_gmem_a_m >= M || load_gmem_b_n >= N) return; 167 | 168 | uint32_t RC[WARP_TILE_M][WARP_TILE_N][2]; 169 | #pragma unroll 170 | for (int i = 0; i < WARP_TILE_M; ++i) { 171 | #pragma unroll 172 | for (int j = 0; j < WARP_TILE_N; ++j) { 173 | RC[i][j][0] = 0; 174 | RC[i][j][1] = 0; 175 | } 176 | } 177 | 178 | #pragma unroll 179 | for (int k = 0; k < NUM_K_TILES; ++k) { 180 | // gmem -> smem 181 | int load_gmem_a_k = k * BK + load_smem_a_k; // global col of a 182 | int load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k; 183 | int load_gmem_b_k = k * BK + load_smem_b_k; // global row of b 184 | int load_gmem_b_addr = load_gmem_b_k * N + load_gmem_b_n; 185 | LDST128BITS(s_b[load_smem_b_k][load_smem_b_n]) = ( 186 | LDST128BITS(B[load_gmem_b_addr])); 187 | LDST128BITS(s_a[load_smem_a_m][load_smem_a_k]) = ( 188 | LDST128BITS(A[load_gmem_a_addr])); 189 | __syncthreads(); 190 | 191 | // ldmatrix for s_a, ldmatrix.trans for s_b. 192 | uint32_t RA[WARP_TILE_M][4]; 193 | uint32_t RB[WARP_TILE_N][2]; 194 | 195 | // smem -> reg 196 | #pragma unroll 197 | for (int i = 0; i < WARP_TILE_M; ++i) { 198 | int warp_smem_a_m = warp_m * (MMA_M * WARP_TILE_M) + i * MMA_M; 199 | int lane_smem_a_m = warp_smem_a_m + lane_id % 16; // 0~15 200 | int lane_smem_a_k = (lane_id / 16) * 8; // 0,8 201 | uint32_t lane_smem_a_ptr = __cvta_generic_to_shared( 202 | &s_a[lane_smem_a_m][lane_smem_a_k]); 203 | LDMATRIX_X4(RA[i][0], RA[i][1], RA[i][2], RA[i][3], lane_smem_a_ptr); 204 | } 205 | 206 | #pragma unroll 207 | for (int j = 0; j < WARP_TILE_N; ++j) { 208 | int warp_smem_b_n = warp_n * (MMA_N * WARP_TILE_N) + j * MMA_N; 209 | int lane_smem_b_k = lane_id % 16; // 0~15 210 | int lane_smem_b_n = warp_smem_b_n; // 0, MMA_N=8 211 | uint32_t lane_smem_b_ptr = __cvta_generic_to_shared( 212 | &s_b[lane_smem_b_k][lane_smem_b_n]); 213 | LDMATRIX_X2_T(RB[j][0], RB[j][1], lane_smem_b_ptr); 214 | } 215 | 216 | // MMA compute 217 | #pragma unroll 218 | for (int i = 0; i < WARP_TILE_M; ++i) { 219 | #pragma unroll 220 | for (int j = 0; j < WARP_TILE_N; ++j) { 221 | HMMA16816(RC[i][j][0], RC[i][j][1], 222 | RA[i][0], RA[i][1], RA[i][2], RA[i][3], 223 | RB[j][0], RB[j][1], 224 | RC[i][j][0], RC[i][j][1]); 225 | } 226 | } 227 | __syncthreads(); 228 | } 229 | 230 | // reg -> gmem, MMA_MxMMA_N=16x8 231 | #pragma unroll 232 | for (int i = 0; i < WARP_TILE_M; ++i) { 233 | #pragma unroll 234 | for (int j = 0; j < WARP_TILE_N; ++j) { 235 | int store_warp_smem_c_m = warp_m * (MMA_M * WARP_TILE_M) + i * MMA_M; 236 | int store_warp_smem_c_n = warp_n * (MMA_N * WARP_TILE_N) + j * MMA_N; 237 | // mapping lane smem index -> global index. 238 | // [16][8], https://docs.nvidia.com/cuda/parallel-thread-execution/index.html 239 | // #matrix-fragments-for-mma-m16n8k16-with-floating-point-type 240 | // [0~7][0~3 u32 -> 0~7 f16], [8~15][0~3 u32 -> 0~7 f16] 241 | int store_lane_gmem_c_m = by * BM + store_warp_smem_c_m + lane_id / 4; 242 | int store_lane_gmem_c_n = bx * BN + store_warp_smem_c_n + (lane_id % 4) * 2; 243 | int store_gmem_c_addr_0 = store_lane_gmem_c_m * N + store_lane_gmem_c_n; 244 | int store_gmem_c_addr_1 = (store_lane_gmem_c_m + 8) * N + store_lane_gmem_c_n; 245 | // TODO: how to use LDST128BITS here ? reverse the loop order ? 246 | LDST32BITS(C[store_gmem_c_addr_0]) = LDST32BITS(RC[i][j][0]); 247 | LDST32BITS(C[store_gmem_c_addr_1]) = LDST32BITS(RC[i][j][1]); 248 | } 249 | } 250 | } 251 | 252 | 253 | // --------------------- PyTorch bindings for custom kernel ----------------------- 254 | #define STRINGFY(str) #str 255 | #define TORCH_BINDING_COMMON_EXTENSION(func) \ 256 | m.def(STRINGFY(func), &func, STRINGFY(func)); 257 | 258 | #define CHECK_TORCH_TENSOR_DTYPE(T, th_type) \ 259 | if(((T).options().dtype() != (th_type))) { \ 260 | std::cout << "Tensor Info:" << (T).options() << std::endl; \ 261 | throw std::runtime_error("values must be "#th_type); \ 262 | } 263 | 264 | #define CHECK_TORCH_TENSOR_SHAPE(T, S0, S1) \ 265 | if (((T).size(0) != (S0)) || ((T).size(1) != (S1))) { \ 266 | throw std::runtime_error("Tensor size mismatch!"); \ 267 | } 268 | 269 | // only 1 warp per block(32 threads), m16n8k16. A, B, C: all row_major. 270 | void hgemm_mma_m16n8k16_naive( 271 | torch::Tensor a, torch::Tensor b, torch::Tensor c) { 272 | CHECK_TORCH_TENSOR_DTYPE(a, torch::kHalf) 273 | CHECK_TORCH_TENSOR_DTYPE(b, torch::kHalf) 274 | CHECK_TORCH_TENSOR_DTYPE(c, torch::kHalf) 275 | const int M = a.size(0); 276 | const int K = a.size(1); 277 | const int N = b.size(1); 278 | CHECK_TORCH_TENSOR_SHAPE(a, M, K) 279 | CHECK_TORCH_TENSOR_SHAPE(b, K, N) 280 | CHECK_TORCH_TENSOR_SHAPE(c, M, N) 281 | constexpr int MMA_M = 16; 282 | constexpr int MMA_N = 8; 283 | constexpr int MMA_K = 16; 284 | 285 | dim3 block(WARP_SIZE); 286 | dim3 grid(div_ceil(N, MMA_N), div_ceil(M, MMA_M)); 287 | 288 | hgemm_mma_m16n8k16_naive_kernel< 289 | MMA_M, MMA_N, MMA_K><<>>( 290 | reinterpret_cast(a.data_ptr()), 291 | reinterpret_cast(b.data_ptr()), 292 | reinterpret_cast(c.data_ptr()), 293 | M, N, K 294 | ); 295 | } 296 | 297 | // 128x128, mma2x4, warp4x4(64,32,16) 298 | void hgemm_mma_m16n8k16_mma2x4_warp4x4( 299 | torch::Tensor a, torch::Tensor b, torch::Tensor c) { 300 | CHECK_TORCH_TENSOR_DTYPE(a, torch::kHalf) 301 | CHECK_TORCH_TENSOR_DTYPE(b, torch::kHalf) 302 | CHECK_TORCH_TENSOR_DTYPE(c, torch::kHalf) 303 | const int M = a.size(0); 304 | const int K = a.size(1); 305 | const int N = b.size(1); 306 | CHECK_TORCH_TENSOR_SHAPE(a, M, K) 307 | CHECK_TORCH_TENSOR_SHAPE(b, K, N) 308 | CHECK_TORCH_TENSOR_SHAPE(c, M, N) 309 | constexpr int MMA_M = 16; 310 | constexpr int MMA_N = 8; 311 | constexpr int MMA_K = 16; 312 | constexpr int MMA_TILE_M = 2; 313 | constexpr int MMA_TILE_N = 4; 314 | constexpr int WARP_TILE_M = 4; 315 | constexpr int WARP_TILE_N = 4; 316 | constexpr int A_PAD = 0; 317 | constexpr int B_PAD = 16; 318 | constexpr int NUM_THREADS= ( 319 | MMA_TILE_M * MMA_TILE_N * WARP_SIZE); // 2 * 4 * 32 = 256 320 | 321 | dim3 block(NUM_THREADS); 322 | dim3 grid(div_ceil(N, MMA_N * MMA_TILE_N * WARP_TILE_N), 323 | div_ceil(M, MMA_M * MMA_TILE_M * WARP_TILE_M)); 324 | 325 | hgemm_mma_m16n8k16_mma2x4_warp4x4_kernel< 326 | MMA_M, MMA_N, MMA_K, MMA_TILE_M, MMA_TILE_N, 327 | WARP_TILE_M, WARP_TILE_N, A_PAD, B_PAD><<>>( 328 | reinterpret_cast(a.data_ptr()), 329 | reinterpret_cast(b.data_ptr()), 330 | reinterpret_cast(c.data_ptr()), 331 | M, N, K 332 | ); 333 | } 334 | -------------------------------------------------------------------------------- /kernels/hgemm/mma/basic/hgemm_mma.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | using namespace nvcuda; 14 | 15 | #define WARP_SIZE 32 16 | #define DEVICE_INLINE __device__ inline 17 | #define HOST_DEVICE_INLINE __device__ __host__ inline 18 | #define INT4(value) (reinterpret_cast(&(value))[0]) 19 | #define FLOAT4(value) (reinterpret_cast(&(value))[0]) 20 | #define HALF2(value) (reinterpret_cast(&(value))[0]) 21 | #define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162*>(&(value))[0]) 22 | #define LDST32BITS(value) (reinterpret_cast(&(value))[0]) 23 | #define LDST64BITS(value) (reinterpret_cast(&(value))[0]) 24 | #define LDST128BITS(value) (reinterpret_cast(&(value))[0]) 25 | #define CP_ASYNC_COMMIT_GROUP() asm volatile("cp.async.commit_group;\n" ::) 26 | #define CP_ASYNC_WAIT_ALL() asm volatile("cp.async.wait_all;\n" ::) 27 | #define CP_ASYNC_WAIT_GROUP(n) asm volatile("cp.async.wait_group %0;\n" ::"n"(n)) 28 | // ca(cache all, L1 + L2): support 4, 8, 16 bytes, cg(cache global, L2): only support 16 bytes. 29 | #define CP_ASYNC_CA(dst, src, bytes) asm volatile("cp.async.ca.shared.global.L2::128B [%0], [%1], %2;\n" ::"r"(dst), "l"(src), "n"(bytes)) 30 | #define CP_ASYNC_CG(dst, src, bytes) asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], %2;\n" ::"r"(dst), "l"(src), "n"(bytes)) 31 | #define LDMATRIX_X1(R, addr) asm volatile("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];\n" : "=r"(R) : "r"(addr)) 32 | #define LDMATRIX_X2(R0, R1, addr) asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" : "=r"(R0), "=r"(R1) : "r"(addr)) 33 | #define LDMATRIX_X4(R0, R1, R2, R3, addr) asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" : "=r"(R0), "=r"(R1), "=r"(R2), "=r"(R3) : "r"(addr)) 34 | #define LDMATRIX_X1_T(R, addr) asm volatile("ldmatrix.sync.aligned.x1.trans.m8n8.shared.b16 {%0}, [%1];\n" : "=r"(R) : "r"(addr)) 35 | #define LDMATRIX_X2_T(R0, R1, addr) asm volatile("ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0, %1}, [%2];\n" : "=r"(R0), "=r"(R1) : "r"(addr)) 36 | #define LDMATRIX_X4_T(R0, R1, R2, R3, addr) asm volatile("ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" : "=r"(R0), "=r"(R1), "=r"(R2), "=r"(R3) : "r"(addr)) 37 | #define HMMA16816(RD0, RD1, RA0, RA1, RA2, RA3, RB0, RB1, RC0, RC1) asm volatile("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%8, %9};\n" : "=r"(RD0), "=r"(RD1) : "r"(RA0), "r"(RA1), "r"(RA2), "r"(RA3), "r"(RB0), "r"(RB1), "r"(RC0), "r"(RC1)) 38 | 39 | HOST_DEVICE_INLINE 40 | int div_ceil(int a, int b) { return (a % b != 0) ? (a / b + 1) : (a / b); } 41 | 42 | // only 1 warp per block(32 threads), m16n8k16. A, B, C: all row_major. 43 | template 44 | __global__ void hgemm_mma_m16n8k16_naive_kernel(half* A, half* B, half* C, 45 | int M, int N, int K) { 46 | const int bx = blockIdx.x; 47 | const int by = blockIdx.y; 48 | const int NUM_K_TILES = div_ceil(K, MMA_K); 49 | constexpr int BM = MMA_M; // 16 50 | constexpr int BN = MMA_N; // 8 51 | constexpr int BK = MMA_K; // 16 52 | 53 | __shared__ half s_a[MMA_M][MMA_K]; // 16x16 54 | __shared__ half s_b[MMA_K][MMA_N]; // 16x8 55 | __shared__ half s_c[MMA_M][MMA_N]; // 16x8 56 | 57 | const int tid = threadIdx.y * blockDim.x + threadIdx.x; // within block 58 | const int lane_id = tid % WARP_SIZE; // 0~31 59 | 60 | // s_a[16][16], 每行16,每线程load 8,需要2线程,共16行,需2x16=32线程 61 | const int load_smem_a_m = tid / 2; // row 0~15 62 | const int load_smem_a_k = (tid % 2) * 8; // col 0,8 63 | // s_b[16][8], 每行8,每线程load 8,需要1线程,共16行,需16线程,只需一半线程加载 64 | const int load_smem_b_k = tid; // row 0~31, but only use 0~15 65 | const int load_smem_b_n = 0; // col 0 66 | const int load_gmem_a_m = by * BM + load_smem_a_m; // global m 67 | const int load_gmem_b_n = bx * BN + load_smem_b_n; // global n 68 | if (load_gmem_a_m >= M && load_gmem_b_n >= N) return; 69 | 70 | uint32_t RC[2] = {0, 0}; 71 | 72 | #pragma unroll 73 | for (int k = 0; k < NUM_K_TILES; ++k) { 74 | // gmem_a -> smem_a 75 | int load_gmem_a_k = k * BK + load_smem_a_k; // global col of a 76 | int load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k; 77 | LDST128BITS(s_a[load_smem_a_m][load_smem_a_k]) = ( 78 | LDST128BITS(A[load_gmem_a_addr])); 79 | 80 | // gmem_b -> smem_b 81 | if (lane_id < MMA_K) { 82 | int load_gmem_b_k = k * MMA_K + load_smem_b_k; // global row of b 83 | int load_gmem_b_addr = load_gmem_b_k * N + load_gmem_b_n; 84 | LDST128BITS(s_b[load_smem_b_k][load_smem_b_n]) = ( 85 | LDST128BITS(B[load_gmem_b_addr])); 86 | } 87 | __syncthreads(); 88 | 89 | uint32_t RA[4]; 90 | uint32_t RB[2]; 91 | 92 | // ldmatrix for s_a, ldmatrix.trans for s_b. 93 | // s_a: (0,1)*8 -> 0,8 -> [(0~15),(0,8)] 94 | uint32_t load_smem_a_ptr = __cvta_generic_to_shared( 95 | &s_a[lane_id % 16][(lane_id / 16) * 8]); 96 | LDMATRIX_X4(RA[0], RA[1], RA[2], RA[3], load_smem_a_ptr); 97 | uint32_t load_smem_b_ptr = __cvta_generic_to_shared( 98 | &s_b[lane_id % 16][0]); 99 | LDMATRIX_X2_T(RB[0], RB[1], load_smem_b_ptr); 100 | 101 | HMMA16816(RC[0], RC[1], RA[0], RA[1], RA[2], RA[3], RB[0], RB[1], RC[0], RC[1]); 102 | 103 | __syncthreads(); 104 | } 105 | 106 | // s_c[16][8], https://docs.nvidia.com/cuda/parallel-thread-execution/index.html 107 | // #matrix-fragments-for-mma-m16n8k16-with-floating-point-type 108 | // [0~7][0~3 u32 -> 0~7 f16], [8~15][0~3 u32 -> 0~7 f16] 109 | LDST32BITS(s_c[lane_id / 4 ][(lane_id % 4) * 2]) = LDST32BITS(RC[0]); 110 | LDST32BITS(s_c[lane_id / 4 + 8][(lane_id % 4) * 2]) = LDST32BITS(RC[1]); 111 | 112 | __syncthreads(); 113 | 114 | // store s_c[16][8] 115 | if (lane_id < MMA_M) { 116 | // store 128 bits per memory issue. 117 | int store_gmem_c_m = by * BM + lane_id; 118 | int store_gmem_c_n = bx * BN; 119 | int store_gmem_c_addr = store_gmem_c_m * N + store_gmem_c_n; 120 | LDST128BITS(C[store_gmem_c_addr]) = (LDST128BITS(s_c[lane_id][0])); 121 | } 122 | } 123 | 124 | // 128x128, mma2x4, warp4x4(64,32,16) 125 | template 134 | __global__ void __launch_bounds__(256) 135 | hgemm_mma_m16n8k16_mma2x4_warp4x4_kernel( 136 | half* A, half* B, half* C, int M, int N, int K) { 137 | const int bx = blockIdx.x; 138 | const int by = blockIdx.y; 139 | const int NUM_K_TILES = div_ceil(K, MMA_K); 140 | constexpr int BM = MMA_M * MMA_TILE_M * WARP_TILE_M; // 16*2*4=128 141 | constexpr int BN = MMA_N * MMA_TILE_N * WARP_TILE_N; // 8*4*4=128 142 | constexpr int BK = MMA_K; // 16 143 | 144 | __shared__ half s_a[BM][BK+A_PAD]; // 128*16*2=4KB 145 | __shared__ half s_b[BK][BN+B_PAD]; // 16*128*2=4KB, 16*(128+16)*2=4.5KB 146 | 147 | const int tid = threadIdx.y * blockDim.x + threadIdx.x; // within block 148 | const int warp_id = tid / WARP_SIZE; // 0~7 warp_id within block 149 | const int lane_id = tid % WARP_SIZE; // 0~31 150 | const int warp_m = warp_id % 2; // 0,1 151 | const int warp_n = warp_id / 2; // 0,1,2,3 152 | 153 | // 先计算shared memory中的索引 154 | // tid和需要加载的smem s_a[BM][BK] 之间的索引关系 BM=128 BK=16 按行读取 A行主序 155 | // 对于s_a每行16个数据,每个线程读取8个,需要2个线程;总共128行,需要128x2刚好256线程 156 | int load_smem_a_m = tid / 2; // row 0~127 157 | int load_smem_a_k = (tid % 2 == 0) ? 0 : 8; // col 0,8 158 | // tid和需要加载的smem s_b[BK][BN] 之间的索引关系 BK=16 BN=128 按行读取 B行主序 159 | // 对于s_b每行128个数据,每个线程读8个数据,需要16个线程;总共16行,需要16x16=256个线程 160 | int load_smem_b_k = tid / 16; // row 0~15 161 | int load_smem_b_n = (tid % 16) * 8; // col 0,8,...,120 162 | // 再计算全局内存中的索引 163 | // 要加载到s_a中的元素对应到A全局内存中的行数 每个block负责出C中大小为BM*BN的块 164 | int load_gmem_a_m = by * BM + load_smem_a_m; // global row of a and c 165 | int load_gmem_b_n = bx * BN + load_smem_b_n; // global col of b and c 166 | if (load_gmem_a_m >= M || load_gmem_b_n >= N) return; 167 | 168 | uint32_t RC[WARP_TILE_M][WARP_TILE_N][2]; 169 | #pragma unroll 170 | for (int i = 0; i < WARP_TILE_M; ++i) { 171 | #pragma unroll 172 | for (int j = 0; j < WARP_TILE_N; ++j) { 173 | RC[i][j][0] = 0; 174 | RC[i][j][1] = 0; 175 | } 176 | } 177 | 178 | #pragma unroll 179 | for (int k = 0; k < NUM_K_TILES; ++k) { 180 | // gmem -> smem 181 | int load_gmem_a_k = k * BK + load_smem_a_k; // global col of a 182 | int load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k; 183 | int load_gmem_b_k = k * BK + load_smem_b_k; // global row of b 184 | int load_gmem_b_addr = load_gmem_b_k * N + load_gmem_b_n; 185 | LDST128BITS(s_b[load_smem_b_k][load_smem_b_n]) = ( 186 | LDST128BITS(B[load_gmem_b_addr])); 187 | LDST128BITS(s_a[load_smem_a_m][load_smem_a_k]) = ( 188 | LDST128BITS(A[load_gmem_a_addr])); 189 | __syncthreads(); 190 | 191 | // ldmatrix for s_a, ldmatrix.trans for s_b. 192 | uint32_t RA[WARP_TILE_M][4]; 193 | uint32_t RB[WARP_TILE_N][2]; 194 | 195 | // smem -> reg 196 | #pragma unroll 197 | for (int i = 0; i < WARP_TILE_M; ++i) { 198 | int warp_smem_a_m = warp_m * (MMA_M * WARP_TILE_M) + i * MMA_M; 199 | int lane_smem_a_m = warp_smem_a_m + lane_id % 16; // 0~15 200 | int lane_smem_a_k = (lane_id / 16) * 8; // 0,8 201 | uint32_t lane_smem_a_ptr = __cvta_generic_to_shared( 202 | &s_a[lane_smem_a_m][lane_smem_a_k]); 203 | LDMATRIX_X4(RA[i][0], RA[i][1], RA[i][2], RA[i][3], lane_smem_a_ptr); 204 | } 205 | 206 | #pragma unroll 207 | for (int j = 0; j < WARP_TILE_N; ++j) { 208 | int warp_smem_b_n = warp_n * (MMA_N * WARP_TILE_N) + j * MMA_N; 209 | int lane_smem_b_k = lane_id % 16; // 0~15 210 | int lane_smem_b_n = warp_smem_b_n; // 0, MMA_N=8 211 | uint32_t lane_smem_b_ptr = __cvta_generic_to_shared( 212 | &s_b[lane_smem_b_k][lane_smem_b_n]); 213 | LDMATRIX_X2_T(RB[j][0], RB[j][1], lane_smem_b_ptr); 214 | } 215 | 216 | // MMA compute 217 | #pragma unroll 218 | for (int i = 0; i < WARP_TILE_M; ++i) { 219 | #pragma unroll 220 | for (int j = 0; j < WARP_TILE_N; ++j) { 221 | HMMA16816(RC[i][j][0], RC[i][j][1], 222 | RA[i][0], RA[i][1], RA[i][2], RA[i][3], 223 | RB[j][0], RB[j][1], 224 | RC[i][j][0], RC[i][j][1]); 225 | } 226 | } 227 | __syncthreads(); 228 | } 229 | 230 | // reg -> gmem, MMA_MxMMA_N=16x8 231 | #pragma unroll 232 | for (int i = 0; i < WARP_TILE_M; ++i) { 233 | #pragma unroll 234 | for (int j = 0; j < WARP_TILE_N; ++j) { 235 | int store_warp_smem_c_m = warp_m * (MMA_M * WARP_TILE_M) + i * MMA_M; 236 | int store_warp_smem_c_n = warp_n * (MMA_N * WARP_TILE_N) + j * MMA_N; 237 | // mapping lane smem index -> global index. 238 | // [16][8], https://docs.nvidia.com/cuda/parallel-thread-execution/index.html 239 | // #matrix-fragments-for-mma-m16n8k16-with-floating-point-type 240 | // [0~7][0~3 u32 -> 0~7 f16], [8~15][0~3 u32 -> 0~7 f16] 241 | int store_lane_gmem_c_m = by * BM + store_warp_smem_c_m + lane_id / 4; 242 | int store_lane_gmem_c_n = bx * BN + store_warp_smem_c_n + (lane_id % 4) * 2; 243 | int store_gmem_c_addr_0 = store_lane_gmem_c_m * N + store_lane_gmem_c_n; 244 | int store_gmem_c_addr_1 = (store_lane_gmem_c_m + 8) * N + store_lane_gmem_c_n; 245 | // TODO: how to use LDST128BITS here ? reverse the loop order ? 246 | LDST32BITS(C[store_gmem_c_addr_0]) = LDST32BITS(RC[i][j][0]); 247 | LDST32BITS(C[store_gmem_c_addr_1]) = LDST32BITS(RC[i][j][1]); 248 | } 249 | } 250 | } 251 | 252 | 253 | // --------------------- PyTorch bindings for custom kernel ----------------------- 254 | #define STRINGFY(str) #str 255 | #define TORCH_BINDING_COMMON_EXTENSION(func) \ 256 | m.def(STRINGFY(func), &func, STRINGFY(func)); 257 | 258 | #define CHECK_TORCH_TENSOR_DTYPE(T, th_type) \ 259 | if(((T).options().dtype() != (th_type))) { \ 260 | std::cout << "Tensor Info:" << (T).options() << std::endl; \ 261 | throw std::runtime_error("values must be "#th_type); \ 262 | } 263 | 264 | #define CHECK_TORCH_TENSOR_SHAPE(T, S0, S1) \ 265 | if (((T).size(0) != (S0)) || ((T).size(1) != (S1))) { \ 266 | throw std::runtime_error("Tensor size mismatch!"); \ 267 | } 268 | 269 | // only 1 warp per block(32 threads), m16n8k16. A, B, C: all row_major. 270 | void hgemm_mma_m16n8k16_naive( 271 | torch::Tensor a, torch::Tensor b, torch::Tensor c) { 272 | CHECK_TORCH_TENSOR_DTYPE(a, torch::kHalf) 273 | CHECK_TORCH_TENSOR_DTYPE(b, torch::kHalf) 274 | CHECK_TORCH_TENSOR_DTYPE(c, torch::kHalf) 275 | const int M = a.size(0); 276 | const int K = a.size(1); 277 | const int N = b.size(1); 278 | CHECK_TORCH_TENSOR_SHAPE(a, M, K) 279 | CHECK_TORCH_TENSOR_SHAPE(b, K, N) 280 | CHECK_TORCH_TENSOR_SHAPE(c, M, N) 281 | constexpr int MMA_M = 16; 282 | constexpr int MMA_N = 8; 283 | constexpr int MMA_K = 16; 284 | 285 | dim3 block(WARP_SIZE); 286 | dim3 grid(div_ceil(N, MMA_N), div_ceil(M, MMA_M)); 287 | 288 | hgemm_mma_m16n8k16_naive_kernel< 289 | MMA_M, MMA_N, MMA_K><<>>( 290 | reinterpret_cast(a.data_ptr()), 291 | reinterpret_cast(b.data_ptr()), 292 | reinterpret_cast(c.data_ptr()), 293 | M, N, K 294 | ); 295 | } 296 | 297 | // 128x128, mma2x4, warp4x4(64,32,16) 298 | void hgemm_mma_m16n8k16_mma2x4_warp4x4( 299 | torch::Tensor a, torch::Tensor b, torch::Tensor c) { 300 | CHECK_TORCH_TENSOR_DTYPE(a, torch::kHalf) 301 | CHECK_TORCH_TENSOR_DTYPE(b, torch::kHalf) 302 | CHECK_TORCH_TENSOR_DTYPE(c, torch::kHalf) 303 | const int M = a.size(0); 304 | const int K = a.size(1); 305 | const int N = b.size(1); 306 | CHECK_TORCH_TENSOR_SHAPE(a, M, K) 307 | CHECK_TORCH_TENSOR_SHAPE(b, K, N) 308 | CHECK_TORCH_TENSOR_SHAPE(c, M, N) 309 | constexpr int MMA_M = 16; 310 | constexpr int MMA_N = 8; 311 | constexpr int MMA_K = 16; 312 | constexpr int MMA_TILE_M = 2; 313 | constexpr int MMA_TILE_N = 4; 314 | constexpr int WARP_TILE_M = 4; 315 | constexpr int WARP_TILE_N = 4; 316 | // bank conflicts free via pad = 8, reject fantasy, trust the profile. 317 | // ncu --metrics l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld ./hgemm_mma_stage.89.debug.bin 318 | // ncu --metrics sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm ./hgemm_mma_stage.89.debug.bin 319 | constexpr int A_PAD = 8; 320 | constexpr int B_PAD = 8; 321 | constexpr int NUM_THREADS= ( 322 | MMA_TILE_M * MMA_TILE_N * WARP_SIZE); // 2 * 4 * 32 = 256 323 | 324 | dim3 block(NUM_THREADS); 325 | dim3 grid(div_ceil(N, MMA_N * MMA_TILE_N * WARP_TILE_N), 326 | div_ceil(M, MMA_M * MMA_TILE_M * WARP_TILE_M)); 327 | 328 | hgemm_mma_m16n8k16_mma2x4_warp4x4_kernel< 329 | MMA_M, MMA_N, MMA_K, MMA_TILE_M, MMA_TILE_N, 330 | WARP_TILE_M, WARP_TILE_N, A_PAD, B_PAD><<>>( 331 | reinterpret_cast(a.data_ptr()), 332 | reinterpret_cast(b.data_ptr()), 333 | reinterpret_cast(c.data_ptr()), 334 | M, N, K 335 | ); 336 | } 337 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | ## ⚡️⚡️Toy-HGEMM: Achieve the 98%~100% TFLOPS of cuBLAS 🎉🎉 3 | 4 | ![toy-hgemm-library](https://github.com/user-attachments/assets/962bda14-b494-4423-b8eb-775da9f5503d) 5 | 6 | [📖Toy-HGEMM Library⚡️⚡️](./kernels/hgemm) is a library that write many HGEMM kernels from scratch using Tensor Cores with WMMA, MMA PTX and CuTe API, thus, can achieve `98%~100%` performance of **cuBLAS**. The codes here are source from 📖[LeetCUDA](https://github.com/xlite-dev/LeetCUDA) ![](https://img.shields.io/github/stars/xlite-dev/LeetCUDA.svg?style=social) and exported as a standalone library, please checkout [LeetCUDA](https://github.com/xlite-dev/LeetCUDA) for latest updates. Welcome to 🌟👆🏻star this repo to support me, many thanks ~ 🎉🎉 7 | 8 |
9 | 10 |
11 | 12 | 13 | 14 |
15 | 16 | 17 | Currently, on NVIDIA L20, RTX 4090 and RTX 3080 Laptop, compared with cuBLAS's default Tensor Cores math algorithm `CUBLAS_GEMM_DEFAULT_TENSOR_OP`, the `HGEMM (WMMA/MMA/CuTe)` implemented in this repo (`blue`🔵) can achieve `98%~100%` of its (`orange`🟠) performance. Please check [toy-hgemm library⚡️⚡️](./kernels/hgemm) for more details. 18 | 19 | |📚Feature |📚Feature |📚Feature |📚Feature| 20 | |:---:|:---:|:---:|:---:| 21 | |✔️CUDA/**Tensor Cores**|✔️Loop over K|✔️Tile Block(BMxBK)|✔️Tile Threads(T 8x8)| 22 | |✔️WMMA(m16n16k16)|✔️MMA(m16n8k16)|✔️Pack LDST(128 bits)|✔️SMEM Padding| 23 | |✔️Copy Async|✔️Tile MMAs|✔️Tile Warps|✔️**Multi Stages(2~4)**| 24 | |✔️Register Double Buffers|✔️**Block Swizzle**|✔️**Warp Swizzle**|✔️**SMEM Swizzle**(CuTe/MMA)| 25 | |✔️Collective Store(Shfl)|✔️Layout NN|✔️Layout TN|✔️SGEMM FP32/TF32| 26 | 27 | ## ©️Citations🎉🎉 28 | 29 | ```BibTeX 30 | @misc{HGEMM@2024, 31 | title={HGEMM: Write HGEMM from scratch using Tensor Cores with WMMA, MMA PTX and CuTe API.}, 32 | url={https://github.com/xlite-dev/HGEMM}, 33 | note={Open-source software available at https://github.com/xlite-dev/HGEMM}, 34 | author={xlite-dev etc}, 35 | year={2024} 36 | } 37 | ``` 38 | 39 | ## 📖 HGEMM CUDA Kernels in Toy-HGEMM Library 🎉🎉 40 | 41 |
42 | 43 | ```C++ 44 | void hgemm_naive_f16(torch::Tensor a, torch::Tensor b, torch::Tensor c); 45 | void hgemm_sliced_k_f16(torch::Tensor a, torch::Tensor b, torch::Tensor c); 46 | void hgemm_t_8x8_sliced_k_f16x4(torch::Tensor a, torch::Tensor b, torch::Tensor c); 47 | void hgemm_t_8x8_sliced_k_f16x4_pack(torch::Tensor a, torch::Tensor b, torch::Tensor c); 48 | void hgemm_t_8x8_sliced_k_f16x4_bcf(torch::Tensor a, torch::Tensor b, torch::Tensor c); 49 | void hgemm_t_8x8_sliced_k_f16x4_pack_bcf(torch::Tensor a, torch::Tensor b, torch::Tensor c); 50 | void hgemm_t_8x8_sliced_k_f16x8_pack_bcf(torch::Tensor a, torch::Tensor b, torch::Tensor c); 51 | void hgemm_t_8x8_sliced_k_f16x8_pack_bcf_dbuf(torch::Tensor a, torch::Tensor b, torch::Tensor c); 52 | void hgemm_t_8x8_sliced_k16_f16x8_pack_dbuf(torch::Tensor a, torch::Tensor b, torch::Tensor c); 53 | void hgemm_t_8x8_sliced_k16_f16x8_pack_dbuf_async(torch::Tensor a, torch::Tensor b, torch::Tensor c); 54 | void hgemm_t_8x8_sliced_k32_f16x8_pack_dbuf(torch::Tensor a, torch::Tensor b, torch::Tensor c); 55 | void hgemm_t_8x8_sliced_k32_f16x8_pack_dbuf_async(torch::Tensor a, torch::Tensor b, torch::Tensor c); 56 | void hgemm_t_16x8_sliced_k32_f16x8_pack_dbuf(torch::Tensor a, torch::Tensor b, torch::Tensor c); 57 | void hgemm_t_16x8_sliced_k32_f16x8_pack_dbuf_async(torch::Tensor a, torch::Tensor b, torch::Tensor c); 58 | void hgemm_cublas_tensor_op_nn(torch::Tensor a, torch::Tensor b, torch::Tensor c); 59 | void hgemm_cublas_tensor_op_tn(torch::Tensor a, torch::Tensor b, torch::Tensor c); 60 | void hgemm_wmma_m16n16k16_naive(torch::Tensor a, torch::Tensor b, torch::Tensor c); 61 | void hgemm_wmma_m16n16k16_mma4x2(torch::Tensor a, torch::Tensor b, torch::Tensor c); 62 | void hgemm_wmma_m16n16k16_mma4x2_warp2x4(torch::Tensor a, torch::Tensor b, torch::Tensor c); 63 | void hgemm_wmma_m16n16k16_mma4x2_warp2x4_dbuf_async(torch::Tensor a, torch::Tensor b, torch::Tensor c); 64 | void hgemm_wmma_m32n8k16_mma2x4_warp2x4_dbuf_async(torch::Tensor a, torch::Tensor b, torch::Tensor c); 65 | void hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride); 66 | void hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride); 67 | void hgemm_wmma_m16n16k16_mma4x2_warp4x4_stages_dsmem(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride); 68 | void hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride); 69 | void hgemm_mma_m16n8k16_naive(torch::Tensor a, torch::Tensor b, torch::Tensor c); 70 | void hgemm_mma_m16n8k16_mma2x4_warp4x4(torch::Tensor a, torch::Tensor b, torch::Tensor c); 71 | void hgemm_mma_m16n8k16_mma2x4_warp4x4_stages(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride); 72 | void hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride); 73 | void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride); 74 | void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_x4(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride); 75 | void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_rr(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride); 76 | void hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride); 77 | void hgemm_mma_stages_block_swizzle_tn_cute(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride); 78 | void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_swizzle(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride); 79 | void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_tn_swizzle_x4(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride); 80 | ``` 81 | 82 | ## 📖 Contents 83 | 84 | - [📖 Prerequisites](#prerequisites) 85 | - [📖 Installation](#install) 86 | - [📖 Python Testing](#test) 87 | - [📖 C++ Testing](#test-cpp) 88 | - [📖 NVIDIA L20 bench](#perf-l20) 89 | - [📖 NVIDIA RTX 4090 bench](#perf-4090) 90 | - [📖 NVIDIA RTX 3080 Laptop bench](#perf-3080) 91 | - [📖 Docs](#opt-docs) 92 | - [📖 References](#ref) 93 | 94 | ## 📖 Prerequisites 95 |
96 | 97 | - PyTorch >= 2.0, CUDA >= 12.0 98 | - Recommended: PyTorch 2.5.1, CUDA 12.5 99 | 100 | ## 📖 Installation 101 | 102 |
103 | 104 | The HGEMM implemented in this repo can be install as a python library, namely, `toy-hgemm` library (optional). 105 | ```bash 106 | cd kernels/hgemm 107 | git submodule update --init --recursive --force # Fetch `CUTLASS` submodule, needed 108 | python3 setup.py bdist_wheel && cd dist && python3 -m pip install *.whl # pip uninstall toy-hgemm -y 109 | ``` 110 | 111 | ## 📖 Python Testing 112 | 113 |
114 | 115 | **CUTLASS**: Fetch `CUTLASS` submodule. Currently, I use `v3.5.1` for HGEMM CuTe kernel. 116 | ```bash 117 | git submodule update --init --recursive --force 118 | ``` 119 | 120 | You can test many custom HGEMM kernel via Python script and figure out the difference in their performance. 121 | 122 | ```bash 123 | # You can test Ada or Ampere only, also, Volta, Ampere, Ada, Hopper, ... 124 | export TORCH_CUDA_ARCH_LIST=Ada # for Ada only 125 | export TORCH_CUDA_ARCH_LIST=Ampere # for Ampere only 126 | python3 hgemm.py --wmma # test defalut wmma kernels for all MNK 127 | python3 hgemm.py --mma # test defalut mma kernels for all MNK 128 | python3 hgemm.py --M 16384 --N 16384 --K 8192 --i 10 --wmma # test default wmma kernels for specific MNK 129 | python3 hgemm.py --M 16384 --N 16384 --K 8192 --i 10 --mma # test default mma kernels for specific MNK 130 | python3 hgemm.py --wmma-all # test all wmma kernels for all MNK 131 | python3 hgemm.py --mma-all # test all mma kernels for all MNK 132 | python3 hgemm.py --cuda-all --wmma-all --mma-all # test all kernels for all MNK 133 | python3 hgemm.py --cute-tn --no-default # test cute hgemm kernels with smem swizzle for all MNK 134 | ``` 135 | If you want to draw a TFLOPS curve, you need to install `matplotlib` first and set the --plot-flops (or --plot) option. 136 | ```bash 137 | python3 -m pip install matplotlib 138 | # Specify topk to plot only the top k kernels with the best performance. 139 | python3 hgemm.py --mma-all --plot --topk 8 140 | # test default mma kernels & cute hgemm kernels with smem swizzle for all MNK 141 | python3 hgemm.py --cute-tn --mma --plot 142 | ``` 143 | 144 | ## 📖 C++ Testing 145 | 146 |
147 | 148 | The HGEMM benchmark also supports C++ testing. Currently, it supports comparisons between the following implementations: 149 | 150 | - MMA HGEMM NN implemented in this repository 151 | - CuTe HGEMM TN implemented in this repository 152 | - cuBLAS HGEMM TN use default Tensor Cores math algorithm 153 | 154 | Performance data obtained from C++ binary tests tend to be slightly better than those from Python tests. This difference may be attributed to additional overhead introduced by the PyTorch Python bindings. 155 | ```bash 156 | make 157 | ./hgemm_mma_stage.bin 158 | # NVIDIA L20 159 | ALGO = MMA16816 HGEMM NN MMA=2x4 WARP=4x4x2 STAGES=2 BLOCK SWIZZLE=2048 160 | M N K = 12544 12544 12544, Time = 0.03445555 0.03446098 0.03447399 s, AVG Performance = 114.5541 Tflops 161 | M N K = 15360 15360 15360, Time = 0.06307226 0.06307789 0.06308864 s, AVG Performance = 114.9017 Tflops 162 | M N K = 15616 15616 15616, Time = 0.06612480 0.06612798 0.06613094 s, AVG Performance = 115.1739 Tflops 163 | M N K = 15872 15872 15872, Time = 0.06969549 0.06970215 0.06971290 s, AVG Performance = 114.7305 Tflops 164 | M N K = 16128 16128 16128, Time = 0.07295078 0.07295406 0.07295693 s, AVG Performance = 115.0064 Tflops 165 | M N K = 16384 16384 16384, Time = 0.07663001 0.07663534 0.07664947 s, AVG Performance = 114.7785 Tflops 166 | 167 | ./hgemm_cute.bin 168 | # NVIDIA L20 169 | ALGO = CuTe HGEMM, TN, STAGES=2, SMEM SWIZZLE=<3, 3, 3>, BLOCK SWIZZLE=2048 170 | M N K = 12544 12544 12544, Time = 0.03413504 0.03414354 0.03415450 s, AVG Performance = 115.6191 Tflops 171 | M N K = 15360 15360 15360, Time = 0.06227354 0.06228111 0.06228992 s, AVG Performance = 116.3717 Tflops 172 | M N K = 15616 15616 15616, Time = 0.06492467 0.06493727 0.06496666 s, AVG Performance = 117.2858 Tflops 173 | M N K = 15872 15872 15872, Time = 0.06843085 0.06843873 0.06844723 s, AVG Performance = 116.8485 Tflops 174 | M N K = 16128 16128 16128, Time = 0.07200256 0.07200881 0.07201792 s, AVG Performance = 116.5161 Tflops 175 | M N K = 16384 16384 16384, Time = 0.07564493 0.07565752 0.07567462 s, AVG Performance = 116.2620 Tflops 176 | 177 | ./hgemm_cublas.bin 178 | # NVIDIA L20 179 | ALGO = cuBLAS CUBLAS_GEMM_DEFAULT_TENSOR_OP TN 180 | M N K = 12544 12544 12544, Time = 0.03472691 0.03472968 0.03473408 s, AVG Performance = 113.6678 Tflops 181 | M N K = 15360 15360 15360, Time = 0.06332416 0.06333143 0.06334157 s, AVG Performance = 114.4417 Tflops 182 | M N K = 15616 15616 15616, Time = 0.06649446 0.06650184 0.06651699 s, AVG Performance = 114.5264 Tflops 183 | M N K = 15872 15872 15872, Time = 0.06977024 0.06977659 0.06978355 s, AVG Performance = 114.6081 Tflops 184 | M N K = 16128 16128 16128, Time = 0.07319142 0.07320709 0.07326925 s, AVG Performance = 114.6089 Tflops 185 | M N K = 16384 16384 16384, Time = 0.07668429 0.07669371 0.07670784 s, AVG Performance = 114.6912 Tflops 186 | ``` 187 | 188 | ## 📖 Benchmark 189 | 190 |
191 | 192 | ### 📖 NVIDIA L20 193 | 196 | The current best implementation, on the L20 (with a theoretical Tensor Cores FP16 performance of 119.5 TFLOPS), achieves performance that is approximately 99~100+% of cuBLAS. 197 | 198 | - Using the WMMA API, it can achieve around 95%~98% of cuBLAS performance (105-113 TFLOPS vs 105-115 TFLOPS). 199 | - Using the MMA API, it can reach 115 TFLOPS, surpassing cuBLAS in some cases. 200 | - The CuTe version of HGEMM implements Block Swizzle (L2 Cache friendly) and SMEM Swizzle (bank conflicts free), achieving the best performance. For large-scale matrix multiplication, it can reach 116-117 TFLOPS, which is approximately 98%~100%+ of cuBLAS performance, and it outperforms cuBLAS in many cases. 201 | 202 | Currently, SMEM Padding and SMEM Swizzle are used to mitigate bank conflicts: 203 | 204 | - For the NN layout, SMEM Padding is used to alleviate bank conflicts. 205 | - For the TN layout, CUTLASS/CuTe's SMEM Swizzle is used to eliminate bank conflicts. 206 | 207 |
208 | 209 | 210 | ![NVIDIA_L20_NN+TN+v2](https://github.com/user-attachments/assets/71927ac9-72b3-4ce9-b0e2-788b5885bc99) 211 | 212 | 213 | The command for testing all MNK setups (Tip: Performance data for each MNK tested individually is more accurate.) 214 | ```bash 215 | python3 hgemm.py --cute-tn --mma --plot 216 | ``` 217 | 218 | ### 📖 NVIDIA GeForce RTX 4090 219 | 220 |
221 | 222 | 225 | 226 | On the NVIDIA RTX 4090 (with an FP16 Tensor Cores performance of 330 TFLOPS), the WMMA (m16n16k16) implementation shows better performance compared to MMA (m16n8k16). For most MNK configurations, this repository's implementation achieves 95%~99% of cuBLAS performance, and in certain cases, it can surpass cuBLAS. Specifically: 227 | 228 | - For large-scale matrix multiplications (MNK >= 8192), the WMMA implementation performs better. 229 | - For small-scale matrix multiplications, the MMA implementation is more efficient. 230 | 231 | 232 | ![NVIDIA_GeForce_RTX_4090_NN+TN+v4](https://github.com/user-attachments/assets/05ef4f5e-d999-48ea-b58e-782cffb24e85) 233 | 234 | ```bash 235 | python3 hgemm.py --cute-tn --mma --wmma-all --plot 236 | ``` 237 | 238 | ### 📖 NVIDIA GeForce RTX 3080 Laptop 239 | 240 |
241 | 242 | 245 | Testing was conducted on a NVIDIA GeForce RTX 3080 Laptop using the mma4x4_warp4x4 configuration (which includes 16 WMMA m16n16k16 operations with a warp tile size of 64x64) along with Thread block swizzle. In most cases, this setup matches or even exceeds cuBLAS performance. The tests were performed using Windows WSL2 + RTX 3080 Laptop. 246 | 247 | ![image](https://github.com/user-attachments/assets/9472e970-c083-4b31-9252-3eeecc761078) 248 | 249 | ```bash 250 | python3 hgemm.py --wmma-all --plot 251 | ``` 252 | 253 |
254 | 🔑️ Performance Optimization Notes(TODO) 255 | 256 | ## 📖 Performance Optimization Notes 257 | 258 |
259 | 260 | ### PyTorch HGEMM Profile 261 | 262 | 在Ada架构下,PyTorch 2.4对FP16使用matmul时,会调用: 263 | ```C++ 264 | ampere_fp16_s1688gemm_fp16_128x128_ldg8_f2f_stages_32x1_nn_kernel 265 | ``` 266 | 内部实际使用HMMA(Tensor Cores)进行计算,在3080上profile发现使用: 267 | ```C++ 268 | sm80_xmma_gemm_f16f16_f16f32_f32_nn_n_tilesize96x64x32_stage3_warpsize2x2x1_tensor16x8x16_kernel 269 | ``` 270 | 因此,只有实现使用Tensor Cores的HGEMM,才有可能接近PyTorch/cuBLAS的性能。 271 | ```bash 272 | ncu -o hgemm.prof -f python3 bench/prof.py 273 | nsys profile --stats=true -t cuda,osrt,nvtx -o hgemm.prof --force-overwrite true python3 prof.py 274 | ``` 275 | - SASS (L20) 276 | 277 | ```C 278 | // ampere_fp16_s1688gemm_fp16_128x128_ldg8_f2f_stages_32x1_nn_kernel 279 | 310 00007f41 37d5b850 LDSM.16.M88.4 R192, [R169+UR8+0x2000] 280 | 311 00007f41 37d5b860 LDSM.16.M88.4 R196, [R169+UR8+0x2800] 281 | 336 00007f41 37d5b9f0 HMMA.1688.F32 R112, R182, R196, R112 282 | ... 283 | ``` 284 | 285 | ### SMEM Padding 286 | 287 | #### Bank Conflicts的产生 288 | 289 | 含义:在访问shared memory时,因多个线程读写同一个Bank中的不同数据地址时,导致shared memory 并发读写 退化 成顺序读写的现象叫做Bank Conflict; 290 | 291 | ![](https://github.com/PaddleJitLab/CUDATutorial/blob/develop/docs/09_optimize_reduce/02_bank_conflict/images/ef322be7c3e5b6b9be69d2b90e88083f50569a58a97129f348e483b946ab4edf.png) 292 | 293 | SM调度单位为一个warp(一个warp内32个Thread),shared_memory 可以 被一个warp中的所有(32个)线程进行访问,shared_memory 映射到大小相等的32个Bank上,Bank的数据读取带宽为32bit / cycle (4 bytes),因此,主要需要考虑一个Warp内32线程的访问共享内存时的bank冲突。 294 | 对于多个线程读取同一个Bank数据时(不同地址),硬件把内存读写请求,拆分成 conflict-free requests,进行顺序读写,此时将会触发多次内存事务。特别地,当一个warp中的所有线程读写同一个地址时,会触发broadcast机制,此时不会退化成顺序读写。上面提到触发broadcast机制的条件是all threads acess same address,但在翻阅cuda-c-programming-guide以及最新版本的[NVProfGuide](https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html) 时,发现只要是多个thread 读写就会触发broadcast(不需要All)。 295 | 296 | - 多个线程读同一个数据时,仅有一个线程读,然后broadcast到其他线程 297 | - 多个线程写同一个数据时,仅会有一个线程写成功 298 | 299 | NVIDIA的[文章](https://developer.nvidia.com/blog/using-shared-memory-cuda-cc/)中指出,我们还可以通过 `cudaDeviceSetSharedMemConfig()` 函数设置默认Bank Size(默认为4 bytes)来避免bank conflicts,可设置为cudaSharedMemBankSizeFourByte或者cudaSharedMemBankSizeEightByte。对于某些场景来说,设置cudaSharedMemBankSizeEightByte或许更加合适,比如使用double数据类型时。 300 | 301 | ```C 302 | cudaDeviceSetSharedMemConfig(cudaSharedMemBankSizeEightByte); 303 | ``` 304 | 目前通过 SMEM Padding 和 SMEM swizzle的方式缓解bank conflicts。对于 NN layout,使用 SMEM Padding 缓解 bank conflicts;对于 TN layout,通过cutlass cute的 SMEM Swizzle 消除 bank conflicts。 305 | 306 | ### 双缓冲 Double Buffers 307 | 308 | 本仓库实现的HGEMM Double Buffers策略如下:1)主循环从bk = 1 开始,第一次数据加载在主循环之前,最后一次计算在主循环之后,这是pipeline 的特点决定的;2)由于计算和下一次访存使用的Shared Memory不同,因此主循环中每次循环只需要一次__syncthreads()即可,对比非double buffers版本,总共节省了 ((K + BK - 1) / BK) - 1 次block内的同步操作。比如,bk=1时,HFMA计算使用的是s_a[0]和s_b[0],因此,和s_a[1]和s_b[1]的加载是没有依赖关系的。HFMA计算,从global内存到s_a[1]和s_b[1]和HFMA计算可以并行。s_a[1]和s_b[1]用于加载下一块BK需要的数据到共享内存;3)由于GPU不能向CPU那样支持乱序执行,主循环中需要先将下一次循环计算需要的Gloabal Memory中的数据load 到寄存器,然后进行本次计算,之后再将load到寄存器中的数据写到Shared Memory,这样在LDG指令向Global Memory做load时,不会影响后续HFMA及其它运算指令的 launch 执行,也就达到了Double Buffers的目的,具体代码见[hgemm.cu](./hgemm.cu)。 309 | 310 | 311 | ### Tile Block 312 | 313 | TODO 314 | 315 | ### Tile Thread 316 | 317 | TODO 318 | 319 | ### Pack LDST 128 bits 320 | 321 | TODO 322 | 323 | ### Async Copy 324 | 325 | TODO 326 | 327 | ### Multi Stages 328 | 329 | TODO 330 | 331 | ### Tensor Cores(WMMA/MMA) 332 | 333 | TODO 334 | 335 | ### Tile MMA/Warp 336 | 337 | TODO 338 | 339 | ### Thread Block Swizze 340 | 341 | TODO 342 | 343 | ### Warp Swizzle 344 | 345 | TODO 346 | 347 | ### Reg Double Buffers 348 | 349 | TODO 350 | 351 | ### Collective Store(Reg Reuse&Warp Shuffle) 352 | 353 | TODO 354 | 355 | ### SMEM Swizzle/Permuted 356 | 357 | TODO 358 | 359 |
360 | 361 | ## 📖 References 362 | 363 |
364 | 365 | - [flash-attention-minimal](https://github.com/tspeterkim/flash-attention-minimal) 366 | - [tiny-flash-attention](https://github.com/66RING/tiny-flash-attention) 367 | - [cute-gemm](https://github.com/reed-lau/cute-gemm) 368 | - [cutlass_flash_atten_fp8](https://github.com/weishengying/cutlass_flash_atten_fp8) 369 | - [cuda_learning](https://github.com/ifromeast/cuda_learning) 370 | - [cuda_hgemm](https://github.com/Bruce-Lee-LY/cuda_hgemm) 371 | - [cuda-tensorcore-hgemm](https://github.com/nicolaswilde/cuda-tensorcore-hgemm) 372 | - [How_to_optimize_in_GPU](https://github.com/Liu-xiandong/How_to_optimize_in_GPU/tree/master/sgemv) 373 | - [cute_gemm](https://github.com/weishengying/cute_gemm) 374 | - [cutlass](https://github.com/NVIDIA/cutlass) 375 | -------------------------------------------------------------------------------- /kernels/hgemm/README.md: -------------------------------------------------------------------------------- 1 | 2 | ## ⚡️⚡️Toy-HGEMM: Achieve the 98%~100% TFLOPS of cuBLAS 🎉🎉 3 | 4 | ![toy-hgemm-library](https://github.com/user-attachments/assets/962bda14-b494-4423-b8eb-775da9f5503d) 5 | 6 | [📖Toy-HGEMM Library⚡️⚡️](./kernels/hgemm) is a library that write many HGEMM kernels from scratch using Tensor Cores with WMMA, MMA PTX and CuTe API, thus, can achieve `98%~100%` performance of **cuBLAS**. The codes here are source from 📖[CUDA-Learn-Notes](https://github.com/DefTruth/CUDA-Learn-Notes) ![](https://img.shields.io/github/stars/DefTruth/CUDA-Learn-Notes.svg?style=social) and exported as a standalone library, please checkout [CUDA-Learn-Notes](https://github.com/DefTruth/CUDA-Learn-Notes) for latest updates. Welcome to 🌟👆🏻star this repo to support me, many thanks ~ 🎉🎉 7 | 8 |
9 | 10 |
11 | 12 | 13 | 14 |
15 | 16 | 17 | Currently, on NVIDIA L20, RTX 4090 and RTX 3080 Laptop, compared with cuBLAS's default Tensor Cores math algorithm `CUBLAS_GEMM_DEFAULT_TENSOR_OP`, the `HGEMM (WMMA/MMA/CuTe)` implemented in this repo (`blue`🔵) can achieve `98%~100%` of its (`orange`🟠) performance. Please check [toy-hgemm library⚡️⚡️](./kernels/hgemm) for more details. 18 | 19 | |📚Feature |📚Feature |📚Feature |📚Feature| 20 | |:---:|:---:|:---:|:---:| 21 | |✔️CUDA/**Tensor Cores**|✔️Loop over K|✔️Tile Block(BMxBK)|✔️Tile Threads(T 8x8)| 22 | |✔️WMMA(m16n16k16)|✔️MMA(m16n8k16)|✔️Pack LDST(128 bits)|✔️SMEM Padding| 23 | |✔️Copy Async|✔️Tile MMAs|✔️Tile Warps|✔️**Multi Stages(2~4)**| 24 | |✔️Register Double Buffers|✔️**Block Swizzle**|✔️**Warp Swizzle**|✔️**SMEM Swizzle**(CuTe/MMA)| 25 | |✔️Collective Store(Shfl)|✔️Layout NN|✔️Layout TN|✔️SGEMM FP32/TF32| 26 | 27 | ## ©️Citations🎉🎉 28 | 29 | ```BibTeX 30 | @misc{hgemm-tensorcores-mma@2024, 31 | title={hgemm-tensorcores-mma: Write HGEMM from scratch using Tensor Cores with WMMA, MMA PTX and CuTe API.}, 32 | url={https://github.com/DefTruth/hgemm-tensorcores-mma}, 33 | note={Open-source software available at https://github.com/DefTruth/hgemm-tensorcores-mma}, 34 | author={DefTruth etc}, 35 | year={2024} 36 | } 37 | ``` 38 | 39 | ## 📖 HGEMM CUDA Kernels in Toy-HGEMM Library 🎉🎉 40 | 41 |
42 | 43 | ```C++ 44 | void hgemm_naive_f16(torch::Tensor a, torch::Tensor b, torch::Tensor c); 45 | void hgemm_sliced_k_f16(torch::Tensor a, torch::Tensor b, torch::Tensor c); 46 | void hgemm_t_8x8_sliced_k_f16x4(torch::Tensor a, torch::Tensor b, torch::Tensor c); 47 | void hgemm_t_8x8_sliced_k_f16x4_pack(torch::Tensor a, torch::Tensor b, torch::Tensor c); 48 | void hgemm_t_8x8_sliced_k_f16x4_bcf(torch::Tensor a, torch::Tensor b, torch::Tensor c); 49 | void hgemm_t_8x8_sliced_k_f16x4_pack_bcf(torch::Tensor a, torch::Tensor b, torch::Tensor c); 50 | void hgemm_t_8x8_sliced_k_f16x8_pack_bcf(torch::Tensor a, torch::Tensor b, torch::Tensor c); 51 | void hgemm_t_8x8_sliced_k_f16x8_pack_bcf_dbuf(torch::Tensor a, torch::Tensor b, torch::Tensor c); 52 | void hgemm_t_8x8_sliced_k16_f16x8_pack_dbuf(torch::Tensor a, torch::Tensor b, torch::Tensor c); 53 | void hgemm_t_8x8_sliced_k16_f16x8_pack_dbuf_async(torch::Tensor a, torch::Tensor b, torch::Tensor c); 54 | void hgemm_t_8x8_sliced_k32_f16x8_pack_dbuf(torch::Tensor a, torch::Tensor b, torch::Tensor c); 55 | void hgemm_t_8x8_sliced_k32_f16x8_pack_dbuf_async(torch::Tensor a, torch::Tensor b, torch::Tensor c); 56 | void hgemm_t_16x8_sliced_k32_f16x8_pack_dbuf(torch::Tensor a, torch::Tensor b, torch::Tensor c); 57 | void hgemm_t_16x8_sliced_k32_f16x8_pack_dbuf_async(torch::Tensor a, torch::Tensor b, torch::Tensor c); 58 | void hgemm_cublas_tensor_op_nn(torch::Tensor a, torch::Tensor b, torch::Tensor c); 59 | void hgemm_cublas_tensor_op_tn(torch::Tensor a, torch::Tensor b, torch::Tensor c); 60 | void hgemm_wmma_m16n16k16_naive(torch::Tensor a, torch::Tensor b, torch::Tensor c); 61 | void hgemm_wmma_m16n16k16_mma4x2(torch::Tensor a, torch::Tensor b, torch::Tensor c); 62 | void hgemm_wmma_m16n16k16_mma4x2_warp2x4(torch::Tensor a, torch::Tensor b, torch::Tensor c); 63 | void hgemm_wmma_m16n16k16_mma4x2_warp2x4_dbuf_async(torch::Tensor a, torch::Tensor b, torch::Tensor c); 64 | void hgemm_wmma_m32n8k16_mma2x4_warp2x4_dbuf_async(torch::Tensor a, torch::Tensor b, torch::Tensor c); 65 | void hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride); 66 | void hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride); 67 | void hgemm_wmma_m16n16k16_mma4x2_warp4x4_stages_dsmem(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride); 68 | void hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride); 69 | void hgemm_mma_m16n8k16_naive(torch::Tensor a, torch::Tensor b, torch::Tensor c); 70 | void hgemm_mma_m16n8k16_mma2x4_warp4x4(torch::Tensor a, torch::Tensor b, torch::Tensor c); 71 | void hgemm_mma_m16n8k16_mma2x4_warp4x4_stages(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride); 72 | void hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride); 73 | void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride); 74 | void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_x4(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride); 75 | void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_rr(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride); 76 | void hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride); 77 | void hgemm_mma_stages_block_swizzle_tn_cute(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride); 78 | void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_swizzle(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride); 79 | void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_tn_swizzle_x4(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride); 80 | ``` 81 | 82 | ## 📖 Contents 83 | 84 | - [📖 Prerequisites](#prerequisites) 85 | - [📖 Installation](#install) 86 | - [📖 Python Testing](#test) 87 | - [📖 C++ Testing](#test-cpp) 88 | - [📖 NVIDIA L20 bench](#perf-l20) 89 | - [📖 NVIDIA RTX 4090 bench](#perf-4090) 90 | - [📖 NVIDIA RTX 3080 Laptop bench](#perf-3080) 91 | - [📖 Docs](#opt-docs) 92 | - [📖 References](#ref) 93 | 94 | ## 📖 Prerequisites 95 |
96 | 97 | - PyTorch >= 2.0, CUDA >= 12.0 98 | - Recommended: PyTorch 2.5.1, CUDA 12.5 99 | 100 | ## 📖 Installation 101 | 102 |
103 | 104 | The HGEMM implemented in this repo can be install as a python library, namely, `toy-hgemm` library (optional). 105 | ```bash 106 | cd kernels/hgemm 107 | git submodule update --init --recursive --force # Fetch `CUTLASS` submodule, needed 108 | python3 setup.py bdist_wheel && cd dist && python3 -m pip install *.whl # pip uninstall toy-hgemm -y 109 | ``` 110 | 111 | ## 📖 Python Testing 112 | 113 |
114 | 115 | **CUTLASS**: Fetch `CUTLASS` submodule. Currently, I use `v3.5.1` for HGEMM CuTe kernel. 116 | ```bash 117 | git submodule update --init --recursive --force 118 | ``` 119 | 120 | You can test many custom HGEMM kernel via Python script and figure out the difference in their performance. 121 | 122 | ```bash 123 | # You can test Ada or Ampere only, also, Volta, Ampere, Ada, Hopper, ... 124 | export TORCH_CUDA_ARCH_LIST=Ada # for Ada only 125 | export TORCH_CUDA_ARCH_LIST=Ampere # for Ampere only 126 | python3 hgemm.py --wmma # test defalut wmma kernels for all MNK 127 | python3 hgemm.py --mma # test defalut mma kernels for all MNK 128 | python3 hgemm.py --M 16384 --N 16384 --K 8192 --i 10 --wmma # test default wmma kernels for specific MNK 129 | python3 hgemm.py --M 16384 --N 16384 --K 8192 --i 10 --mma # test default mma kernels for specific MNK 130 | python3 hgemm.py --wmma-all # test all wmma kernels for all MNK 131 | python3 hgemm.py --mma-all # test all mma kernels for all MNK 132 | python3 hgemm.py --cuda-all --wmma-all --mma-all # test all kernels for all MNK 133 | python3 hgemm.py --cute-tn --no-default # test cute hgemm kernels with smem swizzle for all MNK 134 | ``` 135 | If you want to draw a TFLOPS curve, you need to install `matplotlib` first and set the --plot-flops (or --plot) option. 136 | ```bash 137 | python3 -m pip install matplotlib 138 | # Specify topk to plot only the top k kernels with the best performance. 139 | python3 hgemm.py --mma-all --plot --topk 8 140 | # test default mma kernels & cute hgemm kernels with smem swizzle for all MNK 141 | python3 hgemm.py --cute-tn --mma --plot 142 | ``` 143 | 144 | ## 📖 C++ Testing 145 | 146 |
147 | 148 | The HGEMM benchmark also supports C++ testing. Currently, it supports comparisons between the following implementations: 149 | 150 | - MMA HGEMM NN implemented in this repository 151 | - CuTe HGEMM TN implemented in this repository 152 | - cuBLAS HGEMM TN use default Tensor Cores math algorithm 153 | 154 | Performance data obtained from C++ binary tests tend to be slightly better than those from Python tests. This difference may be attributed to additional overhead introduced by the PyTorch Python bindings. 155 | ```bash 156 | make 157 | ./hgemm_mma_stage.bin 158 | # NVIDIA L20 159 | ALGO = MMA16816 HGEMM NN MMA=2x4 WARP=4x4x2 STAGES=2 BLOCK SWIZZLE=2048 160 | M N K = 12544 12544 12544, Time = 0.03445555 0.03446098 0.03447399 s, AVG Performance = 114.5541 Tflops 161 | M N K = 15360 15360 15360, Time = 0.06307226 0.06307789 0.06308864 s, AVG Performance = 114.9017 Tflops 162 | M N K = 15616 15616 15616, Time = 0.06612480 0.06612798 0.06613094 s, AVG Performance = 115.1739 Tflops 163 | M N K = 15872 15872 15872, Time = 0.06969549 0.06970215 0.06971290 s, AVG Performance = 114.7305 Tflops 164 | M N K = 16128 16128 16128, Time = 0.07295078 0.07295406 0.07295693 s, AVG Performance = 115.0064 Tflops 165 | M N K = 16384 16384 16384, Time = 0.07663001 0.07663534 0.07664947 s, AVG Performance = 114.7785 Tflops 166 | 167 | ./hgemm_cute.bin 168 | # NVIDIA L20 169 | ALGO = CuTe HGEMM, TN, STAGES=2, SMEM SWIZZLE=<3, 3, 3>, BLOCK SWIZZLE=2048 170 | M N K = 12544 12544 12544, Time = 0.03413504 0.03414354 0.03415450 s, AVG Performance = 115.6191 Tflops 171 | M N K = 15360 15360 15360, Time = 0.06227354 0.06228111 0.06228992 s, AVG Performance = 116.3717 Tflops 172 | M N K = 15616 15616 15616, Time = 0.06492467 0.06493727 0.06496666 s, AVG Performance = 117.2858 Tflops 173 | M N K = 15872 15872 15872, Time = 0.06843085 0.06843873 0.06844723 s, AVG Performance = 116.8485 Tflops 174 | M N K = 16128 16128 16128, Time = 0.07200256 0.07200881 0.07201792 s, AVG Performance = 116.5161 Tflops 175 | M N K = 16384 16384 16384, Time = 0.07564493 0.07565752 0.07567462 s, AVG Performance = 116.2620 Tflops 176 | 177 | ./hgemm_cublas.bin 178 | # NVIDIA L20 179 | ALGO = cuBLAS CUBLAS_GEMM_DEFAULT_TENSOR_OP TN 180 | M N K = 12544 12544 12544, Time = 0.03472691 0.03472968 0.03473408 s, AVG Performance = 113.6678 Tflops 181 | M N K = 15360 15360 15360, Time = 0.06332416 0.06333143 0.06334157 s, AVG Performance = 114.4417 Tflops 182 | M N K = 15616 15616 15616, Time = 0.06649446 0.06650184 0.06651699 s, AVG Performance = 114.5264 Tflops 183 | M N K = 15872 15872 15872, Time = 0.06977024 0.06977659 0.06978355 s, AVG Performance = 114.6081 Tflops 184 | M N K = 16128 16128 16128, Time = 0.07319142 0.07320709 0.07326925 s, AVG Performance = 114.6089 Tflops 185 | M N K = 16384 16384 16384, Time = 0.07668429 0.07669371 0.07670784 s, AVG Performance = 114.6912 Tflops 186 | ``` 187 | 188 | ## 📖 Benchmark 189 | 190 |
191 | 192 | ### 📖 NVIDIA L20 193 | 196 | The current best implementation, on the L20 (with a theoretical Tensor Cores FP16 performance of 119.5 TFLOPS), achieves performance that is approximately 99~100+% of cuBLAS. 197 | 198 | - Using the WMMA API, it can achieve around 95%~98% of cuBLAS performance (105-113 TFLOPS vs 105-115 TFLOPS). 199 | - Using the MMA API, it can reach 115 TFLOPS, surpassing cuBLAS in some cases. 200 | - The CuTe version of HGEMM implements Block Swizzle (L2 Cache friendly) and SMEM Swizzle (bank conflicts free), achieving the best performance. For large-scale matrix multiplication, it can reach 116-117 TFLOPS, which is approximately 98%~100%+ of cuBLAS performance, and it outperforms cuBLAS in many cases. 201 | 202 | Currently, SMEM Padding and SMEM Swizzle are used to mitigate bank conflicts: 203 | 204 | - For the NN layout, SMEM Padding is used to alleviate bank conflicts. 205 | - For the TN layout, CUTLASS/CuTe's SMEM Swizzle is used to eliminate bank conflicts. 206 | 207 |
208 | 209 | 210 | ![NVIDIA_L20_NN+TN+v2](https://github.com/user-attachments/assets/71927ac9-72b3-4ce9-b0e2-788b5885bc99) 211 | 212 | 213 | The command for testing all MNK setups (Tip: Performance data for each MNK tested individually is more accurate.) 214 | ```bash 215 | python3 hgemm.py --cute-tn --mma --plot 216 | ``` 217 | 218 | ### 📖 NVIDIA GeForce RTX 4090 219 | 220 |
221 | 222 | 225 | 226 | On the NVIDIA RTX 4090 (with an FP16 Tensor Cores performance of 330 TFLOPS), the WMMA (m16n16k16) implementation shows better performance compared to MMA (m16n8k16). For most MNK configurations, this repository's implementation achieves 95%~99% of cuBLAS performance, and in certain cases, it can surpass cuBLAS. Specifically: 227 | 228 | - For large-scale matrix multiplications (MNK >= 8192), the WMMA implementation performs better. 229 | - For small-scale matrix multiplications, the MMA implementation is more efficient. 230 | 231 | 232 | ![NVIDIA_GeForce_RTX_4090_NN+TN+v4](https://github.com/user-attachments/assets/05ef4f5e-d999-48ea-b58e-782cffb24e85) 233 | 234 | ```bash 235 | python3 hgemm.py --cute-tn --mma --wmma-all --plot 236 | ``` 237 | 238 | ### 📖 NVIDIA GeForce RTX 3080 Laptop 239 | 240 |
241 | 242 | 245 | Testing was conducted on a NVIDIA GeForce RTX 3080 Laptop using the mma4x4_warp4x4 configuration (which includes 16 WMMA m16n16k16 operations with a warp tile size of 64x64) along with Thread block swizzle. In most cases, this setup matches or even exceeds cuBLAS performance. The tests were performed using Windows WSL2 + RTX 3080 Laptop. 246 | 247 | ![image](https://github.com/user-attachments/assets/9472e970-c083-4b31-9252-3eeecc761078) 248 | 249 | ```bash 250 | python3 hgemm.py --wmma-all --plot 251 | ``` 252 | 253 |
254 | 🔑️ Performance Optimization Notes(TODO) 255 | 256 | ## 📖 Performance Optimization Notes 257 | 258 |
259 | 260 | ### PyTorch HGEMM Profile 261 | 262 | 在Ada架构下,PyTorch 2.4对FP16使用matmul时,会调用: 263 | ```C++ 264 | ampere_fp16_s1688gemm_fp16_128x128_ldg8_f2f_stages_32x1_nn_kernel 265 | ``` 266 | 内部实际使用HMMA(Tensor Cores)进行计算,在3080上profile发现使用: 267 | ```C++ 268 | sm80_xmma_gemm_f16f16_f16f32_f32_nn_n_tilesize96x64x32_stage3_warpsize2x2x1_tensor16x8x16_kernel 269 | ``` 270 | 因此,只有实现使用Tensor Cores的HGEMM,才有可能接近PyTorch/cuBLAS的性能。 271 | ```bash 272 | ncu -o hgemm.prof -f python3 bench/prof.py 273 | nsys profile --stats=true -t cuda,osrt,nvtx -o hgemm.prof --force-overwrite true python3 prof.py 274 | ``` 275 | - SASS (L20) 276 | 277 | ```C 278 | // ampere_fp16_s1688gemm_fp16_128x128_ldg8_f2f_stages_32x1_nn_kernel 279 | 310 00007f41 37d5b850 LDSM.16.M88.4 R192, [R169+UR8+0x2000] 280 | 311 00007f41 37d5b860 LDSM.16.M88.4 R196, [R169+UR8+0x2800] 281 | 336 00007f41 37d5b9f0 HMMA.1688.F32 R112, R182, R196, R112 282 | ... 283 | ``` 284 | 285 | ### SMEM Padding 286 | 287 | #### Bank Conflicts的产生 288 | 289 | 含义:在访问shared memory时,因多个线程读写同一个Bank中的不同数据地址时,导致shared memory 并发读写 退化 成顺序读写的现象叫做Bank Conflict; 290 | 291 | ![](https://github.com/PaddleJitLab/CUDATutorial/blob/develop/docs/09_optimize_reduce/02_bank_conflict/images/ef322be7c3e5b6b9be69d2b90e88083f50569a58a97129f348e483b946ab4edf.png) 292 | 293 | SM调度单位为一个warp(一个warp内32个Thread),shared_memory 可以 被一个warp中的所有(32个)线程进行访问,shared_memory 映射到大小相等的32个Bank上,Bank的数据读取带宽为32bit / cycle (4 bytes),因此,主要需要考虑一个Warp内32线程的访问共享内存时的bank冲突。 294 | 对于多个线程读取同一个Bank数据时(不同地址),硬件把内存读写请求,拆分成 conflict-free requests,进行顺序读写,此时将会触发多次内存事务。特别地,当一个warp中的所有线程读写同一个地址时,会触发broadcast机制,此时不会退化成顺序读写。上面提到触发broadcast机制的条件是all threads acess same address,但在翻阅cuda-c-programming-guide以及最新版本的[NVProfGuide](https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html) 时,发现只要是多个thread 读写就会触发broadcast(不需要All)。 295 | 296 | - 多个线程读同一个数据时,仅有一个线程读,然后broadcast到其他线程 297 | - 多个线程写同一个数据时,仅会有一个线程写成功 298 | 299 | NVIDIA的[文章](https://developer.nvidia.com/blog/using-shared-memory-cuda-cc/)中指出,我们还可以通过 `cudaDeviceSetSharedMemConfig()` 函数设置默认Bank Size(默认为4 bytes)来避免bank conflicts,可设置为cudaSharedMemBankSizeFourByte或者cudaSharedMemBankSizeEightByte。对于某些场景来说,设置cudaSharedMemBankSizeEightByte或许更加合适,比如使用double数据类型时。 300 | 301 | ```C 302 | cudaDeviceSetSharedMemConfig(cudaSharedMemBankSizeEightByte); 303 | ``` 304 | 目前通过 SMEM Padding 和 SMEM swizzle的方式缓解bank conflicts。对于 NN layout,使用 SMEM Padding 缓解 bank conflicts;对于 TN layout,通过cutlass cute的 SMEM Swizzle 消除 bank conflicts。 305 | 306 | ### 双缓冲 Double Buffers 307 | 308 | 本仓库实现的HGEMM Double Buffers策略如下:1)主循环从bk = 1 开始,第一次数据加载在主循环之前,最后一次计算在主循环之后,这是pipeline 的特点决定的;2)由于计算和下一次访存使用的Shared Memory不同,因此主循环中每次循环只需要一次__syncthreads()即可,对比非double buffers版本,总共节省了 ((K + BK - 1) / BK) - 1 次block内的同步操作。比如,bk=1时,HFMA计算使用的是s_a[0]和s_b[0],因此,和s_a[1]和s_b[1]的加载是没有依赖关系的。HFMA计算,从global内存到s_a[1]和s_b[1]和HFMA计算可以并行。s_a[1]和s_b[1]用于加载下一块BK需要的数据到共享内存;3)由于GPU不能向CPU那样支持乱序执行,主循环中需要先将下一次循环计算需要的Gloabal Memory中的数据load 到寄存器,然后进行本次计算,之后再将load到寄存器中的数据写到Shared Memory,这样在LDG指令向Global Memory做load时,不会影响后续HFMA及其它运算指令的 launch 执行,也就达到了Double Buffers的目的,具体代码见[hgemm.cu](./hgemm.cu)。 309 | 310 | 311 | ### Tile Block 312 | 313 | TODO 314 | 315 | ### Tile Thread 316 | 317 | TODO 318 | 319 | ### Pack LDST 128 bits 320 | 321 | TODO 322 | 323 | ### Async Copy 324 | 325 | TODO 326 | 327 | ### Multi Stages 328 | 329 | TODO 330 | 331 | ### Tensor Cores(WMMA/MMA) 332 | 333 | TODO 334 | 335 | ### Tile MMA/Warp 336 | 337 | TODO 338 | 339 | ### Thread Block Swizze 340 | 341 | TODO 342 | 343 | ### Warp Swizzle 344 | 345 | TODO 346 | 347 | ### Reg Double Buffers 348 | 349 | TODO 350 | 351 | ### Collective Store(Reg Reuse&Warp Shuffle) 352 | 353 | TODO 354 | 355 | ### SMEM Swizzle/Permuted 356 | 357 | TODO 358 | 359 |
360 | 361 | ## 📖 References 362 | 363 |
364 | 365 | - [flash-attention-minimal](https://github.com/tspeterkim/flash-attention-minimal) 366 | - [tiny-flash-attention](https://github.com/66RING/tiny-flash-attention) 367 | - [cute-gemm](https://github.com/reed-lau/cute-gemm) 368 | - [cutlass_flash_atten_fp8](https://github.com/weishengying/cutlass_flash_atten_fp8) 369 | - [cuda_learning](https://github.com/ifromeast/cuda_learning) 370 | - [cuda_hgemm](https://github.com/Bruce-Lee-LY/cuda_hgemm) 371 | - [cuda-tensorcore-hgemm](https://github.com/nicolaswilde/cuda-tensorcore-hgemm) 372 | - [How_to_optimize_in_GPU](https://github.com/Liu-xiandong/How_to_optimize_in_GPU/tree/master/sgemv) 373 | - [cute_gemm](https://github.com/weishengying/cute_gemm) 374 | - [cutlass](https://github.com/NVIDIA/cutlass) 375 | -------------------------------------------------------------------------------- /kernels/hgemm/mma/hgemm_mma_stage_tn.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | using namespace nvcuda; 14 | 15 | #define WARP_SIZE 32 16 | #define DEVICE_INLINE __device__ inline 17 | #define HOST_DEVICE_INLINE __device__ __host__ inline 18 | #define INT4(value) (reinterpret_cast(&(value))[0]) 19 | #define FLOAT4(value) (reinterpret_cast(&(value))[0]) 20 | #define HALF2(value) (reinterpret_cast(&(value))[0]) 21 | #define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162*>(&(value))[0]) 22 | #define LDST32BITS(value) (reinterpret_cast(&(value))[0]) 23 | #define LDST64BITS(value) (reinterpret_cast(&(value))[0]) 24 | #define LDST128BITS(value) (reinterpret_cast(&(value))[0]) 25 | // gmem -> smem 26 | #define CP_ASYNC_COMMIT_GROUP() asm volatile("cp.async.commit_group;\n" ::) 27 | #define CP_ASYNC_WAIT_ALL() asm volatile("cp.async.wait_all;\n" ::) 28 | #define CP_ASYNC_WAIT_GROUP(n) asm volatile("cp.async.wait_group %0;\n" ::"n"(n)) 29 | // ca(cache all, L1 + L2): support 4, 8, 16 bytes, cg(cache global, L2): only support 16 bytes. 30 | #define CP_ASYNC_CA(dst, src, bytes) asm volatile("cp.async.ca.shared.global.L2::128B [%0], [%1], %2;\n" ::"r"(dst), "l"(src), "n"(bytes)) 31 | #define CP_ASYNC_CG(dst, src, bytes) asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], %2;\n" ::"r"(dst), "l"(src), "n"(bytes)) 32 | // smem -> gmem: requires sm_90 or higher. 33 | #define CP_ASYNC_BULK_COMMIT_GROUP() asm volatile("cp.async.bulk.commit_group;\n" ::) 34 | #define CP_ASYNC_BULK_WAIT_ALL() asm volatile("cp.async.bulk.wait_all;\n" ::) 35 | #define CP_ASYNC_BULK_WAIT_GROUP(n) asm volatile("cp.async.bulk.wait_group %0;\n" ::"n"(n)) 36 | #define CP_ASYNC_BULK(dst, src, bytes) asm volatile("cp.async.bulk.global.shared::cta.bulk_group.L2::128B [%0], [%1], %2;\n" ::"r"(dst), "l"(src), "n"(bytes)) 37 | // ldmatrix 38 | #define LDMATRIX_X1(R, addr) asm volatile("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];\n" : "=r"(R) : "r"(addr)) 39 | #define LDMATRIX_X2(R0, R1, addr) asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" : "=r"(R0), "=r"(R1) : "r"(addr)) 40 | #define LDMATRIX_X4(R0, R1, R2, R3, addr) asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" : "=r"(R0), "=r"(R1), "=r"(R2), "=r"(R3) : "r"(addr)) 41 | #define LDMATRIX_X1_T(R, addr) asm volatile("ldmatrix.sync.aligned.x1.trans.m8n8.shared.b16 {%0}, [%1];\n" : "=r"(R) : "r"(addr)) 42 | #define LDMATRIX_X2_T(R0, R1, addr) asm volatile("ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0, %1}, [%2];\n" : "=r"(R0), "=r"(R1) : "r"(addr)) 43 | #define LDMATRIX_X4_T(R0, R1, R2, R3, addr) asm volatile("ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" : "=r"(R0), "=r"(R1), "=r"(R2), "=r"(R3) : "r"(addr)) 44 | // stmatrix: requires sm_90 or higher. 45 | #define STMATRIX_X1(addr, R) asm volatile("stmatrix.sync.aligned.x1.m8n8.shared.b16 [%0], {%1};\n" :: "r"(addr), "r"(R)) 46 | #define STMATRIX_X2(addr, R0, R1) asm volatile("stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};\n" :: "r"(addr), "r"(R0), "r"(R1)) 47 | #define STMATRIX_X4(addr, R0, R1, R2, R3) asm volatile("stmatrix.sync.aligned.x4.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n" :: "r"(addr), "r"(R0), "r"(R1), "r"(R2), "r"(R3)) 48 | #define STMATRIX_X1_T(addr, R) asm volatile("stmatrix.sync.aligned.x1.trans.m8n8.shared.b16 [%0], {%1};\n" :: "r"(addr), "r"(R)) 49 | #define STMATRIX_X2_T(addr, R0, R1) asm volatile("stmatrix.sync.aligned.x2.trans.m8n8.shared.b16 [%0], {%1, %2};\n" :: "r"(addr), "r"(R0), "r"(R1)) 50 | #define STMATRIX_X4_T(addr, R0, R1, R2, R3) asm volatile("stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n" :: "r"(addr), "r"(R0), "r"(R1), "r"(R2), "r"(R3)) 51 | // mma m16n8k16 52 | #define HMMA16816(RD0, RD1, RA0, RA1, RA2, RA3, RB0, RB1, RC0, RC1) asm volatile("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%8, %9};\n" : "=r"(RD0), "=r"(RD1) : "r"(RA0), "r"(RA1), "r"(RA2), "r"(RA3), "r"(RB0), "r"(RB1), "r"(RC0), "r"(RC1)) 53 | 54 | HOST_DEVICE_INLINE 55 | int div_ceil(int a, int b) { return (a % b != 0) ? (a / b + 1) : (a / b); } 56 | 57 | // NN: A/B/C All row major 58 | // TN: A row major MxK, B col major NxK, C row major MxN 59 | // 128x128, mma2x4, warp4x4(64,32,16), stages, block swizzle, dsmem 60 | template 71 | __global__ void __launch_bounds__(256) 72 | hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn_kernel( 73 | half* A, half* B, half* C, int M, int N, int K) { 74 | // BLOCK_SWIZZLE 0/1 control use block swizzle or not. 75 | const int bx = ((int) BLOCK_SWIZZLE) * blockIdx.z * gridDim.x + blockIdx.x; 76 | const int by = blockIdx.y; 77 | const int NUM_K_TILES = div_ceil(K, MMA_K); 78 | constexpr int BM = MMA_M * MMA_TILE_M * WARP_TILE_M; // 16*2*4=128 79 | constexpr int BN = MMA_N * MMA_TILE_N * WARP_TILE_N; // 8*4*4=128 80 | constexpr int BK = MMA_K; // 16 81 | 82 | extern __shared__ half smem[]; 83 | half* s_a = smem; 84 | half* s_b = smem + K_STAGE * BM * (BK + A_PAD); 85 | constexpr int s_a_stage_offset = BM * (BK + A_PAD); // BMxBK 128*16 86 | constexpr int s_b_stage_offset = BN * (BK + B_PAD); // BNxBK 128*16 87 | 88 | const int tid = threadIdx.y * blockDim.x + threadIdx.x; // within block 89 | const int warp_id = tid / WARP_SIZE; // 0~7 warp_id within block 90 | const int lane_id = tid % WARP_SIZE; // 0~31 91 | const int warp_m = warp_id % 2; // 0,1 92 | const int warp_n = warp_id / 2; // 0,1,2,3 93 | 94 | int load_smem_a_m = tid / 2; // row 0~127 95 | int load_smem_a_k = (tid % 2 == 0) ? 0 : 8; // col 0,8 96 | int load_smem_b_n = tid / 2; // row 0~127 97 | int load_smem_b_k = (tid % 2 == 0) ? 0 : 8; // col 0,8 98 | int load_gmem_a_m = by * BM + load_smem_a_m; // global row of c 99 | int load_gmem_b_n = bx * BN + load_smem_b_n; // global col of c 100 | if (load_gmem_a_m >= M || load_gmem_b_n >= N) return; 101 | 102 | uint32_t RC[WARP_TILE_M][WARP_TILE_N][2]; 103 | #pragma unroll 104 | for (int i = 0; i < WARP_TILE_M; ++i) { 105 | #pragma unroll 106 | for (int j = 0; j < WARP_TILE_N; ++j) { 107 | RC[i][j][0] = 0; 108 | RC[i][j][1] = 0; 109 | } 110 | } 111 | 112 | // may avoid cvta overhead ? only cvta smem base ptr once for cp.async. 113 | uint32_t smem_a_base_ptr = __cvta_generic_to_shared(s_a); 114 | uint32_t smem_b_base_ptr = __cvta_generic_to_shared(s_b); 115 | 116 | #pragma unroll 117 | for (int k = 0; k < (K_STAGE - 1); ++k) { // 0, 1 118 | int load_gmem_a_k = k * BK + load_smem_a_k; // global col of a 119 | int load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k; 120 | int load_gmem_b_k = k * BK + load_smem_b_k; // global col of b 121 | int load_gmem_b_addr = load_gmem_b_n * K + load_gmem_b_k; 122 | 123 | uint32_t load_smem_a_ptr = ( 124 | smem_a_base_ptr + (k * s_a_stage_offset + 125 | load_smem_a_m * (BK + A_PAD) + 126 | load_smem_a_k) * sizeof(half) 127 | ); 128 | CP_ASYNC_CG(load_smem_a_ptr, &A[load_gmem_a_addr], 16); 129 | 130 | uint32_t load_smem_b_ptr = ( 131 | smem_b_base_ptr + (k * s_b_stage_offset + 132 | load_smem_b_n * (BK + B_PAD) + 133 | load_smem_b_k) * sizeof(half) 134 | ); 135 | CP_ASYNC_CG(load_smem_b_ptr, &B[load_gmem_b_addr], 16); 136 | 137 | CP_ASYNC_COMMIT_GROUP(); 138 | } 139 | 140 | CP_ASYNC_WAIT_GROUP(K_STAGE-2); // s2->0, s3->1, s4->2 141 | __syncthreads(); 142 | 143 | #pragma unroll 144 | for (int k = (K_STAGE - 1); k < NUM_K_TILES; ++k) { 145 | // gmem -> smem 146 | // s2/4 can use bitwise ops but s3 can not, so, we use mod 147 | // ops for all stages kernel. s2: (k + 1)&1, s4: (k + 1)&3 148 | // s3: (k + 1) % 3 149 | int smem_sel = (k + 1) % K_STAGE; // s3 k 2->0, k 3->1, k 4->2... 150 | int smem_sel_next = k % K_STAGE; // s3 k 2->2, k 3->0, k 4->1... 151 | 152 | int load_gmem_a_k = k * BK + load_smem_a_k; // global col of a 153 | int load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k; 154 | int load_gmem_b_k = k * BK + load_smem_b_k; // global col of b 155 | int load_gmem_b_addr = load_gmem_b_n * K + load_gmem_b_k; 156 | 157 | uint32_t load_smem_a_ptr = ( 158 | smem_a_base_ptr + (smem_sel_next * s_a_stage_offset + 159 | load_smem_a_m * (BK + A_PAD) + 160 | load_smem_a_k) * sizeof(half) 161 | ); 162 | CP_ASYNC_CG(load_smem_a_ptr, &A[load_gmem_a_addr], 16); 163 | 164 | uint32_t load_smem_b_ptr = ( 165 | smem_b_base_ptr + (smem_sel_next * s_b_stage_offset + 166 | load_smem_b_n * (BK + B_PAD) + 167 | load_smem_b_k) * sizeof(half) 168 | ); 169 | CP_ASYNC_CG(load_smem_b_ptr, &B[load_gmem_b_addr], 16); 170 | 171 | CP_ASYNC_COMMIT_GROUP(); 172 | 173 | uint32_t RA[WARP_TILE_M][4]; 174 | uint32_t RB[WARP_TILE_N][2]; 175 | // smem -> reg 176 | #pragma unroll 177 | for (int i = 0; i < WARP_TILE_M; ++i) { 178 | int warp_smem_a_m = warp_m * (MMA_M * WARP_TILE_M) + i * MMA_M; 179 | int lane_smem_a_m = warp_smem_a_m + lane_id % 16; // 0~15 180 | int lane_smem_a_k = (lane_id / 16) * 8; // 0,8 181 | uint32_t lane_smem_a_ptr = ( 182 | smem_a_base_ptr + (smem_sel * s_a_stage_offset + 183 | lane_smem_a_m * (BK + A_PAD) + 184 | lane_smem_a_k) * sizeof(half) 185 | ); 186 | LDMATRIX_X4(RA[i][0], RA[i][1], RA[i][2], RA[i][3], lane_smem_a_ptr); 187 | } 188 | 189 | #pragma unroll 190 | for (int j = 0; j < WARP_TILE_N; ++j) { 191 | int warp_smem_b_n = warp_n * (MMA_N * WARP_TILE_N) + j * MMA_N; 192 | int lane_smem_b_n = warp_smem_b_n + lane_id % 8; // 0~7, MMA_N=8 193 | int lane_smem_b_k = ((lane_id / 8) % 2) * 8; // 0,8 194 | uint32_t lane_smem_b_ptr = ( 195 | smem_b_base_ptr + (smem_sel * s_b_stage_offset + 196 | lane_smem_b_n * (BK + B_PAD) + 197 | lane_smem_b_k) * sizeof(half) 198 | ); 199 | LDMATRIX_X2(RB[j][0], RB[j][1], lane_smem_b_ptr); 200 | } 201 | 202 | // MMA compute 203 | #pragma unroll 204 | for (int i = 0; i < WARP_TILE_M; ++i) { 205 | #pragma unroll 206 | for (int j = 0; j < WARP_TILE_N; ++j) { 207 | HMMA16816(RC[i][j][0], RC[i][j][1], 208 | RA[i][0], RA[i][1], RA[i][2], RA[i][3], 209 | RB[j][0], RB[j][1], 210 | RC[i][j][0], RC[i][j][1]); 211 | } 212 | } 213 | 214 | CP_ASYNC_WAIT_GROUP(K_STAGE-2); 215 | __syncthreads(); 216 | } 217 | 218 | // make sure all memory issues ready. 219 | if ((K_STAGE - 2) > 0) { 220 | CP_ASYNC_WAIT_GROUP(0); 221 | __syncthreads(); 222 | } 223 | 224 | // processing last (K_STAGE-1) k iters. 225 | { 226 | #pragma unroll 227 | for (int k = 0; k < (K_STAGE - 1); k++) { 228 | uint32_t RA[WARP_TILE_M][4]; 229 | uint32_t RB[WARP_TILE_N][2]; 230 | 231 | int stage_sel = ((NUM_K_TILES - (K_STAGE - 1) + k) % K_STAGE); 232 | // ldmatrix for s_a, ldmatrix.trans for s_b. 233 | // smem -> reg 234 | #pragma unroll 235 | for (int i = 0; i < WARP_TILE_M; ++i) { 236 | int warp_smem_a_m = warp_m * (MMA_M * WARP_TILE_M) + i * MMA_M; 237 | int lane_smem_a_m = warp_smem_a_m + lane_id % 16; // 0~15 238 | int lane_smem_a_k = (lane_id / 16) * 8; // 0,8 239 | uint32_t lane_smem_a_ptr = ( 240 | smem_a_base_ptr + (stage_sel * s_a_stage_offset + 241 | lane_smem_a_m * (BK + A_PAD) + 242 | lane_smem_a_k) * sizeof(half) 243 | ); 244 | LDMATRIX_X4(RA[i][0], RA[i][1], RA[i][2], RA[i][3], lane_smem_a_ptr); 245 | } 246 | 247 | #pragma unroll 248 | for (int j = 0; j < WARP_TILE_N; ++j) { 249 | int warp_smem_b_n = warp_n * (MMA_N * WARP_TILE_N) + j * MMA_N; 250 | int lane_smem_b_n = warp_smem_b_n + lane_id % 8; // 0~7, MMA_N=8 251 | int lane_smem_b_k = ((lane_id / 8) % 2) * 8; // 0,8 252 | uint32_t lane_smem_b_ptr = ( 253 | smem_b_base_ptr + (stage_sel * s_b_stage_offset + 254 | lane_smem_b_n * (BK + B_PAD) + 255 | lane_smem_b_k) * sizeof(half) 256 | ); 257 | LDMATRIX_X2(RB[j][0], RB[j][1], lane_smem_b_ptr); 258 | } 259 | 260 | // MMA compute 261 | #pragma unroll 262 | for (int i = 0; i < WARP_TILE_M; ++i) { 263 | #pragma unroll 264 | for (int j = 0; j < WARP_TILE_N; ++j) { 265 | HMMA16816(RC[i][j][0], RC[i][j][1], 266 | RA[i][0], RA[i][1], RA[i][2], RA[i][3], 267 | RB[j][0], RB[j][1], 268 | RC[i][j][0], RC[i][j][1]); 269 | } 270 | } 271 | } 272 | } 273 | 274 | { 275 | for (int i = 0; i < WARP_TILE_M; ++i) { 276 | // How to use LDST128BITS here? __shfl_sync -> lane 0 -> store 8 half. 277 | // thus, we only need 8 memory issues with 128 bits after shfl_sync. 278 | // may reuse RA[4][4] as RC0 ? only new RC1[4][4]. 279 | uint32_t RC0[WARP_TILE_N][4]; 280 | uint32_t RC1[WARP_TILE_N][4]; 281 | #pragma unroll 282 | for (int j = 0; j < WARP_TILE_N; ++j) { 283 | // How to use LDST128BITS here? __shfl_sync -> lane 0 -> store 8 half. 284 | // thus, we only need 8 memory issues with 128 bits after shfl_sync. 285 | RC0[j][0] = RC[i][j][0]; 286 | RC1[j][0] = RC[i][j][1]; 287 | RC0[j][1] = __shfl_sync((0xffffffff), RC[i][j][0], lane_id + 1); 288 | RC0[j][2] = __shfl_sync((0xffffffff), RC[i][j][0], lane_id + 2); 289 | RC0[j][3] = __shfl_sync((0xffffffff), RC[i][j][0], lane_id + 3); 290 | RC1[j][1] = __shfl_sync((0xffffffff), RC[i][j][1], lane_id + 1); 291 | RC1[j][2] = __shfl_sync((0xffffffff), RC[i][j][1], lane_id + 2); 292 | RC1[j][3] = __shfl_sync((0xffffffff), RC[i][j][1], lane_id + 3); 293 | } 294 | 295 | if (lane_id % 4 == 0) { 296 | int store_warp_smem_c_m = warp_m * (MMA_M * WARP_TILE_M) + i * MMA_M; 297 | int store_lane_gmem_c_m = by * BM + store_warp_smem_c_m + lane_id / 4; 298 | #pragma unroll 299 | for (int j = 0; j < WARP_TILE_N; ++j) { 300 | int store_warp_smem_c_n = warp_n * (MMA_N * WARP_TILE_N) + j * MMA_N; 301 | int store_lane_gmem_c_n = bx * BN + store_warp_smem_c_n; 302 | int store_gmem_c_addr_0 = store_lane_gmem_c_m * N + store_lane_gmem_c_n; 303 | int store_gmem_c_addr_1 = (store_lane_gmem_c_m + 8) * N + store_lane_gmem_c_n; 304 | LDST128BITS(C[store_gmem_c_addr_0]) = LDST128BITS(RC0[j][0]); 305 | LDST128BITS(C[store_gmem_c_addr_1]) = LDST128BITS(RC1[j][0]); 306 | } 307 | } 308 | } 309 | } 310 | } 311 | 312 | // --------------------- PyTorch bindings for custom kernel ----------------------- 313 | #define STRINGFY(str) #str 314 | #define TORCH_BINDING_COMMON_EXTENSION(func) \ 315 | m.def(STRINGFY(func), &func, STRINGFY(func)); 316 | 317 | #define CHECK_TORCH_TENSOR_DTYPE(T, th_type) \ 318 | if(((T).options().dtype() != (th_type))) { \ 319 | std::cout << "Tensor Info:" << (T).options() << std::endl; \ 320 | throw std::runtime_error("values must be "#th_type); \ 321 | } 322 | 323 | #define CHECK_TORCH_TENSOR_SHAPE(T, S0, S1) \ 324 | if (((T).size(0) != (S0)) || ((T).size(1) != (S1))) { \ 325 | throw std::runtime_error("Tensor size mismatch!"); \ 326 | } 327 | 328 | // 128x128, mma2x4, warp4x4(64,32,16), stages, block swizzle, dsmem, TN 329 | #define LAUNCH_16816_STAGE_SWIZZLE_MMA2x4_WARP4x4_DSMEM_TN_KERNEL(stages, stride) \ 330 | { \ 331 | const int smem_max_size = ( \ 332 | (stages) * BM * (BK + A_PAD) * sizeof(half) + \ 333 | (stages) * BN * (BK + B_PAD) * sizeof(half)); \ 334 | cudaFuncSetAttribute( \ 335 | hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn_kernel< \ 336 | MMA_M, MMA_N, MMA_K, MMA_TILE_M, MMA_TILE_N, \ 337 | WARP_TILE_M, WARP_TILE_N, A_PAD, B_PAD, (stages), true>, \ 338 | cudaFuncAttributeMaxDynamicSharedMemorySize, \ 339 | 98304); \ 340 | const int N_SWIZZLE = (N + (stride) - 1) / (stride); \ 341 | dim3 block(NUM_THREADS); \ 342 | dim3 grid((div_ceil(N, BN) + N_SWIZZLE - 1) / N_SWIZZLE, \ 343 | div_ceil(M, BM), \ 344 | N_SWIZZLE); \ 345 | hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn_kernel< \ 346 | MMA_M, MMA_N, MMA_K, MMA_TILE_M, MMA_TILE_N, \ 347 | WARP_TILE_M, WARP_TILE_N, A_PAD, B_PAD, (stages), true><<< \ 348 | grid, block, smem_max_size>>>( \ 349 | reinterpret_cast(a.data_ptr()), \ 350 | reinterpret_cast(b.data_ptr()), \ 351 | reinterpret_cast(c.data_ptr()), \ 352 | M, N, K \ 353 | ); \ 354 | } 355 | 356 | #define LAUNCH_16816_STAGE_NO_SWIZZLE_MMA2x4_WARP4x4_DSMEM_TN_KERNEL(stages) \ 357 | { \ 358 | const int smem_max_size = ( \ 359 | (stages) * BM * (BK + A_PAD) * sizeof(half) + \ 360 | (stages) * BN * (BK + B_PAD) * sizeof(half)); \ 361 | cudaFuncSetAttribute( \ 362 | hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn_kernel< \ 363 | MMA_M, MMA_N, MMA_K, MMA_TILE_M, MMA_TILE_N, \ 364 | WARP_TILE_M, WARP_TILE_N, A_PAD, B_PAD, (stages), false>, \ 365 | cudaFuncAttributeMaxDynamicSharedMemorySize, \ 366 | 98304); \ 367 | dim3 block(NUM_THREADS); \ 368 | dim3 grid(div_ceil(N, BN), div_ceil(M, BM)); \ 369 | hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn_kernel< \ 370 | MMA_M, MMA_N, MMA_K, MMA_TILE_M, MMA_TILE_N, \ 371 | WARP_TILE_M, WARP_TILE_N, A_PAD, B_PAD, (stages), false><<< \ 372 | grid, block, smem_max_size>>>( \ 373 | reinterpret_cast(a.data_ptr()), \ 374 | reinterpret_cast(b.data_ptr()), \ 375 | reinterpret_cast(c.data_ptr()), \ 376 | M, N, K \ 377 | ); \ 378 | } 379 | 380 | // 128x128, mma2x4, warp4x4(64,32,16), stages, block swizzle, dsmem 381 | void hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn( 382 | torch::Tensor a, torch::Tensor b, torch::Tensor c, 383 | int stages, bool swizzle, int swizzle_stride) { 384 | CHECK_TORCH_TENSOR_DTYPE(a, torch::kHalf) 385 | CHECK_TORCH_TENSOR_DTYPE(b, torch::kHalf) 386 | CHECK_TORCH_TENSOR_DTYPE(c, torch::kHalf) 387 | const int M = a.size(0); 388 | const int K = a.size(1); 389 | const int N = b.size(1); 390 | CHECK_TORCH_TENSOR_SHAPE(a, M, K) 391 | CHECK_TORCH_TENSOR_SHAPE(b, K, N) 392 | CHECK_TORCH_TENSOR_SHAPE(c, M, N) 393 | constexpr int MMA_M = 16; 394 | constexpr int MMA_N = 8; 395 | constexpr int MMA_K = 16; 396 | constexpr int MMA_TILE_M = 2; 397 | constexpr int MMA_TILE_N = 4; 398 | constexpr int WARP_TILE_M = 4; 399 | constexpr int WARP_TILE_N = 4; 400 | constexpr int A_PAD = 0; // 0,8,16 401 | constexpr int B_PAD = 0; // 0,8,16 402 | constexpr int NUM_THREADS= ( 403 | MMA_TILE_M * MMA_TILE_N * WARP_SIZE); // 2 * 4 * 32 = 256 404 | constexpr int BM = MMA_M * MMA_TILE_M * WARP_TILE_M; 405 | constexpr int BN = MMA_N * MMA_TILE_N * WARP_TILE_N; 406 | constexpr int BK = MMA_K; 407 | 408 | if (swizzle) { 409 | // assert(swizzle_stride % 256 == 0); 410 | switch (stages) 411 | { 412 | case 2: 413 | LAUNCH_16816_STAGE_SWIZZLE_MMA2x4_WARP4x4_DSMEM_TN_KERNEL(2, swizzle_stride); 414 | break; 415 | case 3: 416 | LAUNCH_16816_STAGE_SWIZZLE_MMA2x4_WARP4x4_DSMEM_TN_KERNEL(3, swizzle_stride); 417 | break; 418 | case 4: 419 | LAUNCH_16816_STAGE_SWIZZLE_MMA2x4_WARP4x4_DSMEM_TN_KERNEL(4, swizzle_stride); 420 | break; 421 | case 5: 422 | LAUNCH_16816_STAGE_SWIZZLE_MMA2x4_WARP4x4_DSMEM_TN_KERNEL(5, swizzle_stride); 423 | break; 424 | default: 425 | LAUNCH_16816_STAGE_SWIZZLE_MMA2x4_WARP4x4_DSMEM_TN_KERNEL(2, swizzle_stride); 426 | break; 427 | } 428 | } else { 429 | switch (stages) 430 | { 431 | case 2: 432 | LAUNCH_16816_STAGE_NO_SWIZZLE_MMA2x4_WARP4x4_DSMEM_TN_KERNEL(2); 433 | break; 434 | case 3: 435 | LAUNCH_16816_STAGE_NO_SWIZZLE_MMA2x4_WARP4x4_DSMEM_TN_KERNEL(3); 436 | break; 437 | case 4: 438 | LAUNCH_16816_STAGE_NO_SWIZZLE_MMA2x4_WARP4x4_DSMEM_TN_KERNEL(4); 439 | break; 440 | case 5: 441 | LAUNCH_16816_STAGE_NO_SWIZZLE_MMA2x4_WARP4x4_DSMEM_TN_KERNEL(5); 442 | break; 443 | default: 444 | LAUNCH_16816_STAGE_NO_SWIZZLE_MMA2x4_WARP4x4_DSMEM_TN_KERNEL(2); 445 | break; 446 | } 447 | } 448 | } 449 | -------------------------------------------------------------------------------- /kernels/hgemm/cutlass/hgemm_mma_stage_tn_cute.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | // BlockSwizzle: means apply thread block swizzle across N dim 8 | template < 9 | typename T, 10 | int BM, 11 | int BN, 12 | int BK, 13 | int kStage, 14 | typename TiledMMA, 15 | typename G2SCopyA, 16 | typename G2SCopyB, 17 | typename SmemLayoutA, 18 | typename SmemLayoutB, 19 | typename SmemLayoutC, 20 | typename S2RCopyAtomA, 21 | typename S2RCopyAtomB, 22 | typename R2SCopyAtomC, 23 | typename S2GCopyAtomC, 24 | typename S2GCopyC, 25 | const bool BlockSwizzle> 26 | __global__ void hgemm_mma_stages_block_swizzle_tn_cute_kernel( 27 | T *Aptr, T *Bptr, T *Dptr, int m, int n, int k) { 28 | using namespace cute; 29 | // Initilize shared memory 30 | extern __shared__ T shm_data[]; 31 | 32 | T *Ashm = shm_data; 33 | T *Bshm = shm_data + cute::cosize(SmemLayoutA{}); 34 | 35 | // Initilize thread block 36 | int idx = threadIdx.x; 37 | // BlockSwizzle 0/1 control use block swizzle or not. 38 | int ix = ((int) BlockSwizzle) * blockIdx.z * gridDim.x + blockIdx.x; 39 | int iy = blockIdx.y; 40 | 41 | if (iy * BM >= m || ix * BN >= n) return; 42 | 43 | // use Tensor notation to represent device pointer + dimension 44 | Tensor A = make_tensor(make_gmem_ptr(Aptr), make_shape(m, k), make_stride(k, Int<1>{})); 45 | Tensor B = make_tensor(make_gmem_ptr(Bptr), make_shape(n, k), make_stride(k, Int<1>{})); 46 | Tensor D = make_tensor(make_gmem_ptr(Dptr), make_shape(m, n), make_stride(n, Int<1>{})); 47 | 48 | // slice the tensor to small one which is used for current thread block. 49 | Tensor gA = local_tile(A, make_tile(Int{}, Int{}), make_coord(iy, _)); // (BM, BK, num_tile_k) 50 | Tensor gB = local_tile(B, make_tile(Int{}, Int{}), make_coord(ix, _)); // (BN, BK, num_tile_k) 51 | Tensor gD = local_tile(D, make_tile(Int{}, Int{}), make_coord(iy, ix));// (BM, BN) 52 | 53 | // shared memory 54 | auto sA = make_tensor(make_smem_ptr(Ashm), SmemLayoutA{}); // (BM, BK, kStage) 55 | auto sB = make_tensor(make_smem_ptr(Bshm), SmemLayoutB{}); // (BN, BK, kStage) 56 | 57 | // dispatch TileA/TileB/TileC mma tensor into thread fragment via partition 58 | TiledMMA tiled_mma; 59 | auto thr_mma = tiled_mma.get_slice(threadIdx.x); 60 | auto tCgD = thr_mma.partition_C(gD); // (MMA,MMA_M, MMA_N) 61 | 62 | auto tCrA = thr_mma.partition_fragment_A(gA(_, _, 0)); // (MMA, MMA_M, MMA_K) 63 | auto tCrB = thr_mma.partition_fragment_B(gB(_, _, 0)); // (MMA, MMA_N, MMA_K) 64 | auto tCrD = thr_mma.partition_fragment_C(gD); // (MMA, MMA_M, MMA_N) 65 | clear(tCrD); 66 | 67 | // from global memory to shared memory 68 | G2SCopyA g2s_tiled_copy_a; 69 | auto g2s_thr_copy_a = g2s_tiled_copy_a.get_slice(idx); 70 | auto tAgA_copy = g2s_thr_copy_a.partition_S(gA); // (CPY, CPY_M, CPY_K, num_tile_k) 71 | auto tAsA_copy = g2s_thr_copy_a.partition_D(sA); // (CPY, CPY_M, CPY_K, kStage) 72 | #ifdef CUTE_HGEMM_DEBUG 73 | if (thread0()) { 74 | print("\npartition_S(tAgA_copy): \n"); print(tAgA_copy); print("\n"); 75 | print("\nThrCopy(g2s_thr_copy_a): \n"); print(g2s_thr_copy_a); print("\n"); 76 | } 77 | #endif 78 | 79 | G2SCopyB g2s_tiled_copy_b; 80 | auto g2s_thr_copy_b = g2s_tiled_copy_b.get_slice(idx); 81 | auto tBgB_copy = g2s_thr_copy_b.partition_S(gB); // (CPY, CPY_N, CPY_K, num_tile_k) 82 | auto tBsB_copy = g2s_thr_copy_b.partition_D(sB); // (CPY, CPY_N, CPY_K, kStage) 83 | 84 | // from shared memory to register, use tiled_mma to generate tiled_copy 85 | auto s2r_tiled_copy_a = make_tiled_copy_A(S2RCopyAtomA{}, tiled_mma); 86 | auto s2r_thr_copy_a = s2r_tiled_copy_a.get_slice(idx); 87 | auto tAsA = s2r_thr_copy_a.partition_S(sA); // (CPY, CPY_M, CPY_K, kStage) 88 | auto tCrA_view = s2r_thr_copy_a.retile_D(tCrA); // (CPY, CPY_M, CPY_K) 89 | 90 | auto s2r_tiled_copy_b = make_tiled_copy_B(S2RCopyAtomB{}, tiled_mma); 91 | auto s2r_thr_copy_b = s2r_tiled_copy_b.get_slice(idx); 92 | auto tBsB = s2r_thr_copy_b.partition_S(sB); // (CPY, CPY_N, CPY_K, kStage) 93 | auto tCrB_view = s2r_thr_copy_b.retile_D(tCrB); // (CPY, CPY_N, CPY_K) 94 | 95 | /* PREFETCH */ 96 | // submit kStage - 1 tile 97 | // gmem -> shm 98 | int itile_to_read = 0; 99 | int ismem_read = 0; 100 | int ismem_write = 0; 101 | 102 | #pragma unroll 103 | for (int istage = 0; istage < kStage - 1; ++istage) { 104 | cute::copy(g2s_tiled_copy_a, tAgA_copy(_, _, _, istage), 105 | tAsA_copy(_, _, _, istage)); 106 | cute::copy(g2s_tiled_copy_b, tBgB_copy(_, _, _, istage), 107 | tBsB_copy(_, _, _, istage)); 108 | cp_async_fence(); 109 | 110 | ++itile_to_read; 111 | ++ismem_write; 112 | } 113 | 114 | // wait one submitted gmem->smem done 115 | cp_async_wait(); 116 | __syncthreads(); 117 | 118 | int ik = 0; 119 | // smem -> reg 120 | // tAsA: (CPY, CPY_M, CPY_K, kStage) tCrA_view: (CPY, CPY_M, CPY_K) 121 | cute::copy(s2r_tiled_copy_a, tAsA(_, _, ik, ismem_read), tCrA_view(_, _, ik)); 122 | cute::copy(s2r_tiled_copy_b, tBsB(_, _, ik, ismem_read), tCrB_view(_, _, ik)); 123 | 124 | // loop over k: i. load tile, ii. mma 125 | int ntile = k / BK; 126 | #pragma unroll 1 127 | for (int itile = 0; itile < ntile; ++itile) { 128 | int nk = size<2>(tCrA); // (MMA, MMA_M, MMA_K) 129 | 130 | #pragma unroll 131 | for (int ik = 0; ik < nk; ++ik) { 132 | int ik_next = (ik + 1) % nk; 133 | 134 | if (ik == nk - 1) { 135 | cp_async_wait(); 136 | __syncthreads(); 137 | 138 | ismem_read = (ismem_read + 1) % kStage; 139 | } 140 | 141 | // shm -> reg s[itile][ik + 1] -> r[ik + 1] 142 | // tAsA: (CPY, CPY_M, CPY_K, kStage), tCrA_view: (CPY, CPY_M, CPY_K) 143 | cute::copy(s2r_tiled_copy_a, tAsA(_, _, ik_next, ismem_read), 144 | tCrA_view(_, _, ik_next)); 145 | // tBsB: (CPY, CPY_M, CPY_K, kStage), tCrB_view: (CPY, CPY_M, CPY_K) 146 | cute::copy(s2r_tiled_copy_b, tBsB(_, _, ik_next, ismem_read), 147 | tCrB_view(_, _, ik_next)); 148 | 149 | if (ik == 0) { 150 | if (itile_to_read < ntile) { 151 | cute::copy(g2s_tiled_copy_a, tAgA_copy(_, _, _, itile_to_read), 152 | tAsA_copy(_, _, _, ismem_write)); 153 | cute::copy(g2s_tiled_copy_b, tBgB_copy(_, _, _, itile_to_read), 154 | tBsB_copy(_, _, _, ismem_write)); 155 | ++itile_to_read; 156 | ismem_write = (ismem_write + 1) % kStage; 157 | } 158 | 159 | cp_async_fence(); 160 | } 161 | 162 | cute::gemm(tiled_mma, tCrD, tCrA(_, _, ik), tCrB(_, _, ik), tCrD); 163 | } // for ik 164 | } 165 | 166 | // use less shared memory as a scratchpad tile to use large wide instuction 167 | // Dreg -> shm -> reg -> global 168 | auto sC = make_tensor(sA(_, _, ismem_read).data(), SmemLayoutC{}); 169 | 170 | auto r2s_tiled_copy_c = make_tiled_copy_C(R2SCopyAtomC{}, tiled_mma); 171 | auto r2s_thr_copy_c = r2s_tiled_copy_c.get_slice(idx); 172 | auto tCrC_r2s = r2s_thr_copy_c.retile_S(tCrD); // (CPY, CPY_M, CPY_N) 173 | auto tCsC_r2s = r2s_thr_copy_c.partition_D(sC); // (CPY, _1, _1, pipe) 174 | 175 | S2GCopyC s2g_tiled_copy_c; 176 | auto s2g_thr_copy_c = s2g_tiled_copy_c.get_thread_slice(idx); 177 | auto tCsC_s2g = s2g_thr_copy_c.partition_S(sC); // (CPY, _1, _1, pipe) 178 | auto tCgC_s2g = s2g_thr_copy_c.partition_D(gD); // (CPY, CPY_M, CPY_N) 179 | 180 | auto tCgC_s2gx = group_modes<1, 3>(tCgC_s2g); // (CPY_, CPY_MN) 181 | auto tCrC_r2sx = group_modes<1, 3>(tCrC_r2s); // (CPY_, CPY_MN) 182 | 183 | int step = size<3>(tCsC_r2s); // pipe 184 | #pragma unroll 185 | for (int i = 0; i < size<1>(tCrC_r2sx); i += step) { 186 | // reg -> shm 187 | #pragma unroll 188 | for (int j = 0; j < step; ++j) { 189 | // we add a temp tensor to cope with accumulator and output data type 190 | // difference 191 | auto t = make_tensor_like(tCrC_r2sx(_, i + j)); 192 | cute::copy(tCrC_r2sx(_, i + j), t); 193 | 194 | cute::copy(r2s_tiled_copy_c, t, tCsC_r2s(_, 0, 0, j)); 195 | } 196 | __syncthreads(); 197 | 198 | #pragma unroll 199 | // shm -> global 200 | for (int j = 0; j < step; ++j) { 201 | cute::copy(s2g_tiled_copy_c, tCsC_s2g(_, 0, 0, j), tCgC_s2gx(_, i + j)); 202 | } 203 | __syncthreads(); 204 | } // end for 205 | } 206 | 207 | // For torch binding, need dynamic block swizzle stride 208 | template 209 | void launch_hgemm_mma_stages_block_swizzle_tn_cute(T *a, 210 | T *b, 211 | T *c, 212 | int M, 213 | int N, 214 | int K, 215 | int swizzle_stride) { 216 | // block swizzle_stride: 1024/2048/..., etc. 217 | using namespace cute; 218 | 219 | auto BM = Int<128>{}; 220 | auto BN = Int<256>{}; 221 | auto BK = Int<32>{}; 222 | auto KStage = Int{}; // default 2 223 | auto kSmemLayoutCBatch = Int<4>{}; // namely, stages. 224 | 225 | // Define the smem layouts, Swizzle<3, 3, 3> and 226 | // Swizzle<2, 3, 3> will get the same results. 227 | // reference: https://zhuanlan.zhihu.com/p/671419093 228 | using SmemLayoutAtom = decltype( 229 | composition( 230 | Swizzle<3, 3, 3>{}, 231 | make_layout(make_shape(Int<8>{}, Int{}), 232 | make_stride(Int{}, Int<1>{})) 233 | ) 234 | ); 235 | using SmemLayoutA = decltype( 236 | tile_to_shape(SmemLayoutAtom{}, 237 | make_shape(Int{}, Int{}, Int{})) 238 | ); 239 | using SmemLayoutB = decltype( 240 | tile_to_shape(SmemLayoutAtom{}, 241 | make_shape(Int{}, Int{}, Int{})) 242 | ); // (m,n) -> smem_idx 243 | #ifdef CUTE_HGEMM_DEBUG 244 | print("SmemLayoutA: "); print(SmemLayoutA{}); print("\n"); 245 | print("SmemLayoutB: "); print(SmemLayoutB{}); print("\n"); 246 | print("SmemLayoutB: "); print(SmemLayoutB{}); print("\n"); 247 | print("SmemLayoutAtom A&B Latex: \n"); print_latex(SmemLayoutAtom{}); print("\n"); 248 | #endif 249 | 250 | // mma 251 | using mma_op = SM80_16x8x16_F16F16F16F16_TN; 252 | using mma_traits = MMA_Traits; 253 | using mma_atom = MMA_Atom; 254 | static constexpr int kMmaEURepeatM = 2; // MMA repeat 2 times across M 255 | static constexpr int kMmaEURepeatN = 2; // MMA repeat 2 times across N 256 | static constexpr int kMmaEURepeatK = 1; // MMA no repeat across K 257 | 258 | using mma_atom_shape = mma_traits::Shape_MNK; // M,N,K 16,8,16 259 | static constexpr int kMmaPM = 1 * kMmaEURepeatM * get<0>(mma_atom_shape{}); // 1*2*16=32 260 | static constexpr int kMmaPN = 2 * kMmaEURepeatN * get<1>(mma_atom_shape{}); // 2*2*8 =32 261 | static constexpr int kMmaPK = 1 * kMmaEURepeatK * get<2>(mma_atom_shape{}); // 1*1*16=16 262 | // TiledMMA, more threads, MMAThrLayout(2,2,1), 4 MMA = 4 warps = 32x4 threads. 263 | using MMA_EU_RepeatT = decltype(make_layout(make_shape( 264 | Int{}, Int{}, Int{}))); 265 | // TiledMMA, more values, Permutations(32,32,16) 266 | using MMA_P_T = Tile, Int, Int>; 267 | using MMA = decltype(make_tiled_mma(mma_atom{}, MMA_EU_RepeatT{}, MMA_P_T{})); 268 | #ifdef CUTE_HGEMM_DEBUG 269 | print("MMA: "); print(MMA{}); print("\n"); 270 | print("MMA Latex: \n"); print_latex(MMA{}); print("\n"); 271 | #endif 272 | 273 | // copy from global memory to shared memory 274 | using g2s_copy_op = SM80_CP_ASYNC_CACHEGLOBAL; 275 | using g2s_copy_traits = Copy_Traits; 276 | using g2s_copy_atom = Copy_Atom; 277 | // Make TiledCopy according to ThrLayout and ValLayout. 278 | // 32x4 threads, each thread load 1x8 values (128 bits) once ? 279 | // Produce a TiledCopy from logical thread and values layouts. 280 | // The thread and value layouts map coordinates to thr_idx and val_idx. 281 | // The product of these layouts is taken to produce the TV layout and the Tiler. 282 | // Useful when threads and values need very specific mappings onto coordinates 283 | // in the target tensors. 284 | using G2SCopyA = 285 | decltype(make_tiled_copy(g2s_copy_atom{}, 286 | make_layout(make_shape(Int<32>{}, Int<4>{}), // Thr layout 32x4 k-major 287 | make_stride(Int<4>{}, Int<1>{})), 288 | make_layout(make_shape(Int<1>{}, Int<8>{})))); // Val layout 1x8 289 | using G2SCopyB = G2SCopyA; 290 | #ifdef CUTE_HGEMM_DEBUG 291 | print("G2SCopyA: "); print(G2SCopyA{}); print("\n"); 292 | print("G2SCopyB: "); print(G2SCopyB{}); print("\n"); 293 | print("G2SCopyA Latex: \n"); print_latex(G2SCopyA{}); print("\n"); 294 | print("G2SCopyB Latex: \n"); print_latex(G2SCopyB{}); print("\n"); 295 | #endif 296 | // copy from shared memory to register 297 | // use mma tiled ,so no tiled here 298 | using s2r_copy_op = SM75_U32x4_LDSM_N; 299 | using s2r_copy_traits = Copy_Traits; 300 | using s2r_copy_atom = Copy_Atom; 301 | using S2RCopyAtomA = s2r_copy_atom; 302 | using S2RCopyAtomB = s2r_copy_atom; 303 | 304 | // epilogue: register to global via shared memory 305 | // Swizzle<3, 3, 3>=BxMxS=(2^3)*(2^3)*(2^3)=512 values=1024 bytes. 306 | // reference: https://zhuanlan.zhihu.com/p/671419093 307 | using SmemLayoutAtomC = decltype( 308 | composition( 309 | Swizzle<3, 3, 3>{}, 310 | make_layout(make_shape(Int{}, Int{}), // 32*32 311 | make_stride(Int{}, Int<1>{}))) 312 | ); 313 | // kSmemLayoutCBatch=4, 32x32x4=4096 values=8192 bytes 314 | using SmemLayoutC = decltype( 315 | tile_to_shape( 316 | SmemLayoutAtomC{}, 317 | make_shape(Int{}, Int{}, Int{}) 318 | ) 319 | ); 320 | 321 | static_assert( 322 | size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) >= size(SmemLayoutC{}), 323 | "C shared memory request is large than A's one pipe" 324 | ); 325 | #ifdef CUTE_HGEMM_DEBUG 326 | print(SmemLayoutC{}); print("\n"); 327 | static constexpr int tmp_sizeC = size(SmemLayoutC{}); 328 | static constexpr int tmp_sizeA_0 = size<0>(SmemLayoutA{}); 329 | static constexpr int tmp_sizeA_1 = size<1>(SmemLayoutA{}); 330 | static constexpr int tmp_sizeA = tmp_sizeA_0 * tmp_sizeA_1; 331 | print("size SmemLayoutC: %d", tmp_sizeC); print("\n"); 332 | print("size SmemLayoutA: %d", tmp_sizeA); print("\n"); 333 | print("size 0 SmemLayoutA: %d", tmp_sizeA_0); print("\n"); 334 | print("size 1 SmemLayoutA: %d", tmp_sizeA_1); print("\n"); 335 | #endif 336 | 337 | using R2SCopyAtomC = Copy_Atom, T>; 338 | 339 | using S2GCopyAtomC = Copy_Atom, T>; 340 | using S2GCopyC = decltype( 341 | make_tiled_copy( 342 | S2GCopyAtomC{}, 343 | make_layout(make_shape(Int<32>{}, Int<4>{}), 344 | make_stride(Int<4>{}, Int<1>{})), 345 | make_layout(make_shape(Int<1>{}, Int<8>{})) 346 | ) 347 | ); 348 | 349 | int BX = (N + BN - 1) / BN; 350 | int BY = (M + BM - 1) / BM; 351 | // NOTE: Apply thread block swizzle across N dim. 352 | int BZ = BlockSwizzle ? (N + (swizzle_stride) - 1) / (swizzle_stride) : 1; 353 | BX = BlockSwizzle ? (BX + BZ - 1) / BZ : BX; 354 | 355 | dim3 block(size(MMA{})); 356 | dim3 grid(BX, BY, BZ); 357 | 358 | // C_shm is shared with A_shm and B_shm 359 | // we don't allocate new smem for C_shm. 360 | // (128 * 32 * 2) * 2 + (256 * 32 * 2) * 2 = 49152 bytes, stages=2 361 | static constexpr int shm_size_AB = 362 | cute::cosize(SmemLayoutA{}) + cute::cosize(SmemLayoutB{}); 363 | static constexpr int shm_size_C = cute::cosize(SmemLayoutC{}); 364 | static constexpr int kShmSize = 365 | cute::max(shm_size_AB, shm_size_C) * sizeof(T); 366 | 367 | int shm_size = kShmSize; 368 | #ifdef CUTE_HGEMM_DEBUG 369 | print("shm_size: %d bytes, shm_size_AB: %d bytes, shm_size_C: %d bytes\n", 370 | shm_size, shm_size_AB * (int) sizeof(T), shm_size_C * (int) sizeof(T)); 371 | #endif 372 | 373 | cudaFuncSetAttribute( 374 | hgemm_mma_stages_block_swizzle_tn_cute_kernel< 375 | T, 376 | BM, BN, BK, 377 | KStage, 378 | MMA, 379 | G2SCopyA, 380 | G2SCopyB, 381 | SmemLayoutA, 382 | SmemLayoutB, 383 | SmemLayoutC, 384 | S2RCopyAtomA, 385 | S2RCopyAtomB, 386 | R2SCopyAtomC, 387 | S2GCopyAtomC, 388 | S2GCopyC, 389 | BlockSwizzle 390 | >, 391 | cudaFuncAttributeMaxDynamicSharedMemorySize, 392 | shm_size 393 | ); 394 | 395 | hgemm_mma_stages_block_swizzle_tn_cute_kernel< 396 | T, 397 | BM, BN, BK, 398 | KStage, 399 | MMA, 400 | G2SCopyA, 401 | G2SCopyB, 402 | SmemLayoutA, 403 | SmemLayoutB, 404 | SmemLayoutC, 405 | S2RCopyAtomA, 406 | S2RCopyAtomB, 407 | R2SCopyAtomC, 408 | S2GCopyAtomC, 409 | S2GCopyC, 410 | BlockSwizzle 411 | ><<>>(a, b, c, M, N, K); 412 | } 413 | 414 | // build cpp binary 415 | #ifndef NO_CUTE_HGEMM_BIN 416 | 417 | #include "utils.h" 418 | 419 | int main() { 420 | using T = cute::half_t; 421 | using namespace cute; 422 | #ifdef CUTE_HGEMM_DEBUG 423 | const int test_num = 1; 424 | #else 425 | const int test_num = 64; 426 | #endif 427 | int M_list[test_num]; 428 | int N_list[test_num]; 429 | int K_list[test_num]; 430 | 431 | for (int i = 0; i < test_num; i++) { 432 | M_list[i] = (i + 1) * 256; 433 | N_list[i] = (i + 1) * 256; 434 | K_list[i] = (i + 1) * 256; 435 | } 436 | 437 | const int thread_block_swizzle_stride = 2048; // thread block swizzle stride 438 | printf("ALGO = CuTe HGEMM, TN, STAGES=2, SMEM SWIZZLE=<3, 3, 3>, BLOCK SWIZZLE=2048\n"); 439 | int check_num = test_num > 5 ? 5 : 1; 440 | for (int j = 0; j < check_num; j++) { 441 | int M = M_list[j], N = N_list[j], K = K_list[j]; 442 | float max_error = gemm_error_check_tn_swizzle( 443 | launch_hgemm_mma_stages_block_swizzle_tn_cute, 444 | M, N, K, thread_block_swizzle_stride); 445 | printf("M N K = %6d %6d %6d, ", M, N, K); 446 | printf("Max Error = %f\n", max_error); 447 | } 448 | 449 | #ifndef CUTE_HGEMM_DEBUG 450 | const int outer_repeat = 10, inner_repeat = 1; 451 | for (int j = 0; j < test_num; j++) { 452 | int M = M_list[j], N = N_list[j], K = K_list[j]; 453 | 454 | double max_sec = 0.0; 455 | double min_sec = DBL_MAX; 456 | double total_sec = 0.0; 457 | 458 | for (int k = 0; k < outer_repeat; k++) { 459 | double this_sec = perf_gemm_swizzle( 460 | launch_hgemm_mma_stages_block_swizzle_tn_cute, 461 | M, N, K, thread_block_swizzle_stride, inner_repeat); 462 | max_sec = max(max_sec, this_sec); 463 | min_sec = min(min_sec, this_sec); 464 | total_sec += this_sec; 465 | } 466 | 467 | // 1 TFLOPS = 10^12 FLOPS 468 | // ref: https://imgtec.eetrend.com/blog/2021/100062210.html. 469 | double avg_sec = total_sec / outer_repeat; 470 | double avg_Tflops = ((double)M) * N * K * 2 * 1e-12 / avg_sec; 471 | 472 | printf("M N K = %6d %6d %6d, ", M, N, K); 473 | printf("Time = %12.8lf %12.8lf %12.8lf s, ", min_sec, avg_sec, max_sec); 474 | printf("AVG Performance = %10.4lf Tflops\n", avg_Tflops); 475 | } 476 | #endif 477 | 478 | return 0; 479 | } 480 | 481 | #else 482 | // build torch python binding 483 | 484 | #include 485 | #include 486 | // --------------------- PyTorch bindings for custom kernel ----------------------- 487 | #define STRINGFY(str) #str 488 | #define TORCH_BINDING_COMMON_EXTENSION(func) \ 489 | m.def(STRINGFY(func), &func, STRINGFY(func)); 490 | 491 | #define CHECK_TORCH_TENSOR_DTYPE(T, th_type) \ 492 | if(((T).options().dtype() != (th_type))) { \ 493 | std::cout << "Tensor Info:" << (T).options() << std::endl; \ 494 | throw std::runtime_error("values must be "#th_type); \ 495 | } 496 | 497 | #define CHECK_TORCH_TENSOR_SHAPE(T, S0, S1) \ 498 | if (((T).size(0) != (S0)) || ((T).size(1) != (S1))) { \ 499 | throw std::runtime_error("Tensor size mismatch!"); \ 500 | } 501 | 502 | #define LAUNCH_HGEMM_MMA_STAGES_CUTE_NO_SWIZZLE_TN(stages) \ 503 | launch_hgemm_mma_stages_block_swizzle_tn_cute< \ 504 | half, (stages), false>( \ 505 | reinterpret_cast(a.data_ptr()), \ 506 | reinterpret_cast(b.data_ptr()), \ 507 | reinterpret_cast(c.data_ptr()), \ 508 | M, N, K, 2048 \ 509 | ); 510 | 511 | #define LAUNCH_HGEMM_MMA_STAGES_CUTE_SWIZZLE_TN(stages, stride) \ 512 | launch_hgemm_mma_stages_block_swizzle_tn_cute< \ 513 | half, (stages), true>( \ 514 | reinterpret_cast(a.data_ptr()), \ 515 | reinterpret_cast(b.data_ptr()), \ 516 | reinterpret_cast(c.data_ptr()), \ 517 | M, N, K, (stride) \ 518 | ); 519 | 520 | 521 | // Multi stages CuTe HGEMM with SMEM Swizzle and Block Swizzle. 522 | void hgemm_mma_stages_block_swizzle_tn_cute( 523 | torch::Tensor a, torch::Tensor b, torch::Tensor c, 524 | int stages, bool swizzle, int swizzle_stride) { 525 | CHECK_TORCH_TENSOR_DTYPE(a, torch::kHalf) 526 | CHECK_TORCH_TENSOR_DTYPE(b, torch::kHalf) 527 | CHECK_TORCH_TENSOR_DTYPE(c, torch::kHalf) 528 | const int M = a.size(0); 529 | const int K = a.size(1); 530 | const int N = b.size(1); 531 | CHECK_TORCH_TENSOR_SHAPE(a, M, K) 532 | CHECK_TORCH_TENSOR_SHAPE(b, K, N) 533 | CHECK_TORCH_TENSOR_SHAPE(c, M, N) 534 | 535 | if (swizzle) { 536 | switch (stages) 537 | { 538 | case 2: 539 | LAUNCH_HGEMM_MMA_STAGES_CUTE_SWIZZLE_TN(2, swizzle_stride); 540 | break; 541 | case 3: 542 | LAUNCH_HGEMM_MMA_STAGES_CUTE_SWIZZLE_TN(3, swizzle_stride); 543 | break; 544 | case 4: 545 | LAUNCH_HGEMM_MMA_STAGES_CUTE_SWIZZLE_TN(4, swizzle_stride); 546 | break; 547 | default: 548 | LAUNCH_HGEMM_MMA_STAGES_CUTE_SWIZZLE_TN(2, swizzle_stride); 549 | break; 550 | } 551 | } else { 552 | switch (stages) { 553 | case 2: 554 | LAUNCH_HGEMM_MMA_STAGES_CUTE_NO_SWIZZLE_TN(2) 555 | break; 556 | case 3: 557 | LAUNCH_HGEMM_MMA_STAGES_CUTE_NO_SWIZZLE_TN(3) 558 | break; 559 | case 4: 560 | LAUNCH_HGEMM_MMA_STAGES_CUTE_NO_SWIZZLE_TN(4) 561 | break; 562 | default: 563 | LAUNCH_HGEMM_MMA_STAGES_CUTE_NO_SWIZZLE_TN(2) 564 | break; 565 | } 566 | } 567 | } 568 | #endif 569 | -------------------------------------------------------------------------------- /kernels/hgemm/hgemm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gc 3 | import torch 4 | import time 5 | from functools import partial 6 | from typing import Optional 7 | import argparse 8 | from tools.utils import (get_device_name, 9 | pretty_print_line, 10 | try_load_hgemm_library, 11 | as_col_major) 12 | 13 | torch.set_grad_enabled(False) 14 | 15 | 16 | def get_args(): 17 | parser = argparse.ArgumentParser(description="hgemm benchmark") 18 | parser.add_argument("--M", type=int, default=None, help="Matrix M size") 19 | parser.add_argument("--N", type=int, default=None, help="Matrix N size") 20 | parser.add_argument("--K", type=int, default=None, help="Matrix K size") 21 | parser.add_argument("--MNK", type=int, default=None, help="Matrix M=N=K size") 22 | parser.add_argument("--MMNK", type=int, default=12800, help="Matrix MAX M=M=N=K size") 23 | parser.add_argument("--SEP", '--sep', type=int, default=256, help="Matrix SEP M=M=N=K size") 24 | parser.add_argument("--warmup", "--w", type=int, default=2, help="Warmup iters") 25 | parser.add_argument("--iters", "--i", type=int, default=10, help="Benchmark iters") 26 | parser.add_argument("--verbose", "--v", action="store_true", help="Verbose") 27 | parser.add_argument("--show-matrix", "--show-m", action="store_true", help="Show output matrix values") 28 | parser.add_argument("--show-all-info", "--show-a", action="store_true", help="Show all the profile info") 29 | parser.add_argument("--show-memory", "--show-mm", action="store_true", help="Show gpu memory info") 30 | parser.add_argument("--enable-mma", "--mma", action="store_true", help="Enable MMA kernel tests") 31 | parser.add_argument("--enable-mma-tn", "--mma-tn", action="store_true", help="Enable TN MMA kernel tests") 32 | parser.add_argument("--enable-wmma", "--wmma", action="store_true", help="Enable WMMA kernel tests") 33 | parser.add_argument("--enable-cuda", "--cuda", action="store_true", help="Enable CUDA kernel tests") 34 | parser.add_argument("--enable-mma-all", "--mma-all", action="store_true", help="Enable all MMA kernel tests") 35 | parser.add_argument("--enable-wmma-all", "--wmma-all", action="store_true", help="Enable all WMMA kernel tests") 36 | parser.add_argument("--enable-cuda-all", "--cuda-all", action="store_true", help="Enable all CUDA kernel tests") 37 | parser.add_argument("--enable-torch", "--torch", action="store_true", help="Enable torch matmul") 38 | parser.add_argument("--enable-cute-tn", "--cute-tn", action="store_true", help="Enable cute hgemm matmul") 39 | parser.add_argument("--enable-cute", "--cute", action="store_true", help="Enable cute hgemm matmul") 40 | parser.add_argument("--disable-cublas", "--no-cublas", action="store_true", help="Disable cublas hgemm") 41 | parser.add_argument("--disable-cublas-tn", "--no-cublas-tn", action="store_true", help="Disable cublas TN hgemm") 42 | parser.add_argument("--sleep-duration", "--sleep", type=float, default=0.1, help="Sleep duration") 43 | parser.add_argument("--swizzle-factor", "--swizzle", type=float, default=None, help="Swizzle factor") 44 | parser.add_argument("--no-default", action="store_true", help="Disable default tests") 45 | parser.add_argument("--plot-flops", "--plot", action="store_true", help="Plot TFLOPS") 46 | parser.add_argument("--plot-topk", "--topk", type=int, default=8, help="Plot top k TFLOPS") 47 | parser.add_argument("--no-plot-best", "--no-best", action="store_true", help="Not Plot best TFLOPS") 48 | parser.add_argument("--exclude-tags", "--exclude", type=str, default=None, help="Exclude tag for plot, sperated by comma") 49 | parser.add_argument("--save-dir", "--dir", type=str, default="./", help="Save dir for plot") 50 | parser.add_argument("--save-tag", "--tag", type=str, default=None, help="Save name for plot") 51 | parser.add_argument("--force-build", "--build", action="store_true", help="Force build from sources") 52 | return parser.parse_args() 53 | 54 | 55 | args = get_args() 56 | pretty_print_line() 57 | print(args) 58 | pretty_print_line() 59 | 60 | 61 | hgemm = try_load_hgemm_library(force_build=args.force_build, verbose=args.verbose) 62 | 63 | MAX_TFLOPS = -1 64 | STATIS_INFO: dict[str, list[float]] = {} 65 | STATIS_INFO["MNK"] = [] 66 | TOATL_TFLOPS: dict[str, float] = {} 67 | CUBLAS_TOTAL_TFLOPS = 0 68 | CUBLAS_TN_TOTAL_TFLOPS = 0 69 | 70 | 71 | def make_block_swizzle_stride(N: int, K: int, swizzle_factor: float = None): 72 | # make swizzle stride as N/8,N/4,N/2 and multiples of 256 73 | if swizzle_factor is None: 74 | swizzle_factor = 0.5 if N <= 4096 else 0.25 75 | if all((N >= 14848, K > 8192, N % 8 == 0)): 76 | swizzle_factor = 0.125 77 | 78 | swizzle_stride = int(N * swizzle_factor) 79 | swizzle_stride = swizzle_stride if swizzle_stride >= 256 else 1 80 | 81 | return swizzle_stride 82 | 83 | 84 | @torch.no_grad 85 | def run_benchmark(perf_func: callable, 86 | a: torch.Tensor, b: torch.Tensor, 87 | tag: str, out: Optional[torch.Tensor] = None, 88 | stages: int = -1, swizzle: bool = False, 89 | swizzle_stride: int = 1, 90 | warmup: int = args.warmup, 91 | iters: int = args.iters, 92 | show_matrix: bool = args.show_matrix, 93 | only_show_improved: bool = not args.show_all_info): 94 | global MAX_TFLOPS 95 | 96 | M = a.size(0) 97 | K = a.size(1) 98 | N = b.size(1) # TN still has shape [K,N] 99 | if swizzle: 100 | swizzle_stride = make_block_swizzle_stride(N, K, args.swizzle_factor) 101 | swizzle = swizzle if swizzle_stride >= 256 else False 102 | else: 103 | swizzle_stride = 1 # means no thread block swizzle 104 | 105 | if stages: 106 | assert swizzle_stride is not None 107 | 108 | if out is not None: 109 | out.fill_(0) 110 | 111 | if "cublas" in tag: 112 | hgemm.init_cublas_handle() 113 | 114 | if out is not None: 115 | for i in range(warmup): 116 | if stages > 1: 117 | perf_func(a, b, out, stages, swizzle, swizzle_stride) 118 | else: 119 | perf_func(a, b, out) 120 | else: 121 | for i in range(warmup): 122 | _ = perf_func(a, b) 123 | 124 | torch.cuda.synchronize() 125 | start = time.time() 126 | # iters 127 | if out is not None: 128 | for i in range(iters): 129 | if stages > 1: 130 | perf_func(a, b, out, stages, swizzle, swizzle_stride) 131 | else: 132 | perf_func(a, b, out) 133 | else: 134 | for i in range(iters): 135 | out = perf_func(a, b) 136 | torch.cuda.synchronize() 137 | 138 | end = time.time() 139 | total_time_secs = (end - start) # ms 140 | mean_time_secs = total_time_secs / iters 141 | out_info = f"{tag}" 142 | out_flat = out.flatten() 143 | out_val_first = out_flat[:2].detach().cpu().numpy().tolist() 144 | out_val_last = out_flat[-2:].detach().cpu().numpy().tolist() 145 | out_val = [out_val_first[0], out_val_last[-1]] 146 | out_val = [round(v, 8) for v in out_val] 147 | out_val = [f"{v:<12}"[:10] for v in out_val] 148 | # 1 TFLOPS = 10^12 FLOPS 149 | # ref: https://imgtec.eetrend.com/blog/2021/100062210.html. 150 | TFLOPS = (2 * M * N * K) * 1e-12 / (mean_time_secs) 151 | mean_time_ms = mean_time_secs * 1000 152 | mean_time_ms = str(f"{mean_time_ms:<12}")[:8] # ms 153 | swizzle_stride = 'NOOP' if swizzle_stride == 1 else swizzle_stride 154 | 155 | # caculate TFLOPS improved. 156 | if TFLOPS > MAX_TFLOPS: 157 | if MAX_TFLOPS > 0: 158 | improve = ((TFLOPS - MAX_TFLOPS) / MAX_TFLOPS) * 100 159 | improve = round(improve, 2) 160 | else: 161 | improve = 0 162 | MAX_TFLOPS = TFLOPS 163 | print(f"{out_info:>53}: {out_val}, time:{mean_time_ms}ms, " 164 | f"swizzle: {swizzle_stride:<4}, TFLOPS: {TFLOPS:<6.2f}(+{improve:.2f}%)") 165 | else: 166 | if not only_show_improved or "cublas" in tag: 167 | print(f"{out_info:>53}: {out_val}, time:{mean_time_ms}ms, " 168 | f"swizzle: {swizzle_stride:<4}, TFLOPS: {TFLOPS:<6.2f}") 169 | if show_matrix: print(out) 170 | if args.plot_flops: 171 | STATIS_INFO[tag] = STATIS_INFO.get(tag, []) 172 | STATIS_INFO[tag].append(TFLOPS) 173 | if "cublas" not in tag: 174 | TOATL_TFLOPS[tag] = TOATL_TFLOPS.get(tag, 0) + TFLOPS 175 | else: 176 | global CUBLAS_TOTAL_TFLOPS 177 | global CUBLAS_TN_TOTAL_TFLOPS 178 | if tag == "tn(cublas)": 179 | CUBLAS_TN_TOTAL_TFLOPS += TFLOPS 180 | else: 181 | CUBLAS_TOTAL_TFLOPS += TFLOPS 182 | 183 | torch.cuda.synchronize() 184 | if "cublas" in tag: 185 | hgemm.destroy_cublas_handle() 186 | 187 | del out_flat 188 | out_flat = None 189 | gc.collect() 190 | torch.cuda.empty_cache() 191 | time.sleep(args.sleep_duration) 192 | return out, mean_time_ms 193 | 194 | 195 | def get_topk_tflops(): 196 | topk_tflops = sorted(TOATL_TFLOPS.items(), key=lambda x: x[1], 197 | reverse=True) 198 | pretty_print_line() 199 | pretty_print_line(f"THE TOTAL TFLOPS OF {len(topk_tflops)} HGEMM ALGO ON {get_device_name()} DEVICE", " ") 200 | pretty_print_line() 201 | for tag, tflops in list(topk_tflops)[::-1]: 202 | print(f"{tag:>53}: {tflops:>20.2f} TFLOPS") 203 | if CUBLAS_TN_TOTAL_TFLOPS > 1: 204 | print(f"{'tn(cublas)':>53}: {CUBLAS_TN_TOTAL_TFLOPS:>20.2f} TFLOPS") 205 | if CUBLAS_TOTAL_TFLOPS > 1: 206 | print(f"{'(cublas)':>53}: {CUBLAS_TOTAL_TFLOPS:>20.2f} TFLOPS") 207 | pretty_print_line() 208 | return list(dict(topk_tflops[:args.plot_topk]).keys()) 209 | 210 | 211 | @torch.no_grad 212 | def get_best_tflops(): 213 | all_tflops = [] 214 | for tag, tflops in STATIS_INFO.items(): 215 | if "cublas" not in tag and "MNK" not in tag: 216 | all_tflops.append(tflops) 217 | # [N, NUM_MNK], reduce max on N dim 218 | all_tflops = torch.tensor(all_tflops, dtype=torch.float) 219 | best_tflops = torch.max(all_tflops, dim=0, keepdim=False)[0].tolist() 220 | return best_tflops 221 | 222 | 223 | def plot_tflops(): 224 | import matplotlib.pyplot as plt 225 | import numpy as np 226 | ax: plt.Axes = plt.subplots(figsize=(16, 9))[1] # fig, axs 227 | plt.subplots_adjust(left=0.04, right=0.99, top=0.95, bottom=0.05) 228 | ax.set_title(f"My HGEMM vs cuBLAS, {get_device_name()}, Warmup={args.warmup}, Iters={args.iters}") 229 | ax.set_xlabel("M=N=K") 230 | ax.set_ylabel("TFLOPS") 231 | ax.grid(True) 232 | ax.set_xticks(np.arange(0, len(STATIS_INFO["MNK"]), 1)) 233 | ax.set_xticklabels(STATIS_INFO["MNK"], rotation=45, ha='right') 234 | exclude_tags = args.exclude_tags.split(",") if args.exclude_tags else [] 235 | exclude_tags.append("MNK") 236 | exclude_tags = set(exclude_tags) 237 | 238 | topk_tflops = get_topk_tflops() 239 | STATIS_INFO["(best)"] = get_best_tflops() 240 | draw_tags = topk_tflops 241 | draw_tags.append("(cublas)") 242 | draw_tags.append("tn(cublas)") 243 | draw_tags.append("(best)") 244 | 245 | def skip_it(tag: str) -> bool: 246 | for etag in exclude_tags: 247 | if etag in tag: 248 | return True 249 | if tag not in draw_tags: 250 | return True 251 | return False 252 | 253 | for tag, tflops in STATIS_INFO.items(): 254 | if skip_it(tag): 255 | continue 256 | if tag == "(cublas)": 257 | ax.plot(tflops, label=tag, linewidth=3, color='orange') 258 | elif tag == "tn(cublas)": 259 | ax.plot(tflops, label=tag, linewidth=3, color='green') 260 | else: 261 | if "best" in tag and not args.no_plot_best: 262 | ax.plot(tflops, label=tag, linewidth=4, color='blue') 263 | else: 264 | ax.plot(tflops, label=tag, linestyle='--') 265 | 266 | ax.legend() 267 | device_name = get_device_name().replace(" ", "_") 268 | if args.save_tag: 269 | save_path = f"{args.save_dir}/{device_name}_{args.save_tag}.png" 270 | else: 271 | save_path = f"{args.save_dir}/{device_name}.png" 272 | os.makedirs(args.save_dir, exist_ok=True) 273 | plt.savefig(save_path, dpi=300) 274 | pretty_print_line(f"plot hgemm TFLOPS done, saved as {save_path}") 275 | 276 | 277 | def get_mnk(sep: int = args.SEP): 278 | Ms = list(range(sep, args.MMNK + sep, sep)) 279 | Ns = list(range(sep, args.MMNK + sep, sep)) 280 | Ks = list(range(sep, args.MMNK + sep, sep)) 281 | return Ms, Ns, Ks 282 | 283 | 284 | Ms, Ns, Ks = get_mnk() 285 | STATIS_INFO["MNK"] = Ms 286 | if args.MNK: 287 | Ms = [args.MNK] 288 | Ns = [args.MNK] 289 | Ks = [args.MNK] 290 | # prefer different M, N, K 291 | if args.M and args.N and args.K: 292 | Ms = [args.M] 293 | Ns = [args.N] 294 | Ks = [args.K] 295 | MAX_M, MAX_N, MAX_K = max(Ms), max(Ns), max(Ks) 296 | # pre allocate for fast profiling. 297 | torch.cuda.synchronize() 298 | start = time.time() 299 | pretty_print_line(f"Allocate buffers for fast profiling start, MAX_M={MAX_M}, MAX_N={MAX_N}, MAX_K={MAX_K}") 300 | A = torch.randn((MAX_M, MAX_K), dtype=torch.half, device="cuda").cuda() 301 | B = torch.randn((MAX_K, MAX_N), dtype=torch.half, device="cuda").cuda() 302 | C = torch.randn((MAX_M, MAX_N), dtype=torch.half, device="cuda").cuda() 303 | torch.cuda.synchronize() 304 | end = time.time() 305 | pretty_print_line(f"Allocate buffers for fast profiling done, time: {(end - start) * 1000:.7f} ms") 306 | 307 | PERF_COUNT = 0 308 | for (M, N, K) in zip(Ms, Ns, Ks): 309 | MAX_TFLOPS = -1 310 | PERF_COUNT += 1 311 | pretty_print_line() 312 | pretty_print_line(f"M={M}, N={N}, K={K}, Warmup={args.warmup}, Iters={args.iters}, {PERF_COUNT}/{len(Ms)}", sep=" ") 313 | pretty_print_line() 314 | a = A[:M, :K].contiguous() 315 | b = B[:K, :N].contiguous() 316 | c = C[:M, :N].contiguous() 317 | b_col_major = as_col_major(b) 318 | torch.cuda.synchronize() 319 | # CUDA Cores FP16, NN 320 | if args.enable_cuda_all: # more cuda cores kernel tests 321 | run_benchmark(hgemm.hgemm_naive_f16, a, b, "(naive)", c) 322 | run_benchmark(hgemm.hgemm_t_8x8_sliced_k_f16x8_pack_bcf, a, b, "(f16x8pack+t8x8+bcf)", c) 323 | if (args.enable_cuda or args.enable_cuda_all) and (not args.no_default): 324 | run_benchmark(hgemm.hgemm_t_8x8_sliced_k_f16x8_pack_bcf_dbuf, a, b, "(f16x8pack+t8x8+dbuf)", c) 325 | run_benchmark(hgemm.hgemm_t_8x8_sliced_k16_f16x8_pack_dbuf, a, b, "(f16x8pack+t8x8+k16+dbuf)", c) 326 | # WMMA API, stages, dsmem, swizzle, NN 327 | if (args.enable_wmma or args.enable_wmma_all) and (not args.no_default): 328 | pretty_print_line("WMMA") 329 | run_benchmark(hgemm.hgemm_wmma_m16n16k16_mma4x2, a, b, "(wmma4x2)", c) 330 | run_benchmark(hgemm.hgemm_wmma_m16n16k16_mma4x2_warp2x4, a, b, "(wmma4x2+warp2x4)", c) 331 | run_benchmark(hgemm.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem, a, b, "(wmma4x2+warp2x4+stage3+dsmem)", c, stages=3) 332 | run_benchmark(hgemm.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem, a, b, "(wmma4x2+warp2x4+stage2+dsmem)", c, stages=2) 333 | run_benchmark(hgemm.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem, a, b, "(wmma4x2+warp2x4+stage3+dsmem+swizzle)", c, stages=3, swizzle=True) 334 | run_benchmark(hgemm.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem, a, b, "(wmma4x2+warp2x4+stage2+dsmem+swizzle)", c, stages=2, swizzle=True) 335 | if args.enable_wmma_all: # more wmma kernel tests. 336 | run_benchmark(hgemm.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages, a, b, "(wmma4x2+warp2x4+stage3)", c, stages=3) 337 | run_benchmark(hgemm.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages, a, b, "(wmma4x2+warp2x4+stage2)", c, stages=2) 338 | run_benchmark(hgemm.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages, a, b, "(wmma4x2+warp2x4+stage3+swizzle)", c, stages=3, swizzle=True) 339 | run_benchmark(hgemm.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages, a, b, "(wmma4x2+warp2x4+stage2+swizzle)", c, stages=2, swizzle=True) 340 | # Prefer on NVIDIA TRX 3080 Laptop 16GB GDDR6 device. 341 | run_benchmark(hgemm.hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem, a, b, "(wmma4x4+warp4x4+stage3+dsmem)", c, stages=3) 342 | run_benchmark(hgemm.hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem, a, b, "(wmma4x4+warp4x4+stage2+dsmem)", c, stages=2) 343 | run_benchmark(hgemm.hgemm_wmma_m16n16k16_mma4x2_warp4x4_stages_dsmem, a, b, "(wmma4x2+warp4x4+stage3+dsmem)", c, stages=3) 344 | run_benchmark(hgemm.hgemm_wmma_m16n16k16_mma4x2_warp4x4_stages_dsmem, a, b, "(wmma4x2+warp4x4+stage2+dsmem)", c, stages=2) 345 | run_benchmark(hgemm.hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem, a, b, "(wmma4x4+warp4x4+stage3+dsmem+swizzle)", c, stages=3, swizzle=True) 346 | run_benchmark(hgemm.hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem, a, b, "(wmma4x4+warp4x4+stage2+dsmem+swizzle)", c, stages=2, swizzle=True) 347 | run_benchmark(hgemm.hgemm_wmma_m16n16k16_mma4x2_warp4x4_stages_dsmem, a, b, "(wmma4x2+warp4x4+stage3+dsmem+swizzle)", c, stages=3, swizzle=True) 348 | run_benchmark(hgemm.hgemm_wmma_m16n16k16_mma4x2_warp4x4_stages_dsmem, a, b, "(wmma4x2+warp4x4+stage2+dsmem+swizzle)", c, stages=2, swizzle=True) 349 | # MMA API, stages, dsmem, swizzle, NN 350 | if (args.enable_mma or args.enable_mma_all) and (not args.no_default): 351 | pretty_print_line("MMA") 352 | if args.enable_mma_all: # more mma kernel tests. 353 | run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4, a, b, "(mma2x4+warp4x4)", c) 354 | run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4_stages, a, b, "(mma2x4+warp4x4+stage3)", c, stages=3) 355 | run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4_stages, a, b, "(mma2x4+warp4x4+stage2)", c, stages=2) 356 | if (args.enable_mma or args.enable_mma_all) and (not args.no_default): 357 | run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem, a, b, "(mma2x4+warp4x4+stage3+dsmem)", c, stages=3) 358 | run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem, a, b, "(mma2x4+warp4x4+stage2+dsmem)", c, stages=2) 359 | run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem, a, b, "(mma2x4+warp4x4x2+stage4+dsmem)", c, stages=4) 360 | run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem, a, b, "(mma2x4+warp4x4x2+stage3+dsmem)", c, stages=3) 361 | run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem, a, b, "(mma2x4+warp4x4x2+stage2+dsmem)", c, stages=2) 362 | run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_swizzle, a, b, "(mma2x4+warp4x4x2+stage4+dsmem+swizzle)", c, stages=4) 363 | run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_swizzle, a, b, "(mma2x4+warp4x4x2+stage3+dsmem+swizzle)", c, stages=3) 364 | run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_swizzle, a, b, "(mma2x4+warp4x4x2+stage2+dsmem+swizzle)", c, stages=2) 365 | if args.enable_mma_all: # more mma kernel tests. 366 | run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_rr, a, b, "(mma2x4+warp4x4x2+stage4+dsmem+rr)", c, stages=4) 367 | run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_rr, a, b, "(mma2x4+warp4x4x2+stage3+dsmem+rr)", c, stages=3) 368 | run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_rr, a, b, "(mma2x4+warp4x4x2+stage2+dsmem+rr)", c, stages=2) 369 | run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_x4, a, b, "(mma2x4+warp4x4x2+stage4+dsmem+x4)", c, stages=4) 370 | run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_x4, a, b, "(mma2x4+warp4x4x2+stage3+dsmem+x4)", c, stages=3) 371 | run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_x4, a, b, "(mma2x4+warp4x4x2+stage2+dsmem+x4)", c, stages=2) 372 | if (args.enable_mma or args.enable_mma_all) and (not args.no_default): 373 | # Thread block swizzle 374 | run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem, a, b, "(mma2x4+warp4x4+stage3+dsmem+swizzle)", c, stages=3, swizzle=True) 375 | run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem, a, b, "(mma2x4+warp4x4+stage2+dsmem+swizzle)", c, stages=2, swizzle=True) 376 | run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem, a, b, "(mma2x4+warp4x4x2+stage4+dsmem+swizzle)", c, stages=4, swizzle=True) 377 | run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem, a, b, "(mma2x4+warp4x4x2+stage3+dsmem+swizzle)", c, stages=3, swizzle=True) 378 | run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem, a, b, "(mma2x4+warp4x4x2+stage2+dsmem+swizzle)", c, stages=2, swizzle=True) 379 | run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_swizzle, a, b, "(mma2x4+warp4x4x2+stage4+dsmem+swizzle)", c, stages=4, swizzle=True) 380 | run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_swizzle, a, b, "(mma2x4+warp4x4x2+stage3+dsmem+swizzle)", c, stages=3, swizzle=True) 381 | run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_swizzle, a, b, "(mma2x4+warp4x4x2+stage2+dsmem+swizzle)", c, stages=2, swizzle=True) 382 | if args.enable_mma_all: 383 | run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4_stages, a, b, "(mma2x4+warp4x4+stage3+swizzle)", c, stages=3, swizzle=True) 384 | run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4_stages, a, b, "(mma2x4+warp4x4+stage2+swizzle)", c, stages=2, swizzle=True) 385 | run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_rr, a, b, "(mma2x4+warp4x4x2+stage4+dsmem+swizzle+rr)", c, stages=4, swizzle=True) 386 | run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_rr, a, b, "(mma2x4+warp4x4x2+stage3+dsmem+swizzle+rr)", c, stages=3, swizzle=True) 387 | run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_rr, a, b, "(mma2x4+warp4x4x2+stage2+dsmem+swizzle+rr)", c, stages=2, swizzle=True) 388 | run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_x4, a, b, "(mma2x4+warp4x4x2+stage4+dsmem+swizzle+x4)", c, stages=4, swizzle=True) 389 | run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_x4, a, b, "(mma2x4+warp4x4x2+stage3+dsmem+swizzle+x4)", c, stages=3, swizzle=True) 390 | run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_x4, a, b, "(mma2x4+warp4x4x2+stage2+dsmem+swizzle+x4)", c, stages=2, swizzle=True) 391 | # TN(MMA/CuTe), TN layout: A row major with shape [M,K], B col major with shape [K,N] 392 | if any((args.enable_mma_tn, args.enable_cute_tn)): 393 | pretty_print_line("TN(MMA/CuTe)") 394 | if args.enable_mma_tn: 395 | run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn, a, b_col_major, "tn(mma2x4+warp4x4+stage3+dsmem)", c, stages=3) 396 | run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn, a, b_col_major, "tn(mma2x4+warp4x4+stage2+dsmem)", c, stages=2) 397 | run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_tn_swizzle_x4, a, b_col_major, "tn(mma2x4+warp4x4x2+stage4+dsmem+swizzle)", c, stages=4) 398 | run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_tn_swizzle_x4, a, b_col_major, "tn(mma2x4+warp4x4x2+stage3+dsmem+swizzle)", c, stages=3) 399 | run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_tn_swizzle_x4, a, b_col_major, "tn(mma2x4+warp4x4x2+stage2+dsmem+swizzle)", c, stages=2) 400 | run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn, a, b_col_major, "tn(mma2x4+warp4x4+stage3+dsmem+swizzle)", c, stages=3, swizzle=True) 401 | run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn, a, b_col_major, "tn(mma2x4+warp4x4+stage2+dsmem+swizzle)", c, stages=2, swizzle=True) 402 | run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_tn_swizzle_x4, a, b_col_major, "tn(mma2x4+warp4x4x2+stage4+dsmem+swizzle)", c, stages=4, swizzle=True) 403 | run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_tn_swizzle_x4, a, b_col_major, "tn(mma2x4+warp4x4x2+stage3+dsmem+swizzle)", c, stages=3, swizzle=True) 404 | run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_tn_swizzle_x4, a, b_col_major, "tn(mma2x4+warp4x4x2+stage2+dsmem+swizzle)", c, stages=2, swizzle=True) 405 | if args.enable_cute_tn: 406 | run_benchmark(hgemm.hgemm_mma_stages_block_swizzle_tn_cute, a, b_col_major, "tn(cute+stage4+swizzle)", c, stages=4) 407 | run_benchmark(hgemm.hgemm_mma_stages_block_swizzle_tn_cute, a, b_col_major, "tn(cute+stage3+swizzle)", c, stages=3) 408 | run_benchmark(hgemm.hgemm_mma_stages_block_swizzle_tn_cute, a, b_col_major, "tn(cute+stage2+swizzle)", c, stages=2) 409 | run_benchmark(hgemm.hgemm_mma_stages_block_swizzle_tn_cute, a, b_col_major, "tn(cute+stage4+swizzle)", c, stages=4, swizzle=True) 410 | run_benchmark(hgemm.hgemm_mma_stages_block_swizzle_tn_cute, a, b_col_major, "tn(cute+stage3+swizzle)", c, stages=3, swizzle=True) 411 | run_benchmark(hgemm.hgemm_mma_stages_block_swizzle_tn_cute, a, b_col_major, "tn(cute+stage2+swizzle)", c, stages=2, swizzle=True) 412 | # TN layout: cublas 413 | if not args.disable_cublas_tn and any((args.enable_mma_tn, args.enable_cute_tn)): 414 | run_benchmark(hgemm.hgemm_cublas_tensor_op_tn, a, b_col_major, "tn(cublas)", c) 415 | # NN layout: cublas/torch 416 | if (not args.disable_cublas) and any(( 417 | args.enable_mma, args.enable_mma_all, args.enable_wmma, args.enable_wmma_all, 418 | args.enable_cuda, args.enable_cuda_all, args.enable_torch)): 419 | run_benchmark(hgemm.hgemm_cublas_tensor_op_nn, a, b, "(cublas)", c) 420 | if args.enable_torch: 421 | run_benchmark(partial(torch.matmul, out=c), a, b, "(torch)") 422 | torch.cuda.synchronize() 423 | # Avoid OOM 424 | del a; a = None 425 | del b; b = None 426 | del c; c = None 427 | del b_col_major; 428 | b_col_major = None 429 | gc.collect() 430 | torch.cuda.empty_cache() 431 | gc.collect() 432 | pretty_print_line() 433 | 434 | if args.show_memory: 435 | pretty_print_line() 436 | print(torch.cuda.memory_summary()) 437 | pretty_print_line() 438 | 439 | if args.plot_flops: 440 | plot_tflops() 441 | -------------------------------------------------------------------------------- /kernels/hgemm/mma/basic/hgemm_mma_stage_tn.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | using namespace nvcuda; 13 | 14 | #define WARP_SIZE 32 15 | #define DEVICE_INLINE __device__ inline 16 | #define HOST_DEVICE_INLINE __device__ __host__ inline 17 | #define INT4(value) (reinterpret_cast(&(value))[0]) 18 | #define FLOAT4(value) (reinterpret_cast(&(value))[0]) 19 | #define HALF2(value) (reinterpret_cast(&(value))[0]) 20 | #define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162*>(&(value))[0]) 21 | #define LDST32BITS(value) (reinterpret_cast(&(value))[0]) 22 | #define LDST64BITS(value) (reinterpret_cast(&(value))[0]) 23 | #define LDST128BITS(value) (reinterpret_cast(&(value))[0]) 24 | // gmem -> smem 25 | #define CP_ASYNC_COMMIT_GROUP() asm volatile("cp.async.commit_group;\n" ::) 26 | #define CP_ASYNC_WAIT_ALL() asm volatile("cp.async.wait_all;\n" ::) 27 | #define CP_ASYNC_WAIT_GROUP(n) asm volatile("cp.async.wait_group %0;\n" ::"n"(n)) 28 | // ca(cache all, L1 + L2): support 4, 8, 16 bytes, cg(cache global, L2): only support 16 bytes. 29 | #define CP_ASYNC_CA(dst, src, bytes) asm volatile("cp.async.ca.shared.global.L2::128B [%0], [%1], %2;\n" ::"r"(dst), "l"(src), "n"(bytes)) 30 | #define CP_ASYNC_CG(dst, src, bytes) asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], %2;\n" ::"r"(dst), "l"(src), "n"(bytes)) 31 | // smem -> gmem: requires sm_90 or higher. 32 | #define CP_ASYNC_BULK_COMMIT_GROUP() asm volatile("cp.async.bulk.commit_group;\n" ::) 33 | #define CP_ASYNC_BULK_WAIT_ALL() asm volatile("cp.async.bulk.wait_all;\n" ::) 34 | #define CP_ASYNC_BULK_WAIT_GROUP(n) asm volatile("cp.async.bulk.wait_group %0;\n" ::"n"(n)) 35 | #define CP_ASYNC_BULK(dst, src, bytes) asm volatile("cp.async.bulk.global.shared::cta.bulk_group.L2::128B [%0], [%1], %2;\n" ::"r"(dst), "l"(src), "n"(bytes)) 36 | // ldmatrix 37 | #define LDMATRIX_X1(R, addr) asm volatile("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];\n" : "=r"(R) : "r"(addr)) 38 | #define LDMATRIX_X2(R0, R1, addr) asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" : "=r"(R0), "=r"(R1) : "r"(addr)) 39 | #define LDMATRIX_X4(R0, R1, R2, R3, addr) asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" : "=r"(R0), "=r"(R1), "=r"(R2), "=r"(R3) : "r"(addr)) 40 | #define LDMATRIX_X1_T(R, addr) asm volatile("ldmatrix.sync.aligned.x1.trans.m8n8.shared.b16 {%0}, [%1];\n" : "=r"(R) : "r"(addr)) 41 | #define LDMATRIX_X2_T(R0, R1, addr) asm volatile("ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0, %1}, [%2];\n" : "=r"(R0), "=r"(R1) : "r"(addr)) 42 | #define LDMATRIX_X4_T(R0, R1, R2, R3, addr) asm volatile("ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" : "=r"(R0), "=r"(R1), "=r"(R2), "=r"(R3) : "r"(addr)) 43 | // stmatrix: requires sm_90 or higher. 44 | #define STMATRIX_X1(addr, R) asm volatile("stmatrix.sync.aligned.x1.m8n8.shared.b16 [%0], {%1};\n" :: "r"(addr), "r"(R)) 45 | #define STMATRIX_X2(addr, R0, R1) asm volatile("stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};\n" :: "r"(addr), "r"(R0), "r"(R1)) 46 | #define STMATRIX_X4(addr, R0, R1, R2, R3) asm volatile("stmatrix.sync.aligned.x4.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n" :: "r"(addr), "r"(R0), "r"(R1), "r"(R2), "r"(R3)) 47 | #define STMATRIX_X1_T(addr, R) asm volatile("stmatrix.sync.aligned.x1.trans.m8n8.shared.b16 [%0], {%1};\n" :: "r"(addr), "r"(R)) 48 | #define STMATRIX_X2_T(addr, R0, R1) asm volatile("stmatrix.sync.aligned.x2.trans.m8n8.shared.b16 [%0], {%1, %2};\n" :: "r"(addr), "r"(R0), "r"(R1)) 49 | #define STMATRIX_X4_T(addr, R0, R1, R2, R3) asm volatile("stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n" :: "r"(addr), "r"(R0), "r"(R1), "r"(R2), "r"(R3)) 50 | // mma m16n8k16 51 | #define HMMA16816(RD0, RD1, RA0, RA1, RA2, RA3, RB0, RB1, RC0, RC1) asm volatile("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%8, %9};\n" : "=r"(RD0), "=r"(RD1) : "r"(RA0), "r"(RA1), "r"(RA2), "r"(RA3), "r"(RB0), "r"(RB1), "r"(RC0), "r"(RC1)) 52 | 53 | HOST_DEVICE_INLINE 54 | int div_ceil(int a, int b) { return (a % b != 0) ? (a / b + 1) : (a / b); } 55 | 56 | // NN: A/B/C All row major 57 | // TN: A row major MxK, B col major NxK, C row major MxN 58 | // 128x128, mma2x4, warp4x4(64,32,16), stages, block swizzle, dsmem 59 | template 70 | __global__ void __launch_bounds__(256) 71 | hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn_kernel( 72 | half* A, half* B, half* C, int M, int N, int K) { 73 | // BLOCK_SWIZZLE 0/1 control use block swizzle or not. 74 | const int bx = ((int) BLOCK_SWIZZLE) * blockIdx.z * gridDim.x + blockIdx.x; 75 | const int by = blockIdx.y; 76 | const int NUM_K_TILES = div_ceil(K, MMA_K); 77 | constexpr int BM = MMA_M * MMA_TILE_M * WARP_TILE_M; // 16*2*4=128 78 | constexpr int BN = MMA_N * MMA_TILE_N * WARP_TILE_N; // 8*4*4=128 79 | constexpr int BK = MMA_K; // 16 80 | 81 | extern __shared__ half smem[]; 82 | half* s_a = smem; 83 | half* s_b = smem + K_STAGE * BM * (BK + A_PAD); 84 | constexpr int s_a_stage_offset = BM * (BK + A_PAD); // BMxBK 128*16 85 | constexpr int s_b_stage_offset = BN * (BK + B_PAD); // BNxBK 128*16 86 | 87 | const int tid = threadIdx.y * blockDim.x + threadIdx.x; // within block 88 | const int warp_id = tid / WARP_SIZE; // 0~7 warp_id within block 89 | const int lane_id = tid % WARP_SIZE; // 0~31 90 | const int warp_m = warp_id % 2; // 0,1 91 | const int warp_n = warp_id / 2; // 0,1,2,3 92 | 93 | int load_smem_a_m = tid / 2; // row 0~127 94 | int load_smem_a_k = (tid % 2 == 0) ? 0 : 8; // col 0,8 95 | int load_smem_b_n = tid / 2; // row 0~127 96 | int load_smem_b_k = (tid % 2 == 0) ? 0 : 8; // col 0,8 97 | int load_gmem_a_m = by * BM + load_smem_a_m; // global row of c 98 | int load_gmem_b_n = bx * BN + load_smem_b_n; // global col of c 99 | if (load_gmem_a_m >= M || load_gmem_b_n >= N) return; 100 | 101 | uint32_t RC[WARP_TILE_M][WARP_TILE_N][2]; 102 | #pragma unroll 103 | for (int i = 0; i < WARP_TILE_M; ++i) { 104 | #pragma unroll 105 | for (int j = 0; j < WARP_TILE_N; ++j) { 106 | RC[i][j][0] = 0; 107 | RC[i][j][1] = 0; 108 | } 109 | } 110 | 111 | // may avoid cvta overhead ? only cvta smem base ptr once for cp.async. 112 | uint32_t smem_a_base_ptr = __cvta_generic_to_shared(s_a); 113 | uint32_t smem_b_base_ptr = __cvta_generic_to_shared(s_b); 114 | 115 | #pragma unroll 116 | for (int k = 0; k < (K_STAGE - 1); ++k) { // 0, 1 117 | int load_gmem_a_k = k * BK + load_smem_a_k; // global col of a 118 | int load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k; 119 | int load_gmem_b_k = k * BK + load_smem_b_k; // global col of b 120 | int load_gmem_b_addr = load_gmem_b_n * K + load_gmem_b_k; 121 | 122 | uint32_t load_smem_a_ptr = ( 123 | smem_a_base_ptr + (k * s_a_stage_offset + 124 | load_smem_a_m * (BK + A_PAD) + 125 | load_smem_a_k) * sizeof(half) 126 | ); 127 | CP_ASYNC_CG(load_smem_a_ptr, &A[load_gmem_a_addr], 16); 128 | 129 | uint32_t load_smem_b_ptr = ( 130 | smem_b_base_ptr + (k * s_b_stage_offset + 131 | load_smem_b_n * (BK + B_PAD) + 132 | load_smem_b_k) * sizeof(half) 133 | ); 134 | CP_ASYNC_CG(load_smem_b_ptr, &B[load_gmem_b_addr], 16); 135 | 136 | CP_ASYNC_COMMIT_GROUP(); 137 | } 138 | 139 | CP_ASYNC_WAIT_GROUP(K_STAGE-2); // s2->0, s3->1, s4->2 140 | __syncthreads(); 141 | 142 | #pragma unroll 143 | for (int k = (K_STAGE - 1); k < NUM_K_TILES; ++k) { 144 | // gmem -> smem 145 | // s2/4 can use bitwise ops but s3 can not, so, we use mod 146 | // ops for all stages kernel. s2: (k + 1)&1, s4: (k + 1)&3 147 | // s3: (k + 1) % 3 148 | int smem_sel = (k + 1) % K_STAGE; // s3 k 2->0, k 3->1, k 4->2... 149 | int smem_sel_next = k % K_STAGE; // s3 k 2->2, k 3->0, k 4->1... 150 | 151 | int load_gmem_a_k = k * BK + load_smem_a_k; // global col of a 152 | int load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k; 153 | int load_gmem_b_k = k * BK + load_smem_b_k; // global col of b 154 | int load_gmem_b_addr = load_gmem_b_n * K + load_gmem_b_k; 155 | 156 | uint32_t load_smem_a_ptr = ( 157 | smem_a_base_ptr + (smem_sel_next * s_a_stage_offset + 158 | load_smem_a_m * (BK + A_PAD) + 159 | load_smem_a_k) * sizeof(half) 160 | ); 161 | CP_ASYNC_CG(load_smem_a_ptr, &A[load_gmem_a_addr], 16); 162 | 163 | uint32_t load_smem_b_ptr = ( 164 | smem_b_base_ptr + (smem_sel_next * s_b_stage_offset + 165 | load_smem_b_n * (BK + B_PAD) + 166 | load_smem_b_k) * sizeof(half) 167 | ); 168 | CP_ASYNC_CG(load_smem_b_ptr, &B[load_gmem_b_addr], 16); 169 | 170 | CP_ASYNC_COMMIT_GROUP(); 171 | 172 | uint32_t RA[WARP_TILE_M][4]; 173 | uint32_t RB[WARP_TILE_N][2]; 174 | // smem -> reg 175 | #pragma unroll 176 | for (int i = 0; i < WARP_TILE_M; ++i) { 177 | int warp_smem_a_m = warp_m * (MMA_M * WARP_TILE_M) + i * MMA_M; 178 | int lane_smem_a_m = warp_smem_a_m + lane_id % 16; // 0~15 179 | int lane_smem_a_k = (lane_id / 16) * 8; // 0,8 180 | uint32_t lane_smem_a_ptr = ( 181 | smem_a_base_ptr + (smem_sel * s_a_stage_offset + 182 | lane_smem_a_m * (BK + A_PAD) + 183 | lane_smem_a_k) * sizeof(half) 184 | ); 185 | LDMATRIX_X4(RA[i][0], RA[i][1], RA[i][2], RA[i][3], lane_smem_a_ptr); 186 | } 187 | 188 | #pragma unroll 189 | for (int j = 0; j < WARP_TILE_N; ++j) { 190 | int warp_smem_b_n = warp_n * (MMA_N * WARP_TILE_N) + j * MMA_N; 191 | int lane_smem_b_n = warp_smem_b_n + lane_id % 8; // 0~7, MMA_N=8 192 | int lane_smem_b_k = ((lane_id / 8) % 2) * 8; // 0,8 193 | uint32_t lane_smem_b_ptr = ( 194 | smem_b_base_ptr + (smem_sel * s_b_stage_offset + 195 | lane_smem_b_n * (BK + B_PAD) + 196 | lane_smem_b_k) * sizeof(half) 197 | ); 198 | LDMATRIX_X2(RB[j][0], RB[j][1], lane_smem_b_ptr); 199 | } 200 | 201 | // MMA compute 202 | #pragma unroll 203 | for (int i = 0; i < WARP_TILE_M; ++i) { 204 | #pragma unroll 205 | for (int j = 0; j < WARP_TILE_N; ++j) { 206 | HMMA16816(RC[i][j][0], RC[i][j][1], 207 | RA[i][0], RA[i][1], RA[i][2], RA[i][3], 208 | RB[j][0], RB[j][1], 209 | RC[i][j][0], RC[i][j][1]); 210 | } 211 | } 212 | 213 | CP_ASYNC_WAIT_GROUP(K_STAGE-2); 214 | __syncthreads(); 215 | } 216 | 217 | // make sure all memory issues ready. 218 | if ((K_STAGE - 2) > 0) { 219 | CP_ASYNC_WAIT_GROUP(0); 220 | __syncthreads(); 221 | } 222 | 223 | // processing last (K_STAGE-1) k iters. 224 | { 225 | #pragma unroll 226 | for (int k = 0; k < (K_STAGE - 1); k++) { 227 | uint32_t RA[WARP_TILE_M][4]; 228 | uint32_t RB[WARP_TILE_N][2]; 229 | 230 | int stage_sel = ((NUM_K_TILES - (K_STAGE - 1) + k) % K_STAGE); 231 | // ldmatrix for s_a, ldmatrix.trans for s_b. 232 | // smem -> reg 233 | #pragma unroll 234 | for (int i = 0; i < WARP_TILE_M; ++i) { 235 | int warp_smem_a_m = warp_m * (MMA_M * WARP_TILE_M) + i * MMA_M; 236 | int lane_smem_a_m = warp_smem_a_m + lane_id % 16; // 0~15 237 | int lane_smem_a_k = (lane_id / 16) * 8; // 0,8 238 | uint32_t lane_smem_a_ptr = ( 239 | smem_a_base_ptr + (stage_sel * s_a_stage_offset + 240 | lane_smem_a_m * (BK + A_PAD) + 241 | lane_smem_a_k) * sizeof(half) 242 | ); 243 | LDMATRIX_X4(RA[i][0], RA[i][1], RA[i][2], RA[i][3], lane_smem_a_ptr); 244 | } 245 | 246 | #pragma unroll 247 | for (int j = 0; j < WARP_TILE_N; ++j) { 248 | int warp_smem_b_n = warp_n * (MMA_N * WARP_TILE_N) + j * MMA_N; 249 | int lane_smem_b_n = warp_smem_b_n + lane_id % 8; // 0~7, MMA_N=8 250 | int lane_smem_b_k = ((lane_id / 8) % 2) * 8; // 0,8 251 | uint32_t lane_smem_b_ptr = ( 252 | smem_b_base_ptr + (stage_sel * s_b_stage_offset + 253 | lane_smem_b_n * (BK + B_PAD) + 254 | lane_smem_b_k) * sizeof(half) 255 | ); 256 | LDMATRIX_X2(RB[j][0], RB[j][1], lane_smem_b_ptr); 257 | } 258 | 259 | // MMA compute 260 | #pragma unroll 261 | for (int i = 0; i < WARP_TILE_M; ++i) { 262 | #pragma unroll 263 | for (int j = 0; j < WARP_TILE_N; ++j) { 264 | HMMA16816(RC[i][j][0], RC[i][j][1], 265 | RA[i][0], RA[i][1], RA[i][2], RA[i][3], 266 | RB[j][0], RB[j][1], 267 | RC[i][j][0], RC[i][j][1]); 268 | } 269 | } 270 | } 271 | } 272 | 273 | { 274 | for (int i = 0; i < WARP_TILE_M; ++i) { 275 | // How to use LDST128BITS here? __shfl_sync -> lane 0 -> store 8 half. 276 | // thus, we only need 8 memory issues with 128 bits after shfl_sync. 277 | // may reuse RA[4][4] as RC0 ? only new RC1[4][4]. 278 | uint32_t RC0[WARP_TILE_N][4]; 279 | uint32_t RC1[WARP_TILE_N][4]; 280 | #pragma unroll 281 | for (int j = 0; j < WARP_TILE_N; ++j) { 282 | // How to use LDST128BITS here? __shfl_sync -> lane 0 -> store 8 half. 283 | // thus, we only need 8 memory issues with 128 bits after shfl_sync. 284 | RC0[j][0] = RC[i][j][0]; 285 | RC1[j][0] = RC[i][j][1]; 286 | RC0[j][1] = __shfl_sync((0xffffffff), RC[i][j][0], lane_id + 1); 287 | RC0[j][2] = __shfl_sync((0xffffffff), RC[i][j][0], lane_id + 2); 288 | RC0[j][3] = __shfl_sync((0xffffffff), RC[i][j][0], lane_id + 3); 289 | RC1[j][1] = __shfl_sync((0xffffffff), RC[i][j][1], lane_id + 1); 290 | RC1[j][2] = __shfl_sync((0xffffffff), RC[i][j][1], lane_id + 2); 291 | RC1[j][3] = __shfl_sync((0xffffffff), RC[i][j][1], lane_id + 3); 292 | } 293 | 294 | if (lane_id % 4 == 0) { 295 | int store_warp_smem_c_m = warp_m * (MMA_M * WARP_TILE_M) + i * MMA_M; 296 | int store_lane_gmem_c_m = by * BM + store_warp_smem_c_m + lane_id / 4; 297 | #pragma unroll 298 | for (int j = 0; j < WARP_TILE_N; ++j) { 299 | int store_warp_smem_c_n = warp_n * (MMA_N * WARP_TILE_N) + j * MMA_N; 300 | int store_lane_gmem_c_n = bx * BN + store_warp_smem_c_n; 301 | int store_gmem_c_addr_0 = store_lane_gmem_c_m * N + store_lane_gmem_c_n; 302 | int store_gmem_c_addr_1 = (store_lane_gmem_c_m + 8) * N + store_lane_gmem_c_n; 303 | LDST128BITS(C[store_gmem_c_addr_0]) = LDST128BITS(RC0[j][0]); 304 | LDST128BITS(C[store_gmem_c_addr_1]) = LDST128BITS(RC1[j][0]); 305 | } 306 | } 307 | } 308 | } 309 | } 310 | 311 | // build cpp binary 312 | #ifndef NO_MMA_HGEMM_BIN 313 | 314 | #include "utils.h" 315 | 316 | // 128x128, mma2x4, warp4x4(64,32,16), stages, block swizzle, dsmem, TN 317 | #define LAUNCH_16816_STAGE_SWIZZLE_MMA2x4_WARP4x4_DSMEM_TN_KERNEL(stages, stride) \ 318 | { \ 319 | const int smem_max_size = ( \ 320 | (stages) * BM * (BK + A_PAD) * sizeof(half) + \ 321 | (stages) * BN * (BK + B_PAD) * sizeof(half)); \ 322 | cudaFuncSetAttribute( \ 323 | hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn_kernel< \ 324 | MMA_M, MMA_N, MMA_K, MMA_TILE_M, MMA_TILE_N, \ 325 | WARP_TILE_M, WARP_TILE_N, A_PAD, B_PAD, (stages), true>, \ 326 | cudaFuncAttributeMaxDynamicSharedMemorySize, \ 327 | 98304); \ 328 | const int N_SWIZZLE = (N + (stride) - 1) / (stride); \ 329 | dim3 block(NUM_THREADS); \ 330 | dim3 grid((div_ceil(N, BN) + N_SWIZZLE - 1) / N_SWIZZLE, \ 331 | div_ceil(M, BM), \ 332 | N_SWIZZLE); \ 333 | hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn_kernel< \ 334 | MMA_M, MMA_N, MMA_K, MMA_TILE_M, MMA_TILE_N, \ 335 | WARP_TILE_M, WARP_TILE_N, A_PAD, B_PAD, (stages), true><<< \ 336 | grid, block, smem_max_size>>>( \ 337 | a, b, c, \ 338 | M, N, K \ 339 | ); \ 340 | } 341 | 342 | template 343 | void lanunch_hgemm_mma_m16n8k16_tn( 344 | half* a, half* b, half* c, int M, int N, int K) { 345 | constexpr int MMA_M = 16; 346 | constexpr int MMA_N = 8; 347 | constexpr int MMA_K = 16; 348 | constexpr int MMA_TILE_M = 2; 349 | constexpr int MMA_TILE_N = 4; 350 | constexpr int WARP_TILE_M = 4; 351 | constexpr int WARP_TILE_N = 4; 352 | constexpr int A_PAD = 0; 353 | constexpr int B_PAD = 0; 354 | constexpr int NUM_THREADS= ( 355 | MMA_TILE_M * MMA_TILE_N * WARP_SIZE); // 2 * 4 * 32 = 256 356 | constexpr int BM = MMA_M * MMA_TILE_M * WARP_TILE_M; 357 | constexpr int BN = MMA_N * MMA_TILE_N * WARP_TILE_N; 358 | constexpr int BK = MMA_K; 359 | // s2: 2*128*(32)*2=16KB, 2*32*(128+16)*2=18KB, ~35KB 360 | // s3: 3*128*(32)*2=24KB, 3*32*(128+16)*2=27KB, ~51KB 361 | // s4: 4*128*(32)*2=32KB, 4*32*(128+16)*2=36KB, ~68KB 362 | // s5: 5*128*(32)*2=40KB, 5*32*(128+16)*2=45KB, ~85KB 363 | LAUNCH_16816_STAGE_SWIZZLE_MMA2x4_WARP4x4_DSMEM_TN_KERNEL( 364 | K_STAGE, BLOCK_SWIZZLE_STRIDE); 365 | } 366 | 367 | #ifdef HGEMM_MMA_DEBUG 368 | #include 369 | #endif 370 | 371 | 372 | int main(int argc, char *argv[]) { 373 | #ifdef HGEMM_MMA_DEBUG 374 | const int test_num = 1; 375 | #else 376 | const int test_num = 64; 377 | #endif 378 | int M_list[test_num]; 379 | int N_list[test_num]; 380 | int K_list[test_num]; 381 | 382 | for (int i = 0; i < test_num; i++) { 383 | M_list[i] = (i + 1) * 256; 384 | N_list[i] = (i + 1) * 256; 385 | K_list[i] = (i + 1) * 256; 386 | } 387 | 388 | #ifdef HGEMM_MMA_DEBUG 389 | if (argc > 1) M_list[0] = std::stoi(argv[1]); 390 | if (argc > 2) N_list[0] = std::stoi(argv[2]); 391 | if (argc > 3) K_list[0] = std::stoi(argv[3]); 392 | #endif 393 | 394 | #ifdef HGEMM_MMA_DEBUG 395 | int outer_repeat = 1, inner_repeat = 1, warmup = 1; 396 | if (argc > 4) warmup = std::stoi(argv[4]); 397 | if (argc > 5) inner_repeat = std::stoi(argv[5]); 398 | #else 399 | int outer_repeat = 10, inner_repeat = 1, warmup = 1; 400 | #endif 401 | 402 | printf("ALGO = MMA16816 HGEMM TN MMA=2x4 WARP=4x4 STAGES=2 BLOCK SWIZZLE=2048\n"); 403 | #ifndef HGEMM_MMA_DEBUG 404 | for (int j = 0; j < 5; j++) { 405 | int M = M_list[j], N = N_list[j], K = K_list[j]; 406 | float max_error = gemm_error_check_tn( 407 | lanunch_hgemm_mma_m16n8k16_tn<2, 2048>, 408 | M, N, K); 409 | printf("M N K = %6d %6d %6d, ", M, N, K); 410 | printf("Max Error = %f\n", max_error); 411 | } 412 | #endif 413 | 414 | for (int j = 0; j < test_num; j++) { 415 | int M = M_list[j], N = N_list[j], K = K_list[j]; 416 | 417 | double max_sec = 0.0; 418 | double min_sec = DBL_MAX; 419 | double total_sec = 0.0; 420 | 421 | for (int k = 0; k < outer_repeat; k++) { 422 | double this_sec = perf_gemm( 423 | lanunch_hgemm_mma_m16n8k16_tn<2, 2048>, 424 | M, N, K, inner_repeat, warmup); 425 | max_sec = max(max_sec, this_sec); 426 | min_sec = min(min_sec, this_sec); 427 | total_sec += this_sec; 428 | } 429 | 430 | // 1 TFLOPS = 10^12 FLOPS 431 | // ref: https://imgtec.eetrend.com/blog/2021/100062210.html. 432 | double avg_sec = total_sec / outer_repeat; 433 | double avg_Tflops = ((double)M) * N * K * 2 * 1e-12 / avg_sec; 434 | 435 | printf("M N K = %6d %6d %6d, W = %1d, R = %2d ", M, N, K, warmup, inner_repeat); 436 | printf("Time = %12.8lf %12.8lf %12.8lf s, ", min_sec, avg_sec, max_sec); 437 | printf("AVG Performance = %10.4lf Tflops\n", avg_Tflops); 438 | } 439 | 440 | return 0; 441 | } 442 | 443 | 444 | #else 445 | 446 | // --------------------- PyTorch bindings for custom kernel ----------------------- 447 | #include 448 | #include 449 | #define STRINGFY(str) #str 450 | #define TORCH_BINDING_COMMON_EXTENSION(func) \ 451 | m.def(STRINGFY(func), &func, STRINGFY(func)); 452 | 453 | #define CHECK_TORCH_TENSOR_DTYPE(T, th_type) \ 454 | if(((T).options().dtype() != (th_type))) { \ 455 | std::cout << "Tensor Info:" << (T).options() << std::endl; \ 456 | throw std::runtime_error("values must be "#th_type); \ 457 | } 458 | 459 | #define CHECK_TORCH_TENSOR_SHAPE(T, S0, S1) \ 460 | if (((T).size(0) != (S0)) || ((T).size(1) != (S1))) { \ 461 | throw std::runtime_error("Tensor size mismatch!"); \ 462 | } 463 | 464 | // 128x128, mma2x4, warp4x4(64,32,16), stages, block swizzle, dsmem, TN 465 | #define LAUNCH_16816_STAGE_SWIZZLE_MMA2x4_WARP4x4_DSMEM_TN_KERNEL(stages, stride) \ 466 | { \ 467 | const int smem_max_size = ( \ 468 | (stages) * BM * (BK + A_PAD) * sizeof(half) + \ 469 | (stages) * BN * (BK + B_PAD) * sizeof(half)); \ 470 | cudaFuncSetAttribute( \ 471 | hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn_kernel< \ 472 | MMA_M, MMA_N, MMA_K, MMA_TILE_M, MMA_TILE_N, \ 473 | WARP_TILE_M, WARP_TILE_N, A_PAD, B_PAD, (stages), true>, \ 474 | cudaFuncAttributeMaxDynamicSharedMemorySize, \ 475 | 98304); \ 476 | const int N_SWIZZLE = (N + (stride) - 1) / (stride); \ 477 | dim3 block(NUM_THREADS); \ 478 | dim3 grid((div_ceil(N, BN) + N_SWIZZLE - 1) / N_SWIZZLE, \ 479 | div_ceil(M, BM), \ 480 | N_SWIZZLE); \ 481 | hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn_kernel< \ 482 | MMA_M, MMA_N, MMA_K, MMA_TILE_M, MMA_TILE_N, \ 483 | WARP_TILE_M, WARP_TILE_N, A_PAD, B_PAD, (stages), true><<< \ 484 | grid, block, smem_max_size>>>( \ 485 | reinterpret_cast(a.data_ptr()), \ 486 | reinterpret_cast(b.data_ptr()), \ 487 | reinterpret_cast(c.data_ptr()), \ 488 | M, N, K \ 489 | ); \ 490 | } 491 | 492 | #define LAUNCH_16816_STAGE_NO_SWIZZLE_MMA2x4_WARP4x4_DSMEM_TN_KERNEL(stages) \ 493 | { \ 494 | const int smem_max_size = ( \ 495 | (stages) * BM * (BK + A_PAD) * sizeof(half) + \ 496 | (stages) * BN * (BK + B_PAD) * sizeof(half)); \ 497 | cudaFuncSetAttribute( \ 498 | hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn_kernel< \ 499 | MMA_M, MMA_N, MMA_K, MMA_TILE_M, MMA_TILE_N, \ 500 | WARP_TILE_M, WARP_TILE_N, A_PAD, B_PAD, (stages), false>, \ 501 | cudaFuncAttributeMaxDynamicSharedMemorySize, \ 502 | 98304); \ 503 | dim3 block(NUM_THREADS); \ 504 | dim3 grid(div_ceil(N, BN), div_ceil(M, BM)); \ 505 | hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn_kernel< \ 506 | MMA_M, MMA_N, MMA_K, MMA_TILE_M, MMA_TILE_N, \ 507 | WARP_TILE_M, WARP_TILE_N, A_PAD, B_PAD, (stages), false><<< \ 508 | grid, block, smem_max_size>>>( \ 509 | reinterpret_cast(a.data_ptr()), \ 510 | reinterpret_cast(b.data_ptr()), \ 511 | reinterpret_cast(c.data_ptr()), \ 512 | M, N, K \ 513 | ); \ 514 | } 515 | 516 | // 128x128, mma2x4, warp4x4(64,32,16), stages, block swizzle, dsmem 517 | void hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn( 518 | torch::Tensor a, torch::Tensor b, torch::Tensor c, 519 | int stages, bool swizzle, int swizzle_stride) { 520 | CHECK_TORCH_TENSOR_DTYPE(a, torch::kHalf) 521 | CHECK_TORCH_TENSOR_DTYPE(b, torch::kHalf) 522 | CHECK_TORCH_TENSOR_DTYPE(c, torch::kHalf) 523 | const int M = a.size(0); 524 | const int K = a.size(1); 525 | const int N = b.size(1); 526 | CHECK_TORCH_TENSOR_SHAPE(a, M, K) 527 | CHECK_TORCH_TENSOR_SHAPE(b, K, N) 528 | CHECK_TORCH_TENSOR_SHAPE(c, M, N) 529 | constexpr int MMA_M = 16; 530 | constexpr int MMA_N = 8; 531 | constexpr int MMA_K = 16; 532 | constexpr int MMA_TILE_M = 2; 533 | constexpr int MMA_TILE_N = 4; 534 | constexpr int WARP_TILE_M = 4; 535 | constexpr int WARP_TILE_N = 4; 536 | constexpr int A_PAD = 0; // 0,8,16 537 | constexpr int B_PAD = 8; // 0,8,16 538 | constexpr int NUM_THREADS= ( 539 | MMA_TILE_M * MMA_TILE_N * WARP_SIZE); // 2 * 4 * 32 = 256 540 | constexpr int BM = MMA_M * MMA_TILE_M * WARP_TILE_M; 541 | constexpr int BN = MMA_N * MMA_TILE_N * WARP_TILE_N; 542 | constexpr int BK = MMA_K; 543 | 544 | if (swizzle) { 545 | // assert(swizzle_stride % 256 == 0); 546 | switch (stages) 547 | { 548 | case 2: 549 | LAUNCH_16816_STAGE_SWIZZLE_MMA2x4_WARP4x4_DSMEM_TN_KERNEL(2, swizzle_stride); 550 | break; 551 | case 3: 552 | LAUNCH_16816_STAGE_SWIZZLE_MMA2x4_WARP4x4_DSMEM_TN_KERNEL(3, swizzle_stride); 553 | break; 554 | case 4: 555 | LAUNCH_16816_STAGE_SWIZZLE_MMA2x4_WARP4x4_DSMEM_TN_KERNEL(4, swizzle_stride); 556 | break; 557 | case 5: 558 | LAUNCH_16816_STAGE_SWIZZLE_MMA2x4_WARP4x4_DSMEM_TN_KERNEL(5, swizzle_stride); 559 | break; 560 | default: 561 | LAUNCH_16816_STAGE_SWIZZLE_MMA2x4_WARP4x4_DSMEM_TN_KERNEL(2, swizzle_stride); 562 | break; 563 | } 564 | } else { 565 | switch (stages) 566 | { 567 | case 2: 568 | LAUNCH_16816_STAGE_NO_SWIZZLE_MMA2x4_WARP4x4_DSMEM_TN_KERNEL(2); 569 | break; 570 | case 3: 571 | LAUNCH_16816_STAGE_NO_SWIZZLE_MMA2x4_WARP4x4_DSMEM_TN_KERNEL(3); 572 | break; 573 | case 4: 574 | LAUNCH_16816_STAGE_NO_SWIZZLE_MMA2x4_WARP4x4_DSMEM_TN_KERNEL(4); 575 | break; 576 | case 5: 577 | LAUNCH_16816_STAGE_NO_SWIZZLE_MMA2x4_WARP4x4_DSMEM_TN_KERNEL(5); 578 | break; 579 | default: 580 | LAUNCH_16816_STAGE_NO_SWIZZLE_MMA2x4_WARP4x4_DSMEM_TN_KERNEL(2); 581 | break; 582 | } 583 | } 584 | } 585 | 586 | #endif 587 | --------------------------------------------------------------------------------