├── .gitignore ├── .gitmodules ├── CMakeLists.txt ├── README.md ├── cmake └── Modules │ └── FindNCCL.cmake ├── scripts ├── converter.py ├── decode_output.py ├── encode_input.py └── lib │ ├── converter_lib.py │ └── gpt_token_encoder.py └── src ├── CMakeLists.txt ├── csrc ├── CMakeLists.txt ├── kernel │ ├── CMakeLists.txt │ ├── activation_types.h │ ├── activations.cuh │ ├── addbias.cu │ ├── addbias.h │ ├── count_nan.cu │ ├── count_nan.h │ ├── embedding.cu │ ├── embedding.h │ ├── findmax.cu │ ├── findmax.h │ ├── fused_activ_multiply.cu │ ├── fused_activ_multiply.h │ ├── fused_addbias_activ.cu │ ├── fused_addbias_activ.h │ ├── fused_context_stage_attention.cu │ ├── fused_context_stage_attention.h │ ├── fused_decoding_stage_attention.cu │ ├── fused_decoding_stage_attention.h │ ├── fused_decoding_stage_attention_mha.cu │ ├── gather_last_tokens.cu │ ├── gather_last_tokens.h │ ├── kvcache_mgmt.cu │ ├── kvcache_mgmt.h │ ├── layernorm.cu │ ├── layernorm.h │ ├── reduction.cuh │ ├── rmsnorm.cu │ ├── rmsnorm.h │ ├── rotary_posi_embedding.cu │ ├── rotary_posi_embedding.h │ ├── softmax.cu │ ├── softmax.h │ ├── unfused_attention.cu │ ├── unfused_attention.h │ ├── xformers_attention.cu │ └── xformers_attention.h ├── layer │ ├── CMakeLists.txt │ ├── attention.cc │ ├── attention.h │ ├── ffn.cc │ ├── ffn.h │ ├── gated_ffn.cc │ └── gated_ffn.h ├── model │ ├── CMakeLists.txt │ └── gpt │ │ ├── CMakeLists.txt │ │ ├── gpt.cc │ │ ├── gpt.h │ │ ├── gpt2 │ │ ├── CMakeLists.txt │ │ ├── gpt2op.cc │ │ └── gpt2op.h │ │ ├── gpt_base.h │ │ ├── gpt_hyper_param.h │ │ ├── gpt_pagedattn_param.h │ │ ├── gpt_parallelism_param.h │ │ ├── gpt_weight.cc │ │ ├── gpt_weight.h │ │ ├── gptop_base.cc │ │ ├── gptop_base.h │ │ ├── llama2 │ │ ├── CMakeLists.txt │ │ ├── llama2op.cc │ │ └── llama2op.h │ │ └── opt │ │ ├── CMakeLists.txt │ │ ├── optop.cc │ │ └── optop.h ├── pybinding.cc └── util │ ├── CMakeLists.txt │ ├── cublas_wrapper.cc │ ├── cublas_wrapper.h │ ├── cuda_utils.h │ ├── debug_utils.h │ ├── nccl_utils.cc │ ├── nccl_utils.h │ ├── py_block_migration.cc │ ├── py_block_migration.h │ ├── py_nccl.cc │ ├── py_nccl.h │ ├── py_swapping.cc │ ├── py_swapping.h │ ├── st_datatypes.h │ └── torch_utils.h ├── examples ├── CMakeLists.txt ├── benchmark_all_input_same.cc ├── lib │ ├── common_gpt_hyper_params.h │ ├── inference_batch.cc │ ├── inference_batch.h │ ├── simple_vocab_decoder.h │ ├── st_args.cc │ ├── st_args.h │ └── utils.h └── run_gpt.cc └── unittest ├── CMakeLists.txt ├── kernel ├── CMakeLists.txt ├── addbias.cc ├── attention_ref.cc ├── attention_ref.h ├── findmax.cc ├── fused_activ_multiply.cc ├── fused_addbias_activ.cc ├── fused_attention.cc ├── kvcache_mgmt_ref.cc ├── kvcache_mgmt_ref.h ├── layernorm.cc ├── rmsnorm.cc ├── rotary_posi_embedding.cc ├── rotary_posi_embedding_ref.cc ├── rotary_posi_embedding_ref.h └── softmax.cc ├── layer ├── CMakeLists.txt ├── attention_ref.cc ├── attention_ref.h ├── attention_utils.h ├── parallel_attention.cc └── parallel_ffn.cc ├── model └── CMakeLists.txt ├── unittest_torch_utils.h ├── unittest_utils.h └── util ├── CMakeLists.txt └── cublas_wrapper.cc /.gitignore: -------------------------------------------------------------------------------- 1 | /build 2 | __pycache__ 3 | 4 | /temp/** 5 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "src/csrc/kernel/xformers"] 2 | path = src/csrc/kernel/xformers 3 | url = https://github.com/facebookresearch/xformers.git 4 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.18 FATAL_ERROR) 2 | project(SwiftTransformer LANGUAGES CXX CUDA) 3 | 4 | if (DEFINED ENV{CONDA_PREFIX}) 5 | # use conda environment 6 | link_directories($ENV{CONDA_PREFIX}/lib) 7 | include_directories($ENV{CONDA_PREFIX}/include) 8 | set(CMAKE_PREFIX_PATH $ENV{CONDA_PREFIX} ${CMAKE_PREFIX_PATH}) 9 | set(CMAKE_CUDA_COMPILER $ENV{CONDA_PREFIX}/bin/nvcc) 10 | set(CUDAToolkit_ROOT $ENV{CONDA_PREFIX}) 11 | endif() 12 | 13 | find_package(CUDAToolkit 11.4 REQUIRED) 14 | 15 | # gcc >= 8 is required, we do not support other compilers 16 | if ((NOT CMAKE_CXX_COMPILER_ID STREQUAL "GNU") OR (CMAKE_CXX_COMPILER_VERSION VERSION_LESS 8.0)) 17 | message(FATAL_ERROR "GCC 8.0 or higher is required") 18 | endif() 19 | # Add filesystem library for gcc < 9 20 | link_libraries( "$<$,$,9.0>>:-lstdc++fs>" ) 21 | 22 | # Set up C++ standard 23 | set(CXX_STD "17" CACHE STRING "C++ standard") 24 | set(CMAKE_CXX_STANDARD ${CXX_STD}) 25 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 26 | 27 | # Switch between release mode and debug mode 28 | # The user can use `-DBUILD_MODE=DEBUG` or `-DBUILD_MODE=RELEASE` to 29 | # choose the build mode. 30 | # If no option is provided, default to debug mode 31 | if (BUILD_MODE) 32 | string(TOUPPER ${BUILD_MODE} BUILD_MODE) 33 | if (BUILD_MODE STREQUAL "DEBUG") 34 | set(DEBUG ON) 35 | elseif (BUILD_MODE STREQUAL "RELEASE") 36 | set(RELEASE ON) 37 | else() 38 | message(FATAL_ERROR "Unknown build mode: ${BUILD_MODE}") 39 | endif() 40 | else() 41 | message("No build type selected, defaulting to RELEASE mode") 42 | message("Use -DBUILD_MODE=DEBUG or -DBUILD_MODE=RELEASE to specify build type") 43 | set(RELEASE ON) 44 | endif() 45 | 46 | # Set up C++ flag and CUDA flag 47 | if (DEBUG) 48 | message("Building in debug mode") 49 | set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -g -G -DDEBUG") 50 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -Wall -Wextra -Wno-unused-parameter -Wno-unused-function -DDEBUG") 51 | else() 52 | message("Building in release mode") 53 | set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3 -DRELEASE -lineinfo --prec-div=false") 54 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Ofast -Wall -Wextra -Wno-unused-parameter -Wno-unused-function -DRELEASE") 55 | endif() 56 | 57 | # Set up COMMON_HEADER_DIRS and COMMON_LIB_DIRS 58 | set(COMMON_HEADER_DIRS 59 | ${PROJECT_SOURCE_DIR} 60 | ${PROJECT_SOURCE_DIR}/src/csrc 61 | ) 62 | set(COMMON_LIB_DIRS "") 63 | 64 | # Set up MPI and NCCL for multi-GPU communication 65 | message("Building with MPI and NCCL") 66 | set(CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/Modules) 67 | set(MKL_MPI "openmpi") 68 | find_package(NCCL REQUIRED) 69 | find_package(MPI REQUIRED) 70 | set(CMAKE_MODULE_PATH "") # prevent the bugs for pytorch building 71 | 72 | # Add MPI and NCCL into COMMON_HEADER_DIRS & COMMON_LIB_DIRS 73 | list(APPEND COMMON_HEADER_DIRS ${MPI_INCLUDE_PATH} ${NCCL_INCLUDE_DIR}) 74 | list(APPEND COMMON_LIB_DIRS ${MPI_LIBRARIES} ${NCCL_LIBRARIES}) 75 | 76 | set(COMMON_LIBS CUDA::cudart) 77 | 78 | # Add Python into COMMON_HEADER_DIRS & COMMON_LIB_DIRS 79 | set(PYTHON_PATH "python" CACHE STRING "Python path") 80 | execute_process(COMMAND ${PYTHON_PATH} "-c" "import sysconfig; 81 | print(sysconfig.get_paths()['include']);" 82 | RESULT_VARIABLE _PYTHON_SUCCESS 83 | OUTPUT_VARIABLE PY_INCLUDE_DIR) 84 | if (NOT _PYTHON_SUCCESS MATCHES 0) 85 | message(FATAL_ERROR "Python config Error.") 86 | endif() 87 | list(APPEND COMMON_HEADER_DIRS ${PY_INCLUDE_DIR}) 88 | 89 | 90 | # Add LibTorch into COMMON_HEADER_DIRS & COMMON_LIB_DIRS 91 | execute_process(COMMAND ${PYTHON_PATH} "-c" "import os; import torch; 92 | print(os.path.dirname(torch.__file__), end='');" 93 | RESULT_VARIABLE _PYTHON_SUCCESS 94 | OUTPUT_VARIABLE TORCH_DIR) 95 | if (NOT _PYTHON_SUCCESS MATCHES 0) 96 | message(FATAL_ERROR "Torch config Error.") 97 | endif() 98 | list(APPEND CMAKE_PREFIX_PATH ${TORCH_DIR}) 99 | set(CAFFE2_USE_CUDNN 1) 100 | find_package(Torch REQUIRED) 101 | list(APPEND COMMON_HEADER_DIRS "${TORCH_INCLUDE_DIRS}") 102 | list(APPEND COMMON_LIBS "${TORCH_LIBRARIES}") 103 | 104 | 105 | # Let COMMON_HEADER_DIRS & COMMON_LIB_DIRS take effect 106 | include_directories(${COMMON_HEADER_DIRS}) 107 | link_directories(${COMMON_LIB_DIRS}) 108 | link_libraries(${COMMON_LIBS}) 109 | 110 | 111 | # Should turn off CXX11 ABI if pytorch is built with CXX11 ABI off 112 | execute_process(COMMAND ${PYTHON_PATH} "-c" "import torch; 113 | print(torch._C._GLIBCXX_USE_CXX11_ABI,end='');" 114 | RESULT_VARIABLE _PYTHON_SUCCESS 115 | OUTPUT_VARIABLE USE_CXX11_ABI) 116 | message("-- USE_CXX11_ABI=${USE_CXX11_ABI}") 117 | if (USE_CXX11_ABI) 118 | set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=1") 119 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=1") 120 | else() 121 | set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=0") 122 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=0") 123 | endif() 124 | 125 | 126 | # GoogleTest Preparation - Code block copied from 127 | # https://google.github.io/googletest/quickstart-cmake.html 128 | include(FetchContent) 129 | FetchContent_Declare( 130 | googletest 131 | GIT_REPOSITORY https://github.com/google/googletest.git 132 | GIT_TAG release-1.12.1 133 | ) 134 | FetchContent_MakeAvailable(googletest) 135 | 136 | 137 | # nlohmann_json Preparation - Code block copied from 138 | # https://github.com/nlohmann/json#cmake 139 | FetchContent_Declare( 140 | json 141 | URL https://github.com/nlohmann/json/releases/download/v3.11.2/json.tar.xz 142 | ) 143 | FetchContent_MakeAvailable(json) 144 | 145 | 146 | # fetch latest argparse 147 | FetchContent_Declare( 148 | argparse 149 | GIT_REPOSITORY https://github.com/p-ranav/argparse.git 150 | ) 151 | FetchContent_MakeAvailable(argparse) 152 | 153 | # Let all executable targets go to bin 154 | set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) 155 | 156 | # Add subdirectories 157 | add_subdirectory(src) 158 | -------------------------------------------------------------------------------- /scripts/converter.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is used for converting weights from other models (e.g. OPT, LLaMA2) 3 | into SwiftTransformer's format. 4 | 5 | Example usage: 6 | - Converting a single, unsharded weight: 7 | python3 converter.py --input /path/to/weight.pt --output /path/to/output --dtype fp16 --model opt 8 | - Converting a sharded weight: 9 | python3 converter.py --input /path/to/weight_*.pt --output /path/to/output --dtype fp16 --model llama2 10 | 11 | For the detailed workflow, please refer to comments in `converter_lib.py` 12 | """ 13 | import os, sys, argparse, re 14 | from glob import glob 15 | from typing import List, Optional 16 | 17 | import torch 18 | import lib.converter_lib as converter_lib 19 | 20 | assert __name__ == "__main__" 21 | 22 | def load_opt_weight(input: str) -> dict[str, torch.Tensor]: 23 | files = glob(input) 24 | if len(files) == 1: 25 | # unsharded weight. Load it directly 26 | return torch.load(files[0], torch.device("cpu"))["model"] 27 | 28 | def tensorMergeFunc(key: str, tensor_list: List[torch.Tensor]) -> Optional[torch.Tensor]: 29 | dim0_shard_regex = re.compile("embed_tokens|ffn_layernorm|fc1") 30 | dim1_shard_regex = re.compile("(fc2|out_proj).weight") 31 | shared_regex = re.compile( 32 | "embed_positions|layer_norm|(fc2|out_proj).bias|output_projection|version" 33 | ) 34 | to_ignore_regex = re.compile("decoder.version") 35 | if to_ignore_regex.search(key): 36 | # This weight should be ignored 37 | return None 38 | elif "qkv_proj.weight" in key: 39 | hidden_size = tensor_list[0].size(-1) 40 | return torch.cat(list(map(lambda x: x.view(3, -1, hidden_size), tensor_list)), dim=1).view(-1, hidden_size) 41 | elif "qkv_proj.bias" in key: 42 | return torch.cat(list(map(lambda x: x.view(3, -1), tensor_list)), dim = 1).view(-1) 43 | elif dim0_shard_regex.search(key): 44 | # This weight is sharded along dim 0 45 | return torch.cat(tensor_list, dim=0) 46 | elif dim1_shard_regex.search(key): 47 | # This weight is sharded along dim 1 48 | return torch.cat(tensor_list, dim=1) 49 | elif shared_regex.search(key): 50 | # This weight is shared across all shards 51 | return tensor_list[0] 52 | else: 53 | raise ValueError(f"Unrecognized weight key: {key}") 54 | 55 | result = converter_lib.reshardWeight( 56 | files, 57 | lambda x: x["model"], 58 | tensorMergeFunc 59 | ) 60 | 61 | return result 62 | 63 | def load_llama2_weight(input: str) -> dict[str, torch.Tensor]: 64 | files = glob(input) 65 | files = glob(input) 66 | if len(files) == 1: 67 | # unsharded weight. Load it directly 68 | return torch.load(files[0], torch.device("cpu"))["model"] 69 | 70 | def tensorMergeFunc(key: str, tensor_list: List[torch.Tensor]) -> Optional[torch.Tensor]: 71 | dim0_shard_regex = re.compile("\ 72 | layers.(\d+).feed_forward.w1.weight|\ 73 | layers.(\d+).feed_forward.w3.weight|\ 74 | layers.(\d+).attention.w(q|k|v).weight|\ 75 | output.weight") 76 | dim1_shard_regex = re.compile("\ 77 | layers.(\d+).feed_forward.w2.weight|\ 78 | layers.(\d+).attention.wo.weight|\ 79 | tok_embeddings.weight") 80 | shared_regex = re.compile("\ 81 | layers.(\d+).attention_norm.weight|\ 82 | layers.(\d+).ffn_norm.weight|\ 83 | norm.weight") 84 | to_ignore_regex = re.compile("rope.freqs") 85 | if to_ignore_regex.search(key): 86 | return None 87 | elif dim0_shard_regex.search(key): 88 | # This weight is sharded along dim 0 89 | return torch.cat(tensor_list, dim=0) 90 | elif dim1_shard_regex.search(key): 91 | # This weight is sharded along dim 1 92 | return torch.cat(tensor_list, dim=1) 93 | elif shared_regex.search(key): 94 | # This weight is shared across all shards 95 | return tensor_list[0] 96 | else: 97 | raise ValueError(f"Unrecognized weight key: {key}") 98 | 99 | result = converter_lib.reshardWeight( 100 | files, 101 | lambda x: x, 102 | tensorMergeFunc 103 | ) 104 | 105 | return result 106 | 107 | if __name__ == "__main__": 108 | parser = argparse.ArgumentParser(description="Convert weights from other models into SwiftTransformer's format.\ 109 | For example usage please refer to comments at the top of this file.") 110 | parser.add_argument("--input", type=str, required=True, help="Input checkpoint path or glob") 111 | parser.add_argument("--output", type=str, required=True, help="Output checkpoint path") 112 | parser.add_argument("--dtype", type=str, required=True, help="Output dtype") 113 | parser.add_argument("--model", type=str, required=True, help="Model name") 114 | 115 | args = parser.parse_args() 116 | input = args.input 117 | output = args.output 118 | dtype = args.dtype 119 | os.makedirs(output, exist_ok=True) 120 | 121 | torch_dtype = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16} 122 | assert dtype in torch_dtype, f"Unknown dtype {dtype}, expected one of {torch_dtype.keys()}" 123 | dtype = torch_dtype[dtype] 124 | 125 | supported_models = {"opt", "llama2"} 126 | assert args.model in supported_models, f"Unknown model {args.model}, expected one of {supported_models}" 127 | 128 | print(f"Converting {input} into torch.jit.script format") 129 | 130 | # Load the state dict (tensor_dict) 131 | # If the whole model is saved in a single file, then load the state dict directly 132 | # otherwise, load them separately and merge them into a single state dict 133 | if len(glob(input)) == 0: 134 | ValueError(f"Input {input} does not match any files") 135 | print(f"Input {input} does not match any files") 136 | exit(1) 137 | 138 | if args.model == "opt": 139 | state_dict = load_opt_weight(input) 140 | elif args.model == "llama2": 141 | state_dict = load_llama2_weight(input) 142 | else: 143 | raise ValueError(f"Unknown model {args.model}") 144 | 145 | print("Resharding and saving weights") 146 | converter_lib.convertWeight(output, state_dict, dtype, args.model) 147 | -------------------------------------------------------------------------------- /scripts/decode_output.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import lib.gpt_token_encoder as encoder 3 | 4 | if __name__ == "__main__": 5 | if len(sys.argv) != 3: 6 | print("Usage: python3 decode_output.py ") 7 | sys.exit(1) 8 | 9 | vocab_file = sys.argv[1] 10 | bpe_file = sys.argv[2] 11 | enc = encoder.get_encoder(vocab_file, bpe_file) 12 | for line in sys.stdin: 13 | line = line.strip() 14 | print(enc.decode(list(map(int, line.replace(',', ' ').split())))) 15 | -------------------------------------------------------------------------------- /scripts/encode_input.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import lib.gpt_token_encoder as encoder 3 | 4 | if __name__ == "__main__": 5 | if len(sys.argv) != 3: 6 | print("Usage: python3 encode_input.py ") 7 | sys.exit(1) 8 | 9 | vocab_file = sys.argv[1] 10 | bpe_file = sys.argv[2] 11 | enc = encoder.get_encoder(vocab_file, bpe_file) 12 | for line in sys.stdin: 13 | line = line.strip() 14 | print(" ".join(map(str, enc.encode(line)))) 15 | -------------------------------------------------------------------------------- /scripts/lib/gpt_token_encoder.py: -------------------------------------------------------------------------------- 1 | """Byte pair encoding utilities""" 2 | 3 | # Modified MIT License 4 | 5 | # Software Copyright (c) 2019 OpenAI 6 | 7 | # We don’t claim ownership of the content you create with GPT-2, so it is yours to do with as you please. 8 | # We only ask that you use GPT-2 responsibly and clearly indicate your content was created using GPT-2. 9 | 10 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and 11 | # associated documentation files (the "Software"), to deal in the Software without restriction, 12 | # including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, 13 | # and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, 14 | # subject to the following conditions: 15 | 16 | # The above copyright notice and this permission notice shall be included 17 | # in all copies or substantial portions of the Software. 18 | # The above copyright notice and this permission notice need not be included 19 | # with content created by the Software. 20 | 21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 22 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS 24 | # BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 25 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE 26 | # OR OTHER DEALINGS IN THE SOFTWARE. 27 | # Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved. 28 | # 29 | # Licensed under the Apache License, Version 2.0 (the "License"); 30 | # you may not use this file except in compliance with the License. 31 | # You may obtain a copy of the License at 32 | # 33 | #     http://www.apache.org/licenses/LICENSE-2.0 34 | # 35 | # Unless required by applicable law or agreed to in writing, software 36 | # distributed under the License is distributed on an "AS IS" BASIS, 37 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 38 | # See the License for the specific language governing permissions and 39 | # limitations under the License. 40 | 41 | import os 42 | import json 43 | import regex as re 44 | from functools import lru_cache 45 | 46 | @lru_cache() 47 | def bytes_to_unicode(): 48 | """ 49 | Returns list of utf-8 byte and a corresponding list of unicode strings. 50 | The reversible bpe codes work on unicode strings. 51 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 52 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 53 | This is a significant percentage of your normal, say, 32K bpe vocab. 54 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 55 | And avoids mapping to whitespace/control characters the bpe code barfs on. 56 | """ 57 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 58 | cs = bs[:] 59 | n = 0 60 | for b in range(2**8): 61 | if b not in bs: 62 | bs.append(b) 63 | cs.append(2**8+n) 64 | n += 1 65 | cs = [chr(n) for n in cs] 66 | return dict(zip(bs, cs)) 67 | 68 | def get_pairs(word): 69 | """Return set of symbol pairs in a word. 70 | 71 | Word is represented as tuple of symbols (symbols being variable-length strings). 72 | """ 73 | pairs = set() 74 | prev_char = word[0] 75 | for char in word[1:]: 76 | pairs.add((prev_char, char)) 77 | prev_char = char 78 | return pairs 79 | 80 | class Encoder: 81 | def __init__(self, encoder, bpe_merges, errors='replace'): 82 | self.encoder = encoder 83 | self.decoder = {v:k for k,v in self.encoder.items()} 84 | self.errors = errors # how to handle errors in decoding 85 | self.byte_encoder = bytes_to_unicode() 86 | self.byte_decoder = {v:k for k, v in self.byte_encoder.items()} 87 | self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) 88 | self.cache = {} 89 | 90 | # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions 91 | self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") 92 | 93 | def bpe(self, token): 94 | if token in self.cache: 95 | return self.cache[token] 96 | word = tuple(token) 97 | pairs = get_pairs(word) 98 | 99 | if not pairs: 100 | return token 101 | 102 | while True: 103 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 104 | if bigram not in self.bpe_ranks: 105 | break 106 | first, second = bigram 107 | new_word = [] 108 | i = 0 109 | while i < len(word): 110 | try: 111 | j = word.index(first, i) 112 | new_word.extend(word[i:j]) 113 | i = j 114 | except: 115 | new_word.extend(word[i:]) 116 | break 117 | 118 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 119 | new_word.append(first+second) 120 | i += 2 121 | else: 122 | new_word.append(word[i]) 123 | i += 1 124 | new_word = tuple(new_word) 125 | word = new_word 126 | if len(word) == 1: 127 | break 128 | else: 129 | pairs = get_pairs(word) 130 | word = ' '.join(word) 131 | self.cache[token] = word 132 | return word 133 | 134 | def encode(self, text): 135 | bpe_tokens = [] 136 | for token in re.findall(self.pat, text): 137 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 138 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 139 | return bpe_tokens 140 | 141 | def decode(self, tokens): 142 | text = ''.join([self.decoder[token] for token in tokens]) 143 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) 144 | return text 145 | 146 | def get_encoder(vocab_file, bpe_file): 147 | with open(vocab_file, 'r') as f: 148 | encoder = json.load(f) 149 | with open(bpe_file, 'r', encoding="utf-8") as f: 150 | bpe_data = f.read() 151 | bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]] 152 | return Encoder( 153 | encoder=encoder, 154 | bpe_merges=bpe_merges, 155 | ) 156 | -------------------------------------------------------------------------------- /src/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(csrc) 2 | add_subdirectory(examples) 3 | add_subdirectory(unittest) 4 | -------------------------------------------------------------------------------- /src/csrc/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(kernel) 2 | add_subdirectory(layer) 3 | add_subdirectory(model) 4 | add_subdirectory(util) 5 | 6 | add_library(st_pybinding SHARED pybinding.cc) 7 | 8 | target_link_libraries(st_pybinding 9 | model_gpt 10 | model_opt 11 | model_llama2 12 | model_gpt2 13 | py_nccl_utils 14 | py_swapping 15 | py_block_migration 16 | ) 17 | 18 | # Set the output directory for the shared library 19 | set_target_properties(st_pybinding PROPERTIES 20 | LIBRARY_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}/lib 21 | ) -------------------------------------------------------------------------------- /src/csrc/kernel/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_library(kernel STATIC 2 | addbias.cu 3 | count_nan.cu 4 | embedding.cu 5 | findmax.cu 6 | fused_activ_multiply.cu 7 | fused_addbias_activ.cu 8 | fused_context_stage_attention.cu 9 | fused_decoding_stage_attention.cu 10 | fused_decoding_stage_attention_mha.cu 11 | gather_last_tokens.cu 12 | kvcache_mgmt.cu 13 | layernorm.cu 14 | rmsnorm.cu 15 | rotary_posi_embedding.cu 16 | softmax.cu 17 | unfused_attention.cu 18 | ) 19 | target_link_libraries(kernel PUBLIC util) 20 | set_property(TARGET kernel PROPERTY POSITION_INDEPENDENT_CODE ON) 21 | 22 | file(GLOB xformers_autogen_impl_files ${CMAKE_CURRENT_SOURCE_DIR}/xformers/xformers/csrc/attention/cuda/fmha/autogen/impl/*.cu) 23 | add_library(xformers_autogen_impl STATIC ${xformers_autogen_impl_files}) 24 | target_include_directories(xformers_autogen_impl PUBLIC xformers/third_party/cutlass/include) 25 | set_property(TARGET xformers_autogen_impl PROPERTY POSITION_INDEPENDENT_CODE ON) 26 | 27 | add_library(xformers_kernel STATIC 28 | xformers_attention.cu 29 | ) 30 | target_include_directories(xformers_kernel PUBLIC xformers/third_party/cutlass/include) 31 | target_link_libraries(xformers_kernel PUBLIC util xformers_autogen_impl) 32 | set_property(TARGET xformers_kernel PROPERTY POSITION_INDEPENDENT_CODE ON) 33 | 34 | -------------------------------------------------------------------------------- /src/csrc/kernel/activation_types.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | /* 4 | activation_types.h & activations.cuh 5 | 6 | Activation functions for neural networks. 7 | 8 | We put the definition of ActivationType in activtion_types.h and the 9 | implementation of the activation functions in activations.cuh. This is 10 | because we want the activation functions to be inlined into our kernels 11 | (so we need to prepend `__forceinline__ __device__` to the function), while 12 | we want to be able to use the ActivationType enum in both the host and device. 13 | 14 | If you are writing a kernel that uses an activation function, you should 15 | include activation_types.h and activations.cuh in your kernel file. If 16 | you are writing host code that only use the ActivationType enum, you should 17 | only include activation_types.h. 18 | */ 19 | 20 | namespace st { 21 | namespace kernel { 22 | enum class ActivationType { 23 | RELU, 24 | SILU, 25 | GELU 26 | }; 27 | } 28 | 29 | using ActivationType = kernel::ActivationType; 30 | } 31 | -------------------------------------------------------------------------------- /src/csrc/kernel/activations.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include "activation_types.h" 6 | #include "util/cuda_utils.h" 7 | 8 | // See comments in activation_types.h 9 | namespace st::kernel { 10 | template 11 | __forceinline__ __device__ T applyActivation(const T &x) { 12 | if constexpr (activation_type == ActivationType::RELU) { 13 | return x > (T)0 ? x : (T)0; 14 | } 15 | else if constexpr (activation_type == ActivationType::SILU) { 16 | return (T)((float)x / (1.0f + __expf((float)-x))); 17 | } 18 | else if constexpr (activation_type == ActivationType::GELU) { 19 | // NOTE. GELU has many different implementations, 20 | // this is the one currently used in vllm-project/vllm repo (gelu_new_kernel). 21 | // file url: https://github.com/vllm-project/vllm/blob/main/csrc/activation_kernels.cu 22 | const float x3 = (float) (x * x * x); 23 | const T t = (T) tanhf((T) (0.79788456f * (float) (x + (T) (0.044715f * x3)))); 24 | return ((T) 0.5) * x * (((T) 1.0) + t); 25 | } 26 | else { 27 | // No activation matches, raise an error 28 | assert(false); 29 | } 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /src/csrc/kernel/addbias.cu: -------------------------------------------------------------------------------- 1 | #include "addbias.h" 2 | 3 | #include "util/debug_utils.h" 4 | #include "util/cuda_utils.h" 5 | 6 | namespace st::kernel { 7 | 8 | /* 9 | addbiasKernel & addbias 10 | 11 | This performs point-wise addition of two arrays, `input` and `bias`, and 12 | store the output to `output`. 13 | 14 | Input: 15 | - input: the input array, [size] 16 | - bias: the bias array, [size] 17 | - size: the size of the input and bias array 18 | Output: 19 | - output: the output array, [size] 20 | output[i] = input[i] + bias[i] 21 | */ 22 | template 23 | __global__ void addbiasKernel( 24 | T* output, 25 | const T* input, 26 | const T* bias, 27 | const int64_t size 28 | ) { 29 | typedef std::conditional_t, half2, float2> T2; 30 | #pragma unroll 4 31 | for (int64_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size/2; i += blockDim.x * gridDim.x) { 32 | T2 input_elem = ((const T2*)input)[i]; 33 | T2 bias_elem = ((const T2*)bias)[i]; 34 | T2 result_elem = {input_elem.x + bias_elem.x, input_elem.y + bias_elem.y}; 35 | ((T2*)output)[i] = result_elem; 36 | } 37 | } 38 | 39 | template 40 | void addbias( 41 | T* output, 42 | const T* input, 43 | const T* bias, 44 | const int64_t size 45 | ) { 46 | assert_whenever (size%2 == 0); 47 | const uint32_t blockSize = 256; 48 | const uint32_t gridSize = std::min((size/2 + blockSize - 1) / blockSize, 16384l); 49 | addbiasKernel<<>>(output, input, bias, size); 50 | } 51 | 52 | template void addbias(half* output, const half* input, const half* bias, const int64_t size); 53 | template void addbias(float* output, const float* input, const float* bias, const int64_t size); 54 | 55 | 56 | /* 57 | addbiasBatchedKernel & addbiasBatched 58 | 59 | This performs batched addbias. 60 | 61 | Input: 62 | - input: the input array, [batch, size] 63 | - bias: the bias array, [size] 64 | - batch: the batch size 65 | - size: the size of the input and bias array 66 | Output: 67 | - output: the output array, [batch, size] 68 | output[b][i] = input[b][i] + bias[i] 69 | TODO Optimize this kernel 70 | */ 71 | template 72 | __global__ void addbiasBatchedKernel( 73 | T* output, 74 | const T* input, 75 | const T* bias, 76 | const int64_t batch, 77 | const int64_t size 78 | ) { 79 | typedef std::conditional_t, half2, float2> T2; 80 | #pragma unroll 4 81 | for (int64_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size * batch/2; i += blockDim.x * gridDim.x) { 82 | const int64_t s = i % (size/2); 83 | T2 input_elem = ((const T2*)input)[i]; 84 | T2 bias_elem = ((const T2*)bias)[s]; 85 | T2 result_elem = {input_elem.x + bias_elem.x, input_elem.y + bias_elem.y}; 86 | ((T2*)output)[i] = result_elem; 87 | } 88 | } 89 | 90 | template 91 | void addbiasBatched( 92 | T* output, 93 | const T* input, 94 | const T* bias, 95 | const int64_t batch, 96 | const int64_t size 97 | ) { 98 | assert_whenever (size%2 == 0); 99 | const uint32_t blockSize = 256; 100 | const uint32_t gridSize = std::min((size*batch/2 + blockSize - 1) / blockSize, 16384l); 101 | addbiasBatchedKernel<<>>(output, input, bias, batch, size); 102 | } 103 | 104 | template void addbiasBatched(half* output, const half* input, const half* bias, const int64_t batch, const int64_t size); 105 | template void addbiasBatched(float* output, const float* input, const float* bias, const int64_t batch, const int64_t size); 106 | 107 | } // namespace st::kernel 108 | -------------------------------------------------------------------------------- /src/csrc/kernel/addbias.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | namespace st::kernel { 6 | 7 | template 8 | void addbias( 9 | T* output, 10 | const T* input, 11 | const T* bias, 12 | const int64_t size 13 | ); 14 | 15 | template 16 | void addbiasBatched( 17 | T* output, 18 | const T* input, 19 | const T* bias, 20 | const int64_t batch_size, 21 | const int64_t size 22 | ); 23 | 24 | } // namespace st::kernel -------------------------------------------------------------------------------- /src/csrc/kernel/count_nan.cu: -------------------------------------------------------------------------------- 1 | #include "count_nan.h" 2 | 3 | #include "util/cuda_utils.h" 4 | 5 | namespace st::kernel { 6 | 7 | template 8 | __global__ void countNanKernel( 9 | int* count, 10 | const T* arr, 11 | int n 12 | ) { 13 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) { 14 | if (arr[i] != arr[i]) { 15 | atomicAdd(count, 1); 16 | } 17 | } 18 | } 19 | 20 | template 21 | int countNan( 22 | const T* arr, 23 | int n 24 | ) { 25 | int* count; 26 | cudaMalloc(&count, sizeof(int)); 27 | cudaMemset(count, 0, sizeof(int)); 28 | 29 | int blockSize = 256; 30 | int numBlocks = (n + blockSize - 1) / blockSize; 31 | countNanKernel<<>> (count, arr, n); 32 | 33 | int res; 34 | cudaMemcpy(&res, count, sizeof(int), cudaMemcpyDeviceToHost); 35 | cudaFree(count); 36 | return res; 37 | } 38 | 39 | #define INSTANTIATE(T) \ 40 | template int countNan(const T* arr, int n); 41 | 42 | INSTANTIATE(float) 43 | INSTANTIATE(half) 44 | 45 | } -------------------------------------------------------------------------------- /src/csrc/kernel/count_nan.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | namespace st::kernel { 4 | 5 | template 6 | int countNan( 7 | const T* arr, 8 | int n 9 | ); 10 | 11 | } -------------------------------------------------------------------------------- /src/csrc/kernel/embedding.cu: -------------------------------------------------------------------------------- 1 | #include "embedding.h" 2 | 3 | namespace st::kernel { 4 | 5 | /* 6 | embedAndPosiEncodeBatched 7 | 8 | Perform batched input embedding & positional encoding. 9 | 10 | This kernel applies embedding and positional encoding on a batch of input sequences. 11 | 12 | Note. If you do not want to perform positional encoding (e.g. when using models that 13 | adopts rotary embedding), pass a nullptr to embed_positions_weight. 14 | */ 15 | 16 | template 17 | __global__ void embedAndPosiEncodeBatchedKernel ( 18 | T* __restrict__ result, 19 | const int64_t* __restrict__ token_ids, 20 | const int64_t* __restrict__ position_ids, 21 | const T* __restrict__ embed_tokens_weight, 22 | const T* __restrict__ embed_positions_weight, 23 | const int64_t hidden_size 24 | ) { 25 | const int64_t my_token_id = token_ids[blockIdx.x]; 26 | const int64_t my_position_id = DO_POSI_ENCODING ? position_ids[blockIdx.x] : 0; 27 | #pragma unroll 4 28 | for (int64_t hidden_size_index = threadIdx.x; hidden_size_index < hidden_size; hidden_size_index += blockDim.x) { 29 | T cur_result = embed_tokens_weight[my_token_id * hidden_size + hidden_size_index]; 30 | if constexpr (DO_POSI_ENCODING) { 31 | cur_result += embed_positions_weight[my_position_id * hidden_size + hidden_size_index]; 32 | } 33 | result[blockIdx.x * hidden_size + hidden_size_index] = cur_result; 34 | } 35 | } 36 | 37 | template 38 | void embedAndPosiEncodeBatched( 39 | T* result, 40 | const int64_t* token_ids, // [num_tokens] 41 | const int64_t* position_ids, // [num_tokens] 42 | const T* embed_tokens_weight, // [vocab_size, hidden_size] 43 | const T* embed_positions_weight, // [max_position_embeddings, hidden_size] 44 | const int64_t num_tokens, 45 | const int64_t hidden_size 46 | ) { 47 | bool perform_posi_encoding = embed_positions_weight != nullptr; 48 | if (perform_posi_encoding) { 49 | embedAndPosiEncodeBatchedKernel<<>>( 50 | result, 51 | token_ids, 52 | position_ids, 53 | embed_tokens_weight, 54 | embed_positions_weight, 55 | hidden_size 56 | ); 57 | } else { 58 | embedAndPosiEncodeBatchedKernel<<>>( 59 | result, 60 | token_ids, 61 | position_ids, 62 | embed_tokens_weight, 63 | embed_positions_weight, 64 | hidden_size 65 | ); 66 | } 67 | } 68 | 69 | #define INSTANTIATE(T) \ 70 | template void embedAndPosiEncodeBatched( \ 71 | T* result, \ 72 | const int64_t* token_ids, \ 73 | const int64_t* position_ids, \ 74 | const T* embed_tokens_weight, \ 75 | const T* embed_positions_weight, \ 76 | const int64_t num_tokens, \ 77 | const int64_t hidden_size \ 78 | ); 79 | 80 | INSTANTIATE(half) 81 | INSTANTIATE(float) 82 | 83 | } // namespace st::kernel -------------------------------------------------------------------------------- /src/csrc/kernel/embedding.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "util/cuda_utils.h" 4 | 5 | namespace st::kernel { 6 | 7 | template 8 | void embedAndPosiEncodeBatched( 9 | T* result, 10 | const int64_t* token_ids, 11 | const int64_t* position_ids, 12 | const T* embed_tokens_weight, 13 | const T* embed_positions_weight, 14 | const int64_t num_tokens, 15 | const int64_t hidden_size 16 | ); 17 | 18 | } // namespace st::kernel -------------------------------------------------------------------------------- /src/csrc/kernel/findmax.cu: -------------------------------------------------------------------------------- 1 | #include "findmax.h" 2 | 3 | #include 4 | 5 | #include "util/cuda_utils.h" 6 | #include "reduction.cuh" 7 | 8 | namespace st::kernel { 9 | 10 | /* 11 | findmaxKernel & findmax 12 | 13 | Find the maximum value in an array and return its index. 14 | */ 15 | template 16 | __global__ void findmaxBatchedKernel( 17 | int64_t* max_indices, 18 | const T* input, 19 | const int64_t batch_size, 20 | const int64_t length 21 | ) { 22 | __shared__ T s_max; 23 | T local_max = -65400; 24 | int64_t local_max_index; 25 | for (int64_t i = threadIdx.x; i < length; i += blockDim.x) { 26 | if (input[i + length*blockIdx.x] > local_max) { 27 | local_max = input[i + length*blockIdx.x]; 28 | local_max_index = i; 29 | } 30 | } 31 | T max_val = (T)(blockDim.x <= 32 ? warpReduceMax((float)local_max) : blockReduceMax((float)local_max)); 32 | if (threadIdx.x == 0) { 33 | s_max = max_val; 34 | } 35 | __syncthreads(); 36 | if (local_max == s_max) { 37 | max_indices[blockIdx.x] = local_max_index; 38 | } 39 | } 40 | 41 | template 42 | void findmaxBatched( 43 | int64_t* max_indices, // [batch_size] 44 | const T* input, // [batch_size, length] 45 | const int64_t batch_size, 46 | const int64_t length 47 | ) { 48 | int64_t threads = 1024; 49 | findmaxBatchedKernel<<>>(max_indices, input, batch_size, length); 50 | } 51 | 52 | template void findmaxBatched(int64_t* max_indices, const half* input, const int64_t batch_size, const int64_t length); 53 | template void findmaxBatched(int64_t* max_indices, const float* input, const int64_t batch_size, const int64_t length); 54 | 55 | 56 | } // namespace st::kernel -------------------------------------------------------------------------------- /src/csrc/kernel/findmax.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | namespace st::kernel { 4 | 5 | template 6 | void findmaxBatched( 7 | int64_t* max_indices, 8 | const T* input, 9 | const int64_t batch_size, 10 | const int64_t length 11 | ); 12 | 13 | } // namespace st::kernel -------------------------------------------------------------------------------- /src/csrc/kernel/fused_activ_multiply.cu: -------------------------------------------------------------------------------- 1 | #include "fused_activ_multiply.h" 2 | 3 | #include 4 | 5 | #include "util/cuda_utils.h" 6 | #include "util/debug_utils.h" 7 | #include "kernel/activations.cuh" 8 | 9 | namespace st::kernel { 10 | 11 | /* 12 | fusedSiluMultiplyKernel 13 | 14 | Given two arrays, input1 and input2, compute the following: 15 | output[i] = silu(input1[i]) * input2[i] 16 | */ 17 | 18 | template 19 | __global__ void fusedActivationMultiplyKernel( 20 | T* output, 21 | const T* input1, 22 | const T* input2, 23 | int64_t n 24 | ) { 25 | typedef std::conditional_t, half2, float2> T2; 26 | #pragma unroll 4 27 | for (int64_t i = blockIdx.x * blockDim.x + threadIdx.x; i < n/2; i += blockDim.x * gridDim.x) { 28 | T2 input1_elem = ((const T2*)input1)[i]; 29 | T2 input2_elem = ((const T2*)input2)[i]; 30 | T2 output_elem = { 31 | applyActivation(input1_elem.x) * input2_elem.x, 32 | applyActivation(input1_elem.y) * input2_elem.y 33 | }; 34 | ((T2*)output)[i] = output_elem; 35 | } 36 | } 37 | 38 | template 39 | void fusedActivationMultiply( 40 | T* output, 41 | const T* input1, 42 | const T* input2, 43 | int64_t n, 44 | ActivationType activation_type 45 | ) { 46 | assert_whenever (n%2 == 0); 47 | int blockSize = 256; 48 | int gridSize = (n/2 + blockSize - 1) / blockSize; 49 | switch (activation_type) { 50 | case ActivationType::RELU: 51 | fusedActivationMultiplyKernel<<>>(output, input1, input2, n); 52 | return; 53 | case ActivationType::SILU: 54 | fusedActivationMultiplyKernel<<>>(output, input1, input2, n); 55 | return; 56 | case ActivationType::GELU: 57 | fusedActivationMultiplyKernel<<>>(output, input1, input2, n); 58 | return; 59 | default: 60 | assert (false); 61 | } 62 | } 63 | 64 | #define INSTANTIALIZE(T) \ 65 | template void fusedActivationMultiply( \ 66 | T* output, \ 67 | const T* input1, \ 68 | const T* input2, \ 69 | int64_t n, \ 70 | ActivationType activationType \ 71 | ); 72 | 73 | INSTANTIALIZE(half) 74 | INSTANTIALIZE(float) 75 | 76 | } // namespace st::kernel -------------------------------------------------------------------------------- /src/csrc/kernel/fused_activ_multiply.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | #include "kernel/activation_types.h" 6 | namespace st::kernel { 7 | 8 | template 9 | void fusedActivationMultiply( 10 | T* output, 11 | const T* input1, 12 | const T* input2, 13 | int64_t n, 14 | ActivationType activationType 15 | ); 16 | 17 | } // namespace st::kernel -------------------------------------------------------------------------------- /src/csrc/kernel/fused_addbias_activ.cu: -------------------------------------------------------------------------------- 1 | #include "fused_addbias_activ.h" 2 | 3 | #include 4 | 5 | #include "activations.cuh" 6 | #include "util/debug_utils.h" 7 | #include "util/cuda_utils.h" 8 | 9 | namespace st::kernel { 10 | 11 | /* 12 | fusedAddbiasBatchedActivation 13 | 14 | This kernel is used to add bias to input and then apply the activation function. 15 | We fuse the two kernels into one to improve performance (by reduce the number of memory accesses). 16 | 17 | Input: 18 | - input: the input array, [size] 19 | - bias: the bias array, [size] 20 | - size: the size of input and bias 21 | Output: 22 | - output: the output array, [size] 23 | output[i] = activation(input[i] + bias[i]) 24 | */ 25 | template 26 | __global__ void fusedAddbiasBatchedActivationKernel( 27 | T* output, 28 | const T* input, 29 | const T* bias, 30 | const int64_t batch, 31 | const int64_t size 32 | ) { 33 | typedef std::conditional_t, half2, float2> T2; 34 | #pragma unroll 4 35 | for (int64_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size * batch / 2; i += blockDim.x * gridDim.x) { 36 | const int64_t s = i % (size/2); 37 | T2 input_elem = ((const T2*)input)[i]; 38 | T2 bias_elem = ((const T2*)bias)[s]; 39 | T2 output_elem = { 40 | applyActivation(input_elem.x + bias_elem.x), 41 | applyActivation(input_elem.y + bias_elem.y) 42 | }; 43 | ((T2*)output)[i] = output_elem; 44 | } 45 | } 46 | 47 | 48 | template 49 | void fusedAddbiasBatchedActivation( 50 | T* output, 51 | const T* input, 52 | const T* bias, 53 | const int64_t batch, 54 | const int64_t size, 55 | ActivationType activation_type 56 | ) { 57 | assert_whenever (size%2 == 0); 58 | const uint32_t blockSize = 256; 59 | const uint32_t gridSize = (size*batch/2 + blockSize - 1) / blockSize; 60 | switch (activation_type) { 61 | case ActivationType::RELU: 62 | fusedAddbiasBatchedActivationKernel<<>>(output, input, bias, batch, size); 63 | break; 64 | case ActivationType::SILU: 65 | fusedAddbiasBatchedActivationKernel<<>>(output, input, bias, batch, size); 66 | break; 67 | case ActivationType::GELU: 68 | fusedAddbiasBatchedActivationKernel<<>>(output, input, bias, batch, size); 69 | break; 70 | default: 71 | assert(false); 72 | } 73 | } 74 | 75 | template void fusedAddbiasBatchedActivation(half* output, const half* input, const half* bias, const int64_t batch, const int64_t size, ActivationType activation_type); 76 | template void fusedAddbiasBatchedActivation(float* output, const float* input, const float* bias, const int64_t batch, const int64_t size, ActivationType activation_type); 77 | 78 | } // namespace st::kernel -------------------------------------------------------------------------------- /src/csrc/kernel/fused_addbias_activ.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | #include "activation_types.h" 7 | 8 | namespace st::kernel { 9 | 10 | template 11 | void fusedAddbiasBatchedActivation( 12 | T* output, 13 | const T* input, 14 | const T* bias, 15 | const int64_t batch, 16 | const int64_t size, 17 | ActivationType activation_type 18 | ); 19 | 20 | } // namespace st::kernel -------------------------------------------------------------------------------- /src/csrc/kernel/fused_context_stage_attention.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "util/cuda_utils.h" 4 | 5 | namespace st::kernel { 6 | 7 | template 8 | void fusedContextStageAttention( 9 | T* __restrict__ result, 10 | const T* __restrict__ qkvs, 11 | const float qk_scale, 12 | const int64_t* __restrict__ input_lens, 13 | const int64_t num_context_reqs, 14 | const int64_t* __restrict__ ith_context_req_req_index, 15 | const int32_t* __restrict__ ith_context_req_token_index, 16 | const int64_t num_q_heads, 17 | const int64_t num_kv_heads, 18 | const int64_t head_dim, 19 | const int64_t num_tokens, 20 | float* __restrict__ m_buf, 21 | float* __restrict__ l_buf 22 | ); 23 | 24 | } // namespace st::kernel -------------------------------------------------------------------------------- /src/csrc/kernel/fused_decoding_stage_attention.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | namespace st::kernel { 4 | 5 | template 6 | void fusedDecodingStageAttention( 7 | T* __restrict__ result, 8 | const T* __restrict__ qkvs, 9 | T* k_cache, 10 | T* v_cache, 11 | const float scale, 12 | const int64_t* __restrict__ block_table, 13 | const int64_t* __restrict__ input_lens, 14 | const int64_t num_decoding_reqs, 15 | const int64_t* __restrict__ ith_decoding_req_req_index, 16 | const int64_t* __restrict__ ith_decoding_req_token_index, 17 | const int64_t max_decoding_req_len, 18 | const int64_t num_layers, 19 | const int64_t num_q_heads, 20 | const int64_t num_kv_heads, 21 | const int64_t head_dim, 22 | const int64_t layer_id, 23 | const int64_t block_size, 24 | const int64_t max_num_block_per_seq 25 | ); 26 | 27 | template 28 | void fusedDecodingStageAttentionMHA( 29 | T* __restrict__ result, 30 | const T* __restrict__ qkvs, 31 | T* k_cache, 32 | T* v_cache, 33 | const float scale, 34 | const int64_t* __restrict__ block_table, 35 | const int64_t* __restrict__ input_lens, 36 | const int64_t num_decoding_reqs, 37 | const int64_t* __restrict__ ith_decoding_req_req_index, 38 | const int64_t* __restrict__ ith_decoding_req_token_index, 39 | const int64_t max_decoding_req_len, 40 | const int64_t num_layers, 41 | const int64_t num_heads, 42 | const int64_t head_dim, 43 | const int64_t layer_id, 44 | const int64_t block_size, 45 | const int64_t max_num_block_per_seq 46 | ); 47 | 48 | } // namespace st::kernel -------------------------------------------------------------------------------- /src/csrc/kernel/gather_last_tokens.cu: -------------------------------------------------------------------------------- 1 | #include "gather_last_tokens.h" 2 | 3 | namespace st::kernel { 4 | 5 | /* 6 | gatherLastTokens 7 | 8 | Gather the last token from each request into a new array. 9 | 10 | When we just finished forwardDecoder, the layout of the decoder output contains a bunch of 11 | tokens, including prompts of requests that are in context stage and the last token of the 12 | requests that are in decoding stage. For example, assume request 0 and 2 are in context 13 | stage while request 1, 3, 4 are in decoding stage, then the array looks like: 14 | | t00 | t01 | t02 | t03 | t12 | t20 | t21 | t22 | t23 | t24 | t34 | t42 | 15 | Where tij means the j-th token of request i. 16 | 17 | However, only the last token from each request needs to be sampled, so we need to gather 18 | the last token of each request together. For the above example, the result should be: 19 | | t03 | t12 | t24 | t34 | t42 | 20 | 21 | That is what this kernel performs. 22 | */ 23 | template 24 | __global__ void gatherLastTokensKernel( 25 | T* __restrict__ result, // [batch_size, hidden_dim] 26 | const T* __restrict__ tokens, // [num_tokens, hidden_dim] 27 | const int64_t num_tokens, 28 | const int64_t batch_size, 29 | const int64_t hidden_dim, 30 | const int64_t* __restrict__ sum_prev_input_lens 31 | ) { 32 | int64_t token_index = blockIdx.x == batch_size-1 ? num_tokens-1 : sum_prev_input_lens[blockIdx.x+1]-1; 33 | #pragma unroll 4 34 | for (int64_t hidden_dim_index = threadIdx.x; hidden_dim_index < hidden_dim; hidden_dim_index += blockDim.x) { 35 | result[blockIdx.x*hidden_dim + hidden_dim_index] = tokens[token_index*hidden_dim + hidden_dim_index]; 36 | } 37 | } 38 | 39 | template 40 | void gatherLastTokens( 41 | T* result, 42 | const T* tokens, 43 | const int64_t num_tokens, 44 | const int64_t batch_size, 45 | const int64_t hidden_dim, 46 | const int64_t* sum_prev_input_lens 47 | ) { 48 | gatherLastTokensKernel<<>>(result, tokens, num_tokens, batch_size, hidden_dim, sum_prev_input_lens); 49 | } 50 | 51 | #define INSTANTIATE(T) \ 52 | template void gatherLastTokens( \ 53 | T* result, \ 54 | const T* tokens, \ 55 | const int64_t num_tokens, \ 56 | const int64_t batch_size, \ 57 | const int64_t hidden_dim, \ 58 | const int64_t* sum_prev_input_lens \ 59 | ); 60 | 61 | INSTANTIATE(half) 62 | INSTANTIATE(float) 63 | 64 | } // namespace st::kernel -------------------------------------------------------------------------------- /src/csrc/kernel/gather_last_tokens.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | #include "util/cuda_utils.h" 7 | 8 | namespace st::kernel { 9 | 10 | template 11 | void gatherLastTokens( 12 | T* result, 13 | const T* tokens, 14 | const int64_t num_tokens, 15 | const int64_t batch_size, 16 | const int64_t hidden_dim, 17 | const int64_t* sum_prev_input_lens 18 | ); 19 | 20 | } // namespace st::kernel -------------------------------------------------------------------------------- /src/csrc/kernel/kvcache_mgmt.cu: -------------------------------------------------------------------------------- 1 | #include "kvcache_mgmt.h" 2 | 3 | #include 4 | #include 5 | 6 | #include "util/cuda_utils.h" 7 | 8 | namespace st::kernel { 9 | 10 | #define WARP_SIZE 32 11 | 12 | // Tuneable parameters 13 | constexpr int64_t DEFAULT_THREAD_BLOCK_SIZE = 512; 14 | 15 | /* 16 | saveContextStageKVCache 17 | 18 | This kernel takes q/k/vs that are just calculated and stores them into 19 | k/v_cache. 20 | 21 | ## Implementation Details 22 | 23 | The size of the grid is (num_kv_heads, num_context_reqs). In other words, 24 | each thread block is assigned to a particular request and head. 25 | 26 | Each token ([HEAD_DIM]) is assigned to a particular warp. 27 | */ 28 | 29 | template< 30 | typename T, 31 | int64_t HEAD_DIM, 32 | int64_t BLOCK_SIZE, 33 | int64_t THREAD_BLOCK_SIZE 34 | > __global__ void saveContextStageKVCacheKernel( 35 | T* __restrict__ k_cache_offseted, // [num_block, num_layers, num_kv_heads, block_size, head_dim] 36 | // The OFFSETed k_cache. 37 | // The shape of k_cache is [num_blocks, num_layers, num_kv_heads, block_size, head_dim] 38 | // This k_cache_offseted is real k_cache + layer_id*num_kv_heads*block_size*head_dim 39 | // So we does not need another register for storing layer_id 40 | T* __restrict__ v_cache_offseted, 41 | 42 | const T* __restrict__ qkvs, // [num_tokens, num_q_heads+2*num_kv_heads, head_dim] 43 | const int64_t* __restrict__ block_table, // [num_reqs, max_num_block_per_seq] 44 | 45 | const int64_t* __restrict__ input_lens, // [num_reqs] 46 | const int64_t* __restrict__ ith_context_req_req_index, // [num_context_reqs] 47 | const int32_t* __restrict__ ith_context_req_token_index, // [num_context_reqs] 48 | 49 | const int64_t max_num_block_per_seq, 50 | const int64_t num_layers, 51 | const int64_t num_q_heads 52 | ) { 53 | typedef std::conditional_t::value, half2, float2> T2; 54 | constexpr int64_t NUM_WARPS = THREAD_BLOCK_SIZE / WARP_SIZE; 55 | 56 | const int64_t num_kv_heads = gridDim.x; 57 | const int64_t head_id = blockIdx.x; 58 | 59 | const int64_t req_index = ith_context_req_req_index[blockIdx.y]; 60 | const int64_t first_token_index = ith_context_req_token_index[blockIdx.y]; 61 | const int64_t input_len = input_lens[req_index]; 62 | 63 | const int64_t warp_id = threadIdx.x / WARP_SIZE; 64 | const int64_t lane_id = threadIdx.x % WARP_SIZE; 65 | 66 | for (int64_t token_index = warp_id; token_index < input_len; token_index += NUM_WARPS) { 67 | int64_t block_index = block_table[INDEX_2D(0, max_num_block_per_seq, req_index, token_index/BLOCK_SIZE)]; 68 | int64_t offset_in_block = token_index % BLOCK_SIZE; 69 | #pragma unroll 70 | for (int64_t hd_index = lane_id; hd_index < HEAD_DIM/2; hd_index += WARP_SIZE) { 71 | // "hd" stands for "head dim" 72 | int64_t kvcache_index = INDEX_5D( 73 | 0, num_layers, num_kv_heads, BLOCK_SIZE, HEAD_DIM/2, 74 | block_index, 0, head_id, offset_in_block, hd_index 75 | ); 76 | ((T2*)k_cache_offseted)[kvcache_index] = ((T2*)qkvs)[INDEX_3D(0, num_q_heads+2*num_kv_heads, HEAD_DIM/2, first_token_index+token_index, num_q_heads+head_id, hd_index)]; 77 | ((T2*)v_cache_offseted)[kvcache_index] = ((T2*)qkvs)[INDEX_3D(0, num_q_heads+2*num_kv_heads, HEAD_DIM/2, first_token_index+token_index, num_q_heads+num_kv_heads+head_id, hd_index)]; 78 | } 79 | } 80 | } 81 | 82 | #define LAUNCH_SAVE_CONTEXT_STAGE_KVCACHE_KERNEL(T, HEAD_DIM, BLOCK_SIZE) \ 83 | saveContextStageKVCacheKernel \ 84 | <<>> \ 85 | (k_cache_offseted, v_cache_offseted, qkvs, block_table, input_lens, \ 86 | ith_context_req_req_index, ith_context_req_token_index, max_num_block_per_seq, num_layers, num_q_heads) 87 | 88 | #define DISPATCH_SAVE_CONTEXT_STAGE_KVCACHE_KERNEL_HEAD_DIM_BLOCK_SIZE(T, HEAD_DIM) \ 89 | switch (block_size) { \ 90 | case 1: LAUNCH_SAVE_CONTEXT_STAGE_KVCACHE_KERNEL(T, HEAD_DIM, 1); break; \ 91 | case 2: LAUNCH_SAVE_CONTEXT_STAGE_KVCACHE_KERNEL(T, HEAD_DIM, 2); break; \ 92 | case 4: LAUNCH_SAVE_CONTEXT_STAGE_KVCACHE_KERNEL(T, HEAD_DIM, 4); break; \ 93 | case 8: LAUNCH_SAVE_CONTEXT_STAGE_KVCACHE_KERNEL(T, HEAD_DIM, 8); break; \ 94 | case 16: LAUNCH_SAVE_CONTEXT_STAGE_KVCACHE_KERNEL(T, HEAD_DIM, 16); break; \ 95 | case 32: LAUNCH_SAVE_CONTEXT_STAGE_KVCACHE_KERNEL(T, HEAD_DIM, 32); break; \ 96 | default: fprintf(stderr, "Unsupported block_size: %ld\n", block_size); assert(0); \ 97 | } 98 | 99 | #define DISPATCH_SAVE_CONTEXT_STAGE_KVCACHE_KERNEL_HEAD_DIM(T) \ 100 | switch (head_dim) { \ 101 | case 64: DISPATCH_SAVE_CONTEXT_STAGE_KVCACHE_KERNEL_HEAD_DIM_BLOCK_SIZE(T, 64); break; \ 102 | case 80: DISPATCH_SAVE_CONTEXT_STAGE_KVCACHE_KERNEL_HEAD_DIM_BLOCK_SIZE(T, 80); break; \ 103 | case 96: DISPATCH_SAVE_CONTEXT_STAGE_KVCACHE_KERNEL_HEAD_DIM_BLOCK_SIZE(T, 96); break; \ 104 | case 112: DISPATCH_SAVE_CONTEXT_STAGE_KVCACHE_KERNEL_HEAD_DIM_BLOCK_SIZE(T, 112); break; \ 105 | case 128: DISPATCH_SAVE_CONTEXT_STAGE_KVCACHE_KERNEL_HEAD_DIM_BLOCK_SIZE(T, 128); break; \ 106 | default: fprintf(stderr, "Unsupported head_dim: %ld\n", head_dim); assert(0); \ 107 | } 108 | 109 | template 110 | void saveContextStageKVCache( 111 | T* k_cache, 112 | T* v_cache, 113 | 114 | const T* qkvs, 115 | const int64_t* block_table, 116 | 117 | const int64_t* input_lens, 118 | const int64_t num_context_reqs, 119 | const int64_t* ith_context_req_req_index, 120 | const int32_t* ith_context_req_token_index, 121 | 122 | const int64_t block_size, 123 | const int64_t max_num_block_per_seq, 124 | const int64_t num_layers, 125 | const int64_t num_q_heads, 126 | const int64_t num_kv_heads, 127 | const int64_t head_dim, 128 | const int64_t layer_id 129 | ) { 130 | T* k_cache_offseted = k_cache + layer_id * num_kv_heads * block_size * head_dim; 131 | T* v_cache_offseted = v_cache + layer_id * num_kv_heads * block_size * head_dim; 132 | dim3 grid_dim(num_kv_heads, num_context_reqs); 133 | DISPATCH_SAVE_CONTEXT_STAGE_KVCACHE_KERNEL_HEAD_DIM(T); 134 | } 135 | 136 | #define INSTANTIATE_SAVE_CONTEXT_STAGE_KVCACHE(T) \ 137 | template void saveContextStageKVCache( \ 138 | T* k_cache, \ 139 | T* v_cache, \ 140 | const T* qkvs, \ 141 | const int64_t* block_table, \ 142 | const int64_t* input_lens, \ 143 | const int64_t num_context_reqs, \ 144 | const int64_t* ith_context_req_req_index, \ 145 | const int32_t* ith_context_req_token_index, \ 146 | const int64_t block_size, \ 147 | const int64_t max_num_block_per_seq, \ 148 | const int64_t num_layers, \ 149 | const int64_t num_q_heads, \ 150 | const int64_t num_kv_heads, \ 151 | const int64_t head_dim, \ 152 | const int64_t layer_id \ 153 | ); 154 | 155 | INSTANTIATE_SAVE_CONTEXT_STAGE_KVCACHE(float); 156 | INSTANTIATE_SAVE_CONTEXT_STAGE_KVCACHE(half); 157 | 158 | } // st::kernel -------------------------------------------------------------------------------- /src/csrc/kernel/kvcache_mgmt.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | namespace st::kernel { 4 | 5 | template 6 | void saveContextStageKVCache( 7 | T* k_cache, 8 | T* v_cache, 9 | 10 | const T* qkvs, 11 | const int64_t* block_table, 12 | 13 | const int64_t* input_lens, 14 | const int64_t num_context_reqs, 15 | const int64_t* ith_context_req_req_index, 16 | const int32_t* ith_context_req_token_index, 17 | 18 | const int64_t block_size, 19 | const int64_t max_num_block_per_seq, 20 | const int64_t num_layers, 21 | const int64_t num_q_heads, 22 | const int64_t num_kv_heads, 23 | const int64_t head_dim, 24 | const int64_t layer_id 25 | ); 26 | 27 | } // namespace st::kernel -------------------------------------------------------------------------------- /src/csrc/kernel/layernorm.cu: -------------------------------------------------------------------------------- 1 | #include "layernorm.h" 2 | 3 | #include 4 | #include 5 | 6 | #include "util/cuda_utils.h" 7 | #include "kernel/reduction.cuh" 8 | #include "util/debug_utils.h" 9 | 10 | namespace st::kernel { 11 | 12 | constexpr int WARP_SIZE = 32; 13 | constexpr int NUM_THREADS = 256; 14 | constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; 15 | 16 | template 17 | __global__ void layernormKernel( 18 | T* __restrict__ out, // [num_tokens, hidden_size] 19 | const T* __restrict__ input, // [num_tokens, hidden_size] 20 | const T* __restrict__ weight, // [hidden_size] 21 | const T* __restrict__ bias, // [hidden_size] 22 | const float epsilon, 23 | const int64_t num_tokens, 24 | const int64_t hidden_size, 25 | T* __restrict__ biased_input, // [num_tokens, hidden_size] 26 | const T* __restrict__ pre_layernorm_bias // [hidden_size] 27 | ) { 28 | typedef std::conditional_t::value, half2, float2> T2; 29 | 30 | extern __shared__ float shared_mem[]; // Have hidden_size * sizeof(T) bytes 31 | T2* input_buf = (T2*)shared_mem; // [hidden_size/2], a cache for input[] 32 | 33 | // Step 1. Calculate (local) mean and variance 34 | __shared__ float s_mean, s_variance; // We use float here, or the value may exceed the range of half 35 | float mean = 0.0, variance = 0.0; 36 | 37 | #pragma unroll 4 38 | for (int64_t idx = threadIdx.x; idx < hidden_size/2; idx += blockDim.x) { 39 | T2 elem = ((T2*)input)[blockIdx.x * hidden_size/2 + idx]; 40 | if constexpr(HAVE_PRE_LAYERNORM_BIAS) { 41 | const T2 pre_layernorm_bias_elem = ((T2*)pre_layernorm_bias)[blockIdx.x * hidden_size/2 + idx]; 42 | elem.x += pre_layernorm_bias_elem.x; 43 | elem.y += pre_layernorm_bias_elem.y; 44 | ((T2*)biased_input)[blockIdx.x * hidden_size/2 + idx] = elem; 45 | } 46 | input_buf[idx] = elem; 47 | const float x = (float)elem.x; 48 | const float y = (float)elem.y; 49 | mean += x+y; 50 | variance += x*x + y*y; 51 | } 52 | 53 | // Step 2. Reduce mean and variance 54 | // Step 2.1 Reduce within the warp 55 | #pragma unroll 56 | for (int offset = WARP_SIZE/2; offset > 0; offset /= 2) { 57 | mean += __shfl_down_sync(0xffffffff, mean, offset); 58 | variance += __shfl_down_sync(0xffffffff, variance, offset); 59 | } 60 | static __shared__ float reduction_wksp[2][NUM_WARPS]; // 32 = max block size (1024) / WARP_SIZE (32) 61 | if ((threadIdx.x & 31) == 0) { 62 | reduction_wksp[0][threadIdx.x >> 5] = mean; 63 | reduction_wksp[1][threadIdx.x >> 5] = variance; 64 | } 65 | __syncthreads(); 66 | 67 | // Step 2.2 Reduce within the block 68 | if (threadIdx.x < NUM_WARPS) { 69 | mean = reduction_wksp[0][threadIdx.x]; 70 | variance = reduction_wksp[1][threadIdx.x]; 71 | } 72 | #pragma unroll 73 | for (int offset = NUM_WARPS/2; offset > 0; offset /= 2) { 74 | mean += __shfl_down_sync(0xffffffff, mean, offset); 75 | variance += __shfl_down_sync(0xffffffff, variance, offset); 76 | } 77 | 78 | if (threadIdx.x == 0) { 79 | float hidden_size_fp = (float)hidden_size; 80 | s_mean = mean / hidden_size_fp; 81 | s_variance = rsqrtf(variance / hidden_size_fp - s_mean * s_mean + epsilon); 82 | } 83 | __syncthreads(); 84 | 85 | // Step 3. Normalize 86 | T final_mean = (T)s_mean; 87 | T final_variance = (T)s_variance; 88 | #pragma unroll 4 89 | for (int64_t idx = threadIdx.x; idx < hidden_size/2; idx += blockDim.x) { 90 | T2 x = input_buf[idx]; 91 | T2 weight_elem = ((T2*)weight)[idx]; 92 | T2 bias_elem = ((T2*)bias)[idx]; 93 | ((T2*)out)[blockIdx.x * hidden_size/2 + idx] = { 94 | ((x.x - final_mean) * final_variance) * weight_elem.x + bias_elem.x, 95 | ((x.y - final_mean) * final_variance) * weight_elem.y + bias_elem.y 96 | }; 97 | } 98 | } 99 | 100 | template 101 | void layernorm( 102 | T* out, 103 | const T* input, 104 | 105 | const T* weight, 106 | const T* bias, 107 | const float epsilon, 108 | 109 | const int64_t num_tokens, 110 | const int64_t hidden_size, 111 | 112 | T* biased_input, // Default: nullptr 113 | const T* pre_layernorm_bias // Default: nullptr 114 | ) { 115 | dim3 grid(num_tokens); 116 | dim3 block(NUM_THREADS); 117 | assert_whenever (hidden_size % NUM_THREADS == 0); 118 | 119 | if (pre_layernorm_bias == nullptr) { 120 | assert_whenever (biased_input == nullptr); 121 | layernormKernel<<>>( 122 | out, 123 | input, 124 | weight, 125 | bias, 126 | epsilon, 127 | num_tokens, 128 | hidden_size, 129 | nullptr, 130 | nullptr); 131 | } else { 132 | assert_whenever (biased_input != nullptr); 133 | layernormKernel<<>>( 134 | out, 135 | input, 136 | weight, 137 | bias, 138 | epsilon, 139 | num_tokens, 140 | hidden_size, 141 | biased_input, 142 | pre_layernorm_bias); 143 | } 144 | } 145 | 146 | template void layernorm( 147 | float* out, const float* input, 148 | const float* weight, const float* bias, const float epsilon, 149 | const int64_t num_tokens, const int64_t hidden_size, float* biased_input, const float* pre_layernorm_bias 150 | ); 151 | template void layernorm( 152 | half* out, const half* input, 153 | const half* weight, const half* bias, const float epsilon, 154 | const int64_t num_tokens, const int64_t hidden_size, half* biased_input, const half* pre_layernorm_bias 155 | ); 156 | 157 | } // namespace st::kernel -------------------------------------------------------------------------------- /src/csrc/kernel/layernorm.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | namespace st::kernel { 4 | 5 | template 6 | void layernorm( 7 | T* out, 8 | const T* input, 9 | 10 | const T* weight, 11 | const T* bias, 12 | const float epsilon, 13 | 14 | const int64_t num_tokens, 15 | const int64_t hidden_size, 16 | 17 | T* biased_input = nullptr, 18 | const T* pre_layernorm_bias = nullptr 19 | ); 20 | 21 | } // namespace st::kernel -------------------------------------------------------------------------------- /src/csrc/kernel/reduction.cuh: -------------------------------------------------------------------------------- 1 | /* 2 | * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/reduce_kernel_utils.cuh 3 | * Copyright (c) 2023, The swiftGPT team. 4 | * Copyright (c) 2023, The CacheFlow team. 5 | * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. 6 | * 7 | * Licensed under the Apache License, Version 2.0 (the "License"); 8 | * you may not use this file except in compliance with the License. 9 | * You may obtain a copy of the License at 10 | * 11 | * http://www.apache.org/licenses/LICENSE-2.0 12 | * 13 | * Unless required by applicable law or agreed to in writing, software 14 | * distributed under the License is distributed on an "AS IS" BASIS, 15 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | * See the License for the specific language governing permissions and 17 | * limitations under the License. 18 | */ 19 | #pragma once 20 | 21 | namespace st::kernel { 22 | 23 | static const float HALF_FLT_MAX = 65504.F; 24 | #define FINAL_MASK 0xffffffffu 25 | 26 | template 27 | __inline__ __device__ T warpReduceSum(T val) { 28 | #pragma unroll 29 | for (int mask = 16; mask > 0; mask >>= 1) 30 | val += __shfl_xor_sync(0xffffffffu, val, mask, 32); 31 | return val; 32 | } 33 | 34 | /* Calculate the sum of all elements in a block */ 35 | template 36 | __inline__ __device__ T blockReduceSum(T val) { 37 | static __shared__ T shared[32]; 38 | int64_t lane = threadIdx.x & 0x1fu; 39 | int64_t wid = threadIdx.x >> 5; 40 | 41 | val = warpReduceSum(val); 42 | 43 | if (lane == 0) 44 | shared[wid] = val; 45 | 46 | __syncthreads(); 47 | 48 | // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent 49 | // blockDim.x is not divided by 32 50 | val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f); 51 | val = warpReduceSum(val); 52 | return val; 53 | } 54 | 55 | template 56 | __inline__ __device__ T warpReduceMax(T val) 57 | { 58 | #pragma unroll 59 | for (int mask = 16; mask > 0; mask >>= 1) 60 | val = max(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32)); 61 | return val; 62 | } 63 | 64 | /* Calculate the maximum of all elements in a block */ 65 | template 66 | __inline__ __device__ T blockReduceMax(T val) 67 | { 68 | static __shared__ T shared[32]; 69 | int64_t lane = threadIdx.x & 0x1fu; // in-warp idx 70 | int64_t wid = threadIdx.x >> 5; // warp idx 71 | 72 | val = warpReduceMax(val); // get maxx in each warp 73 | 74 | if (lane == 0) // record in-warp maxx by warp Idx 75 | shared[wid] = val; 76 | 77 | __syncthreads(); 78 | 79 | // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent 80 | // blockDim.x is not divided by 32 81 | val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : -1e20f; 82 | val = warpReduceMax(val); 83 | 84 | return val; 85 | } 86 | 87 | 88 | } // namespace st::kernel -------------------------------------------------------------------------------- /src/csrc/kernel/rmsnorm.cu: -------------------------------------------------------------------------------- 1 | #include "rmsnorm.h" 2 | 3 | #include "util/cuda_utils.h" 4 | #include "kernel/reduction.cuh" 5 | 6 | namespace st::kernel { 7 | 8 | template 9 | __global__ void rmsnormKernel( 10 | T* output, // [num_tokens, hidden_size] 11 | const T* input, // [num_tokens, hidden_size] 12 | const T* weight, // [hidden_size] 13 | const float epsilon, 14 | const int64_t hidden_size 15 | ) { 16 | // Step 1. Every thread computes some part of the sum of squares 17 | float square_sum = 0.0; 18 | __shared__ float inv_rms; 19 | for (int64_t i = threadIdx.x; i < hidden_size; i += blockDim.x) { 20 | const float x = input[blockIdx.x * hidden_size + i]; 21 | square_sum += x * x; 22 | } 23 | // Step 2. Sum the squares across threads 24 | square_sum = blockReduceSum(square_sum); 25 | // Step 3. Compute the inverse root mean square 26 | if (threadIdx.x == 0) { 27 | inv_rms = rsqrtf(square_sum / hidden_size + epsilon); 28 | } 29 | __syncthreads(); 30 | // Step 4. Compute the output 31 | for (int64_t i = threadIdx.x; i < hidden_size; i += blockDim.x) { 32 | const float x = input[blockIdx.x * hidden_size + i]; 33 | const float w = weight[i]; 34 | output[blockIdx.x * hidden_size + i] = x * w * inv_rms; 35 | } 36 | } 37 | 38 | template 39 | void rmsnorm( 40 | T* out, 41 | const T* input, 42 | 43 | const T* weight, 44 | const float epsilon, 45 | 46 | const int64_t num_tokens, 47 | const int64_t hidden_size 48 | ) { 49 | const int64_t block_size = std::min(hidden_size, 1024L); 50 | const int64_t grid_size = num_tokens; 51 | rmsnormKernel<<>>(out, input, weight, epsilon, hidden_size); 52 | } 53 | 54 | #define INSTANTIATE(T) \ 55 | template void rmsnorm( \ 56 | T* out, \ 57 | const T* input, \ 58 | const T* weight, \ 59 | const float epsilon, \ 60 | const int64_t num_tokens, \ 61 | const int64_t hidden_size \ 62 | ); 63 | 64 | INSTANTIATE(float) 65 | INSTANTIATE(half) 66 | 67 | } // namespace st::kernel -------------------------------------------------------------------------------- /src/csrc/kernel/rmsnorm.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace st::kernel { 6 | 7 | template 8 | void rmsnorm( 9 | T* out, 10 | const T* input, 11 | 12 | const T* weight, 13 | const float epsilon, 14 | 15 | const int64_t num_tokens, 16 | const int64_t hidden_size 17 | ); 18 | 19 | } // namespace st::kernel -------------------------------------------------------------------------------- /src/csrc/kernel/rotary_posi_embedding.cu: -------------------------------------------------------------------------------- 1 | #include "rotary_posi_embedding.h" 2 | 3 | namespace st::kernel { 4 | 5 | /* 6 | rotaryPosiEmbeddingBatched 7 | 8 | Perform rotary positional embedding on a batch of tokens. 9 | 10 | ## Background 11 | 12 | Rotary positional embedding (RoPE), as proposed in "ROFORMER : ENHANCED 13 | TRANSFORMER WITH ROTARY POSITION EMBEDDING", is a method of positional 14 | embedding that encodes the absolute position while incorporates the 15 | relative position between tokens. Models like LLaMA and LLaMA2 are based 16 | on RoPE. 17 | 18 | ## Introduction 19 | 20 | This kernel takes a bunch of tokens and their absolute positions with in the 21 | request, and performs RoPE on them. 22 | 23 | ## Implementation Details 24 | 25 | We summon a grid of shape (num_tokens), i.e. each thread block is 26 | responsible for one token. Each thread block has head_dim/2 threads. The 27 | i-th thread will deal with the (2i) and (2i+1) th elements in the head_dim 28 | in every head. 29 | 30 | ## Notes 31 | 32 | In practice we perform RoPE on both the query and the key. Note that when 33 | performing RoPE on the key, we need to pass num_local_kv_heads as num_heads, 34 | while performing on the query we need to pass num_local_q_heads as num_heads. 35 | */ 36 | 37 | template 38 | __global__ void rotaryPosiEmbeddingBatchedKernel ( 39 | T* __restrict__ target, // [num_tokens, target_1st_dim_size, head_dim]. We will only use [num_tokens, :num_heads, head_dim] 40 | const int64_t* __restrict__ token_indexes, // [num_tokens] 41 | const int64_t num_heads, 42 | const int64_t target_1st_dim_size, 43 | const int64_t head_dim 44 | ) { 45 | const int64_t rel_pos = token_indexes[blockIdx.x]; 46 | float cur_sin_f, cur_cos_f; 47 | __sincosf(rel_pos*__powf(10000.0f, -2.0f*threadIdx.x/head_dim), &cur_sin_f, &cur_cos_f); 48 | const T cur_sin = (T)cur_sin_f, cur_cos = (T)cur_cos_f; 49 | 50 | typedef typename std::conditional::value, float2, half2>::type T2; 51 | for (int64_t head_id = 0; head_id < num_heads; head_id += 1) { 52 | // Read x1 and x2 in pack 53 | const T2 x1_x2 = reinterpret_cast(target)[INDEX_3D( 54 | 0, target_1st_dim_size, head_dim/2, 55 | blockIdx.x, head_id, threadIdx.x 56 | )]; 57 | const T x1 = x1_x2.x, x2 = x1_x2.y; 58 | const T new_x1 = x1*cur_cos - x2*cur_sin; 59 | const T new_x2 = x1*cur_sin + x2*cur_cos; 60 | // Write back 61 | reinterpret_cast(target)[INDEX_3D( 62 | 0, target_1st_dim_size, head_dim/2, 63 | blockIdx.x, head_id, threadIdx.x 64 | )] = T2{new_x1, new_x2}; 65 | } 66 | } 67 | 68 | template 69 | void rotaryPosiEmbeddingBatched( 70 | T* __restrict__ target, 71 | const int64_t* __restrict__ token_indices, 72 | const int64_t num_tokens, 73 | const int64_t target_1st_dim_size, 74 | const int64_t num_heads, 75 | const int64_t head_dim 76 | ) { 77 | rotaryPosiEmbeddingBatchedKernel<<>>( 78 | target, token_indices, num_heads, target_1st_dim_size, head_dim 79 | ); 80 | } 81 | 82 | #define INTANTIATE(T) \ 83 | template void rotaryPosiEmbeddingBatched( \ 84 | T* __restrict__, \ 85 | const int64_t* __restrict__, \ 86 | const int64_t, \ 87 | const int64_t, \ 88 | const int64_t, \ 89 | const int64_t \ 90 | ); 91 | 92 | INTANTIATE(half) 93 | INTANTIATE(float) 94 | 95 | } // namespace st::kernel 96 | -------------------------------------------------------------------------------- /src/csrc/kernel/rotary_posi_embedding.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "util/cuda_utils.h" 4 | 5 | namespace st::kernel { 6 | 7 | template 8 | void rotaryPosiEmbeddingBatched( 9 | T* __restrict__ target, 10 | const int64_t* __restrict__ token_indices, 11 | const int64_t num_tokens, 12 | const int64_t target_1st_dim_size, 13 | const int64_t num_heads, 14 | const int64_t head_dim 15 | ); 16 | 17 | } // namespace st::kernel 18 | -------------------------------------------------------------------------------- /src/csrc/kernel/softmax.cu: -------------------------------------------------------------------------------- 1 | #include "softmax.h" 2 | 3 | #include "reduction.cuh" 4 | #include "util/cuda_utils.h" 5 | 6 | namespace st::kernel { 7 | 8 | /* 9 | scaleMaskSoftmaxKernel & 10 | scaleMaskSoftmax 11 | 12 | This kernel applies scaling (*1/sqrt(dk)), masking, and softmax to the input matrix (attention matrix). 13 | 14 | Input: 15 | - input: [num_heads, input_len, input_len] 16 | - scale: 1/sqrt(dk) 17 | Output: 18 | - output: [num_heads, input_len, input_len] 19 | output[head][row] = softmax(masking(input[head][row] * scale)) 20 | */ 21 | template 22 | __global__ void scaleMaskSoftmaxKernel( 23 | T* output, 24 | const T* input, 25 | const float scale, 26 | const int64_t num_heads, 27 | const int64_t input_len 28 | ) { 29 | const int64_t h = blockIdx.x; 30 | for (int64_t r = 0; r < input_len; ++r) { 31 | float local_max = -1e20f, local_sum = 0.0; 32 | __shared__ float s_max, s_sum; 33 | for (int64_t c = threadIdx.x; c < input_len; c += blockDim.x) { 34 | float val = input[INDEX_3D(num_heads, input_len, input_len, h, r, c)]; 35 | val *= scale; 36 | val += r >= c ? 0 : -10000.0; 37 | output[INDEX_3D(num_heads, input_len, input_len, h, r, c)] = val; 38 | local_max = local_max > val ? local_max : val; 39 | } 40 | float max_val = blockDim.x <= 32 ? warpReduceMax(local_max) : blockReduceMax(local_max); 41 | if (threadIdx.x == 0) { 42 | s_max = max_val; 43 | } 44 | __syncthreads(); 45 | for (int64_t c = threadIdx.x; c < input_len; c += blockDim.x) { 46 | float val = output[INDEX_3D(num_heads, input_len, input_len, h, r, c)]; 47 | val = __expf(val - s_max); 48 | output[INDEX_3D(num_heads, input_len, input_len, h, r, c)] = val; 49 | local_sum += val; 50 | } 51 | float sum = blockDim.x <= 32 ? warpReduceSum(local_sum) : blockReduceSum(local_sum); 52 | if (threadIdx.x == 0) { 53 | s_sum = sum; 54 | } 55 | __syncthreads(); 56 | float to_mult = __fdividef((float)1.0, s_sum+(float)(1e-6)); 57 | for (int64_t c = threadIdx.x; c < input_len; c += blockDim.x) { 58 | float val = output[INDEX_3D(num_heads, input_len, input_len, h, r, c)]; 59 | val *= to_mult; 60 | output[INDEX_3D(num_heads, input_len, input_len, h, r, c)] = val; 61 | } 62 | } 63 | } 64 | 65 | template 66 | void scaleMaskSoftmax( 67 | T* output, 68 | const T* input, 69 | const float scale, 70 | const int64_t num_heads, 71 | const int64_t input_len 72 | ) { 73 | uint32_t block_dim = std::min(input_len, 256l); 74 | scaleMaskSoftmaxKernel<<>>( 75 | output, 76 | input, 77 | scale, 78 | num_heads, 79 | input_len 80 | ); 81 | } 82 | 83 | template void scaleMaskSoftmax( 84 | float* output, const float* input, 85 | const float scale, 86 | const int64_t num_heads, const int64_t input_len 87 | ); 88 | template void scaleMaskSoftmax( 89 | half* output, const half* input, 90 | const float scale, 91 | const int64_t num_heads, const int64_t input_len 92 | ); 93 | 94 | 95 | 96 | /* 97 | scaleSoftmaxKernel & scaleSoftmax 98 | 99 | This performs scale & softmax on a batch of 1-D array 100 | This function is used in the regression stage 101 | 102 | Input: 103 | - input: the input array, typically it is the last row of the attention matrix, [num_heads, seq_len] 104 | - scale: the scale factor, typically it is 1/sqrt(head_dim) 105 | Output: 106 | - output: the output array, [num_heads, seq_len] 107 | output[head] = softmax(input[head] * scale) 108 | */ 109 | template 110 | __global__ void scaleSoftmaxKernel( 111 | T* output, 112 | const T* input, 113 | const float scale, 114 | const int64_t seq_len 115 | ) { 116 | __shared__ float s_max, s_sum; 117 | const int64_t index_start = seq_len*blockIdx.x + threadIdx.x; 118 | const int64_t index_end = seq_len*blockIdx.x + seq_len; 119 | 120 | float local_max = -1e20f; 121 | for (int64_t index = index_start; index < index_end; index += blockDim.x) { 122 | float val = input[index]; 123 | val *= scale; 124 | local_max = local_max > val ? local_max : val; 125 | } 126 | 127 | float max_val = blockDim.x <= 32 ? warpReduceMax(local_max) : blockReduceMax(local_max); 128 | if (threadIdx.x == 0) { 129 | s_max = max_val; 130 | } 131 | __syncthreads(); 132 | 133 | float local_sum = 0; 134 | for (int64_t index = index_start; index < index_end; index += blockDim.x) { 135 | float val = input[index]; 136 | val *= scale; 137 | val = __expf(val - s_max); 138 | local_sum += val; 139 | output[index] = val; 140 | } 141 | 142 | float sum = blockDim.x <= 32 ? warpReduceSum(local_sum) : blockReduceSum(local_sum); 143 | if (threadIdx.x == 0) { 144 | s_sum = sum; 145 | } 146 | __syncthreads(); 147 | 148 | float to_mult = __fdividef((float)1.0, s_sum+1e-6f); 149 | for (int64_t index = index_start; index < index_end; index += blockDim.x) { 150 | float val = output[index]; 151 | val *= to_mult; 152 | output[index] = (T)val; 153 | } 154 | } 155 | 156 | template 157 | void scaleSoftmax( 158 | T* output, 159 | const T* input, 160 | const float scale, 161 | const int64_t num_heads, 162 | const int64_t seq_len 163 | ) { 164 | uint32_t block_dim = std::min(seq_len, 256l); 165 | scaleSoftmaxKernel<<>>(output, input, scale, seq_len); 166 | } 167 | 168 | template void scaleSoftmax( 169 | float* output, const float* input, 170 | const float scale, 171 | const int64_t num_heads, const int64_t seq_len 172 | ); 173 | template void scaleSoftmax( 174 | half* output, const half* input, 175 | const float scale, 176 | const int64_t num_heads, const int64_t seq_len 177 | ); 178 | 179 | } // namespace st::kernel -------------------------------------------------------------------------------- /src/csrc/kernel/softmax.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | namespace st::kernel { 4 | 5 | template 6 | void scaleMaskSoftmax( 7 | T* output, 8 | const T* input, 9 | const float scale, 10 | const int64_t num_heads, 11 | const int64_t input_len 12 | ); 13 | 14 | template 15 | void scaleSoftmax( 16 | T* output, 17 | const T* input, 18 | const float scale, 19 | const int64_t num_heads, 20 | const int64_t seq_len 21 | ); 22 | 23 | } // namespace st::kernel -------------------------------------------------------------------------------- /src/csrc/kernel/unfused_attention.cu: -------------------------------------------------------------------------------- 1 | #include "unfused_attention.h" 2 | 3 | #include "util/cuda_utils.h" 4 | 5 | namespace st::kernel { 6 | 7 | /* 8 | transposeQKVKernel & 9 | transposeQKV 10 | 11 | This kernel retrieves q, k, and v in the particular input from QKV buffer 12 | (generated by previous step, qkv_gemm), and saves them 13 | to q_buf, k_buf, and v_buf respectively. 14 | 15 | The input buffer, QKV, has a shape of [num_tokens, 3, num_heads, head_dim] 16 | The output buffer, q_buf, k_buf, and v_buf, has a shape of [num_heads, cur_input_len, head_dim] 17 | 18 | The selected input (request) is indicated by cur_input_start. 19 | */ 20 | template 21 | __global__ void transposeQKVKernel( 22 | T* q_buf, 23 | T* k_buf, 24 | T* v_buf, 25 | const T* QKV, 26 | const int64_t cur_input_start, 27 | const int64_t input_len, 28 | const int64_t num_heads, 29 | const int64_t head_dim 30 | ) { 31 | const int64_t size_index = threadIdx.x; 32 | const int64_t head_index = blockIdx.x; 33 | for (int64_t token_index = 0; token_index < input_len; ++token_index) { 34 | int64_t qkv_q_index = INDEX_4D(input_len, 3, num_heads, head_dim, token_index+cur_input_start, 0, head_index, size_index); 35 | int64_t qkv_k_index = qkv_q_index + num_heads*head_dim; 36 | int64_t qkv_v_index = qkv_k_index + num_heads*head_dim; 37 | int64_t q_buf_index = INDEX_3D(num_heads, input_len, head_dim, head_index, token_index, size_index); 38 | int64_t k_buf_index = q_buf_index; 39 | int64_t v_buf_index = q_buf_index; 40 | q_buf[q_buf_index] = QKV[qkv_q_index]; 41 | k_buf[k_buf_index] = QKV[qkv_k_index]; 42 | v_buf[v_buf_index] = QKV[qkv_v_index]; 43 | } 44 | } 45 | 46 | template 47 | void transposeQKV( 48 | T* q_buf, 49 | T* k_buf, 50 | T* v_buf, 51 | const T* QKV, 52 | const int64_t cur_input_start, 53 | const int64_t input_len, 54 | const int64_t num_heads, 55 | const int64_t head_dim 56 | ) { 57 | transposeQKVKernel<<>>( 58 | q_buf, 59 | k_buf, 60 | v_buf, 61 | QKV, 62 | cur_input_start, 63 | input_len, 64 | num_heads, 65 | head_dim 66 | ); 67 | } 68 | 69 | template void transposeQKV( 70 | float* q_buf, float* k_buf, float* v_buf, 71 | const float* QKV, 72 | const int64_t cur_input_start, const int64_t input_len, const int64_t num_heads, const int64_t head_dim 73 | ); 74 | 75 | template void transposeQKV( 76 | half* q_buf, half* k_buf, half* v_buf, 77 | const half* QKV, 78 | const int64_t cur_input_start, const int64_t input_len, const int64_t num_heads, const int64_t head_dim 79 | ); 80 | 81 | 82 | /* 83 | mergeOutputKernel & mergeOutput 84 | 85 | This kernel is used in ContextDecoder to merge the output of the attention matrix into a single matrix. 86 | 87 | Input: 88 | - cur_input: [num_heads, cur_input_num_tokens, head_dim] 89 | Output: 90 | - output: [~, num_heads, head_dim,] 91 | output[cur_input_start + i][head_index][size_index] = input[head_index][i][size_index] 92 | */ 93 | template 94 | __global__ void mergeOutputKernel( 95 | T* output, 96 | const T* cur_input, 97 | const int64_t cur_input_num_tokens, 98 | const int64_t cur_input_start, 99 | const int64_t num_heads, 100 | const int64_t head_dim 101 | ) { 102 | const int64_t size_index = threadIdx.x; 103 | const int64_t head_index = blockIdx.x; 104 | for (int64_t token_index = 0; token_index < cur_input_num_tokens; ++token_index) { 105 | int64_t output_index = INDEX_3D(0, num_heads, head_dim, cur_input_start+token_index, head_index, size_index); 106 | int64_t input_index = INDEX_3D(num_heads, cur_input_num_tokens, head_dim, head_index, token_index, size_index); 107 | output[output_index] = cur_input[input_index]; 108 | } 109 | } 110 | 111 | template 112 | void mergeOutput( 113 | T* output, 114 | const T* cur_input, 115 | const int64_t cur_input_num_tokens, 116 | const int64_t cur_input_start, 117 | const int64_t num_heads, 118 | const int64_t head_dim 119 | ) { 120 | mergeOutputKernel<<>>(output, cur_input, cur_input_num_tokens, cur_input_start, num_heads, head_dim); 121 | } 122 | 123 | template void mergeOutput( 124 | float* output, const float* cur_input, 125 | const int64_t cur_input_num_tokens, const int64_t cur_input_start, const int64_t num_heads, const int64_t head_dim 126 | ); 127 | template void mergeOutput( 128 | half* output, const half* cur_input, 129 | const int64_t cur_input_num_tokens, const int64_t cur_input_start, const int64_t num_heads, const int64_t head_dim 130 | ); 131 | 132 | 133 | } // namespace st::kernel 134 | -------------------------------------------------------------------------------- /src/csrc/kernel/unfused_attention.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | namespace st::kernel { 4 | 5 | template 6 | void transposeQKV( 7 | T* q_buf, 8 | T* k_buf, 9 | T* v_buf, 10 | const T* QKV, 11 | const int64_t cur_input_start, 12 | const int64_t input_len, 13 | const int64_t num_heads, 14 | const int64_t head_dim 15 | ); 16 | 17 | 18 | template 19 | void mergeOutput( 20 | T* output, 21 | const T* cur_input, 22 | const int64_t cur_input_len, 23 | const int64_t cur_input_start, 24 | const int64_t num_heads, 25 | const int64_t head_dim 26 | ); 27 | 28 | } // namespace st::kernel 29 | -------------------------------------------------------------------------------- /src/csrc/kernel/xformers_attention.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace st::kernel { 6 | 7 | template 8 | void xformersContextStageAttention( 9 | T* __restrict__ result, 10 | const T* __restrict__ qkvs, 11 | const float qk_scale, 12 | const int64_t* __restrict__ input_lens, 13 | const int64_t num_context_reqs, 14 | const int64_t* __restrict__ ith_context_req_req_index, 15 | const int32_t* __restrict__ ith_context_req_token_index, 16 | const int64_t num_q_heads, 17 | const int64_t num_kv_heads, 18 | const int64_t head_dim, 19 | const int64_t num_tokens, 20 | const int64_t max_context_req_len 21 | ); 22 | 23 | } -------------------------------------------------------------------------------- /src/csrc/layer/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_library(layer STATIC 2 | attention.cc 3 | ffn.cc 4 | gated_ffn.cc 5 | ) 6 | set_property(TARGET layer PROPERTY POSITION_INDEPENDENT_CODE ON) 7 | target_link_libraries(layer PUBLIC kernel util xformers_kernel) 8 | 9 | if (MPI_FOUND AND NCCL_FOUND) 10 | target_link_libraries(layer PUBLIC nccl_utils) 11 | endif() 12 | -------------------------------------------------------------------------------- /src/csrc/layer/attention.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "util/cublas_wrapper.h" 4 | #include "util/nccl_utils.h" 5 | 6 | namespace st::layer { 7 | 8 | template 9 | void attention( 10 | T* output, 11 | T* k_cache, 12 | T* v_cache, 13 | 14 | const T* input, 15 | const int64_t* input_len, 16 | const bool* is_context_stage_cpu, 17 | const int64_t* block_table, 18 | const int64_t* d_token_indexes, 19 | 20 | int64_t num_context_reqs, 21 | int64_t num_decoding_reqs, 22 | const int64_t* ith_context_req_req_index, 23 | const int32_t* ith_context_req_token_index, 24 | const int64_t* ith_decoding_req_req_index, 25 | const int64_t* ith_decoding_req_token_index, 26 | const int64_t max_context_req_len, 27 | const int64_t max_decoding_req_len, 28 | 29 | const T* qkv_weight_kernel, 30 | const T* qkv_weight_bias, 31 | const T* out_weight_kernel, 32 | const T* out_weight_bias, 33 | 34 | const int64_t batch_size, 35 | const int64_t num_tokens, 36 | const int64_t hidden_size, 37 | const int64_t num_layers, 38 | const int64_t num_q_heads, 39 | const int64_t num_kv_heads, 40 | const int64_t head_dim, 41 | const bool perform_rotary_embedding, 42 | const int64_t layer_id, 43 | const int64_t max_num_block_per_req, 44 | const int64_t block_size, 45 | 46 | T* qkv_buf, 47 | T* attn_out_buf, 48 | float* context_stage_kernel_m_buf, 49 | float* context_stage_kernel_l_buf, 50 | 51 | util::CublasWrapper cublas_wrapper, 52 | util::NcclComm nccl_comm 53 | ); 54 | 55 | } // namespace st::layer 56 | -------------------------------------------------------------------------------- /src/csrc/layer/ffn.cc: -------------------------------------------------------------------------------- 1 | #include "ffn.h" 2 | 3 | #include "util/cuda_utils.h" 4 | #include "kernel/activation_types.h" 5 | 6 | namespace st::layer { 7 | 8 | // ffn - Feed forward network with tensor parallelism 9 | // 10 | // Parallel parameters are passed by NcclComm struct 11 | // - size: size of tensor parallel 12 | // - rank: rank of current process 13 | // - comm: nccl communicator, initialized 14 | // 15 | // Weight addr should be prepared by the caller 16 | // 17 | // Architecture: 18 | // 19 | // input [batch_size, input_dim] 20 | // | 21 | // | Linear 1 22 | // | Weights: 23 | // | - fc1_weight [inter_dim / tensor_para_size, input_dim] 24 | // | - fc1_bias [inter_dim / tensor_para_size] 25 | // | inter = input * fc1_weight^T + fc1_bias 26 | // | 27 | // V 28 | // inter [batch_size, inter_dim / tensor_para_size] 29 | // | 30 | // | Activation 31 | // | inter = max(inter, 0) 32 | // | 33 | // V 34 | // inter [batch_size, inter_dim / tensor_para_size] 35 | // | 36 | // | Linear 2 37 | // | Weights: 38 | // | - fc2_weight [output_dim, inter_dim / tensor_para_size] 39 | // | ToReduce = inter * fc2_weight 40 | // | 41 | // V 42 | // ToReduce [batch_size, output_dim] 43 | // | 44 | // V AllReduce 45 | // | Weights: 46 | // | - fc2_bias [output_dim] 47 | // | output = AllReduce(ToReduce) + fc2_bias 48 | 49 | template 50 | void ffn( 51 | T* output, // [batch_size, output_dim] 52 | T* input, // [batch_size, input_dim] 53 | 54 | T* fc1_weight, // [inter_dim / tensor_para_size, input_dim] 55 | T* fc1_bias, // [inter_dim / tensor_para_size] 56 | T* fc2_weight, // [output_dim, inter_dim / tensor_para_size] 57 | T* fc2_bias, // [output_dim] 58 | 59 | int64_t batch_size, 60 | int64_t input_dim, 61 | int64_t inter_dim, 62 | int64_t output_dim, 63 | ActivationType activation_type, 64 | 65 | T* inter_buf, // [batch_size, inter_dim / tensor_para_size] 66 | 67 | util::CublasWrapper cublas_wrapper, 68 | util::NcclComm nccl_comm 69 | ){ 70 | // Linear1 71 | cublas_wrapper.gemm( 72 | CUBLAS_OP_N, 73 | CUBLAS_OP_T, 74 | batch_size, 75 | inter_dim / nccl_comm.size, 76 | input_dim, 77 | input, 78 | fc1_weight, 79 | inter_buf 80 | ); 81 | sync_check_cuda_error(); 82 | 83 | // Addbias & Relu 84 | // Use fused kernel to improve performance 85 | kernel::fusedAddbiasBatchedActivation( 86 | inter_buf, 87 | inter_buf, 88 | fc1_bias, 89 | batch_size, 90 | inter_dim / nccl_comm.size, 91 | activation_type 92 | ); 93 | sync_check_cuda_error(); 94 | 95 | // Linear2 96 | cublas_wrapper.gemm( 97 | CUBLAS_OP_N, 98 | CUBLAS_OP_T, 99 | batch_size, 100 | output_dim, 101 | inter_dim / nccl_comm.size, 102 | inter_buf, 103 | fc2_weight, 104 | output 105 | ); 106 | 107 | sync_check_cuda_error(); 108 | 109 | if (nccl_comm.size != 1) { 110 | st::util::stNcclAllReduce( 111 | output, 112 | output, 113 | batch_size * output_dim, 114 | util::stNcclGetDataType(), 115 | ncclSum, 116 | nccl_comm.comm, 117 | nccl_comm.stream 118 | ); 119 | } 120 | 121 | sync_check_cuda_error(); 122 | 123 | // Addbias 124 | kernel::addbiasBatched(output, output, fc2_bias, batch_size, output_dim); 125 | sync_check_cuda_error(); 126 | } 127 | 128 | template void ffn( 129 | float* output, float* input, 130 | float* fc1_weight, float* fc1_bias, 131 | float* fc2_weight, float* fc2_bias, 132 | int64_t batch_size, int64_t input_dim, int64_t inter_dim, int64_t output_dim, 133 | ActivationType activation_type, 134 | float* inter_buf, util::CublasWrapper cublas_wrapper, 135 | util::NcclComm nccl_comm 136 | ); 137 | 138 | template void ffn( 139 | half* output, half* input, 140 | half* fc1_weight, half* fc1_bias, 141 | half* fc2_weight, half* fc2_bias, 142 | int64_t batch_size, int64_t input_dim, int64_t inter_dim, int64_t output_dim, 143 | ActivationType activation_type, 144 | half* inter_buf, util::CublasWrapper cublas_wrapper, 145 | util::NcclComm nccl_comm 146 | ); 147 | 148 | } // namespace st::layer -------------------------------------------------------------------------------- /src/csrc/layer/ffn.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "kernel/addbias.h" 4 | #include "kernel/fused_addbias_activ.h" 5 | #include "util/cublas_wrapper.h" 6 | 7 | #include "util/nccl_utils.h" 8 | 9 | namespace st::layer { 10 | 11 | template 12 | void ffn( 13 | T* output, 14 | T* input, 15 | 16 | T* fc1_weight, 17 | T* fc1_bias, 18 | T* fc2_weight, 19 | T* fc2_bias, 20 | 21 | int64_t batch_size, 22 | int64_t input_dim, 23 | int64_t inter_dim, 24 | int64_t output_dim, 25 | ActivationType activation_type, 26 | 27 | T* inter_buf, 28 | 29 | util::CublasWrapper cublas_wrapper, 30 | util::NcclComm nccl_comm 31 | ); 32 | 33 | } // namespace st::layer -------------------------------------------------------------------------------- /src/csrc/layer/gated_ffn.cc: -------------------------------------------------------------------------------- 1 | #include "ffn.h" 2 | 3 | #include 4 | #include 5 | 6 | #include "util/cuda_utils.h" 7 | #include "kernel/fused_activ_multiply.h" 8 | #include "kernel/activation_types.h" 9 | namespace st::layer { 10 | 11 | // gatedFfn - FFN_GeXXX (e.g. FFN_GeGLU) with tensor parallelism 12 | // 13 | // Parallel parameters are passed by NcclComm struct 14 | // - size: size of tensor parallel 15 | // - rank: rank of current process 16 | // - comm: nccl communicator, initialized 17 | // 18 | // Weight addr should be prepared by the caller 19 | // 20 | // This layer takes input of shape [num_tokens, input_dim] and weights including: 21 | // - w1: [inter_dim / tensor_para_size, input_dim] 22 | // - w2: [output_dim, inter_dim / tensor_para_size] 23 | // - w3: [inter_dim / tensor_para_size, input_dim] 24 | // (w1, w2, and w3 corresponds to https://github.com/facebookresearch/llama/blob/main/llama/model.py#L212C4-L212C4) 25 | // 26 | // The output is of shape [num_tokens, output_dim] 27 | // output = (activation(input•w1^T) * (input•w3^T)) • w2^T, where • is matrix multiplication and * is element-wise multiplication 28 | 29 | template 30 | void gatedFfn( 31 | T* output, // [num_tokens, output_dim] 32 | T* input, // [num_tokens, output_dim] 33 | 34 | T* w1_weight, // [inter_dim / tensor_para_size, input_dim] 35 | T* w2_weight, // [output_dim, inter_dim / tensor_para_size] 36 | T* w3_weight, // [inter_dim / tensor_para_size, input_dim] 37 | 38 | int64_t num_tokens, 39 | int64_t input_dim, 40 | int64_t inter_dim, 41 | int64_t output_dim, 42 | ActivationType activation_type, 43 | 44 | T* inter_buf1, // [num_tokens, inter_dim / tensor_para_size] 45 | T* inter_buf2, // [num_tokens, inter_dim / tensor_para_size] 46 | 47 | util::CublasWrapper cublas_wrapper, 48 | util::NcclComm nccl_comm 49 | ) { 50 | assert (inter_dim % nccl_comm.size == 0); 51 | 52 | // Calculate input • w1_T 53 | cublas_wrapper.gemm( 54 | CUBLAS_OP_N, 55 | CUBLAS_OP_T, 56 | num_tokens, 57 | inter_dim / nccl_comm.size, 58 | input_dim, 59 | input, 60 | w1_weight, 61 | inter_buf1 62 | ); 63 | sync_check_cuda_error(); 64 | 65 | // Calculate input • w3_T 66 | cublas_wrapper.gemm( 67 | CUBLAS_OP_N, 68 | CUBLAS_OP_T, 69 | num_tokens, 70 | inter_dim / nccl_comm.size, 71 | input_dim, 72 | input, 73 | w3_weight, 74 | inter_buf2 75 | ); 76 | sync_check_cuda_error(); 77 | 78 | // Calculate silu(input • w1_T) * (input • w3_T) 79 | st::kernel::fusedActivationMultiply( 80 | inter_buf1, 81 | inter_buf1, 82 | inter_buf2, 83 | num_tokens * (inter_dim / nccl_comm.size), 84 | activation_type 85 | ); 86 | sync_check_cuda_error(); 87 | 88 | // Calculate (silu(input • w1_T) * (input • w3_T)) • w2_T 89 | cublas_wrapper.gemm( 90 | CUBLAS_OP_N, 91 | CUBLAS_OP_T, 92 | num_tokens, 93 | output_dim, 94 | inter_dim / nccl_comm.size, 95 | inter_buf1, 96 | w2_weight, 97 | output 98 | ); 99 | sync_check_cuda_error(); 100 | 101 | if (nccl_comm.size != 1) { 102 | st::util::stNcclAllReduce( 103 | output, 104 | output, 105 | num_tokens * output_dim, 106 | util::stNcclGetDataType(), 107 | ncclSum, 108 | nccl_comm.comm, 109 | nccl_comm.stream 110 | ); 111 | sync_check_cuda_error(); 112 | } 113 | } 114 | 115 | #define INSTANTIAL_GATED_FFN(T) \ 116 | template void gatedFfn( \ 117 | T* output, \ 118 | T* input, \ 119 | T* fc1_weight, \ 120 | T* fc2_weight, \ 121 | T* fc3_weight, \ 122 | int64_t num_tokens, \ 123 | int64_t input_dim, \ 124 | int64_t inter_dim, \ 125 | int64_t output_dim, \ 126 | ActivationType activation_type, \ 127 | T* inter_buf1, \ 128 | T* inter_buf2, \ 129 | util::CublasWrapper cublas_wrapper, \ 130 | util::NcclComm nccl_comm \ 131 | ); 132 | 133 | INSTANTIAL_GATED_FFN(half) 134 | INSTANTIAL_GATED_FFN(float) 135 | 136 | } // namespace st::layer -------------------------------------------------------------------------------- /src/csrc/layer/gated_ffn.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "kernel/addbias.h" 4 | #include "kernel/activation_types.h" 5 | #include "util/cublas_wrapper.h" 6 | 7 | #include "util/nccl_utils.h" 8 | namespace st::layer { 9 | 10 | template 11 | void gatedFfn( 12 | T* output, 13 | T* input, 14 | 15 | T* fc1_weight, 16 | T* fc2_weight, 17 | T* fc3_weight, 18 | 19 | int64_t num_tokens, 20 | int64_t input_dim, 21 | int64_t inter_dim, 22 | int64_t output_dim, 23 | ActivationType activation_type, 24 | 25 | T* inter_buf1, 26 | T* inter_buf2, 27 | 28 | util::CublasWrapper cublas_wrapper, 29 | util::NcclComm nccl_comm 30 | ); 31 | 32 | } // namespace st::layer -------------------------------------------------------------------------------- /src/csrc/model/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(gpt) -------------------------------------------------------------------------------- /src/csrc/model/gpt/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_library(model_gpt gpt_weight.cc gpt.cc gptop_base.cc) 2 | target_link_libraries(model_gpt layer kernel util) 3 | set_property(TARGET model_gpt PROPERTY POSITION_INDEPENDENT_CODE ON) 4 | 5 | add_subdirectory(opt) 6 | add_subdirectory(llama2) 7 | add_subdirectory(gpt2) 8 | -------------------------------------------------------------------------------- /src/csrc/model/gpt/gpt.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include "util/cublas_wrapper.h" 6 | #include "util/cuda_utils.h" 7 | #include "util/nccl_utils.h" 8 | 9 | #include "gpt_weight.h" 10 | #include "gpt_base.h" 11 | 12 | #include 13 | 14 | namespace st::model { 15 | 16 | // Please refer to gpt_base.h for the design of GptBase, Gpt, GptOpBase, and XXXop. 17 | template 18 | class Gpt : public GptBase { 19 | private: 20 | GptWeight weight; 21 | util::CublasWrapper cublas_wrapper; 22 | util::NcclComm tensor_para_comm, pipeline_para_comm; 23 | 24 | // Buffers for inputs & input metadata 25 | RemallocableArray d_decoder_input; // [num_tokens, hidden_size] 26 | RemallocableArray d_decoder_output; // [num_tokens, hidden_size] 27 | RemallocableArray d_input_lens; // [batch_size] 28 | RemallocableArray d_sum_prev_input_lens; // [batch_size] 29 | 30 | // Buffers for input embedding 31 | RemallocableArray d_token_ids; // [num_tokens] 32 | RemallocableArray d_position_ids; // [num_tokens] 33 | 34 | // Buffers for input indexing 35 | RemallocableArray ith_context_req_req_index; // [batch_size] 36 | RemallocableArray ith_context_req_token_index; // [batch_size] 37 | RemallocableArray ith_decoding_req_req_index; // [batch_size] 38 | RemallocableArray ith_decoding_req_token_index;// [batch_size] 39 | 40 | // Buffers for each layer's internal computation 41 | RemallocableArray qkv_buf; // [num_tokens+15, local_q_head_num + 2*local_kv_head_num, head_dim]. Please refer to fused_context_stage_attention.cu for the reason of +15 here 42 | RemallocableArray attn_out_buf; // [num_tokens, local_q_head_num, head_dim] 43 | RemallocableArray ffn_inter_buf_1; // [num_tokens, local_ffn_inter_dim] 44 | RemallocableArray ffn_inter_buf_2; // [num_tokens, local_ffn_inter_dim], only used when is_gated_ffn = true 45 | RemallocableArray context_stage_kernel_m_buf; // [local_q_head_num, num_tokens] 46 | RemallocableArray context_stage_kernel_l_buf; // [local_q_head_num, num_tokens] 47 | 48 | // Buffers for forwardDecoder 49 | RemallocableArray attention_out; // [num_tokens, hidden_size] 50 | 51 | // Buffers for output projection 52 | RemallocableArray output_projection_last_tokens_buf; // [batch_size, hidden_dim] 53 | RemallocableArray output_projection_buf; // [batch_size, vocab_size] 54 | RemallocableArray output_projection_result_buf; // [batch_size] 55 | 56 | public: 57 | Gpt(const GptHyperParam& hyper_param, 58 | const GptPagedAttnParam& pagedattn_param, 59 | GptParallelismParam& parallelism_param = GptParallelismParam() 60 | ); 61 | ~Gpt() override; 62 | 63 | void setPagedattnParam(const GptPagedAttnParam& pagedattn_param); 64 | void setParallelismParam(const GptParallelismParam& parallelism_param); 65 | 66 | // Init communicator for NCCL. 67 | // Args: 68 | // tp_id: NCCL unique ID for tensor parallelism. 69 | // pp_id: NCCL unique ID for pipeline parallelism. 70 | void init_communicator(const ncclUniqueId& tp_id, const ncclUniqueId& pp_id); 71 | 72 | void getInputPosiIds( 73 | const std::vector> &input_tokens_batched, 74 | const std::vector &first_token_indexes, 75 | const int64_t num_tokens 76 | ); 77 | 78 | void inputBatchEmbedAndPosiEncode( 79 | T* d_output, 80 | const std::vector> &input_tokens_batched, 81 | const int64_t num_tokens 82 | ); 83 | 84 | void selectOutputTokenBatched( 85 | int64_t* h_result_token, 86 | const T* d_input, 87 | int64_t num_tokens, 88 | const int64_t* first_token_indexes, 89 | int64_t batch_size 90 | ); 91 | 92 | void forwardDecoder( 93 | T* d_output, 94 | const T* d_input, 95 | T* d_k_cache, 96 | T* d_v_cache, 97 | int64_t* d_block_table, 98 | const int64_t* d_input_len, 99 | 100 | const int64_t* h_input_len, 101 | const bool* h_is_context_stage, 102 | const int64_t batch_size 103 | ); 104 | 105 | void loadWeight(const std::string& model_path) override; 106 | void initDummyWeight() override; 107 | 108 | std::vector forward( 109 | const std::vector> &input_tokens_batched, 110 | const std::vector &first_token_indexes, 111 | void* d_k_cache, 112 | void* d_v_cache, 113 | int64_t* d_block_table 114 | ) override; 115 | }; 116 | 117 | } // namespace st::model -------------------------------------------------------------------------------- /src/csrc/model/gpt/gpt2/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_library(model_gpt2 gpt2op.cc) 2 | target_link_libraries(model_gpt2 layer kernel util model_gpt "${TORCH_LIBRARIES}") 3 | set_property(TARGET model_gpt2 PROPERTY POSITION_INDEPENDENT_CODE ON) 4 | -------------------------------------------------------------------------------- /src/csrc/model/gpt/gpt2/gpt2op.cc: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "gpt2op.h" 4 | 5 | #include "util/torch_utils.h" 6 | 7 | namespace st::model { 8 | 9 | Gpt2Op::Gpt2Op(const int64_t vocab_size, 10 | const int64_t max_position_embeddings, 11 | const int64_t hidden_size, 12 | const int64_t num_layers, 13 | const int64_t num_heads, 14 | const int64_t head_dim, 15 | std::string inference_dtype, 16 | const int64_t block_size, 17 | const int64_t max_num_block_per_req, 18 | const std::vector parallel_config 19 | ): 20 | GptOpBase(inference_dtype, 21 | GptHyperParam::GetGpt2HyperParam( 22 | vocab_size, 23 | max_position_embeddings, 24 | hidden_size, 25 | num_layers, 26 | num_heads, 27 | head_dim, 28 | 4 * hidden_size 29 | ), 30 | GptPagedAttnParam{ 31 | .block_size = block_size, 32 | .max_num_block_per_req = max_num_block_per_req, 33 | }, 34 | GptParallelismParam(parallel_config) 35 | ) { 36 | } 37 | 38 | } // namespace st::model 39 | -------------------------------------------------------------------------------- /src/csrc/model/gpt/gpt2/gpt2op.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include "model/gpt/gptop_base.h" 6 | 7 | namespace st::model { 8 | 9 | // Please refer to gpt_base.h for the design of GPTBase, Gpt, GptOpBase, and XXXop. 10 | class Gpt2Op : public GptOpBase { 11 | public: 12 | Gpt2Op(const int64_t vocab_size, // GPT Params 13 | const int64_t max_position_embeddings, 14 | const int64_t hidden_size, 15 | const int64_t num_layers, 16 | const int64_t num_heads, 17 | const int64_t head_dim, 18 | const std::string inference_dtype, 19 | const int64_t block_size, // Cache Params 20 | const int64_t max_num_block_per_req, 21 | const std::vector parallel_config); 22 | }; 23 | 24 | } // namespace st::model -------------------------------------------------------------------------------- /src/csrc/model/gpt/gpt_base.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | #include 7 | 8 | #include "gpt_hyper_param.h" 9 | #include "gpt_pagedattn_param.h" 10 | #include "gpt_parallelism_param.h" 11 | 12 | namespace st::model { 13 | 14 | /* 15 | GptBase: Abstract base class for GPT. The actual GPT class is based on this class. 16 | 17 | We have this class because PyTorch binding requires a non-template class, so 18 | our XXXop (e.g. OptOp) classes cannot have template parameters. So the following 19 | implementation won't work: 20 | 21 | ``` 22 | template 23 | class OptOp { 24 | Gpt gpt; 25 | }; 26 | ``` 27 | 28 | Instead, we have to use a non-template base class, and leverage virtual function 29 | and polymorphism in C++ to implement the actual GPT and XXXop class. For example: 30 | 31 | ``` 32 | class GptBase { 33 | // Virtual functions like loadWeight and forward 34 | }; 35 | 36 | template 37 | class Gpt : GPTBase { 38 | // Implementations of virtual functions 39 | }; 40 | 41 | class GptOpBase { 42 | GptBase* gpt; // A pointer to GPTBase, which can be Gpt for any T. 43 | // Implementations of loadWeight() and forward(), which takes torch::Tensor 44 | // as input and calls Gpt::loadWeight() and Gpt::forward(). 45 | }; 46 | 47 | class OptOp : GptOpBase { 48 | // Implementation of constructor 49 | }; 50 | ``` 51 | */ 52 | 53 | class GptBase { 54 | public: 55 | GptHyperParam hyper_param; 56 | GptPagedAttnParam pagedattn_param; 57 | GptParallelismParam parallelism_param; 58 | 59 | virtual ~GptBase() {} 60 | virtual void loadWeight(const std::string&) = 0; 61 | virtual void initDummyWeight() = 0; 62 | 63 | // Forward function for GPT. 64 | // Args: 65 | // input_tokens_batched: a batch of requests, where each element is vector of tokens for corresponding request. 66 | // input_tokens_batched may contain requests in context or decoding phase. 67 | // first_token_indexes: the index of the first token in each request's input_tokens. For example, if request i is 68 | // in decoding phase, and it has generated 5 tokens, then first_token_indexes[i] = 5. 69 | // if first_token_indexes[j] == 0, then request j is in context phase. 70 | // d_k_cache: the overall key cache. [num_blocks, num_layers, num_local_heads, block_size, head_dim] 71 | // d_v_cache: the overall value cache. [num_blocks, num_layers, num_local_heads, block_size, head_dim] 72 | // block_table: block_table[i][j] = k means the j-th logical block for request i is the k-th physical block in the overall key/value cache. 73 | // Note: here request i is the i-th request in input_tokens_batched, not the i-th request in the whole dataset. 74 | virtual std::vector forward( 75 | // input data && metadata 76 | const std::vector> &input_tokens_batched, 77 | const std::vector &first_token_indexes, 78 | 79 | // key-value management 80 | void* d_k_cache, 81 | void* d_v_cache, 82 | int64_t* d_block_table 83 | ) = 0; 84 | 85 | virtual void init_communicator(const ncclUniqueId& tp_id, const ncclUniqueId& pp_id) = 0; 86 | 87 | }; 88 | 89 | } // namespace st::model -------------------------------------------------------------------------------- /src/csrc/model/gpt/gpt_hyper_param.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include "kernel/activation_types.h" 6 | namespace st::model { 7 | 8 | struct GptHyperParam { 9 | // Hyper-parameters 10 | int64_t vocab_size; // The size of the vocabulary 11 | int64_t max_position_embeddings; // The maximum length of the input sequence 12 | int64_t hidden_size; // The length of the embedded vector 13 | int64_t num_layers; // The number of layers (transformer blocks) 14 | int64_t num_q_heads; // The number of query heads in the multi-head attention 15 | int64_t num_kv_heads; // The number of key/value heads in the multi-head attention. 16 | // If the model does not use GQA (Grouped Query Attention), just 17 | // set num_kv_heads = num_q_heads 18 | int64_t head_dim; // The dimension of each head (length of the key, query, and value vectors) 19 | int64_t ffn_inter_dim; // The intermediate dimension of the feed-forward network 20 | 21 | // Model configurations 22 | bool is_pre_layernorm; // Perform layernorm before/after the self-attention and feed-forward network 23 | bool is_rotary_posi_embedding; // Use rotary position embedding instead of absolute position embedding 24 | bool is_gated_ffn; // Use gated feed-forward network 25 | ActivationType ffn_activation_type; // The activation function of the feed-forward network 26 | bool is_rmsnorm; // Use RMSNorm instead of LayerNorm 27 | bool is_attn_qkv_biased; 28 | bool is_attn_out_biased; 29 | 30 | friend std::ostream& operator<<(std::ostream& os, const GptHyperParam& params) { 31 | os << "GptHyperParam {\n" 32 | << "\tvocab_size = " << params.vocab_size << "\n" 33 | << "\tmax_position_embeddings = " << params.max_position_embeddings << "\n" 34 | << "\thidden_size = " << params.hidden_size << "\n" 35 | << "\tnum_layers = " << params.num_layers << "\n" 36 | << "\tnum_q_heads = " << params.num_q_heads << "\n" 37 | << "\tnum_kv_heads = " << params.num_kv_heads << "\n" 38 | << "\thead_dim = " << params.head_dim << "\n" 39 | << "\tffn_inter_dim = " << params.ffn_inter_dim << "\n" 40 | << "\tis_pre_layernorm = " << params.is_pre_layernorm << "\n" 41 | << "\tis_rotary_posi_embedding = " << params.is_rotary_posi_embedding << "\n" 42 | << "\tis_gated_ffn = " << params.is_gated_ffn << "\n" 43 | << "\tffn_activation_type = " << static_cast(params.ffn_activation_type) << "\n" 44 | << "\tis_rmsnorm = " << params.is_rmsnorm << "\n" 45 | << "\tis_attn_qkv_biased = " << params.is_attn_qkv_biased << "\n" 46 | << "\tis_attn_out_bias = " << params.is_attn_out_biased << "\n" 47 | << "}"; 48 | return os; 49 | } 50 | 51 | static GptHyperParam GetOptHyperParam ( 52 | int64_t vocab_size, 53 | int64_t max_position_embeddings, 54 | int64_t hidden_size, 55 | int64_t num_layers, 56 | int64_t num_heads, 57 | int64_t head_dim, 58 | int64_t ffn_inter_dim 59 | ) { 60 | return GptHyperParam{ 61 | .vocab_size = vocab_size, 62 | .max_position_embeddings = max_position_embeddings, 63 | .hidden_size = hidden_size, 64 | .num_layers = num_layers, 65 | .num_q_heads = num_heads, 66 | .num_kv_heads = num_heads, 67 | .head_dim = head_dim, 68 | .ffn_inter_dim = ffn_inter_dim, 69 | .is_pre_layernorm = true, 70 | .is_rotary_posi_embedding = false, 71 | .is_gated_ffn = false, 72 | .ffn_activation_type = ActivationType::RELU, 73 | .is_rmsnorm = false, 74 | .is_attn_qkv_biased = true, 75 | .is_attn_out_biased = true 76 | }; 77 | } 78 | 79 | static GptHyperParam GetLlama2HyperParam ( 80 | int64_t vocab_size, 81 | int64_t max_position_embeddings, 82 | int64_t hidden_size, 83 | int64_t num_layers, 84 | int64_t num_q_heads, 85 | int64_t num_kv_heads, 86 | int64_t head_dim, 87 | int64_t ffn_inter_dim 88 | ) { 89 | return GptHyperParam{ 90 | .vocab_size = vocab_size, 91 | .max_position_embeddings = max_position_embeddings, 92 | .hidden_size = hidden_size, 93 | .num_layers = num_layers, 94 | .num_q_heads = num_q_heads, 95 | .num_kv_heads = num_kv_heads, 96 | .head_dim = head_dim, 97 | .ffn_inter_dim = ffn_inter_dim, 98 | .is_pre_layernorm = true, 99 | .is_rotary_posi_embedding = true, 100 | .is_gated_ffn = true, 101 | .ffn_activation_type = ActivationType::SILU, 102 | .is_rmsnorm = true, 103 | .is_attn_qkv_biased = false, 104 | .is_attn_out_biased = false 105 | }; 106 | } 107 | 108 | static GptHyperParam GetGpt2HyperParam ( 109 | int64_t vocab_size, 110 | int64_t max_position_embeddings, 111 | int64_t hidden_size, 112 | int64_t num_layers, 113 | int64_t num_heads, 114 | int64_t head_dim, 115 | int64_t ffn_inter_dim 116 | ) { 117 | return GptHyperParam{ 118 | .vocab_size = vocab_size, 119 | .max_position_embeddings = max_position_embeddings, 120 | .hidden_size = hidden_size, 121 | .num_layers = num_layers, 122 | .num_q_heads = num_heads, 123 | .num_kv_heads = num_heads, 124 | .head_dim = head_dim, 125 | .ffn_inter_dim = ffn_inter_dim, 126 | .is_pre_layernorm = true, 127 | .is_rotary_posi_embedding = false, 128 | .is_gated_ffn = false, 129 | .ffn_activation_type = ActivationType::GELU, 130 | .is_rmsnorm = false, 131 | .is_attn_qkv_biased = true, 132 | .is_attn_out_biased = true 133 | }; 134 | } 135 | }; 136 | 137 | } // namespace st::model 138 | -------------------------------------------------------------------------------- /src/csrc/model/gpt/gpt_pagedattn_param.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace st::model { 6 | 7 | struct GptPagedAttnParam { 8 | // Hyperparameters related to PagedAttention 9 | int64_t block_size; 10 | int64_t max_num_block_per_req; 11 | 12 | friend std::ostream& operator<<(std::ostream& os, const GptPagedAttnParam& params) { 13 | os << "GptPagedAttnParam {\n" 14 | << "\tblock_size = " << params.block_size << "\n" 15 | << "\tmax_num_block_per_req = " << params.max_num_block_per_req << "\n" 16 | << "}"; 17 | return os; 18 | } 19 | }; 20 | 21 | } // namespace st::model 22 | -------------------------------------------------------------------------------- /src/csrc/model/gpt/gpt_parallelism_param.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | #include "model/gpt/gpt_hyper_param.h" 7 | 8 | namespace st::model { 9 | 10 | struct GptParallelismParam { 11 | // Hyper parameters related to parallelism 12 | int64_t tensor_para_size = 1; 13 | int64_t tensor_para_rank = 0; 14 | 15 | int64_t pipeline_para_size = 1; 16 | int64_t pipeline_para_rank = 0; 17 | 18 | bool hyper_inited = false; 19 | 20 | // The following two parameters are used for pipeline parallelism 21 | // The layer range of the current pipeline stage is [layer_begin, layer_end) 22 | int64_t layer_begin = 0, layer_end = 0, local_layer_num = 0; 23 | 24 | GptParallelismParam(int64_t tensor_para_size = 1, int64_t tensor_para_rank = 0, int64_t pipeline_para_size = 1, int64_t pipeline_para_rank = 0) 25 | : tensor_para_size(tensor_para_size) 26 | , tensor_para_rank(tensor_para_rank) 27 | , pipeline_para_size(pipeline_para_size) 28 | , pipeline_para_rank(pipeline_para_rank) 29 | { 30 | } 31 | 32 | GptParallelismParam(const std::vector parallel_config) 33 | : GptParallelismParam(parallel_config[0], parallel_config[1], parallel_config[2], parallel_config[3]) 34 | { 35 | } 36 | 37 | void init_by_hyper_param(const GptHyperParam& hyper_param) 38 | { 39 | if (hyper_inited) { 40 | return; 41 | } 42 | hyper_inited = true; 43 | if (hyper_param.num_layers % pipeline_para_size != 0) { 44 | throw std::invalid_argument("The number of layers must be divisible by the pipeline parallelism size."); 45 | } 46 | local_layer_num = hyper_param.num_layers / pipeline_para_size; 47 | layer_begin = pipeline_para_rank * local_layer_num; 48 | layer_end = layer_begin + local_layer_num; 49 | } 50 | 51 | inline bool is_parallel() const 52 | { 53 | return tensor_para_size > 1 || pipeline_para_size > 1; 54 | } 55 | 56 | inline bool is_last_stage() const 57 | { 58 | return pipeline_para_rank == pipeline_para_size - 1; 59 | } 60 | 61 | inline bool is_first_stage() const 62 | { 63 | return pipeline_para_rank == 0; 64 | } 65 | 66 | inline bool is_stage_leader() const 67 | { 68 | return tensor_para_rank == 0; 69 | } 70 | 71 | void set_parallelism(int64_t tensor_para_size, int64_t tensor_para_rank, int64_t pipeline_para_size, int64_t pipeline_para_rank) 72 | { 73 | this->tensor_para_size = tensor_para_size; 74 | this->tensor_para_rank = tensor_para_rank; 75 | this->pipeline_para_size = pipeline_para_size; 76 | this->pipeline_para_rank = pipeline_para_rank; 77 | } 78 | 79 | friend std::ostream& operator<<(std::ostream& os, const GptParallelismParam& param) 80 | { 81 | os << "tensor_para_size: " << param.tensor_para_size << std::endl; 82 | os << "tensor_para_rank: " << param.tensor_para_rank << std::endl; 83 | os << "pipeline_para_size: " << param.pipeline_para_size << std::endl; 84 | os << "pipeline_para_rank: " << param.pipeline_para_rank << std::endl; 85 | return os; 86 | } 87 | }; 88 | 89 | } // namespace st::model 90 | -------------------------------------------------------------------------------- /src/csrc/model/gpt/gpt_weight.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | #include 7 | 8 | #include "gpt_hyper_param.h" 9 | #include "gpt_parallelism_param.h" 10 | 11 | namespace st::model { 12 | 13 | // GptLayerWeight - weights of a single GPT layer 14 | // This struct contains the weights of a single GPT layer. The weights are 15 | // loaded from a pt file. 16 | // All pointers are owned (allocated and freed) by GptWeight. 17 | template 18 | struct GptLayerWeight { 19 | T* attn_qkv_kernel = nullptr; // [hidden_size, local_q_head_num+2*local_kv_head_num, head_dim] 20 | T* attn_qkv_bias = nullptr; // [local_q_head_num+2*local_kv_head_num, head_dim] 21 | T* attn_out_kernel = nullptr; // [local_q_head_num, head_dim, hidden_size] 22 | T* attn_out_bias = nullptr; // [hidden_size] 23 | 24 | T* attn_layernorm_weight = nullptr; // [hidden_size] 25 | T* attn_layernorm_bias = nullptr; // [hidden_size] 26 | 27 | T* ffn_fc1_weight = nullptr; // [inter_dim / tensor_para_size, hidden_size] 28 | T* ffn_fc1_bias = nullptr; // [inter_dim / tensor_para_size], will be used only when is_gated_ffn = false 29 | T* ffn_fc2_weight = nullptr; // [hidden_size, inter_dim / tensor_para_size] 30 | T* ffn_fc2_bias = nullptr; // [hidden_size], will be used only when is_gated_ffn = false 31 | T* ffn_fc3_weight = nullptr; // [inter_dim / tensor_para_size, hidden_size], will be used only when is_gated_ffn = true 32 | 33 | T* final_layernorm_weight = nullptr;// [hidden_size] 34 | T* final_layernorm_bias = nullptr; // [hidden_size] 35 | }; 36 | 37 | 38 | // GptWeight - weights of the GPT model 39 | // All pointers are owned (allocated and freed) by itself 40 | template 41 | class GptWeight { 42 | private: 43 | void allocateWeightArray(); 44 | void freeWeightArray(); 45 | void loadTensor_qkv_weight_kernel_or_bias(const uint32_t dim, T* to_ptr, const std::string model_dir, const std::string key, const int64_t expect_size); 46 | void loadTensor_tp(const uint32_t dim, T* to_ptr, const std::string model_dir, const std::string key, const int64_t expect_size); 47 | void loadTensor_all(T* to_ptr, const std::string model_dir, const std::string key, const int64_t expect_size); 48 | bool contain_embedding_layer = false; 49 | 50 | public: 51 | GptHyperParam hyper_param; 52 | GptParallelismParam parallelism_param; 53 | 54 | T* embed_tokens_weight; // [vocab_size, hidden_size] 55 | T* embed_positions_weight; // [max_position_embeddings, hidden_size], will be used only when is_rotary_embedding = false 56 | 57 | std::vector> layer_weights; 58 | 59 | T* final_layernorm_weight; // [hidden_size] 60 | T* final_layernorm_bias; // [hidden_size] 61 | T* output_proj_weight; // [vocab_size, hidden_size] 62 | 63 | T layernorm_epsilon = (T)1e-5; 64 | 65 | bool initialized = false; 66 | 67 | GptWeight(); 68 | ~GptWeight(); 69 | 70 | void init(const GptHyperParam& hyper_param, GptParallelismParam& parallelism_param = GptParallelismParam()); 71 | 72 | void loadWeight(const std::string& weight_path); 73 | void initDummyWeight(); 74 | }; 75 | 76 | 77 | } // namespace st::model 78 | -------------------------------------------------------------------------------- /src/csrc/model/gpt/gptop_base.cc: -------------------------------------------------------------------------------- 1 | #include "gptop_base.h" 2 | 3 | #include "util/torch_utils.h" 4 | #include "util/cuda_utils.h" 5 | #include "util/nccl_utils.h" 6 | 7 | namespace st::model { 8 | 9 | GptOpBase::GptOpBase( 10 | std::string inference_dtype, 11 | GptHyperParam hyper_param, 12 | GptPagedAttnParam pagedattn_param, 13 | GptParallelismParam parallelism_param 14 | ) { 15 | if (inference_dtype == "fp32") { 16 | gpt = new Gpt(hyper_param, pagedattn_param, parallelism_param); 17 | } else if (inference_dtype == "fp16") { 18 | gpt = new Gpt<__half>(hyper_param, pagedattn_param, parallelism_param); 19 | } else { 20 | throw std::runtime_error("Unsupported inference_dtype: " + inference_dtype); 21 | } 22 | 23 | weight_loaded = false; 24 | } 25 | 26 | GptOpBase::~GptOpBase() { 27 | delete gpt; 28 | } 29 | 30 | void GptOpBase::loadWeight(const std::string& weight_path) { 31 | this->gpt->loadWeight(weight_path); 32 | this->weight_loaded = true; 33 | }; 34 | 35 | void GptOpBase::initDummyWeight() { 36 | this->gpt->initDummyWeight(); 37 | this->weight_loaded = true; 38 | }; 39 | 40 | std::vector GptOpBase::forward( 41 | const std::vector> &input_tokens_batched, 42 | const std::vector &first_token_indexes, // [batchsize] 43 | torch::Tensor &k_cache, // [num_blocks, num_heads, block_size, head_dim] 44 | torch::Tensor &v_cache, // [num_blocks, num_heads, block_size, head_dim] 45 | const std::vector> &block_table) 46 | { 47 | if (!this->weight_loaded) { 48 | throw std::runtime_error("Please load the weight before inference."); 49 | } 50 | 51 | int64_t batch_size = input_tokens_batched.size(); 52 | if (batch_size == 0) { 53 | return std::vector(); 54 | } 55 | 56 | // Prepare block_table 57 | int64_t* h_block_table = new int64_t[batch_size * this->gpt->pagedattn_param.max_num_block_per_req]; 58 | for (int64_t i = 0; i < batch_size; i++) { 59 | for (int64_t j = 0; j < (int64_t)block_table[i].size(); j++) { 60 | h_block_table[i * this->gpt->pagedattn_param.max_num_block_per_req + j] = block_table[i][j]; 61 | } 62 | } 63 | int64_t *d_block_table; 64 | CUDA_CHECK(cudaMalloc(&d_block_table, sizeof(int64_t) * batch_size * this->gpt->pagedattn_param.max_num_block_per_req)); 65 | CUDA_FREE_AT_RETURN(d_block_table); 66 | cudaMemcpy(d_block_table, h_block_table, sizeof(int64_t) * batch_size * this->gpt->pagedattn_param.max_num_block_per_req, cudaMemcpyHostToDevice); 67 | delete[] h_block_table; 68 | sync_check_cuda_error(); 69 | 70 | return this->gpt->forward(input_tokens_batched, 71 | first_token_indexes, 72 | st::util::convertTensorToRawPtr(k_cache), 73 | st::util::convertTensorToRawPtr(v_cache), 74 | d_block_table); 75 | } 76 | 77 | void GptOpBase::init_communicator(const std::vector tp_id, const std::vector pp_id){ 78 | ncclUniqueId tp_uid, pp_uid; 79 | memcpy(tp_uid.internal, &tp_id[0], NCCL_UNIQUE_ID_BYTES); 80 | memcpy(pp_uid.internal, &pp_id[0], NCCL_UNIQUE_ID_BYTES); 81 | this->gpt->init_communicator(tp_uid, pp_uid); 82 | } 83 | 84 | } 85 | -------------------------------------------------------------------------------- /src/csrc/model/gpt/gptop_base.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include 6 | 7 | #include "model/gpt/gpt_base.h" 8 | #include "model/gpt/gpt.h" 9 | 10 | namespace st::model { 11 | 12 | // Please refer to gpt_base.h for the design of GPTBase, Gpt, and GptOpBase. 13 | class GptOpBase : public torch::CustomClassHolder { 14 | public: 15 | GptBase* gpt; // A pointer to GptBase, which can be Gpt for any T. 16 | bool weight_loaded; 17 | 18 | GptOpBase( 19 | std::string inference_dtype, 20 | GptHyperParam hyper_param, 21 | GptPagedAttnParam pagedattn_param, 22 | GptParallelismParam parallelism_param 23 | ); 24 | 25 | ~GptOpBase(); 26 | 27 | void loadWeight(const std::string& weight_path); 28 | void initDummyWeight(); 29 | 30 | std::vector forward( 31 | const std::vector> &input_tokens_batched, 32 | const std::vector &first_token_indexes, 33 | torch::Tensor &k_cache, 34 | torch::Tensor &v_cache, 35 | const std::vector> &block_table 36 | ); 37 | 38 | void init_communicator(const std::vector tp_id, const std::vector pp_id); 39 | }; 40 | 41 | } -------------------------------------------------------------------------------- /src/csrc/model/gpt/llama2/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_library(model_llama2 llama2op.cc) 2 | target_link_libraries(model_llama2 layer kernel util model_gpt "${TORCH_LIBRARIES}") 3 | set_property(TARGET model_llama2 PROPERTY POSITION_INDEPENDENT_CODE ON) 4 | -------------------------------------------------------------------------------- /src/csrc/model/gpt/llama2/llama2op.cc: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "llama2op.h" 4 | 5 | #include "util/torch_utils.h" 6 | 7 | namespace st::model { 8 | 9 | Llama2Op::Llama2Op(const int64_t vocab_size, 10 | const int64_t max_position_embeddings, 11 | const int64_t hidden_size, 12 | const int64_t num_layers, 13 | const int64_t num_q_heads, 14 | const int64_t num_kv_heads, 15 | const int64_t head_dim, 16 | const int64_t ffn_inter_dim, 17 | std::string inference_dtype, 18 | const int64_t block_size, 19 | const int64_t max_num_block_per_req, 20 | const std::vector parallel_config): 21 | GptOpBase(inference_dtype, 22 | GptHyperParam::GetLlama2HyperParam( 23 | vocab_size, 24 | max_position_embeddings, 25 | hidden_size, 26 | num_layers, 27 | num_q_heads, 28 | num_kv_heads, 29 | head_dim, 30 | ffn_inter_dim 31 | ), 32 | GptPagedAttnParam{ 33 | .block_size = block_size, 34 | .max_num_block_per_req = max_num_block_per_req, 35 | }, 36 | GptParallelismParam(parallel_config) 37 | ) { 38 | }; 39 | 40 | } // namespace st::model 41 | -------------------------------------------------------------------------------- /src/csrc/model/gpt/llama2/llama2op.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include "model/gpt/gptop_base.h" 6 | 7 | namespace st::model { 8 | 9 | // Please refer to gpt_base.h for the design of GPTBase, Gpt, GptOpBase, and XXXop. 10 | class Llama2Op : public GptOpBase { 11 | public: 12 | Llama2Op(const int64_t vocab_size, // GPT Params 13 | const int64_t max_position_embeddings, 14 | const int64_t hidden_size, 15 | const int64_t num_layers, 16 | const int64_t num_q_heads, 17 | const int64_t num_kv_heads, 18 | const int64_t head_dim, 19 | const int64_t ffn_inter_dim, 20 | const std::string inference_dtype, 21 | const int64_t block_size, // Cache Params 22 | const int64_t max_num_block_per_req, 23 | const std::vector parallel_config); 24 | 25 | // Inherit loadweight() and forward() from GptOpBase 26 | }; 27 | 28 | } // namespace st::model -------------------------------------------------------------------------------- /src/csrc/model/gpt/opt/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_library(model_opt optop.cc) 2 | target_link_libraries(model_opt layer kernel util model_gpt "${TORCH_LIBRARIES}") 3 | set_property(TARGET model_opt PROPERTY POSITION_INDEPENDENT_CODE ON) 4 | -------------------------------------------------------------------------------- /src/csrc/model/gpt/opt/optop.cc: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "optop.h" 4 | 5 | #include "util/torch_utils.h" 6 | 7 | namespace st::model { 8 | 9 | OptOp::OptOp(const int64_t vocab_size, 10 | const int64_t max_position_embeddings, 11 | const int64_t hidden_size, 12 | const int64_t num_layers, 13 | const int64_t num_heads, 14 | const int64_t head_dim, 15 | std::string inference_dtype, 16 | const int64_t block_size, 17 | const int64_t max_num_block_per_req, 18 | const std::vector parallel_config 19 | ): 20 | GptOpBase(inference_dtype, 21 | GptHyperParam::GetOptHyperParam( 22 | vocab_size, 23 | max_position_embeddings, 24 | hidden_size, 25 | num_layers, 26 | num_heads, 27 | head_dim, 28 | 4 * hidden_size 29 | ), 30 | GptPagedAttnParam{ 31 | .block_size = block_size, 32 | .max_num_block_per_req = max_num_block_per_req, 33 | }, 34 | GptParallelismParam(parallel_config) 35 | ) { 36 | } 37 | 38 | } // namespace st::model 39 | -------------------------------------------------------------------------------- /src/csrc/model/gpt/opt/optop.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include "model/gpt/gptop_base.h" 6 | 7 | namespace st::model { 8 | 9 | // Please refer to gpt_base.h for the design of GPTBase, Gpt, GptOpBase, and XXXop. 10 | class OptOp : public GptOpBase { 11 | public: 12 | OptOp(const int64_t vocab_size, // GPT Params 13 | const int64_t max_position_embeddings, 14 | const int64_t hidden_size, 15 | const int64_t num_layers, 16 | const int64_t num_heads, 17 | const int64_t head_dim, 18 | const std::string inference_dtype, 19 | const int64_t block_size, // Cache Params 20 | const int64_t max_num_block_per_req, 21 | const std::vector parallel_config); 22 | }; 23 | 24 | } // namespace st::model -------------------------------------------------------------------------------- /src/csrc/pybinding.cc: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "util/py_nccl.h" 4 | #include "util/py_swapping.h" 5 | #include "util/py_block_migration.h" 6 | 7 | #include "model/gpt/opt/optop.h" 8 | #include "model/gpt/llama2/llama2op.h" 9 | #include "model/gpt/gpt2/gpt2op.h" 10 | 11 | /* 12 | The two function wrappers below are needed to avoid the following error: 13 | 14 | RuntimeError: Tried to convert an IValue of type __torch__.torch.classes.gpt_ops.OptOp (of Python compilation unit at: 0) to custom class type __torch__.torch.classes.gpt_ops.GptOpBase (of Python compilation unit at: 0) 15 | 16 | We encounter the error above because of that, when we create a OptOp class in python and 17 | call its load_weight() method, we are actually calling the load_weight() method of the 18 | GptOpBase class, which is the base class of OptOp. However, PyTorch thinks it needs to 19 | convert the OptOp object to a GptOpBase object (since the first argument of loadWeight 20 | is GptOpBase*), which is not possible because we didn't defined that. 21 | 22 | The solution is to define a wrapper function that takes a OptOp object as the first 23 | argument and calls the loadWeight() method of the OptOp object, which avoid type 24 | conversion. 25 | */ 26 | template 27 | void loadWeightWrapper(const c10::intrusive_ptr& self, const std::string& path) { 28 | self->loadWeight(path); 29 | } 30 | 31 | template 32 | void initDummyWeightWrapper(const c10::intrusive_ptr& self) { 33 | self->initDummyWeight(); 34 | } 35 | 36 | template 37 | std::vector forwardWrapper(const c10::intrusive_ptr& self, 38 | const std::vector> &input_tokens_batched, 39 | const std::vector &first_token_indexes, 40 | torch::Tensor &k_cache, 41 | torch::Tensor &v_cache, 42 | const std::vector> &block_table) { 43 | return self->forward(input_tokens_batched, first_token_indexes, k_cache, v_cache, block_table); 44 | } 45 | 46 | template 47 | void initCommunicatorWrapper(const c10::intrusive_ptr& self, 48 | const std::vector tp_id, 49 | const std::vector pp_id) { 50 | self->init_communicator(tp_id, pp_id); 51 | } 52 | 53 | TORCH_LIBRARY(gpt_ops, m) { 54 | m.class_("GptOpBase"); // Must add this class or will get error: "c10::intrusive_ptr<...> could not be converted to any of the known types." 55 | m.class_("OptOp") 56 | .def(torch::init >()) 58 | .def("load_weight", &loadWeightWrapper) 59 | .def("init_dummy_weights", &initDummyWeightWrapper) 60 | .def("forward", &forwardWrapper) 61 | .def("init_communicator", &initCommunicatorWrapper) 62 | 63 | ; 64 | m.class_("Llama2Op") 65 | .def(torch::init >()) 67 | .def("load_weight", &loadWeightWrapper) 68 | .def("init_dummy_weights", &initDummyWeightWrapper) 69 | .def("forward", &forwardWrapper) 70 | .def("init_communicator", &initCommunicatorWrapper) 71 | ; 72 | m.class_("Gpt2Op") 73 | .def(torch::init >()) 75 | .def("load_weight", &loadWeightWrapper) 76 | .def("init_dummy_weights", &initDummyWeightWrapper) 77 | .def("forward", &forwardWrapper) 78 | .def("init_communicator", &initCommunicatorWrapper) 79 | 80 | ; 81 | } 82 | 83 | TORCH_LIBRARY(nccl_ops, m) 84 | { 85 | m.def("generate_nccl_id", &st::util::generate_nccl_id); 86 | } 87 | 88 | TORCH_LIBRARY(swapping_ops, m) { 89 | m.def("swap", &st::util::swap); 90 | } 91 | 92 | TORCH_LIBRARY(block_migration_ops, m) { 93 | m.def("get_ipc_mem_handle", &st::util::get_ipc_mem_handle); 94 | m.def("register_ipc_mem_handle", &st::util::register_ipc_mem_handle); 95 | m.def("migrate_blocks", &st::util::migrate_blocks); 96 | } -------------------------------------------------------------------------------- /src/csrc/util/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_library(util STATIC cublas_wrapper.cc) 2 | target_link_libraries(util PUBLIC CUDA::cublas) 3 | set_property(TARGET util PROPERTY POSITION_INDEPENDENT_CODE ON) 4 | 5 | add_library(nccl_utils STATIC nccl_utils.cc) 6 | set_property(TARGET nccl_utils PROPERTY POSITION_INDEPENDENT_CODE ON) 7 | set_property(TARGET nccl_utils PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) 8 | target_link_libraries(nccl_utils PUBLIC ${MPI_CXX_LIBRARIES} ${NCCL_LIBRARIES}) 9 | target_include_directories(nccl_utils PUBLIC ${MPI_CXX_INCLUDE_DIRS} ${NCCL_INCLUDE_DIRS}) 10 | 11 | add_library(py_nccl_utils STATIC py_nccl.cc) 12 | set_property(TARGET py_nccl_utils PROPERTY POSITION_INDEPENDENT_CODE ON) 13 | set_property(TARGET py_nccl_utils PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) 14 | target_link_libraries(py_nccl_utils PUBLIC ${MPI_CXX_LIBRARIES} ${NCCL_LIBRARIES}) 15 | 16 | add_library(py_swapping STATIC py_swapping.cc) 17 | target_link_libraries(py_swapping PUBLIC util) 18 | set_property(TARGET py_swapping PROPERTY POSITION_INDEPENDENT_CODE ON) 19 | 20 | add_library(py_block_migration STATIC py_block_migration.cc) 21 | target_link_libraries(py_block_migration PUBLIC util) 22 | set_property(TARGET py_block_migration PROPERTY POSITION_INDEPENDENT_CODE ON) 23 | -------------------------------------------------------------------------------- /src/csrc/util/cublas_wrapper.cc: -------------------------------------------------------------------------------- 1 | #include "cublas_wrapper.h" 2 | 3 | #include 4 | #include 5 | 6 | #include "util/cuda_utils.h" 7 | 8 | namespace st::util { 9 | 10 | template 11 | void CublasWrapper::gemmStridedBatched( 12 | cublasOperation_t transa, 13 | cublasOperation_t transb, 14 | int m, 15 | int n, 16 | int k, 17 | const T alpha, 18 | const T* Aarray, 19 | long long int stride_a, 20 | const T* Barray, 21 | long long int stride_b, 22 | const T beta, 23 | T* Carray, 24 | long long int stride_c, 25 | int batchCount 26 | ) { 27 | cudaDataType_t cuda_datatype = getCudaDataType(); 28 | cublasComputeType_t compute_type = getCublasComputeType(); 29 | cublasStatus_t status = cublasGemmStridedBatchedEx( 30 | *handle_.get(), 31 | transb, 32 | transa, 33 | n, 34 | m, 35 | k, 36 | &alpha, 37 | Barray, 38 | cuda_datatype, 39 | transb == CUBLAS_OP_N ? n : k, 40 | stride_b, 41 | Aarray, 42 | cuda_datatype, 43 | transa == CUBLAS_OP_N ? k : m, 44 | stride_a, 45 | &beta, 46 | Carray, 47 | cuda_datatype, 48 | n, 49 | stride_c, 50 | batchCount, 51 | compute_type, 52 | algo_ 53 | ); 54 | if (status != CUBLAS_STATUS_SUCCESS) { 55 | std::cerr << "CublasWrapper::gemmStridedBatched failed: " << status << std::endl; 56 | throw std::runtime_error("CublasWrapper::gemmStridedBatched failed"); 57 | } 58 | } 59 | 60 | template void CublasWrapper::gemmStridedBatched( 61 | cublasOperation_t transa, cublasOperation_t transb, 62 | int m, int n, int k, 63 | const half alpha, 64 | const half* Aarray, long long int stride_a, 65 | const half* Barray, long long int stride_b, 66 | const half beta, 67 | half* Carray, long long int stride_c, 68 | int batchCount 69 | ); 70 | template void CublasWrapper::gemmStridedBatched( 71 | cublasOperation_t transa, cublasOperation_t transb, 72 | int m, int n, int k, 73 | const float alpha, 74 | const float* Aarray, long long int stride_a, 75 | const float* Barray, long long int stride_b, 76 | const float beta, 77 | float* Carray, long long int stride_c, 78 | int batchCount 79 | ); 80 | 81 | template 82 | void CublasWrapper::gemmStridedBatched( 83 | cublasOperation_t transa, 84 | cublasOperation_t transb, 85 | int m, 86 | int n, 87 | int k, 88 | const T* Aarray, 89 | long long int stride_a, 90 | const T* Barray, 91 | long long int stride_b, 92 | T* Carray, 93 | long long int stride_c, 94 | int batchCount 95 | ) { 96 | gemmStridedBatched( 97 | transa, 98 | transb, 99 | m, 100 | n, 101 | k, 102 | (T)1.0, 103 | Aarray, 104 | stride_a, 105 | Barray, 106 | stride_b, 107 | (T)0.0, 108 | Carray, 109 | stride_c, 110 | batchCount 111 | ); 112 | } 113 | 114 | template void CublasWrapper::gemmStridedBatched( 115 | cublasOperation_t transa, cublasOperation_t transb, 116 | int m, int n, int k, 117 | const half* Aarray, long long int stride_a, 118 | const half* Barray, long long int stride_b, 119 | half* Carray, long long int stride_c, 120 | int batchCount 121 | ); 122 | template void CublasWrapper::gemmStridedBatched( 123 | cublasOperation_t transa, cublasOperation_t transb, 124 | int m, int n, int k, 125 | const float* Aarray, long long int stride_a, 126 | const float* Barray, long long int stride_b, 127 | float* Carray, long long int stride_c, 128 | int batchCount 129 | ); 130 | 131 | template 132 | void CublasWrapper::gemmBatched( 133 | cublasOperation_t transa, 134 | cublasOperation_t transb, 135 | int m, 136 | int n, 137 | int k, 138 | const T* Aarray, 139 | const T* Barray, 140 | T* Carray, 141 | int batchCount 142 | ) { 143 | gemmStridedBatched( 144 | transa, 145 | transb, 146 | m, 147 | n, 148 | k, 149 | (T)1.0, 150 | Aarray, 151 | 1LL*m*k, 152 | Barray, 153 | 1LL*n*k, 154 | (T)0.0, 155 | Carray, 156 | 1LL*m*n, 157 | batchCount 158 | ); 159 | } 160 | 161 | template void CublasWrapper::gemmBatched( 162 | cublasOperation_t transa, cublasOperation_t transb, 163 | int m, int n, int k, 164 | const half* Aarray, const half* Barray, half* Carray, 165 | int batchCount 166 | ); 167 | template void CublasWrapper::gemmBatched( 168 | cublasOperation_t transa, cublasOperation_t transb, 169 | int m, int n, int k, 170 | const float* Aarray, const float* Barray, float* Carray, 171 | int batchCount 172 | ); 173 | 174 | template 175 | void CublasWrapper::gemm( 176 | cublasOperation_t transa, 177 | cublasOperation_t transb, 178 | int m, 179 | int n, 180 | int k, 181 | const T* Aarray, 182 | const T* Barray, 183 | T* Carray 184 | ) { 185 | gemmBatched( 186 | transa, 187 | transb, 188 | m, 189 | n, 190 | k, 191 | Aarray, 192 | Barray, 193 | Carray, 194 | 1 195 | ); 196 | } 197 | 198 | template void CublasWrapper::gemm( 199 | cublasOperation_t transa, cublasOperation_t transb, 200 | int m, int n, int k, 201 | const half* Aarray, const half* Barray, half* Carray 202 | ); 203 | template void CublasWrapper::gemm( 204 | cublasOperation_t transa, cublasOperation_t transb, 205 | int m, int n, int k, 206 | const float* Aarray, const float* Barray, float* Carray 207 | ); 208 | 209 | inline void checkCublasStatus_line(cublasStatus_t status, const char* file, int line){ 210 | if (status != CUBLAS_STATUS_SUCCESS) { 211 | std::cerr << "cublasLt failed: " << status << " " << file << ":" << line << std::endl; 212 | throw std::runtime_error("cublasLt failed:"); 213 | abort(); 214 | } 215 | } 216 | 217 | } -------------------------------------------------------------------------------- /src/csrc/util/cublas_wrapper.h: -------------------------------------------------------------------------------- 1 | /* 2 | cublas_wrapper.h - cublas GEMM wrapper 3 | 4 | This file contains CublasWrapper, a wrapper for cuBLAS functions (mainly GEMM). 5 | 6 | We wrap cublas mainly for the following reasons: 7 | - C++ stores matrixes in row-major order, while cuBLAS accepts matrix in column-major order, 8 | which is confusing and error-prone. So in this wrapper, every input matrix is ROW-MAJOR ORDER. 9 | - StridedBatchedGEMM in cuBLAS contains some seldom-used parameters, which makes the interface 10 | of cuBLAS very complicated. So we wrap it to make it easier to use. 11 | - cuBLAS supports many algorithms for GEMM, and some of them are faster than the default one. 12 | Currently it is using the default algo (CUBLAS_GEMM_DEFAULT) to do GEMM, in the 13 | future we may run a small benchmark ahead of time, and then pick the fastest algo. 14 | */ 15 | 16 | #pragma once 17 | 18 | #include 19 | #include 20 | 21 | #include 22 | 23 | #define checkCublasStatus(status) checkCublasStatus_line((status), __FILE__, __LINE__) 24 | 25 | namespace st::util { 26 | 27 | class CublasWrapper { 28 | private: 29 | // Need to use a shared_ptr here, or when passing CublasWrapper to a function 30 | // and the function returns, the destructor of CublasWrapper will be called, 31 | // which destorys the handle, and then the handle in the function will be invalid. 32 | std::shared_ptr handle_; 33 | cublasGemmAlgo_t algo_; 34 | 35 | public: 36 | CublasWrapper(): 37 | handle_(std::make_shared()) { 38 | cublasCreate(handle_.get()); 39 | algo_ = CUBLAS_GEMM_DEFAULT; 40 | } 41 | 42 | ~CublasWrapper() { 43 | if (handle_.use_count() == 1) { 44 | // I am the last one who uses the handle, so I should destroy it 45 | cublasDestroy(*handle_.get()); 46 | } 47 | } 48 | 49 | /* 50 | gemmStridedBatched - Calculate C = A @ B 51 | 52 | PLEASE KEEP IN MIND THAT A, B AND C SHOULD BE STORED IN ROW MAJOR! 53 | 54 | The size of A is (m, k) (or (k, m) if transa is CUBLAS_OP_T) 55 | The size of B is (k, n) (or (n, k) if transb is CUBLAS_OP_T) 56 | The size of C is (m, n) 57 | 58 | stride: How much elements between two adjacent matrices in A, B and C 59 | In detail, we consider Aarray[0:m*k] as the first matrix in A, 60 | Aarray[stride_a+m*k : stride_a+2*m*k] as the second matrix in A, and so on. 61 | The same for stride_b and stride_c. 62 | 63 | batchCount: How many matrices in A, B and C 64 | 65 | alpha and beta: The same as cuBLAS. C = alpha*(A@B) + beta*C 66 | */ 67 | template 68 | void gemmStridedBatched( 69 | cublasOperation_t transa, 70 | cublasOperation_t transb, 71 | int m, 72 | int n, 73 | int k, 74 | const T alpha, 75 | const T* Aarray, 76 | long long int stride_a, 77 | const T* Barray, 78 | long long int stride_b, 79 | const T beta, 80 | T* Carray, 81 | long long int stride_c, 82 | int batchCount 83 | ); 84 | 85 | /* 86 | gemmStridedBatched - A simplified version of gemmStridedBatched 87 | It omits alpha (set to 1) and beta (set to 0). 88 | */ 89 | template 90 | void gemmStridedBatched( 91 | cublasOperation_t transa, 92 | cublasOperation_t transb, 93 | int m, 94 | int n, 95 | int k, 96 | const T* Aarray, 97 | long long int stride_a, 98 | const T* Barray, 99 | long long int stride_b, 100 | T* Carray, 101 | long long int stride_c, 102 | int batchCount 103 | ); 104 | 105 | /* 106 | gemmBatched - A simplified version of gemmStridedBatched 107 | It omits stride_a, stride_b and stride_c by assuming that 108 | A, B and C are stored continuously in memory. 109 | */ 110 | template 111 | void gemmBatched( 112 | cublasOperation_t transa, 113 | cublasOperation_t transb, 114 | int m, 115 | int n, 116 | int k, 117 | const T* Aarray, 118 | const T* Barray, 119 | T* Carray, 120 | int batchCount 121 | ); 122 | 123 | /* 124 | gemm - Calculate the product of two matrixes 125 | Its function is exactly gemmBatched + batchCount=1 126 | */ 127 | template 128 | void gemm( 129 | cublasOperation_t transa, 130 | cublasOperation_t transb, 131 | int m, 132 | int n, 133 | int k, 134 | const T* Aarray, 135 | const T* Barray, 136 | T* Carray 137 | ); 138 | }; 139 | 140 | template 141 | cublasComputeType_t getCublasComputeType() { 142 | if (std::is_same::value) { 143 | return CUBLAS_COMPUTE_16F; 144 | } 145 | if (std::is_same::value) { 146 | return CUBLAS_COMPUTE_32F; // TODO(sunyh): Maybe try CUBLAS_COMPUTE_32F_FAST_16F? 147 | } 148 | if (std::is_same::value) { 149 | return CUBLAS_COMPUTE_64F; 150 | } 151 | if (std::is_same::value) { 152 | return CUBLAS_COMPUTE_32I; 153 | } 154 | throw std::runtime_error("Cublas compute type: Unsupported type"); 155 | } 156 | 157 | } // namespace st::kernel -------------------------------------------------------------------------------- /src/csrc/util/cuda_utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include 9 | #include 10 | 11 | #define CUDA_CHECK(cmd) do { \ 12 | cudaError_t result = cmd; \ 13 | if (result != cudaSuccess) { \ 14 | printf("[ERROR] CUDA error %s:%d '%s': (%d) %s\n", __FILE__, __LINE__, #cmd, (int)result, cudaGetErrorString(result)); \ 15 | exit(-1); \ 16 | } \ 17 | } while(0) 18 | 19 | inline void syncAndCheck(const char* const file, int const line, bool force_check = false) { 20 | #ifdef DEBUG 21 | force_check = true; 22 | #endif 23 | if (force_check) { 24 | cudaDeviceSynchronize(); 25 | cudaError_t result = cudaGetLastError(); 26 | if (result) { 27 | throw std::runtime_error(std::string("[ST] CUDA runtime error: ") + cudaGetErrorString(result) + " " 28 | + file + ":" + std::to_string(line) + " \n"); 29 | } 30 | } 31 | } 32 | 33 | #define sync_check_cuda_error() syncAndCheck(__FILE__, __LINE__, false) 34 | #define sync_check_cuda_error_force() syncAndCheck(__FILE__, __LINE__, true) 35 | 36 | // Some stuff for indexing into an 1-D array 37 | #define INDEX_2D(dim1, dim2, index1, index2) \ 38 | (((int64_t)index1) * (dim2) + (index2)) 39 | #define INDEX_3D(dim1, dim2, dim3, index1, index2, index3) \ 40 | (((int64_t)index1) * (dim2) * (dim3) + ((int64_t)index2) * (dim3) + (index3)) 41 | #define INDEX_4D(dim1, dim2, dim3, dim4, index1, index2, index3, index4) \ 42 | (((int64_t)index1) * (dim2) * (dim3) * (dim4) + ((int64_t)index2) * (dim3) * (dim4) + ((int64_t)index3) * (dim4) + (index4)) 43 | #define INDEX_5D(dim1, dim2, dim3, dim4, dim5, index1, index2, index3, index4, index5) \ 44 | (((int64_t)index1) * (dim2) * (dim3) * (dim4) * (dim5) + ((int64_t)index2) * (dim3) * (dim4) * (dim5) + ((int64_t)index3) * (dim4) * (dim5) + (index4) * (dim5) + (index5)) 45 | 46 | // A tiny stuff that supports remalloc on GPU 47 | template 48 | struct RemallocableArray { 49 | T* ptr; 50 | int64_t size; 51 | 52 | RemallocableArray() { 53 | ptr = nullptr; 54 | size = 0; 55 | } 56 | 57 | ~RemallocableArray() { 58 | if (ptr != nullptr) { 59 | CUDA_CHECK(cudaFree(ptr)); 60 | } 61 | } 62 | 63 | void remalloc(int64_t target_size) { 64 | if (target_size > size) { 65 | int64_t new_size = size ? size*2 : 64; 66 | while (new_size < target_size) { 67 | new_size *= 2; 68 | } 69 | if (ptr != nullptr) { 70 | CUDA_CHECK(cudaFree(ptr)); 71 | } 72 | CUDA_CHECK(cudaMalloc(&ptr, new_size * sizeof(T))); 73 | size = new_size; 74 | } 75 | } 76 | }; 77 | 78 | template 79 | inline void printGpuArrayHelper(const T* array, int64_t size, const char* arr_name) { 80 | T* array_cpu = new T[size]; 81 | CUDA_CHECK(cudaMemcpy(array_cpu, array, sizeof(T) * size, cudaMemcpyDeviceToHost)); 82 | for (int64_t i = 0; i < size; i++) { 83 | printf("%f ", (float)array_cpu[i]); 84 | } 85 | printf("\n"); 86 | delete[] array_cpu; 87 | } 88 | 89 | #define printGpuArray(array, size) printGpuArrayHelper(array, size, #array) 90 | 91 | // A util to check cuda memory usage 92 | inline int64_t cuda_memory_size() { 93 | size_t free_byte; 94 | size_t total_byte; 95 | cudaMemGetInfo(&free_byte, &total_byte); 96 | return total_byte - free_byte; 97 | } 98 | 99 | // CUDAFreeAtReturn - A tiny macro to call cudaFree when the point goes out of scope 100 | template 101 | class CUDAFreeAtReturnHelper { 102 | private: 103 | PTR_T ptr; 104 | std::string pointer_name; 105 | public: 106 | CUDAFreeAtReturnHelper(PTR_T ptr, std::string pointer_name): 107 | pointer_name(pointer_name) { this->ptr = ptr; } 108 | ~CUDAFreeAtReturnHelper() { 109 | if (ptr != nullptr) { 110 | cudaFree(ptr); 111 | cudaDeviceSynchronize(); 112 | cudaError_t result = cudaGetLastError(); 113 | if (result) { 114 | fprintf(stderr, "Error occured when freeing pointer %s\n", pointer_name.c_str()); 115 | fprintf(stderr, "%s\n", (std::string("[ST] CUDA runtime error: ") + cudaGetErrorString(result) + " " 116 | + __FILE__ + ":" + std::to_string(__LINE__) + " \n").c_str()); 117 | exit(1); 118 | } 119 | } 120 | } 121 | }; 122 | #define CUDA_FREE_AT_RETURN(ptr) CUDAFreeAtReturnHelper ptr##_cuda_free_at_return(ptr, #ptr) 123 | 124 | template 125 | cudaDataType_t getCudaDataType() { 126 | if (std::is_same::value) { 127 | return CUDA_R_16F; 128 | } 129 | #ifdef ENABLE_BF16 130 | else if (std::is_same::value) { 131 | return CUDA_R_16BF; 132 | } 133 | #endif 134 | else if (std::is_same::value) { 135 | return CUDA_R_32F; 136 | } 137 | else { 138 | throw std::runtime_error("Cuda data type: Unsupported type"); 139 | } 140 | } 141 | -------------------------------------------------------------------------------- /src/csrc/util/debug_utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | /* 6 | assert_whenever: assertion which ignore whether NDEBUG is set 7 | 8 | In C++, assert() is evaluated only when NDEBUG is not set. This is 9 | inconvenient when we want to check the assertion even in release mode. 10 | This macro is a workaround for this problem. 11 | */ 12 | 13 | extern "C" { 14 | // Copied from assert.h 15 | extern void __assert_fail (const char *__assertion, const char *__file, 16 | unsigned int __line, const char *__function) 17 | __THROW __attribute__ ((__noreturn__)); 18 | 19 | #define __ST_ASSERT_FUNCTION __extension__ __PRETTY_FUNCTION__ 20 | # define assert_whenever(expr) \ 21 | (static_cast (expr) \ 22 | ? void (0) \ 23 | : __assert_fail (#expr, __FILE__, __LINE__, __ST_ASSERT_FUNCTION)) 24 | } 25 | -------------------------------------------------------------------------------- /src/csrc/util/nccl_utils.cc: -------------------------------------------------------------------------------- 1 | #include "nccl_utils.h" 2 | #include 3 | 4 | #define NCCL_CHECK(cmd) \ 5 | do { \ 6 | ncclResult_t result = cmd; \ 7 | if (result != ncclSuccess) { \ 8 | printf("[ERROR] NCCL error %s:%d '%s' : %s\n", __FILE__, __LINE__, #cmd, ncclGetErrorString(result)); \ 9 | exit(-1); \ 10 | } \ 11 | } while (0) 12 | 13 | namespace st::util { 14 | 15 | void stNcclErrorCheck(ncclResult_t result, const char* func, const char* file, int line) 16 | { 17 | if (result != ncclSuccess) { 18 | printf("[ERROR] NCCL error %s:%d '%s' : %s\n", file, line, func, ncclGetErrorString(result)); 19 | exit(-1); 20 | } 21 | } 22 | 23 | void stNcclGetUniqueId(ncclUniqueId& nccl_id) 24 | { 25 | NCCL_CHECK(ncclGetUniqueId(&nccl_id)); 26 | } 27 | 28 | NcclComm stNcclInit(int64_t world_size, int64_t rank, const ncclUniqueId& nccl_id, cudaStream_t stream, bool real_init) 29 | { 30 | NcclComm nccl_comm; 31 | nccl_comm.rank = rank; 32 | nccl_comm.size = world_size; 33 | nccl_comm.stream = stream; 34 | if (world_size == 1 || !real_init) { 35 | nccl_comm.comm = nullptr; 36 | return nccl_comm; 37 | } 38 | NCCL_CHECK(ncclCommInitRank(&nccl_comm.comm, nccl_comm.size, nccl_id, nccl_comm.rank)); 39 | return nccl_comm; 40 | } 41 | 42 | void stNcclDestroy(NcclComm& nccl_comm) 43 | { 44 | NCCL_CHECK(ncclCommDestroy(nccl_comm.comm)); 45 | } 46 | 47 | void stNcclAllReduce( 48 | void* sendbuff, 49 | void* recvbuff, 50 | int64_t count, 51 | ncclDataType_t datatype, 52 | ncclRedOp_t op, 53 | ncclComm_t comm, 54 | cudaStream_t stream) 55 | { 56 | if (comm == nullptr) { 57 | return; 58 | } 59 | NCCL_CHECK(ncclAllReduce(sendbuff, recvbuff, count, datatype, op, comm, stream)); 60 | } 61 | 62 | void stNcclSend( 63 | void* buff, 64 | int64_t count, 65 | ncclDataType_t datatype, 66 | int64_t send_to, 67 | NcclComm comm, 68 | cudaStream_t stream) 69 | { 70 | if (comm.comm == nullptr) { 71 | printf("[ERROR] NCCL comm is null\n"); 72 | return; 73 | } 74 | 75 | if (send_to == comm.rank) { 76 | printf("[ERROR] Send rank and recv rank are the same\n"); 77 | exit(-1); 78 | } 79 | 80 | NCCL_CHECK(ncclSend(buff, count, datatype, send_to, comm.comm, stream)); 81 | } 82 | 83 | void stNcclRecv( 84 | void* buff, 85 | int64_t count, 86 | ncclDataType_t datatype, 87 | int64_t recv_from, 88 | NcclComm comm, 89 | cudaStream_t stream) 90 | { 91 | if (comm.comm == nullptr) { 92 | printf("[ERROR] NCCL comm is null\n"); 93 | return; 94 | } 95 | 96 | if (recv_from == comm.rank) { 97 | printf("[ERROR] Send rank and recv rank are the same\n"); 98 | exit(-1); 99 | } 100 | 101 | NCCL_CHECK(ncclRecv(buff, count, datatype, recv_from, comm.comm, stream)); 102 | } 103 | 104 | void stNcclSendRecv( 105 | void* sendbuff, 106 | void* recvbuff, 107 | int64_t count, 108 | ncclDataType_t datatype, 109 | int64_t send_rank, 110 | int64_t recv_rank, 111 | NcclComm comm, 112 | cudaStream_t stream) 113 | { 114 | if (comm.comm == nullptr) { 115 | printf("[ERROR] NCCL comm is null\n"); 116 | return; 117 | } 118 | 119 | if (send_rank == recv_rank) { 120 | printf("[ERROR] Send rank and recv rank are the same\n"); 121 | exit(-1); 122 | } 123 | 124 | if (send_rank == comm.rank) { 125 | NCCL_CHECK(ncclSend(sendbuff, count, datatype, recv_rank, comm.comm, stream)); 126 | } else if (recv_rank == comm.rank) { 127 | NCCL_CHECK(ncclRecv(recvbuff, count, datatype, send_rank, comm.comm, stream)); 128 | } else { 129 | printf("[ERROR] Rank %ld is not involved in the send/recv\n", comm.rank); 130 | exit(-1); 131 | } 132 | } 133 | void stNcclBcast( 134 | void* buff, 135 | int64_t count, 136 | ncclDataType_t datatype, 137 | int64_t root, 138 | NcclComm comm, 139 | cudaStream_t stream) 140 | { 141 | if (comm.comm == nullptr) { 142 | printf("[ERROR] NCCL comm is null\n"); 143 | return; 144 | } 145 | NCCL_CHECK(ncclBcast(buff, count, datatype, root, comm.comm, stream)); 146 | } 147 | 148 | } -------------------------------------------------------------------------------- /src/csrc/util/nccl_utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | namespace st::util { 9 | 10 | struct NcclComm { 11 | ncclComm_t comm; 12 | int64_t rank; 13 | int64_t size; 14 | cudaStream_t stream; 15 | 16 | NcclComm() { 17 | comm = nullptr; 18 | rank = 0; 19 | size = 1; 20 | stream = 0; 21 | } 22 | }; 23 | 24 | void stNcclGetUniqueId(ncclUniqueId &nccl_id); 25 | 26 | NcclComm stNcclInit(int64_t world_size, int64_t rank, const ncclUniqueId &nccl_id, cudaStream_t stream = 0, bool real_init = true); 27 | void stNcclDestroy(NcclComm &nccl_comm); 28 | void stNcclAllReduce( 29 | void* sendbuff, 30 | void* recvbuff, 31 | int64_t count, 32 | ncclDataType_t datatype, 33 | ncclRedOp_t op, 34 | ncclComm_t comm, 35 | cudaStream_t stream = 0 36 | ); 37 | 38 | void stNcclSend( 39 | void* buff, 40 | int64_t count, 41 | ncclDataType_t datatype, 42 | int64_t send_to, 43 | NcclComm comm, 44 | cudaStream_t stream = 0 45 | ); 46 | 47 | void stNcclRecv( 48 | void* buff, 49 | int64_t count, 50 | ncclDataType_t datatype, 51 | int64_t recv_from, 52 | NcclComm comm, 53 | cudaStream_t stream = 0 54 | ); 55 | 56 | void stNcclSendRecv( 57 | void* sendbuff, 58 | void* recvbuff, 59 | int64_t count, 60 | ncclDataType_t datatype, 61 | int64_t send_rank, 62 | int64_t recv_rank, 63 | NcclComm comm, 64 | cudaStream_t stream = 0 65 | ); 66 | 67 | void stNcclBcast( 68 | void* buff, 69 | int64_t count, 70 | ncclDataType_t datatype, 71 | int64_t root, 72 | NcclComm comm, 73 | cudaStream_t stream = 0 74 | ); 75 | 76 | template 77 | ncclDataType_t stNcclGetDataType() 78 | { 79 | ncclDataType_t nccl_data_type; 80 | if (std::is_same::value) { 81 | nccl_data_type = ncclFloat32; 82 | } 83 | else if (std::is_same::value) { 84 | nccl_data_type = ncclHalf; 85 | } 86 | else if (std::is_same::value) { 87 | nccl_data_type = ncclInt; 88 | } 89 | else if (std::is_same::value) { 90 | nccl_data_type = ncclChar; 91 | } 92 | else if (std::is_same::value) { 93 | nccl_data_type = ncclInt8; 94 | } 95 | else { 96 | printf("[ERROR] NCCL only support float, half, int, char, and bool. \n"); 97 | exit(-1); 98 | } 99 | return nccl_data_type; 100 | } 101 | 102 | } // namespace st::util -------------------------------------------------------------------------------- /src/csrc/util/py_block_migration.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include 8 | 9 | #include 10 | 11 | namespace st::util { 12 | 13 | std::vector get_ipc_mem_handle(torch::Tensor tensor); 14 | 15 | bool register_ipc_mem_handle( 16 | std::vector k_cache_handle_vec, 17 | std::vector v_cache_handle_vec, 18 | int64_t num_layers, 19 | int64_t num_heads, 20 | const std::vector &context_parallel_config, 21 | const std::vector &decoding_parallel_config 22 | ); 23 | 24 | void migrate_blocks( 25 | const int64_t context_pp_size, 26 | const int64_t context_tp_size, 27 | 28 | const std::vector &context_block_indexes, 29 | 30 | const int64_t decoding_pp_size, 31 | const int64_t decoding_tp_size, 32 | 33 | const int64_t decoding_pp_rank, 34 | const int64_t decoding_tp_rank, 35 | 36 | const std::vector &decoding_block_indexes, 37 | 38 | torch::Tensor decoding_worker_k_cache, 39 | torch::Tensor decoding_worker_v_cache 40 | ); 41 | 42 | } // namespace st::util 43 | -------------------------------------------------------------------------------- /src/csrc/util/py_nccl.cc: -------------------------------------------------------------------------------- 1 | #include "py_nccl.h" 2 | 3 | namespace st::util { 4 | 5 | std::vector generate_nccl_id() { 6 | ncclUniqueId nccl_id; 7 | ncclGetUniqueId(&nccl_id); 8 | std::vector ret; 9 | ret.resize(NCCL_UNIQUE_ID_BYTES / sizeof(int64_t)); 10 | memcpy(ret.data(), nccl_id.internal, NCCL_UNIQUE_ID_BYTES); 11 | return ret; 12 | } 13 | 14 | } // namespace st::util -------------------------------------------------------------------------------- /src/csrc/util/py_nccl.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace st::util { 8 | 9 | // torch function to generate nccl_id 10 | std::vector generate_nccl_id(); 11 | 12 | } // namespace st::util -------------------------------------------------------------------------------- /src/csrc/util/py_swapping.cc: -------------------------------------------------------------------------------- 1 | #include "py_swapping.h" 2 | 3 | #include "torch_utils.h" 4 | #include 5 | #include // for at::cuda::getCurrentCUDAStream() 6 | 7 | namespace st::util { 8 | 9 | // swap - Perform swapping between GPU blocks and CPU blocks 10 | // The source_block_ids and target_block_ids are the block ids of the blocks to be swapped. 11 | // source_block_ids[0] will be copied to target_block_ids[0] and so on 12 | // `is_swap_in` defines whether the swap is a swap-in or swap-out (swap-in means 13 | // to swap from CPU to GPU, swap-out means to swap from GPU to CPU) 14 | // 15 | // Here we do not pass a cudaStream to the function. Instead we use the current 16 | // stream indicated by at::cuda::getCurrentCUDAStream(). So it is python's 17 | // responsibility to set the current stream before calling this function. 18 | // 19 | // Future work: Now the number of cudaMemcpyAsync calls is equal to 2x the number 20 | // of blocks to swap. We can reduce the number of cudaMemcpyAsync calls by 21 | // grouping nearby blocks together and perform a single invocation 22 | void swap( 23 | const std::vector &source_block_ids, 24 | const std::vector &target_block_ids, 25 | const bool is_swap_in, 26 | 27 | torch::Tensor k_cache, 28 | torch::Tensor v_cache, 29 | torch::Tensor k_swap, 30 | torch::Tensor v_swap 31 | ) { 32 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 33 | size_t block_size_in_bytes = getTensorSizeInBytes(k_cache) / k_cache.size(0); 34 | int num_blocks_to_swap = source_block_ids.size(); 35 | for (int i = 0; i < num_blocks_to_swap; i++) { 36 | int64_t source_block_id = source_block_ids[i]; 37 | int64_t target_block_id = target_block_ids[i]; 38 | 39 | if (is_swap_in) { 40 | // Copy from CPU to GPU 41 | cudaMemcpyAsync( 42 | (char*)k_cache.data_ptr() + target_block_id * block_size_in_bytes, 43 | (char*)k_swap.data_ptr() + source_block_id * block_size_in_bytes, 44 | block_size_in_bytes, 45 | cudaMemcpyHostToDevice, 46 | stream 47 | ); 48 | cudaMemcpyAsync( 49 | (char*)v_cache.data_ptr() + target_block_id * block_size_in_bytes, 50 | (char*)v_swap.data_ptr() + source_block_id * block_size_in_bytes, 51 | block_size_in_bytes, 52 | cudaMemcpyHostToDevice, 53 | stream 54 | ); 55 | } else { 56 | // Copy from GPU to CPU 57 | cudaMemcpyAsync( 58 | (char*)k_swap.data_ptr() + target_block_id * block_size_in_bytes, 59 | (char*)k_cache.data_ptr() + source_block_id * block_size_in_bytes, 60 | block_size_in_bytes, 61 | cudaMemcpyDeviceToHost, 62 | stream 63 | ); 64 | cudaMemcpyAsync( 65 | (char*)v_swap.data_ptr() + target_block_id * block_size_in_bytes, 66 | (char*)v_cache.data_ptr() + source_block_id * block_size_in_bytes, 67 | block_size_in_bytes, 68 | cudaMemcpyDeviceToHost, 69 | stream 70 | ); 71 | } 72 | } 73 | } 74 | 75 | } // namespace st::util -------------------------------------------------------------------------------- /src/csrc/util/py_swapping.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include 6 | #include 7 | 8 | namespace st::util { 9 | 10 | void swap( 11 | const std::vector &source_block_ids, 12 | const std::vector &target_block_ids, 13 | const bool is_swap_in, 14 | 15 | torch::Tensor k_cache, 16 | torch::Tensor v_cache, 17 | torch::Tensor k_swap, 18 | torch::Tensor v_swap 19 | ); 20 | 21 | } // namespace st::util 22 | -------------------------------------------------------------------------------- /src/csrc/util/st_datatypes.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace st::util { 8 | // Only type with fixed size is supported 9 | class StDataType { 10 | public: 11 | enum Type { 12 | INT8, 13 | INT32, 14 | INT64, 15 | FLOAT16, 16 | FLOAT32, 17 | FLOAT64, 18 | BOOL, 19 | }; 20 | StDataType() = default; 21 | constexpr StDataType(Type type) 22 | : type(type) 23 | { 24 | } 25 | constexpr int64_t get_size() const 26 | { 27 | switch (type) { 28 | case INT8: 29 | return 1; 30 | case INT32: 31 | return 4; 32 | case INT64: 33 | return 8; 34 | case FLOAT16: 35 | return 2; 36 | case FLOAT32: 37 | return 4; 38 | case FLOAT64: 39 | return 8; 40 | case BOOL: 41 | return 1; 42 | } 43 | printf("[ERROR] Unsupported data type\n"); 44 | exit(-1); 45 | } 46 | constexpr Type get_type() const { return type; } 47 | 48 | // TODO(sunyh): spaceship operator after C++20 49 | constexpr bool operator==(const StDataType& other) const { return type == other.type; } 50 | constexpr bool operator!=(const StDataType& other) const { return type != other.type; } 51 | 52 | constexpr ncclDataType_t get_nccl_type() const 53 | { 54 | switch (type) { 55 | case StDataType::INT8: 56 | return ncclInt8; 57 | case StDataType::INT32: 58 | return ncclInt32; 59 | case StDataType::INT64: 60 | return ncclInt64; 61 | case StDataType::FLOAT16: 62 | return ncclFloat16; 63 | case StDataType::FLOAT32: 64 | return ncclFloat32; 65 | case StDataType::FLOAT64: 66 | return ncclFloat64; 67 | case StDataType::BOOL: 68 | return ncclInt8; 69 | default: 70 | printf("[ERROR] Unsupported data type\n"); 71 | exit(-1); 72 | } 73 | } 74 | 75 | constexpr MPI_Datatype get_mpi_type() const 76 | { 77 | switch (type) { 78 | case INT8: 79 | return MPI_INT8_T; 80 | case INT32: 81 | return MPI_INT32_T; 82 | case INT64: 83 | return MPI_INT64_T; 84 | case FLOAT16: 85 | return MPI_INT16_T; // MPI_FLOAT16_T is not supported, use MPI_INT16_T instead 86 | case FLOAT32: 87 | return MPI_FLOAT; 88 | case FLOAT64: 89 | return MPI_DOUBLE; 90 | case BOOL: 91 | return MPI_C_BOOL; 92 | } 93 | printf("[ERROR] Unsupported data type\n"); 94 | exit(-1); 95 | } 96 | 97 | private: 98 | Type type; 99 | }; 100 | 101 | template 102 | StDataType stGetDataType(){ 103 | if (std::is_same::value) { 104 | return StDataType(StDataType::INT8); 105 | } else if (std::is_same::value) { 106 | return StDataType(StDataType::INT32); 107 | } else if (std::is_same::value) { 108 | return StDataType(StDataType::INT64); 109 | } else if (std::is_same::value) { 110 | return StDataType(StDataType::FLOAT16); 111 | } else if (std::is_same::value) { 112 | return StDataType(StDataType::FLOAT32); 113 | } else if (std::is_same::value) { 114 | return StDataType(StDataType::FLOAT64); 115 | } else if (std::is_same::value) { 116 | return StDataType(StDataType::BOOL); 117 | } else { 118 | printf("[ERROR] Unsupported data type\n"); 119 | exit(-1); 120 | } 121 | } 122 | 123 | } -------------------------------------------------------------------------------- /src/csrc/util/torch_utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | namespace st::util { 7 | 8 | template 9 | inline torch::ScalarType getTorchScalarType() { 10 | if (std::is_same::value) { 11 | return torch::kFloat; 12 | } else if (std::is_same::value) { 13 | return torch::kHalf; 14 | } else { 15 | throw std::runtime_error("Unsupported type"); 16 | } 17 | } 18 | 19 | inline void* convertTensorToRawPtr(torch::Tensor& tensor) { 20 | if (tensor.scalar_type() == torch::kFloat) { 21 | return tensor.data_ptr(); 22 | } else if (tensor.scalar_type() == torch::kHalf) { 23 | return tensor.data_ptr(); 24 | } else { 25 | throw std::runtime_error("Unsupported type"); 26 | } 27 | } 28 | 29 | inline size_t getTensorSizeInBytes(torch::Tensor tensor) 30 | { 31 | return tensor.numel() * torch::elementSize(torch::typeMetaToScalarType(tensor.dtype())); 32 | } 33 | 34 | } // namespace st::util 35 | -------------------------------------------------------------------------------- /src/examples/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_library(st_args STATIC lib/st_args.cc) 2 | target_link_libraries(st_args PRIVATE) 3 | 4 | add_executable(run_gpt run_gpt.cc lib/inference_batch.cc) 5 | target_link_libraries(run_gpt model_gpt nlohmann_json::nlohmann_json argparse st_args) 6 | 7 | add_executable(benchmark_all_input_same benchmark_all_input_same.cc lib/inference_batch.cc) 8 | target_link_libraries(benchmark_all_input_same model_gpt nlohmann_json::nlohmann_json argparse) 9 | -------------------------------------------------------------------------------- /src/examples/lib/common_gpt_hyper_params.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | #include "model/gpt/gpt_hyper_param.h" 7 | #include "kernel/activation_types.h" 8 | 9 | // https://huggingface.co/facebook/opt-125m/blob/main/config.json 10 | const st::model::GptHyperParam HYPERPARAM_OPT_125M = st::model::GptHyperParam::GetOptHyperParam( // opt-125m 11 | 50272, 12 | 2048, 13 | 768, 14 | 12, 15 | 12, 16 | 64, 17 | 3072 18 | ); 19 | 20 | const st::model::GptHyperParam HYPERPARAM_OPT_1P3B = st::model::GptHyperParam::GetOptHyperParam( // opt-1.3b 21 | 50272, 22 | 2048, 23 | 2048, 24 | 24, 25 | 32, 26 | 64, 27 | 8192 28 | ); 29 | 30 | const st::model::GptHyperParam HYPERPARAM_OPT_2P7B = st::model::GptHyperParam::GetOptHyperParam( // opt-2.7b 31 | 50272, 32 | 2048, 33 | 2560, 34 | 32, 35 | 32, 36 | 80, 37 | 10240 38 | ); 39 | 40 | const st::model::GptHyperParam HYPERPARAM_OPT_6P7B = st::model::GptHyperParam::GetOptHyperParam( // opt-6.7b 41 | 50272, 42 | 2048, 43 | 4096, 44 | 32, 45 | 32, 46 | 128, 47 | 16384 48 | ); 49 | 50 | const st::model::GptHyperParam HYPERPARAM_OPT_13B = st::model::GptHyperParam::GetOptHyperParam( // opt-13b 51 | 50272, 52 | 2048, 53 | 5120, 54 | 40, 55 | 40, 56 | 128, 57 | 20480 58 | ); 59 | 60 | const st::model::GptHyperParam HYPERPARAM_OPT_30B = st::model::GptHyperParam::GetOptHyperParam( // opt-30b 61 | 50272, 62 | 2048, 63 | 7168, 64 | 48, 65 | 56, 66 | 128, 67 | 28672 68 | ); 69 | 70 | const st::model::GptHyperParam HYPERPARAM_LLAMA2_7B = st::model::GptHyperParam::GetLlama2HyperParam( // llama2-7b 71 | 32000, 72 | 4096, 73 | 4096, 74 | 32, 75 | 32, 76 | 32, 77 | 128, 78 | 11008 79 | ); 80 | 81 | const st::model::GptHyperParam HYPERPARAM_LLAMA2_13B = st::model::GptHyperParam::GetLlama2HyperParam( // llama2-13b 82 | 32000, 83 | 4096, 84 | 5120, 85 | 40, 86 | 40, 87 | 40, 88 | 128, 89 | 13824 90 | ); 91 | 92 | const st::model::GptHyperParam HYPERPARAM_LLAMA2_70B = st::model::GptHyperParam::GetLlama2HyperParam( // llama2-70b 93 | 32000, 94 | 4096, 95 | 8192, 96 | 80, 97 | 64, 98 | 8, 99 | 128, 100 | 28672 101 | ); 102 | 103 | // str2hyperparam - Return the correct hyperparam based on the string. 104 | // If the string is invalid, print the valid hyperparam and return a hyperparam with vocab_size = -1. 105 | inline st::model::GptHyperParam str2hyperparam(const std::string &str) { 106 | static const std::unordered_map hyper_param_map = { 107 | {"opt_125m", HYPERPARAM_OPT_125M}, 108 | {"opt_1.3b", HYPERPARAM_OPT_1P3B}, 109 | {"opt_2.7b", HYPERPARAM_OPT_2P7B}, 110 | {"opt_6.7b", HYPERPARAM_OPT_6P7B}, 111 | {"opt_13b", HYPERPARAM_OPT_13B}, 112 | {"opt_30b", HYPERPARAM_OPT_30B}, 113 | {"llama2_7b", HYPERPARAM_LLAMA2_7B}, 114 | {"llama2_13b", HYPERPARAM_LLAMA2_13B}, 115 | {"llama2_70b", HYPERPARAM_LLAMA2_70B} 116 | }; 117 | 118 | if (hyper_param_map.find(str) == hyper_param_map.end()) { 119 | printf("Invalid number of parameters: %s\n", str.c_str()); 120 | printf("Valid number of parameters: "); 121 | for (auto it = hyper_param_map.begin(); it != hyper_param_map.end(); ++it) { 122 | printf("%s ", it->first.c_str()); 123 | } 124 | exit(1); 125 | st::model::GptHyperParam res = HYPERPARAM_OPT_125M; 126 | res.vocab_size = -1; 127 | return res; 128 | } 129 | 130 | return hyper_param_map.at(str); 131 | } 132 | -------------------------------------------------------------------------------- /src/examples/lib/inference_batch.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | #include "model/gpt/opt/optop.h" 7 | #include "simple_vocab_decoder.h" 8 | 9 | struct RuntimeUsage { 10 | double context_stage_time; 11 | double decoding_stage_time; 12 | double total_time; 13 | }; 14 | 15 | template 16 | RuntimeUsage run_batched_inference( 17 | std::vector> &output_tokens_batched, 18 | const std::vector> &input_tokens_batched, 19 | st::model::Gpt &gpt, 20 | const st::model::GptHyperParam &hyper_param, 21 | const st::model::GptPagedAttnParam &pagedattn_param, 22 | const st::model::GptParallelismParam ¶llel_param, 23 | const int64_t max_decoding_step, 24 | const int64_t end_token, 25 | const int64_t num_total_blocks, 26 | bool print_debug_info, 27 | bool exit_after_context_stage = false, 28 | std::optional vocab_decoder = std::nullopt 29 | ); 30 | -------------------------------------------------------------------------------- /src/examples/lib/simple_vocab_decoder.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include 9 | using json = nlohmann::json; 10 | 11 | // SimpleVocabDecoder - A tiny (inaccurate) vocab decoder 12 | // It just maps token ids to tokens according to the vocab json file 13 | // Mainly for debugging 14 | class SimpleVocabDecoder { 15 | private: 16 | std::unordered_map vocab_map; 17 | public: 18 | SimpleVocabDecoder(const std::string &vocab_json_path) { 19 | std::ifstream vocab_json_file(vocab_json_path); 20 | if (!vocab_json_file.is_open()) { 21 | printf("Failed to open vocab json file: %s\n", vocab_json_path.c_str()); 22 | exit(1); 23 | } 24 | json vocab_json; 25 | vocab_json_file >> vocab_json; 26 | for (auto it = vocab_json.begin(); it != vocab_json.end(); it++) { 27 | vocab_map[it.value().get()] = it.key(); 28 | } 29 | } 30 | 31 | inline std::string decode(int64_t token) const { 32 | if (vocab_map.find(token) == vocab_map.end()) { 33 | return ""; 34 | } 35 | std::string result = vocab_map.at(token); 36 | if ((int)result[0] == -60 && (int)result[1] == -96) { // Ġ 37 | result = " " + result.substr(2); 38 | } 39 | return result; 40 | } 41 | }; 42 | -------------------------------------------------------------------------------- /src/examples/lib/st_args.cc: -------------------------------------------------------------------------------- 1 | #include "st_args.h" 2 | 3 | #include 4 | 5 | #include "common_gpt_hyper_params.h" 6 | 7 | namespace st::example { 8 | 9 | Precision precision_from_string(const std::string& precision_str){ 10 | if (precision_map.find(precision_str) == precision_map.end()){ 11 | std::cerr << "Invalid precision string: " + precision_str << std::endl; 12 | return Precision::INVALID; 13 | } 14 | return precision_map.at(precision_str); 15 | } 16 | 17 | std::string precision_to_string(Precision precision){ 18 | for (auto it = precision_map.begin(); it != precision_map.end(); ++it){ 19 | if (it->second == precision){ 20 | return it->first; 21 | } 22 | } 23 | std::cerr << "Invalid precision: " + std::to_string(static_cast(precision)) << std::endl; 24 | return ""; 25 | } 26 | 27 | RunArgs::RunArgs( 28 | const std::string& model_weight_path, const std::string& vocab_json_path, const std::string& input_path, 29 | const std::string& str_hyper_param, const std::string& str_precision, 30 | bool is_debug 31 | ) 32 | : model_weight_path(model_weight_path), vocab_json_path(vocab_json_path), input_path(input_path), 33 | hyper_param(str2hyperparam(str_hyper_param)), precision(precision_from_string(str_precision)), 34 | is_debug(is_debug) 35 | { 36 | } 37 | 38 | } -------------------------------------------------------------------------------- /src/examples/lib/st_args.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | #include "model/gpt/gpt_hyper_param.h" 7 | 8 | namespace st::example { 9 | 10 | enum class Precision { 11 | FP32, 12 | FP16, 13 | INVALID 14 | }; 15 | 16 | const std::unordered_map precision_map = { 17 | {"fp32", Precision::FP32}, 18 | {"fp16", Precision::FP16}, 19 | {"FP32", Precision::FP32}, 20 | {"FP16", Precision::FP16}, 21 | {"", Precision::INVALID} 22 | }; 23 | 24 | Precision precision_from_string(const std::string& precision_str); 25 | std::string precision_to_string(Precision precision); 26 | 27 | struct RunArgs { 28 | std::string model_weight_path, vocab_json_path, input_path; 29 | st::model::GptHyperParam hyper_param; 30 | Precision precision; 31 | bool is_debug = false; 32 | 33 | RunArgs() = default; 34 | RunArgs( 35 | const std::string& model_weight_path, const std::string& vocab_json_path, const std::string& input_path, 36 | const std::string& str_hyper_param, const std::string& str_precision, 37 | bool is_debug = false 38 | ); 39 | }; 40 | 41 | } // namespace st::example -------------------------------------------------------------------------------- /src/examples/lib/utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | #include "simple_vocab_decoder.h" 7 | 8 | inline void print_prompt_and_output( 9 | const std::vector &prompt_tokens, 10 | const std::vector &output_tokens, 11 | const SimpleVocabDecoder &decoder) { 12 | printf("("); 13 | for (auto token : prompt_tokens) 14 | printf("%s ", decoder.decode(token).c_str()); 15 | printf(") "); 16 | for (auto token : output_tokens) 17 | printf("%s ", decoder.decode(token).c_str()); 18 | printf(" (prompt len = %ld, %ld tokens generated)", prompt_tokens.size(), output_tokens.size()); 19 | printf("\n"); 20 | } -------------------------------------------------------------------------------- /src/unittest/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(kernel) 2 | add_subdirectory(layer) 3 | add_subdirectory(model) 4 | add_subdirectory(util) 5 | -------------------------------------------------------------------------------- /src/unittest/kernel/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_library(unittest_kernel_ref 2 | attention_ref.cc 3 | kvcache_mgmt_ref.cc 4 | rotary_posi_embedding_ref.cc 5 | ) 6 | target_link_libraries(unittest_kernel_ref PUBLIC kernel) 7 | 8 | add_executable(unittest_kernel 9 | addbias.cc 10 | findmax.cc 11 | fused_activ_multiply.cc 12 | fused_addbias_activ.cc 13 | layernorm.cc 14 | rmsnorm.cc 15 | rotary_posi_embedding.cc 16 | softmax.cc 17 | ) 18 | target_link_libraries(unittest_kernel PUBLIC kernel unittest_kernel_ref gtest_main) 19 | -------------------------------------------------------------------------------- /src/unittest/kernel/addbias.cc: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | #include "../unittest_utils.h" 7 | #include "../unittest_torch_utils.h" 8 | #include "util/cuda_utils.h" 9 | #include "util/torch_utils.h" 10 | #include "kernel/addbias.h" 11 | 12 | template 13 | class AddbiasTestSuite : public ::testing::Test { 14 | public: 15 | void SetUp() override { 16 | setupTorch(); 17 | } 18 | void TearDown() override { 19 | } 20 | }; 21 | 22 | TYPED_TEST_SUITE(AddbiasTestSuite, SupportTypes); 23 | 24 | TYPED_TEST(AddbiasTestSuite, AddbiasTest) { 25 | typedef TypeParam T; 26 | 27 | const int64_t ROUND = 16; 28 | for (int64_t i = 0; i < ROUND; ++i) { 29 | const int64_t N = 1024 + 2*i; // N need to be even, proposed by addbias.cu 30 | torch::Tensor input = torch::rand({N}, torch::kCUDA); 31 | torch::Tensor bias = torch::rand({N}, torch::kCUDA); 32 | torch::Tensor ref_output = input + bias; 33 | torch::Tensor ans_output = torch::empty({N}, torch::kCUDA); 34 | 35 | st::kernel::addbias( 36 | (T*)ans_output.data_ptr(), 37 | (T*)input.data_ptr(), 38 | (T*)bias.data_ptr(), 39 | N 40 | ); 41 | sync_check_cuda_error(); 42 | 43 | bool is_pass = isArrayAlmostEqual((T*)ans_output.data_ptr(), (T*)ref_output.data_ptr(), N, true, true); 44 | sync_check_cuda_error(); 45 | ASSERT_TRUE(is_pass); 46 | } 47 | } 48 | 49 | TYPED_TEST(AddbiasTestSuite, AddbiasBatchedTest) { 50 | typedef TypeParam T; 51 | 52 | const int64_t BATCH_SIZE = 14; 53 | const int64_t N = 1024; 54 | torch::Tensor input = torch::rand({BATCH_SIZE, N}, torch::kCUDA); 55 | torch::Tensor bias = torch::rand({N}, torch::kCUDA); 56 | torch::Tensor ref_output = input + bias; 57 | torch::Tensor ans_output = torch::empty({BATCH_SIZE, N}, torch::kCUDA); 58 | 59 | st::kernel::addbiasBatched( 60 | (T*)ans_output.data_ptr(), 61 | (T*)input.data_ptr(), 62 | (T*)bias.data_ptr(), 63 | BATCH_SIZE, 64 | N 65 | ); 66 | sync_check_cuda_error(); 67 | 68 | bool is_pass = isArrayAlmostEqual((T*)ans_output.data_ptr(), (T*)ref_output.data_ptr(), N*BATCH_SIZE, true, true); 69 | sync_check_cuda_error(); 70 | ASSERT_TRUE(is_pass); 71 | } -------------------------------------------------------------------------------- /src/unittest/kernel/attention_ref.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace st::reference::kernel { 6 | 7 | using torch::Tensor; 8 | 9 | void attentionKernelRef( 10 | Tensor &result, // [num_tokens, hidden_size] 11 | Tensor &k_cache, // [num_blocks, num_layers, num_heads, block_size, head_dim] 12 | Tensor &v_cache, // [num_blocks, num_layers, num_heads, block_size, head_dim] 13 | 14 | const Tensor &qkvs, // [num_tokens, 3, num_heads, head_dim] 15 | const float qk_scale, 16 | const Tensor &block_table_cpu, // [num_reqs, max_num_block_per_seq] 17 | const Tensor &input_len_cpu, // [num_reqs] 18 | const Tensor &is_context_stage_cpu, 19 | 20 | bool run_context_stage, 21 | bool run_decoding_stage 22 | ); 23 | 24 | } // namespace st::reference::kernel 25 | -------------------------------------------------------------------------------- /src/unittest/kernel/findmax.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "../unittest_utils.h" 5 | #include "../unittest_torch_utils.h" 6 | #include "util/cuda_utils.h" 7 | #include "util/torch_utils.h" 8 | #include "kernel/findmax.h" 9 | 10 | template 11 | class FindmaxTestSuite : public ::testing::Test { 12 | public: 13 | void SetUp() override { 14 | setupTorch(); 15 | } 16 | void TearDown() override { 17 | } 18 | }; 19 | 20 | TYPED_TEST_SUITE(FindmaxTestSuite, SupportTypes); 21 | 22 | TYPED_TEST(FindmaxTestSuite, FindmaxBatchedTest) { 23 | typedef TypeParam T; 24 | 25 | const int64_t BATCH_SIZE = 1892; 26 | const int64_t LENGTH = 2345; 27 | torch::Tensor input = torch::rand({BATCH_SIZE, LENGTH}, torch::kCUDA); 28 | torch::Tensor ref_output = torch::argmax(input, 1).cpu(); 29 | torch::Tensor ans_output = torch::empty({BATCH_SIZE}, torch::kInt64).cuda(); 30 | st::kernel::findmaxBatched( 31 | (int64_t*)ans_output.data_ptr(), 32 | (T*)input.data_ptr(), 33 | BATCH_SIZE, 34 | LENGTH 35 | ); 36 | ans_output = ans_output.cpu(); 37 | 38 | for (int64_t i = 0; i < BATCH_SIZE; ++i) { 39 | T ref_max = ((T*)input[i].cpu().data_ptr())[ref_output[i].item()]; 40 | T ans_max = ((T*)input[i].cpu().data_ptr())[ans_output[i].item()]; 41 | ASSERT_EQ(ref_max, ans_max); 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /src/unittest/kernel/fused_activ_multiply.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include 5 | #include 6 | 7 | #include "../unittest_utils.h" 8 | #include "../unittest_torch_utils.h" 9 | #include "util/cuda_utils.h" 10 | #include "util/torch_utils.h" 11 | #include "kernel/fused_activ_multiply.h" 12 | 13 | template 14 | class FusedActivMultiplyTestSuite : public ::testing::Test { 15 | public: 16 | void SetUp() override { 17 | setupTorch(); 18 | } 19 | void TearDown() override { 20 | } 21 | }; 22 | 23 | TYPED_TEST_SUITE(FusedActivMultiplyTestSuite, SupportTypes); 24 | 25 | TYPED_TEST(FusedActivMultiplyTestSuite, FusedActivMultiplyTest) { 26 | typedef TypeParam T; 27 | using st::ActivationType; 28 | 29 | const int64_t SIZE = 1057362; 30 | 31 | torch::Tensor input1 = torch::rand({SIZE}, torch::kCUDA); 32 | torch::Tensor input2 = torch::rand({SIZE}, torch::kCUDA); 33 | 34 | for (ActivationType activation_type : std::vector{ 35 | ActivationType::RELU, 36 | ActivationType::SILU, 37 | ActivationType::GELU 38 | }) { 39 | torch::Tensor ref_output = ( 40 | activation_type == ActivationType::RELU ? torch::relu(input1) : 41 | activation_type == ActivationType::SILU ? torch::silu(input1) : 42 | activation_type == ActivationType::GELU ? torch::gelu(input1) : 43 | input1 ) * input2; 44 | 45 | torch::Tensor ans_output = torch::empty({SIZE}, torch::kCUDA); 46 | st::kernel::fusedActivationMultiply( 47 | (T*)ans_output.data_ptr(), 48 | (T*)input1.data_ptr(), 49 | (T*)input2.data_ptr(), 50 | SIZE, 51 | activation_type 52 | ); 53 | sync_check_cuda_error_force(); 54 | 55 | bool is_pass = isArrayAlmostEqual((T*)ans_output.data_ptr(), (T*)ref_output.data_ptr(), SIZE, true, true); 56 | ASSERT_TRUE(is_pass); 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /src/unittest/kernel/fused_addbias_activ.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include 5 | #include 6 | 7 | #include "../unittest_utils.h" 8 | #include "../unittest_torch_utils.h" 9 | #include "util/cuda_utils.h" 10 | #include "util/torch_utils.h" 11 | #include "kernel/fused_addbias_activ.h" 12 | 13 | template 14 | class FusedAddbiasTestSuite : public ::testing::Test { 15 | public: 16 | void SetUp() override { 17 | setupTorch(); 18 | } 19 | void TearDown() override { 20 | } 21 | }; 22 | 23 | TYPED_TEST_SUITE(FusedAddbiasTestSuite, SupportTypes); 24 | 25 | TYPED_TEST(FusedAddbiasTestSuite, FusedAddbiasBatchedActivTest) { 26 | typedef TypeParam T; 27 | using st::ActivationType; 28 | 29 | const int64_t BATCH_SIZE = 243; 30 | const int64_t SIZE = 1890; 31 | 32 | torch::Tensor input = torch::rand({BATCH_SIZE, SIZE}, torch::kCUDA); 33 | torch::Tensor bias = torch::rand({SIZE}, torch::kCUDA); 34 | 35 | for (ActivationType activation_type : std::vector{ 36 | ActivationType::RELU, 37 | ActivationType::SILU, 38 | ActivationType::GELU 39 | }) { 40 | torch::Tensor ref_output = 41 | activation_type == ActivationType::RELU ? torch::relu(input + bias) : 42 | activation_type == ActivationType::SILU ? torch::silu(input + bias) : 43 | activation_type == ActivationType::GELU ? torch::gelu(input + bias) : 44 | input + bias; 45 | torch::Tensor ans_output = torch::empty({BATCH_SIZE, SIZE}, torch::kCUDA); 46 | st::kernel::fusedAddbiasBatchedActivation( 47 | (T*)ans_output.data_ptr(), 48 | (T*)input.data_ptr(), 49 | (T*)bias.data_ptr(), 50 | BATCH_SIZE, 51 | SIZE, 52 | activation_type 53 | ); 54 | sync_check_cuda_error_force(); 55 | 56 | bool is_pass = isArrayAlmostEqual((T*)ans_output.data_ptr(), (T*)ref_output.data_ptr(), BATCH_SIZE*SIZE, true, true); 57 | ASSERT_TRUE(is_pass); 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /src/unittest/kernel/kvcache_mgmt_ref.cc: -------------------------------------------------------------------------------- 1 | #include "kvcache_mgmt_ref.h" 2 | 3 | #include 4 | 5 | namespace st::reference::kernel { 6 | 7 | using torch::Tensor; 8 | 9 | void saveContextStageKVCacheKernelRef( 10 | Tensor &k_cache, // [num_blocks, num_layers, num_kv_heads, block_size, head_dim] 11 | Tensor &v_cache, // [num_blocks, num_layers, num_kv_heads, block_size, head_dim] 12 | const Tensor &qkvs, // [num_tokens, num_q_heads+2*num_kv_heads, head_dim] 13 | const Tensor &block_table_cpu, // [num_reqs, max_num_block_per_seq] 14 | const Tensor &input_len_cpu, // [num_reqs] 15 | const Tensor &is_context_stage_cpu, // [num_reqs] 16 | const int64_t layer_id 17 | ) { 18 | const int64_t num_reqs = input_len_cpu.size(0); 19 | const int64_t block_size = k_cache.size(3); 20 | const int64_t num_kv_heads = k_cache.size(2); 21 | const int64_t num_q_heads = qkvs.size(1) - 2 * num_kv_heads; 22 | 23 | int64_t first_token_index = 0; 24 | for (int64_t req_index = 0; req_index < num_reqs; ++req_index) { 25 | const int64_t input_len = input_len_cpu[req_index].item(); 26 | const bool is_context_stage = is_context_stage_cpu[req_index].item(); 27 | 28 | if (is_context_stage) { 29 | for (int64_t token_index = 0; token_index < input_len; ++token_index) { 30 | int64_t block_index = block_table_cpu[req_index][token_index / block_size].item(); 31 | int64_t block_offset = token_index % block_size; 32 | k_cache[block_index][layer_id].select(1, block_offset) = qkvs[first_token_index + token_index].slice(0, num_q_heads, num_q_heads+num_kv_heads); 33 | v_cache[block_index][layer_id].select(1, block_offset) = qkvs[first_token_index + token_index].slice(0, num_q_heads+num_kv_heads, num_q_heads+2*num_kv_heads); 34 | } 35 | } 36 | 37 | first_token_index += is_context_stage ? input_len : 1; 38 | } 39 | } 40 | 41 | } -------------------------------------------------------------------------------- /src/unittest/kernel/kvcache_mgmt_ref.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace st::reference::kernel { 6 | 7 | using torch::Tensor; 8 | 9 | void saveContextStageKVCacheKernelRef( 10 | Tensor &k_cache, 11 | Tensor &v_cache, 12 | const Tensor &qkvs, 13 | const Tensor &block_table_cpu, 14 | const Tensor &input_len_cpu, 15 | const Tensor &is_context_stage_cpu, 16 | const int64_t layer_id 17 | ); 18 | 19 | } // namespace st::reference::kernel 20 | -------------------------------------------------------------------------------- /src/unittest/kernel/layernorm.cc: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | #include "../unittest_utils.h" 7 | #include "../unittest_torch_utils.h" 8 | #include "util/cuda_utils.h" 9 | #include "util/torch_utils.h" 10 | #include "kernel/layernorm.h" 11 | 12 | template 13 | class LayernormTestSuite : public ::testing::Test { 14 | public: 15 | void SetUp() override { 16 | setupTorch(); 17 | } 18 | void TearDown() override { 19 | } 20 | }; 21 | 22 | TYPED_TEST_SUITE(LayernormTestSuite, SupportTypes); 23 | 24 | TYPED_TEST(LayernormTestSuite, LayernormTest) { 25 | typedef TypeParam T; 26 | 27 | const int64_t BATCH_SIZE = 259; 28 | const int64_t HIDDEN_SIZE = 8192; 29 | 30 | torch::Tensor input = torch::rand({BATCH_SIZE, HIDDEN_SIZE}, torch::kCUDA); 31 | torch::Tensor weight = torch::rand({HIDDEN_SIZE}, torch::kCUDA); 32 | torch::Tensor bias = torch::rand({HIDDEN_SIZE}, torch::kCUDA); 33 | const T epsilon = 1e-4; 34 | 35 | torch::Tensor ans_output = torch::empty({BATCH_SIZE, HIDDEN_SIZE}, torch::kCUDA); 36 | st::kernel::layernorm( 37 | (T*)ans_output.data_ptr(), 38 | (T*)input.data_ptr(), 39 | 40 | (T*)weight.data_ptr(), 41 | (T*)bias.data_ptr(), 42 | epsilon, 43 | 44 | BATCH_SIZE, 45 | HIDDEN_SIZE 46 | ); 47 | 48 | torch::Tensor ref_output = torch::layer_norm(input, {HIDDEN_SIZE}, weight, bias, epsilon); 49 | 50 | bool is_passed = isArrayAlmostEqual((T*)ans_output.data_ptr(), (T*)ref_output.data_ptr(), BATCH_SIZE * HIDDEN_SIZE, true, true); 51 | ASSERT_TRUE(is_passed); 52 | } 53 | -------------------------------------------------------------------------------- /src/unittest/kernel/rmsnorm.cc: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | #include "../unittest_utils.h" 7 | #include "../unittest_torch_utils.h" 8 | #include "util/cuda_utils.h" 9 | #include "util/torch_utils.h" 10 | #include "kernel/rmsnorm.h" 11 | 12 | template 13 | class RmsnormTestSuite : public ::testing::Test { 14 | public: 15 | void SetUp() override { 16 | setupTorch(); 17 | } 18 | void TearDown() override { 19 | } 20 | }; 21 | 22 | TYPED_TEST_SUITE(RmsnormTestSuite, SupportTypes); 23 | 24 | TYPED_TEST(RmsnormTestSuite, RmsnormTest) { 25 | typedef TypeParam T; 26 | 27 | const int64_t BATCH_SIZE = 259; 28 | const int64_t HIDDEN_SIZE = 40761; 29 | 30 | torch::Tensor input = torch::rand({BATCH_SIZE, HIDDEN_SIZE}, torch::kCUDA); 31 | torch::Tensor weight = torch::rand({HIDDEN_SIZE}, torch::kCUDA); 32 | const T epsilon = 1e-4; 33 | 34 | torch::Tensor ans_output = torch::empty({BATCH_SIZE, HIDDEN_SIZE}, torch::kCUDA); 35 | st::kernel::rmsnorm( 36 | (T*)ans_output.data_ptr(), 37 | (T*)input.data_ptr(), 38 | 39 | (T*)weight.data_ptr(), 40 | epsilon, 41 | 42 | BATCH_SIZE, 43 | HIDDEN_SIZE 44 | ); 45 | 46 | torch::Tensor t = input.pow(2).mean(-1, true) + (float)epsilon; 47 | torch::Tensor ref_output = input * torch::rsqrt(t) * weight; 48 | 49 | bool is_passed = isArrayAlmostEqual((T*)ans_output.data_ptr(), (T*)ref_output.data_ptr(), BATCH_SIZE * HIDDEN_SIZE, true, true); 50 | ASSERT_TRUE(is_passed); 51 | } 52 | -------------------------------------------------------------------------------- /src/unittest/kernel/rotary_posi_embedding.cc: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | #include "../unittest_utils.h" 7 | #include "../unittest_torch_utils.h" 8 | #include "util/cuda_utils.h" 9 | #include "util/torch_utils.h" 10 | #include "kernel/rotary_posi_embedding.h" 11 | #include "rotary_posi_embedding_ref.h" 12 | 13 | template 14 | class RotaryPosiEmbeddingTestSuite : public ::testing::Test { 15 | public: 16 | void SetUp() override { 17 | setupTorch(); 18 | } 19 | void TearDown() override { 20 | } 21 | }; 22 | 23 | TYPED_TEST_SUITE(RotaryPosiEmbeddingTestSuite, SupportTypes); 24 | 25 | TYPED_TEST(RotaryPosiEmbeddingTestSuite, RotaryPosiEmbeddingTest) { 26 | typedef TypeParam T; 27 | std::mt19937 gen(0); 28 | 29 | const int64_t NUM_TOKENS = 243; 30 | const int64_t NUM_HEADS = 64; 31 | const int64_t HEAD_DIM = 128; 32 | 33 | torch::Tensor input = torch::rand({NUM_TOKENS, NUM_HEADS, HEAD_DIM}, torch::kCUDA); 34 | std::vector indexes(NUM_TOKENS); 35 | for (int64_t i = 0; i < NUM_TOKENS; ++i) { 36 | indexes[i] = gen() % NUM_TOKENS; 37 | } 38 | 39 | torch::Tensor ans_output = input.clone(); 40 | GpuArray d_indexes(indexes); 41 | st::kernel::rotaryPosiEmbeddingBatched( 42 | (T*)ans_output.data_ptr(), 43 | d_indexes.data, 44 | NUM_TOKENS, 45 | NUM_HEADS, 46 | NUM_HEADS, 47 | HEAD_DIM 48 | ); 49 | sync_check_cuda_error_force(); 50 | 51 | torch::Tensor ref_output = input.clone(); 52 | st::reference::kernel::rotaryPosiEmbeddingKernelRef( 53 | ref_output, 54 | indexes 55 | ); 56 | sync_check_cuda_error_force(); 57 | 58 | bool is_pass = isArrayAlmostEqual((T*)ans_output.data_ptr(), (T*)ref_output.data_ptr(), NUM_TOKENS*NUM_HEADS*HEAD_DIM, true, true); 59 | sync_check_cuda_error(); 60 | ASSERT_TRUE(is_pass); 61 | } 62 | -------------------------------------------------------------------------------- /src/unittest/kernel/rotary_posi_embedding_ref.cc: -------------------------------------------------------------------------------- 1 | #include "rotary_posi_embedding_ref.h" 2 | 3 | #include 4 | 5 | namespace st::reference::kernel { 6 | 7 | using torch::Tensor; 8 | 9 | void rotaryPosiEmbeddingKernelRef( 10 | Tensor &target, // [num_tokens, num_heads, head_dim] 11 | const std::vector &indexes // [num_tokens] 12 | ) { 13 | const int64_t NUM_TOKENS = target.size(0); 14 | const int64_t NUM_HEADS = target.size(1); 15 | const int64_t HEAD_DIM = target.size(2); 16 | 17 | // To guarantee precision, we use float32 instead of half here. 18 | torch::Tensor angles = torch::arange(0, HEAD_DIM/2, 1, torch::kCUDA).to(torch::kFloat32).mul(-2).div(HEAD_DIM); 19 | torch::Tensor thetas = torch::pow(10000, angles); // thetas[i] = 10000 ** (-2i/d), float32 20 | 21 | // Deal each token one by one 22 | for (int64_t token_index = 0; token_index < NUM_TOKENS; ++token_index) { 23 | int64_t index = indexes[token_index]; 24 | torch::Tensor coses = torch::cos(thetas.mul(index)); // coses[i] = cos(10000 ** (-2i/d) * index) 25 | torch::Tensor sins = torch::sin(thetas.mul(index)); // sins[i] = sin(10000 ** (-2i/d) * index) 26 | coses = coses.unsqueeze(0).repeat_interleave(NUM_HEADS, 0); // [num_heads, head_dim/2] 27 | sins = sins.unsqueeze(0).repeat_interleave(NUM_HEADS, 0); // [num_heads, head_dim/2] 28 | 29 | torch::Tensor cur_token = target[token_index].to(torch::kFloat32); // [num_heads, head_dim] 30 | torch::Tensor cur_token_even = cur_token.index({torch::indexing::Slice(), torch::indexing::Slice(0, HEAD_DIM, 2)}); // [num_heads, head_dim/2] 31 | torch::Tensor cur_token_odd = cur_token.index({torch::indexing::Slice(), torch::indexing::Slice(1, HEAD_DIM, 2)}); // [num_heads, head_dim/2] 32 | 33 | torch::Tensor cur_token_even_rot = cur_token_even.mul(coses) - cur_token_odd.mul(sins); // [num_heads, head_dim/2] 34 | torch::Tensor cur_token_odd_rot = cur_token_even.mul(sins) + cur_token_odd.mul(coses); // [num_heads, head_dim/2] 35 | 36 | target[token_index].index_put_({torch::indexing::Slice(), torch::indexing::Slice(0, HEAD_DIM, 2)}, cur_token_even_rot.to(torch::kFloat16)); 37 | target[token_index].index_put_({torch::indexing::Slice(), torch::indexing::Slice(1, HEAD_DIM, 2)}, cur_token_odd_rot.to(torch::kFloat16)); 38 | // Thanks to CoPilot, otherwise I would never be able to write the code above... 39 | } 40 | } 41 | 42 | } // namespace st::reference::kernel -------------------------------------------------------------------------------- /src/unittest/kernel/rotary_posi_embedding_ref.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | namespace st::reference::kernel { 7 | 8 | using torch::Tensor; 9 | 10 | void rotaryPosiEmbeddingKernelRef( 11 | Tensor &target, // [num_tokens, num_heads, head_dim] 12 | const std::vector &indexes // [num_tokens] 13 | ); 14 | 15 | } // namespace st::reference::kernel 16 | -------------------------------------------------------------------------------- /src/unittest/kernel/softmax.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "../unittest_utils.h" 5 | #include "../unittest_torch_utils.h" 6 | #include "util/cuda_utils.h" 7 | #include "util/torch_utils.h" 8 | #include "kernel/softmax.h" 9 | 10 | template 11 | class SoftmaxTestSuite : public ::testing::Test { 12 | public: 13 | void SetUp() override { 14 | setupTorch(); 15 | } 16 | void TearDown() override { 17 | } 18 | }; 19 | 20 | TYPED_TEST_SUITE(SoftmaxTestSuite, SupportTypes); 21 | 22 | TYPED_TEST(SoftmaxTestSuite, ScaleMaskSoftmax) { 23 | typedef TypeParam T; 24 | 25 | const int64_t NUM_HEADS = 128; 26 | const int64_t INPUT_LEN = 556; 27 | const float scale = 1.0 / sqrt(INPUT_LEN); 28 | 29 | torch::Tensor input = getRandomTensor({NUM_HEADS, INPUT_LEN, INPUT_LEN}); 30 | torch::Tensor ans_output = torch::zeros_like(input); 31 | torch::Tensor ref_output = input.to(torch::kFloat) * scale; 32 | 33 | st::kernel::scaleMaskSoftmax( 34 | (T*)ans_output.data_ptr(), 35 | (T*)input.data_ptr(), 36 | scale, 37 | NUM_HEADS, 38 | INPUT_LEN 39 | ); 40 | 41 | torch::Tensor attn_mask = torch::zeros({INPUT_LEN, INPUT_LEN}, torch::kInt64); 42 | for (int64_t i = 0; i < INPUT_LEN; ++i) { 43 | for (int64_t j = i+1; j < INPUT_LEN; ++j) { 44 | attn_mask.accessor()[i][j] = -10000; 45 | } 46 | } 47 | ref_output = ref_output + attn_mask.to(std::is_same() ? at::kHalf : at::kFloat).to(torch::kCUDA); 48 | ref_output = torch::softmax(ref_output, 2); 49 | ref_output = ref_output.to(std::is_same() ? at::kHalf : at::kFloat); 50 | 51 | bool is_passed = isArrayAlmostEqual((T*)ans_output.data_ptr(), (T*)ref_output.data_ptr(), NUM_HEADS*INPUT_LEN*INPUT_LEN, true, true); 52 | ASSERT_TRUE(is_passed); 53 | } 54 | 55 | TYPED_TEST(SoftmaxTestSuite, ScaleSoftmax) { 56 | typedef TypeParam T; 57 | 58 | const int64_t NUM_HEADS = 128; 59 | const int64_t INPUT_LEN = 1950; 60 | const float scale = 1.0 / sqrt(INPUT_LEN); 61 | 62 | torch::Tensor input = getRandomTensor({NUM_HEADS, INPUT_LEN}); 63 | torch::Tensor ans_output = torch::zeros_like(input); 64 | 65 | st::kernel::scaleSoftmax( 66 | (T*)ans_output.data_ptr(), 67 | (T*)input.data_ptr(), 68 | scale, 69 | NUM_HEADS, 70 | INPUT_LEN 71 | ); 72 | 73 | torch::Tensor ref_output = input.to(torch::kFloat) * scale; 74 | ref_output = torch::softmax(ref_output, 1); 75 | ref_output = ref_output.to(std::is_same() ? at::kHalf : at::kFloat); 76 | 77 | bool is_passed = isArrayAlmostEqual((T*)ans_output.data_ptr(), (T*)ref_output.data_ptr(), NUM_HEADS*INPUT_LEN, true, true); 78 | ASSERT_TRUE(is_passed); 79 | } -------------------------------------------------------------------------------- /src/unittest/layer/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_library(unittest_layer_ref 2 | attention_ref.cc 3 | ) 4 | target_link_libraries(unittest_layer_ref PUBLIC unittest_kernel_ref) 5 | 6 | add_executable(unittest_layer_para 7 | parallel_ffn.cc 8 | parallel_attention.cc 9 | ) 10 | target_link_libraries(unittest_layer_para PUBLIC unittest_layer_ref layer util kernel gtest_main nccl_utils) -------------------------------------------------------------------------------- /src/unittest/layer/attention_ref.cc: -------------------------------------------------------------------------------- 1 | #include "attention_ref.h" 2 | 3 | #include 4 | #include 5 | 6 | #include "../kernel/attention_ref.h" 7 | #include "../kernel/kvcache_mgmt_ref.h" 8 | 9 | namespace st::reference::layer { 10 | 11 | using torch::Tensor; 12 | 13 | void attentionLayerRef( 14 | Tensor &result, // [num_tokens, hidden_size] 15 | Tensor &k_cache, // [num_blocks, num_layers, num_kv_heads, block_size, head_dim] 16 | Tensor &v_cache, // [num_blocks, num_layers, num_kv_heads, block_size, head_dim] 17 | 18 | const Tensor &input, // [num_tokens, hidden_size] 19 | const Tensor &input_len_cpu, // [num_reqs] 20 | const Tensor &is_context_stage_cpu, // [num_reqs] 21 | const Tensor &block_table_cpu, // [num_reqs, max_num_block_per_seq] 22 | 23 | const float qk_scale, 24 | 25 | const Tensor &qkv_weight_kernel, // [hidden_size, num_q_heads + 2*num_kv_heads, head_dim] 26 | const Tensor &qkv_weight_bias, // [num_q_heads+2*num_kv_heads, head_dim] 27 | const Tensor &out_weight_kernel, // [num_q_heads, head_dim, hidden_size] 28 | const Tensor &out_weight_bias, // [hidden_size] 29 | 30 | const int64_t layer_id 31 | ) { 32 | const int64_t num_tokens = input.size(0); 33 | const int64_t num_q_heads = out_weight_kernel.size(0); 34 | const int64_t num_kv_heads = k_cache.size(2); 35 | const int64_t head_dim = qkv_weight_kernel.size(2); 36 | const int64_t hidden_size = qkv_weight_kernel.size(0); 37 | 38 | // Step 1. QKV GEMM 39 | Tensor qkvs = torch::matmul(input, qkv_weight_kernel.view({hidden_size, (num_q_heads+2*num_kv_heads)*head_dim})); 40 | qkvs += qkv_weight_bias.view({(num_q_heads+2*num_kv_heads)*head_dim}); 41 | qkvs = qkvs.view({num_tokens, num_q_heads+2*num_kv_heads, head_dim}); // [num_tokens, num_q_heads+2*num_kv_heads, head_dim] 42 | 43 | // Step 2. Attention 44 | // result: [num_tokens, hidden_size] 45 | st::reference::kernel::attentionKernelRef( 46 | result, 47 | k_cache, 48 | v_cache, 49 | 50 | qkvs, 51 | qk_scale, 52 | block_table_cpu, 53 | input_len_cpu, 54 | is_context_stage_cpu, 55 | 56 | true, true 57 | ); 58 | 59 | // Step 3. Save KV Cache 60 | st::reference::kernel::saveContextStageKVCacheKernelRef( 61 | k_cache, 62 | v_cache, 63 | qkvs, 64 | block_table_cpu, 65 | input_len_cpu, 66 | is_context_stage_cpu, 67 | layer_id 68 | ); 69 | 70 | // Step 4. Output GEMM 71 | result = torch::matmul(result, out_weight_kernel.view({hidden_size, hidden_size})) + out_weight_bias; 72 | } 73 | 74 | } -------------------------------------------------------------------------------- /src/unittest/layer/attention_ref.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace st::reference::layer { 6 | 7 | using torch::Tensor; 8 | 9 | void attentionLayerRef( 10 | Tensor &result, 11 | Tensor &k_cache, 12 | Tensor &v_cache, 13 | 14 | const Tensor &input, 15 | const Tensor &input_len_cpu, 16 | const Tensor &is_context_stage_cpu, 17 | const Tensor &block_table_cpu, 18 | 19 | const float qk_scale, 20 | 21 | const Tensor &qkv_weight_kernel, 22 | const Tensor &qkv_weight_bias, 23 | const Tensor &out_weight_kernel, 24 | const Tensor &out_weight_bias, 25 | 26 | const int64_t layer_id 27 | ); 28 | 29 | } -------------------------------------------------------------------------------- /src/unittest/layer/attention_utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include "../unittest_utils.h" 8 | #include "../unittest_torch_utils.h" 9 | #include "util/cuda_utils.h" 10 | #include "util/torch_utils.h" 11 | #include "util/cublas_wrapper.h" 12 | #include "layer/attention.h" 13 | 14 | constexpr int64_t NUM_TOTAL_BLOCKS = 2048; 15 | 16 | struct PagedAttnParam { 17 | int64_t block_size; 18 | int64_t max_num_block_per_req; 19 | }; 20 | 21 | class Indexes { 22 | public: 23 | // CPU 24 | int64_t* ith_context_req_req_index_cpu; 25 | int32_t* ith_context_req_token_index_cpu; 26 | int64_t* ith_decoding_req_req_index_cpu; 27 | int64_t* ith_decoding_req_token_index_cpu; 28 | int64_t batch_size; 29 | 30 | // GPU 31 | int64_t* ith_context_req_req_index; 32 | int32_t* ith_context_req_token_index; 33 | int64_t* ith_decoding_req_req_index; 34 | int64_t* ith_decoding_req_token_index; 35 | 36 | Indexes(){ 37 | ith_context_req_req_index_cpu = nullptr; 38 | ith_context_req_token_index_cpu = nullptr; 39 | ith_decoding_req_req_index_cpu = nullptr; 40 | ith_decoding_req_token_index_cpu = nullptr; 41 | } 42 | 43 | Indexes(const int64_t batch_size): batch_size(batch_size) { 44 | // CPU 45 | ith_context_req_req_index_cpu = new int64_t[batch_size]; 46 | ith_context_req_token_index_cpu = new int32_t[batch_size+1]; 47 | ith_decoding_req_req_index_cpu = new int64_t[batch_size]; 48 | ith_decoding_req_token_index_cpu = new int64_t[batch_size]; 49 | 50 | // GPU 51 | CUDA_CHECK(cudaMalloc(&ith_context_req_req_index, batch_size * sizeof(int64_t))); 52 | CUDA_CHECK(cudaMalloc(&ith_context_req_token_index, (batch_size+1) * sizeof(int32_t))); 53 | CUDA_CHECK(cudaMalloc(&ith_decoding_req_req_index, batch_size * sizeof(int64_t))); 54 | CUDA_CHECK(cudaMalloc(&ith_decoding_req_token_index, batch_size * sizeof(int64_t))); 55 | } 56 | 57 | ~Indexes() { 58 | if (ith_context_req_req_index_cpu != nullptr) { 59 | delete[] ith_context_req_req_index_cpu; 60 | delete[] ith_context_req_token_index_cpu; 61 | delete[] ith_decoding_req_req_index_cpu; 62 | delete[] ith_decoding_req_token_index_cpu; 63 | CUDA_CHECK(cudaFree(ith_context_req_req_index)); 64 | CUDA_CHECK(cudaFree(ith_context_req_token_index)); 65 | CUDA_CHECK(cudaFree(ith_decoding_req_req_index)); 66 | CUDA_CHECK(cudaFree(ith_decoding_req_token_index)); 67 | } 68 | } 69 | 70 | void toGPU(){ 71 | CUDA_CHECK(cudaMemcpy( 72 | ith_context_req_req_index, 73 | ith_context_req_req_index_cpu, 74 | sizeof(int64_t) * batch_size, 75 | cudaMemcpyHostToDevice 76 | )); 77 | CUDA_CHECK(cudaMemcpy( 78 | ith_context_req_token_index, 79 | ith_context_req_token_index_cpu, 80 | sizeof(int32_t) * (batch_size+1), 81 | cudaMemcpyHostToDevice 82 | )); 83 | CUDA_CHECK(cudaMemcpy( 84 | ith_decoding_req_req_index, 85 | ith_decoding_req_req_index_cpu, 86 | sizeof(int64_t) * batch_size, 87 | cudaMemcpyHostToDevice 88 | )); 89 | CUDA_CHECK(cudaMemcpy( 90 | ith_decoding_req_token_index, 91 | ith_decoding_req_token_index_cpu, 92 | sizeof(int64_t) * batch_size, 93 | cudaMemcpyHostToDevice 94 | )); 95 | } 96 | }; 97 | 98 | void build_block_table( 99 | int64_t* block_table, 100 | const int64_t batch_size, 101 | const PagedAttnParam pagedattn_param, 102 | const int64_t *input_len_cpu 103 | ){ 104 | constexpr int64_t num_total_blocks = NUM_TOTAL_BLOCKS; 105 | // Construct the initial block table 106 | int64_t num_allocated_blocks = 0; 107 | std::function allocateNewBlock = [&]() -> int64_t { 108 | num_allocated_blocks += 1; 109 | assert(num_allocated_blocks < num_total_blocks); 110 | return num_allocated_blocks; 111 | }; 112 | int64_t* allocated_block_cnt = new int64_t[batch_size]; 113 | int64_t* block_table_cpu = new int64_t[batch_size * pagedattn_param.max_num_block_per_req]; 114 | for (int64_t i = 0; i < batch_size; i++) { 115 | int64_t block_needed = (input_len_cpu[i]+1 + pagedattn_param.block_size-1) / pagedattn_param.block_size; 116 | allocated_block_cnt[i] = block_needed; 117 | assert (block_needed <= pagedattn_param.max_num_block_per_req); 118 | for (int64_t j = 0; j < block_needed; j++) { 119 | block_table_cpu[i * pagedattn_param.max_num_block_per_req + j] = allocateNewBlock(); 120 | } 121 | for (int64_t j = block_needed; j < pagedattn_param.max_num_block_per_req; ++j) { 122 | block_table_cpu[i * pagedattn_param.max_num_block_per_req + j] = -10000000; 123 | } 124 | } 125 | CUDA_CHECK(cudaMemcpy(block_table, block_table_cpu, sizeof(int64_t) * batch_size * pagedattn_param.max_num_block_per_req, cudaMemcpyHostToDevice)); 126 | delete[] block_table_cpu; 127 | sync_check_cuda_error(); 128 | } 129 | 130 | Indexes get_req_index(const int64_t batch_size, const int64_t* input_len_cpu, const bool* is_context_stage_cpu){ 131 | // Calculate indexes of requests in context stage and regression stage 132 | // Will be used in the attention layer (fusedDecodingStageAttentionKernel and fusedContextStageAttentionKernel) 133 | int64_t num_context_reqs = 0, num_decoding_reqs = 0; 134 | Indexes indexes(batch_size); 135 | int64_t cur_token_index = 0; 136 | for (int64_t i = 0; i < batch_size; ++i) { 137 | if (is_context_stage_cpu[i]) { 138 | indexes.ith_context_req_req_index_cpu[num_context_reqs] = i; 139 | indexes.ith_context_req_token_index_cpu[num_context_reqs] = cur_token_index; 140 | num_context_reqs += 1; 141 | cur_token_index += input_len_cpu[i]; 142 | } else { 143 | indexes.ith_decoding_req_req_index_cpu[num_decoding_reqs] = i; 144 | indexes.ith_decoding_req_token_index_cpu[num_decoding_reqs] = cur_token_index; 145 | num_decoding_reqs += 1; 146 | cur_token_index += 1; 147 | } 148 | } 149 | indexes.ith_context_req_token_index_cpu[num_context_reqs] = cur_token_index; 150 | return indexes; 151 | } -------------------------------------------------------------------------------- /src/unittest/model/CMakeLists.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LLMServe/SwiftTransformer/9948e31b371aac93caff73b8c0dc50cf77d890cb/src/unittest/model/CMakeLists.txt -------------------------------------------------------------------------------- /src/unittest/unittest_torch_utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include "util/cuda_utils.h" 8 | #include "util/torch_utils.h" 9 | 10 | // setupTorch - Set up pytorch's seed and default datatype 11 | template 12 | inline void setupTorch() { 13 | torch::manual_seed(0); 14 | // torch::Device device(torch::kCUDA); 15 | torch::set_default_dtype(caffe2::TypeMeta::fromScalarType(st::util::getTorchScalarType())); 16 | } 17 | 18 | inline torch::Tensor getRandomTensor(at::IntArrayRef shape, float lower = -1, float upper = +1, torch::Device device = torch::kCUDA) { 19 | auto options = torch::TensorOptions().requires_grad(false).device(device); 20 | return torch::rand(shape, options) * (upper - lower) + lower; 21 | } 22 | -------------------------------------------------------------------------------- /src/unittest/unittest_utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | // Types for TYPED_TEST 10 | typedef testing::Types FloatAndHalfTypes; 11 | #ifndef ENABLE_BF16 12 | typedef FloatAndHalfTypes SupportTypes; 13 | #else 14 | typedef testing::Types FloatHalfBf16Types; 15 | typedef FloatHalfBf16Types SupportTypes; 16 | #endif 17 | 18 | 19 | // Helpers for random number generating 20 | // Generate a random integer from interval [min, max] 21 | template 22 | inline T randNumber(std::mt19937 &rng, T min, T max) { 23 | std::uniform_int_distribution dist(min, max); 24 | return dist(rng); 25 | } 26 | 27 | // isAlmostEqual - Check if two floats are equal under given precision 28 | // It accomplishes this by checking whether fabs(answer-reference) <= ans_tol + rel_tol*fabs(reference) 29 | // When both answer & reference are NaN, return true 30 | inline bool isFloatAlmostEqual(float answer, float reference, const float abs_tol, const float rel_tol) { 31 | if (std::isnan(answer) && std::isnan(reference)) { 32 | return true; 33 | } 34 | if (std::isnan(answer) || std::isnan(reference)) { 35 | return false; 36 | } 37 | return fabs(answer-reference) <= abs_tol + rel_tol*fabs(reference); 38 | } 39 | 40 | 41 | // isArrayAlmostEqual - Check whether two float/half arrays are equal 42 | // Precisions are: 43 | // - abs_tol = 1e-4, rel_tol = 1e-2 for Float (FP32) 44 | // - abs_tol = 1e-3, rel_tol = 1e-1 for Half (FP16) and bfloat16 45 | // If answer[] or reference[] is on device, please set is_answer_on_device or is_reference_on_device to true 46 | template 47 | inline bool isArrayAlmostEqual( 48 | const T* answer_ptr, const T* reference_ptr, const int64_t n, 49 | const bool is_answer_on_device, const bool is_reference_on_device, 50 | const float max_allow_unmatch = -1, 51 | const bool record_pos = false 52 | ) { 53 | bool is_fp32 = std::is_same::value; 54 | float abs_tol = is_fp32 ? 1e-4f : 1e-3f; 55 | float rel_tol = is_fp32 ? 1e-2f : 1e-1f; 56 | int64_t max_non_match = max_allow_unmatch != -1 ? max_allow_unmatch*n : (is_fp32 ? 0.002*n : 0.01*n); // Allow up to 0.2% mismatch for FP32, 1% for FP16/bfloat16 57 | 58 | // Copy the array to host if necessary 59 | T* answer = (T*)answer_ptr; 60 | if (is_answer_on_device) { 61 | answer = (T*)malloc(n*sizeof(T)); 62 | if (!answer) { 63 | printf("Failed to allocate memory for answer in isArrayAlmostEqual\n"); 64 | assert(false); 65 | } 66 | cudaMemcpy(answer, answer_ptr, n*sizeof(T), cudaMemcpyDeviceToHost); 67 | } 68 | T* reference = (T*)reference_ptr; 69 | if (is_reference_on_device) { 70 | reference = (T*)malloc(n*sizeof(T)); 71 | if (!reference) { 72 | printf("Failed to allocate memory for reference in isArrayAlmostEqual\n"); 73 | assert(false); 74 | } 75 | cudaMemcpy(reference, reference_ptr, n*sizeof(T), cudaMemcpyDeviceToHost); 76 | } 77 | 78 | // Compare the two arrays, and output the difference 79 | std::vector error_pos; 80 | int64_t error_count = 0; 81 | int64_t first_error_pos = -1; 82 | for (int64_t i = 0; i < n; ++i) { 83 | bool ok = isFloatAlmostEqual(answer[i], reference[i], abs_tol, rel_tol); 84 | if (!ok) { 85 | if (record_pos){ 86 | error_pos.push_back(i); 87 | } 88 | error_count += 1; 89 | if (error_count == 1) first_error_pos = i; 90 | if (error_count > max_non_match && error_count < max_non_match+4) { 91 | printf("Invalid result: answer[%ld] = %f, reference[%ld] = %f, abs_err = %f, rel_err = %f\n", 92 | i, (float)answer[i], i, (float)reference[i], 93 | fabs(answer[i]-reference[i]), fabs(answer[i]-reference[i])/fabs(reference[i])); 94 | } 95 | } 96 | } 97 | if (error_count != 0) { 98 | printf("Total %ld/%ld (%.2f%%) errors (1st error at pos #%ld)\n", error_count, n, 100.0*error_count/n, first_error_pos); 99 | } 100 | 101 | if (record_pos && error_count > max_non_match) { 102 | std::map cnts; 103 | for (auto x: error_pos) { 104 | cnts[x] ++; 105 | } 106 | for (auto x: cnts) { 107 | printf("Error pos: %ld: %ld\n", x.first, x.second); 108 | } 109 | printf("\n"); 110 | } 111 | 112 | // Free if necessary 113 | if (is_answer_on_device) { 114 | free(answer); 115 | } 116 | if (is_reference_on_device) { 117 | free(reference); 118 | } 119 | 120 | return error_count <= max_non_match; 121 | } 122 | 123 | 124 | // A tiny class for managing an array on GPU 125 | // When being constructed, it allocates a space on GPU and copies the data to it. 126 | // When being destructed, it frees the space on GPU. 127 | // (Something likes std::unique_ptr) 128 | template 129 | class GpuArray { 130 | public: 131 | T* data; 132 | GpuArray(const std::vector &host_data) { 133 | cudaMalloc(&data, host_data.size() * sizeof(T)); 134 | cudaMemcpy(data, host_data.data(), host_data.size() * sizeof(T), cudaMemcpyHostToDevice); 135 | } 136 | ~GpuArray() { 137 | cudaFree(data); 138 | } 139 | }; 140 | -------------------------------------------------------------------------------- /src/unittest/util/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_executable(unittest_util cublas_wrapper.cc) 2 | target_link_libraries(unittest_util PUBLIC util gtest_main) 3 | -------------------------------------------------------------------------------- /src/unittest/util/cublas_wrapper.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include 5 | 6 | #include "../unittest_utils.h" 7 | #include "util/cublas_wrapper.h" 8 | #include "util/cuda_utils.h" 9 | 10 | 11 | // naiveGemmStridedBatched - Perform StridedBatchedGEMM on CPU 12 | // When transa = transb = CUBLAS_OP_N, each matrix in A has a shape of m x k, 13 | // and each matrix in B has a shape of k x n, and the result matrix C has a shape of m x n 14 | // There are totally batchCount matrices in A, B and C 15 | // A, B and C are stored in row major 16 | 17 | template 18 | void naiveGemmStridedBatched( 19 | cublasOperation_t transa, 20 | cublasOperation_t transb, 21 | int m, 22 | int n, 23 | int k, 24 | const T alpha, 25 | const T* Aarray, 26 | long long int stride_a, 27 | const T* Barray, 28 | long long int stride_b, 29 | const T beta, 30 | T* Carray, 31 | long long int stride_c, 32 | int batchCount 33 | ) { 34 | int lda = transa == CUBLAS_OP_N ? k : m; 35 | int ldb = transb == CUBLAS_OP_N ? n : k; 36 | int ldc = n; 37 | for (int batch = 0; batch < batchCount; batch++) { 38 | for (int i = 0; i < m; i++) { 39 | for (int j = 0; j < n; j++) { 40 | // We use float here since 41 | // - It is more accurate when accumulating a_{i, k} b_{k, j} 42 | // - It is faster (2x speedup) on CPU, so our test can run faster 43 | float sum = 0.0; 44 | for (int l = 0; l < k; l++) { 45 | T a_elem = transa == CUBLAS_OP_N ? Aarray[batch * stride_a + i * lda + l] : Aarray[batch * stride_a + l * lda + i]; 46 | T b_elem = transb == CUBLAS_OP_N ? Barray[batch * stride_b + l * ldb + j] : Barray[batch * stride_b + j * ldb + l]; 47 | sum = sum + (float)a_elem*(float)b_elem; 48 | } 49 | Carray[batch * stride_c + i * ldc + j] = alpha * sum + beta * Carray[batch * stride_c + i * ldc + j]; 50 | } 51 | } 52 | } 53 | } 54 | 55 | template 56 | class CublasWrapperTestSuite : public ::testing::Test { 57 | protected: 58 | st::util::CublasWrapper wrapper; 59 | 60 | public: 61 | void SetUp() override { 62 | } 63 | void TearDown() override { 64 | } 65 | }; 66 | 67 | TYPED_TEST_SUITE(CublasWrapperTestSuite, SupportTypes); 68 | 69 | TYPED_TEST(CublasWrapperTestSuite, gemmStridedBatched) { 70 | typedef TypeParam T; 71 | const int M = 61; 72 | const int N = 29; 73 | const int K = 124; 74 | const int BATCH_COUNT = 97; 75 | 76 | const int STRIDE_A = M*K + 12; 77 | const int STRIDE_B = K*N + 9; 78 | const int STRIDE_C = M*N + 13; 79 | 80 | // Alloc arrays on CPU 81 | T* Aarray = new T[STRIDE_A * BATCH_COUNT]; 82 | T* Barray = new T[STRIDE_B * BATCH_COUNT]; 83 | T* ref_Carray = new T[STRIDE_C * BATCH_COUNT]; 84 | 85 | // Alloc arrays on GPU 86 | T* Aarray_gpu, *Barray_gpu, *ans_Carray_gpu; 87 | CUDA_CHECK(cudaMalloc(&Aarray_gpu, STRIDE_A * BATCH_COUNT * sizeof(T))); 88 | CUDA_CHECK(cudaMalloc(&Barray_gpu, STRIDE_B * BATCH_COUNT * sizeof(T))); 89 | CUDA_CHECK(cudaMalloc(&ans_Carray_gpu, STRIDE_C * BATCH_COUNT * sizeof(T))); 90 | 91 | for (int is_transa = 0; is_transa <= 1; ++is_transa) { 92 | for (int is_transb = 0; is_transb <= 1; ++is_transb) { 93 | cublasOperation_t transa = is_transa ? CUBLAS_OP_T : CUBLAS_OP_N; 94 | cublasOperation_t transb = is_transb ? CUBLAS_OP_T : CUBLAS_OP_N; 95 | 96 | // Fill in random numbers 97 | std::mt19937 gen(0); 98 | std::uniform_real_distribution dist(-2, 2); 99 | for (int i = 0; i < STRIDE_A * BATCH_COUNT; i++) { 100 | Aarray[i] = dist(gen); 101 | } 102 | for (int i = 0; i < STRIDE_B * BATCH_COUNT; i++) { 103 | Barray[i] = dist(gen); 104 | } 105 | for (int i = 0; i < STRIDE_C * BATCH_COUNT; i++) { 106 | ref_Carray[i] = dist(gen); 107 | } 108 | T alpha = dist(gen); 109 | T beta = dist(gen); 110 | 111 | // Copy A, B and C to GPU 112 | CUDA_CHECK(cudaMemcpy(Aarray_gpu, Aarray, STRIDE_A * BATCH_COUNT * sizeof(T), cudaMemcpyHostToDevice)); 113 | CUDA_CHECK(cudaMemcpy(Barray_gpu, Barray, STRIDE_B * BATCH_COUNT * sizeof(T), cudaMemcpyHostToDevice)); 114 | CUDA_CHECK(cudaMemcpy(ans_Carray_gpu, ref_Carray, STRIDE_C * BATCH_COUNT * sizeof(T), cudaMemcpyHostToDevice)); 115 | 116 | sync_check_cuda_error(); 117 | 118 | // Calculate ref answers 119 | naiveGemmStridedBatched( 120 | transa, 121 | transb, 122 | M, 123 | N, 124 | K, 125 | alpha, 126 | Aarray, 127 | STRIDE_A, 128 | Barray, 129 | STRIDE_B, 130 | beta, 131 | ref_Carray, 132 | STRIDE_C, 133 | BATCH_COUNT 134 | ); 135 | 136 | // Calculate ans 137 | this->wrapper.gemmStridedBatched( 138 | transa, 139 | transb, 140 | M, 141 | N, 142 | K, 143 | alpha, 144 | Aarray_gpu, 145 | STRIDE_A, 146 | Barray_gpu, 147 | STRIDE_B, 148 | beta, 149 | ans_Carray_gpu, 150 | STRIDE_C, 151 | BATCH_COUNT 152 | ); 153 | sync_check_cuda_error(); 154 | 155 | // Compare 156 | bool is_pass = isArrayAlmostEqual(ans_Carray_gpu, ref_Carray, STRIDE_C * BATCH_COUNT, true, false); 157 | ASSERT_TRUE(is_pass); 158 | } 159 | } 160 | 161 | // Free 162 | CUDA_CHECK(cudaFree(Aarray_gpu)); 163 | CUDA_CHECK(cudaFree(Barray_gpu)); 164 | CUDA_CHECK(cudaFree(ans_Carray_gpu)); 165 | delete[] Aarray; 166 | delete[] Barray; 167 | delete[] ref_Carray; 168 | sync_check_cuda_error(); 169 | } --------------------------------------------------------------------------------