├── .gitignore ├── .vscode ├── c_cpp_properties.json └── settings.json ├── CMakeLists.txt ├── README.md ├── docs └── 【CUDA编程】手撸大模型推理框架FasterLlama.md ├── fasterLlama ├── CMakeLists.txt ├── cuda │ ├── CMakeLists.txt │ ├── allocator.h │ ├── common.h │ ├── cuda_kernels.cuh │ ├── decoder_kernels.cu │ ├── decoder_kernels.cuh │ ├── decoding_kernels.cu │ ├── decoding_kernels.cuh │ ├── decoding_sampling.cu │ ├── open_decoder.cu │ └── utils.h ├── decoding_sampling.h ├── lib │ ├── libfldecoderkernel.so │ ├── libfldecodersampling.so │ ├── libfldecodingkernel.so │ └── libflopendecoder.so └── open_decoder.h └── samples ├── CMakeLists.txt ├── llama_fp16.cu ├── llama_fp32.cu └── test.cu /.gitignore: -------------------------------------------------------------------------------- 1 | # ignore all files in the build/ directory 2 | build/ 3 | -------------------------------------------------------------------------------- /.vscode/c_cpp_properties.json: -------------------------------------------------------------------------------- 1 | { 2 | "configurations": [ 3 | { 4 | "name": "Linux", 5 | "includePath": [ 6 | "${workspaceFolder}/**", 7 | "/usr/local/cuda-11.7/include/**" 8 | ], 9 | "defines": [], 10 | "cStandard": "c11", 11 | "cppStandard": "c++11", 12 | "intelliSenseMode": "linux-gcc-x64" 13 | } 14 | ], 15 | "version": 4 16 | } -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "files.associations": { 3 | "array": "cpp", 4 | "chrono": "cpp", 5 | "functional": "cpp", 6 | "istream": "cpp", 7 | "ostream": "cpp", 8 | "ratio": "cpp", 9 | "tuple": "cpp", 10 | "type_traits": "cpp", 11 | "utility": "cpp", 12 | "__functional_base": "cpp", 13 | "__functional_base_03": "cpp", 14 | "__hash_table": "cpp", 15 | "__tree": "cpp", 16 | "__tuple": "cpp", 17 | "algorithm": "cpp", 18 | "filesystem": "cpp", 19 | "limits": "cpp", 20 | "memory": "cpp", 21 | "random": "cpp", 22 | "string_view": "cpp", 23 | "__locale": "cpp", 24 | "__string": "cpp", 25 | "string": "cpp", 26 | "*.tcc": "cpp", 27 | "iosfwd": "cpp", 28 | "cstdint": "cpp", 29 | "cmath": "cpp", 30 | "sstream": "cpp", 31 | "stdexcept": "cpp" 32 | } 33 | } -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | 2 | 3 | cmake_minimum_required(VERSION 3.16 FATAL_ERROR) 4 | project(FasterLlama LANGUAGES CXX CUDA) 5 | 6 | set(CMAKE_CUDA_ARCHITECTURES 75) 7 | 8 | set(COMMON_HEADER_DIRS 9 | ${PROJECT_SOURCE_DIR} 10 | ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} 11 | ) 12 | 13 | set(COMMON_LIB_DIRS 14 | /usr/local/cuda-11.7/lib64 15 | ) 16 | 17 | include_directories( 18 | ${COMMON_HEADER_DIRS} 19 | ) 20 | 21 | message("-- Assign include directories (include_directories=${COMMON_HEADER_DIRS})") 22 | 23 | add_definitions(-DNDEBUG) 24 | 25 | add_subdirectory(fasterLlama) 26 | add_subdirectory(samples) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 使用 CUDA C++ 实现的大模型推理框架 FasterLLaMA 2 | 3 | ## 1 版本发布背景 4 | 在 FasterLLaMA v1.0 中,笔者提供了一个 Decoder 模块和一套推理方案 Decoding 模型,目前 FasterLLaMA v1.0 仅适配 LLaMA2,至于LLaMA3 及其他开源大模型的适配工作,将在后续版本逐步加入。其中,Decoder 相当于我们常说的 decoder layer;而 Decoding 则包含了整个解码的流程,包括词嵌入、解码层和采样解码等过程。 5 | 6 | 针对 Decoder 模块的 GEMM 场景,笔者提供了基于 cuBLAS 的 INT8 量化实现,对模型权重和激活值进行 INT8 量化,量化粒度均为 per-channel,通过 INT8 量化的矩阵运算可以高效地利用 GPU 中的 INT8 Tensor Core,在保证低精度损失的前提下,取得较好的加速比(对比 FP16 运算精度而言),要注意的是 FasterLLaMA v1.0 仅支持在计算能力不低于 7.5 的设备上运行。另外,`Q*K` 乘法和 `QK*V` 乘法部分在 v1.0 版本仍然还是使用的 FP32 类型,没有实现低精度量化。 7 | 8 | 针对 Decoding 模型的解码场景,笔者参考了 Faster Transformer,提供了两种基于采样解码的实现:top-k 解码和 top-p 解码。 9 | 10 | 数据类型方面,目前 FasterLLaMA v1.0 支持 FP32 和 FP16 两种类型,笔者针对 FP16 类型对相关 Kernel 函数模板进行了特化。 11 | 12 | 注意力机制方面,目前 FasterLLaMA v1.0 仅支持 MHA,计划在后续版本加入对 MQA 和 GQA 的支持。 13 | 14 | ## 2 整体架构 15 | FasterLLaMA v1.0 基于 CUDA、cuBLAS、CUB 等 Nvidia 官方库实现,目前仅提供 C++ API,用户可以将它集成到本机 C++ 中构建的推理服务代码中。此外笔者还提供了一些简单的示例代码来演示如何在 C++ 中执行 Decoding 过程。 16 | 17 | 下面是 Decoder 模块的整体架构图: 18 | ![](https://mmbiz.qpic.cn/sz_mmbiz_png/GJUG0H1sS5qX4u3gKYjsOZ7r3ib6Jk02RkszQibYbxMpzTOPryIsOxonbFgQicponrNVqWCrIvZiasb0heJcevSic3g/640?wx_fmt=png&from=appmsg) 19 | 20 | 下面是 Decoding 模型的整体架构图: 21 | ![](https://mmbiz.qpic.cn/sz_mmbiz_png/GJUG0H1sS5okzmlo35c3o3ibDdV7jLkLp6WL1ibGpZemlnWpgZaXxJjeTicicbzK2bQu5gqfq6SUTRbYXx7ibKAtYwg/640?wx_fmt=png&from=appmsg) -------------------------------------------------------------------------------- /fasterLlama/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | set (FASTER_LLAMA_HEADER_DIR 2 | ${PROJECT_SOURCE_DIR}/fasterLlama 3 | ${PROJECT_SOURCE_DIR}/fasterLlama/cuda 4 | ) 5 | 6 | include_directories( 7 | ${FASTER_LLAMA_HEADER_DIR} 8 | ) 9 | 10 | add_subdirectory(cuda) 11 | 12 | -------------------------------------------------------------------------------- /fasterLlama/cuda/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | set(decoder_kernel_files 2 | decoder_kernels.cu 3 | ) 4 | 5 | set(decoding_kernel_files 6 | decoding_kernels.cu 7 | ) 8 | 9 | set(open_decoder_files 10 | open_decoder.cu 11 | ) 12 | 13 | set(decoding_sampling_files 14 | decoding_sampling.cu 15 | ) 16 | 17 | set(FASTER_LLAMA_CUDA_HEADER_DIR 18 | ${PROJECT_SOURCE_DIR}/fasterLlama 19 | ${PROJECT_SOURCE_DIR}/fasterLlama/cuda 20 | ) 21 | 22 | include_directories( 23 | ${FASTER_LLAMA_CUDA_HEADER_DIR} 24 | ) 25 | 26 | set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${PROJECT_SOURCE_DIR}/fasterLlama/lib) 27 | 28 | message("-- Assign fasterLlama include directories (include_directories=${FASTER_LLAMA_CUDA_HEADER_DIR})") 29 | message("-- in fasterLlama cuda Assign arch (arch=${CMAKE_CUDA_ARCHITECTURES})") 30 | 31 | add_library(fldecoderkernel SHARED ${decoder_kernel_files}) 32 | # Request that particles be built with --std=c++14 33 | # As this is a public compile feature anything that links to particles 34 | # will also build with -std=c++14 35 | target_compile_features(fldecoderkernel PUBLIC cxx_std_14) 36 | target_compile_options(fldecoderkernel PUBLIC "-gencode=arch=compute_${CMAKE_CUDA_ARCHITECTURES},code=sm_${CMAKE_CUDA_ARCHITECTURES}") 37 | set_target_properties(fldecoderkernel PROPERTIES CUDA_SEPARABLE_COMPILATION ON) 38 | target_link_libraries(fldecoderkernel PUBLIC -lcublas -lcudart -lcurand) 39 | 40 | add_library(fldecodingkernel SHARED ${decoding_kernel_files}) 41 | # Request that particles be built with --std=c++14 42 | # As this is a public compile feature anything that links to particles 43 | # will also build with -std=c++14 44 | target_compile_features(fldecodingkernel PUBLIC cxx_std_14) 45 | target_compile_options(fldecodingkernel PUBLIC "-gencode=arch=compute_${CMAKE_CUDA_ARCHITECTURES},code=sm_${CMAKE_CUDA_ARCHITECTURES}") 46 | set_target_properties(fldecodingkernel PROPERTIES CUDA_SEPARABLE_COMPILATION ON) 47 | target_link_libraries(fldecodingkernel PUBLIC -lcublas -lcudart -lcurand) 48 | 49 | add_library(flopendecoder SHARED ${open_decoder_files}) 50 | # Request that particles be built with --std=c++14 51 | # As this is a public compile feature anything that links to particles 52 | # will also build with -std=c++14 53 | target_compile_features(flopendecoder PUBLIC cxx_std_14) 54 | target_compile_options(flopendecoder PUBLIC "-gencode=arch=compute_${CMAKE_CUDA_ARCHITECTURES},code=sm_${CMAKE_CUDA_ARCHITECTURES}") 55 | set_target_properties(flopendecoder PROPERTIES CUDA_SEPARABLE_COMPILATION ON) 56 | target_link_libraries(flopendecoder PUBLIC -lcublas -lcudart -lcurand fldecoderkernel) 57 | 58 | add_library(fldecodersampling SHARED ${decoding_sampling_files}) 59 | # Request that particles be built with --std=c++14 60 | # As this is a public compile feature anything that links to particles 61 | # will also build with -std=c++14 62 | target_compile_features(fldecodersampling PUBLIC cxx_std_14) 63 | target_compile_options(fldecodersampling PUBLIC "-gencode=arch=compute_${CMAKE_CUDA_ARCHITECTURES},code=sm_${CMAKE_CUDA_ARCHITECTURES}") 64 | set_target_properties(fldecodersampling PROPERTIES CUDA_SEPARABLE_COMPILATION ON) 65 | target_link_libraries(fldecodersampling PUBLIC -lcublas -lcudart -lcurand flopendecoder fldecodingkernel fldecoderkernel) -------------------------------------------------------------------------------- /fasterLlama/cuda/allocator.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "utils.h" 4 | #include 5 | #include 6 | 7 | namespace FasterLLaMA 8 | { 9 | 10 | /** 11 | * Pop current cuda device and set new device 12 | * i_device - device ID to set 13 | * o_device - device ID to pop 14 | * ret - return code (the same as cudaError_t) 15 | */ 16 | 17 | inline cudaError_t get_set_device(int i_device, int *o_device = NULL) 18 | { 19 | int current_dev_id = 0; 20 | cudaError_t err = cudaSuccess; 21 | 22 | if (o_device != NULL) 23 | { 24 | err = cudaGetDevice(¤t_dev_id); 25 | if (err != cudaSuccess) 26 | return err; 27 | if (current_dev_id == i_device) 28 | { 29 | *o_device = i_device; 30 | } 31 | else 32 | { 33 | err = cudaSetDevice(i_device); 34 | if (err != cudaSuccess) 35 | { 36 | return err; 37 | } 38 | *o_device = current_dev_id; 39 | } 40 | } 41 | else 42 | { 43 | err = cudaSetDevice(i_device); 44 | if (err != cudaSuccess) 45 | { 46 | return err; 47 | } 48 | } 49 | 50 | return cudaSuccess; 51 | } 52 | 53 | enum class AllocatorType 54 | { 55 | CUDA, 56 | TF, 57 | TH 58 | }; 59 | 60 | class IAllocator 61 | { 62 | public: 63 | virtual void *malloc(size_t size, const bool is_set_zero = true) const = 0; 64 | virtual void free(void *ptr) const = 0; 65 | }; 66 | 67 | template 68 | class Allocator; 69 | 70 | template <> 71 | class Allocator : public IAllocator 72 | { 73 | const int device_id_; 74 | 75 | public: 76 | Allocator(int device_id) : device_id_(device_id) {} 77 | 78 | void *malloc(size_t size, const bool is_set_zero = true) const 79 | { 80 | void *ptr = nullptr; 81 | int o_device = 0; 82 | CHECK_CUDA_ERROR(get_set_device(device_id_, &o_device)); 83 | CHECK_CUDA_ERROR(cudaMalloc(&ptr, size)); 84 | CHECK_CUDA_ERROR(get_set_device(o_device)); 85 | return ptr; 86 | } 87 | 88 | void free(void *ptr) const 89 | { 90 | int o_device = 0; 91 | CHECK_CUDA_ERROR(get_set_device(device_id_, &o_device)); 92 | CHECK_CUDA_ERROR(cudaFree(ptr)); 93 | CHECK_CUDA_ERROR(get_set_device(o_device)); 94 | return; 95 | } 96 | }; 97 | 98 | } // namespace FasterLLaMA 99 | -------------------------------------------------------------------------------- /fasterLlama/cuda/common.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | namespace FasterLLaMA 7 | { 8 | 9 | enum class OperationType 10 | { 11 | FP32, 12 | FP16, 13 | INT8 14 | }; 15 | 16 | template 17 | struct ResNormWeight 18 | { 19 | T *gamma = nullptr; 20 | float eps = 1e-5f; 21 | }; 22 | 23 | template 24 | struct DenseWeight 25 | { 26 | WeightType *kernel = nullptr; 27 | T *bias = nullptr; 28 | float *weight_scale = nullptr; 29 | }; 30 | 31 | template 32 | struct AttentionWeight 33 | { 34 | DenseWeight query_weight; 35 | DenseWeight key_weight; 36 | DenseWeight value_weight; 37 | DenseWeight attention_output_weight; 38 | }; 39 | 40 | template 41 | struct FFNWeight 42 | { 43 | DenseWeight w1_weight; 44 | DenseWeight w2_weight; 45 | DenseWeight w3_weight; 46 | }; 47 | 48 | template 49 | class TransformerTraits; 50 | 51 | template <> 52 | class TransformerTraits 53 | { 54 | public: 55 | typedef int8_t DataType; 56 | typedef int32_t AlphaType; 57 | static const OperationType OpType = OperationType::INT8; 58 | static cublasComputeType_t const computeType = CUBLAS_COMPUTE_32I; 59 | static cudaDataType_t const AType = CUDA_R_8I; 60 | static cudaDataType_t const BType = CUDA_R_8I; 61 | static cudaDataType_t const CType = CUDA_R_32I; 62 | }; 63 | 64 | template <> 65 | class TransformerTraits 66 | { 67 | public: 68 | typedef float DataType; 69 | typedef float AlphaType; 70 | static const OperationType OpType = OperationType::FP32; 71 | static cublasComputeType_t const computeType = CUBLAS_COMPUTE_32F_FAST_16F; 72 | static cudaDataType_t const AType = CUDA_R_32F; 73 | static cudaDataType_t const BType = CUDA_R_32F; 74 | static cudaDataType_t const CType = CUDA_R_32F; 75 | }; 76 | 77 | template <> 78 | class TransformerTraits 79 | { 80 | public: 81 | typedef half DataType; 82 | typedef half AlphaType; 83 | static const OperationType OpType = OperationType::FP16; 84 | static cublasComputeType_t const computeType = CUBLAS_COMPUTE_16F; 85 | static cudaDataType_t const AType = CUDA_R_16F; 86 | static cudaDataType_t const BType = CUDA_R_16F; 87 | static cudaDataType_t const CType = CUDA_R_16F; 88 | }; 89 | 90 | template 91 | class DecoderTransformerTraits; 92 | 93 | template <> 94 | class DecoderTransformerTraits : public TransformerTraits 95 | { 96 | }; 97 | 98 | template <> 99 | class DecoderTransformerTraits : public TransformerTraits 100 | { 101 | }; 102 | 103 | template <> 104 | class DecoderTransformerTraits : public TransformerTraits 105 | { 106 | }; 107 | 108 | } 109 | -------------------------------------------------------------------------------- /fasterLlama/cuda/cuda_kernels.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "common.h" 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | namespace FasterLLaMA 10 | { 11 | 12 | static inline __device__ int8_t float_to_int8_rn(float x); 13 | template 14 | struct SumOp 15 | { 16 | __device__ __forceinline__ T operator()(const T &a, const T &b) const { return a + b; } 17 | }; 18 | 19 | template 20 | struct MaxOp 21 | { 22 | __device__ __forceinline__ T operator()(const T &a, const T &b) const { return max(a, b); } 23 | }; 24 | 25 | template