├── matvec ├── benchmarks │ ├── kernel_1_vs_cublas-gflops.txt │ ├── kernel_1_vs_cublas-memory.txt │ ├── kernel_2_vs_cublas-gflops.txt │ ├── kernel_2_vs_cublas-memory.txt │ ├── kernel_3_vs_cublas-gflops.txt │ ├── kernel_3_vs_cublas-memory.txt │ ├── kernel_4_vs_cublas-gflops.txt │ ├── kernel_4_vs_cublas-memory.txt │ ├── kernel_custom_vs_cublas-gflops.txt │ ├── kernel_custom_vs_cublas-memory.txt │ └── plot.py ├── media │ ├── coalesced-access.png │ ├── benchmark_results.png │ └── sgemv-computation.png ├── include │ ├── cublas_0.cuh │ ├── naive_1.cuh │ ├── vectorized_4.cuh │ ├── coalesced_warp_2.cuh │ ├── coalesced_warpblock_3.cuh │ └── utils.cuh ├── Makefile ├── kernels │ ├── cublas_0.cu │ ├── naive_1.cu │ ├── utils.cu │ ├── coalesced_warp_2.cu │ ├── coalesced_warpblock_3.cu │ └── vectorized_4.cu ├── src.cu └── README.md ├── softmax ├── .vscode │ └── settings.json ├── softmax ├── media │ ├── max_reduction.png │ ├── naive_thread_mapping.png │ └── threads_collab_load.png ├── benchmarks │ ├── benchmark_1.png │ ├── exec_time_ms_cuda.txt │ ├── exec_time_ms_torch.txt │ └── plot.py ├── include │ ├── shfl_3.cuh │ ├── naive_0.cuh │ ├── online_1.cuh │ ├── sharedmem_2.cuh │ ├── vectorized_4.cuh │ ├── blocktiling_5.cuh │ └── cuda_utils.cuh ├── Makefile ├── torch_benchmark.py ├── kernels │ ├── naive_0.cu │ ├── online_1.cu │ ├── sharedmem_2.cu │ ├── blocktiling_5.cu │ ├── vectorized_4.cu │ └── shfl_3.cu ├── test.cu └── bench.cu ├── .gitignore ├── attention ├── execution_speed_comparison.png ├── execution_speed_comparison_tensorcores.png ├── build.cpp ├── pycublas.py ├── test.py ├── bench.py └── attn.cu ├── flash-attention ├── fa1 │ ├── build.cpp │ ├── smolattn.py │ └── flash-attn-1.cu ├── fa2 │ ├── build.cpp │ ├── smolattn2.py │ ├── smolattn2_fp16.py │ ├── flash-attn-2.cu │ └── flash2_fwd_fp16.cu └── triton_attn.py ├── matmul ├── torch_benchmark.py └── main.cu ├── README.md ├── query-device └── main.cu └── LICENSE /matvec/benchmarks/kernel_1_vs_cublas-gflops.txt: -------------------------------------------------------------------------------- 1 | 4096 8.105875 43.895515 2 | -------------------------------------------------------------------------------- /matvec/benchmarks/kernel_1_vs_cublas-memory.txt: -------------------------------------------------------------------------------- 1 | 4096 14.463547 78.324036 2 | -------------------------------------------------------------------------------- /matvec/benchmarks/kernel_2_vs_cublas-gflops.txt: -------------------------------------------------------------------------------- 1 | 4096 27.420919 43.778225 2 | -------------------------------------------------------------------------------- /matvec/benchmarks/kernel_2_vs_cublas-memory.txt: -------------------------------------------------------------------------------- 1 | 4096 48.927948 78.114754 2 | -------------------------------------------------------------------------------- /matvec/benchmarks/kernel_3_vs_cublas-gflops.txt: -------------------------------------------------------------------------------- 1 | 4096 41.744339 44.043011 2 | -------------------------------------------------------------------------------- /matvec/benchmarks/kernel_3_vs_cublas-memory.txt: -------------------------------------------------------------------------------- 1 | 4096 74.485626 78.587219 2 | -------------------------------------------------------------------------------- /matvec/benchmarks/kernel_4_vs_cublas-gflops.txt: -------------------------------------------------------------------------------- 1 | 4096 49.461132 44.086529 2 | -------------------------------------------------------------------------------- /matvec/benchmarks/kernel_4_vs_cublas-memory.txt: -------------------------------------------------------------------------------- 1 | 4096 88.254936 78.664871 2 | -------------------------------------------------------------------------------- /softmax/.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "makefile.configureOnOpen": true 3 | } -------------------------------------------------------------------------------- /softmax/softmax: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Maharshi-Pandya/cudacodes/HEAD/softmax/softmax -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.exe 2 | *.exp 3 | *.lib 4 | */build/* 5 | __pycache__ 6 | scratchpad.md 7 | .vscode 8 | *.ptx -------------------------------------------------------------------------------- /softmax/media/max_reduction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Maharshi-Pandya/cudacodes/HEAD/softmax/media/max_reduction.png -------------------------------------------------------------------------------- /matvec/media/coalesced-access.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Maharshi-Pandya/cudacodes/HEAD/matvec/media/coalesced-access.png -------------------------------------------------------------------------------- /matvec/media/benchmark_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Maharshi-Pandya/cudacodes/HEAD/matvec/media/benchmark_results.png -------------------------------------------------------------------------------- /matvec/media/sgemv-computation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Maharshi-Pandya/cudacodes/HEAD/matvec/media/sgemv-computation.png -------------------------------------------------------------------------------- /softmax/benchmarks/benchmark_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Maharshi-Pandya/cudacodes/HEAD/softmax/benchmarks/benchmark_1.png -------------------------------------------------------------------------------- /softmax/media/naive_thread_mapping.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Maharshi-Pandya/cudacodes/HEAD/softmax/media/naive_thread_mapping.png -------------------------------------------------------------------------------- /softmax/media/threads_collab_load.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Maharshi-Pandya/cudacodes/HEAD/softmax/media/threads_collab_load.png -------------------------------------------------------------------------------- /attention/execution_speed_comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Maharshi-Pandya/cudacodes/HEAD/attention/execution_speed_comparison.png -------------------------------------------------------------------------------- /attention/execution_speed_comparison_tensorcores.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Maharshi-Pandya/cudacodes/HEAD/attention/execution_speed_comparison_tensorcores.png -------------------------------------------------------------------------------- /softmax/benchmarks/exec_time_ms_cuda.txt: -------------------------------------------------------------------------------- 1 | 1024 0.829120 2 | 1024 0.138336 3 | 1024 0.231264 4 | 1024 0.520096 5 | 1024 1.112768 6 | 1024 2.646880 7 | 1024 5.287200 8 | -------------------------------------------------------------------------------- /softmax/benchmarks/exec_time_ms_torch.txt: -------------------------------------------------------------------------------- 1 | 1024 0.293731689453125 2 | 1024 0.3722667694091797 3 | 1024 0.7652759552001953 4 | 1024 1.2854576110839844 5 | 1024 2.215290069580078 6 | 1024 10.414981842041016 7 | 1024 9.735441207885742 8 | -------------------------------------------------------------------------------- /matvec/benchmarks/kernel_custom_vs_cublas-gflops.txt: -------------------------------------------------------------------------------- 1 | 128 1.306123 1.230769 2 | 256 11.283747 17.102297 3 | 512 25.225557 34.565399 4 | 1024 45.637882 46.022472 5 | 2048 45.893555 49.804123 6 | 4096 50.646061 42.666664 7 | 8192 50.786800 50.383244 8 | -------------------------------------------------------------------------------- /matvec/benchmarks/kernel_custom_vs_cublas-memory.txt: -------------------------------------------------------------------------------- 1 | 128 2.357000 2.221019 2 | 256 20.244474 30.683693 3 | 512 45.126034 61.834099 4 | 1024 81.522423 82.209412 5 | 2048 81.919182 88.899467 6 | 4096 90.369240 76.131371 7 | 8192 90.603767 89.883820 8 | -------------------------------------------------------------------------------- /flash-attention/fa1/build.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | torch::Tensor fa_forward(torch::Tensor q, torch::Tensor k, torch::Tensor v); 4 | 5 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 6 | m.def("fa_forward", torch::wrap_pybind_function(fa_forward), "fa_forward"); 7 | } -------------------------------------------------------------------------------- /flash-attention/fa2/build.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | torch::Tensor fa2_forward(torch::Tensor q, torch::Tensor k, torch::Tensor v); 4 | 5 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 6 | m.def("fa2_forward", torch::wrap_pybind_function(fa2_forward), "fa2_forward"); 7 | } -------------------------------------------------------------------------------- /softmax/include/shfl_3.cuh: -------------------------------------------------------------------------------- 1 | #ifndef SHFL_SOFTMAX 2 | #define SHFL_SOFTMAX 3 | 4 | __global__ void softmax_kernel_3(float* __restrict__ matd, float* __restrict__ resd, int M, int N); 5 | 6 | void run_kernel_3(float* __restrict__ matd, float* __restrict__ resd, int M, int N); 7 | 8 | #endif // SHFL_SOFTMAX -------------------------------------------------------------------------------- /matvec/include/cublas_0.cuh: -------------------------------------------------------------------------------- 1 | #ifndef CUBLAS_SGEMV 2 | #define CUBLAS_SGEMV 3 | 4 | float run_kernel_cublas_sgemv(float* __restrict__ matd, float* __restrict__ vecd, float* __restrict__ resd, int M, int N, float THEORETICAL_MAX_GFLOPS, float THEORETICAL_MAX_MEMORY_BANDWIDTH); 5 | 6 | #endif // CUBLAS_SGEMV 7 | -------------------------------------------------------------------------------- /softmax/include/naive_0.cuh: -------------------------------------------------------------------------------- 1 | #ifndef NAIVE_SOFTMAX 2 | #define NAIVE_SOFTMAX 3 | 4 | __global__ void softmax_kernel_0(float* __restrict__ matd, float* __restrict__ resd, int M, int N); 5 | 6 | void run_kernel_0(float* __restrict__ matd, float* __restrict__ resd, int M, int N); 7 | 8 | #endif // NAIVE_SOFTMAX -------------------------------------------------------------------------------- /softmax/include/online_1.cuh: -------------------------------------------------------------------------------- 1 | #ifndef ONLINE_SOFTMAX 2 | #define ONLINE_SOFTMAX 3 | 4 | __global__ void softmax_kernel_1(float* __restrict__ matd, float* __restrict__ resd, int M, int N); 5 | 6 | void run_kernel_1(float* __restrict__ matd, float* __restrict__ resd, int M, int N); 7 | 8 | #endif // ONLINE_SOFTMAX -------------------------------------------------------------------------------- /softmax/include/sharedmem_2.cuh: -------------------------------------------------------------------------------- 1 | #ifndef SHAREDMEM_SOFTMAX 2 | #define SHAREDMEM_SOFTMAX 3 | 4 | __global__ void softmax_kernel_2(float* __restrict__ matd, float* __restrict__ resd, int M, int N); 5 | 6 | void run_kernel_2(float* __restrict__ matd, float* __restrict__ resd, int M, int N); 7 | 8 | #endif // SHAREDMEM_SOFTMAX -------------------------------------------------------------------------------- /softmax/include/vectorized_4.cuh: -------------------------------------------------------------------------------- 1 | #ifndef VECTORIZED_SOFTMAX 2 | #define VECTORIZED_SOFTMAX 3 | 4 | __global__ void softmax_kernel_4(float* __restrict__ matd, float* __restrict__ resd, int M, int N); 5 | 6 | float run_kernel_4(float* __restrict__ matd, float* __restrict__ resd, int M, int N); 7 | 8 | #endif // VECTORIZED_SOFTMAX -------------------------------------------------------------------------------- /softmax/include/blocktiling_5.cuh: -------------------------------------------------------------------------------- 1 | #ifndef BLOCKTILING_SOFTMAX 2 | #define BLOCKTILING_SOFTMAX 3 | 4 | __global__ void softmax_kernel_5(float* __restrict__ matd, float* __restrict__ resd, int M, int N); 5 | 6 | void run_kernel_5(float* __restrict__ matd, float* __restrict__ resd, int M, int N); 7 | 8 | #endif // BLOCKTILING_SOFTMAX -------------------------------------------------------------------------------- /attention/build.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | torch::Tensor attention_forward(uint64_t handle, torch::Tensor Q, torch::Tensor K, torch::Tensor V); 5 | 6 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 7 | m.def("attention_forward", torch::wrap_pybind_function(attention_forward), "attention_forward"); 8 | } 9 | -------------------------------------------------------------------------------- /matvec/include/naive_1.cuh: -------------------------------------------------------------------------------- 1 | #ifndef NAIVE_SGEMV 2 | #define NAIVE_SGEMV 3 | 4 | __global__ void naive_sgemv_kernel(float* __restrict__ matd, float* __restrict__ vecd, float* __restrict__ resd, int M, int N); 5 | 6 | float run_kernel_naive_sgemv(float* __restrict__ matd, float* __restrict__ vecd, float* __restrict__ resd, int M, int N, float THEORETICAL_MAX_GFLOPS, float THEORETICAL_MAX_MEMORY_BANDWIDTH); 7 | 8 | #endif // NAIVE_SGEMV 9 | -------------------------------------------------------------------------------- /matvec/include/vectorized_4.cuh: -------------------------------------------------------------------------------- 1 | #ifndef VECTORIZED_SGEMV 2 | #define VECTORIZED_SGEMV 3 | 4 | __global__ void vectorized_sgemv_kernel(float* __restrict__ matd, float* __restrict__ vecd, float* __restrict__ resd, int M, int N); 5 | 6 | float run_kernel_vectorized_sgmev(float* __restrict__ matd, float* __restrict__ vecd, float* __restrict__ resd, int M, int N, float THEORETICAL_MAX_GFLOPS, float THEORETICAL_MAX_MEMORY_BANDWIDTH); 7 | 8 | #endif // VECTORIZED_SGEMV 9 | -------------------------------------------------------------------------------- /matvec/include/coalesced_warp_2.cuh: -------------------------------------------------------------------------------- 1 | #ifndef COALWARP_SGEMV 2 | #define COALWARP_SGEMV 3 | 4 | __global__ void coalesced_warp_sgmev_kernel(float* __restrict__ matd, float* __restrict__ vecd, float* __restrict__ resd, int M, int N); 5 | 6 | float run_kernel_coalesced_warp_sgmev(float* __restrict__ matd, float* __restrict__ vecd, float* __restrict__ resd, int M, int N, float THEORETICAL_MAX_GFLOPS, float THEORETICAL_MAX_MEMORY_BANDWIDTH); 7 | 8 | #endif // COALWARP_SGEMV 9 | -------------------------------------------------------------------------------- /matvec/include/coalesced_warpblock_3.cuh: -------------------------------------------------------------------------------- 1 | #ifndef COALWARPBLOCK_SGEMV 2 | #define COALWARPBLOCK_SGEMV 3 | 4 | __global__ void coalesced_warpblock_sgmev_kernel(float* __restrict__ matd, float* __restrict__ vecd, float* __restrict__ resd, int M, int N); 5 | 6 | float run_kernel_coalesced_warpblock_sgmev(float* __restrict__ matd, float* __restrict__ vecd, float* __restrict__ resd, int M, int N, float THEORETICAL_MAX_GFLOPS, float THEORETICAL_MAX_MEMORY_BANDWIDTH); 7 | 8 | #endif // COALWARPBLOCK_SGEMV 9 | -------------------------------------------------------------------------------- /matmul/torch_benchmark.py: -------------------------------------------------------------------------------- 1 | # simple benchmark 2 | 3 | import torch 4 | import time 5 | 6 | matrix_size = 4096 # (1024 x 1024) 7 | max_val = 10 8 | min_val = -10 9 | 10 | # Initialize random tensors with a normal distribution and clamp values 11 | A = torch.randn((matrix_size, matrix_size)).clamp(min=min_val, max=max_val) 12 | B = torch.randn((matrix_size, matrix_size)).clamp(min=min_val, max=max_val) 13 | 14 | A, B = A.cuda(), B.cuda() 15 | 16 | print(f">> Benchmarking torch.matmul for {matrix_size} x {matrix_size} matrices...") 17 | start_time = time.time() 18 | C = torch.matmul(A, B) 19 | end_time = time.time() 20 | 21 | elapsed_time_ms = (end_time - start_time) * 1000 22 | 23 | print(f">> Matrix multiplication completed in {elapsed_time_ms:.3f} ms.") -------------------------------------------------------------------------------- /attention/pycublas.py: -------------------------------------------------------------------------------- 1 | import ctypes 2 | 3 | CUBLAS_STATUS_SUCCESS = 0 4 | 5 | # load the cublas shared library 6 | cublas = ctypes.cdll.LoadLibrary("libcublas.so") 7 | 8 | cublas.cublasCreate_v2.restype = ctypes.c_int 9 | cublas.cublasCreate_v2.argtypes = [ctypes.POINTER(ctypes.c_void_p)] 10 | 11 | cublas.cublasDestroy_v2.restype = ctypes.c_int 12 | cublas.cublasDestroy_v2.argtypes = [ctypes.c_void_p] 13 | 14 | 15 | def cublas_create_handle(): 16 | handle = ctypes.c_void_p() 17 | status = cublas.cublasCreate_v2(ctypes.byref(handle)) 18 | if status != CUBLAS_STATUS_SUCCESS: 19 | raise RuntimeError(f"cublasCreate failed with status {status}") 20 | 21 | return handle 22 | 23 | def cublas_destroy_handle(handle): 24 | status = cublas.cublasDestroy_v2(handle) 25 | if status != CUBLAS_STATUS_SUCCESS: 26 | raise RuntimeError(f"cublasDestroy failed with status {status}") 27 | -------------------------------------------------------------------------------- /softmax/benchmarks/plot.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | plt.style.use('fivethirtyeight') 4 | 5 | matrix_sizes = [2048, 4096, 8192, 16384, 32768, 65536, 131072] 6 | custom_cuda_time = [0.739136, 0.732832, 1.107200, 2.074880, 4.109728, 8.162368, 16.282017] 7 | torch_time = [3.1140804290771484, 3.124523162841797, 3.125, 3.1251907348632812, 8 | 6.25004768371582, 12.506628036499023, 25.00143051147461] 9 | 10 | plt.figure(figsize=(12, 7)) 11 | plt.plot(matrix_sizes, custom_cuda_time, label='Custom CUDA softmax', 12 | marker='o', linestyle='-', linewidth=2, color='#2ecc71') 13 | plt.plot(matrix_sizes, torch_time, label='PyTorch softmax', 14 | marker='s', linestyle='--', linewidth=2, color='#e74c3c') 15 | 16 | plt.title('Softmax: Execution Time Comparison', fontsize=16, pad=20) 17 | plt.xlabel('Matrix Column Size (N)', fontsize=14) 18 | plt.ylabel('Execution Time (ms)', fontsize=14) 19 | plt.xticks(matrix_sizes, rotation=90, fontsize=10) 20 | plt.yticks(fontsize=12) 21 | 22 | plt.grid(True, alpha=0.9) 23 | plt.legend(fontsize=12, framealpha=0.8) 24 | plt.tight_layout() 25 | plt.show() -------------------------------------------------------------------------------- /matvec/Makefile: -------------------------------------------------------------------------------- 1 | # Compiler and flags 2 | NVCC = nvcc 3 | CFLAGS = -std=c++17 -allow-unsupported-compiler -lcublas 4 | 5 | # Directories 6 | INCLUDE_DIR = include 7 | KERNELS_DIR = kernels 8 | BUILD_DIR = build 9 | 10 | # Automatically detect all .cu files 11 | KERNEL_SOURCES = $(wildcard $(KERNELS_DIR)/*.cu) 12 | ROOT_SOURCES = src.cu 13 | CUDA_SOURCES = $(KERNEL_SOURCES) $(ROOT_SOURCES) 14 | 15 | # Generate object file names 16 | CUDA_OBJECTS = $(patsubst $(KERNELS_DIR)/%.cu,$(BUILD_DIR)/%.obj,$(KERNEL_SOURCES)) $(BUILD_DIR)/src.obj 17 | 18 | # Output executable 19 | OUTPUT = matvec.exe 20 | 21 | # Rules 22 | all: $(OUTPUT) 23 | 24 | $(OUTPUT): $(CUDA_OBJECTS) 25 | $(NVCC) $(CFLAGS) -o $@ $^ 26 | 27 | # Rules for compiling kernels directory .cu files 28 | $(BUILD_DIR)/%.obj: $(KERNELS_DIR)/%.cu | $(BUILD_DIR) 29 | $(NVCC) $(CFLAGS) -I$(INCLUDE_DIR) -c $< -o $@ 30 | 31 | # Rule for compiling root directory .cu files 32 | $(BUILD_DIR)/src.obj: src.cu | $(BUILD_DIR) 33 | $(NVCC) $(CFLAGS) -I$(INCLUDE_DIR) -c $< -o $@ 34 | 35 | # Create the build directory 36 | $(BUILD_DIR): 37 | mkdir $(BUILD_DIR) 38 | 39 | # Clean rule 40 | clean: 41 | rm $(BUILD_DIR)/*.obj *.exe *.exp *.lib 42 | 43 | .PHONY: all clean 44 | -------------------------------------------------------------------------------- /matvec/kernels/cublas_0.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include "utils.cuh" 7 | 8 | /* 9 | CuBLAS matrix vector multiplication for the baseline scores. 10 | We simply run the Sgemv function that cuBLAS provides. 11 | */ 12 | float run_kernel_cublas_sgemv(float* __restrict__ matd, float* __restrict__ vecd, float* __restrict__ resd, int M, int N, float THEORETICAL_MAX_GFLOPS, float THEORETICAL_MAX_MEMORY_BANDWIDTH) { 13 | cudaEvent_t start, stop; 14 | CUDA_CHECK(cudaEventCreate(&start)); 15 | CUDA_CHECK(cudaEventCreate(&stop)); 16 | float ms = 0.0f; 17 | 18 | // create cublas handle 19 | cublasHandle_t handle; 20 | cublasCreate(&handle); 21 | 22 | // Sgemv: y = (alpha * A * x) + (beta * y) 23 | float alpha = 1.0f, beta = 0.0f; 24 | cudaEventRecord(start); 25 | cublasSgemv(handle, CUBLAS_OP_T, N, M, &alpha, matd, N, vecd, 1, &beta, resd, 1); 26 | cudaEventRecord(stop); 27 | cudaEventSynchronize(stop); 28 | cudaEventElapsedTime(&ms, start, stop); 29 | printf("------- cuBLAS sgmev kernel ---------\n"); 30 | print_kernel_essentials(M, N, ms, THEORETICAL_MAX_GFLOPS, THEORETICAL_MAX_MEMORY_BANDWIDTH); 31 | printf("---------------------------\n"); 32 | 33 | cublasDestroy(handle); 34 | CUDA_CHECK(cudaEventDestroy(start)); 35 | CUDA_CHECK(cudaEventDestroy(stop)); 36 | 37 | return ms; 38 | } 39 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CUDA Code Explorations 2 | 3 | This repository showcases my journey learning and experimenting with CUDA, primarily using C/C++. Inside, you'll find code examples, insights, and optimizations related to various CUDA concepts. Each directory focuses on a specific topic, with extensively commented code to guide understanding. 4 | 5 | ## What's Inside? 6 | 7 | This repo contains: 8 | 9 | * **Practical Examples:** Hands-on CUDA code for common operations. 10 | * **Optimizations:** Explores different optimization techniques to improve performance. 11 | * **Well-Commented Code:** Every line is explained, making it easy to follow along. 12 | * **Learning Resource:** A place to learn from and improve your own CUDA skills. 13 | 14 | ## Quick Navigation 15 | 16 | * **`softmax/`**: Explores various softmax implementations, from naive to optimized. 17 | * **`matmul/`**: Demonstrates matrix multiplication with tiling strategies. 18 | * **`matvec/`**: Iteratively optimizes matrix-vector multiplication (SGEMV) to achieve cuBLAS-like performance. 19 | * **`query-device/`**: A simple tool to get device information of GPU. 20 | * **`flash-attention/`**: Exploration of flash attention algorithms (Work in progress). 21 | 22 | ## Getting Started 23 | 24 | 1. Clone the repository: 25 | ``` 26 | git clone https://github.com/Maharshi-Pandya/cudacodes.git 27 | ``` 28 | 2. Navigate to the directory of interest (e.g., `cd matmul`) 29 | 3. Check the `README.md` inside the directory to learn how to compile and run the examples. 30 | -------------------------------------------------------------------------------- /softmax/Makefile: -------------------------------------------------------------------------------- 1 | # Detect OS 2 | ifeq ($(OS),Windows_NT) 3 | EXEC_EXT = .exe 4 | OBJ_EXT = .obj 5 | RM_CMD = rm -f 6 | MKDIR_CMD = mkdir 7 | PATH_SEP = \\ 8 | else 9 | EXEC_EXT = 10 | OBJ_EXT = .o 11 | RM_CMD = rm -f 12 | MKDIR_CMD = mkdir -p 13 | PATH_SEP = / 14 | endif 15 | 16 | # Compiler and flags 17 | NVCC = nvcc 18 | CFLAGS = -std=c++17 --extended-lambda 19 | 20 | # Directories 21 | INCLUDE_DIR = include 22 | KERNELS_DIR = kernels 23 | BUILD_DIR = build 24 | 25 | # Automatically detect all .cu files 26 | KERNEL_SOURCES = $(wildcard $(KERNELS_DIR)/*.cu) 27 | ROOT_SOURCES = bench.cu 28 | CUDA_SOURCES = $(KERNEL_SOURCES) $(ROOT_SOURCES) 29 | 30 | # Generate object file names 31 | CUDA_OBJECTS = $(patsubst $(KERNELS_DIR)/%.cu,$(BUILD_DIR)/%$(OBJ_EXT),$(KERNEL_SOURCES)) $(BUILD_DIR)/bench$(OBJ_EXT) 32 | 33 | # Output executable 34 | OUTPUT = softmax$(EXEC_EXT) 35 | 36 | # Rules 37 | all: $(OUTPUT) 38 | 39 | $(OUTPUT): $(CUDA_OBJECTS) 40 | $(NVCC) $(CFLAGS) -o $@ $^ 41 | 42 | # Rules for compiling kernels directory .cu files 43 | $(BUILD_DIR)/%$(OBJ_EXT): $(KERNELS_DIR)/%.cu | $(BUILD_DIR) 44 | $(NVCC) $(CFLAGS) -I$(INCLUDE_DIR) -c $< -o $@ 45 | 46 | # Rule for compiling root directory .cu files 47 | $(BUILD_DIR)/bench$(OBJ_EXT): bench.cu | $(BUILD_DIR) 48 | $(NVCC) $(CFLAGS) -I$(INCLUDE_DIR) -c $< -o $@ 49 | 50 | # Create the build directory 51 | $(BUILD_DIR): 52 | $(MKDIR_CMD) $(BUILD_DIR) 53 | 54 | # Clean rule 55 | clean: 56 | ifeq ($(OS),Windows_NT) 57 | if exist $(BUILD_DIR) $(RM_CMD) $(BUILD_DIR)\*$(OBJ_EXT) 58 | if exist $(OUTPUT) $(RM_CMD) $(OUTPUT) 59 | if exist *.exp $(RM_CMD) *.exp 60 | if exist *.lib $(RM_CMD) *.lib 61 | else 62 | $(RM_CMD) $(BUILD_DIR)/*$(OBJ_EXT) $(OUTPUT) 63 | endif 64 | 65 | .PHONY: all clean -------------------------------------------------------------------------------- /matvec/kernels/naive_1.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include "utils.cuh" 7 | 8 | /* 9 | Naive Sgemv kernel 10 | 11 | - Each thread calculates one element of the output vector 12 | - The row index is calculated using block index and thread index 13 | - Uses linearized indexing 14 | - Memory accesses are not coalesced 15 | */ 16 | __global__ void naive_sgemv_kernel(float* __restrict__ matd, float* __restrict__ vecd, float* __restrict__ resd, int M, int N) { 17 | int row = blockDim.x * blockIdx.x + threadIdx.x; 18 | 19 | if (row < M) { 20 | float sum = 0.0f; 21 | for (int col = 0; col < N; col++) { 22 | sum += matd[row * N + col] * vecd[col]; 23 | } 24 | resd[row] = sum; 25 | } 26 | } 27 | 28 | /* 29 | Runs the naive Sgemv kernel. 30 | */ 31 | float run_kernel_naive_sgemv(float* __restrict__ matd, float* __restrict__ vecd, float* __restrict__ resd, int M, int N, float THEORETICAL_MAX_GFLOPS, float THEORETICAL_MAX_MEMORY_BANDWIDTH) { 32 | dim3 block_size(1024); 33 | dim3 grid_size(CEIL_DIV(M, block_size.x)); 34 | 35 | cudaEvent_t start, stop; 36 | CUDA_CHECK(cudaEventCreate(&start)); 37 | CUDA_CHECK(cudaEventCreate(&stop)); 38 | float ms = 0.f; 39 | 40 | CUDA_CHECK(cudaEventRecord(start)); 41 | naive_sgemv_kernel<<>>(matd, vecd, resd, M, N); 42 | CUDA_CHECK(cudaEventRecord(stop)); 43 | CUDA_CHECK(cudaEventSynchronize(stop)); 44 | CUDA_CHECK(cudaEventElapsedTime(&ms, start, stop)); 45 | printf("------- Naive sgmev kernel ---------\n"); 46 | print_kernel_essentials(M, N, ms, THEORETICAL_MAX_GFLOPS, THEORETICAL_MAX_MEMORY_BANDWIDTH); 47 | printf("---------------------------\n"); 48 | 49 | CUDA_CHECK(cudaEventDestroy(start)); 50 | CUDA_CHECK(cudaEventDestroy(stop)); 51 | 52 | return ms; 53 | } 54 | -------------------------------------------------------------------------------- /query-device/main.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | int main() { 6 | int dev_count; 7 | cudaDeviceProp prop; 8 | 9 | cudaGetDeviceCount(&dev_count); 10 | cudaGetDeviceProperties(&prop, 0); 11 | 12 | printf(">> CUDA enabled devices in the system: %d\n", dev_count); 13 | printf(">> Compute capability: %d.%d\n", prop.major, prop.minor); 14 | 15 | printf(">> Max grid size: (%d, %d, %d)\n", prop.maxGridSize[0], prop.maxGridSize[1], prop.maxGridSize[2]); 16 | printf(">> Max block size: %d\n", prop.maxThreadsPerBlock); 17 | 18 | printf(">> Number of SMs: %d\n", prop.multiProcessorCount); 19 | printf(">> Clock rate of the SMs (in kHz): %d\n", prop.clockRate); 20 | 21 | printf(">> Max threads dimension: (%d, %d, %d)\n", prop.maxThreadsDim[0], prop.maxThreadsDim[1], prop.maxThreadsDim[2]); 22 | printf(">> Max threads per SM: %d\n", prop.maxThreadsPerMultiProcessor); 23 | 24 | printf(">> Registers available per block: %d\n", prop.regsPerBlock); 25 | printf(">> Registers available per SM: %d\n", prop.regsPerMultiprocessor); 26 | 27 | printf(">> Warp size (threads per warp): %d\n", prop.warpSize); 28 | printf(">> Shared memory size per block: %zd bytes\n", prop.sharedMemPerBlock); 29 | printf(">> Shared memory size per SM: %zd bytes\n", prop.sharedMemPerMultiprocessor); 30 | 31 | printf(">> L2 cache size: %d bytes\n", prop.l2CacheSize); 32 | 33 | printf(">> Memory bus width: %d bits\n", prop.memoryBusWidth); 34 | printf(">> Memory clock rate: %d KHz\n", prop.memoryClockRate); 35 | 36 | int cudaCores = prop.multiProcessorCount * 128; 37 | float clockGHz = prop.clockRate / 1e6; 38 | float gflops = cudaCores * clockGHz * 2; 39 | 40 | printf(">> Theoretical Max GFLOPS: %.2f\n", gflops); 41 | 42 | float memoryBandwidth = (2 * prop.memoryClockRate * prop.memoryBusWidth) / (8.0 * 1e6); 43 | printf(">> Maximum Memory Bandwidth: %.2f GB/s\n", memoryBandwidth); 44 | } -------------------------------------------------------------------------------- /matvec/kernels/utils.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "utils.cuh" 4 | 5 | /* 6 | Helper function to generate a clamped random number sampled from a 7 | normal distribution with mean 0 and std 1 8 | */ 9 | float random_normal_clamped(float min, float max) { 10 | float u1 = (float)rand() / RAND_MAX; 11 | float u2 = (float)rand() / RAND_MAX; 12 | float num = sqrtf(-2.0f * logf(u1)) * cosf(2.0f * M_PI * u2); 13 | if (num < min) 14 | return min; 15 | if (num > max) 16 | return max; 17 | return num; 18 | } 19 | 20 | float compute_gflops(int M, int N, float ms) { 21 | return (2 * M * N) / (ms * 1e6); 22 | } 23 | 24 | float compute_peak_gflops(float gflops, float THEORETICAL_MAX_GFLOPS) { 25 | cudaDeviceProp prop; 26 | cudaGetDeviceProperties(&prop, 0); 27 | 28 | return (gflops / THEORETICAL_MAX_GFLOPS) * 100; 29 | } 30 | 31 | float compute_peak_memory_bandwidth(int M, int N, float ms, float THEORETICAL_MAX_MEMORY_BANDWIDTH) { 32 | cudaDeviceProp prop; 33 | cudaGetDeviceProperties(&prop, 0); 34 | 35 | size_t totalFloats = (size_t)(M * N + N + M); 36 | float totalBytes = (float)totalFloats * sizeof(float); 37 | 38 | float secs = ms / 1000.0f; 39 | float gbPerSec = (totalBytes / secs) / 1.0e9; 40 | 41 | return (gbPerSec / THEORETICAL_MAX_MEMORY_BANDWIDTH) * 100; 42 | } 43 | 44 | void print_kernel_essentials(int M, int N, float ms, float THEORETICAL_MAX_GFLOPS, float THEORETICAL_MAX_MEMORY_BANDWIDTH) { 45 | float gflops = compute_gflops(M, N, ms); 46 | printf(">> Execution time: %f ms\n", ms); 47 | printf(">> Achieved (GFLOPS): %f\n", gflops); 48 | printf(">> Theoretical max (GFLOPS): %f\n", THEORETICAL_MAX_GFLOPS); 49 | printf(">> Maximum memory bandwidth: %f GB/s\n", THEORETICAL_MAX_MEMORY_BANDWIDTH); 50 | printf(">> Achieves %f %% of peak GFLOPS\n", compute_peak_gflops(gflops, THEORETICAL_MAX_GFLOPS)); 51 | printf(">> Achieves %f %% of peak Memory Bandwidth\n", compute_peak_memory_bandwidth(M, N, ms, THEORETICAL_MAX_MEMORY_BANDWIDTH)); 52 | } 53 | -------------------------------------------------------------------------------- /matvec/benchmarks/plot.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | 4 | # Data 5 | matrix_sizes = [128, 256, 512, 1024, 2048, 4096, 8192] 6 | 7 | custom_cuda_gflops = [1.018906, 19.980488, 26.554296, 40.169170, 47.766766, 49.741516, 50.872116] 8 | cublas_gflops = [1.280000, 17.102297, 32.347481, 46.022472, 49.813587, 44.747833, 50.334869] 9 | 10 | custom_cuda_bandwidth = [1.846255, 35.926109, 47.561153, 71.806648, 85.304932, 88.788231, 90.784180] 11 | cublas_bandwidth = [2.319358, 30.750952, 57.937275, 82.270042, 88.960274, 79.874550, 89.825439] 12 | 13 | # Create subplots 14 | fig, axes = plt.subplots(1, 2, figsize=(12, 6)) 15 | 16 | # Plot GFLOPS 17 | axes[0].plot(matrix_sizes, custom_cuda_gflops, label='Custom CUDA Kernel', marker='o', linestyle='-', linewidth=2) 18 | axes[0].plot(matrix_sizes, cublas_gflops, label='cuBLAS', marker='s', linestyle='--', linewidth=2) 19 | axes[0].set_title('SGEMV: GFLOPS Comparison', fontsize=16) 20 | axes[0].set_xlabel('Matrix Size (M)', fontsize=14) 21 | axes[0].set_ylabel('Achieved GFLOPS', fontsize=14) 22 | axes[0].set_xticks(matrix_sizes) 23 | axes[0].tick_params(axis='x', labelrotation=90, labelsize=10) 24 | axes[0].tick_params(axis='y', labelsize=12) 25 | axes[0].grid(True, which='both', linestyle='--', linewidth=0.5) 26 | axes[0].legend(fontsize=12) 27 | 28 | # Plot Bandwidth 29 | axes[1].plot(matrix_sizes, custom_cuda_bandwidth, label='Custom CUDA Kernel', marker='o', linestyle='-', linewidth=2) 30 | axes[1].plot(matrix_sizes, cublas_bandwidth, label='cuBLAS', marker='s', linestyle='--', linewidth=2) 31 | axes[1].set_title('SGEMV: Memory Bandwidth Comparison', fontsize=16) 32 | axes[1].set_xlabel('Matrix Size (M)', fontsize=14) 33 | axes[1].set_ylabel('Achieved Memory Bandwidth (%)', fontsize=14) 34 | axes[1].set_xticks(matrix_sizes) 35 | axes[1].tick_params(axis='x', labelrotation=90, labelsize=10) 36 | axes[1].tick_params(axis='y', labelsize=12) 37 | axes[1].grid(True, which='both', linestyle='--', linewidth=0.5) 38 | axes[1].legend(fontsize=12) 39 | 40 | # Adjust layout and show plot 41 | plt.tight_layout() 42 | plt.show() -------------------------------------------------------------------------------- /softmax/torch_benchmark.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | import numpy as np 4 | import os 5 | 6 | 7 | def random_normal_clamped(min_val, max_val, size): 8 | return np.clip(np.random.normal(0, 1, size), min_val, max_val) 9 | 10 | 11 | def benchmark_kernel_for_sizes(min_n, max_n): 12 | os.makedirs('benchmarks', exist_ok=True) 13 | n_iters = 5 14 | 15 | with open('benchmarks/exec_time_ms_torch.txt', 'w') as exec_time_file: 16 | N = min_n 17 | while N < max_n: 18 | M = 1024 # matrix size (M, N) 19 | print(f'------------ Running PyTorch softmax benchmark for MxN = ({M}, {N}) -------------') 20 | 21 | # Generate random data 22 | mat = torch.tensor(random_normal_clamped(-10, 10, (M, N)), dtype=torch.float32, device="cuda") 23 | torch.cuda.synchronize() 24 | 25 | # Warmup 26 | for _ in range(5): 27 | _ = torch.nn.functional.softmax(mat, dim=-1) 28 | torch.cuda.synchronize() 29 | 30 | total_time = 0 31 | # Run softmax kernel 32 | for i in range(n_iters): 33 | # Measure time 34 | torch.cuda.synchronize() # Ensure all CUDA operations are finished 35 | start = time.time() 36 | _ = torch.nn.functional.softmax(mat, dim=-1) 37 | torch.cuda.synchronize() # Synchronize again 38 | end = time.time() 39 | 40 | total_time += (end - start) * 1000 41 | 42 | exec_time = (total_time/n_iters) 43 | print(f'>> Kernel execution time: {exec_time:.3f} ms') 44 | 45 | # Write execution time to file 46 | exec_time_file.write(f'{M} {exec_time}\n') 47 | 48 | # Clear GPU memory 49 | del mat 50 | torch.cuda.empty_cache() 51 | torch.cuda.ipc_collect() 52 | torch.cuda.synchronize() 53 | 54 | N *= 2 55 | time.sleep(1) 56 | 57 | 58 | if __name__ == '__main__': 59 | benchmark_kernel_for_sizes(2048, 262144) 60 | -------------------------------------------------------------------------------- /softmax/kernels/naive_0.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "cuda_utils.cuh" 6 | 7 | /* 8 | This kernel implements a naive softmax operation on a matrix of size (M, N). 9 | The softmax operation is performed on the last dimension of the matrix. 10 | 11 | How this works: 12 | One thread processes one entire row, and thus this kernel will be the slowest 13 | since we aren't exploiting parallelism capabilities of GPUs that much. 14 | We are only parallelizing over the rows. 15 | */ 16 | __global__ void softmax_kernel_0(float* __restrict__ matd, float* __restrict__ resd, int M, int N) { 17 | int row = blockDim.x * blockIdx.x + threadIdx.x; 18 | 19 | if (row < M) { 20 | // max 21 | float m = -1 * INFINITY; 22 | // norm factor 23 | float L = 0.0f; 24 | 25 | // 3 passes (not optimal) 26 | for (int col = 0; col < N; col++) { 27 | int i = row * N + col; 28 | m = max(m, matd[i]); 29 | } 30 | for (int col = 0; col < N; col++) { 31 | int i = row * N + col; 32 | L += expf(matd[i] - m); 33 | } 34 | for (int col = 0; col < N; col++) { 35 | int i = row * N + col; 36 | resd[i] = expf(matd[i] - m) / L; 37 | } 38 | } 39 | } 40 | 41 | /* 42 | Runs the naive softmax kernel: `id = 0` 43 | */ 44 | void run_kernel_0(float* __restrict__ matd, float* __restrict__ resd, int M, int N) { 45 | // grid size and block size for this kernel 46 | // change as necessary 47 | dim3 block_size(1024); 48 | dim3 grid_size(CEIL_DIV(M, block_size.x)); 49 | 50 | cudaEvent_t start, stop; 51 | CUDA_CHECK(cudaEventCreate(&start)); 52 | CUDA_CHECK(cudaEventCreate(&stop)); 53 | float ms = 0.f; 54 | 55 | CUDA_CHECK(cudaEventRecord(start)); 56 | softmax_kernel_0<<>>(matd, resd, M, N); 57 | CUDA_CHECK(cudaEventRecord(stop)); 58 | CUDA_CHECK(cudaEventSynchronize(stop)); 59 | CUDA_CHECK(cudaEventElapsedTime(&ms, start, stop)); 60 | printf(">> Kernel execution time: %f ms\n", ms); 61 | 62 | CUDA_CHECK(cudaEventDestroy(start)); 63 | CUDA_CHECK(cudaEventDestroy(stop)); 64 | } -------------------------------------------------------------------------------- /softmax/kernels/online_1.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "cuda_utils.cuh" 6 | 7 | /* 8 | This kernel implements an online softmax operation on a matrix of size (M, N). 9 | The softmax operation is performed on the last dimension of the matrix. 10 | 11 | How this works: 12 | One thread processes one entire row, but instead of 3 passes we do only 2 passes. 13 | This is possible due to the property of exponentials. 14 | We are parallelizing over the rows. 15 | */ 16 | __global__ void softmax_kernel_1(float* __restrict__ matd, float* __restrict__ resd, int M, int N) { 17 | int row = blockDim.x * blockIdx.x + threadIdx.x; 18 | 19 | if (row < M) { 20 | float m = -1 * INFINITY; 21 | float L = 0.0f; 22 | 23 | // compute max and norm factor in one pass only 24 | // by exploiting the property of exponentials 25 | for (int col = 0; col < N; col++) { 26 | int i = row * N + col; 27 | float curr = matd[i]; 28 | if (curr > m) { 29 | L = L * expf(m - curr); 30 | m = curr; 31 | } 32 | L += expf(curr - m); 33 | } 34 | for (int col = 0; col < N; col++) { 35 | int i = row * N + col; 36 | resd[i] = expf(matd[i] - m) / L; 37 | } 38 | } 39 | } 40 | 41 | /* 42 | Runs the online softmax kernel: `id = 1` 43 | */ 44 | void run_kernel_1(float* __restrict__ matd, float* __restrict__ resd, int M, int N) { 45 | // grid size and block size for this kernel 46 | // change as necessary 47 | dim3 block_size(1024); 48 | dim3 grid_size(CEIL_DIV(M, block_size.x)); 49 | 50 | cudaEvent_t start, stop; 51 | CUDA_CHECK(cudaEventCreate(&start)); 52 | CUDA_CHECK(cudaEventCreate(&stop)); 53 | float ms = 0.f; 54 | 55 | CUDA_CHECK(cudaEventRecord(start)); 56 | softmax_kernel_1<<>>(matd, resd, M, N); 57 | CUDA_CHECK(cudaEventRecord(stop)); 58 | CUDA_CHECK(cudaEventSynchronize(stop)); 59 | CUDA_CHECK(cudaEventElapsedTime(&ms, start, stop)); 60 | printf(">> Kernel execution time: %f ms\n", ms); 61 | 62 | CUDA_CHECK(cudaEventDestroy(start)); 63 | CUDA_CHECK(cudaEventDestroy(stop)); 64 | } -------------------------------------------------------------------------------- /matvec/kernels/coalesced_warp_2.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include "utils.cuh" 7 | 8 | /* 9 | Coalesced Warp Sgemv kernel 10 | 11 | - Each block is assigned to a row of the matrix A 12 | - Each block calculates one output element of y 13 | - The columns are accessed in coalesced manner by threads 14 | - Performs warp level sum reduction only 15 | - Block size must be equal to number of threads 16 | */ 17 | __global__ void coalesced_warp_sgmev_kernel(float* __restrict__ matd, float* __restrict__ vecd, float* __restrict__ resd, int M, int N) { 18 | assert(blockDim.x == warpSize); 19 | 20 | int bid = blockIdx.x; 21 | if (bid >= M) return; 22 | 23 | int tid = threadIdx.x; 24 | // each thread calculates its own partial output 25 | float partial_sum = 0.f; 26 | for (int col = tid; col < N; col += blockDim.x) { 27 | partial_sum += matd[bid * N + col] * vecd[col]; 28 | } 29 | 30 | // warp level sum reduction 31 | // only first thread writes the output to global memory 32 | float sum = warpReduceSum(partial_sum); 33 | if (tid == 0) { 34 | resd[bid] = sum; 35 | } 36 | } 37 | 38 | /* 39 | Runs the coalesced warp sgemv kernel. 40 | */ 41 | float run_kernel_coalesced_warp_sgmev(float* __restrict__ matd, float* __restrict__ vecd, float* __restrict__ resd, int M, int N, float THEORETICAL_MAX_GFLOPS, float THEORETICAL_MAX_MEMORY_BANDWIDTH) { 42 | int NUM_THREADS = 32; // = warpSize of the GPU 43 | 44 | dim3 block_size(NUM_THREADS); 45 | dim3 grid_size(M); 46 | 47 | cudaEvent_t start, stop; 48 | CUDA_CHECK(cudaEventCreate(&start)); 49 | CUDA_CHECK(cudaEventCreate(&stop)); 50 | float ms = 0.f; 51 | 52 | CUDA_CHECK(cudaEventRecord(start)); 53 | coalesced_warp_sgmev_kernel<<>>(matd, vecd, resd, M, N); 54 | CUDA_CHECK(cudaEventRecord(stop)); 55 | CUDA_CHECK(cudaEventSynchronize(stop)); 56 | CUDA_CHECK(cudaEventElapsedTime(&ms, start, stop)); 57 | printf("------- Coalesced warp sgmev kernel ---------\n"); 58 | print_kernel_essentials(M, N, ms, THEORETICAL_MAX_GFLOPS, THEORETICAL_MAX_MEMORY_BANDWIDTH); 59 | printf("---------------------------\n"); 60 | 61 | CUDA_CHECK(cudaEventDestroy(start)); 62 | CUDA_CHECK(cudaEventDestroy(stop)); 63 | 64 | return ms; 65 | } 66 | -------------------------------------------------------------------------------- /softmax/test.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "blocktiling_5.cuh" 6 | #include "cuda_utils.cuh" 7 | #include "naive_0.cuh" 8 | #include "online_1.cuh" 9 | #include "sharedmem_2.cuh" 10 | #include "shfl_3.cuh" 11 | #include "vectorized_4.cuh" 12 | 13 | /* 14 | Helper function to generate a clamped random number sampled from a 15 | normal distribution with mean 0 and std 1 16 | */ 17 | float random_normal_clamped(float min, float max) { 18 | float u1 = (float)rand() / RAND_MAX; 19 | float u2 = (float)rand() / RAND_MAX; 20 | float num = sqrtf(-2.0f * logf(u1)) * cosf(2.0f * M_PI * u2); 21 | if (num < min) 22 | return min; 23 | if (num > max) 24 | return max; 25 | return num; 26 | } 27 | 28 | int main() { 29 | int M = 4096; 30 | int N = 4096; 31 | int matsize = M * N; 32 | int totalsize = matsize * sizeof(float); 33 | 34 | // allocate and initialize host matrix 35 | float* mat = (float*)malloc(totalsize); 36 | float* res = (float*)malloc(totalsize); 37 | for (int i = 0; i < matsize; i++) { 38 | mat[i] = random_normal_clamped(-10, 10); 39 | } 40 | 41 | float *matd, *resd; 42 | 43 | cudaEvent_t start, stop; 44 | CUDA_CHECK(cudaEventCreate(&start)); 45 | CUDA_CHECK(cudaEventCreate(&stop)); 46 | float ms = 0.0f; 47 | 48 | cudaEventRecord(start); 49 | CUDA_CHECK(cudaMalloc(&matd, totalsize)); 50 | CUDA_CHECK(cudaMalloc(&resd, totalsize)); 51 | cudaEventRecord(stop); 52 | cudaEventSynchronize(stop); 53 | cudaEventElapsedTime(&ms, start, stop); 54 | printf(">> GPU allocation time: %f ms\n", ms); 55 | 56 | cudaEventRecord(start); 57 | CUDA_CHECK(cudaMemcpy(matd, mat, totalsize, cudaMemcpyHostToDevice)); 58 | cudaEventRecord(stop); 59 | cudaEventSynchronize(stop); 60 | cudaEventElapsedTime(&ms, start, stop); 61 | printf(">> Host to device transfer time: %f ms\n", ms); 62 | 63 | run_kernel_4(matd, resd, M, N); 64 | 65 | cudaEventRecord(start); 66 | CUDA_CHECK(cudaMemcpy(res, resd, totalsize, cudaMemcpyDeviceToHost)); 67 | cudaEventRecord(stop); 68 | cudaEventSynchronize(stop); 69 | cudaEventElapsedTime(&ms, start, stop); 70 | printf(">> Device to host transfer time: %f ms\n", ms); 71 | 72 | free(mat); 73 | free(res); 74 | cudaFree(matd); 75 | cudaFree(resd); 76 | } -------------------------------------------------------------------------------- /matvec/include/utils.cuh: -------------------------------------------------------------------------------- 1 | #ifndef CUDA_UTILS_CUH 2 | #define CUDA_UTILS_CUH 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #define CUDA_CHECK(ans) \ 9 | { \ 10 | cudaAssert((ans), __FILE__, __LINE__); \ 11 | } 12 | inline void cudaAssert(cudaError_t code, const char *file, int line) { 13 | if (code != cudaSuccess) { 14 | fprintf(stderr, "CUDA error %s: %s at %s: %d\n", 15 | cudaGetErrorName(code), cudaGetErrorString(code), 16 | file, line); 17 | exit(code); 18 | } 19 | } 20 | #define CEIL_DIV(x, y) ((x) >= 0 ? (((x) + (y) - 1) / (y)) : ((x) / (y))) 21 | #define M_PI 3.14159265f 22 | 23 | float random_normal_clamped(float min, float max); 24 | 25 | float compute_gflops(int M, int N, float ms); 26 | 27 | float compute_peak_gflops(float gflops, float THEORETICAL_MAX_GFLOPS); 28 | 29 | float compute_peak_memory_bandwidth(int M, int N, float ms, float THEORETICAL_MAX_MEMORY_BANDWIDTH); 30 | 31 | void print_kernel_essentials(int M, int N, float ms, float THEORETICAL_MAX_GFLOPS, float THEORETICAL_MAX_MEMORY_BANDWIDTH); 32 | 33 | /* 34 | Reduction functions on device. These will be inline: 35 | The compiler will replace the call with the code instead of calling the function (overhead) 36 | */ 37 | /* 38 | Utility warp level sum reduction with shuffle instructions 39 | */ 40 | 41 | __device__ __forceinline__ float warpReduceSum(float val) { 42 | for (int offset = warpSize / 2; offset > 0; offset /= 2) { 43 | val += __shfl_down_sync(0xffffffff, val, offset); 44 | } 45 | 46 | return val; 47 | } 48 | 49 | __device__ __forceinline__ void blockReduceSum(float val, float *smem, int tid, int blockDimX) { 50 | // 1. do warpReduce sum 51 | val = warpReduceSum(val); 52 | 53 | // 2. do blockReduce sum 54 | if (blockDimX > warpSize) { 55 | int lane = tid % warpSize; 56 | int wid = tid / warpSize; 57 | if (lane == 0) { 58 | smem[wid] = val; 59 | } 60 | __syncthreads(); 61 | 62 | if (tid < warpSize) { 63 | val = tid < CEIL_DIV(blockDimX, warpSize) ? smem[tid] : 0.0f; 64 | val = warpReduceSum(val); 65 | if (tid == 0) smem[0] = val; 66 | } 67 | } else { 68 | if (tid == 0) smem[0] = val; 69 | } 70 | // __syncthreads(); 71 | // sync not needed because only thread 0 reads from smem[0] 72 | } 73 | 74 | #endif // CUDA_UTILS_CUH -------------------------------------------------------------------------------- /matvec/kernels/coalesced_warpblock_3.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include "utils.cuh" 7 | 8 | /* 9 | Coalesced Warp Block Sgemv kernel 10 | 11 | - Each block is assigned to a row of the matrix A 12 | - Each block calculates one output element of y 13 | - The columns are accessed in coalesced manner by threads 14 | - Performs warp level + block level sum reduction 15 | */ 16 | __global__ void coalesced_warpblock_sgmev_kernel(float* __restrict__ matd, float* __restrict__ vecd, float* __restrict__ resd, int M, int N) { 17 | extern __shared__ float smem[]; 18 | 19 | int bid = blockIdx.x; 20 | if (bid >= M) return; 21 | 22 | int tid = threadIdx.x; 23 | // each thread calculates its own partial output 24 | float partial_sum = 0.f; 25 | for (int col = tid; col < N; col += blockDim.x) { 26 | partial_sum += matd[bid * N + col] * vecd[col]; 27 | } 28 | 29 | // block level sum reduction 30 | // only first thread reads the first location in shared memory 31 | // only first thread writes the output to global memory 32 | blockReduceSum(partial_sum, smem, tid, blockDim.x); 33 | if (tid == 0) { 34 | float sum = smem[0]; 35 | resd[bid] = sum; 36 | } 37 | } 38 | 39 | /* 40 | Runs the coalesced warp sgemv kernel. 41 | */ 42 | float run_kernel_coalesced_warpblock_sgmev(float* __restrict__ matd, float* __restrict__ vecd, float* __restrict__ resd, int M, int N, float THEORETICAL_MAX_GFLOPS, float THEORETICAL_MAX_MEMORY_BANDWIDTH) { 43 | int NUM_THREADS = 64; 44 | int warp_size = 32; 45 | 46 | dim3 block_size(NUM_THREADS); 47 | dim3 grid_size(M); 48 | size_t shared_mem_size = CEIL_DIV(block_size.x, warp_size) * sizeof(float); 49 | 50 | cudaEvent_t start, stop; 51 | CUDA_CHECK(cudaEventCreate(&start)); 52 | CUDA_CHECK(cudaEventCreate(&stop)); 53 | float ms = 0.f; 54 | 55 | CUDA_CHECK(cudaEventRecord(start)); 56 | coalesced_warpblock_sgmev_kernel<<>>(matd, vecd, resd, M, N); 57 | CUDA_CHECK(cudaEventRecord(stop)); 58 | CUDA_CHECK(cudaEventSynchronize(stop)); 59 | CUDA_CHECK(cudaEventElapsedTime(&ms, start, stop)); 60 | printf("------- Coalesced warp-block sgmev kernel ---------\n"); 61 | print_kernel_essentials(M, N, ms, THEORETICAL_MAX_GFLOPS, THEORETICAL_MAX_MEMORY_BANDWIDTH); 62 | printf("---------------------------\n"); 63 | 64 | CUDA_CHECK(cudaEventDestroy(start)); 65 | CUDA_CHECK(cudaEventDestroy(stop)); 66 | 67 | return ms; 68 | } 69 | -------------------------------------------------------------------------------- /flash-attention/fa1/smolattn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | import math 5 | from torch.nn import functional as F 6 | from torch.utils.cpp_extension import load 7 | 8 | # Set CUDA architecture for RTX 4090 9 | os.environ["TORCH_CUDA_ARCH_LIST"] = "8.9" 10 | 11 | # Load CUDA kernel 12 | smolattn = load(name='smolattn', sources=['build.cpp', 'flash-attn-1.cu'], extra_cuda_cflags=['-O3']) 13 | 14 | # Model parameters 15 | batch_size = 16 16 | n_head = 8 17 | seq_len = 512 18 | head_embd = 64 19 | 20 | q = torch.randn((batch_size, n_head, seq_len, head_embd), device="cuda") 21 | k = torch.randn((batch_size, n_head, seq_len, head_embd), device="cuda") 22 | v = torch.randn((batch_size, n_head, seq_len, head_embd), device="cuda") 23 | 24 | # Manual attention function 25 | def manual_attn(q, k, v): 26 | att = (q @ k.transpose(-2, -1) * (1.0 / math.sqrt(k.size(-1)))) 27 | att = F.softmax(att, dim=-1) 28 | y = att @ v 29 | return y 30 | 31 | # Function to measure execution time 32 | def benchmark(func, *args, name="Function"): 33 | torch.cuda.synchronize() # Ensure GPU is idle 34 | start = torch.cuda.Event(enable_timing=True) 35 | end = torch.cuda.Event(enable_timing=True) 36 | 37 | start.record() 38 | result = func(*args) 39 | end.record() 40 | 41 | torch.cuda.synchronize() # Wait for event completion 42 | elapsed_time = start.elapsed_time(end) # Time in ms 43 | print(f"{name} execution time: {elapsed_time:.3f} ms") 44 | 45 | return result, elapsed_time 46 | 47 | 48 | print(f"Batch size: {batch_size}, Num heads: {n_head}, Sequence length: {seq_len}, Head dims: {head_embd}\n") 49 | 50 | # Benchmarking manual attention 51 | print("=== Benchmarking Manual Attention ===") 52 | manual_result, manual_time = benchmark(manual_attn, q, k, v, name="Manual Attention") 53 | 54 | # Benchmarking smolattn implementation 55 | print("\n=== Benchmarking SmolAttn Implementation ===") 56 | minimal_result, smolattn_time = benchmark(smolattn.fa_forward, q, k, v, name="SmolAttn") 57 | 58 | # Print speedup 59 | speedup = manual_time / smolattn_time 60 | print(f"\nSmolAttn is {speedup:.2f}x faster than Manual Attention.") 61 | 62 | # Sanity check for correctness 63 | print("\n=== Accuracy Check ===") 64 | tolerance = 1e-2 65 | allclose = torch.allclose(minimal_result, manual_result, rtol=0, atol=tolerance) 66 | print(f"Attn values match within tolerance ({tolerance}): {allclose}") 67 | 68 | # Compute absolute differences 69 | diff = torch.abs(minimal_result - manual_result) 70 | diff_indices = torch.nonzero(diff > tolerance, as_tuple=True) 71 | 72 | # Print the top mismatches 73 | if diff_indices[0].numel() > 0: 74 | print("\nTop mismatches:") 75 | for idx in zip(*diff_indices[:4]): # Print first 4 mismatches 76 | print(f"Index {idx}: manual={manual_result[idx].item():.6f}, minimal={minimal_result[idx].item():.6f}, diff={diff[idx].item():.6f}") 77 | else: 78 | print("\nNo significant mismatches found.") 79 | -------------------------------------------------------------------------------- /flash-attention/fa2/smolattn2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | import math 5 | from torch.nn import functional as F 6 | from torch.utils.cpp_extension import load 7 | 8 | # Set CUDA architecture for RTX 4090 9 | os.environ["TORCH_CUDA_ARCH_LIST"] = "8.9" 10 | 11 | # Load CUDA kernel 12 | smolattn = load(name='smolattn', sources=['build.cpp', 'flash-attn-2.cu'], extra_cuda_cflags=['-O3']) 13 | 14 | # Model parameters 15 | batch_size = 16 16 | n_head = 8 17 | seq_len = 512 18 | head_embd = 64 19 | 20 | q = torch.randn((batch_size, n_head, seq_len, head_embd), device="cuda") 21 | k = torch.randn((batch_size, n_head, seq_len, head_embd), device="cuda") 22 | v = torch.randn((batch_size, n_head, seq_len, head_embd), device="cuda") 23 | 24 | # Manual attention function 25 | def manual_attn(q, k, v): 26 | att = (q @ k.transpose(-2, -1) * (1.0 / math.sqrt(k.size(-1)))) 27 | att = F.softmax(att, dim=-1) 28 | y = att @ v 29 | return y 30 | 31 | # Function to measure execution time 32 | def benchmark(func, *args, name="Function"): 33 | torch.cuda.synchronize() # Ensure GPU is idle 34 | start = torch.cuda.Event(enable_timing=True) 35 | end = torch.cuda.Event(enable_timing=True) 36 | 37 | start.record() 38 | result = func(*args) 39 | end.record() 40 | 41 | torch.cuda.synchronize() # Wait for event completion 42 | elapsed_time = start.elapsed_time(end) # Time in ms 43 | print(f"{name} execution time: {elapsed_time:.3f} ms") 44 | 45 | return result, elapsed_time 46 | 47 | 48 | print(f"Batch size: {batch_size}, Num heads: {n_head}, Sequence length: {seq_len}, Head dims: {head_embd}\n") 49 | 50 | # Benchmarking manual attention 51 | print("=== Benchmarking Manual Attention ===") 52 | manual_result, manual_time = benchmark(manual_attn, q, k, v, name="Manual Attention") 53 | 54 | # Benchmarking smolattn implementation 55 | print("\n=== Benchmarking SmolAttn Implementation ===") 56 | minimal_result, smolattn_time = benchmark(smolattn.fa2_forward, q, k, v, name="SmolAttn2") 57 | 58 | # Print speedup 59 | speedup = manual_time / smolattn_time 60 | print(f"\nSmolAttn2 is {speedup:.2f}x faster than Manual Attention.") 61 | 62 | # Sanity check for correctness 63 | print("\n=== Accuracy Check ===") 64 | tolerance = 1e-2 65 | allclose = torch.allclose(minimal_result, manual_result, rtol=0, atol=tolerance) 66 | print(f"Attn values match within tolerance ({tolerance}): {allclose}") 67 | 68 | # Compute absolute differences 69 | diff = torch.abs(minimal_result - manual_result) 70 | diff_indices = torch.nonzero(diff > tolerance, as_tuple=True) 71 | 72 | # Print the top mismatches 73 | if diff_indices[0].numel() > 0: 74 | print("\nTop mismatches:") 75 | for idx in zip(*diff_indices[:4]): # Print first 4 mismatches 76 | print(f"Index {idx}: manual={manual_result[idx].item():.6f}, minimal={minimal_result[idx].item():.6f}, diff={diff[idx].item():.6f}") 77 | else: 78 | print("\nNo significant mismatches found.") 79 | -------------------------------------------------------------------------------- /flash-attention/fa2/smolattn2_fp16.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | import math 5 | from torch.nn import functional as F 6 | from torch.utils.cpp_extension import load 7 | 8 | # Set CUDA architecture for RTX 4090 9 | os.environ["TORCH_CUDA_ARCH_LIST"] = "8.9" 10 | 11 | # Load CUDA kernel 12 | smolattn = load(name='smolattn', sources=['build.cpp', 'flash2_fwd_fp16.cu'], extra_cuda_cflags=['-O3']) 13 | 14 | # Model parameters 15 | batch_size = 16 16 | n_head = 8 17 | seq_len = 512 18 | head_embd = 64 19 | 20 | q = torch.randn((batch_size, n_head, seq_len, head_embd), dtype=torch.float16, device="cuda") 21 | k = torch.randn((batch_size, n_head, seq_len, head_embd), dtype=torch.float16, device="cuda") 22 | v = torch.randn((batch_size, n_head, seq_len, head_embd), dtype=torch.float16, device="cuda") 23 | 24 | # Manual attention function 25 | def manual_attn(q, k, v): 26 | att = (q @ k.transpose(-2, -1) * (1.0 / math.sqrt(k.size(-1)))) 27 | att = F.softmax(att, dim=-1) 28 | y = att @ v 29 | return y 30 | 31 | # Function to measure execution time 32 | def benchmark(func, *args, name="Function"): 33 | torch.cuda.synchronize() # Ensure GPU is idle 34 | start = torch.cuda.Event(enable_timing=True) 35 | end = torch.cuda.Event(enable_timing=True) 36 | 37 | start.record() 38 | result = func(*args) 39 | end.record() 40 | 41 | torch.cuda.synchronize() # Wait for event completion 42 | elapsed_time = start.elapsed_time(end) # Time in ms 43 | print(f"{name} execution time: {elapsed_time:.3f} ms") 44 | 45 | return result, elapsed_time 46 | 47 | 48 | print(f"Batch size: {batch_size}, Num heads: {n_head}, Sequence length: {seq_len}, Head dims: {head_embd}\n") 49 | 50 | # Benchmarking manual attention 51 | print("=== Benchmarking Manual Attention in FP16 ===") 52 | manual_result, manual_time = benchmark(manual_attn, q, k, v, name="Manual Attention") 53 | 54 | # Benchmarking smolattn implementation 55 | print("\n=== Benchmarking SmolAttn2 FP16 Implementation ===") 56 | minimal_result, smolattn_time = benchmark(smolattn.fa2_forward, q, k, v, name="SmolAttn2") 57 | 58 | # Print speedup 59 | speedup = manual_time / smolattn_time 60 | print(f"\nSmolAttn2 FP16 is {speedup:.2f}x faster than Manual Attention.") 61 | 62 | # Sanity check for correctness 63 | print("\n=== Accuracy Check ===") 64 | tolerance = 1e-2 65 | allclose = torch.allclose(minimal_result, manual_result, rtol=0, atol=tolerance) 66 | print(f"Attn values match within tolerance ({tolerance}): {allclose}") 67 | 68 | # Compute absolute differences 69 | diff = torch.abs(minimal_result - manual_result) 70 | diff_indices = torch.nonzero(diff > tolerance, as_tuple=True) 71 | 72 | # Print the top mismatches 73 | if diff_indices[0].numel() > 0: 74 | print("\nTop mismatches:") 75 | for idx in zip(*diff_indices[:4]): # Print first 4 mismatches 76 | print(f"Index {idx}: manual={manual_result[idx].item():.6f}, minimal={minimal_result[idx].item():.6f}, diff={diff[idx].item():.6f}") 77 | else: 78 | print("\nNo significant mismatches found.") 79 | -------------------------------------------------------------------------------- /softmax/bench.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include "cuda_utils.cuh" 7 | #include "vectorized_4.cuh" 8 | 9 | /* 10 | Helper function to generate a clamped random number sampled from a 11 | normal distribution with mean 0 and std 1 12 | */ 13 | float random_normal_clamped(float min, float max) { 14 | float u1 = (float)rand() / RAND_MAX; 15 | float u2 = (float)rand() / RAND_MAX; 16 | float num = sqrtf(-2.0f * logf(u1)) * cosf(2.0f * M_PI * u2); 17 | if (num < min) 18 | return min; 19 | if (num > max) 20 | return max; 21 | return num; 22 | } 23 | 24 | /* 25 | Benchmarks a kernel for different sizes 26 | */ 27 | void benchmark_kernel_for_sizes(int minN, int maxN) { 28 | FILE *exec_time_file = fopen("benchmarks/exec_time_ms_cuda.txt", "w"); 29 | 30 | if (exec_time_file == NULL) { 31 | perror("Error opening the file for GFLOPS.\n"); 32 | } 33 | 34 | for (int N = minN; N < maxN; N *= 2) { 35 | int M = 1024; // matrix size (M, N) 36 | 37 | printf("------------ Running CUDA softmax benchmark for MxN = (%d, %d) -------------\n", M, N); 38 | 39 | int matsize = M * N; 40 | int totalsize = matsize * sizeof(float); 41 | 42 | // allocate and initialize host matrix 43 | float *mat = (float *)malloc(totalsize); 44 | float *res = (float *)malloc(totalsize); 45 | for (int i = 0; i < matsize; i++) { 46 | mat[i] = random_normal_clamped(-10, 10); 47 | } 48 | 49 | float *matd, *resd; 50 | 51 | cudaEvent_t start, stop; 52 | CUDA_CHECK(cudaEventCreate(&start)); 53 | CUDA_CHECK(cudaEventCreate(&stop)); 54 | float ms = 0.0f; 55 | 56 | cudaEventRecord(start); 57 | CUDA_CHECK(cudaMalloc(&matd, totalsize)); 58 | CUDA_CHECK(cudaMalloc(&resd, totalsize)); 59 | cudaEventRecord(stop); 60 | cudaEventSynchronize(stop); 61 | cudaEventElapsedTime(&ms, start, stop); 62 | printf(">> GPU allocation time: %f ms\n", ms); 63 | 64 | cudaEventRecord(start); 65 | CUDA_CHECK(cudaMemcpy(matd, mat, totalsize, cudaMemcpyHostToDevice)); 66 | cudaEventRecord(stop); 67 | cudaEventSynchronize(stop); 68 | cudaEventElapsedTime(&ms, start, stop); 69 | printf(">> Host to device transfer time: %f ms\n", ms); 70 | 71 | // run softmax kernel 72 | ms = run_kernel_4(matd, resd, M, N); 73 | 74 | fprintf(exec_time_file, "%d %f\n", M, ms); 75 | 76 | cudaEventRecord(start); 77 | CUDA_CHECK(cudaMemcpy(res, resd, totalsize, cudaMemcpyDeviceToHost)); 78 | cudaEventRecord(stop); 79 | cudaEventSynchronize(stop); 80 | cudaEventElapsedTime(&ms, start, stop); 81 | printf(">> Device to host transfer time: %f ms\n", ms); 82 | 83 | free(mat); 84 | free(res); 85 | cudaFree(matd); 86 | cudaFree(resd); 87 | } 88 | 89 | fclose(exec_time_file); 90 | } 91 | 92 | int main() { 93 | benchmark_kernel_for_sizes(2048, 262144); 94 | } -------------------------------------------------------------------------------- /matvec/kernels/vectorized_4.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include "utils.cuh" 7 | 8 | /* 9 | Vectorized Sgemv kernel 10 | 11 | - Each block is assigned to a row of the matrix A 12 | - Each block calculates one output element of y 13 | - The columns are accessed in coalesced manner by threads 14 | - Vectorized loads are done for efficient memory bandwidth 15 | - Performs warp level + block level sum reduction 16 | */ 17 | __global__ void vectorized_sgemv_kernel(float* __restrict__ matd, float* __restrict__ vecd, float* __restrict__ resd, int M, int N) { 18 | extern __shared__ float smem[]; 19 | 20 | int bid = blockIdx.x; 21 | if (bid >= M) return; 22 | 23 | int tid = threadIdx.x; 24 | int n_float4s = N / 4; 25 | 26 | // cast the matrix and vector as float4 27 | // float4 holds multiple values (x, y, z, w) 28 | float4* mat_row = reinterpret_cast(matd + bid * N); 29 | float4* vec = reinterpret_cast(vecd); 30 | 31 | // each thread calculates its own partial output 32 | float partial_sum = 0.f; 33 | 34 | // manual loop unrolling with a factor of 4 35 | #pragma unroll 4 36 | for (int col = tid; col < n_float4s; col += blockDim.x) { 37 | float4 matval = mat_row[col]; 38 | float4 vecval = vec[col]; 39 | 40 | partial_sum += (matval.x * vecval.x + 41 | matval.y * vecval.y + 42 | matval.z * vecval.z + 43 | matval.w * vecval.w); 44 | } 45 | 46 | // block level sum reduction 47 | // only first thread reads the first location in shared memory 48 | // only first thread writes the output to global memory 49 | blockReduceSum(partial_sum, smem, tid, blockDim.x); 50 | if (tid == 0) { 51 | float sum = smem[0]; 52 | resd[bid] = sum; 53 | } 54 | } 55 | 56 | /* 57 | Runs the vectorized sgemv kernel. 58 | */ 59 | float run_kernel_vectorized_sgmev(float* __restrict__ matd, float* __restrict__ vecd, float* __restrict__ resd, int M, int N, float THEORETICAL_MAX_GFLOPS, float THEORETICAL_MAX_MEMORY_BANDWIDTH) { 60 | int NUM_THREADS = 64; 61 | int warp_size = 32; 62 | 63 | dim3 block_size(NUM_THREADS); 64 | dim3 grid_size(M); 65 | size_t shared_mem_size = CEIL_DIV(block_size.x, warp_size) * sizeof(float); 66 | 67 | cudaEvent_t start, stop; 68 | CUDA_CHECK(cudaEventCreate(&start)); 69 | CUDA_CHECK(cudaEventCreate(&stop)); 70 | float ms = 0.f; 71 | 72 | CUDA_CHECK(cudaEventRecord(start)); 73 | vectorized_sgemv_kernel<<>>(matd, vecd, resd, M, N); 74 | CUDA_CHECK(cudaEventRecord(stop)); 75 | CUDA_CHECK(cudaEventSynchronize(stop)); 76 | CUDA_CHECK(cudaEventElapsedTime(&ms, start, stop)); 77 | printf("------- Vectorized sgmev kernel ---------\n"); 78 | print_kernel_essentials(M, N, ms, THEORETICAL_MAX_GFLOPS, THEORETICAL_MAX_MEMORY_BANDWIDTH); 79 | printf("---------------------------\n"); 80 | 81 | CUDA_CHECK(cudaEventDestroy(start)); 82 | CUDA_CHECK(cudaEventDestroy(stop)); 83 | 84 | return ms; 85 | } 86 | -------------------------------------------------------------------------------- /attention/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | import math 5 | from torch.nn import functional as F 6 | from torch.utils.cpp_extension import load 7 | from pycublas import cublas_create_handle, cublas_destroy_handle 8 | 9 | # Set CUDA architecture for RTX 4090 10 | os.environ["TORCH_CUDA_ARCH_LIST"] = "8.9" 11 | 12 | handle = cublas_create_handle() 13 | 14 | # Load CUDA kernel 15 | smolattn = load(name='smolattn', sources=['build.cpp', 'attn.cu'], extra_cuda_cflags=['-O3', '-arch=sm_89', '-lcublas']) 16 | 17 | # Model parameters 18 | batch_size = 32 19 | n_head = 8 20 | seq_len = 1024 21 | head_embd = 256 22 | 23 | q = torch.randn((batch_size, n_head, seq_len, head_embd), device="cuda") 24 | k = torch.randn((batch_size, n_head, seq_len, head_embd), device="cuda") 25 | v = torch.randn((batch_size, n_head, seq_len, head_embd), device="cuda") 26 | 27 | # Manual attention function 28 | def manual_attn(q, k, v): 29 | att = (q @ k.transpose(-2, -1) * (1.0 / math.sqrt(k.size(-1)))) 30 | att = F.softmax(att, dim=-1) 31 | y = att @ v 32 | return y 33 | 34 | # Function to measure execution time 35 | def benchmark(func, *args, name="Function"): 36 | torch.cuda.synchronize() # Ensure GPU is idle 37 | start = torch.cuda.Event(enable_timing=True) 38 | end = torch.cuda.Event(enable_timing=True) 39 | 40 | start.record() 41 | result = func(*args) 42 | end.record() 43 | 44 | torch.cuda.synchronize() # Wait for event completion 45 | elapsed_time = start.elapsed_time(end) # Time in ms 46 | print(f"{name} execution time: {elapsed_time:.3f} ms") 47 | 48 | return result, elapsed_time 49 | 50 | 51 | # ------------------------- 52 | # Warm-up Step 53 | # ------------------------- 54 | # Run each function a few times to "warm up" the GPU 55 | warmup_iters = 1 56 | 57 | print("=== Warming up the GPU ===") 58 | for _ in range(warmup_iters): 59 | _ = manual_attn(q, k, v) 60 | 61 | for _ in range(warmup_iters): 62 | _ = smolattn.attention_forward(handle.value, q, k, v) 63 | 64 | # After warm-up, synchronize again 65 | torch.cuda.synchronize() 66 | 67 | 68 | print(f"Batch size: {batch_size}, Num heads: {n_head}, Sequence length: {seq_len}, Head dims: {head_embd}\n") 69 | 70 | # Benchmarking manual attention 71 | print("=== Benchmarking Manual Attention ===") 72 | manual_result, manual_time = benchmark(manual_attn, q, k, v, name="Manual Attention") 73 | 74 | # Benchmarking smolattn implementation 75 | print("\n=== Benchmarking SmolAttn Implementation ===") 76 | minimal_result, smolattn_time = benchmark(smolattn.attention_forward, handle.value, q, k, v, name="SmolAttn") 77 | 78 | # Sanity check for correctness 79 | print("\n=== Accuracy Check ===") 80 | tolerance = 1e-2 81 | allclose = torch.allclose(minimal_result, manual_result, rtol=0, atol=tolerance) 82 | print(f"Attn values match within tolerance ({tolerance}): {allclose}") 83 | 84 | # Compute absolute differences 85 | diff = torch.abs(minimal_result - manual_result) 86 | diff_indices = torch.nonzero(diff > tolerance, as_tuple=True) 87 | 88 | # # Print the top mismatches 89 | # if diff_indices[0].numel() > 0: 90 | # print("\nTop mismatches:") 91 | # for idx in zip(*diff_indices[:4]): # Print first 4 mismatches 92 | # print(f"Index {idx}: manual={manual_result[idx].item():.6f}, minimal={minimal_result[idx].item():.6f}, diff={diff[idx].item():.6f}") 93 | # else: 94 | # print("\nNo significant mismatches found.") 95 | 96 | 97 | cublas_destroy_handle(handle) 98 | -------------------------------------------------------------------------------- /softmax/include/cuda_utils.cuh: -------------------------------------------------------------------------------- 1 | #ifndef CUDA_UTILS_CUH 2 | #define CUDA_UTILS_CUH 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #define CUDA_CHECK(ans) \ 10 | { \ 11 | cudaAssert((ans), __FILE__, __LINE__); \ 12 | } 13 | inline void cudaAssert(cudaError_t code, const char *file, int line) { 14 | if (code != cudaSuccess) { 15 | fprintf(stderr, "CUDA error %s: %s at %s: %d\n", 16 | cudaGetErrorName(code), cudaGetErrorString(code), 17 | file, line); 18 | exit(code); 19 | } 20 | } 21 | #define CEIL_DIV(x, y) ((x) >= 0 ? (((x) + (y) - 1) / (y)) : ((x) / (y))) 22 | 23 | #define M_PI 3.14159265f 24 | 25 | #ifndef WARP_SIZE 26 | #define WARP_SIZE 32 27 | #endif 28 | 29 | 30 | // warp reduce for any Op 31 | template 32 | __device__ __forceinline__ T warpReduce(T val, Op op, unsigned int mask = 0xffffffffu) { 33 | for(int offset = WARP_SIZE >> 1; offset > 0; offset >>= 1) 34 | val = op(val, __shfl_down_sync(mask, val, offset)); 35 | return val; 36 | } 37 | 38 | // warp reduce for sum and max using lambda functions 39 | template 40 | __device__ __forceinline__ T warpReduceSum(T val) { 41 | return warpReduce(val, []__device__(T a, T b) { return a + b; }); 42 | } 43 | 44 | template 45 | __device__ __forceinline__ T warpReduceMax(T val) { 46 | return warpReduce(val, []__device__(T a, T b) { return a > b ? a : b; }); 47 | } 48 | 49 | template 50 | __device__ __forceinline__ void blockReduce(T val, T *smem, T identity, Op op) { 51 | int tx = threadIdx.x; 52 | int wid = tx / WARP_SIZE; 53 | int lane = tx % WARP_SIZE; 54 | 55 | val = warpReduce(val, op); 56 | 57 | // when blockDim is greater than 32, we need to do a block level reduction 58 | // AFTER warp level reductions since we have the 8 maximum values that needs to be reduced again 59 | // the global max will be stored in the first warp 60 | if (blockDim.x > WARP_SIZE) { 61 | if (lane == 0) { 62 | // which warp are we at? 63 | // store the value in its first thread index 64 | smem[wid] = val; 65 | } 66 | __syncthreads(); 67 | 68 | // first warp will do global reduction only 69 | // this is possible because we stored the values in the shared memory 70 | // so the threads in the first warp will read from it and then reduce 71 | if (tx < WARP_SIZE) { 72 | val = (tx < CEIL_DIV(blockDim.x, WARP_SIZE)) ? smem[tx] : identity; 73 | val = warpReduce(val, op); 74 | if (tx == 0) smem[0] = val; 75 | } 76 | } else { 77 | // this is for when the number of threads in a block are not 78 | // greater than the warp size, in that case we already reduced 79 | // so we can store the value 80 | if (tx == 0) smem[0] = val; 81 | } 82 | } 83 | 84 | template 85 | __device__ __forceinline__ void blockReduceSum(T val, T *smem, T identity) { 86 | return blockReduce( 87 | val, smem, identity, []__device__(T a, T b) { return a + b; } 88 | ); 89 | } 90 | 91 | 92 | template 93 | __device__ __forceinline__ void blockReduceMax(T val, T *smem, T identity) { 94 | return blockReduce( 95 | val, smem, identity, []__device__(T a, T b) { return a > b ? a : b; } 96 | ); 97 | } 98 | 99 | #endif // CUDA_UTILS_CUH -------------------------------------------------------------------------------- /softmax/kernels/sharedmem_2.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "cuda_utils.cuh" 6 | 7 | /* 8 | This kernel implements an online softmax operation on a matrix of size (M, N). 9 | The softmax operation is performed on the last dimension of the matrix. 10 | 11 | How this works: 12 | In this, we handle each row with a block where the threads within one block work together 13 | to process one row (max and norm factor). Each thread will process some elements 14 | and will contains its local max and local norm in shared memory. Then, we perform reduction 15 | operations to compute the final max and norm factor. Also, we compute maxes and norms 16 | in one pass itself. 17 | */ 18 | __global__ void softmax_kernel_2(float* __restrict__ xd, float* __restrict__ resd, int M, int N) { 19 | // max and norm reduction will happen in shared memory (static) 20 | __shared__ float smem[1024]; 21 | 22 | int row = blockIdx.x; 23 | int tid = threadIdx.x; 24 | 25 | // edge condition (we don't process further) 26 | if (row >= M) return; 27 | 28 | float* input_row = xd + row * N; 29 | float* output_row = resd + row * N; 30 | float local_max = -INFINITY; 31 | float local_norm = 0.0f; 32 | 33 | // compute local max and norm for each thread 34 | // and then finally have a sync barrier before moving on 35 | for (int i = tid; i < N; i += blockDim.x) { 36 | float x = input_row[i]; 37 | if (x > local_max) { 38 | local_norm *= expf(local_max - x); 39 | local_max = x; 40 | } 41 | local_norm += expf(x - local_max); 42 | } 43 | __syncthreads(); 44 | 45 | // each thread will have its own local max 46 | // we store it in the tid of the shared memory 47 | smem[tid] = local_max; 48 | __syncthreads(); 49 | 50 | // block-level reduction in O(log(N)) time over all threads 51 | // is faster than linear reduction over all threads 52 | for (int stride = blockDim.x / 2; stride > 0; stride /= 2) { 53 | if (tid < stride) { 54 | smem[tid] = max(smem[tid], smem[tid + stride]); 55 | } 56 | // sync barrier before next iteration to ensure correctness 57 | __syncthreads(); 58 | } 59 | 60 | // the first element after max reduction from all threads 61 | // will contain the global max for the row 62 | float row_max = smem[0]; 63 | __syncthreads(); 64 | 65 | // each thread will have its own local norm 66 | // we will store the corrected local norm in the shared memory 67 | // again, exploits property of exponentials 68 | smem[tid] = local_norm * expf(local_max - row_max); 69 | __syncthreads(); 70 | 71 | // sum reduction similar to above for global norm factor 72 | for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) { 73 | if (tid < stride) { 74 | smem[tid] += smem[tid + stride]; 75 | } 76 | __syncthreads(); 77 | } 78 | float row_norm = smem[0]; 79 | __syncthreads(); 80 | 81 | // finally, compute softmax 82 | for (int i = tid; i < N; i += blockDim.x) { 83 | output_row[i] = expf(input_row[i] - row_max) / row_norm; 84 | } 85 | } 86 | 87 | /* 88 | Runs the online softmax kernel: `id = 2` 89 | */ 90 | void run_kernel_2(float* __restrict__ matd, float* __restrict__ resd, int M, int N) { 91 | // grid size and block size for this kernel 92 | // change as necessary 93 | dim3 block_size(1024); 94 | dim3 grid_size(M); 95 | 96 | cudaEvent_t start, stop; 97 | CUDA_CHECK(cudaEventCreate(&start)); 98 | CUDA_CHECK(cudaEventCreate(&stop)); 99 | float ms = 0.f; 100 | 101 | CUDA_CHECK(cudaEventRecord(start)); 102 | softmax_kernel_2<<>>(matd, resd, M, N); 103 | CUDA_CHECK(cudaEventRecord(stop)); 104 | CUDA_CHECK(cudaEventSynchronize(stop)); 105 | CUDA_CHECK(cudaEventElapsedTime(&ms, start, stop)); 106 | printf(">> Kernel execution time: %f ms\n", ms); 107 | 108 | CUDA_CHECK(cudaEventDestroy(start)); 109 | CUDA_CHECK(cudaEventDestroy(stop)); 110 | } -------------------------------------------------------------------------------- /attention/bench.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import math 4 | import torch 5 | torch.backends.cuda.matmul.allow_tf32 = True 6 | from torch.nn import functional as F 7 | from torch.utils.cpp_extension import load 8 | from pycublas import cublas_create_handle, cublas_destroy_handle 9 | 10 | import matplotlib.pyplot as plt 11 | import pandas as pd 12 | import numpy as np 13 | 14 | # on H100 15 | os.environ["TORCH_CUDA_ARCH_LIST"] = "8.9" 16 | plt.style.use('Solarize_Light2') 17 | 18 | handle = cublas_create_handle() 19 | smolattn = load(name='smolattn', sources=['build.cpp', 'attn.cu'], 20 | extra_cuda_cflags=['-O3', '-arch=sm_89', '-lcublas']) 21 | 22 | # Fixed model parameters 23 | batch_size = 2 24 | n_head = 16 25 | head_embd = 64 26 | 27 | # Manual attention function 28 | def manual_attn(q, k, v): 29 | att = (q @ k.transpose(-2, -1) * (1.0 / math.sqrt(k.size(-1)))) 30 | att = F.softmax(att, dim=-1) 31 | y = att @ v 32 | return y 33 | 34 | # Benchmarking function to measure execution time 35 | def benchmark(func, *args, name="Function"): 36 | torch.cuda.synchronize() # Ensure GPU is idle 37 | start = torch.cuda.Event(enable_timing=True) 38 | end = torch.cuda.Event(enable_timing=True) 39 | 40 | start.record() 41 | result = func(*args) 42 | end.record() 43 | 44 | torch.cuda.synchronize() # Wait for event completion 45 | elapsed_time = start.elapsed_time(end) # Time in ms 46 | print(f"{name} execution time: {elapsed_time:.3f} ms") 47 | return result, elapsed_time 48 | 49 | # List of sequence lengths to evaluate 50 | seq_lens = [512, 1024, 2048, 4096, 8192] 51 | 52 | # Lists to store execution times for each implementation 53 | manual_times = [] 54 | smolattn_times = [] 55 | 56 | # Number of warmup iterations 57 | warmup_iters = 3 58 | 59 | print("=== Benchmarking across different sequence lengths ===") 60 | for seq_len in seq_lens: 61 | print(f"\n--- Sequence length: {seq_len} ---") 62 | # Create random Q, K, V tensors for the current sequence length 63 | q = torch.randn((batch_size, n_head, seq_len, head_embd), device="cuda") 64 | k = torch.randn((batch_size, n_head, seq_len, head_embd), device="cuda") 65 | v = torch.randn((batch_size, n_head, seq_len, head_embd), device="cuda") 66 | 67 | # Warm-up for manual attention 68 | for _ in range(warmup_iters): 69 | _ = manual_attn(q, k, v) 70 | # Warm-up for smolattn attention 71 | for _ in range(warmup_iters): 72 | _ = smolattn.attention_forward(handle.value, q, k, v) 73 | torch.cuda.synchronize() 74 | 75 | # Benchmark manual attention 76 | _, m_time = benchmark(manual_attn, q, k, v, name="Manual Attention") 77 | manual_times.append(m_time) 78 | 79 | # Benchmark smolattn implementation 80 | minimal_result, s_time = benchmark(smolattn.attention_forward, handle.value, q, k, v, name="SmolAttn") 81 | smolattn_times.append(s_time) 82 | 83 | # Sanity check for correctness (optional) 84 | tolerance = 1e-2 85 | allclose = torch.allclose(minimal_result, manual_attn(q, k, v), rtol=0, atol=tolerance) 86 | print(f"Attention values match within tolerance ({tolerance}): {allclose}") 87 | 88 | # Cleanup the cuBLAS handle 89 | cublas_destroy_handle(handle) 90 | 91 | # Create a DataFrame for plotting 92 | data = { 93 | "Sequence Length": seq_lens * 2, 94 | "Time (ms)": manual_times + smolattn_times, 95 | "Implementation": ["Manual"] * len(seq_lens) + ["SmolAttn"] * len(seq_lens) 96 | } 97 | df = pd.DataFrame(data) 98 | df.to_csv("benchmark.csv") 99 | 100 | plt.figure(figsize=(12, 8)) 101 | 102 | # # Plot the manual times 103 | # plt.plot(seq_lens, manual_times, marker='o', label="Manual", color="#FDB813") # a golden tone 104 | 105 | # # Plot the smolattn times 106 | # plt.plot(seq_lens, smolattn_times, marker='o', label="SmolAttn", color="#D95F02") # a contrasting color 107 | 108 | # Create positions for grouped bars 109 | x = np.arange(len(seq_lens)) 110 | width = 0.35 111 | 112 | # Plot the bars for Manual and SmolAttn times 113 | plt.bar(x - width/2, smolattn_times, width, label="SmolAttn", color="#D95F02") 114 | plt.bar(x + width/2, manual_times, width, label="Manual", color="#FDB813") 115 | 116 | 117 | plt.title("Execution Speed: Manual vs SmolAttn") 118 | plt.xlabel("Sequence Length") 119 | plt.ylabel("Execution Time (ms)") 120 | plt.xticks(x, seq_lens) 121 | plt.legend(title="Implementation") 122 | plt.tight_layout() 123 | 124 | # Save the figure to a file instead of showing it 125 | plt.savefig("execution_speed_comparison.png") 126 | print("Figure saved as execution_speed_comparison.png") 127 | -------------------------------------------------------------------------------- /softmax/kernels/blocktiling_5.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include "cuda_utils.cuh" 7 | 8 | const int TILE_SIZE = 4; 9 | 10 | 11 | /* 12 | Takes in an array of size `TILE_SIZE` and reduces it as warp-wide sum. 13 | The first element in the array will contain the reduced sum. 14 | */ 15 | __device__ __forceinline__ float warpReduceSum(float val) { 16 | for(int offset = warpSize / 2; offset > 0; offset /= 2) { 17 | val += __shfl_down_sync(0xffffffff, val, offset); 18 | } 19 | return val; 20 | } 21 | 22 | /* 23 | Takes in an array of size `TILE_SIZE` and reduces it warp-wide max. 24 | The first element in the array will contain the reduced max. 25 | */ 26 | __device__ __forceinline__ float warpReduceMax(float val) { 27 | for (int offset = warpSize / 2; offset > 0; offset /= 2) { 28 | val = fmaxf(val, __shfl_down_sync(0xffffffff, val, offset)); 29 | } 30 | } 31 | 32 | // __device__ __forceinline__ void blockReduceSum(volatile float* arr, float* smem, int ty, int tx, int blockDimX) { 33 | // warpReduceSum(arr); 34 | 35 | // if (blockDimX <= warpSize) { 36 | // // for small block sizes, single warp reduction is sufficient 37 | // return; 38 | // } 39 | 40 | // int cols = blockDimX / warpSize; 41 | // int lane = tx % warpSize; 42 | // int wid = tx / warpSize; 43 | 44 | // if (lane == 0) { 45 | // #pragma unroll 46 | // smem[ty*cols + wid] = arr[wid]; 47 | // } 48 | // __syncthreads(); 49 | 50 | // if (tx < warpSize) { 51 | // #pragma unroll 52 | // for(int i = 0; i < TILE_SIZE; i++) { 53 | // arr[i] = smem[i*cols + tx]; 54 | // } 55 | // warpReduceSum(arr); 56 | // } 57 | // } 58 | 59 | /* 60 | How this works: 61 | Instead of having one block calculate only one row of the output matrix, one block 62 | will compute `TILE_SIZE` number of rows. This way we will have fewer blocks, and more 63 | computations per block. 2D threads will process elements. `tx` will process elements and 64 | `ty` will process rows in a block 65 | 66 | We will need partial block-wide reduction with width `tx` threads participating in the reduction 67 | for each row's maximum and norm value. 68 | */ 69 | __global__ void softmax_kernel_5(float* __restrict__ xd, float* __restrict__ resd, int M, int N) { 70 | int bx = blockDim.x; 71 | 72 | // ty equals TILE_SIZE 73 | int ty = threadIdx.y; 74 | int tx = threadIdx.x; 75 | 76 | // result matrix's row 77 | int row = (bx * TILE_SIZE + ty); 78 | if (row >= M) return; 79 | 80 | // one for each row 81 | float local_maxs[TILE_SIZE] = {-INFINITY}; 82 | float local_norms[TILE_SIZE] = {0.f}; 83 | float x[TILE_SIZE] = {0.f}; 84 | 85 | for (int i = tx; i < N; i += bx) { 86 | #pragma unroll 87 | for (int j = 0; j < TILE_SIZE; j++) { 88 | x[j] = xd[row * N + i]; 89 | if (x[j] > local_maxs[j]) { 90 | local_norms[j] *= expf(local_maxs[j] - x[j]); 91 | local_maxs[j] = x[j]; 92 | } 93 | local_norms[j] += expf(x[j] - local_maxs[j]); 94 | } 95 | } 96 | __syncthreads(); 97 | 98 | for(int tile = 0; tile < TILE_SIZE; tile++) { 99 | float lm = local_maxs[tile]; 100 | local_maxs[tile] = warpReduceMax(lm); 101 | local_norms[tile] *= expf(lm - local_maxs[tile]); 102 | local_norms[tile] = warpReduceSum(local_norms[tile]); 103 | } 104 | 105 | // finally, compute softmax 106 | for (int i = tx; i < N; i += bx) 107 | for (int tile = 0; tile < TILE_SIZE; tile++) 108 | resd[row * N + i] = expf(xd[row * N + i] - local_maxs[tile]) / local_norms[tile]; 109 | } 110 | 111 | /* 112 | Runs the online softmax kernel: `id = 5` 113 | */ 114 | void run_kernel_5(float* __restrict__ matd, float* __restrict__ resd, int M, int N) { 115 | // grid size and block size for this kernel 116 | // change as necessary 117 | int num_threads_x = 32; 118 | 119 | dim3 block_size(TILE_SIZE, num_threads_x); 120 | dim3 grid_size(CEIL_DIV(M, TILE_SIZE)); 121 | 122 | cudaEvent_t start, stop; 123 | CUDA_CHECK(cudaEventCreate(&start)); 124 | CUDA_CHECK(cudaEventCreate(&stop)); 125 | float ms = 0.f; 126 | 127 | CUDA_CHECK(cudaEventRecord(start)); 128 | softmax_kernel_5<<>>(matd, resd, M, N); 129 | CUDA_CHECK(cudaEventRecord(stop)); 130 | CUDA_CHECK(cudaEventSynchronize(stop)); 131 | CUDA_CHECK(cudaEventElapsedTime(&ms, start, stop)); 132 | printf(">> Kernel execution time: %f ms\n", ms); 133 | 134 | CUDA_CHECK(cudaEventDestroy(start)); 135 | CUDA_CHECK(cudaEventDestroy(stop)); 136 | } -------------------------------------------------------------------------------- /matvec/src.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | #include "coalesced_warp_2.cuh" 8 | #include "coalesced_warpblock_3.cuh" 9 | #include "cublas_0.cuh" 10 | #include "naive_1.cuh" 11 | #include "utils.cuh" 12 | #include "vectorized_4.cuh" 13 | 14 | /* 15 | Benchmarks a kernel against cuBLAS for different sizes 16 | */ 17 | void benchmark_kernel_for_sizes(int minM, int maxM, float THEORETICAL_MAX_GFLOPS, float THEORETICAL_MAX_MEMORY_BANDWIDTH) { 18 | FILE *gflops_file = fopen("benchmarks/kernel_4_vs_cublas-gflops.txt", "w"); 19 | FILE *memory_file = fopen("benchmarks/kernel_4_vs_cublas-memory.txt", "w"); 20 | 21 | if (gflops_file == NULL) { 22 | perror("Error opening the file for GFLOPS.\n"); 23 | } 24 | if (memory_file == NULL) { 25 | perror("Error opening the file for Memory Bandwidth.\n"); 26 | } 27 | 28 | for (int M = minM; M <= maxM; M *= 2) { 29 | int N = 2 * M; // matrix size (M, N) 30 | 31 | printf("------------ Running benchmark for M = %d ---------------\n", M); 32 | 33 | size_t matsize = M * N; // (M, N) 34 | size_t vecsize = N; // (N, 1) 35 | size_t mat_totalsize = matsize * sizeof(float); 36 | size_t vec_totalsize = vecsize * sizeof(float); 37 | 38 | // allocate host 39 | float *mat = (float *)malloc(mat_totalsize); 40 | float *vec = (float *)malloc(vec_totalsize); 41 | float *res = (float *)malloc(M * sizeof(float)); 42 | 43 | for (size_t i = 0; i < matsize; i++) { 44 | mat[i] = random_normal_clamped(-10.f, 10.f); 45 | // hacky way to init the vector as well 46 | if (i < vecsize) { 47 | vec[i] = random_normal_clamped(-10.f, 10.f); 48 | } 49 | } 50 | 51 | cudaEvent_t start, stop; 52 | CUDA_CHECK(cudaEventCreate(&start)); 53 | CUDA_CHECK(cudaEventCreate(&stop)); 54 | float ms = 0.0f; 55 | 56 | // allocate device 57 | float *matd, *vecd, *resd; 58 | cudaEventRecord(start); 59 | CUDA_CHECK(cudaMalloc((void **)&matd, mat_totalsize)); 60 | CUDA_CHECK(cudaMalloc((void **)&vecd, vec_totalsize)); 61 | CUDA_CHECK(cudaMalloc((void **)&resd, M * sizeof(float))); 62 | cudaEventRecord(stop); 63 | cudaEventSynchronize(stop); 64 | cudaEventElapsedTime(&ms, start, stop); 65 | printf(">> GPU allocation time: %f ms\n", ms); 66 | 67 | // copy host to device 68 | cudaEventRecord(start); 69 | CUDA_CHECK(cudaMemcpy(matd, mat, mat_totalsize, cudaMemcpyHostToDevice)); 70 | CUDA_CHECK(cudaMemcpy(vecd, vec, vec_totalsize, cudaMemcpyHostToDevice)); 71 | CUDA_CHECK(cudaMemcpy(resd, res, M * sizeof(float), cudaMemcpyHostToDevice)); 72 | cudaEventRecord(stop); 73 | cudaEventSynchronize(stop); 74 | cudaEventElapsedTime(&ms, start, stop); 75 | printf(">> Host to device transfer time: %f ms\n", ms); 76 | 77 | // run cuBLAS kernel and write results to file 78 | float mscub = run_kernel_cublas_sgemv(matd, vecd, resd, M, N, THEORETICAL_MAX_GFLOPS, THEORETICAL_MAX_MEMORY_BANDWIDTH); 79 | float gflopscub = compute_gflops(M, N, mscub); 80 | float mem_bandcub = compute_peak_memory_bandwidth(M, N, mscub, THEORETICAL_MAX_MEMORY_BANDWIDTH); 81 | 82 | // run custom kernel and write results to file 83 | ms = run_kernel_vectorized_sgmev(matd, vecd, resd, M, N, THEORETICAL_MAX_GFLOPS, THEORETICAL_MAX_MEMORY_BANDWIDTH); 84 | float gflops = compute_gflops(M, N, ms); 85 | float mem_band = compute_peak_memory_bandwidth(M, N, ms, THEORETICAL_MAX_MEMORY_BANDWIDTH); 86 | 87 | fprintf(gflops_file, "%d %f %f\n", M, gflops, gflopscub); 88 | fprintf(memory_file, "%d %f %f\n", M, mem_band, mem_bandcub); 89 | 90 | // copy device to host 91 | cudaEventRecord(start); 92 | CUDA_CHECK(cudaMemcpy(res, resd, M * sizeof(float), cudaMemcpyDeviceToHost)); 93 | cudaEventRecord(stop); 94 | cudaEventSynchronize(stop); 95 | cudaEventElapsedTime(&ms, start, stop); 96 | printf(">> Device to host transfer time: %f ms\n", ms); 97 | 98 | // cleanup 99 | cudaFree(matd); 100 | cudaFree(vecd); 101 | cudaFree(resd); 102 | free(mat); 103 | free(vec); 104 | free(res); 105 | } 106 | 107 | fclose(gflops_file); 108 | fclose(memory_file); 109 | } 110 | 111 | int main() { 112 | cudaDeviceProp prop; 113 | cudaGetDeviceProperties(&prop, 0); 114 | int cudaCores = prop.multiProcessorCount * 128; 115 | float clockGHz = prop.clockRate / 1e6; 116 | 117 | float THEORETICAL_MAX_GFLOPS = cudaCores * clockGHz * 2; 118 | float THEORETICAL_MAX_MEMORY_BANDWIDTH = (2 * prop.memoryClockRate * prop.memoryBusWidth) / (8.0 * 1e6); 119 | 120 | benchmark_kernel_for_sizes(4096, 4096, THEORETICAL_MAX_GFLOPS, THEORETICAL_MAX_MEMORY_BANDWIDTH); 121 | } -------------------------------------------------------------------------------- /softmax/kernels/vectorized_4.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include "cuda_utils.cuh" 7 | 8 | /* 9 | This kernel implements an online softmax operation on a matrix of size (M, N). 10 | The softmax operation is performed on the last dimension of the matrix. 11 | 12 | How this works: 13 | Instead of accessing shared memory and having sync barrier overhead, we will use warp-level primitives (then 14 | block-level) for performing max and sum reductions. The benefit is: it is faster than shared 15 | memory access and also does not need syncing since each warp (group of 32 threads) execute 16 | an instuction parallely on GPU so no chance of race conditions. 17 | 18 | We will also use vectorized loads and stores. 19 | */ 20 | __global__ void softmax_kernel_4(float* __restrict__ xd, float* __restrict__ resd, int M, int N) { 21 | // max and norm reduction will happen in shared memory (static) 22 | extern __shared__ float smem[]; 23 | 24 | int row = blockIdx.x; 25 | int tid = threadIdx.x; 26 | if (row >= M) return; 27 | 28 | float* input_row = xd + row * N; 29 | float* output_row = resd + row * N; 30 | float local_max = -INFINITY; 31 | float local_norm = 0.0f; 32 | 33 | // cast as float4 34 | int n_float4s = N / 4; 35 | int tail = N % 4; 36 | float4* input_row_vec = reinterpret_cast(input_row); 37 | float4* output_row_vec = reinterpret_cast(output_row); 38 | float maxval = -INFINITY; 39 | 40 | #pragma unroll 41 | for (int i = tid; i < n_float4s; i += blockDim.x) { 42 | float4 elem = input_row_vec[i]; 43 | 44 | maxval = fmaxf(maxval, elem.x); 45 | maxval = fmaxf(maxval, elem.y); 46 | maxval = fmaxf(maxval, elem.z); 47 | maxval = fmaxf(maxval, elem.w); 48 | if (maxval > local_max) { 49 | local_norm *= __expf(local_max - maxval); 50 | local_max = maxval; 51 | } 52 | local_norm += __expf(elem.x - maxval); 53 | local_norm += __expf(elem.y - maxval); 54 | local_norm += __expf(elem.z - maxval); 55 | local_norm += __expf(elem.w - maxval); 56 | } 57 | 58 | // handle extra row elements 59 | if (tail && tid < tail) { 60 | float val = input_row[n_float4s * 4 + tid]; 61 | if (val > local_max) { 62 | local_norm *= __expf(local_max - val); 63 | local_max = val; 64 | } 65 | local_norm += __expf(val - local_max); 66 | } 67 | __syncthreads(); 68 | 69 | // warp level reduction using XOR shuffle ('exchanges' the values in the threads) 70 | // note: if there are 256 threads in one block (8 warps of 32 threads each) 71 | // the following for loop reduces the value in all the 8 warps 72 | // the 8 warps contain the 8 maximum values of the 32 threads that reside in those warps 73 | // float val = smem[tid]; 74 | blockReduceMax(local_max, smem, -INFINITY); 75 | __syncthreads(); 76 | 77 | // we got the global row max now 78 | float row_max = smem[0]; 79 | __syncthreads(); 80 | 81 | // each thread will have its own local_norm 82 | // we will store the corrected local_norm and reduce it 83 | // same reduction algorithm as above, but instead of max reduction 84 | // we do a sum reduction i.e. we accumulate the values 85 | float val = local_norm * expf(local_max - row_max); 86 | blockReduceSum(val, smem, 0.0f); 87 | __syncthreads(); 88 | 89 | float row_norm = smem[0]; 90 | __syncthreads(); 91 | 92 | // finally, compute softmax 93 | #pragma unroll 94 | for (int i = tid; i < n_float4s; i += blockDim.x) { 95 | float4 elem = input_row_vec[i]; 96 | elem.x = __expf(elem.x - row_max) / row_norm; 97 | elem.y = __expf(elem.y - row_max) / row_norm; 98 | elem.z = __expf(elem.z - row_max) / row_norm; 99 | elem.w = __expf(elem.w - row_max) / row_norm; 100 | 101 | output_row_vec[i] = elem; 102 | } 103 | // write tail elements 104 | if (tail && tid < tail) 105 | { 106 | float val = input_row[n_float4s * 4 + tid]; 107 | output_row[n_float4s * 4 + tid] = __expf(val - row_max) / row_norm; 108 | } 109 | } 110 | 111 | /* 112 | Runs the online softmax kernel: `id = 4` 113 | */ 114 | float run_kernel_4(float* __restrict__ matd, float* __restrict__ resd, int M, int N) { 115 | // grid size and block size for this kernel 116 | // change as necessary 117 | dim3 block_size(1024); 118 | dim3 grid_size(M); 119 | 120 | int warp_size = 32; 121 | size_t smem_size = CEIL_DIV(block_size.x, warp_size) * sizeof(float); 122 | 123 | cudaEvent_t start, stop; 124 | CUDA_CHECK(cudaEventCreate(&start)); 125 | CUDA_CHECK(cudaEventCreate(&stop)); 126 | float ms = 0.f; 127 | 128 | CUDA_CHECK(cudaEventRecord(start)); 129 | softmax_kernel_4<<>>(matd, resd, M, N); 130 | CUDA_CHECK(cudaEventRecord(stop)); 131 | CUDA_CHECK(cudaEventSynchronize(stop)); 132 | CUDA_CHECK(cudaEventElapsedTime(&ms, start, stop)); 133 | printf(">> Kernel execution time: %f ms\n", ms); 134 | 135 | CUDA_CHECK(cudaEventDestroy(start)); 136 | CUDA_CHECK(cudaEventDestroy(stop)); 137 | 138 | return ms; 139 | } -------------------------------------------------------------------------------- /softmax/kernels/shfl_3.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "cuda_utils.cuh" 6 | 7 | /* 8 | This kernel implements an online softmax operation on a matrix of size (M, N). 9 | The softmax operation is performed on the last dimension of the matrix. 10 | 11 | How this works: 12 | This one is largely similar to the above kernel. The difference is instead of accessing 13 | shared memory and having sync barrier overhead, we will use warp-level primitives (then 14 | block-level) for performing max and sum reductions. The benefit is: it is faster than shared 15 | memory access and also does not need syncing since each warp (group of 32 threads) execute 16 | an instuction parallely on GPU so no chance of race conditions. 17 | */ 18 | __global__ void softmax_kernel_3(float* xd, float* resd, int M, int N) { 19 | // max and norm reduction will happen in shared memory (static) 20 | __shared__ float smem[1024]; 21 | 22 | int row = blockIdx.x; 23 | int tid = threadIdx.x; 24 | // number of threads in a warp 25 | unsigned int warp_size = 32; 26 | if (row >= M) return; 27 | 28 | float* input_row = xd + row * N; 29 | float* output_row = resd + row * N; 30 | float local_max = -INFINITY; 31 | float local_norm = 0.0f; 32 | 33 | for (int i = tid; i < N; i += blockDim.x) { 34 | float x = input_row[i]; 35 | if (x > local_max) { 36 | local_norm *= expf(local_max - x); 37 | local_max = x; 38 | } 39 | local_norm += expf(x - local_max); 40 | } 41 | __syncthreads(); 42 | 43 | // each thread will have its own local max 44 | // we store it in shared memory for reduction 45 | // smem[tid] = local_max; 46 | // __syncthreads(); 47 | 48 | // warp level reduction using XOR shuffle ('exchanges' the values in the threads) 49 | // note: if there are 256 threads in one block (8 warps of 32 threads each) 50 | // the following for loop reduces the value in all the 8 warps 51 | // the 8 warps contain the 8 maximum values of the 32 threads that reside in those warps 52 | // float val = smem[tid]; 53 | float val = local_max; 54 | for (int offset = warp_size / 2; offset > 0; offset /= 2) { 55 | val = fmaxf(val, __shfl_down_sync(0xffffffff, val, offset)); 56 | } 57 | 58 | // when blockDim is greater than 32, we need to do a block level reduction 59 | // AFTER warp level reductions since we have the 8 maximum values that needs to be reduced again 60 | // the global max will be stored in the first warp 61 | if (blockDim.x > warp_size) { 62 | if (tid % warp_size == 0) { 63 | // which warp are we at? 64 | // store the value in its first thread index 65 | smem[tid / warp_size] = val; 66 | } 67 | __syncthreads(); 68 | 69 | // first warp will do global reduction only 70 | // this is possible because we stored the values in the shared memory 71 | // so the threads in the first warp will read from it and then reduce 72 | if (tid < warp_size) { 73 | val = (tid < CEIL_DIV(blockDim.x, warp_size)) ? smem[tid] : -INFINITY; 74 | for (int offset = warp_size / 2; offset > 0; offset /= 2) { 75 | val = fmaxf(val, __shfl_down_sync(0xffffffff, val, offset)); 76 | } 77 | if (tid == 0) smem[0] = val; 78 | } 79 | } else { 80 | // this is for when the number of threads in a block are not 81 | // greater than the warp size, in that case we already reduced 82 | // so we can store the value 83 | if (tid == 0) smem[0] = val; 84 | } 85 | __syncthreads(); 86 | 87 | // we got the global row max now 88 | float row_max = smem[0]; 89 | __syncthreads(); 90 | 91 | // each thread will have its own local_norm 92 | // we will store the corrected local_norm in the shared memory 93 | // smem[tid] = local_norm * expf(local_max - row_max); 94 | // __syncthreads(); 95 | 96 | // same reduction algorithm as above, but instead of max reduction 97 | // we do a sum reduction i.e. we accumulate the values 98 | // val = smem[tid]; 99 | val = local_norm * expf(local_max - row_max); 100 | for (int offset = warp_size / 2; offset > 0; offset /= 2) { 101 | val += __shfl_down_sync(0xffffffff, val, offset); 102 | } 103 | 104 | if (blockDim.x > warp_size) { 105 | if (tid % warp_size == 0) { 106 | smem[tid / warp_size] = val; 107 | } 108 | __syncthreads(); 109 | 110 | // first warp will do global reduction 111 | if (tid < warp_size) { 112 | val = (tid < CEIL_DIV(blockDim.x, warp_size)) ? smem[tid] : 0.0f; 113 | for (int offset = warp_size / 2; offset > 0; offset /= 2) { 114 | val += __shfl_down_sync(0xffffffff, val, offset); 115 | } 116 | if (tid == 0) smem[0] = val; 117 | } 118 | } else { 119 | if (tid == 0) smem[0] = val; 120 | } 121 | __syncthreads(); 122 | 123 | float row_norm = smem[0]; 124 | __syncthreads(); 125 | 126 | // finally, compute softmax 127 | for (int i = tid; i < N; i += blockDim.x) { 128 | output_row[i] = expf(input_row[i] - row_max) / row_norm; 129 | } 130 | } 131 | 132 | /* 133 | Runs the online softmax kernel: `id = 3` 134 | */ 135 | void run_kernel_3(float* __restrict__ matd, float* __restrict__ resd, int M, int N) { 136 | // grid size and block size for this kernel 137 | // change as necessary 138 | dim3 block_size(1024); 139 | dim3 grid_size(M); 140 | 141 | cudaEvent_t start, stop; 142 | CUDA_CHECK(cudaEventCreate(&start)); 143 | CUDA_CHECK(cudaEventCreate(&stop)); 144 | float ms = 0.f; 145 | 146 | CUDA_CHECK(cudaEventRecord(start)); 147 | softmax_kernel_3<<>>(matd, resd, M, N); 148 | CUDA_CHECK(cudaEventRecord(stop)); 149 | CUDA_CHECK(cudaEventSynchronize(stop)); 150 | CUDA_CHECK(cudaEventElapsedTime(&ms, start, stop)); 151 | printf(">> Kernel execution time: %f ms\n", ms); 152 | 153 | CUDA_CHECK(cudaEventDestroy(start)); 154 | CUDA_CHECK(cudaEventDestroy(stop)); 155 | } -------------------------------------------------------------------------------- /attention/attn.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #define CUDA_CHECK(ans) \ 10 | { \ 11 | cudaAssert((ans), __FILE__, __LINE__); \ 12 | } 13 | inline void cudaAssert(cudaError_t code, const char *file, int line) { 14 | if (code != cudaSuccess) { 15 | fprintf(stderr, "CUDA error %s: %s at %s: %d\n", 16 | cudaGetErrorName(code), cudaGetErrorString(code), 17 | file, line); 18 | exit(code); 19 | } 20 | } 21 | #define CEIL_DIV(x, y) ((x) >= 0 ? (((x) + (y) - 1) / (y)) : ((x) / (y))) 22 | 23 | struct SumOp { 24 | __device__ __forceinline__ float operator()(float a, float b) const { 25 | return a + b; 26 | } 27 | __device__ __forceinline__ float identity() const { 28 | return 0.0f; 29 | } 30 | }; 31 | 32 | struct MaxOp { 33 | __device__ __forceinline__ float operator()(float a, float b) const { 34 | return fmaxf(a, b); 35 | } 36 | __device__ __forceinline__ float identity() const { 37 | return -INFINITY; 38 | } 39 | }; 40 | 41 | template 42 | __device__ __forceinline__ float warpReduce(float val, Op op) { 43 | for (int offset = warpSize / 2; offset > 0; offset /= 2) { 44 | val = op(val, __shfl_down_sync(0xffffffff, val, offset)); 45 | } 46 | 47 | return val; 48 | } 49 | 50 | template 51 | __device__ __forceinline__ void blockReduce(float val, float *smem, int tid, int blockDimX, Op op) { 52 | // 1. do warpReduce sum 53 | val = warpReduce(val, op); 54 | 55 | // 2. do blockReduce sum 56 | if (blockDimX > warpSize) { 57 | int lane = tid % warpSize; 58 | int wid = tid / warpSize; 59 | if (lane == 0) { 60 | smem[wid] = val; 61 | } 62 | __syncthreads(); 63 | 64 | if (tid < warpSize) { 65 | val = tid < CEIL_DIV(blockDimX, warpSize) ? smem[tid] : op.identity(); 66 | val = warpReduce(val, op); 67 | if (tid == 0) smem[0] = val; 68 | } 69 | } else { 70 | if (tid == 0) smem[0] = val; 71 | } 72 | __syncthreads(); 73 | } 74 | 75 | __global__ void scale_kernel_inplace(float *X, float scale_factor, int bs, int nh, int sl, int ed, int n_elements) { 76 | assert(n_elements % 4 == 0); 77 | 78 | int idx = blockDim.x * blockIdx.x + threadIdx.x; 79 | int n_float4s = n_elements / 4; 80 | 81 | float4 *inputs = reinterpret_cast(X); 82 | 83 | if (idx < n_float4s) { 84 | float4 elem = inputs[idx]; 85 | elem.x *= scale_factor; 86 | elem.y *= scale_factor; 87 | elem.z *= scale_factor; 88 | elem.w *= scale_factor; 89 | 90 | inputs[idx] = elem; 91 | } 92 | } 93 | 94 | __global__ void softmax_kernel_inplace(float *__restrict__ X, int M, int N) { 95 | assert(N % 4 == 0); 96 | 97 | // max and norm reduction will happen in shared memory (static) 98 | extern __shared__ float smem[]; 99 | 100 | int row = blockIdx.x; 101 | int tid = threadIdx.x; 102 | if (row >= M) return; 103 | 104 | float *input_row = X + row * N; 105 | float local_max = -INFINITY; 106 | float local_norm = 0.0f; 107 | 108 | // cast as float4 109 | int n_float4s = N / 4; 110 | float4 *input_row_vec = reinterpret_cast(input_row); 111 | 112 | float maxval = -INFINITY; 113 | #pragma unroll 114 | for (int i = tid; i < n_float4s; i += blockDim.x) { 115 | float4 elem = input_row_vec[i]; 116 | 117 | maxval = fmaxf(maxval, elem.x); 118 | maxval = fmaxf(maxval, elem.y); 119 | maxval = fmaxf(maxval, elem.z); 120 | maxval = fmaxf(maxval, elem.w); 121 | if (maxval > local_max) { 122 | local_norm *= __expf(local_max - maxval); 123 | local_max = maxval; 124 | } 125 | local_norm += __expf(elem.x - maxval); 126 | local_norm += __expf(elem.y - maxval); 127 | local_norm += __expf(elem.z - maxval); 128 | local_norm += __expf(elem.w - maxval); 129 | } 130 | 131 | blockReduce(local_max, smem, tid, blockDim.x, MaxOp()); 132 | float row_max = smem[0]; 133 | 134 | float adjusted = local_norm * expf(local_max - row_max); 135 | blockReduce(adjusted, smem, tid, blockDim.x, SumOp()); 136 | float row_norm = smem[0]; 137 | 138 | // finally, compute softmax 139 | #pragma unroll 140 | for (int i = tid; i < n_float4s; i += blockDim.x) { 141 | float4 elem = input_row_vec[i]; 142 | elem.x = __expf(elem.x - row_max) / row_norm; 143 | elem.y = __expf(elem.y - row_max) / row_norm; 144 | elem.z = __expf(elem.z - row_max) / row_norm; 145 | elem.w = __expf(elem.w - row_max) / row_norm; 146 | 147 | input_row_vec[i] = elem; 148 | } 149 | } 150 | 151 | torch::Tensor attention_forward(uint64_t handle, torch::Tensor Q, torch::Tensor K, torch::Tensor V) { 152 | cublasHandle_t cu_handle = reinterpret_cast(handle); 153 | cublasSetMathMode(cu_handle, CUBLAS_TF32_TENSOR_OP_MATH); 154 | 155 | int bs = Q.size(0); 156 | int nh = Q.size(1); 157 | int sl = Q.size(2); 158 | int ed = Q.size(3); 159 | 160 | int n_elements = bs * nh * sl * ed; 161 | torch::Tensor pre = torch::empty({bs, nh, sl, sl}, torch::TensorOptions().device(torch::kCUDA)); 162 | torch::Tensor out = torch::zeros_like(Q); 163 | 164 | const float alpha = (1 / sqrt(ed)); 165 | const float beta = 0.0f; 166 | cublasSgemmStridedBatched( 167 | cu_handle, 168 | CUBLAS_OP_T, CUBLAS_OP_N, 169 | sl, sl, ed, 170 | &alpha, 171 | K.data_ptr(), ed, sl * ed, 172 | Q.data_ptr(), ed, sl * ed, 173 | &beta, 174 | pre.data_ptr(), sl, sl * sl, 175 | bs * nh); 176 | 177 | // softmax 178 | dim3 block_dim(128); 179 | dim3 grid_dim_softmax(bs * nh * sl); 180 | size_t smem_size = CEIL_DIV(block_dim.x, 32) * sizeof(float); 181 | softmax_kernel_inplace<<>>(pre.data_ptr(), bs * nh * sl, sl); 182 | 183 | // update alpha here for no scaling 184 | const float alpha_sv = 1.0f; 185 | cublasSgemmStridedBatched( 186 | cu_handle, 187 | CUBLAS_OP_N, CUBLAS_OP_N, 188 | ed, sl, sl, 189 | &alpha_sv, 190 | V.data_ptr(), ed, sl * ed, 191 | pre.data_ptr(), sl, sl * sl, 192 | &beta, 193 | out.data_ptr(), ed, sl * ed, 194 | bs * nh); 195 | 196 | return out; 197 | } -------------------------------------------------------------------------------- /matmul/main.cu: -------------------------------------------------------------------------------- 1 | // Matrix Multiplication (xGEMM) kernels 2 | // Note: this file might change often as i learn more about CUDA and kernels in general 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | #define CUDA_CHECK(ans) \ 13 | { \ 14 | cudaAssert((ans), __FILE__, __LINE__); \ 15 | } 16 | inline void cudaAssert(cudaError_t code, const char* file, int line) { 17 | if (code != cudaSuccess) { 18 | fprintf(stderr, "CUDA error %s: %s at %s: %d\n", 19 | cudaGetErrorName(code), cudaGetErrorString(code), 20 | file, line); 21 | exit(code); 22 | } 23 | } 24 | #define CEIL_DIV(x, y) ((x) >= 0 ? (((x) + (y) - 1) / (y)) : ((x) / (y))) 25 | #define M_PI 3.14159265358979323846f 26 | #define TILE_SIZE 32 27 | 28 | #define BM 64 29 | #define BK 8 30 | #define BN 64 31 | #define COARSE_FACTOR 8 32 | 33 | /* 34 | Naive xGEMM kernel: 35 | 36 | - 2D blocks, 2D threads 37 | - Each thread calculates one element of the output matrix C 38 | - No shared memory, only global memory access 39 | */ 40 | __global__ void naive_xgemm_kernel(float* __restrict__ Ad, float* __restrict__ Bd, float* __restrict__ Cd, int M, int N, int K) { 41 | // for coalesced memory access: 42 | // maps rows to y-direction, and cols to x-direction 43 | int row = blockDim.y * blockIdx.y + threadIdx.y; 44 | int col = blockDim.x * blockIdx.x + threadIdx.x; 45 | 46 | if (row < M && col < N) { 47 | float acc = 0.0f; 48 | for (int k = 0; k < K; k++) { 49 | acc += Ad[row * K + k] * Bd[k * N + col]; 50 | } 51 | Cd[row * N + col] = acc; 52 | } 53 | } 54 | 55 | /* 56 | Tiled xGEMM kernel: 57 | 58 | - Each block calculates a "tile" of the output matrix C 59 | > Here the indices for C, that each block (bx, by) computes would be: 60 | row = by * TILE_SIZE + ty; 61 | col = bx * TILE_SIZE + tx; 62 | 63 | - Each block will loop over the tiles in the common dimension. 64 | 65 | - The threads within each block loads the elements in shared memory 66 | > Thread (tx, ty) will load the corresponding elements from A and B 67 | shared_A[ty][tx] = A[row * K + (tile_num * TILE_SIZE + tx)] 68 | shared_B[ty][tx] = B[(tile_num * TILE_SIZE + ty) * N + col] 69 | 70 | Note: from A, the same row is loaded and from B the same column is loaded 71 | 72 | - Then they accumulate the dot product in a variable for the common dimension 73 | - So block (bx, by) has completed computing the tile (bx, by) of C. 74 | */ 75 | __global__ void tiled_xgemm_kernel(float* __restrict__ Ad, float* __restrict__ Bd, float* __restrict__ Cd, int M, int N, int K) { 76 | int ty = threadIdx.y; 77 | int tx = threadIdx.x; 78 | 79 | int by = blockIdx.y; 80 | int bx = blockIdx.x; 81 | 82 | // indices of C[row, col] 83 | int row = by * TILE_SIZE + ty; 84 | int col = bx * TILE_SIZE + tx; 85 | 86 | // tile that will be loaded by THIS block 87 | __shared__ float a_smem[TILE_SIZE][TILE_SIZE]; 88 | __shared__ float b_smem[TILE_SIZE][TILE_SIZE]; 89 | 90 | // final dot product sum 91 | float acc = 0.f; 92 | 93 | // THIS block will loop over the tiles in common dimension 94 | for (int tile_num = 0; tile_num < CEIL_DIV(K, TILE_SIZE); tile_num++) { 95 | int offset = tile_num * TILE_SIZE; 96 | 97 | // out of bounds check 98 | // same row, different column for A 99 | if (row < M && (offset + tx) < K) 100 | a_smem[ty][tx] = Ad[row * K + offset + tx]; 101 | else 102 | a_smem[ty][tx] = 0.f; 103 | 104 | // different row, same column for B 105 | if ((offset + ty) < K && col < N) 106 | b_smem[ty][tx] = Bd[(offset + ty) * N + col]; 107 | else 108 | b_smem[ty][tx] = 0.f; 109 | __syncthreads(); 110 | 111 | // dot product and accumulate 112 | for (int i = 0; i < TILE_SIZE; i++) { 113 | acc += a_smem[ty][i] * b_smem[i][tx]; 114 | } 115 | __syncthreads(); 116 | } 117 | 118 | // write the final output after looping over all tiles 119 | if (row < M && col < N) { 120 | Cd[row * N + col] = acc; 121 | } 122 | } 123 | 124 | /* 125 | Tiled xGEMM kernel with + 1D blocktiling 126 | 127 | - Each thread calculates more than one element (one column of output matrix C) 128 | - Tiles of A has shape (BM, BK) and tile of B has shape (BK, BN) 129 | - Threads process COARSE_FACTOR rows at a time 130 | */ 131 | __global__ void tiled_xgemm_1d_coarse_kernel(float* __restrict__ Ad, float* __restrict__ Bd, float* __restrict__ Cd, int M, int N, int K) { 132 | int by = blockIdx.y; 133 | int bx = blockIdx.x; 134 | 135 | // for within each tile + for loading B's tile 136 | int ty = threadIdx.x / BN; 137 | int tx = threadIdx.x % BN; 138 | 139 | // for loading A's tile 140 | int aty = threadIdx.x / BK; 141 | int atx = threadIdx.x % BK; 142 | 143 | // working on C[row, col] 144 | int row = by * BM + (ty * COARSE_FACTOR); 145 | int col = bx * BN + tx; 146 | 147 | // shared memory for A and B for computing tiles 148 | __shared__ float a_smem[BM * BK]; 149 | __shared__ float b_smem[BK * BN]; 150 | 151 | float acc[COARSE_FACTOR] = {0.f}; 152 | 153 | for (int tile = 0; tile < K; tile += BK) { 154 | // load tiles into shared memory for both A and B 155 | if ((by * BM + aty) < M && (tile + atx) < K) 156 | a_smem[aty * BK + atx] = Ad[(by * BM + aty) * K + (tile + atx)]; 157 | else 158 | a_smem[aty * BK + atx] = 0.f; 159 | if ((tile + ty) < K && (bx * BN + tx) < N) 160 | b_smem[ty * BN + tx] = Bd[(tile + ty) * N + (bx * BN + tx)]; 161 | else 162 | b_smem[ty * BN + tx] = 0.f; 163 | __syncthreads(); 164 | 165 | // inner loop: 166 | // each thread computes 8 elements 167 | for (int k = 0; k < BK; k++) { 168 | float b_reg = b_smem[k * BN + tx]; 169 | for (int c = 0; c < COARSE_FACTOR; c++) 170 | acc[c] += a_smem[(ty * COARSE_FACTOR + c) * BK + k] * b_reg; 171 | } 172 | __syncthreads(); 173 | } 174 | 175 | for (int c = 0; c < COARSE_FACTOR; c++) { 176 | if ((row + c) < M && col < N) 177 | Cd[(row + c) * N + col] = acc[c]; 178 | } 179 | } 180 | 181 | void gemm_cpu_naive(float* A, float* B, float* C, int M, int N, int K) { 182 | for (int i = 0; i < M; i++) { 183 | for (int j = 0; j < N; j++) { 184 | float sum = 0.f; 185 | for (int k = 0; k < K; k++) { 186 | sum += (A[i * K + k] * B[k * N + j]); 187 | } 188 | C[i * N + j] = sum; 189 | } 190 | } 191 | } 192 | 193 | /* 194 | Helper function to generate a clamped random number sampled from a 195 | normal distribution with mean 0 and std 1 196 | */ 197 | float random_normal_clamped(float min, float max) { 198 | float u1 = (float)rand() / RAND_MAX; 199 | float u2 = (float)rand() / RAND_MAX; 200 | float num = sqrtf(-2.0f * logf(u1)) * cosf(2.0f * M_PI * u2); 201 | if (num < min) 202 | return min; 203 | if (num > max) 204 | return max; 205 | return num; 206 | } 207 | 208 | int main() { 209 | int M = 1024; 210 | int N = 1024; 211 | int K = 1024; 212 | 213 | int a_size = M * K; 214 | int b_size = K * N; 215 | int c_size = M * N; 216 | 217 | printf("Shape A: (%d, %d)\n", M, K); 218 | printf("Shape B: (%d, %d)\n", K, N); 219 | printf("Shape C: (%d, %d)\n", M, N); 220 | 221 | float* A = (float*)malloc(a_size * sizeof(float)); 222 | float* B = (float*)malloc(b_size * sizeof(float)); 223 | float* C = (float*)malloc(c_size * sizeof(float)); 224 | float* C_cpu = (float*)malloc(c_size * sizeof(float)); 225 | 226 | // init the matrices with random values 227 | for (int i = 0; i < a_size; i++) { 228 | A[i] = random_normal_clamped(-10, 10); 229 | } 230 | for (int i = 0; i < b_size; i++) { 231 | B[i] = random_normal_clamped(-10, 10); 232 | } 233 | for (int i = 0; i < b_size; i++) { 234 | C[i] = 0.f; 235 | } 236 | 237 | float *Ad, *Bd, *Cd; 238 | 239 | // uncomment the below lines for tiled_xgemm without coalesce 240 | // dim3 block_size(TILE_SIZE, TILE_SIZE); 241 | // dim3 grid_size(CEIL_DIV(N, block_size.x), CEIL_DIV(M, block_size.y)); 242 | 243 | dim3 block_size(BM * BN / COARSE_FACTOR); 244 | dim3 grid_size(CEIL_DIV(N, BN), CEIL_DIV(M, BM)); 245 | 246 | cudaEvent_t start, stop; 247 | CUDA_CHECK(cudaEventCreate(&start)); 248 | CUDA_CHECK(cudaEventCreate(&stop)); 249 | float ms = 0.0f; 250 | 251 | cudaEventRecord(start); 252 | CUDA_CHECK(cudaMalloc(&Ad, a_size * sizeof(float))); 253 | CUDA_CHECK(cudaMalloc(&Bd, b_size * sizeof(float))); 254 | CUDA_CHECK(cudaMalloc(&Cd, c_size * sizeof(float))); 255 | cudaEventRecord(stop); 256 | cudaEventSynchronize(stop); 257 | cudaEventElapsedTime(&ms, start, stop); 258 | printf(">> GPU allocation time: %f ms\n", ms); 259 | 260 | cudaEventRecord(start); 261 | CUDA_CHECK(cudaMemcpy(Ad, A, a_size * sizeof(float), cudaMemcpyHostToDevice)); 262 | CUDA_CHECK(cudaMemcpy(Bd, B, b_size * sizeof(float), cudaMemcpyHostToDevice)); 263 | cudaEventRecord(stop); 264 | cudaEventSynchronize(stop); 265 | cudaEventElapsedTime(&ms, start, stop); 266 | printf(">> Host to device transfer time: %f ms\n", ms); 267 | 268 | cudaEventRecord(start); 269 | tiled_xgemm_1d_coarse_kernel<<>>(Ad, Bd, Cd, M, N, K); 270 | cudaEventRecord(stop); 271 | cudaEventSynchronize(stop); 272 | cudaEventElapsedTime(&ms, start, stop); 273 | printf(">> Kernel execution time: %f ms\n", ms); 274 | 275 | cudaEventRecord(start); 276 | CUDA_CHECK(cudaMemcpy(C, Cd, c_size * sizeof(float), cudaMemcpyDeviceToHost)); 277 | cudaEventRecord(stop); 278 | cudaEventSynchronize(stop); 279 | cudaEventElapsedTime(&ms, start, stop); 280 | printf(">> Device to host transfer time: %f ms\n", ms); 281 | 282 | printf("\n>> Running GEMM on CPU...\n"); 283 | clock_t ts = clock(); 284 | gemm_cpu_naive(A, B, C_cpu, M, N, K); 285 | clock_t te = clock(); 286 | printf(">> Done\n"); 287 | 288 | float elapsed_time = (te - ts) * 1000 / CLOCKS_PER_SEC; 289 | printf("Elapsed time: %.6f ms\n", elapsed_time); 290 | 291 | // check if results match within an error tolerance (eps) 292 | bool match = true; 293 | float eps = 0.0001; 294 | for (int i = 0; i < c_size; i++) { 295 | if (fabs(C_cpu[i] - C[i]) > eps) { 296 | match = false; 297 | break; 298 | } 299 | } 300 | printf("\n>> Results match for CPU and GPU? "); 301 | printf("%s\n", match ? "true" : "false"); 302 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /flash-attention/fa2/flash-attn-2.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #define CUDA_CHECK(ans) \ 9 | { \ 10 | cudaAssert((ans), __FILE__, __LINE__); \ 11 | } 12 | inline void cudaAssert(cudaError_t code, const char* file, int line) { 13 | if (code != cudaSuccess) { 14 | fprintf(stderr, "CUDA error %s: %s at %s: %d\n", 15 | cudaGetErrorName(code), cudaGetErrorString(code), 16 | file, line); 17 | exit(code); 18 | } 19 | } 20 | #define CEIL_DIV(x, y) ((x) >= 0 ? (((x) + (y) - 1) / (y)) : ((x) / (y))) 21 | #define PI 3.1415 22 | 23 | float random_normal_clamped(float min, float max) { 24 | float u1 = (float)rand() / RAND_MAX; 25 | float u2 = (float)rand() / RAND_MAX; 26 | float num = sqrtf(-2.0f * logf(u1)) * cosf(2.0f * PI * u2); 27 | if (num < min) 28 | return min; 29 | if (num > max) 30 | return max; 31 | return num; 32 | } 33 | 34 | /* 35 | Reduction functions on device. These will be inline: 36 | The compiler will replace the call with the code instead of calling the function (overhead) 37 | */ 38 | /* 39 | Utility warp level sum reduction with shuffle instructions 40 | */ 41 | 42 | __device__ __forceinline__ float warpReduceSum(float val, int width) { 43 | for (int offset = width / 2; offset > 0; offset /= 2) { 44 | val += __shfl_down_sync(0xffffffff, val, offset); 45 | } 46 | 47 | return val; 48 | } 49 | 50 | __device__ __forceinline__ float warpReduceMax(float val, int width) { 51 | for (int offset = width / 2; offset > 0; offset /= 2) { 52 | val = max(val, __shfl_down_sync(0xffffffff, val, offset)); 53 | } 54 | 55 | return val; 56 | } 57 | 58 | /* 59 | This kernel uses flash attention algorithm to compute multi-head attention. 60 | Q, K, and V are 4D tensors of shape (batch_size, n_heads, seq_len, embed_dim). 61 | Additional inputs are Tr and Tc which are the tiles each block computes. 62 | The arrays l and m are to save the norm and maximum for the ith tile. 63 | SRAM will have size M and Br = ceil(M / 4d) and Bc = min(ceil(M / 4d), d) 64 | where M is the size of the SRAM. 65 | */ 66 | template 67 | __global__ void flash_attn_2_kernel(float* Q, float* K, float* V, int N, int d, int Tr, int Tc, float scale, float* L, float* O) { 68 | int tx = threadIdx.x; // Br * Bc threads 69 | 70 | int bx = blockIdx.x; // Batch index 71 | int by = blockIdx.y; // Head index 72 | 73 | // tip to calculate offset: 74 | // count how many elements to skip in the array to reach an index 75 | int qkv_off = (bx * gridDim.y * N * d) + (by * N * d); 76 | int lm_off = (bx * gridDim.y * N) + (by * N); 77 | 78 | // TODO: remove too much shared memory usage 79 | extern __shared__ float smem[]; 80 | float* Qi = smem; 81 | float* Kj = Qi + Br * d; 82 | float* Vj = Kj + Bc * d; 83 | float* Sij = Vj + Bc * d; 84 | float* Oi = Sij + Br * Bc; 85 | float* li = Oi + Br * d; 86 | float* mi = li + Br; 87 | float* mi_new = mi + Br; 88 | 89 | int loads_per_thread_Brd = CEIL_DIV(d, Bc); 90 | int loads_per_thread_Bcd = CEIL_DIV(d, Br); 91 | 92 | for (int i = 0; i < Tr; i++) { 93 | // load Qi and Oi from HBM into SMEM 94 | for (int e = 0; e < loads_per_thread_Brd; e++) { 95 | int idx = e * (Br * Bc) + tx; 96 | int row = idx / d; 97 | if (idx < Br * d && i * Br + row < N) { 98 | int col = idx % d; 99 | Qi[row * d + col] = Q[qkv_off + (i * Br + row) * d + col]; 100 | Oi[row * d + col] = 0.0f; 101 | } 102 | } 103 | 104 | int s_row = tx / Bc; 105 | int s_col = tx % Bc; 106 | 107 | int global_row = (i * Br) + s_row; 108 | 109 | // init li and mi for each row 110 | if (s_col == 0) { 111 | li[s_row] = 0.f; 112 | mi[s_row] = -INFINITY; 113 | mi_new[s_row] = -INFINITY; 114 | } 115 | __syncthreads(); 116 | 117 | for (int j = 0; j < Tc; j++) { 118 | // load Kj and Vj into SMEM from HBM 119 | for (int e = 0; e < loads_per_thread_Bcd; e++) { 120 | int idx = e * (Br * Bc) + tx; 121 | int row = idx / d; 122 | if (idx < Bc * d && j * Bc + row < N) { 123 | int col = idx % d; 124 | Kj[row * d + col] = K[qkv_off + (j * Bc + row) * d + col]; 125 | Vj[row * d + col] = V[qkv_off + (j * Bc + row) * d + col]; 126 | } 127 | } 128 | __syncthreads(); 129 | 130 | // compute S = Qi * Kj^T where shape of S: (Br, Bc) 131 | // TODO: reduce shared memory bank conflicts 132 | float acc = 0.f; 133 | for (int k = 0; k < d; k++) 134 | acc += Qi[s_row * d + k] * Kj[s_col * d + k]; 135 | 136 | acc *= scale; 137 | Sij[s_row * Bc + s_col] = acc; 138 | 139 | // rowmax(S) and rowsum(S) (only one thread per row) 140 | // computes both in a single pass 141 | if (s_col == 0) { 142 | mi[s_row] = mi_new[s_row]; 143 | float row_m = -INFINITY, row_l = 0.f; 144 | for (int c = 0; c < Bc; c++) { 145 | float val = Sij[s_row * Bc + c]; 146 | if (val > row_m) { 147 | row_m = val; 148 | } 149 | } 150 | float maxval = max(mi[s_row], row_m); 151 | 152 | float kahan_comp = 0.0f; 153 | for (int c = 0; c < Bc; c++) { 154 | float exp_val = expf(Sij[s_row * Bc + c] - maxval); 155 | Sij[s_row * Bc + c] = exp_val; 156 | 157 | float y = exp_val - kahan_comp; // subtract the compensation from the new value 158 | float t = row_l + y; // add the compensated value to the sum 159 | kahan_comp = (t - row_l) - y; // compute the new compensation. 160 | row_l = t; 161 | } 162 | 163 | mi_new[s_row] = maxval; 164 | li[s_row] = expf(mi[s_row] - maxval) * li[s_row] + row_l; 165 | } 166 | __syncthreads(); 167 | 168 | // compute Sij * Vj and do a roll-forward update to O 169 | // Sij (Br, Bc) and Vj (Bc, d) and we have Br * Bc threads 170 | // a thread may compute more than one element's dot product 171 | for (int col = s_col; col < d; col += Bc) { 172 | float acc = 0.f; 173 | float kahan_comp = 0.f; 174 | for (int c = 0; c < Bc; c++) { 175 | float term = Sij[s_row * Bc + c] * Vj[c * d + col]; 176 | 177 | float y = term - kahan_comp; 178 | float t = acc + y; 179 | kahan_comp = (t - acc) - y; 180 | acc = t; 181 | } 182 | 183 | Oi[s_row * d + col] = (mi[s_row] == -INFINITY || mi_new[s_row] == -INFINITY) ? acc : expf(mi[s_row] - mi_new[s_row]) * Oi[s_row * d + col] + acc; 184 | } 185 | } 186 | 187 | for (int col = s_col; col < d; col += Bc) { 188 | if (global_row < N) 189 | O[qkv_off + global_row * d + col] = (1 / li[s_row]) * Oi[s_row * d + col]; 190 | } 191 | 192 | if (s_col == 0) { 193 | L[lm_off + (i * Br) + s_row] = mi_new[s_row] + logf(li[s_row]); 194 | } 195 | __syncthreads(); 196 | } 197 | } 198 | 199 | // Comment the below function to compile and run this file as executable 200 | torch::Tensor fa2_forward(torch::Tensor Q, torch::Tensor K, torch::Tensor V) { 201 | // TODO: determine Bc, Br dynamically 202 | const int Bc = 16; 203 | const int Br = 16; 204 | 205 | int B = Q.size(0); 206 | int nh = Q.size(1); 207 | int N = Q.size(2); 208 | int d = Q.size(3); 209 | 210 | int Tc = ceil((float)N / Bc); 211 | int Tr = ceil((float)N / Br); 212 | float softmax_scale = 1.0 / sqrt(d); 213 | 214 | // Initialize O, l, m to HBM 215 | auto O = torch::zeros_like(Q); 216 | auto L = torch::zeros({B, nh, N}); 217 | torch::Device device(torch::kCUDA); 218 | L = L.to(device); 219 | 220 | const int smem_size = ((Br * Bc) + (2 * Br * d) + (2 * Bc * d) + (3 * Br)) * sizeof(float); 221 | int max_sram_size; 222 | cudaDeviceGetAttribute(&max_sram_size, cudaDevAttrMaxSharedMemoryPerBlock, 0); 223 | printf("Max shared memory: %d, requested shared memory: %d \n", max_sram_size, smem_size); 224 | 225 | dim3 grid_size(B, nh); // batch_size x num_heads 226 | dim3 block_size(Br * Bc); // Br * Bc threads per block 227 | 228 | flash_attn_2_kernel<<>>( 229 | Q.data_ptr(), K.data_ptr(), V.data_ptr(), 230 | N, d, Tr, Tc, softmax_scale, 231 | L.data_ptr(), O.data_ptr()); 232 | return O; 233 | } 234 | 235 | int main() { 236 | int batch_size = 16; 237 | int n_head = 8; 238 | int seq_len = 512; 239 | int head_embd = 64; 240 | 241 | int qkv_size = batch_size * n_head * seq_len * head_embd; 242 | int lm_size = batch_size * n_head * seq_len; 243 | 244 | float *Qh, *Kh, *Vh, *Oh, *Lh; 245 | Qh = (float*)malloc(qkv_size * sizeof(float)); 246 | Kh = (float*)malloc(qkv_size * sizeof(float)); 247 | Vh = (float*)malloc(qkv_size * sizeof(float)); 248 | Oh = (float*)malloc(qkv_size * sizeof(float)); 249 | Lh = (float*)malloc(lm_size * sizeof(float)); 250 | 251 | for (int i = 0; i < qkv_size; i++) { 252 | Qh[i] = 1.0; // random_normal_clamped(-1, 1); 253 | Kh[i] = 2.0; // random_normal_clamped(-1, 1); 254 | Vh[i] = 3.0; // random_normal_clamped(-1, 1); 255 | Oh[i] = 0.0f; 256 | } 257 | for (int i = 0; i < lm_size; i++) { 258 | Lh[i] = 0.0f; 259 | } 260 | 261 | const int Br = 16, Bc = 16; 262 | int Tc = ceil((float)seq_len / Bc); 263 | int Tr = ceil((float)seq_len / Br); 264 | float softmax_scale = 1.0 / sqrt(head_embd); 265 | 266 | const int smem_size = ((Br * Bc) + (2 * Br * head_embd) + (2 * Bc * head_embd) + (3 * Br)) * sizeof(float); 267 | int max_sram_size; 268 | cudaDeviceGetAttribute(&max_sram_size, cudaDevAttrMaxSharedMemoryPerBlock, 0); 269 | printf("Max shared memory: %d, requested shared memory: %d \n", max_sram_size, smem_size); 270 | 271 | dim3 grid_dim(batch_size, n_head); // batch_size x num_heads 272 | dim3 block_dim(Br * Bc); // Br * Bc threads per block 273 | 274 | cudaEvent_t start, stop; 275 | CUDA_CHECK(cudaEventCreate(&start)); 276 | CUDA_CHECK(cudaEventCreate(&stop)); 277 | float ms = 0.0f; 278 | 279 | float *Q, *K, *V, *O, *L; 280 | 281 | cudaEventRecord(start); 282 | CUDA_CHECK(cudaMalloc(&Q, qkv_size * sizeof(float))); 283 | CUDA_CHECK(cudaMalloc(&K, qkv_size * sizeof(float))); 284 | CUDA_CHECK(cudaMalloc(&V, qkv_size * sizeof(float))); 285 | CUDA_CHECK(cudaMalloc(&O, qkv_size * sizeof(float))); 286 | CUDA_CHECK(cudaMalloc(&L, lm_size * sizeof(float))); 287 | cudaEventRecord(stop); 288 | cudaEventSynchronize(stop); 289 | cudaEventElapsedTime(&ms, start, stop); 290 | printf(">> GPU allocation time: %f ms\n", ms); 291 | 292 | cudaEventRecord(start); 293 | CUDA_CHECK(cudaMemcpy(Q, Qh, qkv_size * sizeof(float), cudaMemcpyHostToDevice)); 294 | CUDA_CHECK(cudaMemcpy(K, Kh, qkv_size * sizeof(float), cudaMemcpyHostToDevice)); 295 | CUDA_CHECK(cudaMemcpy(V, Vh, qkv_size * sizeof(float), cudaMemcpyHostToDevice)); 296 | CUDA_CHECK(cudaMemcpy(O, Oh, qkv_size * sizeof(float), cudaMemcpyHostToDevice)); 297 | CUDA_CHECK(cudaMemcpy(L, Lh, lm_size * sizeof(float), cudaMemcpyHostToDevice)); 298 | cudaEventRecord(stop); 299 | cudaEventSynchronize(stop); 300 | cudaEventElapsedTime(&ms, start, stop); 301 | printf(">> Host to device transfer time: %f ms\n", ms); 302 | 303 | cudaEventRecord(start); 304 | flash_attn_2_kernel<<>>( 305 | Q, K, V, seq_len, head_embd, Tr, Tc, softmax_scale, L, O); 306 | cudaEventRecord(stop); 307 | cudaEventSynchronize(stop); 308 | cudaEventElapsedTime(&ms, start, stop); 309 | printf(">> Flash-Attention 1 kernel execution time: %f ms\n", ms); 310 | 311 | cudaEventRecord(start); 312 | CUDA_CHECK(cudaMemcpy(Oh, O, qkv_size * sizeof(float), cudaMemcpyDeviceToHost)); 313 | cudaEventRecord(stop); 314 | cudaEventSynchronize(stop); 315 | cudaEventElapsedTime(&ms, start, stop); 316 | printf(">> Device to host transfer time: %f ms\n", ms); 317 | 318 | CUDA_CHECK(cudaEventDestroy(start)); 319 | CUDA_CHECK(cudaEventDestroy(stop)); 320 | 321 | printf("\nFirst and Last value in Output:\n"); 322 | printf("%f and %f\n", Oh[0], Oh[qkv_size - 1]); 323 | 324 | cudaFree(Q); 325 | cudaFree(K); 326 | cudaFree(V); 327 | cudaFree(O); 328 | cudaFree(L); 329 | free(Qh); 330 | free(Kh); 331 | free(Vh); 332 | free(Oh); 333 | free(Lh); 334 | 335 | return 0; 336 | } -------------------------------------------------------------------------------- /flash-attention/fa1/flash-attn-1.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | #define CUDA_CHECK(ans) \ 8 | { \ 9 | cudaAssert((ans), __FILE__, __LINE__); \ 10 | } 11 | inline void cudaAssert(cudaError_t code, const char* file, int line) { 12 | if (code != cudaSuccess) { 13 | fprintf(stderr, "CUDA error %s: %s at %s: %d\n", 14 | cudaGetErrorName(code), cudaGetErrorString(code), 15 | file, line); 16 | exit(code); 17 | } 18 | } 19 | #define CEIL_DIV(x, y) ((x) >= 0 ? (((x) + (y) - 1) / (y)) : ((x) / (y))) 20 | #define PI 3.1415 21 | 22 | float random_normal_clamped(float min, float max) { 23 | float u1 = (float)rand() / RAND_MAX; 24 | float u2 = (float)rand() / RAND_MAX; 25 | float num = sqrtf(-2.0f * logf(u1)) * cosf(2.0f * PI * u2); 26 | if (num < min) 27 | return min; 28 | if (num > max) 29 | return max; 30 | return num; 31 | } 32 | 33 | /* 34 | Reduction functions on device. These will be inline: 35 | The compiler will replace the call with the code instead of calling the function (overhead) 36 | */ 37 | /* 38 | Utility warp level sum reduction with shuffle instructions 39 | */ 40 | 41 | __device__ __forceinline__ float warpReduceSum(float val, int width) { 42 | for (int offset = width / 2; offset > 0; offset /= 2) { 43 | val += __shfl_down_sync(0xffffffff, val, offset); 44 | } 45 | 46 | return val; 47 | } 48 | 49 | __device__ __forceinline__ float warpReduceMax(float val, int width) { 50 | for (int offset = width / 2; offset > 0; offset /= 2) { 51 | val = max(val, __shfl_down_sync(0xffffffff, val, offset)); 52 | } 53 | 54 | return val; 55 | } 56 | 57 | /* 58 | This kernel uses flash attention algorithm to compute multi-head attention. 59 | Q, K, and V are 4D tensors of shape (batch_size, n_heads, seq_len, embed_dim). 60 | Additional inputs are Tr and Tc which are the tiles each block computes. 61 | The arrays l and m are to save the norm and maximum for the ith tile. 62 | SRAM will have size M and Br = ceil(M / 4d) and Bc = min(ceil(M / 4d), d) 63 | where M is the size of the SRAM. 64 | */ 65 | template 66 | __global__ void flash_attn_1_kernel(float* Q, float* K, float* V, int N, int d, int Tr, int Tc, float scale, float* l, float* m, float* O) { 67 | int tx = threadIdx.x; // Br * Bc threads 68 | 69 | int bx = blockIdx.x; // Batch index 70 | int by = blockIdx.y; // Head index 71 | 72 | // tip to calculate offset: 73 | // count how many elements to skip in the array to reach an index 74 | int qkv_off = (bx * gridDim.y * N * d) + (by * N * d); 75 | int lm_off = (bx * gridDim.y * N) + (by * N); 76 | 77 | // TODO: remove too much shared memory usage 78 | extern __shared__ float smem[]; 79 | float* Qi = smem; 80 | float* Kj = Qi + Br * d; 81 | float* Vj = Kj + Bc * d; 82 | float* Sij = Vj + Bc * d; 83 | float* Oi = Sij + Br * Bc; 84 | float* li = Oi + Br * d; 85 | float* li_new = li + Br; 86 | float* mi = li_new + Br; 87 | float* mi_new = mi + Br; 88 | float* mij_dash = mi_new + Br; 89 | 90 | for (int j = 0; j < Tc; j++) { 91 | // load Kj and Vj into SMEM 92 | // a thread may load multiple elements 93 | int loads_per_thread = CEIL_DIV(d, Br); 94 | for (int e = 0; e < loads_per_thread; e++) { 95 | int idx = e * (Br * Bc) + tx; 96 | if (idx < Bc * d) { 97 | int row = idx / d; 98 | int col = idx % d; 99 | 100 | if (j * Bc + row < N) { 101 | Kj[row * d + col] = K[qkv_off + (j * Bc + row) * d + col]; 102 | Vj[row * d + col] = V[qkv_off + (j * Bc + row) * d + col]; 103 | } 104 | } 105 | } 106 | __syncthreads(); // barrier here for correct Kj and Vj values in inner loop 107 | 108 | for (int i = 0; i < Tr; i++) { 109 | // load Qi and Oi into smem similar to Kj 110 | // a thread may load multiple elements 111 | int loads_per_thread = CEIL_DIV(d, Bc); 112 | for (int e = 0; e < loads_per_thread; e++) { 113 | int idx = e * (Br * Bc) + tx; 114 | if (idx < Br * d) { 115 | int row = idx / d; 116 | int col = idx % d; 117 | if (i * Br + row < N) { 118 | Qi[row * d + col] = Q[qkv_off + (i * Br + row) * d + col]; 119 | Oi[row * d + col] = O[qkv_off + (i * Br + row) * d + col]; 120 | } 121 | } 122 | } 123 | 124 | int s_row = tx / Bc; 125 | int s_col = tx % Bc; 126 | 127 | if (s_col == 0) { 128 | mi[s_row] = m[lm_off + (i * Br) + s_row]; 129 | li[s_row] = l[lm_off + (i * Br) + s_row]; 130 | } 131 | __syncthreads(); 132 | 133 | // compute S = Qi * Kj^T where shape of S: (Br, Bc) 134 | // TODO: reduce shared memory bank conflicts 135 | float acc = 0.f; 136 | for (int k = 0; k < d; k++) 137 | acc += Qi[s_row * d + k] * Kj[s_col * d + k]; 138 | 139 | acc *= scale; 140 | Sij[s_row * Bc + s_col] = acc; 141 | 142 | // rowmax(S) and rowsum(S) (only one thread per row) 143 | // computes both in a single pass 144 | if (s_col == 0) { 145 | float row_m = -INFINITY, row_l = 0.f; 146 | for (int c = 0; c < Bc; c++) { 147 | float val = Sij[s_row * Bc + c]; 148 | if (val > row_m) { 149 | row_m = val; 150 | } 151 | } 152 | for (int c = 0; c < Bc; c++) { 153 | float exp_val = expf(Sij[s_row * Bc + c] - row_m); 154 | Sij[s_row * Bc + c] = exp_val; 155 | row_l += exp_val; 156 | } 157 | 158 | mij_dash[s_row] = row_m; 159 | mi_new[s_row] = max(mi[s_row], row_m); 160 | li_new[s_row] = expf(mi[s_row] - mi_new[s_row]) * li[s_row] + expf(row_m - mi_new[s_row]) * row_l; 161 | } 162 | __syncthreads(); 163 | 164 | // compute Sij * Vj and do a roll-forward update to O 165 | // Sij (Br, Bc) and Vj (Bc, d) and we have Br * Bc threads 166 | // a thread may compute more than one element's dot product 167 | for (int col = s_col; col < d; col += Bc) { 168 | float acc = 0.f; 169 | for (int c = 0; c < Bc; c++) 170 | acc += Sij[s_row * Bc + c] * Vj[c * d + col]; 171 | 172 | int global_row = (i * Br) + s_row; 173 | if (global_row < N) { 174 | Oi[s_row * d + col] = (1 / li_new[s_row]) * ((li[s_row] * expf(mi[s_row] - mi_new[s_row]) * Oi[s_row * d + col]) + (expf(mij_dash[s_row] - mi_new[s_row]) * acc)); 175 | O[qkv_off + global_row * d + col] = Oi[s_row * d + col]; 176 | } 177 | } 178 | 179 | // update max and norm for next iteration 180 | m[lm_off + (i * Br) + s_row] = mi_new[s_row]; 181 | l[lm_off + (i * Br) + s_row] = li_new[s_row]; 182 | } 183 | __syncthreads(); 184 | } 185 | } 186 | 187 | // Comment the below function to compile and run this file as executable 188 | torch::Tensor fa_forward(torch::Tensor Q, torch::Tensor K, torch::Tensor V) { 189 | // TODO: determine Bc, Br dynamically 190 | const int Bc = 16; 191 | const int Br = 16; 192 | 193 | int B = Q.size(0); 194 | int nh = Q.size(1); 195 | int N = Q.size(2); 196 | int d = Q.size(3); 197 | 198 | int Tc = ceil((float)N / Bc); 199 | int Tr = ceil((float)N / Br); 200 | float softmax_scale = 1.0 / sqrt(d); 201 | 202 | // Initialize O, l, m to HBM 203 | auto O = torch::zeros_like(Q); 204 | auto l = torch::zeros({B, nh, N}); 205 | auto m = torch::full({B, nh, N}, -INFINITY); 206 | torch::Device device(torch::kCUDA); 207 | l = l.to(device); 208 | m = m.to(device); 209 | 210 | const int smem_size = ((Br * Bc) + (2 * Br * d) + (2 * Bc * d) + (5 * Br)) * sizeof(float); 211 | int max_sram_size; 212 | cudaDeviceGetAttribute(&max_sram_size, cudaDevAttrMaxSharedMemoryPerBlock, 0); 213 | printf("Max shared memory: %d, requested shared memory: %d \n", max_sram_size, smem_size); 214 | 215 | dim3 grid_size(B, nh); // batch_size x num_heads 216 | dim3 block_size(Br * Bc); // Br * Bc threads per block 217 | 218 | flash_attn_1_kernel<<>>( 219 | Q.data_ptr(), K.data_ptr(), V.data_ptr(), 220 | N, d, Tr, Tc, softmax_scale, 221 | l.data_ptr(), m.data_ptr(), O.data_ptr()); 222 | return O; 223 | } 224 | 225 | int main() { 226 | int batch_size = 16; 227 | int n_head = 8; 228 | int seq_len = 512; 229 | int head_embd = 64; 230 | 231 | int qkv_size = batch_size * n_head * seq_len * head_embd; 232 | int lm_size = batch_size * n_head * seq_len; 233 | 234 | float *Qh, *Kh, *Vh, *Oh, *lh, *mh; 235 | Qh = (float*)malloc(qkv_size * sizeof(float)); 236 | Kh = (float*)malloc(qkv_size * sizeof(float)); 237 | Vh = (float*)malloc(qkv_size * sizeof(float)); 238 | Oh = (float*)malloc(qkv_size * sizeof(float)); 239 | lh = (float*)malloc(lm_size * sizeof(float)); 240 | mh = (float*)malloc(lm_size * sizeof(float)); 241 | 242 | for (int i = 0; i < qkv_size; i++) { 243 | Qh[i] = random_normal_clamped(-1, 1); 244 | Kh[i] = random_normal_clamped(-1, 1); 245 | Vh[i] = random_normal_clamped(-1, 1); 246 | Oh[i] = 0.0f; 247 | } 248 | for (int i = 0; i < lm_size; i++) { 249 | lh[i] = 0.0f; 250 | mh[i] = -INFINITY; 251 | } 252 | 253 | const int Br = 16, Bc = 16; 254 | int Tc = ceil((float)seq_len / Bc); 255 | int Tr = ceil((float)seq_len / Br); 256 | float softmax_scale = 1.0 / sqrt(head_embd); 257 | 258 | const int smem_size = ((Br * Bc) + (2 * Br * head_embd) + (2 * Bc * head_embd) + (5 * Br)) * sizeof(float); 259 | int max_sram_size; 260 | cudaDeviceGetAttribute(&max_sram_size, cudaDevAttrMaxSharedMemoryPerBlock, 0); 261 | printf("Max shared memory: %d, requested shared memory: %d \n", max_sram_size, smem_size); 262 | 263 | dim3 grid_dim(batch_size, n_head); // batch_size x num_heads 264 | dim3 block_dim(Br * Bc); // Br * Bc threads per block 265 | 266 | cudaEvent_t start, stop; 267 | CUDA_CHECK(cudaEventCreate(&start)); 268 | CUDA_CHECK(cudaEventCreate(&stop)); 269 | float ms = 0.0f; 270 | 271 | float *Q, *K, *V, *O, *l, *m; 272 | 273 | cudaEventRecord(start); 274 | CUDA_CHECK(cudaMalloc(&Q, qkv_size * sizeof(float))); 275 | CUDA_CHECK(cudaMalloc(&K, qkv_size * sizeof(float))); 276 | CUDA_CHECK(cudaMalloc(&V, qkv_size * sizeof(float))); 277 | CUDA_CHECK(cudaMalloc(&O, qkv_size * sizeof(float))); 278 | CUDA_CHECK(cudaMalloc(&l, lm_size * sizeof(float))); 279 | CUDA_CHECK(cudaMalloc(&m, lm_size * sizeof(float))); 280 | cudaEventRecord(stop); 281 | cudaEventSynchronize(stop); 282 | cudaEventElapsedTime(&ms, start, stop); 283 | printf(">> GPU allocation time: %f ms\n", ms); 284 | 285 | cudaEventRecord(start); 286 | CUDA_CHECK(cudaMemcpy(Q, Qh, qkv_size * sizeof(float), cudaMemcpyHostToDevice)); 287 | CUDA_CHECK(cudaMemcpy(K, Kh, qkv_size * sizeof(float), cudaMemcpyHostToDevice)); 288 | CUDA_CHECK(cudaMemcpy(V, Vh, qkv_size * sizeof(float), cudaMemcpyHostToDevice)); 289 | CUDA_CHECK(cudaMemcpy(O, Oh, qkv_size * sizeof(float), cudaMemcpyHostToDevice)); 290 | CUDA_CHECK(cudaMemcpy(l, lh, lm_size * sizeof(float), cudaMemcpyHostToDevice)); 291 | CUDA_CHECK(cudaMemcpy(m, mh, lm_size * sizeof(float), cudaMemcpyHostToDevice)); 292 | cudaEventRecord(stop); 293 | cudaEventSynchronize(stop); 294 | cudaEventElapsedTime(&ms, start, stop); 295 | printf(">> Host to device transfer time: %f ms\n", ms); 296 | 297 | cudaEventRecord(start); 298 | flash_attn_1_kernel<<>>( 299 | Q, K, V, seq_len, head_embd, Tr, Tc, softmax_scale, 300 | l, m, O); 301 | cudaEventRecord(stop); 302 | cudaEventSynchronize(stop); 303 | cudaEventElapsedTime(&ms, start, stop); 304 | printf(">> Flash-Attention 1 kernel execution time: %f ms\n", ms); 305 | 306 | cudaEventRecord(start); 307 | CUDA_CHECK(cudaMemcpy(Oh, O, qkv_size * sizeof(float), cudaMemcpyDeviceToHost)); 308 | cudaEventRecord(stop); 309 | cudaEventSynchronize(stop); 310 | cudaEventElapsedTime(&ms, start, stop); 311 | printf(">> Device to host transfer time: %f ms\n", ms); 312 | 313 | CUDA_CHECK(cudaEventDestroy(start)); 314 | CUDA_CHECK(cudaEventDestroy(stop)); 315 | 316 | printf("\nFirst and Last value in Output:\n"); 317 | printf("%f and %f\n", Oh[0], Oh[qkv_size - 1]); 318 | 319 | cudaFree(Q); 320 | cudaFree(K); 321 | cudaFree(V); 322 | cudaFree(O); 323 | cudaFree(l); 324 | cudaFree(m); 325 | free(Qh); 326 | free(Kh); 327 | free(Vh); 328 | free(Oh); 329 | free(lh); 330 | free(mh); 331 | 332 | return 0; 333 | } -------------------------------------------------------------------------------- /flash-attention/fa2/flash2_fwd_fp16.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #define CUDA_CHECK(ans) \ 10 | { \ 11 | cudaAssert((ans), __FILE__, __LINE__); \ 12 | } 13 | inline void cudaAssert(cudaError_t code, const char* file, int line) { 14 | if (code != cudaSuccess) { 15 | fprintf(stderr, "CUDA error %s: %s at %s: %d\n", 16 | cudaGetErrorName(code), cudaGetErrorString(code), 17 | file, line); 18 | exit(code); 19 | } 20 | } 21 | #define CEIL_DIV(x, y) ((x) >= 0 ? (((x) + (y) - 1) / (y)) : ((x) / (y))) 22 | #define PI 3.1415 23 | 24 | float random_normal_clamped(float min, float max) { 25 | float u1 = (float)rand() / RAND_MAX; 26 | float u2 = (float)rand() / RAND_MAX; 27 | float num = sqrtf(-2.0f * logf(u1)) * cosf(2.0f * PI * u2); 28 | if (num < min) 29 | return min; 30 | if (num > max) 31 | return max; 32 | return num; 33 | } 34 | 35 | __device__ __forceinline__ float warpReduceSum(float val, int width) { 36 | for (int offset = width / 2; offset > 0; offset /= 2) { 37 | val += __shfl_down_sync(0xffffffff, val, offset); 38 | } 39 | 40 | return val; 41 | } 42 | 43 | __device__ __forceinline__ float warpReduceMax(float val, int width) { 44 | for (int offset = width / 2; offset > 0; offset /= 2) { 45 | val = max(val, __shfl_down_sync(0xffffffff, val, offset)); 46 | } 47 | 48 | return val; 49 | } 50 | 51 | /* 52 | FP16 VERSION: 53 | 54 | This kernel uses flash attention algorithm to compute multi-head attention. 55 | Q, K, and V are 4D tensors of shape (batch_size, n_heads, seq_len, embed_dim). 56 | Additional inputs are Tr and Tc which are the tiles each block computes. 57 | The arrays l and m are to save the norm and maximum for the ith tile. 58 | SRAM will have size M and Br = ceil(M / 4d) and Bc = min(ceil(M / 4d), d) 59 | where M is the size of the SRAM. 60 | */ 61 | template 62 | __global__ void flash_attn_2_kernel_fp16(__half* Q, __half* K, __half* V, int N, int d, int Tr, int Tc, float scale, __half* L, __half* O) { 63 | int tx = threadIdx.x; // Br * Bc threads 64 | 65 | int bx = blockIdx.x; // Batch index 66 | int by = blockIdx.y; // Head index 67 | 68 | // tip to calculate offset: 69 | // count how many elements to skip in the array to reach an index 70 | int qkv_off = (bx * gridDim.y * N * d) + (by * N * d); 71 | int lm_off = (bx * gridDim.y * N) + (by * N); 72 | 73 | // TODO: remove too much shared memory usage 74 | extern __shared__ __half smem[]; 75 | __half* Qi = smem; 76 | __half* Kj = Qi + Br * d; 77 | __half* Vj = Kj + Bc * d; 78 | __half* Sij = Vj + Bc * d; 79 | __half* Oi = Sij + Br * Bc; 80 | __half* li = Oi + Br * d; 81 | __half* mi = li + Br; 82 | __half* mi_new = mi + Br; 83 | 84 | int loads_per_thread_Brd = CEIL_DIV(d, Bc); 85 | int loads_per_thread_Bcd = CEIL_DIV(d, Br); 86 | 87 | for (int i = 0; i < Tr; i++) { 88 | // load Qi and Oi from HBM into SMEM 89 | for (int e = 0; e < loads_per_thread_Brd; e++) { 90 | int idx = e * (Br * Bc) + tx; 91 | int row = idx / d; 92 | if (idx < Br * d && i * Br + row < N) { 93 | int col = idx % d; 94 | Qi[row * d + col] = Q[qkv_off + (i * Br + row) * d + col]; 95 | Oi[row * d + col] = __float2half(0.0f); 96 | } 97 | } 98 | 99 | int s_row = tx / Bc; 100 | int s_col = tx % Bc; 101 | 102 | int global_row = (i * Br) + s_row; 103 | 104 | // init li and mi for each row 105 | if (s_col == 0) { 106 | li[s_row] = __float2half(0.f); 107 | mi[s_row] = __float2half(-INFINITY); 108 | mi_new[s_row] = __float2half(-INFINITY); 109 | } 110 | __syncthreads(); 111 | 112 | for (int j = 0; j < Tc; j++) { 113 | // load Kj and Vj into SMEM from HBM 114 | for (int e = 0; e < loads_per_thread_Bcd; e++) { 115 | int idx = e * (Br * Bc) + tx; 116 | int row = idx / d; 117 | if (idx < Bc * d && j * Bc + row < N) { 118 | int col = idx % d; 119 | Kj[row * d + col] = K[qkv_off + (j * Bc + row) * d + col]; 120 | Vj[row * d + col] = V[qkv_off + (j * Bc + row) * d + col]; 121 | } 122 | } 123 | __syncthreads(); 124 | 125 | // compute S = Qi * Kj^T where shape of S: (Br, Bc) 126 | // TODO: reduce shared memory bank conflicts 127 | float acc = 0.f; 128 | for (int k = 0; k < d; k++) 129 | acc += __half2float(__hmul(Qi[s_row * d + k], Kj[s_col * d + k])); 130 | 131 | acc *= scale; 132 | Sij[s_row * Bc + s_col] = __float2half(acc); 133 | 134 | // rowmax(S) and rowsum(S) (only one thread per row) 135 | // computes both in a single pass 136 | if (s_col == 0) { 137 | mi[s_row] = mi_new[s_row]; 138 | __half row_m = __float2half(-INFINITY); 139 | float row_l = 0.f; 140 | for (int c = 0; c < Bc; c++) { 141 | __half val = Sij[s_row * Bc + c]; 142 | if (__hgt(val, row_m)) { 143 | row_m = val; 144 | } 145 | } 146 | __half maxval = __hmax(mi[s_row], row_m); 147 | 148 | float kahan_comp = 0.0f; 149 | for (int c = 0; c < Bc; c++) { 150 | float exp_val = expf(__half2float(__hsub(Sij[s_row * Bc + c], maxval))); 151 | Sij[s_row * Bc + c] = __float2half(exp_val); 152 | 153 | float y = exp_val - kahan_comp; // subtract the compensation from the new value 154 | float t = row_l + y; // add the compensated value to the sum 155 | kahan_comp = (t - row_l) - y; // compute the new compensation. 156 | row_l = t; 157 | } 158 | 159 | mi_new[s_row] = maxval; 160 | li[s_row] = __float2half(expf(__half2float(__hsub(mi[s_row], maxval))) * __half2float(li[s_row]) + row_l); 161 | } 162 | __syncthreads(); 163 | 164 | // compute Sij * Vj and do a roll-forward update to O 165 | // Sij (Br, Bc) and Vj (Bc, d) and we have Br * Bc threads 166 | // a thread may compute more than one element's dot product 167 | for (int col = s_col; col < d; col += Bc) { 168 | float acc = 0.f; 169 | float kahan_comp = 0.f; 170 | for (int c = 0; c < Bc; c++) { 171 | float term = __half2float(__hmul(Sij[s_row * Bc + c], Vj[c * d + col])); 172 | 173 | float y = term - kahan_comp; 174 | float t = acc + y; 175 | kahan_comp = (t - acc) - y; 176 | acc = t; 177 | } 178 | 179 | Oi[s_row * d + col] = (__heq(mi[s_row], __float2half(-INFINITY)) || __heq(mi_new[s_row], __float2half(-INFINITY))) ? __float2half(acc) : __float2half(expf(__half2float(__hsub(mi[s_row], mi_new[s_row]))) * __half2float(Oi[s_row * d + col]) + acc); 180 | } 181 | } 182 | 183 | for (int col = s_col; col < d; col += Bc) { 184 | if (global_row < N) 185 | O[qkv_off + global_row * d + col] = __float2half((1 / __half2float(li[s_row])) * __half2float(Oi[s_row * d + col])); 186 | } 187 | 188 | if (s_col == 0) { 189 | L[lm_off + (i * Br) + s_row] = __float2half(__half2float(mi_new[s_row]) + logf(__half2float(li[s_row]))); 190 | } 191 | __syncthreads(); 192 | } 193 | } 194 | 195 | // Comment the below function to compile and run this file as executable 196 | torch::Tensor fa2_forward(torch::Tensor Q, torch::Tensor K, torch::Tensor V) { 197 | // TODO: determine Bc, Br dynamically 198 | const int Bc = 16; 199 | const int Br = 16; 200 | 201 | int B = Q.size(0); 202 | int nh = Q.size(1); 203 | int N = Q.size(2); 204 | int d = Q.size(3); 205 | 206 | int Tc = ceil((float)N / Bc); 207 | int Tr = ceil((float)N / Br); 208 | float softmax_scale = 1.0 / sqrt(d); 209 | 210 | // Initialize O, l, m to HBM 211 | auto O = torch::zeros_like(Q); 212 | auto L = torch::zeros({B, nh, N}); 213 | torch::Device device(torch::kCUDA); 214 | L = L.to(device).to(torch::kHalf); 215 | 216 | const int smem_size = ((Br * Bc) + (2 * Br * d) + (2 * Bc * d) + (3 * Br)) * sizeof(__half); 217 | int max_sram_size; 218 | cudaDeviceGetAttribute(&max_sram_size, cudaDevAttrMaxSharedMemoryPerBlock, 0); 219 | printf("Max shared memory: %d, requested shared memory: %d \n", max_sram_size, smem_size); 220 | 221 | dim3 grid_size(B, nh); // batch_size x num_heads 222 | dim3 block_size(Br * Bc); // Br * Bc threads per block 223 | 224 | flash_attn_2_kernel_fp16<<>>( 225 | reinterpret_cast<__half*>(Q.data_ptr()), 226 | reinterpret_cast<__half*>(K.data_ptr()), 227 | reinterpret_cast<__half*>(V.data_ptr()), 228 | N, d, Tr, Tc, softmax_scale, 229 | reinterpret_cast<__half*>(L.data_ptr()), 230 | reinterpret_cast<__half*>(O.data_ptr())); 231 | return O; 232 | } 233 | 234 | int main() { 235 | int batch_size = 16; 236 | int n_head = 8; 237 | int seq_len = 512; 238 | int head_embd = 64; 239 | 240 | int qkv_size = batch_size * n_head * seq_len * head_embd; 241 | int lm_size = batch_size * n_head * seq_len; 242 | 243 | __half *Qh, *Kh, *Vh, *Oh, *Lh; 244 | Qh = (__half*)malloc(qkv_size * sizeof(__half)); 245 | Kh = (__half*)malloc(qkv_size * sizeof(__half)); 246 | Vh = (__half*)malloc(qkv_size * sizeof(__half)); 247 | Oh = (__half*)malloc(qkv_size * sizeof(__half)); 248 | Lh = (__half*)malloc(lm_size * sizeof(__half)); 249 | 250 | for (int i = 0; i < qkv_size; i++) { 251 | Qh[i] = __float2half(1.0); // random_normal_clamped(-1, 1); 252 | Kh[i] = __float2half(2.0); // random_normal_clamped(-1, 1); 253 | Vh[i] = __float2half(3.0); // random_normal_clamped(-1, 1); 254 | Oh[i] = __float2half(0.0f); 255 | } 256 | for (int i = 0; i < lm_size; i++) { 257 | Lh[i] = __float2half(0.0f); 258 | } 259 | 260 | const int Br = 16, Bc = 16; 261 | int Tc = ceil((float)seq_len / Bc); 262 | int Tr = ceil((float)seq_len / Br); 263 | float softmax_scale = 1.0 / sqrt(head_embd); 264 | 265 | const int smem_size = ((Br * Bc) + (2 * Br * head_embd) + (2 * Bc * head_embd) + (3 * Br)) * sizeof(__half); 266 | int max_sram_size; 267 | cudaDeviceGetAttribute(&max_sram_size, cudaDevAttrMaxSharedMemoryPerBlock, 0); 268 | printf("Max shared memory: %d, requested shared memory: %d \n", max_sram_size, smem_size); 269 | 270 | dim3 grid_dim(batch_size, n_head); // batch_size x num_heads 271 | dim3 block_dim(Br * Bc); // Br * Bc threads per block 272 | 273 | cudaEvent_t start, stop; 274 | CUDA_CHECK(cudaEventCreate(&start)); 275 | CUDA_CHECK(cudaEventCreate(&stop)); 276 | float ms = 0.0f; 277 | 278 | __half *Q, *K, *V, *O, *L; 279 | 280 | cudaEventRecord(start); 281 | CUDA_CHECK(cudaMalloc(&Q, qkv_size * sizeof(__half))); 282 | CUDA_CHECK(cudaMalloc(&K, qkv_size * sizeof(__half))); 283 | CUDA_CHECK(cudaMalloc(&V, qkv_size * sizeof(__half))); 284 | CUDA_CHECK(cudaMalloc(&O, qkv_size * sizeof(__half))); 285 | CUDA_CHECK(cudaMalloc(&L, lm_size * sizeof(__half))); 286 | cudaEventRecord(stop); 287 | cudaEventSynchronize(stop); 288 | cudaEventElapsedTime(&ms, start, stop); 289 | printf(">> GPU allocation time: %f ms\n", ms); 290 | 291 | cudaEventRecord(start); 292 | CUDA_CHECK(cudaMemcpy(Q, Qh, qkv_size * sizeof(__half), cudaMemcpyHostToDevice)); 293 | CUDA_CHECK(cudaMemcpy(K, Kh, qkv_size * sizeof(__half), cudaMemcpyHostToDevice)); 294 | CUDA_CHECK(cudaMemcpy(V, Vh, qkv_size * sizeof(__half), cudaMemcpyHostToDevice)); 295 | CUDA_CHECK(cudaMemcpy(O, Oh, qkv_size * sizeof(__half), cudaMemcpyHostToDevice)); 296 | CUDA_CHECK(cudaMemcpy(L, Lh, lm_size * sizeof(__half), cudaMemcpyHostToDevice)); 297 | cudaEventRecord(stop); 298 | cudaEventSynchronize(stop); 299 | cudaEventElapsedTime(&ms, start, stop); 300 | printf(">> Host to device transfer time: %f ms\n", ms); 301 | 302 | cudaEventRecord(start); 303 | flash_attn_2_kernel_fp16<<>>( 304 | Q, K, V, seq_len, head_embd, Tr, Tc, softmax_scale, L, O); 305 | cudaEventRecord(stop); 306 | cudaEventSynchronize(stop); 307 | cudaEventElapsedTime(&ms, start, stop); 308 | printf(">> Flash-Attention 1 kernel execution time: %f ms\n", ms); 309 | 310 | cudaEventRecord(start); 311 | CUDA_CHECK(cudaMemcpy(Oh, O, qkv_size * sizeof(__half), cudaMemcpyDeviceToHost)); 312 | cudaEventRecord(stop); 313 | cudaEventSynchronize(stop); 314 | cudaEventElapsedTime(&ms, start, stop); 315 | printf(">> Device to host transfer time: %f ms\n", ms); 316 | 317 | CUDA_CHECK(cudaEventDestroy(start)); 318 | CUDA_CHECK(cudaEventDestroy(stop)); 319 | 320 | printf("\nFirst and Last value in Output:\n"); 321 | printf("%f and %f\n", __half2float(Oh[0]), __half2float(Oh[qkv_size - 1])); 322 | 323 | cudaFree(Q); 324 | cudaFree(K); 325 | cudaFree(V); 326 | cudaFree(O); 327 | cudaFree(L); 328 | free(Qh); 329 | free(Kh); 330 | free(Vh); 331 | free(Oh); 332 | free(Lh); 333 | 334 | return 0; 335 | } -------------------------------------------------------------------------------- /matvec/README.md: -------------------------------------------------------------------------------- 1 | # Learning CUDA by optimizing matrix-vector multiplication (SGEMV) for cuBLAS-like performance - A worklog 2 | 3 | Matrix-vector multiplication is a foundational operation in linear algebra, where a matrix transforms an input vector into an output vector. This operation basically powers numerous fields including computer science and deep learning. Optimizing matrix-vector multiplication, especially in the context of GPU programming and CUDA can help us learn many new things. 4 | 5 | In this worklog, we will start by benchmarking [cuBLAS](https://developer.nvidia.com/cublas)'s matrix-vector multiplication performance then we will iteratively optimize it in CUDA to see how close we can get to cuBLAS. The intention is to not replace it, but to learn from it. The NVIDIA GPU used for this worklog is one **GTX 1050Ti** (that's all I have got right now). By the end of this worklog, we will achieve what the following figure shows: 6 | 7 | ![Benchmark results](https://raw.githubusercontent.com/Maharshi-Pandya/cudacodes/refs/heads/master/matvec/media/benchmark_results.png) 8 | 9 | The full code is available on the GitHub repository: [Optimizing SGEMV in CUDA](https://github.com/Maharshi-Pandya/cudacodes/tree/master/matvec) 10 | 11 | > If you want me to work together with you on deep learning models, inference, training, software development, custom CUDA kernels or something else then you can shoot a direct message (DM) to me here: [me on X (formerly Twitter)](https://x.com/mrsiipa) 12 | 13 | Let's start! 14 | 15 | 16 | ## Some background first 17 | 18 | From now on, we will call matrix-vector multiplication as **SGEMV** which stands for **Single-Precision General Matrix-Vector multiplication**. The breakdown of this term is: 19 | 20 | 21 | - **S**: Indicates single-precision ($32$-bit floating-point numbers). 22 | - **GE**: Refers to a general matrix, meaning the matrix can have any shape or content (not restricted to special forms like symmetric or diagonal matrices). 23 | - **MV**: Stands for matrix-vector multiplication, the core operation the function performs. 24 | 25 | In essence, given a matrix $\textbf{A}$ of shape $(M, N)$ and an input vector $\textbf{x}$ of shape $(N, 1)$, SGEMV computes an output vector $\textbf{y}$ given as: 26 | 27 | $$ 28 | \textbf{y} = \alpha \cdot \textbf{A} \cdot \textbf{x} + \beta \cdot \textbf{y} 29 | $$ 30 | 31 | Here the terms $\alpha$ and $\beta$ are some scalar coefficients (floating point numbers). In this worklog, we will assume the following for simplicity: 32 | 33 | - The shape of the matrix $\textbf{A}$ will be $(4096, 8192)$ 34 | - The shape of the vector $\textbf{x}$ will be $(8192, 1)$ 35 | - The scalars $\alpha = 1$ and $\beta = 0$ 36 | 37 | And that leaves us with: 38 | 39 | $$ 40 | \textbf{y} = \textbf{A} \cdot \textbf{x} 41 | $$ 42 | 43 | ![SGEMV computation](https://raw.githubusercontent.com/Maharshi-Pandya/cudacodes/refs/heads/master/matvec/media/sgemv-computation.png) 44 | 45 | The figure above and the pseudocode below shows this computation. Note that each row of the matrix $\textbf{A}$ performs a dot product with the entire input vector $\textbf{x}$ to compute one element of the output vector $\textbf{y}$. 46 | 47 | ```cpp 48 | function sgemv_basic(A, x, y, M, N) { 49 | // Initialize output vector 50 | for (i = 0; i < M; i++) { 51 | y[i] = 0; 52 | } 53 | 54 | // Perform the computation 55 | for (i = 0; i < M; i++) { 56 | for (j = 0; j < N; j++) { 57 | y[i] += A[i][j] * x[j]; 58 | } 59 | } 60 | } 61 | ``` 62 | 63 | ## Memory-bound tasks and FLOPS 64 | 65 | This section is very important in terms of comparing performance characteristics of operations like matrix-vector multiplication (SGEMV) and matrix-matrix multiplication (SGEMM). First, we need to define the words **FLOPs** (notice the small 's') and **FLOPS**. 66 | 67 | Basically: 68 | 69 | - **FLOPs** stands for the total number of floating-point operations performed in a computation. The operation can be anything like addition, subtraction, multiplication, and so on. 70 | - **FLOPS** measures the **rate of floating-point operations** that a system can perform in one second. If the system performs $F$ FLOPs in $T$ seconds then FLOPS is given by $(F / T)$. Also, $1$ GFLOPS = $10^9$ FLOPS. 71 | 72 | Now, even though matrix-vector multiplication can be thought of as a special case of matrix-matrix multiplication, there are some differences when it comes to measuring performance of both the operations in CUDA. 73 | 74 | SGEMV is a **memory-bound** operation whereas, SGEMM is a **compute-bound** operation. Let's calculate the FLOPs required for both of these operations: 75 | 76 | **Matrix-vector (SGEMV)** 77 | 78 | 1. Multiplies a matrix $A (M, N)$ with a vector $x (N, 1)$ resulting in a vector $y (M, 1)$. 79 | 2. Memory accesses: 80 | - Reads $A$ ($M \times N$ elements) 81 | - Reads $x$ ($N$ elements) 82 | - Writes $y$ ($M$ elements) 83 | 3. Computations: 84 | - Each row of $A$ is multiplied with $x$, resulting in $M$ dot products. 85 | - Each dot product consists of $N$ multiplications and $N-1$ additions. 86 | - FLOPs: $2 \times MN$ 87 | - $1$ floating-point number is $4$ bytes. 88 | - Bytes transferred: $4 \times (MN + N + M)$ 89 | 90 | 91 | **Matrix-matrix (SGEMM)** 92 | 93 | 1. Multiplies two matrices $A (M, K)$ and $B (K, N)$ resulting in a matrix $C (M, N)$. 94 | 2. Memory accesses: 95 | - Reads $A$ ($M \times K$ elements) 96 | - Reads $B$ ($K \times N$ elements) 97 | - Writes $C$ ($M \times N$ elements) 98 | 3. Computations: 99 | - Dot product of row $i$ of $A$ with column $j$ of $B$ 100 | - Each dot product consists of $K$ multiplications and $K-1$ additions. 101 | - FLOPs: $2 \times MNK$ 102 | - $1$ floating-point number is $4$ bytes. 103 | - Bytes transferred: $4 \times (MK + KN + MN)$ 104 | 105 | 106 | We divide the FLOPs by the total bytes transferred to calculate the computational intensity of the operation. 107 | 108 | Considering $M = 4096$, $N = 4096$ and $K = 4096$, for SGEMV we get: 109 | 110 | $$ 111 | \text{FLOPs per byte} = \dfrac{2MN}{4(MN + N + M)} \approx 0.4998 112 | $$ 113 | 114 | and for SGEMM, we get: 115 | 116 | $$ 117 | \text{FLOPs per byte} = \dfrac{2MNK}{4(MK + KN + MN)} \approx 682.67 118 | $$ 119 | 120 | 121 | As we can see, the computational intensity for SGEMV is very low compared to SGEMM i.e. more time is spent transferring the data from (and to) the global memory of the GPU compared to the actual computation time. Conversly, for SGEMM more time is spent doing the actual computation than transferring the data from (and to) the global memory. 122 | 123 | Thus, SGEMV is a **memory-bound** operation. So, we need to make sure that we are maximizing the memory bandwidth that our CUDA kernel achieves, such that it is close to the maximum memory bandwidth of the GPU ($112.1$ GB/s in our case). 124 | 125 | 126 | ## Benchmark - cuBLAS implementation 127 | 128 | Let's benchmark the SGEMV implementation that cuBLAS provides. To do this, we simply use the `cublasSgemv` function. Below is the corresponding code snippet that does this: 129 | 130 | ```cpp 131 | cublasHandle_t handle; 132 | cublasCreate(&handle); 133 | 134 | float alpha = 1.0f, beta = 0.0f; 135 | cublasSgemv( 136 | handle, CUBLAS_OP_T, N, M, 137 | &alpha, matd, N, vecd, 138 | 1, &beta, resd, 1 139 | ); 140 | ``` 141 | 142 | The on-device matrix is defined as `matd`, input vector is defined as `vecd`, and the resulting vector is defined as `resd`. 143 | The matrix `matd` is stored in the row-major layout in memory and we will use linear indices to access its elements. When we run this we get: 144 | 145 | ``` 146 | >> GPU allocation time: 5.698560 ms 147 | >> Host to device transfer time: 23.588863 ms 148 | ------- cuBLAS sgmev kernel --------- 149 | >> Execution time: 1.532928 ms 150 | >> Achieved (GFLOPS): 43.778225 151 | >> Theoretical max (GFLOPS): 1911.040039 152 | >> Maximum memory bandwidth: 112.127998 GB/s 153 | >> Achieves 2.290806 % of peak GFLOPS 154 | >> Achieves 78.114754 % of peak Memory Bandwidth 155 | --------------------------- 156 | >> Device to host transfer time: 0.042816 ms 157 | ``` 158 | 159 | As expected, we see that the FLOPS achieved is $43.78$ GFLOPS i.e. much less than the theoretical maximum which the GPU can achieve. But, cuBLAS achieves around $78.1$% of the peak memory bandwidth which is great for SGEMV since it is a memory-bound operation, as seen above. 160 | 161 | Let's first write a naive kernel in CUDA for SGMEV and iteratively improve it. 162 | 163 | 164 | ## Kernel 1 - Naive SGEMV 165 | 166 | Following the figure above, we can write a naive kernel for SGEMV. Each thread in a thread block will compute one output element of the vector `resd` in this kernel. The index of the current row and the corresponding output element will be written as `row = blockDim.x * blockIdx.x + threadIdx.x`. 167 | 168 | The corresponding code snippet for this kernel looks like: 169 | 170 | ```cpp 171 | int row = blockDim.x * blockIdx.x + threadIdx.x; 172 | 173 | if (row < M) { 174 | float sum = 0.0f; 175 | for (int col = 0; col < N; col++) { 176 | sum += matd[row * N + col] * vecd[col]; 177 | } 178 | resd[row] = sum; 179 | } 180 | ``` 181 | 182 | Running this kernel results in: 183 | 184 | ``` 185 | >> GPU allocation time: 5.820416 ms 186 | >> Host to device transfer time: 24.071072 ms 187 | ------- Naive sgmev kernel --------- 188 | >> Execution time: 8.279040 ms 189 | >> Achieved (GFLOPS): 8.105875 190 | >> Theoretical max (GFLOPS): 1911.040039 191 | >> Maximum memory bandwidth: 112.127998 GB/s 192 | >> Achieves 0.424160 % of peak GFLOPS 193 | >> Achieves 14.463547 % of peak Memory Bandwidth 194 | --------------------------- 195 | >> Device to host transfer time: 0.048992 ms 196 | ``` 197 | 198 | The naive kernel achieves only $8.27$ GFLOPS which is around $0.42$% of the peak GFLOPS. Apart from that, it achieves only $14.46$% of the peak memory bandwidth. This result is kind of unimpressive. 199 | 200 | But no worries, we can improve this! 201 | 202 | 203 | ## Kernel 2 - Coalesced access with reductions 204 | 205 | One way to improve the naive kernel is to ensure that the memory accesses to both the matrix `matd` and vector `vecd` are **coalesced**. Let's understand what that means. 206 | 207 | In CUDA, we have a grid of blocks where each block can have $T$ number of threads. The streaming multiprocessors (SM) on the GPU process each block in a 'group' of $32$ threads. This group of $32$ threads is called a **warp**. In essence, the SM can execute one instruction at a time for all the threads in a warp. Thus, each block consists of $ceil(T / 32)$ warps. 208 | 209 | But what does this have to do with memory accesses? Well, let's see how the memory accesses for the matrix `matd` looks like in the naive kernel. 210 | 211 | **Warp 1 (Threads 0-31):** 212 | 213 | Assuming that we are dealing with the first block where `blockIdx.x = 0`, the number of threads in a block is $1024$, and all the threads in this warp of block $0$ will execute one instruction in parallel, then we have: 214 | 215 | Thread $0$ executing: 216 | 217 | - The value of `row` = $1024 \times 0 + 0$ = $0$ 218 | - Thread will enter the for loop now. 219 | - We have `col` = $0$ for this thread in the start 220 | - Element of `matd` accessed = $\boxed{0 \times N + 0}$ = $0$ 221 | 222 | Now, at the same time, we have: 223 | 224 | Thread $1$ executing: 225 | 226 | - The value of `row` = $1024 \times 0 + 1$ = $1$ 227 | - Thread will enter the for loop now. 228 | - We have `col` = $0$ for this thread in the start 229 | - Element of `matd` accessed = $\boxed{1 \times N + 0}$ = $N$ 230 | 231 | Similary, we can see that for this particular warp the elements of the matrix `matd` accessed by the other threads will be always **separated by $N$ elements**. But, this is actually NOT an optimal way to access data residing in the global memory! We need to access data from the global memory in a coalesced manner. 232 | 233 | > **Coalesced memory** access occurs when threads in a warp ($32$ threads) access consecutive memory addresses, allowing the 'memory controller' to combine these accesses into a single memory transaction. Global memory on GPUs has high latency and limited bandwidth compared to the speed of computation. Coalescing memory accesses minimizes the number of transactions required to fetch data, maximizing the effective bandwidth. Also, the hardware is designed to handle these coalesced accesses efficiently. When the accesses to the global memory are scattered or random, it *forces* the 'memory controller' to dispatch multiple memory transactions. This results in a slowdown compared to when the memory accesses are coalesced. 234 | 235 | The matrix `matd` is stored in row-major format in the global memory i.e. the elements of the matrix in each row are next to each other (in consecutive memory addresses). The figure below shows the difference of coalesced vs. non-coalesced accesses in the matrix. 236 | 237 | ![Coalesced access](https://raw.githubusercontent.com/Maharshi-Pandya/cudacodes/refs/heads/master/matvec/media/coalesced-access.png) 238 | 239 | 240 | In this kernel, each block will operate on one row of the matrix. Also, we will assume that each block contains of only $32$ threads i.e. $1$ warp. Consecutive threads in the warp will load consecutive elements in the row of the matrix `matd` and the vector `vecd`. We will also have a private variable (private to each thread in a warp) called `partial_sum` which will hold the partial sum of the elements processed by one particular thread. The below code snippet shows this: 241 | 242 | ```cpp 243 | int bid = blockIdx.x; 244 | if (bid >= M) return; 245 | 246 | int tid = threadIdx.x; 247 | 248 | // each thread calculates its own partial output 249 | float partial_sum = 0.f; 250 | for (int col = tid; col < N; col += blockDim.x) { 251 | partial_sum += matd[bid * N + col] * vecd[col]; 252 | } 253 | ``` 254 | 255 | For example, if `bid = 0` and `blockDim.x = 32` then a thread with index `tid` in this block will process the elements `tid`, `tid + blockDim.x`, `tid + 2 * blockDim.x` and so on. Thread $0$ processes the elements $0$, $32$, $64$ etc... thread $1$ processes the elements $1$, $33$, $65$ etc... and the same for remaining threads i.e. consecutive elements are loaded and processed by each thread which results in coalesced global memory access. 256 | 257 | But, there's a problem! Each thread will have its own `partial_sum` variable. To have the final dot product, we need to sum the partial sums for all the threads that are present in the warp. Note that the final value of the dot product of two vectors is a single floating-point number. 258 | 259 | This is where **reductions** can help us. We can essentially 'communicate' values between the threads present in a block/warp using **shared memory** or **warp shuffle intrinsics**, because every thread in a block/warp can have access to the block's shared memory. Have a look at my [CUDA softmax worklog](https://maharshi.bearblog.dev/optimizing-softmax-cuda/) which dives deeper into reductions. The figure below can help you understand reduction for finding the maximum value: 260 | 261 | ![Max reduction](https://raw.githubusercontent.com/Maharshi-Pandya/cudacodes/refs/heads/master/softmax/media/max_reduction.png) 262 | 263 | So, the final value of the dot product will be a **sum reduction** on the `partial_sum` variable for all the threads that are present within the warp. The utility function `warpReduceSum` will help us sum up all the partial sums that are calculated by the threads. Finally, we write the result into the corresponding index of the output vector: 264 | 265 | ```cpp 266 | // warp-level sum reduction 267 | float sum = warpReduceSum(partial_sum); 268 | 269 | // only first thread writes the output to global memory 270 | if (tid == 0) { 271 | resd[bid] = sum; 272 | } 273 | ``` 274 | 275 | Running this kernel results in: 276 | 277 | ``` 278 | >> GPU allocation time: 5.827360 ms 279 | >> Host to device transfer time: 23.294975 ms 280 | ------- Coalesced warp sgmev kernel --------- 281 | >> Execution time: 2.447360 ms 282 | >> Achieved (GFLOPS): 27.420919 283 | >> Theoretical max (GFLOPS): 1911.040039 284 | >> Maximum memory bandwidth: 112.127998 GB/s 285 | >> Achieves 1.434869 % of peak GFLOPS 286 | >> Achieves 48.927948 % of peak Memory Bandwidth 287 | --------------------------- 288 | >> Device to host transfer time: 0.067872 ms 289 | ``` 290 | 291 | By having coalesced global memory accesses and warp level sum reduction (with block size $32$) we achieve $48$% of peak memory bandwidth and $27.4$ GFLOPS! 292 | 293 | This is closer to cuBLAS but we can do even more. 294 | 295 | 296 | ## Kernel 3 - Block-level reductions 297 | 298 | In the previous kernel, we had a constraint on the number of threads we can have in one block i.e. $32$ for warp level reductions only. In this kernel, in order to have more threads in one block to do more computations we need to perform a block level reduction as well AFTER warp level reduction. 299 | 300 | The idea is simple: 301 | 302 | Consider there are $2$ warps in a block making the total number of threads $64$. Now, let's say we perform a sum reduction on the two warps and store the results of each warp in the first thread of the respective warps. 303 | 304 | Thread $0$ will store the value of the sum reduction we perform on warp $1$. Similary, thread $32$ will store the value of the sum reduction we perform on warp $2$. We now have two values that we need to reduce as summation. Since the values to be reduced are from different warps, this type of reduction is called **block level reduction**. 305 | 306 | Essentially, we will sum the values present in thread $0$ and thread $32$, and then store the result in the first memory address of **shared memory**. Then, only thread $0$ can just read the first address in the shared memory and write the final reduced result to the corresponding address in the output vector. The code for this looks like: 307 | 308 | ```cpp 309 | extern __shared__ float smem[]; 310 | 311 | int bid = blockIdx.x; 312 | if (bid >= M) return; 313 | 314 | int tid = threadIdx.x; 315 | // each thread calculates its own partial output 316 | float partial_sum = 0.f; 317 | for (int col = tid; col < N; col += blockDim.x) { 318 | partial_sum += matd[bid * N + col] * vecd[col]; 319 | } 320 | 321 | blockReduceSum(partial_sum, smem, tid, blockDim.x); 322 | 323 | // only first thread writes the output to global memory 324 | if (tid == 0) { 325 | float sum = smem[0]; 326 | resd[bid] = sum; 327 | } 328 | ``` 329 | 330 | The utility function `blockReduceSum` will perform a block-level sum reduction on the partial sums computed by the threads. 331 | 332 | Running this kernel results in: 333 | 334 | ``` 335 | >> GPU allocation time: 5.870848 ms 336 | >> Host to device transfer time: 27.807808 ms 337 | ------- Coalesced warp-block sgmev kernel --------- 338 | >> Execution time: 1.607616 ms 339 | >> Achieved (GFLOPS): 41.744339 340 | >> Theoretical max (GFLOPS): 1911.040039 341 | >> Maximum memory bandwidth: 112.127998 GB/s 342 | >> Achieves 2.184378 % of peak GFLOPS 343 | >> Achieves 74.485626 % of peak Memory Bandwidth 344 | --------------------------- 345 | >> Device to host transfer time: 0.059232 ms 346 | ``` 347 | 348 | We are very close to cuBLAS with this kernel now! It achieves $74.48$% of peak memory bandwidth and $41.74$ GFLOPS. 349 | 350 | Can we do more? Let's see :) 351 | 352 | 353 | ## Kernel 4 - Vectorized loads 354 | 355 | We are already accessing the global memory in a coalesced manner while we are loading the elements of the matrix `matd` and the vector `vecd`. But there is something more that we can do, and it is called **vectorized loads**. 356 | 357 | In essence, vectorized loads (and writes) can improve the memory bandwidth performance of our kernel. What this means is: instead of loading the elements `matd[i]`, `matd[i + 1]`, `matd[i + 2]`, and `matd[i + 4]` in four load instructions, we just load all the $4$ floating-point numbers in only one load instruction. 358 | 359 | CUDA provides us with a variable type called `float4` that can hold $4$ floats (`x`, `y`, `z`, and `w`) i.e. $16$ bytes of data considering FP32 precision. To use vectorized loads, we need to cast our corresponding matrix row and input vector as `float4` so that the compiler knows that we will be loading these as `float4` elements. The code snippet that does this is: 360 | 361 | ```cpp 362 | float4* mat_row = reinterpret_cast(matd + bid * N); 363 | float4* vec = reinterpret_cast(vecd); 364 | ``` 365 | 366 | Note that we do not recast the entire matrix, we only cast the particular row (which we are working with) of the matrix as `float4`. Now, when we write something like: 367 | 368 | ```cpp 369 | float4 element = vecd[i]; 370 | ``` 371 | 372 | We have access to $4$ **consecutive** floating-point numbers in `vecd` that can be accessed like: 373 | 374 | ```cpp 375 | printf("1st float: %f\n", element.x); 376 | printf("2nd float: %f\n", element.y); 377 | printf("3rd float: %f\n", element.z); 378 | printf("4th float: %f\n", element.w); 379 | ``` 380 | 381 | To calculate the partial sum now, all we do is: 382 | 383 | ```cpp 384 | float4 matval = mat_row[col]; 385 | float4 vecval = vec[col]; 386 | 387 | partial_sum += (matval.x * vecval.x + 388 | matval.y * vecval.y + 389 | matval.z * vecval.z + 390 | matval.w * vecval.w); 391 | ``` 392 | 393 | One obvious problem to note here is that the number of columns $N$ of the matrix (and size of the input vector) must be divisble by $4$ to have vectorized loads. But this is solvable if we just "pad" the matrix columns and vector with additional zeros if $N$ is not divisible by $4$ to not have "out of bounds" memory accesses. We won't be doing that in this kernel for now since we are working with $M = 4096$ and $N = 8192$ where both the sizes are divisible by $4$. 394 | 395 | Running the vectorized loads kernel results in: 396 | 397 | ``` 398 | >> GPU allocation time: 5.848320 ms 399 | >> Host to device transfer time: 24.041409 ms 400 | ------- Vectorized sgmev kernel --------- 401 | >> Execution time: 1.356800 ms 402 | >> Achieved (GFLOPS): 49.461132 403 | >> Theoretical max (GFLOPS): 1911.040039 404 | >> Maximum memory bandwidth: 112.127998 GB/s 405 | >> Achieves 2.588179 % of peak GFLOPS 406 | >> Achieves 88.254936 % of peak Memory Bandwidth 407 | --------------------------- 408 | >> Device to host transfer time: 0.074464 ms 409 | ``` 410 | 411 | With this kernel, **we achieved $88.25$% of peak memory bandwidth and $49.46$ GFLOPS which performed better than cuBLAS!** 412 | 413 | We can plot the performance of vectorized loads kernel against cuBLAS for different matrix and vector sizes. Here's what our custom kernel looks like :D 414 | 415 | ![Vectorized vs cuBLAS](https://raw.githubusercontent.com/Maharshi-Pandya/cudacodes/refs/heads/master/matvec/media/benchmark_results.png) 416 | 417 | 418 | ## Conclusion 419 | 420 | In this worklog, we iteratively optimized the SGEMV operation starting from benchmarking cuBLAS and then writing a custom CUDA kernel for that is comparable to cuBLAS's performance if not better! While cuBLAS achieves $43.7$ GFLOPS and $78.1$% of the peak memory bandwidth, our custom kernel achieves $49.5$ GFLOPS and $88.3$% of the peak memory bandwidth. 421 | 422 | The full code is available on the GitHub repository: [Optimizing SGEMV in CUDA](https://github.com/Maharshi-Pandya/cudacodes/tree/master/matvec) 423 | 424 | Also, if you liked reading this blog/worklog then you can follow [me on X (formerly Twitter)](https://x.com/mrsiipa) for real time updates about ML, CUDA, and my life in general. 425 | 426 | Thank you for reading! 427 | -------------------------------------------------------------------------------- /flash-attention/triton_attn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | 4 | from torch.nn import functional as F 5 | from torch.utils.cpp_extension import load 6 | 7 | import triton 8 | import triton.language as tl 9 | 10 | import os 11 | 12 | # Set CUDA architecture for RTX 4090 13 | os.environ["TORCH_CUDA_ARCH_LIST"] = "8.9" 14 | 15 | # Load CUDA kernel 16 | smolattn = load(name='smolattn', sources=['fa1/build.cpp', 'fa1/flash-attn-1.cu'], extra_cuda_cflags=['-O3']) 17 | 18 | 19 | @triton.jit 20 | def _attn_fwd_inner( 21 | O_block, 22 | l_i, 23 | m_i, 24 | Q_block, 25 | K_block_ptr, 26 | V_block_ptr, 27 | block_index_q, 28 | softmax_scale, 29 | BLOCK_SIZE_Q: tl.constexpr, 30 | BLOCK_SIZE_KV: tl.constexpr, 31 | STAGE: tl.constexpr, 32 | offs_q: tl.constexpr, 33 | offs_kv: tl.constexpr, 34 | SEQ_LEN: tl.constexpr, 35 | ): 36 | # range of values handled by this stage 37 | if STAGE == 1: 38 | # From 0 to the left of the diagonal 39 | lo, hi = 0, block_index_q * BLOCK_SIZE_Q 40 | elif STAGE == 2: 41 | # Used only for the block in which there is transition between non-masked and masked keys 42 | lo, hi = block_index_q * BLOCK_SIZE_Q, (block_index_q + 1) * BLOCK_SIZE_Q 43 | lo = tl.multiple_of(lo, BLOCK_SIZE_Q) 44 | else: 45 | # Only used for non-causal attention 46 | lo, hi = 0, SEQ_LEN 47 | 48 | K_block_ptr = tl.advance(K_block_ptr, (0, lo)) 49 | V_block_ptr = tl.advance(V_block_ptr, (lo, 0)) 50 | 51 | # loop over k, v and update accumulator 52 | for start_kv in range(lo, hi, BLOCK_SIZE_KV): 53 | # Just let the compiler know that start_n is a multiple of BLOCK_N, so the compiler can do optimizations 54 | start_kv = tl.multiple_of(start_kv, BLOCK_SIZE_KV) 55 | 56 | # -- compute qk ---- 57 | K_block = tl.load(K_block_ptr) 58 | QK_block = tl.dot(Q_block, K_block) 59 | 60 | if STAGE == 2: 61 | mask = offs_q[:, None] >= (start_kv + offs_kv[None, :]) 62 | QK_block = QK_block * softmax_scale + tl.where(mask, 0, -1.0e6) 63 | m_ij = tl.maximum(m_i, tl.max(QK_block, 1)) 64 | QK_block -= m_ij[:, None] 65 | else: 66 | # Compute the maximum value of qk or keep the old max value 67 | m_ij = tl.maximum(m_i, tl.max(QK_block, 1) * softmax_scale) 68 | QK_block = QK_block * softmax_scale - m_ij[:, None] 69 | 70 | # Compute the exponential of each dot product, so now we are computing exp(qk_ij - m_ij) 71 | P_block = tl.math.exp(QK_block) 72 | # Compute the sum by rows of the attention scores 73 | l_ij = tl.sum(P_block, 1) 74 | 75 | # This is the correction factor for the previous l_i 76 | alpha = tl.math.exp(m_i - m_ij) 77 | # Apply the correction factor to the previous l_i and add the new l_ij 78 | l_i = l_i * alpha + l_ij 79 | 80 | V_block = tl.load(V_block_ptr) 81 | P_block = P_block.to(tl.float32) 82 | # This computes the following: O_new = P x V + O_old * alpha 83 | O_block = O_block * alpha[:, None] 84 | O_block = tl.dot(P_block, V_block, O_block) 85 | 86 | m_i = m_ij 87 | 88 | # Move to the next block of K and V 89 | V_block_ptr = tl.advance(V_block_ptr, (BLOCK_SIZE_KV, 0)) 90 | K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_SIZE_KV)) 91 | return O_block, l_i, m_i 92 | 93 | 94 | @triton.autotune( 95 | [ 96 | triton.Config( 97 | {"BLOCK_SIZE_Q": BLOCK_SIZE_Q, "BLOCK_SIZE_KV": BLOCK_SIZE_KV}, 98 | num_stages=num_stages, 99 | num_warps=num_warps, 100 | ) 101 | for BLOCK_SIZE_Q in [64, 128] 102 | for BLOCK_SIZE_KV in [32, 64] 103 | for num_stages in ([3, 4, 7]) 104 | for num_warps in [2, 4] 105 | ], 106 | key=["SEQ_LEN", "HEAD_DIM"], 107 | ) 108 | @triton.jit 109 | def _attn_fwd( 110 | Q, # BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM 111 | K, # BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM 112 | V, # BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM 113 | softmax_scale, 114 | M, # BATCH_SIZE, NUM_HEADS, SEQ_LEN 115 | O, # BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM 116 | stride_Q_batch, 117 | stride_Q_head, 118 | stride_Q_seq, 119 | stride_Q_dim, 120 | stride_K_batch, 121 | stride_K_head, 122 | stride_K_seq, 123 | stride_K_dim, 124 | stride_V_batch, 125 | stride_V_head, 126 | stride_V_seq, 127 | stride_V_dim, 128 | stride_O_batch, 129 | stride_O_head, 130 | stride_O_seq, 131 | stride_O_dim, 132 | BATCH_SIZE, 133 | NUM_HEADS: tl.constexpr, 134 | SEQ_LEN: tl.constexpr, 135 | HEAD_DIM: tl.constexpr, 136 | BLOCK_SIZE_Q: tl.constexpr, 137 | BLOCK_SIZE_KV: tl.constexpr, 138 | STAGE: tl.constexpr, 139 | ): 140 | tl.static_assert(BLOCK_SIZE_KV <= HEAD_DIM) 141 | 142 | # This indicate which block in the sequence length to process 143 | block_index_q = tl.program_id(0) 144 | 145 | # This indicates which head and batch to process. Each program is associated with a single head of a single batch 146 | index_batch_head = tl.program_id(1) 147 | # This indicate which batch this program is associated with (each batch has NUM_HEADS heads) 148 | index_batch = index_batch_head // NUM_HEADS 149 | # This indicate the position of the head in the batch 150 | index_head = index_batch_head % NUM_HEADS 151 | 152 | # This allows to get the (N_CTX, HEAD_DIM) block in the Q, K, V by selecting indexing it by batch and head 153 | qvk_offset = ( 154 | index_batch.to(tl.int64) * stride_Q_batch 155 | + index_head.to(tl.int64) * stride_Q_head 156 | ) 157 | 158 | Q_block_ptr = tl.make_block_ptr( 159 | base=Q + qvk_offset, 160 | shape=(SEQ_LEN, HEAD_DIM), 161 | strides=(stride_Q_seq, stride_Q_dim), 162 | offsets=(block_index_q * BLOCK_SIZE_Q, 0), 163 | block_shape=(BLOCK_SIZE_Q, HEAD_DIM), 164 | order=(1, 0), 165 | ) 166 | 167 | V_block_ptr = tl.make_block_ptr( 168 | base=V + qvk_offset, 169 | shape=(SEQ_LEN, HEAD_DIM), 170 | strides=(stride_V_seq, stride_V_dim), 171 | offsets=(0, 0), 172 | block_shape=(BLOCK_SIZE_KV, HEAD_DIM), 173 | order=(1, 0), 174 | ) 175 | 176 | K_block_ptr = tl.make_block_ptr( 177 | base=K + qvk_offset, 178 | shape=(HEAD_DIM, SEQ_LEN), 179 | strides=( 180 | stride_K_dim, 181 | stride_K_seq, 182 | ), # We invert the strides w.r.t Q, so we transpose the matrix 183 | offsets=(0, 0), 184 | block_shape=(HEAD_DIM, BLOCK_SIZE_KV), 185 | order=(0, 1), 186 | ) 187 | 188 | O_block_ptr = tl.make_block_ptr( 189 | base=O + qvk_offset, 190 | shape=(SEQ_LEN, HEAD_DIM), 191 | strides=(stride_O_seq, stride_O_dim), 192 | offsets=(block_index_q * BLOCK_SIZE_Q, 0), 193 | block_shape=(BLOCK_SIZE_Q, HEAD_DIM), 194 | order=(1, 0), 195 | ) 196 | 197 | # offs_q: the offsets for the tokens in the Q to process 198 | offs_q = block_index_q * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q) 199 | # offs_kv: the offsets for the tokens in the K and V sequence to process 200 | offs_kv = tl.arange(0, BLOCK_SIZE_KV) 201 | 202 | # m_i: the running maximum. We have one for each query 203 | m_i = tl.zeros([BLOCK_SIZE_Q], dtype=tl.float32) - float("inf") 204 | # l_i: the running sum. We have one for each query (as we sum the attention scores by rows) 205 | l_i = tl.zeros([BLOCK_SIZE_Q], dtype=tl.float32) + 1.0 206 | # acc: the accumulator for the output, which is a group of rows of the O matrix 207 | O_block = tl.zeros([BLOCK_SIZE_Q, HEAD_DIM], dtype=tl.float32) 208 | 209 | # load the blocks of Q: it will stay in SRAM throughout 210 | Q_block = tl.load(Q_block_ptr) 211 | 212 | # Stage: 3 if causal, else 1 213 | 214 | if STAGE == 1 or STAGE == 3: 215 | # This step runs for non-causal attention or for the blocks to the left of the diagonal in the causal attention 216 | O_block, l_i, m_i = _attn_fwd_inner( 217 | O_block, 218 | l_i, 219 | m_i, 220 | Q_block, 221 | K_block_ptr, 222 | V_block_ptr, 223 | block_index_q, 224 | softmax_scale, 225 | BLOCK_SIZE_Q, 226 | BLOCK_SIZE_KV, 227 | 4 - STAGE, 228 | offs_q, 229 | offs_kv, 230 | SEQ_LEN, 231 | ) 232 | 233 | if STAGE == 3: 234 | # This step runs for the blocks to the right of the diagonal in the causal attention 235 | O_block, l_i, m_i = _attn_fwd_inner( 236 | O_block, 237 | l_i, 238 | m_i, 239 | Q_block, 240 | K_block_ptr, 241 | V_block_ptr, 242 | block_index_q, 243 | softmax_scale, 244 | BLOCK_SIZE_Q, 245 | BLOCK_SIZE_KV, 246 | 2, 247 | offs_q, 248 | offs_kv, 249 | SEQ_LEN, 250 | ) 251 | # epilogue 252 | m_i += tl.math.log( 253 | l_i 254 | ) # This is needed to compute the logsumexp for the backwards pass 255 | O_block = O_block / l_i[:, None] 256 | m_ptrs = M + index_batch_head * SEQ_LEN + offs_q 257 | tl.store(m_ptrs, m_i) 258 | tl.store(O_block_ptr, O_block.to(O.type.element_ty)) 259 | 260 | 261 | @triton.jit 262 | def _attn_bwd_preprocess( 263 | O, 264 | dO, 265 | D, 266 | SEQ_LEN, 267 | BLOCK_SIZE_Q: tl.constexpr, 268 | HEAD_DIM: tl.constexpr, 269 | ): 270 | block_index_q = tl.program_id(0) 271 | offs_q = block_index_q * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q) 272 | index_batch_head = tl.program_id(1) 273 | offs_dim = tl.arange(0, HEAD_DIM) 274 | # Load a single block of BLOCK_SIZE_Q rows of O 275 | O_block = tl.load( 276 | O 277 | + index_batch_head * HEAD_DIM * SEQ_LEN 278 | + offs_q[:, None] * HEAD_DIM 279 | + offs_dim[None, :] 280 | ) 281 | # Load a single block of BLOCK_SIZE_Q rows of dO 282 | dO_block = tl.load( 283 | dO 284 | + index_batch_head * HEAD_DIM * SEQ_LEN 285 | + offs_q[:, None] * HEAD_DIM 286 | + offs_dim[None, :] 287 | ).to(tl.float32) 288 | # Compute the D block 289 | D_block = tl.sum(dO_block * O_block, axis=1) # Shape: (BLOCK_SIZE_Q,) 290 | # Store the D block 291 | D_block_ptrs = D + index_batch_head * SEQ_LEN + offs_q 292 | tl.store(D_block_ptrs, D_block) 293 | 294 | 295 | @triton.jit 296 | def _attn_bwd_dq( 297 | Q, 298 | K, 299 | V, 300 | softmax_scale, 301 | dO, 302 | dQ, 303 | dK, 304 | dV, 305 | M, 306 | D, 307 | stride_batch, 308 | stride_head, 309 | stride_seq, 310 | stride_dim, 311 | NUM_HEADS, 312 | SEQ_LEN, 313 | BLOCK_Q: tl.constexpr, 314 | BLOCK_KV: tl.constexpr, 315 | HEAD_DIM: tl.constexpr, 316 | STAGE: tl.constexpr, 317 | ): 318 | index_batch_head = tl.program_id(2) 319 | index_batch = index_batch_head // NUM_HEADS 320 | index_head = index_batch_head % NUM_HEADS 321 | offset_batch_head = (stride_batch * index_batch + stride_head * index_head).to( 322 | tl.int64 323 | ) 324 | # This is the offset that allows us to select the right sequence given the batch and head. 325 | offset_batch_head_seq = (index_batch_head * SEQ_LEN).to(tl.int64) 326 | 327 | # Make sure the pointers are in the right place w.r.t batch and head 328 | # The reason we don't access the blocks through make_block_ptr is because we need to use the range of offsets to apply the masking 329 | Q += offset_batch_head 330 | K += offset_batch_head 331 | V += offset_batch_head 332 | dO += offset_batch_head 333 | dQ += offset_batch_head 334 | dK += offset_batch_head 335 | dV += offset_batch_head 336 | 337 | # Make sure the pointers are in the right place w.r.t batch, head and sequence 338 | M += offset_batch_head_seq 339 | D += offset_batch_head_seq 340 | 341 | # load scales 342 | offs_dim = tl.arange(0, HEAD_DIM) 343 | 344 | index_block_kv = tl.program_id(0) 345 | 346 | start_q = index_block_kv * BLOCK_Q 347 | offs_q = start_q + tl.arange(0, BLOCK_Q) 348 | 349 | Q_block = tl.load(Q + offs_q[:, None] * stride_seq + offs_dim[None, :] * stride_dim) 350 | dQ_block = tl.zeros([BLOCK_Q, HEAD_DIM], dtype=tl.float32) 351 | dO_block = tl.load( 352 | dO + offs_q[:, None] * stride_seq + offs_dim[None, :] * stride_dim 353 | ) 354 | 355 | M_block = tl.load(M + offs_q) 356 | M_block = M_block[:, None] 357 | 358 | offs_kv = tl.arange(0, BLOCK_KV) 359 | 360 | # We access the K and V as transposed blocks 361 | kT_ptrs = K + offs_kv[None, :] * stride_seq + offs_dim[:, None] * stride_dim 362 | vT_ptrs = V + offs_kv[None, :] * stride_seq + offs_dim[:, None] * stride_dim 363 | 364 | Di = tl.load(D + offs_q) 365 | 366 | curr_kv = 0 367 | num_steps = SEQ_LEN // BLOCK_KV 368 | for blk_idx in range(num_steps): 369 | K_T_block = tl.load(kT_ptrs) 370 | V_T_block = tl.load(vT_ptrs) 371 | QK_block = softmax_scale * tl.dot(Q_block, K_T_block) 372 | P_block = tl.math.exp(QK_block - M_block) 373 | 374 | if STAGE == 3: 375 | # Autoregressive masking. 376 | offs_kv = curr_kv + tl.arange(0, BLOCK_KV) 377 | mask_block = offs_q[:, None] >= offs_kv[None, :] 378 | P_block = tl.where(mask_block, P_block, 0.0) 379 | 380 | # Compute dP and dS. 381 | dP_block = tl.dot(dO_block, V_T_block).to(tl.float32) 382 | dS_block = P_block * (dP_block - Di[:, None]) 383 | dS_block = dS_block.to(tl.float16) 384 | # Compute dQ. 385 | # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. 386 | dQ_block += softmax_scale * tl.dot(dS_block, tl.trans(K_T_block)) 387 | # Increment pointers. 388 | curr_kv += BLOCK_KV 389 | kT_ptrs += BLOCK_KV * stride_seq 390 | vT_ptrs += BLOCK_KV * stride_seq 391 | 392 | dQ_block_ptrs = dQ + offs_q[:, None] * stride_seq + offs_dim[None, :] * stride_dim 393 | tl.store(dQ_block_ptrs, dQ_block) 394 | 395 | 396 | @triton.jit 397 | def _attn_bwd_dk_dv( 398 | Q, 399 | K, 400 | V, 401 | softmax_scale, 402 | dO, 403 | dQ, 404 | dK, 405 | dV, 406 | M, 407 | D, 408 | stride_batch, 409 | stride_head, 410 | stride_seq, 411 | stride_dim, 412 | NUM_HEADS, 413 | SEQ_LEN, 414 | BLOCK_Q: tl.constexpr, 415 | BLOCK_KV: tl.constexpr, 416 | HEAD_DIM: tl.constexpr, 417 | STAGE: tl.constexpr, 418 | ): 419 | index_batch_head = tl.program_id(2) 420 | index_batch = index_batch_head // NUM_HEADS 421 | index_head = index_batch_head % NUM_HEADS 422 | offset_batch_head = (stride_batch * index_batch + stride_head * index_head).to( 423 | tl.int64 424 | ) 425 | # This is the offset that allows us to select the right sequence given the batch and head. 426 | offset_batch_head_seq = (index_batch_head * SEQ_LEN).to(tl.int64) 427 | 428 | # Make sure the pointers are in the right place w.r.t batch and head 429 | # The reason we don't access the blocks through make_block_ptr is because we need to use the range of offsets to apply the masking 430 | Q += offset_batch_head 431 | K += offset_batch_head 432 | V += offset_batch_head 433 | dO += offset_batch_head 434 | dQ += offset_batch_head 435 | dK += offset_batch_head 436 | dV += offset_batch_head 437 | 438 | # Make sure the pointers are in the right place w.r.t batch, head and sequence 439 | M += offset_batch_head_seq 440 | D += offset_batch_head_seq 441 | 442 | # load scales 443 | offs_dim = tl.arange(0, HEAD_DIM) 444 | 445 | index_block_kv = tl.program_id(0) 446 | start_kv = index_block_kv * BLOCK_KV 447 | 448 | offs_kv = start_kv + tl.arange(0, BLOCK_KV) 449 | 450 | dV_block = tl.zeros([BLOCK_KV, HEAD_DIM], dtype=tl.float32) 451 | dK_block = tl.zeros([BLOCK_KV, HEAD_DIM], dtype=tl.float32) 452 | 453 | # load K and V: they stay in SRAM throughout the inner loop. 454 | K_block = tl.load( 455 | K + offs_kv[:, None] * stride_seq + offs_dim[None, :] * stride_dim 456 | ) # Shape: (BLOCK_KV1, HEAD_DIM) 457 | V_block = tl.load( 458 | V + offs_kv[:, None] * stride_seq + offs_dim[None, :] * stride_dim 459 | ) # Shape: (BLOCK_KV1, HEAD_DIM) 460 | 461 | offs_q = tl.arange(0, BLOCK_Q) 462 | 463 | # We access the Q as a transposed array, so that's why we treat offs_q as a column vector ans offs_dim as a row vector 464 | # This is equivalent to doing: 465 | # q_ptrs = Q + offs_q[:, None] * stride_seq + offs_dim[None, :] * stride_dim 466 | # qT_ptrs = tl.trans(q_ptrs) 467 | # We point to the first BLOCK_Q rows of Q for both the qT and dO pointers, inside the for loop we will move forward by BLOCK_Q rows at each iteration. 468 | qT_ptrs = Q + offs_q[None, :] * stride_seq + offs_dim[:, None] * stride_dim 469 | dO_ptrs = dO + offs_q[:, None] * stride_seq + offs_dim[None, :] * stride_dim 470 | 471 | # Iterates over the sequence dimension of the query 472 | curr_q = 0 473 | num_steps = SEQ_LEN // BLOCK_Q 474 | for blk_idx in range(num_steps): 475 | # Load a block of Q 476 | qT_block = tl.load(qT_ptrs) 477 | # Load the logsumexp values for the queries in the current block 478 | offs_q = curr_q + tl.arange(0, BLOCK_Q) 479 | m = tl.load(M + offs_q) 480 | 481 | # This gives us (QK^T)^T = (K^T)^T(Q^T) = K(Q^T) = P^T 482 | QK_T_block = softmax_scale * tl.dot(K_block, qT_block) 483 | # We apply the softmax by using the logsumexp trick 484 | P_T_block = tl.math.exp(QK_T_block - m[None, :]) 485 | 486 | if STAGE == 3: 487 | # Autoregressive masking. 488 | # mask is True for all values that DO NOT NEED TO BE MASKED 489 | mask_block = ( 490 | offs_q[None, :] >= offs_kv[:, None] 491 | ) # Shape: (BLOCK_KV1, BLOCK_Q1) 492 | # Replace all the masked values with 0. 493 | # In this case we do not need to mask with -Inf before applying the softmax since we already computed the normalization factors (stored in "m") 494 | P_T_block = tl.where(mask_block, P_T_block, 0.0) 495 | 496 | dO_block = tl.load(dO_ptrs) 497 | # According to the formula: dV_new = dV_old + P^T x dO, where x is the matrix multiplication 498 | dV_block += tl.dot(P_T_block.to(tl.float16), dO_block) 499 | 500 | # Delta = rowsum(O * dO) where * is the element-wise product 501 | Di = tl.load(D + offs_q) 502 | 503 | # dP = dO x V^T, so dP^T = V x dO^T 504 | # Where x is the matrix multiplication 505 | dpT_block = tl.dot(V_block, tl.trans(dO_block)).to(tl.float32) 506 | 507 | # We know that dS = P * (dP - Delta), so dS^T = P^T * (dP^T - Delta^T) 508 | 509 | dS_T_block = P_T_block * (dpT_block - Di[None, :]) 510 | dS_T_block = dS_T_block.to(tl.float16) 511 | 512 | # According to the formula on the paper: dK_new = dK_old + dS^T x Q 513 | dK_block += softmax_scale * tl.dot(dS_T_block, tl.trans(qT_block)) 514 | # Increment pointers. 515 | curr_q += BLOCK_Q 516 | qT_ptrs += BLOCK_Q * stride_seq 517 | dO_ptrs += BLOCK_Q * stride_seq 518 | 519 | # Write back dV. 520 | dV_block_ptrs = dV + offs_kv[:, None] * stride_seq + offs_dim[None, :] * stride_dim 521 | tl.store(dV_block_ptrs, dV_block) 522 | 523 | # Write back dK. 524 | dK_block_ptrs = dK + offs_kv[:, None] * stride_seq + offs_dim[None, :] * stride_dim 525 | tl.store(dK_block_ptrs, dK_block) 526 | 527 | 528 | class TritonAttention(torch.autograd.Function): 529 | 530 | @staticmethod 531 | def forward(ctx, Q, K, V, causal, softmax_scale): 532 | HEAD_DIM_Q, HEAD_DIM_K = Q.shape[-1], K.shape[-1] 533 | HEAD_DIM_V = V.shape[-1] 534 | 535 | BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM = Q.shape 536 | 537 | assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V 538 | 539 | O = torch.empty_like(Q) 540 | stage = 3 if causal else 1 541 | 542 | grid = lambda args: ( 543 | triton.cdiv(SEQ_LEN, args["BLOCK_SIZE_Q"]), 544 | BATCH_SIZE * NUM_HEADS, 545 | 1, 546 | ) 547 | 548 | # M is the logsumexp for the backward pass, one for each query 549 | M = torch.empty( 550 | (BATCH_SIZE, NUM_HEADS, SEQ_LEN), device=Q.device, dtype=torch.float32 551 | ) 552 | 553 | _attn_fwd[grid]( 554 | Q=Q, 555 | K=K, 556 | V=V, 557 | softmax_scale=softmax_scale, 558 | M=M, 559 | O=O, 560 | stride_Q_batch=Q.stride(0), 561 | stride_Q_head=Q.stride(1), 562 | stride_Q_seq=Q.stride(2), 563 | stride_Q_dim=Q.stride(3), 564 | stride_K_batch=K.stride(0), 565 | stride_K_head=K.stride(1), 566 | stride_K_seq=K.stride(2), 567 | stride_K_dim=K.stride(3), 568 | stride_V_batch=V.stride(0), 569 | stride_V_head=V.stride(1), 570 | stride_V_seq=V.stride(2), 571 | stride_V_dim=V.stride(3), 572 | stride_O_batch=O.stride(0), 573 | stride_O_head=O.stride(1), 574 | stride_O_seq=O.stride(2), 575 | stride_O_dim=O.stride(3), 576 | BATCH_SIZE=Q.shape[0], 577 | NUM_HEADS=Q.shape[1], 578 | SEQ_LEN=Q.shape[2], 579 | HEAD_DIM=HEAD_DIM_K, 580 | STAGE=stage, 581 | ) 582 | 583 | ctx.save_for_backward(Q, K, V, O, M) 584 | ctx.grid = grid 585 | ctx.softmax_scale = softmax_scale 586 | ctx.HEAD_DIM = HEAD_DIM_K 587 | ctx.causal = causal 588 | return O 589 | 590 | @staticmethod 591 | def backward(ctx, dO): 592 | Q, K, V, O, M = ctx.saved_tensors 593 | 594 | assert dO.is_contiguous() 595 | assert Q.stride() == K.stride() == V.stride() == O.stride() == dO.stride() 596 | dQ = torch.empty_like(Q) 597 | dK = torch.empty_like(K) 598 | dV = torch.empty_like(V) 599 | 600 | BATCH_SIZE, NUM_HEADS, SEQ_LEN = Q.shape[:3] 601 | NUM_WARPS, NUM_STAGES = 4, 3 602 | BLOCK_SIZE_MICRO, BLOCK_SIZE_MACRO = 32, 128 603 | 604 | preprocess_grid = (SEQ_LEN // BLOCK_SIZE_MACRO, BATCH_SIZE * NUM_HEADS) 605 | D = torch.empty_like(M) # Shape: (BATCH_SIZE, NUM_HEADS, SEQ_LEN) 606 | 607 | # Compute all the elements Di 608 | _attn_bwd_preprocess[preprocess_grid]( 609 | O=O, 610 | dO=dO, 611 | D=D, 612 | SEQ_LEN=SEQ_LEN, 613 | BLOCK_SIZE_Q=BLOCK_SIZE_MACRO, 614 | HEAD_DIM=ctx.HEAD_DIM, 615 | ) 616 | 617 | grid = (SEQ_LEN // BLOCK_SIZE_MACRO, 1, BATCH_SIZE * NUM_HEADS) 618 | 619 | stage = 3 if ctx.causal else 1 620 | 621 | # Fix KV and iterate through all the Q blocks 622 | _attn_bwd_dk_dv[grid]( 623 | Q=Q, 624 | K=K, 625 | V=V, 626 | softmax_scale=ctx.softmax_scale, 627 | dO=dO, 628 | dQ=dQ, 629 | dK=dK, 630 | dV=dV, 631 | M=M, 632 | D=D, 633 | stride_batch=Q.stride(0), 634 | stride_head=Q.stride(1), 635 | stride_seq=Q.stride(2), 636 | stride_dim=Q.stride(3), 637 | NUM_HEADS=NUM_HEADS, 638 | SEQ_LEN=SEQ_LEN, 639 | BLOCK_Q=BLOCK_SIZE_MICRO, 640 | BLOCK_KV=BLOCK_SIZE_MACRO, 641 | HEAD_DIM=ctx.HEAD_DIM, 642 | STAGE=stage, 643 | num_warps=NUM_WARPS, 644 | num_stages=NUM_STAGES, 645 | ) 646 | 647 | # Fix Q and iterate through all the KV block 648 | _attn_bwd_dq[grid]( 649 | Q=Q, 650 | K=K, 651 | V=V, 652 | softmax_scale=ctx.softmax_scale, 653 | dO=dO, 654 | dQ=dQ, 655 | dK=dK, 656 | dV=dV, 657 | M=M, 658 | D=D, 659 | stride_batch=Q.stride(0), 660 | stride_head=Q.stride(1), 661 | stride_seq=Q.stride(2), 662 | stride_dim=Q.stride(3), 663 | NUM_HEADS=NUM_HEADS, 664 | SEQ_LEN=SEQ_LEN, 665 | BLOCK_Q=BLOCK_SIZE_MACRO, 666 | BLOCK_KV=BLOCK_SIZE_MICRO, 667 | HEAD_DIM=ctx.HEAD_DIM, 668 | STAGE=stage, 669 | num_warps=NUM_WARPS, 670 | num_stages=NUM_STAGES, 671 | ) 672 | 673 | return dQ, dK, dV, None, None 674 | 675 | 676 | def benchmark(func, *args): 677 | torch.cuda.synchronize() 678 | start = torch.cuda.Event(enable_timing=True) 679 | end = torch.cuda.Event(enable_timing=True) 680 | 681 | start.record() 682 | output = func(*args) 683 | end.record() 684 | 685 | torch.cuda.synchronize() 686 | return start.elapsed_time(end), output 687 | 688 | 689 | # Manual attention function 690 | def manual_attn(q, k, v): 691 | att = (q @ k.transpose(-2, -1) * (1.0 / math.sqrt(k.size(-1)))) 692 | att = F.softmax(att, dim=-1) 693 | y = att @ v 694 | return y 695 | 696 | 697 | def triton_forward(Q, K, V, causal, dtype=torch.float32): 698 | softmax_scale = 1 / (HEAD_DIM**0.5) 699 | tri_out = TritonAttention.apply(Q, K, V, causal, softmax_scale) 700 | return tri_out 701 | 702 | 703 | if __name__ == "__main__": 704 | BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM = 16, 8, 512, 64 705 | causal = False 706 | 707 | # Create test tensors 708 | Q = torch.randn((BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM), device="cuda") 709 | K = torch.randn((BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM), device="cuda") 710 | V = torch.randn((BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM), device="cuda") 711 | 712 | # Benchmark manual attention 713 | manual_time, manual_output = benchmark(manual_attn, Q, K, V) 714 | 715 | # Benchmark Triton attention 716 | triton_time, triton_output = benchmark(triton_forward, Q, K, V, causal) 717 | 718 | # Benchmark CUDA attention (smolattn) 719 | smolattn_time, smolattn_output = benchmark(smolattn.fa_forward, Q, K, V) 720 | 721 | # Check results match 722 | match_triton = torch.allclose(manual_output, triton_output, atol=1e-2) 723 | match_smolattn = torch.allclose(manual_output, smolattn_output, atol=1e-2) 724 | 725 | # Print results 726 | print(f"\nBatch size: {BATCH_SIZE}, Num heads: {NUM_HEADS}, Sequence length: {SEQ_LEN}, Head dims: {HEAD_DIM}\n") 727 | 728 | print(f">> Manual PyTorch Attention: {manual_time:.3f} ms") 729 | print(f">> Triton (flash attention): {triton_time:.3f} ms") 730 | print(f">> Smolattn vs. Triton speedup: {triton_time / smolattn_time:.5f}x") 731 | print(f">> Smolattn vs. Pytorch speedup: {manual_time / smolattn_time:.5f}x") 732 | print(f">> Smolattn results match: {match_smolattn}") 733 | print(f">> Triton results match: {match_triton}\n") 734 | --------------------------------------------------------------------------------