├── .gitignore ├── include ├── test_gemm.h ├── ape.h ├── kernel.h └── common.h ├── src ├── gemm_fp32_cublas.cpp ├── gemm_fp64_cublas.cpp ├── kernel │ ├── int16_auto.cu │ ├── convert_fp32_to_fp64.cu │ ├── convert_fp64_to_fp32.cu │ ├── convert_int16_to_int32.cu │ ├── merge_tf32_to_fp32.cu │ ├── convert_int32_to_int16.cu │ ├── merge_fp16_to_fp32.cu │ ├── split_fp32_to_tf32.cu │ ├── merge_int8_to_int16.cu │ ├── split_fp32_to_fp16.cu │ ├── merge_bf16_to_fp32.cu │ ├── split_int16_to_int8.cu │ ├── split_fp32_to_bf16.cu │ ├── compare_fp32_to_fp64.cu │ └── fp32_auto.cu ├── gemm_fp32_auto.cpp ├── gemm_fp32_fp32t.cpp ├── gemm_fp32_fp32f.cpp ├── gemm_int16_emu.cpp ├── gemm_fp32_fp32b.cpp ├── ape.cpp └── gemm_fp32_test.cpp ├── test ├── test_gemm_fp32_fp32b.cpp ├── test_gemm_fp32_fp32f.cpp ├── test_gemm_fp32_fp32t.cpp ├── test_gemm_fp32_cublas.cpp ├── test_gemm_fp32_auto.cpp ├── test_count_overflow_fp32.cpp ├── test_count_overflow_int16.cpp └── test_create_mask_fp32.cpp ├── CMakeLists.txt ├── README.md └── LICENSE /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | build/ 3 | env* 4 | .clang-format 5 | -------------------------------------------------------------------------------- /include/test_gemm.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | namespace ape { 4 | void test_gemm_fp32(int m, int n, int k, ape::ApeAlgo algo); 5 | } // namespace ape 6 | -------------------------------------------------------------------------------- /src/gemm_fp32_cublas.cpp: -------------------------------------------------------------------------------- 1 | #include "common.h" 2 | #include "kernel.h" 3 | 4 | namespace ape { 5 | void gemm_fp32_cublas(ApeTrans transa, ApeTrans transb, int m, int n, int k, const float *alpha, const float *A, int lda, 6 | const float *B, int ldb, const float *beta, float *C, int ldc) { 7 | cublasSafeCall(cublasSgemm(APEHandler::getCublasHandle(), cublasOperation_t(transa), cublasOperation_t(transb), m, n, k, 8 | alpha, A, lda, B, ldb, beta, C, ldc)); 9 | } 10 | 11 | } // namespace ape 12 | -------------------------------------------------------------------------------- /src/gemm_fp64_cublas.cpp: -------------------------------------------------------------------------------- 1 | #include "common.h" 2 | #include "kernel.h" 3 | 4 | namespace ape { 5 | void gemm_fp64_cublas(ApeTrans transa, ApeTrans transb, int m, int n, int k, const double *alpha, const double *A, int lda, 6 | const double *B, int ldb, const double *beta, double *C, int ldc) { 7 | cublasSafeCall(cublasDgemm(APEHandler::getCublasHandle(), cublasOperation_t(transa), cublasOperation_t(transb), m, n, k, 8 | alpha, A, lda, B, ldb, beta, C, ldc)); 9 | } 10 | 11 | } // namespace ape 12 | -------------------------------------------------------------------------------- /test/test_gemm_fp32_fp32b.cpp: -------------------------------------------------------------------------------- 1 | #include "ape.h" 2 | #include "test_gemm.h" 3 | 4 | int main() { 5 | ape::test_gemm_fp32(128, 128, 128, ape::APE_ALGO_FP32B); 6 | ape::test_gemm_fp32(256, 256, 256, ape::APE_ALGO_FP32B); 7 | ape::test_gemm_fp32(512, 512, 512, ape::APE_ALGO_FP32B); 8 | ape::test_gemm_fp32(1024, 1024, 1024, ape::APE_ALGO_FP32B); 9 | ape::test_gemm_fp32(2048, 2048, 2048, ape::APE_ALGO_FP32B); 10 | ape::test_gemm_fp32(4096, 4096, 4096, ape::APE_ALGO_FP32B); 11 | ape::test_gemm_fp32(8192, 8192, 8192, ape::APE_ALGO_FP32B); 12 | } -------------------------------------------------------------------------------- /test/test_gemm_fp32_fp32f.cpp: -------------------------------------------------------------------------------- 1 | #include "ape.h" 2 | #include "test_gemm.h" 3 | 4 | int main() { 5 | ape::test_gemm_fp32(128, 128, 128, ape::APE_ALGO_FP32F); 6 | ape::test_gemm_fp32(256, 256, 256, ape::APE_ALGO_FP32F); 7 | ape::test_gemm_fp32(512, 512, 512, ape::APE_ALGO_FP32F); 8 | ape::test_gemm_fp32(1024, 1024, 1024, ape::APE_ALGO_FP32F); 9 | ape::test_gemm_fp32(2048, 2048, 2048, ape::APE_ALGO_FP32F); 10 | ape::test_gemm_fp32(4096, 4096, 4096, ape::APE_ALGO_FP32F); 11 | ape::test_gemm_fp32(8192, 8192, 8192, ape::APE_ALGO_FP32F); 12 | } -------------------------------------------------------------------------------- /test/test_gemm_fp32_fp32t.cpp: -------------------------------------------------------------------------------- 1 | #include "ape.h" 2 | #include "test_gemm.h" 3 | 4 | int main() { 5 | ape::test_gemm_fp32(128, 128, 128, ape::APE_ALGO_FP32T); 6 | ape::test_gemm_fp32(256, 256, 256, ape::APE_ALGO_FP32T); 7 | ape::test_gemm_fp32(512, 512, 512, ape::APE_ALGO_FP32T); 8 | ape::test_gemm_fp32(1024, 1024, 1024, ape::APE_ALGO_FP32T); 9 | ape::test_gemm_fp32(2048, 2048, 2048, ape::APE_ALGO_FP32T); 10 | ape::test_gemm_fp32(4096, 4096, 4096, ape::APE_ALGO_FP32T); 11 | ape::test_gemm_fp32(8192, 8192, 8192, ape::APE_ALGO_FP32T); 12 | } -------------------------------------------------------------------------------- /test/test_gemm_fp32_cublas.cpp: -------------------------------------------------------------------------------- 1 | #include "ape.h" 2 | #include "test_gemm.h" 3 | 4 | int main() { 5 | ape::test_gemm_fp32(128, 128, 128, ape::APE_ALGO_CUBLAS); 6 | ape::test_gemm_fp32(256, 256, 256, ape::APE_ALGO_CUBLAS); 7 | ape::test_gemm_fp32(512, 512, 512, ape::APE_ALGO_CUBLAS); 8 | ape::test_gemm_fp32(1024, 1024, 1024, ape::APE_ALGO_CUBLAS); 9 | ape::test_gemm_fp32(2048, 2048, 2048, ape::APE_ALGO_CUBLAS); 10 | ape::test_gemm_fp32(4096, 4096, 4096, ape::APE_ALGO_CUBLAS); 11 | ape::test_gemm_fp32(8192, 8192, 8192, ape::APE_ALGO_CUBLAS); 12 | } -------------------------------------------------------------------------------- /src/kernel/int16_auto.cu: -------------------------------------------------------------------------------- 1 | #include "kernel.h" 2 | #include "thrust/device_vector.h" 3 | #include 4 | #include 5 | 6 | namespace ape { 7 | 8 | struct OpINT16 { 9 | __host__ __device__ int operator()(int16_t x) { return (x > INT16C_MAX); } 10 | }; 11 | 12 | int count_overflow_int16emu(const int16_t *src, size_t row, size_t col) { 13 | thrust::device_ptr d_src(const_cast(src)); 14 | return thrust::transform_reduce(d_src, d_src + row * col, OpINT16(), 0, thrust::plus()); 15 | } 16 | } // namespace ape 17 | -------------------------------------------------------------------------------- /src/kernel/convert_fp32_to_fp64.cu: -------------------------------------------------------------------------------- 1 | #include "kernel.h" 2 | 3 | namespace ape { 4 | 5 | __global__ void kernel_convert_fp32_to_fp64(double *dst, const float *src, size_t size) { 6 | uint32_t base = 2 * (blockIdx.x * blockDim.x + threadIdx.x); 7 | uint32_t step = 2 * blockDim.x * gridDim.x; 8 | for (uint32_t i = base; i < size; i += step) { 9 | float2 base = (float2 &)src[i]; 10 | double2 buf; 11 | buf.x = double(base.x); 12 | buf.y = double(base.y); 13 | (double2 &)dst[i] = (double2 &)buf; 14 | } 15 | } 16 | 17 | void convert_fp32_to_fp64(double *dst, const float *src, size_t size) { 18 | dim3 grid_size(NUM_SM, 1); 19 | dim3 block_size(MAX_THREAD, 1); 20 | kernel_convert_fp32_to_fp64<<>>(dst, src, size); 21 | cudaCheckError(); 22 | } 23 | 24 | } // namespace ape -------------------------------------------------------------------------------- /src/kernel/convert_fp64_to_fp32.cu: -------------------------------------------------------------------------------- 1 | #include "common.h" 2 | #include "kernel.h" 3 | 4 | namespace ape { 5 | 6 | __global__ void kernel_convert_fp64_to_fp32(float *dst, const double *src, size_t size) { 7 | size_t base = 2 * (blockIdx.x * blockDim.x + threadIdx.x); 8 | size_t step = 2 * blockDim.x * gridDim.x; 9 | for (size_t i = base; i < size; i += step) { 10 | double2 base = (double2 &)src[i]; 11 | float2 buf; 12 | buf.x = float(base.x); 13 | buf.y = float(base.y); 14 | (float2 &)dst[i] = (float2 &)buf; 15 | } 16 | } 17 | 18 | void convert_fp64_to_fp32(float *dst, const double *src, size_t size) { 19 | dim3 grid_size(NUM_SM, 1); 20 | dim3 block_size(MAX_THREAD, 1); 21 | kernel_convert_fp64_to_fp32<<>>(dst, src, size); 22 | cudaCheckError(); 23 | } 24 | 25 | } // namespace ape -------------------------------------------------------------------------------- /src/kernel/convert_int16_to_int32.cu: -------------------------------------------------------------------------------- 1 | #include "common.h" 2 | #include "kernel.h" 3 | 4 | namespace ape { 5 | __global__ void kernel_convert_int16_to_int32(int32_t *dst, const int16_t *src, size_t size) { 6 | uint32_t base = 4 * (blockIdx.x * blockDim.x + threadIdx.x); 7 | uint32_t step = 4 * blockDim.x * gridDim.x; 8 | for (uint32_t i = base; i < size; i += step) { 9 | short4 base = (short4 &)src[i]; 10 | int4 buf; 11 | buf.x = int(base.x); 12 | buf.y = int(base.y); 13 | buf.z = int(base.z); 14 | buf.w = int(base.w); 15 | (int4 &)dst[i] = buf; 16 | } 17 | return; 18 | } 19 | 20 | void convert_int16_to_int32(int32_t *dst, const int16_t *src, size_t size) { 21 | dim3 grid(NUM_SM, 1, 1); 22 | dim3 block(MAX_THREAD, 1, 1); 23 | kernel_convert_int16_to_int32<<>>(dst, src, size); 24 | cudaCheckError(); 25 | } 26 | } // namespace ape -------------------------------------------------------------------------------- /src/kernel/merge_tf32_to_fp32.cu: -------------------------------------------------------------------------------- 1 | #include "common.h" 2 | #include "kernel.h" 3 | 4 | namespace ape { 5 | __global__ void kernel_merge_fp16_to_fp32(float *dst, const float *src, size_t size) { 6 | uint32_t base = 2 * (blockIdx.x * blockDim.x + threadIdx.x); 7 | uint32_t step = 2 * blockDim.x * gridDim.x; 8 | for (uint32_t i = base; i < size; i += step) { 9 | float2 tmp[2]; 10 | tmp[0] = (float2 &)src[i]; 11 | tmp[1] = (float2 &)src[size + i]; 12 | float2 buf; 13 | buf.x = tmp[0].x + tmp[1].x / 4096.0f; 14 | buf.y = tmp[0].y + tmp[1].y / 4096.0f; 15 | (float2 &)dst[i] = buf; 16 | } 17 | return; 18 | } 19 | 20 | void merge_tf32_to_fp32(float *dst, const float *src, size_t size) { 21 | dim3 grid_size(NUM_SM, 1); 22 | dim3 block_size(MAX_THREAD, 1); 23 | kernel_merge_fp16_to_fp32<<>>(dst, src, size); 24 | cudaCheckError(); 25 | } 26 | } // namespace ape 27 | -------------------------------------------------------------------------------- /src/kernel/convert_int32_to_int16.cu: -------------------------------------------------------------------------------- 1 | #include "common.h" 2 | #include "kernel.h" 3 | 4 | namespace ape { 5 | // TODO: change names of other converts into merge or split, because this convert has different meanings from others 6 | __global__ void kernel_convert_int32_to_int16(int16_t *dst, const int32_t *src, size_t size) { 7 | uint32_t base = 2 * (blockIdx.x * blockDim.x + threadIdx.x); 8 | uint32_t step = 2 * blockDim.x * gridDim.x; 9 | for (uint32_t i = base; i < size; i += step) { 10 | int2 base = (int2 &)src[i]; 11 | short2 buf; 12 | buf.x = short(base.x); 13 | buf.y = short(base.y); 14 | (short2 &)dst[i] = buf; 15 | } 16 | return; 17 | } 18 | 19 | void convert_int32_to_int16(int16_t *dst, const int32_t *src, size_t size) { 20 | dim3 grid(NUM_SM, 1, 1); 21 | dim3 block(MAX_THREAD, 1, 1); 22 | kernel_convert_int32_to_int16<<>>(dst, src, size); 23 | cudaCheckError(); 24 | } 25 | } // namespace ape -------------------------------------------------------------------------------- /src/kernel/merge_fp16_to_fp32.cu: -------------------------------------------------------------------------------- 1 | #include "common.h" 2 | #include "kernel.h" 3 | 4 | namespace ape { 5 | __global__ void kernel_merge_fp16_to_fp32(float *dst, const half *src, size_t size) { 6 | uint32_t base = 2 * (blockIdx.x * blockDim.x + threadIdx.x); 7 | uint32_t step = 2 * blockDim.x * gridDim.x; 8 | for (uint32_t i = base; i < size; i += step) { 9 | half2 tmp[2]; 10 | tmp[0] = (half2 &)src[i]; 11 | tmp[1] = (half2 &)src[size + i]; 12 | float2 buf; 13 | buf.x = float(tmp[0].x) + float(tmp[1].x) / 4096.0f; 14 | buf.y = float(tmp[0].y) + float(tmp[1].y) / 4096.0f; 15 | (float2 &)dst[i] = buf; 16 | } 17 | return; 18 | } 19 | 20 | void merge_fp16_to_fp32(float *dst, const half *src, size_t size) { 21 | dim3 grid_size(NUM_SM, 1); 22 | dim3 block_size(MAX_THREAD, 1); 23 | kernel_merge_fp16_to_fp32<<>>(dst, src, size); 24 | cudaCheckError(); 25 | } 26 | } // namespace ape 27 | -------------------------------------------------------------------------------- /src/kernel/split_fp32_to_tf32.cu: -------------------------------------------------------------------------------- 1 | #include "common.h" 2 | #include "kernel.h" 3 | 4 | namespace ape { 5 | __global__ void kernel_split_fp32_to_tf32(float *dst, const float *src, size_t size) { 6 | uint32_t base = 2 * (blockIdx.x * blockDim.x + threadIdx.x); 7 | uint32_t step = 2 * blockDim.x * gridDim.x; 8 | for (uint32_t i = base; i < size; i += step) { 9 | float2 base = (float2 &)src[i]; 10 | float2 buf[2]; 11 | buf[0].x = base.x; 12 | buf[0].y = base.y; 13 | buf[1].x = (base.x - float(buf[0].x)) * 4096.0f; 14 | buf[1].y = (base.y - float(buf[0].y)) * 4096.0f; 15 | (float2 &)dst[i] = buf[0]; 16 | (float2 &)dst[size + i] = buf[1]; 17 | } 18 | return; 19 | } 20 | 21 | void split_fp32_to_tf32(float *dst, const float *src, size_t size) { 22 | dim3 grid_size(NUM_SM, 1); 23 | dim3 block_size(MAX_THREAD, 1); 24 | kernel_split_fp32_to_tf32<<>>(dst, src, size); 25 | cudaCheckError(); 26 | } 27 | } // namespace ape -------------------------------------------------------------------------------- /src/kernel/merge_int8_to_int16.cu: -------------------------------------------------------------------------------- 1 | #include "common.h" 2 | #include "kernel.h" 3 | 4 | namespace ape { 5 | __global__ void kernel_merge_int8_to_int16(int16_t *dst, const int8_t *src, size_t size) { 6 | uint32_t base = 4 * (blockIdx.x * blockDim.x + threadIdx.x); 7 | uint32_t step = 4 * blockDim.x * gridDim.x; 8 | for (uint32_t i = base; i < size; i += step) { 9 | char4 tmp[2]; 10 | tmp[0] = (char4 &)src[size + i]; 11 | tmp[1] = (char4 &)src[size]; 12 | short4 buf; 13 | buf.x = tmp[0].x + tmp[1].x * 256; 14 | buf.y = tmp[0].y + tmp[1].y * 256; 15 | buf.z = tmp[0].z + tmp[1].z * 256; 16 | buf.w = tmp[0].w + tmp[1].w * 256; 17 | (short4 &)dst[i] = buf; 18 | } 19 | return; 20 | } 21 | 22 | void merge_int8_to_int16(int16_t *dst, const int8_t *src, size_t size) { 23 | dim3 grid(NUM_SM, 1, 1); 24 | dim3 block(MAX_THREAD, 1, 1); 25 | kernel_merge_int8_to_int16<<>>(dst, src, size); 26 | cudaCheckError(); 27 | } 28 | } // namespace ape -------------------------------------------------------------------------------- /src/kernel/split_fp32_to_fp16.cu: -------------------------------------------------------------------------------- 1 | #include "common.h" 2 | #include "kernel.h" 3 | 4 | namespace ape { 5 | __global__ void kernel_split_fp32_to_fp16(half *dst, const float *src, size_t size) { 6 | uint32_t base = 2 * (blockIdx.x * blockDim.x + threadIdx.x); 7 | uint32_t step = 2 * blockDim.x * gridDim.x; 8 | for (uint32_t i = base; i < size; i += step) { 9 | float2 base = (float2 &)src[i]; 10 | half2 buf[2]; 11 | buf[0].x = half(base.x); 12 | buf[0].y = half(base.y); 13 | buf[1].x = half((base.x - float(buf[0].x)) * 4096.0f); 14 | buf[1].y = half((base.y - float(buf[0].y)) * 4096.0f); 15 | (half2 &)dst[i] = buf[0]; 16 | (half2 &)dst[size + i] = buf[1]; 17 | } 18 | return; 19 | } 20 | 21 | void split_fp32_to_fp16(half *dst, const float *src, size_t size) { 22 | dim3 grid_size(NUM_SM, 1); 23 | dim3 block_size(MAX_THREAD, 1); 24 | kernel_split_fp32_to_fp16<<>>(dst, src, size); 25 | cudaCheckError(); 26 | } 27 | } // namespace ape -------------------------------------------------------------------------------- /test/test_gemm_fp32_auto.cpp: -------------------------------------------------------------------------------- 1 | #include "ape.h" 2 | #include "test_gemm.h" 3 | 4 | int main() { 5 | ape::test_gemm_fp32(128, 128, 128, ape::APE_ALGO_AUTO); 6 | ape::test_gemm_fp32(128, 128, 128, ape::APE_ALGO_AUTO_STRICT); 7 | ape::test_gemm_fp32(256, 256, 256, ape::APE_ALGO_AUTO); 8 | ape::test_gemm_fp32(256, 256, 256, ape::APE_ALGO_AUTO_STRICT); 9 | ape::test_gemm_fp32(512, 512, 512, ape::APE_ALGO_AUTO); 10 | ape::test_gemm_fp32(512, 512, 512, ape::APE_ALGO_AUTO_STRICT); 11 | ape::test_gemm_fp32(1024, 1024, 1024, ape::APE_ALGO_AUTO); 12 | ape::test_gemm_fp32(1024, 1024, 1024, ape::APE_ALGO_AUTO_STRICT); 13 | ape::test_gemm_fp32(2048, 2048, 2048, ape::APE_ALGO_AUTO); 14 | ape::test_gemm_fp32(2048, 2048, 2048, ape::APE_ALGO_AUTO_STRICT); 15 | ape::test_gemm_fp32(4096, 4096, 4096, ape::APE_ALGO_AUTO); 16 | ape::test_gemm_fp32(4096, 4096, 4096, ape::APE_ALGO_AUTO_STRICT); 17 | ape::test_gemm_fp32(8192, 8192, 8192, ape::APE_ALGO_AUTO); 18 | ape::test_gemm_fp32(8192, 8192, 8192, ape::APE_ALGO_AUTO_STRICT); 19 | } -------------------------------------------------------------------------------- /src/kernel/merge_bf16_to_fp32.cu: -------------------------------------------------------------------------------- 1 | #include "common.h" 2 | #include "kernel.h" 3 | 4 | namespace ape { 5 | 6 | __global__ void kernel_merge_bf16_to_fp32(float *dst, const __nv_bfloat16 *src, uint32_t size) { 7 | uint32_t base = 2 * (blockIdx.x * blockDim.x + threadIdx.x); 8 | uint32_t step = 2 * blockDim.x * gridDim.x; 9 | for (uint32_t i = base; i < size; i += step) { 10 | __nv_bfloat162 a0, a1, a2; 11 | a0 = (__nv_bfloat162 &)src[i]; 12 | a1 = (__nv_bfloat162 &)src[i + size]; 13 | a2 = (__nv_bfloat162 &)src[i + size * 2]; 14 | float2 b; 15 | b.x = float(a0.x) + float(a1.x) + float(a2.x); 16 | b.y = float(a0.y) + float(a1.y) + float(a2.y); 17 | (float2 &)dst[i] = b; 18 | } 19 | return; 20 | } 21 | 22 | void merge_bf16_to_fp32(float *dst, const __nv_bfloat16 *src, uint32_t size) { 23 | dim3 grid_size(NUM_SM, 1); 24 | dim3 block_size(MAX_THREAD, 1); 25 | kernel_merge_bf16_to_fp32<<>>(dst, src, size); 26 | cudaCheckError(); 27 | } 28 | 29 | } // namespace ape 30 | -------------------------------------------------------------------------------- /src/gemm_fp32_auto.cpp: -------------------------------------------------------------------------------- 1 | #include "common.h" 2 | #include "kernel.h" 3 | 4 | namespace ape { 5 | void gemm_fp32_auto(ApeTrans transa, ApeTrans transb, int m, int n, int k, const float *alpha, const float *A, int lda, 6 | const float *B, int ldb, const float *beta, float *C, int ldc) { 7 | if (count_overflow_fp32f(A, m, k) > 0 || count_overflow_fp32f(B, k, n) > 0) { 8 | gemm_fp32_fp32b(transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); 9 | } else { 10 | gemm_fp32_fp32f(transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); 11 | } 12 | } 13 | 14 | void gemm_fp32_auto_strict(ApeTrans transa, ApeTrans transb, int m, int n, int k, const float *alpha, const float *A, int lda, 15 | const float *B, int ldb, const float *beta, float *C, int ldc) { 16 | if (count_overflow_fp32f_strict(A, m, k) > 0 || count_overflow_fp32f_strict(B, k, n) > 0) { 17 | gemm_fp32_fp32b(transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); 18 | } else { 19 | gemm_fp32_fp32f(transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); 20 | } 21 | } 22 | 23 | } // namespace ape 24 | -------------------------------------------------------------------------------- /src/kernel/split_int16_to_int8.cu: -------------------------------------------------------------------------------- 1 | #include "common.h" 2 | #include "kernel.h" 3 | 4 | namespace ape { 5 | __global__ void kernel_split_int16_to_int8(int8_t *dst, const int16_t *src, size_t size) { 6 | uint32_t base = 4 * (blockIdx.x * blockDim.x + threadIdx.x); 7 | uint32_t step = 4 * blockDim.x * gridDim.x; 8 | for (uint32_t i = base; i < size; i += step) { 9 | short4 base = (short4 &)src[i]; 10 | char4 buf[2]; 11 | buf[0].x = char(base.x); 12 | buf[0].y = char(base.y); 13 | buf[0].z = char(base.z); 14 | buf[0].w = char(base.w); 15 | buf[1].x = char((base.x - int16_t(buf[0].x)) / 256); 16 | buf[1].y = char((base.y - int16_t(buf[0].y)) / 256); 17 | buf[1].z = char((base.z - int16_t(buf[0].z)) / 256); 18 | buf[1].w = char((base.w - int16_t(buf[0].w)) / 256); 19 | (char4 &)dst[i] = buf[1]; 20 | (char4 &)dst[size + i] = buf[0]; 21 | } 22 | return; 23 | } 24 | 25 | void split_int16_to_int8(int8_t *dst, const int16_t *src, size_t size) { 26 | dim3 grid(NUM_SM, 1, 1); 27 | dim3 block(MAX_THREAD, 1, 1); 28 | kernel_split_int16_to_int8<<>>(dst, src, size); 29 | cudaCheckError(); 30 | } 31 | } // namespace ape -------------------------------------------------------------------------------- /src/kernel/split_fp32_to_bf16.cu: -------------------------------------------------------------------------------- 1 | #include "common.h" 2 | #include "kernel.h" 3 | 4 | namespace ape { 5 | __global__ void kernel_split_fp32_to_bf16(__nv_bfloat16 *dst, const float *src, uint32_t size) { 6 | uint32_t base = 2 * (blockIdx.x * blockDim.x + threadIdx.x); 7 | uint32_t step = 2 * blockDim.x * gridDim.x; 8 | for (uint32_t i = base; i < size; i += step) { 9 | float2 base = (float2 &)src[i]; 10 | __nv_bfloat162 buf[3]; 11 | buf[0].x = __float2bfloat16(base.x); 12 | buf[0].y = __float2bfloat16(base.y); 13 | buf[1].x = __float2bfloat16(base.x - float(buf[0].x)); 14 | buf[1].y = __float2bfloat16(base.y - float(buf[0].y)); 15 | buf[2].x = __float2bfloat16(base.x - float(buf[0].x) - float(buf[1].x)); 16 | buf[2].y = __float2bfloat16(base.y - float(buf[0].y) - float(buf[1].y)); 17 | (__nv_bfloat162 &)dst[i] = buf[0]; 18 | (__nv_bfloat162 &)dst[size + i] = buf[1]; 19 | (__nv_bfloat162 &)dst[size * 2 + i] = buf[2]; 20 | } 21 | return; 22 | } 23 | 24 | void split_fp32_to_bf16(__nv_bfloat16 *dst, const float *src, uint32_t size) { 25 | dim3 grid_size(NUM_SM, 1); 26 | dim3 block_size(MAX_THREAD, 1); 27 | kernel_split_fp32_to_bf16<<>>(dst, src, size); 28 | cudaCheckError(); 29 | } 30 | } // namespace ape 31 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.9) 2 | project(APE) 3 | 4 | set(CMAKE_CXX_STANDARD 20) 5 | 6 | find_package(CUDA REQUIRED) 7 | enable_language(CUDA) 8 | find_package(CUDAToolkit) 9 | 10 | set(CUDA_NVCC_FLAGS "-use_fast_math") # --ptxas-options=-v 11 | include_directories(include) 12 | 13 | file(GLOB_RECURSE SRC_CUDA src/kernel/*.cu) 14 | add_library(ape_cuda SHARED ${SRC_CUDA}) 15 | 16 | file(GLOB_RECURSE SRC src/*.cpp) 17 | add_library(ape SHARED ${SRC}) 18 | target_link_libraries(ape ape_cuda CUDA::cublas CUDA::curand CUDA::cudart) 19 | 20 | 21 | if ("${ARCH}" STREQUAL "") 22 | set(ARCH "80") 23 | endif() 24 | if ("${TEST}" STREQUAL "") 25 | set(TEST "ON") 26 | endif() 27 | if ("${ARCH}" STREQUAL "80") 28 | message("APE ARCH set to sm80") 29 | add_compile_definitions(ARCH_SM80=true) 30 | set_target_properties(ape_cuda PROPERTIES CUDA_ARCHITECTURES "80") 31 | elseif ("${ARCH}" STREQUAL "70") 32 | message("APE ARCH set to sm70") 33 | add_compile_definitions(ARCH_SM70=True) 34 | set_target_properties(ape_cuda PROPERTIES CUDA_ARCHITECTURES "80") 35 | else() 36 | message(FATAL_ERROR "Invalid ARCH (set ti 70 or 80)") 37 | message("error") 38 | endif() 39 | 40 | file(GLOB_RECURSE SRC_TEST test/test_gemm_fp32.cpp) 41 | set(TESTS test_gemm_fp32_auto test_gemm_fp32_cublas test_gemm_fp32_fp32f test_gemm_fp32_fp32b test_gemm_fp32_fp32t test_count_overflow_fp32 test_count_overflow_int16 test_create_mask_fp32) 42 | foreach(TEST IN LISTS TESTS) 43 | cuda_add_executable(${TEST} test/${TEST}.cpp) 44 | target_link_libraries(${TEST} ape) 45 | endforeach() 46 | 47 | -------------------------------------------------------------------------------- /src/gemm_fp32_fp32t.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include "common.h" 3 | #include "kernel.h" 4 | 5 | namespace ape { 6 | void gemm_fp32_fp32t(ApeTrans transa, ApeTrans transb, int m, int n, int k, const float *alpha, const float *A, int lda, 7 | const float *B, int ldb, const float *beta, float *C, int ldc) { 8 | assert((m * k + k * n) * 8 <= APEHandler::getBufSize()); 9 | float *buf = (float *)APEHandler::getBuf(); 10 | float *buf_a, *buf_b; 11 | buf_a = buf; 12 | buf_b = buf + m * k * 2; 13 | 14 | split_fp32_to_tf32(buf_a, A, m * k); 15 | split_fp32_to_tf32(buf_b, B, k * n); 16 | 17 | float alpha0 = *alpha, alpha1 = *alpha / 4096.0f, beta0 = *beta, beta1 = 1; 18 | cublasSafeCall(cublasGemmEx(APEHandler::getCublasHandle(), cublasOperation_t(transa), cublasOperation_t(transb), m, n, k, 19 | &alpha0, buf_a, CUDA_R_32F, lda, buf_b, CUDA_R_32F, ldb, &beta0, C, CUDA_R_32F, ldc, 20 | CUBLAS_COMPUTE_32F_FAST_TF32, CUBLAS_GEMM_DEFAULT)); 21 | cublasSafeCall(cublasGemmEx(APEHandler::getCublasHandle(), cublasOperation_t(transa), cublasOperation_t(transb), m, n, k, 22 | &alpha1, buf_a + m * k, CUDA_R_32F, lda, buf_b, CUDA_R_32F, ldb, &beta1, C, CUDA_R_32F, ldc, 23 | CUBLAS_COMPUTE_32F_FAST_TF32, CUBLAS_GEMM_DEFAULT)); 24 | cublasSafeCall(cublasGemmEx(APEHandler::getCublasHandle(), cublasOperation_t(transa), cublasOperation_t(transb), m, n, k, 25 | &alpha1, buf_a, CUDA_R_32F, lda, buf_b + k * n, CUDA_R_32F, ldb, &beta1, C, CUDA_R_32F, ldc, 26 | CUBLAS_COMPUTE_32F_FAST_TF32, CUBLAS_GEMM_DEFAULT)); 27 | } 28 | 29 | } // namespace ape 30 | -------------------------------------------------------------------------------- /include/ape.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace ape { 8 | 9 | enum ApeTrans { 10 | APE_TRANS_N = 0, 11 | APE_TRANS_T, 12 | }; 13 | 14 | enum ApeAlgo { 15 | APE_ALGO_AUTO = 1, 16 | APE_ALGO_AUTO_STRICT, 17 | APE_ALGO_CUBLAS, 18 | APE_ALGO_FP32F, 19 | APE_ALGO_FP32B, 20 | APE_ALGO_FP32T, 21 | APE_ALGO_INT16, 22 | }; 23 | 24 | inline std::string getApeAlgoName(ApeAlgo algo) { 25 | switch (algo) { 26 | case APE_ALGO_AUTO: 27 | return "AUTO"; 28 | case APE_ALGO_AUTO_STRICT: 29 | return "AUTO_STRICT"; 30 | case APE_ALGO_CUBLAS: 31 | return "CUBLAS"; 32 | case APE_ALGO_FP32F: 33 | return "FP32F"; 34 | case APE_ALGO_FP32B: 35 | return "FP32B"; 36 | case APE_ALGO_FP32T: 37 | return "FP32T"; 38 | case APE_ALGO_INT16: 39 | return "INT16"; 40 | default: 41 | return "Invalid"; 42 | } 43 | } 44 | 45 | void apeInit(const size_t buf_size = 0); 46 | 47 | void apeGemmFP32(ApeTrans transa, ApeTrans transb, int m, int n, int k, const float *alpha, const float *A, int lda, 48 | const float *B, int ldb, const float *beta, float *C, int ldc, const ApeAlgo algo = APE_ALGO_AUTO); 49 | 50 | void apeGemmFP64(ApeTrans transa, ApeTrans transb, int m, int n, int k, const double *alpha, const double *A, int lda, 51 | const double *B, int ldb, const double *beta, double *C, int ldc, ApeAlgo algo = APE_ALGO_AUTO); 52 | 53 | void apeGemmINT16(ApeTrans transa, ApeTrans transb, int m, int n, int k, const int16_t *alpha, const int16_t *A, int lda, 54 | const int16_t *B, int ldb, const int32_t *beta, int32_t *C, int ldc, ApeAlgo algo = APE_ALGO_AUTO); 55 | 56 | } // namespace ape -------------------------------------------------------------------------------- /src/gemm_fp32_fp32f.cpp: -------------------------------------------------------------------------------- 1 | #include "common.h" 2 | #include "kernel.h" 3 | 4 | namespace ape { 5 | void gemm_fp32_fp32f(ApeTrans transa, ApeTrans transb, int m, int n, int k, const float *alpha, const float *A, int lda, 6 | const float *B, int ldb, const float *beta, float *C, int ldc) { 7 | assert((m * k + k * n) * 4 <= APEHandler::getBufSize()); 8 | half *buf = (half *)APEHandler::getBuf(); 9 | half *buf_a, *buf_b; 10 | buf_a = (half *)APEHandler::getBuf(); 11 | buf_b = (half *)(APEHandler::getBuf()) + m * k * 2; 12 | 13 | split_fp32_to_fp16(buf_a, A, m * k); 14 | split_fp32_to_fp16(buf_b, B, k * n); 15 | 16 | float alpha0 = *alpha, alpha1 = *alpha / 4096.0f, beta0 = *beta, beta1 = 1; 17 | cublasSafeCall(cublasGemmEx(APEHandler::getCublasHandle(), cublasOperation_t(transa), cublasOperation_t(transb), m, n, k, 18 | &alpha0, buf_a, CUDA_R_16F, lda, buf_b, CUDA_R_16F, ldb, &beta0, C, CUDA_R_32F, ldc, 19 | CUBLAS_COMPUTE_32F, CUBLAS_GEMM_DEFAULT)); 20 | cublasSafeCall(cublasGemmEx(APEHandler::getCublasHandle(), cublasOperation_t(transa), cublasOperation_t(transb), m, n, k, 21 | &alpha1, buf_a + m * k, CUDA_R_16F, lda, buf_b, CUDA_R_16F, ldb, &beta1, C, CUDA_R_32F, ldc, 22 | CUBLAS_COMPUTE_32F, CUBLAS_GEMM_DEFAULT)); 23 | cublasSafeCall(cublasGemmEx(APEHandler::getCublasHandle(), cublasOperation_t(transa), cublasOperation_t(transb), m, n, k, 24 | &alpha1, buf_a, CUDA_R_16F, lda, buf_b + k * n, CUDA_R_16F, ldb, &beta1, C, CUDA_R_32F, ldc, 25 | CUBLAS_COMPUTE_32F, CUBLAS_GEMM_DEFAULT)); 26 | 27 | // cudaSafeCall(cudaFree(buf)); 28 | } 29 | 30 | } // namespace ape 31 | -------------------------------------------------------------------------------- /src/gemm_int16_emu.cpp: -------------------------------------------------------------------------------- 1 | #include "common.h" 2 | #include "kernel.h" 3 | 4 | namespace ape { 5 | // TODO: the interface of apeGemmINT16 seems contradictory with current cublasGemmEx 6 | void gemm_int16_emu(ApeTrans transa, ApeTrans transb, int m, int n, int k, const int16_t *alpha, const int16_t *A, int lda, 7 | const int16_t *B, int ldb, const int32_t *beta, int32_t *C, int ldc) { 8 | assert((m * k + k * n) * 2 <= APEHandler::getBufSize()); 9 | int8_t *buf = (int8_t *)APEHandler::getBuf(); 10 | int8_t *buf_a, *buf_b; 11 | buf_a = buf; 12 | buf_b = buf + m * k * 2; 13 | 14 | split_int16_to_int8(buf_a, A, m * k); 15 | split_int16_to_int8(buf_b, B, k * n); 16 | 17 | int alpha0 = *alpha * 256 * 256, alpha1 = *alpha * 256, alpha2 = *alpha; 18 | int beta0 = *beta, beta1 = 1; 19 | cublasSafeCall(cublasGemmEx(APEHandler::getCublasHandle(), cublasOperation_t(transa), cublasOperation_t(transb), m, n, k, 20 | &alpha0, buf_a, CUDA_R_8I, lda, buf_b, CUDA_R_8I, ldb, &beta0, C, CUDA_R_32I, ldc, 21 | CUBLAS_COMPUTE_32I, CUBLAS_GEMM_DEFAULT)); 22 | cublasSafeCall(cublasGemmEx(APEHandler::getCublasHandle(), cublasOperation_t(transa), cublasOperation_t(transb), m, n, k, 23 | &alpha1, buf_a + m * k, CUDA_R_8I, lda, buf_b, CUDA_R_8I, ldb, &beta1, C, CUDA_R_32I, ldc, 24 | CUBLAS_COMPUTE_32I, CUBLAS_GEMM_DEFAULT)); 25 | cublasSafeCall(cublasGemmEx(APEHandler::getCublasHandle(), cublasOperation_t(transa), cublasOperation_t(transb), m, n, k, 26 | &alpha1, buf_a, CUDA_R_8I, lda, buf_b + k * n, CUDA_R_8I, ldb, &beta1, C, CUDA_R_32I, ldc, 27 | CUBLAS_COMPUTE_32I, CUBLAS_GEMM_DEFAULT)); 28 | cublasSafeCall(cublasGemmEx(APEHandler::getCublasHandle(), cublasOperation_t(transa), cublasOperation_t(transb), m, n, k, 29 | &alpha2, buf_a + m * k, CUDA_R_8I, lda, buf_b + k * n, CUDA_R_8I, ldb, &beta1, C, CUDA_R_32I, 30 | ldc, CUBLAS_COMPUTE_32I, CUBLAS_GEMM_DEFAULT)); 31 | } 32 | } // namespace ape -------------------------------------------------------------------------------- /test/test_count_overflow_fp32.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "kernel.h" 4 | 5 | int host_count_overflow(const float *src, int size) { 6 | int count = 0; 7 | 8 | for (int i = 0; i < size; i++) { 9 | count += (src[i] < ape::FP32F_MIN || src[i] > ape::FP32F_MAX); 10 | } 11 | 12 | return count; 13 | } 14 | 15 | //#define DEBUG 16 | 17 | int main() { 18 | float *d_array; 19 | float *h_array; 20 | 21 | curandGenerator_t gen; 22 | curandSafeCall(curandCreateGenerator(&gen, CURAND_RNG_PSEUDO_DEFAULT)); 23 | 24 | cudaEvent_t st, ed; 25 | cudaEventCreate(&st); 26 | cudaEventCreate(&ed); 27 | 28 | for (int size = 128; size < 32768; size <<= 1) { 29 | h_array = (float *)malloc(sizeof(float) * size * size); 30 | cudaSafeCall(cudaMalloc((void **)&d_array, sizeof(float) * size * size)); 31 | curandSafeCall(curandGenerateLogNormal(gen, d_array, size * size, 0, 10)); 32 | cudaSafeCall(cudaMemcpy(h_array, d_array, size * size * sizeof(float), cudaMemcpyDeviceToHost)); 33 | 34 | #ifdef DEBUG 35 | for (int i = 0; i < size; i++) 36 | std::cout << h_array[i] << std::endl; 37 | #endif 38 | 39 | int h_count, d_count; 40 | h_count = host_count_overflow(h_array, size * size); 41 | 42 | for (int i = 0; i < 128; i++) 43 | ape::count_overflow_fp32f(d_array, size, size); 44 | 45 | float ms; 46 | cudaEventRecord(st); 47 | for (int i = 0; i < 128; i++) 48 | d_count = ape::count_overflow_fp32f(d_array, size, size); 49 | cudaEventRecord(ed); 50 | cudaEventSynchronize(st); 51 | cudaEventSynchronize(ed); 52 | cudaEventElapsedTime(&ms, st, ed); 53 | 54 | if (h_count == d_count) { 55 | std::cout << "correct" << std::endl; 56 | } else { 57 | std::cout << "unmatched: " << h_count << " " << d_count << std::endl; 58 | } 59 | 60 | std::cout << size << "x" << size << ": " << size * size * sizeof(int) * 128.0 / (ms * 1024.0 * 1024.0) << " GB/s" 61 | << std::endl; 62 | cudaFree(d_array); 63 | } 64 | 65 | return 0; 66 | } -------------------------------------------------------------------------------- /test/test_count_overflow_int16.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "kernel.h" 5 | 6 | int host_count_overflow(const int16_t *src, int size) { 7 | int count = 0; 8 | 9 | for (int i = 0; i < size; i++) { 10 | count += (src[i] > ape::INT16C_MAX); 11 | } 12 | 13 | return count; 14 | } 15 | 16 | void host_randint(int16_t *dst, int size) { 17 | for (int i = 0; i < size; i++) { 18 | dst[i] = rand() % 65536 - 32768; 19 | } 20 | } 21 | 22 | //#define DEBUG 23 | 24 | int main() { 25 | int16_t *d_array; 26 | int16_t *h_array; 27 | 28 | cudaEvent_t st, ed; 29 | cudaEventCreate(&st); 30 | cudaEventCreate(&ed); 31 | 32 | for (int size = 128; size < 32768; size <<= 1) { 33 | h_array = (int16_t *)malloc(sizeof(int16_t) * size * size); 34 | host_randint(h_array, size * size); 35 | cudaSafeCall(cudaMalloc((void **)&d_array, sizeof(int16_t) * size * size)); 36 | cudaSafeCall(cudaMemcpy(d_array, h_array, size * size * sizeof(int16_t), cudaMemcpyHostToDevice)); 37 | 38 | #ifdef DEBUG 39 | for (int i = 0; i < size; i++) 40 | std::cout << h_array[i] << std::endl; 41 | #endif 42 | 43 | int h_count, d_count; 44 | h_count = host_count_overflow(h_array, size * size); 45 | 46 | for (int i = 0; i < 128; i++) 47 | ape::count_overflow_int16emu(d_array, size, size); 48 | 49 | float ms; 50 | cudaEventRecord(st); 51 | for (int i = 0; i < 128; i++) 52 | d_count = ape::count_overflow_int16emu(d_array, size, size); 53 | cudaEventRecord(ed); 54 | cudaEventSynchronize(st); 55 | cudaEventSynchronize(ed); 56 | cudaEventElapsedTime(&ms, st, ed); 57 | 58 | if (h_count == d_count) { 59 | std::cout << "correct" << std::endl; 60 | } else { 61 | std::cout << "unmatched: " << h_count << " " << d_count << std::endl; 62 | } 63 | 64 | std::cout << size << "x" << size << ": " << size * size * sizeof(int16_t) * 128.0 / (ms * 1024.0 * 1024.0) << " GB/s" 65 | << std::endl; 66 | cudaFree(d_array); 67 | } 68 | 69 | return 0; 70 | } -------------------------------------------------------------------------------- /src/kernel/compare_fp32_to_fp64.cu: -------------------------------------------------------------------------------- 1 | #include "kernel.h" 2 | 3 | namespace ape { 4 | 5 | __global__ void kernel_calc_error(double *buf_sum, double *buf_max, const float *src, const double *dst, size_t size) { 6 | int lane_id = threadIdx.x % 32; 7 | int warp_id = threadIdx.x / 32; 8 | __shared__ double sbuf_sum[32]; 9 | __shared__ double sbuf_max[32]; 10 | 11 | uint32_t base = (blockIdx.x * blockDim.x + threadIdx.x); 12 | uint32_t step = blockDim.x * gridDim.x; 13 | double sum = 0, max = 0; 14 | for (uint32_t i = base; i < size; i += step) { 15 | double err = fabs(double(src[i]) - dst[i]) / fabs(dst[i]); 16 | sum += err; 17 | max = fmax(max, err); 18 | } 19 | 20 | for (int offset = 16; offset > 0; offset /= 2) { 21 | sum += __shfl_down_sync(0xffffffff, sum, offset); 22 | max = fmax(max, __shfl_down_sync(0xffffffff, max, offset)); 23 | } 24 | 25 | sbuf_sum[warp_id] = sum; 26 | sbuf_max[warp_id] = max; 27 | __syncthreads(); 28 | if (warp_id == 0) { 29 | sum = sbuf_sum[lane_id]; 30 | max = sbuf_max[lane_id]; 31 | for (int offset = 16; offset > 0; offset /= 2) { 32 | sum += __shfl_down_sync(0xffffffff, sum, offset); 33 | max = fmax(max, __shfl_down_sync(0xffffffff, max, offset)); 34 | } 35 | } 36 | __syncthreads(); 37 | 38 | if (threadIdx.x == 0) { 39 | buf_sum[blockIdx.x] = sum; 40 | buf_max[blockIdx.x] = max; 41 | } 42 | } 43 | 44 | void compare_fp32_to_fp64(const float *src, const double *dst, size_t size, double &max_error, double &mean_error) { 45 | double *buf_sum, *buf_max; 46 | cudaSafeCall(cudaMalloc((void **)&buf_sum, 108 * sizeof(double))); 47 | cudaSafeCall(cudaMalloc((void **)&buf_max, 108 * sizeof(double))); 48 | 49 | dim3 grid_size(108, 1); 50 | dim3 block_size(1024, 1); 51 | kernel_calc_error<<>>(buf_sum, buf_max, src, dst, size); 52 | cudaCheckError(); 53 | 54 | double *buf_sum_host = (double *)malloc(108 * sizeof(double)); 55 | double *buf_max_host = (double *)malloc(108 * sizeof(double)); 56 | cudaSafeCall(cudaMemcpy(buf_sum_host, buf_sum, 108 * sizeof(double), cudaMemcpyDeviceToHost)); 57 | cudaSafeCall(cudaMemcpy(buf_max_host, buf_max, 108 * sizeof(double), cudaMemcpyDeviceToHost)); 58 | 59 | double sum = 0, max = 0; 60 | for (int i = 0; i < 108; i++) { 61 | sum += buf_sum_host[i]; 62 | max = std::max(max, buf_max_host[i]); 63 | } 64 | max_error = max; 65 | mean_error = sum / size; 66 | } 67 | 68 | } // namespace ape -------------------------------------------------------------------------------- /src/gemm_fp32_fp32b.cpp: -------------------------------------------------------------------------------- 1 | #include "common.h" 2 | #include "kernel.h" 3 | 4 | namespace ape { 5 | void gemm_fp32_fp32b(ApeTrans transa, ApeTrans transb, int m, int n, int k, const float *alpha, const float *A, int lda, 6 | const float *B, int ldb, const float *beta, float *C, int ldc) { 7 | assert((m * k + k * n) * 6 <= APEHandler::getBufSize()); 8 | __nv_bfloat16 *buf = (__nv_bfloat16 *)APEHandler::getBuf(); 9 | __nv_bfloat16 *buf_a, *buf_b; 10 | buf_a = buf; 11 | buf_b = buf + m * k * 3; 12 | 13 | split_fp32_to_bf16(buf_a, A, m * k); 14 | split_fp32_to_bf16(buf_b, B, k * n); 15 | 16 | float beta0 = *beta, beta1 = 1; 17 | cublasSafeCall(cublasGemmEx(APEHandler::getCublasHandle(), cublasOperation_t(transa), cublasOperation_t(transb), m, n, k, 18 | alpha, buf_a, CUDA_R_16BF, lda, buf_b, CUDA_R_16BF, ldb, &beta0, C, CUDA_R_32F, ldc, 19 | CUBLAS_COMPUTE_32F, CUBLAS_GEMM_DEFAULT)); 20 | cublasSafeCall(cublasGemmEx(APEHandler::getCublasHandle(), cublasOperation_t(transa), cublasOperation_t(transb), m, n, k, 21 | alpha, buf_a + m * k, CUDA_R_16BF, lda, buf_b, CUDA_R_16BF, ldb, &beta1, C, CUDA_R_32F, ldc, 22 | CUBLAS_COMPUTE_32F, CUBLAS_GEMM_DEFAULT)); 23 | cublasSafeCall(cublasGemmEx(APEHandler::getCublasHandle(), cublasOperation_t(transa), cublasOperation_t(transb), m, n, k, 24 | alpha, buf_a, CUDA_R_16BF, lda, buf_b + k * n, CUDA_R_16BF, ldb, &beta1, C, CUDA_R_32F, ldc, 25 | CUBLAS_COMPUTE_32F, CUBLAS_GEMM_DEFAULT)); 26 | cublasSafeCall(cublasGemmEx(APEHandler::getCublasHandle(), cublasOperation_t(transa), cublasOperation_t(transb), m, n, k, 27 | alpha, buf_a + m * k, CUDA_R_16BF, lda, buf_b + k * n, CUDA_R_16BF, ldb, &beta1, C, CUDA_R_32F, 28 | ldc, CUBLAS_COMPUTE_32F, CUBLAS_GEMM_DEFAULT)); 29 | cublasSafeCall(cublasGemmEx(APEHandler::getCublasHandle(), cublasOperation_t(transa), cublasOperation_t(transb), m, n, k, 30 | alpha, buf_a + m * k * 2, CUDA_R_16BF, lda, buf_b, CUDA_R_16BF, ldb, &beta1, C, CUDA_R_32F, ldc, 31 | CUBLAS_COMPUTE_32F, CUBLAS_GEMM_DEFAULT)); 32 | cublasSafeCall(cublasGemmEx(APEHandler::getCublasHandle(), cublasOperation_t(transa), cublasOperation_t(transb), m, n, k, 33 | alpha, buf_a, CUDA_R_16BF, lda, buf_b + k * n * 2, CUDA_R_16BF, ldb, &beta1, C, CUDA_R_32F, ldc, 34 | CUBLAS_COMPUTE_32F, CUBLAS_GEMM_DEFAULT)); 35 | } 36 | } // namespace ape 37 | -------------------------------------------------------------------------------- /src/ape.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "ape.h" 4 | #include "common.h" 5 | #include "kernel.h" 6 | 7 | namespace ape { 8 | 9 | void apeInit(const size_t buf_size) { 10 | APEHandler::initCublas(); 11 | if (buf_size > 0) { 12 | APEHandler::initBuffer(buf_size); 13 | } 14 | } 15 | 16 | void apeGemmFP32(ApeTrans transa, ApeTrans transb, int m, int n, int k, const float *alpha, const float *A, int lda, 17 | const float *B, int ldb, const float *beta, float *C, int ldc, const ApeAlgo algo) { 18 | switch (algo) { 19 | case APE_ALGO_AUTO: 20 | gemm_fp32_auto(transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); 21 | break; 22 | case APE_ALGO_AUTO_STRICT: 23 | gemm_fp32_auto_strict(transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); 24 | break; 25 | case APE_ALGO_CUBLAS: 26 | gemm_fp32_cublas(transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); 27 | break; 28 | case APE_ALGO_FP32F: 29 | gemm_fp32_fp32f(transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); 30 | break; 31 | case APE_ALGO_FP32B: 32 | gemm_fp32_fp32b(transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); 33 | break; 34 | case APE_ALGO_FP32T: 35 | gemm_fp32_fp32t(transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); 36 | break; 37 | default: 38 | assert(false); 39 | } 40 | } 41 | 42 | void apeGemmFP64(ApeTrans transa, ApeTrans transb, int m, int n, int k, const double *alpha, const double *A, int lda, 43 | const double *B, int ldb, const double *beta, double *C, int ldc, ApeAlgo algo) { 44 | switch (algo) { 45 | case APE_ALGO_AUTO: 46 | gemm_fp64_cublas(transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); 47 | break; 48 | case APE_ALGO_CUBLAS: 49 | gemm_fp64_cublas(transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); 50 | break; 51 | default: 52 | assert(false); 53 | } 54 | } 55 | 56 | void apeGemmINT16(ApeTrans transa, ApeTrans transb, int m, int n, int k, const int16_t *alpha, const int16_t *A, int lda, 57 | const int16_t *B, int ldb, const int32_t *beta, int32_t *C, int ldc, ApeAlgo algo) { 58 | switch (algo) { 59 | case APE_ALGO_AUTO: 60 | // TODO: check layout 61 | assert(count_overflow_int16emu(A, m, k) == 0); 62 | assert(count_overflow_int16emu(B, k, n) == 0); 63 | gemm_int16_emu(transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); 64 | break; 65 | case APE_ALGO_INT16: 66 | gemm_int16_emu(transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); 67 | break; 68 | default: 69 | assert(false); 70 | } 71 | } 72 | 73 | } // namespace ape -------------------------------------------------------------------------------- /src/kernel/fp32_auto.cu: -------------------------------------------------------------------------------- 1 | #include "kernel.h" 2 | #include "thrust/device_vector.h" 3 | #include 4 | #include 5 | 6 | namespace ape { 7 | 8 | struct OpFP32F { 9 | __device__ int operator()(float x) { return (fabs(x) > FP32F_MAX); } 10 | }; 11 | 12 | struct OpFP32FStrict { 13 | __device__ int operator()(float x) { return (fabs(x) < FP32F_MIN || fabs(x) > FP32F_MAX); } 14 | }; 15 | 16 | __global__ void kernel_create_mask_fp32(const float *src, size_t row, size_t col, ApeTrans trans, int8_t *mask) { 17 | extern __shared__ char smem[]; 18 | int8_t *shmem = (int8_t *)&smem[0]; 19 | uint32_t warp_id = threadIdx.y / 32; 20 | uint32_t lane_id = threadIdx.y % 32; 21 | if (blockIdx.y * blockDim.y + threadIdx.y >= col) { 22 | return; 23 | } 24 | 25 | const float *base; 26 | size_t step; 27 | if (trans == APE_TRANS_T) { 28 | base = src + blockIdx.x * AUTO_BLOCK * col + blockIdx.y * AUTO_BLOCK + threadIdx.y; 29 | step = col; 30 | } else { 31 | base = src + blockIdx.y * AUTO_BLOCK * row + blockIdx.x * AUTO_BLOCK + threadIdx.y * row; 32 | step = 1; 33 | } 34 | 35 | int8_t flag = 1; 36 | for (int i = 0; i < AUTO_BLOCK && blockIdx.x * AUTO_BLOCK + i < row; i++) { 37 | float tmp = fabs(base[i * step]); 38 | if (tmp < FP32F_MIN || tmp > FP32F_MAX) { 39 | flag = 0; 40 | } 41 | } 42 | 43 | for (int i = 16; i > 0; i >>= 1) { 44 | int tmp = __shfl_down_sync(0xffffffff, flag, i); // warp shuffle 45 | if (tmp == 0) { 46 | flag = 0; 47 | } 48 | } 49 | 50 | if (lane_id == 0) { 51 | shmem[warp_id] = flag; 52 | } 53 | __syncthreads(); 54 | 55 | if (threadIdx.y == 0) { 56 | for (int i = 0; i < (blockDim.y - 1) / 32 + 1; i++) { 57 | if (shmem[i] == 0) { 58 | flag = 0; 59 | break; 60 | } 61 | } 62 | 63 | if (trans == APE_TRANS_T) { 64 | mask[blockIdx.x * gridDim.y + blockIdx.y] = flag; 65 | } else { 66 | mask[blockIdx.y * gridDim.x + blockIdx.x] = flag; 67 | } 68 | } 69 | } 70 | 71 | void create_mask_fp32(const float *src, size_t row, size_t col, ApeTrans trans, int8_t *mask) { 72 | dim3 grid((row - 1) / AUTO_BLOCK + 1, (col - 1) / AUTO_BLOCK + 1, 1); 73 | dim3 block(1, AUTO_BLOCK, 1); 74 | kernel_create_mask_fp32<<>>(src, row, col, trans, mask); 75 | cudaCheckError(); 76 | } 77 | 78 | int count_overflow_fp32f(const float *src, size_t row, size_t col) { 79 | thrust::device_ptr d_src(const_cast(src)); 80 | return thrust::transform_reduce(d_src, d_src + row * col, OpFP32F(), 0, thrust::plus()); 81 | } 82 | 83 | int count_overflow_fp32f_strict(const float *src, size_t row, size_t col) { 84 | thrust::device_ptr d_src(const_cast(src)); 85 | return thrust::transform_reduce(d_src, d_src + row * col, OpFP32FStrict(), 0, thrust::plus()); 86 | } 87 | 88 | } // namespace ape 89 | -------------------------------------------------------------------------------- /src/gemm_fp32_test.cpp: -------------------------------------------------------------------------------- 1 | #include "ape.h" 2 | #include "common.h" 3 | #include "kernel.h" 4 | 5 | namespace ape { 6 | void test_gemm_fp32(int m, int n, int k, ape::ApeAlgo algo) { 7 | int width; 8 | switch (algo) { 9 | case APE_ALGO_AUTO: 10 | width = 8; 11 | break; 12 | case APE_ALGO_FP32F: 13 | width = 4; 14 | break; 15 | case APE_ALGO_FP32B: 16 | width = 6; 17 | break; 18 | case APE_ALGO_FP32T: 19 | width = 8; 20 | break; 21 | default: 22 | width = 0; 23 | } 24 | apeInit((m * k + k * n) * width); 25 | float *data_eval_a = 0, *data_eval_b = 0, *data_eval_c = 0; 26 | cudaSafeCall(cudaMalloc((void **)&data_eval_a, m * k * sizeof(float))); 27 | cudaSafeCall(cudaMalloc((void **)&data_eval_b, k * n * sizeof(float))); 28 | cudaSafeCall(cudaMalloc((void **)&data_eval_c, m * n * sizeof(float))); 29 | double *data_res_a = 0, *data_res_b = 0, *data_res_c = 0; 30 | cudaSafeCall(cudaMalloc((void **)&data_res_a, m * k * sizeof(double))); 31 | cudaSafeCall(cudaMalloc((void **)&data_res_b, k * n * sizeof(double))); 32 | cudaSafeCall(cudaMalloc((void **)&data_res_c, m * n * sizeof(double))); 33 | 34 | curandGenerator_t gen; 35 | curandSafeCall(curandCreateGenerator(&gen, CURAND_RNG_PSEUDO_DEFAULT)); 36 | curandSafeCall(curandGenerateUniformDouble(gen, data_res_a, m * k)); 37 | curandSafeCall(curandGenerateUniformDouble(gen, data_res_b, k * n)); 38 | curandSafeCall(curandGenerateUniformDouble(gen, data_res_c, m * n)); 39 | 40 | ape::convert_fp64_to_fp32(data_eval_a, data_res_a, m * k); 41 | ape::convert_fp64_to_fp32(data_eval_b, data_res_b, k * n); 42 | ape::convert_fp64_to_fp32(data_eval_c, data_res_c, m * n); 43 | 44 | double alpha_res = 1.0, beta_res = 0; 45 | ape::apeGemmFP64(ape::APE_TRANS_N, ape::APE_TRANS_N, m, n, k, &alpha_res, data_res_a, m, data_res_b, k, &beta_res, 46 | data_res_c, m, ape::APE_ALGO_CUBLAS); 47 | float alpha_eval = 1.0f, beta_eval = 0.0f; 48 | ape::apeGemmFP32(ape::APE_TRANS_N, ape::APE_TRANS_N, m, n, k, &alpha_eval, data_eval_a, m, data_eval_b, k, &beta_eval, 49 | data_eval_c, m, algo); 50 | double max_error, mean_error; 51 | ape::compare_fp32_to_fp64(data_eval_c, data_res_c, m * n, max_error, mean_error); 52 | 53 | float duration = 0; 54 | cudaEvent_t st, ed; 55 | cudaEventCreate(&st); 56 | cudaEventCreate(&ed); 57 | for (int i = 0; i < 128; i++) { 58 | ape::apeGemmFP32(ape::APE_TRANS_N, ape::APE_TRANS_N, m, n, k, &alpha_eval, data_eval_a, m, data_eval_b, k, &beta_eval, 59 | data_eval_c, m, algo); 60 | } 61 | cudaEventRecord(st, 0); 62 | for (int i = 0; i < 128; i++) { 63 | ape::apeGemmFP32(ape::APE_TRANS_N, ape::APE_TRANS_N, m, n, k, &alpha_eval, data_eval_a, m, data_eval_b, k, &beta_eval, 64 | data_eval_c, m, algo); 65 | } 66 | cudaEventRecord(ed, 0); 67 | cudaEventSynchronize(st); 68 | cudaEventSynchronize(ed); 69 | cudaEventElapsedTime(&duration, st, ed); 70 | double perf = double(m) * double(n) * double(k) * 2.0f * 128.0f / duration / 1e9; 71 | 72 | std::cout << "[TEST] " << getApeAlgoName(algo) << " (" << m << " " << n << " " << k << ") max_error: " << max_error 73 | << " mean_error: " << mean_error << " perf(TFLOPS): " << perf << std::endl; 74 | cudaFree(data_eval_a); 75 | cudaFree(data_eval_b); 76 | cudaFree(data_eval_c); 77 | cudaFree(data_res_a); 78 | cudaFree(data_res_b); 79 | cudaFree(data_res_c); 80 | } 81 | 82 | } // namespace ape 83 | -------------------------------------------------------------------------------- /include/kernel.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "ape.h" 4 | #include "common.h" 5 | 6 | namespace ape { 7 | 8 | #ifdef ARCH_SM80 9 | constexpr int NUM_SM = 108; 10 | constexpr int MAX_THREAD = 1024; 11 | #endif 12 | #ifdef ARCH_SM70 13 | constexpr int NUM_SM = 80; 14 | constexpr int MAX_THREAD = 1024; 15 | #endif 16 | 17 | constexpr int AUTO_BLOCK = 128; 18 | 19 | constexpr float FP32F_MAX = 65504.0f; 20 | constexpr float FP32F_MIN = 3.1e-5f; 21 | constexpr float FP32B_MAX = 3.38e38f; 22 | constexpr float FP32B_MIN = 3.9e-34f; 23 | constexpr int INT16C_MAX = 32639; 24 | 25 | void gemm_fp32_auto(ApeTrans transa, ApeTrans transb, int m, int n, int k, const float *alpha, const float *A, int lda, 26 | const float *B, int ldb, const float *beta, float *C, int ldc); 27 | void gemm_fp32_auto_strict(ApeTrans transa, ApeTrans transb, int m, int n, int k, const float *alpha, const float *A, int lda, 28 | const float *B, int ldb, const float *beta, float *C, int ldc); 29 | void gemm_fp32_cublas(ApeTrans transa, ApeTrans transb, int m, int n, int k, const float *alpha, const float *A, int lda, 30 | const float *B, int ldb, const float *beta, float *C, int ldc); 31 | void gemm_fp32_fp32f(ApeTrans transa, ApeTrans transb, int m, int n, int k, const float *alpha, const float *A, int lda, 32 | const float *B, int ldb, const float *beta, float *C, int ldc); 33 | void gemm_fp32_fp32b(ApeTrans transa, ApeTrans transb, int m, int n, int k, const float *alpha, const float *A, int lda, 34 | const float *B, int ldb, const float *beta, float *C, int ldc); 35 | void gemm_fp32_fp32t(ApeTrans transa, ApeTrans transb, int m, int n, int k, const float *alpha, const float *A, int lda, 36 | const float *B, int ldb, const float *beta, float *C, int ldc); 37 | void gemm_fp64_cublas(ApeTrans transa, ApeTrans transb, int m, int n, int k, const double *alpha, const double *A, int lda, 38 | const double *B, int ldb, const double *beta, double *C, int ldc); 39 | void gemm_int16_emu(ApeTrans transa, ApeTrans transb, int m, int n, int k, const int16_t *alpha, const int16_t *A, int lda, 40 | const int16_t *B, int ldb, const int32_t *beta, int32_t *C, int ldc); 41 | 42 | void convert_fp64_to_fp32(float *dst, const double *src, size_t size); 43 | void convert_fp32_to_fp64(double *dst, const float *src, size_t size); 44 | void convert_int32_to_int16(int16_t *dst, const int32_t *src, size_t size); 45 | void convert_int16_to_int32(int32_t *dst, const int16_t *src, size_t size); 46 | void compare_fp32_to_fp64(const float *src, const double *dst, size_t size, double &max_error, double &mean_error); 47 | 48 | void split_fp32_to_fp16(half *dst, const float *src, size_t size); 49 | void merge_fp16_to_fp32(float *dst, const half *src, size_t size); 50 | void split_fp32_to_bf16(__nv_bfloat16 *dst, const float *src, uint32_t size); 51 | void merge_bf16_to_fp32(float *dst, const __nv_bfloat16 *src, uint32_t size); 52 | void split_fp32_to_tf32(float *dst, const float *src, size_t size); 53 | void merge_tf32_to_fp32(float *dst, const float *src, size_t size); 54 | void split_int16_to_int8(int8_t *dst, const int16_t *src, size_t size); 55 | void merge_int8_to_int16(int16_t *dst, const int8_t *src, size_t size); 56 | void create_mask_fp32(const float *src, size_t row, size_t col, ApeTrans trans, int8_t *mask); 57 | int count_overflow_fp32f(const float *src, size_t row, size_t col); 58 | int count_overflow_fp32f_strict(const float *src, size_t row, size_t col); 59 | int count_overflow_int16emu(const int16_t *src, size_t row, size_t col); 60 | 61 | __device__ inline double fmax(double a, double b) { return (a > b) ? a : b; } 62 | __device__ inline double fabs(double a) { return (a > 0) ? a : -a; } 63 | 64 | } // namespace ape -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # APE on CUDA 2 | 3 | > **Integrated into cuBLAS 13.1** 4 | > The floating-point emulation method in this project is now intergrated as a basic feature of NVIDIA cuBLAS in CUDA 13.1 and later. You can read the details in the [cuBLAS documentation](https://docs.nvidia.com/cuda/cublas/#floating-point-emulation). 5 | 6 | This project is an APE implementation on NVIDIA GPU using the cuBLAS backend. 7 | 8 | APE is a method of emulating high-bitwidth computation with low-bitwidth data types. 9 | For example, APE can use $3$ or $6$ Tensor Core low-bitwidth computation to emulate an FP32 computation with up to $5.3\times$ theoretical speedup. 10 | This project provides the following: 11 | 12 | * GEMM implementations using Tensor Cores with FP32-precision and various representation ranges. 13 | * Auto-adapted algorithm selection that guarantees end-to-end correctness. 14 | * INT16 GEMM implementation using Tensor Cores. 15 | 16 | For more details, please see our [paper](https://dl.acm.org/doi/abs/10.1145/3524059.3532377). 17 | 18 | ## Usage 19 | 20 | ### Build 21 | 22 | ```shell 23 | mkdir build && cd build 24 | cmake .. 25 | make -j 26 | ``` 27 | 28 | ### API 29 | 30 | APE provides a blas-like API, and users only need to include ape.h to use APE to accelerate FP32 applications directly. 31 | 32 | ```c++ 33 | void apeGemmFP32(ApeTrans transa, ApeTrans transb, int m, int n, int k, const float *alpha, const float *A, int lda, const float *B, int ldb, const float *beta, float *C, int ldc, const ApeAlgo algo = APE_ALGO_AUTO); 34 | ``` 35 | 36 | FP32 GEMM supports $5$ algorithms: 37 | 38 | * **APE_ALGO_AUTO**: Select the fastest algorithm without overflow. 39 | 40 | * **APE_ALGO_AUTO_STRICT**: Select the fastest algorithm without overflow and underflow. 41 | 42 | * **APE_ALGO_FP32F**: Use FP16 emulated FP32. (1-bit precision loss, narrow representation range, overflow may occur.) 43 | 44 | * **APE_ALGO_FP32B**: Use BF16 emulated FP32. (no precision loss, large representation range, overflow does not occur.) 45 | 46 | * **APE_ALGO_FP32T**: Use TF32 emulated FP32. (1-bit precision loss, large representation range, overflow does not occur.) 47 | 48 | ```c++ 49 | void apeGemmINT16(ApeTrans transa, ApeTrans transb, int m, int n, int k, const int16_t *alpha, const int16_t *A, int lda, const int16_t *B, int ldb, const int32_t *beta, int32_t *C, int ldc, ApeAlgo algo = APE_ALGO_AUTO); 50 | ``` 51 | 52 | INT16 GEMM supports $2$ algorithms: 53 | 54 | * **APE_ALGO_AUTO**: Select the algorithm without overflow. 55 | 56 | * **APE_ALGO_INT16**: Use INT8 emulate INT16. (The upper bound is $32639$. Native INT16's is $32767$. Overflow may occur.) 57 | 58 | ## Authors 59 | - [Zixuan Ma](https://github.com/JohndeVostok) 60 | - [Yanzhuo Chen](https://github.com/yz-chen18) 61 | 62 | 63 | ## Citation 64 | 65 | Ma, Zixuan, et al. "Efficiently emulating high-bitwidth computation with low-bitwidth hardware." Proceedings of the 36th ACM International Conference on Supercomputing. 2022. 66 | 67 | If you find this work useful in your research, please cite it using the following BibTeX: 68 | 69 | ```bibtex 70 | @inproceedings{ma2022efficiently, 71 | author = {Ma, Zixuan and Wang, Haojie and Feng, Guanyu and Zhang, Chen and Xie, Lei and He, Jiaao and Chen, Shengqi and Zhai, Jidong}, 72 | title = {Efficiently Emulating High-Bitwidth Computation with Low-Bitwidth Hardware}, 73 | year = {2022}, 74 | isbn = {9781450392815}, 75 | publisher = {Association for Computing Machinery}, 76 | address = {New York, NY, USA}, 77 | url = {https://doi.org/10.1145/3524059.3532377}, 78 | doi = {10.1145/3524059.3532377}, 79 | booktitle = {Proceedings of the 36th ACM International Conference on Supercomputing}, 80 | articleno = {5}, 81 | numpages = {12}, 82 | keywords = {emulation, tensor core, domain specific accelerator}, 83 | location = {Virtual Event}, 84 | series = {ICS '22} 85 | } 86 | ``` 87 | -------------------------------------------------------------------------------- /include/common.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | #define ape_error(str) __ape_error(str, __FILE__, __LINE__) 12 | #define ape_warning(str) __ape_warning(str, __FILE__, __LINE__) 13 | #define ape_info(str) __ape_info(str, __FILE__, __LINE__) 14 | #define cudaSafeCall(err) __cudaSafeCall(err, __FILE__, __LINE__) 15 | #define cudaCheckError() __cudaCheckError(__FILE__, __LINE__) 16 | #define cublasSafeCall(err) __cublasSafeCall(err, __FILE__, __LINE__) 17 | #define curandSafeCall(err) __curandSafeCall(err, __FILE__, __LINE__) 18 | 19 | inline void __ape_error(std::string str, const char *file, const int line) { 20 | std::cout << "[ERROR] " << file << "::" << line << " " << str << std::endl; 21 | exit(-1); 22 | } 23 | 24 | inline void __ape_warning(std::string str, const char *file, const int line) { 25 | std::cout << "[WARNING] " << file << "::" << line << " " << str << std::endl; 26 | #if DEBUG 27 | exit(-1); 28 | #endif 29 | } 30 | 31 | inline void __ape_info(std::string str, const char *file, const int line) { 32 | std::cout << "[INFO] " << file << "::" << line << " " << str << std::endl; 33 | } 34 | 35 | inline void __cudaSafeCall(cudaError err, const char *file, const int line) { 36 | if (err != cudaSuccess) { 37 | std::cout << "[ERROR] " << file << "::" << line << ": cudaSafeCall() failed. " << cudaGetErrorString(err) << std::endl; 38 | exit(-1); 39 | } 40 | return; 41 | } 42 | 43 | inline void __cudaCheckError(const char *file, const int line) { 44 | auto err = cudaGetLastError(); 45 | if (err != cudaSuccess) { 46 | std::cout << "[ERROR] " << file << "::" << line << ": cudaCheckError() failed. " << cudaGetErrorString(err) 47 | << std::endl; 48 | exit(-1); 49 | } 50 | 51 | #ifdef DEBUG 52 | // This checking will affect performance. 53 | err = cudaDeviceSynchronize(); 54 | if (err != cudaSuccess) { 55 | std::cout << "[ERROR] " << file << "::" << line << ": cudaCheckError() with sync failed. " << cudaGetErrorString(err) 56 | << std::endl; 57 | exit(-1); 58 | } 59 | #endif 60 | 61 | return; 62 | } 63 | 64 | inline const char *cublasGetErrorString(cublasStatus_t err) { 65 | switch (err) { 66 | case CUBLAS_STATUS_SUCCESS: 67 | return "CUBLAS_STATUS_SUCCESS"; 68 | case CUBLAS_STATUS_NOT_INITIALIZED: 69 | return "CUBLAS_STATUS_NOT_INITIALIZED"; 70 | case CUBLAS_STATUS_ALLOC_FAILED: 71 | return "CUBLAS_STATUS_ALLOC_FAILED"; 72 | case CUBLAS_STATUS_INVALID_VALUE: 73 | return "CUBLAS_STATUS_INVALID_VALUE"; 74 | case CUBLAS_STATUS_ARCH_MISMATCH: 75 | return "CUBLAS_STATUS_ARCH_MISMATCH"; 76 | case CUBLAS_STATUS_MAPPING_ERROR: 77 | return "CUBLAS_STATUS_MAPPING_ERROR"; 78 | case CUBLAS_STATUS_EXECUTION_FAILED: 79 | return "CUBLAS_STATUS_EXECUTION_FAILED"; 80 | case CUBLAS_STATUS_INTERNAL_ERROR: 81 | return "CUBLAS_STATUS_INTERNAL_ERROR"; 82 | case CUBLAS_STATUS_NOT_SUPPORTED: 83 | return "CUBLAS_STATUS_NOT_SUPPORTED"; 84 | case CUBLAS_STATUS_LICENSE_ERROR: 85 | return "CUBLAS_STATUS_LICENSE_ERROR"; 86 | } 87 | return ""; 88 | } 89 | 90 | inline void __cublasSafeCall(cublasStatus_t err, const char *file, const int line) { 91 | if (err != CUBLAS_STATUS_SUCCESS) { 92 | std::cout << "[ERROR]" << file << "::" << line << ": cublasSafeCall() failed. " << cublasGetErrorString(err) 93 | << std::endl; 94 | exit(-1); 95 | } 96 | } 97 | 98 | inline void __curandSafeCall(curandStatus_t err, const char *file, const int line) { 99 | if (err != CURAND_STATUS_SUCCESS) { 100 | std::cout << "[ERROR]" << file << "::" << line << ": curandSafeCall() failed. " << err << std::endl; 101 | exit(-1); 102 | } 103 | } 104 | 105 | namespace ape { 106 | 107 | class APEHandler { 108 | private: 109 | APEHandler() {} 110 | static APEHandler *instance; 111 | cublasHandle_t ape_cublas_handle; 112 | void *buf; 113 | size_t buf_size; 114 | 115 | public: 116 | static inline APEHandler *getInstance() { 117 | static APEHandler instance; 118 | return &instance; 119 | } 120 | static inline cublasHandle_t getCublasHandle() { return getInstance()->ape_cublas_handle; } 121 | static inline void initCublas() { 122 | cublasHandle_t handle; 123 | cublasSafeCall(cublasCreate(&handle)); 124 | cublasSafeCall(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); 125 | getInstance()->ape_cublas_handle = handle; 126 | } 127 | static inline void initBuffer(size_t buf_size) { 128 | void *buf; 129 | cudaSafeCall(cudaMalloc((void **)&buf, buf_size)); 130 | getInstance()->buf_size = buf_size; 131 | getInstance()->buf = buf; 132 | } 133 | static inline size_t getBufSize() { return getInstance()->buf_size; } 134 | static inline void *getBuf() { return getInstance()->buf; } 135 | }; 136 | 137 | } // namespace ape -------------------------------------------------------------------------------- /test/test_create_mask_fp32.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "kernel.h" 4 | 5 | void host_create_mask_fp32(const float *src, int m, int n, ape::ApeTrans trans, int8_t *mask) { 6 | int row_blocks = (m - 1) / ape::AUTO_BLOCK + 1; 7 | int col_blocks = (n - 1) / ape::AUTO_BLOCK + 1; 8 | 9 | for (int block_m = 0; block_m < row_blocks; block_m++) { 10 | for (int block_n = 0; block_n < col_blocks; block_n++) { 11 | if (trans == ape::APE_TRANS_T) { 12 | mask[block_m * col_blocks + block_n] = 1; 13 | } else { 14 | mask[block_n * row_blocks + block_m] = 1; 15 | } 16 | for (int i = 0; i < ape::AUTO_BLOCK; i++) { // row 17 | for (int j = 0; j < ape::AUTO_BLOCK; j++) { // col 18 | if (trans == ape::APE_TRANS_T) { 19 | int index = (i + block_m * ape::AUTO_BLOCK) * n + block_n * ape::AUTO_BLOCK + j; 20 | if (index >= m * n) 21 | goto kernelend; 22 | if (src[index] < ape::FP32F_MIN || src[index] > ape::FP32F_MAX) { 23 | mask[block_m * col_blocks + block_n] = 0; 24 | goto kernelend; 25 | } 26 | } else { 27 | int index = (j + block_n * ape::AUTO_BLOCK) * m + block_m * ape::AUTO_BLOCK + i; 28 | if (index >= m * n) 29 | goto kernelend; 30 | if (src[index] < ape::FP32F_MIN || src[index] > ape::FP32F_MAX) { 31 | mask[block_n * row_blocks + block_m] = 0; 32 | goto kernelend; 33 | } 34 | } 35 | } 36 | } 37 | 38 | kernelend: 39 | continue; 40 | } 41 | } 42 | } 43 | 44 | //#define DEBUG 45 | //#define RESULT 46 | 47 | int main() { 48 | float *d_array, *h_array; 49 | int8_t *h_mask, *d_mask; 50 | 51 | curandGenerator_t gen; 52 | curandSafeCall(curandCreateGenerator(&gen, CURAND_RNG_PSEUDO_DEFAULT)); 53 | 54 | cudaEvent_t st, ed; 55 | cudaEventCreate(&st); 56 | cudaEventCreate(&ed); 57 | 58 | for (int size = 128; size <= 8192; size <<= 1) { 59 | int num_blocks = (size - 1) / ape::AUTO_BLOCK + 1; 60 | h_array = (float *)malloc(sizeof(float) * size * size); 61 | h_mask = (int8_t *)malloc(sizeof(int8_t) * (num_blocks * num_blocks)); 62 | int8_t *tmp = (int8_t *)malloc(sizeof(int8_t) * (num_blocks * num_blocks)); 63 | memset(h_mask, 0, sizeof(int8_t) * (num_blocks * num_blocks)); 64 | cudaSafeCall(cudaMalloc((void **)&d_array, sizeof(float) * size * size)); 65 | curandSafeCall(curandGenerateUniform(gen, d_array, size * size)); 66 | cudaSafeCall(cudaMemcpy(h_array, d_array, size * size * sizeof(float), cudaMemcpyDeviceToHost)); 67 | cudaSafeCall(cudaMalloc((void **)&d_mask, sizeof(int8_t) * (num_blocks * num_blocks))); 68 | cudaSafeCall(cudaMemset(d_mask, 0, sizeof(int8_t) * (num_blocks * num_blocks))); 69 | 70 | #ifdef DEBUG 71 | for (int i = 0; i < size; i++) { 72 | for (int j = 0; j < size; j++) { 73 | std::cout << h_array[i * size + j] << " "; 74 | } 75 | 76 | std::cout << std::endl; 77 | } 78 | #endif 79 | 80 | host_create_mask_fp32(h_array, size, size, ape::APE_TRANS_N, h_mask); 81 | 82 | for (int i = 0; i < 128; i++) 83 | ape::create_mask_fp32(d_array, size, size, ape::APE_TRANS_N, d_mask); 84 | 85 | float ms; 86 | cudaEventRecord(st); 87 | for (int i = 0; i < 128; i++) 88 | ape::create_mask_fp32(d_array, size, size, ape::APE_TRANS_N, d_mask); 89 | cudaEventRecord(ed); 90 | cudaEventSynchronize(st); 91 | cudaEventSynchronize(ed); 92 | cudaEventElapsedTime(&ms, st, ed); 93 | 94 | cudaSafeCall(cudaMemcpy(tmp, d_mask, sizeof(int8_t) * (num_blocks * num_blocks), cudaMemcpyDeviceToHost)); 95 | 96 | int i; 97 | for (i = 0; i < num_blocks * num_blocks; i++) { 98 | if (tmp[i] != h_mask[i]) { 99 | std::cout << "unmatched at " << i << ": " << int(tmp[i]) << " " << int(h_mask[i]) << std::endl; 100 | break; 101 | } 102 | } 103 | 104 | #ifdef RESULT 105 | for (int i = 0; i < num_blocks; i++) { 106 | for (int j = 0; j < num_blocks; j++) { 107 | std::cout << int(tmp[i * num_blocks + j]) << " " << int(h_mask[i * num_blocks + j]) << std::endl; 108 | } 109 | } 110 | #endif 111 | 112 | if (i == num_blocks * num_blocks) 113 | std::cout << "correct" << std::endl; 114 | 115 | std::cout << size << "x" << size << ": " << size * size * sizeof(float) * 128.0 / (ms * 1024.0 * 1024.0) << " GB/s" 116 | << std::endl; 117 | free(h_array); 118 | free(h_mask); 119 | free(tmp); 120 | cudaFree(d_mask); 121 | cudaFree(d_array); 122 | } 123 | 124 | return 0; 125 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | --------------------------------------------------------------------------------