├── .gitignore ├── CMakeLists.txt ├── include ├── cuda_gemm_utils.hpp ├── cuda_gemm.hpp ├── profile_utils.cuh └── cuda_gemm_utils.cuh ├── src ├── cuda_gemm_utils.cu ├── CMakeLists.txt ├── 01_coalesced_global_memory_access.cu ├── 00_non_coalesced_global_memory_access.cu ├── profile_cuda_gemm_fp32.cu ├── profile_cuda_gemm_fp16.cu ├── 02_2d_block_tiling_vectorized_memory_access.cu ├── 02_2d_block_tiling.cu ├── 03_2d_block_tiling_1d_thread_tiling_vectorized_memory_access.cu ├── 03_2d_block_tiling_1d_thread_tiling.cu ├── 05_2d_block_tiling_2d_thread_tiling_matrix_transpose.cu ├── 04_2d_block_tiling_2d_thread_tiling.cu ├── 05_2d_block_tiling_2d_thread_tiling_matrix_transpose_vectorized_memory_access.cu ├── 04_2d_block_tiling_2d_thread_tiling_vectorized_memory_access.cu ├── 07_2d_block_tiling_2d_warp_tiling_2d_thread_tiling_matrix_transpose_wmma_vectorized_memory_access.cu ├── 07_2d_block_tiling_2d_warp_tiling_2d_thread_tiling_matrix_transpose_wmma.cu ├── 06_2d_block_tiling_2d_warp_tiling_2d_thread_tiling_matrix_transpose.cu ├── 06_2d_block_tiling_2d_warp_tiling_2d_thread_tiling_matrix_transpose_vectorized_memory_access.cu └── 07_2d_block_tiling_2d_warp_tiling_2d_thread_tiling_matrix_transpose_wmma_vectorized_memory_access_double_buffered.cu ├── docker └── gemm-cuda.Dockerfile ├── LICENSE ├── .clang-format └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | *.i 2 | *.ii 3 | *.gpu 4 | *.ptx 5 | *.cubin 6 | *.fatbin 7 | 8 | # Build files 9 | build/ 10 | 11 | .vscode/ -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.28) 2 | 3 | project(CUDA-GEMM-Optimization VERSION 0.0.1 LANGUAGES CXX CUDA) 4 | 5 | set(CMAKE_CXX_STANDARD 14) 6 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 7 | 8 | add_subdirectory(src) 9 | -------------------------------------------------------------------------------- /include/cuda_gemm_utils.hpp: -------------------------------------------------------------------------------- 1 | #ifndef CUDA_GEMM_UTILS_HPP 2 | #define CUDA_GEMM_UTILS_HPP 3 | 4 | #include 5 | 6 | #define CHECK_CUDA_ERROR(val) check_cuda((val), #val, __FILE__, __LINE__) 7 | void check_cuda(cudaError_t err, const char* const func, const char* const file, 8 | const int line); 9 | 10 | #define CHECK_LAST_CUDA_ERROR() check_cuda_last(__FILE__, __LINE__) 11 | void check_cuda_last(const char* const file, const int line); 12 | 13 | #endif // CUDA_GEMM_UTILS_HPP -------------------------------------------------------------------------------- /src/cuda_gemm_utils.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | #include "cuda_gemm_utils.hpp" 6 | 7 | void check_cuda(cudaError_t err, const char* const func, const char* const file, 8 | const int line) 9 | { 10 | if (err != cudaSuccess) 11 | { 12 | std::cerr << "CUDA Runtime Error at: " << file << ":" << line 13 | << std::endl; 14 | std::cerr << cudaGetErrorString(err) << " " << func << std::endl; 15 | std::exit(EXIT_FAILURE); 16 | } 17 | } 18 | 19 | void check_cuda_last(const char* const file, const int line) 20 | { 21 | cudaError_t const err{cudaGetLastError()}; 22 | if (err != cudaSuccess) 23 | { 24 | std::cerr << "CUDA Runtime Error at: " << file << ":" << line 25 | << std::endl; 26 | std::cerr << cudaGetErrorString(err) << std::endl; 27 | std::exit(EXIT_FAILURE); 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /docker/gemm-cuda.Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/cuda:12.2.2-devel-ubuntu22.04 2 | 3 | ARG CMAKE_VERSION=3.28.0 4 | ARG NUM_JOBS=8 5 | 6 | ENV DEBIAN_FRONTEND=noninteractive 7 | 8 | # Install package dependencies 9 | RUN apt-get update && apt-get install -y --no-install-recommends \ 10 | build-essential \ 11 | software-properties-common \ 12 | autoconf \ 13 | automake \ 14 | libtool \ 15 | pkg-config \ 16 | ca-certificates \ 17 | locales \ 18 | locales-all \ 19 | wget \ 20 | git && \ 21 | apt-get clean 22 | 23 | # System locale 24 | # Important for UTF-8 25 | ENV LC_ALL=en_US.UTF-8 26 | ENV LANG=en_US.UTF-8 27 | ENV LANGUAGE=en_US.UTF-8 28 | 29 | # Install CMake 30 | RUN cd /tmp && \ 31 | wget https://github.com/Kitware/CMake/releases/download/v${CMAKE_VERSION}/cmake-${CMAKE_VERSION}-linux-x86_64.sh && \ 32 | bash cmake-${CMAKE_VERSION}-linux-x86_64.sh --prefix=/usr/local --exclude-subdir --skip-license 33 | RUN rm -rf /tmp/* 34 | 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Lei Mao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.28) 2 | 3 | project(CUDA-GEMM-Algorithms VERSION 0.0.1 LANGUAGES CXX CUDA) 4 | 5 | set(CMAKE_CXX_STANDARD 14) 6 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 7 | 8 | # Find CUDA Toolkit 9 | find_package(CUDAToolkit REQUIRED) 10 | 11 | find_path(CUDA_GEMM_INCLUDE_DIRS cuda_gemm_utils.hpp HINTS ${CMAKE_SOURCE_DIR}/include) 12 | file(GLOB CUDA_GEMM_HEADERS ${CMAKE_SOURCE_DIR}/include/*.hpp ${CMAKE_SOURCE_DIR}/include/*.cuh) 13 | 14 | # Add all the source files in the current directory to build the library 15 | add_library( 16 | cuda_gemm 17 | SHARED 18 | cuda_gemm_utils.cu 19 | 00_non_coalesced_global_memory_access.cu 20 | 01_coalesced_global_memory_access.cu 21 | 02_2d_block_tiling.cu 22 | 02_2d_block_tiling_vectorized_memory_access.cu 23 | 03_2d_block_tiling_1d_thread_tiling.cu 24 | 03_2d_block_tiling_1d_thread_tiling_vectorized_memory_access.cu 25 | 04_2d_block_tiling_2d_thread_tiling.cu 26 | 04_2d_block_tiling_2d_thread_tiling_vectorized_memory_access.cu 27 | 05_2d_block_tiling_2d_thread_tiling_matrix_transpose.cu 28 | 05_2d_block_tiling_2d_thread_tiling_matrix_transpose_vectorized_memory_access.cu 29 | 06_2d_block_tiling_2d_warp_tiling_2d_thread_tiling_matrix_transpose.cu 30 | 06_2d_block_tiling_2d_warp_tiling_2d_thread_tiling_matrix_transpose_vectorized_memory_access.cu 31 | 06_2d_block_tiling_2d_warp_tiling_2d_thread_tiling_matrix_transpose_vectorized_memory_access_double_buffered.cu 32 | 07_2d_block_tiling_2d_warp_tiling_2d_thread_tiling_matrix_transpose_wmma.cu 33 | 07_2d_block_tiling_2d_warp_tiling_2d_thread_tiling_matrix_transpose_wmma_vectorized_memory_access.cu 34 | 07_2d_block_tiling_2d_warp_tiling_2d_thread_tiling_matrix_transpose_wmma_vectorized_memory_access_double_buffered.cu 35 | ) 36 | 37 | # Add the include directory of the library to the include directories of the project 38 | target_include_directories(cuda_gemm PUBLIC ${CUDA_GEMM_INCLUDE_DIRS}) 39 | 40 | # Set the CUDA architecture to compile the code for 41 | # https://cmake.org/cmake/help/latest/prop_tgt/CUDA_ARCHITECTURES.html 42 | set_target_properties(cuda_gemm PROPERTIES CUDA_ARCHITECTURES native) 43 | install(TARGETS cuda_gemm DESTINATION lib) 44 | install(FILES ${CUDA_GEMM_HEADERS} DESTINATION include) 45 | 46 | add_executable(profile_cuda_gemm_fp32 profile_cuda_gemm_fp32.cu) 47 | target_link_libraries(profile_cuda_gemm_fp32 cuda_gemm CUDA::cublas) 48 | set_target_properties(profile_cuda_gemm_fp32 PROPERTIES CUDA_ARCHITECTURES native) 49 | 50 | add_executable(profile_cuda_gemm_fp16 profile_cuda_gemm_fp16.cu) 51 | target_link_libraries(profile_cuda_gemm_fp16 cuda_gemm CUDA::cublas) 52 | set_target_properties(profile_cuda_gemm_fp16 PROPERTIES CUDA_ARCHITECTURES native) -------------------------------------------------------------------------------- /src/01_coalesced_global_memory_access.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "cuda_gemm.hpp" 4 | #include "cuda_gemm_utils.hpp" 5 | 6 | // GEMM kernel v01. 7 | // Coalesced read and write from global memory. 8 | template 9 | __global__ void gemm_v01(size_t m, size_t n, size_t k, T alpha, T const* A, 10 | size_t lda, T const* B, size_t ldb, T beta, T* C, 11 | size_t ldc) 12 | { 13 | // Compute the row and column of C that this thread is responsible for. 14 | size_t const C_col_idx{blockIdx.x * blockDim.x + threadIdx.x}; 15 | size_t const C_row_idx{blockIdx.y * blockDim.y + threadIdx.y}; 16 | 17 | // Each thread compute 18 | // C[C_row_idx, C_col_idx] = alpha * A[C_row_idx, :] * B[:, C_col_idx] + 19 | // beta * C[C_row_idx, C_col_idx]. 20 | if (C_row_idx < m && C_col_idx < n) 21 | { 22 | T sum{static_cast(0)}; 23 | for (size_t k_idx{0U}; k_idx < k; ++k_idx) 24 | { 25 | sum += A[C_row_idx * lda + k_idx] * B[k_idx * ldb + C_col_idx]; 26 | } 27 | C[C_row_idx * ldc + C_col_idx] = 28 | alpha * sum + beta * C[C_row_idx * ldc + C_col_idx]; 29 | } 30 | } 31 | 32 | template 33 | void launch_gemm_kernel_v01(size_t m, size_t n, size_t k, T const* alpha, 34 | T const* A, size_t lda, T const* B, size_t ldb, 35 | T const* beta, T* C, size_t ldc, 36 | cudaStream_t stream) 37 | { 38 | dim3 const block_dim{32U, 32U, 1U}; 39 | dim3 const grid_dim{ 40 | (static_cast(n) + block_dim.x - 1U) / block_dim.x, 41 | (static_cast(m) + block_dim.y - 1U) / block_dim.y, 1U}; 42 | gemm_v01<<>>(m, n, k, *alpha, A, lda, B, 43 | ldb, *beta, C, ldc); 44 | CHECK_LAST_CUDA_ERROR(); 45 | } 46 | 47 | // Explicit instantiation. 48 | template void launch_gemm_kernel_v01(size_t m, size_t n, size_t k, 49 | float const* alpha, float const* A, 50 | size_t lda, float const* B, 51 | size_t ldb, float const* beta, 52 | float* C, size_t ldc, 53 | cudaStream_t stream); 54 | template void launch_gemm_kernel_v01(size_t m, size_t n, size_t k, 55 | double const* alpha, 56 | double const* A, size_t lda, 57 | double const* B, size_t ldb, 58 | double const* beta, double* C, 59 | size_t ldc, cudaStream_t stream); 60 | template void launch_gemm_kernel_v01<__half>(size_t m, size_t n, size_t k, 61 | __half const* alpha, 62 | __half const* A, size_t lda, 63 | __half const* B, size_t ldb, 64 | __half const* beta, __half* C, 65 | size_t ldc, cudaStream_t stream); -------------------------------------------------------------------------------- /src/00_non_coalesced_global_memory_access.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "cuda_gemm.hpp" 4 | #include "cuda_gemm_utils.hpp" 5 | 6 | // GEMM kernel v00. 7 | // Non-coalesced read and write from global memory. 8 | template 9 | __global__ void gemm_v00(size_t m, size_t n, size_t k, T alpha, T const* A, 10 | size_t lda, T const* B, size_t ldb, T beta, T* C, 11 | size_t ldc) 12 | { 13 | // Compute the row and column of C that this thread is responsible for. 14 | size_t const C_row_idx{blockIdx.x * blockDim.x + threadIdx.x}; 15 | size_t const C_col_idx{blockIdx.y * blockDim.y + threadIdx.y}; 16 | 17 | // Each thread compute 18 | // C[C_row_idx, C_col_idx] = alpha * A[C_row_idx, :] * B[:, C_col_idx] + 19 | // beta * C[C_row_idx, C_col_idx]. 20 | if (C_row_idx < m && C_col_idx < n) 21 | { 22 | T sum{static_cast(0)}; 23 | for (size_t k_idx{0U}; k_idx < k; ++k_idx) 24 | { 25 | sum += A[C_row_idx * lda + k_idx] * B[k_idx * ldb + C_col_idx]; 26 | } 27 | C[C_row_idx * ldc + C_col_idx] = 28 | alpha * sum + beta * C[C_row_idx * ldc + C_col_idx]; 29 | } 30 | } 31 | 32 | template 33 | void launch_gemm_kernel_v00(size_t m, size_t n, size_t k, T const* alpha, 34 | T const* A, size_t lda, T const* B, size_t ldb, 35 | T const* beta, T* C, size_t ldc, 36 | cudaStream_t stream) 37 | { 38 | dim3 const block_dim{32U, 32U, 1U}; 39 | dim3 const grid_dim{ 40 | (static_cast(m) + block_dim.x - 1U) / block_dim.x, 41 | (static_cast(n) + block_dim.y - 1U) / block_dim.y, 1U}; 42 | gemm_v00<<>>(m, n, k, *alpha, A, lda, B, 43 | ldb, *beta, C, ldc); 44 | CHECK_LAST_CUDA_ERROR(); 45 | } 46 | 47 | // Explicit instantiation. 48 | template void launch_gemm_kernel_v00(size_t m, size_t n, size_t k, 49 | float const* alpha, float const* A, 50 | size_t lda, float const* B, 51 | size_t ldb, float const* beta, 52 | float* C, size_t ldc, 53 | cudaStream_t stream); 54 | template void launch_gemm_kernel_v00(size_t m, size_t n, size_t k, 55 | double const* alpha, 56 | double const* A, size_t lda, 57 | double const* B, size_t ldb, 58 | double const* beta, double* C, 59 | size_t ldc, cudaStream_t stream); 60 | template void launch_gemm_kernel_v00<__half>(size_t m, size_t n, size_t k, 61 | __half const* alpha, 62 | __half const* A, size_t lda, 63 | __half const* B, size_t ldb, 64 | __half const* beta, __half* C, 65 | size_t ldc, cudaStream_t stream); -------------------------------------------------------------------------------- /src/profile_cuda_gemm_fp32.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "cuda_gemm.hpp" 5 | #include "profile_utils.cuh" 6 | 7 | int main() 8 | { 9 | print_device_info(); 10 | 11 | constexpr size_t num_repeats{1U}; 12 | constexpr size_t num_warmups{1U}; 13 | 14 | float const fp32_abs_tol{1.0e-3f}; 15 | double const fp32_rel_tol{0.0e-4f}; 16 | 17 | constexpr size_t m{4096U}; 18 | constexpr size_t k{4096U}; 19 | constexpr size_t n{4096U}; 20 | 21 | constexpr size_t lda{(k + 16U - 1U) / 16U * 16U}; 22 | constexpr size_t ldb{(n + 16U - 1U) / 16U * 16U}; 23 | constexpr size_t ldc{(n + 16U - 1U) / 16U * 16U}; 24 | 25 | static_assert(lda >= k); 26 | static_assert(ldb >= n); 27 | static_assert(ldc >= n); 28 | 29 | std::cout << "Matrix Size: " << "M = " << m << " N = " << n << " K = " << k 30 | << std::endl; 31 | std::cout << "Matrix A: " << m << " x " << k 32 | << " Leading Dimension Size = " << lda << std::endl; 33 | std::cout << "Matrix B: " << k << " x " << n 34 | << " Leading Dimension Size = " << ldb << std::endl; 35 | std::cout << "Matrix C: " << m << " x " << n 36 | << " Leading Dimension Size = " << ldc << std::endl; 37 | std::cout << std::endl; 38 | 39 | // Define all the GEMM kernel launch functions to be profiled. 40 | std::vector>> const 45 | gemm_kernel_launch_functions{ 46 | {"Custom GEMM Kernel V00", launch_gemm_kernel_v00}, 47 | {"Custom GEMM Kernel V01", launch_gemm_kernel_v01}, 48 | {"Custom GEMM Kernel V02", launch_gemm_kernel_v02}, 49 | {"Custom GEMM Kernel V02 Vectorized", 50 | launch_gemm_kernel_v02_vectorized}, 51 | {"Custom GEMM Kernel V03", launch_gemm_kernel_v03}, 52 | {"Custom GEMM Kernel V03 Vectorized", 53 | launch_gemm_kernel_v03_vectorized}, 54 | {"Custom GEMM Kernel V04", launch_gemm_kernel_v04}, 55 | {"Custom GEMM Kernel V04 Vectorized", 56 | launch_gemm_kernel_v04_vectorized}, 57 | {"Custom GEMM Kernel V05", launch_gemm_kernel_v05}, 58 | {"Custom GEMM Kernel V05 Vectorized", 59 | launch_gemm_kernel_v05_vectorized}, 60 | {"Custom GEMM Kernel V06", launch_gemm_kernel_v06}, 61 | {"Custom GEMM Kernel V06 Vectorized", 62 | launch_gemm_kernel_v06_vectorized}, 63 | {"Custom GEMM Kernel V06 Vectorized Double Buffered", 64 | launch_gemm_kernel_v06_vectorized_double_buffered}, 65 | }; 66 | 67 | for (auto const& gemm_kernel_launch_function : gemm_kernel_launch_functions) 68 | { 69 | std::cout << gemm_kernel_launch_function.first << std::endl; 70 | std::pair const gemm_kernel_profile_result{ 71 | profile_gemm( 72 | m, n, k, lda, ldb, ldc, gemm_kernel_launch_function.second, 73 | fp32_abs_tol, fp32_rel_tol, num_repeats, num_warmups)}; 74 | std::cout << std::endl; 75 | } 76 | 77 | return 0; 78 | } -------------------------------------------------------------------------------- /.clang-format: -------------------------------------------------------------------------------- 1 | --- 2 | Language: Cpp 3 | # BasedOnStyle: LLVM 4 | AccessModifierOffset: -4 # -2 5 | AlignAfterOpenBracket: Align 6 | AlignConsecutiveAssignments: false 7 | AlignConsecutiveDeclarations: false 8 | AlignEscapedNewlines: Right 9 | AlignOperands: true 10 | AlignTrailingComments: true 11 | AllowAllParametersOfDeclarationOnNextLine: true 12 | AllowShortBlocksOnASingleLine: false 13 | AllowShortCaseLabelsOnASingleLine: false 14 | AllowShortFunctionsOnASingleLine: All 15 | AllowShortIfStatementsOnASingleLine: false 16 | AllowShortLoopsOnASingleLine: false 17 | AlwaysBreakAfterDefinitionReturnType: None 18 | AlwaysBreakAfterReturnType: None 19 | AlwaysBreakBeforeMultilineStrings: false 20 | AlwaysBreakTemplateDeclarations: true # false 21 | BinPackArguments: true 22 | BinPackParameters: true 23 | BraceWrapping: 24 | AfterClass: true # false 25 | AfterControlStatement: true # false 26 | AfterEnum: true # false 27 | AfterFunction: true # false 28 | AfterNamespace: true # false 29 | AfterObjCDeclaration: true # false 30 | AfterStruct: true # false 31 | AfterUnion: true # false 32 | AfterExternBlock: true # false 33 | BeforeCatch: true # false 34 | BeforeElse: true # false 35 | IndentBraces: true # false 36 | SplitEmptyFunction: true 37 | SplitEmptyRecord: true 38 | SplitEmptyNamespace: true 39 | BreakBeforeBinaryOperators: None 40 | BreakBeforeBraces: Allman # Attach 41 | BreakBeforeInheritanceComma: false 42 | BreakBeforeTernaryOperators: true 43 | BreakConstructorInitializersBeforeComma: false 44 | BreakConstructorInitializers: BeforeColon 45 | BreakAfterJavaFieldAnnotations: false 46 | BreakStringLiterals: true 47 | ColumnLimit: 80 48 | CommentPragmas: '^ IWYU pragma:' 49 | CompactNamespaces: false 50 | ConstructorInitializerAllOnOneLineOrOnePerLine: false 51 | ConstructorInitializerIndentWidth: 4 52 | ContinuationIndentWidth: 4 53 | Cpp11BracedListStyle: true 54 | DerivePointerAlignment: false 55 | DisableFormat: false 56 | ExperimentalAutoDetectBinPacking: false 57 | FixNamespaceComments: true 58 | ForEachMacros: 59 | - foreach 60 | - Q_FOREACH 61 | - BOOST_FOREACH 62 | IncludeBlocks: Preserve 63 | IncludeCategories: 64 | - Regex: '^"(llvm|llvm-c|clang|clang-c)/' 65 | Priority: 2 66 | - Regex: '^(<|"(gtest|gmock|isl|json)/)' 67 | Priority: 3 68 | - Regex: '.*' 69 | Priority: 1 70 | IncludeIsMainRegex: '(Test)?$' 71 | IndentCaseLabels: true # false 72 | IndentPPDirectives: None 73 | IndentWidth: 4 #2 74 | IndentWrappedFunctionNames: false 75 | JavaScriptQuotes: Leave 76 | JavaScriptWrapImports: true 77 | KeepEmptyLinesAtTheStartOfBlocks: true 78 | MacroBlockBegin: '' 79 | MacroBlockEnd: '' 80 | MaxEmptyLinesToKeep: 1 81 | NamespaceIndentation: None 82 | ObjCBlockIndentWidth: 2 83 | ObjCSpaceAfterProperty: false 84 | ObjCSpaceBeforeProtocolList: true 85 | PenaltyBreakAssignment: 2 86 | PenaltyBreakBeforeFirstCallParameter: 19 87 | PenaltyBreakComment: 300 88 | PenaltyBreakFirstLessLess: 120 89 | PenaltyBreakString: 1000 90 | PenaltyExcessCharacter: 1000000 91 | PenaltyReturnTypeOnItsOwnLine: 60 92 | PointerAlignment: Left # Right 93 | ReflowComments: true 94 | SortIncludes: true 95 | SortUsingDeclarations: true 96 | SpaceAfterCStyleCast: false 97 | SpaceAfterTemplateKeyword: true 98 | SpaceBeforeAssignmentOperators: true 99 | SpaceBeforeParens: ControlStatements 100 | SpaceInEmptyParentheses: false 101 | SpacesBeforeTrailingComments: 1 102 | SpacesInAngles: false 103 | SpacesInContainerLiterals: true 104 | SpacesInCStyleCastParentheses: false 105 | SpacesInParentheses: false 106 | SpacesInSquareBrackets: false 107 | Standard: Cpp11 108 | TabWidth: 8 109 | UseTab: Never 110 | ... 111 | -------------------------------------------------------------------------------- /src/profile_cuda_gemm_fp16.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "cuda_gemm.hpp" 5 | #include "profile_utils.cuh" 6 | 7 | int main() 8 | { 9 | print_device_info(); 10 | 11 | constexpr size_t num_repeats{1U}; 12 | constexpr size_t num_warmups{1U}; 13 | 14 | __half const fp16_abs_tol{__float2half(5.0e-2f)}; 15 | double const fp16_rel_tol{1.0e-1f}; 16 | 17 | __half const fp16_tensor_core_abs_tol{__float2half(5.0e-2f)}; 18 | double const fp16_tensor_core_rel_tol{1.0e-2f}; 19 | 20 | constexpr size_t m{4096U}; 21 | constexpr size_t k{4096U}; 22 | constexpr size_t n{4096U}; 23 | 24 | constexpr size_t lda{(k + 16U - 1U) / 16U * 16U}; 25 | constexpr size_t ldb{(n + 16U - 1U) / 16U * 16U}; 26 | constexpr size_t ldc{(n + 16U - 1U) / 16U * 16U}; 27 | 28 | static_assert(lda >= k); 29 | static_assert(ldb >= n); 30 | static_assert(ldc >= n); 31 | 32 | std::cout << "Matrix Size: " << "M = " << m << " N = " << n << " K = " << k 33 | << std::endl; 34 | std::cout << "Matrix A: " << m << " x " << k 35 | << " Leading Dimension Size = " << lda << std::endl; 36 | std::cout << "Matrix B: " << k << " x " << n 37 | << " Leading Dimension Size = " << ldb << std::endl; 38 | std::cout << "Matrix C: " << m << " x " << n 39 | << " Leading Dimension Size = " << ldc << std::endl; 40 | std::cout << std::endl; 41 | 42 | // Define all the GEMM kernel launch functions to be profiled. 43 | std::vector>> const 48 | gemm_fp16_kernel_launch_functions{ 49 | {"Custom GEMM Kernel V00", launch_gemm_kernel_v00<__half>}, 50 | {"Custom GEMM Kernel V01", launch_gemm_kernel_v01<__half>}, 51 | {"Custom GEMM Kernel V02", launch_gemm_kernel_v02<__half>}, 52 | {"Custom GEMM Kernel V02 Vectorized", 53 | launch_gemm_kernel_v02_vectorized<__half>}, 54 | {"Custom GEMM Kernel V03", launch_gemm_kernel_v03<__half>}, 55 | {"Custom GEMM Kernel V03 Vectorized", 56 | launch_gemm_kernel_v03_vectorized<__half>}, 57 | {"Custom GEMM Kernel V04", launch_gemm_kernel_v04<__half>}, 58 | {"Custom GEMM Kernel V04 Vectorized", 59 | launch_gemm_kernel_v04_vectorized<__half>}, 60 | {"Custom GEMM Kernel V05", launch_gemm_kernel_v05<__half>}, 61 | {"Custom GEMM Kernel V05 Vectorized", 62 | launch_gemm_kernel_v05_vectorized<__half>}, 63 | {"Custom GEMM Kernel V06", launch_gemm_kernel_v06<__half>}, 64 | {"Custom GEMM Kernel V06 Vectorized", 65 | launch_gemm_kernel_v06_vectorized<__half>}, 66 | {"Custom GEMM Kernel V06 Vectorized Double Buffered", 67 | launch_gemm_kernel_v06_vectorized_double_buffered<__half>}, 68 | }; 69 | 70 | for (auto const& gemm_fp16_kernel_launch_function : 71 | gemm_fp16_kernel_launch_functions) 72 | { 73 | std::cout << gemm_fp16_kernel_launch_function.first << std::endl; 74 | std::pair<__half, __half> const gemm_kernel_profile_result{ 75 | profile_gemm<__half>( 76 | m, n, k, lda, ldb, ldc, gemm_fp16_kernel_launch_function.second, 77 | fp16_abs_tol, fp16_rel_tol, num_repeats, num_warmups)}; 78 | std::cout << std::endl; 79 | } 80 | 81 | std::vector>> const 86 | gemm_fp16_tensor_core_kernel_launch_functions{ 87 | {"Custom GEMM Kernel V07", launch_gemm_kernel_v07<__half>}, 88 | {"Custom GEMM Kernel V07 Vectorized", 89 | launch_gemm_kernel_v07_vectorized<__half>}, 90 | {"Custom GEMM Kernel V07 Vectorized Double Buffered", 91 | launch_gemm_kernel_v07_vectorized_double_buffered<__half>}, 92 | }; 93 | 94 | for (auto const& gemm_fp16_tensor_core_kernel_launch_function : 95 | gemm_fp16_tensor_core_kernel_launch_functions) 96 | { 97 | std::cout << gemm_fp16_tensor_core_kernel_launch_function.first 98 | << std::endl; 99 | std::pair<__half, __half> const gemm_kernel_profile_result{ 100 | profile_gemm<__half>( 101 | m, n, k, lda, ldb, ldc, 102 | gemm_fp16_tensor_core_kernel_launch_function.second, 103 | fp16_tensor_core_abs_tol, fp16_tensor_core_rel_tol, num_repeats, 104 | num_warmups)}; 105 | std::cout << std::endl; 106 | } 107 | 108 | return 0; 109 | } -------------------------------------------------------------------------------- /include/cuda_gemm.hpp: -------------------------------------------------------------------------------- 1 | #ifndef CUDA_GEMM_HPP 2 | #define CUDA_GEMM_HPP 3 | 4 | #include 5 | 6 | template 7 | void launch_gemm_kernel_v00(size_t m, size_t n, size_t k, T const* alpha, 8 | T const* A, size_t lda, T const* B, size_t ldb, 9 | T const* beta, T* C, size_t ldc, 10 | cudaStream_t stream); 11 | 12 | template 13 | void launch_gemm_kernel_v01(size_t m, size_t n, size_t k, T const* alpha, 14 | T const* A, size_t lda, T const* B, size_t ldb, 15 | T const* beta, T* C, size_t ldc, 16 | cudaStream_t stream); 17 | 18 | template 19 | void launch_gemm_kernel_v02(size_t m, size_t n, size_t k, T const* alpha, 20 | T const* A, size_t lda, T const* B, size_t ldb, 21 | T const* beta, T* C, size_t ldc, 22 | cudaStream_t stream); 23 | 24 | template 25 | void launch_gemm_kernel_v02_vectorized(size_t m, size_t n, size_t k, 26 | T const* alpha, T const* A, size_t lda, 27 | T const* B, size_t ldb, T const* beta, 28 | T* C, size_t ldc, cudaStream_t stream); 29 | 30 | template 31 | void launch_gemm_kernel_v03(size_t m, size_t n, size_t k, T const* alpha, 32 | T const* A, size_t lda, T const* B, size_t ldb, 33 | T const* beta, T* C, size_t ldc, 34 | cudaStream_t stream); 35 | 36 | template 37 | void launch_gemm_kernel_v03_vectorized(size_t m, size_t n, size_t k, 38 | T const* alpha, T const* A, size_t lda, 39 | T const* B, size_t ldb, T const* beta, 40 | T* C, size_t ldc, cudaStream_t stream); 41 | template 42 | void launch_gemm_kernel_v04(size_t m, size_t n, size_t k, T const* alpha, 43 | T const* A, size_t lda, T const* B, size_t ldb, 44 | T const* beta, T* C, size_t ldc, 45 | cudaStream_t stream); 46 | 47 | template 48 | void launch_gemm_kernel_v04_vectorized(size_t m, size_t n, size_t k, 49 | T const* alpha, T const* A, size_t lda, 50 | T const* B, size_t ldb, T const* beta, 51 | T* C, size_t ldc, cudaStream_t stream); 52 | 53 | template 54 | void launch_gemm_kernel_v05(size_t m, size_t n, size_t k, T const* alpha, 55 | T const* A, size_t lda, T const* B, size_t ldb, 56 | T const* beta, T* C, size_t ldc, 57 | cudaStream_t stream); 58 | template 59 | void launch_gemm_kernel_v05_vectorized(size_t m, size_t n, size_t k, 60 | T const* alpha, T const* A, size_t lda, 61 | T const* B, size_t ldb, T const* beta, 62 | T* C, size_t ldc, cudaStream_t stream); 63 | 64 | template 65 | void launch_gemm_kernel_v06(size_t m, size_t n, size_t k, T const* alpha, 66 | T const* A, size_t lda, T const* B, size_t ldb, 67 | T const* beta, T* C, size_t ldc, 68 | cudaStream_t stream); 69 | template 70 | void launch_gemm_kernel_v06_vectorized(size_t m, size_t n, size_t k, 71 | T const* alpha, T const* A, size_t lda, 72 | T const* B, size_t ldb, T const* beta, 73 | T* C, size_t ldc, cudaStream_t stream); 74 | 75 | template 76 | void launch_gemm_kernel_v06_vectorized_double_buffered( 77 | size_t m, size_t n, size_t k, T const* alpha, T const* A, size_t lda, 78 | T const* B, size_t ldb, T const* beta, T* C, size_t ldc, 79 | cudaStream_t stream); 80 | 81 | template 82 | void launch_gemm_kernel_v07(size_t m, size_t n, size_t k, T const* alpha, 83 | T const* A, size_t lda, T const* B, size_t ldb, 84 | T const* beta, T* C, size_t ldc, 85 | cudaStream_t stream); 86 | 87 | template 88 | void launch_gemm_kernel_v07_vectorized(size_t m, size_t n, size_t k, 89 | T const* alpha, T const* A, size_t lda, 90 | T const* B, size_t ldb, T const* beta, 91 | T* C, size_t ldc, cudaStream_t stream); 92 | 93 | template 94 | void launch_gemm_kernel_v07_vectorized_double_buffered( 95 | size_t m, size_t n, size_t k, T const* alpha, T const* A, size_t lda, 96 | T const* B, size_t ldb, T const* beta, T* C, size_t ldc, 97 | cudaStream_t stream); 98 | #endif -------------------------------------------------------------------------------- /src/02_2d_block_tiling_vectorized_memory_access.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "cuda_gemm.hpp" 4 | #include "cuda_gemm_utils.cuh" 5 | #include "cuda_gemm_utils.hpp" 6 | 7 | // GEMM kernel v02. 8 | // Coalesced read and write from global memory. 9 | // We guarantee that matrix A, B, and C are 32 byte aligned. 10 | // This implementation is slower because we waste a lot of threads. 11 | template 13 | __global__ void gemm_v02_vectorized(size_t m, size_t n, size_t k, T alpha, 14 | T const* A, size_t lda, T const* B, 15 | size_t ldb, T beta, T* C, size_t ldc) 16 | { 17 | // Avoid using blockDim.x * blockDim.y as the number of threads per block. 18 | // Because it is a runtime constant and the compiler cannot optimize the 19 | // loop unrolling based on that. 20 | // Use a compile time constant instead. 21 | constexpr size_t NUM_THREADS{BLOCK_TILE_SIZE_X * BLOCK_TILE_SIZE_Y}; 22 | size_t const thread_linear_idx{threadIdx.y * blockDim.x + threadIdx.x}; 23 | 24 | // Compute the row and column of C that this thread is responsible for. 25 | size_t const C_col_idx{blockIdx.x * blockDim.x + threadIdx.x}; 26 | size_t const C_row_idx{blockIdx.y * blockDim.y + threadIdx.y}; 27 | 28 | // Cache a tile of A and B in shared memory for data reuse. 29 | __shared__ T A_thread_block_tile[BLOCK_TILE_SIZE_Y][BLOCK_TILE_SIZE_K]; 30 | __shared__ T B_thread_block_tile[BLOCK_TILE_SIZE_K][BLOCK_TILE_SIZE_X]; 31 | 32 | size_t const num_thread_block_tiles{(k + BLOCK_TILE_SIZE_K - 1) / 33 | BLOCK_TILE_SIZE_K}; 34 | 35 | T sum{static_cast(0)}; 36 | for (size_t thread_block_tile_idx{0U}; 37 | thread_block_tile_idx < num_thread_block_tiles; 38 | ++thread_block_tile_idx) 39 | { 40 | load_data_from_global_memory_to_shared_memory_vectorized< 41 | T, BLOCK_TILE_SIZE_X, BLOCK_TILE_SIZE_Y, BLOCK_TILE_SIZE_K, 42 | NUM_THREADS>(A, lda, B, ldb, A_thread_block_tile, 43 | B_thread_block_tile, thread_block_tile_idx, 44 | thread_linear_idx, m, n, k); 45 | __syncthreads(); 46 | 47 | #pragma unroll 48 | for (size_t k_i{0U}; k_i < BLOCK_TILE_SIZE_K; ++k_i) 49 | { 50 | // Doing this results in 2 TOPS. 51 | // Suppose blockDim.x = blockDim.y = 32. 52 | // Effectively, for a warp, in one iteration, we read the value from 53 | // A_thread_block_tile at the same location on the shared memory 54 | // resulting in a broadcast, we also read 32 values that have no 55 | // bank conflicts from B_thread_block_tile. Even with that, all the 56 | // values have to be read from the shared memory and consequence is 57 | // the shared memory instruction runs very intensively just to 58 | // compute a small number of values using simple arithmetic 59 | // instructions, which is not efficient. 60 | sum += A_thread_block_tile[threadIdx.y][k_i] * 61 | B_thread_block_tile[k_i][threadIdx.x]; 62 | } 63 | __syncthreads(); 64 | } 65 | if (C_row_idx < m && C_col_idx < n) 66 | { 67 | C[C_row_idx * ldc + C_col_idx] = 68 | alpha * sum + beta * C[C_row_idx * ldc + C_col_idx]; 69 | } 70 | } 71 | 72 | template 73 | void launch_gemm_kernel_v02_vectorized(size_t m, size_t n, size_t k, 74 | T const* alpha, T const* A, size_t lda, 75 | T const* B, size_t ldb, T const* beta, 76 | T* C, size_t ldc, cudaStream_t stream) 77 | { 78 | // Feel free to play with the block tile sizes. 79 | // The algorithm correctness should always be guaranteed. 80 | constexpr unsigned int BLOCK_TILE_SIZE_X{32U}; 81 | constexpr unsigned int BLOCK_TILE_SIZE_Y{32U}; 82 | constexpr unsigned int BLOCK_TILE_SIZE_K{32U}; 83 | constexpr unsigned int NUM_THREADS{BLOCK_TILE_SIZE_X * BLOCK_TILE_SIZE_Y}; 84 | static_assert(BLOCK_TILE_SIZE_K * BLOCK_TILE_SIZE_Y % NUM_THREADS == 0U); 85 | static_assert(BLOCK_TILE_SIZE_X * BLOCK_TILE_SIZE_K % NUM_THREADS == 0U); 86 | dim3 const block_dim{BLOCK_TILE_SIZE_X, BLOCK_TILE_SIZE_Y, 1U}; 87 | dim3 const grid_dim{ 88 | (static_cast(n) + block_dim.x - 1U) / block_dim.x, 89 | (static_cast(m) + block_dim.y - 1U) / block_dim.y, 1U}; 90 | gemm_v02_vectorized<<>>( 92 | m, n, k, *alpha, A, lda, B, ldb, *beta, C, ldc); 93 | CHECK_LAST_CUDA_ERROR(); 94 | } 95 | 96 | // Explicit instantiation. 97 | template void launch_gemm_kernel_v02_vectorized( 98 | size_t m, size_t n, size_t k, float const* alpha, float const* A, 99 | size_t lda, float const* B, size_t ldb, float const* beta, float* C, 100 | size_t ldc, cudaStream_t stream); 101 | template void launch_gemm_kernel_v02_vectorized( 102 | size_t m, size_t n, size_t k, double const* alpha, double const* A, 103 | size_t lda, double const* B, size_t ldb, double const* beta, double* C, 104 | size_t ldc, cudaStream_t stream); 105 | template void launch_gemm_kernel_v02_vectorized<__half>( 106 | size_t m, size_t n, size_t k, __half const* alpha, __half const* A, 107 | size_t lda, __half const* B, size_t ldb, __half const* beta, __half* C, 108 | size_t ldc, cudaStream_t stream); -------------------------------------------------------------------------------- /src/02_2d_block_tiling.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "cuda_gemm.hpp" 4 | #include "cuda_gemm_utils.cuh" 5 | #include "cuda_gemm_utils.hpp" 6 | 7 | // GEMM kernel v02. 8 | // Coalesced read and write from global memory. 9 | template 11 | __global__ void gemm_v02(size_t m, size_t n, size_t k, T alpha, T const* A, 12 | size_t lda, T const* B, size_t ldb, T beta, T* C, 13 | size_t ldc) 14 | { 15 | // Avoid using blockDim.x * blockDim.y as the number of threads per block. 16 | // Because it is a runtime constant and the compiler cannot optimize the 17 | // loop unrolling based on that. 18 | // Use a compile time constant instead. 19 | constexpr size_t NUM_THREADS{BLOCK_TILE_SIZE_X * BLOCK_TILE_SIZE_Y}; 20 | size_t const thread_linear_idx{threadIdx.y * blockDim.x + threadIdx.x}; 21 | 22 | // Compute the row and column of C that this thread is responsible for. 23 | size_t const C_col_idx{blockIdx.x * blockDim.x + threadIdx.x}; 24 | size_t const C_row_idx{blockIdx.y * blockDim.y + threadIdx.y}; 25 | 26 | // Cache a tile of A and B in shared memory for data reuse. 27 | __shared__ T A_thread_block_tile[BLOCK_TILE_SIZE_Y][BLOCK_TILE_SIZE_K]; 28 | __shared__ T B_thread_block_tile[BLOCK_TILE_SIZE_K][BLOCK_TILE_SIZE_X]; 29 | 30 | size_t const num_thread_block_tiles{(k + BLOCK_TILE_SIZE_K - 1) / 31 | BLOCK_TILE_SIZE_K}; 32 | 33 | T sum{static_cast(0)}; 34 | for (size_t thread_block_tile_idx{0U}; 35 | thread_block_tile_idx < num_thread_block_tiles; 36 | ++thread_block_tile_idx) 37 | { 38 | load_data_from_global_memory_to_shared_memory< 39 | T, BLOCK_TILE_SIZE_X, BLOCK_TILE_SIZE_Y, BLOCK_TILE_SIZE_K, 40 | NUM_THREADS>(A, lda, B, ldb, A_thread_block_tile, 41 | B_thread_block_tile, thread_block_tile_idx, 42 | thread_linear_idx, m, n, k); 43 | __syncthreads(); 44 | 45 | #pragma unroll 46 | for (size_t k_i{0U}; k_i < BLOCK_TILE_SIZE_K; ++k_i) 47 | { 48 | // Doing this results in 2 TOPS. 49 | // Suppose blockDim.x = blockDim.y = 32. 50 | // Effectively, for a warp, in one iteration, we read the value from 51 | // A_thread_block_tile at the same location on the shared memory 52 | // resulting in a broadcast, we also read 32 values that have no 53 | // bank conflicts from B_thread_block_tile. Even with that, all the 54 | // values have to be read from the shared memory and consequence is 55 | // the shared memory instruction runs very intensively just to 56 | // compute a small number of values using simple arithmetic 57 | // instructions, which is not efficient. 58 | sum += A_thread_block_tile[threadIdx.y][k_i] * 59 | B_thread_block_tile[k_i][threadIdx.x]; 60 | } 61 | __syncthreads(); 62 | } 63 | if (C_row_idx < m && C_col_idx < n) 64 | { 65 | C[C_row_idx * ldc + C_col_idx] = 66 | alpha * sum + beta * C[C_row_idx * ldc + C_col_idx]; 67 | } 68 | } 69 | 70 | template 71 | void launch_gemm_kernel_v02(size_t m, size_t n, size_t k, T const* alpha, 72 | T const* A, size_t lda, T const* B, size_t ldb, 73 | T const* beta, T* C, size_t ldc, 74 | cudaStream_t stream) 75 | { 76 | // Feel free to play with the block tile sizes. 77 | // The algorithm correctness should always be guaranteed. 78 | constexpr unsigned int BLOCK_TILE_SIZE_X{32U}; 79 | constexpr unsigned int BLOCK_TILE_SIZE_Y{32U}; 80 | constexpr unsigned int BLOCK_TILE_SIZE_K{32U}; 81 | constexpr unsigned int NUM_THREADS{BLOCK_TILE_SIZE_X * BLOCK_TILE_SIZE_Y}; 82 | static_assert(BLOCK_TILE_SIZE_K * BLOCK_TILE_SIZE_Y % NUM_THREADS == 0U); 83 | static_assert(BLOCK_TILE_SIZE_X * BLOCK_TILE_SIZE_K % NUM_THREADS == 0U); 84 | dim3 const block_dim{BLOCK_TILE_SIZE_X, BLOCK_TILE_SIZE_Y, 1U}; 85 | dim3 const grid_dim{ 86 | (static_cast(n) + block_dim.x - 1U) / block_dim.x, 87 | (static_cast(m) + block_dim.y - 1U) / block_dim.y, 1U}; 88 | gemm_v02 89 | <<>>(m, n, k, *alpha, A, lda, B, ldb, 90 | *beta, C, ldc); 91 | CHECK_LAST_CUDA_ERROR(); 92 | } 93 | 94 | // Explicit instantiation. 95 | template void launch_gemm_kernel_v02(size_t m, size_t n, size_t k, 96 | float const* alpha, float const* A, 97 | size_t lda, float const* B, 98 | size_t ldb, float const* beta, 99 | float* C, size_t ldc, 100 | cudaStream_t stream); 101 | template void launch_gemm_kernel_v02(size_t m, size_t n, size_t k, 102 | double const* alpha, 103 | double const* A, size_t lda, 104 | double const* B, size_t ldb, 105 | double const* beta, double* C, 106 | size_t ldc, cudaStream_t stream); 107 | template void launch_gemm_kernel_v02<__half>(size_t m, size_t n, size_t k, 108 | __half const* alpha, 109 | __half const* A, size_t lda, 110 | __half const* B, size_t ldb, 111 | __half const* beta, __half* C, 112 | size_t ldc, cudaStream_t stream); -------------------------------------------------------------------------------- /src/03_2d_block_tiling_1d_thread_tiling_vectorized_memory_access.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "cuda_gemm.hpp" 4 | #include "cuda_gemm_utils.cuh" 5 | #include "cuda_gemm_utils.hpp" 6 | 7 | // GEMM kernel v03. 8 | // Coalesced read and write from global memory. 9 | template 11 | __global__ void gemm_v03_vectorized(size_t m, size_t n, size_t k, T alpha, 12 | T const* A, size_t lda, T const* B, 13 | size_t ldb, T beta, T* C, size_t ldc) 14 | { 15 | // Avoid using blockDim.x * blockDim.y as the number of threads per block. 16 | // Because it is a runtime constant and the compiler cannot optimize the 17 | // loop unrolling based on that. 18 | // Use a compile time constant instead. 19 | constexpr size_t NUM_THREADS{BLOCK_TILE_SIZE_X * BLOCK_TILE_SIZE_Y / 20 | THREAD_TILE_SIZE_Y}; 21 | size_t const thread_linear_idx{threadIdx.y * blockDim.x + threadIdx.x}; 22 | 23 | // Cache a tile of A and B in shared memory for data reuse. 24 | __shared__ T A_thread_block_tile[BLOCK_TILE_SIZE_Y][BLOCK_TILE_SIZE_K]; 25 | __shared__ T B_thread_block_tile[BLOCK_TILE_SIZE_K][BLOCK_TILE_SIZE_X]; 26 | 27 | size_t const num_thread_block_tiles{(k + BLOCK_TILE_SIZE_K - 1) / 28 | BLOCK_TILE_SIZE_K}; 29 | 30 | // Each thread in the block processes BLOCK_TILE_SIZE_Y output values. 31 | // Specifically, these values corresponds to 32 | // C[blockIdx.y * BLOCK_TILE_SIZE_Y + threadIdx.x / BLOCK_TILE_SIZE_X * 33 | // THREAD_TILE_SIZE_Y : blockIdx.y * BLOCK_TILE_SIZE_Y + (threadIdx.x / 34 | // BLOCK_TILE_SIZE_X + 1) * THREAD_TILE_SIZE_Y][blockIdx.x * 35 | // BLOCK_TILE_SIZE_X + threadIdx.x % BLOCK_TILE_SIZE_X] 36 | T C_thread_results[THREAD_TILE_SIZE_Y] = {static_cast(0)}; 37 | 38 | for (size_t thread_block_tile_idx{0U}; 39 | thread_block_tile_idx < num_thread_block_tiles; 40 | ++thread_block_tile_idx) 41 | { 42 | load_data_from_global_memory_to_shared_memory_vectorized< 43 | T, BLOCK_TILE_SIZE_X, BLOCK_TILE_SIZE_Y, BLOCK_TILE_SIZE_K, 44 | NUM_THREADS>(A, lda, B, ldb, A_thread_block_tile, 45 | B_thread_block_tile, thread_block_tile_idx, 46 | thread_linear_idx, m, n, k); 47 | __syncthreads(); 48 | 49 | #pragma unroll 50 | for (size_t k_i{0U}; k_i < BLOCK_TILE_SIZE_K; ++k_i) 51 | { 52 | size_t const B_thread_block_tile_row_idx{k_i}; 53 | // B_val is cached in the register to alleviate the pressure on the 54 | // shared memory access. 55 | T const B_val{ 56 | B_thread_block_tile[B_thread_block_tile_row_idx] 57 | [thread_linear_idx % BLOCK_TILE_SIZE_X]}; 58 | #pragma unroll 59 | for (size_t thread_tile_row_idx{0U}; 60 | thread_tile_row_idx < THREAD_TILE_SIZE_Y; 61 | ++thread_tile_row_idx) 62 | { 63 | size_t const A_thread_block_tile_row_idx{ 64 | thread_linear_idx / BLOCK_TILE_SIZE_X * THREAD_TILE_SIZE_Y + 65 | thread_tile_row_idx}; 66 | size_t const A_thread_block_tile_col_idx{k_i}; 67 | T const A_val{A_thread_block_tile[A_thread_block_tile_row_idx] 68 | [A_thread_block_tile_col_idx]}; 69 | C_thread_results[thread_tile_row_idx] += A_val * B_val; 70 | } 71 | } 72 | __syncthreads(); 73 | } 74 | 75 | // Write the results to DRAM. 76 | // Cannot vectorized the write to DRAM because we are writting to a column 77 | // instead of a row in C. 78 | #pragma unroll 79 | for (size_t thread_tile_row_idx{0U}; 80 | thread_tile_row_idx < THREAD_TILE_SIZE_Y; ++thread_tile_row_idx) 81 | { 82 | size_t const C_row_idx{blockIdx.y * BLOCK_TILE_SIZE_Y + 83 | thread_linear_idx / BLOCK_TILE_SIZE_X * 84 | THREAD_TILE_SIZE_Y + 85 | thread_tile_row_idx}; 86 | size_t const C_col_idx{blockIdx.x * BLOCK_TILE_SIZE_X + 87 | thread_linear_idx % BLOCK_TILE_SIZE_X}; 88 | if (C_row_idx < m && C_col_idx < n) 89 | { 90 | C[C_row_idx * ldc + C_col_idx] = 91 | alpha * C_thread_results[thread_tile_row_idx] + 92 | beta * C[C_row_idx * ldc + C_col_idx]; 93 | } 94 | } 95 | } 96 | 97 | template 98 | void launch_gemm_kernel_v03_vectorized(size_t m, size_t n, size_t k, 99 | T const* alpha, T const* A, size_t lda, 100 | T const* B, size_t ldb, T const* beta, 101 | T* C, size_t ldc, cudaStream_t stream) 102 | { 103 | // Feel free to play with the block tile sizes. 104 | // The algorithm correctness should always be guaranteed. 105 | constexpr unsigned int BLOCK_TILE_SIZE_X{64U}; 106 | constexpr unsigned int BLOCK_TILE_SIZE_Y{64U}; 107 | constexpr unsigned int BLOCK_TILE_SIZE_K{8U}; 108 | // Each thread computes THREAD_TILE_SIZE_Y values of C. 109 | constexpr unsigned int THREAD_TILE_SIZE_Y{8U}; 110 | constexpr unsigned int NUM_THREADS_PER_BLOCK{ 111 | BLOCK_TILE_SIZE_X * BLOCK_TILE_SIZE_Y / THREAD_TILE_SIZE_Y}; 112 | static_assert(BLOCK_TILE_SIZE_Y % THREAD_TILE_SIZE_Y == 0U); 113 | static_assert(NUM_THREADS_PER_BLOCK % BLOCK_TILE_SIZE_K == 0U); 114 | static_assert(NUM_THREADS_PER_BLOCK % BLOCK_TILE_SIZE_X == 0U); 115 | dim3 const block_dim{NUM_THREADS_PER_BLOCK, 1U, 1U}; 116 | dim3 const grid_dim{ 117 | (static_cast(n) + BLOCK_TILE_SIZE_X - 1U) / 118 | BLOCK_TILE_SIZE_X, 119 | (static_cast(m) + BLOCK_TILE_SIZE_Y - 1U) / 120 | BLOCK_TILE_SIZE_Y, 121 | 1U}; 122 | gemm_v03_vectorized 124 | <<>>(m, n, k, *alpha, A, lda, B, ldb, 125 | *beta, C, ldc); 126 | CHECK_LAST_CUDA_ERROR(); 127 | } 128 | 129 | // Explicit instantiation. 130 | template void launch_gemm_kernel_v03_vectorized( 131 | size_t m, size_t n, size_t k, float const* alpha, float const* A, 132 | size_t lda, float const* B, size_t ldb, float const* beta, float* C, 133 | size_t ldc, cudaStream_t stream); 134 | template void launch_gemm_kernel_v03_vectorized( 135 | size_t m, size_t n, size_t k, double const* alpha, double const* A, 136 | size_t lda, double const* B, size_t ldb, double const* beta, double* C, 137 | size_t ldc, cudaStream_t stream); 138 | template void launch_gemm_kernel_v03_vectorized<__half>( 139 | size_t m, size_t n, size_t k, __half const* alpha, __half const* A, 140 | size_t lda, __half const* B, size_t ldb, __half const* beta, __half* C, 141 | size_t ldc, cudaStream_t stream); -------------------------------------------------------------------------------- /src/03_2d_block_tiling_1d_thread_tiling.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "cuda_gemm.hpp" 4 | #include "cuda_gemm_utils.cuh" 5 | #include "cuda_gemm_utils.hpp" 6 | 7 | // GEMM kernel v03. 8 | // Coalesced read and write from global memory. 9 | template 11 | __global__ void gemm_v03(size_t m, size_t n, size_t k, T alpha, T const* A, 12 | size_t lda, T const* B, size_t ldb, T beta, T* C, 13 | size_t ldc) 14 | { 15 | // Avoid using blockDim.x * blockDim.y as the number of threads per block. 16 | // Because it is a runtime constant and the compiler cannot optimize the 17 | // loop unrolling based on that. 18 | // Use a compile time constant instead. 19 | constexpr size_t NUM_THREADS{BLOCK_TILE_SIZE_X * BLOCK_TILE_SIZE_Y / 20 | THREAD_TILE_SIZE_Y}; 21 | size_t const thread_linear_idx{threadIdx.y * blockDim.x + threadIdx.x}; 22 | 23 | // Cache a tile of A and B in shared memory for data reuse. 24 | __shared__ T A_thread_block_tile[BLOCK_TILE_SIZE_Y][BLOCK_TILE_SIZE_K]; 25 | __shared__ T B_thread_block_tile[BLOCK_TILE_SIZE_K][BLOCK_TILE_SIZE_X]; 26 | 27 | size_t const num_thread_block_tiles{(k + BLOCK_TILE_SIZE_K - 1) / 28 | BLOCK_TILE_SIZE_K}; 29 | 30 | // Each thread in the block processes BLOCK_TILE_SIZE_Y output values. 31 | // Specifically, these values corresponds to 32 | // C[blockIdx.y * BLOCK_TILE_SIZE_Y + threadIdx.x / BLOCK_TILE_SIZE_X * 33 | // THREAD_TILE_SIZE_Y : blockIdx.y * BLOCK_TILE_SIZE_Y + (threadIdx.x / 34 | // BLOCK_TILE_SIZE_X + 1) * THREAD_TILE_SIZE_Y][blockIdx.x * 35 | // BLOCK_TILE_SIZE_X + threadIdx.x % BLOCK_TILE_SIZE_X] 36 | T C_thread_results[THREAD_TILE_SIZE_Y] = {static_cast(0)}; 37 | 38 | for (size_t thread_block_tile_idx{0U}; 39 | thread_block_tile_idx < num_thread_block_tiles; 40 | ++thread_block_tile_idx) 41 | { 42 | load_data_from_global_memory_to_shared_memory< 43 | T, BLOCK_TILE_SIZE_X, BLOCK_TILE_SIZE_Y, BLOCK_TILE_SIZE_K, 44 | NUM_THREADS>(A, lda, B, ldb, A_thread_block_tile, 45 | B_thread_block_tile, thread_block_tile_idx, 46 | thread_linear_idx, m, n, k); 47 | __syncthreads(); 48 | 49 | #pragma unroll 50 | for (size_t k_i{0U}; k_i < BLOCK_TILE_SIZE_K; ++k_i) 51 | { 52 | size_t const B_thread_block_tile_row_idx{k_i}; 53 | // B_val is cached in the register to alleviate the pressure on the 54 | // shared memory access. 55 | T const B_val{ 56 | B_thread_block_tile[B_thread_block_tile_row_idx] 57 | [thread_linear_idx % BLOCK_TILE_SIZE_X]}; 58 | #pragma unroll 59 | for (size_t thread_tile_row_idx{0U}; 60 | thread_tile_row_idx < THREAD_TILE_SIZE_Y; 61 | ++thread_tile_row_idx) 62 | { 63 | size_t const A_thread_block_tile_row_idx{ 64 | thread_linear_idx / BLOCK_TILE_SIZE_X * THREAD_TILE_SIZE_Y + 65 | thread_tile_row_idx}; 66 | size_t const A_thread_block_tile_col_idx{k_i}; 67 | T const A_val{A_thread_block_tile[A_thread_block_tile_row_idx] 68 | [A_thread_block_tile_col_idx]}; 69 | C_thread_results[thread_tile_row_idx] += A_val * B_val; 70 | } 71 | } 72 | __syncthreads(); 73 | } 74 | 75 | // Write the results to DRAM. 76 | #pragma unroll 77 | for (size_t thread_tile_row_idx{0U}; 78 | thread_tile_row_idx < THREAD_TILE_SIZE_Y; ++thread_tile_row_idx) 79 | { 80 | size_t const C_row_idx{blockIdx.y * BLOCK_TILE_SIZE_Y + 81 | thread_linear_idx / BLOCK_TILE_SIZE_X * 82 | THREAD_TILE_SIZE_Y + 83 | thread_tile_row_idx}; 84 | size_t const C_col_idx{blockIdx.x * BLOCK_TILE_SIZE_X + 85 | thread_linear_idx % BLOCK_TILE_SIZE_X}; 86 | if (C_row_idx < m && C_col_idx < n) 87 | { 88 | C[C_row_idx * ldc + C_col_idx] = 89 | alpha * C_thread_results[thread_tile_row_idx] + 90 | beta * C[C_row_idx * ldc + C_col_idx]; 91 | } 92 | } 93 | } 94 | 95 | template 96 | void launch_gemm_kernel_v03(size_t m, size_t n, size_t k, T const* alpha, 97 | T const* A, size_t lda, T const* B, size_t ldb, 98 | T const* beta, T* C, size_t ldc, 99 | cudaStream_t stream) 100 | { 101 | // Feel free to play with the block tile sizes. 102 | // The algorithm correctness should always be guaranteed. 103 | constexpr unsigned int BLOCK_TILE_SIZE_X{64U}; 104 | constexpr unsigned int BLOCK_TILE_SIZE_Y{64U}; 105 | constexpr unsigned int BLOCK_TILE_SIZE_K{8U}; 106 | // Each thread computes THREAD_TILE_SIZE_Y values of C. 107 | constexpr unsigned int THREAD_TILE_SIZE_Y{8U}; 108 | constexpr unsigned int NUM_THREADS_PER_BLOCK{ 109 | BLOCK_TILE_SIZE_X * BLOCK_TILE_SIZE_Y / THREAD_TILE_SIZE_Y}; 110 | static_assert(BLOCK_TILE_SIZE_Y % THREAD_TILE_SIZE_Y == 0U); 111 | static_assert(NUM_THREADS_PER_BLOCK % BLOCK_TILE_SIZE_K == 0U); 112 | static_assert(NUM_THREADS_PER_BLOCK % BLOCK_TILE_SIZE_X == 0U); 113 | dim3 const block_dim{NUM_THREADS_PER_BLOCK, 1U, 1U}; 114 | dim3 const grid_dim{ 115 | (static_cast(n) + BLOCK_TILE_SIZE_X - 1U) / 116 | BLOCK_TILE_SIZE_X, 117 | (static_cast(m) + BLOCK_TILE_SIZE_Y - 1U) / 118 | BLOCK_TILE_SIZE_Y, 119 | 1U}; 120 | gemm_v03<<>>( 122 | m, n, k, *alpha, A, lda, B, ldb, *beta, C, ldc); 123 | CHECK_LAST_CUDA_ERROR(); 124 | } 125 | 126 | // Explicit instantiation. 127 | template void launch_gemm_kernel_v03(size_t m, size_t n, size_t k, 128 | float const* alpha, float const* A, 129 | size_t lda, float const* B, 130 | size_t ldb, float const* beta, 131 | float* C, size_t ldc, 132 | cudaStream_t stream); 133 | template void launch_gemm_kernel_v03(size_t m, size_t n, size_t k, 134 | double const* alpha, 135 | double const* A, size_t lda, 136 | double const* B, size_t ldb, 137 | double const* beta, double* C, 138 | size_t ldc, cudaStream_t stream); 139 | template void launch_gemm_kernel_v03<__half>(size_t m, size_t n, size_t k, 140 | __half const* alpha, 141 | __half const* A, size_t lda, 142 | __half const* B, size_t ldb, 143 | __half const* beta, __half* C, 144 | size_t ldc, cudaStream_t stream); -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CUDA GEMM Optimization 2 | 3 | ## Introduction 4 | 5 | This repository contains the CUDA kernels for general matrix-matrix multiplication (GEMM) and the corresponding performance analysis. The correctness of the CUDA kernels is guaranteed for any matrix size. The parameters of the CUDA kernels are slightly turned for GEMM 4096 x 4096 x 4096 on an NVIDIA GeForce RTX 3090 GPU. The CUDA kernels should be compatible with any NVIDIA GPUs with compute capability 7.0 or higher. 6 | 7 | ## Usages 8 | 9 | Docker is used to build and run the CUDA kernels. The custom Docker container is built based on the [NVIDIA NGC CUDA](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/cuda) 12.2.2 Docker container. 10 | 11 | Please adjust the base Docker container CUDA version if the host computer has a different CUDA version. Otherwise, weird compilation errors and runtime errors may occur. 12 | 13 | ### Build Docker Images 14 | 15 | To build the custom Docker image, please run the following command. 16 | 17 | ```bash 18 | $ docker build -f docker/gemm-cuda.Dockerfile --no-cache --tag=gemm-cuda:12.2.2 . 19 | ``` 20 | 21 | ### Run Docker Container 22 | 23 | To run the custom Docker container, please run the following command. 24 | 25 | ```bash 26 | $ docker run -it --rm --gpus device=0 -v $(pwd):/mnt gemm-cuda:12.2.2 27 | ``` 28 | 29 | If we want to profile the CUDA kernels using [NVIDIA Nsight Compute](https://leimao.github.io/blog/Docker-Nsight-Compute/), we need to add additional flags `--cap-add=SYS_ADMIN` and `--security-opt seccomp=unconfined` when we run the Docker container. 30 | 31 | ### Build CUDA Kernels 32 | 33 | To build the CUDA kernels, please run the following commands inside the Docker container. 34 | 35 | ```bash 36 | $ cmake -B build 37 | $ cmake --build build --config Release --parallel 38 | $ cmake --install build 39 | ``` 40 | 41 | ### Run CUDA Kernels 42 | 43 | To run the FP32 and FP16 GEMM CUDA kernels, please run the following commands inside the Docker container. 44 | 45 | ```bash 46 | $ ./build/src/profile_cuda_gemm_fp32 47 | $ ./build/src/profile_cuda_gemm_fp16 48 | ``` 49 | 50 | ## Performances 51 | 52 | All the experiments are conducted on a single NVIDIA [GeForce RTX 3090 GPU](https://www.nvidia.com/content/PDF/nvidia-ampere-ga-102-gpu-architecture-whitepaper-v2.1.pdf). The performance can vary, sometimes up to 25%, from one measurement to another. 53 | 54 | ### FP32 GEMM 55 | 56 | All the FP32 GEMM kernels cannot utilize the NVIDIA Tensor Cores. 57 | 58 | | GEMM Kernel | TFLOPS | Kernel Description | 59 | | :-------------------------------- | -------- | ---------------------------------------------------------------------------------------------------------: | 60 | | cuBLAS GEMM Kernel | 24.5971 | cuBLAS implementation | 61 | | Custom GEMM Kernel V00 | 0.278129 | Non-coalesced global memory access | 62 | | Custom GEMM Kernel V01 | 1.7218 | Coalesced global memory access | 63 | | Custom GEMM Kernel V02 | 2.66157 | 2D block tiling | 64 | | Custom GEMM Kernel V02 Vectorized | 1.90514 | 2D block tiling with vectorized memory access | 65 | | Custom GEMM Kernel V03 | 8.91318 | 2D block tiling and 1D thread tiling | 66 | | Custom GEMM Kernel V03 Vectorized | 4.04796 | 2D block tiling and 1D thread tiling with vectorized memory access | 67 | | Custom GEMM Kernel V04 | 13.0247 | 2D block tiling and 2D thread tiling | 68 | | Custom GEMM Kernel V04 Vectorized | 15.027 | 2D block tiling and 2D thread tiling with vectorized memory access | 69 | | Custom GEMM Kernel V05 | 11.1448 | 2D block tiling and 2D thread tiling and matrix transpose | 70 | | Custom GEMM Kernel V05 Vectorized | 19.6688 | 2D block tiling and 2D thread tiling and matrix transpose with vectorized memory access | 71 | | Custom GEMM Kernel V06 | 11.0703 | 2D block tiling and 2D warp tiling and 2D thread tiling and matrix transpose | 72 | | Custom GEMM Kernel V06 Vectorized | 20.1649 | 2D block tiling and 2D warp tiling and 2D thread tiling and matrix transpose with vectorized memory access | 73 | 74 | ### FP16 GEMM 75 | 76 | The FP16 custom GEMM kernels V00 to V06 do not utilize the NVIDIA Tensor Cores. The FP16 cuBLAS GEMM kernel and custom GEMM kernels V07 utilize the NVIDIA Tensor Cores. 77 | 78 | | GEMM Kernel | TFLOPS | Kernel Description | 79 | | :-------------------------------- | -------- | ---------------------------------------------------------------------------------------------------------: | 80 | | cuBLAS GEMM Kernel | 138.955 | cuBLAS implementation | 81 | | Custom GEMM Kernel V00 | 0.284095 | Non-coalesced global memory access | 82 | | Custom GEMM Kernel V01 | 1.7316 | Coalesced global memory access | 83 | | Custom GEMM Kernel V02 | 2.46677 | 2D block tiling GEMM | 84 | | Custom GEMM Kernel V02 Vectorized | 1.93088 | 2D block tiling with vectorized memory access | 85 | | Custom GEMM Kernel V03 | 8.67563 | 2D block tiling and 1D thread tiling GEMM | 86 | | Custom GEMM Kernel V03 Vectorized | 2.14047 | 2D block tiling and 1D thread tiling with vectorized memory access | 87 | | Custom GEMM Kernel V04 | 20.2746 | 2D block tiling and 2D thread tiling GEMM | 88 | | Custom GEMM Kernel V04 Vectorized | 22.9001 | 2D block tiling and 2D thread tiling with vectorized memory access | 89 | | Custom GEMM Kernel V05 | 18.3736 | 2D block tiling and 2D thread tiling and matrix transpose GEMM | 90 | | Custom GEMM Kernel V05 Vectorized | 27.962 | 2D block tiling and 2D thread tiling and matrix transpose with vectorized memory access | 91 | | Custom GEMM Kernel V06 | 14.7622 | 2D block tiling and 2D warp tiling and 2D thread tiling and matrix transpose GEMM | 92 | | Custom GEMM Kernel V06 Vectorized | 28.4588 | 2D block tiling and 2D warp tiling and 2D thread tiling and matrix transpose with vectorized memory access | 93 | | Custom GEMM Kernel V07 | 35.2312 | 2D block tiling and 2D warp tiling and WMMA and matrix transpose | 94 | | Custom GEMM Kernel V07 Vectorized | 55.0298 | 2D block tiling and 2D warp tiling and WMMA and matrix transpose and vectorized memory access. | 95 | 96 | ## References 97 | 98 | - [CUDA Matrix Multiplication Optimization](https://leimao.github.io/article/CUDA-Matrix-Multiplication-Optimization/) 99 | -------------------------------------------------------------------------------- /src/05_2d_block_tiling_2d_thread_tiling_matrix_transpose.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "cuda_gemm.hpp" 4 | #include "cuda_gemm_utils.cuh" 5 | #include "cuda_gemm_utils.hpp" 6 | 7 | // GEMM kernel v05. 8 | // Coalesced read and write from global memory. 9 | template 12 | __global__ void gemm_v05(size_t m, size_t n, size_t k, T alpha, T const* A, 13 | size_t lda, T const* B, size_t ldb, T beta, T* C, 14 | size_t ldc) 15 | { 16 | // Avoid using blockDim.x * blockDim.y as the number of threads per block. 17 | // Because it is a runtime constant and the compiler cannot optimize the 18 | // loop unrolling based on that. 19 | // Use a compile time constant instead. 20 | constexpr size_t NUM_THREADS{BLOCK_TILE_SIZE_X * BLOCK_TILE_SIZE_Y / 21 | (THREAD_TILE_SIZE_X * THREAD_TILE_SIZE_Y)}; 22 | size_t const thread_linear_idx{threadIdx.y * blockDim.x + threadIdx.x}; 23 | 24 | // Cache a tile of A and B in shared memory for data reuse. 25 | __shared__ T 26 | A_thread_block_tile_transposed[BLOCK_TILE_SIZE_K][BLOCK_TILE_SIZE_Y]; 27 | __shared__ T B_thread_block_tile[BLOCK_TILE_SIZE_K][BLOCK_TILE_SIZE_X]; 28 | 29 | size_t const num_thread_block_tiles{(k + BLOCK_TILE_SIZE_K - 1) / 30 | BLOCK_TILE_SIZE_K}; 31 | 32 | // Each thread in the block processes BLOCK_TILE_SIZE_Y output values. 33 | // Specifically, these values corresponds to 34 | // C[blockIdx.y * BLOCK_TILE_SIZE_Y + threadIdx.x / BLOCK_TILE_SIZE_X * 35 | // THREAD_TILE_SIZE_Y : blockIdx.y * BLOCK_TILE_SIZE_Y + (threadIdx.x / 36 | // BLOCK_TILE_SIZE_X + 1) * THREAD_TILE_SIZE_Y][blockIdx.x * 37 | // BLOCK_TILE_SIZE_X + threadIdx.x % BLOCK_TILE_SIZE_X * 38 | // THREAD_TILE_SIZE_X : blockIdx.x * BLOCK_TILE_SIZE_X + (threadIdx.x % 39 | // BLOCK_TILE_SIZE_X + 1) * THREAD_TILE_SIZE_X] 40 | T C_thread_results[THREAD_TILE_SIZE_Y][THREAD_TILE_SIZE_X] = { 41 | static_cast(0)}; 42 | // A_vals is cached in the register. 43 | T A_vals[THREAD_TILE_SIZE_Y] = {static_cast(0)}; 44 | // B_vals is cached in the register. 45 | T B_vals[THREAD_TILE_SIZE_X] = {static_cast(0)}; 46 | 47 | for (size_t thread_block_tile_idx{0U}; 48 | thread_block_tile_idx < num_thread_block_tiles; 49 | ++thread_block_tile_idx) 50 | { 51 | 52 | load_data_from_global_memory_to_shared_memory_transposed< 53 | T, BLOCK_TILE_SIZE_X, BLOCK_TILE_SIZE_Y, BLOCK_TILE_SIZE_K, 54 | NUM_THREADS>(A, lda, B, ldb, A_thread_block_tile_transposed, 55 | B_thread_block_tile, thread_block_tile_idx, 56 | thread_linear_idx, m, n, k); 57 | __syncthreads(); 58 | 59 | #pragma unroll 60 | for (size_t k_i{0U}; k_i < BLOCK_TILE_SIZE_K; ++k_i) 61 | { 62 | size_t const A_thread_block_tile_row_idx{ 63 | thread_linear_idx / (BLOCK_TILE_SIZE_X / THREAD_TILE_SIZE_X) * 64 | THREAD_TILE_SIZE_Y}; 65 | size_t const A_thread_block_tile_col_idx{k_i}; 66 | 67 | #pragma unroll 68 | for (size_t thread_tile_row_idx{0U}; 69 | thread_tile_row_idx < THREAD_TILE_SIZE_Y; 70 | ++thread_tile_row_idx) 71 | { 72 | A_vals[thread_tile_row_idx] = 73 | A_thread_block_tile_transposed[A_thread_block_tile_col_idx] 74 | [A_thread_block_tile_row_idx + 75 | thread_tile_row_idx]; 76 | } 77 | 78 | size_t const B_thread_block_tile_row_idx{k_i}; 79 | size_t const B_thread_block_tile_col_idx{ 80 | thread_linear_idx % (BLOCK_TILE_SIZE_X / THREAD_TILE_SIZE_X) * 81 | THREAD_TILE_SIZE_X}; 82 | #pragma unroll 83 | for (size_t thread_tile_col_idx{0U}; 84 | thread_tile_col_idx < THREAD_TILE_SIZE_X; 85 | ++thread_tile_col_idx) 86 | { 87 | B_vals[thread_tile_col_idx] = 88 | B_thread_block_tile[B_thread_block_tile_row_idx] 89 | [B_thread_block_tile_col_idx + 90 | thread_tile_col_idx]; 91 | } 92 | 93 | for (size_t thread_tile_row_idx{0U}; 94 | thread_tile_row_idx < THREAD_TILE_SIZE_Y; 95 | ++thread_tile_row_idx) 96 | { 97 | for (size_t thread_tile_col_idx{0U}; 98 | thread_tile_col_idx < THREAD_TILE_SIZE_X; 99 | ++thread_tile_col_idx) 100 | { 101 | C_thread_results[thread_tile_row_idx] 102 | [thread_tile_col_idx] += 103 | A_vals[thread_tile_row_idx] * 104 | B_vals[thread_tile_col_idx]; 105 | } 106 | } 107 | } 108 | __syncthreads(); 109 | } 110 | 111 | // Write the results to DRAM. 112 | for (size_t thread_tile_row_idx{0U}; 113 | thread_tile_row_idx < THREAD_TILE_SIZE_Y; ++thread_tile_row_idx) 114 | { 115 | for (size_t thread_tile_col_idx{0U}; 116 | thread_tile_col_idx < THREAD_TILE_SIZE_X; ++thread_tile_col_idx) 117 | { 118 | size_t const C_row_idx{ 119 | blockIdx.y * BLOCK_TILE_SIZE_Y + 120 | threadIdx.x / (BLOCK_TILE_SIZE_X / THREAD_TILE_SIZE_X) * 121 | THREAD_TILE_SIZE_Y + 122 | thread_tile_row_idx}; 123 | size_t const C_col_idx{ 124 | blockIdx.x * BLOCK_TILE_SIZE_X + 125 | threadIdx.x % (BLOCK_TILE_SIZE_X / THREAD_TILE_SIZE_X) * 126 | THREAD_TILE_SIZE_X + 127 | thread_tile_col_idx}; 128 | if (C_row_idx < m && C_col_idx < n) 129 | { 130 | C[C_row_idx * ldc + C_col_idx] = 131 | alpha * C_thread_results[thread_tile_row_idx] 132 | [thread_tile_col_idx] + 133 | beta * C[C_row_idx * ldc + C_col_idx]; 134 | } 135 | } 136 | } 137 | } 138 | 139 | template 140 | void launch_gemm_kernel_v05(size_t m, size_t n, size_t k, T const* alpha, 141 | T const* A, size_t lda, T const* B, size_t ldb, 142 | T const* beta, T* C, size_t ldc, 143 | cudaStream_t stream) 144 | { 145 | // Feel free to play with the block tile sizes. 146 | // The algorithm correctness should always be guaranteed. 147 | constexpr unsigned int BLOCK_TILE_SIZE_X{128U}; 148 | constexpr unsigned int BLOCK_TILE_SIZE_Y{128U}; 149 | constexpr unsigned int BLOCK_TILE_SIZE_K{16U}; 150 | // Each thread computes THREAD_TILE_SIZE_X * THREAD_TILE_SIZE_Y values of C. 151 | constexpr unsigned int THREAD_TILE_SIZE_X{8U}; 152 | constexpr unsigned int THREAD_TILE_SIZE_Y{8U}; 153 | constexpr unsigned int NUM_THREADS_PER_BLOCK{ 154 | BLOCK_TILE_SIZE_X * BLOCK_TILE_SIZE_Y / 155 | (THREAD_TILE_SIZE_X * THREAD_TILE_SIZE_Y)}; 156 | static_assert(BLOCK_TILE_SIZE_X % THREAD_TILE_SIZE_X == 0U); 157 | static_assert(BLOCK_TILE_SIZE_Y % THREAD_TILE_SIZE_Y == 0U); 158 | static_assert(NUM_THREADS_PER_BLOCK % BLOCK_TILE_SIZE_K == 0U); 159 | static_assert(NUM_THREADS_PER_BLOCK % BLOCK_TILE_SIZE_X == 0U); 160 | static_assert( 161 | BLOCK_TILE_SIZE_X * BLOCK_TILE_SIZE_K % NUM_THREADS_PER_BLOCK == 0U); 162 | static_assert( 163 | BLOCK_TILE_SIZE_K * BLOCK_TILE_SIZE_Y % NUM_THREADS_PER_BLOCK == 0U); 164 | dim3 const block_dim{NUM_THREADS_PER_BLOCK, 1U, 1U}; 165 | dim3 const grid_dim{ 166 | (static_cast(n) + BLOCK_TILE_SIZE_X - 1U) / 167 | BLOCK_TILE_SIZE_X, 168 | (static_cast(m) + BLOCK_TILE_SIZE_Y - 1U) / 169 | BLOCK_TILE_SIZE_Y, 170 | 1U}; 171 | gemm_v05 173 | <<>>(m, n, k, *alpha, A, lda, B, ldb, 174 | *beta, C, ldc); 175 | CHECK_LAST_CUDA_ERROR(); 176 | } 177 | 178 | // Explicit instantiation. 179 | template void launch_gemm_kernel_v05(size_t m, size_t n, size_t k, 180 | float const* alpha, float const* A, 181 | size_t lda, float const* B, 182 | size_t ldb, float const* beta, 183 | float* C, size_t ldc, 184 | cudaStream_t stream); 185 | template void launch_gemm_kernel_v05(size_t m, size_t n, size_t k, 186 | double const* alpha, 187 | double const* A, size_t lda, 188 | double const* B, size_t ldb, 189 | double const* beta, double* C, 190 | size_t ldc, cudaStream_t stream); 191 | template void launch_gemm_kernel_v05<__half>(size_t m, size_t n, size_t k, 192 | __half const* alpha, 193 | __half const* A, size_t lda, 194 | __half const* B, size_t ldb, 195 | __half const* beta, __half* C, 196 | size_t ldc, cudaStream_t stream); -------------------------------------------------------------------------------- /src/04_2d_block_tiling_2d_thread_tiling.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "cuda_gemm.hpp" 4 | #include "cuda_gemm_utils.cuh" 5 | #include "cuda_gemm_utils.hpp" 6 | 7 | // GEMM kernel v04. 8 | // Coalesced read and write from global memory. 9 | template 12 | __global__ void gemm_v04(size_t m, size_t n, size_t k, T alpha, T const* A, 13 | size_t lda, T const* B, size_t ldb, T beta, T* C, 14 | size_t ldc) 15 | { 16 | // Avoid using blockDim.x * blockDim.y as the number of threads per block. 17 | // Because it is a runtime constant and the compiler cannot optimize the 18 | // loop unrolling based on that. 19 | // Use a compile time constant instead. 20 | constexpr size_t NUM_THREADS{BLOCK_TILE_SIZE_X * BLOCK_TILE_SIZE_Y / 21 | (THREAD_TILE_SIZE_X * THREAD_TILE_SIZE_Y)}; 22 | size_t const thread_linear_idx{threadIdx.y * blockDim.x + threadIdx.x}; 23 | 24 | // Cache a tile of A and B in shared memory for data reuse. 25 | __shared__ T A_thread_block_tile[BLOCK_TILE_SIZE_Y][BLOCK_TILE_SIZE_K]; 26 | __shared__ T B_thread_block_tile[BLOCK_TILE_SIZE_K][BLOCK_TILE_SIZE_X]; 27 | 28 | size_t const num_thread_block_tiles{(k + BLOCK_TILE_SIZE_K - 1) / 29 | BLOCK_TILE_SIZE_K}; 30 | 31 | // Each thread in the block processes BLOCK_TILE_SIZE_Y output values. 32 | // Specifically, these values corresponds to 33 | // C[blockIdx.y * BLOCK_TILE_SIZE_Y + threadIdx.x / BLOCK_TILE_SIZE_X * 34 | // THREAD_TILE_SIZE_Y : blockIdx.y * BLOCK_TILE_SIZE_Y + (threadIdx.x / 35 | // BLOCK_TILE_SIZE_X + 1) * THREAD_TILE_SIZE_Y][blockIdx.x * 36 | // BLOCK_TILE_SIZE_X + threadIdx.x % BLOCK_TILE_SIZE_X * 37 | // THREAD_TILE_SIZE_X : blockIdx.x * BLOCK_TILE_SIZE_X + (threadIdx.x % 38 | // BLOCK_TILE_SIZE_X + 1) * THREAD_TILE_SIZE_X] 39 | T C_thread_results[THREAD_TILE_SIZE_Y][THREAD_TILE_SIZE_X] = { 40 | static_cast(0)}; 41 | // A_vals is cached in the register. 42 | T A_vals[THREAD_TILE_SIZE_Y] = {static_cast(0)}; 43 | // B_vals is cached in the register. 44 | T B_vals[THREAD_TILE_SIZE_X] = {static_cast(0)}; 45 | 46 | for (size_t thread_block_tile_idx{0U}; 47 | thread_block_tile_idx < num_thread_block_tiles; 48 | ++thread_block_tile_idx) 49 | { 50 | 51 | load_data_from_global_memory_to_shared_memory< 52 | T, BLOCK_TILE_SIZE_X, BLOCK_TILE_SIZE_Y, BLOCK_TILE_SIZE_K, 53 | NUM_THREADS>(A, lda, B, ldb, A_thread_block_tile, 54 | B_thread_block_tile, thread_block_tile_idx, 55 | thread_linear_idx, m, n, k); 56 | __syncthreads(); 57 | 58 | #pragma unroll 59 | for (size_t k_i{0U}; k_i < BLOCK_TILE_SIZE_K; ++k_i) 60 | { 61 | size_t const A_thread_block_tile_row_idx{ 62 | thread_linear_idx / (BLOCK_TILE_SIZE_X / THREAD_TILE_SIZE_X) * 63 | THREAD_TILE_SIZE_Y}; 64 | size_t const A_thread_block_tile_col_idx{k_i}; 65 | 66 | #pragma unroll 67 | for (size_t thread_tile_row_idx{0U}; 68 | thread_tile_row_idx < THREAD_TILE_SIZE_Y; 69 | ++thread_tile_row_idx) 70 | { 71 | // There will be shared memory bank conflicts accessing the 72 | // values from A_thread_block_tile. We can do it better by 73 | // transposing the A_thread_block_tile when we load the data 74 | // from DRAM. 75 | A_vals[thread_tile_row_idx] = 76 | A_thread_block_tile[A_thread_block_tile_row_idx + 77 | thread_tile_row_idx] 78 | [A_thread_block_tile_col_idx]; 79 | } 80 | 81 | size_t const B_thread_block_tile_row_idx{k_i}; 82 | size_t const B_thread_block_tile_col_idx{ 83 | thread_linear_idx % (BLOCK_TILE_SIZE_X / THREAD_TILE_SIZE_X) * 84 | THREAD_TILE_SIZE_X}; 85 | #pragma unroll 86 | for (size_t thread_tile_col_idx{0U}; 87 | thread_tile_col_idx < THREAD_TILE_SIZE_X; 88 | ++thread_tile_col_idx) 89 | { 90 | B_vals[thread_tile_col_idx] = 91 | B_thread_block_tile[B_thread_block_tile_row_idx] 92 | [B_thread_block_tile_col_idx + 93 | thread_tile_col_idx]; 94 | } 95 | 96 | for (size_t thread_tile_row_idx{0U}; 97 | thread_tile_row_idx < THREAD_TILE_SIZE_Y; 98 | ++thread_tile_row_idx) 99 | { 100 | for (size_t thread_tile_col_idx{0U}; 101 | thread_tile_col_idx < THREAD_TILE_SIZE_X; 102 | ++thread_tile_col_idx) 103 | { 104 | C_thread_results[thread_tile_row_idx] 105 | [thread_tile_col_idx] += 106 | A_vals[thread_tile_row_idx] * 107 | B_vals[thread_tile_col_idx]; 108 | } 109 | } 110 | } 111 | __syncthreads(); 112 | } 113 | 114 | // Write the results to DRAM. 115 | for (size_t thread_tile_row_idx{0U}; 116 | thread_tile_row_idx < THREAD_TILE_SIZE_Y; ++thread_tile_row_idx) 117 | { 118 | for (size_t thread_tile_col_idx{0U}; 119 | thread_tile_col_idx < THREAD_TILE_SIZE_X; ++thread_tile_col_idx) 120 | { 121 | size_t const C_row_idx{ 122 | blockIdx.y * BLOCK_TILE_SIZE_Y + 123 | threadIdx.x / (BLOCK_TILE_SIZE_X / THREAD_TILE_SIZE_X) * 124 | THREAD_TILE_SIZE_Y + 125 | thread_tile_row_idx}; 126 | size_t const C_col_idx{ 127 | blockIdx.x * BLOCK_TILE_SIZE_X + 128 | threadIdx.x % (BLOCK_TILE_SIZE_X / THREAD_TILE_SIZE_X) * 129 | THREAD_TILE_SIZE_X + 130 | thread_tile_col_idx}; 131 | if (C_row_idx < m && C_col_idx < n) 132 | { 133 | C[C_row_idx * ldc + C_col_idx] = 134 | alpha * C_thread_results[thread_tile_row_idx] 135 | [thread_tile_col_idx] + 136 | beta * C[C_row_idx * ldc + C_col_idx]; 137 | } 138 | } 139 | } 140 | } 141 | 142 | template 143 | void launch_gemm_kernel_v04(size_t m, size_t n, size_t k, T const* alpha, 144 | T const* A, size_t lda, T const* B, size_t ldb, 145 | T const* beta, T* C, size_t ldc, 146 | cudaStream_t stream) 147 | { 148 | // Feel free to play with the block tile sizes. 149 | // The algorithm correctness should always be guaranteed. 150 | constexpr unsigned int BLOCK_TILE_SIZE_X{128U}; 151 | constexpr unsigned int BLOCK_TILE_SIZE_Y{128U}; 152 | constexpr unsigned int BLOCK_TILE_SIZE_K{16U}; 153 | // Each thread computes THREAD_TILE_SIZE_X * THREAD_TILE_SIZE_Y values of C. 154 | constexpr unsigned int THREAD_TILE_SIZE_X{8U}; 155 | constexpr unsigned int THREAD_TILE_SIZE_Y{8U}; 156 | constexpr unsigned int NUM_THREADS_PER_BLOCK{ 157 | BLOCK_TILE_SIZE_X * BLOCK_TILE_SIZE_Y / 158 | (THREAD_TILE_SIZE_X * THREAD_TILE_SIZE_Y)}; 159 | static_assert(BLOCK_TILE_SIZE_X % THREAD_TILE_SIZE_X == 0U); 160 | static_assert(BLOCK_TILE_SIZE_Y % THREAD_TILE_SIZE_Y == 0U); 161 | static_assert(NUM_THREADS_PER_BLOCK % BLOCK_TILE_SIZE_K == 0U); 162 | static_assert(NUM_THREADS_PER_BLOCK % BLOCK_TILE_SIZE_X == 0U); 163 | static_assert( 164 | BLOCK_TILE_SIZE_X * BLOCK_TILE_SIZE_K % NUM_THREADS_PER_BLOCK == 0U); 165 | static_assert( 166 | BLOCK_TILE_SIZE_K * BLOCK_TILE_SIZE_Y % NUM_THREADS_PER_BLOCK == 0U); 167 | dim3 const block_dim{NUM_THREADS_PER_BLOCK, 1U, 1U}; 168 | dim3 const grid_dim{ 169 | (static_cast(n) + BLOCK_TILE_SIZE_X - 1U) / 170 | BLOCK_TILE_SIZE_X, 171 | (static_cast(m) + BLOCK_TILE_SIZE_Y - 1U) / 172 | BLOCK_TILE_SIZE_Y, 173 | 1U}; 174 | gemm_v04 176 | <<>>(m, n, k, *alpha, A, lda, B, ldb, 177 | *beta, C, ldc); 178 | CHECK_LAST_CUDA_ERROR(); 179 | } 180 | 181 | // Explicit instantiation. 182 | template void launch_gemm_kernel_v04(size_t m, size_t n, size_t k, 183 | float const* alpha, float const* A, 184 | size_t lda, float const* B, 185 | size_t ldb, float const* beta, 186 | float* C, size_t ldc, 187 | cudaStream_t stream); 188 | template void launch_gemm_kernel_v04(size_t m, size_t n, size_t k, 189 | double const* alpha, 190 | double const* A, size_t lda, 191 | double const* B, size_t ldb, 192 | double const* beta, double* C, 193 | size_t ldc, cudaStream_t stream); 194 | template void launch_gemm_kernel_v04<__half>(size_t m, size_t n, size_t k, 195 | __half const* alpha, 196 | __half const* A, size_t lda, 197 | __half const* B, size_t ldb, 198 | __half const* beta, __half* C, 199 | size_t ldc, cudaStream_t stream); -------------------------------------------------------------------------------- /src/05_2d_block_tiling_2d_thread_tiling_matrix_transpose_vectorized_memory_access.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "cuda_gemm.hpp" 4 | #include "cuda_gemm_utils.cuh" 5 | #include "cuda_gemm_utils.hpp" 6 | 7 | // GEMM kernel v05. 8 | // Coalesced read and write from global memory. 9 | template 12 | __global__ void gemm_v05_vectorized(size_t m, size_t n, size_t k, T alpha, 13 | T const* A, size_t lda, T const* B, 14 | size_t ldb, T beta, T* C, size_t ldc) 15 | { 16 | // Avoid using blockDim.x * blockDim.y as the number of threads per block. 17 | // Because it is a runtime constant and the compiler cannot optimize the 18 | // loop unrolling based on that. 19 | // Use a compile time constant instead. 20 | constexpr size_t NUM_THREADS{BLOCK_TILE_SIZE_X * BLOCK_TILE_SIZE_Y / 21 | (THREAD_TILE_SIZE_X * THREAD_TILE_SIZE_Y)}; 22 | size_t const thread_linear_idx{threadIdx.y * blockDim.x + threadIdx.x}; 23 | 24 | // Cache a tile of A and B in shared memory for data reuse. 25 | __shared__ T 26 | A_thread_block_tile_transposed[BLOCK_TILE_SIZE_K][BLOCK_TILE_SIZE_Y]; 27 | __shared__ T B_thread_block_tile[BLOCK_TILE_SIZE_K][BLOCK_TILE_SIZE_X]; 28 | 29 | size_t const num_thread_block_tiles{(k + BLOCK_TILE_SIZE_K - 1) / 30 | BLOCK_TILE_SIZE_K}; 31 | 32 | // Each thread in the block processes BLOCK_TILE_SIZE_Y output values. 33 | // Specifically, these values corresponds to 34 | // C[blockIdx.y * BLOCK_TILE_SIZE_Y + threadIdx.x / BLOCK_TILE_SIZE_X * 35 | // THREAD_TILE_SIZE_Y : blockIdx.y * BLOCK_TILE_SIZE_Y + (threadIdx.x / 36 | // BLOCK_TILE_SIZE_X + 1) * THREAD_TILE_SIZE_Y][blockIdx.x * 37 | // BLOCK_TILE_SIZE_X + threadIdx.x % BLOCK_TILE_SIZE_X * 38 | // THREAD_TILE_SIZE_X : blockIdx.x * BLOCK_TILE_SIZE_X + (threadIdx.x % 39 | // BLOCK_TILE_SIZE_X + 1) * THREAD_TILE_SIZE_X] 40 | T C_thread_results[THREAD_TILE_SIZE_Y][THREAD_TILE_SIZE_X] = { 41 | static_cast(0)}; 42 | // A_vals is cached in the register. 43 | T A_vals[THREAD_TILE_SIZE_Y] = {static_cast(0)}; 44 | // B_vals is cached in the register. 45 | T B_vals[THREAD_TILE_SIZE_X] = {static_cast(0)}; 46 | 47 | constexpr size_t NUM_VECTOR_UNITS{sizeof(int4) / sizeof(T)}; 48 | static_assert(sizeof(int4) % sizeof(T) == 0U); 49 | static_assert(BLOCK_TILE_SIZE_K % NUM_VECTOR_UNITS == 0U); 50 | static_assert(BLOCK_TILE_SIZE_X % NUM_VECTOR_UNITS == 0U); 51 | constexpr size_t VECTORIZED_THREAD_TILE_SIZE_X{THREAD_TILE_SIZE_X / 52 | NUM_VECTOR_UNITS}; 53 | static_assert(THREAD_TILE_SIZE_X % NUM_VECTOR_UNITS == 0U); 54 | 55 | for (size_t thread_block_tile_idx{0U}; 56 | thread_block_tile_idx < num_thread_block_tiles; 57 | ++thread_block_tile_idx) 58 | { 59 | load_data_from_global_memory_to_shared_memory_transposed_vectorized< 60 | T, BLOCK_TILE_SIZE_X, BLOCK_TILE_SIZE_Y, BLOCK_TILE_SIZE_K, 61 | NUM_THREADS>(A, lda, B, ldb, A_thread_block_tile_transposed, 62 | B_thread_block_tile, thread_block_tile_idx, 63 | thread_linear_idx, m, n, k); 64 | __syncthreads(); 65 | 66 | #pragma unroll 67 | for (size_t k_i{0U}; k_i < BLOCK_TILE_SIZE_K; ++k_i) 68 | { 69 | size_t const A_thread_block_tile_row_idx{ 70 | thread_linear_idx / (BLOCK_TILE_SIZE_X / THREAD_TILE_SIZE_X) * 71 | THREAD_TILE_SIZE_Y}; 72 | size_t const A_thread_block_tile_col_idx{k_i}; 73 | 74 | #pragma unroll 75 | for (size_t thread_tile_row_idx{0U}; 76 | thread_tile_row_idx < THREAD_TILE_SIZE_Y; 77 | ++thread_tile_row_idx) 78 | { 79 | A_vals[thread_tile_row_idx] = 80 | A_thread_block_tile_transposed[A_thread_block_tile_col_idx] 81 | [A_thread_block_tile_row_idx + 82 | thread_tile_row_idx]; 83 | } 84 | 85 | size_t const B_thread_block_tile_row_idx{k_i}; 86 | size_t const B_thread_block_tile_col_idx{ 87 | thread_linear_idx % (BLOCK_TILE_SIZE_X / THREAD_TILE_SIZE_X) * 88 | THREAD_TILE_SIZE_X}; 89 | // Although the read from A_thread_block_tile cannot be vectorized, the read 90 | // from B_thread_block_tile can be vectorized. 91 | #pragma unroll 92 | for (size_t thread_tile_col_vector_idx{0U}; 93 | thread_tile_col_vector_idx < VECTORIZED_THREAD_TILE_SIZE_X; 94 | ++thread_tile_col_vector_idx) 95 | { 96 | *reinterpret_cast( 97 | &B_vals[thread_tile_col_vector_idx * NUM_VECTOR_UNITS]) = 98 | *reinterpret_cast( 99 | &B_thread_block_tile[B_thread_block_tile_row_idx] 100 | [B_thread_block_tile_col_idx + 101 | thread_tile_col_vector_idx * 102 | NUM_VECTOR_UNITS]); 103 | } 104 | 105 | for (size_t thread_tile_row_idx{0U}; 106 | thread_tile_row_idx < THREAD_TILE_SIZE_Y; 107 | ++thread_tile_row_idx) 108 | { 109 | for (size_t thread_tile_col_idx{0U}; 110 | thread_tile_col_idx < THREAD_TILE_SIZE_X; 111 | ++thread_tile_col_idx) 112 | { 113 | C_thread_results[thread_tile_row_idx] 114 | [thread_tile_col_idx] += 115 | A_vals[thread_tile_row_idx] * 116 | B_vals[thread_tile_col_idx]; 117 | } 118 | } 119 | } 120 | __syncthreads(); 121 | } 122 | 123 | // Vectorized writing the results to DRAM. 124 | for (size_t thread_tile_row_idx{0U}; 125 | thread_tile_row_idx < THREAD_TILE_SIZE_Y; ++thread_tile_row_idx) 126 | { 127 | for (size_t thread_tile_col_vector_idx{0U}; 128 | thread_tile_col_vector_idx < VECTORIZED_THREAD_TILE_SIZE_X; 129 | ++thread_tile_col_vector_idx) 130 | { 131 | size_t const C_row_idx{ 132 | blockIdx.y * BLOCK_TILE_SIZE_Y + 133 | thread_linear_idx / (BLOCK_TILE_SIZE_X / THREAD_TILE_SIZE_X) * 134 | THREAD_TILE_SIZE_Y + 135 | thread_tile_row_idx}; 136 | size_t const C_col_idx{ 137 | blockIdx.x * BLOCK_TILE_SIZE_X + 138 | thread_linear_idx % (BLOCK_TILE_SIZE_X / THREAD_TILE_SIZE_X) * 139 | THREAD_TILE_SIZE_X + 140 | thread_tile_col_vector_idx * NUM_VECTOR_UNITS}; 141 | // Vectorized read from C. 142 | int4 C_row_vector_vals{*reinterpret_cast( 143 | &C[C_row_idx * ldc + C_col_idx])}; 144 | // Vectorized read from C_thread_results. 145 | int4 const C_thread_results_row_vector_vals{ 146 | *reinterpret_cast( 147 | &C_thread_results[thread_tile_row_idx] 148 | [thread_tile_col_vector_idx * 149 | NUM_VECTOR_UNITS])}; 150 | // Update the values in C_row_vector_vals 151 | for (size_t i{0U}; i < NUM_VECTOR_UNITS; ++i) 152 | { 153 | reinterpret_cast(&C_row_vector_vals)[i] = 154 | alpha * reinterpret_cast( 155 | &C_thread_results_row_vector_vals)[i] + 156 | beta * reinterpret_cast(&C_row_vector_vals)[i]; 157 | } 158 | // Vectorized write to C. 159 | if (C_row_idx < m && C_col_idx < n) 160 | { 161 | // No need to mask out the out-of-bound invalid elements, 162 | // because the row of C matrix is 32-byte aligned. 163 | *reinterpret_cast(&C[C_row_idx * ldc + C_col_idx]) = 164 | C_row_vector_vals; 165 | } 166 | } 167 | } 168 | } 169 | 170 | template 171 | void launch_gemm_kernel_v05_vectorized(size_t m, size_t n, size_t k, 172 | T const* alpha, T const* A, size_t lda, 173 | T const* B, size_t ldb, T const* beta, 174 | T* C, size_t ldc, cudaStream_t stream) 175 | { 176 | // Feel free to play with the block tile sizes. 177 | // The algorithm correctness should always be guaranteed. 178 | constexpr unsigned int BLOCK_TILE_SIZE_X{128U}; 179 | constexpr unsigned int BLOCK_TILE_SIZE_Y{128U}; 180 | constexpr unsigned int BLOCK_TILE_SIZE_K{16U}; 181 | // Each thread computes THREAD_TILE_SIZE_X * THREAD_TILE_SIZE_Y values of C. 182 | constexpr unsigned int THREAD_TILE_SIZE_X{8U}; 183 | constexpr unsigned int THREAD_TILE_SIZE_Y{8U}; 184 | constexpr unsigned int NUM_THREADS_PER_BLOCK{ 185 | BLOCK_TILE_SIZE_X * BLOCK_TILE_SIZE_Y / 186 | (THREAD_TILE_SIZE_X * THREAD_TILE_SIZE_Y)}; 187 | static_assert(BLOCK_TILE_SIZE_X % THREAD_TILE_SIZE_X == 0U); 188 | static_assert(BLOCK_TILE_SIZE_Y % THREAD_TILE_SIZE_Y == 0U); 189 | static_assert(NUM_THREADS_PER_BLOCK % BLOCK_TILE_SIZE_K == 0U); 190 | static_assert(NUM_THREADS_PER_BLOCK % BLOCK_TILE_SIZE_X == 0U); 191 | static_assert( 192 | BLOCK_TILE_SIZE_X * BLOCK_TILE_SIZE_K % NUM_THREADS_PER_BLOCK == 0U); 193 | static_assert( 194 | BLOCK_TILE_SIZE_K * BLOCK_TILE_SIZE_Y % NUM_THREADS_PER_BLOCK == 0U); 195 | dim3 const block_dim{NUM_THREADS_PER_BLOCK, 1U, 1U}; 196 | dim3 const grid_dim{ 197 | (static_cast(n) + BLOCK_TILE_SIZE_X - 1U) / 198 | BLOCK_TILE_SIZE_X, 199 | (static_cast(m) + BLOCK_TILE_SIZE_Y - 1U) / 200 | BLOCK_TILE_SIZE_Y, 201 | 1U}; 202 | gemm_v05_vectorized 205 | <<>>(m, n, k, *alpha, A, lda, B, ldb, 206 | *beta, C, ldc); 207 | CHECK_LAST_CUDA_ERROR(); 208 | } 209 | 210 | // Explicit instantiation. 211 | template void launch_gemm_kernel_v05_vectorized( 212 | size_t m, size_t n, size_t k, float const* alpha, float const* A, 213 | size_t lda, float const* B, size_t ldb, float const* beta, float* C, 214 | size_t ldc, cudaStream_t stream); 215 | template void launch_gemm_kernel_v05_vectorized( 216 | size_t m, size_t n, size_t k, double const* alpha, double const* A, 217 | size_t lda, double const* B, size_t ldb, double const* beta, double* C, 218 | size_t ldc, cudaStream_t stream); 219 | template void launch_gemm_kernel_v05_vectorized<__half>( 220 | size_t m, size_t n, size_t k, __half const* alpha, __half const* A, 221 | size_t lda, __half const* B, size_t ldb, __half const* beta, __half* C, 222 | size_t ldc, cudaStream_t stream); -------------------------------------------------------------------------------- /src/04_2d_block_tiling_2d_thread_tiling_vectorized_memory_access.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "cuda_gemm.hpp" 4 | #include "cuda_gemm_utils.cuh" 5 | #include "cuda_gemm_utils.hpp" 6 | 7 | // GEMM kernel v04. 8 | // Coalesced read and write from global memory. 9 | template 12 | __global__ void gemm_v04_vectorized(size_t m, size_t n, size_t k, T alpha, 13 | T const* A, size_t lda, T const* B, 14 | size_t ldb, T beta, T* C, size_t ldc) 15 | { 16 | // Avoid using blockDim.x * blockDim.y as the number of threads per block. 17 | // Because it is a runtime constant and the compiler cannot optimize the 18 | // loop unrolling based on that. 19 | // Use a compile time constant instead. 20 | constexpr size_t NUM_THREADS{BLOCK_TILE_SIZE_X * BLOCK_TILE_SIZE_Y / 21 | (THREAD_TILE_SIZE_X * THREAD_TILE_SIZE_Y)}; 22 | size_t const thread_linear_idx{threadIdx.y * blockDim.x + threadIdx.x}; 23 | 24 | // Cache a tile of A and B in shared memory for data reuse. 25 | __shared__ T A_thread_block_tile[BLOCK_TILE_SIZE_Y][BLOCK_TILE_SIZE_K]; 26 | __shared__ T B_thread_block_tile[BLOCK_TILE_SIZE_K][BLOCK_TILE_SIZE_X]; 27 | 28 | size_t const num_thread_block_tiles{(k + BLOCK_TILE_SIZE_K - 1) / 29 | BLOCK_TILE_SIZE_K}; 30 | 31 | // Each thread in the block processes BLOCK_TILE_SIZE_Y output values. 32 | // Specifically, these values corresponds to 33 | // C[blockIdx.y * BLOCK_TILE_SIZE_Y + threadIdx.x / BLOCK_TILE_SIZE_X * 34 | // THREAD_TILE_SIZE_Y : blockIdx.y * BLOCK_TILE_SIZE_Y + (threadIdx.x / 35 | // BLOCK_TILE_SIZE_X + 1) * THREAD_TILE_SIZE_Y][blockIdx.x * 36 | // BLOCK_TILE_SIZE_X + threadIdx.x % BLOCK_TILE_SIZE_X * 37 | // THREAD_TILE_SIZE_X : blockIdx.x * BLOCK_TILE_SIZE_X + (threadIdx.x % 38 | // BLOCK_TILE_SIZE_X + 1) * THREAD_TILE_SIZE_X] 39 | T C_thread_results[THREAD_TILE_SIZE_Y][THREAD_TILE_SIZE_X] = { 40 | static_cast(0)}; 41 | // A_vals is cached in the register. 42 | T A_vals[THREAD_TILE_SIZE_Y] = {static_cast(0)}; 43 | // B_vals is cached in the register. 44 | T B_vals[THREAD_TILE_SIZE_X] = {static_cast(0)}; 45 | 46 | constexpr size_t NUM_VECTOR_UNITS{sizeof(int4) / sizeof(T)}; 47 | static_assert(sizeof(int4) % sizeof(T) == 0U); 48 | static_assert(BLOCK_TILE_SIZE_K % NUM_VECTOR_UNITS == 0U); 49 | static_assert(BLOCK_TILE_SIZE_X % NUM_VECTOR_UNITS == 0U); 50 | constexpr size_t VECTORIZED_THREAD_TILE_SIZE_X{THREAD_TILE_SIZE_X / 51 | NUM_VECTOR_UNITS}; 52 | static_assert(THREAD_TILE_SIZE_X % NUM_VECTOR_UNITS == 0U); 53 | 54 | for (size_t thread_block_tile_idx{0U}; 55 | thread_block_tile_idx < num_thread_block_tiles; 56 | ++thread_block_tile_idx) 57 | { 58 | load_data_from_global_memory_to_shared_memory< 59 | T, BLOCK_TILE_SIZE_X, BLOCK_TILE_SIZE_Y, BLOCK_TILE_SIZE_K, 60 | NUM_THREADS>(A, lda, B, ldb, A_thread_block_tile, 61 | B_thread_block_tile, thread_block_tile_idx, 62 | thread_linear_idx, m, n, k); 63 | __syncthreads(); 64 | 65 | #pragma unroll 66 | for (size_t k_i{0U}; k_i < BLOCK_TILE_SIZE_K; ++k_i) 67 | { 68 | size_t const A_thread_block_tile_row_idx{ 69 | thread_linear_idx / (BLOCK_TILE_SIZE_X / THREAD_TILE_SIZE_X) * 70 | THREAD_TILE_SIZE_Y}; 71 | size_t const A_thread_block_tile_col_idx{k_i}; 72 | 73 | #pragma unroll 74 | for (size_t thread_tile_row_idx{0U}; 75 | thread_tile_row_idx < THREAD_TILE_SIZE_Y; 76 | ++thread_tile_row_idx) 77 | { 78 | // There will be shared memory bank conflicts accessing the 79 | // values from A_thread_block_tile. We can do it better by 80 | // transposing the A_thread_block_tile when we load the data 81 | // from DRAM. 82 | A_vals[thread_tile_row_idx] = 83 | A_thread_block_tile[A_thread_block_tile_row_idx + 84 | thread_tile_row_idx] 85 | [A_thread_block_tile_col_idx]; 86 | } 87 | 88 | size_t const B_thread_block_tile_row_idx{k_i}; 89 | size_t const B_thread_block_tile_col_idx{ 90 | thread_linear_idx % (BLOCK_TILE_SIZE_X / THREAD_TILE_SIZE_X) * 91 | THREAD_TILE_SIZE_X}; 92 | // Although the read from A_thread_block_tile cannot be vectorized, the read 93 | // from B_thread_block_tile can be vectorized. 94 | #pragma unroll 95 | for (size_t thread_tile_col_vector_idx{0U}; 96 | thread_tile_col_vector_idx < VECTORIZED_THREAD_TILE_SIZE_X; 97 | ++thread_tile_col_vector_idx) 98 | { 99 | *reinterpret_cast( 100 | &B_vals[thread_tile_col_vector_idx * NUM_VECTOR_UNITS]) = 101 | *reinterpret_cast( 102 | &B_thread_block_tile[B_thread_block_tile_row_idx] 103 | [B_thread_block_tile_col_idx + 104 | thread_tile_col_vector_idx * 105 | NUM_VECTOR_UNITS]); 106 | } 107 | 108 | for (size_t thread_tile_row_idx{0U}; 109 | thread_tile_row_idx < THREAD_TILE_SIZE_Y; 110 | ++thread_tile_row_idx) 111 | { 112 | for (size_t thread_tile_col_idx{0U}; 113 | thread_tile_col_idx < THREAD_TILE_SIZE_X; 114 | ++thread_tile_col_idx) 115 | { 116 | C_thread_results[thread_tile_row_idx] 117 | [thread_tile_col_idx] += 118 | A_vals[thread_tile_row_idx] * 119 | B_vals[thread_tile_col_idx]; 120 | } 121 | } 122 | } 123 | __syncthreads(); 124 | } 125 | 126 | // Vectorized writing the results to DRAM. 127 | for (size_t thread_tile_row_idx{0U}; 128 | thread_tile_row_idx < THREAD_TILE_SIZE_Y; ++thread_tile_row_idx) 129 | { 130 | for (size_t thread_tile_col_vector_idx{0U}; 131 | thread_tile_col_vector_idx < VECTORIZED_THREAD_TILE_SIZE_X; 132 | ++thread_tile_col_vector_idx) 133 | { 134 | size_t const C_row_idx{ 135 | blockIdx.y * BLOCK_TILE_SIZE_Y + 136 | thread_linear_idx / (BLOCK_TILE_SIZE_X / THREAD_TILE_SIZE_X) * 137 | THREAD_TILE_SIZE_Y + 138 | thread_tile_row_idx}; 139 | size_t const C_col_idx{ 140 | blockIdx.x * BLOCK_TILE_SIZE_X + 141 | thread_linear_idx % (BLOCK_TILE_SIZE_X / THREAD_TILE_SIZE_X) * 142 | THREAD_TILE_SIZE_X + 143 | thread_tile_col_vector_idx * NUM_VECTOR_UNITS}; 144 | // Vectorized read from C. 145 | int4 C_row_vector_vals{*reinterpret_cast( 146 | &C[C_row_idx * ldc + C_col_idx])}; 147 | // Vectorized read from C_thread_results. 148 | int4 const C_thread_results_row_vector_vals{ 149 | *reinterpret_cast( 150 | &C_thread_results[thread_tile_row_idx] 151 | [thread_tile_col_vector_idx * 152 | NUM_VECTOR_UNITS])}; 153 | // Update the values in C_row_vector_vals 154 | for (size_t i{0U}; i < NUM_VECTOR_UNITS; ++i) 155 | { 156 | reinterpret_cast(&C_row_vector_vals)[i] = 157 | alpha * reinterpret_cast( 158 | &C_thread_results_row_vector_vals)[i] + 159 | beta * reinterpret_cast(&C_row_vector_vals)[i]; 160 | } 161 | // Vectorized write to C. 162 | if (C_row_idx < m && C_col_idx < n) 163 | { 164 | // No need to mask out the out-of-bound invalid elements, 165 | // because the row of C matrix is 32-byte aligned. 166 | *reinterpret_cast(&C[C_row_idx * ldc + C_col_idx]) = 167 | C_row_vector_vals; 168 | } 169 | } 170 | } 171 | } 172 | 173 | template 174 | void launch_gemm_kernel_v04_vectorized(size_t m, size_t n, size_t k, 175 | T const* alpha, T const* A, size_t lda, 176 | T const* B, size_t ldb, T const* beta, 177 | T* C, size_t ldc, cudaStream_t stream) 178 | { 179 | // Feel free to play with the block tile sizes. 180 | // The algorithm correctness should always be guaranteed. 181 | constexpr unsigned int BLOCK_TILE_SIZE_X{128U}; 182 | constexpr unsigned int BLOCK_TILE_SIZE_Y{128U}; 183 | constexpr unsigned int BLOCK_TILE_SIZE_K{16U}; 184 | // Each thread computes THREAD_TILE_SIZE_X * THREAD_TILE_SIZE_Y values of C. 185 | constexpr unsigned int THREAD_TILE_SIZE_X{8U}; 186 | constexpr unsigned int THREAD_TILE_SIZE_Y{8U}; 187 | constexpr unsigned int NUM_THREADS_PER_BLOCK{ 188 | BLOCK_TILE_SIZE_X * BLOCK_TILE_SIZE_Y / 189 | (THREAD_TILE_SIZE_X * THREAD_TILE_SIZE_Y)}; 190 | static_assert(BLOCK_TILE_SIZE_X % THREAD_TILE_SIZE_X == 0U); 191 | static_assert(BLOCK_TILE_SIZE_Y % THREAD_TILE_SIZE_Y == 0U); 192 | static_assert(NUM_THREADS_PER_BLOCK % BLOCK_TILE_SIZE_K == 0U); 193 | static_assert(NUM_THREADS_PER_BLOCK % BLOCK_TILE_SIZE_X == 0U); 194 | static_assert( 195 | BLOCK_TILE_SIZE_X * BLOCK_TILE_SIZE_K % NUM_THREADS_PER_BLOCK == 0U); 196 | static_assert( 197 | BLOCK_TILE_SIZE_K * BLOCK_TILE_SIZE_Y % NUM_THREADS_PER_BLOCK == 0U); 198 | dim3 const block_dim{NUM_THREADS_PER_BLOCK, 1U, 1U}; 199 | dim3 const grid_dim{ 200 | (static_cast(n) + BLOCK_TILE_SIZE_X - 1U) / 201 | BLOCK_TILE_SIZE_X, 202 | (static_cast(m) + BLOCK_TILE_SIZE_Y - 1U) / 203 | BLOCK_TILE_SIZE_Y, 204 | 1U}; 205 | gemm_v04_vectorized 208 | <<>>(m, n, k, *alpha, A, lda, B, ldb, 209 | *beta, C, ldc); 210 | CHECK_LAST_CUDA_ERROR(); 211 | } 212 | 213 | // Explicit instantiation. 214 | template void launch_gemm_kernel_v04_vectorized( 215 | size_t m, size_t n, size_t k, float const* alpha, float const* A, 216 | size_t lda, float const* B, size_t ldb, float const* beta, float* C, 217 | size_t ldc, cudaStream_t stream); 218 | template void launch_gemm_kernel_v04_vectorized( 219 | size_t m, size_t n, size_t k, double const* alpha, double const* A, 220 | size_t lda, double const* B, size_t ldb, double const* beta, double* C, 221 | size_t ldc, cudaStream_t stream); 222 | template void launch_gemm_kernel_v04_vectorized<__half>( 223 | size_t m, size_t n, size_t k, __half const* alpha, __half const* A, 224 | size_t lda, __half const* B, size_t ldb, __half const* beta, __half* C, 225 | size_t ldc, cudaStream_t stream); -------------------------------------------------------------------------------- /src/07_2d_block_tiling_2d_warp_tiling_2d_thread_tiling_matrix_transpose_wmma_vectorized_memory_access.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "cuda_gemm.hpp" 5 | #include "cuda_gemm_utils.cuh" 6 | #include "cuda_gemm_utils.hpp" 7 | 8 | // https://developer.nvidia.com/blog/cutlass-linear-algebra-cuda/ 9 | // https://github.com/NVIDIA/cutlass/blob/b7508e337938137a699e486d8997646980acfc58/media/docs/programming_guidelines.md 10 | 11 | // GEMM kernel v07. 12 | // Each thread in the block processes THREAD_TILE_SIZE_Y * 13 | // THREAD_TILE_SIZE_X output values. Number of threads BLOCK_TILE_SIZE_Y * 14 | // BLOCK_TILE_SIZE_X / (THREAD_TILE_SIZE_Y * THREAD_TILE_SIZE_X) 15 | template 20 | __global__ void gemm_v07_vectorized(size_t m, size_t n, size_t k, T alpha, 21 | T const* A, size_t lda, T const* B, 22 | size_t ldb, T beta, T* C, size_t ldc) 23 | { 24 | constexpr size_t NUM_WARPS_X{BLOCK_TILE_SIZE_X / WARP_TILE_SIZE_X}; 25 | static_assert(BLOCK_TILE_SIZE_X % WARP_TILE_SIZE_X == 0U); 26 | static_assert(BLOCK_TILE_SIZE_Y % WARP_TILE_SIZE_Y == 0U); 27 | 28 | // Cache a tile of A and B in shared memory for data reuse. 29 | __shared__ T A_thread_block_tile_transposed[BLOCK_TILE_SIZE_K] 30 | [BLOCK_TILE_SIZE_Y + 31 | BLOCK_TILE_SKEW_SIZE_Y]; 32 | __shared__ T B_thread_block_tile[BLOCK_TILE_SIZE_K][BLOCK_TILE_SIZE_X + 33 | BLOCK_TILE_SKEW_SIZE_X]; 34 | 35 | constexpr size_t NUM_WMMA_TILES_X{WARP_TILE_SIZE_X / WMMA_TILE_SIZE_X}; 36 | static_assert(WARP_TILE_SIZE_X % WMMA_TILE_SIZE_X == 0U); 37 | constexpr size_t NUM_WMMA_TILES_Y{WARP_TILE_SIZE_Y / WMMA_TILE_SIZE_Y}; 38 | static_assert(WARP_TILE_SIZE_Y % WMMA_TILE_SIZE_Y == 0U); 39 | constexpr size_t NUM_WMMA_TILES_K{BLOCK_TILE_SIZE_K / WMMA_TILE_SIZE_K}; 40 | static_assert(BLOCK_TILE_SIZE_K % WMMA_TILE_SIZE_K == 0U); 41 | 42 | // Declare the fragments. 43 | nvcuda::wmma::fragment 46 | a_frags[NUM_WMMA_TILES_Y]; 47 | nvcuda::wmma::fragment 50 | b_frags[NUM_WMMA_TILES_X]; 51 | nvcuda::wmma::fragment 53 | acc_frags[NUM_WMMA_TILES_Y][NUM_WMMA_TILES_X]; 54 | nvcuda::wmma::fragment 56 | c_frag; 57 | 58 | // Make sure the accumulator starts from 0. 59 | #pragma unroll 60 | for (size_t wmma_tile_row_idx{0U}; wmma_tile_row_idx < NUM_WMMA_TILES_Y; 61 | ++wmma_tile_row_idx) 62 | { 63 | for (size_t wmma_tile_col_idx{0U}; wmma_tile_col_idx < NUM_WMMA_TILES_X; 64 | ++wmma_tile_col_idx) 65 | { 66 | nvcuda::wmma::fill_fragment( 67 | acc_frags[wmma_tile_row_idx][wmma_tile_col_idx], 68 | static_cast(0)); 69 | } 70 | } 71 | 72 | size_t const thread_linear_idx{threadIdx.y * blockDim.x + threadIdx.x}; 73 | size_t const warp_linear_idx{thread_linear_idx / 32U}; 74 | size_t const warp_row_idx{warp_linear_idx / NUM_WARPS_X}; 75 | size_t const warp_col_idx{warp_linear_idx % NUM_WARPS_X}; 76 | 77 | // Number of outer loops to perform the sum of inner products. 78 | // C_thread_block_tile = 79 | // \sigma_{thread_block_tile_idx=0}^{num_thread_block_tiles-1} A[:, 80 | // thread_block_tile_idx:BLOCK_TILE_SIZE_K] * 81 | // B[thread_block_tile_idx:BLOCK_TILE_SIZE_K, :] 82 | size_t const num_thread_block_tiles{(k + BLOCK_TILE_SIZE_K - 1) / 83 | BLOCK_TILE_SIZE_K}; 84 | 85 | for (size_t thread_block_tile_idx{0U}; 86 | thread_block_tile_idx < num_thread_block_tiles; 87 | ++thread_block_tile_idx) 88 | { 89 | load_data_from_global_memory_to_shared_memory_transposed_vectorized< 90 | T, BLOCK_TILE_SIZE_X, BLOCK_TILE_SIZE_Y, BLOCK_TILE_SIZE_K, 91 | NUM_THREADS, BLOCK_TILE_SKEW_SIZE_X, BLOCK_TILE_SKEW_SIZE_Y>( 92 | A, lda, B, ldb, A_thread_block_tile_transposed, B_thread_block_tile, 93 | thread_block_tile_idx, thread_linear_idx, m, n, k); 94 | __syncthreads(); 95 | 96 | // Perform A[:, thread_block_tile_idx:BLOCK_TILE_SIZE_K] * 97 | // B[thread_block_tile_idx:BLOCK_TILE_SIZE_K, :] where A[:, 98 | // thread_block_tile_idx:BLOCK_TILE_SIZE_K] and 99 | // B[thread_block_tile_idx:BLOCK_TILE_SIZE_K, :] are cached in the 100 | // shared memory as A_thread_block_tile and B_thread_block_tile, 101 | // respectively. This inner product is further decomposed to 102 | // BLOCK_TILE_SIZE_K outer products. A_thread_block_tile * 103 | // B_thread_block_tile = \sigma_{k_i=0}^{BLOCK_TILE_SIZE_K-1} 104 | // A_thread_block_tile[:, k_i] @ B_thread_block_tile[k_i, :] Note that 105 | // both A_thread_block_tile and B_thread_block_tile can be cached in the 106 | // register. 107 | #pragma unroll 108 | for (size_t k_i{0U}; k_i < NUM_WMMA_TILES_K; ++k_i) 109 | { 110 | #pragma unroll 111 | for (size_t wmma_tile_row_idx{0U}; 112 | wmma_tile_row_idx < NUM_WMMA_TILES_Y; ++wmma_tile_row_idx) 113 | { 114 | nvcuda::wmma::load_matrix_sync( 115 | a_frags[wmma_tile_row_idx], 116 | &A_thread_block_tile_transposed[k_i * WMMA_TILE_SIZE_K] 117 | [warp_row_idx * 118 | WARP_TILE_SIZE_Y + 119 | wmma_tile_row_idx * 120 | WMMA_TILE_SIZE_Y], 121 | BLOCK_TILE_SIZE_Y + BLOCK_TILE_SKEW_SIZE_Y); 122 | } 123 | #pragma unroll 124 | for (size_t wmma_tile_col_idx{0U}; 125 | wmma_tile_col_idx < NUM_WMMA_TILES_X; ++wmma_tile_col_idx) 126 | { 127 | nvcuda::wmma::load_matrix_sync( 128 | b_frags[wmma_tile_col_idx], 129 | &B_thread_block_tile[k_i * WMMA_TILE_SIZE_K] 130 | [warp_col_idx * WARP_TILE_SIZE_X + 131 | wmma_tile_col_idx * WMMA_TILE_SIZE_X], 132 | BLOCK_TILE_SIZE_X + BLOCK_TILE_SKEW_SIZE_X); 133 | } 134 | #pragma unroll 135 | for (size_t wmma_tile_row_idx{0U}; 136 | wmma_tile_row_idx < NUM_WMMA_TILES_Y; ++wmma_tile_row_idx) 137 | { 138 | #pragma unroll 139 | for (size_t wmma_tile_col_idx{0U}; 140 | wmma_tile_col_idx < NUM_WMMA_TILES_X; ++wmma_tile_col_idx) 141 | { 142 | // Perform the matrix multiplication. 143 | nvcuda::wmma::mma_sync( 144 | acc_frags[wmma_tile_row_idx][wmma_tile_col_idx], 145 | a_frags[wmma_tile_row_idx], b_frags[wmma_tile_col_idx], 146 | acc_frags[wmma_tile_row_idx][wmma_tile_col_idx]); 147 | } 148 | } 149 | } 150 | __syncthreads(); 151 | } 152 | 153 | // Write the results to DRAM. 154 | #pragma unroll 155 | for (size_t wmma_tile_row_idx{0U}; wmma_tile_row_idx < NUM_WMMA_TILES_Y; 156 | ++wmma_tile_row_idx) 157 | { 158 | #pragma unroll 159 | for (size_t wmma_tile_col_idx{0U}; wmma_tile_col_idx < NUM_WMMA_TILES_X; 160 | ++wmma_tile_col_idx) 161 | { 162 | // Load the fragment from global memory. 163 | nvcuda::wmma::load_matrix_sync( 164 | c_frag, 165 | &C[(blockIdx.y * BLOCK_TILE_SIZE_Y + 166 | warp_row_idx * WARP_TILE_SIZE_Y + 167 | wmma_tile_row_idx * WMMA_TILE_SIZE_Y) * 168 | n + 169 | blockIdx.x * BLOCK_TILE_SIZE_X + 170 | warp_col_idx * WARP_TILE_SIZE_X + 171 | wmma_tile_col_idx * WMMA_TILE_SIZE_X], 172 | n, nvcuda::wmma::mem_row_major); 173 | // Perform scaling and addition. 174 | for (size_t i{0}; i < c_frag.num_elements; ++i) 175 | { 176 | c_frag.x[i] = 177 | alpha * 178 | acc_frags[wmma_tile_row_idx][wmma_tile_col_idx].x[i] + 179 | beta * c_frag.x[i]; 180 | } 181 | // Store the fragment back to global memory. 182 | nvcuda::wmma::store_matrix_sync( 183 | &C[(blockIdx.y * BLOCK_TILE_SIZE_Y + 184 | warp_row_idx * WARP_TILE_SIZE_Y + 185 | wmma_tile_row_idx * WMMA_TILE_SIZE_Y) * 186 | n + 187 | blockIdx.x * BLOCK_TILE_SIZE_X + 188 | warp_col_idx * WARP_TILE_SIZE_X + 189 | wmma_tile_col_idx * WMMA_TILE_SIZE_X], 190 | c_frag, n, nvcuda::wmma::mem_row_major); 191 | } 192 | } 193 | } 194 | 195 | template 196 | void launch_gemm_kernel_v07_vectorized(size_t m, size_t n, size_t k, 197 | T const* alpha, T const* A, size_t lda, 198 | T const* B, size_t ldb, T const* beta, 199 | T* C, size_t ldc, cudaStream_t stream) 200 | { 201 | // Feel free to play with the block tile sizes. 202 | // The algorithm correctness should always be guaranteed. 203 | constexpr unsigned int BLOCK_TILE_SIZE_X{128U}; 204 | constexpr unsigned int BLOCK_TILE_SIZE_Y{128U}; 205 | constexpr unsigned int BLOCK_TILE_SIZE_K{16U}; 206 | 207 | // The skew size is used to avoid bank conflicts in shared memory. 208 | constexpr size_t BLOCK_TILE_SKEW_SIZE_X{16U}; 209 | constexpr size_t BLOCK_TILE_SKEW_SIZE_Y{16U}; 210 | 211 | constexpr unsigned int WARP_TILE_SIZE_X{32U}; 212 | constexpr unsigned int WARP_TILE_SIZE_Y{64U}; 213 | constexpr unsigned int NUM_WARPS_X{BLOCK_TILE_SIZE_X / WARP_TILE_SIZE_X}; 214 | constexpr unsigned int NUM_WARPS_Y{BLOCK_TILE_SIZE_Y / WARP_TILE_SIZE_Y}; 215 | static_assert(BLOCK_TILE_SIZE_X % WARP_TILE_SIZE_X == 0U); 216 | static_assert(BLOCK_TILE_SIZE_Y % WARP_TILE_SIZE_Y == 0U); 217 | 218 | constexpr unsigned int WMMA_TILE_SIZE_X{16U}; 219 | constexpr unsigned int WMMA_TILE_SIZE_Y{16U}; 220 | constexpr unsigned int WMMA_TILE_SIZE_K{16U}; 221 | 222 | constexpr unsigned int NUM_THREADS_PER_BLOCK{NUM_WARPS_X * NUM_WARPS_Y * 223 | 32U}; 224 | 225 | dim3 const block_dim{NUM_THREADS_PER_BLOCK, 1U, 1U}; 226 | dim3 const grid_dim{ 227 | (static_cast(n) + BLOCK_TILE_SIZE_X - 1U) / 228 | BLOCK_TILE_SIZE_X, 229 | (static_cast(m) + BLOCK_TILE_SIZE_Y - 1U) / 230 | BLOCK_TILE_SIZE_Y, 231 | 1U}; 232 | gemm_v07_vectorized 237 | <<>>(m, n, k, *alpha, A, lda, B, ldb, 238 | *beta, C, ldc); 239 | CHECK_LAST_CUDA_ERROR(); 240 | } 241 | 242 | // Explicit instantiation. 243 | template void launch_gemm_kernel_v07_vectorized<__half>( 244 | size_t m, size_t n, size_t k, __half const* alpha, __half const* A, 245 | size_t lda, __half const* B, size_t ldb, __half const* beta, __half* C, 246 | size_t ldc, cudaStream_t stream); -------------------------------------------------------------------------------- /src/07_2d_block_tiling_2d_warp_tiling_2d_thread_tiling_matrix_transpose_wmma.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "cuda_gemm.hpp" 5 | #include "cuda_gemm_utils.cuh" 6 | #include "cuda_gemm_utils.hpp" 7 | 8 | // https://developer.nvidia.com/blog/cutlass-linear-algebra-cuda/ 9 | // https://github.com/NVIDIA/cutlass/blob/b7508e337938137a699e486d8997646980acfc58/media/docs/programming_guidelines.md 10 | 11 | // GEMM kernel v07. 12 | // Each thread in the block processes THREAD_TILE_SIZE_Y * 13 | // THREAD_TILE_SIZE_X output values. Number of threads BLOCK_TILE_SIZE_Y * 14 | // BLOCK_TILE_SIZE_X / (THREAD_TILE_SIZE_Y * THREAD_TILE_SIZE_X) 15 | template 20 | __global__ void gemm_v07(size_t m, size_t n, size_t k, T alpha, T const* A, 21 | size_t lda, T const* B, size_t ldb, T beta, T* C, 22 | size_t ldc) 23 | { 24 | constexpr size_t NUM_WARPS_X{BLOCK_TILE_SIZE_X / WARP_TILE_SIZE_X}; 25 | static_assert(BLOCK_TILE_SIZE_X % WARP_TILE_SIZE_X == 0U); 26 | static_assert(BLOCK_TILE_SIZE_Y % WARP_TILE_SIZE_Y == 0U); 27 | 28 | // Cache a tile of A and B in shared memory for data reuse. 29 | __shared__ T A_thread_block_tile_transposed[BLOCK_TILE_SIZE_K] 30 | [BLOCK_TILE_SIZE_Y + 31 | BLOCK_TILE_SKEW_SIZE_Y]; 32 | __shared__ T B_thread_block_tile[BLOCK_TILE_SIZE_K][BLOCK_TILE_SIZE_X + 33 | BLOCK_TILE_SKEW_SIZE_X]; 34 | 35 | constexpr size_t NUM_WMMA_TILES_X{WARP_TILE_SIZE_X / WMMA_TILE_SIZE_X}; 36 | static_assert(WARP_TILE_SIZE_X % WMMA_TILE_SIZE_X == 0U); 37 | constexpr size_t NUM_WMMA_TILES_Y{WARP_TILE_SIZE_Y / WMMA_TILE_SIZE_Y}; 38 | static_assert(WARP_TILE_SIZE_Y % WMMA_TILE_SIZE_Y == 0U); 39 | constexpr size_t NUM_WMMA_TILES_K{BLOCK_TILE_SIZE_K / WMMA_TILE_SIZE_K}; 40 | static_assert(BLOCK_TILE_SIZE_K % WMMA_TILE_SIZE_K == 0U); 41 | 42 | // Declare the fragments. 43 | nvcuda::wmma::fragment 46 | a_frags[NUM_WMMA_TILES_Y]; 47 | nvcuda::wmma::fragment 50 | b_frags[NUM_WMMA_TILES_X]; 51 | nvcuda::wmma::fragment 53 | acc_frags[NUM_WMMA_TILES_Y][NUM_WMMA_TILES_X]; 54 | nvcuda::wmma::fragment 56 | c_frag; 57 | 58 | // Make sure the accumulator starts from 0. 59 | #pragma unroll 60 | for (size_t wmma_tile_row_idx{0U}; wmma_tile_row_idx < NUM_WMMA_TILES_Y; 61 | ++wmma_tile_row_idx) 62 | { 63 | for (size_t wmma_tile_col_idx{0U}; wmma_tile_col_idx < NUM_WMMA_TILES_X; 64 | ++wmma_tile_col_idx) 65 | { 66 | nvcuda::wmma::fill_fragment( 67 | acc_frags[wmma_tile_row_idx][wmma_tile_col_idx], 68 | static_cast(0)); 69 | } 70 | } 71 | 72 | size_t const thread_linear_idx{threadIdx.y * blockDim.x + threadIdx.x}; 73 | size_t const warp_linear_idx{thread_linear_idx / 32U}; 74 | size_t const warp_row_idx{warp_linear_idx / NUM_WARPS_X}; 75 | size_t const warp_col_idx{warp_linear_idx % NUM_WARPS_X}; 76 | 77 | // Number of outer loops to perform the sum of inner products. 78 | // C_thread_block_tile = 79 | // \sigma_{thread_block_tile_idx=0}^{num_thread_block_tiles-1} A[:, 80 | // thread_block_tile_idx:BLOCK_TILE_SIZE_K] * 81 | // B[thread_block_tile_idx:BLOCK_TILE_SIZE_K, :] 82 | size_t const num_thread_block_tiles{(k + BLOCK_TILE_SIZE_K - 1) / 83 | BLOCK_TILE_SIZE_K}; 84 | 85 | for (size_t thread_block_tile_idx{0U}; 86 | thread_block_tile_idx < num_thread_block_tiles; 87 | ++thread_block_tile_idx) 88 | { 89 | load_data_from_global_memory_to_shared_memory_transposed< 90 | T, BLOCK_TILE_SIZE_X, BLOCK_TILE_SIZE_Y, BLOCK_TILE_SIZE_K, 91 | NUM_THREADS, BLOCK_TILE_SKEW_SIZE_X, BLOCK_TILE_SKEW_SIZE_Y>( 92 | A, lda, B, ldb, A_thread_block_tile_transposed, B_thread_block_tile, 93 | thread_block_tile_idx, thread_linear_idx, m, n, k); 94 | __syncthreads(); 95 | 96 | // Perform A[:, thread_block_tile_idx:BLOCK_TILE_SIZE_K] * 97 | // B[thread_block_tile_idx:BLOCK_TILE_SIZE_K, :] where A[:, 98 | // thread_block_tile_idx:BLOCK_TILE_SIZE_K] and 99 | // B[thread_block_tile_idx:BLOCK_TILE_SIZE_K, :] are cached in the 100 | // shared memory as A_thread_block_tile and B_thread_block_tile, 101 | // respectively. This inner product is further decomposed to 102 | // BLOCK_TILE_SIZE_K outer products. A_thread_block_tile * 103 | // B_thread_block_tile = \sigma_{k_i=0}^{BLOCK_TILE_SIZE_K-1} 104 | // A_thread_block_tile[:, k_i] @ B_thread_block_tile[k_i, :] Note that 105 | // both A_thread_block_tile and B_thread_block_tile can be cached in the 106 | // register. 107 | #pragma unroll 108 | for (size_t k_i{0U}; k_i < NUM_WMMA_TILES_K; ++k_i) 109 | { 110 | #pragma unroll 111 | for (size_t wmma_tile_row_idx{0U}; 112 | wmma_tile_row_idx < NUM_WMMA_TILES_Y; ++wmma_tile_row_idx) 113 | { 114 | nvcuda::wmma::load_matrix_sync( 115 | a_frags[wmma_tile_row_idx], 116 | &A_thread_block_tile_transposed[k_i * WMMA_TILE_SIZE_K] 117 | [warp_row_idx * 118 | WARP_TILE_SIZE_Y + 119 | wmma_tile_row_idx * 120 | WMMA_TILE_SIZE_Y], 121 | BLOCK_TILE_SIZE_Y + BLOCK_TILE_SKEW_SIZE_Y); 122 | } 123 | #pragma unroll 124 | for (size_t wmma_tile_col_idx{0U}; 125 | wmma_tile_col_idx < NUM_WMMA_TILES_X; ++wmma_tile_col_idx) 126 | { 127 | nvcuda::wmma::load_matrix_sync( 128 | b_frags[wmma_tile_col_idx], 129 | &B_thread_block_tile[k_i * WMMA_TILE_SIZE_K] 130 | [warp_col_idx * WARP_TILE_SIZE_X + 131 | wmma_tile_col_idx * WMMA_TILE_SIZE_X], 132 | BLOCK_TILE_SIZE_X + BLOCK_TILE_SKEW_SIZE_X); 133 | } 134 | #pragma unroll 135 | for (size_t wmma_tile_row_idx{0U}; 136 | wmma_tile_row_idx < NUM_WMMA_TILES_Y; ++wmma_tile_row_idx) 137 | { 138 | #pragma unroll 139 | for (size_t wmma_tile_col_idx{0U}; 140 | wmma_tile_col_idx < NUM_WMMA_TILES_X; ++wmma_tile_col_idx) 141 | { 142 | // Perform the matrix multiplication. 143 | nvcuda::wmma::mma_sync( 144 | acc_frags[wmma_tile_row_idx][wmma_tile_col_idx], 145 | a_frags[wmma_tile_row_idx], b_frags[wmma_tile_col_idx], 146 | acc_frags[wmma_tile_row_idx][wmma_tile_col_idx]); 147 | } 148 | } 149 | } 150 | __syncthreads(); 151 | } 152 | 153 | // Write the results to DRAM. 154 | #pragma unroll 155 | for (size_t wmma_tile_row_idx{0U}; wmma_tile_row_idx < NUM_WMMA_TILES_Y; 156 | ++wmma_tile_row_idx) 157 | { 158 | #pragma unroll 159 | for (size_t wmma_tile_col_idx{0U}; wmma_tile_col_idx < NUM_WMMA_TILES_X; 160 | ++wmma_tile_col_idx) 161 | { 162 | // Load the fragment from global memory. 163 | nvcuda::wmma::load_matrix_sync( 164 | c_frag, 165 | &C[(blockIdx.y * BLOCK_TILE_SIZE_Y + 166 | warp_row_idx * WARP_TILE_SIZE_Y + 167 | wmma_tile_row_idx * WMMA_TILE_SIZE_Y) * 168 | n + 169 | blockIdx.x * BLOCK_TILE_SIZE_X + 170 | warp_col_idx * WARP_TILE_SIZE_X + 171 | wmma_tile_col_idx * WMMA_TILE_SIZE_X], 172 | n, nvcuda::wmma::mem_row_major); 173 | // Perform scaling and addition. 174 | for (size_t i{0}; i < c_frag.num_elements; ++i) 175 | { 176 | c_frag.x[i] = 177 | alpha * 178 | acc_frags[wmma_tile_row_idx][wmma_tile_col_idx].x[i] + 179 | beta * c_frag.x[i]; 180 | } 181 | // Store the fragment back to global memory. 182 | nvcuda::wmma::store_matrix_sync( 183 | &C[(blockIdx.y * BLOCK_TILE_SIZE_Y + 184 | warp_row_idx * WARP_TILE_SIZE_Y + 185 | wmma_tile_row_idx * WMMA_TILE_SIZE_Y) * 186 | n + 187 | blockIdx.x * BLOCK_TILE_SIZE_X + 188 | warp_col_idx * WARP_TILE_SIZE_X + 189 | wmma_tile_col_idx * WMMA_TILE_SIZE_X], 190 | c_frag, n, nvcuda::wmma::mem_row_major); 191 | } 192 | } 193 | } 194 | 195 | template 196 | void launch_gemm_kernel_v07(size_t m, size_t n, size_t k, T const* alpha, 197 | T const* A, size_t lda, T const* B, size_t ldb, 198 | T const* beta, T* C, size_t ldc, 199 | cudaStream_t stream) 200 | { 201 | // Feel free to play with the block tile sizes. 202 | // The algorithm correctness should always be guaranteed. 203 | constexpr unsigned int BLOCK_TILE_SIZE_X{128U}; 204 | constexpr unsigned int BLOCK_TILE_SIZE_Y{128U}; 205 | constexpr unsigned int BLOCK_TILE_SIZE_K{16U}; 206 | 207 | constexpr unsigned int WARP_TILE_SIZE_X{32U}; 208 | constexpr unsigned int WARP_TILE_SIZE_Y{64U}; 209 | constexpr unsigned int NUM_WARPS_X{BLOCK_TILE_SIZE_X / WARP_TILE_SIZE_X}; 210 | constexpr unsigned int NUM_WARPS_Y{BLOCK_TILE_SIZE_Y / WARP_TILE_SIZE_Y}; 211 | static_assert(BLOCK_TILE_SIZE_X % WARP_TILE_SIZE_X == 0U); 212 | static_assert(BLOCK_TILE_SIZE_Y % WARP_TILE_SIZE_Y == 0U); 213 | 214 | // The skew size is used to avoid bank conflicts in shared memory. 215 | constexpr size_t BLOCK_TILE_SKEW_SIZE_X{16U}; 216 | constexpr size_t BLOCK_TILE_SKEW_SIZE_Y{16U}; 217 | 218 | constexpr unsigned int WMMA_TILE_SIZE_X{16U}; 219 | constexpr unsigned int WMMA_TILE_SIZE_Y{16U}; 220 | constexpr unsigned int WMMA_TILE_SIZE_K{16U}; 221 | 222 | constexpr unsigned int NUM_THREADS_PER_BLOCK{NUM_WARPS_X * NUM_WARPS_Y * 223 | 32U}; 224 | 225 | dim3 const block_dim{NUM_THREADS_PER_BLOCK, 1U, 1U}; 226 | dim3 const grid_dim{ 227 | (static_cast(n) + BLOCK_TILE_SIZE_X - 1U) / 228 | BLOCK_TILE_SIZE_X, 229 | (static_cast(m) + BLOCK_TILE_SIZE_Y - 1U) / 230 | BLOCK_TILE_SIZE_Y, 231 | 1U}; 232 | gemm_v07 236 | <<>>(m, n, k, *alpha, A, lda, B, ldb, 237 | *beta, C, ldc); 238 | CHECK_LAST_CUDA_ERROR(); 239 | } 240 | 241 | // Explicit instantiation. 242 | template void launch_gemm_kernel_v07<__half>(size_t m, size_t n, size_t k, 243 | __half const* alpha, 244 | __half const* A, size_t lda, 245 | __half const* B, size_t ldb, 246 | __half const* beta, __half* C, 247 | size_t ldc, cudaStream_t stream); -------------------------------------------------------------------------------- /include/profile_utils.cuh: -------------------------------------------------------------------------------- 1 | #ifndef PROFILE_UTILS_CUH 2 | #define PROFILE_UTILS_CUH 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include "cuda_gemm.hpp" 11 | #include "cuda_gemm_utils.cuh" 12 | 13 | #include 14 | #include 15 | #include 16 | 17 | template 18 | float measure_performance(std::function bound_function, 19 | cudaStream_t stream, size_t num_repeats = 100, 20 | size_t num_warmups = 100) 21 | { 22 | cudaEvent_t start, stop; 23 | float time; 24 | 25 | CHECK_CUDA_ERROR(cudaEventCreate(&start)); 26 | CHECK_CUDA_ERROR(cudaEventCreate(&stop)); 27 | 28 | for (size_t i{0}; i < num_warmups; ++i) 29 | { 30 | bound_function(stream); 31 | } 32 | 33 | CHECK_CUDA_ERROR(cudaStreamSynchronize(stream)); 34 | 35 | CHECK_CUDA_ERROR(cudaEventRecord(start, stream)); 36 | for (size_t i{0}; i < num_repeats; ++i) 37 | { 38 | bound_function(stream); 39 | } 40 | CHECK_CUDA_ERROR(cudaEventRecord(stop, stream)); 41 | CHECK_CUDA_ERROR(cudaEventSynchronize(stop)); 42 | CHECK_LAST_CUDA_ERROR(); 43 | CHECK_CUDA_ERROR(cudaEventElapsedTime(&time, start, stop)); 44 | CHECK_CUDA_ERROR(cudaEventDestroy(start)); 45 | CHECK_CUDA_ERROR(cudaEventDestroy(stop)); 46 | 47 | float const latency{time / num_repeats}; 48 | 49 | return latency; 50 | } 51 | 52 | #define CHECK_CUBLASS_ERROR(val) check_cublass((val), #val, __FILE__, __LINE__) 53 | void check_cublass(cublasStatus_t err, const char* const func, 54 | const char* const file, const int line) 55 | { 56 | if (err != CUBLAS_STATUS_SUCCESS) 57 | { 58 | std::cerr << "cuBLAS Error at: " << file << ":" << line << std::endl; 59 | std::cerr << cublasGetStatusString(err) << std::endl; 60 | std::exit(EXIT_FAILURE); 61 | } 62 | } 63 | 64 | // Determine CUDA data type from type. 65 | template ::value || 67 | std::is_same::value || 68 | std::is_same::value, 69 | bool>::type = true> 70 | constexpr cudaDataType_t cuda_data_type_trait() 71 | { 72 | if (std::is_same::value) 73 | { 74 | return CUDA_R_32F; 75 | } 76 | else if (std::is_same::value) 77 | { 78 | return CUDA_R_64F; 79 | } 80 | else if (std::is_same::value) 81 | { 82 | return CUDA_R_16F; 83 | } 84 | else 85 | { 86 | throw std::runtime_error("Unsupported data type."); 87 | } 88 | } 89 | 90 | template ::value || 92 | std::is_same::value || 93 | std::is_same::value, 94 | bool>::type = true> 95 | void launch_gemm_cublas(size_t m, size_t n, size_t k, T const* alpha, 96 | T const* A, size_t lda, T const* B, size_t ldb, 97 | T const* beta, T* C, size_t ldc, cublasHandle_t handle) 98 | { 99 | // Non-TensorCore algorithm? 100 | constexpr cublasGemmAlgo_t algo{CUBLAS_GEMM_DEFAULT}; 101 | constexpr cudaDataType_t data_type{cuda_data_type_trait()}; 102 | // All the matrix are in row-major order. 103 | // https://docs.nvidia.com/cuda/cublas/#cublasgemmex 104 | // A: m x k row-major -> A: k x m column-major non-transposed 105 | // B: k x n row-major -> B: n x k column-major non-transposed 106 | // C: m x n row-major -> C: n x m column-major non-transposed 107 | // Thus, without padding, the leading dimension of the matrix in row-major 108 | // order is the number of columns, i.e., k for A, n for B, and n for C. 109 | // Row-major order: C = AB + C 110 | // Column-major order: C = BA + C 111 | // The cuBLAS API requires the leading dimension of the matrix in 112 | // column-major order. This API call looks non-intuitive, but it is correct. 113 | CHECK_CUBLASS_ERROR(cublasGemmEx( 114 | handle, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, alpha, B, data_type, ldb, A, 115 | data_type, lda, beta, C, data_type, ldc, data_type, algo)); 116 | } 117 | 118 | template ::value || 120 | std::is_same::value, 121 | bool>::type = true> 122 | void launch_gemm_cpu(size_t m, size_t n, size_t k, T const* alpha, T const* A, 123 | size_t lda, T const* B, size_t ldb, T const* beta, T* C, 124 | size_t ldc) 125 | { 126 | // Compute GEMM using CPU. 127 | for (size_t i{0U}; i < m; ++i) 128 | { 129 | for (size_t j{0U}; j < n; ++j) 130 | { 131 | T sum{static_cast(0)}; 132 | for (size_t l{0U}; l < k; ++l) 133 | { 134 | sum += A[i * lda + l] * B[l * ldb + j]; 135 | } 136 | C[i * ldc + j] = (*alpha) * sum + (*beta) * C[i * ldc + j]; 137 | } 138 | } 139 | } 140 | 141 | // Many different implementations have been tried for FP16 GEMM on CPU. 142 | // There is always a discrepancy between the results from CPU and GPU (cuBLAS or 143 | // custom kernel). 144 | template ::value, 145 | bool>::type = true> 146 | void launch_gemm_cpu(size_t m, size_t n, size_t k, T const* alpha, T const* A, 147 | size_t lda, T const* B, size_t ldb, T const* beta, T* C, 148 | size_t ldc) 149 | { 150 | // Compute GEMM using CPU. 151 | for (size_t i{0U}; i < m; ++i) 152 | { 153 | for (size_t j{0U}; j < n; ++j) 154 | { 155 | float sum{0.0f}; 156 | for (size_t l{0U}; l < k; ++l) 157 | { 158 | sum += __half2float(__hmul(A[i * lda + l], B[l * ldb + j])); 159 | } 160 | C[i * ldc + j] = __float2half(__half2float(*alpha) * sum + 161 | __half2float(*beta) * 162 | __half2float(C[i * ldc + j])); 163 | } 164 | } 165 | } 166 | 167 | template 168 | bool all_close(T const* C, T const* C_ref, size_t m, size_t n, size_t ldc, 169 | T abs_tol, double rel_tol) 170 | { 171 | bool status{true}; 172 | for (size_t i{0U}; i < m; ++i) 173 | { 174 | for (size_t j{0U}; j < n; ++j) 175 | { 176 | double const C_val{static_cast(C[i * ldc + j])}; 177 | double const C_ref_val{static_cast(C_ref[i * ldc + j])}; 178 | double const diff{C_val - C_ref_val}; 179 | double const diff_val{std::abs(diff)}; 180 | if (diff_val > 181 | std::max(static_cast(abs_tol), 182 | static_cast(std::abs(C_ref_val)) * rel_tol)) 183 | { 184 | std::cout << "C[" << i << ", " << j << "] = " << C_val 185 | << " C_ref[" << i << ", " << j << "] = " << C_ref_val 186 | << " Abs Diff: " << diff_val 187 | << " Abs Diff Threshold: " 188 | << static_cast(abs_tol) 189 | << " Rel->Abs Diff Threshold: " 190 | << static_cast( 191 | static_cast(std::abs(C_ref_val)) * 192 | rel_tol) 193 | << std::endl; 194 | status = false; 195 | return status; 196 | } 197 | } 198 | } 199 | return status; 200 | } 201 | 202 | void print_device_info() 203 | { 204 | int device_id{0}; 205 | cudaGetDevice(&device_id); 206 | cudaDeviceProp device_prop; 207 | cudaGetDeviceProperties(&device_prop, device_id); 208 | std::cout << "Device Name: " << device_prop.name << std::endl; 209 | float const memory_size{static_cast(device_prop.totalGlobalMem) / 210 | (1 << 30)}; 211 | std::cout << "Memory Size: " << memory_size << " GB" << std::endl; 212 | float const peak_bandwidth{ 213 | static_cast(2.0f * device_prop.memoryClockRate * 214 | (device_prop.memoryBusWidth / 8) / 1.0e6)}; 215 | std::cout << "Peak Bandwitdh: " << peak_bandwidth << " GB/s" << std::endl; 216 | std::cout << std::endl; 217 | } 218 | 219 | template 220 | float compute_effective_bandwidth(size_t m, size_t n, size_t k, float latency) 221 | { 222 | return ((m * k + k * n + m * n) * sizeof(T)) / (latency * 1e-3) / 1e9; 223 | } 224 | 225 | float compute_effective_tflops(size_t m, size_t n, size_t k, float latency) 226 | { 227 | return (2.0 * m * k * n) / (latency * 1e-3) / 1e12; 228 | } 229 | 230 | template ::value || 232 | std::is_same::value || 233 | std::is_same::value, 234 | bool>::type = true> 235 | void random_initialize_matrix(T* A, size_t m, size_t n, size_t lda, 236 | unsigned int seed = 0U) 237 | { 238 | std::default_random_engine eng(seed); 239 | // The best way to verify is to use integer values. 240 | std::uniform_int_distribution dis(0, 5); 241 | // std::uniform_real_distribution dis(-1.0f, 1.0f); 242 | auto const rand = [&dis, &eng]() { return dis(eng); }; 243 | for (size_t i{0U}; i < m; ++i) 244 | { 245 | for (size_t j{0U}; j < n; ++j) 246 | { 247 | A[i * lda + j] = static_cast(rand()); 248 | } 249 | } 250 | } 251 | 252 | void print_performance_result(size_t m, size_t n, size_t k, float latency) 253 | { 254 | float const effective_bandwidth{ 255 | compute_effective_bandwidth(m, n, k, latency)}; 256 | float const effective_tflops{compute_effective_tflops(m, n, k, latency)}; 257 | 258 | std::cout << "Latency: " << latency << " ms" << std::endl; 259 | std::cout << "Effective Bandwidth: " << effective_bandwidth << " GB/s" 260 | << std::endl; 261 | std::cout << "Effective TFLOPS: " << effective_tflops << " TFLOPS" 262 | << std::endl; 263 | } 264 | 265 | template ::value || 267 | std::is_same::value || 268 | std::is_same::value, 269 | bool>::type = true> 270 | std::pair profile_gemm( 271 | size_t m, size_t n, size_t k, size_t lda, size_t ldb, size_t ldc, 272 | std::function 274 | gemm_kernel_launch_function, 275 | T abs_tol, double rel_tol, size_t num_repeats = 10, size_t num_warmups = 10, 276 | unsigned int seed = 0U) 277 | { 278 | T const alpha{static_cast(1.0)}; 279 | T const beta{static_cast(0.0)}; 280 | 281 | // Create CUDA stream. 282 | cudaStream_t stream; 283 | CHECK_CUDA_ERROR(cudaStreamCreate(&stream)); 284 | 285 | // Allocate memory on host. 286 | T* A_host{nullptr}; 287 | T* B_host{nullptr}; 288 | T* C_host{nullptr}; 289 | T* C_host_ref{nullptr}; 290 | T* C_host_from_device{nullptr}; 291 | CHECK_CUDA_ERROR(cudaMallocHost(&A_host, m * lda * sizeof(T))); 292 | CHECK_CUDA_ERROR(cudaMallocHost(&B_host, k * ldb * sizeof(T))); 293 | CHECK_CUDA_ERROR(cudaMallocHost(&C_host, m * ldc * sizeof(T))); 294 | CHECK_CUDA_ERROR(cudaMallocHost(&C_host_ref, m * ldc * sizeof(T))); 295 | CHECK_CUDA_ERROR(cudaMallocHost(&C_host_from_device, m * ldc * sizeof(T))); 296 | 297 | // Initialize matrix A and B. 298 | random_initialize_matrix(A_host, m, k, lda); 299 | random_initialize_matrix(B_host, k, n, ldb); 300 | random_initialize_matrix(C_host, m, n, ldc); 301 | 302 | // Allocate memory on device. 303 | T* A_device{nullptr}; 304 | T* B_device{nullptr}; 305 | T* C_device{nullptr}; 306 | CHECK_CUDA_ERROR(cudaMalloc(&A_device, m * lda * sizeof(T))); 307 | CHECK_CUDA_ERROR(cudaMalloc(&B_device, k * ldb * sizeof(T))); 308 | CHECK_CUDA_ERROR(cudaMalloc(&C_device, m * ldc * sizeof(T))); 309 | 310 | // Copy matrix A and B from host to device. 311 | CHECK_CUDA_ERROR(cudaMemcpy(A_device, A_host, m * lda * sizeof(T), 312 | cudaMemcpyHostToDevice)); 313 | CHECK_CUDA_ERROR(cudaMemcpy(B_device, B_host, k * ldb * sizeof(T), 314 | cudaMemcpyHostToDevice)); 315 | CHECK_CUDA_ERROR(cudaMemcpy(C_device, C_host, m * ldc * sizeof(T), 316 | cudaMemcpyHostToDevice)); 317 | CHECK_CUDA_ERROR(cudaMemcpy(C_host_ref, C_host, m * ldc * sizeof(T), 318 | cudaMemcpyHostToHost)); 319 | 320 | // Create cuBLAS handle. 321 | cublasHandle_t handle; 322 | CHECK_CUBLASS_ERROR(cublasCreate(&handle)); 323 | CHECK_CUBLASS_ERROR(cublasSetStream(handle, stream)); 324 | 325 | // Compute reference output using cuBLAS. 326 | launch_gemm_cublas(m, n, k, &alpha, A_device, lda, B_device, ldb, &beta, 327 | C_device, ldc, handle); 328 | CHECK_CUDA_ERROR(cudaStreamSynchronize(stream)); 329 | 330 | // Copy matrix C from device to host. 331 | CHECK_CUDA_ERROR(cudaMemcpy(C_host_ref, C_device, m * ldc * sizeof(T), 332 | cudaMemcpyDeviceToHost)); 333 | 334 | // // Compute reference output using CPU. 335 | // std::cout << "Computing reference output using CPU..." << std::endl; 336 | // launch_gemm_cpu(m, n, k, &alpha, A_host, lda, B_host, ldb, &beta, 337 | // C_host_ref, ldc); 338 | // std::cout << "Done." << std::endl; 339 | 340 | // Launch CUDA GEMM. 341 | CHECK_CUDA_ERROR(cudaMemcpy(C_device, C_host, m * ldc * sizeof(T), 342 | cudaMemcpyHostToDevice)); 343 | // Verify the correctness of CUDA GEMM. 344 | gemm_kernel_launch_function(m, n, k, &alpha, A_device, lda, B_device, ldb, 345 | &beta, C_device, ldc, stream); 346 | 347 | // launch_gemm_cublas(m, n, k, &alpha, A_device, lda, B_device, ldb, 348 | // &beta, 349 | // C_device, ldc, handle); 350 | 351 | CHECK_CUDA_ERROR(cudaStreamSynchronize(stream)); 352 | CHECK_CUDA_ERROR(cudaMemcpy(C_host_from_device, C_device, 353 | m * ldc * sizeof(T), cudaMemcpyDeviceToHost)); 354 | assert(all_close(C_host_from_device, C_host_ref, m, n, ldc, abs_tol, 355 | rel_tol)); 356 | 357 | // Launch cuBLAS GEMM. 358 | float const latency_cublas{measure_performance( 359 | [&](cudaStream_t stream) 360 | { 361 | launch_gemm_cublas(m, n, k, &alpha, A_device, lda, B_device, ldb, 362 | &beta, C_device, ldc, handle); 363 | return; 364 | }, 365 | stream, num_repeats, num_warmups)}; 366 | 367 | float const latency_cuda_gemm{measure_performance( 368 | [&](cudaStream_t stream) 369 | { 370 | gemm_kernel_launch_function(m, n, k, &alpha, A_device, lda, 371 | B_device, ldb, &beta, C_device, ldc, 372 | stream); 373 | return; 374 | }, 375 | stream, num_repeats, num_warmups)}; 376 | 377 | // Release resources. 378 | CHECK_CUDA_ERROR(cudaFree(A_device)); 379 | CHECK_CUDA_ERROR(cudaFree(B_device)); 380 | CHECK_CUDA_ERROR(cudaFree(C_device)); 381 | CHECK_CUDA_ERROR(cudaFreeHost(A_host)); 382 | CHECK_CUDA_ERROR(cudaFreeHost(B_host)); 383 | CHECK_CUDA_ERROR(cudaFreeHost(C_host)); 384 | CHECK_CUDA_ERROR(cudaFreeHost(C_host_ref)); 385 | CHECK_CUDA_ERROR(cudaFreeHost(C_host_from_device)); 386 | CHECK_CUBLASS_ERROR(cublasDestroy(handle)); 387 | CHECK_CUDA_ERROR(cudaStreamDestroy(stream)); 388 | 389 | std::cout << "cuBLAS GEMM Kernel Performance" << std::endl; 390 | print_performance_result(m, n, k, latency_cublas); 391 | std::cout << "Custom GEMM Kernel Performance" << std::endl; 392 | print_performance_result(m, n, k, latency_cuda_gemm); 393 | std::cout << "Custom GEMM VS cuBLAS GEMM Performance: " 394 | << latency_cublas / latency_cuda_gemm * 100.0f << "%" 395 | << std::endl; 396 | 397 | return std::pair{latency_cublas, latency_cuda_gemm}; 398 | } 399 | 400 | #endif // PROFILE_UTILS_CUH -------------------------------------------------------------------------------- /src/06_2d_block_tiling_2d_warp_tiling_2d_thread_tiling_matrix_transpose.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "cuda_gemm.hpp" 4 | #include "cuda_gemm_utils.cuh" 5 | #include "cuda_gemm_utils.hpp" 6 | 7 | template 9 | __device__ void load_data_from_shared_memory_to_register_file( 10 | T const thread_block_tile[BLOCK_TILE_SIZE], 11 | T register_values[NUM_THREAD_TILES_PER_WARP][THREAD_TILE_SIZE], 12 | size_t warp_idx, size_t thread_idx) 13 | { 14 | static_assert(BLOCK_TILE_SIZE % THREAD_TILE_SIZE == 0U); 15 | #pragma unroll 16 | for (size_t thread_tile_repeat_idx{0U}; 17 | thread_tile_repeat_idx < NUM_THREAD_TILES_PER_WARP; 18 | ++thread_tile_repeat_idx) 19 | { 20 | size_t const thread_block_tile_idx{ 21 | warp_idx * WARP_TILE_SIZE + 22 | thread_tile_repeat_idx * 23 | (WARP_TILE_SIZE / NUM_THREAD_TILES_PER_WARP) + 24 | thread_idx * THREAD_TILE_SIZE}; 25 | #pragma unroll 26 | for (size_t thread_tile_idx{0U}; thread_tile_idx < THREAD_TILE_SIZE; 27 | ++thread_tile_idx) 28 | { 29 | register_values[thread_tile_repeat_idx][thread_tile_idx] = 30 | thread_block_tile[thread_block_tile_idx + thread_tile_idx]; 31 | } 32 | } 33 | } 34 | 35 | template 38 | __device__ void compute_thread_tile_results( 39 | T const A_vals[NUM_THREAD_TILES_PER_WARP_Y][THREAD_TILE_SIZE_Y], 40 | T const B_vals[NUM_THREAD_TILES_PER_WARP_X][THREAD_TILE_SIZE_X], 41 | T C_thread_results[NUM_THREAD_TILES_PER_WARP_Y][NUM_THREAD_TILES_PER_WARP_X] 42 | [THREAD_TILE_SIZE_Y][THREAD_TILE_SIZE_X]) 43 | { 44 | // Compute NUM_THREAD_TILES_PER_WARP_Y * NUM_THREAD_TILES_PER_WARP_X outer 45 | // products. 46 | #pragma unroll 47 | for (size_t thread_tile_repeat_row_idx{0U}; 48 | thread_tile_repeat_row_idx < NUM_THREAD_TILES_PER_WARP_Y; 49 | ++thread_tile_repeat_row_idx) 50 | { 51 | #pragma unroll 52 | for (size_t thread_tile_repeat_col_idx{0U}; 53 | thread_tile_repeat_col_idx < NUM_THREAD_TILES_PER_WARP_X; 54 | ++thread_tile_repeat_col_idx) 55 | { 56 | #pragma unroll 57 | for (size_t thread_tile_y_idx{0U}; 58 | thread_tile_y_idx < THREAD_TILE_SIZE_Y; ++thread_tile_y_idx) 59 | { 60 | #pragma unroll 61 | for (size_t thread_tile_x_idx{0U}; 62 | thread_tile_x_idx < THREAD_TILE_SIZE_X; 63 | ++thread_tile_x_idx) 64 | { 65 | C_thread_results[thread_tile_repeat_row_idx] 66 | [thread_tile_repeat_col_idx] 67 | [thread_tile_y_idx][thread_tile_x_idx] += 68 | A_vals[thread_tile_repeat_row_idx][thread_tile_y_idx] * 69 | B_vals[thread_tile_repeat_col_idx][thread_tile_x_idx]; 70 | } 71 | } 72 | } 73 | } 74 | } 75 | 76 | template 81 | __device__ void write_results_from_register_file_to_global_memory( 82 | T const C_thread_results[NUM_THREAD_TILES_PER_WARP_Y] 83 | [NUM_THREAD_TILES_PER_WARP_X][THREAD_TILE_SIZE_Y] 84 | [THREAD_TILE_SIZE_X], 85 | T alpha, T beta, T* C, size_t ldc, size_t m, size_t n, size_t block_row_idx, 86 | size_t block_col_idx, size_t warp_row_idx, size_t warp_col_idx, 87 | size_t thread_row_idx_in_warp, size_t thread_col_idx_in_warp) 88 | { 89 | // Write the results to DRAM. 90 | #pragma unroll 91 | for (size_t thread_tile_repeat_row_idx{0U}; 92 | thread_tile_repeat_row_idx < NUM_THREAD_TILES_PER_WARP_Y; 93 | ++thread_tile_repeat_row_idx) 94 | { 95 | #pragma unroll 96 | for (size_t thread_tile_repeat_col_idx{0U}; 97 | thread_tile_repeat_col_idx < NUM_THREAD_TILES_PER_WARP_X; 98 | ++thread_tile_repeat_col_idx) 99 | { 100 | #pragma unroll 101 | for (size_t thread_tile_y_idx{0U}; 102 | thread_tile_y_idx < THREAD_TILE_SIZE_Y; ++thread_tile_y_idx) 103 | { 104 | #pragma unroll 105 | for (size_t thread_tile_x_idx{0U}; 106 | thread_tile_x_idx < THREAD_TILE_SIZE_X; 107 | ++thread_tile_x_idx) 108 | { 109 | size_t const C_row_idx{ 110 | block_row_idx * BLOCK_TILE_SIZE_Y + 111 | warp_row_idx * WARP_TILE_SIZE_Y + 112 | thread_tile_repeat_row_idx * 113 | (WARP_TILE_SIZE_Y / NUM_THREAD_TILES_PER_WARP_Y) + 114 | thread_row_idx_in_warp * THREAD_TILE_SIZE_Y + 115 | thread_tile_y_idx}; 116 | size_t const C_col_idx{ 117 | block_col_idx * BLOCK_TILE_SIZE_X + 118 | warp_col_idx * WARP_TILE_SIZE_X + 119 | thread_tile_repeat_col_idx * 120 | (WARP_TILE_SIZE_X / NUM_THREAD_TILES_PER_WARP_X) + 121 | thread_col_idx_in_warp * THREAD_TILE_SIZE_X + 122 | thread_tile_x_idx}; 123 | if (C_row_idx < m && C_col_idx < n) 124 | { 125 | C[C_row_idx * ldc + C_col_idx] = 126 | alpha * C_thread_results[thread_tile_repeat_row_idx] 127 | [thread_tile_repeat_col_idx] 128 | [thread_tile_y_idx] 129 | [thread_tile_x_idx] + 130 | beta * C[C_row_idx * ldc + C_col_idx]; 131 | } 132 | } 133 | } 134 | } 135 | } 136 | } 137 | 138 | // GEMM kernel v06. 139 | // Each thread in the block processes THREAD_TILE_SIZE_Y * 140 | // THREAD_TILE_SIZE_X output values. Number of threads BLOCK_TILE_SIZE_Y * 141 | // BLOCK_TILE_SIZE_X / (THREAD_TILE_SIZE_Y * THREAD_TILE_SIZE_X) 142 | template 147 | __global__ void gemm_v06(size_t m, size_t n, size_t k, T alpha, T const* A, 148 | size_t lda, T const* B, size_t ldb, T beta, T* C, 149 | size_t ldc) 150 | { 151 | static_assert(NUM_THREADS_PER_WARP_X * NUM_THREADS_PER_WARP_Y == 32U); 152 | constexpr size_t NUM_WARPS_X{BLOCK_TILE_SIZE_X / WARP_TILE_SIZE_X}; 153 | static_assert(BLOCK_TILE_SIZE_X % WARP_TILE_SIZE_X == 0U); 154 | constexpr size_t NUM_WARPS_Y{BLOCK_TILE_SIZE_Y / WARP_TILE_SIZE_Y}; 155 | static_assert(BLOCK_TILE_SIZE_Y % WARP_TILE_SIZE_Y == 0U); 156 | constexpr unsigned int NUM_THREAD_TILES_PER_WARP_X{ 157 | WARP_TILE_SIZE_X / (THREAD_TILE_SIZE_X * NUM_THREADS_PER_WARP_X)}; 158 | constexpr unsigned int NUM_THREAD_TILES_PER_WARP_Y{ 159 | WARP_TILE_SIZE_Y / (THREAD_TILE_SIZE_Y * NUM_THREADS_PER_WARP_Y)}; 160 | static_assert( 161 | WARP_TILE_SIZE_X % (THREAD_TILE_SIZE_X * NUM_THREADS_PER_WARP_X) == 0U); 162 | static_assert( 163 | WARP_TILE_SIZE_Y % (THREAD_TILE_SIZE_Y * NUM_THREADS_PER_WARP_Y) == 0U); 164 | 165 | constexpr unsigned int NUM_THREADS_X{NUM_WARPS_X * NUM_THREADS_PER_WARP_X}; 166 | constexpr unsigned int NUM_THREADS_Y{NUM_WARPS_Y * NUM_THREADS_PER_WARP_Y}; 167 | // Avoid using blockDim.x * blockDim.y as the number of threads per block. 168 | // Because it is a runtime constant and the compiler cannot optimize the 169 | // loop unrolling based on that. 170 | // Use a compile time constant instead. 171 | constexpr size_t NUM_THREADS{NUM_THREADS_X * NUM_THREADS_Y}; 172 | 173 | // Cache a tile of A and B in shared memory for data reuse. 174 | __shared__ T 175 | A_thread_block_tile_transposed[BLOCK_TILE_SIZE_K][BLOCK_TILE_SIZE_Y]; 176 | __shared__ T B_thread_block_tile[BLOCK_TILE_SIZE_K][BLOCK_TILE_SIZE_X]; 177 | 178 | // A_vals is cached in the register. 179 | T A_vals[NUM_THREAD_TILES_PER_WARP_Y][THREAD_TILE_SIZE_Y] = { 180 | static_cast(0)}; 181 | // B_vals is cached in the register. 182 | T B_vals[NUM_THREAD_TILES_PER_WARP_X][THREAD_TILE_SIZE_X] = { 183 | static_cast(0)}; 184 | 185 | size_t const thread_linear_idx{threadIdx.y * blockDim.x + threadIdx.x}; 186 | size_t const warp_linear_idx{thread_linear_idx / 32U}; 187 | size_t const warp_row_idx{warp_linear_idx / NUM_WARPS_X}; 188 | size_t const warp_col_idx{warp_linear_idx % NUM_WARPS_X}; 189 | size_t const thread_linear_idx_in_warp{thread_linear_idx % 32U}; 190 | size_t const thread_linear_row_idx_in_warp{thread_linear_idx_in_warp / 191 | NUM_THREADS_PER_WARP_X}; 192 | size_t const thread_linear_col_idx_in_warp{thread_linear_idx_in_warp % 193 | NUM_THREADS_PER_WARP_X}; 194 | 195 | // Number of outer loops to perform the sum of inner products. 196 | // C_thread_block_tile = 197 | // \sigma_{thread_block_tile_idx=0}^{num_thread_block_tiles-1} A[:, 198 | // thread_block_tile_idx:BLOCK_TILE_SIZE_K] * 199 | // B[thread_block_tile_idx:BLOCK_TILE_SIZE_K, :] 200 | size_t const num_thread_block_tiles{(k + BLOCK_TILE_SIZE_K - 1) / 201 | BLOCK_TILE_SIZE_K}; 202 | // Each thread in the block processes NUM_THREAD_TILES_PER_WARP_Y * 203 | // NUM_THREAD_TILES_PER_WARP_X * THREAD_TILE_SIZE_Y * 204 | // THREAD_TILE_SIZE_X output values. 205 | T C_thread_results[NUM_THREAD_TILES_PER_WARP_Y][NUM_THREAD_TILES_PER_WARP_X] 206 | [THREAD_TILE_SIZE_Y][THREAD_TILE_SIZE_X] = { 207 | static_cast(0)}; 208 | 209 | for (size_t thread_block_tile_idx{0U}; 210 | thread_block_tile_idx < num_thread_block_tiles; 211 | ++thread_block_tile_idx) 212 | { 213 | load_data_from_global_memory_to_shared_memory_transposed< 214 | T, BLOCK_TILE_SIZE_X, BLOCK_TILE_SIZE_Y, BLOCK_TILE_SIZE_K, 215 | NUM_THREADS>(A, lda, B, ldb, A_thread_block_tile_transposed, 216 | B_thread_block_tile, thread_block_tile_idx, 217 | thread_linear_idx, m, n, k); 218 | __syncthreads(); 219 | 220 | // Perform A[:, thread_block_tile_idx:BLOCK_TILE_SIZE_K] * 221 | // B[thread_block_tile_idx:BLOCK_TILE_SIZE_K, :] where A[:, 222 | // thread_block_tile_idx:BLOCK_TILE_SIZE_K] and 223 | // B[thread_block_tile_idx:BLOCK_TILE_SIZE_K, :] are cached in the 224 | // shared memory as A_thread_block_tile and B_thread_block_tile, 225 | // respectively. This inner product is further decomposed to 226 | // BLOCK_TILE_SIZE_K outer products. A_thread_block_tile * 227 | // B_thread_block_tile = \sigma_{k_i=0}^{BLOCK_TILE_SIZE_K-1} 228 | // A_thread_block_tile[:, k_i] @ B_thread_block_tile[k_i, :] Note that 229 | // both A_thread_block_tile and B_thread_block_tile can be cached in the 230 | // register. 231 | #pragma unroll 232 | for (size_t k_i{0U}; k_i < BLOCK_TILE_SIZE_K; ++k_i) 233 | { 234 | // Load data from shared memory to register file for A. 235 | load_data_from_shared_memory_to_register_file< 236 | T, BLOCK_TILE_SIZE_Y, WARP_TILE_SIZE_Y, NUM_THREADS_PER_WARP_Y, 237 | THREAD_TILE_SIZE_Y>(A_thread_block_tile_transposed[k_i], A_vals, 238 | warp_row_idx, 239 | thread_linear_row_idx_in_warp); 240 | // Load data from shared memory to register file for B. 241 | load_data_from_shared_memory_to_register_file< 242 | T, BLOCK_TILE_SIZE_X, WARP_TILE_SIZE_X, NUM_THREADS_PER_WARP_X, 243 | THREAD_TILE_SIZE_X>(B_thread_block_tile[k_i], B_vals, 244 | warp_col_idx, 245 | thread_linear_col_idx_in_warp); 246 | // Compute NUM_THREAD_TILES_PER_WARP_Y * NUM_THREAD_TILES_PER_WARP_X 247 | // outer products. 248 | compute_thread_tile_results( 251 | A_vals, B_vals, C_thread_results); 252 | } 253 | __syncthreads(); 254 | } 255 | 256 | // Write the results to DRAM. 257 | write_results_from_register_file_to_global_memory< 258 | T, BLOCK_TILE_SIZE_X, BLOCK_TILE_SIZE_Y, WARP_TILE_SIZE_X, 259 | WARP_TILE_SIZE_Y, THREAD_TILE_SIZE_X, THREAD_TILE_SIZE_Y, 260 | NUM_THREAD_TILES_PER_WARP_X, NUM_THREAD_TILES_PER_WARP_Y>( 261 | C_thread_results, alpha, beta, C, ldc, m, n, blockIdx.y, blockIdx.x, 262 | warp_row_idx, warp_col_idx, thread_linear_row_idx_in_warp, 263 | thread_linear_col_idx_in_warp); 264 | } 265 | 266 | template 267 | void launch_gemm_kernel_v06(size_t m, size_t n, size_t k, T const* alpha, 268 | T const* A, size_t lda, T const* B, size_t ldb, 269 | T const* beta, T* C, size_t ldc, 270 | cudaStream_t stream) 271 | { 272 | // Feel free to play with the block tile sizes. 273 | // The algorithm correctness should always be guaranteed. 274 | constexpr unsigned int BLOCK_TILE_SIZE_X{128U}; 275 | constexpr unsigned int BLOCK_TILE_SIZE_Y{128U}; 276 | constexpr unsigned int BLOCK_TILE_SIZE_K{16U}; 277 | 278 | constexpr unsigned int WARP_TILE_SIZE_X{32U}; 279 | constexpr unsigned int WARP_TILE_SIZE_Y{64U}; 280 | constexpr unsigned int NUM_WARPS_X{BLOCK_TILE_SIZE_X / WARP_TILE_SIZE_X}; 281 | constexpr unsigned int NUM_WARPS_Y{BLOCK_TILE_SIZE_Y / WARP_TILE_SIZE_Y}; 282 | static_assert(BLOCK_TILE_SIZE_X % WARP_TILE_SIZE_X == 0U); 283 | static_assert(BLOCK_TILE_SIZE_Y % WARP_TILE_SIZE_Y == 0U); 284 | 285 | constexpr unsigned int THREAD_TILE_SIZE_X{8U}; 286 | constexpr unsigned int THREAD_TILE_SIZE_Y{8U}; 287 | 288 | constexpr unsigned int NUM_THREADS_PER_WARP_X{4U}; 289 | constexpr unsigned int NUM_THREADS_PER_WARP_Y{8U}; 290 | static_assert(NUM_THREADS_PER_WARP_X * NUM_THREADS_PER_WARP_Y == 32U); 291 | static_assert( 292 | WARP_TILE_SIZE_X % (THREAD_TILE_SIZE_X * NUM_THREADS_PER_WARP_X) == 0U); 293 | static_assert( 294 | WARP_TILE_SIZE_Y % (THREAD_TILE_SIZE_Y * NUM_THREADS_PER_WARP_Y) == 0U); 295 | 296 | constexpr unsigned int NUM_THREADS_X{NUM_WARPS_X * NUM_THREADS_PER_WARP_X}; 297 | constexpr unsigned int NUM_THREADS_Y{NUM_WARPS_Y * NUM_THREADS_PER_WARP_Y}; 298 | 299 | constexpr unsigned int NUM_THREADS_PER_BLOCK{NUM_THREADS_X * NUM_THREADS_Y}; 300 | 301 | dim3 const block_dim{NUM_THREADS_PER_BLOCK, 1U, 1U}; 302 | dim3 const grid_dim{ 303 | (static_cast(n) + BLOCK_TILE_SIZE_X - 1U) / 304 | BLOCK_TILE_SIZE_X, 305 | (static_cast(m) + BLOCK_TILE_SIZE_Y - 1U) / 306 | BLOCK_TILE_SIZE_Y, 307 | 1U}; 308 | gemm_v06 311 | <<>>(m, n, k, *alpha, A, lda, B, ldb, 312 | *beta, C, ldc); 313 | CHECK_LAST_CUDA_ERROR(); 314 | } 315 | 316 | // Explicit instantiation. 317 | template void launch_gemm_kernel_v06(size_t m, size_t n, size_t k, 318 | float const* alpha, float const* A, 319 | size_t lda, float const* B, 320 | size_t ldb, float const* beta, 321 | float* C, size_t ldc, 322 | cudaStream_t stream); 323 | template void launch_gemm_kernel_v06(size_t m, size_t n, size_t k, 324 | double const* alpha, 325 | double const* A, size_t lda, 326 | double const* B, size_t ldb, 327 | double const* beta, double* C, 328 | size_t ldc, cudaStream_t stream); 329 | template void launch_gemm_kernel_v06<__half>(size_t m, size_t n, size_t k, 330 | __half const* alpha, 331 | __half const* A, size_t lda, 332 | __half const* B, size_t ldb, 333 | __half const* beta, __half* C, 334 | size_t ldc, cudaStream_t stream); -------------------------------------------------------------------------------- /src/06_2d_block_tiling_2d_warp_tiling_2d_thread_tiling_matrix_transpose_vectorized_memory_access.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "cuda_gemm.hpp" 4 | #include "cuda_gemm_utils.cuh" 5 | #include "cuda_gemm_utils.hpp" 6 | 7 | template 9 | __device__ void load_data_from_shared_memory_to_register_file_vectorized( 10 | T const thread_block_tile[BLOCK_TILE_SIZE], 11 | T register_values[NUM_THREAD_TILES_PER_WARP][THREAD_TILE_SIZE], 12 | size_t warp_idx, size_t thread_idx) 13 | { 14 | static_assert(BLOCK_TILE_SIZE % THREAD_TILE_SIZE == 0U); 15 | constexpr size_t NUM_VECTOR_UNITS{sizeof(int4) / sizeof(T)}; 16 | static_assert(sizeof(int4) % sizeof(T) == 0U); 17 | constexpr size_t VECTORIZED_THREAD_TILE_SIZE{THREAD_TILE_SIZE / 18 | NUM_VECTOR_UNITS}; 19 | static_assert(THREAD_TILE_SIZE % NUM_VECTOR_UNITS == 0U); 20 | 21 | #pragma unroll 22 | for (size_t thread_tile_repeat_row_idx{0U}; 23 | thread_tile_repeat_row_idx < NUM_THREAD_TILES_PER_WARP; 24 | ++thread_tile_repeat_row_idx) 25 | { 26 | size_t const thread_block_tile_row_idx{ 27 | warp_idx * WARP_TILE_SIZE + 28 | thread_tile_repeat_row_idx * 29 | (WARP_TILE_SIZE / NUM_THREAD_TILES_PER_WARP) + 30 | thread_idx * THREAD_TILE_SIZE}; 31 | #pragma unroll 32 | for (size_t thread_tile_vector_idx{0U}; 33 | thread_tile_vector_idx < VECTORIZED_THREAD_TILE_SIZE; 34 | ++thread_tile_vector_idx) 35 | { 36 | *reinterpret_cast( 37 | ®ister_values[thread_tile_repeat_row_idx] 38 | [thread_tile_vector_idx * NUM_VECTOR_UNITS]) = 39 | *reinterpret_cast( 40 | &thread_block_tile[thread_block_tile_row_idx + 41 | thread_tile_vector_idx * 42 | NUM_VECTOR_UNITS]); 43 | } 44 | } 45 | } 46 | 47 | template 50 | __device__ void compute_thread_tile_results( 51 | T const A_vals[NUM_THREAD_TILES_PER_WARP_Y][THREAD_TILE_SIZE_Y], 52 | T const B_vals[NUM_THREAD_TILES_PER_WARP_X][THREAD_TILE_SIZE_X], 53 | T C_thread_results[NUM_THREAD_TILES_PER_WARP_Y][NUM_THREAD_TILES_PER_WARP_X] 54 | [THREAD_TILE_SIZE_Y][THREAD_TILE_SIZE_X]) 55 | { 56 | // Compute NUM_THREAD_TILES_PER_WARP_Y * NUM_THREAD_TILES_PER_WARP_X outer 57 | // products. 58 | #pragma unroll 59 | for (size_t thread_tile_repeat_row_idx{0U}; 60 | thread_tile_repeat_row_idx < NUM_THREAD_TILES_PER_WARP_Y; 61 | ++thread_tile_repeat_row_idx) 62 | { 63 | #pragma unroll 64 | for (size_t thread_tile_repeat_col_idx{0U}; 65 | thread_tile_repeat_col_idx < NUM_THREAD_TILES_PER_WARP_X; 66 | ++thread_tile_repeat_col_idx) 67 | { 68 | #pragma unroll 69 | for (size_t thread_tile_y_idx{0U}; 70 | thread_tile_y_idx < THREAD_TILE_SIZE_Y; ++thread_tile_y_idx) 71 | { 72 | #pragma unroll 73 | for (size_t thread_tile_x_idx{0U}; 74 | thread_tile_x_idx < THREAD_TILE_SIZE_X; 75 | ++thread_tile_x_idx) 76 | { 77 | C_thread_results[thread_tile_repeat_row_idx] 78 | [thread_tile_repeat_col_idx] 79 | [thread_tile_y_idx][thread_tile_x_idx] += 80 | A_vals[thread_tile_repeat_row_idx][thread_tile_y_idx] * 81 | B_vals[thread_tile_repeat_col_idx][thread_tile_x_idx]; 82 | } 83 | } 84 | } 85 | } 86 | } 87 | 88 | template 93 | __device__ void write_results_from_register_file_to_global_memory_vectorized( 94 | T const C_thread_results[NUM_THREAD_TILES_PER_WARP_Y] 95 | [NUM_THREAD_TILES_PER_WARP_X][THREAD_TILE_SIZE_Y] 96 | [THREAD_TILE_SIZE_X], 97 | T alpha, T beta, T* C, size_t ldc, size_t m, size_t n, size_t block_row_idx, 98 | size_t block_col_idx, size_t warp_row_idx, size_t warp_col_idx, 99 | size_t thread_row_idx_in_warp, size_t thread_col_idx_in_warp) 100 | { 101 | constexpr size_t NUM_VECTOR_UNITS{sizeof(int4) / sizeof(T)}; 102 | static_assert(sizeof(int4) % sizeof(T) == 0U); 103 | static_assert(BLOCK_TILE_SIZE_X % NUM_VECTOR_UNITS == 0U); 104 | constexpr size_t VECTORIZED_THREAD_TILE_SIZE_X{THREAD_TILE_SIZE_X / 105 | NUM_VECTOR_UNITS}; 106 | static_assert(THREAD_TILE_SIZE_X % NUM_VECTOR_UNITS == 0U); 107 | 108 | // Write the results to DRAM. 109 | #pragma unroll 110 | for (size_t thread_tile_repeat_row_idx{0U}; 111 | thread_tile_repeat_row_idx < NUM_THREAD_TILES_PER_WARP_Y; 112 | ++thread_tile_repeat_row_idx) 113 | { 114 | #pragma unroll 115 | for (size_t thread_tile_repeat_col_idx{0U}; 116 | thread_tile_repeat_col_idx < NUM_THREAD_TILES_PER_WARP_X; 117 | ++thread_tile_repeat_col_idx) 118 | { 119 | #pragma unroll 120 | for (size_t thread_tile_y_idx{0U}; 121 | thread_tile_y_idx < THREAD_TILE_SIZE_Y; ++thread_tile_y_idx) 122 | { 123 | #pragma unroll 124 | for (size_t thread_tile_x_vector_idx{0U}; 125 | thread_tile_x_vector_idx < VECTORIZED_THREAD_TILE_SIZE_X; 126 | ++thread_tile_x_vector_idx) 127 | { 128 | size_t const C_row_idx{ 129 | blockIdx.y * BLOCK_TILE_SIZE_Y + 130 | warp_row_idx * WARP_TILE_SIZE_Y + 131 | thread_tile_repeat_row_idx * 132 | (WARP_TILE_SIZE_Y / NUM_THREAD_TILES_PER_WARP_Y) + 133 | thread_row_idx_in_warp * THREAD_TILE_SIZE_Y + 134 | thread_tile_y_idx}; 135 | size_t const C_col_idx{ 136 | blockIdx.x * BLOCK_TILE_SIZE_X + 137 | warp_col_idx * WARP_TILE_SIZE_X + 138 | thread_tile_repeat_col_idx * 139 | (WARP_TILE_SIZE_X / NUM_THREAD_TILES_PER_WARP_X) + 140 | thread_col_idx_in_warp * THREAD_TILE_SIZE_X + 141 | thread_tile_x_vector_idx * NUM_VECTOR_UNITS}; 142 | 143 | if (C_row_idx < m && C_col_idx < n) 144 | { 145 | int4 C_vals{*reinterpret_cast( 146 | &C[C_row_idx * ldc + C_col_idx])}; 147 | #pragma unroll 148 | for (size_t i{0U}; i < NUM_VECTOR_UNITS; ++i) 149 | { 150 | reinterpret_cast(&C_vals)[i] = 151 | alpha * 152 | C_thread_results[thread_tile_repeat_row_idx] 153 | [thread_tile_repeat_col_idx] 154 | [thread_tile_y_idx] 155 | [thread_tile_x_vector_idx * 156 | NUM_VECTOR_UNITS + 157 | i] + 158 | beta * reinterpret_cast(&C_vals)[i]; 159 | } 160 | *reinterpret_cast( 161 | &C[C_row_idx * ldc + C_col_idx]) = C_vals; 162 | } 163 | } 164 | } 165 | } 166 | } 167 | } 168 | 169 | // GEMM kernel v06. 170 | // Each thread in the block processes THREAD_TILE_SIZE_Y * 171 | // THREAD_TILE_SIZE_X output values. Number of threads BLOCK_TILE_SIZE_Y * 172 | // BLOCK_TILE_SIZE_X / (THREAD_TILE_SIZE_Y * THREAD_TILE_SIZE_X) 173 | template 178 | __global__ void gemm_v06_vectorized(size_t m, size_t n, size_t k, T alpha, 179 | T const* A, size_t lda, T const* B, 180 | size_t ldb, T beta, T* C, size_t ldc) 181 | { 182 | static_assert(NUM_THREADS_PER_WARP_X * NUM_THREADS_PER_WARP_Y == 32U); 183 | constexpr size_t NUM_WARPS_X{BLOCK_TILE_SIZE_X / WARP_TILE_SIZE_X}; 184 | static_assert(BLOCK_TILE_SIZE_X % WARP_TILE_SIZE_X == 0U); 185 | constexpr size_t NUM_WARPS_Y{BLOCK_TILE_SIZE_Y / WARP_TILE_SIZE_Y}; 186 | static_assert(BLOCK_TILE_SIZE_Y % WARP_TILE_SIZE_Y == 0U); 187 | constexpr unsigned int NUM_THREAD_TILES_PER_WARP_X{ 188 | WARP_TILE_SIZE_X / (THREAD_TILE_SIZE_X * NUM_THREADS_PER_WARP_X)}; 189 | constexpr unsigned int NUM_THREAD_TILES_PER_WARP_Y{ 190 | WARP_TILE_SIZE_Y / (THREAD_TILE_SIZE_Y * NUM_THREADS_PER_WARP_Y)}; 191 | static_assert( 192 | WARP_TILE_SIZE_X % (THREAD_TILE_SIZE_X * NUM_THREADS_PER_WARP_X) == 0U); 193 | static_assert( 194 | WARP_TILE_SIZE_Y % (THREAD_TILE_SIZE_Y * NUM_THREADS_PER_WARP_Y) == 0U); 195 | 196 | constexpr unsigned int NUM_THREADS_X{NUM_WARPS_X * NUM_THREADS_PER_WARP_X}; 197 | constexpr unsigned int NUM_THREADS_Y{NUM_WARPS_Y * NUM_THREADS_PER_WARP_Y}; 198 | // Avoid using blockDim.x * blockDim.y as the number of threads per block. 199 | // Because it is a runtime constant and the compiler cannot optimize the 200 | // loop unrolling based on that. 201 | // Use a compile time constant instead. 202 | constexpr size_t NUM_THREADS{NUM_THREADS_X * NUM_THREADS_Y}; 203 | 204 | // Cache a tile of A and B in shared memory for data reuse. 205 | __shared__ T 206 | A_thread_block_tile_transposed[BLOCK_TILE_SIZE_K][BLOCK_TILE_SIZE_Y]; 207 | __shared__ T B_thread_block_tile[BLOCK_TILE_SIZE_K][BLOCK_TILE_SIZE_X]; 208 | 209 | // A_vals is cached in the register. 210 | T A_vals[NUM_THREAD_TILES_PER_WARP_Y][THREAD_TILE_SIZE_Y] = { 211 | static_cast(0)}; 212 | // B_vals is cached in the register. 213 | T B_vals[NUM_THREAD_TILES_PER_WARP_X][THREAD_TILE_SIZE_X] = { 214 | static_cast(0)}; 215 | 216 | size_t const thread_linear_idx{threadIdx.y * blockDim.x + threadIdx.x}; 217 | size_t const warp_linear_idx{thread_linear_idx / 32U}; 218 | size_t const warp_row_idx{warp_linear_idx / NUM_WARPS_X}; 219 | size_t const warp_col_idx{warp_linear_idx % NUM_WARPS_X}; 220 | size_t const thread_linear_idx_in_warp{thread_linear_idx % 32U}; 221 | size_t const thread_linear_row_idx_in_warp{thread_linear_idx_in_warp / 222 | NUM_THREADS_PER_WARP_X}; 223 | size_t const thread_linear_col_idx_in_warp{thread_linear_idx_in_warp % 224 | NUM_THREADS_PER_WARP_X}; 225 | 226 | // Number of outer loops to perform the sum of inner products. 227 | // C_thread_block_tile = 228 | // \sigma_{thread_block_tile_idx=0}^{num_thread_block_tiles-1} A[:, 229 | // thread_block_tile_idx:BLOCK_TILE_SIZE_K] * 230 | // B[thread_block_tile_idx:BLOCK_TILE_SIZE_K, :] 231 | size_t const num_thread_block_tiles{(k + BLOCK_TILE_SIZE_K - 1) / 232 | BLOCK_TILE_SIZE_K}; 233 | // Each thread in the block processes NUM_THREAD_TILES_PER_WARP_Y * 234 | // NUM_THREAD_TILES_PER_WARP_X * THREAD_TILE_SIZE_Y * 235 | // THREAD_TILE_SIZE_X output values. 236 | T C_thread_results[NUM_THREAD_TILES_PER_WARP_Y][NUM_THREAD_TILES_PER_WARP_X] 237 | [THREAD_TILE_SIZE_Y][THREAD_TILE_SIZE_X] = { 238 | static_cast(0)}; 239 | 240 | for (size_t thread_block_tile_idx{0U}; 241 | thread_block_tile_idx < num_thread_block_tiles; 242 | ++thread_block_tile_idx) 243 | { 244 | load_data_from_global_memory_to_shared_memory_transposed_vectorized< 245 | T, BLOCK_TILE_SIZE_X, BLOCK_TILE_SIZE_Y, BLOCK_TILE_SIZE_K, 246 | NUM_THREADS>(A, lda, B, ldb, A_thread_block_tile_transposed, 247 | B_thread_block_tile, thread_block_tile_idx, 248 | thread_linear_idx, m, n, k); 249 | __syncthreads(); 250 | 251 | // Perform A[:, thread_block_tile_idx:BLOCK_TILE_SIZE_K] * 252 | // B[thread_block_tile_idx:BLOCK_TILE_SIZE_K, :] where A[:, 253 | // thread_block_tile_idx:BLOCK_TILE_SIZE_K] and 254 | // B[thread_block_tile_idx:BLOCK_TILE_SIZE_K, :] are cached in the 255 | // shared memory as A_thread_block_tile and B_thread_block_tile, 256 | // respectively. This inner product is further decomposed to 257 | // BLOCK_TILE_SIZE_K outer products. A_thread_block_tile * 258 | // B_thread_block_tile = \sigma_{k_i=0}^{BLOCK_TILE_SIZE_K-1} 259 | // A_thread_block_tile[:, k_i] @ B_thread_block_tile[k_i, :] Note that 260 | // both A_thread_block_tile and B_thread_block_tile can be cached in the 261 | // register. 262 | #pragma unroll 263 | for (size_t k_i{0U}; k_i < BLOCK_TILE_SIZE_K; ++k_i) 264 | { 265 | // Load data from shared memory to register file for A. 266 | load_data_from_shared_memory_to_register_file_vectorized< 267 | T, BLOCK_TILE_SIZE_Y, WARP_TILE_SIZE_Y, NUM_THREADS_PER_WARP_Y, 268 | THREAD_TILE_SIZE_Y>(A_thread_block_tile_transposed[k_i], A_vals, 269 | warp_row_idx, 270 | thread_linear_row_idx_in_warp); 271 | // Load data from shared memory to register file for B. 272 | load_data_from_shared_memory_to_register_file_vectorized< 273 | T, BLOCK_TILE_SIZE_X, WARP_TILE_SIZE_X, NUM_THREADS_PER_WARP_X, 274 | THREAD_TILE_SIZE_X>(B_thread_block_tile[k_i], B_vals, 275 | warp_col_idx, 276 | thread_linear_col_idx_in_warp); 277 | 278 | // Compute NUM_THREAD_TILES_PER_WARP_Y * NUM_THREAD_TILES_PER_WARP_X 279 | // outer products. 280 | compute_thread_tile_results( 283 | A_vals, B_vals, C_thread_results); 284 | } 285 | __syncthreads(); 286 | } 287 | 288 | // Write the results to DRAM. 289 | write_results_from_register_file_to_global_memory_vectorized< 290 | T, BLOCK_TILE_SIZE_X, BLOCK_TILE_SIZE_Y, WARP_TILE_SIZE_X, 291 | WARP_TILE_SIZE_Y, THREAD_TILE_SIZE_X, THREAD_TILE_SIZE_Y, 292 | NUM_THREAD_TILES_PER_WARP_X, NUM_THREAD_TILES_PER_WARP_Y>( 293 | C_thread_results, alpha, beta, C, ldc, m, n, blockIdx.y, blockIdx.x, 294 | warp_row_idx, warp_col_idx, thread_linear_row_idx_in_warp, 295 | thread_linear_col_idx_in_warp); 296 | } 297 | 298 | template 299 | void launch_gemm_kernel_v06_vectorized(size_t m, size_t n, size_t k, 300 | T const* alpha, T const* A, size_t lda, 301 | T const* B, size_t ldb, T const* beta, 302 | T* C, size_t ldc, cudaStream_t stream) 303 | { 304 | // Feel free to play with the block tile sizes. 305 | // The algorithm correctness should always be guaranteed. 306 | constexpr unsigned int BLOCK_TILE_SIZE_X{128U}; 307 | constexpr unsigned int BLOCK_TILE_SIZE_Y{128U}; 308 | constexpr unsigned int BLOCK_TILE_SIZE_K{16U}; 309 | 310 | constexpr unsigned int WARP_TILE_SIZE_X{32U}; 311 | constexpr unsigned int WARP_TILE_SIZE_Y{64U}; 312 | constexpr unsigned int NUM_WARPS_X{BLOCK_TILE_SIZE_X / WARP_TILE_SIZE_X}; 313 | constexpr unsigned int NUM_WARPS_Y{BLOCK_TILE_SIZE_Y / WARP_TILE_SIZE_Y}; 314 | static_assert(BLOCK_TILE_SIZE_X % WARP_TILE_SIZE_X == 0U); 315 | static_assert(BLOCK_TILE_SIZE_Y % WARP_TILE_SIZE_Y == 0U); 316 | 317 | constexpr unsigned int THREAD_TILE_SIZE_X{8U}; 318 | constexpr unsigned int THREAD_TILE_SIZE_Y{8U}; 319 | 320 | constexpr unsigned int NUM_THREADS_PER_WARP_X{4U}; 321 | constexpr unsigned int NUM_THREADS_PER_WARP_Y{8U}; 322 | static_assert(NUM_THREADS_PER_WARP_X * NUM_THREADS_PER_WARP_Y == 32U); 323 | static_assert( 324 | WARP_TILE_SIZE_X % (THREAD_TILE_SIZE_X * NUM_THREADS_PER_WARP_X) == 0U); 325 | static_assert( 326 | WARP_TILE_SIZE_Y % (THREAD_TILE_SIZE_Y * NUM_THREADS_PER_WARP_Y) == 0U); 327 | 328 | constexpr unsigned int NUM_THREADS_X{NUM_WARPS_X * NUM_THREADS_PER_WARP_X}; 329 | constexpr unsigned int NUM_THREADS_Y{NUM_WARPS_Y * NUM_THREADS_PER_WARP_Y}; 330 | 331 | constexpr unsigned int NUM_THREADS_PER_BLOCK{NUM_THREADS_X * NUM_THREADS_Y}; 332 | 333 | dim3 const block_dim{NUM_THREADS_PER_BLOCK, 1U, 1U}; 334 | dim3 const grid_dim{ 335 | (static_cast(n) + BLOCK_TILE_SIZE_X - 1U) / 336 | BLOCK_TILE_SIZE_X, 337 | (static_cast(m) + BLOCK_TILE_SIZE_Y - 1U) / 338 | BLOCK_TILE_SIZE_Y, 339 | 1U}; 340 | gemm_v06_vectorized 344 | <<>>(m, n, k, *alpha, A, lda, B, ldb, 345 | *beta, C, ldc); 346 | CHECK_LAST_CUDA_ERROR(); 347 | } 348 | 349 | // Explicit instantiation. 350 | template void launch_gemm_kernel_v06_vectorized( 351 | size_t m, size_t n, size_t k, float const* alpha, float const* A, 352 | size_t lda, float const* B, size_t ldb, float const* beta, float* C, 353 | size_t ldc, cudaStream_t stream); 354 | template void launch_gemm_kernel_v06_vectorized( 355 | size_t m, size_t n, size_t k, double const* alpha, double const* A, 356 | size_t lda, double const* B, size_t ldb, double const* beta, double* C, 357 | size_t ldc, cudaStream_t stream); 358 | template void launch_gemm_kernel_v06_vectorized<__half>( 359 | size_t m, size_t n, size_t k, __half const* alpha, __half const* A, 360 | size_t lda, __half const* B, size_t ldb, __half const* beta, __half* C, 361 | size_t ldc, cudaStream_t stream); -------------------------------------------------------------------------------- /src/07_2d_block_tiling_2d_warp_tiling_2d_thread_tiling_matrix_transpose_wmma_vectorized_memory_access_double_buffered.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "cuda_gemm.hpp" 5 | #include "cuda_gemm_utils.cuh" 6 | #include "cuda_gemm_utils.hpp" 7 | 8 | // https://developer.nvidia.com/blog/cutlass-linear-algebra-cuda/ 9 | // https://github.com/NVIDIA/cutlass/blob/b7508e337938137a699e486d8997646980acfc58/media/docs/programming_guidelines.md 10 | 11 | template < 12 | typename T, size_t BLOCK_TILE_SIZE_X, size_t BLOCK_TILE_SIZE_Y, 13 | size_t BLOCK_TILE_SIZE_K, size_t WARP_TILE_SIZE_X, size_t WARP_TILE_SIZE_Y, 14 | size_t WMMA_TILE_SIZE_X, size_t WMMA_TILE_SIZE_Y, size_t WMMA_TILE_SIZE_K, 15 | size_t NUM_WMMA_TILES_X, size_t NUM_WMMA_TILES_Y, size_t NUM_WMMA_TILES_K, 16 | size_t BLOCK_TILE_SKEW_SIZE_X, size_t BLOCK_TILE_SKEW_SIZE_Y> 17 | __device__ void process_data_from_shared_memory_using_wmma( 18 | nvcuda::wmma::fragment 21 | a_frags[NUM_WMMA_TILES_Y], 22 | nvcuda::wmma::fragment 25 | b_frags[NUM_WMMA_TILES_X], 26 | nvcuda::wmma::fragment 28 | acc_frags[NUM_WMMA_TILES_Y][NUM_WMMA_TILES_X], 29 | T const A_thread_block_tile_transposed[BLOCK_TILE_SIZE_K] 30 | [BLOCK_TILE_SIZE_Y + 31 | BLOCK_TILE_SKEW_SIZE_Y], 32 | T const B_thread_block_tile[BLOCK_TILE_SIZE_K] 33 | [BLOCK_TILE_SIZE_X + BLOCK_TILE_SKEW_SIZE_X], 34 | size_t warp_row_idx, size_t warp_col_idx) 35 | { 36 | #pragma unroll 37 | for (size_t k_i{0U}; k_i < NUM_WMMA_TILES_K; ++k_i) 38 | { 39 | #pragma unroll 40 | for (size_t wmma_tile_row_idx{0U}; wmma_tile_row_idx < NUM_WMMA_TILES_Y; 41 | ++wmma_tile_row_idx) 42 | { 43 | nvcuda::wmma::load_matrix_sync( 44 | a_frags[wmma_tile_row_idx], 45 | &A_thread_block_tile_transposed[k_i * WMMA_TILE_SIZE_K] 46 | [warp_row_idx * 47 | WARP_TILE_SIZE_Y + 48 | wmma_tile_row_idx * 49 | WMMA_TILE_SIZE_Y], 50 | BLOCK_TILE_SIZE_Y + BLOCK_TILE_SKEW_SIZE_Y); 51 | } 52 | #pragma unroll 53 | for (size_t wmma_tile_col_idx{0U}; wmma_tile_col_idx < NUM_WMMA_TILES_X; 54 | ++wmma_tile_col_idx) 55 | { 56 | nvcuda::wmma::load_matrix_sync( 57 | b_frags[wmma_tile_col_idx], 58 | &B_thread_block_tile[k_i * WMMA_TILE_SIZE_K] 59 | [warp_col_idx * WARP_TILE_SIZE_X + 60 | wmma_tile_col_idx * WMMA_TILE_SIZE_X], 61 | BLOCK_TILE_SIZE_X + BLOCK_TILE_SKEW_SIZE_X); 62 | } 63 | #pragma unroll 64 | for (size_t wmma_tile_row_idx{0U}; wmma_tile_row_idx < NUM_WMMA_TILES_Y; 65 | ++wmma_tile_row_idx) 66 | { 67 | #pragma unroll 68 | for (size_t wmma_tile_col_idx{0U}; 69 | wmma_tile_col_idx < NUM_WMMA_TILES_X; ++wmma_tile_col_idx) 70 | { 71 | // Perform the matrix multiplication. 72 | nvcuda::wmma::mma_sync( 73 | acc_frags[wmma_tile_row_idx][wmma_tile_col_idx], 74 | a_frags[wmma_tile_row_idx], b_frags[wmma_tile_col_idx], 75 | acc_frags[wmma_tile_row_idx][wmma_tile_col_idx]); 76 | } 77 | } 78 | } 79 | } 80 | 81 | // GEMM kernel v07. 82 | // Each thread in the block processes THREAD_TILE_SIZE_Y * 83 | // THREAD_TILE_SIZE_X output values. Number of threads BLOCK_TILE_SIZE_Y * 84 | // BLOCK_TILE_SIZE_X / (THREAD_TILE_SIZE_Y * THREAD_TILE_SIZE_X) 85 | template 90 | __global__ void 91 | gemm_v07_vectorized_double_buffered(size_t m, size_t n, size_t k, T alpha, 92 | T const* A, size_t lda, T const* B, 93 | size_t ldb, T beta, T* C, size_t ldc) 94 | { 95 | constexpr size_t NUM_WARPS_X{BLOCK_TILE_SIZE_X / WARP_TILE_SIZE_X}; 96 | constexpr size_t NUM_WARPS_Y{BLOCK_TILE_SIZE_Y / WARP_TILE_SIZE_Y}; 97 | static_assert(BLOCK_TILE_SIZE_X % WARP_TILE_SIZE_X == 0U); 98 | static_assert(BLOCK_TILE_SIZE_Y % WARP_TILE_SIZE_Y == 0U); 99 | 100 | constexpr size_t NUM_WMMA_TILES_X{WARP_TILE_SIZE_X / WMMA_TILE_SIZE_X}; 101 | static_assert(WARP_TILE_SIZE_X % WMMA_TILE_SIZE_X == 0U); 102 | constexpr size_t NUM_WMMA_TILES_Y{WARP_TILE_SIZE_Y / WMMA_TILE_SIZE_Y}; 103 | static_assert(WARP_TILE_SIZE_Y % WMMA_TILE_SIZE_Y == 0U); 104 | constexpr size_t NUM_WMMA_TILES_K{BLOCK_TILE_SIZE_K / WMMA_TILE_SIZE_K}; 105 | static_assert(BLOCK_TILE_SIZE_K % WMMA_TILE_SIZE_K == 0U); 106 | 107 | constexpr size_t NUM_PIPELINES{2U}; 108 | // Only double buffer is supported in the implementation. 109 | // But even more number of pipelines can be supported if the implementation 110 | // is modified. 111 | static_assert(NUM_PIPELINES == 2U); 112 | static_assert((NUM_WARPS_X * NUM_WARPS_Y) % NUM_PIPELINES == 0U); 113 | static_assert(NUM_THREADS % NUM_PIPELINES == 0U); 114 | constexpr size_t NUM_THREADS_PER_PIPELINE{NUM_THREADS / NUM_PIPELINES}; 115 | constexpr size_t NUM_WARPS_PER_PIPELINE{(NUM_WARPS_X * NUM_WARPS_Y) / 116 | NUM_PIPELINES}; 117 | 118 | // Cache a tile of A and B in shared memory for data reuse. 119 | __shared__ T 120 | A_thread_block_tile_transposed[NUM_PIPELINES][BLOCK_TILE_SIZE_K] 121 | [BLOCK_TILE_SIZE_Y + 122 | BLOCK_TILE_SKEW_SIZE_Y]; 123 | __shared__ T 124 | B_thread_block_tile[NUM_PIPELINES][BLOCK_TILE_SIZE_K] 125 | [BLOCK_TILE_SIZE_X + BLOCK_TILE_SKEW_SIZE_X]; 126 | 127 | // Declare the fragments. 128 | nvcuda::wmma::fragment 131 | a_frags[NUM_WMMA_TILES_Y]; 132 | nvcuda::wmma::fragment 135 | b_frags[NUM_WMMA_TILES_X]; 136 | nvcuda::wmma::fragment 138 | acc_frags[NUM_WMMA_TILES_Y][NUM_WMMA_TILES_X]; 139 | nvcuda::wmma::fragment 141 | c_frag; 142 | 143 | // Make sure the accumulator starts from 0. 144 | #pragma unroll 145 | for (size_t wmma_tile_row_idx{0U}; wmma_tile_row_idx < NUM_WMMA_TILES_Y; 146 | ++wmma_tile_row_idx) 147 | { 148 | for (size_t wmma_tile_col_idx{0U}; wmma_tile_col_idx < NUM_WMMA_TILES_X; 149 | ++wmma_tile_col_idx) 150 | { 151 | nvcuda::wmma::fill_fragment( 152 | acc_frags[wmma_tile_row_idx][wmma_tile_col_idx], 153 | static_cast(0)); 154 | } 155 | } 156 | 157 | size_t const thread_linear_idx{threadIdx.y * blockDim.x + threadIdx.x}; 158 | size_t const warp_linear_idx{thread_linear_idx / 32U}; 159 | size_t const warp_row_idx{warp_linear_idx / NUM_WARPS_X}; 160 | size_t const warp_col_idx{warp_linear_idx % NUM_WARPS_X}; 161 | // Separate the warps to different pipelines. 162 | size_t const pipeline_index{warp_linear_idx / NUM_WARPS_PER_PIPELINE}; 163 | 164 | // Number of outer loops to perform the sum of inner products. 165 | // C_thread_block_tile = 166 | // \sigma_{thread_block_tile_idx=0}^{num_thread_block_tiles-1} A[:, 167 | // thread_block_tile_idx:BLOCK_TILE_SIZE_K] * 168 | // B[thread_block_tile_idx:BLOCK_TILE_SIZE_K, :] 169 | size_t const num_thread_block_tiles{(k + BLOCK_TILE_SIZE_K - 1) / 170 | BLOCK_TILE_SIZE_K}; 171 | 172 | if (pipeline_index == 0U) 173 | { 174 | // Pipeline 0 warps load buffer 0. 175 | load_data_from_global_memory_to_shared_memory_transposed_vectorized< 176 | T, BLOCK_TILE_SIZE_X, BLOCK_TILE_SIZE_Y, BLOCK_TILE_SIZE_K, 177 | NUM_THREADS_PER_PIPELINE, BLOCK_TILE_SKEW_SIZE_X, 178 | BLOCK_TILE_SKEW_SIZE_Y>( 179 | A, lda, B, ldb, A_thread_block_tile_transposed[pipeline_index], 180 | B_thread_block_tile[pipeline_index], 0U, 181 | thread_linear_idx - pipeline_index * NUM_THREADS_PER_PIPELINE, m, n, 182 | k); 183 | } 184 | __syncthreads(); 185 | 186 | for (size_t thread_block_tile_idx{0U}; 187 | thread_block_tile_idx < num_thread_block_tiles; 188 | thread_block_tile_idx += NUM_PIPELINES) 189 | { 190 | if (pipeline_index == 0U) 191 | { 192 | // Pipeline 0 warps process buffer 0. 193 | process_data_from_shared_memory_using_wmma< 194 | T, BLOCK_TILE_SIZE_X, BLOCK_TILE_SIZE_Y, BLOCK_TILE_SIZE_K, 195 | WARP_TILE_SIZE_X, WARP_TILE_SIZE_Y, WMMA_TILE_SIZE_X, 196 | WMMA_TILE_SIZE_Y, WMMA_TILE_SIZE_K, NUM_WMMA_TILES_X, 197 | NUM_WMMA_TILES_Y, NUM_WMMA_TILES_K, BLOCK_TILE_SKEW_SIZE_X, 198 | BLOCK_TILE_SKEW_SIZE_Y>( 199 | a_frags, b_frags, acc_frags, 200 | A_thread_block_tile_transposed[pipeline_index], 201 | B_thread_block_tile[pipeline_index], warp_row_idx, 202 | warp_col_idx); 203 | __syncthreads(); 204 | 205 | // Pipeline 0 warps process buffer 1. 206 | if (thread_block_tile_idx + 1U < num_thread_block_tiles) 207 | { 208 | process_data_from_shared_memory_using_wmma< 209 | T, BLOCK_TILE_SIZE_X, BLOCK_TILE_SIZE_Y, BLOCK_TILE_SIZE_K, 210 | WARP_TILE_SIZE_X, WARP_TILE_SIZE_Y, WMMA_TILE_SIZE_X, 211 | WMMA_TILE_SIZE_Y, WMMA_TILE_SIZE_K, NUM_WMMA_TILES_X, 212 | NUM_WMMA_TILES_Y, NUM_WMMA_TILES_K, BLOCK_TILE_SKEW_SIZE_X, 213 | BLOCK_TILE_SKEW_SIZE_Y>( 214 | a_frags, b_frags, acc_frags, 215 | A_thread_block_tile_transposed[pipeline_index + 1], 216 | B_thread_block_tile[pipeline_index + 1], warp_row_idx, 217 | warp_col_idx); 218 | } 219 | __syncthreads(); 220 | 221 | // Pipeline 0 warps load buffer 0. 222 | if (thread_block_tile_idx + 2U < num_thread_block_tiles) 223 | { 224 | load_data_from_global_memory_to_shared_memory_transposed_vectorized< 225 | T, BLOCK_TILE_SIZE_X, BLOCK_TILE_SIZE_Y, BLOCK_TILE_SIZE_K, 226 | NUM_THREADS_PER_PIPELINE, BLOCK_TILE_SKEW_SIZE_X, 227 | BLOCK_TILE_SKEW_SIZE_Y>( 228 | A, lda, B, ldb, 229 | A_thread_block_tile_transposed[pipeline_index], 230 | B_thread_block_tile[pipeline_index], 231 | thread_block_tile_idx + 2, 232 | thread_linear_idx - 233 | pipeline_index * NUM_THREADS_PER_PIPELINE, 234 | m, n, k); 235 | } 236 | __syncthreads(); 237 | } 238 | else 239 | { 240 | // Pipeline 1 warps load buffer 1. 241 | if (thread_block_tile_idx + 1U < num_thread_block_tiles) 242 | { 243 | load_data_from_global_memory_to_shared_memory_transposed_vectorized< 244 | T, BLOCK_TILE_SIZE_X, BLOCK_TILE_SIZE_Y, BLOCK_TILE_SIZE_K, 245 | NUM_THREADS_PER_PIPELINE, BLOCK_TILE_SKEW_SIZE_X, 246 | BLOCK_TILE_SKEW_SIZE_Y>( 247 | A, lda, B, ldb, 248 | A_thread_block_tile_transposed[pipeline_index], 249 | B_thread_block_tile[pipeline_index], 250 | thread_block_tile_idx + 1, 251 | thread_linear_idx - 252 | pipeline_index * NUM_THREADS_PER_PIPELINE, 253 | m, n, k); 254 | } 255 | __syncthreads(); 256 | 257 | // Pipeline 1 warps process buffer 0. 258 | process_data_from_shared_memory_using_wmma< 259 | T, BLOCK_TILE_SIZE_X, BLOCK_TILE_SIZE_Y, BLOCK_TILE_SIZE_K, 260 | WARP_TILE_SIZE_X, WARP_TILE_SIZE_Y, WMMA_TILE_SIZE_X, 261 | WMMA_TILE_SIZE_Y, WMMA_TILE_SIZE_K, NUM_WMMA_TILES_X, 262 | NUM_WMMA_TILES_Y, NUM_WMMA_TILES_K, BLOCK_TILE_SKEW_SIZE_X, 263 | BLOCK_TILE_SKEW_SIZE_Y>( 264 | a_frags, b_frags, acc_frags, 265 | A_thread_block_tile_transposed[pipeline_index - 1], 266 | B_thread_block_tile[pipeline_index - 1], warp_row_idx, 267 | warp_col_idx); 268 | __syncthreads(); 269 | 270 | // Pipeline 1 warps process buffer 1. 271 | if (thread_block_tile_idx + 1U < num_thread_block_tiles) 272 | { 273 | process_data_from_shared_memory_using_wmma< 274 | T, BLOCK_TILE_SIZE_X, BLOCK_TILE_SIZE_Y, BLOCK_TILE_SIZE_K, 275 | WARP_TILE_SIZE_X, WARP_TILE_SIZE_Y, WMMA_TILE_SIZE_X, 276 | WMMA_TILE_SIZE_Y, WMMA_TILE_SIZE_K, NUM_WMMA_TILES_X, 277 | NUM_WMMA_TILES_Y, NUM_WMMA_TILES_K, BLOCK_TILE_SKEW_SIZE_X, 278 | BLOCK_TILE_SKEW_SIZE_Y>( 279 | a_frags, b_frags, acc_frags, 280 | A_thread_block_tile_transposed[pipeline_index], 281 | B_thread_block_tile[pipeline_index], warp_row_idx, 282 | warp_col_idx); 283 | } 284 | __syncthreads(); 285 | } 286 | } 287 | 288 | // Write the results to DRAM. 289 | #pragma unroll 290 | for (size_t wmma_tile_row_idx{0U}; wmma_tile_row_idx < NUM_WMMA_TILES_Y; 291 | ++wmma_tile_row_idx) 292 | { 293 | #pragma unroll 294 | for (size_t wmma_tile_col_idx{0U}; wmma_tile_col_idx < NUM_WMMA_TILES_X; 295 | ++wmma_tile_col_idx) 296 | { 297 | // Load the fragment from global memory. 298 | nvcuda::wmma::load_matrix_sync( 299 | c_frag, 300 | &C[(blockIdx.y * BLOCK_TILE_SIZE_Y + 301 | warp_row_idx * WARP_TILE_SIZE_Y + 302 | wmma_tile_row_idx * WMMA_TILE_SIZE_Y) * 303 | n + 304 | blockIdx.x * BLOCK_TILE_SIZE_X + 305 | warp_col_idx * WARP_TILE_SIZE_X + 306 | wmma_tile_col_idx * WMMA_TILE_SIZE_X], 307 | n, nvcuda::wmma::mem_row_major); 308 | // Perform scaling and addition. 309 | for (size_t i{0}; i < c_frag.num_elements; ++i) 310 | { 311 | c_frag.x[i] = 312 | alpha * 313 | acc_frags[wmma_tile_row_idx][wmma_tile_col_idx].x[i] + 314 | beta * c_frag.x[i]; 315 | } 316 | // Store the fragment back to global memory. 317 | nvcuda::wmma::store_matrix_sync( 318 | &C[(blockIdx.y * BLOCK_TILE_SIZE_Y + 319 | warp_row_idx * WARP_TILE_SIZE_Y + 320 | wmma_tile_row_idx * WMMA_TILE_SIZE_Y) * 321 | n + 322 | blockIdx.x * BLOCK_TILE_SIZE_X + 323 | warp_col_idx * WARP_TILE_SIZE_X + 324 | wmma_tile_col_idx * WMMA_TILE_SIZE_X], 325 | c_frag, n, nvcuda::wmma::mem_row_major); 326 | } 327 | } 328 | } 329 | 330 | template 331 | void launch_gemm_kernel_v07_vectorized_double_buffered( 332 | size_t m, size_t n, size_t k, T const* alpha, T const* A, size_t lda, 333 | T const* B, size_t ldb, T const* beta, T* C, size_t ldc, 334 | cudaStream_t stream) 335 | { 336 | // Feel free to play with the block tile sizes. 337 | // The algorithm correctness should always be guaranteed. 338 | constexpr unsigned int BLOCK_TILE_SIZE_X{128U}; 339 | constexpr unsigned int BLOCK_TILE_SIZE_Y{128U}; 340 | constexpr unsigned int BLOCK_TILE_SIZE_K{16U}; 341 | 342 | // The skew size is used to avoid bank conflicts in shared memory. 343 | constexpr size_t BLOCK_TILE_SKEW_SIZE_X{16U}; 344 | constexpr size_t BLOCK_TILE_SKEW_SIZE_Y{16U}; 345 | 346 | constexpr unsigned int WARP_TILE_SIZE_X{32U}; 347 | constexpr unsigned int WARP_TILE_SIZE_Y{64U}; 348 | constexpr unsigned int NUM_WARPS_X{BLOCK_TILE_SIZE_X / WARP_TILE_SIZE_X}; 349 | constexpr unsigned int NUM_WARPS_Y{BLOCK_TILE_SIZE_Y / WARP_TILE_SIZE_Y}; 350 | static_assert(BLOCK_TILE_SIZE_X % WARP_TILE_SIZE_X == 0U); 351 | static_assert(BLOCK_TILE_SIZE_Y % WARP_TILE_SIZE_Y == 0U); 352 | 353 | constexpr unsigned int WMMA_TILE_SIZE_X{16U}; 354 | constexpr unsigned int WMMA_TILE_SIZE_Y{16U}; 355 | constexpr unsigned int WMMA_TILE_SIZE_K{16U}; 356 | 357 | constexpr unsigned int NUM_THREADS_PER_BLOCK{NUM_WARPS_X * NUM_WARPS_Y * 358 | 32U}; 359 | 360 | dim3 const block_dim{NUM_THREADS_PER_BLOCK, 1U, 1U}; 361 | dim3 const grid_dim{ 362 | (static_cast(n) + BLOCK_TILE_SIZE_X - 1U) / 363 | BLOCK_TILE_SIZE_X, 364 | (static_cast(m) + BLOCK_TILE_SIZE_Y - 1U) / 365 | BLOCK_TILE_SIZE_Y, 366 | 1U}; 367 | gemm_v07_vectorized_double_buffered< 368 | T, BLOCK_TILE_SIZE_X, BLOCK_TILE_SIZE_Y, BLOCK_TILE_SIZE_K, 369 | BLOCK_TILE_SKEW_SIZE_X, BLOCK_TILE_SKEW_SIZE_Y, WARP_TILE_SIZE_X, 370 | WARP_TILE_SIZE_Y, WMMA_TILE_SIZE_X, WMMA_TILE_SIZE_Y, WMMA_TILE_SIZE_K, 371 | NUM_THREADS_PER_BLOCK><<>>( 372 | m, n, k, *alpha, A, lda, B, ldb, *beta, C, ldc); 373 | CHECK_LAST_CUDA_ERROR(); 374 | } 375 | 376 | // Explicit instantiation. 377 | template void launch_gemm_kernel_v07_vectorized_double_buffered<__half>( 378 | size_t m, size_t n, size_t k, __half const* alpha, __half const* A, 379 | size_t lda, __half const* B, size_t ldb, __half const* beta, __half* C, 380 | size_t ldc, cudaStream_t stream); -------------------------------------------------------------------------------- /include/cuda_gemm_utils.cuh: -------------------------------------------------------------------------------- 1 | #ifndef CUDA_GEMM_UTILS_CUH 2 | #define CUDA_GEMM_UTILS_CUH 3 | 4 | #include 5 | 6 | #include "cuda_gemm_utils.hpp" 7 | 8 | template 12 | __device__ void load_data_from_global_memory_to_shared_memory( 13 | T const* A, size_t lda, T const* B, size_t ldb, 14 | T A_thread_block_tile[BLOCK_TILE_SIZE_Y] 15 | [BLOCK_TILE_SIZE_K + BLOCK_TILE_SKEW_SIZE_K], 16 | T B_thread_block_tile[BLOCK_TILE_SIZE_K] 17 | [BLOCK_TILE_SIZE_X + BLOCK_TILE_SKEW_SIZE_X], 18 | size_t thread_block_tile_idx, size_t thread_linear_idx, size_t m, size_t n, 19 | size_t k) 20 | { 21 | // Load data from A on DRAM to A_thread_block_tile on shared memory. 22 | #pragma unroll 23 | for (size_t load_idx{0U}; 24 | load_idx < (BLOCK_TILE_SIZE_Y * BLOCK_TILE_SIZE_K + NUM_THREADS - 1U) / 25 | NUM_THREADS; 26 | ++load_idx) 27 | { 28 | size_t const A_thread_block_tile_row_idx{ 29 | (thread_linear_idx + load_idx * NUM_THREADS) / BLOCK_TILE_SIZE_K}; 30 | size_t const A_thread_block_tile_col_idx{ 31 | (thread_linear_idx + load_idx * NUM_THREADS) % BLOCK_TILE_SIZE_K}; 32 | size_t const A_row_idx{blockIdx.y * BLOCK_TILE_SIZE_Y + 33 | A_thread_block_tile_row_idx}; 34 | size_t const A_col_idx{thread_block_tile_idx * BLOCK_TILE_SIZE_K + 35 | A_thread_block_tile_col_idx}; 36 | 37 | // These boundary checks might slow down the kernel to some extent. 38 | // But they guarantee the correctness of the kernel for all 39 | // different GEMM configurations. 40 | T val{static_cast(0)}; 41 | if (A_row_idx < m && A_col_idx < k) 42 | { 43 | val = A[A_row_idx * lda + A_col_idx]; 44 | } 45 | // This if will slow down the kernel. 46 | // Add static asserts from the host code to guarantee this if is 47 | // always true. 48 | static_assert(BLOCK_TILE_SIZE_K * BLOCK_TILE_SIZE_Y % NUM_THREADS == 49 | 0U); 50 | // if (A_thread_block_tile_row_idx < BLOCK_TILE_SIZE_Y && 51 | // A_thread_block_tile_col_idx < BLOCK_TILE_SIZE_K) 52 | // { 53 | // A_thread_block_tile[A_thread_block_tile_row_idx] 54 | // [A_thread_block_tile_col_idx] = val; 55 | // } 56 | A_thread_block_tile[A_thread_block_tile_row_idx] 57 | [A_thread_block_tile_col_idx] = val; 58 | } 59 | // Load data from B on DRAM to B_thread_block_tile on shared memory. 60 | #pragma unroll 61 | for (size_t load_idx{0U}; 62 | load_idx < (BLOCK_TILE_SIZE_K * BLOCK_TILE_SIZE_X + NUM_THREADS - 1U) / 63 | NUM_THREADS; 64 | ++load_idx) 65 | { 66 | size_t const B_thread_block_tile_row_idx{ 67 | (thread_linear_idx + load_idx * NUM_THREADS) / BLOCK_TILE_SIZE_X}; 68 | size_t const B_thread_block_tile_col_idx{ 69 | (thread_linear_idx + load_idx * NUM_THREADS) % BLOCK_TILE_SIZE_X}; 70 | size_t const B_row_idx{thread_block_tile_idx * BLOCK_TILE_SIZE_K + 71 | B_thread_block_tile_row_idx}; 72 | size_t const B_col_idx{blockIdx.x * BLOCK_TILE_SIZE_X + 73 | B_thread_block_tile_col_idx}; 74 | 75 | // These boundary checks might slow down the kernel to some extent. 76 | // But they guarantee the correctness of the kernel for all 77 | // different GEMM configurations. 78 | T val{static_cast(0)}; 79 | if (B_row_idx < k && B_col_idx < n) 80 | { 81 | val = B[B_row_idx * ldb + B_col_idx]; 82 | } 83 | // This if will slow down the kernel. 84 | // Add static asserts from the host code to guarantee this if is 85 | // always true. 86 | static_assert(BLOCK_TILE_SIZE_X * BLOCK_TILE_SIZE_K % NUM_THREADS == 87 | 0U); 88 | // if (B_thread_block_tile_row_idx < BLOCK_TILE_SIZE_K && 89 | // B_thread_block_tile_col_idx < BLOCK_TILE_SIZE_X) 90 | // { 91 | // B_thread_block_tile[B_thread_block_tile_row_idx] 92 | // [B_thread_block_tile_col_idx] = val; 93 | // } 94 | B_thread_block_tile[B_thread_block_tile_row_idx] 95 | [B_thread_block_tile_col_idx] = val; 96 | } 97 | } 98 | 99 | template 103 | __device__ void load_data_from_global_memory_to_shared_memory_transposed( 104 | T const* A, size_t lda, T const* B, size_t ldb, 105 | T A_thread_block_tile_transposed[BLOCK_TILE_SIZE_K][BLOCK_TILE_SIZE_Y + 106 | BLOCK_TILE_SKEW_SIZE_Y], 107 | T B_thread_block_tile[BLOCK_TILE_SIZE_K] 108 | [BLOCK_TILE_SIZE_X + BLOCK_TILE_SKEW_SIZE_X], 109 | size_t thread_block_tile_idx, size_t thread_linear_idx, size_t m, size_t n, 110 | size_t k) 111 | { 112 | // Load data from A on DRAM to A_thread_block_tile on shared memory. 113 | #pragma unroll 114 | for (size_t load_idx{0U}; 115 | load_idx < (BLOCK_TILE_SIZE_Y * BLOCK_TILE_SIZE_K + NUM_THREADS - 1U) / 116 | NUM_THREADS; 117 | ++load_idx) 118 | { 119 | size_t const A_thread_block_tile_row_idx{ 120 | (thread_linear_idx + load_idx * NUM_THREADS) / BLOCK_TILE_SIZE_K}; 121 | size_t const A_thread_block_tile_col_idx{ 122 | (thread_linear_idx + load_idx * NUM_THREADS) % BLOCK_TILE_SIZE_K}; 123 | size_t const A_row_idx{blockIdx.y * BLOCK_TILE_SIZE_Y + 124 | A_thread_block_tile_row_idx}; 125 | size_t const A_col_idx{thread_block_tile_idx * BLOCK_TILE_SIZE_K + 126 | A_thread_block_tile_col_idx}; 127 | 128 | // These boundary checks might slow down the kernel to some extent. 129 | // But they guarantee the correctness of the kernel for all 130 | // different GEMM configurations. 131 | T val{static_cast(0)}; 132 | if (A_row_idx < m && A_col_idx < k) 133 | { 134 | val = A[A_row_idx * lda + A_col_idx]; 135 | } 136 | // Removing the if will give another ~2 FLOPs performance on RTX 137 | // 3090. But it will make the kernel incorrect for some GEMM 138 | // configurations. T val{A[A_row_idx * lda + A_col_idx]}; This if 139 | // will slow down the kernel. Add static asserts from the host code 140 | // to guarantee this if is always true. 141 | static_assert(BLOCK_TILE_SIZE_K * BLOCK_TILE_SIZE_Y % NUM_THREADS == 142 | 0U); 143 | // if (A_thread_block_tile_row_idx < BLOCK_TILE_SIZE_Y && 144 | // A_thread_block_tile_col_idx < BLOCK_TILE_SIZE_K) 145 | // { 146 | // A_thread_block_tile[A_thread_block_tile_row_idx] 147 | // [A_thread_block_tile_col_idx] = val; 148 | // } 149 | A_thread_block_tile_transposed[A_thread_block_tile_col_idx] 150 | [A_thread_block_tile_row_idx] = val; 151 | } 152 | // Load data from B on DRAM to B_thread_block_tile on shared memory. 153 | #pragma unroll 154 | for (size_t load_idx{0U}; 155 | load_idx < (BLOCK_TILE_SIZE_K * BLOCK_TILE_SIZE_X + NUM_THREADS - 1U) / 156 | NUM_THREADS; 157 | ++load_idx) 158 | { 159 | size_t const B_thread_block_tile_row_idx{ 160 | (thread_linear_idx + load_idx * NUM_THREADS) / BLOCK_TILE_SIZE_X}; 161 | size_t const B_thread_block_tile_col_idx{ 162 | (thread_linear_idx + load_idx * NUM_THREADS) % BLOCK_TILE_SIZE_X}; 163 | size_t const B_row_idx{thread_block_tile_idx * BLOCK_TILE_SIZE_K + 164 | B_thread_block_tile_row_idx}; 165 | size_t const B_col_idx{blockIdx.x * BLOCK_TILE_SIZE_X + 166 | B_thread_block_tile_col_idx}; 167 | 168 | // These boundary checks might slow down the kernel to some extent. 169 | // But they guarantee the correctness of the kernel for all 170 | // different GEMM configurations. 171 | T val{static_cast(0)}; 172 | if (B_row_idx < k && B_col_idx < n) 173 | { 174 | val = B[B_row_idx * ldb + B_col_idx]; 175 | } 176 | // Removing the if will give another ~2 FLOPs performance on RTX 177 | // 3090. But it will make the kernel incorrect for some GEMM 178 | // configurations. T val{B[B_row_idx * ldb + B_col_idx]}; This if 179 | // will slow down the kernel. Add static asserts from the host code 180 | // to guarantee this if is always true. 181 | static_assert(BLOCK_TILE_SIZE_X * BLOCK_TILE_SIZE_K % NUM_THREADS == 182 | 0U); 183 | // if (B_thread_block_tile_row_idx < BLOCK_TILE_SIZE_K && 184 | // B_thread_block_tile_col_idx < BLOCK_TILE_SIZE_X) 185 | // { 186 | // B_thread_block_tile[B_thread_block_tile_row_idx] 187 | // [B_thread_block_tile_col_idx] = val; 188 | // } 189 | B_thread_block_tile[B_thread_block_tile_row_idx] 190 | [B_thread_block_tile_col_idx] = val; 191 | } 192 | } 193 | 194 | template 198 | __device__ void load_data_from_global_memory_to_shared_memory_vectorized( 199 | T const* A, size_t lda, T const* B, size_t ldb, 200 | T A_thread_block_tile[BLOCK_TILE_SIZE_Y] 201 | [BLOCK_TILE_SIZE_K + BLOCK_TILE_SKEW_SIZE_K], 202 | T B_thread_block_tile[BLOCK_TILE_SIZE_K] 203 | [BLOCK_TILE_SIZE_X + BLOCK_TILE_SKEW_SIZE_X], 204 | size_t thread_block_tile_idx, size_t thread_linear_idx, size_t m, size_t n, 205 | size_t k) 206 | { 207 | constexpr size_t NUM_VECTOR_UNITS{sizeof(VECTOR_TYPE) / sizeof(T)}; 208 | static_assert(sizeof(VECTOR_TYPE) % sizeof(T) == 0U); 209 | static_assert(BLOCK_TILE_SIZE_K % NUM_VECTOR_UNITS == 0U); 210 | static_assert(BLOCK_TILE_SIZE_X % NUM_VECTOR_UNITS == 0U); 211 | constexpr size_t VECTORIZED_BLOCK_TILE_SIZE_K{BLOCK_TILE_SIZE_K / 212 | NUM_VECTOR_UNITS}; 213 | static_assert(BLOCK_TILE_SIZE_K % NUM_VECTOR_UNITS == 0U); 214 | constexpr size_t VECTORIZED_BLOCK_TILE_SIZE_X{BLOCK_TILE_SIZE_X / 215 | NUM_VECTOR_UNITS}; 216 | static_assert(BLOCK_TILE_SIZE_X % NUM_VECTOR_UNITS == 0U); 217 | 218 | // The skew size could affect the data alignment in shared memory when we 219 | // use vectorized load. We need to make sure the data alignment is correct. 220 | static_assert((BLOCK_TILE_SIZE_K) * sizeof(T) % sizeof(VECTOR_TYPE) == 0U); 221 | static_assert((BLOCK_TILE_SIZE_X) * sizeof(T) % sizeof(VECTOR_TYPE) == 0U); 222 | static_assert((BLOCK_TILE_SIZE_K + BLOCK_TILE_SKEW_SIZE_K) * sizeof(T) % 223 | sizeof(VECTOR_TYPE) == 224 | 0U); 225 | static_assert((BLOCK_TILE_SIZE_X + BLOCK_TILE_SKEW_SIZE_X) * sizeof(T) % 226 | sizeof(VECTOR_TYPE) == 227 | 0U); 228 | 229 | // Load data from A on DRAM to A_thread_block_tile on shared memory. 230 | #pragma unroll 231 | for (size_t load_idx{0U}; 232 | load_idx < 233 | (BLOCK_TILE_SIZE_Y * VECTORIZED_BLOCK_TILE_SIZE_K + NUM_THREADS - 1U) / 234 | NUM_THREADS; 235 | ++load_idx) 236 | { 237 | size_t const A_thread_block_tile_row_idx{ 238 | (thread_linear_idx + load_idx * NUM_THREADS) / 239 | VECTORIZED_BLOCK_TILE_SIZE_K}; 240 | size_t const A_thread_block_tile_col_idx{ 241 | (thread_linear_idx + load_idx * NUM_THREADS) % 242 | VECTORIZED_BLOCK_TILE_SIZE_K * NUM_VECTOR_UNITS}; 243 | size_t const A_row_idx{blockIdx.y * BLOCK_TILE_SIZE_Y + 244 | A_thread_block_tile_row_idx}; 245 | size_t const A_col_idx{thread_block_tile_idx * BLOCK_TILE_SIZE_K + 246 | A_thread_block_tile_col_idx}; 247 | 248 | // These boundary checks might slow down the kernel to some extent. 249 | // But they guarantee the correctness of the kernel for all 250 | // different GEMM configurations. 251 | VECTOR_TYPE A_row_vector_vals{0, 0, 0, 0}; 252 | if (A_row_idx < m && A_col_idx < k) 253 | { 254 | A_row_vector_vals = *reinterpret_cast( 255 | &A[A_row_idx * lda + A_col_idx]); 256 | } 257 | if (A_col_idx + NUM_VECTOR_UNITS > k) 258 | { 259 | // Number of invalid elements in the last vector. 260 | size_t const num_invalid_elements{A_col_idx + NUM_VECTOR_UNITS - k}; 261 | // Mask out the invalid elements. 262 | T* const A_row_vector_vals_ptr{ 263 | reinterpret_cast(&A_row_vector_vals)}; 264 | for (size_t i{0U}; i < num_invalid_elements; ++i) 265 | { 266 | A_row_vector_vals_ptr[NUM_VECTOR_UNITS - 1U - i] = 267 | static_cast(0); 268 | } 269 | } 270 | // If this is true, the following if can be removed. 271 | // static_assert(VECTORIZED_BLOCK_TILE_SIZE_K * BLOCK_TILE_SIZE_Y % 272 | // NUM_THREADS == 0U); 273 | if (A_thread_block_tile_row_idx < BLOCK_TILE_SIZE_Y && 274 | A_thread_block_tile_col_idx < BLOCK_TILE_SIZE_K) 275 | { 276 | *reinterpret_cast( 277 | &A_thread_block_tile[A_thread_block_tile_row_idx] 278 | [A_thread_block_tile_col_idx]) = 279 | A_row_vector_vals; 280 | } 281 | } 282 | // Load data from B on DRAM to B_thread_block_tile on shared memory. 283 | #pragma unroll 284 | for (size_t load_idx{0U}; 285 | load_idx < 286 | (BLOCK_TILE_SIZE_K * VECTORIZED_BLOCK_TILE_SIZE_X + NUM_THREADS - 1U) / 287 | NUM_THREADS; 288 | ++load_idx) 289 | { 290 | size_t const B_thread_block_tile_row_idx{ 291 | (thread_linear_idx + load_idx * NUM_THREADS) / 292 | VECTORIZED_BLOCK_TILE_SIZE_X}; 293 | size_t const B_thread_block_tile_col_idx{ 294 | (thread_linear_idx + load_idx * NUM_THREADS) % 295 | VECTORIZED_BLOCK_TILE_SIZE_X * NUM_VECTOR_UNITS}; 296 | size_t const B_row_idx{thread_block_tile_idx * BLOCK_TILE_SIZE_K + 297 | B_thread_block_tile_row_idx}; 298 | size_t const B_col_idx{blockIdx.x * BLOCK_TILE_SIZE_X + 299 | B_thread_block_tile_col_idx}; 300 | 301 | // These boundary checks might slow down the kernel to some extent. 302 | // But they guarantee the correctness of the kernel for all 303 | // different GEMM configurations. 304 | VECTOR_TYPE B_row_vector_vals{0, 0, 0, 0}; 305 | if (B_row_idx < k && B_col_idx < n) 306 | { 307 | B_row_vector_vals = *reinterpret_cast( 308 | &B[B_row_idx * ldb + B_col_idx]); 309 | } 310 | if (B_col_idx + NUM_VECTOR_UNITS > n) 311 | { 312 | // Number of invalid elements in the last vector. 313 | size_t const num_invalid_elements{B_col_idx + NUM_VECTOR_UNITS - n}; 314 | // Mask out the invalid elements. 315 | T* const B_row_vector_vals_ptr{ 316 | reinterpret_cast(&B_row_vector_vals)}; 317 | for (size_t i{0U}; i < num_invalid_elements; ++i) 318 | { 319 | B_row_vector_vals_ptr[NUM_VECTOR_UNITS - 1U - i] = 320 | static_cast(0); 321 | } 322 | } 323 | // If this is true, the following if can be removed. 324 | // static_assert(VECTORIZED_BLOCK_TILE_SIZE_X * BLOCK_TILE_SIZE_K % 325 | // NUM_THREADS == 326 | // 0U); 327 | if (B_thread_block_tile_row_idx < BLOCK_TILE_SIZE_K && 328 | B_thread_block_tile_col_idx < BLOCK_TILE_SIZE_X) 329 | { 330 | *reinterpret_cast( 331 | &B_thread_block_tile[B_thread_block_tile_row_idx] 332 | [B_thread_block_tile_col_idx]) = 333 | B_row_vector_vals; 334 | } 335 | } 336 | } 337 | 338 | template 342 | __device__ void 343 | load_data_from_global_memory_to_shared_memory_transposed_vectorized( 344 | T const* A, size_t lda, T const* B, size_t ldb, 345 | T A_thread_block_tile_transposed[BLOCK_TILE_SIZE_K][BLOCK_TILE_SIZE_Y + 346 | BLOCK_TILE_SKEW_SIZE_Y], 347 | T B_thread_block_tile[BLOCK_TILE_SIZE_K] 348 | [BLOCK_TILE_SIZE_X + BLOCK_TILE_SKEW_SIZE_X], 349 | size_t thread_block_tile_idx, size_t thread_linear_idx, size_t m, size_t n, 350 | size_t k) 351 | { 352 | constexpr size_t NUM_VECTOR_UNITS{sizeof(VECTOR_TYPE) / sizeof(T)}; 353 | static_assert(sizeof(VECTOR_TYPE) % sizeof(T) == 0U); 354 | static_assert(BLOCK_TILE_SIZE_K % NUM_VECTOR_UNITS == 0U); 355 | static_assert(BLOCK_TILE_SIZE_X % NUM_VECTOR_UNITS == 0U); 356 | constexpr size_t VECTORIZED_BLOCK_TILE_SIZE_K{BLOCK_TILE_SIZE_K / 357 | NUM_VECTOR_UNITS}; 358 | static_assert(BLOCK_TILE_SIZE_K % NUM_VECTOR_UNITS == 0U); 359 | constexpr size_t VECTORIZED_BLOCK_TILE_SIZE_X{BLOCK_TILE_SIZE_X / 360 | NUM_VECTOR_UNITS}; 361 | static_assert(BLOCK_TILE_SIZE_X % NUM_VECTOR_UNITS == 0U); 362 | 363 | // The skew size could affect the data alignment in shared memory when we 364 | // use vectorized load. We need to make sure the data alignment is correct. 365 | static_assert((BLOCK_TILE_SIZE_Y) * sizeof(T) % sizeof(VECTOR_TYPE) == 0U); 366 | static_assert((BLOCK_TILE_SIZE_X) * sizeof(T) % sizeof(VECTOR_TYPE) == 0U); 367 | static_assert((BLOCK_TILE_SIZE_Y + BLOCK_TILE_SKEW_SIZE_Y) * sizeof(T) % 368 | sizeof(VECTOR_TYPE) == 369 | 0U); 370 | static_assert((BLOCK_TILE_SIZE_X + BLOCK_TILE_SKEW_SIZE_X) * sizeof(T) % 371 | sizeof(VECTOR_TYPE) == 372 | 0U); 373 | 374 | // Load data from A on DRAM to A_thread_block_tile on shared memory. 375 | #pragma unroll 376 | for (size_t load_idx{0U}; 377 | load_idx < 378 | (BLOCK_TILE_SIZE_Y * VECTORIZED_BLOCK_TILE_SIZE_K + NUM_THREADS - 1U) / 379 | NUM_THREADS; 380 | ++load_idx) 381 | { 382 | size_t const A_thread_block_tile_row_idx{ 383 | (thread_linear_idx + load_idx * NUM_THREADS) / 384 | VECTORIZED_BLOCK_TILE_SIZE_K}; 385 | size_t const A_thread_block_tile_col_idx{ 386 | (thread_linear_idx + load_idx * NUM_THREADS) % 387 | VECTORIZED_BLOCK_TILE_SIZE_K * NUM_VECTOR_UNITS}; 388 | size_t const A_row_idx{blockIdx.y * BLOCK_TILE_SIZE_Y + 389 | A_thread_block_tile_row_idx}; 390 | size_t const A_col_idx{thread_block_tile_idx * BLOCK_TILE_SIZE_K + 391 | A_thread_block_tile_col_idx}; 392 | 393 | // These boundary checks might slow down the kernel to some extent. 394 | // But they guarantee the correctness of the kernel for all 395 | // different GEMM configurations. 396 | int4 A_row_vector_vals{0, 0, 0, 0}; 397 | if (A_row_idx < m && A_col_idx < k) 398 | { 399 | A_row_vector_vals = 400 | *reinterpret_cast(&A[A_row_idx * lda + A_col_idx]); 401 | } 402 | if (A_col_idx + NUM_VECTOR_UNITS > k) 403 | { 404 | // Number of invalid elements in the last vector. 405 | size_t const num_invalid_elements{A_col_idx + NUM_VECTOR_UNITS - k}; 406 | // Mask out the invalid elements. 407 | T* const A_row_vector_vals_ptr{ 408 | reinterpret_cast(&A_row_vector_vals)}; 409 | for (size_t i{0U}; i < num_invalid_elements; ++i) 410 | { 411 | A_row_vector_vals_ptr[NUM_VECTOR_UNITS - 1U - i] = 412 | static_cast(0); 413 | } 414 | } 415 | // If this is true, the following if can be removed. 416 | // static_assert(VECTORIZED_BLOCK_TILE_SIZE_K * BLOCK_TILE_SIZE_Y % 417 | // NUM_THREADS == 418 | // 0U); 419 | if (A_thread_block_tile_row_idx < BLOCK_TILE_SIZE_Y && 420 | A_thread_block_tile_col_idx < BLOCK_TILE_SIZE_K) 421 | { 422 | for (size_t i{0U}; i < NUM_VECTOR_UNITS; ++i) 423 | { 424 | A_thread_block_tile_transposed[A_thread_block_tile_col_idx + 425 | i][A_thread_block_tile_row_idx] = 426 | reinterpret_cast(&A_row_vector_vals)[i]; 427 | } 428 | } 429 | } 430 | // Load data from B on DRAM to B_thread_block_tile on shared memory. 431 | #pragma unroll 432 | for (size_t load_idx{0U}; 433 | load_idx < 434 | (BLOCK_TILE_SIZE_K * VECTORIZED_BLOCK_TILE_SIZE_X + NUM_THREADS - 1U) / 435 | NUM_THREADS; 436 | ++load_idx) 437 | { 438 | size_t const B_thread_block_tile_row_idx{ 439 | (thread_linear_idx + load_idx * NUM_THREADS) / 440 | VECTORIZED_BLOCK_TILE_SIZE_X}; 441 | size_t const B_thread_block_tile_col_idx{ 442 | (thread_linear_idx + load_idx * NUM_THREADS) % 443 | VECTORIZED_BLOCK_TILE_SIZE_X * NUM_VECTOR_UNITS}; 444 | size_t const B_row_idx{thread_block_tile_idx * BLOCK_TILE_SIZE_K + 445 | B_thread_block_tile_row_idx}; 446 | size_t const B_col_idx{blockIdx.x * BLOCK_TILE_SIZE_X + 447 | B_thread_block_tile_col_idx}; 448 | 449 | // These boundary checks might slow down the kernel to some extent. 450 | // But they guarantee the correctness of the kernel for all 451 | // different GEMM configurations. 452 | int4 B_row_vector_vals{0, 0, 0, 0}; 453 | if (B_row_idx < k && B_col_idx < n) 454 | { 455 | B_row_vector_vals = 456 | *reinterpret_cast(&B[B_row_idx * ldb + B_col_idx]); 457 | } 458 | if (B_col_idx + NUM_VECTOR_UNITS > n) 459 | { 460 | // Number of invalid elements in the last vector. 461 | size_t const num_invalid_elements{B_col_idx + NUM_VECTOR_UNITS - n}; 462 | // Mask out the invalid elements. 463 | T* const B_row_vector_vals_ptr{ 464 | reinterpret_cast(&B_row_vector_vals)}; 465 | for (size_t i{0U}; i < num_invalid_elements; ++i) 466 | { 467 | B_row_vector_vals_ptr[NUM_VECTOR_UNITS - 1U - i] = 468 | static_cast(0); 469 | } 470 | } 471 | // If this is true, the following if can be removed. 472 | // static_assert(VECTORIZED_BLOCK_TILE_SIZE_X * BLOCK_TILE_SIZE_K % 473 | // NUM_THREADS == 474 | // 0U); 475 | if (B_thread_block_tile_row_idx < BLOCK_TILE_SIZE_K && 476 | B_thread_block_tile_col_idx < BLOCK_TILE_SIZE_X) 477 | { 478 | *reinterpret_cast( 479 | &B_thread_block_tile[B_thread_block_tile_row_idx] 480 | [B_thread_block_tile_col_idx]) = 481 | B_row_vector_vals; 482 | } 483 | } 484 | } 485 | 486 | #endif // CUDA_GEMM_UTILS_CUH --------------------------------------------------------------------------------