├── .gitignore ├── allreduce ├── csrc │ ├── my_sync │ │ ├── sync_torch_bindings.cu │ │ ├── format.bash │ │ ├── compile_combined.bash │ │ ├── sync_test.cu │ │ └── sync.cuh │ ├── my_allreduce │ │ ├── format.bash │ │ ├── compile_combined.bash │ │ ├── allreduce_torch_bindings.cu │ │ ├── allreduce_test.cu │ │ └── allreduce.cuh │ ├── add_one │ │ ├── add_one.h │ │ └── add_one.cu │ ├── pybind.cpp │ ├── mpi_cuda_helloworld │ │ ├── split_kernel.cu │ │ ├── compile_split.bash │ │ ├── compile_combined.bash │ │ ├── split_driver.cpp │ │ └── combined.cu │ ├── cuda_exp │ │ ├── compile.bash │ │ └── main.cu │ └── reference_allreduce │ │ ├── compile_combined.bash │ │ ├── fast_allreduce.cu │ │ ├── fast_allreduce_test.cu │ │ └── fast_allreduce.cuh ├── requirements.txt ├── scripts │ ├── glibc_version.bash │ ├── cuda-memcheck-script.bash │ └── strip_glibc_version.bash ├── mpi_tutorials │ ├── run_hello_world.bash │ └── hello_world.c ├── add_one.py ├── .gitignore ├── modal_runner.py ├── setup.py └── notes.md ├── assets ├── awq.png └── PagedAttention.png ├── awq ├── matmul-performance.png ├── matmul-performance.csv ├── gemm_v1_benchmark.py ├── torch_reference.py ├── gemm_kernel_v1.py └── custom_autotune.py ├── LICENSE ├── README.md ├── misc.ipynb └── paged_attention_triton └── attention_kernel.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | 3 | -------------------------------------------------------------------------------- /allreduce/csrc/my_sync/sync_torch_bindings.cu: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /allreduce/requirements.txt: -------------------------------------------------------------------------------- 1 | # Copied from VLLM 2 | ninja 3 | torch == 2.1.2 -------------------------------------------------------------------------------- /assets/awq.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vedantroy/gpu_kernels/HEAD/assets/awq.png -------------------------------------------------------------------------------- /allreduce/scripts/glibc_version.bash: -------------------------------------------------------------------------------- 1 | docker run --rm debian:bookworm-slim dpkg -l | grep libc6 -------------------------------------------------------------------------------- /assets/PagedAttention.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vedantroy/gpu_kernels/HEAD/assets/PagedAttention.png -------------------------------------------------------------------------------- /allreduce/csrc/my_sync/format.bash: -------------------------------------------------------------------------------- 1 | find csrc/my_sync -iname "*.cu" -o -iname "*.cuh" | xargs clang-format -i 2 | -------------------------------------------------------------------------------- /awq/matmul-performance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vedantroy/gpu_kernels/HEAD/awq/matmul-performance.png -------------------------------------------------------------------------------- /allreduce/csrc/my_allreduce/format.bash: -------------------------------------------------------------------------------- 1 | find csrc/my_allreduce -iname "*.cu" -o -iname "*.cuh" | xargs clang-format -i 2 | -------------------------------------------------------------------------------- /allreduce/csrc/add_one/add_one.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | torch::Tensor add_one(torch::Tensor tensor); 5 | 6 | -------------------------------------------------------------------------------- /allreduce/mpi_tutorials/run_hello_world.bash: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env bash 2 | set -euxo pipefail 3 | mpicc -o hello_world.bin hello_world.c 4 | mpirun -np 8 ./hello_world.bin -------------------------------------------------------------------------------- /allreduce/add_one.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cuda_experiments 3 | 4 | tensor = torch.randn(5, device='cuda') 5 | result = cuda_experiments.add_one(tensor) 6 | torch.testing.assert_close(result, tensor + 1) 7 | -------------------------------------------------------------------------------- /allreduce/.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | *.egg-info/ 3 | dist/ 4 | build/ 5 | *.pyc 6 | __pycache__/ 7 | 8 | # Convention for compiled binaries 9 | *.bin 10 | 11 | *.o 12 | 13 | .ninja_deps 14 | .ninja_log 15 | 16 | *.log 17 | *.memcheck 18 | -------------------------------------------------------------------------------- /allreduce/scripts/cuda-memcheck-script.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | LOG=$1.$OMPI_COMM_WORLD_RANK 3 | cuda-memcheck --log-file $LOG.log --save $LOG.memcheck $* 4 | 5 | # use w/ 6 | # mpiexec -np 2 cuda-memcheck-script.bash ./myapp 7 | # mpiexec --allow-run-as-root -np 2 ./scripts/cuda-memcheck-script.bash ./sync_test.bin -------------------------------------------------------------------------------- /allreduce/csrc/pybind.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "add_one/add_one.h" 4 | #include "my_allreduce/allreduce_torch_bindings.cu" 5 | 6 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 7 | m.def("add_one", &add_one, "Add one to all elements of the tensor"); 8 | 9 | m.def("init_ar", &init_ar, "init_ar"); 10 | m.def("allreduce", &allreduce, "allreduce"); 11 | m.def("register_buffer", ®ister_buffer, "register_buffer"); 12 | } 13 | -------------------------------------------------------------------------------- /allreduce/csrc/mpi_cuda_helloworld/split_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | __global__ void writeData(int* data, int rank, int world_size) { 5 | if (threadIdx.x < world_size) { 6 | data[rank] = rank; 7 | printf("Rank %d writing to position %d\n", rank, rank); 8 | } 9 | } 10 | 11 | void runWriteData(int* data, int rank, int world_size) { 12 | writeData<<<1, world_size>>>(data, rank, world_size); 13 | } -------------------------------------------------------------------------------- /allreduce/csrc/mpi_cuda_helloworld/compile_split.bash: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env bash 2 | BASE_PATH=csrc/mpi_cuda_helloworld 3 | KERNEL_PATH=$BASE_PATH/split_kernel 4 | KERNEL_OBJ_FILE=$KERNEL_PATH.o 5 | DRIVER_OBJ_FILE=$BASE_PATH/split_driver.o 6 | nvcc -c $KERNEL_PATH.cu -o $KERNEL_OBJ_FILE -gencode=arch=compute_86,code=sm_86 7 | mpic++ -c $BASE_PATH/split_driver.cpp -o $DRIVER_OBJ_FILE -I/usr/local/cuda/include 8 | mpic++ $KERNEL_OBJ_FILE $DRIVER_OBJ_FILE -lcudart -L/usr/local/cuda/lib64/ -o split.bin -------------------------------------------------------------------------------- /allreduce/csrc/mpi_cuda_helloworld/compile_combined.bash: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env bash 2 | 3 | # Run `mpicc -show` to find include directories 4 | 5 | # https://docs.ccv.brown.edu/oscar/gpu-computing/mpi-cuda 6 | # https://anhnguyen.me/2013/12/how-to-mix-mpi-and-cuda-in-a-single-program/ 7 | # nvcc -I/usr/mpi/gcc/openmpi-1.4.6/include -L/usr/mpi/gcc/openmpi-1.4.6/lib64 -lmpi spaghetti.cu -o program 8 | script_dir=$(dirname "$0") 9 | nvcc -I/usr/lib/x86_64-linux-gnu/openmpi/include/openmpi -I/usr/lib/x86_64-linux-gnu/openmpi/include -L/usr/lib/x86_64-linux-gnu/openmpi/lib -lmpi "$script_dir/combined.cu" -o combined.bin -------------------------------------------------------------------------------- /awq/matmul-performance.csv: -------------------------------------------------------------------------------- 1 | M,cuda,Triton 2 | 256.0,67.7,83.9 3 | 384.0,75.3,102.4 4 | 512.0,79.5,113.4 5 | 640.0,82.2,104.3 6 | 768.0,83.6,124.6 7 | 896.0,83.9,111.2 8 | 1024.0,81.4,124.7 9 | 1152.0,79.1,124.2 10 | 1280.0,78.5,135.7 11 | 1408.0,77.5,123.7 12 | 1536.0,76.8,133.5 13 | 1664.0,76.6,127.4 14 | 1792.0,76.0,135.6 15 | 1920.0,75.8,131.1 16 | 2048.0,75.7,138.1 17 | 2176.0,75.5,131.8 18 | 2304.0,75.8,138.3 19 | 2432.0,75.7,133.3 20 | 2560.0,75.7,139.3 21 | 2688.0,75.8,134.3 22 | 2816.0,75.8,140.0 23 | 2944.0,75.8,135.7 24 | 3072.0,75.9,140.8 25 | 3200.0,76.0,136.7 26 | 3328.0,76.2,141.3 27 | 3456.0,76.2,137.4 28 | 3584.0,76.4,141.5 29 | 3712.0,76.3,138.1 30 | 3840.0,76.4,141.7 31 | 3968.0,76.4,138.8 32 | 4096.0,76.3,141.7 33 | -------------------------------------------------------------------------------- /allreduce/csrc/add_one/add_one.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include "add_one.h" 3 | 4 | __global__ void add_one_kernel(const float* in_data, float* out_data, int size) { 5 | int index = threadIdx.x + blockIdx.x * blockDim.x; 6 | if (index < size) { 7 | out_data[index] = in_data[index] + 1.0; 8 | } 9 | } 10 | 11 | torch::Tensor add_one(torch::Tensor input) { 12 | TORCH_CHECK(input.is_cuda() && input.type().scalarType() == at::ScalarType::Float, "input must be a CUDA float tensor"); 13 | auto output = torch::empty_like(input); 14 | const auto size = input.numel(); 15 | const int threads = 1024; 16 | const int blocks = (size + threads - 1) / threads; 17 | 18 | add_one_kernel<<>>(input.data_ptr(), output.data_ptr(), size); 19 | return output; 20 | } 21 | 22 | 23 | 24 | -------------------------------------------------------------------------------- /allreduce/mpi_tutorials/hello_world.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | int main(int argc, char** argv) { 5 | // Initialize the MPI environment 6 | MPI_Init(NULL, NULL); 7 | 8 | // Get the number of processes 9 | int world_size; 10 | MPI_Comm_size(MPI_COMM_WORLD, &world_size); 11 | 12 | // Get the rank of the process 13 | int world_rank; 14 | MPI_Comm_rank(MPI_COMM_WORLD, &world_rank); 15 | 16 | // Get the name of the processor 17 | char processor_name[MPI_MAX_PROCESSOR_NAME]; 18 | int name_len; 19 | MPI_Get_processor_name(processor_name, &name_len); 20 | 21 | // Print off a hello world message 22 | printf("Hello world from processor %s, rank %d out of %d processors\n", 23 | processor_name, world_rank, world_size); 24 | 25 | // Finalize the MPI environment. 26 | MPI_Finalize(); 27 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Vedant Roy 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /allreduce/scripts/strip_glibc_version.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -euo pipefail 3 | 4 | # Check if the file path is provided 5 | if [ "$#" -ne 1 ]; then 6 | echo "Usage: $0 " 7 | exit 1 8 | fi 9 | 10 | # Assign the file path to a variable 11 | so_file="$1" 12 | to_remove="GLIBC_2.32" 13 | 14 | # Check if the file exists 15 | if [ ! -f "$so_file" ]; then 16 | echo "Error: File '$so_file' not found." 17 | exit 2 18 | fi 19 | 20 | # Print the glibc versions 21 | versions=$(nm --dynamic --undefined-only --with-symbol-versions "$so_file" \ 22 | | grep GLIBC | sed -e 's#.\+@##' | sort --unique) 23 | echo "$versions" 24 | 25 | # Extract the symbols that use GLIBC_2.29 and clear their version 26 | nm --dynamic --undefined-only --with-symbol-versions "$so_file" | grep ${to_remove} | awk '{print $3}' | \ 27 | while read -r symbol; do 28 | echo "Clearing version for symbol: $symbol" 29 | patchelf --clear-symbol-version "$symbol" "$so_file" 30 | done 31 | 32 | echo "FINISHED" 33 | # Print new versions 34 | versions=$(nm --dynamic --undefined-only --with-symbol-versions "$so_file" \ 35 | | grep GLIBC | sed -e 's#.\+@##' | sort --unique) 36 | echo "$versions" -------------------------------------------------------------------------------- /allreduce/csrc/mpi_cuda_helloworld/split_driver.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | extern void runWriteData(int* data, int rank, int world_size); 6 | 7 | int main(int argc, char** argv) { 8 | MPI_Init(&argc, &argv); 9 | 10 | int world_size, rank; 11 | MPI_Comm_size(MPI_COMM_WORLD, &world_size); 12 | MPI_Comm_rank(MPI_COMM_WORLD, &rank); 13 | 14 | // Initialize CUDA 15 | cudaSetDevice(rank); 16 | 17 | // Allocate memory and IPC handles 18 | int* data; 19 | cudaMalloc((void**)&data, world_size * sizeof(int)); 20 | cudaIpcMemHandle_t handle; 21 | cudaIpcGetMemHandle(&handle, data); 22 | 23 | // Gather all handles 24 | cudaIpcMemHandle_t* handles = new cudaIpcMemHandle_t[world_size]; 25 | MPI_Allgather(&handle, sizeof(cudaIpcMemHandle_t), MPI_BYTE, handles, sizeof(cudaIpcMemHandle_t), MPI_BYTE, MPI_COMM_WORLD); 26 | 27 | // Write rank to each device's memory 28 | for (int i = 0; i < world_size; i++) { 29 | if (i != rank) { 30 | int* remoteData; 31 | cudaIpcOpenMemHandle((void**)&remoteData, handles[i], cudaIpcMemLazyEnablePeerAccess); 32 | runWriteData(remoteData, rank, world_size); 33 | cudaDeviceSynchronize(); 34 | cudaIpcCloseMemHandle(remoteData); 35 | } 36 | } 37 | 38 | // Cleanup 39 | cudaFree(data); 40 | delete[] handles; 41 | 42 | MPI_Finalize(); 43 | return 0; 44 | } 45 | -------------------------------------------------------------------------------- /allreduce/csrc/cuda_exp/compile.bash: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env bash 2 | set -euxo pipefail 3 | 4 | if command -v nvidia-smi &> /dev/null 5 | then 6 | # Get GPU name 7 | gpu_name=$(nvidia-smi --query-gpu=name --format=csv,noheader) 8 | 9 | # Set architecture based on GPU name 10 | case $gpu_name in 11 | *T4*) 12 | arch=compute_75 13 | code=sm_75 14 | ;; 15 | *A10*) 16 | arch=compute_80 17 | code=sm_80 18 | ;; 19 | *V100*) 20 | arch=compute_70 21 | code=sm_70 22 | ;; 23 | *A2000*) 24 | arch=compute_86 25 | code=sm_86 26 | ;; 27 | *4090*) 28 | arch=compute_90 29 | code=sm_90 30 | ;; 31 | *A4000*) 32 | arch=compute_86 33 | code=sm_86 34 | ;; 35 | *) 36 | echo "Unsupported GPU: $gpu_name" 37 | exit 1 38 | ;; 39 | esac 40 | else 41 | # Default to A2000 if nvidia-smi does not exist 42 | gpu_name="A2000" 43 | arch=compute_86 44 | code=sm_86 45 | fi 46 | 47 | echo "GPU: $gpu_name" 48 | echo "Architecture: $arch" 49 | echo "Compute capability: $code" 50 | 51 | # Directory of the script 52 | script_dir=$(dirname "$0") 53 | 54 | # nvcc command with dynamic architecture 55 | nvcc -I/usr/include -I/usr/lib/x86_64-linux-gnu/openmpi/include/openmpi -I/usr/lib/x86_64-linux-gnu/openmpi/include -L/usr/lib/x86_64-linux-gnu/openmpi/lib -lmpi -L/usr/lib/x86_64-linux-gnu -lnccl -gencode=arch=$arch,code=$code "$script_dir/main.cu" -o cuda_exp.bin --expt-relaxed-constexpr -G 56 | -------------------------------------------------------------------------------- /allreduce/csrc/my_sync/compile_combined.bash: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env bash 2 | set -euxo pipefail 3 | 4 | if command -v nvidia-smi &> /dev/null 5 | then 6 | # Get GPU name 7 | gpu_name=$(nvidia-smi --query-gpu=name --format=csv,noheader) 8 | 9 | # Set architecture based on GPU name 10 | case $gpu_name in 11 | *T4*) 12 | arch=compute_75 13 | code=sm_75 14 | ;; 15 | *A10*) 16 | arch=compute_80 17 | code=sm_80 18 | ;; 19 | *V100*) 20 | arch=compute_70 21 | code=sm_70 22 | ;; 23 | *A2000*) 24 | arch=compute_86 25 | code=sm_86 26 | ;; 27 | *4090*) 28 | arch=compute_90 29 | code=sm_90 30 | ;; 31 | *A4000*) 32 | arch=compute_86 33 | code=sm_86 34 | ;; 35 | *) 36 | echo "Unsupported GPU: $gpu_name" 37 | exit 1 38 | ;; 39 | esac 40 | else 41 | # Default to A2000 if nvidia-smi does not exist 42 | gpu_name="A2000" 43 | arch=compute_86 44 | code=sm_86 45 | fi 46 | 47 | echo "GPU: $gpu_name" 48 | echo "Architecture: $arch" 49 | echo "Compute capability: $code" 50 | 51 | # Directory of the script 52 | script_dir=$(dirname "$0") 53 | 54 | # nvcc command with dynamic architecture 55 | nvcc -I/usr/include -I/usr/lib/x86_64-linux-gnu/openmpi/include/openmpi -I/usr/lib/x86_64-linux-gnu/openmpi/include -L/usr/lib/x86_64-linux-gnu/openmpi/lib -lmpi -L/usr/lib/x86_64-linux-gnu -lnccl -gencode=arch=$arch,code=$code "$script_dir/sync_test.cu" -o sync_test.bin --expt-relaxed-constexpr -G 56 | -------------------------------------------------------------------------------- /allreduce/csrc/my_allreduce/compile_combined.bash: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env bash 2 | set -euxo pipefail 3 | 4 | if command -v nvidia-smi &> /dev/null 5 | then 6 | # Get GPU name 7 | gpu_name=$(nvidia-smi --query-gpu=name --format=csv,noheader) 8 | 9 | # Set architecture based on GPU name 10 | case $gpu_name in 11 | *T4*) 12 | arch=compute_75 13 | code=sm_75 14 | ;; 15 | *A10*) 16 | arch=compute_80 17 | code=sm_80 18 | ;; 19 | *V100*) 20 | arch=compute_70 21 | code=sm_70 22 | ;; 23 | *A2000*) 24 | arch=compute_86 25 | code=sm_86 26 | ;; 27 | *4090*) 28 | arch=compute_90 29 | code=sm_90 30 | ;; 31 | *A4000*) 32 | arch=compute_86 33 | code=sm_86 34 | ;; 35 | *) 36 | echo "Unsupported GPU: $gpu_name" 37 | exit 1 38 | ;; 39 | esac 40 | else 41 | # Default to A2000 if nvidia-smi does not exist 42 | gpu_name="A2000" 43 | arch=compute_86 44 | code=sm_86 45 | fi 46 | 47 | echo "GPU: $gpu_name" 48 | echo "Architecture: $arch" 49 | echo "Compute capability: $code" 50 | 51 | # Directory of the script 52 | script_dir=$(dirname "$0") 53 | 54 | # nvcc command with dynamic architecture 55 | nvcc -I/usr/include -I/usr/lib/x86_64-linux-gnu/openmpi/include/openmpi -I/usr/lib/x86_64-linux-gnu/openmpi/include -L/usr/lib/x86_64-linux-gnu/openmpi/lib -lmpi -L/usr/lib/x86_64-linux-gnu -lnccl -gencode=arch=$arch,code=$code "$script_dir/allreduce_test.cu" -o allreduce_test.bin --expt-relaxed-constexpr -G 56 | -------------------------------------------------------------------------------- /allreduce/csrc/mpi_cuda_helloworld/combined.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | __global__ void writeData(int* data, int rank, int world_size) { 6 | if (threadIdx.x < world_size) { 7 | data[rank] = rank; // Write rank to its own position 8 | printf("Rank %d writing to position %d\n", rank, rank); 9 | } 10 | } 11 | 12 | int main(int argc, char** argv) { 13 | MPI_Init(&argc, &argv); 14 | 15 | int world_size, rank; 16 | MPI_Comm_size(MPI_COMM_WORLD, &world_size); 17 | MPI_Comm_rank(MPI_COMM_WORLD, &rank); 18 | 19 | // Initialize CUDA 20 | cudaSetDevice(rank); 21 | 22 | // Allocate memory for all ranks on each device 23 | int* data; 24 | cudaMalloc((void**)&data, world_size * sizeof(int)); 25 | cudaIpcMemHandle_t handle; 26 | cudaIpcGetMemHandle(&handle, data); 27 | 28 | // Gather all handles 29 | cudaIpcMemHandle_t* handles = new cudaIpcMemHandle_t[world_size]; 30 | MPI_Allgather(&handle, sizeof(cudaIpcMemHandle_t), MPI_BYTE, handles, sizeof(cudaIpcMemHandle_t), MPI_BYTE, MPI_COMM_WORLD); 31 | 32 | // Write rank to each device's memory 33 | for (int i = 0; i < world_size; i++) { 34 | if (i != rank) { // Skip own memory 35 | int* remoteData; 36 | cudaIpcOpenMemHandle((void**)&remoteData, handles[i], cudaIpcMemLazyEnablePeerAccess); 37 | writeData<<<1, world_size>>>(remoteData, rank, world_size); 38 | cudaDeviceSynchronize(); 39 | cudaIpcCloseMemHandle(remoteData); 40 | } 41 | } 42 | 43 | // Cleanup 44 | cudaFree(data); 45 | delete[] handles; 46 | 47 | MPI_Finalize(); 48 | return 0; 49 | } -------------------------------------------------------------------------------- /allreduce/csrc/cuda_exp/main.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include 6 | #include 7 | 8 | #define CUDACHECK(cmd) \ 9 | do { \ 10 | cudaError_t e = cmd; \ 11 | if (e != cudaSuccess) { \ 12 | printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \ 13 | cudaGetErrorString(e)); \ 14 | exit(EXIT_FAILURE); \ 15 | } \ 16 | } while (0) 17 | 18 | int main(int argc, char **argv) { 19 | #define N_ELEMENTS (1024 * 8) 20 | #define DTYPE half 21 | 22 | DTYPE *input_buf; 23 | // allocate N elements 24 | CUDACHECK(cudaMalloc(&input_buf, N_ELEMENTS * sizeof(DTYPE))); 25 | 26 | // This is wrong b/c we are setting each byte to 1 27 | // but fp16 values are 2 bytes 28 | // CUDACHECK(cudaMemset(input_buf, 1, N_ELEMENTS)); 29 | 30 | DTYPE *input_buf_cpu = new DTYPE[N_ELEMENTS]; 31 | for (int i = 0; i < N_ELEMENTS; i++) { 32 | input_buf_cpu[i] = __float2half(1.0f); 33 | } 34 | 35 | // Copy from CPU to GPU 36 | CUDACHECK(cudaMemcpy(input_buf, input_buf_cpu, N_ELEMENTS * sizeof(DTYPE), cudaMemcpyHostToDevice)); 37 | 38 | // mem copy to cpu and print first element 39 | DTYPE *input_buf_cpu2 = new DTYPE[N_ELEMENTS]; 40 | CUDACHECK(cudaMemcpy(input_buf_cpu, input_buf, N_ELEMENTS * sizeof(DTYPE), cudaMemcpyDeviceToHost)); 41 | printf("input_buf[0] = %f\n", __half2float(input_buf_cpu2[0])); 42 | 43 | return EXIT_SUCCESS; 44 | } 45 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GPU Kernels 2 | This project implements GPU kernels in CUDA/Triton for Allreduce, [PagedAttention](https://arxiv.org/abs/2309.06180), and [Activation-aware Weight Quantization](https://arxiv.org/abs/2306.00978). 3 | 4 | ### Allreduce 5 | There's an implementation of a one-pass allreduce (all ranks read/write from other ranks). The implementation is largely a stripped down version of: https://github.com/vllm-project/vllm/pull/2192. I rewrote parts from scratch, but also copy-pasted a fair bit as well. It's also similar to https://github.com/pytorch/pytorch/pull/114001, which itself is inspired by FasterTransformer. In the process of writing the code, I learned a bunch about CUDA/MPI/etc. 6 | 7 | ### PagedAttention: 8 | ![Paged Attention](./assets/PagedAttention.png) 9 | 10 | Paged attention stores KV vectors in a cache, instead of recomputing them. 11 | 12 | The PagedAttention kernel is not faster than the existing CUDA kernel because Triton has limitations that prevent it from doing the necessary tensor operations. See 13 | 1. https://github.com/openai/triton/issues/2488 14 | 2. https://github.com/openai/triton/issues/2522 15 | 16 | ### AWQ: 17 | ![AWQ](./assets/awq.png) 18 | 19 | AWQ is a quantization method. This kernel implements fast inference using the quantized weights. 20 | 21 | Roughly, the AWQ kernel is dequantizing a matrix using the formula `scale * (weight - zero_point)` before doing a standard FP16 matmul. 22 | 23 | The AWQ kernel is much faster than the existing CUDA implementation, in addition to being simpler (~ 300 lines of C + inline assembly vs ~ 50 lines of Triton). 24 | 25 | Here's a performance comparison: 26 | ![Performance Graph](./awq/matmul-performance.png) 27 | 28 | Credit to 29 | - The Triton matmul tutorial 30 | - [GPTQ-Triton](https://github.com/fpgaminer/GPTQ-triton) for discovering a few clever tricks I used in this kernel and making me realize that using Triton for quantization inference was possible -------------------------------------------------------------------------------- /allreduce/csrc/reference_allreduce/compile_combined.bash: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env bash 2 | set -euxo pipefail 3 | 4 | # Get GPU name 5 | gpu_name=$(nvidia-smi --query-gpu=name --format=csv,noheader) 6 | 7 | # Set architecture based on GPU name 8 | case $gpu_name in 9 | *T4*) 10 | arch=compute_75 11 | code=sm_75 12 | ;; 13 | *A10*) 14 | arch=compute_80 15 | code=sm_80 16 | ;; 17 | *V100*) 18 | arch=compute_70 19 | code=sm_70 20 | ;; 21 | *A2000*) 22 | arch=compute_86 23 | code=sm_86 24 | ;; 25 | *V100*) 26 | arch=compute_70 27 | code=sm_70 28 | ;; 29 | *4090*) 30 | arch=compute_90 31 | code=sm_90 32 | ;; 33 | *) 34 | echo "Unsupported GPU: $gpu_name" 35 | exit 1 36 | ;; 37 | esac 38 | 39 | echo "GPU: $gpu_name" 40 | echo "Architecture: $arch" 41 | echo "Compute capability: $code" 42 | 43 | # Directory of the script 44 | script_dir=$(dirname "$0") 45 | 46 | # nvcc command with dynamic architecture 47 | nvcc -I/usr/include -I/usr/lib/x86_64-linux-gnu/openmpi/include/openmpi -I/usr/lib/x86_64-linux-gnu/openmpi/include -L/usr/lib/x86_64-linux-gnu/openmpi/lib -lmpi -L/usr/lib/x86_64-linux-gnu -lnccl -gencode=arch=$arch,code=$code "$script_dir/fast_allreduce_test.cu" -o fastallreduce_test.bin 48 | 49 | 50 | # # Run `mpicc -show` to find include directories 51 | # 52 | # # https://docs.ccv.brown.edu/oscar/gpu-computing/mpi-cuda 53 | # # https://anhnguyen.me/2013/12/how-to-mix-mpi-and-cuda-in-a-single-program/ 54 | # # nvcc -I/usr/mpi/gcc/openmpi-1.4.6/include -L/usr/mpi/gcc/openmpi-1.4.6/lib64 -lmpi spaghetti.cu -o program 55 | # script_dir=$(dirname "$0") 56 | # # -I/usr/include => path to nccl include directory 57 | # # -L/usr/lib/x86_64-linux-gnu => path to ncl libraries 58 | # nvcc -I/usr/include -I/usr/lib/x86_64-linux-gnu/openmpi/include/openmpi -I/usr/lib/x86_64-linux-gnu/openmpi/include -L/usr/lib/x86_64-linux-gnu/openmpi/lib -lmpi -L/usr/lib/x86_64-linux-gnu -lnccl -gencode=arch=compute_86,code=sm_86 "$script_dir/fast_allreduce_test.cu" -o fastallreduce_test.bin 59 | -------------------------------------------------------------------------------- /allreduce/csrc/my_allreduce/allreduce_torch_bindings.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include "allreduce.cuh" 7 | 8 | // we are returning a pointer as a uint64 9 | // the static assert validates that pointers on this system 10 | // are infact 64 bits 11 | using fptr_t = uint64_t; 12 | static_assert(sizeof(void *) == sizeof(fptr_t)); 13 | 14 | fptr_t init_ar(torch::Tensor &bstate, torch::Tensor &rank_data, 15 | const std::vector &handles, 16 | const std::vector &offsets, int rank) { 17 | 18 | int world_size = offsets.size(); 19 | if (world_size > 8) 20 | throw std::invalid_argument("world size > 8 is not supported"); 21 | if (world_size % 2 != 0) 22 | throw std::invalid_argument("Odd num gpus is not supported for now"); 23 | if (world_size != handles.size()) 24 | throw std::invalid_argument( 25 | "handles length should equal to offsets length"); 26 | if (rank < 0 || rank >= world_size) 27 | throw std::invalid_argument("invalid rank passed in"); 28 | 29 | cudaIpcMemHandle_t ipc_handles[8]; 30 | for (int i = 0; i < world_size; i++) { 31 | std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(cudaIpcMemHandle_t)); 32 | } 33 | return (fptr_t) new mysync::Sync( 34 | reinterpret_cast(bstate.data_ptr()), ipc_handles, offsets, rank); 35 | } 36 | 37 | void register_buffer(fptr_t _fa, torch::Tensor &t, 38 | const std::vector &handles, 39 | const std::vector &offsets) { 40 | auto fa = reinterpret_cast(_fa); 41 | fa->register_buffer(handles, offsets, t.data_ptr()); 42 | } 43 | 44 | void allreduce(fptr_t _fa, torch::Tensor &out) { 45 | auto fa = reinterpret_cast(_fa); 46 | switch (out.scalar_type()) { 47 | case at::ScalarType::Half: { 48 | fa->allreduce(out.numel(), reinterpret_cast(out.data_ptr())); 49 | break; 50 | } 51 | default: 52 | throw std::runtime_error( 53 | "allreduce only supports float16"); 54 | } 55 | } -------------------------------------------------------------------------------- /allreduce/csrc/my_sync/sync_test.cu: -------------------------------------------------------------------------------- 1 | #include "mpi.h" 2 | #include "sync.cuh" 3 | 4 | #define MPICHECK(cmd) \ 5 | do { \ 6 | int e = cmd; \ 7 | if (e != MPI_SUCCESS) { \ 8 | printf("Failed: MPI error %s:%d '%d'\n", __FILE__, __LINE__, e); \ 9 | exit(EXIT_FAILURE); \ 10 | } \ 11 | } while (0) 12 | 13 | int main(int argc, char **argv) { 14 | MPI_Init(NULL, NULL); 15 | 16 | int world_size; 17 | MPI_Comm_size(MPI_COMM_WORLD, &world_size); 18 | 19 | int world_rank; 20 | MPI_Comm_rank(MPI_COMM_WORLD, &world_rank); 21 | 22 | CUDACHECK(cudaSetDevice(world_rank)); 23 | 24 | mysync::BarrierState *state; 25 | CUDACHECK(cudaMalloc(&state, sizeof(mysync::BarrierState))); 26 | 27 | cudaIpcMemHandle_t cur_rank_handle; 28 | cudaIpcMemHandle_t rank_handles[8]; 29 | 30 | CUDACHECK(cudaIpcGetMemHandle(&cur_rank_handle, state)); 31 | MPICHECK(MPI_Allgather(&cur_rank_handle, // void* send_data, 32 | sizeof(cudaIpcMemHandle_t), // int send_count, 33 | MPI_BYTE, // MPI_Datatype send_datatype, 34 | rank_handles, // void* recv_data, 35 | sizeof(cudaIpcMemHandle_t), // int recv_count, 36 | MPI_BYTE, // MPI_Datatype recv_datatype, 37 | MPI_COMM_WORLD // MPI_Comm communicator 38 | )); 39 | 40 | // Offsets are only necessary for Pytorch bindings 41 | // (where tensors are not allocated at the start of a cudaIpcMemHandle) 42 | // (that's why we set them to 0 here) 43 | std::vector offsets(world_size, 0); 44 | 45 | mysync::Sync sync(state, rank_handles, offsets, world_rank); 46 | 47 | // 4 blocks, 64 threads per block 48 | sync.sync_test(4, 64); 49 | 50 | MPI_Finalize(); 51 | return EXIT_SUCCESS; 52 | } 53 | -------------------------------------------------------------------------------- /awq/gemm_v1_benchmark.py: -------------------------------------------------------------------------------- 1 | # Taken from the Triton matmul tutorial 2 | import os 3 | import torch 4 | import triton 5 | 6 | import awq_inference_engine as ie 7 | from gemm_kernel_v1 import quant_matmul 8 | 9 | @triton.testing.perf_report( 10 | triton.testing.Benchmark( 11 | x_names=['M'], # Argument names to use as an x-axis for the plot 12 | x_vals=[128 * i for i in range(2, 33)], # Different possible values for `x_name` 13 | line_arg='provider', # Argument name whose value corresponds to a different line in the plot 14 | # Possible values for `line_arg` 15 | line_vals=['cuda', 'triton'], 16 | # Label name for the lines 17 | line_names=["cuda", "Triton"], 18 | # Line styles 19 | styles=[('green', '-'), ('blue', '-')], 20 | ylabel="TFLOPS", # Label name for the y-axis 21 | plot_name="matmul-performance", # Name for the plot, used also as a file name for saving the plot. 22 | args={}, 23 | ) 24 | ) 25 | def benchmark(M, provider): 26 | N = K = 4096 27 | pack_num = 8 28 | group_size = 128 29 | int32_bounds = (torch.iinfo(torch.int32).min, torch.iinfo(torch.int32).max) 30 | inputs = torch.randn((M, K), dtype=torch.float16, device="cuda") 31 | qweight = torch.randint(*int32_bounds, (N, K // pack_num), dtype=torch.int32, device="cuda") 32 | scales = 0.001 * torch.abs(torch.randn((N, K // group_size), dtype=torch.float16, device="cuda")) 33 | qzeros = torch.randint(*int32_bounds, (N, K // group_size // pack_num), dtype=torch.int32, device="cuda") 34 | 35 | if provider == 'triton': 36 | trans = lambda x: x.T.contiguous() 37 | qweight = trans(qweight) 38 | qzeros = trans(qzeros) 39 | scales = trans(scales) 40 | 41 | quantiles = [0.5, 0.2, 0.8] 42 | if provider == 'cuda': 43 | ms, min_ms, max_ms = triton.testing.do_bench(lambda: ie.gemm_forward_cuda(inputs, qweight, scales, qzeros, group_size, 8), quantiles=quantiles) 44 | if provider == 'triton': 45 | ms, min_ms, max_ms = triton.testing.do_bench(lambda: quant_matmul(inputs, qweight, qzeros, scales, M=M, N=N, K=K, pack_num=pack_num, group_size=group_size), quantiles=quantiles) 46 | perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3) 47 | return perf(ms), perf(max_ms), perf(min_ms) 48 | 49 | parent_dir = os.path.dirname(os.path.realpath(__file__)) 50 | benchmark.run(show_plots=True, print_data=True, save_path=parent_dir) -------------------------------------------------------------------------------- /awq/torch_reference.py: -------------------------------------------------------------------------------- 1 | # A reference implementation of the dequantization in pure Pytorch 2 | import torch 3 | 4 | def generate_random_data(M, N, K): 5 | pack_num = 8 6 | group_size = 128 7 | print(f"Testing with M={M}") 8 | int32_bounds = (torch.iinfo(torch.int32).min, torch.iinfo(torch.int32).max) 9 | inputs = torch.randn((M, K), dtype=torch.float16, device="cuda") 10 | qweight = torch.randint(*int32_bounds, (N, K // pack_num), dtype=torch.int32, device="cuda") 11 | scales = 0.001 * torch.abs(torch.randn((N, K // group_size), dtype=torch.float16, device="cuda")) 12 | qzeros = torch.randint(*int32_bounds, (N, K // group_size // pack_num), dtype=torch.int32, device="cuda") 13 | return inputs, qweight, scales, qzeros 14 | 15 | def matmul_simple( 16 | a, qw, qzeros, scales 17 | ): 18 | N = 4096 19 | pack_num = 8 20 | group_size = 128 21 | 22 | M, K = a.shape 23 | # ASSUMPTION: 24 | # all quantization / packing is done along the channel dimension 25 | # (channels become lower-resolution, but the # of channels is the same) 26 | assert qw.shape == (N, K // pack_num) 27 | assert qzeros.shape == (N, K // group_size // pack_num) 28 | assert scales.shape == (N, K // group_size) 29 | 30 | print("dequantizing matrix ...") 31 | 32 | # dequant small tile 33 | n_rows_to_dequant = 64 34 | K2 = 32 35 | 36 | # dequant full matrix 37 | # n_rows_to_dequant = N 38 | # K2 = K 39 | dequant_matrix = torch.zeros((n_rows_to_dequant, K2), dtype=torch.float32, device=a.device) 40 | 41 | from tqdm import tqdm 42 | 43 | assert 0xF == 0b1111 44 | 45 | for row in tqdm(range(n_rows_to_dequant)): 46 | dequant_row = torch.zeros((K2, ), dtype=torch.float32, device=a.device) 47 | for col in range(K2): 48 | group_idx = col // group_size 49 | scale = scales[row][group_idx].to(torch.float32) 50 | qzero = qzeros[row][group_idx // pack_num] 51 | qweight = qw[row][col // pack_num] 52 | 53 | # assert col // group_size == 0 54 | # assert scale == scales[row][0] 55 | # assert qzero == qzeros[row][0] 56 | # assert qweight in [qw[row][0], qw[row][1], qw[row][2], qw[row][3]] 57 | 58 | # This makes sense b/c the 0th value is in the rightmost section of the packed number 59 | # so it needs to be shifted the least 60 | qzero_unpacked = ((qzero >> (4 * (group_idx % pack_num))) & 0xF).to(torch.float32) 61 | qweight_unpacked = ((qweight >> (4 * (col % pack_num))) & 0xF).to(torch.float32) 62 | dequant = scale * (qweight_unpacked - qzero_unpacked) 63 | dequant_row[col] = dequant 64 | dequant_matrix[row] = dequant_row 65 | torch.cuda.synchronize() 66 | print("finished dequantizing ...") 67 | return dequant_matrix 68 | 69 | if __name__ == "__main__": 70 | M = 128 71 | N = K = 4096 72 | inputs, qweight, scales, qzeros = generate_random_data(M, N, K) 73 | manual_dequant = matmul_simple(inputs, qweight, qzeros, scales) 74 | -------------------------------------------------------------------------------- /allreduce/csrc/reference_allreduce/fast_allreduce.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include "fast_allreduce.cuh" 7 | 8 | // fake pointer type 9 | using fptr_t = uint64_t; 10 | static_assert(sizeof(void *) == sizeof(fptr_t)); 11 | 12 | fptr_t init_fast_ar(torch::Tensor &meta, torch::Tensor &rank_data, 13 | const std::vector &handles, 14 | const std::vector &offsets, int rank, 15 | bool full_nvlink) { 16 | int world_size = offsets.size(); 17 | if (world_size > 8) 18 | throw std::invalid_argument("world size > 8 is not supported"); 19 | if (world_size % 2 != 0) 20 | throw std::invalid_argument("Odd num gpus is not supported for now"); 21 | if (world_size != handles.size()) 22 | throw std::invalid_argument( 23 | "handles length should equal to offsets length"); 24 | if (rank < 0 || rank >= world_size) 25 | throw std::invalid_argument("invalid rank passed in"); 26 | 27 | cudaIpcMemHandle_t ipc_handles[8]; 28 | for (int i = 0; i < world_size; i++) { 29 | std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(cudaIpcMemHandle_t)); 30 | } 31 | return (fptr_t) new vllm::FastAllreduce( 32 | reinterpret_cast(meta.data_ptr()), rank_data.data_ptr(), 33 | rank_data.numel(), ipc_handles, offsets, rank, full_nvlink); 34 | } 35 | 36 | void allreduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out) { 37 | auto fa = reinterpret_cast(_fa); 38 | auto stream = c10::cuda::getCurrentCUDAStream().stream(); 39 | TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); 40 | TORCH_CHECK_EQ(inp.numel(), out.numel()); 41 | switch (inp.scalar_type()) { 42 | case at::ScalarType::Float: { 43 | fa->allreduce(stream, reinterpret_cast(inp.data_ptr()), 44 | reinterpret_cast(out.data_ptr()), 45 | inp.numel()); 46 | break; 47 | } 48 | case at::ScalarType::Half: { 49 | fa->allreduce(stream, reinterpret_cast(inp.data_ptr()), 50 | reinterpret_cast(out.data_ptr()), 51 | inp.numel()); 52 | break; 53 | } 54 | #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) 55 | case at::ScalarType::BFloat16: { 56 | fa->allreduce( 57 | stream, reinterpret_cast(inp.data_ptr()), 58 | reinterpret_cast(out.data_ptr()), inp.numel()); 59 | break; 60 | } 61 | #endif 62 | default: 63 | throw std::runtime_error( 64 | "Fast allreduce only supports float32, float16 and bfloat16"); 65 | } 66 | } 67 | 68 | void dispose(fptr_t _fa) { 69 | auto fa = reinterpret_cast(_fa); 70 | delete fa; 71 | } 72 | 73 | int meta_size() { return sizeof(vllm::Metadata); } 74 | 75 | void register_buffer(fptr_t _fa, torch::Tensor &t, 76 | const std::vector &handles, 77 | const std::vector &offsets) { 78 | auto fa = reinterpret_cast(_fa); 79 | fa->register_buffer(handles, offsets, t.data_ptr()); 80 | } 81 | 82 | std::pair, std::vector> get_graph_buffer_ipc_meta( 83 | fptr_t _fa) { 84 | auto fa = reinterpret_cast(_fa); 85 | return fa->get_graph_buffer_ipc_meta(); 86 | } 87 | 88 | void register_graph_buffers(fptr_t _fa, const std::vector &handles, 89 | const std::vector> &offsets) { 90 | auto fa = reinterpret_cast(_fa); 91 | fa->register_graph_buffers(handles, offsets); 92 | } -------------------------------------------------------------------------------- /allreduce/modal_runner.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import os 3 | import time 4 | import subprocess 5 | from textwrap import dedent 6 | 7 | import modal 8 | 9 | r = lambda *args, **kwargs: subprocess.run(*args, shell=True, **kwargs) 10 | 11 | image = ( 12 | # Use a version of CUDA that's compatible w/ torch 13 | modal.Image.from_registry( 14 | "nvidia/cuda:12.2.2-devel-ubuntu22.04", add_python="3.11") 15 | .pip_install("sh", "torch==2.1.2", "ninja") 16 | # build-essential is probably not needed 17 | .apt_install("git", "build-essential", "clang") 18 | # openmpi (nccl is already installed, so skip "libnccl-dev", "libnccl2") 19 | .apt_install("openmpi-bin", "openmpi-common", "libopenmpi-dev") 20 | ) 21 | 22 | openmpi_src_image = ( 23 | # Use a version of CUDA that's compatible w/ torch 24 | modal.Image.from_registry( 25 | "nvidia/cuda:12.2.2-devel-ubuntu22.04", add_python="3.11") 26 | .pip_install("sh", "torch==2.1.2", "ninja") 27 | .apt_install("git", "build-essential", "clang", "autotools-dev", "autoconf", "libtool") 28 | .run_commands( 29 | "git clone https://github.com/open-mpi/ompi.git", 30 | "cd ompi && git checkout v2.x && ./autogen.pl && ./configure", 31 | "cd ompi && make all && sudo make install" 32 | ) 33 | ) 34 | 35 | stub = modal.Stub() 36 | 37 | # T4 (turing) has ~ instant results for 2,4 GPU count 38 | t4_2 = modal.gpu.T4(count=2) 39 | # A10G (ampere) has ~ instant for 2 GPU, ~ 1 minute for 4 GPU 40 | # a10g = modal.gpu.A10G(count=4) 41 | a10g = modal.gpu.A10G(count=2) 42 | 43 | dirname = os.path.dirname(__file__) 44 | csrc_dir = Path(dirname) / "csrc" 45 | 46 | @stub.function(gpu=a10g, image=openmpi_src_image, cpu=8, 47 | mounts=[modal.Mount.from_local_dir(csrc_dir, remote_path="/root/csrc")]) 48 | def build_pure_cuda_kernel(): 49 | t0 = time.time() 50 | r("cd csrc/reference_allreduce && ./compile_combined.bash") 51 | print(f"Build time: {time.time() - t0:.2f}s") 52 | with open("hostfile.txt", "w") as f: 53 | f.write("localhost slots=2 max_slots=2") 54 | # print the mpi version 55 | r("mpirun --hostfile hostfile.txt --mca btl ^vader --allow-run-as-root -np 2 csrc/reference_allreduce/fastallreduce_test.bin") 56 | # r("mpirun --hostfile hostfile.txt --allow-run-as-root -np 2 csrc/reference_allreduce/fastallreduce_test.bin") 57 | # r("mpirun --hostfile hostfile.txt --mca btl ^vader --allow-run-as-root -np 2 csrc/reference_allreduce/fastallreduce_test.bin") 58 | # r("mpirun --mca btl ^vader --allow-run-as-root -np 2 csrc/reference_allreduce/fastallreduce_test.bin") 59 | print(f"Total time: {time.time() - t0:.2f}s") 60 | 61 | 62 | @stub.function(gpu="any", image=image) 63 | def run_torch(): 64 | print("Adding 1 + 1 on GPU") 65 | import torch 66 | x = torch.tensor([1.0]).cuda() 67 | r = x + x 68 | print(f"Finished: {r}") 69 | 70 | 71 | @stub.function(gpu=t4_2, image=image, cpu=8) 72 | def build_kernel_with_torch_bindings(): 73 | t0 = time.time() 74 | r("git clone --depth 1 https://github.com/vedantroy/gpu_kernels.git") 75 | print(f"Clone time: {time.time() - t0:.2f}s") # ~ 0.5s 76 | r("cd gpu_kernels/allreduce && python3 setup.py install") 77 | # modal (8 cpu) = ~70s 78 | # laptop = ~56s 79 | # vast (ryzen 9) = ~45s 80 | # vast (ryzen 9, no optimization) = ~40s 81 | print(f"Build time: {time.time() - t0:.2f}s") 82 | 83 | code = dedent(""" 84 | import torch 85 | import cuda_experiments 86 | 87 | x = torch.ones(2, device="cuda") 88 | x_plus_x = cuda_experiments.add_one(x) 89 | torch.testing.assert_close(x_plus_x, x + x) 90 | """) 91 | 92 | r(f"echo '{code}' > gpu_kernels/allreduce/test.py") 93 | r("cd gpu_kernels/allreduce && python3 test.py") # ~ 5s 94 | 95 | print(f"All time: {time.time() - t0:.2f}s") 96 | 97 | # Topology, SMI, etc. 98 | 99 | # result = subprocess.run(["nvidia-smi", "topo", "-m"], stdout=subprocess.PIPE) 100 | # subprocess.run( 101 | # [ 102 | # "nvidia-smi", 103 | # "--query-gpu=index,name,utilization.gpu,utilization.memory,memory.total,memory.free,memory.used", 104 | # "--format=csv", 105 | # ] 106 | # ) 107 | -------------------------------------------------------------------------------- /allreduce/setup.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import os 4 | from setuptools import find_packages, setup 5 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 6 | 7 | # T4 GPU 8 | # os.environ['TORCH_CUDA_ARCH_LIST'] = '7.5' 9 | # A2000 10 | os.environ['TORCH_CUDA_ARCH_LIST'] = '8.6' 11 | 12 | ROOT_DIR = os.path.dirname(__file__) 13 | 14 | print(f"# CPUs: {os.cpu_count()}") 15 | 16 | extra_compile_args = { 17 | # "cxx": ["-g", "-O2", "-std=c++17"], 18 | # "nvcc": ["-O2", "-std=c++17", f"--threads={os.cpu_count()}"], 19 | "cxx": ["-g", "-std=c++17"], 20 | "nvcc": ["-std=c++17", f"--threads={os.cpu_count()}"], 21 | } 22 | 23 | 24 | def get_path(*filepath) -> str: 25 | return os.path.join(ROOT_DIR, *filepath) 26 | 27 | 28 | def get_requirements() -> List[str]: 29 | with open(get_path("requirements.txt")) as f: 30 | requirements = f.read().strip().split("\n") 31 | return requirements 32 | 33 | 34 | setup( 35 | name="cuda_experiments", 36 | packages=find_packages(), 37 | ext_modules=[ 38 | CUDAExtension( 39 | name="cuda_experiments", 40 | sources=[ 41 | "csrc/pybind.cpp", 42 | "csrc/add_one/add_one.cu", 43 | "csrc/my_allreduce/allreduce_torch_bindings.cu", 44 | # "csrc/my_allreduce/allreduce.cuh", 45 | # "csrc/reference_allreduce/fast_allreduce.cu", 46 | ], 47 | extra_compile_args=extra_compile_args, 48 | ), 49 | ], 50 | cmdclass={"build_ext": BuildExtension}, 51 | install_requires=get_requirements(), 52 | ) 53 | 54 | # 38 secs nvcc 55 | # pybind.cpp: 20s 56 | # 57 | 58 | # delete object files 59 | # find ./build/temp.linux-x86_64-cpython-311/csrc -name "*.o" | xargs rm 60 | 61 | # ninja file: 62 | # ninja_required_version = 1.3 63 | # cxx = c++ 64 | # nvcc = /usr/local/cuda/bin/nvcc 65 | # 66 | # cflags = -pthread -B /root/micromamba/envs/allreduce/compiler_compat -DNDEBUG -fwrapv -O2 -Wall -fPIC -O2 -isystem /root/micromamba/envs/allreduce/include -fPIC -O2 -isystem /root/micromamba/envs/allreduce/include -fPIC -I/root/micromamba/envs/allreduce/lib/python3.11/site-packages/torch/include -I/root/micromamba/envs/allreduce/lib/python3.11/site-packages/torch/include/torch/csrc/api/include -I/root/micromamba/envs/allreduce/lib/python3.11/site-packages/torch/include/TH -I/root/micromamba/envs/allreduce/lib/python3.11/site-packages/torch/include/THC -I/usr/local/cuda/include -I/root/micromamba/envs/allreduce/include/python3.11 -c 67 | # post_cflags = -g -std=c++17 -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=cuda_experiments -D_GLIBCXX_USE_CXX11_ABI=0 68 | # cuda_cflags = -I/root/micromamba/envs/allreduce/lib/python3.11/site-packages/torch/include -I/root/micromamba/envs/allreduce/lib/python3.11/site-packages/torch/include/torch/csrc/api/include -I/root/micromamba/envs/allreduce/lib/python3.11/site-packages/torch/include/TH -I/root/micromamba/envs/allreduce/lib/python3.11/site-packages/torch/include/THC -I/usr/local/cuda/include -I/root/micromamba/envs/allreduce/include/python3.11 -c 69 | # cuda_post_cflags = -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -std=c++17 --threads=24 -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=cuda_experiments -D_GLIBCXX_USE_CXX11_ABI=0 -gencode=arch=compute_86,code=sm_86 70 | # cuda_dlink_post_cflags = 71 | # ldflags = 72 | # 73 | # rule compile 74 | # command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags 75 | # depfile = $out.d 76 | # deps = gcc 77 | # 78 | # rule cuda_compile 79 | # depfile = $out.d 80 | # deps = gcc 81 | # command = $nvcc $cuda_cflags -c $in -o $out $cuda_post_cflags 82 | # 83 | # 84 | # 85 | # 86 | # 87 | # build /root/gpu_kernels/allreduce/build/temp.linux-x86_64-cpython-311/csrc/add_one/add_one.o: cuda_compile /root/gpu_kernels/allreduce/csrc/add_one/add_one.cu 88 | # build /root/gpu_kernels/allreduce/build/temp.linux-x86_64-cpython-311/csrc/pybind.o: compile /root/gpu_kernels/allreduce/csrc/pybind.cpp -------------------------------------------------------------------------------- /misc.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 12, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# Simple example of packing\n", 10 | "vals = [0, 4, 8, 14, 2, 6, 10, 12]\n", 11 | "assert len(vals) == 8\n", 12 | "\n", 13 | "# Original AWQ packing code\n", 14 | "\"\"\"\n", 15 | "for i in range(pack_num):\n", 16 | " qweight_col = intweight[:, col * pack_num + order_map[i]]\n", 17 | " qweight[:, col] |= qweight_col << (i * awq_linear.w_bit)\n", 18 | "\"\"\"\n", 19 | "\n", 20 | "# Binary values are packed right to left\n", 21 | "# as 4-bit values into a single 32-bit value\n", 22 | "packed = 0\n", 23 | "for idx in range(len(vals)):\n", 24 | " packed |= vals[idx] << (idx * 4)\n", 25 | "\n", 26 | "assert packed == 0b11001010011000101110100001000000\n", 27 | "\n", 28 | "packed_str = format(packed, 'b')\n", 29 | "assert len(packed_str) == 32\n", 30 | "\n", 31 | "assert packed_str[28:] == '0000'\n", 32 | "assert packed_str[24:28] == '0100'\n", 33 | "assert packed_str[20:24] == '1000'\n", 34 | "assert packed_str[0:4] == '1100'" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 22, 40 | "metadata": {}, 41 | "outputs": [ 42 | { 43 | "name": "stdout", 44 | "output_type": "stream", 45 | "text": [ 46 | "[[[1, 4], [0, 0]], [[1, 1], [1, 2]]]\n" 47 | ] 48 | } 49 | ], 50 | "source": [ 51 | "# This is a breakdown of (numpy broadcasting\n", 52 | "# + unpacking inside the original VLLM AWQ implementation)\n", 53 | "# I think the original implementation is wrong\n", 54 | "import numpy as np\n", 55 | "\n", 56 | "BLOCK_K = 2\n", 57 | "PACKED_BLOCK_N = 2\n", 58 | "BLOCK_N = 4\n", 59 | "P = 2\n", 60 | "# Packed:\n", 61 | "# [1, 4]\n", 62 | "# [17, 33]\n", 63 | "# Unpacked:\n", 64 | "# [0, 1, 0, 4]\n", 65 | "# [1, 1, 2, 1]\n", 66 | "\n", 67 | "b = np.array([[0b00000001, 0b00000100], [0b00010001, 0b00100001]], dtype=np.uint8)\n", 68 | "shifter = np.array([0, 1]) * 4\n", 69 | "AWQ_MASK = 0b1111 # Set the mask value to select the lower 4 bits of an 8-bit int\n", 70 | "\n", 71 | "assert b.shape == (BLOCK_K, PACKED_BLOCK_N)\n", 72 | "expanded_b = b[:, None, :] >> np.array([0, 0])[None, :, None]\n", 73 | "assert expanded_b.shape == (BLOCK_K, P, PACKED_BLOCK_N)\n", 74 | "expanded_b = expanded_b.tolist()\n", 75 | "assert expanded_b == [\n", 76 | " [[1, 4], [1, 4]], \n", 77 | " [[17, 33], [17, 33]]\n", 78 | " ]\n", 79 | "\n", 80 | "empty_b = np.zeros_like(b)\n", 81 | "expanded_shifter = shifter[None, :, None] >> np.zeros_like(b)[:, None, :]\n", 82 | "assert expanded_shifter.shape == (BLOCK_K, P, PACKED_BLOCK_N)\n", 83 | "expanded_shifter = expanded_shifter.tolist()\n", 84 | "assert expanded_shifter == [[[0, 0], [4, 4]], \n", 85 | " [[0, 0], [4, 4]]]\n", 86 | "\n", 87 | "# convert expanded_b and expanded_shifter to numpy arrays\n", 88 | "expanded_b = np.array(expanded_b)\n", 89 | "expanded_shifter = np.array(expanded_shifter)\n", 90 | "\n", 91 | "shifted = expanded_b >> expanded_shifter\n", 92 | "masked_out = shifted & AWQ_MASK\n", 93 | "assert masked_out.shape == (BLOCK_K, P, PACKED_BLOCK_N)\n", 94 | "masked_out = masked_out.tolist()\n", 95 | "print(masked_out) # wrong" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": 21, 101 | "metadata": {}, 102 | "outputs": [ 103 | { 104 | "name": "stdout", 105 | "output_type": "stream", 106 | "text": [ 107 | "[[[1 1 1]\n", 108 | " [2 2 2]\n", 109 | " [3 3 3]]\n", 110 | "\n", 111 | " [[1 1 1]\n", 112 | " [2 2 2]\n", 113 | " [3 3 3]]]\n" 114 | ] 115 | } 116 | ], 117 | "source": [ 118 | "# A breakdown of broadcasting semantics\n", 119 | "# The TL;DR is:\n", 120 | "# - `None` adds a new axis (dimension)\n", 121 | "# - Concretely, this means adding a *single* pair of brackets\n", 122 | "# This can be visualized as either \n", 123 | "# - adding brackets inside the previous dimension (e.g b[:, None])\n", 124 | "# - adding brackets outside the next dimension (e.g. b[None, :])\n", 125 | "# - Once the brackets are added, duplicate the values to match the other array's shape\n", 126 | "b = np.array([\n", 127 | " [1, 2, 3],\n", 128 | " [4, 5, 6]\n", 129 | "])\n", 130 | "\n", 131 | "b_expanded = b[:, None, :] + np.array([0, 0])[None, :, None]\n", 132 | "assert b_expanded.tolist() == [\n", 133 | " [[1, 2, 3], [1, 2, 3]],\n", 134 | " [[4, 5, 6], [4, 5, 6]]\n", 135 | "]\n", 136 | "\n", 137 | "b2 = np.array([1, 2, 3])\n", 138 | "b_expanded = b2[None, :, None] + np.zeros_like(b)[:, None, :]\n", 139 | "assert b_expanded.tolist() == [\n", 140 | " [\n", 141 | " [1, 1, 1],\n", 142 | " [2, 2, 2],\n", 143 | " [3, 3, 3]\n", 144 | " ],\n", 145 | " [\n", 146 | " [1, 1, 1],\n", 147 | " [2, 2, 2],\n", 148 | " [3, 3, 3]\n", 149 | " ]\n", 150 | "]" 151 | ] 152 | } 153 | ], 154 | "metadata": { 155 | "kernelspec": { 156 | "display_name": "Python 3", 157 | "language": "python", 158 | "name": "python3" 159 | }, 160 | "language_info": { 161 | "codemirror_mode": { 162 | "name": "ipython", 163 | "version": 3 164 | }, 165 | "file_extension": ".py", 166 | "mimetype": "text/x-python", 167 | "name": "python", 168 | "nbconvert_exporter": "python", 169 | "pygments_lexer": "ipython3", 170 | "version": "3.11.6" 171 | } 172 | }, 173 | "nbformat": 4, 174 | "nbformat_minor": 2 175 | } 176 | -------------------------------------------------------------------------------- /allreduce/csrc/my_allreduce/allreduce_test.cu: -------------------------------------------------------------------------------- 1 | #include "allreduce.cuh" 2 | #include "mpi.h" 3 | #include 4 | 5 | #define MPICHECK(cmd) \ 6 | do { \ 7 | int e = cmd; \ 8 | if (e != MPI_SUCCESS) { \ 9 | printf("Failed: MPI error %s:%d '%d'\n", __FILE__, __LINE__, e); \ 10 | exit(EXIT_FAILURE); \ 11 | } \ 12 | } while (0) 13 | 14 | int main(int argc, char **argv) { 15 | MPI_Init(NULL, NULL); 16 | 17 | int world_size; 18 | MPI_Comm_size(MPI_COMM_WORLD, &world_size); 19 | 20 | int world_rank; 21 | MPI_Comm_rank(MPI_COMM_WORLD, &world_rank); 22 | 23 | CUDACHECK(cudaSetDevice(world_rank)); 24 | 25 | #define N_ELEMENTS (1024 * 8) 26 | #define DTYPE half 27 | 28 | // We allocate the barrier state, input buffer, and output buffer, in one pass, so we only need to 29 | // cudaMalloc once, and we can use the same cudaIpcMemHandle_t for all of them. 30 | 31 | auto *state_cpu = (mysync::BarrierState *)malloc(sizeof(mysync::BarrierState) + (2 * sizeof(DTYPE) * N_ELEMENTS)); 32 | assert(state_cpu != NULL); 33 | 34 | DTYPE *input_buf_cpu = reinterpret_cast(state_cpu + 1); 35 | DTYPE *output_buf_cpu = input_buf_cpu + N_ELEMENTS; 36 | 37 | // set input_buf_cpu to all 1s 38 | for (int i = 0; i < N_ELEMENTS; ++i) { 39 | input_buf_cpu[i] = __float2half(1.0f); 40 | } 41 | 42 | // set output buf to all 0s 43 | // (not that this should matter) 44 | for (int i = 0; i < N_ELEMENTS; ++i) { 45 | output_buf_cpu[i] = __float2half(0.0f); 46 | } 47 | 48 | // copy the entire thing to the GPU 49 | mysync::BarrierState *state; 50 | CUDACHECK(cudaMalloc(&state, sizeof(mysync::BarrierState) + (2 * sizeof(DTYPE) * N_ELEMENTS))); 51 | CUDACHECK(cudaMemcpy(state, state_cpu, sizeof(mysync::BarrierState) + (2 * sizeof(DTYPE) * N_ELEMENTS), cudaMemcpyHostToDevice)); 52 | 53 | DTYPE *input_buf = reinterpret_cast(state + 1); 54 | DTYPE *output_buf = input_buf + N_ELEMENTS; 55 | 56 | // mem copy to cpu and print first element 57 | DTYPE *input_buf_cpu2 = new DTYPE[N_ELEMENTS]; 58 | CUDACHECK(cudaMemcpy(input_buf_cpu2, input_buf, N_ELEMENTS * sizeof(DTYPE), cudaMemcpyDeviceToHost)); 59 | printf("Rank %d: input_buf[0] = %f\n", world_rank, __half2float(input_buf_cpu2[0])); 60 | 61 | DTYPE *output_buf_cpu2 = new DTYPE[N_ELEMENTS]; 62 | CUDACHECK(cudaMemcpy(output_buf_cpu2, output_buf, N_ELEMENTS * sizeof(DTYPE), cudaMemcpyDeviceToHost)); 63 | printf("Rank %d: output_buf[0] = %f\n", world_rank, __half2float(output_buf_cpu2[0])); 64 | 65 | cudaIpcMemHandle_t cur_rank_handle; 66 | cudaIpcMemHandle_t rank_handles[8]; 67 | 68 | CUDACHECK(cudaIpcGetMemHandle(&cur_rank_handle, state)); 69 | MPICHECK(MPI_Allgather(&cur_rank_handle, // void* send_data, 70 | sizeof(cudaIpcMemHandle_t), // int send_count, 71 | MPI_BYTE, // MPI_Datatype send_datatype, 72 | rank_handles, // void* recv_data, 73 | sizeof(cudaIpcMemHandle_t), // int recv_count, 74 | MPI_BYTE, // MPI_Datatype recv_datatype, 75 | MPI_COMM_WORLD // MPI_Comm communicator 76 | )); 77 | 78 | 79 | // Offsets are only necessary for Pytorch bindings 80 | // (where tensors are not allocated at the start of a cudaIpcMemHandle) 81 | // (that's why we set them to 0 here) 82 | std::vector offsets(world_size, 0); 83 | mysync::Sync sync(state, rank_handles, offsets, world_rank); 84 | 85 | { 86 | // register the buffer 87 | std::vector handles(world_size); 88 | handles.reserve(world_size); 89 | for (int i = 0; i < world_size; ++i) { 90 | char buffer[sizeof(cudaIpcMemHandle_t)]; 91 | memcpy(buffer, &rank_handles[i], sizeof(cudaIpcMemHandle_t)); 92 | handles[i] = std::string(buffer, sizeof(cudaIpcMemHandle_t)); 93 | 94 | /* 95 | char *begin = (char *)(&rank_handles[i]); 96 | char *end1 = (char *)(&rank_handles[i + 1]); 97 | char *end2 = begin + sizeof(cudaIpcMemHandle_t); 98 | assert(end1 == end2); 99 | std::string handle_str(begin, end1); 100 | handles.push_back(handle_str); 101 | handles.emplace_back(begin, end1); 102 | handles.push_back(begin, end1); 103 | */ 104 | } 105 | 106 | // { 107 | // for (int i = 0; i < world_size; ++i) { 108 | // if (i == world_rank) continue; // skip self (otherwise we get an 'invalid context' error) 109 | // cudaIpcMemHandle_t handle = rank_handles[i]; 110 | // // printf("Rank %d: opening handle %d before registration\n", world_rank, i); 111 | // char* ptr; 112 | // CUDACHECK(cudaIpcOpenMemHandle((void **)&ptr, handle, cudaIpcMemLazyEnablePeerAccess)); 113 | // printf("Rank %d: opened handle %d before registration\n", world_rank, i); 114 | // } 115 | // } 116 | 117 | // for (int i = 0; i < world_size; ++i) { 118 | // if (i == world_rank) continue; // skip self (otherwise we get an 'invalid context' error) 119 | // // cudaIpcMemHandle_t handle; 120 | // // memcpy(&handle, handles[i].data(), sizeof(cudaIpcMemHandle_t)); 121 | 122 | // char* ptr; 123 | // // CUDACHECK(cudaIpcOpenMemHandle((void **)&ptr, handle, cudaIpcMemLazyEnablePeerAccess)); 124 | // CUDACHECK(cudaIpcOpenMemHandle( 125 | // (void**)&ptr, *((const cudaIpcMemHandle_t *)handles[i].data()), 126 | // cudaIpcMemLazyEnablePeerAccess)); 127 | // printf("Rank %d: opened handle %d before registration (v2)\n", world_rank, i); 128 | // } 129 | 130 | std::vector buffer_offsets(world_size, sizeof(mysync::BarrierState)); 131 | sync.register_buffer(handles, buffer_offsets, input_buf); 132 | } 133 | 134 | sync.sync_test(N_ELEMENTS, output_buf); 135 | 136 | DTYPE *output_buf_cpu3 = new DTYPE[N_ELEMENTS]; 137 | CUDACHECK(cudaMemcpy(output_buf_cpu3, output_buf, N_ELEMENTS * sizeof(DTYPE), cudaMemcpyDeviceToHost)); 138 | printf("Rank %d: output_buf[0] = %f\n", world_rank, __half2float(output_buf_cpu3[0])); 139 | 140 | 141 | MPI_Finalize(); 142 | return EXIT_SUCCESS; 143 | } 144 | -------------------------------------------------------------------------------- /awq/gemm_kernel_v1.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2023 - Vedant Roy 2 | import torch 3 | import triton 4 | import triton.language as tl 5 | 6 | # from . import custom_autotune 7 | import custom_autotune 8 | 9 | 10 | # Auottuner configs: 11 | # https://github.com/fpgaminer/GPTQ-triton/blob/main/src/gptq_triton/quant_linear.py 12 | # https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html#sphx-glr-getting-started-tutorials-03-matrix-multiplication-py 13 | # Autotuner source: 14 | # https://github.com/openai/triton/blob/main/python/triton/runtime/autotuner.py 15 | # Custom autotuner: 16 | # https://github.com/fpgaminer/GPTQ-triton/blob/main/src/gptq_triton/custom_autotune.py 17 | @custom_autotune.autotune( 18 | configs=[ 19 | triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, 20 | num_warps=8), 21 | triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, 22 | num_warps=4), 23 | triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, 24 | num_warps=4), 25 | triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, 26 | num_warps=4), 27 | triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, 28 | num_warps=4), 29 | triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, 30 | num_warps=4), 31 | triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, 32 | num_warps=2), 33 | triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, 34 | num_warps=2), 35 | ], 36 | key=['M', 'N', 'K'], 37 | ) 38 | @triton.jit 39 | def quant_matmul_kernel( 40 | # Pointers to matrices 41 | a_ptr, qw_ptr, c_ptr, scales_ptr, zeros_ptr, 42 | # Matrix dimensions 43 | M, N, K, 44 | # Quantization parameters 45 | group_size, 46 | # Meta-parameters 47 | BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, 48 | GROUP_SIZE_M: tl.constexpr, 49 | ): 50 | """ 51 | Kernel for computing the matmul C = A x qw 52 | 53 | a: (M, K) 54 | qw: (K // pack_num, N) 55 | scales: (K // group_size, N) 56 | qzeros: (K // group_size // pack_num, N) 57 | """ 58 | 59 | stride_zeros_k = N 60 | stride_scales_k = N 61 | stride_a_m = K 62 | stride_qw_k = N 63 | 64 | pid = tl.program_id(axis=0) 65 | num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) 66 | num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) 67 | num_pid_in_group = GROUP_SIZE_M * num_pid_n 68 | group_id = pid // num_pid_in_group 69 | first_pid_m = group_id * GROUP_SIZE_M 70 | group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) 71 | 72 | pid_m = first_pid_m + (pid % group_size_m) 73 | pid_n = (pid % num_pid_in_group) // group_size_m 74 | 75 | offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M 76 | offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) 77 | offs_k = tl.arange(0, BLOCK_SIZE_K) # (K,) 78 | qw_shifter = (offs_k % 8) * 4 79 | 80 | accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 81 | for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): 82 | a_offs = (k * BLOCK_SIZE_K) + (offs_am[:, None] * stride_a_m + offs_k[None, :]) # (M, K) 83 | a = tl.load(a_ptr + a_offs) 84 | 85 | qw_offs = (((k * BLOCK_SIZE_K) + offs_k[:, None]) // 8) * stride_qw_k + offs_bn[ 86 | None, : 87 | ] # (K, N) 88 | qw_packed = tl.load(qw_ptr + qw_offs) # (K, N) 89 | 90 | qw_unpacked = (qw_packed >> qw_shifter[:, None]) & 0xF 91 | 92 | k_iters_per_quant_group = group_size // BLOCK_SIZE_K 93 | grp_idx = k // k_iters_per_quant_group 94 | 95 | col_offs = offs_bn 96 | scales = tl.load(scales_ptr + (stride_scales_k * grp_idx) + col_offs) # (N,) 97 | 98 | packed_zeros = tl.load( 99 | zeros_ptr + stride_zeros_k * (grp_idx // 8) + col_offs 100 | ) # (N,) 101 | unpacked_zeros = (packed_zeros >> ((grp_idx % 8) * 4)) & 0xF 102 | 103 | dequantized = scales[None, :].to(tl.float32) * ( 104 | qw_unpacked.to(tl.float32) - unpacked_zeros[None, :].to(tl.float32) 105 | ) 106 | accumulator += tl.dot(a, dequantized.to(tl.float16)) 107 | c = accumulator.to(tl.float16) 108 | 109 | offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 110 | offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 111 | stride_cm = N 112 | c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + offs_cn[None, :] 113 | c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) 114 | tl.store(c_ptrs, c, mask=c_mask) 115 | 116 | 117 | def quant_matmul(a, qw, qzeros, scales, *, M, N, K, pack_num, group_size): 118 | c = torch.empty((M, N), dtype=torch.float16, device=a.device) 119 | 120 | assert qw.shape == (K // pack_num, N) 121 | assert qzeros.shape == (K // group_size // pack_num, N) 122 | assert scales.shape == (K // group_size, N) 123 | assert all(x.is_contiguous() for x in [a, qw, c, qzeros, scales]) 124 | # BLOCK_SIZE_K has possible values of 32, 64 125 | # group_size, K must be divisible by BLOCK_SIZE_K 126 | assert group_size % 64 == 0, f"group_size {group_size} is not a multiple of 64" 127 | assert K % 64 == 0, f"K {K} is not a multiple of 64" 128 | # BLOCK_SIZE_N has possible values of 32, 64, 128, 256 129 | # N must be divisible by BLOCK_SIZE_N 130 | assert N % 256 == 0, f"N {N} is not a multiple of 256" 131 | 132 | grid_1d = lambda META: ( 133 | triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), 134 | ) 135 | quant_matmul_kernel[grid_1d]( 136 | a_ptr=a, 137 | qw_ptr=qw, 138 | c_ptr=c, 139 | scales_ptr=scales, 140 | zeros_ptr=qzeros, 141 | M=M, 142 | N=N, 143 | K=K, 144 | group_size=group_size, 145 | ) 146 | return c 147 | 148 | if __name__ == "__main__": 149 | # Tested with AWQ commit f0b4b68004f76d562658143cddea5aad8c1b8266 150 | import awq_inference_engine as ie 151 | 152 | M = torch.randint(0, 1000, (1,)).item() 153 | 154 | # M = 128 155 | N = K = 4096 156 | pack_num = 8 157 | group_size = 128 158 | 159 | print(f"Testing with M={M}") 160 | 161 | int32_bounds = (torch.iinfo(torch.int32).min, torch.iinfo(torch.int32).max) 162 | inputs = torch.randn((M, K), dtype=torch.float16, device="cuda") 163 | qweight = torch.randint(*int32_bounds, (N, K // pack_num), dtype=torch.int32, device="cuda") 164 | # TODO: Check what the proper magnitude of scales is 165 | # scales is always positive & needs to be very small (otherwise this test fails) ?? 166 | # Sample stats of scales from a single real layer 167 | # min: 0.0004451274871826172 168 | # max: 0.00801849365234375 169 | # std: 0.0006480216979980469 170 | # mean: 0.0015420913696289062) 171 | scales = 0.001 * torch.abs(torch.randn((N, K // group_size), dtype=torch.float16, device="cuda")) 172 | qzeros = torch.randint(*int32_bounds, (N, K // group_size // pack_num), dtype=torch.int32, device="cuda") 173 | 174 | out_cuda = ie.gemm_forward_cuda(inputs, qweight, scales, qzeros, group_size, 8) 175 | trans = lambda x: x.T.contiguous() 176 | out_triton = quant_matmul(inputs, trans(qweight), trans(qzeros), trans(scales), M=M, N=N, K=K, pack_num=pack_num, group_size=group_size) 177 | 178 | torch.testing.assert_close(out_cuda, out_triton, rtol=1e-3, atol=1e-3) -------------------------------------------------------------------------------- /allreduce/csrc/my_sync/sync.cuh: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | #define CUDACHECK(cmd) \ 10 | do { \ 11 | cudaError_t e = cmd; \ 12 | if (e != cudaSuccess) { \ 13 | printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \ 14 | cudaGetErrorString(e)); \ 15 | exit(EXIT_FAILURE); \ 16 | } \ 17 | } while (0) 18 | 19 | // #define CUDACHECK(cmd) cmd 20 | 21 | namespace mysync { 22 | struct Signal { 23 | alignas(64) union { 24 | uint64_t flag; 25 | unsigned char data[8]; 26 | } start; 27 | alignas(64) union { 28 | uint64_t flag; 29 | unsigned char data[8]; 30 | } end; 31 | }; 32 | 33 | struct BarrierState { 34 | alignas(128) Signal sg; 35 | alignas(128) int counter; 36 | }; 37 | static_assert(offsetof(BarrierState, counter) == 128); 38 | static_assert(sizeof(BarrierState) == 256); 39 | 40 | struct RankSignals { 41 | volatile Signal *signals[8]; 42 | }; 43 | 44 | __device__ uint64_t get_target_flag(int world_size) { 45 | // 64 1s 46 | auto m = std::numeric_limits::max(); 47 | // Each GPU gets 8 bits in the flag 48 | // E.g, if there are 4 GPUs, the target flag is 49 | // 32 0s followed by 32 1s 50 | return m >> ((8 - world_size) * 8); 51 | } 52 | 53 | __device__ void start_sync(const RankSignals &sg, volatile BarrierState *bstate, 54 | int rank, int world_size) { 55 | bool first_block_in_rank = blockIdx.x == 0; 56 | if (first_block_in_rank) { 57 | if (threadIdx.x < world_size) { 58 | int other_rank = threadIdx.x; 59 | // warp 1: notify all other ranks that this rank has reached the sync 60 | // point 61 | sg.signals[other_rank]->start.data[rank] = 255; 62 | } else if (threadIdx.x == 32) { 63 | // warp 2: reset the end signal 64 | bstate->sg.end.flag = 0; 65 | } 66 | } 67 | 68 | // busy-wait until the current rank's signal 69 | // has been written to by all ranks 70 | if (threadIdx.x == 0) { 71 | uint64_t target_flag = get_target_flag(world_size); 72 | while (bstate->sg.start.flag != target_flag) 73 | ; 74 | } 75 | if (threadIdx.x == 0 && first_block_in_rank) 76 | printf("1st block rank %d done busy-wait\n", rank); 77 | __syncthreads(); 78 | } 79 | 80 | __device__ void end_sync(const RankSignals &sg, volatile BarrierState *bstate, 81 | int rank, int world_size) { 82 | __shared__ int blocks_at_sync_point; 83 | if (threadIdx.x == 0) 84 | blocks_at_sync_point = atomicAdd((int *)&bstate->counter, 1); 85 | __syncthreads(); // (I think) this ensures `blocks_at_sync_point` is assigned 86 | 87 | bool last_block_at_sync_point = (blocks_at_sync_point == gridDim.x - 1); 88 | if (last_block_at_sync_point) { 89 | if (threadIdx.x < world_size) { 90 | int other_rank = threadIdx.x; 91 | // warp 1: notify all other ranks that this rank has reached the sync 92 | // point 93 | sg.signals[other_rank]->end.data[rank] = 255; 94 | } else if (threadIdx.x == 32) { 95 | // warp 2: reset the start signal + counter 96 | bstate->sg.start.flag = 0; 97 | bstate->counter = 0; 98 | } 99 | } 100 | 101 | // busy-wait until the current rank's signal 102 | // has been written to by all ranks 103 | if (threadIdx.x == 0) { 104 | uint64_t target_flag = get_target_flag(world_size); 105 | while (bstate->sg.end.flag != target_flag) 106 | ; 107 | } 108 | __syncthreads(); 109 | } 110 | 111 | // Any code using nanosleep here is incorrect 112 | // 1. nanosleep takes in a 32bit int 113 | // 2. nanosleep has a delay ranging from 0 to 2x the specified amount 114 | // 3. nanosleep can only delay a max of 1ms 115 | #define NS_PER_S (uint64_t)1000000000 116 | 117 | __global__ void sleepKernel() { 118 | uint64_t start, end; 119 | uint64_t sleepTime = 5 * NS_PER_S; // Sleep for 5 seconds 120 | 121 | if (threadIdx.x == 0) { 122 | // Record start time 123 | asm volatile("mov.u64 %0, %%globaltimer;" : "=l"(start)); 124 | 125 | // Sleep for 5 seconds 126 | __nanosleep(sleepTime); 127 | 128 | // Record end time 129 | asm volatile("mov.u64 %0, %%globaltimer;" : "=l"(end)); 130 | 131 | // Calculate and print the elapsed time in nanoseconds and milliseconds 132 | uint64_t elapsedNs = end - start; 133 | double elapsedMs = (double)elapsedNs / 1000000.0; 134 | printf("Slept for %llu nanoseconds (%.3f milliseconds)\n", elapsedNs, 135 | elapsedMs); 136 | } 137 | } 138 | 139 | // The %globaltimer register seems to not be working 140 | __global__ void sync_test_kernel(RankSignals sg, volatile BarrierState *bstate, 141 | int rank, int world_size) { 142 | 143 | int sleep_time = (rank * NS_PER_S) + (blockIdx.x * NS_PER_S * 0.1); 144 | uint64_t start, end; 145 | if (threadIdx.x == 0) { 146 | printf("rank %d, block %d, sleep time: %d\n", rank, blockIdx.x, sleep_time); 147 | asm volatile("mov.u64 %0, %%globaltimer;" : "=l"(start)); 148 | __nanosleep((rank * NS_PER_S) + (blockIdx.x * NS_PER_S * 0.1)); 149 | } 150 | __syncthreads(); 151 | 152 | start_sync(sg, bstate, rank, world_size); 153 | 154 | if (threadIdx.x == 0) { 155 | asm volatile("mov.u64 %0, %%globaltimer;" : "=l"(end)); 156 | printf("[start_sync] Hello from rank %d, block %d, elapsed time: %llu ns\n", 157 | rank, blockIdx.x, end - start); 158 | __nanosleep((rank * NS_PER_S) + (blockIdx.x * NS_PER_S * 0.1)); 159 | } 160 | __syncthreads(); 161 | 162 | end_sync(sg, bstate, rank, world_size); 163 | 164 | if (threadIdx.x == 0) { 165 | asm volatile("mov.u64 %0, %%globaltimer;" : "=l"(end)); 166 | printf("[end_sync] Hello from rank %d, block %d, elapsed time: %llu ns\n", 167 | rank, blockIdx.x, end - start); 168 | } 169 | __syncthreads(); 170 | } 171 | 172 | class Sync { 173 | public: 174 | int rank_; 175 | int world_size_; 176 | 177 | // below are device pointers 178 | RankSignals sg_; 179 | BarrierState *barrier_state_; 180 | 181 | std::vector ipc_handles_; 182 | 183 | Sync(BarrierState *barrier_state, const cudaIpcMemHandle_t *handles, 184 | const std::vector &offsets, int rank) 185 | : rank_(rank), world_size_(offsets.size()), 186 | barrier_state_(barrier_state) { 187 | for (int i = 0; i < world_size_; i++) { 188 | BarrierState *rank_barrier_state; 189 | if (i != rank_) { 190 | char *handle; 191 | CUDACHECK(cudaIpcOpenMemHandle((void **)&handle, handles[i], 192 | cudaIpcMemLazyEnablePeerAccess)); 193 | ipc_handles_.push_back(handle); 194 | handle += offsets[i]; 195 | rank_barrier_state = (BarrierState *)handle; 196 | } else { 197 | rank_barrier_state = barrier_state_; 198 | } 199 | // This is pure pointer math (no access to on-device memory) 200 | sg_.signals[i] = &rank_barrier_state->sg; 201 | } 202 | } 203 | 204 | void sync_test(int blocks, int threads) { 205 | if (threads % 32 != 0 || threads <= 32) { 206 | throw std::runtime_error( 207 | "Threads must be a multiple of 32 greater than 32"); 208 | } 209 | sleepKernel<<<1, 1>>>(); 210 | cudaDeviceSynchronize(); 211 | 212 | sync_test_kernel<<>>(sg_, barrier_state_, rank_, 213 | world_size_); 214 | cudaDeviceSynchronize(); 215 | } 216 | 217 | ~Sync() { 218 | printf("Rank %d calling destructor\n", rank_); 219 | for (auto ptr : ipc_handles_) { 220 | CUDACHECK(cudaIpcCloseMemHandle(ptr)); 221 | } 222 | } 223 | }; 224 | } // namespace mysync 225 | -------------------------------------------------------------------------------- /allreduce/notes.md: -------------------------------------------------------------------------------- 1 | ## Compilation 2 | ### add_one 3 | - modal (8 cpu) = ~70s 4 | - laptop = ~56s 5 | - vast (ryzen 9) = ~45s 6 | - vast (ryzen 9, no optimization) = ~40s 7 | - ~ 38s to run add_one.o 8 | - ~ 20s to compile pybind.o 9 | 10 | I suspect most of the compilation time for the very simple add_one extension comes from the torch headers. Let me benchmark compiling a simple kernel that doesn't use torch at all. 11 | 12 | ### mpi_cuda_helloworld 13 | Installation: 14 | - `sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev` 15 | - `sudo apt install libnccl-dev libnccl2` 16 | - Use `mpicc -show` to find where your mpi headers are 17 | 18 | Runtime is near instant. *Torch extension process must have been adding a lot of overhead* 19 | Run with `mpirun --allow-run-as-root -np 2 ./fastallreduce_test.bin` 20 | 21 | ## File Structure 22 | - fast_allreduce.cuh => implements allreduce w/o touching Pytorch 23 | - fast_allreduce.cu => torch bindings 24 | - fast_all_reduce_test.cu => driver to test (does not include fast_allreduce.cu) 25 | 26 | ## TODOs 27 | [01-01-24] 28 | - [ ] Get CUDA running on Modal 29 | - Abandoned, too much effort 30 | - [X] Figure out if register_graph_buffer is only used in tests 31 | - No, it's used in the Python bindings 32 | 33 | - Each *buffer* has a `RankData` 34 | - > The rank data is an array to store the registered IPC addresses for all ranks for each input address passed in. 35 | - `buffers_` is (pointer => `RankData`), the pointer is to an input on the current rank? 36 | - The input in the CUDA test is `self_data`, which is the second section of the main test buffer 37 | - > The first section is a temporary buffer for storing intermediate allreduce results, if a particular algorithm requires it. The second section is for the input to the allreduce 38 | - Thus, `std::unordered_map buffers_;` is map of (pointer on current rank) => (pointers on all ranks) 39 | - Also, `RankData *` is a pointer to `RankData` stored on GPU memory 40 | - `*d_rank_data_base` and `*d_rank_data_end` are pointers that mark the start and end, respectively, of a segment of GPU memory allocated for storing `RankData` instances. As new `RankData` instances are copied to the device, `*d_rank_data_base` is incremented, effectively moving the 'start' pointer forward. This means that `*d_rank_data_base` always points to the next available location within the allocated memory segment where new RankData can be copied. 41 | - `ipc_handles_` stores all the ipc handles so they can be closed once the `FastAllreduce` class is destroyed 42 | - `Metadata` stores a `Signal` + a counter (both for the current rank). The `Signal` contains a start/end field, both of which are a `union` of 64-bit int and 8-bytes. The `Signal` is a synchronization primitive. 43 | - `RankSignals` consists of an array of 8 device-pointers, each pointing to a `Signal` on a different rank 44 | - `RankSignals` itself + the device-pointers are stored in CPU memory 45 | 46 | - [X] Understand the test data 47 | - nccl reduces `self_data_copy` => `self_data` 48 | - custom impl reduces `self_data` => `result` 49 | - `self_data` is the 2nd section of the buffer 50 | - [ ] Understand synchronization primitive 51 | 52 | - [X] Roughly understand the barrier code 53 | - Grid shape: `<<>>` 54 | - 1D block grid + 1D thread grid per block (evenly splitting `size` over the # of blocks) 55 | - `packed_t` explanation: 56 | ```cpp 57 | /* 58 | Maximize memory efficiency w/ ld.128,st.128 instructions 59 | The GPU cannot load more than 128 bytes at a time 60 | CUDA threads can read a maximum of 16 bytes at once 61 | So, each thread loads as many elements of type T as can be accommodated within 16 bytes 62 | https://stackoverflow.com/questions/72147025/what-are-cuda-global-memory-32-64-and-128-byte-transactions 63 | */ 64 | template 65 | struct packed_t { 66 | // the (P)acked type for load/store 67 | using P = array_t; 68 | // the (A)ccumulator type for reduction 69 | using A = array_t; 70 | }; 71 | ``` 72 | 73 | Start sync: 74 | - The 1st block in each rank's grid uses its 1st warp to write to the signal in all other ranks and its 2nd warp to reset the end sync flag 75 | - All threads busy-wait until the 1st block per rank has written to all other ranks 76 | - *I guess this is 1st block synchronization?* 77 | 78 | End sync: 79 | - The final block to *reach* the sync point uses its 2nd warp to reset the start flag and its 1st warp to write to the signal across all other ranks 80 | - Its important we use the final block to *reach* the sync point for writing to the signal across all other ranks. 81 | - To illustrate the above point, consider a scenario where a rank has three blocks: block0, block1, and block2. Let's assume that block0 is used to write to the signal and the blocks reach the sync point in the order of block0, block1, block2. In this case, block0 and block1 would proceed to the next task before block2 has necessarily completed its previous task. Similarly, block0 would move on to the next task before block1 has necessarily completed its previous task. 82 | - **QUESTION**: If the kernel exit can serve as a sync, then why even sync at all? 83 | 84 | - [X] Understand where the metadata is stored for the Pytorch bindings 85 | - [X] Read the torch binding functions 86 | - Expectation (seems confirmed?): 87 | - 1 metadata: place to store signals + counter for current rank 88 | - a register_buffer function / something to create new rank data (requires passing in IPC handles to torch) 89 | 90 | ```python 91 | def manual_registration(world_size, rank, distributed_init_port): 92 | init_test_distributed_environment(1, world_size, rank, 93 | distributed_init_port) 94 | sz = 1024 95 | fast_ar.init_fast_ar() 96 | fa = fast_ar.get_handle() 97 | inp = torch.ones(sz, 98 | dtype=torch.float32, 99 | device=torch.cuda.current_device()) 100 | fa.register_buffer(inp) 101 | out = fa.all_reduce(inp) 102 | assert torch.allclose(out, inp * world_size) 103 | ``` 104 | - [X] Understand where `init_fast_ar` and `get_handle` are called 105 | ```python 106 | class FastAllreduce: 107 | 108 | # max_size: max supported allreduce size 109 | def __init__(self, rank, world_size, max_size=8192 * 1024) -> None: 110 | # buffers memory are owned by this Python class and passed to C++ 111 | self.meta = torch.zeros(fast_ar.meta_size() + max_size, 112 | dtype=torch.uint8, 113 | device="cuda") 114 | self.rank_data = torch.empty(16 * 1024 * 1024, 115 | dtype=torch.uint8, 116 | device="cuda") 117 | self.max_size = max_size 118 | self.world_size = world_size 119 | handles, offsets = self._get_ipc_meta(self.meta) 120 | self.full_nvlink = _is_full_nvlink(rank, world_size) 121 | self._ptr = fast_ar.init_fast_ar(self.meta, self.rank_data, handles, 122 | offsets, rank, self.full_nvlink) 123 | self.fast_cond = self.full_nvlink or world_size <= 2 124 | 125 | def _get_ipc_meta(self, inp: torch.Tensor): 126 | data = inp.storage()._share_cuda_() 127 | shard_data = ( 128 | data[1], # ipc handle to base ptr 129 | data[3], # offset of base ptr 130 | ) 131 | return self._gather_ipc_meta(shard_data) 132 | 133 | def _gather_ipc_meta(self, shard_data): 134 | all_data = [None] * self.world_size 135 | dist.all_gather_object(all_data, shard_data) 136 | 137 | handles = [] 138 | offsets = [] 139 | for i in range(len(all_data)): 140 | handles.append(all_data[i][0]) 141 | offsets.append(all_data[i][1]) 142 | return handles, offsets 143 | 144 | def register_buffer(self, inp: torch.Tensor): 145 | handles, offsets = self._get_ipc_meta(inp) 146 | fast_ar.register_buffer(self._ptr, inp, handles, offsets) 147 | ``` 148 | - [X] Understand the pytorch bindings for single buffer registration 149 | - Pytorch offers an API, `_share_cuda_`, to get the IPC handle + offset of a given tensor. (I'm guessing CUDA memory is allocated in large blocks, and tensors are given subsections) 150 | 151 | Return values of `_share_cuda_`: 152 | ```cpp 153 | PyTuple_SET_ITEM(tuple.get(), 0, device.release()); 154 | // cudaIpcMemHandle_t(of basePtr) 155 | PyTuple_SET_ITEM(tuple.get(), 1, _handle.release()); 156 | // Size(in bytes) of the real storage, note this is not the size of basePtr 157 | // memory block. 158 | PyTuple_SET_ITEM(tuple.get(), 2, size_bytes.release()); 159 | // Offset(in bytes) of the real storage in the basePtr memory block. 160 | // NB: this offset MUST be in bytes instead of numel, since we use 161 | // (storage_handle, offset) 162 | // as key in shared_cache(multiprocessing/reduction.py). 163 | // Offset in numel cannot uniquely represent a storage. 164 | PyTuple_SET_ITEM(tuple.get(), 3, _offset_bytes.release()); 165 | PyTuple_SET_ITEM(tuple.get(), 4, _ref_counter.release()); 166 | PyTuple_SET_ITEM(tuple.get(), 5, _ref_counter_offset.release()); 167 | PyTuple_SET_ITEM(tuple.get(), 6, _event_handle.release()); 168 | PyTuple_SET_ITEM(tuple.get(), 7, _event_sync_required.release()); 169 | ``` 170 | 171 | - [X] Email author asking why `end_sync` is needed? *Anything else I should email him on?* 172 | - [X] Implement the synchronization primitive 173 | - [X] Debug the sync primitive 174 | - [X] Try tools like cuda-memcheck 175 | - [X] Narrowed it down to 1-line 176 | - CUDA memcheck is fine & all, but be methodical. There are very few places in this kernel where memory is read/written 177 | - [X] **Do not pass references to CPU memory to CUDA kernels** 178 | - We can write through `BarrierState` but not `RankSignals`? 179 | - The signal's memory address is being set correctly ... 180 | - (Not sure if I actually tried this), we can't even print out the memory address of `RankSignals` in the kernel?? 181 | - [X] Sanity-check by adding print statements + nanosleep -- *nevermind, `__nanosleep` is not useful* 182 | - Each rank delays by a second, each block delays by a second 183 | - We print the total # of clock cycles ?? before reaching the start sync point 184 | - We print the total # of clock cycles after the sync point 185 | - Each rank delays by a second, each block delays by a 0.1 second 186 | - We print the total # of clock cycles ?? before reaching the end sync point 187 | - We print the total # of clock cycles after the end sync point 188 | - Expectation: after the start sync, the minimum clock cycles taken = (secs * # ranks) 189 | - Expectation: after the end sync, the min clock cycles taken = 2 * (secs * #ranks) + (2 * 0.1secs * #blocks) -- *the end sync should be equal to time of last block on last rank* -------------------------------------------------------------------------------- /allreduce/csrc/reference_allreduce/fast_allreduce_test.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include 7 | #include 8 | 9 | #include "cuda_profiler_api.h" 10 | #include "fast_allreduce.cuh" 11 | #include "mpi.h" 12 | #include "nccl.h" 13 | 14 | #define MPICHECK(cmd) \ 15 | do { \ 16 | int e = cmd; \ 17 | if (e != MPI_SUCCESS) { \ 18 | printf("Failed: MPI error %s:%d '%d'\n", __FILE__, __LINE__, e); \ 19 | exit(EXIT_FAILURE); \ 20 | } \ 21 | } while (0) 22 | 23 | #define NCCLCHECK(cmd) \ 24 | do { \ 25 | ncclResult_t r = cmd; \ 26 | if (r != ncclSuccess) { \ 27 | printf("Failed, NCCL error %s:%d '%s'\n", __FILE__, __LINE__, \ 28 | ncclGetErrorString(r)); \ 29 | exit(EXIT_FAILURE); \ 30 | } \ 31 | } while (0) 32 | 33 | __global__ void dummy_kernel() { 34 | for (int i = 0; i < 100; i++) __nanosleep(1000000); // 100ms 35 | } 36 | 37 | template 38 | __global__ void set_data(T *data, int size, int myRank) { 39 | for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; 40 | idx += gridDim.x * blockDim.x) { 41 | data[idx] = myRank * 0.11f; 42 | } 43 | } 44 | 45 | template 46 | __global__ void convert_data(const T *data1, const T *data2, double *fdata1, 47 | double *fdata2, int size) { 48 | for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; 49 | idx += gridDim.x * blockDim.x) { 50 | fdata1[idx] = data1[idx]; 51 | fdata2[idx] = data2[idx]; 52 | } 53 | } 54 | 55 | __global__ void init_rand(curandState_t *state, int size, int nRanks) { 56 | for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; 57 | idx += gridDim.x * blockDim.x) { 58 | for (int i = 0; i < nRanks; i++) { 59 | curand_init(i + 1, idx, 0, &state[idx * nRanks + i]); 60 | } 61 | } 62 | } 63 | 64 | template 65 | __global__ void gen_data(curandState_t *state, T *data, double *ground_truth, 66 | int myRank, int nRanks, int size) { 67 | for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; 68 | idx += gridDim.x * blockDim.x) { 69 | double sum = 0.0; 70 | for (int i = 0; i < nRanks; i++) { 71 | double val = curand_uniform_double(&state[idx * nRanks + i]) * 4; 72 | T hval = val; // downcast first 73 | sum += static_cast(hval); 74 | if (i == myRank) data[idx] = hval; 75 | } 76 | ground_truth[idx] = sum; 77 | } 78 | } 79 | 80 | // T is half/etc. 81 | template 82 | void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, 83 | int data_size) { 84 | T *result; 85 | cudaStream_t stream; 86 | CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); 87 | CUDACHECK(cudaMalloc(&result, data_size * sizeof(T))); 88 | CUDACHECK(cudaMemset(result, 0, data_size * sizeof(T))); 89 | 90 | cudaIpcMemHandle_t self_data_handle; 91 | cudaIpcMemHandle_t data_handles[8]; 92 | 93 | // 256 bytes => signal + counter (128 byte aligned) 94 | vllm::Metadata *buffer; 95 | T *self_data_copy; 96 | // Allocate 2 * data_size + space for vllm:Metadata 97 | CUDACHECK( 98 | cudaMalloc(&buffer, 2 * data_size * sizeof(T) + sizeof(vllm::Metadata))); 99 | CUDACHECK(cudaMemset(buffer, 0, 100 | 2 * data_size * sizeof(T) + sizeof(vllm::Metadata))); 101 | CUDACHECK(cudaMalloc(&self_data_copy, data_size * sizeof(T))); 102 | 103 | // Create a handle to the start of the buffer 104 | CUDACHECK(cudaIpcGetMemHandle(&self_data_handle, buffer)); 105 | 106 | // MPI_Allgather( 107 | // void* send_data, 108 | // int send_count, 109 | // MPI_Datatype send_datatype, 110 | // void* recv_data, 111 | // int recv_count, 112 | // MPI_Datatype recv_datatype, 113 | // MPI_Comm communicator) 114 | 115 | MPICHECK(MPI_Allgather(&self_data_handle, sizeof(cudaIpcMemHandle_t), 116 | MPI_BYTE, data_handles, sizeof(cudaIpcMemHandle_t), 117 | MPI_BYTE, MPI_COMM_WORLD)); 118 | 119 | void *rank_data; 120 | size_t rank_data_sz = 16 * 1024 * 1024; 121 | CUDACHECK(cudaMalloc(&rank_data, rank_data_sz)); 122 | std::vector offsets(nRanks, 0); 123 | vllm::FastAllreduce fa(buffer, rank_data, rank_data_sz, data_handles, offsets, 124 | myRank); 125 | auto *self_data = 126 | reinterpret_cast(reinterpret_cast(buffer) + 127 | sizeof(vllm::Metadata) + data_size * sizeof(T)); 128 | // hack buffer registration 129 | { 130 | std::vector handles; 131 | handles.reserve(nRanks); 132 | for (int i = 0; i < nRanks; i++) { 133 | char *begin = (char *)&data_handles[i]; 134 | char *end = (char *)&data_handles[i + 1]; 135 | // handles.emplace_back(begin, end); 136 | handles.emplace_back(begin, end); 137 | } 138 | std::vector offsets( 139 | nRanks, sizeof(vllm::Metadata) + data_size * sizeof(T)); 140 | fa.register_buffer(handles, offsets, self_data); 141 | } 142 | 143 | // Ground truth data 144 | double *verification_buffer; 145 | CUDACHECK(cudaMallocHost(&verification_buffer, data_size * sizeof(double))); 146 | 147 | curandState_t *states; 148 | CUDACHECK(cudaMalloc(&states, sizeof(curandState_t) * nRanks * data_size)); 149 | 150 | // blocks, threads per block 151 | init_rand<<<108, 1024, 0, stream>>>(states, data_size, nRanks); 152 | 153 | // self_data is inputs for current rank 154 | gen_data<<<108, 1024, 0, stream>>>(states, self_data, verification_buffer, 155 | myRank, nRanks, data_size); 156 | 157 | CUDACHECK(cudaMemcpyAsync(self_data_copy, self_data, data_size * sizeof(T), 158 | cudaMemcpyDeviceToDevice, stream)); 159 | cudaEvent_t start, stop; 160 | CUDACHECK(cudaEventCreate(&start)); 161 | CUDACHECK(cudaEventCreate(&stop)); 162 | 163 | ncclDataType_t ncclDtype; 164 | if (std::is_same::value) { 165 | ncclDtype = ncclFloat16; 166 | } else if (std::is_same::value) { 167 | ncclDtype = ncclBfloat16; 168 | } else { 169 | ncclDtype = ncclFloat; 170 | } 171 | 172 | dummy_kernel<<<1, 1, 0, stream>>>(); 173 | constexpr int warmup_iters = 5; 174 | constexpr int num_iters = 25; 175 | // warmup 176 | for (int i = 0; i < warmup_iters; i++) { 177 | NCCLCHECK(ncclAllReduce(result, result, data_size, ncclDtype, ncclSum, comm, 178 | stream)); 179 | } 180 | CUDACHECK(cudaEventRecord(start, stream)); 181 | for (int i = 0; i < num_iters; i++) { 182 | NCCLCHECK(ncclAllReduce(result, result, data_size, ncclDtype, ncclSum, comm, 183 | stream)); 184 | } 185 | CUDACHECK(cudaEventRecord(stop, stream)); 186 | CUDACHECK(cudaStreamSynchronize(stream)); 187 | float allreduce_ms = 0; 188 | cudaEventElapsedTime(&allreduce_ms, start, stop); 189 | 190 | // if (myRank == 1) dummy_kernel<<<1, 1, 0, stream>>>(); 191 | // set_data<<<16, 1024, 0, stream>>>(self_data, data_size, myRank); 192 | 193 | dummy_kernel<<<1, 1, 0, stream>>>(); 194 | // warm up 195 | for (int i = 0; i < warmup_iters; i++) { 196 | // Method sig: 197 | // cudaStream_t, half *, half *, int, int, int 198 | fa.allreduce(stream, self_data, result, data_size, threads, block_limit); 199 | } 200 | CUDACHECK(cudaEventRecord(start, stream)); 201 | for (int i = 0; i < num_iters; i++) { 202 | fa.allreduce(stream, self_data, result, data_size, threads, block_limit); 203 | } 204 | CUDACHECK(cudaEventRecord(stop, stream)); 205 | CUDACHECK(cudaStreamSynchronize(stream)); 206 | 207 | float duration_ms = 0; 208 | cudaEventElapsedTime(&duration_ms, start, stop); 209 | if (myRank == 0) 210 | printf( 211 | "Rank %d done, nGPUs:%d, sz (kb): %d, %d, %d, my time:%.2fus, nccl " 212 | "time:%.2fus\n", 213 | myRank, nRanks, data_size * sizeof(T) / 1024, threads, block_limit, 214 | duration_ms * 1e3 / num_iters, allreduce_ms * 1e3 / num_iters); 215 | 216 | // And wait for all the queued up work to complete 217 | CUDACHECK(cudaStreamSynchronize(stream)); 218 | 219 | // nccl result is in self_data 220 | NCCLCHECK(ncclAllReduce(self_data_copy, self_data, data_size, ncclDtype, 221 | ncclSum, comm, stream)); 222 | 223 | double *nccl_result, *my_result; 224 | CUDACHECK(cudaMallocHost(&nccl_result, data_size * sizeof(double))); 225 | CUDACHECK(cudaMallocHost(&my_result, data_size * sizeof(double))); 226 | 227 | // copy self_data -> nccl_result 228 | // copy result -> my_result 229 | convert_data<<<108, 1024, 0, stream>>>(self_data, result, nccl_result, 230 | my_result, data_size); 231 | CUDACHECK(cudaStreamSynchronize(stream)); 232 | 233 | for (unsigned long j = 0; j < data_size; j++) { 234 | auto diff = abs(nccl_result[j] - my_result[j]); 235 | if (diff >= 1e-2) { 236 | printf("Rank %d: Verification mismatch at %lld: %f != (my) %f\n", myRank, 237 | j, nccl_result[j], my_result[j]); 238 | break; 239 | } 240 | } 241 | 242 | long double nccl_diffs = 0.0; 243 | long double my_diffs = 0.0; 244 | for (int j = 0; j < data_size; j++) { 245 | nccl_diffs += abs(nccl_result[j] - verification_buffer[j]); 246 | my_diffs += abs(my_result[j] - verification_buffer[j]); 247 | } 248 | if (myRank == 0) 249 | std::cout << "average abs diffs: nccl: " << nccl_diffs / data_size 250 | << " me: " << my_diffs / data_size << std::endl; 251 | 252 | CUDACHECK(cudaFree(result)); 253 | CUDACHECK(cudaFree(self_data_copy)); 254 | CUDACHECK(cudaFree(rank_data)); 255 | CUDACHECK(cudaFree(buffer)); 256 | CUDACHECK(cudaFree(states)); 257 | CUDACHECK(cudaFreeHost(verification_buffer)); 258 | CUDACHECK(cudaFreeHost(nccl_result)); 259 | CUDACHECK(cudaFreeHost(my_result)); 260 | CUDACHECK(cudaStreamDestroy(stream)); 261 | } 262 | 263 | int main(int argc, char **argv) { 264 | int nRanks, myRank; 265 | MPICHECK(MPI_Init(&argc, &argv)); 266 | MPICHECK(MPI_Comm_rank(MPI_COMM_WORLD, &myRank)); 267 | MPICHECK(MPI_Comm_size(MPI_COMM_WORLD, &nRanks)); 268 | CUDACHECK(cudaSetDevice(myRank)); 269 | ncclUniqueId id; 270 | ncclComm_t comm; 271 | if (myRank == 0) ncclGetUniqueId(&id); 272 | MPICHECK(MPI_Bcast(static_cast(&id), sizeof(id), MPI_BYTE, 0, 273 | MPI_COMM_WORLD)); 274 | NCCLCHECK(ncclCommInitRank(&comm, nRanks, id, myRank)); 275 | 276 | cudaProfilerStart(); 277 | // for (int threads : {256, 512}) { 278 | // for (int block_limit = 16; block_limit < 112; block_limit += 4) { 279 | // run(myRank, nRanks, comm, threads, block_limit, 4096 * 1024); 280 | // } 281 | // } 282 | for (int sz = 512; sz <= (4 << 20); sz *= 2) { 283 | run(myRank, nRanks, comm, 512, 36, sz + 8 * 50); 284 | } 285 | 286 | cudaProfilerStop(); 287 | MPI_Finalize(); // prevents MPI error 288 | return EXIT_SUCCESS; 289 | } -------------------------------------------------------------------------------- /awq/custom_autotune.py: -------------------------------------------------------------------------------- 1 | # Exactly the same as the normal autotuner 2 | # but it prints the best config 3 | from __future__ import annotations 4 | 5 | import builtins 6 | import time 7 | from typing import Dict 8 | 9 | from triton import KernelInterface 10 | from triton.testing import do_bench 11 | 12 | 13 | class OutOfResources(Exception): 14 | def __init__(self, required, limit, name): 15 | self.message = f'out of resource: {name}, '\ 16 | f'Required: {required}, '\ 17 | f'Hardware limit: {limit}' 18 | self.message += '. Reducing block sizes or `num_stages` may help.' 19 | self.required = required 20 | self.limit = limit 21 | self.name = name 22 | super().__init__(self.message) 23 | 24 | def __reduce__(self): 25 | # this is necessary to make CompilationError picklable 26 | return (type(self), (self.required, self.limit, self.name)) 27 | 28 | 29 | class Autotuner(KernelInterface): 30 | def __init__(self, fn, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None, warmup=25, rep=100): 31 | ''' 32 | :param prune_configs_by: a dict of functions that are used to prune configs, fields: 33 | 'perf_model': performance model used to predicate running time with different configs, returns running time 34 | 'top_k': number of configs to bench 35 | 'prune_num_stages_by'(optional): a function used to prune num_stages. It takes configs:List[Config] as its input, and returns pruned configs. 36 | ''' 37 | if not configs: 38 | self.configs = [Config({}, num_warps=4, num_stages=2)] 39 | else: 40 | self.configs = configs 41 | self.key_idx = [arg_names.index(k) for k in key] 42 | self.cache = {} 43 | # hook to reset all required tensor to zeros before relaunching a kernel 44 | self.hook = lambda args: 0 45 | if reset_to_zero is not None: 46 | self.reset_idx = [arg_names.index(k) for k in reset_to_zero] 47 | 48 | def _hook(args): 49 | for i in self.reset_idx: 50 | args[i].zero_() 51 | self.hook = _hook 52 | self.arg_names = arg_names 53 | # prune configs 54 | if prune_configs_by: 55 | perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k'] 56 | if 'early_config_prune' in prune_configs_by: 57 | early_config_prune = prune_configs_by['early_config_prune'] 58 | else: 59 | perf_model, top_k, early_config_prune = None, None, None 60 | self.perf_model, self.configs_top_k = perf_model, top_k 61 | self.early_config_prune = early_config_prune 62 | self.fn = fn 63 | self.warmup = warmup 64 | self.rep = rep 65 | 66 | def _bench(self, *args, config, **meta): 67 | # check for conflicts, i.e. meta-parameters both provided 68 | # as kwargs and by the autotuner 69 | conflicts = meta.keys() & config.kwargs.keys() 70 | if conflicts: 71 | raise ValueError( 72 | f"Conflicting meta-parameters: {', '.join(conflicts)}." 73 | " Make sure that you don't re-define auto-tuned symbols." 74 | ) 75 | # augment meta-parameters with tunable ones 76 | current = dict(meta, **config.kwargs) 77 | full_nargs = {**self.nargs, **current} 78 | 79 | def kernel_call(): 80 | if config.pre_hook: 81 | config.pre_hook(full_nargs) 82 | self.hook(args) 83 | self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current) 84 | try: 85 | return do_bench(kernel_call, warmup=self.warmup, rep=self.rep, quantiles=(0.5, 0.2, 0.8)) 86 | except OutOfResources: 87 | return [float('inf'), float('inf'), float('inf')] 88 | 89 | def run(self, *args, **kwargs): 90 | self.nargs = dict(zip(self.arg_names, args)) 91 | if len(self.configs) > 1: 92 | all_args = {**self.nargs, **kwargs} 93 | _args = [] 94 | for name in self.arg_names: 95 | if name in all_args: 96 | _args.append(all_args[name]) 97 | key = tuple(_args[i] for i in self.key_idx) 98 | if key not in self.cache: 99 | # prune configs 100 | pruned_configs = self.prune_configs(kwargs) 101 | bench_start = time.time() 102 | timings = {config: self._bench(*args, config=config, **kwargs) 103 | for config in pruned_configs} 104 | bench_end = time.time() 105 | self.bench_time = bench_end - bench_start 106 | self.cache[key] = builtins.min(timings, key=timings.get) 107 | self.hook(args) 108 | self.configs_timings = timings 109 | config = self.cache[key] 110 | else: 111 | config = self.configs[0] 112 | self.best_config = config 113 | print(f"Best config: {config}") 114 | if config.pre_hook is not None: 115 | full_nargs = {**self.nargs, **kwargs, **self.best_config.kwargs} 116 | config.pre_hook(full_nargs) 117 | ret = self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs) 118 | self.nargs = None 119 | return ret 120 | 121 | def prune_configs(self, kwargs): 122 | pruned_configs = self.configs 123 | if self.early_config_prune: 124 | pruned_configs = self.early_config_prune(self.configs, self.nargs) 125 | if self.perf_model: 126 | top_k = self.configs_top_k 127 | if isinstance(top_k, float) and top_k <= 1.0: 128 | top_k = int(len(self.configs) * top_k) 129 | if len(pruned_configs) > top_k: 130 | est_timing = { 131 | config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, 132 | num_warps=config.num_warps) 133 | for config in pruned_configs 134 | } 135 | pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k] 136 | return pruned_configs 137 | 138 | def warmup(self, *args, **kwargs): 139 | self.nargs = dict(zip(self.arg_names, args)) 140 | for config in self.prune_configs(kwargs): 141 | self.fn.warmup( 142 | *args, 143 | num_warps=config.num_warps, 144 | num_stages=config.num_stages, 145 | **kwargs, 146 | **config.kwargs, 147 | ) 148 | self.nargs = None 149 | 150 | 151 | class Config: 152 | """ 153 | An object that represents a possible kernel configuration for the auto-tuner to try. 154 | 155 | :ivar meta: a dictionary of meta-parameters to pass to the kernel as keyword arguments. 156 | :type meta: dict[Str, Any] 157 | :ivar num_warps: the number of warps to use for the kernel when compiled for GPUs. For example, if 158 | `num_warps=8`, then each kernel instance will be automatically parallelized to 159 | cooperatively execute using `8 * 32 = 256` threads. 160 | :type num_warps: int 161 | :ivar num_stages: the number of stages that the compiler should use when software-pipelining loops. 162 | Mostly useful for matrix multiplication workloads on SM80+ GPUs. 163 | :type num_stages: int 164 | :ivar pre_hook: a function that will be called before the kernel is called. Parameters of this 165 | function are args. 166 | """ 167 | 168 | def __init__(self, kwargs, num_warps=4, num_stages=2, pre_hook=None): 169 | self.kwargs = kwargs 170 | self.num_warps = num_warps 171 | self.num_stages = num_stages 172 | self.pre_hook = pre_hook 173 | 174 | def __str__(self): 175 | res = [] 176 | for k, v in self.kwargs.items(): 177 | res.append(f'{k}: {v}') 178 | res.append(f'num_warps: {self.num_warps}') 179 | res.append(f'num_stages: {self.num_stages}') 180 | return ', '.join(res) 181 | 182 | 183 | def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, warmup=25, rep=100): 184 | """ 185 | Decorator for auto-tuning a :code:`triton.jit`'d function. 186 | 187 | .. highlight:: python 188 | .. code-block:: python 189 | 190 | @triton.autotune(configs=[ 191 | triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4), 192 | triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8), 193 | ], 194 | key=['x_size'] # the two above configs will be evaluated anytime 195 | # the value of x_size changes 196 | ) 197 | @triton.jit 198 | def kernel(x_ptr, x_size, **META): 199 | BLOCK_SIZE = META['BLOCK_SIZE'] 200 | :note: When all the configurations are evaluated, the kernel will run multiple times. 201 | This means that whatever value the kernel updates will be updated multiple times. 202 | To avoid this undesired behavior, you can use the `reset_to_zero` argument, which 203 | resets the value of the provided tensor to `zero` before running any configuration. 204 | :param configs: a list of :code:`triton.Config` objects 205 | :type configs: list[triton.Config] 206 | :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs. 207 | :type key: list[str] 208 | :param prune_configs_by: a dict of functions that are used to prune configs, fields: 209 | 'perf_model': performance model used to predicate running time with different configs, returns running time 210 | 'top_k': number of configs to bench 211 | 'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It takes configs:List[Config] as its input, and returns pruned configs. 212 | :param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs. 213 | :type reset_to_zero: list[str] 214 | :param warmup: Warmup time (in ms) to pass to benchmarking, defaults to 25. 215 | :type warmup: int 216 | :param rep: Repetition time (in ms) to pass to benchmarking, defaults to 100. 217 | :type rep: int 218 | """ 219 | def decorator(fn): 220 | return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, prune_configs_by, warmup, rep) 221 | 222 | return decorator 223 | 224 | 225 | class Heuristics(KernelInterface): 226 | 227 | def __init__(self, fn, arg_names, values) -> None: 228 | self.fn = fn 229 | self.values = values 230 | self.arg_names = arg_names 231 | 232 | def run(self, *args, **kwargs): 233 | for v, heur in self.values.items(): 234 | kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwargs}) 235 | return self.fn.run(*args, **kwargs) 236 | 237 | 238 | def heuristics(values): 239 | """ 240 | Decorator for specifying how the values of certain meta-parameters may be computed. 241 | This is useful for cases where auto-tuning is prohibitevely expensive, or just not applicable. 242 | 243 | .. highlight:: python 244 | .. code-block:: python 245 | 246 | @triton.heuristics(values={'BLOCK_SIZE': lambda args: 2 ** int(math.ceil(math.log2(args[1])))}) 247 | @triton.jit 248 | def kernel(x_ptr, x_size, **META): 249 | BLOCK_SIZE = META['BLOCK_SIZE'] # smallest power-of-two >= x_size 250 | :param values: a dictionary of meta-parameter names and functions that compute the value of the meta-parameter. 251 | each such function takes a list of positional arguments as input. 252 | :type values: dict[str, Callable[[list[Any]], Any]] 253 | """ 254 | def decorator(fn): 255 | return Heuristics(fn, fn.arg_names, values) 256 | 257 | return decorator 258 | -------------------------------------------------------------------------------- /allreduce/csrc/my_allreduce/allreduce.cuh: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | #define CUDACHECK(cmd) \ 10 | do { \ 11 | cudaError_t e = cmd; \ 12 | if (e != cudaSuccess) { \ 13 | printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \ 14 | cudaGetErrorString(e)); \ 15 | exit(EXIT_FAILURE); \ 16 | } \ 17 | } while (0) 18 | 19 | #define CHECK_NOT_NULL(ptr) \ 20 | do { \ 21 | if (ptr == NULL) { \ 22 | printf("Failed: Null pointer error %s:%d\n", __FILE__, __LINE__); \ 23 | exit(EXIT_FAILURE); \ 24 | } \ 25 | } while (0) 26 | 27 | // #define CUDACHECK(cmd) cmd 28 | 29 | namespace mysync { 30 | struct Signal { 31 | alignas(64) union { 32 | uint64_t flag; 33 | unsigned char data[8]; 34 | } start; 35 | alignas(64) union { 36 | uint64_t flag; 37 | unsigned char data[8]; 38 | } end; 39 | }; 40 | 41 | struct BarrierState { 42 | alignas(128) Signal sg; 43 | alignas(128) int counter; 44 | }; 45 | static_assert(offsetof(BarrierState, counter) == 128); 46 | static_assert(sizeof(BarrierState) == 256); 47 | 48 | struct RankSignals { 49 | volatile Signal *signals[8]; 50 | }; 51 | 52 | // TODO: Why alignment of 16? 53 | // TODO: Why even put this on the GPU at all? 54 | 55 | // struct __align__(16) RankPtrs { 56 | struct __align__(16) RankPtrs { 57 | const void *__restrict__ ptrs[8]; 58 | }; 59 | 60 | __device__ uint64_t get_target_flag(int world_size) { 61 | // 64 1s 62 | auto m = std::numeric_limits::max(); 63 | // Each GPU gets 8 bits in the flag 64 | // E.g, if there are 4 GPUs, the target flag is 65 | // 32 0s followed by 32 1s 66 | return m >> ((8 - world_size) * 8); 67 | } 68 | 69 | __device__ void start_sync(const RankSignals &sg, volatile BarrierState *bstate, 70 | int rank, int world_size) { 71 | bool first_block_in_rank = blockIdx.x == 0; 72 | if (first_block_in_rank) { 73 | if (threadIdx.x < world_size) { 74 | int other_rank = threadIdx.x; 75 | // warp 1: notify all other ranks that this rank has reached the sync 76 | // point 77 | sg.signals[other_rank]->start.data[rank] = 255; 78 | } else if (threadIdx.x == 32) { 79 | // warp 2: reset the end signal 80 | bstate->sg.end.flag = 0; 81 | } 82 | } 83 | 84 | // busy-wait until the current rank's signal 85 | // has been written to by all ranks 86 | if (threadIdx.x == 0) { 87 | uint64_t target_flag = get_target_flag(world_size); 88 | while (bstate->sg.start.flag != target_flag) 89 | ; 90 | } 91 | __syncthreads(); 92 | } 93 | 94 | __device__ void end_sync(const RankSignals &sg, volatile BarrierState *bstate, 95 | int rank, int world_size) { 96 | __shared__ int blocks_at_sync_point; 97 | if (threadIdx.x == 0) 98 | blocks_at_sync_point = atomicAdd((int *)&bstate->counter, 1); 99 | __syncthreads(); // (I think) this ensures `blocks_at_sync_point` is assigned 100 | 101 | bool last_block_at_sync_point = (blocks_at_sync_point == gridDim.x - 1); 102 | if (last_block_at_sync_point) { 103 | if (threadIdx.x < world_size) { 104 | int other_rank = threadIdx.x; 105 | // warp 1: notify all other ranks that this rank has reached the sync 106 | // point 107 | sg.signals[other_rank]->end.data[rank] = 255; 108 | } else if (threadIdx.x == 32) { 109 | // warp 2: reset the start signal + counter 110 | bstate->sg.start.flag = 0; 111 | bstate->counter = 0; 112 | } 113 | } 114 | 115 | // busy-wait until the current rank's signal 116 | // has been written to by all ranks 117 | 118 | // For simplicity, we do a cross-gpu sync on all blocks 119 | // But, in practice, if this is the last sync, we only need to do a cross-gpu sync 120 | // on 1 block per rank, and let the other blocks use the kernel exit to sync 121 | if (threadIdx.x == 0) { 122 | uint64_t target_flag = get_target_flag(world_size); 123 | while (bstate->sg.end.flag != target_flag) 124 | ; 125 | } 126 | __syncthreads(); 127 | } 128 | 129 | // like std::array, but aligned 130 | // TODO: Exactly why does this matter? 131 | template struct __align__(alignof(T) * sz) array_t { 132 | T data[sz]; 133 | using type = T; 134 | static constexpr int size = sz; 135 | }; 136 | 137 | /* 138 | Maximize memory efficiency w/ ld.128,st.128 instructions 139 | The GPU cannot load more than 128 bytes at a time 140 | CUDA threads can read a maximum of 16 bytes at once 141 | So, each thread loads as many elements of type T as can be accommodated within 142 | 16 bytes 143 | https://stackoverflow.com/questions/72147025/what-are-cuda-global-memory-32-64-and-128-byte-transactions 144 | */ 145 | template struct packed_t { 146 | // the (P)acked type for load/store 147 | using P = array_t; 148 | // the (A)ccumulator type for reduction 149 | using A = array_t; 150 | }; 151 | 152 | #define DINLINE __device__ __forceinline__ 153 | 154 | // scalar cast functions 155 | DINLINE float upcast_s(half val) { return __half2float(val); } 156 | 157 | // downcast_s should never be called when the input is a `float` 158 | template 159 | DINLINE T downcast_s(float val); 160 | template <> 161 | DINLINE half downcast_s(float val) { 162 | return __float2half(val); 163 | } 164 | 165 | // scalar add functions 166 | // for some reason when compiling with Pytorch, the + operator for half and 167 | // bfloat is disabled so we call the intrinsics directly 168 | DINLINE half &assign_add(half &a, half b) { 169 | a = __hadd(a, b); 170 | return a; 171 | } 172 | DINLINE float &assign_add(float &a, float b) { return a += b; } 173 | 174 | template 175 | DINLINE array_t &packed_assign_add(array_t &a, array_t b) { 176 | #pragma unroll 177 | for (int i = 0; i < N; i++) { 178 | assign_add(a.data[i], b.data[i]); 179 | } 180 | return a; 181 | } 182 | 183 | template 184 | DINLINE array_t upcast(array_t val) { 185 | if constexpr (std::is_same::value) { 186 | return val; 187 | } else { 188 | array_t out; 189 | #pragma unroll 190 | for (int i = 0; i < N; i++) { 191 | out.data[i] = upcast_s(val.data[i]); 192 | } 193 | return out; 194 | } 195 | } 196 | 197 | template DINLINE O downcast(array_t val) { 198 | if constexpr (std::is_same::value) { 199 | return val; 200 | } else { 201 | O out; 202 | #pragma unroll 203 | for (int i = 0; i < O::size; i++) { 204 | out.data[i] = downcast_s(val.data[i]); 205 | } 206 | return out; 207 | } 208 | } 209 | 210 | template 211 | DINLINE P packed_reduce(const P *ptrs[], int idx) { 212 | A tmp = upcast(ptrs[0][idx]); 213 | #pragma unroll 214 | for (int i = 1; i < ngpus; i++) { 215 | packed_assign_add(tmp, upcast(ptrs[i][idx])); 216 | } 217 | return downcast

