├── utils ├── python │ └── __init__.py └── cpp │ ├── cutlass │ ├── traits_base.cuh │ ├── compute.cuh │ ├── convert.cuh │ └── copy.cuh │ ├── cuda_info.cuh │ └── cuda_utils.cuh ├── benchs ├── python │ ├── gemm │ │ ├── cutlass │ │ │ ├── __init__.py │ │ │ ├── .gitignore │ │ │ ├── entry.py │ │ │ ├── gemm.py │ │ │ ├── compile.py │ │ │ └── cutlass_gemm.cuh │ │ ├── tiledcuda │ │ │ ├── __init__.py │ │ │ ├── .gitignore │ │ │ ├── gemm.py │ │ │ ├── entry.py │ │ │ └── compile.py │ │ ├── cuBLAS │ │ │ ├── .gitignore │ │ │ ├── README.md │ │ │ ├── src │ │ │ │ ├── Makefile │ │ │ │ ├── CMakeLists.txt │ │ │ │ ├── bind.cu │ │ │ │ └── cublas_gemm.cuh │ │ │ └── __init__.py │ │ ├── gemm_bench_NVIDIA_A100_80GB_PCIe.csv │ │ ├── triton │ │ │ ├── test.py │ │ │ └── gemm.py │ │ └── bench.py │ ├── batched_gemm │ │ └── cutlass │ │ │ ├── entry.py │ │ │ ├── batched_gemm.py │ │ │ ├── test.py │ │ │ ├── compile.py │ │ │ └── cutlass_batched_gemm.cuh │ ├── fused_gemm │ │ ├── cutlass │ │ │ ├── entry.py │ │ │ ├── fused_gemm.py │ │ │ ├── test.py │ │ │ ├── compile.py │ │ │ └── cutlass_fused_gemm.cuh │ │ └── tiledcuda │ │ │ ├── fused_gemm.py │ │ │ ├── entry.py │ │ │ └── compile.py │ └── lstm │ │ └── cutlass │ │ ├── entry.py │ │ ├── lstm.py │ │ ├── test.py │ │ ├── compile.py │ │ └── cutlass_lstm.cuh └── cpp │ ├── copy │ ├── Makefile │ ├── bench_NVIDIA_A100_80GB_PCIe_copy.tsv │ ├── CMakeLists.txt │ ├── bench.cu │ ├── cutlass │ │ └── cutlass_copy.cuh │ └── tiledcuda │ │ └── tiledcuda_copy.cuh │ └── gemm │ ├── Makefile │ ├── CMakeLists.txt │ ├── figures │ └── bench_NVIDIA_A100_80GB_PCIe_gemm.tsv │ ├── util.cuh │ ├── tiledcuda_gemm.cuh │ ├── cutlass_gemm.cuh │ └── bench.cu ├── .gitignore ├── .gitmodules ├── README.md ├── scripts ├── clang_format.hook └── cmake │ └── generic.cmake ├── .clang-format ├── LICENSE └── .pre-commit-config.yaml /utils/python/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /benchs/python/gemm/cutlass/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /benchs/python/gemm/tiledcuda/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /benchs/python/gemm/cuBLAS/.gitignore: -------------------------------------------------------------------------------- 1 | src/build 2 | -------------------------------------------------------------------------------- /benchs/python/gemm/cutlass/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | tmp 3 | -------------------------------------------------------------------------------- /benchs/python/gemm/tiledcuda/.gitignore: -------------------------------------------------------------------------------- 1 | tmp 2 | __pycache__ 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/* 2 | **/build/* 3 | **/__pycache__/* 4 | **/tmp/* 5 | .DS_Store 6 | **/.DS_Store 7 | -------------------------------------------------------------------------------- /benchs/python/gemm/cuBLAS/README.md: -------------------------------------------------------------------------------- 1 | # Run the test 2 | 3 | ```bash 4 | cd src & make 5 | cd ../ 6 | python3 test.py 7 | ``` 8 | -------------------------------------------------------------------------------- /benchs/cpp/copy/Makefile: -------------------------------------------------------------------------------- 1 | BUILD_DIR := build 2 | 3 | .PHONY: build clean 4 | 5 | build: 6 | @mkdir -p $(BUILD_DIR) 7 | @cd $(BUILD_DIR) && cmake .. && make -j$(proc) 8 | 9 | clean: 10 | @rm -rf $(BUILD_DIR) 11 | -------------------------------------------------------------------------------- /benchs/cpp/gemm/Makefile: -------------------------------------------------------------------------------- 1 | BUILD_DIR := build 2 | 3 | .PHONY: build clean 4 | 5 | build: 6 | @mkdir -p $(BUILD_DIR) 7 | @cd $(BUILD_DIR) && cmake .. && make -j$(proc) 8 | 9 | clean: 10 | @rm -rf $(BUILD_DIR) 11 | -------------------------------------------------------------------------------- /benchs/python/gemm/cuBLAS/src/Makefile: -------------------------------------------------------------------------------- 1 | BUILD_DIR := build 2 | 3 | .PHONY: build clean 4 | 5 | build: 6 | @mkdir -p $(BUILD_DIR) 7 | @cd $(BUILD_DIR) && cmake .. && make -j$(proc) 8 | 9 | clean: 10 | @rm -rf $(BUILD_DIR) 11 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "3rd-party/TiledCUDA"] 2 | path = 3rd-party/TiledCUDA 3 | url = git@github.com:TiledTensor/TiledCUDA.git 4 | [submodule "3rd-party/cutlass"] 5 | path = 3rd-party/cutlass 6 | url = https://github.com/NVIDIA/cutlass 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # benchmarks 2 | Benchmark tests supporting the TiledCUDA library. 3 | 4 | ## Download 5 | 6 | ```bash 7 | git clone https://github.com/TiledTensor/benchmarks.git 8 | cd benchmarks && git submodule update --init --recursive 9 | ``` 10 | -------------------------------------------------------------------------------- /benchs/cpp/copy/bench_NVIDIA_A100_80GB_PCIe_copy.tsv: -------------------------------------------------------------------------------- 1 | Copy Type [M, N, K] [kTM, kTN, kTK] [kWarpPerRow, kWarpPerCol] CutlassTime(ms) TiledCUDATime(ms) Ratio 2 | Whole [4096, 4096, 2048] [64, 32, 32] [2, 2] 0.5805 1.0871 1.8729 3 | G2S [4096, 4096, 2048] [64, 32, 32] [2, 2] 0.5806 1.0877 1.8735 4 | -------------------------------------------------------------------------------- /scripts/clang_format.hook: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | readonly VERSION="12.0.0" 5 | 6 | version=$(clang-format -version) 7 | 8 | if ! [[ $version == *"$VERSION"* ]]; then 9 | echo "clang-format version check failed." 10 | echo "a version contains '$VERSION' is needed, but get '$version'" 11 | echo "you can install the right version, and make an soft-link to '\$PATH' env" 12 | exit -1 13 | fi 14 | 15 | clang-format $@ 16 | -------------------------------------------------------------------------------- /utils/cpp/cutlass/traits_base.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | namespace benchmarks { 4 | namespace cutlass_wrapper { 5 | 6 | template 7 | struct TraitsBase { 8 | // the maximal width of vectorized access. 9 | static constexpr int kAccessInBits = 128; 10 | static constexpr int kElmentBits = cutlass::sizeof_bits::value; 11 | static constexpr int kNumPerAccess = kAccessInBits / kElmentBits; 12 | }; 13 | 14 | } // namespace cutlass_wrapper 15 | } // namespace benchmarks 16 | -------------------------------------------------------------------------------- /benchs/python/gemm/cutlass/entry.py: -------------------------------------------------------------------------------- 1 | entry = """#include "../cutlass_gemm.cuh" 2 | 3 | extern "C" int kernel_entry(const __half* dA, const __half* dB, __half* dC) {{ 4 | using DType = cutlass::half_t; 5 | 6 | auto* A = reinterpret_cast(dA); 7 | auto* B = reinterpret_cast(dB); 8 | auto* C = reinterpret_cast(dC); 9 | 10 | cute_gemm(A, B, C); 12 | return 0; 13 | }} 14 | """ 15 | -------------------------------------------------------------------------------- /benchs/python/batched_gemm/cutlass/entry.py: -------------------------------------------------------------------------------- 1 | entry = """#include "../cutlass_batched_gemm.cuh" 2 | 3 | extern "C" int kernel_entry(const __half* dA, const __half* dB, __half* dC) {{ 4 | using DType = cutlass::half_t; 5 | 6 | auto* A = reinterpret_cast(dA); 7 | auto* B = reinterpret_cast(dB); 8 | auto* C = reinterpret_cast(dC); 9 | 10 | cute_batched_gemm(A, B, C); 11 | return 0; 12 | }} 13 | """ 14 | -------------------------------------------------------------------------------- /benchs/python/gemm/cuBLAS/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | import os 4 | 5 | dyn_lib_path = os.path.join(os.path.dirname(__file__), "src/build/libcublas_gemm.so") 6 | 7 | torch.ops.load_library(dyn_lib_path) 8 | 9 | def cublas_gemm( 10 | m: int, 11 | n: int, 12 | k: int, 13 | a: Tensor, 14 | b: Tensor, 15 | c: Tensor, 16 | elapsed_time: Tensor, 17 | iters: int = 0, 18 | warmup: int = 0 19 | ): 20 | torch.ops.cublas_gemm.gemm(m, n, k, a, b, c, elapsed_time, iters, warmup) 21 | -------------------------------------------------------------------------------- /benchs/python/fused_gemm/cutlass/entry.py: -------------------------------------------------------------------------------- 1 | entry = """#include "../cutlass_fused_gemm.cuh" 2 | 3 | extern "C" int kernel_entry(const __half* dA, const __half* dB, const __half* dC, __half* dD) {{ 4 | using DType = cutlass::half_t; 5 | 6 | auto* A = reinterpret_cast(dA); 7 | auto* B = reinterpret_cast(dB); 8 | auto* C = reinterpret_cast(dC); 9 | auto* D = reinterpret_cast(dD); 10 | 11 | cute_fused_gemm(A, B, C, D); 13 | return 0; 14 | }} 15 | """ 16 | -------------------------------------------------------------------------------- /benchs/python/gemm/cuBLAS/src/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.25 FATAL_ERROR) 2 | project(cublas_gemm LANGUAGES CXX CUDA) 3 | 4 | set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} 5 | "${PROJECT_SOURCE_DIR}/../../../../../scripts/cmake") 6 | include(generic) 7 | 8 | set(THIRD_PARTY_DIR "${PROJECT_SOURCE_DIR}/../../../../../3rd-party") 9 | include_directories("${THIRD_PARTY_DIR}/cutlass/include") 10 | include_directories("${THIRD_PARTY_DIR}/TiledCUDA/include") 11 | 12 | cuda_add_library(cublas_gemm SHARED bind.cu) 13 | target_link_libraries(cublas_gemm ${TORCH_LIBRARIES} ${CUDA_CUBLAS_LIBRARIES}) 14 | -------------------------------------------------------------------------------- /benchs/cpp/gemm/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.25 FATAL_ERROR) 2 | project(gemm_bench LANGUAGES C CXX CUDA) 3 | 4 | set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} 5 | "${PROJECT_SOURCE_DIR}/../../../scripts/cmake") 6 | set(THIRD_PARTY_DIR "${PROJECT_SOURCE_DIR}/../../../3rd-party") 7 | set(ROOT_DIR "${PROJECT_SOURCE_DIR}/../../../") 8 | include(generic) 9 | 10 | include_directories("${THIRD_PARTY_DIR}/cutlass/include") 11 | include_directories("${THIRD_PARTY_DIR}/TiledCUDA/include") 12 | include_directories("${ROOT_DIR}") 13 | 14 | add_executable(bench_gemm bench.cu) 15 | target_link_libraries(bench_gemm ${CUDA_CUBLAS_LIBRARIES}) 16 | -------------------------------------------------------------------------------- /.clang-format: -------------------------------------------------------------------------------- 1 | # Run manually to reformat a file: 2 | # clang-format -i --style=file 3 | BasedOnStyle: Google 4 | ColumnLimit: 80 5 | IndentWidth: 4 6 | AccessModifierOffset: -2 7 | DerivePointerAlignment: false 8 | KeepEmptyLinesAtTheStartOfBlocks: false 9 | SortIncludes: true 10 | IncludeBlocks: Regroup 11 | IncludeCategories: 12 | - Regex: '<([A-Za-z0-9\Q/-_\E])+>' 13 | Priority: 4 14 | - Regex: '<(catch2|boost)\/' 15 | Priority: 3 16 | - Regex: '<([A-Za-z0-9.\Q/-_\E])+>' 17 | Priority: 2 18 | - Regex: '"([A-Za-z0-9.\Q/-_\E])+"' 19 | Priority: 1 20 | 21 | AllowShortLoopsOnASingleLine: true 22 | AllowShortIfStatementsOnASingleLine: true 23 | Cpp11BracedListStyle: true 24 | -------------------------------------------------------------------------------- /benchs/cpp/copy/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.25 FATAL_ERROR) 2 | project(gemm_bench LANGUAGES C CXX CUDA) 3 | 4 | set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} 5 | "${PROJECT_SOURCE_DIR}/../../../scripts/cmake") 6 | include(generic) 7 | 8 | set(THIRD_PARTY_DIR "${PROJECT_SOURCE_DIR}/../../../3rd-party") 9 | set(ROOT_DIR "${PROJECT_SOURCE_DIR}/../../../") 10 | 11 | include_directories("${THIRD_PARTY_DIR}/cutlass/include") 12 | include_directories("${THIRD_PARTY_DIR}/TiledCUDA/include") 13 | 14 | include_directories("${ROOT_DIR}/") 15 | include_directories("${PROJECT_SOURCE_DIR}/cutlass") 16 | 17 | add_executable(bench_copy bench.cu) 18 | target_link_libraries(bench_copy ${CUDA_CUBLAS_LIBRARIES}) 19 | -------------------------------------------------------------------------------- /utils/cpp/cuda_info.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include 6 | #include 7 | 8 | namespace benchmarks { 9 | 10 | // Returns the name of the device. 11 | std::string get_device_name() { 12 | cudaDeviceProp prop; 13 | cudaGetDeviceProperties(&prop, 0); 14 | 15 | std::stringstream ss(prop.name); 16 | const char delim = ' '; 17 | 18 | std::string s; 19 | std::vector out; 20 | 21 | while (std::getline(ss, s, delim)) { 22 | out.push_back(s); 23 | } 24 | 25 | std::stringstream out_ss; 26 | int i = 0; 27 | for (; i < static_cast(out.size()) - 1; ++i) out_ss << out[i] << "_"; 28 | out_ss << out[i]; 29 | return out_ss.str(); 30 | } 31 | } // namespace benchmarks 32 | -------------------------------------------------------------------------------- /utils/cpp/cutlass/compute.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace benchmarks { 8 | namespace cutlass_wrapper { 9 | 10 | template 11 | DEVICE void cute_tanh(cute::Tensor& tensor) { 12 | #pragma unroll 13 | for (int i = 0; i < size(tensor); ++i) { 14 | tensor(i) = tanh(tensor(i)); 15 | } 16 | } 17 | 18 | template 19 | DEVICE void cute_sigmoid(cute::Tensor& tensor) { 20 | #pragma unroll 21 | for (int i = 0; i < size(tensor); ++i) { 22 | tensor(i) = 1.0 / (1.0 + exp(-tensor(i))); 23 | } 24 | } 25 | } // namespace cutlass_wrapper 26 | } // namespace benchmarks 27 | -------------------------------------------------------------------------------- /benchs/cpp/gemm/figures/bench_NVIDIA_A100_80GB_PCIe_gemm.tsv: -------------------------------------------------------------------------------- 1 | [M, N, K] [kTM, kTN, kTK] kRK Warp Layout cuBLAS(ms) cutlass(ms) TiledCUDA(ms) 2 | [4096, 4096, 4096] [128, 128, 64] 16 [2, 4] 0.5164 0.6647 (1.29) 1.0760 (2.08) 3 | [4096, 4096, 4096] [128, 128, 64] 32 [2, 4] 0.5167 0.6715 (1.30) 1.0930 (2.12) 4 | [4096, 4096, 4096] [128, 128, 64] 64 [2, 4] 0.5166 0.6647 (1.29) 1.6878 (3.27) 5 | [4096, 4096, 4096] [64, 128, 128] 32 [2, 4] 0.5174 0.8806 (1.70) 1.4246 (2.75) 6 | [4096, 4096, 4096] [128, 64, 64] 32 [2, 4] 0.5167 0.8623 (1.67) 1.4854 (2.87) 7 | [4096, 4096, 4096] [128, 128, 64] 32 [2, 2] 0.5171 0.6435 (1.24) 1.1199 (2.17) 8 | [4096, 4096, 4096] [64, 128, 128] 32 [2, 2] 0.5173 0.8126 (1.57) 1.4209 (2.75) 9 | [4096, 4096, 4096] [64, 64, 64] 32 [2, 2] 0.5174 0.9200 (1.78) 1.8451 (3.57) 10 | -------------------------------------------------------------------------------- /benchs/python/gemm/cuBLAS/src/bind.cu: -------------------------------------------------------------------------------- 1 | #include "cublas_gemm.cuh" 2 | 3 | #include 4 | 5 | void gemm_op(int64_t m, int64_t n, int64_t k, const torch::Tensor& A, 6 | const torch::Tensor& B, torch::Tensor& C, torch::Tensor& time, 7 | int64_t iters = 20, int64_t warm_up = 5) { 8 | using namespace benchmarks; 9 | using DType = __half; 10 | 11 | auto* dA = reinterpret_cast(A.data_ptr()); 12 | auto* dB = reinterpret_cast(B.data_ptr()); 13 | auto* dC = reinterpret_cast(C.data_ptr()); 14 | auto* time_data = reinterpret_cast(time.data_ptr()); 15 | 16 | cublas_hgemm(m, n, k, dA, dB, dC, time_data, iters, warm_up); 17 | } 18 | 19 | TORCH_LIBRARY(cublas_gemm, t) { t.def("gemm", &gemm_op); }; 20 | -------------------------------------------------------------------------------- /benchs/python/lstm/cutlass/entry.py: -------------------------------------------------------------------------------- 1 | entry = """#include "../cutlass_lstm.cuh" 2 | 3 | extern "C" int kernel_entry(const __half* dW, const __half* dX, const __half* dU, 4 | const __half* dC, const __half* dH, __half* dCO, __half* dHO) {{ 5 | using DType = cutlass::half_t; 6 | 7 | auto* W = reinterpret_cast(dW); 8 | auto* X = reinterpret_cast(dX); 9 | auto* U = reinterpret_cast(dU); 10 | auto* C = reinterpret_cast(dC); 11 | auto* H = reinterpret_cast(dH); 12 | auto* CO = reinterpret_cast(dCO); 13 | auto* HO = reinterpret_cast(dHO); 14 | 15 | cute_lstm_cell(W, X, U, C, H, CO, HO); 17 | return 0; 18 | }} 19 | """ 20 | -------------------------------------------------------------------------------- /benchs/python/gemm/gemm_bench_NVIDIA_A100_80GB_PCIe.csv: -------------------------------------------------------------------------------- 1 | M,N,K,kTM,kTN,kTK,cuBLAS(ms),Cutlass(ms),TiledCUDA(ms) 2 | 4096,4096,2048,128,256,64,0.2777,1.5857,nan 3 | 4096,4096,2048,64,256,32,0.2787,1.5351,1.5729 4 | 4096,4096,2048,128,128,32,0.2790,1.5618,1.5439 5 | 4096,4096,2048,128,64,32,0.2777,1.5506,1.5677 6 | 4096,4096,2048,64,128,32,0.2781,1.5840,1.5887 7 | 4096,4096,2048,128,32,32,0.2778,1.5643,1.5810 8 | 4096,4096,2048,32,64,32,0.2794,3.9620,1.5433 9 | 4096,4096,2048,128,256,128,0.2805,1.5647,nan 10 | 4096,4096,2048,256,128,128,0.2790,1.5799,2.5801 11 | 4096,4096,2048,64,256,128,0.2780,1.5490,1.5441 12 | 4096,4096,2048,128,128,128,0.2791,1.5612,1.5834 13 | 4096,4096,2048,128,64,64,0.2791,1.5799,1.5626 14 | 4096,4096,2048,64,128,64,0.2779,1.5899,1.5814 15 | 4096,4096,32,64,32,32,0.0297,1.5930,1.5719 16 | 4096,4096,64,128,128,64,0.0372,1.6003,1.5975 17 | -------------------------------------------------------------------------------- /benchs/python/gemm/cutlass/gemm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | import os 5 | 6 | from .compile import Compile 7 | 8 | __all__ = [ 9 | "gemm_func", 10 | ] 11 | 12 | 13 | class GemmFunc(torch.autograd.Function): 14 | 15 | @staticmethod 16 | def forward( 17 | ctx, 18 | A: Tensor, 19 | B: Tensor, 20 | C: Tensor, 21 | M: int, 22 | N: int, 23 | K: int, 24 | kTM: int, 25 | kTN: int, 26 | kTK: int, 27 | warp_per_row: int, 28 | warp_per_col: int, 29 | ) -> Tensor: 30 | tmp_dir = os.path.join(os.path.dirname(__file__), 'tmp') 31 | builder = Compile(file_prefix="gemm", tmp_dir=tmp_dir) 32 | lib_name = builder.compile(M, N, K, kTM, kTN, kTK, warp_per_row, 33 | warp_per_col) 34 | 35 | if lib_name is None: 36 | raise RuntimeError("Failed to compile the library.") 37 | 38 | builder.apply(lib_name, [A, B, C], device=0) 39 | return C 40 | 41 | 42 | gemm_func = GemmFunc.apply 43 | -------------------------------------------------------------------------------- /benchs/python/batched_gemm/cutlass/batched_gemm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | from compile import Compile 5 | 6 | __all__ = [ 7 | "batched_gemm_func", 8 | ] 9 | 10 | 11 | class BatchedGemmFunc(torch.autograd.Function): 12 | 13 | @staticmethod 14 | def forward( 15 | ctx, 16 | A: Tensor, 17 | B: Tensor, 18 | C: Tensor, 19 | M: int, 20 | N: int, 21 | K: int, 22 | BatchCount: int, 23 | kTM: int, 24 | kTN: int, 25 | kTK: int, 26 | warp_per_row: int, 27 | warp_per_col: int, 28 | ) -> Tensor: 29 | builder = Compile(file_prefix="batched_gemm", tmp_dir="tmp") 30 | lib_name = builder.compile(M, N, K, BatchCount, kTM, kTN, kTK, warp_per_row, 31 | warp_per_col) 32 | 33 | if lib_name is None: 34 | raise RuntimeError("Failed to compile the library.") 35 | 36 | builder.apply(lib_name, [A, B, C], device=0) 37 | return C 38 | 39 | 40 | batched_gemm_func = BatchedGemmFunc.apply 41 | -------------------------------------------------------------------------------- /benchs/python/lstm/cutlass/lstm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | from compile import Compile 5 | 6 | __all__ = [ 7 | "lstm_func", 8 | ] 9 | 10 | 11 | class LstmFunc(torch.autograd.Function): 12 | 13 | @staticmethod 14 | def forward( 15 | ctx, 16 | W: Tensor, 17 | X: Tensor, 18 | U: Tensor, 19 | C: Tensor, 20 | H: Tensor, 21 | CO: Tensor, 22 | HO: Tensor, 23 | M: int, 24 | N: int, 25 | K: int, 26 | kTM: int, 27 | kTN: int, 28 | kTK: int, 29 | warp_per_row: int, 30 | warp_per_col: int, 31 | ) -> Tensor: 32 | builder = Compile(file_prefix="lstm", tmp_dir="tmp") 33 | lib_name = builder.compile(M, N, K, kTM, kTN, kTK, warp_per_row, 34 | warp_per_col) 35 | 36 | if lib_name is None: 37 | raise RuntimeError("Failed to compile the library.") 38 | 39 | builder.apply(lib_name, [W, X, U, C, H, CO, HO], device=0) 40 | return C 41 | 42 | 43 | lstm_func = LstmFunc.apply 44 | -------------------------------------------------------------------------------- /benchs/python/gemm/tiledcuda/gemm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | import os 5 | 6 | from .compile import Compile 7 | 8 | __all__ = [ 9 | "gemm_func", 10 | ] 11 | 12 | 13 | class GemmFunc(torch.autograd.Function): 14 | 15 | @staticmethod 16 | def forward( 17 | ctx, 18 | A: Tensor, 19 | B: Tensor, 20 | C: Tensor, 21 | M: int, 22 | N: int, 23 | K: int, 24 | kTM: int, 25 | kTN: int, 26 | kTK: int, 27 | kRK: int, 28 | warp_per_row: int, 29 | warp_per_col: int, 30 | ) -> Tensor: 31 | tmp_dir = os.path.join(os.path.dirname(__file__), 'tmp') 32 | builder = Compile(file_prefix="gemm", tmp_dir=tmp_dir) 33 | lib_name = builder.compile(M, N, K, kTM, kTN, kTK, kRK, warp_per_row, 34 | warp_per_col) 35 | 36 | if lib_name is None: 37 | raise RuntimeError("Failed to compile the library.") 38 | 39 | builder.apply(lib_name, [A, B, C], device=0) 40 | return C 41 | 42 | 43 | gemm_func = GemmFunc.apply 44 | -------------------------------------------------------------------------------- /benchs/python/fused_gemm/cutlass/fused_gemm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | from compile import Compile 5 | 6 | __all__ = [ 7 | "fused_gemm_func", 8 | ] 9 | 10 | 11 | class FusedGemmFunc(torch.autograd.Function): 12 | 13 | @staticmethod 14 | def forward( 15 | ctx, 16 | A: Tensor, 17 | B: Tensor, 18 | C: Tensor, 19 | D: Tensor, 20 | M: int, 21 | N: int, 22 | K: int, 23 | P: int, 24 | kTM: int, 25 | kTN: int, 26 | kTK: int, 27 | kTP: int, 28 | warp_per_row: int, 29 | warp_per_col: int, 30 | ) -> Tensor: 31 | builder = Compile(file_prefix="fused_gemm", tmp_dir="tmp") 32 | lib_name = builder.compile(M, N, K, P, kTM, kTN, kTK, kTP, warp_per_row, 33 | warp_per_col) 34 | 35 | if lib_name is None: 36 | raise RuntimeError("Failed to compile the library.") 37 | 38 | 39 | builder.apply(lib_name, [A, B, C, D], device=0) 40 | return D 41 | 42 | 43 | fused_gemm_func = FusedGemmFunc.apply 44 | -------------------------------------------------------------------------------- /benchs/python/fused_gemm/tiledcuda/fused_gemm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | from compile import Compile 5 | 6 | __all__ = [ 7 | "fused_gemm_func", 8 | ] 9 | 10 | 11 | class FusedGemmFunc(torch.autograd.Function): 12 | 13 | @staticmethod 14 | def forward( 15 | ctx, 16 | A: Tensor, 17 | B: Tensor, 18 | C: Tensor, 19 | D: Tensor, 20 | M: int, 21 | N: int, 22 | K: int, 23 | P: int, 24 | kTM: int, 25 | kTN: int, 26 | kTK: int, 27 | kTP: int, 28 | kRK: int, 29 | warp_per_row: int, 30 | warp_per_col: int, 31 | ) -> Tensor: 32 | builder = Compile(file_prefix="fused_gemm", tmp_dir="tmp") 33 | lib_name = builder.compile(M, N, K, P, kTM, kTN, kTK, kTP, kRK, warp_per_row, 34 | warp_per_col) 35 | 36 | if lib_name is None: 37 | raise RuntimeError("Failed to compile the library.") 38 | 39 | builder.apply(lib_name, [A, B, C, D], device=0) 40 | return D 41 | 42 | 43 | fused_gemm_func = FusedGemmFunc.apply 44 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 TiledTensor 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/Lucas-C/pre-commit-hooks.git 3 | rev: v1.5.5 4 | hooks: 5 | - id: remove-crlf 6 | files: (?!.*third_party)^.*$ | (?!.*book)^.*$ 7 | - repo: https://github.com/pre-commit/mirrors-yapf.git 8 | rev: v0.32.0 9 | hooks: 10 | - id: yapf 11 | files: (.*\.(py|bzl)|BUILD|.*\.BUILD|WORKSPACE)$ 12 | - repo: https://github.com/pre-commit/pre-commit-hooks 13 | rev: v4.6.0 14 | hooks: 15 | - id: check-added-large-files 16 | - id: check-merge-conflict 17 | - id: check-symlinks 18 | - id: detect-private-key 19 | files: (?!.*third_party)^.*$ | (?!.*book)^.*$ 20 | - id: end-of-file-fixer 21 | - id: check-yaml 22 | - id: check-toml 23 | - id: check-ast 24 | - id: check-executables-have-shebangs 25 | - id: check-shebang-scripts-are-executable 26 | - id: detect-private-key 27 | - id: debug-statements 28 | - repo: local 29 | hooks: 30 | - id: clang-format-with-version-check 31 | name: clang-format 32 | description: Format files with ClangFormat. 33 | entry: bash ./scripts/clang_format.hook -i 34 | language: system 35 | files: \.(c|cc|cxx|cpp|cu|h|cuh|hpp|hxx|proto)$ 36 | - repo: https://github.com/iconmaster5326/cmake-format-pre-commit-hook 37 | rev: v0.6.9 38 | hooks: 39 | - id: cmake-format 40 | -------------------------------------------------------------------------------- /benchs/python/gemm/triton/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | 4 | os.environ["TRITON_INTERPRET"] = '1' 5 | 6 | import triton 7 | 8 | import gemm 9 | 10 | def run_unittest( 11 | M: int, 12 | N: int, 13 | K: int, 14 | debug_print=False, 15 | epsilon: float = 5e-2 16 | ): 17 | torch.manual_seed(1234) 18 | a = torch.randn(M, K, device = 'cuda', dtype = torch.float16) 19 | b = torch.randn(K, N, device = 'cuda', dtype = torch.float16) 20 | 21 | triton_c = gemm.gemm(a, b) 22 | torch_c = torch.mm(a, b) 23 | 24 | if debug_print: 25 | print("Result:") 26 | print(triton_c) 27 | 28 | print("\nReference:") 29 | print(torch_c) 30 | 31 | avg_diff = (torch.sum(torch.abs(triton_c.half() - torch_c) / (M * N))).item() 32 | 33 | if avg_diff > epsilon: 34 | return False 35 | else: 36 | return True 37 | 38 | def bench( 39 | M: int, 40 | N: int, 41 | K: int 42 | ): 43 | torch.manual_seed(1234) 44 | 45 | a = torch.randn(M, K, device = 'cuda', dtype=torch.float16) 46 | b = torch.randn(K, N, device = 'cuda', dtype=torch.float16) 47 | 48 | warmup = 5 49 | iters = 20 50 | 51 | ms = triton.testing.do_bench(lambda: gemm.gemm(a, b), warmup=warmup, rep=iters) 52 | 53 | return ms 54 | 55 | 56 | if __name__ == '__main__': 57 | M = 4096 58 | N = 4096 59 | K = 2048 60 | 61 | if run_unittest(M, N, K, True): 62 | print("Unittest passed") 63 | else: 64 | print("Unittest failed") 65 | 66 | time = bench(M, N, K) 67 | print("Elapsed time: {:.4f} ms".format(time)) 68 | -------------------------------------------------------------------------------- /benchs/python/gemm/cuBLAS/src/cublas_gemm.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "util/cuda_timer.hpp" 3 | 4 | #include 5 | 6 | namespace benchmarks { 7 | // In this implementation, A and C is laid out in row-major, B, is laid out 8 | // in column-major: C[m, n] = A[m, k] @ B[k, n] 9 | void cublas_hgemm(int64_t kM, int64_t kN, int64_t kK, // problem shape 10 | const __half* A, const __half* B, __half* C, float* time, 11 | int64_t iters = 20, int64_t warm_up = 5) { 12 | cublasHandle_t handle; 13 | cublasCreate(&handle); 14 | 15 | __half alf = static_cast<__half>(1.); 16 | __half bet = static_cast<__half>(0.); 17 | 18 | if (iters) { // measure time 19 | for (int i = 0; i < warm_up; ++i) { 20 | cublasHgemm(handle, CUBLAS_OP_T /* transb*/, CUBLAS_OP_N, kN, kM, 21 | kK, &alf, B, kK, A, kK, &bet, C, kN); 22 | } 23 | cudaDeviceSynchronize(); 24 | 25 | tiledcuda::CudaTimer timer; 26 | timer.start(); 27 | for (int i = 0; i < iters; ++i) { 28 | cublasHgemm(handle, CUBLAS_OP_T /* transb*/, CUBLAS_OP_N, kN, kM, 29 | kK, &alf, B, kK, A, kK, &bet, C, kN); 30 | } 31 | cudaDeviceSynchronize(); 32 | time[0] = timer.stop() / iters; 33 | } else { 34 | // C = A @ B, but in cuBLAS, matrix is by default laid out in 35 | // column-major, therefore we compute: 36 | // C^T = B^T @ A^T [n, m] = [n, k] @ [k, m] 37 | cublasHgemm(handle, CUBLAS_OP_T /* transb*/, CUBLAS_OP_N, kN, kM, kK, 38 | &alf, B, kK, A, kK, &bet, C, kN); 39 | cudaDeviceSynchronize(); 40 | } 41 | 42 | cublasDestroy(handle); 43 | } 44 | } // namespace benchmarks 45 | -------------------------------------------------------------------------------- /scripts/cmake/generic.cmake: -------------------------------------------------------------------------------- 1 | set(CMAKE_BUILD_TYPE Release) 2 | 3 | set(CMAKE_CXX_STANDARD 4 | 20 5 | CACHE STRING "The C++ standard whoese features are requested." FORCE) 6 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 7 | 8 | set(CMAKE_CUDA_STANDARD 9 | 20 10 | CACHE STRING "The CUDA standard whose features are requested." FORCE) 11 | set(CMAKE_CUDA_STANDARD_REQUIRED ON) 12 | 13 | # Set host compiler flags. Enable all warnings and treat them as errors 14 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wall") 15 | 16 | find_package(CUDAToolkit QUIET REQUIRED) 17 | enable_language(CUDA) 18 | set(CMAKE_CUDA on) 19 | 20 | find_package(Python3 REQUIRED COMPONENTS Interpreter) 21 | 22 | set(TORCH_LIB_PREFIX "${Python3_SITEARCH}/torch") 23 | if(NOT EXISTS ${TORCH_LIB_PREFIX}) 24 | message(FATAL_ERROR "Torch library is not installed.") 25 | else() 26 | list(APPEND CMAKE_PREFIX_PATH "${TORCH_LIB_PREFIX}/share/cmake/Torch") 27 | endif() 28 | find_package(Torch REQUIRED) 29 | 30 | # let cmake automatically detect the current CUDA architecture to avoid 31 | # generating device codes for all possible architectures 32 | set(CMAKE_CUDA_ARCHITECTURES OFF) 33 | set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --Werror all-warnings") 34 | # Set the CUDA_PROPAGATE_HOST_FLAGS to OFF to avoid passing host compiler flags 35 | # to the device compiler 36 | set(CUDA_PROPAGATE_HOST_FLAGS OFF) 37 | 38 | # FIXME(haruhi): -std=c++20 has to be set explicitly here, Otherwise, linking 39 | # against torchlibs will raise errors. it seems that the host compilation 40 | # options are not passed to torchlibs. 41 | set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS} -std=c++20) 42 | set(CUDA_NVCC_FLAGS_DEBUG ${CUDA_NVCC_FLAGS_DEBUG} -std=c++20 -O0) 43 | set(CUDA_NVCC_FLAGS_RELEASE ${CUDA_NVCC_FLAGS_RELEASE} -std=c++20 -O3) 44 | 45 | message(STATUS "CUDA detected: " ${CUDA_VERSION}) 46 | message(STATUS "CUDA nvcc is: " ${CUDA_NVCC_EXECUTABLE}) 47 | message(STATUS "CUDA toolkit directory: " ${CUDA_TOOLKIT_ROOT_DIR}) 48 | 49 | if(ENABLE_DEBUG) 50 | message(STATUS "TiledCUDA: Debug mode enabled") 51 | set(CMAKE_BUILD_TYPE Debug) 52 | set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -DDEBUG") 53 | set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -DDEBUG") 54 | endif() 55 | -------------------------------------------------------------------------------- /utils/cpp/cutlass/convert.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "cuda_utils.hpp" 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | namespace benchmarks { 9 | namespace cutlass_wrapper { 10 | 11 | using namespace cute; 12 | 13 | namespace { 14 | template 15 | DEVICE auto convert_type(cute::Tensor const& tensor) { 16 | using From_type = typename Engine::value_type; 17 | constexpr int numel = decltype(size(tensor))::value; 18 | cutlass::NumericArrayConverter convert_op; 19 | // HACK: this requires tensor to be "contiguous" 20 | auto frag = 21 | convert_op(*reinterpret_cast*>( 22 | tensor.data())); 23 | 24 | return make_tensor(make_rmem_ptr(&frag), tensor.layout()); 25 | } 26 | 27 | template 28 | struct IndexedTensor_ { 29 | DEVICE IndexedTensor_(Tensor& tensor) : tensor_(tensor) {} 30 | 31 | DEVICE const auto operator[](int idx) { return tensor_(_, _, idx); } 32 | 33 | private: 34 | Tensor& tensor_; 35 | }; 36 | } // namespace 37 | 38 | // Convert acc_layout from (MMA=4, MMA_M, MMA_N) to 39 | // ((4, 2), MMA_M, MMA_N / 2) if using m16n8k16, or to (4, MMA_M, MMA_N) if 40 | // using m16n8k8. 41 | template 42 | DEVICE auto convert_layout(const Tensor& acc) { 43 | auto acc_layout = acc.layout(); 44 | 45 | using X = Underscore; 46 | static_assert(decltype(size<0>(acc_layout))::value == 4); 47 | static_assert(decltype(cute::rank(acc_layout))::value == 3); 48 | 49 | constexpr int mma_shape_K = cute::get<2>(typename MMA::Shape_MNK{}); 50 | static_assert(mma_shape_K == 8 || mma_shape_K == 16); 51 | 52 | if constexpr (mma_shape_K == 8) { 53 | IndexedTensor_ indexed_tensor(acc); 54 | return indexed_tensor; 55 | } else { 56 | // (4, MMA_M, (2, MMA_N / 2))) 57 | auto l = cute::logical_divide(acc_layout, Shape{}); 58 | auto new_layout = make_layout(make_layout(get<0>(l), get<2, 0>(l)), 59 | get<1>(l), get<2, 1>(l)); 60 | auto new_tensor = make_tensor(acc.data(), new_layout); 61 | 62 | IndexedTensor_ indexed_tensor(new_tensor); 63 | return indexed_tensor; 64 | } 65 | }; 66 | } // namespace cutlass_wrapper 67 | } // namespace benchmarks 68 | -------------------------------------------------------------------------------- /benchs/python/gemm/tiledcuda/entry.py: -------------------------------------------------------------------------------- 1 | config = """#include "gemm.hpp" 2 | 3 | static constexpr int kWarpPerRow = {kWarpPerRow}; 4 | static constexpr int kWarpPerCol = {kWarpPerCol}; 5 | 6 | static constexpr int kM = {kM}; 7 | static constexpr int kN = {kN}; 8 | static constexpr int kK = {kK}; 9 | 10 | static constexpr int kTM = {kTM}; 11 | static constexpr int kTN = {kTN}; 12 | static constexpr int kTK = {kTK}; 13 | 14 | static constexpr int kRK = {kRK}; 15 | """ 16 | 17 | kernel_entry = """ 18 | extern "C" int kernel_entry(const __half* A, const __half* B, float* C) { 19 | using InType = __half; 20 | using AccType = float; 21 | 22 | using WholeShape = GemmShape; 23 | using CtaTileShape = GemmShape; 24 | using WarpLayout = tl::RowMajor; 25 | 26 | using Config = KeGemmTraits; 28 | 29 | auto kernel = 30 | &gemm; 40 | 41 | static constexpr int smem_size_inputs = kTK * (kTN + kTM) * sizeof(InType); 42 | static constexpr int smem_size_accumulators = kTM * kTN * sizeof(AccType); 43 | static constexpr int smem_size = smem_size_inputs > smem_size_accumulators 44 | ? smem_size_inputs 45 | : smem_size_accumulators; 46 | 47 | const int kMaxSmemPerBlock = 48 * 1024; 48 | if (smem_size > kMaxSmemPerBlock) { 49 | cudaFuncSetAttribute( 50 | kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); 51 | } 52 | 53 | int block_x = CeilDiv; 54 | int block_y = CeilDiv; 55 | 56 | dim3 dim_grid(block_x, block_y, 1); 57 | dim3 dim_block(Config::kThreads, 1, 1); 58 | 59 | kernel<<>>(A, B, C); 60 | 61 | return 0; 62 | } 63 | """ 64 | -------------------------------------------------------------------------------- /benchs/python/batched_gemm/cutlass/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from typing import Tuple 4 | 5 | from batched_gemm import batched_gemm_func as cutlass_batched_gemm 6 | 7 | 8 | def run_unittest(a: Tensor, 9 | b: Tensor, 10 | c: Tensor, 11 | M: int, 12 | N: int, 13 | K: int, 14 | BatchCount: int, 15 | kTM: int, 16 | kTN: int, 17 | kTK: int, 18 | warp_layout: Tuple, 19 | debug_print=False, 20 | epsilon: float = 5e-2): 21 | cutlass_batched_gemm(a, b, c, M, N, K, BatchCount, kTM, kTN, kTK, *warp_layout) 22 | ref_c = torch.bmm(a.view(BatchCount, M, K), b.transpose(1, 2).view(BatchCount, K, N)) 23 | 24 | if debug_print: 25 | print("Result:") 26 | print(c) 27 | 28 | print("\nReference:") 29 | print(ref_c) 30 | 31 | avg_diff = (torch.sum(torch.abs(ref_c - c) / (M * N))).item() 32 | 33 | if avg_diff > epsilon: 34 | return False 35 | else: 36 | return True 37 | 38 | 39 | def run_test( 40 | M: int, 41 | N: int, 42 | K: int, 43 | BatchCount: int, 44 | kTM: int, 45 | kTN: int, 46 | kTK: int, 47 | warp_layout: Tuple, 48 | ): 49 | device = torch.device("cuda") 50 | dtype = torch.float16 51 | 52 | torch.manual_seed(1234) 53 | 54 | a = torch.randn(BatchCount, M, K, device=device, dtype=dtype) 55 | b = torch.randn(BatchCount, N, K, device=device, dtype=dtype) 56 | c = torch.zeros(BatchCount, M, N, device=device, dtype=dtype) 57 | 58 | if run_unittest(a, b, c, M, N, K, BatchCount, kTM, kTN, kTK, warp_layout, debug_print=True): 59 | print("Unittest passed") 60 | else: 61 | raise ValueError("Unittest failed") 62 | 63 | start_event = torch.cuda.Event(enable_timing=True) 64 | end_event = torch.cuda.Event(enable_timing=True) 65 | 66 | iters = 50 67 | start_event.record() 68 | for _ in range(iters): 69 | cutlass_batched_gemm(a, b, c, M, N, K, BatchCount, kTM, kTN, kTK, *warp_layout) 70 | end_event.record() 71 | torch.cuda.synchronize() 72 | 73 | time = start_event.elapsed_time(end_event) / iters 74 | 75 | return time 76 | 77 | 78 | if __name__ == "__main__": 79 | kM = 256 80 | kN = 256 81 | kK = 256 82 | BatchCount = 10 83 | 84 | kTM = 32 85 | kTN = 32 86 | kTK = 32 87 | 88 | time = run_test(kM, kN, kK, BatchCount, kTM, kTN, kTK, (2, 2)) 89 | 90 | print("Elapsed time: {:.4f} ms".format(time)) 91 | -------------------------------------------------------------------------------- /utils/cpp/cuda_utils.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | namespace benchmarks { 10 | 11 | template 12 | inline constexpr int CeilDiv = (a + b - 1) / b; // for compile-time values 13 | 14 | #if defined(__CUDA_ARCH__) 15 | #define HOST_DEVICE __forceinline__ __host__ __device__ 16 | #define DEVICE __forceinline__ __device__ 17 | #define HOST __forceinline__ __host__ 18 | #else 19 | #define HOST_DEVICE inline 20 | #define DEVICE inline 21 | #define HOST inline 22 | #endif 23 | 24 | #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) 25 | #define CP_ASYNC_SM80_ENABLED 26 | #endif 27 | 28 | const char* cublasGetErrorString(cublasStatus_t status) { 29 | switch (status) { 30 | case CUBLAS_STATUS_SUCCESS: 31 | return "CUBLAS_STATUS_SUCCESS"; 32 | case CUBLAS_STATUS_NOT_INITIALIZED: 33 | return "CUBLAS_STATUS_NOT_INITIALIZED"; 34 | case CUBLAS_STATUS_ALLOC_FAILED: 35 | return "CUBLAS_STATUS_ALLOC_FAILED"; 36 | case CUBLAS_STATUS_INVALID_VALUE: 37 | return "CUBLAS_STATUS_INVALID_VALUE"; 38 | case CUBLAS_STATUS_ARCH_MISMATCH: 39 | return "CUBLAS_STATUS_ARCH_MISMATCH"; 40 | case CUBLAS_STATUS_MAPPING_ERROR: 41 | return "CUBLAS_STATUS_MAPPING_ERROR"; 42 | case CUBLAS_STATUS_EXECUTION_FAILED: 43 | return "CUBLAS_STATUS_EXECUTION_FAILED"; 44 | case CUBLAS_STATUS_INTERNAL_ERROR: 45 | return "CUBLAS_STATUS_INTERNAL_ERROR"; 46 | case CUBLAS_STATUS_NOT_SUPPORTED: 47 | return "CUBLAS_STATUS_NOT_SUPPORTED"; 48 | case CUBLAS_STATUS_LICENSE_ERROR: 49 | return "CUBLAS_STATUS_LICENSE_ERROR"; 50 | } 51 | return "unknown error"; 52 | } 53 | 54 | inline void __cublasCheck(const cublasStatus_t err, const char* file, 55 | int line) { 56 | if (err != CUBLAS_STATUS_SUCCESS) { 57 | fprintf(stderr, "%s(%d): Cublas error: %s.\n", file, line, 58 | cublasGetErrorString(err)); 59 | exit(EXIT_FAILURE); 60 | } 61 | } 62 | #define CublasCheck(call) __cublasCheck(call, __FILE__, __LINE__) 63 | 64 | inline void __cudaCheck(const cudaError err, const char* file, int line) { 65 | if (err != cudaSuccess) { 66 | fprintf(stderr, "%s(%d): CUDA error: %s.\n", file, line, 67 | cudaGetErrorString(err)); 68 | exit(EXIT_FAILURE); 69 | } 70 | } 71 | #define CudaCheck(call) __cudaCheck(call, __FILE__, __LINE__) 72 | 73 | } // namespace benchmarks 74 | -------------------------------------------------------------------------------- /benchs/python/fused_gemm/tiledcuda/entry.py: -------------------------------------------------------------------------------- 1 | config = """#include "fused_gemm.hpp" 2 | 3 | static constexpr int kWarpPerRow = {kWarpPerRow}; 4 | static constexpr int kWarpPerCol = {kWarpPerCol}; 5 | 6 | static constexpr int kM = {kM}; 7 | static constexpr int kN = {kN}; 8 | static constexpr int kK = {kK}; 9 | static constexpr int kP = {kP}; 10 | 11 | static constexpr int kTM = {kTM}; 12 | static constexpr int kTN = {kTN}; 13 | static constexpr int kTK = {kTK}; 14 | static constexpr int kTP = {kTP}; 15 | 16 | static constexpr int kRK = {kRK}; 17 | """ 18 | 19 | kernel_entry = """ 20 | extern "C" int kernel_entry(const __half* A, const __half* B, const __half* C, float* D) { 21 | using InType = __half; 22 | using AccType = float; 23 | 24 | using WholeShape = GemmShape; 25 | using CtaTileShape = GemmShape; 26 | using WarpLayout = tl::RowMajor; 27 | 28 | using Config = KeGemmTraits; 30 | 31 | auto kernel = 32 | &gemm; 42 | 43 | static constexpr int smem_size_inputs = kTK * (kTN + kTM) * sizeof(InType); 44 | static constexpr int smem_size_accumulators = kTM * kTN * sizeof(AccType); 45 | static constexpr int smem_size = smem_size_inputs > smem_size_accumulators 46 | ? smem_size_inputs 47 | : smem_size_accumulators; 48 | 49 | const int kMaxSmemPerBlock = 48 * 1024; 50 | if (smem_size > kMaxSmemPerBlock) { 51 | cudaFuncSetAttribute( 52 | kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); 53 | } 54 | 55 | int block_x = CeilDiv; 56 | int block_y = CeilDiv; 57 | 58 | dim3 dim_grid(block_x, block_y, 1); 59 | dim3 dim_block(Config::kThreads, 1, 1); 60 | 61 | kernel<<>>(A, B, C); 62 | 63 | return 0; 64 | } 65 | """ 66 | -------------------------------------------------------------------------------- /benchs/python/fused_gemm/cutlass/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from typing import Tuple 4 | 5 | from fused_gemm import fused_gemm_func as cutlass_fused_gemm 6 | 7 | 8 | def run_unittest(a: Tensor, 9 | b: Tensor, 10 | c: Tensor, 11 | d: Tensor, 12 | M: int, 13 | N: int, 14 | K: int, 15 | P: int, 16 | kTM: int, 17 | kTN: int, 18 | kTK: int, 19 | kTP: int, 20 | warp_layout: Tuple, 21 | debug_print=False, 22 | epsilon: float = 5e-2): 23 | cutlass_fused_gemm(a, b, c, d, M, N, K, P, kTM, kTN, kTK, kTP, *warp_layout) 24 | 25 | ref_acc = a @ b.t() 26 | ref_d = ref_acc @ c.t() 27 | 28 | if debug_print: 29 | print("Result:") 30 | print(d) 31 | 32 | print("\nReference:") 33 | print(ref_d) 34 | 35 | avg_diff = (torch.sum(torch.abs(ref_d - d) / (M * N))).item() 36 | 37 | if avg_diff > epsilon: 38 | return False 39 | else: 40 | return True 41 | 42 | 43 | def run_test( 44 | M: int, 45 | N: int, 46 | K: int, 47 | P: int, 48 | kTM: int, 49 | kTN: int, 50 | kTK: int, 51 | kTP: int, 52 | warp_layout: Tuple, 53 | ): 54 | device = torch.device("cuda") 55 | dtype = torch.float16 56 | 57 | torch.manual_seed(1234) 58 | 59 | a = torch.randn(M, K, device=device, dtype=dtype) 60 | b = torch.randn(N, K, device=device, dtype=dtype) 61 | c = torch.randn(P, N, device=device, dtype=dtype) 62 | d = torch.zeros(M, P, device=device, dtype=dtype) 63 | 64 | if run_unittest(a, b, c, d, M, N, K, P, kTM, kTN, kTK, kTP, warp_layout, debug_print=True): 65 | print("Unittest passed") 66 | else: 67 | raise ValueError("Unittest failed") 68 | 69 | start_event = torch.cuda.Event(enable_timing=True) 70 | end_event = torch.cuda.Event(enable_timing=True) 71 | 72 | iters = 50 73 | start_event.record() 74 | for _ in range(iters): 75 | cutlass_fused_gemm(a, b, c, d, M, N, K, P, kTM, kTN, kTK, kTP, *warp_layout) 76 | end_event.record() 77 | torch.cuda.synchronize() 78 | 79 | time = start_event.elapsed_time(end_event) / iters 80 | 81 | return time 82 | 83 | 84 | if __name__ == "__main__": 85 | kM = 1024 86 | kN = 1024 87 | kK = 1024 88 | kP = 1024 89 | 90 | kTM = 64 91 | kTN = 64 92 | kTK = 64 93 | kTP = 64 94 | 95 | time = run_test(kM, kN, kK, kP, kTM, kTN, kTK, kTP, (1, 1)) 96 | 97 | print("Elapsed time: {:.4f} ms".format(time)) 98 | -------------------------------------------------------------------------------- /benchs/cpp/gemm/util.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "cutlass_gemm.cuh" 4 | #include "tiledcuda_gemm.cuh" 5 | #include "util/cuda_timer.hpp" 6 | 7 | #include 8 | #include 9 | 10 | #include 11 | 12 | using namespace benchmarks; 13 | 14 | float rand_float(float a = 1e-4, float b = 5e-3) { 15 | float random = ((float)rand()) / (float)RAND_MAX; 16 | float diff = b - a; 17 | float r = random * diff; 18 | return a + r; 19 | } 20 | 21 | namespace { 22 | bool check_results_impl(const __half* values1, const float* values2, 23 | int numel) { 24 | bool passed = true; 25 | const float epsilon = 1e-3; 26 | 27 | double total_diff = 0.; 28 | double max_abs_diff = FLT_MIN; 29 | double diff = 0.; 30 | 31 | #ifdef DEBUG 32 | int cut_off = 128; 33 | printf("ground truth:\n"); 34 | for (int i = 0; i < cut_off; ++i) { 35 | printf("%.5f, ", __half2float(values1[i])); 36 | if (i && (i + 1) % 16 == 0) printf("\n"); 37 | } 38 | printf("\ncomputed values:\n"); 39 | for (int i = 0; i < cut_off; ++i) { 40 | printf("%.5f, ", values2[i]); 41 | if (i && (i + 1) % 16 == 0) printf("\n"); 42 | } 43 | #endif 44 | 45 | for (int i = 0; i < numel; ++i) { 46 | float v1 = __half2float(values1[i]); 47 | float v2 = values2[i]; 48 | 49 | diff = fabs(v1 - v2); 50 | max_abs_diff = max_abs_diff < diff ? diff : max_abs_diff; 51 | total_diff += diff; 52 | 53 | #ifdef DEBUG 54 | if (diff > epsilon) { 55 | printf("the %d-th value differs (%.4f): %.4f vs. %.4f\n", i, diff, 56 | v1, v2); 57 | } 58 | #endif 59 | } 60 | 61 | double avg_diff = total_diff / numel; 62 | if (avg_diff > epsilon) passed = false; 63 | 64 | return passed; 65 | } 66 | } // namespace 67 | 68 | template 69 | bool check_results(const T* values1_, const float* values2, int numel); 70 | 71 | template <> 72 | bool check_results(const cutlass::half_t* values1_, const float* values2, 73 | int numel) { 74 | const __half* values1 = reinterpret_cast(values1_); 75 | return check_results_impl(values1, values2, numel); 76 | } 77 | 78 | template <> 79 | bool check_results(const __half* values1, const float* values2, int numel) { 80 | return check_results_impl(values1, values2, numel); 81 | } 82 | 83 | float cublas_hgemm(int64_t kM, int64_t kN, int64_t kK, const __half* A, 84 | const __half* B, __half* C, bool timeit = false, 85 | int warm_up = 5, int iters = 20) { 86 | cublasHandle_t handle; 87 | cublasCreate(&handle); 88 | 89 | __half alf = static_cast<__half>(1.); 90 | __half bet = static_cast<__half>(0.); 91 | 92 | float elapsed = 0.; 93 | 94 | if (timeit) { 95 | for (int i = 0; i < warm_up; ++i) { 96 | cublasHgemm(handle, CUBLAS_OP_T /* transb*/, CUBLAS_OP_N, kN, kM, 97 | kK, &alf, B, kK, A, kK, &bet, C, kN); 98 | } 99 | cudaDeviceSynchronize(); 100 | 101 | CudaTimer timer; 102 | timer.start(); 103 | for (int i = 0; i < iters; ++i) { 104 | cublasHgemm(handle, CUBLAS_OP_T /* transb*/, CUBLAS_OP_N, kN, kM, 105 | kK, &alf, B, kK, A, kK, &bet, C, kN); 106 | } 107 | cudaDeviceSynchronize(); 108 | elapsed = timer.stop() / iters; 109 | } else { 110 | cublasHgemm(handle, CUBLAS_OP_T /* transb*/, CUBLAS_OP_N, kN, kM, kK, 111 | &alf, B, kK, A, kK, &bet, C, kN); 112 | } 113 | cudaDeviceSynchronize(); 114 | 115 | cublasDestroy(handle); 116 | return elapsed; 117 | } 118 | -------------------------------------------------------------------------------- /benchs/python/lstm/cutlass/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from typing import Tuple 4 | 5 | from lstm import lstm_func as cutlass_lstm 6 | 7 | 8 | def run_unittest(w: Tensor, 9 | x: Tensor, 10 | u: Tensor, 11 | c0: Tensor, 12 | h0: Tensor, 13 | c1: Tensor, 14 | h1: Tensor, 15 | M: int, 16 | N: int, 17 | K: int, 18 | kTM: int, 19 | kTN: int, 20 | kTK: int, 21 | warp_layout: Tuple, 22 | debug_print=False, 23 | epsilon: float = 5e-2): 24 | wdata = w.flatten() 25 | xdata = x.t().flatten() 26 | udata = u.flatten() 27 | c0data = c0.flatten() 28 | h0data = h0.t().flatten() 29 | c1data = c1.flatten() 30 | h1data = h1.flatten() 31 | 32 | cutlass_lstm(wdata, xdata, udata, c0data, h0data, c1data, h1data, M, N, K, kTM, kTN, kTK, *warp_layout) 33 | 34 | # Input Gate 35 | i = torch.sigmoid( 36 | w[0] @ x + u[0] @ h0 37 | ) 38 | # Forget Gate 39 | f = torch.sigmoid( 40 | w[1] @ x + u[1] @ h0 41 | ) 42 | # Output Gate 43 | o = torch.sigmoid( 44 | w[2] @ x + u[2] @ h0 45 | ) 46 | # Cell Gate 47 | c = torch.tanh( 48 | w[3] @ x + u[3] @ h0 49 | ) 50 | 51 | ref_c1 = f * c0 + i * c 52 | ref_h1 = o * torch.tanh(c1) 53 | c1 = c1.view(hidden_size, batch_size) 54 | h1 = h1.view(hidden_size, batch_size) 55 | 56 | if debug_print: 57 | print("Result:") 58 | print("c: ", c1) 59 | print("h: ", h1) 60 | 61 | print("\nReference:") 62 | print("c: ", ref_c1) 63 | print("h: ", ref_h1) 64 | 65 | avg_diff = (torch.sum(torch.abs(ref_c1 - c1) / (M * N))).item() 66 | 67 | if avg_diff > epsilon: 68 | print(f"Average difference: {avg_diff}") 69 | return False 70 | else: 71 | return True 72 | 73 | 74 | def run_test( 75 | hidden_size: int, 76 | batch_size: int, 77 | kTM: int, 78 | kTN: int, 79 | kTK: int, 80 | warp_layout: Tuple, 81 | ): 82 | 83 | 84 | device = torch.device("cuda") 85 | dtype = torch.float16 86 | 87 | torch.manual_seed(1234) 88 | 89 | M = 4 * hidden_size 90 | N = batch_size 91 | K = hidden_size 92 | 93 | w = torch.randn(4, hidden_size, hidden_size, device=device, dtype=dtype) 94 | x = torch.randn(hidden_size, batch_size, device=device, dtype=dtype) 95 | u = torch.randn(4, hidden_size, hidden_size, device=device, dtype=dtype) 96 | c0 = torch.randn(hidden_size, batch_size, device=device, dtype=dtype) 97 | h0 = torch.randn(hidden_size, batch_size, device=device, dtype=dtype) 98 | c1 = torch.empty(hidden_size, batch_size, device=device, dtype=dtype) 99 | h1 = torch.empty(hidden_size, batch_size, device=device, dtype=dtype) 100 | 101 | if run_unittest(w, x, u, c0, h0, c1, h1, M, N, K, kTM, kTN, kTK, warp_layout, debug_print=True): 102 | print("Unittest passed") 103 | else: 104 | raise ValueError("Unittest failed") 105 | 106 | start_event = torch.cuda.Event(enable_timing=True) 107 | end_event = torch.cuda.Event(enable_timing=True) 108 | 109 | iters = 50 110 | start_event.record() 111 | for _ in range(iters): 112 | cutlass_lstm(w, x, u, c0, h0, c1, h1, M, N, K, kTM, kTN, kTK, *warp_layout) 113 | end_event.record() 114 | torch.cuda.synchronize() 115 | 116 | time = start_event.elapsed_time(end_event) / iters 117 | 118 | return time 119 | 120 | 121 | if __name__ == "__main__": 122 | hidden_size = 1024 123 | batch_size = 256 124 | 125 | kTM = 32 126 | kTN = 32 127 | kTK = 32 128 | 129 | time = run_test(hidden_size, batch_size, kTM, kTN, kTK, (2, 2)) 130 | 131 | print("Elapsed time: {:.4f} ms".format(time)) 132 | -------------------------------------------------------------------------------- /benchs/python/gemm/triton/gemm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | 4 | os.environ['CUDA_LAUNCH_BLOCKING'] = '1' 5 | os.environ["TRITON_INTERPRET"] = '1' 6 | 7 | import triton 8 | import triton.language as tl 9 | 10 | 11 | @triton.autotune( 12 | configs=[ 13 | triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=3, num_warps=8), 14 | triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32}, num_stages=3, num_warps=8), 15 | triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=3, num_warps=8), 16 | triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32}, num_stages=3, num_warps=8), 17 | triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=3, num_warps=8), 18 | triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32}, num_stages=3, num_warps=8), 19 | triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32}, num_stages=3, num_warps=8), 20 | triton.Config({'BLOCK_M': 32, 'BLOCK_N': 64, 'BLOCK_K': 32}, num_stages=3, num_warps=8), 21 | triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128}, num_stages=3, num_warps=8), 22 | triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128}, num_stages=3, num_warps=8), 23 | triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128}, num_stages=3, num_warps=8), 24 | triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128}, num_stages=3, num_warps=8), 25 | triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128}, num_stages=3, num_warps=8), 26 | triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64}, num_stages=3, num_warps=8), 27 | triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=3, num_warps=8), 28 | triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=3, num_warps=8), 29 | triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64}, num_stages=3, num_warps=8) 30 | ], 31 | key = ['M', 'N', 'K'] 32 | ) 33 | 34 | @triton.jit 35 | def _gemm_kernel( 36 | a_ptr, b_ptr, c_ptr, 37 | M, N, K, 38 | stride_am, stride_ak, 39 | stride_bk, stride_bn, 40 | stride_cm, stride_cn, 41 | BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, 42 | ): 43 | 44 | pid_m = tl.program_id(0) 45 | pid_n = tl.program_id(1) 46 | 47 | offset_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M 48 | offset_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N 49 | offset_k = tl.arange(0, BLOCK_K) 50 | 51 | a_ptrs = a_ptr + (offset_am[:, None] * stride_am + offset_k[None, :] * stride_ak) 52 | b_ptrs = b_ptr + (offset_k[:, None] * stride_bk + offset_bn[None, :] * stride_bn) 53 | 54 | acc = tl.zeros((BLOCK_M, BLOCK_N), dtype = tl.float32) 55 | for k in range(0, tl.cdiv(K, BLOCK_K)): 56 | a = tl.load(a_ptrs, mask = offset_k[None, :] < K - k * BLOCK_K) 57 | b = tl.load(b_ptrs, mask = offset_k[:, None] < K - k * BLOCK_K) 58 | 59 | acc = tl.dot(a, b, acc) 60 | 61 | a_ptrs += BLOCK_K * stride_ak 62 | b_ptrs += BLOCK_K * stride_bk 63 | 64 | 65 | offset_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 66 | offset_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 67 | c_ptrs = c_ptr + (offset_cm[:, None] * stride_cm + offset_cn[None, :] * stride_cn) 68 | c_mask = (offset_cm[:, None] < M) & (offset_cn[None, :] < N) 69 | 70 | tl.store(c_ptrs, acc, mask = c_mask) 71 | 72 | def gemm(a, b): 73 | assert a.shape[1] == b.shape[0], "shape mismatch" 74 | assert a.is_contiguous() and b.is_contiguous(), "input must be contiguous" 75 | 76 | M, K = a.shape 77 | K, N = b.shape 78 | 79 | c = torch.empty((M, N), device = a.device, dtype = torch.float16) 80 | 81 | def grid(META): 82 | return (tl.cdiv(M, META['BLOCK_M']), tl.cdiv(N, META['BLOCK_N']), 1) 83 | 84 | _gemm_kernel[grid]( 85 | a, b, c, 86 | M, N, K, 87 | a.stride(0), a.stride(1), 88 | b.stride(0), b.stride(1), 89 | c.stride(0), c.stride(1) 90 | ) 91 | 92 | return c 93 | -------------------------------------------------------------------------------- /benchs/python/lstm/cutlass/compile.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import importlib.util 4 | 5 | from collections import defaultdict 6 | 7 | import ctypes 8 | import torch 9 | 10 | __all__ = ["Compile"] 11 | 12 | cutlass_include_dir = os.path.join(os.path.dirname(__file__), 13 | "../../../3rd-party/cutlass/include") 14 | tiledcuda_include_dir = os.path.join(os.path.dirname(__file__), 15 | "../../../3rd-party/TiledCUDA/include/") 16 | utils_include_dir = os.path.join(os.path.dirname(__file__), "../../../") 17 | 18 | 19 | class Compile: 20 | 21 | def __init__(self, file_prefix, tmp_dir): 22 | self.tmp_dir = tmp_dir 23 | self.file_prefix = file_prefix 24 | 25 | if not os.path.exists(self.tmp_dir): 26 | os.makedirs(self.tmp_dir) 27 | 28 | compute_capability = torch.cuda.get_device_capability() 29 | self.cc = f"{compute_capability[0]}{compute_capability[1]}" 30 | 31 | self.nvcc_path = self._find_nvcc_path() 32 | 33 | def _find_nvcc_path(self): 34 | 35 | def py_str(x): 36 | return x.decode('utf-8') 37 | 38 | if "CUDA_PATH" in os.environ: 39 | return os.environ["CUDA_PATH"] 40 | 41 | cmd = ["which", "nvcc"] 42 | proc = subprocess.Popen(cmd, 43 | stdout=subprocess.PIPE, 44 | stderr=subprocess.STDOUT) 45 | (out, _) = proc.communicate() 46 | 47 | if proc.returncode == 0: 48 | return py_str(out.strip()) 49 | else: 50 | raise RuntimeError("Cannot find cuda path") 51 | 52 | def _create_entry_code(self, M: int, N: int, K: int, kTM: int, kTN: int, 53 | kTK: int, warp_per_row: int, warp_per_col: int): 54 | entry_code_path = "entry.py" 55 | spec = importlib.util.spec_from_file_location("binding", 56 | entry_code_path) 57 | foo = importlib.util.module_from_spec(spec) 58 | spec.loader.exec_module(foo) 59 | 60 | shape = defaultdict(int) 61 | shape["WarpPerRow"] = warp_per_row 62 | shape["WarpPerCol"] = warp_per_col 63 | shape["kM"] = M 64 | shape["kN"] = N 65 | shape["kK"] = K 66 | shape["kTM"] = kTM 67 | shape["kTN"] = kTN 68 | shape["kTK"] = kTK 69 | 70 | return foo.entry.format_map(shape) 71 | 72 | def compile(self, 73 | M: int, 74 | N: int, 75 | K: int, 76 | kTM: int, 77 | kTN: int, 78 | kTK: int, 79 | warp_per_row: int, 80 | warp_per_col: int, 81 | timeout: float = None): 82 | 83 | temp_dir = self.tmp_dir 84 | 85 | file_name = (f"{self.file_prefix}_{M}_{N}_{K}" 86 | f"_{kTM}_{kTN}_{kTK}_{warp_per_row}_{warp_per_col}") 87 | lib_path = os.path.join(temp_dir, f"{file_name}.so") 88 | 89 | if os.path.exists(lib_path): 90 | return lib_path 91 | 92 | entry_code = self._create_entry_code(M, N, K, kTM, kTN, kTK, 93 | warp_per_row, warp_per_col) 94 | 95 | source_path = os.path.join(temp_dir, "bind.cu") 96 | with open(source_path, "w") as f: 97 | f.write(entry_code) 98 | 99 | command = [ 100 | self.nvcc_path, "-std=c++20", "-O3", "--use_fast_math", 101 | "--expt-relaxed-constexpr", "--disable-warnings", 102 | "--compiler-options", "'-fPIC'", "--shared", source_path, "-lcuda", 103 | f"-gencode=arch=compute_{self.cc},code=sm_{self.cc}", 104 | f"-I{cutlass_include_dir}", f"-I{tiledcuda_include_dir}", 105 | f"-I{utils_include_dir}", "-o", lib_path 106 | ] 107 | try: 108 | ret = subprocess.run(command, timeout=timeout) 109 | except subprocess.TimeoutExpired: 110 | return None 111 | if ret.returncode == 0: 112 | return lib_path 113 | else: 114 | raise RuntimeError("Compilation failed") 115 | 116 | def apply(self, lib_path, torch_array: list, device: int): 117 | lib = ctypes.CDLL(lib_path) 118 | 119 | lib.kernel_entry.restype = ctypes.c_int 120 | torch.cuda.set_device(device) 121 | 122 | ret = lib.kernel_entry( 123 | *[ctypes.c_void_p(arr.data_ptr()) for arr in torch_array]) 124 | return ret 125 | -------------------------------------------------------------------------------- /benchs/python/gemm/cutlass/compile.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import importlib.util 4 | 5 | from collections import defaultdict 6 | 7 | import ctypes 8 | import torch 9 | 10 | __all__ = ["Compile"] 11 | 12 | cutlass_include_dir = os.path.join(os.path.dirname(__file__), 13 | "../../../../3rd-party/cutlass/include") 14 | tiledcuda_include_dir = os.path.join(os.path.dirname(__file__), 15 | "../../../../3rd-party/TiledCUDA/include/") 16 | utils_include_dir = os.path.join(os.path.dirname(__file__), "../../../../") 17 | 18 | 19 | class Compile: 20 | 21 | def __init__(self, file_prefix, tmp_dir): 22 | self.tmp_dir = tmp_dir 23 | self.file_prefix = file_prefix 24 | 25 | if not os.path.exists(self.tmp_dir): 26 | os.makedirs(self.tmp_dir) 27 | 28 | compute_capability = torch.cuda.get_device_capability() 29 | self.cc = f"{compute_capability[0]}{compute_capability[1]}" 30 | 31 | self.nvcc_path = self._find_nvcc_path() 32 | 33 | def _find_nvcc_path(self): 34 | 35 | def py_str(x): 36 | return x.decode('utf-8') 37 | 38 | if "CUDA_PATH" in os.environ: 39 | return os.environ["CUDA_PATH"] 40 | 41 | cmd = ["which", "nvcc"] 42 | proc = subprocess.Popen(cmd, 43 | stdout=subprocess.PIPE, 44 | stderr=subprocess.STDOUT) 45 | (out, _) = proc.communicate() 46 | 47 | if proc.returncode == 0: 48 | return py_str(out.strip()) 49 | else: 50 | raise RuntimeError("Cannot find cuda path") 51 | 52 | def _create_entry_code(self, M: int, N: int, K: int, kTM: int, kTN: int, 53 | kTK: int, warp_per_row: int, warp_per_col: int): 54 | entry_code_path = os.path.join(os.path.dirname(__file__), "entry.py") 55 | spec = importlib.util.spec_from_file_location("binding", 56 | entry_code_path) 57 | foo = importlib.util.module_from_spec(spec) 58 | spec.loader.exec_module(foo) 59 | 60 | shape = defaultdict(int) 61 | shape["WarpPerRow"] = warp_per_row 62 | shape["WarpPerCol"] = warp_per_col 63 | shape["kM"] = M 64 | shape["kN"] = N 65 | shape["kK"] = K 66 | shape["kTM"] = kTM 67 | shape["kTN"] = kTN 68 | shape["kTK"] = kTK 69 | 70 | return foo.entry.format_map(shape) 71 | 72 | def compile(self, 73 | M: int, 74 | N: int, 75 | K: int, 76 | kTM: int, 77 | kTN: int, 78 | kTK: int, 79 | warp_per_row: int, 80 | warp_per_col: int, 81 | timeout: float = None): 82 | 83 | temp_dir = self.tmp_dir 84 | 85 | file_name = (f"{self.file_prefix}_{M}_{N}_{K}" 86 | f"_{kTM}_{kTN}_{kTK}_{warp_per_row}_{warp_per_col}") 87 | lib_path = os.path.join(temp_dir, f"{file_name}.so") 88 | 89 | if os.path.exists(lib_path): 90 | return lib_path 91 | 92 | entry_code = self._create_entry_code(M, N, K, kTM, kTN, kTK, 93 | warp_per_row, warp_per_col) 94 | 95 | source_path = os.path.join(temp_dir, "bind.cu") 96 | with open(source_path, "w") as f: 97 | f.write(entry_code) 98 | 99 | command = [ 100 | self.nvcc_path, "-std=c++20", "-O3", "--use_fast_math", 101 | "--expt-relaxed-constexpr", "--disable-warnings", 102 | "--compiler-options", "'-fPIC'", "--shared", source_path, "-lcuda", 103 | f"-gencode=arch=compute_{self.cc},code=sm_{self.cc}", 104 | f"-I{cutlass_include_dir}", f"-I{tiledcuda_include_dir}", 105 | f"-I{utils_include_dir}", "-o", lib_path 106 | ] 107 | try: 108 | ret = subprocess.run(command, timeout=timeout) 109 | except subprocess.TimeoutExpired: 110 | return None 111 | if ret.returncode == 0: 112 | return lib_path 113 | else: 114 | raise RuntimeError("Compilation failed") 115 | 116 | def apply(self, lib_path, torch_array: list, device: int): 117 | lib = ctypes.CDLL(lib_path) 118 | 119 | lib.kernel_entry.restype = ctypes.c_int 120 | torch.cuda.set_device(device) 121 | 122 | ret = lib.kernel_entry( 123 | *[ctypes.c_void_p(arr.data_ptr()) for arr in torch_array]) 124 | return ret 125 | -------------------------------------------------------------------------------- /benchs/python/batched_gemm/cutlass/compile.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import importlib.util 4 | 5 | from collections import defaultdict 6 | 7 | import ctypes 8 | import torch 9 | 10 | __all__ = ["Compile"] 11 | 12 | cutlass_include_dir = os.path.join(os.path.dirname(__file__), 13 | "../../../3rd-party/cutlass/include") 14 | tiledcuda_include_dir = os.path.join(os.path.dirname(__file__), 15 | "../../../3rd-party/TiledCUDA/include/") 16 | utils_include_dir = os.path.join(os.path.dirname(__file__), "../../../") 17 | 18 | 19 | class Compile: 20 | 21 | def __init__(self, file_prefix, tmp_dir): 22 | self.tmp_dir = tmp_dir 23 | self.file_prefix = file_prefix 24 | 25 | if not os.path.exists(self.tmp_dir): 26 | os.makedirs(self.tmp_dir) 27 | 28 | compute_capability = torch.cuda.get_device_capability() 29 | self.cc = f"{compute_capability[0]}{compute_capability[1]}" 30 | 31 | self.nvcc_path = self._find_nvcc_path() 32 | 33 | def _find_nvcc_path(self): 34 | 35 | def py_str(x): 36 | return x.decode('utf-8') 37 | 38 | if "CUDA_PATH" in os.environ: 39 | return os.environ["CUDA_PATH"] 40 | 41 | cmd = ["which", "nvcc"] 42 | proc = subprocess.Popen(cmd, 43 | stdout=subprocess.PIPE, 44 | stderr=subprocess.STDOUT) 45 | (out, _) = proc.communicate() 46 | 47 | if proc.returncode == 0: 48 | return py_str(out.strip()) 49 | else: 50 | raise RuntimeError("Cannot find cuda path") 51 | 52 | def _create_entry_code(self, M: int, N: int, K: int, BatchCount: int, 53 | kTM: int, kTN: int, kTK: int, warp_per_row: int, 54 | warp_per_col: int): 55 | entry_code_path = "entry.py" 56 | spec = importlib.util.spec_from_file_location("binding", 57 | entry_code_path) 58 | foo = importlib.util.module_from_spec(spec) 59 | spec.loader.exec_module(foo) 60 | 61 | shape = defaultdict(int) 62 | shape["WarpPerRow"] = warp_per_row 63 | shape["WarpPerCol"] = warp_per_col 64 | shape["kM"] = M 65 | shape["kN"] = N 66 | shape["kK"] = K 67 | shape["BatchCount"] = BatchCount 68 | shape["kTM"] = kTM 69 | shape["kTN"] = kTN 70 | shape["kTK"] = kTK 71 | 72 | return foo.entry.format_map(shape) 73 | 74 | def compile(self, 75 | M: int, 76 | N: int, 77 | K: int, 78 | BatchCount: int, 79 | kTM: int, 80 | kTN: int, 81 | kTK: int, 82 | warp_per_row: int, 83 | warp_per_col: int, 84 | timeout: float = None): 85 | 86 | temp_dir = self.tmp_dir 87 | 88 | file_name = (f"{self.file_prefix}_{M}_{N}_{K}_{BatchCount}" 89 | f"_{kTM}_{kTN}_{kTK}_{warp_per_row}_{warp_per_col}") 90 | lib_path = os.path.join(temp_dir, f"{file_name}.so") 91 | 92 | if os.path.exists(lib_path): 93 | return lib_path 94 | 95 | entry_code = self._create_entry_code(M, N, K, BatchCount, kTM, kTN, kTK, 96 | warp_per_row, warp_per_col) 97 | 98 | source_path = os.path.join(temp_dir, "bind.cu") 99 | with open(source_path, "w") as f: 100 | f.write(entry_code) 101 | 102 | command = [ 103 | self.nvcc_path, "-std=c++20", "-O3", "--use_fast_math", 104 | "--expt-relaxed-constexpr", "--disable-warnings", 105 | "--compiler-options", "'-fPIC'", "--shared", source_path, "-lcuda", 106 | f"-gencode=arch=compute_{self.cc},code=sm_{self.cc}", 107 | f"-I{cutlass_include_dir}", f"-I{tiledcuda_include_dir}", 108 | f"-I{utils_include_dir}", "-o", lib_path 109 | ] 110 | try: 111 | ret = subprocess.run(command, timeout=timeout) 112 | except subprocess.TimeoutExpired: 113 | return None 114 | if ret.returncode == 0: 115 | return lib_path 116 | else: 117 | raise RuntimeError("Compilation failed") 118 | 119 | def apply(self, lib_path, torch_array: list, device: int): 120 | lib = ctypes.CDLL(lib_path) 121 | 122 | lib.kernel_entry.restype = ctypes.c_int 123 | torch.cuda.set_device(device) 124 | 125 | ret = lib.kernel_entry( 126 | *[ctypes.c_void_p(arr.data_ptr()) for arr in torch_array]) 127 | return ret 128 | -------------------------------------------------------------------------------- /benchs/python/fused_gemm/cutlass/compile.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import importlib.util 4 | 5 | from collections import defaultdict 6 | 7 | import ctypes 8 | import torch 9 | 10 | __all__ = ["Compile"] 11 | 12 | cutlass_include_dir = os.path.join(os.path.dirname(__file__), 13 | "../../../3rd-party/cutlass/include") 14 | tiledcuda_include_dir = os.path.join(os.path.dirname(__file__), 15 | "../../../3rd-party/TiledCUDA/include/") 16 | utils_include_dir = os.path.join(os.path.dirname(__file__), "../../../") 17 | 18 | 19 | class Compile: 20 | 21 | def __init__(self, file_prefix, tmp_dir): 22 | self.tmp_dir = tmp_dir 23 | self.file_prefix = file_prefix 24 | 25 | if not os.path.exists(self.tmp_dir): 26 | os.makedirs(self.tmp_dir) 27 | 28 | compute_capability = torch.cuda.get_device_capability() 29 | self.cc = f"{compute_capability[0]}{compute_capability[1]}" 30 | 31 | self.nvcc_path = self._find_nvcc_path() 32 | 33 | def _find_nvcc_path(self): 34 | 35 | def py_str(x): 36 | return x.decode('utf-8') 37 | 38 | if "CUDA_PATH" in os.environ: 39 | return os.environ["CUDA_PATH"] 40 | 41 | cmd = ["which", "nvcc"] 42 | proc = subprocess.Popen(cmd, 43 | stdout=subprocess.PIPE, 44 | stderr=subprocess.STDOUT) 45 | (out, _) = proc.communicate() 46 | 47 | if proc.returncode == 0: 48 | return py_str(out.strip()) 49 | else: 50 | raise RuntimeError("Cannot find cuda path") 51 | 52 | def _create_entry_code(self, M: int, N: int, K: int, P: int, kTM: int, kTN: int, 53 | kTK: int, kTP: int, warp_per_row: int, warp_per_col: int): 54 | entry_code_path = "entry.py" 55 | spec = importlib.util.spec_from_file_location("binding", 56 | entry_code_path) 57 | foo = importlib.util.module_from_spec(spec) 58 | spec.loader.exec_module(foo) 59 | 60 | shape = defaultdict(int) 61 | shape["WarpPerRow"] = warp_per_row 62 | shape["WarpPerCol"] = warp_per_col 63 | shape["kM"] = M 64 | shape["kN"] = N 65 | shape["kK"] = K 66 | shape["kP"] = P 67 | shape["kTM"] = kTM 68 | shape["kTN"] = kTN 69 | shape["kTK"] = kTK 70 | shape["kTP"] = kTP 71 | 72 | return foo.entry.format_map(shape) 73 | 74 | def compile(self, 75 | M: int, 76 | N: int, 77 | K: int, 78 | P: int, 79 | kTM: int, 80 | kTN: int, 81 | kTK: int, 82 | kTP: int, 83 | warp_per_row: int, 84 | warp_per_col: int, 85 | timeout: float = None): 86 | 87 | temp_dir = self.tmp_dir 88 | 89 | file_name = (f"{self.file_prefix}_{M}_{N}_{K}_{P}" 90 | f"_{kTM}_{kTN}_{kTK}_{kTP}_{warp_per_row}_{warp_per_col}") 91 | lib_path = os.path.join(temp_dir, f"{file_name}.so") 92 | 93 | if os.path.exists(lib_path): 94 | return lib_path 95 | 96 | entry_code = self._create_entry_code(M, N, K, P, kTM, kTN, kTK, kTP, 97 | warp_per_row, warp_per_col) 98 | 99 | source_path = os.path.join(temp_dir, "bind.cu") 100 | with open(source_path, "w") as f: 101 | f.write(entry_code) 102 | 103 | command = [ 104 | self.nvcc_path, "-std=c++20", "-O3", "--use_fast_math", 105 | "--expt-relaxed-constexpr", "--disable-warnings", 106 | "--compiler-options", "'-fPIC'", "--shared", source_path, "-lcuda", 107 | f"-gencode=arch=compute_{self.cc},code=sm_{self.cc}", 108 | f"-I{cutlass_include_dir}", f"-I{tiledcuda_include_dir}", 109 | f"-I{utils_include_dir}", "-o", lib_path 110 | ] 111 | try: 112 | ret = subprocess.run(command, timeout=timeout) 113 | except subprocess.TimeoutExpired: 114 | return None 115 | if ret.returncode == 0: 116 | return lib_path 117 | else: 118 | raise RuntimeError("Compilation failed") 119 | 120 | def apply(self, lib_path, torch_array: list, device: int): 121 | lib = ctypes.CDLL(lib_path) 122 | 123 | lib.kernel_entry.restype = ctypes.c_int 124 | torch.cuda.set_device(device) 125 | 126 | ret = lib.kernel_entry( 127 | *[ctypes.c_void_p(arr.data_ptr()) for arr in torch_array]) 128 | return ret 129 | -------------------------------------------------------------------------------- /benchs/python/gemm/tiledcuda/compile.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import importlib.util 4 | 5 | from collections import defaultdict 6 | 7 | import ctypes 8 | import torch 9 | 10 | __all__ = ["Compile"] 11 | 12 | cutlass_include_dir = os.path.join(os.path.dirname(__file__), 13 | "../../../../3rd-party/cutlass/include") 14 | tiledcuda_include_dir = os.path.join(os.path.dirname(__file__), 15 | "../../../../3rd-party/TiledCUDA/include/") 16 | utils_include_dir = os.path.join(os.path.dirname(__file__), 17 | ("../../../../3rd-party/TiledCUDA/examples/" 18 | "cpp/01_gemm/02_gemm_all_mem")) 19 | 20 | 21 | class Compile: 22 | 23 | def __init__(self, file_prefix, tmp_dir): 24 | self.tmp_dir = tmp_dir 25 | self.file_prefix = file_prefix 26 | 27 | if not os.path.exists(self.tmp_dir): 28 | os.makedirs(self.tmp_dir) 29 | 30 | compute_capability = torch.cuda.get_device_capability() 31 | self.cc = f"{compute_capability[0]}{compute_capability[1]}" 32 | 33 | self.nvcc_path = self._find_nvcc_path() 34 | 35 | def _find_nvcc_path(self): 36 | 37 | def py_str(x): 38 | return x.decode('utf-8') 39 | 40 | if "CUDA_PATH" in os.environ: 41 | return os.environ["CUDA_PATH"] 42 | 43 | cmd = ["which", "nvcc"] 44 | proc = subprocess.Popen(cmd, 45 | stdout=subprocess.PIPE, 46 | stderr=subprocess.STDOUT) 47 | (out, _) = proc.communicate() 48 | 49 | if proc.returncode == 0: 50 | return py_str(out.strip()) 51 | else: 52 | raise RuntimeError("Cannot find cuda path") 53 | 54 | def _create_entry_code(self, M: int, N: int, K: int, kTM: int, kTN: int, 55 | kTK: int, kRK: int, warp_per_row: int, 56 | warp_per_col: int): 57 | entry_code_path = os.path.join(os.path.dirname(__file__), "entry.py") 58 | spec = importlib.util.spec_from_file_location("binding", 59 | entry_code_path) 60 | foo = importlib.util.module_from_spec(spec) 61 | spec.loader.exec_module(foo) 62 | 63 | shape = defaultdict(int) 64 | shape["kWarpPerRow"] = warp_per_row 65 | shape["kWarpPerCol"] = warp_per_col 66 | shape["kM"] = M 67 | shape["kN"] = N 68 | shape["kK"] = K 69 | shape["kTM"] = kTM 70 | shape["kTN"] = kTN 71 | shape["kTK"] = kTK 72 | shape["kRK"] = kRK 73 | 74 | return foo.config.format_map(shape) + foo.kernel_entry 75 | 76 | def compile(self, 77 | M: int, 78 | N: int, 79 | K: int, 80 | kTM: int, 81 | kTN: int, 82 | kTK: int, 83 | kRK: int, 84 | warp_per_row: int, 85 | warp_per_col: int, 86 | timeout: float = None): 87 | 88 | temp_dir = self.tmp_dir 89 | 90 | file_name = (f"{self.file_prefix}_{M}_{N}_{K}" 91 | f"_{kTM}_{kTN}_{kTK}_{kRK}_{warp_per_row}_{warp_per_col}") 92 | lib_path = os.path.join(temp_dir, f"{file_name}.so") 93 | 94 | if os.path.exists(lib_path): 95 | return lib_path 96 | 97 | entry_code = self._create_entry_code(M, N, K, kTM, kTN, kTK, kRK, 98 | warp_per_row, warp_per_col) 99 | 100 | source_path = os.path.join(temp_dir, "bind.cu") 101 | with open(source_path, "w") as f: 102 | f.write(entry_code) 103 | 104 | command = [ 105 | self.nvcc_path, "-std=c++20", "-O3", "--use_fast_math", 106 | "--expt-relaxed-constexpr", "--disable-warnings", 107 | "--compiler-options", "'-fPIC'", "--shared", source_path, "-lcuda", 108 | f"-gencode=arch=compute_{self.cc},code=sm_{self.cc}", 109 | f"-I{cutlass_include_dir}", f"-I{tiledcuda_include_dir}", 110 | f"-I{utils_include_dir}", "-o", lib_path 111 | ] 112 | try: 113 | ret = subprocess.run(command, timeout=timeout) 114 | except subprocess.TimeoutExpired: 115 | return None 116 | if ret.returncode == 0: 117 | return lib_path 118 | else: 119 | raise RuntimeError("Compilation failed") 120 | 121 | def apply(self, lib_path, torch_array: list, device: int): 122 | lib = ctypes.CDLL(lib_path) 123 | 124 | lib.kernel_entry.restype = ctypes.c_int 125 | torch.cuda.set_device(device) 126 | 127 | ret = lib.kernel_entry( 128 | *[ctypes.c_void_p(arr.data_ptr()) for arr in torch_array]) 129 | return ret 130 | -------------------------------------------------------------------------------- /benchs/python/fused_gemm/tiledcuda/compile.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import importlib.util 4 | 5 | from collections import defaultdict 6 | 7 | import ctypes 8 | import torch 9 | 10 | __all__ = ["Compile"] 11 | 12 | cutlass_include_dir = os.path.join(os.path.dirname(__file__), 13 | "../../../3rd-party/cutlass/include") 14 | tiledcuda_include_dir = os.path.join(os.path.dirname(__file__), 15 | "../../../3rd-party/TiledCUDA/include/") 16 | utils_include_dir = os.path.join(os.path.dirname(__file__), 17 | ("../../../3rd-party/TiledCUDA/examples/" 18 | "cpp/01_gemm/02_gemm_all_mem")) 19 | 20 | 21 | class Compile: 22 | 23 | def __init__(self, file_prefix, tmp_dir): 24 | self.tmp_dir = tmp_dir 25 | self.file_prefix = file_prefix 26 | 27 | if not os.path.exists(self.tmp_dir): 28 | os.makedirs(self.tmp_dir) 29 | 30 | compute_capability = torch.cuda.get_device_capability() 31 | self.cc = f"{compute_capability[0]}{compute_capability[1]}" 32 | 33 | self.nvcc_path = self._find_nvcc_path() 34 | 35 | def _find_nvcc_path(self): 36 | 37 | def py_str(x): 38 | return x.decode('utf-8') 39 | 40 | if "CUDA_PATH" in os.environ: 41 | return os.environ["CUDA_PATH"] 42 | 43 | cmd = ["which", "nvcc"] 44 | proc = subprocess.Popen(cmd, 45 | stdout=subprocess.PIPE, 46 | stderr=subprocess.STDOUT) 47 | (out, _) = proc.communicate() 48 | 49 | if proc.returncode == 0: 50 | return py_str(out.strip()) 51 | else: 52 | raise RuntimeError("Cannot find cuda path") 53 | 54 | def _create_entry_code(self, M: int, N: int, K: int, P: int, 55 | kTM: int, kTN: int, kTK: int, kTP: int, 56 | kRK: int, warp_per_row: int, 57 | warp_per_col: int): 58 | entry_code_path = "entry.py" 59 | spec = importlib.util.spec_from_file_location("binding", 60 | entry_code_path) 61 | foo = importlib.util.module_from_spec(spec) 62 | spec.loader.exec_module(foo) 63 | 64 | shape = defaultdict(int) 65 | shape["kWarpPerRow"] = warp_per_row 66 | shape["kWarpPerCol"] = warp_per_col 67 | shape["kM"] = M 68 | shape["kN"] = N 69 | shape["kK"] = K 70 | shape["kP"] = P 71 | shape["kTM"] = kTM 72 | shape["kTN"] = kTN 73 | shape["kTK"] = kTK 74 | shape["kTP"] = kTP 75 | shape["kRK"] = kRK 76 | 77 | return foo.config.format_map(shape) + foo.kernel_entry 78 | 79 | def compile(self, 80 | M: int, 81 | N: int, 82 | K: int, 83 | P: int, 84 | kTM: int, 85 | kTN: int, 86 | kTK: int, 87 | kTP: int, 88 | kRK: int, 89 | warp_per_row: int, 90 | warp_per_col: int, 91 | timeout: float = None): 92 | 93 | temp_dir = self.tmp_dir 94 | 95 | file_name = (f"{self.file_prefix}_{M}_{N}_{K}_{P}" 96 | f"_{kTM}_{kTN}_{kTK}_{kTP}_{kRK}_{warp_per_row}_{warp_per_col}") 97 | lib_path = os.path.join(temp_dir, f"{file_name}.so") 98 | 99 | if os.path.exists(lib_path): 100 | return lib_path 101 | 102 | entry_code = self._create_entry_code(M, N, K, kTM, kTN, kTK, kRK, 103 | warp_per_row, warp_per_col) 104 | 105 | source_path = os.path.join(temp_dir, "bind.cu") 106 | with open(source_path, "w") as f: 107 | f.write(entry_code) 108 | 109 | command = [ 110 | self.nvcc_path, "-std=c++20", "-O3", "--use_fast_math", 111 | "--expt-relaxed-constexpr", "--disable-warnings", 112 | "--compiler-options", "'-fPIC'", "--shared", source_path, "-lcuda", 113 | f"-gencode=arch=compute_{self.cc},code=sm_{self.cc}", 114 | f"-I{cutlass_include_dir}", f"-I{tiledcuda_include_dir}", 115 | f"-I{utils_include_dir}", "-o", lib_path 116 | ] 117 | try: 118 | ret = subprocess.run(command, timeout=timeout) 119 | except subprocess.TimeoutExpired: 120 | return None 121 | if ret.returncode == 0: 122 | return lib_path 123 | else: 124 | raise RuntimeError("Compilation failed") 125 | 126 | def apply(self, lib_path, torch_array: list, device: int): 127 | lib = ctypes.CDLL(lib_path) 128 | 129 | lib.kernel_entry.restype = ctypes.c_int 130 | torch.cuda.set_device(device) 131 | 132 | ret = lib.kernel_entry( 133 | *[ctypes.c_void_p(arr.data_ptr()) for arr in torch_array]) 134 | return ret 135 | -------------------------------------------------------------------------------- /benchs/cpp/gemm/tiledcuda_gemm.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "cell/mod.hpp" 3 | #include "types/mod.hpp" 4 | 5 | using namespace tiledcuda; 6 | using namespace tiledcuda::cell; 7 | using namespace tiledcuda::cell::copy; 8 | 9 | namespace tl = tile_layout; 10 | 11 | template 12 | using GemmShape = TileShape; 13 | 14 | template 16 | struct KeGemmTraits { 17 | using BaseShape = traits::BaseTileShape; 18 | 19 | static constexpr int kThreads = tl::get_numel * 32; 20 | static constexpr int kWarpPerRow = tl::num_rows; 21 | static constexpr int kWarpPerCol = tl::num_cols; 22 | 23 | static constexpr int kM = dim_size<0, WholeShape>; 24 | static constexpr int kN = dim_size<1, WholeShape>; 25 | static constexpr int kK = dim_size<2, WholeShape>; 26 | 27 | static constexpr int kTM = dim_size<0, CtaTileShape>; 28 | static constexpr int kTN = dim_size<1, CtaTileShape>; 29 | static constexpr int kTK = dim_size<2, CtaTileShape>; 30 | 31 | static const bool kSwizzled = true; 32 | 33 | // Total data access for operand A in global memory 34 | using GlobalA = GlobalTile>; 35 | // Access a single global tile for operand A 36 | using GIteratorA = GTileIterator>; 37 | 38 | // Shared Tile for operand A 39 | using SharedA = SharedTile, kSwizzled>; 40 | // Access a single register tile for operand A 41 | using SIteratorA = STileIterator>; 42 | 43 | // Register tile for a single thread of operand A 44 | static constexpr int kAMs = kTM / kWarpPerRow / BaseShape::kTileSize; 45 | static constexpr int kAKs = kRK / BaseShape::kTileSize; 46 | using RegA = RegTile, tl::RowMajor>; 47 | 48 | // Loaders for operand A 49 | using G2SLoaderA = GlobalToSharedLoader; 50 | using S2RLoaderA = 51 | SharedToRegLoader; 52 | 53 | // Total data access for operand B in global memory 54 | using GlobalB = GlobalTile>; 55 | // Access a single global tile for operand B 56 | using GIteratorB = GTileIterator>; 57 | 58 | // Shared Tile for operand B 59 | using SharedB = SharedTile, kSwizzled>; 60 | // Access a single register tile for operand B 61 | using SIteratorB = STileIterator>; 62 | 63 | static_assert(GIteratorA::sc1 == GIteratorB::sc0, 64 | "mismatched K dimension!"); 65 | static_assert(SIteratorA::sc1 == SIteratorB::sc0, 66 | "mismatched K dimension!"); 67 | 68 | // Register tile for a single thread of operand A 69 | static constexpr int kBKs = kRK / BaseShape::kTileSize; 70 | static constexpr int kBNs = kTN / kWarpPerCol / BaseShape::kTileSize; 71 | using RegB = RegTile, tl::ColMajor>; 72 | 73 | using G2SLoaderB = GlobalToSharedLoader; 74 | using S2RLoaderB = 75 | SharedToRegLoader; 76 | 77 | // Global Tile for output C 78 | using GlobalC = GlobalTile>; 79 | // Shared Tile for output C 80 | using SharedC = SharedTile, kSwizzled>; 81 | 82 | // Register Tile for output C 83 | static constexpr int kCMs = kTM / kWarpPerRow / BaseShape::kTileSize; 84 | static constexpr int kCNs = kTN / kWarpPerCol / BaseShape::kTileSize; 85 | using RegC = RegTile, tl::RowMajor>; 86 | 87 | using R2SStorerC = RegToSharedStorer; 88 | using S2GStorerC = SharedToGlobalStorer; 89 | }; 90 | 91 | template 102 | __global__ void gemm(const InType* dA, const InType* dB, AccType* dC) { 103 | int offset_a = blockIdx.x * kTM * kK; 104 | int offset_b = blockIdx.y * kTN * kK; 105 | int offset_c = blockIdx.x * kTM * kN + blockIdx.y * kTN; 106 | 107 | extern __shared__ __align__(sizeof(double)) unsigned char buf[]; 108 | InType* sA_ptr = reinterpret_cast(buf); 109 | InType* sB_ptr = sA_ptr + SIteratorA::Tile::kNumel; 110 | AccType* sC_ptr = reinterpret_cast(buf); 111 | 112 | // declare tiles, iterators and loaders 113 | GIteratorA gAs(dA + offset_a); 114 | SIteratorA sAs(sA_ptr); 115 | 116 | GIteratorB gBs(dB + offset_b); 117 | SIteratorB sBs(sB_ptr); 118 | 119 | SharedA sA(sA_ptr); 120 | RegA rA; 121 | 122 | SharedB sB(sB_ptr); 123 | RegB rB; 124 | 125 | RegC acc; 126 | SharedC sC(sC_ptr); 127 | GlobalC gC(dC + offset_c); 128 | 129 | G2SLoaderA g2s_a; 130 | S2RLoaderA s2r_a; 131 | 132 | G2SLoaderB g2s_b; 133 | S2RLoaderB s2r_b; 134 | 135 | R2SStorerC r2s_c; 136 | S2GStorerC s2g_c; 137 | 138 | for (int k1 = 0; k1 < GIteratorA::sc1; ++k1) { 139 | g2s_a(gAs(k1), sA); 140 | g2s_b(gBs(k1), sB); 141 | __copy_async(); 142 | __syncthreads(); 143 | 144 | for (int k2 = 0; k2 < SIteratorA::sc1; ++k2) { 145 | s2r_a(sAs(k2), rA); 146 | s2r_b(sBs(k2), rB); 147 | 148 | compute::gemm(rA, rB, acc); 149 | } 150 | } 151 | r2s_c(acc, sC); 152 | __syncthreads(); 153 | s2g_c(sC, gC); 154 | } 155 | -------------------------------------------------------------------------------- /benchs/cpp/copy/bench.cu: -------------------------------------------------------------------------------- 1 | #include "cutlass/cutlass_copy.cuh" 2 | #include "tiledcuda/tiledcuda_copy.cuh" 3 | #include "utils/cpp/cuda_info.cuh" 4 | #include "utils/cpp/cuda_timer.cuh" 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | float rand_float(float a = 1e-4, float b = 1e-2) { 17 | float random = ((float)rand()) / (float)RAND_MAX; 18 | float diff = b - a; 19 | float r = random * diff; 20 | return a + r; 21 | } 22 | 23 | void run_test(std::ofstream& fout) { 24 | using Element = cutlass::half_t; 25 | using InType = __half; 26 | using AccType = float; 27 | 28 | static constexpr int kM = 4096; 29 | static constexpr int kN = 4096; 30 | static constexpr int kK = 2048; 31 | 32 | static constexpr int kTM = 64; 33 | static constexpr int kTN = 32; 34 | static constexpr int kTK = 32; 35 | 36 | static constexpr int kWarpPerRow = 2; 37 | static constexpr int kWarpPerCol = 2; 38 | 39 | thrust::host_vector h_a(kM * kK); 40 | for (int i = 0; i < h_a.size(); ++i) 41 | h_a[i] = static_cast(rand_float()); 42 | 43 | thrust::host_vector h_b(kK * kN); 44 | for (int i = 0; i < h_b.size(); ++i) 45 | h_b[i] = static_cast(rand_float()); 46 | 47 | thrust::host_vector h_c(kM * kN); 48 | thrust::fill(h_c.begin(), h_c.end(), 0.); 49 | 50 | thrust::host_vector h_c2(kM * kN); 51 | thrust::fill(h_c2.begin(), h_c2.end(), 0.); 52 | 53 | thrust::device_vector d_a = h_a; 54 | thrust::device_vector d_b = h_b; 55 | thrust::device_vector d_c = h_c; 56 | thrust::device_vector d_c2 = h_c2; 57 | 58 | const Element* dA = thrust::raw_pointer_cast(d_a.data()); 59 | const Element* dB = thrust::raw_pointer_cast(d_b.data()); 60 | Element* dC = thrust::raw_pointer_cast(d_c.data()); 61 | 62 | const InType* dA2 = reinterpret_cast(dA); 63 | const InType* dB2 = reinterpret_cast(dB); 64 | AccType* dC2 = thrust::raw_pointer_cast(d_c2.data()); 65 | 66 | auto cute_shared_copy_kernel = 67 | &cute_shared_copy; 69 | 70 | auto cute_copy_kernel = &cute_copy; 72 | 73 | auto tiledcuda_shared_copy_kernel = 74 | &tiledcuda_shared_copy; 76 | 77 | auto tiledcuda_copy_kernel = 78 | &tiledcuda_copy; 80 | 81 | const int warm_up = 5; 82 | const int iters = 20; 83 | 84 | benchmarks::CudaTimer timer; 85 | 86 | for (int i = 0; i < warm_up; ++i) { 87 | cute_shared_copy_kernel(dA, dB, dC); 88 | } 89 | 90 | cudaDeviceSynchronize(); 91 | 92 | timer.start(); 93 | for (int i = 0; i < iters; ++i) { 94 | cute_shared_copy_kernel(dA, dB, dC); 95 | } 96 | cudaDeviceSynchronize(); 97 | float cutlass_time = timer.stop() / iters; 98 | 99 | for (int i = 0; i < warm_up; ++i) { 100 | cute_copy_kernel(dA, dB, dC); 101 | } 102 | 103 | cudaDeviceSynchronize(); 104 | 105 | timer.start(); 106 | for (int i = 0; i < iters; ++i) { 107 | cute_copy_kernel(dA, dB, dC); 108 | } 109 | cudaDeviceSynchronize(); 110 | float cutlass_time2 = timer.stop() / iters; 111 | 112 | for (int i = 0; i < warm_up; ++i) { 113 | tiledcuda_shared_copy_kernel(dA2, dB2, dC2); 114 | } 115 | 116 | cudaDeviceSynchronize(); 117 | 118 | timer.start(); 119 | for (int i = 0; i < iters; ++i) { 120 | tiledcuda_shared_copy_kernel(dA2, dB2, dC2); 121 | } 122 | cudaDeviceSynchronize(); 123 | float tiledcuda_time = timer.stop() / iters; 124 | 125 | for (int i = 0; i < warm_up; ++i) { 126 | tiledcuda_copy_kernel(dA2, dB2, dC2); 127 | } 128 | 129 | cudaDeviceSynchronize(); 130 | 131 | timer.start(); 132 | for (int i = 0; i < iters; ++i) { 133 | tiledcuda_copy_kernel(dA2, dB2, dC2); 134 | } 135 | 136 | cudaDeviceSynchronize(); 137 | float tiledcuda_time2 = timer.stop() / iters; 138 | 139 | std::cout << "Whole\t" << "[" << kM << ", " << kN << ", " << kK << "]\t[" 140 | << kTM << ", " << kTN << ", " << kTK << "]\t[" << kWarpPerRow 141 | << ", " << kWarpPerCol << "]\t" << cutlass_time2 << "\t" 142 | << tiledcuda_time2 << "\t" << tiledcuda_time2 / cutlass_time2 143 | << std::endl; 144 | 145 | std::cout << "G2S\t" << "[" << kM << ", " << kN << ", " << kK << "]\t[" 146 | << kTM << ", " << kTN << ", " << kTK << "]\t[" << kWarpPerRow 147 | << ", " << kWarpPerCol << "]\t" << cutlass_time << "\t" 148 | << tiledcuda_time << "\t" << tiledcuda_time / cutlass_time 149 | << std::endl; 150 | 151 | fout << "Whole\t" << "[" << kM << ", " << kN << ", " << kK << "]\t[" << kTM 152 | << ", " << kTN << ", " << kTK << "]\t[" << kWarpPerRow << ", " 153 | << kWarpPerCol << "]\t" << cutlass_time2 << "\t" << tiledcuda_time2 154 | << "\t" << tiledcuda_time2 / cutlass_time2 << std::endl; 155 | 156 | fout << "G2S\t" << "[" << kM << ", " << kN << ", " << kK << "]\t[" << kTM 157 | << ", " << kTN << ", " << kTK << "]\t[" << kWarpPerRow << ", " 158 | << kWarpPerCol << "]\t" << cutlass_time << "\t" << tiledcuda_time 159 | << "\t" << tiledcuda_time / cutlass_time << std::endl; 160 | } 161 | 162 | int main() { 163 | std::ofstream fout; 164 | fout.setf(std::ios::fixed); 165 | fout.precision(4); 166 | 167 | auto dev_name = benchmarks::get_device_name(); 168 | std::stringstream file_name; 169 | file_name << "bench_" << dev_name << "_copy.tsv"; 170 | fout.open(file_name.str(), std::ios::out); 171 | 172 | fout << "Copy Type\t" 173 | << "[M, N, K]\t[kTM, kTN, kTK]\t[kWarpPerRow, " 174 | "kWarpPerCol]\tCutlassTime(ms)\tTiledCUDATime(ms)\tRatio" 175 | << std::endl; 176 | 177 | std::cout << "Copy Type\t" 178 | << "[M, N, K]\t[kTM, kTN, kTK]\t[kWarpPerRow, " 179 | "kWarpPerCol]\tCutlassTime(ms)\tTiledCUDATime(ms)\tRatio" 180 | << std::endl; 181 | 182 | run_test(fout); 183 | 184 | return 0; 185 | } 186 | -------------------------------------------------------------------------------- /benchs/cpp/gemm/cutlass_gemm.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "utils/cpp/cuda_utils.cuh" 4 | #include "utils/cpp/cutlass/copy.cuh" 5 | #include "utils/cpp/cutlass/traits_base.cuh" 6 | 7 | #include 8 | 9 | namespace benchmarks { 10 | namespace cutlass_wrapper { 11 | 12 | using namespace cute; 13 | 14 | template > 19 | struct GemmTraits : public Base { 20 | using Element = Element_; 21 | 22 | static_assert(kTM % kWarpPerRow == 0, 23 | "the M dimension of the CTA tile should be divisible by the " 24 | "number of warps along that that dimension."); 25 | static_assert(kTN % kWarpPerCol == 0, 26 | "the N dimension of the CTA tile should be divisible by the " 27 | "number of warps along that that dimension."); 28 | 29 | // declare global to shared memory copy layout. 30 | using GmemLayoutA = Layout, Int>, Stride, _1>>; 31 | using GmemLayoutB = Layout, Int>, Stride, _1>>; 32 | using GmemLayoutC = Layout, Int>, Stride, _1>>; 33 | 34 | using TiledMma = 35 | TiledMMA, // for ampere 36 | Layout, Int, _1>>, 37 | Tile, Int<16 * kWarpPerCol>, _16>>; 38 | 39 | static constexpr int kThreads = size(TiledMma{}); 40 | static_assert(kThreads == kWarpPerRow * kWarpPerCol * 32); 41 | 42 | static constexpr int kNumPerAccess = Base::kNumPerAccess; 43 | static constexpr int kThreadsPerCol = CeilDiv; 44 | static constexpr int kThreadsPerRow = CeilDiv; 45 | 46 | using SmemLayoutAtom = decltype(composition( 47 | Swizzle<2, 3, 3>{}, Layout>, 48 | Stride, _1>>{})); 49 | 50 | using SmemLayoutA = 51 | decltype(tile_to_shape(SmemLayoutAtom{}, Shape, Int>{})); 52 | using SmemLayoutB = 53 | decltype(tile_to_shape(SmemLayoutAtom{}, Shape, Int>{})); 54 | using SmemLayoutC = 55 | decltype(tile_to_shape(SmemLayoutAtom{}, Shape, Int>{})); 56 | 57 | #ifdef CP_ASYNC_SM80_ENABLED 58 | using CopyInstG2S = 59 | Copy_Atom, Element>; 60 | #else 61 | using CopyInstG2S = Copy_Atom; 62 | #endif 63 | 64 | using TiledCopyG2S = decltype(make_tiled_copy( 65 | CopyInstG2S{}, 66 | Layout, Int>, 67 | Stride, _1>>{}, 68 | Layout>>{})); 69 | 70 | using TiledCopyS2G = decltype(make_tiled_copy( 71 | Copy_Atom{}, 72 | Layout, Int>, 73 | Stride, _1>>{}, 74 | Layout>>{})); 75 | using StoreC_R2S = R2SCopy2D; 76 | }; 77 | 78 | template 80 | __global__ void gemm_kernel(const Element* dA, const Element* dB, Element* dC) { 81 | extern __shared__ __align__(sizeof(double)) unsigned char buf_[]; 82 | auto* buf = reinterpret_cast(buf_); 83 | 84 | // Advance to the global data tile to the current CTA. 85 | Element* gA_ptr = const_cast(dA) + blockIdx.x * kK * kTM; 86 | Element* gB_ptr = const_cast(dB) + blockIdx.y * kK * kTN; 87 | Element* gC_ptr = dC + blockIdx.x * kTM * kN + blockIdx.y * kTN; 88 | 89 | // pointers to shared memory tiles 90 | Element* sA_ptr = buf; 91 | Element* sB_ptr = buf + kTM * kTK; 92 | Element* sC_ptr = buf; 93 | 94 | typename KeTraits::TiledMma mma; 95 | typename KeTraits::TiledCopyG2S tiled_copy; 96 | 97 | auto rA = make_s2rA(sA_ptr, typename KeTraits::SmemLayoutA{}, mma); 98 | auto rB = make_s2rB(sB_ptr, typename KeTraits::SmemLayoutB{}, mma); 99 | auto acc = get_acc(mma); 100 | 101 | for (int k = 0; k < kK; k += kTK) { 102 | copy_tile_g2s(gA_ptr, sA_ptr, typename KeTraits::GmemLayoutA{}, 103 | typename KeTraits::SmemLayoutA{}, tiled_copy); 104 | copy_tile_g2s(gB_ptr, sB_ptr, typename KeTraits::GmemLayoutB{}, 105 | typename KeTraits::SmemLayoutB{}, tiled_copy); 106 | __copy_async(); 107 | __syncthreads(); 108 | 109 | for (int i = 0; i < rA.get_iters(); ++i) { 110 | rA.copy(i); // load A register tile from shared memory 111 | rB.copy(i); // load B register tile from shared memory 112 | 113 | gemm(mma, rA[i], rB[i], acc); 114 | } 115 | 116 | gA_ptr += kTK; 117 | gB_ptr += kTK; 118 | } 119 | 120 | typename KeTraits::StoreC_R2S sC; // declare register to shared store plan 121 | sC.copy(acc, buf); // store register tile to shared memory 122 | __syncthreads(); 123 | 124 | // store shared memory tile to global memory 125 | copy_tile_s2g(sC_ptr, gC_ptr, typename KeTraits::SmemLayoutC{}, 126 | typename KeTraits::GmemLayoutC{}, 127 | typename KeTraits::TiledCopyS2G{}); 128 | } 129 | } // namespace cutlass_wrapper 130 | } // namespace benchmarks 131 | 132 | template 136 | void cute_gemm(const Element* dA, const Element* dB, Element* dC) { 137 | using namespace benchmarks::cutlass_wrapper; 138 | 139 | using KeTraits = GemmTraits; 141 | 142 | static constexpr int smem_size = 143 | std::max(kTK * (kTN + kTM), kTM * kTN) * sizeof(Element); 144 | 145 | auto kernel = &gemm_kernel; 146 | 147 | // maximal statically allocated smem per block 148 | const int kMaxSmemPerBlock = 48 * 1024; 149 | if (smem_size > kMaxSmemPerBlock) { 150 | cudaFuncSetAttribute( 151 | kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); 152 | } 153 | 154 | const int block_m = (kM + kTM - 1) / kTM; 155 | const int block_n = (kN + kTN - 1) / kTN; 156 | 157 | const int kThreads = KeTraits::kThreads; 158 | 159 | dim3 gridDim(block_m, block_n); 160 | dim3 blockDim(kThreads, 1, 1); 161 | 162 | kernel<<>>(dA, dB, dC); 163 | } 164 | -------------------------------------------------------------------------------- /benchs/python/gemm/cutlass/cutlass_gemm.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "utils/cpp/cuda_utils.cuh" 4 | #include "utils/cpp/cutlass/copy.cuh" 5 | #include "utils/cpp/cutlass/traits_base.cuh" 6 | 7 | #include 8 | 9 | namespace benchmarks { 10 | namespace cutlass_wrapper { 11 | 12 | using namespace cute; 13 | 14 | template > 19 | struct GemmTraits : public Base { 20 | using Element = Element_; 21 | 22 | static_assert(kTM % kWarpPerRow == 0, 23 | "the M dimension of the CTA tile should be divisible by the " 24 | "number of warps along that that dimension."); 25 | static_assert(kTN % kWarpPerCol == 0, 26 | "the N dimension of the CTA tile should be divisible by the " 27 | "number of warps along that that dimension."); 28 | 29 | // declare global to shared memory copy layout. 30 | using GmemLayoutA = Layout, Int>, Stride, _1>>; 31 | using GmemLayoutB = Layout, Int>, Stride, _1>>; 32 | using GmemLayoutC = Layout, Int>, Stride, _1>>; 33 | 34 | using TiledMma = 35 | TiledMMA, // for ampere 36 | Layout, Int, _1>>, 37 | Tile, Int<16 * kWarpPerCol>, _16>>; 38 | 39 | static constexpr int kThreads = size(TiledMma{}); 40 | static_assert(kThreads == kWarpPerRow * kWarpPerCol * 32); 41 | 42 | static constexpr int kNumPerAccess = Base::kNumPerAccess; 43 | static constexpr int kThreadsPerCol = CeilDiv; 44 | static constexpr int kThreadsPerRow = CeilDiv; 45 | 46 | using SmemLayoutAtom = decltype(composition( 47 | Swizzle<2, 3, 3>{}, Layout>, 48 | Stride, _1>>{})); 49 | 50 | using SmemLayoutA = 51 | decltype(tile_to_shape(SmemLayoutAtom{}, Shape, Int>{})); 52 | using SmemLayoutB = 53 | decltype(tile_to_shape(SmemLayoutAtom{}, Shape, Int>{})); 54 | using SmemLayoutC = 55 | decltype(tile_to_shape(SmemLayoutAtom{}, Shape, Int>{})); 56 | 57 | #ifdef CP_ASYNC_SM80_ENABLED 58 | using CopyInstG2S = 59 | Copy_Atom, Element>; 60 | #else 61 | using CopyInstG2S = Copy_Atom; 62 | #endif 63 | 64 | using TiledCopyG2S = decltype(make_tiled_copy( 65 | CopyInstG2S{}, 66 | Layout, Int>, 67 | Stride, _1>>{}, 68 | Layout>>{})); 69 | 70 | using TiledCopyS2G = decltype(make_tiled_copy( 71 | Copy_Atom{}, 72 | Layout, Int>, 73 | Stride, _1>>{}, 74 | Layout>>{})); 75 | using StoreC_R2S = R2SCopy2D; 76 | }; 77 | 78 | template 80 | __global__ void gemm_kernel(const Element* dA, const Element* dB, Element* dC) { 81 | extern __shared__ __align__(sizeof(double)) unsigned char buf_[]; 82 | auto* buf = reinterpret_cast(buf_); 83 | 84 | // Advance to the global data tile to the current CTA. 85 | Element* gA_ptr = const_cast(dA) + blockIdx.x * kK * kTM; 86 | Element* gB_ptr = const_cast(dB) + blockIdx.y * kK * kTN; 87 | Element* gC_ptr = dC + blockIdx.x * kTM * kN + blockIdx.y * kTN; 88 | 89 | // pointers to shared memory tiles 90 | Element* sA_ptr = buf; 91 | Element* sB_ptr = buf + kTM * kTK; 92 | Element* sC_ptr = buf; 93 | 94 | typename KeTraits::TiledMma mma; 95 | typename KeTraits::TiledCopyG2S tiled_copy; 96 | 97 | auto rA = make_s2rA(sA_ptr, typename KeTraits::SmemLayoutA{}, mma); 98 | auto rB = make_s2rB(sB_ptr, typename KeTraits::SmemLayoutB{}, mma); 99 | auto acc = get_acc(mma); 100 | 101 | for (int k = 0; k < kK; k += kTK) { 102 | copy_tile_g2s(gA_ptr, sA_ptr, typename KeTraits::GmemLayoutA{}, 103 | typename KeTraits::SmemLayoutA{}, tiled_copy); 104 | copy_tile_g2s(gB_ptr, sB_ptr, typename KeTraits::GmemLayoutB{}, 105 | typename KeTraits::SmemLayoutB{}, tiled_copy); 106 | __copy_async(); 107 | __syncthreads(); 108 | 109 | for (int i = 0; i < rA.get_iters(); ++i) { 110 | rA.copy(i); // load A register tile from shared memory 111 | rB.copy(i); // load B register tile from shared memory 112 | 113 | gemm(mma, rA[i], rB[i], acc); 114 | } 115 | __syncthreads(); 116 | 117 | gA_ptr += kTK; 118 | gB_ptr += kTK; 119 | } 120 | 121 | typename KeTraits::StoreC_R2S sC; // declare register to shared store plan 122 | sC.copy(acc, buf); // store register tile to shared memory 123 | __syncthreads(); 124 | 125 | // store shared memory tile to global memory 126 | copy_tile_s2g(sC_ptr, gC_ptr, typename KeTraits::SmemLayoutC{}, 127 | typename KeTraits::GmemLayoutC{}, 128 | typename KeTraits::TiledCopyS2G{}); 129 | } 130 | } // namespace cutlass_wrapper 131 | } // namespace benchmarks 132 | 133 | template 137 | void cute_gemm(const Element* dA, const Element* dB, Element* dC) { 138 | using namespace benchmarks::cutlass_wrapper; 139 | 140 | using KeTraits = GemmTraits; 142 | 143 | static constexpr int smem_size = 144 | std::max(kTK * (kTN + kTM), kTM * kTN) * sizeof(Element); 145 | 146 | auto kernel = &gemm_kernel; 147 | 148 | // maximal statically allocated smem per block 149 | const int kMaxSmemPerBlock = 48 * 1024; 150 | if (smem_size > kMaxSmemPerBlock) { 151 | cudaFuncSetAttribute( 152 | kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); 153 | } 154 | 155 | const int block_m = (kM + kTM - 1) / kTM; 156 | const int block_n = (kN + kTN - 1) / kTN; 157 | 158 | const int kThreads = KeTraits::kThreads; 159 | 160 | dim3 gridDim(block_m, block_n); 161 | dim3 blockDim(kThreads, 1, 1); 162 | 163 | kernel<<>>(dA, dB, dC); 164 | } 165 | -------------------------------------------------------------------------------- /benchs/cpp/gemm/bench.cu: -------------------------------------------------------------------------------- 1 | #include "util.cuh" 2 | #include "utils/cpp/cuda_info.cuh" 3 | 4 | #include 5 | 6 | #include 7 | #include 8 | 9 | //// =============== Test Config=============== //// 10 | static const int kWarpPerRow = 2; 11 | static const int kWarpPerCol = 4; 12 | using WholeShape = GemmShape<4096, 4096, 4096>; 13 | using CtaTileShape = GemmShape<128, 128, 64>; 14 | using WarpLayout = tl::RowMajor; 15 | static constexpr int kRK = 16; 16 | 17 | void run_test(std::ofstream& fout) { 18 | //// =============== Declaration =============== //// 19 | static constexpr int kM = dim_size<0, WholeShape>; 20 | static constexpr int kN = dim_size<1, WholeShape>; 21 | static constexpr int kK = dim_size<2, WholeShape>; 22 | 23 | static constexpr int kTM = dim_size<0, CtaTileShape>; 24 | static constexpr int kTN = dim_size<1, CtaTileShape>; 25 | static constexpr int kTK = dim_size<2, CtaTileShape>; 26 | 27 | using InType = cutlass::half_t; 28 | using AccType = float; 29 | 30 | using Config = KeGemmTraits; 32 | auto tiledcuda_gemm = 33 | &gemm; 43 | 44 | static constexpr int inputs = kTK * (kTN + kTM) * sizeof(InType); 45 | static constexpr int accumulators = kTM * kTN * sizeof(AccType); 46 | static constexpr int smem_size = 47 | inputs > accumulators ? inputs : accumulators; 48 | 49 | const int kMaxSmemPerBlock = 48 * 1024; 50 | if (smem_size > kMaxSmemPerBlock) { 51 | cudaFuncSetAttribute(tiledcuda_gemm, 52 | cudaFuncAttributeMaxDynamicSharedMemorySize, 53 | smem_size); 54 | } 55 | int block_x = benchmarks::CeilDiv; 56 | int block_y = benchmarks::CeilDiv; 57 | dim3 dim_grid(block_x, block_y, 1); 58 | dim3 dim_block(Config::kThreads, 1, 1); 59 | 60 | auto cutlass_gemm = 61 | &cute_gemm; 62 | 63 | //// =============== Prepare data =============== //// 64 | // input matrix A 65 | thrust::host_vector h_a(kM * kK); 66 | for (int i = 0; i < h_a.size(); ++i) 67 | h_a[i] = static_cast(rand_float()); 68 | thrust::device_vector d_a = h_a; 69 | const InType* dA = thrust::raw_pointer_cast(d_a.data()); 70 | 71 | // input matrix B 72 | thrust::host_vector h_b(kK * kN); 73 | for (int i = 0; i < h_b.size(); ++i) 74 | h_b[i] = static_cast(rand_float()); 75 | thrust::device_vector d_b = h_b; 76 | const InType* dB = thrust::raw_pointer_cast(d_b.data()); 77 | 78 | // output matrix C for cutlass GEMM kernel 79 | thrust::device_vector d_c(kM * kN); 80 | thrust::fill(d_c.begin(), d_c.end(), static_cast(0.)); 81 | InType* dC = thrust::raw_pointer_cast(d_c.data()); 82 | 83 | cutlass_gemm(dA, dB, dC); 84 | cudaDeviceSynchronize(); 85 | thrust::host_vector h_c = d_c; 86 | 87 | // tiled cuda gemm kernel 88 | thrust::device_vector d_c2(kM * kN); 89 | thrust::fill(d_c2.begin(), d_c2.end(), static_cast(0.)); 90 | AccType* dC2 = thrust::raw_pointer_cast(d_c2.data()); 91 | 92 | tiledcuda_gemm<<>>(dA, dB, dC2); 93 | cudaDeviceSynchronize(); 94 | thrust::host_vector h_c2 = d_c2; 95 | 96 | bool passed1 = check_results( 97 | thrust::raw_pointer_cast(h_c.data()) /*cutlass*/, 98 | thrust::raw_pointer_cast(h_c2.data()) /*tiled cuda*/, kM * kN); 99 | 100 | // cublas 101 | const __half* dA2 = reinterpret_cast(dA); 102 | const __half* dB2 = reinterpret_cast(dB); 103 | thrust::device_vector<__half> d_c3(kM * kN); 104 | thrust::fill(d_c3.begin(), d_c3.end(), static_cast<__half>(0.)); 105 | 106 | cublas_hgemm(kM, kN, kK, dA2, dB2, thrust::raw_pointer_cast(d_c3.data()), 107 | false /*timeit*/); 108 | thrust::host_vector<__half> h_c3 = d_c3; 109 | 110 | bool passed2 = check_results( 111 | thrust::raw_pointer_cast(h_c3.data()) /*cutlass*/, 112 | thrust::raw_pointer_cast(h_c2.data()) /*tiled cuda*/, kM * kN); 113 | 114 | if (!(passed1 && passed2)) { 115 | std::cerr << "Test failed" << std::endl; 116 | return; 117 | } 118 | 119 | //// =============== Timing =============== //// 120 | thrust::fill(d_c.begin(), d_c.end(), static_cast(0.)); 121 | thrust::fill(d_c2.begin(), d_c2.end(), static_cast(0.)); 122 | thrust::fill(d_c3.begin(), d_c3.end(), static_cast<__half>(0.)); 123 | // cublas 124 | float cublas_time = 125 | cublas_hgemm(kM, kN, kK, dA2, dB2, 126 | thrust::raw_pointer_cast(d_c3.data()), true /*timeit*/); 127 | 128 | const int warm_up = 5; 129 | const int iters = 20; 130 | for (int i = 0; i < warm_up; ++i) { 131 | cutlass_gemm(dA, dB, dC); 132 | tiledcuda_gemm<<>>(dA, dB, dC2); 133 | } 134 | cudaDeviceSynchronize(); 135 | 136 | CudaTimer timer; 137 | timer.start(); 138 | for (int i = 0; i < iters; ++i) { 139 | cutlass_gemm(dA, dB, dC); 140 | } 141 | cudaDeviceSynchronize(); 142 | float cutlass_time = timer.stop() / iters; 143 | 144 | timer.start(); 145 | for (int i = 0; i < iters; ++i) { 146 | tiledcuda_gemm<<>>(dA, dB, dC2); 147 | } 148 | cudaDeviceSynchronize(); 149 | float tiledcuda_time = timer.stop() / iters; 150 | 151 | fout << "[" << kM << ", " << kN << ", " << kK << "]\t[" << kTM << ", " 152 | << kTN << ", " << kTK << "]\t" << kRK << "\t[" << kWarpPerRow << ", " 153 | << kWarpPerCol << "]\t" << cublas_time << "\t" << cutlass_time << " (" 154 | << std::setprecision(2) << cutlass_time / cublas_time << ")" 155 | << "\t" << std::setprecision(4) << tiledcuda_time << " (" 156 | << std::setprecision(2) << tiledcuda_time / cublas_time << ")" 157 | << std::endl; 158 | } 159 | 160 | int main() { 161 | std::ofstream fout; 162 | fout.setf(std::ios::fixed); 163 | fout.precision(4); 164 | 165 | auto dev_name = benchmarks::get_device_name(); 166 | std::stringstream file_name; 167 | file_name << "figures/bench_" << dev_name << "_gemm.tsv"; 168 | fout.open(file_name.str(), std::ios::out); 169 | 170 | fout << "[M, N, K]\t[kTM, kTN, kTK]\tkRK\tWarp Layout\t" 171 | "cuBLAS(ms)\tcutlass(ms)\tTiledCUDA(ms)" 172 | << std::endl; 173 | 174 | run_test(fout); 175 | 176 | return 0; 177 | } 178 | -------------------------------------------------------------------------------- /utils/cpp/cutlass/copy.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | namespace benchmarks { 7 | namespace cutlass_wrapper { 8 | 9 | using namespace cute; 10 | 11 | #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) 12 | #define CP_ASYNC_SM80_ENABLED 13 | #endif 14 | 15 | template 16 | DEVICE void wait_group() { 17 | #if defined(CP_ASYNC_SM80_ENABLED) 18 | asm volatile("cp.async.wait_group %0;\n" ::"n"(N)); 19 | #endif 20 | } 21 | 22 | DEVICE void commit_copy_group() { 23 | #if defined(CP_ASYNC_SM80_ENABLED) 24 | cute::cp_async_fence(); 25 | #endif 26 | } 27 | 28 | DEVICE void __copy_async() { 29 | commit_copy_group(); 30 | wait_group<0>(); 31 | } 32 | 33 | // Copy a 2d data tile from global memory to shared memory 34 | template 36 | DEVICE void copy_tile_g2s(const Element* src_data, Element* dst_data, 37 | SrcLayout src_layout, DstLayout dst_layout, 38 | TiledCopy tiled_copy) { 39 | int tid = threadIdx.x; 40 | 41 | auto gtile = make_tensor(make_gmem_ptr(src_data), src_layout); 42 | auto stile = make_tensor(make_smem_ptr(dst_data), dst_layout); 43 | 44 | auto loader = tiled_copy.get_thread_slice(tid); 45 | 46 | auto src = loader.partition_S(gtile); 47 | auto dst = loader.partition_D(stile); 48 | 49 | #pragma unroll 50 | for (int i = 0; i < int(size<1>(src)); ++i) 51 | #pragma unroll 52 | for (int j = 0; j < int(size<2>(src)); ++j) 53 | cute::copy(tiled_copy, src(_, i, j), dst(_, i, j)); 54 | } 55 | 56 | // Copy a tensor from shared memory to global memory 57 | template 59 | DEVICE void copy_tile_s2g(const Element* src_data, Element* dst_data, 60 | SrcLayout src_layout, DstLayout dst_layout, 61 | TiledCopy tiled_copy) { 62 | int tid = threadIdx.x; 63 | 64 | auto stile = make_tensor(make_smem_ptr(src_data), src_layout); 65 | auto gtile = make_tensor(make_gmem_ptr(dst_data), dst_layout); 66 | 67 | auto loader = tiled_copy.get_thread_slice(tid); 68 | 69 | auto src = loader.partition_S(stile); 70 | auto dst = loader.partition_D(gtile); 71 | 72 | #pragma unroll 73 | for (int i = 0; i < int(size<1>(src)); ++i) 74 | #pragma unroll 75 | for (int j = 0; j < int(size<2>(src)); ++j) 76 | cute::copy(tiled_copy, src(_, i, j), dst(_, i, j)); 77 | } 78 | 79 | template 80 | struct R2SCopy2D { 81 | using TiledMma = TiledMma_; 82 | using Dstlayout_ = DstLayout; 83 | using CopyAtom = Copy_Atom; 84 | 85 | public: 86 | template 87 | __device__ void copy(cute::Tensor const& acc, 88 | Element* dst_data) { 89 | int tid = threadIdx.x; 90 | 91 | // FIXME(haruhi): This implementation is specifically designed 92 | // for tcu WMMA and assumes that the ACC value has a 93 | // floating-point precision. The code converts the ACC value 94 | // to half-precision. 95 | auto src_tensor = convert_type(acc); 96 | auto dst_tensor = make_tensor(make_smem_ptr(dst_data), DstLayout{}); 97 | 98 | auto tiled_copy = make_tiled_copy_C(CopyAtom{}, TiledMma{}); 99 | auto thrd_copy = tiled_copy.get_thread_slice(tid); 100 | 101 | auto src = thrd_copy.retile_S(src_tensor); 102 | auto dst = thrd_copy.partition_D(dst_tensor); 103 | cute::copy(tiled_copy, src, dst); 104 | } 105 | 106 | private: 107 | template 108 | DEVICE auto convert_type(cute::Tensor const& tensor) { 109 | using From_type = typename Engine::value_type; 110 | constexpr int numel = decltype(size(tensor))::value; 111 | cutlass::NumericArrayConverter convert_op; 112 | // HACK: this requires tensor to be "contiguous" 113 | auto frag = convert_op( 114 | *reinterpret_cast*>( 115 | tensor.data())); 116 | return make_tensor(make_rmem_ptr(&frag), tensor.layout()); 117 | } 118 | }; 119 | 120 | template 122 | struct Shm2RegLoad { 123 | public: 124 | DEVICE Shm2RegLoad(TiledCopy& copy, const STensor& src, DTensor& dst, 125 | DTensorView& dst_view) 126 | : tiled_copy_(copy), src_(src), dst_(dst), dst_view_(dst_view) {} 127 | 128 | DEVICE void copy(int pos) { 129 | cute::copy(tiled_copy_, src_(_, _, pos), dst_view_(_, _, pos)); 130 | } 131 | 132 | DEVICE int get_iters() { return size<2>(dst_); } 133 | 134 | DEVICE const auto operator[](int idx) { return dst_(_, _, idx); } 135 | 136 | private: 137 | TiledCopy& tiled_copy_; 138 | const STensor& src_; 139 | DTensor& dst_; 140 | DTensorView& dst_view_; 141 | }; 142 | 143 | template 144 | DEVICE auto get_acc(const TiledMma& tiled_mma) { 145 | auto acc = partition_fragment_C(tiled_mma, Shape, Int>{}); 146 | clear(acc); 147 | 148 | return acc; 149 | } 150 | 151 | template 152 | DEVICE auto make_s2rA(const Element* data, const Layout& layout, 153 | const TiledMma& tiled_mma) { 154 | int tid = threadIdx.x; 155 | 156 | auto tensor = cute::make_tensor(make_smem_ptr(data), layout); 157 | 158 | using SmemLoadAtom = Copy_Atom; 159 | auto tiled_copy = make_tiled_copy_A(SmemLoadAtom{}, tiled_mma); 160 | 161 | auto thrd_copy = tiled_copy.get_thread_slice(tid); 162 | auto src = thrd_copy.partition_S(tensor); 163 | 164 | // partition register 165 | auto thr_mma = tiled_mma.get_thread_slice(tid); 166 | auto dst = thr_mma.partition_fragment_A(tensor); 167 | auto dst_view = thrd_copy.retile_D(dst); 168 | 169 | Shm2RegLoad loader(tiled_copy, src, dst, dst_view); 170 | return loader; 171 | } 172 | 173 | // FIXIME(haruhi): the current implementation is for fast experiment, 174 | // it is coupled shared memory layout with the register layout 175 | template 176 | DEVICE auto make_s2rB(const Element* data, const Layout& layout, 177 | const TiledMma& tiled_mma) { 178 | int tid = threadIdx.x; 179 | 180 | using SmemLoadAtom = Copy_Atom; 181 | auto tiled_copy = make_tiled_copy_B(SmemLoadAtom{}, tiled_mma); 182 | auto thrd_copy = tiled_copy.get_thread_slice(tid); 183 | 184 | auto tensor = make_tensor(make_smem_ptr(data), layout); 185 | auto src = thrd_copy.partition_S(tensor); 186 | 187 | // partition register 188 | auto thr_mma = tiled_mma.get_thread_slice(tid); 189 | auto dst = thr_mma.partition_fragment_B(tensor); 190 | auto dst_view = thrd_copy.retile_D(dst); 191 | 192 | Shm2RegLoad loader(tiled_copy, src, dst, dst_view); 193 | return loader; 194 | } 195 | 196 | } // namespace cutlass_wrapper 197 | } // namespace benchmarks 198 | -------------------------------------------------------------------------------- /benchs/python/batched_gemm/cutlass/cutlass_batched_gemm.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "utils/cpp/cuda_utils.cuh" 4 | #include "utils/cpp/cutlass/copy.cuh" 5 | #include "utils/cpp/cutlass/traits_base.cuh" 6 | 7 | #include 8 | 9 | namespace benchmarks { 10 | namespace cutlass_wrapper { 11 | 12 | using namespace cute; 13 | 14 | template > 19 | struct BatchedGemmTraits : public Base { 20 | using Element = Element_; 21 | 22 | static_assert(kTM % kWarpPerRow == 0, 23 | "the M dimension of the CTA tile should be divisible by the " 24 | "number of warps along that that dimension."); 25 | static_assert(kTN % kWarpPerCol == 0, 26 | "the N dimension of the CTA tile should be divisible by the " 27 | "number of warps along that that dimension."); 28 | 29 | // declare global to shared memory copy layout. 30 | using GmemLayoutA = Layout, Int>, Stride, _1>>; 31 | using GmemLayoutB = Layout, Int>, Stride, _1>>; 32 | using GmemLayoutC = Layout, Int>, Stride, _1>>; 33 | 34 | // TODO(haruhi): The current implementation uses ldmatrix.x4 35 | // instruction which requires the TileMMA configuration to be 36 | // fixed as follows. Make it able to be tuned by policy in 37 | // future implementation. 38 | using TiledMma = 39 | TiledMMA, // for ampere 40 | Layout, Int, _1>>, 41 | Tile, Int<16 * kWarpPerCol>, _16>>; 42 | static constexpr int kThreads = size(TiledMma{}); 43 | static_assert(kThreads == kWarpPerRow * kWarpPerCol * 32); 44 | 45 | static constexpr int kNumPerAccess = Base::kNumPerAccess; 46 | using SmemLayoutAtom = 47 | decltype(composition(Swizzle<2, 3, 3>{}, 48 | 49 | Layout>, 50 | Stride, _1>>{})); 51 | 52 | static constexpr int kThreadsPerCol = CeilDiv; 53 | static constexpr int kThreadsPerRow = CeilDiv; 54 | #ifdef CP_ASYNC_SM80_ENABLED 55 | using CopyInstG2S = 56 | Copy_Atom, Element>; 57 | #else 58 | using CopyInstG2S = Copy_Atom; 59 | #endif 60 | using TiledCopyG2S = decltype(make_tiled_copy( 61 | CopyInstG2S{}, 62 | Layout, Int>, 63 | Stride, _1>>{}, 64 | Layout>>{})); 65 | 66 | using SmemLayoutA = 67 | decltype(tile_to_shape(SmemLayoutAtom{}, Shape, Int>{})); 68 | using SmemLayoutB = 69 | decltype(tile_to_shape(SmemLayoutAtom{}, Shape, Int>{})); 70 | 71 | using TiledCopyS2G = decltype(make_tiled_copy( 72 | Copy_Atom{}, 73 | Layout, Int>, 74 | Stride, _1>>{}, 75 | Layout>>{})); 76 | using SmemLayoutC = 77 | decltype(tile_to_shape(SmemLayoutAtom{}, Shape, Int>{})); 78 | using StoreC_R2S = R2SCopy2D; 79 | }; 80 | 81 | template 83 | __global__ void batched_gemm_kernel(const Element* dA, const Element* dB, 84 | Element* dC) { 85 | extern __shared__ __align__(sizeof(double)) unsigned char buf_[]; 86 | auto* buf = reinterpret_cast(buf_); 87 | 88 | // Advance to the global data tile to the current CTA. 89 | Element* gA_ptr = 90 | const_cast(dA) + blockIdx.x * kK * kTM + blockIdx.z * kK * kM; 91 | Element* gB_ptr = 92 | const_cast(dB) + blockIdx.y * kK * kTN + blockIdx.z * kK * kN; 93 | Element* gC_ptr = 94 | dC + blockIdx.x * kTM * kN + blockIdx.y * kTN + blockIdx.z * kM * kN; 95 | 96 | // pointers to shared memory tiles 97 | Element* sA_ptr = buf; 98 | Element* sB_ptr = buf + kTM * kTK; 99 | Element* sC_ptr = buf; 100 | 101 | typename KeTraits::TiledMma mma; 102 | typename KeTraits::TiledCopyG2S tiled_copy; 103 | 104 | auto rA = make_s2rA(sA_ptr, typename KeTraits::SmemLayoutA{}, mma); 105 | auto rB = make_s2rB(sB_ptr, typename KeTraits::SmemLayoutB{}, mma); 106 | auto acc = get_acc(mma); 107 | 108 | for (int k = 0; k < kK; k += kTK) { 109 | copy_tile_g2s(gA_ptr, sA_ptr, typename KeTraits::GmemLayoutA{}, 110 | typename KeTraits::SmemLayoutA{}, tiled_copy); 111 | copy_tile_g2s(gB_ptr, sB_ptr, typename KeTraits::GmemLayoutB{}, 112 | typename KeTraits::SmemLayoutB{}, tiled_copy); 113 | __copy_async(); 114 | __syncthreads(); 115 | 116 | for (int i = 0; i < rA.get_iters(); ++i) { 117 | rA.copy(i); // load A register tile from shared memory 118 | rB.copy(i); // load B register tile from shared memory 119 | 120 | gemm(mma, rA[i], rB[i], acc); 121 | } 122 | __syncthreads(); 123 | 124 | gA_ptr += kTK; 125 | gB_ptr += kTK; 126 | } 127 | 128 | typename KeTraits::StoreC_R2S sC; // declare register to shared store plan 129 | sC.copy(acc, buf); // store register tile to shared memory 130 | __syncthreads(); 131 | 132 | // store shared memory tile to global memory 133 | copy_tile_s2g(sC_ptr, gC_ptr, typename KeTraits::SmemLayoutC{}, 134 | typename KeTraits::GmemLayoutC{}, 135 | typename KeTraits::TiledCopyS2G{}); 136 | } 137 | } // namespace cutlass_wrapper 138 | } // namespace benchmarks 139 | 140 | template 145 | void cute_batched_gemm(const Element* dA, const Element* dB, Element* dC) { 146 | using namespace benchmarks::cutlass_wrapper; 147 | 148 | using KeTraits = BatchedGemmTraits; 150 | 151 | static constexpr int smem_size = 152 | std::max(kTK * (kTN + kTM), kTM * kTN) * sizeof(Element); 153 | 154 | auto kernel = 155 | &batched_gemm_kernel; 156 | 157 | // maximal statically allocated smem per block 158 | const int kMaxSmemPerBlock = 48 * 1024; 159 | if (smem_size > kMaxSmemPerBlock) { 160 | cudaFuncSetAttribute( 161 | kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); 162 | } 163 | 164 | const int block_m = (kM + kTM - 1) / kTM; 165 | const int block_n = (kN + kTN - 1) / kTN; 166 | 167 | const int kThreads = KeTraits::kThreads; 168 | 169 | dim3 gridDim(block_m, block_n, BatchCount); 170 | dim3 blockDim(kThreads, 1, 1); 171 | 172 | kernel<<>>(dA, dB, dC); 173 | 174 | cudaDeviceSynchronize(); 175 | } 176 | -------------------------------------------------------------------------------- /benchs/python/gemm/bench.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from typing import Tuple 4 | import sys 5 | import os 6 | import csv 7 | 8 | cutlass_dir = os.path.join(os.path.dirname(__file__), 'cutlass') 9 | sys.path.insert(0, cutlass_dir) 10 | 11 | tiledcuda_dir = os.path.join(os.path.dirname(__file__), 'tiledcuda') 12 | sys.path.insert(0, tiledcuda_dir) 13 | 14 | from cutlass.gemm import gemm_func as cutlass_gemm 15 | from tiledcuda.gemm import gemm_func as tiledcuda_gemm 16 | from cuBLAS import cublas_gemm 17 | 18 | def run_cublas_unittest( 19 | a: Tensor, 20 | b: Tensor, 21 | c: Tensor, 22 | M: int, 23 | N: int, 24 | K: int, 25 | debug_print=False, 26 | epsilon: float = 5e-2 27 | ): 28 | time = torch.zeros(1, device=torch.device("cpu"), dtype=torch.float32) 29 | cublas_gemm(M, N, K, a, b, c, time) 30 | ref_c = a @ b.t() 31 | 32 | if debug_print: 33 | print("Result:") 34 | print(c) 35 | 36 | print("\nReference:") 37 | print(ref_c) 38 | 39 | avg_diff = (torch.sum(torch.abs(ref_c - c) / (M * N))).item() 40 | 41 | if avg_diff > epsilon: 42 | return False 43 | else: 44 | return True 45 | 46 | def run_tiledcuda_unittest( 47 | a: Tensor, 48 | b: Tensor, 49 | c: Tensor, 50 | M: int, 51 | N: int, 52 | K: int, 53 | kTM: int, 54 | kTN: int, 55 | kTK: int, 56 | kRK: int, 57 | warp_layout: Tuple, 58 | debug_print=False, 59 | epsilon: float = 5e-2 60 | ): 61 | tiledcuda_gemm(a, b, c, M, N, K, kTM, kTN, kTK, kRK, *warp_layout) 62 | ref_c = a @ b.t() 63 | 64 | if debug_print: 65 | print("Result:") 66 | print(c) 67 | 68 | print("\nReference:") 69 | print(ref_c) 70 | 71 | avg_diff = (torch.sum(torch.abs(ref_c - c.half()) / (M * N))).item() 72 | 73 | if avg_diff > epsilon: 74 | return False 75 | else: 76 | return True 77 | 78 | def run_cutlass_unittest( 79 | a: Tensor, 80 | b: Tensor, 81 | c: Tensor, 82 | M: int, 83 | N: int, 84 | K: int, 85 | kTM: int, 86 | kTN: int, 87 | kTK: int, 88 | warp_layout: Tuple, 89 | debug_print=False, 90 | epsilon: float = 5e-2 91 | ): 92 | cutlass_gemm(a, b, c, M, N, K, kTM, kTN, kTK, *warp_layout) 93 | ref_c = a @ b.t() 94 | 95 | if debug_print: 96 | print("Result:") 97 | print(c) 98 | 99 | print("\nReference:") 100 | print(ref_c) 101 | 102 | avg_diff = (torch.sum(torch.abs(ref_c - c) / (M * N))).item() 103 | 104 | if avg_diff > epsilon: 105 | return False 106 | else: 107 | return True 108 | 109 | def run_cublas_bench( 110 | a: Tensor, 111 | b: Tensor, 112 | c: Tensor, 113 | M: int, 114 | N: int, 115 | K: int, 116 | time: Tensor 117 | ): 118 | if run_cublas_unittest(a, b, c, M, N, K): 119 | pass 120 | else: 121 | print("Run cuBLAS unittest failed") 122 | return float("NaN") 123 | 124 | iters = 50 125 | warmup = 10 126 | 127 | cublas_gemm(M, N, K, a, b, c, time, iters, warmup) 128 | 129 | return time.item() 130 | 131 | def run_cutlass_bench( 132 | a: Tensor, 133 | b: Tensor, 134 | c: Tensor, 135 | M: int, 136 | N: int, 137 | K: int, 138 | kTM: int, 139 | kTN: int, 140 | kTK: int, 141 | warp_layout: Tuple, 142 | ): 143 | if run_cutlass_unittest(a, b, c, M, N, K, kTM, kTN, kTK, warp_layout): 144 | pass 145 | else: 146 | print("Run Cutlass unittest failed") 147 | return float("NaN") 148 | 149 | warmup = 10 150 | iters = 50 151 | 152 | for _ in range(warmup): 153 | cutlass_gemm(a, b, c, M, N, K, kTM, kTN, kTK, *warp_layout) 154 | 155 | 156 | start_event = torch.cuda.Event(enable_timing=True) 157 | end_event = torch.cuda.Event(enable_timing=True) 158 | 159 | start_event.record() 160 | for _ in range(iters): 161 | cutlass_gemm(a, b, c, M, N, K, kTM, kTN, kTK, *warp_layout) 162 | end_event.record() 163 | torch.cuda.synchronize() 164 | 165 | time = start_event.elapsed_time(end_event) / iters 166 | 167 | return time 168 | 169 | 170 | def run_tiledcuda_bench( 171 | a: Tensor, 172 | b: Tensor, 173 | c: Tensor, 174 | M: int, 175 | N: int, 176 | K: int, 177 | kTM: int, 178 | kTN: int, 179 | kTK: int, 180 | kRK: int, 181 | warp_layout: Tuple, 182 | ): 183 | if run_tiledcuda_unittest(a, b, c, M, N, K, kTM, kTN, kTK, kRK, warp_layout): 184 | pass 185 | else: 186 | print("Run TiledCUDA unittest failed") 187 | return float('NaN') 188 | 189 | warmup = 10 190 | iters = 50 191 | 192 | for _ in range(warmup): 193 | tiledcuda_gemm(a, b, c, M, N, K, kTM, kTN, kTK, kRK, *warp_layout) 194 | 195 | start_event = torch.cuda.Event(enable_timing=True) 196 | end_event = torch.cuda.Event(enable_timing=True) 197 | 198 | start_event.record() 199 | for _ in range(iters): 200 | tiledcuda_gemm(a, b, c, M, N, K, kTM, kTN, kTK, kRK, *warp_layout) 201 | end_event.record() 202 | torch.cuda.synchronize() 203 | 204 | time = start_event.elapsed_time(end_event) / iters 205 | 206 | return time 207 | 208 | def run_bench( 209 | M: int, 210 | N: int, 211 | K: int, 212 | kTM: int, 213 | kTN: int, 214 | kTK: int, 215 | kRK: int, 216 | warp_layout: Tuple, 217 | record_csv = None 218 | ): 219 | torch.manual_seed(1234) 220 | 221 | a = torch.randn(M, K, device = 'cuda', dtype = torch.float16) 222 | b = torch.randn(N, K, device = 'cuda', dtype = torch.float16) 223 | c = torch.zeros(M, N, device = 'cuda', dtype = torch.float32) 224 | half_c = torch.zeros(M, N, device = 'cuda', dtype = torch.float16) 225 | 226 | cublas_time_tensor = torch.zeros(1, device=torch.device("cpu"), dtype=torch.float32) 227 | 228 | cublas_time = run_cublas_bench(a, b, half_c, M, N, K, cublas_time_tensor) 229 | cutlass_time = run_cutlass_bench(a, b, half_c, M, N, K, kTM, kTN, kTK, warp_layout) 230 | tiledcuda_time = run_tiledcuda_bench(a, b, c, M, N, K, kTM, kTN, kTK, kRK, warp_layout) 231 | 232 | print("(M, N, K) (kTM, kTN, kTK)") 233 | print("({}, {}, {}) ({}, {}, {})".format(M, N, K, kTM, kTN, kTK)) 234 | print("cublas_time: {:.4f} ms, cutlass_time: {:.4f} ms, tiledcuda_time: {:.4f} ms".format(cublas_time, cutlass_time, tiledcuda_time)) 235 | 236 | csv.writer(record_csv).writerow([M, N, K, kTM, kTN, kTK, "{:.4f}".format(cublas_time), "{:.4f}".format(cutlass_time), "{:.4f}".format(tiledcuda_time)]) 237 | 238 | 239 | if __name__ == "__main__": 240 | kRK = 32 241 | 242 | device_id = torch.cuda.current_device() 243 | device_name = torch.cuda.get_device_name(device_id).replace(" ", "_") 244 | 245 | record = 'gemm_bench_{}.csv'.format(device_name) 246 | record_csv = open(record, 'w', newline='') 247 | 248 | csv.writer(record_csv).writerow(["M", "N", "K", "kTM", "kTN", "kTK", "cuBLAS(ms)", "Cutlass(ms)", "TiledCUDA(ms)"]) 249 | 250 | run_bench(4096, 4096, 2048, 128, 256, 64, kRK, (2, 2), record_csv) 251 | run_bench(4096, 4096, 2048, 64, 256, 32, kRK, (2, 2), record_csv) 252 | run_bench(4096, 4096, 2048, 128, 128, 32, kRK, (2, 2), record_csv) 253 | run_bench(4096, 4096, 2048, 128, 64, 32, kRK, (2, 2), record_csv) 254 | run_bench(4096, 4096, 2048, 64, 128, 32, kRK, (2, 2), record_csv) 255 | run_bench(4096, 4096, 2048, 128, 32, 32, kRK, (2, 2), record_csv) 256 | run_bench(4096, 4096, 2048, 32, 64, 32, kRK, (2, 2), record_csv) 257 | run_bench(4096, 4096, 2048, 128, 256, 128, kRK, (2, 2), record_csv) 258 | run_bench(4096, 4096, 2048, 256, 128, 128, kRK, (2, 2), record_csv) 259 | # Cutlass RuntimeError: CUDA error: an illegal memory access was encountered. 260 | # run_bench(4096, 4096, 2048, 256, 64, 128, kRK, (2, 2), record_csv) 261 | run_bench(4096, 4096, 2048, 64, 256, 128, kRK, (2, 2), record_csv) 262 | run_bench(4096, 4096, 2048, 128, 128, 128, kRK, (2, 2), record_csv) 263 | run_bench(4096, 4096, 2048, 128, 64, 64, kRK, (2, 2), record_csv) 264 | run_bench(4096, 4096, 2048, 64, 128, 64, kRK, (2, 2), record_csv) 265 | # Cutlass RuntimeError: CUDA error: an illegal memory access was encountered. 266 | # run_bench(4096, 4096, 2048, 128, 32, 64, kRK, (2, 2), record_csv) 267 | 268 | run_bench(4096, 4096, 32, 64, 32, 32, kRK, (2, 2), record_csv) 269 | 270 | run_bench(4096, 4096, 64, 128, 128, 64, kRK, (2, 2), record_csv) 271 | 272 | 273 | record_csv.close() 274 | -------------------------------------------------------------------------------- /benchs/python/fused_gemm/cutlass/cutlass_fused_gemm.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "utils/cpp/cuda_utils.cuh" 4 | #include "utils/cpp/cutlass/convert.cuh" 5 | #include "utils/cpp/cutlass/copy.cuh" 6 | #include "utils/cpp/cutlass/traits_base.cuh" 7 | 8 | #include 9 | 10 | namespace benchmarks { 11 | namespace cutlass_wrapper { 12 | 13 | using namespace cute; 14 | 15 | template > 20 | struct FusedGemmTraits : public Base { 21 | using Element = Element_; 22 | 23 | static_assert(kTK == kTN && kTN == kTP, 24 | "Fused GEMM requires kTK == kTN == kTP."); 25 | static_assert(kWarpPerCol == 1, 26 | "The Fused GEMM requires a single warp along CTA tile."); 27 | 28 | using GmemLayoutA = Layout, Int>, Stride, _1>>; 29 | using GmemLayoutB = Layout, Int>, Stride, _1>>; 30 | using GmemLayoutC = Layout, Int>, Stride, _1>>; 31 | using GmemLayoutD = Layout, Int>, Stride, _1>>; 32 | 33 | // TODO(haruhi): The current implementation uses ldmatrix.x4 34 | // instruction which requires the TileMMA configuration to be 35 | // fixed as follows. Make it able to be tuned by policy in 36 | // future implementation. 37 | using TiledMma = 38 | TiledMMA, // for ampere 39 | Layout, Int, _1>>, 40 | Tile, Int<16 * kWarpPerCol>, _16>>; 41 | static constexpr int kThreads = size(TiledMma{}); 42 | static_assert(kThreads == kWarpPerRow * kWarpPerCol * 32); 43 | 44 | static constexpr int kNumPerAccess = Base::kNumPerAccess; 45 | static constexpr int kThreadsPerCol = CeilDiv; 46 | static constexpr int kThreadsPerRow = CeilDiv; 47 | 48 | static constexpr int kSwizzle = (kTK == 32 ? 2 : 3); 49 | using SmemLayoutAtom = decltype(composition( 50 | Swizzle{}, 51 | Layout>, Stride, _1>>{})); 52 | 53 | using SmemLayoutA = 54 | decltype(tile_to_shape(SmemLayoutAtom{}, Shape, Int>{})); 55 | 56 | // The current implementation requires B are laid out in column 57 | // major. a [kTK, kTN] matrix in column major can be interpreted 58 | // as a [kTN, kTK] matrix in row major. 59 | using SmemLayoutB = 60 | decltype(tile_to_shape(SmemLayoutAtom{}, Shape, Int>{})); 61 | // a [kTN, kTP] matrix in column major fashion, 62 | // can be interpreted as a [kTP, kTN] matrix in row major fashion. 63 | using SmemLayoutC = 64 | decltype(tile_to_shape(SmemLayoutAtom{}, Shape, Int>{})); 65 | 66 | #ifdef CP_ASYNC_SM80_ENABLED 67 | using CopyInstG2S = 68 | Copy_Atom, Element>; 69 | #else 70 | using CopyInstG2S = Copy_Atom; 71 | #endif 72 | using TiledCopyG2S = decltype(make_tiled_copy( 73 | CopyInstG2S{}, 74 | Layout, Int>, 75 | Stride, _1>>{}, 76 | Layout>>{})); 77 | 78 | using TiledCopyS2G = decltype(make_tiled_copy( 79 | Copy_Atom{}, 80 | Layout, Int>, 81 | Stride, _1>>{}, 82 | Layout>>{})); 83 | using SmemLayoutD = 84 | decltype(tile_to_shape(SmemLayoutAtom{}, Shape, Int>{})); 85 | 86 | using StoreD_R2S = R2SCopy2D; 87 | }; 88 | 89 | template 92 | __global__ void fused_gemm_kernel(const Element* dA, const Element* dB, 93 | const Element* dC, Element* dD) { 94 | // Advance to the global data tile to the current CTA. 95 | Element* A = const_cast(dA) + blockIdx.x * (kTM * kK); 96 | Element* B = const_cast(dB); 97 | Element* gC_ptr = const_cast(dC) + blockIdx.y * (kTP * kN); 98 | Element* gD_ptr = dD + blockIdx.x * (kTM * kP) + (blockIdx.y * kTP); 99 | 100 | Element* gA_ptr; 101 | Element* gB_ptr; 102 | 103 | extern __shared__ __align__(sizeof(double)) unsigned char shared_buf[]; 104 | auto* shm = reinterpret_cast(shared_buf); 105 | // pointers to shared memory tiles 106 | Element* sA_ptr = shm; 107 | Element* sB_ptr = shm + kTM * kTK; 108 | Element* sC_ptr = shm + kTM * kTK + kTK * kTN; 109 | Element* sD_ptr = shm; 110 | 111 | typename KeTraits::TiledMma mma; // for shared memory to register copy 112 | typename KeTraits::TiledCopyG2S tiled_copy; 113 | 114 | auto rA = make_s2rA(sA_ptr, typename KeTraits::SmemLayoutA{}, mma); 115 | auto rB = make_s2rB(sB_ptr, typename KeTraits::SmemLayoutB{}, mma); 116 | auto acc1 = get_acc(mma); // accumulator for the 1st gemm 117 | 118 | auto rC = make_s2rB(sC_ptr, typename KeTraits::SmemLayoutC{}, mma); 119 | auto acc2 = get_acc(mma); // accumulator for the 2nd gemm 120 | 121 | typename KeTraits::StoreD_R2S sD; // declare register to shared store plan 122 | 123 | for (int n = 0; n < kN; n += kTN) { // iterate over N 124 | gA_ptr = A; // A tile is repeated loaded 125 | gB_ptr = B + n * kK; 126 | for (int k = 0; k < kK; k += kTK) { // iterate over K 127 | copy_tile_g2s(gA_ptr, sA_ptr, typename KeTraits::GmemLayoutA{}, 128 | typename KeTraits::SmemLayoutA{}, tiled_copy); 129 | copy_tile_g2s(gB_ptr, sB_ptr, typename KeTraits::GmemLayoutB{}, 130 | typename KeTraits::SmemLayoutB{}, tiled_copy); 131 | __copy_async(); 132 | __syncthreads(); 133 | 134 | // iterate over the register tiles along the kTK dimension 135 | for (int i = 0; i < rA.get_iters(); ++i) { 136 | rA.copy(i); // load A register tile from shared memory 137 | rB.copy(i); // load B register tile from shared memory 138 | cute::gemm(mma, rA[i], rB[i], acc1); // compute 139 | } 140 | __syncthreads(); 141 | 142 | gA_ptr += kTK; 143 | gB_ptr += kTK; 144 | } 145 | 146 | // The output type of the first tensor core matrix multiplication is 147 | // float32. However, before the second GEMM operation, the output 148 | // needs to be converted to half precision. 149 | auto acc_half = convert_type(acc1); 150 | auto rA2 = convert_layout(acc_half); 151 | 152 | // load C tile from global to shared memory 153 | copy_tile_g2s(gC_ptr, sC_ptr, typename KeTraits::GmemLayoutC{}, 154 | typename KeTraits::SmemLayoutC{}, tiled_copy); 155 | __copy_async(); 156 | __syncthreads(); 157 | 158 | // iterate over register tiles along the kTN dimension 159 | for (int i = 0; i < rC.get_iters(); ++i) { 160 | rC.copy(i); // load C tile from shared memory to register 161 | cute::gemm(mma, rA2[i], rC[i], acc2); // compute 162 | } 163 | __syncthreads(); 164 | 165 | clear(acc1); 166 | gC_ptr += kTN; 167 | } 168 | 169 | // store register tile to shared memory 170 | sD.copy(acc2, shm); 171 | __syncthreads(); 172 | 173 | copy_tile_s2g(sD_ptr, gD_ptr, typename KeTraits::SmemLayoutD{}, 174 | typename KeTraits::GmemLayoutD{}, 175 | typename KeTraits::TiledCopyS2G{}); 176 | } 177 | 178 | } // namespace cutlass_wrapper 179 | } // namespace benchmarks 180 | 181 | template 185 | void cute_fused_gemm(const Element* dA, const Element* dB, const Element* dC, 186 | Element* dD) { 187 | using namespace benchmarks::cutlass_wrapper; 188 | 189 | using KeTraits = FusedGemmTraits; 191 | 192 | auto kernel = &fused_gemm_kernel; 194 | 195 | int shm_input = (kTM * kTK + kTK * kTN + kTN * kTP); 196 | int shm_output = kTM * kTP; 197 | int shm_size = shm_input < shm_output ? shm_output * sizeof(Element) 198 | : shm_input * sizeof(Element); 199 | 200 | // maximal statically allocated smem per block 201 | const int kMaxSmemPerBlock = 48 * 1024; 202 | if (shm_size > kMaxSmemPerBlock) { 203 | cudaFuncSetAttribute( 204 | kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shm_size); 205 | } 206 | 207 | // blocks are launched along the M and P dimensions. 208 | int block_x = (kM + kTM - 1) / kTM; 209 | int block_y = (kP + kTP - 1) / kTP; 210 | const int kThreads = KeTraits::kThreads; 211 | 212 | dim3 gridDim(block_x, block_y, 1); 213 | dim3 blockDim(kThreads, 1, 1); 214 | 215 | kernel<<>>(dA, dB, dC, dD); 216 | } 217 | -------------------------------------------------------------------------------- /benchs/cpp/copy/cutlass/cutlass_copy.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "utils/cpp/cuda_utils.cuh" 4 | #include "utils/cpp/cutlass/copy.cuh" 5 | #include "utils/cpp/cutlass/traits_base.cuh" 6 | 7 | #include 8 | 9 | namespace benchmarks { 10 | namespace cutlass_wrapper { 11 | 12 | using namespace cute; 13 | 14 | template > 19 | struct GemmTraits : public Base { 20 | using Element = Element_; 21 | 22 | static_assert(kTM % kWarpPerRow == 0, 23 | "the M dimension of the CTA tile should be divisible by the " 24 | "number of warps along that that dimension."); 25 | static_assert(kTN % kWarpPerCol == 0, 26 | "the N dimension of the CTA tile should be divisible by the " 27 | "number of warps along that that dimension."); 28 | 29 | // declare global to shared memory copy layout. 30 | using GmemLayoutA = Layout, Int>, Stride, _1>>; 31 | using GmemLayoutB = Layout, Int>, Stride, _1>>; 32 | using GmemLayoutC = Layout, Int>, Stride, _1>>; 33 | 34 | using TiledMma = 35 | TiledMMA, // for ampere 36 | Layout, Int, _1>>, 37 | Tile, Int<16 * kWarpPerCol>, _16>>; 38 | 39 | static constexpr int kThreads = size(TiledMma{}); 40 | static_assert(kThreads == kWarpPerRow * kWarpPerCol * 32); 41 | 42 | static constexpr int kNumPerAccess = Base::kNumPerAccess; 43 | static constexpr int kThreadsPerCol = CeilDiv; 44 | static constexpr int kThreadsPerRow = CeilDiv; 45 | 46 | using SmemLayoutAtom = decltype(composition( 47 | Swizzle<2, 3, 3>{}, Layout>, 48 | Stride, _1>>{})); 49 | 50 | using SmemLayoutA = 51 | decltype(tile_to_shape(SmemLayoutAtom{}, Shape, Int>{})); 52 | using SmemLayoutB = 53 | decltype(tile_to_shape(SmemLayoutAtom{}, Shape, Int>{})); 54 | using SmemLayoutC = 55 | decltype(tile_to_shape(SmemLayoutAtom{}, Shape, Int>{})); 56 | 57 | #ifdef CP_ASYNC_SM80_ENABLED 58 | using CopyInstG2S = 59 | Copy_Atom, Element>; 60 | #else 61 | using CopyInstG2S = Copy_Atom; 62 | #endif 63 | 64 | using TiledCopyG2S = decltype(make_tiled_copy( 65 | CopyInstG2S{}, 66 | Layout, Int>, 67 | Stride, _1>>{}, 68 | Layout>>{})); 69 | 70 | using TiledCopyS2G = decltype(make_tiled_copy( 71 | Copy_Atom{}, 72 | Layout, Int>, 73 | Stride, _1>>{}, 74 | Layout>>{})); 75 | using StoreC_R2S = R2SCopy2D; 76 | }; 77 | 78 | template 80 | __global__ void copy_shared_kernel(const Element* dA, const Element* dB, 81 | Element* dC) { 82 | extern __shared__ __align__(sizeof(double)) unsigned char buf_[]; 83 | auto* buf = reinterpret_cast(buf_); 84 | 85 | // Advance to the global data tile to the current CTA. 86 | Element* gA_ptr = const_cast(dA) + blockIdx.x * kK * kTM; 87 | Element* gB_ptr = const_cast(dB) + blockIdx.y * kK * kTN; 88 | Element* gC_ptr = dC + blockIdx.x * kTM * kN + blockIdx.y * kTN; 89 | 90 | // pointers to shared memory tiles 91 | Element* sA_ptr = buf; 92 | Element* sB_ptr = buf + kTM * kTK; 93 | Element* sC_ptr = buf; 94 | 95 | typename KeTraits::TiledMma mma; 96 | typename KeTraits::TiledCopyG2S tiled_copy; 97 | 98 | auto rA = make_s2rA(sA_ptr, typename KeTraits::SmemLayoutA{}, mma); 99 | auto rB = make_s2rB(sB_ptr, typename KeTraits::SmemLayoutB{}, mma); 100 | auto acc = get_acc(mma); 101 | 102 | for (int k = 0; k < kK; k += kTK) { 103 | copy_tile_g2s(gA_ptr, sA_ptr, typename KeTraits::GmemLayoutA{}, 104 | typename KeTraits::SmemLayoutA{}, tiled_copy); 105 | copy_tile_g2s(gB_ptr, sB_ptr, typename KeTraits::GmemLayoutB{}, 106 | typename KeTraits::SmemLayoutB{}, tiled_copy); 107 | __copy_async(); 108 | __syncthreads(); 109 | 110 | gA_ptr += kTK; 111 | gB_ptr += kTK; 112 | } 113 | 114 | // store shared memory tile to global memory 115 | copy_tile_s2g(sC_ptr, gC_ptr, typename KeTraits::SmemLayoutC{}, 116 | typename KeTraits::GmemLayoutC{}, 117 | typename KeTraits::TiledCopyS2G{}); 118 | } 119 | 120 | template 122 | __global__ void copy_kernel(const Element* dA, const Element* dB, Element* dC) { 123 | extern __shared__ __align__(sizeof(double)) unsigned char buf_[]; 124 | auto* buf = reinterpret_cast(buf_); 125 | 126 | // Advance to the global data tile to the current CTA. 127 | Element* gA_ptr = const_cast(dA) + blockIdx.x * kK * kTM; 128 | Element* gB_ptr = const_cast(dB) + blockIdx.y * kK * kTN; 129 | Element* gC_ptr = dC + blockIdx.x * kTM * kN + blockIdx.y * kTN; 130 | 131 | // pointers to shared memory tiles 132 | Element* sA_ptr = buf; 133 | Element* sB_ptr = buf + kTM * kTK; 134 | Element* sC_ptr = buf; 135 | 136 | typename KeTraits::TiledMma mma; 137 | typename KeTraits::TiledCopyG2S tiled_copy; 138 | 139 | auto rA = make_s2rA(sA_ptr, typename KeTraits::SmemLayoutA{}, mma); 140 | auto rB = make_s2rB(sB_ptr, typename KeTraits::SmemLayoutB{}, mma); 141 | auto acc = get_acc(mma); 142 | 143 | for (int k = 0; k < kK; k += kTK) { 144 | copy_tile_g2s(gA_ptr, sA_ptr, typename KeTraits::GmemLayoutA{}, 145 | typename KeTraits::SmemLayoutA{}, tiled_copy); 146 | copy_tile_g2s(gB_ptr, sB_ptr, typename KeTraits::GmemLayoutB{}, 147 | typename KeTraits::SmemLayoutB{}, tiled_copy); 148 | __copy_async(); 149 | __syncthreads(); 150 | 151 | for (int i = 0; i < rA.get_iters(); ++i) { 152 | rA.copy(i); // load A register tile from shared memory 153 | rB.copy(i); // load B register tile from shared memory 154 | } 155 | __syncthreads(); 156 | 157 | gA_ptr += kTK; 158 | gB_ptr += kTK; 159 | } 160 | 161 | typename KeTraits::StoreC_R2S sC; // declare register to shared store plan 162 | sC.copy(acc, buf); // store register tile to shared memory 163 | __syncthreads(); 164 | 165 | // store shared memory tile to global memory 166 | copy_tile_s2g(sC_ptr, gC_ptr, typename KeTraits::SmemLayoutC{}, 167 | typename KeTraits::GmemLayoutC{}, 168 | typename KeTraits::TiledCopyS2G{}); 169 | } 170 | 171 | } // namespace cutlass_wrapper 172 | } // namespace benchmarks 173 | 174 | template 178 | void cute_shared_copy(const Element* dA, const Element* dB, Element* dC) { 179 | using namespace benchmarks::cutlass_wrapper; 180 | 181 | using KeTraits = GemmTraits; 183 | 184 | static constexpr int smem_size = 185 | std::max(kTK * (kTN + kTM), kTM * kTN) * sizeof(Element); 186 | 187 | auto kernel = 188 | ©_shared_kernel; 189 | 190 | // maximal statically allocated smem per block 191 | const int kMaxSmemPerBlock = 48 * 1024; 192 | if (smem_size > kMaxSmemPerBlock) { 193 | cudaFuncSetAttribute( 194 | kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); 195 | } 196 | 197 | const int block_m = (kM + kTM - 1) / kTM; 198 | const int block_n = (kN + kTN - 1) / kTN; 199 | 200 | const int kThreads = KeTraits::kThreads; 201 | 202 | dim3 gridDim(block_m, block_n); 203 | dim3 blockDim(kThreads, 1, 1); 204 | 205 | kernel<<>>(dA, dB, dC); 206 | } 207 | 208 | template 212 | void cute_copy(const Element* dA, const Element* dB, Element* dC) { 213 | using namespace benchmarks::cutlass_wrapper; 214 | 215 | using KeTraits = GemmTraits; 217 | 218 | static constexpr int smem_size = 219 | std::max(kTK * (kTN + kTM), kTM * kTN) * sizeof(Element); 220 | 221 | auto kernel = ©_kernel; 222 | 223 | // maximal statically allocated smem per block 224 | const int kMaxSmemPerBlock = 48 * 1024; 225 | if (smem_size > kMaxSmemPerBlock) { 226 | cudaFuncSetAttribute( 227 | kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); 228 | } 229 | 230 | const int block_m = (kM + kTM - 1) / kTM; 231 | const int block_n = (kN + kTN - 1) / kTN; 232 | 233 | const int kThreads = KeTraits::kThreads; 234 | 235 | dim3 gridDim(block_m, block_n); 236 | dim3 blockDim(kThreads, 1, 1); 237 | 238 | kernel<<>>(dA, dB, dC); 239 | } 240 | -------------------------------------------------------------------------------- /benchs/python/lstm/cutlass/cutlass_lstm.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "utils/cpp/cuda_utils.cuh" 4 | #include "utils/cpp/cutlass/compute.cuh" 5 | #include "utils/cpp/cutlass/copy.cuh" 6 | #include "utils/cpp/cutlass/traits_base.cuh" 7 | 8 | #include 9 | 10 | namespace benchmarks { 11 | namespace cutlass_wrapper { 12 | 13 | using namespace cute; 14 | 15 | template > 20 | struct LstmTraits : public Base { 21 | using Element = Element_; 22 | 23 | static_assert(kTM % kWarpPerRow == 0, 24 | "the M dimension of the CTA tile should be divisible by the " 25 | "number of warps along that that dimension."); 26 | static_assert(kTN % kWarpPerCol == 0, 27 | "the N dimension of the CTA tile should be divisible by the " 28 | "number of warps along that that dimension."); 29 | 30 | // declare global to shared memory copy layout. 31 | using GmemLayoutA = Layout, Int>, Stride, _1>>; 32 | using GmemLayoutB = Layout, Int>, Stride, _1>>; 33 | using GmemLayoutC = Layout, Int>, Stride, _1>>; 34 | using GmemLayoutD = Layout, Int>, Stride, _1>>; 35 | using GmemLayoutE = Layout, Int>, Stride, _1>>; 36 | 37 | using TiledMma = 38 | TiledMMA, // for ampere 39 | Layout, Int, _1>>, 40 | Tile, Int<16 * kWarpPerCol>, _16>>; 41 | 42 | static constexpr int kThreads = size(TiledMma{}); 43 | static_assert(kThreads == kWarpPerRow * kWarpPerCol * 32); 44 | 45 | static constexpr int kNumPerAccess = Base::kNumPerAccess; 46 | static constexpr int kThreadsPerCol = CeilDiv; 47 | static constexpr int kThreadsPerRow = CeilDiv; 48 | 49 | using SmemLayoutAtom = decltype(composition( 50 | Swizzle<2, 3, 3>{}, Layout>, 51 | Stride, _1>>{})); 52 | 53 | using SmemLayoutA = 54 | decltype(tile_to_shape(SmemLayoutAtom{}, Shape, Int>{})); 55 | using SmemLayoutB = 56 | decltype(tile_to_shape(SmemLayoutAtom{}, Shape, Int>{})); 57 | using SmemLayoutC = 58 | decltype(tile_to_shape(SmemLayoutAtom{}, Shape, Int>{})); 59 | using SmemLayoutD = 60 | decltype(tile_to_shape(SmemLayoutAtom{}, Shape, Int>{})); 61 | using SmemLayoutE = 62 | decltype(tile_to_shape(SmemLayoutAtom{}, Shape, Int>{})); 63 | 64 | #ifdef CP_ASYNC_SM80_ENABLED 65 | using CopyInstG2S = 66 | Copy_Atom, Element>; 67 | #else 68 | using CopyInstG2S = Copy_Atom; 69 | #endif 70 | 71 | using TiledCopyG2S = decltype(make_tiled_copy( 72 | CopyInstG2S{}, 73 | Layout, Int>, 74 | Stride, _1>>{}, 75 | Layout>>{})); 76 | 77 | using TiledCopyS2G = decltype(make_tiled_copy( 78 | Copy_Atom{}, 79 | Layout, Int>, 80 | Stride, _1>>{}, 81 | Layout>>{})); 82 | 83 | using StoreE_R2S = R2SCopy2D; 84 | }; 85 | 86 | template 88 | __global__ void lstm_gate_kernel(const Element* ws, const Element* us, 89 | const Element* xs, const Element* hs, 90 | Element* ts) { 91 | extern __shared__ __align__(sizeof(double)) unsigned char shared_buf[]; 92 | auto* shm = reinterpret_cast(shared_buf); 93 | 94 | // Advance to the global data tile to the current CTA. 95 | Element* gxs_ptr = const_cast(xs) + blockIdx.y * kK * kTN; 96 | Element* ghs_ptr = const_cast(hs) + blockIdx.y * kK * kTN; 97 | Element* gws_ptr = const_cast(ws) + blockIdx.x * kK * kTM; 98 | Element* gus_ptr = const_cast(us) + blockIdx.x * kK * kTM; 99 | Element* gts_ptr = ts + blockIdx.x * kTM * kN + blockIdx.y * kTN; 100 | 101 | int total_block_x = gridDim.x; 102 | int current_block_x = blockIdx.x; 103 | 104 | // pointers to shared memory tiles 105 | Element* sws_ptr = shm; 106 | Element* sxs_ptr = shm + kTM * kTK; 107 | Element* sus_ptr = shm + kTM * kTK + kTK * kTN; 108 | Element* shs_ptr = shm + kTM * kTK + kTK * kTN + kTM * kTK; 109 | Element* sts_ptr = shm; 110 | 111 | // declare shared memory to register file copy plan. 112 | // tcu's wmma instruction prescribes a strict data to thread 113 | // mapping, in the current implementation, the shm-2-reg copy 114 | // plan is related to mma. 115 | typename KeTraits::TiledMma mma; 116 | typename KeTraits::TiledCopyG2S tiled_copy; 117 | 118 | auto rws = make_s2rA(sws_ptr, typename KeTraits::SmemLayoutA{}, mma); 119 | auto rxs = make_s2rB(sxs_ptr, typename KeTraits::SmemLayoutB{}, mma); 120 | auto rus = make_s2rA(sus_ptr, typename KeTraits::SmemLayoutC{}, mma); 121 | auto rhs = make_s2rB(shs_ptr, typename KeTraits::SmemLayoutD{}, mma); 122 | 123 | auto acc1 = get_acc(mma); 124 | auto acc2 = get_acc(mma); 125 | 126 | typename KeTraits::StoreE_R2S sts; // declare register to shared store 127 | 128 | for (int k = 0; k < kK; k += kTK) { 129 | copy_tile_g2s(gws_ptr, sws_ptr, typename KeTraits::GmemLayoutA{}, 130 | typename KeTraits::SmemLayoutA{}, tiled_copy); 131 | copy_tile_g2s(gxs_ptr, sxs_ptr, typename KeTraits::GmemLayoutB{}, 132 | typename KeTraits::SmemLayoutB{}, tiled_copy); 133 | copy_tile_g2s(gus_ptr, sus_ptr, typename KeTraits::GmemLayoutC{}, 134 | typename KeTraits::SmemLayoutC{}, tiled_copy); 135 | copy_tile_g2s(ghs_ptr, shs_ptr, typename KeTraits::GmemLayoutD{}, 136 | typename KeTraits::SmemLayoutD{}, tiled_copy); 137 | 138 | __copy_async(); 139 | __syncthreads(); 140 | 141 | for (int i = 0; i < rws.get_iters(); i++) { 142 | rws.copy(i); 143 | rxs.copy(i); 144 | gemm(mma, rws[i], rxs[i], acc1); 145 | } 146 | 147 | for (int i = 0; i < rus.get_iters(); i++) { 148 | rus.copy(i); 149 | rhs.copy(i); 150 | gemm(mma, rus[i], rhs[i], acc2); 151 | } 152 | 153 | __syncthreads(); 154 | gws_ptr += kTK; 155 | gxs_ptr += kTK; 156 | gus_ptr += kTK; 157 | ghs_ptr += kTK; 158 | } 159 | 160 | __syncthreads(); 161 | cute::axpby(1.0, acc1, 1.0, acc2); 162 | 163 | __syncthreads(); 164 | if (current_block_x < total_block_x * 3 / 4) { 165 | cute_sigmoid(acc2); 166 | } else { 167 | cute_tanh(acc2); 168 | } 169 | __syncthreads(); 170 | 171 | sts.copy(acc2, shm); 172 | 173 | __syncthreads(); 174 | 175 | copy_tile_s2g(sts_ptr, gts_ptr, typename KeTraits::SmemLayoutE{}, 176 | typename KeTraits::GmemLayoutE{}, 177 | typename KeTraits::TiledCopyS2G{}); 178 | } 179 | 180 | template 181 | __global__ void lstm_element_wise(const Element* i, const Element* f, 182 | const Element* o, const Element* c_candidate, 183 | const Element* c, Element* c_out, 184 | Element* h_out, const int block_size, 185 | int size) { 186 | int index = blockIdx.x * block_size + threadIdx.x; 187 | if (index < size) { 188 | // TODO: Loading data into shared memory and computing, versus 189 | // computing directly in global memory, does not seem to make a 190 | // difference. This seems to require further optimization, such as 191 | // reconsidering redistributing data to different threads and performing 192 | // vectorized loading and storing. 193 | 194 | // This is a very naive kernel that loads data into shared memory and 195 | // then performs computations. It has been temporarily commented out. 196 | 197 | c_out[index] = f[index] * c[index] + i[index] * c_candidate[index]; 198 | 199 | __syncthreads(); 200 | 201 | h_out[index] = o[index] * tanh(c_out[index]); 202 | } 203 | } 204 | 205 | template 209 | void lstm_gate(const Element* w, const Element* x, const Element* u, 210 | const Element* h, Element* t) { 211 | using KeTraits = LstmTraits; 213 | 214 | static constexpr int smem_size = 215 | std::max(kTK * (kTN + kTM) * 2, kTM * kTN) * sizeof(Element); 216 | 217 | auto kernel = 218 | &lstm_gate_kernel; 219 | 220 | // maximal statically allocated smem per block 221 | const int kMaxSmemPerBlock = 48 * 1024; 222 | if (smem_size > kMaxSmemPerBlock) { 223 | cudaFuncSetAttribute( 224 | kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); 225 | } 226 | 227 | const int block_m = (kM + kTM - 1) / kTM; 228 | const int block_n = (kN + kTN - 1) / kTN; 229 | 230 | const int kThreads = KeTraits::kThreads; 231 | 232 | dim3 gridDim(block_m, block_n, 1); 233 | dim3 blockDim(kThreads, 1, 1); 234 | 235 | kernel<<>>(w, u, x, h, t); 236 | } 237 | 238 | } // namespace cutlass_wrapper 239 | } // namespace benchmarks 240 | 241 | template 245 | void cute_lstm_cell(const Element* w, const Element* x, const Element* u, 246 | const Element* c, const Element* h, Element* c_out, 247 | Element* h_out) { 248 | static const int M = kM / 4; 249 | static const int N = kN; 250 | 251 | // Cuda malloc for output 252 | Element* t; 253 | benchmarks::CudaCheck(cudaMalloc(&t, kM * kN * sizeof(Element))); 254 | 255 | benchmarks::cutlass_wrapper::lstm_gate(w, x, u, 257 | h, t); 258 | 259 | const Element* i = t; 260 | const Element* f = t + M * N; 261 | const Element* o = t + 2 * M * N; 262 | const Element* c_candidate = t + 3 * M * N; 263 | 264 | auto element_wise = 265 | &benchmarks::cutlass_wrapper::lstm_element_wise; 266 | 267 | /* 268 | TODO: Use `kMaxThreads` will case a runtime error: 269 | ``` 270 | RuntimeError: CUDA error: invalid configuration argument 271 | CUDA kernel errors might be asynchronously reported at some other API call, 272 | so the stacktrace below might be incorrect. For debugging consider passing 273 | CUDA_LAUNCH_BLOCKING=1. Compile with `TORCH_USE_CUDA_DSA` to enable 274 | device-side assertions. 275 | ``` 276 | */ 277 | // int kMaxThreads = GetGPUMaxThreadsPerMultiProcessor(0); 278 | int size = M * N; 279 | const int block_threads = 512; 280 | int block_size = (size + block_threads - 1) / block_threads; 281 | dim3 element_wise_grid_dim(block_size, 1, 1); 282 | dim3 element_wise_block_dim(block_threads, 1, 1); 283 | 284 | element_wise<<>>( 285 | i, f, o, c_candidate, c, c_out, h_out, block_threads, size); 286 | 287 | benchmarks::CudaCheck(cudaFree(t)); 288 | } 289 | -------------------------------------------------------------------------------- /benchs/cpp/copy/tiledcuda/tiledcuda_copy.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "cell/mod.hpp" 4 | #include "types/mod.hpp" 5 | 6 | using namespace tiledcuda; 7 | using namespace tiledcuda::cell; 8 | using namespace tiledcuda::cell::copy; 9 | 10 | namespace tl = tile_layout; 11 | 12 | template 13 | using GemmShape = TileShape; 14 | 15 | template 17 | struct KeGemmTraits { 18 | using BaseShape = traits::BaseTileShape; 19 | 20 | static constexpr int kThreads = tl::get_numel * 32; 21 | static constexpr int kWarpPerRow = tl::num_rows; 22 | static constexpr int kWarpPerCol = tl::num_cols; 23 | 24 | static constexpr int kM = dim_size<0, WholeShape>; 25 | static constexpr int kN = dim_size<1, WholeShape>; 26 | static constexpr int kK = dim_size<2, WholeShape>; 27 | 28 | static constexpr int kTM = dim_size<0, CtaTileShape>; 29 | static constexpr int kTN = dim_size<1, CtaTileShape>; 30 | static constexpr int kTK = dim_size<2, CtaTileShape>; 31 | 32 | static const bool kSwizzled = true; 33 | 34 | // Total data access for operand A in global memory 35 | using GlobalA = GlobalTile>; 36 | // Access a single global tile for operand A 37 | using GIteratorA = GTileIterator>; 38 | 39 | // Shared Tile for operand A 40 | using SharedA = SharedTile, kSwizzled>; 41 | // Access a single register tile for operand A 42 | using SIteratorA = STileIterator>; 43 | 44 | // Register tile for a single thread of operand A 45 | static constexpr int kAMs = kTM / kWarpPerRow / BaseShape::kTileSize; 46 | static constexpr int kAKs = kRK / BaseShape::kTileSize; 47 | using RegA = RegTile, tl::RowMajor>; 48 | 49 | // Loaders for operand A 50 | using G2SLoaderA = GlobalToSharedLoader; 51 | using S2RLoaderA = 52 | SharedToRegLoader; 53 | 54 | // Total data access for operand B in global memory 55 | using GlobalB = GlobalTile>; 56 | // Access a single global tile for operand B 57 | using GIteratorB = GTileIterator>; 58 | 59 | // Shared Tile for operand B 60 | using SharedB = SharedTile, kSwizzled>; 61 | // Access a single register tile for operand B 62 | using SIteratorB = STileIterator>; 63 | 64 | static_assert(GIteratorA::sc1 == GIteratorB::sc0, 65 | "mismatched K dimension!"); 66 | static_assert(SIteratorA::sc1 == SIteratorB::sc0, 67 | "mismatched K dimension!"); 68 | 69 | // Register tile for a single thread of operand A 70 | static constexpr int kBKs = kRK / BaseShape::kTileSize; 71 | static constexpr int kBNs = kTN / kWarpPerCol / BaseShape::kTileSize; 72 | using RegB = RegTile, tl::ColMajor>; 73 | 74 | using G2SLoaderB = GlobalToSharedLoader; 75 | using S2RLoaderB = 76 | SharedToRegLoader; 77 | 78 | // Global Tile for output C 79 | using GlobalC = GlobalTile>; 80 | // Shared Tile for output C 81 | using SharedC = SharedTile, kSwizzled>; 82 | 83 | // Register Tile for output C 84 | static constexpr int kCMs = kTM / kWarpPerRow / BaseShape::kTileSize; 85 | static constexpr int kCNs = kTN / kWarpPerCol / BaseShape::kTileSize; 86 | using RegC = RegTile, tl::RowMajor>; 87 | 88 | using R2SStorerC = RegToSharedStorer; 89 | using S2GStorerC = SharedToGlobalStorer; 90 | }; 91 | 92 | template 103 | __global__ void copy_shared_kernel(const InType* dA, const InType* dB, 104 | AccType* dC) { 105 | int offset_a = blockIdx.x * kTM * kK; 106 | int offset_b = blockIdx.y * kTN * kK; 107 | int offset_c = blockIdx.x * kTM * kN + blockIdx.y * kTN; 108 | 109 | extern __shared__ __align__(sizeof(double)) unsigned char buf[]; 110 | InType* sA_ptr = reinterpret_cast(buf); 111 | InType* sB_ptr = sA_ptr + SIteratorA::Tile::kNumel; 112 | AccType* sC_ptr = reinterpret_cast(buf); 113 | 114 | // declare tiles, iterators and loaders 115 | GIteratorA gAs(dA + offset_a); 116 | SIteratorA sAs(sA_ptr); 117 | 118 | GIteratorB gBs(dB + offset_b); 119 | SIteratorB sBs(sB_ptr); 120 | 121 | SharedA sA(sA_ptr); 122 | RegA rA; 123 | 124 | SharedB sB(sB_ptr); 125 | RegB rB; 126 | 127 | RegC acc; 128 | SharedC sC(sC_ptr); 129 | GlobalC gC(dC + offset_c); 130 | 131 | G2SLoaderA g2s_a; 132 | S2RLoaderA s2r_a; 133 | 134 | G2SLoaderB g2s_b; 135 | S2RLoaderB s2r_b; 136 | 137 | R2SStorerC r2s_c; 138 | S2GStorerC s2g_c; 139 | 140 | for (int k1 = 0; k1 < GIteratorA::sc1; ++k1) { 141 | g2s_a(gAs(k1), sA); 142 | g2s_b(gBs(k1), sB); 143 | __copy_async(); 144 | __syncthreads(); 145 | } 146 | s2g_c(sC, gC); 147 | } 148 | 149 | template 160 | __global__ void copy_kernel(const InType* dA, const InType* dB, AccType* dC) { 161 | int offset_a = blockIdx.x * kTM * kK; 162 | int offset_b = blockIdx.y * kTN * kK; 163 | int offset_c = blockIdx.x * kTM * kN + blockIdx.y * kTN; 164 | 165 | extern __shared__ __align__(sizeof(double)) unsigned char buf[]; 166 | InType* sA_ptr = reinterpret_cast(buf); 167 | InType* sB_ptr = sA_ptr + SIteratorA::Tile::kNumel; 168 | AccType* sC_ptr = reinterpret_cast(buf); 169 | 170 | // declare tiles, iterators and loaders 171 | GIteratorA gAs(dA + offset_a); 172 | SIteratorA sAs(sA_ptr); 173 | 174 | GIteratorB gBs(dB + offset_b); 175 | SIteratorB sBs(sB_ptr); 176 | 177 | SharedA sA(sA_ptr); 178 | RegA rA; 179 | 180 | SharedB sB(sB_ptr); 181 | RegB rB; 182 | 183 | RegC acc; 184 | SharedC sC(sC_ptr); 185 | GlobalC gC(dC + offset_c); 186 | 187 | G2SLoaderA g2s_a; 188 | S2RLoaderA s2r_a; 189 | 190 | G2SLoaderB g2s_b; 191 | S2RLoaderB s2r_b; 192 | 193 | R2SStorerC r2s_c; 194 | S2GStorerC s2g_c; 195 | 196 | for (int k1 = 0; k1 < GIteratorA::sc1; ++k1) { 197 | g2s_a(gAs(k1), sA); 198 | g2s_b(gBs(k1), sB); 199 | __copy_async(); 200 | __syncthreads(); 201 | 202 | for (int k2 = 0; k2 < SIteratorA::sc1; ++k2) { 203 | s2r_a(sAs(k2), rA); 204 | s2r_b(sBs(k2), rB); 205 | } 206 | } 207 | r2s_c(acc, sC); 208 | __syncthreads(); 209 | s2g_c(sC, gC); 210 | } 211 | 212 | template 216 | void tiledcuda_shared_copy(const InType* dA, const InType* dB, AccType* dC) { 217 | using WholeShape = GemmShape; 218 | using CtaTileShape = GemmShape; 219 | using WarpLayout = tl::RowMajor; 220 | 221 | static constexpr int kRK = 32; 222 | 223 | using Config = KeGemmTraits; 225 | 226 | auto kernel = ©_shared_kernel< 227 | InType, AccType, kM, kN, kK, kTM, kTN, kTK, typename Config::GIteratorA, 228 | typename Config::SIteratorA, typename Config::SharedA, 229 | typename Config::RegA, typename Config::G2SLoaderA, 230 | typename Config::S2RLoaderA, typename Config::GIteratorB, 231 | typename Config::SIteratorB, typename Config::SharedB, 232 | typename Config::RegB, typename Config::G2SLoaderB, 233 | typename Config::S2RLoaderB, typename Config::GlobalC, 234 | typename Config::SharedC, typename Config::RegC, 235 | typename Config::R2SStorerC, typename Config::S2GStorerC>; 236 | 237 | static constexpr int smem_size_inputs = kTK * (kTN + kTM) * sizeof(InType); 238 | static constexpr int smem_size_accumulators = kTM * kTN * sizeof(AccType); 239 | static constexpr int smem_size = smem_size_inputs > smem_size_accumulators 240 | ? smem_size_inputs 241 | : smem_size_accumulators; 242 | 243 | const int kMaxSmemPerBlock = 48 * 1024; 244 | if (smem_size > kMaxSmemPerBlock) { 245 | cudaFuncSetAttribute( 246 | kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); 247 | } 248 | 249 | int block_x = (kM + kTM - 1) / kTM; 250 | int block_y = (kN + kTN - 1) / kTN; 251 | 252 | dim3 grid(block_x, block_y, 1); 253 | dim3 block(Config::kThreads, 1, 1); 254 | 255 | kernel<<>>(dA, dB, dC); 256 | } 257 | 258 | template 262 | void tiledcuda_copy(const InType* dA, const InType* dB, AccType* dC) { 263 | using WholeShape = GemmShape; 264 | using CtaTileShape = GemmShape; 265 | using WarpLayout = tl::RowMajor; 266 | 267 | static constexpr int kRK = 32; 268 | 269 | using Config = KeGemmTraits; 271 | 272 | auto kernel = 273 | ©_kernel; 283 | 284 | static constexpr int smem_size_inputs = kTK * (kTN + kTM) * sizeof(InType); 285 | static constexpr int smem_size_accumulators = kTM * kTN * sizeof(AccType); 286 | static constexpr int smem_size = smem_size_inputs > smem_size_accumulators 287 | ? smem_size_inputs 288 | : smem_size_accumulators; 289 | 290 | const int kMaxSmemPerBlock = 48 * 1024; 291 | if (smem_size > kMaxSmemPerBlock) { 292 | cudaFuncSetAttribute( 293 | kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); 294 | } 295 | 296 | int block_x = (kM + kTM - 1) / kTM; 297 | int block_y = (kN + kTN - 1) / kTN; 298 | 299 | dim3 grid(block_x, block_y, 1); 300 | dim3 block(Config::kThreads, 1, 1); 301 | 302 | kernel<<>>(dA, dB, dC); 303 | } 304 | --------------------------------------------------------------------------------