├── .clang-format ├── .gitignore ├── .gitmodules ├── .pre-commit-config.yaml ├── CMakeLists.txt ├── LICENSE ├── Makefile ├── README.md ├── docs ├── README.md └── _static │ └── TiledCUDA_overview.png ├── examples ├── README.md ├── cpp │ ├── 01_gemm │ │ ├── .gitignore │ │ ├── 01_gemm_global_reg │ │ │ ├── gemm.hpp │ │ │ └── main.cu │ │ ├── 02_gemm_all_mem │ │ │ ├── gemm.hpp │ │ │ └── main.cu │ │ ├── CMakeLists.txt │ │ ├── Makefile │ │ └── util.hpp │ ├── 02_fused_two_gemms │ │ ├── .gitignore │ │ ├── CMakeLists.txt │ │ ├── Makefile │ │ ├── fused_gemm.cu │ │ ├── fused_gemm.hpp │ │ └── util.hpp │ └── 03_flash_attention │ │ ├── .gitignore │ │ ├── CMakeLists.txt │ │ ├── flash_attn.cu │ │ ├── flash_attn.hpp │ │ ├── flash_attn_cpu.hpp │ │ └── util.hpp └── python │ ├── gemm │ ├── .gitignore │ ├── README.md │ ├── compile.py │ ├── csrc │ │ └── kernel.h │ ├── entry.py │ ├── gemm.py │ └── main.py │ └── scatter_nd.py ├── include ├── cell │ ├── acc.hpp │ ├── compute │ │ ├── broadcast.hpp │ │ ├── gemm.hpp │ │ ├── map.hpp │ │ ├── math_functor.hpp │ │ ├── mod.hpp │ │ ├── reduce.hpp │ │ └── softmax.hpp │ ├── convert.hpp │ ├── copy │ │ ├── constants.hpp │ │ ├── copy_atom.hpp │ │ ├── global_to_register.hpp │ │ ├── global_to_shared.hpp │ │ ├── mod.hpp │ │ ├── register.hpp │ │ ├── shared_to_register.hpp │ │ └── warp.hpp │ ├── mod.hpp │ ├── sync.hpp │ ├── traits │ │ └── base.hpp │ └── warp.hpp ├── config.hpp ├── cuda_info.hpp ├── cuda_utils.hpp ├── errors.hpp ├── kernels │ ├── flash_attn.hpp │ ├── mod.hpp │ └── scatter_nd.hpp ├── types │ ├── global.hpp │ ├── global_tile_iterator.hpp │ ├── layout.hpp │ ├── mod.hpp │ ├── register.hpp │ ├── shared.hpp │ ├── shared_tile_iterator.hpp │ └── tile_shape.hpp └── util │ ├── cuda_timer.hpp │ ├── debug.hpp │ └── print.hpp ├── pytiledcuda ├── .gitignore └── __init__.py ├── scripts ├── clang_format.hook ├── cmake │ ├── dependencies.cmake │ ├── external │ │ └── glog.cmake │ ├── generic.cmake │ └── public │ │ └── glog.cmake └── unittests │ ├── python.sh │ └── run_all_cpp_tests.sh ├── src ├── CMakeLists.txt ├── cuda_info.cc ├── cuda_utils.cc ├── kernels │ ├── flash_attn.cu │ └── scatter_nd.cu └── torch_bind.cc └── tests ├── cpp ├── CMakeLists.txt ├── cell │ ├── test_broadcast.cu │ ├── test_flash_attn.cu │ ├── test_g2r_copy.cu │ ├── test_g2s_load.cu │ ├── test_gemm.cu │ ├── test_global_tile_iterator.cu │ ├── test_layout.cu │ ├── test_reduce.cu │ ├── test_s2r_copy.cu │ ├── test_single_wmma.cu │ ├── test_softmax.cu │ └── test_swizzled_copy.cu ├── common │ ├── test_utils.cc │ └── test_utils.hpp ├── kernels │ └── test_scatter_nd.cu └── test_unit.cc └── python ├── context.py ├── test_flash_attn.py └── test_scatter_nd.py /.clang-format: -------------------------------------------------------------------------------- 1 | # Run manually to reformat a file: 2 | # clang-format -i --style=file 3 | BasedOnStyle: Google 4 | ColumnLimit: 80 5 | IndentWidth: 4 6 | AccessModifierOffset: -2 7 | DerivePointerAlignment: false 8 | KeepEmptyLinesAtTheStartOfBlocks: false 9 | SortIncludes: true 10 | IncludeBlocks: Regroup 11 | IncludeCategories: 12 | - Regex: '<([A-Za-z0-9\Q/-_\E])+>' 13 | Priority: 4 14 | - Regex: '<(catch2|boost)\/' 15 | Priority: 3 16 | - Regex: '<([A-Za-z0-9.\Q/-_\E])+>' 17 | Priority: 2 18 | - Regex: '"([A-Za-z0-9.\Q/-_\E])+"' 19 | Priority: 1 20 | 21 | AllowShortLoopsOnASingleLine: true 22 | AllowShortIfStatementsOnASingleLine: true 23 | Cpp11BracedListStyle: true 24 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/* 2 | build/* 3 | **/__pycache__/* 4 | .DS_Store 5 | **/.DS_Store 6 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "3rd-party/cutlass"] 2 | path = 3rd-party/cutlass 3 | url = git@github.com:NVIDIA/cutlass.git 4 | [submodule "3rd-party/googletest"] 5 | path = 3rd-party/googletest 6 | url = git@github.com:google/googletest.git 7 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/Lucas-C/pre-commit-hooks.git 3 | rev: v1.5.5 4 | hooks: 5 | - id: remove-crlf 6 | files: (?!.*third_party)^.*$ | (?!.*book)^.*$ 7 | - repo: https://github.com/pre-commit/mirrors-yapf.git 8 | rev: v0.32.0 9 | hooks: 10 | - id: yapf 11 | files: (.*\.(py|bzl)|BUILD|.*\.BUILD|WORKSPACE)$ 12 | - repo: https://github.com/pre-commit/pre-commit-hooks 13 | rev: v4.6.0 14 | hooks: 15 | - id: check-added-large-files 16 | - id: check-merge-conflict 17 | - id: check-symlinks 18 | - id: detect-private-key 19 | files: (?!.*third_party)^.*$ | (?!.*book)^.*$ 20 | - id: end-of-file-fixer 21 | - id: check-yaml 22 | - id: check-toml 23 | - id: check-ast 24 | - id: check-executables-have-shebangs 25 | - id: check-shebang-scripts-are-executable 26 | - id: detect-private-key 27 | - id: debug-statements 28 | - repo: local 29 | hooks: 30 | - id: clang-format-with-version-check 31 | name: clang-format 32 | description: Format files with ClangFormat. 33 | entry: bash ./scripts/clang_format.hook -i 34 | language: system 35 | files: \.(c|cc|cxx|cpp|cu|h|cuh|hpp|hxx|proto)$ 36 | - repo: https://github.com/iconmaster5326/cmake-format-pre-commit-hook 37 | rev: v0.6.9 38 | hooks: 39 | - id: cmake-format 40 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # CMake 3.25 is required for CUDA 20. 2 | cmake_minimum_required(VERSION 3.25 FATAL_ERROR) 3 | project(tiledcuda LANGUAGES C CXX CUDA) 4 | 5 | # Prohibit in-source builds 6 | if(${CMAKE_SOURCE_DIR} STREQUAL ${CMAKE_BINARY_DIR}) 7 | message(FATAL_ERROR "In-source build are not supported") 8 | endif() 9 | 10 | set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} 11 | "${PROJECT_SOURCE_DIR}/scripts/cmake") 12 | 13 | option(WITH_TESTING "Build with CTests" ON) 14 | if(WITH_TESTING) 15 | enable_testing() 16 | endif() 17 | 18 | option(ENABLE_DEBUG "Enable debug mode" OFF) 19 | 20 | # this is to be compatible with the latest glog. DO NOT remove it. 21 | add_compile_definitions(GLOG_USE_GLOG_EXPORT) 22 | 23 | include(generic) 24 | include(dependencies) 25 | 26 | include_directories(include) 27 | add_subdirectory(src) 28 | add_subdirectory(tests/cpp) 29 | 30 | option(BUILD_EXAMPLES "Build TiledCUDA with examples" ON) 31 | if(BUILD_EXAMPLES) 32 | set(EXAMPLES_DIR "${CMAKE_CURRENT_SOURCE_DIR}/examples/cpp") 33 | file(GLOB SUBDIRS "${EXAMPLES_DIR}/*") 34 | 35 | foreach(SUBDIR ${SUBDIRS}) 36 | message(STATUS "Add Example: ${SUBDIR}") 37 | if(IS_DIRECTORY ${SUBDIR}) 38 | message(STATUS "Adding example: ${EXAMPLE_DIR}") 39 | add_subdirectory(${SUBDIR}) 40 | endif() 41 | endforeach() 42 | endif() 43 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 TiledTensor 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | EXAMPLE_DIR := examples 2 | TEST_DIR := tests/python 3 | UNIT_TEST ?= test_lstm_cell 4 | CPP_UT ?= test_copy 5 | CPP_UTS := scripts/unittests/run_all_cpp_tests.sh 6 | 7 | PY_EXAMPLE ?= $(EXAMPLE_DIR)/python/scatter_nd.py 8 | CPP_EXAMPLE ?= $(EXAMPLE_DIR)/cpp/b2b_gemm/b2b_gemm 9 | UNIT ?= $(TEST_DIR)/$(UNIT_TEST).py 10 | 11 | WITH_TEST ?= ON 12 | 13 | BUILD_DIR := build 14 | DYNAMIC_LIB := $(BUILD_DIR)/libtiledcuda.so 15 | 16 | .PHONY: build example unit_test clean 17 | 18 | build: 19 | @mkdir -p build 20 | @cd build && cmake -DWITH_TESTING=$(WITH_TEST) .. && make -j$(proc) 21 | 22 | $(DYNAMIC_LIB): build 23 | 24 | py_example: $(DYNAMIC_LIB) 25 | @python3 $(PY_EXAMPLE) 26 | 27 | cpp_example: $(DYNAMIC_LIB) 28 | @./$(BUILD_DIR)/$(CPP_EXAMPLE) 29 | 30 | unit_test: $(DYNAMIC_LIB) 31 | @python3 $(UNIT) 32 | 33 | unit_test_cpp: $(DYNAMIC_LIB) 34 | @cd $(BUILD_DIR) && ctest -R $(CPP_UT) -V 35 | 36 | unit_test_cpps: $(DYNAMIC_LIB) 37 | @sh $(CPP_UTS) 38 | 39 | clean: 40 | @rm -rf build 41 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TiledCUDA 2 | 3 | **TiledCUDA** is a highly efficient kernel template library designed to elevate CUDA C’s level of abstraction for processing tiles. It is designed to be: 4 | 5 | - **Higher-Level Programming**: TiledCUDA offers a set of device kernels for transferring tiles between the CUDA device's three memory hierarchies and for computing tiles. 6 | - **Modularity**: TiledCUDA enables users to construct their applications by processing larger tiles in time and space using the provided BaseTiles. 7 | - **Efficiency**: TiledCUDA offers highly efficient implementations of these device kernels. 8 | 9 | TiledCUDA adopts a hardware bottom-up approach by building kernels around the core concept of the **BaseTile**. The shapes of these BaseTiles align with TensorCore's instruction shape and encapsulate hardware-dependent performance parameters to optimally utilize TensorCore's capabilities. Serving as building blocks, these BaseTiles are then combined to construct larger tiles in both temporal and spatial dimensions, enabling users to process larger tiles composed of BaseTiles for their applications. 10 | 11 | ## Latest News 🔥 12 | 13 | - [2025/01/28] :bangbang: We have **migrated the latest developments and updates** of the TiledCUDA project to a new repository. We invite you to visit and follow our new repository at [microsoft/TileFusion](https://github.com/microsoft/TileFusion). For any questions or further assistance, please feel free to reach out to us at the new repository. 14 | 15 | - [2024/8/30] TiledCUDA supported FlashAttention-v2, [FlashAttention-v2 Example](src/kernels/flash_attn.cu). 16 | 17 | 18 | ## Example 19 | 20 | TiledCUDA implements `GlobalTile`, `SharedTile` and `RegTile` to customize the shape and layout of tiles located in the GPU's three memory hierarchies. Here's an example of a simple GEMM kernel written in TiledCUDA (the complete example can be found in [this directory](https://github.com/TiledTensor/TiledCUDA/tree/master/examples/cpp/gemm)): 21 | 22 |

23 | 24 |