(tmp); 218 | } 219 | 220 | template 221 | __global__ void allreduce_kernel( 222 | RankPtrs buffer_ptrs, // Pointers to the buffer to reduce, 1 for each GPU 223 | RankSignals sg, volatile BarrierState *bstate, T *__restrict__ result, 224 | int rank, int world_size, int nelements) { 225 | // Both P,A are array_t 226 | using P = typename packed_t::P; 227 | using A = typename packed_t::A; 228 | const P *ptrs[ngpus]; 229 | 230 | // TODO: Why are we loading the pointers in a circular fashion? 231 | // Since every allreduce sums across all ranks, we should be able to 232 | // load the pointers inorder 233 | #pragma unroll 234 | for (int i = 0; i < world_size; i++) { 235 | int target = (rank + i) % world_size; 236 | ptrs[i] = (P *)buffer_ptrs.ptrs[target]; 237 | } 238 | 239 | start_sync(sg, bstate, rank, world_size); 240 | 241 | // This is summing across all the ranks 242 | // at the given index 243 | // it's basically `result[idx] = rank1[idx] + rank2[idx] + rank3[idx]` 244 | // All complexity comes from 245 | // the packed type -- which means idx is actually a range of indices 246 | for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < nelements; 247 | idx += gridDim.x * blockDim.x) { 248 | // packed_reduce is just iterating over all the ranks 249 | ((P *)result)[idx] = packed_reduce(ptrs, idx); 250 | } 251 | 252 | end_sync(sg, bstate, rank, world_size); 253 | } 254 | 255 | class Sync { 256 | public: 257 | int rank_; 258 | int world_size_; 259 | 260 | RankPtrs *buffer_ptrs_; 261 | 262 | // Contains pointers to GPU memory 263 | RankSignals sg_; 264 | 265 | BarrierState *barrier_state_; 266 | 267 | std::vector ipc_handles_; 268 | 269 | Sync(BarrierState *barrier_state, const cudaIpcMemHandle_t *handles, 270 | const std::vector &offsets, int rank) 271 | : rank_(rank), world_size_(offsets.size()), 272 | barrier_state_(barrier_state), buffer_ptrs_(nullptr) { 273 | for (int i = 0; i < world_size_; i++) { 274 | BarrierState *rank_barrier_state; 275 | if (i != rank_) { 276 | char *handle; 277 | CUDACHECK(cudaIpcOpenMemHandle((void **)&handle, handles[i], 278 | cudaIpcMemLazyEnablePeerAccess)); 279 | ipc_handles_.push_back(handle); 280 | handle += offsets[i]; 281 | rank_barrier_state = (BarrierState *)handle; 282 | printf("Rank %d: opened handle %d in constructor\n", rank_, i); 283 | } else { 284 | rank_barrier_state = barrier_state_; 285 | } 286 | // This is pure pointer math (no access to on-device memory) 287 | sg_.signals[i] = &rank_barrier_state->sg; 288 | } 289 | } 290 | 291 | void register_buffer(const std::vector &handles, 292 | const std::vector &offsets, void *rank_ptr) { 293 | if (buffer_ptrs_ != nullptr) { 294 | throw std::runtime_error("register_buffer() called twice"); 295 | } 296 | 297 | buffer_ptrs_ = (RankPtrs*)malloc(sizeof(RankPtrs)); 298 | CHECK_NOT_NULL(buffer_ptrs_); 299 | 300 | for (int i = 0; i < world_size_; i++) { 301 | if (i != rank_) { 302 | char *ptr; 303 | CUDACHECK(cudaIpcOpenMemHandle( 304 | (void**)&ptr, *((const cudaIpcMemHandle_t *)handles[i].data()), 305 | cudaIpcMemLazyEnablePeerAccess)); 306 | ipc_handles_.push_back(ptr); 307 | (buffer_ptrs_)->ptrs[i] = ptr + offsets[i]; 308 | } else { 309 | (buffer_ptrs_)->ptrs[i] = rank_ptr; 310 | } 311 | } 312 | } 313 | 314 | template 315 | void allreduce(int num_elements, T *output) { 316 | if (buffer_ptrs_ == nullptr) { 317 | throw std::runtime_error("register_buffer() must be called first"); 318 | } 319 | 320 | int threads = 64; 321 | auto num_packed_elements = packed_t::P::size; 322 | if (num_elements % num_packed_elements != 0) { 323 | throw std::runtime_error( 324 | "Number of elements must be a multiple of " + 325 | std::to_string(num_packed_elements)); 326 | } 327 | 328 | #define CEIL_DIV(x, y) (((x) + (y) - 1) / (y)) 329 | int blocks = CEIL_DIV(num_packed_elements, threads); 330 | 331 | switch (world_size_) { 332 | case 2: 333 | allreduce_kernel<<>>( 334 | *buffer_ptrs_, sg_, barrier_state_, output, rank_, world_size_, num_elements); 335 | break; 336 | case 4: 337 | allreduce_kernel<<>>( 338 | *buffer_ptrs_, sg_, barrier_state_, output, rank_, world_size_, num_elements); 339 | break; 340 | default: 341 | throw std::runtime_error("Unsupported world size"); 342 | } 343 | cudaDeviceSynchronize(); 344 | } 345 | 346 | ~Sync() { 347 | printf("Rank %d calling destructor\n", rank_); 348 | for (auto ptr : ipc_handles_) { 349 | CUDACHECK(cudaIpcCloseMemHandle(ptr)); 350 | } 351 | // free buffer_ptrs_ 352 | free(buffer_ptrs_); 353 | 354 | /* 355 | if (buffer_ptrs_ != nullptr) { 356 | CUDACHECK(cudaFree(buffer_ptrs_)); 357 | } 358 | */ 359 | } 360 | }; 361 | } // namespace mysync 362 | -------------------------------------------------------------------------------- /paged_attention_triton/attention_kernel.py: -------------------------------------------------------------------------------- 1 | from einops import rearrange 2 | import triton 3 | import triton.language as tl 4 | 5 | # Expect block table to map 6 | # logical bid (block id) -> (physical bid, # filled) 7 | # In tests, it maps: logical bid -> physical bid 8 | 9 | 10 | @triton.jit 11 | def paged_attention( 12 | debug_block_idxs_ptr, 13 | debug_key_cache_load_ptr, 14 | debug_key_cache_load_ptr2, 15 | debug_block_idx_ptr2, 16 | debug_key_cache_load_ptr3, 17 | debug_key_cache_load_ptr4, 18 | debug_key_cache_load_ptr5, 19 | debug_scores_ptr, 20 | debug_softmax_ptr, 21 | debug_output_ptr, 22 | 23 | # need these b/c we can't use view/reshape 24 | scratchpad_key_ptr, # [num_seqs, max_context_len, num_heads, head_size] 25 | scratchpad_value_ptr, # [num_seqs, max_context_len, num_heads, head_size] 26 | output_ptr, # [num_seqs, num_query_heads, head_size] 27 | query_ptr, # [num_seqs, num_query_heads, head_size] 28 | key_cache_ptr, # [num_blocks, num_kv_heads, head_size, block_size] 29 | value_cache_ptr, # [num_blocks, num_kv_heads, head_size, block_size] 30 | block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] 31 | context_lens_ptr, # [num_seqs] 32 | scale, # float32 33 | num_seqs, # int 34 | num_heads, # int 35 | cache_block_stride, # int 36 | MAX_CONTEXT_LEN: tl.constexpr, # int 37 | BLOCK_SIZE: tl.constexpr, # int 38 | HEAD_SIZE: tl.constexpr, # int, must be power of 2 39 | MAX_NUM_BLOCKS_PER_SEQ: tl.constexpr, # int, must be power of 2 40 | ): 41 | seq_idx = tl.program_id(0) 42 | head_idx = tl.program_id(1) 43 | 44 | query_offset = seq_idx * num_seqs + head_idx * HEAD_SIZE 45 | query_head = tl.load(query_ptr + query_offset + tl.arange(0, HEAD_SIZE)) 46 | block_table_offset = seq_idx * MAX_NUM_BLOCKS_PER_SEQ 47 | 48 | # physical_block_idxs = tl.load( 49 | # block_tables_ptr + block_table_offset + tl.arange(0, MAX_NUM_BLOCKS_PER_SEQ) 50 | # ) 51 | 52 | # if seq_idx == 0 and head_idx == 0: 53 | # tl.store( 54 | # debug_block_idxs_ptr + tl.arange(0, MAX_NUM_BLOCKS_PER_SEQ), 55 | # physical_block_idxs, 56 | # ) 57 | 58 | # start_of_block_offsets = (physical_block_idxs * cache_block_stride) + ( 59 | # head_idx * HEAD_SIZE * BLOCK_SIZE 60 | # ) 61 | 62 | # # Test #2 with: 63 | # if seq_idx == 0 and head_idx == 1: 64 | # sample_values = tl.load(key_cache_ptr + start_of_block_offsets) 65 | # tl.store( 66 | # debug_key_cache_load_ptr + tl.arange(0, MAX_NUM_BLOCKS_PER_SEQ), 67 | # sample_values, 68 | # ) 69 | 70 | # Works, but can't transform [max_num_blocks_per_seq, head_size, block_size] => [seq_len, head_size] 71 | # https://github.com/openai/triton/discussions/2666 72 | # https://github.com/openai/triton/issues/2522 73 | 74 | # key_block_offsets = ( 75 | # start_of_block_offsets[:, None, None] 76 | # + (BLOCK_SIZE * tl.arange(0, HEAD_SIZE)[None, :, None]) 77 | # + (1 * tl.arange(0, BLOCK_SIZE)[None, None, :]) 78 | # ) 79 | 80 | # # shape = [max_num_blocks_per_seq, head_size, block_size] 81 | # key_block = tl.load(key_cache_ptr + key_block_offsets) 82 | # if seq_idx == 0 and head_idx == 1: 83 | # store_offsets = ( 84 | # (BLOCK_SIZE * HEAD_SIZE * tl.arange(0, MAX_NUM_BLOCKS_PER_SEQ)[:, None, None]) 85 | # + (BLOCK_SIZE * tl.arange(0, HEAD_SIZE)[None, :, None]) 86 | # + (1 * tl.arange(0, BLOCK_SIZE)[None, None, :]) 87 | # ) 88 | # tl.store( 89 | # debug_key_cache_load_ptr2 + store_offsets, 90 | # key_block 91 | # ) 92 | 93 | context_len = tl.load(context_lens_ptr + seq_idx) 94 | 95 | # Can't allocate memory that's not known at compile time 96 | # (We could make it known @ compile time by making context_len a tl.constexpr) 97 | # seq_keys = tl.zeros((context_len, HEAD_SIZE), dtype=tl.float32) 98 | # seq_values = tl.zeros((context_len, HEAD_SIZE), dtype=tl.float32) 99 | 100 | for tok_idx in range(0, context_len): 101 | logical_block_idx = tok_idx // BLOCK_SIZE 102 | physical_block_idx = tl.load( 103 | block_tables_ptr + block_table_offset + logical_block_idx 104 | ) 105 | 106 | if (tok_idx == 0 and seq_idx == 0) and (head_idx == 1): 107 | tl.store(debug_block_idx_ptr2, physical_block_idx) 108 | 109 | start_of_block_offset = ( 110 | physical_block_idx * cache_block_stride + head_idx * HEAD_SIZE * BLOCK_SIZE 111 | ) 112 | tok_idx_within_block = tok_idx % BLOCK_SIZE 113 | tok_offsets = ( 114 | start_of_block_offset 115 | + BLOCK_SIZE * tl.arange(0, HEAD_SIZE) 116 | + tok_idx_within_block 117 | ) 118 | 119 | tok_key = tl.load(key_cache_ptr + tok_offsets) 120 | tok_value = tl.load(value_cache_ptr + tok_offsets) 121 | 122 | if (tok_idx == 0 and seq_idx == 0) and (head_idx == 0): 123 | tl.store(debug_key_cache_load_ptr3 + tl.arange(0, HEAD_SIZE), tok_key) 124 | 125 | if (tok_idx == 1 and seq_idx == 0) and (head_idx == 0): 126 | tl.store(debug_key_cache_load_ptr4 + tl.arange(0, HEAD_SIZE), tok_key) 127 | 128 | if (tok_idx == 7 and seq_idx == num_seqs - 1) and (head_idx == 0): 129 | tl.store(debug_key_cache_load_ptr5 + tl.arange(0, HEAD_SIZE), tok_key) 130 | 131 | scratchpad_offset = ( 132 | seq_idx * (MAX_CONTEXT_LEN * num_heads * HEAD_SIZE) 133 | + tok_idx * (num_heads * HEAD_SIZE) 134 | + head_idx * HEAD_SIZE 135 | ) 136 | tl.store( 137 | scratchpad_key_ptr + scratchpad_offset + tl.arange(0, HEAD_SIZE), tok_key 138 | ) 139 | tl.store( 140 | scratchpad_value_ptr + scratchpad_offset + tl.arange(0, HEAD_SIZE), 141 | tok_value, 142 | ) 143 | 144 | # TODO: Not sure if this is necessary 145 | tl.debug_barrier() 146 | 147 | # scratchpad_key_ptr, # [num_seqs, max_context_len, num_heads, head_size] 148 | start_seq_offset = (MAX_CONTEXT_LEN * num_heads * HEAD_SIZE) * seq_idx 149 | start_tok_offset = start_seq_offset + tl.arange(0, MAX_CONTEXT_LEN) * (num_heads * HEAD_SIZE) + head_idx * HEAD_SIZE 150 | 151 | # [seq_len, head_size] 152 | # zero out keys that aren't part of the sequence 153 | mask = tl.arange(0, MAX_CONTEXT_LEN)[:, None] < context_len 154 | kv_offs = start_tok_offset[:, None] + tl.arange(0, HEAD_SIZE)[None, :] 155 | keys = tl.load(scratchpad_key_ptr + kv_offs, mask=mask, other=0.0) 156 | values = tl.load(scratchpad_value_ptr + kv_offs, mask=mask, other=0.0) 157 | 158 | # keys shape [seq_len x head_size], query shape = [head_size] 159 | 160 | # Can't do below b/c minimum size on all dimensions is 16 161 | # scores = tl.dot(query_head[None, :], keys.T) 162 | 163 | # Workaround for matrix, vector dot product 164 | # shape = [seq_len] 165 | # tmp_scores = tl.zeros([MAX_CONTEXT_LEN], dtype=tl.float32) 166 | scores = tl.sum(keys * query_head[None, :], axis=1) 167 | 168 | # This mask is necessary b/c even though we mask out the keys on load 169 | # that just results in 0s in the attention dot product, 170 | # which then get softmaxed and result in non-zero values 171 | # in the softmax output (which is wrong) 172 | # -inf guarantees that the softmax output will be 0 for masked values 173 | mask = tl.full([MAX_CONTEXT_LEN], -float('inf'), dtype=tl.float32) 174 | cond = tl.arange(0, MAX_CONTEXT_LEN) < context_len 175 | scores_masked = tl.where(cond, scores, mask) 176 | 177 | if seq_idx == 0 and head_idx == 0: 178 | # tl.store(debug_scores_ptr + tl.arange(0, MAX_CONTEXT_LEN), scores) 179 | tl.store(debug_scores_ptr + tl.arange(0, MAX_CONTEXT_LEN), scores_masked) 180 | 181 | # do a numerically stable softmax on the scores 182 | scores_minus_max = scores_masked - tl.max(scores_masked, axis=0) 183 | numerator = tl.exp(scores_minus_max) 184 | denominator = tl.sum(numerator, axis=0) 185 | logits = numerator / denominator 186 | 187 | if seq_idx == 0 and head_idx == 0: 188 | tl.store(debug_softmax_ptr + tl.arange(0, MAX_CONTEXT_LEN), logits) 189 | 190 | weighted_values = tl.sum(values * logits[:, None], axis=0) 191 | 192 | if seq_idx == 0 and head_idx == 0: 193 | tl.store(debug_output_ptr + tl.arange(0, HEAD_SIZE), weighted_values) 194 | 195 | output_offset = seq_idx * (num_heads * HEAD_SIZE) + head_idx * HEAD_SIZE 196 | tl.store(output_ptr + output_offset + tl.arange(0, HEAD_SIZE), weighted_values) 197 | 198 | 199 | def test_triton_paged_attention(): 200 | import random 201 | import torch 202 | 203 | # num_blocks_in_cache = 32 204 | 205 | # block_size = 32 206 | # seed = 0 207 | # head_size = 64 208 | # num_heads = (8, 8) 209 | # num_seqs = 8 210 | # max_seq_len = 512 211 | 212 | num_blocks_in_cache = 8 213 | 214 | block_size = 2 215 | seed = 0 216 | head_size = 4 217 | num_heads = (2, 2) 218 | num_seqs = 2 219 | max_seq_len = 8 220 | 221 | random.seed(seed) 222 | torch.random.manual_seed(seed) 223 | torch.cuda.manual_seed(seed) 224 | 225 | scale = float(1.0 / (head_size**0.5)) 226 | num_query_heads, num_kv_heads = num_heads 227 | query = torch.empty( 228 | num_seqs, num_query_heads, head_size, dtype=torch.float32, device="cuda" 229 | ) 230 | query.uniform_(-scale, scale) 231 | output = torch.empty_like(query, device="cuda") 232 | 233 | cache_shape = (num_blocks_in_cache, num_query_heads, head_size, block_size) 234 | 235 | key_cache = torch.empty(cache_shape, dtype=torch.float32, device="cuda") 236 | key_cache.uniform_(-scale, scale) 237 | assert key_cache.stride(0) == num_query_heads * head_size * block_size 238 | 239 | value_cache = torch.empty(cache_shape, dtype=torch.float32, device="cuda") 240 | value_cache.uniform_(-scale, scale) 241 | 242 | context_lens = torch.tensor( 243 | [random.randint(1, max_seq_len) for _ in range(num_seqs)], device="cuda" 244 | ) 245 | context_lens[-1] = max_seq_len 246 | max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size 247 | block_tables = [ 248 | [ 249 | random.randint(0, num_blocks_in_cache - 1) 250 | for _ in range(max_num_blocks_per_seq) 251 | ] 252 | for _ in range(num_seqs) 253 | ] 254 | block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda") 255 | 256 | # create tensor of all 0s of size 16 257 | debug_block_idxs = torch.zeros( 258 | max_num_blocks_per_seq, dtype=torch.int, device="cuda" 259 | ) 260 | debug_key_cache_load = torch.zeros( 261 | max_num_blocks_per_seq, dtype=key_cache.dtype, device="cuda" 262 | ) 263 | debug_key_cache_load2 = torch.zeros( 264 | max_num_blocks_per_seq, 265 | head_size, 266 | block_size, 267 | dtype=torch.float32, 268 | device="cuda", 269 | ) 270 | debug_block_idx_ptr2 = torch.zeros(1, dtype=torch.int, device="cuda") 271 | debug_key_cache_load3 = torch.zeros(head_size, dtype=torch.float32, device="cuda") 272 | debug_key_cache_load4 = torch.zeros(head_size, dtype=torch.float32, device="cuda") 273 | debug_key_cache_load5 = torch.zeros(head_size, dtype=torch.float32, device="cuda") 274 | debug_scores = torch.zeros(max_seq_len, dtype=torch.float32, device="cuda") 275 | debug_softmax = torch.zeros(max_seq_len, dtype=torch.float32, device="cuda") 276 | debug_output_ptr = torch.zeros(head_size, dtype=torch.float32, device="cuda") 277 | 278 | scratchpad_key = torch.zeros( 279 | (num_seqs, max_seq_len, num_query_heads, head_size), 280 | dtype=torch.float32, 281 | device="cuda", 282 | ) 283 | scratchpad_value = torch.zeros_like(scratchpad_key) 284 | 285 | paged_attention[(num_seqs, num_query_heads)]( 286 | debug_block_idxs_ptr=debug_block_idxs, 287 | debug_key_cache_load_ptr=debug_key_cache_load, 288 | debug_key_cache_load_ptr2=debug_key_cache_load2, 289 | debug_block_idx_ptr2=debug_block_idx_ptr2, 290 | debug_key_cache_load_ptr3=debug_key_cache_load3, 291 | debug_key_cache_load_ptr4=debug_key_cache_load4, 292 | debug_key_cache_load_ptr5=debug_key_cache_load5, 293 | debug_scores_ptr=debug_scores, 294 | debug_softmax_ptr=debug_softmax, 295 | debug_output_ptr=debug_output_ptr, 296 | 297 | scratchpad_key_ptr=scratchpad_key, 298 | scratchpad_value_ptr=scratchpad_value, 299 | output_ptr=output, 300 | query_ptr=query, 301 | key_cache_ptr=key_cache, 302 | value_cache_ptr=value_cache, 303 | block_tables_ptr=block_tables, 304 | context_lens_ptr=context_lens, 305 | scale=scale, 306 | num_seqs=num_seqs, 307 | num_heads=num_query_heads, 308 | cache_block_stride=key_cache.stride(0), 309 | MAX_CONTEXT_LEN=max_seq_len, 310 | BLOCK_SIZE=block_size, 311 | HEAD_SIZE=head_size, 312 | MAX_NUM_BLOCKS_PER_SEQ=max_num_blocks_per_seq, 313 | ) 314 | 315 | torch.cuda.synchronize() 316 | # torch.testing.assert_close(debug_block_idxs, block_tables[0]) 317 | # torch.testing.assert_close( 318 | # debug_key_cache_load, key_cache[block_tables[0], 1, 0, 0] 319 | # ) 320 | 321 | # seq_0_head_1_keys = key_cache[block_tables[0], 1] 322 | # torch.testing.assert_close(seq_0_head_1_keys, debug_key_cache_load2) 323 | 324 | assert debug_block_idx_ptr2[0] == block_tables[0, 0] 325 | 326 | seq0_tok0_head0_key = key_cache[block_tables[0, 0], 0, :, 0] 327 | torch.testing.assert_close(debug_key_cache_load3, seq0_tok0_head0_key) 328 | 329 | seq0_tok1_head0_key = key_cache[block_tables[0, 0], 0, :, 1] 330 | torch.testing.assert_close(debug_key_cache_load4, seq0_tok1_head0_key) 331 | 332 | last_seq_tok7_head0_key = key_cache[ 333 | block_tables[num_seqs - 1, 7 // block_size], 0, :, 7 % block_size 334 | ] 335 | torch.testing.assert_close(debug_key_cache_load5, last_seq_tok7_head0_key) 336 | 337 | seq0_len = context_lens[0] 338 | seq0_head0_keys = key_cache[block_tables[0], 0] 339 | divide_round_up = lambda x, y: (x + y - 1) // y 340 | seq0_num_blocks = divide_round_up(seq0_len, block_size) 341 | assert seq0_head0_keys.shape == (seq0_num_blocks, head_size, block_size) 342 | seq0_head0_keys = rearrange( 343 | seq0_head0_keys, 344 | "num_blocks head_size block_size -> (num_blocks block_size) head_size", 345 | ) 346 | assert seq0_head0_keys.shape == (seq0_num_blocks * block_size, head_size) 347 | seq0_head0_keys_clipped = seq0_head0_keys[:seq0_len] 348 | assert seq0_head0_keys_clipped.shape == (seq0_len, head_size) 349 | torch.testing.assert_close(seq0_head0_keys_clipped, scratchpad_key[0, :seq0_len, 0, :]) 350 | 351 | # do dot product of query & keys 352 | scores = seq0_head0_keys @ query[0, 0] 353 | assert scores.shape == debug_scores.shape 354 | # emulate triton's masking 355 | scores[-1] = -float('inf') 356 | torch.testing.assert_close(scores[:-1], debug_scores[:-1]) 357 | 358 | expected_softmax = torch.softmax(scores, dim=0) 359 | torch.testing.assert_close(debug_softmax, expected_softmax) 360 | 361 | seq0_head0_values = value_cache[block_tables[0], 0] 362 | seq0_head0_values = rearrange( 363 | seq0_head0_values, 364 | "num_blocks head_size block_size -> (num_blocks block_size) head_size", 365 | ) 366 | assert seq0_head0_values.shape == (seq0_num_blocks * block_size, head_size) 367 | seq0_head0_values_clipped = seq0_head0_values[:seq0_len] 368 | assert seq0_head0_values_clipped.shape == (seq0_len, head_size) 369 | torch.testing.assert_close( 370 | seq0_head0_values_clipped, scratchpad_value[0, :seq0_len, 0, :] 371 | ) 372 | 373 | expected_output = seq0_head0_values.T @ expected_softmax 374 | torch.testing.assert_close(expected_output, debug_output_ptr) 375 | 376 | # load output from correct place in output_ptr (ensure location is right) 377 | loaded_output = output[0, 0] 378 | torch.testing.assert_close(loaded_output, expected_output) 379 | print("KERNEL RAN SUCCESSFULLY ...") 380 | 381 | 382 | if __name__ == "__main__": 383 | test_triton_paged_attention() 384 | -------------------------------------------------------------------------------- /allreduce/csrc/reference_allreduce/fast_allreduce.cuh: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | #define CUDACHECK(cmd) \ 12 | do { \ 13 | cudaError_t e = cmd; \ 14 | if (e != cudaSuccess) { \ 15 | printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \ 16 | cudaGetErrorString(e)); \ 17 | exit(EXIT_FAILURE); \ 18 | } \ 19 | } while (0) 20 | 21 | namespace vllm { 22 | 23 | struct Signal { 24 | alignas(64) union { 25 | uint64_t flag; 26 | unsigned char data[8]; 27 | } start; 28 | alignas(64) union { 29 | uint64_t flag; 30 | unsigned char data[8]; 31 | } end; 32 | }; 33 | 34 | struct Metadata { 35 | alignas(128) Signal sg; 36 | alignas(128) int counter; 37 | }; 38 | static_assert(offsetof(Metadata, counter) == 128); 39 | static_assert(sizeof(Metadata) == 256); 40 | 41 | struct __align__(16) RankData { const void *__restrict__ ptrs[8]; }; 42 | 43 | struct RankSignals { 44 | volatile Signal *signals[8]; 45 | }; 46 | 47 | // like std::array, but aligned 48 | template 49 | struct __align__(alignof(T) * sz) array_t { 50 | T data[sz]; 51 | using type = T; 52 | static constexpr int size = sz; 53 | }; 54 | 55 | // use packed type to maximize memory efficiency 56 | // goal: generate ld.128 and st.128 instructions 57 | template 58 | struct packed_t { 59 | // the (P)acked type for load/store 60 | using P = array_t; 61 | // the (A)ccumulator type for reduction 62 | using A = array_t; 63 | }; 64 | 65 | #define DINLINE __device__ __forceinline__ 66 | 67 | // scalar cast functions 68 | DINLINE float upcast_s(half val) { return __half2float(val); } 69 | 70 | template 71 | DINLINE T downcast_s(float val); 72 | template <> 73 | DINLINE half downcast_s(float val) { 74 | return __float2half(val); 75 | } 76 | 77 | // scalar add functions 78 | // for some reason when compiling with Pytorch, the + operator for half and 79 | // bfloat is disabled so we call the intrinsics directly 80 | DINLINE half &assign_add(half &a, half b) { 81 | a = __hadd(a, b); 82 | return a; 83 | } 84 | DINLINE float &assign_add(float &a, float b) { return a += b; } 85 | 86 | #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) 87 | DINLINE float upcast_s(nv_bfloat16 val) { return __bfloat162float(val); } 88 | template <> 89 | DINLINE nv_bfloat16 downcast_s(float val) { 90 | return __float2bfloat16(val); 91 | } 92 | DINLINE nv_bfloat16 &assign_add(nv_bfloat16 &a, nv_bfloat16 b) { 93 | a = __hadd(a, b); 94 | return a; 95 | } 96 | #endif 97 | 98 | template 99 | DINLINE array_t &packed_assign_add(array_t &a, array_t b) { 100 | #pragma unroll 101 | for (int i = 0; i < N; i++) { 102 | assign_add(a.data[i], b.data[i]); 103 | } 104 | return a; 105 | } 106 | 107 | template 108 | DINLINE array_t upcast(array_t val) { 109 | if constexpr (std::is_same::value) { 110 | return val; 111 | } else { 112 | array_t out; 113 | #pragma unroll 114 | for (int i = 0; i < N; i++) { 115 | out.data[i] = upcast_s(val.data[i]); 116 | } 117 | return out; 118 | } 119 | } 120 | 121 | template 122 | DINLINE O downcast(array_t val) { 123 | if constexpr (std::is_same::value) { 124 | return val; 125 | } else { 126 | O out; 127 | #pragma unroll 128 | for (int i = 0; i < O::size; i++) { 129 | out.data[i] = downcast_s(val.data[i]); 130 | } 131 | return out; 132 | } 133 | } 134 | 135 | // compute flag at compile time 136 | __host__ __device__ constexpr uint64_t compute_flag(int ngpus) { 137 | auto m = std::numeric_limits::max(); 138 | return m >> ((8 - ngpus) * 8); 139 | } 140 | 141 | template 142 | __device__ __forceinline__ void start_sync(const RankSignals &sg, 143 | volatile Metadata *meta, int rank) { 144 | constexpr auto FLAG = compute_flag(ngpus); 145 | if (blockIdx.x == 0) { 146 | if (threadIdx.x < ngpus) 147 | // simultaneously write to the corresponding byte to all other ranks. 148 | // Latency = 1 p2p write 149 | sg.signals[threadIdx.x]->start.data[rank] = 255; 150 | else if (threadIdx.x == 32) 151 | // reset 152 | meta->sg.end.flag = 0; 153 | } 154 | 155 | if (threadIdx.x == 0) { 156 | while (meta->sg.start.flag != FLAG) 157 | ; 158 | } 159 | 160 | __syncthreads(); 161 | } 162 | 163 | template 164 | __device__ __forceinline__ void end_sync(const RankSignals &sg, 165 | volatile Metadata *meta, int rank) { 166 | constexpr auto FLAG = compute_flag(ngpus); 167 | __syncthreads(); 168 | __shared__ int num; 169 | if (threadIdx.x == 0) num = atomicAdd((int *)&meta->counter, 1); 170 | __syncthreads(); // I'm guessing this is necessary for all threads to see the updated value of `num` 171 | 172 | // Only the last completing block can perform the end synchronization 173 | // This can ensures when the final busy wait ends, all ranks must have 174 | // finished reading each other's buffer. 175 | if (num == gridDim.x - 1) { 176 | if (threadIdx.x == 32) { 177 | // reset in a different warp 178 | meta->counter = 0; 179 | meta->sg.start.flag = 0; 180 | } else if (threadIdx.x < ngpus) { 181 | // simultaneously write to the corresponding byte to all other ranks. 182 | // Latency = 1 p2p write 183 | sg.signals[threadIdx.x]->end.data[rank] = 255; 184 | } 185 | // if this is the final sync, only one block needs it 186 | // because kernel exit can serve as sync 187 | if constexpr (final_sync) { 188 | if (threadIdx.x == 0) { 189 | while (meta->sg.end.flag != FLAG) 190 | ; 191 | } 192 | } 193 | } 194 | if constexpr (!final_sync) { 195 | if (threadIdx.x == 0) { 196 | while (meta->sg.end.flag != FLAG) 197 | ; 198 | } 199 | __syncthreads(); 200 | } 201 | } 202 | 203 | template 204 | DINLINE P packed_reduce(const P *ptrs[], int idx) { 205 | A tmp = upcast(ptrs[0][idx]); 206 | #pragma unroll 207 | for (int i = 1; i < ngpus; i++) { 208 | packed_assign_add(tmp, upcast(ptrs[i][idx])); 209 | } 210 | return downcast

(tmp); 211 | } 212 | 213 | template 214 | __global__ void __launch_bounds__(512, 1) 215 | cross_device_reduce_1stage(RankData *_dp, RankSignals sg, 216 | volatile Metadata *meta, T *__restrict__ result, 217 | int rank, int size) { 218 | // Both P,A are array_t 219 | using P = typename packed_t::P; 220 | using A = typename packed_t::A; 221 | const P *ptrs[ngpus]; 222 | #pragma unroll 223 | for (int i = 0; i < ngpus; i++) { 224 | int target = (rank + i) % ngpus; 225 | ptrs[i] = (P *)_dp->ptrs[target]; 226 | } 227 | start_sync(sg, meta, rank); 228 | // do the actual reduction 229 | for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; 230 | idx += gridDim.x * blockDim.x) { 231 | ((P *)result)[idx] = packed_reduce(ptrs, idx); 232 | } 233 | end_sync(sg, meta, rank); 234 | } 235 | 236 | template 237 | DINLINE P *get_tmp_buf(volatile Signal *sg) { 238 | return (P *)(((Metadata *)sg) + 1); 239 | } 240 | 241 | template 242 | __global__ void __launch_bounds__(512, 1) 243 | cross_device_reduce_2stage(RankData *_dp, RankSignals sg, 244 | volatile Metadata *meta, T *__restrict__ result, 245 | int rank, int size) { 246 | int tid = blockIdx.x * blockDim.x + threadIdx.x; 247 | int stride = gridDim.x * blockDim.x; 248 | using P = typename packed_t::P; 249 | using A = typename packed_t::A; 250 | int part = size / ngpus; 251 | int start = rank * part; 252 | int end = rank == ngpus - 1 ? size : start + part; 253 | const P *ptrs[ngpus]; 254 | P *tmps[ngpus]; 255 | #pragma unroll 256 | for (int i = 0; i < ngpus; i++) { 257 | int target = (rank + i) % ngpus; 258 | ptrs[i] = (const P *)_dp->ptrs[target]; 259 | tmps[i] = get_tmp_buf

(sg.signals[target]); 260 | } 261 | auto tmp_out = tmps[0]; 262 | start_sync(sg, meta, rank); 263 | // stage 1: reduce scatter 264 | for (int idx = start + tid; idx < end; idx += stride) { 265 | tmp_out[idx - start] = packed_reduce(ptrs, idx); 266 | } 267 | // Maybe TODO: replace this with per-block release-acquire 268 | // can save about 1-2us (not a lot though) 269 | end_sync(sg, meta, rank); 270 | 271 | // stage 2: allgather 272 | for (int idx = tid; idx < part; idx += stride) { 273 | #pragma unroll 274 | for (int i = 0; i < ngpus; i++) { 275 | int dst_idx = ((rank + i) % ngpus) * part + idx; 276 | ((P *)result)[dst_idx] = tmps[i][idx]; 277 | } 278 | } 279 | // process the last larger partition 280 | int remaining = size - part * ngpus; 281 | if (tid < remaining) { 282 | int dst_idx = tid + part * ngpus; 283 | ((P *)result)[dst_idx] = get_tmp_buf

