├── .gitignore ├── container.def ├── LICENSE ├── rowsum.h ├── setup.py ├── mem.h ├── dispatch.h ├── readme.md ├── pytorch_custom_mma_cuda.cu ├── benchmark.py └── MMA.h /.gitignore: -------------------------------------------------------------------------------- 1 | .eggs/ 2 | build/ 3 | dist/ 4 | pytorch_custom_mma_cuda.egg-info/ 5 | container.sif -------------------------------------------------------------------------------- /container.def: -------------------------------------------------------------------------------- 1 | Bootstrap: docker 2 | From: nvidia/cuda:11.6.1-cudnn8-devel-ubuntu20.04 3 | 4 | %post 5 | # Downloads the latest package lists (important). 6 | apt-get update -y 7 | 8 | # Runs apt-get while ensuring that there are no user prompts that would 9 | # cause the build process to hang. 10 | DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ 11 | python3 \ 12 | python3-pip \ 13 | python3-setuptools \ 14 | python3-dev 15 | 16 | # Reduce the size of the image by deleting the package lists we downloaded, 17 | # which are useless now. 18 | rm -rf /var/lib/apt/lists/* 19 | 20 | # Install Python modules. 21 | pip3 install numpy 22 | pip3 install torch torchtyping --extra-index-url https://download.pytorch.org/whl/cu116 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Arthur Hennequin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /rowsum.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | template 4 | struct rowsum_accumulator { 5 | static constexpr int N_tile = warp_tile_t::N_tile; 6 | static constexpr int M_tile = warp_tile_t::M_tile; 7 | 8 | float acc; 9 | 10 | __device__ void zero() { 11 | acc = 0; 12 | } 13 | 14 | template 15 | __device__ void add(shared_fragment& smem) { 16 | if (threadIdx.x < N_tile) { 17 | #pragma unroll 18 | for (int i = 0; i < M_tile; i++) { 19 | acc += smem(threadIdx.x, i); 20 | } 21 | } 22 | } 23 | 24 | __device__ void divide(scalar_t* smem, warp_tile_t& mma) { 25 | if (threadIdx.x < N_tile) smem[threadIdx.x] = 1.f / acc; 26 | __syncthreads(); 27 | 28 | mma.pointwise([&](scalar_t el, int, int y) { 29 | return el * smem[y]; 30 | }); 31 | __syncthreads(); 32 | } 33 | 34 | template 35 | __device__ void store(accessor gmem, int tile_y) { 36 | if (threadIdx.x < N_tile) { 37 | gmem[threadIdx.x + tile_y] = acc; 38 | } 39 | } 40 | }; -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from functools import lru_cache 3 | from subprocess import DEVNULL, call 4 | from setuptools import setup, find_packages 5 | 6 | from torch.utils.cpp_extension import CUDAExtension, BuildExtension 7 | 8 | @lru_cache(None) 9 | def cuda_toolkit_available(): 10 | try: 11 | call(["nvcc"], stdout = DEVNULL, stderr = DEVNULL) 12 | return True 13 | except FileNotFoundError: 14 | return False 15 | 16 | def compile_args(): 17 | args = ["-fopenmp", "-ffast-math"] 18 | if sys.platform == "darwin": 19 | args = ["-Xpreprocessor", *args] 20 | return args 21 | 22 | def ext_modules(): 23 | if not cuda_toolkit_available(): 24 | return [] 25 | 26 | return [ 27 | CUDAExtension( 28 | "pytorch_custom_mma_cuda", 29 | sources = ["pytorch_custom_mma_cuda.cu"] 30 | ) 31 | ] 32 | 33 | # main setup code 34 | 35 | setup( 36 | name = 'pytorch_custom_mma_cuda', 37 | packages = find_packages(exclude=[]), 38 | version = '0.0.3', 39 | license='MIT', 40 | install_requires=[ 41 | 'torch>=1.10', 42 | 'torchtyping' 43 | ], 44 | setup_requires=[ 45 | 'pytest-runner', 46 | ], 47 | tests_require=[ 48 | 'pytest' 49 | ], 50 | ext_modules = ext_modules(), 51 | cmdclass = {"build_ext": BuildExtension}, 52 | include_package_data = True, 53 | ) 54 | -------------------------------------------------------------------------------- /mem.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | namespace mem { 4 | template 5 | struct shared_fragment { 6 | static constexpr int N = N_tile; 7 | static constexpr int M = M_tile; 8 | static constexpr int stride = M + (sizeof(T) == 2 ? 8 : 1); 9 | static constexpr int size = N * stride; 10 | 11 | T* smem; 12 | 13 | __device__ shared_fragment(char* shared_base) 14 | : smem(reinterpret_cast(shared_base)) { } 15 | 16 | template 17 | __device__ void load(accessor gmem, int tile_x, int tile_y) { 18 | for (int i = threadIdx.x; i < N * M; i += blockDim.x) { 19 | int x = i % M; 20 | int y = i / M; 21 | smem[y * stride + x] = gmem[y + tile_y][x + tile_x]; 22 | } 23 | } 24 | 25 | template 26 | __device__ void load_transpose(accessor gmem, int tile_x, int tile_y) { 27 | for (int i = threadIdx.x; i < N * M; i += blockDim.x) { 28 | int x = i % N; 29 | int y = i / N; 30 | smem[x * stride + y] = gmem[y + tile_y][x + tile_x]; 31 | } 32 | } 33 | 34 | __device__ T& operator()(int x, int y) { 35 | return smem[y * stride + x]; 36 | } 37 | 38 | template 39 | __device__ void store(accessor gmem, int tile_x, int tile_y) { 40 | for (int i = threadIdx.x; i < N * M; i += blockDim.x) { 41 | int x = i % M; 42 | int y = i / M; 43 | gmem[y + tile_y][x + tile_x] = smem[y * stride + x]; 44 | } 45 | } 46 | 47 | __device__ char* next() { 48 | return reinterpret_cast(smem + size); 49 | } 50 | }; 51 | } // namespace mem -------------------------------------------------------------------------------- /dispatch.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | // Custom dispatch inspired from 6 | // https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h 7 | // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h 8 | // https://github.com/swansontec/map-macro 9 | 10 | // Macro utilities: 11 | #define REMOVE_PAREN_IMPL(...) __VA_ARGS__ 12 | #define REMOVE_PAREN(args) REMOVE_PAREN_IMPL args 13 | 14 | #define EVAL0(...) __VA_ARGS__ 15 | #define EVAL1(...) EVAL0(EVAL0(EVAL0(__VA_ARGS__))) 16 | #define EVAL2(...) EVAL1(EVAL1(EVAL1(__VA_ARGS__))) 17 | #define EVAL3(...) EVAL2(EVAL2(EVAL2(__VA_ARGS__))) 18 | #define EVAL4(...) EVAL3(EVAL3(EVAL3(__VA_ARGS__))) 19 | #define EVAL(...) EVAL4(EVAL4(EVAL4(__VA_ARGS__))) 20 | 21 | #define MAP_END(...) 22 | #define MAP_OUT 23 | 24 | #define MAP_GET_END2() 0, MAP_END 25 | #define MAP_GET_END1(...) MAP_GET_END2 26 | #define MAP_GET_END(...) MAP_GET_END1 27 | #define MAP_NEXT0(test, next, ...) next MAP_OUT 28 | #define MAP_NEXT1(test, next) MAP_NEXT0(test, next, 0) 29 | #define MAP_NEXT(test, next) MAP_NEXT1(MAP_GET_END test, next) 30 | 31 | #define MAP0(f, TYPE_NAME, CASE_CODE, x, peek, ...) f(TYPE_NAME, CASE_CODE, x) MAP_NEXT(peek, MAP1)(f, TYPE_NAME, CASE_CODE, peek, __VA_ARGS__) 32 | #define MAP1(f, TYPE_NAME, CASE_CODE, x, peek, ...) f(TYPE_NAME, CASE_CODE, x) MAP_NEXT(peek, MAP0)(f, TYPE_NAME, CASE_CODE, peek, __VA_ARGS__) 33 | #define MAP(f, TYPE_NAME, CASE_CODE, ...) EVAL(MAP1(f, TYPE_NAME, CASE_CODE, __VA_ARGS__, ()()(), ()()(), ()()(), 0)) 34 | 35 | // Type dispatch 36 | #define AT_TYPE_DISPATCH_CASE(TYPE_NAME, CASE_CODE, x) \ 37 | case x: { \ 38 | using TYPE_NAME C10_UNUSED_DISPATCH_CUDA_WORKAROUND = \ 39 | typename c10::impl::ScalarTypeToCPPType::type; \ 40 | REMOVE_PAREN(CASE_CODE) \ 41 | break; \ 42 | } 43 | 44 | #define AT_TYPE_DISPATCH_SWITCH(TYPE, TYPE_NAME, TYPES, CASE_CODE, DEFAULT_CODE) \ 45 | { \ 46 | switch (TYPE) { \ 47 | MAP(AT_TYPE_DISPATCH_CASE, TYPE_NAME, CASE_CODE, REMOVE_PAREN(TYPES)) \ 48 | default: { \ 49 | REMOVE_PAREN(DEFAULT_CODE) \ 50 | } \ 51 | } \ 52 | } 53 | 54 | // Value dispatch 55 | #define VALUE_DISPATCH_CASE(VALUE_NAME, CASE_CODE, x) \ 56 | case x: { \ 57 | constexpr const auto VALUE_NAME = x; \ 58 | REMOVE_PAREN(CASE_CODE) \ 59 | break; \ 60 | } 61 | 62 | #define VALUE_DISPATCH_SWITCH(VALUE, VALUE_NAME, VALUES, CASE_CODE, DEFAULT_CODE) \ 63 | { \ 64 | switch (VALUE) { \ 65 | MAP(VALUE_DISPATCH_CASE, VALUE_NAME, CASE_CODE, REMOVE_PAREN(VALUES)) \ 66 | default: { \ 67 | REMOVE_PAREN(DEFAULT_CODE) \ 68 | } \ 69 | } \ 70 | } 71 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | ## Experiments with flash cosine similarity attention 2 | 3 | Main repo: https://github.com/lucidrains/flash-cosine-sim-attention 4 | 5 | ```python 6 | # O = softmax(scale * (Q * K^T) - scale) * V 7 | def plain_impl(Q: TensorType['b', 'i', 'd'], 8 | K: TensorType['b', 'j', 'd'], 9 | V: TensorType['b', 'j', 'd'], 10 | scale=8) -> TensorType['b', 'i', 'j']: 11 | C = einsum('... i d, ... j d -> ... i j', Q, K) 12 | C = (C * scale - scale).softmax(dim = -1) 13 | O = einsum('... i j, ... j d -> ... i d', C, V) 14 | return O 15 | ``` 16 | 17 | Implements ideas from https://developer.nvidia.com/blog/cutlass-linear-algebra-cuda/ 18 | 19 | The goal is to match pytorch performance while keeping the cuda code simple to understand. 20 | 21 | Results on RTX2070: 22 | ``` 23 | -------------------------------------------------------------------------------- 24 | batch: 32 head dim: 64 V dim: 64 dtype: torch.float32 25 | -------------------------------------------------------------------------------- 26 | seq_len: 64 slower: 0.64x kernel: 0.086ms baseline: 0.134ms 27 | seq_len: 128 slower: 1.02x kernel: 0.142ms baseline: 0.140ms 28 | seq_len: 256 slower: 1.13x kernel: 0.390ms baseline: 0.347ms 29 | seq_len: 512 slower: 1.18x kernel: 1.432ms baseline: 1.209ms 30 | seq_len: 1024 slower: 1.02x kernel: 5.140ms baseline: 5.042ms 31 | seq_len: 2048 slower: 0.94x kernel: 17.211ms baseline: 18.235ms 32 | seq_len: 4096 slower: 0.00x kernel: 69.789ms baseline: OOM 33 | seq_len: 8192 slower: 0.00x kernel: 279.815ms baseline: OOM 34 | -------------------------------------------------------------------------------- 35 | batch: 32 head dim: 64 V dim: 64 dtype: torch.float16 36 | -------------------------------------------------------------------------------- 37 | seq_len: 64 slower: 0.56x kernel: 0.078ms baseline: 0.140ms 38 | seq_len: 128 slower: 0.59x kernel: 0.097ms baseline: 0.164ms 39 | seq_len: 256 slower: 1.17x kernel: 0.209ms baseline: 0.178ms 40 | seq_len: 512 slower: 1.02x kernel: 0.562ms baseline: 0.552ms 41 | seq_len: 1024 slower: 0.95x kernel: 1.931ms baseline: 2.032ms 42 | seq_len: 2048 slower: 0.80x kernel: 7.147ms baseline: 8.928ms 43 | seq_len: 4096 slower: 0.82x kernel: 27.829ms baseline: 34.134ms 44 | seq_len: 8192 slower: 0.00x kernel: 109.877ms baseline: OOM 45 | ``` 46 | 47 | Results on A100: 48 | ``` 49 | -------------------------------------------------------------------------------- 50 | batch: 32 head dim: 64 V dim: 64 dtype: torch.float32 51 | -------------------------------------------------------------------------------- 52 | seq_len: 64 slower: 0.50x kernel: 0.082ms baseline: 0.165ms 53 | seq_len: 128 slower: 0.56x kernel: 0.092ms baseline: 0.164ms 54 | seq_len: 256 slower: 0.88x kernel: 0.160ms baseline: 0.182ms 55 | seq_len: 512 slower: 0.65x kernel: 0.325ms baseline: 0.496ms 56 | seq_len: 1024 slower: 0.73x kernel: 1.084ms baseline: 1.489ms 57 | seq_len: 2048 slower: 0.63x kernel: 3.362ms baseline: 5.371ms 58 | seq_len: 4096 slower: 0.57x kernel: 12.065ms baseline: 21.270ms 59 | seq_len: 8192 slower: 0.00x kernel: 46.413ms baseline: OOM 60 | seq_len: 16384 slower: 0.00x kernel: 180.894ms baseline: OOM 61 | seq_len: 32768 slower: 0.00x kernel: 744.898ms baseline: OOM 62 | -------------------------------------------------------------------------------- 63 | batch: 32 head dim: 64 V dim: 64 dtype: torch.float16 64 | -------------------------------------------------------------------------------- 65 | seq_len: 64 slower: 0.41x kernel: 0.066ms baseline: 0.160ms 66 | seq_len: 128 slower: 0.38x kernel: 0.077ms baseline: 0.201ms 67 | seq_len: 256 slower: 0.63x kernel: 0.102ms baseline: 0.161ms 68 | seq_len: 512 slower: 1.05x kernel: 0.170ms baseline: 0.162ms 69 | seq_len: 1024 slower: 0.71x kernel: 0.372ms baseline: 0.521ms 70 | seq_len: 2048 slower: 0.77x kernel: 1.427ms baseline: 1.851ms 71 | seq_len: 4096 slower: 0.69x kernel: 4.944ms baseline: 7.167ms 72 | seq_len: 8192 slower: 0.63x kernel: 18.202ms baseline: 29.042ms 73 | seq_len: 16384 slower: 0.00x kernel: 68.439ms baseline: OOM 74 | seq_len: 32768 slower: 0.00x kernel: 262.238ms baseline: OOM 75 | ``` 76 | 77 | ## Building the image to run in HPC cluster 78 | 79 | On local computer (need root access): 80 | 81 | ```bash 82 | sudo singularity build container.sif container.def 83 | ``` 84 | 85 | Run interactive shell: 86 | ```bash 87 | singularity shell --nv container.sif 88 | ``` 89 | 90 | Install & run: 91 | ```bash 92 | python3 setup.py install --user 93 | python3 benchmark.py 94 | ``` -------------------------------------------------------------------------------- /pytorch_custom_mma_cuda.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | #include 8 | 9 | #include "dispatch.h" 10 | #include "MMA.h" 11 | #include "mem.h" 12 | #include "rowsum.h" 13 | 14 | #define CHECK_LAST_CUDA_ERROR() check(__FILE__, __LINE__) 15 | void check(const char* file, const int line) { 16 | cudaError_t err = cudaGetLastError(); 17 | 18 | if (err != cudaSuccess) { 19 | std::cerr << "CUDA Error at: " << file << ":" << line << std::endl; 20 | std::cerr << cudaGetErrorString(err) << std::endl; 21 | } 22 | } 23 | 24 | #define ACCESSOR(x, n, type) x.packed_accessor32() 25 | 26 | // type alias 27 | 28 | template 29 | using PackedAccessor = torch::PackedTensorAccessor32; 30 | 31 | template 32 | __global__ void forward_kernel( 33 | const PackedAccessor Q, 34 | const PackedAccessor K, 35 | const PackedAccessor V, 36 | PackedAccessor O, 37 | PackedAccessor l, 38 | const float scale 39 | ) { 40 | const int batch = blockIdx.y; 41 | 42 | const int N = Q.size(1); 43 | const int M = K.size(1); 44 | const int QK_dim = Q.size(2); 45 | 46 | using QK_mma_t = mma::warp_tile; 47 | using out_mma_t = mma::warp_tile; 48 | 49 | using Q_sm_t = mem::shared_fragment; 50 | using K_sm_t = mem::shared_fragment; 51 | using C_sm_t = mem::shared_fragment; 52 | 53 | const int tile_y = blockIdx.x * QK_mma_t::N_tile; 54 | 55 | __shared__ scalar_t _shared_mem[Q_sm_t::size + K_sm_t::size + C_sm_t::size]; 56 | 57 | QK_mma_t QK_mma; // 32x16 tile per warp in registers -> process 64x64 with the block 58 | out_mma_t out_mma; 59 | rowsum_accumulator L_acc; 60 | 61 | Q_sm_t Q_sm{reinterpret_cast(_shared_mem)}; 62 | K_sm_t K_sm{Q_sm.next()}; 63 | C_sm_t C_sm{K_sm.next()}; 64 | 65 | out_mma.zero(); 66 | L_acc.zero(); 67 | 68 | for (int tile_x = 0; tile_x < M; tile_x += QK_mma_t::M_tile) { 69 | QK_mma.zero(); 70 | 71 | for (int k = 0; k < QK_dim; k += K_sm_t::N) { 72 | Q_sm.load_transpose(Q[batch], k, tile_y); // TODO: reload only if needed (if head_dim=16, no reload) 73 | K_sm.load_transpose(K[batch], k, tile_x); 74 | __syncthreads(); 75 | 76 | QK_mma.mma(Q_sm, K_sm, 0, 0, K_sm_t::N); 77 | __syncthreads(); 78 | } 79 | 80 | QK_mma.pointwise([&](scalar_t el, int, int) -> scalar_t { 81 | return expf(scale * el - scale); 82 | }); 83 | 84 | QK_mma.store_transpose(C_sm); 85 | __syncthreads(); 86 | 87 | L_acc.add(C_sm); 88 | 89 | // Second matmul: 90 | for (int k = 0; k < QK_mma_t::M_tile; k += K_sm_t::N) { 91 | K_sm.load(V[batch], 0, tile_x + k); // reuse K shared mem for V 92 | __syncthreads(); 93 | 94 | out_mma.mma(C_sm, K_sm, k, 0, K_sm_t::N); 95 | __syncthreads(); 96 | } 97 | } 98 | 99 | L_acc.store(l[batch], tile_y); 100 | L_acc.divide(C_sm.smem, out_mma); 101 | 102 | out_mma.store(O[batch], C_sm, 0, tile_y); 103 | } 104 | 105 | std::vector mma_forward( 106 | torch::Tensor Q, 107 | torch::Tensor K, 108 | torch::Tensor V, 109 | float scale 110 | ) { 111 | const at::cuda::OptionalCUDAGuard device_guard(device_of(Q)); 112 | 113 | const int batch = Q.size(0); 114 | const int N = Q.size(1); 115 | const int M = K.size(1); 116 | const int QK_dim = Q.size(2); 117 | const int V_dim = V.size(2); 118 | 119 | auto options = torch::TensorOptions().device(device_of(Q)).dtype(Q.scalar_type()); 120 | auto O = at::empty({batch, N, V_dim}, options); 121 | auto l = at::empty({batch, N}, options); 122 | 123 | const dim3 threads_per_block(256); 124 | 125 | AT_TYPE_DISPATCH_SWITCH(Q.scalar_type(), scalar_t, (at::ScalarType::Float, at::ScalarType::Half), ( 126 | VALUE_DISPATCH_SWITCH(V_dim, out_dim, (64), ( 127 | const int N_tile = 64; 128 | const dim3 blocks(N / N_tile, batch); 129 | forward_kernel<<>>( 130 | ACCESSOR(Q, 3, scalar_t), 131 | ACCESSOR(K, 3, scalar_t), 132 | ACCESSOR(V, 3, scalar_t), 133 | ACCESSOR(O, 3, scalar_t), 134 | ACCESSOR(l, 2, scalar_t), 135 | scale 136 | ); 137 | ), ()) 138 | ), ()) 139 | 140 | // handle error 141 | cudaDeviceSynchronize(); 142 | CHECK_LAST_CUDA_ERROR(); 143 | 144 | return { O, l }; 145 | } 146 | 147 | // bind 148 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 149 | m.def("forward", &mma_forward, "MMA Forward"); 150 | // m.def("backward", &mma_backward, "MMA Backward"); 151 | } 152 | -------------------------------------------------------------------------------- /benchmark.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | from torchtyping import TensorType 4 | from torch import einsum 5 | 6 | from torch.cuda import synchronize, Event 7 | from functools import wraps, partial 8 | import torch.nn.functional as F 9 | 10 | # argparse 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--only-forwards', default = True, action = 'store_true') 13 | parser.add_argument('--only-backwards', default = False, action = 'store_true') 14 | args = parser.parse_args() 15 | 16 | assert not (args.only_forwards and args.only_backwards) 17 | 18 | torch.manual_seed(0) 19 | 20 | # constants 21 | TEST_SEQUENCE_LENGTHS = [64, 128, 256, 512, 1024, 2048, 4096, 8192] 22 | 23 | TEST_FORWARDS = not args.only_backwards 24 | TEST_BACKWARDS = not args.only_forwards 25 | 26 | timer = partial(Event, enable_timing = True) 27 | 28 | def benchmark( 29 | fn, 30 | *, 31 | num_times = 10, 32 | warmup_iters = 10, 33 | forwards = True, 34 | backwards = False 35 | ): 36 | assert forwards or backwards 37 | 38 | @wraps(fn) 39 | def inner(*args, **kwargs): 40 | # warmup 41 | for _ in range(warmup_iters): 42 | loss = fn(*args, **kwargs).sum() 43 | #loss.backward() 44 | 45 | # average across number of function calls 46 | all_measured_times_ms = 0. 47 | 48 | for _ in range(num_times): 49 | start_event = timer() 50 | end_event = timer() 51 | 52 | if forwards: 53 | start_event.record() 54 | 55 | o = fn(*args, **kwargs) 56 | 57 | if not backwards: 58 | end_event.record() 59 | 60 | if not forwards: 61 | start_event.record() 62 | 63 | if backwards: 64 | loss = o.sum() 65 | loss.backward() 66 | end_event.record() 67 | 68 | synchronize() 69 | 70 | elapsed_time_ms = start_event.elapsed_time(end_event) 71 | all_measured_times_ms += elapsed_time_ms 72 | 73 | return all_measured_times_ms / num_times 74 | 75 | return inner 76 | 77 | # O = softmax(scale * (Q * K^T) - scale) * V 78 | def plain_impl(Q: TensorType['b', 'i', 'd'], 79 | K: TensorType['b', 'j', 'd'], 80 | V: TensorType['b', 'j', 'd'], 81 | scale=8) -> TensorType['b', 'i', 'j']: 82 | C = einsum('... i d, ... j d -> ... i j', Q, K) 83 | C = (C * scale - scale).softmax(dim = -1) 84 | O = einsum('... i j, ... j d -> ... i d', C, V) 85 | return O 86 | 87 | import pytorch_custom_mma_cuda 88 | 89 | class MMACudaFunction(torch.autograd.Function): 90 | @staticmethod 91 | def forward(ctx, Q, K, V, scale=8): 92 | C, l = pytorch_custom_mma_cuda.forward(Q, K, V, scale) 93 | ctx.save_for_backward(Q, K, V, C, l) 94 | return C 95 | @staticmethod 96 | def backward(ctx, grad_c): 97 | #Q, K, V, C, l = ctx.saved_tensors # TODO 98 | return None, None 99 | 100 | # O = softmax(scale * (Q * K^T) - scale) * V 101 | def cuda_impl(Q: TensorType['b', 'i', 'd'], 102 | K: TensorType['b', 'j', 'd'], 103 | V: TensorType['b', 'j', 'd'], 104 | scale=8) -> TensorType['b', 'i', 'j']: 105 | return MMACudaFunction.apply(Q, K, V, scale) 106 | 107 | plain_fn = benchmark( 108 | plain_impl, 109 | forwards = TEST_FORWARDS, 110 | backwards = TEST_BACKWARDS 111 | ) 112 | 113 | cuda_fn = benchmark( 114 | cuda_impl, 115 | forwards = TEST_FORWARDS, 116 | backwards = TEST_BACKWARDS 117 | ) 118 | 119 | def allclose(a, b, atol = 1e-2): 120 | diff = (a - b).abs().amax() 121 | return diff <= atol 122 | 123 | def l2norm(t): 124 | return F.normalize(t, dim = -1) 125 | 126 | def bench(batch_size=32, head_dim=64, v_dim=64, dtype=torch.float32): 127 | print("-" * 80) 128 | print(f'batch: {batch_size}\thead dim: {head_dim}\tV dim: {v_dim}\t\tdtype: {dtype}') 129 | print("-" * 80) 130 | for seq_len in TEST_SEQUENCE_LENGTHS: 131 | Q = torch.randn(batch_size, seq_len, head_dim, dtype=dtype).cuda().requires_grad_() 132 | K = torch.randn(batch_size, seq_len, head_dim, dtype=dtype).cuda().requires_grad_() 133 | V = torch.randn(batch_size, seq_len, v_dim, dtype=dtype).cuda().requires_grad_() 134 | #V = torch.ones(batch_size, seq_len, v_dim, dtype=dtype).cuda().requires_grad_() 135 | 136 | Q, K = map(l2norm, (Q, K)) 137 | 138 | if (seq_len <= 2048): 139 | # assert correctness 140 | C_plain = plain_impl(Q, K, V) 141 | C_cuda = cuda_impl(Q, K, V) 142 | 143 | #print(C_plain, C_plain.shape) 144 | #torch.set_printoptions(profile="full") 145 | #print(C_cuda, C_cuda.shape) 146 | #torch.set_printoptions(profile="default") # reset 147 | assert allclose(C_plain, C_cuda) 148 | 149 | # benchmark 150 | fused_time = cuda_fn(Q, K, V) 151 | try: 152 | baseline_time = plain_fn(Q, K, V) 153 | except: 154 | torch.cuda.empty_cache() 155 | baseline_time = -1 156 | 157 | times_slower = (fused_time / baseline_time) if baseline_time != -1 else 0. 158 | baseline_time_str = ' OOM' if baseline_time == -1 else f"{baseline_time:7.3f}ms" 159 | 160 | print(f'seq_len: {seq_len}\tslower: {times_slower:.2f}x\tkernel: {fused_time:7.3f}ms\tbaseline: {baseline_time_str}') 161 | 162 | bench(dtype=torch.float32) 163 | bench(dtype=torch.float16) -------------------------------------------------------------------------------- /MMA.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace mma { 6 | template 7 | struct warp_tile { 8 | // Dimensions of the tile, in threads: 9 | static constexpr int N_tile = N_tile_; 10 | static constexpr int M_tile = M_tile_; 11 | static constexpr int K_tile = 1; 12 | 13 | // Warp layout within a block: 14 | static constexpr int N_block = 2; 15 | static constexpr int M_block = 4; 16 | 17 | // Thread layout within a warp: 18 | static constexpr int N_warp = 8; 19 | static constexpr int M_warp = 4; 20 | 21 | // How much data is processed by a single thread: 22 | static constexpr int N_thread = N_tile / (N_warp * N_block); 23 | static constexpr int M_thread = M_tile / (M_warp * M_block); 24 | 25 | static_assert(N_warp * N_block * N_thread == N_tile); 26 | static_assert(M_warp * M_block * M_thread == M_tile); 27 | static_assert(N_warp * M_warp == 32); 28 | static_assert(N_block * M_block * N_warp * M_warp == 256); // blockDim.x 29 | 30 | // Registers: 31 | float C_frag[N_thread * M_thread]; // N x M fragment 32 | 33 | int warp_x; // x offset of the warp within the block tile 34 | int warp_y; // y offset of the warp within the block tile 35 | int thread_x; // x offset of the thread within the warp tile 36 | int thread_y; // y offset of the thread within the warp tile 37 | 38 | __device__ warp_tile() { 39 | int warp_id = threadIdx.x / 32; 40 | warp_x = (warp_id % M_block); 41 | warp_y = (warp_id / M_block); 42 | 43 | int lane_id = threadIdx.x % 32; 44 | thread_x = warp_x * M_warp * M_thread + lane_id % M_warp; 45 | thread_y = warp_y * N_warp * N_thread + lane_id / M_warp; 46 | } 47 | 48 | // Initialize C to all zeros 49 | __device__ void zero() { 50 | #pragma unroll 51 | for (int i = 0; i < N_thread * M_thread; i++) { 52 | C_frag[i] = 0.f; 53 | } 54 | } 55 | 56 | // Performs C = A * B + C 57 | template 58 | __device__ void mma(fragA& A_sm, fragB& B_sm, int ka0, int kb0, int D) { 59 | float A_frag[N_thread]; // N x 1 fragment 60 | float B_frag[M_thread]; // 1 x M fragment 61 | 62 | for (int k = 0; k < D; k += K_tile) { 63 | // Load a N x 1 fragment of A from shared memory to registers: 64 | #pragma unroll 65 | for (int i = 0; i < N_thread; i++) { 66 | A_frag[i] = A_sm(i * N_warp + thread_y, ka0 + k); 67 | } 68 | 69 | // Load a 1 x M fragment of B from shared memory to registers: 70 | #pragma unroll 71 | for (int i = 0; i < M_thread; i++) { 72 | B_frag[i] = B_sm(i * M_warp + thread_x, kb0 + k); 73 | } 74 | 75 | // Compute: 76 | #pragma unroll 77 | for (int i = 0; i < N_thread; i++) { 78 | #pragma unroll 79 | for (int j = 0; j < M_thread ; j++) { 80 | C_frag[i * M_thread + j] += A_frag[i] * B_frag[j]; 81 | } 82 | } 83 | } 84 | } 85 | 86 | // Perform a pointwise operation, specified by the given lambda, on C 87 | template 88 | __device__ void pointwise(F&& op) { 89 | #pragma unroll 90 | for (int i = 0; i < N_thread; i++) { 91 | int row = i * N_warp + thread_y; 92 | #pragma unroll 93 | for (int j = 0; j < M_thread; j++) { 94 | int col = j * M_warp + thread_x; 95 | C_frag[i * M_thread + j] = op(C_frag[i * M_thread + j], col, row); 96 | } 97 | } 98 | } 99 | 100 | // Copy C from registers to shared memory 101 | template 102 | __device__ void store(shared_fragment& C_sm) { 103 | #pragma unroll 104 | for (int i = 0; i < N_thread; i++) { 105 | #pragma unroll 106 | for (int j = 0; j < M_thread ; j++) { 107 | C_sm(j * M_warp + thread_x, i * N_warp + thread_y) = C_frag[i * M_thread + j]; 108 | } 109 | } 110 | } 111 | 112 | template 113 | __device__ void store_transpose(shared_fragment& C_sm) { 114 | #pragma unroll 115 | for (int i = 0; i < N_thread; i++) { 116 | #pragma unroll 117 | for (int j = 0; j < M_thread ; j++) { 118 | C_sm(i * N_warp + thread_y, j * M_warp + thread_x) = C_frag[i * M_thread + j]; 119 | } 120 | } 121 | } 122 | 123 | // Stream C from registers to global memory using temporary shared memory buffer 124 | template 125 | __device__ void store(accessor gmem, shared_fragment& smem, int tile_x, int tile_y) { 126 | store(smem); 127 | __syncthreads(); 128 | smem.store(gmem, tile_x, tile_y); 129 | } 130 | }; 131 | 132 | using namespace nvcuda; 133 | template 134 | struct warp_tile { 135 | // Dimensions of the tile, in threads: 136 | static constexpr int N_tile = N_tile_; 137 | static constexpr int M_tile = M_tile_; 138 | static constexpr int K_tile = 16; 139 | 140 | // Warp layout within a block: 141 | static constexpr int N_block = 2; 142 | static constexpr int M_block = 4; 143 | 144 | // Thread layout within a warp: 145 | static constexpr int N_warp = 16; 146 | static constexpr int M_warp = 16; 147 | 148 | // How much data is processed by a single thread: 149 | static constexpr int N_thread = N_tile / (N_warp * N_block); 150 | static constexpr int M_thread = M_tile / (M_warp * M_block); 151 | 152 | static_assert(N_warp * N_block * N_thread == N_tile); 153 | static_assert(M_warp * M_block * M_thread == M_tile); 154 | 155 | using output_t = float; // TODO: make this a template parameter 156 | 157 | // Registers: 158 | wmma::fragment C_frag[N_thread * M_thread]; 159 | 160 | int warp_x; // x offset of the warp within the block tile 161 | int warp_y; // y offset of the warp within the block tile 162 | 163 | __device__ warp_tile() { 164 | int warp_id = threadIdx.x / 32; 165 | warp_x = (warp_id % M_block); 166 | warp_y = (warp_id / M_block); 167 | } 168 | 169 | // Initialize C to all zeros 170 | __device__ void zero() { 171 | #pragma unroll 172 | for (int i = 0; i < N_thread; i++) { 173 | #pragma unroll 174 | for (int j = 0; j < M_thread; j++) { 175 | #pragma unroll 176 | for (int k = 0; k < C_frag[i * M_thread + j].num_elements; k++) { 177 | C_frag[i * M_thread + j].x[k] = (c10::Half) 0.f; 178 | } 179 | } 180 | } 181 | } 182 | 183 | // Performs C = A * B + C 184 | template 185 | __device__ void mma(fragA& A_sm, fragB& B_sm, int ka0, int kb0, int D) { 186 | wmma::fragment A_frag; 187 | wmma::fragment B_frag; 188 | 189 | for (int k = 0; k < D; k += K_tile) { 190 | #pragma unroll 191 | for (int j = 0; j < M_thread; j++) { 192 | // Load a 1 x M fragment of B from shared memory to registers: 193 | int x = (warp_x * M_thread + j) * M_warp; 194 | wmma::load_matrix_sync(B_frag, reinterpret_cast(&B_sm(x, kb0 + k)), B_sm.stride); 195 | 196 | #pragma unroll 197 | for (int i = 0; i < N_thread; i++) { 198 | // Load a N x 1 fragment of A from shared memory to registers: 199 | int y = (warp_y * N_thread + i) * N_warp; 200 | wmma::load_matrix_sync(A_frag, reinterpret_cast(&A_sm(y, ka0 + k)), A_sm.stride); 201 | 202 | // Compute: 203 | wmma::mma_sync(C_frag[i * M_thread + j], A_frag, B_frag, C_frag[i * M_thread + j]); 204 | } 205 | } 206 | } 207 | } 208 | 209 | __device__ int getWarpRow(int i) { 210 | int tid = threadIdx.x % 32; 211 | #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 700) 212 | if (std::is_same::value) { 213 | return (tid & 3) + ((tid & 4) << 1) + ((tid & 16) >> 2); 214 | } else { 215 | return (tid & 16) / 4 + 2 * (tid & 4) + (tid & 1) + (i & 2); 216 | } 217 | #else 218 | return (i & 2) * 4 + tid / 4; 219 | #endif 220 | } 221 | 222 | __device__ int getWarpCol(int i) { 223 | int tid = threadIdx.x % 32; 224 | #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 700) 225 | if (std::is_same::value) { 226 | return (i & 7) + (tid & 8); 227 | } else { 228 | return (tid & 10) + (i & 5); 229 | } 230 | #else 231 | return (tid % 4) * 2 + i % 2 + (i & 4) * 2; 232 | #endif 233 | } 234 | 235 | // Perform a pointwise operation, specified by the given lambda, on C 236 | template 237 | __device__ void pointwise(F&& op) { 238 | #pragma unroll 239 | for (int i = 0; i < N_thread; i++) { 240 | #pragma unroll 241 | for (int j = 0; j < M_thread; j++) { 242 | #pragma unroll 243 | for (int k = 0; k < C_frag[i * M_thread + j].num_elements; k++) { 244 | int col = getWarpCol(k) + (warp_x * M_thread + j) * M_warp; 245 | int row = getWarpRow(k) + (warp_y * N_thread + i) * N_warp; 246 | C_frag[i * M_thread + j].x[k] = op(C_frag[i * M_thread + j].x[k], col, row); 247 | } 248 | } 249 | } 250 | } 251 | 252 | // Copy C from registers to shared memory 253 | template 254 | __device__ void store(shared_fragment& C_sm) { 255 | #pragma unroll 256 | for (int i = 0; i < N_thread; i++) { 257 | #pragma unroll 258 | for (int j = 0; j < M_thread; j++) { 259 | #pragma unroll 260 | for (int k = 0; k < C_frag[i * M_thread + j].num_elements; k++) { 261 | int col = getWarpCol(k) + (warp_x * M_thread + j) * M_warp; 262 | int row = getWarpRow(k) + (warp_y * N_thread + i) * N_warp; 263 | C_sm(col, row) = C_frag[i * M_thread + j].x[k]; 264 | } 265 | } 266 | } 267 | } 268 | 269 | template 270 | __device__ void store_transpose(shared_fragment& C_sm) { 271 | #pragma unroll 272 | for (int i = 0; i < N_thread; i++) { 273 | #pragma unroll 274 | for (int j = 0; j < M_thread; j++) { 275 | #pragma unroll 276 | for (int k = 0; k < C_frag[i * M_thread + j].num_elements; k++) { 277 | int col = getWarpCol(k) + (warp_x * M_thread + j) * M_warp; 278 | int row = getWarpRow(k) + (warp_y * N_thread + i) * N_warp; 279 | C_sm(row, col) = C_frag[i * M_thread + j].x[k]; 280 | } 281 | } 282 | } 283 | } 284 | 285 | // Stream C from registers to global memory using temporary shared memory buffer 286 | template 287 | __device__ void store(accessor gmem, shared_fragment& smem, int tile_x, int tile_y) { 288 | store(smem); 289 | __syncthreads(); 290 | smem.store(gmem, tile_x, tile_y); 291 | } 292 | }; 293 | } // namespace mma --------------------------------------------------------------------------------