├── .gitignore ├── CMakeLists.txt ├── LICENSE ├── README.md ├── src ├── CMakeLists.txt ├── kernels │ ├── CMakeLists.txt │ ├── act_kernel.cu │ ├── act_kernel.h │ ├── attn_softmax_kernel.cu │ ├── attn_softmax_kernel.h │ ├── build_casual_mask.cu │ ├── build_casual_mask.h │ ├── cal_paddingoffset.cu │ ├── cal_paddingoffset.h │ ├── concat_past_kv.cu │ ├── concat_past_kv.h │ ├── cublas_utils.cc │ ├── cublas_utils.h │ ├── fused_addresidual_norm.cu │ ├── fused_addresidual_norm.h │ ├── fused_decoder_self_attention.cu │ ├── fused_decoder_self_attention.h │ ├── fused_transpose_and_remv_pad.cu │ ├── fused_transpose_and_remv_pad.h │ ├── input_embedding.cu │ ├── input_embedding.h │ ├── linear.cu │ ├── linear.h │ ├── qkv_bias_and_RoPE.cu │ ├── qkv_bias_and_RoPE.h │ ├── repeat_kv.cu │ ├── repeat_kv.h │ ├── topK.cu │ └── topK.h ├── memory │ └── allocator │ │ ├── base_allocator.h │ │ └── cuda_allocator.h ├── models │ ├── basemodel.h │ ├── common_params.h │ └── tokenizer.h ├── utils │ ├── CMakeLists.txt │ ├── cuda_debug_utils.cuh │ ├── debug_utils.h │ ├── macro.h │ ├── model_utils.h │ ├── params.h │ ├── string_utils.h │ ├── tensor.h │ ├── vectorize_utils.h │ ├── weight_utils.cu │ └── weight_utils.h └── weights │ ├── CMakeLists.txt │ ├── base_weights.h │ ├── llama │ ├── CMakeLists.txt │ ├── attention_weights.h │ ├── embedding_weights.h │ ├── ffn_weights.h │ ├── layer_weights.cc │ ├── layer_weights.h │ ├── llama_weights.cc │ ├── llama_weights.h │ └── norm_weights.h │ └── weight.h ├── tests ├── CMakeLists.txt └── unittests │ ├── CMakeLists.txt │ ├── test_act.cu │ ├── test_bias_and_RoPE.cu │ ├── test_bmm.cu │ ├── test_cal_paddingoffset.cu │ ├── test_casual_mask.cu │ ├── test_concat_kv.cu │ ├── test_fused_addresidual_norm.cu │ ├── test_fused_decoder_attention.cu │ ├── test_fused_trans_remv_pad.cu │ ├── test_input_embedding.cu │ ├── test_linear.cu │ ├── test_mask_softmax.cu │ └── test_repeat_kv.cu └── tools ├── convert_downloaded_llama_weights.py └── weights_convert.py /.gitignore: -------------------------------------------------------------------------------- 1 | build/ -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.8 FATAL_ERROR) 2 | project(InferSpore LANGUAGES CXX CUDA) 3 | 4 | find_package(CUDA 10.0 REQUIRED) 5 | 6 | set(CUDA_PATH ${CUDA_TOOLKIT_ROOT_DIR}) 7 | 8 | 9 | list(APPEND CMAKE_MODULE_PATH ${CUDA_PATH}/lib64) 10 | find_package(CUDA REQUIRED) 11 | 12 | # setting compiler flags 13 | set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS}") 14 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") 15 | set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler -Wall") 16 | 17 | set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} \ 18 | -gencode=arch=compute_70,code=\\\"sm_70,compute_70\\\" \ 19 | -gencode=arch=compute_75,code=\\\"sm_75,compute_75\\\" \ 20 | -gencode=arch=compute_80,code=\\\"sm_80,compute_80\\\" \ 21 | -gencode=arch=compute_86,code=\\\"sm_86,compute_86\\\" \ 22 | ") 23 | # -rdc=true") # not sure the effect of this option, retain it temply 24 | 25 | set(CMAKE_CUDA_ARCHITECTURES 70 75 80 86) 26 | message("-- Assign GPU architecture (sm=70 75 80 86)") 27 | 28 | set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} -Wall -O0") 29 | set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -Wall -O0") 30 | set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -O0 -G -Xcompiler -Wall") 31 | 32 | message(STATUS "CMAKE_CXX_FLAGS" ${CMAKE_CXX_FLAGS}) 33 | 34 | set(CMAKE_CXX_STANDARD 11) 35 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 36 | 37 | if(CMAKE_CXX_STANDARD STREQUAL "11") 38 | set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-extended-lambda") 39 | set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr") 40 | set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --std=c++11") 41 | endif() 42 | 43 | set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3") 44 | set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -Xcompiler -O3") 45 | 46 | set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) 47 | set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) 48 | set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) 49 | 50 | set(COMMON_HEADER_DIRS 51 | ${PROJECT_SOURCE_DIR} 52 | ${CUDA_PATH}/include 53 | ) 54 | 55 | set(COMMON_LIB_DIRS 56 | ${CUDA_PATH}/lib64 57 | ) 58 | 59 | include_directories( 60 | ${COMMON_HEADER_DIRS} 61 | ) 62 | 63 | link_directories( 64 | ${COMMON_LIB_DIRS} 65 | ) 66 | option (PERF 67 | "measure model inference performance" 68 | OFF 69 | ) 70 | option (PRINT_DATA 71 | "print kernel output to debug" 72 | OFF 73 | ) 74 | option (SAVE_DATA 75 | "save kernel output to debug" 76 | OFF 77 | ) 78 | if (PERF) 79 | add_compile_options(-DPERF) 80 | endif() 81 | if (PRINT_DATA) 82 | add_compile_options(-DPRINT_DATA) 83 | endif() 84 | if (SAVE_DATA) 85 | add_compile_options(-DSAVE_DATA) 86 | endif() 87 | #cmake .. -DPRINT_DATA=ON && make 88 | #cmake .. -DPRINT_DATA=ON -DSAVE_DATA=ON && make 89 | #cmake .. -DPERF=ON && make 90 | #cmake .. && make 91 | file(GLOB_RECURSE LLM_CXX_SOURCES ${PROJECT_SOURCE_DIR}/src/*.cpp ${PROJECT_SOURCE_DIR}/src/*.cc) 92 | file(GLOB_RECURSE LLM_CUDA_SOURCES ${PROJECT_SOURCE_DIR}/src/*.cu) 93 | 94 | add_library(llmengine OBJECT 95 | ${LLM_CXX_SOURCES} 96 | ${LLM_CUDA_SOURCES} 97 | ) 98 | 99 | add_subdirectory(src) 100 | add_subdirectory(tests) 101 | # add_subdirectory(examples) 102 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 m0dulo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🌱 InferSpore 🧩 2 | A fully independent Large Language Model (LLM) inference engine, built leveraging cuBLAS and cub. 3 | -------------------------------------------------------------------------------- /src/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(weights) 2 | add_subdirectory(kernels) 3 | add_subdirectory(utils) -------------------------------------------------------------------------------- /src/kernels/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_library(embeddingFunctor STATIC input_embedding.cu) 2 | set_property(TARGET embeddingFunctor PROPERTY CUDA_SEPARABLE_COMPILATION ON) 3 | set_property(TARGET embeddingFunctor PROPERTY POSITION_INDEPENDENT_CODE ON) 4 | set_property(TARGET embeddingFunctor PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) 5 | 6 | add_library(cal_paddingoffset STATIC cal_paddingoffset.cu) 7 | set_property(TARGET cal_paddingoffset PROPERTY CUDA_SEPARABLE_COMPILATION ON) 8 | set_property(TARGET cal_paddingoffset PROPERTY POSITION_INDEPENDENT_CODE ON) 9 | set_property(TARGET cal_paddingoffset PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) 10 | 11 | add_library(build_casual_mask STATIC build_casual_mask.cu) 12 | set_property(TARGET build_casual_mask PROPERTY CUDA_SEPARABLE_COMPILATION ON) 13 | set_property(TARGET build_casual_mask PROPERTY POSITION_INDEPENDENT_CODE ON) 14 | set_property(TARGET build_casual_mask PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) 15 | 16 | add_library(cublasWrapper STATIC cublas_utils.cc) 17 | set_property(TARGET cublasWrapper PROPERTY POSITION_INDEPENDENT_CODE ON) 18 | set_property(TARGET cublasWrapper PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) 19 | 20 | add_library(linear STATIC linear.cu) 21 | set_property(TARGET linear PROPERTY CUDA_SEPARABLE_COMPILATION ON) 22 | set_property(TARGET linear PROPERTY POSITION_INDEPENDENT_CODE ON) 23 | set_property(TARGET linear PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) 24 | target_link_libraries(linear PUBLIC -lcudart -lcublas cublasWrapper) 25 | 26 | add_library(qkv_bias_and_rope STATIC qkv_bias_and_RoPE.cu) 27 | set_property(TARGET qkv_bias_and_rope PROPERTY CUDA_SEPARABLE_COMPILATION ON) 28 | set_property(TARGET qkv_bias_and_rope PROPERTY POSITION_INDEPENDENT_CODE ON) 29 | set_property(TARGET qkv_bias_and_rope PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) 30 | 31 | add_library(concat_kv STATIC concat_past_kv.cu) 32 | set_property(TARGET concat_kv PROPERTY CUDA_SEPARABLE_COMPILATION ON) 33 | set_property(TARGET concat_kv PROPERTY POSITION_INDEPENDENT_CODE ON) 34 | set_property(TARGET concat_kv PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) 35 | 36 | add_library(repeat_kv STATIC repeat_kv.cu) 37 | set_property(TARGET repeat_kv PROPERTY CUDA_SEPARABLE_COMPILATION ON) 38 | set_property(TARGET repeat_kv PROPERTY POSITION_INDEPENDENT_CODE ON) 39 | set_property(TARGET repeat_kv PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) 40 | 41 | add_library(mask_softmax STATIC attn_softmax_kernel.cu) 42 | set_property(TARGET mask_softmax PROPERTY CUDA_SEPARABLE_COMPILATION ON) 43 | set_property(TARGET mask_softmax PROPERTY POSITION_INDEPENDENT_CODE ON) 44 | set_property(TARGET mask_softmax PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) 45 | 46 | add_library(fused_transpose_and_remv_pad STATIC fused_transpose_and_remv_pad.cu) 47 | set_property(TARGET fused_transpose_and_remv_pad PROPERTY CUDA_SEPARABLE_COMPILATION ON) 48 | set_property(TARGET fused_transpose_and_remv_pad PROPERTY POSITION_INDEPENDENT_CODE ON) 49 | set_property(TARGET fused_transpose_and_remv_pad PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) 50 | 51 | add_library(fused_addresidual_norm STATIC fused_addresidual_norm.cu) 52 | set_property(TARGET fused_addresidual_norm PROPERTY CUDA_SEPARABLE_COMPILATION ON) 53 | set_property(TARGET fused_addresidual_norm PROPERTY POSITION_INDEPENDENT_CODE ON) 54 | set_property(TARGET fused_addresidual_norm PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) 55 | 56 | add_library(act STATIC act_kernel.cu) 57 | set_property(TARGET act PROPERTY CUDA_SEPARABLE_COMPILATION ON) 58 | set_property(TARGET act PROPERTY POSITION_INDEPENDENT_CODE ON) 59 | set_property(TARGET act PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) 60 | 61 | add_library(topk STATIC topK.cu) 62 | set_property(TARGET topk PROPERTY CUDA_SEPARABLE_COMPILATION ON) 63 | set_property(TARGET topk PROPERTY POSITION_INDEPENDENT_CODE ON) 64 | set_property(TARGET topk PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) 65 | 66 | add_library(fused_decoder_self_attention STATIC fused_decoder_self_attention.cu) 67 | set_property(TARGET fused_decoder_self_attention PROPERTY CUDA_SEPARABLE_COMPILATION ON) 68 | set_property(TARGET fused_decoder_self_attention PROPERTY POSITION_INDEPENDENT_CODE ON) 69 | set_property(TARGET fused_decoder_self_attention PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) -------------------------------------------------------------------------------- /src/kernels/act_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include "src/kernels/act_kernel.h" 3 | #include "src/utils/cuda_debug_utils.cuh" 4 | #include "src/utils/macro.h" 5 | // fp32 silu version 6 | template 7 | __device__ __forceinline__ T silu(const T& in) { 8 | // x * sigmoid(x) 9 | return (T) (((float) in) / (1.0f + expf((float) -in))); 10 | } 11 | // fp16 silu version 12 | template<> 13 | __device__ __forceinline__ half2 silu(const half2& in) { 14 | return make_half2(__float2half(silu((float)(in.x))), __float2half(silu((float)(in.y)))); 15 | } 16 | 17 | // 代码逻辑:第一个intermediate 去做silu,其结果与第二个intermediate 做点乘mul 18 | template 19 | __global__ void silu_and_mul_kernel( 20 | T* out, // shape: [bs, intermedia size] 21 | const T* input, // shape: [bs, 2, intermedia size] 22 | const int intermedia_size) { 23 | const int batch_idx = blockIdx.x; 24 | // 循环处理intermedia_size,当线程不够时,使得也能处理完 25 | for (int idx = threadIdx.x; idx < intermedia_size; idx += blockDim.x) { 26 | // 第一个和第二个intermediate处于同一buffer: input 27 | // 根据shape索引第一个intermediate 28 | const T x = input[batch_idx * 2 * intermedia_size + idx]; 29 | // 根据shape索引第二个intermediate 30 | const T y = input[batch_idx * 2 * intermedia_size + intermedia_size + idx]; 31 | // 索引到了后做计算,把计算结果写回output 32 | out[batch_idx * intermedia_size + idx] = silu(x) * y; 33 | } 34 | } 35 | 36 | template<> 37 | __global__ void silu_and_mul_kernel( 38 | half* out, // [bs, intermedia size] 39 | const half* input, // [bs, 2, intermedia size] 40 | const int intermedia_size) { 41 | const int batch_idx = blockIdx.x; 42 | // 获取fp16的向量大小 43 | int vec_size = Vec::size; 44 | // 获取fp16的向量类型half2 45 | using Vec_t = typename Vec::Type; 46 | for (int idx = threadIdx.x * vec_size; idx < intermedia_size; idx += blockDim.x) { 47 | // 与fp32实现的不同在于 48 | // 1.向量化读取 49 | // 2.使用hmul2向量化计算 50 | // 3.向量化写入 51 | const Vec_t x = *reinterpret_cast(const_cast(&input[batch_idx * 2 * intermedia_size + idx])); 52 | const Vec_t y = *reinterpret_cast(const_cast(&input[batch_idx * 2 * intermedia_size + intermedia_size + idx])); 53 | *reinterpret_cast(&out[batch_idx * intermedia_size + idx]) = __hmul2(silu(x), y); 54 | } 55 | } 56 | 57 | template 58 | void launchAct(TensorWrapper* input, TensorWrapper* out) { 59 | int batch_size = input->shape[0]; 60 | // 预防性检查,主要是防止shape的定义写错,导致不是我们expect的,那就比较难debug了 61 | LLM_CHECK(input->shape[1] == 2); 62 | int intermedia_size = input->shape[2]; 63 | dim3 grid(batch_size); 64 | dim3 block(256); 65 | silu_and_mul_kernel<<>>(out->data, input->data, intermedia_size); 66 | // for debug,打印swiglu这个kernel的输出结果 67 | #ifdef PRINT_DATA 68 | print_data<<<1, 1>>>(out->data); 69 | #else 70 | #endif 71 | } 72 | // We must instancite the template, if not, will report linking issue 73 | template void launchAct(TensorWrapper* input, TensorWrapper* output); 74 | template void launchAct(TensorWrapper* input, TensorWrapper* output); 75 | -------------------------------------------------------------------------------- /src/kernels/act_kernel.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include "src/utils/tensor.h" 6 | #include "src/utils/vectorize_utils.h" 7 | 8 | template 9 | void launchAct(TensorWrapper* input, TensorWrapper* out); -------------------------------------------------------------------------------- /src/kernels/attn_softmax_kernel.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include "src/utils/tensor.h" 6 | #include "src/utils/vectorize_utils.h" 7 | template 8 | void launchScaleMaskAndSoftmax(TensorWrapper* qk, 9 | TensorWrapper* mask, 10 | TensorWrapper* attn_score, 11 | float scale); -------------------------------------------------------------------------------- /src/kernels/build_casual_mask.cu: -------------------------------------------------------------------------------- 1 | #include "src/kernels/build_casual_mask.h" 2 | // 此算子仅使用在context decoder阶段,用于遮盖掉seq当前位置之后的信息,防止模型使用未来的信息 3 | // 而self decoder是一个自回归模型,本就没有未来的信息 4 | // mask shape = [bs, max_q_len, max_k_len] 5 | template 6 | __global__ void BuildCausalMasksConsideringContextPastKV(T* mask, 7 | const int* q_lens, //input lens, shape=[batch size] 8 | const int* k_lens, //context lens, shape=[batch size] 9 | int max_q_len, // max(q_lens) 10 | int max_k_len){ // max(k_lens) 11 | int tid = threadIdx.x; 12 | // 核函数共分配了bs个block,可以方便得通过block id来访问q_lens和k_lens数组中的值 13 | // 一个block负责处理一个bs大小中的数 14 | int qlen = q_lens[blockIdx.x]; 15 | int klen = k_lens[blockIdx.x]; 16 | // 偏移一个bs大小的空间 17 | // 即blockIdx.x==0时,指向mask数组的开头;blockIdx.x==1时,指向mask数组偏移了max_q_len * max_k_len大小后的位置 18 | mask += blockIdx.x * max_q_len * max_k_len; 19 | // offset用于表示每个bs内部的偏移 20 | int offset = threadIdx.x; 21 | // note: this judgement confirms we dont exceed data boundry 22 | while (offset < max_q_len * max_k_len){ 23 | // 分别求出行号q和列号k 24 | int q = offset / max_k_len; 25 | int k = offset % max_k_len; 26 | // k考虑了多轮对话的上下文序列,但设置mask时 k >= klen - qlen 将旧序列一并遮去了 27 | // 下图为支持多轮对话的mask 28 | // 1 1 1 | 1 -inf -inf 29 | // 1 1 1 | 1 1 -inf 30 | // 1 1 1 | 1 1 1 31 | // 下图为不支持多轮对话的mask 32 | // -inf -inf -inf | 1 -inf -inf 33 | // -inf -inf -inf | 1 1 -inf 34 | // -inf -inf -inf | 1 1 1 35 | // "|"符号前表示旧的对话序列,符号后表示当前轮的对话序列 36 | bool is_one = q < qlen && k < klen && k <= q + (klen - qlen) && k >= klen - qlen; 37 | mask[offset] = static_cast(is_one); 38 | 39 | // 保证遍历完一个bs中所有的空间 40 | offset += blockDim.x; 41 | } 42 | } 43 | 44 | template 45 | void launchBuildCausalMasks(TensorWrapper* mask, 46 | TensorWrapper* q_lens, 47 | TensorWrapper* k_lens) 48 | { 49 | int batch_size = mask->shape[0]; 50 | int max_q_len = mask->shape[1]; 51 | int max_k_len = mask->shape[2]; 52 | // XuLin-1017: 此处的max_q_len和max_k_len是经过统计后得出的外部输入 53 | BuildCausalMasksConsideringContextPastKV<<>>(mask->data, q_lens->data, k_lens->data, max_q_len, max_k_len); 54 | } 55 | 56 | template void launchBuildCausalMasks(TensorWrapper* mask, 57 | TensorWrapper* q_lens, 58 | TensorWrapper* k_lens); 59 | 60 | template void launchBuildCausalMasks(TensorWrapper* mask, 61 | TensorWrapper* q_lens, 62 | TensorWrapper* k_lens); 63 | -------------------------------------------------------------------------------- /src/kernels/build_casual_mask.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include "src/utils/tensor.h" 6 | #include "src/utils/macro.h" 7 | template 8 | void launchBuildCausalMasks(TensorWrapper* mask, 9 | TensorWrapper* q_lens, 10 | TensorWrapper* k_lens); -------------------------------------------------------------------------------- /src/kernels/cal_paddingoffset.cu: -------------------------------------------------------------------------------- 1 | #include "src/kernels/cal_paddingoffset.h" 2 | // shape: 3 | //seq_lengths:[batch size] 4 | //cum_seqlens:[batch size + 1],first ele is 0 5 | //padding_offset:[batch size * max q len] 6 | // note: the point is to calc padding offset and cum offset 7 | // TODO: we first use serial algo, then can enhance to CUDA scan algo 8 | 9 | // bs = 3, seqlen = [3,2,5], max_seq_len = 5 10 | // 1 1 1 0 0 11 | // 1 1 0 0 0 12 | // 1 1 1 1 1 13 | // cum_seqlens:[0,3,2,5] 14 | // paddingoffset 为 15 | // 0 0 0 0 0 16 | // 2 2 2 2 2 17 | // 5 5 5 5 5 18 | __global__ void CalPaddingoffset(int* padding_offset, 19 | int* cum_seqlens, 20 | const int* input_lengths, //actual input lens 21 | const int batch_size, 22 | const int max_q_len) { 23 | int ind = 0; 24 | int cum_offset = 0; 25 | int total_seqlen = 0; 26 | // 遍历每一个批次 27 | for(int b = 0; b < batch_size; b++) { 28 | // 获取到每个句子的长度 29 | int seqlen = input_lengths[b]; 30 | // 累计的句子长度 31 | cum_seqlens[b] = total_seqlen; 32 | // 遍历一个句子里的所有token位置 33 | // each token in one seq has same cum offset 34 | for (int i = 0; i < seqlen; i++) { 35 | // index是对于每个token的索引,每个token都有一个paddingoffset 36 | padding_offset[ind] = cum_offset; 37 | ind++; 38 | } 39 | // 获取累计的 padding offset 和 总共的句子长度 40 | cum_offset += max_q_len - seqlen; 41 | total_seqlen += seqlen; 42 | } 43 | // 注意 cum_seqlens 的形状,添加最后一个累计句子长度(总长度) 44 | cum_seqlens[batch_size] = total_seqlen; 45 | } 46 | 47 | // 在attention之后,可以方便的移除padding。 48 | // padding操作和 seq len 维度相关,因此相关操作需要在不涉及这一维度的计算后添加。 49 | void launchCalPaddingoffset(TensorWrapper* padding_offset, 50 | TensorWrapper* cum_seqlens, 51 | TensorWrapper* input_lengths)//actual input lens 52 | { 53 | const int batch_size = padding_offset->shape[0]; 54 | const int max_q_len = padding_offset->shape[1]; 55 | LLM_CHECK_WITH_INFO(batch_size == input_lengths->shape[0], "input lenghts numbers should equal to padding offset bs dim!") ; 56 | LLM_CHECK_WITH_INFO(batch_size == cum_seqlens->shape[0] - 1, "cum seqlen numbers should equal to padding offset bs dim + 1!") ; 57 | CalPaddingoffset<<<1, 1>>>( 58 | padding_offset->data, cum_seqlens->data, input_lengths->data, batch_size, max_q_len 59 | ); 60 | } -------------------------------------------------------------------------------- /src/kernels/cal_paddingoffset.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include "src/utils/macro.h" 6 | #include "src/utils/tensor.h" 7 | 8 | void launchCalPaddingoffset(TensorWrapper* padding_offset, 9 | TensorWrapper* cum_seqlens, 10 | TensorWrapper* input_lengths //actual input lens 11 | ); -------------------------------------------------------------------------------- /src/kernels/concat_past_kv.cu: -------------------------------------------------------------------------------- 1 | // k/v shape = [bs, kv_head num, max_q_len, head size] // 为什么这里不是max_k_len?因为q k v=w * x,此时x中seqlen维度为max_q_len 2 | // kv cache shape = [num layers, bs, kv_head num, max_seq_len, head size] = >[bs, kv_head num, seqlen[history_len: history_len + max q len] , head size] 3 | 4 | #include "src/kernels/concat_past_kv.h" 5 | #include "src/utils/cuda_debug_utils.cuh" 6 | #include 7 | template 8 | __global__ void append_key_cache(T *k_dst, // [num layers, bs, kv head num, max_q_len, head size] 9 | const size_t layer_offset, 10 | const T *k_src, // [bs, kv_head num, max_q_len, head size] 11 | const int kv_head_num, 12 | const int head_size, 13 | const int *cur_query_length, 14 | const int *history_length, 15 | const int max_q_len, 16 | const int max_seq_len) 17 | { 18 | int batch_id = blockIdx.y; 19 | int head_id = blockIdx.z; 20 | int tid = threadIdx.x; 21 | int token_id = blockIdx.x; 22 | 23 | // 指针偏移到当前layer的k cache 24 | T *k_cache_dst = k_dst + layer_offset; 25 | int cur_seq_len = cur_query_length[batch_id]; 26 | int cumsum_seq_len = history_length[batch_id]; 27 | // note: the if judge is a must, because the max_q_len is GTE than cur_seq_len. 28 | if (token_id < cur_seq_len) 29 | { 30 | // [batch, head num, max_q_len, head size] -> [batch, head num, maxseqlen[cumsum_seq_len:cumsum_seq_len + max q len], head size] 31 | int src_offset = batch_id * kv_head_num * max_q_len * head_size + 32 | head_id * max_q_len * head_size + 33 | token_id * head_size + tid; 34 | int dst_offset = batch_id * kv_head_num * max_seq_len * head_size + 35 | head_id * max_seq_len * head_size + 36 | (cumsum_seq_len + token_id) * head_size + tid; 37 | k_cache_dst[dst_offset] = k_src[src_offset]; 38 | } 39 | } 40 | 41 | template 42 | __global__ void append_value_cache(T *v_dst, 43 | const size_t layer_offset, 44 | const T *v_src, 45 | const int kv_head_num, 46 | const int head_size, 47 | const int *cur_query_length, 48 | const int *history_length, 49 | const int max_q_len, 50 | const int max_seq_len) 51 | { 52 | int batch_id = blockIdx.y; 53 | int head_id = blockIdx.z; 54 | int tid = threadIdx.x; 55 | int token_id = blockIdx.x; 56 | 57 | // (m0dulo) notes:指针偏移到v cache在当前layer的起始地址 58 | T *v_cache_dst = v_dst + layer_offset; 59 | int cur_seq_len = cur_query_length[batch_id]; 60 | int cumsum_seq_len = history_length[batch_id]; 61 | // note: the if judge is a must, because the max_q_len is greater than or equal to cur_seq_len. 62 | if (token_id < cur_seq_len) 63 | { 64 | // [batch, head num, max_q_len, head size] -> [batch, head num, maxseqlen[cumsum_seq_len:cumsum_seq_len+cur_seq_len], head size] 65 | int src_offset = batch_id * kv_head_num * max_q_len * head_size + 66 | head_id * max_q_len * head_size + 67 | token_id * head_size + tid; 68 | int dst_offset = batch_id * kv_head_num * max_seq_len * head_size + 69 | head_id * max_seq_len * head_size + 70 | (cumsum_seq_len + token_id) * head_size + tid; 71 | v_cache_dst[dst_offset] = v_src[src_offset]; 72 | } 73 | } 74 | 75 | template 76 | void launchConcatKVCache(TensorWrapper *k_src, // from qkv bias and rope {batch_size, kv_head_num, max_q_len, head_size} 77 | TensorWrapper *v_src, 78 | TensorWrapper *layer_id, // layer offset = layer_id * batchxbeam * max_seq_len * kv_head_num * head_size 79 | TensorWrapper *cur_query_length, // current epoch or local input length,[batchsize] 80 | TensorWrapper *history_length, 81 | TensorWrapper *k_dst, //{num_layers, batch_size, kv_head_num, max_seq_len, head_size} 82 | TensorWrapper *v_dst) 83 | { 84 | int batch_size = k_src->shape[0]; 85 | int max_seq_len = k_dst->shape[3]; 86 | int kv_head_num = k_src->shape[1]; 87 | int max_q_len = k_src->shape[2]; 88 | int head_size = k_src->shape[3]; 89 | int blockSize = head_size; 90 | int layer = layer_id->getVal(); 91 | size_t layer_offset = layer * batch_size * kv_head_num * max_seq_len * head_size; 92 | dim3 grid(max_q_len, batch_size, kv_head_num); 93 | append_key_cache<<>>(k_dst->data, 94 | layer_offset, 95 | k_src->data, 96 | kv_head_num, 97 | head_size, 98 | /*(int*)*/ cur_query_length->data, 99 | /*(int*)*/ history_length->data, 100 | max_q_len, 101 | max_seq_len); 102 | 103 | append_value_cache<<>>(v_dst->data, 104 | layer_offset, 105 | v_src->data, 106 | kv_head_num, 107 | head_size, 108 | /*(int*)*/ cur_query_length->data, 109 | /*(int*)*/ history_length->data, 110 | max_q_len, 111 | max_seq_len); 112 | 113 | } 114 | 115 | template void launchConcatKVCache(TensorWrapper *k_src, // from qkv bias and rope 116 | TensorWrapper *v_src, 117 | TensorWrapper *layer_id, // layer offset = layer_id * batchxbeam * max_seq_len * kv_head_num * head_size 118 | TensorWrapper *cur_query_length, // current epoch or local input length,[batchsize] 119 | TensorWrapper *history_length, 120 | TensorWrapper *k_dst, 121 | TensorWrapper *v_dst); 122 | 123 | template void launchConcatKVCache(TensorWrapper *k_src, // from qkv bias and rope 124 | TensorWrapper *v_src, 125 | TensorWrapper *layer_id, // layer offset = layer_id * batchxbeam * max_seq_len * kv_head_num * head_size 126 | TensorWrapper *cur_query_length, // current epoch or local input length,[batchsize] 127 | TensorWrapper *history_length, 128 | TensorWrapper *k_dst, 129 | TensorWrapper *v_dst); -------------------------------------------------------------------------------- /src/kernels/concat_past_kv.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include "src/utils/tensor.h" 6 | 7 | template 8 | void launchConcatKVCache(TensorWrapper *k_src, 9 | TensorWrapper *v_src, 10 | TensorWrapper *layer_id, 11 | TensorWrapper *cur_query_length, 12 | TensorWrapper *history_length, 13 | TensorWrapper *k_dst, 14 | TensorWrapper *v_dst); -------------------------------------------------------------------------------- /src/kernels/cublas_utils.cc: -------------------------------------------------------------------------------- 1 | #include "cublas_utils.h" 2 | #include 3 | 4 | cublasWrapper::cublasWrapper(cublasHandle_t cublas_handle, 5 | cublasLtHandle_t cublaslt_handle): 6 | cublas_handle_(cublas_handle), 7 | cublaslt_handle_(cublaslt_handle) 8 | { 9 | } 10 | 11 | cublasWrapper::~cublasWrapper() 12 | { 13 | } 14 | // invoked in model example main function after initialize cublas wrapper 15 | void cublasWrapper::setFP32GemmConfig() 16 | { 17 | Atype_ = CUDA_R_32F; 18 | Btype_ = CUDA_R_32F; 19 | Ctype_ = CUDA_R_32F; 20 | computeType_ = CUDA_R_32F; 21 | } 22 | 23 | void cublasWrapper::setFP16GemmConfig() 24 | { 25 | Atype_ = CUDA_R_16F; 26 | Btype_ = CUDA_R_16F; 27 | Ctype_ = CUDA_R_16F; 28 | computeType_ = CUDA_R_32F; 29 | } 30 | 31 | //fp32 gemm and fp16 gemm 32 | void cublasWrapper::Gemm(cublasOperation_t transa, 33 | cublasOperation_t transb, 34 | const int m, 35 | const int n, 36 | const int k, 37 | const void* A, 38 | const int lda, 39 | const void* B, 40 | const int ldb, 41 | void* C, 42 | const int ldc, 43 | float f_alpha = 1.0f, 44 | float f_beta = 0.0f) 45 | { 46 | half h_alpha = (half)(f_alpha); 47 | half h_beta = (half)(f_beta); 48 | int is_fp16_computeType = computeType_ == CUDA_R_16F ? 1 : 0; //之前是CUDA_R_16F 49 | const void* alpha = is_fp16_computeType ? reinterpret_cast(&(h_alpha)) : reinterpret_cast(&f_alpha); 50 | const void* beta = is_fp16_computeType ? reinterpret_cast(&(h_beta)) : reinterpret_cast(&f_beta); 51 | CHECK_CUBLAS(cublasGemmEx(cublas_handle_, 52 | transa, 53 | transb, 54 | m, 55 | n, 56 | k, 57 | alpha, 58 | A, 59 | Atype_, 60 | lda, 61 | B, 62 | Btype_, 63 | ldb, 64 | beta, 65 | C, 66 | Ctype_, 67 | ldc, 68 | computeType_, 69 | CUBLAS_GEMM_DEFAULT)); 70 | } 71 | 72 | void cublasWrapper::stridedBatchedGemm(cublasOperation_t transa, 73 | cublasOperation_t transb, 74 | const int m, 75 | const int n, 76 | const int k, 77 | const void* A, 78 | const int lda, 79 | const int64_t strideA, 80 | const void* B, 81 | const int ldb, 82 | const int64_t strideB, 83 | void* C, 84 | const int ldc, 85 | const int64_t strideC, 86 | const int batchCount, 87 | float f_alpha = 1.0f, 88 | float f_beta = 0.0f) 89 | { 90 | int is_fp16_computeType = computeType_ == CUDA_R_16F ? 1 : 0; 91 | const void* alpha = 92 | is_fp16_computeType ? reinterpret_cast(&(f_alpha)) : reinterpret_cast(&f_alpha); 93 | const void* beta = is_fp16_computeType ? reinterpret_cast(&(f_beta)) : reinterpret_cast(&f_beta); 94 | CHECK_CUBLAS(cublasGemmStridedBatchedEx(cublas_handle_, 95 | transa, 96 | transb, 97 | m, 98 | n, 99 | k, 100 | alpha, 101 | A, 102 | Atype_, 103 | lda, 104 | strideA, 105 | B, 106 | Btype_, 107 | ldb, 108 | strideB, 109 | beta, 110 | C, 111 | Ctype_, 112 | ldc, 113 | strideC, 114 | batchCount, 115 | computeType_, 116 | CUBLAS_GEMM_DEFAULT)); 117 | } 118 | -------------------------------------------------------------------------------- /src/kernels/cublas_utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include "src/utils/macro.h" 8 | //1.cublas API: must allocate the required matrices in the GPU memory space, 9 | // fill them with data, call the sequence of desired cuBLAS functions, and then upload the results back to the host. 10 | //2.cublasXt API: have the data on the Host 11 | //3.cuBLASLt API: lightweight library dedicated to GEMM with a new flexible API. 12 | // adds flexibility in matrix data layouts, input types, compute types, and also in choosing the algorithmic implementations and heuristics through parameter programmability 13 | class cublasWrapper { 14 | private: 15 | cublasHandle_t cublas_handle_; 16 | cublasLtHandle_t cublaslt_handle_; 17 | 18 | cudaDataType_t Atype_; 19 | cudaDataType_t Btype_; 20 | cudaDataType_t Ctype_; 21 | cudaDataType_t computeType_; 22 | 23 | public: 24 | cublasWrapper(cublasHandle_t cublas_handle_, 25 | cublasLtHandle_t cublaslt_handle_); 26 | // BaseAllocator* allocator); enable it when we use cublasLt API 27 | 28 | ~cublasWrapper(); 29 | void setFP32GemmConfig(); 30 | void setFP16GemmConfig(); 31 | //for proj matmul 32 | void Gemm(cublasOperation_t transa, 33 | cublasOperation_t transb, 34 | const int m, 35 | const int n, 36 | const int k, 37 | const void* A, 38 | const int lda, 39 | const void* B, 40 | const int ldb, 41 | void* C, 42 | const int ldc, 43 | float alpha, 44 | float beta); 45 | // for qk*v and q*k 46 | void stridedBatchedGemm(cublasOperation_t transa, 47 | cublasOperation_t transb, 48 | const int m, 49 | const int n, 50 | const int k, 51 | const void* A, 52 | const int lda, 53 | const int64_t strideA, 54 | const void* B, 55 | const int ldb, 56 | const int64_t strideB, 57 | void* C, 58 | const int ldc, 59 | const int64_t strideC, 60 | const int batchCount, 61 | float f_alpha, 62 | float f_beta); 63 | }; 64 | -------------------------------------------------------------------------------- /src/kernels/fused_addresidual_norm.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include "src/utils/cuda_debug_utils.cuh" 3 | #include "src/kernels/fused_addresidual_norm.h" 4 | 5 | template 6 | __device__ T warpReduceSum(T val){ 7 | for(int i = 32 / 2; i > 0; i >>= 1){ 8 | val += __shfl_xor_sync(0xffffffff, val, i); 9 | } 10 | return val; // 32 threads return val, but only 0th thread is sum val 11 | } 12 | // (m0dulo) notes:!!!when blocksize < 32, use blockDim.x/32 to get warp nums is wrong, we should ceil it instead 13 | template 14 | __device__ T blockReduceSum(T val){ 15 | int tid = threadIdx.x; 16 | int wid = tid / 32; 17 | int laneid = tid % 32; 18 | int warpnum = (blockDim.x + 31) / 32; 19 | static __shared__ T warpsum[64]; 20 | val = warpReduceSum(val); 21 | if(laneid == 0){ 22 | warpsum[wid] = val; 23 | } 24 | __syncthreads(); 25 | 26 | T sum = tid < warpnum ? warpsum[tid] : (T)0.0f; 27 | sum = warpReduceSum(sum); //though 0th own the sum, but dont need to shfl sync 28 | return sum; 29 | } 30 | // 1.this kernel is used after self attention in every layer 31 | // 2.I allocate threads number by assuming head size can be divided by 4 and 2 32 | template 33 | __global__ void FusedAddBiasResidualRMSNorm( // residual.shape = [num tokens, hidden_units] 34 | T* residual, 35 | T* decoder_out, // [num tokens, hidden_units] 36 | /*optional*/const T* bias, // [hidden_units] 37 | const T* scale, // [hidden_units], RMSNorm weights 38 | float eps, // RMSNorm eps 39 | int num_tokens, 40 | int hidden_units){ 41 | int vec_size = Vec::size; 42 | using Vec_t = typename Vec::Type; 43 | int batch_id = blockIdx.x; 44 | int tid = threadIdx.x; 45 | Vec_t *rsd, *bia, *s; 46 | Vec_t tmp; 47 | Vec_t* de_out = reinterpret_cast(decoder_out + batch_id * hidden_units);// note the offset should divide vec size 48 | 49 | T thread_accm = static_cast(0); 50 | rsd = reinterpret_cast(residual + batch_id * hidden_units);//note the offset should divide vec size 51 | if (bias != nullptr){ 52 | bia = reinterpret_cast(const_cast(bias)); 53 | } 54 | 55 | for(int i = tid; i < hidden_units / vec_size; i += blockDim.x) { 56 | if (residual != nullptr) { 57 | de_out[i].x += rsd[i].x; 58 | de_out[i].y += rsd[i].y; 59 | de_out[i].z += rsd[i].z; 60 | de_out[i].w += rsd[i].w; 61 | // update residual to be used in add residual kernel at the end of every decoder layer 62 | rsd[i].x = de_out[i].x; 63 | rsd[i].y = de_out[i].y; 64 | rsd[i].z = de_out[i].z; 65 | rsd[i].w = de_out[i].w; 66 | } 67 | //TODO: to update rsd by rsd + bias when bias is valid 68 | if (bias != nullptr) { 69 | de_out[i].x += bia[i].x; 70 | de_out[i].y += bia[i].y; 71 | de_out[i].z += bia[i].z; 72 | de_out[i].w += bia[i].w; 73 | } 74 | thread_accm += de_out[i].x * de_out[i].x; 75 | thread_accm += de_out[i].y * de_out[i].y; 76 | thread_accm += de_out[i].z * de_out[i].z; 77 | thread_accm += de_out[i].w * de_out[i].w; 78 | } // addresidual 79 | // mean(x^2) 80 | T blocksum = blockReduceSum(thread_accm); 81 | __shared__ float inv_fenmu; 82 | if(tid == 0){ 83 | inv_fenmu = rsqrt(blocksum / hidden_units + eps); 84 | //debug info printf("inv_fenmu on GPU is %f\n", inv_fenmu); 85 | } 86 | __syncthreads(); 87 | // rmsnorm 88 | if (scale != nullptr){ 89 | s = reinterpret_cast(const_cast(scale)); 90 | } 91 | for(int i = tid; i < hidden_units / vec_size; i += blockDim.x) { 92 | //s = reinterpret_cast(const_cast(scale))[i]; 93 | de_out[i].x = s[i].x * de_out[i].x * inv_fenmu; 94 | de_out[i].y = s[i].y * de_out[i].y * inv_fenmu; 95 | de_out[i].z = s[i].z * de_out[i].z * inv_fenmu; 96 | de_out[i].w = s[i].w * de_out[i].w * inv_fenmu; 97 | } 98 | } 99 | 100 | template<> 101 | __global__ void FusedAddBiasResidualRMSNorm( // residual.shape = [num tokens, hidden_units] 102 | half* residual, 103 | half* decoder_out, // [num tokens, hidden_units] 104 | const half* bias, //[hidden_units] 105 | const half* scale, //[hidden_units], RMSNorm weights 106 | float eps, //RMSNorm eps 107 | int num_tokens, 108 | int hidden_units){ 109 | int vec_size = Vec::size; 110 | using Vec_t = typename Vec::Type; 111 | int batch_id = blockIdx.x; 112 | int tid = threadIdx.x; 113 | Vec_t *rsd, *bia, *s; 114 | Vec_t dout, tmp; 115 | float thread_accm = 0.0f; 116 | if (residual != nullptr && bias != nullptr){ 117 | rsd = reinterpret_cast(residual + batch_id * hidden_units);//note the offset should divide vec size 118 | bia = reinterpret_cast(const_cast(bias)); 119 | } 120 | for(int i = tid; i < hidden_units / vec_size; i += blockDim.x) { 121 | dout = reinterpret_cast(decoder_out)[batch_id * hidden_units / vec_size + i];// note the offset should divide vec size 122 | tmp = __hadd2(__hadd2(dout, rsd[i]), bia[i]); 123 | thread_accm += __half2float(tmp.x) * __half2float(tmp.x) + __half2float(tmp.y) * __half2float(tmp.y); 124 | } // addresidual 125 | // mean(x^2) 126 | float blocksum = blockReduceSum(thread_accm); 127 | __shared__ Vec_t inv_fenmu; 128 | if(tid == 0){ 129 | //debug info printf("blocksum on GPU is %f\n", blocksum); 130 | inv_fenmu = scalar_cast_vec(__float2half(rsqrt(blocksum / hidden_units + eps))); 131 | //debug info printf("inv_fenmu on GPU is %f\n", inv_fenmu); 132 | } 133 | // rmsnorm 134 | Vec_t* out = reinterpret_cast(decoder_out + batch_id * hidden_units);// note before vec the stride is batch_id * hiddenunits w/o / vecsize 135 | if (scale != nullptr){ 136 | s = reinterpret_cast(const_cast(scale)); 137 | } 138 | for(int i = tid; i < hidden_units / vec_size; i += blockDim.x) { 139 | out[i] = __hmul2(__hmul2(s[i], out[i]), inv_fenmu); 140 | } 141 | } 142 | 143 | template 144 | void launchFusedAddBiasResidualRMSNorm( // residual.shape = [num tokens, hidden_units] 145 | TensorWrapper* residual, 146 | TensorWrapper* decoder_out, // [num tokens, hidden_units] 147 | BaseWeight& norm, 148 | T* scale, //RMSNorm weights 149 | float eps) //RMSNorm eps 150 | { 151 | int batch_size = decoder_out->shape[0]; 152 | int hidden_units = decoder_out->shape[1]; 153 | T* bias = norm.bias; 154 | T* gamma = scale; 155 | int vec_size = Vec::size; 156 | int num_threads = hidden_units / vec_size; // assume head size can be divided by 4 and 2 157 | dim3 grid(batch_size); 158 | dim3 block(num_threads); 159 | FusedAddBiasResidualRMSNorm<<>>(residual->data, 160 | decoder_out->data, 161 | bias, 162 | gamma, 163 | eps, 164 | batch_size, 165 | hidden_units); 166 | #ifdef PRINT_DATA 167 | print_data<<<1, 1>>>(decoder_out->data); 168 | #else 169 | #endif 170 | } 171 | template void launchFusedAddBiasResidualRMSNorm( // residual.shape = [num tokens, hidden_units] 172 | TensorWrapper* residual, 173 | TensorWrapper* decoder_out, // [num tokens, hidden_units] 174 | BaseWeight& norm, 175 | float* scale, //RMSNorm weights 176 | float eps); 177 | template void launchFusedAddBiasResidualRMSNorm( // residual.shape = [num tokens, hidden_units] 178 | TensorWrapper* residual, 179 | TensorWrapper* decoder_out, // [num tokens, hidden_units] 180 | BaseWeight& norm, 181 | half* scale, //RMSNorm weights 182 | float eps); -------------------------------------------------------------------------------- /src/kernels/fused_addresidual_norm.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include "src/weights/base_weights.h" 6 | #include "src/weights/llama/norm_weights.h" 7 | #include "src/utils/tensor.h" 8 | #include "src/utils/vectorize_utils.h" 9 | template 10 | void launchFusedAddBiasResidualRMSNorm( 11 | TensorWrapper* residual, 12 | TensorWrapper* decoder_out, 13 | BaseWeight& norm, 14 | T* scale, 15 | float eps); -------------------------------------------------------------------------------- /src/kernels/fused_decoder_self_attention.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include "src/utils/tensor.h" 6 | #include "src/models/llama/llama_params.h" 7 | #include "src/weights/base_weights.h" 8 | #include "src/utils/vectorize_utils.h" 9 | 10 | template 11 | void launchDecoderMaskedMHA(TensorWrapper* qkv_buf, 12 | BaseWeight& qkv, 13 | TensorWrapper* layer_id, 14 | TensorWrapper* k_cache, 15 | TensorWrapper* v_cache, 16 | TensorWrapper* finished, 17 | TensorWrapper* step, 18 | TensorWrapper* mha_output, 19 | LLaMAAttentionStaticParams& static_params); -------------------------------------------------------------------------------- /src/kernels/fused_transpose_and_remv_pad.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include "src/utils/tensor.h" 6 | 7 | template 8 | void launchTransposeOutRemovePadding(TensorWrapper* qkv_buf_w_pad, 9 | TensorWrapper* padding_offset, 10 | TensorWrapper* qkv_buf_wo_pad_1); -------------------------------------------------------------------------------- /src/kernels/input_embedding.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include "src/kernels/input_embedding.h" 3 | #include "src/utils/cuda_debug_utils.cuh" 4 | template 5 | __global__ void embeddingFunctor(const int* input_ids, 6 | T* output, 7 | const T* embed_table, 8 | const int max_context_token_num, 9 | const int hidden_size) 10 | { 11 | // zhaziqwe:得到全局线程id 12 | int index = blockIdx.x * blockDim.x + threadIdx.x; 13 | // 如果分配的总线程数不够,那么在输出范围内的线程id再来多轮 14 | while (index < max_context_token_num * hidden_size) { 15 | // 拿到token_num这一个维度的序号,输出的行号 16 | int id = input_ids[index / hidden_size]; 17 | // 通过token id索引到对应embedding table的行号,此行号乘上embedding的列数,得到该token id的hidden units 18 | // 每个线程并行地把hidden size个值读取并写到对应线程id位置的output 19 | output[index] = embed_table[id * hidden_size + index % hidden_size]; 20 | // 当前线程处理完一轮,累加index到下一轮,防止总线程数不够。 21 | index += blockDim.x * gridDim.x; 22 | } 23 | } 24 | 25 | template 26 | void launchInputEmbedding(TensorWrapper* input_ids, // INT [token num] 27 | TensorWrapper* output, // FP32 [token num, hidden_size] = [token num, 4096] 28 | EmbeddingWeight* embed_table// FP32 [vocab_size, hidden_size] 29 | ) { 30 | // zhaziqwe:分配线程块,核函数需要的维度信息 31 | const int blockSize = 256; 32 | const int max_context_token_num = output->shape[0]; // token num 33 | const int hidden_size = output->shape[1]; 34 | const int gridSize = 2048; 35 | LLM_CHECK_WITH_INFO(max_context_token_num == input_ids->shape[0], "input ids 1st shape should equal to 1st shape of output"); 36 | embeddingFunctor<<>>(input_ids->data, 37 | output->data, 38 | embed_table->data, 39 | max_context_token_num, 40 | hidden_size); 41 | #ifdef PRINT_DATA 42 | print_data<<<1, 1>>>(output->data); 43 | #else 44 | #endif 45 | } 46 | 47 | // zhaziqwe: 显式实例化模版函数,由于cuda的语法规则,不能存在.cpp文件里,因此只能在此实例化 48 | template void launchInputEmbedding(TensorWrapper* input_ids, 49 | TensorWrapper* output, 50 | EmbeddingWeight* embed_table); 51 | template void launchInputEmbedding(TensorWrapper* input_ids, 52 | TensorWrapper* output, 53 | EmbeddingWeight* embed_table); 54 | -------------------------------------------------------------------------------- /src/kernels/input_embedding.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include "src/utils/tensor.h" 5 | #include "src/weights/llama/embedding_weights.h" 6 | template 7 | void launchInputEmbedding(TensorWrapper* input_ids, 8 | TensorWrapper* output, 9 | EmbeddingWeight* embed_table); -------------------------------------------------------------------------------- /src/kernels/linear.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "src/utils/cuda_debug_utils.cuh" 4 | #include "src/kernels/linear.h" 5 | // TODO: when abstracted weight class, replace T with class 6 | // all matmul cases: 7 | // ctx qkv lienar: [num_tokens, qhiddenunits] * [qhiddenunits, hiddenunits] = {num_tokens, qkv_head_num, head_size} 8 | // ctx attn output linear: {num_tokens, head_num, head_size} * {q hidden units, q hidden units} = {num_tokens, q hidden units} 9 | // self qkv linear: [bs, q hidden units] * [qhiddenunits, hiddenunits] = {bs, qkv_head_num, head_size}} 10 | // self attn output linear: {batch_size, q hidden_units} * [qhiddenunits, qhiddenunits] = [bs, q hiddenunits] 11 | // lmhead linear: [bs, q hidden units] * [vocab size, q hiden units], need transpose B 12 | // gate:[bs/token nums, q hidden units] * [q hidden units, inter size] = [bs/token nums, inter size] 13 | // up:[bs/token nums, q hidden units] * [q hidden units, inter size] = [bs/token nums, inter size] 14 | // fusedGateUpGemm: [bs/token nums, q hidden units] * [q hidden units, 2 * inter size] = [bs/token nums, 2 * inter size] 15 | // down:[bs/token nums, inter size] * [q hidden units, inter size] = [bs/token nums, q hidden units] 16 | template 17 | void launchLinearGemm(TensorWrapper *input, 18 | BaseWeight &weight, 19 | TensorWrapper *output, 20 | cublasWrapper *cublas_wrapper, 21 | bool trans_a, 22 | bool trans_b) 23 | { 24 | int Am = weight.shape[1]; 25 | int Ak = weight.shape[0]; 26 | int Bk = input->shape[1]; 27 | int Bn = input->shape[0]; 28 | int Cm = output->shape[1]; 29 | int Cn = output->shape[0]; 30 | // for ctx attn and self attn qkv linear, assume [bs/token nums, qkv h ead num, head size] 31 | // for gate & up linear, assume weight.shape=[hidden,2*intersize], output.shape=[bs, 2, inter size] 32 | Cm = output->shape.size() == 3 ? output->shape[1] * output->shape[2] : output->shape[1]; 33 | // for ctx attn output linear 34 | Bk = input->shape.size() == 3 ? input->shape[1] * input->shape[2] : input->shape[1]; 35 | int lda = Am; 36 | int ldb = Bk; 37 | int ldc = Cm; 38 | 39 | // for lmhead linear and ffn all lieanrs 40 | cublasOperation_t transA = trans_b ? CUBLAS_OP_T : CUBLAS_OP_N; 41 | cublasOperation_t transB = trans_a ? CUBLAS_OP_T : CUBLAS_OP_N; 42 | if (!trans_a && !trans_b) 43 | { 44 | LLM_CHECK_WITH_INFO(Ak == Bk, "2nd dim of input MUST = 1st dim of weight"); 45 | } 46 | cublas_wrapper->Gemm(transA, 47 | transB, 48 | trans_b ? Ak : Am, // m 49 | Cn, // n, when load real weight, lmhead weight is same as pre embedding, which shape = [vocab, hidden], so here should transpose b 50 | Bk, 51 | weight.data, // A, cur_input_len is for context decoder lmhead 52 | lda, // lda 53 | input->data, // B 54 | ldb, // ldb 55 | output->data, // C 56 | ldc, // ldc 57 | 1.0f, 58 | 0.0f); 59 | #ifdef PRINT_DATA 60 | print_data<<<1, 1>>>(output->data); 61 | #else 62 | #endif 63 | } 64 | 65 | template 66 | void launchLinearStridedBatchGemm(TensorWrapper *input1, 67 | TensorWrapper *input2, 68 | TensorWrapper *output, 69 | cublasWrapper *cublas_wrapper, 70 | bool trans_a, 71 | bool trans_b) 72 | { 73 | // B.T A.T = C.T 74 | 75 | int Bm = input1->shape[2]; // len q // len q 76 | int Bk = input1->shape[3]; // head size // len k 77 | int Ak = input2->shape[2]; // len k // len k 78 | int An = input2->shape[3]; // head size // head size 79 | int Cm = output->shape[2]; // len q // len q 80 | int Cn = output->shape[3]; // len k // head size 81 | int lda = An; 82 | int ldb = Bk; // ld should be val before transpose 83 | int ldc = Cn; 84 | int64_t strideA = Ak * An; // stride should be val after transpose 85 | int64_t strideB = Bm * Bk; 86 | int64_t strideC = Cm * Cn; 87 | 88 | int batchCount = input1->shape[0] * input1->shape[1]; 89 | 90 | cublasOperation_t transA = trans_b ? CUBLAS_OP_T : CUBLAS_OP_N; 91 | cublasOperation_t transB = trans_a ? CUBLAS_OP_T : CUBLAS_OP_N; 92 | 93 | cublas_wrapper->stridedBatchedGemm(transA, 94 | transB, 95 | Cn, // m 96 | Cm, // n 97 | Bk, // k 98 | input2->data, // A,[Bk, Bn]=[bs, head num, head size,max k len] 99 | lda, 100 | strideA, 101 | input1->data, // B [Ak, An]=[bs, head num, head size,max q len] 102 | ldb, 103 | strideB, 104 | output->data, // C [[bs, head num, max k len, max q len] 105 | ldc, 106 | strideC, 107 | batchCount, 108 | 1.0f, 109 | 0.0f); 110 | #ifdef PRINT_DATA 111 | print_data<<<1, 1>>>(output->data); 112 | #else 113 | #endif 114 | } 115 | 116 | template void launchLinearGemm(TensorWrapper *input, 117 | BaseWeight &weight, 118 | TensorWrapper *output, 119 | cublasWrapper *cublas_wrapper, 120 | bool trans_a, 121 | bool trans_b); 122 | 123 | template void launchLinearGemm(TensorWrapper *input, 124 | BaseWeight &weight, 125 | TensorWrapper *output, 126 | cublasWrapper *cublas_wrapper, 127 | bool trans_a, 128 | bool trans_b); 129 | 130 | template void launchLinearStridedBatchGemm(TensorWrapper *input1, 131 | TensorWrapper *input2, 132 | TensorWrapper *output, 133 | cublasWrapper *cublas_wrapper, 134 | bool trans_a, 135 | bool trans_b); 136 | 137 | template void launchLinearStridedBatchGemm(TensorWrapper *input1, 138 | TensorWrapper *input2, 139 | TensorWrapper *output, 140 | cublasWrapper *cublas_wrapper, 141 | bool trans_a, 142 | bool trans_b); 143 | -------------------------------------------------------------------------------- /src/kernels/linear.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include "src/kernels/cublas_utils.h" 7 | #include "src/utils/tensor.h" 8 | #include "src/weights/base_weights.h" 9 | #include "src/utils/macro.h" 10 | 11 | template 12 | void launchLinearGemm(TensorWrapper* input, 13 | BaseWeight& weight, 14 | TensorWrapper* output, 15 | cublasWrapper* cublas_wrapper, 16 | bool trans_a = false, 17 | bool trans_b = false); 18 | template 19 | void launchLinearStridedBatchGemm(TensorWrapper* input1, 20 | TensorWrapper* input2, 21 | TensorWrapper* output, 22 | cublasWrapper* cublas_wrapper, 23 | bool trans_a = false, 24 | bool trans_b = false); 25 | -------------------------------------------------------------------------------- /src/kernels/qkv_bias_and_RoPE.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include "src/models/llama/llama_params.h" 6 | #include "src/utils/tensor.h" 7 | #include "src/weights/base_weights.h" 8 | #include "src/utils/vectorize_utils.h" 9 | 10 | template 11 | void launchAddFusedQKVBiasTransposeAndRoPE(TensorWrapper* q_buf, 12 | TensorWrapper* k_buf, 13 | TensorWrapper* v_buf, 14 | TensorWrapper* QKV, 15 | BaseWeight& qkv, 16 | //Tensor* qkv_bias, 17 | TensorWrapper* padding_offset, 18 | TensorWrapper* history_length, 19 | TensorWrapper* input_length, 20 | LLaMAAttentionStaticParams& params); 21 | 22 | template 23 | void launchRoPE(TensorWrapper* qkv_buf, 24 | TensorWrapper* step, 25 | LLaMAAttentionStaticParams& static_params); -------------------------------------------------------------------------------- /src/kernels/repeat_kv.cu: -------------------------------------------------------------------------------- 1 | #include "src/kernels/repeat_kv.h" 2 | #include "src/utils/cuda_debug_utils.cuh" 3 | #include 4 | // if MQA or GQA, we should use this repeat kv kernel to broadcast kv head num to q head num 5 | // 此kernel的输入输出维度变化: [num layers, bs, kv head num, max_seq_len, head size]=>[bs, q head num, max_k_len, head size] 6 | // context_length.shape = [bs] 7 | // bugs1: when k_dst.shape = [1,32,13,128],现在这个k_dst以13*128为单位循环第一个13*128的值 8 | // solu1: launcher函数里面获取kv cache的shape出错,需要仔细核对各个TensorWrapper的shape再通过正确索引获取 9 | // fp32和fp16 kernel都是下面这个模板实现 10 | template 11 | __global__ void repeat_value_cache(T *v_dst, 12 | const T *v_src, 13 | const size_t layer_offset, 14 | const int head_num, 15 | const int q_head_per_kv, 16 | const int head_size, 17 | const int *context_length, 18 | const int max_k_len, 19 | const int max_seq_len) 20 | { 21 | const int batch_id = blockIdx.y; 22 | const int head_id = blockIdx.z; 23 | // x block维度上的全局线程id 24 | const int idx = blockIdx.x * blockDim.x + threadIdx.x; 25 | // 当前layer的key/value偏移 26 | const auto val_src = v_src + layer_offset; 27 | const auto val_dst = v_dst; 28 | // 当前句子的长度 29 | const auto seq_len = context_length[batch_id]; 30 | // headsize维度上的数据偏移 31 | const int v_head_size_id = idx % head_size; 32 | // seqlen维度上的数据偏移 33 | const int v_seq_len_id = idx / head_size; 34 | // only fetch context_length( 52 | void launchRepeatKVCache(TensorWrapper *k_cache_src, //{num_layers, batch_size, kv_head_num, max_seq_len, head_size} 53 | TensorWrapper *v_cache_src, //{num_layers, batch_size, kv_head_num, max_seq_len, head_size} 54 | TensorWrapper *context_length, 55 | TensorWrapper *layer_id, 56 | TensorWrapper *k_cache_dst, //{batch_size, head_num, max_k_len, head_size} 57 | TensorWrapper *v_cache_dst) 58 | { 59 | int batch_size = context_length->shape[0]; 60 | int kv_head_num = k_cache_src->shape[2]; // (m0dulo)note: we should carefully access the shape value, corresponding to the place where tensorwapper is defined 61 | int max_seq_len = k_cache_src->shape[3]; 62 | int head_num = k_cache_dst->shape[1]; 63 | 64 | int max_k_len = k_cache_dst->shape[2]; 65 | int head_size = k_cache_dst->shape[3]; 66 | // (m0dulo)note: if layer id is on GPU, here MUSTN'T use layer_id->getVal(), because we cant access GPU memory directly by [] if data is on GPU 67 | // (m0dulo)note: so we can make layer data locate on CPU, so that we can access data by [] 68 | int layer = layer_id->getVal(); 69 | 70 | size_t layer_offset = layer * batch_size * kv_head_num * max_seq_len * head_size; 71 | int q_head_per_kv = head_num / kv_head_num; 72 | int blockSize = 128; 73 | dim3 block(blockSize); 74 | // 这里分配了三维block,方便匹配输入输出的多维shape 75 | dim3 grid((max_k_len * head_size + blockSize - 1) / blockSize, batch_size, head_num); 76 | repeat_value_cache<<>>(v_cache_dst->data, 77 | v_cache_src->data, 78 | layer_offset, 79 | head_num, 80 | q_head_per_kv, 81 | head_size, 82 | context_length->data, 83 | max_k_len, 84 | max_seq_len); 85 | 86 | repeat_value_cache<<>>(k_cache_dst->data, 87 | k_cache_src->data, 88 | layer_offset, 89 | head_num, 90 | q_head_per_kv, 91 | head_size, 92 | context_length->data, 93 | max_k_len, 94 | max_seq_len); 95 | // for debug,打印repeat_kv这个kernel的输出结果 96 | #ifdef PRINT_DATA 97 | print_data<<<1, 1>>>(k_cache_dst->data); 98 | #else 99 | #endif 100 | } 101 | 102 | template void launchRepeatKVCache(TensorWrapper *k_cache_src, 103 | TensorWrapper *v_cache_src, 104 | TensorWrapper *context_length, 105 | TensorWrapper *layer_id, 106 | TensorWrapper *k_cache_dst, 107 | TensorWrapper *v_cache_dst); 108 | template void launchRepeatKVCache(TensorWrapper *k_cache_src, 109 | TensorWrapper *v_cache_src, 110 | TensorWrapper *context_length, 111 | TensorWrapper *layer_id, 112 | TensorWrapper *k_cache_dst, 113 | TensorWrapper *v_cache_dst); -------------------------------------------------------------------------------- /src/kernels/repeat_kv.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include "src/utils/tensor.h" 6 | 7 | template 8 | void launchRepeatKVCache(TensorWrapper *k_cache_src, 9 | TensorWrapper *v_cache_src, 10 | TensorWrapper *context_length, 11 | TensorWrapper *layer_id, 12 | TensorWrapper *k_cache_dst, 13 | TensorWrapper *v_cache_dst); -------------------------------------------------------------------------------- /src/kernels/topK.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include "src/kernels/topK.h" 5 | #include 6 | 7 | 8 | template 9 | __device__ topK reduce_functor(const topK& a, const topK& b) { 10 | topK res = a; 11 | for(int i = 0; i < K; i++){ 12 | res.insertHeap(b.val[i], b.id[i]); 13 | } 14 | return res; 15 | } 16 | 17 | template 18 | __global__ void topK_kernel_round1(const T* probs, const int vocab_size, 19 | int* topK_ids, T* topK_vals) 20 | { 21 | typedef cub::BlockReduce, blockSize> blockreduce; 22 | __shared__ typename blockreduce::TempStorage temp_storage; 23 | 24 | int tid = threadIdx.x; 25 | int bid = blockIdx.x; 26 | int gid = blockIdx.x * blockDim.x + threadIdx.x; 27 | int row_id = bid / BlockPerBeam; 28 | int block_lane = bid % BlockPerBeam; 29 | topK thread_topK; 30 | thread_topK.init(); 31 | 32 | for(int data_id = tid + block_lane * blockSize; data_id < vocab_size; data_id += BlockPerBeam * blockSize){ 33 | int data_offset = data_id + row_id * vocab_size; 34 | T data = probs[data_offset]; 35 | thread_topK.insertHeap(data, data_id); 36 | } 37 | 38 | topK block_topK = blockreduce(temp_storage).Reduce(thread_topK, reduce_functor); 39 | 40 | if(tid == 0){ 41 | for(int k_offset = 0; k_offset < K; k_offset++) { 42 | topK_vals[row_id * BlockPerBeam * K + block_lane * K + k_offset] = block_topK.val[k_offset]; 43 | topK_ids[row_id * BlockPerBeam * K + block_lane * K + k_offset] = block_topK.id[k_offset]; 44 | } 45 | } 46 | } 47 | 48 | template 49 | __global__ void topK_kernel_round2(const int* topK_ids, const T* topK_vals, 50 | int* final_topK_ids, T* final_topK_vals) 51 | { 52 | typedef cub::BlockReduce, blockSize> blockreduce; 53 | __shared__ typename blockreduce::TempStorage temp_storage; 54 | 55 | int tid = threadIdx.x; 56 | int bid = blockIdx.x; 57 | int gid = blockIdx.x * blockDim.x + threadIdx.x; 58 | int row_id = bid; 59 | topK thread_topK; 60 | thread_topK.init(); 61 | 62 | for(int i = tid; i < BlockPerBeam * K; i += blockDim.x) { 63 | int data_offset = bid * BlockPerBeam * K + i; 64 | thread_topK.insertHeap(topK_vals[data_offset], topK_ids[i]); 65 | } 66 | 67 | topK block_topK = blockreduce(temp_storage).Reduce(thread_topK, reduce_functor); 68 | if(tid == 0){ 69 | for(int k_offset = 0; k_offset < K; k_offset++) { 70 | final_topK_vals[bid * K + k_offset] = block_topK.val[k_offset]; 71 | final_topK_ids[bid * K + k_offset] = block_topK.id[k_offset]; 72 | } 73 | } 74 | } 75 | 76 | template 77 | void launchTopKforBeamSearch(TensorWrapper *probs, 78 | TensorWrapper *topk_ids, 79 | TensorWrapper *topk_vals, 80 | TensorWrapper *final_topk_ids, 81 | TensorWrapper *final_topk_vals) 82 | { 83 | int bsxbm = probs->shape[0]; 84 | int vocab_size = probs->shape[1]; 85 | constexpr int BlockPerBeam = 8; 86 | constexpr int beamwidth = 1; 87 | constexpr int K = 5; 88 | 89 | int topK_val_buf_size = bsxbm * BlockPerBeam * K; 90 | int topK_ids_buf_size = bsxbm * BlockPerBeam * K; 91 | int final_topK_val_buf_size = bsxbm * K; 92 | 93 | T* topK_vals_data = topk_vals->data; 94 | int* topK_ids_data = topk_ids->data; 95 | T* final_topK_vals_data = final_topk_vals->data; 96 | int* final_topK_ids_data = final_topk_ids->data; 97 | 98 | int maxBlockNums = 1024; 99 | int BlockNums1 = std::min(bsxbm * BlockPerBeam, maxBlockNums); 100 | int BlockNums2 = std::min(bsxbm, maxBlockNums); 101 | dim3 grid_round1(BlockNums1); 102 | dim3 block_round1(256); 103 | dim3 grid_round2(BlockNums2); 104 | dim3 block_round2(256); 105 | 106 | topK_kernel_round1 107 | <<>>(probs->data, vocab_size, topK_ids_data, topK_vals_data); 108 | topK_kernel_round2 109 | <<>>(topK_ids_data, topK_vals_data, final_topK_ids_data, final_topK_vals_data); 110 | } 111 | 112 | 113 | template void launchTopKforBeamSearch(TensorWrapper *probs, 114 | TensorWrapper *topk_ids, 115 | TensorWrapper *topk_vals, 116 | TensorWrapper *final_topk_ids, 117 | TensorWrapper *final_topk_vals); 118 | 119 | template void launchTopKforBeamSearch(TensorWrapper *probs, 120 | TensorWrapper *topk_ids, 121 | TensorWrapper *topk_vals, 122 | TensorWrapper *final_topk_ids, 123 | TensorWrapper *final_topk_vals); -------------------------------------------------------------------------------- /src/kernels/topK.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include "src/utils/tensor.h" 7 | 8 | template 9 | struct topK 10 | { 11 | T val[K]; 12 | int id[K]; 13 | 14 | __device__ void init(){ 15 | for (int i = 0; i < K; i++) { 16 | id[i] = -1; 17 | val[i] = -1e-20; 18 | } 19 | } 20 | 21 | __device__ void insertHeap(T data, int data_id){ 22 | float v = (float)val[K-1]; 23 | if(id[K-1] == -1 || v < (float)data){ 24 | id[K-1] = data_id; 25 | val[K-1] = data; 26 | } 27 | for (int i = K - 2; i >= 0; i--){ 28 | if(val[i + 1] > val[i] || id[i] == -1) { 29 | T tmp = val[i]; 30 | val[i] = val[i + 1]; 31 | val[i + 1] = tmp; 32 | int tmp_id = id[i]; 33 | id[i] = id[i + 1]; 34 | id[i + 1] = tmp_id; 35 | } 36 | } 37 | } 38 | }; 39 | 40 | 41 | template 42 | void launchTopKforBeamSearch(TensorWrapper *probs, 43 | TensorWrapper *topk_ids, 44 | TensorWrapper *topk_vals, 45 | TensorWrapper *final_topk_ids, 46 | TensorWrapper *final_topk_vals); -------------------------------------------------------------------------------- /src/memory/allocator/base_allocator.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | class BaseAllocator 6 | { 7 | public: 8 | virtual ~BaseAllocator(){}; 9 | 10 | template 11 | T* Malloc(T* ptr, size_t size, bool is_host){ 12 | return (T*)UnifyMalloc((void*)ptr, size, is_host); 13 | } 14 | virtual void* UnifyMalloc(void* ptr, size_t size, bool is_host = false) = 0; 15 | template 16 | void Free(T* ptr, bool is_host = false){ 17 | UnifyFree((void*)ptr, is_host); 18 | } 19 | virtual void UnifyFree(void* ptr, bool is_host = false) = 0; 20 | }; -------------------------------------------------------------------------------- /src/memory/allocator/cuda_allocator.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include "src/memory/allocator/base_allocator.h" 7 | #include "src/utils/macro.h" 8 | 9 | // I use Bytes to printf buffer size msg, because sometime I allocate <1KB buffer, which causes that display 0KB 10 | // 分了两种块 11 | struct CudaBigBlock { 12 | void *data; 13 | size_t size; 14 | bool is_allocated; 15 | 16 | CudaBigBlock() = default; 17 | CudaBigBlock(void* data_, int size_, bool is_allocated_): 18 | data(data_), 19 | size(size_), 20 | is_allocated(is_allocated_){} 21 | }; 22 | 23 | struct CudaSmallBlock { 24 | void *data; 25 | size_t size; 26 | bool is_allocated; 27 | 28 | CudaSmallBlock() = default; 29 | CudaSmallBlock(void* data_, int size_, bool is_allocated_): 30 | data(data_), 31 | size(size_), 32 | is_allocated(is_allocated_){} 33 | }; 34 | 35 | /* 36 | * 分配块的好处:在linux中需要进入内核态才能真正分配cudaMalloc,耗时很大,如果能维护块表 37 | * 则可以大大减少分配的开销提高块的利用率 38 | */ 39 | class CudaAllocator: public BaseAllocator { 40 | private: 41 | //{device id: block} 42 | std::map > cudaSmallBlocksMap; 43 | std::map > cudaBigBlocksMap; 44 | std::map FreeSize; 45 | size_t total_allocated_size = 0; 46 | int dev_id; 47 | public: 48 | CudaAllocator() { 49 | cudaGetDevice(&dev_id); 50 | } 51 | ~CudaAllocator() { 52 | for (auto &it: cudaSmallBlocksMap) { 53 | auto &cudaBlocks = it.second; //vector 54 | for (int i = 0; i < cudaBlocks.size(); i++) { 55 | cudaFree(cudaBlocks[i].data); 56 | } 57 | auto &bigBlocks = cudaBigBlocksMap[it.first]; 58 | for (int i = 0; i < bigBlocks.size(); i++) { 59 | cudaFree(bigBlocks[i].data); 60 | } 61 | } 62 | } 63 | 64 | void* UnifyMalloc(void* ptr, size_t size, bool is_host) { 65 | // 1. host malloc 66 | if (is_host) { 67 | //CHECK(cudaMallocHost(&ptr, size)); // for cuda stream async 68 | ptr = malloc(size); 69 | memset(ptr, 0, size); 70 | return ptr; 71 | } 72 | // 2.big buf, 先去bigblocks里面找空闲的(free出来且未归还到OS的) 73 | if (size > 1024 * 1024) { // > 1M 74 | auto &BigBlocks = cudaBigBlocksMap[dev_id]; 75 | int blockID = -1; 76 | for (int i = 0; i < BigBlocks.size(); i++) { 77 | // the freed bigblock by free method 78 | // 朴素的分配算法 分配策略为首次分配并且块空间-分配空间大于1M 79 | // 会造成碎片化 80 | if (BigBlocks[i].size >= size && !BigBlocks[i].is_allocated 81 | && BigBlocks[i].size - size < 1 * 1024 * 1024) { 82 | if (blockID == -1 || BigBlocks[blockID].size > BigBlocks[i].size) { 83 | blockID = i; 84 | } 85 | } 86 | } 87 | // the allocated big block id 88 | if (blockID != -1) { 89 | BigBlocks[blockID].is_allocated = true; 90 | // std::cout << "allocate a existed big block, id = " << blockID 91 | // <<", size = "<< size << "B" 92 | // <<", block size = "<< BigBlocks[blockID].size << "B" 93 | // << std::endl; 94 | 95 | return BigBlocks[blockID].data; 96 | } 97 | // 没找到空闲的再cudaMalloc,并插进block pool 98 | void* new_buffer; 99 | cudaMalloc(&new_buffer, size); 100 | total_allocated_size += size; 101 | // std::cout << "allocate a new big block from OS using cudaMalloc, size = " 102 | // << size << "B, total allocated size " << total_allocated_size << "B" 103 | // << std::endl; 104 | BigBlocks.push_back(CudaBigBlock(new_buffer, size, true)); 105 | return new_buffer; 106 | } 107 | // 3.small buf, 先去smallblocks里面找空闲的(free出来且未归还到OS的) 108 | // 问题: 为什么要分成大小block? 答: 用free size记录碎片 109 | // 匹配策略:简单首次匹配 110 | auto &SmallBlocks = cudaSmallBlocksMap[dev_id]; 111 | for (int i = 0; i < SmallBlocks.size(); i++) { 112 | if (SmallBlocks[i].size >= size && !SmallBlocks[i].is_allocated) { 113 | SmallBlocks[i].is_allocated = true; 114 | FreeSize[i] += SmallBlocks[i].size;//小buf size 115 | // std::cout << "allocate a existed small block, id = " << i 116 | // <<", size = "<< size << "B" 117 | // <<", block size = "<< SmallBlocks[i].size << "B" 118 | // << std::endl; 119 | return SmallBlocks[i].data; 120 | } 121 | } 122 | // 4.没找到空闲的再cudaMalloc 123 | void* new_buffer = (void*)ptr; 124 | CHECK(cudaMalloc(&new_buffer, size)); 125 | CHECK(cudaMemset(new_buffer, 0, size)); 126 | // std::cout << "allocate a new small block from OS using cudaMalloc, size = " 127 | // << size << "B, total allocated size " << total_allocated_size << "B" 128 | // << std::endl; 129 | 130 | SmallBlocks.push_back(CudaSmallBlock(new_buffer, size, true)); 131 | return new_buffer; 132 | } 133 | 134 | void UnifyFree(void* ptr, bool is_host) { 135 | if (ptr == nullptr) { 136 | return; 137 | } 138 | // 1.host free 139 | if (is_host) { 140 | free(ptr); 141 | return; 142 | } 143 | // 2.清理碎片:当累计的小buf超出了1G时,清理未分配出去的smallblocks, 已分配的还是保留在smallmap 144 | for (auto &it: cudaSmallBlocksMap) { 145 | if (FreeSize[it.first] > 1024 * 1024 * 1024) { 146 | auto &cudaBlocks = it.second; 147 | std::vector temp; 148 | for (int i = 0; i < cudaBlocks.size(); i++) { 149 | if (!cudaBlocks[i].is_allocated) { 150 | cudaSetDevice(it.first); 151 | // std::cout << "free a small block to OS using cudaFree, block id = " 152 | // << i 153 | // << ",size = " 154 | // << cudaBlocks[i].size << "B" 155 | // << std::endl; 156 | cudaFree(cudaBlocks[i].data); 157 | } else { 158 | temp.push_back(cudaBlocks[i]); 159 | } 160 | } 161 | cudaBlocks.clear(); 162 | it.second = temp; 163 | FreeSize[it.first] = 0; 164 | } 165 | } 166 | // 3.找到待free的buffer的位置,设is_allocated = false,大小block都不归还到OS,除非没有在大小block里面找到待free的ptr 167 | // 大块清理分配比较耗时,为了降低损耗,用标记为标记为已经清除即可 168 | for (auto &it: cudaSmallBlocksMap) { 169 | auto &cudaBlocks = it.second; 170 | for (int i = 0; i < cudaBlocks.size(); i++) { 171 | if (cudaBlocks[i].data == ptr) { 172 | FreeSize[it.first] += cudaBlocks[i].size; 173 | cudaBlocks[i].is_allocated = false; 174 | // std::cout << "free a small block but not to OS, block id = " 175 | // << i 176 | // << ",size = " 177 | // << cudaBlocks[i].size << "B" 178 | // << std::endl; 179 | return; 180 | } 181 | } 182 | //若是大block,那不归还到OS 183 | auto &bigBlocks = cudaBigBlocksMap[it.first]; 184 | for (int i = 0; i < bigBlocks.size(); i++) { 185 | if (bigBlocks[i].data == ptr) { 186 | // std::cout << "free a big block but not to OS, block id = " 187 | // << i 188 | // << ",size = " 189 | // << cudaBlocks[i].size << "B" 190 | // << std::endl; 191 | bigBlocks[i].is_allocated = false; 192 | return; 193 | } 194 | } 195 | } 196 | // std::cout << "NOT found the ptr in blocks, so free the ptr to OS using cudaFree" 197 | // << std::endl; 198 | cudaFree(ptr); 199 | } 200 | }; -------------------------------------------------------------------------------- /src/models/basemodel.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include "src/utils/tensor.h" 5 | #include "src/models/common_params.h" 6 | #include "src/memory/allocator/base_allocator.h" 7 | #include "src/kernels/cublas_utils.h" 8 | // (m0dulo)note: 回调函数, 用于打印当前轮次对话的LLM生成内容 9 | using CallBack = std::function; 10 | 11 | class BaseModel{ 12 | public: 13 | std::string model_name; 14 | // (m0dulo)note: 必需的且所有模型子类都共有的4个数据成员 15 | cudaStream_t stream; 16 | cublasWrapper* cublas_wrapper; 17 | BaseAllocator* allocator; 18 | cudaDeviceProp* cuda_device_prop; 19 | BaseModel(cudaStream_t stream, 20 | cublasWrapper* cublas_wrapper, 21 | BaseAllocator* allocator, 22 | cudaDeviceProp* cuda_device_prop = nullptr): 23 | stream(stream), 24 | cublas_wrapper(cublas_wrapper), 25 | allocator(allocator), 26 | cuda_device_prop(cuda_device_prop){}; 27 | // (m0dulo)note: 3个纯虚函数API, 每个具体模型子类需要实现 28 | virtual void loadTokenizer(std::string file) = 0; 29 | virtual void loadWeights(std::string file) = 0; 30 | virtual void loadWeightsFromDummy() = 0; 31 | // (m0dulo)note: 3个纯虚函数API, 用于定义每轮对话的输入、历史记录和回复API 32 | // 根据历史信息和当前输入生成当前轮次的prompt 33 | virtual std::vector MakeInput(const std::string &history, int round, const std::string &input) = 0; 34 | // 根据当前轮次回复更新到history string 35 | virtual std::string MakeHistory(const std::string &history, int round, const std::string &input, const std::string &output) = 0; 36 | // 回复内容的返回接口 37 | virtual std::string Response(const std::vector& input, CallBack PrintRes) = 0; 38 | }; -------------------------------------------------------------------------------- /src/models/common_params.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/m0dulo/InferSpore/406aae6052b3dc51688e29b91797febb22ad6150/src/models/common_params.h -------------------------------------------------------------------------------- /src/utils/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_library(weightutils STATIC weight_utils.cu) 2 | set_property(TARGET weightutils PROPERTY CUDA_SEPARABLE_COMPILATION ON) 3 | set_property(TARGET weightutils PROPERTY POSITION_INDEPENDENT_CODE ON) 4 | set_property(TARGET weightutils PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) -------------------------------------------------------------------------------- /src/utils/cuda_debug_utils.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | // usage: print_data<<<1, 1>>>() 6 | 7 | template 8 | __global__ void print_data(T* src1, bool is_target=false) { 9 | int tid = threadIdx.x; 10 | if(tid == 0) { 11 | printf("%dth = %f\n", tid, src1[tid]); 12 | printf("%dth = %f\n", tid + 1, src1[tid + 1]); 13 | // is_target is used to print the info for specified function, to avoid too much print info in screen. 14 | if (is_target){ 15 | printf("%dth = %f\n", tid + 128, src1[tid + 128]); 16 | printf("%dth = %f\n", tid + 129, src1[tid + 129]); 17 | printf("%dth = %f\n", tid + 130, src1[tid + 130]); 18 | printf("%dth = %f\n", tid + 131, src1[tid + 131]); 19 | printf("%dth = %f\n", tid + 1024, src1[tid + 1024]); 20 | } 21 | // printf("from_tensor/outlinearin data[%d] = %f\n", tid, src3[tid]); 22 | // printf("from_tensor/outlinearin data[%d] = %f\n", tid + 1, src3[tid+1]); 23 | // printf("from_tensor/outlinearin data[%d] = %f\n", tid + 128, src3[tid+128]); 24 | // printf("from_tensor/outlinearin data[%d] = %f\n", tid + 129, src3[tid+129]); 25 | 26 | // printf("qkvweight/outweight data[%d] = %f\n", tid, src2[tid]); 27 | // printf("qkvweight/outweight data[%d] = %f\n", tid + 1, src2[tid+1]); 28 | // printf("qkvweight/outweight data[%d] = %f\n", tid + 128, src2[tid+128]); 29 | // printf("qkvweight/outweight data[%d] = %f\n", tid + 129, src2[tid +129]); 30 | // printf("linear done\n"); 31 | 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /src/utils/debug_utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include "src/utils/tensor.h" 7 | #include "src/weights/base_weights.h" 8 | #include "src/utils/macro.h" 9 | // (m0dulo)note: overloaded 3 different function for saving intermediate output tensor to debug 10 | // because LLMs have many layers, so I provide some overloaded function to specify layer id to print specify layer output tensor to debug 11 | // after you save tensor into specified file ,you can turn to tests/unitests/test_data_compare.cu to specify file path to compare res with HF. 12 | template 13 | void save_tensor(TensorWrapper* input, std::string filename){ 14 | int Bm = 0; 15 | int Bk = 0; 16 | if (input->shape.size() == 4){ 17 | Bm = input->shape[0] * input->shape[1]; 18 | Bk = input->shape[3] * input->shape[2]; 19 | } else if (input->shape.size() == 3){ 20 | Bm = input->shape[0]; 21 | Bk = input->shape[1] * input->shape[2]; 22 | } else if (input->shape.size() == 2){ 23 | Bm = input->shape[0]; 24 | Bk = input->shape[1]; 25 | } 26 | T* icpu = (T*)malloc(sizeof(T) * Bm * Bk); 27 | cudaMemcpy(icpu, input->data, sizeof(T) * Bm * Bk, cudaMemcpyDeviceToHost); 28 | std::ofstream F; 29 | std::cout << "saving intermediate tensor in " << filename << "\n"; 30 | F.open("/home/data/"+ filename, std::ofstream::binary); 31 | F.write(reinterpret_cast(icpu), sizeof(T)*Bm*Bk); 32 | F.close(); 33 | } 34 | 35 | template 36 | void save_tensor(TensorWrapper* input, std::string filename, TensorWrapper* layer_id){ 37 | int id = layer_id->getVal(); 38 | if (id > 2) { 39 | return; 40 | } 41 | int Bm = 0; 42 | int Bk = 0; 43 | if (input->shape.size() == 4){ 44 | Bm = input->shape[0] * input->shape[1]; 45 | Bk = input->shape[3] * input->shape[2]; 46 | } else if (input->shape.size() == 3){ 47 | Bm = input->shape[0]; 48 | Bk = input->shape[1] * input->shape[2]; 49 | } else if (input->shape.size() == 2){ 50 | Bm = input->shape[0]; 51 | Bk = input->shape[1]; 52 | } 53 | T* icpu = (T*)malloc(sizeof(T) * Bm * Bk); 54 | cudaMemcpy(icpu, input->data, sizeof(T) * Bm * Bk, cudaMemcpyDeviceToHost); 55 | std::ofstream F; 56 | std::cout << "saving intermediate tensor in " << filename << "\n"; 57 | F.open("/home/data/" + std::to_string(id) + "_" + filename, std::ofstream::binary); 58 | F.write(reinterpret_cast(icpu), sizeof(T)*Bm*Bk); 59 | F.close(); 60 | } 61 | 62 | template 63 | void save_tensor(TensorWrapper* input, std::string filename, int layer_id){ 64 | int id = layer_id; 65 | if (id > 2) { 66 | return; 67 | } 68 | int Bm = 0; 69 | int Bk = 0; 70 | if (input->shape.size() == 4){ 71 | Bm = input->shape[0] * input->shape[1]; 72 | Bk = input->shape[3] * input->shape[2]; 73 | } else if (input->shape.size() == 3){ 74 | Bm = input->shape[0]; 75 | Bk = input->shape[1] * input->shape[2]; 76 | } else if (input->shape.size() == 2){ 77 | Bm = input->shape[0]; 78 | Bk = input->shape[1]; 79 | } 80 | T* icpu = (T*)malloc(sizeof(T) * Bm * Bk); 81 | cudaMemcpy(icpu, input->data, sizeof(T) * Bm * Bk, cudaMemcpyDeviceToHost); 82 | std::ofstream F; 83 | std::cout << "saving intermediate tensor in " << filename << "\n"; 84 | F.open("/home/data/" + std::to_string(id) + "_" + filename, std::ofstream::binary); 85 | F.write(reinterpret_cast(icpu), sizeof(T)*Bm*Bk); 86 | F.close(); 87 | } 88 | -------------------------------------------------------------------------------- /src/utils/macro.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | // find the bugs faster 10 | #define CHECK(call) \ 11 | do \ 12 | { \ 13 | const cudaError_t error_code = call; \ 14 | if (error_code != cudaSuccess) \ 15 | { \ 16 | printf("CUDA Error:\n"); \ 17 | printf(" File: %s\n", __FILE__); \ 18 | printf(" Line: %d\n", __LINE__); \ 19 | printf(" Error code: %d\n", error_code); \ 20 | printf(" Error text: %s\n", \ 21 | cudaGetErrorString(error_code)); \ 22 | exit(1); \ 23 | } \ 24 | } while (0) 25 | 26 | static const char* _cudaGetErrorEnum(cudaError_t error) 27 | { 28 | return cudaGetErrorString(error); 29 | } 30 | 31 | static const char* _cudaGetErrorEnum(cublasStatus_t error) 32 | { 33 | switch (error) { 34 | case CUBLAS_STATUS_SUCCESS: 35 | return "CUBLAS_STATUS_SUCCESS"; 36 | 37 | case CUBLAS_STATUS_NOT_INITIALIZED: 38 | return "CUBLAS_STATUS_NOT_INITIALIZED"; 39 | 40 | case CUBLAS_STATUS_ALLOC_FAILED: 41 | return "CUBLAS_STATUS_ALLOC_FAILED"; 42 | 43 | case CUBLAS_STATUS_INVALID_VALUE: 44 | return "CUBLAS_STATUS_INVALID_VALUE"; 45 | 46 | case CUBLAS_STATUS_ARCH_MISMATCH: 47 | return "CUBLAS_STATUS_ARCH_MISMATCH"; 48 | 49 | case CUBLAS_STATUS_MAPPING_ERROR: 50 | return "CUBLAS_STATUS_MAPPING_ERROR"; 51 | 52 | case CUBLAS_STATUS_EXECUTION_FAILED: 53 | return "CUBLAS_STATUS_EXECUTION_FAILED"; 54 | 55 | case CUBLAS_STATUS_INTERNAL_ERROR: 56 | return "CUBLAS_STATUS_INTERNAL_ERROR"; 57 | 58 | case CUBLAS_STATUS_NOT_SUPPORTED: 59 | return "CUBLAS_STATUS_NOT_SUPPORTED"; 60 | 61 | case CUBLAS_STATUS_LICENSE_ERROR: 62 | return "CUBLAS_STATUS_LICENSE_ERROR"; 63 | } 64 | return ""; 65 | } 66 | 67 | template 68 | void check(T result, char const* const func, const char* const file, int const line) 69 | { 70 | if (result) { 71 | throw std::runtime_error(std::string("[TM][ERROR] CUDA runtime error: ") + (_cudaGetErrorEnum(result)) + " " 72 | + file + ":" + std::to_string(line) + " \n"); 73 | } 74 | } 75 | 76 | #define CHECK_CUBLAS(val) check((val), #val, __FILE__, __LINE__) 77 | 78 | inline void syncAndCheck(const char* const file, int const line) 79 | { 80 | cudaDeviceSynchronize(); 81 | cudaError_t result = cudaGetLastError(); 82 | if (result) { 83 | throw std::runtime_error(std::string("[TM][ERROR] CUDA runtime error: ") + (_cudaGetErrorEnum(result)) + " " 84 | + file + ":" + std::to_string(line) + " \n"); 85 | } 86 | } 87 | 88 | #define DeviceSyncAndCheckCudaError() syncAndCheck(__FILE__, __LINE__) 89 | 90 | [[noreturn]] inline void throwRuntimeError(const char* const file, int const line, std::string const& info = "") 91 | { 92 | throw std::runtime_error(std::string("[oneLLM][ERROR] ") + info + " Assertion fail: " + file + ":" 93 | + std::to_string(line) + " \n"); 94 | } 95 | 96 | inline void llmAssert(bool result, const char* const file, int const line, std::string const& info = "") 97 | { 98 | if (!result) { 99 | throwRuntimeError(file, line, info); 100 | } 101 | } 102 | 103 | #define LLM_CHECK(val) llmAssert(val, __FILE__, __LINE__) 104 | #define LLM_CHECK_WITH_INFO(val, info) \ 105 | do { \ 106 | bool is_valid_val = (val); \ 107 | if (!is_valid_val) { \ 108 | llmAssert(is_valid_val, __FILE__, __LINE__, (info)); \ 109 | } \ 110 | } while (0) 111 | -------------------------------------------------------------------------------- /src/utils/model_utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include "src/models/basemodel.h" 6 | #include "src/models/llama/llama.h" 7 | #include "src/utils/macro.h" 8 | #include "src/memory/allocator/cuda_allocator.h" 9 | #include "src/models/llama/llama_params.h" 10 | // (m0dulo) note: all LLM models are created in the header file, and I provided two ways, one is real weight model, the other is dummy weight model for functionality 11 | namespace llm { 12 | template 13 | BaseModel *CreateModelWithName(const std::string& model_name) { 14 | LLM_CHECK_WITH_INFO(model_name == "llama", "dont support other models except llama yet!"); 15 | int head_num = 32; 16 | int kv_head_num = 32; 17 | int head_size = 128; 18 | int inter_size = 11008; 19 | int num_layers = 32; 20 | int max_seq_len = 64; 21 | int vocab_size = 32000; 22 | int hidden_units = (head_num + 2 * kv_head_num) * head_size; 23 | int q_hidden_units = head_num * head_size; 24 | bool attn_bias = false; 25 | LLaMAAttentionStaticParams attn_static_params; 26 | attn_static_params.rotary_embedding_dim = 128; 27 | attn_static_params.rotary_embedding_base = 10000; 28 | attn_static_params.max_position_embeddings = 4096; 29 | attn_static_params.use_dynamic_ntk = false; // true is for dyn scaling rope 30 | cublasHandle_t cublas_handle; 31 | cublasLtHandle_t cublaslt_handle; 32 | cudaStream_t stream; 33 | cublasCreate(&cublas_handle); 34 | cublasSetMathMode(cublas_handle, CUBLAS_DEFAULT_MATH); 35 | cublasWrapper* cublas_wrapper = new cublasWrapper(cublas_handle, cublaslt_handle); 36 | cublas_wrapper->setFP32GemmConfig(); 37 | BaseAllocator* allocator = new CudaAllocator; 38 | cudaDeviceProp deviceProp; 39 | cudaGetDeviceProperties(&deviceProp, 0); 40 | BaseModel *model = new Llama(head_num, 41 | kv_head_num, 42 | head_size, 43 | inter_size, 44 | num_layers, 45 | vocab_size, 46 | attn_static_params, 47 | max_seq_len, 48 | stream, 49 | cublas_wrapper, 50 | allocator, 51 | &deviceProp); 52 | return model; 53 | } 54 | 55 | template 56 | std::unique_ptr CreateDummyLLMModel(std::string tokenizer_file){ 57 | BaseModel *model = CreateModelWithName("llama"); 58 | model->loadTokenizer(tokenizer_file); 59 | model->loadWeightsFromDummy(); 60 | return std::unique_ptr (model); 61 | } 62 | 63 | template 64 | std::unique_ptr CreateRealLLMModel(std::string model_dir, std::string tokenizer_file){ 65 | BaseModel *model = CreateModelWithName("llama"); 66 | std::cout << "start creating model..." << "\n"; 67 | model->loadTokenizer(tokenizer_file); 68 | model->loadWeights(model_dir); 69 | std::cout << "finish creating model..." << "\n"; 70 | return std::unique_ptr (model); 71 | } 72 | } // namespace llm -------------------------------------------------------------------------------- /src/utils/params.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | // (m0dulo) notes: some data structure to wrap many arguements of a function for simplicity 5 | using IntDict = std::unordered_map; 6 | using floatDict = std::unordered_map; -------------------------------------------------------------------------------- /src/utils/string_utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include // std::make_unique 3 | #include // std::stringstream 4 | #include 5 | #include 6 | //(m0dulo)note: this function allow us can self define print string 7 | template 8 | inline std::string fmtstr(const std::string& format, Args... args) 9 | { 10 | // This function came from a code snippet in stackoverflow under cc-by-1.0 11 | // https://stackoverflow.com/questions/2342162/stdstring-formatting-like-sprintf 12 | 13 | // Disable format-security warning in this function. 14 | int size_s = std::snprintf(nullptr, 0, format.c_str(), args...) + 1; // Extra space for '\0' 15 | if (size_s <= 0) { 16 | throw std::runtime_error("Error during formatting."); 17 | } 18 | auto size = static_cast(size_s); 19 | std::unique_ptr buf(new char[size]); 20 | std::snprintf(buf.get(), size, format.c_str(), args...); 21 | return std::string(buf.get(), buf.get() + size - 1); // We don't want the '\0' inside 22 | } 23 | //(m0dulo)note: below two functions allow us can convert elements in vector or pointer to string 24 | template 25 | inline std::string vec2str(std::vector vec) 26 | { 27 | std::stringstream ss; 28 | ss << "("; 29 | if (!vec.empty()) { 30 | for (size_t i = 0; i < vec.size() - 1; ++i) { 31 | ss << vec[i] << ", "; 32 | } 33 | ss << vec.back(); 34 | } 35 | ss << ")"; 36 | return ss.str(); 37 | } 38 | 39 | template 40 | inline std::string arr2str(T* arr, size_t size) 41 | { 42 | std::stringstream ss; 43 | ss << "("; 44 | for (size_t i = 0; i < size - 1; ++i) { 45 | ss << arr[i] << ", "; 46 | } 47 | if (size > 0) { 48 | ss << arr[size - 1]; 49 | } 50 | ss << ")"; 51 | return ss.str(); 52 | } 53 | -------------------------------------------------------------------------------- /src/utils/tensor.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include "src/utils/string_utils.h" 11 | #include "src/utils/macro.h" 12 | enum Device 13 | { 14 | CPU_PINNED, 15 | CPU, 16 | GPU 17 | }; 18 | 19 | enum DataType 20 | { 21 | FP32, 22 | FP16, 23 | INT8, 24 | INT32, 25 | BOOL, 26 | BYTES, 27 | UNSUPPORTED 28 | }; 29 | 30 | template 31 | DataType getTensorType() 32 | { 33 | if (std::is_same::value || std::is_same::value) { 34 | return FP32; 35 | } 36 | else if (std::is_same::value || std::is_same::value) { 37 | return FP16; 38 | } 39 | else if (std::is_same::value || std::is_same::value) { 40 | return INT32; 41 | } 42 | else if (std::is_same::value || std::is_same::value) { 43 | return INT8; 44 | } 45 | else if (std::is_same::value || std::is_same::value) { 46 | return BOOL; 47 | } 48 | else if (std::is_same::value || std::is_same::value) { 49 | return BYTES; 50 | } 51 | else { 52 | return UNSUPPORTED; 53 | } 54 | } 55 | template 56 | class TensorWrapper; 57 | 58 | struct Tensor { 59 | Device location; 60 | DataType dtype; 61 | std::vector shape; 62 | 63 | Tensor() = default; 64 | 65 | Tensor(const Device location_, 66 | const DataType dtype_, 67 | const std::vector shape_): 68 | location(location_), 69 | dtype(dtype_), 70 | shape(shape_){} 71 | virtual int size() const { 72 | if (shape.size() == 0) { 73 | // TODO: add an reminder info 74 | return 0; 75 | } 76 | return std::accumulate(shape.begin(), shape.end(), (int)1, std::multiplies()); 77 | } 78 | // note: data's destructor is invoked by allocator's free API or cudaFree API 79 | // ~Tensor() { 80 | // if(data!=nullptr) { 81 | // delete data; 82 | // data = nullptr; 83 | // } 84 | // } 85 | template 86 | TensorWrapper* as(){ 87 | return static_cast*>(this); 88 | } 89 | 90 | std::string DeviceString() const 91 | { 92 | static const std::unordered_map devicetring{ 93 | {CPU, "CPU"}, {CPU_PINNED, "CPU_PINNED"}, {GPU, "GPU"}}; 94 | return devicetring.at(location); 95 | } 96 | 97 | virtual std::string toString() const 98 | { 99 | std::string device_str = DeviceString(); 100 | 101 | static const std::unordered_map type_to_string{ 102 | {INT8, "INT8"}, 103 | {INT32,"INT32"}, 104 | {FP16, "FP16"}, 105 | {FP32, "FP32"}, 106 | 107 | }; 108 | return fmtstr("Tensor[where=%s, type=%s, shape=%s]", 109 | device_str.c_str(), 110 | type_to_string.at(dtype).c_str(), 111 | vec2str(shape).c_str()); 112 | } 113 | }; 114 | 115 | template 116 | class TensorWrapper: public Tensor { 117 | public: 118 | T* data; 119 | // cant declare shape's type to std::vector&, because we usually pass a tmp var, which cant be non-const refer 120 | TensorWrapper(Device location, DataType dtype, std::vector shape): 121 | Tensor(location, dtype, shape){} 122 | 123 | TensorWrapper(Device location, DataType dtype, std::vector shape, T* data): 124 | Tensor(location, dtype, shape), 125 | data(data){ 126 | DataType in_dtype = getTensorType(); 127 | LLM_CHECK_WITH_INFO(in_dtype == dtype, "when build TensorWrapper, the passed in data type should be same as dtype in params"); 128 | } 129 | 130 | // friend bool operator==(Tensor& t1, Tensor& t2); 131 | virtual int size() const { 132 | if (data == nullptr || shape.size() == 0) { 133 | // TODO: add an reminder info 134 | return 0; 135 | } 136 | return std::accumulate(shape.begin(), shape.end(), (int)1, std::multiplies()); 137 | } 138 | 139 | inline T getVal(int id) const { 140 | //TODO: need some boundry and device check 141 | LLM_CHECK(location == CPU); 142 | return data[id]; 143 | } // only available on CPU by [] 144 | 145 | inline T getVal() const 146 | { 147 | // TODO: add type check, this is very important, because we often naturally access GPU data, which is wrong 148 | // for example, I am in transpose kernel to use layer_id->getVal(), which is wrong 149 | LLM_CHECK(location == CPU); 150 | return getVal(0); 151 | } 152 | 153 | inline T* getPtr() const { 154 | //TODO: need some boundry check 155 | return (T*)data; 156 | } 157 | 158 | inline T* getPtrByOffset(int offset) const { 159 | //TODO: need some boundry check 160 | return (T*)data + offset; 161 | } 162 | // for debug 163 | virtual std::string toString() const 164 | { 165 | std::string device_str = DeviceString(); 166 | 167 | static const std::unordered_map type_to_string{ 168 | {INT8, "INT8"}, 169 | {FP16, "FP16"}, 170 | {FP32, "FP32"}, 171 | 172 | }; 173 | return fmtstr("Tensor[where=%s, type=%s, shape=%s, data=%p]", 174 | device_str.c_str(), 175 | type_to_string.at(dtype).c_str(), 176 | vec2str(shape).c_str(), 177 | data); 178 | } 179 | }; 180 | 181 | 182 | 183 | //I cant check if the data pointer in TensorWrapper is nullptr, because the val in tensormap is tensor* 184 | //so I must check the data pointer using LLM_CHECK_WITH_INFO before insert into tensormap. 185 | struct TensorMap { 186 | std::unordered_map tensor_map_; 187 | 188 | TensorMap() = default; 189 | TensorMap(std::initializer_list> tensor_map){ 190 | for (auto& pair : tensor_map) { 191 | if (isValid(pair.second)) { 192 | insert(pair.first, pair.second); 193 | } 194 | else { 195 | // std::cout << "this is not a valid tensor, skip to insert into tensormap" << std::endl; 196 | LLM_CHECK_WITH_INFO(isValid(pair.second),fmtstr("%s is not a valid tensor, skipping insert into TensorMap", pair.first.c_str())); 197 | } 198 | } 199 | } 200 | 201 | TensorMap(const std::unordered_map& tensor_map) { 202 | // C++ 11 traverse 203 | // for (auto& kv : tensor_map) { 204 | // C++ 98 traverse 205 | for(auto it = tensor_map_.begin(); it != tensor_map_.end(); it++) { 206 | // if (isValid(kv.second)) { 207 | // insert(kv.first, kv.second); 208 | // } 209 | if (isValid(it->second)) { 210 | insert(it->first, it->second); 211 | } 212 | else { 213 | // TODO: add a reminder info 214 | } 215 | } 216 | }; 217 | 218 | ~TensorMap(){ 219 | tensor_map_.clear(); 220 | } 221 | 222 | inline size_t size() const 223 | { 224 | return tensor_map_.size(); 225 | } 226 | 227 | inline bool isExist(const std::string& key) const 228 | { 229 | return tensor_map_.find(key) != tensor_map_.end(); 230 | } 231 | 232 | inline bool isValid(const Tensor* tensor) 233 | { 234 | return tensor->size() > 0; 235 | } 236 | // 增 237 | inline void insert(const std::string& key, Tensor* value) 238 | { 239 | // TODO: add a check to check key is unique and value is valid 240 | // tensor_map_.insert({key, value}); 241 | tensor_map_[key] = value; 242 | } 243 | 244 | inline void insert(std::pair p) 245 | { 246 | tensor_map_.insert(p); 247 | } 248 | //删 249 | 250 | //改 251 | 252 | //查 253 | inline Tensor* at(const std::string& key) 254 | { 255 | LLM_CHECK_WITH_INFO(isExist(key), fmtstr("Cannot find a tensor of name %s in the tensor map (keys: %s)", 256 | key.c_str(), 257 | vec2str(keys()).c_str())); 258 | return tensor_map_.at(key); 259 | 260 | } 261 | 262 | inline Tensor* operator[](const std::string& key) 263 | { 264 | LLM_CHECK_WITH_INFO(isExist(key), fmtstr("Cannot find a tensor of name %s in the tensor map (keys: %s)", 265 | key.c_str(), 266 | vec2str(keys()).c_str())); 267 | return tensor_map_.at(key); 268 | 269 | } 270 | // TODO: Now cant use get* function in TensorMap struct, cause the value is Tensor*, not TensorWrapper,need to enhance 271 | // template 272 | // inline T getVal(const std::string& key) const 273 | // { 274 | // // TODO: add a check to check key is existed 275 | // return tensor_map_.at(key).getVal(); 276 | // } 277 | // template 278 | // inline T getValByOffset(const std::string& key, int index) const 279 | // { 280 | // // TODO: add a check to check key is existed 281 | // return tensor_map_.at(key).getVal(index); 282 | // } 283 | // //default get ptr with offset 0 284 | // template 285 | // inline T* getPtr(const std::string& key) const 286 | // { 287 | // // TODO: add a check to check key is existed 288 | // return tensor_map_.at(key).getPtr(); 289 | // } 290 | // //get ptr with specified offset 291 | // template 292 | // inline T* getPtrWithOffset(const std::string& key, int index) const 293 | // { 294 | // // TODO: add a check to check key is existed 295 | // return tensor_map_.at(key).getPtrByOffset(index); 296 | // } 297 | 298 | //for debug 299 | std::vector keys() const 300 | { 301 | std::vector key_names; 302 | for (auto& kv : tensor_map_) { 303 | key_names.push_back(kv.first); 304 | } 305 | return key_names; 306 | } 307 | // 打印出tensormap中的所有key 308 | std::string toString() 309 | { 310 | std::stringstream ss; 311 | ss << "{"; 312 | std::vector key_names = keys(); 313 | for (size_t i = 0; i < tensor_map_.size(); ++i) { 314 | ss << key_names[i] << ": " << at(key_names[i])->toString(); 315 | if (i < tensor_map_.size() - 1) { 316 | ss << ", "; 317 | } 318 | } 319 | ss << "}"; 320 | return ss.str(); 321 | } 322 | }; 323 | -------------------------------------------------------------------------------- /src/utils/vectorize_utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | //(m0dulo)note: below 5 overloaded function can convert different scalar type data to specified vector type data. 6 | template 7 | inline __device__ T_OUT scalar_cast_vec(T_IN val) 8 | { 9 | return val; 10 | } 11 | 12 | template<> 13 | inline __device__ half2 scalar_cast_vec(float val) 14 | { 15 | return __float2half2_rn(val); 16 | } 17 | 18 | template<> 19 | inline __device__ float4 scalar_cast_vec(float val) 20 | { 21 | return make_float4(val, val, val, val); 22 | } 23 | 24 | template<> 25 | inline __device__ float2 scalar_cast_vec(float val) 26 | { 27 | return make_float2(val, val); 28 | } 29 | 30 | template<> 31 | inline __device__ half2 scalar_cast_vec(half val) 32 | { 33 | //(m0dulo)note: __half2half2 cant be parsed by my nvcc compiler, so I give it up 34 | //return __half2half2(val); 35 | half2 res; 36 | res.x = val; 37 | res.y = val; 38 | return res; 39 | } 40 | 41 | template 42 | struct Vec { 43 | using Type = T; 44 | static constexpr int size = 0; 45 | }; 46 | template<> 47 | struct Vec { 48 | using Type = half2; 49 | static constexpr int size = 2; 50 | }; 51 | template<> 52 | struct Vec { 53 | using Type = float4; 54 | static constexpr int size = 4; 55 | }; 56 | //(m0dulo)note: temply dont know which LLM use two continuous elements do RoPE 57 | struct TwoFloat2{ 58 | float2 x; 59 | float2 y; 60 | }; 61 | -------------------------------------------------------------------------------- /src/utils/weight_utils.cu: -------------------------------------------------------------------------------- 1 | #include "src/utils/weight_utils.h" 2 | 3 | template 4 | inline __device__ T_OUT type_cast(T_IN val) { 5 | return val; 6 | } 7 | template<> 8 | inline __device__ float type_cast(half val) { 9 | return __half2float(val); 10 | } 11 | 12 | template<> 13 | inline __device__ half type_cast(float val) { 14 | return __float2half(val); 15 | } 16 | 17 | template 18 | void GPUMalloc(T** ptr, size_t size) 19 | { 20 | LLM_CHECK_WITH_INFO(size >= ((size_t)0), "Ask cudaMalloc size " + std::to_string(size) + "< 0 is invalid."); 21 | CHECK(cudaMalloc((void**)(ptr), sizeof(T) * size)); 22 | } 23 | template void GPUMalloc(float** ptr, size_t size); 24 | template void GPUMalloc(half** ptr, size_t size); 25 | 26 | template 27 | void GPUFree(T* ptr) 28 | { 29 | if (ptr != NULL) { 30 | CHECK(cudaFree(ptr)); 31 | ptr = NULL; 32 | } 33 | } 34 | template void GPUFree(float* ptr); 35 | template void GPUFree(half* ptr); 36 | 37 | template 38 | void cudaH2Dcpy(T* tgt, const T* src, const size_t size) 39 | { 40 | CHECK(cudaMemcpy(tgt, src, sizeof(T) * size, cudaMemcpyHostToDevice)); 41 | } 42 | 43 | template void cudaH2Dcpy(float* tgt, const float* src, const size_t size); 44 | template void cudaH2Dcpy(half* tgt, const half* src, const size_t size); 45 | 46 | template 47 | __global__ void type_conversion(T_OUT* dst, const T_IN* src, const int size) 48 | { 49 | int gtid = threadIdx.x + blockIdx.x * blockDim.x; 50 | int total_thread_nums = blockDim.x * gridDim.x; 51 | for (int index = gtid; index < size; index += total_thread_nums) { 52 | dst[index] = type_cast(src[index]); 53 | } 54 | } 55 | 56 | template 57 | void cuda_type_conversion(T_OUT* dst, const T_IN* src, const int size) 58 | { 59 | dim3 grid(128); 60 | dim3 block(128); 61 | type_conversion<<>>(dst, src, size); 62 | } 63 | 64 | template void cuda_type_conversion(float* dst, const half* src, const int size); 65 | template void cuda_type_conversion(half* dst, const float* src, const int size); 66 | 67 | 68 | template 69 | std::vector loadWeightFromBinHelper(std::vector shape, std::string filename) 70 | { 71 | if (shape.size() > 2) { 72 | printf("[ERROR] shape should have less than two dims \n"); 73 | return std::vector(); 74 | } 75 | size_t dim0 = shape[0], dim1 = 1; 76 | if (shape.size() == 2) { 77 | dim1 = shape[1]; 78 | } 79 | size_t size = dim0 * dim1; 80 | if (size == 0) { 81 | std::cout << "shape is zero, skip loading weight from file: " << filename << std::endl; 82 | return std::vector(); 83 | } 84 | 85 | std::vector host_array(size); 86 | std::ifstream in(filename, std::ios::in | std::ios::binary); 87 | if (!in.is_open()) { 88 | std::cout << "file" << filename << "cannot be opened, loading model fails!" << std::endl; 89 | return std::vector(); 90 | } 91 | 92 | size_t loaded_data_size = sizeof(T) * size; 93 | in.seekg(0, in.end); 94 | in.seekg(0, in.beg); 95 | 96 | std::cout << "Read " << std::to_string(loaded_data_size) << " bytes from " << filename << std::endl; 97 | in.read((char*)host_array.data(), loaded_data_size); 98 | 99 | size_t in_get_size = in.gcount(); 100 | if (in_get_size != loaded_data_size) { 101 | return std::vector(); 102 | } 103 | in.close(); 104 | 105 | return host_array; 106 | } 107 | 108 | template 109 | struct loadWeightFromBin 110 | { 111 | public: 112 | static void internalFunc(T_OUT* ptr, std::vector shape, std::string filename) { 113 | std::vector host_array = loadWeightFromBinHelper(shape, filename); 114 | if (host_array.empty()) { 115 | return; 116 | } 117 | 118 | cudaH2Dcpy(ptr, host_array.data(), host_array.size()); 119 | return; 120 | } 121 | }; 122 | 123 | template 124 | struct loadWeightFromBin 125 | { 126 | public: 127 | static void internalFunc(T_OUT* ptr, std::vector shape, std::string filename) { 128 | std::vector host_array = loadWeightFromBinHelper(shape, filename); 129 | if (host_array.empty()) { 130 | return; 131 | } 132 | 133 | T_FILE* ptr_tmp; 134 | GPUMalloc(&ptr_tmp, host_array.size()); 135 | cudaH2Dcpy(ptr_tmp, host_array.data(), host_array.size()); 136 | cuda_type_conversion(ptr, ptr_tmp, host_array.size()); 137 | GPUFree(ptr_tmp); 138 | return; 139 | } 140 | }; 141 | 142 | 143 | // template 144 | // typename std::enable_if::value, int>::type loadWeightFromBin(T_OUT* ptr, std::vector shape, std::string filename) 145 | // { 146 | // std::vector host_array = loadWeightFromBinHelper(shape, filename); 147 | 148 | // if (host_array.empty()) { 149 | // return 0; 150 | // } 151 | 152 | // cudaH2Dcpy(ptr, host_array.data(), host_array.size()); 153 | // return 0; 154 | // } 155 | 156 | 157 | template struct loadWeightFromBin; 158 | template struct loadWeightFromBin; 159 | template struct loadWeightFromBin; 160 | template struct loadWeightFromBin; 161 | -------------------------------------------------------------------------------- /src/utils/weight_utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include "src/utils/macro.h" 8 | 9 | template 10 | void GPUMalloc(T** ptr, size_t size); 11 | 12 | template 13 | void GPUFree(T* ptr); 14 | 15 | template ::value> struct loadWeightFromBin{ 16 | public: 17 | static void internalFunc(T_OUT* ptr, std::vector shape, std::string filename); 18 | }; 19 | -------------------------------------------------------------------------------- /src/weights/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(llama) -------------------------------------------------------------------------------- /src/weights/base_weights.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | enum class WeightType 6 | { 7 | FP32_W, 8 | FP16_W, 9 | INT8_W, 10 | UNSUPPORTED_W 11 | }; 12 | 13 | template 14 | inline WeightType getWeightType() 15 | { 16 | if (std::is_same::value || std::is_same::value) { 17 | return WeightType::FP32_W; 18 | } 19 | else if (std::is_same::value || std::is_same::value) { 20 | return WeightType::FP16_W; 21 | } 22 | else if (std::is_same::value || std::is_same::value) { 23 | return WeightType::INT8_W; 24 | } 25 | else { 26 | return WeightType::UNSUPPORTED_W; 27 | } 28 | } 29 | template 30 | struct BaseWeight { 31 | std::vector shape; 32 | T* data; 33 | WeightType type; 34 | T* bias; 35 | }; -------------------------------------------------------------------------------- /src/weights/llama/CMakeLists.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/m0dulo/InferSpore/406aae6052b3dc51688e29b91797febb22ad6150/src/weights/llama/CMakeLists.txt -------------------------------------------------------------------------------- /src/weights/llama/attention_weights.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "src/weights/base_weights.h" 3 | template 4 | struct LLaMAattentionWeights { 5 | BaseWeight q; 6 | BaseWeight k; 7 | BaseWeight v; 8 | BaseWeight qkv; 9 | BaseWeight output; 10 | }; 11 | -------------------------------------------------------------------------------- /src/weights/llama/embedding_weights.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "src/weights/base_weights.h" 3 | template 4 | struct EmbeddingWeight: public BaseWeight { 5 | }; 6 | -------------------------------------------------------------------------------- /src/weights/llama/ffn_weights.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "src/weights/base_weights.h" 3 | template 4 | struct LLaMAFFNWeights { 5 | BaseWeight gate; 6 | BaseWeight up; 7 | BaseWeight down; 8 | BaseWeight gateAndup; 9 | }; 10 | -------------------------------------------------------------------------------- /src/weights/llama/layer_weights.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include "src/weights/llama/layer_weights.h" 3 | #include "src/utils/macro.h" 4 | template 5 | LlamaLayerWeight::LlamaLayerWeight(int head_num, 6 | int kv_head_num, 7 | int head_size, 8 | int inter_size, 9 | WeightType weight_type, 10 | bool attn_bias): 11 | head_num(head_num), 12 | kv_head_num(kv_head_num), 13 | head_size(head_size), 14 | hidden_units(head_num * head_size), 15 | inter_size(inter_size), 16 | weight_type(weight_type), 17 | attn_bias(attn_bias) 18 | { 19 | // init weights structure and cudamalloc for weights 20 | CHECK(cudaMalloc((void**)&attn_norm_weight.gamma, sizeof(T) * hidden_units)); 21 | CHECK(cudaMalloc((void**)&ffn_norm_weight.gamma, sizeof(T) * hidden_units)); 22 | self_attn_weight.qkv.type = weight_type; 23 | self_attn_weight.qkv.shape = {(head_num + 2 * kv_head_num) * head_size, hidden_units}; 24 | CHECK(cudaMalloc((void**)&self_attn_weight.qkv.data, sizeof(T) * hidden_units * (head_num + 2 * kv_head_num) * head_size)); 25 | self_attn_weight.output.type = weight_type; 26 | self_attn_weight.output.shape = {hidden_units, hidden_units}; 27 | CHECK(cudaMalloc((void**)&self_attn_weight.output.data, sizeof(T) * hidden_units * hidden_units)); 28 | if (attn_bias) { 29 | CHECK(cudaMalloc((void**)&self_attn_weight.qkv.bias, sizeof(T) * (head_num + 2 * kv_head_num) * head_size)); 30 | CHECK(cudaMalloc((void**)&self_attn_weight.output.bias, sizeof(T) * hidden_units)); 31 | } 32 | // (m0dulo)note: we concat gate linear weight and up linear weight to one weight tensor for performance improvement 33 | ffn_weight.gateAndup.type = weight_type; 34 | ffn_weight.down.type = weight_type; 35 | ffn_weight.gateAndup.shape = {2 * inter_size, hidden_units}; 36 | // ffn_weight.up.shape = {hidden_units, inter_size}; 37 | ffn_weight.down.shape = {hidden_units, inter_size}; 38 | CHECK(cudaMalloc((void**)&ffn_weight.gateAndup.data, sizeof(T) * hidden_units * 2 * inter_size)); 39 | // CHECK(cudaMalloc((void**)&ffn_weight.up.data, hidden_units * inter_size)); 40 | CHECK(cudaMalloc((void**)&ffn_weight.down.data, sizeof(T) * hidden_units * inter_size)); 41 | } 42 | // (m0dulo)note: weight from HF is always half type, and if we want run fp32 inference, we should convert half weight to fp32 weight in tools/weights_convert.py 43 | // (m0dulo)note: shape and data of ffn weight downloaded form HF are transposed, so we should carefully declare shape here 44 | template 45 | void LlamaLayerWeight::loadWeights(std::string weight_path, WeightType weight_type) // weighttype参数比较多余 46 | { 47 | loadWeightFromBin::internalFunc(attn_norm_weight.gamma, {hidden_units}, weight_path + ".input_layernorm.weight.bin"); 48 | loadWeightFromBin::internalFunc(ffn_norm_weight.gamma, {hidden_units}, weight_path + ".post_attention_layernorm.weight.bin"); 49 | 50 | loadWeightFromBin::internalFunc(self_attn_weight.qkv.data, {(head_num + 2 * kv_head_num) * head_size, hidden_units}, weight_path + ".self_attn.qkv.weight.bin"); 51 | loadWeightFromBin::internalFunc(self_attn_weight.output.data, {hidden_units, hidden_units}, weight_path + ".self_attn.o_proj.weight.bin"); 52 | loadWeightFromBin::internalFunc(ffn_weight.gateAndup.data, {2 * inter_size, hidden_units}, weight_path + ".mlp.gate_up_proj.weight.bin"); 53 | // loadWeightFromBin::internalFunc(ffn_weight.up.data, {hidden_units, inter_size}, weight_path + ".mlp.up_proj.weight.bin"); 54 | loadWeightFromBin::internalFunc(ffn_weight.down.data, {hidden_units, inter_size}, weight_path + ".mlp.down_proj.weight.bin"); 55 | if (attn_bias) {//TODO 56 | loadWeightFromBin::internalFunc(self_attn_weight.qkv.bias, {(head_num + 2 * kv_head_num) * head_size}, weight_path + ".attention.wqkv.bias.bin"); 57 | loadWeightFromBin::internalFunc(self_attn_weight.output.bias, {head_num * head_size}, weight_path + ".attention.wo.bias.bin"); 58 | } else { 59 | self_attn_weight.qkv.bias = nullptr; 60 | self_attn_weight.output.bias = nullptr; 61 | ffn_weight.down.bias = nullptr; 62 | } 63 | // (m0dulo)note: below code lines can be enabled when I dont support qkvbiasandrope and fusedbiasaddresidual's bias nullptr case. 64 | //T* d_dummy_qkv_bias; 65 | //GPUMalloc(&d_dummy_qkv_bias, sizeof(T) * (head_num + 2 * kv_head_num) * head_size); 66 | //cudaMemset((void*)d_dummy_qkv_bias, 0, sizeof(T) * (head_num + 2 * kv_head_num) * head_size); 67 | //self_attn_weight.qkv.bias = (T*)d_dummy_qkv_bias; 68 | 69 | //T* d_dummy_output_bias; 70 | //GPUMalloc(&d_dummy_output_bias, sizeof(T) * head_num * head_size); 71 | //cudaMemset((void*)d_dummy_output_bias, 0, sizeof(T) * head_num * head_size); 72 | //self_attn_weight.output.bias = (T*)d_dummy_output_bias; 73 | 74 | //T* d_dummy_ffn_down_bias; 75 | //GPUMalloc(&d_dummy_ffn_down_bias, sizeof(T) * hidden_units); 76 | //cudaMemset((void*)d_dummy_ffn_down_bias, 0, sizeof(T) * hidden_units); 77 | //ffn_weight.down.bias = (T*)d_dummy_ffn_down_bias; 78 | } 79 | 80 | // (m0dulo)note: load dummy model/weight API, is used to the time when you want test inference performance only 81 | template 82 | void LlamaLayerWeight::loadWeights() 83 | { 84 | T* d_dummy_attn_norm_weight; 85 | T* d_dummy_ffn_norm_weight; 86 | T* d_dummy_qkv_weights; 87 | //T* d_dummy_qkv_bias; 88 | T* d_dummy_output_weights; 89 | T* d_dummy_output_bias; 90 | T* d_dummy_ffn_down; 91 | T* d_dummy_ffn_down_bias; 92 | T* d_dummy_ffn_gate_up; 93 | // T* d_dummy_ffn_up; 94 | CHECK(cudaMalloc((void**)&d_dummy_attn_norm_weight, sizeof(T) * hidden_units)); 95 | CHECK(cudaMalloc((void**)&d_dummy_ffn_norm_weight, sizeof(T) * hidden_units)); 96 | CHECK(cudaMalloc((void**)&d_dummy_qkv_weights, sizeof(T) * hidden_units * (head_num + 2 * kv_head_num) * head_size)); 97 | // CHECK(cudaMalloc((void**)&d_dummy_qkv_bias, sizeof(T) * (head_num + 2 * kv_head_num) * head_size)); 98 | CHECK(cudaMalloc((void**)&d_dummy_output_weights, sizeof(T) * hidden_units * hidden_units)); 99 | CHECK(cudaMalloc((void**)&d_dummy_output_bias, sizeof(T) * hidden_units)); 100 | CHECK(cudaMalloc((void**)&d_dummy_ffn_down, sizeof(T) * hidden_units * inter_size)); 101 | CHECK(cudaMalloc((void**)&d_dummy_ffn_down_bias, sizeof(T) * hidden_units)); 102 | CHECK(cudaMalloc((void**)&d_dummy_ffn_gate_up, sizeof(T) * hidden_units * 2 * inter_size)); 103 | // CHECK(cudaMalloc(&d_dummy_ffn_up, sizeof(T) * hidden_units * inter_size)); 104 | 105 | T* h_dummy_attn_norm_weight = (T*)malloc(sizeof(T) * hidden_units); 106 | T* h_dummy_ffn_norm_weight = (T*)malloc(sizeof(T) * hidden_units); 107 | T* h_dummy_qkv_weights = (T*)malloc(sizeof(T) * hidden_units * (head_num + 2 * kv_head_num) * head_size); 108 | // T* h_dummy_qkv_bias = (T*)malloc(sizeof(T) * (head_num + 2 * kv_head_num) * head_size); 109 | T* h_dummy_output_weights = (T*)malloc(sizeof(T) * hidden_units * hidden_units); 110 | T* h_dummy_output_bias = (T*)malloc(sizeof(T) * hidden_units); 111 | T* h_dummy_ffn_down = (T*)malloc(sizeof(T) * hidden_units * inter_size); 112 | T* h_dummy_ffn_down_bias = (T*)malloc(sizeof(T) * hidden_units); 113 | T* h_dummy_ffn_gate_up = (T*)malloc(sizeof(T) * hidden_units * 2 * inter_size); 114 | // T* h_dummy_ffn_up = (T*)malloc(sizeof(T) * hidden_units * inter_size); 115 | 116 | for (int i = 0; i < hidden_units; i++){ 117 | h_dummy_attn_norm_weight[i] = (T)(rand() % 100 / (float)100000); 118 | h_dummy_ffn_norm_weight[i] = (T)(rand() % 100 / (float)100000); 119 | h_dummy_output_bias[i] = (T)(rand() % 100 / (float)100000); 120 | h_dummy_ffn_down_bias[i] = (T)(rand() % 100 / (float)100000); 121 | } 122 | //for (int i = 0; i < (head_num + 2 * kv_head_num) * head_size; i++) { 123 | // h_dummy_qkv_bias[i] = (T)(rand() % 100 / (float)100000); 124 | //} 125 | for (int i = 0; i < hidden_units * inter_size; i++) { 126 | h_dummy_ffn_down[i] = (T)(rand() % 100 / (float)100000); 127 | } 128 | for (int i = 0; i < hidden_units * 2 * inter_size; i++) { 129 | h_dummy_ffn_gate_up[i] = (T)(rand() % 100 / (float)100000); 130 | // h_dummy_ffn_up[i] = (T)1.0f; 131 | } 132 | for (int i = 0; i < hidden_units * hidden_units; i++) { 133 | h_dummy_output_weights[i] = (T)(rand() % 100 / (float)100000); 134 | } 135 | for (int i = 0; i < hidden_units * (head_num + 2 * kv_head_num) * head_size; i++) { 136 | h_dummy_qkv_weights[i] = (T)(rand() % 100 / (float)100000); 137 | } 138 | CHECK(cudaMemcpy(d_dummy_attn_norm_weight, h_dummy_attn_norm_weight, sizeof(T) * hidden_units, cudaMemcpyHostToDevice)); 139 | CHECK(cudaMemcpy(d_dummy_ffn_norm_weight, h_dummy_ffn_norm_weight, sizeof(T) * hidden_units, cudaMemcpyHostToDevice)); 140 | CHECK(cudaMemcpy(d_dummy_qkv_weights, h_dummy_qkv_weights, sizeof(T) * hidden_units * (head_num + 2 * kv_head_num) * head_size, cudaMemcpyHostToDevice)); 141 | //CHECK(cudaMemcpy(d_dummy_qkv_bias, h_dummy_qkv_bias, sizeof(T) * (head_num + 2 * kv_head_num) * head_size, cudaMemcpyHostToDevice)); 142 | CHECK(cudaMemcpy(d_dummy_output_weights, h_dummy_output_weights, sizeof(T) * hidden_units * hidden_units, cudaMemcpyHostToDevice)); 143 | CHECK(cudaMemcpy(d_dummy_output_bias, h_dummy_output_bias, sizeof(T) * hidden_units, cudaMemcpyHostToDevice)); 144 | CHECK(cudaMemcpy(d_dummy_ffn_down, h_dummy_ffn_down, sizeof(T) * hidden_units * inter_size, cudaMemcpyHostToDevice)); 145 | CHECK(cudaMemcpy(d_dummy_ffn_down_bias, h_dummy_ffn_down_bias, sizeof(T) * hidden_units, cudaMemcpyHostToDevice)); 146 | CHECK(cudaMemcpy(d_dummy_ffn_gate_up, h_dummy_ffn_gate_up, sizeof(T) * hidden_units * 2 * inter_size, cudaMemcpyHostToDevice)); 147 | // CHECK(cudaMemcpy(d_dummy_ffn_up, h_dummy_ffn_up, sizeof(T) * hidden_units * inter_size, cudaMemcpyHostToDevice)); 148 | // before kernel launch, the ptr is always void*, when luanching kernel, ptr type will be cast to float* or T* 149 | attn_norm_weight.gamma = d_dummy_attn_norm_weight; 150 | ffn_norm_weight.gamma = d_dummy_ffn_norm_weight; 151 | self_attn_weight.qkv.data = d_dummy_qkv_weights; 152 | self_attn_weight.qkv.bias = nullptr; 153 | self_attn_weight.output.data = d_dummy_output_weights; 154 | self_attn_weight.output.bias = d_dummy_output_bias; 155 | ffn_weight.gateAndup.data = d_dummy_ffn_gate_up; 156 | //ffn_weight.up.data = d_dummy_ffn_up; 157 | ffn_weight.down.data = d_dummy_ffn_down; 158 | ffn_weight.down.bias = d_dummy_ffn_down_bias; 159 | } 160 | 161 | template 162 | void freeWeights(BaseWeight& weights) 163 | { 164 | cudaFree(weights.data); 165 | if(weights.bias != nullptr) { 166 | cudaFree(weights.bias); 167 | } 168 | 169 | weights.data = nullptr; 170 | weights.bias = nullptr; 171 | } 172 | template 173 | LlamaLayerWeight::~LlamaLayerWeight() 174 | { 175 | // free norm weights ptr 176 | cudaFree(attn_norm_weight.gamma); 177 | cudaFree(ffn_norm_weight.gamma); 178 | // free other weights, including data and bias 179 | freeWeights(self_attn_weight.qkv); 180 | freeWeights(self_attn_weight.output); 181 | freeWeights(ffn_weight.gateAndup); 182 | // freeWeights(ffn_weight.up); 183 | freeWeights(ffn_weight.down); 184 | } 185 | // template instantial required in linking time 186 | template class LlamaLayerWeight; 187 | template class LlamaLayerWeight; -------------------------------------------------------------------------------- /src/weights/llama/layer_weights.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "src/weights/llama/norm_weights.h" 3 | #include "src/weights/llama/attention_weights.h" 4 | #include "src/weights/llama/ffn_weights.h" 5 | #include "src/utils/weight_utils.h" 6 | template 7 | class LlamaLayerWeight { 8 | private: 9 | int head_num; 10 | int kv_head_num; 11 | int head_size; 12 | int hidden_units; 13 | int inter_size; 14 | WeightType weight_type; 15 | int bit_size; 16 | bool attn_bias; 17 | 18 | public: 19 | LlamaLayerWeight() = delete; 20 | LlamaLayerWeight(int head_num, 21 | int kv_head_num, 22 | int head_size, 23 | int inter_size, 24 | WeightType weight_type, 25 | bool attn_bias); 26 | ~LlamaLayerWeight(); 27 | 28 | void loadWeights(std::string weight_path, WeightType weight_type); 29 | 30 | void loadWeights(); 31 | 32 | LayerNormWeight attn_norm_weight; 33 | LayerNormWeight ffn_norm_weight; 34 | LLaMAattentionWeights self_attn_weight; 35 | LLaMAFFNWeights ffn_weight; 36 | }; -------------------------------------------------------------------------------- /src/weights/llama/llama_weights.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include "src/weights/llama/llama_weights.h" 3 | template 4 | LlamaWeight::LlamaWeight( 5 | int head_num, 6 | int kv_head_num, 7 | int head_size, 8 | int inter_size, 9 | int vocab_size, 10 | int num_layer, 11 | bool attn_bias, 12 | WeightType weight_type 13 | ): 14 | hidden_units(head_num * head_size), 15 | inter_size(inter_size), 16 | vocab_size(vocab_size), 17 | vocab_size_padded(vocab_size), 18 | num_layer(num_layer), 19 | weight_type(weight_type) 20 | { 21 | llama_layer_weight.reserve(num_layer); 22 | for (int l = 0; l < num_layer; ++l) { 23 | llama_layer_weight.push_back(new LlamaLayerWeight(head_num, 24 | kv_head_num, 25 | head_size, 26 | inter_size, 27 | weight_type, 28 | //group_size, 29 | attn_bias)); 30 | } 31 | GPUMalloc(&out_rmsnorm_weight.gamma, hidden_units); 32 | GPUMalloc(&post_decoder_embedding_weight.data, vocab_size * hidden_units); 33 | GPUMalloc(&pre_decoder_embedding_weight.data, vocab_size * hidden_units); 34 | pre_decoder_embedding_weight.shape = {vocab_size, hidden_units}; 35 | post_decoder_embedding_weight.shape = {vocab_size, hidden_units}; 36 | pre_decoder_embedding_weight.type = weight_type; 37 | post_decoder_embedding_weight.type = weight_type; 38 | } 39 | // (m0dulo)note: weight from HF is always half type, and if we want run fp32 inference, we should convert half weight to fp32 weight in tools/weights_convert.py 40 | // (m0dulo)note: shape and data of embedding and LMHead weight downloaded form HF are transposed, so we should carefully declare shape here 41 | template 42 | void LlamaWeight::loadWeights(std::string weight_path) { 43 | loadWeightFromBin::internalFunc(out_rmsnorm_weight.gamma, {(size_t)hidden_units}, weight_path + "model.norm.weight.bin"); 44 | loadWeightFromBin::internalFunc(post_decoder_embedding_weight.data, {(size_t)vocab_size, (size_t)hidden_units}, weight_path + "lm_head.weight.bin"); 45 | loadWeightFromBin::internalFunc(pre_decoder_embedding_weight.data, {(size_t)vocab_size, (size_t)hidden_units}, weight_path + "model.embed_tokens.weight.bin"); 46 | for (int layer = 0; layer < num_layer; ++layer) { 47 | llama_layer_weight[layer]->loadWeights(weight_path + "model.layers." + std::to_string(layer), weight_type); 48 | } 49 | } 50 | 51 | template 52 | void LlamaWeight::loadWeightsFromDummy() { 53 | T* d_dummy_out_rmsnorm_weight_gamma; 54 | T* d_dummy_post_decoder_embedding_weight; 55 | T* d_dummy_pre_decoder_embedding_weight; 56 | GPUMalloc(&d_dummy_out_rmsnorm_weight_gamma, sizeof(T) * hidden_units); 57 | GPUMalloc(&d_dummy_post_decoder_embedding_weight, sizeof(T) * hidden_units * vocab_size); 58 | GPUMalloc(&d_dummy_pre_decoder_embedding_weight, sizeof(T) * hidden_units * vocab_size); 59 | T* h_dummy_out_rmsnorm_weight_gamma = (T*)malloc(sizeof(T) * hidden_units); 60 | T* h_dummy_post_decoder_embedding_weight = (T*)malloc(sizeof(T) * hidden_units * vocab_size); 61 | T* h_dummy_pre_decoder_embedding_weight = (T*)malloc(sizeof(T) * hidden_units * vocab_size); 62 | for (int i = 0; i < hidden_units; i++){ 63 | h_dummy_out_rmsnorm_weight_gamma[i] = (T)1.0f; 64 | } 65 | for (int i = 0; i < hidden_units * vocab_size; i++) { 66 | h_dummy_post_decoder_embedding_weight[i] = (T)1.0f; 67 | h_dummy_pre_decoder_embedding_weight[i] = (T)1.0f; 68 | } 69 | cudaMemcpy(d_dummy_out_rmsnorm_weight_gamma, h_dummy_out_rmsnorm_weight_gamma, sizeof(T) * hidden_units, cudaMemcpyHostToDevice); 70 | cudaMemcpy(d_dummy_post_decoder_embedding_weight, h_dummy_post_decoder_embedding_weight, sizeof(T) * hidden_units * vocab_size, cudaMemcpyHostToDevice); 71 | cudaMemcpy(d_dummy_pre_decoder_embedding_weight, h_dummy_pre_decoder_embedding_weight, sizeof(T) * hidden_units * vocab_size, cudaMemcpyHostToDevice); 72 | 73 | out_rmsnorm_weight.gamma = d_dummy_out_rmsnorm_weight_gamma; 74 | post_decoder_embedding_weight.data = d_dummy_post_decoder_embedding_weight; 75 | pre_decoder_embedding_weight.data = d_dummy_pre_decoder_embedding_weight; 76 | for (int layer = 0; layer < num_layer; ++layer) { 77 | llama_layer_weight[layer]->loadWeights(); 78 | } 79 | } 80 | 81 | template 82 | LlamaWeight::~LlamaWeight() 83 | { 84 | cudaFree(pre_decoder_embedding_weight.data); 85 | cudaFree(out_rmsnorm_weight.gamma); 86 | cudaFree(post_decoder_embedding_weight.data); 87 | 88 | for (auto& p : llama_layer_weight) { 89 | delete p; 90 | } 91 | } 92 | 93 | template struct LlamaWeight; 94 | template struct LlamaWeight; -------------------------------------------------------------------------------- /src/weights/llama/llama_weights.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include "src/weights/weight.h" 4 | #include "src/weights/base_weights.h" 5 | #include "src/weights/llama/embedding_weights.h" 6 | #include "src/weights/llama/layer_weights.h" 7 | template 8 | struct LlamaWeight : public Weight { 9 | private: 10 | int hidden_units; 11 | int inter_size; 12 | int vocab_size; 13 | int vocab_size_padded; 14 | int num_layer; 15 | WeightType weight_type; 16 | 17 | public: 18 | std::vector*> llama_layer_weight; 19 | LayerNormWeight out_rmsnorm_weight; 20 | EmbeddingWeight post_decoder_embedding_weight; 21 | EmbeddingWeight pre_decoder_embedding_weight; 22 | 23 | LlamaWeight() = default; 24 | LlamaWeight( 25 | int head_num, 26 | int kv_head_num, 27 | int head_size, 28 | int inter_size, 29 | int vocab_size, 30 | int num_layer, 31 | bool attn_bias, 32 | WeightType weight_type 33 | ); 34 | ~LlamaWeight(); 35 | void loadWeights(std::string weight_path); 36 | void loadWeightsFromDummy(); 37 | }; -------------------------------------------------------------------------------- /src/weights/llama/norm_weights.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | template 3 | struct LayerNormWeight { 4 | T* gamma; 5 | }; -------------------------------------------------------------------------------- /src/weights/weight.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | struct Weight { 3 | virtual void loadWeights(std::string weight_path) = 0; 4 | }; -------------------------------------------------------------------------------- /tests/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(unittests) -------------------------------------------------------------------------------- /tests/unittests/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_executable(embedding 2 | test_input_embedding.cu 3 | ) 4 | target_link_libraries( 5 | embedding PUBLIC 6 | -lcudart 7 | -lcudadevrt 8 | embeddingFunctor) 9 | 10 | add_executable(paddingoffset 11 | test_cal_paddingoffset.cu 12 | ) 13 | target_link_libraries( 14 | paddingoffset PUBLIC 15 | -lcudart 16 | -lcudadevrt 17 | cal_paddingoffset) 18 | 19 | add_executable(causalmask 20 | test_casual_mask.cu 21 | ) 22 | target_link_libraries( 23 | -lcudart 24 | -lcudadevrt 25 | build_casual_mask) 26 | 27 | add_executable(biasRope 28 | test_bias_and_RoPE.cu 29 | ) 30 | target_link_libraries( 31 | biasRope PUBLIC 32 | -lcudart 33 | -lcudadevrt 34 | qkv_bias_and_rope) 35 | 36 | add_executable(test_concat_kv 37 | test_concat_kv.cu 38 | ) 39 | target_link_libraries( 40 | test_concat_kv PUBLIC 41 | -lcudart 42 | -lcudadevrt 43 | concat_kv) 44 | 45 | add_executable(testlinear 46 | test_linear.cu 47 | ) 48 | target_link_libraries( 49 | testlinear PUBLIC 50 | -lcudart 51 | -lcudadevrt 52 | linear) 53 | 54 | add_executable(bmm 55 | test_bmm.cu 56 | ) 57 | target_link_libraries( 58 | bmm PUBLIC 59 | -lcudart 60 | -lcudadevrt 61 | linear) 62 | 63 | add_executable(test_repeat_kv 64 | test_repeat_kv.cu 65 | ) 66 | target_link_libraries( 67 | test_repeat_kv PUBLIC 68 | -lcudart 69 | -lcudadevrt 70 | repeat_kv) 71 | 72 | add_executable(test_mask_softmax 73 | test_mask_softmax.cu 74 | ) 75 | target_link_libraries( 76 | test_mask_softmax PUBLIC 77 | -lcudart 78 | -lcudadevrt 79 | mask_softmax) 80 | 81 | add_executable(test_fused_trans_remv_pad 82 | test_fused_trans_remv_pad.cu 83 | ) 84 | target_link_libraries( 85 | test_fused_trans_remv_pad PUBLIC 86 | -lcudart 87 | -lcudadevrt 88 | fused_transpose_and_remv_pad) 89 | 90 | add_executable(test_act 91 | test_act.cu 92 | ) 93 | target_link_libraries( 94 | test_act PUBLIC 95 | -lcudart 96 | -lcudadevrt 97 | act) 98 | 99 | add_executable(test_topk 100 | test_topk.cu 101 | ) 102 | target_link_libraries( 103 | test_topk PUBLIC 104 | -lcudart 105 | -lcudadevrt 106 | topk) 107 | 108 | add_executable(test_fused_decoder_attention 109 | test_fused_decoder_attention.cu 110 | ) 111 | target_link_libraries( 112 | test_fused_decoder_attention PUBLIC 113 | -lcudart 114 | -lcudadevrt 115 | fused_decoder_self_attention) -------------------------------------------------------------------------------- /tests/unittests/test_act.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include "src/kernels/act_kernel.h" 8 | // (m0dulo)note: not sure CPU implementation is absolutely right and the GPU kernel is right compared with HF. 9 | // when you are implementing LLMs inference on CPU, you can reuse the CPU kernel and test its correctness 10 | // (m0dulo)note: 11 | // `./test_act 1` to test half GPU kernel 12 | // `./test_act` to test fp32 GPU kernel 13 | template 14 | void CPUSwiGLU(T* input, T* output, int batch_size, int intermedia_size){ 15 | float silu_out = 0.0f; 16 | for(int batch_id = 0; batch_id < batch_size; batch_id++){ 17 | for(int i = 0; i < intermedia_size; i++) { 18 | int offset1 = batch_id * 2 * intermedia_size + i; 19 | int offset2 = batch_id * 2 * intermedia_size + i + intermedia_size; 20 | int out_offset = batch_id * intermedia_size + i; 21 | silu_out = (float)input[offset1] / (1.0f + expf(-1 * (float)input[offset1])); 22 | output[out_offset] = static_cast(silu_out * (float)input[offset2]); 23 | } 24 | } 25 | } 26 | template 27 | bool CheckResult(T* CPUoutput, T* GPUoutput, int output_size) { 28 | for(int i = 0; i < output_size; i++) { 29 | if(fabs((float)CPUoutput[i] - (float)GPUoutput[i]) > 1e-6){ 30 | printf("the %dth res is wrong, CPUoutput = %f, GPUoutput = %f\n", i, (float)CPUoutput[i], (float)GPUoutput[i]); 31 | } 32 | } 33 | return true; 34 | } 35 | 36 | template 37 | void test_act(int batch_size, int intermedia_size, int input_size , int output_size) { 38 | T* h_input; 39 | T* d_input; 40 | h_input = (T*)malloc(sizeof(T) * input_size); 41 | cudaMalloc((void**)&d_input, sizeof(T) * input_size); 42 | T* h_output; 43 | T* d_output; 44 | h_output = (T*)malloc(sizeof(T) * output_size); 45 | cudaMalloc((void**)&d_output, sizeof(T) * output_size); 46 | for(int i = 0; i < input_size; i++) { // initialize host data 47 | h_input[i] = (T)1; 48 | } 49 | cudaMemcpy(d_input, h_input, sizeof(T) * input_size, cudaMemcpyHostToDevice); 50 | DataType type = getTensorType(); 51 | TensorWrapper* input_tensor = new TensorWrapper(GPU, type, {batch_size, 2, intermedia_size}, d_input); 52 | TensorWrapper* output_tensor = new TensorWrapper(GPU, type, {batch_size, intermedia_size}, d_output); 53 | launchAct(input_tensor, output_tensor); 54 | cudaMemcpy(h_output, d_output, sizeof(T) * output_size, cudaMemcpyDeviceToHost); 55 | T* CPU_output = (T*)malloc(sizeof(T) * output_size); 56 | CPUSwiGLU(h_input, CPU_output, batch_size, intermedia_size); 57 | bool is_true = CheckResult(CPU_output, h_output, output_size); 58 | if(is_true){ 59 | printf("test passed"); 60 | } else { 61 | printf("test failed"); 62 | } 63 | 64 | free(h_input); 65 | free(h_output); 66 | free(CPU_output); 67 | cudaFree(d_input); 68 | cudaFree(d_output); 69 | } 70 | 71 | int main(int argc, char** argv) { 72 | constexpr int batch_size = 16; 73 | constexpr int intermedia_size = 11008; 74 | constexpr int input_size = batch_size * intermedia_size * 2; 75 | constexpr int output_size = batch_size * intermedia_size; 76 | if (argv[1]){ 77 | test_act(batch_size, intermedia_size, input_size, output_size); 78 | } else { 79 | test_act(batch_size, intermedia_size, input_size, output_size); 80 | } 81 | 82 | } 83 | -------------------------------------------------------------------------------- /tests/unittests/test_bias_and_RoPE.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include "src/kernels/qkv_bias_and_RoPE.h" 9 | #include "src/weights/llama/attention_weights.h" 10 | #include "src/utils/macro.h" 11 | 12 | void CPUfunc(float* q, 13 | float* k, 14 | float* v, 15 | float* QKV, 16 | const float* qkv_bias, 17 | const int* padding_offset, 18 | const int* history_length, 19 | const int* input_length, 20 | const int batch_size, 21 | const int seq_len, 22 | const int token_num, 23 | const int head_num, 24 | const int kv_head_num, 25 | const int head_size, 26 | const int rotary_embedding_dim, 27 | float rotary_embedding_base) { 28 | int qbatchstride = seq_len * head_num * head_size; 29 | int kvbatchstride = seq_len * kv_head_num * head_size; 30 | for (int b = 0; b < batch_size; b++) { 31 | for (int s = 0; s < seq_len; s++) { 32 | int timestep = history_length[b] + s; 33 | for (int head = 0; head < head_num; head++) { 34 | for (int d = 0; d < head_size; d++) { 35 | //q bias 36 | q[b * qbatchstride + s * head_num * head_size + head * head_size + d] = 37 | QKV[b * qbatchstride + s * head_num * head_size + head * head_size + d]; 38 | } 39 | //RoPE 40 | for (int d = 0; d < head_size / 2; d++) { 41 | float x0 = q[b * qbatchstride + s * head_num * head_size + head * head_size + d]; 42 | float x1 = q[b * qbatchstride + s * head_num * head_size + head * head_size + d + 64]; 43 | 44 | float inv_freq = timestep / powf(rotary_embedding_base, (d * 2) / (float)rotary_embedding_dim); 45 | q[b * qbatchstride + s * head_num * head_size + head * head_size + d] = 46 | x0 * cos(inv_freq) - x1 * sin(inv_freq); 47 | 48 | q[b * qbatchstride + s * head_num * head_size + head * head_size + d + 64] = 49 | x1 * cos(inv_freq) + x0 * sin(inv_freq); 50 | 51 | } 52 | } 53 | for (int head = 0; head < kv_head_num; head++) { 54 | for (int d = 0; d < head_size; d++) { 55 | //k bias 56 | k[b * kvbatchstride + s * kv_head_num * head_size + head * head_size + d] = 57 | QKV[b * kvbatchstride + s * (head_num + kv_head_num) * head_size + head * head_size + d];// + qkv_bias[(head_num + kv_head_num) * head_size + d]; 58 | v[b * kvbatchstride + s * kv_head_num * head_size + head * head_size + d] = 59 | QKV[b * kvbatchstride + s * (head_num + kv_head_num * 2) * head_size + head * head_size + d];// + qkv_bias[(head_num + 2 * kv_head_num) * head_size + d]; 60 | } 61 | //RoPE 62 | for (int d = 0; d < head_size / 2; d++) { 63 | float x0 = k[b * kvbatchstride + s * kv_head_num * head_size + head * head_size + d]; 64 | float x1 = k[b * kvbatchstride + s * kv_head_num * head_size + head * head_size + d + 64]; 65 | float inv_freq = timestep / powf(rotary_embedding_base, (d * 2) / (float)rotary_embedding_dim); 66 | k[b * kvbatchstride + s * kv_head_num * head_size + head * head_size + d] = 67 | x0 * cos(inv_freq) - x1 * sin(inv_freq); 68 | 69 | k[b * kvbatchstride + s * kv_head_num * head_size + head * head_size + d + 64] = 70 | x1 * cos(inv_freq) + x0 * sin(inv_freq); 71 | 72 | } 73 | } 74 | } 75 | } 76 | } 77 | 78 | bool CheckResult(float* q, float* k, float* hq, float* hk, 79 | const int q_size, const int k_size) { 80 | for(int i = 0; i < q_size; i++) { 81 | if(fabs(q[i] - hq[i]) > 1e-6){ 82 | printf("the %dth q is wrong, q = %f, hq = %f\n", i, q[i], hq[i]); 83 | return false; 84 | } 85 | } 86 | for(int i = 0; i < k_size; i++) { 87 | if(fabs(k[i] - hk[i]) > 1e-6){ 88 | printf("the %dth k is wrong, k = %f, hk = %f\n", i, k[i], hk[i]); 89 | return false; 90 | } 91 | } 92 | return true; 93 | } 94 | 95 | int main() { 96 | const int batch_size = 1; 97 | const int seq_len = 32; 98 | int* padding_offset = (int*)malloc(sizeof(int) * batch_size * seq_len); 99 | int* history_length = (int*)malloc(sizeof(int) * batch_size); 100 | int* input_length = (int*)malloc(sizeof(int) * batch_size); 101 | const int token_num = batch_size * seq_len; 102 | const int head_num = 32; 103 | const int kv_head_num = 32; 104 | const int head_size = 128; 105 | const int rotary_embedding_dim = 128; 106 | const int rotary_embedding_base = 10000; 107 | const int max_position_embeddings = 2048; 108 | 109 | float* q = (float*)malloc(sizeof(float) * batch_size * seq_len * head_num * head_size); //output 110 | float* k = (float*)malloc(sizeof(float) * batch_size * seq_len * kv_head_num * head_size); //output 111 | float* v = (float*)malloc(sizeof(float) * batch_size * seq_len * kv_head_num * head_size); //output 112 | float* QKV = (float*)malloc(sizeof(float) * token_num * (head_num + 2 * kv_head_num) * head_size); 113 | float* qkv_bias = (float*)malloc(sizeof(float) * (head_num + 2 * kv_head_num) * head_size); 114 | for(int i = 0; i < token_num * (head_num + 2 * kv_head_num) * head_size; i++){ 115 | QKV[i] = 32.0f; 116 | } 117 | for(int i = 0; i < (head_num + 2 * kv_head_num) * head_size; i++){ 118 | qkv_bias[i] = 2.0f; 119 | } 120 | for(int i = 0; i < batch_size; i++){ 121 | input_length[i] = 7; 122 | history_length[i] = 0; 123 | } 124 | for(int i = 0; i < batch_size * seq_len; i++){ 125 | padding_offset[i] = 0; 126 | } 127 | 128 | int* dpadding_offset; 129 | int* dhistory_length; 130 | int* dinput_length; 131 | float* dq; 132 | float* dk; 133 | float* dv; 134 | float* dQKV; 135 | float* dqkv_bias; 136 | cudaMalloc((void**)&dpadding_offset, sizeof(int) * batch_size * seq_len); 137 | cudaMalloc((void**)&dhistory_length, sizeof(int) * batch_size); 138 | cudaMalloc((void**)&dinput_length, sizeof(int) * batch_size); 139 | cudaMalloc((void**)&dq, sizeof(float) * batch_size * seq_len * head_num * head_size); 140 | cudaMalloc((void**)&dk, sizeof(float) * batch_size * seq_len * kv_head_num * head_size); 141 | cudaMalloc((void**)&dv, sizeof(float) * batch_size * seq_len * kv_head_num * head_size); 142 | cudaMalloc((void**)&dQKV, sizeof(float) * token_num * (head_num + 2 * kv_head_num) * head_size); 143 | cudaMalloc((void**)&dqkv_bias, sizeof(float) * (head_num + 2 * kv_head_num) * head_size); 144 | 145 | cudaMemcpy(dinput_length, input_length, sizeof(int) * batch_size, cudaMemcpyHostToDevice); 146 | cudaMemcpy(dhistory_length, history_length, sizeof(int) * batch_size, cudaMemcpyHostToDevice); 147 | cudaMemcpy(dpadding_offset, padding_offset, sizeof(int) * seq_len * batch_size, cudaMemcpyHostToDevice); 148 | cudaMemcpy(dQKV, QKV, sizeof(float) * token_num * (head_num + 2 * kv_head_num) * head_size, cudaMemcpyHostToDevice); 149 | cudaMemcpy(dqkv_bias, qkv_bias, sizeof(float) * (head_num + 2 * kv_head_num) * head_size, cudaMemcpyHostToDevice); 150 | 151 | DataType type = getTensorType(); 152 | TensorWrapper* q_buf = new TensorWrapper(Device::GPU, type, {batch_size, head_num, seq_len, head_size}, dq); 153 | TensorWrapper* k_buf = new TensorWrapper(Device::GPU, type, {batch_size, kv_head_num, seq_len, head_size}, dk); 154 | TensorWrapper* v_buf = new TensorWrapper(Device::GPU, type, {batch_size, kv_head_num, seq_len, head_size}, dv); 155 | TensorWrapper* QKV_buf = new TensorWrapper(Device::GPU, type, {token_num, head_num + 2 * kv_head_num, head_size}, dQKV); 156 | 157 | LLaMAattentionWeights attn_weights; 158 | attn_weights.qkv.bias = dqkv_bias; 159 | DataType type_int = getTensorType(); 160 | TensorWrapper* input_length_buf = new TensorWrapper(Device::GPU, type_int, {batch_size}, dinput_length); 161 | TensorWrapper* history_length_buf = new TensorWrapper(Device::GPU, type_int, {batch_size}, dhistory_length); 162 | TensorWrapper* padding_offset_buf = new TensorWrapper(Device::GPU, type_int, {batch_size, seq_len}, dpadding_offset); 163 | LLaMAAttentionStaticParams params; 164 | params.rotary_embedding_dim = rotary_embedding_dim; 165 | params.rotary_embedding_base = rotary_embedding_base; 166 | params.max_position_embeddings = max_position_embeddings; 167 | params.use_dynamic_ntk = false; 168 | 169 | std::cout << "before launch kernel" << std::endl; 170 | launchAddFusedQKVBiasTransposeAndRoPE(q_buf, 171 | k_buf, 172 | v_buf, 173 | QKV_buf, 174 | attn_weights.qkv, 175 | padding_offset_buf, 176 | history_length_buf, 177 | input_length_buf, 178 | params); 179 | 180 | std::cout << "after launch kernel" << std::endl; 181 | 182 | std::cout << "cuda memcpy device to host" << std::endl; 183 | 184 | CHECK(cudaMemcpy(q, dq, sizeof(float) * batch_size * seq_len * head_num * head_size, cudaMemcpyDeviceToHost)); 185 | CHECK(cudaMemcpy(k, dk, sizeof(float) * batch_size * seq_len * kv_head_num * head_size, cudaMemcpyDeviceToHost)); 186 | 187 | std::cout << "after memcpyd2h, dq[0] = " << q[0] << std::endl; 188 | std::cout << "before CPU function" << std::endl; 189 | float* hq = (float*)malloc(sizeof(float) * batch_size * seq_len * head_num * head_size); 190 | float* hk = (float*)malloc(sizeof(float) * batch_size * seq_len * kv_head_num * head_size); 191 | CPUfunc(hq, 192 | hk, 193 | v, 194 | QKV, 195 | qkv_bias, 196 | padding_offset, 197 | history_length, 198 | input_length, 199 | batch_size, 200 | seq_len, 201 | token_num, 202 | head_num, 203 | kv_head_num, 204 | head_size, 205 | rotary_embedding_dim, 206 | rotary_embedding_base); 207 | std::cout << "after CPU function" << std::endl; 208 | bool is_right = CheckResult(q, k, hq, hk, 209 | batch_size * seq_len * head_num * head_size, 210 | batch_size * seq_len * kv_head_num * head_size); 211 | 212 | std::cout << "before free" << std::endl; 213 | std::cout << "passed" << std::endl; 214 | free(q); 215 | free(k); 216 | free(v); 217 | free(QKV); 218 | free(qkv_bias); 219 | free(padding_offset); 220 | free(history_length); 221 | free(input_length); 222 | free(hq); 223 | free(hk); 224 | cudaFree(dq); 225 | cudaFree(dk); 226 | cudaFree(dv); 227 | cudaFree(dQKV); 228 | cudaFree(dqkv_bias); 229 | cudaFree(dpadding_offset); 230 | cudaFree(dhistory_length); 231 | cudaFree(dinput_length); 232 | } 233 | -------------------------------------------------------------------------------- /tests/unittests/test_bmm.cu: -------------------------------------------------------------------------------- 1 | 2 | #include // std::fill_n 3 | #include // snprintf 4 | #include // expf, log 5 | #include // rand 6 | #include // std::string 7 | #include // std::vector 8 | #include 9 | #include "src/utils/macro.h" 10 | #include "src/kernels/linear.h" 11 | #include "src/weights/base_weights.h" 12 | // (m0dulo)note: this kernel's CPU implementation is absolutely right. 13 | // But when you are implementing LLMs inference on CPU, I dont recommend to reuse the CPU kernel, because its performance is bad 14 | void CPUlinear(float* input, float* weight, float* output, 15 | int m, int k, int n, int batch) { 16 | for(int b = 0; b < batch; b++) { 17 | for(int i = 0; i < m; i++) { 18 | for(int j = 0; j < n; j++) { 19 | for(int l = 0; l < k; l++) { 20 | output[b * m * n + i * n + j] += input[b * m * k + i * k + l] * weight[b * k * n + l * n + j]; 21 | } 22 | } 23 | } 24 | } 25 | } 26 | 27 | bool CheckResult(float* CPUoutput, float* GPUoutput, int output_size) { 28 | for(int i = 0; i < output_size; i++) { 29 | if(fabs(CPUoutput[i] - GPUoutput[i]) > 1e-6){ 30 | printf("the %dth res is wrong, CPUoutput = %f, GPUoutput = %f\n", i, CPUoutput[i], GPUoutput[i]); 31 | return false; 32 | } 33 | } 34 | return true; 35 | } 36 | // (m0dulo)note: 37 | // `./bmm 1` to test fp32 GPU batch matmul with trans_b = true 38 | // `./bmm` to test fp32 GPU batch matmul with trans_b = false 39 | int main(int argc, char *argv[]) { 40 | const int batch_size = 1; 41 | const int seqlen_in = 16; 42 | const int seqlen_w = 16; 43 | const int hidden_units = 4096; 44 | const int head_num = 32; 45 | const int head_size = 128; 46 | int in_size = 0; 47 | int w_size = 0; 48 | int output_size = 0; 49 | if (argv[1]) {// enable trans_b for test lmhead linear 50 | in_size = batch_size * head_num * seqlen_in * head_size; // q 51 | w_size = batch_size * head_num * seqlen_w * head_size; // k 52 | output_size = batch_size * head_num * seqlen_in * seqlen_w; //q k 53 | } else { 54 | in_size = batch_size * head_num * seqlen_in * seqlen_w; //qk 55 | w_size = batch_size * head_num * seqlen_w * head_size; // v 56 | output_size = batch_size * head_num * seqlen_in * head_size; 57 | } 58 | // debug info, better to retain: std::cout <<"batch_size=" << batch_size << " vocab_size=" << vocab_size << std::endl; 59 | float* h_w; 60 | float* d_w; 61 | h_w = (float*)malloc(sizeof(float) * w_size); 62 | cudaMalloc((void**)&d_w, sizeof(float) * w_size); 63 | for(int i = 0; i < w_size; i++) { 64 | h_w[i] = (float)(i % 2 + 1); 65 | //h_w[i] = 1.0f; // simple data 66 | } 67 | 68 | float* h_in = (float*) malloc(sizeof(float) * in_size); 69 | float* d_in; 70 | cudaMalloc((void**)&d_in, sizeof(float) * in_size); 71 | for(int i = 0; i < in_size; i++) { 72 | h_in[i] = (float)(i % 2 + 1); 73 | //h_in[i] = 1.0f; // simple data 74 | } 75 | 76 | float* h_out = (float*) malloc(sizeof(float) * output_size); 77 | float* d_out; 78 | cudaMalloc((void**)&d_out, sizeof(float) * output_size); 79 | 80 | CHECK(cudaMemcpy(d_in, h_in, sizeof(float) * in_size, cudaMemcpyHostToDevice)); 81 | CHECK(cudaMemcpy(d_w, h_w, sizeof(float) * w_size, cudaMemcpyHostToDevice)); 82 | DataType type = getTensorType(); 83 | WeightType wtype = getWeightType(); 84 | TensorWrapper* in; 85 | if (argv[1]) {// enable trans_b for test qk*v 86 | in = new TensorWrapper(Device::GPU, type, {batch_size, head_num, seqlen_in, head_size}, d_in); 87 | } else {// disable trans_b for test q*k 88 | in = new TensorWrapper(Device::GPU, type, {batch_size, head_num, seqlen_in, seqlen_w}, d_in); 89 | } 90 | TensorWrapper* weight = new TensorWrapper(Device::GPU, type, {batch_size, head_num, seqlen_w, head_size}, d_w); 91 | TensorWrapper* out; 92 | if (argv[1]) {// enable trans_b for test qk*v 93 | out = new TensorWrapper(Device::GPU, type, {batch_size, head_num, seqlen_in, seqlen_w}, d_out); 94 | } else {// disable trans_b for test q*k 95 | out = new TensorWrapper(Device::GPU, type, {batch_size, head_num, seqlen_in, head_size}, d_out); 96 | } 97 | cublasHandle_t cublas_handle; 98 | cublasLtHandle_t cublaslt_handle; 99 | cublasCreate(&cublas_handle); 100 | cublasSetMathMode(cublas_handle, CUBLAS_DEFAULT_MATH); 101 | cublasWrapper* cublas_wrapper = new cublasWrapper(cublas_handle, cublaslt_handle); 102 | cublas_wrapper->setFP32GemmConfig(); 103 | // debug info, better to retain: 104 | std::cout << "before launch kernel" << std::endl; 105 | if (argv[1]) {// enable trans_b for test qk*v 106 | launchLinearStridedBatchGemm(in, weight, out, cublas_wrapper, false, true); 107 | } else {// disable trans_b for test q*k 108 | launchLinearStridedBatchGemm(in, weight, out, cublas_wrapper); 109 | } 110 | // debug info, better to retain: 111 | std::cout << "after launch kernel" << std::endl; 112 | // debug info, better to retain: 113 | std::cout << "cuda memcpy device to host" << std::endl; 114 | // Note: remember to memcpy from device to host and define the correct copy size(mul the sizeof(dtype)), or will cause segment fault 115 | CHECK(cudaMemcpy(h_out, d_out, sizeof(float) * output_size, cudaMemcpyDeviceToHost)); 116 | float* CPUout = (float*) malloc(sizeof(float) * output_size); 117 | if (argv[1]) {// enable trans_b for ttest qk*v 118 | CPUlinear(h_in, h_w, CPUout, seqlen_in, head_size, seqlen_w, batch_size * head_num); 119 | } else {// disable trans_b for test q*k 120 | CPUlinear(h_in, h_w, CPUout, seqlen_in, seqlen_w, head_size, batch_size * head_num); 121 | } 122 | 123 | bool is_right = CheckResult(CPUout, h_out, output_size); 124 | // debug info, better to retain: 125 | std::cout << "before free" << std::endl; 126 | std::cout << "linear passed" << std::endl; 127 | free(h_in); 128 | free(h_w); 129 | free(h_out); 130 | free(CPUout); 131 | cudaFree(d_in); 132 | cudaFree(d_w); 133 | cudaFree(d_out); 134 | } -------------------------------------------------------------------------------- /tests/unittests/test_cal_paddingoffset.cu: -------------------------------------------------------------------------------- 1 | #include // std::fill_n 2 | #include // snprintf 3 | #include // expf, log 4 | #include // rand 5 | #include // std::string 6 | #include // std::vector 7 | 8 | #include "src/kernels/cal_paddingoffset.h" 9 | // note: this kernel is only int type input and output, not fp32 or half 10 | // we compare the kernel correctnesss by eyes and result print infos 11 | // `./paddingoffset` to run 12 | int main() { 13 | const int batch_size = 3; 14 | const int max_q_len = 5; 15 | // debug info, better to retain: std::cout <<"batch_size=" << batch_size << " vocab_size=" << vocab_size << std::endl; 16 | int* h_seq_lens; 17 | int *d_seq_lens; 18 | h_seq_lens = (int*)malloc(sizeof(int) * batch_size); 19 | cudaMalloc((void**)&d_seq_lens, sizeof(int) * batch_size); 20 | 21 | int* h_cum_seqlens; 22 | int* d_cum_seqlens; 23 | h_cum_seqlens = (int*)malloc(sizeof(int) * (batch_size + 1)); 24 | cudaMalloc((void**)&d_cum_seqlens, sizeof(int) * (batch_size + 1)); 25 | 26 | int* h_padding_offset; 27 | int* d_padding_offset; 28 | h_padding_offset = (int*)malloc(sizeof(int) * batch_size * max_q_len); 29 | cudaMalloc((void**)&d_padding_offset, sizeof(int) * batch_size * max_q_len); 30 | 31 | for(int i = 0; i < batch_size; i++) { // 3 32 | h_seq_lens[i] = batch_size; 33 | } 34 | cudaMemcpy(d_seq_lens, h_seq_lens, sizeof(int) * batch_size, cudaMemcpyHostToDevice); 35 | DataType type_int = getTensorType(); 36 | TensorWrapper* padding_offset = new TensorWrapper(Device::GPU, type_int, {batch_size, max_q_len}, d_padding_offset); 37 | TensorWrapper* cum_seqlens = new TensorWrapper(Device::GPU, type_int, {batch_size + 1}, d_cum_seqlens); 38 | TensorWrapper* input_lengths = new TensorWrapper(Device::GPU, type_int, {batch_size}, d_seq_lens); 39 | // debug info, better to retain: std::cout << "before launch kernel" << std::endl; 40 | launchCalPaddingoffset(padding_offset, 41 | cum_seqlens, 42 | input_lengths); 43 | // debug info, better to retain: std::cout << "after launch kernel" << std::endl; 44 | // Note: remember to memcpy from device to host and define the correct copy size(mul the sizeof(dtype)), or will cause segment fault 45 | cudaMemcpy(h_padding_offset, d_padding_offset, sizeof(int) * batch_size * max_q_len, cudaMemcpyDeviceToHost); 46 | cudaMemcpy(h_cum_seqlens, d_cum_seqlens, sizeof(int) * (batch_size + 1), cudaMemcpyDeviceToHost); 47 | // debug info, better to retain: std::cout << "cuda memcpy device to host" << std::endl; 48 | for(int i = 0; i < batch_size * max_q_len; i++) { 49 | printf("padding_offset = %d\n", h_padding_offset[i]); 50 | } 51 | for(int i = 0; i < batch_size + 1; i++){ 52 | printf("cum_seqlens =%d\n", h_cum_seqlens[i]); 53 | } 54 | /*11100 55 | //11100 56 | /11100*/ 57 | //expected result is: 58 | // padding_offset: 0,0,0,2,2,2,4,4,4,0.... shape = [batchsize, max_q_len] 59 | // cum_seqlens: 0,3,6,9. shape=[batchsize+1] 60 | // debug info, better to retain: std::cout << "before free" << std::endl; 61 | free(h_seq_lens); 62 | free(h_padding_offset); 63 | free(h_cum_seqlens); 64 | cudaFree(d_seq_lens); 65 | cudaFree(d_padding_offset); 66 | cudaFree(d_cum_seqlens); 67 | } 68 | -------------------------------------------------------------------------------- /tests/unittests/test_casual_mask.cu: -------------------------------------------------------------------------------- 1 | #include // std::fill_n 2 | #include // snprintf 3 | #include // expf, log 4 | #include // rand 5 | #include // std::string 6 | #include // std::vector 7 | 8 | #include "src/kernels/build_casual_mask.h" 9 | // note: this kernel's CPU implementation is absolutely right. 10 | // when you are implementing LLMs inference on CPU, you can reuse the CPU kernel 11 | // we compare the kernel correctnesss by eyes and result print infos 12 | void CPUbuildCasualMask(float* mask, 13 | const int* q_lens, //input lens, shape=[batch size] 14 | const int* k_lens, //context lens, shape=[batch size] 15 | int max_q_len, 16 | int max_k_len, 17 | int batch_size) { 18 | for(int b = 0; b < batch_size; b++){ 19 | int start = b * max_q_len * max_k_len; 20 | int q = q_lens[b]; 21 | int k = k_lens[b]; 22 | for(int i = 0; i < max_q_len; i++) { 23 | for(int j = 0; j < max_k_len; j++) { 24 | if(j <= i + (k - q) && i < q && j < k) { 25 | mask[start + i * max_k_len + j] = 1.0f; 26 | } else { 27 | mask[start + i * max_k_len + j] = 0.0f; 28 | } 29 | } 30 | } 31 | } 32 | } 33 | bool CheckResult(float* CPUres, float* GPUres, const int size) { 34 | for(int i = 0; i < size; i++) { 35 | if(fabs(CPUres[i] - GPUres[i]) > 1e-6){ 36 | printf("the %dth res is wrong, CPU mask = %f, GPU mask = %f\n", i, CPUres[i], GPUres[i]); 37 | return false; 38 | } 39 | } 40 | return true; 41 | } 42 | // note: 43 | // `./causalmask` to test fp32 GPU build causal mask kernel 44 | int main() { 45 | const int batch_size = 1; 46 | const int max_q_len = 5; 47 | const int max_k_len = 5; 48 | // debug info, better to retain: std::cout <<"batch_size=" << batch_size << " vocab_size=" << vocab_size << std::endl; 49 | const int mask_size = batch_size * max_q_len * max_k_len; 50 | int* h_q_lens; 51 | int* d_q_lens; 52 | h_q_lens = (int*)malloc(sizeof(int) * batch_size); 53 | cudaMalloc((void**)&d_q_lens, sizeof(int) * batch_size); 54 | int* h_k_lens; 55 | int* d_k_lens; 56 | h_k_lens = (int*)malloc(sizeof(int) * batch_size); 57 | cudaMalloc((void**)&d_k_lens, sizeof(int) * batch_size); 58 | 59 | float* d_mask; 60 | float* h_mask = (float*)malloc(sizeof(float) * mask_size); 61 | cudaMalloc((void**)&d_mask, sizeof(float) * mask_size); 62 | 63 | for(int i = 0; i < batch_size; i++) { 64 | h_q_lens[i] = 3; 65 | } 66 | for(int i = 0; i < batch_size; i++) { 67 | h_k_lens[i] = 3; 68 | } 69 | CHECK(cudaMemcpy(d_q_lens, h_q_lens, sizeof(int) * batch_size, cudaMemcpyHostToDevice)); 70 | CHECK(cudaMemcpy(d_k_lens, h_k_lens, sizeof(int) * batch_size, cudaMemcpyHostToDevice)); 71 | DataType type_float = getTensorType(); 72 | DataType type_int = getTensorType(); 73 | TensorWrapper* mask = new TensorWrapper(Device::GPU, 74 | type_float, 75 | {batch_size, max_q_len, max_k_len}, 76 | d_mask); 77 | TensorWrapper* q_lens = new TensorWrapper(Device::GPU, 78 | type_int, 79 | {batch_size}, 80 | d_q_lens); 81 | TensorWrapper* k_lens = new TensorWrapper(Device::GPU, 82 | type_int, 83 | {batch_size}, 84 | d_k_lens); 85 | launchBuildCausalMasks(mask, q_lens, k_lens); 86 | // debug info, better to retain: std::cout << "after launch kernel" << std::endl; 87 | // Note: remember to memcpy from device to host and define the correct copy size(mul the sizeof(dtype)), or will cause segment fault 88 | CHECK(cudaMemcpy(h_mask, d_mask, sizeof(float) * mask_size, cudaMemcpyDeviceToHost)); 89 | float* CPUmask = (float*)malloc(sizeof(float) * mask_size); 90 | CPUbuildCasualMask(CPUmask, h_q_lens, h_k_lens, max_q_len, max_k_len, batch_size); 91 | if (CheckResult(CPUmask, h_mask, mask_size)) { 92 | printf("test passed!\n"); 93 | } 94 | 95 | // debug info, better to retain: std::cout << "before free" << std::endl; 96 | free(h_q_lens); 97 | free(h_k_lens); 98 | free(h_mask); 99 | free(CPUmask); 100 | cudaFree(d_q_lens); 101 | cudaFree(d_k_lens); 102 | cudaFree(d_mask); 103 | } 104 | -------------------------------------------------------------------------------- /tests/unittests/test_concat_kv.cu: -------------------------------------------------------------------------------- 1 | #include // std::fill_n 2 | #include // snprintf 3 | #include // expf, log 4 | #include // rand 5 | #include // std::string 6 | #include // std::vector 7 | 8 | #include 9 | #include "src/kernels/concat_past_kv.h" 10 | // (m0dulo)note: 11 | // there is no concat kv cpu kernel implementation now 12 | // we compare the kernel correctnesss by eyes and result print infos 13 | // `./test_concat_kv` to test fp32 GPU kernel 14 | int main() 15 | { 16 | const int batch_size = 1; 17 | const int max_q_len = 16; 18 | const int max_seq_len = 32; 19 | const int head_size = 8; 20 | const int kv_head_num = 2; 21 | const int kv_size = 1 * batch_size * max_q_len * kv_head_num * head_size; 22 | const int layer_offset = 1 * batch_size * max_seq_len * kv_head_num * head_size; 23 | const int kvcache_size = layer_offset; 24 | // (m0dulo)note: we plan to place layer id on CPU 25 | // const int layer_id = 0; 26 | 27 | float *h_k_src; 28 | float *d_k_src; 29 | h_k_src = (float *)malloc(sizeof(float) * kv_size); 30 | cudaMalloc((void **)&d_k_src, sizeof(float) * kv_size); 31 | 32 | float *h_v_src; 33 | float *d_v_src; 34 | h_v_src = (float *)malloc(sizeof(float) * kv_size); 35 | cudaMalloc((void **)&d_v_src, sizeof(float) * kv_size); 36 | 37 | int *cur_query_length = (int *)malloc(sizeof(int) * batch_size); 38 | int *history_length = (int *)malloc(sizeof(int) * batch_size); 39 | int *dcur_query_length; 40 | int *dhistory_length; 41 | cudaMalloc((void **)&dcur_query_length, sizeof(int) * batch_size); 42 | cudaMalloc((void **)&dhistory_length, sizeof(int) * batch_size); 43 | 44 | float *h_k_dst = (float *)malloc(sizeof(float) * kvcache_size); 45 | float *h_v_dst = (float *)malloc(sizeof(float) * kvcache_size); 46 | float *d_k_dst; 47 | float *d_v_dst; 48 | cudaMalloc((void **)&d_k_dst, sizeof(float) * kvcache_size); 49 | cudaMalloc((void **)&d_v_dst, sizeof(float) * kvcache_size); 50 | float *kv_scale; 51 | cudaMalloc((void **)&kv_scale, sizeof(float)); 52 | int *h_layer_id = (int *)malloc(sizeof(int) * batch_size); 53 | // (m0dulo)note: we plan to place layer id on CPU 54 | // int *d_layer_id; 55 | // cudaMalloc((void **)&d_layer_id, sizeof(int) * batch_size); 56 | 57 | for (int i = 0; i < kv_size; i++) 58 | { 59 | h_k_src[i] = 1.0f; 60 | h_v_src[i] = 1.0f; 61 | } 62 | for (int i = 0; i < batch_size; i++) 63 | { 64 | cur_query_length[i] = 16; 65 | history_length[i] = 1; 66 | h_layer_id[i] = 0; 67 | } 68 | cudaMemcpy(d_v_src, h_v_src, sizeof(float) * kv_size, cudaMemcpyHostToDevice); 69 | cudaMemcpy(d_k_src, h_k_src, sizeof(float) * kv_size, cudaMemcpyHostToDevice); 70 | cudaMemcpy(dcur_query_length, cur_query_length, sizeof(int) * batch_size, cudaMemcpyHostToDevice); 71 | cudaMemcpy(dhistory_length, history_length, sizeof(int) * batch_size, cudaMemcpyHostToDevice); 72 | // cudaMemcpy(d_layer_id, h_layer_id, sizeof(int) * batch_size, cudaMemcpyHostToDevice); 73 | 74 | DataType type = getTensorType(); 75 | DataType type_int = getTensorType(); 76 | TensorWrapper *in_ksrc = new TensorWrapper(Device::GPU, type, {batch_size, kv_head_num, max_q_len, head_size}, d_k_src); 77 | TensorWrapper *in_vsrc = new TensorWrapper(Device::GPU, type, {batch_size, kv_head_num, max_q_len, head_size}, d_v_src); 78 | TensorWrapper *layer_id = new TensorWrapper(Device::CPU, type_int, {batch_size}, h_layer_id); 79 | TensorWrapper *cur_q_len = new TensorWrapper(Device::GPU, type_int, {batch_size}, dcur_query_length); 80 | TensorWrapper *history_len = new TensorWrapper(Device::GPU, type_int, {batch_size}, dhistory_length); 81 | TensorWrapper *out_kdst = new TensorWrapper(Device::GPU, type, {batch_size, kv_head_num, max_seq_len, head_size}, d_k_dst); 82 | TensorWrapper *out_vdst = new TensorWrapper(Device::GPU, type, {batch_size, kv_head_num, max_seq_len, head_size}, d_v_dst); 83 | // debug info, better to retain: std::cout << "before launch kernel" << std::endl; 84 | launchConcatKVCache(in_ksrc, in_vsrc, layer_id, cur_q_len, history_len, out_kdst, out_vdst); 85 | // debug info, better to retain: std::cout << "after launch kernel" << std::endl; 86 | // Note: remember to memcpy from device to host and define the correct copy size(mul the sizeof(dtype)), or will cause segment fault 87 | cudaMemcpy(h_v_dst, d_v_dst, sizeof(float) * kvcache_size, cudaMemcpyDeviceToHost); 88 | cudaMemcpy(h_k_dst, d_k_dst, sizeof(float) * kvcache_size, cudaMemcpyDeviceToHost); 89 | // debug info, better to retain: std::cout << "cuda memcpy device to host" << std::endl; 90 | // note: need to add offset2index and index2offset API to help us program and check result 91 | for (int i = batch_size * (1) * kv_head_num * head_size; i < batch_size * max_seq_len * kv_head_num * head_size; i++) 92 | { 93 | printf("index = %d\n", i); 94 | printf("res k = %f\n", h_k_dst[i]); 95 | // debug info, better to retain: printf("topK id = %d\n", id); 96 | printf("res v = %f\n", h_v_dst[i]); 97 | printf("===============\n"); 98 | // debug info, better to retain: printf("topK val =%f\n", val); 99 | } 100 | // debug info, better to retain: std::cout << "before free" << std::endl; 101 | free(h_k_src); 102 | free(h_v_src); 103 | free(h_k_dst); 104 | free(h_v_dst); 105 | free(cur_query_length); 106 | free(history_length); 107 | free(h_layer_id); 108 | cudaFree(d_k_src); 109 | cudaFree(d_v_src); 110 | cudaFree(d_k_dst); 111 | cudaFree(d_v_dst); 112 | cudaFree(dcur_query_length); 113 | cudaFree(dhistory_length); 114 | cudaFree(kv_scale); 115 | } -------------------------------------------------------------------------------- /tests/unittests/test_fused_addresidual_norm.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include 9 | #include "src/kernels/fused_addresidual_norm.h" 10 | 11 | #include 12 | // (m0dulo)note: 13 | // `./test_fused_addresidual_norm` to test fp32 GPU kernel 14 | // (m0dulo)note: this kernel's CPU implementation is absolutely right. 15 | // when you are implementing LLMs inference on CPU, you can reuse the CPU kernel 16 | 17 | #define CHECK(call) \ 18 | do \ 19 | { \ 20 | const cudaError_t error_code = call; \ 21 | if (error_code != cudaSuccess) \ 22 | { \ 23 | printf("CUDA Error:\n"); \ 24 | printf(" File: %s\n", __FILE__); \ 25 | printf(" Line: %d\n", __LINE__); \ 26 | printf(" Error code: %d\n", error_code); \ 27 | printf(" Error text: %s\n", \ 28 | cudaGetErrorString(error_code)); \ 29 | exit(1); \ 30 | } \ 31 | } while (0) 32 | 33 | void CPUfusedresidandRMSNorm(float* h_residual, float* h_decoder_out, float* h_bias, 34 | float* h_scale, float eps, int hidden_units, int num_tokens) { 35 | for(int b = 0; b < num_tokens; b++) { 36 | float inv_fenmu = 0.0f; 37 | float mean = 0.0f; 38 | float input = 0.0f; 39 | for (int i = 0; i < hidden_units; i++) { 40 | input = h_decoder_out[b * hidden_units + i] + 41 | h_residual[b * hidden_units + i] + h_bias[i]; 42 | } 43 | float sum = 0.0f; 44 | for (int i = 0; i < hidden_units; i++) { 45 | sum += input * input; 46 | } 47 | 48 | mean = (float)(sum / hidden_units); 49 | inv_fenmu = rsqrt(mean + eps); 50 | 51 | for (int i = 0; i < hidden_units; i++) { 52 | h_decoder_out[b * hidden_units + i] = h_decoder_out[b * hidden_units + i] * inv_fenmu * h_scale[i]; 53 | } 54 | } 55 | } 56 | 57 | bool CheckResult(float* CPUoutput, float* GPUoutput, int output_size) { 58 | for(int i = 0; i < output_size; i++) { 59 | if(fabs(CPUoutput[i] - GPUoutput[i]) > 1e-6){ 60 | printf("the %dth res is wrong, CPUoutput = %f, GPUoutput = %f\n", i, CPUoutput[i], GPUoutput[i]); 61 | return false; 62 | } 63 | 64 | } 65 | return true; 66 | } 67 | 68 | int main() { 69 | const int num_tokens = 2; 70 | const int hidden_units = 32; 71 | const int total_size = num_tokens * hidden_units; 72 | float eps = 0.5f; 73 | // debug info, better to retain: std::cout <<"batch_size=" << batch_size << " vocab_size=" << vocab_size << std::endl; 74 | float* h_residual; 75 | float* d_residual; 76 | h_residual = (float*)malloc(sizeof(float) * total_size); 77 | cudaMalloc((void**)&d_residual, sizeof(float) * total_size); 78 | for(int i = 0; i < total_size; i++) { 79 | h_residual[i] = 0.0f; 80 | } 81 | 82 | float* h_decoder_out = (float*) malloc(sizeof(float) * total_size); 83 | float* decoder_out = (float*) malloc(sizeof(float) * total_size); 84 | float* d_decoder_out; 85 | cudaMalloc((void**)&d_decoder_out, sizeof(float) * total_size); 86 | for(int i = 0; i < total_size; i++) { 87 | h_decoder_out[i] = 1.0f; 88 | } 89 | //bias 90 | float* h_bias = (float*) malloc(sizeof(float) * hidden_units); 91 | float* d_bias; 92 | cudaMalloc((void**)&d_bias, sizeof(float) * hidden_units); 93 | for(int i = 0; i < hidden_units; i++) { 94 | h_bias[i] = 0.0f; 95 | } 96 | //rmsnorm weights 97 | float* h_scale = (float*) malloc(sizeof(float) * hidden_units); 98 | float* d_scale; 99 | cudaMalloc((void**)&d_scale, sizeof(float) * hidden_units); 100 | for(int i = 0; i < hidden_units; i++) { 101 | h_scale[i] = 1.0f; 102 | } 103 | 104 | CHECK(cudaMemcpy(d_residual, h_residual, sizeof(float) * total_size, cudaMemcpyHostToDevice)); 105 | CHECK(cudaMemcpy(d_decoder_out, h_decoder_out, sizeof(float) * total_size, cudaMemcpyHostToDevice)); 106 | CHECK(cudaMemcpy(d_bias, h_bias, sizeof(float) * hidden_units, cudaMemcpyHostToDevice)); 107 | CHECK(cudaMemcpy(d_scale, h_scale, sizeof(float) * hidden_units, cudaMemcpyHostToDevice)); 108 | DataType type_float = getTensorType(); 109 | DataType type_int = getTensorType(); 110 | TensorWrapper* decoder_out_tensor = new TensorWrapper(Device::GPU, 111 | type_float, 112 | {num_tokens, hidden_units}, 113 | d_decoder_out); 114 | TensorWrapper* residual_tensor = new TensorWrapper(Device::GPU, 115 | type_float, 116 | {num_tokens, hidden_units}, 117 | d_residual); 118 | BaseWeight norm; 119 | // norm.bias = d_bias; 120 | LayerNormWeight scale; 121 | scale.gamma = d_scale; 122 | // debug info, better to retain: 123 | std::cout << "before launch kernel" << std::endl; 124 | launchFusedAddBiasResidualRMSNorm(residual_tensor, 125 | decoder_out_tensor, 126 | norm, 127 | d_scale, 128 | eps); 129 | // debug info, better to retain: 130 | std::cout << "after launch kernel" << std::endl; 131 | // debug info, better to retain: 132 | std::cout << "cuda memcpy device to host" << std::endl; 133 | // Note: remember to memcpy from device to host and define the correct copy size(mul the sizeof(dtype)), or will cause segment fault 134 | CHECK(cudaMemcpy(decoder_out, d_decoder_out, sizeof(float) * total_size, cudaMemcpyDeviceToHost)); 135 | float* CPUout = (float*) malloc(sizeof(float) * total_size); 136 | for(int i = 0; i < total_size; i++){ 137 | CPUout[i] = 1.0f; 138 | } 139 | CPUfusedresidandRMSNorm(h_residual, CPUout, h_bias, 140 | h_scale, eps, hidden_units, num_tokens); 141 | bool is_right = CheckResult(CPUout, decoder_out, total_size); 142 | // debug info, better to retain: 143 | std::cout << "before free" << std::endl; 144 | std::cout << "fused addres and rmsnorm passed" << std::endl; 145 | free(h_residual); 146 | free(h_decoder_out); 147 | free(h_bias); 148 | free(h_scale); 149 | free(CPUout); 150 | free(decoder_out); 151 | cudaFree(d_residual); 152 | cudaFree(d_decoder_out); 153 | cudaFree(d_bias); 154 | cudaFree(d_scale); 155 | } -------------------------------------------------------------------------------- /tests/unittests/test_fused_trans_remv_pad.cu: -------------------------------------------------------------------------------- 1 | #include "src/kernels/fused_transpose_and_remv_pad.h" 2 | #include 3 | // [b,h,s,d]=>[b,s,h,d]=>[num tokens,h,d] 4 | // padding_offset.shape = [num_tokens] 5 | // `./test_fused_trans_remv_pad` to test fp32 kernel 6 | int main() { 7 | const int batch_size = 2; 8 | const int head_num = 2; 9 | const int max_seq_len = 4; 10 | const int head_size = 2; 11 | const int num_tokens = 5; 12 | // debug info, better to retain: std::cout <<"batch_size=" << batch_size << " vocab_size=" << vocab_size << std::endl; 13 | const int in_size = batch_size * head_num * max_seq_len * head_size; 14 | const int out_size = num_tokens * head_num * head_size; 15 | float* h_in; 16 | float* d_in; 17 | h_in = (float*)malloc(sizeof(float) * in_size); 18 | cudaMalloc((void**)&d_in, sizeof(float) * in_size); 19 | float* h_out; 20 | float* d_out; 21 | h_out = (float*)malloc(sizeof(float) * out_size); 22 | cudaMalloc((void**)&d_out, sizeof(float) * out_size); 23 | int* h_padding_offset; 24 | int* d_padding_offset; 25 | h_padding_offset = (int*)malloc(sizeof(int) * num_tokens); 26 | cudaMalloc((void**)&d_padding_offset, sizeof(int) * num_tokens); 27 | 28 | //1st seqlen: 2, due to 1st seq, so its padding offset are all 0 29 | //2nd seqlen: 3, so its padding offset are all 4-2=2 30 | for(int i = 0; i < in_size; i++) { 31 | h_in[i] = i; 32 | } 33 | for(int i = 0; i < 2; i++) { 34 | h_padding_offset[i] = 0; 35 | } 36 | h_padding_offset[2] = 2; 37 | h_padding_offset[3] = 2; 38 | h_padding_offset[4] = 2; 39 | 40 | cudaMemcpy(d_in, h_in, sizeof(float) * in_size, cudaMemcpyHostToDevice); 41 | cudaMemcpy(d_padding_offset, h_padding_offset, sizeof(int) * num_tokens, cudaMemcpyHostToDevice); 42 | 43 | DataType type = getTensorType(); 44 | DataType type_pad = getTensorType(); 45 | TensorWrapper* in = new TensorWrapper(Device::GPU, type, {batch_size, head_num, max_seq_len, head_size}, d_in); 46 | TensorWrapper* in_pad = new TensorWrapper(Device::GPU, type_pad, {num_tokens}, d_padding_offset); 47 | TensorWrapper* out = new TensorWrapper(Device::GPU, type, {num_tokens, head_num, head_size}, d_out); 48 | std::cout << "before launch softmax kernel" << std::endl; 49 | launchTransposeOutRemovePadding(in, in_pad, out); 50 | std::cout << "after launch softmax kernel" << std::endl; 51 | std::cout << "cuda memcpy device to host" << std::endl; 52 | // Note: remember to memcpy from device to host and define the correct copy size(mul the sizeof(dtype)), or will cause segment fault 53 | cudaMemcpy(h_out, out->data, sizeof(float) * out_size, cudaMemcpyDeviceToHost); 54 | for(int i = 0; i < out_size; i++) { 55 | printf("after trans and remv pad, out[%d] = %f\n", i, h_out[i]); 56 | } 57 | // debug info, better to retain: std::cout << "before free" << std::endl; 58 | free(h_in); 59 | free(h_out); 60 | free(h_padding_offset); 61 | cudaFree(d_in); 62 | cudaFree(d_out); 63 | cudaFree(d_padding_offset); 64 | } -------------------------------------------------------------------------------- /tests/unittests/test_input_embedding.cu: -------------------------------------------------------------------------------- 1 | #include // std::fill_n 2 | #include // snprintf 3 | #include // expf, log 4 | #include // rand 5 | #include // std::string 6 | #include // std::vector 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | #include "src/kernels/input_embedding.h" 14 | // note: 15 | // there is no embedding cpu kernel implementation now 16 | // `./embedding` to test fp16 GPU kernel 17 | // `./embedding 1` to test fp32 GPU kernel 18 | 19 | #define CHECK(call) \ 20 | do \ 21 | { \ 22 | const cudaError_t error_code = call; \ 23 | if (error_code != cudaSuccess) \ 24 | { \ 25 | printf("CUDA Error:\n"); \ 26 | printf(" File: %s\n", __FILE__); \ 27 | printf(" Line: %d\n", __LINE__); \ 28 | printf(" Error code: %d\n", error_code); \ 29 | printf(" Error text: %s\n", \ 30 | cudaGetErrorString(error_code)); \ 31 | exit(1); \ 32 | } \ 33 | } while (0) 34 | 35 | void cpuEmbedding(const int* input_ids, float* output, float* embed_table, const int max_context_token_num, const int hidden_size, const int vocab_size) { 36 | for (int i = 0; i < max_context_token_num; ++i) { 37 | for (int j = 0; j < hidden_size; ++j) { 38 | output[j + i * hidden_size] = embed_table[j + input_ids[i] * hidden_size]; 39 | } 40 | } 41 | } 42 | 43 | bool checkResults(float* h_output, float* d_output, const int output_size) { 44 | float* d_output_cpu = (float*) malloc(output_size * sizeof(float)); // prepare for cpu check 45 | CHECK(cudaMemcpy(d_output_cpu, d_output, output_size * sizeof(float), cudaMemcpyDeviceToHost)); 46 | for (int i = 0; i < output_size; ++i) { 47 | if (fabs(d_output_cpu[i] - h_output[i]) > 1e5) { 48 | std::cout << "Dev : "; 49 | for (int j = max(0, i - 10); j < min(output_size, i + 10); ++j) { 50 | std::cout << d_output_cpu[i]; 51 | } 52 | std::cout << std::endl; 53 | std::cout << "Cpu : "; 54 | for (int j = max(0, i - 10); j < min(output_size, i + 10); ++j) { 55 | std::cout << h_output[i]; 56 | } 57 | std::cout << std::endl; 58 | free(d_output_cpu); 59 | return false; 60 | } 61 | } 62 | free(d_output_cpu); 63 | return true; 64 | } 65 | 66 | int main(int argc, char *argv[]) { 67 | const int max_context_token_num = 64; 68 | const int hidden_size = 4096; 69 | const int vocab_size = 32000; 70 | const int input_size = max_context_token_num; 71 | const int table_size = vocab_size * hidden_size; 72 | const int output_size = max_context_token_num * hidden_size; 73 | 74 | int* h_input = (int*) malloc(input_size * sizeof(int)); 75 | if (argv[1]) { 76 | float* h_table = (float*) malloc(table_size * sizeof(float)); 77 | float* h_output = (float*) malloc(output_size * sizeof(float)); 78 | 79 | // debug info, better to retain: 80 | std::cout << "init memory on host" << std::endl; 81 | 82 | std::random_device rd; 83 | std::mt19937 gen(rd()); 84 | std::uniform_int_distribution<> dis_int(0, vocab_size - 1); 85 | std::uniform_real_distribution<> dis_real(1.0, 2.0); 86 | 87 | for (int i = 0; i < max_context_token_num; ++i) { 88 | h_input[i] = dis_int(gen); 89 | printf("h_input[%d] = %d\n",i, h_input[i]); 90 | } 91 | for (int i = 0; i < table_size; ++i) { 92 | h_table[i] = (float)(i / hidden_size); 93 | } 94 | 95 | int* d_input; 96 | float *d_table, *d_output; 97 | cudaMalloc((void**)&d_input, input_size * sizeof(int)); 98 | cudaMalloc((void**)&d_table, table_size * sizeof(float)); 99 | cudaMalloc((void**)&d_output, output_size * sizeof(float)); 100 | // debug info, better to retain: 101 | std::cout << "init memory on device" << std::endl; 102 | 103 | CHECK(cudaMemcpy(d_input, h_input, input_size * sizeof(int), cudaMemcpyHostToDevice)); 104 | CHECK(cudaMemcpy(d_table, h_table, table_size * sizeof(float), cudaMemcpyHostToDevice)); 105 | // debug info, better to retain: 106 | std::cout << "copy to device" << std::endl; 107 | 108 | DataType type_float = getTensorType(); 109 | DataType type_int = getTensorType(); 110 | TensorWrapper* input_ids = new TensorWrapper(Device::GPU, type_int, {max_context_token_num}, d_input); 111 | TensorWrapper* output = new TensorWrapper(Device::GPU, type_float, {max_context_token_num, hidden_size}, d_output); 112 | EmbeddingWeight emb_table; 113 | emb_table.data = d_table; 114 | launchInputEmbedding(input_ids, output, &emb_table); 115 | CHECK(cudaMemcpy(h_output, output->data, output_size * sizeof(float), cudaMemcpyDeviceToHost)); 116 | std::cout << "printf h_output for check" << std::endl; 117 | for (int i = 0; i < max_context_token_num; i++){ 118 | std::cout << (float)h_output[i * hidden_size] << std::endl; 119 | } 120 | 121 | cudaFree(d_output); 122 | cudaFree(d_table); 123 | cudaFree(d_input); 124 | free(h_output); 125 | free(h_table); 126 | free(h_input); 127 | } else { 128 | half* h_table = (half*) malloc(table_size * sizeof(half)); 129 | half* h_output = (half*) malloc(output_size * sizeof(half)); 130 | 131 | // debug info, better to retain: 132 | std::cout << "init memory on host" << std::endl; 133 | 134 | std::random_device rd; 135 | std::mt19937 gen(rd()); 136 | std::uniform_int_distribution<> dis_int(0, vocab_size - 1); 137 | std::uniform_real_distribution<> dis_real(1.0, 2.0); 138 | 139 | for (int i = 0; i < max_context_token_num; ++i) { 140 | h_input[i] = dis_int(gen); 141 | } 142 | printf("h_input[0] = %d\n", h_input[0]); 143 | for (int i = 0; i < table_size; ++i) { 144 | h_table[i] = (half)(i / hidden_size); 145 | } 146 | 147 | int* d_input; 148 | 149 | half *d_table, *d_output; 150 | cudaMalloc((void**)&d_input, input_size * sizeof(int)); 151 | cudaMalloc((void**)&d_table, table_size * sizeof(half)); 152 | cudaMalloc((void**)&d_output, output_size * sizeof(half)); 153 | // debug info, better to retain: 154 | std::cout << "init memory on device" << std::endl; 155 | 156 | CHECK(cudaMemcpy(d_input, h_input, input_size * sizeof(int), cudaMemcpyHostToDevice)); 157 | CHECK(cudaMemcpy(d_table, h_table, table_size * sizeof(half), cudaMemcpyHostToDevice)); 158 | // debug info, better to retain: 159 | std::cout << "copy to device" << std::endl; 160 | 161 | DataType type_float = getTensorType(); 162 | DataType type_half = getTensorType(); 163 | DataType type_int = getTensorType(); 164 | TensorWrapper* input_ids = new TensorWrapper(Device::GPU, type_int, {max_context_token_num}, d_input); 165 | TensorWrapper* output = new TensorWrapper(Device::GPU, type_half, {max_context_token_num, hidden_size}, d_output); 166 | EmbeddingWeight emb_table; 167 | emb_table.data = d_table; 168 | launchInputEmbedding(input_ids, output, &emb_table); 169 | CHECK(cudaMemcpy(h_output, output->data, output_size * sizeof(half), cudaMemcpyDeviceToHost)); 170 | std::cout << "printf h_output for check" << std::endl; 171 | std::cout << (float)h_output[0] << std::endl; 172 | std::cout << (float)h_output[1] << std::endl; 173 | cudaFree(d_output); 174 | cudaFree(d_table); 175 | cudaFree(d_input); 176 | free(h_output); 177 | free(h_table); 178 | free(h_input); 179 | } 180 | } 181 | -------------------------------------------------------------------------------- /tests/unittests/test_linear.cu: -------------------------------------------------------------------------------- 1 | #include // std::fill_n 2 | #include // snprintf 3 | #include // expf, log 4 | #include // rand 5 | #include // std::string 6 | #include // std::vector 7 | #include 8 | #include 9 | #include "src/utils/macro.h" 10 | #include "src/kernels/linear.h" 11 | #include "src/weights/base_weights.h" 12 | 13 | void CPUlinear(float* input, float* weight, float* output, 14 | int m, int k, int n) { 15 | for(int i = 0; i < m; i++) { 16 | for(int j = 0; j < n; j++) { 17 | for(int l = 0; l < k; l++) { 18 | output[i * n + j] += input[i * k + l] * weight[l * n + j]; 19 | } 20 | } 21 | } 22 | } 23 | 24 | bool CheckResult(float* CPUoutput, float* GPUoutput, int output_size) { 25 | for(int i = 0; i < output_size; i++) { 26 | if (i < 5) { 27 | printf("0th res, CPUoutput = %f, GPUoutput = %f\n", CPUoutput[i], GPUoutput[i]); 28 | } 29 | if(fabs(CPUoutput[i] - GPUoutput[i]) > 1e-6){ 30 | printf("the %dth res is wrong, CPUoutput = %f, GPUoutput = %f\n", i, CPUoutput[i], GPUoutput[i]); 31 | return false; 32 | } 33 | 34 | } 35 | return true; 36 | } 37 | 38 | int main(int argc, char *argv[]) { 39 | const int seqlen = 13; 40 | const int hidden_units = 4096; 41 | const int vocab_size = 32; 42 | const int inter_size = 10; 43 | int hidden_units_2 = 0; 44 | int output_size = 0; 45 | 46 | hidden_units_2 = hidden_units * hidden_units; 47 | output_size = seqlen * hidden_units; 48 | // debug info, better to retain: std::cout <<"batch_size=" << batch_size << " vocab_size=" << vocab_size << std::endl; 49 | float* h_w; 50 | float* d_w; 51 | h_w = (float*)malloc(sizeof(float) * hidden_units_2); 52 | cudaMalloc((void**)&d_w, sizeof(float) * hidden_units_2); 53 | for(int i = 0; i < hidden_units_2; i++) { 54 | h_w[i] = (float)(i % 3); // 1 2 1 2 55 | } 56 | 57 | float* h_in = (float*) malloc(sizeof(float) * hidden_units * seqlen); 58 | float* d_in; 59 | cudaMalloc((void**)&d_in, sizeof(float) * seqlen * hidden_units); 60 | for(int i = 0; i < hidden_units * seqlen; i++) { 61 | h_in[i] = (float)(i % 3); 62 | } 63 | 64 | float* h_out = (float*) malloc(sizeof(float) * output_size); 65 | float* d_out; 66 | cudaMalloc((void**)&d_out, sizeof(float) * output_size); 67 | CHECK(cudaMemcpy(d_in, h_in, sizeof(float) * hidden_units * seqlen, cudaMemcpyHostToDevice)); 68 | CHECK(cudaMemcpy(d_w, h_w, sizeof(float) * hidden_units_2, cudaMemcpyHostToDevice)); 69 | DataType type = getTensorType(); 70 | WeightType wtype = getWeightType(); 71 | TensorWrapper* in = new TensorWrapper(Device::GPU, type, {seqlen, hidden_units}, d_in); 72 | BaseWeight weight; 73 | weight.shape = {hidden_units, hidden_units}; 74 | weight.data = d_w; 75 | weight.type = wtype; 76 | TensorWrapper* out; 77 | out = new TensorWrapper(Device::GPU, type, {seqlen, hidden_units}, d_out); 78 | cublasHandle_t cublas_handle; 79 | cublasLtHandle_t cublaslt_handle; 80 | cublasCreate(&cublas_handle); 81 | cublasSetMathMode(cublas_handle, CUBLAS_DEFAULT_MATH); 82 | cublasWrapper* cublas_wrapper = new cublasWrapper(cublas_handle, cublaslt_handle); 83 | cublas_wrapper->setFP32GemmConfig(); 84 | // debug info, better to retain: 85 | std::cout << "before launch kernel" << std::endl; 86 | launchLinearGemm(in, weight, out, cublas_wrapper); 87 | // debug info, better to retain: 88 | std::cout << "after launch kernel" << std::endl; 89 | // debug info, better to retain: 90 | std::cout << "cuda memcpy device to host" << std::endl; 91 | // Note: remember to memcpy from device to host and define the correct copy size(mul the sizeof(dtype)), or will cause segment fault 92 | CHECK(cudaMemcpy(h_out, d_out, sizeof(float) * output_size, cudaMemcpyDeviceToHost)); 93 | float* CPUout = (float*) malloc(sizeof(float) * output_size); 94 | CPUlinear(h_in, h_w, CPUout, seqlen, hidden_units, hidden_units); 95 | 96 | bool is_right = CheckResult(CPUout, h_out, output_size); 97 | // debug info, better to retain: 98 | std::cout << "before free" << std::endl; 99 | std::cout << "linear passed" << std::endl; 100 | free(h_in); 101 | free(h_w); 102 | free(h_out); 103 | free(CPUout); 104 | cudaFree(d_in); 105 | cudaFree(d_w); 106 | cudaFree(d_out); 107 | } -------------------------------------------------------------------------------- /tests/unittests/test_mask_softmax.cu: -------------------------------------------------------------------------------- 1 | #include // std::fill_n 2 | #include // snprintf 3 | #include // expf, log 4 | #include // rand 5 | #include // std::string 6 | #include // std::vector 7 | 8 | #include 9 | #include "src/kernels/attn_softmax_kernel.h" 10 | // (m0dulo)note: 11 | // there is no cpu kernel implementation now, and if you bought my CUDA lesson, you can find CPU softmax kernel. 12 | // we compare the kernel correctnesss by eyes and result print infos 13 | // `./test_mask_softmax 1` to test half GPU kernel 14 | // `./test_mask_softmax` to test fp32 GPU kernel 15 | #define TEST_MASKED_SOFTMAX(dtype) \ 16 | dtype *h_qk; \ 17 | dtype *d_qk; \ 18 | h_qk = (dtype *)malloc(sizeof(dtype) * qk_size); \ 19 | cudaMalloc((void **)&d_qk, sizeof(dtype) * qk_size); \ 20 | dtype *h_score; \ 21 | dtype *d_score; \ 22 | h_score = (dtype *)malloc(sizeof(dtype) * qk_size); \ 23 | cudaMalloc((void **)&d_score, sizeof(dtype) * qk_size); \ 24 | dtype *h_mask; \ 25 | dtype *d_mask; \ 26 | h_mask = (dtype *)malloc(sizeof(dtype) * batch_size * q_length * k_length); \ 27 | cudaMalloc((void **)&d_mask, sizeof(dtype) * batch_size * q_length * k_length); \ 28 | for (int i = 0; i < qk_size; i++) \ 29 | { \ 30 | h_qk[i] = i % 8; \ 31 | } \ 32 | for (int i = 0; i < batch_size * q_length * k_length; i++) \ 33 | { \ 34 | h_mask[i] = (dtype)(1); \ 35 | } \ 36 | cudaMemcpy(d_qk, h_qk, sizeof(dtype) * qk_size, cudaMemcpyHostToDevice); \ 37 | cudaMemcpy(d_mask, h_mask, sizeof(dtype) * batch_size * q_length * k_length, cudaMemcpyHostToDevice); \ 38 | DataType type = getTensorType(); \ 39 | TensorWrapper *qk = new TensorWrapper(Device::GPU, type, {batch_size, head_num, q_length, k_length}, d_qk); \ 40 | TensorWrapper *mask = new TensorWrapper(Device::GPU, type, {batch_size, q_length, k_length}, d_mask); \ 41 | TensorWrapper *score = new TensorWrapper(Device::GPU, type, {batch_size, head_num, q_length, k_length}, d_score); \ 42 | std::cout << "before launch softmax kernel" << std::endl; \ 43 | launchScaleMaskAndSoftmax(qk, mask, score, scale); \ 44 | std::cout << "after launch softmax kernel" << std::endl; \ 45 | std::cout << "cuda memcpy device to host" << std::endl; \ 46 | cudaMemcpy(h_score, score->data, sizeof(dtype) * qk_size, cudaMemcpyDeviceToHost); \ 47 | for (int i = 0; i < qk_size; i++) \ 48 | { \ 49 | printf("attn score[%d] = %f\n", i, (float)h_score[i]); \ 50 | } \ 51 | free(h_qk); \ 52 | free(h_score); \ 53 | free(h_mask); \ 54 | cudaFree(d_qk); \ 55 | cudaFree(d_score); \ 56 | cudaFree(d_mask); 57 | 58 | int main(int argc, char *argv[]) 59 | { 60 | const int batch_size = 1; 61 | const int head_num = 2; 62 | const int q_length = 8; 63 | const int k_length = 8; 64 | const int head_size = 4; 65 | float scale = rsqrtf(float(head_size)); 66 | // debug info, better to retain: std::cout <<"batch_size=" << batch_size << " vocab_size=" << vocab_size << std::endl; 67 | const int qk_size = batch_size * head_num * q_length * k_length; 68 | if (argv[1]) 69 | { 70 | TEST_MASKED_SOFTMAX(half); 71 | } 72 | else 73 | { 74 | TEST_MASKED_SOFTMAX(float); 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /tests/unittests/test_repeat_kv.cu: -------------------------------------------------------------------------------- 1 | #include // std::fill_n 2 | #include // snprintf 3 | #include // expf, log 4 | #include // rand 5 | #include // std::string 6 | #include // std::vector 7 | 8 | #include 9 | #include "src/kernels/repeat_kv.h" 10 | // (m0dulo)note: 11 | // there is no repeat kv cpu kernel implementation now 12 | // we compare the kernel correctnesss by eyes 13 | // `./test_repeat_kv` to test fp32 GPU kernel 14 | int main() { 15 | const int batch_size = 1; 16 | const int head_num = 2; 17 | const int kv_head_num = 2; 18 | const int max_seq_len = 4; 19 | const int max_k_len = 2; 20 | const int head_size = 2; 21 | const int num_layers = 2; 22 | const int k_size = num_layers * batch_size * kv_head_num * max_seq_len * head_size; 23 | const int out_k_size = batch_size * head_num * max_k_len * head_size; 24 | float* h_k; 25 | float* d_k; 26 | h_k = (float*)malloc(sizeof(float) * k_size); 27 | cudaMalloc((void**)&d_k, sizeof(float) * k_size); 28 | float* h_v; 29 | float* d_v; 30 | h_v = (float*)malloc(sizeof(float) * k_size); 31 | cudaMalloc((void**)&d_v, sizeof(float) * k_size); 32 | int* h_ctx_len; 33 | int* d_ctx_len; 34 | h_ctx_len = (int*)malloc(sizeof(int) * batch_size); 35 | cudaMalloc((void**)&d_ctx_len, sizeof(int) * batch_size); 36 | float* h_trans_k; 37 | float* d_trans_k; 38 | h_trans_k = (float*)malloc(sizeof(float) * out_k_size); 39 | cudaMalloc((void**)&d_trans_k, sizeof(float) * out_k_size); 40 | float* h_trans_v; 41 | float* d_trans_v; 42 | h_trans_v = (float*)malloc(sizeof(float) * out_k_size); 43 | cudaMalloc((void**)&d_trans_v, sizeof(float) * out_k_size); 44 | 45 | for(int i = 0; i < k_size; i++) { 46 | h_v[i] = i; 47 | h_k[i] = i; 48 | } 49 | int* h_layer_id = (int*)malloc(sizeof(int)*batch_size); 50 | 51 | for(int i = 0; i < batch_size; i++) { 52 | h_ctx_len[i] = 2; 53 | h_layer_id[i] = 0; 54 | } 55 | 56 | cudaMemcpy(d_k, h_k, sizeof(float) * k_size, cudaMemcpyHostToDevice); 57 | cudaMemcpy(d_v, h_v, sizeof(float) * k_size, cudaMemcpyHostToDevice); 58 | cudaMemcpy(d_ctx_len, h_ctx_len, sizeof(int) * batch_size, cudaMemcpyHostToDevice); 59 | DataType type = getTensorType(); 60 | DataType type_int = getTensorType(); 61 | TensorWrapper* in_k = new TensorWrapper(Device::GPU, type, {num_layers, batch_size, kv_head_num, max_seq_len, head_size}, d_k); 62 | TensorWrapper* in_v = new TensorWrapper(Device::GPU, type, {num_layers, batch_size, kv_head_num, max_seq_len, head_size}, d_v); 63 | TensorWrapper* ctx_len = new TensorWrapper(Device::GPU, type_int, {batch_size}, d_ctx_len); 64 | TensorWrapper* out_k = new TensorWrapper(Device::GPU, type, {batch_size, head_num, max_k_len, head_size}, d_trans_k); 65 | TensorWrapper* out_v = new TensorWrapper(Device::GPU, type, {batch_size, head_num, max_k_len, head_size}, d_trans_v); 66 | TensorWrapper* layer_id = new TensorWrapper(Device::CPU, type_int, {batch_size}, h_layer_id); 67 | 68 | std::cout << "before launch repeat kv kernel" << std::endl; 69 | launchRepeatKVCache(in_k, in_v, ctx_len, layer_id, out_k, out_v); 70 | std::cout << "after launch repeat kv kernel" << std::endl; 71 | std::cout << "cuda memcpy device to host" << std::endl; 72 | // Note: remember to memcpy from device to host and define the correct copy size(mul the sizeof(dtype)), or will cause segment fault 73 | cudaMemcpy(h_trans_k, out_k->data, sizeof(float) * out_k_size, cudaMemcpyDeviceToHost); 74 | for(int i = 0; i < out_k_size; i++) { 75 | printf("k trans[%d] = %f\n", i, h_trans_k[i]); 76 | } 77 | // debug info, better to retain: std::cout << "before free" << std::endl; 78 | free(h_k); 79 | free(h_v); 80 | free(h_ctx_len); 81 | free(h_trans_k); 82 | free(h_trans_v); 83 | free(h_layer_id); 84 | cudaFree(d_k); 85 | cudaFree(d_v); 86 | cudaFree(d_ctx_len); 87 | cudaFree(d_trans_k); 88 | cudaFree(d_trans_v); 89 | } 90 | -------------------------------------------------------------------------------- /tools/weights_convert.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import configparser 3 | import os 4 | from pathlib import Path 5 | import numpy as np 6 | import torch 7 | 8 | from transformers import LlamaForCausalLM, LlamaTokenizer 9 | from transformers.models.llama.modeling_llama import LlamaDecoderLayer 10 | 11 | def get_weight_data_type(data_type): 12 | if data_type == "fp32": 13 | return np.float32 14 | elif data_type == "fp16": 15 | return np.float16 16 | else: 17 | assert False, f"Invalid weight data type {data_type}" 18 | 19 | if __name__ == "__main__": 20 | parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter) 21 | parser.add_argument('-saved_dir', '-o', type=str, help='file name of output file', required=True) 22 | parser.add_argument('-in_file', '-i', type=str, help='file name of input checkpoint file', required=True) 23 | # parser.add_argument('-trained_gpu_num', '-t_g', type=int, help='How many gpus for inference', default=1) 24 | # parser.add_argument('-infer_gpu_num', '-i_g', type=int, help='How many gpus for inference', required=True) 25 | # parser.add_argument("-processes", "-p", type=int, help="How many processes to spawn for conversion (default: 4)", default=4) 26 | parser.add_argument("-weight_data_type", type=str, default="fp32", choices=["fp32", "fp16", "bf16"]) 27 | parser.add_argument("--load-model-on-cpu", action="store_true") 28 | parser.add_argument("--convert-model-on-cpu", action="store_true") 29 | args = parser.parse_args() 30 | print("\n=============== Argument ===============") 31 | for key in vars(args): 32 | print("{}: {}".format(key, vars(args)[key]))#vars(args)返回args的字典形式 33 | print("========================================") 34 | torch_dtype = torch.float16 if args.weight_data_type == 'fp16' else torch.float32 35 | model = LlamaForCausalLM.from_pretrained(args.in_file, 36 | torch_dtype=torch_dtype, 37 | device_map="auto", 38 | trust_remote_code=True) 39 | if args.load_model_on_cpu: 40 | model = model.float() 41 | model = model.cpu() 42 | torch.cuda.empty_cache() 43 | 44 | saved_dir = args.saved_dir + "/1-gpu" 45 | hf_config = vars(model.config) 46 | print("\n=============== HF model config ===============") 47 | print(hf_config) 48 | print("\n===============================================") 49 | # 根据config来写入config.ini 50 | import pdb;pdb.set_trace() 51 | config = configparser.ConfigParser() 52 | config["llama"] = {} 53 | config["llama"]["model_name"] = "llama-2-7b-chat" if hf_config["_name_or_path"] == '' else hf_config["_name_or_path"] 54 | config["llama"]["head_num"] = str(hf_config["num_attention_heads"]) 55 | config["llama"]["kv_head_num"] = str(hf_config["num_key_value_heads"]) 56 | config["llama"]["hidden_size"] = str(hf_config["hidden_size"]) 57 | config["llama"]["head_size"] = str(hf_config["hidden_size"] // hf_config["num_attention_heads"]) 58 | config["llama"]["inter_size"] = str(hf_config["intermediate_size"]) #11008 59 | #config['llama']['max_pos_seq_len'] = str(hf_config['n_positions']) 60 | config["llama"]["num_layer"] = str(hf_config["num_hidden_layers"]) 61 | config["llama"]["vocab_size"] = str(hf_config["vocab_size"]) #32000 62 | config["llama"]["bos_token_id"] = str(hf_config["bos_token_id"]) #1 63 | config["llama"]["eos_token_id"] = str(hf_config["eos_token_id"]) #2 64 | config['llama']['weight_data_type'] = args.weight_data_type 65 | config['llama']['max_position_embeddings'] = str(hf_config["max_position_embeddings"]) 66 | config['llama']['rope_theta'] = str(hf_config["rope_theta"]) 67 | config['llama']['rms_norm_eps'] = str(hf_config["rms_norm_eps"]) 68 | config['llama']['attention_bias'] = str(hf_config["attention_bias"]) #false 69 | config['llama']['top_k'] = str(hf_config["top_k"])#50 70 | with open(saved_dir + "/config.ini", 'w') as configfile: 71 | config.write(configfile) 72 | # except: 73 | # print(f"Fail to save the config in config.ini.") 74 | np_weight_data_type = get_weight_data_type(args.weight_data_type) 75 | cur_layer = 0 76 | q = 0 77 | k = 0 78 | for name, param in model.named_parameters(): 79 | # model.embed_tokens.weight [32000, 4096] 80 | # import pdb;pdb.set_trace() 81 | # if name.find("weight") == -1 and name.find("bias") == -1: 82 | # continue 83 | # import pdb;pdb.set_trace() 84 | if name.find('model.embed_tokens.weight') != -1: 85 | param.detach().cpu().float().numpy().astype(np_weight_data_type).tofile(f"model.embed_tokens.weight.bin") 86 | elif name.find('model.norm.weight') != -1: 87 | param.detach().cpu().float().numpy().astype(np_weight_data_type).tofile(f"model.norm.weight.bin") 88 | elif name.find('lm_head.weight') != -1: 89 | param.detach().cpu().float().numpy().astype(np_weight_data_type).tofile(f"lm_head.weight.bin") 90 | elif name.find('self_attn.q_proj.weight') != -1 or name.find('self_attn.k_proj.weight') != -1 or name.find('self_attn.v_proj.weight') != -1: 91 | layer = name.split(".")[2] 92 | if name.find('self_attn.q_proj.weight') != -1: 93 | q = param.detach().cpu().float().numpy() 94 | elif name.find('self_attn.k_proj.weight') != -1: 95 | k = param.detach().cpu().float().numpy() 96 | elif name.find('self_attn.v_proj.weight') != -1: 97 | v = param.detach().cpu().float().numpy() 98 | qkv = np.hstack((q, k, v)) 99 | qkv.astype(np_weight_data_type).tofile(f"model.layers.{layer}.self_attn.qkv.weight.bin") 100 | print("qkv shape: ", qkv.shape) 101 | # if cur_layer == layer: 102 | # np.concat(param.detach().cpu().float().numpy()) 103 | # else: 104 | # cur_layer = layer 105 | 106 | elif name.find('self_attn.o_proj.weight') != -1: 107 | layer = name.split(".")[2] 108 | param.detach().cpu().float().numpy().astype(np_weight_data_type).tofile(f"model.layers.{layer}.self_attn.o_proj.weight.bin") 109 | 110 | elif name.find('mlp.gate_proj.weight') != -1 or name.find('mlp.up_proj.weight') != -1: 111 | layer = name.split(".")[2] 112 | if name.find('mlp.gate_proj.weight') != -1: 113 | gate = param.detach().cpu().float().numpy() 114 | elif name.find('mlp.up_proj.weight') != -1: 115 | up = param.detach().cpu().float().numpy() 116 | gate_up = np.hstack((gate, up)) 117 | gate_up.astype(np_weight_data_type).tofile(f"model.layers.{layer}.mlp.gate_up_proj.weight.bin") 118 | print("fused gate_up shape: ", gate_up.shape) 119 | # elif name.find('mlp.up_proj.weight') != -1: 120 | # layer = name.split(".")[2] 121 | # param.detach().cpu().float().numpy().astype(np_weight_data_type).tofile(f"model.layers.{layer}.mlp.up_proj.weight.bin") 122 | elif name.find('mlp.down_proj.weight') != -1: 123 | layer = name.split(".")[2] 124 | param.detach().cpu().float().numpy().astype(np_weight_data_type).tofile(f"model.layers.{layer}.mlp.down_proj.weight.bin") 125 | elif name.find('input_layernorm.weight') != -1: 126 | layer = name.split(".")[2] 127 | param.detach().cpu().float().numpy().astype(np_weight_data_type).tofile(f"model.layers.{layer}.input_layernorm.weight.bin") 128 | elif name.find('post_attention_layernorm.weight') != -1: 129 | layer = name.split(".")[2] 130 | param.detach().cpu().float().numpy().astype(np_weight_data_type).tofile(f"model.layers.{layer}.post_attention_layernorm.weight.bin") 131 | 132 | # else: 133 | 134 | # for i in range(len(huggingface_model_name_pattern)): 135 | # if name.find(huggingface_model_name_pattern[i]) != -1: 136 | # new_name = name.replace("h.", "layers.").replace(huggingface_model_name_pattern[i], ft_model_name_pattern[i]) 137 | # param.detach().cpu().numpy().astype(np_weight_data_type).tofile(saved_dir + "model..bin") 138 | --------------------------------------------------------------------------------