(sg.signals[ngpus - 1])[part + tid]; 284 | } 285 | 286 | // faster than this 287 | // for (int idx = tid; idx < size; idx += stride) { 288 | // int target_rank = idx / part; 289 | // if (target_rank == ngpus) target_rank -= 1; 290 | // ((P *)result)[idx] = tmps[target_rank][idx - target_rank * part]; 291 | // } 292 | } 293 | 294 | template 295 | __global__ void __launch_bounds__(512, 1) 296 | cross_device_reduce_half_butterfly(RankData *_dp, RankSignals sg, 297 | volatile Metadata *meta, 298 | T *__restrict__ result, int rank, 299 | int size) { 300 | int tid = blockIdx.x * blockDim.x + threadIdx.x; 301 | int stride = gridDim.x * blockDim.x; 302 | using P = typename packed_t::P; 303 | using A = typename packed_t::A; 304 | auto tmp_out = get_tmp_buf

(sg.signals[rank]); 305 | constexpr int hg = ngpus / 2; 306 | // Actually not quite half butterfly. 307 | // This is an all-to-all within each group containing half of the ranks 308 | // followed by cross-group add. Equivalent to half butterfly when there 309 | // are 4 GPUs, a common case for PCIe cards like T4 and A10. 310 | const P *ptrs[hg]; 311 | { 312 | int start = rank - rank % hg; 313 | #pragma unroll 314 | for (int i = 0; i < hg; i++) { 315 | ptrs[i] = (const P *)_dp->ptrs[i + start]; 316 | } 317 | } 318 | start_sync(sg, meta, rank); 319 | for (int idx = tid; idx < size; idx += stride) { 320 | tmp_out[idx] = packed_reduce(ptrs, idx); 321 | } 322 | end_sync(sg, meta, rank); 323 | 324 | auto src = get_tmp_buf