25 | 26 | (*To simplify the demonstration, this example only involves two memory levels: global memory and registers. TiledCUDA also applies a similar concept to shared memory*.) 27 | 28 | ```cpp 29 | template 32 | __global__ void simple_gemm(const InType* dA, const InType* dB, AccType* dC) { 33 | IteratorA gAs(dA); 34 | RegA rA; 35 | LoaderA loader_a; 36 | 37 | IteratorB gBs(dB); 38 | RegB rB; 39 | LoaderB loader_b; 40 | 41 | RegC acc; 42 | 43 | for (int k = 0; k < IteratorA::sc1; ++k) { 44 | loader_a(gAs(k), rA); 45 | loader_b(gBs(k), rB); 46 | __syncthreads(); 47 | 48 | gemm(rA, rB, acc); 49 | } 50 | __syncthreads(); 51 | 52 | GlobalC gC(dC); 53 | CStorer storer_c; 54 | storer_c(acc, gC); 55 | } 56 | ``` 57 | - The `TileIterator` is used to divide the `GlobalTile` into smaller sub-tiles and iterate over them. Various warp reuse methods are provided to support efficient repeated loading of data by warps within a thread block. TiledCUDA provides efficient loading and storing methods that transfer data between memory hierarchies by utilizing specialized hardware-accelerated instructions. Tiles of data are then cooperatively loaded into the `RegTile`, which is stored in each thread's local register file. 58 | 59 | - Once the data is loaded into a thread's local register file, `gemm` performs matrix multiplication using TensorCore's warp-level matrix multiply-and-accumulate (wmma) instruction on the `BaseTile`s. The specialized data distribution required by TensorCore is automatically maintained by TiledCUDA's `RegTile` layout. 60 | 61 | - After the `gemm` operation is completed, data in the `RegTile` is cooperatively stored back from registers to global memory using the `RegToGlobalStorer`. 62 | 63 | Here is how to declare the `Tile` at each level of memory, use `TileIterator` to chunk large tiles into sub-tiles, and declare loaders and storers to transfer tiles between memory hierarchies. 64 | 65 | ```cpp 66 | using WarpLayout = RowMajor<2, 2>; 67 | 68 | // operand A 69 | using GlobalA = GlobalTile>; 70 | using IteratorA = TileIterator>; 71 | using RegA = RegTile, RowMajor<8, 8>>; 72 | using ALoader = GlobalToRegLoader; 73 | 74 | // operand B 75 | using GlobalB = GlobalTile>; 76 | using IteratorB = TileIterator>; 77 | using RegB = RegTile, ColMajor<8, 4>>; 78 | using BLoader = GlobalToRegLoader; 79 | 80 | // output C 81 | using GlobalC = GlobalTile>; 82 | using RegC = RegTile, RowMajor<8, 8>>; 83 | using CStorer = RegToGlobalStorer; 84 | ``` 85 | 86 | ## Quick Start 87 | 88 | ### Download 89 | 90 | ```bash 91 | git clone git@github.com:TiledTensor/TiledCUDA.git 92 | cd TiledCUDA && git submodule update --init --recursive 93 | ``` 94 | 95 | ### Installation 96 | 97 | TiledCUDA requires a C++20 host compiler, CUDA 12.0 or later, and GCC version 10.0 or higher to support C++20 features. 98 | 99 | ### Unit Test 100 | 101 | - **Run a single unit test**: `make unit_test UNIT_TEST=test_scatter_nd.py` 102 | - **Run all unit tests**: `./scripts/unittests/python.sh` 103 | - **Run a single cpp unit test**: `make unit_test_cpp CPP_UT=test_copy` 104 | - **Run all cpp unit tests**: `make unit_test_cpps` 105 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TiledTensor/TiledCUDA/1abaa17846b5a0cbcddf2613230e6447ce63d62f/docs/README.md -------------------------------------------------------------------------------- /docs/_static/TiledCUDA_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TiledTensor/TiledCUDA/1abaa17846b5a0cbcddf2613230e6447ce63d62f/docs/_static/TiledCUDA_overview.png -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | [TBD] 2 | -------------------------------------------------------------------------------- /examples/cpp/01_gemm/.gitignore: -------------------------------------------------------------------------------- 1 | build 2 | -------------------------------------------------------------------------------- /examples/cpp/01_gemm/01_gemm_global_reg/gemm.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "cell/mod.hpp" 4 | #include "types/mod.hpp" 5 | 6 | using namespace tiledcuda; 7 | using namespace tiledcuda::cell; 8 | namespace tl = tile_layout; 9 | 10 | template 11 | using GemmShape = TileShape; 12 | 13 | template 15 | struct GemmTraits { 16 | using BaseShape = traits::BaseTileShape; 17 | static constexpr int kChunkK = 64; 18 | 19 | static constexpr int kThreads = tl::get_numel * 32; 20 | static constexpr int kWarpPerRow = tl::num_rows; 21 | static constexpr int kWarpPerCol = tl::num_cols; 22 | 23 | static constexpr int kM = dim_size<0, WholeShape>; 24 | static constexpr int kN = dim_size<1, WholeShape>; 25 | static constexpr int kK = dim_size<2, WholeShape>; 26 | 27 | // operand A 28 | using GlobalA = GlobalTile>; 29 | using IteratorA = GTileIterator>; 30 | 31 | static constexpr int kAMs = kTM / kWarpPerRow / BaseShape::kTileSize; 32 | static constexpr int kAKs = kChunkK / BaseShape::kTileSize; 33 | using RegA = RegTile, tl::RowMajor>; 34 | 35 | using ALoader = copy::GlobalToRegLoader; 37 | 38 | // operand B 39 | using GlobalB = GlobalTile>; 40 | using IteratorB = GTileIterator>; 41 | 42 | static constexpr int kBKs = kChunkK / BaseShape::kTileSize; 43 | static constexpr int kBNs = kTN / kWarpPerCol / BaseShape::kTileSize; 44 | using RegB = RegTile, tl::ColMajor>; 45 | 46 | using BLoader = copy::GlobalToRegLoader; 48 | 49 | // output C 50 | using GlobalC = GlobalTile>; 51 | 52 | static constexpr int kCMs = kTM / kWarpPerRow / BaseShape::kTileSize; 53 | static constexpr int kCNs = kTN / kWarpPerCol / BaseShape::kTileSize; 54 | using RegC = RegTile, tl::RowMajor>; 55 | 56 | using CStorer = copy::RegToGlobalStorer; 57 | }; 58 | 59 | template 65 | __global__ void simple_gemm(const InType* dA, const InType* dB, AccType* dC) { 66 | int offset_a = blockIdx.x * kTM * kK; 67 | int offset_b = blockIdx.y * kTN * kK; 68 | int offset_c = blockIdx.x * kTM * kN + blockIdx.y * kTN; 69 | 70 | IteratorA gAs(dA + offset_a); 71 | RegA rA; 72 | ALoader loader_a; 73 | 74 | IteratorB gBs(dB + offset_b); 75 | RegB rB; 76 | BLoader loader_b; 77 | 78 | RegC acc; 79 | GlobalC gC(dC + offset_c); 80 | CStorer storer_c; 81 | 82 | for (int k = 0; k < IteratorA::sc1; ++k) { 83 | loader_a(gAs(k), rA); 84 | loader_b(gBs(k), rB); 85 | __syncthreads(); 86 | 87 | compute::gemm(rA, rB, acc); 88 | } 89 | __syncthreads(); 90 | 91 | storer_c(acc, gC); 92 | } 93 | -------------------------------------------------------------------------------- /examples/cpp/01_gemm/01_gemm_global_reg/main.cu: -------------------------------------------------------------------------------- 1 | #include "gemm.hpp" 2 | #include "util.hpp" 3 | 4 | template 5 | int run_test() { 6 | using InType = __half; 7 | using AccType = float; 8 | 9 | static constexpr int kM = dim_size<0, WholeShape>; 10 | static constexpr int kN = dim_size<1, WholeShape>; 11 | static constexpr int kK = dim_size<2, WholeShape>; 12 | 13 | static constexpr int kTM = dim_size<0, CtaTileShape>; 14 | static constexpr int kTN = dim_size<1, CtaTileShape>; 15 | 16 | thrust::host_vector h_a(kM * kK); 17 | for (int i = 0; i < h_a.size(); ++i) 18 | h_a[i] = static_cast(rand_float()); 19 | 20 | thrust::host_vector h_b(kK * kN); 21 | for (int i = 0; i < h_b.size(); ++i) 22 | h_b[i] = static_cast(rand_float()); 23 | 24 | thrust::host_vector h_c(kM * kN); 25 | thrust::fill(h_c.begin(), h_c.end(), 0.); 26 | 27 | thrust::device_vector d_a = h_a; 28 | thrust::device_vector d_b = h_b; 29 | thrust::device_vector d_c = h_c; 30 | 31 | const InType* A = thrust::raw_pointer_cast(d_a.data()); 32 | const InType* B = thrust::raw_pointer_cast(d_b.data()); 33 | AccType* C = thrust::raw_pointer_cast(d_c.data()); 34 | 35 | using Config = 36 | GemmTraits; 37 | 38 | using RegA = typename Config::RegA; 39 | using RegB = typename Config::RegB; 40 | using RegC = typename Config::RegC; 41 | 42 | using IteratorA = typename Config::IteratorA; 43 | using IteratorB = typename Config::IteratorB; 44 | 45 | int block_x = CeilDiv; 46 | int block_y = CeilDiv; 47 | 48 | std::cout << "kThreads: " << Config::kThreads << std::endl 49 | << "RegA: " << RegA{} << std::endl 50 | << "RegB: " << RegB{} << std::endl 51 | << "RegC: " << RegC{} << std::endl 52 | << "IteratorA: " << IteratorA{} << std::endl 53 | << "IteratorB: " << IteratorB{} << std::endl 54 | << "blocks: [" << block_x << ", " << block_y << "]" << std::endl 55 | << std::endl; 56 | 57 | dim3 dim_grid(block_x, block_y, 1); 58 | dim3 dim_block(Config::kThreads, 1, 1); 59 | simple_gemm 63 | <<>>(A, B, C); 64 | cudaDeviceSynchronize(); 65 | h_c = d_c; 66 | 67 | // check correctness 68 | thrust::device_vector d_c2(kM * kN); 69 | thrust::fill(d_c2.begin(), d_c2.end(), 0.); 70 | 71 | cublas_hgemm(kM, kN, kK, thrust::raw_pointer_cast(d_a.data()), 72 | thrust::raw_pointer_cast(d_b.data()), 73 | thrust::raw_pointer_cast(d_c2.data()), false /*timeit*/); 74 | thrust::host_vector h_c2 = d_c2; 75 | 76 | bool passed = check_results(thrust::raw_pointer_cast(h_c.data()), 77 | thrust::raw_pointer_cast(h_c2.data()), kM * kN); 78 | 79 | if (passed) { 80 | std::cout << "Test passed." << std::endl; 81 | 82 | CudaTimer timer; 83 | timer.start(); 84 | int iters = 20; 85 | for (int i = 0; i < iters; ++i) { 86 | simple_gemm 90 | <<>>(A, B, C); 91 | } 92 | cudaDeviceSynchronize(); 93 | 94 | float time = timer.stop(); 95 | std::cout << std::setprecision(4) << "elapsed time: " << time / iters 96 | << " ms" << std::endl; 97 | 98 | } else 99 | std::cerr << "Test failed." << std::endl; 100 | 101 | return 0; 102 | } 103 | 104 | int main(int argc, char* argv[]) { 105 | run_test, GemmShape<256, 128, 64>, 106 | tl::RowMajor<2, 2>>(); 107 | 108 | return 0; 109 | } 110 | -------------------------------------------------------------------------------- /examples/cpp/01_gemm/02_gemm_all_mem/gemm.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "cell/mod.hpp" 4 | #include "types/mod.hpp" 5 | 6 | using namespace tiledcuda; 7 | using namespace tiledcuda::cell; 8 | using namespace tiledcuda::cell::copy; 9 | 10 | namespace tl = tile_layout; 11 | 12 | template 13 | using GemmShape = TileShape; 14 | 15 | template 17 | struct KeGemmTraits { 18 | using BaseShape = traits::BaseTileShape; 19 | 20 | static constexpr int kThreads = tl::get_numel * 32; 21 | static constexpr int kWarpPerRow = tl::num_rows; 22 | static constexpr int kWarpPerCol = tl::num_cols; 23 | 24 | static constexpr int kM = dim_size<0, WholeShape>; 25 | static constexpr int kN = dim_size<1, WholeShape>; 26 | static constexpr int kK = dim_size<2, WholeShape>; 27 | 28 | static constexpr int kTM = dim_size<0, CtaTileShape>; 29 | static constexpr int kTN = dim_size<1, CtaTileShape>; 30 | static constexpr int kTK = dim_size<2, CtaTileShape>; 31 | 32 | static const bool kSwizzled = true; 33 | 34 | // Total data access for operand A in global memory 35 | using GlobalA = GlobalTile>; 36 | // Access a single global tile for operand A 37 | using GIteratorA = GTileIterator>; 38 | 39 | // Shared Tile for operand A 40 | using SharedA = SharedTile, kSwizzled>; 41 | // Access a single register tile for operand A 42 | using SIteratorA = STileIterator>; 43 | 44 | // Register tile for a single thread of operand A 45 | static constexpr int kAMs = kTM / kWarpPerRow / BaseShape::kTileSize; 46 | static constexpr int kAKs = kRK / BaseShape::kTileSize; 47 | using RegA = RegTile, tl::RowMajor>; 48 | 49 | // Loaders for operand A 50 | using G2SLoaderA = GlobalToSharedLoader; 51 | using S2RLoaderA = 52 | SharedToRegLoader; 53 | 54 | // Total data access for operand B in global memory 55 | using GlobalB = GlobalTile>; 56 | // Access a single global tile for operand B 57 | using GIteratorB = GTileIterator>; 58 | 59 | // Shared Tile for operand B 60 | using SharedB = SharedTile, kSwizzled>; 61 | // Access a single register tile for operand B 62 | using SIteratorB = STileIterator>; 63 | 64 | static_assert(GIteratorA::sc1 == GIteratorB::sc0, 65 | "mismatched K dimension!"); 66 | static_assert(SIteratorA::sc1 == SIteratorB::sc0, 67 | "mismatched K dimension!"); 68 | 69 | // Register tile for a single thread of operand A 70 | static constexpr int kBKs = kRK / BaseShape::kTileSize; 71 | static constexpr int kBNs = kTN / kWarpPerCol / BaseShape::kTileSize; 72 | using RegB = RegTile, tl::ColMajor>; 73 | 74 | using G2SLoaderB = GlobalToSharedLoader; 75 | using S2RLoaderB = 76 | SharedToRegLoader; 77 | 78 | // Global Tile for output C 79 | using GlobalC = GlobalTile>; 80 | // Shared Tile for output C 81 | using SharedC = SharedTile, kSwizzled>; 82 | 83 | // Register Tile for output C 84 | static constexpr int kCMs = kTM / kWarpPerRow / BaseShape::kTileSize; 85 | static constexpr int kCNs = kTN / kWarpPerCol / BaseShape::kTileSize; 86 | using RegC = RegTile, tl::RowMajor>; 87 | 88 | using R2SStorerC = RegToSharedStorer; 89 | using S2GStorerC = SharedToGlobalStorer; 90 | }; 91 | 92 | template 103 | __global__ void gemm(const InType* dA, const InType* dB, AccType* dC) { 104 | int offset_a = blockIdx.x * kTM * kK; 105 | int offset_b = blockIdx.y * kTN * kK; 106 | int offset_c = blockIdx.x * kTM * kN + blockIdx.y * kTN; 107 | 108 | extern __shared__ __align__(sizeof(double)) unsigned char buf[]; 109 | InType* sA_ptr = reinterpret_cast(buf); 110 | InType* sB_ptr = sA_ptr + SIteratorA::Tile::kNumel; 111 | AccType* sC_ptr = reinterpret_cast(buf); 112 | 113 | // declare tiles, iterators and loaders 114 | GIteratorA gAs(dA + offset_a); 115 | SIteratorA sAs(sA_ptr); 116 | 117 | GIteratorB gBs(dB + offset_b); 118 | SIteratorB sBs(sB_ptr); 119 | 120 | SharedA sA(sA_ptr); 121 | RegA rA; 122 | 123 | SharedB sB(sB_ptr); 124 | RegB rB; 125 | 126 | RegC acc; 127 | SharedC sC(sC_ptr); 128 | GlobalC gC(dC + offset_c); 129 | 130 | G2SLoaderA g2s_a; 131 | S2RLoaderA s2r_a; 132 | 133 | G2SLoaderB g2s_b; 134 | S2RLoaderB s2r_b; 135 | 136 | R2SStorerC r2s_c; 137 | S2GStorerC s2g_c; 138 | 139 | for (int k1 = 0; k1 < GIteratorA::sc1; ++k1) { 140 | g2s_a(gAs(k1), sA); 141 | g2s_b(gBs(k1), sB); 142 | __copy_async(); 143 | __syncthreads(); 144 | 145 | for (int k2 = 0; k2 < SIteratorA::sc1; ++k2) { 146 | s2r_a(sAs(k2), rA); 147 | s2r_b(sBs(k2), rB); 148 | 149 | compute::gemm(rA, rB, acc); 150 | } 151 | } 152 | r2s_c(acc, sC); 153 | __syncthreads(); 154 | s2g_c(sC, gC); 155 | } 156 | -------------------------------------------------------------------------------- /examples/cpp/01_gemm/02_gemm_all_mem/main.cu: -------------------------------------------------------------------------------- 1 | #include "gemm.hpp" 2 | #include "util.hpp" 3 | 4 | void run_test() { 5 | using WholeShape = GemmShape<4096, 4096, 4096>; 6 | using CtaTileShape = GemmShape<64, 256, 32>; 7 | using WarpLayout = tl::RowMajor<2, 2>; 8 | static constexpr int kRK = 32; 9 | 10 | using InType = __half; 11 | using AccType = float; 12 | 13 | static constexpr int kM = dim_size<0, WholeShape>; 14 | static constexpr int kN = dim_size<1, WholeShape>; 15 | static constexpr int kK = dim_size<2, WholeShape>; 16 | 17 | static constexpr int kTM = dim_size<0, CtaTileShape>; 18 | static constexpr int kTN = dim_size<1, CtaTileShape>; 19 | static constexpr int kTK = dim_size<2, CtaTileShape>; 20 | 21 | thrust::host_vector h_a(kM * kK); 22 | for (int i = 0; i < h_a.size(); ++i) 23 | h_a[i] = static_cast(rand_float()); 24 | 25 | thrust::host_vector h_b(kK * kN); 26 | for (int i = 0; i < h_b.size(); ++i) 27 | h_b[i] = static_cast(rand_float()); 28 | 29 | thrust::host_vector h_c(kM * kN); 30 | thrust::fill(h_c.begin(), h_c.end(), 0.); 31 | 32 | thrust::device_vector d_a = h_a; 33 | thrust::device_vector d_b = h_b; 34 | thrust::device_vector d_c = h_c; 35 | 36 | const InType* A = thrust::raw_pointer_cast(d_a.data()); 37 | const InType* B = thrust::raw_pointer_cast(d_b.data()); 38 | AccType* C = thrust::raw_pointer_cast(d_c.data()); 39 | 40 | using Config = KeGemmTraits; 42 | auto kernel = 43 | &gemm; 53 | 54 | static constexpr int smem_size_inputs = kTK * (kTN + kTM) * sizeof(InType); 55 | static constexpr int smem_size_accumulators = kTM * kTN * sizeof(AccType); 56 | static constexpr int smem_size = smem_size_inputs > smem_size_accumulators 57 | ? smem_size_inputs 58 | : smem_size_accumulators; 59 | 60 | const int kMaxSmemPerBlock = 48 * 1024; 61 | if (smem_size > kMaxSmemPerBlock) { 62 | cudaFuncSetAttribute( 63 | kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); 64 | } 65 | 66 | int block_x = CeilDiv; 67 | int block_y = CeilDiv; 68 | 69 | dim3 dim_grid(block_x, block_y, 1); 70 | dim3 dim_block(Config::kThreads, 1, 1); 71 | 72 | kernel<<>>(A, B, C); 73 | cudaDeviceSynchronize(); 74 | h_c = d_c; 75 | 76 | // check correctness 77 | thrust::device_vector d_c2(kM * kN); 78 | thrust::fill(d_c2.begin(), d_c2.end(), 0.); 79 | 80 | cublas_hgemm(kM, kN, kK, thrust::raw_pointer_cast(d_a.data()), 81 | thrust::raw_pointer_cast(d_b.data()), 82 | thrust::raw_pointer_cast(d_c2.data()), false /*timeit*/); 83 | thrust::host_vector h_c2 = d_c2; 84 | 85 | bool passed = check_results(thrust::raw_pointer_cast(h_c.data()), 86 | thrust::raw_pointer_cast(h_c2.data()), kM * kN); 87 | 88 | if (passed) { 89 | std::cout << "Test passed." << std::endl; 90 | 91 | int warm_up = 10; 92 | for (int i = 0; i < warm_up; ++i) 93 | kernel<<>>(A, B, C); 94 | cudaDeviceSynchronize(); 95 | 96 | CudaTimer timer; 97 | timer.start(); 98 | int iters = 20; 99 | for (int i = 0; i < iters; ++i) { 100 | kernel<<>>(A, B, C); 101 | } 102 | cudaDeviceSynchronize(); 103 | float time2 = timer.stop() / iters; 104 | 105 | float time1 = cublas_hgemm( 106 | kM, kN, kK, thrust::raw_pointer_cast(d_a.data()), 107 | thrust::raw_pointer_cast(d_b.data()), 108 | thrust::raw_pointer_cast(d_c2.data()), true /*timeit*/); 109 | 110 | std::cout << "cuBLAS\tTiledCUDA\tRatio" << std::endl; 111 | std::cout << std::setprecision(4) << time1 << "\t" << time2 << "\t" 112 | << time2 / time1 << std::endl; 113 | } else { 114 | std::cerr << "Test failed." << std::endl; 115 | } 116 | } 117 | 118 | int main(int argc, char* argv[]) { 119 | run_test(); 120 | return 0; 121 | } 122 | -------------------------------------------------------------------------------- /examples/cpp/01_gemm/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.18 FATAL_ERROR) 2 | project(gemm_examples LANGUAGES C CXX CUDA) 3 | 4 | set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} 5 | "${PROJECT_SOURCE_DIR}/../../../scripts/cmake") 6 | include(generic) 7 | 8 | set(THIRD_PARTY_DIR "${PROJECT_SOURCE_DIR}/../../../3rd-party") 9 | include_directories("${THIRD_PARTY_DIR}/cutlass/include") 10 | include_directories("${PROJECT_SOURCE_DIR}") 11 | include_directories("${PROJECT_SOURCE_DIR}/../../../include") 12 | 13 | add_executable(01_gemm_global_reg 01_gemm_global_reg/main.cu) 14 | target_link_libraries(01_gemm_global_reg ${CUDA_CUBLAS_LIBRARIES}) 15 | 16 | add_executable(02_gemm_all_mem 02_gemm_all_mem/main.cu) 17 | target_link_libraries(02_gemm_all_mem ${CUDA_CUBLAS_LIBRARIES}) 18 | -------------------------------------------------------------------------------- /examples/cpp/01_gemm/Makefile: -------------------------------------------------------------------------------- 1 | BUILD_DIR := build 2 | 3 | .PHONY: build clean 4 | 5 | build: 6 | @mkdir -p $(BUILD_DIR) 7 | @cd $(BUILD_DIR) && cmake .. && make -j$(proc) 8 | 9 | clean: 10 | @rm -rf $(BUILD_DIR) 11 | -------------------------------------------------------------------------------- /examples/cpp/01_gemm/util.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "util/cuda_timer.hpp" 4 | 5 | #include 6 | #include 7 | 8 | #include 9 | 10 | float rand_float(float a = 1e-4, float b = 1e-2) { 11 | float random = ((float)rand()) / (float)RAND_MAX; 12 | float diff = b - a; 13 | float r = random * diff; 14 | return a + r; 15 | } 16 | 17 | float cublas_hgemm(int64_t kM, int64_t kN, int64_t kK, // problem shape 18 | const __half* A, const __half* B, __half* C, 19 | bool timeit = false, int warm_up = 5, int iters = 20) { 20 | cublasHandle_t handle; 21 | cublasCreate(&handle); 22 | 23 | __half alf = static_cast<__half>(1.); 24 | __half bet = static_cast<__half>(0.); 25 | 26 | float elapsed = 0.; 27 | 28 | if (timeit) { 29 | for (int i = 0; i < warm_up; ++i) { 30 | cublasHgemm(handle, CUBLAS_OP_T /* transb*/, CUBLAS_OP_N, kN, kM, 31 | kK, &alf, B, kK, A, kK, &bet, C, kN); 32 | } 33 | cudaDeviceSynchronize(); 34 | 35 | CudaTimer timer; 36 | timer.start(); 37 | for (int i = 0; i < iters; ++i) { 38 | cublasHgemm(handle, CUBLAS_OP_T /* transb*/, CUBLAS_OP_N, kN, kM, 39 | kK, &alf, B, kK, A, kK, &bet, C, kN); 40 | } 41 | cudaDeviceSynchronize(); 42 | elapsed = timer.stop() / iters; 43 | } else { 44 | cublasHgemm(handle, CUBLAS_OP_T /* transb*/, CUBLAS_OP_N, kN, kM, kK, 45 | &alf, B, kK, A, kK, &bet, C, kN); 46 | } 47 | cudaDeviceSynchronize(); 48 | 49 | cublasDestroy(handle); 50 | return elapsed; 51 | } 52 | 53 | bool check_results(const float* values1, const __half* values2, int numel) { 54 | bool passed = true; 55 | const float epsilon = 1e-3; 56 | 57 | double total_diff = 0.; 58 | double max_abs_diff = FLT_MIN; 59 | double diff = 0.; 60 | 61 | #ifdef DEBUG 62 | int cut_off = 128; 63 | printf("ground truth:\n"); 64 | for (int i = 0; i < cut_off; ++i) { 65 | printf("%.3f, ", __half2float(values2[i])); 66 | if (i && (i + 1) % 16 == 0) printf("\n"); 67 | } 68 | printf("\ncomputed values:\n"); 69 | for (int i = 0; i < cut_off; ++i) { 70 | printf("%.3f, ", values1[i]); 71 | if (i && (i + 1) % 16 == 0) printf("\n"); 72 | } 73 | #endif 74 | 75 | for (int i = 0; i < numel; ++i) { 76 | float v1 = values1[i]; 77 | float v2 = __half2float(values2[i]); 78 | diff = fabs(v1 - v2); 79 | max_abs_diff = max_abs_diff < diff ? diff : max_abs_diff; 80 | total_diff += diff; 81 | 82 | #ifdef DEBUG 83 | if (diff > epsilon) { 84 | printf("the %d-th value differs (%.4f): %.4f vs. %.4f\n", i, diff, 85 | v1, v2); 86 | } 87 | #endif 88 | } 89 | 90 | double avg_diff = total_diff / numel; 91 | if (avg_diff > epsilon) passed = false; 92 | 93 | return passed; 94 | } 95 | -------------------------------------------------------------------------------- /examples/cpp/02_fused_two_gemms/.gitignore: -------------------------------------------------------------------------------- 1 | build/ 2 | -------------------------------------------------------------------------------- /examples/cpp/02_fused_two_gemms/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.18 FATAL_ERROR) 2 | project(b2b_gemm_example LANGUAGES C CXX CUDA) 3 | 4 | set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} 5 | "${PROJECT_SOURCE_DIR}/../../../scripts/cmake") 6 | include(generic) 7 | 8 | set(THIRD_PARTY_DIR "${PROJECT_SOURCE_DIR}/../../../3rd-party") 9 | include_directories(${THIRD_PARTY_DIR}/cutlass/include) 10 | include_directories("${PROJECT_SOURCE_DIR}/../../../include") 11 | 12 | add_executable(fused_gemms fused_gemm.cu) 13 | target_link_libraries(fused_gemms ${CUDA_CUBLAS_LIBRARIES}) 14 | -------------------------------------------------------------------------------- /examples/cpp/02_fused_two_gemms/Makefile: -------------------------------------------------------------------------------- 1 | BUILD_DIR := build 2 | 3 | .PHONY: build clean 4 | 5 | build: 6 | @mkdir -p $(BUILD_DIR) 7 | @cd $(BUILD_DIR) && cmake .. && make -j$(proc) 8 | 9 | clean: 10 | @rm -rf $(BUILD_DIR) 11 | -------------------------------------------------------------------------------- /examples/cpp/02_fused_two_gemms/fused_gemm.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "cell/compute/mod.hpp" 4 | #include "cell/mod.hpp" 5 | #include "types/mod.hpp" 6 | 7 | using namespace tiledcuda; 8 | using namespace tiledcuda::cell; 9 | using namespace tiledcuda::cell::copy; 10 | namespace tl = tile_layout; 11 | 12 | template 14 | struct FusedGemmTraits { 15 | using BaseShape = traits::BaseTileShape; 16 | 17 | static constexpr int kWarpPerRow = tl::num_rows; 18 | static constexpr int kWarpPerCol = tl::num_cols; 19 | static_assert(kWarpPerCol == 1, "WarpPerCol must be 1"); 20 | 21 | static constexpr int kThreads = tl::get_numel * 32; 22 | 23 | static constexpr int kM = dim_size<0, WholeShape>; 24 | static constexpr int kN = dim_size<1, WholeShape>; 25 | static constexpr int kK = dim_size<2, WholeShape>; 26 | static constexpr int kP = dim_size<3, WholeShape>; 27 | 28 | static constexpr int kTM = dim_size<0, CtaTileShape>; 29 | static constexpr int kTN = dim_size<1, CtaTileShape>; 30 | static constexpr int kTK = dim_size<2, CtaTileShape>; 31 | static constexpr int kTP = dim_size<3, CtaTileShape>; 32 | 33 | // operand A 34 | using GlobalA = GlobalTile>; 35 | // chunk the K dimension to fit into shared memory 36 | using GIteratorA = GTileIterator>; 37 | 38 | static const bool kUseSwizzling = true; 39 | 40 | using SharedA = SharedTile, kUseSwizzling>; 41 | 42 | static constexpr int kAMs = kTM / kWarpPerRow / BaseShape::kTileSize; 43 | static constexpr int kAKs = kTK / BaseShape::kTileSize; 44 | using RegA = RegTile, tl::RowMajor>; 45 | 46 | using SharedALoader = GlobalToSharedLoader; 47 | using RegALoader = 48 | SharedToRegLoader; 49 | 50 | // operand B 51 | using GlobalB = GlobalTile>; 52 | using GIteratorB = GTileIterator>; 53 | using SharedB = SharedTile, kUseSwizzling>; 54 | 55 | static constexpr int kBKs = kTK / BaseShape::kTileSize; 56 | static constexpr int kBNs = kTN / kWarpPerCol / BaseShape::kTileSize; 57 | using RegB = RegTile, tl::ColMajor>; 58 | 59 | using SharedBLoader = GlobalToSharedLoader; 60 | using RegBLoader = 61 | SharedToRegLoader; 62 | 63 | // operand C 64 | using GlobalC = GlobalTile>; 65 | // chunk the N dimension to fit into shared memory 66 | using GIteratorC = GTileIterator>; 67 | using SharedC = SharedTile, kUseSwizzling>; 68 | 69 | static constexpr int kCNs = kTN / BaseShape::kTileSize; 70 | static constexpr int kCPs = kTP / kWarpPerCol / BaseShape::kTileSize; 71 | using RegC = RegTile, tl::ColMajor>; 72 | 73 | using SharedCLoader = GlobalToSharedLoader; 74 | using RegCLoader = 75 | SharedToRegLoader; 76 | 77 | // output D 78 | using GlobalD = GlobalTile>; 79 | 80 | static constexpr int kDMs = kTM / kWarpPerRow / BaseShape::kTileSize; 81 | static constexpr int kDPs = kTP / kWarpPerCol / BaseShape::kTileSize; 82 | using RegD = RegTile, tl::RowMajor>; 83 | using DStorer = copy::RegToGlobalStorer; 84 | 85 | static constexpr int kAccMs = kTM / kWarpPerRow / BaseShape::kTileSize; 86 | static constexpr int kAccNs = kTN / kWarpPerCol / BaseShape::kTileSize; 87 | 88 | // Reg Acc 89 | using RegAcc = 90 | RegTile, tl::RowMajor>; 91 | using RegAccCast = 92 | RegTile, tl::RowMajor>; 93 | 94 | // Convert the accumulator to half 95 | using ConvertHalf = compute::RegTileConvert; 96 | }; 97 | 98 | template 107 | __global__ void KeFusedGemm(const InType* dA, const InType* dB, 108 | const InType* dC, AccType* dD, int kM, int kN, 109 | int kK, int kP, int kTM, int kTN, int kTK, 110 | int kTP) { 111 | // Advance to the global data tile to the current CTA. 112 | const InType* A = dA + blockIdx.z * (kM * kK) + blockIdx.x * (kTM * kK); 113 | const InType* B = dB + blockIdx.z * (kK * kN); 114 | const InType* gC_ptr = 115 | dC + blockIdx.z * (kN * kP) + blockIdx.y * (kTP * kN); 116 | 117 | AccType* gD_ptr = dD + blockIdx.z * (kM * kP) + blockIdx.x * (kTM * kP) + 118 | (blockIdx.y * kTP); 119 | 120 | extern __shared__ __align__(sizeof(double)) unsigned char shared_buf[]; 121 | auto* shm = reinterpret_cast(shared_buf); 122 | 123 | InType* sA_ptr = shm; 124 | InType* sB_ptr = shm + SharedA::kNumel; 125 | InType* sC_ptr = shm + SharedA::kNumel + SharedB::kNumel; 126 | 127 | GIteratorA gAs(A); 128 | SharedA sA(sA_ptr); 129 | RegA rA; 130 | 131 | SharedALoader load_sa; 132 | RegALoader load_ra; 133 | 134 | GIteratorB gBs(B); 135 | SharedB sB(sB_ptr); 136 | RegB rB; 137 | 138 | SharedBLoader load_sb; 139 | RegBLoader load_rb; 140 | 141 | GIteratorC gCs(gC_ptr); 142 | SharedC sC(sC_ptr); 143 | 144 | SharedCLoader load_sc; 145 | RegCLoader load_rc; 146 | RegC rC; 147 | 148 | RegD rD; 149 | RegAcc acc; 150 | RegAccCast acc_half; 151 | 152 | for (int n = 0; n < GIteratorC::sc0; ++n) { 153 | load_sc(gCs(n), sC); 154 | 155 | for (int k = 0; k < GIteratorA::sc1; ++k) { 156 | load_sa(gAs(k), sA); 157 | load_sb(gBs(k, n), sB); 158 | __copy_async(); 159 | __syncthreads(); 160 | 161 | load_ra(sA, rA); 162 | load_rb(sB, rB); 163 | __syncthreads(); 164 | 165 | compute::gemm(rA, rB, acc); 166 | } 167 | load_rc(sC, rC); 168 | __syncthreads(); 169 | 170 | ConvertAcc cast_acc; // Convert acc to half precision 171 | cast_acc(acc, acc_half); 172 | 173 | compute::gemm(acc_half, rC, rD); 174 | acc.clear(); 175 | } 176 | __syncthreads(); 177 | 178 | GlobalD gD(gD_ptr); 179 | DStorer storer_d; // Store D tile from register to global. 180 | storer_d(rD, gD); 181 | } 182 | -------------------------------------------------------------------------------- /examples/cpp/02_fused_two_gemms/util.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "util/cuda_timer.hpp" 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | template 11 | using FusedGemmShape = TileShape; 12 | 13 | float rand_float(float a = 1e-1, float b = 5e-2) { 14 | float random = ((float)rand()) / (float)RAND_MAX; 15 | float diff = b - a; 16 | float r = random * diff; 17 | return a + r; 18 | } 19 | 20 | /* In this implementation, A and D are interpreted as being laid out in 21 | row-major, and B, C is interpreted as being laid out in column-major. 22 | 23 | A and D are laid out in row-major fashion 24 | B and C are laid out in column-major fashion 25 | 26 | acc[m, n] = A[m, k] @ B[k, n] 27 | D[m, p] = acc[m, n] @ C[n, p] 28 | */ 29 | float cublas_two_gemms(int kM, int kN, int kK, int kP, int kBatch, 30 | const __half* As, const __half* Bs, const __half* Cs, 31 | __half* Ds, __half* accs, bool timeit = false, 32 | int warmup = 10, int iters = 20) { 33 | cublasHandle_t handle; 34 | cublasCreate(&handle); 35 | 36 | __half alf = static_cast<__half>(1.); 37 | __half bet = static_cast<__half>(0.); 38 | 39 | const __half* A = As; 40 | const __half* B = Bs; 41 | const __half* C = Cs; 42 | __half* acc = accs; 43 | __half* D = Ds; 44 | 45 | for (int b = 0; b < kBatch; ++b) { 46 | A += b * kM * kK; 47 | B += b * kK * kN; 48 | C += b * kM * kN; 49 | acc += b * kM * kN; 50 | D += b * kM * kP; 51 | 52 | // acc = A @ B 53 | // acc^T = B^T @ A^T 54 | // [n, m] = [n, k] @ [k, m] 55 | cublasHgemm(handle, CUBLAS_OP_T, CUBLAS_OP_N /* transb*/, kN, kM, kK, 56 | &alf, B, kK, A, kK, &bet, acc, kN); 57 | 58 | // D and acc are laid out in row-major fashion, while C is in column 59 | // major fashion. Operands of cuBLAS is by default in column fashion. 60 | // D = acc @ C 61 | // D^T = C^T @ acc^T; [p, m] = [p, n] @ [n, m] 62 | cublasHgemm(handle, CUBLAS_OP_T, CUBLAS_OP_N /* transb*/, kP, kM, kN, 63 | &alf, C, kN, acc, kN, &bet, D, kP); 64 | } 65 | 66 | float elapsed = 0.; 67 | if (timeit) { 68 | for (int i = 0; i < warmup; ++i) { 69 | A = As; 70 | B = Bs; 71 | C = Cs; 72 | acc = accs; 73 | D = Ds; 74 | for (int b = 0; b < kBatch; ++b) { 75 | A += b * kM * kK; 76 | B += b * kK * kN; 77 | C += b * kM * kN; 78 | acc += b * kM * kN; 79 | D += b * kM * kP; 80 | 81 | cublasHgemm(handle, CUBLAS_OP_T, CUBLAS_OP_N /* transb*/, kN, 82 | kM, kK, &alf, B, kK, A, kK, &bet, acc, kN); 83 | 84 | cublasHgemm(handle, CUBLAS_OP_T, CUBLAS_OP_N /* transb*/, kP, 85 | kM, kN, &alf, C, kN, acc, kN, &bet, D, kP); 86 | } 87 | } 88 | cudaDeviceSynchronize(); 89 | 90 | CudaTimer timer; 91 | timer.start(); 92 | for (int i = 0; i < iters; ++i) { 93 | A = As; 94 | B = Bs; 95 | C = Cs; 96 | acc = accs; 97 | D = Ds; 98 | for (int b = 0; b < kBatch; ++b) { 99 | A += b * kM * kK; 100 | B += b * kK * kN; 101 | C += b * kM * kN; 102 | acc += b * kM * kN; 103 | D += b * kM * kP; 104 | 105 | cublasHgemm(handle, CUBLAS_OP_T, CUBLAS_OP_N /* transb*/, kN, 106 | kM, kK, &alf, B, kK, A, kK, &bet, acc, kN); 107 | 108 | cublasHgemm(handle, CUBLAS_OP_T, CUBLAS_OP_N /* transb*/, kP, 109 | kM, kN, &alf, C, kN, acc, kN, &bet, D, kP); 110 | } 111 | } 112 | cudaDeviceSynchronize(); 113 | elapsed = timer.stop() / iters; 114 | } 115 | 116 | cublasDestroy(handle); 117 | 118 | return elapsed; 119 | } 120 | 121 | bool check_results(const float* values1, const __half* values2, int numel, 122 | float epsilon) { 123 | bool passed = true; 124 | 125 | float v2 = 0.; 126 | 127 | double total_diff = 0.; 128 | double max_abs_diff = FLT_MIN; 129 | double diff = 0.; 130 | 131 | for (int i = 0; i < numel; ++i) { 132 | v2 = __half2float(values2[i]); 133 | diff = abs(values1[i] - v2); 134 | max_abs_diff = max_abs_diff < diff ? diff : max_abs_diff; 135 | total_diff += diff; 136 | 137 | #ifdef DEBUG 138 | if (diff > epsilon) { 139 | printf("%d-th value has large differences: %.3f vs. %.3f\n", i, 140 | values1[i], v2); 141 | } 142 | #endif 143 | } 144 | 145 | double avg_diff = total_diff / numel; 146 | if (avg_diff > epsilon) passed = false; 147 | 148 | return passed; 149 | } 150 | -------------------------------------------------------------------------------- /examples/cpp/03_flash_attention/.gitignore: -------------------------------------------------------------------------------- 1 | build 2 | -------------------------------------------------------------------------------- /examples/cpp/03_flash_attention/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.18 FATAL_ERROR) 2 | project(flash_attn_example LANGUAGES C CXX CUDA) 3 | 4 | set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} 5 | "${PROJECT_SOURCE_DIR}/../../../scripts/cmake") 6 | include(generic) 7 | 8 | set(THIRD_PARTY_DIR "${PROJECT_SOURCE_DIR}/../../../3rd-party") 9 | include_directories(${THIRD_PARTY_DIR}/cutlass/include) 10 | include_directories("${PROJECT_SOURCE_DIR}/../../../include") 11 | 12 | add_executable(flash_attn flash_attn.cu) 13 | -------------------------------------------------------------------------------- /examples/cpp/03_flash_attention/flash_attn_cpu.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "cuda_utils.hpp" 4 | 5 | __half max(__half a, __half b) { return a > b ? a : b; } 6 | 7 | __half exp(__half x) { return __float2half(exp(__half2float(x))); } 8 | 9 | void host_flash_attn(int kM, int kN, int kK, int kP, int kBatch, 10 | const __half* Q, const __half* K, const __half* V, 11 | __half* O, __half* acc, __half* exp_values, 12 | __half* cur_row_max, __half* prev_row_max, 13 | __half* new_row_max, __half* prev_norm_vec, 14 | __half* new_norm_vec, __half* prev_sums, __half* cur_sums, 15 | __half* new_sums) { 16 | #pragma omp parallel for 17 | for (int b = 0; b < kBatch; ++b) { 18 | const __half* q_ = Q + b * kM * kK; 19 | const __half* k_ = K + b * kK * kN; 20 | const __half* v_ = V + b * kN * kP; 21 | __half* o_ = O + b * kM * kP; 22 | 23 | // Compute attention scores: Q * K^T 24 | for (int i = 0; i < kM; ++i) { 25 | for (int j = 0; j < kN; ++j) { 26 | __half s = 0.; 27 | for (int k = 0; k < kK; ++k) { 28 | s += q_[i * kK + k] * k_[k + kK * j]; 29 | } 30 | acc[i * kN + j] = s; 31 | } 32 | } 33 | 34 | // Compute row max of attention scores 35 | for (int i = 0; i < kM; ++i) { 36 | __half max_val = 0.; 37 | for (int j = 0; j < kN; ++j) { 38 | max_val = max(max_val, acc[i * kN + j]); 39 | } 40 | cur_row_max[i] = max_val; 41 | } 42 | 43 | // Broadcast sub row max to attention scores 44 | for (int i = 0; i < kM; ++i) { 45 | for (int j = 0; j < kN; ++j) { 46 | acc[i * kN + j] = exp(acc[i * kN + j] - cur_row_max[i]); 47 | } 48 | } 49 | 50 | // Compute reduce sum for each row and store into `cur_sums`. 51 | for (int i = 0; i < kM; ++i) { 52 | __half sum = 0.; 53 | for (int j = 0; j < kN; ++j) { 54 | sum += acc[i * kN + j]; 55 | } 56 | cur_sums[i] = sum; 57 | } 58 | 59 | // Compare cur row max with prev row max and compute new row max 60 | for (int i = 0; i < kM; ++i) { 61 | new_row_max[i] = max(cur_row_max[i], prev_row_max[i]); 62 | } 63 | 64 | for (int i = 0; i < kM; ++i) { 65 | // Compute remormalization factor for the previous block. 66 | prev_norm_vec[i] = exp(prev_row_max[i] - new_row_max[i]); 67 | 68 | // Compute remormalization factor for the current block. 69 | new_norm_vec[i] = exp(cur_row_max[i] - new_row_max[i]); 70 | } 71 | 72 | // Update normalization factor l(x). 73 | for (int i = 0; i < kM; ++i) { 74 | new_sums[i] = 75 | prev_norm_vec[i] * prev_sums[i] + new_norm_vec[i] * cur_sums[i]; 76 | } 77 | 78 | // Compute unnormalizatied attention score @ values 79 | for (int i = 0; i < kM; ++i) { 80 | for (int j = 0; j < kP; ++j) { 81 | __half s = 0.; 82 | for (int k = 0; k < kN; ++k) { 83 | s += acc[i * kN + k] * v_[k + kN * j]; 84 | } 85 | exp_values[i * kP + j] = s; 86 | } 87 | } 88 | 89 | // Compute O = (O * prev_sums * prev_norm_vec + new_norm_vec * 90 | // exp_values) / new_sums. 91 | for (int i = 0; i < kM; ++i) { 92 | for (int j = 0; j < kP; ++j) { 93 | o_[i * kP + j] = 94 | (o_[i * kP + j] * prev_sums[i] * prev_norm_vec[i] + 95 | new_norm_vec[i] * exp_values[i * kP + j]) / 96 | new_sums[i]; 97 | } 98 | } 99 | } 100 | } 101 | -------------------------------------------------------------------------------- /examples/cpp/03_flash_attention/util.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "flash_attn_cpu.hpp" 4 | #include "util/debug.hpp" 5 | 6 | #include 7 | #include 8 | 9 | float rand_float(float a = 1e-1, float b = 5e-2) { 10 | float random = ((float)rand()) / (float)RAND_MAX; 11 | float diff = b - a; 12 | float r = random * diff; 13 | return a + r; 14 | } 15 | 16 | bool check_results(const __half* values1, const __half* values2, int numel) { 17 | bool passed = true; 18 | const float epsilon = 1e-1; 19 | 20 | for (int i = 0; i < numel; ++i) { 21 | if (fabs(__half2float(values1[i]) - __half2float(values2[i])) > 22 | epsilon) { 23 | printf("%d-th value differs: %.3f vs. %.3f\n", i, 24 | __half2float(values1[i]), __half2float(values2[i])); 25 | passed = false; 26 | break; 27 | } 28 | } 29 | return passed; 30 | } 31 | -------------------------------------------------------------------------------- /examples/python/gemm/.gitignore: -------------------------------------------------------------------------------- 1 | tmp 2 | build 3 | __pycache__ 4 | -------------------------------------------------------------------------------- /examples/python/gemm/README.md: -------------------------------------------------------------------------------- 1 | This example demonstrates how to use TiledCuda's macro kernels to compose a simple GEMM and auto-tune some performance-critical parameters. 2 | 3 | In this simple GEMM implementation, data tiles are loaded from global memory directly into a thread's local registers. TensorCore's WMMA is then used to compute GEMM on the registers, and finally, the results are stored back to global memory. 4 | 5 | > [!Note] 6 | > *This example is for demonstration purposes and does not leverage shared memory.* 7 | 8 | To execute the example, run: 9 | 10 | ```bash 11 | python3 main.py 2>&1 | tee log.tsv 12 | ``` 13 | -------------------------------------------------------------------------------- /examples/python/gemm/compile.py: -------------------------------------------------------------------------------- 1 | import os 2 | import importlib.util 3 | import shutil 4 | from collections import defaultdict 5 | 6 | import subprocess 7 | import ctypes 8 | import torch 9 | 10 | __all__ = [ 11 | "Compile", 12 | ] 13 | 14 | cutlass_include_dir = os.path.join(os.path.dirname(__file__), 15 | "../../../3rd-party/cutlass/include") 16 | tiledcuda_include_dir = os.path.join(os.path.dirname(__file__), 17 | "../../../include/") 18 | csrc_include_dir = os.path.join(os.path.dirname(__file__), "csrc") 19 | 20 | 21 | class Compile: 22 | 23 | def __init__(self, file_prefix, tmp_dir): 24 | self.tmp_dir = tmp_dir 25 | self.file_prefix = file_prefix 26 | 27 | if not os.path.exists(self.tmp_dir): 28 | os.makedirs(self.tmp_dir) 29 | 30 | compute_capability = torch.cuda.get_device_capability() 31 | self.cc = f"{compute_capability[0]}{compute_capability[1]}" 32 | 33 | self.nvcc_path = self._find_nvcc_path() 34 | 35 | def _find_nvcc_path(self): 36 | 37 | def py_str(x): 38 | return x.decode('utf-8') 39 | 40 | if "CUDA_PATH" in os.environ: 41 | return os.environ["CUDA_PATH"] 42 | 43 | cmd = ["which", "nvcc"] 44 | proc = subprocess.Popen(cmd, 45 | stdout=subprocess.PIPE, 46 | stderr=subprocess.STDOUT) 47 | (out, _) = proc.communicate() 48 | 49 | if proc.returncode == 0: 50 | return py_str(out.strip()) 51 | else: 52 | raise RuntimeError("Cannot find cuda path") 53 | 54 | def _create_entry_code( 55 | self, 56 | M: int, 57 | N: int, 58 | K: int, 59 | TM: int, 60 | TN: int, 61 | kChunkK: int, 62 | warp_per_row: int, 63 | warp_per_col: int, 64 | ): 65 | entry_code_path = "entry.py" 66 | spec = importlib.util.spec_from_file_location("entry_code", 67 | entry_code_path) 68 | foo = importlib.util.module_from_spec(spec) 69 | spec.loader.exec_module(foo) 70 | 71 | shape = defaultdict(int) 72 | shape["kM"] = M 73 | shape["kN"] = N 74 | shape["kK"] = K 75 | shape["kTM"] = TM 76 | shape["kTN"] = TN 77 | shape["kChunkK"] = kChunkK 78 | shape["warp_per_row"] = warp_per_row 79 | shape["warp_per_col"] = warp_per_col 80 | 81 | return foo.types.format_map(shape) + foo.entry 82 | 83 | def compile(self, 84 | M: int, 85 | N: int, 86 | K: int, 87 | TM: int, 88 | TN: int, 89 | kChunkK: int, 90 | warp_per_row: int, 91 | warp_per_col: int, 92 | timeout: float = None): 93 | temp_dir = self.tmp_dir 94 | 95 | file_name = (f"{self.file_prefix}_{M}_{N}_{K}" 96 | f"_{TM}_{TN}_{warp_per_row}_{warp_per_col}") 97 | lib_path = os.path.join(temp_dir, f"{file_name}.so") 98 | 99 | if os.path.exists(lib_path): 100 | return lib_path 101 | 102 | entry_code = self._create_entry_code(M, N, K, TM, TN, kChunkK, 103 | warp_per_row, warp_per_col) 104 | 105 | source_path = os.path.join(temp_dir, f"{file_name}.cu") 106 | with open(source_path, "w") as f: 107 | f.write(entry_code) 108 | 109 | if os.path.exists(lib_path): 110 | return lib_path 111 | 112 | command = [ 113 | self.nvcc_path, "-std=c++20", "-O3", "--use_fast_math", 114 | "--expt-relaxed-constexpr", "--disable-warnings", 115 | "--compiler-options", "'-fPIC'", "--shared", source_path, "-lcuda", 116 | f"-gencode=arch=compute_{self.cc},code=sm_{self.cc}", 117 | f"-I{cutlass_include_dir}", f"-I{tiledcuda_include_dir}", 118 | f"-I{csrc_include_dir}", "-o", lib_path 119 | ] 120 | try: 121 | ret = subprocess.run(command, timeout=timeout) 122 | except subprocess.TimeoutExpired: 123 | return None 124 | if ret.returncode == 0: 125 | return lib_path 126 | else: 127 | raise RuntimeError("Compilation failed") 128 | 129 | def apply(self, lib_path, torch_array: list, device: int): 130 | lib = ctypes.CDLL(lib_path) 131 | 132 | lib.kernel_entry.restype = ctypes.c_int 133 | torch.cuda.set_device(device) 134 | 135 | ret = lib.kernel_entry( 136 | *[ctypes.c_void_p(arr.data_ptr()) for arr in torch_array]) 137 | return ret 138 | -------------------------------------------------------------------------------- /examples/python/gemm/csrc/kernel.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | using namespace tiledcuda; 4 | using namespace tiledcuda::cell; 5 | using namespace tiledcuda::cell::compute; 6 | 7 | template 13 | __global__ void gemm(const InType* dA, const InType* dB, AccType* dC) { 14 | int offset_a = blockIdx.x * kTM * kK; 15 | int offset_b = blockIdx.y * kTN * kK; 16 | int offset_c = blockIdx.x * kTM * kN + blockIdx.y * kTN; 17 | 18 | IteratorA gAs(dA + offset_a); 19 | RegA rA; 20 | ALoader loader_a; 21 | 22 | IteratorB gBs(dB + offset_b); 23 | RegB rB; 24 | BLoader loader_b; 25 | 26 | RegC acc; 27 | GlobalC gC(dC + offset_c); 28 | CStorer storer_c; 29 | 30 | for (int k = 0; k < IteratorA::sc1; ++k) { 31 | loader_a(gAs(k), rA); 32 | loader_b(gBs(k), rB); 33 | __syncthreads(); 34 | 35 | compute::gemm_(rA, rB, acc); 36 | } 37 | __syncthreads(); 38 | 39 | storer_c(acc, gC); 40 | } 41 | -------------------------------------------------------------------------------- /examples/python/gemm/entry.py: -------------------------------------------------------------------------------- 1 | types = """#include "cell/mod.hpp" 2 | #include "types/mod.hpp" 3 | #include "kernel.h" 4 | 5 | using InType = __half; 6 | using AccType = float; 7 | 8 | static constexpr int kM = {kM}; 9 | static constexpr int kN = {kN}; 10 | static constexpr int kK = {kK}; 11 | 12 | static constexpr int kTM = {kTM}; 13 | static constexpr int kTN = {kTN}; 14 | 15 | using WarpLayout = tl::RowMajor<{warp_per_row}, {warp_per_col}>; 16 | 17 | using BaseShape = traits::BaseTileShape; 18 | static constexpr int kChunkK = {kChunkK}; 19 | 20 | static constexpr int kThreads = tl::get_numel * 32; 21 | static constexpr int kWarpPerRow = tl::num_rows; 22 | static constexpr int kWarpPerCol = tl::num_cols; 23 | 24 | using GlobalA = GlobalTile>; 25 | using IteratorA = GTileIterator>; 26 | 27 | static constexpr int kAMs = kTM / kWarpPerRow / BaseShape::kTileSize; 28 | static constexpr int kAKs = kChunkK / BaseShape::kTileSize; 29 | using RegA = RegTile, tl::RowMajor>; 30 | 31 | using ALoader = copy::GlobalToRegLoader; 33 | 34 | using GlobalB = GlobalTile>; 35 | using IteratorB = GTileIterator>; 36 | 37 | static constexpr int kBKs = kChunkK / BaseShape::kTileSize; 38 | static constexpr int kBNs = kTN / kWarpPerCol / BaseShape::kTileSize; 39 | using RegB = RegTile, tl::ColMajor>; 40 | 41 | using BLoader = copy::GlobalToRegLoader; 43 | 44 | using GlobalC = GlobalTile>; 45 | 46 | static constexpr int kCMs = kTM / kWarpPerRow / BaseShape::kTileSize; 47 | static constexpr int kCNs = kTN / kWarpPerCol / BaseShape::kTileSize; 48 | using RegC = RegTile, tl::RowMajor>; 49 | 50 | using CStorer = copy::RegToGlobalStorer; 51 | 52 | int block_x = CeilDiv; 53 | int block_y = CeilDiv; 54 | """ 55 | 56 | entry = """ 57 | extern "C" int kernel_entry(__half* parameter1, __half* parameter2, 58 | float* paramter3) { 59 | auto kernel = 60 | &gemm; 62 | 63 | kernel<<>>( 64 | parameter1, parameter2, paramter3); 65 | 66 | return 0; 67 | } 68 | """ 69 | -------------------------------------------------------------------------------- /examples/python/gemm/gemm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | from compile import Compile 5 | 6 | __all__ = [ 7 | "gemm_func", 8 | ] 9 | 10 | 11 | class GemmFunc(torch.autograd.Function): 12 | 13 | @staticmethod 14 | def forward( 15 | ctx, 16 | A: Tensor, 17 | B: Tensor, 18 | C: Tensor, 19 | M: int, 20 | N: int, 21 | K: int, 22 | kM: int, 23 | kN: int, 24 | kChunkK: int, 25 | warp_per_row: int, 26 | warp_per_col: int, 27 | ) -> Tensor: 28 | builder = Compile(file_prefix="gemm", tmp_dir="tmp") 29 | lib_name = builder.compile(M, N, K, kM, kN, kChunkK, warp_per_row, 30 | warp_per_col) 31 | 32 | if lib_name is None: 33 | raise RuntimeError("Failed to compile the library.") 34 | 35 | builder.apply(lib_name, [A, B, C], device=0) 36 | return C 37 | 38 | 39 | gemm_func = GemmFunc.apply 40 | -------------------------------------------------------------------------------- /examples/python/gemm/main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from typing import Tuple 4 | 5 | from gemm import gemm_func 6 | 7 | 8 | def run_unittest(a: Tensor, 9 | b: Tensor, 10 | c: Tensor, 11 | M: int, 12 | N: int, 13 | K: int, 14 | TM: int, 15 | TN: int, 16 | kChunkK: int, 17 | warp_layout: Tuple, 18 | epsilon: float = 5e-2, 19 | debug_print=False): 20 | gemm_func(a, b, c, M, N, K, TM, TN, kChunkK, *warp_layout) 21 | ref_c = a @ b.t() 22 | 23 | if debug_print: 24 | print("Result:") 25 | print(c) 26 | 27 | print("\nReference:") 28 | print(ref_c) 29 | 30 | avg_diff = (torch.sum(torch.abs(ref_c - c)) / (M * N)).item() 31 | if avg_diff > epsilon: 32 | return False 33 | else: 34 | return True 35 | 36 | 37 | def run_test( 38 | M: int, 39 | N: int, 40 | K: int, 41 | TM: int, 42 | TN: int, 43 | kChunkK: int, 44 | warp_layout: Tuple, 45 | ): 46 | device = torch.device("cuda") 47 | dtype = torch.float16 48 | 49 | a = torch.randn(M, K, device=device, dtype=dtype) 50 | b = torch.randn(N, K, device=device, dtype=dtype) 51 | c = torch.zeros(M, N, device=device, dtype=torch.float32) 52 | 53 | if not run_unittest(a, b, c, M, N, K, TM, TN, kChunkK, warp_layout): 54 | raise RuntimeError("Failed unittest.") 55 | 56 | for _ in range(5): # warm up 57 | gemm_func(a, b, c, M, N, K, TM, TN, kChunkK, *warp_layout) 58 | ref_c = a @ b.t() 59 | 60 | start_event = torch.cuda.Event(enable_timing=True) 61 | end_event = torch.cuda.Event(enable_timing=True) 62 | 63 | iters = 50 64 | start_event.record() 65 | for i in range(iters): 66 | gemm_func(a, b, c, M, N, K, TM, TN, kChunkK, *warp_layout) 67 | end_event.record() 68 | torch.cuda.synchronize() 69 | 70 | time1 = start_event.elapsed_time(end_event) / iters 71 | 72 | start_event.record() 73 | for i in range(iters): 74 | ref_c = a @ b.t() 75 | end_event.record() 76 | torch.cuda.synchronize() 77 | 78 | time2 = start_event.elapsed_time(end_event) / iters 79 | return time1, time2 80 | 81 | 82 | if __name__ == "__main__": 83 | M = 4096 84 | N = 4096 85 | K = 4096 86 | 87 | print(("Whole Shape\tBlock Shape\tthreads" 88 | "\ttiledcuda(ms)\tcublass(ms)\tRatio")) 89 | 90 | warp_layout = (1, 2) 91 | threads = warp_layout[0] * warp_layout[1] * 32 92 | for TM in [64, 128]: 93 | for TN in [64, 128]: 94 | for kChunkK in [32, 64, 128]: 95 | time1, time2 = run_test(M, N, K, TM, TN, kChunkK, warp_layout) 96 | print(("[{}, {}, {}]\t[{}, {}, {}]" 97 | "\t{}\t{:.4f}\t{:.4f}\t{:.3f}").format( 98 | M, N, K, TM, TN, kChunkK, threads, time1, time2, 99 | time1 / time2)) 100 | 101 | for warp_layout in [(2, 2), (2, 4)]: 102 | threads = warp_layout[0] * warp_layout[1] * 32 103 | 104 | for TM in [64, 128, 256]: 105 | for TN in [64, 128, 256]: 106 | for kChunkK in [32, 64, 128]: 107 | time1, time2 = run_test(M, N, K, TM, TN, kChunkK, 108 | warp_layout) 109 | print(("[{}, {}, {}]\t[{}, {}, {}]" 110 | "\t{}\t{:.4f}\t{:.4f}\t{:.3f}").format( 111 | M, N, K, TM, TN, kChunkK, threads, time1, time2, 112 | time1 / time2)) 113 | -------------------------------------------------------------------------------- /examples/python/scatter_nd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import random 4 | from functools import reduce 5 | from operator import mul 6 | 7 | torch.ops.load_library("build/libtiledcuda.so") 8 | 9 | 10 | def compute_output_shape(index_dims, input_dims): 11 | end_size = index_dims[-1] 12 | out_shape = index_dims[:-1] 13 | for i in range(len(input_dims) - end_size): 14 | out_shape.append(input_dims[len(index_dims) + i]) 15 | return out_shape 16 | 17 | 18 | def test_scatter_nd(): 19 | data_shape = [7, 8, 9, 10] 20 | data_numel = reduce(mul, data_shape) 21 | data = torch.empty(data_shape, dtype=torch.float32, 22 | device='cuda').fill_(5.0) 23 | scatter_data = data.flatten() 24 | 25 | indices_shape = [5, 2] 26 | indices_numel = reduce(mul, indices_shape) 27 | indices = torch.empty(indices_shape, dtype=torch.int64, device='cuda') 28 | 29 | for i in range(indices_shape[0]): 30 | # indices[i * indices_shape[1]] = random.randint(0, data_shape[0] - 1) 31 | # indices[i * indices_shape[1] + 32 | # 1] = random.randint(0, data_shape[1] - 1) 33 | indices[i][0] = random.randint(0, data_shape[0] - 1) 34 | indices[i][1] = random.randint(0, data_shape[1] - 1) 35 | 36 | scatter_indices = indices.flatten() 37 | 38 | slice_size = 1 39 | end_size = indices_shape[-1] 40 | for i in range(end_size, len(data_shape)): 41 | slice_size *= data_shape[i] 42 | 43 | update_shape = compute_output_shape(indices_shape, data_shape) 44 | # update_numel = reduce(mul, update_shape) 45 | updates = torch.empty(update_shape, dtype=torch.float32, 46 | device='cuda').fill_(10.0) 47 | scatter_updates = updates.flatten() 48 | 49 | torch.ops.tiledcuda.scatter_nd(scatter_data, scatter_updates, 50 | scatter_indices) 51 | 52 | return scatter_data 53 | 54 | 55 | if __name__ == "__main__": 56 | print(torch.ops.tiledcuda.scatter_nd) 57 | data = test_scatter_nd() 58 | # Print data 59 | print(data) 60 | -------------------------------------------------------------------------------- /include/cell/acc.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "cuda_utils.hpp" 3 | 4 | #include 5 | 6 | namespace tiledcuda::cell { 7 | using namespace cute; 8 | 9 | template 10 | DEVICE auto get_acc(const TiledMma& tiled_mma) { 11 | auto acc = partition_fragment_C(tiled_mma, Shape, Int>{}); 12 | clear(acc); 13 | 14 | return acc; 15 | } 16 | } // namespace tiledcuda::cell 17 | -------------------------------------------------------------------------------- /include/cell/compute/broadcast.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "cell/compute/math_functor.hpp" 4 | #include "cuda_utils.hpp" 5 | #include "types/layout.hpp" 6 | #include "types/tile_shape.hpp" 7 | 8 | namespace tiledcuda::cell::compute { 9 | namespace tl = tile_layout; 10 | 11 | template 12 | struct Broadcast { 13 | using DType = typename DstTile::DType::DType; 14 | 15 | static constexpr int kRows = DstTile::kRows; 16 | static constexpr int kCols = DstTile::kCols; 17 | 18 | DEVICE void operator()(const SrcTile& src, DstTile& dst) {} 19 | }; 20 | 21 | template 22 | struct Broadcast { 23 | using DType = typename DstTile::DType::DType; 24 | 25 | static constexpr int kRows = DstTile::kRows; 26 | static constexpr int kCols = DstTile::kCols; 27 | 28 | DEVICE void operator()(const SrcTile& src, DstTile& dst) { 29 | #pragma unroll 30 | for (int i = 0; i < kRows; ++i) { 31 | DType top_row = src(i, 0); 32 | DType bottom_row = src(i, 1); 33 | #pragma unroll 34 | for (int j = 0; j < kCols; ++j) { 35 | dst(i, j)(0, 0) = top_row; 36 | dst(i, j)(0, 1) = top_row; 37 | dst(i, j)(1, 0) = top_row; 38 | dst(i, j)(1, 1) = top_row; 39 | 40 | dst(i, j)(0, 2) = bottom_row; 41 | dst(i, j)(0, 3) = bottom_row; 42 | dst(i, j)(1, 2) = bottom_row; 43 | dst(i, j)(1, 3) = bottom_row; 44 | } 45 | } 46 | } 47 | }; 48 | 49 | template 50 | struct Broadcast { 51 | using DType = typename DstTile::DType::DType; 52 | 53 | static constexpr int kRows = DstTile::kRows; 54 | static constexpr int kCols = DstTile::kCols; 55 | 56 | DEVICE void operator()(const SrcTile& src, DstTile& dst) { 57 | #pragma unroll 58 | for (int j = 0; j < kCols; ++j) { 59 | DType top_col = src(0, j); 60 | DType bottom_col = src(1, j); 61 | #pragma unroll 62 | for (int i = 0; i < kRows; ++i) { 63 | dst(i, j)(0, 0) = top_col; 64 | dst(i, j)(1, 0) = top_col; 65 | dst(i, j)(0, 1) = top_col; 66 | dst(i, j)(1, 1) = top_col; 67 | 68 | dst(i, j)(2, 0) = bottom_col; 69 | dst(i, j)(3, 0) = bottom_col; 70 | dst(i, j)(2, 1) = bottom_col; 71 | dst(i, j)(3, 1) = bottom_col; 72 | } 73 | } 74 | } 75 | }; 76 | 77 | template 79 | struct BroadcastFuse { 80 | using DType = typename DstTile::DType::DType; 81 | 82 | static constexpr int kRows = DstTile::kRows; 83 | static constexpr int kCols = DstTile::kCols; 84 | 85 | DEVICE void operator()(const SrcTile& src, DstTile& dst) {} 86 | }; 87 | 88 | template 89 | struct BroadcastFuse { 90 | using DType = typename DstTile::DType::DType; 91 | 92 | static constexpr int kRows = DstTile::kRows; 93 | static constexpr int kCols = DstTile::kCols; 94 | 95 | DEVICE void operator()(const SrcTile& src, DstTile& dst) { 96 | Functor f; 97 | #pragma unroll 98 | for (int i = 0; i < kRows; ++i) { 99 | DType top_row = src(i, 0); 100 | DType bottom_row = src(i, 1); 101 | #pragma unroll 102 | for (int j = 0; j < kCols; ++j) { 103 | f(dst(i, j)(0, 0), top_row, dst(i, j)(0, 0)); 104 | f(dst(i, j)(0, 1), top_row, dst(i, j)(0, 1)); 105 | f(dst(i, j)(1, 0), top_row, dst(i, j)(1, 0)); 106 | f(dst(i, j)(1, 1), top_row, dst(i, j)(1, 1)); 107 | 108 | f(dst(i, j)(0, 2), bottom_row, dst(i, j)(0, 2)); 109 | f(dst(i, j)(0, 3), bottom_row, dst(i, j)(0, 3)); 110 | f(dst(i, j)(1, 2), bottom_row, dst(i, j)(1, 2)); 111 | f(dst(i, j)(1, 3), bottom_row, dst(i, j)(1, 3)); 112 | } 113 | } 114 | } 115 | }; 116 | 117 | template 118 | struct BroadcastFuse { 119 | using DType = typename DstTile::DType::DType; 120 | 121 | static constexpr int kRows = DstTile::kRows; 122 | static constexpr int kCols = DstTile::kCols; 123 | 124 | DEVICE void operator()(const SrcTile& src, DstTile& dst) { 125 | Functor f; 126 | #pragma unroll 127 | for (int j = 0; j < kCols; ++j) { 128 | DType top_col = src(0, j); 129 | DType bottom_col = src(1, j); 130 | #pragma unroll 131 | for (int i = 0; i < kRows; ++i) { 132 | f(dst(i, j)(0, 0), top_col, dst(i, j)(0, 0)); 133 | f(dst(i, j)(1, 0), top_col, dst(i, j)(1, 0)); 134 | f(dst(i, j)(0, 1), top_col, dst(i, j)(0, 1)); 135 | f(dst(i, j)(1, 1), top_col, dst(i, j)(1, 1)); 136 | 137 | f(dst(i, j)(2, 0), bottom_col, dst(i, j)(2, 0)); 138 | f(dst(i, j)(3, 0), bottom_col, dst(i, j)(3, 0)); 139 | f(dst(i, j)(2, 1), bottom_col, dst(i, j)(2, 1)); 140 | f(dst(i, j)(3, 1), bottom_col, dst(i, j)(3, 1)); 141 | } 142 | } 143 | } 144 | }; 145 | 146 | template 147 | using BroadcastAdd = 148 | BroadcastFuse, kLayout>; 149 | 150 | template 151 | using BroadcastSub = 152 | BroadcastFuse, kLayout>; 153 | 154 | template 155 | using BroadcastMul = 156 | BroadcastFuse, kLayout>; 157 | 158 | template 159 | using BroadcastDiv = 160 | BroadcastFuse, kLayout>; 161 | 162 | } // namespace tiledcuda::cell::compute 163 | -------------------------------------------------------------------------------- /include/cell/compute/gemm.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "cell/traits/base.hpp" 4 | #include "cuda_utils.hpp" 5 | #include "types/layout.hpp" 6 | #include "types/tile_shape.hpp" 7 | 8 | #include 9 | 10 | namespace tiledcuda::cell::compute { 11 | namespace tl = tile_layout; 12 | 13 | namespace detail { 14 | 15 | /// @brief: Functor to warp wmma PTX instruction. See the below document for 16 | /// various choices and detailed parameters of the wmma PTX instruction. 17 | /// https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-mma 18 | template 19 | struct Gemm { 20 | using InTypeA = typename RegTileA::DType::DType; 21 | using InTypeB = typename RegTileB::DType::DType; 22 | using OutType = typename RegTileC::DType::DType; 23 | 24 | using BaseShape = traits::BaseTileShape; 25 | 26 | static_assert(std::is_same_v || 27 | std::is_same_v, 28 | "This GEMM implementation supports only half-precision as " 29 | "the input element type."); 30 | static_assert(std::is_same_v, 31 | "The output type must be float."); 32 | static_assert(std::is_same_v, 33 | "Mismatched data type for operand A and B."); 34 | static_assert(RegTileB::kRows == RegTileA::kCols, 35 | "Mismatched k-dimension for operand A and B."); 36 | 37 | static constexpr int kMs = RegTileA::kRows; 38 | static constexpr int kNs = RegTileB::kCols; 39 | static constexpr int kKs = RegTileA::kCols; 40 | static_assert(kMs && kNs && kKs, "Invalid tile shapes for GEMM."); 41 | 42 | DEVICE void operator()(const RegTileA& a, const RegTileB& b, RegTileC& c) { 43 | for (int i = 0; i < kMs; ++i) { 44 | for (int j = 0; j < kNs; ++j) { 45 | #pragma unroll 46 | for (int k = 0; k < kKs; ++k) { 47 | tile_wmma(a(i, k).data(), b(k, j).data(), 48 | c(i, j).mutable_data()); 49 | } 50 | } 51 | } 52 | } 53 | 54 | private: 55 | DEVICE void tile_wmma(const InTypeA* ra, const InTypeB* rb, OutType* rc) { 56 | const uint32_t* A = reinterpret_cast(ra); 57 | const uint32_t* B = reinterpret_cast(rb); 58 | float* C = static_cast(rc); 59 | 60 | asm volatile( 61 | "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " 62 | "{%0, %1, %2, %3}," 63 | "{%4, %5, %6, %7}," 64 | "{%8, %9}," 65 | "{%10, %11, %12, %13};\n" 66 | : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) 67 | : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[2]), 68 | "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); 69 | 70 | asm volatile( 71 | "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " 72 | "{%0, %1, %2, %3}," 73 | "{%4, %5, %6, %7}," 74 | "{%8, %9}," 75 | "{%10, %11, %12, %13};\n" 76 | : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) 77 | : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[1]), "r"(B[3]), 78 | "f"(C[4]), "f"(C[5]), "f"(C[6]), "f"(C[7])); 79 | } 80 | }; 81 | 82 | } // namespace detail 83 | 84 | template 85 | DEVICE void gemm(const RegTileA& a, const RegTileB& b, RegTileC& c) { 86 | detail::Gemm gemm; 87 | gemm(a, b, c); 88 | } 89 | 90 | } // namespace tiledcuda::cell::compute 91 | -------------------------------------------------------------------------------- /include/cell/compute/map.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "cell/compute/math_functor.hpp" 4 | #include "cell/warp.hpp" 5 | #include "cuda_utils.hpp" 6 | #include "types/layout.hpp" 7 | 8 | namespace tiledcuda::cell::compute { 9 | 10 | namespace tl = tile_layout; 11 | 12 | namespace detail { 13 | 14 | // TODO(KuangjuX): Distinguish whether the `Layout` is Row Major or Column 15 | // Major. Different Layouts have different directions of memory continuity, 16 | // which will affect the memory access performance. 17 | template 18 | struct ElementWise { 19 | using DType = typename RegTile::DType; 20 | 21 | static constexpr int kRows = RegTile::kRows; 22 | static constexpr int kCols = RegTile::kCols; 23 | 24 | DEVICE void operator()(const RegTile& src, RegTile& dst) { 25 | Functor f; 26 | #pragma unroll 27 | for (int i = 0; i < kRows; ++i) { 28 | #pragma unroll 29 | for (int j = 0; j < kCols; ++j) { 30 | f(src(i, j), dst(i, j)); 31 | } 32 | } 33 | } 34 | }; 35 | 36 | // TODO(KuangjuX): Distinguish whether the `Layout` is Row Major or Column 37 | // Major. Different Layouts have different directions of memory continuity, 38 | // which will affect the memory access performance. 39 | template 40 | struct ElementWise2 { 41 | static constexpr int kRows = SrcRegTile::kRows; 42 | static constexpr int kCols = SrcRegTile::kCols; 43 | 44 | static_assert(kRows == DstRegTile::kRows, "kRows must be equal"); 45 | static_assert(kCols == DstRegTile::kCols, "kCols must be equal"); 46 | 47 | DEVICE void operator()(const SrcRegTile& src, DstRegTile& dst) { 48 | Functor f; 49 | #pragma unroll 50 | for (int i = 0; i < kRows; ++i) { 51 | #pragma unroll 52 | for (int j = 0; j < kCols; ++j) { 53 | f(src(i, j), dst(i, j)); 54 | } 55 | } 56 | } 57 | }; 58 | 59 | // TODO(KuangjuX): Distinguish whether the `Layout` is Row Major or Column 60 | // Major. Different Layouts have different directions of memory continuity, 61 | // which will affect the memory access performance. 62 | template 63 | struct Binary { 64 | using DType = typename RegTile::DType; 65 | 66 | static constexpr int kRows = RegTile::kRows; 67 | static constexpr int kCols = RegTile::kCols; 68 | 69 | DEVICE void operator()(const RegTile& lhs, const RegTile& rhs, 70 | RegTile& dst) { 71 | Functor f; 72 | #pragma unroll 73 | for (int i = 0; i < kRows; ++i) { 74 | #pragma unroll 75 | for (int j = 0; j < kCols; ++j) { 76 | f(lhs(i, j), rhs(i, j), dst(i, j)); 77 | } 78 | } 79 | } 80 | }; 81 | 82 | } // namespace detail 83 | 84 | template 85 | using BaseTileExp = detail::ElementWise>; 86 | template 87 | using RegTileExp = 88 | detail::ElementWise>; 89 | 90 | template 91 | using BaseTileLog = detail::ElementWise>; 92 | template 93 | using RegTileLog = 94 | detail::ElementWise>; 95 | 96 | template 97 | using BaseTileConvert = detail::ElementWise2< 98 | SrcRegTile, DstRegTile, 99 | Convert>; 100 | template 101 | using RegTileConvert = detail::ElementWise2< 102 | SrcRegTile, DstRegTile, 103 | BaseTileConvert>; 104 | 105 | template 106 | using BaseTileAdd = detail::Binary>; 107 | template 108 | using RegTileAdd = 109 | detail::Binary>; 110 | 111 | template 112 | using BaseTileSub = detail::Binary>; 113 | template 114 | using RegTileSub = 115 | detail::Binary>; 116 | 117 | template 118 | using BaseTileMul = detail::Binary>; 119 | template 120 | using RegTileMul = 121 | detail::Binary>; 122 | 123 | template 124 | using BaseTileDiv = detail::Binary>; 125 | template 126 | using RegTileDiv = 127 | detail::Binary>; 128 | 129 | template 130 | using BaseTileMax = detail::Binary>; 131 | template 132 | using RegTileMax = 133 | detail::Binary>; 134 | 135 | } // namespace tiledcuda::cell::compute 136 | -------------------------------------------------------------------------------- /include/cell/compute/math_functor.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "cuda_utils.hpp" 4 | 5 | namespace tiledcuda::cell::compute { 6 | 7 | template 8 | struct Add { 9 | DEVICE Element operator()(Element a, Element b) const { return a + b; } 10 | 11 | DEVICE void operator()(const Element& lhs, const Element& rhs, 12 | Element& dst) { 13 | dst = lhs + rhs; 14 | } 15 | }; 16 | 17 | template 18 | struct Sub { 19 | DEVICE Element operator()(Element a, Element b) const { return a - b; } 20 | 21 | DEVICE void operator()(const Element& lhs, const Element& rhs, 22 | Element& dst) { 23 | dst = lhs - rhs; 24 | } 25 | }; 26 | 27 | template 28 | struct Mul { 29 | DEVICE Element operator()(Element a, Element b) const { return a * b; } 30 | 31 | DEVICE void operator()(const Element& lhs, const Element& rhs, 32 | Element& dst) { 33 | dst = lhs * rhs; 34 | } 35 | }; 36 | 37 | template 38 | struct Div { 39 | DEVICE Element operator()(Element a, Element b) const { return a / b; } 40 | 41 | DEVICE void operator()(const Element& lhs, const Element& rhs, 42 | Element& dst) { 43 | dst = lhs / rhs; 44 | } 45 | }; 46 | 47 | template 48 | struct Max { 49 | DEVICE Element operator()(Element a, Element b) const { 50 | return a > b ? a : b; 51 | } 52 | 53 | DEVICE void operator()(const Element& lhs, const Element& rhs, 54 | Element& dst) { 55 | dst = lhs > rhs ? lhs : rhs; 56 | } 57 | }; 58 | 59 | template 60 | struct Min { 61 | DEVICE Element operator()(Element a, Element b) const { 62 | return a < b ? a : b; 63 | } 64 | 65 | DEVICE void operator()(const Element& lhs, const Element& rhs, 66 | Element& dst) { 67 | dst = lhs < rhs ? lhs : rhs; 68 | } 69 | }; 70 | 71 | template 72 | struct Exp { 73 | DEVICE Element operator()(Element a) const { return exp(a); } 74 | 75 | DEVICE void operator()(const Element& src, Element& dst) { dst = exp(src); } 76 | }; 77 | 78 | #if defined(__CUDA_ARCH__) 79 | template <> 80 | struct Exp { 81 | DEVICE float operator()(float a) const { return __expf(a); } 82 | 83 | DEVICE void operator()(const float& src, float& dst) { dst = __expf(src); } 84 | }; 85 | 86 | template <> 87 | struct Exp<__half> { 88 | DEVICE __half operator()(__half a) const { return hexp(a); } 89 | 90 | DEVICE void operator()(const __half& src, __half& dst) { dst = hexp(src); } 91 | }; 92 | #endif 93 | 94 | template 95 | struct Log { 96 | DEVICE Element operator()(Element a) const { return log(a); } 97 | 98 | DEVICE void operator()(const Element& src, Element& dst) { dst = log(src); } 99 | }; 100 | 101 | #if defined(__CUDA_ARCH__) 102 | template <> 103 | struct Log { 104 | DEVICE float operator()(float a) const { return __logf(a); } 105 | 106 | DEVICE void operator()(const float& src, float& dst) { dst = __logf(src); } 107 | }; 108 | 109 | template <> 110 | struct Log<__half> { 111 | DEVICE __half operator()(__half a) const { return hlog(a); } 112 | 113 | DEVICE void operator()(const __half& src, __half& dst) { dst = hlog(src); } 114 | }; 115 | #endif 116 | 117 | template 118 | struct Relu { 119 | DEVICE Element operator()(Element a) const { return a > 0 ? a : 0; } 120 | 121 | DEVICE void operator()(const Element& src, Element& dst) { 122 | dst = src > 0 ? src : 0; 123 | } 124 | }; 125 | 126 | #if defined(__CUDA_ARCH__) 127 | template <> 128 | struct Relu { 129 | DEVICE float operator()(float a) const { return max(a, 0.f); } 130 | 131 | DEVICE void operator()(const float& src, float& dst) { 132 | dst = max(src, 0.f); 133 | } 134 | }; 135 | 136 | template <> 137 | struct Relu<__half> { 138 | DEVICE __half operator()(__half a) const { return __hmax(a, 0); } 139 | 140 | DEVICE void operator()(const __half& src, __half& dst) { 141 | dst = __hmax(src, 0); 142 | } 143 | }; 144 | #endif 145 | 146 | template 147 | struct Convert { 148 | DEVICE DstType operator()(SrcType a) const { 149 | return static_cast(a); 150 | } 151 | 152 | DEVICE void operator()(const SrcType& src, DstType& dst) { 153 | dst = static_cast(src); 154 | } 155 | }; 156 | 157 | #if defined(__CUDA_ARCH__) 158 | 159 | template <> 160 | struct Convert { 161 | DEVICE __half operator()(float a) const { return __float2half(a); } 162 | 163 | DEVICE void operator()(const float& src, __half& dst) { 164 | dst = __float2half(src); 165 | } 166 | }; 167 | 168 | template <> 169 | struct Convert<__half, float> { 170 | DEVICE float operator()(__half a) const { return __half2float(a); } 171 | 172 | DEVICE void operator()(const __half& src, float& dst) { 173 | dst = __half2float(src); 174 | } 175 | }; 176 | #endif 177 | 178 | } // namespace tiledcuda::cell::compute 179 | -------------------------------------------------------------------------------- /include/cell/compute/mod.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "cell/compute/broadcast.hpp" 4 | #include "cell/compute/gemm.hpp" 5 | #include "cell/compute/map.hpp" 6 | #include "cell/compute/math_functor.hpp" 7 | #include "cell/compute/reduce.hpp" 8 | #include "cell/compute/softmax.hpp" 9 | -------------------------------------------------------------------------------- /include/cell/compute/reduce.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "cell/compute/math_functor.hpp" 4 | #include "cell/traits/base.hpp" 5 | #include "cell/warp.hpp" 6 | #include "cuda_utils.hpp" 7 | #include "types/layout.hpp" 8 | #include "types/tile_shape.hpp" 9 | 10 | namespace tiledcuda::cell::compute { 11 | 12 | namespace tl = tile_layout; 13 | 14 | namespace detail { 15 | 16 | template 17 | struct Reduce { 18 | using DType = typename RegTile::DType::DType; 19 | using BaseShape = traits::BaseTileShape; 20 | 21 | static constexpr int kRows = RegTile::kRows; 22 | static constexpr int kCols = RegTile::kCols; 23 | 24 | template 25 | DEVICE void operator()(const RegTile& src, DstTile& dst, Reduce reduce) {} 26 | }; 27 | 28 | template 29 | struct Reduce { 30 | using DType = typename RegTile::DType::DType; 31 | using BaseShape = traits::BaseTileShape; 32 | 33 | static constexpr int kRows = RegTile::kRows; 34 | static constexpr int kCols = RegTile::kCols; 35 | 36 | template 37 | DEVICE void operator()(const RegTile& src, DstTile& dst, Reduce reduce) { 38 | const int leader = threadIdx.x & 0x1C; 39 | #pragma unroll 40 | for (int i = 0; i < kRows; ++i) { 41 | DType top_rows[kCols]; 42 | DType bottom_rows[kCols]; 43 | #pragma unroll 44 | for (int j = 0; j < kCols; ++j) { 45 | auto base_tile = src(i, j); 46 | DType top_row_0 = reduce(base_tile(0, 0), base_tile(0, 1)); 47 | DType top_row_1 = reduce(base_tile(1, 0), base_tile(1, 1)); 48 | top_rows[j] = reduce(top_row_0, top_row_1); 49 | 50 | DType bottom_row_0 = reduce(base_tile(0, 2), base_tile(0, 3)); 51 | DType bottom_row_1 = reduce(base_tile(1, 2), base_tile(1, 3)); 52 | bottom_rows[j] = reduce(bottom_row_0, bottom_row_1); 53 | } 54 | 55 | DType top_row = top_rows[0]; 56 | DType bottom_row = bottom_rows[0]; 57 | 58 | // Compute the reduction of the top and bottom rows. 59 | #pragma unroll 60 | for (int j = 1; j < kCols; ++j) { 61 | top_row = reduce(top_row, top_rows[j]); 62 | bottom_row = reduce(bottom_row, bottom_rows[j]); 63 | } 64 | 65 | // Shuffle the results to the leader thread. 66 | top_row = reduce(top_row, shuffle_down_sync(MASK_ALL, top_row, 2)); 67 | top_row = reduce(top_row, shuffle_down_sync(MASK_ALL, top_row, 1)); 68 | 69 | bottom_row = 70 | reduce(bottom_row, shuffle_down_sync(MASK_ALL, bottom_row, 2)); 71 | bottom_row = 72 | reduce(bottom_row, shuffle_down_sync(MASK_ALL, bottom_row, 1)); 73 | 74 | // Group the threads into groups of four, and broadcast the data 75 | // from the first thread in each group to the other three threads. 76 | top_row = shuffle_sync(MASK_ALL, top_row, leader); 77 | bottom_row = shuffle_sync(MASK_ALL, bottom_row, leader); 78 | 79 | // Store the results to the destination tile. 80 | dst(i, 0) = top_row; 81 | dst(i, 1) = bottom_row; 82 | } 83 | } 84 | }; 85 | 86 | template 87 | struct Reduce { 88 | using DType = typename RegTile::DType::DType; 89 | using BaseShape = traits::BaseTileShape; 90 | 91 | static constexpr int kRows = RegTile::kRows; 92 | static constexpr int kCols = RegTile::kCols; 93 | 94 | template 95 | DEVICE void operator()(const RegTile& tile, DstTile& dst, Reduce reduce) { 96 | const int leader = threadIdx.x & 0x1C; 97 | 98 | #pragma unroll 99 | for (int i = 0; i < kCols; ++i) { 100 | DType top_cols[kRows]; 101 | DType bottom_cols[kRows]; 102 | #pragma unroll 103 | for (int j = 0; j < kRows; ++j) { 104 | auto base_tile = tile(j, i); 105 | DType top_col_0 = reduce(base_tile(0, 0), base_tile(1, 0)); 106 | DType top_col_1 = reduce(base_tile(0, 1), base_tile(1, 1)); 107 | top_cols[j] = reduce(top_col_0, top_col_1); 108 | 109 | DType bottom_col_0 = reduce(base_tile(2, 0), base_tile(3, 0)); 110 | DType bottom_col_1 = reduce(base_tile(2, 1), base_tile(3, 1)); 111 | bottom_cols[j] = reduce(bottom_col_0, bottom_col_1); 112 | } 113 | 114 | DType top_col = top_cols[0]; 115 | DType bottom_col = bottom_cols[0]; 116 | 117 | // Compute the reduction of the top and bottom columns. 118 | #pragma unroll 119 | for (int j = 1; j < kRows; ++j) { 120 | top_col = reduce(top_col, top_cols[j]); 121 | bottom_col = reduce(bottom_col, bottom_cols[j]); 122 | } 123 | 124 | // Shuffle the results to the leader thread. 125 | top_col = reduce(top_col, shuffle_down_sync(MASK_ALL, top_col, 2)); 126 | top_col = reduce(top_col, shuffle_down_sync(MASK_ALL, top_col, 1)); 127 | bottom_col = 128 | reduce(bottom_col, shuffle_down_sync(MASK_ALL, bottom_col, 2)); 129 | bottom_col = 130 | reduce(bottom_col, shuffle_down_sync(MASK_ALL, bottom_col, 1)); 131 | 132 | // Group the threads into groups of four, and broadcast the data 133 | // from the first thread in each group to the other three threads. 134 | top_col = shuffle_sync(MASK_ALL, top_col, leader); 135 | bottom_col = shuffle_sync(MASK_ALL, bottom_col, leader); 136 | 137 | // Store the results to the destination tile. 138 | dst(0, i) = top_col; 139 | dst(1, i) = bottom_col; 140 | } 141 | } 142 | }; 143 | 144 | } // namespace detail 145 | 146 | template 147 | struct SumReduce { 148 | using DType = typename RegTile::DType::DType; 149 | 150 | template 151 | DEVICE void operator()(const RegTile& src, DstTile& dst) { 152 | detail::Reduce row_sum; 153 | row_sum(src, dst, Add{}); 154 | } 155 | }; 156 | 157 | template 158 | struct MaxReduce { 159 | using DType = typename RegTile::DType::DType; 160 | 161 | template 162 | DEVICE void operator()(const RegTile& src, DstTile& dst) { 163 | detail::Reduce row_max; 164 | row_max(src, dst, Max{}); 165 | } 166 | }; 167 | 168 | } // namespace tiledcuda::cell::compute 169 | -------------------------------------------------------------------------------- /include/cell/compute/softmax.hpp: -------------------------------------------------------------------------------- 1 | /** 2 | * TODO(KuangjuX): This version is numerically unstable and has not been tested 3 | * with larger inputs using more warps, which makes it generally unsuitable for 4 | * practical use. 5 | */ 6 | #pragma once 7 | 8 | #include "cell/compute/reduce.hpp" 9 | #include "cuda_utils.hpp" 10 | #include "types/layout.hpp" 11 | #include "types/mod.hpp" 12 | #include "types/tile_shape.hpp" 13 | 14 | namespace tiledcuda::cell::compute { 15 | 16 | namespace tl = tile_layout; 17 | 18 | template 19 | struct Softmax { 20 | using DType = typename RegTile::DType::DType; 21 | using BaseShape = traits::BaseTileShape; 22 | 23 | static constexpr int kRows = RegTile::kRows; 24 | static constexpr int kCols = RegTile::kCols; 25 | }; 26 | 27 | template 28 | struct Softmax { 29 | using DType = typename RegTile::DType::DType; 30 | using BaseShape = traits::BaseTileShape; 31 | 32 | static constexpr int kRows = RegTile::kRows; 33 | static constexpr int kCols = RegTile::kCols; 34 | 35 | template 36 | DEVICE void operator()(RegTile& tile, ReduceTile& reduce_tile) { 37 | #pragma unroll 38 | for (int i = 0; i < kRows; ++i) { 39 | #pragma unroll 40 | for (int j = 0; j < kCols; ++j) { 41 | auto base_tile = tile(i, j); 42 | 43 | tile(i, j)(0, 0) = exp(base_tile(0, 0)); 44 | tile(i, j)(0, 1) = exp(base_tile(0, 1)); 45 | tile(i, j)(1, 0) = exp(base_tile(1, 0)); 46 | tile(i, j)(1, 1) = exp(base_tile(1, 1)); 47 | tile(i, j)(0, 2) = exp(base_tile(0, 2)); 48 | tile(i, j)(0, 3) = exp(base_tile(0, 3)); 49 | tile(i, j)(1, 2) = exp(base_tile(1, 2)); 50 | tile(i, j)(1, 3) = exp(base_tile(1, 3)); 51 | } 52 | } 53 | 54 | compute::SumReduce row_sum; 55 | row_sum(tile, reduce_tile); 56 | 57 | #pragma unroll 58 | for (int i = 0; i < kRows; ++i) { 59 | DType top_row_sum = reduce_tile(i, 0); 60 | DType bottom_row_sum = reduce_tile(i, 1); 61 | 62 | #pragma unroll 63 | for (int j = 0; j < kCols; ++j) { 64 | auto base_tile = tile(i, j); 65 | 66 | tile(i, j)(0, 0) = base_tile(0, 0) / top_row_sum; 67 | tile(i, j)(0, 1) = base_tile(0, 1) / top_row_sum; 68 | tile(i, j)(1, 0) = base_tile(1, 0) / top_row_sum; 69 | tile(i, j)(1, 1) = base_tile(1, 1) / top_row_sum; 70 | 71 | tile(i, j)(0, 2) = base_tile(0, 2) / bottom_row_sum; 72 | tile(i, j)(0, 3) = base_tile(0, 3) / bottom_row_sum; 73 | tile(i, j)(1, 2) = base_tile(1, 2) / bottom_row_sum; 74 | tile(i, j)(1, 3) = base_tile(1, 3) / bottom_row_sum; 75 | } 76 | } 77 | } 78 | }; 79 | 80 | template 81 | struct Softmax { 82 | using DType = typename RegTile::DType::DType; 83 | using BaseShape = traits::BaseTileShape; 84 | 85 | static constexpr int kRows = RegTile::kRows; 86 | static constexpr int kCols = RegTile::kCols; 87 | }; 88 | 89 | } // namespace tiledcuda::cell::compute 90 | -------------------------------------------------------------------------------- /include/cell/convert.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "cuda_utils.hpp" 3 | 4 | #include 5 | #include 6 | 7 | namespace tiledcuda::cell { 8 | 9 | namespace { 10 | template 11 | DEVICE auto convert_type(cute::Tensor const& tensor) { 12 | using From_type = typename Engine::value_type; 13 | constexpr int numel = decltype(size(tensor))::value; 14 | cutlass::NumericArrayConverter convert_op; 15 | // HACK: this requires tensor to be "contiguous" 16 | auto frag = 17 | convert_op(*reinterpret_cast*>( 18 | tensor.data())); 19 | 20 | return make_tensor(make_rmem_ptr(&frag), tensor.layout()); 21 | } 22 | 23 | template 24 | struct IndexedTensor_ { 25 | DEVICE IndexedTensor_(Tensor& tensor) : tensor_(tensor) {} 26 | 27 | DEVICE const auto operator[](int idx) { return tensor_(_, _, idx); } 28 | 29 | private: 30 | Tensor& tensor_; 31 | }; 32 | } // namespace 33 | 34 | // Convert acc_layout from (MMA=4, MMA_M, MMA_N) to 35 | // ((4, 2), MMA_M, MMA_N / 2) if using m16n8k16, or to (4, MMA_M, MMA_N) if 36 | // using m16n8k8. 37 | template 38 | DEVICE auto convert_layout(const Tensor& acc) { 39 | auto acc_layout = acc.layout(); 40 | 41 | using X = Underscore; 42 | static_assert(decltype(size<0>(acc_layout))::value == 4); 43 | static_assert(decltype(cute::rank(acc_layout))::value == 3); 44 | 45 | constexpr int mma_shape_K = cute::get<2>(typename MMA::Shape_MNK{}); 46 | static_assert(mma_shape_K == 8 || mma_shape_K == 16); 47 | 48 | if constexpr (mma_shape_K == 8) { 49 | IndexedTensor_ indexed_tensor(acc); 50 | return indexed_tensor; 51 | } else { 52 | // (4, MMA_M, (2, MMA_N / 2))) 53 | auto l = cute::logical_divide(acc_layout, Shape{}); 54 | auto new_layout = make_layout(make_layout(get<0>(l), get<2, 0>(l)), 55 | get<1>(l), get<2, 1>(l)); 56 | auto new_tensor = make_tensor(acc.data(), new_layout); 57 | 58 | IndexedTensor_ indexed_tensor(new_tensor); 59 | return indexed_tensor; 60 | } 61 | }; 62 | } // namespace tiledcuda::cell 63 | -------------------------------------------------------------------------------- /include/cell/copy/constants.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "config.hpp" 4 | 5 | namespace tiledcuda::cell::copy { 6 | 7 | enum class CopyInst { 8 | kLoadMat = 0, // ldmatrix for loading data from shared memory to register. 9 | kStoreMat = 1, // stmatrix for storing data from register to shared memory. 10 | kLoadShared32 = 2, // ldsm32 for loading 32-bit data from shared memory. 11 | kLoadShared128 = 3 // ldsm128 for loading 128-bit data from shared memory. 12 | }; 13 | 14 | enum class WarpReuse { 15 | // TODO(haruhi): It seems that Cir/RowReuseCir/ColReuseCir are not ncessary, 16 | // thus the reuse mode can be simplified. 17 | // data are evenly partitioned to be loaded by warps. 18 | kCont = 0, // all warps continuously load data, no reuse 19 | kCir = 1, // all warps circularly load data, no reuse 20 | kRowReuseCont = 2, // Row-wise even reuse, warps in the same row 21 | // repeatedly load the same data 22 | kRowReuseCir = 3, // Row-wise circular reuse 23 | kColReuseCont = 4, // Column-wise even reuse, warps in the same column 24 | // repeatedly load the same data 25 | kColReuseCir = 5 // Column-wise circular reuse 26 | }; 27 | 28 | } // namespace tiledcuda::cell::copy 29 | -------------------------------------------------------------------------------- /include/cell/copy/mod.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "cell/copy/constants.hpp" 4 | #include "cell/copy/copy_atom.hpp" 5 | #include "cell/copy/global_to_register.hpp" 6 | #include "cell/copy/global_to_shared.hpp" 7 | #include "cell/copy/register.hpp" 8 | #include "cell/copy/shared_to_register.hpp" 9 | #include "cell/copy/warp.hpp" 10 | -------------------------------------------------------------------------------- /include/cell/copy/register.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "cuda_utils.hpp" 4 | 5 | namespace tiledcuda::cell::copy { 6 | 7 | namespace detail { 8 | template 9 | struct DataCopy { 10 | DEVICE void operator()(const Element& src, Element& dst) { dst = src; } 11 | }; 12 | 13 | template 14 | struct RegCopy { 15 | using DType = typename RegTile::DType; 16 | 17 | static constexpr int kRows = RegTile::kRows; 18 | static constexpr int kCols = RegTile::kCols; 19 | 20 | DEVICE void operator()(const RegTile& src, RegTile& dst) { 21 | Copy c; 22 | #pragma unroll 23 | for (int i = 0; i < kRows; ++i) { 24 | #pragma unroll 25 | for (int j = 0; j < kCols; ++j) { 26 | c(src(i, j), dst(i, j)); 27 | } 28 | } 29 | } 30 | }; 31 | } // namespace detail 32 | 33 | template 34 | using BaseTileCopy = 35 | detail::RegCopy>; 36 | template 37 | using RegTileCopy = 38 | detail::RegCopy>; 39 | 40 | } // namespace tiledcuda::cell::copy 41 | -------------------------------------------------------------------------------- /include/cell/mod.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "cell/acc.hpp" 4 | #include "cell/compute/mod.hpp" 5 | #include "cell/convert.hpp" 6 | #include "cell/copy/mod.hpp" 7 | #include "cell/sync.hpp" 8 | #include "cell/traits/base.hpp" 9 | #include "cell/warp.hpp" 10 | #include "types/mod.hpp" 11 | -------------------------------------------------------------------------------- /include/cell/sync.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "cuda_utils.hpp" 4 | 5 | namespace tiledcuda::cell { 6 | 7 | template 8 | DEVICE void wait_group() { 9 | #if defined(CP_ASYNC_SM80_ENABLED) 10 | asm volatile("cp.async.wait_group %0;\n" ::"n"(N)); 11 | #endif 12 | } 13 | 14 | DEVICE void commit_copy_group() { 15 | #if defined(CP_ASYNC_SM80_ENABLED) 16 | cute::cp_async_fence(); 17 | #endif 18 | } 19 | 20 | DEVICE void __copy_async() { 21 | commit_copy_group(); 22 | wait_group<0>(); 23 | } 24 | 25 | } // namespace tiledcuda::cell 26 | -------------------------------------------------------------------------------- /include/cell/traits/base.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | #include 7 | 8 | namespace tiledcuda::cell::traits { 9 | 10 | template 11 | concept BaseType = std::is_same_v || 12 | std::is_same_v || std::is_same_v; 13 | 14 | /// @brief Architecture-specific magic numbers. 15 | /// @tparam Element: the data type of the elements. 16 | template 17 | struct TraitsBase { 18 | // the maximal width of vectorized access. 19 | static constexpr int kAccessInBits = 128; 20 | static constexpr int kElmentBits = cutlass::sizeof_bits::value; 21 | static constexpr int kNumPerAccess = kAccessInBits / kElmentBits; 22 | }; 23 | 24 | template 25 | requires BaseType 26 | struct BaseTileShape { 27 | using DType = Element; 28 | 29 | static constexpr int kTileSize = 16; 30 | static constexpr int kRows = kTileSize; 31 | static constexpr int kCols = kTileSize; 32 | static constexpr int kNumel = kRows * kCols; 33 | static constexpr int kNumelPerThread = kNumel / 32; // 8 34 | static constexpr int kPackedPerThread = kNumelPerThread / 2; // 4 35 | 36 | // 4 registers used in half / bf16, 8 registers used in float. 37 | static constexpr int kRegsPerThread = 38 | sizeof(DType) * kNumelPerThread / sizeof(uint32_t); 39 | }; 40 | 41 | } // namespace tiledcuda::cell::traits 42 | -------------------------------------------------------------------------------- /include/cell/warp.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "cuda_utils.hpp" 4 | 5 | namespace tiledcuda::cell { 6 | 7 | static constexpr uint32_t MASK_ALL = 0xFFFFFFFF; 8 | 9 | /** 10 | * @brief `shuffle_sync` provides a way of moving a value from one thread 11 | * to other threads in the warp in one instruction. 12 | */ 13 | template 14 | DEVICE Element shuffle_sync(uint32_t mask, Element value, int src_lane) { 15 | return __shfl_sync(mask, value, src_lane); 16 | } 17 | 18 | /** 19 | * @brief `shuffle_down_sync` enable you to shift data within a warp in 20 | * one instruction. So basically the `value` is shiffted down `delta` 21 | * lanes. 22 | */ 23 | template 24 | DEVICE Element shuffle_down_sync(uint32_t mask, Element value, int delta) { 25 | return __shfl_down_sync(mask, value, delta); 26 | } 27 | 28 | } // namespace tiledcuda::cell 29 | -------------------------------------------------------------------------------- /include/config.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #if defined(__CUDA_ARCH__) 4 | #define HOST_DEVICE __forceinline__ __host__ __device__ 5 | #define DEVICE __forceinline__ __device__ 6 | #define HOST __forceinline__ __host__ 7 | #else 8 | #define HOST_DEVICE inline 9 | #define DEVICE inline 10 | #define HOST inline 11 | #endif 12 | 13 | #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) 14 | #define CP_ASYNC_SM80_ENABLED 15 | #endif 16 | -------------------------------------------------------------------------------- /include/cuda_info.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "cuda_utils.hpp" 4 | 5 | #include 6 | 7 | namespace tiledcuda { 8 | // Returns the number of GPUs. 9 | int GetGPUDeviceCount(); 10 | 11 | // Returns the compute capability of the given GPU. 12 | int GetGPUComputeCapability(int id); 13 | 14 | // Returns the number of multiprocessors for the given GPU. 15 | int GetGPUMultiProcessors(int id); 16 | 17 | // Returns the maximum number of threads per multiprocessor for the given 18 | // GPU. 19 | int GetGPUMaxThreadsPerMultiProcessor(int id); 20 | 21 | // Returns the maximum number of threads per block for the given GPU. 22 | int GetGPUMaxThreadsPerBlock(int id); 23 | 24 | // Returns the maximum grid size for the given GPU. 25 | dim3 GetGpuMaxGridDimSize(int id); 26 | 27 | // Returns the name of the device. 28 | std::string GetDeviceName(); 29 | 30 | } // namespace tiledcuda 31 | -------------------------------------------------------------------------------- /include/cuda_utils.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "config.hpp" 4 | #include "types/layout.hpp" 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | namespace tiledcuda { 13 | namespace tl = cell::tile_layout; 14 | 15 | template 16 | inline constexpr int CeilDiv = (a + b - 1) / b; // for compile-time values 17 | 18 | const char* cublasGetErrorString(cublasStatus_t status); 19 | 20 | inline void __cudaCheck(const cudaError err, const char* file, int line) { 21 | if (err != cudaSuccess) { 22 | fprintf(stderr, "%s(%d): CUDA error: %s.\n", file, line, 23 | cudaGetErrorString(err)); 24 | exit(EXIT_FAILURE); 25 | } 26 | } 27 | #define CudaCheck(call) __cudaCheck(call, __FILE__, __LINE__) 28 | 29 | inline void __checkLast(const char* const file, const int line) { 30 | cudaError_t const err{cudaGetLastError()}; 31 | if (err != cudaSuccess) { 32 | fprintf(stderr, "%s(%d): CUDA Runtime Error at: %s.\n", file, line, 33 | cudaGetErrorString(err)); 34 | exit(EXIT_FAILURE); 35 | } 36 | } 37 | #define CudaCheckLastError() __checkLast(__FILE__, __LINE__) 38 | 39 | inline void __cublasCheck(const cublasStatus_t err, const char* file, 40 | int line) { 41 | if (err != CUBLAS_STATUS_SUCCESS) { 42 | fprintf(stderr, "%s(%d): Cublas error: %s.\n", file, line, 43 | cublasGetErrorString(err)); 44 | exit(EXIT_FAILURE); 45 | } 46 | } 47 | #define CublasCheck(call) __cublasCheck(call, __FILE__, __LINE__) 48 | 49 | HOST_DEVICE 50 | const char* layout_type_to_str(tl::Layout type) { 51 | switch (type) { 52 | case tl::Layout::kRowMajor: 53 | return "RowMajor"; 54 | case tl::Layout::kColMajor: 55 | return "ColMajor"; 56 | case tl::Layout::kSwizzledRowMajor: 57 | return "SwizzledRowMajor"; 58 | case tl::Layout::kSwizzledColMajor: 59 | return "SwizzledColMajor"; 60 | } 61 | return "UnsupportedLayout"; 62 | } 63 | } // namespace tiledcuda 64 | -------------------------------------------------------------------------------- /include/errors.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | namespace tiledcuda::errors { 7 | 8 | class NotImplementedException : public std::exception { 9 | public: 10 | NotImplementedException(const char* error = "Not yet implemented!") { 11 | errorMessage = error; 12 | } 13 | 14 | // Provided for compatibility with std::exception. 15 | const char* what() const noexcept { return errorMessage.c_str(); } 16 | 17 | private: 18 | std::string errorMessage; 19 | }; 20 | 21 | } // namespace tiledcuda::errors 22 | -------------------------------------------------------------------------------- /include/kernels/flash_attn.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "cuda_utils.hpp" 4 | 5 | #include 6 | 7 | namespace tiledcuda::kernels { 8 | 9 | template 25 | __global__ void flash_attention(const InType* dQ, const InType* dK, 26 | const InType* dV, InType* dO, int kM, int kN, 27 | int kK, int kP, int kTM, int kTN, int kTK, 28 | int kTP); 29 | 30 | template 32 | void run_flash_attention(const InType* dQ, const InType* dK, const InType* dV, 33 | OutType* dO); 34 | 35 | void custom_flash_attention_op(const torch::Tensor& Q, const torch::Tensor& K, 36 | const torch::Tensor& V, torch::Tensor& O, 37 | int64_t m, int64_t n, int64_t k, int64_t p); 38 | 39 | } // namespace tiledcuda::kernels 40 | -------------------------------------------------------------------------------- /include/kernels/mod.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "kernels/flash_attn.hpp" 4 | #include "kernels/scatter_nd.hpp" 5 | -------------------------------------------------------------------------------- /include/kernels/scatter_nd.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "cuda_utils.hpp" 4 | 5 | #include 6 | 7 | #include 8 | 9 | namespace tiledcuda::kernels { 10 | 11 | // reference: 12 | // https://github.com/InfiniTensor/RefactorGraph/blob/master/src/04kernel/cuda/src/scatter_nd.cu#L7 13 | // TODO: optimize the kernel by increasing the number of threads to perform 14 | // `atomic_add` operations under `slice_size`. 15 | /** 16 | * @brief The ScatterNdkernel updates the content of `updates` into `data` based 17 | * on the index information provided in the given `indices`. 18 | * 19 | * @param in The input tensor `updates`. 20 | * @param out The output tensor `data`. 21 | * @param indices The indices tensor. 22 | * @param strides record the stride information between different dimensions in 23 | * the `data` tensor. 24 | * @param n The number of indices. 25 | * @param rank The last dimension of `indices`. 26 | * @param slice_size The length of the slice to be updated. Specifically, it is 27 | * the product of the difference between the rank of `data` and the last 28 | * dimension of `indices` along the memory dimensions of `data`. 29 | */ 30 | template 31 | __global__ void scatter_nd_kernel(const T* in, T* out, const int64_t* indices, 32 | unsigned int const* __restrict__ strides, 33 | size_t n, size_t rank, size_t slice_size); 34 | 35 | template 36 | void scatter_nd(torch::Tensor& data, const torch::Tensor& updates, 37 | const torch::Tensor& indices); 38 | 39 | void custom_scatter_op(torch::Tensor& data, const torch::Tensor& updates, 40 | const torch::Tensor& indices); 41 | 42 | } // namespace tiledcuda::kernels 43 | -------------------------------------------------------------------------------- /include/types/global.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "types/layout.hpp" 4 | #include "util/print.hpp" 5 | 6 | namespace tiledcuda::cell { 7 | namespace tl = tile_layout; 8 | 9 | template 10 | struct GlobalTile { 11 | using DType = Element_; 12 | using Layout = Layout_; 13 | 14 | static constexpr int kNumel = tl::get_numel; 15 | 16 | static constexpr int kRows = tl::num_rows; 17 | static constexpr int kCols = tl::num_cols; 18 | 19 | static constexpr int kRowStride = tl::row_stride; 20 | static constexpr int kColStride = tl::col_stride; 21 | 22 | static constexpr tl::Layout kType = tl::layout_type; 23 | 24 | DEVICE GlobalTile(DType* data) : data_(data), layout_(Layout{}) {} 25 | 26 | DEVICE GlobalTile(const DType* data) 27 | : data_(const_cast(data)), layout_(Layout{}) {} 28 | 29 | DEVICE DType* mutable_data() { return data_; } 30 | 31 | DEVICE const DType* data() const { return data_; } 32 | 33 | HOST_DEVICE const Layout& layout() const { return layout_; } 34 | 35 | // for write access 36 | DEVICE DType& operator()(int x, int y) { return data_[layout_(x, y)]; } 37 | 38 | // for read access 39 | DEVICE 40 | const DType& operator()(int x, int y) const { return data_[layout_(x, y)]; } 41 | 42 | DEVICE void dump_value() { print_tile(data_, layout_); } 43 | 44 | private: 45 | DType* data_; 46 | Layout layout_; 47 | }; 48 | } // namespace tiledcuda::cell 49 | -------------------------------------------------------------------------------- /include/types/global_tile_iterator.hpp: -------------------------------------------------------------------------------- 1 | 2 | #pragma once 3 | 4 | #include "types/global.hpp" 5 | #include "types/tile_shape.hpp" 6 | 7 | namespace tiledcuda::cell { 8 | namespace tl = tile_layout; 9 | 10 | namespace detail { 11 | /// @brief Helper for pretty printing a tile iterator's static shape-related 12 | /// information. This printer works ONLY on the host. 13 | struct GTileIteratorPrettyPrinter { 14 | template 15 | static HOST void print(std::ostream& out, const TileIterator& itr) { 16 | out << "numel = " << TileIterator::Tile::kNumel << ", ChunkShape[" 17 | << dim_size<0, typename TileIterator::ChunkShape> << ", " 18 | << dim_size<1, typename TileIterator::ChunkShape> << "], sc0 = " 19 | << TileIterator::sc0 << ", sc1 = " << TileIterator::sc1; 20 | } 21 | }; 22 | } // namespace detail 23 | 24 | /// @brief `SharedTileIterator` chunks a shared memory tile into smaller tiles 25 | /// and iterates over these smaller sub-tiles. 26 | /// @tparam Tile_: The type of the large tile to chunk. 27 | /// @tparam ChunkShape_: The shape of the smaller tiles into which the large 28 | /// tile is partitioned (chunk shape). 29 | template 30 | class GTileIterator { 31 | public: 32 | using Tile = Tile_; 33 | using DType = Tile::DType; 34 | using ChunkShape = ChunkShape_; 35 | 36 | static_assert(Tile::kRows >= dim_size<0, ChunkShape>, 37 | "Tile::kRows must be >= dim_size<0, ChunkShape>"); 38 | static_assert(Tile::kCols >= dim_size<1, ChunkShape>, 39 | "Tile::kCols must be >= dim_size<1, ChunkShape>"); 40 | 41 | static constexpr int kStride0 = dim_size<0, ChunkShape>; 42 | static constexpr int kStride1 = dim_size<1, ChunkShape>; 43 | 44 | static constexpr int sc0 = Tile::kRows / kStride0; 45 | static constexpr int sc1 = Tile::kCols / kStride1; 46 | 47 | HOST_DEVICE GTileIterator() : data_(nullptr) {} 48 | 49 | DEVICE GTileIterator(DType* data) : data_(data) {} 50 | 51 | DEVICE GTileIterator(const DType* data) : data_(const_cast(data)) {} 52 | 53 | // Since a Tile is considered to be at most a 2D array, the iterator 54 | // traverses over these two dimensions. The current rules are: 55 | // 1. If the index is a 2D integer, this access is considered to be a 56 | // single tile, hence it returns a Tile. 57 | // 2. If any part of the index is an underscore, this access is 58 | // considered to be a slice, naturally it returns a TileIterator. 59 | DEVICE auto operator()(int i) { 60 | assert(data_); // The iterator is not initialized. 61 | static_assert(sc0 == 1 || sc1 == 1, 62 | "A single index is supported only when the strip count " 63 | "of one of the iterator's dimensions is 1."); 64 | 65 | int x = sc0 == 1 ? 0 : i; 66 | int y = sc0 == 1 ? i : 0; 67 | 68 | using TileLayout = 69 | decltype(tl::make_tile_layout()); 71 | using NewTile = GlobalTile; 72 | 73 | int offset = Tile::kType == tl::Layout::kRowMajor 74 | ? x * (kStride0 * Tile::kRowStride) + y * kStride1 75 | : x * kStride0 + y * (Tile::kColStride * kStride1); 76 | 77 | NewTile tile(data_ + offset); 78 | 79 | return tile; 80 | } 81 | 82 | DEVICE auto operator()(int x, int y) { 83 | assert(data_); // The iterator is not initialized. 84 | assert(x < sc0 && y < sc1); // indices must be within the strip count. 85 | 86 | using TileLayout = 87 | decltype(tl::make_tile_layout()); 89 | using NewTile = GlobalTile; 90 | 91 | int offset = Tile::kType == tl::Layout::kRowMajor 92 | ? x * (kStride0 * Tile::kCols) + y * kStride1 93 | : x * kStride0 + y * (Tile::kRows * kStride1); 94 | NewTile tile(data_ + offset); 95 | 96 | return tile; 97 | } 98 | 99 | DEVICE auto operator()(int x, const Underscore& y) { 100 | assert(data_); // The iterator is not initialized. 101 | assert(x < sc0); // index must be within the strip count. 102 | 103 | // Updated the layout for sub-tiles accessed by the sliced iterator. 104 | // Note: Only the shape changes; the stride remains the same. 105 | using TileLayout = decltype(tl::make_tile_layout()); 108 | 109 | using NewTile = GlobalTile; 110 | using Iter = GTileIterator; 111 | static_assert(Iter::sc0 == 1); 112 | 113 | // advance pointer to the correct start position 114 | int offset = Tile::kType == tl::Layout::kRowMajor 115 | ? x * (kStride0 * Tile::kCols) 116 | : x * kStride0; 117 | 118 | Iter iter(data_ + offset); 119 | return iter; 120 | } 121 | 122 | DEVICE auto operator()(const Underscore& x, int y) { 123 | assert(data_); // The iterator is not initialized. 124 | assert(y < sc1); // index must be within the strip count. 125 | 126 | // Updated the layout for sub-tiles accessed by the sliced iterator. 127 | // Note: Only the shape changes; the stride remains the same. 128 | using TileLayout = decltype(tl::make_tile_layout()); 131 | 132 | using NewTile = GlobalTile; 133 | using Iter = GTileIterator; 134 | static_assert(Iter::sc1 == 1); 135 | 136 | // advance pointer to the correct start position 137 | int offset = Tile::kType == tl::Layout::kRowMajor 138 | ? y * kStride1 139 | : y * (Tile::kRows * kStride1); 140 | 141 | Iter iter(data_ + offset); 142 | return iter; 143 | } 144 | 145 | DEVICE auto to_tile() { 146 | Tile tile(data_); 147 | return tile; 148 | } 149 | 150 | private: 151 | DType* data_; 152 | }; 153 | 154 | /// @brief Pretty printer for the static shape information of a TileIterator. 155 | /// Note: This printer function works ONLY on the host. 156 | template 157 | static HOST std::ostream& operator<<( 158 | std::ostream& out, const GTileIterator& itr) { 159 | detail::GTileIteratorPrettyPrinter::print(out, itr); 160 | return out; 161 | } 162 | 163 | } // namespace tiledcuda::cell 164 | -------------------------------------------------------------------------------- /include/types/mod.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "types/global.hpp" 4 | #include "types/global_tile_iterator.hpp" 5 | #include "types/layout.hpp" 6 | #include "types/register.hpp" 7 | #include "types/shared.hpp" 8 | #include "types/shared_tile_iterator.hpp" 9 | #include "types/tile_shape.hpp" 10 | -------------------------------------------------------------------------------- /include/types/register.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "cuda_utils.hpp" 4 | #include "types/layout.hpp" 5 | #include "util/print.hpp" 6 | 7 | namespace tiledcuda::cell { 8 | namespace tl = tile_layout; 9 | 10 | namespace detail { 11 | 12 | namespace { 13 | template 14 | constexpr int get_rows = DType::kRows; 15 | 16 | template <> 17 | constexpr int get_rows = 1; 18 | 19 | template <> 20 | constexpr int get_rows<__half> = 1; 21 | 22 | template <> 23 | constexpr int get_rows = 1; 24 | 25 | template 26 | constexpr int get_cols = DType::kCols; 27 | 28 | template <> 29 | constexpr int get_cols = 1; 30 | 31 | template <> 32 | constexpr int get_cols<__half> = 1; 33 | 34 | template <> 35 | constexpr int get_cols = 1; 36 | } // namespace 37 | 38 | /// @brief Helper for pretty printing a register tile's static shape 39 | /// information. This printer works ONLY on the host. 40 | struct RegTilePrettyPrinter { 41 | template 42 | static HOST void print(std::ostream& out, const Tile& tile) { 43 | out << layout_type_to_str(Tile::kType) << "[" 44 | << Tile::kRows * get_rows << ", " 45 | << Tile::kCols * get_cols << "]"; 46 | } 47 | }; 48 | 49 | DEVICE void clear(float* data, int numel) { 50 | memset((void*)data, 0, sizeof(float) * numel); 51 | } 52 | 53 | DEVICE void clear(__half* data, int numel) { 54 | memset((void*)data, 0, sizeof(__half) * numel); 55 | } 56 | 57 | template 58 | DEVICE void clear(DType* data, int numel) { 59 | for (int i = 0; i < numel; ++i) { 60 | clear(data[i].mutable_data(), 8); 61 | } 62 | } 63 | } // namespace detail 64 | 65 | template 66 | class RegTile { 67 | public: 68 | using DType = Element_; 69 | using Layout = Layout_; 70 | 71 | static constexpr int kNumel = tl::get_numel; 72 | static constexpr int kRows = tl::num_rows; 73 | static constexpr int kCols = tl::num_cols; 74 | 75 | // FIXME(haruhi): this is a hack to fix the layout type deduction for when 76 | // the shape is 1x1. This is a workaround. Fix this to be more robust. 77 | static constexpr tl::Layout kType = tl::layout_type; 78 | 79 | DEVICE RegTile() : layout_(Layout{}) { 80 | memset((void*)data_, 0, sizeof(data_)); 81 | } 82 | 83 | DEVICE DType* mutable_data() { return (DType*)data_; } 84 | 85 | DEVICE const DType* data() const { return (DType*)data_; } 86 | 87 | DEVICE const Layout& layout() const { return layout_; } 88 | 89 | // for write access 90 | DEVICE DType& operator()(int x, int y) { return data_[layout_(x, y)]; } 91 | 92 | // for read access 93 | DEVICE const DType& operator()(int x, int y) const { 94 | return data_[layout_(x, y)]; 95 | } 96 | 97 | DEVICE void dump_value() const { 98 | print_tile(const_cast(data_), layout_); 99 | } 100 | 101 | DEVICE void clear() { detail::clear(data_, kNumel); } 102 | 103 | private: 104 | DType data_[kNumel]; 105 | Layout layout_; 106 | }; 107 | 108 | template 109 | using BaseTileRowMajor = RegTile>; 110 | 111 | template 112 | using BaseTileColMajor = RegTile>; 113 | 114 | /// @brief Pretty printer for the static shape information of a register tile. 115 | /// Note: This printer function works ONLY on the host. The current 116 | /// implementation prints a flattened layout and only displays the outer 117 | /// name of the tile layout. 118 | /// @tparam T: element type, which must be a `RegTile` rather than a basic 119 | /// element type like float 120 | /// @tparam Layout: tile layout 121 | template 122 | static HOST std::ostream& operator<<(std::ostream& out, 123 | const RegTile& tile) { 124 | detail::RegTilePrettyPrinter::print(out, tile); 125 | return out; 126 | } 127 | 128 | } // namespace tiledcuda::cell 129 | -------------------------------------------------------------------------------- /include/types/shared.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "types/layout.hpp" 4 | #include "util/print.hpp" 5 | 6 | namespace tiledcuda::cell { 7 | namespace tl = tile_layout; 8 | 9 | template 10 | class SharedTile { 11 | public: 12 | using DType = Element_; 13 | using Layout = Layout_; 14 | 15 | static constexpr int kNumel = tl::get_numel; 16 | 17 | static constexpr int kRows = tl::num_rows; 18 | static constexpr int kCols = tl::num_cols; 19 | 20 | static constexpr int kRowStride = tl::row_stride; 21 | static constexpr int kColStride = tl::col_stride; 22 | 23 | static constexpr tl::Layout kType = tl::layout_type; 24 | static constexpr bool kSwizzled = kSwizzled_; 25 | 26 | DEVICE SharedTile(DType* data) : data_(data), layout_(Layout{}) {} 27 | 28 | DEVICE DType* mutable_data() { return data_; } 29 | 30 | DEVICE const DType* data() const { return data_; } 31 | 32 | HOST_DEVICE const Layout& layout() const { return layout_; } 33 | 34 | // for write access 35 | DEVICE DType& operator()(int x, int y) { return data_[layout_(x, y)]; } 36 | 37 | // for read access 38 | DEVICE 39 | const DType& operator()(int x, int y) const { return data_[layout_(x, y)]; } 40 | 41 | DEVICE void dump_value() { print_tile(data_, layout_); } 42 | 43 | private: 44 | DType* data_; 45 | Layout layout_; 46 | }; 47 | } // namespace tiledcuda::cell 48 | -------------------------------------------------------------------------------- /include/types/shared_tile_iterator.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "cell/traits/base.hpp" 4 | #include "types/shared.hpp" 5 | #include "types/tile_shape.hpp" 6 | 7 | namespace tiledcuda::cell { 8 | namespace tl = tile_layout; 9 | 10 | using namespace cute; 11 | 12 | namespace detail { 13 | /// @brief Helper for pretty printing a tile iterator's static shape-related 14 | /// information. This printer works ONLY on the host. 15 | struct STileIteratorPrettyPrinter { 16 | template 17 | static HOST void print(std::ostream& out, const TileIterator& itr) { 18 | out << "numel = " << TileIterator::Tile::kNumel << ", ChunkShape[" 19 | << dim_size<0, typename TileIterator::ChunkShape> << ", " 20 | << dim_size<1, typename TileIterator::ChunkShape> << "], sc0 = " 21 | << TileIterator::sc0 << ", sc1 = " << TileIterator::sc1; 22 | } 23 | }; 24 | } // namespace detail 25 | 26 | /// @brief `SharedTileIterator` chunks a shared memory tile into smaller tiles 27 | /// and iterates over these smaller sub-tiles. 28 | /// @tparam Tile_: The type of the large tile to chunk. 29 | /// @tparam ChunkShape_: The shape of the smaller tiles into which the large 30 | /// tile is partitioned (chunk shape). 31 | template 32 | class STileIterator { 33 | public: 34 | using Tile = Tile_; 35 | using DType = Tile::DType; 36 | using ChunkShape = ChunkShape_; 37 | using BaseShape = traits::BaseTileShape; 38 | 39 | static_assert(Tile::kRows >= dim_size<0, ChunkShape>, 40 | "Tile::kRows must be >= dim_size<0, ChunkShape>"); 41 | static_assert(Tile::kCols >= dim_size<1, ChunkShape>, 42 | "Tile::kCols must be >= dim_size<1, ChunkShape>"); 43 | 44 | static constexpr int kChunkRow = dim_size<0, ChunkShape>; 45 | static constexpr int kChunkCol = dim_size<1, ChunkShape>; 46 | 47 | static constexpr int sc0 = Tile::kRows / kChunkRow; 48 | static constexpr int sc1 = Tile::kCols / kChunkCol; 49 | 50 | HOST_DEVICE STileIterator() : data_(nullptr) {} 51 | 52 | DEVICE STileIterator(DType* data) : data_(data) {} 53 | 54 | DEVICE STileIterator(const DType* data) : data_(const_cast(data)) {} 55 | 56 | // Since a Tile is considered to be at most a 2D array, the iterator 57 | // traverses over these two dimensions. The current rules are: 58 | // 1. If the index is a 2D integer, this access is considered to be a 59 | // single tile, hence it returns a Tile. 60 | // 2. If any part of the index is an underscore, this access is 61 | // considered to be a slice, naturally it returns a TileIterator. 62 | DEVICE auto operator()(int i) { 63 | assert(data_); // The iterator is not initialized. 64 | static_assert(sc0 == 1 || sc1 == 1, 65 | "A single index is supported only when the strip count " 66 | "of one of the iterator's dimensions is 1."); 67 | 68 | int x = sc0 == 1 ? 0 : i; 69 | int y = sc0 == 1 ? i : 0; 70 | 71 | using TileLayout = 72 | decltype(tl::make_shared_tile_layout()); 75 | 76 | using NewTile = SharedTile; 77 | 78 | int offset1 = x * (kChunkRow * Tile::kRowStride) + 79 | y * kTilePerChunkCol * BaseShape::kNumel; 80 | int offset2 = x * kTilePerChunkRow * BaseShape::kNumel + 81 | y * (Tile::kColStride * kChunkCol); 82 | int offset = Tile::kType == tl::Layout::kRowMajor ? offset1 : offset2; 83 | 84 | NewTile tile(data_ + offset); 85 | return tile; 86 | } 87 | 88 | DEVICE auto operator()(int x, int y) { 89 | assert(false && "Not implemented yet."); 90 | return 0; 91 | } 92 | 93 | DEVICE auto operator()(int x, const Underscore& y) { 94 | assert(false && "Not implemented yet."); 95 | return 0; 96 | } 97 | 98 | DEVICE auto operator()(const Underscore& x, int y) { 99 | assert(false && "Not implemented yet."); 100 | return 0; 101 | } 102 | 103 | DEVICE auto to_tile() { 104 | Tile tile(data_); 105 | return tile; 106 | } 107 | 108 | private: 109 | static constexpr int kTilePerRow = Tile::kRows / BaseShape::kRows; 110 | static constexpr int kTilePerCol = Tile::kCols / BaseShape::kCols; 111 | 112 | static constexpr int kTilePerChunkRow = kChunkRow / BaseShape::kRows; 113 | static constexpr int kTilePerChunkCol = kChunkCol / BaseShape::kCols; 114 | 115 | // The shared memory tile iterator creates a sub-tile that spans multiple 116 | // `BaseTile`s. The row and column strides are used to address a single 117 | // `BaseTile`. DO NOT modify these unless you fully understand how this 118 | // layout is used with the Shared to Register loader, as changes might 119 | // cause significant errors. 120 | static constexpr int kTileRowStride = Tile::kType == tl::Layout::kRowMajor 121 | ? kTilePerCol * BaseShape::kNumel 122 | : BaseShape::kNumel; 123 | 124 | static constexpr int kTileColStride = Tile::kType == tl::Layout::kRowMajor 125 | ? BaseShape::kNumel 126 | : kTilePerRow * BaseShape::kNumel; 127 | 128 | DType* data_; 129 | }; 130 | 131 | /// @brief Pretty printer for the static shape information of a TileIterator. 132 | /// Note: This printer function works ONLY on the host. 133 | template 134 | static HOST std::ostream& operator<<( 135 | std::ostream& out, const STileIterator& itr) { 136 | detail::STileIteratorPrettyPrinter::print(out, itr); 137 | return out; 138 | } 139 | 140 | } // namespace tiledcuda::cell 141 | -------------------------------------------------------------------------------- /include/types/tile_shape.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | namespace tiledcuda::cell { 7 | 8 | template 9 | struct TileShape { 10 | static constexpr cute::array shape = {Ns...}; 11 | 12 | static constexpr size_t get_numel() { 13 | size_t product = 1; 14 | for (size_t n : shape) product *= n; 15 | return product; 16 | } 17 | 18 | static constexpr size_t kNumel = get_numel(); 19 | }; 20 | 21 | template 22 | inline static constexpr int64_t get_numel = TileShape::kNumel; 23 | 24 | template 25 | inline static constexpr size_t dim_size = cute::get(TileShape::shape); 26 | 27 | struct Underscore {}; // dummy type for underscore 28 | static const __device__ Underscore _; // for slicing 29 | } // namespace tiledcuda::cell 30 | -------------------------------------------------------------------------------- /include/util/cuda_timer.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "cuda_utils.hpp" 4 | 5 | namespace tiledcuda { 6 | 7 | /// @brief: Cuda timer to measure the time taken by a kernel. 8 | /// Usage: 9 | /// CudaTimer timer; 10 | /// timer.start(); 11 | /// ... 12 | /// float time = timer.stop(); 13 | class CudaTimer { 14 | public: 15 | CudaTimer() { 16 | CudaCheck(cudaEventCreate(&start_event)); 17 | CudaCheck(cudaEventCreate(&stop_event)); 18 | } 19 | 20 | ~CudaTimer() { 21 | CudaCheck(cudaEventDestroy(start_event)); 22 | CudaCheck(cudaEventDestroy(stop_event)); 23 | } 24 | 25 | void start(cudaStream_t st = 0) { 26 | stream = st; 27 | CudaCheck(cudaEventRecord(start_event, stream)); 28 | } 29 | 30 | float stop() { 31 | float milliseconds = 0.; 32 | CudaCheck(cudaEventRecord(stop_event, stream)); 33 | CudaCheck(cudaEventSynchronize(stop_event)); 34 | CudaCheck(cudaEventElapsedTime(&milliseconds, start_event, stop_event)); 35 | return milliseconds; 36 | } 37 | 38 | private: 39 | cudaEvent_t start_event; 40 | cudaEvent_t stop_event; 41 | cudaStream_t stream; 42 | }; 43 | 44 | } // namespace tiledcuda 45 | -------------------------------------------------------------------------------- /include/util/debug.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "config.hpp" 4 | 5 | #include 6 | 7 | namespace tiledcuda { 8 | 9 | DEVICE bool block(int bid) { 10 | int id = blockIdx.x + blockIdx.y * gridDim.x + 11 | blockIdx.z * gridDim.x * gridDim.y; 12 | return id == bid; 13 | } 14 | 15 | DEVICE bool thread(int tid, int bid) { 16 | int id = threadIdx.x + threadIdx.y * blockDim.x + 17 | threadIdx.z * blockDim.x * blockDim.y; 18 | return id == tid && block(bid); 19 | } 20 | 21 | // usage, e.g. 22 | // if (thread(0, 0)) { ... } 23 | // if (thread(37)) { ... } 24 | // if (block(0)) { ... } 25 | 26 | DEVICE bool thread(int tid) { return thread(tid, 0); } 27 | 28 | DEVICE bool thread0() { return thread(0, 0); } 29 | 30 | DEVICE bool block0() { return block(0); } 31 | 32 | } // namespace tiledcuda 33 | -------------------------------------------------------------------------------- /include/util/print.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "types/layout.hpp" 4 | #include "util/debug.hpp" 5 | 6 | namespace tiledcuda::cell { 7 | namespace tl = tile_layout; 8 | 9 | /// @brief Print a tile of single-precision floating point numbers. NOTE: when 10 | // use print in the device function, do add (if(thread0())) to avoid printing 11 | // multiple times by multiple threads. usage: 12 | // if(thread0()) { 13 | // print_tile(data, layout); 14 | // } 15 | template 16 | DEVICE void print_tile(const float* data, const Layout& layout) { 17 | for (int i = 0; i < tl::num_rows; ++i) { 18 | for (int j = 0; j < tl::num_cols; ++j) { 19 | printf("%.3f, ", data[layout(i, j)]); 20 | } 21 | printf("\n"); 22 | 23 | if (i && (i + 1) % 16 == 0) printf("\n"); 24 | } 25 | } 26 | 27 | /// @brief Print a tile of half-precision floating point numbers. 28 | template 29 | DEVICE void print_tile(const cutlass::half_t* data, const Layout& layout) { 30 | const half* data_ = reinterpret_cast(data); 31 | 32 | for (int i = 0; i < tl::num_rows; ++i) { 33 | for (int j = 0; j < tl::num_cols; ++j) { 34 | printf("%.3f, ", __half2float(data_[layout(i, j)])); 35 | } 36 | printf("\n"); 37 | 38 | if (i && (i + 1) % 16 == 0) printf("\n"); 39 | } 40 | } 41 | 42 | /// @brief Print a tile of half-precision floating point numbers. 43 | template 44 | DEVICE void print_tile(const __half* data, const Layout& layout) { 45 | for (int i = 0; i < tl::num_rows; ++i) { 46 | for (int j = 0; j < tl::num_cols; ++j) { 47 | printf("%.3f, ", __half2float(data[layout(i, j)])); 48 | } 49 | printf("\n"); 50 | 51 | if (i && (i + 1) % 16 == 0) printf("\n"); 52 | } 53 | } 54 | 55 | /// @brief Print a register tile. Since register tile is a nested array-like 56 | /// structure. printing resigter tile hits this function. 57 | template 58 | DEVICE void print_tile(const DType* data, const Layout& layout) { 59 | for (int i = 0; i < tl::num_rows; ++i) { 60 | for (int j = 0; j < tl::num_cols; ++j) { 61 | auto tile = data[layout(i, j)]; 62 | print_tile(tile.data(), tile.layout()); 63 | } 64 | } 65 | } 66 | 67 | template 68 | struct RegVecPrinter { 69 | static constexpr int kRows = RegTile::kRows; 70 | 71 | DEVICE void operator()(const RegTile& tile, int tid) { 72 | int lane_id = tid % 32; 73 | for (int i = 0; i < kRows; ++i) { 74 | if (lane_id % 4 == 0) { 75 | printf("%.3f, ", __half2float(tile(i, 0))); 76 | } 77 | 78 | #if defined(__CUDA_ARCH__) 79 | // Sync Threads to print in-order data. 80 | __syncthreads(); 81 | #endif 82 | if (lane_id % 4 == 0) { 83 | printf("%.3f, ", __half2float(tile(i, 1))); 84 | } 85 | } 86 | 87 | if (lane_id == 0) printf("\n"); 88 | } 89 | }; 90 | 91 | template 92 | struct RegTilePrinter { 93 | constexpr static int kRows = RegTile::kRows; 94 | constexpr static int kCols = RegTile::kCols; 95 | 96 | void operator()(const RegTile& tile, int tid) {} 97 | }; 98 | 99 | template 100 | struct RegTilePrinter { 101 | constexpr static int kRows = RegTile::kRows; 102 | constexpr static int kCols = RegTile::kCols; 103 | 104 | using DType = typename RegTile::DType::DType; 105 | 106 | DEVICE void print_tile_col(const RegTile& tile, int lane_id, int row_num, 107 | bool is_top) { 108 | for (int col_num = 0; col_num < kCols; ++col_num) { 109 | if (is_top) { 110 | printf("%.3f, %.3f, ", 111 | __half2float(tile(row_num, col_num)(0, 0)), 112 | __half2float(tile(row_num, col_num)(0, 1))); 113 | printf("%.3f, %.3f, ", 114 | __half2float(tile(row_num, col_num)(1, 0)), 115 | __half2float(tile(row_num, col_num)(1, 1))); 116 | } else { 117 | printf("%.3f, %.3f, ", 118 | __half2float(tile(row_num, col_num)(0, 2)), 119 | __half2float(tile(row_num, col_num)(0, 3))); 120 | printf("%.3f, %.3f, ", 121 | __half2float(tile(row_num, col_num)(1, 2)), 122 | __half2float(tile(row_num, col_num)(1, 3))); 123 | } 124 | } 125 | if (lane_id % 4 == 0) printf("\n"); 126 | } 127 | 128 | DEVICE void operator()(const RegTile& tile, int tid) { 129 | // BaseTile base_tile; 130 | int lane_id = tid % 32; 131 | for (int i = 0; i < kRows; ++i) { 132 | // Print top row. 133 | if (lane_id >= 0 && lane_id <= 3) 134 | print_tile_col(tile, lane_id, i, true); 135 | else if (lane_id >= 4 && lane_id <= 7) 136 | print_tile_col(tile, lane_id, i, true); 137 | else if (lane_id >= 8 && lane_id <= 11) 138 | print_tile_col(tile, lane_id, i, true); 139 | else if (lane_id >= 12 && lane_id <= 15) 140 | print_tile_col(tile, lane_id, i, true); 141 | else if (lane_id >= 16 && lane_id <= 19) 142 | print_tile_col(tile, lane_id, i, true); 143 | else if (lane_id >= 20 && lane_id <= 23) 144 | print_tile_col(tile, lane_id, i, true); 145 | else if (lane_id >= 24 && lane_id <= 27) 146 | print_tile_col(tile, lane_id, i, true); 147 | else if (lane_id >= 28 && lane_id <= 31) 148 | print_tile_col(tile, lane_id, i, true); 149 | 150 | #if defined(__CUDA_ARCH__) 151 | // Sync Threads to print in-order data. 152 | __syncthreads(); 153 | #endif 154 | 155 | // Print bottom row. 156 | if (lane_id >= 0 && lane_id <= 3) 157 | print_tile_col(tile, lane_id, i, false); 158 | else if (lane_id >= 4 && lane_id <= 7) 159 | print_tile_col(tile, lane_id, i, false); 160 | else if (lane_id >= 8 && lane_id <= 11) 161 | print_tile_col(tile, lane_id, i, false); 162 | else if (lane_id >= 12 && lane_id <= 15) 163 | print_tile_col(tile, lane_id, i, false); 164 | else if (lane_id >= 16 && lane_id <= 19) 165 | print_tile_col(tile, lane_id, i, false); 166 | else if (lane_id >= 20 && lane_id <= 23) 167 | print_tile_col(tile, lane_id, i, false); 168 | else if (lane_id >= 24 && lane_id <= 27) 169 | print_tile_col(tile, lane_id, i, false); 170 | else if (lane_id >= 28 && lane_id <= 31) 171 | print_tile_col(tile, lane_id, i, false); 172 | } 173 | if (lane_id == 0) printf("\n"); 174 | } 175 | }; 176 | 177 | } // namespace tiledcuda::cell 178 | -------------------------------------------------------------------------------- /pytiledcuda/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/* 2 | -------------------------------------------------------------------------------- /pytiledcuda/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | torch.ops.load_library("build/src/libtiledcuda.so") 4 | 5 | 6 | def scatter_nd(scatter_data, scatter_indices, scatter_updates): 7 | torch.ops.tiledcuda.scatter_nd(scatter_data, scatter_updates, 8 | scatter_indices) 9 | 10 | 11 | def flash_attention_fwd(Q, K, V, O, m, n, k, p): 12 | torch.ops.tiledcuda.flash_attention_fwd(Q, K, V, O, m, n, k, p) 13 | 14 | 15 | class TiledFlashAttention(): 16 | 17 | def __init__(self, query, key, value): 18 | self.m, self.k = query.size(-2), query.size(-1) 19 | self.n, self.p = value.size(-2), value.size(-1) 20 | 21 | self.query = query.half().flatten() 22 | # TODO(KuangjuX): To simplify the usage of the kernel, 23 | # we treat K as k.Transpose. 24 | self.key = key.half().t().flatten() 25 | self.value = value.half().t().flatten() 26 | 27 | self.output = torch.empty(self.m, 28 | self.p, 29 | dtype=torch.half, 30 | device='cuda').flatten() 31 | 32 | def forward(self) -> torch.Tensor: 33 | flash_attention_fwd(self.query, self.key, self.value, self.output, 34 | self.m, self.n, self.k, self.p) 35 | 36 | return self.output.view(self.m, self.p) 37 | -------------------------------------------------------------------------------- /scripts/clang_format.hook: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | readonly VERSION="12.0.0" 5 | 6 | version=$(clang-format -version) 7 | 8 | if ! [[ $version == *"$VERSION"* ]]; then 9 | echo "clang-format version check failed." 10 | echo "a version contains '$VERSION' is needed, but get '$version'" 11 | echo "you can install the right version, and make an soft-link to '\$PATH' env" 12 | exit -1 13 | fi 14 | 15 | clang-format $@ 16 | -------------------------------------------------------------------------------- /scripts/cmake/dependencies.cmake: -------------------------------------------------------------------------------- 1 | # set the third party directory for dependencies that do not need a build 2 | set(THIRD_PARTY_DIR 3 | "${PROJECT_SOURCE_DIR}/3rd-party" 4 | CACHE STRING 5 | "A path that specifies the directory for third-party downloads.") 6 | 7 | # set the third party build directory for dependencies that need a build 8 | set(THIRD_PARTY_BUILD_DIR 9 | "${CMAKE_BINARY_DIR}/3rd-party" 10 | CACHE STRING 11 | "A path that specifies the directory for third-party build & install." 12 | ) 13 | 14 | set(THIRD_PARTY_BUILD_TYPE Release) 15 | 16 | # add cutlass into dependence 17 | include_directories(${THIRD_PARTY_DIR}/cutlass/include) 18 | 19 | # add googletest into dependence 20 | set(INSTALL_GTEST 21 | OFF 22 | CACHE BOOL "Install gtest." FORCE) 23 | add_subdirectory(${THIRD_PARTY_DIR}/googletest) 24 | include_directories(BEFORE SYSTEM 25 | ${THIRD_PARTY_DIR}/googletest/googletest/include) 26 | 27 | # add glog into dependence 28 | include(${PROJECT_SOURCE_DIR}/scripts/cmake/public/glog.cmake) 29 | -------------------------------------------------------------------------------- /scripts/cmake/external/glog.cmake: -------------------------------------------------------------------------------- 1 | include(ExternalProject) 2 | 3 | set(GLOG_REPOSITORY https://github.com/google/glog.git) 4 | set(GLOG_TAG v0.7.1) 5 | 6 | set(GLOG_PREFIX_DIR ${THIRD_PARTY_BUILD_DIR}/glog) 7 | set(GLOG_SOURCE_DIRS ${GLOG_PREFIX_DIR}/src/extern_glog) 8 | set(GLOG_INSTALL_DIR ${THIRD_PARTY_BUILD_DIR}/install/glog) 9 | 10 | set(GLOG_INCLUDE_DIRS 11 | "${GLOG_INSTALL_DIR}/include" 12 | CACHE PATH "glog include directory." FORCE) 13 | set(GLOG_LIBRARIES 14 | "${GLOG_INSTALL_DIR}/lib/libglog.so" 15 | CACHE FILEPATH "glog library." FORCE) 16 | 17 | ExternalProject_Add( 18 | extern_glog 19 | GIT_REPOSITORY ${GLOG_REPOSITORY} 20 | GIT_TAG ${GLOG_TAG} 21 | GIT_SHALLOW TRUE 22 | PREFIX ${GLOG_PREFIX_DIR} 23 | SOURCE_DIR ${GLOG_SOURCE_DIR} 24 | INSTALL_DIR ${GLOG_INSTALL_DIR} 25 | BUILD_IN_SOURCE 1 26 | UPDATE_COMMAND "" 27 | CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${GLOG_INSTALL_DIR} 28 | -DCMAKE_INSTALL_LIBDIR=${GLOG_INSTALL_DIR}/lib 29 | -DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE} 30 | -DCMAKE_POSITION_INDEPENDENT_CODE=ON 31 | -DWITH_GFLAGS=OFF 32 | -DWITH_GTEST=OFF 33 | -DBUILD_TESTING=OFF) 34 | 35 | add_library(glog::glog INTERFACE IMPORTED) 36 | include_directories(glog::glog BEFORE SYSTEM ${GLOG_INCLUDE_DIRS}) 37 | target_link_libraries(glog::glog INTERFACE ${GLOG_LIBRARIES}) 38 | add_dependencies(glog::glog extern_glog) 39 | -------------------------------------------------------------------------------- /scripts/cmake/generic.cmake: -------------------------------------------------------------------------------- 1 | set(CMAKE_BUILD_TYPE Release) 2 | 3 | set(CMAKE_CXX_STANDARD 4 | 20 5 | CACHE STRING "The C++ standard whoese features are requested." FORCE) 6 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 7 | 8 | set(CMAKE_CUDA_STANDARD 9 | 20 10 | CACHE STRING "The CUDA standard whose features are requested." FORCE) 11 | set(CMAKE_CUDA_STANDARD_REQUIRED ON) 12 | 13 | # Set host compiler flags. Enable all warnings and treat them as errors 14 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wall") 15 | 16 | find_package(CUDAToolkit QUIET REQUIRED) 17 | enable_language(CUDA) 18 | set(CMAKE_CUDA on) 19 | 20 | find_package(Python3 REQUIRED COMPONENTS Interpreter) 21 | message(STATUS "Python interpreter path: ${Python3_EXECUTABLE}") 22 | 23 | set(TORCH_LIB_PREFIX "${Python3_SITEARCH}/torch") 24 | if(NOT EXISTS ${TORCH_LIB_PREFIX}) 25 | message(FATAL_ERROR "Torch library is not installed.") 26 | else() 27 | list(APPEND CMAKE_PREFIX_PATH "${TORCH_LIB_PREFIX}/share/cmake/Torch") 28 | endif() 29 | find_package(Torch REQUIRED) 30 | 31 | # let cmake automatically detect the current CUDA architecture to avoid 32 | # generating device codes for all possible architectures 33 | set(CMAKE_CUDA_ARCHITECTURES OFF) 34 | set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --Werror all-warnings") 35 | # Set the CUDA_PROPAGATE_HOST_FLAGS to OFF to avoid passing host compiler flags 36 | # to the device compiler 37 | set(CUDA_PROPAGATE_HOST_FLAGS OFF) 38 | 39 | # FIXME(haruhi): -std=c++20 has to be set explicitly here, Otherwise, linking 40 | # against torchlibs will raise errors. it seems that the host compilation 41 | # options are not passed to torchlibs. 42 | set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS} -std=c++20) 43 | set(CUDA_NVCC_FLAGS_DEBUG ${CUDA_NVCC_FLAGS_DEBUG} -std=c++20 -O0) 44 | set(CUDA_NVCC_FLAGS_RELEASE ${CUDA_NVCC_FLAGS_RELEASE} -std=c++20 -O3) 45 | 46 | message(STATUS "TiledCUDA: CUDA detected: " ${CUDA_VERSION}) 47 | message(STATUS "TiledCUDA: CUDA nvcc is: " ${CUDA_NVCC_EXECUTABLE}) 48 | message(STATUS "TiledCUDA: CUDA toolkit directory: " ${CUDA_TOOLKIT_ROOT_DIR}) 49 | 50 | if(ENABLE_DEBUG) 51 | message(STATUS "TiledCUDA: Debug mode enabled") 52 | set(CMAKE_BUILD_TYPE Debug) 53 | set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -DDEBUG") 54 | set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -DDEBUG") 55 | endif() 56 | 57 | function(cuda_test TARGET_NAME) 58 | set(oneValueArgs "") 59 | set(multiValueArgs SRCS DEPS) 60 | cmake_parse_arguments(cuda_test "${options}" "${oneValueArgs}" 61 | "${multiValueArgs}" ${ARGN}) 62 | 63 | list(APPEND UT_SRCS "${PROJECT_SOURCE_DIR}/tests/cpp/test_unit.cc" 64 | "${PROJECT_SOURCE_DIR}/src/cuda_utils.cc" 65 | "${PROJECT_SOURCE_DIR}/tests/cpp/common/test_utils.cc" ${cuda_test_SRCS}) 66 | 67 | cuda_add_executable(${TARGET_NAME} ${UT_SRCS}) 68 | target_link_libraries(${TARGET_NAME} ${cuda_test_DEPS} gtest glog::glog) 69 | add_dependencies(${TARGET_NAME} gtest glog::glog) 70 | 71 | # add a unittest into ctest with the same name as the target 72 | add_test(${TARGET_NAME} ${TARGET_NAME}) 73 | endfunction(cuda_test) 74 | -------------------------------------------------------------------------------- /scripts/cmake/public/glog.cmake: -------------------------------------------------------------------------------- 1 | set(GLOG_FIND_REQUIRED ON) 2 | 3 | # find locally installed glog 4 | find_package(glog CONFIG QUIET) # try to use the config mode first 5 | if(NOT TARGET glog::glog) 6 | find_package(glog MODULE QUIET) 7 | endif() 8 | 9 | if(TARGET glog::glog) 10 | set(GLOG_FOUND TRUE) 11 | message(STATUS "Found Glog at ${glog_DIR}") 12 | message(STATUS " Target : glog::glog") 13 | else() 14 | message( 15 | STATUS "Could not find Glog automatically with new-style glog target, " 16 | "use legacy find.") 17 | # Try to find glog manually. Older versions of glog do not include a 18 | # find_package configuration. Therefore, we must use custom logic to locate 19 | # the library and remap it to an imported target. 20 | 21 | include(FindPackageHandleStandardArgs) 22 | 23 | # set path search hints 24 | list(APPEND GLOG_CHECK_INCLUDE_DIRS /usr/local/include /opt/local/include 25 | /usr/include) 26 | list(APPEND GLOG_CHECK_PATH_SUFFIXES glog/include glog/Include Glog/include 27 | Glog/Include) 28 | 29 | list(APPEND GLOG_CHECK_LIBRARY_DIRS /usr/local/lib /opt/local/lib /usr/lib) 30 | list( 31 | APPEND 32 | GLOG_CHECK_LIBRARY_SUFFIXES 33 | glog/lib 34 | glog/Lib 35 | glog/lib64 36 | Glog/lib 37 | Glog/Lib 38 | x64/Release) 39 | 40 | find_path( 41 | GLOG_INCLUDE_DIRS 42 | NAMES glog/logging.h 43 | PATHS ${GLOG_INCLUDE_DIR_HINTS} ${GLOG_CHECK_INCLUDE_DIRS} 44 | PATH_SUFFIXES ${GLOG_CHECK_PATH_SUFFIXES}) 45 | 46 | find_library( 47 | GLOG_LIBRARIES 48 | NAMES glog libglog 49 | PATHS ${GLOG_LIBRARY_DIR_HINTS} ${GLOG_CHECK_LIBRARY_DIRS} 50 | PATH_SUFFIXES ${GLOG_CHECK_LIBRARY_SUFFIXES}) 51 | 52 | if(GLOG_INCLUDE_DIRS AND GLOG_LIBRARIES) 53 | set(GLOG_FOUND TRUE) 54 | message(STATUS "Found Glog") 55 | message(STATUS " Includes : ${GLOG_INCLUDE_DIRS}") 56 | message(STATUS " Libraries : ${GLOG_LIBRARIES}") 57 | 58 | add_library(glog::glog INTERFACE IMPORTED) 59 | target_include_directories(glog::glog INTERFACE ${GLOG_INCLUDE_DIRS}) 60 | target_link_libraries(glog::glog INTERFACE ${GLOG_LIBRARIES}) 61 | endif() 62 | endif() 63 | 64 | if(NOT GLOG_FOUND AND GLOG_FIND_REQUIRED) 65 | # If glog is not installed locally, download and build it using 66 | # ExternalProject. 67 | include(external/glog) 68 | endif() 69 | -------------------------------------------------------------------------------- /scripts/unittests/python.sh: -------------------------------------------------------------------------------- 1 | # for file in find(); do 2 | # unit = $(basename $file .py) 3 | # make unit_test UNIT_TEST=$unit 4 | # done 5 | 6 | for file in $(find tests/python -name *.py); do 7 | unit=$(basename $file .py) 8 | make unit_test UNIT_TEST=$unit 9 | done 10 | -------------------------------------------------------------------------------- /scripts/unittests/run_all_cpp_tests.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | BUILD_DIR="build" 4 | TESTS_DIR="$(pwd)/tests/cpp" 5 | 6 | if [ ! -d $BUILD_DIR ]; then 7 | echo "This script should be run from the root of the project." 8 | exit 0 9 | fi 10 | 11 | if [ $(find "$BUILD_DIR/tests/cpp" -name "test_*" | wc -l) -eq 0 ]; then 12 | echo "No cpp tests are found." 13 | exit 0 14 | fi 15 | 16 | cd $BUILD_DIR 17 | 18 | for file in $(find "$TESTS_DIR/cell/" -name "test_*.cu"); do 19 | test_name=$(basename $file .cu) 20 | echo "Running test: $test_name" 21 | ctest -R $test_name 22 | done 23 | 24 | for file in $(find "$TESTS_DIR/kernels/" -name "test_*.cu"); do 25 | test_name=$(basename $file .cu) 26 | echo "Running test: $test_name" 27 | ctest -R $test_name 28 | done 29 | -------------------------------------------------------------------------------- /src/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | set(TARGET "tiledcuda") 2 | 3 | file(GLOB_RECURSE SOURCES "kernels/*.cu" "*.cc") 4 | 5 | # Define our library target 6 | cuda_add_library(${TARGET} SHARED ${SOURCES}) 7 | 8 | set_target_properties( 9 | ${TARGET} 10 | PROPERTIES CXX_STANDARD 20 11 | CXX_STANDARD_REQUIRED ON 12 | CXX_EXTENSIONS OFF 13 | CUDA_STANDARD 20 14 | CUDA_STANDARD_REQUIRED ON 15 | CUDA_EXTENSIONS OFF 16 | CUDA_RESOLVE_DEVICE_SYMBOLS ON 17 | CUDA_SEPARABLE_COMPILATION ON) 18 | 19 | target_compile_options( 20 | ${TARGET} PUBLIC $<$: -Werror,-Wall -rdc=true 21 | -std=c++20 -fconcepts -fpermissive>) 22 | target_compile_features(${TARGET} PUBLIC cxx_std_20 cuda_std_20) 23 | target_link_libraries(${TARGET} "${TORCH_LIBRARIES}") 24 | -------------------------------------------------------------------------------- /src/cuda_info.cc: -------------------------------------------------------------------------------- 1 | #include "cuda_info.hpp" 2 | 3 | #include 4 | #include 5 | 6 | namespace tiledcuda { 7 | // Returns the number of GPUs. 8 | int GetGPUDeviceCount() { 9 | int deviceCount = 0; 10 | CudaCheck(cudaGetDeviceCount(&deviceCount)); 11 | return deviceCount; 12 | } 13 | 14 | // Returns the compute capability of the given GPU. 15 | int GetGPUComputeCapability(int id) { 16 | int major, minor; 17 | CudaCheck( 18 | cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, id)); 19 | CudaCheck( 20 | cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, id)); 21 | return major * 10 + minor; 22 | } 23 | 24 | // Returns the number of multiprocessors for the given GPU. 25 | int GetGPUMultiProcessors(int id) { 26 | int count; 27 | CudaCheck( 28 | cudaDeviceGetAttribute(&count, cudaDevAttrMultiProcessorCount, id)); 29 | return count; 30 | } 31 | 32 | // Returns the maximum number of threads per multiprocessor for the given GPU. 33 | int GetGPUMaxThreadsPerMultiProcessor(int id) { 34 | int count; 35 | CudaCheck(cudaDeviceGetAttribute( 36 | &count, cudaDevAttrMaxThreadsPerMultiProcessor, id)); 37 | return count; 38 | } 39 | 40 | // Returns the maximum number of threads per block for the given GPU. 41 | int GetGPUMaxThreadsPerBlock(int id) { 42 | int count; 43 | CudaCheck( 44 | cudaDeviceGetAttribute(&count, cudaDevAttrMaxThreadsPerBlock, id)); 45 | return count; 46 | } 47 | 48 | // Returns the maximum grid size for the given GPU. 49 | dim3 GetGpuMaxGridDimSize(int id) { 50 | dim3 grid_size; 51 | 52 | int size; 53 | CudaCheck(cudaDeviceGetAttribute(&size, cudaDevAttrMaxGridDimX, id)); 54 | grid_size.x = size; 55 | 56 | CudaCheck(cudaDeviceGetAttribute(&size, cudaDevAttrMaxGridDimY, id)); 57 | grid_size.y = size; 58 | 59 | CudaCheck(cudaDeviceGetAttribute(&size, cudaDevAttrMaxGridDimZ, id)); 60 | grid_size.z = size; 61 | return grid_size; 62 | } 63 | 64 | // Returns the name of the device. 65 | std::string GetDeviceName() { 66 | cudaDeviceProp prop; 67 | cudaGetDeviceProperties(&prop, 0); 68 | 69 | std::stringstream ss(prop.name); 70 | const char delim = ' '; 71 | 72 | std::string s; 73 | std::vector out; 74 | 75 | while (std::getline(ss, s, delim)) { 76 | out.push_back(s); 77 | } 78 | 79 | std::stringstream out_ss; 80 | int i = 0; 81 | for (; i < static_cast(out.size()) - 1; ++i) out_ss << out[i] << "_"; 82 | out_ss << out[i]; 83 | return out_ss.str(); 84 | } 85 | } // namespace tiledcuda 86 | -------------------------------------------------------------------------------- /src/cuda_utils.cc: -------------------------------------------------------------------------------- 1 | #include "cuda_utils.hpp" 2 | 3 | namespace tiledcuda { 4 | const char* cublasGetErrorString(cublasStatus_t status) { 5 | switch (status) { 6 | case CUBLAS_STATUS_SUCCESS: 7 | return "CUBLAS_STATUS_SUCCESS"; 8 | case CUBLAS_STATUS_NOT_INITIALIZED: 9 | return "CUBLAS_STATUS_NOT_INITIALIZED"; 10 | case CUBLAS_STATUS_ALLOC_FAILED: 11 | return "CUBLAS_STATUS_ALLOC_FAILED"; 12 | case CUBLAS_STATUS_INVALID_VALUE: 13 | return "CUBLAS_STATUS_INVALID_VALUE"; 14 | case CUBLAS_STATUS_ARCH_MISMATCH: 15 | return "CUBLAS_STATUS_ARCH_MISMATCH"; 16 | case CUBLAS_STATUS_MAPPING_ERROR: 17 | return "CUBLAS_STATUS_MAPPING_ERROR"; 18 | case CUBLAS_STATUS_EXECUTION_FAILED: 19 | return "CUBLAS_STATUS_EXECUTION_FAILED"; 20 | case CUBLAS_STATUS_INTERNAL_ERROR: 21 | return "CUBLAS_STATUS_INTERNAL_ERROR"; 22 | case CUBLAS_STATUS_NOT_SUPPORTED: 23 | return "CUBLAS_STATUS_NOT_SUPPORTED"; 24 | case CUBLAS_STATUS_LICENSE_ERROR: 25 | return "CUBLAS_STATUS_LICENSE_ERROR"; 26 | } 27 | return "unknown error"; 28 | } 29 | } // namespace tiledcuda 30 | -------------------------------------------------------------------------------- /src/kernels/scatter_nd.cu: -------------------------------------------------------------------------------- 1 | #include "kernels/scatter_nd.hpp" 2 | 3 | #include 4 | 5 | namespace tiledcuda::kernels { 6 | // reference: 7 | // https://github.com/InfiniTensor/RefactorGraph/blob/master/src/04kernel/cuda/src/scatter_nd.cu#L7 8 | // TODO: optimize the kernel by increasing the number of threads to perform 9 | // `atomic_add` operations under `slice_size`. 10 | /** 11 | * @brief The ScatterNdkernel updates the content of `updates` into `data` based 12 | * on the index information provided in the given `indices`. 13 | * 14 | * @param in The input tensor `updates`. 15 | * @param out The output tensor `data`. 16 | * @param indices The indices tensor. 17 | * @param strides record the stride information between different dimensions in 18 | * the `data` tensor. 19 | * @param n The number of indices. 20 | * @param rank The last dimension of `indices`. 21 | * @param slice_size The length of the slice to be updated. Specifically, it is 22 | * the product of the difference between the rank of `data` and the last 23 | * dimension of `indices` along the memory dimensions of `data`. 24 | */ 25 | template 26 | __global__ void scatter_nd_kernel(const T* in, T* out, const int64_t* indices, 27 | unsigned int const* __restrict__ strides, 28 | size_t n, size_t rank, size_t slice_size) { 29 | for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x, 30 | step = blockDim.x * gridDim.x; 31 | tid < n; tid += step) { 32 | if (tid < n) { 33 | // tid = indices_index 34 | unsigned int out_index = 0; 35 | // the rank of `data`. 36 | auto i = indices + tid * rank; 37 | // Compute the offset in the output. 38 | // j = i[0] * strides[0] + i[1] * strides[1] + ... + i[k] * 39 | // strides[k] 40 | 41 | for (auto k = 0; k < rank; ++k) { 42 | out_index += i[k] * __ldg(strides + k); 43 | }; 44 | for (size_t offset = 0; offset < slice_size; ++offset) { 45 | atomicAdd(out + out_index + offset, 46 | in[tid * slice_size + offset]); 47 | } 48 | } 49 | } 50 | } 51 | 52 | template 53 | void scatter_nd(torch::Tensor& data, const torch::Tensor& updates, 54 | const torch::Tensor& indices) { 55 | auto data_dims = data.sizes(); 56 | auto update_dims = updates.sizes(); 57 | auto indices_dims = indices.sizes(); 58 | 59 | // k is the last dimension of indices. 60 | int64_t k = indices_dims[indices_dims.size() - 1]; 61 | 62 | // the rank of data. 63 | size_t rank = data_dims.size(); 64 | 65 | unsigned int* strides = new unsigned int[rank]; 66 | strides[rank - 1] = 1; 67 | 68 | for (int64_t i = rank - 2; i >= 0; --i) { 69 | strides[i] = strides[i + 1] * data_dims[i + 1]; 70 | } 71 | 72 | unsigned int* device_strides; 73 | CudaCheck(cudaMalloc(&device_strides, rank * sizeof(unsigned int))); 74 | CudaCheck(cudaMemcpy(device_strides, strides, rank * sizeof(unsigned int), 75 | cudaMemcpyHostToDevice)); 76 | 77 | // `n` is the product of all dimensions excluding the innermost 78 | // dimension of `indices`. 79 | size_t n = indices.numel() / k; 80 | 81 | size_t slice_size = 1; 82 | for (size_t i = k; i < rank; ++i) { 83 | slice_size *= data_dims[i]; 84 | } 85 | 86 | size_t data_size = data.numel(); 87 | 88 | // #ifdef DEBUG 89 | // for (int i = rank - 1; i >= 0; --i) { 90 | // std::cout << "strides[" << i << "]: " << strides[i] << std::endl; 91 | // } 92 | // for (int i = rank - 1; i >= 0; --i) { 93 | // std::cout << "data_dims[" << i << "]: " << data_dims[i] << 94 | // std::endl; 95 | // } 96 | // std::cout << "k: " << k << ", rank: " << rank << std::endl; 97 | // std::cout << "n: " << n << ", slice_size: " << slice_size << 98 | // std::endl; std::cout << "data_size: " << data_size << std::endl; 99 | // #endif 100 | 101 | // TODO: Add some assertion checks. 102 | 103 | int64_t block = 256; 104 | int64_t grid = (n + block - 1) / block; 105 | 106 | scatter_nd_kernel<<>>( 107 | reinterpret_cast(indices.const_data_ptr()), 108 | reinterpret_cast(data.mutable_data_ptr()), 109 | reinterpret_cast(indices.const_data_ptr()), 110 | reinterpret_cast(device_strides), n, k, 111 | slice_size); 112 | } 113 | 114 | void custom_scatter_op(torch::Tensor& data, const torch::Tensor& updates, 115 | const torch::Tensor& indices) { 116 | auto dtype = data.dtype(); 117 | if (dtype == torch::kFloat32) { 118 | scatter_nd(data, updates, indices); 119 | } else if (dtype == torch::kHalf) { 120 | scatter_nd<__half>(data, updates, indices); 121 | } 122 | } 123 | 124 | } // namespace tiledcuda::kernels 125 | -------------------------------------------------------------------------------- /src/torch_bind.cc: -------------------------------------------------------------------------------- 1 | #include "kernels/mod.hpp" 2 | 3 | #include 4 | 5 | namespace tiledcuda { 6 | using namespace tiledcuda::kernels; 7 | 8 | TORCH_LIBRARY(tiledcuda, t) { 9 | t.def("scatter_nd", &custom_scatter_op); 10 | t.def("flash_attention_fwd", &custom_flash_attention_op); 11 | }; 12 | 13 | } // namespace tiledcuda 14 | -------------------------------------------------------------------------------- /tests/cpp/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | if(WITH_TESTING) 2 | include_directories(${CMAKE_CURRENT_SOURCE_DIR}) 3 | file( 4 | GLOB_RECURSE UNIT_TESTS 5 | LIST_DIRECTORIES FALSE 6 | RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" 7 | "test_*.cu") 8 | 9 | foreach(FILE_PATH ${UNIT_TESTS}) 10 | string(REGEX REPLACE ".+/(.+)\\..*" "\\1" FILE_NAME ${FILE_PATH}) 11 | string(REPLACE ".cu" "" TEST_NAME "${FILE_NAME}") 12 | 13 | if("${TEST_NAME}" STREQUAL "test_gemm") 14 | continue() # the unittest for gemm requires extra dependencies 15 | endif() 16 | 17 | cuda_test(${TEST_NAME} SRCS "${CMAKE_CURRENT_SOURCE_DIR}/${FILE_PATH}") 18 | endforeach() 19 | 20 | cuda_test(test_gemm SRCS "${CMAKE_CURRENT_SOURCE_DIR}/cell/test_gemm.cu" 21 | "${PROJECT_SOURCE_DIR}/src/cuda_utils.cc" DEPS 22 | ${CUDA_CUBLAS_LIBRARIES}) 23 | endif() 24 | -------------------------------------------------------------------------------- /tests/cpp/cell/test_broadcast.cu: -------------------------------------------------------------------------------- 1 | #include "cell/compute/broadcast.hpp" 2 | #include "cell/compute/reduce.hpp" 3 | #include "cell/copy/constants.hpp" 4 | #include "cell/copy/mod.hpp" 5 | #include "common/test_utils.hpp" 6 | #include "types/mod.hpp" 7 | #include "util/debug.hpp" 8 | 9 | #include 10 | #include 11 | 12 | namespace tiledcuda::testing { 13 | using namespace cell; 14 | 15 | template 19 | __global__ void reg_broadcast(Element* src) { 20 | using SrcLoadTile = GlobalTile; 21 | using DstLoadTile = RegTile; 22 | using SrcReduceTile = DstLoadTile; 23 | using DstReduceTile = RegTile>; 24 | using SrcBroadcastTile = DstReduceTile; 25 | using DstBroadcastTile = SrcReduceTile; 26 | 27 | SrcLoadTile src_load_tile(src); 28 | DstLoadTile dst_load_tile; 29 | DstReduceTile dst_reduce_tile; 30 | DstBroadcastTile dst_broadcast_tile; 31 | 32 | // Load data from global memory to register file 33 | copy::GlobalToRegLoader loader; 34 | loader(src_load_tile, dst_load_tile); 35 | __syncthreads(); 36 | 37 | // Execute reduce operation. 38 | compute::MaxReduce row_max; 39 | row_max(dst_load_tile, dst_reduce_tile); 40 | 41 | __syncthreads(); 42 | 43 | compute::Broadcast 44 | broadcast_reduce; 45 | 46 | broadcast_reduce(dst_reduce_tile, dst_broadcast_tile); 47 | 48 | __syncthreads(); 49 | 50 | if (thread(0)) { 51 | printf("Row Max:\n"); 52 | printf("Thread 0:\n"); 53 | dst_broadcast_tile.dump_value(); 54 | } 55 | } 56 | 57 | template 60 | void run_row_major_reg_broadcast() { 61 | int kNumel = 16 * 16 * kHeight * kWidth; 62 | int kWarpSize = tl::get_numel; 63 | 64 | using ReduceLayout = tl::RowMajor; 65 | 66 | thrust::host_vector h_src(kNumel); 67 | for (int i = 0; i < kNumel; ++i) { 68 | h_src[i] = (Element)i; 69 | } 70 | 71 | thrust::device_vector d_src = h_src; 72 | 73 | reg_broadcast 75 | <<<1, 32 * kWarpSize>>>(thrust::raw_pointer_cast(d_src.data())); 76 | } 77 | 78 | TEST(TestRegBroadcast, row_major_reg_broadcast_0) { 79 | const int kHeight = 1; 80 | const int kWidth = 1; 81 | using Element = float; 82 | using WarpLayout = tl::RowMajor<1, 1>; 83 | using RegLayout = tl::RowMajor; 84 | 85 | const copy::WarpReuse kMode = copy::WarpReuse::kCont; 86 | 87 | using GlobalLayout = tl::RowMajor<16 * kHeight, 16 * kWidth>; 88 | 89 | run_row_major_reg_broadcast< 90 | Element, RegLayout, GlobalLayout, BaseTileRowMajor, WarpLayout, 91 | tl::Layout::kRowMajor, kMode, kHeight, kWidth>(); 92 | } 93 | 94 | TEST(TestRegBroadcast, row_major_reg_broadcast_1) { 95 | const int kHeight = 2; 96 | const int kWidth = 2; 97 | using Element = float; 98 | using WarpLayout = tl::RowMajor<1, 1>; 99 | using RegLayout = tl::RowMajor; 100 | 101 | const copy::WarpReuse kMode = copy::WarpReuse::kCont; 102 | 103 | using GlobalLayout = tl::RowMajor<16 * kHeight, 16 * kWidth>; 104 | 105 | run_row_major_reg_broadcast< 106 | Element, RegLayout, GlobalLayout, BaseTileRowMajor, WarpLayout, 107 | tl::Layout::kRowMajor, kMode, kHeight, kWidth>(); 108 | } 109 | 110 | } // namespace tiledcuda::testing 111 | -------------------------------------------------------------------------------- /tests/cpp/cell/test_flash_attn.cu: -------------------------------------------------------------------------------- 1 | #include "cell/compute/broadcast.hpp" 2 | #include "cell/compute/map.hpp" 3 | #include "cell/compute/reduce.hpp" 4 | #include "cell/copy/constants.hpp" 5 | #include "cell/copy/mod.hpp" 6 | #include "common/test_utils.hpp" 7 | #include "types/mod.hpp" 8 | #include "util/debug.hpp" 9 | 10 | #include 11 | #include 12 | 13 | namespace tiledcuda::testing { 14 | using namespace cell; 15 | 16 | /** 17 | * @brief Reduce/Map operation for the flash attention. 18 | */ 19 | template 23 | __global__ void flash_attn_reg_reduce(Element* src) { 24 | using SrcLoadTile = GlobalTile; 25 | using DstLoadTile = RegTile; 26 | using SrcReduceTile = DstLoadTile; 27 | using DstReduceTile = RegTile>; 28 | using SrcBroadcastTile = DstReduceTile; 29 | using DstBroadcastTile = SrcReduceTile; 30 | 31 | SrcLoadTile src_load_tile(src); 32 | DstLoadTile attn_block; 33 | DstReduceTile last_max_vec; 34 | DstReduceTile max_vec; 35 | DstReduceTile last_norm_vec; 36 | DstReduceTile norm_vec; 37 | DstBroadcastTile max_broadcast_tile; 38 | DstBroadcastTile norm_broadcast_tile; 39 | 40 | // Load data from global memory to register file 41 | copy::GlobalToRegLoader loader; 42 | loader(src_load_tile, attn_block); 43 | 44 | // Copy `max_vec` into `last_max_vec` 45 | copy::BaseTileCopy copy_max_reg; 46 | copy_max_reg(max_vec, last_max_vec); 47 | // Copy `norm_vec` into `last_norm_vec` 48 | copy::BaseTileCopy copy_norm_reg; 49 | copy_norm_reg(norm_vec, last_norm_vec); 50 | 51 | // Execute reduce operation. 52 | compute::MaxReduce row_max; 53 | // accumulate onto the max_vec 54 | row_max(attn_block, max_vec); 55 | 56 | compute::Broadcast 57 | broadcast_max; 58 | 59 | broadcast_max(max_vec, max_broadcast_tile); 60 | 61 | if (thread(0)) { 62 | printf("Thread 0:\n"); 63 | max_vec.dump_value(); 64 | max_broadcast_tile.dump_value(); 65 | attn_block.dump_value(); 66 | } 67 | 68 | // subtract max from attention -- now all <= 0. 69 | compute::RegTileSub sub_row_max; 70 | sub_row_max(attn_block, max_broadcast_tile, attn_block); 71 | 72 | if (thread(0)) { 73 | printf("Thread 0:\n"); 74 | attn_block.dump_value(); 75 | } 76 | 77 | // exponentiate the block in-place. 78 | compute::RegTileExp exp_attn; 79 | exp_attn(attn_block, attn_block); 80 | 81 | if (thread(0)) { 82 | printf("Thread 0:\n"); 83 | attn_block.dump_value(); 84 | } 85 | 86 | // subtract new max from old max to find the new normalization. 87 | compute::BaseTileSub sub_new_max; 88 | sub_new_max(last_max_vec, max_vec, last_max_vec); 89 | 90 | // exponentiate this vector -- this is what we need to normalize by. 91 | compute::BaseTileExp exp_max; 92 | exp_max(last_max_vec, last_max_vec); 93 | 94 | // and the norm vec is now normalized. 95 | compute::BaseTileMul mul_norm; 96 | mul_norm(last_max_vec, norm_vec, norm_vec); 97 | 98 | // Accumulate the new attention block onto the now-rescaled norm-vec. 99 | // Reduce Sum + Add 100 | DstReduceTile sum_vec; 101 | compute::SumReduce row_sum; 102 | row_sum(attn_block, sum_vec); 103 | compute::BaseTileAdd add_sum; 104 | add_sum(sum_vec, norm_vec, norm_vec); 105 | 106 | // Now the attention block is correctly normalized. 107 | // Broadcast + Divide 108 | compute::Broadcast 109 | broadcast_norm; 110 | broadcast_norm(norm_vec, norm_broadcast_tile); 111 | compute::RegTileDiv div_norm; 112 | div_norm(attn_block, norm_broadcast_tile, attn_block); 113 | 114 | // Normalize the previous norm vec accorfing to the new max. 115 | compute::BaseTileMul mul_norm_new; 116 | mul_norm_new(last_max_vec, last_norm_vec, last_norm_vec); 117 | 118 | // Normalize the previous norm vec according to the new norm. 119 | compute::BaseTileDiv div_norm_new; 120 | div_norm_new(last_norm_vec, norm_vec, last_norm_vec); 121 | } 122 | 123 | template 126 | void run_row_major_reg_flash_attn() { 127 | int kNumel = 16 * 16 * kHeight * kWidth; 128 | int kWarpSize = tl::get_numel; 129 | 130 | using ReduceLayout = tl::RowMajor; 131 | 132 | thrust::host_vector h_src(kNumel); 133 | for (int i = 0; i < kNumel; ++i) { 134 | h_src[i] = (Element)i; 135 | } 136 | 137 | thrust::device_vector d_src = h_src; 138 | 139 | flash_attn_reg_reduce 141 | <<<1, 32 * kWarpSize>>>(thrust::raw_pointer_cast(d_src.data())); 142 | } 143 | 144 | TEST(TestRegBroadcast, row_major_reg_flash_attn_0) { 145 | const int kHeight = 1; 146 | const int kWidth = 1; 147 | using Element = float; 148 | using WarpLayout = tl::RowMajor<1, 1>; 149 | using RegLayout = tl::RowMajor; 150 | 151 | const copy::WarpReuse kMode = copy::WarpReuse::kCont; 152 | 153 | using GlobalLayout = tl::RowMajor<16 * kHeight, 16 * kWidth>; 154 | 155 | run_row_major_reg_flash_attn< 156 | Element, RegLayout, GlobalLayout, BaseTileRowMajor, WarpLayout, 157 | tl::Layout::kRowMajor, kMode, kHeight, kWidth>(); 158 | } 159 | 160 | } // namespace tiledcuda::testing 161 | -------------------------------------------------------------------------------- /tests/cpp/cell/test_g2s_load.cu: -------------------------------------------------------------------------------- 1 | #include "cell/copy/mod.hpp" 2 | #include "cell/sync.hpp" 3 | #include "common/test_utils.hpp" 4 | #include "types/mod.hpp" 5 | 6 | #include 7 | #include 8 | 9 | namespace tiledcuda::testing { 10 | using namespace cell; 11 | 12 | namespace { 13 | template 15 | __global__ void copy_g2s(const Element* src_ptr, Element* dst_ptr, 16 | Loader& loader, Storer& storer) { 17 | extern __shared__ __align__(sizeof(double)) unsigned char buf_[]; 18 | auto* buf = reinterpret_cast(buf_); 19 | 20 | SrcTile src(src_ptr); // global memory tile 21 | DstTile inter(buf); // shared memory tile 22 | SrcTile dst(dst_ptr); // global memory tile 23 | 24 | loader(src, inter); 25 | __copy_async(); 26 | __syncthreads(); 27 | 28 | storer(inter, dst); 29 | __syncthreads(); 30 | } 31 | 32 | template 34 | void run_test_row_major() { 35 | static const int kThreads = tl::get_numel * 32; 36 | 37 | int numel = kRows * kCols; 38 | thrust::host_vector h_A(numel); 39 | for (int i = 0; i < h_A.size(); ++i) 40 | h_A[i] = static_cast(i % 2048); 41 | 42 | thrust::device_vector d_B(numel); 43 | thrust::fill(d_B.begin(), d_B.end(), static_cast(0.)); 44 | thrust::device_vector d_A = h_A; 45 | 46 | static const bool kSwizzled = false; 47 | 48 | using SrcTile = GlobalTile>; 49 | using DstTile = SharedTile, kSwizzled>; 50 | 51 | using Loader = copy::GlobalToSharedLoader; 52 | Loader loader; 53 | 54 | using Storer = copy::SharedToGlobalStorer; 55 | Storer storer; 56 | 57 | dim3 dim_grid(1, 1); 58 | dim3 dim_block(kThreads); 59 | 60 | copy_g2s 61 | <<>>( 62 | thrust::raw_pointer_cast(d_A.data()), 63 | thrust::raw_pointer_cast(d_B.data()), loader, storer); 64 | cudaDeviceSynchronize(); 65 | 66 | thrust::host_vector h_B(numel); 67 | h_B = d_B; 68 | 69 | assert_equal( 70 | reinterpret_cast(thrust::raw_pointer_cast(h_A.data())), 71 | reinterpret_cast(thrust::raw_pointer_cast(h_B.data())), numel, 72 | 1e-5); 73 | } 74 | 75 | template 77 | void run_test_col_major() { 78 | static const int kThreads = tl::get_numel * 32; 79 | 80 | int numel = kRows * kCols; 81 | thrust::host_vector h_A(numel); 82 | for (int i = 0; i < h_A.size(); ++i) 83 | h_A[i] = static_cast(i % 2048); 84 | 85 | thrust::device_vector d_B(numel); 86 | thrust::fill(d_B.begin(), d_B.end(), static_cast(0.)); 87 | thrust::device_vector d_A = h_A; 88 | 89 | static const bool kSwizzled = false; 90 | using SrcTile = GlobalTile>; 91 | using DstTile = SharedTile, kSwizzled>; 92 | 93 | using Loader = copy::GlobalToSharedLoader; 94 | Loader loader; 95 | 96 | using Storer = copy::SharedToGlobalStorer; 97 | Storer storer; 98 | 99 | dim3 dim_grid(1, 1); 100 | dim3 dim_block(kThreads); 101 | 102 | copy_g2s 103 | <<>>( 104 | thrust::raw_pointer_cast(d_A.data()), 105 | thrust::raw_pointer_cast(d_B.data()), loader, storer); 106 | cudaDeviceSynchronize(); 107 | 108 | thrust::host_vector h_B(numel); 109 | h_B = d_B; 110 | 111 | assert_equal( 112 | reinterpret_cast(thrust::raw_pointer_cast(h_A.data())), 113 | reinterpret_cast(thrust::raw_pointer_cast(h_B.data())), numel, 114 | 1e-5); 115 | } 116 | } // namespace 117 | 118 | TEST(GlobalToSharedLoad, test_row_major_load) { 119 | run_test_row_major<__half, tl::RowMajor<1, 1>, 16, 32>(); 120 | run_test_row_major<__half, tl::RowMajor<1, 4>, 32, 128>(); 121 | run_test_row_major<__half, tl::RowMajor<4, 1>, 192, 32>(); 122 | run_test_row_major<__half, tl::RowMajor<2, 2>, 64, 128>(); 123 | run_test_row_major<__half, tl::RowMajor<2, 4>, 96, 128>(); 124 | 125 | run_test_row_major, 16, 16>(); 126 | run_test_row_major, 32, 128>(); 127 | run_test_row_major, 192, 32>(); 128 | run_test_row_major, 64, 128>(); 129 | run_test_row_major, 96, 128>(); 130 | } 131 | 132 | TEST(GlobalToSharedLoad, test_col_major_load) { 133 | run_test_col_major<__half, tl::RowMajor<1, 1>, 16, 16>(); 134 | run_test_col_major<__half, tl::RowMajor<1, 4>, 32, 128>(); 135 | run_test_col_major<__half, tl::RowMajor<4, 1>, 192, 32>(); 136 | run_test_col_major<__half, tl::RowMajor<2, 2>, 64, 128>(); 137 | run_test_col_major<__half, tl::RowMajor<2, 4>, 96, 128>(); 138 | 139 | run_test_col_major, 16, 16>(); 140 | run_test_col_major, 32, 128>(); 141 | run_test_col_major, 192, 32>(); 142 | run_test_col_major, 64, 128>(); 143 | run_test_col_major, 96, 128>(); 144 | } 145 | 146 | } // namespace tiledcuda::testing 147 | -------------------------------------------------------------------------------- /tests/cpp/cell/test_global_tile_iterator.cu: -------------------------------------------------------------------------------- 1 | #include "common/test_utils.hpp" 2 | #include "types/mod.hpp" 3 | 4 | #include 5 | 6 | namespace tiledcuda::testing { 7 | 8 | using namespace cell; 9 | namespace tl = tile_layout; 10 | 11 | namespace { 12 | template 13 | struct GTileIteratorTester; 14 | 15 | template 16 | struct GTileIteratorTester { 17 | using Element = float; 18 | using Layout = Layout_; 19 | 20 | static constexpr int kRows = Layout::kRows; 21 | static constexpr int kCols = Layout::kCols; 22 | 23 | static constexpr int kStride0 = dim_size<0, ChunkShape>; 24 | static constexpr int kStride1 = dim_size<1, ChunkShape>; 25 | 26 | const int kTileRowStride = kStride0 * Layout::kRowStride; 27 | const int kTileColStride = kStride1; 28 | 29 | static_assert(kRows % kStride0 == 0, "kRows must be divisible by kStride0"); 30 | static_assert(kCols % kStride1 == 0, "kCols must be divisible by kStride1"); 31 | 32 | using Tile = GlobalTile>; 33 | using Iterator = GTileIterator; 34 | 35 | void operator()() { 36 | int numel = kRows * kCols; 37 | thrust::host_vector data(numel); 38 | 39 | Layout layout; 40 | Element* ptr = data.data(); 41 | int count = 0; 42 | for (int i = 0; i < kRows; ++i) 43 | for (int j = 0; j < kCols; ++j) ptr[count++] = layout(i, j); 44 | 45 | #if defined(DEBUG_PRINT) 46 | Tile gtile(ptr); 47 | gtile.dump_value(); 48 | #endif 49 | 50 | EXPECT_EQ(Iterator::sc0, kRows / kStride0); 51 | EXPECT_EQ(Iterator::sc1, kCols / kStride1); 52 | 53 | Iterator iter(data.data()); 54 | 55 | for (int i = 0; i < Iterator::sc0; ++i) { 56 | for (int j = 0; j < Iterator::sc1; ++j) { 57 | int start_n = i * kTileRowStride + j * kTileColStride; 58 | auto tile = iter(i, j); 59 | for (int m = 0; m < kStride0; ++m) { 60 | for (int n = 0; n < kStride1; ++n) { 61 | int v1 = int(tile(m, n)); 62 | int v2 = start_n + m * Layout::kRowStride + n; 63 | EXPECT_EQ(v1, v2); 64 | } 65 | } 66 | 67 | #if defined(DEBUG_PRINT) 68 | printf("\nIteration-[%d, %d]:\n", i, j); 69 | iter(i, j).dump_value(); 70 | printf("\n"); 71 | #endif 72 | } 73 | } 74 | } 75 | }; 76 | 77 | template 78 | struct GTileIteratorTester { 79 | using Element = float; 80 | using Layout = Layout_; 81 | 82 | static constexpr int kRows = Layout::kRows; 83 | static constexpr int kCols = Layout::kCols; 84 | 85 | static constexpr int kStride0 = dim_size<0, ChunkShape>; 86 | static constexpr int kStride1 = dim_size<1, ChunkShape>; 87 | 88 | const int kTileRowStride = kStride0; 89 | const int kTileColStride = kStride1 * Layout::kColStride; 90 | 91 | static_assert(kRows % kStride0 == 0, "kRows must be divisible by kStride0"); 92 | static_assert(kCols % kStride1 == 0, "kCols must be divisible by kStride1"); 93 | 94 | using Tile = GlobalTile>; 95 | using Iterator = GTileIterator; 96 | 97 | void operator()() { 98 | int numel = kRows * kCols; 99 | thrust::host_vector data(numel); 100 | 101 | Layout layout; 102 | Element* ptr = data.data(); 103 | int count = 0; 104 | for (int i = 0; i < kRows; ++i) 105 | for (int j = 0; j < kCols; ++j) ptr[count++] = layout(i, j); 106 | 107 | #if defined(DEBUG_PRINT) 108 | Tile gtile(ptr); 109 | gtile.dump_value(); 110 | #endif 111 | 112 | EXPECT_EQ(Iterator::sc0, kRows / kStride0); 113 | EXPECT_EQ(Iterator::sc1, kCols / kStride1); 114 | 115 | Iterator iter(data.data()); 116 | 117 | for (int i = 0; i < Iterator::sc0; ++i) { 118 | for (int j = 0; j < Iterator::sc1; ++j) { 119 | int start_n = i * kTileRowStride + j * kTileColStride; 120 | 121 | auto tile = iter(i, j); 122 | for (int m = 0; m < kStride0; ++m) { 123 | for (int n = 0; n < kStride1; ++n) { 124 | int v1 = int(tile(m, n)); 125 | int v2 = start_n + m + n * Layout::kColStride; 126 | 127 | EXPECT_EQ(v1, v2); 128 | } 129 | } 130 | 131 | #if defined(DEBUG_PRINT) 132 | printf("\nIteration-[%d, %d]:\n", i, j); 133 | iter(i, j).dump_value(); 134 | printf("\n"); 135 | #endif 136 | } 137 | } 138 | } 139 | }; 140 | } // namespace 141 | 142 | TEST(TestGTileIterator, test_row_major) { 143 | using Tester = GTileIteratorTester, TileShape<2, 3>, 144 | tl::Layout::kRowMajor>; 145 | Tester tester; 146 | tester(); 147 | } 148 | 149 | TEST(TestGTileIterator, col_major) { 150 | using Tester = GTileIteratorTester, TileShape<2, 3>, 151 | tl::Layout::kColMajor>; 152 | Tester tester; 153 | tester(); 154 | } 155 | } // namespace tiledcuda::testing 156 | -------------------------------------------------------------------------------- /tests/cpp/cell/test_layout.cu: -------------------------------------------------------------------------------- 1 | #include "cell/copy/global_to_shared.hpp" 2 | #include "common/test_utils.hpp" 3 | #include "types/mod.hpp" 4 | 5 | #include 6 | #include 7 | 8 | namespace tiledcuda::testing { 9 | using namespace cell; 10 | using namespace cute; 11 | namespace tl = tile_layout; 12 | 13 | namespace { 14 | 15 | template 16 | void test_swizzled_function(); 17 | 18 | template <> 19 | void test_swizzled_function<__half>() { 20 | using Element = __half; 21 | static constexpr int kBits = 16 * 8; 22 | 23 | const int kRows = 16; 24 | const int kCols = 32; 25 | 26 | thrust::host_vector data(kRows * kCols); 27 | for (int i = 0; i < data.size(); ++i) { 28 | data[i] = static_cast(i % 2048); 29 | } 30 | 31 | using RowMajor = tl::RowMajor; 32 | RowMajor layout1; 33 | 34 | // only siwizzle the first [16x16] half of the [kRows, kCols] matrix 35 | using Swizzled = tl::detail::SwizzledRowMajor; 36 | Swizzled layout2; 37 | 38 | Element* ptr = thrust::raw_pointer_cast(data.data()); 39 | 40 | printf("\nnon-swizzled:\n"); 41 | for (int i = 0; i < RowMajor::kRows; ++i) { 42 | for (int j = 0; j < RowMajor::kCols; ++j) { 43 | printf("%.0f, ", __half2float(ptr[layout1(i, j)])); 44 | } 45 | printf("\n"); 46 | } 47 | 48 | printf("\nswizzled:\n"); 49 | for (int i = 0; i < kRows; ++i) { 50 | for (int j = 0; j < 16; ++j) { 51 | printf("%.0f, ", __half2float(ptr[layout2(i, j)])); 52 | } 53 | printf("\n"); 54 | } 55 | } 56 | 57 | template <> 58 | void test_swizzled_function() { 59 | using Element = float; 60 | static constexpr int kBits = 32 * 2; 61 | 62 | const int kRows = 16; 63 | const int kCols = 16; 64 | 65 | thrust::host_vector data(kRows * kCols); 66 | for (int i = 0; i < data.size(); ++i) { 67 | data[i] = static_cast(i % 2048); 68 | } 69 | 70 | using RowMajor = tl::RowMajor; 71 | RowMajor layout1; 72 | 73 | // only siwizzle the first [16x16] half of the [kRows, kCols] matrix 74 | using Swizzled = tl::detail::SwizzledRowMajor; 75 | Swizzled layout2; 76 | 77 | for (int i = 0; i < RowMajor::kRows; ++i) { 78 | for (int j = 0; j < RowMajor::kCols; ++j) { 79 | printf("[%d:%d], ", layout1(i, j), layout2(i, j)); 80 | } 81 | printf("\n"); 82 | } 83 | 84 | Element* ptr = thrust::raw_pointer_cast(data.data()); 85 | 86 | printf("\nnon-swizzled:\n"); 87 | for (int i = 0; i < RowMajor::kRows; ++i) { 88 | for (int j = 0; j < RowMajor::kCols; ++j) { 89 | printf("%.0f, ", ptr[layout1(i, j)]); 90 | } 91 | printf("\n"); 92 | } 93 | 94 | printf("\nswizzled:\n"); 95 | for (int i = 0; i < kRows; ++i) { 96 | for (int j = 0; j < 16; ++j) { 97 | printf("%.0f, ", ptr[layout2(i, j)]); 98 | } 99 | printf("\n"); 100 | } 101 | } 102 | 103 | } // namespace 104 | 105 | TEST(TestLayout, test_layout) { 106 | using Element = cutlass::half_t; 107 | 108 | using Layout1 = tl::RowMajor<4, 7>; 109 | EXPECT_EQ(tl::num_rows, 4); 110 | EXPECT_EQ(tl::num_cols, 7); 111 | EXPECT_EQ(tl::get_numel, 28); 112 | EXPECT_EQ(tl::row_stride, 7); 113 | EXPECT_EQ(tl::col_stride, 1); 114 | 115 | tl::Layout type1 = tl::layout_type; 116 | EXPECT_EQ(type1, tl::Layout::kRowMajor); 117 | auto layout_name1 = layout_type_to_str(type1); 118 | EXPECT_EQ(layout_name1, "RowMajor"); 119 | 120 | using Layout2 = tl::ColMajor<4, 7>; 121 | EXPECT_EQ(tl::num_rows, 4); 122 | EXPECT_EQ(tl::num_cols, 7); 123 | EXPECT_EQ(tl::get_numel, 28); 124 | EXPECT_EQ(tl::row_stride, 1); 125 | EXPECT_EQ(tl::col_stride, 4); 126 | 127 | tl::Layout type2 = tl::layout_type; 128 | EXPECT_EQ(type2, tl::Layout::kColMajor); 129 | auto layout_name2 = layout_type_to_str(type2); 130 | EXPECT_EQ(layout_name2, "ColMajor"); 131 | } 132 | 133 | TEST(TestLayout, test_swizzled_layout_half) { 134 | test_swizzled_function<__half>(); 135 | test_swizzled_function(); 136 | } 137 | 138 | } // namespace tiledcuda::testing 139 | -------------------------------------------------------------------------------- /tests/cpp/cell/test_reduce.cu: -------------------------------------------------------------------------------- 1 | #include "cell/compute/reduce.hpp" 2 | #include "cell/copy/constants.hpp" 3 | #include "cell/copy/mod.hpp" 4 | #include "common/test_utils.hpp" 5 | #include "types/mod.hpp" 6 | #include "util/debug.hpp" 7 | 8 | #include 9 | #include 10 | 11 | namespace tiledcuda::testing { 12 | using namespace cell; 13 | 14 | template 18 | __global__ void reg_reduce(Element* src) { 19 | using SrcLoadTile = GlobalTile; 20 | using DstLoadTile = RegTile; 21 | using SrcReduceTile = DstLoadTile; 22 | using DstReduceTile = RegTile>; 23 | 24 | SrcLoadTile src_load_tile(src); 25 | DstLoadTile dst_load_tile; 26 | DstReduceTile dst_reduce_tile; 27 | 28 | // Load data from global memory to register file 29 | copy::GlobalToRegLoader loader; 30 | loader(src_load_tile, dst_load_tile); 31 | __syncthreads(); 32 | 33 | // Execute reduce operation. 34 | compute::MaxReduce row_max; 35 | row_max(dst_load_tile, dst_reduce_tile); 36 | 37 | __syncthreads(); 38 | 39 | if (thread(0)) { 40 | printf("Row Max:\n"); 41 | printf("Thread 0:\n"); 42 | dst_reduce_tile.dump_value(); 43 | } 44 | 45 | if (thread(1)) { 46 | printf("Thread 1:\n"); 47 | dst_reduce_tile.dump_value(); 48 | } 49 | 50 | if (thread(4)) { 51 | printf("Thread 4:\n"); 52 | dst_reduce_tile.dump_value(); 53 | } 54 | 55 | if (thread(8)) { 56 | printf("Thread 8:\n"); 57 | dst_reduce_tile.dump_value(); 58 | } 59 | 60 | __syncthreads(); 61 | 62 | compute::SumReduce row_sum; 63 | row_sum(dst_load_tile, dst_reduce_tile); 64 | 65 | __syncthreads(); 66 | 67 | if (thread(0)) { 68 | printf("Row Sum:\n"); 69 | printf("Thread 0:\n"); 70 | dst_reduce_tile.dump_value(); 71 | } 72 | 73 | if (thread(1)) { 74 | printf("Thread 1:\n"); 75 | dst_reduce_tile.dump_value(); 76 | } 77 | 78 | if (thread(4)) { 79 | printf("Thread 4:\n"); 80 | dst_reduce_tile.dump_value(); 81 | } 82 | 83 | if (thread(8)) { 84 | printf("Thread 8:\n"); 85 | dst_reduce_tile.dump_value(); 86 | } 87 | } 88 | 89 | template 92 | void run_row_major_reg_reduce() { 93 | int kNumel = 16 * 16 * kHeight * kWidth; 94 | int kWarpSize = tl::get_numel; 95 | 96 | using ReduceLayout = tl::RowMajor; 97 | 98 | thrust::host_vector h_src(kNumel); 99 | for (int i = 0; i < kNumel; ++i) { 100 | h_src[i] = (Element)i; 101 | } 102 | 103 | thrust::device_vector d_src = h_src; 104 | 105 | reg_reduce 107 | <<<1, 32 * kWarpSize>>>(thrust::raw_pointer_cast(d_src.data())); 108 | } 109 | 110 | template 113 | void run_col_major_reg_reduce() { 114 | int kNumel = 16 * 16 * kHeight * kWidth; 115 | int kWarpSize = tl::get_numel; 116 | 117 | using ReduceLayout = tl::ColMajor<2, kWidth>; 118 | 119 | thrust::host_vector h_src(kNumel); 120 | for (int i = 0; i < kNumel; ++i) { 121 | h_src[i] = (Element)i; 122 | } 123 | 124 | thrust::device_vector d_src = h_src; 125 | 126 | reg_reduce 128 | <<<1, 32 * kWarpSize>>>(thrust::raw_pointer_cast(d_src.data())); 129 | } 130 | 131 | TEST(TestRegReduce, row_major_reg_reduce_0) { 132 | const int kHeight = 1; 133 | const int kWidth = 1; 134 | using Element = float; 135 | using WarpLayout = tl::RowMajor<1, 1>; 136 | using RegLayout = tl::RowMajor; 137 | 138 | const copy::WarpReuse kMode = copy::WarpReuse::kCont; 139 | 140 | using GlobalLayout = tl::RowMajor<16 * kHeight, 16 * kWidth>; 141 | 142 | run_row_major_reg_reduce, WarpLayout, 144 | tl::Layout::kRowMajor, kMode, kHeight, kWidth>(); 145 | } 146 | 147 | TEST(TestRegReduce, row_major_reg_reduce_1) { 148 | const int kHeight = 2; 149 | const int kWidth = 2; 150 | using Element = float; 151 | using WarpLayout = tl::RowMajor<1, 1>; 152 | using RegLayout = tl::RowMajor; 153 | 154 | const copy::WarpReuse kMode = copy::WarpReuse::kCont; 155 | 156 | using GlobalLayout = tl::RowMajor<16 * kHeight, 16 * kWidth>; 157 | 158 | run_row_major_reg_reduce, WarpLayout, 160 | tl::Layout::kRowMajor, kMode, kHeight, kWidth>(); 161 | } 162 | 163 | TEST(TestRegReduce, col_major_reg_reduce_0) { 164 | const int kHeight = 1; 165 | const int kWidth = 1; 166 | using Element = float; 167 | using WarpLayout = tl::ColMajor<1, 1>; 168 | using RegLayout = tl::ColMajor; 169 | 170 | const copy::WarpReuse kMode = copy::WarpReuse::kCont; 171 | 172 | using GlobalLayout = tl::ColMajor<16 * kHeight, 16 * kWidth>; 173 | 174 | run_col_major_reg_reduce, WarpLayout, 176 | tl::Layout::kColMajor, kMode, kHeight, kWidth>(); 177 | } 178 | 179 | TEST(TestRegReduce, col_major_reg_reduce_1) { 180 | const int kHeight = 2; 181 | const int kWidth = 2; 182 | using Element = float; 183 | using WarpLayout = tl::ColMajor<1, 1>; 184 | using RegLayout = tl::ColMajor; 185 | 186 | const copy::WarpReuse kMode = copy::WarpReuse::kCont; 187 | 188 | using GlobalLayout = tl::ColMajor<16 * kHeight, 16 * kWidth>; 189 | 190 | run_col_major_reg_reduce, WarpLayout, 192 | tl::Layout::kColMajor, kMode, kHeight, kWidth>(); 193 | } 194 | 195 | } // namespace tiledcuda::testing 196 | -------------------------------------------------------------------------------- /tests/cpp/cell/test_softmax.cu: -------------------------------------------------------------------------------- 1 | #include "cell/compute/softmax.hpp" 2 | #include "cell/copy/constants.hpp" 3 | #include "cell/copy/mod.hpp" 4 | #include "common/test_utils.hpp" 5 | #include "types/mod.hpp" 6 | #include "util/debug.hpp" 7 | 8 | #include 9 | #include 10 | 11 | namespace tiledcuda::testing { 12 | using namespace cell; 13 | 14 | template 15 | void cpu_softmax(Element* src, Element* dst, int rows, int cols) { 16 | for (int i = 0; i < rows; ++i) { 17 | Element sum = 0; 18 | for (int j = 0; j < cols; ++j) { 19 | sum += exp(src[i * cols + j]); 20 | } 21 | for (int j = 0; j < cols; ++j) { 22 | dst[i * cols + j] = exp(src[i * cols + j]) / sum; 23 | } 24 | } 25 | } 26 | 27 | template 30 | __global__ void reg_softmax(Element* src, Element* dst) { 31 | using SrcLoadTile = GlobalTile; 32 | using DstLoadTile = RegTile; 33 | using SrcReduceTile = DstLoadTile; 34 | using DstReduceTile = RegTile>; 35 | using SrcStoreTile = DstLoadTile; 36 | using DstStoreTile = GlobalTile; 37 | 38 | SrcLoadTile src_load_tile(src); 39 | DstLoadTile dst_load_tile; 40 | DstReduceTile dst_reduce_tile; 41 | DstStoreTile dst_store_tile(dst); 42 | 43 | // Load data from global memory to register file 44 | copy::GlobalToRegLoader loader; 45 | loader(src_load_tile, dst_load_tile); 46 | __syncthreads(); 47 | 48 | // Execute softmax. 49 | compute::Softmax row_softmax; 50 | row_softmax(dst_load_tile, dst_reduce_tile); 51 | 52 | __syncthreads(); 53 | 54 | if (thread(0)) { 55 | printf("Thread 0:\n"); 56 | dst_load_tile.dump_value(); 57 | } 58 | 59 | if (thread(1)) { 60 | printf("Thread 1:\n"); 61 | dst_load_tile.dump_value(); 62 | } 63 | 64 | if (thread(4)) { 65 | printf("Thread 4:\n"); 66 | dst_load_tile.dump_value(); 67 | } 68 | 69 | if (thread(8)) { 70 | printf("Thread 8:\n"); 71 | dst_load_tile.dump_value(); 72 | } 73 | 74 | __syncthreads(); 75 | copy::RegToGlobalStorer storer; 76 | storer(dst_load_tile, dst_store_tile); 77 | } 78 | 79 | template 82 | void run_reg_softmax() { 83 | int kNumel = 16 * 16 * kHeight * kWidth; 84 | int kWarpSize = tl::get_numel; 85 | 86 | srand(42); 87 | 88 | thrust::host_vector h_src(kNumel); 89 | for (int i = 0; i < kNumel; ++i) { 90 | h_src[i] = (Element)(10 * (rand() / float(RAND_MAX)) - 5); 91 | } 92 | 93 | thrust::device_vector d_src = h_src; 94 | thrust::device_vector d_dst(kNumel); 95 | thrust::fill(d_dst.begin(), d_dst.end(), static_cast(0.)); 96 | 97 | thrust::host_vector h_dst_ref(kNumel); 98 | thrust::fill(h_dst_ref.begin(), h_dst_ref.end(), static_cast(0.)); 99 | 100 | reg_softmax 102 | <<<1, 32 * kWarpSize>>>(thrust::raw_pointer_cast(d_src.data()), 103 | thrust::raw_pointer_cast(d_dst.data())); 104 | 105 | thrust::host_vector h_dst = d_dst; 106 | 107 | cpu_softmax(thrust::raw_pointer_cast(h_src.data()), 108 | thrust::raw_pointer_cast(h_dst_ref.data()), 16 * kHeight, 109 | 16 * kWidth); 110 | 111 | // Check the result 112 | for (int i = 0; i < kNumel; ++i) { 113 | EXPECT_NEAR(h_dst[i], h_dst_ref[i], 1e-3); 114 | } 115 | } 116 | 117 | TEST(TestRegSoftmax, row_major_reg_softmax_0) { 118 | const int kHeight = 1; 119 | const int kWidth = 1; 120 | using Element = float; 121 | using WarpLayout = tl::RowMajor<1, 1>; 122 | using RegLayout = tl::RowMajor; 123 | 124 | const copy::WarpReuse kMode = copy::WarpReuse::kCont; 125 | 126 | using GlobalLayout = tl::RowMajor<16 * kHeight, 16 * kWidth>; 127 | 128 | run_reg_softmax, 129 | WarpLayout, tl::Layout::kRowMajor, kMode, kHeight, 130 | kWidth>(); 131 | } 132 | 133 | TEST(TestRegSoftmax, row_major_reg_softmax_1) { 134 | const int kHeight = 2; 135 | const int kWidth = 2; 136 | using Element = float; 137 | using WarpLayout = tl::RowMajor<1, 1>; 138 | using RegLayout = tl::RowMajor; 139 | 140 | const copy::WarpReuse kMode = copy::WarpReuse::kCont; 141 | 142 | using GlobalLayout = tl::RowMajor<16 * kHeight, 16 * kWidth>; 143 | 144 | run_reg_softmax, 145 | WarpLayout, tl::Layout::kRowMajor, kMode, kHeight, 146 | kWidth>(); 147 | } 148 | 149 | TEST(TestRegSoftmax, row_major_reg_softmax_2) { 150 | const int kHeight = 4; 151 | const int kWidth = 1; 152 | using Element = float; 153 | using WarpLayout = tl::RowMajor; 154 | using RegLayout = tl::RowMajor<1, kWidth>; 155 | 156 | const copy::WarpReuse kMode = copy::WarpReuse::kCont; 157 | 158 | using GlobalLayout = tl::RowMajor<16 * kHeight, 16 * kWidth>; 159 | 160 | run_reg_softmax, 161 | WarpLayout, tl::Layout::kRowMajor, kMode, kHeight, 162 | kWidth>(); 163 | } 164 | 165 | } // namespace tiledcuda::testing 166 | -------------------------------------------------------------------------------- /tests/cpp/common/test_utils.cc: -------------------------------------------------------------------------------- 1 | #include "common/test_utils.hpp" 2 | 3 | namespace tiledcuda::testing { 4 | 5 | // FIXME(haruhi): A quick implementation to compare two __half arrays. Refine 6 | // the implementation of necessary unittest utilities. 7 | template <> 8 | void assert_equal(const __half* v1, const __half* v2, int64_t numel, 9 | float epsilon) { 10 | float a = 0.f; 11 | float b = 0.f; 12 | for (int i = 0; i < numel; ++i) { 13 | a = __half2float(v1[i]); 14 | b = __half2float(v2[i]); 15 | 16 | EXPECT_NEAR(a, b, epsilon) << "v1[" << i << "] vs. v2[" << i 17 | << "] = " << a << " vs. " << b << std::endl; 18 | } 19 | } 20 | 21 | template <> 22 | void assert_equal(const float* v1, const float* v2, int64_t numel, 23 | float epsilon) { 24 | for (int i = 0; i < numel; ++i) 25 | EXPECT_NEAR(v1[i], v2[i], epsilon) 26 | << "v1[" << i << "] vs. v2[" << i << "] = " << v1[i] << " vs. " 27 | << v2[i] << std::endl; 28 | } 29 | 30 | } // namespace tiledcuda::testing 31 | -------------------------------------------------------------------------------- /tests/cpp/common/test_utils.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "config.hpp" 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | namespace tiledcuda::testing { 10 | 11 | template 12 | void assert_equal(const T* v1, const T* v2, int64_t numel, float epsilon); 13 | 14 | } // namespace tiledcuda::testing 15 | -------------------------------------------------------------------------------- /tests/cpp/kernels/test_scatter_nd.cu: -------------------------------------------------------------------------------- 1 | #include "common/test_utils.hpp" 2 | 3 | namespace tiledcuda { 4 | namespace testing { 5 | 6 | TEST(TestScatter, test) {} 7 | 8 | } // namespace testing 9 | } // namespace tiledcuda 10 | -------------------------------------------------------------------------------- /tests/cpp/test_unit.cc: -------------------------------------------------------------------------------- 1 | #include "common/test_utils.hpp" 2 | 3 | int main(int argc, char** argv) { 4 | FLAGS_alsologtostderr = 1; // redirect log to stderr 5 | google::InitGoogleLogging(argv[0]); 6 | 7 | testing::InitGoogleTest(&argc, argv); 8 | 9 | return RUN_ALL_TESTS(); 10 | } 11 | -------------------------------------------------------------------------------- /tests/python/context.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.insert( 5 | 0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))) 6 | -------------------------------------------------------------------------------- /tests/python/test_flash_attn.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | 5 | import context 6 | from pytiledcuda import TiledFlashAttention 7 | 8 | 9 | class FlashAttention: 10 | 11 | def __init__(self, query, key, value, M, N, K, P, kTM, kTN, kTK, kTP): 12 | self.M = M 13 | self.N = N 14 | self.K = K 15 | self.P = P 16 | 17 | self.kTM = kTM 18 | self.kTN = kTN 19 | self.kTK = kTK 20 | self.kTP = kTP 21 | 22 | self.query = query 23 | self.key = key 24 | self.value = value 25 | self.output = torch.empty(M, P, device='cpu') 26 | 27 | def forward(self): 28 | iter_n = self.N // self.kTN 29 | 30 | prev_maxes = torch.zeros(self.M, 1, device='cpu') 31 | prev_sums = torch.zeros(self.M, 1, device='cpu') 32 | 33 | output = self.output.view(self.M, self.P) 34 | 35 | dK = self.key.view(self.K, self.N) 36 | dV = self.value.view(self.N, self.P) 37 | 38 | ks = torch.chunk(dK, iter_n, dim=-1) 39 | vs = torch.chunk(dV, iter_n, dim=-2) 40 | 41 | for n in range(iter_n): 42 | q = self.query.view(self.M, self.K) # m * k 43 | 44 | k = ks[n] 45 | v = vs[n] 46 | 47 | attn_weights = q @ k # m * ktn 48 | 49 | # reduce maxes 50 | cur_maxes, _ = torch.max(attn_weights, dim=-1, keepdim=True) 51 | exp_weights = torch.exp(attn_weights - cur_maxes) 52 | # unnormalized attention score @ values 53 | exp_values = exp_weights @ v 54 | # move the normalization step to the very end of the attention computation. 55 | cur_sums = torch.sum(exp_weights, dim=-1, keepdim=True) # l(x_cur) 56 | 57 | # ======================= renormalization ======================# 58 | new_maxes = torch.max(cur_maxes, prev_maxes) # update m(x) 59 | # renormalization factor for the previous block 60 | renorm_prev = torch.exp(prev_maxes - new_maxes) 61 | # renormalization factor for the current block 62 | renorm_cur = torch.exp(cur_maxes - new_maxes) 63 | 64 | # update normalization factor l(x) 65 | new_sums = renorm_prev * prev_sums + renorm_cur * cur_sums 66 | 67 | output = (output * prev_sums * renorm_prev + 68 | renorm_cur * exp_values) / new_sums 69 | 70 | prev_sums = new_sums 71 | prev_maxes = new_maxes 72 | 73 | self.output = output 74 | 75 | return self.output 76 | 77 | 78 | class TestFlashAttention(unittest.TestCase): 79 | 80 | def setUp(self): 81 | torch.manual_seed(1234) 82 | 83 | def run_flash_attention(self, m, n, k, p, kTM, kTN, kTK, kTP): 84 | 85 | Q = torch.randn(m, k, device='cpu') 86 | K = torch.randn(k, n, device='cpu') 87 | V = torch.randn(n, p, device='cpu') 88 | O = torch.empty(m, p, device='cpu') 89 | 90 | flash_attn = FlashAttention(Q.half().flatten(), 91 | K.half().flatten(), 92 | V.half().flatten(), m, n, k, p, kTM, kTN, 93 | kTK, kTP) 94 | 95 | ref_o = flash_attn.forward().half() 96 | 97 | CUDA_Q = Q.cuda() 98 | CUDA_K = K.cuda() 99 | CUDA_V = V.cuda() 100 | 101 | tiled_flash_attention = TiledFlashAttention(CUDA_Q, CUDA_K, CUDA_V) 102 | O = tiled_flash_attention.forward() 103 | 104 | print('CPU Reference O: ', ref_o) 105 | print('TiledCUDA O: ', O) 106 | 107 | hO = O.cpu() 108 | 109 | passed = True 110 | 111 | # Compare elements one by one and print the different numbers. 112 | for i in range(m): 113 | for j in range(p): 114 | if abs(hO[i][j] - ref_o[i][j]) > 8e-2: 115 | print('(', i, ', ', j, ')') 116 | print('TiledCUDA O: ', hO[i][j]) 117 | print('CPU Reference O: ', ref_o[i][j]) 118 | 119 | passed = False 120 | break 121 | 122 | assert passed 123 | 124 | def test_flash_attention_v0(self): 125 | M = 64 126 | N = 64 127 | K = 128 128 | P = 128 129 | 130 | kTM = 64 131 | kTN = 64 132 | kTK = 128 133 | kTP = 128 134 | 135 | self.run_flash_attention(M, N, K, P, kTM, kTN, kTK, kTP) 136 | 137 | def test_flash_attention_v1(self): 138 | M = 64 139 | N = 128 140 | K = 128 141 | P = 128 142 | 143 | kTM = 64 144 | kTN = 64 145 | kTK = 128 146 | kTP = 128 147 | 148 | self.run_flash_attention(M, N, K, P, kTM, kTN, kTK, kTP) 149 | 150 | def test_flash_attention_v2(self): 151 | M = 64 152 | N = 256 153 | K = 128 154 | P = 128 155 | 156 | kTM = 64 157 | kTN = 64 158 | kTK = 128 159 | kTP = 128 160 | 161 | self.run_flash_attention(M, N, K, P, kTM, kTN, kTK, kTP) 162 | 163 | 164 | if __name__ == "__main__": 165 | unittest.main() 166 | -------------------------------------------------------------------------------- /tests/python/test_scatter_nd.py: -------------------------------------------------------------------------------- 1 | import random 2 | import unittest 3 | 4 | import torch 5 | 6 | import context 7 | from pytiledcuda import scatter_nd 8 | 9 | 10 | class TestGemm(unittest.TestCase): 11 | 12 | def _compute_output_shape(self, index_dims, input_dims): 13 | end_size = index_dims[-1] 14 | out_shape = index_dims[:-1] 15 | for i in range(len(input_dims) - end_size): 16 | out_shape.append(input_dims[len(index_dims) + i]) 17 | return out_shape 18 | 19 | def setUp(self): 20 | torch.manual_seed(1234) 21 | 22 | def test_scatter_nd(self): 23 | data_shape = [7, 8, 9, 10] 24 | data = torch.empty(data_shape, dtype=torch.float32, 25 | device='cuda').fill_(5.0) 26 | scatter_data = data.flatten() 27 | 28 | indices_shape = [5, 2] 29 | indices = torch.empty(indices_shape, dtype=torch.int64, device='cuda') 30 | 31 | for i in range(indices_shape[0]): 32 | indices[i][0] = random.randint(0, data_shape[0] - 1) 33 | indices[i][1] = random.randint(0, data_shape[1] - 1) 34 | 35 | scatter_indices = indices.flatten() 36 | 37 | update_shape = self._compute_output_shape(indices_shape, data_shape) 38 | updates = torch.empty(update_shape, dtype=torch.float32, 39 | device='cuda').fill_(10.0) 40 | scatter_updates = updates.flatten() 41 | 42 | # import pytiledcuda 43 | scatter_nd(scatter_data, scatter_indices, scatter_updates) 44 | 45 | # Implement `scatter_nd` in Python. 46 | data[indices[:, 0], indices[:, 1]] = updates 47 | 48 | flattened_data = data.flatten() 49 | 50 | # Print data 51 | print(scatter_data) 52 | print(flattened_data) 53 | 54 | assert torch.allclose(scatter_data, flattened_data) 55 | 56 | 57 | if __name__ == "__main__": 58 | unittest.main() 59 | --------------------------------------------------------------------------------