(sg.signals[(ngpus - 1) - rank % ngpus]); 325 | // do the actual reduction 326 | for (int idx = tid; idx < size; idx += stride) { 327 | auto tmp = tmp_out[idx]; 328 | packed_assign_add(tmp, src[idx]); 329 | ((P *)result)[idx] = tmp; 330 | } 331 | } 332 | class FastAllreduce { 333 | public: 334 | int rank_; 335 | int world_size_; 336 | bool full_nvlink_; 337 | 338 | // below are device pointers 339 | RankSignals sg_; 340 | std::unordered_map buffers_; 341 | Metadata *meta_; 342 | 343 | // stores the registered device pointers from all ranks 344 | RankData *d_rank_data_base_, *d_rank_data_end_; 345 | std::vector graph_unreg_buffers_; 346 | std::vector ipc_handles_; 347 | 348 | /** 349 | * meta is a pointer to device metadata and temporary buffer for allreduce. 350 | * 351 | * There's a total of sizeof(Metadata) of prefix before the actual data, 352 | * so meta + 1 points to actual temporary buffer. 353 | * 354 | * note: this class does not own any device memory. Any required buffers 355 | * are passed in from the constructor 356 | */ 357 | FastAllreduce(Metadata *meta, void *rank_data, size_t rank_data_sz, 358 | const cudaIpcMemHandle_t *handles, 359 | const std::vector &offsets, int rank, 360 | bool full_nvlink = true) 361 | : rank_(rank), 362 | world_size_(offsets.size()), 363 | full_nvlink_(full_nvlink), 364 | meta_(meta), 365 | d_rank_data_base_(reinterpret_cast(rank_data)), 366 | d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) { 367 | for (int i = 0; i < world_size_; i++) { 368 | Metadata *rank_meta; 369 | if (i != rank_) { 370 | char *handle; 371 | CUDACHECK(cudaIpcOpenMemHandle((void **)&handle, handles[i], 372 | cudaIpcMemLazyEnablePeerAccess)); 373 | ipc_handles_.push_back(handle); 374 | handle += offsets[i]; 375 | rank_meta = (Metadata *)handle; 376 | } else { 377 | rank_meta = meta_; 378 | } 379 | sg_.signals[i] = &rank_meta->sg; 380 | } 381 | } 382 | 383 | std::pair, std::vector> 384 | get_graph_buffer_ipc_meta() { 385 | auto num_buffers = graph_unreg_buffers_.size(); 386 | auto handle_sz = sizeof(cudaIpcMemHandle_t); 387 | std::vector handles(handle_sz * num_buffers, 0); 388 | std::vector offsets(num_buffers); 389 | for (int i = 0; i < num_buffers; i++) { 390 | auto ptr = graph_unreg_buffers_[i]; 391 | void *base_ptr; 392 | // note: must share the base address of each allocation, or we get wrong 393 | // address 394 | if (cuPointerGetAttribute(&base_ptr, 395 | CU_POINTER_ATTRIBUTE_RANGE_START_ADDR, 396 | (CUdeviceptr)ptr) != CUDA_SUCCESS) 397 | throw std::runtime_error("failed to get pointer attr"); 398 | CUDACHECK(cudaIpcGetMemHandle( 399 | (cudaIpcMemHandle_t *)&handles[i * handle_sz], base_ptr)); 400 | offsets[i] = ((char *)ptr) - ((char *)base_ptr); 401 | } 402 | return std::make_pair(handles, offsets); 403 | } 404 | 405 | void check_rank_data_capacity(size_t num = 1) { 406 | if (d_rank_data_base_ + num > d_rank_data_end_) 407 | throw std::runtime_error( 408 | "Rank data buffer is overflowed by " + 409 | std::to_string(d_rank_data_base_ + num - d_rank_data_end_)); 410 | } 411 | 412 | // I would rename as: 413 | // void *self => void *cur_rank_ptr 414 | void register_buffer(const std::vector &handles, 415 | const std::vector &offsets, void *self) { 416 | check_rank_data_capacity(); 417 | RankData data; 418 | for (int i = 0; i < world_size_; i++) { 419 | if (i != rank_) { 420 | char *handle; 421 | CUDACHECK(cudaIpcOpenMemHandle( 422 | (void **)&handle, *((const cudaIpcMemHandle_t *)handles[i].data()), 423 | cudaIpcMemLazyEnablePeerAccess)); 424 | ipc_handles_.push_back(handle); 425 | handle += offsets[i]; 426 | data.ptrs[i] = handle; 427 | } else { 428 | data.ptrs[i] = self; 429 | } 430 | } 431 | auto d_data = d_rank_data_base_++; 432 | CUDACHECK( 433 | cudaMemcpy(d_data, &data, sizeof(RankData), cudaMemcpyHostToDevice)); 434 | buffers_[self] = d_data; 435 | } 436 | 437 | // note: when registering graph buffers, we intentionally choose to not 438 | // deduplicate the addresses. That means if the allocator reuses some 439 | // addresses, they will be registered again. This is to account for the remote 440 | // possibility of different allocation patterns between ranks. For example, 441 | // rank 1 may get the same input address for the second allreduce, but rank 2 442 | // got a different address. IPC handles have internal reference counting 443 | // mechanism so overhead should be small. 444 | void register_graph_buffers( 445 | const std::vector &handles, 446 | const std::vector> &offsets) { 447 | auto num_buffers = graph_unreg_buffers_.size(); 448 | check_rank_data_capacity(num_buffers); 449 | 450 | // Each buffer is registered across (up to) 8 ranks 451 | std::vector rank_data(num_buffers); 452 | 453 | for (int i = 0; i < num_buffers; i++) { 454 | auto self_ptr = graph_unreg_buffers_[i]; 455 | auto &rd = rank_data[i]; 456 | for (int j = 0; j < world_size_; j++) { 457 | if (j != rank_) { 458 | char *handle; 459 | CUDACHECK(cudaIpcOpenMemHandle( 460 | (void **)&handle, 461 | *((cudaIpcMemHandle_t *)&handles[j] 462 | [i * sizeof(cudaIpcMemHandle_t)]), 463 | cudaIpcMemLazyEnablePeerAccess)); 464 | ipc_handles_.push_back(handle); 465 | handle += offsets[j][i]; 466 | rd.ptrs[j] = handle; 467 | } else { 468 | rd.ptrs[j] = self_ptr; 469 | } 470 | } 471 | } 472 | CUDACHECK(cudaMemcpy(d_rank_data_base_, rank_data.data(), 473 | sizeof(RankData) * num_buffers, 474 | cudaMemcpyHostToDevice)); 475 | d_rank_data_base_ += num_buffers; 476 | graph_unreg_buffers_.clear(); 477 | } 478 | 479 | // note: 512, 36 is good for most cases 480 | template 481 | void allreduce(cudaStream_t stream, T *input, T *output, int size, 482 | int threads = 512, int block_limit = 36) { 483 | auto d = packed_t::P::size; 484 | if (size % d != 0) 485 | throw std::runtime_error( 486 | "fast allreduce currently requires input length to be multiple of " + 487 | std::to_string(d)); 488 | 489 | RankData *ptrs; 490 | cudaStreamCaptureStatus status; 491 | CUDACHECK(cudaStreamIsCapturing(stream, &status)); 492 | if (status == cudaStreamCaptureStatusActive) { 493 | ptrs = d_rank_data_base_ + graph_unreg_buffers_.size(); 494 | graph_unreg_buffers_.push_back(input); 495 | } else { 496 | auto it = buffers_.find(input); 497 | if (it == buffers_.end()) 498 | throw std::runtime_error( 499 | "buffer address " + 500 | std::to_string(reinterpret_cast(input)) + 501 | " is not registered!"); 502 | ptrs = it->second; 503 | } 504 | 505 | size /= d; 506 | auto bytes = size * sizeof(typename packed_t::P); 507 | int blocks = std::min(block_limit, (size + threads - 1) / threads); 508 | #define KL(ngpus, name) \ 509 | name \ 510 | <<>>(ptrs, sg_, meta_, output, rank_, size); 511 | #define REDUCE_CASE(ngpus) \ 512 | case ngpus: { \ 513 | if (world_size_ == 2) { \ 514 | KL(ngpus, cross_device_reduce_1stage); \ 515 | } else if (full_nvlink_) { \ 516 | if ((world_size_ <= 4 && bytes < 512 * 1024) || \ 517 | (world_size_ <= 8 && bytes < 256 * 1024)) { \ 518 | KL(ngpus, cross_device_reduce_1stage); \ 519 | } else { \ 520 | KL(ngpus, cross_device_reduce_2stage); \ 521 | } \ 522 | } else { \ 523 | KL(ngpus, cross_device_reduce_half_butterfly); \ 524 | } \ 525 | break; \ 526 | } 527 | 528 | switch (world_size_) { 529 | REDUCE_CASE(2) 530 | REDUCE_CASE(4) 531 | REDUCE_CASE(6) 532 | REDUCE_CASE(8) 533 | default: 534 | throw std::runtime_error( 535 | "Fast allreduce only supports num gpus in (2,4,6,8). Actual num " 536 | "gpus = " + 537 | std::to_string(world_size_)); 538 | } 539 | #undef REDUCE_CASE 540 | #undef KL 541 | } 542 | 543 | ~FastAllreduce() { 544 | for (auto ptr : ipc_handles_) { 545 | CUDACHECK(cudaIpcCloseMemHandle(ptr)); 546 | } 547 | } 548 | }; 549 | template void FastAllreduce::allreduce(cudaStream_t, half *, half *, int, int, int); 550 | } // namespace vllm --------------------------------------------------------------------------------